diff --git a/config/config.yaml b/config/config.yaml index a8a4137..95cd9fa 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -4,8 +4,8 @@ server: host: "0.0.0.0" ollama: - base_url: "http://localhost:11434" - model: "qwen3:8b" + base_url: "http://127.0.0.1:11434" + model: "qwen3-coder:480b-cloud" timeout: "120s" level: "info" format: "json" diff --git a/internal/biz/llm_service/common.go b/internal/biz/llm_service/common.go new file mode 100644 index 0000000..d833f51 --- /dev/null +++ b/internal/biz/llm_service/common.go @@ -0,0 +1,38 @@ +package llm_service + +import ( + "ai_scheduler/internal/data/model" + "ai_scheduler/internal/entitys" + "context" +) + +type LlmService interface { + IntentRecognize(ctx context.Context, sysInfo model.AiSy, history []model.AiChatHi, userInput string, tasks []model.AiTask) (string, error) +} + +// buildSystemPrompt 构建系统提示词 +func buildSystemPrompt(prompt string) string { + if len(prompt) == 0 { + prompt = "[system] 你是一个智能路由系统,核心职责是 **精准解析用户意图并路由至对应任务模块**\n[rule]\n1.返回以下格式的JSON:{ \"index\": \"工具索引index\", \"confidence\": 0.0-1.0,\"reasoning\": \"判断理由\"}\n2.严格返回字符串格式,禁用markdown格式返回\n3.只返回json字符串,不包含任何其他解释性文字\n4.当用户意图非常不清晰时使用,尝试进行追问具体希望查询内容" + } + + return prompt +} + +func buildAssistant(his []model.AiChatHi) (chatHis entitys.ChatHis) { + for _, item := range his { + if len(chatHis.SessionId) == 0 { + chatHis.SessionId = item.SessionID + } + chatHis.Messages = append(chatHis.Messages, entitys.HisMessage{ + Role: item.Role, + Content: item.Content, + Timestamp: item.CreateAt.Format("2006-01-02 15:04:05"), + }) + } + chatHis.Context = entitys.HisContext{ + UserLanguage: "zh-CN", + SystemMode: "technical_support", + } + return +} diff --git a/internal/biz/llm_service/langchain.go b/internal/biz/llm_service/langchain.go new file mode 100644 index 0000000..c8cc3ee --- /dev/null +++ b/internal/biz/llm_service/langchain.go @@ -0,0 +1,86 @@ +package llm_service + +import ( + "ai_scheduler/internal/data/model" + "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg" + "ai_scheduler/internal/pkg/utils_langchain" + "context" + "encoding/json" + + "github.com/tmc/langchaingo/llms" +) + +type LangChainService struct { + client *utils_langchain.UtilLangChain +} + +func NewLangChainGenerate( + client *utils_langchain.UtilLangChain, +) *LangChainService { + return &LangChainService{ + client: client, + } +} + +func (r *LangChainService) IntentRecognize(ctx context.Context, sysInfo model.AiSy, history []model.AiChatHi, userInput string, tasks []model.AiTask) (msg string, err error) { + prompt := r.getPrompt(sysInfo, history, userInput, tasks) + AgentClient := r.client.Get() + defer r.client.Put(AgentClient) + match, err := AgentClient.Llm.GenerateContent( + ctx, // 使用可取消的上下文 + prompt, + llms.WithJSONMode(), + ) + msg = match.Choices[0].Content + return +} + +func (r *LangChainService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []llms.MessageContent { + var ( + prompt = make([]llms.MessageContent, 0) + ) + prompt = append(prompt, llms.MessageContent{ + Role: llms.ChatMessageTypeSystem, + Parts: []llms.ContentPart{ + llms.TextPart(buildSystemPrompt(sysInfo.SysPrompt)), + }, + }, llms.MessageContent{ + Role: llms.ChatMessageTypeTool, + Parts: []llms.ContentPart{ + llms.TextPart(pkg.JsonStringIgonErr(buildAssistant(history))), + }, + }, llms.MessageContent{ + Role: llms.ChatMessageTypeTool, + Parts: []llms.ContentPart{ + llms.TextPart(pkg.JsonStringIgonErr(r.registerTools(tasks))), + }, + }, llms.MessageContent{ + Role: llms.ChatMessageTypeHuman, + Parts: []llms.ContentPart{ + llms.TextPart(reqInput), + }, + }) + return prompt +} + +func (r *LangChainService) registerTools(tasks []model.AiTask) []llms.Tool { + taskPrompt := make([]llms.Tool, 0) + for _, task := range tasks { + var taskConfig entitys.TaskConfig + err := json.Unmarshal([]byte(task.Config), &taskConfig) + if err != nil { + continue + } + taskPrompt = append(taskPrompt, llms.Tool{ + Type: "function", + Function: &llms.FunctionDefinition{ + Name: task.Index, + Description: task.Desc, + Parameters: taskConfig.Param, + }, + }) + + } + return taskPrompt +} diff --git a/internal/biz/llm_service/ollama.go b/internal/biz/llm_service/ollama.go new file mode 100644 index 0000000..a4a1297 --- /dev/null +++ b/internal/biz/llm_service/ollama.go @@ -0,0 +1,79 @@ +package llm_service + +import ( + "ai_scheduler/internal/data/model" + "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg" + "ai_scheduler/internal/pkg/utils_ollama" + "context" + "encoding/json" + "fmt" + + "github.com/ollama/ollama/api" +) + +type OllamaService struct { + client *utils_ollama.Client +} + +func NewOllamaGenerate( + client *utils_ollama.Client, +) *OllamaService { + return &OllamaService{ + client: client, + } +} + +func (r *OllamaService) IntentRecognize(ctx context.Context, requireData *entitys.RequireData) (msg string, err error) { + prompt := r.getPrompt(requireData.Sys, requireData.Histories, requireData.UserInput, requireData.Tasks) + toolDefinitions := r.registerToolsOllama(requireData.Tasks) + match, err := r.client.ToolSelect(context.TODO(), prompt, toolDefinitions) + if err != nil { + return + } + msg = match.Message.Content + return +} + +func (r *OllamaService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []api.Message { + var ( + prompt = make([]api.Message, 0) + ) + prompt = append(prompt, api.Message{ + Role: "system", + Content: buildSystemPrompt(sysInfo.SysPrompt), + }, api.Message{ + Role: "assistant", + Content: fmt.Sprintf("聊天记录:%s", pkg.JsonStringIgonErr(buildAssistant(history))), + }, api.Message{ + Role: "user", + Content: reqInput, + }) + return prompt +} + +func (r *OllamaService) registerToolsOllama(tasks []model.AiTask) []api.Tool { + taskPrompt := make([]api.Tool, 0) + for _, task := range tasks { + var taskConfig entitys.TaskConfigDetail + err := json.Unmarshal([]byte(task.Config), &taskConfig) + if err != nil { + continue + } + + taskPrompt = append(taskPrompt, api.Tool{ + Type: "function", + Function: api.ToolFunction{ + Name: task.Index, + Description: task.Desc, + Parameters: api.ToolFunctionParameters{ + Type: taskConfig.Param.Type, + Required: taskConfig.Param.Required, + Properties: taskConfig.Param.Properties, + }, + }, + }) + + } + return taskPrompt +} diff --git a/internal/biz/provider_set.go b/internal/biz/provider_set.go index 856f6c0..6dfb2c7 100644 --- a/internal/biz/provider_set.go +++ b/internal/biz/provider_set.go @@ -1,5 +1,15 @@ package biz -import "github.com/google/wire" +import ( + "ai_scheduler/internal/biz/llm_service" -var ProviderSetBiz = wire.NewSet(NewAiRouterBiz, NewSessionBiz, NewChatHistoryBiz) + "github.com/google/wire" +) + +var ProviderSetBiz = wire.NewSet( + NewAiRouterBiz, + NewSessionBiz, + NewChatHistoryBiz, + llm_service.NewLangChainGenerate, + llm_service.NewOllamaGenerate, +) diff --git a/internal/biz/router.go b/internal/biz/router.go index 906f9f0..3a4e8d3 100644 --- a/internal/biz/router.go +++ b/internal/biz/router.go @@ -1,6 +1,7 @@ package biz import ( + "ai_scheduler/internal/biz/llm_service" "ai_scheduler/internal/config" "ai_scheduler/internal/data/constants" errors "ai_scheduler/internal/data/error" @@ -9,8 +10,8 @@ import ( "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg" "ai_scheduler/internal/pkg/mapstructure" - "ai_scheduler/internal/pkg/utils_ollama" - tools "ai_scheduler/internal/tools" + + "ai_scheduler/internal/tools" "ai_scheduler/tmpl/dataTemp" "context" "encoding/json" @@ -21,96 +22,108 @@ import ( "gitea.cdlsxd.cn/self-tools/l_request" "github.com/gofiber/fiber/v2/log" "github.com/gofiber/websocket/v2" - "github.com/ollama/ollama/api" - "github.com/tmc/langchaingo/llms" "xorm.io/builder" ) // AiRouterBiz 智能路由服务 type AiRouterBiz struct { - //aiClient entitys.AIClient toolManager *tools.Manager sessionImpl *impl.SessionImpl sysImpl *impl.SysImpl taskImpl *impl.TaskImpl hisImpl *impl.ChatImpl conf *config.Config - utilAgent *utils_ollama.UtilOllama - ollama *utils_ollama.Client - channelPool *pkg.SafeChannelPool rds *pkg.Rdb + langChain *llm_service.LangChainService + Ollama *llm_service.OllamaService } // NewRouterService 创建路由服务 func NewAiRouterBiz( - //aiClient entitys.AIClient, toolManager *tools.Manager, sessionImpl *impl.SessionImpl, sysImpl *impl.SysImpl, taskImpl *impl.TaskImpl, hisImpl *impl.ChatImpl, conf *config.Config, - utilAgent *utils_ollama.UtilOllama, - channelPool *pkg.SafeChannelPool, - ollama *utils_ollama.Client, - + langChain *llm_service.LangChainService, + Ollama *llm_service.OllamaService, ) *AiRouterBiz { return &AiRouterBiz{ - //aiClient: aiClient, toolManager: toolManager, sessionImpl: sessionImpl, conf: conf, sysImpl: sysImpl, hisImpl: hisImpl, taskImpl: taskImpl, - utilAgent: utilAgent, - channelPool: channelPool, - ollama: ollama, + langChain: langChain, + Ollama: Ollama, } } -// Route 执行智能路由 -func (r *AiRouterBiz) Route(ctx context.Context, req *entitys.ChatRequest) (*entitys.ChatResponse, error) { - - return nil, nil -} - func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) (err error) { - - session := c.Query("x-session", "") - if len(session) == 0 { - return errors.SessionNotFound - } - auth := c.Query("x-authorization", "") - if len(auth) == 0 { - return errors.AuthNotFound - } - key := c.Query("x-app-key", "") - if len(key) == 0 { - return errors.KeyNotFound + //必要数据验证和获取 + var requireData entitys.RequireData + err = r.dataAuth(c, &requireData) + if err != nil { + return } - var chat = make([]string, 0) - + //初始化通道/上下文 + requireData.Ch = make(chan entitys.Response) ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - //ch := r.channelPool.Get() - ch := make(chan entitys.Response) + // 启动独立的消息处理协程 + done := r.startMessageHandler(ctx, c, &requireData) + defer func() { + close(requireData.Ch) //关闭主通道 + <-done // 等待消息处理完成 + cancel() + }() + + //获取系统信息 + err = r.getRequireData(req.Text, &requireData) + if err != nil { + log.Errorf("SQL error: %v", err) + return + } + //意图识别 + err = r.recognize(ctx, &requireData) + if err != nil { + log.Errorf("LLM error: %v", err) + return + } + //向下传递 + if err = r.handleMatch(ctx, &requireData); err != nil { + log.Errorf("Handle error: %v", err) + return + } + + return +} + +// startMessageHandler 启动独立的消息处理协程 +func (r *AiRouterBiz) startMessageHandler( + ctx context.Context, + c *websocket.Conn, + requireData *entitys.RequireData, +) <-chan struct{} { done := make(chan struct{}) + var chat []string + go func() { defer func() { close(done) - + // 保存历史记录 var his = []*model.AiChatHi{ { - SessionID: session, + SessionID: requireData.Session, Role: "user", - Content: req.Text, + Content: "", // 用户输入在外部处理 }, } if len(chat) > 0 { his = append(his, &model.AiChatHi{ - SessionID: session, + SessionID: requireData.Session, Role: "assistant", Content: strings.Join(chat, ""), }) @@ -119,92 +132,19 @@ func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRe r.hisImpl.Add(hi) } }() - for { - select { - case v, ok := <-ch: - if !ok { - return - } - // 带超时的发送,避免阻塞 - if err = sendWithTimeout(c, v, 2*time.Second); err != nil { - log.Errorf("Send error: %v", err) - cancel() // 通知主流程退出 - return - } - if v.Type == entitys.ResponseText || v.Type == entitys.ResponseStream || v.Type == entitys.ResponseJson { - chat = append(chat, v.Content) - } - case <-ctx.Done(): + for v := range requireData.Ch { // 自动检测通道关闭 + if err := sendWithTimeout(c, v, 2*time.Second); err != nil { + log.Errorf("Send error: %v", err) return } + if v.Type == entitys.ResponseText || v.Type == entitys.ResponseStream || v.Type == entitys.ResponseJson { + chat = append(chat, v.Content) + } } }() - defer func() { - close(ch) - }() - - sysInfo, err := r.getSysInfo(key) - if err != nil { - return errors.SysErr("获取系统信息失败:%v", err.Error()) - } - - history, err := r.getSessionChatHis(session) - if err != nil { - return errors.SysErr("获取历史记录失败:%v", err.Error()) - } - - task, err := r.getTasks(sysInfo.SysID) - if err != nil { - return errors.SysErr("获取任务列表失败:%v", err.Error()) - } - - prompt := r.getPromptLLM(sysInfo, history, req.Text, task) - - AgentClient := r.utilAgent.Get() - - ch <- entitys.Response{ - Index: "", - Content: "准备意图识别", - Type: entitys.ResponseLog, - } - - match, err := AgentClient.Llm.GenerateContent( - ctx, // 使用可取消的上下文 - prompt, - llms.WithJSONMode(), - ) - resMsg := match.Choices[0].Content - - r.utilAgent.Put(AgentClient) - ch <- entitys.Response{ - Index: "", - Content: resMsg, - Type: entitys.ResponseLog, - } - ch <- entitys.Response{ - Index: "", - Content: "意图识别结束", - Type: entitys.ResponseLog, - } - - if err != nil { - log.Errorf("LLM error: %v", err) - return errors.SystemError - } - var matchJson entitys.Match - if err := json.Unmarshal([]byte(resMsg), &matchJson); err != nil { - log.Info(resMsg) - return errors.SysErr("数据结构错误:%v", err.Error()) - } - matchJson.History = pkg.JsonByteIgonErr(history) - matchJson.UserInput = req.Text - if err := r.handleMatch(ctx, c, ch, &matchJson, task, sysInfo); err != nil { - return err - } - - return nil + return done } // 辅助函数:带超时的 WebSocket 发送 @@ -218,8 +158,11 @@ func sendWithTimeout(c *websocket.Conn, data entitys.Response, timeout time.Dura if r := recover(); r != nil { done <- fmt.Errorf("panic in MsgSend: %v", r) } + close(done) }() - done <- entitys.MsgSend(c, data) + // 如果 MsgSend 阻塞,这里会卡住 + err := entitys.MsgSend(c, data) + done <- err }() select { @@ -229,58 +172,135 @@ func sendWithTimeout(c *websocket.Conn, data entitys.Response, timeout time.Dura return sendCtx.Err() } } -func (r *AiRouterBiz) handleOtherTask(c *websocket.Conn, ch chan entitys.Response, matchJson *entitys.Match) (err error) { - ch <- entitys.Response{ + +func (r *AiRouterBiz) recognize(ctx context.Context, requireData *entitys.RequireData) (err error) { + requireData.Ch <- entitys.Response{ Index: "", - Content: matchJson.Reasoning, + Content: "准备意图识别", + Type: entitys.ResponseLog, + } + //意图识别 + recognizeMsg, err := r.Ollama.IntentRecognize(ctx, requireData) + if err != nil { + return + } + requireData.Ch <- entitys.Response{ + Index: "", + Content: recognizeMsg, + Type: entitys.ResponseLog, + } + requireData.Ch <- entitys.Response{ + Index: "", + Content: "意图识别结束", + Type: entitys.ResponseLog, + } + if err = json.Unmarshal([]byte(recognizeMsg), requireData.Match); err != nil { + err = errors.SysErr("数据结构错误:%v", err.Error()) + return + } + return +} + +func (r *AiRouterBiz) getRequireData(userInput string, requireData *entitys.RequireData) (err error) { + requireData.Sys, err = r.getSysInfo(requireData.Key) + if err != nil { + err = errors.SysErr("获取系统信息失败:%v", err.Error()) + return + } + requireData.Histories, err = r.getSessionChatHis(requireData.Session) + if err != nil { + err = errors.SysErr("获取历史记录失败:%v", err.Error()) + return + } + + requireData.Tasks, err = r.getTasks(requireData.Sys.SysID) + if err != nil { + err = errors.SysErr("获取任务列表失败:%v", err.Error()) + return + } + + requireData.UserInput = userInput + if len(requireData.UserInput) == 0 { + err = errors.SysErr("获取用户输入失败") + return + } + if len(requireData.UserInput) == 0 { + err = errors.SysErr("获取用户输入失败") + return + } + return +} + +func (r *AiRouterBiz) dataAuth(c *websocket.Conn, requireData *entitys.RequireData) (err error) { + requireData.Session = c.Query("x-session", "") + if len(requireData.Session) == 0 { + err = errors.SessionNotFound + return + } + requireData.Auth = c.Query("x-authorization", "") + if len(requireData.Auth) == 0 { + err = errors.AuthNotFound + return + } + requireData.Key = c.Query("x-app-key", "") + if len(requireData.Key) == 0 { + err = errors.KeyNotFound + return + } + return +} + +func (r *AiRouterBiz) handleOtherTask(ctx context.Context, requireData *entitys.RequireData) (err error) { + requireData.Ch <- entitys.Response{ + Index: "", + Content: requireData.Match.Reasoning, Type: entitys.ResponseText, } return } -func (r *AiRouterBiz) handleMatch(ctx context.Context, c *websocket.Conn, ch chan entitys.Response, matchJson *entitys.Match, tasks []model.AiTask, sysInfo model.AiSy) (err error) { +func (r *AiRouterBiz) handleMatch(ctx context.Context, requireData *entitys.RequireData) (err error) { - if !matchJson.IsMatch { - _ = entitys.MsgSend(c, entitys.Response{ + if !requireData.Match.IsMatch { + requireData.Ch <- entitys.Response{ Index: "", - Content: matchJson.Reasoning, + Content: requireData.Match.Reasoning, Type: entitys.ResponseText, - }) - + } return } var pointTask *model.AiTask - for _, task := range tasks { - if task.Index == matchJson.Index { + for _, task := range requireData.Tasks { + if task.Index == requireData.Match.Index { pointTask = &task break } } if pointTask == nil || pointTask.Index == "other" { - return r.handleOtherTask(c, ch, matchJson) + return r.handleOtherTask(ctx, requireData) } switch pointTask.Type { case constants.TaskTypeApi: - return r.handleApiTask(ch, c, matchJson, pointTask) + return r.handleApiTask(ctx, requireData, pointTask) case constants.TaskTypeFunc: - return r.handleTask(ch, c, matchJson, pointTask) + return r.handleTask(ctx, requireData, pointTask) case constants.TaskTypeKnowle: - return r.handleKnowle(ch, c, matchJson, pointTask, sysInfo) + return r.handleKnowle(ctx, requireData, pointTask) default: - return r.handleOtherTask(c, ch, matchJson) + return r.handleOtherTask(ctx, requireData) } } -func (r *AiRouterBiz) handleTask(channel chan entitys.Response, c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) { +func (r *AiRouterBiz) handleTask(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) { var configData entitys.ConfigDataTool err = json.Unmarshal([]byte(task.Config), &configData) if err != nil { return } - err = r.toolManager.ExecuteTool(channel, c, configData.Tool, []byte(matchJson.Parameters), matchJson) + err = r.toolManager.ExecuteTool(ctx, configData.Tool, requireData) if err != nil { return } @@ -289,7 +309,7 @@ func (r *AiRouterBiz) handleTask(channel chan entitys.Response, c *websocket.Con } // 知识库 -func (r *AiRouterBiz) handleKnowle(channel chan entitys.Response, c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask, sysInfo model.AiSy) (err error) { +func (r *AiRouterBiz) handleKnowle(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) { var ( configData entitys.ConfigDataTool @@ -303,11 +323,11 @@ func (r *AiRouterBiz) handleKnowle(channel chan entitys.Response, c *websocket.C } // 通过session 找到知识库session - session := c.Query("x-session", "") - if len(session) == 0 { + var has bool + if len(requireData.Session) == 0 { return errors.SessionNotFound } - sessionInfo, has, err := r.sessionImpl.FindOne(r.sessionImpl.WithSessionId(session)) + requireData.SessionInfo, has, err = r.sessionImpl.FindOne(r.sessionImpl.WithSessionId(requireData.Session)) if err != nil { return } else if !has { @@ -330,15 +350,15 @@ func (r *AiRouterBiz) handleKnowle(channel chan entitys.Response, c *websocket.C } // 知识库的session为空,请求知识库获取, 并绑定 - if sessionInfo.KnowlegeSessionID == "" { + if requireData.SessionInfo.KnowlegeSessionID == "" { // 请求知识库 - if sessionIdKnowledge, err = tools.GetKnowledgeBaseSession(host, sysInfo.KnowlegeBaseID, sysInfo.KnowlegeTenantKey); err != nil { + if sessionIdKnowledge, err = tools.GetKnowledgeBaseSession(host, requireData.Sys.KnowlegeBaseID, requireData.Sys.KnowlegeTenantKey); err != nil { return } // 绑定知识库session,下次可以使用 - sessionInfo.KnowlegeSessionID = sessionIdKnowledge - if err = r.sessionImpl.Update(&sessionInfo, r.sessionImpl.WithSessionId(sessionInfo.SessionID)); err != nil { + requireData.SessionInfo.KnowlegeSessionID = sessionIdKnowledge + if err = r.sessionImpl.Update(&requireData.SessionInfo, r.sessionImpl.WithSessionId(requireData.SessionInfo.SessionID)); err != nil { return } } @@ -346,25 +366,21 @@ func (r *AiRouterBiz) handleKnowle(channel chan entitys.Response, c *websocket.C // 用户输入解析 var ok bool input := make(map[string]string) - if err = json.Unmarshal([]byte(matchJson.Parameters), &input); err != nil { + if err = json.Unmarshal([]byte(requireData.Match.Parameters), &input); err != nil { return } if query, ok = input["query"]; !ok { return fmt.Errorf("query不能为空") } - knowledgeConfig := tools.KnowledgeBaseRequest{ - Session: sessionInfo.KnowlegeSessionID, - ApiKey: sysInfo.KnowlegeTenantKey, + requireData.KnowledgeConf = entitys.KnowledgeBaseRequest{ + Session: requireData.SessionInfo.KnowlegeSessionID, + ApiKey: requireData.Sys.KnowlegeTenantKey, Query: query, } - b, err := json.Marshal(knowledgeConfig) - if err != nil { - return - } // 执行工具 - err = r.toolManager.ExecuteTool(channel, c, configData.Tool, b, nil) + err = r.toolManager.ExecuteTool(ctx, configData.Tool, requireData) if err != nil { return } @@ -372,17 +388,16 @@ func (r *AiRouterBiz) handleKnowle(channel chan entitys.Response, c *websocket.C return } -func (r *AiRouterBiz) handleApiTask(channels chan entitys.Response, c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) { +func (r *AiRouterBiz) handleApiTask(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) { var ( request l_request.Request - auth = c.Query("x-authorization", "") requestParam map[string]interface{} ) - err = json.Unmarshal([]byte(matchJson.Parameters), &requestParam) + err = json.Unmarshal([]byte(requireData.Match.Parameters), &requestParam) if err != nil { return } - request.Url = strings.ReplaceAll(task.Config, "${authorization}", auth) + request.Url = strings.ReplaceAll(task.Config, "${authorization}", requireData.Auth) for k, v := range requestParam { task.Config = strings.ReplaceAll(task.Config, "${"+k+"}", fmt.Sprintf("%v", v)) } @@ -403,7 +418,11 @@ func (r *AiRouterBiz) handleApiTask(channels chan entitys.Response, c *websocket if err != nil { return } - c.WriteMessage(1, res.Content) + requireData.Ch <- entitys.Response{ + Index: "", + Content: pkg.JsonStringIgonErr(res.Text), + Type: entitys.ResponseJson, + } return } @@ -436,119 +455,3 @@ func (r *AiRouterBiz) getTasks(sysId int32) (tasks []model.AiTask, err error) { return } - -func (r *AiRouterBiz) registerTools(tasks []model.AiTask) []llms.Tool { - taskPrompt := make([]llms.Tool, 0) - for _, task := range tasks { - var taskConfig entitys.TaskConfig - err := json.Unmarshal([]byte(task.Config), &taskConfig) - if err != nil { - continue - } - taskPrompt = append(taskPrompt, llms.Tool{ - Type: "function", - Function: &llms.FunctionDefinition{ - Name: task.Index, - Description: task.Desc, - Parameters: taskConfig.Param, - }, - }) - - } - return taskPrompt -} - -func (r *AiRouterBiz) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string) []entitys.Message { - var ( - prompt = make([]entitys.Message, 0) - ) - prompt = append(prompt, entitys.Message{ - Role: "system", - Content: r.buildSystemPrompt(sysInfo.SysPrompt), - }, entitys.Message{ - Role: "assistant", - Content: pkg.JsonStringIgonErr(r.buildAssistant(history)), - }, entitys.Message{ - Role: "user", - Content: reqInput, - }) - return prompt -} - -func (r *AiRouterBiz) getPromptOllama(sysInfo model.AiSy, history []model.AiChatHi, reqInput string) []api.Message { - var ( - prompt = make([]api.Message, 0) - ) - prompt = append(prompt, api.Message{ - Role: "system", - Content: r.buildSystemPrompt(sysInfo.SysPrompt), - }, api.Message{ - Role: "assistant", - Content: pkg.JsonStringIgonErr(r.buildAssistant(history)), - }, api.Message{ - Role: "user", - Content: reqInput, - }) - return prompt -} - -func (r *AiRouterBiz) getPromptLLM(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []llms.MessageContent { - var ( - prompt = make([]llms.MessageContent, 0) - ) - prompt = append(prompt, llms.MessageContent{ - Role: llms.ChatMessageTypeSystem, - Parts: []llms.ContentPart{ - llms.TextPart(r.buildSystemPrompt(sysInfo.SysPrompt)), - }, - }, llms.MessageContent{ - Role: llms.ChatMessageTypeTool, - Parts: []llms.ContentPart{ - llms.TextPart(pkg.JsonStringIgonErr(r.buildAssistant(history))), - }, - }, llms.MessageContent{ - Role: llms.ChatMessageTypeTool, - Parts: []llms.ContentPart{ - llms.TextPart(pkg.JsonStringIgonErr(r.registerTools(tasks))), - }, - }, llms.MessageContent{ - Role: llms.ChatMessageTypeHuman, - Parts: []llms.ContentPart{ - llms.TextPart(reqInput), - }, - }) - return prompt -} - -// buildSystemPrompt 构建系统提示词 -func (r *AiRouterBiz) buildSystemPrompt(prompt string) string { - if len(prompt) == 0 { - prompt = "[system] 你是一个智能路由系统,核心职责是 **精准解析用户意图并路由至对应任务模块**\n[rule]\n1.返回以下格式的JSON:{ \"index\": \"工具索引index\", \"confidence\": 0.0-1.0,\"reasoning\": \"判断理由\"}\n2.严格返回字符串格式,禁用markdown格式返回\n3.只返回json字符串,不包含任何其他解释性文字\n4.当用户意图非常不清晰时使用,尝试进行追问具体希望查询内容" - } - - return prompt -} - -func (r *AiRouterBiz) buildAssistant(his []model.AiChatHi) (chatHis entitys.ChatHis) { - for _, item := range his { - if len(chatHis.SessionId) == 0 { - chatHis.SessionId = item.SessionID - } - chatHis.Messages = append(chatHis.Messages, entitys.HisMessage{ - Role: item.Role, - Content: item.Content, - Timestamp: item.CreateAt.Format("2006-01-02 15:04:05"), - }) - } - chatHis.Context = entitys.HisContext{ - UserLanguage: "zh-CN", - SystemMode: "technical_support", - } - return -} - -// handleKnowledgeQA 处理知识问答意图 -func (r *AiRouterBiz) handleKnowledgeQA(ctx context.Context, req *entitys.ChatRequest, messages []entitys.Message) (*entitys.ChatResponse, error) { - - return nil, nil -} diff --git a/internal/entitys/types.go b/internal/entitys/types.go index d437a86..fe77e4e 100644 --- a/internal/entitys/types.go +++ b/internal/entitys/types.go @@ -1,6 +1,8 @@ package entitys import ( + "ai_scheduler/internal/data/model" + "context" "encoding/json" @@ -73,7 +75,7 @@ type Tool interface { Name() string Description() string Definition() ToolDefinition - Execute(channel chan Response, c *websocket.Conn, args json.RawMessage, matchJson *Match) error + Execute(ctx context.Context, requireData *RequireData) error } type ConfigDataHttp struct { @@ -118,6 +120,7 @@ type Match struct { Reasoning string `json:"reasoning"` History []byte `json:"history"` UserInput string `json:"user_input"` + Auth string `json:"auth"` } type ChatHis struct { SessionId string `json:"session_id"` @@ -135,8 +138,27 @@ type HisContext struct { SystemMode string `json:"system_mode"` } +type RequireData struct { + Session string + Key string + Sys model.AiSy + Histories []model.AiChatHi + SessionInfo model.AiSession + Tasks []model.AiTask + Match *Match + UserInput string + Auth string + Ch chan Response + KnowledgeConf KnowledgeBaseRequest +} + +type KnowledgeBaseRequest struct { + Session string // 知识库会话id + ApiKey string // 知识库apiKey + Query string // 用户输入 +} + // RouterService 路由服务接口 type RouterService interface { - Route(ctx context.Context, req *ChatRequest) (*ChatResponse, error) RouteWithSocket(c *websocket.Conn, req *ChatSockRequest) error } diff --git a/internal/pkg/provider_set.go b/internal/pkg/provider_set.go index f8b1a25..e20b9c2 100644 --- a/internal/pkg/provider_set.go +++ b/internal/pkg/provider_set.go @@ -1,6 +1,7 @@ package pkg import ( + "ai_scheduler/internal/pkg/utils_langchain" "ai_scheduler/internal/pkg/utils_ollama" "github.com/google/wire" @@ -9,7 +10,7 @@ import ( var ProviderSetClient = wire.NewSet( NewRdb, NewGormDb, - utils_ollama.NewUtilOllama, + utils_langchain.NewUtilLangChain, utils_ollama.NewClient, NewSafeChannelPool, ) diff --git a/internal/pkg/utils_ollama/ollama.go b/internal/pkg/utils_langchain/client.go similarity index 88% rename from internal/pkg/utils_ollama/ollama.go rename to internal/pkg/utils_langchain/client.go index 6121838..914d290 100644 --- a/internal/pkg/utils_ollama/ollama.go +++ b/internal/pkg/utils_langchain/client.go @@ -1,4 +1,4 @@ -package utils_ollama +package utils_langchain import ( "ai_scheduler/internal/config" @@ -13,7 +13,7 @@ import ( "github.com/tmc/langchaingo/llms/ollama" ) -type UtilOllama struct { +type UtilLangChain struct { LlmClientPool *sync.Pool poolSize int // 记录池大小,用于调试 model string @@ -26,7 +26,7 @@ type LlmObj struct { Llm *ollama.LLM } -func NewUtilOllama(c *config.Config, logger log.AllLogger) *UtilOllama { +func NewUtilLangChain(c *config.Config, logger log.AllLogger) *UtilLangChain { poolSize := c.Sys.LlmPoolLen if poolSize <= 0 { poolSize = 10 // 默认值 @@ -60,7 +60,7 @@ func NewUtilOllama(c *config.Config, logger log.AllLogger) *UtilOllama { pool.Put(pool.New()) } - return &UtilOllama{ + return &UtilLangChain{ LlmClientPool: pool, poolSize: poolSize, model: c.Ollama.Model, @@ -69,7 +69,7 @@ func NewUtilOllama(c *config.Config, logger log.AllLogger) *UtilOllama { } -func (o *UtilOllama) NewClient() *ollama.LLM { +func (o *UtilLangChain) NewClient() *ollama.LLM { llm, _ := ollama.New( ollama.WithModel(o.c.Ollama.Model), ollama.WithHTTPClient(&http.Client{ @@ -91,13 +91,13 @@ func (o *UtilOllama) NewClient() *ollama.LLM { } // Get 返回一个可用的 LLM 客户端 -func (o *UtilOllama) Get() *LlmObj { +func (o *UtilLangChain) Get() *LlmObj { client := o.LlmClientPool.Get().(*LlmObj) return client } // Put 归还客户端(可选:检查是否仍可用) -func (o *UtilOllama) Put(llm *LlmObj) { +func (o *UtilLangChain) Put(llm *LlmObj) { if llm == nil { return } @@ -105,7 +105,7 @@ func (o *UtilOllama) Put(llm *LlmObj) { } // Stats 返回池的统计信息(用于监控) -func (o *UtilOllama) Stats() (current, max int) { +func (o *UtilLangChain) Stats() (current, max int) { return o.poolSize, o.poolSize } diff --git a/internal/tools/konwledge_base.go b/internal/tools/konwledge_base.go index 9aca733..897f686 100644 --- a/internal/tools/konwledge_base.go +++ b/internal/tools/konwledge_base.go @@ -5,13 +5,11 @@ import ( "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg/l_request" "bufio" + "context" "encoding/json" "fmt" "net/http" "strings" - - "github.com/gofiber/fiber/v2/log" - "github.com/gofiber/websocket/v2" ) // 知识库工具 @@ -60,22 +58,10 @@ func (k *KnowledgeBaseTool) Definition() entitys.ToolDefinition { } // Execute 执行知识库查询 -func (k *KnowledgeBaseTool) Execute(channel chan entitys.Response, c *websocket.Conn, args json.RawMessage, matchJson *entitys.Match) error { +func (k *KnowledgeBaseTool) Execute(ctx context.Context, requireData *entitys.RequireData) error { - var params KnowledgeBaseRequest - if err := json.Unmarshal(args, ¶ms); err != nil { - return fmt.Errorf("unmarshal args failed: %w", err) - } - log.Info("开始执行知识库 KnowledgeBaseTool Execute, params: %v", params) + return k.chat(requireData) - return k.chat(channel, c, params) - -} - -type KnowledgeBaseRequest struct { - Session string // 知识库会话id - ApiKey string // 知识库apiKey - Query string // 用户输入 } // Message 表示解析后的 SSE 消息 @@ -110,20 +96,20 @@ func (this *KnowledgeBaseTool) msgContentParse(input string, channel chan entity } // 请求知识库聊天 -func (this *KnowledgeBaseTool) chat(channel chan entitys.Response, c *websocket.Conn, param KnowledgeBaseRequest) (err error) { +func (this *KnowledgeBaseTool) chat(requireData *entitys.RequireData) (err error) { req := l_request.Request{ Method: "post", - Url: this.config.BaseURL + "/api/v1/knowledge-chat/" + param.Session, + Url: this.config.BaseURL + "/api/v1/knowledge-chat/" + requireData.KnowledgeConf.Session, Params: nil, Headers: map[string]string{ "Content-Type": "application/json", - "X-API-Key": param.ApiKey, + "X-API-Key": requireData.KnowledgeConf.ApiKey, }, Cookies: nil, Data: nil, Json: map[string]interface{}{ - "query": param.Query, + "query": requireData.KnowledgeConf.Query, }, Files: nil, Raw: "", @@ -137,7 +123,7 @@ func (this *KnowledgeBaseTool) chat(channel chan entitys.Response, c *websocket. } defer rsp.Body.Close() - err = this.connectAndReadSSE(rsp, channel) + err = this.connectAndReadSSE(rsp, requireData.Ch) if err != nil { return } diff --git a/internal/tools/manager.go b/internal/tools/manager.go index c061301..b31277b 100644 --- a/internal/tools/manager.go +++ b/internal/tools/manager.go @@ -5,11 +5,9 @@ import ( "ai_scheduler/internal/data/constants" "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg/utils_ollama" + "context" - "encoding/json" "fmt" - - "github.com/gofiber/websocket/v2" ) // Manager 工具管理器 @@ -100,13 +98,13 @@ func (m *Manager) GetToolDefinitions(caller constants.Caller) []entitys.ToolDefi } // ExecuteTool 执行工具 -func (m *Manager) ExecuteTool(channel chan entitys.Response, c *websocket.Conn, name string, args json.RawMessage, matchJson *entitys.Match) error { +func (m *Manager) ExecuteTool(ctx context.Context, name string, requireData *entitys.RequireData) error { tool, exists := m.GetTool(name) if !exists { return fmt.Errorf("tool not found: %s", name) } - return tool.Execute(channel, c, args, matchJson) + return tool.Execute(ctx, requireData) } // ExecuteToolCalls 执行多个工具调用 diff --git a/internal/tools/zltx_order_detail.go b/internal/tools/zltx_order_detail.go index bfe4ac0..84ff6f6 100644 --- a/internal/tools/zltx_order_detail.go +++ b/internal/tools/zltx_order_detail.go @@ -3,13 +3,13 @@ package tools import ( "ai_scheduler/internal/config" "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg" "ai_scheduler/internal/pkg/utils_ollama" "context" "encoding/json" "fmt" "gitea.cdlsxd.cn/self-tools/l_request" - "github.com/gofiber/websocket/v2" "github.com/ollama/ollama/api" ) @@ -81,34 +81,26 @@ type ZltxOrderDetailData struct { } // Execute 执行直连天下订单详情查询 -func (w *ZltxOrderDetailTool) Execute(channel chan entitys.Response, c *websocket.Conn, args json.RawMessage, matchJson *entitys.Match) error { +func (w *ZltxOrderDetailTool) Execute(ctx context.Context, requireData *entitys.RequireData) error { var req ZltxOrderDetailRequest - if err := json.Unmarshal(args, &req); err != nil { + if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil { return fmt.Errorf("invalid zltxOrderDetail request: %w", err) } - if req.OrderNumber == "" { return fmt.Errorf("number is required") } // 这里可以集成真实的直连天下订单详情API - return w.getZltxOrderDetail(channel, c, req.OrderNumber, matchJson) + return w.getZltxOrderDetail(requireData, req.OrderNumber) } // getMockZltxOrderDetail 获取模拟直连天下订单详情数据 -func (w *ZltxOrderDetailTool) getZltxOrderDetail(ch chan entitys.Response, c *websocket.Conn, number string, matchJson *entitys.Match) (err error) { +func (w *ZltxOrderDetailTool) getZltxOrderDetail(requireData *entitys.RequireData, number string) (err error) { //查询订单详情 - var auth string - if c != nil { - auth = c.Query("x-authorization", "") - } - if len(auth) == 0 { - auth = w.config.APIKey - } req := l_request.Request{ Url: fmt.Sprintf("%szltx_api/admin/direct/ai/%s", w.config.BaseURL, number), Headers: map[string]string{ - "Authorization": fmt.Sprintf("Bearer %s", auth), + "Authorization": fmt.Sprintf("Bearer %s", requireData.Auth), }, Method: "GET", } @@ -129,13 +121,13 @@ func (w *ZltxOrderDetailTool) getZltxOrderDetail(ch chan entitys.Response, c *we if err = json.Unmarshal(res.Content, &resData); err != nil { return } - ch <- entitys.Response{ + requireData.Ch <- entitys.Response{ Index: w.Name(), Content: res.Text, Type: entitys.ResponseJson, } if resData.Data.Direct != nil && resData.Data.Direct["needAi"].(bool) { - ch <- entitys.Response{ + requireData.Ch <- entitys.Response{ Index: w.Name(), Content: "正在分析订单日志", Type: entitys.ResponseLoading, @@ -144,7 +136,7 @@ func (w *ZltxOrderDetailTool) getZltxOrderDetail(ch chan entitys.Response, c *we req = l_request.Request{ Url: fmt.Sprintf("%szltx_api/admin/direct/log/%s/%s", w.config.BaseURL, resData.Data.Direct["orderOrderNumber"].(string), resData.Data.Direct["serialNumber"].(string)), Headers: map[string]string{ - "Authorization": fmt.Sprintf("Bearer %s", auth), + "Authorization": fmt.Sprintf("Bearer %s", requireData.Auth), }, Method: "GET", } @@ -164,14 +156,14 @@ func (w *ZltxOrderDetailTool) getZltxOrderDetail(ch chan entitys.Response, c *we return fmt.Errorf("订单日志解析失败:%s", err) } - err = w.llm.ChatStream(context.TODO(), ch, []api.Message{ + err = w.llm.ChatStream(context.TODO(), requireData.Ch, []api.Message{ { Role: "system", Content: "你是一个订单日志助手。用户可能会提供订单日志,你需要分析订单日志,提取出订单失败的原因。", }, { Role: "assistant", - Content: fmt.Sprintf("聊天记录:%s", string(matchJson.History)), + Content: fmt.Sprintf("聊天记录:%s", pkg.JsonStringIgonErr(requireData.Histories)), }, { Role: "assistant", @@ -179,7 +171,7 @@ func (w *ZltxOrderDetailTool) getZltxOrderDetail(ch chan entitys.Response, c *we }, { Role: "user", - Content: matchJson.UserInput, + Content: requireData.UserInput, }, }, w.Name()) if err != nil { @@ -187,15 +179,11 @@ func (w *ZltxOrderDetailTool) getZltxOrderDetail(ch chan entitys.Response, c *we } } if resData.Data.Direct == nil { - ch <- entitys.Response{ + requireData.Ch <- entitys.Response{ Index: w.Name(), Content: "该订单无充值流水记录,不需要进行订单错误分析,建议检查订单收单模式以及扣款方式等原因😘", Type: entitys.ResponseText, } } - // else { - - //} - return } diff --git a/internal/tools/zltx_order_direct_log.go b/internal/tools/zltx_order_direct_log.go index 2ef61ba..d2a0528 100644 --- a/internal/tools/zltx_order_direct_log.go +++ b/internal/tools/zltx_order_direct_log.go @@ -3,11 +3,11 @@ package tools import ( "ai_scheduler/internal/config" "ai_scheduler/internal/entitys" + "context" "encoding/json" "fmt" "gitea.cdlsxd.cn/self-tools/l_request" - "github.com/gofiber/websocket/v2" ) type ZltxOrderLogTool struct { @@ -67,31 +67,25 @@ type ZltxOrderDirectLogData struct { Data map[string]interface{} `json:"data"` } -func (t *ZltxOrderLogTool) Execute(channel chan entitys.Response, c *websocket.Conn, args json.RawMessage, matchJson *entitys.Match) error { +func (t *ZltxOrderLogTool) Execute(ctx context.Context, requireData *entitys.RequireData) error { var req ZltxOrderLogRequest - if err := json.Unmarshal(args, &req); err != nil { + if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil { return fmt.Errorf("invalid zltxOrderLog request: %w", err) } if req.OrderNumber == "" || req.SerialNumber == "" { return fmt.Errorf("orderNumber and serialNumber is required") } - return t.getZltxOrderLog(channel, c, req.OrderNumber, req.SerialNumber) + return t.getZltxOrderLog(req.OrderNumber, req.SerialNumber, requireData) } -func (t *ZltxOrderLogTool) getZltxOrderLog(channel chan entitys.Response, c *websocket.Conn, orderNumber, serialNumber string) (err error) { +func (t *ZltxOrderLogTool) getZltxOrderLog(orderNumber, serialNumber string, requireData *entitys.RequireData) (err error) { //查询订单详情 - var auth string - if c != nil { - auth = c.Query("x-authorization", "") - } - if len(auth) == 0 { - auth = t.config.APIKey - } + url := fmt.Sprintf("%s%s/%s", t.config.BaseURL, orderNumber, serialNumber) req := l_request.Request{ Url: url, Headers: map[string]string{ - "Authorization": fmt.Sprintf("Bearer %s", auth), + "Authorization": fmt.Sprintf("Bearer %s", requireData.Auth), }, Method: "GET", } @@ -106,7 +100,7 @@ func (t *ZltxOrderLogTool) getZltxOrderLog(channel chan entitys.Response, c *web if err = json.Unmarshal(res.Content, &resData); err != nil { return } - channel <- entitys.Response{ + requireData.Ch <- entitys.Response{ Index: t.Name(), Content: res.Text, Type: entitys.ResponseJson, diff --git a/internal/tools/zltx_product.go b/internal/tools/zltx_product.go index d6801c0..1d7b346 100644 --- a/internal/tools/zltx_product.go +++ b/internal/tools/zltx_product.go @@ -3,13 +3,13 @@ package tools import ( "ai_scheduler/internal/config" "ai_scheduler/internal/entitys" + "context" "encoding/json" "fmt" "strconv" "strings" "gitea.cdlsxd.cn/self-tools/l_request" - "github.com/gofiber/websocket/v2" ) type ZltxProductTool struct { @@ -53,12 +53,12 @@ type ZltxProductRequest struct { Name string `json:"name"` } -func (z ZltxProductTool) Execute(channel chan entitys.Response, c *websocket.Conn, args json.RawMessage, matchJson *entitys.Match) error { +func (z ZltxProductTool) Execute(ctx context.Context, requireData *entitys.RequireData) error { var req ZltxProductRequest - if err := json.Unmarshal(args, &req); err != nil { + if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil { return fmt.Errorf("invalid zltxProduct request: %w", err) } - return z.getZltxProduct(channel, c, req.Id, req.Name) + return z.getZltxProduct(&req, requireData) } type ZltxProductResponse struct { @@ -133,22 +133,16 @@ type ZltxProductData struct { PlatformProductList interface{} `json:"platform_product_list"` } -func (z ZltxProductTool) getZltxProduct(channel chan entitys.Response, c *websocket.Conn, id string, name string) error { - var auth string - if c != nil { - auth = c.Query("x-authorization", "") - } - if len(auth) == 0 { - auth = z.config.APIKey - } +func (z ZltxProductTool) getZltxProduct(body *ZltxProductRequest, requireData *entitys.RequireData) error { + var Url string var params map[string]string - if id != "" { - Url = fmt.Sprintf("%s/%s", z.config.BaseURL, id) + if body.Id != "" { + Url = fmt.Sprintf("%s/%s", z.config.BaseURL, body.Id) } else { - Url = fmt.Sprintf("%s?keyword=%s&limit=10&page=1", z.config.BaseURL, name) + Url = fmt.Sprintf("%s?keyword=%s&limit=10&page=1", z.config.BaseURL, body.Name) params = map[string]string{ - "keyword": name, + "keyword": body.Name, "limit": "10", "page": "1", } @@ -159,7 +153,7 @@ func (z ZltxProductTool) getZltxProduct(channel chan entitys.Response, c *websoc //根据商品ID或名称走不同的接口查询 Url: Url, Headers: map[string]string{ - "Authorization": fmt.Sprintf("Bearer %s", auth), + "Authorization": fmt.Sprintf("Bearer %s", requireData.Auth), }, Params: params, Method: "GET", @@ -191,7 +185,7 @@ func (z ZltxProductTool) getZltxProduct(channel chan entitys.Response, c *websoc for i := range resp.Data.List { // 调用 平台商品列表 if resp.Data.List[i].AuthProductIds != "" { - platformProductList := z.ExecutePlatformProductList(auth, resp.Data.List[i].AuthProductIds, resp.Data.List[i].OfficialProductID) + platformProductList := z.ExecutePlatformProductList(requireData.Auth, resp.Data.List[i].AuthProductIds, resp.Data.List[i].OfficialProductID) resp.Data.List[i].PlatformProductList = platformProductList } } @@ -200,7 +194,7 @@ func (z ZltxProductTool) getZltxProduct(channel chan entitys.Response, c *websoc if err != nil { return err } - channel <- entitys.Response{ + requireData.Ch <- entitys.Response{ Index: z.Name(), Content: string(marshal), Type: entitys.ResponseJson, diff --git a/internal/tools/zltx_statistics.go b/internal/tools/zltx_statistics.go index 8d925ba..7601605 100644 --- a/internal/tools/zltx_statistics.go +++ b/internal/tools/zltx_statistics.go @@ -3,12 +3,12 @@ package tools import ( "ai_scheduler/internal/config" "ai_scheduler/internal/entitys" + "context" "encoding/json" "fmt" "sort" "gitea.cdlsxd.cn/self-tools/l_request" - "github.com/gofiber/websocket/v2" ) type ZltxOrderStatisticsTool struct { @@ -47,15 +47,15 @@ type ZltxOrderStatisticsRequest struct { Number string `json:"number"` } -func (z ZltxOrderStatisticsTool) Execute(channel chan entitys.Response, c *websocket.Conn, args json.RawMessage, matchJson *entitys.Match) error { +func (z ZltxOrderStatisticsTool) Execute(ctx context.Context, requireData *entitys.RequireData) error { var req ZltxOrderStatisticsRequest - if err := json.Unmarshal(args, &req); err != nil { + if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil { return err } if req.Number == "" { return fmt.Errorf("number is required") } - return z.getZltxOrderStatistics(channel, c, req.Number) + return z.getZltxOrderStatistics(req.Number, requireData) } type ZltxOrderStatisticsResponse struct { @@ -75,20 +75,14 @@ type ZltxOrderStatisticsData struct { Total int `json:"total"` } -func (z ZltxOrderStatisticsTool) getZltxOrderStatistics(channel chan entitys.Response, c *websocket.Conn, number string) error { +func (z ZltxOrderStatisticsTool) getZltxOrderStatistics(number string, requireData *entitys.RequireData) error { //查询订单详情 - var auth string - if c != nil { - auth = c.Query("x-authorization", "") - } - if len(auth) == 0 { - auth = z.config.APIKey - } + url := fmt.Sprintf("%s%s", z.config.BaseURL, number) req := l_request.Request{ Url: url, Headers: map[string]string{ - "Authorization": fmt.Sprintf("Bearer %s", auth), + "Authorization": fmt.Sprintf("Bearer %s", requireData.Auth), }, Method: "GET", } @@ -114,7 +108,7 @@ func (z ZltxOrderStatisticsTool) getZltxOrderStatistics(channel chan entitys.Res if err != nil { return err } - channel <- entitys.Response{ + requireData.Ch <- entitys.Response{ Index: z.Name(), Content: string(jsonByte), Type: entitys.ResponseJson,