package tools import ( "ai_scheduler/internal/config" "ai_scheduler/internal/data/constants" "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg/utils_ollama" "encoding/json" "fmt" "github.com/gofiber/websocket/v2" ) // Manager 工具管理器 type Manager struct { tools map[string]entitys.Tool llm *utils_ollama.Client } // NewManager 创建工具管理器 func NewManager(config *config.Config, llm *utils_ollama.Client) *Manager { m := &Manager{ tools: make(map[string]entitys.Tool), llm: llm, } // 注册天气工具 //if config.Tools.Weather.Enabled { // weatherTool := NewWeatherTool() // m.tools[weatherTool.Name()] = weatherTool //} // //// 注册计算器工具 //if config.Tools.Calculator.Enabled { // calcTool := NewCalculatorTool() // m.tools[calcTool.Name()] = calcTool //} // 注册知识库工具 // if config.Knowledge.Enabled { // knowledgeTool := NewKnowledgeTool() // m.tools[knowledgeTool.Name()] = knowledgeTool // } // 注册直连天下订单详情工具 if config.Tools.ZltxOrderDetail.Enabled { zltxOrderDetailTool := NewZltxOrderDetailTool(config.Tools.ZltxOrderDetail, m.llm) m.tools[zltxOrderDetailTool.Name()] = zltxOrderDetailTool } //注册直连天下订单日志工具 if config.Tools.ZltxOrderDirectLog.Enabled { zltxOrderLogTool := NewZltxOrderLogTool(config.Tools.ZltxOrderDirectLog) m.tools[zltxOrderLogTool.Name()] = zltxOrderLogTool } //注册直连天下商品工具 if config.Tools.ZltxProduct.Enabled { zltxProductTool := NewZltxProductTool(config.Tools.ZltxProduct) m.tools[zltxProductTool.Name()] = zltxProductTool } //注册直连天下订单统计工具 if config.Tools.ZltxOrderStatistics.Enabled { zltxOrderStatisticsTool := NewZltxOrderStatisticsTool(config.Tools.ZltxOrderStatistics) m.tools[zltxOrderStatisticsTool.Name()] = zltxOrderStatisticsTool } // 注册知识库工具 if config.Tools.Knowledge.Enabled { knowledgeTool := NewKnowledgeBaseTool(config.Tools.Knowledge) m.tools[knowledgeTool.Name()] = knowledgeTool } return m } // GetTool 获取工具 func (m *Manager) GetTool(name string) (entitys.Tool, bool) { tool, exists := m.tools[name] return tool, exists } // GetAllTools 获取所有工具 func (m *Manager) GetAllTools() []entitys.Tool { tools := make([]entitys.Tool, 0, len(m.tools)) for _, tool := range m.tools { tools = append(tools, tool) } return tools } // GetToolDefinitions 获取所有工具定义 func (m *Manager) GetToolDefinitions(caller constants.Caller) []entitys.ToolDefinition { definitions := make([]entitys.ToolDefinition, 0, len(m.tools)) for _, tool := range m.tools { definitions = append(definitions, tool.Definition()) } return definitions } // ExecuteTool 执行工具 func (m *Manager) ExecuteTool(channel chan entitys.Response, c *websocket.Conn, name string, args json.RawMessage) error { tool, exists := m.GetTool(name) if !exists { return fmt.Errorf("tool not found: %s", name) } return tool.Execute(channel, c, args) } // ExecuteToolCalls 执行多个工具调用 //func (m *Manager) ExecuteToolCalls(c *websocket.Conn, toolCalls []entitys.ToolCall) ([]entitys.ToolCall, error) { // results := make([]entitys.ToolCall, len(toolCalls)) // // for i, toolCall := range toolCalls { // results[i] = toolCall // // // 执行工具 // err := m.ExecuteTool(c, toolCall.Function.Name, toolCall.Function.Arguments) // if err != nil { // // 将错误信息作为结果返回 // errorResult := map[string]interface{}{ // "error": err.Error(), // } // resultBytes, _ := json.Marshal(errorResult) // results[i].Result = resultBytes // } else { // // 将成功结果序列化 // resultBytes, err := json.Marshal(result) // if err != nil { // errorResult := map[string]interface{}{ // "error": fmt.Sprintf("failed to serialize result: %v", err), // } // resultBytes, _ = json.Marshal(errorResult) // } // results[i].Result = resultBytes // } // } // // return results, nil //}