ai_scheduler/internal/tools/manager.go

146 lines
4.0 KiB
Go

package tools
import (
"ai_scheduler/internal/config"
"ai_scheduler/internal/data/constants"
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg/utils_ollama"
"context"
"fmt"
)
// Manager 工具管理器
type Manager struct {
tools map[string]entitys.Tool
llm *utils_ollama.Client
}
// NewManager 创建工具管理器
func NewManager(config *config.Config, llm *utils_ollama.Client) *Manager {
m := &Manager{
tools: make(map[string]entitys.Tool),
llm: llm,
}
// 注册天气工具
//if config.Tools.Weather.Enabled {
// weatherTool := NewWeatherTool()
// m.tools[weatherTool.Name()] = weatherTool
//}
//
//// 注册计算器工具
//if config.Tools.Calculator.Enabled {
// calcTool := NewCalculatorTool()
// m.tools[calcTool.Name()] = calcTool
//}
// 注册知识库工具
// if config.Knowledge.Enabled {
// knowledgeTool := NewKnowledgeTool()
// m.tools[knowledgeTool.Name()] = knowledgeTool
// }
// 注册直连天下订单详情工具
if config.Tools.ZltxOrderDetail.Enabled {
zltxOrderDetailTool := NewZltxOrderDetailTool(config.Tools.ZltxOrderDetail, m.llm)
m.tools[zltxOrderDetailTool.Name()] = zltxOrderDetailTool
}
//注册直连天下订单日志工具
if config.Tools.ZltxOrderDirectLog.Enabled {
zltxOrderLogTool := NewZltxOrderLogTool(config.Tools.ZltxOrderDirectLog)
m.tools[zltxOrderLogTool.Name()] = zltxOrderLogTool
}
//注册直连天下商品工具
if config.Tools.ZltxProduct.Enabled {
zltxProductTool := NewZltxProductTool(config.Tools.ZltxProduct)
m.tools[zltxProductTool.Name()] = zltxProductTool
}
//注册直连天下订单统计工具
if config.Tools.ZltxOrderStatistics.Enabled {
zltxOrderStatisticsTool := NewZltxOrderStatisticsTool(config.Tools.ZltxOrderStatistics)
m.tools[zltxOrderStatisticsTool.Name()] = zltxOrderStatisticsTool
}
// 注册知识库工具
if config.Tools.Knowledge.Enabled {
knowledgeTool := NewKnowledgeBaseTool(config.Tools.Knowledge)
m.tools[knowledgeTool.Name()] = knowledgeTool
}
// 普通对话
chat := NewNormalChatTool(m.llm)
m.tools[chat.Name()] = chat
return m
}
// GetTool 获取工具
func (m *Manager) GetTool(name string) (entitys.Tool, bool) {
tool, exists := m.tools[name]
return tool, exists
}
// GetAllTools 获取所有工具
func (m *Manager) GetAllTools() []entitys.Tool {
tools := make([]entitys.Tool, 0, len(m.tools))
for _, tool := range m.tools {
tools = append(tools, tool)
}
return tools
}
// GetToolDefinitions 获取所有工具定义
func (m *Manager) GetToolDefinitions(caller constants.Caller) []entitys.ToolDefinition {
definitions := make([]entitys.ToolDefinition, 0, len(m.tools))
for _, tool := range m.tools {
definitions = append(definitions, tool.Definition())
}
return definitions
}
// ExecuteTool 执行工具
func (m *Manager) ExecuteTool(ctx context.Context, name string, requireData *entitys.RequireData) error {
tool, exists := m.GetTool(name)
if !exists {
return fmt.Errorf("tool not found: %s", name)
}
return tool.Execute(ctx, requireData)
}
// ExecuteToolCalls 执行多个工具调用
//func (m *Manager) ExecuteToolCalls(c *websocket.Conn, toolCalls []entitys.ToolCall) ([]entitys.ToolCall, error) {
// results := make([]entitys.ToolCall, len(toolCalls))
//
// for i, toolCall := range toolCalls {
// results[i] = toolCall
//
// // 执行工具
// err := m.ExecuteTool(c, toolCall.Function.Name, toolCall.Function.Arguments)
// if err != nil {
// // 将错误信息作为结果返回
// errorResult := map[string]interface{}{
// "error": err.Error(),
// }
// resultBytes, _ := json.Marshal(errorResult)
// results[i].Result = resultBytes
// } else {
// // 将成功结果序列化
// resultBytes, err := json.Marshal(result)
// if err != nil {
// errorResult := map[string]interface{}{
// "error": fmt.Sprintf("failed to serialize result: %v", err),
// }
// resultBytes, _ = json.Marshal(errorResult)
// }
// results[i].Result = resultBytes
// }
// }
//
// return results, nil
//}