diff --git a/cmd/server/wire.go b/cmd/server/wire.go index 79f3667..22d62b3 100644 --- a/cmd/server/wire.go +++ b/cmd/server/wire.go @@ -6,6 +6,7 @@ package main import ( "ai_scheduler/internal/biz" "ai_scheduler/internal/biz/handle/dingtalk" + "ai_scheduler/internal/biz/tools_regis" "ai_scheduler/internal/config" "ai_scheduler/internal/data/impl" "ai_scheduler/internal/domain/workflow" @@ -33,6 +34,7 @@ func InitializeApp(*config.Config, log.AllLogger) (*server.Servers, func(), erro utils.ProviderUtils, tools_bot.ProviderSetBotTools, dingtalk.ProviderSetDingTalk, + tools_regis.ProviderToolsRegis, )) } diff --git a/internal/biz/ding_talk_bot.go b/internal/biz/ding_talk_bot.go index c67c8e5..ca7431b 100644 --- a/internal/biz/ding_talk_bot.go +++ b/internal/biz/ding_talk_bot.go @@ -3,16 +3,23 @@ package biz import ( "ai_scheduler/internal/biz/do" "ai_scheduler/internal/biz/handle/dingtalk" + "ai_scheduler/internal/biz/tools_regis" "ai_scheduler/internal/data/constants" "ai_scheduler/internal/data/impl" "ai_scheduler/internal/data/model" "ai_scheduler/internal/entitys" + "ai_scheduler/internal/tools" + "context" + "database/sql" "encoding/json" + "errors" "fmt" + "strings" "github.com/gofiber/fiber/v2/log" "github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot" + "xorm.io/builder" ) @@ -24,6 +31,9 @@ type DingTalkBotBiz struct { replier *chatbot.ChatbotReplier log log.Logger dingTalkUser *dingtalk.User + botTools []model.AiBotTool + botGroupImpl *impl.BotGroupImpl + toolManager *tools.Manager } // NewDingTalkBotBiz @@ -31,8 +41,10 @@ func NewDingTalkBotBiz( do *do.Do, handle *do.Handle, botConfigImpl *impl.BotConfigImpl, + botGroupImpl *impl.BotGroupImpl, dingTalkUser *dingtalk.User, - + tools *tools_regis.ToolRegis, + toolManager *tools.Manager, ) *DingTalkBotBiz { return &DingTalkBotBiz{ do: do, @@ -40,6 +52,9 @@ func NewDingTalkBotBiz( botConfigImpl: botConfigImpl, replier: chatbot.NewChatbotReplier(), dingTalkUser: dingTalkUser, + botTools: tools.BootTools, + botGroupImpl: botGroupImpl, + toolManager: toolManager, } } @@ -68,18 +83,201 @@ func (d *DingTalkBotBiz) InitRequire(ctx context.Context, data *chatbot.BotCallb Req: data, Ch: make(chan entitys.Response, 2), } - entitys.ResLog(requireData.Ch, "recognize_start", "收到消息,正在处理中,请稍等") - - requireData.UserInfo, err = d.dingTalkUser.GetUserInfoFromBot(ctx, data.SenderStaffId, dingtalk.WithId(1)) return } -func (d *DingTalkBotBiz) Recognize(ctx context.Context, bot *chatbot.BotCallbackDataModel) (match entitys.Match, err error) { - - return d.handle.Recognize(ctx, nil, &do.WithDingTalkBot{}) +func (d *DingTalkBotBiz) Do(ctx context.Context, requireData *entitys.RequireDataDingTalkBot) (err error) { + entitys.ResText(requireData.Ch, "", "收到消息,正在处理中,请稍等") + defer close(requireData.Ch) + switch constants.ConversationType(requireData.Req.ConversationType) { + case constants.ConversationTypeSingle: + err = d.handleSingleChat(ctx, requireData) + case constants.ConversationTypeGroup: + err = d.handleGroupChat(ctx, requireData) + default: + err = errors.New("未知的聊天类型:" + requireData.Req.ConversationType) + } + return } +func (d *DingTalkBotBiz) handleSingleChat(ctx context.Context, requireData *entitys.RequireDataDingTalkBot) (err error) { + entitys.ResLog(requireData.Ch, "", "个人聊天暂未开启,请期待后续更新") + return + //requireData.UserInfo, err = d.dingTalkUser.GetUserInfoFromBot(ctx, requireData.Req.SenderStaffId, dingtalk.WithId(1)) + //if err != nil { + // return + //} + ////如果不是管理或者不是老板,则进行权限判断 + //if requireData.UserInfo.IsSenior == constants.IsSeniorFalse && requireData.UserInfo.IsBoss == constants.IsBossFalse { + // + //} + //return +} + +func (d *DingTalkBotBiz) handleGroupChat(ctx context.Context, requireData *entitys.RequireDataDingTalkBot) (err error) { + group, err := d.initGroup(ctx, requireData.Req.ConversationId, requireData.Req.ConversationTitle) + if err != nil { + return + } + groupTools, err := d.getGroupTools(ctx, group) + if err != nil { + return + } + rec, err := d.recognize(ctx, requireData, groupTools) + if err != nil { + return + } + + return d.handleMatch(ctx, rec) +} + +func (d *DingTalkBotBiz) initGroup(ctx context.Context, conversationId string, conversationTitle string) (group *model.AiBotGroup, err error) { + group, err = d.botGroupImpl.GetByConversationId(conversationId) + if err != nil { + if !errors.Is(err, sql.ErrNoRows) { + + return + } + } + + if group.GroupID == 0 { + group = &model.AiBotGroup{ + ConversationID: conversationId, + Title: conversationTitle, + ToolList: "", + } + //如果不存在则创建 + d.botGroupImpl.Add(group) + } + return +} + +func (d *DingTalkBotBiz) getGroupTools(ctx context.Context, group *model.AiBotGroup) (tools []model.AiBotTool, err error) { + if len(d.botTools) == 0 { + return + } + var ( + groupRegisTools map[string]struct{} + ) + if group.ToolList != "" { + groupList := strings.Split(group.ToolList, ",") + for _, tool := range groupList { + groupRegisTools[tool] = struct{}{} + } + } + + for _, v := range d.botTools { + if v.PermissionType == constants.PermissionTypeNone { + tools = append(tools, v) + continue + } + if _, ex := groupRegisTools[v.Index]; ex { + tools = append(tools, v) + } + } + return +} +func (d *DingTalkBotBiz) recognize(ctx context.Context, requireData *entitys.RequireDataDingTalkBot, tools []model.AiBotTool) (rec *entitys.Recognize, err error) { + + userContent, err := d.getUserContent(requireData.Req.Msgtype, requireData.Req.Text.Content) + if err != nil { + return nil, err + } + rec = &entitys.Recognize{ + Ch: requireData.Ch, + SystemPrompt: d.defaultPrompt(), + UserContent: userContent, + } + if len(tools) > 0 { + rec.Tasks = make([]entitys.RegistrationTask, 0, len(tools)) + for _, task := range tools { + taskConfig := entitys.TaskConfigDetail{} + if err = json.Unmarshal([]byte(task.Config), &taskConfig); err != nil { + log.Errorf("解析任务配置失败: %s, 任务ID: %s", err.Error(), task.Index) + continue // 解析失败时跳过该任务,而不是直接返回错误 + } + + rec.Tasks = append(rec.Tasks, entitys.RegistrationTask{ + Name: task.Index, + Desc: task.Desc, + TaskConfigDetail: taskConfig, // 直接使用解析后的配置,避免重复构建 + }) + } + } + err = d.handle.Recognize(ctx, rec, &do.WithDingTalkBot{}) + return +} + +func (d *DingTalkBotBiz) getUserContent(msgType string, msgContent interface{}) (content *entitys.RecognizeUserContent, err error) { + switch constants.BotMsgType(msgType) { + case constants.BotMsgTypeText: + content = &entitys.RecognizeUserContent{ + Text: msgContent.(string), + } + default: + return nil, errors.New("未知的消息类型:" + msgType) + } + return +} + +func (d *DingTalkBotBiz) defaultPrompt() string { + + return `{"system":"智能路由系统,精准解析用户意图并路由至任务模块,遵循以下规则:","rule":{"返回格式":"{\\"index\\":\\"工具索引\\",\\"confidence\\":\\"0.0-1.0\\",\\"reasoning\\":\\"判断理由\\",\\"parameters\\":\\"转义JSON参数\\",\\"is_match\\":true|false,\\"chat\\":\\"追问内容\\"}","工具匹配":["用工具parameters匹配,区分必选(required)和可选(optional)参数","无法匹配时,is_match=false,chat提醒用户适用工具(例:'请问您要查询订单还是商品?')"],"参数提取":["从用户输入提取parameters中明确提及的参数","必须参数仅用用户直接提及的,缺失时is_match=false,chat提醒补充(例:'需补充XX信息')"],"格式要求":["所有字段值为字符串(含confidence)","parameters为转义JSON字符串(如\\"{\\\\"key\\\\":\\\\"value\\\\"}\\")"]}}` +} + +func (d *DingTalkBotBiz) handleMatch(ctx context.Context, rec *entitys.Recognize) (err error) { + + if !rec.Match.IsMatch { + if len(rec.Match.Chat) != 0 { + entitys.ResText(rec.Ch, "", rec.Match.Chat) + } else { + entitys.ResText(rec.Ch, "", rec.Match.Reasoning) + } + return + } + var pointTask *model.AiBotTool + for _, task := range d.botTools { + if task.Index == rec.Match.Index { + pointTask = &task + + break + } + } + + if pointTask == nil || pointTask.Index == "other" { + return d.otherTask(ctx, rec) + } + switch constants.TaskType(pointTask.Type) { + //case constants.TaskTypeApi: + //return d.handleApiTask(ctx, requireData, pointTask) + case constants.TaskTypeFunc: + return d.handleTask(ctx, rec, pointTask) + default: + return d.otherTask(ctx, rec) + } + return +} + +func (d *DingTalkBotBiz) handleTask(ctx context.Context, rec *entitys.Recognize, task *model.AiBotTool) (err error) { + var configData entitys.ConfigDataTool + err = json.Unmarshal([]byte(task.Config), &configData) + if err != nil { + return + } + + err = d.toolManager.ExecuteTool(ctx, configData.Tool, rec) + if err != nil { + return + } + + return +} + +func (d *DingTalkBotBiz) otherTask(ctx context.Context, rec *entitys.Recognize) (err error) { + entitys.ResText(rec.Ch, "", rec.Match.Reasoning) + return +} func (d *DingTalkBotBiz) HandleRes(ctx context.Context, data *chatbot.BotCallbackDataModel, resp entitys.Response) error { switch resp.Type { case entitys.ResponseText: diff --git a/internal/biz/do/ctx.go b/internal/biz/do/ctx.go index 80deb73..4b5d2ec 100644 --- a/internal/biz/do/ctx.go +++ b/internal/biz/do/ctx.go @@ -225,15 +225,6 @@ func (d *Do) GetSysInfo(requireData *entitys.RequireData) (sysInfo model.AiSy, e return } -func (d *Do) GetSysInfoForDingTalkBot(requireData *entitys.RequireDataDingTalkBot) (sysInfo model.AiSy, err error) { - cond := builder.NewCond() - cond = cond.And(builder.Eq{"app_key": requireData.Auth}) - cond = cond.And(builder.IsNull{"delete_at"}) - cond = cond.And(builder.Eq{"status": 1}) - err = d.sysImpl.GetOneBySearchToStrut(&cond, &sysInfo) - return -} - func (d *Do) getSessionChatHis(requireData *entitys.RequireData) (his []model.AiChatHi, err error) { cond := builder.NewCond() diff --git a/internal/biz/do/handle.go b/internal/biz/do/handle.go index 8a9162d..76de3d0 100644 --- a/internal/biz/do/handle.go +++ b/internal/biz/do/handle.go @@ -12,10 +12,11 @@ import ( "ai_scheduler/internal/gateway" "ai_scheduler/internal/pkg/l_request" "ai_scheduler/internal/pkg/mapstructure" + "ai_scheduler/internal/pkg/rec_extra" "ai_scheduler/internal/pkg/util" "ai_scheduler/internal/tools" "ai_scheduler/internal/tools/public" - "ai_scheduler/internal/tools_bot" + "context" "encoding/json" "fmt" @@ -25,9 +26,9 @@ import ( ) type Handle struct { - Ollama *llm_service.OllamaService - toolManager *tools.Manager - Bot *tools_bot.BotTool + Ollama *llm_service.OllamaService + toolManager *tools.Manager + conf *config.Config sessionImpl *impl.SessionImpl workflowManager *runtime.Registry @@ -38,20 +39,20 @@ func NewHandle( toolManager *tools.Manager, conf *config.Config, sessionImpl *impl.SessionImpl, - dTalkBot *tools_bot.BotTool, + workflowManager *runtime.Registry, ) *Handle { return &Handle{ - Ollama: Ollama, - toolManager: toolManager, - conf: conf, - sessionImpl: sessionImpl, - Bot: dTalkBot, + Ollama: Ollama, + toolManager: toolManager, + conf: conf, + sessionImpl: sessionImpl, + workflowManager: workflowManager, } } -func (r *Handle) Recognize(ctx context.Context, rec *entitys.Recognize, promptProcessor PromptOption) (match entitys.Match, err error) { +func (r *Handle) Recognize(ctx context.Context, rec *entitys.Recognize, promptProcessor PromptOption) (err error) { entitys.ResLog(rec.Ch, "recognize_start", "准备意图识别") prompt, err := promptProcessor.CreatePrompt(ctx, rec) @@ -65,11 +66,13 @@ func (r *Handle) Recognize(ctx context.Context, rec *entitys.Recognize, promptPr } entitys.ResLog(rec.Ch, "recognize", recognizeMsg) entitys.ResLog(rec.Ch, "recognize_end", "意图识别结束") - + var match entitys.Match if err = json.Unmarshal([]byte(recognizeMsg), &match); err != nil { err = errors.SysErr("数据结构错误:%v", err.Error()) return } + rec.Match = &match + return } @@ -78,28 +81,27 @@ func (r *Handle) handleOtherTask(ctx context.Context, requireData *entitys.Requi return } -func (r *Handle) HandleMatch(ctx context.Context, client *gateway.Client, requireData *entitys.RequireData) (err error) { +func (r *Handle) HandleMatch(ctx context.Context, client *gateway.Client, rec *entitys.Recognize, requireData *entitys.RequireData) (err error) { - if !requireData.Match.IsMatch { - if len(requireData.Match.Chat) != 0 { - entitys.ResText(requireData.Ch, "", requireData.Match.Chat) + if !rec.Match.IsMatch { + if len(rec.Match.Chat) != 0 { + entitys.ResText(rec.Ch, "", rec.Match.Chat) } else { - entitys.ResText(requireData.Ch, "", requireData.Match.Reasoning) + entitys.ResText(rec.Ch, "", rec.Match.Reasoning) } return } var pointTask *model.AiTask for _, task := range requireData.Tasks { - if task.Index == requireData.Match.Index { + if task.Index == rec.Match.Index { pointTask = &task - requireData.Task = task break } } if pointTask == nil || pointTask.Index == "other" { - return r.OtherTask(ctx, requireData) + return r.OtherTask(ctx, rec) } // 校验用户权限 @@ -110,47 +112,32 @@ func (r *Handle) HandleMatch(ctx context.Context, client *gateway.Client, requir switch constants.TaskType(pointTask.Type) { case constants.TaskTypeApi: - return r.handleApiTask(ctx, requireData, pointTask) + return r.handleApiTask(ctx, rec, pointTask) case constants.TaskTypeFunc: - return r.handleTask(ctx, requireData, pointTask) + return r.handleTask(ctx, rec, pointTask) case constants.TaskTypeKnowle: - return r.handleKnowle(ctx, requireData, pointTask) - case constants.TaskTypeBot: - return r.handleBot(ctx, requireData, pointTask) + return r.handleKnowle(ctx, rec, pointTask) + case constants.TaskTypeEinoWorkflow: - return r.handleEinoWorkflow(ctx, requireData, pointTask) + return r.handleEinoWorkflow(ctx, rec, pointTask) default: return r.handleOtherTask(ctx, requireData) } } -func (r *Handle) OtherTask(ctx context.Context, requireData *entitys.RequireData) (err error) { +func (r *Handle) OtherTask(ctx context.Context, requireData *entitys.Recognize) (err error) { entitys.ResText(requireData.Ch, "", requireData.Match.Reasoning) return } -func (r *Handle) handleBot(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) { - var configData entitys.ConfigDataTool - err = json.Unmarshal([]byte(task.Config), &configData) - if err != nil { - return - } - err = r.Bot.Execute(ctx, configData.Tool, requireData) - if err != nil { - return - } - - return -} - -func (r *Handle) handleTask(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) { +func (r *Handle) handleTask(ctx context.Context, rec *entitys.Recognize, task *model.AiTask) (err error) { var configData entitys.ConfigDataTool err = json.Unmarshal([]byte(task.Config), &configData) if err != nil { return } - err = r.toolManager.ExecuteTool(ctx, configData.Tool, requireData) + err = r.toolManager.ExecuteTool(ctx, configData.Tool, rec) if err != nil { return } @@ -159,7 +146,7 @@ func (r *Handle) handleTask(ctx context.Context, requireData *entitys.RequireDat } // 知识库 -func (r *Handle) handleKnowle(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) { +func (r *Handle) handleKnowle(ctx context.Context, rec *entitys.Recognize, task *model.AiTask) (err error) { var ( configData entitys.ConfigDataTool @@ -171,13 +158,16 @@ func (r *Handle) handleKnowle(ctx context.Context, requireData *entitys.RequireD if err != nil { return } - + ext, err := rec_extra.GetTaskRecExt(rec) + if err != nil { + return + } // 通过session 找到知识库session var has bool - if len(requireData.Session) == 0 { + if len(ext.Session) == 0 { return errors.SessionNotFound } - requireData.SessionInfo, has, err = r.sessionImpl.FindOne(r.sessionImpl.WithSessionId(requireData.Session)) + ext.SessionInfo, has, err = r.sessionImpl.FindOne(r.sessionImpl.WithSessionId(ext.Session)) if err != nil { return } else if !has { @@ -200,15 +190,15 @@ func (r *Handle) handleKnowle(ctx context.Context, requireData *entitys.RequireD } // 知识库的session为空,请求知识库获取, 并绑定 - if requireData.SessionInfo.KnowlegeSessionID == "" { + if ext.SessionInfo.KnowlegeSessionID == "" { // 请求知识库 - if sessionIdKnowledge, err = public.GetKnowledgeBaseSession(host, requireData.Sys.KnowlegeBaseID, requireData.Sys.KnowlegeTenantKey); err != nil { + if sessionIdKnowledge, err = public.GetKnowledgeBaseSession(host, ext.Sys.KnowlegeBaseID, ext.Sys.KnowlegeTenantKey); err != nil { return } // 绑定知识库session,下次可以使用 - requireData.SessionInfo.KnowlegeSessionID = sessionIdKnowledge - if err = r.sessionImpl.Update(&requireData.SessionInfo, r.sessionImpl.WithSessionId(requireData.SessionInfo.SessionID)); err != nil { + ext.SessionInfo.KnowlegeSessionID = sessionIdKnowledge + if err = r.sessionImpl.Update(&ext.SessionInfo, r.sessionImpl.WithSessionId(ext.SessionInfo.SessionID)); err != nil { return } } @@ -216,21 +206,21 @@ func (r *Handle) handleKnowle(ctx context.Context, requireData *entitys.RequireD // 用户输入解析 var ok bool input := make(map[string]string) - if err = json.Unmarshal([]byte(requireData.Match.Parameters), &input); err != nil { + if err = json.Unmarshal([]byte(rec.Match.Parameters), &input); err != nil { return } if query, ok = input["query"]; !ok { return fmt.Errorf("query不能为空") } - requireData.KnowledgeConf = entitys.KnowledgeBaseRequest{ - Session: requireData.SessionInfo.KnowlegeSessionID, - ApiKey: requireData.Sys.KnowlegeTenantKey, + ext.KnowledgeConf = entitys.KnowledgeBaseRequest{ + Session: ext.SessionInfo.KnowlegeSessionID, + ApiKey: ext.Sys.KnowlegeTenantKey, Query: query, } // 执行工具 - err = r.toolManager.ExecuteTool(ctx, configData.Tool, requireData) + err = r.toolManager.ExecuteTool(ctx, configData.Tool, rec) if err != nil { return } @@ -238,17 +228,21 @@ func (r *Handle) handleKnowle(ctx context.Context, requireData *entitys.RequireD return } -func (r *Handle) handleApiTask(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) { +func (r *Handle) handleApiTask(ctx context.Context, rec *entitys.Recognize, task *model.AiTask) (err error) { var ( request l_request.Request requestParam map[string]interface{} ) - err = json.Unmarshal([]byte(requireData.Match.Parameters), &requestParam) + ext, err := rec_extra.GetTaskRecExt(rec) + if err != nil { + return + } + err = json.Unmarshal([]byte(rec.Match.Parameters), &requestParam) if err != nil { return } // request.Url = strings.ReplaceAll(task.Config, "${authorization}", requireData.Auth) - task.Config = strings.ReplaceAll(task.Config, "${authorization}", requireData.Auth) + task.Config = strings.ReplaceAll(task.Config, "${authorization}", ext.Auth) for k, v := range requestParam { if vStr, ok := v.(string); ok { task.Config = strings.ReplaceAll(task.Config, "${"+k+"}", vStr) @@ -275,27 +269,31 @@ func (r *Handle) handleApiTask(ctx context.Context, requireData *entitys.Require return } - entitys.ResLoading(requireData.Ch, requireData.Task.Index, "正在请求数据") + entitys.ResLoading(rec.Ch, task.Index, "正在请求数据") res, err := request.Send() if err != nil { return } - entitys.ResJson(requireData.Ch, requireData.Task.Index, res.Text) + entitys.ResJson(rec.Ch, task.Index, res.Text) return } // eino 工作流 -func (r *Handle) handleEinoWorkflow(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) { +func (r *Handle) handleEinoWorkflow(ctx context.Context, rec *entitys.Recognize, task *model.AiTask) (err error) { // token 写入ctx - ctx = util.SetTokenToContext(ctx, requireData.Auth) + ext, err := rec_extra.GetTaskRecExt(rec) + if err != nil { + return + } + ctx = util.SetTokenToContext(ctx, ext.Auth) - entitys.ResLoading(requireData.Ch, requireData.Task.Index, "正在执行工作流") + entitys.ResLoading(rec.Ch, task.Index, "正在执行工作流") // 工作流内部输出 workflowId := task.Index - _, err = r.workflowManager.Invoke(ctx, workflowId, requireData) + _, err = r.workflowManager.Invoke(ctx, workflowId, rec) if err != nil { return err } diff --git a/internal/biz/handle/dingtalk/dept.go b/internal/biz/handle/dingtalk/dept.go index 0a6bd15..df8a607 100644 --- a/internal/biz/handle/dingtalk/dept.go +++ b/internal/biz/handle/dingtalk/dept.go @@ -41,8 +41,9 @@ func (d *Dept) GetDeptInfoByDeptIds(ctx context.Context, deptIds []int, authInfo var existDept = make([]int, len(deptsInfo), 0) for _, dept := range deptsInfo { depts = append(depts, &entitys.Dept{ - DeptId: int(dept.DeptID), - Name: dept.Name, + DeptId: int(dept.DeptID), + Name: dept.Name, + ToolList: dept.ToolList, }) existDept = append(existDept, int(dept.DeptID)) } diff --git a/internal/biz/handle/dingtalk/user.go b/internal/biz/handle/dingtalk/user.go index 900a492..2e4e615 100644 --- a/internal/biz/handle/dingtalk/user.go +++ b/internal/biz/handle/dingtalk/user.go @@ -46,12 +46,14 @@ func (u *User) GetUserInfoFromBot(ctx context.Context, staffId string, botOption return } } + //待优化 authInfo, err := u.auth.GetTokenFromBotOption(ctx, botOption...) if err != nil || authInfo == nil { return } //如果没有找到,则新增 if user == nil { + DingUserInfo, _err := u.getUserInfoFromDingTalk(ctx, authInfo.AccessToken, staffId) if _err != nil { return nil, _err diff --git a/internal/biz/router.go b/internal/biz/router.go index 5127a54..28ce0b8 100644 --- a/internal/biz/router.go +++ b/internal/biz/router.go @@ -5,6 +5,8 @@ import ( "ai_scheduler/internal/data/constants" errors "ai_scheduler/internal/data/error" "ai_scheduler/internal/gateway" + + "ai_scheduler/internal/pkg/rec_extra" "context" "encoding/json" "strings" @@ -68,14 +70,15 @@ func (r *AiRouterBiz) RouteWithSocket(client *gateway.Client, req *entitys.ChatS } //意图识别 - requireData.Match, err = r.handle.Recognize(ctx, &rec, sys) + err = r.handle.Recognize(ctx, &rec, sys) if err != nil { log.Errorf("意图识别失败: %s", err.Error()) return } - + //任务处理 + rec_extra.SetTaskRecExt(requireData, &rec) //向下传递 - if err = r.handle.HandleMatch(ctx, client, requireData); err != nil { + if err = r.handle.HandleMatch(ctx, client, &rec, requireData); err != nil { log.Errorf("任务处理失败: %s", err.Error()) return } diff --git a/internal/biz/tools_regis/provider_set.go b/internal/biz/tools_regis/provider_set.go new file mode 100644 index 0000000..8294cf0 --- /dev/null +++ b/internal/biz/tools_regis/provider_set.go @@ -0,0 +1,9 @@ +package tools_regis + +import ( + "github.com/google/wire" +) + +var ProviderToolsRegis = wire.NewSet( + NewToolsRegis, +) diff --git a/internal/biz/tools_regis/tools_regis.go b/internal/biz/tools_regis/tools_regis.go new file mode 100644 index 0000000..0109849 --- /dev/null +++ b/internal/biz/tools_regis/tools_regis.go @@ -0,0 +1,30 @@ +package tools_regis + +import ( + "ai_scheduler/internal/data/constants" + "ai_scheduler/internal/data/impl" + "ai_scheduler/internal/data/model" + + "xorm.io/builder" +) + +type ToolRegis struct { + //待优化 + BootTools []model.AiBotTool +} + +func NewToolsRegis(botToolsImpl *impl.BotToolsImpl) *ToolRegis { + botTools := &ToolRegis{} + err := botTools.RegisTools(botToolsImpl) + if err != nil { + panic(err) + } + return botTools +} + +func (t *ToolRegis) RegisTools(botToolsImpl *impl.BotToolsImpl) error { + cond := builder.NewCond() + cond = cond.And(builder.Eq{"status": constants.Enable}) + err := botToolsImpl.GetRangeToMapStruct(&cond, &t.BootTools) + return err +} diff --git a/internal/data/constants/bot.go b/internal/data/constants/bot.go index 446519a..78a46f1 100644 --- a/internal/data/constants/bot.go +++ b/internal/data/constants/bot.go @@ -33,3 +33,11 @@ const ( ) const DingTalkAuthBaseKeyPrefix = "dingTalk_auth" + +// PermissionType 工具使用权限 +type PermissionType int32 + +const ( + PermissionTypeNone = 1 + PermissionTypeDept = 2 +) diff --git a/internal/data/constants/dingtalk.go b/internal/data/constants/dingtalk.go index 5c9de89..c6d55b0 100644 --- a/internal/data/constants/dingtalk.go +++ b/internal/data/constants/dingtalk.go @@ -36,3 +36,16 @@ const ( IsSeniorTrue IsSenior = 1 IsSeniorFalse IsSenior = 0 ) + +type ConversationType string + +const ( + ConversationTypeSingle = "1" // 单聊 + ConversationTypeGroup = "2" //群聊 +) + +type BotMsgType string + +const ( + BotMsgTypeText BotMsgType = "text" +) diff --git a/internal/data/impl/bot_group.go b/internal/data/impl/bot_group.go new file mode 100644 index 0000000..4382d82 --- /dev/null +++ b/internal/data/impl/bot_group.go @@ -0,0 +1,27 @@ +package impl + +import ( + "ai_scheduler/internal/data/model" + "ai_scheduler/tmpl/dataTemp" + "ai_scheduler/utils" + "database/sql" +) + +type BotGroupImpl struct { + dataTemp.DataTemp +} + +func NewBotGroupImpl(db *utils.Db) *BotGroupImpl { + return &BotGroupImpl{ + DataTemp: *dataTemp.NewDataTemp(db, new(model.AiBotGroup)), + } +} + +func (k BotGroupImpl) GetByConversationId(staffId string) (*model.AiBotGroup, error) { + var data model.AiBotGroup + err := k.Db.Model(k.Model).Where("conversation_id = ?", staffId).Find(&data).Error + if data.GroupID == 0 { + err = sql.ErrNoRows + } + return &data, err +} diff --git a/internal/data/impl/bot_tools.go b/internal/data/impl/bot_tools.go new file mode 100644 index 0000000..d119098 --- /dev/null +++ b/internal/data/impl/bot_tools.go @@ -0,0 +1,17 @@ +package impl + +import ( + "ai_scheduler/internal/data/model" + "ai_scheduler/tmpl/dataTemp" + "ai_scheduler/utils" +) + +type BotToolsImpl struct { + dataTemp.DataTemp +} + +func NewBotToolsImpl(db *utils.Db) *BotToolsImpl { + return &BotToolsImpl{ + DataTemp: *dataTemp.NewDataTemp(db, new(model.AiBotTool)), + } +} diff --git a/internal/data/impl/bot_user.go b/internal/data/impl/bot_user.go index 46d8442..862292f 100644 --- a/internal/data/impl/bot_user.go +++ b/internal/data/impl/bot_user.go @@ -17,11 +17,11 @@ func NewBotUserImpl(db *utils.Db) *BotUserImpl { } } -func (k BotUserImpl) GetByStaffId(staffId string) (data *model.AiBotUser, err error) { - - err = k.Db.Model(k.Model).Where("staff_id = ?", staffId).Find(data).Error - if data == nil { +func (k BotUserImpl) GetByStaffId(staffId string) (*model.AiBotUser, error) { + var data model.AiBotUser + err := k.Db.Model(k.Model).Where("staff_id = ?", staffId).Find(&data).Error + if data.UserID == 0 { err = sql.ErrNoRows } - return + return &data, err } diff --git a/internal/data/impl/provider_set.go b/internal/data/impl/provider_set.go index 1563258..5624b3e 100644 --- a/internal/data/impl/provider_set.go +++ b/internal/data/impl/provider_set.go @@ -13,4 +13,6 @@ var ProviderImpl = wire.NewSet( NewBotDeptImpl, NewBotUserImpl, NewBotChatHisImpl, + NewBotToolsImpl, + NewBotGroupImpl, ) diff --git a/internal/data/model/ai_bot_dept.gen.go b/internal/data/model/ai_bot_dept.gen.go index db8ba60..ceddcce 100644 --- a/internal/data/model/ai_bot_dept.gen.go +++ b/internal/data/model/ai_bot_dept.gen.go @@ -12,12 +12,13 @@ const TableNameAiBotDept = "ai_bot_dept" // AiBotDept mapped from table type AiBotDept struct { - DeptID int32 `gorm:"column:dept_id;primaryKey" json:"dept_id"` - DingtalkDeptID int32 `gorm:"column:dingtalk_dept_id;not null;comment:标记部门的唯一id,钉钉:钉钉侧提供的dept_id" json:"dingtalk_dept_id"` // 标记部门的唯一id,钉钉:钉钉侧提供的dept_id - Name string `gorm:"column:name;not null;comment:用户名称" json:"name"` // 用户名称 - Status int32 `gorm:"column:status;not null" json:"status"` - DeleteAt time.Time `gorm:"column:delete_at" json:"delete_at"` - CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"` + DeptID int32 `gorm:"column:dept_id;primaryKey;autoIncrement:true" json:"dept_id"` + DingtalkDeptID int32 `gorm:"column:dingtalk_dept_id;not null;comment:标记部门的唯一id,钉钉:钉钉侧提供的dept_id" json:"dingtalk_dept_id"` // 标记部门的唯一id,钉钉:钉钉侧提供的dept_id + Name string `gorm:"column:name;not null;comment:用户名称" json:"name"` // 用户名称 + ToolList string `gorm:"column:tool_list;not null;comment:该部门支持的权限" json:"tool_list"` // 该部门支持的权限 + Status int32 `gorm:"column:status;not null;default:1" json:"status"` + DeleteAt *time.Time `gorm:"column:delete_at" json:"delete_at"` + CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"` } // TableName AiBotDept's table name diff --git a/internal/data/model/ai_bot_group.gen.go b/internal/data/model/ai_bot_group.gen.go new file mode 100644 index 0000000..d0ff93a --- /dev/null +++ b/internal/data/model/ai_bot_group.gen.go @@ -0,0 +1,27 @@ +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. + +package model + +import ( + "time" +) + +const TableNameAiBotGroup = "ai_bot_group" + +// AiBotGroup mapped from table +type AiBotGroup struct { + GroupID int32 `gorm:"column:group_id;primaryKey;autoIncrement:true" json:"group_id"` + ConversationID string `gorm:"column:conversation_id;not null;comment:会话ID" json:"conversation_id"` // 会话ID + Title string `gorm:"column:title;not null;comment:群名称" json:"title"` // 群名称 + ToolList string `gorm:"column:tool_list;not null;comment:开通工具列表" json:"tool_list"` // 开通工具列表 + Status int32 `gorm:"column:status;not null;default:1" json:"status"` + DeleteAt *time.Time `gorm:"column:delete_at" json:"delete_at"` + CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"` +} + +// TableName AiBotGroup's table name +func (*AiBotGroup) TableName() string { + return TableNameAiBotGroup +} diff --git a/internal/data/model/ai_bot_tools.gen.go b/internal/data/model/ai_bot_tools.gen.go new file mode 100644 index 0000000..f57889b --- /dev/null +++ b/internal/data/model/ai_bot_tools.gen.go @@ -0,0 +1,32 @@ +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. + +package model + +import ( + "time" +) + +const TableNameAiBotTool = "ai_bot_tools" + +// AiBotTool mapped from table +type AiBotTool struct { + ToolID int32 `gorm:"column:tool_id;primaryKey;autoIncrement:true" json:"tool_id"` + PermissionType int32 `gorm:"column:permission_type;not null;comment:类型,1为公共工具,不需要进行权限管理,反之则为2" json:"permission_type"` // 类型,1为公共工具,不需要进行权限管理,反之则为2 + Config string `gorm:"column:config;not null;comment:类型下所需路由以及参数" json:"config"` // 类型下所需路由以及参数 + Type int32 `gorm:"column:type;not null;default:3" json:"type"` + Name string `gorm:"column:name;not null;default:1;comment:工具名称" json:"name"` // 工具名称 + Index string `gorm:"column:index;not null;comment:索引" json:"index"` // 索引 + Desc string `gorm:"column:desc;not null;comment:工具描述" json:"desc"` // 工具描述 + TempPrompt string `gorm:"column:temp_prompt;not null;comment:提示词模板" json:"temp_prompt"` // 提示词模板 + CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"` + UpdatedAt time.Time `gorm:"column:updated_at;default:CURRENT_TIMESTAMP" json:"updated_at"` + Status int32 `gorm:"column:status;not null" json:"status"` + DeleteAt time.Time `gorm:"column:delete_at" json:"delete_at"` +} + +// TableName AiBotTool's table name +func (*AiBotTool) TableName() string { + return TableNameAiBotTool +} diff --git a/internal/domain/workflow/runtime/registry.go b/internal/domain/workflow/runtime/registry.go index 8ae7a82..c854053 100644 --- a/internal/domain/workflow/runtime/registry.go +++ b/internal/domain/workflow/runtime/registry.go @@ -12,7 +12,7 @@ import ( type Workflow interface { ID() string Schema() map[string]any - Invoke(ctx context.Context, requireData *entitys.RequireData) (map[string]any, error) + Invoke(ctx context.Context, requireData *entitys.Recognize) (map[string]any, error) } type Deps struct { @@ -63,7 +63,7 @@ func Default() *Registry { return r } -func (r *Registry) Invoke(ctx context.Context, id string, requireData *entitys.RequireData) (map[string]any, error) { +func (r *Registry) Invoke(ctx context.Context, id string, rec *entitys.Recognize) (map[string]any, error) { regMu.RLock() f, ok := factories[id] regMu.RUnlock() @@ -89,5 +89,5 @@ func (r *Registry) Invoke(ctx context.Context, id string, requireData *entitys.R r.mu.Unlock() } - return w.Invoke(ctx, requireData) + return w.Invoke(ctx, rec) } diff --git a/internal/domain/workflow/zltx/order_after_reseller_batch.go b/internal/domain/workflow/zltx/order_after_reseller_batch.go index 4d151d2..b4df00d 100644 --- a/internal/domain/workflow/zltx/order_after_reseller_batch.go +++ b/internal/domain/workflow/zltx/order_after_reseller_batch.go @@ -31,7 +31,7 @@ type OrderAfterSaleResellerBatchWorkflowInput struct { Ch chan entitys.Response // 响应通道 UserInput string // 用户输入文本 FileContent string // 文件解析结果 - UserHistory []model.AiChatHi // 用户对话历史 + UserHistory entitys.ChatHis // 用户对话历史 ParameterResult string // 参数解析结果 Data *OrderAfterSaleResellerBatchNodeData // 节点所需参数 } @@ -88,7 +88,7 @@ func (o *orderAfterSaleResellerBatch) Schema() map[string]any { } // Invoke 调用原有编排工作流并规范化输出 -func (o *orderAfterSaleResellerBatch) Invoke(ctx context.Context, requireData *entitys.RequireData) (map[string]any, error) { +func (o *orderAfterSaleResellerBatch) Invoke(ctx context.Context, rec *entitys.Recognize) (map[string]any, error) { // 构建工作流 chain, err := o.buildWorkflow(ctx) if err != nil { @@ -96,11 +96,11 @@ func (o *orderAfterSaleResellerBatch) Invoke(ctx context.Context, requireData *e } o.data = &OrderAfterSaleResellerBatchWorkflowInput{ - Ch: requireData.Ch, - UserInput: requireData.Req.Text, + Ch: rec.Ch, + UserInput: rec.UserContent.Text, FileContent: "", - UserHistory: requireData.Histories, - ParameterResult: requireData.Match.Parameters, + UserHistory: rec.ChatHis, + ParameterResult: rec.Match.Parameters, } // 工作流过程输出,不关注最终输出 _, err = chain.Invoke(ctx, o.data) diff --git a/internal/entitys/bot.go b/internal/entitys/bot.go index af92ba8..7085a37 100644 --- a/internal/entitys/bot.go +++ b/internal/entitys/bot.go @@ -3,21 +3,16 @@ package entitys import ( "ai_scheduler/internal/data/model" - "github.com/ollama/ollama/api" "github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot" ) type RequireDataDingTalkBot struct { - Histories []model.AiChatHi - UserInfo *DingTalkUserInfo - Tasks []model.AiTask - Match *Match - Req *chatbot.BotCallbackDataModel - Auth string - Ch chan Response - KnowledgeConf KnowledgeBaseRequest - ImgByte []api.ImageData - ImgUrls []string + Histories []model.AiChatHi + UserInfo *DingTalkUserInfo + Tools []model.AiBotTool + Match *Match + Req *chatbot.BotCallbackDataModel + Ch chan Response } type DingTalkBot struct { diff --git a/internal/entitys/dingtalk.go b/internal/entitys/dingtalk.go index 3595bf3..93cc825 100644 --- a/internal/entitys/dingtalk.go +++ b/internal/entitys/dingtalk.go @@ -17,6 +17,7 @@ type DingTalkUserInfo struct { } type Dept struct { - Name string `json:"name"` - DeptId int `json:"dept_id"` + Name string `json:"name"` + DeptId int `json:"dept_id"` + ToolList string `json:"tool_list"` } diff --git a/internal/entitys/recognize.go b/internal/entitys/recognize.go index 3d68be7..831bef7 100644 --- a/internal/entitys/recognize.go +++ b/internal/entitys/recognize.go @@ -1,6 +1,9 @@ package entitys -import "ai_scheduler/internal/data/constants" +import ( + "ai_scheduler/internal/data/constants" + "ai_scheduler/internal/data/model" +) type Recognize struct { SystemPrompt string // 系统提示内容 @@ -8,11 +11,23 @@ type Recognize struct { ChatHis ChatHis // 会话历史记录 Tasks []RegistrationTask Ch chan Response + Match *Match + Ext []byte +} + +type TaskExt struct { + Auth string `json:"auth"` + Session string `json:"session"` + Key string `json:"key"` + SessionInfo model.AiSession + Sys model.AiSy + KnowledgeConf KnowledgeBaseRequest } type RegistrationTask struct { Name string Desc string + Index string TaskConfigDetail TaskConfigDetail } @@ -26,6 +41,7 @@ type RecognizeUserContent struct { type FileData []byte type RecognizeFile struct { + FileRec string //文件识别内容 FileData FileData // 文件数据(二进制格式) FileType constants.Caller // 文件类型(文件类型,能填最好填,可以跳过一层判断) FileUrl string // 文件下载链接 diff --git a/internal/entitys/response.go b/internal/entitys/response.go index 44d4d18..39e81ad 100644 --- a/internal/entitys/response.go +++ b/internal/entitys/response.go @@ -25,6 +25,9 @@ const ( ) func ResLog(ch chan Response, index string, content string) { + if ch == nil { + return + } ch <- Response{ Index: index, Content: content, @@ -33,6 +36,9 @@ func ResLog(ch chan Response, index string, content string) { } func ResStream(ch chan Response, index string, content string) { + if ch == nil { + return + } ch <- Response{ Index: index, Content: content, @@ -41,6 +47,9 @@ func ResStream(ch chan Response, index string, content string) { } func ResJson(ch chan Response, index string, content string) { + if ch == nil { + return + } ch <- Response{ Index: index, Content: content, diff --git a/internal/entitys/types.go b/internal/entitys/types.go index b8f8caa..601f50e 100644 --- a/internal/entitys/types.go +++ b/internal/entitys/types.go @@ -78,7 +78,7 @@ type Tool interface { Name() string Description() string Definition() ToolDefinition - Execute(ctx context.Context, requireData *RequireData) error + Execute(ctx context.Context, requireData *Recognize) error } type ConfigDataHttp struct { diff --git a/internal/pkg/rec_extra/ext.go b/internal/pkg/rec_extra/ext.go new file mode 100644 index 0000000..53fbbe7 --- /dev/null +++ b/internal/pkg/rec_extra/ext.go @@ -0,0 +1,22 @@ +package rec_extra + +import ( + "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg" + "encoding/json" +) + +func SetTaskRecExt(requireData *entitys.RequireData, rec *entitys.Recognize) { + TaskExt := entitys.TaskExt{ + Auth: requireData.Auth, + Session: requireData.Session, + Key: requireData.Key, + Sys: requireData.Sys, + } + rec.Ext = pkg.JsonByteIgonErr(TaskExt) +} + +func GetTaskRecExt(rec *entitys.Recognize) (ext *entitys.TaskExt, err error) { + err = json.Unmarshal(rec.Ext, ext) + return ext, err +} diff --git a/internal/services/dtalk_bot.go b/internal/services/dtalk_bot.go index e4e9e8d..0bfe129 100644 --- a/internal/services/dtalk_bot.go +++ b/internal/services/dtalk_bot.go @@ -2,18 +2,19 @@ package services import ( "ai_scheduler/internal/biz" + "log" + "time" + "ai_scheduler/internal/config" "ai_scheduler/internal/entitys" "context" - "fmt" "github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot" ) type DingBotService struct { - config *config.Config - + config *config.Config dingTalkBotBiz *biz.DingTalkBotBiz } @@ -30,38 +31,42 @@ func (d *DingBotService) OnChatBotMessageReceived(ctx context.Context, data *cha if err != nil { return } + // 使用 ctx.Done() 通知 Do 方法提前终止 + subCtx, cancel := context.WithCancel(ctx) + defer cancel() + // 异步执行 Do 方法 + done := make(chan error, 1) go func() { - //defer close(requireData.Ch) - //if match, _err := d.dingTalkBotBiz.Recognize(ctx, data); _err != nil { - // requireData.Ch <- entitys.Response{ - // Type: entitys.ResponseEnd, - // Content: fmt.Sprintf("处理消息时出错: %s", _err.Error()), - // } - //} - ////向下传递 - //if err = d.dingTalkBotBiz.HandleMatch(ctx, nil, requireData); err != nil { - // requireData.Ch <- entitys.Response{ - // Type: entitys.ResponseEnd, - // Content: fmt.Sprintf("匹配失败: %v", err), - // } - //} + done <- d.dingTalkBotBiz.Do(subCtx, requireData) }() + + var lastErr error for { select { case <-ctx.Done(): - return nil, ctx.Err() + lastErr = ctx.Err() + goto cleanup case resp, ok := <-requireData.Ch: if !ok { - return []byte("success"), nil // 通道关闭,处理完成 + return []byte("success"), nil } - if resp.Type == entitys.ResponseLog { - return + continue } if err := d.dingTalkBotBiz.HandleRes(ctx, data, resp); err != nil { - return nil, fmt.Errorf("回复失败: %w", err) + log.Printf("HandleRes 失败: %v", err) } } } - return +cleanup: + select { + case err := <-done: + if err != nil { + log.Printf("Do 方法执行失败: %v", err) + } + case <-time.After(1 * time.Second): + log.Println("警告:等待 Do 方法超时,可能发生 goroutine 泄漏") + } + + return nil, lastErr } diff --git a/internal/tools/manager.go b/internal/tools/manager.go index cd1940d..eedb4ee 100644 --- a/internal/tools/manager.go +++ b/internal/tools/manager.go @@ -24,24 +24,6 @@ func NewManager(config *config.Config, llm *utils_ollama.Client) *Manager { llm: llm, } - // 注册天气工具 - //if config.Tools.Weather.Enabled { - // weatherTool := NewWeatherTool() - // m.tools[weatherTool.Name()] = weatherTool - //} - // - //// 注册计算器工具 - //if config.Tools.Calculator.Enabled { - // calcTool := NewCalculatorTool() - // m.tools[calcTool.Name()] = calcTool - //} - - // 注册知识库工具 - // if config.Knowledge.Enabled { - // knowledgeTool := NewKnowledgeTool() - // m.tools[knowledgeTool.Name()] = knowledgeTool - // } - // 注册直连天下订单详情工具 if config.Tools.ZltxOrderDetail.Enabled { zltxOrderDetailTool := zltxtool.NewZltxOrderDetailTool(config.Tools.ZltxOrderDetail, m.llm) @@ -115,63 +97,12 @@ func (m *Manager) GetTool(name string) (entitys.Tool, bool) { return tool, exists } -// GetAllTools 获取所有工具 -// func (m *Manager) GetAllTools() []entitys.Tool { -// tools := make([]entitys.Tool, 0, len(m.tools)) -// for _, tool := range m.tools { -// tools = append(tools, tool) -// } -// return tools -// } - -// // GetToolDefinitions 获取所有工具定义 -// func (m *Manager) GetToolDefinitions(caller constants.Caller) []entitys.ToolDefinition { -// definitions := make([]entitys.ToolDefinition, 0, len(m.tools)) -// for _, tool := range m.tools { -// definitions = append(definitions, tool.Definition()) -// } - -// return definitions -// } - // ExecuteTool 执行工具 -func (m *Manager) ExecuteTool(ctx context.Context, name string, requireData *entitys.RequireData) error { +func (m *Manager) ExecuteTool(ctx context.Context, name string, rec *entitys.Recognize) error { tool, exists := m.GetTool(name) if !exists { return fmt.Errorf("tool not found: %s", name) } - return tool.Execute(ctx, requireData) + return tool.Execute(ctx, rec) } - -// ExecuteToolCalls 执行多个工具调用 -//func (m *Manager) ExecuteToolCalls(c *websocket.Conn, toolCalls []entitys.ToolCall) ([]entitys.ToolCall, error) { -// results := make([]entitys.ToolCall, len(toolCalls)) -// -// for i, toolCall := range toolCalls { -// results[i] = toolCall -// -// // 执行工具 -// err := m.ExecuteTool(c, toolCall.Function.Name, toolCall.Function.Arguments) -// if err != nil { -// // 将错误信息作为结果返回 -// errorResult := map[string]interface{}{ -// "error": err.Error(), -// } -// resultBytes, _ := json.Marshal(errorResult) -// results[i].Result = resultBytes -// } else { -// // 将成功结果序列化 -// resultBytes, err := json.Marshal(result) -// if err != nil { -// errorResult := map[string]interface{}{ -// "error": fmt.Sprintf("failed to serialize result: %v", err), -// } -// resultBytes, _ = json.Marshal(errorResult) -// } -// results[i].Result = resultBytes -// } -// } -// -// return results, nil -//} diff --git a/internal/tools/public/coze_company.go b/internal/tools/public/coze_company.go index 3f10638..e061965 100644 --- a/internal/tools/public/coze_company.go +++ b/internal/tools/public/coze_company.go @@ -7,10 +7,11 @@ import ( "context" "encoding/json" "fmt" - "github.com/ollama/ollama/api" "net/http" "time" + "github.com/ollama/ollama/api" + "github.com/coze-dev/coze-go" ) @@ -70,7 +71,7 @@ func (c *CozeCompany) Definition() entitys.ToolDefinition { } // Execute 执行查询 -func (c *CozeCompany) Execute(ctx context.Context, requireData *entitys.RequireData) error { +func (c *CozeCompany) Execute(ctx context.Context, requireData *entitys.Recognize) error { var req map[string]interface{} if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil { return fmt.Errorf("invalid express request: %w", err) @@ -191,7 +192,7 @@ Top3核心风险:1. 根据司法信息,近1年作为被告的合同纠纷占 }, { Role: "user", - Content: requireData.Req.Text, + Content: requireData.UserContent.Text, }, }, c.Name(), "") diff --git a/internal/tools/public/coze_express.go b/internal/tools/public/coze_express.go index de316b5..58e6172 100644 --- a/internal/tools/public/coze_express.go +++ b/internal/tools/public/coze_express.go @@ -8,6 +8,7 @@ import ( "context" "encoding/json" "fmt" + "github.com/ollama/ollama/api" "github.com/coze-dev/coze-go" @@ -67,7 +68,7 @@ func (c *CozeExpress) Definition() entitys.ToolDefinition { } // Execute 执行查询 -func (c *CozeExpress) Execute(ctx context.Context, requireData *entitys.RequireData) error { +func (c *CozeExpress) Execute(ctx context.Context, requireData *entitys.Recognize) error { var req map[string]interface{} if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil { return fmt.Errorf("invalid express request: %w", err) @@ -89,7 +90,7 @@ func (c *CozeExpress) Execute(ctx context.Context, requireData *entitys.RequireD }, { Role: "assistant", - Content: fmt.Sprintf("聊天记录:%s", pkg.JsonStringIgonErr(requireData.Histories)), + Content: fmt.Sprintf("聊天记录:%s", pkg.JsonStringIgonErr(requireData.ChatHis)), }, { Role: "assistant", @@ -97,7 +98,7 @@ func (c *CozeExpress) Execute(ctx context.Context, requireData *entitys.RequireD }, { Role: "user", - Content: requireData.Req.Text, + Content: requireData.UserContent.Text, }, }, c.Name(), "") if err != nil { diff --git a/internal/tools/public/konwledge_base.go b/internal/tools/public/konwledge_base.go index 505a349..48ddc1e 100644 --- a/internal/tools/public/konwledge_base.go +++ b/internal/tools/public/konwledge_base.go @@ -4,6 +4,7 @@ import ( "ai_scheduler/internal/config" "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg/l_request" + "ai_scheduler/internal/pkg/rec_extra" "bufio" "context" "encoding/json" @@ -58,7 +59,7 @@ func (k *KnowledgeBaseTool) Definition() entitys.ToolDefinition { } // Execute 执行知识库查询 -func (k *KnowledgeBaseTool) Execute(ctx context.Context, requireData *entitys.RequireData) error { +func (k *KnowledgeBaseTool) Execute(ctx context.Context, requireData *entitys.Recognize) error { entitys.ResLoading(requireData.Ch, k.Name(), "正在为您搜索相关信息") return k.chat(requireData) @@ -91,20 +92,23 @@ func (this *KnowledgeBaseTool) msgContentParse(input string, channel chan entity } // 请求知识库聊天 -func (this *KnowledgeBaseTool) chat(requireData *entitys.RequireData) (err error) { - +func (this *KnowledgeBaseTool) chat(rec *entitys.Recognize) (err error) { + ext, err := rec_extra.GetTaskRecExt(rec) + if err != nil { + return + } req := l_request.Request{ Method: "post", - Url: this.config.BaseURL + "/api/v1/knowledge-chat/" + requireData.KnowledgeConf.Session, + Url: this.config.BaseURL + "/api/v1/knowledge-chat/" + ext.KnowledgeConf.Session, Params: nil, Headers: map[string]string{ "Content-Type": "application/json", - "X-API-Key": requireData.KnowledgeConf.ApiKey, + "X-API-Key": ext.KnowledgeConf.ApiKey, }, Cookies: nil, Data: nil, Json: map[string]interface{}{ - "query": requireData.KnowledgeConf.Query, + "query": ext.KnowledgeConf.Query, }, Files: nil, Raw: "", @@ -118,7 +122,7 @@ func (this *KnowledgeBaseTool) chat(requireData *entitys.RequireData) (err error } defer rsp.Body.Close() - err = this.connectAndReadSSE(rsp, requireData.Ch) + err = this.connectAndReadSSE(rsp, rec.Ch) if err != nil { return } diff --git a/internal/tools/public/normal_chat.go b/internal/tools/public/normal_chat.go index a4c96e0..cde667c 100644 --- a/internal/tools/public/normal_chat.go +++ b/internal/tools/public/normal_chat.go @@ -43,9 +43,9 @@ func (w *NormalChatTool) Definition() entitys.ToolDefinition { } // Execute 执行直连天下订单详情查询 -func (w *NormalChatTool) Execute(ctx context.Context, requireData *entitys.RequireData) error { +func (w *NormalChatTool) Execute(ctx context.Context, rec *entitys.Recognize) error { var req NormalChat - if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil { + if err := json.Unmarshal([]byte(rec.Match.Parameters), &req); err != nil { return fmt.Errorf("invalid zltxOrderDetail request: %w", err) } if req.ChatContent == "" { @@ -53,25 +53,25 @@ func (w *NormalChatTool) Execute(ctx context.Context, requireData *entitys.Requi } // 这里可以集成真实的直连天下订单详情API - return w.chat(requireData, &req) + return w.chat(rec, &req) } // getMockZltxOrderDetail 获取模拟直连天下订单详情数据 -func (w *NormalChatTool) chat(requireData *entitys.RequireData, chat *NormalChat) (err error) { +func (w *NormalChatTool) chat(rec *entitys.Recognize, chat *NormalChat) (err error) { //requireData.Ch <- entitys.Response{ // Index: w.Name(), // Content: "", // Type: entitys.ResponseStream, //} - err = w.llm.ChatStream(context.TODO(), requireData.Ch, []api.Message{ + err = w.llm.ChatStream(context.TODO(), rec.Ch, []api.Message{ { Role: "system", Content: "你是一个聊天助手", }, { Role: "assistant", - Content: fmt.Sprintf("聊天记录:%s", pkg.JsonStringIgonErr(requireData.Histories)), + Content: fmt.Sprintf("聊天记录:%s", pkg.JsonStringIgonErr(rec.ChatHis)), }, { Role: "user", diff --git a/internal/tools/public/weather.go b/internal/tools/public/weather.go index 0292fa4..ca36907 100644 --- a/internal/tools/public/weather.go +++ b/internal/tools/public/weather.go @@ -107,9 +107,9 @@ type LiveWeather struct { } // Execute 执行天气查询 -func (w *WeatherTool) Execute(ctx context.Context, requireData *entitys.RequireData) error { +func (w *WeatherTool) Execute(ctx context.Context, rec *entitys.Recognize) error { var req WeatherRequest - if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil { + if err := json.Unmarshal([]byte(rec.Match.Parameters), &req); err != nil { return fmt.Errorf("invalid weather request: %w", err) } @@ -134,7 +134,7 @@ func (w *WeatherTool) Execute(ctx context.Context, requireData *entitys.RequireD // 根据 extensions 参数返回不同的天气信息 if req.Extensions == "base" { - entitys.ResText(requireData.Ch, "", fmt.Sprintf("%s实时天气:%s,温度:%.1f℃,湿度:%d%%,风速:%.1fkm/h,风向:%s", + entitys.ResText(rec.Ch, "", fmt.Sprintf("%s实时天气:%s,温度:%.1f℃,湿度:%d%%,风速:%.1fkm/h,风向:%s", req.City, responseMsg.LiveWeather.Condition, responseMsg.LiveWeather.Temperature, @@ -149,7 +149,7 @@ func (w *WeatherTool) Execute(ctx context.Context, requireData *entitys.RequireD forecast.Date, forecast.DayTemp, forecast.NightTemp, forecast.DayWind, forecast.NightWind) } - entitys.ResText(requireData.Ch, "", rspStr) + entitys.ResText(rec.Ch, "", rspStr) } return nil } diff --git a/internal/tools/zltx/order_after_reseller.go b/internal/tools/zltx/order_after_reseller.go index 71e0629..fcb71e0 100644 --- a/internal/tools/zltx/order_after_reseller.go +++ b/internal/tools/zltx/order_after_reseller.go @@ -4,6 +4,7 @@ import ( "ai_scheduler/internal/config" "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg/l_request" + "ai_scheduler/internal/pkg/rec_extra" "ai_scheduler/internal/pkg/util" "context" "encoding/json" @@ -104,9 +105,9 @@ type OrderAfterSaleResellerApiExtItem struct { SerialCreateTime int `json:"createTime"` // 流水创建时间 } -func (t *OrderAfterSaleResellerTool) Execute(ctx context.Context, requireData *entitys.RequireData) error { +func (t *OrderAfterSaleResellerTool) Execute(ctx context.Context, rec *entitys.Recognize) error { var req OrderAfterSaleResellerRequest - if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil { + if err := json.Unmarshal([]byte(rec.Match.Parameters), &req); err != nil { return fmt.Errorf("解析参数失败,请重试或联系管理员") } if len(req.OrderNumber) == 0 && len(req.Account) == 0 { @@ -116,18 +117,22 @@ func (t *OrderAfterSaleResellerTool) Execute(ctx context.Context, requireData *e if req.SerialCreateTime != "" { _, err := time.ParseInLocation(time.DateTime, req.SerialCreateTime, time.Local) if err != nil { - entitys.ResLog(requireData.Ch, t.Name(), "时间格式不匹配,已置为空") + entitys.ResLog(rec.Ch, t.Name(), "时间格式不匹配,已置为空") req.SerialCreateTime = "" } } - entitys.ResLog(requireData.Ch, t.Name(), "正在拉取售后订单信息") + entitys.ResLog(rec.Ch, t.Name(), "正在拉取售后订单信息") - return t.checkOrderAfterSaleReseller(req, requireData) + return t.checkOrderAfterSaleReseller(req, rec) } -func (t *OrderAfterSaleResellerTool) checkOrderAfterSaleReseller(toolReq OrderAfterSaleResellerRequest, requireData *entitys.RequireData) error { +func (t *OrderAfterSaleResellerTool) checkOrderAfterSaleReseller(toolReq OrderAfterSaleResellerRequest, rec *entitys.Recognize) error { var serialStartTime, serialEndTime int64 + ext, err := rec_extra.GetTaskRecExt(rec) + if err != nil { + return err + } if toolReq.SerialCreateTime != "" { // 流水创建时间上下浮动10min serialCreateTime, err := time.ParseInLocation(time.DateTime, toolReq.SerialCreateTime, time.Local) @@ -144,17 +149,16 @@ func (t *OrderAfterSaleResellerTool) checkOrderAfterSaleReseller(toolReq OrderAf // 账号数量超过10直接截断 if len(toolReq.Account) > 10 { - entitys.ResLog(requireData.Ch, t.Name(), "账号数量超过10已被截断") + entitys.ResLog(rec.Ch, t.Name(), "账号数量超过10已被截断") toolReq.Account = toolReq.Account[:10] } headers := map[string]string{ - "Authorization": fmt.Sprintf("Bearer %s", requireData.Auth), + "Authorization": fmt.Sprintf("Bearer %s", ext.Auth), } // 最终输出 var orderList []*OrderAfterSaleResellerData - var err error // 多订单号 if len(toolReq.OrderNumber) > 0 { @@ -217,8 +221,8 @@ func (t *OrderAfterSaleResellerTool) checkOrderAfterSaleReseller(toolReq OrderAf return err } - entitys.ResLog(requireData.Ch, t.Name(), "售后订单信息拉取完成") - entitys.ResJson(requireData.Ch, t.Name(), string(jsonByte)) + entitys.ResLog(rec.Ch, t.Name(), "售后订单信息拉取完成") + entitys.ResJson(rec.Ch, t.Name(), string(jsonByte)) return nil } diff --git a/internal/tools/zltx/order_after_reseller_batch.go b/internal/tools/zltx/order_after_reseller_batch.go index e12e602..664d3d5 100644 --- a/internal/tools/zltx/order_after_reseller_batch.go +++ b/internal/tools/zltx/order_after_reseller_batch.go @@ -4,6 +4,7 @@ import ( "ai_scheduler/internal/config" "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg/l_request" + "ai_scheduler/internal/pkg/rec_extra" "ai_scheduler/internal/pkg/util" "context" "encoding/json" @@ -100,25 +101,29 @@ type OrderAfterSaleResellerBatchApiExtItem struct { SerialCreateTime int `json:"createTime"` // 流水创建时间 } -func (t *OrderAfterSaleResellerBatchTool) Execute(ctx context.Context, requireData *entitys.RequireData) error { +func (t *OrderAfterSaleResellerBatchTool) Execute(ctx context.Context, rec *entitys.Recognize) error { var req OrderAfterSaleResellerBatchRequest - if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil { + if err := json.Unmarshal([]byte(rec.Match.Parameters), &req); err != nil { return fmt.Errorf("解析参数失败,请重试或联系管理员") } if len(req.OrderNumber) == 0 { return fmt.Errorf("批充订单号不能为空") } - entitys.ResLog(requireData.Ch, t.Name(), "正在拉取售后订单信息") + entitys.ResLog(rec.Ch, t.Name(), "正在拉取售后订单信息") - return t.checkOrderAfterSaleResellerBatch(req, requireData) + return t.checkOrderAfterSaleResellerBatch(req, rec) } -func (t *OrderAfterSaleResellerBatchTool) checkOrderAfterSaleResellerBatch(toolReq OrderAfterSaleResellerBatchRequest, requireData *entitys.RequireData) error { +func (t *OrderAfterSaleResellerBatchTool) checkOrderAfterSaleResellerBatch(toolReq OrderAfterSaleResellerBatchRequest, rec *entitys.Recognize) error { + ext, err := rec_extra.GetTaskRecExt(rec) + if err != nil { + return err + } req := l_request.Request{ Url: t.config.BaseURL, Headers: map[string]string{ - "Authorization": fmt.Sprintf("Bearer %s", requireData.Auth), + "Authorization": fmt.Sprintf("Bearer %s", ext.Auth), }, Method: "POST", Json: map[string]any{ @@ -200,7 +205,7 @@ func (t *OrderAfterSaleResellerBatchTool) checkOrderAfterSaleResellerBatch(toolR return err } - entitys.ResLog(requireData.Ch, t.Name(), "售后订单信息拉取完成") - entitys.ResJson(requireData.Ch, t.Name(), string(jsonByte)) + entitys.ResLog(rec.Ch, t.Name(), "售后订单信息拉取完成") + entitys.ResJson(rec.Ch, t.Name(), string(jsonByte)) return nil } diff --git a/internal/tools/zltx/order_after_supplier.go b/internal/tools/zltx/order_after_supplier.go index 4baf7f9..9e9ad82 100644 --- a/internal/tools/zltx/order_after_supplier.go +++ b/internal/tools/zltx/order_after_supplier.go @@ -4,6 +4,7 @@ import ( "ai_scheduler/internal/config" "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg/l_request" + "ai_scheduler/internal/pkg/rec_extra" "ai_scheduler/internal/pkg/util" "context" "encoding/json" @@ -98,9 +99,9 @@ type OrderAfterSaleSupplierApiExtItem struct { SerialCreateTime int `json:"createTime"` // 流水创建时间 } -func (t *OrderAfterSaleSupplierTool) Execute(ctx context.Context, requireData *entitys.RequireData) error { +func (t *OrderAfterSaleSupplierTool) Execute(ctx context.Context, rec *entitys.Recognize) error { var req OrderAfterSaleSupplierRequest - if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil { + if err := json.Unmarshal([]byte(rec.Match.Parameters), &req); err != nil { return fmt.Errorf("解析参数失败,请重试或联系管理员") } if len(req.SerialNumber) == 0 && len(req.Account) == 0 { @@ -110,18 +111,24 @@ func (t *OrderAfterSaleSupplierTool) Execute(ctx context.Context, requireData *e if req.SerialCreateTime != "" { _, err := time.ParseInLocation(time.DateTime, req.SerialCreateTime, time.Local) if err != nil { - entitys.ResLog(requireData.Ch, t.Name(), "时间格式不匹配,已置为空") + entitys.ResLog(rec.Ch, t.Name(), "时间格式不匹配,已置为空") req.SerialCreateTime = "" } } - entitys.ResLog(requireData.Ch, t.Name(), "正在拉取售后订单信息") + entitys.ResLog(rec.Ch, t.Name(), "正在拉取售后订单信息") - return t.checkOrderAfterSaleSupplier(req, requireData) + return t.checkOrderAfterSaleSupplier(req, rec) } -func (t *OrderAfterSaleSupplierTool) checkOrderAfterSaleSupplier(toolReq OrderAfterSaleSupplierRequest, requireData *entitys.RequireData) error { +func (t *OrderAfterSaleSupplierTool) checkOrderAfterSaleSupplier(toolReq OrderAfterSaleSupplierRequest, rec *entitys.Recognize) error { + var serialStartTime, serialEndTime int64 + ext, err := rec_extra.GetTaskRecExt(rec) + if err != nil { + return err + } + if toolReq.SerialCreateTime != "" { // 流水创建时间上下浮动10min serialCreateTime, err := time.ParseInLocation(time.DateTime, toolReq.SerialCreateTime, time.Local) @@ -138,17 +145,16 @@ func (t *OrderAfterSaleSupplierTool) checkOrderAfterSaleSupplier(toolReq OrderAf // 账号数量超过10直接截断 if len(toolReq.Account) > 10 { - entitys.ResLog(requireData.Ch, t.Name(), "账号数量超过10已被截断") + entitys.ResLog(rec.Ch, t.Name(), "账号数量超过10已被截断") toolReq.Account = toolReq.Account[:10] } headers := map[string]string{ - "Authorization": fmt.Sprintf("Bearer %s", requireData.Auth), + "Authorization": fmt.Sprintf("Bearer %s", ext.Auth), } // 最终输出 var orderList []*OrderAfterSaleSupplierData - var err error // 多流水号 if len(toolReq.SerialNumber) > 0 { @@ -210,8 +216,8 @@ func (t *OrderAfterSaleSupplierTool) checkOrderAfterSaleSupplier(toolReq OrderAf return err } - entitys.ResLog(requireData.Ch, t.Name(), "售后订单信息拉取完成") - entitys.ResJson(requireData.Ch, t.Name(), string(jsonByte)) + entitys.ResLog(rec.Ch, t.Name(), "售后订单信息拉取完成") + entitys.ResJson(rec.Ch, t.Name(), string(jsonByte)) return nil } diff --git a/internal/tools/zltx/zltx_order_detail.go b/internal/tools/zltx/zltx_order_detail.go index 7dae9f1..52d6bbb 100644 --- a/internal/tools/zltx/zltx_order_detail.go +++ b/internal/tools/zltx/zltx_order_detail.go @@ -4,6 +4,7 @@ import ( "ai_scheduler/internal/config" "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg" + "ai_scheduler/internal/pkg/rec_extra" "ai_scheduler/internal/pkg/utils_ollama" "context" "encoding/json" @@ -81,9 +82,9 @@ type ZltxOrderDetailData struct { } // Execute 执行直连天下订单详情查询 -func (w *ZltxOrderDetailTool) Execute(ctx context.Context, requireData *entitys.RequireData) error { +func (w *ZltxOrderDetailTool) Execute(ctx context.Context, rec *entitys.Recognize) error { var req ZltxOrderDetailRequest - if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil { + if err := json.Unmarshal([]byte(rec.Match.Parameters), &req); err != nil { return fmt.Errorf("invalid zltxOrderDetail request: %w", err) } if req.OrderNumber == "" { @@ -91,16 +92,20 @@ func (w *ZltxOrderDetailTool) Execute(ctx context.Context, requireData *entitys. } // 这里可以集成真实的直连天下订单详情API - return w.getZltxOrderDetail(requireData, req.OrderNumber) + return w.getZltxOrderDetail(rec, req.OrderNumber) } // getMockZltxOrderDetail 获取模拟直连天下订单详情数据 -func (w *ZltxOrderDetailTool) getZltxOrderDetail(requireData *entitys.RequireData, number string) (err error) { +func (w *ZltxOrderDetailTool) getZltxOrderDetail(rec *entitys.Recognize, number string) (err error) { + ext, err := rec_extra.GetTaskRecExt(rec) + if err != nil { + return + } //查询订单详情 req := l_request.Request{ Url: fmt.Sprintf(w.config.BaseURL, number), Headers: map[string]string{ - "Authorization": fmt.Sprintf("Bearer %s", requireData.Auth), + "Authorization": fmt.Sprintf("Bearer %s", ext.Auth), }, Method: "GET", } @@ -121,15 +126,15 @@ func (w *ZltxOrderDetailTool) getZltxOrderDetail(requireData *entitys.RequireDat if err = json.Unmarshal(res.Content, &resData); err != nil { return } - entitys.ResJson(requireData.Ch, w.Name(), res.Text) + entitys.ResJson(rec.Ch, w.Name(), res.Text) if resData.Data.Direct != nil { - entitys.ResLoading(requireData.Ch, w.Name(), "正在分析订单日志") + entitys.ResLoading(rec.Ch, w.Name(), "正在分析订单日志") req = l_request.Request{ Url: fmt.Sprintf(w.config.AddURL, resData.Data.Direct["orderOrderNumber"].(string), resData.Data.Direct["serialNumber"].(string)), Headers: map[string]string{ - "Authorization": fmt.Sprintf("Bearer %s", requireData.Auth), + "Authorization": fmt.Sprintf("Bearer %s", ext.Auth), }, Method: "GET", } @@ -149,14 +154,14 @@ func (w *ZltxOrderDetailTool) getZltxOrderDetail(requireData *entitys.RequireDat return fmt.Errorf("订单日志解析失败:%s", err) } - err = w.llm.ChatStream(context.TODO(), requireData.Ch, []api.Message{ + err = w.llm.ChatStream(context.TODO(), rec.Ch, []api.Message{ { Role: "system", Content: "你是一个订单日志助手。用户可能会提供订单日志,你需要分析订单日志,失败订单->分析失败原因,成功订单->找出整个日志的 Base64 编码的 JSON 数据的内容进行转换并反馈给我", }, { Role: "assistant", - Content: fmt.Sprintf("聊天记录:%s", pkg.JsonStringIgonErr(requireData.Histories)), + Content: fmt.Sprintf("聊天记录:%s", pkg.JsonStringIgonErr(rec.ChatHis)), }, { Role: "assistant", @@ -164,7 +169,7 @@ func (w *ZltxOrderDetailTool) getZltxOrderDetail(requireData *entitys.RequireDat }, { Role: "user", - Content: requireData.Req.Text, + Content: rec.UserContent.Text, }, }, w.Name(), "") if err != nil { @@ -172,7 +177,7 @@ func (w *ZltxOrderDetailTool) getZltxOrderDetail(requireData *entitys.RequireDat } } if resData.Data.Direct == nil { - entitys.ResText(requireData.Ch, w.Name(), "该订单无充值流水记录,不需要进行订单错误分析,建议检查订单收单模式以及扣款方式等原因😘") + entitys.ResText(rec.Ch, w.Name(), "该订单无充值流水记录,不需要进行订单错误分析,建议检查订单收单模式以及扣款方式等原因😘") } return } diff --git a/internal/tools/zltx/zltx_order_direct_log.go b/internal/tools/zltx/zltx_order_direct_log.go index b1b1483..188a954 100644 --- a/internal/tools/zltx/zltx_order_direct_log.go +++ b/internal/tools/zltx/zltx_order_direct_log.go @@ -3,6 +3,7 @@ package zltx import ( "ai_scheduler/internal/config" "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg/rec_extra" "context" "encoding/json" "fmt" @@ -67,25 +68,28 @@ type ZltxOrderDirectLogData struct { Data map[string]interface{} `json:"data"` } -func (t *ZltxOrderLogTool) Execute(ctx context.Context, requireData *entitys.RequireData) error { +func (t *ZltxOrderLogTool) Execute(ctx context.Context, rec *entitys.Recognize) error { var req ZltxOrderLogRequest - if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil { + if err := json.Unmarshal([]byte(rec.Match.Parameters), &req); err != nil { return fmt.Errorf("invalid zltxOrderLog request: %w", err) } if req.OrderNumber == "" || req.SerialNumber == "" { return fmt.Errorf("orderNumber and serialNumber is required") } - return t.getZltxOrderLog(req.OrderNumber, req.SerialNumber, requireData) + return t.getZltxOrderLog(req.OrderNumber, req.SerialNumber, rec) } -func (t *ZltxOrderLogTool) getZltxOrderLog(orderNumber, serialNumber string, requireData *entitys.RequireData) (err error) { +func (t *ZltxOrderLogTool) getZltxOrderLog(orderNumber, serialNumber string, rec *entitys.Recognize) (err error) { //查询订单详情 - + ext, err := rec_extra.GetTaskRecExt(rec) + if err != nil { + return + } url := fmt.Sprintf("%s%s/%s", t.config.BaseURL, orderNumber, serialNumber) req := l_request.Request{ Url: url, Headers: map[string]string{ - "Authorization": fmt.Sprintf("Bearer %s", requireData.Auth), + "Authorization": fmt.Sprintf("Bearer %s", ext.Auth), }, Method: "GET", } @@ -100,7 +104,7 @@ func (t *ZltxOrderLogTool) getZltxOrderLog(orderNumber, serialNumber string, req if err = json.Unmarshal(res.Content, &resData); err != nil { return } - entitys.ResJson(requireData.Ch, t.Name(), res.Text) + entitys.ResJson(rec.Ch, t.Name(), res.Text) return } diff --git a/internal/tools/zltx/zltx_product.go b/internal/tools/zltx/zltx_product.go index de24236..0c63d84 100644 --- a/internal/tools/zltx/zltx_product.go +++ b/internal/tools/zltx/zltx_product.go @@ -3,6 +3,7 @@ package zltx import ( "ai_scheduler/internal/config" "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg/rec_extra" "context" "encoding/json" "fmt" @@ -53,12 +54,12 @@ type ZltxProductRequest struct { Name string `json:"name"` } -func (z ZltxProductTool) Execute(ctx context.Context, requireData *entitys.RequireData) error { +func (z ZltxProductTool) Execute(ctx context.Context, rec *entitys.Recognize) error { var req ZltxProductRequest - if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil { + if err := json.Unmarshal([]byte(rec.Match.Parameters), &req); err != nil { return fmt.Errorf("invalid zltxProduct request: %w", err) } - return z.getZltxProduct(&req, requireData) + return z.getZltxProduct(&req, rec) } type ZltxProductResponse struct { @@ -133,8 +134,11 @@ type ZltxProductData struct { PlatformProductList interface{} `json:"platform_product_list"` } -func (z ZltxProductTool) getZltxProduct(body *ZltxProductRequest, requireData *entitys.RequireData) error { - +func (z ZltxProductTool) getZltxProduct(body *ZltxProductRequest, rec *entitys.Recognize) error { + ext, err := rec_extra.GetTaskRecExt(rec) + if err != nil { + return err + } var Url string var params map[string]string if body.Id != "" { @@ -153,7 +157,7 @@ func (z ZltxProductTool) getZltxProduct(body *ZltxProductRequest, requireData *e //根据商品ID或名称走不同的接口查询 Url: Url, Headers: map[string]string{ - "Authorization": fmt.Sprintf("Bearer %s", requireData.Auth), + "Authorization": fmt.Sprintf("Bearer %s", ext.Auth), }, Params: params, Method: "GET", @@ -185,7 +189,7 @@ func (z ZltxProductTool) getZltxProduct(body *ZltxProductRequest, requireData *e for i := range resp.Data.List { // 调用 平台商品列表 if resp.Data.List[i].AuthProductIds != "" { - platformProductList := z.ExecutePlatformProductList(requireData.Auth, resp.Data.List[i].AuthProductIds, resp.Data.List[i].OfficialProductID) + platformProductList := z.ExecutePlatformProductList(ext.Auth, resp.Data.List[i].AuthProductIds, resp.Data.List[i].OfficialProductID) resp.Data.List[i].PlatformProductList = platformProductList } } @@ -194,7 +198,7 @@ func (z ZltxProductTool) getZltxProduct(body *ZltxProductRequest, requireData *e if err != nil { return err } - entitys.ResJson(requireData.Ch, z.Name(), string(marshal)) + entitys.ResJson(rec.Ch, z.Name(), string(marshal)) return nil } diff --git a/internal/tools/zltx/zltx_statistics.go b/internal/tools/zltx/zltx_statistics.go index 4a051a8..5d71a9b 100644 --- a/internal/tools/zltx/zltx_statistics.go +++ b/internal/tools/zltx/zltx_statistics.go @@ -3,6 +3,7 @@ package zltx import ( "ai_scheduler/internal/config" "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg/rec_extra" "context" "encoding/json" "fmt" @@ -47,15 +48,15 @@ type ZltxOrderStatisticsRequest struct { Number string `json:"number"` } -func (z ZltxOrderStatisticsTool) Execute(ctx context.Context, requireData *entitys.RequireData) error { +func (z ZltxOrderStatisticsTool) Execute(ctx context.Context, rec *entitys.Recognize) error { var req ZltxOrderStatisticsRequest - if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil { + if err := json.Unmarshal([]byte(rec.Match.Parameters), &req); err != nil { return err } if req.Number == "" { return fmt.Errorf("number is required") } - return z.getZltxOrderStatistics(req.Number, requireData) + return z.getZltxOrderStatistics(req.Number, rec) } type ZltxOrderStatisticsResponse struct { @@ -75,14 +76,18 @@ type ZltxOrderStatisticsData struct { Total int `json:"total"` } -func (z ZltxOrderStatisticsTool) getZltxOrderStatistics(number string, requireData *entitys.RequireData) error { +func (z ZltxOrderStatisticsTool) getZltxOrderStatistics(number string, rec *entitys.Recognize) error { + ext, err := rec_extra.GetTaskRecExt(rec) + if err != nil { + return err + } //查询订单详情 url := fmt.Sprintf("%s%s", z.config.BaseURL, number) req := l_request.Request{ Url: url, Headers: map[string]string{ - "Authorization": fmt.Sprintf("Bearer %s", requireData.Auth), + "Authorization": fmt.Sprintf("Bearer %s", ext.Auth), }, Method: "GET", } @@ -108,7 +113,7 @@ func (z ZltxOrderStatisticsTool) getZltxOrderStatistics(number string, requireDa if err != nil { return err } - entitys.ResJson(requireData.Ch, z.Name(), string(jsonByte)) + entitys.ResJson(rec.Ch, z.Name(), string(jsonByte)) return nil } diff --git a/internal/tools_bot/dtalk_bot.go b/internal/tools_bot/dtalk_bot.go deleted file mode 100644 index 2eae1d5..0000000 --- a/internal/tools_bot/dtalk_bot.go +++ /dev/null @@ -1,27 +0,0 @@ -package tools_bot - -import ( - "ai_scheduler/internal/config" - "ai_scheduler/internal/data/impl" - "ai_scheduler/internal/entitys" - "ai_scheduler/internal/pkg/utils_ollama" - - "context" -) - -type BotTool struct { - config *config.Config - llm *utils_ollama.Client - sessionImpl *impl.SessionImpl - taskMap map[string]string -} - -// NewBotTool 创建直连天下订单详情工具 -func NewBotTool(config *config.Config, llm *utils_ollama.Client, sessionImpl *impl.SessionImpl) *BotTool { - return &BotTool{config: config, llm: llm, sessionImpl: sessionImpl, taskMap: make(map[string]string)} -} - -// Execute 执行直连天下订单详情查询 -func (w *BotTool) Execute(ctx context.Context, toolName string, requireData *entitys.RequireData) (err error) { - return -}