diff --git a/config/config.yaml b/config/config.yaml index a51c128..8e9c764 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -26,3 +26,9 @@ redis: db: driver: mysql source: root:SD###sdf323r343@tcp(121.199.38.107:3306)/sys_ai?charset=utf8mb4&parseTime=true&loc=Asia%2FShanghai + +tools: + zltxOrderDetail: + enabled: true + base_url: "https://gateway.dev.cdlsxd.cn/zltx_api/admin/direct/ai/" + api_key: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJ1c2VyQ2VudGVyIiwiZXhwIjoxNzU2MTgyNTM1LCJuYmYiOjE3NTYxODA3MzUsImp0aSI6IjEiLCJQaG9uZSI6IjE4MDAwMDAwMDAwIiwiVXNlck5hbWUiOiJsc3hkIiwiUmVhbE5hbWUiOiLotoXnuqfnrqHnkIblkZgiLCJBY2NvdW50VHlwZSI6MSwiR3JvdXBDb2RlcyI6IlZDTF9DQVNISUVSLFZDTF9PUEVSQVRFLFZDTF9BRE1JTixWQ0xfQUFBLFZDTF9WQ0xfT1BFUkFULFZDTF9JTlZPSUNFLENSTV9BRE1JTixMSUFOTElBTl9BRE1JTixNQVJLRVRNQUcyX0FETUlOLFBIT05FQklMTF9BRE1JTixRSUFOWkhVX1NVUFBFUl9BRE0sTUFSS0VUSU5HU0FBU19TVVBFUkFETUlOLENBUkRfQ09ERSxDQVJEX1BST0NVUkVNRU5ULE1BUktFVElOR1NZU1RFTV9TVVBFUixTVEFUSVNUSUNBTFNZU1RFTV9BRE1JTixaTFRYX0FETUlOLFpMVFhfT1BFUkFURSIsIkRpbmdVc2VySWQiOiIxNjIwMjYxMjMwMjg5MzM4MzQifQ.N1xv1PYbcO8_jR5adaczc16YzGsr4z101gwEZdulkRaREBJNYTOnFrvRxTFx3RJTooXsqTqroE1MR84v_1WPX6BS6kKonA-kC1Jgot6yrt5rFWhGNGb2Cpr9rKIFCCQYmiGd3AUgDazEeaQ0_sodv3E-EXg9VfE1SX8nMcck9Yjnc8NCy7RTWaBIaSeOdZcEl-JfCD0S6GSx3oErp_hk-U9FKGwf60wAuDGTY1R0BP4BYpcEqS-C2LSnsSGyURi54Cuk5xH8r1WuF0Dm5bwAj5d7Hvs77-N_sUF-C5ONqyZJRAEhYLgcmN9RX_WQZfizdQJxizlTczdpzYfy-v-1eQ" diff --git a/internal/biz/router.go b/internal/biz/router.go index 1471009..a127e27 100644 --- a/internal/biz/router.go +++ b/internal/biz/router.go @@ -2,25 +2,31 @@ package biz import ( "ai_scheduler/internal/config" + "ai_scheduler/internal/data/constant" errors "ai_scheduler/internal/data/error" "ai_scheduler/internal/data/impl" "ai_scheduler/internal/data/model" "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg" + "ai_scheduler/internal/pkg/mapstructure" "ai_scheduler/internal/pkg/utils_ollama" "ai_scheduler/internal/tools" "ai_scheduler/tmpl/dataTemp" "context" "encoding/json" - "log" + "fmt" + "strings" + + "gitea.cdlsxd.cn/self-tools/l_request" + "github.com/gofiber/fiber/v2/log" "github.com/gofiber/websocket/v2" "github.com/tmc/langchaingo/llms" "xorm.io/builder" ) -// AiRouterService 智能路由服务 -type AiRouterService struct { +// AiRouterBiz 智能路由服务 +type AiRouterBiz struct { //aiClient entitys.AIClient toolManager *tools.Manager sessionImpl *impl.SessionImpl @@ -41,8 +47,9 @@ func NewAiRouterBiz( hisImpl *impl.ChatImpl, conf *config.Config, utilAgent *utils_ollama.UtilOllama, -) entitys.RouterService { - return &AiRouterService{ + +) *AiRouterBiz { + return &AiRouterBiz{ //aiClient: aiClient, toolManager: toolManager, sessionImpl: sessionImpl, @@ -55,13 +62,13 @@ func NewAiRouterBiz( } // Route 执行智能路由 -func (r *AiRouterService) Route(ctx context.Context, req *entitys.ChatRequest) (*entitys.ChatResponse, error) { +func (r *AiRouterBiz) Route(ctx context.Context, req *entitys.ChatRequest) (*entitys.ChatResponse, error) { return nil, nil } // Route 执行智能路由 -func (r *AiRouterService) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) error { +func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) error { session := c.Headers("X-Session", "") if len(session) == 0 { @@ -91,26 +98,27 @@ func (r *AiRouterService) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSo return errors.SystemError } - toolDefinitions := r.registerTools(task) + //toolDefinitions := r.registerTools(task) //prompt := r.getPrompt(sysInfo, history, req.Text) //意图预测 - //msg, err := r.utilAgent.Llm.Call(context.TODO(), pkg.JsonStringIgonErr(prompt), - // llms.WithTools(toolDefinitions), - // //llms.WithToolChoice(llms.FunctionCallBehaviorAuto), - // llms.WithJSONMode(), - //) - prompt := r.getPromptLLM(sysInfo, history, req.Text) - msg, err := r.utilAgent.Llm.GenerateContent(context.TODO(), prompt, - llms.WithTools(toolDefinitions), - llms.WithToolChoice("tool_name"), + prompt := r.getPromptLLM(sysInfo, history, req.Text, task) + match, err := r.utilAgent.Llm.GenerateContent(context.TODO(), prompt, + //llms.WithTools(toolDefinitions), + //llms.WithToolChoice("tool_name"), llms.WithJSONMode(), ) if err != nil { return errors.SystemError } - - c.WriteMessage(1, []byte(msg.Choices[0].Content)) + log.Info(match) + var matchJson entitys.Match + err = json.Unmarshal([]byte(match.Choices[0].Content), &matchJson) + if err != nil { + return errors.SystemError + } + return r.handleMatch(c, &matchJson, task) + //c.WriteMessage(1, []byte(msg.Choices[0].Content)) // 构建消息 //messages := []entitys.Message{ // { @@ -193,7 +201,98 @@ func (r *AiRouterService) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSo return nil } -func (r *AiRouterService) getSessionChatHis(sessionId string) (his []model.AiChatHi, err error) { +func (r *AiRouterBiz) handleMatch(c *websocket.Conn, matchJson *entitys.Match, tasks []model.AiTask) (err error) { + defer func() { + if err != nil { + c.WriteMessage(websocket.TextMessage, []byte(err.Error())) + } + c.WriteMessage(websocket.TextMessage, []byte("EOF")) + }() + if !matchJson.IsMatch { + c.WriteMessage(websocket.TextMessage, []byte(matchJson.Reasoning)) + return + } + var pointTask *model.AiTask + for _, task := range tasks { + if task.Index == matchJson.Index { + pointTask = &task + break + } + } + if pointTask == nil || pointTask.Index == "other" { + return r.handleOtherTask(c, matchJson) + } + + switch pointTask.Type { + case constant.TaskTypeApi: + err = r.handleApiTask(c, matchJson, pointTask) + case constant.TaskTypeFunc: + err = r.handleTask(c, matchJson, pointTask) + default: + return r.handleOtherTask(c, matchJson) + } + + return +} + +func (r *AiRouterBiz) handleTask(c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) { + + var configData entitys.ConfigDataTool + err = json.Unmarshal([]byte(task.Config), &configData) + if err != nil { + return + } + err = r.toolManager.ExecuteTool(c, configData.Tool, []byte(matchJson.Parameters)) + if err != nil { + return + } + + return +} + +func (r *AiRouterBiz) handleOtherTask(c *websocket.Conn, matchJson *entitys.Match) (err error) { + + c.WriteMessage(1, []byte(matchJson.Reasoning)) + + return +} + +func (r *AiRouterBiz) handleApiTask(c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) { + var ( + request l_request.Request + auth = c.Headers("X-Authorization", "") + requestParam map[string]interface{} + ) + err = json.Unmarshal([]byte(matchJson.Parameters), &requestParam) + if err != nil { + return + } + request.Url = strings.ReplaceAll(task.Config, "${authorization}", auth) + for k, v := range requestParam { + task.Config = strings.ReplaceAll(task.Config, "${"+k+"}", fmt.Sprintf("%v", v)) + } + var configData entitys.ConfigDataHttp + err = json.Unmarshal([]byte(task.Config), &configData) + if err != nil { + return + } + err = mapstructure.Decode(configData.Request, &request) + if err != nil { + return + } + if len(request.Url) == 0 { + err = errors.NewBusinessErr("00022", "api地址获取失败") + return + } + res, err := request.Send() + if err != nil { + return + } + c.WriteMessage(1, res.Content) + return +} + +func (r *AiRouterBiz) getSessionChatHis(sessionId string) (his []model.AiChatHi, err error) { cond := builder.NewCond() cond = cond.And(builder.Eq{"session_id": sessionId}) @@ -203,7 +302,7 @@ func (r *AiRouterService) getSessionChatHis(sessionId string) (his []model.AiCha return } -func (r *AiRouterService) getSysInfo(appKey string) (sysInfo model.AiSy, err error) { +func (r *AiRouterBiz) getSysInfo(appKey string) (sysInfo model.AiSy, err error) { cond := builder.NewCond() cond = cond.And(builder.Eq{"app_key": appKey}) cond = cond.And(builder.IsNull{"delete_at"}) @@ -212,7 +311,7 @@ func (r *AiRouterService) getSysInfo(appKey string) (sysInfo model.AiSy, err err return } -func (r *AiRouterService) getTasks(sysId int32) (tasks []model.AiTask, err error) { +func (r *AiRouterBiz) getTasks(sysId int32) (tasks []model.AiTask, err error) { cond := builder.NewCond() cond = cond.And(builder.Eq{"sys_id": sysId}) @@ -223,7 +322,7 @@ func (r *AiRouterService) getTasks(sysId int32) (tasks []model.AiTask, err error return } -func (r *AiRouterService) registerTools(tasks []model.AiTask) []llms.Tool { +func (r *AiRouterBiz) registerTools(tasks []model.AiTask) []llms.Tool { taskPrompt := make([]llms.Tool, 0) for _, task := range tasks { var taskConfig entitys.TaskConfig @@ -244,7 +343,7 @@ func (r *AiRouterService) registerTools(tasks []model.AiTask) []llms.Tool { return taskPrompt } -func (r *AiRouterService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string) []entitys.Message { +func (r *AiRouterBiz) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string) []entitys.Message { var ( prompt = make([]entitys.Message, 0) ) @@ -261,7 +360,7 @@ func (r *AiRouterService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi return prompt } -func (r *AiRouterService) getPromptLLM(sysInfo model.AiSy, history []model.AiChatHi, reqInput string) []llms.MessageContent { +func (r *AiRouterBiz) getPromptLLM(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []llms.MessageContent { var ( prompt = make([]llms.MessageContent, 0) ) @@ -275,6 +374,11 @@ func (r *AiRouterService) getPromptLLM(sysInfo model.AiSy, history []model.AiCha Parts: []llms.ContentPart{ llms.TextPart(pkg.JsonStringIgonErr(r.buildAssistant(history))), }, + }, llms.MessageContent{ + Role: llms.ChatMessageTypeTool, + Parts: []llms.ContentPart{ + llms.TextPart(pkg.JsonStringIgonErr(r.registerTools(tasks))), + }, }, llms.MessageContent{ Role: llms.ChatMessageTypeHuman, Parts: []llms.ContentPart{ @@ -285,7 +389,7 @@ func (r *AiRouterService) getPromptLLM(sysInfo model.AiSy, history []model.AiCha } // buildSystemPrompt 构建系统提示词 -func (r *AiRouterService) buildSystemPrompt(prompt string) string { +func (r *AiRouterBiz) buildSystemPrompt(prompt string) string { if len(prompt) == 0 { prompt = "[system] 你是一个智能路由系统,核心职责是 **精准解析用户意图并路由至对应任务模块**\n[rule]\n1.返回以下格式的JSON:{ \"index\": \"工具索引index\", \"confidence\": 0.0-1.0,\"reasoning\": \"判断理由\"}\n2.严格返回字符串格式,禁用markdown格式返回\n3.只返回json字符串,不包含任何其他解释性文字\n4.当用户意图非常不清晰时使用,尝试进行追问具体希望查询内容" } @@ -293,7 +397,7 @@ func (r *AiRouterService) buildSystemPrompt(prompt string) string { return prompt } -func (r *AiRouterService) buildAssistant(his []model.AiChatHi) (chatHis entitys.ChatHis) { +func (r *AiRouterBiz) buildAssistant(his []model.AiChatHi) (chatHis entitys.ChatHis) { for _, item := range his { if len(chatHis.SessionId) == 0 { chatHis.SessionId = item.SessionID @@ -311,61 +415,8 @@ func (r *AiRouterService) buildAssistant(his []model.AiChatHi) (chatHis entitys. return } -// extractIntent 从AI响应中提取意图 -func (r *AiRouterService) extractIntent(response *entitys.ChatResponse) string { - if response == nil || response.Message == "" { - return "" - } - - // 尝试解析JSON - var intent struct { - Intent string `json:"intent"` - Confidence string `json:"confidence"` - Reasoning string `json:"reasoning"` - } - err := json.Unmarshal([]byte(response.Message), &intent) - if err != nil { - log.Printf("Failed to parse intent JSON: %v", err) - return "" - } - - return intent.Intent -} - -// handleOrderDiagnosis 处理订单诊断意图 -func (r *AiRouterService) handleOrderDiagnosis(ctx context.Context, req *entitys.ChatRequest, messages []entitys.Message) (*entitys.ChatResponse, error) { - // 调用订单详情工具 - //orderDetailTool, ok := r.toolManager.GetTool("zltxOrderDetail") - //if orderDetailTool == nil || !ok { - // return nil, fmt.Errorf("order detail tool not found") - //} - //orderDetailTool.Execute(ctx, json.RawMessage{}) - // - //// 获取相关工具定义 - //toolDefinitions := r.toolManager.GetToolDefinitions(constants.Caller(req.Caller)) - // - //// 调用AI,获取是否需要使用工具 - //response, err := r.aiClient.Chat(ctx, messages, toolDefinitions) - //if err != nil { - // return nil, fmt.Errorf("failed to chat with AI: %w", err) - //} - // - //// 如果没有工具调用,直接返回 - //if len(response.ToolCalls) == 0 { - // return response, nil - //} - // - //// 执行工具调用 - //toolResults, err := r.toolManager.ExecuteToolCalls(ctx, response.ToolCalls) - //if err != nil { - // return nil, fmt.Errorf("failed to execute tools: %w", err) - //} - - return nil, nil -} - // handleKnowledgeQA 处理知识问答意图 -func (r *AiRouterService) handleKnowledgeQA(ctx context.Context, req *entitys.ChatRequest, messages []entitys.Message) (*entitys.ChatResponse, error) { +func (r *AiRouterBiz) handleKnowledgeQA(ctx context.Context, req *entitys.ChatRequest, messages []entitys.Message) (*entitys.ChatResponse, error) { return nil, nil } diff --git a/internal/biz/router_test.go b/internal/biz/router_test.go new file mode 100644 index 0000000..d6c8658 --- /dev/null +++ b/internal/biz/router_test.go @@ -0,0 +1,84 @@ +package biz + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/data/impl" + "ai_scheduler/internal/data/model" + "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg/utils_ollama" + "ai_scheduler/internal/tools" + "ai_scheduler/utils" + "encoding/json" + "flag" + "fmt" + "os" + "path/filepath" + "testing" + + "github.com/gofiber/fiber/v2/log" +) + +func Test_task(t *testing.T) { + var c entitys.TaskConfig + config := `{"param": {"type": "object", "required": ["number"], "properties": {"number": {"type": "string", "description": "订单编号/流水号"}}}, "request": {"url": "http://www.baidu.com/${number}", "headers": {"Authorization": "${authorization}"}, "method": "GET"}}` + err := json.Unmarshal([]byte(config), &c) + t.Log(err) +} + +type configData struct { + Param map[string]interface{} `json:"param"` + Do map[string]interface{} `json:"do"` +} + +func Test_Order(t *testing.T) { + routerBiz := in() + err := routerBiz.handleTask(nil, &entitys.Match{Index: "order_diagnosis", Parameters: `{"order_number":"12312312312"}`}, &model.AiTask{Config: `{"tool": "zltxOrderDetail", "param": {"type": "object", "optional": [], "required": ["order_number"], "properties": {"order_number": {"type": "string", "description": "订单编号/流水号"}}}}`}) + t.Log(err) +} + +func in() *AiRouterBiz { + modDir, err := getModuleDir() + if err != nil { + panic("1") + } + configPath := flag.String("config", fmt.Sprintf("%s/config/config.yaml", modDir), "Path to configuration file") + flag.Parse() + configConfig, err := config.LoadConfig(*configPath) + if err != nil { + panic("加载配置失败") + } + allLogger := log.DefaultLogger() + manager := tools.NewManager(configConfig) + db, _ := utils.NewGormDb(configConfig) + sessionImpl := impl.NewSessionImpl(db) + sysImpl := impl.NewSysImpl(db) + taskImpl := impl.NewTaskImpl(db) + chatImpl := impl.NewChatImpl(db) + utilOllama := utils_ollama.NewUtilOllama(configConfig, allLogger) + routerBiz := NewAiRouterBiz(manager, sessionImpl, sysImpl, taskImpl, chatImpl, configConfig, utilOllama) + + return routerBiz +} + +func getModuleDir() (string, error) { + dir, err := os.Getwd() + if err != nil { + return "", err + } + + for { + modPath := filepath.Join(dir, "go.mod") + if _, err := os.Stat(modPath); err == nil { + return dir, nil // 找到 go.mod + } + + // 向上查找父目录 + parent := filepath.Dir(dir) + if parent == dir { + break // 到达根目录,未找到 + } + dir = parent + } + + return "", fmt.Errorf("go.mod not found in current directory or parents") +} diff --git a/internal/config/config.go b/internal/config/config.go index bfbedbf..c3c75c9 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -73,10 +73,9 @@ type ToolsConfig struct { // ToolConfig 单个工具配置 type ToolConfig struct { - Enabled bool `mapstructure:"enabled"` - BaseURL string `mapstructure:"base_url"` - APIKey string `mapstructure:"api_key"` - BizSystem string `mapstructure:"biz_system"` + Enabled bool `mapstructure:"enabled"` + BaseURL string `mapstructure:"base_url"` + APIKey string `mapstructure:"api_key"` } // LoggingConfig 日志配置 diff --git a/internal/data/constant/const.go b/internal/data/constant/const.go index d611d70..13ef745 100644 --- a/internal/data/constant/const.go +++ b/internal/data/constant/const.go @@ -11,6 +11,7 @@ const ( type TaskType int32 const ( - TaskTypeApi ConnStatus = iota + 1 - TaskTypeKnowle + TaskTypeApi = 1 + TaskTypeKnowle = 2 + TaskTypeFunc = 3 ) diff --git a/internal/entitys/types.go b/internal/entitys/types.go index 13a1f9f..8705225 100644 --- a/internal/entitys/types.go +++ b/internal/entitys/types.go @@ -66,12 +66,17 @@ type Tool interface { Name() string Description() string Definition() ToolDefinition - Execute(ctx context.Context, args json.RawMessage) (interface{}, error) + Execute(c *websocket.Conn, args json.RawMessage) error } -// AIClient AI客户端接口 -type AIClient interface { - Chat(ctx context.Context, messages []Message, tools []ToolDefinition) (*ChatResponse, error) +type ConfigDataHttp struct { + Param map[string]interface{} `json:"param"` + Request map[string]interface{} `json:"request"` +} + +type ConfigDataTool struct { + Param map[string]interface{} `json:"param"` + Tool string `json:"tool"` } // Message 消息 @@ -88,7 +93,13 @@ type FuncApi struct { type TaskConfig struct { Param interface{} `json:"param"` } - +type Match struct { + Confidence float64 `json:"confidence"` + Index string `json:"index"` + IsMatch bool `json:"is_match"` + Parameters string `json:"parameters"` + Reasoning string `json:"reasoning"` +} type ChatHis struct { SessionId string `json:"session_id"` Messages []HisMessage `json:"messages"` diff --git a/internal/pkg/utils_ollama/ollama.go b/internal/pkg/utils_ollama/ollama.go index f5546bb..65388a3 100644 --- a/internal/pkg/utils_ollama/ollama.go +++ b/internal/pkg/utils_ollama/ollama.go @@ -18,7 +18,7 @@ func NewUtilOllama(c *config.Config, logger log.AllLogger) *UtilOllama { ollama.WithModel(c.Ollama.Model), ollama.WithHTTPClient(http.DefaultClient), ollama.WithServerURL(getUrl(c)), - ollama.WithKeepAlive("1h"), + ollama.WithKeepAlive("-1s"), ) if err != nil { logger.Fatal(err) diff --git a/internal/services/chat.go b/internal/services/chat.go index 43804b8..82995e1 100644 --- a/internal/services/chat.go +++ b/internal/services/chat.go @@ -1,6 +1,7 @@ package services import ( + "ai_scheduler/internal/biz" "ai_scheduler/internal/data/constant" "ai_scheduler/internal/entitys" "ai_scheduler/internal/gateway" @@ -17,16 +18,16 @@ import ( // ChatHandler 聊天处理器 type ChatService struct { - routerService entitys.RouterService - Gw *gateway.Gateway - mu sync.Mutex + routerBiz *biz.AiRouterBiz + Gw *gateway.Gateway + mu sync.Mutex } // NewChatHandler 创建聊天处理器 -func NewChatService(routerService entitys.RouterService, gw *gateway.Gateway) *ChatService { +func NewChatService(routerService *biz.AiRouterBiz, gw *gateway.Gateway) *ChatService { return &ChatService{ - routerService: routerService, - Gw: gw, + routerBiz: routerService, + Gw: gw, } } @@ -100,7 +101,7 @@ func (h *ChatService) Chat(c *websocket.Conn) { log.Println("JSON parse error:", err) continue } - err = h.routerService.RouteWithSocket(c, &req) + err = h.routerBiz.RouteWithSocket(c, &req) if err != nil { log.Println("处理失败:", err) continue diff --git a/internal/test/chat_test.go b/internal/test/chat_test.go deleted file mode 100644 index ebd1bda..0000000 --- a/internal/test/chat_test.go +++ /dev/null @@ -1,14 +0,0 @@ -package test - -import ( - "ai_scheduler/internal/entitys" - "encoding/json" - "testing" -) - -func Test_task(t *testing.T) { - var c entitys.TaskConfig - config := `{"param": {"type": "object", "required": ["number"], "properties": {"number": {"type": "string", "description": "订单编号/流水号"}}}, "request": {"url": "http://www.baidu.com/${number}", "headers": {"Authorization": "${authorization}"}, "method": "GET"}}` - err := json.Unmarshal([]byte(config), &c) - t.Log(err) -} diff --git a/internal/test/task_test.go b/internal/test/task_test.go new file mode 100644 index 0000000..f46832f --- /dev/null +++ b/internal/test/task_test.go @@ -0,0 +1,41 @@ +package test + +import ( + "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg/mapstructure" + "encoding/json" + "testing" + + "gitea.cdlsxd.cn/self-tools/l_request" +) + +func Test_task(t *testing.T) { + var c entitys.TaskConfig + config := `{"param": {"type": "object", "required": ["number"], "properties": {"number": {"type": "string", "description": "订单编号/流水号"}}}, "request": {"url": "http://www.baidu.com/${number}", "headers": {"Authorization": "${authorization}"}, "method": "GET"}}` + err := json.Unmarshal([]byte(config), &c) + t.Log(err) +} + +type configData struct { + Param map[string]interface{} `json:"param"` + Do map[string]interface{} `json:"do"` +} + +func Test_task2(t *testing.T) { + var ( + c l_request.Request + config configData + ) + + configJson := `{"tool": "zltxOrderDetail", "param": {"type": "object", "optional": [], "required": ["order_number"], "properties": {"order_number": {"type": "string", "description": "订单编号/流水号"}}}}` + err := json.Unmarshal([]byte(configJson), &config) + if err != nil { + panic(err) + } + mapstructure.Decode(config.Do, &c) + t.Log(err) +} + +func in() { + +} diff --git a/internal/tools/manager.go b/internal/tools/manager.go index 28b1c12..801fdd1 100644 --- a/internal/tools/manager.go +++ b/internal/tools/manager.go @@ -5,9 +5,10 @@ import ( "ai_scheduler/internal/data/constants" "ai_scheduler/internal/entitys" - "context" "encoding/json" "fmt" + + "github.com/gofiber/websocket/v2" ) // Manager 工具管理器 @@ -22,16 +23,16 @@ func NewManager(config *config.Config) *Manager { } // 注册天气工具 - if config.Tools.Weather.Enabled { - weatherTool := NewWeatherTool() - m.tools[weatherTool.Name()] = weatherTool - } - - // 注册计算器工具 - if config.Tools.Calculator.Enabled { - calcTool := NewCalculatorTool() - m.tools[calcTool.Name()] = calcTool - } + //if config.Tools.Weather.Enabled { + // weatherTool := NewWeatherTool() + // m.tools[weatherTool.Name()] = weatherTool + //} + // + //// 注册计算器工具 + //if config.Tools.Calculator.Enabled { + // calcTool := NewCalculatorTool() + // m.tools[calcTool.Name()] = calcTool + //} // 注册知识库工具 // if config.Knowledge.Enabled { @@ -80,43 +81,43 @@ func (m *Manager) GetToolDefinitions(caller constants.Caller) []entitys.ToolDefi } // ExecuteTool 执行工具 -func (m *Manager) ExecuteTool(ctx context.Context, name string, args json.RawMessage) (interface{}, error) { +func (m *Manager) ExecuteTool(c *websocket.Conn, name string, args json.RawMessage) error { tool, exists := m.GetTool(name) if !exists { - return nil, fmt.Errorf("tool not found: %s", name) + return fmt.Errorf("tool not found: %s", name) } - return tool.Execute(ctx, args) + return tool.Execute(c, args) } // ExecuteToolCalls 执行多个工具调用 -func (m *Manager) ExecuteToolCalls(ctx context.Context, toolCalls []entitys.ToolCall) ([]entitys.ToolCall, error) { - results := make([]entitys.ToolCall, len(toolCalls)) - - for i, toolCall := range toolCalls { - results[i] = toolCall - - // 执行工具 - result, err := m.ExecuteTool(ctx, toolCall.Function.Name, toolCall.Function.Arguments) - if err != nil { - // 将错误信息作为结果返回 - errorResult := map[string]interface{}{ - "error": err.Error(), - } - resultBytes, _ := json.Marshal(errorResult) - results[i].Result = resultBytes - } else { - // 将成功结果序列化 - resultBytes, err := json.Marshal(result) - if err != nil { - errorResult := map[string]interface{}{ - "error": fmt.Sprintf("failed to serialize result: %v", err), - } - resultBytes, _ = json.Marshal(errorResult) - } - results[i].Result = resultBytes - } - } - - return results, nil -} +//func (m *Manager) ExecuteToolCalls(c *websocket.Conn, toolCalls []entitys.ToolCall) ([]entitys.ToolCall, error) { +// results := make([]entitys.ToolCall, len(toolCalls)) +// +// for i, toolCall := range toolCalls { +// results[i] = toolCall +// +// // 执行工具 +// err := m.ExecuteTool(c, toolCall.Function.Name, toolCall.Function.Arguments) +// if err != nil { +// // 将错误信息作为结果返回 +// errorResult := map[string]interface{}{ +// "error": err.Error(), +// } +// resultBytes, _ := json.Marshal(errorResult) +// results[i].Result = resultBytes +// } else { +// // 将成功结果序列化 +// resultBytes, err := json.Marshal(result) +// if err != nil { +// errorResult := map[string]interface{}{ +// "error": fmt.Sprintf("failed to serialize result: %v", err), +// } +// resultBytes, _ = json.Marshal(errorResult) +// } +// results[i].Result = resultBytes +// } +// } +// +// return results, nil +//} diff --git a/internal/tools/zltx_order_detail.go b/internal/tools/zltx_order_detail.go index 4d6ddd8..9cbb39e 100644 --- a/internal/tools/zltx_order_detail.go +++ b/internal/tools/zltx_order_detail.go @@ -4,10 +4,11 @@ import ( "ai_scheduler/internal/config" "ai_scheduler/internal/entitys" - "context" "encoding/json" "fmt" - "net/http" + + "gitea.cdlsxd.cn/self-tools/l_request" + "github.com/gofiber/websocket/v2" ) // ZltxOrderDetailTool 直连天下订单详情工具 @@ -53,7 +54,7 @@ func (w *ZltxOrderDetailTool) Definition() entitys.ToolDefinition { // ZltxOrderDetailRequest 直连天下订单详情请求参数 type ZltxOrderDetailRequest struct { - Number string `json:"number"` + OrderNumber string `json:"order_number"` } // ZltxOrderDetailResponse 直连天下订单详情响应 @@ -70,37 +71,43 @@ type ZltxOrderDetailData struct { } // Execute 执行直连天下订单详情查询 -func (w *ZltxOrderDetailTool) Execute(ctx context.Context, args json.RawMessage) (interface{}, error) { +func (w *ZltxOrderDetailTool) Execute(c *websocket.Conn, args json.RawMessage) error { var req ZltxOrderDetailRequest if err := json.Unmarshal(args, &req); err != nil { - return nil, fmt.Errorf("invalid zltxOrderDetail request: %w", err) + return fmt.Errorf("invalid zltxOrderDetail request: %w", err) } - if req.Number == "" { - return nil, fmt.Errorf("number is required") + if req.OrderNumber == "" { + return fmt.Errorf("number is required") } // 这里可以集成真实的直连天下订单详情API - return w.getZltxOrderDetail(ctx, req.Number), nil + return w.getZltxOrderDetail(c, req.OrderNumber) } // getMockZltxOrderDetail 获取模拟直连天下订单详情数据 -func (w *ZltxOrderDetailTool) getZltxOrderDetail(ctx context.Context, number string) *ZltxOrderDetailResponse { - url := fmt.Sprintf("%s/admin/direct/ai/%s", w.config.BaseURL, number) - authorization := fmt.Sprintf("Bearer %s", w.config.APIKey) - - // 发送http请求 - req, err := http.NewRequest("GET", url, nil) - if err != nil { - return &ZltxOrderDetailResponse{} +func (w *ZltxOrderDetailTool) getZltxOrderDetail(c *websocket.Conn, number string) (err error) { + //查询订单详情 + var auth string + if c != nil { + auth = c.Headers("X-Authorization", "") } - req.Header.Set("Authorization", authorization) - - resp, err := http.DefaultClient.Do(req) - if err != nil { - return &ZltxOrderDetailResponse{} + if len(auth) == 0 { + auth = w.config.APIKey } - defer resp.Body.Close() + req := l_request.Request{ + Url: fmt.Sprintf("%s%s", w.config.BaseURL, number), + Headers: map[string]string{ + "Authorization": fmt.Sprintf("Bearer %s", auth), + }, + Method: "GET", + } + res, err := req.Send() - return &ZltxOrderDetailResponse{} + if err != nil { + return + } + c.WriteMessage(websocket.TextMessage, res.Content) + + return }