diff --git a/internal/biz/router.go b/internal/biz/router.go index b2384f1..f721791 100644 --- a/internal/biz/router.go +++ b/internal/biz/router.go @@ -15,13 +15,11 @@ import ( "context" "encoding/json" "fmt" - "strings" "gitea.cdlsxd.cn/self-tools/l_request" "github.com/gofiber/fiber/v2/log" "github.com/gofiber/websocket/v2" - "github.com/ollama/ollama/api" "github.com/tmc/langchaingo/llms" "xorm.io/builder" ) @@ -35,7 +33,7 @@ type AiRouterBiz struct { taskImpl *impl.TaskImpl hisImpl *impl.ChatImpl conf *config.Config - ai *utils_ollama.Client + utilAgent *utils_ollama.UtilOllama } // NewRouterService 创建路由服务 @@ -47,7 +45,7 @@ func NewAiRouterBiz( taskImpl *impl.TaskImpl, hisImpl *impl.ChatImpl, conf *config.Config, - ai *utils_ollama.Client, + utilAgent *utils_ollama.UtilOllama, ) *AiRouterBiz { return &AiRouterBiz{ @@ -58,7 +56,7 @@ func NewAiRouterBiz( sysImpl: sysImpl, hisImpl: hisImpl, taskImpl: taskImpl, - ai: ai, + utilAgent: utilAgent, } } @@ -69,8 +67,15 @@ func (r *AiRouterBiz) Route(ctx context.Context, req *entitys.ChatRequest) (*ent } // Route 执行智能路由 -func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) error { +func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) (err error) { + var ch = make(chan []byte) + defer func() { + if err != nil { + _ = c.WriteMessage(websocket.TextMessage, []byte(err.Error())) + } + _ = c.WriteMessage(websocket.TextMessage, []byte("EOF")) + }() session := c.Headers("X-Session", "") if len(session) == 0 { return errors.SessionNotFound @@ -98,118 +103,53 @@ func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRe if err != nil { return errors.SystemError } - - //toolDefinitions := r.registerTools(task) - //prompt := r.getPrompt(sysInfo, history, req.Text) - //意图预测 prompt := r.getPromptLLM(sysInfo, history, req.Text, task) - toolDefinitions := r.registerTools(task) - match, err := r.ai.ToolSelect(context.TODO(), prompt, toolDefinitions) + match, err := r.utilAgent.Llm.GenerateContent(context.TODO(), prompt, + llms.WithJSONMode(), + ) if err != nil { return errors.SystemError } - log.Info(match) - //var matchJson entitys.Match - //err = json.Unmarshal([]byte(match.Choices[0].Content), &matchJson) - //if err != nil { - // return errors.SystemError - //} - //return r.handleMatch(c, &matchJson, task) - c.WriteMessage(1, []byte(match.Message.Content)) - // 构建消息 - //messages := []entitys.Message{ - // { - // Role: "user", - // Content: req.UserInput, - // }, - //} - // - //// 第1次调用AI,获取用户意图 - //intentResponse, err := r.aiClient.Chat(ctx, messages, nil) - //if err != nil { - // return nil, fmt.Errorf("AI响应失败: %w", err) - //} - // - //// 从AI响应中提取意图 - //intent := r.extractIntent(intentResponse) - //if intent == "" { - // return nil, fmt.Errorf("未识别到用户意图") - //} - // - //switch intent { - //case "order_diagnosis": - // // 订单诊断意图 - // return r.handleOrderDiagnosis(ctx, req, messages) - //case "knowledge_qa": - // // 知识问答意图 - // return r.handleKnowledgeQA(ctx, req, messages) - //default: - // // 未知意图 - // return nil, fmt.Errorf("意图识别失败,请明确您的需求呢,我可以为您") - //} - // - //// 获取工具定义 - //toolDefinitions := r.toolManager.GetToolDefinitions(constants.Caller(req.Caller)) - // - //// 第2次调用AI,获取是否需要使用工具 - //response, err := r.aiClient.Chat(ctx, messages, toolDefinitions) - //if err != nil { - // return nil, fmt.Errorf("failed to chat with AI: %w", err) - //} - // - //// 如果没有工具调用,直接返回 - //if len(response.ToolCalls) == 0 { - // return response, nil - //} - // - //// 执行工具调用 - //toolResults, err := r.toolManager.ExecuteToolCalls(ctx, response.ToolCalls) - //if err != nil { - // return nil, fmt.Errorf("failed to execute tools: %w", err) - //} - // - //// 构建包含工具结果的消息 - //messages = append(messages, entitys.Message{ - // Role: "assistant", - // Content: response.Message, - //}) - // - //// 添加工具调用结果 - //for _, toolResult := range toolResults { - // toolResultStr, _ := json.Marshal(toolResult.Result) - // messages = append(messages, entitys.Message{ - // Role: "tool", - // Content: fmt.Sprintf("Tool %s result: %s", toolResult.Function.Name, string(toolResultStr)), - // }) - //} - // - //// 第二次调用AI,生成最终回复 - //finalResponse, err := r.aiClient.Chat(ctx, messages, nil) - //if err != nil { - // return nil, fmt.Errorf("failed to generate final response: %w", err) - //} - // - //// 合并工具调用信息到最终响应 - //finalResponse.ToolCalls = toolResults - // - //log.Printf("Router processed request: %s, used %d tools", req.UserInput, len(toolResults)) + log.Info(match.Choices[0].Content) + _ = c.WriteMessage(websocket.TextMessage, []byte(match.Choices[0].Content)) + var matchJson entitys.Match + err = json.Unmarshal([]byte(match.Choices[0].Content), &matchJson) + if err != nil { + return errors.SystemError + } + go func() { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("recovered from panic: %v", r) + } + }() + defer close(ch) + err = r.handleMatch(c, ch, &matchJson, task) + if err != nil { + return + } + }() - //return finalResponse, nil - return nil + for v := range ch { + if err := c.WriteMessage(websocket.TextMessage, v); err != nil { + return err + } + } + _ = c.WriteMessage(websocket.TextMessage, []byte("结束")) + return } -func (r *AiRouterBiz) handleMatch(c *websocket.Conn, matchJson *entitys.Match, tasks []model.AiTask) (err error) { - var resChan = make(chan []byte, 10) - defer func() { - close(resChan) - if err != nil { - c.WriteMessage(websocket.TextMessage, []byte(err.Error())) - } - c.WriteMessage(websocket.TextMessage, []byte("EOF")) - }() +func (r *AiRouterBiz) handleOtherTask(c *websocket.Conn, ch chan []byte, matchJson *entitys.Match) (err error) { + ch <- []byte(matchJson.Reasoning) + + return +} + +func (r *AiRouterBiz) handleMatch(c *websocket.Conn, ch chan []byte, matchJson *entitys.Match, tasks []model.AiTask) (err error) { + if !matchJson.IsMatch { - c.WriteMessage(websocket.TextMessage, []byte(matchJson.Reasoning)) + ch <- []byte(matchJson.Reasoning) return } var pointTask *model.AiTask @@ -221,23 +161,16 @@ func (r *AiRouterBiz) handleMatch(c *websocket.Conn, matchJson *entitys.Match, t } if pointTask == nil || pointTask.Index == "other" { - return r.handleOtherTask(resChan, c, matchJson) + return r.handleOtherTask(c, ch, matchJson) } switch pointTask.Type { case constant.TaskTypeApi: - err = r.handleApiTask(resChan, c, matchJson, pointTask) + return r.handleApiTask(ch, c, matchJson, pointTask) case constant.TaskTypeFunc: - err = r.handleTask(resChan, c, matchJson, pointTask) + return r.handleTask(ch, c, matchJson, pointTask) default: - return r.handleOtherTask(resChan, c, matchJson) + return r.handleOtherTask(c, ch, matchJson) } - select { - case v := <-resChan: // 尝试接收 - fmt.Println("接收到值:", v) - default: - fmt.Println("无数据可接收") - } - return } func (r *AiRouterBiz) handleTask(channel chan []byte, c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) { @@ -255,11 +188,6 @@ func (r *AiRouterBiz) handleTask(channel chan []byte, c *websocket.Conn, matchJs return } -func (r *AiRouterBiz) handleOtherTask(channel chan []byte, c *websocket.Conn, matchJson *entitys.Match) (err error) { - channel <- []byte(matchJson.Reasoning) - return -} - func (r *AiRouterBiz) handleApiTask(channels chan []byte, c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) { var ( request l_request.Request @@ -325,25 +253,20 @@ func (r *AiRouterBiz) getTasks(sysId int32) (tasks []model.AiTask, err error) { return } -func (r *AiRouterBiz) registerTools(tasks []model.AiTask) []api.Tool { - taskPrompt := make([]api.Tool, 0) +func (r *AiRouterBiz) registerTools(tasks []model.AiTask) []llms.Tool { + taskPrompt := make([]llms.Tool, 0) for _, task := range tasks { - var taskConfig entitys.TaskConfigDetail + var taskConfig entitys.TaskConfig err := json.Unmarshal([]byte(task.Config), &taskConfig) if err != nil { continue } - - taskPrompt = append(taskPrompt, api.Tool{ + taskPrompt = append(taskPrompt, llms.Tool{ Type: "function", - Function: api.ToolFunction{ + Function: &llms.FunctionDefinition{ Name: task.Index, Description: task.Desc, - Parameters: api.ToolFunctionParameters{ - Type: taskConfig.Param.Type, - Required: taskConfig.Param.Required, - Properties: taskConfig.Param.Properties, - }, + Parameters: taskConfig.Param, }, }) @@ -368,22 +291,30 @@ func (r *AiRouterBiz) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, re return prompt } -func (r *AiRouterBiz) getPromptLLM(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []api.Message { +func (r *AiRouterBiz) getPromptLLM(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []llms.MessageContent { var ( - prompt = make([]api.Message, 0) + prompt = make([]llms.MessageContent, 0) ) - prompt = append(prompt, api.Message{ - Role: string(llms.ChatMessageTypeSystem), - Content: r.buildSystemPrompt(sysInfo.SysPrompt), - }, api.Message{ - Role: "chatHistory", - Content: pkg.JsonStringIgonErr(r.buildAssistant(history)), - }, api.Message{ - Role: string(llms.ChatMessageTypeTool), - Content: pkg.JsonStringIgonErr(r.registerTools(tasks)), - }, api.Message{ - Role: string(llms.ChatMessageTypeHuman), - Content: reqInput, + prompt = append(prompt, llms.MessageContent{ + Role: llms.ChatMessageTypeSystem, + Parts: []llms.ContentPart{ + llms.TextPart(r.buildSystemPrompt(sysInfo.SysPrompt)), + }, + }, llms.MessageContent{ + Role: llms.ChatMessageTypeTool, + Parts: []llms.ContentPart{ + llms.TextPart(pkg.JsonStringIgonErr(r.buildAssistant(history))), + }, + }, llms.MessageContent{ + Role: llms.ChatMessageTypeTool, + Parts: []llms.ContentPart{ + llms.TextPart(pkg.JsonStringIgonErr(r.registerTools(tasks))), + }, + }, llms.MessageContent{ + Role: llms.ChatMessageTypeHuman, + Parts: []llms.ContentPart{ + llms.TextPart(reqInput), + }, }) return prompt } diff --git a/internal/entitys/types.go b/internal/entitys/types.go index bdb9702..820bedc 100644 --- a/internal/entitys/types.go +++ b/internal/entitys/types.go @@ -104,11 +104,11 @@ type ConfigParam struct { Type string `json:"type"` } type Match struct { - Confidence float64 `json:"confidence"` - Index string `json:"index"` - IsMatch bool `json:"is_match"` - Parameters string `json:"parameters"` - Reasoning string `json:"reasoning"` + Confidence string `json:"confidence"` + Index string `json:"index"` + IsMatch bool `json:"is_match"` + Parameters string `json:"parameters"` + Reasoning string `json:"reasoning"` } type ChatHis struct { SessionId string `json:"session_id"` diff --git a/internal/pkg/utils_ollama/client.go b/internal/pkg/utils_ollama/client.go new file mode 100644 index 0000000..bae73cc --- /dev/null +++ b/internal/pkg/utils_ollama/client.go @@ -0,0 +1,131 @@ +package utils_ollama + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/entitys" + "context" + "net/http" + "net/url" + "os" + "sync" + + "github.com/ollama/ollama/api" +) + +// Client Ollama客户端适配器 +type Client struct { + client *api.Client + config *config.OllamaConfig +} + +// NewClient 创建新的Ollama客户端 +func NewClient(config *config.Config) (client *Client, cleanFunc func(), err error) { + client = &Client{ + config: &config.Ollama, + } + url, err := client.getUrl() + if err != nil { + return + } + client.client = api.NewClient(url, http.DefaultClient) + + cleanup := func() { + if client != nil { + client = nil + } + } + + return client, cleanup, nil +} + +// ToolSelect 工具选择 +func (c *Client) ToolSelect(ctx context.Context, messages []api.Message, tools []api.Tool) (res api.ChatResponse, err error) { + // 构建聊天请求 + req := &api.ChatRequest{ + Model: c.config.Model, + Messages: messages, + Stream: new(bool), // 设置为false,不使用流式响应 + Think: &api.ThinkValue{Value: true}, + Tools: tools, + } + + err = c.client.Chat(ctx, req, func(resp api.ChatResponse) error { + res = resp + return nil + }) + if err != nil { + return + } + + return +} + +func (c *Client) ChatStream(ctx context.Context, ch chan []byte, messages []api.Message) (err error) { + // 构建聊天请求 + req := &api.ChatRequest{ + Model: c.config.Model, + Messages: messages, + Stream: nil, + Think: &api.ThinkValue{Value: true}, + } + var w sync.WaitGroup + w.Add(1) + go func() { + defer w.Done() + err = c.client.Chat(ctx, req, func(resp api.ChatResponse) error { + if resp.Message.Content != "" { + ch <- []byte(resp.Message.Content) + } + if resp.Done { + ch <- []byte("EOF") + } + return nil + }) + if err != nil { + ch <- []byte("EOF") + return + } + }() + w.Wait() + + return +} + +// convertResponse 转换响应格式 +func (c *Client) convertResponse(resp *api.ChatResponse) *entitys.ChatResponse { + //result := &entitys.ChatResponse{ + // Message: resp.Message.Content, + // Finished: resp.Done, + //} + // + //// 转换工具调用 + //if len(resp.Message.ToolCalls) > 0 { + // result.ToolCalls = make([]entitys.ToolCall, len(resp.Message.ToolCalls)) + // for i, toolCall := range resp.Message.ToolCalls { + // // 转换函数参数 + // argBytes, _ := json.Marshal(toolCall.Function.Arguments) + // + // result.ToolCalls[i] = entitys.ToolCall{ + // ID: fmt.Sprintf("call_%d", i), + // Type: "function", + // Function: entitys.FunctionCall{ + // Name: toolCall.Function.Name, + // Arguments: json.RawMessage(argBytes), + // }, + // } + // } + //} + + //return result + return nil +} + +func (c *Client) getUrl() (*url.URL, error) { + baseURL := c.config.BaseURL + envURL := os.Getenv("OLLAMA_BASE_URL") + if envURL != "" { + baseURL = envURL + } + + return url.Parse(baseURL) +} diff --git a/internal/test/task_test.go b/internal/test/task_test.go index f46832f..b50a2cf 100644 --- a/internal/test/task_test.go +++ b/internal/test/task_test.go @@ -4,7 +4,9 @@ import ( "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg/mapstructure" "encoding/json" + "fmt" "testing" + "time" "gitea.cdlsxd.cn/self-tools/l_request" ) @@ -36,6 +38,24 @@ func Test_task2(t *testing.T) { t.Log(err) } -func in() { - +func producer(ch chan<- int) { + for i := 0; i < 100; i++ { + ch <- i // 发送数据到通道 + fmt.Printf("Sent: %d\n", i) + time.Sleep(500 * time.Millisecond) // 模拟生产延迟 + } + close(ch) // 关闭通道,通知接收方数据发送完毕 +} + +func consumer(ch <-chan int) { + for v := range ch { // 阻塞等待数据,有数据立即处理 + fmt.Printf("Received: %d\n", v) + } +} + +func Test_a(t *testing.T) { + ch := make(chan int, 3) // 有缓冲通道(可选) + + go producer(ch) + consumer(ch) // 主线程阻塞,直到通道关闭 } diff --git a/internal/tools/manager.go b/internal/tools/manager.go index f048b58..4e4e1fe 100644 --- a/internal/tools/manager.go +++ b/internal/tools/manager.go @@ -15,11 +15,11 @@ import ( // Manager 工具管理器 type Manager struct { tools map[string]entitys.Tool - llm *utils_ollama.UtilOllama + llm *utils_ollama.Client } // NewManager 创建工具管理器 -func NewManager(config *config.Config, llm *utils_ollama.UtilOllama) *Manager { +func NewManager(config *config.Config, llm *utils_ollama.Client) *Manager { m := &Manager{ tools: make(map[string]entitys.Tool), llm: llm, diff --git a/internal/tools/zltx_order_detail.go b/internal/tools/zltx_order_detail.go index db89642..1181e74 100644 --- a/internal/tools/zltx_order_detail.go +++ b/internal/tools/zltx_order_detail.go @@ -4,21 +4,23 @@ import ( "ai_scheduler/internal/config" "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg/utils_ollama" + "context" "encoding/json" "fmt" "gitea.cdlsxd.cn/self-tools/l_request" "github.com/gofiber/websocket/v2" + "github.com/ollama/ollama/api" ) // ZltxOrderDetailTool 直连天下订单详情工具 type ZltxOrderDetailTool struct { config config.ToolConfig - llm *utils_ollama.UtilOllama + llm *utils_ollama.Client } // NewZltxOrderDetailTool 创建直连天下订单详情工具 -func NewZltxOrderDetailTool(config config.ToolConfig, llm *utils_ollama.UtilOllama) *ZltxOrderDetailTool { +func NewZltxOrderDetailTool(config config.ToolConfig, llm *utils_ollama.Client) *ZltxOrderDetailTool { return &ZltxOrderDetailTool{config: config, llm: llm} } @@ -63,6 +65,7 @@ type ZltxOrderDetailResponse struct { Code int `json:"code"` Error string `json:"error"` Data ZltxOrderDetailData `json:"data"` + Mes string `json:"mes"` } type ZltxOrderLogResponse struct { @@ -112,15 +115,20 @@ func (w *ZltxOrderDetailTool) getZltxOrderDetail(ch chan []byte, c *websocket.Co res, err := req.Send() if err != nil { + return fmt.Errorf("订单查询失败:%s", err.Error()) + } + var codeMap map[string]interface{} + if err = json.Unmarshal(res.Content, &codeMap); err != nil { return } + if codeMap["code"].(float64) != 200 { + return fmt.Errorf("订单查询失败:%s", res.Text) + } + var resData ZltxOrderDetailResponse if err = json.Unmarshal(res.Content, &resData); err != nil { return } - if resData.Code != 200 { - return fmt.Errorf("订单查询失败:%s", resData.Error) - } ch <- res.Content if resData.Data.Direct != nil && resData.Data.Direct["needAi"].(bool) { ch <- []byte("orderErrorChecking") @@ -142,7 +150,23 @@ func (w *ZltxOrderDetailTool) getZltxOrderDetail(ch chan []byte, c *websocket.Co if orderLog.Code != 200 { return fmt.Errorf("订单日志查询失败:%s", orderLog.Error) } - + dataJson, err := json.Marshal(orderLog.Data) + if err != nil { + return fmt.Errorf("订单日志解析失败:%s", err) + } + err = w.llm.ChatStream(context.TODO(), ch, []api.Message{ + { + Role: "system", + Content: "你是一个订单日志助手。用户可能会提供订单日志,你需要分析订单日志,提取出订单失败的原因。", + }, + { + Role: "user", + Content: fmt.Sprintf("订单日志:%s", string(dataJson)), + }, + }) + if err != nil { + return fmt.Errorf("订单日志解析失败:%s", err) + } } return