diff --git a/internal/biz/router.go b/internal/biz/router.go index 0cdc93d..45ad515 100644 --- a/internal/biz/router.go +++ b/internal/biz/router.go @@ -15,7 +15,7 @@ import ( "context" "encoding/json" "fmt" - "net/http" + "strings" "gitea.cdlsxd.cn/self-tools/l_request" @@ -203,10 +203,13 @@ func (r *AiRouterService) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSo func (r *AiRouterService) handleMatch(c *websocket.Conn, matchJson *entitys.Match, tasks []model.AiTask) (err error) { defer func() { - c.WriteMessage(1, []byte("EOF")) + if err != nil { + c.WriteMessage(websocket.TextMessage, []byte(err.Error())) + } + c.WriteMessage(websocket.TextMessage, []byte("EOF")) }() if !matchJson.IsMatch { - c.WriteMessage(1, []byte(matchJson.Reasoning)) + c.WriteMessage(websocket.TextMessage, []byte(matchJson.Reasoning)) return } var pointTask *model.AiTask @@ -219,33 +222,32 @@ func (r *AiRouterService) handleMatch(c *websocket.Conn, matchJson *entitys.Matc if pointTask == nil || pointTask.Index == "other" { return r.handleOtherTask(c, matchJson) } - var res []byte + switch pointTask.Type { case constant.TaskTypeApi: - res, err = r.handleApiTask(c, matchJson, pointTask) + err = r.handleApiTask(c, matchJson, pointTask) case constant.TaskTypeFunc: - ctx := context.TODO() - res, err = r.handleTask(ctx, matchJson, pointTask) + err = r.handleTask(c, matchJson, pointTask) default: return r.handleOtherTask(c, matchJson) } - fmt.Println(res) + return } -func (r *AiRouterService) handleTask(c context.Context, matchJson *entitys.Match, task *model.AiTask) (res []byte, err error) { +func (r *AiRouterService) handleTask(c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) { var configData entitys.ConfigDataTool err = json.Unmarshal([]byte(task.Config), &configData) if err != nil { return } - resInterface, err := r.toolManager.ExecuteTool(c, configData.Tool, []byte(matchJson.Parameters)) + err = r.toolManager.ExecuteTool(c, configData.Tool, []byte(matchJson.Parameters)) if err != nil { - return nil, err + return } - return json.Marshal(resInterface) + return } func (r *AiRouterService) handleOtherTask(c *websocket.Conn, matchJson *entitys.Match) (err error) { @@ -255,7 +257,7 @@ func (r *AiRouterService) handleOtherTask(c *websocket.Conn, matchJson *entitys. return } -func (r *AiRouterService) handleApiTask(c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (resByte []byte, err error) { +func (r *AiRouterService) handleApiTask(c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) { var ( request l_request.Request auth = c.Headers("X-Authorization", "") @@ -286,7 +288,8 @@ func (r *AiRouterService) handleApiTask(c *websocket.Conn, matchJson *entitys.Ma if err != nil { return } - return res.Content, nil + c.WriteMessage(1, res.Content) + return } func (r *AiRouterService) getSessionChatHis(sessionId string) (his []model.AiChatHi, err error) { @@ -412,59 +415,6 @@ func (r *AiRouterService) buildAssistant(his []model.AiChatHi) (chatHis entitys. return } -// extractIntent 从AI响应中提取意图 -func (r *AiRouterService) extractIntent(response *entitys.ChatResponse) string { - if response == nil || response.Message == "" { - return "" - } - - // 尝试解析JSON - var intent struct { - Intent string `json:"intent"` - Confidence string `json:"confidence"` - Reasoning string `json:"reasoning"` - } - err := json.Unmarshal([]byte(response.Message), &intent) - if err != nil { - log.Printf("Failed to parse intent JSON: %v", err) - return "" - } - - return intent.Intent -} - -// handleOrderDiagnosis 处理订单诊断意图 -func (r *AiRouterService) handleOrderDiagnosis(ctx context.Context, req *entitys.ChatRequest, messages []entitys.Message) (*entitys.ChatResponse, error) { - // 调用订单详情工具 - //orderDetailTool, ok := r.toolManager.GetTool("zltxOrderDetail") - //if orderDetailTool == nil || !ok { - // return nil, fmt.Errorf("order detail tool not found") - //} - //orderDetailTool.Execute(ctx, json.RawMessage{}) - // - //// 获取相关工具定义 - //toolDefinitions := r.toolManager.GetToolDefinitions(constants.Caller(req.Caller)) - // - //// 调用AI,获取是否需要使用工具 - //response, err := r.aiClient.Chat(ctx, messages, toolDefinitions) - //if err != nil { - // return nil, fmt.Errorf("failed to chat with AI: %w", err) - //} - // - //// 如果没有工具调用,直接返回 - //if len(response.ToolCalls) == 0 { - // return response, nil - //} - // - //// 执行工具调用 - //toolResults, err := r.toolManager.ExecuteToolCalls(ctx, response.ToolCalls) - //if err != nil { - // return nil, fmt.Errorf("failed to execute tools: %w", err) - //} - - return nil, nil -} - // handleKnowledgeQA 处理知识问答意图 func (r *AiRouterService) handleKnowledgeQA(ctx context.Context, req *entitys.ChatRequest, messages []entitys.Message) (*entitys.ChatResponse, error) { diff --git a/internal/entitys/types.go b/internal/entitys/types.go index fea6177..8705225 100644 --- a/internal/entitys/types.go +++ b/internal/entitys/types.go @@ -66,7 +66,7 @@ type Tool interface { Name() string Description() string Definition() ToolDefinition - Execute(ctx context.Context, args json.RawMessage) (interface{}, error) + Execute(c *websocket.Conn, args json.RawMessage) error } type ConfigDataHttp struct { diff --git a/internal/tools/manager.go b/internal/tools/manager.go index 6b907e2..a587a61 100644 --- a/internal/tools/manager.go +++ b/internal/tools/manager.go @@ -4,10 +4,10 @@ import ( "ai_scheduler/internal/config" "ai_scheduler/internal/constants" "ai_scheduler/internal/entitys" - "context" "encoding/json" "fmt" + "github.com/gofiber/websocket/v2" ) // Manager 工具管理器 @@ -22,16 +22,16 @@ func NewManager(config *config.Config) *Manager { } // 注册天气工具 - if config.Tools.Weather.Enabled { - weatherTool := NewWeatherTool() - m.tools[weatherTool.Name()] = weatherTool - } - - // 注册计算器工具 - if config.Tools.Calculator.Enabled { - calcTool := NewCalculatorTool() - m.tools[calcTool.Name()] = calcTool - } + //if config.Tools.Weather.Enabled { + // weatherTool := NewWeatherTool() + // m.tools[weatherTool.Name()] = weatherTool + //} + // + //// 注册计算器工具 + //if config.Tools.Calculator.Enabled { + // calcTool := NewCalculatorTool() + // m.tools[calcTool.Name()] = calcTool + //} // 注册知识库工具 // if config.Knowledge.Enabled { @@ -80,43 +80,43 @@ func (m *Manager) GetToolDefinitions(caller constants.Caller) []entitys.ToolDefi } // ExecuteTool 执行工具 -func (m *Manager) ExecuteTool(ctx context.Context, name string, args json.RawMessage) (interface{}, error) { +func (m *Manager) ExecuteTool(c *websocket.Conn, name string, args json.RawMessage) error { tool, exists := m.GetTool(name) if !exists { - return nil, fmt.Errorf("tool not found: %s", name) + return fmt.Errorf("tool not found: %s", name) } - return tool.Execute(ctx, args) + return tool.Execute(c, args) } // ExecuteToolCalls 执行多个工具调用 -func (m *Manager) ExecuteToolCalls(ctx context.Context, toolCalls []entitys.ToolCall) ([]entitys.ToolCall, error) { - results := make([]entitys.ToolCall, len(toolCalls)) - - for i, toolCall := range toolCalls { - results[i] = toolCall - - // 执行工具 - result, err := m.ExecuteTool(ctx, toolCall.Function.Name, toolCall.Function.Arguments) - if err != nil { - // 将错误信息作为结果返回 - errorResult := map[string]interface{}{ - "error": err.Error(), - } - resultBytes, _ := json.Marshal(errorResult) - results[i].Result = resultBytes - } else { - // 将成功结果序列化 - resultBytes, err := json.Marshal(result) - if err != nil { - errorResult := map[string]interface{}{ - "error": fmt.Sprintf("failed to serialize result: %v", err), - } - resultBytes, _ = json.Marshal(errorResult) - } - results[i].Result = resultBytes - } - } - - return results, nil -} +//func (m *Manager) ExecuteToolCalls(c *websocket.Conn, toolCalls []entitys.ToolCall) ([]entitys.ToolCall, error) { +// results := make([]entitys.ToolCall, len(toolCalls)) +// +// for i, toolCall := range toolCalls { +// results[i] = toolCall +// +// // 执行工具 +// err := m.ExecuteTool(c, toolCall.Function.Name, toolCall.Function.Arguments) +// if err != nil { +// // 将错误信息作为结果返回 +// errorResult := map[string]interface{}{ +// "error": err.Error(), +// } +// resultBytes, _ := json.Marshal(errorResult) +// results[i].Result = resultBytes +// } else { +// // 将成功结果序列化 +// resultBytes, err := json.Marshal(result) +// if err != nil { +// errorResult := map[string]interface{}{ +// "error": fmt.Sprintf("failed to serialize result: %v", err), +// } +// resultBytes, _ = json.Marshal(errorResult) +// } +// results[i].Result = resultBytes +// } +// } +// +// return results, nil +//} diff --git a/internal/tools/zltx_order_detail.go b/internal/tools/zltx_order_detail.go index 4d6ddd8..417a409 100644 --- a/internal/tools/zltx_order_detail.go +++ b/internal/tools/zltx_order_detail.go @@ -4,10 +4,11 @@ import ( "ai_scheduler/internal/config" "ai_scheduler/internal/entitys" - "context" "encoding/json" "fmt" "net/http" + + "github.com/gofiber/websocket/v2" ) // ZltxOrderDetailTool 直连天下订单详情工具 @@ -70,37 +71,37 @@ type ZltxOrderDetailData struct { } // Execute 执行直连天下订单详情查询 -func (w *ZltxOrderDetailTool) Execute(ctx context.Context, args json.RawMessage) (interface{}, error) { +func (w *ZltxOrderDetailTool) Execute(c *websocket.Conn, args json.RawMessage) error { var req ZltxOrderDetailRequest if err := json.Unmarshal(args, &req); err != nil { - return nil, fmt.Errorf("invalid zltxOrderDetail request: %w", err) + return fmt.Errorf("invalid zltxOrderDetail request: %w", err) } if req.Number == "" { - return nil, fmt.Errorf("number is required") + return fmt.Errorf("number is required") } // 这里可以集成真实的直连天下订单详情API - return w.getZltxOrderDetail(ctx, req.Number), nil + return w.getZltxOrderDetail(c, req.Number) } // getMockZltxOrderDetail 获取模拟直连天下订单详情数据 -func (w *ZltxOrderDetailTool) getZltxOrderDetail(ctx context.Context, number string) *ZltxOrderDetailResponse { +func (w *ZltxOrderDetailTool) getZltxOrderDetail(c *websocket.Conn, number string) (err error) { url := fmt.Sprintf("%s/admin/direct/ai/%s", w.config.BaseURL, number) authorization := fmt.Sprintf("Bearer %s", w.config.APIKey) // 发送http请求 req, err := http.NewRequest("GET", url, nil) if err != nil { - return &ZltxOrderDetailResponse{} + return } req.Header.Set("Authorization", authorization) resp, err := http.DefaultClient.Do(req) if err != nil { - return &ZltxOrderDetailResponse{} + return } defer resp.Body.Close() - return &ZltxOrderDetailResponse{} + return }