fix: 策略模式实现地址提取
This commit is contained in:
parent
8d046df04e
commit
e981115cc1
|
|
@ -3,6 +3,7 @@ package biz
|
|||
import (
|
||||
"ai_scheduler/internal/biz/do"
|
||||
"ai_scheduler/internal/biz/llm_service"
|
||||
"ai_scheduler/internal/biz/support"
|
||||
|
||||
"github.com/google/wire"
|
||||
)
|
||||
|
|
@ -22,4 +23,5 @@ var ProviderSetBiz = wire.NewSet(
|
|||
NewQywxAppBiz,
|
||||
NewGroupConfigBiz,
|
||||
do.NewMacro,
|
||||
support.NewHytAddressIngester,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,81 @@
|
|||
package support
|
||||
|
||||
import (
|
||||
"ai_scheduler/internal/config"
|
||||
"ai_scheduler/internal/data/constants"
|
||||
errorcode "ai_scheduler/internal/data/error"
|
||||
"ai_scheduler/internal/pkg/util"
|
||||
"ai_scheduler/internal/pkg/utils_vllm"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// HytAddressIngester 货易通地址提取实现
|
||||
type HytAddressIngester struct {
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
func NewHytAddressIngester(cfg *config.Config) *HytAddressIngester {
|
||||
return &HytAddressIngester{cfg: cfg}
|
||||
}
|
||||
|
||||
// Auth 鉴权逻辑
|
||||
func (s *HytAddressIngester) Auth(c *fiber.Ctx) error {
|
||||
// 读取头
|
||||
token := strings.TrimSpace(c.Get("X-Source-Key"))
|
||||
ts := strings.TrimSpace(c.Get("X-Timestamp"))
|
||||
|
||||
// 时间窗口校验
|
||||
if ts != "" && !util.IsInTimeWindow(ts, 5*time.Minute) {
|
||||
return errorcode.AuthNotFound
|
||||
}
|
||||
// token校验
|
||||
if token == "" || token != constants.TokenAddressIngestHyt {
|
||||
return errorcode.KeyNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ingest 执行提取逻辑
|
||||
func (s *HytAddressIngester) Ingest(ctx context.Context, text string) (*AddressIngestResp, error) {
|
||||
// 模型调用
|
||||
client, cleanup, err := utils_vllm.NewClient(s.cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
res, err := client.Chat(ctx, []*schema.Message{
|
||||
{
|
||||
Role: "system",
|
||||
Content: constants.SystemPromptAddressIngestHyt,
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Content: text,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 解析模型返回结果
|
||||
var addr AddressIngestResp
|
||||
// 尝试直接解析
|
||||
if err := json.Unmarshal([]byte(res.Content), &addr); err != nil {
|
||||
// 修复json字符串
|
||||
repairedContent, err := util.JSONRepair(res.Content)
|
||||
if err != nil {
|
||||
return nil, errorcode.ParamErrf("invalid response body: %v", err)
|
||||
}
|
||||
if err := json.Unmarshal([]byte(repairedContent), &addr); err != nil {
|
||||
return nil, errorcode.ParamErrf("invalid response body: %v", err)
|
||||
}
|
||||
}
|
||||
return &addr, nil
|
||||
}
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
package support
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// AddressIngestResp 通用地址提取响应
|
||||
type AddressIngestResp struct {
|
||||
Recipient string `json:"recipient"` // 收货人
|
||||
Phone string `json:"phone"` // 联系电话
|
||||
Address string `json:"address"` // 收货地址
|
||||
}
|
||||
|
||||
// AddressIngester 地址提取接口
|
||||
type AddressIngester interface {
|
||||
// Auth 鉴权逻辑
|
||||
Auth(c *fiber.Ctx) error
|
||||
// Ingest 执行提取逻辑
|
||||
Ingest(ctx context.Context, text string) (*AddressIngestResp, error)
|
||||
}
|
||||
|
|
@ -100,7 +100,7 @@ func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionServi
|
|||
r.Post("/capability/product/ingest/:thread_id/confirm", capabilityService.ProductIngestConfirm) // 商品数据提取确认
|
||||
|
||||
// 外部系统支持
|
||||
r.Post("/support/address/ingest/hyt", supportService.AddressIngestHyt) // 货易通收获地址提取
|
||||
r.Post("/support/address/ingest/:platform", supportService.AddressIngest) // 通用收获地址提取
|
||||
}
|
||||
|
||||
func routerSocket(app *fiber.App, chatService *services.ChatService) {
|
||||
|
|
|
|||
|
|
@ -1,50 +1,48 @@
|
|||
package services
|
||||
|
||||
import (
|
||||
"ai_scheduler/internal/biz/support"
|
||||
"ai_scheduler/internal/config"
|
||||
"ai_scheduler/internal/data/constants"
|
||||
errorcode "ai_scheduler/internal/data/error"
|
||||
"ai_scheduler/internal/pkg/util"
|
||||
"ai_scheduler/internal/pkg/utils_vllm"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
type SupportService struct {
|
||||
cfg *config.Config
|
||||
cfg *config.Config
|
||||
addressIngester map[string]support.AddressIngester
|
||||
addressIngestHyt *support.HytAddressIngester
|
||||
}
|
||||
|
||||
func NewSupportService(cfg *config.Config) *SupportService {
|
||||
return &SupportService{
|
||||
func NewSupportService(cfg *config.Config, addressIngestHyt *support.HytAddressIngester) *SupportService {
|
||||
s := &SupportService{
|
||||
cfg: cfg,
|
||||
addressIngester: map[string]support.AddressIngester{
|
||||
"hyt": addressIngestHyt,
|
||||
},
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
type AddressIngestHytReq struct {
|
||||
type AddressIngestReq struct {
|
||||
Text string `json:"text"` // 待提取文本
|
||||
}
|
||||
|
||||
type AddressIngestHytResp struct {
|
||||
Recipient string `json:"recipient"` // 收货人
|
||||
Phone string `json:"phone"` // 联系电话
|
||||
Address string `json:"address"` // 收货地址
|
||||
}
|
||||
// AddressIngest 收获地址提取
|
||||
func (s *SupportService) AddressIngest(c *fiber.Ctx) error {
|
||||
platform := c.Params("platform")
|
||||
ingester, ok := s.addressIngester[platform]
|
||||
if !ok {
|
||||
return errorcode.ParamErrf("unsupported platform: %s", platform)
|
||||
}
|
||||
|
||||
// AddressIngestHyt 货易通收获地址提取
|
||||
func (s *SupportService) AddressIngestHyt(c *fiber.Ctx) error {
|
||||
ctx := context.Background()
|
||||
|
||||
// 请求头校验
|
||||
if err := s.checkRequestHeader(c); err != nil {
|
||||
// 鉴权
|
||||
if err := ingester.Auth(c); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 解析请求参数 body
|
||||
req := AddressIngestHytReq{}
|
||||
req := AddressIngestReq{}
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return errorcode.ParamErrf("invalid request body: %v", err)
|
||||
}
|
||||
|
|
@ -53,56 +51,11 @@ func (s *SupportService) AddressIngestHyt(c *fiber.Ctx) error {
|
|||
return errorcode.ParamErrf("missing required fields")
|
||||
}
|
||||
|
||||
// 模型调用
|
||||
client, cleanup, err := utils_vllm.NewClient(s.cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cleanup()
|
||||
res, err := client.Chat(ctx, []*schema.Message{
|
||||
{
|
||||
Role: "system",
|
||||
Content: constants.SystemPromptAddressIngestHyt,
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Content: req.Text,
|
||||
},
|
||||
})
|
||||
// 执行提取
|
||||
res, err := ingester.Ingest(c.Context(), req.Text)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 解析模型返回结果
|
||||
var addr AddressIngestHytResp
|
||||
if err := json.Unmarshal([]byte(res.Content), &addr); err != nil {
|
||||
// 修复json字符串
|
||||
res.Content, err = util.JSONRepair(res.Content)
|
||||
if err != nil {
|
||||
return errorcode.ParamErrf("invalid response body: %v", err)
|
||||
}
|
||||
if err := json.Unmarshal([]byte(res.Content), &addr); err != nil {
|
||||
return errorcode.ParamErrf("invalid response body: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return c.JSON(addr)
|
||||
}
|
||||
|
||||
// checkRequestHeader 校验请求头
|
||||
func (s *SupportService) checkRequestHeader(c *fiber.Ctx) error {
|
||||
// 读取头
|
||||
token := strings.TrimSpace(c.Get("X-Source-Key"))
|
||||
ts := strings.TrimSpace(c.Get("X-Timestamp"))
|
||||
|
||||
// 时间窗口校验
|
||||
if ts != "" && !util.IsInTimeWindow(ts, 5*time.Minute) {
|
||||
return errorcode.AuthNotFound
|
||||
}
|
||||
// token校验
|
||||
if token == "" || token != constants.TokenAddressIngestHyt {
|
||||
return errorcode.KeyNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
return c.JSON(res)
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue