refactor: optimize tools execution and authentication

This commit is contained in:
renzhiyuan 2025-12-16 11:09:53 +08:00
parent 4b83314849
commit b409266751
41 changed files with 680 additions and 328 deletions

View File

@ -6,6 +6,7 @@ package main
import ( import (
"ai_scheduler/internal/biz" "ai_scheduler/internal/biz"
"ai_scheduler/internal/biz/handle/dingtalk" "ai_scheduler/internal/biz/handle/dingtalk"
"ai_scheduler/internal/biz/tools_regis"
"ai_scheduler/internal/config" "ai_scheduler/internal/config"
"ai_scheduler/internal/data/impl" "ai_scheduler/internal/data/impl"
"ai_scheduler/internal/domain/workflow" "ai_scheduler/internal/domain/workflow"
@ -33,6 +34,7 @@ func InitializeApp(*config.Config, log.AllLogger) (*server.Servers, func(), erro
utils.ProviderUtils, utils.ProviderUtils,
tools_bot.ProviderSetBotTools, tools_bot.ProviderSetBotTools,
dingtalk.ProviderSetDingTalk, dingtalk.ProviderSetDingTalk,
tools_regis.ProviderToolsRegis,
)) ))
} }

View File

@ -3,16 +3,23 @@ package biz
import ( import (
"ai_scheduler/internal/biz/do" "ai_scheduler/internal/biz/do"
"ai_scheduler/internal/biz/handle/dingtalk" "ai_scheduler/internal/biz/handle/dingtalk"
"ai_scheduler/internal/biz/tools_regis"
"ai_scheduler/internal/data/constants" "ai_scheduler/internal/data/constants"
"ai_scheduler/internal/data/impl" "ai_scheduler/internal/data/impl"
"ai_scheduler/internal/data/model" "ai_scheduler/internal/data/model"
"ai_scheduler/internal/entitys" "ai_scheduler/internal/entitys"
"ai_scheduler/internal/tools"
"context" "context"
"database/sql"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"strings"
"github.com/gofiber/fiber/v2/log" "github.com/gofiber/fiber/v2/log"
"github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot" "github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot"
"xorm.io/builder" "xorm.io/builder"
) )
@ -24,6 +31,9 @@ type DingTalkBotBiz struct {
replier *chatbot.ChatbotReplier replier *chatbot.ChatbotReplier
log log.Logger log log.Logger
dingTalkUser *dingtalk.User dingTalkUser *dingtalk.User
botTools []model.AiBotTool
botGroupImpl *impl.BotGroupImpl
toolManager *tools.Manager
} }
// NewDingTalkBotBiz // NewDingTalkBotBiz
@ -31,8 +41,10 @@ func NewDingTalkBotBiz(
do *do.Do, do *do.Do,
handle *do.Handle, handle *do.Handle,
botConfigImpl *impl.BotConfigImpl, botConfigImpl *impl.BotConfigImpl,
botGroupImpl *impl.BotGroupImpl,
dingTalkUser *dingtalk.User, dingTalkUser *dingtalk.User,
tools *tools_regis.ToolRegis,
toolManager *tools.Manager,
) *DingTalkBotBiz { ) *DingTalkBotBiz {
return &DingTalkBotBiz{ return &DingTalkBotBiz{
do: do, do: do,
@ -40,6 +52,9 @@ func NewDingTalkBotBiz(
botConfigImpl: botConfigImpl, botConfigImpl: botConfigImpl,
replier: chatbot.NewChatbotReplier(), replier: chatbot.NewChatbotReplier(),
dingTalkUser: dingTalkUser, dingTalkUser: dingTalkUser,
botTools: tools.BootTools,
botGroupImpl: botGroupImpl,
toolManager: toolManager,
} }
} }
@ -68,18 +83,201 @@ func (d *DingTalkBotBiz) InitRequire(ctx context.Context, data *chatbot.BotCallb
Req: data, Req: data,
Ch: make(chan entitys.Response, 2), Ch: make(chan entitys.Response, 2),
} }
entitys.ResLog(requireData.Ch, "recognize_start", "收到消息,正在处理中,请稍等")
requireData.UserInfo, err = d.dingTalkUser.GetUserInfoFromBot(ctx, data.SenderStaffId, dingtalk.WithId(1))
return return
} }
func (d *DingTalkBotBiz) Recognize(ctx context.Context, bot *chatbot.BotCallbackDataModel) (match entitys.Match, err error) { func (d *DingTalkBotBiz) Do(ctx context.Context, requireData *entitys.RequireDataDingTalkBot) (err error) {
entitys.ResText(requireData.Ch, "", "收到消息,正在处理中,请稍等")
return d.handle.Recognize(ctx, nil, &do.WithDingTalkBot{}) defer close(requireData.Ch)
switch constants.ConversationType(requireData.Req.ConversationType) {
case constants.ConversationTypeSingle:
err = d.handleSingleChat(ctx, requireData)
case constants.ConversationTypeGroup:
err = d.handleGroupChat(ctx, requireData)
default:
err = errors.New("未知的聊天类型:" + requireData.Req.ConversationType)
}
return
} }
func (d *DingTalkBotBiz) handleSingleChat(ctx context.Context, requireData *entitys.RequireDataDingTalkBot) (err error) {
entitys.ResLog(requireData.Ch, "", "个人聊天暂未开启,请期待后续更新")
return
//requireData.UserInfo, err = d.dingTalkUser.GetUserInfoFromBot(ctx, requireData.Req.SenderStaffId, dingtalk.WithId(1))
//if err != nil {
// return
//}
////如果不是管理或者不是老板,则进行权限判断
//if requireData.UserInfo.IsSenior == constants.IsSeniorFalse && requireData.UserInfo.IsBoss == constants.IsBossFalse {
//
//}
//return
}
func (d *DingTalkBotBiz) handleGroupChat(ctx context.Context, requireData *entitys.RequireDataDingTalkBot) (err error) {
group, err := d.initGroup(ctx, requireData.Req.ConversationId, requireData.Req.ConversationTitle)
if err != nil {
return
}
groupTools, err := d.getGroupTools(ctx, group)
if err != nil {
return
}
rec, err := d.recognize(ctx, requireData, groupTools)
if err != nil {
return
}
return d.handleMatch(ctx, rec)
}
func (d *DingTalkBotBiz) initGroup(ctx context.Context, conversationId string, conversationTitle string) (group *model.AiBotGroup, err error) {
group, err = d.botGroupImpl.GetByConversationId(conversationId)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
return
}
}
if group.GroupID == 0 {
group = &model.AiBotGroup{
ConversationID: conversationId,
Title: conversationTitle,
ToolList: "",
}
//如果不存在则创建
d.botGroupImpl.Add(group)
}
return
}
func (d *DingTalkBotBiz) getGroupTools(ctx context.Context, group *model.AiBotGroup) (tools []model.AiBotTool, err error) {
if len(d.botTools) == 0 {
return
}
var (
groupRegisTools map[string]struct{}
)
if group.ToolList != "" {
groupList := strings.Split(group.ToolList, ",")
for _, tool := range groupList {
groupRegisTools[tool] = struct{}{}
}
}
for _, v := range d.botTools {
if v.PermissionType == constants.PermissionTypeNone {
tools = append(tools, v)
continue
}
if _, ex := groupRegisTools[v.Index]; ex {
tools = append(tools, v)
}
}
return
}
func (d *DingTalkBotBiz) recognize(ctx context.Context, requireData *entitys.RequireDataDingTalkBot, tools []model.AiBotTool) (rec *entitys.Recognize, err error) {
userContent, err := d.getUserContent(requireData.Req.Msgtype, requireData.Req.Text.Content)
if err != nil {
return nil, err
}
rec = &entitys.Recognize{
Ch: requireData.Ch,
SystemPrompt: d.defaultPrompt(),
UserContent: userContent,
}
if len(tools) > 0 {
rec.Tasks = make([]entitys.RegistrationTask, 0, len(tools))
for _, task := range tools {
taskConfig := entitys.TaskConfigDetail{}
if err = json.Unmarshal([]byte(task.Config), &taskConfig); err != nil {
log.Errorf("解析任务配置失败: %s, 任务ID: %s", err.Error(), task.Index)
continue // 解析失败时跳过该任务,而不是直接返回错误
}
rec.Tasks = append(rec.Tasks, entitys.RegistrationTask{
Name: task.Index,
Desc: task.Desc,
TaskConfigDetail: taskConfig, // 直接使用解析后的配置,避免重复构建
})
}
}
err = d.handle.Recognize(ctx, rec, &do.WithDingTalkBot{})
return
}
func (d *DingTalkBotBiz) getUserContent(msgType string, msgContent interface{}) (content *entitys.RecognizeUserContent, err error) {
switch constants.BotMsgType(msgType) {
case constants.BotMsgTypeText:
content = &entitys.RecognizeUserContent{
Text: msgContent.(string),
}
default:
return nil, errors.New("未知的消息类型:" + msgType)
}
return
}
func (d *DingTalkBotBiz) defaultPrompt() string {
return `{"system":"智能路由系统,精准解析用户意图并路由至任务模块,遵循以下规则:","rule":{"返回格式":"{\\"index\\":\\"工具索引\\",\\"confidence\\":\\"0.0-1.0\\",\\"reasoning\\":\\"判断理由\\",\\"parameters\\":\\"转义JSON参数\\",\\"is_match\\":true|false,\\"chat\\":\\"追问内容\\"}","工具匹配":["用工具parameters匹配区分必选required和可选optional参数","无法匹配时is_match=falsechat提醒用户适用工具'请问您要查询订单还是商品?'"],"参数提取":["从用户输入提取parameters中明确提及的参数","必须参数仅用用户直接提及的缺失时is_match=falsechat提醒补充'需补充XX信息'"],"格式要求":["所有字段值为字符串含confidence","parameters为转义JSON字符串如\\"{\\\\"key\\\\":\\\\"value\\\\"}\\""]}}`
}
func (d *DingTalkBotBiz) handleMatch(ctx context.Context, rec *entitys.Recognize) (err error) {
if !rec.Match.IsMatch {
if len(rec.Match.Chat) != 0 {
entitys.ResText(rec.Ch, "", rec.Match.Chat)
} else {
entitys.ResText(rec.Ch, "", rec.Match.Reasoning)
}
return
}
var pointTask *model.AiBotTool
for _, task := range d.botTools {
if task.Index == rec.Match.Index {
pointTask = &task
break
}
}
if pointTask == nil || pointTask.Index == "other" {
return d.otherTask(ctx, rec)
}
switch constants.TaskType(pointTask.Type) {
//case constants.TaskTypeApi:
//return d.handleApiTask(ctx, requireData, pointTask)
case constants.TaskTypeFunc:
return d.handleTask(ctx, rec, pointTask)
default:
return d.otherTask(ctx, rec)
}
return
}
func (d *DingTalkBotBiz) handleTask(ctx context.Context, rec *entitys.Recognize, task *model.AiBotTool) (err error) {
var configData entitys.ConfigDataTool
err = json.Unmarshal([]byte(task.Config), &configData)
if err != nil {
return
}
err = d.toolManager.ExecuteTool(ctx, configData.Tool, rec)
if err != nil {
return
}
return
}
func (d *DingTalkBotBiz) otherTask(ctx context.Context, rec *entitys.Recognize) (err error) {
entitys.ResText(rec.Ch, "", rec.Match.Reasoning)
return
}
func (d *DingTalkBotBiz) HandleRes(ctx context.Context, data *chatbot.BotCallbackDataModel, resp entitys.Response) error { func (d *DingTalkBotBiz) HandleRes(ctx context.Context, data *chatbot.BotCallbackDataModel, resp entitys.Response) error {
switch resp.Type { switch resp.Type {
case entitys.ResponseText: case entitys.ResponseText:

View File

@ -225,15 +225,6 @@ func (d *Do) GetSysInfo(requireData *entitys.RequireData) (sysInfo model.AiSy, e
return return
} }
func (d *Do) GetSysInfoForDingTalkBot(requireData *entitys.RequireDataDingTalkBot) (sysInfo model.AiSy, err error) {
cond := builder.NewCond()
cond = cond.And(builder.Eq{"app_key": requireData.Auth})
cond = cond.And(builder.IsNull{"delete_at"})
cond = cond.And(builder.Eq{"status": 1})
err = d.sysImpl.GetOneBySearchToStrut(&cond, &sysInfo)
return
}
func (d *Do) getSessionChatHis(requireData *entitys.RequireData) (his []model.AiChatHi, err error) { func (d *Do) getSessionChatHis(requireData *entitys.RequireData) (his []model.AiChatHi, err error) {
cond := builder.NewCond() cond := builder.NewCond()

View File

@ -12,10 +12,11 @@ import (
"ai_scheduler/internal/gateway" "ai_scheduler/internal/gateway"
"ai_scheduler/internal/pkg/l_request" "ai_scheduler/internal/pkg/l_request"
"ai_scheduler/internal/pkg/mapstructure" "ai_scheduler/internal/pkg/mapstructure"
"ai_scheduler/internal/pkg/rec_extra"
"ai_scheduler/internal/pkg/util" "ai_scheduler/internal/pkg/util"
"ai_scheduler/internal/tools" "ai_scheduler/internal/tools"
"ai_scheduler/internal/tools/public" "ai_scheduler/internal/tools/public"
"ai_scheduler/internal/tools_bot"
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
@ -25,9 +26,9 @@ import (
) )
type Handle struct { type Handle struct {
Ollama *llm_service.OllamaService Ollama *llm_service.OllamaService
toolManager *tools.Manager toolManager *tools.Manager
Bot *tools_bot.BotTool
conf *config.Config conf *config.Config
sessionImpl *impl.SessionImpl sessionImpl *impl.SessionImpl
workflowManager *runtime.Registry workflowManager *runtime.Registry
@ -38,20 +39,20 @@ func NewHandle(
toolManager *tools.Manager, toolManager *tools.Manager,
conf *config.Config, conf *config.Config,
sessionImpl *impl.SessionImpl, sessionImpl *impl.SessionImpl,
dTalkBot *tools_bot.BotTool,
workflowManager *runtime.Registry, workflowManager *runtime.Registry,
) *Handle { ) *Handle {
return &Handle{ return &Handle{
Ollama: Ollama, Ollama: Ollama,
toolManager: toolManager, toolManager: toolManager,
conf: conf, conf: conf,
sessionImpl: sessionImpl, sessionImpl: sessionImpl,
Bot: dTalkBot,
workflowManager: workflowManager, workflowManager: workflowManager,
} }
} }
func (r *Handle) Recognize(ctx context.Context, rec *entitys.Recognize, promptProcessor PromptOption) (match entitys.Match, err error) { func (r *Handle) Recognize(ctx context.Context, rec *entitys.Recognize, promptProcessor PromptOption) (err error) {
entitys.ResLog(rec.Ch, "recognize_start", "准备意图识别") entitys.ResLog(rec.Ch, "recognize_start", "准备意图识别")
prompt, err := promptProcessor.CreatePrompt(ctx, rec) prompt, err := promptProcessor.CreatePrompt(ctx, rec)
@ -65,11 +66,13 @@ func (r *Handle) Recognize(ctx context.Context, rec *entitys.Recognize, promptPr
} }
entitys.ResLog(rec.Ch, "recognize", recognizeMsg) entitys.ResLog(rec.Ch, "recognize", recognizeMsg)
entitys.ResLog(rec.Ch, "recognize_end", "意图识别结束") entitys.ResLog(rec.Ch, "recognize_end", "意图识别结束")
var match entitys.Match
if err = json.Unmarshal([]byte(recognizeMsg), &match); err != nil { if err = json.Unmarshal([]byte(recognizeMsg), &match); err != nil {
err = errors.SysErr("数据结构错误:%v", err.Error()) err = errors.SysErr("数据结构错误:%v", err.Error())
return return
} }
rec.Match = &match
return return
} }
@ -78,28 +81,27 @@ func (r *Handle) handleOtherTask(ctx context.Context, requireData *entitys.Requi
return return
} }
func (r *Handle) HandleMatch(ctx context.Context, client *gateway.Client, requireData *entitys.RequireData) (err error) { func (r *Handle) HandleMatch(ctx context.Context, client *gateway.Client, rec *entitys.Recognize, requireData *entitys.RequireData) (err error) {
if !requireData.Match.IsMatch { if !rec.Match.IsMatch {
if len(requireData.Match.Chat) != 0 { if len(rec.Match.Chat) != 0 {
entitys.ResText(requireData.Ch, "", requireData.Match.Chat) entitys.ResText(rec.Ch, "", rec.Match.Chat)
} else { } else {
entitys.ResText(requireData.Ch, "", requireData.Match.Reasoning) entitys.ResText(rec.Ch, "", rec.Match.Reasoning)
} }
return return
} }
var pointTask *model.AiTask var pointTask *model.AiTask
for _, task := range requireData.Tasks { for _, task := range requireData.Tasks {
if task.Index == requireData.Match.Index { if task.Index == rec.Match.Index {
pointTask = &task pointTask = &task
requireData.Task = task
break break
} }
} }
if pointTask == nil || pointTask.Index == "other" { if pointTask == nil || pointTask.Index == "other" {
return r.OtherTask(ctx, requireData) return r.OtherTask(ctx, rec)
} }
// 校验用户权限 // 校验用户权限
@ -110,47 +112,32 @@ func (r *Handle) HandleMatch(ctx context.Context, client *gateway.Client, requir
switch constants.TaskType(pointTask.Type) { switch constants.TaskType(pointTask.Type) {
case constants.TaskTypeApi: case constants.TaskTypeApi:
return r.handleApiTask(ctx, requireData, pointTask) return r.handleApiTask(ctx, rec, pointTask)
case constants.TaskTypeFunc: case constants.TaskTypeFunc:
return r.handleTask(ctx, requireData, pointTask) return r.handleTask(ctx, rec, pointTask)
case constants.TaskTypeKnowle: case constants.TaskTypeKnowle:
return r.handleKnowle(ctx, requireData, pointTask) return r.handleKnowle(ctx, rec, pointTask)
case constants.TaskTypeBot:
return r.handleBot(ctx, requireData, pointTask)
case constants.TaskTypeEinoWorkflow: case constants.TaskTypeEinoWorkflow:
return r.handleEinoWorkflow(ctx, requireData, pointTask) return r.handleEinoWorkflow(ctx, rec, pointTask)
default: default:
return r.handleOtherTask(ctx, requireData) return r.handleOtherTask(ctx, requireData)
} }
} }
func (r *Handle) OtherTask(ctx context.Context, requireData *entitys.RequireData) (err error) { func (r *Handle) OtherTask(ctx context.Context, requireData *entitys.Recognize) (err error) {
entitys.ResText(requireData.Ch, "", requireData.Match.Reasoning) entitys.ResText(requireData.Ch, "", requireData.Match.Reasoning)
return return
} }
func (r *Handle) handleBot(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) { func (r *Handle) handleTask(ctx context.Context, rec *entitys.Recognize, task *model.AiTask) (err error) {
var configData entitys.ConfigDataTool
err = json.Unmarshal([]byte(task.Config), &configData)
if err != nil {
return
}
err = r.Bot.Execute(ctx, configData.Tool, requireData)
if err != nil {
return
}
return
}
func (r *Handle) handleTask(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) {
var configData entitys.ConfigDataTool var configData entitys.ConfigDataTool
err = json.Unmarshal([]byte(task.Config), &configData) err = json.Unmarshal([]byte(task.Config), &configData)
if err != nil { if err != nil {
return return
} }
err = r.toolManager.ExecuteTool(ctx, configData.Tool, requireData) err = r.toolManager.ExecuteTool(ctx, configData.Tool, rec)
if err != nil { if err != nil {
return return
} }
@ -159,7 +146,7 @@ func (r *Handle) handleTask(ctx context.Context, requireData *entitys.RequireDat
} }
// 知识库 // 知识库
func (r *Handle) handleKnowle(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) { func (r *Handle) handleKnowle(ctx context.Context, rec *entitys.Recognize, task *model.AiTask) (err error) {
var ( var (
configData entitys.ConfigDataTool configData entitys.ConfigDataTool
@ -171,13 +158,16 @@ func (r *Handle) handleKnowle(ctx context.Context, requireData *entitys.RequireD
if err != nil { if err != nil {
return return
} }
ext, err := rec_extra.GetTaskRecExt(rec)
if err != nil {
return
}
// 通过session 找到知识库session // 通过session 找到知识库session
var has bool var has bool
if len(requireData.Session) == 0 { if len(ext.Session) == 0 {
return errors.SessionNotFound return errors.SessionNotFound
} }
requireData.SessionInfo, has, err = r.sessionImpl.FindOne(r.sessionImpl.WithSessionId(requireData.Session)) ext.SessionInfo, has, err = r.sessionImpl.FindOne(r.sessionImpl.WithSessionId(ext.Session))
if err != nil { if err != nil {
return return
} else if !has { } else if !has {
@ -200,15 +190,15 @@ func (r *Handle) handleKnowle(ctx context.Context, requireData *entitys.RequireD
} }
// 知识库的session为空请求知识库获取, 并绑定 // 知识库的session为空请求知识库获取, 并绑定
if requireData.SessionInfo.KnowlegeSessionID == "" { if ext.SessionInfo.KnowlegeSessionID == "" {
// 请求知识库 // 请求知识库
if sessionIdKnowledge, err = public.GetKnowledgeBaseSession(host, requireData.Sys.KnowlegeBaseID, requireData.Sys.KnowlegeTenantKey); err != nil { if sessionIdKnowledge, err = public.GetKnowledgeBaseSession(host, ext.Sys.KnowlegeBaseID, ext.Sys.KnowlegeTenantKey); err != nil {
return return
} }
// 绑定知识库session下次可以使用 // 绑定知识库session下次可以使用
requireData.SessionInfo.KnowlegeSessionID = sessionIdKnowledge ext.SessionInfo.KnowlegeSessionID = sessionIdKnowledge
if err = r.sessionImpl.Update(&requireData.SessionInfo, r.sessionImpl.WithSessionId(requireData.SessionInfo.SessionID)); err != nil { if err = r.sessionImpl.Update(&ext.SessionInfo, r.sessionImpl.WithSessionId(ext.SessionInfo.SessionID)); err != nil {
return return
} }
} }
@ -216,21 +206,21 @@ func (r *Handle) handleKnowle(ctx context.Context, requireData *entitys.RequireD
// 用户输入解析 // 用户输入解析
var ok bool var ok bool
input := make(map[string]string) input := make(map[string]string)
if err = json.Unmarshal([]byte(requireData.Match.Parameters), &input); err != nil { if err = json.Unmarshal([]byte(rec.Match.Parameters), &input); err != nil {
return return
} }
if query, ok = input["query"]; !ok { if query, ok = input["query"]; !ok {
return fmt.Errorf("query不能为空") return fmt.Errorf("query不能为空")
} }
requireData.KnowledgeConf = entitys.KnowledgeBaseRequest{ ext.KnowledgeConf = entitys.KnowledgeBaseRequest{
Session: requireData.SessionInfo.KnowlegeSessionID, Session: ext.SessionInfo.KnowlegeSessionID,
ApiKey: requireData.Sys.KnowlegeTenantKey, ApiKey: ext.Sys.KnowlegeTenantKey,
Query: query, Query: query,
} }
// 执行工具 // 执行工具
err = r.toolManager.ExecuteTool(ctx, configData.Tool, requireData) err = r.toolManager.ExecuteTool(ctx, configData.Tool, rec)
if err != nil { if err != nil {
return return
} }
@ -238,17 +228,21 @@ func (r *Handle) handleKnowle(ctx context.Context, requireData *entitys.RequireD
return return
} }
func (r *Handle) handleApiTask(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) { func (r *Handle) handleApiTask(ctx context.Context, rec *entitys.Recognize, task *model.AiTask) (err error) {
var ( var (
request l_request.Request request l_request.Request
requestParam map[string]interface{} requestParam map[string]interface{}
) )
err = json.Unmarshal([]byte(requireData.Match.Parameters), &requestParam) ext, err := rec_extra.GetTaskRecExt(rec)
if err != nil {
return
}
err = json.Unmarshal([]byte(rec.Match.Parameters), &requestParam)
if err != nil { if err != nil {
return return
} }
// request.Url = strings.ReplaceAll(task.Config, "${authorization}", requireData.Auth) // request.Url = strings.ReplaceAll(task.Config, "${authorization}", requireData.Auth)
task.Config = strings.ReplaceAll(task.Config, "${authorization}", requireData.Auth) task.Config = strings.ReplaceAll(task.Config, "${authorization}", ext.Auth)
for k, v := range requestParam { for k, v := range requestParam {
if vStr, ok := v.(string); ok { if vStr, ok := v.(string); ok {
task.Config = strings.ReplaceAll(task.Config, "${"+k+"}", vStr) task.Config = strings.ReplaceAll(task.Config, "${"+k+"}", vStr)
@ -275,27 +269,31 @@ func (r *Handle) handleApiTask(ctx context.Context, requireData *entitys.Require
return return
} }
entitys.ResLoading(requireData.Ch, requireData.Task.Index, "正在请求数据") entitys.ResLoading(rec.Ch, task.Index, "正在请求数据")
res, err := request.Send() res, err := request.Send()
if err != nil { if err != nil {
return return
} }
entitys.ResJson(requireData.Ch, requireData.Task.Index, res.Text) entitys.ResJson(rec.Ch, task.Index, res.Text)
return return
} }
// eino 工作流 // eino 工作流
func (r *Handle) handleEinoWorkflow(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) { func (r *Handle) handleEinoWorkflow(ctx context.Context, rec *entitys.Recognize, task *model.AiTask) (err error) {
// token 写入ctx // token 写入ctx
ctx = util.SetTokenToContext(ctx, requireData.Auth) ext, err := rec_extra.GetTaskRecExt(rec)
if err != nil {
return
}
ctx = util.SetTokenToContext(ctx, ext.Auth)
entitys.ResLoading(requireData.Ch, requireData.Task.Index, "正在执行工作流") entitys.ResLoading(rec.Ch, task.Index, "正在执行工作流")
// 工作流内部输出 // 工作流内部输出
workflowId := task.Index workflowId := task.Index
_, err = r.workflowManager.Invoke(ctx, workflowId, requireData) _, err = r.workflowManager.Invoke(ctx, workflowId, rec)
if err != nil { if err != nil {
return err return err
} }

View File

@ -41,8 +41,9 @@ func (d *Dept) GetDeptInfoByDeptIds(ctx context.Context, deptIds []int, authInfo
var existDept = make([]int, len(deptsInfo), 0) var existDept = make([]int, len(deptsInfo), 0)
for _, dept := range deptsInfo { for _, dept := range deptsInfo {
depts = append(depts, &entitys.Dept{ depts = append(depts, &entitys.Dept{
DeptId: int(dept.DeptID), DeptId: int(dept.DeptID),
Name: dept.Name, Name: dept.Name,
ToolList: dept.ToolList,
}) })
existDept = append(existDept, int(dept.DeptID)) existDept = append(existDept, int(dept.DeptID))
} }

View File

@ -46,12 +46,14 @@ func (u *User) GetUserInfoFromBot(ctx context.Context, staffId string, botOption
return return
} }
} }
//待优化
authInfo, err := u.auth.GetTokenFromBotOption(ctx, botOption...) authInfo, err := u.auth.GetTokenFromBotOption(ctx, botOption...)
if err != nil || authInfo == nil { if err != nil || authInfo == nil {
return return
} }
//如果没有找到,则新增 //如果没有找到,则新增
if user == nil { if user == nil {
DingUserInfo, _err := u.getUserInfoFromDingTalk(ctx, authInfo.AccessToken, staffId) DingUserInfo, _err := u.getUserInfoFromDingTalk(ctx, authInfo.AccessToken, staffId)
if _err != nil { if _err != nil {
return nil, _err return nil, _err

View File

@ -5,6 +5,8 @@ import (
"ai_scheduler/internal/data/constants" "ai_scheduler/internal/data/constants"
errors "ai_scheduler/internal/data/error" errors "ai_scheduler/internal/data/error"
"ai_scheduler/internal/gateway" "ai_scheduler/internal/gateway"
"ai_scheduler/internal/pkg/rec_extra"
"context" "context"
"encoding/json" "encoding/json"
"strings" "strings"
@ -68,14 +70,15 @@ func (r *AiRouterBiz) RouteWithSocket(client *gateway.Client, req *entitys.ChatS
} }
//意图识别 //意图识别
requireData.Match, err = r.handle.Recognize(ctx, &rec, sys) err = r.handle.Recognize(ctx, &rec, sys)
if err != nil { if err != nil {
log.Errorf("意图识别失败: %s", err.Error()) log.Errorf("意图识别失败: %s", err.Error())
return return
} }
//任务处理
rec_extra.SetTaskRecExt(requireData, &rec)
//向下传递 //向下传递
if err = r.handle.HandleMatch(ctx, client, requireData); err != nil { if err = r.handle.HandleMatch(ctx, client, &rec, requireData); err != nil {
log.Errorf("任务处理失败: %s", err.Error()) log.Errorf("任务处理失败: %s", err.Error())
return return
} }

View File

@ -0,0 +1,9 @@
package tools_regis
import (
"github.com/google/wire"
)
var ProviderToolsRegis = wire.NewSet(
NewToolsRegis,
)

View File

@ -0,0 +1,30 @@
package tools_regis
import (
"ai_scheduler/internal/data/constants"
"ai_scheduler/internal/data/impl"
"ai_scheduler/internal/data/model"
"xorm.io/builder"
)
type ToolRegis struct {
//待优化
BootTools []model.AiBotTool
}
func NewToolsRegis(botToolsImpl *impl.BotToolsImpl) *ToolRegis {
botTools := &ToolRegis{}
err := botTools.RegisTools(botToolsImpl)
if err != nil {
panic(err)
}
return botTools
}
func (t *ToolRegis) RegisTools(botToolsImpl *impl.BotToolsImpl) error {
cond := builder.NewCond()
cond = cond.And(builder.Eq{"status": constants.Enable})
err := botToolsImpl.GetRangeToMapStruct(&cond, &t.BootTools)
return err
}

View File

@ -33,3 +33,11 @@ const (
) )
const DingTalkAuthBaseKeyPrefix = "dingTalk_auth" const DingTalkAuthBaseKeyPrefix = "dingTalk_auth"
// PermissionType 工具使用权限
type PermissionType int32
const (
PermissionTypeNone = 1
PermissionTypeDept = 2
)

View File

@ -36,3 +36,16 @@ const (
IsSeniorTrue IsSenior = 1 IsSeniorTrue IsSenior = 1
IsSeniorFalse IsSenior = 0 IsSeniorFalse IsSenior = 0
) )
type ConversationType string
const (
ConversationTypeSingle = "1" // 单聊
ConversationTypeGroup = "2" //群聊
)
type BotMsgType string
const (
BotMsgTypeText BotMsgType = "text"
)

View File

@ -0,0 +1,27 @@
package impl
import (
"ai_scheduler/internal/data/model"
"ai_scheduler/tmpl/dataTemp"
"ai_scheduler/utils"
"database/sql"
)
type BotGroupImpl struct {
dataTemp.DataTemp
}
func NewBotGroupImpl(db *utils.Db) *BotGroupImpl {
return &BotGroupImpl{
DataTemp: *dataTemp.NewDataTemp(db, new(model.AiBotGroup)),
}
}
func (k BotGroupImpl) GetByConversationId(staffId string) (*model.AiBotGroup, error) {
var data model.AiBotGroup
err := k.Db.Model(k.Model).Where("conversation_id = ?", staffId).Find(&data).Error
if data.GroupID == 0 {
err = sql.ErrNoRows
}
return &data, err
}

View File

@ -0,0 +1,17 @@
package impl
import (
"ai_scheduler/internal/data/model"
"ai_scheduler/tmpl/dataTemp"
"ai_scheduler/utils"
)
type BotToolsImpl struct {
dataTemp.DataTemp
}
func NewBotToolsImpl(db *utils.Db) *BotToolsImpl {
return &BotToolsImpl{
DataTemp: *dataTemp.NewDataTemp(db, new(model.AiBotTool)),
}
}

View File

@ -17,11 +17,11 @@ func NewBotUserImpl(db *utils.Db) *BotUserImpl {
} }
} }
func (k BotUserImpl) GetByStaffId(staffId string) (data *model.AiBotUser, err error) { func (k BotUserImpl) GetByStaffId(staffId string) (*model.AiBotUser, error) {
var data model.AiBotUser
err = k.Db.Model(k.Model).Where("staff_id = ?", staffId).Find(data).Error err := k.Db.Model(k.Model).Where("staff_id = ?", staffId).Find(&data).Error
if data == nil { if data.UserID == 0 {
err = sql.ErrNoRows err = sql.ErrNoRows
} }
return return &data, err
} }

View File

@ -13,4 +13,6 @@ var ProviderImpl = wire.NewSet(
NewBotDeptImpl, NewBotDeptImpl,
NewBotUserImpl, NewBotUserImpl,
NewBotChatHisImpl, NewBotChatHisImpl,
NewBotToolsImpl,
NewBotGroupImpl,
) )

View File

@ -12,12 +12,13 @@ const TableNameAiBotDept = "ai_bot_dept"
// AiBotDept mapped from table <ai_bot_dept> // AiBotDept mapped from table <ai_bot_dept>
type AiBotDept struct { type AiBotDept struct {
DeptID int32 `gorm:"column:dept_id;primaryKey" json:"dept_id"` DeptID int32 `gorm:"column:dept_id;primaryKey;autoIncrement:true" json:"dept_id"`
DingtalkDeptID int32 `gorm:"column:dingtalk_dept_id;not null;comment:标记部门的唯一id钉钉钉钉侧提供的dept_id" json:"dingtalk_dept_id"` // 标记部门的唯一id钉钉钉钉侧提供的dept_id DingtalkDeptID int32 `gorm:"column:dingtalk_dept_id;not null;comment:标记部门的唯一id钉钉钉钉侧提供的dept_id" json:"dingtalk_dept_id"` // 标记部门的唯一id钉钉钉钉侧提供的dept_id
Name string `gorm:"column:name;not null;comment:用户名称" json:"name"` // 用户名称 Name string `gorm:"column:name;not null;comment:用户名称" json:"name"` // 用户名称
Status int32 `gorm:"column:status;not null" json:"status"` ToolList string `gorm:"column:tool_list;not null;comment:该部门支持的权限" json:"tool_list"` // 该部门支持的权限
DeleteAt time.Time `gorm:"column:delete_at" json:"delete_at"` Status int32 `gorm:"column:status;not null;default:1" json:"status"`
CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"` DeleteAt *time.Time `gorm:"column:delete_at" json:"delete_at"`
CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"`
} }
// TableName AiBotDept's table name // TableName AiBotDept's table name

View File

@ -0,0 +1,27 @@
// Code generated by gorm.io/gen. DO NOT EDIT.
// Code generated by gorm.io/gen. DO NOT EDIT.
// Code generated by gorm.io/gen. DO NOT EDIT.
package model
import (
"time"
)
const TableNameAiBotGroup = "ai_bot_group"
// AiBotGroup mapped from table <ai_bot_group>
type AiBotGroup struct {
GroupID int32 `gorm:"column:group_id;primaryKey;autoIncrement:true" json:"group_id"`
ConversationID string `gorm:"column:conversation_id;not null;comment:会话ID" json:"conversation_id"` // 会话ID
Title string `gorm:"column:title;not null;comment:群名称" json:"title"` // 群名称
ToolList string `gorm:"column:tool_list;not null;comment:开通工具列表" json:"tool_list"` // 开通工具列表
Status int32 `gorm:"column:status;not null;default:1" json:"status"`
DeleteAt *time.Time `gorm:"column:delete_at" json:"delete_at"`
CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"`
}
// TableName AiBotGroup's table name
func (*AiBotGroup) TableName() string {
return TableNameAiBotGroup
}

View File

@ -0,0 +1,32 @@
// Code generated by gorm.io/gen. DO NOT EDIT.
// Code generated by gorm.io/gen. DO NOT EDIT.
// Code generated by gorm.io/gen. DO NOT EDIT.
package model
import (
"time"
)
const TableNameAiBotTool = "ai_bot_tools"
// AiBotTool mapped from table <ai_bot_tools>
type AiBotTool struct {
ToolID int32 `gorm:"column:tool_id;primaryKey;autoIncrement:true" json:"tool_id"`
PermissionType int32 `gorm:"column:permission_type;not null;comment:类型1为公共工具不需要进行权限管理反之则为2" json:"permission_type"` // 类型1为公共工具不需要进行权限管理反之则为2
Config string `gorm:"column:config;not null;comment:类型下所需路由以及参数" json:"config"` // 类型下所需路由以及参数
Type int32 `gorm:"column:type;not null;default:3" json:"type"`
Name string `gorm:"column:name;not null;default:1;comment:工具名称" json:"name"` // 工具名称
Index string `gorm:"column:index;not null;comment:索引" json:"index"` // 索引
Desc string `gorm:"column:desc;not null;comment:工具描述" json:"desc"` // 工具描述
TempPrompt string `gorm:"column:temp_prompt;not null;comment:提示词模板" json:"temp_prompt"` // 提示词模板
CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"`
UpdatedAt time.Time `gorm:"column:updated_at;default:CURRENT_TIMESTAMP" json:"updated_at"`
Status int32 `gorm:"column:status;not null" json:"status"`
DeleteAt time.Time `gorm:"column:delete_at" json:"delete_at"`
}
// TableName AiBotTool's table name
func (*AiBotTool) TableName() string {
return TableNameAiBotTool
}

View File

@ -12,7 +12,7 @@ import (
type Workflow interface { type Workflow interface {
ID() string ID() string
Schema() map[string]any Schema() map[string]any
Invoke(ctx context.Context, requireData *entitys.RequireData) (map[string]any, error) Invoke(ctx context.Context, requireData *entitys.Recognize) (map[string]any, error)
} }
type Deps struct { type Deps struct {
@ -63,7 +63,7 @@ func Default() *Registry {
return r return r
} }
func (r *Registry) Invoke(ctx context.Context, id string, requireData *entitys.RequireData) (map[string]any, error) { func (r *Registry) Invoke(ctx context.Context, id string, rec *entitys.Recognize) (map[string]any, error) {
regMu.RLock() regMu.RLock()
f, ok := factories[id] f, ok := factories[id]
regMu.RUnlock() regMu.RUnlock()
@ -89,5 +89,5 @@ func (r *Registry) Invoke(ctx context.Context, id string, requireData *entitys.R
r.mu.Unlock() r.mu.Unlock()
} }
return w.Invoke(ctx, requireData) return w.Invoke(ctx, rec)
} }

View File

@ -31,7 +31,7 @@ type OrderAfterSaleResellerBatchWorkflowInput struct {
Ch chan entitys.Response // 响应通道 Ch chan entitys.Response // 响应通道
UserInput string // 用户输入文本 UserInput string // 用户输入文本
FileContent string // 文件解析结果 FileContent string // 文件解析结果
UserHistory []model.AiChatHi // 用户对话历史 UserHistory entitys.ChatHis // 用户对话历史
ParameterResult string // 参数解析结果 ParameterResult string // 参数解析结果
Data *OrderAfterSaleResellerBatchNodeData // 节点所需参数 Data *OrderAfterSaleResellerBatchNodeData // 节点所需参数
} }
@ -88,7 +88,7 @@ func (o *orderAfterSaleResellerBatch) Schema() map[string]any {
} }
// Invoke 调用原有编排工作流并规范化输出 // Invoke 调用原有编排工作流并规范化输出
func (o *orderAfterSaleResellerBatch) Invoke(ctx context.Context, requireData *entitys.RequireData) (map[string]any, error) { func (o *orderAfterSaleResellerBatch) Invoke(ctx context.Context, rec *entitys.Recognize) (map[string]any, error) {
// 构建工作流 // 构建工作流
chain, err := o.buildWorkflow(ctx) chain, err := o.buildWorkflow(ctx)
if err != nil { if err != nil {
@ -96,11 +96,11 @@ func (o *orderAfterSaleResellerBatch) Invoke(ctx context.Context, requireData *e
} }
o.data = &OrderAfterSaleResellerBatchWorkflowInput{ o.data = &OrderAfterSaleResellerBatchWorkflowInput{
Ch: requireData.Ch, Ch: rec.Ch,
UserInput: requireData.Req.Text, UserInput: rec.UserContent.Text,
FileContent: "", FileContent: "",
UserHistory: requireData.Histories, UserHistory: rec.ChatHis,
ParameterResult: requireData.Match.Parameters, ParameterResult: rec.Match.Parameters,
} }
// 工作流过程输出,不关注最终输出 // 工作流过程输出,不关注最终输出
_, err = chain.Invoke(ctx, o.data) _, err = chain.Invoke(ctx, o.data)

View File

@ -3,21 +3,16 @@ package entitys
import ( import (
"ai_scheduler/internal/data/model" "ai_scheduler/internal/data/model"
"github.com/ollama/ollama/api"
"github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot" "github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot"
) )
type RequireDataDingTalkBot struct { type RequireDataDingTalkBot struct {
Histories []model.AiChatHi Histories []model.AiChatHi
UserInfo *DingTalkUserInfo UserInfo *DingTalkUserInfo
Tasks []model.AiTask Tools []model.AiBotTool
Match *Match Match *Match
Req *chatbot.BotCallbackDataModel Req *chatbot.BotCallbackDataModel
Auth string Ch chan Response
Ch chan Response
KnowledgeConf KnowledgeBaseRequest
ImgByte []api.ImageData
ImgUrls []string
} }
type DingTalkBot struct { type DingTalkBot struct {

View File

@ -17,6 +17,7 @@ type DingTalkUserInfo struct {
} }
type Dept struct { type Dept struct {
Name string `json:"name"` Name string `json:"name"`
DeptId int `json:"dept_id"` DeptId int `json:"dept_id"`
ToolList string `json:"tool_list"`
} }

View File

@ -1,6 +1,9 @@
package entitys package entitys
import "ai_scheduler/internal/data/constants" import (
"ai_scheduler/internal/data/constants"
"ai_scheduler/internal/data/model"
)
type Recognize struct { type Recognize struct {
SystemPrompt string // 系统提示内容 SystemPrompt string // 系统提示内容
@ -8,11 +11,23 @@ type Recognize struct {
ChatHis ChatHis // 会话历史记录 ChatHis ChatHis // 会话历史记录
Tasks []RegistrationTask Tasks []RegistrationTask
Ch chan Response Ch chan Response
Match *Match
Ext []byte
}
type TaskExt struct {
Auth string `json:"auth"`
Session string `json:"session"`
Key string `json:"key"`
SessionInfo model.AiSession
Sys model.AiSy
KnowledgeConf KnowledgeBaseRequest
} }
type RegistrationTask struct { type RegistrationTask struct {
Name string Name string
Desc string Desc string
Index string
TaskConfigDetail TaskConfigDetail TaskConfigDetail TaskConfigDetail
} }
@ -26,6 +41,7 @@ type RecognizeUserContent struct {
type FileData []byte type FileData []byte
type RecognizeFile struct { type RecognizeFile struct {
FileRec string //文件识别内容
FileData FileData // 文件数据(二进制格式) FileData FileData // 文件数据(二进制格式)
FileType constants.Caller // 文件类型(文件类型,能填最好填,可以跳过一层判断) FileType constants.Caller // 文件类型(文件类型,能填最好填,可以跳过一层判断)
FileUrl string // 文件下载链接 FileUrl string // 文件下载链接

View File

@ -25,6 +25,9 @@ const (
) )
func ResLog(ch chan Response, index string, content string) { func ResLog(ch chan Response, index string, content string) {
if ch == nil {
return
}
ch <- Response{ ch <- Response{
Index: index, Index: index,
Content: content, Content: content,
@ -33,6 +36,9 @@ func ResLog(ch chan Response, index string, content string) {
} }
func ResStream(ch chan Response, index string, content string) { func ResStream(ch chan Response, index string, content string) {
if ch == nil {
return
}
ch <- Response{ ch <- Response{
Index: index, Index: index,
Content: content, Content: content,
@ -41,6 +47,9 @@ func ResStream(ch chan Response, index string, content string) {
} }
func ResJson(ch chan Response, index string, content string) { func ResJson(ch chan Response, index string, content string) {
if ch == nil {
return
}
ch <- Response{ ch <- Response{
Index: index, Index: index,
Content: content, Content: content,

View File

@ -78,7 +78,7 @@ type Tool interface {
Name() string Name() string
Description() string Description() string
Definition() ToolDefinition Definition() ToolDefinition
Execute(ctx context.Context, requireData *RequireData) error Execute(ctx context.Context, requireData *Recognize) error
} }
type ConfigDataHttp struct { type ConfigDataHttp struct {

View File

@ -0,0 +1,22 @@
package rec_extra
import (
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg"
"encoding/json"
)
func SetTaskRecExt(requireData *entitys.RequireData, rec *entitys.Recognize) {
TaskExt := entitys.TaskExt{
Auth: requireData.Auth,
Session: requireData.Session,
Key: requireData.Key,
Sys: requireData.Sys,
}
rec.Ext = pkg.JsonByteIgonErr(TaskExt)
}
func GetTaskRecExt(rec *entitys.Recognize) (ext *entitys.TaskExt, err error) {
err = json.Unmarshal(rec.Ext, ext)
return ext, err
}

View File

@ -2,18 +2,19 @@ package services
import ( import (
"ai_scheduler/internal/biz" "ai_scheduler/internal/biz"
"log"
"time"
"ai_scheduler/internal/config" "ai_scheduler/internal/config"
"ai_scheduler/internal/entitys" "ai_scheduler/internal/entitys"
"context" "context"
"fmt"
"github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot" "github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot"
) )
type DingBotService struct { type DingBotService struct {
config *config.Config config *config.Config
dingTalkBotBiz *biz.DingTalkBotBiz dingTalkBotBiz *biz.DingTalkBotBiz
} }
@ -30,38 +31,42 @@ func (d *DingBotService) OnChatBotMessageReceived(ctx context.Context, data *cha
if err != nil { if err != nil {
return return
} }
// 使用 ctx.Done() 通知 Do 方法提前终止
subCtx, cancel := context.WithCancel(ctx)
defer cancel()
// 异步执行 Do 方法
done := make(chan error, 1)
go func() { go func() {
//defer close(requireData.Ch) done <- d.dingTalkBotBiz.Do(subCtx, requireData)
//if match, _err := d.dingTalkBotBiz.Recognize(ctx, data); _err != nil {
// requireData.Ch <- entitys.Response{
// Type: entitys.ResponseEnd,
// Content: fmt.Sprintf("处理消息时出错: %s", _err.Error()),
// }
//}
////向下传递
//if err = d.dingTalkBotBiz.HandleMatch(ctx, nil, requireData); err != nil {
// requireData.Ch <- entitys.Response{
// Type: entitys.ResponseEnd,
// Content: fmt.Sprintf("匹配失败: %v", err),
// }
//}
}() }()
var lastErr error
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return nil, ctx.Err() lastErr = ctx.Err()
goto cleanup
case resp, ok := <-requireData.Ch: case resp, ok := <-requireData.Ch:
if !ok { if !ok {
return []byte("success"), nil // 通道关闭,处理完成 return []byte("success"), nil
} }
if resp.Type == entitys.ResponseLog { if resp.Type == entitys.ResponseLog {
return continue
} }
if err := d.dingTalkBotBiz.HandleRes(ctx, data, resp); err != nil { if err := d.dingTalkBotBiz.HandleRes(ctx, data, resp); err != nil {
return nil, fmt.Errorf("回复失败: %w", err) log.Printf("HandleRes 失败: %v", err)
} }
} }
} }
return cleanup:
select {
case err := <-done:
if err != nil {
log.Printf("Do 方法执行失败: %v", err)
}
case <-time.After(1 * time.Second):
log.Println("警告:等待 Do 方法超时,可能发生 goroutine 泄漏")
}
return nil, lastErr
} }

View File

@ -24,24 +24,6 @@ func NewManager(config *config.Config, llm *utils_ollama.Client) *Manager {
llm: llm, llm: llm,
} }
// 注册天气工具
//if config.Tools.Weather.Enabled {
// weatherTool := NewWeatherTool()
// m.tools[weatherTool.Name()] = weatherTool
//}
//
//// 注册计算器工具
//if config.Tools.Calculator.Enabled {
// calcTool := NewCalculatorTool()
// m.tools[calcTool.Name()] = calcTool
//}
// 注册知识库工具
// if config.Knowledge.Enabled {
// knowledgeTool := NewKnowledgeTool()
// m.tools[knowledgeTool.Name()] = knowledgeTool
// }
// 注册直连天下订单详情工具 // 注册直连天下订单详情工具
if config.Tools.ZltxOrderDetail.Enabled { if config.Tools.ZltxOrderDetail.Enabled {
zltxOrderDetailTool := zltxtool.NewZltxOrderDetailTool(config.Tools.ZltxOrderDetail, m.llm) zltxOrderDetailTool := zltxtool.NewZltxOrderDetailTool(config.Tools.ZltxOrderDetail, m.llm)
@ -115,63 +97,12 @@ func (m *Manager) GetTool(name string) (entitys.Tool, bool) {
return tool, exists return tool, exists
} }
// GetAllTools 获取所有工具
// func (m *Manager) GetAllTools() []entitys.Tool {
// tools := make([]entitys.Tool, 0, len(m.tools))
// for _, tool := range m.tools {
// tools = append(tools, tool)
// }
// return tools
// }
// // GetToolDefinitions 获取所有工具定义
// func (m *Manager) GetToolDefinitions(caller constants.Caller) []entitys.ToolDefinition {
// definitions := make([]entitys.ToolDefinition, 0, len(m.tools))
// for _, tool := range m.tools {
// definitions = append(definitions, tool.Definition())
// }
// return definitions
// }
// ExecuteTool 执行工具 // ExecuteTool 执行工具
func (m *Manager) ExecuteTool(ctx context.Context, name string, requireData *entitys.RequireData) error { func (m *Manager) ExecuteTool(ctx context.Context, name string, rec *entitys.Recognize) error {
tool, exists := m.GetTool(name) tool, exists := m.GetTool(name)
if !exists { if !exists {
return fmt.Errorf("tool not found: %s", name) return fmt.Errorf("tool not found: %s", name)
} }
return tool.Execute(ctx, requireData) return tool.Execute(ctx, rec)
} }
// ExecuteToolCalls 执行多个工具调用
//func (m *Manager) ExecuteToolCalls(c *websocket.Conn, toolCalls []entitys.ToolCall) ([]entitys.ToolCall, error) {
// results := make([]entitys.ToolCall, len(toolCalls))
//
// for i, toolCall := range toolCalls {
// results[i] = toolCall
//
// // 执行工具
// err := m.ExecuteTool(c, toolCall.Function.Name, toolCall.Function.Arguments)
// if err != nil {
// // 将错误信息作为结果返回
// errorResult := map[string]interface{}{
// "error": err.Error(),
// }
// resultBytes, _ := json.Marshal(errorResult)
// results[i].Result = resultBytes
// } else {
// // 将成功结果序列化
// resultBytes, err := json.Marshal(result)
// if err != nil {
// errorResult := map[string]interface{}{
// "error": fmt.Sprintf("failed to serialize result: %v", err),
// }
// resultBytes, _ = json.Marshal(errorResult)
// }
// results[i].Result = resultBytes
// }
// }
//
// return results, nil
//}

View File

@ -7,10 +7,11 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/ollama/ollama/api"
"net/http" "net/http"
"time" "time"
"github.com/ollama/ollama/api"
"github.com/coze-dev/coze-go" "github.com/coze-dev/coze-go"
) )
@ -70,7 +71,7 @@ func (c *CozeCompany) Definition() entitys.ToolDefinition {
} }
// Execute 执行查询 // Execute 执行查询
func (c *CozeCompany) Execute(ctx context.Context, requireData *entitys.RequireData) error { func (c *CozeCompany) Execute(ctx context.Context, requireData *entitys.Recognize) error {
var req map[string]interface{} var req map[string]interface{}
if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil { if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil {
return fmt.Errorf("invalid express request: %w", err) return fmt.Errorf("invalid express request: %w", err)
@ -191,7 +192,7 @@ Top3核心风险1. 根据司法信息近1年作为被告的合同纠纷占
}, },
{ {
Role: "user", Role: "user",
Content: requireData.Req.Text, Content: requireData.UserContent.Text,
}, },
}, },
c.Name(), "") c.Name(), "")

View File

@ -8,6 +8,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/coze-dev/coze-go" "github.com/coze-dev/coze-go"
@ -67,7 +68,7 @@ func (c *CozeExpress) Definition() entitys.ToolDefinition {
} }
// Execute 执行查询 // Execute 执行查询
func (c *CozeExpress) Execute(ctx context.Context, requireData *entitys.RequireData) error { func (c *CozeExpress) Execute(ctx context.Context, requireData *entitys.Recognize) error {
var req map[string]interface{} var req map[string]interface{}
if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil { if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil {
return fmt.Errorf("invalid express request: %w", err) return fmt.Errorf("invalid express request: %w", err)
@ -89,7 +90,7 @@ func (c *CozeExpress) Execute(ctx context.Context, requireData *entitys.RequireD
}, },
{ {
Role: "assistant", Role: "assistant",
Content: fmt.Sprintf("聊天记录:%s", pkg.JsonStringIgonErr(requireData.Histories)), Content: fmt.Sprintf("聊天记录:%s", pkg.JsonStringIgonErr(requireData.ChatHis)),
}, },
{ {
Role: "assistant", Role: "assistant",
@ -97,7 +98,7 @@ func (c *CozeExpress) Execute(ctx context.Context, requireData *entitys.RequireD
}, },
{ {
Role: "user", Role: "user",
Content: requireData.Req.Text, Content: requireData.UserContent.Text,
}, },
}, c.Name(), "") }, c.Name(), "")
if err != nil { if err != nil {

View File

@ -4,6 +4,7 @@ import (
"ai_scheduler/internal/config" "ai_scheduler/internal/config"
"ai_scheduler/internal/entitys" "ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg/l_request" "ai_scheduler/internal/pkg/l_request"
"ai_scheduler/internal/pkg/rec_extra"
"bufio" "bufio"
"context" "context"
"encoding/json" "encoding/json"
@ -58,7 +59,7 @@ func (k *KnowledgeBaseTool) Definition() entitys.ToolDefinition {
} }
// Execute 执行知识库查询 // Execute 执行知识库查询
func (k *KnowledgeBaseTool) Execute(ctx context.Context, requireData *entitys.RequireData) error { func (k *KnowledgeBaseTool) Execute(ctx context.Context, requireData *entitys.Recognize) error {
entitys.ResLoading(requireData.Ch, k.Name(), "正在为您搜索相关信息") entitys.ResLoading(requireData.Ch, k.Name(), "正在为您搜索相关信息")
return k.chat(requireData) return k.chat(requireData)
@ -91,20 +92,23 @@ func (this *KnowledgeBaseTool) msgContentParse(input string, channel chan entity
} }
// 请求知识库聊天 // 请求知识库聊天
func (this *KnowledgeBaseTool) chat(requireData *entitys.RequireData) (err error) { func (this *KnowledgeBaseTool) chat(rec *entitys.Recognize) (err error) {
ext, err := rec_extra.GetTaskRecExt(rec)
if err != nil {
return
}
req := l_request.Request{ req := l_request.Request{
Method: "post", Method: "post",
Url: this.config.BaseURL + "/api/v1/knowledge-chat/" + requireData.KnowledgeConf.Session, Url: this.config.BaseURL + "/api/v1/knowledge-chat/" + ext.KnowledgeConf.Session,
Params: nil, Params: nil,
Headers: map[string]string{ Headers: map[string]string{
"Content-Type": "application/json", "Content-Type": "application/json",
"X-API-Key": requireData.KnowledgeConf.ApiKey, "X-API-Key": ext.KnowledgeConf.ApiKey,
}, },
Cookies: nil, Cookies: nil,
Data: nil, Data: nil,
Json: map[string]interface{}{ Json: map[string]interface{}{
"query": requireData.KnowledgeConf.Query, "query": ext.KnowledgeConf.Query,
}, },
Files: nil, Files: nil,
Raw: "", Raw: "",
@ -118,7 +122,7 @@ func (this *KnowledgeBaseTool) chat(requireData *entitys.RequireData) (err error
} }
defer rsp.Body.Close() defer rsp.Body.Close()
err = this.connectAndReadSSE(rsp, requireData.Ch) err = this.connectAndReadSSE(rsp, rec.Ch)
if err != nil { if err != nil {
return return
} }

View File

@ -43,9 +43,9 @@ func (w *NormalChatTool) Definition() entitys.ToolDefinition {
} }
// Execute 执行直连天下订单详情查询 // Execute 执行直连天下订单详情查询
func (w *NormalChatTool) Execute(ctx context.Context, requireData *entitys.RequireData) error { func (w *NormalChatTool) Execute(ctx context.Context, rec *entitys.Recognize) error {
var req NormalChat var req NormalChat
if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil { if err := json.Unmarshal([]byte(rec.Match.Parameters), &req); err != nil {
return fmt.Errorf("invalid zltxOrderDetail request: %w", err) return fmt.Errorf("invalid zltxOrderDetail request: %w", err)
} }
if req.ChatContent == "" { if req.ChatContent == "" {
@ -53,25 +53,25 @@ func (w *NormalChatTool) Execute(ctx context.Context, requireData *entitys.Requi
} }
// 这里可以集成真实的直连天下订单详情API // 这里可以集成真实的直连天下订单详情API
return w.chat(requireData, &req) return w.chat(rec, &req)
} }
// getMockZltxOrderDetail 获取模拟直连天下订单详情数据 // getMockZltxOrderDetail 获取模拟直连天下订单详情数据
func (w *NormalChatTool) chat(requireData *entitys.RequireData, chat *NormalChat) (err error) { func (w *NormalChatTool) chat(rec *entitys.Recognize, chat *NormalChat) (err error) {
//requireData.Ch <- entitys.Response{ //requireData.Ch <- entitys.Response{
// Index: w.Name(), // Index: w.Name(),
// Content: "<think></think>", // Content: "<think></think>",
// Type: entitys.ResponseStream, // Type: entitys.ResponseStream,
//} //}
err = w.llm.ChatStream(context.TODO(), requireData.Ch, []api.Message{ err = w.llm.ChatStream(context.TODO(), rec.Ch, []api.Message{
{ {
Role: "system", Role: "system",
Content: "你是一个聊天助手", Content: "你是一个聊天助手",
}, },
{ {
Role: "assistant", Role: "assistant",
Content: fmt.Sprintf("聊天记录:%s", pkg.JsonStringIgonErr(requireData.Histories)), Content: fmt.Sprintf("聊天记录:%s", pkg.JsonStringIgonErr(rec.ChatHis)),
}, },
{ {
Role: "user", Role: "user",

View File

@ -107,9 +107,9 @@ type LiveWeather struct {
} }
// Execute 执行天气查询 // Execute 执行天气查询
func (w *WeatherTool) Execute(ctx context.Context, requireData *entitys.RequireData) error { func (w *WeatherTool) Execute(ctx context.Context, rec *entitys.Recognize) error {
var req WeatherRequest var req WeatherRequest
if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil { if err := json.Unmarshal([]byte(rec.Match.Parameters), &req); err != nil {
return fmt.Errorf("invalid weather request: %w", err) return fmt.Errorf("invalid weather request: %w", err)
} }
@ -134,7 +134,7 @@ func (w *WeatherTool) Execute(ctx context.Context, requireData *entitys.RequireD
// 根据 extensions 参数返回不同的天气信息 // 根据 extensions 参数返回不同的天气信息
if req.Extensions == "base" { if req.Extensions == "base" {
entitys.ResText(requireData.Ch, "", fmt.Sprintf("%s实时天气%s温度%.1f℃,湿度:%d%%,风速:%.1fkm/h风向%s", entitys.ResText(rec.Ch, "", fmt.Sprintf("%s实时天气%s温度%.1f℃,湿度:%d%%,风速:%.1fkm/h风向%s",
req.City, req.City,
responseMsg.LiveWeather.Condition, responseMsg.LiveWeather.Condition,
responseMsg.LiveWeather.Temperature, responseMsg.LiveWeather.Temperature,
@ -149,7 +149,7 @@ func (w *WeatherTool) Execute(ctx context.Context, requireData *entitys.RequireD
forecast.Date, forecast.DayTemp, forecast.NightTemp, forecast.DayWind, forecast.NightWind) forecast.Date, forecast.DayTemp, forecast.NightTemp, forecast.DayWind, forecast.NightWind)
} }
entitys.ResText(requireData.Ch, "", rspStr) entitys.ResText(rec.Ch, "", rspStr)
} }
return nil return nil
} }

View File

@ -4,6 +4,7 @@ import (
"ai_scheduler/internal/config" "ai_scheduler/internal/config"
"ai_scheduler/internal/entitys" "ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg/l_request" "ai_scheduler/internal/pkg/l_request"
"ai_scheduler/internal/pkg/rec_extra"
"ai_scheduler/internal/pkg/util" "ai_scheduler/internal/pkg/util"
"context" "context"
"encoding/json" "encoding/json"
@ -104,9 +105,9 @@ type OrderAfterSaleResellerApiExtItem struct {
SerialCreateTime int `json:"createTime"` // 流水创建时间 SerialCreateTime int `json:"createTime"` // 流水创建时间
} }
func (t *OrderAfterSaleResellerTool) Execute(ctx context.Context, requireData *entitys.RequireData) error { func (t *OrderAfterSaleResellerTool) Execute(ctx context.Context, rec *entitys.Recognize) error {
var req OrderAfterSaleResellerRequest var req OrderAfterSaleResellerRequest
if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil { if err := json.Unmarshal([]byte(rec.Match.Parameters), &req); err != nil {
return fmt.Errorf("解析参数失败,请重试或联系管理员") return fmt.Errorf("解析参数失败,请重试或联系管理员")
} }
if len(req.OrderNumber) == 0 && len(req.Account) == 0 { if len(req.OrderNumber) == 0 && len(req.Account) == 0 {
@ -116,18 +117,22 @@ func (t *OrderAfterSaleResellerTool) Execute(ctx context.Context, requireData *e
if req.SerialCreateTime != "" { if req.SerialCreateTime != "" {
_, err := time.ParseInLocation(time.DateTime, req.SerialCreateTime, time.Local) _, err := time.ParseInLocation(time.DateTime, req.SerialCreateTime, time.Local)
if err != nil { if err != nil {
entitys.ResLog(requireData.Ch, t.Name(), "时间格式不匹配,已置为空") entitys.ResLog(rec.Ch, t.Name(), "时间格式不匹配,已置为空")
req.SerialCreateTime = "" req.SerialCreateTime = ""
} }
} }
entitys.ResLog(requireData.Ch, t.Name(), "正在拉取售后订单信息") entitys.ResLog(rec.Ch, t.Name(), "正在拉取售后订单信息")
return t.checkOrderAfterSaleReseller(req, requireData) return t.checkOrderAfterSaleReseller(req, rec)
} }
func (t *OrderAfterSaleResellerTool) checkOrderAfterSaleReseller(toolReq OrderAfterSaleResellerRequest, requireData *entitys.RequireData) error { func (t *OrderAfterSaleResellerTool) checkOrderAfterSaleReseller(toolReq OrderAfterSaleResellerRequest, rec *entitys.Recognize) error {
var serialStartTime, serialEndTime int64 var serialStartTime, serialEndTime int64
ext, err := rec_extra.GetTaskRecExt(rec)
if err != nil {
return err
}
if toolReq.SerialCreateTime != "" { if toolReq.SerialCreateTime != "" {
// 流水创建时间上下浮动10min // 流水创建时间上下浮动10min
serialCreateTime, err := time.ParseInLocation(time.DateTime, toolReq.SerialCreateTime, time.Local) serialCreateTime, err := time.ParseInLocation(time.DateTime, toolReq.SerialCreateTime, time.Local)
@ -144,17 +149,16 @@ func (t *OrderAfterSaleResellerTool) checkOrderAfterSaleReseller(toolReq OrderAf
// 账号数量超过10直接截断 // 账号数量超过10直接截断
if len(toolReq.Account) > 10 { if len(toolReq.Account) > 10 {
entitys.ResLog(requireData.Ch, t.Name(), "账号数量超过10已被截断") entitys.ResLog(rec.Ch, t.Name(), "账号数量超过10已被截断")
toolReq.Account = toolReq.Account[:10] toolReq.Account = toolReq.Account[:10]
} }
headers := map[string]string{ headers := map[string]string{
"Authorization": fmt.Sprintf("Bearer %s", requireData.Auth), "Authorization": fmt.Sprintf("Bearer %s", ext.Auth),
} }
// 最终输出 // 最终输出
var orderList []*OrderAfterSaleResellerData var orderList []*OrderAfterSaleResellerData
var err error
// 多订单号 // 多订单号
if len(toolReq.OrderNumber) > 0 { if len(toolReq.OrderNumber) > 0 {
@ -217,8 +221,8 @@ func (t *OrderAfterSaleResellerTool) checkOrderAfterSaleReseller(toolReq OrderAf
return err return err
} }
entitys.ResLog(requireData.Ch, t.Name(), "售后订单信息拉取完成") entitys.ResLog(rec.Ch, t.Name(), "售后订单信息拉取完成")
entitys.ResJson(requireData.Ch, t.Name(), string(jsonByte)) entitys.ResJson(rec.Ch, t.Name(), string(jsonByte))
return nil return nil
} }

View File

@ -4,6 +4,7 @@ import (
"ai_scheduler/internal/config" "ai_scheduler/internal/config"
"ai_scheduler/internal/entitys" "ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg/l_request" "ai_scheduler/internal/pkg/l_request"
"ai_scheduler/internal/pkg/rec_extra"
"ai_scheduler/internal/pkg/util" "ai_scheduler/internal/pkg/util"
"context" "context"
"encoding/json" "encoding/json"
@ -100,25 +101,29 @@ type OrderAfterSaleResellerBatchApiExtItem struct {
SerialCreateTime int `json:"createTime"` // 流水创建时间 SerialCreateTime int `json:"createTime"` // 流水创建时间
} }
func (t *OrderAfterSaleResellerBatchTool) Execute(ctx context.Context, requireData *entitys.RequireData) error { func (t *OrderAfterSaleResellerBatchTool) Execute(ctx context.Context, rec *entitys.Recognize) error {
var req OrderAfterSaleResellerBatchRequest var req OrderAfterSaleResellerBatchRequest
if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil { if err := json.Unmarshal([]byte(rec.Match.Parameters), &req); err != nil {
return fmt.Errorf("解析参数失败,请重试或联系管理员") return fmt.Errorf("解析参数失败,请重试或联系管理员")
} }
if len(req.OrderNumber) == 0 { if len(req.OrderNumber) == 0 {
return fmt.Errorf("批充订单号不能为空") return fmt.Errorf("批充订单号不能为空")
} }
entitys.ResLog(requireData.Ch, t.Name(), "正在拉取售后订单信息") entitys.ResLog(rec.Ch, t.Name(), "正在拉取售后订单信息")
return t.checkOrderAfterSaleResellerBatch(req, requireData) return t.checkOrderAfterSaleResellerBatch(req, rec)
} }
func (t *OrderAfterSaleResellerBatchTool) checkOrderAfterSaleResellerBatch(toolReq OrderAfterSaleResellerBatchRequest, requireData *entitys.RequireData) error { func (t *OrderAfterSaleResellerBatchTool) checkOrderAfterSaleResellerBatch(toolReq OrderAfterSaleResellerBatchRequest, rec *entitys.Recognize) error {
ext, err := rec_extra.GetTaskRecExt(rec)
if err != nil {
return err
}
req := l_request.Request{ req := l_request.Request{
Url: t.config.BaseURL, Url: t.config.BaseURL,
Headers: map[string]string{ Headers: map[string]string{
"Authorization": fmt.Sprintf("Bearer %s", requireData.Auth), "Authorization": fmt.Sprintf("Bearer %s", ext.Auth),
}, },
Method: "POST", Method: "POST",
Json: map[string]any{ Json: map[string]any{
@ -200,7 +205,7 @@ func (t *OrderAfterSaleResellerBatchTool) checkOrderAfterSaleResellerBatch(toolR
return err return err
} }
entitys.ResLog(requireData.Ch, t.Name(), "售后订单信息拉取完成") entitys.ResLog(rec.Ch, t.Name(), "售后订单信息拉取完成")
entitys.ResJson(requireData.Ch, t.Name(), string(jsonByte)) entitys.ResJson(rec.Ch, t.Name(), string(jsonByte))
return nil return nil
} }

View File

@ -4,6 +4,7 @@ import (
"ai_scheduler/internal/config" "ai_scheduler/internal/config"
"ai_scheduler/internal/entitys" "ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg/l_request" "ai_scheduler/internal/pkg/l_request"
"ai_scheduler/internal/pkg/rec_extra"
"ai_scheduler/internal/pkg/util" "ai_scheduler/internal/pkg/util"
"context" "context"
"encoding/json" "encoding/json"
@ -98,9 +99,9 @@ type OrderAfterSaleSupplierApiExtItem struct {
SerialCreateTime int `json:"createTime"` // 流水创建时间 SerialCreateTime int `json:"createTime"` // 流水创建时间
} }
func (t *OrderAfterSaleSupplierTool) Execute(ctx context.Context, requireData *entitys.RequireData) error { func (t *OrderAfterSaleSupplierTool) Execute(ctx context.Context, rec *entitys.Recognize) error {
var req OrderAfterSaleSupplierRequest var req OrderAfterSaleSupplierRequest
if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil { if err := json.Unmarshal([]byte(rec.Match.Parameters), &req); err != nil {
return fmt.Errorf("解析参数失败,请重试或联系管理员") return fmt.Errorf("解析参数失败,请重试或联系管理员")
} }
if len(req.SerialNumber) == 0 && len(req.Account) == 0 { if len(req.SerialNumber) == 0 && len(req.Account) == 0 {
@ -110,18 +111,24 @@ func (t *OrderAfterSaleSupplierTool) Execute(ctx context.Context, requireData *e
if req.SerialCreateTime != "" { if req.SerialCreateTime != "" {
_, err := time.ParseInLocation(time.DateTime, req.SerialCreateTime, time.Local) _, err := time.ParseInLocation(time.DateTime, req.SerialCreateTime, time.Local)
if err != nil { if err != nil {
entitys.ResLog(requireData.Ch, t.Name(), "时间格式不匹配,已置为空") entitys.ResLog(rec.Ch, t.Name(), "时间格式不匹配,已置为空")
req.SerialCreateTime = "" req.SerialCreateTime = ""
} }
} }
entitys.ResLog(requireData.Ch, t.Name(), "正在拉取售后订单信息") entitys.ResLog(rec.Ch, t.Name(), "正在拉取售后订单信息")
return t.checkOrderAfterSaleSupplier(req, requireData) return t.checkOrderAfterSaleSupplier(req, rec)
} }
func (t *OrderAfterSaleSupplierTool) checkOrderAfterSaleSupplier(toolReq OrderAfterSaleSupplierRequest, requireData *entitys.RequireData) error { func (t *OrderAfterSaleSupplierTool) checkOrderAfterSaleSupplier(toolReq OrderAfterSaleSupplierRequest, rec *entitys.Recognize) error {
var serialStartTime, serialEndTime int64 var serialStartTime, serialEndTime int64
ext, err := rec_extra.GetTaskRecExt(rec)
if err != nil {
return err
}
if toolReq.SerialCreateTime != "" { if toolReq.SerialCreateTime != "" {
// 流水创建时间上下浮动10min // 流水创建时间上下浮动10min
serialCreateTime, err := time.ParseInLocation(time.DateTime, toolReq.SerialCreateTime, time.Local) serialCreateTime, err := time.ParseInLocation(time.DateTime, toolReq.SerialCreateTime, time.Local)
@ -138,17 +145,16 @@ func (t *OrderAfterSaleSupplierTool) checkOrderAfterSaleSupplier(toolReq OrderAf
// 账号数量超过10直接截断 // 账号数量超过10直接截断
if len(toolReq.Account) > 10 { if len(toolReq.Account) > 10 {
entitys.ResLog(requireData.Ch, t.Name(), "账号数量超过10已被截断") entitys.ResLog(rec.Ch, t.Name(), "账号数量超过10已被截断")
toolReq.Account = toolReq.Account[:10] toolReq.Account = toolReq.Account[:10]
} }
headers := map[string]string{ headers := map[string]string{
"Authorization": fmt.Sprintf("Bearer %s", requireData.Auth), "Authorization": fmt.Sprintf("Bearer %s", ext.Auth),
} }
// 最终输出 // 最终输出
var orderList []*OrderAfterSaleSupplierData var orderList []*OrderAfterSaleSupplierData
var err error
// 多流水号 // 多流水号
if len(toolReq.SerialNumber) > 0 { if len(toolReq.SerialNumber) > 0 {
@ -210,8 +216,8 @@ func (t *OrderAfterSaleSupplierTool) checkOrderAfterSaleSupplier(toolReq OrderAf
return err return err
} }
entitys.ResLog(requireData.Ch, t.Name(), "售后订单信息拉取完成") entitys.ResLog(rec.Ch, t.Name(), "售后订单信息拉取完成")
entitys.ResJson(requireData.Ch, t.Name(), string(jsonByte)) entitys.ResJson(rec.Ch, t.Name(), string(jsonByte))
return nil return nil
} }

View File

@ -4,6 +4,7 @@ import (
"ai_scheduler/internal/config" "ai_scheduler/internal/config"
"ai_scheduler/internal/entitys" "ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg" "ai_scheduler/internal/pkg"
"ai_scheduler/internal/pkg/rec_extra"
"ai_scheduler/internal/pkg/utils_ollama" "ai_scheduler/internal/pkg/utils_ollama"
"context" "context"
"encoding/json" "encoding/json"
@ -81,9 +82,9 @@ type ZltxOrderDetailData struct {
} }
// Execute 执行直连天下订单详情查询 // Execute 执行直连天下订单详情查询
func (w *ZltxOrderDetailTool) Execute(ctx context.Context, requireData *entitys.RequireData) error { func (w *ZltxOrderDetailTool) Execute(ctx context.Context, rec *entitys.Recognize) error {
var req ZltxOrderDetailRequest var req ZltxOrderDetailRequest
if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil { if err := json.Unmarshal([]byte(rec.Match.Parameters), &req); err != nil {
return fmt.Errorf("invalid zltxOrderDetail request: %w", err) return fmt.Errorf("invalid zltxOrderDetail request: %w", err)
} }
if req.OrderNumber == "" { if req.OrderNumber == "" {
@ -91,16 +92,20 @@ func (w *ZltxOrderDetailTool) Execute(ctx context.Context, requireData *entitys.
} }
// 这里可以集成真实的直连天下订单详情API // 这里可以集成真实的直连天下订单详情API
return w.getZltxOrderDetail(requireData, req.OrderNumber) return w.getZltxOrderDetail(rec, req.OrderNumber)
} }
// getMockZltxOrderDetail 获取模拟直连天下订单详情数据 // getMockZltxOrderDetail 获取模拟直连天下订单详情数据
func (w *ZltxOrderDetailTool) getZltxOrderDetail(requireData *entitys.RequireData, number string) (err error) { func (w *ZltxOrderDetailTool) getZltxOrderDetail(rec *entitys.Recognize, number string) (err error) {
ext, err := rec_extra.GetTaskRecExt(rec)
if err != nil {
return
}
//查询订单详情 //查询订单详情
req := l_request.Request{ req := l_request.Request{
Url: fmt.Sprintf(w.config.BaseURL, number), Url: fmt.Sprintf(w.config.BaseURL, number),
Headers: map[string]string{ Headers: map[string]string{
"Authorization": fmt.Sprintf("Bearer %s", requireData.Auth), "Authorization": fmt.Sprintf("Bearer %s", ext.Auth),
}, },
Method: "GET", Method: "GET",
} }
@ -121,15 +126,15 @@ func (w *ZltxOrderDetailTool) getZltxOrderDetail(requireData *entitys.RequireDat
if err = json.Unmarshal(res.Content, &resData); err != nil { if err = json.Unmarshal(res.Content, &resData); err != nil {
return return
} }
entitys.ResJson(requireData.Ch, w.Name(), res.Text) entitys.ResJson(rec.Ch, w.Name(), res.Text)
if resData.Data.Direct != nil { if resData.Data.Direct != nil {
entitys.ResLoading(requireData.Ch, w.Name(), "正在分析订单日志") entitys.ResLoading(rec.Ch, w.Name(), "正在分析订单日志")
req = l_request.Request{ req = l_request.Request{
Url: fmt.Sprintf(w.config.AddURL, resData.Data.Direct["orderOrderNumber"].(string), resData.Data.Direct["serialNumber"].(string)), Url: fmt.Sprintf(w.config.AddURL, resData.Data.Direct["orderOrderNumber"].(string), resData.Data.Direct["serialNumber"].(string)),
Headers: map[string]string{ Headers: map[string]string{
"Authorization": fmt.Sprintf("Bearer %s", requireData.Auth), "Authorization": fmt.Sprintf("Bearer %s", ext.Auth),
}, },
Method: "GET", Method: "GET",
} }
@ -149,14 +154,14 @@ func (w *ZltxOrderDetailTool) getZltxOrderDetail(requireData *entitys.RequireDat
return fmt.Errorf("订单日志解析失败:%s", err) return fmt.Errorf("订单日志解析失败:%s", err)
} }
err = w.llm.ChatStream(context.TODO(), requireData.Ch, []api.Message{ err = w.llm.ChatStream(context.TODO(), rec.Ch, []api.Message{
{ {
Role: "system", Role: "system",
Content: "你是一个订单日志助手。用户可能会提供订单日志,你需要分析订单日志,失败订单->分析失败原因,成功订单->找出整个日志的 Base64 编码的 JSON 数据的内容进行转换并反馈给我", Content: "你是一个订单日志助手。用户可能会提供订单日志,你需要分析订单日志,失败订单->分析失败原因,成功订单->找出整个日志的 Base64 编码的 JSON 数据的内容进行转换并反馈给我",
}, },
{ {
Role: "assistant", Role: "assistant",
Content: fmt.Sprintf("聊天记录:%s", pkg.JsonStringIgonErr(requireData.Histories)), Content: fmt.Sprintf("聊天记录:%s", pkg.JsonStringIgonErr(rec.ChatHis)),
}, },
{ {
Role: "assistant", Role: "assistant",
@ -164,7 +169,7 @@ func (w *ZltxOrderDetailTool) getZltxOrderDetail(requireData *entitys.RequireDat
}, },
{ {
Role: "user", Role: "user",
Content: requireData.Req.Text, Content: rec.UserContent.Text,
}, },
}, w.Name(), "") }, w.Name(), "")
if err != nil { if err != nil {
@ -172,7 +177,7 @@ func (w *ZltxOrderDetailTool) getZltxOrderDetail(requireData *entitys.RequireDat
} }
} }
if resData.Data.Direct == nil { if resData.Data.Direct == nil {
entitys.ResText(requireData.Ch, w.Name(), "该订单无充值流水记录,不需要进行订单错误分析,建议检查订单收单模式以及扣款方式等原因😘") entitys.ResText(rec.Ch, w.Name(), "该订单无充值流水记录,不需要进行订单错误分析,建议检查订单收单模式以及扣款方式等原因😘")
} }
return return
} }

View File

@ -3,6 +3,7 @@ package zltx
import ( import (
"ai_scheduler/internal/config" "ai_scheduler/internal/config"
"ai_scheduler/internal/entitys" "ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg/rec_extra"
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
@ -67,25 +68,28 @@ type ZltxOrderDirectLogData struct {
Data map[string]interface{} `json:"data"` Data map[string]interface{} `json:"data"`
} }
func (t *ZltxOrderLogTool) Execute(ctx context.Context, requireData *entitys.RequireData) error { func (t *ZltxOrderLogTool) Execute(ctx context.Context, rec *entitys.Recognize) error {
var req ZltxOrderLogRequest var req ZltxOrderLogRequest
if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil { if err := json.Unmarshal([]byte(rec.Match.Parameters), &req); err != nil {
return fmt.Errorf("invalid zltxOrderLog request: %w", err) return fmt.Errorf("invalid zltxOrderLog request: %w", err)
} }
if req.OrderNumber == "" || req.SerialNumber == "" { if req.OrderNumber == "" || req.SerialNumber == "" {
return fmt.Errorf("orderNumber and serialNumber is required") return fmt.Errorf("orderNumber and serialNumber is required")
} }
return t.getZltxOrderLog(req.OrderNumber, req.SerialNumber, requireData) return t.getZltxOrderLog(req.OrderNumber, req.SerialNumber, rec)
} }
func (t *ZltxOrderLogTool) getZltxOrderLog(orderNumber, serialNumber string, requireData *entitys.RequireData) (err error) { func (t *ZltxOrderLogTool) getZltxOrderLog(orderNumber, serialNumber string, rec *entitys.Recognize) (err error) {
//查询订单详情 //查询订单详情
ext, err := rec_extra.GetTaskRecExt(rec)
if err != nil {
return
}
url := fmt.Sprintf("%s%s/%s", t.config.BaseURL, orderNumber, serialNumber) url := fmt.Sprintf("%s%s/%s", t.config.BaseURL, orderNumber, serialNumber)
req := l_request.Request{ req := l_request.Request{
Url: url, Url: url,
Headers: map[string]string{ Headers: map[string]string{
"Authorization": fmt.Sprintf("Bearer %s", requireData.Auth), "Authorization": fmt.Sprintf("Bearer %s", ext.Auth),
}, },
Method: "GET", Method: "GET",
} }
@ -100,7 +104,7 @@ func (t *ZltxOrderLogTool) getZltxOrderLog(orderNumber, serialNumber string, req
if err = json.Unmarshal(res.Content, &resData); err != nil { if err = json.Unmarshal(res.Content, &resData); err != nil {
return return
} }
entitys.ResJson(requireData.Ch, t.Name(), res.Text) entitys.ResJson(rec.Ch, t.Name(), res.Text)
return return
} }

View File

@ -3,6 +3,7 @@ package zltx
import ( import (
"ai_scheduler/internal/config" "ai_scheduler/internal/config"
"ai_scheduler/internal/entitys" "ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg/rec_extra"
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
@ -53,12 +54,12 @@ type ZltxProductRequest struct {
Name string `json:"name"` Name string `json:"name"`
} }
func (z ZltxProductTool) Execute(ctx context.Context, requireData *entitys.RequireData) error { func (z ZltxProductTool) Execute(ctx context.Context, rec *entitys.Recognize) error {
var req ZltxProductRequest var req ZltxProductRequest
if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil { if err := json.Unmarshal([]byte(rec.Match.Parameters), &req); err != nil {
return fmt.Errorf("invalid zltxProduct request: %w", err) return fmt.Errorf("invalid zltxProduct request: %w", err)
} }
return z.getZltxProduct(&req, requireData) return z.getZltxProduct(&req, rec)
} }
type ZltxProductResponse struct { type ZltxProductResponse struct {
@ -133,8 +134,11 @@ type ZltxProductData struct {
PlatformProductList interface{} `json:"platform_product_list"` PlatformProductList interface{} `json:"platform_product_list"`
} }
func (z ZltxProductTool) getZltxProduct(body *ZltxProductRequest, requireData *entitys.RequireData) error { func (z ZltxProductTool) getZltxProduct(body *ZltxProductRequest, rec *entitys.Recognize) error {
ext, err := rec_extra.GetTaskRecExt(rec)
if err != nil {
return err
}
var Url string var Url string
var params map[string]string var params map[string]string
if body.Id != "" { if body.Id != "" {
@ -153,7 +157,7 @@ func (z ZltxProductTool) getZltxProduct(body *ZltxProductRequest, requireData *e
//根据商品ID或名称走不同的接口查询 //根据商品ID或名称走不同的接口查询
Url: Url, Url: Url,
Headers: map[string]string{ Headers: map[string]string{
"Authorization": fmt.Sprintf("Bearer %s", requireData.Auth), "Authorization": fmt.Sprintf("Bearer %s", ext.Auth),
}, },
Params: params, Params: params,
Method: "GET", Method: "GET",
@ -185,7 +189,7 @@ func (z ZltxProductTool) getZltxProduct(body *ZltxProductRequest, requireData *e
for i := range resp.Data.List { for i := range resp.Data.List {
// 调用 平台商品列表 // 调用 平台商品列表
if resp.Data.List[i].AuthProductIds != "" { if resp.Data.List[i].AuthProductIds != "" {
platformProductList := z.ExecutePlatformProductList(requireData.Auth, resp.Data.List[i].AuthProductIds, resp.Data.List[i].OfficialProductID) platformProductList := z.ExecutePlatformProductList(ext.Auth, resp.Data.List[i].AuthProductIds, resp.Data.List[i].OfficialProductID)
resp.Data.List[i].PlatformProductList = platformProductList resp.Data.List[i].PlatformProductList = platformProductList
} }
} }
@ -194,7 +198,7 @@ func (z ZltxProductTool) getZltxProduct(body *ZltxProductRequest, requireData *e
if err != nil { if err != nil {
return err return err
} }
entitys.ResJson(requireData.Ch, z.Name(), string(marshal)) entitys.ResJson(rec.Ch, z.Name(), string(marshal))
return nil return nil
} }

View File

@ -3,6 +3,7 @@ package zltx
import ( import (
"ai_scheduler/internal/config" "ai_scheduler/internal/config"
"ai_scheduler/internal/entitys" "ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg/rec_extra"
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
@ -47,15 +48,15 @@ type ZltxOrderStatisticsRequest struct {
Number string `json:"number"` Number string `json:"number"`
} }
func (z ZltxOrderStatisticsTool) Execute(ctx context.Context, requireData *entitys.RequireData) error { func (z ZltxOrderStatisticsTool) Execute(ctx context.Context, rec *entitys.Recognize) error {
var req ZltxOrderStatisticsRequest var req ZltxOrderStatisticsRequest
if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil { if err := json.Unmarshal([]byte(rec.Match.Parameters), &req); err != nil {
return err return err
} }
if req.Number == "" { if req.Number == "" {
return fmt.Errorf("number is required") return fmt.Errorf("number is required")
} }
return z.getZltxOrderStatistics(req.Number, requireData) return z.getZltxOrderStatistics(req.Number, rec)
} }
type ZltxOrderStatisticsResponse struct { type ZltxOrderStatisticsResponse struct {
@ -75,14 +76,18 @@ type ZltxOrderStatisticsData struct {
Total int `json:"total"` Total int `json:"total"`
} }
func (z ZltxOrderStatisticsTool) getZltxOrderStatistics(number string, requireData *entitys.RequireData) error { func (z ZltxOrderStatisticsTool) getZltxOrderStatistics(number string, rec *entitys.Recognize) error {
ext, err := rec_extra.GetTaskRecExt(rec)
if err != nil {
return err
}
//查询订单详情 //查询订单详情
url := fmt.Sprintf("%s%s", z.config.BaseURL, number) url := fmt.Sprintf("%s%s", z.config.BaseURL, number)
req := l_request.Request{ req := l_request.Request{
Url: url, Url: url,
Headers: map[string]string{ Headers: map[string]string{
"Authorization": fmt.Sprintf("Bearer %s", requireData.Auth), "Authorization": fmt.Sprintf("Bearer %s", ext.Auth),
}, },
Method: "GET", Method: "GET",
} }
@ -108,7 +113,7 @@ func (z ZltxOrderStatisticsTool) getZltxOrderStatistics(number string, requireDa
if err != nil { if err != nil {
return err return err
} }
entitys.ResJson(requireData.Ch, z.Name(), string(jsonByte)) entitys.ResJson(rec.Ch, z.Name(), string(jsonByte))
return nil return nil
} }

View File

@ -1,27 +0,0 @@
package tools_bot
import (
"ai_scheduler/internal/config"
"ai_scheduler/internal/data/impl"
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg/utils_ollama"
"context"
)
type BotTool struct {
config *config.Config
llm *utils_ollama.Client
sessionImpl *impl.SessionImpl
taskMap map[string]string
}
// NewBotTool 创建直连天下订单详情工具
func NewBotTool(config *config.Config, llm *utils_ollama.Client, sessionImpl *impl.SessionImpl) *BotTool {
return &BotTool{config: config, llm: llm, sessionImpl: sessionImpl, taskMap: make(map[string]string)}
}
// Execute 执行直连天下订单详情查询
func (w *BotTool) Execute(ctx context.Context, toolName string, requireData *entitys.RequireData) (err error) {
return
}