109 lines
2.5 KiB
Go
109 lines
2.5 KiB
Go
package services
|
|
|
|
import (
|
|
"ai_scheduler/internal/config"
|
|
"ai_scheduler/internal/data/constants"
|
|
errorcode "ai_scheduler/internal/data/error"
|
|
"ai_scheduler/internal/pkg/util"
|
|
"ai_scheduler/internal/pkg/utils_vllm"
|
|
"context"
|
|
"encoding/json"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/cloudwego/eino/schema"
|
|
"github.com/gofiber/fiber/v2"
|
|
)
|
|
|
|
type SupportService struct {
|
|
cfg *config.Config
|
|
}
|
|
|
|
func NewSupportService(cfg *config.Config) *SupportService {
|
|
return &SupportService{
|
|
cfg: cfg,
|
|
}
|
|
}
|
|
|
|
type AddressIngestHytReq struct {
|
|
Text string `json:"text"` // 待提取文本
|
|
}
|
|
|
|
type AddressIngestHytResp struct {
|
|
Recipient string `json:"recipient"` // 收货人
|
|
Phone string `json:"phone"` // 联系电话
|
|
Address string `json:"address"` // 收货地址
|
|
}
|
|
|
|
// AddressIngestHyt 货易通收获地址提取
|
|
func (s *SupportService) AddressIngestHyt(c *fiber.Ctx) error {
|
|
ctx := context.Background()
|
|
|
|
// 请求头校验
|
|
if err := s.checkRequestHeader(c); err != nil {
|
|
return err
|
|
}
|
|
// 解析请求参数 body
|
|
req := AddressIngestHytReq{}
|
|
if err := c.BodyParser(&req); err != nil {
|
|
return errorcode.ParamErrf("invalid request body: %v", err)
|
|
}
|
|
// 必要参数校验
|
|
if req.Text == "" {
|
|
return errorcode.ParamErrf("missing required fields")
|
|
}
|
|
|
|
// 模型调用
|
|
client, cleanup, err := utils_vllm.NewClient(s.cfg)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer cleanup()
|
|
res, err := client.Chat(ctx, []*schema.Message{
|
|
{
|
|
Role: "system",
|
|
Content: constants.SystemPromptAddressIngestHyt,
|
|
},
|
|
{
|
|
Role: "user",
|
|
Content: req.Text,
|
|
},
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// 解析模型返回结果
|
|
var addr AddressIngestHytResp
|
|
if err := json.Unmarshal([]byte(res.Content), &addr); err != nil {
|
|
// 修复json字符串
|
|
res.Content, err = util.JSONRepair(res.Content)
|
|
if err != nil {
|
|
return errorcode.ParamErrf("invalid response body: %v", err)
|
|
}
|
|
if err := json.Unmarshal([]byte(res.Content), &addr); err != nil {
|
|
return errorcode.ParamErrf("invalid response body: %v", err)
|
|
}
|
|
}
|
|
|
|
return c.JSON(addr)
|
|
}
|
|
|
|
// checkRequestHeader 校验请求头
|
|
func (s *SupportService) checkRequestHeader(c *fiber.Ctx) error {
|
|
// 读取头
|
|
token := strings.TrimSpace(c.Get("X-Source-Key"))
|
|
ts := strings.TrimSpace(c.Get("X-Timestamp"))
|
|
|
|
// 时间窗口校验
|
|
if ts != "" && !util.IsInTimeWindow(ts, 5*time.Minute) {
|
|
return errorcode.AuthNotFound
|
|
}
|
|
// token校验
|
|
if token == "" || token != constants.TokenAddressIngestHyt {
|
|
return errorcode.KeyNotFound
|
|
}
|
|
|
|
return nil
|
|
}
|