ai_scheduler/internal/services/capability.go

208 lines
5.4 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package services
import (
"ai_scheduler/internal/config"
"ai_scheduler/internal/data/constants"
errorcode "ai_scheduler/internal/data/error"
"ai_scheduler/internal/domain/workflow/runtime"
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg/util"
"ai_scheduler/internal/pkg/utils_ollama"
"ai_scheduler/utils"
"context"
"encoding/json"
"fmt"
"strings"
"time"
hytWorkflow "ai_scheduler/internal/domain/workflow/hyt"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/ollama/ollama/api"
"github.com/redis/go-redis/v9"
)
// CapabilityService 统一回调入口
type CapabilityService struct {
cfg *config.Config
workflowManager *runtime.Registry
rdsCli *redis.Client
}
func NewCapabilityService(cfg *config.Config, workflowManager *runtime.Registry, rdb *utils.Rdb) *CapabilityService {
return &CapabilityService{
cfg: cfg,
workflowManager: workflowManager,
rdsCli: rdb.Rdb,
}
}
// 产品数据提取入参
type ProductIngestReq struct {
SysId string `json:"sys_id"` // 业务系统ID - 当前仅支持货易通hyt
Url string `json:"url"` // 商品详情页URL
Title string `json:"title"` // 商品标题
Text string `json:"text"` // 商品描述
Images []string `json:"images"` // 商品图片URL列表
}
type ProductIngestResp struct {
ThreadId string `json:"thread_id"` // 线程ID后续确认调用时需要
SysId string `json:"sys_id"` // 业务系统ID
MetaData any `json:"meta"` // 元数据
Draft string `json:"draft"` // 草稿数据,后续确认调用时需要
}
// ProductIngest 产品数据提取
func (s *CapabilityService) ProductIngest(c *fiber.Ctx) error {
ctx := context.Background()
// 请求头校验
if err := s.checkRequestHeader(c); err != nil {
return err
}
// 解析请求参数
req := ProductIngestReq{}
if err := c.BodyParser(&req); err != nil {
return errorcode.ParamErrf("invalid request body: %v", err)
}
// 必要参数校验
if req.Text == "" || req.SysId == "" {
return errorcode.ParamErrf("missing required fields")
}
// 映射目标系统商品属性中文模板
var sysProductPropertyTemplateZH string
switch req.SysId {
case "hyt": // 货易通
sysProductPropertyTemplateZH = constants.HYTProductPropertyTemplateZH
default:
return errorcode.ParamErrf("invalid sys_id")
}
// 模型调用
client, cleanup, err := utils_ollama.NewClient(s.cfg)
if err != nil {
return err
}
defer cleanup()
res, err := client.Chat(ctx, []api.Message{
{
Role: "system",
Content: constants.SystemPrompt,
},
{
Role: "assistant",
Content: fmt.Sprintf("目标属性模板:%s。", sysProductPropertyTemplateZH),
},
{
Role: "user",
Content: req.Text,
},
{
Role: "user",
Content: "商品图片URL列表" + strings.Join(req.Images, ","),
},
})
if err != nil {
return err
}
// 生成thread_id
threadId := uuid.NewString()
resp := &ProductIngestResp{
ThreadId: threadId,
SysId: req.SysId,
MetaData: req,
Draft: res.Message.Content, // Go中map会无序交给前端解析
}
respJson, _ := json.Marshal(resp)
// 存redis缓存
if err = s.rdsCli.Set(ctx, fmt.Sprintf(constants.CapabilityProductIngestCacheKey, threadId), respJson, 30*time.Minute).Err(); err != nil {
return err
}
// 解析模型输出
c.JSON(resp)
return nil
}
// checkRequestHeader 校验请求头
func (s *CapabilityService) checkRequestHeader(c *fiber.Ctx) error {
// 读取头
token := strings.TrimSpace(c.Get("X-Source-Key"))
ts := strings.TrimSpace(c.Get("X-Timestamp"))
// 时间窗口校验
if ts != "" && !util.IsInTimeWindow(ts, 5*time.Minute) {
return errorcode.AuthNotFound
}
// token校验
if token == "" || token != constants.CapabilityProductIngestToken {
return errorcode.KeyNotFound
}
return nil
}
type ProductIngestConfirmReq struct {
ThreadId string `json:"thread_id"` // 线程ID
Confirmed string `json:"confirmed"` // 已确认数据json字符串
}
// ProductIngestConfirm 商品数据提取确认
func (s *CapabilityService) ProductIngestConfirm(c *fiber.Ctx) error {
ctx := context.Background()
// 请求头校验
if err := s.checkRequestHeader(c); err != nil {
return err
}
// 获取路径参数中的 thread_id
threadId := c.Params("thread_id")
if threadId == "" {
return errorcode.ParamErrf("missing required fields")
}
// 解析请求参数 body
req := ProductIngestConfirmReq{}
if err := c.BodyParser(&req); err != nil {
return errorcode.ParamErrf("invalid request body: %v", err)
}
// 必要参数校验
if req.Confirmed == "" || threadId == "" {
return errorcode.ParamErr("missing required fields")
}
// 校验线程ID是否存在
resp, err := s.rdsCli.Get(ctx, fmt.Sprintf(constants.CapabilityProductIngestCacheKey, threadId)).Result()
if err != nil {
return errorcode.ParamErr("invalid thread_id")
}
var respData ProductIngestResp
if err = json.Unmarshal([]byte(resp), &respData); err != nil {
return errorcode.ParamErr("invalid thread_id data")
}
// 映射目标系统工作流ID
var workflowId string
switch respData.SysId {
// 货易通
case "hyt":
workflowId = hytWorkflow.WorkflowID
default:
return errorcode.ParamErr("invalid sys_id")
}
// 调用eino工作流实现商品上传到目标系统
rec := &entitys.Recognize{UserContent: &entitys.RecognizeUserContent{Text: req.Confirmed}}
res, err := s.workflowManager.Invoke(ctx, workflowId, rec)
if err != nil {
return err
}
return c.JSON(res)
}