diff --git a/internal/biz/router.go b/internal/biz/router.go index dac3809..b2384f1 100644 --- a/internal/biz/router.go +++ b/internal/biz/router.go @@ -21,6 +21,7 @@ import ( "gitea.cdlsxd.cn/self-tools/l_request" "github.com/gofiber/fiber/v2/log" "github.com/gofiber/websocket/v2" + "github.com/ollama/ollama/api" "github.com/tmc/langchaingo/llms" "xorm.io/builder" ) @@ -34,7 +35,7 @@ type AiRouterBiz struct { taskImpl *impl.TaskImpl hisImpl *impl.ChatImpl conf *config.Config - utilAgent *utils_ollama.UtilOllama + ai *utils_ollama.Client } // NewRouterService 创建路由服务 @@ -46,7 +47,7 @@ func NewAiRouterBiz( taskImpl *impl.TaskImpl, hisImpl *impl.ChatImpl, conf *config.Config, - utilAgent *utils_ollama.UtilOllama, + ai *utils_ollama.Client, ) *AiRouterBiz { return &AiRouterBiz{ @@ -57,7 +58,7 @@ func NewAiRouterBiz( sysImpl: sysImpl, hisImpl: hisImpl, taskImpl: taskImpl, - utilAgent: utilAgent, + ai: ai, } } @@ -103,22 +104,19 @@ func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRe //意图预测 prompt := r.getPromptLLM(sysInfo, history, req.Text, task) - match, err := r.utilAgent.Llm.GenerateContent(context.TODO(), prompt, - //llms.WithTools(toolDefinitions), - //llms.WithToolChoice("tool_name"), - llms.WithJSONMode(), - ) + toolDefinitions := r.registerTools(task) + match, err := r.ai.ToolSelect(context.TODO(), prompt, toolDefinitions) if err != nil { return errors.SystemError } log.Info(match) - var matchJson entitys.Match - err = json.Unmarshal([]byte(match.Choices[0].Content), &matchJson) - if err != nil { - return errors.SystemError - } - return r.handleMatch(c, &matchJson, task) - //c.WriteMessage(1, []byte(msg.Choices[0].Content)) + //var matchJson entitys.Match + //err = json.Unmarshal([]byte(match.Choices[0].Content), &matchJson) + //if err != nil { + // return errors.SystemError + //} + //return r.handleMatch(c, &matchJson, task) + c.WriteMessage(1, []byte(match.Message.Content)) // 构建消息 //messages := []entitys.Message{ // { @@ -327,20 +325,25 @@ func (r *AiRouterBiz) getTasks(sysId int32) (tasks []model.AiTask, err error) { return } -func (r *AiRouterBiz) registerTools(tasks []model.AiTask) []llms.Tool { - taskPrompt := make([]llms.Tool, 0) +func (r *AiRouterBiz) registerTools(tasks []model.AiTask) []api.Tool { + taskPrompt := make([]api.Tool, 0) for _, task := range tasks { - var taskConfig entitys.TaskConfig + var taskConfig entitys.TaskConfigDetail err := json.Unmarshal([]byte(task.Config), &taskConfig) if err != nil { continue } - taskPrompt = append(taskPrompt, llms.Tool{ + + taskPrompt = append(taskPrompt, api.Tool{ Type: "function", - Function: &llms.FunctionDefinition{ + Function: api.ToolFunction{ Name: task.Index, Description: task.Desc, - Parameters: taskConfig.Param, + Parameters: api.ToolFunctionParameters{ + Type: taskConfig.Param.Type, + Required: taskConfig.Param.Required, + Properties: taskConfig.Param.Properties, + }, }, }) @@ -365,30 +368,22 @@ func (r *AiRouterBiz) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, re return prompt } -func (r *AiRouterBiz) getPromptLLM(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []llms.MessageContent { +func (r *AiRouterBiz) getPromptLLM(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []api.Message { var ( - prompt = make([]llms.MessageContent, 0) + prompt = make([]api.Message, 0) ) - prompt = append(prompt, llms.MessageContent{ - Role: llms.ChatMessageTypeSystem, - Parts: []llms.ContentPart{ - llms.TextPart(r.buildSystemPrompt(sysInfo.SysPrompt)), - }, - }, llms.MessageContent{ - Role: llms.ChatMessageTypeTool, - Parts: []llms.ContentPart{ - llms.TextPart(pkg.JsonStringIgonErr(r.buildAssistant(history))), - }, - }, llms.MessageContent{ - Role: llms.ChatMessageTypeTool, - Parts: []llms.ContentPart{ - llms.TextPart(pkg.JsonStringIgonErr(r.registerTools(tasks))), - }, - }, llms.MessageContent{ - Role: llms.ChatMessageTypeHuman, - Parts: []llms.ContentPart{ - llms.TextPart(reqInput), - }, + prompt = append(prompt, api.Message{ + Role: string(llms.ChatMessageTypeSystem), + Content: r.buildSystemPrompt(sysInfo.SysPrompt), + }, api.Message{ + Role: "chatHistory", + Content: pkg.JsonStringIgonErr(r.buildAssistant(history)), + }, api.Message{ + Role: string(llms.ChatMessageTypeTool), + Content: pkg.JsonStringIgonErr(r.registerTools(tasks)), + }, api.Message{ + Role: string(llms.ChatMessageTypeHuman), + Content: reqInput, }) return prompt } diff --git a/internal/entitys/types.go b/internal/entitys/types.go index 3e61305..bdb9702 100644 --- a/internal/entitys/types.go +++ b/internal/entitys/types.go @@ -6,6 +6,7 @@ import ( "gitea.cdlsxd.cn/self-tools/l_request" "github.com/gofiber/websocket/v2" + "github.com/ollama/ollama/api" ) // ChatRequest 聊天请求 @@ -93,6 +94,15 @@ type FuncApi struct { type TaskConfig struct { Param interface{} `json:"param"` } + +type TaskConfigDetail struct { + Param ConfigParam `json:"param"` +} +type ConfigParam struct { + Properties map[string]api.ToolProperty + Required []string `json:"required"` + Type string `json:"type"` +} type Match struct { Confidence float64 `json:"confidence"` Index string `json:"index"` diff --git a/internal/pkg/provider_set.go b/internal/pkg/provider_set.go index 5986313..d123eed 100644 --- a/internal/pkg/provider_set.go +++ b/internal/pkg/provider_set.go @@ -10,4 +10,5 @@ var ProviderSetClient = wire.NewSet( NewRdb, NewGormDb, utils_ollama.NewUtilOllama, + utils_ollama.NewClient, ) diff --git a/internal/pkg/utils_ollama/client.go.bak b/internal/pkg/utils_ollama/client.go.bak deleted file mode 100644 index d20b7f0..0000000 --- a/internal/pkg/utils_ollama/client.go.bak +++ /dev/null @@ -1,124 +0,0 @@ -package utils_ollama - -import ( - "ai_scheduler/internal/config" - "ai_scheduler/internal/entitys" - "context" - "encoding/json" - "fmt" - "time" - - "github.com/ollama/ollama/api" - "github.com/tmc/langchaingo/llms/ollama" -) - -// Client Ollama客户端适配器 -type Client struct { - client *api.Client - config *config.OllamaConfig -} - -// NewClient 创建新的Ollama客户端 -func NewClient(config *config.Config) (entitys.AIClient, func(), error) { - client, err := api.ClientFromEnvironment() - cleanup := func() { - if client != nil { - client = nil - } - } - if err != nil { - return nil, cleanup, fmt.Errorf("failed to create ollama client: %w", err) - } - - return &Client{ - client: client, - config: &config.Ollama, - }, cleanup, nil -} - -// Chat 实现聊天功能 -func (c *Client) Chat(ctx context.Context, messages []entitys.Message, tools []entitys.ToolDefinition) (*entitys.ChatResponse, error) { - // 构建聊天请求 - req := &api.ChatRequest{ - Model: c.config.Model, - Messages: make([]api.Message, len(messages)), - Stream: new(bool), // 设置为false,不使用流式响应 - Think: &api.ThinkValue{Value: true}, - } - - // 转换消息格式 - for i, msg := range messages { - req.Messages[i] = api.Message{ - Role: msg.Role, - Content: msg.Content, - } - } - - // 添加工具定义 - if len(tools) > 0 { - req.Tools = make([]api.Tool, len(tools)) - for i, tool := range tools { - toolData, _ := json.Marshal(tool) - var apiTool api.Tool - json.Unmarshal(toolData, &apiTool) - req.Tools[i] = apiTool - } - } - - // 发送请求 - responseChan := make(chan api.ChatResponse) - errorChan := make(chan error) - - go func() { - err := c.client.Chat(ctx, req, func(resp api.ChatResponse) error { - responseChan <- resp - return nil - }) - if err != nil { - errorChan <- err - } - close(responseChan) - close(errorChan) - }() - - // 等待响应 - select { - case resp := <-responseChan: - return c.convertResponse(&resp), nil - case err := <-errorChan: - return nil, fmt.Errorf("chat request failed: %w", err) - case <-ctx.Done(): - return nil, ctx.Err() - case <-time.After(c.config.Timeout): - return nil, fmt.Errorf("chat request timeout") - } -} - -// convertResponse 转换响应格式 -func (c *Client) convertResponse(resp *api.ChatResponse) *entitys.ChatResponse { - //result := &entitys.ChatResponse{ - // Message: resp.Message.Content, - // Finished: resp.Done, - //} - // - //// 转换工具调用 - //if len(resp.Message.ToolCalls) > 0 { - // result.ToolCalls = make([]entitys.ToolCall, len(resp.Message.ToolCalls)) - // for i, toolCall := range resp.Message.ToolCalls { - // // 转换函数参数 - // argBytes, _ := json.Marshal(toolCall.Function.Arguments) - // - // result.ToolCalls[i] = entitys.ToolCall{ - // ID: fmt.Sprintf("call_%d", i), - // Type: "function", - // Function: entitys.FunctionCall{ - // Name: toolCall.Function.Name, - // Arguments: json.RawMessage(argBytes), - // }, - // } - // } - //} - - //return result - return nil -} diff --git a/internal/tools/manager.go b/internal/tools/manager.go index 40b9694..f048b58 100644 --- a/internal/tools/manager.go +++ b/internal/tools/manager.go @@ -45,7 +45,7 @@ func NewManager(config *config.Config, llm *utils_ollama.UtilOllama) *Manager { // 注册直连天下订单详情工具 if config.Tools.ZltxOrderDetail.Enabled { - zltxOrderDetailTool := NewZltxOrderDetailTool(config.Tools.ZltxOrderDetail) + zltxOrderDetailTool := NewZltxOrderDetailTool(config.Tools.ZltxOrderDetail, m.llm) m.tools[zltxOrderDetailTool.Name()] = zltxOrderDetailTool } diff --git a/internal/tools/zltx_order_detail.go b/internal/tools/zltx_order_detail.go index 44a83e9..db89642 100644 --- a/internal/tools/zltx_order_detail.go +++ b/internal/tools/zltx_order_detail.go @@ -3,6 +3,7 @@ package tools import ( "ai_scheduler/internal/config" "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg/utils_ollama" "encoding/json" "fmt" @@ -13,11 +14,12 @@ import ( // ZltxOrderDetailTool 直连天下订单详情工具 type ZltxOrderDetailTool struct { config config.ToolConfig + llm *utils_ollama.UtilOllama } // NewZltxOrderDetailTool 创建直连天下订单详情工具 -func NewZltxOrderDetailTool(config config.ToolConfig) *ZltxOrderDetailTool { - return &ZltxOrderDetailTool{config: config} +func NewZltxOrderDetailTool(config config.ToolConfig, llm *utils_ollama.UtilOllama) *ZltxOrderDetailTool { + return &ZltxOrderDetailTool{config: config, llm: llm} } // Name 返回工具名称