146 lines
		
	
	
		
			4.0 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			146 lines
		
	
	
		
			4.0 KiB
		
	
	
	
		
			Go
		
	
	
	
package tools
 | 
						|
 | 
						|
import (
 | 
						|
	"ai_scheduler/internal/config"
 | 
						|
	"ai_scheduler/internal/data/constants"
 | 
						|
	"ai_scheduler/internal/entitys"
 | 
						|
	"ai_scheduler/internal/pkg/utils_ollama"
 | 
						|
	"context"
 | 
						|
 | 
						|
	"fmt"
 | 
						|
)
 | 
						|
 | 
						|
// Manager 工具管理器
 | 
						|
type Manager struct {
 | 
						|
	tools map[string]entitys.Tool
 | 
						|
	llm   *utils_ollama.Client
 | 
						|
}
 | 
						|
 | 
						|
// NewManager 创建工具管理器
 | 
						|
func NewManager(config *config.Config, llm *utils_ollama.Client) *Manager {
 | 
						|
	m := &Manager{
 | 
						|
		tools: make(map[string]entitys.Tool),
 | 
						|
		llm:   llm,
 | 
						|
	}
 | 
						|
 | 
						|
	// 注册天气工具
 | 
						|
	//if config.Tools.Weather.Enabled {
 | 
						|
	//	weatherTool := NewWeatherTool()
 | 
						|
	//	m.tools[weatherTool.Name()] = weatherTool
 | 
						|
	//}
 | 
						|
	//
 | 
						|
	//// 注册计算器工具
 | 
						|
	//if config.Tools.Calculator.Enabled {
 | 
						|
	//	calcTool := NewCalculatorTool()
 | 
						|
	//	m.tools[calcTool.Name()] = calcTool
 | 
						|
	//}
 | 
						|
 | 
						|
	// 注册知识库工具
 | 
						|
	// if config.Knowledge.Enabled {
 | 
						|
	// 	knowledgeTool := NewKnowledgeTool()
 | 
						|
	// 	m.tools[knowledgeTool.Name()] = knowledgeTool
 | 
						|
	// }
 | 
						|
 | 
						|
	// 注册直连天下订单详情工具
 | 
						|
	if config.Tools.ZltxOrderDetail.Enabled {
 | 
						|
		zltxOrderDetailTool := NewZltxOrderDetailTool(config.Tools.ZltxOrderDetail, m.llm)
 | 
						|
		m.tools[zltxOrderDetailTool.Name()] = zltxOrderDetailTool
 | 
						|
	}
 | 
						|
 | 
						|
	//注册直连天下订单日志工具
 | 
						|
	if config.Tools.ZltxOrderDirectLog.Enabled {
 | 
						|
		zltxOrderLogTool := NewZltxOrderLogTool(config.Tools.ZltxOrderDirectLog)
 | 
						|
		m.tools[zltxOrderLogTool.Name()] = zltxOrderLogTool
 | 
						|
	}
 | 
						|
 | 
						|
	//注册直连天下商品工具
 | 
						|
	if config.Tools.ZltxProduct.Enabled {
 | 
						|
		zltxProductTool := NewZltxProductTool(config.Tools.ZltxProduct)
 | 
						|
		m.tools[zltxProductTool.Name()] = zltxProductTool
 | 
						|
	}
 | 
						|
	//注册直连天下订单统计工具
 | 
						|
	if config.Tools.ZltxOrderStatistics.Enabled {
 | 
						|
		zltxOrderStatisticsTool := NewZltxOrderStatisticsTool(config.Tools.ZltxOrderStatistics)
 | 
						|
		m.tools[zltxOrderStatisticsTool.Name()] = zltxOrderStatisticsTool
 | 
						|
	}
 | 
						|
 | 
						|
	// 注册知识库工具
 | 
						|
	if config.Tools.Knowledge.Enabled {
 | 
						|
		knowledgeTool := NewKnowledgeBaseTool(config.Tools.Knowledge)
 | 
						|
		m.tools[knowledgeTool.Name()] = knowledgeTool
 | 
						|
	}
 | 
						|
 | 
						|
	// 普通对话
 | 
						|
	chat := NewNormalChatTool(m.llm, config)
 | 
						|
	m.tools[chat.Name()] = chat
 | 
						|
 | 
						|
	return m
 | 
						|
}
 | 
						|
 | 
						|
