fix: 策略模式实现地址提取

This commit is contained in:
fuzhongyun 2026-03-02 17:30:49 +08:00
parent 8d046df04e
commit e981115cc1
5 changed files with 131 additions and 73 deletions

View File

@ -3,6 +3,7 @@ package biz
import (
"ai_scheduler/internal/biz/do"
"ai_scheduler/internal/biz/llm_service"
"ai_scheduler/internal/biz/support"
"github.com/google/wire"
)
@ -22,4 +23,5 @@ var ProviderSetBiz = wire.NewSet(
NewQywxAppBiz,
NewGroupConfigBiz,
do.NewMacro,
support.NewHytAddressIngester,
)

View File

@ -0,0 +1,81 @@
package support
import (
"ai_scheduler/internal/config"
"ai_scheduler/internal/data/constants"
errorcode "ai_scheduler/internal/data/error"
"ai_scheduler/internal/pkg/util"
"ai_scheduler/internal/pkg/utils_vllm"
"context"
"encoding/json"
"strings"
"time"
"github.com/cloudwego/eino/schema"
"github.com/gofiber/fiber/v2"
)
// HytAddressIngester 货易通地址提取实现
type HytAddressIngester struct {
cfg *config.Config
}
func NewHytAddressIngester(cfg *config.Config) *HytAddressIngester {
return &HytAddressIngester{cfg: cfg}
}
// Auth 鉴权逻辑
func (s *HytAddressIngester) Auth(c *fiber.Ctx) error {
// 读取头
token := strings.TrimSpace(c.Get("X-Source-Key"))
ts := strings.TrimSpace(c.Get("X-Timestamp"))
// 时间窗口校验
if ts != "" && !util.IsInTimeWindow(ts, 5*time.Minute) {
return errorcode.AuthNotFound
}
// token校验
if token == "" || token != constants.TokenAddressIngestHyt {
return errorcode.KeyNotFound
}
return nil
}
// Ingest 执行提取逻辑
func (s *HytAddressIngester) Ingest(ctx context.Context, text string) (*AddressIngestResp, error) {
// 模型调用
client, cleanup, err := utils_vllm.NewClient(s.cfg)
if err != nil {
return nil, err
}
defer cleanup()
res, err := client.Chat(ctx, []*schema.Message{
{
Role: "system",
Content: constants.SystemPromptAddressIngestHyt,
},
{
Role: "user",
Content: text,
},
})
if err != nil {
return nil, err
}
// 解析模型返回结果
var addr AddressIngestResp
// 尝试直接解析
if err := json.Unmarshal([]byte(res.Content), &addr); err != nil {
// 修复json字符串
repairedContent, err := util.JSONRepair(res.Content)
if err != nil {
return nil, errorcode.ParamErrf("invalid response body: %v", err)
}
if err := json.Unmarshal([]byte(repairedContent), &addr); err != nil {
return nil, errorcode.ParamErrf("invalid response body: %v", err)
}
}
return &addr, nil
}

View File

@ -0,0 +1,22 @@
package support
import (
"context"
"github.com/gofiber/fiber/v2"
)
// AddressIngestResp 通用地址提取响应
type AddressIngestResp struct {
Recipient string `json:"recipient"` // 收货人
Phone string `json:"phone"` // 联系电话
Address string `json:"address"` // 收货地址
}
// AddressIngester 地址提取接口
type AddressIngester interface {
// Auth 鉴权逻辑
Auth(c *fiber.Ctx) error
// Ingest 执行提取逻辑
Ingest(ctx context.Context, text string) (*AddressIngestResp, error)
}

View File

