fix: 策略模式实现地址提取
This commit is contained in:
parent
8d046df04e
commit
e981115cc1
|
|
@ -3,6 +3,7 @@ package biz
|
||||||
import (
|
import (
|
||||||
"ai_scheduler/internal/biz/do"
|
"ai_scheduler/internal/biz/do"
|
||||||
"ai_scheduler/internal/biz/llm_service"
|
"ai_scheduler/internal/biz/llm_service"
|
||||||
|
"ai_scheduler/internal/biz/support"
|
||||||
|
|
||||||
"github.com/google/wire"
|
"github.com/google/wire"
|
||||||
)
|
)
|
||||||
|
|
@ -22,4 +23,5 @@ var ProviderSetBiz = wire.NewSet(
|
||||||
NewQywxAppBiz,
|
NewQywxAppBiz,
|
||||||
NewGroupConfigBiz,
|
NewGroupConfigBiz,
|
||||||
do.NewMacro,
|
do.NewMacro,
|
||||||
|
support.NewHytAddressIngester,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,81 @@
|
||||||
|
package support
|
||||||
|
|
||||||
|
import (
|
||||||
|
"ai_scheduler/internal/config"
|
||||||
|
"ai_scheduler/internal/data/constants"
|
||||||
|
errorcode "ai_scheduler/internal/data/error"
|
||||||
|
"ai_scheduler/internal/pkg/util"
|
||||||
|
"ai_scheduler/internal/pkg/utils_vllm"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/cloudwego/eino/schema"
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HytAddressIngester 货易通地址提取实现
|
||||||
|
type HytAddressIngester struct {
|
||||||
|
cfg *config.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewHytAddressIngester(cfg *config.Config) *HytAddressIngester {
|
||||||
|
return &HytAddressIngester{cfg: cfg}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Auth 鉴权逻辑
|
||||||
|
func (s *HytAddressIngester) Auth(c *fiber.Ctx) error {
|
||||||
|
// 读取头
|
||||||
|
token := strings.TrimSpace(c.Get("X-Source-Key"))
|
||||||
|
ts := strings.TrimSpace(c.Get("X-Timestamp"))
|
||||||
|
|
||||||
|
// 时间窗口校验
|
||||||
|
if ts != "" && !util.IsInTimeWindow(ts, 5*time.Minute) {
|
||||||
|
return errorcode.AuthNotFound
|
||||||
|
}
|
||||||
|
// token校验
|
||||||
|
if token == "" || token != constants.TokenAddressIngestHyt {
|
||||||
|
return errorcode.KeyNotFound
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ingest 执行提取逻辑
|
||||||
|
func (s *HytAddressIngester) Ingest(ctx context.Context, text string) (*AddressIngestResp, error) {
|
||||||
|
// 模型调用
|
||||||
|
client, cleanup, err := utils_vllm.NewClient(s.cfg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
res, err := client.Chat(ctx, []*schema.Message{
|
||||||
|
{
|
||||||
|
Role: "system",
|
||||||
|
Content: constants.SystemPromptAddressIngestHyt,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: text,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析模型返回结果
|
||||||
|
var addr AddressIngestResp
|
||||||
|
// 尝试直接解析
|
||||||
|
if err := json.Unmarshal([]byte(res.Content), &addr); err != nil {
|
||||||
|
// 修复json字符串
|
||||||
|
repairedContent, err := util.JSONRepair(res.Content)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errorcode.ParamErrf("invalid response body: %v", err)
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal([]byte(repairedContent), &addr); err != nil {
|
||||||
|
return nil, errorcode.ParamErrf("invalid response body: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &addr, nil
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,22 @@
|
||||||
|
package support
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AddressIngestResp 通用地址提取响应
|
||||||
|
type AddressIngestResp struct {
|
||||||
|
Recipient string `json:"recipient"` // 收货人
|
||||||
|
Phone string `json:"phone"` // 联系电话
|
||||||
|
Address string `json:"address"` // 收货地址
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddressIngester 地址提取接口
|
||||||
|
type AddressIngester interface {
|
||||||
|
// Auth 鉴权逻辑
|
||||||
|
Auth(c *fiber.Ctx) error
|
||||||
|
// Ingest 执行提取逻辑
|
||||||
|
Ingest(ctx context.Context, text string) (*AddressIngestResp, error)
|
||||||
|
}
|
||||||
|
|
@ -100,7 +100,7 @@ func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionServi
|
||||||
r.Post("/capability/product/ingest/:thread_id/confirm", capabilityService.ProductIngestConfirm) // 商品数据提取确认
|
r.Post("/capability/product/ingest/:thread_id/confirm", capabilityService.ProductIngestConfirm) // 商品数据提取确认
|
||||||
|
|
||||||
// 外部系统支持
|
// 外部系统支持
|
||||||
r.Post("/support/address/ingest/hyt", supportService.AddressIngestHyt) // 货易通收获地址提取
|
r.Post("/support/address/ingest/:platform", supportService.AddressIngest) // 通用收获地址提取
|
||||||
}
|
}
|
||||||
|
|
||||||
func routerSocket(app *fiber.App, chatService *services.ChatService) {
|
func routerSocket(app *fiber.App, chatService *services.ChatService) {
|
||||||
|
|
|
||||||
|
|
@ -1,50 +1,48 @@
|
||||||
package services
|
package services
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"ai_scheduler/internal/biz/support"
|
||||||
"ai_scheduler/internal/config"
|
"ai_scheduler/internal/config"
|
||||||
"ai_scheduler/internal/data/constants"
|
|
||||||
errorcode "ai_scheduler/internal/data/error"
|
errorcode "ai_scheduler/internal/data/error"
|
||||||
"ai_scheduler/internal/pkg/util"
|
|
||||||
"ai_scheduler/internal/pkg/utils_vllm"
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/cloudwego/eino/schema"
|
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
type SupportService struct {
|
type SupportService struct {
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
|
addressIngester map[string]support.AddressIngester
|
||||||
|
addressIngestHyt *support.HytAddressIngester
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSupportService(cfg *config.Config) *SupportService {
|
func NewSupportService(cfg *config.Config, addressIngestHyt *support.HytAddressIngester) *SupportService {
|
||||||
return &SupportService{
|
s := &SupportService{
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
|
addressIngester: map[string]support.AddressIngester{
|
||||||
|
"hyt": addressIngestHyt,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
type AddressIngestHytReq struct {
|
type AddressIngestReq struct {
|
||||||
Text string `json:"text"` // 待提取文本
|
Text string `json:"text"` // 待提取文本
|
||||||
}
|
}
|
||||||
|
|
||||||
type AddressIngestHytResp struct {
|
// AddressIngest 收获地址提取
|
||||||
Recipient string `json:"recipient"` // 收货人
|
func (s *SupportService) AddressIngest(c *fiber.Ctx) error {
|
||||||
Phone string `json:"phone"` // 联系电话
|
platform := c.Params("platform")
|
||||||
Address string `json:"address"` // 收货地址
|
ingester, ok := s.addressIngester[platform]
|
||||||
|
if !ok {
|
||||||
|
return errorcode.ParamErrf("unsupported platform: %s", platform)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddressIngestHyt 货易通收获地址提取
|
// 鉴权
|
||||||
func (s *SupportService) AddressIngestHyt(c *fiber.Ctx) error {
|
if err := ingester.Auth(c); err != nil {
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
// 请求头校验
|
|
||||||
if err := s.checkRequestHeader(c); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 解析请求参数 body
|
// 解析请求参数 body
|
||||||
req := AddressIngestHytReq{}
|
req := AddressIngestReq{}
|
||||||
if err := c.BodyParser(&req); err != nil {
|
if err := c.BodyParser(&req); err != nil {
|
||||||
return errorcode.ParamErrf("invalid request body: %v", err)
|
return errorcode.ParamErrf("invalid request body: %v", err)
|
||||||
}
|
}
|
||||||
|
|
@ -53,56 +51,11 @@ func (s *SupportService) AddressIngestHyt(c *fiber.Ctx) error {
|
||||||
return errorcode.ParamErrf("missing required fields")
|
return errorcode.ParamErrf("missing required fields")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 模型调用
|
// 执行提取
|
||||||
client, cleanup, err := utils_vllm.NewClient(s.cfg)
|
res, err := ingester.Ingest(c.Context(), req.Text)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer cleanup()
|
|
||||||
res, err := client.Chat(ctx, []*schema.Message{
|
|
||||||
{
|
|
||||||
Role: "system",
|
|
||||||
Content: constants.SystemPromptAddressIngestHyt,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Role: "user",
|
|
||||||
Content: req.Text,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 解析模型返回结果
|
return c.JSON(res)
|
||||||
var addr AddressIngestHytResp
|
|
||||||
if err := json.Unmarshal([]byte(res.Content), &addr); err != nil {
|
|
||||||
// 修复json字符串
|
|
||||||
res.Content, err = util.JSONRepair(res.Content)
|
|
||||||
if err != nil {
|
|
||||||
return errorcode.ParamErrf("invalid response body: %v", err)
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal([]byte(res.Content), &addr); err != nil {
|
|
||||||
return errorcode.ParamErrf("invalid response body: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return c.JSON(addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// checkRequestHeader 校验请求头
|
|
||||||
func (s *SupportService) checkRequestHeader(c *fiber.Ctx) error {
|
|
||||||
// 读取头
|
|
||||||
token := strings.TrimSpace(c.Get("X-Source-Key"))
|
|
||||||
ts := strings.TrimSpace(c.Get("X-Timestamp"))
|
|
||||||
|
|
||||||
// 时间窗口校验
|
|
||||||
if ts != "" && !util.IsInTimeWindow(ts, 5*time.Minute) {
|
|
||||||
return errorcode.AuthNotFound
|
|
||||||
}
|
|
||||||
// token校验
|
|
||||||
if token == "" || token != constants.TokenAddressIngestHyt {
|
|
||||||
return errorcode.KeyNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue