ai_scheduler/internal/services/support.go

109 lines
2.5 KiB
Go

package services
import (
"ai_scheduler/internal/config"
"ai_scheduler/internal/data/constants"
errorcode "ai_scheduler/internal/data/error"
"ai_scheduler/internal/pkg/util"
"ai_scheduler/internal/pkg/utils_vllm"
"context"
"encoding/json"
"strings"
"time"
"github.com/cloudwego/eino/schema"
"github.com/gofiber/fiber/v2"
)
type SupportService struct {
cfg *config.Config
}
func NewSupportService(cfg *config.Config) *SupportService {
return &SupportService{
cfg: cfg,
}
}
type AddressIngestHytReq struct {
Text string `json:"text"` // 待提取文本
}
type AddressIngestHytResp struct {
Recipient string `json:"recipient"` // 收货人
Phone string `json:"phone"` // 联系电话
Address string `json:"address"` // 收货地址
}
// AddressIngestHyt 货易通收获地址提取
func (s *SupportService) AddressIngestHyt(c *fiber.Ctx) error {
ctx := context.Background()
// 请求头校验
if err := s.checkRequestHeader(c); err != nil {
return err
}
// 解析请求参数 body
req := AddressIngestHytReq{}
if err := c.BodyParser(&req); err != nil {
return errorcode.ParamErrf("invalid request body: %v", err)
}
// 必要参数校验
if req.Text == "" {
return errorcode.ParamErrf("missing required fields")
}
// 模型调用
client, cleanup, err := utils_vllm.NewClient(s.cfg)
if err != nil {
return err
}
defer cleanup()
res, err := client.Chat(ctx, []*schema.Message{
{
Role: "system",
Content: constants.SystemPromptAddressIngestHyt,
},
{
Role: "user",
Content: req.Text,
},
})
if err != nil {
return err
}
// 解析模型返回结果
var addr AddressIngestHytResp
if err := json.Unmarshal([]byte(res.Content), &addr); err != nil {
// 修复json字符串
res.Content, err = util.JSONRepair(res.Content)
if err != nil {
return errorcode.ParamErrf("invalid response body: %v", err)
}
if err := json.Unmarshal([]byte(res.Content), &addr); err != nil {
return errorcode.ParamErrf("invalid response body: %v", err)
}
}
return c.JSON(addr)
}
// checkRequestHeader 校验请求头
func (s *SupportService) checkRequestHeader(c *fiber.Ctx) error {
// 读取头
token := strings.TrimSpace(c.Get("X-Source-Key"))
ts := strings.TrimSpace(c.Get("X-Timestamp"))
// 时间窗口校验
if ts != "" && !util.IsInTimeWindow(ts, 5*time.Minute) {
return errorcode.AuthNotFound
}
// token校验
if token == "" || token != constants.TokenAddressIngestHyt {
return errorcode.KeyNotFound
}
return nil
}