// GetTool 获取工具
 | 
						|
func (m *Manager) GetTool(name string) (entitys.Tool, bool) {
 | 
						|
	tool, exists := m.tools[name]
 | 
						|
	return tool, exists
 | 
						|
}
 | 
						|
 | 
						|
// GetAllTools 获取所有工具
 | 
						|
func (m *Manager) GetAllTools() []entitys.Tool {
 | 
						|
	tools := make([]entitys.Tool, 0, len(m.tools))
 | 
						|
	for _, tool := range m.tools {
 | 
						|
		tools = append(tools, tool)
 | 
						|
	}
 | 
						|
	return tools
 | 
						|
}
 | 
						|
 | 
						|
// GetToolDefinitions 获取所有工具定义
 | 
						|
func (m *Manager) GetToolDefinitions(caller constants.Caller) []entitys.ToolDefinition {
 | 
						|
	definitions := make([]entitys.ToolDefinition, 0, len(m.tools))
 | 
						|
	for _, tool := range m.tools {
 | 
						|
		definitions = append(definitions, tool.Definition())
 | 
						|
	}
 | 
						|
 | 
						|
	return definitions
 | 
						|
}
 | 
						|
 | 
						|
// ExecuteTool 执行工具
 | 
						|
func (m *Manager) ExecuteTool(ctx context.Context, name string, requireData *entitys.RequireData) error {
 | 
						|
	tool, exists := m.GetTool(name)
 | 
						|
	if !exists {
 | 
						|
		return fmt.Errorf("tool not found: %s", name)
 | 
						|
	}
 | 
						|
 | 
						|
	return tool.Execute(ctx, requireData)
 | 
						|
}
 | 
						|
 | 
						|
// ExecuteToolCalls 执行多个工具调用
 | 
						|
//func (m *Manager) ExecuteToolCalls(c *websocket.Conn, toolCalls []entitys.ToolCall) ([]entitys.ToolCall, error) {
 | 
						|
//	results := make([]entitys.ToolCall, len(toolCalls))
 | 
						|
//
 | 
						|
//	for i, toolCall := range toolCalls {
 | 
						|
//		results[i] = toolCall
 | 
						|
//
 | 
						|
//		// 执行工具
 | 
						|
//		err := m.ExecuteTool(c, toolCall.Function.Name, toolCall.Function.Arguments)
 | 
						|
//		if err != nil {
 | 
						|
//			// 将错误信息作为结果返回
 | 
						|
//			errorResult := map[string]interface{}{
 | 
						|
//				"error": err.Error(),
 | 
						|
//			}
 | 
						|
//			resultBytes, _ := json.Marshal(errorResult)
 | 
						|
//			results[i].Result = resultBytes
 | 
						|
//		} else {
 | 
						|
//			// 将成功结果序列化
 | 
						|
//			resultBytes, err := json.Marshal(result)
 | 
						|
//			if err != nil {
 | 
						|
//				errorResult := map[string]interface{}{
 | 
						|
//					"error": fmt.Sprintf("failed to serialize result: %v", err),
 | 
						|
//				}
 | 
						|
//				resultBytes, _ = json.Marshal(errorResult)
 | 
						|
//			}
 | 
						|
//			results[i].Result = resultBytes
 | 
						|
//		}
 | 
						|
//	}
 | 
						|
//
 | 
						|
//	return results, nil
 | 
						|
//}
 |