From 4339b6eee80e5bacaa33061c7b4e5b0a7e70a3bd Mon Sep 17 00:00:00 2001 From: renzhiyuan <465386466@qq.com> Date: Thu, 4 Dec 2025 10:14:57 +0800 Subject: [PATCH 01/66] =?UTF-8?q?feat:=20=E4=BC=98=E5=8C=96=E4=BB=BB?= =?UTF-8?q?=E5=8A=A1=E6=A8=A1=E5=9E=8B=E4=B8=8E=E6=9D=83=E9=99=90=E6=8E=A7?= =?UTF-8?q?=E5=88=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/config_test.yaml | 4 +- internal/biz/llm_service/ollama.go | 1 + internal/config/config.go | 9 ++-- internal/gateway/client.go | 74 ++++++++++++++++++++++-------- internal/gateway/gateway.go | 25 ++++++++-- internal/pkg/func.go | 10 ++++ internal/server/router/router.go | 5 +- internal/services/chat.go | 48 ++++++++----------- 8 files changed, 114 insertions(+), 62 deletions(-) diff --git a/config/config_test.yaml b/config/config_test.yaml index 8275102..2929dd7 100644 --- a/config/config_test.yaml +++ b/config/config_test.yaml @@ -3,11 +3,12 @@ server: port: 8090 host: "0.0.0.0" + ollama: base_url: "http://127.0.0.1:11434" model: "qwen3-coder:480b-cloud" generate_model: "qwen3-coder:480b-cloud" - vl_model: "qwen2.5vl:7b" + vl_model: "gemini-3-pro-preview" timeout: "120s" level: "info" format: "json" @@ -19,6 +20,7 @@ sys: channel_pool_len: 100 channel_pool_size: 32 llm_pool_len: 5 + heartbeat_interval: 30 redis: host: 47.97.27.195:6379 type: node diff --git a/internal/biz/llm_service/ollama.go b/internal/biz/llm_service/ollama.go index 60d9c78..6527a3b 100644 --- a/internal/biz/llm_service/ollama.go +++ b/internal/biz/llm_service/ollama.go @@ -139,6 +139,7 @@ func (r *OllamaService) RecognizeWithImg(ctx context.Context, requireData *entit Prompt: r.config.DefaultPrompt.ImgRecognize.UserPrompt, Images: requireData.ImgByte, KeepAlive: &api.Duration{Duration: 3600 * time.Second}, + //Think: &api.ThinkValue{Value: false}, }) if err != nil { return diff --git a/internal/config/config.go b/internal/config/config.go index 5c9fa21..da25c39 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -36,10 +36,11 @@ type LLM struct { // SysConfig 系统配置 type SysConfig struct { - SessionLen int `mapstructure:"session_len"` - ChannelPoolLen int `mapstructure:"channel_pool_len"` - ChannelPoolSize int `mapstructure:"channel_pool_size"` - LlmPoolLen int `mapstructure:"llm_pool_len"` + SessionLen int `mapstructure:"session_len"` + ChannelPoolLen int `mapstructure:"channel_pool_len"` + ChannelPoolSize int `mapstructure:"channel_pool_size"` + LlmPoolLen int `mapstructure:"llm_pool_len"` + HeartbeatInterval int `mapstructure:"heartbeat_interval"` } // ServerConfig 服务器配置 diff --git a/internal/gateway/client.go b/internal/gateway/client.go index b49bf45..f90678c 100644 --- a/internal/gateway/client.go +++ b/internal/gateway/client.go @@ -3,33 +3,44 @@ package gateway import ( errors "ai_scheduler/internal/data/error" "ai_scheduler/internal/data/model" - "encoding/hex" - "fmt" - "github.com/gofiber/websocket/v2" + "ai_scheduler/internal/pkg" + "context" + "encoding/binary" + "log" "math/rand" "time" + + "github.com/gofiber/websocket/v2" ) var ( ErrConnClosed = errors.SysErr("连接不存在或已关闭") + rng = rand.New(rand.NewSource(time.Now().UnixNano())) + idBuf = make([]byte, 20) ) type Client struct { - id string // 客户端唯一ID - conn *websocket.Conn // WebSocket 连接 - session string // 会话ID - key string // 应用密钥 - auth string // 用户凭证token - codes []string // 用户权限code - sysInfo *model.AiSy // 系统信息 - tasks []model.AiTask // 任务列表 - sysCode string // 系统编码 + id string // 客户端唯一ID + conn *websocket.Conn // WebSocket 连接 + session string // 会话ID + key string // 应用密钥 + auth string // 用户凭证token + codes []string // 用户权限code + sysInfo *model.AiSy // 系统信息 + tasks []model.AiTask // 任务列表 + sysCode string // 系统编码 + Ctx context.Context + Cancel context.CancelFunc + LastActive time.Time } -func NewClient(conn *websocket.Conn) *Client { +func NewClient(conn *websocket.Conn, ctx context.Context, cancel context.CancelFunc) *Client { + return &Client{ - id: generateClientID(), - conn: conn, + id: generateClientID(), + conn: conn, + Ctx: ctx, + Cancel: cancel, } } @@ -103,12 +114,16 @@ func (c *Client) SendFunc(msg []byte) error { // 生成唯一的客户端ID func generateClientID() string { - // 使用时间戳+随机数确保唯一性 + // 1. 时间戳 timestamp := time.Now().UnixNano() - randomBytes := make([]byte, 4) - rand.Read(randomBytes) - randomStr := hex.EncodeToString(randomBytes) - return fmt.Sprintf("%d%s", timestamp, randomStr) + binary.BigEndian.PutUint64(idBuf[:8], uint64(timestamp)) + + // 2. 随机数(4字节) + binary.BigEndian.PutUint32(idBuf[8:12], rng.Uint32()) + + // 3. 十六进制编码 + n := pkg.HexEncode(idBuf[:12], idBuf[12:]) + return string(idBuf[12 : 12+n]) } // 连接数据验证和收集 @@ -136,3 +151,22 @@ func (c *Client) DataAuth() (err error) { } return } + +func (c *Client) InitHeartbeat(timeoutSecond time.Duration) { + ticker := time.NewTicker(timeoutSecond * time.Second) + defer ticker.Stop() + for { + select { + case <-ticker.C: + //*2是防止丢包,连续丢包两次,再加5s网络延迟容错 + if time.Since(c.LastActive) > (timeoutSecond*2*time.Second + 5) { // 5秒容错 + log.Println("Heartbeat timeout", "id", c.id) + c.conn.WriteMessage(websocket.CloseMessage, []byte("Heartbeat timeout")) + c.conn.Close() + return + } + case <-c.Ctx.Done(): + return + } + } +} diff --git a/internal/gateway/gateway.go b/internal/gateway/gateway.go index 0f0e6f3..03ce961 100644 --- a/internal/gateway/gateway.go +++ b/internal/gateway/gateway.go @@ -2,7 +2,9 @@ package gateway import ( "errors" + "log" "sync" + "time" ) type Gateway struct { @@ -20,14 +22,27 @@ func NewGateway() *Gateway { func (g *Gateway) AddClient(c *Client) { g.mu.Lock() - defer g.mu.Unlock() + defer func() { + g.mu.Unlock() + //心跳开始计时 + c.LastActive = time.Now() + log.Println("client connected:", c.GetID()) + log.Println("客户端已连接") + }() g.clients[c.GetID()] = c } -func (g *Gateway) RemoveClient(clientID string) { +func (g *Gateway) Cleanup(clientID string) { g.mu.Lock() - defer g.mu.Unlock() - delete(g.clients, clientID) + defer func() { + if c, ex := g.clients[clientID]; ex { + delete(g.clients, clientID) + _ = c.conn.Close() + c.Cancel() + } + g.mu.Unlock() + log.Println("client disconnected:", clientID) + }() for uid, list := range g.uidMap { newList := []string{} for _, cid := range list { @@ -37,6 +52,7 @@ func (g *Gateway) RemoveClient(clientID string) { } g.uidMap[uid] = newList } + } func (g *Gateway) SendToAll(msg []byte) { @@ -63,6 +79,7 @@ func (g *Gateway) BindUid(clientID, uid string) error { return errors.New("client not found") } g.uidMap[uid] = append(g.uidMap[uid], clientID) + log.Printf("bind %s -> uid:%s\n", clientID, uid) return nil } diff --git a/internal/pkg/func.go b/internal/pkg/func.go index 4e6481a..f6006ac 100644 --- a/internal/pkg/func.go +++ b/internal/pkg/func.go @@ -55,3 +55,13 @@ func ValidateImageURL(rawURL string) error { return nil } + +// hexEncode 将 src 的二进制数据编码为十六进制字符串,写入 dst,返回写入长度 +func HexEncode(src, dst []byte) int { + const hextable = "0123456789abcdef" + for i := 0; i < len(src); i++ { + dst[i*2] = hextable[src[i]>>4] + dst[i*2+1] = hextable[src[i]&0xf] + } + return len(src) * 2 +} diff --git a/internal/server/router/router.go b/internal/server/router/router.go index 5c935d7..1a6fd8b 100644 --- a/internal/server/router/router.go +++ b/internal/server/router/router.go @@ -82,10 +82,7 @@ func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionServi func routerSocket(app *fiber.App, chatService *services.ChatService) { ws := app.Group("ws/v1/") // WebSocket 路由配置 - ws.Get("/chat", websocket.New(func(c *websocket.Conn) { - // 可以在这里添加握手前的中间件逻辑(如头校验) - chatService.Chat(c) // 调用实际的 Chat 处理函数 - }, websocket.Config{ + ws.Get("/chat", websocket.New(chatService.Chat, websocket.Config{ // 可选配置:跨域检查、最大负载大小等 HandshakeTimeout: 10 * time.Second, //Subprotocols: []string{"json", "msgpack"}, diff --git a/internal/services/chat.go b/internal/services/chat.go index ff75bee..a01ba73 100644 --- a/internal/services/chat.go +++ b/internal/services/chat.go @@ -6,9 +6,11 @@ import ( "ai_scheduler/internal/data/constants" "ai_scheduler/internal/entitys" "ai_scheduler/internal/gateway" + "context" "encoding/json" "log" "sync" + "time" "github.com/gofiber/fiber/v2" "github.com/gofiber/websocket/v2" @@ -38,20 +40,6 @@ func NewChatService( } } -// ToolCallResponse 工具调用响应 -type ToolCallResponse struct { - ID string `json:"id" example:"call_1"` - Type string `json:"type" example:"function"` - Function FunctionCallResponse `json:"function"` - Result interface{} `json:"result,omitempty"` -} - -// FunctionCallResponse 函数调用响应 -type FunctionCallResponse struct { - Name string `json:"name" example:"get_weather"` - Arguments interface{} `json:"arguments"` -} - func (h *ChatService) ChatFail(c *websocket.Conn, content string) { err := c.WriteMessage(websocket.TextMessage, []byte(content)) if err != nil { @@ -63,15 +51,18 @@ func (h *ChatService) ChatFail(c *websocket.Conn, content string) { // Chat 处理WebSocket聊天连接 // 这是WebSocket处理的主入口函数 func (h *ChatService) Chat(c *websocket.Conn) { - // 创建新的客户端实例 - h.mu.Lock() - client := gateway.NewClient(c) - h.mu.Unlock() + ctx, cancel := context.WithCancel(context.Background()) + // 创建新的客户端实例 + client := gateway.NewClient(c, ctx, cancel) + // 心跳检测 + go client.InitHeartbeat(time.Duration(h.cfg.Sys.HeartbeatInterval)) // 将客户端添加到网关管理 h.Gw.AddClient(client) - log.Println("client connected:", client.GetID()) - log.Println("客户端已连接") + // 确保在函数返回时移除客户端并关闭连接 + defer func() { + h.Gw.Cleanup(client.GetID()) + }() // 绑定会话ID uid := c.Query("x-session") @@ -79,7 +70,6 @@ func (h *ChatService) Chat(c *websocket.Conn) { if err := h.Gw.BindUid(client.GetID(), uid); err != nil { log.Println("绑定UID错误:", err) } - log.Printf("bind %s -> uid:%s\n", client.GetID(), uid) } // 验证并收集连接数据,后续对话中会使用 @@ -89,13 +79,6 @@ func (h *ChatService) Chat(c *websocket.Conn) { return } - // 确保在函数返回时移除客户端并关闭连接 - defer func() { - h.Gw.RemoveClient(client.GetID()) - _ = c.Close() - log.Println("client disconnected:", client.GetID()) - }() - // 循环读取客户端消息 for { // 读取消息 @@ -104,7 +87,14 @@ func (h *ChatService) Chat(c *websocket.Conn) { log.Println("读取错误:", err) break } - + //if string(message) == `{"type":"ping"}` { + // client.LastActive = time.Now() + // if err := c.WriteMessage(websocket.TextMessage, []byte(`{"type":"pong"}`)); err != nil { + // log.Println("Heartbeat response failed", "id", client.GetID(), "err", err) + // return + // } + // continue + //} // 处理消息 msg, chatType := h.handleMessageToString(c, messageType, message) if chatType == constants.ConnStatusClosed { From 9c19e426000e3d97415fe51d7c649299a56dc3be Mon Sep 17 00:00:00 2001 From: fuzhongyun <15339891972@163.com> Date: Thu, 4 Dec 2025 17:11:53 +0800 Subject: [PATCH 02/66] =?UTF-8?q?feat:=20=E5=A2=9E=E5=8A=A0=E4=B8=80?= =?UTF-8?q?=E5=A5=97eino=E5=AE=9E=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cmd/server/wire.go | 1 + config/config_test.yaml | 49 ++++++++ go.mod | 36 +++++- go.sum | 112 ++++++++++++++++-- internal/config/config.go | 27 ++++- internal/domain/common/mapper.go | 11 ++ internal/domain/common/types.go | 4 + internal/domain/llm/api.go | 14 +++ internal/domain/llm/capability/router.go | 48 ++++++++ internal/domain/llm/capability/types.go | 15 +++ internal/domain/llm/capability/validator.go | 22 ++++ internal/domain/llm/errors.go | 9 ++ internal/domain/llm/options.go | 16 +++ internal/domain/llm/pipeline/chat.go | 38 ++++++ internal/domain/llm/pipeline/intent.go | 26 ++++ internal/domain/llm/pipeline/vision.go | 26 ++++ internal/domain/llm/prompt/templates.go | 6 + .../domain/llm/provider/ollama/adapter.go | 72 +++++++++++ internal/domain/llm/provider/registry.go | 25 ++++ internal/domain/llm/provider/vllm/adapter.go | 70 +++++++++++ internal/domain/llm/service/chat_service.go | 22 ++++ internal/domain/llm/service/intent_service.go | 22 ++++ internal/domain/llm/service/service.go | 96 +++++++++++++++ internal/domain/llm/service/vision_service.go | 22 ++++ internal/domain/tools/registry.go | 16 +++ 25 files changed, 788 insertions(+), 17 deletions(-) create mode 100644 internal/domain/common/mapper.go create mode 100644 internal/domain/common/types.go create mode 100644 internal/domain/llm/api.go create mode 100644 internal/domain/llm/capability/router.go create mode 100644 internal/domain/llm/capability/types.go create mode 100644 internal/domain/llm/capability/validator.go create mode 100644 internal/domain/llm/errors.go create mode 100644 internal/domain/llm/options.go create mode 100644 internal/domain/llm/pipeline/chat.go create mode 100644 internal/domain/llm/pipeline/intent.go create mode 100644 internal/domain/llm/pipeline/vision.go create mode 100644 internal/domain/llm/prompt/templates.go create mode 100644 internal/domain/llm/provider/ollama/adapter.go create mode 100644 internal/domain/llm/provider/registry.go create mode 100644 internal/domain/llm/provider/vllm/adapter.go create mode 100644 internal/domain/llm/service/chat_service.go create mode 100644 internal/domain/llm/service/intent_service.go create mode 100644 internal/domain/llm/service/service.go create mode 100644 internal/domain/llm/service/vision_service.go create mode 100644 internal/domain/tools/registry.go diff --git a/cmd/server/wire.go b/cmd/server/wire.go index f134ef9..bcd1d2c 100644 --- a/cmd/server/wire.go +++ b/cmd/server/wire.go @@ -22,6 +22,7 @@ import ( func InitializeApp(*config.Config, log.AllLogger) (*server.Servers, func(), error) { panic(wire.Build( server.ProviderSetServer, + llm.ProviderSet, tools.ProviderSetTools, pkg.ProviderSetClient, services.ProviderSetServices, diff --git a/config/config_test.yaml b/config/config_test.yaml index 8275102..90c971d 100644 --- a/config/config_test.yaml +++ b/config/config_test.yaml @@ -79,3 +79,52 @@ default_prompt: # 权限配置 permissionConfig: permission_url: "http://api.test.user.1688sup.cn:8001/v1/menu/myCodes?systemId=" + +# llm 服务配置 +llm: + providers: + ollama: + endpoint: http://127.0.0.1:11434 + timeout: 60s + models: + - id: qwen3-coder:480b-cloud + name: qwen3-coder:480b-cloud + streaming: true + modalities: [text] + max_tokens: 4096 + + vllm: + endpoint: http://117.175.169.61:16001 + timeout: 60s + models: + - id: models/Qwen2.5-VL-3B-Instruct-AWQ + name: qwen2.5-vl-3b + streaming: true + modalities: [text, image] + max_tokens: 4096 + + # 每个能力只绑定一个 provider+model,不做自动回退 + capabilities: + intent: + provider: vllm + model: qwen2.5-vl-3b + parameters: + temperature: 0.2 + max_tokens: 4096 + stream: false + + vision: + provider: vllm + model: qwen2.5-vl-3b + parameters: + temperature: 0.5 + max_tokens: 4096 + stream: true + + chat: + provider: ollama + model: qwen3-coder:480b-cloud + parameters: + temperature: 0.7 + max_tokens: 4096 + stream: true diff --git a/go.mod b/go.mod index 6b08e4a..087b3ee 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,9 @@ require ( github.com/alibabacloud-go/dingtalk v1.6.96 github.com/alibabacloud-go/tea v1.2.2 github.com/alibabacloud-go/tea-utils/v2 v2.0.6 + github.com/cloudwego/eino v0.7.6 + github.com/cloudwego/eino-ext/components/model/ollama v0.1.6 + github.com/cloudwego/eino-ext/components/model/openai v0.1.5 github.com/emirpasic/gods v1.18.1 github.com/faabiosr/cachego v0.26.0 github.com/fastwego/dingding v1.0.0-beta.4 @@ -23,8 +26,6 @@ require ( github.com/spf13/viper v1.17.0 github.com/tmc/langchaingo v0.1.13 google.golang.org/grpc v1.64.0 - google.golang.org/protobuf v1.34.1 - gopkg.in/yaml.v3 v3.0.1 gorm.io/driver/mysql v1.6.0 gorm.io/gorm v1.31.0 xorm.io/builder v0.3.13 @@ -39,31 +40,50 @@ require ( github.com/alibabacloud-go/tea-xml v1.1.3 // indirect github.com/aliyun/credentials-go v1.4.6 // indirect github.com/andybalholm/brotli v1.1.0 // indirect + github.com/bahlo/generic-list-go v0.2.0 // indirect + github.com/buger/jsonparser v1.1.1 // indirect + github.com/bytedance/gopkg v0.1.3 // indirect + github.com/bytedance/sonic v1.14.1 // indirect + github.com/bytedance/sonic/loader v0.3.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/clbanning/mxj/v2 v2.5.5 // indirect + github.com/cloudwego/base64x v0.1.6 // indirect + github.com/cloudwego/eino-ext/libs/acl/openai v0.1.2 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dlclark/regexp2 v1.11.4 // indirect + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/eino-contrib/jsonschema v1.0.3 // indirect + github.com/eino-contrib/ollama v0.1.0 // indirect + github.com/evanphx/json-patch v0.5.2 // indirect github.com/fasthttp/websocket v1.5.3 // indirect github.com/fsnotify/fsnotify v1.6.0 // indirect github.com/go-sql-driver/mysql v1.8.1 // indirect + github.com/goph/emperror v0.17.2 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/compress v1.17.9 // indirect + github.com/klauspost/cpuid/v2 v2.2.9 // indirect github.com/magiconair/properties v1.8.7 // indirect + github.com/mailru/easyjson v0.7.7 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-runewidth v0.0.16 // indirect + github.com/meguminnnnnnnnn/go-openai v0.1.0 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/nikolalohinski/gonja v1.5.3 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect + github.com/pkg/errors v0.9.1 // indirect github.com/pkoukk/tiktoken-go v0.1.6 // indirect github.com/rivo/uniseg v0.2.0 // indirect github.com/sagikazarmark/locafero v0.3.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee // indirect + github.com/sirupsen/logrus v1.9.3 // indirect + github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f // indirect github.com/sourcegraph/conc v0.3.0 // indirect github.com/spf13/afero v1.10.0 // indirect github.com/spf13/cast v1.5.1 // indirect @@ -71,16 +91,22 @@ require ( github.com/stretchr/testify v1.11.1 // indirect github.com/subosito/gotenv v1.6.0 // indirect github.com/tjfoc/gmsm v1.4.1 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasthttp v1.51.0 // indirect github.com/valyala/tcplisten v1.0.0 // indirect + github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect + github.com/yargevad/filepathx v1.0.0 // indirect go.uber.org/atomic v1.9.0 // indirect go.uber.org/multierr v1.9.0 // indirect - golang.org/x/crypto v0.36.0 // indirect + golang.org/x/arch v0.11.0 // indirect + golang.org/x/crypto v0.39.0 // indirect golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect golang.org/x/net v0.38.0 // indirect - golang.org/x/sys v0.31.0 // indirect - golang.org/x/text v0.23.0 // indirect + golang.org/x/sys v0.33.0 // indirect + golang.org/x/text v0.26.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240604185151-ef581f913117 // indirect + google.golang.org/protobuf v1.34.1 // indirect gopkg.in/ini.v1 v1.67.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 29b4ace..7141ab9 100644 --- a/go.sum +++ b/go.sum @@ -44,6 +44,7 @@ gitea.com/xorm/sqlfiddle v0.0.0-20180821085327-62ce714f951a h1:lSA0F4e9A2NcQSqGq gitea.com/xorm/sqlfiddle v0.0.0-20180821085327-62ce714f951a/go.mod h1:EXuID2Zs0pAQhH8yz+DNjUbjppKQzKFAn28TMYPB6IU= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= +github.com/airbrake/gobrake v3.6.1+incompatible/go.mod h1:wM4gu3Cn0W0K7GUuVWnlXZU11AGBXMILnrdOU8Kn00o= github.com/alibabacloud-go/alibabacloud-gateway-pop v0.0.6 h1:eIf+iGJxdU4U9ypaUfbtOWCsZSbTb8AUHvyPrxu6mAA= github.com/alibabacloud-go/alibabacloud-gateway-pop v0.0.6/go.mod h1:4EUIoxs/do24zMOGGqYVWgw0s9NtiylnJglOeEB5UJo= github.com/alibabacloud-go/alibabacloud-gateway-spi v0.0.4/go.mod h1:sCavSAvdzOjul4cEqeVtvlSaSScfNsTQ+46HwlTL1hc= @@ -96,12 +97,29 @@ github.com/aliyun/credentials-go v1.4.6 h1:CG8rc/nxCNKfXbZWpWDzI9GjF4Tuu3Es14qT8 github.com/aliyun/credentials-go v1.4.6/go.mod h1:Jm6d+xIgwJVLVWT561vy67ZRP4lPTQxMbEYRuT2Ti1U= github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M= github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY= +github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= +github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= +github.com/bitly/go-simplejson v0.5.0/go.mod h1:cXHtHw4XUPsvGaxgjIAn8PhEWG9NfngEKAMDJEczWVA= +github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4= github.com/bradfitz/gomemcache v0.0.0-20170208213004-1952afaa557d/go.mod h1:PmM6Mmwb0LSuEubjR8N7PtNe1KxZLtOUHtbeikc5h60= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= +github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= +github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= +github.com/bugsnag/bugsnag-go v1.4.0/go.mod h1:2oa8nejYd4cQ/b0hMIopN0lCRxU0bueqREvZLWFrtK8= +github.com/bugsnag/panicwrap v1.2.0/go.mod h1:D/8v3kj0zr8ZAKg1AQ6crr+5VwKN5eIywRkfhyM/+dE= +github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M= +github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM= +github.com/bytedance/mockey v1.2.14 h1:KZaFgPdiUwW+jOWFieo3Lr7INM1P+6adO3hxZhDswY8= +github.com/bytedance/mockey v1.2.14/go.mod h1:1BPHF9sol5R1ud/+0VEHGQq/+i2lN+GTsr3O2Q9IENY= +github.com/bytedance/sonic v1.14.1 h1:FBMC0zVz5XUmE4z9wF4Jey0An5FueFvOsTKKKtwIl7w= +github.com/bytedance/sonic v1.14.1/go.mod h1:gi6uhQLMbTdeP0muCnrjHLeCUPyb70ujhnNlhOylAFc= +github.com/bytedance/sonic/loader v0.3.0 h1:dskwH8edlzNMctoruo8FPTJDF3vLtDT0sXZwvZJyqeA= +github.com/bytedance/sonic/loader v0.3.0/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= +github.com/certifi/gocertifi v0.0.0-20190105021004-abcd57078448/go.mod h1:GJKEexRPVJrBSOjoqN5VNOIKJ5Q3RViH6eu3puDRwx4= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= @@ -110,6 +128,16 @@ github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMn github.com/clbanning/mxj/v2 v2.5.5 h1:oT81vUeEiQQ/DcHbzSytRngP6Ky9O+L+0Bw0zSJag9E= github.com/clbanning/mxj/v2 v2.5.5/go.mod h1:hNiWqW14h+kc+MdF9C6/YoRfjEJoR3ou6tn/Qo+ve2s= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= +github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= +github.com/cloudwego/eino v0.7.6 h1:9KGY1IZ/5kCf2viMDrPF3ck3tqd92bOhVOoSKTFRwY0= +github.com/cloudwego/eino v0.7.6/go.mod h1:nA8Vacmuqv3pqKBQbTWENBLQ8MmGmPt/WqiyLeB8ohQ= +github.com/cloudwego/eino-ext/components/model/ollama v0.1.6 h1:ZbrhV91uE0hGIOYXhb2i3G6tQJ/rK2SLYtoYrmocZXM= +github.com/cloudwego/eino-ext/components/model/ollama v0.1.6/go.mod h1:GDXrvorGdRNV6g2mK5jdla2D8Xc/hh7XDrTeGDteLLo= +github.com/cloudwego/eino-ext/components/model/openai v0.1.5 h1:+yvGbTPw93li9GSmdm6Rix88Yy8AXg5NNBcRbWx3CQU= +github.com/cloudwego/eino-ext/components/model/openai v0.1.5/go.mod h1:IPVYMFoZcuHeVEsDTGN6SZjvue0xr1iZFhdpq1SBWdQ= +github.com/cloudwego/eino-ext/libs/acl/openai v0.1.2 h1:r9Id2wzJ05PoHl+Km7jQgNMgciaZI93TVnUYso89esM= +github.com/cloudwego/eino-ext/libs/acl/openai v0.1.2/go.mod h1:S4OkvglPY9hsm9tXeShODrf/WN1Cgu4bqu4nn/CnIic= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= @@ -121,6 +149,12 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/dlclark/regexp2 v1.11.4 h1:rPYF9/LECdNymJufQKmri9gV604RvvABwgOA8un7yAo= github.com/dlclark/regexp2 v1.11.4/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/eino-contrib/jsonschema v1.0.3 h1:2Kfsm1xlMV0ssY2nuxshS4AwbLFuqmPmzIjLVJ1Fsp0= +github.com/eino-contrib/jsonschema v1.0.3/go.mod h1:cpnX4SyKjWjGC7iN2EbhxaTdLqGjCi0e9DxpLYxddD4= +github.com/eino-contrib/ollama v0.1.0 h1:z1NaMdKW6X1ftP8g5xGGR5zDRPUtuTKFq35vBQgxsN4= +github.com/eino-contrib/ollama v0.1.0/go.mod h1:mYsQ7b3DeqY8bHPuD3MZJYTqkgyL6LoemxoP/B7ZNhA= github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc= github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= @@ -129,6 +163,8 @@ github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1m github.com/envoyproxy/go-control-plane v0.9.7/go.mod h1:cwu0lG7PUMfa9snN8LXBig5ynNVH9qI8YYLbd1fK2po= github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/evanphx/json-patch v0.5.2 h1:xVCHIVMUu1wtM/VkR9jVZ45N3FhZfYMMYGorLCR8P3k= +github.com/evanphx/json-patch v0.5.2/go.mod h1:ZWS5hhDbVDyob71nXKNL0+PWn6ToqBHMikGIFbs31qQ= github.com/faabiosr/cachego v0.15.0/go.mod h1:L2EomlU3/rUWjzFavY9Fwm8B4zZmX2X6u8kTMkETrwI= github.com/faabiosr/cachego v0.26.0 h1:EDDv2y9T0XJ4Cx3tUhbKSUayGWxCGkkZUivNLceHRWY= github.com/faabiosr/cachego v0.26.0/go.mod h1:p54WXVzeB1CctH1ix/rjqv1EotNzD0Xoxk2IsR1PQX8= @@ -143,6 +179,9 @@ github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4 github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY= github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw= github.com/garyburd/redigo v1.6.0/go.mod h1:NR3MbYisc3/PwhQ00EMzDiPmrwpPxAn5GI05/YaO1SY= +github.com/getsentry/raven-go v0.2.0/go.mod h1:KungGk8q33+aIAZUIVWZDr2OfAEBsO49PX4NzFV5kcQ= +github.com/go-check/check v0.0.0-20180628173108-788fd7840127 h1:0gkP6mzaMqkmpcJYCFOLkIBwI7xFExG03bbkOkCvUPI= +github.com/go-check/check v0.0.0-20180628173108-788fd7840127/go.mod h1:9ES+weclKsC9YodN5RgxqK/VD9HM9JsCSh7rNhMZE98= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= @@ -154,6 +193,7 @@ github.com/gofiber/fiber/v2 v2.52.9 h1:YjKl5DOiyP3j0mO61u3NTmK7or8GzzWzCFzkboyP5 github.com/gofiber/fiber/v2 v2.52.9/go.mod h1:YEcBbO/FB+5M1IZNBP9FO3J9281zgPAreiI1oqg8nDw= github.com/gofiber/websocket/v2 v2.2.1 h1:C9cjxvloojayOp9AovmpQrk8VqvVnT8Oao3+IUygH7w= github.com/gofiber/websocket/v2 v2.2.1/go.mod h1:Ao/+nyNnX5u/hIFPuHl28a+NIkrqK7PRimyKaj4JxVU= +github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= @@ -215,8 +255,12 @@ github.com/google/wire v0.7.0/go.mod h1:n6YbUQD9cPKTnHXEBN2DXlOp/mVADhVErcMFb0v3 github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= github.com/googleapis/google-cloud-go-testing v0.0.0-20200911160855-bcd43fbb19e8/go.mod h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g= +github.com/goph/emperror v0.17.2 h1:yLapQcmEsO0ipe9p5TaN22djm3OFV/TfM/fcYP0/J18= +github.com/goph/emperror v0.17.2/go.mod h1:+ZbQ+fUNO/6FNiUo0ujtMjhgad9Xa6fQL9KhH4LNHic= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= +github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25dO0g= +github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= @@ -224,19 +268,26 @@ github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= +github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk= +github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= +github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0/go.mod h1:1NbS8ALrpOvjt0rHPNLyCIeMtbizbir8U//inJ+zuB8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= +github.com/klauspost/cpuid/v2 v2.2.9 h1:66ze0taIn2H33fBvCkXuv9BmCwDfafmiIVpKV9kKGuY= +github.com/klauspost/cpuid/v2 v2.2.9/go.mod h1:rqkxqrZ1EhYM9G+hXH7YdowN5R5RGN6NK4QwQ3WMXF8= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= @@ -247,6 +298,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= +github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= +github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= @@ -255,6 +308,10 @@ github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mattn/go-sqlite3 v1.6.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= +github.com/meguminnnnnnnnn/go-openai v0.1.0 h1:BGzB1PlS2Epq0mBB2TGLwzMihbR7BANrlMH3w4ZnY88= +github.com/meguminnnnnnnnn/go-openai v0.1.0/go.mod h1:qs96ysDmxhE4BZoU45I43zcyfnaYxU3X+aRzLko/htY= +github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b h1:j7+1HpAFS1zy5+Q4qx1fWh90gTKwiN4QCGoY9TWyyO4= +github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -265,16 +322,22 @@ github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3Rllmb github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= +github.com/nikolalohinski/gonja v1.5.3 h1:GsA+EEaZDZPGJ8JtpeGN78jidhOlxeJROpqMT9fTj9c= +github.com/nikolalohinski/gonja v1.5.3/go.mod h1:RmjwxNiXAEqcq1HeK5SSMmqFJvKOfTfXhkJv6YBtPa4= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= github.com/ollama/ollama v0.12.7 h1:dxokli1UyO/a0Aun5sE4+0Gg+A9oMUAPiFQhxrXOIXA= github.com/ollama/ollama v0.12.7/go.mod h1:9+1//yWPsDE2u+l1a5mpaKrYw4VdnSsRU3ioq5BvMms= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.8.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= github.com/onsi/ginkgo v1.13.0/go.mod h1:+REjRxOmWfHCjfv9TTWB1jD1Frx4XydAD3zm1lskyM0= +github.com/onsi/gomega v1.5.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= +github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/sftp v1.13.1/go.mod h1:3HaPG6Dq1ILlpPZRO0HVMrsydcdLt6HRDccSgb87qRg= github.com/pkoukk/tiktoken-go v0.1.6 h1:JF0TlJzhTbrI30wCvFuiw6FzP2+/bR+FIxUdgEAcUsw= @@ -290,6 +353,7 @@ github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJ github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= +github.com/rollbar/rollbar-go v1.0.2/go.mod h1:AcFs5f0I+c71bpHlXNNDbOWJiKwjFDtISeXco0L5PKQ= github.com/sagikazarmark/locafero v0.3.0 h1:zT7VEGWC2DTflmccN/5T1etyKvxSxpHsjb9cJvm4SvQ= github.com/sagikazarmark/locafero v0.3.0/go.mod h1:w+v7UsPNFwzF1cHuOajOOzoq4U7v/ig1mpRjqV+Bu1U= github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE= @@ -297,9 +361,18 @@ github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWR github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee h1:8Iv5m6xEo1NR1AvpV+7XmhI4r39LGNzwUL4YpMuL5vk= github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee/go.mod h1:qwtSXrKuJh/zsFQ12yEE89xfCrGKK63Rr7ctU/uCo4g= github.com/sclevine/agouti v3.0.0+incompatible/go.mod h1:b4WX9W9L1sfQKXeJf1mUTLZKJ48R1S7H23Ji7oFO5Bw= +github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f h1:Z2cODYsUxQPofhpYRMQVwWz4yUVpHF+vPi+eUdruUYI= +github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f/go.mod h1:JqzWyvTuI2X4+9wOHmKSQCYxybB/8j6Ko43qVmXDuZg= +github.com/smarty/assertions v1.15.0 h1:cR//PqUBUiQRakZWqBiFFQ9wb8emQGDb0HeGdqGByCY= +github.com/smarty/assertions v1.15.0/go.mod h1:yABtdzeQs6l1brC900WlRNwj6ZR55d7B+E8C6HtKdec= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/assertions v1.1.0/go.mod h1:tcbTF8ujkAEcZ8TElKY+i30BzYlVhC/LOxJk7iOWnoo= github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= +github.com/smartystreets/goconvey v1.8.1 h1:qGjIddxOk4grTu9JPOU31tVfq3cNdBlNa5sSznIX1xY= +github.com/smartystreets/goconvey v1.8.1/go.mod h1:+/u4qLyY6x1jReYOp7GOM2FSt8aP9CzCZL03bI28W60= github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= github.com/spf13/afero v1.10.0 h1:EaGW2JJh15aKOejeuJ+wpFSHnbd7GE6Wvp3TsNhb6LY= @@ -311,16 +384,19 @@ github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An github.com/spf13/viper v1.17.0 h1:I5txKw7MJasPL/BrfkbA0Jyo/oELqVmux4pR/UxOMfI= github.com/spf13/viper v1.17.0/go.mod h1:BmMMMLQXSbcHK6KAOiFLz0l5JHrU89OdIRHvsk0+yVI= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= @@ -332,12 +408,20 @@ github.com/tjfoc/gmsm v1.4.1 h1:aMe1GlZb+0bLjn+cKTPEvvn9oUEBlJitaZiiBwsbgho= github.com/tjfoc/gmsm v1.4.1/go.mod h1:j4INPkHWMrhJb38G+J6W4Tw0AbuN8Thu3PbdVYhVcTE= github.com/tmc/langchaingo v0.1.13 h1:rcpMWBIi2y3B90XxfE4Ao8dhCQPVDMaNPnN5cGB1CaA= github.com/tmc/langchaingo v0.1.13/go.mod h1:vpQ5NOIhpzxDfTZK9B6tf2GM/MoaHewPWM5KXXGh7hg= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasthttp v1.51.0 h1:8b30A5JlZ6C7AS81RsWjYMQmrZG6feChmgAolCl1SqA= github.com/valyala/fasthttp v1.51.0/go.mod h1:oI2XroL+lI7vdXyYoQk03bXBThfFl2cVdIA3Xl7cH8g= github.com/valyala/tcplisten v1.0.0 h1:rBHj/Xf+E1tRGZyWIWwJDiRY0zc1Js+CV5DqwacVSA8= github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc= +github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= +github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= +github.com/x-cray/logrus-prefixed-formatter v0.5.2 h1:00txxvfBM9muc0jiLIEAkAcIMJzfthRT6usrui8uGmg= +github.com/x-cray/logrus-prefixed-formatter v0.5.2/go.mod h1:2duySbKsL6M18s5GU7VPsoEPHyzalCE06qoARUCeBBE= +github.com/yargevad/filepathx v1.0.0 h1:SYcT+N3tYGi+NvazubCNlvgIPbzAk7i7y2dwg3I5FYc= +github.com/yargevad/filepathx v1.0.0/go.mod h1:BprfX/gpYNJHJfc35GjRRpVcwWXS89gGulUIU5tK3tA= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.30/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= @@ -353,8 +437,13 @@ go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk= go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= +go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ= +golang.org/x/arch v0.11.0 h1:KXV8WWKCXm6tRpLirl2szsO5j/oOODwZf4hATmGVNs4= +golang.org/x/arch v0.11.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= +golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= @@ -370,8 +459,8 @@ golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= -golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= -golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= +golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM= +golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -475,9 +564,10 @@ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw= -golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8= +golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -519,6 +609,7 @@ golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -529,8 +620,8 @@ golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= -golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= +golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= @@ -539,8 +630,8 @@ golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U= golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58= -golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y= -golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g= +golang.org/x/term v0.32.0 h1:DR4lr0TjUs3epypdhTOkMmuF5CDFJ/8pOnbzMZPQ7bg= +golang.org/x/term v0.32.0/go.mod h1:uZG1FhGx848Sqfsq4/DlJr3xGGsYMu/L5GW4abiaEPQ= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -553,8 +644,8 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= -golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= +golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M= +golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= @@ -721,6 +812,7 @@ gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/mgo.v2 v2.0.0-20160818020120-3f83fa500528/go.mod h1:yeKp02qBN3iKW1OzL3MGk2IdtZzaj7SFntXj72NppTA= gopkg.in/redis.v4 v4.2.4/go.mod h1:8KREHdypkCEojGKQcjMqAODMICIVwZAONWq8RowTITA= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= +gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/internal/config/config.go b/internal/config/config.go index 5c9fa21..60a397b 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -18,7 +18,7 @@ type Config struct { DB DB `mapstructure:"db"` DefaultPrompt SysPrompt `mapstructure:"default_prompt"` PermissionConfig PermissionConfig `mapstructure:"permissionConfig"` - // LLM *LLM `mapstructure:"llm"` + LLM LLM `mapstructure:"llm"` } type SysPrompt struct { @@ -31,7 +31,30 @@ type DefaultPrompt struct { } type LLM struct { - Model string `mapstructure:"model"` + Providers map[string]LLMProviderConfig `mapstructure:"providers"` + Capabilities map[string]LLMCapabilityConfig `mapstructure:"capabilities"` +} +type LLMProviderConfig struct { + Endpoint string `mapstructure:"endpoint"` + Timeout string `mapstructure:"timeout"` + Models []LLMModle `mapstructure:"models"` +} +type LLMModle struct { + ID string `mapstructure:"id"` + Name string `mapstructure:"name"` + Streaming bool `mapstructure:"streaming"` + Modalities []string `mapstructure:"modalities"` + MaxTokens int `mapstructure:"max_tokens"` +} +type LLMParameters struct { + Temperature float64 `mapstructure:"temperature"` + MaxTokens int `mapstructure:"max_tokens"` + Stream bool `mapstructure:"stream"` +} +type LLMCapabilityConfig struct { + Provider string `mapstructure:"provider"` + Model string `mapstructure:"model"` + Parameters LLMParameters `mapstructure:"parameters"` } // SysConfig 系统配置 diff --git a/internal/domain/common/mapper.go b/internal/domain/common/mapper.go new file mode 100644 index 0000000..2b42e28 --- /dev/null +++ b/internal/domain/common/mapper.go @@ -0,0 +1,11 @@ +package common + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/domain/llm" +) + +func OptionsFromLLMParameters(p config.LLMParameters) llm.Options { + return llm.Options{Temperature: float32(p.Temperature), MaxTokens: p.MaxTokens, Stream: p.Stream} +} + diff --git a/internal/domain/common/types.go b/internal/domain/common/types.go new file mode 100644 index 0000000..ebd62ed --- /dev/null +++ b/internal/domain/common/types.go @@ -0,0 +1,4 @@ +package common + +type KV map[string]any + diff --git a/internal/domain/llm/api.go b/internal/domain/llm/api.go new file mode 100644 index 0000000..8b4b87f --- /dev/null +++ b/internal/domain/llm/api.go @@ -0,0 +1,14 @@ +package llm + +import ( + "context" + "github.com/cloudwego/eino/schema" +) + +type Service interface { + Chat(ctx context.Context, input []*schema.Message, opts Options) (*schema.Message, error) + ChatStream(ctx context.Context, input []*schema.Message, opts Options) (*schema.StreamReader[*schema.Message], error) + Vision(ctx context.Context, input []*schema.Message, opts Options) (*schema.Message, error) + Intent(ctx context.Context, input []*schema.Message, opts Options) (*schema.Message, error) +} + diff --git a/internal/domain/llm/capability/router.go b/internal/domain/llm/capability/router.go new file mode 100644 index 0000000..5798ed8 --- /dev/null +++ b/internal/domain/llm/capability/router.go @@ -0,0 +1,48 @@ +package capability + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/domain/llm" + "strings" + "time" +) + +func Route(cfg *config.Config, ability Ability) (ProviderChoice, llm.Options, error) { + cap, ok := cfg.LLM.Capabilities[string(ability)] + if !ok { + return ProviderChoice{}, llm.Options{}, llm.ErrInvalidCapability + } + prov, ok := cfg.LLM.Providers[cap.Provider] + if !ok { + return ProviderChoice{}, llm.Options{}, llm.ErrProviderNotFound + } + var modelConf config.LLMModle + found := false + for _, m := range prov.Models { + if m.Name == cap.Model || m.ID == cap.Model { + modelConf = m + found = true + break + } + } + if !found { + return ProviderChoice{}, llm.Options{}, llm.ErrModelNotFound + } + to := llm.Options{} + to.Model = modelConf.Name + to.Stream = cap.Parameters.Stream || modelConf.Streaming + if to.Stream && !modelConf.Streaming { + to.Stream = false + } + to.MaxTokens = modelConf.MaxTokens + if cap.Parameters.MaxTokens > 0 && cap.Parameters.MaxTokens <= modelConf.MaxTokens { + to.MaxTokens = cap.Parameters.MaxTokens + } + to.Temperature = float32(cap.Parameters.Temperature) + to.Modalities = append([]string{}, modelConf.Modalities...) + d, _ := time.ParseDuration(strings.TrimSpace(prov.Timeout)) + to.Timeout = d + to.Endpoint = prov.Endpoint + choice := ProviderChoice{Provider: cap.Provider, Model: to.Model} + return choice, to, nil +} diff --git a/internal/domain/llm/capability/types.go b/internal/domain/llm/capability/types.go new file mode 100644 index 0000000..618728f --- /dev/null +++ b/internal/domain/llm/capability/types.go @@ -0,0 +1,15 @@ +package capability + +type Ability string + +const ( + Intent Ability = "intent" + Vision Ability = "vision" + Chat Ability = "chat" +) + +type ProviderChoice struct { + Provider string + Model string +} + diff --git a/internal/domain/llm/capability/validator.go b/internal/domain/llm/capability/validator.go new file mode 100644 index 0000000..c2edfd2 --- /dev/null +++ b/internal/domain/llm/capability/validator.go @@ -0,0 +1,22 @@ +package capability + +import ( + "ai_scheduler/internal/domain/llm" +) + +func Validate(ability Ability, opts llm.Options) error { + if ability == Vision { + has := false + for _, m := range opts.Modalities { + if m == "image" { + has = true + break + } + } + if !has { + return llm.ErrModalityMismatch + } + } + return nil +} + diff --git a/internal/domain/llm/errors.go b/internal/domain/llm/errors.go new file mode 100644 index 0000000..59d76e6 --- /dev/null +++ b/internal/domain/llm/errors.go @@ -0,0 +1,9 @@ +package llm + +import "errors" + +var ErrInvalidCapability = errors.New("invalid capability") +var ErrProviderNotFound = errors.New("provider not found") +var ErrModelNotFound = errors.New("model not found") +var ErrModalityMismatch = errors.New("modality mismatch") +var ErrNotImplemented = errors.New("not implemented") diff --git a/internal/domain/llm/options.go b/internal/domain/llm/options.go new file mode 100644 index 0000000..c427153 --- /dev/null +++ b/internal/domain/llm/options.go @@ -0,0 +1,16 @@ +package llm + +import "time" + +type Options struct { + Temperature float32 + MaxTokens int + Stream bool + Timeout time.Duration + Modalities []string + SystemPrompt string + Model string + TopP float32 + Stop []string + Endpoint string +} diff --git a/internal/domain/llm/pipeline/chat.go b/internal/domain/llm/pipeline/chat.go new file mode 100644 index 0000000..50cc405 --- /dev/null +++ b/internal/domain/llm/pipeline/chat.go @@ -0,0 +1,38 @@ +package pipeline + +import ( + "context" + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/schema" + "ai_scheduler/internal/config" + "ai_scheduler/internal/domain/llm" + "ai_scheduler/internal/domain/llm/capability" + "ai_scheduler/internal/domain/llm/provider" + "ai_scheduler/internal/domain/llm/provider/ollama" + "ai_scheduler/internal/domain/llm/provider/vllm" +) + +func init() { + provider.Register("ollama", func() provider.Adapter { return ollama.New() }) + provider.Register("vllm", func() provider.Adapter { return vllm.New() }) +} + +func BuildChat(ctx context.Context, cfg *config.Config) (compose.Runnable[[]*schema.Message, *schema.Message], error) { + choice, opts, err := capability.Route(cfg, capability.Chat) + if err != nil { + return nil, err + } + if err = capability.Validate(capability.Chat, opts); err != nil { + return nil, err + } + f := provider.Get(choice.Provider) + if f == nil { + return nil, llm.ErrProviderNotFound + } + ad := f() + c := compose.NewChain[[]*schema.Message, *schema.Message]() + c.AppendLambda(compose.InvokableLambda(func(ctx context.Context, in []*schema.Message) (*schema.Message, error) { + return ad.Generate(ctx, in, opts) + })) + return c.Compile(ctx) +} diff --git a/internal/domain/llm/pipeline/intent.go b/internal/domain/llm/pipeline/intent.go new file mode 100644 index 0000000..d867933 --- /dev/null +++ b/internal/domain/llm/pipeline/intent.go @@ -0,0 +1,26 @@ +package pipeline + +import ( + "context" + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/schema" + "ai_scheduler/internal/config" + "ai_scheduler/internal/domain/llm" + "ai_scheduler/internal/domain/llm/capability" + "ai_scheduler/internal/domain/llm/provider" +) + +func BuildIntent(ctx context.Context, cfg *config.Config) (compose.Runnable[[]*schema.Message, *schema.Message], error) { + choice, opts, err := capability.Route(cfg, capability.Intent) + if err != nil { return nil, err } + if err = capability.Validate(capability.Intent, opts); err != nil { return nil, err } + f := provider.Get(choice.Provider) + if f == nil { return nil, llm.ErrProviderNotFound } + ad := f() + c := compose.NewChain[[]*schema.Message, *schema.Message]() + c.AppendLambda(compose.InvokableLambda(func(ctx context.Context, in []*schema.Message) (*schema.Message, error) { + return ad.Generate(ctx, in, opts) + })) + return c.Compile(ctx) +} + diff --git a/internal/domain/llm/pipeline/vision.go b/internal/domain/llm/pipeline/vision.go new file mode 100644 index 0000000..3563048 --- /dev/null +++ b/internal/domain/llm/pipeline/vision.go @@ -0,0 +1,26 @@ +package pipeline + +import ( + "context" + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/schema" + "ai_scheduler/internal/config" + "ai_scheduler/internal/domain/llm" + "ai_scheduler/internal/domain/llm/capability" + "ai_scheduler/internal/domain/llm/provider" +) + +func BuildVision(ctx context.Context, cfg *config.Config) (compose.Runnable[[]*schema.Message, *schema.Message], error) { + choice, opts, err := capability.Route(cfg, capability.Vision) + if err != nil { return nil, err } + if err = capability.Validate(capability.Vision, opts); err != nil { return nil, err } + f := provider.Get(choice.Provider) + if f == nil { return nil, llm.ErrProviderNotFound } + ad := f() + c := compose.NewChain[[]*schema.Message, *schema.Message]() + c.AppendLambda(compose.InvokableLambda(func(ctx context.Context, in []*schema.Message) (*schema.Message, error) { + return ad.Generate(ctx, in, opts) + })) + return c.Compile(ctx) +} + diff --git a/internal/domain/llm/prompt/templates.go b/internal/domain/llm/prompt/templates.go new file mode 100644 index 0000000..47da00a --- /dev/null +++ b/internal/domain/llm/prompt/templates.go @@ -0,0 +1,6 @@ +package prompt + +func SystemForChat() string { return "You are a helpful assistant." } +func SystemForVision() string { return "You are a vision assistant." } +func SystemForIntent() string { return "You classify user intent." } + diff --git a/internal/domain/llm/provider/ollama/adapter.go b/internal/domain/llm/provider/ollama/adapter.go new file mode 100644 index 0000000..1c26bab --- /dev/null +++ b/internal/domain/llm/provider/ollama/adapter.go @@ -0,0 +1,72 @@ +package ollama + +import ( + "ai_scheduler/internal/domain/llm" + "context" + + eino_ollama "github.com/cloudwego/eino-ext/components/model/ollama" + eino_model "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/schema" +) + +type Adapter struct{} + +func New() *Adapter { return &Adapter{} } + +func (a *Adapter) Generate(ctx context.Context, input []*schema.Message, opts llm.Options) (*schema.Message, error) { + cm, err := eino_ollama.NewChatModel(ctx, &eino_ollama.ChatModelConfig{ + BaseURL: opts.Endpoint, + Timeout: opts.Timeout, + Model: opts.Model, + Options: &eino_ollama.Options{Temperature: opts.Temperature, NumPredict: opts.MaxTokens}, + }) + if err != nil { + return nil, err + } + var mopts []eino_model.Option + if opts.Temperature != 0 { + mopts = append(mopts, eino_model.WithTemperature(opts.Temperature)) + } + if opts.MaxTokens > 0 { + mopts = append(mopts, eino_model.WithMaxTokens(opts.MaxTokens)) + } + if opts.Model != "" { + mopts = append(mopts, eino_model.WithModel(opts.Model)) + } + if opts.TopP != 0 { + mopts = append(mopts, eino_model.WithTopP(opts.TopP)) + } + if len(opts.Stop) > 0 { + mopts = append(mopts, eino_model.WithStop(opts.Stop)) + } + return cm.Generate(ctx, input, mopts...) +} + +func (a *Adapter) Stream(ctx context.Context, input []*schema.Message, opts llm.Options) (*schema.StreamReader[*schema.Message], error) { + cm, err := eino_ollama.NewChatModel(ctx, &eino_ollama.ChatModelConfig{ + BaseURL: opts.Endpoint, + Timeout: opts.Timeout, + Model: opts.Model, + Options: &eino_ollama.Options{Temperature: opts.Temperature, NumPredict: opts.MaxTokens}, + }) + if err != nil { + return nil, err + } + var mopts []eino_model.Option + if opts.Temperature != 0 { + mopts = append(mopts, eino_model.WithTemperature(opts.Temperature)) + } + if opts.MaxTokens > 0 { + mopts = append(mopts, eino_model.WithMaxTokens(opts.MaxTokens)) + } + if opts.Model != "" { + mopts = append(mopts, eino_model.WithModel(opts.Model)) + } + if opts.TopP != 0 { + mopts = append(mopts, eino_model.WithTopP(opts.TopP)) + } + if len(opts.Stop) > 0 { + mopts = append(mopts, eino_model.WithStop(opts.Stop)) + } + return cm.Stream(ctx, input, mopts...) +} diff --git a/internal/domain/llm/provider/registry.go b/internal/domain/llm/provider/registry.go new file mode 100644 index 0000000..9dc7ff5 --- /dev/null +++ b/internal/domain/llm/provider/registry.go @@ -0,0 +1,25 @@ +package provider + +import ( + "context" + "github.com/cloudwego/eino/schema" + "ai_scheduler/internal/domain/llm" +) + +type Adapter interface { + Generate(ctx context.Context, input []*schema.Message, opts llm.Options) (*schema.Message, error) + Stream(ctx context.Context, input []*schema.Message, opts llm.Options) (*schema.StreamReader[*schema.Message], error) +} + +type Factory func() Adapter + +var registry = map[string]Factory{} + +func Register(name string, f Factory) { + registry[name] = f +} + +func Get(name string) Factory { + return registry[name] +} + diff --git a/internal/domain/llm/provider/vllm/adapter.go b/internal/domain/llm/provider/vllm/adapter.go new file mode 100644 index 0000000..b3828db --- /dev/null +++ b/internal/domain/llm/provider/vllm/adapter.go @@ -0,0 +1,70 @@ +package vllm + +import ( + "ai_scheduler/internal/domain/llm" + "context" + + eino_openai "github.com/cloudwego/eino-ext/components/model/openai" + eino_model "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/schema" +) + +type Adapter struct{} + +func New() *Adapter { return &Adapter{} } + +func (a *Adapter) Generate(ctx context.Context, input []*schema.Message, opts llm.Options) (*schema.Message, error) { + cm, err := eino_openai.NewChatModel(ctx, &eino_openai.ChatModelConfig{ + BaseURL: opts.Endpoint, + Timeout: opts.Timeout, + Model: opts.Model, + }) + if err != nil { + return nil, err + } + var mopts []eino_model.Option + if opts.Temperature != 0 { + mopts = append(mopts, eino_model.WithTemperature(opts.Temperature)) + } + if opts.MaxTokens > 0 { + mopts = append(mopts, eino_model.WithMaxTokens(opts.MaxTokens)) + } + if opts.Model != "" { + mopts = append(mopts, eino_model.WithModel(opts.Model)) + } + if opts.TopP != 0 { + mopts = append(mopts, eino_model.WithTopP(opts.TopP)) + } + if len(opts.Stop) > 0 { + mopts = append(mopts, eino_model.WithStop(opts.Stop)) + } + return cm.Generate(ctx, input, mopts...) +} + +func (a *Adapter) Stream(ctx context.Context, input []*schema.Message, opts llm.Options) (*schema.StreamReader[*schema.Message], error) { + cm, err := eino_openai.NewChatModel(ctx, &eino_openai.ChatModelConfig{ + BaseURL: opts.Endpoint, + Timeout: opts.Timeout, + Model: opts.Model, + }) + if err != nil { + return nil, err + } + var mopts []eino_model.Option + if opts.Temperature != 0 { + mopts = append(mopts, eino_model.WithTemperature(opts.Temperature)) + } + if opts.MaxTokens > 0 { + mopts = append(mopts, eino_model.WithMaxTokens(opts.MaxTokens)) + } + if opts.Model != "" { + mopts = append(mopts, eino_model.WithModel(opts.Model)) + } + if opts.TopP != 0 { + mopts = append(mopts, eino_model.WithTopP(opts.TopP)) + } + if len(opts.Stop) > 0 { + mopts = append(mopts, eino_model.WithStop(opts.Stop)) + } + return cm.Stream(ctx, input, mopts...) +} diff --git a/internal/domain/llm/service/chat_service.go b/internal/domain/llm/service/chat_service.go new file mode 100644 index 0000000..f642cf4 --- /dev/null +++ b/internal/domain/llm/service/chat_service.go @@ -0,0 +1,22 @@ +package service + +import ( + "context" + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/schema" + "ai_scheduler/internal/config" + "ai_scheduler/internal/domain/llm/pipeline" +) + +type ChatService struct{ run compose.Runnable[[]*schema.Message, *schema.Message] } + +func NewChatService(ctx context.Context, cfg *config.Config) (*ChatService, error) { + r, err := pipeline.BuildChat(ctx, cfg) + if err != nil { return nil, err } + return &ChatService{run: r}, nil +} + +func (s *ChatService) Invoke(ctx context.Context, msgs []*schema.Message) (*schema.Message, error) { + return s.run.Invoke(ctx, msgs) +} + diff --git a/internal/domain/llm/service/intent_service.go b/internal/domain/llm/service/intent_service.go new file mode 100644 index 0000000..25929ae --- /dev/null +++ b/internal/domain/llm/service/intent_service.go @@ -0,0 +1,22 @@ +package service + +import ( + "context" + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/schema" + "ai_scheduler/internal/config" + "ai_scheduler/internal/domain/llm/pipeline" +) + +type IntentService struct{ run compose.Runnable[[]*schema.Message, *schema.Message] } + +func NewIntentService(ctx context.Context, cfg *config.Config) (*IntentService, error) { + r, err := pipeline.BuildIntent(ctx, cfg) + if err != nil { return nil, err } + return &IntentService{run: r}, nil +} + +func (s *IntentService) Invoke(ctx context.Context, msgs []*schema.Message) (*schema.Message, error) { + return s.run.Invoke(ctx, msgs) +} + diff --git a/internal/domain/llm/service/service.go b/internal/domain/llm/service/service.go new file mode 100644 index 0000000..a8e2e21 --- /dev/null +++ b/internal/domain/llm/service/service.go @@ -0,0 +1,96 @@ +package service + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/domain/llm" + "ai_scheduler/internal/domain/llm/capability" + "ai_scheduler/internal/domain/llm/provider" + "context" + + "github.com/cloudwego/eino/schema" +) + +type LLMService struct{ cfg *config.Config } + +func NewLLMService(cfg *config.Config) *LLMService { return &LLMService{cfg: cfg} } + +func (s *LLMService) Chat(ctx context.Context, input []*schema.Message, opts llm.Options) (*schema.Message, error) { + choice, routeOpts, err := capability.Route(s.cfg, capability.Chat) + if err != nil { + return nil, err + } + mergeOptions(&routeOpts, opts) + f := provider.Get(choice.Provider) + if f == nil { + return nil, llm.ErrProviderNotFound + } + return f().Generate(ctx, input, routeOpts) +} + +func (s *LLMService) ChatStream(ctx context.Context, input []*schema.Message, opts llm.Options) (*schema.StreamReader[*schema.Message], error) { + choice, routeOpts, err := capability.Route(s.cfg, capability.Chat) + if err != nil { + return nil, err + } + routeOpts.Stream = true + mergeOptions(&routeOpts, opts) + f := provider.Get(choice.Provider) + if f == nil { + return nil, llm.ErrProviderNotFound + } + return f().Stream(ctx, input, routeOpts) +} + +func (s *LLMService) Vision(ctx context.Context, input []*schema.Message, opts llm.Options) (*schema.Message, error) { + choice, routeOpts, err := capability.Route(s.cfg, capability.Vision) + if err != nil { + return nil, err + } + mergeOptions(&routeOpts, opts) + f := provider.Get(choice.Provider) + if f == nil { + return nil, llm.ErrProviderNotFound + } + return f().Generate(ctx, input, routeOpts) +} + +func (s *LLMService) Intent(ctx context.Context, input []*schema.Message, opts llm.Options) (*schema.Message, error) { + choice, routeOpts, err := capability.Route(s.cfg, capability.Intent) + if err != nil { + return nil, err + } + mergeOptions(&routeOpts, opts) + f := provider.Get(choice.Provider) + if f == nil { + return nil, llm.ErrProviderNotFound + } + return f().Generate(ctx, input, routeOpts) +} + +func mergeOptions(base *llm.Options, override llm.Options) { + if override.Model != "" { + base.Model = override.Model + } + if override.MaxTokens > 0 { + base.MaxTokens = override.MaxTokens + } + if override.Temperature != 0 { + base.Temperature = override.Temperature + } + if override.Timeout > 0 { + base.Timeout = override.Timeout + } + if len(override.Modalities) > 0 { + base.Modalities = override.Modalities + } + if override.SystemPrompt != "" { + base.SystemPrompt = override.SystemPrompt + } + if override.TopP != 0 { + base.TopP = override.TopP + } + if len(override.Stop) > 0 { + base.Stop = override.Stop + } + base.Stream = base.Stream || override.Stream +} diff --git a/internal/domain/llm/service/vision_service.go b/internal/domain/llm/service/vision_service.go new file mode 100644 index 0000000..be37be2 --- /dev/null +++ b/internal/domain/llm/service/vision_service.go @@ -0,0 +1,22 @@ +package service + +import ( + "context" + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/schema" + "ai_scheduler/internal/config" + "ai_scheduler/internal/domain/llm/pipeline" +) + +type VisionService struct{ run compose.Runnable[[]*schema.Message, *schema.Message] } + +func NewVisionService(ctx context.Context, cfg *config.Config) (*VisionService, error) { + r, err := pipeline.BuildVision(ctx, cfg) + if err != nil { return nil, err } + return &VisionService{run: r}, nil +} + +func (s *VisionService) Invoke(ctx context.Context, msgs []*schema.Message) (*schema.Message, error) { + return s.run.Invoke(ctx, msgs) +} + diff --git a/internal/domain/tools/registry.go b/internal/domain/tools/registry.go new file mode 100644 index 0000000..dc6a67e --- /dev/null +++ b/internal/domain/tools/registry.go @@ -0,0 +1,16 @@ +package tools + +type Tool interface{ + Name() string +} + +var registry = map[string]Tool{} + +func Register(t Tool){ + registry[t.Name()] = t +} + +func Get(name string) Tool{ + return registry[name] +} + From f240e1fac4e9ad599518e31a929ccb6dede177f8 Mon Sep 17 00:00:00 2001 From: fuzhongyun <15339891972@163.com> Date: Fri, 5 Dec 2025 09:11:38 +0800 Subject: [PATCH 03/66] =?UTF-8?q?feat:=20=E5=A2=9E=E5=8A=A0=E5=A4=9A?= =?UTF-8?q?=E6=A8=A1=E6=80=81=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/biz/llm_service/ollama.go | 56 +++++++++++++++ internal/domain/common/vision_builder.go | 37 ++++++++++ internal/domain/llm/errors.go | 8 +-- internal/domain/llm/pipeline/vision.go | 88 +++++++++++++++++++----- internal/domain/llm/prompt/templates.go | 13 ++-- 5 files changed, 176 insertions(+), 26 deletions(-) create mode 100644 internal/domain/common/vision_builder.go diff --git a/internal/biz/llm_service/ollama.go b/internal/biz/llm_service/ollama.go index 60d9c78..98100fa 100644 --- a/internal/biz/llm_service/ollama.go +++ b/internal/biz/llm_service/ollama.go @@ -13,6 +13,8 @@ import ( "strings" "time" + "github.com/cloudwego/eino-ext/components/model/openai" + "github.com/cloudwego/eino/schema" "github.com/ollama/ollama/api" "xorm.io/builder" ) @@ -147,6 +149,60 @@ func (r *OllamaService) RecognizeWithImg(ctx context.Context, requireData *entit return } +func (r *OllamaService) RecognizeWithImgVllm(ctx context.Context, requireData *entitys.RequireData) (desc api.GenerateResponse, err error) { + if requireData.ImgByte == nil { + return + } + entitys.ResLog(requireData.Ch, "recognize_img_start", "图片识别中...") + + chatModel, err := openai.NewChatModel(ctx, &openai.ChatModelConfig{}) + if err != nil { + return api.GenerateResponse{}, err + } + in := []*schema.Message{ + { + Role: "system", + Content: r.config.DefaultPrompt.ImgRecognize.SystemPrompt, + }, + { + Role: "user", + Content: r.config.DefaultPrompt.ImgRecognize.UserPrompt, + UserInputMultiContent: []schema.MessageInputPart{ + { + Type: schema.ChatMessagePartTypeText, + Text: r.config.DefaultPrompt.ImgRecognize.UserPrompt, + }, + }, + }, + } + + for _, imgUrl := range requireData.ImgUrls { + imgTmp := imgUrl + + in[1].UserInputMultiContent = append(in[1].UserInputMultiContent, schema.MessageInputPart{ + Type: schema.ChatMessagePartTypeImageURL, + Image: &schema.MessageInputImage{ + MessagePartCommon: schema.MessagePartCommon{ + URL: &imgTmp, + }, + Detail: schema.ImageURLDetailHigh, + }, + }) + } + + outMsg, err := chatModel.Generate(ctx, in) + if err != nil { + return api.GenerateResponse{}, err + } + + desc = api.GenerateResponse{ + Response: outMsg.Content, + } + + entitys.ResLog(requireData.Ch, "recognize_img_end", "图片识别完成,识别内容:"+desc.Response) + return +} + func (r *OllamaService) registerToolsOllama(tasks []model.AiTask) []api.Tool { taskPrompt := make([]api.Tool, 0) for _, task := range tasks { diff --git a/internal/domain/common/vision_builder.go b/internal/domain/common/vision_builder.go new file mode 100644 index 0000000..66153bf --- /dev/null +++ b/internal/domain/common/vision_builder.go @@ -0,0 +1,37 @@ +package common + +import ( + "errors" + "strings" + + "github.com/cloudwego/eino/schema" +) + +type ImageInput struct { + URLs []string +} + +func BuildVisionMessages(systemPrompt string, userText string, images ImageInput) ([]*schema.Message, error) { + if len(images.URLs) == 0 { + return nil, errors.New("vision requires at least one image url") + } + parts := make([]schema.MessageInputPart, 0, 1+len(images.URLs)) + if strings.TrimSpace(userText) != "" { + parts = append(parts, schema.MessageInputPart{Type: schema.ChatMessagePartTypeText, Text: userText}) + } + for _, u := range images.URLs { + if u == "" { + continue + } + if !strings.HasPrefix(u, "http://") && !strings.HasPrefix(u, "https://") { + continue + } + parts = append(parts, schema.MessageInputPart{Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageInputImage{MessagePartCommon: schema.MessagePartCommon{URL: &u}, Detail: schema.ImageURLDetailHigh}}) + } + if len(parts) == 0 { + return nil, errors.New("vision inputs invalid: no text or valid image urls") + } + msgs := []*schema.Message{schema.SystemMessage(systemPrompt)} + msgs = append(msgs, &schema.Message{Role: schema.User, UserInputMultiContent: parts}) + return msgs, nil +} diff --git a/internal/domain/llm/errors.go b/internal/domain/llm/errors.go index 59d76e6..366cf33 100644 --- a/internal/domain/llm/errors.go +++ b/internal/domain/llm/errors.go @@ -2,8 +2,8 @@ package llm import "errors" -var ErrInvalidCapability = errors.New("invalid capability") -var ErrProviderNotFound = errors.New("provider not found") -var ErrModelNotFound = errors.New("model not found") -var ErrModalityMismatch = errors.New("modality mismatch") +var ErrInvalidCapability = errors.New("能力未配置或无效") +var ErrProviderNotFound = errors.New("提供者未找到或未注册") +var ErrModelNotFound = errors.New("模型未找到或未配置") +var ErrModalityMismatch = errors.New("模态不匹配:视觉能力需要包含 image") var ErrNotImplemented = errors.New("not implemented") diff --git a/internal/domain/llm/pipeline/vision.go b/internal/domain/llm/pipeline/vision.go index 3563048..12f7981 100644 --- a/internal/domain/llm/pipeline/vision.go +++ b/internal/domain/llm/pipeline/vision.go @@ -1,26 +1,78 @@ package pipeline import ( - "context" - "github.com/cloudwego/eino/compose" - "github.com/cloudwego/eino/schema" - "ai_scheduler/internal/config" - "ai_scheduler/internal/domain/llm" - "ai_scheduler/internal/domain/llm/capability" - "ai_scheduler/internal/domain/llm/provider" + "ai_scheduler/internal/config" + "ai_scheduler/internal/domain/common" + "ai_scheduler/internal/domain/llm" + "ai_scheduler/internal/domain/llm/capability" + "ai_scheduler/internal/domain/llm/provider" + "context" + + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/schema" ) func BuildVision(ctx context.Context, cfg *config.Config) (compose.Runnable[[]*schema.Message, *schema.Message], error) { - choice, opts, err := capability.Route(cfg, capability.Vision) - if err != nil { return nil, err } - if err = capability.Validate(capability.Vision, opts); err != nil { return nil, err } - f := provider.Get(choice.Provider) - if f == nil { return nil, llm.ErrProviderNotFound } - ad := f() - c := compose.NewChain[[]*schema.Message, *schema.Message]() - c.AppendLambda(compose.InvokableLambda(func(ctx context.Context, in []*schema.Message) (*schema.Message, error) { - return ad.Generate(ctx, in, opts) - })) - return c.Compile(ctx) + choice, opts, err := capability.Route(cfg, capability.Vision) + if err != nil { + return nil, err + } + if err = capability.Validate(capability.Vision, opts); err != nil { + return nil, err + } + f := provider.Get(choice.Provider) + if f == nil { + return nil, llm.ErrProviderNotFound + } + ad := f() + c := compose.NewChain[[]*schema.Message, *schema.Message]() + c.AppendLambda(compose.InvokableLambda(func(ctx context.Context, in []*schema.Message) (*schema.Message, error) { + if len(in) == 0 { + msgs, err := common.BuildVisionMessages("你是一名视觉助手,请根据图片与描述进行理解与回答。", "", common.ImageInput{URLs: []string{}}) + if err != nil { + return nil, err + } + return ad.Generate(ctx, msgs, opts) + } + if len(in[0].MultiContent) == 0 { + urls := []string{} + for _, tok := range splitBySpace(in[0].Content) { + if hasHTTPPrefix(tok) { + urls = append(urls, tok) + } + } + msgs, err := common.BuildVisionMessages("你是一名视觉助手,请根据图片与描述进行理解与回答。", "", common.ImageInput{URLs: urls}) + if err != nil { + return nil, err + } + return ad.Generate(ctx, msgs, opts) + } + return ad.Generate(ctx, in, opts) + })) + return c.Compile(ctx) } +func splitBySpace(s string) []string { + res := []string{} + start := -1 + for i, r := range s { + if r == ' ' || r == '\n' || r == '\t' || r == '\r' { + if start >= 0 { + res = append(res, s[start:i]) + start = -1 + } + } else { + if start < 0 { + start = i + } + } + } + if start >= 0 { + res = append(res, s[start:]) + } + return res +} + +func hasHTTPPrefix(s string) bool { + return len(s) >= 7 && (s[:7] == "http://" || (len(s) >= 8 && s[:8] == "https://")) +} diff --git a/internal/domain/llm/prompt/templates.go b/internal/domain/llm/prompt/templates.go index 47da00a..47674a5 100644 --- a/internal/domain/llm/prompt/templates.go +++ b/internal/domain/llm/prompt/templates.go @@ -1,6 +1,11 @@ package prompt -func SystemForChat() string { return "You are a helpful assistant." } -func SystemForVision() string { return "You are a vision assistant." } -func SystemForIntent() string { return "You classify user intent." } - +func SystemForChat() string { + return "你是一名有用的助手,请用清晰、简洁的中文回答。" +} +func SystemForVision() string { + return "你是一名视觉助手,请根据图片与描述进行中文理解与回答。" +} +func SystemForIntent() string { + return "你负责意图识别,请用中文给出明确的意图类别与理由。" +} From f2638b32b55f374307e2cf7960ef9085d55c3f4f Mon Sep 17 00:00:00 2001 From: wolter <11@gmail> Date: Fri, 5 Dec 2025 09:12:06 +0800 Subject: [PATCH 04/66] =?UTF-8?q?feat:=20=E5=BF=83=E8=B7=B3=E6=A3=80?= =?UTF-8?q?=E6=B5=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/biz/do/ctx.go | 16 +++-- internal/biz/router.go | 4 +- internal/entitys/response.go | 5 +- internal/gateway/client.go | 77 +++++++++++++++++------ internal/gateway/gateway.go | 6 +- internal/services/chat.go | 117 +++++++++++++++++++++-------------- 6 files changed, 144 insertions(+), 81 deletions(-) diff --git a/internal/biz/do/ctx.go b/internal/biz/do/ctx.go index ed8749d..eb2c5a5 100644 --- a/internal/biz/do/ctx.go +++ b/internal/biz/do/ctx.go @@ -19,8 +19,6 @@ import ( "gitea.cdlsxd.cn/self-tools/l_request" "github.com/gofiber/fiber/v2/log" - "github.com/gofiber/websocket/v2" - "xorm.io/builder" ) @@ -141,10 +139,10 @@ func (d *Do) loadChatHistory(ctx context.Context, requireData *entitys.RequireDa return nil } -func (d *Do) MakeCh(c *websocket.Conn, requireData *entitys.RequireData) (ctx context.Context, deferFunc func()) { +func (d *Do) MakeCh(client *gateway.Client, requireData *entitys.RequireData) (ctx context.Context, deferFunc func()) { requireData.Ch = make(chan entitys.Response) ctx, cancel := context.WithCancel(context.Background()) - done := d.startMessageHandler(ctx, c, requireData) + done := d.startMessageHandler(ctx, client, requireData) return ctx, func() { close(requireData.Ch) //关闭主通道 <-done // 等待消息处理完成 @@ -235,7 +233,7 @@ func (d *Do) getTasks(sysId int32) (tasks []model.AiTask, err error) { // startMessageHandler 启动独立的消息处理协程 func (d *Do) startMessageHandler( ctx context.Context, - c *websocket.Conn, + client *gateway.Client, requireData *entitys.RequireData, ) <-chan struct{} { done := make(chan struct{}) @@ -259,7 +257,7 @@ func (d *Do) startMessageHandler( hisLog.HisId = AiRes.HisID } - _ = entitys.MsgSend(c, entitys.Response{ + _ = entitys.MsgSend(client, entitys.Response{ Content: pkg.JsonStringIgonErr(hisLog), Type: entitys.ResponseEnd, }) @@ -267,7 +265,7 @@ func (d *Do) startMessageHandler( }() for v := range requireData.Ch { // 自动检测通道关闭 - if err := sendWithTimeout(c, v, 2*time.Second); err != nil { + if err := sendWithTimeout(client, v, 10*time.Second); err != nil { log.Errorf("Send error: %v", err) return } @@ -281,7 +279,7 @@ func (d *Do) startMessageHandler( } // 辅助函数:带超时的 WebSocket 发送 -func sendWithTimeout(c *websocket.Conn, data entitys.Response, timeout time.Duration) error { +func sendWithTimeout(client *gateway.Client, data entitys.Response, timeout time.Duration) error { sendCtx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() @@ -294,7 +292,7 @@ func sendWithTimeout(c *websocket.Conn, data entitys.Response, timeout time.Dura close(done) }() // 如果 MsgSend 阻塞,这里会卡住 - err := entitys.MsgSend(c, data) + err := entitys.MsgSend(client, data) done <- err }() diff --git a/internal/biz/router.go b/internal/biz/router.go index 6dcc233..975a487 100644 --- a/internal/biz/router.go +++ b/internal/biz/router.go @@ -39,11 +39,9 @@ func (r *AiRouterBiz) RouteWithSocket(client *gateway.Client, req *entitys.ChatS requireData := &entitys.RequireData{ Req: req, } - // 获取WebSocket连接 - conn := client.GetConn() //初始化通道/上下文 - ctx, clearFunc := r.do.MakeCh(conn, requireData) + ctx, clearFunc := r.do.MakeCh(client, requireData) defer func() { if err != nil { entitys.ResError(requireData.Ch, "", err.Error()) diff --git a/internal/entitys/response.go b/internal/entitys/response.go index cdadc98..44e053b 100644 --- a/internal/entitys/response.go +++ b/internal/entitys/response.go @@ -1,6 +1,7 @@ package entitys import ( + "ai_scheduler/internal/gateway" "encoding/json" "github.com/gofiber/websocket/v2" ) @@ -100,13 +101,13 @@ func MsgSet(msgType ResponseType, msg string, done bool) []byte { return jsonByte } -func MsgSend(c *websocket.Conn, msg Response) error { +func MsgSend(client *gateway.Client, msg Response) error { // 检查上下文是否已取消 if msg.Type == ResponseText { } jsonByte, _ := json.Marshal(msg) - return c.WriteMessage(websocket.TextMessage, jsonByte) + return client.SendFunc(jsonByte) } func MsgSendByte(c *websocket.Conn, msg []byte) { diff --git a/internal/gateway/client.go b/internal/gateway/client.go index f90678c..a293daa 100644 --- a/internal/gateway/client.go +++ b/internal/gateway/client.go @@ -3,11 +3,11 @@ package gateway import ( errors "ai_scheduler/internal/data/error" "ai_scheduler/internal/data/model" - "ai_scheduler/internal/pkg" "context" - "encoding/binary" + "github.com/google/uuid" "log" "math/rand" + "sync" "time" "github.com/gofiber/websocket/v2" @@ -32,6 +32,7 @@ type Client struct { Ctx context.Context Cancel context.CancelFunc LastActive time.Time + mu sync.Mutex } func NewClient(conn *websocket.Conn, ctx context.Context, cancel context.CancelFunc) *Client { @@ -41,6 +42,7 @@ func NewClient(conn *websocket.Conn, ctx context.Context, cancel context.CancelF conn: conn, Ctx: ctx, Cancel: cancel, + mu: sync.Mutex{}, } } @@ -74,7 +76,7 @@ func (c *Client) GetCodes() []string { return c.codes } -// GetSysCode 获取系统编码 +// 获取系统编码 func (c *Client) GetSysCode() string { return c.sysCode } @@ -104,26 +106,51 @@ func (c *Client) SetCodes(codes []string) { c.codes = codes } +// Close 关闭客户端连接 +func (c *Client) Close() { + //c.mu.Lock() + //defer c.mu.Unlock() + if c.conn != nil { + _ = c.conn.Close() + c.conn = nil + } +} + // SendFunc 发送消息到客户端 func (c *Client) SendFunc(msg []byte) error { - if c.conn != nil { - return c.conn.WriteMessage(websocket.TextMessage, msg) + return c.SendMessage(websocket.TextMessage, msg) +} + +// 在Client结构体中添加更详细的日志 +func (c *Client) SendMessage(msgType int, msg []byte) error { + c.mu.Lock() + defer c.mu.Unlock() + + if c.conn == nil { + return ErrConnClosed } - return ErrConnClosed + + err := c.conn.WriteMessage(msgType, msg) + if err != nil { + log.Printf("发送消息失败: %v, 客户端ID: %s, 消息类型: %d", + err, c.id, msgType) + } + return err } // 生成唯一的客户端ID func generateClientID() string { - // 1. 时间戳 - timestamp := time.Now().UnixNano() - binary.BigEndian.PutUint64(idBuf[:8], uint64(timestamp)) - - // 2. 随机数(4字节) - binary.BigEndian.PutUint32(idBuf[8:12], rng.Uint32()) - - // 3. 十六进制编码 - n := pkg.HexEncode(idBuf[:12], idBuf[12:]) - return string(idBuf[12 : 12+n]) + return uuid.New().String() + //// 1. 时间戳 + //timestamp := time.Now().UnixNano() + //binary.BigEndian.PutUint64(idBuf[:8], uint64(timestamp)) + // + //// 2. 随机数(4字节) + //binary.BigEndian.PutUint32(idBuf[8:12], rng.Uint32()) + // + //// 3. 十六进制编码 + //n := pkg.HexEncode(idBuf[:12], idBuf[12:]) + //return string(idBuf[12 : 12+n]) } // 连接数据验证和收集 @@ -152,6 +179,7 @@ func (c *Client) DataAuth() (err error) { return } +// 总结:目前绝大多数浏览器不支持直接发送WebSocket Ping帧,因此在实际开发中,应该实现应用层ping机制作为主要心跳检测方案,todo 同时保留对未来可能的原生支持的兼容检测。 func (c *Client) InitHeartbeat(timeoutSecond time.Duration) { ticker := time.NewTicker(timeoutSecond * time.Second) defer ticker.Stop() @@ -160,9 +188,12 @@ func (c *Client) InitHeartbeat(timeoutSecond time.Duration) { case <-ticker.C: //*2是防止丢包,连续丢包两次,再加5s网络延迟容错 if time.Since(c.LastActive) > (timeoutSecond*2*time.Second + 5) { // 5秒容错 - log.Println("Heartbeat timeout", "id", c.id) - c.conn.WriteMessage(websocket.CloseMessage, []byte("Heartbeat timeout")) - c.conn.Close() + log.Println("心跳超时", "clientId", c.id) + err := c.SendMessage(websocket.CloseMessage, []byte("Heartbeat timeout")) + if err != nil { + log.Println("发送心跳超时消息失败", err) + } + c.Close() return } case <-c.Ctx.Done(): @@ -170,3 +201,11 @@ func (c *Client) InitHeartbeat(timeoutSecond time.Duration) { } } } + +// 在Client结构体中添加ReadMessage方法 +func (c *Client) ReadMessage() (messageType int, message []byte, err error) { + if c.conn == nil { + return 0, nil, ErrConnClosed + } + return c.conn.ReadMessage() +} diff --git a/internal/gateway/gateway.go b/internal/gateway/gateway.go index 03ce961..6e09bdc 100644 --- a/internal/gateway/gateway.go +++ b/internal/gateway/gateway.go @@ -34,15 +34,17 @@ func (g *Gateway) AddClient(c *Client) { func (g *Gateway) Cleanup(clientID string) { g.mu.Lock() + // 从网关管理中移除客户端 defer func() { if c, ex := g.clients[clientID]; ex { delete(g.clients, clientID) - _ = c.conn.Close() + c.Close() c.Cancel() } g.mu.Unlock() log.Println("client disconnected:", clientID) }() + // 从所有绑定的UID列表中移除该客户端 for uid, list := range g.uidMap { newList := []string{} for _, cid := range list { @@ -79,7 +81,7 @@ func (g *Gateway) BindUid(clientID, uid string) error { return errors.New("client not found") } g.uidMap[uid] = append(g.uidMap[uid], clientID) - log.Printf("bind %s -> uid:%s\n", clientID, uid) + log.Printf("绑定 clientId %s -> uid:%s\n", clientID, uid) return nil } diff --git a/internal/services/chat.go b/internal/services/chat.go index a01ba73..11dda38 100644 --- a/internal/services/chat.go +++ b/internal/services/chat.go @@ -55,82 +55,109 @@ func (h *ChatService) Chat(c *websocket.Conn) { // 创建新的客户端实例 client := gateway.NewClient(c, ctx, cancel) - // 心跳检测 - go client.InitHeartbeat(time.Duration(h.cfg.Sys.HeartbeatInterval)) - // 将客户端添加到网关管理 + + // 验证并收集连接数据,后续对话中会使用 + if err := client.DataAuth(); err != nil { + log.Println("数据验证错误:", err) + _ = client.SendFunc([]byte(err.Error())) + client.Close() + return + } + + // 验证通过后,将客户端添加到网关管理 h.Gw.AddClient(client) + + // 使用信号量限制并发处理的消息数量 + semaphore := make(chan struct{}, 1) // 最多1个并发消息处理 + // 用于等待所有goroutine完成的wait group + var wg sync.WaitGroup // 确保在函数返回时移除客户端并关闭连接 defer func() { h.Gw.Cleanup(client.GetID()) + close(semaphore) // 关闭信号量通道 + wg.Wait() // 等待所有消息处理goroutine完成 }() - // 绑定会话ID - uid := c.Query("x-session") - if uid != "" { + // 绑定会话ID, sessionId 为空时, 则不绑定 + if uid := client.GetSession(); uid != "" { if err := h.Gw.BindUid(client.GetID(), uid); err != nil { log.Println("绑定UID错误:", err) } } - // 验证并收集连接数据,后续对话中会使用 - if err := client.DataAuth(); err != nil { - log.Println("数据验证错误:", err) - h.ChatFail(c, err.Error()) - return - } + // 开启心跳检测 + go client.InitHeartbeat(time.Duration(h.cfg.Sys.HeartbeatInterval)) // 循环读取客户端消息 for { - // 读取消息 - messageType, message, err := c.ReadMessage() + messageType, message, err := client.ReadMessage() if err != nil { - log.Println("读取错误:", err) + log.Printf("读取错误: %v, 客户端ID: %s", err, client.GetID()) break } - //if string(message) == `{"type":"ping"}` { - // client.LastActive = time.Now() - // if err := c.WriteMessage(websocket.TextMessage, []byte(`{"type":"pong"}`)); err != nil { - // log.Println("Heartbeat response failed", "id", client.GetID(), "err", err) - // return - // } - // continue - //} - // 处理消息 - msg, chatType := h.handleMessageToString(c, messageType, message) - if chatType == constants.ConnStatusClosed { - break - } - if chatType == constants.ConnStatusIgnore { + + // 处理心跳消息 + if messageType == websocket.PingMessage || string(message) == "PING" { + client.LastActive = time.Now() + msgType := websocket.TextMessage + if messageType == websocket.PingMessage { + msgType = websocket.PongMessage + } + if err = client.SendMessage(msgType, []byte(`PONG`)); err != nil { + log.Printf("发送pong消息失败: %v", err) + } continue } - log.Printf("收到消息: %s", string(msg)) + // 使用信号量限制并发 + semaphore <- struct{}{} + wg.Add(1) + go func(msgType int, msg []byte) { + defer func() { + <-semaphore + wg.Done() + // 恢复panic + if r := recover(); r != nil { + log.Printf("消息处理goroutine发生panic: %v", r) + } + }() - // 解析请求 - var req entitys.ChatSockRequest - if err = json.Unmarshal(msg, &req); err != nil { - log.Println("JSON parse error:", err) - continue - } + // 消息处理逻辑 + h.processMessage(client, msgType, msg) + }(messageType, message) + } +} - // 路由处理请求 - err = h.routerBiz.RouteWithSocket(client, &req) - if err != nil { - log.Println("处理失败:", err) - } +// 将消息处理逻辑提取到单独的方法 +func (h *ChatService) processMessage(client *gateway.Client, msgType int, msg []byte) { + // 处理消息 + processedMsg, _ := h.handleMessageToString(client, msgType, msg) + log.Printf("收到消息:消息类型 %d, 内容 %s, 客户端ID: %s", + msgType, string(processedMsg), client.GetID()) + + // 解析请求 + var req entitys.ChatSockRequest + if err := json.Unmarshal(processedMsg, &req); err != nil { + log.Printf("JSON解析错误: %v, 客户端ID: %s", err, client.GetID()) + return + } + + // 路由处理请求 + if err := h.routerBiz.RouteWithSocket(client, &req); err != nil { + log.Printf("处理失败: %v, 客户端ID: %s", err, client.GetID()) } } // handleMessageToString 处理不同类型的WebSocket消息 // 参数: -// - c: WebSocket连接 +// - client: 客户端对象 // - msgType: 消息类型 // - msg: 消息内容 // // 返回: // - text: 处理后的文本内容 // - chatType: 连接状态 -func (h *ChatService) handleMessageToString(c *websocket.Conn, msgType int, msg any) (text []byte, chatType constants.ConnStatus) { +func (h *ChatService) handleMessageToString(client *gateway.Client, msgType int, msg any) (text []byte, chatType constants.ConnStatus) { switch msgType { case websocket.TextMessage: return msg.([]byte), constants.ConnStatusNormal @@ -140,15 +167,13 @@ func (h *ChatService) handleMessageToString(c *websocket.Conn, msgType int, msg return nil, constants.ConnStatusClosed case websocket.PingMessage: - // 可选:回复 Pong - c.WriteMessage(websocket.PongMessage, nil) return nil, constants.ConnStatusIgnore case websocket.PongMessage: return nil, constants.ConnStatusIgnore default: + log.Printf("未知的消息类型: %d", msgType) return nil, constants.ConnStatusIgnore } - return msg.([]byte), constants.ConnStatusIgnore } func (s *ChatService) Useful(c *fiber.Ctx) error { From 67d09e9c91f791d4a57c9328e97e94c001dbef38 Mon Sep 17 00:00:00 2001 From: fuzhongyun <15339891972@163.com> Date: Fri, 5 Dec 2025 14:16:23 +0800 Subject: [PATCH 05/66] chore --- config/config_test.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/config/config_test.yaml b/config/config_test.yaml index 90c971d..f6a7be0 100644 --- a/config/config_test.yaml +++ b/config/config_test.yaml @@ -128,3 +128,4 @@ llm: temperature: 0.7 max_tokens: 4096 stream: true + From 325ea8d78c9425dc7c8e63bbf3845f1914296da1 Mon Sep 17 00:00:00 2001 From: wolter <11@gmail> Date: Fri, 5 Dec 2025 14:48:49 +0800 Subject: [PATCH 06/66] feat: Dockerfile --- Dockerfile | 26 ++++++++++++++++++++++++-- deploy.sh | 14 +++++++------- go.mod | 6 ++---- 3 files changed, 33 insertions(+), 13 deletions(-) diff --git a/Dockerfile b/Dockerfile index 4b5b1ec..d1cf2c1 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,3 +1,25 @@ +## 使用官方Go镜像作为构建环境 +FROM golang:1.24.0-alpine AS builder + +# 设置工作目录 +WORKDIR /app + +# 使用国内镜像源加速APK包下载 +RUN sed -i 's/dl-cdn.alpinelinux.org/mirrors.aliyun.com/g' /etc/apk/repositories + +# 使用国内镜像源加速依赖下载 +ENV GOPROXY=https://goproxy.cn,direct + +# 复制项目源码 +COPY . . + +# 复制go模块依赖文件 +COPY go.mod go.sum ./ +RUN go mod download + +# 编译Go应用程序,生成静态链接的二进制文件 +RUN go build -ldflags="-s -w" -o server ./cmd/server + # 创建最终镜像,用于运行编译后的Go程序 FROM alpine @@ -12,10 +34,10 @@ RUN echo 'http://mirrors.ustc.edu.cn/alpine/v3.5/main' > /etc/apk/repositories \ WORKDIR /app # 将编译好的二进制文件从构建阶段复制到运行阶段 -COPY ./ /app +COPY --from=builder /app/server ./server ENV TZ=Asia/Shanghai # 设置容器启动时运行的命令 -CMD ["./bin/server"] +CMD ["./server"] diff --git a/deploy.sh b/deploy.sh index e813827..ab694f2 100644 --- a/deploy.sh +++ b/deploy.sh @@ -1,9 +1,9 @@ -export GO111MODULE=on -export GOPROXY=https://goproxy.cn,direct -export GOPATH=/root/go -export GOCACHE=/root/.cache/go-build +#export GO111MODULE=on +#export GOPROXY=https://goproxy.cn,direct +#export GOPATH=/root/go +#export GOCACHE=/root/.cache/go-build export CONTAINER_NAME=ai_scheduler -export CGO_ENABLED='0' +#export CGO_ENABLED='0' MODE="$1" @@ -22,8 +22,8 @@ fi git fetch origin git checkout "$BRANCH" git pull origin "$BRANCH" -go mod tidy -make build +#go mod tidy +#make build docker build -t ${CONTAINER_NAME} . docker stop ${CONTAINER_NAME} docker rm -f ${CONTAINER_NAME} diff --git a/go.mod b/go.mod index 6b08e4a..cee04ec 100644 --- a/go.mod +++ b/go.mod @@ -2,8 +2,6 @@ module ai_scheduler go 1.24.0 -toolchain go1.24.7 - require ( gitea.cdlsxd.cn/self-tools/l_request v1.0.8 github.com/alibabacloud-go/darabonba-openapi/v2 v2.0.12 @@ -23,8 +21,6 @@ require ( github.com/spf13/viper v1.17.0 github.com/tmc/langchaingo v0.1.13 google.golang.org/grpc v1.64.0 - google.golang.org/protobuf v1.34.1 - gopkg.in/yaml.v3 v3.0.1 gorm.io/driver/mysql v1.6.0 gorm.io/gorm v1.31.0 xorm.io/builder v0.3.13 @@ -82,5 +78,7 @@ require ( golang.org/x/sys v0.31.0 // indirect golang.org/x/text v0.23.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240604185151-ef581f913117 // indirect + google.golang.org/protobuf v1.34.1 // indirect gopkg.in/ini.v1 v1.67.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) From 0f9ac06f73aff5bdda6a26ffc38aa1248fef67bf Mon Sep 17 00:00:00 2001 From: wolter <11@gmail> Date: Fri, 5 Dec 2025 15:08:28 +0800 Subject: [PATCH 07/66] feat: Dockerfile1 --- Dockerfile | 6 ++++++ deploy.sh | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index d1cf2c1..2c0e527 100644 --- a/Dockerfile +++ b/Dockerfile @@ -17,6 +17,9 @@ COPY . . COPY go.mod go.sum ./ RUN go mod download +RUN go install github.com/google/wire/cmd/wire@latest +RUN wire ./cmd/server + # 编译Go应用程序,生成静态链接的二进制文件 RUN go build -ldflags="-s -w" -o server ./cmd/server @@ -35,6 +38,9 @@ WORKDIR /app # 将编译好的二进制文件从构建阶段复制到运行阶段 COPY --from=builder /app/server ./server +# 复制配置文件夹 +COPY --from=builder /app/config ./config + ENV TZ=Asia/Shanghai # 设置容器启动时运行的命令 diff --git a/deploy.sh b/deploy.sh index ab694f2..f02e162 100644 --- a/deploy.sh +++ b/deploy.sh @@ -33,6 +33,6 @@ docker run -itd \ -e "OLLAMA_BASE_URL=${OLLAMA_BASE_URL:-http://host.docker.internal:11434}" \ -e "MODE=${MODE}" \ -p 8090:8090 \ - "${CONTAINER_NAME}" ./bin/server --config "./${CONFIG_FILE}" + "${CONTAINER_NAME}" ./server --config "./${CONFIG_FILE}" docker logs -f ${CONTAINER_NAME} \ No newline at end of file From 881310c0ccc568e9b671f3c4e9124b1e8fd3f1c3 Mon Sep 17 00:00:00 2001 From: fuzhongyun <15339891972@163.com> Date: Fri, 5 Dec 2025 17:51:41 +0800 Subject: [PATCH 08/66] =?UTF-8?q?feat:=201.=E6=96=B0=E5=A2=9Evllm=E5=AE=A2?= =?UTF-8?q?=E6=88=B7=E7=AB=AF=E5=8F=8A=E7=9B=B8=E5=85=B3=E6=96=B9=E6=B3=95?= =?UTF-8?q?=20=202.=E5=9B=BE=E7=89=87=E8=AF=86=E5=88=AB=E6=9C=8D=E5=8A=A1?= =?UTF-8?q?=E5=88=87=E6=8D=A2=E5=88=B0vllm?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/config_env.yaml | 86 +++++++++++++++++++++++++ config/config_test.yaml | 5 ++ go.mod | 28 +++++++- go.sum | 88 ++++++++++++++++++++++++++ internal/biz/llm_service/ollama.go | 42 ++++++++++-- internal/config/config.go | 38 ++++++----- internal/pkg/provider_set.go | 2 + internal/pkg/utils_vllm/client.go | 60 ++++++++++++++++++ internal/pkg/utils_vllm/client_test.go | 66 +++++++++++++++++++ 9 files changed, 391 insertions(+), 24 deletions(-) create mode 100644 config/config_env.yaml create mode 100644 internal/pkg/utils_vllm/client.go create mode 100644 internal/pkg/utils_vllm/client_test.go diff --git a/config/config_env.yaml b/config/config_env.yaml new file mode 100644 index 0000000..b6fa262 --- /dev/null +++ b/config/config_env.yaml @@ -0,0 +1,86 @@ +# 服务器配置 +server: + port: 8090 + host: "0.0.0.0" + +ollama: + base_url: "http://192.168.6.109:11434" + model: "qwen3-coder:480b-cloud" + generate_model: "qwen3-coder:480b-cloud" + vl_model: "qwen2.5vl:7b" + timeout: "120s" + level: "info" + format: "json" + +vllm: + base_url: "http://117.175.169.61:16001/v1" + vl_model: "models/Qwen2.5-VL-3B-Instruct-AWQ" + timeout: "120s" + level: "info" + + +sys: + session_len: 6 + channel_pool_len: 100 + channel_pool_size: 32 + llm_pool_len: 5 +redis: + host: 47.97.27.195:6379 + type: node + pass: lansexiongdi@666 + key: report-api-test + pollSize: 5 #连接池大小,不配置,或配置为0表示不启用连接池 + minIdleConns: 2 #最小空闲连接数 + maxIdleTime: 30 #每个连接最大空闲时间,如果超过了这个时间会被关闭 + tls: 30 + db: +db: + driver: mysql + source: root:SD###sdf323r343@tcp(121.199.38.107:3306)/sys_ai_test?charset=utf8mb4&parseTime=true&loc=Asia%2FShanghai + +tools: + zltxOrderDetail: + enabled: true + base_url: "https://gateway.dev.cdlsxd.cn/zltx_api/admin/direct/ai/%s" + add_url: "https://gateway.dev.cdlsxd.cn/zltx_api/admin/direct/log/%s/%s" + api_key: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJ1c2VyQ2VudGVyIiwiZXhwIjoxNzU4MDkxOTU4LCJuYmYiOjE3NTgwOTAxNTgsImp0aSI6IjEiLCJQaG9uZSI6IjE4MDAwMDAwMDAwIiwiVXNlck5hbWUiOiJsc3hkIiwiUmVhbE5hbWUiOiLotoXnuqfnrqHnkIblkZgiLCJBY2NvdW50VHlwZSI6MSwiR3JvdXBDb2RlcyI6IlZDTF9DQVNISUVSLFZDTF9PUEVSQVRFLFZDTF9BRE1JTixWQ0xfQUFBLFZDTF9WQ0xfT1BFUkFULFZDTF9JTlZPSUNFLENSTV9BRE1JTixMSUFOTElBTl9BRE1JTixNQVJLRVRNQUcyX0FETUlOLFBIT05FQklMTF9BRE1JTixRSUFOWkhVX1NVUFBFUl9BRE0sTUFSS0VUSU5HU0FBU19TVVBFUkFETUlOLENBUkRfQ09ERSxDQVJEX1BST0NVUkVNRU5ULE1BUktFVElOR1NZU1RFTV9TVVBFUixTVEFUSVNUSUNBTFNZU1RFTV9BRE1JTixaTFRYX0FETUlOLFpMVFhfT1BFUkFURSIsIkRpbmdVc2VySWQiOiIxNjIwMjYxMjMwMjg5MzM4MzQifQ.Bjsx9f8yfcrV9EWxb0n6POwnXVOq9XPRD78JFZnnf1_VAVMN78W4W570SZL27PWuDnkD7E4oUg6RzeZwZgl7BZrNpNr-a-QpNC5qCptqrqXeNfVStmX7pxWA8GqnzI8ybkZgbhQ58Gje7DzdJtBq_8zte_LDaYhTYXdIc5EAG0AbCzAk22nPTl47nkMeHtmisXQVLEsdibl1hW3ViFJlXwfXvUrOENItmL1_mRYkggUB0MaTu2nHJOYM6PaOVGLHx-74eepnmK2rm6konFEb6ed-Ukc6gVR-nM9yWZaYLYNGNKJLwZoCX3tRuerq74n4kzQgWmUEJeaVI1yIGSw1zw" + zltxProduct: + enabled: true + base_url: "https://gateway.dev.cdlsxd.cn/zltx_api/admin/oursProduct" + add_url: "https://gateway.dev.cdlsxd.cn/zltx_api/admin/platformProduct/getProductsByOfficialProductId" + api_key: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJ1c2VyQ2VudGVyIiwiZXhwIjoxNzU2MTgyNTM1LCJuYmYiOjE3NTYxODA3MzUsImp0aSI6IjEiLCJQaG9uZSI6IjE4MDAwMDAwMDAwIiwiVXNlck5hbWUiOiJsc3hkIiwiUmVhbE5hbWUiOiLotoXnuqfnrqHnkIblkZgiLCJBY2NvdW50VHlwZSI6MSwiR3JvdXBDb2RlcyI6IlZDTF9DQVNISUVSLFZDTF9PUEVSQVRFLFZDTF9BRE1JTixWQ0xfQUFBLFZDTF9WQ0xfT1BFUkFULFZDTF9JTlZPSUNFLENSTV9BRE1JTixMSUFOTElBTl9BRE1JTixNQVJLRVRNQUcyX0FETUlOLFBIT05FQklMTF9BRE1JTixRSUFOWkhVX1NVUFBFUl9BRE0sTUFSS0VUSU5HU0FBU19TVVBFUkFETUlOLENBUkRfQ09ERSxDQVJEX1BST0NVUkVNRU5ULE1BUktFVElOR1NZU1RFTV9TVVBFUixTVEFUSVNUSUNBTFNZU1RFTV9BRE1JTixaTFRYX0FETUlOLFpMVFhfT1BFUkFURSIsIkRpbmdVc2VySWQiOiIxNjIwMjYxMjMwMjg5MzM4MzQifQ.N1xv1PYbcO8_jR5adaczc16YzGsr4z101gwEZdulkRaREBJNYTOnFrvRxTFx3RJTooXsqTqroE1MR84v_1WPX6BS6kKonA-kC1Jgot6yrt5rFWhGNGb2Cpr9rKIFCCQYmiGd3AUgDazEeaQ0_sodv3E-EXg9VfE1SX8nMcck9Yjnc8NCy7RTWaBIaSeOdZcEl-JfCD0S6GSx3oErp_hk-U9FKGwf60wAuDGTY1R0BP4BYpcEqS-C2LSnsSGyURi54Cuk5xH8r1WuF0Dm5bwAj5d7Hvs77-N_sUF-C5ONqyZJRAEhYLgcmN9RX_WQZfizdQJxizlTczdpzYfy-v-1eQ" + zltxOrderStatistics: + base_url: "https://gateway.dev.cdlsxd.cn/zltx_api/admin/direct/ai/search/" + enabled: true + api_key: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJ1c2VyQ2VudGVyIiwiZXhwIjoxNzU2MTgyNTM1LCJuYmYiOjE3NTYxODA3MzUsImp0aSI6IjEiLCJQaG9uZSI6IjE4MDAwMDAwMDAwIiwiVXNlck5hbWUiOiJsc3hkIiwiUmVhbE5hbWUiOiLotoXnuqfnrqHnkIblkZgiLCJBY2NvdW50VHlwZSI6MSwiR3JvdXBDb2RlcyI6IlZDTF9DQVNISUVSLFZDTF9PUEVSQVRFLFZDTF9BRE1JTixWQ0xfQUFBLFZDTF9WQ0xfT1BFUkFULFZDTF9JTlZPSUNFLENSTV9BRE1JTixMSUFOTElBTl9BRE1JTixNQVJLRVRNQUcyX0FETUlOLFBIT05FQklMTF9BRE1JTixRSUFOWkhVX1NVUFBFUl9BRE0sTUFSS0VUSU5HU0FBU19TVVBFUkFETUlOLENBUkRfQ09ERSxDQVJEX1BST0NVUkVNRU5ULE1BUktFVElOR1NZU1RFTV9TVVBFUixTVEFUSVNUSUNBTFNZU1RFTV9BRE1JTixaTFRYX0FETUlOLFpMVFhfT1BFUkFURSIsIkRpbmdVc2VySWQiOiIxNjIwMjYxMjMwMjg5MzM4MzQifQ.N1xv1PYbcO8_jR5adaczc16YzGsr4z101gwEZdulkRaREBJNYTOnFrvRxTFx3RJTooXsqTqroE1MR84v_1WPX6BS6kKonA-kC1Jgot6yrt5rFWhGNGb2Cpr9rKIFCCQYmiGd3AUgDazEeaQ0_sodv3E-EXg9VfE1SX8nMcck9Yjnc8NCy7RTWaBIaSeOdZcEl-JfCD0S6GSx3oErp_hk-U9FKGwf60wAuDGTY1R0BP4BYpcEqS-C2LSnsSGyURi54Cuk5xH8r1WuF0Dm5bwAj5d7Hvs77-N_sUF-C5ONqyZJRAEhYLgcmN9RX_WQZfizdQJxizlTczdpzYfy-v-1eQ" + knowledge: + base_url: "http://117.175.169.61:10000" + enabled: true + DingTalkBot: + enabled: true + api_key: "dingsbbntrkeiyazcfdg" + api_secret: "ObqxwyR20r9rVNhju0sCPQyQA98_FZSc32W4vgxnGFH_b02HZr1BPCJsOAF816nu" + zltxOrderAfterSaleSupplier: + enabled: true + base_url: "https://gateway.dev.cdlsxd.cn/zltx_api/admin/afterSales/directs" + zltxOrderAfterSaleReseller: + enabled: true + base_url: "https://gateway.dev.cdlsxd.cn/zltx_api/admin/afterSales/reseller_pre_ai" + zltxOrderAfterSaleResellerBatch: + enabled: true + base_url: "https://gateway.dev.cdlsxd.cn/zltx_api/admin/afterSales/reseller_pre_ai" + + + +default_prompt: + img_recognize: + system_prompt: + '你是一个具备图像理解与用户意图分析能力的智能助手。当用户提供一张图片时,请完成以下任务: + 1. 关键信息提取: + 提取出图片中对用户可能有用的关键信息(例如金额、日期、标题、编号、联系信息、商品名称等)。 + 若图片为文档类(如合同、发票、收据),请结构化输出关键字段(如客户名称、金额、开票日期等)。 + ' + user_prompt: '识别图片内容' +# 权限配置 +permissionConfig: + permission_url: "http://api.test.user.1688sup.cn:8001/v1/menu/myCodes?systemId=" diff --git a/config/config_test.yaml b/config/config_test.yaml index 8275102..0958db3 100644 --- a/config/config_test.yaml +++ b/config/config_test.yaml @@ -12,6 +12,11 @@ ollama: level: "info" format: "json" +vllm: + base_url: "http://127.0.0.1:8001/v1" + vl_model: "models/Qwen2.5-VL-3B-Instruct-AWQ" + timeout: "120s" + level: "info" sys: diff --git a/go.mod b/go.mod index 6b08e4a..f49df60 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,8 @@ require ( github.com/alibabacloud-go/dingtalk v1.6.96 github.com/alibabacloud-go/tea v1.2.2 github.com/alibabacloud-go/tea-utils/v2 v2.0.6 + github.com/cloudwego/eino v0.7.7 + github.com/cloudwego/eino-ext/components/model/openai v0.1.5 github.com/emirpasic/gods v1.18.1 github.com/faabiosr/cachego v0.26.0 github.com/fastwego/dingding v1.0.0-beta.4 @@ -23,8 +25,6 @@ require ( github.com/spf13/viper v1.17.0 github.com/tmc/langchaingo v0.1.13 google.golang.org/grpc v1.64.0 - google.golang.org/protobuf v1.34.1 - gopkg.in/yaml.v3 v3.0.1 gorm.io/driver/mysql v1.6.0 gorm.io/gorm v1.31.0 xorm.io/builder v0.3.13 @@ -39,31 +39,49 @@ require ( github.com/alibabacloud-go/tea-xml v1.1.3 // indirect github.com/aliyun/credentials-go v1.4.6 // indirect github.com/andybalholm/brotli v1.1.0 // indirect + github.com/bahlo/generic-list-go v0.2.0 // indirect + github.com/buger/jsonparser v1.1.1 // indirect + github.com/bytedance/gopkg v0.1.3 // indirect + github.com/bytedance/sonic v1.14.1 // indirect + github.com/bytedance/sonic/loader v0.3.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/clbanning/mxj/v2 v2.5.5 // indirect + github.com/cloudwego/base64x v0.1.6 // indirect + github.com/cloudwego/eino-ext/libs/acl/openai v0.1.2 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dlclark/regexp2 v1.11.4 // indirect + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/eino-contrib/jsonschema v1.0.3 // indirect + github.com/evanphx/json-patch v0.5.2 // indirect github.com/fasthttp/websocket v1.5.3 // indirect github.com/fsnotify/fsnotify v1.6.0 // indirect github.com/go-sql-driver/mysql v1.8.1 // indirect + github.com/goph/emperror v0.17.2 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/compress v1.17.9 // indirect + github.com/klauspost/cpuid/v2 v2.2.9 // indirect github.com/magiconair/properties v1.8.7 // indirect + github.com/mailru/easyjson v0.7.7 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-runewidth v0.0.16 // indirect + github.com/meguminnnnnnnnn/go-openai v0.1.0 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/nikolalohinski/gonja v1.5.3 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect + github.com/pkg/errors v0.9.1 // indirect github.com/pkoukk/tiktoken-go v0.1.6 // indirect github.com/rivo/uniseg v0.2.0 // indirect github.com/sagikazarmark/locafero v0.3.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee // indirect + github.com/sirupsen/logrus v1.9.3 // indirect + github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f // indirect github.com/sourcegraph/conc v0.3.0 // indirect github.com/spf13/afero v1.10.0 // indirect github.com/spf13/cast v1.5.1 // indirect @@ -71,16 +89,22 @@ require ( github.com/stretchr/testify v1.11.1 // indirect github.com/subosito/gotenv v1.6.0 // indirect github.com/tjfoc/gmsm v1.4.1 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasthttp v1.51.0 // indirect github.com/valyala/tcplisten v1.0.0 // indirect + github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect + github.com/yargevad/filepathx v1.0.0 // indirect go.uber.org/atomic v1.9.0 // indirect go.uber.org/multierr v1.9.0 // indirect + golang.org/x/arch v0.11.0 // indirect golang.org/x/crypto v0.36.0 // indirect golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect golang.org/x/net v0.38.0 // indirect golang.org/x/sys v0.31.0 // indirect golang.org/x/text v0.23.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240604185151-ef581f913117 // indirect + google.golang.org/protobuf v1.34.1 // indirect gopkg.in/ini.v1 v1.67.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 29b4ace..5dc58f0 100644 --- a/go.sum +++ b/go.sum @@ -44,6 +44,7 @@ gitea.com/xorm/sqlfiddle v0.0.0-20180821085327-62ce714f951a h1:lSA0F4e9A2NcQSqGq gitea.com/xorm/sqlfiddle v0.0.0-20180821085327-62ce714f951a/go.mod h1:EXuID2Zs0pAQhH8yz+DNjUbjppKQzKFAn28TMYPB6IU= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= +github.com/airbrake/gobrake v3.6.1+incompatible/go.mod h1:wM4gu3Cn0W0K7GUuVWnlXZU11AGBXMILnrdOU8Kn00o= github.com/alibabacloud-go/alibabacloud-gateway-pop v0.0.6 h1:eIf+iGJxdU4U9ypaUfbtOWCsZSbTb8AUHvyPrxu6mAA= github.com/alibabacloud-go/alibabacloud-gateway-pop v0.0.6/go.mod h1:4EUIoxs/do24zMOGGqYVWgw0s9NtiylnJglOeEB5UJo= github.com/alibabacloud-go/alibabacloud-gateway-spi v0.0.4/go.mod h1:sCavSAvdzOjul4cEqeVtvlSaSScfNsTQ+46HwlTL1hc= @@ -96,12 +97,29 @@ github.com/aliyun/credentials-go v1.4.6 h1:CG8rc/nxCNKfXbZWpWDzI9GjF4Tuu3Es14qT8 github.com/aliyun/credentials-go v1.4.6/go.mod h1:Jm6d+xIgwJVLVWT561vy67ZRP4lPTQxMbEYRuT2Ti1U= github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M= github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY= +github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= +github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= +github.com/bitly/go-simplejson v0.5.0/go.mod h1:cXHtHw4XUPsvGaxgjIAn8PhEWG9NfngEKAMDJEczWVA= +github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4= github.com/bradfitz/gomemcache v0.0.0-20170208213004-1952afaa557d/go.mod h1:PmM6Mmwb0LSuEubjR8N7PtNe1KxZLtOUHtbeikc5h60= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= +github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= +github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= +github.com/bugsnag/bugsnag-go v1.4.0/go.mod h1:2oa8nejYd4cQ/b0hMIopN0lCRxU0bueqREvZLWFrtK8= +github.com/bugsnag/panicwrap v1.2.0/go.mod h1:D/8v3kj0zr8ZAKg1AQ6crr+5VwKN5eIywRkfhyM/+dE= +github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M= +github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM= +github.com/bytedance/mockey v1.2.14 h1:KZaFgPdiUwW+jOWFieo3Lr7INM1P+6adO3hxZhDswY8= +github.com/bytedance/mockey v1.2.14/go.mod h1:1BPHF9sol5R1ud/+0VEHGQq/+i2lN+GTsr3O2Q9IENY= +github.com/bytedance/sonic v1.14.1 h1:FBMC0zVz5XUmE4z9wF4Jey0An5FueFvOsTKKKtwIl7w= +github.com/bytedance/sonic v1.14.1/go.mod h1:gi6uhQLMbTdeP0muCnrjHLeCUPyb70ujhnNlhOylAFc= +github.com/bytedance/sonic/loader v0.3.0 h1:dskwH8edlzNMctoruo8FPTJDF3vLtDT0sXZwvZJyqeA= +github.com/bytedance/sonic/loader v0.3.0/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= +github.com/certifi/gocertifi v0.0.0-20190105021004-abcd57078448/go.mod h1:GJKEexRPVJrBSOjoqN5VNOIKJ5Q3RViH6eu3puDRwx4= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= @@ -110,6 +128,14 @@ github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMn github.com/clbanning/mxj/v2 v2.5.5 h1:oT81vUeEiQQ/DcHbzSytRngP6Ky9O+L+0Bw0zSJag9E= github.com/clbanning/mxj/v2 v2.5.5/go.mod h1:hNiWqW14h+kc+MdF9C6/YoRfjEJoR3ou6tn/Qo+ve2s= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= +github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= +github.com/cloudwego/eino v0.7.7 h1:WhP0SMWWPgLdOH03HrKUxtP9/Q96NhziMZNEQl9lxpU= +github.com/cloudwego/eino v0.7.7/go.mod h1:nA8Vacmuqv3pqKBQbTWENBLQ8MmGmPt/WqiyLeB8ohQ= +github.com/cloudwego/eino-ext/components/model/openai v0.1.5 h1:+yvGbTPw93li9GSmdm6Rix88Yy8AXg5NNBcRbWx3CQU= +github.com/cloudwego/eino-ext/components/model/openai v0.1.5/go.mod h1:IPVYMFoZcuHeVEsDTGN6SZjvue0xr1iZFhdpq1SBWdQ= +github.com/cloudwego/eino-ext/libs/acl/openai v0.1.2 h1:r9Id2wzJ05PoHl+Km7jQgNMgciaZI93TVnUYso89esM= +github.com/cloudwego/eino-ext/libs/acl/openai v0.1.2/go.mod h1:S4OkvglPY9hsm9tXeShODrf/WN1Cgu4bqu4nn/CnIic= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= @@ -121,6 +147,10 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/dlclark/regexp2 v1.11.4 h1:rPYF9/LECdNymJufQKmri9gV604RvvABwgOA8un7yAo= github.com/dlclark/regexp2 v1.11.4/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/eino-contrib/jsonschema v1.0.3 h1:2Kfsm1xlMV0ssY2nuxshS4AwbLFuqmPmzIjLVJ1Fsp0= +github.com/eino-contrib/jsonschema v1.0.3/go.mod h1:cpnX4SyKjWjGC7iN2EbhxaTdLqGjCi0e9DxpLYxddD4= github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc= github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= @@ -129,6 +159,8 @@ github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1m github.com/envoyproxy/go-control-plane v0.9.7/go.mod h1:cwu0lG7PUMfa9snN8LXBig5ynNVH9qI8YYLbd1fK2po= github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/evanphx/json-patch v0.5.2 h1:xVCHIVMUu1wtM/VkR9jVZ45N3FhZfYMMYGorLCR8P3k= +github.com/evanphx/json-patch v0.5.2/go.mod h1:ZWS5hhDbVDyob71nXKNL0+PWn6ToqBHMikGIFbs31qQ= github.com/faabiosr/cachego v0.15.0/go.mod h1:L2EomlU3/rUWjzFavY9Fwm8B4zZmX2X6u8kTMkETrwI= github.com/faabiosr/cachego v0.26.0 h1:EDDv2y9T0XJ4Cx3tUhbKSUayGWxCGkkZUivNLceHRWY= github.com/faabiosr/cachego v0.26.0/go.mod h1:p54WXVzeB1CctH1ix/rjqv1EotNzD0Xoxk2IsR1PQX8= @@ -143,6 +175,9 @@ github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4 github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY= github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw= github.com/garyburd/redigo v1.6.0/go.mod h1:NR3MbYisc3/PwhQ00EMzDiPmrwpPxAn5GI05/YaO1SY= +github.com/getsentry/raven-go v0.2.0/go.mod h1:KungGk8q33+aIAZUIVWZDr2OfAEBsO49PX4NzFV5kcQ= +github.com/go-check/check v0.0.0-20180628173108-788fd7840127 h1:0gkP6mzaMqkmpcJYCFOLkIBwI7xFExG03bbkOkCvUPI= +github.com/go-check/check v0.0.0-20180628173108-788fd7840127/go.mod h1:9ES+weclKsC9YodN5RgxqK/VD9HM9JsCSh7rNhMZE98= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= @@ -154,6 +189,7 @@ github.com/gofiber/fiber/v2 v2.52.9 h1:YjKl5DOiyP3j0mO61u3NTmK7or8GzzWzCFzkboyP5 github.com/gofiber/fiber/v2 v2.52.9/go.mod h1:YEcBbO/FB+5M1IZNBP9FO3J9281zgPAreiI1oqg8nDw= github.com/gofiber/websocket/v2 v2.2.1 h1:C9cjxvloojayOp9AovmpQrk8VqvVnT8Oao3+IUygH7w= github.com/gofiber/websocket/v2 v2.2.1/go.mod h1:Ao/+nyNnX5u/hIFPuHl28a+NIkrqK7PRimyKaj4JxVU= +github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= @@ -215,8 +251,12 @@ github.com/google/wire v0.7.0/go.mod h1:n6YbUQD9cPKTnHXEBN2DXlOp/mVADhVErcMFb0v3 github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= github.com/googleapis/google-cloud-go-testing v0.0.0-20200911160855-bcd43fbb19e8/go.mod h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g= +github.com/goph/emperror v0.17.2 h1:yLapQcmEsO0ipe9p5TaN22djm3OFV/TfM/fcYP0/J18= +github.com/goph/emperror v0.17.2/go.mod h1:+ZbQ+fUNO/6FNiUo0ujtMjhgad9Xa6fQL9KhH4LNHic= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= +github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25dO0g= +github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= @@ -224,19 +264,26 @@ github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= +github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk= +github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= +github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0/go.mod h1:1NbS8ALrpOvjt0rHPNLyCIeMtbizbir8U//inJ+zuB8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= +github.com/klauspost/cpuid/v2 v2.2.9 h1:66ze0taIn2H33fBvCkXuv9BmCwDfafmiIVpKV9kKGuY= +github.com/klauspost/cpuid/v2 v2.2.9/go.mod h1:rqkxqrZ1EhYM9G+hXH7YdowN5R5RGN6NK4QwQ3WMXF8= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= @@ -247,6 +294,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= +github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= +github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= @@ -255,6 +304,10 @@ github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mattn/go-sqlite3 v1.6.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= +github.com/meguminnnnnnnnn/go-openai v0.1.0 h1:BGzB1PlS2Epq0mBB2TGLwzMihbR7BANrlMH3w4ZnY88= +github.com/meguminnnnnnnnn/go-openai v0.1.0/go.mod h1:qs96ysDmxhE4BZoU45I43zcyfnaYxU3X+aRzLko/htY= +github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b h1:j7+1HpAFS1zy5+Q4qx1fWh90gTKwiN4QCGoY9TWyyO4= +github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -265,16 +318,22 @@ github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3Rllmb github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= +github.com/nikolalohinski/gonja v1.5.3 h1:GsA+EEaZDZPGJ8JtpeGN78jidhOlxeJROpqMT9fTj9c= +github.com/nikolalohinski/gonja v1.5.3/go.mod h1:RmjwxNiXAEqcq1HeK5SSMmqFJvKOfTfXhkJv6YBtPa4= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= github.com/ollama/ollama v0.12.7 h1:dxokli1UyO/a0Aun5sE4+0Gg+A9oMUAPiFQhxrXOIXA= github.com/ollama/ollama v0.12.7/go.mod h1:9+1//yWPsDE2u+l1a5mpaKrYw4VdnSsRU3ioq5BvMms= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.8.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= github.com/onsi/ginkgo v1.13.0/go.mod h1:+REjRxOmWfHCjfv9TTWB1jD1Frx4XydAD3zm1lskyM0= +github.com/onsi/gomega v1.5.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= +github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/sftp v1.13.1/go.mod h1:3HaPG6Dq1ILlpPZRO0HVMrsydcdLt6HRDccSgb87qRg= github.com/pkoukk/tiktoken-go v0.1.6 h1:JF0TlJzhTbrI30wCvFuiw6FzP2+/bR+FIxUdgEAcUsw= @@ -290,6 +349,7 @@ github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJ github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= +github.com/rollbar/rollbar-go v1.0.2/go.mod h1:AcFs5f0I+c71bpHlXNNDbOWJiKwjFDtISeXco0L5PKQ= github.com/sagikazarmark/locafero v0.3.0 h1:zT7VEGWC2DTflmccN/5T1etyKvxSxpHsjb9cJvm4SvQ= github.com/sagikazarmark/locafero v0.3.0/go.mod h1:w+v7UsPNFwzF1cHuOajOOzoq4U7v/ig1mpRjqV+Bu1U= github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE= @@ -297,9 +357,18 @@ github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWR github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee h1:8Iv5m6xEo1NR1AvpV+7XmhI4r39LGNzwUL4YpMuL5vk= github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee/go.mod h1:qwtSXrKuJh/zsFQ12yEE89xfCrGKK63Rr7ctU/uCo4g= github.com/sclevine/agouti v3.0.0+incompatible/go.mod h1:b4WX9W9L1sfQKXeJf1mUTLZKJ48R1S7H23Ji7oFO5Bw= +github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f h1:Z2cODYsUxQPofhpYRMQVwWz4yUVpHF+vPi+eUdruUYI= +github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f/go.mod h1:JqzWyvTuI2X4+9wOHmKSQCYxybB/8j6Ko43qVmXDuZg= +github.com/smarty/assertions v1.15.0 h1:cR//PqUBUiQRakZWqBiFFQ9wb8emQGDb0HeGdqGByCY= +github.com/smarty/assertions v1.15.0/go.mod h1:yABtdzeQs6l1brC900WlRNwj6ZR55d7B+E8C6HtKdec= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/assertions v1.1.0/go.mod h1:tcbTF8ujkAEcZ8TElKY+i30BzYlVhC/LOxJk7iOWnoo= github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= +github.com/smartystreets/goconvey v1.8.1 h1:qGjIddxOk4grTu9JPOU31tVfq3cNdBlNa5sSznIX1xY= +github.com/smartystreets/goconvey v1.8.1/go.mod h1:+/u4qLyY6x1jReYOp7GOM2FSt8aP9CzCZL03bI28W60= github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= github.com/spf13/afero v1.10.0 h1:EaGW2JJh15aKOejeuJ+wpFSHnbd7GE6Wvp3TsNhb6LY= @@ -311,16 +380,19 @@ github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An github.com/spf13/viper v1.17.0 h1:I5txKw7MJasPL/BrfkbA0Jyo/oELqVmux4pR/UxOMfI= github.com/spf13/viper v1.17.0/go.mod h1:BmMMMLQXSbcHK6KAOiFLz0l5JHrU89OdIRHvsk0+yVI= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= @@ -332,12 +404,20 @@ github.com/tjfoc/gmsm v1.4.1 h1:aMe1GlZb+0bLjn+cKTPEvvn9oUEBlJitaZiiBwsbgho= github.com/tjfoc/gmsm v1.4.1/go.mod h1:j4INPkHWMrhJb38G+J6W4Tw0AbuN8Thu3PbdVYhVcTE= github.com/tmc/langchaingo v0.1.13 h1:rcpMWBIi2y3B90XxfE4Ao8dhCQPVDMaNPnN5cGB1CaA= github.com/tmc/langchaingo v0.1.13/go.mod h1:vpQ5NOIhpzxDfTZK9B6tf2GM/MoaHewPWM5KXXGh7hg= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasthttp v1.51.0 h1:8b30A5JlZ6C7AS81RsWjYMQmrZG6feChmgAolCl1SqA= github.com/valyala/fasthttp v1.51.0/go.mod h1:oI2XroL+lI7vdXyYoQk03bXBThfFl2cVdIA3Xl7cH8g= github.com/valyala/tcplisten v1.0.0 h1:rBHj/Xf+E1tRGZyWIWwJDiRY0zc1Js+CV5DqwacVSA8= github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc= +github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= +github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= +github.com/x-cray/logrus-prefixed-formatter v0.5.2 h1:00txxvfBM9muc0jiLIEAkAcIMJzfthRT6usrui8uGmg= +github.com/x-cray/logrus-prefixed-formatter v0.5.2/go.mod h1:2duySbKsL6M18s5GU7VPsoEPHyzalCE06qoARUCeBBE= +github.com/yargevad/filepathx v1.0.0 h1:SYcT+N3tYGi+NvazubCNlvgIPbzAk7i7y2dwg3I5FYc= +github.com/yargevad/filepathx v1.0.0/go.mod h1:BprfX/gpYNJHJfc35GjRRpVcwWXS89gGulUIU5tK3tA= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.30/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= @@ -353,8 +433,13 @@ go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk= go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= +go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ= +golang.org/x/arch v0.11.0 h1:KXV8WWKCXm6tRpLirl2szsO5j/oOODwZf4hATmGVNs4= +golang.org/x/arch v0.11.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= +golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= @@ -478,6 +563,7 @@ golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw= golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -519,6 +605,7 @@ golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -721,6 +808,7 @@ gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/mgo.v2 v2.0.0-20160818020120-3f83fa500528/go.mod h1:yeKp02qBN3iKW1OzL3MGk2IdtZzaj7SFntXj72NppTA= gopkg.in/redis.v4 v4.2.4/go.mod h1:8KREHdypkCEojGKQcjMqAODMICIVwZAONWq8RowTITA= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= +gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/internal/biz/llm_service/ollama.go b/internal/biz/llm_service/ollama.go index 60d9c78..f850367 100644 --- a/internal/biz/llm_service/ollama.go +++ b/internal/biz/llm_service/ollama.go @@ -7,6 +7,7 @@ import ( "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg" "ai_scheduler/internal/pkg/utils_ollama" + "ai_scheduler/internal/pkg/utils_vllm" "context" "encoding/json" "errors" @@ -18,20 +19,23 @@ import ( ) type OllamaService struct { - client *utils_ollama.Client - config *config.Config - chatHis *impl.ChatImpl + client *utils_ollama.Client + vllmClient *utils_vllm.Client + config *config.Config + chatHis *impl.ChatImpl } func NewOllamaGenerate( client *utils_ollama.Client, + vllmClient *utils_vllm.Client, config *config.Config, chatHis *impl.ChatImpl, ) *OllamaService { return &OllamaService{ - client: client, - config: config, - chatHis: chatHis, + client: client, + vllmClient: vllmClient, + config: config, + chatHis: chatHis, } } @@ -103,7 +107,8 @@ func (r *OllamaService) getUserContent(ctx context.Context, requireData *entitys } if len(requireData.ImgByte) > 0 { - desc, err := r.RecognizeWithImg(ctx, requireData) + // desc, err := r.RecognizeWithImg(ctx, requireData) + desc, err := r.RecognizeWithImgVllm(ctx, requireData) if err != nil { return "", err } @@ -147,6 +152,29 @@ func (r *OllamaService) RecognizeWithImg(ctx context.Context, requireData *entit return } +func (r *OllamaService) RecognizeWithImgVllm(ctx context.Context, requireData *entitys.RequireData) (desc api.GenerateResponse, err error) { + if requireData.ImgByte == nil { + return + } + entitys.ResLog(requireData.Ch, "recognize_img_start", "图片识别中...") + + outMsg, err := r.vllmClient.RecognizeWithImg(ctx, + r.config.DefaultPrompt.ImgRecognize.SystemPrompt, + r.config.DefaultPrompt.ImgRecognize.UserPrompt, + requireData.ImgUrls, + ) + if err != nil { + return api.GenerateResponse{}, err + } + + desc = api.GenerateResponse{ + Response: outMsg.Content, + } + + entitys.ResLog(requireData.Ch, "recognize_img_end", "图片识别完成,识别内容:"+desc.Response) + return +} + func (r *OllamaService) registerToolsOllama(tasks []model.AiTask) []api.Tool { taskPrompt := make([]api.Tool, 0) for _, task := range tasks { diff --git a/internal/config/config.go b/internal/config/config.go index 5c9fa21..27738bc 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -9,16 +9,17 @@ import ( // Config 应用配置 type Config struct { - Server ServerConfig `mapstructure:"server"` - Ollama OllamaConfig `mapstructure:"ollama"` - Sys SysConfig `mapstructure:"sys"` - Tools ToolsConfig `mapstructure:"tools"` - Logging LoggingConfig `mapstructure:"logging"` - Redis Redis `mapstructure:"redis"` - DB DB `mapstructure:"db"` - DefaultPrompt SysPrompt `mapstructure:"default_prompt"` - PermissionConfig PermissionConfig `mapstructure:"permissionConfig"` - // LLM *LLM `mapstructure:"llm"` + Server ServerConfig `mapstructure:"server"` + Ollama OllamaConfig `mapstructure:"ollama"` + Vllm VllmConfig `mapstructure:"vllm"` + Sys SysConfig `mapstructure:"sys"` + Tools ToolsConfig `mapstructure:"tools"` + Logging LoggingConfig `mapstructure:"logging"` + Redis Redis `mapstructure:"redis"` + DB DB `mapstructure:"db"` + DefaultPrompt SysPrompt `mapstructure:"default_prompt"` + PermissionConfig PermissionConfig `mapstructure:"permissionConfig"` + // LLM *LLM `mapstructure:"llm"` } type SysPrompt struct { @@ -50,11 +51,18 @@ type ServerConfig struct { // OllamaConfig Ollama配置 type OllamaConfig struct { - BaseURL string `mapstructure:"base_url"` - Model string `mapstructure:"model"` - GenerateModel string `mapstructure:"generate_model"` - VlModel string `mapstructure:"vl_model"` - Timeout time.Duration `mapstructure:"timeout"` + BaseURL string `mapstructure:"base_url"` + Model string `mapstructure:"model"` + GenerateModel string `mapstructure:"generate_model"` + VlModel string `mapstructure:"vl_model"` + Timeout time.Duration `mapstructure:"timeout"` +} + +type VllmConfig struct { + BaseURL string `mapstructure:"base_url"` + VlModel string `mapstructure:"vl_model"` + Timeout time.Duration `mapstructure:"timeout"` + Level string `mapstructure:"level"` } type Redis struct { diff --git a/internal/pkg/provider_set.go b/internal/pkg/provider_set.go index 93d9180..f8fadac 100644 --- a/internal/pkg/provider_set.go +++ b/internal/pkg/provider_set.go @@ -4,6 +4,7 @@ import ( "ai_scheduler/internal/pkg/dingtalk" "ai_scheduler/internal/pkg/utils_langchain" "ai_scheduler/internal/pkg/utils_ollama" + "ai_scheduler/internal/pkg/utils_vllm" "github.com/google/wire" ) @@ -13,6 +14,7 @@ var ProviderSetClient = wire.NewSet( NewGormDb, utils_langchain.NewUtilLangChain, utils_ollama.NewClient, + utils_vllm.NewClient, NewSafeChannelPool, dingtalk.NewOldClient, dingtalk.NewContactClient, diff --git a/internal/pkg/utils_vllm/client.go b/internal/pkg/utils_vllm/client.go new file mode 100644 index 0000000..c333350 --- /dev/null +++ b/internal/pkg/utils_vllm/client.go @@ -0,0 +1,60 @@ +package utils_vllm + +import ( + "ai_scheduler/internal/config" + "context" + + "github.com/cloudwego/eino-ext/components/model/openai" + "github.com/cloudwego/eino/schema" +) + +type Client struct { + model *openai.ChatModel + config *config.Config +} + +func NewClient(config *config.Config) (*Client, func(), error) { + m, err := openai.NewChatModel(context.Background(), &openai.ChatModelConfig{ + BaseURL: config.Vllm.BaseURL, + Model: config.Vllm.VlModel, + Timeout: config.Vllm.Timeout, + }) + if err != nil { + return nil, nil, err + } + c := &Client{model: m, config: config} + cleanup := func() {} + return c, cleanup, nil +} + +func (c *Client) Chat(ctx context.Context, msgs []*schema.Message) (*schema.Message, error) { + return c.model.Generate(ctx, msgs) +} + +func (c *Client) RecognizeWithImg(ctx context.Context, systemPrompt, userPrompt string, imgURLs []string) (*schema.Message, error) { + in := []*schema.Message{ + { + Role: schema.System, + Content: systemPrompt, + }, + { + Role: schema.User, + }, + } + parts := []schema.MessageInputPart{ + {Type: schema.ChatMessagePartTypeText, Text: userPrompt}, + } + for i := range imgURLs { + u := imgURLs[i] + parts = append(parts, schema.MessageInputPart{ + Type: schema.ChatMessagePartTypeImageURL, + Image: &schema.MessageInputImage{ + MessagePartCommon: schema.MessagePartCommon{URL: &u}, + Detail: schema.ImageURLDetailHigh, + }, + }) + } + + in[1].UserInputMultiContent = parts + return c.model.Generate(ctx, in) +} diff --git a/internal/pkg/utils_vllm/client_test.go b/internal/pkg/utils_vllm/client_test.go new file mode 100644 index 0000000..e71390d --- /dev/null +++ b/internal/pkg/utils_vllm/client_test.go @@ -0,0 +1,66 @@ +package utils_vllm + +import ( + "ai_scheduler/internal/config" + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/cloudwego/eino/schema" +) + +func newMockServer() *httptest.Server { + h := http.NewServeMux() + h.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"id":"cmpl-1","object":"chat.completion","created":173,"model":"x","choices":[{"index":0,"message":{"role":"assistant","content":"ok"},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}`)) + }) + return httptest.NewServer(h) +} + +func Test_Vllm_Chat_Generate(t *testing.T) { + cfg, err := config.LoadConfig("../../../config/config_test.yaml") + if err != nil { + t.Fatalf("load config: %v", err) + } + + ctx := context.Background() + client, _, err := NewClient(cfg) + if err != nil { + t.Fatalf("new client: %v", err) + } + + msgs := []*schema.Message{{Role: schema.User, Content: "hi"}} + out, err := client.Chat(ctx, msgs) + if err != nil { + t.Fatalf("chat generate: %v", err) + } + if out == nil || out.Content != "ok" { + t.Fatalf("unexpected content: %v", out) + } + + t.Logf("结果: %v", out) +} + +func Test_Vllm_RecognizeWithImg(t *testing.T) { + cfg, err := config.LoadConfig("../../../config/config_test.yaml") + if err != nil { + t.Fatalf("load config: %v", err) + } + + ctx := context.Background() + client, _, err := NewClient(cfg) + if err != nil { + t.Fatalf("new client: %v", err) + } + + out, err := client.RecognizeWithImg(ctx, "sys", "user", []string{"https://img0.baidu.com/it/u=910428455,194434251&fm=253&app=138&f=JPEG?w=1122&h=800"}) + if err != nil { + t.Fatalf("recognize with img: %v", err) + } + if out == nil || out.Content != "ok" { + t.Fatalf("unexpected content: %v", out) + } +} From 057cf707d30f062a010111866d39dadfce71b33d Mon Sep 17 00:00:00 2001 From: wolter <11@gmail> Date: Fri, 5 Dec 2025 17:59:27 +0800 Subject: [PATCH 09/66] =?UTF-8?q?feat:=20=E4=BC=9A=E8=AF=9D=E5=8E=86?= =?UTF-8?q?=E5=8F=B2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/biz/chat_history.go | 50 ++++++++++++++++---------- internal/biz/do/ctx.go | 5 +-- internal/biz/do/handle.go | 1 + internal/biz/llm_service/ollama.go | 4 +-- internal/biz/session.go | 14 ++++---- internal/data/impl/chat_history.go | 12 +++---- internal/data/impl/provider_set.go | 2 +- internal/data/model/ai_chat_his.gen.go | 1 + internal/entitys/chat_history.go | 6 ++++ internal/entitys/types.go | 1 + internal/server/http.go | 29 +++++++-------- internal/server/router/router.go | 16 ++++++--- internal/services/chat.go | 4 +-- internal/services/chat_history.go | 44 +++++++++++++++++++++++ internal/services/provider_set.go | 2 +- 15 files changed, 133 insertions(+), 58 deletions(-) create mode 100644 internal/services/chat_history.go diff --git a/internal/biz/chat_history.go b/internal/biz/chat_history.go index 48dcd0d..eeb74e5 100644 --- a/internal/biz/chat_history.go +++ b/internal/biz/chat_history.go @@ -10,31 +10,45 @@ import ( ) type ChatHistoryBiz struct { - chatRepo *impl.ChatImpl + chatHiRepo *impl.ChatHisImpl } -func NewChatHistoryBiz(chatRepo *impl.ChatImpl) *ChatHistoryBiz { +func NewChatHistoryBiz(chatHiRepo *impl.ChatHisImpl) *ChatHistoryBiz { s := &ChatHistoryBiz{ - chatRepo: chatRepo, + chatHiRepo: chatHiRepo, } //go s.AsyncProcess(context.Background()) return s } -//func (s *ChatHistoryBiz) create(ctx context.Context, sessionID, role, content string) error { -// chat := model.AiChatHi{ -// SessionID: sessionID, -// Role: role, -// Content: content, -// } -// -// return s.chatRepo.Create(&chat) -//} -// -//// 添加会话历史 -//func (s *ChatHistoryBiz) Create(ctx context.Context, chat entitys.ChatHistory) error { -// return s.create(ctx, chat.SessionID, chat.Role.String(), chat.Content) -//} +// 查询会话历史 +func (s *ChatHistoryBiz) List(ctx context.Context, query *entitys.ChatHistQuery) ([]model.AiChatHi, error) { + chats, err := s.chatHiRepo.FindAll( + s.chatHiRepo.WithSessionId(query.SessionID), + s.chatHiRepo.PaginateScope(query.Page, query.PageSize), + ) + if err != nil { + return nil, err + } + return chats, nil +} + +// 添加会话历史 +func (s *ChatHistoryBiz) Create(ctx context.Context, chat entitys.ChatHistory) error { + return s.chatHiRepo.Create(&model.AiChatHi{ + SessionID: chat.SessionID, + Ques: chat.Role.String(), + Ans: chat.Content, + }) +} + +// 更新会话历史内容 +func (s *ChatHistoryBiz) UpdateContent(ctx context.Context, chat *entitys.UseFulRequest) error { + cond := builder.NewCond() + cond = cond.And(builder.Eq{"his_id": chat.HisId}) + + return s.chatHiRepo.UpdateByCond(&cond, &model.AiChatHi{HisID: chat.HisId, Useful: chat.Useful}) +} // 异步添加会话历史 //func (s *ChatHistoryBiz) AsyncCreate(ctx context.Context, chat entitys.ChatHistory) { @@ -53,5 +67,5 @@ func NewChatHistoryBiz(chatRepo *impl.ChatImpl) *ChatHistoryBiz { func (s *ChatHistoryBiz) Update(ctx context.Context, chat *entitys.UseFulRequest) error { cond := builder.NewCond() cond = cond.And(builder.Eq{"his_id": chat.HisId}) - return s.chatRepo.UpdateByCond(&cond, &model.AiChatHi{HisID: chat.HisId, Useful: chat.Useful}) + return s.chatHiRepo.UpdateByCond(&cond, &model.AiChatHi{HisID: chat.HisId, Useful: chat.Useful}) } diff --git a/internal/biz/do/ctx.go b/internal/biz/do/ctx.go index eb2c5a5..43f1e1a 100644 --- a/internal/biz/do/ctx.go +++ b/internal/biz/do/ctx.go @@ -27,14 +27,14 @@ type Do struct { sessionImpl *impl.SessionImpl sysImpl *impl.SysImpl taskImpl *impl.TaskImpl - hisImpl *impl.ChatImpl + hisImpl *impl.ChatHisImpl conf *config.Config } func NewDo( sysImpl *impl.SysImpl, taskImpl *impl.TaskImpl, - hisImpl *impl.ChatImpl, + hisImpl *impl.ChatHisImpl, conf *config.Config, ) *Do { return &Do{ @@ -252,6 +252,7 @@ func (d *Do) startMessageHandler( Ques: requireData.Req.Text, Ans: strings.Join(chat, ""), Files: requireData.Req.Img, + TaskID: requireData.Task.TaskID, } d.hisImpl.AddWithData(AiRes) hisLog.HisId = AiRes.HisID diff --git a/internal/biz/do/handle.go b/internal/biz/do/handle.go index e014d25..14695e6 100644 --- a/internal/biz/do/handle.go +++ b/internal/biz/do/handle.go @@ -85,6 +85,7 @@ func (r *Handle) HandleMatch(ctx context.Context, client *gateway.Client, requir for _, task := range requireData.Tasks { if task.Index == requireData.Match.Index { pointTask = &task + requireData.Task = task break } } diff --git a/internal/biz/llm_service/ollama.go b/internal/biz/llm_service/ollama.go index 6527a3b..56887a2 100644 --- a/internal/biz/llm_service/ollama.go +++ b/internal/biz/llm_service/ollama.go @@ -20,13 +20,13 @@ import ( type OllamaService struct { client *utils_ollama.Client config *config.Config - chatHis *impl.ChatImpl + chatHis *impl.ChatHisImpl } func NewOllamaGenerate( client *utils_ollama.Client, config *config.Config, - chatHis *impl.ChatImpl, + chatHis *impl.ChatHisImpl, ) *OllamaService { return &OllamaService{ client: client, diff --git a/internal/biz/session.go b/internal/biz/session.go index d014148..a42e9c0 100644 --- a/internal/biz/session.go +++ b/internal/biz/session.go @@ -17,16 +17,16 @@ import ( type SessionBiz struct { sessionRepo *impl.SessionImpl sysRepo *impl.SysImpl - chatRepo *impl.ChatImpl + chatHisRepo *impl.ChatHisImpl conf *config.Config } -func NewSessionBiz(conf *config.Config, sessionImpl *impl.SessionImpl, sysImpl *impl.SysImpl, chatImpl *impl.ChatImpl) *SessionBiz { +func NewSessionBiz(conf *config.Config, sessionImpl *impl.SessionImpl, sysImpl *impl.SysImpl, chatImpl *impl.ChatHisImpl) *SessionBiz { return &SessionBiz{ sessionRepo: sessionImpl, sysRepo: sysImpl, - chatRepo: chatImpl, + chatHisRepo: chatImpl, conf: conf, } } @@ -91,10 +91,10 @@ func (s *SessionBiz) SessionInit(ctx context.Context, req *entitys.SessionInitRe result.Prologue = sysConfig.Prologue // 存在,返回会话历史 var chatList []model.AiChatHi - chatList, err = s.chatRepo.FindAll( - s.chatRepo.WithSessionId(session.SessionID), // 条件:会话ID - s.chatRepo.OrderByDesc("create_at"), // 排序:按创建时间降序 - s.chatRepo.WithLimit(constants.ChatHistoryLimit), // 限制返回条数 + chatList, err = s.chatHisRepo.FindAll( + s.chatHisRepo.WithSessionId(session.SessionID), // 条件:会话ID + s.chatHisRepo.OrderByDesc("create_at"), // 排序:按创建时间降序 + s.chatHisRepo.WithLimit(constants.ChatHistoryLimit), // 限制返回条数 ) if err != nil { return diff --git a/internal/data/impl/chat_history.go b/internal/data/impl/chat_history.go index 6f6027d..f1c0b42 100644 --- a/internal/data/impl/chat_history.go +++ b/internal/data/impl/chat_history.go @@ -11,14 +11,14 @@ import ( "gorm.io/gorm" ) -type ChatImpl struct { +type ChatHisImpl struct { dataTemp.DataTemp BaseRepository[model.AiChatHi] chatChannel chan model.AiChatHi } -func NewChatImpl(db *utils.Db) *ChatImpl { - return &ChatImpl{ +func NewChatHisImpl(db *utils.Db) *ChatHisImpl { + return &ChatHisImpl{ DataTemp: *dataTemp.NewDataTemp(db, new(model.AiChatHi)), BaseRepository: NewBaseModel[model.AiChatHi](db.Client), chatChannel: make(chan model.AiChatHi, 100), @@ -26,19 +26,19 @@ func NewChatImpl(db *utils.Db) *ChatImpl { } // WithSessionId 条件:会话ID -func (impl *ChatImpl) WithSessionId(sessionId interface{}) CondFunc { +func (impl *ChatHisImpl) WithSessionId(sessionId interface{}) CondFunc { return func(db *gorm.DB) *gorm.DB { return db.Where("session_id = ?", sessionId) } } // 异步添加会话历史 -func (impl *ChatImpl) AsyncCreate(ctx context.Context, chat model.AiChatHi) { +func (impl *ChatHisImpl) AsyncCreate(ctx context.Context, chat model.AiChatHi) { impl.chatChannel <- chat } // 异步处理会话历史 -func (impl *ChatImpl) AsyncProcess(ctx context.Context) { +func (impl *ChatHisImpl) AsyncProcess(ctx context.Context) { for { select { case chat := <-impl.chatChannel: diff --git a/internal/data/impl/provider_set.go b/internal/data/impl/provider_set.go index f970e84..1284f11 100644 --- a/internal/data/impl/provider_set.go +++ b/internal/data/impl/provider_set.go @@ -4,4 +4,4 @@ import ( "github.com/google/wire" ) -var ProviderImpl = wire.NewSet(NewSessionImpl, NewSysImpl, NewTaskImpl, NewChatImpl) +var ProviderImpl = wire.NewSet(NewSessionImpl, NewSysImpl, NewTaskImpl, NewChatHisImpl) diff --git a/internal/data/model/ai_chat_his.gen.go b/internal/data/model/ai_chat_his.gen.go index d143972..e3b5b58 100644 --- a/internal/data/model/ai_chat_his.gen.go +++ b/internal/data/model/ai_chat_his.gen.go @@ -20,6 +20,7 @@ type AiChatHi struct { Useful int32 `gorm:"column:useful;not null;comment:0不评价,1有用,其他为无用" json:"useful"` // 0不评价,1有用,其他为无用 CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"` UpdatedAt time.Time `gorm:"column:updated_at;default:CURRENT_TIMESTAMP" json:"updated_at"` + TaskID int32 `gorm:"column:task_id;not null" json:"task_id"` // 任务ID } // TableName AiChatHi's table name diff --git a/internal/entitys/chat_history.go b/internal/entitys/chat_history.go index a26fa46..34e6c5e 100644 --- a/internal/entitys/chat_history.go +++ b/internal/entitys/chat_history.go @@ -14,3 +14,9 @@ type ChatHistory struct { type ChatHisLog struct { HisId int64 `json:"his_id"` } + +type ChatHistQuery struct { + SessionID string `json:"session_id"` + Page int `json:"page"` + PageSize int `json:"page_size"` +} diff --git a/internal/entitys/types.go b/internal/entitys/types.go index ece5ceb..a9dc72e 100644 --- a/internal/entitys/types.go +++ b/internal/entitys/types.go @@ -150,6 +150,7 @@ type RequireData struct { Histories []model.AiChatHi SessionInfo model.AiSession Tasks []model.AiTask + Task model.AiTask Match *Match Req *ChatSockRequest Auth string diff --git a/internal/server/http.go b/internal/server/http.go index 4cdc393..fd7e49e 100644 --- a/internal/server/http.go +++ b/internal/server/http.go @@ -11,24 +11,25 @@ import ( ) type HTTPServer struct { - app *fiber.App - service *services.ChatService - session *services.SessionService - gateway *gateway.Gateway - callback *services.CallbackService + app *fiber.App + service *services.ChatService + session *services.SessionService + gateway *gateway.Gateway + callback *services.CallbackService } func NewHTTPServer( - service *services.ChatService, - session *services.SessionService, - task *services.TaskService, - gateway *gateway.Gateway, - callback *services.CallbackService, + service *services.ChatService, + session *services.SessionService, + task *services.TaskService, + gateway *gateway.Gateway, + callback *services.CallbackService, + chatHis *services.HistoryService, ) *fiber.App { - //构建 server - app := initRoute() - router.SetupRoutes(app, service, session, task, gateway, callback) - return app + //构建 server + app := initRoute() + router.SetupRoutes(app, service, session, task, gateway, callback, chatHis) + return app } func initRoute() *fiber.App { diff --git a/internal/server/router/router.go b/internal/server/router/router.go index 1a6fd8b..4f5579c 100644 --- a/internal/server/router/router.go +++ b/internal/server/router/router.go @@ -15,14 +15,17 @@ import ( ) type RouterServer struct { - app *fiber.App - service *services.ChatService - session *services.SessionService - gateway *gateway.Gateway + app *fiber.App + service *services.ChatService + session *services.SessionService + gateway *gateway.Gateway + chatHist *services.HistoryService } // SetupRoutes 设置路由 -func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionService *services.SessionService, task *services.TaskService, gateway *gateway.Gateway, callbackService *services.CallbackService) { +func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionService *services.SessionService, task *services.TaskService, + gateway *gateway.Gateway, callbackService *services.CallbackService, chatHist *services.HistoryService, +) { app.Use(func(c *fiber.Ctx) error { // 设置 CORS 头 c.Set("Access-Control-Allow-Origin", "*") @@ -77,6 +80,9 @@ func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionServi return ctx.Status(400).SendString("unknown action") } }) + + // 会话历史 + r.Post("/chat/history", chatHist.GetHistory) } func routerSocket(app *fiber.App, chatService *services.ChatService) { diff --git a/internal/services/chat.go b/internal/services/chat.go index 11dda38..eb47809 100644 --- a/internal/services/chat.go +++ b/internal/services/chat.go @@ -73,9 +73,9 @@ func (h *ChatService) Chat(c *websocket.Conn) { var wg sync.WaitGroup // 确保在函数返回时移除客户端并关闭连接 defer func() { - h.Gw.Cleanup(client.GetID()) - close(semaphore) // 关闭信号量通道 wg.Wait() // 等待所有消息处理goroutine完成 + close(semaphore) // 关闭信号量通道 + h.Gw.Cleanup(client.GetID()) }() // 绑定会话ID, sessionId 为空时, 则不绑定 diff --git a/internal/services/chat_history.go b/internal/services/chat_history.go new file mode 100644 index 0000000..f49d238 --- /dev/null +++ b/internal/services/chat_history.go @@ -0,0 +1,44 @@ +package services + +import ( + "ai_scheduler/internal/biz" + errors "ai_scheduler/internal/data/error" + "ai_scheduler/internal/entitys" + "github.com/gofiber/fiber/v2" +) + +type HistoryService struct { + chatRepo *biz.ChatHistoryBiz +} + +func NewHistoryService(chatRepo *biz.ChatHistoryBiz) *HistoryService { + return &HistoryService{ + chatRepo: chatRepo, + } +} + +// GetHistoryService 获取会话历史 +func (h *HistoryService) GetHistory(c *fiber.Ctx) error { + var query entitys.ChatHistQuery + if err := c.BodyParser(&query); err != nil { + return err + } + // 校验参数 + if query.SessionID == "" { + return errors.SessionNotFound + } + if query.Page <= 0 { + query.Page = 1 + } + if query.PageSize <= 0 { + query.PageSize = 10 + } + + // 查询历史 + history, err := h.chatRepo.List(c.Context(), &query) + if err != nil { + return err + } + + return c.JSON(history) +} diff --git a/internal/services/provider_set.go b/internal/services/provider_set.go index 0e1284b..668c7fe 100644 --- a/internal/services/provider_set.go +++ b/internal/services/provider_set.go @@ -6,4 +6,4 @@ import ( "github.com/google/wire" ) -var ProviderSetServices = wire.NewSet(NewChatService, NewSessionService, gateway.NewGateway, NewTaskService, NewCallbackService) +var ProviderSetServices = wire.NewSet(NewChatService, NewSessionService, gateway.NewGateway, NewTaskService, NewCallbackService, NewHistoryService) From 5a44253d6d7f6622a3f4acd3cfd3733053f072f0 Mon Sep 17 00:00:00 2001 From: wolter <11@gmail> Date: Fri, 5 Dec 2025 18:11:02 +0800 Subject: [PATCH 10/66] fix: Dockerfile --- Dockerfile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Dockerfile b/Dockerfile index 2c0e527..2bd6192 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ ## 使用官方Go镜像作为构建环境 -FROM golang:1.24.0-alpine AS builder +FROM golang:1.24.1-alpine AS builder # 设置工作目录 WORKDIR /app @@ -15,7 +15,7 @@ COPY . . # 复制go模块依赖文件 COPY go.mod go.sum ./ -RUN go mod download +RUN go mod tidy RUN go install github.com/google/wire/cmd/wire@latest RUN wire ./cmd/server From 5d82a4c0b9d56dfec9d9620e4debe9ba569acab2 Mon Sep 17 00:00:00 2001 From: fuzhongyun <15339891972@163.com> Date: Fri, 5 Dec 2025 18:28:34 +0800 Subject: [PATCH 11/66] =?UTF-8?q?fix:=20=E4=BF=AE=E6=94=B9vllm=E5=9C=B0?= =?UTF-8?q?=E5=9D=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/config_test.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/config_test.yaml b/config/config_test.yaml index 2d099ac..0fd9f40 100644 --- a/config/config_test.yaml +++ b/config/config_test.yaml @@ -14,7 +14,7 @@ ollama: format: "json" vllm: - base_url: "http://127.0.0.1:8001/v1" + base_url: "http://host.docker.internal:8001/v1" vl_model: "models/Qwen2.5-VL-3B-Instruct-AWQ" timeout: "120s" level: "info" From 7ac61893f5ef324b29ec8cb7a4acda131e879ac2 Mon Sep 17 00:00:00 2001 From: wolter <11@gmail> Date: Mon, 8 Dec 2025 13:55:56 +0800 Subject: [PATCH 12/66] =?UTF-8?q?feat:=E4=BC=9A=E8=AF=9D=E5=8E=86=E5=8F=B2?= =?UTF-8?q?=E5=92=8C=E5=89=8D=E7=AB=AF=E5=9B=9E=E4=BC=A0=E6=95=B0=E6=8D=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- go.mod | 5 ++ go.sum | 12 ++++ internal/biz/chat_history.go | 88 ++++++++++++++++++++------ internal/data/impl/base.go | 16 +++++ internal/data/impl/chat_history.go | 7 ++ internal/data/impl/task_impl.go | 6 +- internal/data/model/ai_chat_his.gen.go | 1 + internal/entitys/chat_history.go | 42 ++++++++++++ internal/pkg/util/string.go | 10 +++ internal/pkg/validate/validate.go | 45 +++++++++++++ internal/server/router/router.go | 3 +- internal/services/chat_history.go | 21 +++++- 12 files changed, 235 insertions(+), 21 deletions(-) create mode 100644 internal/pkg/validate/validate.go diff --git a/go.mod b/go.mod index 49b7751..b6e99c9 100644 --- a/go.mod +++ b/go.mod @@ -14,6 +14,9 @@ require ( github.com/faabiosr/cachego v0.26.0 github.com/fastwego/dingding v1.0.0-beta.4 github.com/go-kratos/kratos/v2 v2.9.1 + github.com/go-playground/locales v0.14.1 + github.com/go-playground/universal-translator v0.18.1 + github.com/go-playground/validator/v10 v10.20.0 github.com/gofiber/fiber/v2 v2.52.9 github.com/gofiber/websocket/v2 v2.2.1 github.com/google/uuid v1.6.0 @@ -53,6 +56,7 @@ require ( github.com/evanphx/json-patch v0.5.2 // indirect github.com/fasthttp/websocket v1.5.3 // indirect github.com/fsnotify/fsnotify v1.6.0 // indirect + github.com/gabriel-vasile/mimetype v1.4.3 // indirect github.com/go-sql-driver/mysql v1.8.1 // indirect github.com/goph/emperror v0.17.2 // indirect github.com/hashicorp/hcl v1.0.0 // indirect @@ -61,6 +65,7 @@ require ( github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/compress v1.17.9 // indirect github.com/klauspost/cpuid/v2 v2.2.9 // indirect + github.com/leodido/go-urn v1.4.0 // indirect github.com/magiconair/properties v1.8.7 // indirect github.com/mailru/easyjson v0.7.7 // indirect github.com/mattn/go-colorable v0.1.13 // indirect diff --git a/go.sum b/go.sum index 5dc58f0..a6689ca 100644 --- a/go.sum +++ b/go.sum @@ -174,6 +174,8 @@ github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMo github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY= github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw= +github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= +github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= github.com/garyburd/redigo v1.6.0/go.mod h1:NR3MbYisc3/PwhQ00EMzDiPmrwpPxAn5GI05/YaO1SY= github.com/getsentry/raven-go v0.2.0/go.mod h1:KungGk8q33+aIAZUIVWZDr2OfAEBsO49PX4NzFV5kcQ= github.com/go-check/check v0.0.0-20180628173108-788fd7840127 h1:0gkP6mzaMqkmpcJYCFOLkIBwI7xFExG03bbkOkCvUPI= @@ -183,6 +185,14 @@ github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2 github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-kratos/kratos/v2 v2.9.1 h1:EGif6/S/aK/RCR5clIbyhioTNyoSrii3FC118jG40Z0= github.com/go-kratos/kratos/v2 v2.9.1/go.mod h1:a1MQLjMhIh7R0kcJS9SzJYR43BRI7EPzzN0J1Ksu2bA= +github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= +github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= +github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= +github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= +github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= +github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= +github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBExVwjEviJTixqxL8= +github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= github.com/gofiber/fiber/v2 v2.52.9 h1:YjKl5DOiyP3j0mO61u3NTmK7or8GzzWzCFzkboyP5cw= @@ -292,6 +302,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= +github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= diff --git a/internal/biz/chat_history.go b/internal/biz/chat_history.go index eeb74e5..58266e8 100644 --- a/internal/biz/chat_history.go +++ b/internal/biz/chat_history.go @@ -1,53 +1,105 @@ package biz import ( + errors "ai_scheduler/internal/data/error" "ai_scheduler/internal/data/impl" "ai_scheduler/internal/data/model" "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg/util" "context" - + "encoding/json" "xorm.io/builder" ) type ChatHistoryBiz struct { chatHiRepo *impl.ChatHisImpl + taskRepo *impl.TaskImpl } -func NewChatHistoryBiz(chatHiRepo *impl.ChatHisImpl) *ChatHistoryBiz { +func NewChatHistoryBiz(chatHiRepo *impl.ChatHisImpl, taskRepo *impl.TaskImpl) *ChatHistoryBiz { s := &ChatHistoryBiz{ chatHiRepo: chatHiRepo, + taskRepo: taskRepo, } - //go s.AsyncProcess(context.Background()) return s } // 查询会话历史 -func (s *ChatHistoryBiz) List(ctx context.Context, query *entitys.ChatHistQuery) ([]model.AiChatHi, error) { +func (s *ChatHistoryBiz) List(ctx context.Context, query *entitys.ChatHistQuery) ([]entitys.ChatHisQueryResponse, error) { chats, err := s.chatHiRepo.FindAll( s.chatHiRepo.WithSessionId(query.SessionID), s.chatHiRepo.PaginateScope(query.Page, query.PageSize), + s.chatHiRepo.OrderByDesc("his_id"), ) if err != nil { return nil, err } - return chats, nil + + taskIds := make([]int32, 0, len(chats)) + for _, chat := range chats { + // 去重任务ID + if !util.Contains(taskIds, chat.TaskID) { + taskIds = append(taskIds, chat.TaskID) + } + } + + // 查询任务名称 + tasks, err := s.taskRepo.FindAll(s.taskRepo.In("task_id", taskIds)) + if err != nil { + return nil, err + } + taskMap := make(map[int32]model.AiTask) + for _, task := range tasks { + taskMap[task.TaskID] = task + } + + // 构建结果 + result := make([]entitys.ChatHisQueryResponse, 0, len(chats)) + for _, chat := range chats { + item := entitys.ChatHisQueryResponse{} + item.FromModel(chat, taskMap[chat.TaskID]) + result = append(result, item) + } + + return result, nil } -// 添加会话历史 -func (s *ChatHistoryBiz) Create(ctx context.Context, chat entitys.ChatHistory) error { - return s.chatHiRepo.Create(&model.AiChatHi{ - SessionID: chat.SessionID, - Ques: chat.Role.String(), - Ans: chat.Content, - }) -} +//// 添加会话历史 +//func (s *ChatHistoryBiz) Create(ctx context.Context, chat entitys.ChatHistory) error { +// return s.chatHiRepo.Create(&model.AiChatHi{ +// SessionID: chat.SessionID, +// Ques: chat.Role.String(), +// Ans: chat.Content, +// }) +//} -// 更新会话历史内容 -func (s *ChatHistoryBiz) UpdateContent(ctx context.Context, chat *entitys.UseFulRequest) error { - cond := builder.NewCond() - cond = cond.And(builder.Eq{"his_id": chat.HisId}) +// 更新会话历史内容, 追加内容, 不覆盖原有内容, content 使用json格式存储 +func (c *ChatHistoryBiz) UpdateContent(ctx context.Context, chat *entitys.UpdateContentRequest) error { + var contents []string + chatHi, has, err := c.chatHiRepo.FindOne(c.chatHiRepo.WithHisId(chat.HisID)) + if err != nil { + return err + } else if !has { + return errors.NewBusinessErr(errors.InvalidParamCode, "chat history not found") + } - return s.chatHiRepo.UpdateByCond(&cond, &model.AiChatHi{HisID: chat.HisId, Useful: chat.Useful}) + if "" != chatHi.Content { + // 解析历史内容 + err = json.Unmarshal([]byte(chatHi.Content), &contents) + if err != nil { + return err + } + } + contents = append(contents, chat.Content) + + b, err := json.Marshal(contents) + if err != nil { + return err + } + chatHi.Content = string(b) + return c.chatHiRepo.Update(&chatHi, + c.chatHiRepo.Select("content"), + c.chatHiRepo.WithHisId(chatHi.HisID)) } // 异步添加会话历史 diff --git a/internal/data/impl/base.go b/internal/data/impl/base.go index 5ab2afa..7aee0b7 100644 --- a/internal/data/impl/base.go +++ b/internal/data/impl/base.go @@ -54,6 +54,8 @@ type BaseRepository[P PO] interface { WithStatus(status int) CondFunc // 查询status GetDb() *gorm.DB // 获取数据库连接 WithLimit(limit int) CondFunc // 限制返回条数 + In(field string, values interface{}) CondFunc // 查询字段是否在列表中 + Select(fields ...string) CondFunc // 选择字段 } // PaginationResult 分页查询结果 @@ -215,3 +217,17 @@ func (this *BaseModel[P]) WithLimit(limit int) CondFunc { return db.Limit(limit) } } + +// 查询字段是否在列表中 +func (this *BaseModel[P]) In(field string, values interface{}) CondFunc { + return func(db *gorm.DB) *gorm.DB { + return db.Where(fmt.Sprintf("%s IN ?", field), values) + } +} + +// select 字段 +func (this *BaseModel[P]) Select(fields ...string) CondFunc { + return func(db *gorm.DB) *gorm.DB { + return db.Select(fields) + } +} diff --git a/internal/data/impl/chat_history.go b/internal/data/impl/chat_history.go index f1c0b42..08a7fa8 100644 --- a/internal/data/impl/chat_history.go +++ b/internal/data/impl/chat_history.go @@ -55,3 +55,10 @@ func (impl *ChatHisImpl) AsyncProcess(ctx context.Context) { } } } + +// his_id 条件:历史ID +func (impl *ChatHisImpl) WithHisId(hisId interface{}) CondFunc { + return func(db *gorm.DB) *gorm.DB { + return db.Where("his_id = ?", hisId) + } +} diff --git a/internal/data/impl/task_impl.go b/internal/data/impl/task_impl.go index 8b3f246..3c76680 100644 --- a/internal/data/impl/task_impl.go +++ b/internal/data/impl/task_impl.go @@ -8,8 +8,12 @@ import ( type TaskImpl struct { dataTemp.DataTemp + BaseRepository[model.AiTask] } func NewTaskImpl(db *utils.Db) *TaskImpl { - return &TaskImpl{*dataTemp.NewDataTemp(db, new(model.AiTask))} + return &TaskImpl{ + DataTemp: *dataTemp.NewDataTemp(db, new(model.AiTask)), + BaseRepository: NewBaseModel[model.AiTask](db.Client), + } } diff --git a/internal/data/model/ai_chat_his.gen.go b/internal/data/model/ai_chat_his.gen.go index e3b5b58..595b4c4 100644 --- a/internal/data/model/ai_chat_his.gen.go +++ b/internal/data/model/ai_chat_his.gen.go @@ -21,6 +21,7 @@ type AiChatHi struct { CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"` UpdatedAt time.Time `gorm:"column:updated_at;default:CURRENT_TIMESTAMP" json:"updated_at"` TaskID int32 `gorm:"column:task_id;not null" json:"task_id"` // 任务ID + Content string `gorm:"column:content" json:"content"` // 前端回传数据 } // TableName AiChatHi's table name diff --git a/internal/entitys/chat_history.go b/internal/entitys/chat_history.go index 34e6c5e..6991a53 100644 --- a/internal/entitys/chat_history.go +++ b/internal/entitys/chat_history.go @@ -2,6 +2,8 @@ package entitys import ( "ai_scheduler/internal/data/constants" + "ai_scheduler/internal/data/model" + "encoding/json" ) type ChatHistory struct { @@ -20,3 +22,43 @@ type ChatHistQuery struct { Page int `json:"page"` PageSize int `json:"page_size"` } + +type ChatHisQueryResponse struct { + HisID int64 `gorm:"column:his_id;primaryKey;autoIncrement:true" json:"his_id"` + SessionID string `gorm:"column:session_id;not null" json:"session_id"` + Ques string `gorm:"column:ques;not null" json:"ques"` + Ans string `gorm:"column:ans;not null" json:"ans"` + Files string `gorm:"column:files;not null" json:"files"` + Useful int32 `gorm:"column:useful;not null;comment:0不评价,1有用,其他为无用" json:"useful"` // 0不评价,1有用,其他为无用 + CreateAt string `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"` + TaskID int32 `gorm:"column:task_id;not null" json:"task_id"` // 任务ID + TaskName string `gorm:"column:task_name;not null" json:"task_name"` // 任务名称 + Content []string `gorm:"column:content" json:"content"` // 前端回传数据 +} + +func (c *ChatHisQueryResponse) FromModel(chat model.AiChatHi, task model.AiTask) { + c.HisID = chat.HisID + c.SessionID = chat.SessionID + c.Ques = chat.Ques + c.Ans = chat.Ans + c.Files = chat.Files + c.Useful = chat.Useful + c.CreateAt = chat.CreateAt.Format("2006-01-02 15:04:05") + c.TaskID = chat.TaskID + c.TaskName = task.Name + c.Content = make([]string, 0) + + // 解析Content + if "" != chat.Content { + var contents []string + if err := json.Unmarshal([]byte(chat.Content), &contents); err != nil { + c.Content = contents + } + c.Content = append(c.Content, chat.Content) + } +} + +type UpdateContentRequest struct { + HisID int64 `json:"his_id" validate:"required"` + Content string `json:"content" validate:"required"` +} diff --git a/internal/pkg/util/string.go b/internal/pkg/util/string.go index 1fc898a..9dd4056 100644 --- a/internal/pkg/util/string.go +++ b/internal/pkg/util/string.go @@ -32,3 +32,13 @@ func StringToFloat64(s string) float64 { i, _ := strconv.ParseFloat(s, 64) return i } + +// 是否包含在数组中 +func Contains[T comparable](strings []T, str T) bool { + for _, s := range strings { + if s == str { + return true + } + } + return false +} diff --git a/internal/pkg/validate/validate.go b/internal/pkg/validate/validate.go new file mode 100644 index 0000000..de58134 --- /dev/null +++ b/internal/pkg/validate/validate.go @@ -0,0 +1,45 @@ +package validate + +import ( + "fmt" + "github.com/go-playground/locales/zh" + ut "github.com/go-playground/universal-translator" + "github.com/go-playground/validator/v10" + zh_translations "github.com/go-playground/validator/v10/translations/zh" + "reflect" +) + +func Struct(s interface{}) (errMsg []string, err error) { + // 创建验证器实例 + validate := validator.New() + + // 创建中文翻译器 + zh_ch := zh.New() + uni := ut.New(zh_ch, zh_ch) + trans, _ := uni.GetTranslator("zh") + + //注册一个函数,获取struct tag里自定义的label作为字段名 + validate.RegisterTagNameFunc(func(fld reflect.StructField) string { + name := fld.Tag.Get("label") + return name + }) + + // 注册中文翻译器到验证器 + _ = zh_translations.RegisterDefaultTranslations(validate, trans) + + // 验证结构体 + err = validate.Struct(s) + if err != nil { + // 处理验证错误 + if _, ok := err.(*validator.InvalidValidationError); ok { + fmt.Println("处理验证错误error:", err) + errMsg = append(errMsg, err.Error()) + } else { + for _, v := range err.(validator.ValidationErrors) { + errMsg = append(errMsg, v.Translate(trans)) + } + } + } + + return +} diff --git a/internal/server/router/router.go b/internal/server/router/router.go index 4f5579c..e2645bb 100644 --- a/internal/server/router/router.go +++ b/internal/server/router/router.go @@ -82,7 +82,8 @@ func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionServi }) // 会话历史 - r.Post("/chat/history", chatHist.GetHistory) + r.Post("/chat/history/list", chatHist.List) + r.Post("/chat/history/update/content", chatHist.UpdateContent) } func routerSocket(app *fiber.App, chatService *services.ChatService) { diff --git a/internal/services/chat_history.go b/internal/services/chat_history.go index f49d238..7bd4d75 100644 --- a/internal/services/chat_history.go +++ b/internal/services/chat_history.go @@ -4,7 +4,10 @@ import ( "ai_scheduler/internal/biz" errors "ai_scheduler/internal/data/error" "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg/validate" "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/log" + "strings" ) type HistoryService struct { @@ -18,7 +21,7 @@ func NewHistoryService(chatRepo *biz.ChatHistoryBiz) *HistoryService { } // GetHistoryService 获取会话历史 -func (h *HistoryService) GetHistory(c *fiber.Ctx) error { +func (h *HistoryService) List(c *fiber.Ctx) error { var query entitys.ChatHistQuery if err := c.BodyParser(&query); err != nil { return err @@ -42,3 +45,19 @@ func (h *HistoryService) GetHistory(c *fiber.Ctx) error { return c.JSON(history) } + +func (h *HistoryService) UpdateContent(c *fiber.Ctx) error { + var req entitys.UpdateContentRequest + if err := c.BodyParser(&req); err != nil { + return err + } + // 校验参数 + msg, err := validate.Struct(req) + if err != nil { + log.Error(c.UserContext(), "参数错误 error: ", err) + return errors.NewBusinessErr(errors.InvalidParamCode, strings.Join(msg, ";")) + } + + // 更新历史 + return h.chatRepo.UpdateContent(c.Context(), &req) +} From 2c9874a1801e3e5d0c80ad5a72a9b86e1445e34d Mon Sep 17 00:00:00 2001 From: wolter <11@gmail> Date: Mon, 8 Dec 2025 14:33:32 +0800 Subject: [PATCH 13/66] =?UTF-8?q?fix:=E4=BC=9A=E8=AF=9D=E5=8E=86=E5=8F=B2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/entitys/chat_history.go | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/internal/entitys/chat_history.go b/internal/entitys/chat_history.go index 6991a53..019a8c6 100644 --- a/internal/entitys/chat_history.go +++ b/internal/entitys/chat_history.go @@ -4,6 +4,7 @@ import ( "ai_scheduler/internal/data/constants" "ai_scheduler/internal/data/model" "encoding/json" + "log" ) type ChatHistory struct { @@ -33,7 +34,7 @@ type ChatHisQueryResponse struct { CreateAt string `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"` TaskID int32 `gorm:"column:task_id;not null" json:"task_id"` // 任务ID TaskName string `gorm:"column:task_name;not null" json:"task_name"` // 任务名称 - Content []string `gorm:"column:content" json:"content"` // 前端回传数据 + Contents []string `gorm:"column:contents" json:"contents"` // 前端回传数据 } func (c *ChatHisQueryResponse) FromModel(chat model.AiChatHi, task model.AiTask) { @@ -46,15 +47,15 @@ func (c *ChatHisQueryResponse) FromModel(chat model.AiChatHi, task model.AiTask) c.CreateAt = chat.CreateAt.Format("2006-01-02 15:04:05") c.TaskID = chat.TaskID c.TaskName = task.Name - c.Content = make([]string, 0) + c.Contents = make([]string, 0) // 解析Content if "" != chat.Content { - var contents []string - if err := json.Unmarshal([]byte(chat.Content), &contents); err != nil { - c.Content = contents + err := json.Unmarshal([]byte(chat.Content), &c.Contents) + if err != nil { + c.Contents = append(c.Contents, chat.Content) + log.Println("解析Content失败 error: ", err) } - c.Content = append(c.Content, chat.Content) } } From 8ab6cbe3f4d315900d617fabfae276ce01abfc11 Mon Sep 17 00:00:00 2001 From: fuzhongyun <15339891972@163.com> Date: Tue, 9 Dec 2025 14:22:13 +0800 Subject: [PATCH 14/66] =?UTF-8?q?fix:=20=E8=B0=83=E6=95=B4api=E7=9B=B4?= =?UTF-8?q?=E8=BF=9E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/biz/do/handle.go | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/internal/biz/do/handle.go b/internal/biz/do/handle.go index 14695e6..ebebf83 100644 --- a/internal/biz/do/handle.go +++ b/internal/biz/do/handle.go @@ -17,8 +17,9 @@ import ( "context" "encoding/json" "fmt" - "gorm.io/gorm/utils" "strings" + + "gorm.io/gorm/utils" ) type Handle struct { @@ -237,9 +238,19 @@ func (r *Handle) handleApiTask(ctx context.Context, requireData *entitys.Require if err != nil { return } - request.Url = strings.ReplaceAll(task.Config, "${authorization}", requireData.Auth) + // request.Url = strings.ReplaceAll(task.Config, "${authorization}", requireData.Auth) + task.Config = strings.ReplaceAll(task.Config, "${authorization}", requireData.Auth) for k, v := range requestParam { - task.Config = strings.ReplaceAll(task.Config, "${"+k+"}", fmt.Sprintf("%v", v)) + if vStr, ok := v.(string); ok { + task.Config = strings.ReplaceAll(task.Config, "${"+k+"}", vStr) + } else { + var jsonStr []byte + jsonStr, err = json.Marshal(v) + if err != nil { + return errors.NewBusinessErr(422, "请求参数解析失败") + } + task.Config = strings.ReplaceAll(task.Config, "\"${"+k+"}\"", string(jsonStr)) + } } var configData entitys.ConfigDataHttp err = json.Unmarshal([]byte(task.Config), &configData) From b184dce3eacc5ad6febfff28a2fc7e402348dfd7 Mon Sep 17 00:00:00 2001 From: fuzhongyun <15339891972@163.com> Date: Tue, 9 Dec 2025 14:22:18 +0800 Subject: [PATCH 15/66] =?UTF-8?q?fix:=20=E8=B0=83=E6=95=B4api=E7=9B=B4?= =?UTF-8?q?=E8=BF=9E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/config_env.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/config/config_env.yaml b/config/config_env.yaml index b6fa262..d2cd2a4 100644 --- a/config/config_env.yaml +++ b/config/config_env.yaml @@ -24,6 +24,7 @@ sys: channel_pool_len: 100 channel_pool_size: 32 llm_pool_len: 5 + heartbeat_interval: 30 redis: host: 47.97.27.195:6379 type: node From 945f3ff4fc6d1b18c6a01e2d17706c366b5ec44c Mon Sep 17 00:00:00 2001 From: wolter <11@gmail> Date: Tue, 9 Dec 2025 14:25:13 +0800 Subject: [PATCH 16/66] =?UTF-8?q?feat:=E4=BC=9A=E8=AF=9D=E6=9F=A5=E8=AF=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/biz/chat_history.go | 10 +++++++++- internal/entitys/chat_history.go | 1 + 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/internal/biz/chat_history.go b/internal/biz/chat_history.go index 58266e8..c3912a4 100644 --- a/internal/biz/chat_history.go +++ b/internal/biz/chat_history.go @@ -26,10 +26,18 @@ func NewChatHistoryBiz(chatHiRepo *impl.ChatHisImpl, taskRepo *impl.TaskImpl) *C // 查询会话历史 func (s *ChatHistoryBiz) List(ctx context.Context, query *entitys.ChatHistQuery) ([]entitys.ChatHisQueryResponse, error) { - chats, err := s.chatHiRepo.FindAll( + + con := []impl.CondFunc{ s.chatHiRepo.WithSessionId(query.SessionID), s.chatHiRepo.PaginateScope(query.Page, query.PageSize), s.chatHiRepo.OrderByDesc("his_id"), + } + if query.HisID > 0 { + con = append(con, s.chatHiRepo.WithHisId(query.HisID)) + } + + chats, err := s.chatHiRepo.FindAll( + con..., ) if err != nil { return nil, err diff --git a/internal/entitys/chat_history.go b/internal/entitys/chat_history.go index 019a8c6..b50148e 100644 --- a/internal/entitys/chat_history.go +++ b/internal/entitys/chat_history.go @@ -19,6 +19,7 @@ type ChatHisLog struct { } type ChatHistQuery struct { + HisID int64 `json:"his_id"` SessionID string `json:"session_id"` Page int `json:"page"` PageSize int `json:"page_size"` From d509a18d442b3bd8346605458ca833aceb820dde Mon Sep 17 00:00:00 2001 From: renzhiyuan <465386466@qq.com> Date: Tue, 9 Dec 2025 16:32:13 +0800 Subject: [PATCH 17/66] feat: add DingTalk bot integration --- cmd/server/main.go | 5 +- config/config_test.yaml | 6 + go.mod | 11 +- go.sum | 6 + internal/biz/ding_talk_bot.go | 42 ++++++ internal/biz/do/ctx.go | 31 +++- internal/biz/do/handle.go | 20 +-- internal/biz/do/prompt.go | 70 +++++++++ internal/biz/handle/file.go | 14 ++ internal/biz/handle/handle.go | 22 --- internal/biz/llm_service/common.go | 47 ------ internal/biz/llm_service/langchain.go | 159 ++++++++++----------- internal/biz/llm_service/ollama.go | 79 ++-------- internal/biz/provider_set.go | 6 +- internal/biz/router.go | 16 ++- internal/config/config.go | 25 ++-- internal/data/constants/bot.go | 22 ++- internal/data/constants/file.go | 75 ++++++++++ internal/data/impl/bot_chat_history.go | 15 ++ internal/data/model/ai_bot_chat_his.gen.go | 27 ++++ internal/entitys/bot.go | 23 ++- internal/entitys/ollama.go | 10 ++ internal/entitys/recognize.go | 31 ++++ internal/entitys/response.go | 35 +++-- internal/entitys/types.go | 3 +- internal/server/ding_talk_bot.go | 65 +++++++++ internal/server/provider_set.go | 11 +- internal/server/server.go | 22 ++- internal/services/callback.go | 6 +- internal/services/dtalk_bot.go | 135 +++++++++++++++++ internal/services/provider_set.go | 9 +- internal/tools_bot/dtalk_bot.go | 15 +- 32 files changed, 765 insertions(+), 298 deletions(-) create mode 100644 internal/biz/ding_talk_bot.go create mode 100644 internal/biz/do/prompt.go create mode 100644 internal/biz/handle/file.go delete mode 100644 internal/biz/handle/handle.go create mode 100644 internal/data/constants/file.go create mode 100644 internal/data/impl/bot_chat_history.go create mode 100644 internal/data/model/ai_bot_chat_his.gen.go create mode 100644 internal/entitys/ollama.go create mode 100644 internal/entitys/recognize.go create mode 100644 internal/server/ding_talk_bot.go create mode 100644 internal/services/dtalk_bot.go diff --git a/cmd/server/main.go b/cmd/server/main.go index f0b0214..a6607f5 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -2,7 +2,6 @@ package main import ( "ai_scheduler/internal/config" - "flag" "fmt" @@ -23,8 +22,8 @@ func main() { } defer func() { cleanup() - }() - + //app.DingBotServer.Run(context.Background()) + //app.DingBotServer.RunBots(app.DingBotServer.BotServices) log.Fatal(app.HttpServer.Listen(fmt.Sprintf(":%d", bc.Server.Port))) } diff --git a/config/config_test.yaml b/config/config_test.yaml index 2929dd7..eed1fa5 100644 --- a/config/config_test.yaml +++ b/config/config_test.yaml @@ -81,3 +81,9 @@ default_prompt: # 权限配置 permissionConfig: permission_url: "http://api.test.user.1688sup.cn:8001/v1/menu/myCodes?systemId=" + + +ding_talk_bots: + public: + client_id: "dingchg59zwwvmuuvldx", + client_secret: "ZwetAnRiTQobNFVlNrshRagSMAJIFpBAepWkWI7on7Tt_o617KHtTjBLp8fQfplz", \ No newline at end of file diff --git a/go.mod b/go.mod index 6b08e4a..c78edda 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,6 @@ module ai_scheduler -go 1.24.0 - -toolchain go1.24.7 +go 1.24.1 require ( gitea.cdlsxd.cn/self-tools/l_request v1.0.8 @@ -13,18 +11,18 @@ require ( github.com/emirpasic/gods v1.18.1 github.com/faabiosr/cachego v0.26.0 github.com/fastwego/dingding v1.0.0-beta.4 + github.com/gabriel-vasile/mimetype v1.4.11 github.com/go-kratos/kratos/v2 v2.9.1 github.com/gofiber/fiber/v2 v2.52.9 github.com/gofiber/websocket/v2 v2.2.1 github.com/google/uuid v1.6.0 github.com/google/wire v0.7.0 github.com/ollama/ollama v0.12.7 + github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1 github.com/redis/go-redis/v9 v9.16.0 github.com/spf13/viper v1.17.0 github.com/tmc/langchaingo v0.1.13 google.golang.org/grpc v1.64.0 - google.golang.org/protobuf v1.34.1 - gopkg.in/yaml.v3 v3.0.1 gorm.io/driver/mysql v1.6.0 gorm.io/gorm v1.31.0 xorm.io/builder v0.3.13 @@ -46,6 +44,7 @@ require ( github.com/fasthttp/websocket v1.5.3 // indirect github.com/fsnotify/fsnotify v1.6.0 // indirect github.com/go-sql-driver/mysql v1.8.1 // indirect + github.com/gorilla/websocket v1.5.0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect @@ -82,5 +81,7 @@ require ( golang.org/x/sys v0.31.0 // indirect golang.org/x/text v0.23.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240604185151-ef581f913117 // indirect + google.golang.org/protobuf v1.34.1 // indirect gopkg.in/ini.v1 v1.67.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 29b4ace..e1306fe 100644 --- a/go.sum +++ b/go.sum @@ -142,6 +142,8 @@ github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMo github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY= github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw= +github.com/gabriel-vasile/mimetype v1.4.11 h1:AQvxbp830wPhHTqc1u7nzoLT+ZFxGY7emj5DR5DYFik= +github.com/gabriel-vasile/mimetype v1.4.11/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s= github.com/garyburd/redigo v1.6.0/go.mod h1:NR3MbYisc3/PwhQ00EMzDiPmrwpPxAn5GI05/YaO1SY= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= @@ -217,6 +219,8 @@ github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5m github.com/googleapis/google-cloud-go-testing v0.0.0-20200911160855-bcd43fbb19e8/go.mod h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= +github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= +github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= @@ -273,6 +277,8 @@ github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108 github.com/onsi/ginkgo v1.13.0/go.mod h1:+REjRxOmWfHCjfv9TTWB1jD1Frx4XydAD3zm1lskyM0= github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= +github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1 h1:Lb/Uzkiw2Ugt2Xf03J5wmv81PdkYOiWbI8CNBi1boC8= +github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1/go.mod h1:ln3IqPYYocZbYvl9TAOrG/cxGR9xcn4pnZRLdCTEGEU= github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= diff --git a/internal/biz/ding_talk_bot.go b/internal/biz/ding_talk_bot.go new file mode 100644 index 0000000..09de5d4 --- /dev/null +++ b/internal/biz/ding_talk_bot.go @@ -0,0 +1,42 @@ +package biz + +import ( + "ai_scheduler/internal/biz/do" + "ai_scheduler/internal/entitys" + "context" + + "github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot" +) + +// AiRouterBiz 智能路由服务 +type DingTalkBotBiz struct { + do *do.Do + handle *do.Handle +} + +// NewDingTalkBotBiz +func NewDingTalkBotBiz( + do *do.Do, + handle *do.Handle, +) *DingTalkBotBiz { + return &DingTalkBotBiz{ + do: do, + handle: handle, + } +} + +func (d *DingTalkBotBiz) InitRequire(ctx context.Context, data *chatbot.BotCallbackDataModel) (requireData *entitys.RequireDataDingTalkBot, err error) { + requireData = &entitys.RequireDataDingTalkBot{ + Req: data, + Ch: make(chan entitys.Response, 2), + } + entitys.ResLog(requireData.Ch, "recognize_start", "收到消息,正在处理中,请稍等") + + requireData.Sys, err = d.do.GetSysInfoForDingTalkBot(requireData) + requireData.Tasks, err = d.do.GetTasks(requireData.Sys.SysID) + return +} + +func (d *DingTalkBotBiz) Recognize(ctx context.Context, rec *entitys.Recognize, ch chan entitys.Response) (match *entitys.Match, err error) { + return d.handle.RecognizeBot(ctx, rec, ch) +} diff --git a/internal/biz/do/ctx.go b/internal/biz/do/ctx.go index ed8749d..eeefbeb 100644 --- a/internal/biz/do/ctx.go +++ b/internal/biz/do/ctx.go @@ -81,6 +81,20 @@ func (d *Do) DataAuth(ctx context.Context, client *gateway.Client, requireData * return nil } +func (d *Do) DataAuthForBot(ctx context.Context, client *gateway.Client, requireData *entitys.RequireData) (err error) { + + // 2. 加载系统信息 + if err = d.loadSystemInfo(ctx, client, requireData); err != nil { + return fmt.Errorf("获取系统信息失败: %w", err) + } + + // 3. 加载任务列表 + if err = d.loadTaskList(ctx, client, requireData); err != nil { + return fmt.Errorf("获取任务列表失败: %w", err) + } + return nil +} + // 提取数据验证为单独函数 func (d *Do) validateClientData(client *gateway.Client, requireData *entitys.RequireData) error { requireData.Session = client.GetSession() @@ -104,7 +118,7 @@ func (d *Do) validateClientData(client *gateway.Client, requireData *entitys.Req // 获取系统信息的辅助函数 func (d *Do) loadSystemInfo(ctx context.Context, client *gateway.Client, requireData *entitys.RequireData) error { if sysInfo := client.GetSysInfo(); sysInfo == nil { - sys, err := d.getSysInfo(requireData) + sys, err := d.GetSysInfo(requireData) if err != nil { return err } @@ -119,7 +133,7 @@ func (d *Do) loadSystemInfo(ctx context.Context, client *gateway.Client, require // 获取任务列表的辅助函数 func (d *Do) loadTaskList(ctx context.Context, client *gateway.Client, requireData *entitys.RequireData) error { if taskInfo := client.GetTasks(); len(taskInfo) == 0 { - tasks, err := d.getTasks(requireData.Sys.SysID) + tasks, err := d.GetTasks(requireData.Sys.SysID) if err != nil { return err } @@ -202,7 +216,16 @@ func (d *Do) getRequireData() (err error) { return } -func (d *Do) getSysInfo(requireData *entitys.RequireData) (sysInfo model.AiSy, err error) { +func (d *Do) GetSysInfo(requireData *entitys.RequireData) (sysInfo model.AiSy, err error) { + cond := builder.NewCond() + cond = cond.And(builder.Eq{"app_key": requireData.Key}) + cond = cond.And(builder.IsNull{"delete_at"}) + cond = cond.And(builder.Eq{"status": 1}) + err = d.sysImpl.GetOneBySearchToStrut(&cond, &sysInfo) + return +} + +func (d *Do) GetSysInfoForDingTalkBot(requireData *entitys.RequireDataDingTalkBot) (sysInfo model.AiSy, err error) { cond := builder.NewCond() cond = cond.And(builder.Eq{"app_key": requireData.Key}) cond = cond.And(builder.IsNull{"delete_at"}) @@ -221,7 +244,7 @@ func (d *Do) getSessionChatHis(requireData *entitys.RequireData) (his []model.Ai return } -func (d *Do) getTasks(sysId int32) (tasks []model.AiTask, err error) { +func (d *Do) GetTasks(sysId int32) (tasks []model.AiTask, err error) { cond := builder.NewCond() cond = cond.And(builder.Eq{"sys_id": sysId}) diff --git a/internal/biz/do/handle.go b/internal/biz/do/handle.go index e014d25..162005f 100644 --- a/internal/biz/do/handle.go +++ b/internal/biz/do/handle.go @@ -17,8 +17,9 @@ import ( "context" "encoding/json" "fmt" - "gorm.io/gorm/utils" "strings" + + "gorm.io/gorm/utils" ) type Handle struct { @@ -45,23 +46,24 @@ func NewHandle( } } -func (r *Handle) Recognize(ctx context.Context, requireData *entitys.RequireData) (err error) { - entitys.ResLog(requireData.Ch, "recognize_start", "准备意图识别") - +func (r *Handle) Recognize(ctx context.Context, rec *entitys.Recognize, promptProcessor PromptOption) (match entitys.Match, err error) { + entitys.ResLog(rec.Ch, "recognize_start", "准备意图识别") + prompt, err := promptProcessor.CreatePrompt(ctx, rec) //意图识别 - recognizeMsg, err := r.Ollama.IntentRecognize(ctx, requireData) + recognizeMsg, err := r.Ollama.IntentRecognize(ctx, &entitys.ToolSelect{ + Prompt: prompt, + Tools: rec.Tasks, + }) if err != nil { return } - entitys.ResLog(requireData.Ch, "recognize", recognizeMsg) - entitys.ResLog(requireData.Ch, "recognize_end", "意图识别结束") + entitys.ResLog(rec.Ch, "recognize", recognizeMsg) + entitys.ResLog(rec.Ch, "recognize_end", "意图识别结束") - var match entitys.Match if err = json.Unmarshal([]byte(recognizeMsg), &match); err != nil { err = errors.SysErr("数据结构错误:%v", err.Error()) return } - requireData.Match = &match return } diff --git a/internal/biz/do/prompt.go b/internal/biz/do/prompt.go new file mode 100644 index 0000000..cb1ca18 --- /dev/null +++ b/internal/biz/do/prompt.go @@ -0,0 +1,70 @@ +package do + +import ( + "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg" + "context" + "strings" + + "github.com/ollama/ollama/api" +) + +type PromptOption interface { + CreatePrompt(ctx context.Context, rec *entitys.Recognize) (mes []api.Message, err error) +} + +type WithSys struct { +} + +func (f *WithSys) CreatePrompt(ctx context.Context, rec *entitys.Recognize) (mes []api.Message, err error) { + var ( + prompt = make([]api.Message, 0) // 初始化一个空的api.Message切片 + ) + // 获取用户内容,如果出错则直接返回错误 + content, err := f.getUserContent(ctx, rec) + if err != nil { + return nil, err + } + // 构建提示消息列表,包含系统提示、助手回复和用户内容 + mes = append(prompt, api.Message{ + Role: "system", // 系统角色 + Content: rec.SystemPrompt, // 系统提示内容 + }, api.Message{ + Role: "assistant", // 助手角色 + Content: "### 聊天记录:" + pkg.JsonStringIgonErr(rec.ChatHis), // 助手回复内容 + }, api.Message{ + Role: "user", // 用户角色 + Content: content.String(), // 用户输入内容 + }) + + return +} + +func (f *WithSys) getUserContent(ctx context.Context, rec *entitys.Recognize) (content strings.Builder, err error) { + var hasFile bool + if len(rec.UserContent.FileUrl) > 0 || rec.UserContent.File != nil { + hasFile = true + } + content.WriteString(rec.UserContent.Text) + if hasFile { + content.WriteString("\n") + } + + if len(rec.UserContent.Tag) > 0 { + content.WriteString("\n") + content.WriteString("### 工具必须使用:") + content.WriteString(rec.UserContent.Tag) + } + + if len(rec.ChatHis.Messages) > 0 { + content.WriteString("### 引用历史聊天记录:\n") + content.WriteString(pkg.JsonStringIgonErr(rec.ChatHis)) + } + + if hasFile { + content.WriteString("\n") + content.WriteString("### 文件内容:\n") + hand.WriteString(rec.UserContent.FileUrl, rec.UserContent.FileUrl) + } + return +} diff --git a/internal/biz/handle/file.go b/internal/biz/handle/file.go new file mode 100644 index 0000000..3c9d79b --- /dev/null +++ b/internal/biz/handle/file.go @@ -0,0 +1,14 @@ +package handle + +import ( + "ai_scheduler/internal/entitys" +) + +// HandleRecognizeFile 这里的目的是无论将什么类型的file都转为二进制格式 +// 判断文件大小 +// 判断文件类型 +// 判断文件是否合法 +func HandleRecognizeFile(file *entitys.RecognizeFile) { + //Todo 仲云 + return +} diff --git a/internal/biz/handle/handle.go b/internal/biz/handle/handle.go deleted file mode 100644 index 391b6eb..0000000 --- a/internal/biz/handle/handle.go +++ /dev/null @@ -1,22 +0,0 @@ -package handle - -import ( - "ai_scheduler/internal/config" - "ai_scheduler/internal/tools" -) - -type Handle struct { - toolManager *tools.Manager - conf *config.Config -} - -func NewHandle( - toolManager *tools.Manager, - conf *config.Config, - -) *Handle { - return &Handle{ - toolManager: toolManager, - conf: conf, - } -} diff --git a/internal/biz/llm_service/common.go b/internal/biz/llm_service/common.go index 1d62ed7..c3ffe46 100644 --- a/internal/biz/llm_service/common.go +++ b/internal/biz/llm_service/common.go @@ -1,9 +1,7 @@ package llm_service import ( - "ai_scheduler/internal/data/constants" "ai_scheduler/internal/data/model" - "ai_scheduler/internal/entitys" "context" "time" ) @@ -20,48 +18,3 @@ func buildSystemPrompt(prompt string) string { return prompt } - -func buildAssistant(his []model.AiChatHi) (chatHis entitys.ChatHis) { - for _, item := range his { - if len(chatHis.SessionId) == 0 { - chatHis.SessionId = item.SessionID - } - chatHis.Messages = append(chatHis.Messages, []entitys.HisMessage{ - { - Role: constants.RoleUser, - Content: item.Ques, - Timestamp: item.CreateAt.Format(time.DateTime), - }, - { - Role: constants.RoleAssistant, - Content: item.Ans, - Timestamp: item.CreateAt.Format(time.DateTime), - }, - }...) - } - chatHis.Context = entitys.HisContext{ - UserLanguage: "zh-CN", - SystemMode: "technical_support", - } - return -} - -func BuildChatHisMessage(his []model.AiChatHi) (chatHis []entitys.HisMessage) { - for _, item := range his { - - chatHis = append(chatHis, []entitys.HisMessage{ - { - Role: constants.RoleUser, - Content: item.Ques, - Timestamp: item.CreateAt.Format(time.DateTime), - }, - { - Role: constants.RoleAssistant, - Content: item.Ans, - Timestamp: item.CreateAt.Format(time.DateTime), - }, - }...) - } - - return -} diff --git a/internal/biz/llm_service/langchain.go b/internal/biz/llm_service/langchain.go index 63f8815..a3216e8 100644 --- a/internal/biz/llm_service/langchain.go +++ b/internal/biz/llm_service/langchain.go @@ -1,87 +1,76 @@ package llm_service -import ( - "ai_scheduler/internal/data/model" - "ai_scheduler/internal/entitys" - "ai_scheduler/internal/pkg" - "ai_scheduler/internal/pkg/utils_langchain" - "context" - "encoding/json" - - "github.com/tmc/langchaingo/llms" -) - -type LangChainService struct { - client *utils_langchain.UtilLangChain -} - -func NewLangChainGenerate( - client *utils_langchain.UtilLangChain, -) *LangChainService { - - return &LangChainService{ - client: client, - } -} - -func (r *LangChainService) IntentRecognize(ctx context.Context, sysInfo model.AiSy, history []model.AiChatHi, userInput string, tasks []model.AiTask) (msg string, err error) { - prompt := r.getPrompt(sysInfo, history, userInput, tasks) - AgentClient := r.client.Get() - defer r.client.Put(AgentClient) - match, err := AgentClient.Llm.GenerateContent( - ctx, // 使用可取消的上下文 - prompt, - llms.WithJSONMode(), - ) - msg = match.Choices[0].Content - return -} - -func (r *LangChainService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []llms.MessageContent { - var ( - prompt = make([]llms.MessageContent, 0) - ) - prompt = append(prompt, llms.MessageContent{ - Role: llms.ChatMessageTypeSystem, - Parts: []llms.ContentPart{ - llms.TextPart(buildSystemPrompt(sysInfo.SysPrompt)), - }, - }, llms.MessageContent{ - Role: llms.ChatMessageTypeTool, - Parts: []llms.ContentPart{ - llms.TextPart(pkg.JsonStringIgonErr(buildAssistant(history))), - }, - }, llms.MessageContent{ - Role: llms.ChatMessageTypeTool, - Parts: []llms.ContentPart{ - llms.TextPart(pkg.JsonStringIgonErr(r.registerTools(tasks))), - }, - }, llms.MessageContent{ - Role: llms.ChatMessageTypeHuman, - Parts: []llms.ContentPart{ - llms.TextPart(reqInput), - }, - }) - return prompt -} - -func (r *LangChainService) registerTools(tasks []model.AiTask) []llms.Tool { - taskPrompt := make([]llms.Tool, 0) - for _, task := range tasks { - var taskConfig entitys.TaskConfig - err := json.Unmarshal([]byte(task.Config), &taskConfig) - if err != nil { - continue - } - taskPrompt = append(taskPrompt, llms.Tool{ - Type: "function", - Function: &llms.FunctionDefinition{ - Name: task.Index, - Description: task.Desc, - Parameters: taskConfig.Param, - }, - }) - - } - return taskPrompt -} +//type LangChainService struct { +// client *utils_langchain.UtilLangChain +//} +// +//func NewLangChainGenerate( +// client *utils_langchain.UtilLangChain, +//) *LangChainService { +// +// return &LangChainService{ +// client: client, +// } +//} +// +//func (r *LangChainService) IntentRecognize(ctx context.Context, sysInfo model.AiSy, history []model.AiChatHi, userInput string, tasks []model.AiTask) (msg string, err error) { +// prompt := r.getPrompt(sysInfo, history, userInput, tasks) +// AgentClient := r.client.Get() +// defer r.client.Put(AgentClient) +// match, err := AgentClient.Llm.GenerateContent( +// ctx, // 使用可取消的上下文 +// prompt, +// llms.WithJSONMode(), +// ) +// msg = match.Choices[0].Content +// return +//} +// +//func (r *LangChainService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []llms.MessageContent { +// var ( +// prompt = make([]llms.MessageContent, 0) +// ) +// prompt = append(prompt, llms.MessageContent{ +// Role: llms.ChatMessageTypeSystem, +// Parts: []llms.ContentPart{ +// llms.TextPart(buildSystemPrompt(sysInfo.SysPrompt)), +// }, +// }, llms.MessageContent{ +// Role: llms.ChatMessageTypeTool, +// Parts: []llms.ContentPart{ +// llms.TextPart(pkg.JsonStringIgonErr(buildAssistant(history))), +// }, +// }, llms.MessageContent{ +// Role: llms.ChatMessageTypeTool, +// Parts: []llms.ContentPart{ +// llms.TextPart(pkg.JsonStringIgonErr(r.registerTools(tasks))), +// }, +// }, llms.MessageContent{ +// Role: llms.ChatMessageTypeHuman, +// Parts: []llms.ContentPart{ +// llms.TextPart(reqInput), +// }, +// }) +// return prompt +//} +// +//func (r *LangChainService) registerTools(tasks []model.AiTask) []llms.Tool { +// taskPrompt := make([]llms.Tool, 0) +// for _, task := range tasks { +// var taskConfig entitys.TaskConfig +// err := json.Unmarshal([]byte(task.Config), &taskConfig) +// if err != nil { +// continue +// } +// taskPrompt = append(taskPrompt, llms.Tool{ +// Type: "function", +// Function: &llms.FunctionDefinition{ +// Name: task.Index, +// Description: task.Desc, +// Parameters: taskConfig.Param, +// }, +// }) +// +// } +// return taskPrompt +//} diff --git a/internal/biz/llm_service/ollama.go b/internal/biz/llm_service/ollama.go index 6527a3b..238e002 100644 --- a/internal/biz/llm_service/ollama.go +++ b/internal/biz/llm_service/ollama.go @@ -35,14 +35,14 @@ func NewOllamaGenerate( } } -func (r *OllamaService) IntentRecognize(ctx context.Context, requireData *entitys.RequireData) (msg string, err error) { - prompt, err := r.getPrompt(ctx, requireData) +func (r *OllamaService) IntentRecognize(ctx context.Context, req *entitys.ToolSelect) (msg string, err error) { + if err != nil { return } - toolDefinitions := r.registerToolsOllama(requireData.Tasks) + toolDefinitions := r.registerToolsOllama(req.Tools) - match, err := r.client.ToolSelect(ctx, prompt, toolDefinitions) + match, err := r.client.ToolSelect(ctx, rec.Prompt, toolDefinitions) if err != nil { return } @@ -64,87 +64,28 @@ func (r *OllamaService) IntentRecognize(ctx context.Context, requireData *entity msg = match.Message.Content return + } -func (r *OllamaService) getPrompt(ctx context.Context, requireData *entitys.RequireData) ([]api.Message, error) { - - var ( - prompt = make([]api.Message, 0) - ) - content, err := r.getUserContent(ctx, requireData) - if err != nil { - return nil, err - } - prompt = append(prompt, api.Message{ - Role: "system", - Content: buildSystemPrompt(requireData.Sys.SysPrompt), - }, api.Message{ - Role: "assistant", - Content: "### 聊天记录:" + pkg.JsonStringIgonErr(buildAssistant(requireData.Histories)), - }, api.Message{ - Role: "user", - Content: content, - }) - - return prompt, nil -} - -func (r *OllamaService) getUserContent(ctx context.Context, requireData *entitys.RequireData) (string, error) { - var content strings.Builder - content.WriteString(requireData.Req.Text) - if len(requireData.ImgByte) > 0 { - content.WriteString("\n") - } - - if len(requireData.Req.Tags) > 0 { - content.WriteString("\n") - content.WriteString("### 工具必须使用:") - content.WriteString(requireData.Req.Tags) - } - - if len(requireData.ImgByte) > 0 { - desc, err := r.RecognizeWithImg(ctx, requireData) - if err != nil { - return "", err - } - content.WriteString("### 上传图片解析内容:\n") - content.WriteString(requireData.Req.Tags) - content.WriteString(desc.Response) - } - - if requireData.Req.MarkHis > 0 { - var his model.AiChatHi - cond := builder.NewCond() - cond = cond.And(builder.Eq{"his_id": requireData.Req.MarkHis}) - err := r.chatHis.GetOneBySearchToStrut(&cond, &his) - if err != nil { - return "", err - } - content.WriteString("### 引用历史聊天记录:\n") - content.WriteString(pkg.JsonStringIgonErr(BuildChatHisMessage([]model.AiChatHi{his}))) - } - return content.String(), nil -} - -func (r *OllamaService) RecognizeWithImg(ctx context.Context, requireData *entitys.RequireData) (desc api.GenerateResponse, err error) { - if requireData.ImgByte == nil { +func (r *OllamaService) RecognizeWithImg(ctx context.Context, imgByte []api.ImageData, ch chan entitys.Response) (desc api.GenerateResponse, err error) { + if imgByte == nil { return } - entitys.ResLog(requireData.Ch, "recognize_img_start", "图片识别中...") + entitys.ResLog(ch, "recognize_img_start", "图片识别中...") desc, err = r.client.Generation(ctx, &api.GenerateRequest{ Model: r.config.Ollama.VlModel, Stream: new(bool), System: r.config.DefaultPrompt.ImgRecognize.SystemPrompt, Prompt: r.config.DefaultPrompt.ImgRecognize.UserPrompt, - Images: requireData.ImgByte, + Images: imgByte, KeepAlive: &api.Duration{Duration: 3600 * time.Second}, //Think: &api.ThinkValue{Value: false}, }) if err != nil { return } - entitys.ResLog(requireData.Ch, "recognize_img_end", "图片识别完成,识别内容:"+desc.Response) + entitys.ResLog(ch, "recognize_img_end", "图片识别完成,识别内容:"+desc.Response) return } diff --git a/internal/biz/provider_set.go b/internal/biz/provider_set.go index cc3de0a..aefa3ce 100644 --- a/internal/biz/provider_set.go +++ b/internal/biz/provider_set.go @@ -2,7 +2,6 @@ package biz import ( "ai_scheduler/internal/biz/do" - "ai_scheduler/internal/biz/handle" "ai_scheduler/internal/biz/llm_service" "github.com/google/wire" @@ -12,10 +11,11 @@ var ProviderSetBiz = wire.NewSet( NewAiRouterBiz, NewSessionBiz, NewChatHistoryBiz, - llm_service.NewLangChainGenerate, + //llm_service.NewLangChainGenerate, llm_service.NewOllamaGenerate, - handle.NewHandle, + //handle.NewHandle, do.NewDo, do.NewHandle, NewTaskBiz, + NewDingTalkBotBiz, ) diff --git a/internal/biz/router.go b/internal/biz/router.go index 6dcc233..335ddcc 100644 --- a/internal/biz/router.go +++ b/internal/biz/router.go @@ -3,6 +3,7 @@ package biz import ( "ai_scheduler/internal/biz/do" "ai_scheduler/internal/gateway" + "context" "ai_scheduler/internal/entitys" @@ -56,9 +57,15 @@ func (r *AiRouterBiz) RouteWithSocket(client *gateway.Client, req *entitys.ChatS log.Errorf("数据验证和收集失败: %s", err.Error()) return } - + //组装意图识别 + rec, err := r.SetRec(ctx, requireData) + if err != nil { + log.Errorf("组装意图识别失败: %s", err.Error()) + return + } //意图识别 - if err = r.handle.Recognize(ctx, requireData); err != nil { + requireData.Match, err = r.handle.Recognize(ctx, &rec, &do.WithSys{}) + if err != nil { log.Errorf("意图识别失败: %s", err.Error()) return } @@ -70,3 +77,8 @@ func (r *AiRouterBiz) RouteWithSocket(client *gateway.Client, req *entitys.ChatS } return } + +func (r *AiRouterBiz) SetRec(ctx context.Context, requireData *entitys.RequireData) (match entitys.Recognize, err error) { + //TODO 叙平 + return +} diff --git a/internal/config/config.go b/internal/config/config.go index da25c39..815dca0 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -9,16 +9,21 @@ import ( // Config 应用配置 type Config struct { - Server ServerConfig `mapstructure:"server"` - Ollama OllamaConfig `mapstructure:"ollama"` - Sys SysConfig `mapstructure:"sys"` - Tools ToolsConfig `mapstructure:"tools"` - Logging LoggingConfig `mapstructure:"logging"` - Redis Redis `mapstructure:"redis"` - DB DB `mapstructure:"db"` - DefaultPrompt SysPrompt `mapstructure:"default_prompt"` - PermissionConfig PermissionConfig `mapstructure:"permissionConfig"` - // LLM *LLM `mapstructure:"llm"` + Server ServerConfig `mapstructure:"server"` + Ollama OllamaConfig `mapstructure:"ollama"` + Sys SysConfig `mapstructure:"sys"` + Tools ToolsConfig `mapstructure:"tools"` + Logging LoggingConfig `mapstructure:"logging"` + Redis Redis `mapstructure:"redis"` + DB DB `mapstructure:"db"` + DefaultPrompt SysPrompt `mapstructure:"default_prompt"` + PermissionConfig PermissionConfig `mapstructure:"permissionConfig"` + DingTalkBots map[string]*DingTalkBot `mapstructure:"ding_talk_bots"` +} + +type DingTalkBot struct { + ClientId string `mapstructure:"client_id"` + ClientSecret string `mapstructure:"client_secret"` } type SysPrompt struct { diff --git a/internal/data/constants/bot.go b/internal/data/constants/bot.go index 6dc6680..bfe004c 100644 --- a/internal/data/constants/bot.go +++ b/internal/data/constants/bot.go @@ -3,5 +3,25 @@ package constants type BotTools string const ( - BotToolsBugOptimizationSubmit = "bug_optimization_submit" // 系统的bug/优化建议 + BotToolsBugOptimizationSubmit BotTools = "bug_optimization_submit" // 系统的bug/优化建议 ) + +type ChatStyle int + +const ( + ChatStyleNormal ChatStyle = 1 //正常 + ChatStyleSerious ChatStyle = 2 //严肃 + ChatStyleGentle ChatStyle = 3 //温柔 + ChatStyleArrogance ChatStyle = 4 //傲慢 + ChatStyleCute ChatStyle = 5 //可爱 + ChatStyleAngry ChatStyle = 6 //愤怒 +) + +var ChatStyleMap = map[ChatStyle]string{ + ChatStyleNormal: "正常", + ChatStyleSerious: "严肃", + ChatStyleGentle: "温柔", + ChatStyleArrogance: "傲慢", + ChatStyleCute: "可爱", + ChatStyleAngry: "愤怒", +} diff --git a/internal/data/constants/file.go b/internal/data/constants/file.go new file mode 100644 index 0000000..a418959 --- /dev/null +++ b/internal/data/constants/file.go @@ -0,0 +1,75 @@ +package constants + +import ( + "github.com/gabriel-vasile/mimetype" + "io" + "strings" +) + +type FileType string + +const ( + FileTypeUnknown FileType = "unknown" + FileTypeImage FileType = "image" + //FileTypeVideo FileType = "video" + FileTypeExcel FileType = "excel" + FileTypeWord FileType = "word" + FileTypeTxt FileType = "txt" + FileTypePDF FileType = "pdf" +) + +var FileTypeMappings = map[FileType][]string{ + FileTypeImage: { + "image/jpeg", "image/png", "image/gif", "image/webp", "image/svg+xml", + ".jpg", ".jpeg", ".png", ".gif", ".webp", ".svg", + }, + FileTypeExcel: { + "application/vnd.ms-excel", + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + ".xls", ".xlsx", + }, + FileTypeWord: { + "application/msword", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + ".doc", ".docx", + }, + FileTypePDF: { + "application/pdf", + ".pdf", + }, + FileTypeTxt: { + "text/plain", + ".txt", + }, +} + +func DetectFileType(file io.ReadSeeker) (FileType, error) { + // 读取文件头(512字节足够检测大多数类型) + buffer := make([]byte, 512) + n, err := file.Read(buffer) + if err != nil && err != io.EOF { + return FileTypeUnknown, err + } + + // 重置读取位置 + if _, err := file.Seek(0, io.SeekStart); err != nil { + return FileTypeUnknown, err + } + + // 检测 MIME 类型 + detectedMIME := mimetype.Detect(buffer[:n]).String() + + // 遍历映射表,匹配 MIME 或扩展名 + for fileType, mimesOrExts := range FileTypeMappings { + for _, item := range mimesOrExts { + if strings.HasPrefix(item, ".") { + continue // 跳过扩展名(仅 MIME 检测) + } + if item == detectedMIME { + return fileType, nil + } + } + } + + return FileTypeUnknown, nil +} diff --git a/internal/data/impl/bot_chat_history.go b/internal/data/impl/bot_chat_history.go new file mode 100644 index 0000000..3da5184 --- /dev/null +++ b/internal/data/impl/bot_chat_history.go @@ -0,0 +1,15 @@ +package impl + +import ( + "ai_scheduler/internal/data/model" + "ai_scheduler/tmpl/dataTemp" + "ai_scheduler/utils" +) + +type BotChatHisImpl struct { + dataTemp.DataTemp +} + +func NewBotChatHisImpl(db *utils.Db) *TaskImpl { + return &TaskImpl{*dataTemp.NewDataTemp(db, new(model.AiBotChatHi))} +} diff --git a/internal/data/model/ai_bot_chat_his.gen.go b/internal/data/model/ai_bot_chat_his.gen.go new file mode 100644 index 0000000..1e4bfff --- /dev/null +++ b/internal/data/model/ai_bot_chat_his.gen.go @@ -0,0 +1,27 @@ +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. + +package model + +import ( + "time" +) + +const TableNameAiBotChatHi = "ai_bot_chat_his" + +// AiBotChatHi mapped from table +type AiBotChatHi struct { + HisID int64 `gorm:"column:his_id;primaryKey;autoIncrement:true" json:"his_id"` + SessionID string `gorm:"column:session_id;not null" json:"session_id"` + Role string `gorm:"column:role;not null;comment:system系统输出,assistant助手输出,user用户输入" json:"role"` // system系统输出,assistant助手输出,user用户输入 + Content string `gorm:"column:content;not null" json:"content"` + Files string `gorm:"column:files;not null" json:"files"` + CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"` + UpdatedAt time.Time `gorm:"column:updated_at;default:CURRENT_TIMESTAMP" json:"updated_at"` +} + +// TableName AiBotChatHi's table name +func (*AiBotChatHi) TableName() string { + return TableNameAiBotChatHi +} diff --git a/internal/entitys/bot.go b/internal/entitys/bot.go index 226fbc8..0a99113 100644 --- a/internal/entitys/bot.go +++ b/internal/entitys/bot.go @@ -1,7 +1,24 @@ package entitys -type BotType int +import ( + "ai_scheduler/internal/data/model" -const ( - BugAndQuesDingTalk BotType = iota + 1 + "github.com/ollama/ollama/api" + "github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot" ) + +type RequireDataDingTalkBot struct { + Session string + Key string + Sys model.AiSy + Histories []model.AiChatHi + SessionInfo model.AiSession + Tasks []model.AiTask + Match *Match + Req *chatbot.BotCallbackDataModel + Auth string + Ch chan Response + KnowledgeConf KnowledgeBaseRequest + ImgByte []api.ImageData + ImgUrls []string +} diff --git a/internal/entitys/ollama.go b/internal/entitys/ollama.go new file mode 100644 index 0000000..7297dab --- /dev/null +++ b/internal/entitys/ollama.go @@ -0,0 +1,10 @@ +package entitys + +import ( + "github.com/ollama/ollama/api" +) + +type ToolSelect struct { + Prompt []api.Message + Tools []RegistrationTask +} diff --git a/internal/entitys/recognize.go b/internal/entitys/recognize.go new file mode 100644 index 0000000..ae32b89 --- /dev/null +++ b/internal/entitys/recognize.go @@ -0,0 +1,31 @@ +package entitys + +import ( + "ai_scheduler/internal/data/constants" +) + +type Recognize struct { + SystemPrompt string + UserContent *RecognizeUserContent + ChatHis ChatHis + Tasks []RegistrationTask + Ch chan Response +} + +type RegistrationTask struct { +} + +type RecognizeUserContent struct { + Text string + File *RecognizeFile + ActionCardUrl string + Tag string +} + +type FileData []byte + +type RecognizeFile struct { + File []FileData // 文件数据(二进制格式) + FileUrl string // 文件下载链接 + FileType constants.Caller // 文件类型(文件类型,能填最好填,可以跳过一层判断) +} diff --git a/internal/entitys/response.go b/internal/entitys/response.go index cdadc98..d99127f 100644 --- a/internal/entitys/response.go +++ b/internal/entitys/response.go @@ -2,22 +2,25 @@ package entitys import ( "encoding/json" + "github.com/gofiber/websocket/v2" ) type ResponseType string const ( - ResponseJson ResponseType = "json" - ResponseLoading ResponseType = "loading" - ResponseEnd ResponseType = "end" - ResponseStream ResponseType = "stream" - ResponseText ResponseType = "txt" - ResponseImg ResponseType = "img" - ResponseFile ResponseType = "file" - ResponseErr ResponseType = "error" - ResponseLog ResponseType = "log" - ResponseAuth ResponseType = "auth" + ResponseJson ResponseType = "json" + ResponseLoading ResponseType = "loading" + ResponseEnd ResponseType = "end" + ResponseStream ResponseType = "stream" + ResponseText ResponseType = "txt" + ResponseImg ResponseType = "img" + ResponseFile ResponseType = "file" + ResponseErr ResponseType = "error" + ResponseLog ResponseType = "log" + ResponseAuth ResponseType = "auth" + ResponseMarkdown ResponseType = "markdown" + ResponseActionCard ResponseType = "actionCard" ) func ResLog(ch chan Response, index string, content string) { @@ -45,6 +48,9 @@ func ResJson(ch chan Response, index string, content string) { } func ResEnd(ch chan Response, index string, content string) { + if ch == nil { + return + } ch <- Response{ Index: index, Content: content, @@ -53,6 +59,9 @@ func ResEnd(ch chan Response, index string, content string) { } func ResText(ch chan Response, index string, content string) { + if ch == nil { + return + } ch <- Response{ Index: index, Content: content, @@ -61,6 +70,9 @@ func ResText(ch chan Response, index string, content string) { } func ResLoading(ch chan Response, index string, content string) { + if ch == nil { + return + } ch <- Response{ Index: index, Content: content, @@ -68,6 +80,9 @@ func ResLoading(ch chan Response, index string, content string) { } } func ResError(ch chan Response, index string, content string) { + if ch == nil { + return + } ch <- Response{ Index: index, Content: content, diff --git a/internal/entitys/types.go b/internal/entitys/types.go index ece5ceb..cdc7986 100644 --- a/internal/entitys/types.go +++ b/internal/entitys/types.go @@ -3,7 +3,6 @@ package entitys import ( "ai_scheduler/internal/data/constants" "ai_scheduler/internal/data/model" - "context" "encoding/json" @@ -150,7 +149,7 @@ type RequireData struct { Histories []model.AiChatHi SessionInfo model.AiSession Tasks []model.AiTask - Match *Match + Match Match Req *ChatSockRequest Auth string Ch chan Response diff --git a/internal/server/ding_talk_bot.go b/internal/server/ding_talk_bot.go new file mode 100644 index 0000000..cda0e1c --- /dev/null +++ b/internal/server/ding_talk_bot.go @@ -0,0 +1,65 @@ +package server + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/services" + "context" + + "github.com/go-kratos/kratos/v2/log" + "github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot" + "github.com/open-dingtalk/dingtalk-stream-sdk-go/client" +) + +type DingBotServiceInterface interface { + GetServiceCfg(cfg map[string]*config.DingTalkBot) (*config.DingTalkBot, string) + OnChatBotMessageReceived(ctx context.Context, data *chatbot.BotCallbackDataModel) (content []byte, err error) +} + +type DingTalkBotServer struct { + Clients []*client.StreamClient +} + +func NewDingTalkBotServer( + cfg *config.Config, + services []DingBotServiceInterface, +) *DingTalkBotServer { + clients := make([]*client.StreamClient, 0) + for _, service := range services { + serviceConf, index := service.GetServiceCfg(cfg.DingTalkBots) + if serviceConf == nil { + log.Info("未找到%s配置", index) + continue + } + cli := DingBotServerInit(serviceConf.ClientId, serviceConf.ClientSecret, service) + if cli == nil { + log.Info("%s客户端初始失败", index) + continue + } + clients = append(clients, cli) + } + return &DingTalkBotServer{ + Clients: clients, + } +} + +func ProvideAllDingBotServices( + dingBotSvc *services.DingBotService, +) []DingBotServiceInterface { + return []DingBotServiceInterface{dingBotSvc} +} + +func (d *DingTalkBotServer) Run(ctx context.Context) { + for name, cli := range d.Clients { + err := cli.Start(ctx) + if err != nil { + log.Info("%s启动失败", name) + continue + } + log.Info("%s启动成功", name) + } +} +func DingBotServerInit(clientId string, clientSecret string, service DingBotServiceInterface) (cli *client.StreamClient) { + cli = client.NewStreamClient(client.WithAppCredential(client.NewAppCredentialConfig(clientId, clientSecret))) + cli.RegisterChatBotCallbackRouter(service.OnChatBotMessageReceived) + return +} diff --git a/internal/server/provider_set.go b/internal/server/provider_set.go index dd6b4b4..d5cef3d 100644 --- a/internal/server/provider_set.go +++ b/internal/server/provider_set.go @@ -1,5 +1,12 @@ package server -import "github.com/google/wire" +import ( + "github.com/google/wire" +) -var ProviderSetServer = wire.NewSet(NewServers, NewHTTPServer) +var ProviderSetServer = wire.NewSet( + NewServers, + NewHTTPServer, + ProvideAllDingBotServices, + NewDingTalkBotServer, +) diff --git a/internal/server/server.go b/internal/server/server.go index 488ef38..02c8f84 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -1,13 +1,27 @@ package server -import "github.com/gofiber/fiber/v2" +import ( + "ai_scheduler/internal/config" + + "github.com/gofiber/fiber/v2" +) type Servers struct { - HttpServer *fiber.App + cfg *config.Config + HttpServer *fiber.App + DingBotServer *DingTalkBotServer } -func NewServers(fiber *fiber.App) *Servers { +func NewServers(cfg *config.Config, fiber *fiber.App, DingBotServer *DingTalkBotServer) *Servers { return &Servers{ - HttpServer: fiber, + HttpServer: fiber, + cfg: cfg, + DingBotServer: DingBotServer, } } + +//func DingBotServerInit(clientId string, clientSecret string, cfg *config.Config, handler *do.Handle, do *do.Do) (cli *client.StreamClient) { +// cli = client.NewStreamClient(client.WithAppCredential(client.NewAppCredentialConfig(clientId, clientSecret))) +// cli.RegisterChatBotCallbackRouter(services.NewDingBotService(cfg, handler, do).OnChatBotMessageReceived) +// return +//} diff --git a/internal/services/callback.go b/internal/services/callback.go index 415f7cb..b0697c0 100644 --- a/internal/services/callback.go +++ b/internal/services/callback.go @@ -212,7 +212,7 @@ func (s *CallbackService) sendStreamLog(sessionID string, content string) { } streamLog := entitys.Response{ - Index: constants.BotToolsBugOptimizationSubmit, + Index: string(constants.BotToolsBugOptimizationSubmit), Content: content, Type: entitys.ResponseLog, } @@ -227,7 +227,7 @@ func (s *CallbackService) sendStreamTxt(sessionID string, content string) { } streamLog := entitys.Response{ - Index: constants.BotToolsBugOptimizationSubmit, + Index: string(constants.BotToolsBugOptimizationSubmit), Content: content, Type: entitys.ResponseText, } @@ -242,7 +242,7 @@ func (s *CallbackService) sendStreamLoading(sessionID string, content string) { } streamLog := entitys.Response{ - Index: constants.BotToolsBugOptimizationSubmit, + Index: string(constants.BotToolsBugOptimizationSubmit), Content: content, Type: entitys.ResponseLoading, } diff --git a/internal/services/dtalk_bot.go b/internal/services/dtalk_bot.go new file mode 100644 index 0000000..2855471 --- /dev/null +++ b/internal/services/dtalk_bot.go @@ -0,0 +1,135 @@ +package services + +import ( + "ai_scheduler/internal/biz" + "ai_scheduler/internal/config" + "ai_scheduler/internal/entitys" + "context" + "fmt" + + "github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot" +) + +type DingBotService struct { + config *config.Config + replier *chatbot.ChatbotReplier + env string + DingTalkBotBiz *biz.DingTalkBotBiz +} + +func NewDingBotService(config *config.Config, DingTalkBotBiz *biz.DingTalkBotBiz) *DingBotService { + return &DingBotService{config: config, replier: chatbot.NewChatbotReplier(), env: "public", DingTalkBotBiz: DingTalkBotBiz} +} + +func (d *DingBotService) GetServiceCfg(cfg map[string]*config.DingTalkBot) (*config.DingTalkBot, string) { + return cfg[d.env], d.env +} + +func (d *DingBotService) OnChatBotMessageReceived(ctx context.Context, data *chatbot.BotCallbackDataModel) (content []byte, err error) { + requireData, err := d.DingTalkBotBiz.InitRequire(ctx, data) + if err != nil { + return + } + go func() { + defer close(requireData.Ch) + + if recognizeErr := d.DingTalkBotBiz.Recognize(ctx, requireData); recognizeErr != nil { + requireData.Ch <- entitys.Response{ + Type: entitys.ResponseEnd, + Content: fmt.Sprintf("处理消息时出错: %v", recognizeErr), + } + } + //向下传递 + if err = d.handle.HandleMatch(ctx, nil, requireData); err != nil { + requireData.Ch <- entitys.Response{ + Type: entitys.ResponseEnd, + Content: fmt.Sprintf("匹配失败: %v", err), + } + } + }() + for { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case resp, ok := <-requireData.Ch: + if !ok { + return []byte("success"), nil // 通道关闭,处理完成 + } + + if resp.Type == entitys.ResponseLog { + return + } + if err := d.handleRes(ctx, data, resp); err != nil { + return nil, fmt.Errorf("回复失败: %w", err) + } + } + } + return +} + +func (d *DingBotService) handleRes(ctx context.Context, data *chatbot.BotCallbackDataModel, resp entitys.Response) error { + switch resp.Type { + case entitys.ResponseText: + return d.replyText(ctx, data.SessionWebhook, resp.Content) + case entitys.ResponseStream: + return d.replySteam(ctx, data.SessionWebhook, resp.Content) + case entitys.ResponseImg: + return d.replyImg(ctx, data.SessionWebhook, resp.Content) + case entitys.ResponseFile: + return d.replyFile(ctx, data.SessionWebhook, resp.Content) + case entitys.ResponseMarkdown: + return d.replyMarkdown(ctx, data.SessionWebhook, resp.Content) + case entitys.ResponseActionCard: + return d.replyActionCard(ctx, data.SessionWebhook, resp.Content) + default: + return nil + } +} + +func (d *DingBotService) replyText(ctx context.Context, SessionWebhook string, content string, arg ...string) error { + msg := content + if len(arg) > 0 { + msg = fmt.Sprintf(content, arg) + } + return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg)) +} + +func (d *DingBotService) replySteam(ctx context.Context, SessionWebhook string, content string, arg ...string) error { + msg := content + if len(arg) > 0 { + msg = fmt.Sprintf(content, arg) + } + return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg)) +} + +func (d *DingBotService) replyImg(ctx context.Context, SessionWebhook string, content string, arg ...string) error { + msg := content + if len(arg) > 0 { + msg = fmt.Sprintf(content, arg) + } + return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg)) +} + +func (d *DingBotService) replyFile(ctx context.Context, SessionWebhook string, content string, arg ...string) error { + msg := content + if len(arg) > 0 { + msg = fmt.Sprintf(content, arg) + } + return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg)) +} + +func (d *DingBotService) replyMarkdown(ctx context.Context, SessionWebhook string, content string, arg ...string) error { + msg := content + if len(arg) > 0 { + msg = fmt.Sprintf(content, arg) + } + return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg)) +} + +func (d *DingBotService) replyActionCard(ctx context.Context, SessionWebhook string, content string, arg ...string) error { + msg := content + if len(arg) > 0 { + msg = fmt.Sprintf(content, arg) + } + return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg)) +} diff --git a/internal/services/provider_set.go b/internal/services/provider_set.go index 0e1284b..9610ed6 100644 --- a/internal/services/provider_set.go +++ b/internal/services/provider_set.go @@ -6,4 +6,11 @@ import ( "github.com/google/wire" ) -var ProviderSetServices = wire.NewSet(NewChatService, NewSessionService, gateway.NewGateway, NewTaskService, NewCallbackService) +var ProviderSetServices = wire.NewSet( + NewChatService, + NewSessionService, + gateway.NewGateway, + NewTaskService, + NewCallbackService, + NewDingBotService, +) diff --git a/internal/tools_bot/dtalk_bot.go b/internal/tools_bot/dtalk_bot.go index 0e9525a..2eae1d5 100644 --- a/internal/tools_bot/dtalk_bot.go +++ b/internal/tools_bot/dtalk_bot.go @@ -2,22 +2,18 @@ package tools_bot import ( "ai_scheduler/internal/config" - "ai_scheduler/internal/data/constants" - errors "ai_scheduler/internal/data/error" "ai_scheduler/internal/data/impl" "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg/utils_ollama" - "context" - "github.com/gofiber/fiber/v2/log" + "context" ) type BotTool struct { config *config.Config llm *utils_ollama.Client sessionImpl *impl.SessionImpl - taskMap map[string]string // task_id -> session_id - // zltxOrderAfterSaleTool tools.ZltxOrderAfterSaleTool + taskMap map[string]string } // NewBotTool 创建直连天下订单详情工具 @@ -27,12 +23,5 @@ func NewBotTool(config *config.Config, llm *utils_ollama.Client, sessionImpl *im // Execute 执行直连天下订单详情查询 func (w *BotTool) Execute(ctx context.Context, toolName string, requireData *entitys.RequireData) (err error) { - switch toolName { - case constants.BotToolsBugOptimizationSubmit: - err = w.BugOptimizationSubmit(ctx, requireData) - default: - log.Errorf("未知的工具类型:%s", toolName) - err = errors.ParamErr("未知的工具类型:%s", toolName) - } return } From f4980f1f22277baa18e2307bf4e94bd22b1d8d2b Mon Sep 17 00:00:00 2001 From: renzhiyuan <465386466@qq.com> Date: Tue, 9 Dec 2025 16:45:38 +0800 Subject: [PATCH 18/66] feat: add DingTalk bot integration --- internal/biz/handle/file.go | 70 +++++++++++++++++++++++++++++++++ internal/data/constants/file.go | 37 ----------------- 2 files changed, 70 insertions(+), 37 deletions(-) diff --git a/internal/biz/handle/file.go b/internal/biz/handle/file.go index 3c9d79b..9fd4fae 100644 --- a/internal/biz/handle/file.go +++ b/internal/biz/handle/file.go @@ -1,7 +1,17 @@ package handle import ( + "ai_scheduler/internal/data/constants" "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg/l_request" + "errors" + "fmt" + "io" + "net/http" + "path/filepath" + "strings" + + "github.com/gabriel-vasile/mimetype" ) // HandleRecognizeFile 这里的目的是无论将什么类型的file都转为二进制格式 @@ -12,3 +22,63 @@ func HandleRecognizeFile(file *entitys.RecognizeFile) { //Todo 仲云 return } + +// 下载文件并返回二进制数据、MIME 类型 +func downloadFile(fileUrl string) (fileBytes []byte, contentType string, err error) { + if len(fileUrl) == 0 { + return + } + req := l_request.Request{ + Method: "GET", + Url: fileUrl, + Headers: map[string]string{ + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36", + "Accept": "image/webp,image/apng,image/*,*/*;q=0.8", + }, + } + res, err := req.Send() + if err != nil { + return + } + var ex bool + if contentType, ex = res.Headers["Content-Type"]; !ex { + err = errors.New("Content-Type不存在") + return + } + + if res.StatusCode != http.StatusOK { + err = fmt.Errorf("server returned non-200 status: %d", res.StatusCode) + } + fileBytes = res.Content + + return fileBytes, contentType, nil +} + +// detectFileType 判断文件类型 +func detectFileType(file io.ReadSeeker, filename string) constants.FileType { + // 1. 读取文件头检测 MIME + buffer := make([]byte, 512) + n, _ := file.Read(buffer) + file.Seek(0, io.SeekStart) // 重置读取位置 + + detectedMIME := mimetype.Detect(buffer[:n]).String() + for fileType, items := range constants.FileTypeMappings { + for _, item := range items { + if !strings.HasPrefix(item, ".") && item == detectedMIME { + return fileType + } + } + } + + // 2. 备用:通过扩展名检测 + ext := strings.ToLower(filepath.Ext(filename)) + for fileType, items := range constants.FileTypeMappings { + for _, item := range items { + if strings.HasPrefix(item, ".") && item == ext { + return fileType + } + } + } + + return constants.FileTypeUnknown +} diff --git a/internal/data/constants/file.go b/internal/data/constants/file.go index a418959..132ad0e 100644 --- a/internal/data/constants/file.go +++ b/internal/data/constants/file.go @@ -1,11 +1,5 @@ package constants -import ( - "github.com/gabriel-vasile/mimetype" - "io" - "strings" -) - type FileType string const ( @@ -42,34 +36,3 @@ var FileTypeMappings = map[FileType][]string{ ".txt", }, } - -func DetectFileType(file io.ReadSeeker) (FileType, error) { - // 读取文件头(512字节足够检测大多数类型) - buffer := make([]byte, 512) - n, err := file.Read(buffer) - if err != nil && err != io.EOF { - return FileTypeUnknown, err - } - - // 重置读取位置 - if _, err := file.Seek(0, io.SeekStart); err != nil { - return FileTypeUnknown, err - } - - // 检测 MIME 类型 - detectedMIME := mimetype.Detect(buffer[:n]).String() - - // 遍历映射表,匹配 MIME 或扩展名 - for fileType, mimesOrExts := range FileTypeMappings { - for _, item := range mimesOrExts { - if strings.HasPrefix(item, ".") { - continue // 跳过扩展名(仅 MIME 检测) - } - if item == detectedMIME { - return fileType, nil - } - } - } - - return FileTypeUnknown, nil -} From 586ad6a124417ade86ba02608b41f1a4b49a8c31 Mon Sep 17 00:00:00 2001 From: renzhiyuan <465386466@qq.com> Date: Tue, 9 Dec 2025 17:24:41 +0800 Subject: [PATCH 19/66] refactor: optimize ollama service and update dependencies --- go.mod | 1 - go.sum | 2 - internal/biz/llm_service/ollama.go | 177 ++++++++--------------------- internal/entitys/recognize.go | 3 + 4 files changed, 53 insertions(+), 130 deletions(-) diff --git a/go.mod b/go.mod index 48d0741..073c4eb 100644 --- a/go.mod +++ b/go.mod @@ -58,7 +58,6 @@ require ( github.com/evanphx/json-patch v0.5.2 // indirect github.com/fasthttp/websocket v1.5.3 // indirect github.com/fsnotify/fsnotify v1.6.0 // indirect - github.com/gabriel-vasile/mimetype v1.4.3 // indirect github.com/go-sql-driver/mysql v1.8.1 // indirect github.com/goph/emperror v0.17.2 // indirect github.com/gorilla/websocket v1.5.0 // indirect diff --git a/go.sum b/go.sum index bbb5c7a..c1b06f6 100644 --- a/go.sum +++ b/go.sum @@ -174,8 +174,6 @@ github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMo github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY= github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw= -github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= -github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= github.com/gabriel-vasile/mimetype v1.4.11 h1:AQvxbp830wPhHTqc1u7nzoLT+ZFxGY7emj5DR5DYFik= github.com/gabriel-vasile/mimetype v1.4.11/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s= github.com/garyburd/redigo v1.6.0/go.mod h1:NR3MbYisc3/PwhQ00EMzDiPmrwpPxAn5GI05/YaO1SY= diff --git a/internal/biz/llm_service/ollama.go b/internal/biz/llm_service/ollama.go index 8c127cd..7333de5 100644 --- a/internal/biz/llm_service/ollama.go +++ b/internal/biz/llm_service/ollama.go @@ -3,19 +3,15 @@ package llm_service import ( "ai_scheduler/internal/config" "ai_scheduler/internal/data/impl" - "ai_scheduler/internal/data/model" "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg" "ai_scheduler/internal/pkg/utils_ollama" "ai_scheduler/internal/pkg/utils_vllm" "context" - "encoding/json" + "errors" - "strings" - "time" "github.com/ollama/ollama/api" - "xorm.io/builder" ) type OllamaService struct { @@ -41,12 +37,8 @@ func NewOllamaGenerate( func (r *OllamaService) IntentRecognize(ctx context.Context, req *entitys.ToolSelect) (msg string, err error) { - if err != nil { - return - } toolDefinitions := r.registerToolsOllama(req.Tools) - - match, err := r.client.ToolSelect(ctx, rec.Prompt, toolDefinitions) + match, err := r.client.ToolSelect(ctx, req.Prompt, toolDefinitions) if err != nil { return } @@ -70,132 +62,63 @@ func (r *OllamaService) IntentRecognize(ctx context.Context, req *entitys.ToolSe return } -func (r *OllamaService) getPrompt(ctx context.Context, requireData *entitys.RequireData) ([]api.Message, error) { +//func (r *OllamaService) RecognizeWithImg(ctx context.Context, imgByte []api.ImageData, ch chan entitys.Response) (desc api.GenerateResponse, err error) { +// if imgByte == nil { +// return +// } +// entitys.ResLog(requireData.Ch, "recognize_img_start", "图片识别中...") +// +// desc, err = r.client.Generation(ctx, &api.GenerateRequest{ +// Model: r.config.Ollama.VlModel, +// Stream: new(bool), +// System: r.config.DefaultPrompt.ImgRecognize.SystemPrompt, +// Prompt: r.config.DefaultPrompt.ImgRecognize.UserPrompt, +// Images: requireData.ImgByte, +// KeepAlive: &api.Duration{Duration: 3600 * time.Second}, +// //Think: &api.ThinkValue{Value: false}, +// }) +// if err != nil { +// return +// } +// entitys.ResLog(requireData.Ch, "recognize_img_end", "图片识别完成,识别内容:"+desc.Response) +// return +//} - var ( - prompt = make([]api.Message, 0) - ) - content, err := r.getUserContent(ctx, requireData) - if err != nil { - return nil, err - } - prompt = append(prompt, api.Message{ - Role: "system", - Content: buildSystemPrompt(requireData.Sys.SysPrompt), - }, api.Message{ - Role: "assistant", - Content: "### 聊天记录:" + pkg.JsonStringIgonErr(buildAssistant(requireData.Histories)), - }, api.Message{ - Role: "user", - Content: content, - }) +//func (r *OllamaService) RecognizeWithImgVllm(ctx context.Context, requireData *entitys.RequireData) (desc api.GenerateResponse, err error) { +// if requireData.ImgByte == nil { +// return +// } +// entitys.ResLog(requireData.Ch, "recognize_img_start", "图片识别中...") +// +// outMsg, err := r.vllmClient.RecognizeWithImg(ctx, +// r.config.DefaultPrompt.ImgRecognize.SystemPrompt, +// r.config.DefaultPrompt.ImgRecognize.UserPrompt, +// requireData.ImgUrls, +// ) +// if err != nil { +// return api.GenerateResponse{}, err +// } +// +// desc = api.GenerateResponse{ +// Response: outMsg.Content, +// } +// +// entitys.ResLog(requireData.Ch, "recognize_img_end", "图片识别完成,识别内容:"+desc.Response) +// return +//} - return prompt, nil -} - -func (r *OllamaService) getUserContent(ctx context.Context, requireData *entitys.RequireData) (string, error) { - var content strings.Builder - content.WriteString(requireData.Req.Text) - if len(requireData.ImgByte) > 0 { - content.WriteString("\n") - } - - if len(requireData.Req.Tags) > 0 { - content.WriteString("\n") - content.WriteString("### 工具必须使用:") - content.WriteString(requireData.Req.Tags) - } - - if len(requireData.ImgByte) > 0 { - // desc, err := r.RecognizeWithImg(ctx, requireData) - desc, err := r.RecognizeWithImgVllm(ctx, requireData) - if err != nil { - return "", err - } - content.WriteString("### 上传图片解析内容:\n") - content.WriteString(requireData.Req.Tags) - content.WriteString(desc.Response) - } - - if requireData.Req.MarkHis > 0 { - var his model.AiChatHi - cond := builder.NewCond() - cond = cond.And(builder.Eq{"his_id": requireData.Req.MarkHis}) - err := r.chatHis.GetOneBySearchToStrut(&cond, &his) - if err != nil { - return "", err - } - content.WriteString("### 引用历史聊天记录:\n") - content.WriteString(pkg.JsonStringIgonErr(BuildChatHisMessage([]model.AiChatHi{his}))) - } - return content.String(), nil -} - -func (r *OllamaService) RecognizeWithImg(ctx context.Context, requireData *entitys.RequireData) (desc api.GenerateResponse, err error) { - if requireData.ImgByte == nil { -func (r *OllamaService) RecognizeWithImg(ctx context.Context, imgByte []api.ImageData, ch chan entitys.Response) (desc api.GenerateResponse, err error) { - if imgByte == nil { - return - } - entitys.ResLog(requireData.Ch, "recognize_img_start", "图片识别中...") - - desc, err = r.client.Generation(ctx, &api.GenerateRequest{ - Model: r.config.Ollama.VlModel, - Stream: new(bool), - System: r.config.DefaultPrompt.ImgRecognize.SystemPrompt, - Prompt: r.config.DefaultPrompt.ImgRecognize.UserPrompt, - Images: requireData.ImgByte, - KeepAlive: &api.Duration{Duration: 3600 * time.Second}, - //Think: &api.ThinkValue{Value: false}, - }) - if err != nil { - return - } - entitys.ResLog(requireData.Ch, "recognize_img_end", "图片识别完成,识别内容:"+desc.Response) - return -} - -func (r *OllamaService) RecognizeWithImgVllm(ctx context.Context, requireData *entitys.RequireData) (desc api.GenerateResponse, err error) { - if requireData.ImgByte == nil { - return - } - entitys.ResLog(requireData.Ch, "recognize_img_start", "图片识别中...") - - outMsg, err := r.vllmClient.RecognizeWithImg(ctx, - r.config.DefaultPrompt.ImgRecognize.SystemPrompt, - r.config.DefaultPrompt.ImgRecognize.UserPrompt, - requireData.ImgUrls, - ) - if err != nil { - return api.GenerateResponse{}, err - } - - desc = api.GenerateResponse{ - Response: outMsg.Content, - } - - entitys.ResLog(requireData.Ch, "recognize_img_end", "图片识别完成,识别内容:"+desc.Response) - return -} - -func (r *OllamaService) registerToolsOllama(tasks []model.AiTask) []api.Tool { +func (r *OllamaService) registerToolsOllama(tasks []entitys.RegistrationTask) []api.Tool { taskPrompt := make([]api.Tool, 0) for _, task := range tasks { - var taskConfig entitys.TaskConfigDetail - err := json.Unmarshal([]byte(task.Config), &taskConfig) - if err != nil { - continue - } - taskPrompt = append(taskPrompt, api.Tool{ Type: "function", Function: api.ToolFunction{ - Name: task.Index, + Name: task.Name, Description: task.Desc, Parameters: api.ToolFunctionParameters{ - Type: taskConfig.Param.Type, - Required: taskConfig.Param.Required, - Properties: taskConfig.Param.Properties, + Type: task.TaskConfigDetail.Param.Type, + Required: task.TaskConfigDetail.Param.Required, + Properties: task.TaskConfigDetail.Param.Properties, }, }, }) diff --git a/internal/entitys/recognize.go b/internal/entitys/recognize.go index ae32b89..6d692eb 100644 --- a/internal/entitys/recognize.go +++ b/internal/entitys/recognize.go @@ -13,6 +13,9 @@ type Recognize struct { } type RegistrationTask struct { + Name string + Desc string + TaskConfigDetail TaskConfigDetail } type RecognizeUserContent struct { From e099f821193e04957545a6651ce6c0cef929544f Mon Sep 17 00:00:00 2001 From: fuzhongyun <15339891972@163.com> Date: Tue, 9 Dec 2025 18:36:09 +0800 Subject: [PATCH 20/66] =?UTF-8?q?feat=EF=BC=9A=E5=B7=A5=E4=BD=9C=E6=B5=81?= =?UTF-8?q?=E3=80=81=E5=B7=A5=E5=85=B7=E6=9E=B6=E6=9E=84=E8=B0=83=E6=95=B4?= =?UTF-8?q?=20=E4=B8=8E=20demo=20=E7=BC=96=E5=86=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- go.mod | 4 +- go.sum | 4 ++ .../zltx/order_after_reseller_batch/client.go | 47 +++++++++++++++++++ .../order_after_reseller_batch/invokable.go | 24 ++++++++++ .../zltx/order_after_reseller_batch/types.go | 32 +++++++++++++ .../domain/workflow/zltx/crontab_supplier.go | 1 + .../zltx/order_after_reseller_batch.go | 47 +++++++++++++++++++ internal/pkg/util/ctx.go | 25 ++++++++++ 8 files changed, 183 insertions(+), 1 deletion(-) create mode 100644 internal/domain/tools/zltx/order_after_reseller_batch/client.go create mode 100644 internal/domain/tools/zltx/order_after_reseller_batch/invokable.go create mode 100644 internal/domain/tools/zltx/order_after_reseller_batch/types.go create mode 100644 internal/domain/workflow/zltx/crontab_supplier.go create mode 100644 internal/domain/workflow/zltx/order_after_reseller_batch.go create mode 100644 internal/pkg/util/ctx.go diff --git a/go.mod b/go.mod index 2978eaa..366d790 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/alibabacloud-go/tea v1.2.2 github.com/alibabacloud-go/tea-utils/v2 v2.0.6 github.com/cloudwego/eino v0.7.7 + github.com/cloudwego/eino-ext/components/model/ollama v0.1.6 github.com/cloudwego/eino-ext/components/model/openai v0.1.5 github.com/emirpasic/gods v1.18.1 github.com/faabiosr/cachego v0.26.0 @@ -53,6 +54,7 @@ require ( github.com/dlclark/regexp2 v1.11.4 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/eino-contrib/jsonschema v1.0.3 // indirect + github.com/eino-contrib/ollama v0.1.0 // indirect github.com/evanphx/json-patch v0.5.2 // indirect github.com/fasthttp/websocket v1.5.3 // indirect github.com/fsnotify/fsnotify v1.6.0 // indirect @@ -101,7 +103,7 @@ require ( go.uber.org/atomic v1.9.0 // indirect go.uber.org/multierr v1.9.0 // indirect golang.org/x/arch v0.11.0 // indirect - golang.org/x/crypto v0.36.0 // indirect + golang.org/x/crypto v0.39.0 // indirect golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect golang.org/x/net v0.38.0 // indirect golang.org/x/sys v0.33.0 // indirect diff --git a/go.sum b/go.sum index 4b4e61f..fdce0ca 100644 --- a/go.sum +++ b/go.sum @@ -132,6 +132,8 @@ github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= github.com/cloudwego/eino v0.7.7 h1:WhP0SMWWPgLdOH03HrKUxtP9/Q96NhziMZNEQl9lxpU= github.com/cloudwego/eino v0.7.7/go.mod h1:nA8Vacmuqv3pqKBQbTWENBLQ8MmGmPt/WqiyLeB8ohQ= +github.com/cloudwego/eino-ext/components/model/ollama v0.1.6 h1:ZbrhV91uE0hGIOYXhb2i3G6tQJ/rK2SLYtoYrmocZXM= +github.com/cloudwego/eino-ext/components/model/ollama v0.1.6/go.mod h1:GDXrvorGdRNV6g2mK5jdla2D8Xc/hh7XDrTeGDteLLo= github.com/cloudwego/eino-ext/components/model/openai v0.1.5 h1:+yvGbTPw93li9GSmdm6Rix88Yy8AXg5NNBcRbWx3CQU= github.com/cloudwego/eino-ext/components/model/openai v0.1.5/go.mod h1:IPVYMFoZcuHeVEsDTGN6SZjvue0xr1iZFhdpq1SBWdQ= github.com/cloudwego/eino-ext/libs/acl/openai v0.1.2 h1:r9Id2wzJ05PoHl+Km7jQgNMgciaZI93TVnUYso89esM= @@ -151,6 +153,8 @@ github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkp github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/eino-contrib/jsonschema v1.0.3 h1:2Kfsm1xlMV0ssY2nuxshS4AwbLFuqmPmzIjLVJ1Fsp0= github.com/eino-contrib/jsonschema v1.0.3/go.mod h1:cpnX4SyKjWjGC7iN2EbhxaTdLqGjCi0e9DxpLYxddD4= +github.com/eino-contrib/ollama v0.1.0 h1:z1NaMdKW6X1ftP8g5xGGR5zDRPUtuTKFq35vBQgxsN4= +github.com/eino-contrib/ollama v0.1.0/go.mod h1:mYsQ7b3DeqY8bHPuD3MZJYTqkgyL6LoemxoP/B7ZNhA= github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc= github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= diff --git a/internal/domain/tools/zltx/order_after_reseller_batch/client.go b/internal/domain/tools/zltx/order_after_reseller_batch/client.go new file mode 100644 index 0000000..90da543 --- /dev/null +++ b/internal/domain/tools/zltx/order_after_reseller_batch/client.go @@ -0,0 +1,47 @@ +package order_after_reseller_batch + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/pkg/l_request" + "ai_scheduler/internal/pkg/util" + "context" + "encoding/json" + "errors" + "fmt" +) + +func Call(ctx context.Context, cfg config.ToolConfig, orderNumbers []string) (OrderAfterSaleResellerBatchData, error) { + if len(orderNumbers) == 0 { + return OrderAfterSaleResellerBatchData{}, errors.New("批充订单号不能为空") + } + token := util.GetTokenFromContext(ctx) + if token == "" { + return OrderAfterSaleResellerBatchData{}, errors.New("token 未注入") + } + r := l_request.Request{ + Url: cfg.BaseURL, + Headers: map[string]string{ + "Authorization": fmt.Sprintf("Bearer %s", token), + }, + Method: "POST", + Json: map[string]any{ + "order_numbers": orderNumbers, + "order_type": 2, + }, + } + res, err := r.Send() + if err != nil { + return OrderAfterSaleResellerBatchData{}, err + } + var response OrderAfterSaleResellerBatchResponse + if err = json.Unmarshal(res.Content, &response); err != nil { + return OrderAfterSaleResellerBatchData{}, err + } + if response.Code != 200 { + return OrderAfterSaleResellerBatchData{}, fmt.Errorf("售后订单查询异常: %s", response.Error) + } + if len(response.Data.Data) == 0 { + return OrderAfterSaleResellerBatchData{}, errors.New("未查询到相应售后订单,请核实订单号是否正确") + } + return response.Data, nil +} diff --git a/internal/domain/tools/zltx/order_after_reseller_batch/invokable.go b/internal/domain/tools/zltx/order_after_reseller_batch/invokable.go new file mode 100644 index 0000000..34da7e1 --- /dev/null +++ b/internal/domain/tools/zltx/order_after_reseller_batch/invokable.go @@ -0,0 +1,24 @@ +package order_after_reseller_batch + +import ( + "ai_scheduler/internal/config" + "context" + + "github.com/cloudwego/eino/components/tool" + toolutils "github.com/cloudwego/eino/components/tool/utils" +) + +type Args struct { + OrderNumber []string `json:"orderNumber"` +} + +func NewInvokable(cfg config.ToolConfig) tool.InvokableTool { + run := func(ctx context.Context, in Args) (OrderAfterSaleResellerBatchData, error) { + return Call(ctx, cfg, in.OrderNumber) + } + t, err := toolutils.InferTool("zltxOrderAfterSaleResellerBatch", "直连天下下游分销商批充订单售后工具", run) + if err != nil { + panic(err) + } + return t +} diff --git a/internal/domain/tools/zltx/order_after_reseller_batch/types.go b/internal/domain/tools/zltx/order_after_reseller_batch/types.go new file mode 100644 index 0000000..9d2a071 --- /dev/null +++ b/internal/domain/tools/zltx/order_after_reseller_batch/types.go @@ -0,0 +1,32 @@ +package order_after_reseller_batch + +type OrderAfterSaleResellerBatchResponse struct { + Code int `json:"code"` + Error string `json:"error"` + Data OrderAfterSaleResellerBatchData `json:"data"` +} + +type OrderAfterSaleResellerBatchData struct { + Data []OrderAfterSaleResellerBatchBase `json:"data"` + ExtData map[string]OrderAfterSaleResellerBatchExtItem `json:"extraData"` +} + +type OrderAfterSaleResellerBatchBase struct { + OrderType int `json:"orderType"` + OrderNumber string `json:"orderNumber"` + OrderAmount float64 `json:"orderAmount"` + OrderPrice float64 `json:"orderPrice"` + SignCompany int `json:"signCompany"` + OrderQuantity int `json:"orderQuantity"` + ResellerID int `json:"resellerId"` + ResellerName string `json:"resellerName"` + OurProductID int `json:"ourProductId"` + OurProductTitle string `json:"ourProductTitle"` + Account []string `json:"account"` + Platforms map[int]string `json:"platforms"` +} + +type OrderAfterSaleResellerBatchExtItem struct { + IsExistsAfterSale bool `json:"isExistsAfterSale"` + SerialCreateTime int `json:"createTime"` +} diff --git a/internal/domain/workflow/zltx/crontab_supplier.go b/internal/domain/workflow/zltx/crontab_supplier.go new file mode 100644 index 0000000..29ee896 --- /dev/null +++ b/internal/domain/workflow/zltx/crontab_supplier.go @@ -0,0 +1 @@ +package zltx diff --git a/internal/domain/workflow/zltx/order_after_reseller_batch.go b/internal/domain/workflow/zltx/order_after_reseller_batch.go new file mode 100644 index 0000000..af41488 --- /dev/null +++ b/internal/domain/workflow/zltx/order_after_reseller_batch.go @@ -0,0 +1,47 @@ +package zltx + +import ( + "ai_scheduler/internal/config" + toolZoarb "ai_scheduler/internal/domain/tools/zltx/order_after_reseller_batch" + "context" + "errors" + + "github.com/cloudwego/eino/compose" +) + +type OrderAfterSaleResellerBatchWorkflowInput struct { + Token string `json:"token"` + OrderNumber []string `json:"orderNumber"` +} + +var ErrInvalidOrderNumbers = errors.New("orderNumber 不能为空") +var ErrMissingToken = errors.New("token 不能为空") + +func BuildOrderAfterResellerBatchWorkflow(ctx context.Context, cfg config.ToolConfig) (compose.Runnable[OrderAfterSaleResellerBatchWorkflowInput, toolZoarb.OrderAfterSaleResellerBatchData], error) { + // 定义工作流、出入参 + c := compose.NewChain[OrderAfterSaleResellerBatchWorkflowInput, toolZoarb.OrderAfterSaleResellerBatchData]() + + // 1.入参解析与校验 + c.AppendLambda(compose.InvokableLambda(func(_ context.Context, in OrderAfterSaleResellerBatchWorkflowInput) (OrderAfterSaleResellerBatchWorkflowInput, error) { + if len(in.OrderNumber) == 0 { + return OrderAfterSaleResellerBatchWorkflowInput{}, ErrInvalidOrderNumbers + } + if in.Token == "" { + return OrderAfterSaleResellerBatchWorkflowInput{}, ErrMissingToken + } + return in, nil + })) + + // 2.调用工具 + c.AppendLambda(compose.InvokableLambda(func(ctx context.Context, in OrderAfterSaleResellerBatchWorkflowInput) (toolZoarb.OrderAfterSaleResellerBatchData, error) { + return toolZoarb.Call(ctx, cfg, in.OrderNumber) + })) + + // 3.结果映射与整形 + c.AppendLambda(compose.InvokableLambda(func(_ context.Context, in toolZoarb.OrderAfterSaleResellerBatchData) (toolZoarb.OrderAfterSaleResellerBatchData, error) { + return in, nil + })) + + // 编译工作流 + return c.Compile(ctx) +} diff --git a/internal/pkg/util/ctx.go b/internal/pkg/util/ctx.go new file mode 100644 index 0000000..1998805 --- /dev/null +++ b/internal/pkg/util/ctx.go @@ -0,0 +1,25 @@ +package util + +import ( + "context" +) + +type ContextKey string + +const ( + ContextKeyToken ContextKey = "token" +) + +// token 写入上下文 +func SetTokenToContext(ctx context.Context, token string) context.Context { + return context.WithValue(ctx, ContextKeyToken, token) +} + +// 从上下文获取token +func GetTokenFromContext(ctx context.Context) string { + token, ok := ctx.Value(ContextKeyToken).(string) + if !ok { + return "" + } + return token +} From e6c142f3a125f5619de592486f8026c2a0f39bff Mon Sep 17 00:00:00 2001 From: renzhiyuan <465386466@qq.com> Date: Wed, 10 Dec 2025 11:29:54 +0800 Subject: [PATCH 21/66] =?UTF-8?q?fix:=E4=BF=AE=E5=A4=8D=E6=8A=A5=E9=94=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/config_test.yaml | 8 +- internal/biz/ding_talk_bot.go | 112 +++++++++++++++++- internal/biz/do/prompt.go | 63 +++++++++- internal/config/config.go | 45 ++++--- internal/data/constants/bot.go | 6 + internal/data/impl/base.go | 2 +- internal/data/impl/bot_chat_history.go | 4 +- internal/data/impl/bot_config.go | 17 +++ internal/data/impl/bot_impl.go | 28 ----- internal/data/impl/provider_set.go | 8 +- .../{ai_bot.gen.go => ai_bot_config.gen.go} | 14 +-- internal/entitys/bot.go | 5 + internal/server/ding_talk_bot.go | 31 +++-- internal/services/dtalk_bot.go | 112 ++++-------------- internal/services/provider_set.go | 8 +- 15 files changed, 283 insertions(+), 180 deletions(-) create mode 100644 internal/data/impl/bot_config.go delete mode 100644 internal/data/impl/bot_impl.go rename internal/data/model/{ai_bot.gen.go => ai_bot_config.gen.go} (72%) diff --git a/config/config_test.yaml b/config/config_test.yaml index baedada..4ebccb8 100644 --- a/config/config_test.yaml +++ b/config/config_test.yaml @@ -88,7 +88,7 @@ permissionConfig: permission_url: "http://api.test.user.1688sup.cn:8001/v1/menu/myCodes?systemId=" -ding_talk_bots: - public: - client_id: "dingchg59zwwvmuuvldx", - client_secret: "ZwetAnRiTQobNFVlNrshRagSMAJIFpBAepWkWI7on7Tt_o617KHtTjBLp8fQfplz", \ No newline at end of file +#ding_talk_bots: +# public: +# client_id: "dingchg59zwwvmuuvldx", +# client_secret: "ZwetAnRiTQobNFVlNrshRagSMAJIFpBAepWkWI7on7Tt_o617KHtTjBLp8fQfplz", \ No newline at end of file diff --git a/internal/biz/ding_talk_bot.go b/internal/biz/ding_talk_bot.go index 09de5d4..f2abea1 100644 --- a/internal/biz/ding_talk_bot.go +++ b/internal/biz/ding_talk_bot.go @@ -2,29 +2,61 @@ package biz import ( "ai_scheduler/internal/biz/do" + "ai_scheduler/internal/data/constants" + "ai_scheduler/internal/data/impl" + "ai_scheduler/internal/data/model" "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg/mapstructure" "context" + "fmt" + "github.com/gofiber/fiber/v2/log" "github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot" + "xorm.io/builder" ) // AiRouterBiz 智能路由服务 type DingTalkBotBiz struct { - do *do.Do - handle *do.Handle + do *do.Do + handle *do.Handle + botConfigImpl *impl.BotConfigImpl + replier *chatbot.ChatbotReplier + log log.Logger } // NewDingTalkBotBiz func NewDingTalkBotBiz( do *do.Do, handle *do.Handle, + botConfigImpl *impl.BotConfigImpl, ) *DingTalkBotBiz { return &DingTalkBotBiz{ - do: do, - handle: handle, + do: do, + handle: handle, + botConfigImpl: botConfigImpl, + replier: chatbot.NewChatbotReplier(), } } +func (d *DingTalkBotBiz) GetDingTalkBotCfgList() (dingBotList []entitys.DingTalkBot, err error) { + botConfig := make([]model.AiBotConfig, 0) + + cond := builder.NewCond() + cond = cond.And(builder.Eq{"status": constants.Enable}) + cond = cond.And(builder.Eq{"bot_type": constants.BotTypeDingTalk}) + err = d.botConfigImpl.GetRangeToMapStruct(&cond, &botConfig) + for _, v := range botConfig { + var config entitys.DingTalkBot + err = mapstructure.Decode(v, &config) + if err != nil { + d.log.Info("初始化“%s”失败:%s", v.BotName, err.Error()) + } + dingBotList = append(dingBotList, config) + } + return + +} + func (d *DingTalkBotBiz) InitRequire(ctx context.Context, data *chatbot.BotCallbackDataModel) (requireData *entitys.RequireDataDingTalkBot, err error) { requireData = &entitys.RequireDataDingTalkBot{ Req: data, @@ -37,6 +69,74 @@ func (d *DingTalkBotBiz) InitRequire(ctx context.Context, data *chatbot.BotCallb return } -func (d *DingTalkBotBiz) Recognize(ctx context.Context, rec *entitys.Recognize, ch chan entitys.Response) (match *entitys.Match, err error) { - return d.handle.RecognizeBot(ctx, rec, ch) +func (d *DingTalkBotBiz) Recognize(ctx context.Context, bot *chatbot.BotCallbackDataModel) (match entitys.Match, err error) { + + return d.handle.Recognize(ctx, nil, &do.WithDingTalkBot{}) +} + +func (d *DingTalkBotBiz) HandleRes(ctx context.Context, data *chatbot.BotCallbackDataModel, resp entitys.Response) error { + switch resp.Type { + case entitys.ResponseText: + return d.replyText(ctx, data.SessionWebhook, resp.Content) + case entitys.ResponseStream: + return d.replySteam(ctx, data.SessionWebhook, resp.Content) + case entitys.ResponseImg: + return d.replyImg(ctx, data.SessionWebhook, resp.Content) + case entitys.ResponseFile: + return d.replyFile(ctx, data.SessionWebhook, resp.Content) + case entitys.ResponseMarkdown: + return d.replyMarkdown(ctx, data.SessionWebhook, resp.Content) + case entitys.ResponseActionCard: + return d.replyActionCard(ctx, data.SessionWebhook, resp.Content) + default: + return nil + } +} + +func (d *DingTalkBotBiz) replySteam(ctx context.Context, SessionWebhook string, content string, arg ...string) error { + msg := content + if len(arg) > 0 { + msg = fmt.Sprintf(content, arg) + } + return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg)) +} + +func (d *DingTalkBotBiz) replyText(ctx context.Context, SessionWebhook string, content string, arg ...string) error { + msg := content + if len(arg) > 0 { + msg = fmt.Sprintf(content, arg) + } + return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg)) +} + +func (d *DingTalkBotBiz) replyImg(ctx context.Context, SessionWebhook string, content string, arg ...string) error { + msg := content + if len(arg) > 0 { + msg = fmt.Sprintf(content, arg) + } + return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg)) +} + +func (d *DingTalkBotBiz) replyFile(ctx context.Context, SessionWebhook string, content string, arg ...string) error { + msg := content + if len(arg) > 0 { + msg = fmt.Sprintf(content, arg) + } + return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg)) +} + +func (d *DingTalkBotBiz) replyMarkdown(ctx context.Context, SessionWebhook string, content string, arg ...string) error { + msg := content + if len(arg) > 0 { + msg = fmt.Sprintf(content, arg) + } + return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg)) +} + +func (d *DingTalkBotBiz) replyActionCard(ctx context.Context, SessionWebhook string, content string, arg ...string) error { + msg := content + if len(arg) > 0 { + msg = fmt.Sprintf(content, arg) + } + return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg)) } diff --git a/internal/biz/do/prompt.go b/internal/biz/do/prompt.go index cb1ca18..f25cf95 100644 --- a/internal/biz/do/prompt.go +++ b/internal/biz/do/prompt.go @@ -1,6 +1,7 @@ package do import ( + "ai_scheduler/internal/biz/handle" "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg" "context" @@ -42,7 +43,7 @@ func (f *WithSys) CreatePrompt(ctx context.Context, rec *entitys.Recognize) (mes func (f *WithSys) getUserContent(ctx context.Context, rec *entitys.Recognize) (content strings.Builder, err error) { var hasFile bool - if len(rec.UserContent.FileUrl) > 0 || rec.UserContent.File != nil { + if rec.UserContent.File != nil && (len(rec.UserContent.File.FileUrl) > 0 || rec.UserContent.File.File != nil) { hasFile = true } content.WriteString(rec.UserContent.Text) @@ -64,7 +65,65 @@ func (f *WithSys) getUserContent(ctx context.Context, rec *entitys.Recognize) (c if hasFile { content.WriteString("\n") content.WriteString("### 文件内容:\n") - hand.WriteString(rec.UserContent.FileUrl, rec.UserContent.FileUrl) + handle.HandleRecognizeFile(rec.UserContent.File) + //...do something with file + } + return +} + +type WithDingTalkBot struct { +} + +func (f *WithDingTalkBot) CreatePrompt(ctx context.Context, rec *entitys.Recognize) (mes []api.Message, err error) { + var ( + prompt = make([]api.Message, 0) // 初始化一个空的api.Message切片 + ) + // 获取用户内容,如果出错则直接返回错误 + content, err := f.getUserContent(ctx, rec) + if err != nil { + return nil, err + } + // 构建提示消息列表,包含系统提示、助手回复和用户内容 + mes = append(prompt, api.Message{ + Role: "system", // 系统角色 + Content: rec.SystemPrompt, // 系统提示内容 + }, api.Message{ + Role: "assistant", // 助手角色 + Content: "### 聊天记录:" + pkg.JsonStringIgonErr(rec.ChatHis), // 助手回复内容 + }, api.Message{ + Role: "user", // 用户角色 + Content: content.String(), // 用户输入内容 + }) + + return +} + +func (f *WithDingTalkBot) getUserContent(ctx context.Context, rec *entitys.Recognize) (content strings.Builder, err error) { + var hasFile bool + if rec.UserContent.File != nil && (len(rec.UserContent.File.FileUrl) > 0 || rec.UserContent.File.File != nil) { + hasFile = true + } + content.WriteString(rec.UserContent.Text) + if hasFile { + content.WriteString("\n") + } + + if len(rec.UserContent.Tag) > 0 { + content.WriteString("\n") + content.WriteString("### 工具必须使用:") + content.WriteString(rec.UserContent.Tag) + } + + if len(rec.ChatHis.Messages) > 0 { + content.WriteString("### 引用历史聊天记录:\n") + content.WriteString(pkg.JsonStringIgonErr(rec.ChatHis)) + } + + if hasFile { + content.WriteString("\n") + content.WriteString("### 文件内容:\n") + handle.HandleRecognizeFile(rec.UserContent.File) + //...do something with file } return } diff --git a/internal/config/config.go b/internal/config/config.go index 85b8be8..160275f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -9,22 +9,17 @@ import ( // Config 应用配置 type Config struct { - Server ServerConfig `mapstructure:"server"` - Ollama OllamaConfig `mapstructure:"ollama"` - Vllm VllmConfig `mapstructure:"vllm"` - Sys SysConfig `mapstructure:"sys"` - Tools ToolsConfig `mapstructure:"tools"` - Logging LoggingConfig `mapstructure:"logging"` - Redis Redis `mapstructure:"redis"` - DB DB `mapstructure:"db"` - DefaultPrompt SysPrompt `mapstructure:"default_prompt"` - PermissionConfig PermissionConfig `mapstructure:"permissionConfig"` - DingTalkBots map[string]*DingTalkBot `mapstructure:"ding_talk_bots"` -} - -type DingTalkBot struct { - ClientId string `mapstructure:"client_id"` - ClientSecret string `mapstructure:"client_secret"` + Server ServerConfig `mapstructure:"server"` + Ollama OllamaConfig `mapstructure:"ollama"` + Vllm VllmConfig `mapstructure:"vllm"` + Sys SysConfig `mapstructure:"sys"` + Tools ToolsConfig `mapstructure:"tools"` + Logging LoggingConfig `mapstructure:"logging"` + Redis Redis `mapstructure:"redis"` + DB DB `mapstructure:"db"` + DefaultPrompt SysPrompt `mapstructure:"default_prompt"` + PermissionConfig PermissionConfig `mapstructure:"permissionConfig"` + // DingTalkBots map[string]*DingTalkBot `mapstructure:"ding_talk_bots"` } type SysPrompt struct { @@ -57,18 +52,18 @@ type ServerConfig struct { // OllamaConfig Ollama配置 type OllamaConfig struct { - BaseURL string `mapstructure:"base_url"` - Model string `mapstructure:"model"` - GenerateModel string `mapstructure:"generate_model"` - VlModel string `mapstructure:"vl_model"` - Timeout time.Duration `mapstructure:"timeout"` + BaseURL string `mapstructure:"base_url"` + Model string `mapstructure:"model"` + GenerateModel string `mapstructure:"generate_model"` + VlModel string `mapstructure:"vl_model"` + Timeout time.Duration `mapstructure:"timeout"` } type VllmConfig struct { - BaseURL string `mapstructure:"base_url"` - VlModel string `mapstructure:"vl_model"` - Timeout time.Duration `mapstructure:"timeout"` - Level string `mapstructure:"level"` + BaseURL string `mapstructure:"base_url"` + VlModel string `mapstructure:"vl_model"` + Timeout time.Duration `mapstructure:"timeout"` + Level string `mapstructure:"level"` } type Redis struct { diff --git a/internal/data/constants/bot.go b/internal/data/constants/bot.go index bfe004c..2b24cfe 100644 --- a/internal/data/constants/bot.go +++ b/internal/data/constants/bot.go @@ -25,3 +25,9 @@ var ChatStyleMap = map[ChatStyle]string{ ChatStyleCute: "可爱", ChatStyleAngry: "愤怒", } + +type BotType int + +const ( + BotTypeDingTalk BotType = 1 // 系统的bug/优化建议 +) diff --git a/internal/data/impl/base.go b/internal/data/impl/base.go index 7aee0b7..b7abf1a 100644 --- a/internal/data/impl/base.go +++ b/internal/data/impl/base.go @@ -22,7 +22,7 @@ BaseModel 是一个泛型结构体,用于封装GORM数据库通用操作。 // 定义受支持的PO类型集合(可根据需要扩展), 只有包含表结构才能使用BaseModel,避免使用出现问题 type PO interface { model.AiChatHi | - model.AiSy | model.AiSession | model.AiTask | model.AiBot + model.AiSy | model.AiSession | model.AiTask | model.AiBotConfig } type BaseModel[P PO] struct { diff --git a/internal/data/impl/bot_chat_history.go b/internal/data/impl/bot_chat_history.go index 3da5184..7e4de3c 100644 --- a/internal/data/impl/bot_chat_history.go +++ b/internal/data/impl/bot_chat_history.go @@ -10,6 +10,6 @@ type BotChatHisImpl struct { dataTemp.DataTemp } -func NewBotChatHisImpl(db *utils.Db) *TaskImpl { - return &TaskImpl{*dataTemp.NewDataTemp(db, new(model.AiBotChatHi))} +func NewBotChatHisImpl(db *utils.Db) *BotChatHisImpl { + return &BotChatHisImpl{*dataTemp.NewDataTemp(db, new(model.AiBotChatHi))} } diff --git a/internal/data/impl/bot_config.go b/internal/data/impl/bot_config.go new file mode 100644 index 0000000..2c98ffb --- /dev/null +++ b/internal/data/impl/bot_config.go @@ -0,0 +1,17 @@ +package impl + +import ( + "ai_scheduler/internal/data/model" + "ai_scheduler/tmpl/dataTemp" + "ai_scheduler/utils" +) + +type BotConfigImpl struct { + dataTemp.DataTemp +} + +func NewBotConfigImpl(db *utils.Db) *BotConfigImpl { + return &BotConfigImpl{ + DataTemp: *dataTemp.NewDataTemp(db, new(model.AiBotConfig)), + } +} diff --git a/internal/data/impl/bot_impl.go b/internal/data/impl/bot_impl.go deleted file mode 100644 index e76e8de..0000000 --- a/internal/data/impl/bot_impl.go +++ /dev/null @@ -1,28 +0,0 @@ -package impl - -import ( - "ai_scheduler/internal/data/model" - "ai_scheduler/tmpl/dataTemp" - "ai_scheduler/utils" - - "gorm.io/gorm" -) - -type BotImpl struct { - dataTemp.DataTemp - BaseRepository[model.AiBot] -} - -func NewBotImpl(db *utils.Db) *BotImpl { - return &BotImpl{ - DataTemp: *dataTemp.NewDataTemp(db, new(model.AiBot)), - BaseRepository: NewBaseModel[model.AiBot](db.Client), - } -} - -// WithSysId 系统id -func (s *BotImpl) WithSysId(sysId interface{}) CondFunc { - return func(db *gorm.DB) *gorm.DB { - return db.Where("sys_id = ?", sysId) - } -} diff --git a/internal/data/impl/provider_set.go b/internal/data/impl/provider_set.go index 1284f11..53d389c 100644 --- a/internal/data/impl/provider_set.go +++ b/internal/data/impl/provider_set.go @@ -4,4 +4,10 @@ import ( "github.com/google/wire" ) -var ProviderImpl = wire.NewSet(NewSessionImpl, NewSysImpl, NewTaskImpl, NewChatHisImpl) +var ProviderImpl = wire.NewSet( + NewSessionImpl, + NewSysImpl, + NewTaskImpl, + NewChatHisImpl, + NewBotConfigImpl, +) diff --git a/internal/data/model/ai_bot.gen.go b/internal/data/model/ai_bot_config.gen.go similarity index 72% rename from internal/data/model/ai_bot.gen.go rename to internal/data/model/ai_bot_config.gen.go index ffe73c2..5c4f27a 100644 --- a/internal/data/model/ai_bot.gen.go +++ b/internal/data/model/ai_bot_config.gen.go @@ -8,13 +8,13 @@ import ( "time" ) -const TableNameAiBot = "ai_bot" +const TableNameAiBotConfig = "ai_bot_config" -// AiBot mapped from table -type AiBot struct { +// AiBotConfig mapped from table +type AiBotConfig struct { BotID int32 `gorm:"column:bot_id;primaryKey;autoIncrement:true" json:"bot_id"` SysID int32 `gorm:"column:sys_id;not null" json:"sys_id"` - BotType int32 `gorm:"column:bot_type" json:"bot_type"` + BotType int32 `gorm:"column:bot_type;not null;default:1" json:"bot_type"` BotName string `gorm:"column:bot_name;not null" json:"bot_name"` BotConfig string `gorm:"column:bot_config;not null" json:"bot_config"` CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"` @@ -23,7 +23,7 @@ type AiBot struct { DeleteAt time.Time `gorm:"column:delete_at" json:"delete_at"` } -// TableName AiBot's table name -func (*AiBot) TableName() string { - return TableNameAiBot +// TableName AiBotConfig's table name +func (*AiBotConfig) TableName() string { + return TableNameAiBotConfig } diff --git a/internal/entitys/bot.go b/internal/entitys/bot.go index 0a99113..fc24764 100644 --- a/internal/entitys/bot.go +++ b/internal/entitys/bot.go @@ -22,3 +22,8 @@ type RequireDataDingTalkBot struct { ImgByte []api.ImageData ImgUrls []string } + +type DingTalkBot struct { + ClientId string `mapstructure:"client_id"` + ClientSecret string `mapstructure:"client_secret"` +} diff --git a/internal/server/ding_talk_bot.go b/internal/server/ding_talk_bot.go index cda0e1c..1d6c4a4 100644 --- a/internal/server/ding_talk_bot.go +++ b/internal/server/ding_talk_bot.go @@ -1,7 +1,7 @@ package server import ( - "ai_scheduler/internal/config" + "ai_scheduler/internal/entitys" "ai_scheduler/internal/services" "context" @@ -11,7 +11,7 @@ import ( ) type DingBotServiceInterface interface { - GetServiceCfg(cfg map[string]*config.DingTalkBot) (*config.DingTalkBot, string) + GetServiceCfg() ([]entitys.DingTalkBot, error) OnChatBotMessageReceived(ctx context.Context, data *chatbot.BotCallbackDataModel) (content []byte, err error) } @@ -19,23 +19,28 @@ type DingTalkBotServer struct { Clients []*client.StreamClient } +// NewDingTalkBotServer 批量注册钉钉客户端cli +// 这里支持两种方式,一种是完全独立service,一种是直接用现成的service +// 独立的service,在本页的ProvideAllDingBotServices方法进行注册 +// 现成的service参考services->dtalk_bot.go +// 具体使用请根据实际业务需求 func NewDingTalkBotServer( - cfg *config.Config, services []DingBotServiceInterface, ) *DingTalkBotServer { clients := make([]*client.StreamClient, 0) for _, service := range services { - serviceConf, index := service.GetServiceCfg(cfg.DingTalkBots) - if serviceConf == nil { - log.Info("未找到%s配置", index) - continue + serviceConfigs, index := service.GetServiceCfg() + for _, serviceConf := range serviceConfigs { + if serviceConf.ClientId == "" || serviceConf.ClientSecret == "" { + continue + } + cli := DingBotServerInit(serviceConf.ClientId, serviceConf.ClientSecret, service) + if cli == nil { + log.Info("%s客户端初始失败", index) + continue + } + clients = append(clients, cli) } - cli := DingBotServerInit(serviceConf.ClientId, serviceConf.ClientSecret, service) - if cli == nil { - log.Info("%s客户端初始失败", index) - continue - } - clients = append(clients, cli) } return &DingTalkBotServer{ Clients: clients, diff --git a/internal/services/dtalk_bot.go b/internal/services/dtalk_bot.go index 2855471..864f5d7 100644 --- a/internal/services/dtalk_bot.go +++ b/internal/services/dtalk_bot.go @@ -3,6 +3,7 @@ package services import ( "ai_scheduler/internal/biz" "ai_scheduler/internal/config" + "ai_scheduler/internal/entitys" "context" "fmt" @@ -11,41 +12,39 @@ import ( ) type DingBotService struct { - config *config.Config - replier *chatbot.ChatbotReplier - env string - DingTalkBotBiz *biz.DingTalkBotBiz + config *config.Config + + dingTalkBotBiz *biz.DingTalkBotBiz } func NewDingBotService(config *config.Config, DingTalkBotBiz *biz.DingTalkBotBiz) *DingBotService { - return &DingBotService{config: config, replier: chatbot.NewChatbotReplier(), env: "public", DingTalkBotBiz: DingTalkBotBiz} + return &DingBotService{config: config, dingTalkBotBiz: DingTalkBotBiz} } -func (d *DingBotService) GetServiceCfg(cfg map[string]*config.DingTalkBot) (*config.DingTalkBot, string) { - return cfg[d.env], d.env +func (d *DingBotService) GetServiceCfg() ([]entitys.DingTalkBot, error) { + return d.dingTalkBotBiz.GetDingTalkBotCfgList() } func (d *DingBotService) OnChatBotMessageReceived(ctx context.Context, data *chatbot.BotCallbackDataModel) (content []byte, err error) { - requireData, err := d.DingTalkBotBiz.InitRequire(ctx, data) + requireData, err := d.dingTalkBotBiz.InitRequire(ctx, data) if err != nil { return } go func() { defer close(requireData.Ch) - - if recognizeErr := d.DingTalkBotBiz.Recognize(ctx, requireData); recognizeErr != nil { - requireData.Ch <- entitys.Response{ - Type: entitys.ResponseEnd, - Content: fmt.Sprintf("处理消息时出错: %v", recognizeErr), - } - } - //向下传递 - if err = d.handle.HandleMatch(ctx, nil, requireData); err != nil { - requireData.Ch <- entitys.Response{ - Type: entitys.ResponseEnd, - Content: fmt.Sprintf("匹配失败: %v", err), - } - } + //if match, _err := d.dingTalkBotBiz.Recognize(ctx, data); _err != nil { + // requireData.Ch <- entitys.Response{ + // Type: entitys.ResponseEnd, + // Content: fmt.Sprintf("处理消息时出错: %s", _err.Error()), + // } + //} + ////向下传递 + //if err = d.dingTalkBotBiz.HandleMatch(ctx, nil, requireData); err != nil { + // requireData.Ch <- entitys.Response{ + // Type: entitys.ResponseEnd, + // Content: fmt.Sprintf("匹配失败: %v", err), + // } + //} }() for { select { @@ -59,77 +58,10 @@ func (d *DingBotService) OnChatBotMessageReceived(ctx context.Context, data *cha if resp.Type == entitys.ResponseLog { return } - if err := d.handleRes(ctx, data, resp); err != nil { + if err := d.dingTalkBotBiz.HandleRes(ctx, data, resp); err != nil { return nil, fmt.Errorf("回复失败: %w", err) } } } return } - -func (d *DingBotService) handleRes(ctx context.Context, data *chatbot.BotCallbackDataModel, resp entitys.Response) error { - switch resp.Type { - case entitys.ResponseText: - return d.replyText(ctx, data.SessionWebhook, resp.Content) - case entitys.ResponseStream: - return d.replySteam(ctx, data.SessionWebhook, resp.Content) - case entitys.ResponseImg: - return d.replyImg(ctx, data.SessionWebhook, resp.Content) - case entitys.ResponseFile: - return d.replyFile(ctx, data.SessionWebhook, resp.Content) - case entitys.ResponseMarkdown: - return d.replyMarkdown(ctx, data.SessionWebhook, resp.Content) - case entitys.ResponseActionCard: - return d.replyActionCard(ctx, data.SessionWebhook, resp.Content) - default: - return nil - } -} - -func (d *DingBotService) replyText(ctx context.Context, SessionWebhook string, content string, arg ...string) error { - msg := content - if len(arg) > 0 { - msg = fmt.Sprintf(content, arg) - } - return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg)) -} - -func (d *DingBotService) replySteam(ctx context.Context, SessionWebhook string, content string, arg ...string) error { - msg := content - if len(arg) > 0 { - msg = fmt.Sprintf(content, arg) - } - return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg)) -} - -func (d *DingBotService) replyImg(ctx context.Context, SessionWebhook string, content string, arg ...string) error { - msg := content - if len(arg) > 0 { - msg = fmt.Sprintf(content, arg) - } - return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg)) -} - -func (d *DingBotService) replyFile(ctx context.Context, SessionWebhook string, content string, arg ...string) error { - msg := content - if len(arg) > 0 { - msg = fmt.Sprintf(content, arg) - } - return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg)) -} - -func (d *DingBotService) replyMarkdown(ctx context.Context, SessionWebhook string, content string, arg ...string) error { - msg := content - if len(arg) > 0 { - msg = fmt.Sprintf(content, arg) - } - return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg)) -} - -func (d *DingBotService) replyActionCard(ctx context.Context, SessionWebhook string, content string, arg ...string) error { - msg := content - if len(arg) > 0 { - msg = fmt.Sprintf(content, arg) - } - return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg)) -} diff --git a/internal/services/provider_set.go b/internal/services/provider_set.go index 668c7fe..867eb11 100644 --- a/internal/services/provider_set.go +++ b/internal/services/provider_set.go @@ -6,4 +6,10 @@ import ( "github.com/google/wire" ) -var ProviderSetServices = wire.NewSet(NewChatService, NewSessionService, gateway.NewGateway, NewTaskService, NewCallbackService, NewHistoryService) +var ProviderSetServices = wire.NewSet( + NewChatService, + NewSessionService, gateway.NewGateway, + NewTaskService, + NewCallbackService, + NewDingBotService, + NewHistoryService) From 567874848bb256106d9c590b13c9fbe5370b1640 Mon Sep 17 00:00:00 2001 From: renzhiyuan <465386466@qq.com> Date: Wed, 10 Dec 2025 11:47:39 +0800 Subject: [PATCH 22/66] =?UTF-8?q?fix:=E4=BF=AE=E5=A4=8D=E6=8A=A5=E9=94=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/entitys/recognize.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/entitys/recognize.go b/internal/entitys/recognize.go index 6d692eb..fc8b55f 100644 --- a/internal/entitys/recognize.go +++ b/internal/entitys/recognize.go @@ -20,7 +20,7 @@ type RegistrationTask struct { type RecognizeUserContent struct { Text string - File *RecognizeFile + File []*RecognizeFile ActionCardUrl string Tag string } From d7b39f4d60a8db83c8d20c4fcc0371bc8effda5e Mon Sep 17 00:00:00 2001 From: renzhiyuan <465386466@qq.com> Date: Wed, 10 Dec 2025 14:19:35 +0800 Subject: [PATCH 23/66] =?UTF-8?q?fix:=E4=BF=AE=E5=A4=8D=E6=8A=A5=E9=94=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cmd/server/main.go | 6 ++++-- deploy.sh | 4 +++- internal/biz/do/prompt.go | 15 ++++++--------- internal/biz/handle/file.go | 2 +- internal/entitys/bot.go | 5 +++-- internal/server/ding_talk_bot.go | 17 +++++++++++------ 6 files changed, 28 insertions(+), 21 deletions(-) diff --git a/cmd/server/main.go b/cmd/server/main.go index a6607f5..d765735 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -2,6 +2,7 @@ package main import ( "ai_scheduler/internal/config" + "context" "flag" "fmt" @@ -10,6 +11,7 @@ import ( func main() { configPath := flag.String("config", "./config/config_test.yaml", "Path to configuration file") + onBot := flag.String("bot", "", "bot start") flag.Parse() bc, err := config.LoadConfig(*configPath) if err != nil { @@ -23,7 +25,7 @@ func main() { defer func() { cleanup() }() - //app.DingBotServer.Run(context.Background()) - //app.DingBotServer.RunBots(app.DingBotServer.BotServices) + app.DingBotServer.Run(context.Background(), *onBot) + log.Fatal(app.HttpServer.Listen(fmt.Sprintf(":%d", bc.Server.Port))) } diff --git a/deploy.sh b/deploy.sh index f02e162..2766253 100644 --- a/deploy.sh +++ b/deploy.sh @@ -14,8 +14,10 @@ fi CONFIG_FILE="config/config.yaml" BRANCH="master" +BOT="ALL" if [ "$MODE" = "dev" ]; then CONFIG_FILE="config/config_test.yaml" + BOT="zltx" BRANCH="test" fi @@ -33,6 +35,6 @@ docker run -itd \ -e "OLLAMA_BASE_URL=${OLLAMA_BASE_URL:-http://host.docker.internal:11434}" \ -e "MODE=${MODE}" \ -p 8090:8090 \ - "${CONTAINER_NAME}" ./server --config "./${CONFIG_FILE}" + "${CONTAINER_NAME}" ./server --config "./${CONFIG_FILE}" --bot "./${BOT}" docker logs -f ${CONTAINER_NAME} \ No newline at end of file diff --git a/internal/biz/do/prompt.go b/internal/biz/do/prompt.go index f25cf95..ea82f84 100644 --- a/internal/biz/do/prompt.go +++ b/internal/biz/do/prompt.go @@ -43,7 +43,7 @@ func (f *WithSys) CreatePrompt(ctx context.Context, rec *entitys.Recognize) (mes func (f *WithSys) getUserContent(ctx context.Context, rec *entitys.Recognize) (content strings.Builder, err error) { var hasFile bool - if rec.UserContent.File != nil && (len(rec.UserContent.File.FileUrl) > 0 || rec.UserContent.File.File != nil) { + if rec.UserContent.File != nil && len(rec.UserContent.File) > 0 { hasFile = true } content.WriteString(rec.UserContent.Text) @@ -65,7 +65,10 @@ func (f *WithSys) getUserContent(ctx context.Context, rec *entitys.Recognize) (c if hasFile { content.WriteString("\n") content.WriteString("### 文件内容:\n") - handle.HandleRecognizeFile(rec.UserContent.File) + for _, file := range rec.UserContent.File { + handle.HandleRecognizeFile(file) + } + //...do something with file } return @@ -100,7 +103,7 @@ func (f *WithDingTalkBot) CreatePrompt(ctx context.Context, rec *entitys.Recogni func (f *WithDingTalkBot) getUserContent(ctx context.Context, rec *entitys.Recognize) (content strings.Builder, err error) { var hasFile bool - if rec.UserContent.File != nil && (len(rec.UserContent.File.FileUrl) > 0 || rec.UserContent.File.File != nil) { + if rec.UserContent.File != nil && len(rec.UserContent.File) > 0 { hasFile = true } content.WriteString(rec.UserContent.Text) @@ -119,11 +122,5 @@ func (f *WithDingTalkBot) getUserContent(ctx context.Context, rec *entitys.Recog content.WriteString(pkg.JsonStringIgonErr(rec.ChatHis)) } - if hasFile { - content.WriteString("\n") - content.WriteString("### 文件内容:\n") - handle.HandleRecognizeFile(rec.UserContent.File) - //...do something with file - } return } diff --git a/internal/biz/handle/file.go b/internal/biz/handle/file.go index 9fd4fae..e9332aa 100644 --- a/internal/biz/handle/file.go +++ b/internal/biz/handle/file.go @@ -18,7 +18,7 @@ import ( // 判断文件大小 // 判断文件类型 // 判断文件是否合法 -func HandleRecognizeFile(file *entitys.RecognizeFile) { +func HandleRecognizeFile(files *entitys.RecognizeFile) { //Todo 仲云 return } diff --git a/internal/entitys/bot.go b/internal/entitys/bot.go index fc24764..568822c 100644 --- a/internal/entitys/bot.go +++ b/internal/entitys/bot.go @@ -24,6 +24,7 @@ type RequireDataDingTalkBot struct { } type DingTalkBot struct { - ClientId string `mapstructure:"client_id"` - ClientSecret string `mapstructure:"client_secret"` + BotIndex string + ClientId string + ClientSecret string } diff --git a/internal/server/ding_talk_bot.go b/internal/server/ding_talk_bot.go index 1d6c4a4..9a63812 100644 --- a/internal/server/ding_talk_bot.go +++ b/internal/server/ding_talk_bot.go @@ -16,7 +16,7 @@ type DingBotServiceInterface interface { } type DingTalkBotServer struct { - Clients []*client.StreamClient + Clients map[string]*client.StreamClient } // NewDingTalkBotServer 批量注册钉钉客户端cli @@ -27,19 +27,19 @@ type DingTalkBotServer struct { func NewDingTalkBotServer( services []DingBotServiceInterface, ) *DingTalkBotServer { - clients := make([]*client.StreamClient, 0) + clients := make(map[string]*client.StreamClient) for _, service := range services { - serviceConfigs, index := service.GetServiceCfg() + serviceConfigs, err := service.GetServiceCfg() for _, serviceConf := range serviceConfigs { if serviceConf.ClientId == "" || serviceConf.ClientSecret == "" { continue } cli := DingBotServerInit(serviceConf.ClientId, serviceConf.ClientSecret, service) if cli == nil { - log.Info("%s客户端初始失败", index) + log.Info("%s客户端初始失败:%s", serviceConf.BotIndex, err.Error()) continue } - clients = append(clients, cli) + clients[serviceConf.BotIndex] = cli } } return &DingTalkBotServer{ @@ -53,8 +53,13 @@ func ProvideAllDingBotServices( return []DingBotServiceInterface{dingBotSvc} } -func (d *DingTalkBotServer) Run(ctx context.Context) { +func (d *DingTalkBotServer) Run(ctx context.Context, botIndex string) { for name, cli := range d.Clients { + if botIndex != "All" { + if name != botIndex { + continue + } + } err := cli.Start(ctx) if err != nil { log.Info("%s启动失败", name) From 14a5fe5974dd1df767edf4f77374de340a827872 Mon Sep 17 00:00:00 2001 From: fuzhongyun <15339891972@163.com> Date: Wed, 10 Dec 2025 14:39:30 +0800 Subject: [PATCH 24/66] =?UTF-8?q?fix:=201.=E8=B0=83=E6=95=B4api=E7=9B=B4?= =?UTF-8?q?=E8=BF=9E=E8=BE=93=E5=87=BA=202.=E5=A2=9E=E5=8A=A0=E7=9F=A5?= =?UTF-8?q?=E8=AF=86=E5=BA=93loading=E8=BE=93=E5=87=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/biz/do/handle.go | 3 +-- internal/tools/konwledge_base.go | 1 + 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/biz/do/handle.go b/internal/biz/do/handle.go index ebebf83..16af5f8 100644 --- a/internal/biz/do/handle.go +++ b/internal/biz/do/handle.go @@ -9,7 +9,6 @@ import ( "ai_scheduler/internal/data/model" "ai_scheduler/internal/entitys" "ai_scheduler/internal/gateway" - "ai_scheduler/internal/pkg" "ai_scheduler/internal/pkg/l_request" "ai_scheduler/internal/pkg/mapstructure" "ai_scheduler/internal/tools" @@ -269,7 +268,7 @@ func (r *Handle) handleApiTask(ctx context.Context, requireData *entitys.Require if err != nil { return } - entitys.ResJson(requireData.Ch, "", pkg.JsonStringIgonErr(res.Text)) + entitys.ResJson(requireData.Ch, "", res.Text) return } diff --git a/internal/tools/konwledge_base.go b/internal/tools/konwledge_base.go index fb5a7f1..dbbb5ce 100644 --- a/internal/tools/konwledge_base.go +++ b/internal/tools/konwledge_base.go @@ -59,6 +59,7 @@ func (k *KnowledgeBaseTool) Definition() entitys.ToolDefinition { // Execute 执行知识库查询 func (k *KnowledgeBaseTool) Execute(ctx context.Context, requireData *entitys.RequireData) error { + entitys.ResLoading(requireData.Ch, k.Name(), "正在为您搜索相关信息") return k.chat(requireData) From 7c973750dddca84353bd7e8ec2ea84de8e4ce599 Mon Sep 17 00:00:00 2001 From: wolter <11@gmail> Date: Wed, 10 Dec 2025 15:23:02 +0800 Subject: [PATCH 25/66] =?UTF-8?q?feat:=E6=84=8F=E5=9B=BE=E8=AF=86=E5=88=AB?= =?UTF-8?q?=E6=8B=86=E5=88=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/biz/router.go | 98 ++++++++++++++++++++++++++++++- internal/data/constants/caller.go | 8 +++ internal/entitys/recognize.go | 14 ++--- internal/entitys/types.go | 4 +- 4 files changed, 114 insertions(+), 10 deletions(-) diff --git a/internal/biz/router.go b/internal/biz/router.go index 183ee89..c164c31 100644 --- a/internal/biz/router.go +++ b/internal/biz/router.go @@ -2,8 +2,11 @@ package biz import ( "ai_scheduler/internal/biz/do" + "ai_scheduler/internal/data/constants" "ai_scheduler/internal/gateway" "context" + "encoding/json" + "time" "ai_scheduler/internal/entitys" @@ -77,6 +80,99 @@ func (r *AiRouterBiz) RouteWithSocket(client *gateway.Client, req *entitys.ChatS } func (r *AiRouterBiz) SetRec(ctx context.Context, requireData *entitys.RequireData) (match entitys.Recognize, err error) { - //TODO 叙平 + // 参数空值检查 + if requireData == nil || requireData.Req == nil { + return match, err + } + + // 1. 系统提示词 + match.SystemPrompt = requireData.Sys.SysPrompt + + // 2. 用户输入和文件处理 + match.UserContent, err = r.buildUserContent(requireData) + if err != nil { + log.Errorf("构建用户内容失败: %s", err.Error()) + return + } + + // 3. 聊天记录 - 只有在有历史记录时才构建 + if len(requireData.Histories) > 0 { + match.ChatHis = r.buildChatHistory(requireData) + } + + // 4. 任务列表 - 预分配切片容量 + if len(requireData.Tasks) > 0 { + match.Tasks = make([]entitys.RegistrationTask, 0, len(requireData.Tasks)) + for _, task := range requireData.Tasks { + taskConfig := entitys.TaskConfigDetail{} + if err = json.Unmarshal([]byte(task.Config), &taskConfig); err != nil { + log.Errorf("解析任务配置失败: %s, 任务ID: %s", err.Error(), task.Index) + continue // 解析失败时跳过该任务,而不是直接返回错误 + } + + match.Tasks = append(match.Tasks, entitys.RegistrationTask{ + Name: task.Index, + Desc: task.Desc, + TaskConfigDetail: taskConfig, // 直接使用解析后的配置,避免重复构建 + }) + } + } + + match.Ch = requireData.Ch return } + +// buildUserContent 构建用户内容 +func (r *AiRouterBiz) buildUserContent(requireData *entitys.RequireData) (*entitys.RecognizeUserContent, error) { + // 预分配文件切片容量(最多2个文件:File和Img) + files := make([]*entitys.RecognizeFile, 0, 2) + + // 处理文件和图片 + fileUrls := []string{requireData.Req.File, requireData.Req.Img} + for _, url := range fileUrls { + if url != "" { + files = append(files, &entitys.RecognizeFile{FileUrl: url}) + } + } + + // 构建并返回用户内容 + return &entitys.RecognizeUserContent{ + Text: requireData.Req.Text, + File: files, + ActionCardUrl: "", // TODO: 后续实现操作卡片功能 + Tag: requireData.Req.Tags, + }, nil +} + +// buildChatHistory 构建聊天历史 +func (r *AiRouterBiz) buildChatHistory(requireData *entitys.RequireData) entitys.ChatHis { + // 预分配消息切片容量(每个历史记录生成2条消息) + messages := make([]entitys.HisMessage, 0, len(requireData.Histories)*2) + + // 构建聊天记录 + for _, h := range requireData.Histories { + // 用户消息 + messages = append(messages, entitys.HisMessage{ + Role: constants.RoleUser, // 用户角色 + Content: h.Ans, // 用户输入内容 + Timestamp: h.CreateAt.Format(time.DateTime), + }) + + // 助手消息 + messages = append(messages, entitys.HisMessage{ + Role: constants.RoleAssistant, // 助手角色 + Content: h.Ques, // 助手回复内容 + Timestamp: h.CreateAt.Format(time.DateTime), + }) + } + + // 构建聊天历史上下文 + return entitys.ChatHis{ + SessionId: requireData.Session, + Messages: messages, + Context: entitys.HisContext{ + UserLanguage: constants.LangZhCN, // 默认中文 + SystemMode: constants.SystemModeTechnicalSupport, // 默认技术支持模式 + }, + } +} diff --git a/internal/data/constants/caller.go b/internal/data/constants/caller.go index 74493f3..0df9b33 100644 --- a/internal/data/constants/caller.go +++ b/internal/data/constants/caller.go @@ -13,6 +13,14 @@ const ( // 分页默认条数 ChatHistoryLimit = 10 + + // 语言 + LangZhCN = "zh-CN" // 中文 + + // 系统模式 + SystemModeDefault = "default" // 默认模式 + // 系统模式 "technical_support", // 技术支持模式 + SystemModeTechnicalSupport = "technical_support" // 技术支持模式 ) func (c Caller) String() string { diff --git a/internal/entitys/recognize.go b/internal/entitys/recognize.go index fc8b55f..45d9fcc 100644 --- a/internal/entitys/recognize.go +++ b/internal/entitys/recognize.go @@ -5,9 +5,9 @@ import ( ) type Recognize struct { - SystemPrompt string - UserContent *RecognizeUserContent - ChatHis ChatHis + SystemPrompt string // 系统提示内容 + UserContent *RecognizeUserContent // 用户输入内容 + ChatHis ChatHis // 会话历史记录 Tasks []RegistrationTask Ch chan Response } @@ -19,10 +19,10 @@ type RegistrationTask struct { } type RecognizeUserContent struct { - Text string - File []*RecognizeFile - ActionCardUrl string - Tag string + Text string // 用户输入的文本内容 + File []*RecognizeFile // 文件内容 + ActionCardUrl string // 操作卡片链接 + Tag string // 工具标签 } type FileData []byte diff --git a/internal/entitys/types.go b/internal/entitys/types.go index ef229a7..72ecd6c 100644 --- a/internal/entitys/types.go +++ b/internal/entitys/types.go @@ -138,8 +138,8 @@ type HisMessage struct { } type HisContext struct { - UserLanguage string `json:"user_language"` - SystemMode string `json:"system_mode"` + UserLanguage string `json:"user_language"` // 用户语言 + SystemMode string `json:"system_mode"` // 系统模式, } type RequireData struct { From af84e74c2c6ea8afafb61935c469ccb576efc02e Mon Sep 17 00:00:00 2001 From: wolter <11@gmail> Date: Wed, 10 Dec 2025 16:08:59 +0800 Subject: [PATCH 26/66] =?UTF-8?q?feat:=E9=85=8D=E7=BD=AE=E4=B8=8D=E5=90=8C?= =?UTF-8?q?=E7=9A=84=E7=B3=BB=E7=BB=9F=E6=8F=90=E7=A4=BA=E8=AF=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/biz/router.go | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/internal/biz/router.go b/internal/biz/router.go index c164c31..a711445 100644 --- a/internal/biz/router.go +++ b/internal/biz/router.go @@ -3,6 +3,7 @@ package biz import ( "ai_scheduler/internal/biz/do" "ai_scheduler/internal/data/constants" + errors "ai_scheduler/internal/data/error" "ai_scheduler/internal/gateway" "context" "encoding/json" @@ -59,13 +60,13 @@ func (r *AiRouterBiz) RouteWithSocket(client *gateway.Client, req *entitys.ChatS return } //组装意图识别 - rec, err := r.SetRec(ctx, requireData) + rec, sys, err := r.SetRec(ctx, requireData) if err != nil { log.Errorf("组装意图识别失败: %s", err.Error()) return } //意图识别 - requireData.Match, err = r.handle.Recognize(ctx, &rec, &do.WithSys{}) + requireData.Match, err = r.handle.Recognize(ctx, &rec, sys) if err != nil { log.Errorf("意图识别失败: %s", err.Error()) return @@ -79,10 +80,16 @@ func (r *AiRouterBiz) RouteWithSocket(client *gateway.Client, req *entitys.ChatS return } -func (r *AiRouterBiz) SetRec(ctx context.Context, requireData *entitys.RequireData) (match entitys.Recognize, err error) { +func (r *AiRouterBiz) SetRec(ctx context.Context, requireData *entitys.RequireData) (match entitys.Recognize, sys do.PromptOption, err error) { // 参数空值检查 if requireData == nil || requireData.Req == nil { - return match, err + return match, sys, errors.NewBusinessErr(500, "请求参数为空") + } + + // 对应不同的appKey, 配置不同的系统提示词 + switch requireData.Sys.AppKey { + default: + sys = &do.WithSys{} } // 1. 系统提示词 From f0a4008896fa22039ba7c6e01319127f812982e3 Mon Sep 17 00:00:00 2001 From: fuzhongyun <15339891972@163.com> Date: Wed, 10 Dec 2025 16:09:16 +0800 Subject: [PATCH 27/66] =?UTF-8?q?fix=EF=BC=9A=20=E8=B0=83=E6=95=B4go?= =?UTF-8?q?=E7=89=88=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 073c4eb..0f01662 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module ai_scheduler -go 1.24.0 +go 1.24.7 require ( gitea.cdlsxd.cn/self-tools/l_request v1.0.8 From 15c3febe76c681f605903cc7ece074f12324ba6c Mon Sep 17 00:00:00 2001 From: fuzhongyun <15339891972@163.com> Date: Thu, 11 Dec 2025 16:41:46 +0800 Subject: [PATCH 28/66] =?UTF-8?q?chore:=20=E6=9A=82=E5=AD=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/biz/do/handle.go | 26 ++++++++++- internal/data/constants/const.go | 9 ++-- internal/domain/workflow/manager.go | 1 + .../zltx/order_after_reseller_batch.go | 5 -- .../zltx/order_after_reseller_batch_test.go | 46 +++++++++++++++++++ 5 files changed, 77 insertions(+), 10 deletions(-) create mode 100644 internal/domain/workflow/manager.go create mode 100644 internal/domain/workflow/zltx/order_after_reseller_batch_test.go diff --git a/internal/biz/do/handle.go b/internal/biz/do/handle.go index 2e2a313..fa7b3c5 100644 --- a/internal/biz/do/handle.go +++ b/internal/biz/do/handle.go @@ -7,10 +7,12 @@ import ( errors "ai_scheduler/internal/data/error" "ai_scheduler/internal/data/impl" "ai_scheduler/internal/data/model" + "ai_scheduler/internal/domain/workflow/zltx" "ai_scheduler/internal/entitys" "ai_scheduler/internal/gateway" "ai_scheduler/internal/pkg/l_request" "ai_scheduler/internal/pkg/mapstructure" + "ai_scheduler/internal/pkg/util" "ai_scheduler/internal/tools" "ai_scheduler/internal/tools_bot" "context" @@ -110,6 +112,8 @@ func (r *Handle) HandleMatch(ctx context.Context, client *gateway.Client, requir return r.handleKnowle(ctx, requireData, pointTask) case constants.TaskTypeBot: return r.handleBot(ctx, requireData, pointTask) + case constants.TaskTypeEinoWorkflow: + return r.handleEinoWorkflow(ctx, requireData, pointTask) default: return r.handleOtherTask(ctx, requireData) } @@ -265,11 +269,31 @@ func (r *Handle) handleApiTask(ctx context.Context, requireData *entitys.Require err = errors.NewBusinessErr(422, "api地址获取失败") return } + + entitys.ResLoading(requireData.Ch, requireData.Task.Index, "正在请求数据") + res, err := request.Send() if err != nil { return } - entitys.ResJson(requireData.Ch, "", res.Text) + entitys.ResJson(requireData.Ch, requireData.Task.Index, res.Text) + + return +} + +// eino 工作流 +func (r *Handle) handleEinoWorkflow(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) { + // token 写入ctx + ctx = util.SetTokenToContext(ctx, requireData.Auth) + + // 构建工作流 - todo 示例,后续抽象出来 + zltxWorkflow, err := zltx.BuildOrderAfterResellerBatchWorkflow(ctx, r.conf.Tools.ZltxOrderAfterSaleResellerBatch) + if err != nil { + return + } + + // 工作流执行 + _, err = zltxWorkflow.Invoke(ctx, zltx.OrderAfterSaleResellerBatchWorkflowInput{}) return } diff --git a/internal/data/constants/const.go b/internal/data/constants/const.go index f06e906..151d259 100644 --- a/internal/data/constants/const.go +++ b/internal/data/constants/const.go @@ -11,10 +11,11 @@ const ( type TaskType int32 const ( - TaskTypeApi TaskType = 1 - TaskTypeKnowle TaskType = 2 - TaskTypeFunc TaskType = 3 - TaskTypeBot TaskType = 4 + TaskTypeApi TaskType = 1 + TaskTypeKnowle TaskType = 2 + TaskTypeFunc TaskType = 3 + TaskTypeBot TaskType = 4 + TaskTypeEinoWorkflow TaskType = 5 // eino 工作流 ) type UseFul int32 diff --git a/internal/domain/workflow/manager.go b/internal/domain/workflow/manager.go new file mode 100644 index 0000000..0e59ea2 --- /dev/null +++ b/internal/domain/workflow/manager.go @@ -0,0 +1 @@ +package workflow diff --git a/internal/domain/workflow/zltx/order_after_reseller_batch.go b/internal/domain/workflow/zltx/order_after_reseller_batch.go index af41488..5c3081a 100644 --- a/internal/domain/workflow/zltx/order_after_reseller_batch.go +++ b/internal/domain/workflow/zltx/order_after_reseller_batch.go @@ -10,12 +10,10 @@ import ( ) type OrderAfterSaleResellerBatchWorkflowInput struct { - Token string `json:"token"` OrderNumber []string `json:"orderNumber"` } var ErrInvalidOrderNumbers = errors.New("orderNumber 不能为空") -var ErrMissingToken = errors.New("token 不能为空") func BuildOrderAfterResellerBatchWorkflow(ctx context.Context, cfg config.ToolConfig) (compose.Runnable[OrderAfterSaleResellerBatchWorkflowInput, toolZoarb.OrderAfterSaleResellerBatchData], error) { // 定义工作流、出入参 @@ -26,9 +24,6 @@ func BuildOrderAfterResellerBatchWorkflow(ctx context.Context, cfg config.ToolCo if len(in.OrderNumber) == 0 { return OrderAfterSaleResellerBatchWorkflowInput{}, ErrInvalidOrderNumbers } - if in.Token == "" { - return OrderAfterSaleResellerBatchWorkflowInput{}, ErrMissingToken - } return in, nil })) diff --git a/internal/domain/workflow/zltx/order_after_reseller_batch_test.go b/internal/domain/workflow/zltx/order_after_reseller_batch_test.go new file mode 100644 index 0000000..be56786 --- /dev/null +++ b/internal/domain/workflow/zltx/order_after_reseller_batch_test.go @@ -0,0 +1,46 @@ +package zltx + +import ( + "ai_scheduler/internal/config" + "context" + "errors" + "strings" + "testing" +) + +func TestOrderAfterResellerBatch_InvalidOrderNumbers(t *testing.T) { + ctx := context.Background() + cfg := config.ToolConfig{} + + run, err := BuildOrderAfterResellerBatchWorkflow(ctx, cfg) + if err != nil { + t.Fatalf("build workflow error: %v", err) + } + + _, err = run.Invoke(ctx, OrderAfterSaleResellerBatchWorkflowInput{}) + if err == nil { + t.Fatalf("expected error, got nil") + } + if !errors.Is(err, ErrInvalidOrderNumbers) { + t.Fatalf("expected ErrInvalidOrderNumbers, got %v", err) + } +} + +func TestOrderAfterResellerBatch_ContextTokenRequired(t *testing.T) { + ctx := context.Background() + cfg := config.ToolConfig{} + + run, err := BuildOrderAfterResellerBatchWorkflow(ctx, cfg) + if err != nil { + t.Fatalf("build workflow error: %v", err) + } + + in := OrderAfterSaleResellerBatchWorkflowInput{OrderNumber: []string{"123"}} + _, err = run.Invoke(ctx, in) + if err == nil { + t.Fatalf("expected error, got nil") + } + if !strings.Contains(err.Error(), "token 未注入") { + t.Fatalf("expected contains 'token 未注入', got %v", err) + } +} From c73327e80579bc4aecbc4ff3663dce785d2e1d32 Mon Sep 17 00:00:00 2001 From: wolter <11@gmail> Date: Thu, 11 Dec 2025 17:03:10 +0800 Subject: [PATCH 29/66] =?UTF-8?q?feat:=E5=A4=A9=E6=B0=94=E5=B7=A5=E5=85=B7?= =?UTF-8?q?,=E8=B0=83=E6=95=B4tools?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/config_test.yaml | 4 + internal/biz/do/ctx.go | 10 +- internal/biz/do/handle.go | 5 +- internal/tools/calculator.go | 121 ------ internal/tools/manager.go | 18 +- internal/tools/{ => public}/konwledge_base.go | 2 +- .../tools/{ => public}/konwledge_base_test.go | 2 +- internal/tools/{ => public}/normal_chat.go | 2 +- internal/tools/public/weather.go | 345 ++++++++++++++++++ internal/tools/weather.go | 139 ------- .../tools/{ => zltx}/zltx_after_direct.go | 2 +- internal/tools/{ => zltx}/zltx_after_pre.go | 2 +- .../tools/{ => zltx}/zltx_order_detail.go | 2 +- .../tools/{ => zltx}/zltx_order_direct_log.go | 2 +- internal/tools/{ => zltx}/zltx_product.go | 2 +- internal/tools/{ => zltx}/zltx_statistics.go | 2 +- 16 files changed, 379 insertions(+), 281 deletions(-) delete mode 100644 internal/tools/calculator.go rename internal/tools/{ => public}/konwledge_base.go (99%) rename internal/tools/{ => public}/konwledge_base_test.go (97%) rename internal/tools/{ => public}/normal_chat.go (99%) create mode 100644 internal/tools/public/weather.go delete mode 100644 internal/tools/weather.go rename internal/tools/{ => zltx}/zltx_after_direct.go (99%) rename internal/tools/{ => zltx}/zltx_after_pre.go (99%) rename internal/tools/{ => zltx}/zltx_order_detail.go (99%) rename internal/tools/{ => zltx}/zltx_order_direct_log.go (99%) rename internal/tools/{ => zltx}/zltx_product.go (99%) rename internal/tools/{ => zltx}/zltx_statistics.go (99%) diff --git a/config/config_test.yaml b/config/config_test.yaml index 4ebccb8..85a7eb1 100644 --- a/config/config_test.yaml +++ b/config/config_test.yaml @@ -71,6 +71,10 @@ tools: zltxOrderAfterSaleResellerBatch: enabled: true base_url: "https://gateway.dev.cdlsxd.cn/zltx_api/admin/afterSales/reseller_pre_ai" + weather: + enabled: true + base_url: "https://restapi.amap.com/v3/weather/weatherInfo" + api_key: "12afbde5ab78cb7e575ff76bd0bdef2b" diff --git a/internal/biz/do/ctx.go b/internal/biz/do/ctx.go index d252c47..1c577ed 100644 --- a/internal/biz/do/ctx.go +++ b/internal/biz/do/ctx.go @@ -131,7 +131,8 @@ func (d *Do) loadSystemInfo(ctx context.Context, client *gateway.Client, require // 获取任务列表的辅助函数 func (d *Do) loadTaskList(ctx context.Context, client *gateway.Client, requireData *entitys.RequireData) error { if taskInfo := client.GetTasks(); len(taskInfo) == 0 { - tasks, err := d.GetTasks(requireData.Sys.SysID) + // 从数据库获取任务列表, 0 表示获取公共的任务 + tasks, err := d.GetTasks(requireData.Sys.SysID, 0) if err != nil { return err } @@ -140,6 +141,7 @@ func (d *Do) loadTaskList(ctx context.Context, client *gateway.Client, requireDa } else { requireData.Tasks = taskInfo } + return nil } @@ -242,12 +244,12 @@ func (d *Do) getSessionChatHis(requireData *entitys.RequireData) (his []model.Ai return } -func (d *Do) GetTasks(sysId int32) (tasks []model.AiTask, err error) { +func (d *Do) GetTasks(sysId ...int32) (tasks []model.AiTask, err error) { cond := builder.NewCond() - cond = cond.And(builder.Eq{"sys_id": sysId}) + //cond = cond.And(builder.Eq{"sys_id": sysId}) cond = cond.And(builder.IsNull{"delete_at"}) - cond = cond.And(builder.Eq{"status": 1}) + cond = cond.And(builder.Eq{"status": 1}.And(builder.In("sys_id", sysId))) _, err = d.taskImpl.GetListToStruct(&cond, nil, &tasks, "") return diff --git a/internal/biz/do/handle.go b/internal/biz/do/handle.go index 2e2a313..65eeeea 100644 --- a/internal/biz/do/handle.go +++ b/internal/biz/do/handle.go @@ -12,6 +12,7 @@ import ( "ai_scheduler/internal/pkg/l_request" "ai_scheduler/internal/pkg/mapstructure" "ai_scheduler/internal/tools" + "ai_scheduler/internal/tools/public" "ai_scheduler/internal/tools_bot" "context" "encoding/json" @@ -182,7 +183,7 @@ func (r *Handle) handleKnowle(ctx context.Context, requireData *entitys.RequireD return fmt.Errorf("tool not found: %s", configData.Tool) } - if knowledgeTool, ok := tool.(*tools.KnowledgeBaseTool); !ok { + if knowledgeTool, ok := tool.(*public.KnowledgeBaseTool); !ok { return fmt.Errorf("未找到知识库Tool: %s", configData.Tool) } else { host = knowledgeTool.GetConfig().BaseURL @@ -193,7 +194,7 @@ func (r *Handle) handleKnowle(ctx context.Context, requireData *entitys.RequireD // 知识库的session为空,请求知识库获取, 并绑定 if requireData.SessionInfo.KnowlegeSessionID == "" { // 请求知识库 - if sessionIdKnowledge, err = tools.GetKnowledgeBaseSession(host, requireData.Sys.KnowlegeBaseID, requireData.Sys.KnowlegeTenantKey); err != nil { + if sessionIdKnowledge, err = public.GetKnowledgeBaseSession(host, requireData.Sys.KnowlegeBaseID, requireData.Sys.KnowlegeTenantKey); err != nil { return } diff --git a/internal/tools/calculator.go b/internal/tools/calculator.go deleted file mode 100644 index fecde64..0000000 --- a/internal/tools/calculator.go +++ /dev/null @@ -1,121 +0,0 @@ -package tools - -import ( - "ai_scheduler/internal/entitys" - - "context" - "encoding/json" - "fmt" - "math" -) - -// CalculatorTool 计算器工具 -type CalculatorTool struct{} - -// NewCalculatorTool 创建计算器工具 -func NewCalculatorTool() *CalculatorTool { - return &CalculatorTool{} -} - -// Name 返回工具名称 -func (c *CalculatorTool) Name() string { - return "calculate" -} - -// Description 返回工具描述 -func (c *CalculatorTool) Description() string { - return "执行基本的数学运算,支持加减乘除和幂运算" -} - -// Definition 返回工具定义 -func (c *CalculatorTool) Definition() entitys.ToolDefinition { - return entitys.ToolDefinition{ - Type: "function", - Function: entitys.FunctionDef{ - Name: c.Name(), - Description: c.Description(), - Parameters: map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "operation": map[string]interface{}{ - "type": "string", - "description": "运算类型", - "enum": []string{"add", "subtract", "multiply", "divide", "power"}, - }, - "a": map[string]interface{}{ - "type": "number", - "description": "第一个数字", - }, - "b": map[string]interface{}{ - "type": "number", - "description": "第二个数字", - }, - }, - "required": []string{"operation", "a", "b"}, - }, - }, - } -} - -// CalculateRequest 计算请求参数 -type CalculateRequest struct { - Operation string `json:"operation"` - A float64 `json:"a"` - B float64 `json:"b"` -} - -// CalculateResponse 计算响应 -type CalculateResponse struct { - Operation string `json:"operation"` - A float64 `json:"a"` - B float64 `json:"b"` - Result float64 `json:"result"` - Expression string `json:"expression"` -} - -// Execute 执行计算 -func (c *CalculatorTool) Execute(ctx context.Context, args json.RawMessage) (interface{}, error) { - var req CalculateRequest - if err := json.Unmarshal(args, &req); err != nil { - return nil, fmt.Errorf("invalid calculate request: %w", err) - } - - var result float64 - var expression string - - switch req.Operation { - case "add": - result = req.A + req.B - expression = fmt.Sprintf("%.2f + %.2f = %.2f", req.A, req.B, result) - case "subtract": - result = req.A - req.B - expression = fmt.Sprintf("%.2f - %.2f = %.2f", req.A, req.B, result) - case "multiply": - result = req.A * req.B - expression = fmt.Sprintf("%.2f × %.2f = %.2f", req.A, req.B, result) - case "divide": - if req.B == 0 { - return nil, fmt.Errorf("division by zero is not allowed") - } - result = req.A / req.B - expression = fmt.Sprintf("%.2f ÷ %.2f = %.2f", req.A, req.B, result) - case "power": - result = math.Pow(req.A, req.B) - expression = fmt.Sprintf("%.2f ^ %.2f = %.2f", req.A, req.B, result) - default: - return nil, fmt.Errorf("unsupported operation: %s", req.Operation) - } - - // 检查结果是否有效 - if math.IsNaN(result) || math.IsInf(result, 0) { - return nil, fmt.Errorf("calculation resulted in invalid number") - } - - return &CalculateResponse{ - Operation: req.Operation, - A: req.A, - B: req.B, - Result: result, - Expression: expression, - }, nil -} diff --git a/internal/tools/manager.go b/internal/tools/manager.go index 2ebbb69..87212b7 100644 --- a/internal/tools/manager.go +++ b/internal/tools/manager.go @@ -5,6 +5,7 @@ import ( "ai_scheduler/internal/data/constants" "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg/utils_ollama" + "ai_scheduler/internal/tools/public" zltxtool "ai_scheduler/internal/tools/zltx" "context" @@ -44,30 +45,30 @@ func NewManager(config *config.Config, llm *utils_ollama.Client) *Manager { // 注册直连天下订单详情工具 if config.Tools.ZltxOrderDetail.Enabled { - zltxOrderDetailTool := NewZltxOrderDetailTool(config.Tools.ZltxOrderDetail, m.llm) + zltxOrderDetailTool := zltxtool.NewZltxOrderDetailTool(config.Tools.ZltxOrderDetail, m.llm) m.tools[zltxOrderDetailTool.Name()] = zltxOrderDetailTool } //注册直连天下订单日志工具 if config.Tools.ZltxOrderDirectLog.Enabled { - zltxOrderLogTool := NewZltxOrderLogTool(config.Tools.ZltxOrderDirectLog) + zltxOrderLogTool := zltxtool.NewZltxOrderLogTool(config.Tools.ZltxOrderDirectLog) m.tools[zltxOrderLogTool.Name()] = zltxOrderLogTool } //注册直连天下商品工具 if config.Tools.ZltxProduct.Enabled { - zltxProductTool := NewZltxProductTool(config.Tools.ZltxProduct) + zltxProductTool := zltxtool.NewZltxProductTool(config.Tools.ZltxProduct) m.tools[zltxProductTool.Name()] = zltxProductTool } //注册直连天下订单统计工具 if config.Tools.ZltxOrderStatistics.Enabled { - zltxOrderStatisticsTool := NewZltxOrderStatisticsTool(config.Tools.ZltxOrderStatistics) + zltxOrderStatisticsTool := zltxtool.NewZltxOrderStatisticsTool(config.Tools.ZltxOrderStatistics) m.tools[zltxOrderStatisticsTool.Name()] = zltxOrderStatisticsTool } // 注册知识库工具 if config.Tools.Knowledge.Enabled { - knowledgeTool := NewKnowledgeBaseTool(config.Tools.Knowledge) + knowledgeTool := public.NewKnowledgeBaseTool(config.Tools.Knowledge) m.tools[knowledgeTool.Name()] = knowledgeTool } @@ -86,9 +87,14 @@ func NewManager(config *config.Config, llm *utils_ollama.Client) *Manager { zltxOrderAfterSaleResellerBatchTool := zltxtool.NewOrderAfterSaleResellerBatchTool(config.Tools.ZltxOrderAfterSaleResellerBatch) m.tools[zltxOrderAfterSaleResellerBatchTool.Name()] = zltxOrderAfterSaleResellerBatchTool } + // 注册天气工具 + if config.Tools.Weather.Enabled { + weatherTool := public.NewWeatherTool(config.Tools.Weather) + m.tools[weatherTool.Name()] = weatherTool + } // 普通对话 - chat := NewNormalChatTool(m.llm, config) + chat := public.NewNormalChatTool(m.llm, config) m.tools[chat.Name()] = chat return m diff --git a/internal/tools/konwledge_base.go b/internal/tools/public/konwledge_base.go similarity index 99% rename from internal/tools/konwledge_base.go rename to internal/tools/public/konwledge_base.go index dbbb5ce..505a349 100644 --- a/internal/tools/konwledge_base.go +++ b/internal/tools/public/konwledge_base.go @@ -1,4 +1,4 @@ -package tools +package public import ( "ai_scheduler/internal/config" diff --git a/internal/tools/konwledge_base_test.go b/internal/tools/public/konwledge_base_test.go similarity index 97% rename from internal/tools/konwledge_base_test.go rename to internal/tools/public/konwledge_base_test.go index de1ab3f..0838587 100644 --- a/internal/tools/konwledge_base_test.go +++ b/internal/tools/public/konwledge_base_test.go @@ -1,4 +1,4 @@ -package tools +package public import ( "testing" diff --git a/internal/tools/normal_chat.go b/internal/tools/public/normal_chat.go similarity index 99% rename from internal/tools/normal_chat.go rename to internal/tools/public/normal_chat.go index 56010fc..a4c96e0 100644 --- a/internal/tools/normal_chat.go +++ b/internal/tools/public/normal_chat.go @@ -1,4 +1,4 @@ -package tools +package public import ( "ai_scheduler/internal/config" diff --git a/internal/tools/public/weather.go b/internal/tools/public/weather.go new file mode 100644 index 0000000..0292fa4 --- /dev/null +++ b/internal/tools/public/weather.go @@ -0,0 +1,345 @@ +package public + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg/l_request" + "log" + "strconv" + + "context" + "encoding/json" + "fmt" + "time" +) + +// WeatherTool 天气查询工具 +type WeatherTool struct { + mockData bool + config config.ToolConfig +} + +// NewWeatherTool 创建天气工具 +func NewWeatherTool(config config.ToolConfig) *WeatherTool { + return &WeatherTool{config: config} +} + +// Name 返回工具名称 +func (w *WeatherTool) Name() string { + return "get_weather" +} + +// Description 返回工具描述 +func (w *WeatherTool) Description() string { + return "获取指定城市的天气信息" +} + +// Definition 返回工具定义 +func (w *WeatherTool) Definition() entitys.ToolDefinition { + return entitys.ToolDefinition{ + Type: "function", + Function: entitys.FunctionDef{ + Name: w.Name(), + Description: w.Description(), + Parameters: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "city": map[string]interface{}{ + "type": "string", + "description": "城市名称,如:北京、上海、广州", + }, + "unit": map[string]interface{}{ + "type": "string", + "description": "温度单位,celsius(摄氏度)或fahrenheit(华氏度)", + "enum": []string{"celsius", "fahrenheit"}, + "default": "celsius", + }, + "extensions": map[string]interface{}{ + "type": "string", + "description": "扩展参数,base/all base:返回实况天气 all:返回预报天气", + "enum": []string{"base", "all"}, + "default": "base", + }, + }, + "required": []string{"city"}, + }, + }, + } +} + +// WeatherRequest 天气请求参数 +type WeatherRequest struct { + City string `json:"city"` + Extensions string `json:"extensions"` // 扩展参数,base/all base:返回实况天气 all:返回预报天气 + Unit string `json:"unit,omitempty"` +} + +// WeatherResponse 天气响应 +type WeatherResponse struct { + City string `json:"city"` + Unit string `json:"unit"` + Timestamp string `json:"timestamp"` + LiveWeather *LiveWeather `json:"live_weather,omitempty"` // 实时天气 + Forecasts []ForecastWeather `json:"forecasts,omitempty"` // 预报天气 +} + +// ForecastWeather 预报天气 +type ForecastWeather struct { + Date string `json:"date"` + Week string `json:"week"` + DayWeather string `json:"day_weather"` + NightWeather string `json:"night_weather"` + DayTemp float64 `json:"day_temp"` + NightTemp float64 `json:"night_temp"` + DayWind string `json:"day_wind"` + NightWind string `json:"night_wind"` + DayWindPower string `json:"day_wind_power"` + NightWindPower string `json:"night_wind_power"` +} + +// LiveWeather 实时天气 +type LiveWeather struct { + Temperature float64 `json:"temperature"` + Condition string `json:"condition"` + Humidity int `json:"humidity"` + WindSpeed float64 `json:"wind_speed"` + WindDirection string `json:"wind_direction"` +} + +// Execute 执行天气查询 +func (w *WeatherTool) Execute(ctx context.Context, requireData *entitys.RequireData) error { + var req WeatherRequest + if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil { + return fmt.Errorf("invalid weather request: %w", err) + } + + if req.City == "" { + return fmt.Errorf("city is required") + } + + if req.Unit == "" { + req.Unit = "celsius" + } + + // 设置默认获取实时天气信息 + if req.Extensions == "" { + req.Extensions = "base" + } + + // 这里可以集成真实的天气API + responseMsg, err := w.getRealWeather(req) + if err != nil { + return fmt.Errorf("failed to get real weather: %w", err) + } + + // 根据 extensions 参数返回不同的天气信息 + if req.Extensions == "base" { + entitys.ResText(requireData.Ch, "", fmt.Sprintf("%s实时天气:%s,温度:%.1f℃,湿度:%d%%,风速:%.1fkm/h,风向:%s", + req.City, + responseMsg.LiveWeather.Condition, + responseMsg.LiveWeather.Temperature, + responseMsg.LiveWeather.Humidity, + responseMsg.LiveWeather.WindSpeed, + responseMsg.LiveWeather.WindDirection)) + } else { + + rspStr := fmt.Sprintf("%s天气预报:\n", req.City) + for _, forecast := range responseMsg.Forecasts { + rspStr += fmt.Sprintf("%s 温度:%.1f℃/%.1f℃ 风力:%s %s\n", + forecast.Date, forecast.DayTemp, forecast.NightTemp, forecast.DayWind, forecast.NightWind) + } + + entitys.ResText(requireData.Ch, "", rspStr) + } + return nil +} + +//// getMockWeather 获取模拟天气数据 +//func (w *WeatherTool) getMockWeather(city, unit string) *WeatherResponse { +// rand.Seed(time.Now().UnixNano()) +// +// // 模拟不同城市的基础温度 +// baseTemp := map[string]float64{ +// "北京": 15.0, +// "上海": 18.0, +// "广州": 25.0, +// "深圳": 26.0, +// "杭州": 17.0, +// "成都": 16.0, +// } +// +// temp := baseTemp[city] +// if temp == 0 { +// temp = 20.0 // 默认温度 +// } +// +// // 添加随机变化 +// temp += (rand.Float64() - 0.5) * 10 +// +// // 转换温度单位 +// if unit == "fahrenheit" { +// temp = temp*9/5 + 32 +// } +// +// conditions := []string{"晴朗", "多云", "阴天", "小雨", "中雨"} +// condition := conditions[rand.Intn(len(conditions))] +// +// return &WeatherResponse{ +// City: city, +// Temperature: float64(int(temp*10)) / 10, // 保留一位小数 +// Unit: unit, +// Condition: condition, +// Humidity: rand.Intn(40) + 40, // 40-80% +// WindSpeed: float64(rand.Intn(20)) + 1.0, +// Timestamp: time.Now().Format("2006-01-02 15:04:05"), +// } +//} + +// getRealWeather 调用高德天气API +func (w *WeatherTool) getRealWeather(request WeatherRequest) (*WeatherResponse, error) { + // 构建请求URL + req := l_request.Request{ + Url: w.config.BaseURL, + Headers: map[string]string{}, + Params: map[string]string{ + "city": request.City, // 城市名称 + "key": w.config.APIKey, // API密钥 + // extensions: 基础天气数据 可选值:base/all base:返回实况天气 all:返回预报天气 + "extensions": request.Extensions, // 基础天气数据 + "output": "JSON", // JSON格式返回 + }, + Method: "GET", + } + res, err := req.Send() + if err != nil { + return nil, err + } + + // 解析API响应 + var apiResp struct { + Status string `json:"status"` + Count string `json:"count"` + Info string `json:"info"` + Infocode string `json:"infocode"` + + // 预报天气信息数据 + Forecasts []struct { + City string `json:"city"` + Adcode string `json:"adcode"` + Province string `json:"province"` + Reporttime string `json:"reporttime"` + Casts []struct { + Date string `json:"date"` + Week string `json:"week"` + Dayweather string `json:"dayweather"` + Nightweather string `json:"nightweather"` + Daytemp string `json:"daytemp"` + Nighttemp string `json:"nighttemp"` + Daywind string `json:"daywind"` + Nightwind string `json:"nightwind"` + Daypower string `json:"daypower"` + Nightpower string `json:"nightpower"` + DaytempFloat string `json:"daytemp_float"` + NighttempFloat string `json:"nighttemp_float"` + } `json:"casts"` + } `json:"forecasts"` + // 实况天气信息数据 + Lives []struct { + Province string `json:"province"` + City string `json:"city"` + Adcode string `json:"adcode"` + Weather string `json:"weather"` + Temperature string `json:"temperature"` + Winddirection string `json:"winddirection"` + Windpower string `json:"windpower"` + Humidity string `json:"humidity"` + Reporttime string `json:"reporttime"` + TemperatureFloat string `json:"temperature_float"` + HumidityFloat string `json:"humidity_float"` + } `json:"lives"` + } + + log.Printf("weather API response: %s", string(res.Content)) + + if err = json.Unmarshal(res.Content, &apiResp); err != nil { + return nil, fmt.Errorf("parse weather API response failed: %w", err) + } + + // 检查API返回状态 + if apiResp.Status != "1" { + return nil, fmt.Errorf("weather API returned error: %s, info: %s", apiResp.Status, apiResp.Info) + } + + // 获取城市名称 + cityName := "" + if len(apiResp.Lives) > 0 { + cityName = apiResp.Lives[0].City + } else if len(apiResp.Forecasts) > 0 { + cityName = apiResp.Forecasts[0].City + } else { + return nil, fmt.Errorf("no weather data found") + } + + // 构建响应 + response := &WeatherResponse{ + City: cityName, + Unit: request.Unit, + Timestamp: time.Now().Format("2006-01-02 15:04:05"), + } + // 处理实时天气 + if len(apiResp.Lives) > 0 { + liveData := apiResp.Lives[0] + + // 转换温度 + temp, _ := strconv.ParseFloat(liveData.Temperature, 64) + if request.Unit == "fahrenheit" { + temp = temp*9/5 + 32 + } + + // 转换湿度和风速 + humidity, _ := strconv.Atoi(liveData.Humidity) + windSpeed, _ := strconv.ParseFloat(liveData.Windpower, 64) + + response.LiveWeather = &LiveWeather{ + Temperature: temp, + Condition: liveData.Weather, + Humidity: humidity, + WindSpeed: windSpeed, + WindDirection: liveData.Winddirection, + } + } + + // 处理预报天气 + if len(apiResp.Forecasts) > 0 && len(apiResp.Forecasts[0].Casts) > 0 { + response.Forecasts = make([]ForecastWeather, 0, len(apiResp.Forecasts[0].Casts)) + + for _, cast := range apiResp.Forecasts[0].Casts { + // 转换温度 + dayTemp, _ := strconv.ParseFloat(cast.Daytemp, 64) + nightTemp, _ := strconv.ParseFloat(cast.Nighttemp, 64) + + if request.Unit == "fahrenheit" { + dayTemp = dayTemp*9/5 + 32 + nightTemp = nightTemp*9/5 + 32 + } + + forecast := ForecastWeather{ + Date: cast.Date, + Week: cast.Week, + DayWeather: cast.Dayweather, + NightWeather: cast.Nightweather, + DayTemp: dayTemp, + NightTemp: nightTemp, + DayWind: cast.Daywind, + NightWind: cast.Nightwind, + DayWindPower: cast.Daypower, + NightWindPower: cast.Nightpower, + } + + response.Forecasts = append(response.Forecasts, forecast) + } + } + + return response, nil + +} diff --git a/internal/tools/weather.go b/internal/tools/weather.go deleted file mode 100644 index 744216f..0000000 --- a/internal/tools/weather.go +++ /dev/null @@ -1,139 +0,0 @@ -package tools - -import ( - "ai_scheduler/internal/entitys" - - "context" - "encoding/json" - "fmt" - "math/rand" - "time" -) - -// WeatherTool 天气查询工具 -type WeatherTool struct { - mockData bool -} - -// NewWeatherTool 创建天气工具 -func NewWeatherTool() *WeatherTool { - return &WeatherTool{} -} - -// Name 返回工具名称 -func (w *WeatherTool) Name() string { - return "get_weather" -} - -// Description 返回工具描述 -func (w *WeatherTool) Description() string { - return "获取指定城市的天气信息" -} - -// Definition 返回工具定义 -func (w *WeatherTool) Definition() entitys.ToolDefinition { - return entitys.ToolDefinition{ - Type: "function", - Function: entitys.FunctionDef{ - Name: w.Name(), - Description: w.Description(), - Parameters: map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "city": map[string]interface{}{ - "type": "string", - "description": "城市名称,如:北京、上海、广州", - }, - "unit": map[string]interface{}{ - "type": "string", - "description": "温度单位,celsius(摄氏度)或fahrenheit(华氏度)", - "enum": []string{"celsius", "fahrenheit"}, - "default": "celsius", - }, - }, - "required": []string{"city"}, - }, - }, - } -} - -// WeatherRequest 天气请求参数 -type WeatherRequest struct { - City string `json:"city"` - Unit string `json:"unit,omitempty"` -} - -// WeatherResponse 天气响应 -type WeatherResponse struct { - City string `json:"city"` - Temperature float64 `json:"temperature"` - Unit string `json:"unit"` - Condition string `json:"condition"` - Humidity int `json:"humidity"` - WindSpeed float64 `json:"wind_speed"` - Timestamp string `json:"timestamp"` -} - -// Execute 执行天气查询 -func (w *WeatherTool) Execute(ctx context.Context, args json.RawMessage) (interface{}, error) { - var req WeatherRequest - if err := json.Unmarshal(args, &req); err != nil { - return nil, fmt.Errorf("invalid weather request: %w", err) - } - - if req.City == "" { - return nil, fmt.Errorf("city is required") - } - - if req.Unit == "" { - req.Unit = "celsius" - } - - if w.mockData { - return w.getMockWeather(req.City, req.Unit), nil - } - - // 这里可以集成真实的天气API - return w.getMockWeather(req.City, req.Unit), nil -} - -// getMockWeather 获取模拟天气数据 -func (w *WeatherTool) getMockWeather(city, unit string) *WeatherResponse { - rand.Seed(time.Now().UnixNano()) - - // 模拟不同城市的基础温度 - baseTemp := map[string]float64{ - "北京": 15.0, - "上海": 18.0, - "广州": 25.0, - "深圳": 26.0, - "杭州": 17.0, - "成都": 16.0, - } - - temp := baseTemp[city] - if temp == 0 { - temp = 20.0 // 默认温度 - } - - // 添加随机变化 - temp += (rand.Float64() - 0.5) * 10 - - // 转换温度单位 - if unit == "fahrenheit" { - temp = temp*9/5 + 32 - } - - conditions := []string{"晴朗", "多云", "阴天", "小雨", "中雨"} - condition := conditions[rand.Intn(len(conditions))] - - return &WeatherResponse{ - City: city, - Temperature: float64(int(temp*10)) / 10, // 保留一位小数 - Unit: unit, - Condition: condition, - Humidity: rand.Intn(40) + 40, // 40-80% - WindSpeed: float64(rand.Intn(20)) + 1.0, - Timestamp: time.Now().Format("2006-01-02 15:04:05"), - } -} diff --git a/internal/tools/zltx_after_direct.go b/internal/tools/zltx/zltx_after_direct.go similarity index 99% rename from internal/tools/zltx_after_direct.go rename to internal/tools/zltx/zltx_after_direct.go index 5f1e20e..01a6282 100644 --- a/internal/tools/zltx_after_direct.go +++ b/internal/tools/zltx/zltx_after_direct.go @@ -1,4 +1,4 @@ -package tools +package zltx import ( "ai_scheduler/internal/config" diff --git a/internal/tools/zltx_after_pre.go b/internal/tools/zltx/zltx_after_pre.go similarity index 99% rename from internal/tools/zltx_after_pre.go rename to internal/tools/zltx/zltx_after_pre.go index b44812b..5bbf809 100644 --- a/internal/tools/zltx_after_pre.go +++ b/internal/tools/zltx/zltx_after_pre.go @@ -1,4 +1,4 @@ -package tools +package zltx import ( "ai_scheduler/internal/config" diff --git a/internal/tools/zltx_order_detail.go b/internal/tools/zltx/zltx_order_detail.go similarity index 99% rename from internal/tools/zltx_order_detail.go rename to internal/tools/zltx/zltx_order_detail.go index f253f59..7dae9f1 100644 --- a/internal/tools/zltx_order_detail.go +++ b/internal/tools/zltx/zltx_order_detail.go @@ -1,4 +1,4 @@ -package tools +package zltx import ( "ai_scheduler/internal/config" diff --git a/internal/tools/zltx_order_direct_log.go b/internal/tools/zltx/zltx_order_direct_log.go similarity index 99% rename from internal/tools/zltx_order_direct_log.go rename to internal/tools/zltx/zltx_order_direct_log.go index bf82408..b1b1483 100644 --- a/internal/tools/zltx_order_direct_log.go +++ b/internal/tools/zltx/zltx_order_direct_log.go @@ -1,4 +1,4 @@ -package tools +package zltx import ( "ai_scheduler/internal/config" diff --git a/internal/tools/zltx_product.go b/internal/tools/zltx/zltx_product.go similarity index 99% rename from internal/tools/zltx_product.go rename to internal/tools/zltx/zltx_product.go index 61b8bb1..de24236 100644 --- a/internal/tools/zltx_product.go +++ b/internal/tools/zltx/zltx_product.go @@ -1,4 +1,4 @@ -package tools +package zltx import ( "ai_scheduler/internal/config" diff --git a/internal/tools/zltx_statistics.go b/internal/tools/zltx/zltx_statistics.go similarity index 99% rename from internal/tools/zltx_statistics.go rename to internal/tools/zltx/zltx_statistics.go index f53d4b1..4a051a8 100644 --- a/internal/tools/zltx_statistics.go +++ b/internal/tools/zltx/zltx_statistics.go @@ -1,4 +1,4 @@ -package tools +package zltx import ( "ai_scheduler/internal/config" From 25b55ec1f8cfb071c67eba32c7fd1def53a57ffb Mon Sep 17 00:00:00 2001 From: renzhiyuan <465386466@qq.com> Date: Thu, 11 Dec 2025 17:12:18 +0800 Subject: [PATCH 30/66] =?UTF-8?q?fix:=E4=BF=AE=E5=A4=8D=E6=8A=A5=E9=94=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/config.yaml | 2 +- internal/biz/ding_talk_bot.go | 5 +- internal/biz/do/handle.go | 1 + internal/biz/handle/dingtalk/auth.go | 69 +++++++++++ internal/biz/handle/dingtalk/dept.go | 35 ++++++ internal/biz/handle/dingtalk/option.go | 21 ++++ internal/biz/handle/dingtalk/types.go | 59 ++++++++++ internal/biz/handle/dingtalk/user.go | 140 +++++++++++++++++++++++ internal/biz/router.go | 1 + internal/data/constants/bot.go | 2 + internal/data/constants/const.go | 2 + internal/data/constants/dingtalk.go | 38 ++++++ internal/data/constants/user_test.go | 91 +++++++++++++++ internal/data/impl/bot_user.go | 27 +++++ internal/data/model/ai_bot_config.gen.go | 7 +- internal/data/model/ai_bot_user.gen.go | 33 ++++++ internal/entitys/bot.go | 9 +- internal/entitys/dingtalk.go | 17 +++ internal/pkg/func.go | 29 +++++ internal/server/ding_talk_bot.go | 4 +- internal/services/dtalk_bot.go | 2 +- 21 files changed, 579 insertions(+), 15 deletions(-) create mode 100644 internal/biz/handle/dingtalk/auth.go create mode 100644 internal/biz/handle/dingtalk/dept.go create mode 100644 internal/biz/handle/dingtalk/option.go create mode 100644 internal/biz/handle/dingtalk/types.go create mode 100644 internal/biz/handle/dingtalk/user.go create mode 100644 internal/data/constants/dingtalk.go create mode 100644 internal/data/constants/user_test.go create mode 100644 internal/data/impl/bot_user.go create mode 100644 internal/data/model/ai_bot_user.gen.go create mode 100644 internal/entitys/dingtalk.go diff --git a/config/config.yaml b/config/config.yaml index c1e8c08..4b568b4 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -23,7 +23,7 @@ redis: host: 47.97.27.195:6379 type: node pass: lansexiongdi@666 - key: report-api-test + key: report-api pollSize: 5 #连接池大小,不配置,或配置为0表示不启用连接池 minIdleConns: 2 #最小空闲连接数 maxIdleTime: 30 #每个连接最大空闲时间,如果超过了这个时间会被关闭 diff --git a/internal/biz/ding_talk_bot.go b/internal/biz/ding_talk_bot.go index f2abea1..74f2a47 100644 --- a/internal/biz/ding_talk_bot.go +++ b/internal/biz/ding_talk_bot.go @@ -6,8 +6,8 @@ import ( "ai_scheduler/internal/data/impl" "ai_scheduler/internal/data/model" "ai_scheduler/internal/entitys" - "ai_scheduler/internal/pkg/mapstructure" "context" + "encoding/json" "fmt" "github.com/gofiber/fiber/v2/log" @@ -47,10 +47,11 @@ func (d *DingTalkBotBiz) GetDingTalkBotCfgList() (dingBotList []entitys.DingTalk err = d.botConfigImpl.GetRangeToMapStruct(&cond, &botConfig) for _, v := range botConfig { var config entitys.DingTalkBot - err = mapstructure.Decode(v, &config) + err = json.Unmarshal([]byte(v.BotConfig), &config) if err != nil { d.log.Info("初始化“%s”失败:%s", v.BotName, err.Error()) } + config.BotIndex = v.BotIndex dingBotList = append(dingBotList, config) } return diff --git a/internal/biz/do/handle.go b/internal/biz/do/handle.go index 3a5ffac..169f025 100644 --- a/internal/biz/do/handle.go +++ b/internal/biz/do/handle.go @@ -48,6 +48,7 @@ func NewHandle( func (r *Handle) Recognize(ctx context.Context, rec *entitys.Recognize, promptProcessor PromptOption) (match entitys.Match, err error) { entitys.ResLog(rec.Ch, "recognize_start", "准备意图识别") + prompt, err := promptProcessor.CreatePrompt(ctx, rec) //意图识别 recognizeMsg, err := r.Ollama.IntentRecognize(ctx, &entitys.ToolSelect{ diff --git a/internal/biz/handle/dingtalk/auth.go b/internal/biz/handle/dingtalk/auth.go new file mode 100644 index 0000000..8ddd0c1 --- /dev/null +++ b/internal/biz/handle/dingtalk/auth.go @@ -0,0 +1,69 @@ +package dingtalk + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/data/constants" + "ai_scheduler/internal/pkg/l_request" + "context" + "encoding/json" + "errors" + "net/http" + "time" + + "github.com/redis/go-redis/v9" +) + +type Auth struct { + redis *redis.Client + cfg *config.Config +} + +func NewAuth(cfg *config.Config, redis *redis.Client) *Auth { + return &Auth{ + redis: redis, + cfg: cfg, + } +} + +func (a *Auth) GetAccessToken(ctx context.Context, clientId string, clientSecret string) (accessToken string, err error) { + if clientId == "" { + return "", errors.New("clientId is empty") + } + token := a.redis.Get(ctx, a.getKey(clientId)).String() + if token == "" { + authInfo, _err := a.getNewAccessToken(ctx, clientId, clientSecret) + if _err != nil { + return "", _err + } + a.redis.SetEx(ctx, a.getKey(clientId), authInfo.AccessToken, time.Duration(authInfo.ExpiresIn-3600)*time.Second) + accessToken = authInfo.AccessToken + } + return +} + +func (a *Auth) getKey(clientId string) string { + return a.cfg.Redis.Key + ":" + constants.DingTalkAuthBaseKeyPrefix + ":" + clientId +} + +func (a *Auth) getNewAccessToken(ctx context.Context, clientId string, clientSecret string) (auth AuthInfo, err error) { + if clientId == "" || clientSecret == "" { + err = errors.New("clientId or clientSecret is empty") + return + } + + req := l_request.Request{ + Method: http.MethodPost, + Url: constants.GetDingTalkRequestUrl(constants.RequestUrlGetAccessToken, nil), + Data: map[string]string{ + "appkey": clientId, + "appsecret": clientSecret, + }, + } + res, err := req.Send() + if err != nil { + return + } + err = json.Unmarshal(res.Content, &auth) + + return +} diff --git a/internal/biz/handle/dingtalk/dept.go b/internal/biz/handle/dingtalk/dept.go new file mode 100644 index 0000000..14327c9 --- /dev/null +++ b/internal/biz/handle/dingtalk/dept.go @@ -0,0 +1,35 @@ +package dingtalk + +import ( + "ai_scheduler/internal/data/impl" + "ai_scheduler/internal/entitys" + "database/sql" + "errors" +) + +type Dept struct { + dingUserImpl *impl.BotUserImpl +} + +func NewDept(dingUserImpl *impl.BotUserImpl) *User { + return &User{ + dingUserImpl: dingUserImpl, + } +} + +func (u *User) GetDeptInfo(userId string) (userInfo *entitys.DingTalkUserInfo, err error) { + if len(userId) == 0 { + return + } + user, err := u.dingUserImpl.GetByStaffId(userId) + if err != nil { + if !errors.Is(err, sql.ErrNoRows) { + return + } + } + //如果没有找到,则新增 + if user == nil { + + } + return +} diff --git a/internal/biz/handle/dingtalk/option.go b/internal/biz/handle/dingtalk/option.go new file mode 100644 index 0000000..fb473c7 --- /dev/null +++ b/internal/biz/handle/dingtalk/option.go @@ -0,0 +1,21 @@ +package dingtalk + +import "ai_scheduler/internal/data/model" + +type Bot struct { + id int + botConfig *model.AiBotConfig +} +type BotOption func(*Bot) + +func WithId(id int) BotOption { + return func(b *Bot) { + b.id = id + } +} + +func WithBootConfig(BotConfig *model.AiBotConfig) BotOption { + return func(bot *Bot) { + bot.botConfig = BotConfig + } +} diff --git a/internal/biz/handle/dingtalk/types.go b/internal/biz/handle/dingtalk/types.go new file mode 100644 index 0000000..1e9fe76 --- /dev/null +++ b/internal/biz/handle/dingtalk/types.go @@ -0,0 +1,59 @@ +package dingtalk + +type AuthInfo struct { + AccessToken string `json:"accessToken"` + ExpiresIn int64 `json:"expiresIn"` +} + +type UserInfoResResult struct { + Errcode string `json:"errcode"` + Result Result `json:"result"` + Errmsg string `json:"errmsg"` +} +type Result struct { + Extension string `json:"extension"` + Unionid string `json:"unionid"` + Boss string `json:"boss"` + RoleList struct { + GroupName string `json:"group_name"` + Name string `json:"name"` + Id string `json:"id"` + } `json:"role_list"` + ExclusiveAccount bool `json:"exclusive_account"` + ManagerUserid string `json:"manager_userid"` + Admin string `json:"admin"` + Remark string `json:"remark"` + Title string `json:"title"` + HiredDate int `json:"hired_date"` + Userid string `json:"userid"` + WorkPlace string `json:"work_place"` + DeptOrderList struct { + DeptId string `json:"dept_id"` + Order string `json:"order"` + } `json:"dept_order_list"` + RealAuthed string `json:"real_authed"` + DeptIdList string `json:"dept_id_list"` + JobNumber string `json:"job_number"` + Email string `json:"email"` + LeaderInDept struct { + Leader string `json:"leader"` + DeptId string `json:"dept_id"` + } `json:"leader_in_dept"` + Mobile string `json:"mobile"` + Active string `json:"active"` + OrgEmail string `json:"org_email"` + Telephone string `json:"telephone"` + Avatar string `json:"avatar"` + HideMobile string `json:"hide_mobile"` + Senior string `json:"senior"` + Name string `json:"name"` + UnionEmpExt struct { + UnionEmpMapList struct { + Userid string `json:"userid"` + CorpId string `json:"corp_id"` + } `json:"union_emp_map_list"` + Userid string `json:"userid"` + CorpId string `json:"corp_id"` + } `json:"union_emp_ext"` + StateCode string `json:"state_code"` +} diff --git a/internal/biz/handle/dingtalk/user.go b/internal/biz/handle/dingtalk/user.go new file mode 100644 index 0000000..7d213c2 --- /dev/null +++ b/internal/biz/handle/dingtalk/user.go @@ -0,0 +1,140 @@ +package dingtalk + +import ( + "ai_scheduler/internal/data/constants" + "ai_scheduler/internal/data/impl" + "ai_scheduler/internal/data/model" + "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg" + "ai_scheduler/internal/pkg/l_request" + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "net/http" + "strings" + "time" + + "github.com/gofiber/fiber/v2/log" + "xorm.io/builder" +) + +type User struct { + dingUserImpl *impl.BotUserImpl + botConfigImpl *impl.BotConfigImpl + auth *Auth + logger log.Logger +} + +func NewUser( + dingUserImpl *impl.BotUserImpl, + botConfig *impl.BotConfigImpl, + auth *Auth, + logger log.Logger, +) *User { + return &User{ + dingUserImpl: dingUserImpl, + botConfigImpl: botConfig, + logger: logger, + auth: auth, + } +} + +func (u *User) GetUserInfo(ctx context.Context, staffId string, botOption ...BotOption) (userInfo *entitys.DingTalkUserInfo, err error) { + if len(staffId) == 0 { + return + } + user, err := u.dingUserImpl.GetByStaffId(staffId) + if err != nil { + if !errors.Is(err, sql.ErrNoRows) { + return + } + } + //如果没有找到,则新增 + if user == nil { + DingUserInfo, _err := u.GetUserInfoFromDingTalkWithBot(ctx, staffId, botOption...) + if _err != nil { + return nil, _err + } + dingUserDo := &model.AiBotUser{ + StaffID: DingUserInfo.Userid, + Name: DingUserInfo.Name, + Title: DingUserInfo.Title, + Extension: DingUserInfo.Extension, + DeptIDList: DingUserInfo.DeptIdList, + IsBoss: int32(pkg.Ter(DingUserInfo.Boss == "true", constants.IsBossTrue, constants.IsBossFalse)), + IsSenior: int32(pkg.Ter(DingUserInfo.Boss == "true", constants.IsSeniorTrue, constants.IsSeniorFalse)), + HiredDate: time.Unix(int64(DingUserInfo.HiredDate), 0), + } + //deptIdList, _err := pkg.StringToSlice(DingUserInfo.DeptIdList) + //if err != nil { + // return nil, _err + //} + dingUserDo.DeptIDList = strings.Trim(dingUserDo.DeptIDList, "[]") + } + return +} + +func (u *User) GetUserInfoFromDingTalkWithBot(ctx context.Context, staffId string, botOption ...BotOption) (userInfo Result, err error) { + botInfo := &Bot{} + for _, option := range botOption { + option(botInfo) + } + + if botInfo.id == 0 && botInfo.botConfig == nil { + err = errors.New("botInfo is nil") + return + } + if botInfo.botConfig == nil { + cond := builder.NewCond() + cond = cond.And(builder.Eq{"bot_id": botInfo.id}) + err = u.botConfigImpl.GetOneBySearchToStrut(&cond, botInfo.botConfig) + if err != nil { + return + } + + } + var config entitys.DingTalkBot + err = json.Unmarshal([]byte(botInfo.botConfig.BotConfig), &config) + if err != nil { + log.Infof("初始化“%s”失败:%s", botInfo.botConfig.BotName, err.Error()) + return + } + token, err := u.auth.GetAccessToken(ctx, config.ClientId, config.ClientSecret) + if err != nil { + return + } + + return u.GetUserInfoFromDingTalk(ctx, token, staffId) +} + +func (u *User) GetUserInfoFromDingTalk(ctx context.Context, token string, staffId string) (user Result, err error) { + if token == "" && staffId == "" { + err = errors.New("获取钉钉用户信息的必要参数不足") + return + } + + req := l_request.Request{ + Method: http.MethodPost, + Url: constants.GetDingTalkRequestUrl(constants.RequestUrlGetUserGet, map[string]string{ + "access_token": token, + }), + Data: map[string]string{ + "userid": staffId, + }, + } + res, err := req.Send() + if err != nil { + return + } + var userInfoResResult UserInfoResResult + err = json.Unmarshal(res.Content, &userInfoResResult) + if err != nil { + return + } + if userInfoResResult.Errcode != "0" { + fmt.Errorf("钉钉请求报错:%s", userInfoResResult.Errmsg) + } + return userInfoResResult.Result, err +} diff --git a/internal/biz/router.go b/internal/biz/router.go index c164c31..b5add80 100644 --- a/internal/biz/router.go +++ b/internal/biz/router.go @@ -65,6 +65,7 @@ func (r *AiRouterBiz) RouteWithSocket(client *gateway.Client, req *entitys.ChatS return } //意图识别 + requireData.Match, err = r.handle.Recognize(ctx, &rec, &do.WithSys{}) if err != nil { log.Errorf("意图识别失败: %s", err.Error()) diff --git a/internal/data/constants/bot.go b/internal/data/constants/bot.go index 2b24cfe..446519a 100644 --- a/internal/data/constants/bot.go +++ b/internal/data/constants/bot.go @@ -31,3 +31,5 @@ type BotType int const ( BotTypeDingTalk BotType = 1 // 系统的bug/优化建议 ) + +const DingTalkAuthBaseKeyPrefix = "dingTalk_auth" diff --git a/internal/data/constants/const.go b/internal/data/constants/const.go index f06e906..b0731ad 100644 --- a/internal/data/constants/const.go +++ b/internal/data/constants/const.go @@ -30,3 +30,5 @@ var UseFulMap = map[UseFul]string{ UseFulNotUnclear: "回答不明确", UseFulNotError: "理解错误", } + +type BaseBool int32 diff --git a/internal/data/constants/dingtalk.go b/internal/data/constants/dingtalk.go new file mode 100644 index 0000000..dd730cc --- /dev/null +++ b/internal/data/constants/dingtalk.go @@ -0,0 +1,38 @@ +package constants + +import "net/url" + +const DingTalkBseUrl = "https://oapi.dingtalk.com" + +type RequestUrl string + +const ( + RequestUrlGetAccessToken RequestUrl = "/v1.0/oauth2/accessToken" + RequestUrlGetUserGet RequestUrl = "/topapi/v2/user/get" +) + +func GetDingTalkRequestUrl(path RequestUrl, query map[string]string) string { + u, _ := url.Parse(DingTalkBseUrl + string(path)) + q := u.Query() + for key, val := range query { + q.Add(key, val) + } + u.RawQuery = q.Encode() + return u.String() +} + +// IsBoss 是否是老板 +type IsBoss int + +const ( + IsBossTrue IsBoss = 1 + IsBossFalse IsBoss = 0 +) + +// IsSenior 是否是老板 +type IsSenior int + +const ( + IsSeniorTrue IsSenior = 1 + IsSeniorFalse IsSenior = 0 +) diff --git a/internal/data/constants/user_test.go b/internal/data/constants/user_test.go new file mode 100644 index 0000000..eee1b17 --- /dev/null +++ b/internal/data/constants/user_test.go @@ -0,0 +1,91 @@ +package constants + +import ( + "ai_scheduler/internal/biz" + "ai_scheduler/internal/biz/do" + "ai_scheduler/internal/biz/llm_service" + "ai_scheduler/internal/data/impl" + "ai_scheduler/internal/entitys" + "ai_scheduler/internal/gateway" + "ai_scheduler/internal/pkg/dingtalk" + "ai_scheduler/internal/pkg/utils_ollama" + "ai_scheduler/internal/pkg/utils_vllm" + "ai_scheduler/internal/server" + "ai_scheduler/internal/services" + "ai_scheduler/internal/tools" + "ai_scheduler/internal/tools_bot" + "ai_scheduler/utils" + "os" + "testing" +) + +const + +func TestMain(m *testing.M) { + bootstrap := initialize.LoadConfigWithTest() + businessLogger := log2.NewBusinessLogger(bootstrap.Logs, id, Name, Version) + helper := pkg.NewLogHelper(businessLogger, bootstrap) + + db, cleanup := utils.NewGormDb(configConfig) + sysImpl := impl.NewSysImpl(db) + taskImpl := impl.NewTaskImpl(db) + chatHisImpl := impl.NewChatHisImpl(db) + doDo := do.NewDo(sysImpl, taskImpl, chatHisImpl, configConfig) + client, cleanup2, err := utils_ollama.NewClient(configConfig) + if err != nil { + cleanup() + return nil, nil, err + } + utils_vllmClient, cleanup3, err := utils_vllm.NewClient(configConfig) + if err != nil { + cleanup2() + cleanup() + return nil, nil, err + } + ollamaService := llm_service.NewOllamaGenerate(client, utils_vllmClient, configConfig, chatHisImpl) + manager := tools.NewManager(configConfig, client) + sessionImpl := impl.NewSessionImpl(db) + botTool := tools_bot.NewBotTool(configConfig, client, sessionImpl) + handle := do.NewHandle(ollamaService, manager, configConfig, sessionImpl, botTool) + aiRouterBiz := biz.NewAiRouterBiz(doDo, handle) + chatHistoryBiz := biz.NewChatHistoryBiz(chatHisImpl, taskImpl) + gatewayGateway := gateway.NewGateway() + chatService := services.NewChatService(aiRouterBiz, chatHistoryBiz, gatewayGateway, configConfig) + sessionBiz := biz.NewSessionBiz(configConfig, sessionImpl, sysImpl, chatHisImpl) + sessionService := services.NewSessionService(sessionBiz, chatHistoryBiz) + taskBiz := biz.NewTaskBiz(configConfig, taskImpl) + taskService := services.NewTaskService(sessionBiz, taskBiz) + oldClient := dingtalk.NewOldClient(configConfig) + contactClient, err := dingtalk.NewContactClient(configConfig) + if err != nil { + cleanup3() + cleanup2() + cleanup() + return nil, nil, err + } + notableClient, err := dingtalk.NewNotableClient(configConfig) + if err != nil { + cleanup3() + cleanup2() + cleanup() + return nil, nil, err + } + callbackService := services.NewCallbackService(configConfig, gatewayGateway, oldClient, contactClient, notableClient, botTool) + historyService := services.NewHistoryService(chatHistoryBiz) + app := server.NewHTTPServer(chatService, sessionService, taskService, gatewayGateway, callbackService, historyService) + botConfigImpl := impl.NewBotConfigImpl(db) + dingTalkBotBiz := biz.NewDingTalkBotBiz(doDo, handle, botConfigImpl) + dingBotService := services.NewDingBotService(configConfig, dingTalkBotBiz) + v := server.ProvideAllDingBotServices(dingBotService) + dingTalkBotServer := server.NewDingTalkBotServer(v) + servers := server.NewServers(configConfig, app, dingTalkBotServer) + code := m.Run() + os.Exit(code) +} + +func Test_GetUserInfo(t *testing.T) { + var c entitys.TaskConfig + config := `{"param": {"type": "object", "required": ["number"], "properties": {"number": {"type": "string", "description": "订单编号/流水号"}}}, "request": {"url": "http://www.baidu.com/${number}", "headers": {"Authorization": "${authorization}"}, "method": "GET"}}` + err := json.Unmarshal([]byte(config), &c) + t.Log(err) +} diff --git a/internal/data/impl/bot_user.go b/internal/data/impl/bot_user.go new file mode 100644 index 0000000..25d0d16 --- /dev/null +++ b/internal/data/impl/bot_user.go @@ -0,0 +1,27 @@ +package impl + +import ( + "ai_scheduler/internal/data/model" + "ai_scheduler/tmpl/dataTemp" + "ai_scheduler/utils" + "database/sql" +) + +type BotUserImpl struct { + dataTemp.DataTemp +} + +func NewBotUserImpl(db *utils.Db) *BotUserImpl { + return &BotUserImpl{ + DataTemp: *dataTemp.NewDataTemp(db, new(model.AiBotUser)), + } +} + +func (k BotUserImpl) GetByStaffId(staffId string) (data *model.AiBotUser, err error) { + data = &model.AiBotUser{} + err = k.Db.Model(k.Model).Where("staff_id = ?", staffId).Find(data).Error + if data == nil { + err = sql.ErrNoRows + } + return +} diff --git a/internal/data/model/ai_bot_config.gen.go b/internal/data/model/ai_bot_config.gen.go index 5c4f27a..6885e81 100644 --- a/internal/data/model/ai_bot_config.gen.go +++ b/internal/data/model/ai_bot_config.gen.go @@ -14,9 +14,10 @@ const TableNameAiBotConfig = "ai_bot_config" type AiBotConfig struct { BotID int32 `gorm:"column:bot_id;primaryKey;autoIncrement:true" json:"bot_id"` SysID int32 `gorm:"column:sys_id;not null" json:"sys_id"` - BotType int32 `gorm:"column:bot_type;not null;default:1" json:"bot_type"` - BotName string `gorm:"column:bot_name;not null" json:"bot_name"` - BotConfig string `gorm:"column:bot_config;not null" json:"bot_config"` + BotType int32 `gorm:"column:bot_type;not null;default:1;comment:类型,1为钉钉机器人" json:"bot_type"` // 类型,1为钉钉机器人 + BotName string `gorm:"column:bot_name;not null;comment:名字" json:"bot_name"` // 名字 + BotConfig string `gorm:"column:bot_config;not null;comment:配置" json:"bot_config"` // 配置 + BotIndex string `gorm:"column:bot_index;not null;comment:索引" json:"bot_index"` // 索引 CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"` UpdatedAt time.Time `gorm:"column:updated_at;default:CURRENT_TIMESTAMP" json:"updated_at"` Status int32 `gorm:"column:status;not null" json:"status"` diff --git a/internal/data/model/ai_bot_user.gen.go b/internal/data/model/ai_bot_user.gen.go new file mode 100644 index 0000000..bad185a --- /dev/null +++ b/internal/data/model/ai_bot_user.gen.go @@ -0,0 +1,33 @@ +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. + +package model + +import ( + "time" +) + +const TableNameAiBotUser = "ai_bot_user" + +// AiBotUser mapped from table +type AiBotUser struct { + UserID int32 `gorm:"column:user_id;primaryKey" json:"user_id"` + StaffID string `gorm:"column:staff_id;not null;comment:标记用户用的唯一id,钉钉:钉钉侧提供的user_id" json:"staff_id"` // 标记用户用的唯一id,钉钉:钉钉侧提供的user_id + Name string `gorm:"column:name;not null;comment:用户名称" json:"name"` // 用户名称 + Title string `gorm:"column:title;not null;comment:职位" json:"title"` // 职位 + Extension string `gorm:"column:extension;not null;default:1;comment:信息面板" json:"extension"` // 信息面板 + RoleList string `gorm:"column:role_list;not null;comment:角色列表。" json:"role_list"` // 角色列表。 + DeptIDList string `gorm:"column:dept_id_list;not null;comment:所在部门id列表" json:"dept_id_list"` // 所在部门id列表 + IsBoss int32 `gorm:"column:is_boss;not null;comment:是否是老板" json:"is_boss"` // 是否是老板 + IsSenior int32 `gorm:"column:is_senior;not null;comment:是否是高管" json:"is_senior"` // 是否是高管 + HiredDate time.Time `gorm:"column:hired_date;not null;default:CURRENT_TIMESTAMP;comment:入职时间" json:"hired_date"` // 入职时间 + Status int32 `gorm:"column:status;not null" json:"status"` + DeleteAt time.Time `gorm:"column:delete_at" json:"delete_at"` + CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"` +} + +// TableName AiBotUser's table name +func (*AiBotUser) TableName() string { + return TableNameAiBotUser +} diff --git a/internal/entitys/bot.go b/internal/entitys/bot.go index 568822c..3e59b0d 100644 --- a/internal/entitys/bot.go +++ b/internal/entitys/bot.go @@ -8,9 +8,6 @@ import ( ) type RequireDataDingTalkBot struct { - Session string - Key string - Sys model.AiSy Histories []model.AiChatHi SessionInfo model.AiSession Tasks []model.AiTask @@ -24,7 +21,7 @@ type RequireDataDingTalkBot struct { } type DingTalkBot struct { - BotIndex string - ClientId string - ClientSecret string + BotIndex string `json:"bot_index"` + ClientId string `json:"client_id"` + ClientSecret string `json:"client_secret"` } diff --git a/internal/entitys/dingtalk.go b/internal/entitys/dingtalk.go new file mode 100644 index 0000000..bddf9b4 --- /dev/null +++ b/internal/entitys/dingtalk.go @@ -0,0 +1,17 @@ +package entitys + +type DingTalkUserInfo struct { + UserId string `json:"user_id"` + StaffId string `json:"staff_id"` + Name string `json:"name"` + Dept []Dept `json:"dept"` + IsBoss int `json:"is_boss"` + IsSenior int `json:"is_senior"` + HiredDate int `json:"hired_date"` + Extension string `json:"extension"` +} + +type Dept struct { + DeptName string `json:"dept_name"` + DeptId int `json:"dept_id"` +} diff --git a/internal/pkg/func.go b/internal/pkg/func.go index f6006ac..2197726 100644 --- a/internal/pkg/func.go +++ b/internal/pkg/func.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "net/url" + "strconv" "strings" ) @@ -65,3 +66,31 @@ func HexEncode(src, dst []byte) int { } return len(src) * 2 } + +// Ter 三目运算 Ter(true, 1, 2) +func Ter[T any](cond bool, a, b T) T { + if cond { + return a + } + return b +} + +// StringToSlice [num,num]转slice +func StringToSlice(s string) ([]int, error) { + // 1. 去掉两端的方括号 + trimmed := strings.Trim(s, "[]") + + // 2. 按逗号分割 + parts := strings.Split(trimmed, ",") + + // 3. 转换为 []int + result := make([]int, 0, len(parts)) + for _, part := range parts { + num, err := strconv.Atoi(strings.TrimSpace(part)) + if err != nil { + return nil, err + } + result = append(result, num) + } + return result, nil +} diff --git a/internal/server/ding_talk_bot.go b/internal/server/ding_talk_bot.go index 9a63812..f68f543 100644 --- a/internal/server/ding_talk_bot.go +++ b/internal/server/ding_talk_bot.go @@ -62,10 +62,10 @@ func (d *DingTalkBotServer) Run(ctx context.Context, botIndex string) { } err := cli.Start(ctx) if err != nil { - log.Info("%s启动失败", name) + log.Infof("%s启动失败", name) continue } - log.Info("%s启动成功", name) + log.Infof("%s启动成功", name) } } func DingBotServerInit(clientId string, clientSecret string, service DingBotServiceInterface) (cli *client.StreamClient) { diff --git a/internal/services/dtalk_bot.go b/internal/services/dtalk_bot.go index 864f5d7..e4e9e8d 100644 --- a/internal/services/dtalk_bot.go +++ b/internal/services/dtalk_bot.go @@ -31,7 +31,7 @@ func (d *DingBotService) OnChatBotMessageReceived(ctx context.Context, data *cha return } go func() { - defer close(requireData.Ch) + //defer close(requireData.Ch) //if match, _err := d.dingTalkBotBiz.Recognize(ctx, data); _err != nil { // requireData.Ch <- entitys.Response{ // Type: entitys.ResponseEnd, From d310bf8104791346fdef0978c86e29b0a2f45571 Mon Sep 17 00:00:00 2001 From: fuzhongyun <15339891972@163.com> Date: Thu, 11 Dec 2025 18:35:48 +0800 Subject: [PATCH 31/66] =?UTF-8?q?fix:=201.eino=20=E5=B7=A5=E4=BD=9C?= =?UTF-8?q?=E6=B5=81=E6=96=B9=E6=A1=88=E8=B0=83=E6=95=B4=202.=E4=B8=8B?= =?UTF-8?q?=E6=B8=B8=E6=89=B9=E9=87=8F=E8=AE=A2=E5=8D=95=E5=94=AE=E5=90=8E?= =?UTF-8?q?=E6=8E=A5=E5=85=A5=E5=B7=A5=E4=BD=9C=E6=B5=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cmd/server/wire.go | 42 ++++---- internal/biz/do/handle.go | 52 ++++++---- internal/biz/provider_set.go | 10 +- internal/domain/workflow/manager.go | 1 - internal/domain/workflow/provider_set.go | 22 +++++ internal/domain/workflow/register_all_gen.go | 7 ++ internal/domain/workflow/registry.go | 12 +++ internal/domain/workflow/runtime/registry.go | 95 +++++++++++++++++++ .../domain/workflow/zltx/crontab_supplier.go | 1 - .../zltx/order_after_reseller_batch.go | 51 +++++++++- .../zltx/order_after_reseller_batch_test.go | 46 --------- internal/tools/manager.go | 31 +++--- 12 files changed, 258 insertions(+), 112 deletions(-) delete mode 100644 internal/domain/workflow/manager.go create mode 100644 internal/domain/workflow/provider_set.go create mode 100644 internal/domain/workflow/register_all_gen.go create mode 100644 internal/domain/workflow/registry.go create mode 100644 internal/domain/workflow/runtime/registry.go delete mode 100644 internal/domain/workflow/zltx/crontab_supplier.go delete mode 100644 internal/domain/workflow/zltx/order_after_reseller_batch_test.go diff --git a/cmd/server/wire.go b/cmd/server/wire.go index bcd1d2c..8a29212 100644 --- a/cmd/server/wire.go +++ b/cmd/server/wire.go @@ -4,15 +4,16 @@ package main import ( - "ai_scheduler/internal/biz" - "ai_scheduler/internal/config" - "ai_scheduler/internal/data/impl" - "ai_scheduler/internal/pkg" - "ai_scheduler/internal/server" - "ai_scheduler/internal/services" - "ai_scheduler/internal/tools" - "ai_scheduler/internal/tools_bot" - "ai_scheduler/utils" + "ai_scheduler/internal/biz" + "ai_scheduler/internal/config" + "ai_scheduler/internal/data/impl" + "ai_scheduler/internal/pkg" + "ai_scheduler/internal/server" + "ai_scheduler/internal/services" + "ai_scheduler/internal/domain/workflow" + "ai_scheduler/internal/tools" + "ai_scheduler/internal/tools_bot" + "ai_scheduler/utils" "github.com/gofiber/fiber/v2/log" "github.com/google/wire" @@ -20,16 +21,17 @@ import ( // InitializeApp 初始化应用程序 func InitializeApp(*config.Config, log.AllLogger) (*server.Servers, func(), error) { - panic(wire.Build( - server.ProviderSetServer, - llm.ProviderSet, - tools.ProviderSetTools, - pkg.ProviderSetClient, - services.ProviderSetServices, - biz.ProviderSetBiz, - impl.ProviderImpl, - utils.ProviderUtils, - tools_bot.ProviderSetBotTools, - )) + panic(wire.Build( + server.ProviderSetServer, + llm.ProviderSet, + workflow.ProviderSetWorkflow, + tools.ProviderSetTools, + pkg.ProviderSetClient, + services.ProviderSetServices, + biz.ProviderSetBiz, + impl.ProviderImpl, + utils.ProviderUtils, + tools_bot.ProviderSetBotTools, + )) } diff --git a/internal/biz/do/handle.go b/internal/biz/do/handle.go index fa7b3c5..2028f54 100644 --- a/internal/biz/do/handle.go +++ b/internal/biz/do/handle.go @@ -7,7 +7,7 @@ import ( errors "ai_scheduler/internal/data/error" "ai_scheduler/internal/data/impl" "ai_scheduler/internal/data/model" - "ai_scheduler/internal/domain/workflow/zltx" + "ai_scheduler/internal/domain/workflow/runtime" "ai_scheduler/internal/entitys" "ai_scheduler/internal/gateway" "ai_scheduler/internal/pkg/l_request" @@ -24,11 +24,12 @@ import ( ) type Handle struct { - Ollama *llm_service.OllamaService - toolManager *tools.Manager - Bot *tools_bot.BotTool - conf *config.Config - sessionImpl *impl.SessionImpl + Ollama *llm_service.OllamaService + toolManager *tools.Manager + Bot *tools_bot.BotTool + conf *config.Config + sessionImpl *impl.SessionImpl + workflowManager *runtime.Registry } func NewHandle( @@ -37,13 +38,15 @@ func NewHandle( conf *config.Config, sessionImpl *impl.SessionImpl, dTalkBot *tools_bot.BotTool, + workflowManager *runtime.Registry, ) *Handle { return &Handle{ - Ollama: Ollama, - toolManager: toolManager, - conf: conf, - sessionImpl: sessionImpl, - Bot: dTalkBot, + Ollama: Ollama, + toolManager: toolManager, + conf: conf, + sessionImpl: sessionImpl, + Bot: dTalkBot, + workflowManager: workflowManager, } } @@ -285,17 +288,24 @@ func (r *Handle) handleApiTask(ctx context.Context, requireData *entitys.Require func (r *Handle) handleEinoWorkflow(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) { // token 写入ctx ctx = util.SetTokenToContext(ctx, requireData.Auth) - - // 构建工作流 - todo 示例,后续抽象出来 - zltxWorkflow, err := zltx.BuildOrderAfterResellerBatchWorkflow(ctx, r.conf.Tools.ZltxOrderAfterSaleResellerBatch) - if err != nil { - return + // 解析入参:workflow_id 与 input + var params map[string]any + if len(requireData.Match.Parameters) > 0 { + _ = json.Unmarshal([]byte(requireData.Match.Parameters), ¶ms) } - - // 工作流执行 - _, err = zltxWorkflow.Invoke(ctx, zltx.OrderAfterSaleResellerBatchWorkflowInput{}) - - return + wfID, _ := params["workflow_id"].(string) + input, _ := params["input"].(map[string]any) + if wfID == "" { + return fmt.Errorf("workflow_id 不能为空") + } + entitys.ResLoading(requireData.Ch, requireData.Task.Index, "正在执行工作流") + res, err := r.workflowManager.Invoke(ctx, wfID, input) + if err != nil { + return err + } + b, _ := json.Marshal(res) + entitys.ResJson(requireData.Ch, requireData.Task.Index, string(b)) + return nil } // 权限验证 diff --git a/internal/biz/provider_set.go b/internal/biz/provider_set.go index aefa3ce..1bdc0f7 100644 --- a/internal/biz/provider_set.go +++ b/internal/biz/provider_set.go @@ -1,10 +1,10 @@ package biz import ( - "ai_scheduler/internal/biz/do" - "ai_scheduler/internal/biz/llm_service" - - "github.com/google/wire" + "ai_scheduler/internal/biz/do" + "ai_scheduler/internal/biz/llm_service" + + "github.com/google/wire" ) var ProviderSetBiz = wire.NewSet( @@ -15,7 +15,7 @@ var ProviderSetBiz = wire.NewSet( llm_service.NewOllamaGenerate, //handle.NewHandle, do.NewDo, - do.NewHandle, + do.NewHandle, NewTaskBiz, NewDingTalkBotBiz, ) diff --git a/internal/domain/workflow/manager.go b/internal/domain/workflow/manager.go deleted file mode 100644 index 0e59ea2..0000000 --- a/internal/domain/workflow/manager.go +++ /dev/null @@ -1 +0,0 @@ -package workflow diff --git a/internal/domain/workflow/provider_set.go b/internal/domain/workflow/provider_set.go new file mode 100644 index 0000000..c728b44 --- /dev/null +++ b/internal/domain/workflow/provider_set.go @@ -0,0 +1,22 @@ +package workflow + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/domain/workflow/runtime" + "ai_scheduler/internal/pkg/utils_ollama" + + "github.com/google/wire" +) + +var ProviderSetWorkflow = wire.NewSet(NewRegistry) + +// NewRegistry 注入共享依赖并注册默认 Registry,确保自注册工作流可被发现 +func NewRegistry(conf *config.Config, llm *utils_ollama.Client) *runtime.Registry { + // 步骤1:设置运行时依赖(配置与LLM客户端),供工作流工厂在首次实例化时使用;必须在任何调用 Invoke 之前完成,否则会触发 "deps not set" + runtime.SetDeps(&runtime.Deps{Conf: conf, LLM: llm}) + // 步骤2:创建新的工作流注册表;注册表负责按工作流ID惰性实例化并缓存单例实例,保障并发访问下的安全 + r := runtime.NewRegistry() + // 步骤3:将该注册表设置为全局默认,便于通过 runtime.Default() 获取;自注册的工作流可通过默认注册表被发现并调用 + runtime.SetDefault(r) + return r +} diff --git a/internal/domain/workflow/register_all_gen.go b/internal/domain/workflow/register_all_gen.go new file mode 100644 index 0000000..d30ba64 --- /dev/null +++ b/internal/domain/workflow/register_all_gen.go @@ -0,0 +1,7 @@ +package workflow + +import ( + // 手工维护:在此空导入工作流包以触发其 init() 自注册 + // 新增工作流时,只需在这里添加一行 `_ ""` + _ "ai_scheduler/internal/domain/workflow/zltx" +) diff --git a/internal/domain/workflow/registry.go b/internal/domain/workflow/registry.go new file mode 100644 index 0000000..cbde3b6 --- /dev/null +++ b/internal/domain/workflow/registry.go @@ -0,0 +1,12 @@ +package workflow + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/pkg/utils_ollama" +) + +// 仅声明依赖结构,避免在 workflow 包内实现注册中心逻辑导致循环依赖 +type Deps struct { + Conf *config.Config + LLM *utils_ollama.Client +} diff --git a/internal/domain/workflow/runtime/registry.go b/internal/domain/workflow/runtime/registry.go new file mode 100644 index 0000000..1260173 --- /dev/null +++ b/internal/domain/workflow/runtime/registry.go @@ -0,0 +1,95 @@ +package runtime + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/pkg/utils_ollama" + "context" + "errors" + "sync" +) + +type Workflow interface { + ID() string + Schema() map[string]any + Invoke(ctx context.Context, input map[string]any) (map[string]any, error) +} + +type Deps struct { + Conf *config.Config + LLM *utils_ollama.Client +} + +type Factory func(deps *Deps) (Workflow, error) + +var ( + regMu sync.RWMutex + factories = map[string]Factory{} + deps *Deps + defaultReg *Registry +) + +func Register(id string, f Factory) { + regMu.Lock() + factories[id] = f + regMu.Unlock() +} + +func SetDeps(d *Deps) { + regMu.Lock() + deps = d + regMu.Unlock() +} + +type Registry struct { + mu sync.RWMutex + instances map[string]Workflow +} + +func NewRegistry() *Registry { + return &Registry{instances: make(map[string]Workflow)} +} + +func SetDefault(r *Registry) { + regMu.Lock() + defaultReg = r + regMu.Unlock() +} + +func Default() *Registry { + regMu.RLock() + r := defaultReg + regMu.RUnlock() + return r +} + +func (r *Registry) Invoke(ctx context.Context, id string, input map[string]any) (map[string]any, error) { + if input == nil { + input = map[string]any{} + } + regMu.RLock() + f, ok := factories[id] + regMu.RUnlock() + if !ok { + return nil, errors.New("workflow not found: " + id) + } + + r.mu.RLock() + w, exists := r.instances[id] + r.mu.RUnlock() + + if !exists { + if deps == nil { + return nil, errors.New("deps not set") + } + nw, err := f(deps) + if err != nil { + return nil, err + } + r.mu.Lock() + r.instances[id] = nw + w = nw + r.mu.Unlock() + } + + return w.Invoke(ctx, input) +} diff --git a/internal/domain/workflow/zltx/crontab_supplier.go b/internal/domain/workflow/zltx/crontab_supplier.go deleted file mode 100644 index 29ee896..0000000 --- a/internal/domain/workflow/zltx/crontab_supplier.go +++ /dev/null @@ -1 +0,0 @@ -package zltx diff --git a/internal/domain/workflow/zltx/order_after_reseller_batch.go b/internal/domain/workflow/zltx/order_after_reseller_batch.go index 5c3081a..e00b726 100644 --- a/internal/domain/workflow/zltx/order_after_reseller_batch.go +++ b/internal/domain/workflow/zltx/order_after_reseller_batch.go @@ -3,19 +3,66 @@ package zltx import ( "ai_scheduler/internal/config" toolZoarb "ai_scheduler/internal/domain/tools/zltx/order_after_reseller_batch" + "ai_scheduler/internal/domain/workflow/runtime" "context" "errors" "github.com/cloudwego/eino/compose" ) +func init() { + runtime.Register("zltx.orderAfterSaleResellerBatch", func(d *runtime.Deps) (runtime.Workflow, error) { + return &orderAfterSaleResellerBatch{cfg: d.Conf.Tools.ZltxOrderAfterSaleResellerBatch}, nil + }) +} + +type orderAfterSaleResellerBatch struct { + cfg config.ToolConfig +} + +// ID 返回工作流唯一标识 +func (o *orderAfterSaleResellerBatch) ID() string { return "zltx.orderAfterSaleResellerBatch" } + +// Schema 返回入参约束(用于校验/表单生成) +func (o *orderAfterSaleResellerBatch) Schema() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{"orderNumber": map[string]any{"type": "array", "items": map[string]any{"type": "string"}}}, + "required": []string{"orderNumber"}, + } +} + +// Invoke 调用原有编排工作流并规范化输出 +func (o *orderAfterSaleResellerBatch) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) { + // 构建工作流 + chain, err := o.buildWorkflow(ctx) + if err != nil { + return nil, err + } + + var in OrderAfterSaleResellerBatchWorkflowInput + if v, ok := input["orderNumber"].([]string); ok { + in.OrderNumber = v + } + _, err = chain.Invoke(ctx, in) + if err != nil { + return nil, err + } + + // 工作流 callback + + // 不关心输出,全部在途中输出 + return nil, nil +} + type OrderAfterSaleResellerBatchWorkflowInput struct { OrderNumber []string `json:"orderNumber"` } var ErrInvalidOrderNumbers = errors.New("orderNumber 不能为空") -func BuildOrderAfterResellerBatchWorkflow(ctx context.Context, cfg config.ToolConfig) (compose.Runnable[OrderAfterSaleResellerBatchWorkflowInput, toolZoarb.OrderAfterSaleResellerBatchData], error) { +// buildWorkflow 构建工作流 +func (o *orderAfterSaleResellerBatch) buildWorkflow(ctx context.Context) (compose.Runnable[OrderAfterSaleResellerBatchWorkflowInput, toolZoarb.OrderAfterSaleResellerBatchData], error) { // 定义工作流、出入参 c := compose.NewChain[OrderAfterSaleResellerBatchWorkflowInput, toolZoarb.OrderAfterSaleResellerBatchData]() @@ -29,7 +76,7 @@ func BuildOrderAfterResellerBatchWorkflow(ctx context.Context, cfg config.ToolCo // 2.调用工具 c.AppendLambda(compose.InvokableLambda(func(ctx context.Context, in OrderAfterSaleResellerBatchWorkflowInput) (toolZoarb.OrderAfterSaleResellerBatchData, error) { - return toolZoarb.Call(ctx, cfg, in.OrderNumber) + return toolZoarb.Call(ctx, o.cfg, in.OrderNumber) })) // 3.结果映射与整形 diff --git a/internal/domain/workflow/zltx/order_after_reseller_batch_test.go b/internal/domain/workflow/zltx/order_after_reseller_batch_test.go deleted file mode 100644 index be56786..0000000 --- a/internal/domain/workflow/zltx/order_after_reseller_batch_test.go +++ /dev/null @@ -1,46 +0,0 @@ -package zltx - -import ( - "ai_scheduler/internal/config" - "context" - "errors" - "strings" - "testing" -) - -func TestOrderAfterResellerBatch_InvalidOrderNumbers(t *testing.T) { - ctx := context.Background() - cfg := config.ToolConfig{} - - run, err := BuildOrderAfterResellerBatchWorkflow(ctx, cfg) - if err != nil { - t.Fatalf("build workflow error: %v", err) - } - - _, err = run.Invoke(ctx, OrderAfterSaleResellerBatchWorkflowInput{}) - if err == nil { - t.Fatalf("expected error, got nil") - } - if !errors.Is(err, ErrInvalidOrderNumbers) { - t.Fatalf("expected ErrInvalidOrderNumbers, got %v", err) - } -} - -func TestOrderAfterResellerBatch_ContextTokenRequired(t *testing.T) { - ctx := context.Background() - cfg := config.ToolConfig{} - - run, err := BuildOrderAfterResellerBatchWorkflow(ctx, cfg) - if err != nil { - t.Fatalf("build workflow error: %v", err) - } - - in := OrderAfterSaleResellerBatchWorkflowInput{OrderNumber: []string{"123"}} - _, err = run.Invoke(ctx, in) - if err == nil { - t.Fatalf("expected error, got nil") - } - if !strings.Contains(err.Error(), "token 未注入") { - t.Fatalf("expected contains 'token 未注入', got %v", err) - } -} diff --git a/internal/tools/manager.go b/internal/tools/manager.go index 2ebbb69..406d11c 100644 --- a/internal/tools/manager.go +++ b/internal/tools/manager.go @@ -2,7 +2,6 @@ package tools import ( "ai_scheduler/internal/config" - "ai_scheduler/internal/data/constants" "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg/utils_ollama" zltxtool "ai_scheduler/internal/tools/zltx" @@ -101,23 +100,23 @@ func (m *Manager) GetTool(name string) (entitys.Tool, bool) { } // GetAllTools 获取所有工具 -func (m *Manager) GetAllTools() []entitys.Tool { - tools := make([]entitys.Tool, 0, len(m.tools)) - for _, tool := range m.tools { - tools = append(tools, tool) - } - return tools -} +// func (m *Manager) GetAllTools() []entitys.Tool { +// tools := make([]entitys.Tool, 0, len(m.tools)) +// for _, tool := range m.tools { +// tools = append(tools, tool) +// } +// return tools +// } -// GetToolDefinitions 获取所有工具定义 -func (m *Manager) GetToolDefinitions(caller constants.Caller) []entitys.ToolDefinition { - definitions := make([]entitys.ToolDefinition, 0, len(m.tools)) - for _, tool := range m.tools { - definitions = append(definitions, tool.Definition()) - } +// // GetToolDefinitions 获取所有工具定义 +// func (m *Manager) GetToolDefinitions(caller constants.Caller) []entitys.ToolDefinition { +// definitions := make([]entitys.ToolDefinition, 0, len(m.tools)) +// for _, tool := range m.tools { +// definitions = append(definitions, tool.Definition()) +// } - return definitions -} +// return definitions +// } // ExecuteTool 执行工具 func (m *Manager) ExecuteTool(ctx context.Context, name string, requireData *entitys.RequireData) error { From afb24987c5ae0f4a9f3212206de20f5bfaff7a73 Mon Sep 17 00:00:00 2001 From: renzhiyuan <465386466@qq.com> Date: Fri, 12 Dec 2025 10:56:42 +0800 Subject: [PATCH 32/66] refactor: optimize DingTalk integration and add new utilities --- cmd/server/wire.go | 2 + internal/biz/ding_talk_bot.go | 9 +- internal/biz/do/ctx.go | 2 +- internal/biz/handle/dingtalk/auth.go | 68 +++++++++-- internal/biz/handle/dingtalk/dept.go | 104 ++++++++++++++--- internal/biz/handle/dingtalk/provider_set.go | 11 ++ internal/biz/handle/dingtalk/types.go | 117 +++++++++++-------- internal/biz/handle/dingtalk/user.go | 106 ++++++++--------- internal/data/constants/dingtalk.go | 4 +- internal/data/impl/bot_dept.go | 17 +++ internal/data/impl/bot_user.go | 2 +- internal/data/impl/provider_set.go | 3 + internal/entitys/bot.go | 2 +- internal/entitys/dingtalk.go | 25 ++-- internal/entitys/recognize.go | 6 +- internal/pkg/func.go | 40 +++++++ utils/rds.go | 4 +- 17 files changed, 369 insertions(+), 153 deletions(-) create mode 100644 internal/biz/handle/dingtalk/provider_set.go create mode 100644 internal/data/impl/bot_dept.go diff --git a/cmd/server/wire.go b/cmd/server/wire.go index f134ef9..b20df54 100644 --- a/cmd/server/wire.go +++ b/cmd/server/wire.go @@ -5,6 +5,7 @@ package main import ( "ai_scheduler/internal/biz" + "ai_scheduler/internal/biz/handle/dingtalk" "ai_scheduler/internal/config" "ai_scheduler/internal/data/impl" "ai_scheduler/internal/pkg" @@ -29,6 +30,7 @@ func InitializeApp(*config.Config, log.AllLogger) (*server.Servers, func(), erro impl.ProviderImpl, utils.ProviderUtils, tools_bot.ProviderSetBotTools, + dingtalk.ProviderSetDingTalk, )) } diff --git a/internal/biz/ding_talk_bot.go b/internal/biz/ding_talk_bot.go index 74f2a47..d39eaf0 100644 --- a/internal/biz/ding_talk_bot.go +++ b/internal/biz/ding_talk_bot.go @@ -2,6 +2,7 @@ package biz import ( "ai_scheduler/internal/biz/do" + "ai_scheduler/internal/biz/handle/dingtalk" "ai_scheduler/internal/data/constants" "ai_scheduler/internal/data/impl" "ai_scheduler/internal/data/model" @@ -22,6 +23,7 @@ type DingTalkBotBiz struct { botConfigImpl *impl.BotConfigImpl replier *chatbot.ChatbotReplier log log.Logger + dingTalkUser *dingtalk.User } // NewDingTalkBotBiz @@ -29,12 +31,15 @@ func NewDingTalkBotBiz( do *do.Do, handle *do.Handle, botConfigImpl *impl.BotConfigImpl, + dingTalkUser *dingtalk.User, + ) *DingTalkBotBiz { return &DingTalkBotBiz{ do: do, handle: handle, botConfigImpl: botConfigImpl, replier: chatbot.NewChatbotReplier(), + dingTalkUser: dingTalkUser, } } @@ -65,8 +70,8 @@ func (d *DingTalkBotBiz) InitRequire(ctx context.Context, data *chatbot.BotCallb } entitys.ResLog(requireData.Ch, "recognize_start", "收到消息,正在处理中,请稍等") - requireData.Sys, err = d.do.GetSysInfoForDingTalkBot(requireData) - requireData.Tasks, err = d.do.GetTasks(requireData.Sys.SysID) + requireData.UserInfo, err = d.dingTalkUser.GetUserInfo(ctx, data.SenderStaffId, dingtalk.WithId(1)) + return } diff --git a/internal/biz/do/ctx.go b/internal/biz/do/ctx.go index d252c47..da0d7ae 100644 --- a/internal/biz/do/ctx.go +++ b/internal/biz/do/ctx.go @@ -225,7 +225,7 @@ func (d *Do) GetSysInfo(requireData *entitys.RequireData) (sysInfo model.AiSy, e func (d *Do) GetSysInfoForDingTalkBot(requireData *entitys.RequireDataDingTalkBot) (sysInfo model.AiSy, err error) { cond := builder.NewCond() - cond = cond.And(builder.Eq{"app_key": requireData.Key}) + cond = cond.And(builder.Eq{"app_key": requireData.Auth}) cond = cond.And(builder.IsNull{"delete_at"}) cond = cond.And(builder.Eq{"status": 1}) err = d.sysImpl.GetOneBySearchToStrut(&cond, &sysInfo) diff --git a/internal/biz/handle/dingtalk/auth.go b/internal/biz/handle/dingtalk/auth.go index 8ddd0c1..45fa566 100644 --- a/internal/biz/handle/dingtalk/auth.go +++ b/internal/biz/handle/dingtalk/auth.go @@ -3,25 +3,33 @@ package dingtalk import ( "ai_scheduler/internal/config" "ai_scheduler/internal/data/constants" + "ai_scheduler/internal/data/impl" + "ai_scheduler/internal/data/model" + "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg/l_request" + "ai_scheduler/utils" "context" "encoding/json" "errors" "net/http" "time" + "github.com/gofiber/fiber/v2/log" "github.com/redis/go-redis/v9" + "xorm.io/builder" ) type Auth struct { - redis *redis.Client - cfg *config.Config + redis *redis.Client + cfg *config.Config + botConfigImpl *impl.BotConfigImpl } -func NewAuth(cfg *config.Config, redis *redis.Client) *Auth { +func NewAuth(cfg *config.Config, redis *utils.Rdb, botConfigImpl *impl.BotConfigImpl) *Auth { return &Auth{ - redis: redis, - cfg: cfg, + redis: redis.Rdb, + cfg: cfg, + botConfigImpl: botConfigImpl, } } @@ -29,13 +37,16 @@ func (a *Auth) GetAccessToken(ctx context.Context, clientId string, clientSecret if clientId == "" { return "", errors.New("clientId is empty") } - token := a.redis.Get(ctx, a.getKey(clientId)).String() - if token == "" { + accessToken = a.redis.Get(ctx, a.getKey(clientId)).Val() + if accessToken == "" { authInfo, _err := a.getNewAccessToken(ctx, clientId, clientSecret) if _err != nil { return "", _err } - a.redis.SetEx(ctx, a.getKey(clientId), authInfo.AccessToken, time.Duration(authInfo.ExpiresIn-3600)*time.Second) + err = a.redis.SetEx(ctx, a.getKey(clientId), authInfo.AccessToken, time.Duration(authInfo.ExpireIn-3600)*time.Second).Err() + if err != nil { + return + } accessToken = authInfo.AccessToken } return @@ -53,10 +64,10 @@ func (a *Auth) getNewAccessToken(ctx context.Context, clientId string, clientSec req := l_request.Request{ Method: http.MethodPost, - Url: constants.GetDingTalkRequestUrl(constants.RequestUrlGetAccessToken, nil), - Data: map[string]string{ - "appkey": clientId, - "appsecret": clientSecret, + Url: "https://api.dingtalk.com/v1.0/oauth2/accessToken", + Json: map[string]interface{}{ + "appKey": clientId, + "appSecret": clientSecret, }, } res, err := req.Send() @@ -67,3 +78,36 @@ func (a *Auth) getNewAccessToken(ctx context.Context, clientId string, clientSec return } + +func (a *Auth) GetTokenFromBotOption(ctx context.Context, botOption ...BotOption) (token string, err error) { + botInfo := &Bot{} + for _, option := range botOption { + option(botInfo) + } + + if botInfo.id == 0 && botInfo.botConfig == nil { + err = errors.New("botInfo is nil") + return + } + if botInfo.botConfig == nil { + var botConfigDo model.AiBotConfig + cond := builder.NewCond() + cond = cond.And(builder.Eq{"bot_id": botInfo.id}) + err = a.botConfigImpl.GetOneBySearchToStrut(&cond, &botConfigDo) + if err != nil { + return + } + if botConfigDo.BotID == 0 { + return "", errors.New("未找到机器人服务配置") + } + botInfo.botConfig = &botConfigDo + } + var botConfig entitys.DingTalkBot + err = json.Unmarshal([]byte(botInfo.botConfig.BotConfig), &botConfig) + if err != nil { + log.Infof("初始化“%s”失败:%s", botInfo.botConfig.BotName, err.Error()) + return + } + return a.GetAccessToken(ctx, botConfig.ClientId, botConfig.ClientSecret) + +} diff --git a/internal/biz/handle/dingtalk/dept.go b/internal/biz/handle/dingtalk/dept.go index 14327c9..3ae1a36 100644 --- a/internal/biz/handle/dingtalk/dept.go +++ b/internal/biz/handle/dingtalk/dept.go @@ -1,35 +1,111 @@ package dingtalk import ( + "ai_scheduler/internal/data/constants" "ai_scheduler/internal/data/impl" + "ai_scheduler/internal/data/model" "ai_scheduler/internal/entitys" - "database/sql" - "errors" + "ai_scheduler/internal/pkg" + "ai_scheduler/internal/pkg/l_request" + "context" + "encoding/json" + "fmt" + "net/http" + + "xorm.io/builder" ) type Dept struct { - dingUserImpl *impl.BotUserImpl + dingDeptImpl *impl.BotDeptImpl + auth *Auth } -func NewDept(dingUserImpl *impl.BotUserImpl) *User { - return &User{ - dingUserImpl: dingUserImpl, +func NewDept(dingDeptImpl *impl.BotDeptImpl, auth *Auth) *Dept { + return &Dept{ + dingDeptImpl: dingDeptImpl, + auth: auth, } } -func (u *User) GetDeptInfo(userId string) (userInfo *entitys.DingTalkUserInfo, err error) { - if len(userId) == 0 { +func (d *Dept) GetDeptInfoByDeptIds(ctx context.Context, deptIds []int) (depts []*entitys.Dept, err error) { + if len(deptIds) == 0 { return } - user, err := u.dingUserImpl.GetByStaffId(userId) + deptsInfo := make([]model.AiBotDept, 0) + cond := builder.NewCond() + cond = cond.And(builder.Eq{"dingtalk_dept_id": deptIds}) + err = d.dingDeptImpl.GetRangeToMapStruct(&cond, deptsInfo) if err != nil { - if !errors.Is(err, sql.ErrNoRows) { - return + return + } + var existDept = make([]int, len(deptsInfo), 0) + for _, dept := range deptsInfo { + depts = append(depts, &entitys.Dept{ + DeptId: int(dept.DeptID), + Name: dept.Name, + }) + existDept = append(existDept, int(dept.DeptID)) + } + diff := pkg.Difference(deptIds, existDept) + if len(diff) > 0 { + deptDo := make([]model.AiBotDept, 0) + for _, deptId := range diff { + deptInfo, err := d.GetDeptInfoFromDingTalk(ctx, deptId) + if err != nil { + return nil, err + } + depts = append(depts, &entitys.Dept{ + DeptId: deptInfo.DeptId, + Name: deptInfo.Name, + }) + deptDo = append(deptDo, model.AiBotDept{ + DingtalkDeptID: int32(deptInfo.DeptId), + Name: deptInfo.Name, + }) + } + if len(deptDo) > 0 { + _, err = d.dingDeptImpl.Add(deptDo) + if err != nil { + return nil, err + } } } - //如果没有找到,则新增 - if user == nil { - } return } + +func (d *Dept) GetDeptInfoFromDingTalk(ctx context.Context, deptId int, botOption ...BotOption) (depts DeptResResult, err error) { + if deptId == 0 { + return + } + token, err := d.auth.GetTokenFromBotOption(ctx, botOption...) + if err != nil { + return + } + + req := l_request.Request{ + Method: http.MethodPost, + Url: constants.GetDingTalkRequestUrl(constants.RequestUrlGetDeptGet, map[string]string{ + "access_token": token, + }), + Json: map[string]interface{}{ + "dept_id": deptId, + }, + } + res, _err := req.Send() + if _err != nil { + err = _err + return + } + var deptInfo DeptRes + + err = json.Unmarshal(res.Content, &deptInfo) + if err != nil { + return + } + if deptInfo.Errcode != 0 { + fmt.Errorf("钉钉请求报错:%s", deptInfo.Errmsg) + } + return deptInfo.DeptResResult, err + +} diff --git a/internal/biz/handle/dingtalk/provider_set.go b/internal/biz/handle/dingtalk/provider_set.go new file mode 100644 index 0000000..579d464 --- /dev/null +++ b/internal/biz/handle/dingtalk/provider_set.go @@ -0,0 +1,11 @@ +package dingtalk + +import ( + "github.com/google/wire" +) + +var ProviderSetDingTalk = wire.NewSet( + NewUser, + NewAuth, + NewDept, +) diff --git a/internal/biz/handle/dingtalk/types.go b/internal/biz/handle/dingtalk/types.go index 1e9fe76..275010c 100644 --- a/internal/biz/handle/dingtalk/types.go +++ b/internal/biz/handle/dingtalk/types.go @@ -1,59 +1,78 @@ package dingtalk +import "time" + type AuthInfo struct { AccessToken string `json:"accessToken"` - ExpiresIn int64 `json:"expiresIn"` + ExpireIn int64 `json:"expireIn"` +} + +type UserInfoRes struct { + Errcode int `json:"errcode"` + Errmsg string `json:"errmsg"` + Result UserInfoResResult `json:"result"` + RequestId string `json:"request_id"` } type UserInfoResResult struct { - Errcode string `json:"errcode"` - Result Result `json:"result"` - Errmsg string `json:"errmsg"` -} -type Result struct { - Extension string `json:"extension"` - Unionid string `json:"unionid"` - Boss string `json:"boss"` - RoleList struct { - GroupName string `json:"group_name"` - Name string `json:"name"` - Id string `json:"id"` - } `json:"role_list"` - ExclusiveAccount bool `json:"exclusive_account"` - ManagerUserid string `json:"manager_userid"` - Admin string `json:"admin"` - Remark string `json:"remark"` - Title string `json:"title"` - HiredDate int `json:"hired_date"` - Userid string `json:"userid"` - WorkPlace string `json:"work_place"` - DeptOrderList struct { - DeptId string `json:"dept_id"` - Order string `json:"order"` + Active bool `json:"active"` + Admin bool `json:"admin"` + Avatar string `json:"avatar"` + Boss bool `json:"boss"` + CreateTime time.Time `json:"create_time"` + DeptIdList []int `json:"dept_id_list"` + DeptOrderList []struct { + DeptId int `json:"dept_id"` + Order int64 `json:"order"` } `json:"dept_order_list"` - RealAuthed string `json:"real_authed"` - DeptIdList string `json:"dept_id_list"` - JobNumber string `json:"job_number"` - Email string `json:"email"` - LeaderInDept struct { - Leader string `json:"leader"` - DeptId string `json:"dept_id"` + ExclusiveAccount bool `json:"exclusive_account"` + HideMobile bool `json:"hide_mobile"` + HiredDate int64 `json:"hired_date"` + JobNumber string `json:"job_number"` + LeaderInDept []struct { + DeptId int `json:"dept_id"` + Leader bool `json:"leader"` } `json:"leader_in_dept"` - Mobile string `json:"mobile"` - Active string `json:"active"` - OrgEmail string `json:"org_email"` - Telephone string `json:"telephone"` - Avatar string `json:"avatar"` - HideMobile string `json:"hide_mobile"` - Senior string `json:"senior"` - Name string `json:"name"` - UnionEmpExt struct { - UnionEmpMapList struct { - Userid string `json:"userid"` - CorpId string `json:"corp_id"` - } `json:"union_emp_map_list"` - Userid string `json:"userid"` - CorpId string `json:"corp_id"` - } `json:"union_emp_ext"` - StateCode string `json:"state_code"` + ManagerUserid string `json:"manager_userid"` + Name string `json:"name"` + RealAuthed bool `json:"real_authed"` + RoleList []struct { + GroupName string `json:"group_name"` + Id int `json:"id"` + Name string `json:"name"` + } `json:"role_list"` + Senior bool `json:"senior"` + Title string `json:"title"` + Unionid string `json:"unionid"` + Userid string `json:"userid"` +} + +type DeptRes struct { + Errcode int `json:"errcode"` + Errmsg string `json:"errmsg"` + DeptResResult DeptResResult `json:"result"` + RequestId string `json:"request_id"` +} + +type DeptResResult struct { + DeptPermits []int `json:"dept_permits"` + OuterPermitUsers []string `json:"outer_permit_users"` + DeptManagerUseridList []string `json:"dept_manager_userid_list"` + OrgDeptOwner string `json:"org_dept_owner"` + OuterDept bool `json:"outer_dept"` + DeptGroupChatId string `json:"dept_group_chat_id"` + GroupContainSubDept bool `json:"group_contain_sub_dept"` + AutoAddUser bool `json:"auto_add_user"` + HideDept bool `json:"hide_dept"` + Name string `json:"name"` + OuterPermitDepts []int `json:"outer_permit_depts"` + UserPermits []interface{} `json:"user_permits"` + DeptId int `json:"dept_id"` + CreateDeptGroup bool `json:"create_dept_group"` + Order int `json:"order"` + Code string `json:"code"` + UnionDeptExt struct { + CorpId string `json:"corp_id"` + DeptId int `json:"dept_id"` + } `json:"union_dept_ext"` } diff --git a/internal/biz/handle/dingtalk/user.go b/internal/biz/handle/dingtalk/user.go index 7d213c2..cf860f1 100644 --- a/internal/biz/handle/dingtalk/user.go +++ b/internal/biz/handle/dingtalk/user.go @@ -15,29 +15,24 @@ import ( "net/http" "strings" "time" - - "github.com/gofiber/fiber/v2/log" - "xorm.io/builder" ) type User struct { dingUserImpl *impl.BotUserImpl botConfigImpl *impl.BotConfigImpl auth *Auth - logger log.Logger + dept *Dept } func NewUser( dingUserImpl *impl.BotUserImpl, - botConfig *impl.BotConfigImpl, auth *Auth, - logger log.Logger, + dept *Dept, ) *User { return &User{ - dingUserImpl: dingUserImpl, - botConfigImpl: botConfig, - logger: logger, - auth: auth, + dingUserImpl: dingUserImpl, + auth: auth, + dept: dept, } } @@ -53,63 +48,58 @@ func (u *User) GetUserInfo(ctx context.Context, staffId string, botOption ...Bot } //如果没有找到,则新增 if user == nil { - DingUserInfo, _err := u.GetUserInfoFromDingTalkWithBot(ctx, staffId, botOption...) + DingUserInfo, _err := u.getUserInfoFromDingTalkWithBot(ctx, staffId, botOption...) if _err != nil { return nil, _err } - dingUserDo := &model.AiBotUser{ - StaffID: DingUserInfo.Userid, - Name: DingUserInfo.Name, - Title: DingUserInfo.Title, - Extension: DingUserInfo.Extension, - DeptIDList: DingUserInfo.DeptIdList, - IsBoss: int32(pkg.Ter(DingUserInfo.Boss == "true", constants.IsBossTrue, constants.IsBossFalse)), - IsSenior: int32(pkg.Ter(DingUserInfo.Boss == "true", constants.IsSeniorTrue, constants.IsSeniorFalse)), - HiredDate: time.Unix(int64(DingUserInfo.HiredDate), 0), + user = &model.AiBotUser{ + StaffID: DingUserInfo.Userid, + Name: DingUserInfo.Name, + Title: DingUserInfo.Title, + //Extension: DingUserInfo.Extension, + DeptIDList: strings.Join(pkg.SliceIntToString(DingUserInfo.DeptIdList), ","), + IsBoss: int32(pkg.Ter(DingUserInfo.Boss, constants.IsBossTrue, constants.IsBossFalse)), + IsSenior: int32(pkg.Ter(DingUserInfo.Senior, constants.IsSeniorTrue, constants.IsSeniorFalse)), + HiredDate: time.Unix(DingUserInfo.HiredDate, 0), } - //deptIdList, _err := pkg.StringToSlice(DingUserInfo.DeptIdList) - //if err != nil { - // return nil, _err - //} - dingUserDo.DeptIDList = strings.Trim(dingUserDo.DeptIDList, "[]") - } - return -} -func (u *User) GetUserInfoFromDingTalkWithBot(ctx context.Context, staffId string, botOption ...BotOption) (userInfo Result, err error) { - botInfo := &Bot{} - for _, option := range botOption { - option(botInfo) - } - - if botInfo.id == 0 && botInfo.botConfig == nil { - err = errors.New("botInfo is nil") - return - } - if botInfo.botConfig == nil { - cond := builder.NewCond() - cond = cond.And(builder.Eq{"bot_id": botInfo.id}) - err = u.botConfigImpl.GetOneBySearchToStrut(&cond, botInfo.botConfig) + _, err = u.dingUserImpl.Add(user) if err != nil { return } - } - var config entitys.DingTalkBot - err = json.Unmarshal([]byte(botInfo.botConfig.BotConfig), &config) - if err != nil { - log.Infof("初始化“%s”失败:%s", botInfo.botConfig.BotName, err.Error()) - return + userInfo = &entitys.DingTalkUserInfo{ + UserId: int(user.UserID), + StaffId: user.StaffID, + Name: user.Name, + IsBoss: constants.IsBoss(user.IsBoss), + IsSenior: constants.IsSenior(user.IsSenior), + HiredDate: user.HiredDate, + Extension: user.Extension, } - token, err := u.auth.GetAccessToken(ctx, config.ClientId, config.ClientSecret) - if err != nil { - return + if len(user.DeptIDList) > 0 { + deptIdList := pkg.SliceStringToInt(strings.Split(user.DeptIDList, ",")) + depts, _err := u.dept.GetDeptInfoByDeptIds(ctx, deptIdList) + if _err != nil { + return nil, err + } + for _, dept := range depts { + userInfo.Dept = append(userInfo.Dept, dept) + } } - return u.GetUserInfoFromDingTalk(ctx, token, staffId) + return userInfo, nil } -func (u *User) GetUserInfoFromDingTalk(ctx context.Context, token string, staffId string) (user Result, err error) { +func (u *User) getUserInfoFromDingTalkWithBot(ctx context.Context, staffId string, botOption ...BotOption) (userInfo UserInfoResResult, err error) { + token, err := u.auth.GetTokenFromBotOption(ctx, botOption...) + if err != nil { + return + } + return u.getUserInfoFromDingTalk(ctx, token, staffId) +} + +func (u *User) getUserInfoFromDingTalk(ctx context.Context, token string, staffId string) (user UserInfoResResult, err error) { if token == "" && staffId == "" { err = errors.New("获取钉钉用户信息的必要参数不足") return @@ -128,13 +118,13 @@ func (u *User) GetUserInfoFromDingTalk(ctx context.Context, token string, staffI if err != nil { return } - var userInfoResResult UserInfoResResult - err = json.Unmarshal(res.Content, &userInfoResResult) + var userInfoRes UserInfoRes + err = json.Unmarshal(res.Content, &userInfoRes) if err != nil { return } - if userInfoResResult.Errcode != "0" { - fmt.Errorf("钉钉请求报错:%s", userInfoResResult.Errmsg) + if userInfoRes.Errcode != 0 { + fmt.Errorf("钉钉请求报错:%s", userInfoRes.Errmsg) } - return userInfoResResult.Result, err + return userInfoRes.Result, err } diff --git a/internal/data/constants/dingtalk.go b/internal/data/constants/dingtalk.go index dd730cc..5c9de89 100644 --- a/internal/data/constants/dingtalk.go +++ b/internal/data/constants/dingtalk.go @@ -7,8 +7,8 @@ const DingTalkBseUrl = "https://oapi.dingtalk.com" type RequestUrl string const ( - RequestUrlGetAccessToken RequestUrl = "/v1.0/oauth2/accessToken" - RequestUrlGetUserGet RequestUrl = "/topapi/v2/user/get" + RequestUrlGetUserGet RequestUrl = "/topapi/v2/user/get" + RequestUrlGetDeptGet RequestUrl = "/topapi/v2/department/get" ) func GetDingTalkRequestUrl(path RequestUrl, query map[string]string) string { diff --git a/internal/data/impl/bot_dept.go b/internal/data/impl/bot_dept.go new file mode 100644 index 0000000..8ae0e4b --- /dev/null +++ b/internal/data/impl/bot_dept.go @@ -0,0 +1,17 @@ +package impl + +import ( + "ai_scheduler/internal/data/model" + "ai_scheduler/tmpl/dataTemp" + "ai_scheduler/utils" +) + +type BotDeptImpl struct { + dataTemp.DataTemp +} + +func NewBotDeptImpl(db *utils.Db) *BotDeptImpl { + return &BotDeptImpl{ + DataTemp: *dataTemp.NewDataTemp(db, new(model.AiBotDept)), + } +} diff --git a/internal/data/impl/bot_user.go b/internal/data/impl/bot_user.go index 25d0d16..46d8442 100644 --- a/internal/data/impl/bot_user.go +++ b/internal/data/impl/bot_user.go @@ -18,7 +18,7 @@ func NewBotUserImpl(db *utils.Db) *BotUserImpl { } func (k BotUserImpl) GetByStaffId(staffId string) (data *model.AiBotUser, err error) { - data = &model.AiBotUser{} + err = k.Db.Model(k.Model).Where("staff_id = ?", staffId).Find(data).Error if data == nil { err = sql.ErrNoRows diff --git a/internal/data/impl/provider_set.go b/internal/data/impl/provider_set.go index 53d389c..1563258 100644 --- a/internal/data/impl/provider_set.go +++ b/internal/data/impl/provider_set.go @@ -10,4 +10,7 @@ var ProviderImpl = wire.NewSet( NewTaskImpl, NewChatHisImpl, NewBotConfigImpl, + NewBotDeptImpl, + NewBotUserImpl, + NewBotChatHisImpl, ) diff --git a/internal/entitys/bot.go b/internal/entitys/bot.go index 3e59b0d..af92ba8 100644 --- a/internal/entitys/bot.go +++ b/internal/entitys/bot.go @@ -9,7 +9,7 @@ import ( type RequireDataDingTalkBot struct { Histories []model.AiChatHi - SessionInfo model.AiSession + UserInfo *DingTalkUserInfo Tasks []model.AiTask Match *Match Req *chatbot.BotCallbackDataModel diff --git a/internal/entitys/dingtalk.go b/internal/entitys/dingtalk.go index bddf9b4..3595bf3 100644 --- a/internal/entitys/dingtalk.go +++ b/internal/entitys/dingtalk.go @@ -1,17 +1,22 @@ package entitys +import ( + "ai_scheduler/internal/data/constants" + "time" +) + type DingTalkUserInfo struct { - UserId string `json:"user_id"` - StaffId string `json:"staff_id"` - Name string `json:"name"` - Dept []Dept `json:"dept"` - IsBoss int `json:"is_boss"` - IsSenior int `json:"is_senior"` - HiredDate int `json:"hired_date"` - Extension string `json:"extension"` + UserId int `json:"user_id"` + StaffId string `json:"staff_id"` + Name string `json:"name"` + Dept []*Dept `json:"dept"` + IsBoss constants.IsBoss `json:"is_boss"` + IsSenior constants.IsSenior `json:"is_senior"` + HiredDate time.Time `json:"hired_date"` + Extension string `json:"extension"` } type Dept struct { - DeptName string `json:"dept_name"` - DeptId int `json:"dept_id"` + Name string `json:"name"` + DeptId int `json:"dept_id"` } diff --git a/internal/entitys/recognize.go b/internal/entitys/recognize.go index 45d9fcc..46aef01 100644 --- a/internal/entitys/recognize.go +++ b/internal/entitys/recognize.go @@ -28,7 +28,11 @@ type RecognizeUserContent struct { type FileData []byte type RecognizeFile struct { + File Files + FileUrl []string // 文件下载链接 +} + +type Files struct { File []FileData // 文件数据(二进制格式) - FileUrl string // 文件下载链接 FileType constants.Caller // 文件类型(文件类型,能填最好填,可以跳过一层判断) } diff --git a/internal/pkg/func.go b/internal/pkg/func.go index 2197726..3e6e5b7 100644 --- a/internal/pkg/func.go +++ b/internal/pkg/func.go @@ -94,3 +94,43 @@ func StringToSlice(s string) ([]int, error) { } return result, nil } + +// Difference 差集 +func Difference[T comparable](a, b []T) []T { + // 创建 b 的映射(T 必须是可比较的类型) + bMap := make(map[T]struct{}, len(b)) + for _, item := range b { + bMap[item] = struct{}{} + } + + var diff []T // 修正为 []T 而非 []int + for _, item := range a { + if _, found := bMap[item]; !found { + diff = append(diff, item) + } + } + return diff +} + +// SliceStringToInt []string=>[]int +func SliceStringToInt(strSlice []string) []int { + numSlice := make([]int, len(strSlice), 0) + for _, str := range strSlice { + num, err := strconv.Atoi(str) + if err != nil { + return nil + } + numSlice = append(numSlice, num) + } + return numSlice +} + +// SliceIntToString []int=>[]string +func SliceIntToString(slice []int) []string { + numSlice := make([]string, len(slice), 0) + for _, str := range slice { + + numSlice = append(numSlice, strconv.Itoa(str)) + } + return numSlice +} diff --git a/utils/rds.go b/utils/rds.go index 57b89f1..874eb3e 100644 --- a/utils/rds.go +++ b/utils/rds.go @@ -13,10 +13,10 @@ type Rdb struct { var rdb *Rdb -func NewRdb(c *config.Redis) *Rdb { +func NewRdb(c *config.Config) *Rdb { if rdb == nil { //构建 redis - rdbBuild := buildRdb(c) + rdbBuild := buildRdb(&c.Redis) //退出时清理资源 rdb = &Rdb{Rdb: rdbBuild} } From e0fd700768c68512f09c1a349dab00ee7682802f Mon Sep 17 00:00:00 2001 From: renzhiyuan <465386466@qq.com> Date: Fri, 12 Dec 2025 10:59:43 +0800 Subject: [PATCH 33/66] refactor: optimize DingTalk integration and add new utilities --- internal/entitys/recognize.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/internal/entitys/recognize.go b/internal/entitys/recognize.go index 46aef01..679b86f 100644 --- a/internal/entitys/recognize.go +++ b/internal/entitys/recognize.go @@ -28,11 +28,11 @@ type RecognizeUserContent struct { type FileData []byte type RecognizeFile struct { - File Files + File []File FileUrl []string // 文件下载链接 } -type Files struct { - File []FileData // 文件数据(二进制格式) +type File struct { + FileData FileData // 文件数据(二进制格式) FileType constants.Caller // 文件类型(文件类型,能填最好填,可以跳过一层判断) } From e3b958e1655330866e3118fa530354af7ae06b3d Mon Sep 17 00:00:00 2001 From: wolter <11@gmail> Date: Fri, 12 Dec 2025 11:03:30 +0800 Subject: [PATCH 34/66] feat: RecognizeFile --- internal/entitys/recognize.go | 10 ++-------- internal/entitys/types.go | 2 +- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/internal/entitys/recognize.go b/internal/entitys/recognize.go index 679b86f..3d68be7 100644 --- a/internal/entitys/recognize.go +++ b/internal/entitys/recognize.go @@ -1,8 +1,6 @@ package entitys -import ( - "ai_scheduler/internal/data/constants" -) +import "ai_scheduler/internal/data/constants" type Recognize struct { SystemPrompt string // 系统提示内容 @@ -28,11 +26,7 @@ type RecognizeUserContent struct { type FileData []byte type RecognizeFile struct { - File []File - FileUrl []string // 文件下载链接 -} - -type File struct { FileData FileData // 文件数据(二进制格式) FileType constants.Caller // 文件类型(文件类型,能填最好填,可以跳过一层判断) + FileUrl string // 文件下载链接 } diff --git a/internal/entitys/types.go b/internal/entitys/types.go index 72ecd6c..b8f8caa 100644 --- a/internal/entitys/types.go +++ b/internal/entitys/types.go @@ -30,7 +30,7 @@ type FirstSockRequest struct { type ChatSockRequest struct { Text string `json:"text" binding:"required"` - Img string `json:"img" binding:"required"` + Img string `json:"img" binding:"required"` // 多图片使用 英文, 分割 File string `json:"file" binding:"required"` Tags string `json:"tags" binding:"required"` MarkHis int64 `json:"mark_his" ` From 92bff0e4b9d1e1b0cdf390308cb65eb1273b4459 Mon Sep 17 00:00:00 2001 From: wolter <11@gmail> Date: Fri, 12 Dec 2025 11:06:09 +0800 Subject: [PATCH 35/66] =?UTF-8?q?feat:=20=E5=A4=84=E7=90=86=E7=94=A8?= =?UTF-8?q?=E6=88=B7=E8=BE=93=E5=85=A5=E6=96=87=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/biz/router.go | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/internal/biz/router.go b/internal/biz/router.go index c78127f..5127a54 100644 --- a/internal/biz/router.go +++ b/internal/biz/router.go @@ -7,6 +7,7 @@ import ( "ai_scheduler/internal/gateway" "context" "encoding/json" + "strings" "time" "ai_scheduler/internal/entitys" @@ -137,9 +138,13 @@ func (r *AiRouterBiz) buildUserContent(requireData *entitys.RequireData) (*entit // 处理文件和图片 fileUrls := []string{requireData.Req.File, requireData.Req.Img} - for _, url := range fileUrls { - if url != "" { - files = append(files, &entitys.RecognizeFile{FileUrl: url}) + for _, item := range fileUrls { + // 处理逗号分隔的多个URL + urlList := strings.Split(item, ",") + for _, url := range urlList { + if url != "" { + files = append(files, &entitys.RecognizeFile{FileUrl: url}) + } } } From 9c3f0300ad21a26e2f476e5051e619948b5564b9 Mon Sep 17 00:00:00 2001 From: renzhiyuan <465386466@qq.com> Date: Fri, 12 Dec 2025 15:05:50 +0800 Subject: [PATCH 36/66] refactor: optimize DingTalk integration and add new utilities --- internal/biz/ding_talk_bot.go | 2 +- internal/biz/handle/dingtalk/auth.go | 27 +++++++++++++++----------- internal/biz/handle/dingtalk/dept.go | 22 +++++++++------------ internal/biz/handle/dingtalk/types.go | 8 +++++++- internal/biz/handle/dingtalk/user.go | 20 ++++++++----------- internal/data/model/ai_bot_dept.gen.go | 26 +++++++++++++++++++++++++ internal/data/model/ai_bot_user.gen.go | 26 ++++++++++++------------- internal/pkg/func.go | 15 +++++++------- 8 files changed, 87 insertions(+), 59 deletions(-) create mode 100644 internal/data/model/ai_bot_dept.gen.go diff --git a/internal/biz/ding_talk_bot.go b/internal/biz/ding_talk_bot.go index d39eaf0..c67c8e5 100644 --- a/internal/biz/ding_talk_bot.go +++ b/internal/biz/ding_talk_bot.go @@ -70,7 +70,7 @@ func (d *DingTalkBotBiz) InitRequire(ctx context.Context, data *chatbot.BotCallb } entitys.ResLog(requireData.Ch, "recognize_start", "收到消息,正在处理中,请稍等") - requireData.UserInfo, err = d.dingTalkUser.GetUserInfo(ctx, data.SenderStaffId, dingtalk.WithId(1)) + requireData.UserInfo, err = d.dingTalkUser.GetUserInfoFromBot(ctx, data.SenderStaffId, dingtalk.WithId(1)) return } diff --git a/internal/biz/handle/dingtalk/auth.go b/internal/biz/handle/dingtalk/auth.go index 45fa566..2b359a9 100644 --- a/internal/biz/handle/dingtalk/auth.go +++ b/internal/biz/handle/dingtalk/auth.go @@ -33,30 +33,34 @@ func NewAuth(cfg *config.Config, redis *utils.Rdb, botConfigImpl *impl.BotConfig } } -func (a *Auth) GetAccessToken(ctx context.Context, clientId string, clientSecret string) (accessToken string, err error) { +func (a *Auth) GetAccessToken(ctx context.Context, clientId string, clientSecret string) (authInfo *AuthInfo, err error) { if clientId == "" { - return "", errors.New("clientId is empty") + return nil, errors.New("clientId is empty") } - accessToken = a.redis.Get(ctx, a.getKey(clientId)).Val() + accessToken := a.redis.Get(ctx, a.getKey(clientId)).Val() if accessToken == "" { - authInfo, _err := a.getNewAccessToken(ctx, clientId, clientSecret) + dingTalkAuthRes, _err := a.getNewAccessToken(ctx, clientId, clientSecret) if _err != nil { - return "", _err + return nil, _err } - err = a.redis.SetEx(ctx, a.getKey(clientId), authInfo.AccessToken, time.Duration(authInfo.ExpireIn-3600)*time.Second).Err() + err = a.redis.SetEx(ctx, a.getKey(clientId), dingTalkAuthRes.AccessToken, time.Duration(dingTalkAuthRes.ExpireIn-3600)*time.Second).Err() if err != nil { return } - accessToken = authInfo.AccessToken + accessToken = dingTalkAuthRes.AccessToken } - return + return &AuthInfo{ + ClientId: clientId, + ClientSecret: clientSecret, + AccessToken: accessToken, + }, nil } func (a *Auth) getKey(clientId string) string { return a.cfg.Redis.Key + ":" + constants.DingTalkAuthBaseKeyPrefix + ":" + clientId } -func (a *Auth) getNewAccessToken(ctx context.Context, clientId string, clientSecret string) (auth AuthInfo, err error) { +func (a *Auth) getNewAccessToken(ctx context.Context, clientId string, clientSecret string) (auth DingTalkAuthIRes, err error) { if clientId == "" || clientSecret == "" { err = errors.New("clientId or clientSecret is empty") return @@ -79,7 +83,7 @@ func (a *Auth) getNewAccessToken(ctx context.Context, clientId string, clientSec return } -func (a *Auth) GetTokenFromBotOption(ctx context.Context, botOption ...BotOption) (token string, err error) { +func (a *Auth) GetTokenFromBotOption(ctx context.Context, botOption ...BotOption) (token *AuthInfo, err error) { botInfo := &Bot{} for _, option := range botOption { option(botInfo) @@ -98,7 +102,8 @@ func (a *Auth) GetTokenFromBotOption(ctx context.Context, botOption ...BotOption return } if botConfigDo.BotID == 0 { - return "", errors.New("未找到机器人服务配置") + err = errors.New("未找到机器人服务配置") + return } botInfo.botConfig = &botConfigDo } diff --git a/internal/biz/handle/dingtalk/dept.go b/internal/biz/handle/dingtalk/dept.go index 3ae1a36..0a6bd15 100644 --- a/internal/biz/handle/dingtalk/dept.go +++ b/internal/biz/handle/dingtalk/dept.go @@ -27,14 +27,14 @@ func NewDept(dingDeptImpl *impl.BotDeptImpl, auth *Auth) *Dept { } } -func (d *Dept) GetDeptInfoByDeptIds(ctx context.Context, deptIds []int) (depts []*entitys.Dept, err error) { - if len(deptIds) == 0 { +func (d *Dept) GetDeptInfoByDeptIds(ctx context.Context, deptIds []int, authInfo *AuthInfo) (depts []*entitys.Dept, err error) { + if len(deptIds) == 0 || authInfo == nil { return } - deptsInfo := make([]model.AiBotDept, 0) + var deptsInfo []model.AiBotDept cond := builder.NewCond() cond = cond.And(builder.Eq{"dingtalk_dept_id": deptIds}) - err = d.dingDeptImpl.GetRangeToMapStruct(&cond, deptsInfo) + err = d.dingDeptImpl.GetRangeToMapStruct(&cond, &deptsInfo) if err != nil { return } @@ -50,9 +50,9 @@ func (d *Dept) GetDeptInfoByDeptIds(ctx context.Context, deptIds []int) (depts [ if len(diff) > 0 { deptDo := make([]model.AiBotDept, 0) for _, deptId := range diff { - deptInfo, err := d.GetDeptInfoFromDingTalk(ctx, deptId) - if err != nil { - return nil, err + deptInfo, _err := d.GetDeptInfoFromDingTalk(ctx, deptId, authInfo.AccessToken) + if _err != nil { + return nil, _err } depts = append(depts, &entitys.Dept{ DeptId: deptInfo.DeptId, @@ -74,12 +74,8 @@ func (d *Dept) GetDeptInfoByDeptIds(ctx context.Context, deptIds []int) (depts [ return } -func (d *Dept) GetDeptInfoFromDingTalk(ctx context.Context, deptId int, botOption ...BotOption) (depts DeptResResult, err error) { - if deptId == 0 { - return - } - token, err := d.auth.GetTokenFromBotOption(ctx, botOption...) - if err != nil { +func (d *Dept) GetDeptInfoFromDingTalk(ctx context.Context, deptId int, token string) (depts DeptResResult, err error) { + if deptId == 0 || len(token) == 0 { return } diff --git a/internal/biz/handle/dingtalk/types.go b/internal/biz/handle/dingtalk/types.go index 275010c..a36ea76 100644 --- a/internal/biz/handle/dingtalk/types.go +++ b/internal/biz/handle/dingtalk/types.go @@ -2,7 +2,7 @@ package dingtalk import "time" -type AuthInfo struct { +type DingTalkAuthIRes struct { AccessToken string `json:"accessToken"` ExpireIn int64 `json:"expireIn"` } @@ -76,3 +76,9 @@ type DeptResResult struct { DeptId int `json:"dept_id"` } `json:"union_dept_ext"` } + +type AuthInfo struct { + ClientId string `json:"clientId"` + ClientSecret string `json:"clientSecret"` + AccessToken string `json:"accessToken"` +} diff --git a/internal/biz/handle/dingtalk/user.go b/internal/biz/handle/dingtalk/user.go index cf860f1..900a492 100644 --- a/internal/biz/handle/dingtalk/user.go +++ b/internal/biz/handle/dingtalk/user.go @@ -36,7 +36,7 @@ func NewUser( } } -func (u *User) GetUserInfo(ctx context.Context, staffId string, botOption ...BotOption) (userInfo *entitys.DingTalkUserInfo, err error) { +func (u *User) GetUserInfoFromBot(ctx context.Context, staffId string, botOption ...BotOption) (userInfo *entitys.DingTalkUserInfo, err error) { if len(staffId) == 0 { return } @@ -46,9 +46,13 @@ func (u *User) GetUserInfo(ctx context.Context, staffId string, botOption ...Bot return } } + authInfo, err := u.auth.GetTokenFromBotOption(ctx, botOption...) + if err != nil || authInfo == nil { + return + } //如果没有找到,则新增 if user == nil { - DingUserInfo, _err := u.getUserInfoFromDingTalkWithBot(ctx, staffId, botOption...) + DingUserInfo, _err := u.getUserInfoFromDingTalk(ctx, authInfo.AccessToken, staffId) if _err != nil { return nil, _err } @@ -60,7 +64,7 @@ func (u *User) GetUserInfo(ctx context.Context, staffId string, botOption ...Bot DeptIDList: strings.Join(pkg.SliceIntToString(DingUserInfo.DeptIdList), ","), IsBoss: int32(pkg.Ter(DingUserInfo.Boss, constants.IsBossTrue, constants.IsBossFalse)), IsSenior: int32(pkg.Ter(DingUserInfo.Senior, constants.IsSeniorTrue, constants.IsSeniorFalse)), - HiredDate: time.Unix(DingUserInfo.HiredDate, 0), + HiredDate: time.UnixMilli(DingUserInfo.HiredDate), } _, err = u.dingUserImpl.Add(user) @@ -79,7 +83,7 @@ func (u *User) GetUserInfo(ctx context.Context, staffId string, botOption ...Bot } if len(user.DeptIDList) > 0 { deptIdList := pkg.SliceStringToInt(strings.Split(user.DeptIDList, ",")) - depts, _err := u.dept.GetDeptInfoByDeptIds(ctx, deptIdList) + depts, _err := u.dept.GetDeptInfoByDeptIds(ctx, deptIdList, authInfo) if _err != nil { return nil, err } @@ -91,14 +95,6 @@ func (u *User) GetUserInfo(ctx context.Context, staffId string, botOption ...Bot return userInfo, nil } -func (u *User) getUserInfoFromDingTalkWithBot(ctx context.Context, staffId string, botOption ...BotOption) (userInfo UserInfoResResult, err error) { - token, err := u.auth.GetTokenFromBotOption(ctx, botOption...) - if err != nil { - return - } - return u.getUserInfoFromDingTalk(ctx, token, staffId) -} - func (u *User) getUserInfoFromDingTalk(ctx context.Context, token string, staffId string) (user UserInfoResResult, err error) { if token == "" && staffId == "" { err = errors.New("获取钉钉用户信息的必要参数不足") diff --git a/internal/data/model/ai_bot_dept.gen.go b/internal/data/model/ai_bot_dept.gen.go new file mode 100644 index 0000000..db8ba60 --- /dev/null +++ b/internal/data/model/ai_bot_dept.gen.go @@ -0,0 +1,26 @@ +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. + +package model + +import ( + "time" +) + +const TableNameAiBotDept = "ai_bot_dept" + +// AiBotDept mapped from table +type AiBotDept struct { + DeptID int32 `gorm:"column:dept_id;primaryKey" json:"dept_id"` + DingtalkDeptID int32 `gorm:"column:dingtalk_dept_id;not null;comment:标记部门的唯一id,钉钉:钉钉侧提供的dept_id" json:"dingtalk_dept_id"` // 标记部门的唯一id,钉钉:钉钉侧提供的dept_id + Name string `gorm:"column:name;not null;comment:用户名称" json:"name"` // 用户名称 + Status int32 `gorm:"column:status;not null" json:"status"` + DeleteAt time.Time `gorm:"column:delete_at" json:"delete_at"` + CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"` +} + +// TableName AiBotDept's table name +func (*AiBotDept) TableName() string { + return TableNameAiBotDept +} diff --git a/internal/data/model/ai_bot_user.gen.go b/internal/data/model/ai_bot_user.gen.go index bad185a..5783e51 100644 --- a/internal/data/model/ai_bot_user.gen.go +++ b/internal/data/model/ai_bot_user.gen.go @@ -12,19 +12,19 @@ const TableNameAiBotUser = "ai_bot_user" // AiBotUser mapped from table type AiBotUser struct { - UserID int32 `gorm:"column:user_id;primaryKey" json:"user_id"` - StaffID string `gorm:"column:staff_id;not null;comment:标记用户用的唯一id,钉钉:钉钉侧提供的user_id" json:"staff_id"` // 标记用户用的唯一id,钉钉:钉钉侧提供的user_id - Name string `gorm:"column:name;not null;comment:用户名称" json:"name"` // 用户名称 - Title string `gorm:"column:title;not null;comment:职位" json:"title"` // 职位 - Extension string `gorm:"column:extension;not null;default:1;comment:信息面板" json:"extension"` // 信息面板 - RoleList string `gorm:"column:role_list;not null;comment:角色列表。" json:"role_list"` // 角色列表。 - DeptIDList string `gorm:"column:dept_id_list;not null;comment:所在部门id列表" json:"dept_id_list"` // 所在部门id列表 - IsBoss int32 `gorm:"column:is_boss;not null;comment:是否是老板" json:"is_boss"` // 是否是老板 - IsSenior int32 `gorm:"column:is_senior;not null;comment:是否是高管" json:"is_senior"` // 是否是高管 - HiredDate time.Time `gorm:"column:hired_date;not null;default:CURRENT_TIMESTAMP;comment:入职时间" json:"hired_date"` // 入职时间 - Status int32 `gorm:"column:status;not null" json:"status"` - DeleteAt time.Time `gorm:"column:delete_at" json:"delete_at"` - CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"` + UserID int32 `gorm:"column:user_id;primaryKey" json:"user_id"` + StaffID string `gorm:"column:staff_id;not null;comment:标记用户用的唯一id,钉钉:钉钉侧提供的user_id" json:"staff_id"` // 标记用户用的唯一id,钉钉:钉钉侧提供的user_id + Name string `gorm:"column:name;not null;comment:用户名称" json:"name"` // 用户名称 + Title string `gorm:"column:title;not null;comment:职位" json:"title"` // 职位 + Extension string `gorm:"column:extension;not null;default:1;comment:信息面板" json:"extension"` // 信息面板 + RoleList string `gorm:"column:role_list;not null;comment:角色列表。" json:"role_list"` // 角色列表。 + DeptIDList string `gorm:"column:dept_id_list;not null;comment:所在部门id列表" json:"dept_id_list"` // 所在部门id列表 + IsBoss int32 `gorm:"column:is_boss;not null;comment:是否是老板" json:"is_boss"` // 是否是老板 + IsSenior int32 `gorm:"column:is_senior;not null;comment:是否是高管" json:"is_senior"` // 是否是高管 + HiredDate time.Time `gorm:"column:hired_date;not null;default:CURRENT_TIMESTAMP;comment:入职时间" json:"hired_date"` // 入职时间 + Status int32 `gorm:"column:status;not null" json:"status"` + DeleteAt *time.Time `gorm:"column:delete_at" json:"delete_at"` + CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"` } // TableName AiBotUser's table name diff --git a/internal/pkg/func.go b/internal/pkg/func.go index 3e6e5b7..57321bd 100644 --- a/internal/pkg/func.go +++ b/internal/pkg/func.go @@ -114,23 +114,22 @@ func Difference[T comparable](a, b []T) []T { // SliceStringToInt []string=>[]int func SliceStringToInt(strSlice []string) []int { - numSlice := make([]int, len(strSlice), 0) - for _, str := range strSlice { + numSlice := make([]int, len(strSlice)) + for i, str := range strSlice { num, err := strconv.Atoi(str) if err != nil { return nil } - numSlice = append(numSlice, num) + numSlice[i] = num } return numSlice } // SliceIntToString []int=>[]string func SliceIntToString(slice []int) []string { - numSlice := make([]string, len(slice), 0) - for _, str := range slice { - - numSlice = append(numSlice, strconv.Itoa(str)) + strSlice := make([]string, len(slice)) // len=cap=len(slice) + for i, num := range slice { + strSlice[i] = strconv.Itoa(num) // 直接赋值,无 append } - return numSlice + return strSlice } From 54321bacf848a834541416ba5fb1f25b1e6b05c0 Mon Sep 17 00:00:00 2001 From: fuzhongyun <15339891972@163.com> Date: Fri, 12 Dec 2025 18:00:20 +0800 Subject: [PATCH 37/66] =?UTF-8?q?fix:=20eino=20=E5=B7=A5=E4=BD=9C=E6=B5=81?= =?UTF-8?q?=20=E7=B1=BB=E5=9E=8B=E4=BB=BB=E5=8A=A1=E6=8E=A5=E5=85=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cmd/server/wire.go | 43 ++-- internal/biz/do/handle.go | 19 +- .../zltx/order_after_reseller_batch/client.go | 19 +- .../order_after_reseller_batch/invokable.go | 2 +- .../zltx/order_after_reseller_batch/types.go | 10 +- internal/domain/workflow/runtime/registry.go | 10 +- .../zltx/order_after_reseller_batch.go | 185 +++++++++++++++--- 7 files changed, 210 insertions(+), 78 deletions(-) diff --git a/cmd/server/wire.go b/cmd/server/wire.go index 8a29212..e34c611 100644 --- a/cmd/server/wire.go +++ b/cmd/server/wire.go @@ -4,16 +4,16 @@ package main import ( - "ai_scheduler/internal/biz" - "ai_scheduler/internal/config" - "ai_scheduler/internal/data/impl" - "ai_scheduler/internal/pkg" - "ai_scheduler/internal/server" - "ai_scheduler/internal/services" - "ai_scheduler/internal/domain/workflow" - "ai_scheduler/internal/tools" - "ai_scheduler/internal/tools_bot" - "ai_scheduler/utils" + "ai_scheduler/internal/biz" + "ai_scheduler/internal/config" + "ai_scheduler/internal/data/impl" + "ai_scheduler/internal/domain/workflow" + "ai_scheduler/internal/pkg" + "ai_scheduler/internal/server" + "ai_scheduler/internal/services" + "ai_scheduler/internal/tools" + "ai_scheduler/internal/tools_bot" + "ai_scheduler/utils" "github.com/gofiber/fiber/v2/log" "github.com/google/wire" @@ -21,17 +21,16 @@ import ( // InitializeApp 初始化应用程序 func InitializeApp(*config.Config, log.AllLogger) (*server.Servers, func(), error) { - panic(wire.Build( - server.ProviderSetServer, - llm.ProviderSet, - workflow.ProviderSetWorkflow, - tools.ProviderSetTools, - pkg.ProviderSetClient, - services.ProviderSetServices, - biz.ProviderSetBiz, - impl.ProviderImpl, - utils.ProviderUtils, - tools_bot.ProviderSetBotTools, - )) + panic(wire.Build( + server.ProviderSetServer, + workflow.ProviderSetWorkflow, + tools.ProviderSetTools, + pkg.ProviderSetClient, + services.ProviderSetServices, + biz.ProviderSetBiz, + impl.ProviderImpl, + utils.ProviderUtils, + tools_bot.ProviderSetBotTools, + )) } diff --git a/internal/biz/do/handle.go b/internal/biz/do/handle.go index 2028f54..e5b30b4 100644 --- a/internal/biz/do/handle.go +++ b/internal/biz/do/handle.go @@ -288,23 +288,16 @@ func (r *Handle) handleApiTask(ctx context.Context, requireData *entitys.Require func (r *Handle) handleEinoWorkflow(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) { // token 写入ctx ctx = util.SetTokenToContext(ctx, requireData.Auth) - // 解析入参:workflow_id 与 input - var params map[string]any - if len(requireData.Match.Parameters) > 0 { - _ = json.Unmarshal([]byte(requireData.Match.Parameters), ¶ms) - } - wfID, _ := params["workflow_id"].(string) - input, _ := params["input"].(map[string]any) - if wfID == "" { - return fmt.Errorf("workflow_id 不能为空") - } + entitys.ResLoading(requireData.Ch, requireData.Task.Index, "正在执行工作流") - res, err := r.workflowManager.Invoke(ctx, wfID, input) + + // 工作流内部输出 + workflowId := task.Index + _, err = r.workflowManager.Invoke(ctx, workflowId, requireData) if err != nil { return err } - b, _ := json.Marshal(res) - entitys.ResJson(requireData.Ch, requireData.Task.Index, string(b)) + return nil } diff --git a/internal/domain/tools/zltx/order_after_reseller_batch/client.go b/internal/domain/tools/zltx/order_after_reseller_batch/client.go index 90da543..efcf29c 100644 --- a/internal/domain/tools/zltx/order_after_reseller_batch/client.go +++ b/internal/domain/tools/zltx/order_after_reseller_batch/client.go @@ -10,13 +10,13 @@ import ( "fmt" ) -func Call(ctx context.Context, cfg config.ToolConfig, orderNumbers []string) (OrderAfterSaleResellerBatchData, error) { +func Call(ctx context.Context, cfg config.ToolConfig, orderNumbers []string) (*OrderAfterSaleResellerBatchResponse, error) { if len(orderNumbers) == 0 { - return OrderAfterSaleResellerBatchData{}, errors.New("批充订单号不能为空") + return nil, errors.New("批充订单号不能为空") } token := util.GetTokenFromContext(ctx) if token == "" { - return OrderAfterSaleResellerBatchData{}, errors.New("token 未注入") + return nil, errors.New("token 未注入") } r := l_request.Request{ Url: cfg.BaseURL, @@ -31,17 +31,18 @@ func Call(ctx context.Context, cfg config.ToolConfig, orderNumbers []string) (Or } res, err := r.Send() if err != nil { - return OrderAfterSaleResellerBatchData{}, err + return nil, err } - var response OrderAfterSaleResellerBatchResponse + + response := &OrderAfterSaleResellerBatchResponse{} if err = json.Unmarshal(res.Content, &response); err != nil { - return OrderAfterSaleResellerBatchData{}, err + return nil, err } if response.Code != 200 { - return OrderAfterSaleResellerBatchData{}, fmt.Errorf("售后订单查询异常: %s", response.Error) + return nil, fmt.Errorf("售后订单查询异常: %s", response.Error) } if len(response.Data.Data) == 0 { - return OrderAfterSaleResellerBatchData{}, errors.New("未查询到相应售后订单,请核实订单号是否正确") + return nil, errors.New("未查询到相应售后订单,请核实订单号是否正确") } - return response.Data, nil + return response, nil } diff --git a/internal/domain/tools/zltx/order_after_reseller_batch/invokable.go b/internal/domain/tools/zltx/order_after_reseller_batch/invokable.go index 34da7e1..e7e99b5 100644 --- a/internal/domain/tools/zltx/order_after_reseller_batch/invokable.go +++ b/internal/domain/tools/zltx/order_after_reseller_batch/invokable.go @@ -13,7 +13,7 @@ type Args struct { } func NewInvokable(cfg config.ToolConfig) tool.InvokableTool { - run := func(ctx context.Context, in Args) (OrderAfterSaleResellerBatchData, error) { + run := func(ctx context.Context, in Args) (*OrderAfterSaleResellerBatchResponse, error) { return Call(ctx, cfg, in.OrderNumber) } t, err := toolutils.InferTool("zltxOrderAfterSaleResellerBatch", "直连天下下游分销商批充订单售后工具", run) diff --git a/internal/domain/tools/zltx/order_after_reseller_batch/types.go b/internal/domain/tools/zltx/order_after_reseller_batch/types.go index 9d2a071..3e35115 100644 --- a/internal/domain/tools/zltx/order_after_reseller_batch/types.go +++ b/internal/domain/tools/zltx/order_after_reseller_batch/types.go @@ -1,14 +1,14 @@ package order_after_reseller_batch type OrderAfterSaleResellerBatchResponse struct { - Code int `json:"code"` - Error string `json:"error"` - Data OrderAfterSaleResellerBatchData `json:"data"` + Code int `json:"code"` + Error string `json:"error"` + Data *OrderAfterSaleResellerBatchData `json:"data"` } type OrderAfterSaleResellerBatchData struct { - Data []OrderAfterSaleResellerBatchBase `json:"data"` - ExtData map[string]OrderAfterSaleResellerBatchExtItem `json:"extraData"` + Data []*OrderAfterSaleResellerBatchBase `json:"data"` + ExtData map[string]*OrderAfterSaleResellerBatchExtItem `json:"extraData"` } type OrderAfterSaleResellerBatchBase struct { diff --git a/internal/domain/workflow/runtime/registry.go b/internal/domain/workflow/runtime/registry.go index 1260173..8ae7a82 100644 --- a/internal/domain/workflow/runtime/registry.go +++ b/internal/domain/workflow/runtime/registry.go @@ -2,6 +2,7 @@ package runtime import ( "ai_scheduler/internal/config" + "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg/utils_ollama" "context" "errors" @@ -11,7 +12,7 @@ import ( type Workflow interface { ID() string Schema() map[string]any - Invoke(ctx context.Context, input map[string]any) (map[string]any, error) + Invoke(ctx context.Context, requireData *entitys.RequireData) (map[string]any, error) } type Deps struct { @@ -62,10 +63,7 @@ func Default() *Registry { return r } -func (r *Registry) Invoke(ctx context.Context, id string, input map[string]any) (map[string]any, error) { - if input == nil { - input = map[string]any{} - } +func (r *Registry) Invoke(ctx context.Context, id string, requireData *entitys.RequireData) (map[string]any, error) { regMu.RLock() f, ok := factories[id] regMu.RUnlock() @@ -91,5 +89,5 @@ func (r *Registry) Invoke(ctx context.Context, id string, input map[string]any) r.mu.Unlock() } - return w.Invoke(ctx, input) + return w.Invoke(ctx, requireData) } diff --git a/internal/domain/workflow/zltx/order_after_reseller_batch.go b/internal/domain/workflow/zltx/order_after_reseller_batch.go index e00b726..4d151d2 100644 --- a/internal/domain/workflow/zltx/order_after_reseller_batch.go +++ b/internal/domain/workflow/zltx/order_after_reseller_batch.go @@ -2,12 +2,17 @@ package zltx import ( "ai_scheduler/internal/config" + "ai_scheduler/internal/data/model" toolZoarb "ai_scheduler/internal/domain/tools/zltx/order_after_reseller_batch" "ai_scheduler/internal/domain/workflow/runtime" + "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg/util" "context" + "encoding/json" "errors" "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/schema" ) func init() { @@ -17,7 +22,57 @@ func init() { } type orderAfterSaleResellerBatch struct { - cfg config.ToolConfig + cfg config.ToolConfig + data *OrderAfterSaleResellerBatchWorkflowInput +} + +// 工作流入参 +type OrderAfterSaleResellerBatchWorkflowInput struct { + Ch chan entitys.Response // 响应通道 + UserInput string // 用户输入文本 + FileContent string // 文件解析结果 + UserHistory []model.AiChatHi // 用户对话历史 + ParameterResult string // 参数解析结果 + Data *OrderAfterSaleResellerBatchNodeData // 节点所需参数 +} + +// 节点所需参数 +type OrderAfterSaleResellerBatchNodeData struct { + OrderNumber []string `json:"orderNumber"` // 订单号 + AfterType string `json:"afterType"` // 处理方式 1.退款 2.扣款 + AfterSalesPrice string `json:"afterSalesPrice"` // 售后金额 + AfterSalesReason string `json:"afterSalesReason"` // 售后原因 + ResponsibleType string `json:"responsibleType"` // 费用承担者 1.供应商 2.商务 3.公司 4.无 + ResponsiblePerson string `json:"responsiblePerson"` // 费用承担供应商 +} + +// 工作流出参 +type OrderAfterSaleResellerBatchWorkflowOutput struct { + Code int `json:"code"` + Msg string `json:"msg"` + Data []*OrderAfterSaleResellerBatchData `json:"data"` +} + +type OrderAfterSaleResellerBatchData struct { + OrderType int `json:"orderType"` + OrderNumber string `json:"orderNumber"` + OrderAmount float64 `json:"orderAmount"` + OrderPrice float64 `json:"orderPrice"` + SignCompany int `json:"signCompany"` + OrderQuantity int `json:"orderQuantity"` + ResellerID int `json:"resellerId"` + ResellerName string `json:"resellerName"` + OurProductID int `json:"ourProductId"` + OurProductTitle string `json:"ourProductTitle"` + Account []string `json:"account"` + Platforms map[int]string `json:"platforms"` + AfterType int `json:"afterType"` // 处理方式 1.退款 2.扣款 + Remark string `json:"remark"` // 售后原因 + AfterAmount float64 `json:"afterAmount"` // 售后金额 + ResponsibleType int `json:"responsibleType"` // 费用承担者 1.供应商 2.商务 3.公司 4.无 + ResponsiblePerson string `json:"responsiblePerson"` // 费用承担供应商 + IsExistsAfterSale bool `json:"isExistsAfterSale"` // 是否已存在售后 + CreateTime int `json:"createTime"` // 创建时间 } // ID 返回工作流唯一标识 @@ -33,18 +88,22 @@ func (o *orderAfterSaleResellerBatch) Schema() map[string]any { } // Invoke 调用原有编排工作流并规范化输出 -func (o *orderAfterSaleResellerBatch) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) { +func (o *orderAfterSaleResellerBatch) Invoke(ctx context.Context, requireData *entitys.RequireData) (map[string]any, error) { // 构建工作流 chain, err := o.buildWorkflow(ctx) if err != nil { return nil, err } - var in OrderAfterSaleResellerBatchWorkflowInput - if v, ok := input["orderNumber"].([]string); ok { - in.OrderNumber = v + o.data = &OrderAfterSaleResellerBatchWorkflowInput{ + Ch: requireData.Ch, + UserInput: requireData.Req.Text, + FileContent: "", + UserHistory: requireData.Histories, + ParameterResult: requireData.Match.Parameters, } - _, err = chain.Invoke(ctx, in) + // 工作流过程输出,不关注最终输出 + _, err = chain.Invoke(ctx, o.data) if err != nil { return nil, err } @@ -55,35 +114,117 @@ func (o *orderAfterSaleResellerBatch) Invoke(ctx context.Context, input map[stri return nil, nil } -type OrderAfterSaleResellerBatchWorkflowInput struct { - OrderNumber []string `json:"orderNumber"` -} - var ErrInvalidOrderNumbers = errors.New("orderNumber 不能为空") // buildWorkflow 构建工作流 -func (o *orderAfterSaleResellerBatch) buildWorkflow(ctx context.Context) (compose.Runnable[OrderAfterSaleResellerBatchWorkflowInput, toolZoarb.OrderAfterSaleResellerBatchData], error) { +func (o *orderAfterSaleResellerBatch) buildWorkflow(ctx context.Context) (compose.Runnable[*OrderAfterSaleResellerBatchWorkflowInput, *OrderAfterSaleResellerBatchWorkflowOutput], error) { // 定义工作流、出入参 - c := compose.NewChain[OrderAfterSaleResellerBatchWorkflowInput, toolZoarb.OrderAfterSaleResellerBatchData]() + c := compose.NewChain[*OrderAfterSaleResellerBatchWorkflowInput, *OrderAfterSaleResellerBatchWorkflowOutput]() - // 1.入参解析与校验 - c.AppendLambda(compose.InvokableLambda(func(_ context.Context, in OrderAfterSaleResellerBatchWorkflowInput) (OrderAfterSaleResellerBatchWorkflowInput, error) { + // 1.llm 推断参数 (若需要) + c.AppendLambda(compose.InvokableLambda(func(_ context.Context, in *OrderAfterSaleResellerBatchWorkflowInput) (*schema.Message, error) { + // 已推断完,直接使用 + parameters := in.ParameterResult + return &schema.Message{Content: parameters}, nil + })) + + // 2.参数解析为结构体 + c.AppendLambda(compose.MessageParser( + schema.NewMessageJSONParser[*OrderAfterSaleResellerBatchNodeData](&schema.MessageJSONParseConfig{ + ParseFrom: schema.MessageParseFromContent, + ParseKeyPath: "", // 如果仅需要 parse 子字段,可用 "key.sub.grandsub" + }), + )) + + // 3.参数校验 + c.AppendLambda(compose.InvokableLambda(func(_ context.Context, in *OrderAfterSaleResellerBatchNodeData) (*OrderAfterSaleResellerBatchNodeData, error) { + // 校验必填项 if len(in.OrderNumber) == 0 { - return OrderAfterSaleResellerBatchWorkflowInput{}, ErrInvalidOrderNumbers + return nil, ErrInvalidOrderNumbers } + + o.data.Data = in + return in, nil })) - // 2.调用工具 - c.AppendLambda(compose.InvokableLambda(func(ctx context.Context, in OrderAfterSaleResellerBatchWorkflowInput) (toolZoarb.OrderAfterSaleResellerBatchData, error) { - return toolZoarb.Call(ctx, o.cfg, in.OrderNumber) + // 4.工具调用 + c.AppendLambda(compose.InvokableLambda(func(_ context.Context, in *OrderAfterSaleResellerBatchNodeData) (*toolZoarb.OrderAfterSaleResellerBatchResponse, error) { + entitys.ResLoading(o.data.Ch, o.ID(), "数据拉取中") + + toolRes, err := toolZoarb.Call(ctx, o.cfg, in.OrderNumber) + + entitys.ResLog(o.data.Ch, o.ID(), "数据拉取完成") + + return toolRes, err })) - // 3.结果映射与整形 - c.AppendLambda(compose.InvokableLambda(func(_ context.Context, in toolZoarb.OrderAfterSaleResellerBatchData) (toolZoarb.OrderAfterSaleResellerBatchData, error) { - return in, nil - })) + // 5.结果数据映射 + c.AppendLambda(compose.InvokableLambda(o.dataMapping)) // 编译工作流 return c.Compile(ctx) } + +// 结果数据映射 +func (o *orderAfterSaleResellerBatch) dataMapping(_ context.Context, in *toolZoarb.OrderAfterSaleResellerBatchResponse) (*OrderAfterSaleResellerBatchWorkflowOutput, error) { + entitys.ResLog(o.data.Ch, o.ID(), "数据整理中") + + toolResp := &OrderAfterSaleResellerBatchWorkflowOutput{ + Code: in.Code, + Msg: in.Error, + Data: make([]*OrderAfterSaleResellerBatchData, 0, len(in.Data.Data)), + } + + // 转换数据 + for _, item := range in.Data.Data { + // 处理方式 + afterType := util.StringToInt(o.data.Data.AfterType) + if afterType == 0 { + afterType = 1 // 默认退款 + } + // 费用承担者 + responsibleType := util.StringToInt(o.data.Data.ResponsibleType) + if responsibleType == 0 { + responsibleType = 4 // 默认无 + } + // 售后金额 + afterSalesPrice := util.StringToFloat64(o.data.Data.AfterSalesPrice) + if afterSalesPrice == 0 { + afterSalesPrice = item.OrderPrice + } + + toolResp.Data = append(toolResp.Data, &OrderAfterSaleResellerBatchData{ + OrderType: item.OrderType, + OrderNumber: item.OrderNumber, + OrderAmount: item.OrderAmount, + OrderPrice: item.OrderPrice, + SignCompany: item.SignCompany, + OrderQuantity: item.OrderQuantity, + ResellerID: item.ResellerID, + ResellerName: item.ResellerName, + OurProductID: item.OurProductID, + OurProductTitle: item.OurProductTitle, + Account: item.Account, + Platforms: item.Platforms, + AfterType: afterType, + Remark: o.data.Data.AfterSalesReason, + AfterAmount: afterSalesPrice, + ResponsibleType: responsibleType, + ResponsiblePerson: o.data.Data.ResponsiblePerson, + }) + } + + // 追加扩展数据 + for _, item := range toolResp.Data { + if extItem, ok := in.Data.ExtData[item.OrderNumber]; ok { + item.IsExistsAfterSale = item.OrderType > 100 // 102 批充&已售后 + item.CreateTime = extItem.SerialCreateTime + } + } + + toolRespJson, _ := json.Marshal(toolResp) + entitys.ResJson(o.data.Ch, o.ID(), string(toolRespJson)) + + return toolResp, nil +} From fc90dc5b38fd5f2f6c63d1645edb9d8303ea5d93 Mon Sep 17 00:00:00 2001 From: wolter <11@gmail> Date: Mon, 15 Dec 2025 15:48:57 +0800 Subject: [PATCH 38/66] =?UTF-8?q?feat:=20=E5=BF=AB=E9=80=92=E6=9F=A5?= =?UTF-8?q?=E8=AF=A2=E4=B8=8E=E4=BC=81=E4=B8=9A=E6=9F=A5=E8=AF=A2=E5=B7=A5?= =?UTF-8?q?=E5=85=B7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/config_test.yaml | 10 + go.mod | 4 +- go.sum | 6 + internal/config/config.go | 4 + internal/tools/manager.go | 10 + internal/tools/public/coze_company.go | 264 ++++++++++++++++++++++++++ internal/tools/public/coze_express.go | 140 ++++++++++++++ 7 files changed, 437 insertions(+), 1 deletion(-) create mode 100644 internal/tools/public/coze_company.go create mode 100644 internal/tools/public/coze_express.go diff --git a/config/config_test.yaml b/config/config_test.yaml index 85a7eb1..0ef429d 100644 --- a/config/config_test.yaml +++ b/config/config_test.yaml @@ -75,6 +75,16 @@ tools: enabled: true base_url: "https://restapi.amap.com/v3/weather/weatherInfo" api_key: "12afbde5ab78cb7e575ff76bd0bdef2b" + cozeExpress: + enabled: true + base_url: "https://api.coze.cn" + api_key: "7582477438102552616" + api_secret: "pat_eEN0BdLNDughEtABjJJRYTW71olvDU0qUbfQUeaPc2NnYWO8HeyNoui5aR9z0sSZ" + cozeCompany: + enabled: true + base_url: "https://api.coze.cn" + api_key: "7583905168607100978" + api_secret: "pat_eEN0BdLNDughEtABjJJRYTW71olvDU0qUbfQUeaPc2NnYWO8HeyNoui5aR9z0sSZ" diff --git a/go.mod b/go.mod index 0f01662..a398d0d 100644 --- a/go.mod +++ b/go.mod @@ -51,6 +51,7 @@ require ( github.com/clbanning/mxj/v2 v2.5.5 // indirect github.com/cloudwego/base64x v0.1.6 // indirect github.com/cloudwego/eino-ext/libs/acl/openai v0.1.2 // indirect + github.com/coze-dev/coze-go v0.0.0-20251029161603-312b7fd62d20 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dlclark/regexp2 v1.11.4 // indirect github.com/dustin/go-humanize v1.0.1 // indirect @@ -59,8 +60,9 @@ require ( github.com/fasthttp/websocket v1.5.3 // indirect github.com/fsnotify/fsnotify v1.6.0 // indirect github.com/go-sql-driver/mysql v1.8.1 // indirect + github.com/golang-jwt/jwt/v5 v5.2.2 // indirect github.com/goph/emperror v0.17.2 // indirect - github.com/gorilla/websocket v1.5.0 // indirect + github.com/gorilla/websocket v1.5.3 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect diff --git a/go.sum b/go.sum index c1b06f6..b7f99b7 100644 --- a/go.sum +++ b/go.sum @@ -139,6 +139,8 @@ github.com/cloudwego/eino-ext/libs/acl/openai v0.1.2/go.mod h1:S4OkvglPY9hsm9tXe github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= +github.com/coze-dev/coze-go v0.0.0-20251029161603-312b7fd62d20 h1:m6P88V9lLrxZsE7uj9otq7l7nqDuCSAJ86KhzRlWf0M= +github.com/coze-dev/coze-go v0.0.0-20251029161603-312b7fd62d20/go.mod h1:wdT5CFt/sFsWz9hna2Z7DWzUra9spx0SoX1PUZyoSB0= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= @@ -200,6 +202,8 @@ github.com/gofiber/fiber/v2 v2.52.9/go.mod h1:YEcBbO/FB+5M1IZNBP9FO3J9281zgPArei github.com/gofiber/websocket/v2 v2.2.1 h1:C9cjxvloojayOp9AovmpQrk8VqvVnT8Oao3+IUygH7w= github.com/gofiber/websocket/v2 v2.2.1/go.mod h1:Ao/+nyNnX5u/hIFPuHl28a+NIkrqK7PRimyKaj4JxVU= github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= +github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= +github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= @@ -269,6 +273,8 @@ github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25d github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= diff --git a/internal/config/config.go b/internal/config/config.go index 160275f..df79ed0 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -105,6 +105,10 @@ type ToolsConfig struct { ZltxOrderAfterSaleReseller ToolConfig `mapstructure:"zltxOrderAfterSaleReseller"` // 下游批充订单售后 ZltxOrderAfterSaleResellerBatch ToolConfig `mapstructure:"zltxOrderAfterSaleResellerBatch"` + // Coze 快递查询工具 + CozeExpress ToolConfig `mapstructure:"cozeExpress"` + // Coze 公司查询工具 + CozeCompany ToolConfig `mapstructure:"cozeCompany"` } // ToolConfig 单个工具配置 diff --git a/internal/tools/manager.go b/internal/tools/manager.go index 87212b7..7025d8c 100644 --- a/internal/tools/manager.go +++ b/internal/tools/manager.go @@ -92,6 +92,16 @@ func NewManager(config *config.Config, llm *utils_ollama.Client) *Manager { weatherTool := public.NewWeatherTool(config.Tools.Weather) m.tools[weatherTool.Name()] = weatherTool } + // 注册 Coze 快递查询工具 + if config.Tools.CozeExpress.Enabled { + cozeTool := public.NewCozeExpress(config.Tools.CozeExpress, m.llm) + m.tools[cozeTool.Name()] = cozeTool + } + // 注册 Coze 公司查询工具 + if config.Tools.CozeCompany.Enabled { + cozeTool := public.NewCozeCompany(config.Tools.CozeCompany, m.llm) + m.tools[cozeTool.Name()] = cozeTool + } // 普通对话 chat := public.NewNormalChatTool(m.llm, config) diff --git a/internal/tools/public/coze_company.go b/internal/tools/public/coze_company.go new file mode 100644 index 0000000..3f10638 --- /dev/null +++ b/internal/tools/public/coze_company.go @@ -0,0 +1,264 @@ +package public + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg/utils_ollama" + "context" + "encoding/json" + "fmt" + "github.com/ollama/ollama/api" + "net/http" + "time" + + "github.com/coze-dev/coze-go" +) + +type CozeCompany struct { + cozeApi coze.CozeAPI + config config.ToolConfig + llm *utils_ollama.Client +} + +// NewCoze 创建 Coze 实例 +func NewCozeCompany(config config.ToolConfig, llm *utils_ollama.Client) *CozeCompany { + return &CozeCompany{ + cozeApi: newCozeApi(config), + config: config, + llm: llm, + } +} + +// newCozeClient 创建 Coze 客户端 +func newCozeApi(config config.ToolConfig) coze.CozeAPI { + authCli := coze.NewTokenAuth(config.APISecret) + cozeApi := coze.NewCozeAPI(authCli, coze.WithBaseURL(config.BaseURL), coze.WithHttpClient(&http.Client{ + Timeout: time.Second * 120, + })) + return cozeApi +} + +// Name 返回工具名称 +func (c *CozeCompany) Name() string { + return "coze_company" +} + +// Description 返回工具描述 +func (c *CozeCompany) Description() string { + return "查询企业信息" +} + +// Definition 返回工具定义 +func (c *CozeCompany) Definition() entitys.ToolDefinition { + return entitys.ToolDefinition{ + Type: "function", + Function: entitys.FunctionDef{ + Name: c.Name(), + Description: c.Description(), + Parameters: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "company_name": map[string]interface{}{ + "type": "string", + "description": "企业名称", + }, + }, + "required": []string{"company_name"}, + }, + }, + } +} + +// Execute 执行查询 +func (c *CozeCompany) Execute(ctx context.Context, requireData *entitys.RequireData) error { + var req map[string]interface{} + if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil { + return fmt.Errorf("invalid express request: %w", err) + } + + if req["company_name"] == "" { + return fmt.Errorf("company_name is required") + } + + // 调用 Coze 工作流 + rsp, err := c.callWorkflow(ctx, req) + if err != nil { + return fmt.Errorf("failed to get real weather: %w", err) + } + + companyInfo := CompanyInfo{} + err = json.Unmarshal([]byte(rsp.Data), &companyInfo) + if err != nil { + return fmt.Errorf("failed to unmarshal company info: %w", err) + } + + // 调用 LLM 模型 + err = c.llm.ChatStream(ctx, requireData.Ch, []api.Message{ + { + Role: "system", + Content: `# Role: 企业信息分析与经营诊断专家: +请基于以下12项指定数据字段(无需补充未提供的信息),完成目标企业的全维度分析总结,要求每部分结论必须100%锚定对应数据,拒绝主观推测,突出“风险可见性”与“关键信息关联性”: +一、输入数据清单(需逐一对应分析) +行政处罚:公司是否有行政处罚(含处罚事由、处罚机关、处罚日期、处罚文号) +清算信息:公司的清算信息(含清算原因、清算组构成、清算进展状态) +变更记录:公司的变更记录(含变更事项:注册资本/股东/法定代表人/经营范围、变更时间、变更前后内容) +主要成员:公司名搜索公司的主要成员(含姓名、职位、任职时间、核心履历关键词) +企业详情:公司名称,搜索公司详细信息(含成立时间、注册资本、经营范围、行业分类、注册地址) +经营异常:公司是否有经营异常(含列入原因、列入日期、移出状态) +破产重组:公司破产重组的信息(含申请法院、受理时间、重组方案核心内容) +严重违法:是否有严重违法信息(含违法事由、认定机关、公示期限) +司法信息:司法信息(含案件类型、原被告身份、判决结果/进展) +被执行信息:公司的被执行信息(含执行法院、执行标的、未履行金额、失信状态) +企业手机号:公司名称查询企业手机号(需标注是否为公开备案号) +股东信息:公司对应的股东(含股东名称、出资额、出资比例、股东类型:自然人/企业/机构) +二、分析总结框架(严格对应数据) +1. 企业基础画像(锚定公司名称,搜索公司详细信息) +核心属性:成立时间、注册资本(实缴/认缴)、行业分类(如“批发和零售业”“软件和信息技术服务业”)、主营业务(从经营范围中提炼1-2个核心赛道,如“专注于智能仓储设备研发与销售”); +注册地址:是否与主要经营地一致(若公司名称,搜索公司详细信息有披露)。 +2. 股权结构与股东特征(锚定公司对应的股东) +股权集中度:前三大股东出资占比之和(如“第一大股东持股60%,为绝对控股股东”); +股东类型分布:自然人股东、企业股东、机构股东的占比(如“70%为企业股东,含1家行业头部企业”); +关键股东亮点:若有知名企业/机构股东,需明确标注(如“股东含XX产业投资基金,具备产业链资源协同潜力”)。 +3. 经营稳定性与合规风险(锚定公司是否有行政处罚/公司是否有经营异常/是否有严重违法信息/公司的清算信息/公司破产重组的信息) +行政处罚风险:是否存在行政处罚? +若有:列清「处罚事由(如“虚假宣传”“税务逾期申报”)、处罚机关、处罚日期」,并判断“是否属于高频违规”(如1年内≥2次同类处罚→高风险); +若无:标注“无公开行政处罚记录”。 +经营异常风险:是否存在经营异常? +若有:说明「列入原因(如“通过登记的住所无法联系”“未公示年度报告”)、是否已移出」,并评估对业务的影响(如“地址失联可能导致客户信任下降”); +若无:标注“无经营异常记录”。 +严重违法红线:是否存在严重违法? +若有:明确「违法事由(如“拒不履行生效法律文书”“欺诈消费者”)、认定机关、公示期限」,并标注“触及监管红线,需重点核查整改情况”; +若无:标注“无严重违法记录”。 +极端状态预警:是否存在清算或破产重组? +若有清算:说明「清算原因(如“股东会决议解散”“营业期限届满”)、清算进展(如“清算组已完成资产清查”)」; +若有破产重组:说明「申请法院、受理时间、重组方案核心内容(如“拟引入战略投资者注资5000万”)」; +若无:标注“无清算或破产重组记录”。 +4. 司法与执行风险(锚定司法信息/公司的被执行信息) +涉诉情况:司法案件的核心特征(如“80%为买卖合同纠纷,20%为劳动争议仲裁”;“作为被告的案件占比75%”;“判决胜诉率约60%”); +被执行压力:是否存在被执行信息? +若有:列清「执行法院、执行标的金额、未履行金额、是否纳入失信被执行人名单」,并计算“未履行金额占注册资本的比例”(如“未履行800万,占注册资本的20%”); +若无:标注“无公开被执行记录”。 +5. 管理团队与联系信息(锚定公司名搜索公司的主要成员/公司名称查询企业手机号) +核心团队稳定性:主要成员的任职时间分布(如“CEO任职4年,CFO任职2年,核心团队平均任职3年”);是否有频繁变动(如“1年内2位高管离职”→需提示“管理稳定性风险”); +关键岗位资质:核心成员(如法定代表人、CEO、CFO)的过往履历亮点(如“CEO曾任职XX上市公司,主导过亿元级项目落地”); +联系信息可信度:企业手机号是否为公开备案(如“与工商登记预留电话一致→可信度高”;“为非备案号→需提示‘联系信息真实性存疑’”)。 +6. 综合结论与行动建议(整合所有数据) +整体风险评级:基于数据密度,给出“低风险/中风险/高风险”定性(示例:“无行政处罚、无被执行、仅1条普通合同纠纷→低风险”;“有失信被执行+严重违法→高风险”); +Top3核心风险:按“严重违法>被执行>破产重组>行政处罚>经营异常>司法纠纷”排序,列出最需关注的3个问题(如“1. 未履行金额占注册资本20%,存在债务违约风险;2. 1年内2次地址异常,经营稳定性弱;3. 自然人股东占比过高,决策易受个人因素影响”); + actionable 建议:针对每个核心风险,给出可执行的核查/应对动作(如“核查未履行金额对应的案件进展,评估企业偿债能力;要求企业提供近1年的地址证明,确认经营场所稳定性;穿透核查自然人股东的资产状况,降低决策风险”)。 +三、输出规则(强制遵守) +数据溯源:每句结论必须标注对应数据字段(如“根据公司是否有行政处罚,企业2023年因‘税务逾期申报’被区税务局处罚”); +量化优先:拒绝模糊表述(如不说“很多案件”,要说“近1年涉及5起买卖合同纠纷”); +风险分级:用“★”标注风险等级(★越多越严重,如“严重违法★★★”“被执行★★”“经营异常★”); +语言风格:专业简洁,避免冗余,适合风控/投资/合作前的快速决策阅读。 + +【示例输出片段】 +3. 经营稳定性与合规风险 + +行政处罚:根据公司是否有行政处罚,企业2022年11月因“发布虚假广告”被区市场监管局处以3万元罚款(文号:X市监罚字〔2022〕456号),无后续同类处罚→风险等级★; + +经营异常:根据公司是否有经营异常,企业2023年6月因“通过登记的住所无法联系”被列入异常名录,2023年12月已移出→风险等级★; + +严重违法:根据是否有严重违法信息,无公开严重违法记录→风险等级无; + +极端状态:根据公司的清算信息/公司破产重组的信息,无清算或破产重组记录→风险等级无。 + +6. 综合结论与行动建议 + +整体评级:低风险(仅1次轻微行政处罚,无重大合规瑕疵); + +Top3核心风险:1. 根据司法信息,近1年作为被告的合同纠纷占比75%,需警惕应收账款回收风险★★;2. 根据公司名搜索公司的主要成员,1年内2位销售总监离职,管理稳定性弱★;3. 根据公司名称查询企业手机号,企业手机号为非备案号,联系信息真实性存疑★; + +建议:核查合同纠纷案件的原告身份及回款情况,评估坏账概率;要求企业提供离职人员的交接说明,确认业务连续性;索要企业备案的联系方式,验证沟通有效性。 +`, + }, + + { + Role: "assistant", + Content: fmt.Sprintf(`请分析企业:%s, +公司是否有行政处罚:%v, +公司的清算信息:%v, +公司的变更记录:%v, +公司名搜索公司的主要成员:%v, +公司名称,搜索公司详细信息:%v, +公司是否有经营异常:%v, +公司破产重组的信息:%v, +是否有严重违法信息:%v, +司法信息:%v, +公司的被执行信息:%v, +公司名称查询企业手机号:%v, +公司对应的股东:%v`, req["company_name"], companyInfo.Xzcf, companyInfo.Clears, companyInfo.Changes, companyInfo.Employees, companyInfo.Searchdata, companyInfo.Operations, companyInfo.BankruptcyPublicList, companyInfo.Illegals, companyInfo.JudicialList, companyInfo.Executes, companyInfo.Phone, companyInfo.Partners), + }, + { + Role: "user", + Content: requireData.Req.Text, + }, + }, + c.Name(), "") + if err != nil { + return fmt.Errorf("failed to get express info: %w", err) + } + + //entitys.ResText(requireData.Ch, "", rsp.Data) + + return nil +} + +// CallWorkflow 调用 Coze 工作流 +// 参数: +// - ctx: 上下文,用于控制超时和取消 +// - workflowId: 工作流 ID +// - params: 工作流参数 +// 返回: +// - interface{}: 工作流执行结果 +// - error: 错误信息 +func (c *CozeCompany) callWorkflow(ctx context.Context, params map[string]interface{}) (*coze.RunWorkflowsResp, error) { + // 准备工作流请求参数 + workflowReq := &coze.RunWorkflowsReq{ + WorkflowID: c.config.APIKey, + Parameters: params, + } + + // 调用工作流 + resp, err := c.cozeApi.Workflows.Runs.Create(ctx, workflowReq) + if err != nil { + return nil, fmt.Errorf("工作流调用失败: %w", err) + } + + // 处理工作流响应 + if resp == nil { + return nil, fmt.Errorf("工作流响应为空") + } + + // 返回工作流执行结果 + return resp, nil +} + +type CompanyInfo struct { + BankruptcyPublicList interface{} `json:"bankruptcy_public_list"` // 破产公示列表 + Changes interface{} `json:"changes"` // 变更记录 + Clears interface{} `json:"clears"` // 清算记录 + Employees interface{} `json:"employees"` // 员工列表 + Executes interface{} `json:"executes"` // 执行记录 + Illegals interface{} `json:"illegals"` // 违法记录 + JudicialList interface{} `json:"judicial_list"` // 司法记录 + Operations interface{} `json:"operations"` // 经营记录 + Partners interface{} `json:"partners"` // 合伙人列表 + Phone string `json:"phone"` // 联系电话 + // 搜索数据 + Searchdata struct { + Authority interface{} `json:"authority"` + BusinessScope interface{} `json:"business_scope"` + Capital interface{} `json:"capital"` + CompanyAddress interface{} `json:"company_address"` + CompanyName string `json:"company_name"` + CompanyStatus interface{} `json:"company_status"` + CompanyType interface{} `json:"company_type"` + CreditNo interface{} `json:"credit_no"` + EstablishDate interface{} `json:"establish_date"` + Industry interface{} `json:"industry"` + LegalPerson interface{} `json:"legal_person"` + Province interface{} `json:"province"` + } `json:"searchdata"` + Xzcf interface{} `json:"xzcf"` // 行政处罚 +} diff --git a/internal/tools/public/coze_express.go b/internal/tools/public/coze_express.go new file mode 100644 index 0000000..de316b5 --- /dev/null +++ b/internal/tools/public/coze_express.go @@ -0,0 +1,140 @@ +package public + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg" + "ai_scheduler/internal/pkg/utils_ollama" + "context" + "encoding/json" + "fmt" + "github.com/ollama/ollama/api" + + "github.com/coze-dev/coze-go" +) + +type CozeExpress struct { + cozeApi coze.CozeAPI + config config.ToolConfig + llm *utils_ollama.Client +} + +// NewCozeExpress 创建 CozeExpress 实例 +func NewCozeExpress(config config.ToolConfig, llm *utils_ollama.Client) *CozeExpress { + return &CozeExpress{ + cozeApi: newCozeApi(config), + config: config, + llm: llm, + } +} + +// newCozeExpressClient 创建 CozeExpress 客户端 +func newCozeExpressApi(config config.ToolConfig) coze.CozeAPI { + authCli := coze.NewTokenAuth(config.APISecret) + cozeApi := coze.NewCozeAPI(authCli, coze.WithBaseURL(config.BaseURL)) + return cozeApi +} + +// Name 返回工具名称 +func (c *CozeExpress) Name() string { + return "coze_express" +} + +// Description 返回工具描述 +func (c *CozeExpress) Description() string { + return "查询快递物流信息" +} + +// Definition 返回工具定义 +func (c *CozeExpress) Definition() entitys.ToolDefinition { + return entitys.ToolDefinition{ + Type: "function", + Function: entitys.FunctionDef{ + Name: c.Name(), + Description: c.Description(), + Parameters: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "express_id": map[string]interface{}{ + "type": "string", + "description": "快递单号", + }, + }, + "required": []string{"express_id"}, + }, + }, + } +} + +// Execute 执行查询 +func (c *CozeExpress) Execute(ctx context.Context, requireData *entitys.RequireData) error { + var req map[string]interface{} + if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil { + return fmt.Errorf("invalid express request: %w", err) + } + + if req["express_id"] == "" { + return fmt.Errorf("express_id is required") + } + + // 调用 Coze 工作流查询快递物流信息 + rsp, err := c.callWorkflow(ctx, req) + if err != nil { + return fmt.Errorf("failed to get real weather: %w", err) + } + err = c.llm.ChatStream(ctx, requireData.Ch, []api.Message{ + { + Role: "system", + Content: "你是一个快递查询助手。用户可能会提供快递单号,你需要分析快递单号,根据快递单号查询物流信息并反馈给我", + }, + { + Role: "assistant", + Content: fmt.Sprintf("聊天记录:%s", pkg.JsonStringIgonErr(requireData.Histories)), + }, + { + Role: "assistant", + Content: fmt.Sprintf("需要分析的快递单号:%s", rsp.Data), + }, + { + Role: "user", + Content: requireData.Req.Text, + }, + }, c.Name(), "") + if err != nil { + return fmt.Errorf("failed to get express info: %w", err) + } + + //entitys.ResText(requireData.Ch, "", rsp.Data) + + return nil +} + +// CallWorkflow 调用 Coze 工作流 +// 参数: +// - ctx: 上下文,用于控制超时和取消 +// - workflowId: 工作流 ID +// - params: 工作流参数 +// 返回: +// - interface{}: 工作流执行结果 +// - error: 错误信息 +func (c *CozeExpress) callWorkflow(ctx context.Context, params map[string]interface{}) (*coze.RunWorkflowsResp, error) { + // 准备工作流请求参数 + workflowReq := &coze.RunWorkflowsReq{ + WorkflowID: c.config.APIKey, + Parameters: params, + } + + // 调用工作流 + resp, err := c.cozeApi.Workflows.Runs.Create(ctx, workflowReq) + if err != nil { + return nil, fmt.Errorf("工作流调用失败: %w", err) + } + + // 处理工作流响应 + if resp == nil { + return nil, fmt.Errorf("工作流响应为空") + } + + // 返回工作流执行结果 + return resp, nil +} From 384eddcf04e292124144fc76aefee7500dbe32cc Mon Sep 17 00:00:00 2001 From: fuzhongyun <15339891972@163.com> Date: Mon, 15 Dec 2025 16:52:02 +0800 Subject: [PATCH 39/66] fix: go1.24.7 --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index 2bd6192..8b02af6 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ ## 使用官方Go镜像作为构建环境 -FROM golang:1.24.1-alpine AS builder +FROM golang:1.24.7-alpine AS builder # 设置工作目录 WORKDIR /app From b409266751c8a8d8a79b067dccf40c36f4a4a484 Mon Sep 17 00:00:00 2001 From: renzhiyuan <465386466@qq.com> Date: Tue, 16 Dec 2025 11:09:53 +0800 Subject: [PATCH 40/66] refactor: optimize tools execution and authentication --- cmd/server/wire.go | 2 + internal/biz/ding_talk_bot.go | 212 +++++++++++++++++- internal/biz/do/ctx.go | 9 - internal/biz/do/handle.go | 128 ++++++----- internal/biz/handle/dingtalk/dept.go | 5 +- internal/biz/handle/dingtalk/user.go | 2 + internal/biz/router.go | 9 +- internal/biz/tools_regis/provider_set.go | 9 + internal/biz/tools_regis/tools_regis.go | 30 +++ internal/data/constants/bot.go | 8 + internal/data/constants/dingtalk.go | 13 ++ internal/data/impl/bot_group.go | 27 +++ internal/data/impl/bot_tools.go | 17 ++ internal/data/impl/bot_user.go | 10 +- internal/data/impl/provider_set.go | 2 + internal/data/model/ai_bot_dept.gen.go | 13 +- internal/data/model/ai_bot_group.gen.go | 27 +++ internal/data/model/ai_bot_tools.gen.go | 32 +++ internal/domain/workflow/runtime/registry.go | 6 +- .../zltx/order_after_reseller_batch.go | 12 +- internal/entitys/bot.go | 17 +- internal/entitys/dingtalk.go | 5 +- internal/entitys/recognize.go | 18 +- internal/entitys/response.go | 9 + internal/entitys/types.go | 2 +- internal/pkg/rec_extra/ext.go | 22 ++ internal/services/dtalk_bot.go | 51 +++-- internal/tools/manager.go | 73 +----- internal/tools/public/coze_company.go | 7 +- internal/tools/public/coze_express.go | 7 +- internal/tools/public/konwledge_base.go | 18 +- internal/tools/public/normal_chat.go | 12 +- internal/tools/public/weather.go | 8 +- internal/tools/zltx/order_after_reseller.go | 26 ++- .../tools/zltx/order_after_reseller_batch.go | 21 +- internal/tools/zltx/order_after_supplier.go | 28 ++- internal/tools/zltx/zltx_order_detail.go | 29 ++- internal/tools/zltx/zltx_order_direct_log.go | 18 +- internal/tools/zltx/zltx_product.go | 20 +- internal/tools/zltx/zltx_statistics.go | 17 +- internal/tools_bot/dtalk_bot.go | 27 --- 41 files changed, 680 insertions(+), 328 deletions(-) create mode 100644 internal/biz/tools_regis/provider_set.go create mode 100644 internal/biz/tools_regis/tools_regis.go create mode 100644 internal/data/impl/bot_group.go create mode 100644 internal/data/impl/bot_tools.go create mode 100644 internal/data/model/ai_bot_group.gen.go create mode 100644 internal/data/model/ai_bot_tools.gen.go create mode 100644 internal/pkg/rec_extra/ext.go delete mode 100644 internal/tools_bot/dtalk_bot.go diff --git a/cmd/server/wire.go b/cmd/server/wire.go index 79f3667..22d62b3 100644 --- a/cmd/server/wire.go +++ b/cmd/server/wire.go @@ -6,6 +6,7 @@ package main import ( "ai_scheduler/internal/biz" "ai_scheduler/internal/biz/handle/dingtalk" + "ai_scheduler/internal/biz/tools_regis" "ai_scheduler/internal/config" "ai_scheduler/internal/data/impl" "ai_scheduler/internal/domain/workflow" @@ -33,6 +34,7 @@ func InitializeApp(*config.Config, log.AllLogger) (*server.Servers, func(), erro utils.ProviderUtils, tools_bot.ProviderSetBotTools, dingtalk.ProviderSetDingTalk, + tools_regis.ProviderToolsRegis, )) } diff --git a/internal/biz/ding_talk_bot.go b/internal/biz/ding_talk_bot.go index c67c8e5..ca7431b 100644 --- a/internal/biz/ding_talk_bot.go +++ b/internal/biz/ding_talk_bot.go @@ -3,16 +3,23 @@ package biz import ( "ai_scheduler/internal/biz/do" "ai_scheduler/internal/biz/handle/dingtalk" + "ai_scheduler/internal/biz/tools_regis" "ai_scheduler/internal/data/constants" "ai_scheduler/internal/data/impl" "ai_scheduler/internal/data/model" "ai_scheduler/internal/entitys" + "ai_scheduler/internal/tools" + "context" + "database/sql" "encoding/json" + "errors" "fmt" + "strings" "github.com/gofiber/fiber/v2/log" "github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot" + "xorm.io/builder" ) @@ -24,6 +31,9 @@ type DingTalkBotBiz struct { replier *chatbot.ChatbotReplier log log.Logger dingTalkUser *dingtalk.User + botTools []model.AiBotTool + botGroupImpl *impl.BotGroupImpl + toolManager *tools.Manager } // NewDingTalkBotBiz @@ -31,8 +41,10 @@ func NewDingTalkBotBiz( do *do.Do, handle *do.Handle, botConfigImpl *impl.BotConfigImpl, + botGroupImpl *impl.BotGroupImpl, dingTalkUser *dingtalk.User, - + tools *tools_regis.ToolRegis, + toolManager *tools.Manager, ) *DingTalkBotBiz { return &DingTalkBotBiz{ do: do, @@ -40,6 +52,9 @@ func NewDingTalkBotBiz( botConfigImpl: botConfigImpl, replier: chatbot.NewChatbotReplier(), dingTalkUser: dingTalkUser, + botTools: tools.BootTools, + botGroupImpl: botGroupImpl, + toolManager: toolManager, } } @@ -68,18 +83,201 @@ func (d *DingTalkBotBiz) InitRequire(ctx context.Context, data *chatbot.BotCallb Req: data, Ch: make(chan entitys.Response, 2), } - entitys.ResLog(requireData.Ch, "recognize_start", "收到消息,正在处理中,请稍等") - - requireData.UserInfo, err = d.dingTalkUser.GetUserInfoFromBot(ctx, data.SenderStaffId, dingtalk.WithId(1)) return } -func (d *DingTalkBotBiz) Recognize(ctx context.Context, bot *chatbot.BotCallbackDataModel) (match entitys.Match, err error) { - - return d.handle.Recognize(ctx, nil, &do.WithDingTalkBot{}) +func (d *DingTalkBotBiz) Do(ctx context.Context, requireData *entitys.RequireDataDingTalkBot) (err error) { + entitys.ResText(requireData.Ch, "", "收到消息,正在处理中,请稍等") + defer close(requireData.Ch) + switch constants.ConversationType(requireData.Req.ConversationType) { + case constants.ConversationTypeSingle: + err = d.handleSingleChat(ctx, requireData) + case constants.ConversationTypeGroup: + err = d.handleGroupChat(ctx, requireData) + default: + err = errors.New("未知的聊天类型:" + requireData.Req.ConversationType) + } + return } +func (d *DingTalkBotBiz) handleSingleChat(ctx context.Context, requireData *entitys.RequireDataDingTalkBot) (err error) { + entitys.ResLog(requireData.Ch, "", "个人聊天暂未开启,请期待后续更新") + return + //requireData.UserInfo, err = d.dingTalkUser.GetUserInfoFromBot(ctx, requireData.Req.SenderStaffId, dingtalk.WithId(1)) + //if err != nil { + // return + //} + ////如果不是管理或者不是老板,则进行权限判断 + //if requireData.UserInfo.IsSenior == constants.IsSeniorFalse && requireData.UserInfo.IsBoss == constants.IsBossFalse { + // + //} + //return +} + +func (d *DingTalkBotBiz) handleGroupChat(ctx context.Context, requireData *entitys.RequireDataDingTalkBot) (err error) { + group, err := d.initGroup(ctx, requireData.Req.ConversationId, requireData.Req.ConversationTitle) + if err != nil { + return + } + groupTools, err := d.getGroupTools(ctx, group) + if err != nil { + return + } + rec, err := d.recognize(ctx, requireData, groupTools) + if err != nil { + return + } + + return d.handleMatch(ctx, rec) +} + +func (d *DingTalkBotBiz) initGroup(ctx context.Context, conversationId string, conversationTitle string) (group *model.AiBotGroup, err error) { + group, err = d.botGroupImpl.GetByConversationId(conversationId) + if err != nil { + if !errors.Is(err, sql.ErrNoRows) { + + return + } + } + + if group.GroupID == 0 { + group = &model.AiBotGroup{ + ConversationID: conversationId, + Title: conversationTitle, + ToolList: "", + } + //如果不存在则创建 + d.botGroupImpl.Add(group) + } + return +} + +func (d *DingTalkBotBiz) getGroupTools(ctx context.Context, group *model.AiBotGroup) (tools []model.AiBotTool, err error) { + if len(d.botTools) == 0 { + return + } + var ( + groupRegisTools map[string]struct{} + ) + if group.ToolList != "" { + groupList := strings.Split(group.ToolList, ",") + for _, tool := range groupList { + groupRegisTools[tool] = struct{}{} + } + } + + for _, v := range d.botTools { + if v.PermissionType == constants.PermissionTypeNone { + tools = append(tools, v) + continue + } + if _, ex := groupRegisTools[v.Index]; ex { + tools = append(tools, v) + } + } + return +} +func (d *DingTalkBotBiz) recognize(ctx context.Context, requireData *entitys.RequireDataDingTalkBot, tools []model.AiBotTool) (rec *entitys.Recognize, err error) { + + userContent, err := d.getUserContent(requireData.Req.Msgtype, requireData.Req.Text.Content) + if err != nil { + return nil, err + } + rec = &entitys.Recognize{ + Ch: requireData.Ch, + SystemPrompt: d.defaultPrompt(), + UserContent: userContent, + } + if len(tools) > 0 { + rec.Tasks = make([]entitys.RegistrationTask, 0, len(tools)) + for _, task := range tools { + taskConfig := entitys.TaskConfigDetail{} + if err = json.Unmarshal([]byte(task.Config), &taskConfig); err != nil { + log.Errorf("解析任务配置失败: %s, 任务ID: %s", err.Error(), task.Index) + continue // 解析失败时跳过该任务,而不是直接返回错误 + } + + rec.Tasks = append(rec.Tasks, entitys.RegistrationTask{ + Name: task.Index, + Desc: task.Desc, + TaskConfigDetail: taskConfig, // 直接使用解析后的配置,避免重复构建 + }) + } + } + err = d.handle.Recognize(ctx, rec, &do.WithDingTalkBot{}) + return +} + +func (d *DingTalkBotBiz) getUserContent(msgType string, msgContent interface{}) (content *entitys.RecognizeUserContent, err error) { + switch constants.BotMsgType(msgType) { + case constants.BotMsgTypeText: + content = &entitys.RecognizeUserContent{ + Text: msgContent.(string), + } + default: + return nil, errors.New("未知的消息类型:" + msgType) + } + return +} + +func (d *DingTalkBotBiz) defaultPrompt() string { + + return `{"system":"智能路由系统,精准解析用户意图并路由至任务模块,遵循以下规则:","rule":{"返回格式":"{\\"index\\":\\"工具索引\\",\\"confidence\\":\\"0.0-1.0\\",\\"reasoning\\":\\"判断理由\\",\\"parameters\\":\\"转义JSON参数\\",\\"is_match\\":true|false,\\"chat\\":\\"追问内容\\"}","工具匹配":["用工具parameters匹配,区分必选(required)和可选(optional)参数","无法匹配时,is_match=false,chat提醒用户适用工具(例:'请问您要查询订单还是商品?')"],"参数提取":["从用户输入提取parameters中明确提及的参数","必须参数仅用用户直接提及的,缺失时is_match=false,chat提醒补充(例:'需补充XX信息')"],"格式要求":["所有字段值为字符串(含confidence)","parameters为转义JSON字符串(如\\"{\\\\"key\\\\":\\\\"value\\\\"}\\")"]}}` +} + +func (d *DingTalkBotBiz) handleMatch(ctx context.Context, rec *entitys.Recognize) (err error) { + + if !rec.Match.IsMatch { + if len(rec.Match.Chat) != 0 { + entitys.ResText(rec.Ch, "", rec.Match.Chat) + } else { + entitys.ResText(rec.Ch, "", rec.Match.Reasoning) + } + return + } + var pointTask *model.AiBotTool + for _, task := range d.botTools { + if task.Index == rec.Match.Index { + pointTask = &task + + break + } + } + + if pointTask == nil || pointTask.Index == "other" { + return d.otherTask(ctx, rec) + } + switch constants.TaskType(pointTask.Type) { + //case constants.TaskTypeApi: + //return d.handleApiTask(ctx, requireData, pointTask) + case constants.TaskTypeFunc: + return d.handleTask(ctx, rec, pointTask) + default: + return d.otherTask(ctx, rec) + } + return +} + +func (d *DingTalkBotBiz) handleTask(ctx context.Context, rec *entitys.Recognize, task *model.AiBotTool) (err error) { + var configData entitys.ConfigDataTool + err = json.Unmarshal([]byte(task.Config), &configData) + if err != nil { + return + } + + err = d.toolManager.ExecuteTool(ctx, configData.Tool, rec) + if err != nil { + return + } + + return +} + +func (d *DingTalkBotBiz) otherTask(ctx context.Context, rec *entitys.Recognize) (err error) { + entitys.ResText(rec.Ch, "", rec.Match.Reasoning) + return +} func (d *DingTalkBotBiz) HandleRes(ctx context.Context, data *chatbot.BotCallbackDataModel, resp entitys.Response) error { switch resp.Type { case entitys.ResponseText: diff --git a/internal/biz/do/ctx.go b/internal/biz/do/ctx.go index 80deb73..4b5d2ec 100644 --- a/internal/biz/do/ctx.go +++ b/internal/biz/do/ctx.go @@ -225,15 +225,6 @@ func (d *Do) GetSysInfo(requireData *entitys.RequireData) (sysInfo model.AiSy, e return } -func (d *Do) GetSysInfoForDingTalkBot(requireData *entitys.RequireDataDingTalkBot) (sysInfo model.AiSy, err error) { - cond := builder.NewCond() - cond = cond.And(builder.Eq{"app_key": requireData.Auth}) - cond = cond.And(builder.IsNull{"delete_at"}) - cond = cond.And(builder.Eq{"status": 1}) - err = d.sysImpl.GetOneBySearchToStrut(&cond, &sysInfo) - return -} - func (d *Do) getSessionChatHis(requireData *entitys.RequireData) (his []model.AiChatHi, err error) { cond := builder.NewCond() diff --git a/internal/biz/do/handle.go b/internal/biz/do/handle.go index 8a9162d..76de3d0 100644 --- a/internal/biz/do/handle.go +++ b/internal/biz/do/handle.go @@ -12,10 +12,11 @@ import ( "ai_scheduler/internal/gateway" "ai_scheduler/internal/pkg/l_request" "ai_scheduler/internal/pkg/mapstructure" + "ai_scheduler/internal/pkg/rec_extra" "ai_scheduler/internal/pkg/util" "ai_scheduler/internal/tools" "ai_scheduler/internal/tools/public" - "ai_scheduler/internal/tools_bot" + "context" "encoding/json" "fmt" @@ -25,9 +26,9 @@ import ( ) type Handle struct { - Ollama *llm_service.OllamaService - toolManager *tools.Manager - Bot *tools_bot.BotTool + Ollama *llm_service.OllamaService + toolManager *tools.Manager + conf *config.Config sessionImpl *impl.SessionImpl workflowManager *runtime.Registry @@ -38,20 +39,20 @@ func NewHandle( toolManager *tools.Manager, conf *config.Config, sessionImpl *impl.SessionImpl, - dTalkBot *tools_bot.BotTool, + workflowManager *runtime.Registry, ) *Handle { return &Handle{ - Ollama: Ollama, - toolManager: toolManager, - conf: conf, - sessionImpl: sessionImpl, - Bot: dTalkBot, + Ollama: Ollama, + toolManager: toolManager, + conf: conf, + sessionImpl: sessionImpl, + workflowManager: workflowManager, } } -func (r *Handle) Recognize(ctx context.Context, rec *entitys.Recognize, promptProcessor PromptOption) (match entitys.Match, err error) { +func (r *Handle) Recognize(ctx context.Context, rec *entitys.Recognize, promptProcessor PromptOption) (err error) { entitys.ResLog(rec.Ch, "recognize_start", "准备意图识别") prompt, err := promptProcessor.CreatePrompt(ctx, rec) @@ -65,11 +66,13 @@ func (r *Handle) Recognize(ctx context.Context, rec *entitys.Recognize, promptPr } entitys.ResLog(rec.Ch, "recognize", recognizeMsg) entitys.ResLog(rec.Ch, "recognize_end", "意图识别结束") - + var match entitys.Match if err = json.Unmarshal([]byte(recognizeMsg), &match); err != nil { err = errors.SysErr("数据结构错误:%v", err.Error()) return } + rec.Match = &match + return } @@ -78,28 +81,27 @@ func (r *Handle) handleOtherTask(ctx context.Context, requireData *entitys.Requi return } -func (r *Handle) HandleMatch(ctx context.Context, client *gateway.Client, requireData *entitys.RequireData) (err error) { +func (r *Handle) HandleMatch(ctx context.Context, client *gateway.Client, rec *entitys.Recognize, requireData *entitys.RequireData) (err error) { - if !requireData.Match.IsMatch { - if len(requireData.Match.Chat) != 0 { - entitys.ResText(requireData.Ch, "", requireData.Match.Chat) + if !rec.Match.IsMatch { + if len(rec.Match.Chat) != 0 { + entitys.ResText(rec.Ch, "", rec.Match.Chat) } else { - entitys.ResText(requireData.Ch, "", requireData.Match.Reasoning) + entitys.ResText(rec.Ch, "", rec.Match.Reasoning) } return } var pointTask *model.AiTask for _, task := range requireData.Tasks { - if task.Index == requireData.Match.Index { + if task.Index == rec.Match.Index { pointTask = &task - requireData.Task = task break } } if pointTask == nil || pointTask.Index == "other" { - return r.OtherTask(ctx, requireData) + return r.OtherTask(ctx, rec) } // 校验用户权限 @@ -110,47 +112,32 @@ func (r *Handle) HandleMatch(ctx context.Context, client *gateway.Client, requir switch constants.TaskType(pointTask.Type) { case constants.TaskTypeApi: - return r.handleApiTask(ctx, requireData, pointTask) + return r.handleApiTask(ctx, rec, pointTask) case constants.TaskTypeFunc: - return r.handleTask(ctx, requireData, pointTask) + return r.handleTask(ctx, rec, pointTask) case constants.TaskTypeKnowle: - return r.handleKnowle(ctx, requireData, pointTask) - case constants.TaskTypeBot: - return r.handleBot(ctx, requireData, pointTask) + return r.handleKnowle(ctx, rec, pointTask) + case constants.TaskTypeEinoWorkflow: - return r.handleEinoWorkflow(ctx, requireData, pointTask) + return r.handleEinoWorkflow(ctx, rec, pointTask) default: return r.handleOtherTask(ctx, requireData) } } -func (r *Handle) OtherTask(ctx context.Context, requireData *entitys.RequireData) (err error) { +func (r *Handle) OtherTask(ctx context.Context, requireData *entitys.Recognize) (err error) { entitys.ResText(requireData.Ch, "", requireData.Match.Reasoning) return } -func (r *Handle) handleBot(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) { - var configData entitys.ConfigDataTool - err = json.Unmarshal([]byte(task.Config), &configData) - if err != nil { - return - } - err = r.Bot.Execute(ctx, configData.Tool, requireData) - if err != nil { - return - } - - return -} - -func (r *Handle) handleTask(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) { +func (r *Handle) handleTask(ctx context.Context, rec *entitys.Recognize, task *model.AiTask) (err error) { var configData entitys.ConfigDataTool err = json.Unmarshal([]byte(task.Config), &configData) if err != nil { return } - err = r.toolManager.ExecuteTool(ctx, configData.Tool, requireData) + err = r.toolManager.ExecuteTool(ctx, configData.Tool, rec) if err != nil { return } @@ -159,7 +146,7 @@ func (r *Handle) handleTask(ctx context.Context, requireData *entitys.RequireDat } // 知识库 -func (r *Handle) handleKnowle(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) { +func (r *Handle) handleKnowle(ctx context.Context, rec *entitys.Recognize, task *model.AiTask) (err error) { var ( configData entitys.ConfigDataTool @@ -171,13 +158,16 @@ func (r *Handle) handleKnowle(ctx context.Context, requireData *entitys.RequireD if err != nil { return } - + ext, err := rec_extra.GetTaskRecExt(rec) + if err != nil { + return + } // 通过session 找到知识库session var has bool - if len(requireData.Session) == 0 { + if len(ext.Session) == 0 { return errors.SessionNotFound } - requireData.SessionInfo, has, err = r.sessionImpl.FindOne(r.sessionImpl.WithSessionId(requireData.Session)) + ext.SessionInfo, has, err = r.sessionImpl.FindOne(r.sessionImpl.WithSessionId(ext.Session)) if err != nil { return } else if !has { @@ -200,15 +190,15 @@ func (r *Handle) handleKnowle(ctx context.Context, requireData *entitys.RequireD } // 知识库的session为空,请求知识库获取, 并绑定 - if requireData.SessionInfo.KnowlegeSessionID == "" { + if ext.SessionInfo.KnowlegeSessionID == "" { // 请求知识库 - if sessionIdKnowledge, err = public.GetKnowledgeBaseSession(host, requireData.Sys.KnowlegeBaseID, requireData.Sys.KnowlegeTenantKey); err != nil { + if sessionIdKnowledge, err = public.GetKnowledgeBaseSession(host, ext.Sys.KnowlegeBaseID, ext.Sys.KnowlegeTenantKey); err != nil { return } // 绑定知识库session,下次可以使用 - requireData.SessionInfo.KnowlegeSessionID = sessionIdKnowledge - if err = r.sessionImpl.Update(&requireData.SessionInfo, r.sessionImpl.WithSessionId(requireData.SessionInfo.SessionID)); err != nil { + ext.SessionInfo.KnowlegeSessionID = sessionIdKnowledge + if err = r.sessionImpl.Update(&ext.SessionInfo, r.sessionImpl.WithSessionId(ext.SessionInfo.SessionID)); err != nil { return } } @@ -216,21 +206,21 @@ func (r *Handle) handleKnowle(ctx context.Context, requireData *entitys.RequireD // 用户输入解析 var ok bool input := make(map[string]string) - if err = json.Unmarshal([]byte(requireData.Match.Parameters), &input); err != nil { + if err = json.Unmarshal([]byte(rec.Match.Parameters), &input); err != nil { return } if query, ok = input["query"]; !ok { return fmt.Errorf("query不能为空") } - requireData.KnowledgeConf = entitys.KnowledgeBaseRequest{ - Session: requireData.SessionInfo.KnowlegeSessionID, - ApiKey: requireData.Sys.KnowlegeTenantKey, + ext.KnowledgeConf = entitys.KnowledgeBaseRequest{ + Session: ext.SessionInfo.KnowlegeSessionID, + ApiKey: ext.Sys.KnowlegeTenantKey, Query: query, } // 执行工具 - err = r.toolManager.ExecuteTool(ctx, configData.Tool, requireData) + err = r.toolManager.ExecuteTool(ctx, configData.Tool, rec) if err != nil { return } @@ -238,17 +228,21 @@ func (r *Handle) handleKnowle(ctx context.Context, requireData *entitys.RequireD return } -func (r *Handle) handleApiTask(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) { +func (r *Handle) handleApiTask(ctx context.Context, rec *entitys.Recognize, task *model.AiTask) (err error) { var ( request l_request.Request requestParam map[string]interface{} ) - err = json.Unmarshal([]byte(requireData.Match.Parameters), &requestParam) + ext, err := rec_extra.GetTaskRecExt(rec) + if err != nil { + return + } + err = json.Unmarshal([]byte(rec.Match.Parameters), &requestParam) if err != nil { return } // request.Url = strings.ReplaceAll(task.Config, "${authorization}", requireData.Auth) - task.Config = strings.ReplaceAll(task.Config, "${authorization}", requireData.Auth) + task.Config = strings.ReplaceAll(task.Config, "${authorization}", ext.Auth) for k, v := range requestParam { if vStr, ok := v.(string); ok { task.Config = strings.ReplaceAll(task.Config, "${"+k+"}", vStr) @@ -275,27 +269,31 @@ func (r *Handle) handleApiTask(ctx context.Context, requireData *entitys.Require return } - entitys.ResLoading(requireData.Ch, requireData.Task.Index, "正在请求数据") + entitys.ResLoading(rec.Ch, task.Index, "正在请求数据") res, err := request.Send() if err != nil { return } - entitys.ResJson(requireData.Ch, requireData.Task.Index, res.Text) + entitys.ResJson(rec.Ch, task.Index, res.Text) return } // eino 工作流 -func (r *Handle) handleEinoWorkflow(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) { +func (r *Handle) handleEinoWorkflow(ctx context.Context, rec *entitys.Recognize, task *model.AiTask) (err error) { // token 写入ctx - ctx = util.SetTokenToContext(ctx, requireData.Auth) + ext, err := rec_extra.GetTaskRecExt(rec) + if err != nil { + return + } + ctx = util.SetTokenToContext(ctx, ext.Auth) - entitys.ResLoading(requireData.Ch, requireData.Task.Index, "正在执行工作流") + entitys.ResLoading(rec.Ch, task.Index, "正在执行工作流") // 工作流内部输出 workflowId := task.Index - _, err = r.workflowManager.Invoke(ctx, workflowId, requireData) + _, err = r.workflowManager.Invoke(ctx, workflowId, rec) if err != nil { return err } diff --git a/internal/biz/handle/dingtalk/dept.go b/internal/biz/handle/dingtalk/dept.go index 0a6bd15..df8a607 100644 --- a/internal/biz/handle/dingtalk/dept.go +++ b/internal/biz/handle/dingtalk/dept.go @@ -41,8 +41,9 @@ func (d *Dept) GetDeptInfoByDeptIds(ctx context.Context, deptIds []int, authInfo var existDept = make([]int, len(deptsInfo), 0) for _, dept := range deptsInfo { depts = append(depts, &entitys.Dept{ - DeptId: int(dept.DeptID), - Name: dept.Name, + DeptId: int(dept.DeptID), + Name: dept.Name, + ToolList: dept.ToolList, }) existDept = append(existDept, int(dept.DeptID)) } diff --git a/internal/biz/handle/dingtalk/user.go b/internal/biz/handle/dingtalk/user.go index 900a492..2e4e615 100644 --- a/internal/biz/handle/dingtalk/user.go +++ b/internal/biz/handle/dingtalk/user.go @@ -46,12 +46,14 @@ func (u *User) GetUserInfoFromBot(ctx context.Context, staffId string, botOption return } } + //待优化 authInfo, err := u.auth.GetTokenFromBotOption(ctx, botOption...) if err != nil || authInfo == nil { return } //如果没有找到,则新增 if user == nil { + DingUserInfo, _err := u.getUserInfoFromDingTalk(ctx, authInfo.AccessToken, staffId) if _err != nil { return nil, _err diff --git a/internal/biz/router.go b/internal/biz/router.go index 5127a54..28ce0b8 100644 --- a/internal/biz/router.go +++ b/internal/biz/router.go @@ -5,6 +5,8 @@ import ( "ai_scheduler/internal/data/constants" errors "ai_scheduler/internal/data/error" "ai_scheduler/internal/gateway" + + "ai_scheduler/internal/pkg/rec_extra" "context" "encoding/json" "strings" @@ -68,14 +70,15 @@ func (r *AiRouterBiz) RouteWithSocket(client *gateway.Client, req *entitys.ChatS } //意图识别 - requireData.Match, err = r.handle.Recognize(ctx, &rec, sys) + err = r.handle.Recognize(ctx, &rec, sys) if err != nil { log.Errorf("意图识别失败: %s", err.Error()) return } - + //任务处理 + rec_extra.SetTaskRecExt(requireData, &rec) //向下传递 - if err = r.handle.HandleMatch(ctx, client, requireData); err != nil { + if err = r.handle.HandleMatch(ctx, client, &rec, requireData); err != nil { log.Errorf("任务处理失败: %s", err.Error()) return } diff --git a/internal/biz/tools_regis/provider_set.go b/internal/biz/tools_regis/provider_set.go new file mode 100644 index 0000000..8294cf0 --- /dev/null +++ b/internal/biz/tools_regis/provider_set.go @@ -0,0 +1,9 @@ +package tools_regis + +import ( + "github.com/google/wire" +) + +var ProviderToolsRegis = wire.NewSet( + NewToolsRegis, +) diff --git a/internal/biz/tools_regis/tools_regis.go b/internal/biz/tools_regis/tools_regis.go new file mode 100644 index 0000000..0109849 --- /dev/null +++ b/internal/biz/tools_regis/tools_regis.go @@ -0,0 +1,30 @@ +package tools_regis + +import ( + "ai_scheduler/internal/data/constants" + "ai_scheduler/internal/data/impl" + "ai_scheduler/internal/data/model" + + "xorm.io/builder" +) + +type ToolRegis struct { + //待优化 + BootTools []model.AiBotTool +} + +func NewToolsRegis(botToolsImpl *impl.BotToolsImpl) *ToolRegis { + botTools := &ToolRegis{} + err := botTools.RegisTools(botToolsImpl) + if err != nil { + panic(err) + } + return botTools +} + +func (t *ToolRegis) RegisTools(botToolsImpl *impl.BotToolsImpl) error { + cond := builder.NewCond() + cond = cond.And(builder.Eq{"status": constants.Enable}) + err := botToolsImpl.GetRangeToMapStruct(&cond, &t.BootTools) + return err +} diff --git a/internal/data/constants/bot.go b/internal/data/constants/bot.go index 446519a..78a46f1 100644 --- a/internal/data/constants/bot.go +++ b/internal/data/constants/bot.go @@ -33,3 +33,11 @@ const ( ) const DingTalkAuthBaseKeyPrefix = "dingTalk_auth" + +// PermissionType 工具使用权限 +type PermissionType int32 + +const ( + PermissionTypeNone = 1 + PermissionTypeDept = 2 +) diff --git a/internal/data/constants/dingtalk.go b/internal/data/constants/dingtalk.go index 5c9de89..c6d55b0 100644 --- a/internal/data/constants/dingtalk.go +++ b/internal/data/constants/dingtalk.go @@ -36,3 +36,16 @@ const ( IsSeniorTrue IsSenior = 1 IsSeniorFalse IsSenior = 0 ) + +type ConversationType string + +const ( + ConversationTypeSingle = "1" // 单聊 + ConversationTypeGroup = "2" //群聊 +) + +type BotMsgType string + +const ( + BotMsgTypeText BotMsgType = "text" +) diff --git a/internal/data/impl/bot_group.go b/internal/data/impl/bot_group.go new file mode 100644 index 0000000..4382d82 --- /dev/null +++ b/internal/data/impl/bot_group.go @@ -0,0 +1,27 @@ +package impl + +import ( + "ai_scheduler/internal/data/model" + "ai_scheduler/tmpl/dataTemp" + "ai_scheduler/utils" + "database/sql" +) + +type BotGroupImpl struct { + dataTemp.DataTemp +} + +func NewBotGroupImpl(db *utils.Db) *BotGroupImpl { + return &BotGroupImpl{ + DataTemp: *dataTemp.NewDataTemp(db, new(model.AiBotGroup)), + } +} + +func (k BotGroupImpl) GetByConversationId(staffId string) (*model.AiBotGroup, error) { + var data model.AiBotGroup + err := k.Db.Model(k.Model).Where("conversation_id = ?", staffId).Find(&data).Error + if data.GroupID == 0 { + err = sql.ErrNoRows + } + return &data, err +} diff --git a/internal/data/impl/bot_tools.go b/internal/data/impl/bot_tools.go new file mode 100644 index 0000000..d119098 --- /dev/null +++ b/internal/data/impl/bot_tools.go @@ -0,0 +1,17 @@ +package impl + +import ( + "ai_scheduler/internal/data/model" + "ai_scheduler/tmpl/dataTemp" + "ai_scheduler/utils" +) + +type BotToolsImpl struct { + dataTemp.DataTemp +} + +func NewBotToolsImpl(db *utils.Db) *BotToolsImpl { + return &BotToolsImpl{ + DataTemp: *dataTemp.NewDataTemp(db, new(model.AiBotTool)), + } +} diff --git a/internal/data/impl/bot_user.go b/internal/data/impl/bot_user.go index 46d8442..862292f 100644 --- a/internal/data/impl/bot_user.go +++ b/internal/data/impl/bot_user.go @@ -17,11 +17,11 @@ func NewBotUserImpl(db *utils.Db) *BotUserImpl { } } -func (k BotUserImpl) GetByStaffId(staffId string) (data *model.AiBotUser, err error) { - - err = k.Db.Model(k.Model).Where("staff_id = ?", staffId).Find(data).Error - if data == nil { +func (k BotUserImpl) GetByStaffId(staffId string) (*model.AiBotUser, error) { + var data model.AiBotUser + err := k.Db.Model(k.Model).Where("staff_id = ?", staffId).Find(&data).Error + if data.UserID == 0 { err = sql.ErrNoRows } - return + return &data, err } diff --git a/internal/data/impl/provider_set.go b/internal/data/impl/provider_set.go index 1563258..5624b3e 100644 --- a/internal/data/impl/provider_set.go +++ b/internal/data/impl/provider_set.go @@ -13,4 +13,6 @@ var ProviderImpl = wire.NewSet( NewBotDeptImpl, NewBotUserImpl, NewBotChatHisImpl, + NewBotToolsImpl, + NewBotGroupImpl, ) diff --git a/internal/data/model/ai_bot_dept.gen.go b/internal/data/model/ai_bot_dept.gen.go index db8ba60..ceddcce 100644 --- a/internal/data/model/ai_bot_dept.gen.go +++ b/internal/data/model/ai_bot_dept.gen.go @@ -12,12 +12,13 @@ const TableNameAiBotDept = "ai_bot_dept" // AiBotDept mapped from table type AiBotDept struct { - DeptID int32 `gorm:"column:dept_id;primaryKey" json:"dept_id"` - DingtalkDeptID int32 `gorm:"column:dingtalk_dept_id;not null;comment:标记部门的唯一id,钉钉:钉钉侧提供的dept_id" json:"dingtalk_dept_id"` // 标记部门的唯一id,钉钉:钉钉侧提供的dept_id - Name string `gorm:"column:name;not null;comment:用户名称" json:"name"` // 用户名称 - Status int32 `gorm:"column:status;not null" json:"status"` - DeleteAt time.Time `gorm:"column:delete_at" json:"delete_at"` - CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"` + DeptID int32 `gorm:"column:dept_id;primaryKey;autoIncrement:true" json:"dept_id"` + DingtalkDeptID int32 `gorm:"column:dingtalk_dept_id;not null;comment:标记部门的唯一id,钉钉:钉钉侧提供的dept_id" json:"dingtalk_dept_id"` // 标记部门的唯一id,钉钉:钉钉侧提供的dept_id + Name string `gorm:"column:name;not null;comment:用户名称" json:"name"` // 用户名称 + ToolList string `gorm:"column:tool_list;not null;comment:该部门支持的权限" json:"tool_list"` // 该部门支持的权限 + Status int32 `gorm:"column:status;not null;default:1" json:"status"` + DeleteAt *time.Time `gorm:"column:delete_at" json:"delete_at"` + CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"` } // TableName AiBotDept's table name diff --git a/internal/data/model/ai_bot_group.gen.go b/internal/data/model/ai_bot_group.gen.go new file mode 100644 index 0000000..d0ff93a --- /dev/null +++ b/internal/data/model/ai_bot_group.gen.go @@ -0,0 +1,27 @@ +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. + +package model + +import ( + "time" +) + +const TableNameAiBotGroup = "ai_bot_group" + +// AiBotGroup mapped from table +type AiBotGroup struct { + GroupID int32 `gorm:"column:group_id;primaryKey;autoIncrement:true" json:"group_id"` + ConversationID string `gorm:"column:conversation_id;not null;comment:会话ID" json:"conversation_id"` // 会话ID + Title string `gorm:"column:title;not null;comment:群名称" json:"title"` // 群名称 + ToolList string `gorm:"column:tool_list;not null;comment:开通工具列表" json:"tool_list"` // 开通工具列表 + Status int32 `gorm:"column:status;not null;default:1" json:"status"` + DeleteAt *time.Time `gorm:"column:delete_at" json:"delete_at"` + CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"` +} + +// TableName AiBotGroup's table name +func (*AiBotGroup) TableName() string { + return TableNameAiBotGroup +} diff --git a/internal/data/model/ai_bot_tools.gen.go b/internal/data/model/ai_bot_tools.gen.go new file mode 100644 index 0000000..f57889b --- /dev/null +++ b/internal/data/model/ai_bot_tools.gen.go @@ -0,0 +1,32 @@ +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. + +package model + +import ( + "time" +) + +const TableNameAiBotTool = "ai_bot_tools" + +// AiBotTool mapped from table +type AiBotTool struct { + ToolID int32 `gorm:"column:tool_id;primaryKey;autoIncrement:true" json:"tool_id"` + PermissionType int32 `gorm:"column:permission_type;not null;comment:类型,1为公共工具,不需要进行权限管理,反之则为2" json:"permission_type"` // 类型,1为公共工具,不需要进行权限管理,反之则为2 + Config string `gorm:"column:config;not null;comment:类型下所需路由以及参数" json:"config"` // 类型下所需路由以及参数 + Type int32 `gorm:"column:type;not null;default:3" json:"type"` + Name string `gorm:"column:name;not null;default:1;comment:工具名称" json:"name"` // 工具名称 + Index string `gorm:"column:index;not null;comment:索引" json:"index"` // 索引 + Desc string `gorm:"column:desc;not null;comment:工具描述" json:"desc"` // 工具描述 + TempPrompt string `gorm:"column:temp_prompt;not null;comment:提示词模板" json:"temp_prompt"` // 提示词模板 + CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"` + UpdatedAt time.Time `gorm:"column:updated_at;default:CURRENT_TIMESTAMP" json:"updated_at"` + Status int32 `gorm:"column:status;not null" json:"status"` + DeleteAt time.Time `gorm:"column:delete_at" json:"delete_at"` +} + +// TableName AiBotTool's table name +func (*AiBotTool) TableName() string { + return TableNameAiBotTool +} diff --git a/internal/domain/workflow/runtime/registry.go b/internal/domain/workflow/runtime/registry.go index 8ae7a82..c854053 100644 --- a/internal/domain/workflow/runtime/registry.go +++ b/internal/domain/workflow/runtime/registry.go @@ -12,7 +12,7 @@ import ( type Workflow interface { ID() string Schema() map[string]any - Invoke(ctx context.Context, requireData *entitys.RequireData) (map[string]any, error) + Invoke(ctx context.Context, requireData *entitys.Recognize) (map[string]any, error) } type Deps struct { @@ -63,7 +63,7 @@ func Default() *Registry { return r } -func (r *Registry) Invoke(ctx context.Context, id string, requireData *entitys.RequireData) (map[string]any, error) { +func (r *Registry) Invoke(ctx context.Context, id string, rec *entitys.Recognize) (map[string]any, error) { regMu.RLock() f, ok := factories[id] regMu.RUnlock() @@ -89,5 +89,5 @@ func (r *Registry) Invoke(ctx context.Context, id string, requireData *entitys.R r.mu.Unlock() } - return w.Invoke(ctx, requireData) + return w.Invoke(ctx, rec) } diff --git a/internal/domain/workflow/zltx/order_after_reseller_batch.go b/internal/domain/workflow/zltx/order_after_reseller_batch.go index 4d151d2..b4df00d 100644 --- a/internal/domain/workflow/zltx/order_after_reseller_batch.go +++ b/internal/domain/workflow/zltx/order_after_reseller_batch.go @@ -31,7 +31,7 @@ type OrderAfterSaleResellerBatchWorkflowInput struct { Ch chan entitys.Response // 响应通道 UserInput string // 用户输入文本 FileContent string // 文件解析结果 - UserHistory []model.AiChatHi // 用户对话历史 + UserHistory entitys.ChatHis // 用户对话历史 ParameterResult string // 参数解析结果 Data *OrderAfterSaleResellerBatchNodeData // 节点所需参数 } @@ -88,7 +88,7 @@ func (o *orderAfterSaleResellerBatch) Schema() map[string]any { } // Invoke 调用原有编排工作流并规范化输出 -func (o *orderAfterSaleResellerBatch) Invoke(ctx context.Context, requireData *entitys.RequireData) (map[string]any, error) { +func (o *orderAfterSaleResellerBatch) Invoke(ctx context.Context, rec *entitys.Recognize) (map[string]any, error) { // 构建工作流 chain, err := o.buildWorkflow(ctx) if err != nil { @@ -96,11 +96,11 @@ func (o *orderAfterSaleResellerBatch) Invoke(ctx context.Context, requireData *e } o.data = &OrderAfterSaleResellerBatchWorkflowInput{ - Ch: requireData.Ch, - UserInput: requireData.Req.Text, + Ch: rec.Ch, + UserInput: rec.UserContent.Text, FileContent: "", - UserHistory: requireData.Histories, - ParameterResult: requireData.Match.Parameters, + UserHistory: rec.ChatHis, + ParameterResult: rec.Match.Parameters, } // 工作流过程输出,不关注最终输出 _, err = chain.Invoke(ctx, o.data) diff --git a/internal/entitys/bot.go b/internal/entitys/bot.go index af92ba8..7085a37 100644 --- a/internal/entitys/bot.go +++ b/internal/entitys/bot.go @@ -3,21 +3,16 @@ package entitys import ( "ai_scheduler/internal/data/model" - "github.com/ollama/ollama/api" "github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot" ) type RequireDataDingTalkBot struct { - Histories []model.AiChatHi - UserInfo *DingTalkUserInfo - Tasks []model.AiTask - Match *Match - Req *chatbot.BotCallbackDataModel - Auth string - Ch chan Response - KnowledgeConf KnowledgeBaseRequest - ImgByte []api.ImageData - ImgUrls []string + Histories []model.AiChatHi + UserInfo *DingTalkUserInfo + Tools []model.AiBotTool + Match *Match + Req *chatbot.BotCallbackDataModel + Ch chan Response } type DingTalkBot struct { diff --git a/internal/entitys/dingtalk.go b/internal/entitys/dingtalk.go index 3595bf3..93cc825 100644 --- a/internal/entitys/dingtalk.go +++ b/internal/entitys/dingtalk.go @@ -17,6 +17,7 @@ type DingTalkUserInfo struct { } type Dept struct { - Name string `json:"name"` - DeptId int `json:"dept_id"` + Name string `json:"name"` + DeptId int `json:"dept_id"` + ToolList string `json:"tool_list"` } diff --git a/internal/entitys/recognize.go b/internal/entitys/recognize.go index 3d68be7..831bef7 100644 --- a/internal/entitys/recognize.go +++ b/internal/entitys/recognize.go @@ -1,6 +1,9 @@ package entitys -import "ai_scheduler/internal/data/constants" +import ( + "ai_scheduler/internal/data/constants" + "ai_scheduler/internal/data/model" +) type Recognize struct { SystemPrompt string // 系统提示内容 @@ -8,11 +11,23 @@ type Recognize struct { ChatHis ChatHis // 会话历史记录 Tasks []RegistrationTask Ch chan Response + Match *Match + Ext []byte +} + +type TaskExt struct { + Auth string `json:"auth"` + Session string `json:"session"` + Key string `json:"key"` + SessionInfo model.AiSession + Sys model.AiSy + KnowledgeConf KnowledgeBaseRequest } type RegistrationTask struct { Name string Desc string + Index string TaskConfigDetail TaskConfigDetail } @@ -26,6 +41,7 @@ type RecognizeUserContent struct { type FileData []byte type RecognizeFile struct { + FileRec string //文件识别内容 FileData FileData // 文件数据(二进制格式) FileType constants.Caller // 文件类型(文件类型,能填最好填,可以跳过一层判断) FileUrl string // 文件下载链接 diff --git a/internal/entitys/response.go b/internal/entitys/response.go index 44d4d18..39e81ad 100644 --- a/internal/entitys/response.go +++ b/internal/entitys/response.go @@ -25,6 +25,9 @@ const ( ) func ResLog(ch chan Response, index string, content string) { + if ch == nil { + return + } ch <- Response{ Index: index, Content: content, @@ -33,6 +36,9 @@ func ResLog(ch chan Response, index string, content string) { } func ResStream(ch chan Response, index string, content string) { + if ch == nil { + return + } ch <- Response{ Index: index, Content: content, @@ -41,6 +47,9 @@ func ResStream(ch chan Response, index string, content string) { } func ResJson(ch chan Response, index string, content string) { + if ch == nil { + return + } ch <- Response{ Index: index, Content: content, diff --git a/internal/entitys/types.go b/internal/entitys/types.go index b8f8caa..601f50e 100644 --- a/internal/entitys/types.go +++ b/internal/entitys/types.go @@ -78,7 +78,7 @@ type Tool interface { Name() string Description() string Definition() ToolDefinition - Execute(ctx context.Context, requireData *RequireData) error + Execute(ctx context.Context, requireData *Recognize) error } type ConfigDataHttp struct { diff --git a/internal/pkg/rec_extra/ext.go b/internal/pkg/rec_extra/ext.go new file mode 100644 index 0000000..53fbbe7 --- /dev/null +++ b/internal/pkg/rec_extra/ext.go @@ -0,0 +1,22 @@ +package rec_extra + +import ( + "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg" + "encoding/json" +) + +func SetTaskRecExt(requireData *entitys.RequireData, rec *entitys.Recognize) { + TaskExt := entitys.TaskExt{ + Auth: requireData.Auth, + Session: requireData.Session, + Key: requireData.Key, + Sys: requireData.Sys, + } + rec.Ext = pkg.JsonByteIgonErr(TaskExt) +} + +func GetTaskRecExt(rec *entitys.Recognize) (ext *entitys.TaskExt, err error) { + err = json.Unmarshal(rec.Ext, ext) + return ext, err +} diff --git a/internal/services/dtalk_bot.go b/internal/services/dtalk_bot.go index e4e9e8d..0bfe129 100644 --- a/internal/services/dtalk_bot.go +++ b/internal/services/dtalk_bot.go @@ -2,18 +2,19 @@ package services import ( "ai_scheduler/internal/biz" + "log" + "time" + "ai_scheduler/internal/config" "ai_scheduler/internal/entitys" "context" - "fmt" "github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot" ) type DingBotService struct { - config *config.Config - + config *config.Config dingTalkBotBiz *biz.DingTalkBotBiz } @@ -30,38 +31,42 @@ func (d *DingBotService) OnChatBotMessageReceived(ctx context.Context, data *cha if err != nil { return } + // 使用 ctx.Done() 通知 Do 方法提前终止 + subCtx, cancel := context.WithCancel(ctx) + defer cancel() + // 异步执行 Do 方法 + done := make(chan error, 1) go func() { - //defer close(requireData.Ch) - //if match, _err := d.dingTalkBotBiz.Recognize(ctx, data); _err != nil { - // requireData.Ch <- entitys.Response{ - // Type: entitys.ResponseEnd, - // Content: fmt.Sprintf("处理消息时出错: %s", _err.Error()), - // } - //} - ////向下传递 - //if err = d.dingTalkBotBiz.HandleMatch(ctx, nil, requireData); err != nil { - // requireData.Ch <- entitys.Response{ - // Type: entitys.ResponseEnd, - // Content: fmt.Sprintf("匹配失败: %v", err), - // } - //} + done <- d.dingTalkBotBiz.Do(subCtx, requireData) }() + + var lastErr error for { select { case <-ctx.Done(): - return nil, ctx.Err() + lastErr = ctx.Err() + goto cleanup case resp, ok := <-requireData.Ch: if !ok { - return []byte("success"), nil // 通道关闭,处理完成 + return []byte("success"), nil } - if resp.Type == entitys.ResponseLog { - return + continue } if err := d.dingTalkBotBiz.HandleRes(ctx, data, resp); err != nil { - return nil, fmt.Errorf("回复失败: %w", err) + log.Printf("HandleRes 失败: %v", err) } } } - return +cleanup: + select { + case err := <-done: + if err != nil { + log.Printf("Do 方法执行失败: %v", err) + } + case <-time.After(1 * time.Second): + log.Println("警告:等待 Do 方法超时,可能发生 goroutine 泄漏") + } + + return nil, lastErr } diff --git a/internal/tools/manager.go b/internal/tools/manager.go index cd1940d..eedb4ee 100644 --- a/internal/tools/manager.go +++ b/internal/tools/manager.go @@ -24,24 +24,6 @@ func NewManager(config *config.Config, llm *utils_ollama.Client) *Manager { llm: llm, } - // 注册天气工具 - //if config.Tools.Weather.Enabled { - // weatherTool := NewWeatherTool() - // m.tools[weatherTool.Name()] = weatherTool - //} - // - //// 注册计算器工具 - //if config.Tools.Calculator.Enabled { - // calcTool := NewCalculatorTool() - // m.tools[calcTool.Name()] = calcTool - //} - - // 注册知识库工具 - // if config.Knowledge.Enabled { - // knowledgeTool := NewKnowledgeTool() - // m.tools[knowledgeTool.Name()] = knowledgeTool - // } - // 注册直连天下订单详情工具 if config.Tools.ZltxOrderDetail.Enabled { zltxOrderDetailTool := zltxtool.NewZltxOrderDetailTool(config.Tools.ZltxOrderDetail, m.llm) @@ -115,63 +97,12 @@ func (m *Manager) GetTool(name string) (entitys.Tool, bool) { return tool, exists } -// GetAllTools 获取所有工具 -// func (m *Manager) GetAllTools() []entitys.Tool { -// tools := make([]entitys.Tool, 0, len(m.tools)) -// for _, tool := range m.tools { -// tools = append(tools, tool) -// } -// return tools -// } - -// // GetToolDefinitions 获取所有工具定义 -// func (m *Manager) GetToolDefinitions(caller constants.Caller) []entitys.ToolDefinition { -// definitions := make([]entitys.ToolDefinition, 0, len(m.tools)) -// for _, tool := range m.tools { -// definitions = append(definitions, tool.Definition()) -// } - -// return definitions -// } - // ExecuteTool 执行工具 -func (m *Manager) ExecuteTool(ctx context.Context, name string, requireData *entitys.RequireData) error { +func (m *Manager) ExecuteTool(ctx context.Context, name string, rec *entitys.Recognize) error { tool, exists := m.GetTool(name) if !exists { return fmt.Errorf("tool not found: %s", name) } - return tool.Execute(ctx, requireData) + return tool.Execute(ctx, rec) } - -// ExecuteToolCalls 执行多个工具调用 -//func (m *Manager) ExecuteToolCalls(c *websocket.Conn, toolCalls []entitys.ToolCall) ([]entitys.ToolCall, error) { -// results := make([]entitys.ToolCall, len(toolCalls)) -// -// for i, toolCall := range toolCalls { -// results[i] = toolCall -// -// // 执行工具 -// err := m.ExecuteTool(c, toolCall.Function.Name, toolCall.Function.Arguments) -// if err != nil { -// // 将错误信息作为结果返回 -// errorResult := map[string]interface{}{ -// "error": err.Error(), -// } -// resultBytes, _ := json.Marshal(errorResult) -// results[i].Result = resultBytes -// } else { -// // 将成功结果序列化 -// resultBytes, err := json.Marshal(result) -// if err != nil { -// errorResult := map[string]interface{}{ -// "error": fmt.Sprintf("failed to serialize result: %v", err), -// } -// resultBytes, _ = json.Marshal(errorResult) -// } -// results[i].Result = resultBytes -// } -// } -// -// return results, nil -//} diff --git a/internal/tools/public/coze_company.go b/internal/tools/public/coze_company.go index 3f10638..e061965 100644 --- a/internal/tools/public/coze_company.go +++ b/internal/tools/public/coze_company.go @@ -7,10 +7,11 @@ import ( "context" "encoding/json" "fmt" - "github.com/ollama/ollama/api" "net/http" "time" + "github.com/ollama/ollama/api" + "github.com/coze-dev/coze-go" ) @@ -70,7 +71,7 @@ func (c *CozeCompany) Definition() entitys.ToolDefinition { } // Execute 执行查询 -func (c *CozeCompany) Execute(ctx context.Context, requireData *entitys.RequireData) error { +func (c *CozeCompany) Execute(ctx context.Context, requireData *entitys.Recognize) error { var req map[string]interface{} if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil { return fmt.Errorf("invalid express request: %w", err) @@ -191,7 +192,7 @@ Top3核心风险:1. 根据司法信息,近1年作为被告的合同纠纷占 }, { Role: "user", - Content: requireData.Req.Text, + Content: requireData.UserContent.Text, }, }, c.Name(), "") diff --git a/internal/tools/public/coze_express.go b/internal/tools/public/coze_express.go index de316b5..58e6172 100644 --- a/internal/tools/public/coze_express.go +++ b/internal/tools/public/coze_express.go @@ -8,6 +8,7 @@ import ( "context" "encoding/json" "fmt" + "github.com/ollama/ollama/api" "github.com/coze-dev/coze-go" @@ -67,7 +68,7 @@ func (c *CozeExpress) Definition() entitys.ToolDefinition { } // Execute 执行查询 -func (c *CozeExpress) Execute(ctx context.Context, requireData *entitys.RequireData) error { +func (c *CozeExpress) Execute(ctx context.Context, requireData *entitys.Recognize) error { var req map[string]interface{} if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil { return fmt.Errorf("invalid express request: %w", err) @@ -89,7 +90,7 @@ func (c *CozeExpress) Execute(ctx context.Context, requireData *entitys.RequireD }, { Role: "assistant", - Content: fmt.Sprintf("聊天记录:%s", pkg.JsonStringIgonErr(requireData.Histories)), + Content: fmt.Sprintf("聊天记录:%s", pkg.JsonStringIgonErr(requireData.ChatHis)), }, { Role: "assistant", @@ -97,7 +98,7 @@ func (c *CozeExpress) Execute(ctx context.Context, requireData *entitys.RequireD }, { Role: "user", - Content: requireData.Req.Text, + Content: requireData.UserContent.Text, }, }, c.Name(), "") if err != nil { diff --git a/internal/tools/public/konwledge_base.go b/internal/tools/public/konwledge_base.go index 505a349..48ddc1e 100644 --- a/internal/tools/public/konwledge_base.go +++ b/internal/tools/public/konwledge_base.go @@ -4,6 +4,7 @@ import ( "ai_scheduler/internal/config" "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg/l_request" + "ai_scheduler/internal/pkg/rec_extra" "bufio" "context" "encoding/json" @@ -58,7 +59,7 @@ func (k *KnowledgeBaseTool) Definition() entitys.ToolDefinition { } // Execute 执行知识库查询 -func (k *KnowledgeBaseTool) Execute(ctx context.Context, requireData *entitys.RequireData) error { +func (k *KnowledgeBaseTool) Execute(ctx context.Context, requireData *entitys.Recognize) error { entitys.ResLoading(requireData.Ch, k.Name(), "正在为您搜索相关信息") return k.chat(requireData) @@ -91,20 +92,23 @@ func (this *KnowledgeBaseTool) msgContentParse(input string, channel chan entity } // 请求知识库聊天 -func (this *KnowledgeBaseTool) chat(requireData *entitys.RequireData) (err error) { - +func (this *KnowledgeBaseTool) chat(rec *entitys.Recognize) (err error) { + ext, err := rec_extra.GetTaskRecExt(rec) + if err != nil { + return + } req := l_request.Request{ Method: "post", - Url: this.config.BaseURL + "/api/v1/knowledge-chat/" + requireData.KnowledgeConf.Session, + Url: this.config.BaseURL + "/api/v1/knowledge-chat/" + ext.KnowledgeConf.Session, Params: nil, Headers: map[string]string{ "Content-Type": "application/json", - "X-API-Key": requireData.KnowledgeConf.ApiKey, + "X-API-Key": ext.KnowledgeConf.ApiKey, }, Cookies: nil, Data: nil, Json: map[string]interface{}{ - "query": requireData.KnowledgeConf.Query, + "query": ext.KnowledgeConf.Query, }, Files: nil, Raw: "", @@ -118,7 +122,7 @@ func (this *KnowledgeBaseTool) chat(requireData *entitys.RequireData) (err error } defer rsp.Body.Close() - err = this.connectAndReadSSE(rsp, requireData.Ch) + err = this.connectAndReadSSE(rsp, rec.Ch) if err != nil { return } diff --git a/internal/tools/public/normal_chat.go b/internal/tools/public/normal_chat.go index a4c96e0..cde667c 100644 --- a/internal/tools/public/normal_chat.go +++ b/internal/tools/public/normal_chat.go @@ -43,9 +43,9 @@ func (w *NormalChatTool) Definition() entitys.ToolDefinition { } // Execute 执行直连天下订单详情查询 -func (w *NormalChatTool) Execute(ctx context.Context, requireData *entitys.RequireData) error { +func (w *NormalChatTool) Execute(ctx context.Context, rec *entitys.Recognize) error { var req NormalChat - if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil { + if err := json.Unmarshal([]byte(rec.Match.Parameters), &req); err != nil { return fmt.Errorf("invalid zltxOrderDetail request: %w", err) } if req.ChatContent == "" { @@ -53,25 +53,25 @@ func (w *NormalChatTool) Execute(ctx context.Context, requireData *entitys.Requi } // 这里可以集成真实的直连天下订单详情API - return w.chat(requireData, &req) + return w.chat(rec, &req) } // getMockZltxOrderDetail 获取模拟直连天下订单详情数据 -func (w *NormalChatTool) chat(requireData *entitys.RequireData, chat *NormalChat) (err error) { +func (w *NormalChatTool) chat(rec *entitys.Recognize, chat *NormalChat) (err error) { //requireData.Ch <- entitys.Response{ // Index: w.Name(), // Content: "", // Type: entitys.ResponseStream, //} - err = w.llm.ChatStream(context.TODO(), requireData.Ch, []api.Message{ + err = w.llm.ChatStream(context.TODO(), rec.Ch, []api.Message{ { Role: "system", Content: "你是一个聊天助手", }, { Role: "assistant", - Content: fmt.Sprintf("聊天记录:%s", pkg.JsonStringIgonErr(requireData.Histories)), + Content: fmt.Sprintf("聊天记录:%s", pkg.JsonStringIgonErr(rec.ChatHis)), }, { Role: "user", diff --git a/internal/tools/public/weather.go b/internal/tools/public/weather.go index 0292fa4..ca36907 100644 --- a/internal/tools/public/weather.go +++ b/internal/tools/public/weather.go @@ -107,9 +107,9 @@ type LiveWeather struct { } // Execute 执行天气查询 -func (w *WeatherTool) Execute(ctx context.Context, requireData *entitys.RequireData) error { +func (w *WeatherTool) Execute(ctx context.Context, rec *entitys.Recognize) error { var req WeatherRequest - if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil { + if err := json.Unmarshal([]byte(rec.Match.Parameters), &req); err != nil { return fmt.Errorf("invalid weather request: %w", err) } @@ -134,7 +134,7 @@ func (w *WeatherTool) Execute(ctx context.Context, requireData *entitys.RequireD // 根据 extensions 参数返回不同的天气信息 if req.Extensions == "base" { - entitys.ResText(requireData.Ch, "", fmt.Sprintf("%s实时天气:%s,温度:%.1f℃,湿度:%d%%,风速:%.1fkm/h,风向:%s", + entitys.ResText(rec.Ch, "", fmt.Sprintf("%s实时天气:%s,温度:%.1f℃,湿度:%d%%,风速:%.1fkm/h,风向:%s", req.City, responseMsg.LiveWeather.Condition, responseMsg.LiveWeather.Temperature, @@ -149,7 +149,7 @@ func (w *WeatherTool) Execute(ctx context.Context, requireData *entitys.RequireD forecast.Date, forecast.DayTemp, forecast.NightTemp, forecast.DayWind, forecast.NightWind) } - entitys.ResText(requireData.Ch, "", rspStr) + entitys.ResText(rec.Ch, "", rspStr) } return nil } diff --git a/internal/tools/zltx/order_after_reseller.go b/internal/tools/zltx/order_after_reseller.go index 71e0629..fcb71e0 100644 --- a/internal/tools/zltx/order_after_reseller.go +++ b/internal/tools/zltx/order_after_reseller.go @@ -4,6 +4,7 @@ import ( "ai_scheduler/internal/config" "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg/l_request" + "ai_scheduler/internal/pkg/rec_extra" "ai_scheduler/internal/pkg/util" "context" "encoding/json" @@ -104,9 +105,9 @@ type OrderAfterSaleResellerApiExtItem struct { SerialCreateTime int `json:"createTime"` // 流水创建时间 } -func (t *OrderAfterSaleResellerTool) Execute(ctx context.Context, requireData *entitys.RequireData) error { +func (t *OrderAfterSaleResellerTool) Execute(ctx context.Context, rec *entitys.Recognize) error { var req OrderAfterSaleResellerRequest - if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil { + if err := json.Unmarshal([]byte(rec.Match.Parameters), &req); err != nil { return fmt.Errorf("解析参数失败,请重试或联系管理员") } if len(req.OrderNumber) == 0 && len(req.Account) == 0 { @@ -116,18 +117,22 @@ func (t *OrderAfterSaleResellerTool) Execute(ctx context.Context, requireData *e if req.SerialCreateTime != "" { _, err := time.ParseInLocation(time.DateTime, req.SerialCreateTime, time.Local) if err != nil { - entitys.ResLog(requireData.Ch, t.Name(), "时间格式不匹配,已置为空") + entitys.ResLog(rec.Ch, t.Name(), "时间格式不匹配,已置为空") req.SerialCreateTime = "" } } - entitys.ResLog(requireData.Ch, t.Name(), "正在拉取售后订单信息") + entitys.ResLog(rec.Ch, t.Name(), "正在拉取售后订单信息") - return t.checkOrderAfterSaleReseller(req, requireData) + return t.checkOrderAfterSaleReseller(req, rec) } -func (t *OrderAfterSaleResellerTool) checkOrderAfterSaleReseller(toolReq OrderAfterSaleResellerRequest, requireData *entitys.RequireData) error { +func (t *OrderAfterSaleResellerTool) checkOrderAfterSaleReseller(toolReq OrderAfterSaleResellerRequest, rec *entitys.Recognize) error { var serialStartTime, serialEndTime int64 + ext, err := rec_extra.GetTaskRecExt(rec) + if err != nil { + return err + } if toolReq.SerialCreateTime != "" { // 流水创建时间上下浮动10min serialCreateTime, err := time.ParseInLocation(time.DateTime, toolReq.SerialCreateTime, time.Local) @@ -144,17 +149,16 @@ func (t *OrderAfterSaleResellerTool) checkOrderAfterSaleReseller(toolReq OrderAf // 账号数量超过10直接截断 if len(toolReq.Account) > 10 { - entitys.ResLog(requireData.Ch, t.Name(), "账号数量超过10已被截断") + entitys.ResLog(rec.Ch, t.Name(), "账号数量超过10已被截断") toolReq.Account = toolReq.Account[:10] } headers := map[string]string{ - "Authorization": fmt.Sprintf("Bearer %s", requireData.Auth), + "Authorization": fmt.Sprintf("Bearer %s", ext.Auth), } // 最终输出 var orderList []*OrderAfterSaleResellerData - var err error // 多订单号 if len(toolReq.OrderNumber) > 0 { @@ -217,8 +221,8 @@ func (t *OrderAfterSaleResellerTool) checkOrderAfterSaleReseller(toolReq OrderAf return err } - entitys.ResLog(requireData.Ch, t.Name(), "售后订单信息拉取完成") - entitys.ResJson(requireData.Ch, t.Name(), string(jsonByte)) + entitys.ResLog(rec.Ch, t.Name(), "售后订单信息拉取完成") + entitys.ResJson(rec.Ch, t.Name(), string(jsonByte)) return nil } diff --git a/internal/tools/zltx/order_after_reseller_batch.go b/internal/tools/zltx/order_after_reseller_batch.go index e12e602..664d3d5 100644 --- a/internal/tools/zltx/order_after_reseller_batch.go +++ b/internal/tools/zltx/order_after_reseller_batch.go @@ -4,6 +4,7 @@ import ( "ai_scheduler/internal/config" "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg/l_request" + "ai_scheduler/internal/pkg/rec_extra" "ai_scheduler/internal/pkg/util" "context" "encoding/json" @@ -100,25 +101,29 @@ type OrderAfterSaleResellerBatchApiExtItem struct { SerialCreateTime int `json:"createTime"` // 流水创建时间 } -func (t *OrderAfterSaleResellerBatchTool) Execute(ctx context.Context, requireData *entitys.RequireData) error { +func (t *OrderAfterSaleResellerBatchTool) Execute(ctx context.Context, rec *entitys.Recognize) error { var req OrderAfterSaleResellerBatchRequest - if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil { + if err := json.Unmarshal([]byte(rec.Match.Parameters), &req); err != nil { return fmt.Errorf("解析参数失败,请重试或联系管理员") } if len(req.OrderNumber) == 0 { return fmt.Errorf("批充订单号不能为空") } - entitys.ResLog(requireData.Ch, t.Name(), "正在拉取售后订单信息") + entitys.ResLog(rec.Ch, t.Name(), "正在拉取售后订单信息") - return t.checkOrderAfterSaleResellerBatch(req, requireData) + return t.checkOrderAfterSaleResellerBatch(req, rec) } -func (t *OrderAfterSaleResellerBatchTool) checkOrderAfterSaleResellerBatch(toolReq OrderAfterSaleResellerBatchRequest, requireData *entitys.RequireData) error { +func (t *OrderAfterSaleResellerBatchTool) checkOrderAfterSaleResellerBatch(toolReq OrderAfterSaleResellerBatchRequest, rec *entitys.Recognize) error { + ext, err := rec_extra.GetTaskRecExt(rec) + if err != nil { + return err + } req := l_request.Request{ Url: t.config.BaseURL, Headers: map[string]string{ - "Authorization": fmt.Sprintf("Bearer %s", requireData.Auth), + "Authorization": fmt.Sprintf("Bearer %s", ext.Auth), }, Method: "POST", Json: map[string]any{ @@ -200,7 +205,7 @@ func (t *OrderAfterSaleResellerBatchTool) checkOrderAfterSaleResellerBatch(toolR return err } - entitys.ResLog(requireData.Ch, t.Name(), "售后订单信息拉取完成") - entitys.ResJson(requireData.Ch, t.Name(), string(jsonByte)) + entitys.ResLog(rec.Ch, t.Name(), "售后订单信息拉取完成") + entitys.ResJson(rec.Ch, t.Name(), string(jsonByte)) return nil } diff --git a/internal/tools/zltx/order_after_supplier.go b/internal/tools/zltx/order_after_supplier.go index 4baf7f9..9e9ad82 100644 --- a/internal/tools/zltx/order_after_supplier.go +++ b/internal/tools/zltx/order_after_supplier.go @@ -4,6 +4,7 @@ import ( "ai_scheduler/internal/config" "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg/l_request" + "ai_scheduler/internal/pkg/rec_extra" "ai_scheduler/internal/pkg/util" "context" "encoding/json" @@ -98,9 +99,9 @@ type OrderAfterSaleSupplierApiExtItem struct { SerialCreateTime int `json:"createTime"` // 流水创建时间 } -func (t *OrderAfterSaleSupplierTool) Execute(ctx context.Context, requireData *entitys.RequireData) error { +func (t *OrderAfterSaleSupplierTool) Execute(ctx context.Context, rec *entitys.Recognize) error { var req OrderAfterSaleSupplierRequest - if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil { + if err := json.Unmarshal([]byte(rec.Match.Parameters), &req); err != nil { return fmt.Errorf("解析参数失败,请重试或联系管理员") } if len(req.SerialNumber) == 0 && len(req.Account) == 0 { @@ -110,18 +111,24 @@ func (t *OrderAfterSaleSupplierTool) Execute(ctx context.Context, requireData *e if req.SerialCreateTime != "" { _, err := time.ParseInLocation(time.DateTime, req.SerialCreateTime, time.Local) if err != nil { - entitys.ResLog(requireData.Ch, t.Name(), "时间格式不匹配,已置为空") + entitys.ResLog(rec.Ch, t.Name(), "时间格式不匹配,已置为空") req.SerialCreateTime = "" } } - entitys.ResLog(requireData.Ch, t.Name(), "正在拉取售后订单信息") + entitys.ResLog(rec.Ch, t.Name(), "正在拉取售后订单信息") - return t.checkOrderAfterSaleSupplier(req, requireData) + return t.checkOrderAfterSaleSupplier(req, rec) } -func (t *OrderAfterSaleSupplierTool) checkOrderAfterSaleSupplier(toolReq OrderAfterSaleSupplierRequest, requireData *entitys.RequireData) error { +func (t *OrderAfterSaleSupplierTool) checkOrderAfterSaleSupplier(toolReq OrderAfterSaleSupplierRequest, rec *entitys.Recognize) error { + var serialStartTime, serialEndTime int64 + ext, err := rec_extra.GetTaskRecExt(rec) + if err != nil { + return err + } + if toolReq.SerialCreateTime != "" { // 流水创建时间上下浮动10min serialCreateTime, err := time.ParseInLocation(time.DateTime, toolReq.SerialCreateTime, time.Local) @@ -138,17 +145,16 @@ func (t *OrderAfterSaleSupplierTool) checkOrderAfterSaleSupplier(toolReq OrderAf // 账号数量超过10直接截断 if len(toolReq.Account) > 10 { - entitys.ResLog(requireData.Ch, t.Name(), "账号数量超过10已被截断") + entitys.ResLog(rec.Ch, t.Name(), "账号数量超过10已被截断") toolReq.Account = toolReq.Account[:10] } headers := map[string]string{ - "Authorization": fmt.Sprintf("Bearer %s", requireData.Auth), + "Authorization": fmt.Sprintf("Bearer %s", ext.Auth), } // 最终输出 var orderList []*OrderAfterSaleSupplierData - var err error // 多流水号 if len(toolReq.SerialNumber) > 0 { @@ -210,8 +216,8 @@ func (t *OrderAfterSaleSupplierTool) checkOrderAfterSaleSupplier(toolReq OrderAf return err } - entitys.ResLog(requireData.Ch, t.Name(), "售后订单信息拉取完成") - entitys.ResJson(requireData.Ch, t.Name(), string(jsonByte)) + entitys.ResLog(rec.Ch, t.Name(), "售后订单信息拉取完成") + entitys.ResJson(rec.Ch, t.Name(), string(jsonByte)) return nil } diff --git a/internal/tools/zltx/zltx_order_detail.go b/internal/tools/zltx/zltx_order_detail.go index 7dae9f1..52d6bbb 100644 --- a/internal/tools/zltx/zltx_order_detail.go +++ b/internal/tools/zltx/zltx_order_detail.go @@ -4,6 +4,7 @@ import ( "ai_scheduler/internal/config" "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg" + "ai_scheduler/internal/pkg/rec_extra" "ai_scheduler/internal/pkg/utils_ollama" "context" "encoding/json" @@ -81,9 +82,9 @@ type ZltxOrderDetailData struct { } // Execute 执行直连天下订单详情查询 -func (w *ZltxOrderDetailTool) Execute(ctx context.Context, requireData *entitys.RequireData) error { +func (w *ZltxOrderDetailTool) Execute(ctx context.Context, rec *entitys.Recognize) error { var req ZltxOrderDetailRequest - if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil { + if err := json.Unmarshal([]byte(rec.Match.Parameters), &req); err != nil { return fmt.Errorf("invalid zltxOrderDetail request: %w", err) } if req.OrderNumber == "" { @@ -91,16 +92,20 @@ func (w *ZltxOrderDetailTool) Execute(ctx context.Context, requireData *entitys. } // 这里可以集成真实的直连天下订单详情API - return w.getZltxOrderDetail(requireData, req.OrderNumber) + return w.getZltxOrderDetail(rec, req.OrderNumber) } // getMockZltxOrderDetail 获取模拟直连天下订单详情数据 -func (w *ZltxOrderDetailTool) getZltxOrderDetail(requireData *entitys.RequireData, number string) (err error) { +func (w *ZltxOrderDetailTool) getZltxOrderDetail(rec *entitys.Recognize, number string) (err error) { + ext, err := rec_extra.GetTaskRecExt(rec) + if err != nil { + return + } //查询订单详情 req := l_request.Request{ Url: fmt.Sprintf(w.config.BaseURL, number), Headers: map[string]string{ - "Authorization": fmt.Sprintf("Bearer %s", requireData.Auth), + "Authorization": fmt.Sprintf("Bearer %s", ext.Auth), }, Method: "GET", } @@ -121,15 +126,15 @@ func (w *ZltxOrderDetailTool) getZltxOrderDetail(requireData *entitys.RequireDat if err = json.Unmarshal(res.Content, &resData); err != nil { return } - entitys.ResJson(requireData.Ch, w.Name(), res.Text) + entitys.ResJson(rec.Ch, w.Name(), res.Text) if resData.Data.Direct != nil { - entitys.ResLoading(requireData.Ch, w.Name(), "正在分析订单日志") + entitys.ResLoading(rec.Ch, w.Name(), "正在分析订单日志") req = l_request.Request{ Url: fmt.Sprintf(w.config.AddURL, resData.Data.Direct["orderOrderNumber"].(string), resData.Data.Direct["serialNumber"].(string)), Headers: map[string]string{ - "Authorization": fmt.Sprintf("Bearer %s", requireData.Auth), + "Authorization": fmt.Sprintf("Bearer %s", ext.Auth), }, Method: "GET", } @@ -149,14 +154,14 @@ func (w *ZltxOrderDetailTool) getZltxOrderDetail(requireData *entitys.RequireDat return fmt.Errorf("订单日志解析失败:%s", err) } - err = w.llm.ChatStream(context.TODO(), requireData.Ch, []api.Message{ + err = w.llm.ChatStream(context.TODO(), rec.Ch, []api.Message{ { Role: "system", Content: "你是一个订单日志助手。用户可能会提供订单日志,你需要分析订单日志,失败订单->分析失败原因,成功订单->找出整个日志的 Base64 编码的 JSON 数据的内容进行转换并反馈给我", }, { Role: "assistant", - Content: fmt.Sprintf("聊天记录:%s", pkg.JsonStringIgonErr(requireData.Histories)), + Content: fmt.Sprintf("聊天记录:%s", pkg.JsonStringIgonErr(rec.ChatHis)), }, { Role: "assistant", @@ -164,7 +169,7 @@ func (w *ZltxOrderDetailTool) getZltxOrderDetail(requireData *entitys.RequireDat }, { Role: "user", - Content: requireData.Req.Text, + Content: rec.UserContent.Text, }, }, w.Name(), "") if err != nil { @@ -172,7 +177,7 @@ func (w *ZltxOrderDetailTool) getZltxOrderDetail(requireData *entitys.RequireDat } } if resData.Data.Direct == nil { - entitys.ResText(requireData.Ch, w.Name(), "该订单无充值流水记录,不需要进行订单错误分析,建议检查订单收单模式以及扣款方式等原因😘") + entitys.ResText(rec.Ch, w.Name(), "该订单无充值流水记录,不需要进行订单错误分析,建议检查订单收单模式以及扣款方式等原因😘") } return } diff --git a/internal/tools/zltx/zltx_order_direct_log.go b/internal/tools/zltx/zltx_order_direct_log.go index b1b1483..188a954 100644 --- a/internal/tools/zltx/zltx_order_direct_log.go +++ b/internal/tools/zltx/zltx_order_direct_log.go @@ -3,6 +3,7 @@ package zltx import ( "ai_scheduler/internal/config" "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg/rec_extra" "context" "encoding/json" "fmt" @@ -67,25 +68,28 @@ type ZltxOrderDirectLogData struct { Data map[string]interface{} `json:"data"` } -func (t *ZltxOrderLogTool) Execute(ctx context.Context, requireData *entitys.RequireData) error { +func (t *ZltxOrderLogTool) Execute(ctx context.Context, rec *entitys.Recognize) error { var req ZltxOrderLogRequest - if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil { + if err := json.Unmarshal([]byte(rec.Match.Parameters), &req); err != nil { return fmt.Errorf("invalid zltxOrderLog request: %w", err) } if req.OrderNumber == "" || req.SerialNumber == "" { return fmt.Errorf("orderNumber and serialNumber is required") } - return t.getZltxOrderLog(req.OrderNumber, req.SerialNumber, requireData) + return t.getZltxOrderLog(req.OrderNumber, req.SerialNumber, rec) } -func (t *ZltxOrderLogTool) getZltxOrderLog(orderNumber, serialNumber string, requireData *entitys.RequireData) (err error) { +func (t *ZltxOrderLogTool) getZltxOrderLog(orderNumber, serialNumber string, rec *entitys.Recognize) (err error) { //查询订单详情 - + ext, err := rec_extra.GetTaskRecExt(rec) + if err != nil { + return + } url := fmt.Sprintf("%s%s/%s", t.config.BaseURL, orderNumber, serialNumber) req := l_request.Request{ Url: url, Headers: map[string]string{ - "Authorization": fmt.Sprintf("Bearer %s", requireData.Auth), + "Authorization": fmt.Sprintf("Bearer %s", ext.Auth), }, Method: "GET", } @@ -100,7 +104,7 @@ func (t *ZltxOrderLogTool) getZltxOrderLog(orderNumber, serialNumber string, req if err = json.Unmarshal(res.Content, &resData); err != nil { return } - entitys.ResJson(requireData.Ch, t.Name(), res.Text) + entitys.ResJson(rec.Ch, t.Name(), res.Text) return } diff --git a/internal/tools/zltx/zltx_product.go b/internal/tools/zltx/zltx_product.go index de24236..0c63d84 100644 --- a/internal/tools/zltx/zltx_product.go +++ b/internal/tools/zltx/zltx_product.go @@ -3,6 +3,7 @@ package zltx import ( "ai_scheduler/internal/config" "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg/rec_extra" "context" "encoding/json" "fmt" @@ -53,12 +54,12 @@ type ZltxProductRequest struct { Name string `json:"name"` } -func (z ZltxProductTool) Execute(ctx context.Context, requireData *entitys.RequireData) error { +func (z ZltxProductTool) Execute(ctx context.Context, rec *entitys.Recognize) error { var req ZltxProductRequest - if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil { + if err := json.Unmarshal([]byte(rec.Match.Parameters), &req); err != nil { return fmt.Errorf("invalid zltxProduct request: %w", err) } - return z.getZltxProduct(&req, requireData) + return z.getZltxProduct(&req, rec) } type ZltxProductResponse struct { @@ -133,8 +134,11 @@ type ZltxProductData struct { PlatformProductList interface{} `json:"platform_product_list"` } -func (z ZltxProductTool) getZltxProduct(body *ZltxProductRequest, requireData *entitys.RequireData) error { - +func (z ZltxProductTool) getZltxProduct(body *ZltxProductRequest, rec *entitys.Recognize) error { + ext, err := rec_extra.GetTaskRecExt(rec) + if err != nil { + return err + } var Url string var params map[string]string if body.Id != "" { @@ -153,7 +157,7 @@ func (z ZltxProductTool) getZltxProduct(body *ZltxProductRequest, requireData *e //根据商品ID或名称走不同的接口查询 Url: Url, Headers: map[string]string{ - "Authorization": fmt.Sprintf("Bearer %s", requireData.Auth), + "Authorization": fmt.Sprintf("Bearer %s", ext.Auth), }, Params: params, Method: "GET", @@ -185,7 +189,7 @@ func (z ZltxProductTool) getZltxProduct(body *ZltxProductRequest, requireData *e for i := range resp.Data.List { // 调用 平台商品列表 if resp.Data.List[i].AuthProductIds != "" { - platformProductList := z.ExecutePlatformProductList(requireData.Auth, resp.Data.List[i].AuthProductIds, resp.Data.List[i].OfficialProductID) + platformProductList := z.ExecutePlatformProductList(ext.Auth, resp.Data.List[i].AuthProductIds, resp.Data.List[i].OfficialProductID) resp.Data.List[i].PlatformProductList = platformProductList } } @@ -194,7 +198,7 @@ func (z ZltxProductTool) getZltxProduct(body *ZltxProductRequest, requireData *e if err != nil { return err } - entitys.ResJson(requireData.Ch, z.Name(), string(marshal)) + entitys.ResJson(rec.Ch, z.Name(), string(marshal)) return nil } diff --git a/internal/tools/zltx/zltx_statistics.go b/internal/tools/zltx/zltx_statistics.go index 4a051a8..5d71a9b 100644 --- a/internal/tools/zltx/zltx_statistics.go +++ b/internal/tools/zltx/zltx_statistics.go @@ -3,6 +3,7 @@ package zltx import ( "ai_scheduler/internal/config" "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg/rec_extra" "context" "encoding/json" "fmt" @@ -47,15 +48,15 @@ type ZltxOrderStatisticsRequest struct { Number string `json:"number"` } -func (z ZltxOrderStatisticsTool) Execute(ctx context.Context, requireData *entitys.RequireData) error { +func (z ZltxOrderStatisticsTool) Execute(ctx context.Context, rec *entitys.Recognize) error { var req ZltxOrderStatisticsRequest - if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil { + if err := json.Unmarshal([]byte(rec.Match.Parameters), &req); err != nil { return err } if req.Number == "" { return fmt.Errorf("number is required") } - return z.getZltxOrderStatistics(req.Number, requireData) + return z.getZltxOrderStatistics(req.Number, rec) } type ZltxOrderStatisticsResponse struct { @@ -75,14 +76,18 @@ type ZltxOrderStatisticsData struct { Total int `json:"total"` } -func (z ZltxOrderStatisticsTool) getZltxOrderStatistics(number string, requireData *entitys.RequireData) error { +func (z ZltxOrderStatisticsTool) getZltxOrderStatistics(number string, rec *entitys.Recognize) error { + ext, err := rec_extra.GetTaskRecExt(rec) + if err != nil { + return err + } //查询订单详情 url := fmt.Sprintf("%s%s", z.config.BaseURL, number) req := l_request.Request{ Url: url, Headers: map[string]string{ - "Authorization": fmt.Sprintf("Bearer %s", requireData.Auth), + "Authorization": fmt.Sprintf("Bearer %s", ext.Auth), }, Method: "GET", } @@ -108,7 +113,7 @@ func (z ZltxOrderStatisticsTool) getZltxOrderStatistics(number string, requireDa if err != nil { return err } - entitys.ResJson(requireData.Ch, z.Name(), string(jsonByte)) + entitys.ResJson(rec.Ch, z.Name(), string(jsonByte)) return nil } diff --git a/internal/tools_bot/dtalk_bot.go b/internal/tools_bot/dtalk_bot.go deleted file mode 100644 index 2eae1d5..0000000 --- a/internal/tools_bot/dtalk_bot.go +++ /dev/null @@ -1,27 +0,0 @@ -package tools_bot - -import ( - "ai_scheduler/internal/config" - "ai_scheduler/internal/data/impl" - "ai_scheduler/internal/entitys" - "ai_scheduler/internal/pkg/utils_ollama" - - "context" -) - -type BotTool struct { - config *config.Config - llm *utils_ollama.Client - sessionImpl *impl.SessionImpl - taskMap map[string]string -} - -// NewBotTool 创建直连天下订单详情工具 -func NewBotTool(config *config.Config, llm *utils_ollama.Client, sessionImpl *impl.SessionImpl) *BotTool { - return &BotTool{config: config, llm: llm, sessionImpl: sessionImpl, taskMap: make(map[string]string)} -} - -// Execute 执行直连天下订单详情查询 -func (w *BotTool) Execute(ctx context.Context, toolName string, requireData *entitys.RequireData) (err error) { - return -} From 7a48333efcced932d4d77f32e2f05ada4e42816f Mon Sep 17 00:00:00 2001 From: renzhiyuan <465386466@qq.com> Date: Tue, 16 Dec 2025 11:35:10 +0800 Subject: [PATCH 41/66] refactor: rename tools_bot to tool_callback --- cmd/server/wire.go | 4 +- internal/biz/do/handle.go | 3 +- internal/data/constants/user_test.go | 91 ------------------- .../zltx/order_after_reseller_batch.go | 2 +- internal/pkg/rec_extra/ext.go | 4 +- internal/services/callback.go | 14 +-- .../bug_optimization_submit.go | 30 +++++- internal/tool_callback/provider_set.go | 9 ++ internal/tools_bot/provider_set.go | 9 -- 9 files changed, 48 insertions(+), 118 deletions(-) delete mode 100644 internal/data/constants/user_test.go rename internal/{tools_bot => tool_callback}/bug_optimization_submit.go (74%) create mode 100644 internal/tool_callback/provider_set.go delete mode 100644 internal/tools_bot/provider_set.go diff --git a/cmd/server/wire.go b/cmd/server/wire.go index 22d62b3..a35fa9b 100644 --- a/cmd/server/wire.go +++ b/cmd/server/wire.go @@ -13,8 +13,8 @@ import ( "ai_scheduler/internal/pkg" "ai_scheduler/internal/server" "ai_scheduler/internal/services" + "ai_scheduler/internal/tool_callback" "ai_scheduler/internal/tools" - "ai_scheduler/internal/tools_bot" "ai_scheduler/utils" "github.com/gofiber/fiber/v2/log" @@ -32,9 +32,9 @@ func InitializeApp(*config.Config, log.AllLogger) (*server.Servers, func(), erro biz.ProviderSetBiz, impl.ProviderImpl, utils.ProviderUtils, - tools_bot.ProviderSetBotTools, dingtalk.ProviderSetDingTalk, tools_regis.ProviderToolsRegis, + tool_callback.ProviderSetCallBackTools, )) } diff --git a/internal/biz/do/handle.go b/internal/biz/do/handle.go index 76de3d0..a09eb48 100644 --- a/internal/biz/do/handle.go +++ b/internal/biz/do/handle.go @@ -10,6 +10,7 @@ import ( "ai_scheduler/internal/domain/workflow/runtime" "ai_scheduler/internal/entitys" "ai_scheduler/internal/gateway" + "ai_scheduler/internal/pkg" "ai_scheduler/internal/pkg/l_request" "ai_scheduler/internal/pkg/mapstructure" "ai_scheduler/internal/pkg/rec_extra" @@ -218,7 +219,7 @@ func (r *Handle) handleKnowle(ctx context.Context, rec *entitys.Recognize, task ApiKey: ext.Sys.KnowlegeTenantKey, Query: query, } - + rec.Ext = pkg.JsonByteIgonErr(ext) // 执行工具 err = r.toolManager.ExecuteTool(ctx, configData.Tool, rec) if err != nil { diff --git a/internal/data/constants/user_test.go b/internal/data/constants/user_test.go deleted file mode 100644 index eee1b17..0000000 --- a/internal/data/constants/user_test.go +++ /dev/null @@ -1,91 +0,0 @@ -package constants - -import ( - "ai_scheduler/internal/biz" - "ai_scheduler/internal/biz/do" - "ai_scheduler/internal/biz/llm_service" - "ai_scheduler/internal/data/impl" - "ai_scheduler/internal/entitys" - "ai_scheduler/internal/gateway" - "ai_scheduler/internal/pkg/dingtalk" - "ai_scheduler/internal/pkg/utils_ollama" - "ai_scheduler/internal/pkg/utils_vllm" - "ai_scheduler/internal/server" - "ai_scheduler/internal/services" - "ai_scheduler/internal/tools" - "ai_scheduler/internal/tools_bot" - "ai_scheduler/utils" - "os" - "testing" -) - -const - -func TestMain(m *testing.M) { - bootstrap := initialize.LoadConfigWithTest() - businessLogger := log2.NewBusinessLogger(bootstrap.Logs, id, Name, Version) - helper := pkg.NewLogHelper(businessLogger, bootstrap) - - db, cleanup := utils.NewGormDb(configConfig) - sysImpl := impl.NewSysImpl(db) - taskImpl := impl.NewTaskImpl(db) - chatHisImpl := impl.NewChatHisImpl(db) - doDo := do.NewDo(sysImpl, taskImpl, chatHisImpl, configConfig) - client, cleanup2, err := utils_ollama.NewClient(configConfig) - if err != nil { - cleanup() - return nil, nil, err - } - utils_vllmClient, cleanup3, err := utils_vllm.NewClient(configConfig) - if err != nil { - cleanup2() - cleanup() - return nil, nil, err - } - ollamaService := llm_service.NewOllamaGenerate(client, utils_vllmClient, configConfig, chatHisImpl) - manager := tools.NewManager(configConfig, client) - sessionImpl := impl.NewSessionImpl(db) - botTool := tools_bot.NewBotTool(configConfig, client, sessionImpl) - handle := do.NewHandle(ollamaService, manager, configConfig, sessionImpl, botTool) - aiRouterBiz := biz.NewAiRouterBiz(doDo, handle) - chatHistoryBiz := biz.NewChatHistoryBiz(chatHisImpl, taskImpl) - gatewayGateway := gateway.NewGateway() - chatService := services.NewChatService(aiRouterBiz, chatHistoryBiz, gatewayGateway, configConfig) - sessionBiz := biz.NewSessionBiz(configConfig, sessionImpl, sysImpl, chatHisImpl) - sessionService := services.NewSessionService(sessionBiz, chatHistoryBiz) - taskBiz := biz.NewTaskBiz(configConfig, taskImpl) - taskService := services.NewTaskService(sessionBiz, taskBiz) - oldClient := dingtalk.NewOldClient(configConfig) - contactClient, err := dingtalk.NewContactClient(configConfig) - if err != nil { - cleanup3() - cleanup2() - cleanup() - return nil, nil, err - } - notableClient, err := dingtalk.NewNotableClient(configConfig) - if err != nil { - cleanup3() - cleanup2() - cleanup() - return nil, nil, err - } - callbackService := services.NewCallbackService(configConfig, gatewayGateway, oldClient, contactClient, notableClient, botTool) - historyService := services.NewHistoryService(chatHistoryBiz) - app := server.NewHTTPServer(chatService, sessionService, taskService, gatewayGateway, callbackService, historyService) - botConfigImpl := impl.NewBotConfigImpl(db) - dingTalkBotBiz := biz.NewDingTalkBotBiz(doDo, handle, botConfigImpl) - dingBotService := services.NewDingBotService(configConfig, dingTalkBotBiz) - v := server.ProvideAllDingBotServices(dingBotService) - dingTalkBotServer := server.NewDingTalkBotServer(v) - servers := server.NewServers(configConfig, app, dingTalkBotServer) - code := m.Run() - os.Exit(code) -} - -func Test_GetUserInfo(t *testing.T) { - var c entitys.TaskConfig - config := `{"param": {"type": "object", "required": ["number"], "properties": {"number": {"type": "string", "description": "订单编号/流水号"}}}, "request": {"url": "http://www.baidu.com/${number}", "headers": {"Authorization": "${authorization}"}, "method": "GET"}}` - err := json.Unmarshal([]byte(config), &c) - t.Log(err) -} diff --git a/internal/domain/workflow/zltx/order_after_reseller_batch.go b/internal/domain/workflow/zltx/order_after_reseller_batch.go index b4df00d..9749bf7 100644 --- a/internal/domain/workflow/zltx/order_after_reseller_batch.go +++ b/internal/domain/workflow/zltx/order_after_reseller_batch.go @@ -2,7 +2,7 @@ package zltx import ( "ai_scheduler/internal/config" - "ai_scheduler/internal/data/model" + toolZoarb "ai_scheduler/internal/domain/tools/zltx/order_after_reseller_batch" "ai_scheduler/internal/domain/workflow/runtime" "ai_scheduler/internal/entitys" diff --git a/internal/pkg/rec_extra/ext.go b/internal/pkg/rec_extra/ext.go index 53fbbe7..b8b3c86 100644 --- a/internal/pkg/rec_extra/ext.go +++ b/internal/pkg/rec_extra/ext.go @@ -16,7 +16,7 @@ func SetTaskRecExt(requireData *entitys.RequireData, rec *entitys.Recognize) { rec.Ext = pkg.JsonByteIgonErr(TaskExt) } -func GetTaskRecExt(rec *entitys.Recognize) (ext *entitys.TaskExt, err error) { - err = json.Unmarshal(rec.Ext, ext) +func GetTaskRecExt(rec *entitys.Recognize) (ext entitys.TaskExt, err error) { + err = json.Unmarshal(rec.Ext, &ext) return ext, err } diff --git a/internal/services/callback.go b/internal/services/callback.go index b0697c0..c903bea 100644 --- a/internal/services/callback.go +++ b/internal/services/callback.go @@ -9,7 +9,7 @@ import ( "ai_scheduler/internal/pkg" "ai_scheduler/internal/pkg/dingtalk" "ai_scheduler/internal/pkg/util" - "ai_scheduler/internal/tools_bot" + "ai_scheduler/internal/tool_callback" "context" "encoding/json" "strings" @@ -25,17 +25,17 @@ type CallbackService struct { dingtalkOldClient *dingtalk.OldClient dingtalkContactClient *dingtalk.ContactClient dingtalkNotableClient *dingtalk.NotableClient - botTool *tools_bot.BotTool + callBackTool *tool_callback.CallBackTool } -func NewCallbackService(cfg *config.Config, gateway *gateway.Gateway, dingtalkOldClient *dingtalk.OldClient, dingtalkContactClient *dingtalk.ContactClient, dingtalkNotableClient *dingtalk.NotableClient, botTool *tools_bot.BotTool) *CallbackService { +func NewCallbackService(cfg *config.Config, gateway *gateway.Gateway, dingtalkOldClient *dingtalk.OldClient, dingtalkContactClient *dingtalk.ContactClient, dingtalkNotableClient *dingtalk.NotableClient, callBackTool *tool_callback.CallBackTool) *CallbackService { return &CallbackService{ cfg: cfg, gateway: gateway, dingtalkOldClient: dingtalkOldClient, dingtalkContactClient: dingtalkContactClient, dingtalkNotableClient: dingtalkNotableClient, - botTool: botTool, + callBackTool: callBackTool, } } @@ -138,7 +138,7 @@ func parseInt64(s string) (int64, bool) { func (s *CallbackService) handleDingTalkCallback(c *fiber.Ctx, env Envelope) error { // 校验taskId - sessionID, ok := s.botTool.GetSessionByTaskID(env.TaskID) + sessionID, ok := s.callBackTool.GetSessionByTaskID(env.TaskID) if !ok { return errorcode.ParamErr("missing session_id for task_id: %s", env.TaskID) } @@ -166,7 +166,7 @@ func (s *CallbackService) handleDingTalkCallback(c *fiber.Ctx, env Envelope) err s.sendStreamTxt(sessionID, msg) // 删除映射 - s.botTool.DelTaskMapping(env.TaskID) + s.callBackTool.DelTaskMapping(env.TaskID) return c.JSON(fiber.Map{"code": 0, "message": "ok"}) case ActionBugOptimizationSubmitProcess: @@ -283,7 +283,7 @@ func (s *CallbackService) handleBugOptimizationSubmitUpdate(ctx context.Context, BaseId: data.BaseId, SheetId: data.SheetId, RecordId: data.RecordId, - OperatorId: tools_bot.BotBugOptimizationSubmitAdminUnionId, + OperatorId: tool_callback.BotBugOptimizationSubmitAdminUnionId, CreatorUnionId: unionId, }) if err != nil { diff --git a/internal/tools_bot/bug_optimization_submit.go b/internal/tool_callback/bug_optimization_submit.go similarity index 74% rename from internal/tools_bot/bug_optimization_submit.go rename to internal/tool_callback/bug_optimization_submit.go index 1bb2d31..245ab8c 100644 --- a/internal/tools_bot/bug_optimization_submit.go +++ b/internal/tool_callback/bug_optimization_submit.go @@ -1,10 +1,13 @@ -package tools_bot +package tool_callback import ( + "ai_scheduler/internal/config" errors "ai_scheduler/internal/data/error" + "ai_scheduler/internal/data/impl" "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg" "ai_scheduler/internal/pkg/l_request" + "ai_scheduler/internal/pkg/utils_ollama" "context" "encoding/json" "fmt" @@ -15,6 +18,23 @@ import ( "xorm.io/builder" ) +type CallBackTool struct { + config *config.Config + llm *utils_ollama.Client + sessionImpl *impl.SessionImpl + taskMap map[string]string +} + +// NewBotTool 创建直连天下订单详情工具 +func NewCallBackTool(config *config.Config, llm *utils_ollama.Client, sessionImpl *impl.SessionImpl) *CallBackTool { + return &CallBackTool{config: config, llm: llm, sessionImpl: sessionImpl, taskMap: make(map[string]string)} +} + +// Execute 执行直连天下订单详情查询 +func (w *CallBackTool) Execute(ctx context.Context, toolName string, requireData *entitys.RequireData) (err error) { + return +} + // BugOptimizationSubmitForm 工单提交表单参数 type BugOptimizationSubmitForm struct { Mark string `json:"mark"` // 工单标识 @@ -34,7 +54,7 @@ const ( ) // BugOptimizationSubmit 工单提交 -func (w *BotTool) BugOptimizationSubmit(ctx context.Context, requireData *entitys.RequireData) (err error) { +func (w *CallBackTool) BugOptimizationSubmit(ctx context.Context, requireData *entitys.RequireData) (err error) { // 获取用户信息 cond := builder.NewCond() cond = cond.And(builder.Eq{"session_id": requireData.Session}) @@ -98,7 +118,7 @@ func (w *BotTool) BugOptimizationSubmit(ctx context.Context, requireData *entity // SetTaskMapping 设置 task_id 到 session_id 的映射(内存版)。 // 后续考虑使用 Redis,确保幂等与过期清理。 -func (w *BotTool) SetTaskMapping(taskID, sessionID string) { +func (w *CallBackTool) SetTaskMapping(taskID, sessionID string) { if taskID == "" || sessionID == "" { return } @@ -106,12 +126,12 @@ func (w *BotTool) SetTaskMapping(taskID, sessionID string) { } // GetSessionByTaskID 读取映射 -func (w *BotTool) GetSessionByTaskID(taskID string) (string, bool) { +func (w *CallBackTool) GetSessionByTaskID(taskID string) (string, bool) { v, ok := w.taskMap[taskID] return v, ok } // DelTaskMapping 删除 task_id 到 session_id 的映射(内存版)。 -func (w *BotTool) DelTaskMapping(taskID string) { +func (w *CallBackTool) DelTaskMapping(taskID string) { delete(w.taskMap, taskID) } diff --git a/internal/tool_callback/provider_set.go b/internal/tool_callback/provider_set.go new file mode 100644 index 0000000..c2671d4 --- /dev/null +++ b/internal/tool_callback/provider_set.go @@ -0,0 +1,9 @@ +package tool_callback + +import ( + "github.com/google/wire" +) + +var ProviderSetCallBackTools = wire.NewSet( + NewCallBackTool, +) diff --git a/internal/tools_bot/provider_set.go b/internal/tools_bot/provider_set.go deleted file mode 100644 index 7bcff12..0000000 --- a/internal/tools_bot/provider_set.go +++ /dev/null @@ -1,9 +0,0 @@ -package tools_bot - -import ( - "github.com/google/wire" -) - -var ProviderSetBotTools = wire.NewSet( - NewBotTool, -) From 33b6363233523e70a59ce4fcf262bd0db44b879b Mon Sep 17 00:00:00 2001 From: fuzhongyun <15339891972@163.com> Date: Tue, 16 Dec 2025 14:06:14 +0800 Subject: [PATCH 42/66] =?UTF-8?q?feat:=20=20=E5=A2=9E=E5=8A=A0=E5=9B=BE?= =?UTF-8?q?=E7=89=87=E7=BB=9F=E4=B8=80=E5=A4=84=E7=90=86=E4=B8=BA=E4=BA=8C?= =?UTF-8?q?=E8=BF=9B=E5=88=B6=E6=96=B9=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/biz/handle/file.go | 105 ++++++++++++++++++++++++++++++-- internal/data/constants/file.go | 10 +++ 2 files changed, 110 insertions(+), 5 deletions(-) diff --git a/internal/biz/handle/file.go b/internal/biz/handle/file.go index e9332aa..882ab86 100644 --- a/internal/biz/handle/file.go +++ b/internal/biz/handle/file.go @@ -4,10 +4,12 @@ import ( "ai_scheduler/internal/data/constants" "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg/l_request" + "bytes" "errors" "fmt" "io" "net/http" + "net/url" "path/filepath" "strings" @@ -15,12 +17,105 @@ import ( ) // HandleRecognizeFile 这里的目的是无论将什么类型的file都转为二进制格式 -// 判断文件大小 -// 判断文件类型 -// 判断文件是否合法 +// 最终输出:1.将 files.FileData 填充为文件的二进制数据 2.将 files.FileType 填充为文件的类型(当前为 constants.Caller,兼容写入其字符串值) +// 判断文件大小(统一限制为10MB);判断文件类型;判断文件是否合法(类型在白名单映射中);无法识别/非法/超限→填充unknown并兼容返回 +// 若 FileData 不存在 且 FileUrl 不存在, 则直接退出 +// 若 FileData 存在 FileType 存在, 则直接退出 +// 若 FileData 存在 FileType 不存在, 则根据 FileData 推断文件类型并填充 FileType +// 若 FileUrl 存在, 则下载文件并填充 FileData 和 FileType func HandleRecognizeFile(files *entitys.RecognizeFile) { - //Todo 仲云 - return + if files == nil { + return + } + + const maxSize = 10 * 1024 * 1024 // 10MB 上限 + + // 工具:根据 MIME 或扩展名映射到 FileType + mapToFileType := func(s string) constants.FileType { + if len(s) == 0 { + return constants.FileTypeUnknown + } + s = strings.ToLower(strings.TrimSpace(s)) + for ft, items := range constants.FileTypeMappings { + for _, item := range items { + if !strings.HasPrefix(item, ".") { // MIME + if s == item { + return ft + } + } else { // 扩展名 + if s == item { + return ft + } + } + } + } + return constants.FileTypeUnknown + } + + // 分支1:无数据、无URL→直接返回 + if len(files.FileData) == 0 && len(files.FileUrl) == 0 { + return + } + + // 分支2:已有数据且已有类型→直接返回 + if len(files.FileData) > 0 && len(strings.TrimSpace(files.FileType.String())) > 0 { + return + } + + // 分支3:仅有数据、无类型→内容检测并填充 + if len(files.FileData) > 0 && len(strings.TrimSpace(files.FileType.String())) == 0 { + if len(files.FileData) > maxSize { + files.FileType = constants.Caller(constants.FileTypeUnknown) + return + } + + reader := bytes.NewReader(files.FileData) + detected := detectFileType(reader, "") + if detected == constants.FileTypeUnknown { + files.FileType = constants.Caller(constants.FileTypeUnknown) + return + } + files.FileType = constants.Caller(detected) + return + } + + // 分支4:存在URL→下载并填充数据与类型 + if len(files.FileUrl) > 0 { + fileBytes, contentType, err := downloadFile(files.FileUrl) + if err != nil || len(fileBytes) == 0 { + files.FileType = constants.Caller(constants.FileTypeUnknown) + return + } + + if len(fileBytes) > maxSize { + // 超限:不写入数据,类型置 unknown + files.FileType = constants.Caller(constants.FileTypeUnknown) + return + } + + // 优先使用响应头的 Content-Type 映射 + detected := mapToFileType(contentType) + + if detected == constants.FileTypeUnknown { + // 回退:内容检测 + URL 文件名扩展名辅助 + var fname string + if u, perr := url.Parse(files.FileUrl); perr == nil { + fname = filepath.Base(u.Path) + } + reader := bytes.NewReader(fileBytes) + detected = detectFileType(reader, fname) + } + + // 写入数据 + files.FileData = fileBytes + + if detected == constants.FileTypeUnknown { + files.FileType = constants.Caller(constants.FileTypeUnknown) + return + } + files.FileType = constants.Caller(detected) + return + } } // 下载文件并返回二进制数据、MIME 类型 diff --git a/internal/data/constants/file.go b/internal/data/constants/file.go index 132ad0e..b825971 100644 --- a/internal/data/constants/file.go +++ b/internal/data/constants/file.go @@ -10,6 +10,8 @@ const ( FileTypeWord FileType = "word" FileTypeTxt FileType = "txt" FileTypePDF FileType = "pdf" + FileTypePPT FileType = "ppt" + FileTypeCSV FileType = "csv" ) var FileTypeMappings = map[FileType][]string{ @@ -35,4 +37,12 @@ var FileTypeMappings = map[FileType][]string{ "text/plain", ".txt", }, + FileTypePPT: { + "application/vnd.openxmlformats-officedocument.presentationml.presentation", + ".pptx", + }, + FileTypeCSV: { + "text/csv", + ".csv", + }, } From fbaa8cdc808e04eee10d4a9c4dd75f86ec9efbe5 Mon Sep 17 00:00:00 2001 From: fuzhongyun <15339891972@163.com> Date: Tue, 16 Dec 2025 14:53:16 +0800 Subject: [PATCH 43/66] =?UTF-8?q?fix=EF=BC=9A=20=E9=85=8D=E7=BD=AE?= =?UTF-8?q?=E6=9B=B4=E6=96=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/config_test.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/config/config_test.yaml b/config/config_test.yaml index cbeb64f..dc20fd7 100644 --- a/config/config_test.yaml +++ b/config/config_test.yaml @@ -5,7 +5,7 @@ server: ollama: - base_url: "http://127.0.0.1:11434" + base_url: "http://host.docker.internal:11434" model: "qwen3-coder:480b-cloud" generate_model: "qwen3-coder:480b-cloud" vl_model: "gemini-3-pro-preview" @@ -105,7 +105,7 @@ permissionConfig: llm: providers: ollama: - endpoint: http://127.0.0.1:11434 + endpoint: http://host.docker.internal:11434 timeout: 60s models: - id: qwen3-coder:480b-cloud From 3dccccda9ecd2a8eb83b4a216e75e2da5c08085a Mon Sep 17 00:00:00 2001 From: fuzhongyun <15339891972@163.com> Date: Tue, 16 Dec 2025 18:25:13 +0800 Subject: [PATCH 44/66] =?UTF-8?q?feat:=201.=20=E6=96=B0=E5=A2=9E=E2=80=9Cc?= =?UTF-8?q?oze=E5=B7=A5=E4=BD=9C=E6=B5=81=E2=80=9D=E4=BB=BB=E5=8A=A1?= =?UTF-8?q?=E7=B1=BB=E5=9E=8B=E5=8F=8A=E7=9B=B8=E5=85=B3=E9=85=8D=E7=BD=AE?= =?UTF-8?q?=202.=20=E5=A2=9E=E5=8A=A0coze=E5=B7=A5=E4=BD=9C=E6=B5=81?= =?UTF-8?q?=E7=BB=9F=E4=B8=80=E6=B5=81=E5=BC=8F&=E9=9D=9E=E6=B5=81?= =?UTF-8?q?=E5=BC=8F=E4=BB=BB=E5=8A=A1=E5=A4=84=E7=90=86=E6=96=B9=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/config_env.yaml | 4 ++ internal/biz/do/handle.go | 108 ++++++++++++++++++++++++++++++- internal/config/config.go | 8 +++ internal/data/constants/const.go | 1 + 4 files changed, 120 insertions(+), 1 deletion(-) diff --git a/config/config_env.yaml b/config/config_env.yaml index d2cd2a4..188e5f9 100644 --- a/config/config_env.yaml +++ b/config/config_env.yaml @@ -18,6 +18,10 @@ vllm: timeout: "120s" level: "info" +coze: + base_url: "https://api.coze.cn" + api_secret: "pat_guUSPk8KZFvIIbVReuaMlOBVAaIISSdkTEV8MaRgVPNv6UEYPHKTBUXznFcxl04H" + sys: session_len: 6 diff --git a/internal/biz/do/handle.go b/internal/biz/do/handle.go index a09eb48..3a4a12b 100644 --- a/internal/biz/do/handle.go +++ b/internal/biz/do/handle.go @@ -17,12 +17,17 @@ import ( "ai_scheduler/internal/pkg/util" "ai_scheduler/internal/tools" "ai_scheduler/internal/tools/public" + errorsSpecial "errors" + "io" + "net/http" + "time" "context" "encoding/json" "fmt" "strings" + "github.com/coze-dev/coze-go" "gorm.io/gorm/utils" ) @@ -118,9 +123,10 @@ func (r *Handle) HandleMatch(ctx context.Context, client *gateway.Client, rec *e return r.handleTask(ctx, rec, pointTask) case constants.TaskTypeKnowle: return r.handleKnowle(ctx, rec, pointTask) - case constants.TaskTypeEinoWorkflow: return r.handleEinoWorkflow(ctx, rec, pointTask) + case constants.TaskTypeCozeWorkflow: + return r.handleCozeWorkflow(ctx, rec, pointTask) default: return r.handleOtherTask(ctx, requireData) } @@ -302,6 +308,106 @@ func (r *Handle) handleEinoWorkflow(ctx context.Context, rec *entitys.Recognize, return nil } +func (r *Handle) handleCozeWorkflow(ctx context.Context, rec *entitys.Recognize, task *model.AiTask) (err error) { + entitys.ResLoading(rec.Ch, task.Index, "正在执行工作流(coze)") + + customClient := &http.Client{ + Timeout: time.Minute * 30, + } + + authCli := coze.NewTokenAuth(r.conf.Coze.ApiSecret) + cozeCli := coze.NewCozeAPI( + authCli, + coze.WithBaseURL(r.conf.Coze.BaseURL), + coze.WithHttpClient(customClient), + ) + + // 从参数中获取workflowID + type requestParams struct { + Request l_request.Request `json:"request"` + } + var config requestParams + err = json.Unmarshal([]byte(task.Config), &config) + if err != nil { + return err + } + workflowId, ok := config.Request.Json["workflow_id"].(string) + if !ok { + return fmt.Errorf("workflow_id不能为空") + } + // 提取参数 + var data map[string]interface{} + err = json.Unmarshal([]byte(rec.Match.Parameters), &data) + + req := &coze.RunWorkflowsReq{ + WorkflowID: workflowId, + Parameters: data, + // IsAsync: true, + } + + stream := config.Request.Json["stream"].(bool) + + entitys.ResLog(rec.Ch, task.Index, "工作流执行中...") + + if stream { + streamResp, err := cozeCli.Workflows.Runs.Stream(ctx, req) + if err != nil { + return err + } + + handleCozeWorkflowEvents(ctx, streamResp, cozeCli, workflowId, rec.Ch, task.Index) + } else { + resp, err := cozeCli.Workflows.Runs.Create(ctx, req) + if err != nil { + return err + } + + entitys.ResJson(rec.Ch, task.Index, resp.Data) + } + + return +} + +// handleCozeWorkflowEvents 处理 coze 工作流事件 +func handleCozeWorkflowEvents(ctx context.Context, resp coze.Stream[coze.WorkflowEvent], cozeCli coze.CozeAPI, workflowID string, ch chan entitys.Response, index string) { + defer resp.Close() + for { + event, err := resp.Recv() + if errorsSpecial.Is(err, io.EOF) { + fmt.Println("Stream finished") + break + } + if err != nil { + fmt.Println("Error receiving event:", err) + break + } + + switch event.Event { + case coze.WorkflowEventTypeMessage: + entitys.ResStream(ch, index, event.Message.Content) + case coze.WorkflowEventTypeError: + entitys.ResError(ch, index, fmt.Sprintf("工作流执行错误: %s", event.Error)) + case coze.WorkflowEventTypeDone: + entitys.ResEnd(ch, index, "工作流执行完成") + case coze.WorkflowEventTypeInterrupt: + resumeReq := &coze.ResumeRunWorkflowsReq{ + WorkflowID: workflowID, + EventID: event.Interrupt.InterruptData.EventID, + ResumeData: "your data", + InterruptType: event.Interrupt.InterruptData.Type, + } + newResp, err := cozeCli.Workflows.Runs.Resume(ctx, resumeReq) + if err != nil { + entitys.ResError(ch, index, fmt.Sprintf("工作流恢复执行错误: %s", err.Error())) + return + } + entitys.ResLog(ch, index, "工作流恢复执行中...") + handleCozeWorkflowEvents(ctx, newResp, cozeCli, workflowID, ch, index) + } + } + fmt.Printf("done, log:%s\n", resp.Response().LogID()) +} + // 权限验证 func (r *Handle) PermissionAuth(client *gateway.Client, pointTask *model.AiTask) (err error) { // 授权检查权限 diff --git a/internal/config/config.go b/internal/config/config.go index 91115cb..ee0c1a5 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -12,6 +12,7 @@ type Config struct { Server ServerConfig `mapstructure:"server"` Ollama OllamaConfig `mapstructure:"ollama"` Vllm VllmConfig `mapstructure:"vllm"` + Coze CozeConfig `mapstructure:"coze"` Sys SysConfig `mapstructure:"sys"` Tools ToolsConfig `mapstructure:"tools"` Logging LoggingConfig `mapstructure:"logging"` @@ -90,6 +91,13 @@ type VllmConfig struct { Level string `mapstructure:"level"` } +// CozeConfig Coze配置 +type CozeConfig struct { + BaseURL string `mapstructure:"base_url"` + ApiKey string `mapstructure:"api_key"` + ApiSecret string `mapstructure:"api_secret"` +} + type Redis struct { Host string `mapstructure:"host"` Type string `mapstructure:"type"` diff --git a/internal/data/constants/const.go b/internal/data/constants/const.go index 604848e..b3c6ef0 100644 --- a/internal/data/constants/const.go +++ b/internal/data/constants/const.go @@ -16,6 +16,7 @@ const ( TaskTypeFunc TaskType = 3 TaskTypeBot TaskType = 4 TaskTypeEinoWorkflow TaskType = 5 // eino 工作流 + TaskTypeCozeWorkflow TaskType = 6 // coze 工作流 ) type UseFul int32 From 9d95106b038fd196a2e33b8cf83f73fcb81a04b3 Mon Sep 17 00:00:00 2001 From: renzhiyuan <465386466@qq.com> Date: Tue, 16 Dec 2025 21:13:24 +0800 Subject: [PATCH 45/66] refactor: enhance bot chat history and error handling --- internal/biz/ding_talk_bot.go | 83 +++++++++++++++++++--- internal/data/model/ai_bot_chat_his.gen.go | 3 +- internal/entitys/bot.go | 1 + internal/services/dtalk_bot.go | 20 ++++-- 4 files changed, 90 insertions(+), 17 deletions(-) diff --git a/internal/biz/ding_talk_bot.go b/internal/biz/ding_talk_bot.go index ca7431b..d799796 100644 --- a/internal/biz/ding_talk_bot.go +++ b/internal/biz/ding_talk_bot.go @@ -9,6 +9,7 @@ import ( "ai_scheduler/internal/data/model" "ai_scheduler/internal/entitys" "ai_scheduler/internal/tools" + "strconv" "context" "database/sql" @@ -34,6 +35,7 @@ type DingTalkBotBiz struct { botTools []model.AiBotTool botGroupImpl *impl.BotGroupImpl toolManager *tools.Manager + chatHis *impl.BotChatHisImpl } // NewDingTalkBotBiz @@ -44,6 +46,7 @@ func NewDingTalkBotBiz( botGroupImpl *impl.BotGroupImpl, dingTalkUser *dingtalk.User, tools *tools_regis.ToolRegis, + chatHis *impl.BotChatHisImpl, toolManager *tools.Manager, ) *DingTalkBotBiz { return &DingTalkBotBiz{ @@ -55,6 +58,7 @@ func NewDingTalkBotBiz( botTools: tools.BootTools, botGroupImpl: botGroupImpl, toolManager: toolManager, + chatHis: chatHis, } } @@ -88,7 +92,7 @@ func (d *DingTalkBotBiz) InitRequire(ctx context.Context, data *chatbot.BotCallb } func (d *DingTalkBotBiz) Do(ctx context.Context, requireData *entitys.RequireDataDingTalkBot) (err error) { - entitys.ResText(requireData.Ch, "", "收到消息,正在处理中,请稍等") + entitys.ResLoading(requireData.Ch, "", "收到消息,正在处理中,请稍等") defer close(requireData.Ch) switch constants.ConversationType(requireData.Req.ConversationType) { case constants.ConversationTypeSingle: @@ -98,6 +102,9 @@ func (d *DingTalkBotBiz) Do(ctx context.Context, requireData *entitys.RequireDat default: err = errors.New("未知的聊天类型:" + requireData.Req.ConversationType) } + if err != nil { + entitys.ResText(requireData.Ch, "", err.Error()) + } return } @@ -108,6 +115,7 @@ func (d *DingTalkBotBiz) handleSingleChat(ctx context.Context, requireData *enti //if err != nil { // return //} + //requireData.ID=requireData.UserInfo.UserID ////如果不是管理或者不是老板,则进行权限判断 //if requireData.UserInfo.IsSenior == constants.IsSeniorFalse && requireData.UserInfo.IsBoss == constants.IsBossFalse { // @@ -117,9 +125,11 @@ func (d *DingTalkBotBiz) handleSingleChat(ctx context.Context, requireData *enti func (d *DingTalkBotBiz) handleGroupChat(ctx context.Context, requireData *entitys.RequireDataDingTalkBot) (err error) { group, err := d.initGroup(ctx, requireData.Req.ConversationId, requireData.Req.ConversationTitle) + if err != nil { return } + requireData.ID = group.GroupID groupTools, err := d.getGroupTools(ctx, group) if err != nil { return @@ -158,12 +168,19 @@ func (d *DingTalkBotBiz) getGroupTools(ctx context.Context, group *model.AiBotGr return } var ( - groupRegisTools map[string]struct{} + groupRegisTools = make(map[int]struct{}) ) if group.ToolList != "" { - groupList := strings.Split(group.ToolList, ",") - for _, tool := range groupList { - groupRegisTools[tool] = struct{}{} + groupToolList := strings.Split(group.ToolList, ",") + for _, tool := range groupToolList { + if tool == "" { + continue + } + num, _err := strconv.Atoi(tool) + if _err != nil { + continue + } + groupRegisTools[num] = struct{}{} } } @@ -172,7 +189,7 @@ func (d *DingTalkBotBiz) getGroupTools(ctx context.Context, group *model.AiBotGr tools = append(tools, v) continue } - if _, ex := groupRegisTools[v.Index]; ex { + if _, ex := groupRegisTools[int(v.ToolID)]; ex { tools = append(tools, v) } } @@ -221,11 +238,6 @@ func (d *DingTalkBotBiz) getUserContent(msgType string, msgContent interface{}) return } -func (d *DingTalkBotBiz) defaultPrompt() string { - - return `{"system":"智能路由系统,精准解析用户意图并路由至任务模块,遵循以下规则:","rule":{"返回格式":"{\\"index\\":\\"工具索引\\",\\"confidence\\":\\"0.0-1.0\\",\\"reasoning\\":\\"判断理由\\",\\"parameters\\":\\"转义JSON参数\\",\\"is_match\\":true|false,\\"chat\\":\\"追问内容\\"}","工具匹配":["用工具parameters匹配,区分必选(required)和可选(optional)参数","无法匹配时,is_match=false,chat提醒用户适用工具(例:'请问您要查询订单还是商品?')"],"参数提取":["从用户输入提取parameters中明确提及的参数","必须参数仅用用户直接提及的,缺失时is_match=false,chat提醒补充(例:'需补充XX信息')"],"格式要求":["所有字段值为字符串(含confidence)","parameters为转义JSON字符串(如\\"{\\\\"key\\\\":\\\\"value\\\\"}\\")"]}}` -} - func (d *DingTalkBotBiz) handleMatch(ctx context.Context, rec *entitys.Recognize) (err error) { if !rec.Match.IsMatch { @@ -297,6 +309,28 @@ func (d *DingTalkBotBiz) HandleRes(ctx context.Context, data *chatbot.BotCallbac } } +func (d *DingTalkBotBiz) SaveHis(ctx context.Context, requireData *entitys.RequireDataDingTalkBot, chat []string) (err error) { + if len(chat) == 0 { + return + } + his := []*model.AiBotChatHi{ + { + HisType: requireData.Req.ConversationType, + ID: requireData.ID, + Role: "user", + Content: requireData.Req.Text.Content, + }, + { + HisType: requireData.Req.ConversationType, + ID: requireData.ID, + Role: "system", + Content: strings.Join(chat, "\n"), + }, + } + _, err = d.chatHis.Add(his) + return err +} + func (d *DingTalkBotBiz) replySteam(ctx context.Context, SessionWebhook string, content string, arg ...string) error { msg := content if len(arg) > 0 { @@ -344,3 +378,30 @@ func (d *DingTalkBotBiz) replyActionCard(ctx context.Context, SessionWebhook str } return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg)) } + +func (d *DingTalkBotBiz) defaultPrompt() string { + + return `[system] 你是一个智能路由系统,核心职责是 **精准解析用户意图并路由至对应任务模块**。请严格遵循以下规则: +[rule] +1. **返回格式**: +仅输出以下 **严格格式化的 JSON 字符串**(禁用 Markdown): +{ "index": "工具索引index", "confidence": 0.0-1.0,"reasoning": "判断理由","parameters":"jsonstring |提取参数","is_match":true||false,"chat": "追问内容"} +关键规则(按优先级排序): + +2. **工具匹配**: + +- 若匹配到工具,使用工具的 parameters 作为模板做参数匹配 +- 注意区分 parameters 中的 必须参数(required) 和 可选参数(optional),按下述参数提取规则处理。 +- 若**完全无法匹配**,立即设置 is_match: false,并在 chat 中已第一人称的角度提醒用户需要适用何种工具(例:"请问您是要查询订单还是商品呢")。 + +1. **参数提取**: + +- 根据 parameters 字段列出的参数名,从用户输入中提取对应值。 +- **仅提取**明确提及的参数,忽略未列出的内容。 +- 必须参数仅使用用户直接提及的参数,不允许从上下文推断。 +- 若必须参数缺失,立即设置 is_match: false,并在 chat 中已第一人称的角度提醒用户提供缺少的参数追问(例:"需要您补充XX信息")。 + +4. 格式强制要求: +-所有字段值必须是**字符串**(包括 confidence)。 +-parameters 必须是 **转义后的 JSON 字符串**(如 "{\"product_name\": \"京东月卡\"}")。` +} diff --git a/internal/data/model/ai_bot_chat_his.gen.go b/internal/data/model/ai_bot_chat_his.gen.go index 1e4bfff..2285343 100644 --- a/internal/data/model/ai_bot_chat_his.gen.go +++ b/internal/data/model/ai_bot_chat_his.gen.go @@ -13,7 +13,8 @@ const TableNameAiBotChatHi = "ai_bot_chat_his" // AiBotChatHi mapped from table type AiBotChatHi struct { HisID int64 `gorm:"column:his_id;primaryKey;autoIncrement:true" json:"his_id"` - SessionID string `gorm:"column:session_id;not null" json:"session_id"` + HisType string `gorm:"column:his_type;not null;default:1;comment:1为个人,2为群聊" json:"his_type"` // 1为个人,2为群聊 + ID int32 `gorm:"column:id;not null;comment:对应的id" json:"id"` // 对应的id Role string `gorm:"column:role;not null;comment:system系统输出,assistant助手输出,user用户输入" json:"role"` // system系统输出,assistant助手输出,user用户输入 Content string `gorm:"column:content;not null" json:"content"` Files string `gorm:"column:files;not null" json:"files"` diff --git a/internal/entitys/bot.go b/internal/entitys/bot.go index 7085a37..46661ab 100644 --- a/internal/entitys/bot.go +++ b/internal/entitys/bot.go @@ -13,6 +13,7 @@ type RequireDataDingTalkBot struct { Match *Match Req *chatbot.BotCallbackDataModel Ch chan Response + ID int32 } type DingTalkBot struct { diff --git a/internal/services/dtalk_bot.go b/internal/services/dtalk_bot.go index 0bfe129..5c9c92b 100644 --- a/internal/services/dtalk_bot.go +++ b/internal/services/dtalk_bot.go @@ -27,20 +27,27 @@ func (d *DingBotService) GetServiceCfg() ([]entitys.DingTalkBot, error) { } func (d *DingBotService) OnChatBotMessageReceived(ctx context.Context, data *chatbot.BotCallbackDataModel) (content []byte, err error) { + var ( + lastErr error + chat []string + ) requireData, err := d.dingTalkBotBiz.InitRequire(ctx, data) if err != nil { return } // 使用 ctx.Done() 通知 Do 方法提前终止 subCtx, cancel := context.WithCancel(ctx) - defer cancel() + defer func() { + cancel() + _ = d.dingTalkBotBiz.SaveHis(ctx, requireData, chat) + + }() // 异步执行 Do 方法 done := make(chan error, 1) go func() { done <- d.dingTalkBotBiz.Do(subCtx, requireData) }() - var lastErr error for { select { case <-ctx.Done(): @@ -53,6 +60,9 @@ func (d *DingBotService) OnChatBotMessageReceived(ctx context.Context, data *cha if resp.Type == entitys.ResponseLog { continue } + if resp.Type == entitys.ResponseText || resp.Type == entitys.ResponseStream || resp.Type == entitys.ResponseJson { + chat = append(chat, resp.Content) + } if err := d.dingTalkBotBiz.HandleRes(ctx, data, resp); err != nil { log.Printf("HandleRes 失败: %v", err) } @@ -60,9 +70,9 @@ func (d *DingBotService) OnChatBotMessageReceived(ctx context.Context, data *cha } cleanup: select { - case err := <-done: - if err != nil { - log.Printf("Do 方法执行失败: %v", err) + case _err := <-done: + if _err != nil { + panic(_err) } case <-time.After(1 * time.Second): log.Println("警告:等待 Do 方法超时,可能发生 goroutine 泄漏") From 1d604a4af9ccb9fa1069d18014e537f7d19f25e3 Mon Sep 17 00:00:00 2001 From: wolter <11@gmail> Date: Wed, 17 Dec 2025 15:04:14 +0800 Subject: [PATCH 46/66] feat: config_test.yaml --- config/config_test.yaml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/config/config_test.yaml b/config/config_test.yaml index dc20fd7..5fa0dbd 100644 --- a/config/config_test.yaml +++ b/config/config_test.yaml @@ -19,6 +19,10 @@ vllm: timeout: "120s" level: "info" +coze: + base_url: "https://api.coze.cn" + api_secret: "sat_AqvFcdNgesP8megy1ItTscWFXRcsHRzmM4NJ1KNavfcdT0EPwYuCPkDqGhItpx13" + sys: session_len: 6 @@ -149,8 +153,6 @@ llm: temperature: 0.7 max_tokens: 4096 stream: true - - #ding_talk_bots: # public: # client_id: "dingchg59zwwvmuuvldx", From ae34efb98953d58bc695ad7ebc81d4ef96db9911 Mon Sep 17 00:00:00 2001 From: fuzhongyun <15339891972@163.com> Date: Wed, 17 Dec 2025 18:11:29 +0800 Subject: [PATCH 47/66] =?UTF-8?q?feat:=201.=E8=B0=83=E6=95=B4=E9=80=9A?= =?UTF-8?q?=E7=94=A8=E6=96=87=E4=BB=B6=E8=AF=86=E5=88=AB=202.=E6=96=B0?= =?UTF-8?q?=E5=A2=9Eeino=20vllm=20=E5=9B=BE=E7=89=87=E5=AD=97=E8=8A=82?= =?UTF-8?q?=E8=AF=86=E5=88=AB=E6=96=B9=E6=B3=95=20=203.=E5=AE=9E=E7=8E=B0?= =?UTF-8?q?=E7=9B=B4=E8=BF=9EAI=E4=B8=9A=E5=8A=A1=E7=9A=84=E5=9B=BE?= =?UTF-8?q?=E7=89=87=E8=AF=86=E5=88=AB=E6=8E=A5=E5=85=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/config_env.yaml | 2 +- config/config_test.yaml | 2 +- internal/biz/do/prompt.go | 48 +++++++++++++++++++++++++++++-- internal/biz/handle/file.go | 34 +++++++++------------- internal/biz/router.go | 6 +++- internal/data/constants/file.go | 4 +++ internal/entitys/recognize.go | 9 +++--- internal/pkg/util/point.go | 6 ++++ internal/pkg/utils_vllm/client.go | 30 +++++++++++++++++++ 9 files changed, 111 insertions(+), 30 deletions(-) create mode 100644 internal/pkg/util/point.go diff --git a/config/config_env.yaml b/config/config_env.yaml index 188e5f9..23a123b 100644 --- a/config/config_env.yaml +++ b/config/config_env.yaml @@ -14,7 +14,7 @@ ollama: vllm: base_url: "http://117.175.169.61:16001/v1" - vl_model: "models/Qwen2.5-VL-3B-Instruct-AWQ" + vl_model: "Qwen2.5-VL-3B-Instruct-AWQ" timeout: "120s" level: "info" diff --git a/config/config_test.yaml b/config/config_test.yaml index 5fa0dbd..ab91824 100644 --- a/config/config_test.yaml +++ b/config/config_test.yaml @@ -15,7 +15,7 @@ ollama: vllm: base_url: "http://host.docker.internal:8001/v1" - vl_model: "models/Qwen2.5-VL-3B-Instruct-AWQ" + vl_model: "Qwen2.5-VL-3B-Instruct-AWQ" timeout: "120s" level: "info" diff --git a/internal/biz/do/prompt.go b/internal/biz/do/prompt.go index ea82f84..cada763 100644 --- a/internal/biz/do/prompt.go +++ b/internal/biz/do/prompt.go @@ -2,8 +2,11 @@ package do import ( "ai_scheduler/internal/biz/handle" + "ai_scheduler/internal/config" + "ai_scheduler/internal/data/constants" "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg" + "ai_scheduler/internal/pkg/utils_vllm" "context" "strings" @@ -15,6 +18,7 @@ type PromptOption interface { } type WithSys struct { + Config *config.Config } func (f *WithSys) CreatePrompt(ctx context.Context, rec *entitys.Recognize) (mes []api.Message, err error) { @@ -43,7 +47,7 @@ func (f *WithSys) CreatePrompt(ctx context.Context, rec *entitys.Recognize) (mes func (f *WithSys) getUserContent(ctx context.Context, rec *entitys.Recognize) (content strings.Builder, err error) { var hasFile bool - if rec.UserContent.File != nil && len(rec.UserContent.File) > 0 { + if len(rec.UserContent.File) > 0 { hasFile = true } content.WriteString(rec.UserContent.Text) @@ -67,13 +71,51 @@ func (f *WithSys) getUserContent(ctx context.Context, rec *entitys.Recognize) (c content.WriteString("### 文件内容:\n") for _, file := range rec.UserContent.File { handle.HandleRecognizeFile(file) - } + // 文件识别 + switch file.FileType { + case constants.FileTypeImage: + entitys.ResLog(rec.Ch, "recognize_img_start", "图片识别中...") + var imageContent string + imageContent, err = f.recognizeWithImgVllm(ctx, file) + if err != nil { + return + } + entitys.ResLog(rec.Ch, "recognize_img_end", "图片识别完成,识别内容:"+imageContent) - //...do something with file + // 解析结果回写到file + file.FileRec = imageContent + default: + content.WriteString(file.FileRec) + } + } } return } +func (f *WithSys) recognizeWithImgVllm(ctx context.Context, file *entitys.RecognizeFile) (content string, err error) { + if file.FileData == nil || file.FileType != constants.FileTypeImage { + return + } + + client, cleanup, err := utils_vllm.NewClient(f.Config) + if err != nil { + return "", err + } + defer cleanup() + + outMsg, err := client.RecognizeWithImgBytes(ctx, + f.Config.DefaultPrompt.ImgRecognize.SystemPrompt, + f.Config.DefaultPrompt.ImgRecognize.UserPrompt, + file.FileData, + file.FileRealMime, + ) + if err != nil { + return "", err + } + + return outMsg.Content, nil +} + type WithDingTalkBot struct { } diff --git a/internal/biz/handle/file.go b/internal/biz/handle/file.go index 882ab86..e93fbc3 100644 --- a/internal/biz/handle/file.go +++ b/internal/biz/handle/file.go @@ -65,17 +65,14 @@ func HandleRecognizeFile(files *entitys.RecognizeFile) { // 分支3:仅有数据、无类型→内容检测并填充 if len(files.FileData) > 0 && len(strings.TrimSpace(files.FileType.String())) == 0 { if len(files.FileData) > maxSize { - files.FileType = constants.Caller(constants.FileTypeUnknown) + files.FileType = constants.FileTypeUnknown return } reader := bytes.NewReader(files.FileData) - detected := detectFileType(reader, "") - if detected == constants.FileTypeUnknown { - files.FileType = constants.Caller(constants.FileTypeUnknown) - return - } - files.FileType = constants.Caller(detected) + detected, fileRealMime := detectFileType(reader, "") + files.FileType = detected + files.FileRealMime = fileRealMime return } @@ -83,18 +80,19 @@ func HandleRecognizeFile(files *entitys.RecognizeFile) { if len(files.FileUrl) > 0 { fileBytes, contentType, err := downloadFile(files.FileUrl) if err != nil || len(fileBytes) == 0 { - files.FileType = constants.Caller(constants.FileTypeUnknown) + files.FileType = constants.FileTypeUnknown return } if len(fileBytes) > maxSize { // 超限:不写入数据,类型置 unknown - files.FileType = constants.Caller(constants.FileTypeUnknown) + files.FileType = constants.FileTypeUnknown return } // 优先使用响应头的 Content-Type 映射 detected := mapToFileType(contentType) + fileRealMime := contentType if detected == constants.FileTypeUnknown { // 回退:内容检测 + URL 文件名扩展名辅助 @@ -103,17 +101,13 @@ func HandleRecognizeFile(files *entitys.RecognizeFile) { fname = filepath.Base(u.Path) } reader := bytes.NewReader(fileBytes) - detected = detectFileType(reader, fname) + detected, fileRealMime = detectFileType(reader, fname) } // 写入数据 files.FileData = fileBytes - - if detected == constants.FileTypeUnknown { - files.FileType = constants.Caller(constants.FileTypeUnknown) - return - } - files.FileType = constants.Caller(detected) + files.FileType = detected + files.FileRealMime = fileRealMime return } } @@ -150,7 +144,7 @@ func downloadFile(fileUrl string) (fileBytes []byte, contentType string, err err } // detectFileType 判断文件类型 -func detectFileType(file io.ReadSeeker, filename string) constants.FileType { +func detectFileType(file io.ReadSeeker, filename string) (constants.FileType, string) { // 1. 读取文件头检测 MIME buffer := make([]byte, 512) n, _ := file.Read(buffer) @@ -160,7 +154,7 @@ func detectFileType(file io.ReadSeeker, filename string) constants.FileType { for fileType, items := range constants.FileTypeMappings { for _, item := range items { if !strings.HasPrefix(item, ".") && item == detectedMIME { - return fileType + return fileType, detectedMIME } } } @@ -170,10 +164,10 @@ func detectFileType(file io.ReadSeeker, filename string) constants.FileType { for fileType, items := range constants.FileTypeMappings { for _, item := range items { if strings.HasPrefix(item, ".") && item == ext { - return fileType + return fileType, ext } } } - return constants.FileTypeUnknown + return constants.FileTypeUnknown, "" } diff --git a/internal/biz/router.go b/internal/biz/router.go index 28ce0b8..7d045f8 100644 --- a/internal/biz/router.go +++ b/internal/biz/router.go @@ -2,6 +2,7 @@ package biz import ( "ai_scheduler/internal/biz/do" + "ai_scheduler/internal/config" "ai_scheduler/internal/data/constants" errors "ai_scheduler/internal/data/error" "ai_scheduler/internal/gateway" @@ -21,16 +22,19 @@ import ( type AiRouterBiz struct { do *do.Do handle *do.Handle + config *config.Config } // NewAiRouterBiz 创建路由服务 func NewAiRouterBiz( do *do.Do, handle *do.Handle, + config *config.Config, ) *AiRouterBiz { return &AiRouterBiz{ do: do, handle: handle, + config: config, } } @@ -94,7 +98,7 @@ func (r *AiRouterBiz) SetRec(ctx context.Context, requireData *entitys.RequireDa // 对应不同的appKey, 配置不同的系统提示词 switch requireData.Sys.AppKey { default: - sys = &do.WithSys{} + sys = &do.WithSys{Config: r.config} } // 1. 系统提示词 diff --git a/internal/data/constants/file.go b/internal/data/constants/file.go index b825971..f234572 100644 --- a/internal/data/constants/file.go +++ b/internal/data/constants/file.go @@ -46,3 +46,7 @@ var FileTypeMappings = map[FileType][]string{ ".csv", }, } + +func (ft FileType) String() string { + return string(ft) +} diff --git a/internal/entitys/recognize.go b/internal/entitys/recognize.go index 831bef7..87f684c 100644 --- a/internal/entitys/recognize.go +++ b/internal/entitys/recognize.go @@ -41,8 +41,9 @@ type RecognizeUserContent struct { type FileData []byte type RecognizeFile struct { - FileRec string //文件识别内容 - FileData FileData // 文件数据(二进制格式) - FileType constants.Caller // 文件类型(文件类型,能填最好填,可以跳过一层判断) - FileUrl string // 文件下载链接 + FileRec string //文件识别内容 + FileData FileData // 文件数据(二进制格式) + FileType constants.FileType // 文件类型(文件类型,能填最好填,可以跳过一层判断) + FileRealMime string // 文件真实MIME类型 + FileUrl string // 文件下载链接 } diff --git a/internal/pkg/util/point.go b/internal/pkg/util/point.go new file mode 100644 index 0000000..beceddf --- /dev/null +++ b/internal/pkg/util/point.go @@ -0,0 +1,6 @@ +package util + +// AnyToPoint converts any value to a pointer. +func AnyToPoint[T any](v T) *T { + return &v +} diff --git a/internal/pkg/utils_vllm/client.go b/internal/pkg/utils_vllm/client.go index c333350..c8c4aec 100644 --- a/internal/pkg/utils_vllm/client.go +++ b/internal/pkg/utils_vllm/client.go @@ -2,7 +2,9 @@ package utils_vllm import ( "ai_scheduler/internal/config" + "ai_scheduler/internal/pkg/util" "context" + "encoding/base64" "github.com/cloudwego/eino-ext/components/model/openai" "github.com/cloudwego/eino/schema" @@ -58,3 +60,31 @@ func (c *Client) RecognizeWithImg(ctx context.Context, systemPrompt, userPrompt in[1].UserInputMultiContent = parts return c.model.Generate(ctx, in) } + +// 识别图片by二进制文件 +func (c *Client) RecognizeWithImgBytes(ctx context.Context, systemPrompt, userPrompt string, imgBytes []byte, imgType string) (*schema.Message, error) { + in := []*schema.Message{ + { + Role: schema.System, + Content: systemPrompt, + }, + { + Role: schema.User, + }, + } + parts := []schema.MessageInputPart{ + {Type: schema.ChatMessagePartTypeText, Text: userPrompt}, + } + parts = append(parts, schema.MessageInputPart{ + Type: schema.ChatMessagePartTypeImageURL, + Image: &schema.MessageInputImage{ + MessagePartCommon: schema.MessagePartCommon{ + MIMEType: imgType, + Base64Data: util.AnyToPoint(base64.StdEncoding.EncodeToString(imgBytes)), + }, + }, + }) + + in[1].UserInputMultiContent = parts + return c.model.Generate(ctx, in) +} From c2906ad9265deb53dbec5bb3cd88a915549e8935 Mon Sep 17 00:00:00 2001 From: renzhiyuan <465386466@qq.com> Date: Thu, 18 Dec 2025 18:15:44 +0800 Subject: [PATCH 48/66] feat: enhance dingtalk card send and auth handling --- go.mod | 5 +- go.sum | 6 +- internal/biz/ding_talk_bot.go | 277 +++++++++++++---- internal/biz/handle/dingtalk/auth.go | 71 ++++- internal/biz/handle/dingtalk/option.go | 25 +- internal/biz/handle/dingtalk/provider_set.go | 1 + internal/biz/handle/dingtalk/send_card.go | 287 ++++++++++++++++++ .../biz/handle/dingtalk/send_card.go.bak1 | 280 +++++++++++++++++ internal/biz/handle/dingtalk/types.go | 34 ++- internal/data/constants/bot.go | 2 + internal/data/constants/dingtalk.go | 29 ++ internal/data/impl/bot_group.go | 4 +- internal/data/model/ai_bot_config.gen.go | 8 +- internal/data/model/ai_bot_group.gen.go | 1 + internal/entitys/bot.go | 2 +- internal/pkg/func.go | 32 ++ internal/server/ding_talk_bot.go | 58 +++- internal/services/dtalk_bot.go | 146 ++++++--- internal/services/dtalk_bot.go.bak | 130 ++++++++ 19 files changed, 1240 insertions(+), 158 deletions(-) create mode 100644 internal/biz/handle/dingtalk/send_card.go create mode 100644 internal/biz/handle/dingtalk/send_card.go.bak1 create mode 100644 internal/services/dtalk_bot.go.bak diff --git a/go.mod b/go.mod index 521dbdb..76bed20 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module ai_scheduler go 1.24.7 require ( + gitea.cdlsxd.cn/self-tools/l-dingtalk-stream-sdk-go v0.9.3 gitea.cdlsxd.cn/self-tools/l_request v1.0.8 github.com/alibabacloud-go/darabonba-openapi/v2 v2.0.12 github.com/alibabacloud-go/dingtalk v1.6.96 @@ -11,6 +12,7 @@ require ( github.com/cloudwego/eino v0.7.7 github.com/cloudwego/eino-ext/components/model/ollama v0.1.6 github.com/cloudwego/eino-ext/components/model/openai v0.1.5 + github.com/coze-dev/coze-go v0.0.0-20251029161603-312b7fd62d20 github.com/emirpasic/gods v1.18.1 github.com/faabiosr/cachego v0.26.0 github.com/fastwego/dingding v1.0.0-beta.4 @@ -24,10 +26,10 @@ require ( github.com/google/uuid v1.6.0 github.com/google/wire v0.7.0 github.com/ollama/ollama v0.12.7 - github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1 github.com/redis/go-redis/v9 v9.16.0 github.com/spf13/viper v1.17.0 github.com/tmc/langchaingo v0.1.13 + golang.org/x/sync v0.15.0 google.golang.org/grpc v1.64.0 gorm.io/driver/mysql v1.6.0 gorm.io/gorm v1.31.0 @@ -52,7 +54,6 @@ require ( github.com/clbanning/mxj/v2 v2.5.5 // indirect github.com/cloudwego/base64x v0.1.6 // indirect github.com/cloudwego/eino-ext/libs/acl/openai v0.1.2 // indirect - github.com/coze-dev/coze-go v0.0.0-20251029161603-312b7fd62d20 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dlclark/regexp2 v1.11.4 // indirect github.com/dustin/go-humanize v1.0.1 // indirect diff --git a/go.sum b/go.sum index 6a18537..a6e9a9c 100644 --- a/go.sum +++ b/go.sum @@ -38,6 +38,8 @@ cloud.google.com/go/storage v1.14.0/go.mod h1:GrKmX003DSIwi9o29oFT7YDnHYwZoctc3f dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +gitea.cdlsxd.cn/self-tools/l-dingtalk-stream-sdk-go v0.9.3 h1:qaSPxVz5kHCs2AWvShnOG8mUgrUP9Gc3uUB4ZX1BF5A= +gitea.cdlsxd.cn/self-tools/l-dingtalk-stream-sdk-go v0.9.3/go.mod h1:5mCPTjBxOk69LRJPHWJRNTkfxcffqlQSOBMD4M5JVnE= gitea.cdlsxd.cn/self-tools/l_request v1.0.8 h1:FaKRql9mCVcSoaGqPeBOAruZ52slzRngQ6VRTYKNSsA= gitea.cdlsxd.cn/self-tools/l_request v1.0.8/go.mod h1:Qf4hVXm2Eu5vOvwXk8D7U0q/aekMCkZ4Fg9wnRKlasQ= gitea.com/xorm/sqlfiddle v0.0.0-20180821085327-62ce714f951a h1:lSA0F4e9A2NcQSqGqTOXqu2aRi/XEQxDCBwM8yJtE6s= @@ -275,8 +277,6 @@ github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORR github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25dO0g= github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k= -github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= -github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= @@ -354,8 +354,6 @@ github.com/onsi/ginkgo v1.13.0/go.mod h1:+REjRxOmWfHCjfv9TTWB1jD1Frx4XydAD3zm1ls github.com/onsi/gomega v1.5.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= -github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1 h1:Lb/Uzkiw2Ugt2Xf03J5wmv81PdkYOiWbI8CNBi1boC8= -github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1/go.mod h1:ln3IqPYYocZbYvl9TAOrG/cxGR9xcn4pnZRLdCTEGEU= github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= diff --git a/internal/biz/ding_talk_bot.go b/internal/biz/ding_talk_bot.go index d799796..b8976df 100644 --- a/internal/biz/ding_talk_bot.go +++ b/internal/biz/ding_talk_bot.go @@ -8,9 +8,15 @@ import ( "ai_scheduler/internal/data/impl" "ai_scheduler/internal/data/model" "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg/l_request" "ai_scheduler/internal/tools" + "ai_scheduler/tmpl/dataTemp" + "io" + "net/http" "strconv" + "time" + "ai_scheduler/internal/config" "context" "database/sql" "encoding/json" @@ -18,9 +24,9 @@ import ( "fmt" "strings" + "gitea.cdlsxd.cn/self-tools/l-dingtalk-stream-sdk-go/chatbot" + "github.com/coze-dev/coze-go" "github.com/gofiber/fiber/v2/log" - "github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot" - "xorm.io/builder" ) @@ -36,6 +42,8 @@ type DingTalkBotBiz struct { botGroupImpl *impl.BotGroupImpl toolManager *tools.Manager chatHis *impl.BotChatHisImpl + conf *config.Config + cardSend *dingtalk.SendCardClient } // NewDingTalkBotBiz @@ -48,6 +56,8 @@ func NewDingTalkBotBiz( tools *tools_regis.ToolRegis, chatHis *impl.BotChatHisImpl, toolManager *tools.Manager, + conf *config.Config, + cardSend *dingtalk.SendCardClient, ) *DingTalkBotBiz { return &DingTalkBotBiz{ do: do, @@ -59,6 +69,8 @@ func NewDingTalkBotBiz( botGroupImpl: botGroupImpl, toolManager: toolManager, chatHis: chatHis, + conf: conf, + cardSend: cardSend, } } @@ -75,7 +87,7 @@ func (d *DingTalkBotBiz) GetDingTalkBotCfgList() (dingBotList []entitys.DingTalk if err != nil { d.log.Info("初始化“%s”失败:%s", v.BotName, err.Error()) } - config.BotIndex = v.BotIndex + config.BotIndex = v.RobotCode dingBotList = append(dingBotList, config) } return @@ -92,8 +104,8 @@ func (d *DingTalkBotBiz) InitRequire(ctx context.Context, data *chatbot.BotCallb } func (d *DingTalkBotBiz) Do(ctx context.Context, requireData *entitys.RequireDataDingTalkBot) (err error) { - entitys.ResLoading(requireData.Ch, "", "收到消息,正在处理中,请稍等") - defer close(requireData.Ch) + //entitys.ResLoading(requireData.Ch, "", "收到消息,正在处理中,请稍等") + //defer close(requireData.Ch) switch constants.ConversationType(requireData.Req.ConversationType) { case constants.ConversationTypeSingle: err = d.handleSingleChat(ctx, requireData) @@ -124,7 +136,7 @@ func (d *DingTalkBotBiz) handleSingleChat(ctx context.Context, requireData *enti } func (d *DingTalkBotBiz) handleGroupChat(ctx context.Context, requireData *entitys.RequireDataDingTalkBot) (err error) { - group, err := d.initGroup(ctx, requireData.Req.ConversationId, requireData.Req.ConversationTitle) + group, err := d.initGroup(ctx, requireData.Req.ConversationId, requireData.Req.ConversationTitle, requireData.Req.RobotCode) if err != nil { return @@ -142,8 +154,8 @@ func (d *DingTalkBotBiz) handleGroupChat(ctx context.Context, requireData *entit return d.handleMatch(ctx, rec) } -func (d *DingTalkBotBiz) initGroup(ctx context.Context, conversationId string, conversationTitle string) (group *model.AiBotGroup, err error) { - group, err = d.botGroupImpl.GetByConversationId(conversationId) +func (d *DingTalkBotBiz) initGroup(ctx context.Context, conversationId string, conversationTitle string, robotCode string) (group *model.AiBotGroup, err error) { + group, err = d.botGroupImpl.GetByConversationIdAndRobotCode(conversationId, robotCode) if err != nil { if !errors.Is(err, sql.ErrNoRows) { @@ -155,10 +167,11 @@ func (d *DingTalkBotBiz) initGroup(ctx context.Context, conversationId string, c group = &model.AiBotGroup{ ConversationID: conversationId, Title: conversationTitle, + RobotCode: robotCode, ToolList: "", } //如果不存在则创建 - d.botGroupImpl.Add(group) + _, err = d.botGroupImpl.Add(group) } return } @@ -199,13 +212,19 @@ func (d *DingTalkBotBiz) recognize(ctx context.Context, requireData *entitys.Req userContent, err := d.getUserContent(requireData.Req.Msgtype, requireData.Req.Text.Content) if err != nil { - return nil, err + return } rec = &entitys.Recognize{ Ch: requireData.Ch, SystemPrompt: d.defaultPrompt(), UserContent: userContent, } + //历史记录 + rec.ChatHis, err = d.getHis(ctx, constants.ConversationType(requireData.Req.ConversationType), requireData.ID) + if err != nil { + return + } + //工具注册 if len(tools) > 0 { rec.Tasks = make([]entitys.RegistrationTask, 0, len(tools)) for _, task := range tools { @@ -226,6 +245,36 @@ func (d *DingTalkBotBiz) recognize(ctx context.Context, requireData *entitys.Req return } +func (d *DingTalkBotBiz) getHis(ctx context.Context, conversationType constants.ConversationType, Id int32) (content entitys.ChatHis, err error) { + + var ( + his []model.AiBotChatHi + ) + cond := builder.NewCond() + cond = cond.And(builder.Eq{"his_type": conversationType}) + cond = cond.And(builder.Eq{"id": Id}) + _, err = d.chatHis.GetListToStruct(&cond, &dataTemp.ReqPageBo{Limit: d.conf.Sys.SessionLen}, &his, "his_id desc") + if err != nil { + return + } + messages := make([]entitys.HisMessage, 0, len(his)) + for _, v := range his { + messages = append(messages, entitys.HisMessage{ + Role: constants.Caller(v.Role), // 用户角色 + Content: v.Content, // 用户输入内容 + Timestamp: v.CreateAt.Format(time.DateTime), + }) + } + return entitys.ChatHis{ + SessionId: fmt.Sprintf("%s_%d", conversationType, Id), + Messages: messages, + Context: entitys.HisContext{ + UserLanguage: constants.LangZhCN, // 默认中文 + SystemMode: constants.SystemModeTechnicalSupport, // 默认技术支持模式 + }, + }, nil +} + func (d *DingTalkBotBiz) getUserContent(msgType string, msgContent interface{}) (content *entitys.RecognizeUserContent, err error) { switch constants.BotMsgType(msgType) { case constants.BotMsgTypeText: @@ -261,16 +310,116 @@ func (d *DingTalkBotBiz) handleMatch(ctx context.Context, rec *entitys.Recognize return d.otherTask(ctx, rec) } switch constants.TaskType(pointTask.Type) { - //case constants.TaskTypeApi: - //return d.handleApiTask(ctx, requireData, pointTask) case constants.TaskTypeFunc: return d.handleTask(ctx, rec, pointTask) + case constants.TaskTypeCozeWorkflow: + return d.handleCozeWorkflow(ctx, rec, pointTask) default: return d.otherTask(ctx, rec) } return } +func (d *DingTalkBotBiz) handleCozeWorkflow(ctx context.Context, rec *entitys.Recognize, task *model.AiBotTool) (err error) { + entitys.ResLoading(rec.Ch, task.Index, "正在执行工作流(coze)\n") + + customClient := &http.Client{ + Timeout: time.Minute * 30, + } + + authCli := coze.NewTokenAuth(d.conf.Coze.ApiSecret) + cozeCli := coze.NewCozeAPI( + authCli, + coze.WithBaseURL(d.conf.Coze.BaseURL), + coze.WithHttpClient(customClient), + ) + + // 从参数中获取workflowID + type requestParams struct { + Request l_request.Request `json:"request"` + } + var config requestParams + err = json.Unmarshal([]byte(task.Config), &config) + if err != nil { + return err + } + workflowId, ok := config.Request.Json["workflow_id"].(string) + if !ok { + return fmt.Errorf("workflow_id不能为空") + } + // 提取参数 + var data map[string]interface{} + err = json.Unmarshal([]byte(rec.Match.Parameters), &data) + + req := &coze.RunWorkflowsReq{ + WorkflowID: workflowId, + Parameters: data, + // IsAsync: true, + } + + stream := config.Request.Json["stream"].(bool) + + entitys.ResLog(rec.Ch, task.Index, "工作流执行中...") + + if stream { + streamResp, err := cozeCli.Workflows.Runs.Stream(ctx, req) + if err != nil { + return err + } + + handleCozeWorkflowEvents(ctx, streamResp, cozeCli, workflowId, rec.Ch, task.Index) + } else { + resp, err := cozeCli.Workflows.Runs.Create(ctx, req) + if err != nil { + return err + } + + entitys.ResJson(rec.Ch, task.Index, resp.Data) + } + + return +} + +// handleCozeWorkflowEvents 处理 coze 工作流事件 +func handleCozeWorkflowEvents(ctx context.Context, resp coze.Stream[coze.WorkflowEvent], cozeCli coze.CozeAPI, workflowID string, ch chan entitys.Response, index string) { + defer resp.Close() + for { + event, err := resp.Recv() + if errors.Is(err, io.EOF) { + fmt.Println("Stream finished") + break + } + if err != nil { + fmt.Println("Error receiving event:", err) + break + } + + switch event.Event { + case coze.WorkflowEventTypeMessage: + entitys.ResStream(ch, index, event.Message.Content) + case coze.WorkflowEventTypeError: + entitys.ResError(ch, index, fmt.Sprintf("工作流执行错误: %s", event.Error)) + case coze.WorkflowEventTypeDone: + entitys.ResEnd(ch, index, "工作流执行完成") + case coze.WorkflowEventTypeInterrupt: + resumeReq := &coze.ResumeRunWorkflowsReq{ + WorkflowID: workflowID, + EventID: event.Interrupt.InterruptData.EventID, + ResumeData: "your data", + InterruptType: event.Interrupt.InterruptData.Type, + } + newResp, err := cozeCli.Workflows.Runs.Resume(ctx, resumeReq) + if err != nil { + entitys.ResError(ch, index, fmt.Sprintf("工作流恢复执行错误: %s", err.Error())) + return + } + entitys.ResLog(ch, index, "工作流恢复执行中...") + handleCozeWorkflowEvents(ctx, newResp, cozeCli, workflowID, ch, index) + } + } + fmt.Printf("done, log:%s\n", resp.Response().LogID()) +} + func (d *DingTalkBotBiz) handleTask(ctx context.Context, rec *entitys.Recognize, task *model.AiBotTool) (err error) { var configData entitys.ConfigDataTool err = json.Unmarshal([]byte(task.Config), &configData) @@ -290,56 +439,42 @@ func (d *DingTalkBotBiz) otherTask(ctx context.Context, rec *entitys.Recognize) entitys.ResText(rec.Ch, "", rec.Match.Reasoning) return } -func (d *DingTalkBotBiz) HandleRes(ctx context.Context, data *chatbot.BotCallbackDataModel, resp entitys.Response) error { - switch resp.Type { - case entitys.ResponseText: - return d.replyText(ctx, data.SessionWebhook, resp.Content) - case entitys.ResponseStream: - return d.replySteam(ctx, data.SessionWebhook, resp.Content) - case entitys.ResponseImg: - return d.replyImg(ctx, data.SessionWebhook, resp.Content) - case entitys.ResponseFile: - return d.replyFile(ctx, data.SessionWebhook, resp.Content) - case entitys.ResponseMarkdown: - return d.replyMarkdown(ctx, data.SessionWebhook, resp.Content) - case entitys.ResponseActionCard: - return d.replyActionCard(ctx, data.SessionWebhook, resp.Content) - default: - return nil - } + +//func (d *DingTalkBotBiz) HandleRes(ctx context.Context, data *chatbot.BotCallbackDataModel, resp entitys.Response, ch chan string) error { +// switch resp.Type { +// case entitys.ResponseText: +// return d.replyText(ctx, data.SessionWebhook, resp.Content) +// case entitys.ResponseStream: +// +// return d.replySteam(ctx, data, ch) +// case entitys.ResponseImg: +// return d.replyImg(ctx, data.SessionWebhook, resp.Content) +// case entitys.ResponseFile: +// return d.replyFile(ctx, data.SessionWebhook, resp.Content) +// case entitys.ResponseMarkdown: +// return d.replyMarkdown(ctx, data.SessionWebhook, resp.Content) +// case entitys.ResponseActionCard: +// return d.replyActionCard(ctx, data.SessionWebhook, resp.Content) +// default: +// return nil +// } +//} + +func (d *DingTalkBotBiz) HandleStreamRes(ctx context.Context, data *chatbot.BotCallbackDataModel, content chan string) (err error) { + err = d.cardSend.NewCard(ctx, &dingtalk.CardSend{ + RobotCode: data.RobotCode, + ConversationType: constants.ConversationType(data.ConversationType), + Template: constants.CardTempDefault, + ContentChannel: content, // 指定内容通道 + ConversationId: data.ConversationId, + SenderStaffId: data.SenderStaffId, + Title: data.Text.Content, + }) + + return } -func (d *DingTalkBotBiz) SaveHis(ctx context.Context, requireData *entitys.RequireDataDingTalkBot, chat []string) (err error) { - if len(chat) == 0 { - return - } - his := []*model.AiBotChatHi{ - { - HisType: requireData.Req.ConversationType, - ID: requireData.ID, - Role: "user", - Content: requireData.Req.Text.Content, - }, - { - HisType: requireData.Req.ConversationType, - ID: requireData.ID, - Role: "system", - Content: strings.Join(chat, "\n"), - }, - } - _, err = d.chatHis.Add(his) - return err -} - -func (d *DingTalkBotBiz) replySteam(ctx context.Context, SessionWebhook string, content string, arg ...string) error { - msg := content - if len(arg) > 0 { - msg = fmt.Sprintf(content, arg) - } - return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg)) -} - -func (d *DingTalkBotBiz) replyText(ctx context.Context, SessionWebhook string, content string, arg ...string) error { +func (d *DingTalkBotBiz) ReplyText(ctx context.Context, SessionWebhook string, content string, arg ...string) error { msg := content if len(arg) > 0 { msg = fmt.Sprintf(content, arg) @@ -379,6 +514,28 @@ func (d *DingTalkBotBiz) replyActionCard(ctx context.Context, SessionWebhook str return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg)) } +func (d *DingTalkBotBiz) SaveHis(ctx context.Context, requireData *entitys.RequireDataDingTalkBot, chat []string) (err error) { + if len(chat) == 0 { + return + } + his := []*model.AiBotChatHi{ + { + HisType: requireData.Req.ConversationType, + ID: requireData.ID, + Role: "user", + Content: requireData.Req.Text.Content, + }, + { + HisType: requireData.Req.ConversationType, + ID: requireData.ID, + Role: "system", + Content: strings.Join(chat, "\n"), + }, + } + _, err = d.chatHis.Add(his) + return err +} + func (d *DingTalkBotBiz) defaultPrompt() string { return `[system] 你是一个智能路由系统,核心职责是 **精准解析用户意图并路由至对应任务模块**。请严格遵循以下规则: diff --git a/internal/biz/handle/dingtalk/auth.go b/internal/biz/handle/dingtalk/auth.go index 2b359a9..143ab45 100644 --- a/internal/biz/handle/dingtalk/auth.go +++ b/internal/biz/handle/dingtalk/auth.go @@ -6,6 +6,7 @@ import ( "ai_scheduler/internal/data/impl" "ai_scheduler/internal/data/model" "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg" "ai_scheduler/internal/pkg/l_request" "ai_scheduler/utils" "context" @@ -38,21 +39,26 @@ func (a *Auth) GetAccessToken(ctx context.Context, clientId string, clientSecret return nil, errors.New("clientId is empty") } accessToken := a.redis.Get(ctx, a.getKey(clientId)).Val() + var expire time.Duration if accessToken == "" { dingTalkAuthRes, _err := a.getNewAccessToken(ctx, clientId, clientSecret) if _err != nil { return nil, _err } - err = a.redis.SetEx(ctx, a.getKey(clientId), dingTalkAuthRes.AccessToken, time.Duration(dingTalkAuthRes.ExpireIn-3600)*time.Second).Err() + expire = time.Duration(dingTalkAuthRes.ExpireIn-3600) * time.Second + err = a.redis.SetEx(ctx, a.getKey(clientId), dingTalkAuthRes.AccessToken, expire).Err() if err != nil { return } accessToken = dingTalkAuthRes.AccessToken + } else { + expire, _ = a.redis.TTL(ctx, a.getKey(clientId)).Result() } return &AuthInfo{ ClientId: clientId, ClientSecret: clientSecret, AccessToken: accessToken, + Expire: expire, }, nil } @@ -60,6 +66,10 @@ func (a *Auth) getKey(clientId string) string { return a.cfg.Redis.Key + ":" + constants.DingTalkAuthBaseKeyPrefix + ":" + clientId } +func (a *Auth) getKeyBot(botCode string) string { + return a.cfg.Redis.Key + ":" + constants.DingTalkAuthBaseKeyBotPrefix + ":" + botCode +} + func (a *Auth) getNewAccessToken(ctx context.Context, clientId string, clientSecret string) (auth DingTalkAuthIRes, err error) { if clientId == "" || clientSecret == "" { err = errors.New("clientId or clientSecret is empty") @@ -89,30 +99,61 @@ func (a *Auth) GetTokenFromBotOption(ctx context.Context, botOption ...BotOption option(botInfo) } - if botInfo.id == 0 && botInfo.botConfig == nil { + if botInfo.Id == 0 && botInfo.BotConfig == nil && botInfo.BotCode == "" { err = errors.New("botInfo is nil") return } - if botInfo.botConfig == nil { - var botConfigDo model.AiBotConfig - cond := builder.NewCond() - cond = cond.And(builder.Eq{"bot_id": botInfo.id}) - err = a.botConfigImpl.GetOneBySearchToStrut(&cond, &botConfigDo) + + if botInfo.BotConfig == nil { + err = a.GetBotConfigFromModel(botInfo) if err != nil { return } - if botConfigDo.BotID == 0 { - err = errors.New("未找到机器人服务配置") + } + + authInfo := a.redis.Get(ctx, a.getKeyBot(botInfo.BotConfig.RobotCode)).Val() + if authInfo == "" { + var botConfig entitys.DingTalkBot + err = json.Unmarshal([]byte(botInfo.BotConfig.BotConfig), &botConfig) + if err != nil { + log.Infof("初始化“%s”失败:%s", botInfo.BotConfig.BotName, err.Error()) return } - botInfo.botConfig = &botConfigDo + token, err = a.GetAccessToken(ctx, botConfig.ClientId, botConfig.ClientSecret) + if err != nil { + return + } + err = a.redis.SetEx(ctx, a.getKeyBot(botInfo.BotConfig.RobotCode), pkg.JsonStringIgonErr(token), token.Expire).Err() + if err != nil { + return + } + } else { + var tokenData AuthInfo + err = json.Unmarshal([]byte(authInfo), &tokenData) + token = &tokenData } - var botConfig entitys.DingTalkBot - err = json.Unmarshal([]byte(botInfo.botConfig.BotConfig), &botConfig) + return +} + +func (a *Auth) GetBotConfigFromModel(botInfo *Bot) (err error) { + var ( + botConfigDo model.AiBotConfig + ) + cond := builder.NewCond() + if botInfo.Id > 0 { + cond = cond.And(builder.Eq{"bot_id": botInfo.Id}) + } + if botInfo.BotCode != "" { + cond = cond.And(builder.Eq{"robot_code": botInfo.BotCode}) + } + err = a.botConfigImpl.GetOneBySearchToStrut(&cond, &botConfigDo) if err != nil { - log.Infof("初始化“%s”失败:%s", botInfo.botConfig.BotName, err.Error()) return } - return a.GetAccessToken(ctx, botConfig.ClientId, botConfig.ClientSecret) - + if botConfigDo.BotID == 0 { + err = errors.New("未找到机器人服务配置") + return + } + botInfo.BotConfig = &botConfigDo + return nil } diff --git a/internal/biz/handle/dingtalk/option.go b/internal/biz/handle/dingtalk/option.go index fb473c7..3b72795 100644 --- a/internal/biz/handle/dingtalk/option.go +++ b/internal/biz/handle/dingtalk/option.go @@ -3,19 +3,34 @@ package dingtalk import "ai_scheduler/internal/data/model" type Bot struct { - id int - botConfig *model.AiBotConfig + Id int + BotCode string + BotConfig *model.AiBotConfig } type BotOption func(*Bot) func WithId(id int) BotOption { return func(b *Bot) { - b.id = id + b.Id = id } } -func WithBootConfig(BotConfig *model.AiBotConfig) BotOption { +func WithBotConfig(BotConfig *model.AiBotConfig) BotOption { return func(bot *Bot) { - bot.botConfig = BotConfig + bot.BotConfig = BotConfig + } +} + +func WithBotCode(BotCode string) BotOption { + return func(bot *Bot) { + bot.BotCode = BotCode + } +} + +func WithBot(botSelf *Bot) BotOption { + return func(bot *Bot) { + bot.BotCode = botSelf.BotCode + bot.Id = botSelf.Id + bot.BotConfig = botSelf.BotConfig } } diff --git a/internal/biz/handle/dingtalk/provider_set.go b/internal/biz/handle/dingtalk/provider_set.go index 579d464..70f31ff 100644 --- a/internal/biz/handle/dingtalk/provider_set.go +++ b/internal/biz/handle/dingtalk/provider_set.go @@ -8,4 +8,5 @@ var ProviderSetDingTalk = wire.NewSet( NewUser, NewAuth, NewDept, + NewSendCardClient, ) diff --git a/internal/biz/handle/dingtalk/send_card.go b/internal/biz/handle/dingtalk/send_card.go new file mode 100644 index 0000000..c2063da --- /dev/null +++ b/internal/biz/handle/dingtalk/send_card.go @@ -0,0 +1,287 @@ +package dingtalk + +import ( + "ai_scheduler/internal/data/constants" + "ai_scheduler/internal/pkg" + "context" + "encoding/json" + "errors" + "fmt" + "strings" + "sync" + "time" + + openapi "github.com/alibabacloud-go/darabonba-openapi/v2/client" + dingtalkim_1_0 "github.com/alibabacloud-go/dingtalk/im_1_0" + util "github.com/alibabacloud-go/tea-utils/v2/service" + "github.com/alibabacloud-go/tea/tea" + "github.com/gofiber/fiber/v2/log" + "github.com/google/uuid" +) + +const DefaultInterval = 100 * time.Millisecond +const HeardBeatX = 50 + +type SendCardClient struct { + Auth *Auth + CardClient *sync.Map + mu sync.RWMutex // 保护 CardClient 的并发访问 + logger log.AllLogger // 日志记录 + botOption *Bot +} + +func NewSendCardClient(auth *Auth, logger log.AllLogger) *SendCardClient { + return &SendCardClient{ + Auth: auth, + CardClient: &sync.Map{}, + logger: logger, + botOption: &Bot{}, + } +} + +// initClient 初始化或复用 DingTalk 客户端 +func (s *SendCardClient) initClient(robotCode string) (*dingtalkim_1_0.Client, error) { + if client, ok := s.CardClient.Load(robotCode); ok { + return client.(*dingtalkim_1_0.Client), nil + } + s.botOption.BotCode = robotCode + config := &openapi.Config{ + Protocol: tea.String("https"), + RegionId: tea.String("central"), + } + client, err := dingtalkim_1_0.NewClient(config) + if err != nil { + s.logger.Error("failed to init DingTalk client") + return nil, fmt.Errorf("init client failed: %w", err) + } + + s.CardClient.Store(robotCode, client) + return client, nil +} + +func (s *SendCardClient) NewCard(ctx context.Context, cardSend *CardSend) error { + // 参数校验 + if (len(cardSend.ContentSlice) == 0 || cardSend.ContentSlice == nil) && cardSend.ContentChannel == nil { + return errors.New("卡片内容不能为空") + } + if cardSend.UpdateInterval == 0 { + cardSend.UpdateInterval = DefaultInterval // 默认更新间隔 + } + if cardSend.Title == "" { + cardSend.Title = "钉钉卡片" + } + //替换标题 + replace, err := pkg.SafeReplace(string(cardSend.Template), "${title}", cardSend.Title) + if err != nil { + return err + } + cardSend.Template = constants.CardTemp(replace) + // 初始化客户端 + client, err := s.initClient(cardSend.RobotCode) + if err != nil { + return fmt.Errorf("初始化client失败: %w", err) + } + + // 生成卡片实例ID + cardInstanceId, err := uuid.NewUUID() + if err != nil { + return fmt.Errorf("创建uuid失败: %w", err) + } + + // 构建初始请求 + request, err := s.buildBaseRequest(cardSend, cardInstanceId.String()) + if err != nil { + return fmt.Errorf("请求失败: %w", err) + } + + // 发送初始卡片 + if _, err := s.SendInteractiveCard(ctx, request, cardSend.RobotCode, client); err != nil { + return fmt.Errorf("发送初始卡片失败: %w", err) + } + + // 处理切片内容(同步) + if len(cardSend.ContentSlice) > 0 { + if err := s.processContentSlice(ctx, cardSend, cardInstanceId.String(), client); err != nil { + return fmt.Errorf("内容同步失败: %w", err) + } + } + + // 处理通道内容(异步) + if cardSend.ContentChannel != nil { + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + s.processContentChannel(ctx, cardSend, cardInstanceId.String(), client) + }() + wg.Wait() + } + + return nil +} + +// buildBaseRequest 构建基础请求 +func (s *SendCardClient) buildBaseRequest(cardSend *CardSend, cardInstanceId string) (*dingtalkim_1_0.SendRobotInteractiveCardRequest, error) { + cardData := fmt.Sprintf(string(cardSend.Template), "") // 初始空内容 + request := &dingtalkim_1_0.SendRobotInteractiveCardRequest{ + CardTemplateId: tea.String("StandardCard"), + CardBizId: tea.String(cardInstanceId), + CardData: tea.String(cardData), + RobotCode: tea.String(cardSend.RobotCode), + SendOptions: &dingtalkim_1_0.SendRobotInteractiveCardRequestSendOptions{}, + PullStrategy: tea.Bool(false), + } + + switch cardSend.ConversationType { + case constants.ConversationTypeGroup: + request.SetOpenConversationId(cardSend.ConversationId) + case constants.ConversationTypeSingle: + receiver, err := json.Marshal(map[string]string{"userId": cardSend.SenderStaffId}) + if err != nil { + return nil, fmt.Errorf("数据整理失败: %w", err) + } + request.SetSingleChatReceiver(string(receiver)) + default: + return nil, errors.New("未知的聊天场景") + } + + return request, nil +} + +// processContentChannel 处理通道内容(异步更新) +func (s *SendCardClient) processContentChannel(ctx context.Context, cardSend *CardSend, cardInstanceId string, client *dingtalkim_1_0.Client) { + defer func() { + if r := recover(); r != nil { + s.logger.Error("panic in processContentChannel") + } + }() + + ticker := time.NewTicker(cardSend.UpdateInterval) + defer ticker.Stop() + heartbeatTicker := time.NewTicker(time.Duration(HeardBeatX) * DefaultInterval) + defer heartbeatTicker.Stop() + + var ( + contentBuilder strings.Builder + lastUpdate time.Time + ) + for { + + select { + case content, ok := <-cardSend.ContentChannel: + if !ok { + // 通道关闭,发送最终内容 + if contentBuilder.Len() > 0 { + if err := s.updateCardContent(ctx, cardSend, cardInstanceId, contentBuilder.String(), client); err != nil { + s.logger.Errorf("更新卡片失败1:%s", err.Error()) + } + } + return + } + contentBuilder.WriteString(content) + if contentBuilder.Len() > 0 { + if err := s.updateCardContent(ctx, cardSend, cardInstanceId, contentBuilder.String(), client); err != nil { + s.logger.Errorf("更新卡片失败2:%s", err.Error()) + } + } + lastUpdate = time.Now() + + case <-heartbeatTicker.C: + if time.Now().Unix()-lastUpdate.Unix() >= HeardBeatX { + return + } + + case <-ctx.Done(): + s.logger.Info("context canceled, stop channel processing") + return + } + } + +} + +// processContentSlice 处理切片内容(同步更新) +func (s *SendCardClient) processContentSlice(ctx context.Context, cardSend *CardSend, cardInstanceId string, client *dingtalkim_1_0.Client) error { + var contentBuilder strings.Builder + for _, content := range cardSend.ContentSlice { + + contentBuilder.WriteString(content) + err := s.updateCardRequest(ctx, &UpdateCardRequest{ + Template: string(cardSend.Template), + Content: contentBuilder.String(), + Client: client, + RobotCode: cardSend.RobotCode, + CardInstanceId: cardInstanceId, + }) + if err != nil { + return fmt.Errorf("更新卡片失败: %w", err) + } + time.Sleep(cardSend.UpdateInterval) // 控制更新频率 + } + return nil +} + +// updateCardContent 封装卡片更新逻辑 +func (s *SendCardClient) updateCardContent(ctx context.Context, cardSend *CardSend, cardInstanceId, content string, client *dingtalkim_1_0.Client) error { + err := s.updateCardRequest(ctx, &UpdateCardRequest{ + Template: string(cardSend.Template), + Content: content, + Client: client, + RobotCode: cardSend.RobotCode, + CardInstanceId: cardInstanceId, + }) + + return err +} + +func (s *SendCardClient) updateCardRequest(ctx context.Context, updateCardRequest *UpdateCardRequest) error { + content, err := pkg.SafeReplace(updateCardRequest.Template, "%s", updateCardRequest.Content) + if err != nil { + return err + } + updateRequest := &dingtalkim_1_0.UpdateRobotInteractiveCardRequest{ + CardBizId: tea.String(updateCardRequest.CardInstanceId), + CardData: tea.String(content), + } + _, err = s.UpdateInteractiveCard(ctx, updateRequest, updateCardRequest.RobotCode, updateCardRequest.Client) + return err +} + +// UpdateInteractiveCard 更新交互卡片(封装错误处理) +func (s *SendCardClient) UpdateInteractiveCard(ctx context.Context, request *dingtalkim_1_0.UpdateRobotInteractiveCardRequest, robotCode string, client *dingtalkim_1_0.Client) (*dingtalkim_1_0.UpdateRobotInteractiveCardResponse, error) { + authInfo, err := s.Auth.GetTokenFromBotOption(ctx, WithBot(s.botOption)) + if err != nil { + return nil, fmt.Errorf("get token failed: %w", err) + } + + headers := &dingtalkim_1_0.UpdateRobotInteractiveCardHeaders{ + XAcsDingtalkAccessToken: tea.String(authInfo.AccessToken), + } + + response, err := client.UpdateRobotInteractiveCardWithOptions(request, headers, &util.RuntimeOptions{}) + if err != nil { + return nil, fmt.Errorf("API call failed: %w,request:%v", err, request.String()) + } + return response, nil +} + +// SendInteractiveCard 发送交互卡片(封装错误处理) +func (s *SendCardClient) SendInteractiveCard(ctx context.Context, request *dingtalkim_1_0.SendRobotInteractiveCardRequest, robotCode string, client *dingtalkim_1_0.Client) (res *dingtalkim_1_0.SendRobotInteractiveCardResponse, err error) { + err = s.Auth.GetBotConfigFromModel(s.botOption) + if err != nil { + return nil, fmt.Errorf("初始化bot失败: %w", err) + } + authInfo, err := s.Auth.GetTokenFromBotOption(ctx, WithBot(s.botOption)) + if err != nil { + return nil, fmt.Errorf("get token failed: %w", err) + } + + headers := &dingtalkim_1_0.SendRobotInteractiveCardHeaders{ + XAcsDingtalkAccessToken: tea.String(authInfo.AccessToken), + } + + response, err := client.SendRobotInteractiveCardWithOptions(request, headers, &util.RuntimeOptions{}) + if err != nil { + return nil, fmt.Errorf("API call failed: %w", err) + } + return response, nil +} diff --git a/internal/biz/handle/dingtalk/send_card.go.bak1 b/internal/biz/handle/dingtalk/send_card.go.bak1 new file mode 100644 index 0000000..9fb1e8d --- /dev/null +++ b/internal/biz/handle/dingtalk/send_card.go.bak1 @@ -0,0 +1,280 @@ +package dingtalk + +import ( + "ai_scheduler/internal/data/constants" + "context" + "encoding/json" + "errors" + "fmt" + "strings" + "sync" + "time" + + openapi "github.com/alibabacloud-go/darabonba-openapi/v2/client" + dingtalkcard_1_0 "github.com/alibabacloud-go/dingtalk/card_1_0" + dingtalkim_1_0 "github.com/alibabacloud-go/dingtalk/im_1_0" + util "github.com/alibabacloud-go/tea-utils/v2/service" + "github.com/alibabacloud-go/tea/tea" + "github.com/gofiber/fiber/v2/log" + "github.com/google/uuid" +) + +const DefaultInterval = 100 * time.Millisecond +const HeardBeatX = 50 + +type SendCardClient struct { + Auth *Auth + CardClient *sync.Map + mu sync.RWMutex // 保护 CardClient 的并发访问 + logger log.AllLogger // 日志记录 + botOption *Bot +} + +func NewSendCardClient(auth *Auth, logger log.AllLogger) *SendCardClient { + return &SendCardClient{ + Auth: auth, + CardClient: &sync.Map{}, + logger: logger, + botOption: &Bot{}, + } +} + +// initClient 初始化或复用 DingTalk 客户端 +func (s *SendCardClient) initClient(robotCode string) (*dingtalkcard_1_0.Client, error) { + if client, ok := s.CardClient.Load(robotCode); ok { + return client.(*dingtalkcard_1_0.Client), nil + } + s.botOption.BotCode = robotCode + config := &openapi.Config{ + Protocol: tea.String("https"), + RegionId: tea.String("central"), + } + client, err := dingtalkcard_1_0.NewClient(config) + if err != nil { + s.logger.Error("failed to init DingTalk client") + return nil, fmt.Errorf("init client failed: %w", err) + } + + s.CardClient.Store(robotCode, client) + return client, nil +} + +func (s *SendCardClient) NewCard(ctx context.Context, cardSend *CardSend) error { + // 参数校验 + if (len(cardSend.ContentSlice) == 0 || cardSend.ContentSlice == nil) && cardSend.ContentChannel == nil { + return errors.New("卡片内容不能为空") + } + if cardSend.UpdateInterval == 0 { + cardSend.UpdateInterval = DefaultInterval // 默认更新间隔 + } + if cardSend.Title == "" { + cardSend.Title = "钉钉卡片" + } + //替换标题 + cardSend.Template = constants.CardTemp(strings.Replace(string(cardSend.Template), "${title}", cardSend.Title, 1)) + // 初始化客户端 + client, err := s.initClient(cardSend.RobotCode) + if err != nil { + return fmt.Errorf("初始化client失败: %w", err) + } + + // 生成卡片实例ID + cardInstanceId, err := uuid.NewUUID() + if err != nil { + return fmt.Errorf("创建uuid失败: %w", err) + } + + // 构建初始请求 + request, err := s.buildBaseRequest(cardSend, cardInstanceId.String()) + if err != nil { + return fmt.Errorf("请求失败: %w", err) + } + + // 发送初始卡片 + if _, err := s.SendInteractiveCard(ctx, request, cardSend.RobotCode, client); err != nil { + return fmt.Errorf("发送初始卡片失败: %w", err) + } + + // 处理切片内容(同步) + if len(cardSend.ContentSlice) > 0 { + if err := s.processContentSlice(ctx, cardSend, cardInstanceId.String(), client); err != nil { + return fmt.Errorf("内容同步失败: %w", err) + } + } + + // 处理通道内容(异步) + if cardSend.ContentChannel != nil { + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + s.processContentChannel(ctx, cardSend, cardInstanceId.String(), client) + }() + wg.Wait() + } + + return nil +} + +// buildBaseRequest 构建基础请求 +func (s *SendCardClient) buildBaseRequest(cardSend *CardSend, cardInstanceId string) (*dingtalkcard_1_0.StreamingUpdateRequest, error) { + cardData := fmt.Sprintf(string(cardSend.Template), "") // 初始空内容 + request := &dingtalkcard_1_0.StreamingUpdateRequest{ + OutTrackId: tea.String("your-out-track-id"), + Guid: tea.String("0F714542-0AFC-2B0E-CF14-E2D39F5BFFE8"), + Key: tea.String("your-ai-param"), + Content: tea.String("test"), + IsFull: tea.Bool(false), + IsFinalize: tea.Bool(false), + IsError: tea.Bool(false), + } + + switch cardSend.ConversationType { + case constants.ConversationTypeGroup: + request.SetOpenConversationId(cardSend.ConversationId) + case constants.ConversationTypeSingle: + receiver, err := json.Marshal(map[string]string{"userId": cardSend.SenderStaffId}) + if err != nil { + return nil, fmt.Errorf("数据整理失败: %w", err) + } + request.SetSingleChatReceiver(string(receiver)) + default: + return nil, errors.New("未知的聊天场景") + } + + return request, nil +} + +// processContentChannel 处理通道内容(异步更新) +func (s *SendCardClient) processContentChannel(ctx context.Context, cardSend *CardSend, cardInstanceId string, client *dingtalkim_1_0.Client) { + defer func() { + if r := recover(); r != nil { + s.logger.Error("panic in processContentChannel") + } + }() + + ticker := time.NewTicker(cardSend.UpdateInterval) + defer ticker.Stop() + heartbeatTicker := time.NewTicker(time.Duration(HeardBeatX) * DefaultInterval) + defer heartbeatTicker.Stop() + + var ( + contentBuilder strings.Builder + lastUpdate time.Time + ) + for { + + select { + case content, ok := <-cardSend.ContentChannel: + if !ok { + // 通道关闭,发送最终内容 + if contentBuilder.Len() > 0 { + if err := s.updateCardContent(ctx, cardSend, cardInstanceId, contentBuilder.String(), client); err != nil { + s.logger.Errorf("更新卡片失败1:%s", err.Error()) + } + } + return + } + contentBuilder.WriteString(content) + if contentBuilder.Len() > 0 { + if err := s.updateCardContent(ctx, cardSend, cardInstanceId, contentBuilder.String(), client); err != nil { + s.logger.Errorf("更新卡片失败2:%s", err.Error()) + } + } + lastUpdate = time.Now() + + case <-heartbeatTicker.C: + if time.Now().Unix()-lastUpdate.Unix() >= HeardBeatX { + return + } + + case <-ctx.Done(): + s.logger.Info("context canceled, stop channel processing") + return + } + } + +} + +// processContentSlice 处理切片内容(同步更新) +func (s *SendCardClient) processContentSlice(ctx context.Context, cardSend *CardSend, cardInstanceId string, client *dingtalkim_1_0.Client) error { + var contentBuilder strings.Builder + for _, content := range cardSend.ContentSlice { + contentBuilder.WriteString(content) + err := s.updateCardRequest(ctx, &UpdateCardRequest{ + Template: string(cardSend.Template), + Content: contentBuilder.String(), + Client: client, + RobotCode: cardSend.RobotCode, + CardInstanceId: cardInstanceId, + }) + if err != nil { + return fmt.Errorf("更新卡片失败: %w", err) + } + time.Sleep(cardSend.UpdateInterval) // 控制更新频率 + } + return nil +} + +// updateCardContent 封装卡片更新逻辑 +func (s *SendCardClient) updateCardContent(ctx context.Context, cardSend *CardSend, cardInstanceId, content string, client *dingtalkim_1_0.Client) error { + err := s.updateCardRequest(ctx, &UpdateCardRequest{ + Template: string(cardSend.Template), + Content: content, + Client: client, + RobotCode: cardSend.RobotCode, + CardInstanceId: cardInstanceId, + }) + + return err +} + +func (s *SendCardClient) updateCardRequest(ctx context.Context, updateCardRequest *UpdateCardRequest) error { + + updateRequest := &dingtalkim_1_0.UpdateRobotInteractiveCardRequest{ + CardBizId: tea.String(updateCardRequest.CardInstanceId), + CardData: tea.String(fmt.Sprintf(updateCardRequest.Template, updateCardRequest.Content)), + } + _, err := s.UpdateInteractiveCard(ctx, updateRequest, updateCardRequest.RobotCode, updateCardRequest.Client) + return err +} + +// UpdateInteractiveCard 更新交互卡片(封装错误处理) +func (s *SendCardClient) UpdateInteractiveCard(ctx context.Context, request *dingtalkim_1_0.UpdateRobotInteractiveCardRequest, robotCode string, client *dingtalkim_1_0.Client) (*dingtalkim_1_0.UpdateRobotInteractiveCardResponse, error) { + authInfo, err := s.Auth.GetTokenFromBotOption(ctx, WithBot(s.botOption)) + if err != nil { + return nil, fmt.Errorf("get token failed: %w", err) + } + + headers := &dingtalkim_1_0.UpdateRobotInteractiveCardHeaders{ + XAcsDingtalkAccessToken: tea.String(authInfo.AccessToken), + } + + response, err := client.UpdateRobotInteractiveCardWithOptions(request, headers, &util.RuntimeOptions{}) + if err != nil { + return nil, fmt.Errorf("API call failed: %w,request:%v", err, request.String()) + } + return response, nil +} + +// SendInteractiveCard 发送交互卡片(封装错误处理) +func (s *SendCardClient) SendInteractiveCard(ctx context.Context, request *dingtalkim_1_0.SendRobotInteractiveCardRequest, robotCode string, client *dingtalkim_1_0.Client) (res *dingtalkim_1_0.SendRobotInteractiveCardResponse, err error) { + err = s.Auth.GetBotConfigFromModel(s.botOption) + if err != nil { + return nil, fmt.Errorf("初始化bot失败: %w", err) + } + authInfo, err := s.Auth.GetTokenFromBotOption(ctx, WithBot(s.botOption)) + if err != nil { + return nil, fmt.Errorf("get token failed: %w", err) + } + + headers := &dingtalkim_1_0.SendRobotInteractiveCardHeaders{ + XAcsDingtalkAccessToken: tea.String(authInfo.AccessToken), + } + + response, err := client.SendRobotInteractiveCardWithOptions(request, headers, &util.RuntimeOptions{}) + if err != nil { + return nil, fmt.Errorf("API call failed: %w", err) + } + return response, nil +} diff --git a/internal/biz/handle/dingtalk/types.go b/internal/biz/handle/dingtalk/types.go index a36ea76..5baa770 100644 --- a/internal/biz/handle/dingtalk/types.go +++ b/internal/biz/handle/dingtalk/types.go @@ -1,6 +1,11 @@ package dingtalk -import "time" +import ( + "ai_scheduler/internal/data/constants" + "time" + + dingtalkim_1_0 "github.com/alibabacloud-go/dingtalk/im_1_0" +) type DingTalkAuthIRes struct { AccessToken string `json:"accessToken"` @@ -78,7 +83,28 @@ type DeptResResult struct { } type AuthInfo struct { - ClientId string `json:"clientId"` - ClientSecret string `json:"clientSecret"` - AccessToken string `json:"accessToken"` + ClientId string `json:"clientId"` + ClientSecret string `json:"clientSecret"` + AccessToken string `json:"accessToken"` + Expire time.Duration `json:"expireIn"` +} + +type CardSend struct { + RobotCode string + ConversationType constants.ConversationType + ConversationId string + Template constants.CardTemp + SenderStaffId string + Title string + ContentSlice []string + ContentChannel chan string + UpdateInterval time.Duration // 控制通道更新的频率 +} + +type UpdateCardRequest struct { + Template string + Content string + Client *dingtalkim_1_0.Client + RobotCode string + CardInstanceId string } diff --git a/internal/data/constants/bot.go b/internal/data/constants/bot.go index 78a46f1..d0ca85c 100644 --- a/internal/data/constants/bot.go +++ b/internal/data/constants/bot.go @@ -34,6 +34,8 @@ const ( const DingTalkAuthBaseKeyPrefix = "dingTalk_auth" +const DingTalkAuthBaseKeyBotPrefix = "dingTalk_auth_bot" + // PermissionType 工具使用权限 type PermissionType int32 diff --git a/internal/data/constants/dingtalk.go b/internal/data/constants/dingtalk.go index c6d55b0..fbbc7b8 100644 --- a/internal/data/constants/dingtalk.go +++ b/internal/data/constants/dingtalk.go @@ -49,3 +49,32 @@ type BotMsgType string const ( BotMsgTypeText BotMsgType = "text" ) + +type CardTemp string + +const ( + CardTempDefault CardTemp = `{ + "config": { + "autoLayout": true, + "enableForward": true + }, + "header": { + "title": { + "type": "text", + "text": "${title}", + }, + "logo": "@lALPDfJ6V_FPDmvNAfTNAfQ" + }, + "contents": [ + { + "type": "divider", + "id": "divider_1765952728523" + }, + { + "type": "markdown", + "text": "%s", + "id": "markdown_1765970168635" + } + ] +}` +) diff --git a/internal/data/impl/bot_group.go b/internal/data/impl/bot_group.go index 4382d82..e0593c4 100644 --- a/internal/data/impl/bot_group.go +++ b/internal/data/impl/bot_group.go @@ -17,9 +17,9 @@ func NewBotGroupImpl(db *utils.Db) *BotGroupImpl { } } -func (k BotGroupImpl) GetByConversationId(staffId string) (*model.AiBotGroup, error) { +func (k BotGroupImpl) GetByConversationIdAndRobotCode(staffId string, robotCode string) (*model.AiBotGroup, error) { var data model.AiBotGroup - err := k.Db.Model(k.Model).Where("conversation_id = ?", staffId).Find(&data).Error + err := k.Db.Model(k.Model).Where("conversation_id = ? and robot_code = ?", staffId, robotCode).Find(&data).Error if data.GroupID == 0 { err = sql.ErrNoRows } diff --git a/internal/data/model/ai_bot_config.gen.go b/internal/data/model/ai_bot_config.gen.go index 6885e81..e6142f7 100644 --- a/internal/data/model/ai_bot_config.gen.go +++ b/internal/data/model/ai_bot_config.gen.go @@ -13,11 +13,11 @@ const TableNameAiBotConfig = "ai_bot_config" // AiBotConfig mapped from table type AiBotConfig struct { BotID int32 `gorm:"column:bot_id;primaryKey;autoIncrement:true" json:"bot_id"` - SysID int32 `gorm:"column:sys_id;not null" json:"sys_id"` BotType int32 `gorm:"column:bot_type;not null;default:1;comment:类型,1为钉钉机器人" json:"bot_type"` // 类型,1为钉钉机器人 - BotName string `gorm:"column:bot_name;not null;comment:名字" json:"bot_name"` // 名字 - BotConfig string `gorm:"column:bot_config;not null;comment:配置" json:"bot_config"` // 配置 - BotIndex string `gorm:"column:bot_index;not null;comment:索引" json:"bot_index"` // 索引 + SysPrompt string `gorm:"column:sys_prompt" json:"sys_prompt"` + BotName string `gorm:"column:bot_name;not null;comment:名字" json:"bot_name"` // 名字 + BotConfig string `gorm:"column:bot_config;not null;comment:配置" json:"bot_config"` // 配置 + RobotCode string `gorm:"column:robot_code;not null;comment:索引" json:"robot_code"` // 索引 CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"` UpdatedAt time.Time `gorm:"column:updated_at;default:CURRENT_TIMESTAMP" json:"updated_at"` Status int32 `gorm:"column:status;not null" json:"status"` diff --git a/internal/data/model/ai_bot_group.gen.go b/internal/data/model/ai_bot_group.gen.go index d0ff93a..80c50d1 100644 --- a/internal/data/model/ai_bot_group.gen.go +++ b/internal/data/model/ai_bot_group.gen.go @@ -14,6 +14,7 @@ const TableNameAiBotGroup = "ai_bot_group" type AiBotGroup struct { GroupID int32 `gorm:"column:group_id;primaryKey;autoIncrement:true" json:"group_id"` ConversationID string `gorm:"column:conversation_id;not null;comment:会话ID" json:"conversation_id"` // 会话ID + RobotCode string `gorm:"column:robot_code;not null;comment:绑定机器人code" json:"robot_code"` // 绑定机器人code Title string `gorm:"column:title;not null;comment:群名称" json:"title"` // 群名称 ToolList string `gorm:"column:tool_list;not null;comment:开通工具列表" json:"tool_list"` // 开通工具列表 Status int32 `gorm:"column:status;not null;default:1" json:"status"` diff --git a/internal/entitys/bot.go b/internal/entitys/bot.go index 46661ab..ed0902c 100644 --- a/internal/entitys/bot.go +++ b/internal/entitys/bot.go @@ -3,7 +3,7 @@ package entitys import ( "ai_scheduler/internal/data/model" - "github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot" + "gitea.cdlsxd.cn/self-tools/l-dingtalk-stream-sdk-go/chatbot" ) type RequireDataDingTalkBot struct { diff --git a/internal/pkg/func.go b/internal/pkg/func.go index 57321bd..32c404b 100644 --- a/internal/pkg/func.go +++ b/internal/pkg/func.go @@ -133,3 +133,35 @@ func SliceIntToString(slice []int) []string { } return strSlice } + +// SafeReplace 替换字符串中的 %s,并自动转义特殊字符(如 ") +/** + * SafeReplace 函数用于安全地替换模板字符串中的占位符 + * @param template 原始模板字符串 + * @param replaceTag 要被替换的占位符(如 "%s") + * @param replacements 可变参数,用于替换占位符的字符串 + * @return 返回替换后的字符串和可能的错误 + */ +func SafeReplace(template string, replaceTag string, replacements ...string) (string, error) { + // 如果没有提供替换参数,直接返回原始模板 + if len(replacements) == 0 { + return template, nil + } + + // 检查模板中 %s 的数量是否匹配替换参数 + expectedReplacements := strings.Count(template, replaceTag) + if expectedReplacements != len(replacements) { + return "", fmt.Errorf("模板需要 %d 个替换参数,但提供了 %d 个", expectedReplacements, len(replacements)) + } + + // 逐个替换 %s,并转义特殊字符 + for _, rep := range replacements { + // 转义特殊字符(如 ", \, \n 等) + escaped := strconv.Quote(rep) + // 去掉 strconv.Quote 添加的额外引号 + escaped = escaped[1 : len(escaped)-1] + template = strings.Replace(template, replaceTag, escaped, 1) + } + + return template, nil +} diff --git a/internal/server/ding_talk_bot.go b/internal/server/ding_talk_bot.go index f68f543..2eb31c6 100644 --- a/internal/server/ding_talk_bot.go +++ b/internal/server/ding_talk_bot.go @@ -4,10 +4,12 @@ import ( "ai_scheduler/internal/entitys" "ai_scheduler/internal/services" "context" + "fmt" + "sync" + "gitea.cdlsxd.cn/self-tools/l-dingtalk-stream-sdk-go/chatbot" + "gitea.cdlsxd.cn/self-tools/l-dingtalk-stream-sdk-go/client" "github.com/go-kratos/kratos/v2/log" - "github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot" - "github.com/open-dingtalk/dingtalk-stream-sdk-go/client" ) type DingBotServiceInterface interface { @@ -54,18 +56,48 @@ func ProvideAllDingBotServices( } func (d *DingTalkBotServer) Run(ctx context.Context, botIndex string) { - for name, cli := range d.Clients { - if botIndex != "All" { - if name != botIndex { - continue + if botIndex == "" { + log.Info("未指定机器人索引,跳过启动") + return + } + + var targets []string + switch { + case botIndex == "All": + targets = make([]string, 0, len(d.Clients)) + for name := range d.Clients { + targets = append(targets, name) + } + default: + if _, exists := d.Clients[botIndex]; exists { + targets = []string{botIndex} + } else { + log.Infof("未找到索引为 %s 的机器人", botIndex) + return + } + } + + var wg sync.WaitGroup + errors := make([]error, 0, len(targets)) + + for _, name := range targets { + wg.Add(1) + go func(name string) { + defer wg.Done() + err := d.Clients[name].Start(ctx) + if err != nil { + log.Errorf("%s 启动失败: %v", name, err) + errors = append(errors, fmt.Errorf("%s: %w", name, err)) + } else { + log.Infof("%s 启动成功", name) } - } - err := cli.Start(ctx) - if err != nil { - log.Infof("%s启动失败", name) - continue - } - log.Infof("%s启动成功", name) + }(name) + } + + wg.Wait() + if len(errors) > 0 { + log.Errorf("部分机器人启动失败,总数: %d, 成功: %d, 失败: %d", + len(targets), len(targets)-len(errors), len(errors)) } } func DingBotServerInit(clientId string, clientSecret string, service DingBotServiceInterface) (cli *client.StreamClient) { diff --git a/internal/services/dtalk_bot.go b/internal/services/dtalk_bot.go index 5c9c92b..5b9a8e6 100644 --- a/internal/services/dtalk_bot.go +++ b/internal/services/dtalk_bot.go @@ -2,15 +2,15 @@ package services import ( "ai_scheduler/internal/biz" - "log" - "time" - "ai_scheduler/internal/config" - "ai_scheduler/internal/entitys" "context" + "log" + "sync" + "time" - "github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot" + "gitea.cdlsxd.cn/self-tools/l-dingtalk-stream-sdk-go/chatbot" + "golang.org/x/sync/errgroup" ) type DingBotService struct { @@ -18,65 +18,115 @@ type DingBotService struct { dingTalkBotBiz *biz.DingTalkBotBiz } -func NewDingBotService(config *config.Config, DingTalkBotBiz *biz.DingTalkBotBiz) *DingBotService { - return &DingBotService{config: config, dingTalkBotBiz: DingTalkBotBiz} +func NewDingBotService(config *config.Config, dingTalkBotBiz *biz.DingTalkBotBiz) *DingBotService { + return &DingBotService{ + config: config, + dingTalkBotBiz: dingTalkBotBiz, + } } func (d *DingBotService) GetServiceCfg() ([]entitys.DingTalkBot, error) { return d.dingTalkBotBiz.GetDingTalkBotCfgList() } -func (d *DingBotService) OnChatBotMessageReceived(ctx context.Context, data *chatbot.BotCallbackDataModel) (content []byte, err error) { - var ( - lastErr error - chat []string - ) +func (d *DingBotService) OnChatBotMessageReceived(ctx context.Context, data *chatbot.BotCallbackDataModel) ([]byte, error) { requireData, err := d.dingTalkBotBiz.InitRequire(ctx, data) if err != nil { - return + return nil, err } - // 使用 ctx.Done() 通知 Do 方法提前终止 - subCtx, cancel := context.WithCancel(ctx) - defer func() { - cancel() - _ = d.dingTalkBotBiz.SaveHis(ctx, requireData, chat) - }() - // 异步执行 Do 方法 - done := make(chan error, 1) + // 启动后台任务(独立生命周期,带超时控制) go func() { - done <- d.dingTalkBotBiz.Do(subCtx, requireData) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + if err := d.runBackgroundTasks(ctx, data, requireData); err != nil { + log.Printf("后台任务执行失败: %v", err) + } }() - for { - select { - case <-ctx.Done(): - lastErr = ctx.Err() - goto cleanup - case resp, ok := <-requireData.Ch: - if !ok { - return []byte("success"), nil - } - if resp.Type == entitys.ResponseLog { - continue - } - if resp.Type == entitys.ResponseText || resp.Type == entitys.ResponseStream || resp.Type == entitys.ResponseJson { - chat = append(chat, resp.Content) - } - if err := d.dingTalkBotBiz.HandleRes(ctx, data, resp); err != nil { - log.Printf("HandleRes 失败: %v", err) + return []byte("success"), nil +} + +func (d *DingBotService) runBackgroundTasks(ctx context.Context, data *chatbot.BotCallbackDataModel, requireData *entitys.RequireDataDingTalkBot) error { + g, ctx := errgroup.WithContext(ctx) + var ( + chat []string + chatMu sync.Mutex + resChan = make(chan string, 10) + ) + + // 1. 流式处理协程 + g.Go(func() error { + defer func() { + // 确保通道最终关闭 + close(resChan) + }() + return d.dingTalkBotBiz.HandleStreamRes(ctx, data, resChan) + }) + + // 2. 业务处理协程(负责关闭requireData.Ch) + g.Go(func() error { + // 在完成时关闭通道 + defer close(requireData.Ch) + return d.dingTalkBotBiz.Do(ctx, requireData) + }) + + // 3. 结果收集协程(修改后的版本) + resultDone := make(chan struct{}) + g.Go(func() error { + // 使用defer确保通道关闭 + defer close(resultDone) + + // 处理通道中的数据 + for { + select { + case resp, ok := <-requireData.Ch: + if !ok { + return nil // 通道已关闭,正常退出 + } + if resp.Type != entitys.ResponseLog { + chatMu.Lock() + chat = append(chat, resp.Content) + chatMu.Unlock() + + select { + case resChan <- resp.Content: + case <-ctx.Done(): + return ctx.Err() + } + } + case <-ctx.Done(): + return ctx.Err() // 上下文取消,提前退出 } } - } -cleanup: - select { - case _err := <-done: - if _err != nil { - panic(_err) + }) + + // 4. 统一关闭通道的协程(只关闭resChan) + g.Go(func() error { + <-resultDone + // resChan已在流式处理协程关闭 + return nil + }) + + // 5. 历史记录保存协程 + g.Go(func() error { + <-resultDone + chatMu.Lock() + savedChat := make([]string, len(chat)) + copy(savedChat, chat) + chatMu.Unlock() + + if err := d.dingTalkBotBiz.SaveHis(ctx, requireData, savedChat); err != nil { + log.Printf("保存历史记录失败: %v", err) + return err } - case <-time.After(1 * time.Second): - log.Println("警告:等待 Do 方法超时,可能发生 goroutine 泄漏") + return nil + }) + + // 阻塞直到所有协程完成或出错 + if err := g.Wait(); err != nil { + return err } - return nil, lastErr + return nil } diff --git a/internal/services/dtalk_bot.go.bak b/internal/services/dtalk_bot.go.bak new file mode 100644 index 0000000..75c2c7f --- /dev/null +++ b/internal/services/dtalk_bot.go.bak @@ -0,0 +1,130 @@ +package services + +import ( + "ai_scheduler/internal/biz" + "log" + "sync" + "time" + + "ai_scheduler/internal/config" + "ai_scheduler/internal/entitys" + "context" + + "gitea.cdlsxd.cn/self-tools/l-dingtalk-stream-sdk-go/chatbot" +) + +type DingBotService struct { + config *config.Config + dingTalkBotBiz *biz.DingTalkBotBiz +} + +func NewDingBotService(config *config.Config, DingTalkBotBiz *biz.DingTalkBotBiz) *DingBotService { + return &DingBotService{config: config, dingTalkBotBiz: DingTalkBotBiz} +} + +func (d *DingBotService) GetServiceCfg() ([]entitys.DingTalkBot, error) { + return d.dingTalkBotBiz.GetDingTalkBotCfgList() +} + +func (d *DingBotService) OnChatBotMessageReceived(ctx context.Context, data *chatbot.BotCallbackDataModel) ([]byte, error) { + var ( + lastErr error + chat []string + streamWG sync.WaitGroup + resChan = make(chan string, 100) // 缓冲通道防止阻塞 + ) + + // 初始化请求 + requireData, err := d.dingTalkBotBiz.InitRequire(ctx, data) + if err != nil { + return nil, err + } + + // 创建子上下文用于控制goroutine生命周期 + subCtx, cancel := context.WithCancel(ctx) + defer cancel() + + // 启动流式处理goroutine + streamWG.Add(1) + go func() { + defer streamWG.Done() + err = d.dingTalkBotBiz.HandleStreamRes(subCtx, data, resChan) + if err != nil { + return + } + }() + + // 启动业务处理goroutine + done := make(chan error, 1) + go func() { + done <- d.dingTalkBotBiz.Do(subCtx, requireData) + }() + + // 主处理循环 + for { + select { + case <-ctx.Done(): + lastErr = ctx.Err() + goto cleanup + + case resp, ok := <-requireData.Ch: + if !ok { + goto cleanup + } + + // 处理不同类型响应 + switch resp.Type { + case entitys.ResponseLog: + // 忽略日志类型 + continue + + //case entitys.ResponseText, entitys.ResponseJson: + // chat = append(chat, resp.Content) + // if err := d.dingTalkBotBiz.ReplyText(ctx, data.SessionWebhook, resp.Content); err != nil { + // log.Printf("处理非流响应失败: %v", err) + // lastErr = err + // } + + default: + chat = append(chat, resp.Content) + select { + case resChan <- resp.Content: + case <-ctx.Done(): + lastErr = ctx.Err() + goto cleanup + } + } + } + } + +cleanup: + streamWG.Wait() + // 关闭流式通道 + close(resChan) + + // 保存历史记录 + if saveErr := d.dingTalkBotBiz.SaveHis(ctx, requireData, chat); saveErr != nil { + log.Printf("保存历史记录失败: %v", saveErr) + if lastErr == nil { + lastErr = saveErr + } + } + + // 等待业务处理完成(带超时) + select { + case err := <-done: + if err != nil { + log.Printf("业务处理失败: %v", err) + if lastErr == nil { + lastErr = err + } + } + case <-time.After(3 * time.Second): // 增加超时时间 + log.Println("警告:等待业务处理超时,可能发生goroutine泄漏") + } + + if lastErr != nil { + return nil, lastErr + } + return []byte("success"), nil +} From 284624bcba2a6c1e5d9dc9b916d05070a57089d5 Mon Sep 17 00:00:00 2001 From: fuzhongyun <15339891972@163.com> Date: Thu, 18 Dec 2025 18:19:46 +0800 Subject: [PATCH 49/66] =?UTF-8?q?feat:=201.=20=E6=96=B0=E5=A2=9E=20ollamaC?= =?UTF-8?q?lient=20chat=E6=96=B9=E6=B3=95=202.=20=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E4=BA=A7=E5=93=81=E6=95=B0=E6=8D=AE=E6=8F=90=E5=8F=96=E8=83=BD?= =?UTF-8?q?=E5=8A=9B=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/data/error/error_code.go | 4 + internal/domain/llm/options.go | 21 +-- .../domain/llm/provider/ollama/adapter.go | 9 +- internal/pkg/util/time.go | 41 ++++++ internal/pkg/utils_ollama/client.go | 19 +++ internal/server/http.go | 15 +- internal/server/router/router.go | 15 +- internal/services/callback.go | 69 +++++----- internal/services/capability.go | 129 ++++++++++++++++++ internal/services/provider_set.go | 4 +- 10 files changed, 266 insertions(+), 60 deletions(-) create mode 100644 internal/pkg/util/time.go create mode 100644 internal/services/capability.go diff --git a/internal/data/error/error_code.go b/internal/data/error/error_code.go index 9c28865..9d907e7 100644 --- a/internal/data/error/error_code.go +++ b/internal/data/error/error_code.go @@ -54,3 +54,7 @@ func ParamErr(message string, arg ...any) *BusinessErr { func (e *BusinessErr) Wrap(err error) *BusinessErr { return NewBusinessErr(e.code, err.Error()) } + +func KeyErr() *BusinessErr { + return &BusinessErr{code: KeyNotFound.code, message: KeyNotFound.message} +} diff --git a/internal/domain/llm/options.go b/internal/domain/llm/options.go index c427153..19f8983 100644 --- a/internal/domain/llm/options.go +++ b/internal/domain/llm/options.go @@ -3,14 +3,15 @@ package llm import "time" type Options struct { - Temperature float32 - MaxTokens int - Stream bool - Timeout time.Duration - Modalities []string - SystemPrompt string - Model string - TopP float32 - Stop []string - Endpoint string + Temperature float32 + MaxTokens int + Stream bool + Timeout time.Duration + Modalities []string + SystemPrompt string + Model string + TopP float32 + Stop []string + Endpoint string + Thinking bool } diff --git a/internal/domain/llm/provider/ollama/adapter.go b/internal/domain/llm/provider/ollama/adapter.go index 1c26bab..554bcf0 100644 --- a/internal/domain/llm/provider/ollama/adapter.go +++ b/internal/domain/llm/provider/ollama/adapter.go @@ -15,10 +15,11 @@ func New() *Adapter { return &Adapter{} } func (a *Adapter) Generate(ctx context.Context, input []*schema.Message, opts llm.Options) (*schema.Message, error) { cm, err := eino_ollama.NewChatModel(ctx, &eino_ollama.ChatModelConfig{ - BaseURL: opts.Endpoint, - Timeout: opts.Timeout, - Model: opts.Model, - Options: &eino_ollama.Options{Temperature: opts.Temperature, NumPredict: opts.MaxTokens}, + BaseURL: opts.Endpoint, + Timeout: opts.Timeout, + Model: opts.Model, + Options: &eino_ollama.Options{Temperature: opts.Temperature, NumPredict: opts.MaxTokens}, + Thinking: &eino_ollama.ThinkValue{Value: opts.Thinking}, }) if err != nil { return nil, err diff --git a/internal/pkg/util/time.go b/internal/pkg/util/time.go new file mode 100644 index 0000000..1c1bf3e --- /dev/null +++ b/internal/pkg/util/time.go @@ -0,0 +1,41 @@ +package util + +import "time" + +// 判断当前时间是否在时间窗口内 +// ts 时间戳字符串,支持秒级或毫秒级 +// window 时间窗口,例如 10 * time.Minute +func IsInTimeWindow(ts string, window time.Duration) bool { + // 期望毫秒时间戳或秒级,简单容错 + // 尝试解析为整数 + var n int64 + for _, base := range []int64{1, 1000} { // 秒或毫秒 + if v, ok := parseInt64(ts); ok { + n = v + // 归一为毫秒 + if base == 1 && len(ts) <= 10 { + n = n * 1000 + } + now := time.Now().UnixMilli() + diff := now - n + if diff < 0 { + diff = -diff + } + if diff <= window.Milliseconds() { + return true + } + } + } + return false +} + +func parseInt64(s string) (int64, bool) { + var n int64 + for _, ch := range s { + if ch < '0' || ch > '9' { + return 0, false + } + n = n*10 + int64(ch-'0') + } + return n, true +} diff --git a/internal/pkg/utils_ollama/client.go b/internal/pkg/utils_ollama/client.go index 91640f0..1f67774 100644 --- a/internal/pkg/utils_ollama/client.go +++ b/internal/pkg/utils_ollama/client.go @@ -90,6 +90,25 @@ func (c *Client) ChatStream(ctx context.Context, ch chan entitys.Response, messa return } +func (c *Client) Chat(ctx context.Context, messages []api.Message) (res api.ChatResponse, err error) { + // 构建聊天请求 + req := &api.ChatRequest{ + Model: c.config.Model, + Messages: messages, + Stream: new(bool), // 设置为false,不使用流式响应 + Think: &api.ThinkValue{Value: true}, + } + err = c.client.Chat(ctx, req, func(resp api.ChatResponse) error { + res = resp + return nil + }) + if err != nil { + return + } + + return +} + func (c *Client) Generation(ctx context.Context, generateRequest *api.GenerateRequest) (result api.GenerateResponse, err error) { err = c.client.Generate(ctx, generateRequest, func(resp api.GenerateResponse) error { result = resp diff --git a/internal/server/http.go b/internal/server/http.go index fd7e49e..53446c8 100644 --- a/internal/server/http.go +++ b/internal/server/http.go @@ -11,11 +11,13 @@ import ( ) type HTTPServer struct { - app *fiber.App - service *services.ChatService - session *services.SessionService - gateway *gateway.Gateway - callback *services.CallbackService + app *fiber.App + service *services.ChatService + session *services.SessionService + gateway *gateway.Gateway + callback *services.CallbackService + chatHis *services.HistoryService + capabilityService *services.CapabilityService } func NewHTTPServer( @@ -25,10 +27,11 @@ func NewHTTPServer( gateway *gateway.Gateway, callback *services.CallbackService, chatHis *services.HistoryService, + capabilityService *services.CapabilityService, ) *fiber.App { //构建 server app := initRoute() - router.SetupRoutes(app, service, session, task, gateway, callback, chatHis) + router.SetupRoutes(app, service, session, task, gateway, callback, chatHis, capabilityService) return app } diff --git a/internal/server/router/router.go b/internal/server/router/router.go index e2645bb..be861aa 100644 --- a/internal/server/router/router.go +++ b/internal/server/router/router.go @@ -15,16 +15,18 @@ import ( ) type RouterServer struct { - app *fiber.App - service *services.ChatService - session *services.SessionService - gateway *gateway.Gateway - chatHist *services.HistoryService + app *fiber.App + service *services.ChatService + session *services.SessionService + gateway *gateway.Gateway + chatHist *services.HistoryService + capabilityService *services.CapabilityService } // SetupRoutes 设置路由 func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionService *services.SessionService, task *services.TaskService, gateway *gateway.Gateway, callbackService *services.CallbackService, chatHist *services.HistoryService, + capabilityService *services.CapabilityService, ) { app.Use(func(c *fiber.Ctx) error { // 设置 CORS 头 @@ -84,6 +86,9 @@ func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionServi // 会话历史 r.Post("/chat/history/list", chatHist.List) r.Post("/chat/history/update/content", chatHist.UpdateContent) + + // 能力 + r.Post("/capability/product/ingest", capabilityService.ProductIngest) // 商品数据提取 } func routerSocket(app *fiber.App, chatService *services.ChatService) { diff --git a/internal/services/callback.go b/internal/services/callback.go index c903bea..cc32660 100644 --- a/internal/services/callback.go +++ b/internal/services/callback.go @@ -77,7 +77,8 @@ func (s *CallbackService) Callback(c *fiber.Ctx) error { ts := strings.TrimSpace(c.Get("X-Timestamp")) // 时间窗口(如果提供了 ts 则校验,否则跳过),窗口 5 分钟 - if ts != "" && !validateTimestamp(ts, 5*time.Minute) { + // if ts != "" && !validateTimestamp(ts, 5*time.Minute) { + if ts != "" && !util.IsInTimeWindow(ts, 5*time.Minute) { return errorcode.AuthNotFound } @@ -101,40 +102,40 @@ func (s *CallbackService) Callback(c *fiber.Ctx) error { } } -func validateTimestamp(ts string, window time.Duration) bool { - // 期望毫秒时间戳或秒级,简单容错 - // 尝试解析为整数 - var n int64 - for _, base := range []int64{1, 1000} { // 秒或毫秒 - if v, ok := parseInt64(ts); ok { - n = v - // 归一为毫秒 - if base == 1 && len(ts) <= 10 { - n = n * 1000 - } - now := time.Now().UnixMilli() - diff := now - n - if diff < 0 { - diff = -diff - } - if diff <= window.Milliseconds() { - return true - } - } - } - return false -} +// func validateTimestamp(ts string, window time.Duration) bool { +// // 期望毫秒时间戳或秒级,简单容错 +// // 尝试解析为整数 +// var n int64 +// for _, base := range []int64{1, 1000} { // 秒或毫秒 +// if v, ok := parseInt64(ts); ok { +// n = v +// // 归一为毫秒 +// if base == 1 && len(ts) <= 10 { +// n = n * 1000 +// } +// now := time.Now().UnixMilli() +// diff := now - n +// if diff < 0 { +// diff = -diff +// } +// if diff <= window.Milliseconds() { +// return true +// } +// } +// } +// return false +// } -func parseInt64(s string) (int64, bool) { - var n int64 - for _, ch := range s { - if ch < '0' || ch > '9' { - return 0, false - } - n = n*10 + int64(ch-'0') - } - return n, true -} +// func parseInt64(s string) (int64, bool) { +// var n int64 +// for _, ch := range s { +// if ch < '0' || ch > '9' { +// return 0, false +// } +// n = n*10 + int64(ch-'0') +// } +// return n, true +// } func (s *CallbackService) handleDingTalkCallback(c *fiber.Ctx, env Envelope) error { // 校验taskId diff --git a/internal/services/capability.go b/internal/services/capability.go new file mode 100644 index 0000000..e0adea0 --- /dev/null +++ b/internal/services/capability.go @@ -0,0 +1,129 @@ +package services + +import ( + "ai_scheduler/internal/config" + errorcode "ai_scheduler/internal/data/error" + "ai_scheduler/internal/pkg/util" + "ai_scheduler/internal/pkg/utils_ollama" + "context" + "fmt" + "strings" + "time" + + "github.com/gofiber/fiber/v2" + "github.com/ollama/ollama/api" +) + +// CapabilityService 统一回调入口 +type CapabilityService struct { + cfg *config.Config +} + +func NewCapabilityService(cfg *config.Config) *CapabilityService { + return &CapabilityService{ + cfg: cfg, + } +} + +// 产品数据提取入参 +type ProductIngestReq struct { + Url string `json:"url"` // 商品详情页URL + Title string `json:"title"` // 商品标题 + Text string `json:"text"` // 商品描述 + Images []string `json:"images"` // 商品图片URL列表 + Timestamp int64 `json:"timestamp"` // 商品发布时间戳 +} + +const ( + // 货易通商品属性模板-中文 + HYTProductPropertyTemplateZH = `{ + "条码": "string", // 商品编号 + "分类名称": "string", // 商品分类 + "货品名称": "string", // 商品名称 + "货品编号": "string", // 商品编号 + "商品货号": "string", // 商品编号 + "品牌": "string", // 商品品牌 + "单位": "string", // 商品单位,若无则使用'个' + "规格参数": "string", // 商品规格参数 + "货品说明": "string", // 商品说明 + "保质期": "string", // 商品保质期 + "保质期单位": "string", // 商品保质期单位 + "链接": "string", // + "货品图片": ["string"], // 商品多图,取1-2个即可 + "电商销售价格": "decimal(10,2)", // 商品电商销售价格 + "销售价": "decimal(10,2)", // 商品销售价格 + "供应商报价": "decimal(10,2)", // 商品供应商报价 + "税率": "number%", // 商品税率 x% + "默认供应商": "", // 空即可 + "默认存放仓库": "", // 空即可 + "备注": "", // 备注 + "长": "string", // 商品长度,decimal(10,2)+单位 + "宽": "string", // 商品宽度,decimal(10,2)+单位 + "高": "string", // 商品高度,decimal(10,2)+单位 + "重量": "string", // 商品重量(kg) + "SPU名称": "string", // 商品SPU名称 + "SPU编码": "string" // 编码串,jd_{timestamp}_rand(1000-999) + }` + + SystemPrompt = `你是一个专业的商品属性提取助手,你的任务是根据用户输入提取商品的属性信息。 + 目标属性模板:%s。 + 最终输出格式为纯JSON字符串,键值对对应目标属性和提取到的属性值。 + 最终输出不要携带markdown标识,不要携带回车换行` +) + +// ProductIngest 产品数据提取 +func (s *CapabilityService) ProductIngest(c *fiber.Ctx) error { + // 读取头 + token := strings.TrimSpace(c.Get("X-Source-Key")) + ts := strings.TrimSpace(c.Get("X-Timestamp")) + + // 时间窗口校验 + if ts != "" && !util.IsInTimeWindow(ts, 5*time.Minute) { + // return errorcode.AuthNotFound + } + // token校验 + if token == "" || token != "A7f9KQ3mP2X8LZC4R5e" { + // return errorcode.KeyErr() + } + + // 解析请求参数 + req := ProductIngestReq{} + if err := c.BodyParser(&req); err != nil { + return errorcode.ParamErr("invalid request body: %v", err) + } + + // 必要参数校验 + if req.Text == "" { + return errorcode.ParamErr("missing required fields") + } + + // 模型调用 + client, cleanup, err := utils_ollama.NewClient(s.cfg) + if err != nil { + return err + } + defer cleanup() + + res, err := client.Chat(context.Background(), []api.Message{ + { + Role: "system", + Content: fmt.Sprintf(SystemPrompt, HYTProductPropertyTemplateZH), + }, + { + Role: "user", + Content: req.Text, + }, + { + Role: "user", + Content: "商品图片URL列表:" + strings.Join(req.Images, ","), + }, + }) + if err != nil { + return err + } + + // 解析模型输出 + c.JSON(res.Message.Content) + + return nil +} diff --git a/internal/services/provider_set.go b/internal/services/provider_set.go index 867eb11..375a886 100644 --- a/internal/services/provider_set.go +++ b/internal/services/provider_set.go @@ -12,4 +12,6 @@ var ProviderSetServices = wire.NewSet( NewTaskService, NewCallbackService, NewDingBotService, - NewHistoryService) + NewHistoryService, + NewCapabilityService, +) From 8414a57661f7633843caeb471c60eeadb514db63 Mon Sep 17 00:00:00 2001 From: fuzhongyun <15339891972@163.com> Date: Fri, 19 Dec 2025 10:10:45 +0800 Subject: [PATCH 50/66] =?UTF-8?q?fix=EF=BC=9A=E6=89=93=E5=BC=80=E5=AE=89?= =?UTF-8?q?=E5=85=A8=E6=A0=A1=E9=AA=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/services/capability.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/services/capability.go b/internal/services/capability.go index e0adea0..aac842d 100644 --- a/internal/services/capability.go +++ b/internal/services/capability.go @@ -79,11 +79,11 @@ func (s *CapabilityService) ProductIngest(c *fiber.Ctx) error { // 时间窗口校验 if ts != "" && !util.IsInTimeWindow(ts, 5*time.Minute) { - // return errorcode.AuthNotFound + return errorcode.AuthNotFound } // token校验 if token == "" || token != "A7f9KQ3mP2X8LZC4R5e" { - // return errorcode.KeyErr() + return errorcode.KeyErr() } // 解析请求参数 From 5b789e557a876bb270844bb7601729b608429a7d Mon Sep 17 00:00:00 2001 From: fuzhongyun <15339891972@163.com> Date: Fri, 19 Dec 2025 18:38:06 +0800 Subject: [PATCH 51/66] =?UTF-8?q?feat=EF=BC=9A=201.=20=E6=96=B0=E5=A2=9E?= =?UTF-8?q?=E8=B4=A7=E6=98=93=E9=80=9A=E5=95=86=E5=93=81=E4=B8=8A=E4=BC=A0?= =?UTF-8?q?=E5=B7=A5=E5=85=B7=202.=20=E6=96=B0=E5=A2=9E=E8=B4=A7=E6=98=93?= =?UTF-8?q?=E9=80=9A=E5=95=86=E5=93=81=E4=B8=8A=E4=BC=A0=E5=B7=A5=E4=BD=9C?= =?UTF-8?q?=E6=B5=81=203.=20=E6=96=B0=E5=A2=9E=E5=95=86=E5=93=81=E4=B8=8A?= =?UTF-8?q?=E4=BC=A0=E8=87=B3=E8=B4=A7=E6=98=93=E9=80=9A=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/config_env.yaml | 3 + internal/config/config.go | 2 + internal/data/constants/capability.go | 174 ++++++++++++++++++ .../domain/tools/hyt/product_upload/client.go | 60 ++++++ .../tools/hyt/product_upload/client_test.go | 60 ++++++ .../domain/tools/hyt/product_upload/types.go | 54 ++++++ .../domain/workflow/hyt/product_upload.go | 112 +++++++++++ internal/domain/workflow/runtime/registry.go | 2 +- .../zltx/order_after_reseller_batch.go | 14 +- internal/pkg/util/map.go | 14 ++ internal/server/router/router.go | 8 +- internal/services/capability.go | 107 +++++------ 12 files changed, 548 insertions(+), 62 deletions(-) create mode 100644 internal/data/constants/capability.go create mode 100644 internal/domain/tools/hyt/product_upload/client.go create mode 100644 internal/domain/tools/hyt/product_upload/client_test.go create mode 100644 internal/domain/tools/hyt/product_upload/types.go create mode 100644 internal/domain/workflow/hyt/product_upload.go create mode 100644 internal/pkg/util/map.go diff --git a/config/config_env.yaml b/config/config_env.yaml index 23a123b..aa8b07b 100644 --- a/config/config_env.yaml +++ b/config/config_env.yaml @@ -74,6 +74,9 @@ tools: zltxOrderAfterSaleResellerBatch: enabled: true base_url: "https://gateway.dev.cdlsxd.cn/zltx_api/admin/afterSales/reseller_pre_ai" + hytProductUpload: + enabled: true + base_url: "https://gateway.dev.cdlsxd.cn/zltx_api/admin/oursProduct" diff --git a/internal/config/config.go b/internal/config/config.go index ee0c1a5..0c3f4dd 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -141,6 +141,8 @@ type ToolsConfig struct { CozeExpress ToolConfig `mapstructure:"cozeExpress"` // Coze 公司查询工具 CozeCompany ToolConfig `mapstructure:"cozeCompany"` + // 货易通商品上传 + HytProductUpload ToolConfig `mapstructure:"hytProductUpload"` } // ToolConfig 单个工具配置 diff --git a/internal/data/constants/capability.go b/internal/data/constants/capability.go new file mode 100644 index 0000000..4ee518b --- /dev/null +++ b/internal/data/constants/capability.go @@ -0,0 +1,174 @@ +package constants + +// Token +const ( + CapabilityProductIngestToken = "A7f9KQ3mP2X8LZC4R5e" +) + +// Prompt +const ( + SystemPrompt = ` + #你是一个专业的商品属性提取助手,你的任务是根据用户输入提取商品的属性信息。 + 目标属性模板:%s。 + 1.最终输出格式为纯JSON字符串,键值对对应目标属性和提取到的属性值。 + 2.最终输出不要携带markdown标识,不要携带回车换行` +) + +// 商品属性模板-中文 +const ( + // 货易通商品属性模板-中文 + HYTProductPropertyTemplateZH = `{ + "条码": "string", // 商品编号 + "分类名称": "string", // 商品分类 + "货品名称": "string", // 商品名称 + "货品编号": "string", // 商品编号 + "商品货号": "string", // 商品编号 + "品牌": "string", // 商品品牌 + "单位": "string", // 商品单位,若无则使用'个' + "规格参数": "string", // 商品规格参数 + "货品说明": "string", // 商品说明 + "保质期": "string", // 商品保质期 + "保质期单位": "string", // 商品保质期单位 + "链接": "string", // + "货品图片": ["string"], // 商品多图,取1-2个即可 + "电商销售价格": "decimal(10,2)", // 商品电商销售价格 + "销售价": "decimal(10,2)", // 商品销售价格 + "供应商报价": "decimal(10,2)", // 商品供应商报价 + "税率": "number%", // 商品税率 x% + "默认供应商": "", // 空即可 + "默认存放仓库": "", // 空即可 + "备注": "", // 备注 + "长": "string", // 商品长度,decimal(10,2)+单位 + "宽": "string", // 商品宽度,decimal(10,2)+单位 + "高": "string", // 商品高度,decimal(10,2)+单位 + "重量": "string", // 商品重量(kg) + "SPU名称": "string", // 商品SPU名称 + "SPU编码": "string" // 编码串,jd_{timestamp}_rand(1000-999) + }` +) + +// 商品属性模板 +const ( + HYTProductPropertyTemplate = `{ + "important_data": { + "type": "object", + "properties": { + "supplier_name": { + "type": "string", + "description": "供应商名称" + }, + "warehouse_name": { + "type": "string", + "description": "仓库名称" + }, + "profit": { + "type": "float64", + "description": "利润 decimal(10,2)" + }, + "tax_rate": { + "type": "integer", + "description": "税率 (x)%" + } + } + }, + "goods_info": { + "type": "object", + "properties": { + "title": { + "type": "string", + "description": "商品名称" + }, + "brand": { + "type": "string", + "description": "品牌" + }, + "category": { + "type": "string", + "description": "分类" + }, + "price": { + "type": "float64", + "description": "市场价 decimal(10,2)" + }, + "sales_price": { + "type": "float64", + "description": "建议销售价 decimal(10,2)" + }, + "discount": { + "type": "integer", + "description": "折扣百分比 公式:(市场价-建议销售价)/市场价*100" + }, + "goods_attributes": { + "type": "string", + "description": "商品属性" + }, + "goods_bar_code": { + "type": "string", + "description": "商品条码" + }, + "goods_illustration": { + "type": "string", + "description": "商品插图" + }, + "goods_num": { + "type": "string", + "description": "商品编号" + }, + "introduction": { + "type": "string", + "description": "商品介绍" + }, + "spu_name": { + "type": "string", + "description": "SPU名称" + }, + "spu_num": { + "type": "string", + "description": "SPU编号" + }, + "stock": { + "type": "integer", + "description": "库存" + }, + "tax_rate": { + "type": "integer", + "description": "税率" + }, + "unit": { + "type": "string", + "description": "单位" + }, + "weight": { + "type": "string", + "description": "重量" + } + } + }, + "goods_media_list": { + "type": "array", + "description": "商品媒体文件列表", + "items": { + "type": "object", + "properties": { + "url": { + "type": "string", + "description": "媒体文件URL" + }, + "type": { + "type": "integer", + "description": "媒体类型(1:图片, 2:视频)" + }, + "sort": { + "type": "integer", + "description": "排序序号" + } + } + } + } + }` +) + +// 外部平台地址 +const ( + HYTProductListPageURL = "https://gateway.dev.cdlsxd.cn/sw//#/goods/goodsManage" +) diff --git a/internal/domain/tools/hyt/product_upload/client.go b/internal/domain/tools/hyt/product_upload/client.go new file mode 100644 index 0000000..4924b50 --- /dev/null +++ b/internal/domain/tools/hyt/product_upload/client.go @@ -0,0 +1,60 @@ +package product_upload + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/pkg/l_request" + "ai_scheduler/internal/pkg/util" + "context" + "encoding/json" + "errors" +) + +func Call(ctx context.Context, cfg config.ToolConfig, toolReq *ProductUploadRequest) (toolResp *ProductUploadResponse, err error) { + // 商品有且只能有一个 + if len(toolReq.GoodsList) != 1 { + err = errors.New("商品只能有一个") + return + } + + apiReq, _ := util.StructToMap(toolReq) + + req := l_request.Request{ + Method: "Post", + Url: "http://120.55.12.245:8100/api/v1/goods/supplier/batch/add/complete", + Json: apiReq, + } + res, err := req.Send() + + if err != nil { + return + } + + type resType struct { + Code int `json:"code"` + Msg string `json:"msg"` + Data struct { + Ids []int `json:"ids"` // 预览URL + } `json:"data"` + } + var resMap resType + err = json.Unmarshal([]byte(res.Text), &resMap) + if err != nil { + return + } + if resMap.Code != 200 { + err = errors.New("货易通商品创建失败") + return + } + if len(resMap.Data.Ids) == 0 { + err = errors.New("货易通商品创建失败") + return + } + + toolResp = &ProductUploadResponse{ + PreviewUrl: "https://gateway.dev.cdlsxd.cn/sw//#/goods/goodsManage", + SpuNum: toolReq.GoodsList[0].GoodsInfo.SpuNum, + Id: resMap.Data.Ids[0], + } + + return toolResp, nil +} diff --git a/internal/domain/tools/hyt/product_upload/client_test.go b/internal/domain/tools/hyt/product_upload/client_test.go new file mode 100644 index 0000000..2f4b01b --- /dev/null +++ b/internal/domain/tools/hyt/product_upload/client_test.go @@ -0,0 +1,60 @@ +package product_upload + +import ( + "ai_scheduler/internal/config" + "context" + "fmt" + "testing" +) + +// Test_Call +func Test_Call(t *testing.T) { + req := &ProductUploadRequest{ + SupplierId: 261, + WarehouseId: 257, + IsDefaultWarehouse: 1, + Sort: 1, + Profit: 40, + TaxRate: 13, + GoodsList: []Goods{ + { + GoodsInfo: GoodsInfo{ + Title: "Apple iPhone 17 Pro Max 星宇橙色 256GB", + Brand: "Apple/苹果", + Category: "手机", + CostPrice: 9999.00, + GoodsAttributes: "CPU型号:A19 Pro;操作系统:iOS;机身存储:256GB;屏幕尺寸:6.86英寸;屏幕材质:OLED直屏;屏幕技术:视网膜XDR;后置摄像头:4800万像素三主摄系统(主摄4800万+超广角4800万+长焦4800万);前置摄像头:1800万像素;网络支持:5G双卡双待(移动/联通/电信);生物识别:人脸识别;防水等级:IP68;充电功率:40W;无线充电:支持;机身尺寸:163.4mm×78.0mm×8.75mm;机身重量:231g;机身颜色:星宇橙色;特征特质:轻薄、防水防尘、无线充电、NFC、磁吸无线充", + GoodsBarCode: "10181383848993", + GoodsIllustration: "Apple/苹果 iPhone 17 Pro Max 【需当面激活】支持移动联通电信 5G 双卡双待手机 星宇橙色 256GB 官方标配。搭载A19 Pro芯片,6.86英寸OLED视网膜XDR直屏,4800万像素三主摄系统,支持5G双卡双待,IP68防水防尘,40W有线充电,支持无线充电和磁吸充电。", + GoodsNum: "10181383848993", + Introduction: "Apple/苹果 iPhone 17 Pro Max 【需当面激活】支持移动联通电信 5G 双卡双待手机 星宇橙色 256GB 官方标配。搭载A19 Pro芯片,6.86英寸OLED视网膜XDR直屏,4800万像素三主摄系统,支持5G双卡双待,IP68防水防尘,40W有线充电,支持无线充电和磁吸充电。", + IsBind: 1, + SpuName: "Apple iPhone 17 Pro Max", + SpuNum: "jd_1766038130329_8721", + TaxRate: 13, + Unit: "台", + Weight: "0.231", // 单位:kg + Price: 9999.00, + SalesPrice: 9999.00, + Stock: 0, // JSON 中未提供库存信息 + Discount: 10, // JSON 中未提供折扣信息 + IsComposeGoods: 2, + IsHot: 2, + }, + GoodsMediaList: []GoodsMedia{ + { + Type: 1, + Url: "https://img10.360buyimg.com/pcpubliccms/s228x228_jfs/t1/363919/12/2409/45712/691d9970F84b99d32/9f9a5d5d16efeb79.jpg.avif", + }, + }, + }, + }, + } + toolResp, err := Call(context.Background(), config.ToolConfig{}, req) + + if err != nil { + t.Errorf("Call() error = %v", err) + } + + fmt.Printf("toolResp: %v\n", toolResp) +} diff --git a/internal/domain/tools/hyt/product_upload/types.go b/internal/domain/tools/hyt/product_upload/types.go new file mode 100644 index 0000000..947dbe5 --- /dev/null +++ b/internal/domain/tools/hyt/product_upload/types.go @@ -0,0 +1,54 @@ +package product_upload + +type ProductUploadRequest struct { + SupplierId int `json:"supplier_id"` // 供应商ID + WarehouseId int `json:"warehouse_id"` // 仓库ID + IsDefaultWarehouse int `json:"is_default_warehouse"` // 是否默认仓库 + Sort int `json:"sort"` // 排序 + Profit float64 `json:"profit"` // 利润 + TaxRate int `json:"tax_rate"` // 税率 + GoodsList []Goods `json:"goods_list"` // 商品列表 +} + +type Goods struct { + GoodsInfo GoodsInfo `json:"goods_info"` + GoodsMediaList []GoodsMedia `json:"goods_media_list"` +} + +type GoodsInfo struct { + Title string `json:"title"` // 商品名称 + Brand string `json:"brand"` // 品牌 + Category string `json:"category"` // 分类 + Discount int `json:"discount"` // 折扣 + GoodsAttributes string `json:"goods_attributes"` // 商品属性 + GoodsBarCode string `json:"goods_bar_code"` // 商品条码 + GoodsNum string `json:"goods_num"` // 商品编号 + Introduction string `json:"introduction"` // 商品介绍 + SpuName string `json:"spu_name"` // SPU名称 + SpuNum string `json:"spu_num"` // SPU编号 + Stock int `json:"stock"` // 库存 + TaxRate int `json:"tax_rate"` // 税率 + Unit string `json:"unit"` // 单位 + Weight string `json:"weight"` // 重量 + Price float64 `json:"price"` // 市场价 + SalesPrice float64 `json:"sales_price"` // 建议销售价格 + GoodsIllustration string `json:"goods_illustration"` // 商品插图 - 暂不提供 + Id int `json:"id"` // 商品ID - 无需 + CostPrice float64 `json:"cost_price"` // 成本价格 - 无需 + IsBind int `json:"is_bind"` // 是否绑定 - 默认0 + IsComposeGoods int32 `json:"is_compose_goods"` // 是否组合商品 - 默认2 + IsHot int `json:"is_hot"` // 是否热门商品 - 默认2 +} + +type GoodsMedia struct { + Remark string `json:"remark"` // 备注 + Sort int `json:"sort"` // 排序 + Type int `json:"type"` // 类型 + Url string `json:"url"` // URL +} + +type ProductUploadResponse struct { + PreviewUrl string `json:"preview_url"` // 预览URL + SpuNum string `json:"spu_code"` // SPU编码 + Id int `json:"id"` // 商品ID +} diff --git a/internal/domain/workflow/hyt/product_upload.go b/internal/domain/workflow/hyt/product_upload.go new file mode 100644 index 0000000..6ab98f1 --- /dev/null +++ b/internal/domain/workflow/hyt/product_upload.go @@ -0,0 +1,112 @@ +package hyt + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/data/constants" + toolHytPu "ai_scheduler/internal/domain/tools/hyt/product_upload" + "ai_scheduler/internal/domain/workflow/runtime" + "ai_scheduler/internal/entitys" + "context" + "encoding/json" + "fmt" + + eino_ollama "github.com/cloudwego/eino-ext/components/model/ollama" + "github.com/cloudwego/eino/components/prompt" + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/schema" +) + +const WorkflowID = "hyt.productUpload" + +func init() { + runtime.Register(WorkflowID, func(d *runtime.Deps) (runtime.Workflow, error) { + return &productUpload{cfg: d.Conf}, nil + }) +} + +type productUpload struct { + cfg *config.Config + data *ProductUploadWorkflowInput +} + +type ProductUploadWorkflowInput struct { + Text string `mapstructure:"text"` +} + +func (o *productUpload) ID() string { return WorkflowID } + +func (o *productUpload) Invoke(ctx context.Context, rec *entitys.Recognize) (map[string]any, error) { + // 构建工作流 + chain, err := o.buildWorkflow(ctx) + if err != nil { + return nil, err + } + + o.data = &ProductUploadWorkflowInput{ + Text: rec.UserContent.Text, + } + // 工作流过程调用 + output, err := chain.Invoke(ctx, o.data) + if err != nil { + return nil, err + } + + fmt.Printf("workflow output: %v\n", output) + + // 不关心输出,全部在途中输出 + return output, nil +} + +func (o *productUpload) buildWorkflow(ctx context.Context) (compose.Runnable[*ProductUploadWorkflowInput, map[string]any], error) { + // 定义工作流 + c := compose.NewChain[*ProductUploadWorkflowInput, map[string]any]() + + // AI映射工具所需参数' + paramMappingModel, err := eino_ollama.NewChatModel(ctx, &eino_ollama.ChatModelConfig{ + BaseURL: o.cfg.Ollama.BaseURL, + Timeout: o.cfg.Ollama.Timeout, + Model: o.cfg.Ollama.Model, + Thinking: &eino_ollama.ThinkValue{Value: true}, + Options: &eino_ollama.Options{Temperature: 0.5}, + }) + if err != nil { + return nil, err + } + + // 1. 构建参LLM数映射提示词 + c.AppendChatTemplate(prompt.FromMessages( + schema.FString, + schema.SystemMessage("你是一个专业的商品参数解析器,你需要根据用户输入的商品描述,解析出商品的目标参数。"), + schema.SystemMessage("目标参数:"+constants.HYTProductPropertyTemplate), + schema.UserMessage("用户输入:{{.Text}}"), + )) + // 2. 调用LLM + c.AppendChatModel(paramMappingModel) + + // 3.工具参数整理 + c.AppendLambda(compose.InvokableLambda(func(_ context.Context, in *schema.Message) (*toolHytPu.ProductUploadRequest, error) { + toolReq := &toolHytPu.ProductUploadRequest{} + if err := json.Unmarshal([]byte(in.Content), toolReq); err != nil { + return nil, err + } + return toolReq, nil + })) + + // 4.工具调用 + c.AppendLambda(compose.InvokableLambda(func(_ context.Context, in *toolHytPu.ProductUploadRequest) (*toolHytPu.ProductUploadResponse, error) { + toolRes, err := toolHytPu.Call(ctx, o.cfg.Tools.HytProductUpload, in) + return toolRes, err + })) + + // 5.结果数据映射 + c.AppendLambda(compose.InvokableLambda(func(_ context.Context, in *toolHytPu.ProductUploadResponse) (map[string]any, error) { + return map[string]any{ + "预览URL(货易通商品列表)": in.PreviewUrl, + "SPU编码": in.SpuNum, + "商品ID": in.Id, + }, nil + })) + + // 6.编译工作流 + return c.Compile(ctx) +} diff --git a/internal/domain/workflow/runtime/registry.go b/internal/domain/workflow/runtime/registry.go index c854053..bf840e6 100644 --- a/internal/domain/workflow/runtime/registry.go +++ b/internal/domain/workflow/runtime/registry.go @@ -11,7 +11,7 @@ import ( type Workflow interface { ID() string - Schema() map[string]any + // Schema() map[string]any Invoke(ctx context.Context, requireData *entitys.Recognize) (map[string]any, error) } diff --git a/internal/domain/workflow/zltx/order_after_reseller_batch.go b/internal/domain/workflow/zltx/order_after_reseller_batch.go index 9749bf7..5e5fa51 100644 --- a/internal/domain/workflow/zltx/order_after_reseller_batch.go +++ b/internal/domain/workflow/zltx/order_after_reseller_batch.go @@ -79,13 +79,13 @@ type OrderAfterSaleResellerBatchData struct { func (o *orderAfterSaleResellerBatch) ID() string { return "zltx.orderAfterSaleResellerBatch" } // Schema 返回入参约束(用于校验/表单生成) -func (o *orderAfterSaleResellerBatch) Schema() map[string]any { - return map[string]any{ - "type": "object", - "properties": map[string]any{"orderNumber": map[string]any{"type": "array", "items": map[string]any{"type": "string"}}}, - "required": []string{"orderNumber"}, - } -} +// func (o *orderAfterSaleResellerBatch) Schema() map[string]any { +// return map[string]any{ +// "type": "object", +// "properties": map[string]any{"orderNumber": map[string]any{"type": "array", "items": map[string]any{"type": "string"}}}, +// "required": []string{"orderNumber"}, +// } +// } // Invoke 调用原有编排工作流并规范化输出 func (o *orderAfterSaleResellerBatch) Invoke(ctx context.Context, rec *entitys.Recognize) (map[string]any, error) { diff --git a/internal/pkg/util/map.go b/internal/pkg/util/map.go new file mode 100644 index 0000000..5ca80c8 --- /dev/null +++ b/internal/pkg/util/map.go @@ -0,0 +1,14 @@ +package util + +import "encoding/json" + +// StructToMap 将结构体转换为 map[string]any +func StructToMap(v any) (map[string]any, error) { + b, err := json.Marshal(v) + if err != nil { + return nil, err + } + var m map[string]any + err = json.Unmarshal(b, &m) + return m, err +} diff --git a/internal/server/router/router.go b/internal/server/router/router.go index be861aa..3359984 100644 --- a/internal/server/router/router.go +++ b/internal/server/router/router.go @@ -34,6 +34,11 @@ func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionServi c.Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") c.Set("Access-Control-Allow-Headers", "Content-Type, Authorization") + // AI能力调用路由,设置不同的 CORS 头 + if strings.HasPrefix(c.Path(), "/api/v1/capability") { + c.Set("Access-Control-Allow-Headers", "Content-Type, X-Source-Key, X-Timestamp") + } + // 如果是预检请求(OPTIONS),直接返回 204 if c.Method() == "OPTIONS" { return c.SendStatus(fiber.StatusNoContent) // 204 @@ -88,7 +93,8 @@ func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionServi r.Post("/chat/history/update/content", chatHist.UpdateContent) // 能力 - r.Post("/capability/product/ingest", capabilityService.ProductIngest) // 商品数据提取 + r.Post("/capability/product/ingest", capabilityService.ProductIngest) // 商品数据提取 + r.Post("/capability/product/upload/hyt", capabilityService.ProductUploadHyt) // 货易通商品数据上传 } func routerSocket(app *fiber.App, chatService *services.ChatService) { diff --git a/internal/services/capability.go b/internal/services/capability.go index aac842d..d0eb1fe 100644 --- a/internal/services/capability.go +++ b/internal/services/capability.go @@ -2,7 +2,10 @@ package services import ( "ai_scheduler/internal/config" + "ai_scheduler/internal/data/constants" errorcode "ai_scheduler/internal/data/error" + "ai_scheduler/internal/domain/workflow/runtime" + "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg/util" "ai_scheduler/internal/pkg/utils_ollama" "context" @@ -16,12 +19,14 @@ import ( // CapabilityService 统一回调入口 type CapabilityService struct { - cfg *config.Config + cfg *config.Config + workflowManager *runtime.Registry } -func NewCapabilityService(cfg *config.Config) *CapabilityService { +func NewCapabilityService(cfg *config.Config, workflowManager *runtime.Registry) *CapabilityService { return &CapabilityService{ - cfg: cfg, + cfg: cfg, + workflowManager: workflowManager, } } @@ -34,56 +39,11 @@ type ProductIngestReq struct { Timestamp int64 `json:"timestamp"` // 商品发布时间戳 } -const ( - // 货易通商品属性模板-中文 - HYTProductPropertyTemplateZH = `{ - "条码": "string", // 商品编号 - "分类名称": "string", // 商品分类 - "货品名称": "string", // 商品名称 - "货品编号": "string", // 商品编号 - "商品货号": "string", // 商品编号 - "品牌": "string", // 商品品牌 - "单位": "string", // 商品单位,若无则使用'个' - "规格参数": "string", // 商品规格参数 - "货品说明": "string", // 商品说明 - "保质期": "string", // 商品保质期 - "保质期单位": "string", // 商品保质期单位 - "链接": "string", // - "货品图片": ["string"], // 商品多图,取1-2个即可 - "电商销售价格": "decimal(10,2)", // 商品电商销售价格 - "销售价": "decimal(10,2)", // 商品销售价格 - "供应商报价": "decimal(10,2)", // 商品供应商报价 - "税率": "number%", // 商品税率 x% - "默认供应商": "", // 空即可 - "默认存放仓库": "", // 空即可 - "备注": "", // 备注 - "长": "string", // 商品长度,decimal(10,2)+单位 - "宽": "string", // 商品宽度,decimal(10,2)+单位 - "高": "string", // 商品高度,decimal(10,2)+单位 - "重量": "string", // 商品重量(kg) - "SPU名称": "string", // 商品SPU名称 - "SPU编码": "string" // 编码串,jd_{timestamp}_rand(1000-999) - }` - - SystemPrompt = `你是一个专业的商品属性提取助手,你的任务是根据用户输入提取商品的属性信息。 - 目标属性模板:%s。 - 最终输出格式为纯JSON字符串,键值对对应目标属性和提取到的属性值。 - 最终输出不要携带markdown标识,不要携带回车换行` -) - // ProductIngest 产品数据提取 func (s *CapabilityService) ProductIngest(c *fiber.Ctx) error { - // 读取头 - token := strings.TrimSpace(c.Get("X-Source-Key")) - ts := strings.TrimSpace(c.Get("X-Timestamp")) - - // 时间窗口校验 - if ts != "" && !util.IsInTimeWindow(ts, 5*time.Minute) { - return errorcode.AuthNotFound - } - // token校验 - if token == "" || token != "A7f9KQ3mP2X8LZC4R5e" { - return errorcode.KeyErr() + // 请求头校验 + if err := s.checkRequestHeader(c); err != nil { + return err } // 解析请求参数 @@ -91,7 +51,6 @@ func (s *CapabilityService) ProductIngest(c *fiber.Ctx) error { if err := c.BodyParser(&req); err != nil { return errorcode.ParamErr("invalid request body: %v", err) } - // 必要参数校验 if req.Text == "" { return errorcode.ParamErr("missing required fields") @@ -107,7 +66,7 @@ func (s *CapabilityService) ProductIngest(c *fiber.Ctx) error { res, err := client.Chat(context.Background(), []api.Message{ { Role: "system", - Content: fmt.Sprintf(SystemPrompt, HYTProductPropertyTemplateZH), + Content: fmt.Sprintf(constants.SystemPrompt, constants.HYTProductPropertyTemplateZH), }, { Role: "user", @@ -122,8 +81,50 @@ func (s *CapabilityService) ProductIngest(c *fiber.Ctx) error { return err } + // res.Message.Content Go中map会无序,交给前端解析 + // 解析模型输出 c.JSON(res.Message.Content) return nil } + +// checkRequestHeader 校验请求头 +func (s *CapabilityService) checkRequestHeader(c *fiber.Ctx) error { + // 读取头 + token := strings.TrimSpace(c.Get("X-Source-Key")) + ts := strings.TrimSpace(c.Get("X-Timestamp")) + + // 时间窗口校验 + if ts != "" && !util.IsInTimeWindow(ts, 5*time.Minute) { + return errorcode.AuthNotFound + } + // token校验 + if token == "" || token != "A7f9KQ3mP2X8LZC4R5e" { + return errorcode.KeyErr() + } + + return nil +} + +// ProductUploadHyt 商品上传至货易通 +func (s *CapabilityService) ProductUploadHyt(c *fiber.Ctx) error { + // 请求头校验 + if err := s.checkRequestHeader(c); err != nil { + return err + } + + // 获取 body json 串 + raw := append([]byte(nil), c.BodyRaw()...) + bodyStr := string(raw) + + // 调用eino工作流,实现商品上传到货易通 + workflowId := "hyt.productUpload" + rec := &entitys.Recognize{UserContent: &entitys.RecognizeUserContent{Text: bodyStr}} + res, err := s.workflowManager.Invoke(context.Background(), workflowId, rec) + if err != nil { + return err + } + + return c.JSON(res) +} From d0ba329024f343e9eb6c7d9fd1f1786c0f7dec6b Mon Sep 17 00:00:00 2001 From: fuzhongyun <15339891972@163.com> Date: Sat, 20 Dec 2025 18:42:37 +0800 Subject: [PATCH 52/66] =?UTF-8?q?feat:=201.=E8=B0=83=E6=95=B4=E5=B1=9E?= =?UTF-8?q?=E6=80=A7=E6=A8=A1=E6=9D=BF=202.=E4=BA=AC=E4=B8=9C=E5=95=86?= =?UTF-8?q?=E5=93=81=E6=8A=93=E5=8F=96=E5=B7=A5=E4=BD=9C=E6=B5=81=203.?= =?UTF-8?q?=E6=96=B0=E5=A2=9E=E6=89=80=E9=9C=80=E5=B7=A5=E5=85=B7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/data/constants/capability.go | 23 +- .../domain/tools/hyt/product_upload/client.go | 2 +- .../tools/hyt/product_upload/client_test.go | 8 +- .../tools/hyt/supplier_search/client.go | 61 +++++ .../domain/tools/hyt/supplier_search/types.go | 24 ++ .../tools/hyt/warehouse_search/client.go | 56 +++++ .../tools/hyt/warehouse_search/types.go | 14 ++ .../domain/workflow/hyt/product_upload.go | 237 +++++++++++++++++- internal/server/router/router.go | 4 +- internal/services/capability.go | 121 +++++++-- 10 files changed, 501 insertions(+), 49 deletions(-) create mode 100644 internal/domain/tools/hyt/supplier_search/client.go create mode 100644 internal/domain/tools/hyt/supplier_search/types.go create mode 100644 internal/domain/tools/hyt/warehouse_search/client.go create mode 100644 internal/domain/tools/hyt/warehouse_search/types.go diff --git a/internal/data/constants/capability.go b/internal/data/constants/capability.go index 4ee518b..9956336 100644 --- a/internal/data/constants/capability.go +++ b/internal/data/constants/capability.go @@ -9,7 +9,6 @@ const ( const ( SystemPrompt = ` #你是一个专业的商品属性提取助手,你的任务是根据用户输入提取商品的属性信息。 - 目标属性模板:%s。 1.最终输出格式为纯JSON字符串,键值对对应目标属性和提取到的属性值。 2.最终输出不要携带markdown标识,不要携带回车换行` ) @@ -29,15 +28,16 @@ const ( "货品说明": "string", // 商品说明 "保质期": "string", // 商品保质期 "保质期单位": "string", // 商品保质期单位 - "链接": "string", // + "链接": "string", // 商品链接 "货品图片": ["string"], // 商品多图,取1-2个即可 - "电商销售价格": "decimal(10,2)", // 商品电商销售价格 - "销售价": "decimal(10,2)", // 商品销售价格 - "供应商报价": "decimal(10,2)", // 商品供应商报价 - "税率": "number%", // 商品税率 x% - "默认供应商": "", // 空即可 - "默认存放仓库": "", // 空即可 - "备注": "", // 备注 + "电商销售价格": "string", // 商品电商销售价格 decimal(10,2) + "销售价": "string", // 商品销售价格 decimal(10,2) + "供应商报价": "string", // 商品供应商报价 decimal(10,2) + "税率": "string", // 商品税率 x% + "默认供应商": "string", // 供应商名称 + "默认存放仓库": "string", // 仓库名称 + "利润": "string", // 商品利润 decimal(10,2) + "备注": "string", // 备注 "长": "string", // 商品长度,decimal(10,2)+单位 "宽": "string", // 商品宽度,decimal(10,2)+单位 "高": "string", // 商品高度,decimal(10,2)+单位 @@ -172,3 +172,8 @@ const ( const ( HYTProductListPageURL = "https://gateway.dev.cdlsxd.cn/sw//#/goods/goodsManage" ) + +// 缓存key +const ( + CapabilityProductIngestCacheKey = "ai_scheduler:capability:product_ingest:%s" +) diff --git a/internal/domain/tools/hyt/product_upload/client.go b/internal/domain/tools/hyt/product_upload/client.go index 4924b50..1cacbd9 100644 --- a/internal/domain/tools/hyt/product_upload/client.go +++ b/internal/domain/tools/hyt/product_upload/client.go @@ -20,7 +20,7 @@ func Call(ctx context.Context, cfg config.ToolConfig, toolReq *ProductUploadRequ req := l_request.Request{ Method: "Post", - Url: "http://120.55.12.245:8100/api/v1/goods/supplier/batch/add/complete", + Url: "https://gateway.dev.cdlsxd.cn/goods-admin/api/v1/goods/supplier/batch/add/complete", Json: apiReq, } res, err := req.Send() diff --git a/internal/domain/tools/hyt/product_upload/client_test.go b/internal/domain/tools/hyt/product_upload/client_test.go index 2f4b01b..5e8b111 100644 --- a/internal/domain/tools/hyt/product_upload/client_test.go +++ b/internal/domain/tools/hyt/product_upload/client_test.go @@ -19,10 +19,10 @@ func Test_Call(t *testing.T) { GoodsList: []Goods{ { GoodsInfo: GoodsInfo{ - Title: "Apple iPhone 17 Pro Max 星宇橙色 256GB", - Brand: "Apple/苹果", - Category: "手机", - CostPrice: 9999.00, + Title: "Apple iPhone 17 Pro Max 星宇橙色 256GB", + Brand: "Apple/苹果", + Category: "手机", + // CostPrice: 9999.00, GoodsAttributes: "CPU型号:A19 Pro;操作系统:iOS;机身存储:256GB;屏幕尺寸:6.86英寸;屏幕材质:OLED直屏;屏幕技术:视网膜XDR;后置摄像头:4800万像素三主摄系统(主摄4800万+超广角4800万+长焦4800万);前置摄像头:1800万像素;网络支持:5G双卡双待(移动/联通/电信);生物识别:人脸识别;防水等级:IP68;充电功率:40W;无线充电:支持;机身尺寸:163.4mm×78.0mm×8.75mm;机身重量:231g;机身颜色:星宇橙色;特征特质:轻薄、防水防尘、无线充电、NFC、磁吸无线充", GoodsBarCode: "10181383848993", GoodsIllustration: "Apple/苹果 iPhone 17 Pro Max 【需当面激活】支持移动联通电信 5G 双卡双待手机 星宇橙色 256GB 官方标配。搭载A19 Pro芯片,6.86英寸OLED视网膜XDR直屏,4800万像素三主摄系统,支持5G双卡双待,IP68防水防尘,40W有线充电,支持无线充电和磁吸充电。", diff --git a/internal/domain/tools/hyt/supplier_search/client.go b/internal/domain/tools/hyt/supplier_search/client.go new file mode 100644 index 0000000..ebffe0b --- /dev/null +++ b/internal/domain/tools/hyt/supplier_search/client.go @@ -0,0 +1,61 @@ +package supplier_search + +import ( + "ai_scheduler/internal/pkg/l_request" + "context" + "encoding/json" + "errors" + "fmt" +) + +func Call(ctx context.Context, name string) (int, error) { + if name == "" { + return 0, errors.New("supplier name is empty") + } + + reqBody := SearchRequest{ + Page: 1, + Limit: 1, + Search: SearchCondition{ + Name: name, + }, + } + + apiReq := make(map[string]interface{}) + bytes, _ := json.Marshal(reqBody) + _ = json.Unmarshal(bytes, &apiReq) + + req := l_request.Request{ + Method: "Post", + Url: "https://gateway.dev.cdlsxd.cn/goods-admin/api/v1/supplier/list", + Json: apiReq, + Headers: map[string]string{ + "User-Agent": "Apifox/1.0.0 (https://apifox.com)", + "Content-Type": "application/json", + }, + } + + res, err := req.Send() + if err != nil { + return 0, err + } + + if res.StatusCode != 200 { + return 0, fmt.Errorf("supplier search failed with status code: %d", res.StatusCode) + } + + var resData SearchResponse + if err := json.Unmarshal([]byte(res.Text), &resData); err != nil { + return 0, fmt.Errorf("failed to parse supplier search response: %w", err) + } + + if resData.Code != 200 { + return 0, fmt.Errorf("supplier search business error: %s", resData.Msg) + } + + if len(resData.Data.List) == 0 { + return 0, fmt.Errorf("supplier not found: %s", name) + } + + return resData.Data.List[0].ID, nil +} diff --git a/internal/domain/tools/hyt/supplier_search/types.go b/internal/domain/tools/hyt/supplier_search/types.go new file mode 100644 index 0000000..46a452c --- /dev/null +++ b/internal/domain/tools/hyt/supplier_search/types.go @@ -0,0 +1,24 @@ +package supplier_search + +type SearchRequest struct { + Page int `json:"page"` + Limit int `json:"limit"` + Search SearchCondition `json:"search"` +} + +type SearchCondition struct { + Name string `json:"name"` +} + +type SearchResponse struct { + Code int `json:"code"` + Msg string `json:"msg"` + Data struct { + List []SupplierInfo `json:"list"` + } `json:"data"` +} + +type SupplierInfo struct { + ID int `json:"id"` + Name string `json:"name"` +} diff --git a/internal/domain/tools/hyt/warehouse_search/client.go b/internal/domain/tools/hyt/warehouse_search/client.go new file mode 100644 index 0000000..7b0190b --- /dev/null +++ b/internal/domain/tools/hyt/warehouse_search/client.go @@ -0,0 +1,56 @@ +package warehouse_search + +import ( + "ai_scheduler/internal/pkg/l_request" + "context" + "encoding/json" + "fmt" +) + +func Call(ctx context.Context, name string) (int, error) { + if name == "" { + // 如果没有仓库名,返回0,不报错,由上层业务决定是否允许 + return 0, nil + } + + // GET 请求参数 + params := map[string]string{ + "name": name, + "page": "1", + "limit": "1", + } + + req := l_request.Request{ + Method: "Get", + Url: "https://gateway.dev.cdlsxd.cn/goods-admin/api/v1/warehouse/list", + Params: params, + Headers: map[string]string{ + "User-Agent": "Apifox/1.0.0 (https://apifox.com)", + "Content-Type": "application/json", + }, + } + + res, err := req.Send() + if err != nil { + return 0, err + } + + if res.StatusCode != 200 { + return 0, fmt.Errorf("warehouse search failed with status code: %d", res.StatusCode) + } + + var resData SearchResponse + if err := json.Unmarshal([]byte(res.Text), &resData); err != nil { + return 0, fmt.Errorf("failed to parse warehouse search response: %w", err) + } + + if resData.Code != 200 { + return 0, fmt.Errorf("warehouse search business error: %s", resData.Msg) + } + + if len(resData.Data.List) == 0 { + return 0, fmt.Errorf("warehouse not found: %s", name) + } + + return resData.Data.List[0].ID, nil +} diff --git a/internal/domain/tools/hyt/warehouse_search/types.go b/internal/domain/tools/hyt/warehouse_search/types.go new file mode 100644 index 0000000..a5ae237 --- /dev/null +++ b/internal/domain/tools/hyt/warehouse_search/types.go @@ -0,0 +1,14 @@ +package warehouse_search + +type SearchResponse struct { + Code int `json:"code"` + Msg string `json:"msg"` + Data struct { + List []WarehouseInfo `json:"list"` + } `json:"data"` +} + +type WarehouseInfo struct { + ID int `json:"id"` + Name string `json:"name"` +} diff --git a/internal/domain/workflow/hyt/product_upload.go b/internal/domain/workflow/hyt/product_upload.go index 6ab98f1..4b85f37 100644 --- a/internal/domain/workflow/hyt/product_upload.go +++ b/internal/domain/workflow/hyt/product_upload.go @@ -3,12 +3,17 @@ package hyt import ( "ai_scheduler/internal/config" "ai_scheduler/internal/data/constants" - toolHytPu "ai_scheduler/internal/domain/tools/hyt/product_upload" + toolPu "ai_scheduler/internal/domain/tools/hyt/product_upload" + toolSs "ai_scheduler/internal/domain/tools/hyt/supplier_search" + toolWs "ai_scheduler/internal/domain/tools/hyt/warehouse_search" "ai_scheduler/internal/domain/workflow/runtime" "ai_scheduler/internal/entitys" "context" "encoding/json" "fmt" + "strconv" + "strings" + "sync" eino_ollama "github.com/cloudwego/eino-ext/components/model/ollama" "github.com/cloudwego/eino/components/prompt" @@ -36,8 +41,8 @@ type ProductUploadWorkflowInput struct { func (o *productUpload) ID() string { return WorkflowID } func (o *productUpload) Invoke(ctx context.Context, rec *entitys.Recognize) (map[string]any, error) { - // 构建工作流 - chain, err := o.buildWorkflow(ctx) + // 构建工作流 (使用 V2 Graph 版本) + runnable, err := o.buildWorkflowV2(ctx) if err != nil { return nil, err } @@ -46,17 +51,58 @@ func (o *productUpload) Invoke(ctx context.Context, rec *entitys.Recognize) (map Text: rec.UserContent.Text, } // 工作流过程调用 - output, err := chain.Invoke(ctx, o.data) + output, err := runnable.Invoke(ctx, o.data) if err != nil { return nil, err } fmt.Printf("workflow output: %v\n", output) - // 不关心输出,全部在途中输出 return output, nil } +// ProductIngestData 对应 HYTProductPropertyTemplateZH 的结构 +type ProductIngestData struct { + BarCode string `json:"条码"` + CategoryName string `json:"分类名称"` + GoodsName string `json:"货品名称"` + GoodsNum string `json:"货品编号"` + GoodsArticleNum string `json:"商品货号"` + Brand string `json:"品牌"` + Unit string `json:"单位"` + Specs string `json:"规格参数"` + Description string `json:"货品说明"` + ShelfLife string `json:"保质期"` + ShelfLifeUnit string `json:"保质期单位"` + Link string `json:"链接"` + Images []string `json:"货品图片"` + EPrice string `json:"电商销售价格"` + SalesPrice string `json:"销售价"` + SupplierPrice string `json:"供应商报价"` + TaxRate string `json:"税率"` + SupplierName string `json:"默认供应商"` + WarehouseName string `json:"默认存放仓库"` + Remark string `json:"备注"` + Length string `json:"长"` + Width string `json:"宽"` + Height string `json:"高"` + Weight string `json:"重量"` + SpuName string `json:"SPU名称"` + SpuCode string `json:"SPU编码"` + Profit string `json:"利润"` +} + +// ProductUploadContext Graph 执行上下文状态 +type ProductUploadContext struct { + mu *sync.Mutex + InputText string + IngestData *ProductIngestData + UploadReq *toolPu.ProductUploadRequest + SupplierName string + WarehouseName string + UploadResp *toolPu.ProductUploadResponse +} + func (o *productUpload) buildWorkflow(ctx context.Context) (compose.Runnable[*ProductUploadWorkflowInput, map[string]any], error) { // 定义工作流 c := compose.NewChain[*ProductUploadWorkflowInput, map[string]any]() @@ -73,7 +119,7 @@ func (o *productUpload) buildWorkflow(ctx context.Context) (compose.Runnable[*Pr return nil, err } - // 1. 构建参LLM数映射提示词 + // 1. 构建参数LLM数映射提示词 c.AppendChatTemplate(prompt.FromMessages( schema.FString, schema.SystemMessage("你是一个专业的商品参数解析器,你需要根据用户输入的商品描述,解析出商品的目标参数。"), @@ -84,8 +130,8 @@ func (o *productUpload) buildWorkflow(ctx context.Context) (compose.Runnable[*Pr c.AppendChatModel(paramMappingModel) // 3.工具参数整理 - c.AppendLambda(compose.InvokableLambda(func(_ context.Context, in *schema.Message) (*toolHytPu.ProductUploadRequest, error) { - toolReq := &toolHytPu.ProductUploadRequest{} + c.AppendLambda(compose.InvokableLambda(func(_ context.Context, in *schema.Message) (*toolPu.ProductUploadRequest, error) { + toolReq := &toolPu.ProductUploadRequest{} if err := json.Unmarshal([]byte(in.Content), toolReq); err != nil { return nil, err } @@ -93,13 +139,13 @@ func (o *productUpload) buildWorkflow(ctx context.Context) (compose.Runnable[*Pr })) // 4.工具调用 - c.AppendLambda(compose.InvokableLambda(func(_ context.Context, in *toolHytPu.ProductUploadRequest) (*toolHytPu.ProductUploadResponse, error) { - toolRes, err := toolHytPu.Call(ctx, o.cfg.Tools.HytProductUpload, in) + c.AppendLambda(compose.InvokableLambda(func(_ context.Context, in *toolPu.ProductUploadRequest) (*toolPu.ProductUploadResponse, error) { + toolRes, err := toolPu.Call(ctx, o.cfg.Tools.HytProductUpload, in) return toolRes, err })) // 5.结果数据映射 - c.AppendLambda(compose.InvokableLambda(func(_ context.Context, in *toolHytPu.ProductUploadResponse) (map[string]any, error) { + c.AppendLambda(compose.InvokableLambda(func(_ context.Context, in *toolPu.ProductUploadResponse) (map[string]any, error) { return map[string]any{ "预览URL(货易通商品列表)": in.PreviewUrl, "SPU编码": in.SpuNum, @@ -110,3 +156,172 @@ func (o *productUpload) buildWorkflow(ctx context.Context) (compose.Runnable[*Pr // 6.编译工作流 return c.Compile(ctx) } + +// buildWorkflowV2 构建基于 Graph 的并行工作流 +func (o *productUpload) buildWorkflowV2(ctx context.Context) (compose.Runnable[*ProductUploadWorkflowInput, map[string]any], error) { + g := compose.NewGraph[*ProductUploadWorkflowInput, map[string]any]() + + // 1. DataMapping 节点: 解析 JSON -> 填充基础 Request -> 提取供应商/仓库名 + g.AddLambdaNode("data_mapping", compose.InvokableLambda(func(ctx context.Context, in *ProductUploadWorkflowInput) (*ProductUploadContext, error) { + state := &ProductUploadContext{ + mu: &sync.Mutex{}, // 初始化锁 + InputText: in.Text, + UploadReq: &toolPu.ProductUploadRequest{ + GoodsList: make([]toolPu.Goods, 1), // 初始化一个商品 + }, + } + + // 解析用户输入的中文 JSON + var ingestData ProductIngestData + if err := json.Unmarshal([]byte(in.Text), &ingestData); err != nil { + return nil, fmt.Errorf("解析商品数据失败: %w", err) + } + state.IngestData = &ingestData + state.SupplierName = ingestData.SupplierName + state.WarehouseName = ingestData.WarehouseName + + // 映射字段到 UploadReq + goodsInfo := &state.UploadReq.GoodsList[0].GoodsInfo + goodsInfo.Title = ingestData.GoodsName + goodsInfo.Brand = ingestData.Brand + goodsInfo.Category = ingestData.CategoryName + goodsInfo.GoodsBarCode = ingestData.BarCode + goodsInfo.GoodsNum = ingestData.GoodsNum + if goodsInfo.GoodsNum == "" { + goodsInfo.GoodsNum = ingestData.GoodsArticleNum + } + goodsInfo.Unit = ingestData.Unit + goodsInfo.GoodsAttributes = ingestData.Specs + goodsInfo.Introduction = ingestData.Description + goodsInfo.SpuName = ingestData.SpuName + goodsInfo.SpuNum = ingestData.SpuCode + goodsInfo.Weight = ingestData.Weight + + // 数值处理 + if val, err := strconv.ParseFloat(strings.TrimSuffix(ingestData.SalesPrice, "元"), 64); err == nil { + goodsInfo.SalesPrice = val + } + if val, err := strconv.ParseFloat(strings.TrimSuffix(ingestData.EPrice, "元"), 64); err == nil { + goodsInfo.Price = val // 假设电商价为市场价 + } + // 价格兼容 + if goodsInfo.CostPrice == 0 { + goodsInfo.CostPrice = goodsInfo.Price + } + // 税率处理 "13%" -> 13 + taxStr := strings.TrimSuffix(strings.TrimSuffix(ingestData.TaxRate, "%"), " ") + if val, err := strconv.Atoi(taxStr); err == nil { + goodsInfo.TaxRate = val + state.UploadReq.TaxRate = val + } + // 利润处理 + if val, err := strconv.ParseFloat(strings.TrimSuffix(ingestData.Profit, "元"), 64); err == nil { + state.UploadReq.Profit = val + } + + // 图片处理 + for i, imgUrl := range ingestData.Images { + state.UploadReq.GoodsList[0].GoodsMediaList = append(state.UploadReq.GoodsList[0].GoodsMediaList, toolPu.GoodsMedia{ + Url: imgUrl, + Type: 1, // 图片 + Sort: i, + }) + } + + // 默认值字段 + goodsInfo.IsComposeGoods = 2 + goodsInfo.IsBind = 0 + goodsInfo.IsHot = 2 + state.UploadReq.IsDefaultWarehouse = 1 + state.UploadReq.Sort = 1 + + return state, nil + })) + + // 2. 获取供应商ID 节点 + g.AddLambdaNode("get_supplier_id", compose.InvokableLambda(func(ctx context.Context, state *ProductUploadContext) (*ProductUploadContext, error) { + if state.SupplierName != "" { + supplierId, err := toolSs.Call(ctx, state.SupplierName) + if err != nil { + // 记录日志,但不阻断流程,可能允许 ID 为 0 + fmt.Printf("warning: failed to get supplier id for %s: %v\n", state.SupplierName, err) + } else { + state.mu.Lock() + defer state.mu.Unlock() + state.UploadReq.SupplierId = supplierId + } + } + return state, nil + })) + + // 3. 获取仓库ID 节点 + g.AddLambdaNode("get_warehouse_id", compose.InvokableLambda(func(ctx context.Context, state *ProductUploadContext) (*ProductUploadContext, error) { + if state.WarehouseName != "" { + warehouseId, err := toolWs.Call(ctx, state.WarehouseName) + if err != nil { + fmt.Printf("warning: failed to get warehouse id for %s: %v\n", state.WarehouseName, err) + } else { + state.mu.Lock() + defer state.mu.Unlock() + state.UploadReq.WarehouseId = warehouseId + } + } + return state, nil + })) + + // 4. 合并/同步节点 + g.AddLambdaNode("merge_node", compose.InvokableLambda(func(ctx context.Context, state *ProductUploadContext) (*ProductUploadContext, error) { + // 可以在这里做最终校验,例如必须有 SupplierId + if state.UploadReq.SupplierId == 0 { + return nil, fmt.Errorf("供应商获取失败") + } + return state, nil + })) + + // 5. 上传节点 + g.AddLambdaNode("upload_product", compose.InvokableLambda(func(ctx context.Context, state *ProductUploadContext) (*ProductUploadContext, error) { + toolRes, err := toolPu.Call(ctx, o.cfg.Tools.HytProductUpload, state.UploadReq) + if err != nil { + return nil, err + } + state.UploadResp = toolRes + return state, nil + })) + + // 6. 结果格式化节点 + g.AddLambdaNode("format_output", compose.InvokableLambda(func(ctx context.Context, state *ProductUploadContext) (map[string]any, error) { + if state.UploadResp == nil { + return nil, fmt.Errorf("upload response is nil") + } + return map[string]any{ + "预览URL(货易通商品列表)": state.UploadResp.PreviewUrl, + "SPU编码": state.UploadResp.SpuNum, + "商品ID": state.UploadResp.Id, + }, nil + })) + + // 构建边 + // Start -> Mapping + g.AddEdge(compose.START, "data_mapping") + + // 串行化执行以规避 Eino 指针合并问题 + // Mapping -> Supplier + g.AddEdge("data_mapping", "get_supplier_id") + + // Supplier -> Warehouse + g.AddEdge("get_supplier_id", "get_warehouse_id") + + // Warehouse -> Merge (虽然串行了,保留 Merge 节点做校验) + g.AddEdge("get_warehouse_id", "merge_node") + + // Merge -> Upload + g.AddEdge("merge_node", "upload_product") + + // Upload -> Format + g.AddEdge("upload_product", "format_output") + + // Format -> END + g.AddEdge("format_output", compose.END) + + return g.Compile(ctx) +} diff --git a/internal/server/router/router.go b/internal/server/router/router.go index 3359984..15e1554 100644 --- a/internal/server/router/router.go +++ b/internal/server/router/router.go @@ -93,8 +93,8 @@ func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionServi r.Post("/chat/history/update/content", chatHist.UpdateContent) // 能力 - r.Post("/capability/product/ingest", capabilityService.ProductIngest) // 商品数据提取 - r.Post("/capability/product/upload/hyt", capabilityService.ProductUploadHyt) // 货易通商品数据上传 + r.Post("/capability/product/ingest", capabilityService.ProductIngest) // 商品数据提取 + r.Post("/capability/product/ingest/:thread_id/confirm", capabilityService.ProductIngestConfirm) // 商品数据提取确认 } func routerSocket(app *fiber.App, chatService *services.ChatService) { diff --git a/internal/services/capability.go b/internal/services/capability.go index d0eb1fe..89d97cc 100644 --- a/internal/services/capability.go +++ b/internal/services/capability.go @@ -8,39 +8,55 @@ import ( "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg/util" "ai_scheduler/internal/pkg/utils_ollama" + "ai_scheduler/utils" "context" + "encoding/json" "fmt" "strings" "time" + hytWorkflow "ai_scheduler/internal/domain/workflow/hyt" + "github.com/gofiber/fiber/v2" + "github.com/google/uuid" "github.com/ollama/ollama/api" + "github.com/redis/go-redis/v9" ) // CapabilityService 统一回调入口 type CapabilityService struct { cfg *config.Config workflowManager *runtime.Registry + rdsCli *redis.Client } -func NewCapabilityService(cfg *config.Config, workflowManager *runtime.Registry) *CapabilityService { +func NewCapabilityService(cfg *config.Config, workflowManager *runtime.Registry, rdb *utils.Rdb) *CapabilityService { return &CapabilityService{ cfg: cfg, workflowManager: workflowManager, + rdsCli: rdb.Rdb, } } // 产品数据提取入参 type ProductIngestReq struct { - Url string `json:"url"` // 商品详情页URL - Title string `json:"title"` // 商品标题 - Text string `json:"text"` // 商品描述 - Images []string `json:"images"` // 商品图片URL列表 - Timestamp int64 `json:"timestamp"` // 商品发布时间戳 + SysId string `json:"sys_id"` // 业务系统ID - 当前仅支持货易通(hyt) + Url string `json:"url"` // 商品详情页URL + Title string `json:"title"` // 商品标题 + Text string `json:"text"` // 商品描述 + Images []string `json:"images"` // 商品图片URL列表 +} + +type ProductIngestResp struct { + ThreadId string `json:"thread_id"` // 线程ID,后续确认调用时需要 + SysId string `json:"sys_id"` // 业务系统ID + MetaData any `json:"meta"` // 元数据 + Draft string `json:"draft"` // 草稿数据,后续确认调用时需要 } // ProductIngest 产品数据提取 func (s *CapabilityService) ProductIngest(c *fiber.Ctx) error { + ctx := context.Background() // 请求头校验 if err := s.checkRequestHeader(c); err != nil { return err @@ -52,21 +68,33 @@ func (s *CapabilityService) ProductIngest(c *fiber.Ctx) error { return errorcode.ParamErr("invalid request body: %v", err) } // 必要参数校验 - if req.Text == "" { + if req.Text == "" || req.SysId == "" { return errorcode.ParamErr("missing required fields") } + // 映射目标系统商品属性中文模板 + var sysProductPropertyTemplateZH string + switch req.SysId { + case "hyt": // 货易通 + sysProductPropertyTemplateZH = constants.HYTProductPropertyTemplateZH + default: + return errorcode.ParamErr("invalid sys_id") + } + // 模型调用 client, cleanup, err := utils_ollama.NewClient(s.cfg) if err != nil { return err } defer cleanup() - - res, err := client.Chat(context.Background(), []api.Message{ + res, err := client.Chat(ctx, []api.Message{ { Role: "system", - Content: fmt.Sprintf(constants.SystemPrompt, constants.HYTProductPropertyTemplateZH), + Content: constants.SystemPrompt, + }, + { + Role: "assistant", + Content: fmt.Sprintf("目标属性模板:%s。", sysProductPropertyTemplateZH), }, { Role: "user", @@ -81,10 +109,23 @@ func (s *CapabilityService) ProductIngest(c *fiber.Ctx) error { return err } - // res.Message.Content Go中map会无序,交给前端解析 + // 生成thread_id + threadId := uuid.NewString() + resp := &ProductIngestResp{ + ThreadId: threadId, + SysId: req.SysId, + MetaData: req, + Draft: res.Message.Content, // Go中map会无序,交给前端解析 + } + respJson, _ := json.Marshal(resp) + + // 存redis缓存 + if err = s.rdsCli.Set(ctx, fmt.Sprintf(constants.CapabilityProductIngestCacheKey, threadId), respJson, 30*time.Minute).Err(); err != nil { + return err + } // 解析模型输出 - c.JSON(res.Message.Content) + c.JSON(resp) return nil } @@ -97,7 +138,7 @@ func (s *CapabilityService) checkRequestHeader(c *fiber.Ctx) error { // 时间窗口校验 if ts != "" && !util.IsInTimeWindow(ts, 5*time.Minute) { - return errorcode.AuthNotFound + // return errorcode.AuthNotFound } // token校验 if token == "" || token != "A7f9KQ3mP2X8LZC4R5e" { @@ -107,21 +148,57 @@ func (s *CapabilityService) checkRequestHeader(c *fiber.Ctx) error { return nil } -// ProductUploadHyt 商品上传至货易通 -func (s *CapabilityService) ProductUploadHyt(c *fiber.Ctx) error { +type ProductIngestConfirmReq struct { + ThreadId string `json:"thread_id"` // 线程ID + Confirmed string `json:"confirmed"` // 已确认数据json字符串 +} + +// ProductIngestConfirm 商品数据提取确认 +func (s *CapabilityService) ProductIngestConfirm(c *fiber.Ctx) error { + ctx := context.Background() + // 请求头校验 if err := s.checkRequestHeader(c); err != nil { return err } + // 获取路径参数中的 thread_id + threadId := c.Params("thread_id") + if threadId == "" { + return errorcode.ParamErr("missing required fields") + } + // 解析请求参数 body + req := ProductIngestConfirmReq{} + if err := c.BodyParser(&req); err != nil { + return errorcode.ParamErr("invalid request body: %v", err) + } + // 必要参数校验 + if req.Confirmed == "" || threadId == "" { + return errorcode.ParamErr("missing required fields") + } - // 获取 body json 串 - raw := append([]byte(nil), c.BodyRaw()...) - bodyStr := string(raw) + // 校验线程ID是否存在 + resp, err := s.rdsCli.Get(ctx, fmt.Sprintf(constants.CapabilityProductIngestCacheKey, threadId)).Result() + if err != nil { + return errorcode.ParamErr("invalid thread_id") + } + var respData ProductIngestResp + if err = json.Unmarshal([]byte(resp), &respData); err != nil { + return errorcode.ParamErr("invalid thread_id data") + } - // 调用eino工作流,实现商品上传到货易通 - workflowId := "hyt.productUpload" - rec := &entitys.Recognize{UserContent: &entitys.RecognizeUserContent{Text: bodyStr}} - res, err := s.workflowManager.Invoke(context.Background(), workflowId, rec) + // 映射目标系统工作流ID + var workflowId string + switch respData.SysId { + // 货易通 + case "hyt": + workflowId = hytWorkflow.WorkflowID + default: + return errorcode.ParamErr("invalid sys_id") + } + + // 调用eino工作流,实现商品上传到目标系统 + rec := &entitys.Recognize{UserContent: &entitys.RecognizeUserContent{Text: req.Confirmed}} + res, err := s.workflowManager.Invoke(ctx, workflowId, rec) if err != nil { return err } From d8df571cce11f718ea53f0bff3bad5b3b189b962 Mon Sep 17 00:00:00 2001 From: fuzhongyun <15339891972@163.com> Date: Mon, 22 Dec 2025 11:14:15 +0800 Subject: [PATCH 53/66] =?UTF-8?q?feat:=201.=E5=A2=9E=E5=8A=A0=20eino=20too?= =?UTF-8?q?l=20=E7=9B=B8=E5=85=B3=E9=85=8D=E7=BD=AE=EF=BC=8C=E8=B4=A7?= =?UTF-8?q?=E6=98=93=E9=80=9A=E5=95=86=E5=93=81=E4=B8=8A=E4=BC=A0=E5=8F=82?= =?UTF-8?q?=E6=95=B0=E9=85=8D=E7=BD=AE=E5=8C=96=202.=20eino=20tool=20?= =?UTF-8?q?=E6=B3=A8=E5=86=8C=E6=96=B9=E6=B3=95=E8=B0=83=E6=95=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/config_env.yaml | 16 +- internal/config/config.go | 13 +- internal/data/constants/capability.go | 136 +---------------- .../domain/tools/hyt/product_upload/client.go | 16 +- .../tools/hyt/product_upload/client_test.go | 3 +- .../tools/hyt/supplier_search/client.go | 19 ++- .../tools/hyt/warehouse_search/client.go | 15 +- internal/domain/tools/registry.go | 31 ++-- .../domain/workflow/hyt/product_upload.go | 138 +++++++----------- internal/domain/workflow/provider_set.go | 4 +- internal/domain/workflow/registry.go | 6 +- internal/domain/workflow/runtime/registry.go | 6 +- 12 files changed, 155 insertions(+), 248 deletions(-) diff --git a/config/config_env.yaml b/config/config_env.yaml index aa8b07b..a10a7d4 100644 --- a/config/config_env.yaml +++ b/config/config_env.yaml @@ -74,11 +74,19 @@ tools: zltxOrderAfterSaleResellerBatch: enabled: true base_url: "https://gateway.dev.cdlsxd.cn/zltx_api/admin/afterSales/reseller_pre_ai" + +# eino tool 配置 +eino_tools: + # 货易通商品上传 hytProductUpload: - enabled: true - base_url: "https://gateway.dev.cdlsxd.cn/zltx_api/admin/oursProduct" - - + base_url: "https://gateway.dev.cdlsxd.cn/goods-admin/api/v1/goods/supplier/batch/add/complete" + add_url: "https://gateway.dev.cdlsxd.cn/sw//#/goods/goodsManage" + # 货易通供应商查询 + hytSupplierSearch: + base_url: "https://gateway.dev.cdlsxd.cn/goods-admin/api/v1/supplier/list" + # 货易通仓库查询 + hytWarehouseSearch: + base_url: "https://gateway.dev.cdlsxd.cn/goods-admin/api/v1/warehouse/list" default_prompt: img_recognize: diff --git a/internal/config/config.go b/internal/config/config.go index 0c3f4dd..5cbb506 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -15,6 +15,7 @@ type Config struct { Coze CozeConfig `mapstructure:"coze"` Sys SysConfig `mapstructure:"sys"` Tools ToolsConfig `mapstructure:"tools"` + EinoTools EinoToolsConfig `mapstructure:"eino_tools"` Logging LoggingConfig `mapstructure:"logging"` Redis Redis `mapstructure:"redis"` DB DB `mapstructure:"db"` @@ -141,8 +142,6 @@ type ToolsConfig struct { CozeExpress ToolConfig `mapstructure:"cozeExpress"` // Coze 公司查询工具 CozeCompany ToolConfig `mapstructure:"cozeCompany"` - // 货易通商品上传 - HytProductUpload ToolConfig `mapstructure:"hytProductUpload"` } // ToolConfig 单个工具配置 @@ -155,6 +154,16 @@ type ToolConfig struct { AddURL string `mapstructure:"add_url"` } +// EinoToolsConfig eino tool 配置 +type EinoToolsConfig struct { + // 货易通商品上传 + HytProductUpload ToolConfig `mapstructure:"hytProductUpload"` + // 货易通供应商查询 + HytSupplierSearch ToolConfig `mapstructure:"hytSupplierSearch"` + // 货易通仓库查询 + HytWarehouseSearch ToolConfig `mapstructure:"hytWarehouseSearch"` +} + // LoggingConfig 日志配置 type LoggingConfig struct { Level string `mapstructure:"level"` diff --git a/internal/data/constants/capability.go b/internal/data/constants/capability.go index 9956336..fab9db7 100644 --- a/internal/data/constants/capability.go +++ b/internal/data/constants/capability.go @@ -32,11 +32,6 @@ const ( "货品图片": ["string"], // 商品多图,取1-2个即可 "电商销售价格": "string", // 商品电商销售价格 decimal(10,2) "销售价": "string", // 商品销售价格 decimal(10,2) - "供应商报价": "string", // 商品供应商报价 decimal(10,2) - "税率": "string", // 商品税率 x% - "默认供应商": "string", // 供应商名称 - "默认存放仓库": "string", // 仓库名称 - "利润": "string", // 商品利润 decimal(10,2) "备注": "string", // 备注 "长": "string", // 商品长度,decimal(10,2)+单位 "宽": "string", // 商品宽度,decimal(10,2)+单位 @@ -44,135 +39,14 @@ const ( "重量": "string", // 商品重量(kg) "SPU名称": "string", // 商品SPU名称 "SPU编码": "string" // 编码串,jd_{timestamp}_rand(1000-999) + "供应商报价": "string", // 商品供应商报价 decimal(10,2) + "税率": "string", // 商品税率 x% + "利润": "string", // 商品利润 decimal(10,2) + "默认供应商": "string", // 供应商名称 + "默认存放仓库": "string", // 仓库名称 }` ) -// 商品属性模板 -const ( - HYTProductPropertyTemplate = `{ - "important_data": { - "type": "object", - "properties": { - "supplier_name": { - "type": "string", - "description": "供应商名称" - }, - "warehouse_name": { - "type": "string", - "description": "仓库名称" - }, - "profit": { - "type": "float64", - "description": "利润 decimal(10,2)" - }, - "tax_rate": { - "type": "integer", - "description": "税率 (x)%" - } - } - }, - "goods_info": { - "type": "object", - "properties": { - "title": { - "type": "string", - "description": "商品名称" - }, - "brand": { - "type": "string", - "description": "品牌" - }, - "category": { - "type": "string", - "description": "分类" - }, - "price": { - "type": "float64", - "description": "市场价 decimal(10,2)" - }, - "sales_price": { - "type": "float64", - "description": "建议销售价 decimal(10,2)" - }, - "discount": { - "type": "integer", - "description": "折扣百分比 公式:(市场价-建议销售价)/市场价*100" - }, - "goods_attributes": { - "type": "string", - "description": "商品属性" - }, - "goods_bar_code": { - "type": "string", - "description": "商品条码" - }, - "goods_illustration": { - "type": "string", - "description": "商品插图" - }, - "goods_num": { - "type": "string", - "description": "商品编号" - }, - "introduction": { - "type": "string", - "description": "商品介绍" - }, - "spu_name": { - "type": "string", - "description": "SPU名称" - }, - "spu_num": { - "type": "string", - "description": "SPU编号" - }, - "stock": { - "type": "integer", - "description": "库存" - }, - "tax_rate": { - "type": "integer", - "description": "税率" - }, - "unit": { - "type": "string", - "description": "单位" - }, - "weight": { - "type": "string", - "description": "重量" - } - } - }, - "goods_media_list": { - "type": "array", - "description": "商品媒体文件列表", - "items": { - "type": "object", - "properties": { - "url": { - "type": "string", - "description": "媒体文件URL" - }, - "type": { - "type": "integer", - "description": "媒体类型(1:图片, 2:视频)" - }, - "sort": { - "type": "integer", - "description": "排序序号" - } - } - } - } - }` -) - -// 外部平台地址 -const ( - HYTProductListPageURL = "https://gateway.dev.cdlsxd.cn/sw//#/goods/goodsManage" -) - // 缓存key const ( CapabilityProductIngestCacheKey = "ai_scheduler:capability:product_ingest:%s" diff --git a/internal/domain/tools/hyt/product_upload/client.go b/internal/domain/tools/hyt/product_upload/client.go index 1cacbd9..6ebcc62 100644 --- a/internal/domain/tools/hyt/product_upload/client.go +++ b/internal/domain/tools/hyt/product_upload/client.go @@ -9,7 +9,17 @@ import ( "errors" ) -func Call(ctx context.Context, cfg config.ToolConfig, toolReq *ProductUploadRequest) (toolResp *ProductUploadResponse, err error) { +type Client struct { + cfg config.ToolConfig +} + +func New(cfg config.ToolConfig) *Client { + return &Client{ + cfg: cfg, + } +} + +func (c *Client) Call(ctx context.Context, toolReq *ProductUploadRequest) (toolResp *ProductUploadResponse, err error) { // 商品有且只能有一个 if len(toolReq.GoodsList) != 1 { err = errors.New("商品只能有一个") @@ -20,7 +30,7 @@ func Call(ctx context.Context, cfg config.ToolConfig, toolReq *ProductUploadRequ req := l_request.Request{ Method: "Post", - Url: "https://gateway.dev.cdlsxd.cn/goods-admin/api/v1/goods/supplier/batch/add/complete", + Url: c.cfg.BaseURL, Json: apiReq, } res, err := req.Send() @@ -51,7 +61,7 @@ func Call(ctx context.Context, cfg config.ToolConfig, toolReq *ProductUploadRequ } toolResp = &ProductUploadResponse{ - PreviewUrl: "https://gateway.dev.cdlsxd.cn/sw//#/goods/goodsManage", + PreviewUrl: c.cfg.AddURL, SpuNum: toolReq.GoodsList[0].GoodsInfo.SpuNum, Id: resMap.Data.Ids[0], } diff --git a/internal/domain/tools/hyt/product_upload/client_test.go b/internal/domain/tools/hyt/product_upload/client_test.go index 5e8b111..fdd99f0 100644 --- a/internal/domain/tools/hyt/product_upload/client_test.go +++ b/internal/domain/tools/hyt/product_upload/client_test.go @@ -50,7 +50,8 @@ func Test_Call(t *testing.T) { }, }, } - toolResp, err := Call(context.Background(), config.ToolConfig{}, req) + client := New(config.ToolConfig{}) + toolResp, err := client.Call(context.Background(), req) if err != nil { t.Errorf("Call() error = %v", err) diff --git a/internal/domain/tools/hyt/supplier_search/client.go b/internal/domain/tools/hyt/supplier_search/client.go index ebffe0b..bd53aa9 100644 --- a/internal/domain/tools/hyt/supplier_search/client.go +++ b/internal/domain/tools/hyt/supplier_search/client.go @@ -1,16 +1,27 @@ package supplier_search import ( + "ai_scheduler/internal/config" "ai_scheduler/internal/pkg/l_request" "context" "encoding/json" - "errors" "fmt" ) -func Call(ctx context.Context, name string) (int, error) { +type Client struct { + cfg config.ToolConfig +} + +func New(cfg config.ToolConfig) *Client { + return &Client{ + cfg: cfg, + } +} + +func (c *Client) Call(ctx context.Context, name string) (int, error) { if name == "" { - return 0, errors.New("supplier name is empty") + // 如果没有供应商名,返回0,不报错,由上层业务决定是否允许 + return 0, nil } reqBody := SearchRequest{ @@ -27,7 +38,7 @@ func Call(ctx context.Context, name string) (int, error) { req := l_request.Request{ Method: "Post", - Url: "https://gateway.dev.cdlsxd.cn/goods-admin/api/v1/supplier/list", + Url: c.cfg.BaseURL, Json: apiReq, Headers: map[string]string{ "User-Agent": "Apifox/1.0.0 (https://apifox.com)", diff --git a/internal/domain/tools/hyt/warehouse_search/client.go b/internal/domain/tools/hyt/warehouse_search/client.go index 7b0190b..32d7fa4 100644 --- a/internal/domain/tools/hyt/warehouse_search/client.go +++ b/internal/domain/tools/hyt/warehouse_search/client.go @@ -1,13 +1,24 @@ package warehouse_search import ( + "ai_scheduler/internal/config" "ai_scheduler/internal/pkg/l_request" "context" "encoding/json" "fmt" ) -func Call(ctx context.Context, name string) (int, error) { +type Client struct { + cfg config.ToolConfig +} + +func New(cfg config.ToolConfig) *Client { + return &Client{ + cfg: cfg, + } +} + +func (c *Client) Call(ctx context.Context, name string) (int, error) { if name == "" { // 如果没有仓库名,返回0,不报错,由上层业务决定是否允许 return 0, nil @@ -22,7 +33,7 @@ func Call(ctx context.Context, name string) (int, error) { req := l_request.Request{ Method: "Get", - Url: "https://gateway.dev.cdlsxd.cn/goods-admin/api/v1/warehouse/list", + Url: c.cfg.BaseURL, Params: params, Headers: map[string]string{ "User-Agent": "Apifox/1.0.0 (https://apifox.com)", diff --git a/internal/domain/tools/registry.go b/internal/domain/tools/registry.go index dc6a67e..31a8636 100644 --- a/internal/domain/tools/registry.go +++ b/internal/domain/tools/registry.go @@ -1,16 +1,29 @@ package tools -type Tool interface{ - Name() string +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/domain/tools/hyt/product_upload" + "ai_scheduler/internal/domain/tools/hyt/supplier_search" + "ai_scheduler/internal/domain/tools/hyt/warehouse_search" +) + +type Manager struct { + Hyt *HytTools + // Zltx *ZltxTools } -var registry = map[string]Tool{} - -func Register(t Tool){ - registry[t.Name()] = t +type HytTools struct { + ProductUpload *product_upload.Client + SupplierSearch *supplier_search.Client + WarehouseSearch *warehouse_search.Client } -func Get(name string) Tool{ - return registry[name] +func NewManager(cfg *config.Config) *Manager { + return &Manager{ + Hyt: &HytTools{ + ProductUpload: product_upload.New(cfg.EinoTools.HytProductUpload), + SupplierSearch: supplier_search.New(cfg.EinoTools.HytSupplierSearch), + WarehouseSearch: warehouse_search.New(cfg.EinoTools.HytWarehouseSearch), + }, + } } - diff --git a/internal/domain/workflow/hyt/product_upload.go b/internal/domain/workflow/hyt/product_upload.go index 4b85f37..0a93ce1 100644 --- a/internal/domain/workflow/hyt/product_upload.go +++ b/internal/domain/workflow/hyt/product_upload.go @@ -2,36 +2,33 @@ package hyt import ( "ai_scheduler/internal/config" - "ai_scheduler/internal/data/constants" + toolManager "ai_scheduler/internal/domain/tools" toolPu "ai_scheduler/internal/domain/tools/hyt/product_upload" - toolSs "ai_scheduler/internal/domain/tools/hyt/supplier_search" - toolWs "ai_scheduler/internal/domain/tools/hyt/warehouse_search" "ai_scheduler/internal/domain/workflow/runtime" "ai_scheduler/internal/entitys" "context" "encoding/json" + "errors" "fmt" "strconv" "strings" "sync" - eino_ollama "github.com/cloudwego/eino-ext/components/model/ollama" - "github.com/cloudwego/eino/components/prompt" "github.com/cloudwego/eino/compose" - "github.com/cloudwego/eino/schema" ) const WorkflowID = "hyt.productUpload" func init() { runtime.Register(WorkflowID, func(d *runtime.Deps) (runtime.Workflow, error) { - return &productUpload{cfg: d.Conf}, nil + return &productUpload{cfg: d.Conf, toolManager: d.ToolManager}, nil }) } type productUpload struct { - cfg *config.Config - data *ProductUploadWorkflowInput + cfg *config.Config + toolManager *toolManager.Manager + data *ProductUploadWorkflowInput } type ProductUploadWorkflowInput struct { @@ -41,8 +38,8 @@ type ProductUploadWorkflowInput struct { func (o *productUpload) ID() string { return WorkflowID } func (o *productUpload) Invoke(ctx context.Context, rec *entitys.Recognize) (map[string]any, error) { - // 构建工作流 (使用 V2 Graph 版本) - runnable, err := o.buildWorkflowV2(ctx) + // 构建工作流 + runnable, err := o.buildWorkflow(ctx) if err != nil { return nil, err } @@ -103,65 +100,11 @@ type ProductUploadContext struct { UploadResp *toolPu.ProductUploadResponse } +// buildWorkflow 构建基于 Graph 的并行工作流 func (o *productUpload) buildWorkflow(ctx context.Context) (compose.Runnable[*ProductUploadWorkflowInput, map[string]any], error) { - // 定义工作流 - c := compose.NewChain[*ProductUploadWorkflowInput, map[string]any]() - - // AI映射工具所需参数' - paramMappingModel, err := eino_ollama.NewChatModel(ctx, &eino_ollama.ChatModelConfig{ - BaseURL: o.cfg.Ollama.BaseURL, - Timeout: o.cfg.Ollama.Timeout, - Model: o.cfg.Ollama.Model, - Thinking: &eino_ollama.ThinkValue{Value: true}, - Options: &eino_ollama.Options{Temperature: 0.5}, - }) - if err != nil { - return nil, err - } - - // 1. 构建参数LLM数映射提示词 - c.AppendChatTemplate(prompt.FromMessages( - schema.FString, - schema.SystemMessage("你是一个专业的商品参数解析器,你需要根据用户输入的商品描述,解析出商品的目标参数。"), - schema.SystemMessage("目标参数:"+constants.HYTProductPropertyTemplate), - schema.UserMessage("用户输入:{{.Text}}"), - )) - // 2. 调用LLM - c.AppendChatModel(paramMappingModel) - - // 3.工具参数整理 - c.AppendLambda(compose.InvokableLambda(func(_ context.Context, in *schema.Message) (*toolPu.ProductUploadRequest, error) { - toolReq := &toolPu.ProductUploadRequest{} - if err := json.Unmarshal([]byte(in.Content), toolReq); err != nil { - return nil, err - } - return toolReq, nil - })) - - // 4.工具调用 - c.AppendLambda(compose.InvokableLambda(func(_ context.Context, in *toolPu.ProductUploadRequest) (*toolPu.ProductUploadResponse, error) { - toolRes, err := toolPu.Call(ctx, o.cfg.Tools.HytProductUpload, in) - return toolRes, err - })) - - // 5.结果数据映射 - c.AppendLambda(compose.InvokableLambda(func(_ context.Context, in *toolPu.ProductUploadResponse) (map[string]any, error) { - return map[string]any{ - "预览URL(货易通商品列表)": in.PreviewUrl, - "SPU编码": in.SpuNum, - "商品ID": in.Id, - }, nil - })) - - // 6.编译工作流 - return c.Compile(ctx) -} - -// buildWorkflowV2 构建基于 Graph 的并行工作流 -func (o *productUpload) buildWorkflowV2(ctx context.Context) (compose.Runnable[*ProductUploadWorkflowInput, map[string]any], error) { g := compose.NewGraph[*ProductUploadWorkflowInput, map[string]any]() - // 1. DataMapping 节点: 解析 JSON -> 填充基础 Request -> 提取供应商/仓库名 + // 1. DataMapping 节点: 解析 JSON -> 填充基础 Request g.AddLambdaNode("data_mapping", compose.InvokableLambda(func(ctx context.Context, in *ProductUploadWorkflowInput) (*ProductUploadContext, error) { state := &ProductUploadContext{ mu: &sync.Mutex{}, // 初始化锁 @@ -176,6 +119,21 @@ func (o *productUpload) buildWorkflowV2(ctx context.Context) (compose.Runnable[* if err := json.Unmarshal([]byte(in.Text), &ingestData); err != nil { return nil, fmt.Errorf("解析商品数据失败: %w", err) } + + // 必填校验 + if ingestData.SupplierName == "" { + return nil, errors.New("供应商名称不能为空") + } + if ingestData.WarehouseName == "" { + return nil, errors.New("仓库名称不能为空") + } + if ingestData.Profit == "" { + return nil, errors.New("利润不能为空") + } + if ingestData.TaxRate == "" { + return nil, errors.New("税率不能为空") + } + state.IngestData = &ingestData state.SupplierName = ingestData.SupplierName state.WarehouseName = ingestData.WarehouseName @@ -240,32 +198,38 @@ func (o *productUpload) buildWorkflowV2(ctx context.Context) (compose.Runnable[* // 2. 获取供应商ID 节点 g.AddLambdaNode("get_supplier_id", compose.InvokableLambda(func(ctx context.Context, state *ProductUploadContext) (*ProductUploadContext, error) { - if state.SupplierName != "" { - supplierId, err := toolSs.Call(ctx, state.SupplierName) - if err != nil { - // 记录日志,但不阻断流程,可能允许 ID 为 0 - fmt.Printf("warning: failed to get supplier id for %s: %v\n", state.SupplierName, err) - } else { - state.mu.Lock() - defer state.mu.Unlock() - state.UploadReq.SupplierId = supplierId - } + if state.SupplierName == "" { + return state, errors.New("供应商名称不能为空") } + + supplierId, err := o.toolManager.Hyt.SupplierSearch.Call(ctx, state.SupplierName) + if err != nil { + // 记录日志,但不阻断流程,可能允许 ID 为 0 + fmt.Printf("warning: failed to get supplier id for %s: %v\n", state.SupplierName, err) + } else { + state.mu.Lock() + defer state.mu.Unlock() + state.UploadReq.SupplierId = supplierId + } + return state, nil })) // 3. 获取仓库ID 节点 g.AddLambdaNode("get_warehouse_id", compose.InvokableLambda(func(ctx context.Context, state *ProductUploadContext) (*ProductUploadContext, error) { - if state.WarehouseName != "" { - warehouseId, err := toolWs.Call(ctx, state.WarehouseName) - if err != nil { - fmt.Printf("warning: failed to get warehouse id for %s: %v\n", state.WarehouseName, err) - } else { - state.mu.Lock() - defer state.mu.Unlock() - state.UploadReq.WarehouseId = warehouseId - } + if state.WarehouseName == "" { + return state, errors.New("仓库名称不能为空") } + + warehouseId, err := o.toolManager.Hyt.WarehouseSearch.Call(ctx, state.WarehouseName) + if err != nil { + fmt.Printf("warning: failed to get warehouse id for %s: %v\n", state.WarehouseName, err) + } else { + state.mu.Lock() + defer state.mu.Unlock() + state.UploadReq.WarehouseId = warehouseId + } + return state, nil })) @@ -280,7 +244,7 @@ func (o *productUpload) buildWorkflowV2(ctx context.Context) (compose.Runnable[* // 5. 上传节点 g.AddLambdaNode("upload_product", compose.InvokableLambda(func(ctx context.Context, state *ProductUploadContext) (*ProductUploadContext, error) { - toolRes, err := toolPu.Call(ctx, o.cfg.Tools.HytProductUpload, state.UploadReq) + toolRes, err := o.toolManager.Hyt.ProductUpload.Call(ctx, state.UploadReq) if err != nil { return nil, err } diff --git a/internal/domain/workflow/provider_set.go b/internal/domain/workflow/provider_set.go index c728b44..97e1b5d 100644 --- a/internal/domain/workflow/provider_set.go +++ b/internal/domain/workflow/provider_set.go @@ -5,6 +5,8 @@ import ( "ai_scheduler/internal/domain/workflow/runtime" "ai_scheduler/internal/pkg/utils_ollama" + toolManager "ai_scheduler/internal/domain/tools" + "github.com/google/wire" ) @@ -13,7 +15,7 @@ var ProviderSetWorkflow = wire.NewSet(NewRegistry) // NewRegistry 注入共享依赖并注册默认 Registry,确保自注册工作流可被发现 func NewRegistry(conf *config.Config, llm *utils_ollama.Client) *runtime.Registry { // 步骤1:设置运行时依赖(配置与LLM客户端),供工作流工厂在首次实例化时使用;必须在任何调用 Invoke 之前完成,否则会触发 "deps not set" - runtime.SetDeps(&runtime.Deps{Conf: conf, LLM: llm}) + runtime.SetDeps(&runtime.Deps{Conf: conf, LLM: llm, ToolManager: toolManager.NewManager(conf)}) // 步骤2:创建新的工作流注册表;注册表负责按工作流ID惰性实例化并缓存单例实例,保障并发访问下的安全 r := runtime.NewRegistry() // 步骤3:将该注册表设置为全局默认,便于通过 runtime.Default() 获取;自注册的工作流可通过默认注册表被发现并调用 diff --git a/internal/domain/workflow/registry.go b/internal/domain/workflow/registry.go index cbde3b6..af69a03 100644 --- a/internal/domain/workflow/registry.go +++ b/internal/domain/workflow/registry.go @@ -2,11 +2,13 @@ package workflow import ( "ai_scheduler/internal/config" + toolManager "ai_scheduler/internal/domain/tools" "ai_scheduler/internal/pkg/utils_ollama" ) // 仅声明依赖结构,避免在 workflow 包内实现注册中心逻辑导致循环依赖 type Deps struct { - Conf *config.Config - LLM *utils_ollama.Client + Conf *config.Config + LLM *utils_ollama.Client + ToolManager *toolManager.Manager } diff --git a/internal/domain/workflow/runtime/registry.go b/internal/domain/workflow/runtime/registry.go index bf840e6..2b4049b 100644 --- a/internal/domain/workflow/runtime/registry.go +++ b/internal/domain/workflow/runtime/registry.go @@ -2,6 +2,7 @@ package runtime import ( "ai_scheduler/internal/config" + toolManager "ai_scheduler/internal/domain/tools" "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg/utils_ollama" "context" @@ -16,8 +17,9 @@ type Workflow interface { } type Deps struct { - Conf *config.Config - LLM *utils_ollama.Client + Conf *config.Config + LLM *utils_ollama.Client + ToolManager *toolManager.Manager } type Factory func(deps *Deps) (Workflow, error) From b87767ea5a290fc39db1447f837cd87bb573024e Mon Sep 17 00:00:00 2001 From: fuzhongyun <15339891972@163.com> Date: Mon, 22 Dec 2025 11:56:32 +0800 Subject: [PATCH 54/66] =?UTF-8?q?fix:=20=E6=96=B0=E5=A2=9E=E5=B7=A5?= =?UTF-8?q?=E4=BD=9C=E6=B5=81=E4=B8=9A=E5=8A=A1=E9=94=99=E8=AF=AF=EF=BC=8C?= =?UTF-8?q?=E8=B0=83=E6=95=B4=E5=B7=A5=E4=BD=9C=E6=B5=81=E9=94=99=E8=AF=AF?= =?UTF-8?q?=E6=97=B6=E8=BE=93=E5=87=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/data/error/error_code.go | 5 +++++ internal/domain/workflow/hyt/product_upload.go | 12 ++++++++++-- .../workflow/zltx/order_after_reseller_batch.go | 9 --------- internal/server/router/router.go | 1 - 4 files changed, 15 insertions(+), 12 deletions(-) diff --git a/internal/data/error/error_code.go b/internal/data/error/error_code.go index 9d907e7..85abd88 100644 --- a/internal/data/error/error_code.go +++ b/internal/data/error/error_code.go @@ -15,6 +15,7 @@ var ( SysNotFound = &BusinessErr{code: 410, message: "未找到系统信息"} SysCodeNotFound = &BusinessErr{code: 411, message: "未找到系统编码"} InvalidParam = &BusinessErr{code: InvalidParamCode, message: "无效参数"} + WorkflowError = &BusinessErr{code: 501, message: "工作流过程错误"} ) const ( @@ -58,3 +59,7 @@ func (e *BusinessErr) Wrap(err error) *BusinessErr { func KeyErr() *BusinessErr { return &BusinessErr{code: KeyNotFound.code, message: KeyNotFound.message} } + +func WorkflowErr(message string) *BusinessErr { + return NewBusinessErr(WorkflowError.code, message) +} diff --git a/internal/domain/workflow/hyt/product_upload.go b/internal/domain/workflow/hyt/product_upload.go index 0a93ce1..de74977 100644 --- a/internal/domain/workflow/hyt/product_upload.go +++ b/internal/domain/workflow/hyt/product_upload.go @@ -2,6 +2,7 @@ package hyt import ( "ai_scheduler/internal/config" + errorcode "ai_scheduler/internal/data/error" toolManager "ai_scheduler/internal/domain/tools" toolPu "ai_scheduler/internal/domain/tools/hyt/product_upload" "ai_scheduler/internal/domain/workflow/runtime" @@ -50,7 +51,11 @@ func (o *productUpload) Invoke(ctx context.Context, rec *entitys.Recognize) (map // 工作流过程调用 output, err := runnable.Invoke(ctx, o.data) if err != nil { - return nil, err + errStr := err.Error() + if u := errors.Unwrap(err); u != nil { + errStr = u.Error() + } + return nil, errorcode.WorkflowErr(errStr) } fmt.Printf("workflow output: %v\n", output) @@ -235,10 +240,13 @@ func (o *productUpload) buildWorkflow(ctx context.Context) (compose.Runnable[*Pr // 4. 合并/同步节点 g.AddLambdaNode("merge_node", compose.InvokableLambda(func(ctx context.Context, state *ProductUploadContext) (*ProductUploadContext, error) { - // 可以在这里做最终校验,例如必须有 SupplierId + // 最终校验 if state.UploadReq.SupplierId == 0 { return nil, fmt.Errorf("供应商获取失败") } + if state.UploadReq.WarehouseId == 0 { + return nil, fmt.Errorf("仓库获取失败") + } return state, nil })) diff --git a/internal/domain/workflow/zltx/order_after_reseller_batch.go b/internal/domain/workflow/zltx/order_after_reseller_batch.go index 5e5fa51..ff21d84 100644 --- a/internal/domain/workflow/zltx/order_after_reseller_batch.go +++ b/internal/domain/workflow/zltx/order_after_reseller_batch.go @@ -78,15 +78,6 @@ type OrderAfterSaleResellerBatchData struct { // ID 返回工作流唯一标识 func (o *orderAfterSaleResellerBatch) ID() string { return "zltx.orderAfterSaleResellerBatch" } -// Schema 返回入参约束(用于校验/表单生成) -// func (o *orderAfterSaleResellerBatch) Schema() map[string]any { -// return map[string]any{ -// "type": "object", -// "properties": map[string]any{"orderNumber": map[string]any{"type": "array", "items": map[string]any{"type": "string"}}}, -// "required": []string{"orderNumber"}, -// } -// } - // Invoke 调用原有编排工作流并规范化输出 func (o *orderAfterSaleResellerBatch) Invoke(ctx context.Context, rec *entitys.Recognize) (map[string]any, error) { // 构建工作流 diff --git a/internal/server/router/router.go b/internal/server/router/router.go index 15e1554..091c85f 100644 --- a/internal/server/router/router.go +++ b/internal/server/router/router.go @@ -52,7 +52,6 @@ func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionServi r := app.Group("api/v1/") registerResponse(r) - // 注册 CORS 中间件 r.Get("/health", func(c *fiber.Ctx) error { c.Response().SetBody([]byte("1")) return nil From c1b7cd6bf5cc8597c8ee6cc1372be8c5cd1be4cb Mon Sep 17 00:00:00 2001 From: fuzhongyun <15339891972@163.com> Date: Mon, 22 Dec 2025 14:08:48 +0800 Subject: [PATCH 55/66] =?UTF-8?q?fix:=201.=E8=B0=83=E6=95=B4=E5=85=AC?= =?UTF-8?q?=E5=85=B1code=E8=BE=93=E5=87=BA=E6=96=B9=E6=B3=95=202.=E8=B0=83?= =?UTF-8?q?=E6=95=B4=E5=95=86=E5=93=81=E6=8A=93=E5=8F=96=E5=B7=A5=E4=BD=9C?= =?UTF-8?q?=E6=B5=81=E6=97=A5=E5=BF=97=E5=8F=8A=E9=94=99=E8=AF=AF=E8=BE=93?= =?UTF-8?q?=E5=87=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/biz/do/ctx.go | 2 +- internal/biz/do/handle.go | 2 +- internal/biz/task.go | 4 +-- internal/data/error/error_code.go | 16 ++++++---- .../domain/tools/hyt/product_upload/client.go | 14 ++++---- .../tools/hyt/supplier_search/client.go | 12 +++---- .../tools/hyt/warehouse_search/client.go | 12 +++---- .../domain/workflow/hyt/product_upload.go | 7 ++-- internal/gateway/client.go | 5 +-- internal/pkg/dingtalk/contact_client.go | 4 +-- internal/pkg/dingtalk/notable_client.go | 2 +- internal/services/callback.go | 32 +++++++++---------- internal/services/capability.go | 14 ++++---- .../tool_callback/bug_optimization_submit.go | 2 +- 14 files changed, 62 insertions(+), 66 deletions(-) diff --git a/internal/biz/do/ctx.go b/internal/biz/do/ctx.go index 4b5d2ec..5d9ba72 100644 --- a/internal/biz/do/ctx.go +++ b/internal/biz/do/ctx.go @@ -349,7 +349,7 @@ func (d *Do) LoadUserPermission(client *gateway.Client, requireData *entitys.Req // 检查响应状态码 if res.StatusCode != http.StatusOK { - err = errors.SysErr("获取用户权限失败") + err = errors.SysErrf("获取用户权限失败") return } diff --git a/internal/biz/do/handle.go b/internal/biz/do/handle.go index 3a4a12b..76f1257 100644 --- a/internal/biz/do/handle.go +++ b/internal/biz/do/handle.go @@ -74,7 +74,7 @@ func (r *Handle) Recognize(ctx context.Context, rec *entitys.Recognize, promptPr entitys.ResLog(rec.Ch, "recognize_end", "意图识别结束") var match entitys.Match if err = json.Unmarshal([]byte(recognizeMsg), &match); err != nil { - err = errors.SysErr("数据结构错误:%v", err.Error()) + err = errors.SysErrf("数据结构错误:%v", err.Error()) return } rec.Match = &match diff --git a/internal/biz/task.go b/internal/biz/task.go index 277a97e..f9c03f0 100644 --- a/internal/biz/task.go +++ b/internal/biz/task.go @@ -70,13 +70,13 @@ func (t *TaskBiz) GetUserPermission(req *entitys.TaskRequest, auth string) (code // 发送请求 res, err := request.Send() if err != nil { - err = errors.SysErr("请求用户权限失败") + err = errors.SysErrf("请求用户权限失败") return } // 检查响应状态码 if res.StatusCode != http.StatusOK { - err = errors.SysErr("获取用户权限失败") + err = errors.SysErrf("获取用户权限失败") return } diff --git a/internal/data/error/error_code.go b/internal/data/error/error_code.go index 85abd88..1e2f0b6 100644 --- a/internal/data/error/error_code.go +++ b/internal/data/error/error_code.go @@ -44,22 +44,26 @@ func NewBusinessErr(code int, message string) *BusinessErr { return &BusinessErr{code: code, message: message} } -func SysErr(message string, arg ...any) *BusinessErr { +func SysErrf(message string, arg ...any) *BusinessErr { return &BusinessErr{code: SystemError.code, message: fmt.Sprintf(message, arg)} } -func ParamErr(message string, arg ...any) *BusinessErr { +func SysErr(message string) *BusinessErr { + return &BusinessErr{code: SystemError.code, message: message} +} + +func ParamErrf(message string, arg ...any) *BusinessErr { return &BusinessErr{code: ParamError.code, message: fmt.Sprintf(message, arg)} } +func ParamErr(message string) *BusinessErr { + return &BusinessErr{code: ParamError.code, message: message} +} + func (e *BusinessErr) Wrap(err error) *BusinessErr { return NewBusinessErr(e.code, err.Error()) } -func KeyErr() *BusinessErr { - return &BusinessErr{code: KeyNotFound.code, message: KeyNotFound.message} -} - func WorkflowErr(message string) *BusinessErr { return NewBusinessErr(WorkflowError.code, message) } diff --git a/internal/domain/tools/hyt/product_upload/client.go b/internal/domain/tools/hyt/product_upload/client.go index 6ebcc62..1cd68b8 100644 --- a/internal/domain/tools/hyt/product_upload/client.go +++ b/internal/domain/tools/hyt/product_upload/client.go @@ -7,6 +7,7 @@ import ( "context" "encoding/json" "errors" + "fmt" ) type Client struct { @@ -22,8 +23,7 @@ func New(cfg config.ToolConfig) *Client { func (c *Client) Call(ctx context.Context, toolReq *ProductUploadRequest) (toolResp *ProductUploadResponse, err error) { // 商品有且只能有一个 if len(toolReq.GoodsList) != 1 { - err = errors.New("商品只能有一个") - return + return nil, errors.New("商品只能有一个") } apiReq, _ := util.StructToMap(toolReq) @@ -36,7 +36,7 @@ func (c *Client) Call(ctx context.Context, toolReq *ProductUploadRequest) (toolR res, err := req.Send() if err != nil { - return + return nil, fmt.Errorf("请求失败,err: %v", err) } type resType struct { @@ -49,15 +49,13 @@ func (c *Client) Call(ctx context.Context, toolReq *ProductUploadRequest) (toolR var resMap resType err = json.Unmarshal([]byte(res.Text), &resMap) if err != nil { - return + return nil, fmt.Errorf("解析响应失败,err: %v", err) } if resMap.Code != 200 { - err = errors.New("货易通商品创建失败") - return + return nil, fmt.Errorf("业务错误,code: %d, msg: %s", resMap.Code, resMap.Msg) } if len(resMap.Data.Ids) == 0 { - err = errors.New("货易通商品创建失败") - return + return nil, fmt.Errorf("ids为空") } toolResp = &ProductUploadResponse{ diff --git a/internal/domain/tools/hyt/supplier_search/client.go b/internal/domain/tools/hyt/supplier_search/client.go index bd53aa9..cbb20b4 100644 --- a/internal/domain/tools/hyt/supplier_search/client.go +++ b/internal/domain/tools/hyt/supplier_search/client.go @@ -48,24 +48,20 @@ func (c *Client) Call(ctx context.Context, name string) (int, error) { res, err := req.Send() if err != nil { - return 0, err - } - - if res.StatusCode != 200 { - return 0, fmt.Errorf("supplier search failed with status code: %d", res.StatusCode) + return 0, fmt.Errorf("请求失败,err: %v", err) } var resData SearchResponse if err := json.Unmarshal([]byte(res.Text), &resData); err != nil { - return 0, fmt.Errorf("failed to parse supplier search response: %w", err) + return 0, fmt.Errorf("解析响应失败,err: %v", err) } if resData.Code != 200 { - return 0, fmt.Errorf("supplier search business error: %s", resData.Msg) + return 0, fmt.Errorf("业务错误,code: %d, msg: %s", resData.Code, resData.Msg) } if len(resData.Data.List) == 0 { - return 0, fmt.Errorf("supplier not found: %s", name) + return 0, fmt.Errorf("供应商不存在") } return resData.Data.List[0].ID, nil diff --git a/internal/domain/tools/hyt/warehouse_search/client.go b/internal/domain/tools/hyt/warehouse_search/client.go index 32d7fa4..4502c42 100644 --- a/internal/domain/tools/hyt/warehouse_search/client.go +++ b/internal/domain/tools/hyt/warehouse_search/client.go @@ -43,24 +43,20 @@ func (c *Client) Call(ctx context.Context, name string) (int, error) { res, err := req.Send() if err != nil { - return 0, err - } - - if res.StatusCode != 200 { - return 0, fmt.Errorf("warehouse search failed with status code: %d", res.StatusCode) + return 0, fmt.Errorf("请求失败,err: %v", err) } var resData SearchResponse if err := json.Unmarshal([]byte(res.Text), &resData); err != nil { - return 0, fmt.Errorf("failed to parse warehouse search response: %w", err) + return 0, fmt.Errorf("解析响应失败,err: %v", err) } if resData.Code != 200 { - return 0, fmt.Errorf("warehouse search business error: %s", resData.Msg) + return 0, fmt.Errorf("业务错误,code: %d, msg: %s", resData.Code, resData.Msg) } if len(resData.Data.List) == 0 { - return 0, fmt.Errorf("warehouse not found: %s", name) + return 0, fmt.Errorf("仓库不存在: %s", name) } return resData.Data.List[0].ID, nil diff --git a/internal/domain/workflow/hyt/product_upload.go b/internal/domain/workflow/hyt/product_upload.go index de74977..fca3aec 100644 --- a/internal/domain/workflow/hyt/product_upload.go +++ b/internal/domain/workflow/hyt/product_upload.go @@ -11,6 +11,7 @@ import ( "encoding/json" "errors" "fmt" + "log" "strconv" "strings" "sync" @@ -210,7 +211,7 @@ func (o *productUpload) buildWorkflow(ctx context.Context) (compose.Runnable[*Pr supplierId, err := o.toolManager.Hyt.SupplierSearch.Call(ctx, state.SupplierName) if err != nil { // 记录日志,但不阻断流程,可能允许 ID 为 0 - fmt.Printf("warning: failed to get supplier id for %s: %v\n", state.SupplierName, err) + log.Printf("warning: 供应商ID获取失败,%s: %v\n", state.SupplierName, err) } else { state.mu.Lock() defer state.mu.Unlock() @@ -228,7 +229,7 @@ func (o *productUpload) buildWorkflow(ctx context.Context) (compose.Runnable[*Pr warehouseId, err := o.toolManager.Hyt.WarehouseSearch.Call(ctx, state.WarehouseName) if err != nil { - fmt.Printf("warning: failed to get warehouse id for %s: %v\n", state.WarehouseName, err) + log.Printf("warning: 仓库ID获取失败,%s: %v\n", state.WarehouseName, err) } else { state.mu.Lock() defer state.mu.Unlock() @@ -254,7 +255,7 @@ func (o *productUpload) buildWorkflow(ctx context.Context) (compose.Runnable[*Pr g.AddLambdaNode("upload_product", compose.InvokableLambda(func(ctx context.Context, state *ProductUploadContext) (*ProductUploadContext, error) { toolRes, err := o.toolManager.Hyt.ProductUpload.Call(ctx, state.UploadReq) if err != nil { - return nil, err + return nil, fmt.Errorf("商品上传失败") } state.UploadResp = toolRes return state, nil diff --git a/internal/gateway/client.go b/internal/gateway/client.go index a293daa..1da0dd8 100644 --- a/internal/gateway/client.go +++ b/internal/gateway/client.go @@ -4,17 +4,18 @@ import ( errors "ai_scheduler/internal/data/error" "ai_scheduler/internal/data/model" "context" - "github.com/google/uuid" "log" "math/rand" "sync" "time" + "github.com/google/uuid" + "github.com/gofiber/websocket/v2" ) var ( - ErrConnClosed = errors.SysErr("连接不存在或已关闭") + ErrConnClosed = errors.SysErrf("连接不存在或已关闭") rng = rand.New(rand.NewSource(time.Now().UnixNano())) idBuf = make([]byte, 20) ) diff --git a/internal/pkg/dingtalk/contact_client.go b/internal/pkg/dingtalk/contact_client.go index 4461995..ed3e3dd 100644 --- a/internal/pkg/dingtalk/contact_client.go +++ b/internal/pkg/dingtalk/contact_client.go @@ -54,10 +54,10 @@ func (c *ContactClient) SearchUserOne(accessToken string, name string) (string, } if resp.Body == nil { - return "", errorcode.ParamErr("empty response body") + return "", errorcode.ParamErrf("empty response body") } if len(resp.Body.List) == 0 { - return "", errorcode.ParamErr("empty user list") + return "", errorcode.ParamErrf("empty user list") } userId := resp.Body.List[0] diff --git a/internal/pkg/dingtalk/notable_client.go b/internal/pkg/dingtalk/notable_client.go index 7cfe04f..d7d5434 100644 --- a/internal/pkg/dingtalk/notable_client.go +++ b/internal/pkg/dingtalk/notable_client.go @@ -67,7 +67,7 @@ func (c *NotableClient) UpdateRecord(accessToken string, req *UpdateRecordReq) ( } if resp.Body == nil { - return false, errorcode.ParamErr("empty response body") + return false, errorcode.ParamErrf("empty response body") } return true, nil diff --git a/internal/services/callback.go b/internal/services/callback.go index cc32660..e224ce3 100644 --- a/internal/services/callback.go +++ b/internal/services/callback.go @@ -85,13 +85,13 @@ func (s *CallbackService) Callback(c *fiber.Ctx) error { // 解析 Envelope var env Envelope if err := json.Unmarshal(c.Body(), &env); err != nil { - return errorcode.ParamErr("invalid json: %v", err) + return errorcode.ParamErrf("invalid json: %v", err) } if env.Action == "" || env.TaskID == "" { - return errorcode.ParamErr("missing action/task_id") + return errorcode.ParamErrf("missing action/task_id") } if env.Data == nil { - return errorcode.ParamErr("missing data") + return errorcode.ParamErrf("missing data") } switch sourceKey { @@ -141,7 +141,7 @@ func (s *CallbackService) handleDingTalkCallback(c *fiber.Ctx, env Envelope) err // 校验taskId sessionID, ok := s.callBackTool.GetSessionByTaskID(env.TaskID) if !ok { - return errorcode.ParamErr("missing session_id for task_id: %s", env.TaskID) + return errorcode.ParamErrf("missing session_id for task_id: %s", env.TaskID) } ctx := c.Context() @@ -176,14 +176,14 @@ func (s *CallbackService) handleDingTalkCallback(c *fiber.Ctx, env Envelope) err } var data processData if err := json.Unmarshal(env.Data, &data); err != nil { - return errorcode.ParamErr("invalid json: %v", err) + return errorcode.ParamErrf("invalid json: %v", err) } s.sendStreamLoading(sessionID, data.Process) return c.JSON(fiber.Map{"code": 0, "message": "ok"}) default: - return errorcode.ParamErr("unknown action: %s", env.Action) + return errorcode.ParamErrf("unknown action: %s", env.Action) } } @@ -255,27 +255,27 @@ func (s *CallbackService) sendStreamLoading(sessionID string, content string) { func (s *CallbackService) handleBugOptimizationSubmitUpdate(ctx context.Context, taskData json.RawMessage) (string, *errorcode.BusinessErr) { var data BugOptimizationSubmitUpdateData if err := json.Unmarshal(taskData, &data); err != nil { - return "", errorcode.ParamErr("invalid data type: %v", err) + return "", errorcode.ParamErrf("invalid data type: %v", err) } if data.Creator == "" { - return "", errorcode.ParamErr("empty creator") + return "", errorcode.ParamErrf("empty creator") } // 获取创建者uid accessToken, _ := s.dingtalkOldClient.GetAccessToken() creatorId, err := s.dingtalkContactClient.SearchUserOne(accessToken, data.Creator) if err != nil { - return "", errorcode.ParamErr("invalid data type: %v", err) + return "", errorcode.ParamErrf("invalid data type: %v", err) } // 获取用户详情 userDetails, err := s.dingtalkOldClient.QueryUserDetails(ctx, creatorId) if err != nil { - return "", errorcode.ParamErr("invalid data type: %v", err) + return "", errorcode.ParamErrf("invalid data type: %v", err) } if userDetails == nil { - return "", errorcode.ParamErr("user details not found") + return "", errorcode.ParamErrf("user details not found") } unionId := userDetails.UnionID @@ -288,10 +288,10 @@ func (s *CallbackService) handleBugOptimizationSubmitUpdate(ctx context.Context, CreatorUnionId: unionId, }) if err != nil { - return "", errorcode.ParamErr("invalid data type: %v", err) + return "", errorcode.ParamErrf("invalid data type: %v", err) } if !ok { - return "", errorcode.ParamErr("update record failed") + return "", errorcode.ParamErrf("update record failed") } return "问题记录即将完成", nil @@ -301,16 +301,16 @@ func (s *CallbackService) handleBugOptimizationSubmitUpdate(ctx context.Context, func (s *CallbackService) handleBugOptimizationSubmitDone(ctx context.Context, taskData json.RawMessage) (string, *errorcode.BusinessErr) { var data BugOptimizationSubmitDoneData if err := json.Unmarshal(taskData, &data); err != nil { - return "", errorcode.ParamErr("invalid data type: %v", err) + return "", errorcode.ParamErrf("invalid data type: %v", err) } if len(data.Receivers) == 0 { - return "", errorcode.ParamErr("empty receivers") + return "", errorcode.ParamErrf("empty receivers") } // 构建接收者 receivers := s.getDingtalkReceivers(ctx, data.Receivers) if receivers == "" { - return "", errorcode.ParamErr("invalid receivers") + return "", errorcode.ParamErrf("invalid receivers") } // 构建跳转链接 diff --git a/internal/services/capability.go b/internal/services/capability.go index 89d97cc..5b3d152 100644 --- a/internal/services/capability.go +++ b/internal/services/capability.go @@ -65,11 +65,11 @@ func (s *CapabilityService) ProductIngest(c *fiber.Ctx) error { // 解析请求参数 req := ProductIngestReq{} if err := c.BodyParser(&req); err != nil { - return errorcode.ParamErr("invalid request body: %v", err) + return errorcode.ParamErrf("invalid request body: %v", err) } // 必要参数校验 if req.Text == "" || req.SysId == "" { - return errorcode.ParamErr("missing required fields") + return errorcode.ParamErrf("missing required fields") } // 映射目标系统商品属性中文模板 @@ -78,7 +78,7 @@ func (s *CapabilityService) ProductIngest(c *fiber.Ctx) error { case "hyt": // 货易通 sysProductPropertyTemplateZH = constants.HYTProductPropertyTemplateZH default: - return errorcode.ParamErr("invalid sys_id") + return errorcode.ParamErrf("invalid sys_id") } // 模型调用 @@ -138,11 +138,11 @@ func (s *CapabilityService) checkRequestHeader(c *fiber.Ctx) error { // 时间窗口校验 if ts != "" && !util.IsInTimeWindow(ts, 5*time.Minute) { - // return errorcode.AuthNotFound + return errorcode.AuthNotFound } // token校验 if token == "" || token != "A7f9KQ3mP2X8LZC4R5e" { - return errorcode.KeyErr() + return errorcode.KeyNotFound } return nil @@ -164,12 +164,12 @@ func (s *CapabilityService) ProductIngestConfirm(c *fiber.Ctx) error { // 获取路径参数中的 thread_id threadId := c.Params("thread_id") if threadId == "" { - return errorcode.ParamErr("missing required fields") + return errorcode.ParamErrf("missing required fields") } // 解析请求参数 body req := ProductIngestConfirmReq{} if err := c.BodyParser(&req); err != nil { - return errorcode.ParamErr("invalid request body: %v", err) + return errorcode.ParamErrf("invalid request body: %v", err) } // 必要参数校验 if req.Confirmed == "" || threadId == "" { diff --git a/internal/tool_callback/bug_optimization_submit.go b/internal/tool_callback/bug_optimization_submit.go index 245ab8c..57a5e73 100644 --- a/internal/tool_callback/bug_optimization_submit.go +++ b/internal/tool_callback/bug_optimization_submit.go @@ -60,7 +60,7 @@ func (w *CallBackTool) BugOptimizationSubmit(ctx context.Context, requireData *e cond = cond.And(builder.Eq{"session_id": requireData.Session}) sessionInfo, err := w.sessionImpl.GetOneBySearch(&cond) if err != nil { - err = errors.SysErr("获取会话信息失败:%v", err.Error()) + err = errors.SysErrf("获取会话信息失败:%v", err.Error()) return } userName := sessionInfo["user_name"].(string) From cd13e1dbfa1e44b2dbca7b9cf6377dc0aca75410 Mon Sep 17 00:00:00 2001 From: fuzhongyun <15339891972@163.com> Date: Mon, 22 Dec 2025 14:17:37 +0800 Subject: [PATCH 56/66] =?UTF-8?q?feat:=20=E5=A2=9E=E5=8A=A0=E6=B5=8B?= =?UTF-8?q?=E8=AF=95=E7=8E=AF=E5=A2=83=E9=85=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/config_test.yaml | 13 ++++++++++++- internal/services/capability.go | 2 +- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/config/config_test.yaml b/config/config_test.yaml index ab91824..fe68122 100644 --- a/config/config_test.yaml +++ b/config/config_test.yaml @@ -3,7 +3,6 @@ server: port: 8090 host: "0.0.0.0" - ollama: base_url: "http://host.docker.internal:11434" model: "qwen3-coder:480b-cloud" @@ -90,6 +89,18 @@ tools: api_key: "7583905168607100978" api_secret: "pat_eEN0BdLNDughEtABjJJRYTW71olvDU0qUbfQUeaPc2NnYWO8HeyNoui5aR9z0sSZ" +# eino tool 配置 +eino_tools: + # 货易通商品上传 + hytProductUpload: + base_url: "https://gateway.dev.cdlsxd.cn/goods-admin/api/v1/goods/supplier/batch/add/complete" + add_url: "https://gateway.dev.cdlsxd.cn/sw//#/goods/goodsManage" + # 货易通供应商查询 + hytSupplierSearch: + base_url: "https://gateway.dev.cdlsxd.cn/goods-admin/api/v1/supplier/list" + # 货易通仓库查询 + hytWarehouseSearch: + base_url: "https://gateway.dev.cdlsxd.cn/goods-admin/api/v1/warehouse/list" default_prompt: diff --git a/internal/services/capability.go b/internal/services/capability.go index 5b3d152..c5dec87 100644 --- a/internal/services/capability.go +++ b/internal/services/capability.go @@ -141,7 +141,7 @@ func (s *CapabilityService) checkRequestHeader(c *fiber.Ctx) error { return errorcode.AuthNotFound } // token校验 - if token == "" || token != "A7f9KQ3mP2X8LZC4R5e" { + if token == "" || token != constants.CapabilityProductIngestToken { return errorcode.KeyNotFound } From adda03e5d871dfef45f4fa0b4078149ec4a6eb22 Mon Sep 17 00:00:00 2001 From: fuzhongyun <15339891972@163.com> Date: Mon, 22 Dec 2025 15:01:21 +0800 Subject: [PATCH 57/66] =?UTF-8?q?fix:=20=E5=A2=9E=E5=8A=A0=E4=BE=9B?= =?UTF-8?q?=E5=BA=94=E5=95=86=E6=8A=A5=E4=BB=B7=E5=BF=85=E5=A1=AB=E6=A0=A1?= =?UTF-8?q?=E9=AA=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/domain/workflow/hyt/product_upload.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/internal/domain/workflow/hyt/product_upload.go b/internal/domain/workflow/hyt/product_upload.go index fca3aec..cc9ab70 100644 --- a/internal/domain/workflow/hyt/product_upload.go +++ b/internal/domain/workflow/hyt/product_upload.go @@ -139,6 +139,9 @@ func (o *productUpload) buildWorkflow(ctx context.Context) (compose.Runnable[*Pr if ingestData.TaxRate == "" { return nil, errors.New("税率不能为空") } + if ingestData.SupplierPrice == "" { + return nil, errors.New("供应商报价不能为空") + } state.IngestData = &ingestData state.SupplierName = ingestData.SupplierName From 8d4f3c494ed8b7410d8b28edd2090a88b2dacaa0 Mon Sep 17 00:00:00 2001 From: fuzhongyun <15339891972@163.com> Date: Mon, 22 Dec 2025 17:48:48 +0800 Subject: [PATCH 58/66] =?UTF-8?q?fix:=20=E6=98=A0=E5=B0=84=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E8=B0=83=E6=95=B4=E3=80=81chat=E6=96=B9=E6=B3=95?= =?UTF-8?q?=E8=B0=83=E6=95=B4=E3=80=81=E6=8F=90=E7=A4=BA=E8=AF=8D=E8=B0=83?= =?UTF-8?q?=E6=95=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/config_env.yaml | 1 + config/config_test.yaml | 1 + internal/config/config.go | 1 + internal/data/constants/capability.go | 36 ++++++++++++++------------- internal/pkg/utils_ollama/client.go | 6 ++--- internal/services/capability.go | 2 +- 6 files changed, 26 insertions(+), 21 deletions(-) diff --git a/config/config_env.yaml b/config/config_env.yaml index a10a7d4..ec4f4c2 100644 --- a/config/config_env.yaml +++ b/config/config_env.yaml @@ -7,6 +7,7 @@ ollama: base_url: "http://192.168.6.109:11434" model: "qwen3-coder:480b-cloud" generate_model: "qwen3-coder:480b-cloud" + mapping_model: "deepseek-v3.2:cloud" vl_model: "qwen2.5vl:7b" timeout: "120s" level: "info" diff --git a/config/config_test.yaml b/config/config_test.yaml index fe68122..45fb701 100644 --- a/config/config_test.yaml +++ b/config/config_test.yaml @@ -7,6 +7,7 @@ ollama: base_url: "http://host.docker.internal:11434" model: "qwen3-coder:480b-cloud" generate_model: "qwen3-coder:480b-cloud" + mapping_model: "deepseek-v3.2:cloud" vl_model: "gemini-3-pro-preview" timeout: "120s" level: "info" diff --git a/internal/config/config.go b/internal/config/config.go index 5cbb506..b30adae 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -81,6 +81,7 @@ type OllamaConfig struct { BaseURL string `mapstructure:"base_url"` Model string `mapstructure:"model"` GenerateModel string `mapstructure:"generate_model"` + MappingModel string `mapstructure:"mapping_model"` VlModel string `mapstructure:"vl_model"` Timeout time.Duration `mapstructure:"timeout"` } diff --git a/internal/data/constants/capability.go b/internal/data/constants/capability.go index fab9db7..5b40d9f 100644 --- a/internal/data/constants/capability.go +++ b/internal/data/constants/capability.go @@ -9,41 +9,43 @@ const ( const ( SystemPrompt = ` #你是一个专业的商品属性提取助手,你的任务是根据用户输入提取商品的属性信息。 - 1.最终输出格式为纯JSON字符串,键值对对应目标属性和提取到的属性值。 - 2.最终输出不要携带markdown标识,不要携带回车换行` + 关键格式要求: + 1.输出必须是一个紧凑的、无任何多余空白字符的纯JSON字符串。 + 2.确保整个JSON输出在一行内完成,键、值、冒号、引号、括号之间均不要换行。 + 3.最终输出不要携带任何markdown标识(如json),直接输出纯JSON内容。` ) // 商品属性模板-中文 const ( // 货易通商品属性模板-中文 HYTProductPropertyTemplateZH = `{ - "条码": "string", // 商品编号 + "货品编号": "string", // 商品编号 + "条码": "string", // 货品编号 "分类名称": "string", // 商品分类 "货品名称": "string", // 商品名称 - "货品编号": "string", // 商品编号 - "商品货号": "string", // 商品编号 + "商品货号": "string", // 货品编号 "品牌": "string", // 商品品牌 "单位": "string", // 商品单位,若无则使用'个' "规格参数": "string", // 商品规格参数 "货品说明": "string", // 商品说明 - "保质期": "string", // 商品保质期 - "保质期单位": "string", // 商品保质期单位 - "链接": "string", // 商品链接 - "货品图片": ["string"], // 商品多图,取1-2个即可 + "保质期": "string", // 商品保质期,无则空 + "保质期单位": "string", // 商品保质期单位,无则空 + "链接": "string", // 空 + "货品图片": ["string"], // 商品多图,取前2个即可 "电商销售价格": "string", // 商品电商销售价格 decimal(10,2) "销售价": "string", // 商品销售价格 decimal(10,2) - "备注": "string", // 备注 + "备注": "string", // 无则空 "长": "string", // 商品长度,decimal(10,2)+单位 "宽": "string", // 商品宽度,decimal(10,2)+单位 "高": "string", // 商品高度,decimal(10,2)+单位 - "重量": "string", // 商品重量(kg) + "重量": "string", // 商品重量,decimal(10,2)+单位(kg) "SPU名称": "string", // 商品SPU名称 - "SPU编码": "string" // 编码串,jd_{timestamp}_rand(1000-999) - "供应商报价": "string", // 商品供应商报价 decimal(10,2) - "税率": "string", // 商品税率 x% - "利润": "string", // 商品利润 decimal(10,2) - "默认供应商": "string", // 供应商名称 - "默认存放仓库": "string", // 仓库名称 + "SPU编码": "string" // 货品编号 + "供应商报价": "string", // 空 + "税率": "string", // 商品税率 x%,无则空 + "利润": "string", // 空 + "默认供应商": "string", // 空 + "默认存放仓库": "string", // 空 }` ) diff --git a/internal/pkg/utils_ollama/client.go b/internal/pkg/utils_ollama/client.go index 1f67774..fa88afa 100644 --- a/internal/pkg/utils_ollama/client.go +++ b/internal/pkg/utils_ollama/client.go @@ -90,13 +90,13 @@ func (c *Client) ChatStream(ctx context.Context, ch chan entitys.Response, messa return } -func (c *Client) Chat(ctx context.Context, messages []api.Message) (res api.ChatResponse, err error) { +func (c *Client) Chat(ctx context.Context, model string, messages []api.Message) (res api.ChatResponse, err error) { // 构建聊天请求 req := &api.ChatRequest{ - Model: c.config.Model, + Model: model, Messages: messages, Stream: new(bool), // 设置为false,不使用流式响应 - Think: &api.ThinkValue{Value: true}, + Think: &api.ThinkValue{Value: false}, } err = c.client.Chat(ctx, req, func(resp api.ChatResponse) error { res = resp diff --git a/internal/services/capability.go b/internal/services/capability.go index c5dec87..f163a79 100644 --- a/internal/services/capability.go +++ b/internal/services/capability.go @@ -87,7 +87,7 @@ func (s *CapabilityService) ProductIngest(c *fiber.Ctx) error { return err } defer cleanup() - res, err := client.Chat(ctx, []api.Message{ + res, err := client.Chat(ctx, s.cfg.Ollama.MappingModel, []api.Message{ { Role: "system", Content: constants.SystemPrompt, From 208f749483c24b942a002652b2438e6a6ec95924 Mon Sep 17 00:00:00 2001 From: fuzhongyun <15339891972@163.com> Date: Wed, 24 Dec 2025 11:52:00 +0800 Subject: [PATCH 59/66] =?UTF-8?q?feat:=201.=E6=96=B0=E5=A2=9E=E5=A4=9A?= =?UTF-8?q?=E4=B8=AA=E8=B4=A7=E6=98=93=E9=80=9A=E5=B7=A5=E5=85=B7=202.?= =?UTF-8?q?=E6=96=B0=E5=A2=9E=E8=B4=A7=E6=98=93=E9=80=9A=E5=88=9B=E5=BB=BA?= =?UTF-8?q?=E5=95=86=E5=93=81=E5=B7=A5=E4=BD=9C=E6=B5=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/config_env.yaml | 17 + internal/config/config.go | 10 + internal/data/constants/capability.go | 31 +- internal/domain/tools/hyt/goods_add/client.go | 49 +++ .../domain/tools/hyt/goods_add/client_test.go | 51 +++ internal/domain/tools/hyt/goods_add/types.go | 35 ++ .../tools/hyt/goods_brand_search/client.go | 67 +++ .../hyt/goods_brand_search/client_test.go | 28 ++ .../tools/hyt/goods_brand_search/types.go | 25 ++ .../tools/hyt/goods_category_add/client.go | 49 +++ .../hyt/goods_category_add/client_test.go | 31 ++ .../tools/hyt/goods_category_add/types.go | 15 + .../tools/hyt/goods_category_search/client.go | 66 +++ .../tools/hyt/goods_category_search/types.go | 24 ++ .../tools/hyt/goods_media_add/client.go | 49 +++ .../tools/hyt/goods_media_add/client_test.go | 37 ++ .../domain/tools/hyt/goods_media_add/types.go | 21 + .../domain/tools/hyt/product_upload/client.go | 2 +- .../tools/hyt/product_upload/client_test.go | 61 --- .../tools/hyt/supplier_search/client.go | 1 - .../tools/hyt/warehouse_search/client.go | 1 - internal/domain/tools/registry.go | 27 +- internal/domain/workflow/hyt/goods_add.go | 387 ++++++++++++++++++ .../domain/workflow/hyt/product_upload.go | 14 +- internal/services/capability.go | 4 +- 25 files changed, 1021 insertions(+), 81 deletions(-) create mode 100644 internal/domain/tools/hyt/goods_add/client.go create mode 100644 internal/domain/tools/hyt/goods_add/client_test.go create mode 100644 internal/domain/tools/hyt/goods_add/types.go create mode 100644 internal/domain/tools/hyt/goods_brand_search/client.go create mode 100644 internal/domain/tools/hyt/goods_brand_search/client_test.go create mode 100644 internal/domain/tools/hyt/goods_brand_search/types.go create mode 100644 internal/domain/tools/hyt/goods_category_add/client.go create mode 100644 internal/domain/tools/hyt/goods_category_add/client_test.go create mode 100644 internal/domain/tools/hyt/goods_category_add/types.go create mode 100644 internal/domain/tools/hyt/goods_category_search/client.go create mode 100644 internal/domain/tools/hyt/goods_category_search/types.go create mode 100644 internal/domain/tools/hyt/goods_media_add/client.go create mode 100644 internal/domain/tools/hyt/goods_media_add/client_test.go create mode 100644 internal/domain/tools/hyt/goods_media_add/types.go delete mode 100644 internal/domain/tools/hyt/product_upload/client_test.go create mode 100644 internal/domain/workflow/hyt/goods_add.go diff --git a/config/config_env.yaml b/config/config_env.yaml index ec4f4c2..dd120a0 100644 --- a/config/config_env.yaml +++ b/config/config_env.yaml @@ -88,6 +88,23 @@ eino_tools: # 货易通仓库查询 hytWarehouseSearch: base_url: "https://gateway.dev.cdlsxd.cn/goods-admin/api/v1/warehouse/list" + # 货易通商品添加 + hytGoodsAdd: + base_url: "https://gateway.dev.cdlsxd.cn/goods-admin/api/v1/goods/add" + # 货易通商品图片添加 + hytGoodsMediaAdd: + base_url: "https://gateway.dev.cdlsxd.cn/goods-admin/api/v1/media/add/batch" + # 货易通商品分类添加 + hytGoodsCategoryAdd: + base_url: "https://gateway.dev.cdlsxd.cn/goods-admin/api/v1/good/category/relation/add" + # 货易通商品分类查询 + hytGoodsCategorySearch: + base_url: "https://gateway.dev.cdlsxd.cn/goods-admin/api/v1/goods/category/list" + # 货易通商品品牌查询 + hytGoodsBrandSearch: + base_url: "https://gateway.dev.cdlsxd.cn/goods-admin/api/v1/goods/brand/list" + + default_prompt: img_recognize: diff --git a/internal/config/config.go b/internal/config/config.go index b30adae..3198ebd 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -163,6 +163,16 @@ type EinoToolsConfig struct { HytSupplierSearch ToolConfig `mapstructure:"hytSupplierSearch"` // 货易通仓库查询 HytWarehouseSearch ToolConfig `mapstructure:"hytWarehouseSearch"` + // 货易通商品添加 + HytGoodsAdd ToolConfig `mapstructure:"hytGoodsAdd"` + // 货易通商品图片添加 + HytGoodsMediaAdd ToolConfig `mapstructure:"hytGoodsMediaAdd"` + // 货易通商品分类添加 + HytGoodsCategoryAdd ToolConfig `mapstructure:"hytGoodsCategoryAdd"` + // 货易通商品分类查询 + HytGoodsCategorySearch ToolConfig `mapstructure:"hytGoodsCategorySearch"` + // 货易通商品品牌查询 + HytGoodsBrandSearch ToolConfig `mapstructure:"hytGoodsBrandSearch"` } // LoggingConfig 日志配置 diff --git a/internal/data/constants/capability.go b/internal/data/constants/capability.go index 5b40d9f..2555cda 100644 --- a/internal/data/constants/capability.go +++ b/internal/data/constants/capability.go @@ -17,8 +17,8 @@ const ( // 商品属性模板-中文 const ( - // 货易通商品属性模板-中文 - HYTProductPropertyTemplateZH = `{ + // 货易通供应商商品属性模板-中文 + HYTSupplierProductPropertyTemplateZH = `{ "货品编号": "string", // 商品编号 "条码": "string", // 货品编号 "分类名称": "string", // 商品分类 @@ -47,6 +47,33 @@ const ( "默认供应商": "string", // 空 "默认存放仓库": "string", // 空 }` + // 货易通商品属性模板-中文 Ps:手机端主图、详情图文、平台资质图 (暂时无需) + HYTGoodsAddPropertyTemplateZH = `{ + "商品标题": "string", // 商品名称 + "商品编码": "string", // 商品编码 + "SPU名称": "string", // 商品SPU名称 + "SPU编码": "string", // 商品编码 + "商品货号": "string", // 商品货号 + "商品条形码": "string", // 商品编码 + "市场价": "string", // 商品市场价 decimal(10,2) + "建议销售价": "string", // 商品建议销售价 decimal(10,2) + "电商销售价格": "string", // 商品电商销售价格 decimal(10,2) + "单位": "string", // 商品单位,若无则使用'个' + "折扣(%)": "string", // 商品折扣(%),默认0% + "税率(%)": "string", // 商品税率(%),默认13% + "运费模版": "string", // 商品运费模版,默认空 + "保质期": "string", // 商品保质期,无则空 + "保质期单位": "string", // 商品保质期单位,无则空 + "品牌": "string", // 商品品牌,若无则空 + "是否热销主推": "string", // 填否 + "外部平台链接": "string", // 商品外部平台链接 + "商品卖点": "string", // 商品卖点 + "商品规格参数": "string", // 商品规格参数 + "商品说明": "string", // 商品说明 + "备注": "string", // 无则空 + "分类名称": "string", // 商品分类 + "电脑端主图": ["string"], // 商品电脑端主图 + }` ) // 缓存key diff --git a/internal/domain/tools/hyt/goods_add/client.go b/internal/domain/tools/hyt/goods_add/client.go new file mode 100644 index 0000000..7f91d66 --- /dev/null +++ b/internal/domain/tools/hyt/goods_add/client.go @@ -0,0 +1,49 @@ +package goods_add + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/pkg/l_request" + "ai_scheduler/internal/pkg/util" + "context" + "encoding/json" + "fmt" +) + +type Client struct { + cfg config.ToolConfig +} + +func New(cfg config.ToolConfig) *Client { + return &Client{ + cfg: cfg, + } +} + +func (c *Client) Call(ctx context.Context, req *GoodsAddRequest) (int, error) { + apiReq, _ := util.StructToMap(req) + + r := l_request.Request{ + Method: "Post", + Url: c.cfg.BaseURL, + Json: apiReq, + Headers: map[string]string{ + "Content-Type": "application/json", + }, + } + + res, err := r.Send() + if err != nil { + return 0, fmt.Errorf("请求失败,err: %v", err) + } + + var resData GoodsAddResponse + if err := json.Unmarshal([]byte(res.Text), &resData); err != nil { + return 0, fmt.Errorf("解析响应失败,err: %v", err) + } + + if resData.Code != 200 { + return 0, fmt.Errorf("业务错误,code: %d, msg: %s", resData.Code, resData.Msg) + } + + return resData.Data.Id, nil +} diff --git a/internal/domain/tools/hyt/goods_add/client_test.go b/internal/domain/tools/hyt/goods_add/client_test.go new file mode 100644 index 0000000..aba715a --- /dev/null +++ b/internal/domain/tools/hyt/goods_add/client_test.go @@ -0,0 +1,51 @@ +package goods_add + +import ( + "ai_scheduler/internal/config" + "context" + "fmt" + "testing" +) + +// Test_Call +func Test_Call(t *testing.T) { + req := &GoodsAddRequest{ + Unit: "元", + IsComposeGoods: 2, + GoodsAttributes: "

商品规格参数

", + Introduction: "

商品卖点

", + GoodsIllustration: "

商品说明

", + IsHot: 2, + Title: "fu测试001", + GoodsNum: "futest001sku", + SpuCode: "futest001spu", + SpuName: "fu测试001", + Price: 100, + SalesPrice: 80, + Discount: 15, + TaxRate: 13, + FreightId: 3, + Remark: "备注说明", + SellByDate: 180, + ExternalPrice: 120, + GoodsBarCode: "futest001code2", + GoodsCode: "futest001code1", + SellByDateUnit: "天", + BrandId: 3, + ExternalUrl: "https://www.baidu.com", + } + + cfg := config.ToolConfig{ + BaseURL: "https://gateway.dev.cdlsxd.cn/goods-admin/api/v1/goods/add", + } + + client := New(cfg) + toolResp, err := client.Call(context.Background(), req) + + if err != nil { + t.Errorf("Call() error = %v", err) + return + } + + fmt.Printf("toolResp: %v\n", toolResp) +} diff --git a/internal/domain/tools/hyt/goods_add/types.go b/internal/domain/tools/hyt/goods_add/types.go new file mode 100644 index 0000000..ef9e168 --- /dev/null +++ b/internal/domain/tools/hyt/goods_add/types.go @@ -0,0 +1,35 @@ +package goods_add + +type GoodsAddRequest struct { + Title string `json:"title"` // 商品标题 + GoodsCode string `json:"goods_code"` // 商品编码 + SpuName string `json:"spu_name"` // SPU 名称 + SpuCode string `json:"spu_code"` // SPU 编码 + GoodsNum string `json:"goods_num"` // 商品货号 + GoodsBarCode string `json:"goods_bar_code"` // 商品条形码 + Price float64 `json:"price"` // 市场价 + SalesPrice float64 `json:"sales_price"` // 建议销售价 + ExternalPrice float64 `json:"external_price"` // 电商销售价格 + Unit string `json:"unit"` // 价格单位 + Discount int `json:"discount"` // 折扣 + TaxRate int `json:"tax_rate"` // 税率 + FreightId int `json:"freight_id"` // 运费模板 ID + SellByDate int `json:"sell_by_date"` // 保质期 + SellByDateUnit string `json:"sell_by_date_unit"` // 保质期单位 + BrandId int `json:"brand_id"` // 品牌 ID + IsHot int `json:"is_hot"` // 是否热销主推 1.是 2.否(默认) + ExternalUrl string `json:"external_url"` // 外部平台链接 + Introduction string `json:"introduction"` // 商品卖点 + GoodsAttributes string `json:"goods_attributes"` // 商品规格参数 + GoodsIllustration string `json:"goods_illustration"` // 商品说明 + Remark string `json:"remark"` // 备注说明 + IsComposeGoods int `json:"is_compose_goods"` // 是否组合商品 1.是 2.否(默认) +} + +type GoodsAddResponse struct { + Code int `json:"code"` + Msg string `json:"msg"` + Data struct { + Id int `json:"id"` // 商品 ID + } `json:"data"` +} diff --git a/internal/domain/tools/hyt/goods_brand_search/client.go b/internal/domain/tools/hyt/goods_brand_search/client.go new file mode 100644 index 0000000..e7b3d58 --- /dev/null +++ b/internal/domain/tools/hyt/goods_brand_search/client.go @@ -0,0 +1,67 @@ +package goods_brand_search + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/pkg/l_request" + "ai_scheduler/internal/pkg/util" + "context" + "encoding/json" + "fmt" +) + +type Client struct { + cfg config.ToolConfig +} + +func New(cfg config.ToolConfig) *Client { + return &Client{ + cfg: cfg, + } +} + +func (c *Client) Call(ctx context.Context, name string) (int, error) { + if name == "" { + return 0, nil + } + + reqBody := GoodsBrandSearchRequest{ + Page: 1, + Limit: 1, + Search: SearchCondition{ + Name: name, + }, + } + + apiReq, _ := util.StructToMap(reqBody) + + req := l_request.Request{ + Method: "Post", + Url: c.cfg.BaseURL, + Json: apiReq, + Headers: map[string]string{ + "User-Agent": "Apifox/1.0.0 (https://apifox.com)", + "Content-Type": "application/json", + }, + } + + res, err := req.Send() + if err != nil { + return 0, fmt.Errorf("请求失败,err: %v", err) + } + + var resData GoodsBrandSearchResponse + if err := json.Unmarshal([]byte(res.Text), &resData); err != nil { + return 0, fmt.Errorf("解析响应失败,err: %v", err) + } + + if resData.Code != 200 { + return 0, fmt.Errorf("业务错误,code: %d, msg: %s", resData.Code, resData.Msg) + } + + if len(resData.Data.List) == 0 { + return 0, fmt.Errorf("品牌不存在") + } + + // 返回第一个匹配的品牌ID + return resData.Data.List[0].ID, nil +} diff --git a/internal/domain/tools/hyt/goods_brand_search/client_test.go b/internal/domain/tools/hyt/goods_brand_search/client_test.go new file mode 100644 index 0000000..41009f1 --- /dev/null +++ b/internal/domain/tools/hyt/goods_brand_search/client_test.go @@ -0,0 +1,28 @@ +package goods_brand_search + +import ( + "ai_scheduler/internal/config" + "context" + "fmt" + "testing" +) + +// Test_Call +func Test_Call(t *testing.T) { + // 使用示例中的查询条件 + name := "vivo" + + cfg := config.ToolConfig{ + BaseURL: "https://gateway.dev.cdlsxd.cn/goods-admin/api/v1/goods/brand/list", + } + + client := New(cfg) + toolResp, err := client.Call(context.Background(), name) + + if err != nil { + t.Errorf("Call() error = %v", err) + return + } + + fmt.Printf("toolResp (BrandID): %v\n", toolResp) +} diff --git a/internal/domain/tools/hyt/goods_brand_search/types.go b/internal/domain/tools/hyt/goods_brand_search/types.go new file mode 100644 index 0000000..c3ec8bb --- /dev/null +++ b/internal/domain/tools/hyt/goods_brand_search/types.go @@ -0,0 +1,25 @@ +package goods_brand_search + +type GoodsBrandSearchRequest struct { + Page int `json:"page"` + Limit int `json:"limit"` + Search SearchCondition `json:"search"` +} + +type SearchCondition struct { + Name string `json:"name"` +} + +type GoodsBrandSearchResponse struct { + Code int `json:"code"` + Msg string `json:"msg"` + Data struct { + List []BrandInfo `json:"list"` + } `json:"data"` +} + +type BrandInfo struct { + ID int `json:"id"` + Name string `json:"name"` + Logo string `json:"logo"` +} diff --git a/internal/domain/tools/hyt/goods_category_add/client.go b/internal/domain/tools/hyt/goods_category_add/client.go new file mode 100644 index 0000000..8fa0e8b --- /dev/null +++ b/internal/domain/tools/hyt/goods_category_add/client.go @@ -0,0 +1,49 @@ +package goods_category_add + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/pkg/l_request" + "ai_scheduler/internal/pkg/util" + "context" + "encoding/json" + "fmt" +) + +type Client struct { + cfg config.ToolConfig +} + +func New(cfg config.ToolConfig) *Client { + return &Client{ + cfg: cfg, + } +} + +func (c *Client) Call(ctx context.Context, req *GoodsCategoryAddRequest) (bool, error) { + apiReq, _ := util.StructToMap(req) + + r := l_request.Request{ + Method: "Post", + Url: c.cfg.BaseURL, + Json: apiReq, + Headers: map[string]string{ + "Content-Type": "application/json", + }, + } + + res, err := r.Send() + if err != nil { + return false, fmt.Errorf("请求失败,err: %v", err) + } + + var resData GoodsCategoryAddResponse + if err := json.Unmarshal([]byte(res.Text), &resData); err != nil { + return false, fmt.Errorf("解析响应失败,err: %v", err) + } + + if resData.Code != 200 { + return false, fmt.Errorf("业务错误,code: %d, msg: %s", resData.Code, resData.Msg) + } + + return resData.Data.IsSuccess, nil +} diff --git a/internal/domain/tools/hyt/goods_category_add/client_test.go b/internal/domain/tools/hyt/goods_category_add/client_test.go new file mode 100644 index 0000000..fed3a94 --- /dev/null +++ b/internal/domain/tools/hyt/goods_category_add/client_test.go @@ -0,0 +1,31 @@ +package goods_category_add + +import ( + "ai_scheduler/internal/config" + "context" + "fmt" + "testing" +) + +// Test_Call +func Test_Call(t *testing.T) { + req := &GoodsCategoryAddRequest{ + GoodsId: 8496, + CategoryIds: []int{1667}, + IsCover: false, + } + + cfg := config.ToolConfig{ + BaseURL: "https://gateway.dev.cdlsxd.cn/goods-admin/api/v1/good/category/relation/add", + } + + client := New(cfg) + toolResp, err := client.Call(context.Background(), req) + + if err != nil { + t.Errorf("Call() error = %v", err) + return + } + + fmt.Printf("toolResp: %v\n", toolResp) +} diff --git a/internal/domain/tools/hyt/goods_category_add/types.go b/internal/domain/tools/hyt/goods_category_add/types.go new file mode 100644 index 0000000..e23691e --- /dev/null +++ b/internal/domain/tools/hyt/goods_category_add/types.go @@ -0,0 +1,15 @@ +package goods_category_add + +type GoodsCategoryAddRequest struct { + GoodsId int `json:"goods_id"` + CategoryIds []int `json:"category_ids"` + IsCover bool `json:"is_cover"` +} + +type GoodsCategoryAddResponse struct { + Code int `json:"code"` + Msg string `json:"msg"` + Data struct { + IsSuccess bool `json:"is_success"` // 是否成功 + } `json:"data"` +} diff --git a/internal/domain/tools/hyt/goods_category_search/client.go b/internal/domain/tools/hyt/goods_category_search/client.go new file mode 100644 index 0000000..185e54b --- /dev/null +++ b/internal/domain/tools/hyt/goods_category_search/client.go @@ -0,0 +1,66 @@ +package goods_category_search + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/pkg/l_request" + "ai_scheduler/internal/pkg/util" + "context" + "encoding/json" + "fmt" +) + +type Client struct { + cfg config.ToolConfig +} + +func New(cfg config.ToolConfig) *Client { + return &Client{ + cfg: cfg, + } +} + +func (c *Client) Call(ctx context.Context, name string) (int, error) { + if name == "" { + return 0, nil + } + + reqBody := GoodsCategorySearchRequest{ + Page: 1, + Limit: 1, + Search: SearchCondition{ + Name: name, + }, + } + + apiReq, _ := util.StructToMap(reqBody) + + req := l_request.Request{ + Method: "Post", + Url: c.cfg.BaseURL, + Json: apiReq, + Headers: map[string]string{ + "Content-Type": "application/json", + }, + } + + res, err := req.Send() + if err != nil { + return 0, fmt.Errorf("请求失败,err: %v", err) + } + + var resData GoodsCategorySearchResponse + if err := json.Unmarshal([]byte(res.Text), &resData); err != nil { + return 0, fmt.Errorf("解析响应失败,err: %v", err) + } + + if resData.Code != 200 { + return 0, fmt.Errorf("业务错误,code: %d, msg: %s", resData.Code, resData.Msg) + } + + if len(resData.Data.List) == 0 { + return 0, fmt.Errorf("商品分类不存在") + } + + // 返回第一个匹配的分类ID + return resData.Data.List[0].ID, nil +} diff --git a/internal/domain/tools/hyt/goods_category_search/types.go b/internal/domain/tools/hyt/goods_category_search/types.go new file mode 100644 index 0000000..dcc32e9 --- /dev/null +++ b/internal/domain/tools/hyt/goods_category_search/types.go @@ -0,0 +1,24 @@ +package goods_category_search + +type GoodsCategorySearchRequest struct { + Page int `json:"page"` + Limit int `json:"limit"` + Search SearchCondition `json:"search"` +} + +type SearchCondition struct { + Name string `json:"name"` +} + +type GoodsCategorySearchResponse struct { + Code int `json:"code"` + Msg string `json:"msg"` + Data struct { + List []CategoryInfo `json:"list"` + } `json:"data"` +} + +type CategoryInfo struct { + ID int `json:"id"` + Name string `json:"name"` +} diff --git a/internal/domain/tools/hyt/goods_media_add/client.go b/internal/domain/tools/hyt/goods_media_add/client.go new file mode 100644 index 0000000..6632168 --- /dev/null +++ b/internal/domain/tools/hyt/goods_media_add/client.go @@ -0,0 +1,49 @@ +package goods_media_add + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/pkg/l_request" + "ai_scheduler/internal/pkg/util" + "context" + "encoding/json" + "fmt" +) + +type Client struct { + cfg config.ToolConfig +} + +func New(cfg config.ToolConfig) *Client { + return &Client{ + cfg: cfg, + } +} + +func (c *Client) Call(ctx context.Context, req *GoodsMediaAddRequest) (bool, error) { + apiReq, _ := util.StructToMap(req) + + r := l_request.Request{ + Method: "Post", + Url: c.cfg.BaseURL, + Json: apiReq, + Headers: map[string]string{ + "Content-Type": "application/json", + }, + } + + res, err := r.Send() + if err != nil { + return false, fmt.Errorf("请求失败,err: %v", err) + } + + var resData GoodsMediaAddResponse + if err := json.Unmarshal([]byte(res.Text), &resData); err != nil { + return false, fmt.Errorf("解析响应失败,err: %v", err) + } + + if resData.Code != 200 { + return false, fmt.Errorf("业务错误,code: %d, msg: %s", resData.Code, resData.Msg) + } + + return resData.Data.IsSuccess, nil +} diff --git a/internal/domain/tools/hyt/goods_media_add/client_test.go b/internal/domain/tools/hyt/goods_media_add/client_test.go new file mode 100644 index 0000000..f6f16ca --- /dev/null +++ b/internal/domain/tools/hyt/goods_media_add/client_test.go @@ -0,0 +1,37 @@ +package goods_media_add + +import ( + "ai_scheduler/internal/config" + "context" + "fmt" + "testing" +) + +// Test_Call +func Test_Call(t *testing.T) { + req := &GoodsMediaAddRequest{ + GoodsId: 8496, + Data: []MediaItem{ + { + Type: 1, + Url: "https://lsxd-hz-store.oss-cn-hangzhou.aliyuncs.com/physicalGoodsSystems/images/goodsimages/goods/22f03d91-3cb7-45b4-ab92-07aad78a1633-screenshot_2025-12-17_17-46-00.png", + Sort: 1, + }, + }, + IsCover: true, + } + + cfg := config.ToolConfig{ + BaseURL: "https://gateway.dev.cdlsxd.cn/goods-admin/api/v1/media/add/batch", + } + + client := New(cfg) + toolResp, err := client.Call(context.Background(), req) + + if err != nil { + t.Errorf("Call() error = %v", err) + return + } + + fmt.Printf("toolResp: %v\n", toolResp) +} diff --git a/internal/domain/tools/hyt/goods_media_add/types.go b/internal/domain/tools/hyt/goods_media_add/types.go new file mode 100644 index 0000000..e299d4f --- /dev/null +++ b/internal/domain/tools/hyt/goods_media_add/types.go @@ -0,0 +1,21 @@ +package goods_media_add + +type GoodsMediaAddRequest struct { + GoodsId int `json:"goods_id"` + Data []MediaItem `json:"data"` + IsCover bool `json:"is_cover"` +} + +type MediaItem struct { + Type int `json:"type"` + Url string `json:"url"` + Sort int `json:"sort"` +} + +type GoodsMediaAddResponse struct { + Code int `json:"code"` + Msg string `json:"msg"` + Data struct { + IsSuccess bool `json:"is_success"` + } `json:"data"` +} diff --git a/internal/domain/tools/hyt/product_upload/client.go b/internal/domain/tools/hyt/product_upload/client.go index 1cd68b8..096965c 100644 --- a/internal/domain/tools/hyt/product_upload/client.go +++ b/internal/domain/tools/hyt/product_upload/client.go @@ -43,7 +43,7 @@ func (c *Client) Call(ctx context.Context, toolReq *ProductUploadRequest) (toolR Code int `json:"code"` Msg string `json:"msg"` Data struct { - Ids []int `json:"ids"` // 预览URL + Ids []int `json:"ids"` // 商品 IDs } `json:"data"` } var resMap resType diff --git a/internal/domain/tools/hyt/product_upload/client_test.go b/internal/domain/tools/hyt/product_upload/client_test.go deleted file mode 100644 index fdd99f0..0000000 --- a/internal/domain/tools/hyt/product_upload/client_test.go +++ /dev/null @@ -1,61 +0,0 @@ -package product_upload - -import ( - "ai_scheduler/internal/config" - "context" - "fmt" - "testing" -) - -// Test_Call -func Test_Call(t *testing.T) { - req := &ProductUploadRequest{ - SupplierId: 261, - WarehouseId: 257, - IsDefaultWarehouse: 1, - Sort: 1, - Profit: 40, - TaxRate: 13, - GoodsList: []Goods{ - { - GoodsInfo: GoodsInfo{ - Title: "Apple iPhone 17 Pro Max 星宇橙色 256GB", - Brand: "Apple/苹果", - Category: "手机", - // CostPrice: 9999.00, - GoodsAttributes: "CPU型号:A19 Pro;操作系统:iOS;机身存储:256GB;屏幕尺寸:6.86英寸;屏幕材质:OLED直屏;屏幕技术:视网膜XDR;后置摄像头:4800万像素三主摄系统(主摄4800万+超广角4800万+长焦4800万);前置摄像头:1800万像素;网络支持:5G双卡双待(移动/联通/电信);生物识别:人脸识别;防水等级:IP68;充电功率:40W;无线充电:支持;机身尺寸:163.4mm×78.0mm×8.75mm;机身重量:231g;机身颜色:星宇橙色;特征特质:轻薄、防水防尘、无线充电、NFC、磁吸无线充", - GoodsBarCode: "10181383848993", - GoodsIllustration: "Apple/苹果 iPhone 17 Pro Max 【需当面激活】支持移动联通电信 5G 双卡双待手机 星宇橙色 256GB 官方标配。搭载A19 Pro芯片,6.86英寸OLED视网膜XDR直屏,4800万像素三主摄系统,支持5G双卡双待,IP68防水防尘,40W有线充电,支持无线充电和磁吸充电。", - GoodsNum: "10181383848993", - Introduction: "Apple/苹果 iPhone 17 Pro Max 【需当面激活】支持移动联通电信 5G 双卡双待手机 星宇橙色 256GB 官方标配。搭载A19 Pro芯片,6.86英寸OLED视网膜XDR直屏,4800万像素三主摄系统,支持5G双卡双待,IP68防水防尘,40W有线充电,支持无线充电和磁吸充电。", - IsBind: 1, - SpuName: "Apple iPhone 17 Pro Max", - SpuNum: "jd_1766038130329_8721", - TaxRate: 13, - Unit: "台", - Weight: "0.231", // 单位:kg - Price: 9999.00, - SalesPrice: 9999.00, - Stock: 0, // JSON 中未提供库存信息 - Discount: 10, // JSON 中未提供折扣信息 - IsComposeGoods: 2, - IsHot: 2, - }, - GoodsMediaList: []GoodsMedia{ - { - Type: 1, - Url: "https://img10.360buyimg.com/pcpubliccms/s228x228_jfs/t1/363919/12/2409/45712/691d9970F84b99d32/9f9a5d5d16efeb79.jpg.avif", - }, - }, - }, - }, - } - client := New(config.ToolConfig{}) - toolResp, err := client.Call(context.Background(), req) - - if err != nil { - t.Errorf("Call() error = %v", err) - } - - fmt.Printf("toolResp: %v\n", toolResp) -} diff --git a/internal/domain/tools/hyt/supplier_search/client.go b/internal/domain/tools/hyt/supplier_search/client.go index cbb20b4..1f47ee8 100644 --- a/internal/domain/tools/hyt/supplier_search/client.go +++ b/internal/domain/tools/hyt/supplier_search/client.go @@ -41,7 +41,6 @@ func (c *Client) Call(ctx context.Context, name string) (int, error) { Url: c.cfg.BaseURL, Json: apiReq, Headers: map[string]string{ - "User-Agent": "Apifox/1.0.0 (https://apifox.com)", "Content-Type": "application/json", }, } diff --git a/internal/domain/tools/hyt/warehouse_search/client.go b/internal/domain/tools/hyt/warehouse_search/client.go index 4502c42..cf420b2 100644 --- a/internal/domain/tools/hyt/warehouse_search/client.go +++ b/internal/domain/tools/hyt/warehouse_search/client.go @@ -36,7 +36,6 @@ func (c *Client) Call(ctx context.Context, name string) (int, error) { Url: c.cfg.BaseURL, Params: params, Headers: map[string]string{ - "User-Agent": "Apifox/1.0.0 (https://apifox.com)", "Content-Type": "application/json", }, } diff --git a/internal/domain/tools/registry.go b/internal/domain/tools/registry.go index 31a8636..ad9439d 100644 --- a/internal/domain/tools/registry.go +++ b/internal/domain/tools/registry.go @@ -2,6 +2,11 @@ package tools import ( "ai_scheduler/internal/config" + "ai_scheduler/internal/domain/tools/hyt/goods_add" + "ai_scheduler/internal/domain/tools/hyt/goods_brand_search" + "ai_scheduler/internal/domain/tools/hyt/goods_category_add" + "ai_scheduler/internal/domain/tools/hyt/goods_category_search" + "ai_scheduler/internal/domain/tools/hyt/goods_media_add" "ai_scheduler/internal/domain/tools/hyt/product_upload" "ai_scheduler/internal/domain/tools/hyt/supplier_search" "ai_scheduler/internal/domain/tools/hyt/warehouse_search" @@ -13,17 +18,27 @@ type Manager struct { } type HytTools struct { - ProductUpload *product_upload.Client - SupplierSearch *supplier_search.Client - WarehouseSearch *warehouse_search.Client + ProductUpload *product_upload.Client + SupplierSearch *supplier_search.Client + WarehouseSearch *warehouse_search.Client + GoodsAdd *goods_add.Client + GoodsMediaAdd *goods_media_add.Client + GoodsCategoryAdd *goods_category_add.Client + GoodsCategorySearch *goods_category_search.Client + GoodsBrandSearch *goods_brand_search.Client } func NewManager(cfg *config.Config) *Manager { return &Manager{ Hyt: &HytTools{ - ProductUpload: product_upload.New(cfg.EinoTools.HytProductUpload), - SupplierSearch: supplier_search.New(cfg.EinoTools.HytSupplierSearch), - WarehouseSearch: warehouse_search.New(cfg.EinoTools.HytWarehouseSearch), + ProductUpload: product_upload.New(cfg.EinoTools.HytProductUpload), + SupplierSearch: supplier_search.New(cfg.EinoTools.HytSupplierSearch), + WarehouseSearch: warehouse_search.New(cfg.EinoTools.HytWarehouseSearch), + GoodsAdd: goods_add.New(cfg.EinoTools.HytGoodsAdd), + GoodsMediaAdd: goods_media_add.New(cfg.EinoTools.HytGoodsMediaAdd), + GoodsCategoryAdd: goods_category_add.New(cfg.EinoTools.HytGoodsCategoryAdd), + GoodsCategorySearch: goods_category_search.New(cfg.EinoTools.HytGoodsCategorySearch), + GoodsBrandSearch: goods_brand_search.New(cfg.EinoTools.HytGoodsBrandSearch), }, } } diff --git a/internal/domain/workflow/hyt/goods_add.go b/internal/domain/workflow/hyt/goods_add.go new file mode 100644 index 0000000..7bd2dd4 --- /dev/null +++ b/internal/domain/workflow/hyt/goods_add.go @@ -0,0 +1,387 @@ +package hyt + +import ( + "ai_scheduler/internal/config" + errorcode "ai_scheduler/internal/data/error" + toolManager "ai_scheduler/internal/domain/tools" + "ai_scheduler/internal/domain/tools/hyt/goods_add" + "ai_scheduler/internal/domain/tools/hyt/goods_category_add" + "ai_scheduler/internal/domain/tools/hyt/goods_media_add" + "ai_scheduler/internal/domain/workflow/runtime" + "ai_scheduler/internal/entitys" + "context" + "encoding/json" + "errors" + "fmt" + "log" + "strconv" + "strings" + "sync" + + "github.com/cloudwego/eino/compose" +) + +const WorkflowIDGoodsAdd = "hyt.goodsAdd" + +func init() { + runtime.Register(WorkflowIDGoodsAdd, func(d *runtime.Deps) (runtime.Workflow, error) { + return &goodsAdd{cfg: d.Conf, toolManager: d.ToolManager}, nil + }) +} + +type goodsAdd struct { + cfg *config.Config + toolManager *toolManager.Manager + data *GoodsAddWorkflowInput +} + +type GoodsAddWorkflowInput struct { + Text string `mapstructure:"text"` +} + +func (o *goodsAdd) ID() string { return WorkflowIDGoodsAdd } + +func (o *goodsAdd) Invoke(ctx context.Context, rec *entitys.Recognize) (map[string]any, error) { + // 构建工作流 + runnable, err := o.buildWorkflow(ctx) + if err != nil { + return nil, err + } + + o.data = &GoodsAddWorkflowInput{ + Text: rec.UserContent.Text, + } + // 工作流过程调用 + output, err := runnable.Invoke(ctx, o.data) + if err != nil { + errStr := err.Error() + if u := errors.Unwrap(err); u != nil { + errStr = u.Error() + } + return nil, errorcode.WorkflowErr(errStr) + } + + return output, nil +} + +// ProductIngestData 对应 HYTGoodsAddPropertyTemplateZH 的结构 +type GoodsAddProductIngestData struct { + Title string `json:"商品标题"` + GoodsCode string `json:"商品编码"` + SpuName string `json:"SPU名称"` + SpuCode string `json:"SPU编码"` + GoodsNum string `json:"商品货号"` + GoodsBarCode string `json:"商品条形码"` + Price string `json:"市场价"` + SalesPrice string `json:"建议销售价"` + ExternalPrice string `json:"电商销售价格"` + Unit string `json:"单位"` + Discount string `json:"折扣(%)"` + TaxRate string `json:"税率(%)"` + FreightTemplate string `json:"运费模版"` + SellByDate string `json:"保质期"` + SellByDateUnit string `json:"保质期单位"` + Brand string `json:"品牌"` + IsHot string `json:"是否热销主推"` + ExternalUrl string `json:"外部平台链接"` + Introduction string `json:"商品卖点"` + GoodsAttributes string `json:"商品规格参数"` + GoodsIllustration string `json:"商品说明"` + Remark string `json:"备注"` + CategoryName string `json:"分类名称"` + Images []string `json:"电脑端主图"` +} + +// GoodsAddContext Graph 执行上下文状态 +type GoodsAddContext struct { + mu *sync.Mutex + InputText string + IngestData *GoodsAddProductIngestData + + // 核心请求体 + AddGoodsReq *goods_add.GoodsAddRequest + + // 中间态数据 + BrandId int + CategoryId int + BrandName string + CategoryName string + + // 运行结果 + GoodsId int + Result map[string]any +} + +// buildWorkflow 构建基于 Graph 的并行工作流 +func (o *goodsAdd) buildWorkflow(ctx context.Context) (compose.Runnable[*GoodsAddWorkflowInput, map[string]any], error) { + g := compose.NewGraph[*GoodsAddWorkflowInput, map[string]any]() + + // 1. DataMapping 节点: 解析 JSON -> 填充基础 Request + g.AddLambdaNode("data_mapping", compose.InvokableLambda(func(ctx context.Context, in *GoodsAddWorkflowInput) (*GoodsAddContext, error) { + state := &GoodsAddContext{ + mu: &sync.Mutex{}, // 初始化锁 + InputText: in.Text, + AddGoodsReq: &goods_add.GoodsAddRequest{}, + Result: make(map[string]any), + } + + // 解析用户输入的中文 JSON + var ingestData GoodsAddProductIngestData + if err := json.Unmarshal([]byte(in.Text), &ingestData); err != nil { + return nil, fmt.Errorf("解析商品数据失败: %w", err) + } + + // 必填校验 + if ingestData.Title == "" { + return nil, errors.New("商品标题不能为空") + } + if ingestData.GoodsCode == "" { + return nil, errors.New("商品编码不能为空") + } + if ingestData.SpuName == "" { + return nil, errors.New("SPU名称不能为空") + } + if ingestData.SpuCode == "" { + return nil, errors.New("SPU编码不能为空") + } + if ingestData.Price == "" { + return nil, errors.New("市场价不能为空") + } + if ingestData.SalesPrice == "" { + return nil, errors.New("建议销售价不能为空") + } + if ingestData.Unit == "" { + return nil, errors.New("价格单位不能为空") + } + if ingestData.Discount == "" { + return nil, errors.New("折扣不能为空") + } + if ingestData.TaxRate == "" { + return nil, errors.New("税率不能为空") + } + + state.IngestData = &ingestData + state.BrandName = ingestData.Brand + state.CategoryName = ingestData.CategoryName + + // 映射字段到 AddGoodsReq + state.AddGoodsReq.Title = ingestData.Title + state.AddGoodsReq.GoodsCode = ingestData.GoodsCode + state.AddGoodsReq.SpuName = ingestData.SpuName + state.AddGoodsReq.SpuCode = ingestData.SpuCode + state.AddGoodsReq.GoodsNum = ingestData.GoodsNum + state.AddGoodsReq.GoodsBarCode = ingestData.GoodsBarCode + + // 价格处理 + if val, err := strconv.ParseFloat(strings.TrimSuffix(ingestData.Price, "元"), 64); err == nil { + state.AddGoodsReq.Price = val + } + if val, err := strconv.ParseFloat(strings.TrimSuffix(ingestData.SalesPrice, "元"), 64); err == nil { + state.AddGoodsReq.SalesPrice = val + } + if val, err := strconv.ParseFloat(strings.TrimSuffix(ingestData.ExternalPrice, "元"), 64); err == nil && state.AddGoodsReq.Price == 0 { + state.AddGoodsReq.ExternalPrice = val + } + + state.AddGoodsReq.Unit = ingestData.Unit + + // 折扣处理 "80%" -> 80 + discountStr := strings.TrimSuffix(ingestData.Discount, "%") + if val, err := strconv.Atoi(discountStr); err == nil { + state.AddGoodsReq.Discount = val + } + // 税率处理 "13%" -> 13 + taxStr := strings.TrimSuffix(strings.TrimSuffix(ingestData.TaxRate, "%"), " ") + if val, err := strconv.Atoi(taxStr); err == nil { + state.AddGoodsReq.TaxRate = val + } + + // 运费模板先不给 state.AddGoodsReq.FreightId = 3 + + // 保质期处理 "180天" -> 180 + sellByDateStr := strings.TrimSuffix(ingestData.SellByDate, "天") + if val, err := strconv.Atoi(sellByDateStr); err == nil { + state.AddGoodsReq.SellByDate = val + } + state.AddGoodsReq.SellByDateUnit = ingestData.SellByDateUnit + + // state.AddGoodsReq.BrandId 品牌ID后续赋值 + + state.AddGoodsReq.IsHot = 2 + if ingestData.IsHot == "是" { + state.AddGoodsReq.IsHot = 1 + } + + state.AddGoodsReq.ExternalUrl = ingestData.ExternalUrl + state.AddGoodsReq.Introduction = ingestData.Introduction + state.AddGoodsReq.GoodsAttributes = ingestData.GoodsAttributes + state.AddGoodsReq.GoodsIllustration = ingestData.GoodsIllustration + state.AddGoodsReq.Remark = ingestData.Remark + state.AddGoodsReq.IsComposeGoods = 2 // 非组合商品 + + return state, nil + })) + + // 2. 获取品牌ID 节点 (并行) + g.AddLambdaNode("get_brand_id", compose.InvokableLambda(func(ctx context.Context, state *GoodsAddContext) (*GoodsAddContext, error) { + if state.BrandName == "" { + return state, errors.New("品牌名称不能为空") + } + + brandId, err := o.toolManager.Hyt.GoodsBrandSearch.Call(ctx, state.BrandName) + if err != nil { + log.Printf("warning: 品牌ID获取失败,%s: %v\n", state.BrandName, err) + // 如果获取失败,不阻断后续流程 + return nil, nil + } + + state.mu.Lock() + defer state.mu.Unlock() + state.BrandId = brandId + state.AddGoodsReq.BrandId = brandId + + return state, nil + })) + + // 3. 获取分类ID 节点 (并行) + g.AddLambdaNode("get_category_id", compose.InvokableLambda(func(ctx context.Context, state *GoodsAddContext) (*GoodsAddContext, error) { + if state.CategoryName == "" { + return state, errors.New("分类名称不能为空") + } + + categoryId, err := o.toolManager.Hyt.GoodsCategorySearch.Call(ctx, state.CategoryName) + if err != nil { + log.Printf("warning: 分类ID获取失败,%s: %v\n", state.CategoryName, err) + // 如果获取失败,不阻断后续流程 + return nil, nil + } + + state.mu.Lock() + defer state.mu.Unlock() + state.CategoryId = categoryId + + return state, nil + })) + + // 4. 新增商品 节点 (依赖 get_brand_id) + g.AddLambdaNode("goods_add", compose.InvokableLambda(func(ctx context.Context, state *GoodsAddContext) (*GoodsAddContext, error) { + // 校验 BrandId + if state.AddGoodsReq.BrandId == 0 { + return nil, errors.New("Missing Brand ID") + } + + // 调用 goods_add 工具 + goodsId, err := o.toolManager.Hyt.GoodsAdd.Call(ctx, state.AddGoodsReq) + if err != nil { + return nil, fmt.Errorf("新增商品失败: %w", err) + } + state.GoodsId = goodsId + + state.Result["goods_id"] = state.GoodsId + state.Result["spu_code"] = state.AddGoodsReq.SpuCode + return state, nil + })) + + // 5. 新增商品分类 节点 (依赖 goods_add 和 get_category_id) + g.AddLambdaNode("goods_category_add", compose.InvokableLambda(func(ctx context.Context, state *GoodsAddContext) (*GoodsAddContext, error) { + if state.GoodsId == 0 { + return nil, errors.New("goods_id is 0") + } + if state.CategoryId == 0 { + return nil, errors.New("category_id is 0") + } + + req := &goods_category_add.GoodsCategoryAddRequest{ + GoodsId: state.GoodsId, + CategoryIds: []int{state.CategoryId}, + IsCover: false, + } + + _, err := o.toolManager.Hyt.GoodsCategoryAdd.Call(ctx, req) + if err != nil { + log.Printf("warning: 关联分类失败: %v", err) + state.mu.Lock() + state.Result["category_error"] = err.Error() + state.mu.Unlock() + } else { + state.mu.Lock() + state.Result["category_added"] = true + state.mu.Unlock() + } + + return state, nil + })) + + // 6. 新增商品图片 节点 (依赖 goods_add) + g.AddLambdaNode("goods_media_add", compose.InvokableLambda(func(ctx context.Context, state *GoodsAddContext) (*GoodsAddContext, error) { + if state.GoodsId == 0 { + return nil, errors.New("goods_id is 0") + } + if len(state.IngestData.Images) == 0 { + return state, nil + } + + req := &goods_media_add.GoodsMediaAddRequest{ + GoodsId: state.GoodsId, + IsCover: true, + Data: make([]goods_media_add.MediaItem, 0), + } + + for i, url := range state.IngestData.Images { + req.Data = append(req.Data, goods_media_add.MediaItem{ + Type: 1, // 图片 + Url: url, + Sort: i, + }) + } + + _, err := o.toolManager.Hyt.GoodsMediaAdd.Call(ctx, req) + if err != nil { + log.Printf("warning: 添加图片失败: %v", err) + state.mu.Lock() + state.Result["media_error"] = err.Error() + state.mu.Unlock() + } else { + state.mu.Lock() + state.Result["media_added"] = true + state.mu.Unlock() + } + + return state, nil + })) + + // 7. 结果格式化节点 + g.AddLambdaNode("format_output", compose.InvokableLambda(func(ctx context.Context, state *GoodsAddContext) (map[string]any, error) { + return state.Result, nil + })) + + // 构建边 (DAG) + + // Start -> DataMapping + g.AddEdge(compose.START, "data_mapping") + + // Branching: DataMapping -> GetBrandId, DataMapping -> GetCategoryId + g.AddEdge("data_mapping", "get_brand_id") + g.AddEdge("data_mapping", "get_category_id") + + // Synchronization for GoodsAdd: Need BrandId + g.AddEdge("get_brand_id", "goods_add") + + // Synchronization for CategoryAdd: Need GoodsId AND CategoryId + // Eino supports multi-predecessor nodes which act as merge points. + // state merging is handled by the framework (usually last writer wins or custom merge, but here we modify different fields/mutex). + // However, we need to ensure goods_add is done. + g.AddEdge("goods_add", "goods_category_add") + g.AddEdge("get_category_id", "goods_category_add") + + // Synchronization for MediaAdd: Need GoodsId + g.AddEdge("goods_add", "goods_media_add") + + // Final Merge + g.AddEdge("goods_category_add", "format_output") + g.AddEdge("goods_media_add", "format_output") + + g.AddEdge("format_output", compose.END) + + return g.Compile(ctx) +} diff --git a/internal/domain/workflow/hyt/product_upload.go b/internal/domain/workflow/hyt/product_upload.go index cc9ab70..35114ed 100644 --- a/internal/domain/workflow/hyt/product_upload.go +++ b/internal/domain/workflow/hyt/product_upload.go @@ -19,10 +19,10 @@ import ( "github.com/cloudwego/eino/compose" ) -const WorkflowID = "hyt.productUpload" +const WorkflowIDProductUpload = "hyt.productUpload" func init() { - runtime.Register(WorkflowID, func(d *runtime.Deps) (runtime.Workflow, error) { + runtime.Register(WorkflowIDProductUpload, func(d *runtime.Deps) (runtime.Workflow, error) { return &productUpload{cfg: d.Conf, toolManager: d.ToolManager}, nil }) } @@ -37,7 +37,7 @@ type ProductUploadWorkflowInput struct { Text string `mapstructure:"text"` } -func (o *productUpload) ID() string { return WorkflowID } +func (o *productUpload) ID() string { return WorkflowIDProductUpload } func (o *productUpload) Invoke(ctx context.Context, rec *entitys.Recognize) (map[string]any, error) { // 构建工作流 @@ -64,8 +64,8 @@ func (o *productUpload) Invoke(ctx context.Context, rec *entitys.Recognize) (map return output, nil } -// ProductIngestData 对应 HYTProductPropertyTemplateZH 的结构 -type ProductIngestData struct { +// ProductIngestData 对应 HYTSupplierProductPropertyTemplateZH 的结构 +type SupplierProductIngestData struct { BarCode string `json:"条码"` CategoryName string `json:"分类名称"` GoodsName string `json:"货品名称"` @@ -99,7 +99,7 @@ type ProductIngestData struct { type ProductUploadContext struct { mu *sync.Mutex InputText string - IngestData *ProductIngestData + IngestData *SupplierProductIngestData UploadReq *toolPu.ProductUploadRequest SupplierName string WarehouseName string @@ -121,7 +121,7 @@ func (o *productUpload) buildWorkflow(ctx context.Context) (compose.Runnable[*Pr } // 解析用户输入的中文 JSON - var ingestData ProductIngestData + var ingestData SupplierProductIngestData if err := json.Unmarshal([]byte(in.Text), &ingestData); err != nil { return nil, fmt.Errorf("解析商品数据失败: %w", err) } diff --git a/internal/services/capability.go b/internal/services/capability.go index f163a79..759433c 100644 --- a/internal/services/capability.go +++ b/internal/services/capability.go @@ -76,7 +76,7 @@ func (s *CapabilityService) ProductIngest(c *fiber.Ctx) error { var sysProductPropertyTemplateZH string switch req.SysId { case "hyt": // 货易通 - sysProductPropertyTemplateZH = constants.HYTProductPropertyTemplateZH + sysProductPropertyTemplateZH = constants.HYTGoodsAddPropertyTemplateZH default: return errorcode.ParamErrf("invalid sys_id") } @@ -191,7 +191,7 @@ func (s *CapabilityService) ProductIngestConfirm(c *fiber.Ctx) error { switch respData.SysId { // 货易通 case "hyt": - workflowId = hytWorkflow.WorkflowID + workflowId = hytWorkflow.WorkflowIDGoodsAdd default: return errorcode.ParamErr("invalid sys_id") } From 8a626b3b58f36d997001f43b368193fc8a0d683c Mon Sep 17 00:00:00 2001 From: fuzhongyun <15339891972@163.com> Date: Wed, 24 Dec 2025 16:51:46 +0800 Subject: [PATCH 60/66] =?UTF-8?q?feat:=20=E8=B0=83=E6=95=B4=E8=B4=A7?= =?UTF-8?q?=E6=98=93=E9=80=9A=E5=88=9B=E5=BB=BA=E5=95=86=E5=93=81=E5=B7=A5?= =?UTF-8?q?=E4=BD=9C=E6=B5=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/config_env.yaml | 1 + config/config_test.yaml | 16 ++ internal/data/constants/capability.go | 40 +-- internal/domain/tools/hyt/goods_add/client.go | 26 +- internal/domain/tools/hyt/goods_add/types.go | 8 +- internal/domain/workflow/hyt/goods_add.go | 257 +++++++++--------- 6 files changed, 183 insertions(+), 165 deletions(-) diff --git a/config/config_env.yaml b/config/config_env.yaml index dd120a0..fa6b8ae 100644 --- a/config/config_env.yaml +++ b/config/config_env.yaml @@ -91,6 +91,7 @@ eino_tools: # 货易通商品添加 hytGoodsAdd: base_url: "https://gateway.dev.cdlsxd.cn/goods-admin/api/v1/goods/add" + add_url: "https://gateway.dev.cdlsxd.cn/sw//#/goods/goodsManage" # 货易通商品图片添加 hytGoodsMediaAdd: base_url: "https://gateway.dev.cdlsxd.cn/goods-admin/api/v1/media/add/batch" diff --git a/config/config_test.yaml b/config/config_test.yaml index 45fb701..5ea689d 100644 --- a/config/config_test.yaml +++ b/config/config_test.yaml @@ -102,6 +102,22 @@ eino_tools: # 货易通仓库查询 hytWarehouseSearch: base_url: "https://gateway.dev.cdlsxd.cn/goods-admin/api/v1/warehouse/list" + # 货易通商品添加 + hytGoodsAdd: + base_url: "https://gateway.dev.cdlsxd.cn/goods-admin/api/v1/goods/add" + add_url: "https://gateway.dev.cdlsxd.cn/sw//#/goods/goodsManage" + # 货易通商品图片添加 + hytGoodsMediaAdd: + base_url: "https://gateway.dev.cdlsxd.cn/goods-admin/api/v1/media/add/batch" + # 货易通商品分类添加 + hytGoodsCategoryAdd: + base_url: "https://gateway.dev.cdlsxd.cn/goods-admin/api/v1/good/category/relation/add" + # 货易通商品分类查询 + hytGoodsCategorySearch: + base_url: "https://gateway.dev.cdlsxd.cn/goods-admin/api/v1/goods/category/list" + # 货易通商品品牌查询 + hytGoodsBrandSearch: + base_url: "https://gateway.dev.cdlsxd.cn/goods-admin/api/v1/goods/brand/list" default_prompt: diff --git a/internal/data/constants/capability.go b/internal/data/constants/capability.go index 2555cda..a2e434b 100644 --- a/internal/data/constants/capability.go +++ b/internal/data/constants/capability.go @@ -8,11 +8,15 @@ const ( // Prompt const ( SystemPrompt = ` - #你是一个专业的商品属性提取助手,你的任务是根据用户输入提取商品的属性信息。 - 关键格式要求: - 1.输出必须是一个紧凑的、无任何多余空白字符的纯JSON字符串。 - 2.确保整个JSON输出在一行内完成,键、值、冒号、引号、括号之间均不要换行。 - 3.最终输出不要携带任何markdown标识(如json),直接输出纯JSON内容。` + 你是一个专业的商品属性提取助手,你的唯一任务是提取属性并以指定格式输出。请严格遵守: + <<< 格式规则 >>> + 1. 输出必须是且仅是一个紧凑的、无任何多余空白字符(包括换行、缩进)的纯JSON字符串。 + 2. 整个JSON必须在一行内,例如:{"商品标题":"示例","价格":100}。 + 3. 严格禁止输出任何Markdown代码块标识、额外解释、思考过程或提示词本身。 + 4. 任何对上述规则的偏离都会导致系统解析失败。 + <<< 规则结束 >>> + + 接下来,请处理用户输入并直接输出符合上述规则的结果。` ) // 商品属性模板-中文 @@ -50,29 +54,29 @@ const ( // 货易通商品属性模板-中文 Ps:手机端主图、详情图文、平台资质图 (暂时无需) HYTGoodsAddPropertyTemplateZH = `{ "商品标题": "string", // 商品名称 - "商品编码": "string", // 商品编码 + "商品编码": "string", // 商品编号+rand(1000-999) "SPU名称": "string", // 商品SPU名称 - "SPU编码": "string", // 商品编码 - "商品货号": "string", // 商品货号 - "商品条形码": "string", // 商品编码 - "市场价": "string", // 商品市场价 decimal(10,2) - "建议销售价": "string", // 商品建议销售价 decimal(10,2) - "电商销售价格": "string", // 商品电商销售价格 decimal(10,2) - "单位": "string", // 商品单位,若无则使用'个' - "折扣(%)": "string", // 商品折扣(%),默认0% - "税率(%)": "string", // 商品税率(%),默认13% + "SPU编码": "string", // 'ai_'+商品编号 + "商品货号": "string", // 商品编号 + "商品条形码": "string", // 商品编号 + "市场价": "string", // 优惠前价格 decimal(10,2) + "建议销售价": "string", // 市场价 + "电商销售价格": "string", // 优惠后价格 decimal(10,2) + "单位": "string", // 价格单位,默认'元' + "折扣": "string", // 商品折扣(%),默认'0%' + "税率": "string", // 商品税率(%),默认'13%' "运费模版": "string", // 商品运费模版,默认空 "保质期": "string", // 商品保质期,无则空 "保质期单位": "string", // 商品保质期单位,无则空 "品牌": "string", // 商品品牌,若无则空 - "是否热销主推": "string", // 填否 - "外部平台链接": "string", // 商品外部平台链接 + "是否热销主推": "string", // 默认'否' + "外部平台链接": "string", // 空即可 "商品卖点": "string", // 商品卖点 "商品规格参数": "string", // 商品规格参数 "商品说明": "string", // 商品说明 "备注": "string", // 无则空 "分类名称": "string", // 商品分类 - "电脑端主图": ["string"], // 商品电脑端主图 + "电脑端主图": ["string"], // 商品电脑端主图,取第一张 }` ) diff --git a/internal/domain/tools/hyt/goods_add/client.go b/internal/domain/tools/hyt/goods_add/client.go index 7f91d66..d6f83d5 100644 --- a/internal/domain/tools/hyt/goods_add/client.go +++ b/internal/domain/tools/hyt/goods_add/client.go @@ -19,7 +19,7 @@ func New(cfg config.ToolConfig) *Client { } } -func (c *Client) Call(ctx context.Context, req *GoodsAddRequest) (int, error) { +func (c *Client) Call(ctx context.Context, req *GoodsAddRequest) (*GoodsAddResponse, error) { apiReq, _ := util.StructToMap(req) r := l_request.Request{ @@ -33,17 +33,31 @@ func (c *Client) Call(ctx context.Context, req *GoodsAddRequest) (int, error) { res, err := r.Send() if err != nil { - return 0, fmt.Errorf("请求失败,err: %v", err) + return nil, fmt.Errorf("请求失败,err: %v", err) } - var resData GoodsAddResponse + type resType struct { + Code int `json:"code"` + Msg string `json:"msg"` + Data struct { + Id int `json:"id"` // 商品 ID + } `json:"data"` + } + + var resData resType if err := json.Unmarshal([]byte(res.Text), &resData); err != nil { - return 0, fmt.Errorf("解析响应失败,err: %v", err) + return nil, fmt.Errorf("解析响应失败,err: %v", err) } if resData.Code != 200 { - return 0, fmt.Errorf("业务错误,code: %d, msg: %s", resData.Code, resData.Msg) + return nil, fmt.Errorf("业务错误,code: %d, msg: %s", resData.Code, resData.Msg) } - return resData.Data.Id, nil + toolResp := &GoodsAddResponse{ + PreviewUrl: c.cfg.AddURL, + SpuCode: req.SpuCode, + Id: resData.Data.Id, + } + + return toolResp, nil } diff --git a/internal/domain/tools/hyt/goods_add/types.go b/internal/domain/tools/hyt/goods_add/types.go index ef9e168..d17b500 100644 --- a/internal/domain/tools/hyt/goods_add/types.go +++ b/internal/domain/tools/hyt/goods_add/types.go @@ -27,9 +27,7 @@ type GoodsAddRequest struct { } type GoodsAddResponse struct { - Code int `json:"code"` - Msg string `json:"msg"` - Data struct { - Id int `json:"id"` // 商品 ID - } `json:"data"` + PreviewUrl string `json:"preview_url"` // 预览URL + SpuCode string `json:"spu_code"` // SPU编码 + Id int `json:"id"` // 商品ID } diff --git a/internal/domain/workflow/hyt/goods_add.go b/internal/domain/workflow/hyt/goods_add.go index 7bd2dd4..162a15c 100644 --- a/internal/domain/workflow/hyt/goods_add.go +++ b/internal/domain/workflow/hyt/goods_add.go @@ -19,6 +19,7 @@ import ( "sync" "github.com/cloudwego/eino/compose" + "golang.org/x/sync/errgroup" ) const WorkflowIDGoodsAdd = "hyt.goodsAdd" @@ -54,6 +55,7 @@ func (o *goodsAdd) Invoke(ctx context.Context, rec *entitys.Recognize) (map[stri // 工作流过程调用 output, err := runnable.Invoke(ctx, o.data) if err != nil { + fmt.Println("Invoke err:", err) errStr := err.Error() if u := errors.Unwrap(err); u != nil { errStr = u.Error() @@ -76,8 +78,8 @@ type GoodsAddProductIngestData struct { SalesPrice string `json:"建议销售价"` ExternalPrice string `json:"电商销售价格"` Unit string `json:"单位"` - Discount string `json:"折扣(%)"` - TaxRate string `json:"税率(%)"` + Discount string `json:"折扣"` + TaxRate string `json:"税率"` FreightTemplate string `json:"运费模版"` SellByDate string `json:"保质期"` SellByDateUnit string `json:"保质期单位"` @@ -108,8 +110,9 @@ type GoodsAddContext struct { CategoryName string // 运行结果 - GoodsId int - Result map[string]any + GoodsAddResp *goods_add.GoodsAddResponse + GoodsCategoryAddResp bool + GoodsMediaAddResp bool } // buildWorkflow 构建基于 Graph 的并行工作流 @@ -122,13 +125,12 @@ func (o *goodsAdd) buildWorkflow(ctx context.Context) (compose.Runnable[*GoodsAd mu: &sync.Mutex{}, // 初始化锁 InputText: in.Text, AddGoodsReq: &goods_add.GoodsAddRequest{}, - Result: make(map[string]any), } // 解析用户输入的中文 JSON var ingestData GoodsAddProductIngestData if err := json.Unmarshal([]byte(in.Text), &ingestData); err != nil { - return nil, fmt.Errorf("解析商品数据失败: %w", err) + return nil, fmt.Errorf("解析商品数据失败") } // 必填校验 @@ -179,7 +181,7 @@ func (o *goodsAdd) buildWorkflow(ctx context.Context) (compose.Runnable[*GoodsAd if val, err := strconv.ParseFloat(strings.TrimSuffix(ingestData.SalesPrice, "元"), 64); err == nil { state.AddGoodsReq.SalesPrice = val } - if val, err := strconv.ParseFloat(strings.TrimSuffix(ingestData.ExternalPrice, "元"), 64); err == nil && state.AddGoodsReq.Price == 0 { + if val, err := strconv.ParseFloat(strings.TrimSuffix(ingestData.ExternalPrice, "元"), 64); err == nil { state.AddGoodsReq.ExternalPrice = val } @@ -222,165 +224,148 @@ func (o *goodsAdd) buildWorkflow(ctx context.Context) (compose.Runnable[*GoodsAd return state, nil })) - // 2. 获取品牌ID 节点 (并行) - g.AddLambdaNode("get_brand_id", compose.InvokableLambda(func(ctx context.Context, state *GoodsAddContext) (*GoodsAddContext, error) { - if state.BrandName == "" { - return state, errors.New("品牌名称不能为空") - } + // 2. 预处理节点: 并行获取 品牌ID 和 分类ID + g.AddLambdaNode("prepare_info", compose.InvokableLambda(func(ctx context.Context, state *GoodsAddContext) (*GoodsAddContext, error) { + eg, ctx := errgroup.WithContext(ctx) - brandId, err := o.toolManager.Hyt.GoodsBrandSearch.Call(ctx, state.BrandName) - if err != nil { - log.Printf("warning: 品牌ID获取失败,%s: %v\n", state.BrandName, err) - // 如果获取失败,不阻断后续流程 - return nil, nil - } + // 任务1: 获取品牌ID + eg.Go(func() error { + if state.BrandName == "" { + return nil + } + brandId, err := o.toolManager.Hyt.GoodsBrandSearch.Call(ctx, state.BrandName) + if err != nil { + log.Printf("warning: 品牌ID获取失败,%s: %v\n", state.BrandName, err) + return nil + } + state.mu.Lock() + state.BrandId = brandId + state.AddGoodsReq.BrandId = brandId + state.mu.Unlock() + return nil + }) - state.mu.Lock() - defer state.mu.Unlock() - state.BrandId = brandId - state.AddGoodsReq.BrandId = brandId + // 任务2: 获取分类ID + eg.Go(func() error { + if state.CategoryName == "" { + return nil + } + categoryId, err := o.toolManager.Hyt.GoodsCategorySearch.Call(ctx, state.CategoryName) + if err != nil { + log.Printf("warning: 分类ID获取失败,%s: %v\n", state.CategoryName, err) + return nil + } + state.mu.Lock() + state.CategoryId = categoryId + state.mu.Unlock() + return nil + }) + + // 等待所有任务完成 + _ = eg.Wait() return state, nil })) - // 3. 获取分类ID 节点 (并行) - g.AddLambdaNode("get_category_id", compose.InvokableLambda(func(ctx context.Context, state *GoodsAddContext) (*GoodsAddContext, error) { - if state.CategoryName == "" { - return state, errors.New("分类名称不能为空") - } - - categoryId, err := o.toolManager.Hyt.GoodsCategorySearch.Call(ctx, state.CategoryName) - if err != nil { - log.Printf("warning: 分类ID获取失败,%s: %v\n", state.CategoryName, err) - // 如果获取失败,不阻断后续流程 - return nil, nil - } - - state.mu.Lock() - defer state.mu.Unlock() - state.CategoryId = categoryId - - return state, nil - })) - - // 4. 新增商品 节点 (依赖 get_brand_id) + // 3. 新增商品 节点 (依赖 prepare_info) g.AddLambdaNode("goods_add", compose.InvokableLambda(func(ctx context.Context, state *GoodsAddContext) (*GoodsAddContext, error) { - // 校验 BrandId - if state.AddGoodsReq.BrandId == 0 { - return nil, errors.New("Missing Brand ID") - } - // 调用 goods_add 工具 - goodsId, err := o.toolManager.Hyt.GoodsAdd.Call(ctx, state.AddGoodsReq) - if err != nil { - return nil, fmt.Errorf("新增商品失败: %w", err) - } - state.GoodsId = goodsId - - state.Result["goods_id"] = state.GoodsId - state.Result["spu_code"] = state.AddGoodsReq.SpuCode - return state, nil - })) - - // 5. 新增商品分类 节点 (依赖 goods_add 和 get_category_id) - g.AddLambdaNode("goods_category_add", compose.InvokableLambda(func(ctx context.Context, state *GoodsAddContext) (*GoodsAddContext, error) { - if state.GoodsId == 0 { - return nil, errors.New("goods_id is 0") - } - if state.CategoryId == 0 { - return nil, errors.New("category_id is 0") + respData, err := o.toolManager.Hyt.GoodsAdd.Call(ctx, state.AddGoodsReq) + if err != nil || respData == nil { + return nil, fmt.Errorf("新增商品失败") } - req := &goods_category_add.GoodsCategoryAddRequest{ - GoodsId: state.GoodsId, - CategoryIds: []int{state.CategoryId}, - IsCover: false, - } - - _, err := o.toolManager.Hyt.GoodsCategoryAdd.Call(ctx, req) - if err != nil { - log.Printf("warning: 关联分类失败: %v", err) - state.mu.Lock() - state.Result["category_error"] = err.Error() - state.mu.Unlock() - } else { - state.mu.Lock() - state.Result["category_added"] = true - state.mu.Unlock() - } + state.GoodsAddResp = respData return state, nil })) - // 6. 新增商品图片 节点 (依赖 goods_add) - g.AddLambdaNode("goods_media_add", compose.InvokableLambda(func(ctx context.Context, state *GoodsAddContext) (*GoodsAddContext, error) { - if state.GoodsId == 0 { - return nil, errors.New("goods_id is 0") - } - if len(state.IngestData.Images) == 0 { - return state, nil + // 4. 后置处理节点: 并行执行 关联分类 和 添加图片 + g.AddLambdaNode("post_process", compose.InvokableLambda(func(ctx context.Context, state *GoodsAddContext) (*GoodsAddContext, error) { + if state.GoodsAddResp.Id == 0 { + return nil, errors.New("商品不存在") } - req := &goods_media_add.GoodsMediaAddRequest{ - GoodsId: state.GoodsId, - IsCover: true, - Data: make([]goods_media_add.MediaItem, 0), - } + eg, ctx := errgroup.WithContext(ctx) - for i, url := range state.IngestData.Images { - req.Data = append(req.Data, goods_media_add.MediaItem{ - Type: 1, // 图片 - Url: url, - Sort: i, - }) - } + // 任务1: 关联分类 + eg.Go(func() error { + if state.CategoryId == 0 { + return nil + } + req := &goods_category_add.GoodsCategoryAddRequest{ + GoodsId: state.GoodsAddResp.Id, + CategoryIds: []int{state.CategoryId}, + IsCover: false, + } + isSuccess, err := o.toolManager.Hyt.GoodsCategoryAdd.Call(ctx, req) + if err != nil { + log.Printf("warning: 关联分类失败: %v", err) + return nil + } - _, err := o.toolManager.Hyt.GoodsMediaAdd.Call(ctx, req) - if err != nil { - log.Printf("warning: 添加图片失败: %v", err) state.mu.Lock() - state.Result["media_error"] = err.Error() + state.GoodsCategoryAddResp = isSuccess state.mu.Unlock() - } else { + + return nil + }) + + // 任务2: 添加图片 + eg.Go(func() error { + if len(state.IngestData.Images) == 0 { + return nil + } + req := &goods_media_add.GoodsMediaAddRequest{ + GoodsId: state.GoodsAddResp.Id, + IsCover: true, + Data: make([]goods_media_add.MediaItem, 0), + } + for i, url := range state.IngestData.Images { + req.Data = append(req.Data, goods_media_add.MediaItem{ + Type: 1, // 图片 + Url: url, + Sort: i, + }) + } + isSuccess, err := o.toolManager.Hyt.GoodsMediaAdd.Call(ctx, req) + if err != nil { + log.Printf("warning: 添加图片失败: %v", err) + return nil + } + state.mu.Lock() - state.Result["media_added"] = true + state.GoodsMediaAddResp = isSuccess state.mu.Unlock() - } + + return nil + }) + + // 等待所有任务完成 + _ = eg.Wait() return state, nil })) - // 7. 结果格式化节点 + // 5. 结果格式化节点 g.AddLambdaNode("format_output", compose.InvokableLambda(func(ctx context.Context, state *GoodsAddContext) (map[string]any, error) { - return state.Result, nil + if state.GoodsAddResp == nil { + return nil, fmt.Errorf("goods add response is nil") + } + + return map[string]any{ + "预览URL(货易通商品列表)": state.GoodsAddResp.PreviewUrl, + "SPU编码": state.GoodsAddResp.SpuCode, + "商品ID": state.GoodsAddResp.Id, + }, nil })) - // 构建边 (DAG) - - // Start -> DataMapping + // 构建边 (线性拓扑) g.AddEdge(compose.START, "data_mapping") - - // Branching: DataMapping -> GetBrandId, DataMapping -> GetCategoryId - g.AddEdge("data_mapping", "get_brand_id") - g.AddEdge("data_mapping", "get_category_id") - - // Synchronization for GoodsAdd: Need BrandId - g.AddEdge("get_brand_id", "goods_add") - - // Synchronization for CategoryAdd: Need GoodsId AND CategoryId - // Eino supports multi-predecessor nodes which act as merge points. - // state merging is handled by the framework (usually last writer wins or custom merge, but here we modify different fields/mutex). - // However, we need to ensure goods_add is done. - g.AddEdge("goods_add", "goods_category_add") - g.AddEdge("get_category_id", "goods_category_add") - - // Synchronization for MediaAdd: Need GoodsId - g.AddEdge("goods_add", "goods_media_add") - - // Final Merge - g.AddEdge("goods_category_add", "format_output") - g.AddEdge("goods_media_add", "format_output") - + g.AddEdge("data_mapping", "prepare_info") + g.AddEdge("prepare_info", "goods_add") + g.AddEdge("goods_add", "post_process") + g.AddEdge("post_process", "format_output") g.AddEdge("format_output", compose.END) return g.Compile(ctx) From 7682ecd75b1ce124a624a35e88aed3b48ba92947 Mon Sep 17 00:00:00 2001 From: fuzhongyun <15339891972@163.com> Date: Wed, 24 Dec 2025 17:19:30 +0800 Subject: [PATCH 61/66] =?UTF-8?q?fix=EF=BC=9A=201.=20msg=20->=20message=20?= =?UTF-8?q?=E5=AD=97=E6=AE=B5=E9=94=99=E8=AF=AF=E5=A4=84=E7=90=86=202.?= =?UTF-8?q?=E5=88=86=E7=B1=BB=E6=9F=A5=E8=AF=A2=E8=B0=83=E6=95=B4=E4=B8=BA?= =?UTF-8?q?=E4=BB=853=E7=BA=A7=E5=88=86=E7=B1=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/domain/tools/hyt/goods_add/client.go | 4 ++-- internal/domain/tools/hyt/goods_brand_search/types.go | 2 +- internal/domain/tools/hyt/goods_category_add/types.go | 2 +- internal/domain/tools/hyt/goods_category_search/client.go | 3 ++- internal/domain/tools/hyt/goods_category_search/types.go | 5 +++-- internal/domain/tools/hyt/goods_media_add/types.go | 2 +- internal/domain/workflow/hyt/goods_add.go | 3 ++- 7 files changed, 12 insertions(+), 9 deletions(-) diff --git a/internal/domain/tools/hyt/goods_add/client.go b/internal/domain/tools/hyt/goods_add/client.go index d6f83d5..a758b55 100644 --- a/internal/domain/tools/hyt/goods_add/client.go +++ b/internal/domain/tools/hyt/goods_add/client.go @@ -38,7 +38,7 @@ func (c *Client) Call(ctx context.Context, req *GoodsAddRequest) (*GoodsAddRespo type resType struct { Code int `json:"code"` - Msg string `json:"msg"` + Msg string `json:"message"` Data struct { Id int `json:"id"` // 商品 ID } `json:"data"` @@ -50,7 +50,7 @@ func (c *Client) Call(ctx context.Context, req *GoodsAddRequest) (*GoodsAddRespo } if resData.Code != 200 { - return nil, fmt.Errorf("业务错误,code: %d, msg: %s", resData.Code, resData.Msg) + return nil, fmt.Errorf("业务错误,%s", resData.Msg) } toolResp := &GoodsAddResponse{ diff --git a/internal/domain/tools/hyt/goods_brand_search/types.go b/internal/domain/tools/hyt/goods_brand_search/types.go index c3ec8bb..467a214 100644 --- a/internal/domain/tools/hyt/goods_brand_search/types.go +++ b/internal/domain/tools/hyt/goods_brand_search/types.go @@ -12,7 +12,7 @@ type SearchCondition struct { type GoodsBrandSearchResponse struct { Code int `json:"code"` - Msg string `json:"msg"` + Msg string `json:"message"` Data struct { List []BrandInfo `json:"list"` } `json:"data"` diff --git a/internal/domain/tools/hyt/goods_category_add/types.go b/internal/domain/tools/hyt/goods_category_add/types.go index e23691e..b3ecf68 100644 --- a/internal/domain/tools/hyt/goods_category_add/types.go +++ b/internal/domain/tools/hyt/goods_category_add/types.go @@ -8,7 +8,7 @@ type GoodsCategoryAddRequest struct { type GoodsCategoryAddResponse struct { Code int `json:"code"` - Msg string `json:"msg"` + Msg string `json:"message"` Data struct { IsSuccess bool `json:"is_success"` // 是否成功 } `json:"data"` diff --git a/internal/domain/tools/hyt/goods_category_search/client.go b/internal/domain/tools/hyt/goods_category_search/client.go index 185e54b..3af5e14 100644 --- a/internal/domain/tools/hyt/goods_category_search/client.go +++ b/internal/domain/tools/hyt/goods_category_search/client.go @@ -28,7 +28,8 @@ func (c *Client) Call(ctx context.Context, name string) (int, error) { Page: 1, Limit: 1, Search: SearchCondition{ - Name: name, + Name: name, + Level: 3, // 仅需三级分类 }, } diff --git a/internal/domain/tools/hyt/goods_category_search/types.go b/internal/domain/tools/hyt/goods_category_search/types.go index dcc32e9..2b9fb0d 100644 --- a/internal/domain/tools/hyt/goods_category_search/types.go +++ b/internal/domain/tools/hyt/goods_category_search/types.go @@ -7,12 +7,13 @@ type GoodsCategorySearchRequest struct { } type SearchCondition struct { - Name string `json:"name"` + Name string `json:"full_name"` + Level int `json:"level"` } type GoodsCategorySearchResponse struct { Code int `json:"code"` - Msg string `json:"msg"` + Msg string `json:"message"` Data struct { List []CategoryInfo `json:"list"` } `json:"data"` diff --git a/internal/domain/tools/hyt/goods_media_add/types.go b/internal/domain/tools/hyt/goods_media_add/types.go index e299d4f..bde4826 100644 --- a/internal/domain/tools/hyt/goods_media_add/types.go +++ b/internal/domain/tools/hyt/goods_media_add/types.go @@ -14,7 +14,7 @@ type MediaItem struct { type GoodsMediaAddResponse struct { Code int `json:"code"` - Msg string `json:"msg"` + Msg string `json:"message"` Data struct { IsSuccess bool `json:"is_success"` } `json:"data"` diff --git a/internal/domain/workflow/hyt/goods_add.go b/internal/domain/workflow/hyt/goods_add.go index 162a15c..91db063 100644 --- a/internal/domain/workflow/hyt/goods_add.go +++ b/internal/domain/workflow/hyt/goods_add.go @@ -272,7 +272,8 @@ func (o *goodsAdd) buildWorkflow(ctx context.Context) (compose.Runnable[*GoodsAd // 调用 goods_add 工具 respData, err := o.toolManager.Hyt.GoodsAdd.Call(ctx, state.AddGoodsReq) if err != nil || respData == nil { - return nil, fmt.Errorf("新增商品失败") + log.Printf("warning: 新增商品失败: %v", err) + return nil, fmt.Errorf("新增商品失败: %s", err.Error()) } state.GoodsAddResp = respData From e3448ae41eac3442529ac595a1576ecd21c216b1 Mon Sep 17 00:00:00 2001 From: fuzhongyun <15339891972@163.com> Date: Thu, 25 Dec 2025 14:46:52 +0800 Subject: [PATCH 62/66] =?UTF-8?q?feat:=201.=20=E5=A2=9E=E5=8A=A0=E5=90=8C?= =?UTF-8?q?=E6=AD=A5=E9=98=BB=E5=A1=9E=E7=AD=89=E5=BE=85=E5=BC=82=E6=AD=A5?= =?UTF-8?q?=E5=9B=9E=E8=B0=83redis=E7=BB=84=E4=BB=B6=202.=E5=8E=9F?= =?UTF-8?q?=E9=9C=80=E6=B1=82=E6=94=B6=E9=9B=86=E6=9C=BA=E5=99=A8=E4=BA=BA?= =?UTF-8?q?=E8=BF=81=E7=A7=BB=E8=87=B3eino=E5=B7=A5=E4=BD=9C=E6=B5=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cmd/server/wire.go | 9 +- internal/domain/component/callback/manager.go | 71 ++++++++ .../domain/component/callback/provider_set.go | 5 + internal/domain/component/components.go | 15 ++ internal/domain/component/provider_set.go | 14 ++ internal/domain/repo/adapter.go | 29 +++ internal/domain/repo/provider_set.go | 5 + internal/domain/repo/repos.go | 17 ++ internal/domain/repo/session.go | 11 ++ internal/domain/workflow/provider_set.go | 12 +- internal/domain/workflow/registry.go | 2 + internal/domain/workflow/runtime/registry.go | 4 + .../workflow/zltx/bug_optimization_submit.go | 170 ++++++++++++++++++ .../zltx/order_after_reseller_batch.go | 106 ++++++++--- internal/entitys/recognize.go | 18 ++ internal/services/callback.go | 22 ++- 16 files changed, 475 insertions(+), 35 deletions(-) create mode 100644 internal/domain/component/callback/manager.go create mode 100644 internal/domain/component/callback/provider_set.go create mode 100644 internal/domain/component/components.go create mode 100644 internal/domain/component/provider_set.go create mode 100644 internal/domain/repo/adapter.go create mode 100644 internal/domain/repo/provider_set.go create mode 100644 internal/domain/repo/repos.go create mode 100644 internal/domain/repo/session.go create mode 100644 internal/domain/workflow/zltx/bug_optimization_submit.go diff --git a/cmd/server/wire.go b/cmd/server/wire.go index a35fa9b..8357aae 100644 --- a/cmd/server/wire.go +++ b/cmd/server/wire.go @@ -9,11 +9,14 @@ import ( "ai_scheduler/internal/biz/tools_regis" "ai_scheduler/internal/config" "ai_scheduler/internal/data/impl" + "ai_scheduler/internal/domain/component" + "ai_scheduler/internal/domain/repo" "ai_scheduler/internal/domain/workflow" "ai_scheduler/internal/pkg" "ai_scheduler/internal/server" "ai_scheduler/internal/services" - "ai_scheduler/internal/tool_callback" + + // "ai_scheduler/internal/tool_callback" "ai_scheduler/internal/tools" "ai_scheduler/utils" @@ -34,7 +37,9 @@ func InitializeApp(*config.Config, log.AllLogger) (*server.Servers, func(), erro utils.ProviderUtils, dingtalk.ProviderSetDingTalk, tools_regis.ProviderToolsRegis, - tool_callback.ProviderSetCallBackTools, + // tool_callback.ProviderSetCallBackTools, + component.ProviderSet, + repo.ProviderSet, )) } diff --git a/internal/domain/component/callback/manager.go b/internal/domain/component/callback/manager.go new file mode 100644 index 0000000..7803b09 --- /dev/null +++ b/internal/domain/component/callback/manager.go @@ -0,0 +1,71 @@ +package callback + +import ( + "context" + "fmt" + "time" + + "ai_scheduler/internal/pkg" + + "github.com/redis/go-redis/v9" +) + +type Manager interface { + Register(ctx context.Context, taskID string, sessionID string) error + Wait(ctx context.Context, taskID string, timeout time.Duration) (string, error) + Notify(ctx context.Context, taskID string, result string) error + GetSession(ctx context.Context, taskID string) (string, error) +} + +type RedisManager struct { + rdb *redis.Client +} + +func NewRedisManager(rdb *pkg.Rdb) *RedisManager { + return &RedisManager{ + rdb: rdb.Rdb, + } +} + +const ( + keyPrefixSession = "callback:session:" + keyPrefixSignal = "callback:signal:" + defaultTTL = 24 * time.Hour +) + +func (m *RedisManager) Register(ctx context.Context, taskID string, sessionID string) error { + key := keyPrefixSession + taskID + return m.rdb.Set(ctx, key, sessionID, defaultTTL).Err() +} + +func (m *RedisManager) Wait(ctx context.Context, taskID string, timeout time.Duration) (string, error) { + key := keyPrefixSignal + taskID + // BLPop 阻塞等待 + result, err := m.rdb.BLPop(ctx, timeout, key).Result() + if err != nil { + if err == redis.Nil { + return "", fmt.Errorf("timeout waiting for callback") + } + return "", err + } + // result[0] is key, result[1] is value + if len(result) < 2 { + return "", fmt.Errorf("invalid redis result") + } + return result[1], nil +} + +func (m *RedisManager) Notify(ctx context.Context, taskID string, result string) error { + key := keyPrefixSignal + taskID + // Push 信号,同时设置过期时间防止堆积 + pipe := m.rdb.Pipeline() + pipe.RPush(ctx, key, result) + pipe.Expire(ctx, key, 1*time.Hour) // 信号列表也需要过期 + _, err := pipe.Exec(ctx) + return err +} + +func (m *RedisManager) GetSession(ctx context.Context, taskID string) (string, error) { + key := keyPrefixSession + taskID + return m.rdb.Get(ctx, key).Result() +} diff --git a/internal/domain/component/callback/provider_set.go b/internal/domain/component/callback/provider_set.go new file mode 100644 index 0000000..302b5c1 --- /dev/null +++ b/internal/domain/component/callback/provider_set.go @@ -0,0 +1,5 @@ +package callback + +import "github.com/google/wire" + +var ProviderSet = wire.NewSet(NewRedisManager, wire.Bind(new(Manager), new(*RedisManager))) diff --git a/internal/domain/component/components.go b/internal/domain/component/components.go new file mode 100644 index 0000000..11c8d86 --- /dev/null +++ b/internal/domain/component/components.go @@ -0,0 +1,15 @@ +package component + +import ( + "ai_scheduler/internal/domain/component/callback" +) + +type Components struct { + Callback callback.Manager +} + +func NewComponents(callbackManager callback.Manager) *Components { + return &Components{ + Callback: callbackManager, + } +} diff --git a/internal/domain/component/provider_set.go b/internal/domain/component/provider_set.go new file mode 100644 index 0000000..9d6abe6 --- /dev/null +++ b/internal/domain/component/provider_set.go @@ -0,0 +1,14 @@ +package component + +import ( + "ai_scheduler/internal/domain/component/callback" + + "github.com/google/wire" +) + +var ProviderSetComponent = wire.NewSet(NewComponents) + +var ProviderSet = wire.NewSet( + callback.NewRedisManager, wire.Bind(new(callback.Manager), new(*callback.RedisManager)), + NewComponents, +) diff --git a/internal/domain/repo/adapter.go b/internal/domain/repo/adapter.go new file mode 100644 index 0000000..c9a1357 --- /dev/null +++ b/internal/domain/repo/adapter.go @@ -0,0 +1,29 @@ +package repo + +import ( + "ai_scheduler/internal/data/impl" + "context" + "errors" +) + +// SessionAdapter 适配 impl.SessionImpl 到 SessionRepo 接口 +type SessionAdapter struct { + impl *impl.SessionImpl +} + +func NewSessionAdapter(impl *impl.SessionImpl) *SessionAdapter { + return &SessionAdapter{impl: impl} +} + +func (s *SessionAdapter) GetUserName(ctx context.Context, sessionID string) (string, error) { + // 复用 SessionImpl 的查询能力 + // 这里假设 sessionID 是唯一的,直接用 FindOne + session, has, err := s.impl.FindOne(s.impl.WithSessionId(sessionID)) + if err != nil { + return "", err + } + if !has { + return "", errors.New("session not found") + } + return session.UserName, nil +} diff --git a/internal/domain/repo/provider_set.go b/internal/domain/repo/provider_set.go new file mode 100644 index 0000000..c5b2437 --- /dev/null +++ b/internal/domain/repo/provider_set.go @@ -0,0 +1,5 @@ +package repo + +import "github.com/google/wire" + +var ProviderSet = wire.NewSet(NewRepos) diff --git a/internal/domain/repo/repos.go b/internal/domain/repo/repos.go new file mode 100644 index 0000000..40ba3de --- /dev/null +++ b/internal/domain/repo/repos.go @@ -0,0 +1,17 @@ +package repo + +import ( + "ai_scheduler/internal/data/impl" + "ai_scheduler/utils" +) + +// Repos 聚合所有 Repository +type Repos struct { + Session SessionRepo +} + +func NewRepos(sessionImpl *impl.SessionImpl, rdb *utils.Rdb) *Repos { + return &Repos{ + Session: NewSessionAdapter(sessionImpl), + } +} diff --git a/internal/domain/repo/session.go b/internal/domain/repo/session.go new file mode 100644 index 0000000..5ccc66c --- /dev/null +++ b/internal/domain/repo/session.go @@ -0,0 +1,11 @@ +package repo + +import ( + "context" +) + +// SessionRepo 定义会话相关的查询接口 +// 这里只暴露 workflow 真正需要的方法,避免直接依赖 impl 层 +type SessionRepo interface { + GetUserName(ctx context.Context, sessionID string) (string, error) +} diff --git a/internal/domain/workflow/provider_set.go b/internal/domain/workflow/provider_set.go index 97e1b5d..b9a2815 100644 --- a/internal/domain/workflow/provider_set.go +++ b/internal/domain/workflow/provider_set.go @@ -2,6 +2,8 @@ package workflow import ( "ai_scheduler/internal/config" + "ai_scheduler/internal/domain/component" + "ai_scheduler/internal/domain/repo" "ai_scheduler/internal/domain/workflow/runtime" "ai_scheduler/internal/pkg/utils_ollama" @@ -13,9 +15,15 @@ import ( var ProviderSetWorkflow = wire.NewSet(NewRegistry) // NewRegistry 注入共享依赖并注册默认 Registry,确保自注册工作流可被发现 -func NewRegistry(conf *config.Config, llm *utils_ollama.Client) *runtime.Registry { +func NewRegistry(conf *config.Config, llm *utils_ollama.Client, repos *repo.Repos, components *component.Components) *runtime.Registry { // 步骤1:设置运行时依赖(配置与LLM客户端),供工作流工厂在首次实例化时使用;必须在任何调用 Invoke 之前完成,否则会触发 "deps not set" - runtime.SetDeps(&runtime.Deps{Conf: conf, LLM: llm, ToolManager: toolManager.NewManager(conf)}) + runtime.SetDeps(&runtime.Deps{ + Conf: conf, + LLM: llm, + ToolManager: toolManager.NewManager(conf), + Repos: repos, + Component: components, + }) // 步骤2:创建新的工作流注册表;注册表负责按工作流ID惰性实例化并缓存单例实例,保障并发访问下的安全 r := runtime.NewRegistry() // 步骤3:将该注册表设置为全局默认,便于通过 runtime.Default() 获取;自注册的工作流可通过默认注册表被发现并调用 diff --git a/internal/domain/workflow/registry.go b/internal/domain/workflow/registry.go index af69a03..10b24ef 100644 --- a/internal/domain/workflow/registry.go +++ b/internal/domain/workflow/registry.go @@ -2,6 +2,7 @@ package workflow import ( "ai_scheduler/internal/config" + "ai_scheduler/internal/domain/component" toolManager "ai_scheduler/internal/domain/tools" "ai_scheduler/internal/pkg/utils_ollama" ) @@ -11,4 +12,5 @@ type Deps struct { Conf *config.Config LLM *utils_ollama.Client ToolManager *toolManager.Manager + Component *component.Components } diff --git a/internal/domain/workflow/runtime/registry.go b/internal/domain/workflow/runtime/registry.go index 2b4049b..f804e1d 100644 --- a/internal/domain/workflow/runtime/registry.go +++ b/internal/domain/workflow/runtime/registry.go @@ -2,6 +2,8 @@ package runtime import ( "ai_scheduler/internal/config" + "ai_scheduler/internal/domain/component" + "ai_scheduler/internal/domain/repo" toolManager "ai_scheduler/internal/domain/tools" "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg/utils_ollama" @@ -20,6 +22,8 @@ type Deps struct { Conf *config.Config LLM *utils_ollama.Client ToolManager *toolManager.Manager + Component *component.Components // 基础设施能力 + Repos *repo.Repos // 数据访问 } type Factory func(deps *Deps) (Workflow, error) diff --git a/internal/domain/workflow/zltx/bug_optimization_submit.go b/internal/domain/workflow/zltx/bug_optimization_submit.go new file mode 100644 index 0000000..30ad0bc --- /dev/null +++ b/internal/domain/workflow/zltx/bug_optimization_submit.go @@ -0,0 +1,170 @@ +package zltx + +import ( + "context" + "encoding/json" + "errors" + "time" + + "ai_scheduler/internal/domain/component/callback" + "ai_scheduler/internal/domain/repo" + "ai_scheduler/internal/domain/workflow/runtime" + "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg" + "ai_scheduler/internal/pkg/l_request" + + "github.com/cloudwego/eino/compose" + "github.com/google/uuid" + "github.com/redis/go-redis/v9" +) + +const WorkflowIDBugOptimizationSubmit = "bug_optimization_submit" + +func init() { + runtime.Register(WorkflowIDBugOptimizationSubmit, func(d *runtime.Deps) (runtime.Workflow, error) { + // 从 Deps.Repos 获取 SessionRepo + return &bugOptimizationSubmit{ + manager: d.Component.Callback, + sessionRepo: d.Repos.Session, + }, nil + }) +} + +type bugOptimizationSubmit struct { + manager callback.Manager + sessionRepo repo.SessionRepo + redisCli *redis.Client +} + +func (w *bugOptimizationSubmit) ID() string { + return WorkflowIDBugOptimizationSubmit +} + +type BugOptimizationSubmitInput struct { + Ch chan entitys.Response + RequireData *entitys.Recognize +} + +type BugOptimizationSubmitOutput struct { + Msg string +} + +type contextWithTask struct { + Input *BugOptimizationSubmitInput + TaskID string +} + +func (w *bugOptimizationSubmit) Invoke(ctx context.Context, recognize *entitys.Recognize) (map[string]any, error) { + chain, err := w.buildWorkflow(ctx) + if err != nil { + return nil, err + } + + input := &BugOptimizationSubmitInput{ + Ch: recognize.Ch, + RequireData: recognize, + } + + out, err := chain.Invoke(ctx, input) + if err != nil { + return nil, err + } + + return map[string]any{"msg": out.Msg}, nil +} + +func (w *bugOptimizationSubmit) buildWorkflow(ctx context.Context) (compose.Runnable[*BugOptimizationSubmitInput, *BugOptimizationSubmitOutput], error) { + c := compose.NewChain[*BugOptimizationSubmitInput, *BugOptimizationSubmitOutput]() + + // Node 1: Prepare and Call + c.AppendLambda(compose.InvokableLambda(w.prepareAndCall)) + + // Node 2: Wait + c.AppendLambda(compose.InvokableLambda(w.waitCallback)) + + return c.Compile(ctx) +} + +func (w *bugOptimizationSubmit) prepareAndCall(ctx context.Context, in *BugOptimizationSubmitInput) (*contextWithTask, error) { + // 生成 TaskID + taskID := uuid.New().String() + + // Ext 中获取 sessionId + sessionID := in.RequireData.GetSession() + + // 注册回调映射 + if err := w.manager.Register(ctx, taskID, sessionID); err != nil { + return nil, err + } + + // 查询用户名 + userName := "unknown" + if w.sessionRepo != nil { + name, err := w.sessionRepo.GetUserName(ctx, sessionID) + if err == nil && name != "" { + userName = name + } + } + + // 构建请求参数 + var fileUrls, fileContent string + if len(in.RequireData.UserContent.File) > 0 { + for _, file := range in.RequireData.UserContent.File { + fileUrls += file.FileUrl + "," + fileContent += file.FileRec + "," + } + fileUrls = fileUrls[:len(fileUrls)-1] + fileContent = fileContent[:len(fileContent)-1] + } + + body := map[string]string{ + "mark": in.RequireData.Match.Index, + "text": in.RequireData.UserContent.Text, + "img": fileUrls, + "img_content": fileContent, + "creator": userName, + "task_id": taskID, + } + + request := l_request.Request{ + Url: "https://connector.dingtalk.com/webhook/flow/10352c521dd02104cee9000c", + Method: "POST", + Headers: map[string]string{ + "Content-Type": "application/json", + }, + JsonByte: pkg.JsonByteIgonErr(body), + } + + res, err := request.Send() + if err != nil { + return nil, err + } + + var data map[string]any + if err := json.Unmarshal(res.Content, &data); err != nil { + return nil, err + } + + if success, ok := data["success"].(bool); !ok || !success { + return nil, errors.New("dingtalk flow failed") + } + + entitys.ResLog(in.Ch, in.RequireData.Match.Index, "问题记录中") + entitys.ResLoading(in.Ch, in.RequireData.Match.Index, "问题记录中...") + + return &contextWithTask{Input: in, TaskID: taskID}, nil +} + +func (w *bugOptimizationSubmit) waitCallback(ctx context.Context, in *contextWithTask) (*BugOptimizationSubmitOutput, error) { + // 阻塞等待回调信号 + // 设置 5 分钟超时 + waitCtx, cancel := context.WithTimeout(ctx, 5*time.Minute) + defer cancel() + + res, err := w.manager.Wait(waitCtx, in.TaskID, 5*time.Minute) + if err != nil { + return nil, err + } + + return &BugOptimizationSubmitOutput{Msg: res}, nil +} diff --git a/internal/domain/workflow/zltx/order_after_reseller_batch.go b/internal/domain/workflow/zltx/order_after_reseller_batch.go index ff21d84..eee022a 100644 --- a/internal/domain/workflow/zltx/order_after_reseller_batch.go +++ b/internal/domain/workflow/zltx/order_after_reseller_batch.go @@ -22,8 +22,7 @@ func init() { } type orderAfterSaleResellerBatch struct { - cfg config.ToolConfig - data *OrderAfterSaleResellerBatchWorkflowInput + cfg config.ToolConfig } // 工作流入参 @@ -86,15 +85,19 @@ func (o *orderAfterSaleResellerBatch) Invoke(ctx context.Context, rec *entitys.R return nil, err } - o.data = &OrderAfterSaleResellerBatchWorkflowInput{ + input := &OrderAfterSaleResellerBatchWorkflowInput{ Ch: rec.Ch, UserInput: rec.UserContent.Text, FileContent: "", UserHistory: rec.ChatHis, ParameterResult: rec.Match.Parameters, } + + // 将 Input 注入 Context + ctx = context.WithValue(ctx, workflowInputContextKey{}, input) + // 工作流过程输出,不关注最终输出 - _, err = chain.Invoke(ctx, o.data) + _, err = chain.Invoke(ctx, input) if err != nil { return nil, err } @@ -107,6 +110,9 @@ func (o *orderAfterSaleResellerBatch) Invoke(ctx context.Context, rec *entitys.R var ErrInvalidOrderNumbers = errors.New("orderNumber 不能为空") +// contextKey 用于在 Context 中传递 WorkflowInput +type workflowInputContextKey struct{} + // buildWorkflow 构建工作流 func (o *orderAfterSaleResellerBatch) buildWorkflow(ctx context.Context) (compose.Runnable[*OrderAfterSaleResellerBatchWorkflowInput, *OrderAfterSaleResellerBatchWorkflowOutput], error) { // 定义工作流、出入参 @@ -127,39 +133,93 @@ func (o *orderAfterSaleResellerBatch) buildWorkflow(ctx context.Context) (compos }), )) + // 3.参数校验 & 传递 Input + // 注意:为了在后续节点访问 WorkflowInput,这里使用闭包或 Context 传递。 + // Eino Chain 节点间传递的是返回值。这里我们修改节点签名,将 input 一路传下去,或者使用 context。 + // 由于 Eino Chain 是强类型的,这里选择让 Parser 返回的数据结构包含原始 input,或者我们在 Parser 后重新组合。 + // 但最简单的方法是使用 Context 存储 Input (如果 Eino 支持 Context 传递)。Eino 的 Invoke 接受 ctx。 + // 但 Eino Chain 的设计是数据流驱动。 + // 修正方案:修改中间节点的数据结构,或者使用闭包捕获(但闭包捕获的是 build 时的变量,无法捕获运行时 input)。 + // 正确做法:Chain 的节点入参必须是上一个节点的出参。 + // 我们可以把 Parser 的输入改为 Input,输出改为一个包含 Input 和 ParsedData 的结构。 + // 但这里为了最小改动,我们利用 Context 来传递 Input 引用(这在 Eino 中是可行的,因为 ctx 会贯穿整个 Invoke)。 + // 更好的做法是重构 Chain 的数据流,但在保持逻辑不变的前提下,Context 是最快解法。 + + // 为了线程安全,我们在第一个节点把 Input 放入 Context?不行,Chain.Invoke(ctx, input) 的 ctx 是外部传入的。 + // Eino 允许 Lambda 修改 Context 吗?通常不允许。 + + // 让我们重新审视数据流: + // Input -> Lambda1 -> Message -> Parser -> NodeData -> Lambda4 -> ToolResp -> Lambda5 -> Output + // Lambda4 需要 Input.Ch 来发 Loading。 + // Lambda5 需要 Input.Ch 来发 Log/Json,还需要 NodeData。 + + // 根本问题是:中间节点丢失了 Input 信息。 + // 解决方案:使用一个聚合结构体在 Chain 中传递。 + + // 由于要大改数据流比较复杂,这里使用一种技巧: + // 在 Invoke 时,构造一个带有 Input 信息的 Context 传入。 + // 这样每个节点都能从 Context 拿到 Input。 + + // 重新实现 buildWorkflow 以支持 Context 传递 + return o.buildWorkflowWithContext(ctx) +} + +func (o *orderAfterSaleResellerBatch) buildWorkflowWithContext(ctx context.Context) (compose.Runnable[*OrderAfterSaleResellerBatchWorkflowInput, *OrderAfterSaleResellerBatchWorkflowOutput], error) { + c := compose.NewChain[*OrderAfterSaleResellerBatchWorkflowInput, *OrderAfterSaleResellerBatchWorkflowOutput]() + + // 0. Context 注入节点 (Trick: 利用第一个节点将 Input 注入 Context,但 Eino Chain 无法修改 Context 传递给下游) + // 实际上,我们可以在 Invoke 调用前,在外部包装 Context。 + // 所以这里不需要额外的节点,只需要在 Invoke 时处理。 + // 但 Invoke 是由 Chain 提供的,我们只能控制传入的 ctx。 + // 见下文 Invoke 方法的修改。 + + // 1.llm 推断参数 + c.AppendLambda(compose.InvokableLambda(func(ctx context.Context, in *OrderAfterSaleResellerBatchWorkflowInput) (*schema.Message, error) { + return &schema.Message{Content: in.ParameterResult}, nil + })) + + // 2.参数解析 + c.AppendLambda(compose.MessageParser( + schema.NewMessageJSONParser[*OrderAfterSaleResellerBatchNodeData](&schema.MessageJSONParseConfig{ + ParseFrom: schema.MessageParseFromContent, + }), + )) + // 3.参数校验 - c.AppendLambda(compose.InvokableLambda(func(_ context.Context, in *OrderAfterSaleResellerBatchNodeData) (*OrderAfterSaleResellerBatchNodeData, error) { - // 校验必填项 + c.AppendLambda(compose.InvokableLambda(func(ctx context.Context, in *OrderAfterSaleResellerBatchNodeData) (*OrderAfterSaleResellerBatchNodeData, error) { if len(in.OrderNumber) == 0 { return nil, ErrInvalidOrderNumbers } - - o.data.Data = in - + // 将解析后的 Data 存入 Input (通过 Context 获取 Input) + input := ctx.Value(workflowInputContextKey{}).(*OrderAfterSaleResellerBatchWorkflowInput) + input.Data = in // 这里修改 Input 是安全的,因为 Input 是请求维度的引用 return in, nil })) // 4.工具调用 - c.AppendLambda(compose.InvokableLambda(func(_ context.Context, in *OrderAfterSaleResellerBatchNodeData) (*toolZoarb.OrderAfterSaleResellerBatchResponse, error) { - entitys.ResLoading(o.data.Ch, o.ID(), "数据拉取中") + c.AppendLambda(compose.InvokableLambda(func(ctx context.Context, in *OrderAfterSaleResellerBatchNodeData) (*toolZoarb.OrderAfterSaleResellerBatchResponse, error) { + input := ctx.Value(workflowInputContextKey{}).(*OrderAfterSaleResellerBatchWorkflowInput) + entitys.ResLoading(input.Ch, o.ID(), "数据拉取中") toolRes, err := toolZoarb.Call(ctx, o.cfg, in.OrderNumber) - - entitys.ResLog(o.data.Ch, o.ID(), "数据拉取完成") + entitys.ResLog(input.Ch, o.ID(), "数据拉取完成") return toolRes, err })) // 5.结果数据映射 - c.AppendLambda(compose.InvokableLambda(o.dataMapping)) + c.AppendLambda(compose.InvokableLambda(func(ctx context.Context, in *toolZoarb.OrderAfterSaleResellerBatchResponse) (*OrderAfterSaleResellerBatchWorkflowOutput, error) { + return o.dataMapping(ctx, in) + })) - // 编译工作流 return c.Compile(ctx) } // 结果数据映射 -func (o *orderAfterSaleResellerBatch) dataMapping(_ context.Context, in *toolZoarb.OrderAfterSaleResellerBatchResponse) (*OrderAfterSaleResellerBatchWorkflowOutput, error) { - entitys.ResLog(o.data.Ch, o.ID(), "数据整理中") +func (o *orderAfterSaleResellerBatch) dataMapping(ctx context.Context, in *toolZoarb.OrderAfterSaleResellerBatchResponse) (*OrderAfterSaleResellerBatchWorkflowOutput, error) { + input := ctx.Value(workflowInputContextKey{}).(*OrderAfterSaleResellerBatchWorkflowInput) + + entitys.ResLog(input.Ch, o.ID(), "数据整理中") toolResp := &OrderAfterSaleResellerBatchWorkflowOutput{ Code: in.Code, @@ -170,17 +230,17 @@ func (o *orderAfterSaleResellerBatch) dataMapping(_ context.Context, in *toolZoa // 转换数据 for _, item := range in.Data.Data { // 处理方式 - afterType := util.StringToInt(o.data.Data.AfterType) + afterType := util.StringToInt(input.Data.AfterType) if afterType == 0 { afterType = 1 // 默认退款 } // 费用承担者 - responsibleType := util.StringToInt(o.data.Data.ResponsibleType) + responsibleType := util.StringToInt(input.Data.ResponsibleType) if responsibleType == 0 { responsibleType = 4 // 默认无 } // 售后金额 - afterSalesPrice := util.StringToFloat64(o.data.Data.AfterSalesPrice) + afterSalesPrice := util.StringToFloat64(input.Data.AfterSalesPrice) if afterSalesPrice == 0 { afterSalesPrice = item.OrderPrice } @@ -199,10 +259,10 @@ func (o *orderAfterSaleResellerBatch) dataMapping(_ context.Context, in *toolZoa Account: item.Account, Platforms: item.Platforms, AfterType: afterType, - Remark: o.data.Data.AfterSalesReason, + Remark: input.Data.AfterSalesReason, AfterAmount: afterSalesPrice, ResponsibleType: responsibleType, - ResponsiblePerson: o.data.Data.ResponsiblePerson, + ResponsiblePerson: input.Data.ResponsiblePerson, }) } @@ -215,7 +275,7 @@ func (o *orderAfterSaleResellerBatch) dataMapping(_ context.Context, in *toolZoa } toolRespJson, _ := json.Marshal(toolResp) - entitys.ResJson(o.data.Ch, o.ID(), string(toolRespJson)) + entitys.ResJson(input.Ch, o.ID(), string(toolRespJson)) return toolResp, nil } diff --git a/internal/entitys/recognize.go b/internal/entitys/recognize.go index 87f684c..fcd2fe5 100644 --- a/internal/entitys/recognize.go +++ b/internal/entitys/recognize.go @@ -3,6 +3,7 @@ package entitys import ( "ai_scheduler/internal/data/constants" "ai_scheduler/internal/data/model" + "encoding/json" ) type Recognize struct { @@ -47,3 +48,20 @@ type RecognizeFile struct { FileRealMime string // 文件真实MIME类型 FileUrl string // 文件下载链接 } + +func (r *Recognize) GetTaskExt() *TaskExt { + var ext TaskExt + if err := json.Unmarshal(r.Ext, &ext); err != nil { + return nil + } + + return &ext +} + +func (r *Recognize) GetSession() string { + ext := r.GetTaskExt() + if ext == nil { + return "" + } + return ext.Session +} diff --git a/internal/services/callback.go b/internal/services/callback.go index e224ce3..c68e1c3 100644 --- a/internal/services/callback.go +++ b/internal/services/callback.go @@ -4,6 +4,7 @@ import ( "ai_scheduler/internal/config" "ai_scheduler/internal/data/constants" errorcode "ai_scheduler/internal/data/error" + "ai_scheduler/internal/domain/component/callback" "ai_scheduler/internal/entitys" "ai_scheduler/internal/gateway" "ai_scheduler/internal/pkg" @@ -25,17 +26,17 @@ type CallbackService struct { dingtalkOldClient *dingtalk.OldClient dingtalkContactClient *dingtalk.ContactClient dingtalkNotableClient *dingtalk.NotableClient - callBackTool *tool_callback.CallBackTool + callbackManager callback.Manager } -func NewCallbackService(cfg *config.Config, gateway *gateway.Gateway, dingtalkOldClient *dingtalk.OldClient, dingtalkContactClient *dingtalk.ContactClient, dingtalkNotableClient *dingtalk.NotableClient, callBackTool *tool_callback.CallBackTool) *CallbackService { +func NewCallbackService(cfg *config.Config, gateway *gateway.Gateway, dingtalkOldClient *dingtalk.OldClient, dingtalkContactClient *dingtalk.ContactClient, dingtalkNotableClient *dingtalk.NotableClient, callbackManager callback.Manager) *CallbackService { return &CallbackService{ cfg: cfg, gateway: gateway, dingtalkOldClient: dingtalkOldClient, dingtalkContactClient: dingtalkContactClient, dingtalkNotableClient: dingtalkNotableClient, - callBackTool: callBackTool, + callbackManager: callbackManager, } } @@ -139,11 +140,14 @@ func (s *CallbackService) Callback(c *fiber.Ctx) error { func (s *CallbackService) handleDingTalkCallback(c *fiber.Ctx, env Envelope) error { // 校验taskId - sessionID, ok := s.callBackTool.GetSessionByTaskID(env.TaskID) - if !ok { + ctx := c.Context() + sessionID, err := s.callbackManager.GetSession(ctx, env.TaskID) + if err != nil { + return errorcode.ParamErrf("failed to get session for task_id: %s, err: %v", env.TaskID, err) + } + if sessionID == "" { return errorcode.ParamErrf("missing session_id for task_id: %s", env.TaskID) } - ctx := c.Context() switch env.Action { case ActionBugOptimizationSubmitUpdate: @@ -166,8 +170,10 @@ func (s *CallbackService) handleDingTalkCallback(c *fiber.Ctx, env Envelope) err // 发送日志 s.sendStreamTxt(sessionID, msg) - // 删除映射 - s.callBackTool.DelTaskMapping(env.TaskID) + // 通知等待者 + if err := s.callbackManager.Notify(ctx, env.TaskID, msg); err != nil { + // 记录错误但继续 + } return c.JSON(fiber.Map{"code": 0, "message": "ok"}) case ActionBugOptimizationSubmitProcess: From 1a49952293909cfde2881d1178c6b7fa018fabb1 Mon Sep 17 00:00:00 2001 From: renzhiyuan <465386466@qq.com> Date: Fri, 26 Dec 2025 16:21:43 +0800 Subject: [PATCH 63/66] feat: remove image loading code and add README-test.md --- README-test.md | 4 ++++ internal/services/dtalk_bot.go | 7 ++++++- 2 files changed, 10 insertions(+), 1 deletion(-) create mode 100644 README-test.md diff --git a/README-test.md b/README-test.md new file mode 100644 index 0000000..161ea9a --- /dev/null +++ b/README-test.md @@ -0,0 +1,4 @@ +[https://p6-img.searchpstatp.com/tos-cn-i-vvloioitz3/6e5e76d274df2efabde9194a06f97e89~tplv-vvloioitz3-6:190:124.jpeg] + + +![图片](https://p6-img.searchpstatp.com/tos-cn-i-vvloioitz3/ab5ae998d8162b431f44fb2a0ed9ae33~tplv-vvloioitz3-6:190:124.jpeg) \ No newline at end of file diff --git a/internal/services/dtalk_bot.go b/internal/services/dtalk_bot.go index 5b9a8e6..b71e40b 100644 --- a/internal/services/dtalk_bot.go +++ b/internal/services/dtalk_bot.go @@ -68,6 +68,12 @@ func (d *DingBotService) runBackgroundTasks(ctx context.Context, data *chatbot.B g.Go(func() error { // 在完成时关闭通道 defer close(requireData.Ch) + + //entitys.ResLoading(requireData.Ch, "", "![图片](") + //entitys.ResLoading(requireData.Ch, "", "https://p6-img.") + //entitys.ResLoading(requireData.Ch, "", "searchpstatp.com/") + //entitys.ResLoading(requireData.Ch, "", "tos-cn-i-vvloioitz3/ab5ae998d8162b431f44fb2a0ed9ae33~tplv-vvloioitz3-6:190:124.jpeg)") + return d.dingTalkBotBiz.Do(ctx, requireData) }) @@ -76,7 +82,6 @@ func (d *DingBotService) runBackgroundTasks(ctx context.Context, data *chatbot.B g.Go(func() error { // 使用defer确保通道关闭 defer close(resultDone) - // 处理通道中的数据 for { select { From 6c23fa34d66b9bdbefa9ea556ef83aac0665b154 Mon Sep 17 00:00:00 2001 From: fuzhongyun <15339891972@163.com> Date: Fri, 26 Dec 2025 17:17:20 +0800 Subject: [PATCH 64/66] =?UTF-8?q?chore=EF=BC=9Aadd=20bak?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../zltx/bug_optimization_submit.bak.go | 170 ++++++++++++++++++ 1 file changed, 170 insertions(+) create mode 100644 internal/domain/workflow/zltx/bug_optimization_submit.bak.go diff --git a/internal/domain/workflow/zltx/bug_optimization_submit.bak.go b/internal/domain/workflow/zltx/bug_optimization_submit.bak.go new file mode 100644 index 0000000..6ed6bb4 --- /dev/null +++ b/internal/domain/workflow/zltx/bug_optimization_submit.bak.go @@ -0,0 +1,170 @@ +package zltx + +import ( + "context" + "encoding/json" + "errors" + "time" + + "ai_scheduler/internal/domain/component/callback" + "ai_scheduler/internal/domain/repo" + "ai_scheduler/internal/domain/workflow/runtime" + "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg" + "ai_scheduler/internal/pkg/l_request" + + "github.com/cloudwego/eino/compose" + "github.com/google/uuid" + "github.com/redis/go-redis/v9" +) + +const WorkflowIDBugOptimizationSubmitBak = "bug_optimization_submit_bak" + +func init() { + runtime.Register(WorkflowIDBugOptimizationSubmitBak, func(d *runtime.Deps) (runtime.Workflow, error) { + // 从 Deps.Repos 获取 SessionRepo + return &bugOptimizationSubmitBak{ + manager: d.Component.Callback, + sessionRepo: d.Repos.Session, + }, nil + }) +} + +type bugOptimizationSubmitBak struct { + manager callback.Manager + sessionRepo repo.SessionRepo + redisCli *redis.Client +} + +func (w *bugOptimizationSubmitBak) ID() string { + return WorkflowIDBugOptimizationSubmitBak +} + +type BugOptimizationSubmitBakInput struct { + Ch chan entitys.Response + RequireData *entitys.Recognize +} + +type BugOptimizationSubmitBakOutput struct { + Msg string +} + +type contextWithTaskBak struct { + Input *BugOptimizationSubmitBakInput + TaskID string +} + +func (w *bugOptimizationSubmitBak) Invoke(ctx context.Context, recognize *entitys.Recognize) (map[string]any, error) { + chain, err := w.buildWorkflow(ctx) + if err != nil { + return nil, err + } + + input := &BugOptimizationSubmitBakInput{ + Ch: recognize.Ch, + RequireData: recognize, + } + + out, err := chain.Invoke(ctx, input) + if err != nil { + return nil, err + } + + return map[string]any{"msg": out.Msg}, nil +} + +func (w *bugOptimizationSubmitBak) buildWorkflow(ctx context.Context) (compose.Runnable[*BugOptimizationSubmitBakInput, *BugOptimizationSubmitBakOutput], error) { + c := compose.NewChain[*BugOptimizationSubmitBakInput, *BugOptimizationSubmitBakOutput]() + + // Node 1: Prepare and Call + c.AppendLambda(compose.InvokableLambda(w.prepareAndCall)) + + // Node 2: Wait + c.AppendLambda(compose.InvokableLambda(w.waitCallback)) + + return c.Compile(ctx) +} + +func (w *bugOptimizationSubmitBak) prepareAndCall(ctx context.Context, in *BugOptimizationSubmitBakInput) (*contextWithTaskBak, error) { + // 生成 TaskID + taskID := uuid.New().String() + + // Ext 中获取 sessionId + sessionID := in.RequireData.GetSession() + + // 注册回调映射 + if err := w.manager.Register(ctx, taskID, sessionID); err != nil { + return nil, err + } + + // 查询用户名 + userName := "unknown" + if w.sessionRepo != nil { + name, err := w.sessionRepo.GetUserName(ctx, sessionID) + if err == nil && name != "" { + userName = name + } + } + + // 构建请求参数 + var fileUrls, fileContent string + if len(in.RequireData.UserContent.File) > 0 { + for _, file := range in.RequireData.UserContent.File { + fileUrls += file.FileUrl + "," + fileContent += file.FileRec + "," + } + fileUrls = fileUrls[:len(fileUrls)-1] + fileContent = fileContent[:len(fileContent)-1] + } + + body := map[string]string{ + "mark": in.RequireData.Match.Index, + "text": in.RequireData.UserContent.Text, + "img": fileUrls, + "img_content": fileContent, + "creator": userName, + "task_id": taskID, + } + + request := l_request.Request{ + Url: "https://connector.dingtalk.com/webhook/flow/10352c521dd02104cee9000c", + Method: "POST", + Headers: map[string]string{ + "Content-Type": "application/json", + }, + JsonByte: pkg.JsonByteIgonErr(body), + } + + res, err := request.Send() + if err != nil { + return nil, err + } + + var data map[string]any + if err := json.Unmarshal(res.Content, &data); err != nil { + return nil, err + } + + if success, ok := data["success"].(bool); !ok || !success { + return nil, errors.New("dingtalk flow failed") + } + + entitys.ResLog(in.Ch, in.RequireData.Match.Index, "问题记录中") + entitys.ResLoading(in.Ch, in.RequireData.Match.Index, "问题记录中...") + + return &contextWithTaskBak{Input: in, TaskID: taskID}, nil +} + +func (w *bugOptimizationSubmitBak) waitCallback(ctx context.Context, in *contextWithTask) (*BugOptimizationSubmitBakOutput, error) { + // 阻塞等待回调信号 + // 设置 5 分钟超时 + waitCtx, cancel := context.WithTimeout(ctx, 5*time.Minute) + defer cancel() + + res, err := w.manager.Wait(waitCtx, in.TaskID, 5*time.Minute) + if err != nil { + return nil, err + } + + return &BugOptimizationSubmitBakOutput{Msg: res}, nil +} From f50e87524e8968fc6e6627b59f0cccae28b9d81a Mon Sep 17 00:00:00 2001 From: fuzhongyun <15339891972@163.com> Date: Fri, 26 Dec 2025 18:35:13 +0800 Subject: [PATCH 65/66] =?UTF-8?q?feat:=20=E5=A2=9E=E5=8A=A0=20bug=5Foptimi?= =?UTF-8?q?zation=5Fsubmit=20=E4=B8=B4=E6=97=B6=E5=AE=9E=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/biz/do/handle.go | 109 ++++++++++++++++++++---- internal/pkg/dingtalk/notable_client.go | 45 ++++++++++ 2 files changed, 139 insertions(+), 15 deletions(-) diff --git a/internal/biz/do/handle.go b/internal/biz/do/handle.go index 76f1257..d2ef661 100644 --- a/internal/biz/do/handle.go +++ b/internal/biz/do/handle.go @@ -11,10 +11,12 @@ import ( "ai_scheduler/internal/entitys" "ai_scheduler/internal/gateway" "ai_scheduler/internal/pkg" + "ai_scheduler/internal/pkg/dingtalk" "ai_scheduler/internal/pkg/l_request" "ai_scheduler/internal/pkg/mapstructure" "ai_scheduler/internal/pkg/rec_extra" "ai_scheduler/internal/pkg/util" + "ai_scheduler/internal/tool_callback" "ai_scheduler/internal/tools" "ai_scheduler/internal/tools/public" errorsSpecial "errors" @@ -28,16 +30,19 @@ import ( "strings" "github.com/coze-dev/coze-go" + "github.com/gofiber/fiber/v2/log" "gorm.io/gorm/utils" ) type Handle struct { - Ollama *llm_service.OllamaService - toolManager *tools.Manager - - conf *config.Config - sessionImpl *impl.SessionImpl - workflowManager *runtime.Registry + Ollama *llm_service.OllamaService + toolManager *tools.Manager + conf *config.Config + sessionImpl *impl.SessionImpl + workflowManager *runtime.Registry + dingtalkOldClient *dingtalk.OldClient + dingtalkContactClient *dingtalk.ContactClient + dingtalkNotableClient *dingtalk.NotableClient } func NewHandle( @@ -45,16 +50,20 @@ func NewHandle( toolManager *tools.Manager, conf *config.Config, sessionImpl *impl.SessionImpl, - workflowManager *runtime.Registry, + dingtalkOldClient *dingtalk.OldClient, + dingtalkContactClient *dingtalk.ContactClient, + dingtalkNotableClient *dingtalk.NotableClient, ) *Handle { return &Handle{ - Ollama: Ollama, - toolManager: toolManager, - conf: conf, - sessionImpl: sessionImpl, - - workflowManager: workflowManager, + Ollama: Ollama, + toolManager: toolManager, + conf: conf, + sessionImpl: sessionImpl, + workflowManager: workflowManager, + dingtalkOldClient: dingtalkOldClient, + dingtalkContactClient: dingtalkContactClient, + dingtalkNotableClient: dingtalkNotableClient, } } @@ -119,10 +128,12 @@ func (r *Handle) HandleMatch(ctx context.Context, client *gateway.Client, rec *e switch constants.TaskType(pointTask.Type) { case constants.TaskTypeApi: return r.handleApiTask(ctx, rec, pointTask) - case constants.TaskTypeFunc: - return r.handleTask(ctx, rec, pointTask) case constants.TaskTypeKnowle: return r.handleKnowle(ctx, rec, pointTask) + case constants.TaskTypeFunc: + return r.handleTask(ctx, rec, pointTask) + case constants.TaskTypeBot: + return r.handleBot(ctx, rec, pointTask) case constants.TaskTypeEinoWorkflow: return r.handleEinoWorkflow(ctx, rec, pointTask) case constants.TaskTypeCozeWorkflow: @@ -235,6 +246,74 @@ func (r *Handle) handleKnowle(ctx context.Context, rec *entitys.Recognize, task return } +// bot 临时实现,后续转到 eino 工作流 +func (r *Handle) handleBot(ctx context.Context, rec *entitys.Recognize, task *model.AiTask) (err error) { + if task.Index == "bug_optimization_submit" { + // Ext 中获取 sessionId + sessionID := rec.GetSession() + // 获取dingtalk accessToken + accessToken, _ := r.dingtalkOldClient.GetAccessToken() + // 获取创建者 dingtalk unionId + unionId := r.getUserDingtalkUnionId(ctx, accessToken, sessionID) + baseId := "YQBnd5ExVE6qAbnOiANQg2KKJyeZqMmz" + // 获取第一个数据表 + // sheetIdOrName := r.getFirstSheetIdOrName(ctx, baseId) + recordId, err := r.dingtalkNotableClient.InsertRecord(accessToken, &dingtalk.InsertRecordReq{ + BaseId: baseId, + SheetIdOrName: "数据表", + OperatorId: tool_callback.BotBugOptimizationSubmitAdminUnionId, + CreatorUnionId: unionId, + Content: rec.Match.Parameters, + }) + if err != nil { + return err + } + + if recordId == "" { + return errors.NewBusinessErr(422, "创建记录失败") + } + + entitys.ResText(rec.Ch, "", fmt.Sprintf("创建记录成功,记录ID: %s", recordId)) + + return nil + } + + return errors.NewBusinessErr(422, "bot 任务未实现") +} + +// getUserDingtalkUnionId 获取用户的 dingtalk unionId +func (r *Handle) getUserDingtalkUnionId(ctx context.Context, accessToken, sessionID string) (unionId string) { + // 查询用户名 + session, has, err := r.sessionImpl.FindOne(r.sessionImpl.WithSessionId(sessionID)) + if err != nil || !has { + log.Warnf("session not found: %s", sessionID) + return + } + creatorName := session.UserName + + // 获取创建者uid 用户名 -> dingtalk uid + creatorId, err := r.dingtalkContactClient.SearchUserOne(accessToken, creatorName) + if err != nil { + log.Warnf("search dingtalk user one failed: %v", err) + return + } + + // 获取用户详情 dingtalk uid -> dingtalk unionId + userDetails, err := r.dingtalkOldClient.QueryUserDetails(ctx, creatorId) + if err != nil { + log.Warnf("query user dingtalk details failed: %v", err) + return + } + if userDetails == nil { + log.Warnf("user details not found: %s", creatorId) + return + } + + unionId = userDetails.UnionID + + return +} + func (r *Handle) handleApiTask(ctx context.Context, rec *entitys.Recognize, task *model.AiTask) (err error) { var ( request l_request.Request diff --git a/internal/pkg/dingtalk/notable_client.go b/internal/pkg/dingtalk/notable_client.go index d7d5434..27a2f89 100644 --- a/internal/pkg/dingtalk/notable_client.go +++ b/internal/pkg/dingtalk/notable_client.go @@ -72,3 +72,48 @@ func (c *NotableClient) UpdateRecord(accessToken string, req *UpdateRecordReq) ( return true, nil } + +type InsertRecordReq struct { + BaseId string + SheetIdOrName string + OperatorId string + CreatorUnionId string + Content string +} + +func (c *NotableClient) InsertRecord(accessToken string, req *InsertRecordReq) (string, error) { + // 默认使用“数据表” + if req.SheetIdOrName == "" { + req.SheetIdOrName = "数据表" + } + + headers := ¬able.InsertRecordsHeaders{} + headers.XAcsDingtalkAccessToken = tea.String(accessToken) + resp, err := c.cli.InsertRecordsWithOptions( + tea.String(req.BaseId), + tea.String(req.SheetIdOrName), + ¬able.InsertRecordsRequest{ + OperatorId: tea.String(req.OperatorId), + Records: []*notable.InsertRecordsRequestRecords{ + { + Fields: map[string]any{ + "需求内容": req.Content, + "提交人": []map[string]any{ + { + "unionId": req.CreatorUnionId, + }, + }, + }, + }, + }, + }, headers, &util.RuntimeOptions{}) + if err != nil { + return "", err + } + + if resp.Body == nil || resp.Body.Value == nil || len(resp.Body.Value) == 0 { + return "", errorcode.ParamErrf("empty response body") + } + + return *resp.Body.Value[0].Id, nil +} From 43edfc54c76f5628d14284da0f546def465bfd8e Mon Sep 17 00:00:00 2001 From: fuzhongyun <15339891972@163.com> Date: Sat, 27 Dec 2025 10:13:57 +0800 Subject: [PATCH 66/66] =?UTF-8?q?feat:=20=E5=AE=8C=E5=96=84=20bug=5Foptimi?= =?UTF-8?q?zation=5Fsubmit=20=E5=AE=9E=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/config.yaml | 7 ++++++ config/config_env.yaml | 7 ++++++ config/config_test.yaml | 7 ++++++ internal/biz/do/handle.go | 33 +++++++++++++++++-------- internal/config/config.go | 15 +++++++++++ internal/data/error/error_code.go | 13 +++++++--- internal/pkg/dingtalk/notable_client.go | 24 ++++++++++++++++++ 7 files changed, 92 insertions(+), 14 deletions(-) diff --git a/config/config.yaml b/config/config.yaml index 4b568b4..a1497f4 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -65,6 +65,13 @@ tools: enabled: true base_url: "https://revcl.1688sup.com/api/admin/afterSales/reseller_pre_ai" +dingtalk: + api_key: "dingsbbntrkeiyazcfdg" + api_secret: "ObqxwyR20r9rVNhju0sCPQyQA98_FZSc32W4vgxnGFH_b02HZr1BPCJsOAF816nu" + table_demand: + url: "https://alidocs.dingtalk.com/i/nodes/YQBnd5ExVE6qAbnOiANQg2KKJyeZqMmz" + base_id: "2Amq4vjg89RnYx9DTp66m2orW3kdP0wQ" + sheet_id_or_name: "数据表" default_prompt: img_recognize: diff --git a/config/config_env.yaml b/config/config_env.yaml index fa6b8ae..e180a4b 100644 --- a/config/config_env.yaml +++ b/config/config_env.yaml @@ -105,6 +105,13 @@ eino_tools: hytGoodsBrandSearch: base_url: "https://gateway.dev.cdlsxd.cn/goods-admin/api/v1/goods/brand/list" +dingtalk: + api_key: "dingsbbntrkeiyazcfdg" + api_secret: "ObqxwyR20r9rVNhju0sCPQyQA98_FZSc32W4vgxnGFH_b02HZr1BPCJsOAF816nu" + table_demand: + url: "https://alidocs.dingtalk.com/i/nodes/YQBnd5ExVE6qAbnOiANQg2KKJyeZqMmz" + base_id: "YQBnd5ExVE6qAbnOiANQg2KKJyeZqMmz" + sheet_id_or_name: "数据表" default_prompt: diff --git a/config/config_test.yaml b/config/config_test.yaml index 5ea689d..2180e18 100644 --- a/config/config_test.yaml +++ b/config/config_test.yaml @@ -119,6 +119,13 @@ eino_tools: hytGoodsBrandSearch: base_url: "https://gateway.dev.cdlsxd.cn/goods-admin/api/v1/goods/brand/list" +dingtalk: + api_key: "dingsbbntrkeiyazcfdg" + api_secret: "ObqxwyR20r9rVNhju0sCPQyQA98_FZSc32W4vgxnGFH_b02HZr1BPCJsOAF816nu" + table_demand: + url: "https://alidocs.dingtalk.com/i/nodes/YQBnd5ExVE6qAbnOiANQg2KKJyeZqMmz" + base_id: "YQBnd5ExVE6qAbnOiANQg2KKJyeZqMmz" + sheet_id_or_name: "数据表" default_prompt: img_recognize: diff --git a/internal/biz/do/handle.go b/internal/biz/do/handle.go index d2ef661..ca2392c 100644 --- a/internal/biz/do/handle.go +++ b/internal/biz/do/handle.go @@ -4,6 +4,7 @@ import ( "ai_scheduler/internal/biz/llm_service" "ai_scheduler/internal/config" "ai_scheduler/internal/data/constants" + errorcode "ai_scheduler/internal/data/error" errors "ai_scheduler/internal/data/error" "ai_scheduler/internal/data/impl" "ai_scheduler/internal/data/model" @@ -16,7 +17,6 @@ import ( "ai_scheduler/internal/pkg/mapstructure" "ai_scheduler/internal/pkg/rec_extra" "ai_scheduler/internal/pkg/util" - "ai_scheduler/internal/tool_callback" "ai_scheduler/internal/tools" "ai_scheduler/internal/tools/public" errorsSpecial "errors" @@ -255,25 +255,38 @@ func (r *Handle) handleBot(ctx context.Context, rec *entitys.Recognize, task *mo accessToken, _ := r.dingtalkOldClient.GetAccessToken() // 获取创建者 dingtalk unionId unionId := r.getUserDingtalkUnionId(ctx, accessToken, sessionID) - baseId := "YQBnd5ExVE6qAbnOiANQg2KKJyeZqMmz" - // 获取第一个数据表 - // sheetIdOrName := r.getFirstSheetIdOrName(ctx, baseId) + // 附件url + var attachmentUrl string + for _, file := range rec.UserContent.File { + attachmentUrl = file.FileUrl + break + } recordId, err := r.dingtalkNotableClient.InsertRecord(accessToken, &dingtalk.InsertRecordReq{ - BaseId: baseId, - SheetIdOrName: "数据表", - OperatorId: tool_callback.BotBugOptimizationSubmitAdminUnionId, + BaseId: r.conf.Dingtalk.TableDemand.BaseId, + SheetIdOrName: r.conf.Dingtalk.TableDemand.SheetIdOrName, + // OperatorId: tool_callback.BotBugOptimizationSubmitAdminUnionId, + OperatorId: unionId, CreatorUnionId: unionId, - Content: rec.Match.Parameters, + Content: rec.UserContent.Text, + AttachmentUrl: attachmentUrl, }) if err != nil { + errCode := r.dingtalkNotableClient.GetHTTPStatus(err) + // 权限不足 + if errCode == 403 { + return errorcode.ForbiddenErr("您当前没有AI需求表编辑权限,请联系管理员添加权限") + } return err } if recordId == "" { - return errors.NewBusinessErr(422, "创建记录失败") + return errors.NewBusinessErr(422, "创建记录失败,请联系管理员") } - entitys.ResText(rec.Ch, "", fmt.Sprintf("创建记录成功,记录ID: %s", recordId)) + // 构建跳转链接 + detailPage := util.BuildJumpLink(r.conf.Dingtalk.TableDemand.Url, "去查看") + + entitys.ResText(rec.Ch, "", fmt.Sprintf("问题已记录,正在分配相关人员处理,请您耐心等待处理结果。点击查看工单进度:%s", detailPage)) return nil } diff --git a/internal/config/config.go b/internal/config/config.go index 3198ebd..441d09d 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -23,6 +23,7 @@ type Config struct { PermissionConfig PermissionConfig `mapstructure:"permissionConfig"` LLM LLM `mapstructure:"llm"` // DingTalkBots map[string]*DingTalkBot `mapstructure:"ding_talk_bots"` + Dingtalk DingtalkConfig `mapstructure:"dingtalk"` } type SysPrompt struct { @@ -61,6 +62,20 @@ type LLMCapabilityConfig struct { Parameters LLMParameters `mapstructure:"parameters"` } +// DingtalkConfig 钉钉配置 +type DingtalkConfig struct { + ApiKey string `mapstructure:"api_key"` + ApiSecret string `mapstructure:"api_secret"` + TableDemand AITableConfig `mapstructure:"table_demand"` +} + +// TableDemandConfig 需求表配置 +type AITableConfig struct { + Url string `mapstructure:"url"` + BaseId string `mapstructure:"base_id"` + SheetIdOrName string `mapstructure:"sheet_id_or_name"` +} + // SysConfig 系统配置 type SysConfig struct { SessionLen int `mapstructure:"session_len"` diff --git a/internal/data/error/error_code.go b/internal/data/error/error_code.go index 1e2f0b6..390f448 100644 --- a/internal/data/error/error_code.go +++ b/internal/data/error/error_code.go @@ -3,10 +3,11 @@ package errorcode import "fmt" var ( - Success = &BusinessErr{code: 200, message: "成功"} - ParamError = &BusinessErr{code: 401, message: "参数错误"} - NotFoundError = &BusinessErr{code: 404, message: "请求地址未找到"} - SystemError = &BusinessErr{code: 405, message: "系统错误"} + Success = &BusinessErr{code: 200, message: "成功"} + ParamError = &BusinessErr{code: 401, message: "参数错误"} + ForbiddenError = &BusinessErr{code: 403, message: "权限不足"} + NotFoundError = &BusinessErr{code: 404, message: "请求地址未找到"} + SystemError = &BusinessErr{code: 405, message: "系统错误"} ClientNotFound = &BusinessErr{code: 406, message: "未找到client_id"} SessionNotFound = &BusinessErr{code: 407, message: "未找到会话信息"} @@ -67,3 +68,7 @@ func (e *BusinessErr) Wrap(err error) *BusinessErr { func WorkflowErr(message string) *BusinessErr { return NewBusinessErr(WorkflowError.code, message) } + +func ForbiddenErr(message string) *BusinessErr { + return NewBusinessErr(ForbiddenError.code, message) +} diff --git a/internal/pkg/dingtalk/notable_client.go b/internal/pkg/dingtalk/notable_client.go index 27a2f89..885e111 100644 --- a/internal/pkg/dingtalk/notable_client.go +++ b/internal/pkg/dingtalk/notable_client.go @@ -3,6 +3,8 @@ package dingtalk import ( "ai_scheduler/internal/config" errorcode "ai_scheduler/internal/data/error" + "encoding/json" + "time" openapi "github.com/alibabacloud-go/darabonba-openapi/v2/client" notable "github.com/alibabacloud-go/dingtalk/notable_1_0" @@ -79,6 +81,7 @@ type InsertRecordReq struct { OperatorId string CreatorUnionId string Content string + AttachmentUrl string } func (c *NotableClient) InsertRecord(accessToken string, req *InsertRecordReq) (string, error) { @@ -97,12 +100,16 @@ func (c *NotableClient) InsertRecord(accessToken string, req *InsertRecordReq) ( Records: []*notable.InsertRecordsRequestRecords{ { Fields: map[string]any{ + "创建日期": time.Now().Format(time.DateTime), "需求内容": req.Content, "提交人": []map[string]any{ { "unionId": req.CreatorUnionId, }, }, + "附件": map[string]any{ + "link": req.AttachmentUrl, + }, }, }, }, @@ -117,3 +124,20 @@ func (c *NotableClient) InsertRecord(accessToken string, req *InsertRecordReq) ( return *resp.Body.Value[0].Id, nil } + +func (c *NotableClient) GetHTTPStatus(err error) int { + if sdkErr, ok := err.(*tea.SDKError); ok { + if sdkErr.StatusCode != nil { + return *sdkErr.StatusCode + } + if sdkErr.Data != nil { + var m struct { + StatusCode int `json:"statusCode"` + } + if json.Unmarshal([]byte(*sdkErr.Data), &m) == nil { + return m.StatusCode + } + } + } + return 0 // 0 = 非 HTTP 错误 +}