143 lines
4.0 KiB
Go
143 lines
4.0 KiB
Go
package tools
|
|
|
|
import (
|
|
"ai_scheduler/internal/config"
|
|
"ai_scheduler/internal/data/constants"
|
|
"ai_scheduler/internal/entitys"
|
|
"ai_scheduler/internal/pkg/utils_ollama"
|
|
|
|
"encoding/json"
|
|
"fmt"
|
|
|
|
"github.com/gofiber/websocket/v2"
|
|
)
|
|
|
|
// Manager 工具管理器
|
|
type Manager struct {
|
|
tools map[string]entitys.Tool
|
|
llm *utils_ollama.Client
|
|
}
|
|
|
|
// NewManager 创建工具管理器
|
|
func NewManager(config *config.Config, llm *utils_ollama.Client) *Manager {
|
|
m := &Manager{
|
|
tools: make(map[string]entitys.Tool),
|
|
llm: llm,
|
|
}
|
|
|
|
// 注册天气工具
|
|
//if config.Tools.Weather.Enabled {
|
|
// weatherTool := NewWeatherTool()
|
|
// m.tools[weatherTool.Name()] = weatherTool
|
|
//}
|
|
//
|
|
//// 注册计算器工具
|
|
//if config.Tools.Calculator.Enabled {
|
|
// calcTool := NewCalculatorTool()
|
|
// m.tools[calcTool.Name()] = calcTool
|
|
//}
|
|
|
|
// 注册知识库工具
|
|
// if config.Knowledge.Enabled {
|
|
// knowledgeTool := NewKnowledgeTool()
|
|
// m.tools[knowledgeTool.Name()] = knowledgeTool
|
|
// }
|
|
|
|
// 注册直连天下订单详情工具
|
|
if config.Tools.ZltxOrderDetail.Enabled {
|
|
zltxOrderDetailTool := NewZltxOrderDetailTool(config.Tools.ZltxOrderDetail, m.llm)
|
|
m.tools[zltxOrderDetailTool.Name()] = zltxOrderDetailTool
|
|
}
|
|
|
|
//注册直连天下订单日志工具
|
|
if config.Tools.ZltxOrderDirectLog.Enabled {
|
|
zltxOrderLogTool := NewZltxOrderLogTool(config.Tools.ZltxOrderDirectLog)
|
|
m.tools[zltxOrderLogTool.Name()] = zltxOrderLogTool
|
|
}
|
|
|
|
//注册直连天下商品工具
|
|
if config.Tools.ZltxProduct.Enabled {
|
|
zltxProductTool := NewZltxProductTool(config.Tools.ZltxProduct)
|
|
m.tools[zltxProductTool.Name()] = zltxProductTool
|
|
}
|
|
//注册直连天下订单统计工具
|
|
if config.Tools.ZltxOrderStatistics.Enabled {
|
|
zltxOrderStatisticsTool := NewZltxOrderStatisticsTool(config.Tools.ZltxOrderStatistics)
|
|
m.tools[zltxOrderStatisticsTool.Name()] = zltxOrderStatisticsTool
|
|
}
|
|
|
|
// 注册知识库工具
|
|
if config.Tools.Knowledge.Enabled {
|
|
knowledgeTool := NewKnowledgeBaseTool(config.Tools.Knowledge)
|
|
m.tools[knowledgeTool.Name()] = knowledgeTool
|
|
}
|
|
return m
|
|
}
|
|
|
|
// GetTool 获取工具
|
|
func (m *Manager) GetTool(name string) (entitys.Tool, bool) {
|
|
tool, exists := m.tools[name]
|
|
return tool, exists
|
|
}
|
|
|
|
// GetAllTools 获取所有工具
|
|
func (m *Manager) GetAllTools() []entitys.Tool {
|
|
tools := make([]entitys.Tool, 0, len(m.tools))
|
|
for _, tool := range m.tools {
|
|
tools = append(tools, tool)
|
|
}
|
|
return tools
|
|
}
|
|
|
|
// GetToolDefinitions 获取所有工具定义
|
|
func (m *Manager) GetToolDefinitions(caller constants.Caller) []entitys.ToolDefinition {
|
|
definitions := make([]entitys.ToolDefinition, 0, len(m.tools))
|
|
for _, tool := range m.tools {
|
|
definitions = append(definitions, tool.Definition())
|
|
}
|
|
|
|
return definitions
|
|
}
|
|
|
|
// ExecuteTool 执行工具
|
|
func (m *Manager) ExecuteTool(channel chan entitys.Response, c *websocket.Conn, name string, args json.RawMessage, matchJson *entitys.Match) error {
|
|
tool, exists := m.GetTool(name)
|
|
if !exists {
|
|
return fmt.Errorf("tool not found: %s", name)
|
|
}
|
|
|
|
return tool.Execute(channel, c, args, matchJson)
|
|
}
|
|
|
|
// ExecuteToolCalls 执行多个工具调用
|
|
//func (m *Manager) ExecuteToolCalls(c *websocket.Conn, toolCalls []entitys.ToolCall) ([]entitys.ToolCall, error) {
|
|
// results := make([]entitys.ToolCall, len(toolCalls))
|
|
//
|
|
// for i, toolCall := range toolCalls {
|
|
// results[i] = toolCall
|
|
//
|
|
// // 执行工具
|
|
// err := m.ExecuteTool(c, toolCall.Function.Name, toolCall.Function.Arguments)
|
|
// if err != nil {
|
|
// // 将错误信息作为结果返回
|
|
// errorResult := map[string]interface{}{
|
|
// "error": err.Error(),
|
|
// }
|
|
// resultBytes, _ := json.Marshal(errorResult)
|
|
// results[i].Result = resultBytes
|
|
// } else {
|
|
// // 将成功结果序列化
|
|
// resultBytes, err := json.Marshal(result)
|
|
// if err != nil {
|
|
// errorResult := map[string]interface{}{
|
|
// "error": fmt.Sprintf("failed to serialize result: %v", err),
|
|
// }
|
|
// resultBytes, _ = json.Marshal(errorResult)
|
|
// }
|
|
// results[i].Result = resultBytes
|
|
// }
|
|
// }
|
|
//
|
|
// return results, nil
|
|
//}
|