208 lines
5.5 KiB
Go
208 lines
5.5 KiB
Go
package services
|
||
|
||
import (
|
||
"ai_scheduler/internal/config"
|
||
"ai_scheduler/internal/data/constants"
|
||
errorcode "ai_scheduler/internal/data/error"
|
||
"ai_scheduler/internal/domain/workflow/runtime"
|
||
"ai_scheduler/internal/entitys"
|
||
"ai_scheduler/internal/pkg/util"
|
||
"ai_scheduler/internal/pkg/utils_ollama"
|
||
"ai_scheduler/utils"
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"strings"
|
||
"time"
|
||
|
||
hytWorkflow "ai_scheduler/internal/domain/workflow/hyt"
|
||
|
||
"github.com/gofiber/fiber/v2"
|
||
"github.com/google/uuid"
|
||
"github.com/ollama/ollama/api"
|
||
"github.com/redis/go-redis/v9"
|
||
)
|
||
|
||
// CapabilityService 统一回调入口
|
||
type CapabilityService struct {
|
||
cfg *config.Config
|
||
workflowManager *runtime.Registry
|
||
rdsCli *redis.Client
|
||
}
|
||
|
||
func NewCapabilityService(cfg *config.Config, workflowManager *runtime.Registry, rdb *utils.Rdb) *CapabilityService {
|
||
return &CapabilityService{
|
||
cfg: cfg,
|
||
workflowManager: workflowManager,
|
||
rdsCli: rdb.Rdb,
|
||
}
|
||
}
|
||
|
||
// 产品数据提取入参
|
||
type ProductIngestReq struct {
|
||
SysId string `json:"sys_id"` // 业务系统ID - 当前仅支持货易通(hyt)
|
||
Url string `json:"url"` // 商品详情页URL
|
||
Title string `json:"title"` // 商品标题
|
||
Text string `json:"text"` // 商品描述
|
||
Images []string `json:"images"` // 商品图片URL列表
|
||
}
|
||
|
||
type ProductIngestResp struct {
|
||
ThreadId string `json:"thread_id"` // 线程ID,后续确认调用时需要
|
||
SysId string `json:"sys_id"` // 业务系统ID
|
||
MetaData any `json:"meta"` // 元数据
|
||
Draft string `json:"draft"` // 草稿数据,后续确认调用时需要
|
||
}
|
||
|
||
// ProductIngest 产品数据提取
|
||
func (s *CapabilityService) ProductIngest(c *fiber.Ctx) error {
|
||
ctx := context.Background()
|
||
// 请求头校验
|
||
if err := s.checkRequestHeader(c); err != nil {
|
||
return err
|
||
}
|
||
|
||
// 解析请求参数
|
||
req := ProductIngestReq{}
|
||
if err := c.BodyParser(&req); err != nil {
|
||
return errorcode.ParamErrf("invalid request body: %v", err)
|
||
}
|
||
// 必要参数校验
|
||
if req.Text == "" || req.SysId == "" {
|
||
return errorcode.ParamErrf("missing required fields")
|
||
}
|
||
|
||
// 映射目标系统商品属性中文模板
|
||
var sysProductPropertyTemplateZH string
|
||
switch req.SysId {
|
||
case "hyt": // 货易通
|
||
sysProductPropertyTemplateZH = constants.HYTGoodsAddPropertyTemplateZH
|
||
default:
|
||
return errorcode.ParamErrf("invalid sys_id")
|
||
}
|
||
|
||
// 模型调用
|
||
client, cleanup, err := utils_ollama.NewClient(s.cfg)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
defer cleanup()
|
||
res, err := client.Chat(ctx, s.cfg.Ollama.MappingModel, []api.Message{
|
||
{
|
||
Role: "system",
|
||
Content: constants.SystemPrompt,
|
||
},
|
||
{
|
||
Role: "assistant",
|
||
Content: fmt.Sprintf("目标属性模板:%s。", sysProductPropertyTemplateZH),
|
||
},
|
||
{
|
||
Role: "user",
|
||
Content: req.Text,
|
||
},
|
||
{
|
||
Role: "user",
|
||
Content: "商品图片URL列表:" + strings.Join(req.Images, ","),
|
||
},
|
||
})
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// 生成thread_id
|
||
threadId := uuid.NewString()
|
||
resp := &ProductIngestResp{
|
||
ThreadId: threadId,
|
||
SysId: req.SysId,
|
||
MetaData: req,
|
||
Draft: res.Message.Content, // Go中map会无序,交给前端解析
|
||
}
|
||
respJson, _ := json.Marshal(resp)
|
||
|
||
// 存redis缓存
|
||
if err = s.rdsCli.Set(ctx, fmt.Sprintf(constants.CapabilityProductIngestCacheKey, threadId), respJson, 30*time.Minute).Err(); err != nil {
|
||
return err
|
||
}
|
||
|
||
// 解析模型输出
|
||
c.JSON(resp)
|
||
|
||
return nil
|
||
}
|
||
|
||
// checkRequestHeader 校验请求头
|
||
func (s *CapabilityService) checkRequestHeader(c *fiber.Ctx) error {
|
||
// 读取头
|
||
token := strings.TrimSpace(c.Get("X-Source-Key"))
|
||
ts := strings.TrimSpace(c.Get("X-Timestamp"))
|
||
|
||
// 时间窗口校验
|
||
if ts != "" && !util.IsInTimeWindow(ts, 5*time.Minute) {
|
||
return errorcode.AuthNotFound
|
||
}
|
||
// token校验
|
||
if token == "" || token != constants.CapabilityProductIngestToken {
|
||
return errorcode.KeyNotFound
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
type ProductIngestConfirmReq struct {
|
||
ThreadId string `json:"thread_id"` // 线程ID
|
||
Confirmed string `json:"confirmed"` // 已确认数据json字符串
|
||
}
|
||
|
||
// ProductIngestConfirm 商品数据提取确认
|
||
func (s *CapabilityService) ProductIngestConfirm(c *fiber.Ctx) error {
|
||
ctx := context.Background()
|
||
|
||
// 请求头校验
|
||
if err := s.checkRequestHeader(c); err != nil {
|
||
return err
|
||
}
|
||
// 获取路径参数中的 thread_id
|
||
threadId := c.Params("thread_id")
|
||
if threadId == "" {
|
||
return errorcode.ParamErrf("missing required fields")
|
||
}
|
||
// 解析请求参数 body
|
||
req := ProductIngestConfirmReq{}
|
||
if err := c.BodyParser(&req); err != nil {
|
||
return errorcode.ParamErrf("invalid request body: %v", err)
|
||
}
|
||
// 必要参数校验
|
||
if req.Confirmed == "" || threadId == "" {
|
||
return errorcode.ParamErr("missing required fields")
|
||
}
|
||
|
||
// 校验线程ID是否存在
|
||
resp, err := s.rdsCli.Get(ctx, fmt.Sprintf(constants.CapabilityProductIngestCacheKey, threadId)).Result()
|
||
if err != nil {
|
||
return errorcode.ParamErr("invalid thread_id")
|
||
}
|
||
var respData ProductIngestResp
|
||
if err = json.Unmarshal([]byte(resp), &respData); err != nil {
|
||
return errorcode.ParamErr("invalid thread_id data")
|
||
}
|
||
|
||
// 映射目标系统工作流ID
|
||
var workflowId string
|
||
switch respData.SysId {
|
||
// 货易通
|
||
case "hyt":
|
||
workflowId = hytWorkflow.WorkflowIDGoodsAdd
|
||
default:
|
||
return errorcode.ParamErr("invalid sys_id")
|
||
}
|
||
|
||
// 调用eino工作流,实现商品上传到目标系统
|
||
rec := &entitys.Recognize{UserContent: &entitys.RecognizeUserContent{Text: req.Confirmed}}
|
||
res, err := s.workflowManager.Invoke(ctx, workflowId, rec)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
return c.JSON(res)
|
||
}
|