@ -100,7 +100,7 @@ func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionServi
r.Post("/capability/product/ingest/:thread_id/confirm", capabilityService.ProductIngestConfirm) // 商品数据提取确认
// 外部系统支持
r.Post("/support/address/ingest/hyt", supportService.AddressIngestHyt) // 货易通收获地址提取
r.Post("/support/address/ingest/:platform", supportService.AddressIngest) // 收获地址提取
}
func routerSocket(app *fiber.App, chatService *services.ChatService) {

View File

@ -1,50 +1,48 @@
package services
import (
"ai_scheduler/internal/biz/support"
"ai_scheduler/internal/config"
"ai_scheduler/internal/data/constants"
errorcode "ai_scheduler/internal/data/error"
"ai_scheduler/internal/pkg/util"
"ai_scheduler/internal/pkg/utils_vllm"
"context"
"encoding/json"
"strings"
"time"
"github.com/cloudwego/eino/schema"
"github.com/gofiber/fiber/v2"
)
type SupportService struct {
cfg *config.Config
cfg *config.Config
addressIngester map[string]support.AddressIngester
addressIngestHyt *support.HytAddressIngester
}
func NewSupportService(cfg *config.Config) *SupportService {
return &SupportService{
func NewSupportService(cfg *config.Config, addressIngestHyt *support.HytAddressIngester) *SupportService {
s := &SupportService{
cfg: cfg,
addressIngester: map[string]support.AddressIngester{
"hyt": addressIngestHyt,
},
}
return s
}
type AddressIngestHytReq struct {
type AddressIngestReq struct {
Text string `json:"text"` // 待提取文本
}
type AddressIngestHytResp struct {
Recipient string `json:"recipient"` // 收货人
Phone string `json:"phone"` // 联系电话
Address string `json:"address"` // 收货地址
}
// AddressIngest 收获地址提取
func (s *SupportService) AddressIngest(c *fiber.Ctx) error {
platform := c.Params("platform")
ingester, ok := s.addressIngester[platform]
if !ok {
return errorcode.ParamErrf("unsupported platform: %s", platform)
}
// AddressIngestHyt 货易通收获地址提取
func (s *SupportService) AddressIngestHyt(c *fiber.Ctx) error {
ctx := context.Background()
// 请求头校验
if err := s.checkRequestHeader(c); err != nil {
// 鉴权
if err := ingester.Auth(c); err != nil {
return err
}
// 解析请求参数 body
req := AddressIngestHytReq{}
req := AddressIngestReq{}
if err := c.BodyParser(&req); err != nil {
return errorcode.ParamErrf("invalid request body: %v", err)
}
@ -53,56 +51,11 @@ func (s *SupportService) AddressIngestHyt(c *fiber.Ctx) error {
return errorcode.ParamErrf("missing required fields")
}
// 模型调用
client, cleanup, err := utils_vllm.NewClient(s.cfg)
if err != nil {
return err
}
defer cleanup()
res, err := client.Chat(ctx, []*schema.Message{
{
Role: "system",
Content: constants.SystemPromptAddressIngestHyt,
},
{
Role: "user",
Content: req.Text,
},
})
// 执行提取
res, err := ingester.Ingest(c.Context(), req.Text)
if err != nil {
return err
}
// 解析模型返回结果
var addr AddressIngestHytResp
if err := json.Unmarshal([]byte(res.Content), &addr); err != nil {
// 修复json字符串
res.Content, err = util.JSONRepair(res.Content)
if err != nil {
return errorcode.ParamErrf("invalid response body: %v", err)
}
if err := json.Unmarshal([]byte(res.Content), &addr); err != nil {
return errorcode.ParamErrf("invalid response body: %v", err)
}
}
return c.JSON(addr)
}
// checkRequestHeader 校验请求头
func (s *SupportService) checkRequestHeader(c *fiber.Ctx) error {
// 读取头
token := strings.TrimSpace(c.Get("X-Source-Key"))
ts := strings.TrimSpace(c.Get("X-Timestamp"))
// 时间窗口校验
if ts != "" && !util.IsInTimeWindow(ts, 5*time.Minute) {
return errorcode.AuthNotFound
}
// token校验
if token == "" || token != constants.TokenAddressIngestHyt {
return errorcode.KeyNotFound
}
return nil
return c.JSON(res)
}