101 lines
2.4 KiB
Go
101 lines
2.4 KiB
Go
package tools
|
|
|
|
import (
|
|
"ai_scheduler/internal/config"
|
|
"ai_scheduler/pkg/types"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
)
|
|
|
|
// Manager 工具管理器
|
|
type Manager struct {
|
|
tools map[string]types.Tool
|
|
}
|
|
|
|
// NewManager 创建工具管理器
|
|
func NewManager(config *config.ToolsConfig) *Manager {
|
|
m := &Manager{
|
|
tools: make(map[string]types.Tool),
|
|
}
|
|
|
|
// 注册天气工具
|
|
if config.Weather.Enabled {
|
|
weatherTool := NewWeatherTool(config.Weather.MockData)
|
|
m.tools[weatherTool.Name()] = weatherTool
|
|
}
|
|
|
|
// 注册计算器工具
|
|
if config.Calculator.Enabled {
|
|
calcTool := NewCalculatorTool()
|
|
m.tools[calcTool.Name()] = calcTool
|
|
}
|
|
|
|
return m
|
|
}
|
|
|
|
// GetTool 获取工具
|
|
func (m *Manager) GetTool(name string) (types.Tool, bool) {
|
|
tool, exists := m.tools[name]
|
|
return tool, exists
|
|
}
|
|
|
|
// GetAllTools 获取所有工具
|
|
func (m *Manager) GetAllTools() []types.Tool {
|
|
tools := make([]types.Tool, 0, len(m.tools))
|
|
for _, tool := range m.tools {
|
|
tools = append(tools, tool)
|
|
}
|
|
return tools
|
|
}
|
|
|
|
// GetToolDefinitions 获取所有工具定义
|
|
func (m *Manager) GetToolDefinitions() []types.ToolDefinition {
|
|
definitions := make([]types.ToolDefinition, 0, len(m.tools))
|
|
for _, tool := range m.tools {
|
|
definitions = append(definitions, tool.Definition())
|
|
}
|
|
return definitions
|
|
}
|
|
|
|
// ExecuteTool 执行工具
|
|
func (m *Manager) ExecuteTool(ctx context.Context, name string, args json.RawMessage) (interface{}, error) {
|
|
tool, exists := m.GetTool(name)
|
|
if !exists {
|
|
return nil, fmt.Errorf("tool not found: %s", name)
|
|
}
|
|
|
|
return tool.Execute(ctx, args)
|
|
}
|
|
|
|
// ExecuteToolCalls 执行多个工具调用
|
|
func (m *Manager) ExecuteToolCalls(ctx context.Context, toolCalls []types.ToolCall) ([]types.ToolCall, error) {
|
|
results := make([]types.ToolCall, len(toolCalls))
|
|
|
|
for i, toolCall := range toolCalls {
|
|
results[i] = toolCall
|
|
|
|
// 执行工具
|
|
result, err := m.ExecuteTool(ctx, toolCall.Function.Name, toolCall.Function.Arguments)
|
|
if err != nil {
|
|
// 将错误信息作为结果返回
|
|
errorResult := map[string]interface{}{
|
|
"error": err.Error(),
|
|
}
|
|
resultBytes, _ := json.Marshal(errorResult)
|
|
results[i].Result = resultBytes
|
|
} else {
|
|
// 将成功结果序列化
|
|
resultBytes, err := json.Marshal(result)
|
|
if err != nil {
|
|
errorResult := map[string]interface{}{
|
|
"error": fmt.Sprintf("failed to serialize result: %v", err),
|
|
}
|
|
resultBytes, _ = json.Marshal(errorResult)
|
|
}
|
|
results[i].Result = resultBytes
|
|
}
|
|
}
|
|
|
|
return results, nil
|
|
} |