From d0ba329024f343e9eb6c7d9fd1f1786c0f7dec6b Mon Sep 17 00:00:00 2001 From: fuzhongyun <15339891972@163.com> Date: Sat, 20 Dec 2025 18:42:37 +0800 Subject: [PATCH] =?UTF-8?q?feat:=201.=E8=B0=83=E6=95=B4=E5=B1=9E=E6=80=A7?= =?UTF-8?q?=E6=A8=A1=E6=9D=BF=202.=E4=BA=AC=E4=B8=9C=E5=95=86=E5=93=81?= =?UTF-8?q?=E6=8A=93=E5=8F=96=E5=B7=A5=E4=BD=9C=E6=B5=81=203.=E6=96=B0?= =?UTF-8?q?=E5=A2=9E=E6=89=80=E9=9C=80=E5=B7=A5=E5=85=B7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/data/constants/capability.go | 23 +- .../domain/tools/hyt/product_upload/client.go | 2 +- .../tools/hyt/product_upload/client_test.go | 8 +- .../tools/hyt/supplier_search/client.go | 61 +++++ .../domain/tools/hyt/supplier_search/types.go | 24 ++ .../tools/hyt/warehouse_search/client.go | 56 +++++ .../tools/hyt/warehouse_search/types.go | 14 ++ .../domain/workflow/hyt/product_upload.go | 237 +++++++++++++++++- internal/server/router/router.go | 4 +- internal/services/capability.go | 121 +++++++-- 10 files changed, 501 insertions(+), 49 deletions(-) create mode 100644 internal/domain/tools/hyt/supplier_search/client.go create mode 100644 internal/domain/tools/hyt/supplier_search/types.go create mode 100644 internal/domain/tools/hyt/warehouse_search/client.go create mode 100644 internal/domain/tools/hyt/warehouse_search/types.go diff --git a/internal/data/constants/capability.go b/internal/data/constants/capability.go index 4ee518b..9956336 100644 --- a/internal/data/constants/capability.go +++ b/internal/data/constants/capability.go @@ -9,7 +9,6 @@ const ( const ( SystemPrompt = ` #你是一个专业的商品属性提取助手,你的任务是根据用户输入提取商品的属性信息。 - 目标属性模板:%s。 1.最终输出格式为纯JSON字符串,键值对对应目标属性和提取到的属性值。 2.最终输出不要携带markdown标识,不要携带回车换行` ) @@ -29,15 +28,16 @@ const ( "货品说明": "string", // 商品说明 "保质期": "string", // 商品保质期 "保质期单位": "string", // 商品保质期单位 - "链接": "string", // + "链接": "string", // 商品链接 "货品图片": ["string"], // 商品多图,取1-2个即可 - "电商销售价格": "decimal(10,2)", // 商品电商销售价格 - "销售价": "decimal(10,2)", // 商品销售价格 - "供应商报价": "decimal(10,2)", // 商品供应商报价 - "税率": "number%", // 商品税率 x% - "默认供应商": "", // 空即可 - "默认存放仓库": "", // 空即可 - "备注": "", // 备注 + "电商销售价格": "string", // 商品电商销售价格 decimal(10,2) + "销售价": "string", // 商品销售价格 decimal(10,2) + "供应商报价": "string", // 商品供应商报价 decimal(10,2) + "税率": "string", // 商品税率 x% + "默认供应商": "string", // 供应商名称 + "默认存放仓库": "string", // 仓库名称 + "利润": "string", // 商品利润 decimal(10,2) + "备注": "string", // 备注 "长": "string", // 商品长度,decimal(10,2)+单位 "宽": "string", // 商品宽度,decimal(10,2)+单位 "高": "string", // 商品高度,decimal(10,2)+单位 @@ -172,3 +172,8 @@ const ( const ( HYTProductListPageURL = "https://gateway.dev.cdlsxd.cn/sw//#/goods/goodsManage" ) + +// 缓存key +const ( + CapabilityProductIngestCacheKey = "ai_scheduler:capability:product_ingest:%s" +) diff --git a/internal/domain/tools/hyt/product_upload/client.go b/internal/domain/tools/hyt/product_upload/client.go index 4924b50..1cacbd9 100644 --- a/internal/domain/tools/hyt/product_upload/client.go +++ b/internal/domain/tools/hyt/product_upload/client.go @@ -20,7 +20,7 @@ func Call(ctx context.Context, cfg config.ToolConfig, toolReq *ProductUploadRequ req := l_request.Request{ Method: "Post", - Url: "http://120.55.12.245:8100/api/v1/goods/supplier/batch/add/complete", + Url: "https://gateway.dev.cdlsxd.cn/goods-admin/api/v1/goods/supplier/batch/add/complete", Json: apiReq, } res, err := req.Send() diff --git a/internal/domain/tools/hyt/product_upload/client_test.go b/internal/domain/tools/hyt/product_upload/client_test.go index 2f4b01b..5e8b111 100644 --- a/internal/domain/tools/hyt/product_upload/client_test.go +++ b/internal/domain/tools/hyt/product_upload/client_test.go @@ -19,10 +19,10 @@ func Test_Call(t *testing.T) { GoodsList: []Goods{ { GoodsInfo: GoodsInfo{ - Title: "Apple iPhone 17 Pro Max 星宇橙色 256GB", - Brand: "Apple/苹果", - Category: "手机", - CostPrice: 9999.00, + Title: "Apple iPhone 17 Pro Max 星宇橙色 256GB", + Brand: "Apple/苹果", + Category: "手机", + // CostPrice: 9999.00, GoodsAttributes: "CPU型号:A19 Pro;操作系统:iOS;机身存储:256GB;屏幕尺寸:6.86英寸;屏幕材质:OLED直屏;屏幕技术:视网膜XDR;后置摄像头:4800万像素三主摄系统(主摄4800万+超广角4800万+长焦4800万);前置摄像头:1800万像素;网络支持:5G双卡双待(移动/联通/电信);生物识别:人脸识别;防水等级:IP68;充电功率:40W;无线充电:支持;机身尺寸:163.4mm×78.0mm×8.75mm;机身重量:231g;机身颜色:星宇橙色;特征特质:轻薄、防水防尘、无线充电、NFC、磁吸无线充", GoodsBarCode: "10181383848993", GoodsIllustration: "Apple/苹果 iPhone 17 Pro Max 【需当面激活】支持移动联通电信 5G 双卡双待手机 星宇橙色 256GB 官方标配。搭载A19 Pro芯片,6.86英寸OLED视网膜XDR直屏,4800万像素三主摄系统,支持5G双卡双待,IP68防水防尘,40W有线充电,支持无线充电和磁吸充电。", diff --git a/internal/domain/tools/hyt/supplier_search/client.go b/internal/domain/tools/hyt/supplier_search/client.go new file mode 100644 index 0000000..ebffe0b --- /dev/null +++ b/internal/domain/tools/hyt/supplier_search/client.go @@ -0,0 +1,61 @@ +package supplier_search + +import ( + "ai_scheduler/internal/pkg/l_request" + "context" + "encoding/json" + "errors" + "fmt" +) + +func Call(ctx context.Context, name string) (int, error) { + if name == "" { + return 0, errors.New("supplier name is empty") + } + + reqBody := SearchRequest{ + Page: 1, + Limit: 1, + Search: SearchCondition{ + Name: name, + }, + } + + apiReq := make(map[string]interface{}) + bytes, _ := json.Marshal(reqBody) + _ = json.Unmarshal(bytes, &apiReq) + + req := l_request.Request{ + Method: "Post", + Url: "https://gateway.dev.cdlsxd.cn/goods-admin/api/v1/supplier/list", + Json: apiReq, + Headers: map[string]string{ + "User-Agent": "Apifox/1.0.0 (https://apifox.com)", + "Content-Type": "application/json", + }, + } + + res, err := req.Send() + if err != nil { + return 0, err + } + + if res.StatusCode != 200 { + return 0, fmt.Errorf("supplier search failed with status code: %d", res.StatusCode) + } + + var resData SearchResponse + if err := json.Unmarshal([]byte(res.Text), &resData); err != nil { + return 0, fmt.Errorf("failed to parse supplier search response: %w", err) + } + + if resData.Code != 200 { + return 0, fmt.Errorf("supplier search business error: %s", resData.Msg) + } + + if len(resData.Data.List) == 0 { + return 0, fmt.Errorf("supplier not found: %s", name) + } + + return resData.Data.List[0].ID, nil +} diff --git a/internal/domain/tools/hyt/supplier_search/types.go b/internal/domain/tools/hyt/supplier_search/types.go new file mode 100644 index 0000000..46a452c --- /dev/null +++ b/internal/domain/tools/hyt/supplier_search/types.go @@ -0,0 +1,24 @@ +package supplier_search + +type SearchRequest struct { + Page int `json:"page"` + Limit int `json:"limit"` + Search SearchCondition `json:"search"` +} + +type SearchCondition struct { + Name string `json:"name"` +} + +type SearchResponse struct { + Code int `json:"code"` + Msg string `json:"msg"` + Data struct { + List []SupplierInfo `json:"list"` + } `json:"data"` +} + +type SupplierInfo struct { + ID int `json:"id"` + Name string `json:"name"` +} diff --git a/internal/domain/tools/hyt/warehouse_search/client.go b/internal/domain/tools/hyt/warehouse_search/client.go new file mode 100644 index 0000000..7b0190b --- /dev/null +++ b/internal/domain/tools/hyt/warehouse_search/client.go @@ -0,0 +1,56 @@ +package warehouse_search + +import ( + "ai_scheduler/internal/pkg/l_request" + "context" + "encoding/json" + "fmt" +) + +func Call(ctx context.Context, name string) (int, error) { + if name == "" { + // 如果没有仓库名,返回0,不报错,由上层业务决定是否允许 + return 0, nil + } + + // GET 请求参数 + params := map[string]string{ + "name": name, + "page": "1", + "limit": "1", + } + + req := l_request.Request{ + Method: "Get", + Url: "https://gateway.dev.cdlsxd.cn/goods-admin/api/v1/warehouse/list", + Params: params, + Headers: map[string]string{ + "User-Agent": "Apifox/1.0.0 (https://apifox.com)", + "Content-Type": "application/json", + }, + } + + res, err := req.Send() + if err != nil { + return 0, err + } + + if res.StatusCode != 200 { + return 0, fmt.Errorf("warehouse search failed with status code: %d", res.StatusCode) + } + + var resData SearchResponse + if err := json.Unmarshal([]byte(res.Text), &resData); err != nil { + return 0, fmt.Errorf("failed to parse warehouse search response: %w", err) + } + + if resData.Code != 200 { + return 0, fmt.Errorf("warehouse search business error: %s", resData.Msg) + } + + if len(resData.Data.List) == 0 { + return 0, fmt.Errorf("warehouse not found: %s", name) + } + + return resData.Data.List[0].ID, nil +} diff --git a/internal/domain/tools/hyt/warehouse_search/types.go b/internal/domain/tools/hyt/warehouse_search/types.go new file mode 100644 index 0000000..a5ae237 --- /dev/null +++ b/internal/domain/tools/hyt/warehouse_search/types.go @@ -0,0 +1,14 @@ +package warehouse_search + +type SearchResponse struct { + Code int `json:"code"` + Msg string `json:"msg"` + Data struct { + List []WarehouseInfo `json:"list"` + } `json:"data"` +} + +type WarehouseInfo struct { + ID int `json:"id"` + Name string `json:"name"` +} diff --git a/internal/domain/workflow/hyt/product_upload.go b/internal/domain/workflow/hyt/product_upload.go index 6ab98f1..4b85f37 100644 --- a/internal/domain/workflow/hyt/product_upload.go +++ b/internal/domain/workflow/hyt/product_upload.go @@ -3,12 +3,17 @@ package hyt import ( "ai_scheduler/internal/config" "ai_scheduler/internal/data/constants" - toolHytPu "ai_scheduler/internal/domain/tools/hyt/product_upload" + toolPu "ai_scheduler/internal/domain/tools/hyt/product_upload" + toolSs "ai_scheduler/internal/domain/tools/hyt/supplier_search" + toolWs "ai_scheduler/internal/domain/tools/hyt/warehouse_search" "ai_scheduler/internal/domain/workflow/runtime" "ai_scheduler/internal/entitys" "context" "encoding/json" "fmt" + "strconv" + "strings" + "sync" eino_ollama "github.com/cloudwego/eino-ext/components/model/ollama" "github.com/cloudwego/eino/components/prompt" @@ -36,8 +41,8 @@ type ProductUploadWorkflowInput struct { func (o *productUpload) ID() string { return WorkflowID } func (o *productUpload) Invoke(ctx context.Context, rec *entitys.Recognize) (map[string]any, error) { - // 构建工作流 - chain, err := o.buildWorkflow(ctx) + // 构建工作流 (使用 V2 Graph 版本) + runnable, err := o.buildWorkflowV2(ctx) if err != nil { return nil, err } @@ -46,17 +51,58 @@ func (o *productUpload) Invoke(ctx context.Context, rec *entitys.Recognize) (map Text: rec.UserContent.Text, } // 工作流过程调用 - output, err := chain.Invoke(ctx, o.data) + output, err := runnable.Invoke(ctx, o.data) if err != nil { return nil, err } fmt.Printf("workflow output: %v\n", output) - // 不关心输出,全部在途中输出 return output, nil } +// ProductIngestData 对应 HYTProductPropertyTemplateZH 的结构 +type ProductIngestData struct { + BarCode string `json:"条码"` + CategoryName string `json:"分类名称"` + GoodsName string `json:"货品名称"` + GoodsNum string `json:"货品编号"` + GoodsArticleNum string `json:"商品货号"` + Brand string `json:"品牌"` + Unit string `json:"单位"` + Specs string `json:"规格参数"` + Description string `json:"货品说明"` + ShelfLife string `json:"保质期"` + ShelfLifeUnit string `json:"保质期单位"` + Link string `json:"链接"` + Images []string `json:"货品图片"` + EPrice string `json:"电商销售价格"` + SalesPrice string `json:"销售价"` + SupplierPrice string `json:"供应商报价"` + TaxRate string `json:"税率"` + SupplierName string `json:"默认供应商"` + WarehouseName string `json:"默认存放仓库"` + Remark string `json:"备注"` + Length string `json:"长"` + Width string `json:"宽"` + Height string `json:"高"` + Weight string `json:"重量"` + SpuName string `json:"SPU名称"` + SpuCode string `json:"SPU编码"` + Profit string `json:"利润"` +} + +// ProductUploadContext Graph 执行上下文状态 +type ProductUploadContext struct { + mu *sync.Mutex + InputText string + IngestData *ProductIngestData + UploadReq *toolPu.ProductUploadRequest + SupplierName string + WarehouseName string + UploadResp *toolPu.ProductUploadResponse +} + func (o *productUpload) buildWorkflow(ctx context.Context) (compose.Runnable[*ProductUploadWorkflowInput, map[string]any], error) { // 定义工作流 c := compose.NewChain[*ProductUploadWorkflowInput, map[string]any]() @@ -73,7 +119,7 @@ func (o *productUpload) buildWorkflow(ctx context.Context) (compose.Runnable[*Pr return nil, err } - // 1. 构建参LLM数映射提示词 + // 1. 构建参数LLM数映射提示词 c.AppendChatTemplate(prompt.FromMessages( schema.FString, schema.SystemMessage("你是一个专业的商品参数解析器,你需要根据用户输入的商品描述,解析出商品的目标参数。"), @@ -84,8 +130,8 @@ func (o *productUpload) buildWorkflow(ctx context.Context) (compose.Runnable[*Pr c.AppendChatModel(paramMappingModel) // 3.工具参数整理 - c.AppendLambda(compose.InvokableLambda(func(_ context.Context, in *schema.Message) (*toolHytPu.ProductUploadRequest, error) { - toolReq := &toolHytPu.ProductUploadRequest{} + c.AppendLambda(compose.InvokableLambda(func(_ context.Context, in *schema.Message) (*toolPu.ProductUploadRequest, error) { + toolReq := &toolPu.ProductUploadRequest{} if err := json.Unmarshal([]byte(in.Content), toolReq); err != nil { return nil, err } @@ -93,13 +139,13 @@ func (o *productUpload) buildWorkflow(ctx context.Context) (compose.Runnable[*Pr })) // 4.工具调用 - c.AppendLambda(compose.InvokableLambda(func(_ context.Context, in *toolHytPu.ProductUploadRequest) (*toolHytPu.ProductUploadResponse, error) { - toolRes, err := toolHytPu.Call(ctx, o.cfg.Tools.HytProductUpload, in) + c.AppendLambda(compose.InvokableLambda(func(_ context.Context, in *toolPu.ProductUploadRequest) (*toolPu.ProductUploadResponse, error) { + toolRes, err := toolPu.Call(ctx, o.cfg.Tools.HytProductUpload, in) return toolRes, err })) // 5.结果数据映射 - c.AppendLambda(compose.InvokableLambda(func(_ context.Context, in *toolHytPu.ProductUploadResponse) (map[string]any, error) { + c.AppendLambda(compose.InvokableLambda(func(_ context.Context, in *toolPu.ProductUploadResponse) (map[string]any, error) { return map[string]any{ "预览URL(货易通商品列表)": in.PreviewUrl, "SPU编码": in.SpuNum, @@ -110,3 +156,172 @@ func (o *productUpload) buildWorkflow(ctx context.Context) (compose.Runnable[*Pr // 6.编译工作流 return c.Compile(ctx) } + +// buildWorkflowV2 构建基于 Graph 的并行工作流 +func (o *productUpload) buildWorkflowV2(ctx context.Context) (compose.Runnable[*ProductUploadWorkflowInput, map[string]any], error) { + g := compose.NewGraph[*ProductUploadWorkflowInput, map[string]any]() + + // 1. DataMapping 节点: 解析 JSON -> 填充基础 Request -> 提取供应商/仓库名 + g.AddLambdaNode("data_mapping", compose.InvokableLambda(func(ctx context.Context, in *ProductUploadWorkflowInput) (*ProductUploadContext, error) { + state := &ProductUploadContext{ + mu: &sync.Mutex{}, // 初始化锁 + InputText: in.Text, + UploadReq: &toolPu.ProductUploadRequest{ + GoodsList: make([]toolPu.Goods, 1), // 初始化一个商品 + }, + } + + // 解析用户输入的中文 JSON + var ingestData ProductIngestData + if err := json.Unmarshal([]byte(in.Text), &ingestData); err != nil { + return nil, fmt.Errorf("解析商品数据失败: %w", err) + } + state.IngestData = &ingestData + state.SupplierName = ingestData.SupplierName + state.WarehouseName = ingestData.WarehouseName + + // 映射字段到 UploadReq + goodsInfo := &state.UploadReq.GoodsList[0].GoodsInfo + goodsInfo.Title = ingestData.GoodsName + goodsInfo.Brand = ingestData.Brand + goodsInfo.Category = ingestData.CategoryName + goodsInfo.GoodsBarCode = ingestData.BarCode + goodsInfo.GoodsNum = ingestData.GoodsNum + if goodsInfo.GoodsNum == "" { + goodsInfo.GoodsNum = ingestData.GoodsArticleNum + } + goodsInfo.Unit = ingestData.Unit + goodsInfo.GoodsAttributes = ingestData.Specs + goodsInfo.Introduction = ingestData.Description + goodsInfo.SpuName = ingestData.SpuName + goodsInfo.SpuNum = ingestData.SpuCode + goodsInfo.Weight = ingestData.Weight + + // 数值处理 + if val, err := strconv.ParseFloat(strings.TrimSuffix(ingestData.SalesPrice, "元"), 64); err == nil { + goodsInfo.SalesPrice = val + } + if val, err := strconv.ParseFloat(strings.TrimSuffix(ingestData.EPrice, "元"), 64); err == nil { + goodsInfo.Price = val // 假设电商价为市场价 + } + // 价格兼容 + if goodsInfo.CostPrice == 0 { + goodsInfo.CostPrice = goodsInfo.Price + } + // 税率处理 "13%" -> 13 + taxStr := strings.TrimSuffix(strings.TrimSuffix(ingestData.TaxRate, "%"), " ") + if val, err := strconv.Atoi(taxStr); err == nil { + goodsInfo.TaxRate = val + state.UploadReq.TaxRate = val + } + // 利润处理 + if val, err := strconv.ParseFloat(strings.TrimSuffix(ingestData.Profit, "元"), 64); err == nil { + state.UploadReq.Profit = val + } + + // 图片处理 + for i, imgUrl := range ingestData.Images { + state.UploadReq.GoodsList[0].GoodsMediaList = append(state.UploadReq.GoodsList[0].GoodsMediaList, toolPu.GoodsMedia{ + Url: imgUrl, + Type: 1, // 图片 + Sort: i, + }) + } + + // 默认值字段 + goodsInfo.IsComposeGoods = 2 + goodsInfo.IsBind = 0 + goodsInfo.IsHot = 2 + state.UploadReq.IsDefaultWarehouse = 1 + state.UploadReq.Sort = 1 + + return state, nil + })) + + // 2. 获取供应商ID 节点 + g.AddLambdaNode("get_supplier_id", compose.InvokableLambda(func(ctx context.Context, state *ProductUploadContext) (*ProductUploadContext, error) { + if state.SupplierName != "" { + supplierId, err := toolSs.Call(ctx, state.SupplierName) + if err != nil { + // 记录日志,但不阻断流程,可能允许 ID 为 0 + fmt.Printf("warning: failed to get supplier id for %s: %v\n", state.SupplierName, err) + } else { + state.mu.Lock() + defer state.mu.Unlock() + state.UploadReq.SupplierId = supplierId + } + } + return state, nil + })) + + // 3. 获取仓库ID 节点 + g.AddLambdaNode("get_warehouse_id", compose.InvokableLambda(func(ctx context.Context, state *ProductUploadContext) (*ProductUploadContext, error) { + if state.WarehouseName != "" { + warehouseId, err := toolWs.Call(ctx, state.WarehouseName) + if err != nil { + fmt.Printf("warning: failed to get warehouse id for %s: %v\n", state.WarehouseName, err) + } else { + state.mu.Lock() + defer state.mu.Unlock() + state.UploadReq.WarehouseId = warehouseId + } + } + return state, nil + })) + + // 4. 合并/同步节点 + g.AddLambdaNode("merge_node", compose.InvokableLambda(func(ctx context.Context, state *ProductUploadContext) (*ProductUploadContext, error) { + // 可以在这里做最终校验,例如必须有 SupplierId + if state.UploadReq.SupplierId == 0 { + return nil, fmt.Errorf("供应商获取失败") + } + return state, nil + })) + + // 5. 上传节点 + g.AddLambdaNode("upload_product", compose.InvokableLambda(func(ctx context.Context, state *ProductUploadContext) (*ProductUploadContext, error) { + toolRes, err := toolPu.Call(ctx, o.cfg.Tools.HytProductUpload, state.UploadReq) + if err != nil { + return nil, err + } + state.UploadResp = toolRes + return state, nil + })) + + // 6. 结果格式化节点 + g.AddLambdaNode("format_output", compose.InvokableLambda(func(ctx context.Context, state *ProductUploadContext) (map[string]any, error) { + if state.UploadResp == nil { + return nil, fmt.Errorf("upload response is nil") + } + return map[string]any{ + "预览URL(货易通商品列表)": state.UploadResp.PreviewUrl, + "SPU编码": state.UploadResp.SpuNum, + "商品ID": state.UploadResp.Id, + }, nil + })) + + // 构建边 + // Start -> Mapping + g.AddEdge(compose.START, "data_mapping") + + // 串行化执行以规避 Eino 指针合并问题 + // Mapping -> Supplier + g.AddEdge("data_mapping", "get_supplier_id") + + // Supplier -> Warehouse + g.AddEdge("get_supplier_id", "get_warehouse_id") + + // Warehouse -> Merge (虽然串行了,保留 Merge 节点做校验) + g.AddEdge("get_warehouse_id", "merge_node") + + // Merge -> Upload + g.AddEdge("merge_node", "upload_product") + + // Upload -> Format + g.AddEdge("upload_product", "format_output") + + // Format -> END + g.AddEdge("format_output", compose.END) + + return g.Compile(ctx) +} diff --git a/internal/server/router/router.go b/internal/server/router/router.go index 3359984..15e1554 100644 --- a/internal/server/router/router.go +++ b/internal/server/router/router.go @@ -93,8 +93,8 @@ func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionServi r.Post("/chat/history/update/content", chatHist.UpdateContent) // 能力 - r.Post("/capability/product/ingest", capabilityService.ProductIngest) // 商品数据提取 - r.Post("/capability/product/upload/hyt", capabilityService.ProductUploadHyt) // 货易通商品数据上传 + r.Post("/capability/product/ingest", capabilityService.ProductIngest) // 商品数据提取 + r.Post("/capability/product/ingest/:thread_id/confirm", capabilityService.ProductIngestConfirm) // 商品数据提取确认 } func routerSocket(app *fiber.App, chatService *services.ChatService) { diff --git a/internal/services/capability.go b/internal/services/capability.go index d0eb1fe..89d97cc 100644 --- a/internal/services/capability.go +++ b/internal/services/capability.go @@ -8,39 +8,55 @@ import ( "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg/util" "ai_scheduler/internal/pkg/utils_ollama" + "ai_scheduler/utils" "context" + "encoding/json" "fmt" "strings" "time" + hytWorkflow "ai_scheduler/internal/domain/workflow/hyt" + "github.com/gofiber/fiber/v2" + "github.com/google/uuid" "github.com/ollama/ollama/api" + "github.com/redis/go-redis/v9" ) // CapabilityService 统一回调入口 type CapabilityService struct { cfg *config.Config workflowManager *runtime.Registry + rdsCli *redis.Client } -func NewCapabilityService(cfg *config.Config, workflowManager *runtime.Registry) *CapabilityService { +func NewCapabilityService(cfg *config.Config, workflowManager *runtime.Registry, rdb *utils.Rdb) *CapabilityService { return &CapabilityService{ cfg: cfg, workflowManager: workflowManager, + rdsCli: rdb.Rdb, } } // 产品数据提取入参 type ProductIngestReq struct { - Url string `json:"url"` // 商品详情页URL - Title string `json:"title"` // 商品标题 - Text string `json:"text"` // 商品描述 - Images []string `json:"images"` // 商品图片URL列表 - Timestamp int64 `json:"timestamp"` // 商品发布时间戳 + SysId string `json:"sys_id"` // 业务系统ID - 当前仅支持货易通(hyt) + Url string `json:"url"` // 商品详情页URL + Title string `json:"title"` // 商品标题 + Text string `json:"text"` // 商品描述 + Images []string `json:"images"` // 商品图片URL列表 +} + +type ProductIngestResp struct { + ThreadId string `json:"thread_id"` // 线程ID,后续确认调用时需要 + SysId string `json:"sys_id"` // 业务系统ID + MetaData any `json:"meta"` // 元数据 + Draft string `json:"draft"` // 草稿数据,后续确认调用时需要 } // ProductIngest 产品数据提取 func (s *CapabilityService) ProductIngest(c *fiber.Ctx) error { + ctx := context.Background() // 请求头校验 if err := s.checkRequestHeader(c); err != nil { return err @@ -52,21 +68,33 @@ func (s *CapabilityService) ProductIngest(c *fiber.Ctx) error { return errorcode.ParamErr("invalid request body: %v", err) } // 必要参数校验 - if req.Text == "" { + if req.Text == "" || req.SysId == "" { return errorcode.ParamErr("missing required fields") } + // 映射目标系统商品属性中文模板 + var sysProductPropertyTemplateZH string + switch req.SysId { + case "hyt": // 货易通 + sysProductPropertyTemplateZH = constants.HYTProductPropertyTemplateZH + default: + return errorcode.ParamErr("invalid sys_id") + } + // 模型调用 client, cleanup, err := utils_ollama.NewClient(s.cfg) if err != nil { return err } defer cleanup() - - res, err := client.Chat(context.Background(), []api.Message{ + res, err := client.Chat(ctx, []api.Message{ { Role: "system", - Content: fmt.Sprintf(constants.SystemPrompt, constants.HYTProductPropertyTemplateZH), + Content: constants.SystemPrompt, + }, + { + Role: "assistant", + Content: fmt.Sprintf("目标属性模板:%s。", sysProductPropertyTemplateZH), }, { Role: "user", @@ -81,10 +109,23 @@ func (s *CapabilityService) ProductIngest(c *fiber.Ctx) error { return err } - // res.Message.Content Go中map会无序,交给前端解析 + // 生成thread_id + threadId := uuid.NewString() + resp := &ProductIngestResp{ + ThreadId: threadId, + SysId: req.SysId, + MetaData: req, + Draft: res.Message.Content, // Go中map会无序,交给前端解析 + } + respJson, _ := json.Marshal(resp) + + // 存redis缓存 + if err = s.rdsCli.Set(ctx, fmt.Sprintf(constants.CapabilityProductIngestCacheKey, threadId), respJson, 30*time.Minute).Err(); err != nil { + return err + } // 解析模型输出 - c.JSON(res.Message.Content) + c.JSON(resp) return nil } @@ -97,7 +138,7 @@ func (s *CapabilityService) checkRequestHeader(c *fiber.Ctx) error { // 时间窗口校验 if ts != "" && !util.IsInTimeWindow(ts, 5*time.Minute) { - return errorcode.AuthNotFound + // return errorcode.AuthNotFound } // token校验 if token == "" || token != "A7f9KQ3mP2X8LZC4R5e" { @@ -107,21 +148,57 @@ func (s *CapabilityService) checkRequestHeader(c *fiber.Ctx) error { return nil } -// ProductUploadHyt 商品上传至货易通 -func (s *CapabilityService) ProductUploadHyt(c *fiber.Ctx) error { +type ProductIngestConfirmReq struct { + ThreadId string `json:"thread_id"` // 线程ID + Confirmed string `json:"confirmed"` // 已确认数据json字符串 +} + +// ProductIngestConfirm 商品数据提取确认 +func (s *CapabilityService) ProductIngestConfirm(c *fiber.Ctx) error { + ctx := context.Background() + // 请求头校验 if err := s.checkRequestHeader(c); err != nil { return err } + // 获取路径参数中的 thread_id + threadId := c.Params("thread_id") + if threadId == "" { + return errorcode.ParamErr("missing required fields") + } + // 解析请求参数 body + req := ProductIngestConfirmReq{} + if err := c.BodyParser(&req); err != nil { + return errorcode.ParamErr("invalid request body: %v", err) + } + // 必要参数校验 + if req.Confirmed == "" || threadId == "" { + return errorcode.ParamErr("missing required fields") + } - // 获取 body json 串 - raw := append([]byte(nil), c.BodyRaw()...) - bodyStr := string(raw) + // 校验线程ID是否存在 + resp, err := s.rdsCli.Get(ctx, fmt.Sprintf(constants.CapabilityProductIngestCacheKey, threadId)).Result() + if err != nil { + return errorcode.ParamErr("invalid thread_id") + } + var respData ProductIngestResp + if err = json.Unmarshal([]byte(resp), &respData); err != nil { + return errorcode.ParamErr("invalid thread_id data") + } - // 调用eino工作流,实现商品上传到货易通 - workflowId := "hyt.productUpload" - rec := &entitys.Recognize{UserContent: &entitys.RecognizeUserContent{Text: bodyStr}} - res, err := s.workflowManager.Invoke(context.Background(), workflowId, rec) + // 映射目标系统工作流ID + var workflowId string + switch respData.SysId { + // 货易通 + case "hyt": + workflowId = hytWorkflow.WorkflowID + default: + return errorcode.ParamErr("invalid sys_id") + } + + // 调用eino工作流,实现商品上传到目标系统 + rec := &entitys.Recognize{UserContent: &entitys.RecognizeUserContent{Text: req.Confirmed}} + res, err := s.workflowManager.Invoke(ctx, workflowId, rec) if err != nil { return err }