From 284624bcba2a6c1e5d9dc9b916d05070a57089d5 Mon Sep 17 00:00:00 2001 From: fuzhongyun <15339891972@163.com> Date: Thu, 18 Dec 2025 18:19:46 +0800 Subject: [PATCH] =?UTF-8?q?feat:=201.=20=E6=96=B0=E5=A2=9E=20ollamaClient?= =?UTF-8?q?=20chat=E6=96=B9=E6=B3=95=202.=20=E5=A2=9E=E5=8A=A0=E4=BA=A7?= =?UTF-8?q?=E5=93=81=E6=95=B0=E6=8D=AE=E6=8F=90=E5=8F=96=E8=83=BD=E5=8A=9B?= =?UTF-8?q?=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/data/error/error_code.go | 4 + internal/domain/llm/options.go | 21 +-- .../domain/llm/provider/ollama/adapter.go | 9 +- internal/pkg/util/time.go | 41 ++++++ internal/pkg/utils_ollama/client.go | 19 +++ internal/server/http.go | 15 +- internal/server/router/router.go | 15 +- internal/services/callback.go | 69 +++++----- internal/services/capability.go | 129 ++++++++++++++++++ internal/services/provider_set.go | 4 +- 10 files changed, 266 insertions(+), 60 deletions(-) create mode 100644 internal/pkg/util/time.go create mode 100644 internal/services/capability.go diff --git a/internal/data/error/error_code.go b/internal/data/error/error_code.go index 9c28865..9d907e7 100644 --- a/internal/data/error/error_code.go +++ b/internal/data/error/error_code.go @@ -54,3 +54,7 @@ func ParamErr(message string, arg ...any) *BusinessErr { func (e *BusinessErr) Wrap(err error) *BusinessErr { return NewBusinessErr(e.code, err.Error()) } + +func KeyErr() *BusinessErr { + return &BusinessErr{code: KeyNotFound.code, message: KeyNotFound.message} +} diff --git a/internal/domain/llm/options.go b/internal/domain/llm/options.go index c427153..19f8983 100644 --- a/internal/domain/llm/options.go +++ b/internal/domain/llm/options.go @@ -3,14 +3,15 @@ package llm import "time" type Options struct { - Temperature float32 - MaxTokens int - Stream bool - Timeout time.Duration - Modalities []string - SystemPrompt string - Model string - TopP float32 - Stop []string - Endpoint string + Temperature float32 + MaxTokens int + Stream bool + Timeout time.Duration + Modalities []string + SystemPrompt string + Model string + TopP float32 + Stop []string + Endpoint string + Thinking bool } diff --git a/internal/domain/llm/provider/ollama/adapter.go b/internal/domain/llm/provider/ollama/adapter.go index 1c26bab..554bcf0 100644 --- a/internal/domain/llm/provider/ollama/adapter.go +++ b/internal/domain/llm/provider/ollama/adapter.go @@ -15,10 +15,11 @@ func New() *Adapter { return &Adapter{} } func (a *Adapter) Generate(ctx context.Context, input []*schema.Message, opts llm.Options) (*schema.Message, error) { cm, err := eino_ollama.NewChatModel(ctx, &eino_ollama.ChatModelConfig{ - BaseURL: opts.Endpoint, - Timeout: opts.Timeout, - Model: opts.Model, - Options: &eino_ollama.Options{Temperature: opts.Temperature, NumPredict: opts.MaxTokens}, + BaseURL: opts.Endpoint, + Timeout: opts.Timeout, + Model: opts.Model, + Options: &eino_ollama.Options{Temperature: opts.Temperature, NumPredict: opts.MaxTokens}, + Thinking: &eino_ollama.ThinkValue{Value: opts.Thinking}, }) if err != nil { return nil, err diff --git a/internal/pkg/util/time.go b/internal/pkg/util/time.go new file mode 100644 index 0000000..1c1bf3e --- /dev/null +++ b/internal/pkg/util/time.go @@ -0,0 +1,41 @@ +package util + +import "time" + +// 判断当前时间是否在时间窗口内 +// ts 时间戳字符串,支持秒级或毫秒级 +// window 时间窗口,例如 10 * time.Minute +func IsInTimeWindow(ts string, window time.Duration) bool { + // 期望毫秒时间戳或秒级,简单容错 + // 尝试解析为整数 + var n int64 + for _, base := range []int64{1, 1000} { // 秒或毫秒 + if v, ok := parseInt64(ts); ok { + n = v + // 归一为毫秒 + if base == 1 && len(ts) <= 10 { + n = n * 1000 + } + now := time.Now().UnixMilli() + diff := now - n + if diff < 0 { + diff = -diff + } + if diff <= window.Milliseconds() { + return true + } + } + } + return false +} + +func parseInt64(s string) (int64, bool) { + var n int64 + for _, ch := range s { + if ch < '0' || ch > '9' { + return 0, false + } + n = n*10 + int64(ch-'0') + } + return n, true +} diff --git a/internal/pkg/utils_ollama/client.go b/internal/pkg/utils_ollama/client.go index 91640f0..1f67774 100644 --- a/internal/pkg/utils_ollama/client.go +++ b/internal/pkg/utils_ollama/client.go @@ -90,6 +90,25 @@ func (c *Client) ChatStream(ctx context.Context, ch chan entitys.Response, messa return } +func (c *Client) Chat(ctx context.Context, messages []api.Message) (res api.ChatResponse, err error) { + // 构建聊天请求 + req := &api.ChatRequest{ + Model: c.config.Model, + Messages: messages, + Stream: new(bool), // 设置为false,不使用流式响应 + Think: &api.ThinkValue{Value: true}, + } + err = c.client.Chat(ctx, req, func(resp api.ChatResponse) error { + res = resp + return nil + }) + if err != nil { + return + } + + return +} + func (c *Client) Generation(ctx context.Context, generateRequest *api.GenerateRequest) (result api.GenerateResponse, err error) { err = c.client.Generate(ctx, generateRequest, func(resp api.GenerateResponse) error { result = resp diff --git a/internal/server/http.go b/internal/server/http.go index fd7e49e..53446c8 100644 --- a/internal/server/http.go +++ b/internal/server/http.go @@ -11,11 +11,13 @@ import ( ) type HTTPServer struct { - app *fiber.App - service *services.ChatService - session *services.SessionService - gateway *gateway.Gateway - callback *services.CallbackService + app *fiber.App + service *services.ChatService + session *services.SessionService + gateway *gateway.Gateway + callback *services.CallbackService + chatHis *services.HistoryService + capabilityService *services.CapabilityService } func NewHTTPServer( @@ -25,10 +27,11 @@ func NewHTTPServer( gateway *gateway.Gateway, callback *services.CallbackService, chatHis *services.HistoryService, + capabilityService *services.CapabilityService, ) *fiber.App { //构建 server app := initRoute() - router.SetupRoutes(app, service, session, task, gateway, callback, chatHis) + router.SetupRoutes(app, service, session, task, gateway, callback, chatHis, capabilityService) return app } diff --git a/internal/server/router/router.go b/internal/server/router/router.go index e2645bb..be861aa 100644 --- a/internal/server/router/router.go +++ b/internal/server/router/router.go @@ -15,16 +15,18 @@ import ( ) type RouterServer struct { - app *fiber.App - service *services.ChatService - session *services.SessionService - gateway *gateway.Gateway - chatHist *services.HistoryService + app *fiber.App + service *services.ChatService + session *services.SessionService + gateway *gateway.Gateway + chatHist *services.HistoryService + capabilityService *services.CapabilityService } // SetupRoutes 设置路由 func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionService *services.SessionService, task *services.TaskService, gateway *gateway.Gateway, callbackService *services.CallbackService, chatHist *services.HistoryService, + capabilityService *services.CapabilityService, ) { app.Use(func(c *fiber.Ctx) error { // 设置 CORS 头 @@ -84,6 +86,9 @@ func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionServi // 会话历史 r.Post("/chat/history/list", chatHist.List) r.Post("/chat/history/update/content", chatHist.UpdateContent) + + // 能力 + r.Post("/capability/product/ingest", capabilityService.ProductIngest) // 商品数据提取 } func routerSocket(app *fiber.App, chatService *services.ChatService) { diff --git a/internal/services/callback.go b/internal/services/callback.go index c903bea..cc32660 100644 --- a/internal/services/callback.go +++ b/internal/services/callback.go @@ -77,7 +77,8 @@ func (s *CallbackService) Callback(c *fiber.Ctx) error { ts := strings.TrimSpace(c.Get("X-Timestamp")) // 时间窗口(如果提供了 ts 则校验,否则跳过),窗口 5 分钟 - if ts != "" && !validateTimestamp(ts, 5*time.Minute) { + // if ts != "" && !validateTimestamp(ts, 5*time.Minute) { + if ts != "" && !util.IsInTimeWindow(ts, 5*time.Minute) { return errorcode.AuthNotFound } @@ -101,40 +102,40 @@ func (s *CallbackService) Callback(c *fiber.Ctx) error { } } -func validateTimestamp(ts string, window time.Duration) bool { - // 期望毫秒时间戳或秒级,简单容错 - // 尝试解析为整数 - var n int64 - for _, base := range []int64{1, 1000} { // 秒或毫秒 - if v, ok := parseInt64(ts); ok { - n = v - // 归一为毫秒 - if base == 1 && len(ts) <= 10 { - n = n * 1000 - } - now := time.Now().UnixMilli() - diff := now - n - if diff < 0 { - diff = -diff - } - if diff <= window.Milliseconds() { - return true - } - } - } - return false -} +// func validateTimestamp(ts string, window time.Duration) bool { +// // 期望毫秒时间戳或秒级,简单容错 +// // 尝试解析为整数 +// var n int64 +// for _, base := range []int64{1, 1000} { // 秒或毫秒 +// if v, ok := parseInt64(ts); ok { +// n = v +// // 归一为毫秒 +// if base == 1 && len(ts) <= 10 { +// n = n * 1000 +// } +// now := time.Now().UnixMilli() +// diff := now - n +// if diff < 0 { +// diff = -diff +// } +// if diff <= window.Milliseconds() { +// return true +// } +// } +// } +// return false +// } -func parseInt64(s string) (int64, bool) { - var n int64 - for _, ch := range s { - if ch < '0' || ch > '9' { - return 0, false - } - n = n*10 + int64(ch-'0') - } - return n, true -} +// func parseInt64(s string) (int64, bool) { +// var n int64 +// for _, ch := range s { +// if ch < '0' || ch > '9' { +// return 0, false +// } +// n = n*10 + int64(ch-'0') +// } +// return n, true +// } func (s *CallbackService) handleDingTalkCallback(c *fiber.Ctx, env Envelope) error { // 校验taskId diff --git a/internal/services/capability.go b/internal/services/capability.go new file mode 100644 index 0000000..e0adea0 --- /dev/null +++ b/internal/services/capability.go @@ -0,0 +1,129 @@ +package services + +import ( + "ai_scheduler/internal/config" + errorcode "ai_scheduler/internal/data/error" + "ai_scheduler/internal/pkg/util" + "ai_scheduler/internal/pkg/utils_ollama" + "context" + "fmt" + "strings" + "time" + + "github.com/gofiber/fiber/v2" + "github.com/ollama/ollama/api" +) + +// CapabilityService 统一回调入口 +type CapabilityService struct { + cfg *config.Config +} + +func NewCapabilityService(cfg *config.Config) *CapabilityService { + return &CapabilityService{ + cfg: cfg, + } +} + +// 产品数据提取入参 +type ProductIngestReq struct { + Url string `json:"url"` // 商品详情页URL + Title string `json:"title"` // 商品标题 + Text string `json:"text"` // 商品描述 + Images []string `json:"images"` // 商品图片URL列表 + Timestamp int64 `json:"timestamp"` // 商品发布时间戳 +} + +const ( + // 货易通商品属性模板-中文 + HYTProductPropertyTemplateZH = `{ + "条码": "string", // 商品编号 + "分类名称": "string", // 商品分类 + "货品名称": "string", // 商品名称 + "货品编号": "string", // 商品编号 + "商品货号": "string", // 商品编号 + "品牌": "string", // 商品品牌 + "单位": "string", // 商品单位,若无则使用'个' + "规格参数": "string", // 商品规格参数 + "货品说明": "string", // 商品说明 + "保质期": "string", // 商品保质期 + "保质期单位": "string", // 商品保质期单位 + "链接": "string", // + "货品图片": ["string"], // 商品多图,取1-2个即可 + "电商销售价格": "decimal(10,2)", // 商品电商销售价格 + "销售价": "decimal(10,2)", // 商品销售价格 + "供应商报价": "decimal(10,2)", // 商品供应商报价 + "税率": "number%", // 商品税率 x% + "默认供应商": "", // 空即可 + "默认存放仓库": "", // 空即可 + "备注": "", // 备注 + "长": "string", // 商品长度,decimal(10,2)+单位 + "宽": "string", // 商品宽度,decimal(10,2)+单位 + "高": "string", // 商品高度,decimal(10,2)+单位 + "重量": "string", // 商品重量(kg) + "SPU名称": "string", // 商品SPU名称 + "SPU编码": "string" // 编码串,jd_{timestamp}_rand(1000-999) + }` + + SystemPrompt = `你是一个专业的商品属性提取助手,你的任务是根据用户输入提取商品的属性信息。 + 目标属性模板:%s。 + 最终输出格式为纯JSON字符串,键值对对应目标属性和提取到的属性值。 + 最终输出不要携带markdown标识,不要携带回车换行` +) + +// ProductIngest 产品数据提取 +func (s *CapabilityService) ProductIngest(c *fiber.Ctx) error { + // 读取头 + token := strings.TrimSpace(c.Get("X-Source-Key")) + ts := strings.TrimSpace(c.Get("X-Timestamp")) + + // 时间窗口校验 + if ts != "" && !util.IsInTimeWindow(ts, 5*time.Minute) { + // return errorcode.AuthNotFound + } + // token校验 + if token == "" || token != "A7f9KQ3mP2X8LZC4R5e" { + // return errorcode.KeyErr() + } + + // 解析请求参数 + req := ProductIngestReq{} + if err := c.BodyParser(&req); err != nil { + return errorcode.ParamErr("invalid request body: %v", err) + } + + // 必要参数校验 + if req.Text == "" { + return errorcode.ParamErr("missing required fields") + } + + // 模型调用 + client, cleanup, err := utils_ollama.NewClient(s.cfg) + if err != nil { + return err + } + defer cleanup() + + res, err := client.Chat(context.Background(), []api.Message{ + { + Role: "system", + Content: fmt.Sprintf(SystemPrompt, HYTProductPropertyTemplateZH), + }, + { + Role: "user", + Content: req.Text, + }, + { + Role: "user", + Content: "商品图片URL列表:" + strings.Join(req.Images, ","), + }, + }) + if err != nil { + return err + } + + // 解析模型输出 + c.JSON(res.Message.Content) + + return nil +} diff --git a/internal/services/provider_set.go b/internal/services/provider_set.go index 867eb11..375a886 100644 --- a/internal/services/provider_set.go +++ b/internal/services/provider_set.go @@ -12,4 +12,6 @@ var ProviderSetServices = wire.NewSet( NewTaskService, NewCallbackService, NewDingBotService, - NewHistoryService) + NewHistoryService, + NewCapabilityService, +)