From ad66615c2e861bffd2deb4b2c0aa3ed5abd419a1 Mon Sep 17 00:00:00 2001 From: renzhiyuan <465386466@qq.com> Date: Wed, 24 Sep 2025 11:54:05 +0800 Subject: [PATCH] =?UTF-8?q?=E7=BB=93=E6=9E=84=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/biz/router.go | 172 ++++++++++++------------ internal/entitys/response.go | 32 +++-- internal/entitys/types.go | 2 +- internal/pkg/utils_ollama/client.go | 6 +- internal/tools/konwledge_base.go | 23 ++-- internal/tools/manager.go | 2 +- internal/tools/zltx_order_detail.go | 14 +- internal/tools/zltx_order_direct_log.go | 17 +-- internal/tools/zltx_product.go | 8 +- internal/tools/zltx_statistics.go | 17 +-- 10 files changed, 146 insertions(+), 147 deletions(-) diff --git a/internal/biz/router.go b/internal/biz/router.go index 1ed0ed8..87684fc 100644 --- a/internal/biz/router.go +++ b/internal/biz/router.go @@ -75,50 +75,6 @@ func (r *AiRouterBiz) Route(ctx context.Context, req *entitys.ChatRequest) (*ent func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) (err error) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - //ch := r.channelPool.Get() - ch := make(chan entitys.ResponseData) - done := make(chan struct{}) - - go func() { - defer close(done) - for { - select { - case v, ok := <-ch: - if !ok { - return - } - // 带超时的发送,避免阻塞 - if err := sendWithTimeout(c, v, 2*time.Second); err != nil { - log.Errorf("Send error: %v", err) - cancel() // 通知主流程退出 - return - } - case <-ctx.Done(): - return - } - } - }() - - defer func() { - - if err != nil { - _ = entitys.MsgSend(c, entitys.ResponseData{ - Done: false, - Content: err.Error(), - Type: entitys.ResponseErr, - }) - } - _ = entitys.MsgSend(c, entitys.ResponseData{ - Done: true, - Content: "", - Type: entitys.ResponseEnd, - }) - //r.channelPool.Put(ch) - close(ch) - }() - session := c.Headers("X-Session", "") if len(session) == 0 { return errors.SessionNotFound @@ -132,6 +88,77 @@ func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRe return errors.KeyNotFound } + var chat = make([]string, 0) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + //ch := r.channelPool.Get() + ch := make(chan entitys.Response) + done := make(chan struct{}) + go func() { + defer func() { + close(done) + if len(chat) > 0 { + + } + var his = []*model.AiChatHi{ + { + SessionID: session, + Role: "user", + Content: req.Text, + }, + } + if len(chat) > 0 { + his = append(his, &model.AiChatHi{ + SessionID: session, + Role: "assistant", + Content: strings.Join(chat, ""), + }) + } + for _, hi := range his { + r.hisImpl.Add(hi) + } + }() + for { + select { + case v, ok := <-ch: + if !ok { + return + } + // 带超时的发送,避免阻塞 + if err = sendWithTimeout(c, v, 2*time.Second); err != nil { + log.Errorf("Send error: %v", err) + cancel() // 通知主流程退出 + return + } + if v.Type == entitys.ResponseText || v.Type == entitys.ResponseStream { + chat = append(chat, v.Content) + } + + case <-ctx.Done(): + return + } + } + }() + + defer func() { + if err != nil { + _ = entitys.MsgSend(c, entitys.Response{ + + Content: err.Error(), + Type: entitys.ResponseErr, + }) + } + _ = entitys.MsgSend(c, entitys.Response{ + + Content: "", + Type: entitys.ResponseEnd, + }) + + //r.channelPool.Put(ch) + close(ch) + }() + sysInfo, err := r.getSysInfo(key) if err != nil { return errors.SysNotFound @@ -151,8 +178,8 @@ func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRe AgentClient := r.utilAgent.Get() - ch <- entitys.ResponseData{ - Done: false, + ch <- entitys.Response{ + Index: "", Content: "准备意图识别", Type: entitys.ResponseLog, } @@ -165,46 +192,21 @@ func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRe resMsg := match.Choices[0].Content r.utilAgent.Put(AgentClient) - ch <- entitys.ResponseData{ - Done: false, + ch <- entitys.Response{ + Index: "", Content: resMsg, Type: entitys.ResponseLog, } - ch <- entitys.ResponseData{ - Done: false, + ch <- entitys.Response{ + Index: "", Content: "意图识别结束", Type: entitys.ResponseLog, } - //for i := 1; i < 10; i++ { - // ch <- entitys.ResponseData{ - // Done: false, - // Content: fmt.Sprintf("%d", i), - // Type: entitys.ResponseLog, - // } - // time.Sleep(1 * time.Second) - //} - //return if err != nil { log.Errorf("LLM error: %v", err) return errors.SystemError } - - //msg, err := r.ollama.ToolSelect(ctx, r.getPromptOllama(sysInfo, history, req.Text), []api.Tool{}) - //if err != nil { - // return - //} - //resMsg := msg.Message.Content - select { - case ch <- entitys.ResponseData{ - Done: false, - Content: resMsg, - Type: entitys.ResponseLog, - }: - case <-ctx.Done(): - return ctx.Err() - } - var matchJson entitys.Match if err := json.Unmarshal([]byte(resMsg), &matchJson); err != nil { return errors.SystemError @@ -218,7 +220,7 @@ func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRe } // 辅助函数:带超时的 WebSocket 发送 -func sendWithTimeout(c *websocket.Conn, data entitys.ResponseData, timeout time.Duration) error { +func sendWithTimeout(c *websocket.Conn, data entitys.Response, timeout time.Duration) error { sendCtx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() @@ -234,9 +236,9 @@ func sendWithTimeout(c *websocket.Conn, data entitys.ResponseData, timeout time. return sendCtx.Err() } } -func (r *AiRouterBiz) handleOtherTask(c *websocket.Conn, ch chan entitys.ResponseData, matchJson *entitys.Match) (err error) { - ch <- entitys.ResponseData{ - Done: false, +func (r *AiRouterBiz) handleOtherTask(c *websocket.Conn, ch chan entitys.Response, matchJson *entitys.Match) (err error) { + ch <- entitys.Response{ + Index: "", Content: matchJson.Reasoning, Type: entitys.ResponseText, } @@ -244,11 +246,11 @@ func (r *AiRouterBiz) handleOtherTask(c *websocket.Conn, ch chan entitys.Respons return } -func (r *AiRouterBiz) handleMatch(ctx context.Context, c *websocket.Conn, ch chan entitys.ResponseData, matchJson *entitys.Match, tasks []model.AiTask, sysInfo model.AiSy) (err error) { +func (r *AiRouterBiz) handleMatch(ctx context.Context, c *websocket.Conn, ch chan entitys.Response, matchJson *entitys.Match, tasks []model.AiTask, sysInfo model.AiSy) (err error) { if !matchJson.IsMatch { - ch <- entitys.ResponseData{ - Done: false, + ch <- entitys.Response{ + Index: "", Content: matchJson.Reasoning, Type: entitys.ResponseText, } @@ -277,7 +279,7 @@ func (r *AiRouterBiz) handleMatch(ctx context.Context, c *websocket.Conn, ch cha } } -func (r *AiRouterBiz) handleTask(channel chan entitys.ResponseData, c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) { +func (r *AiRouterBiz) handleTask(channel chan entitys.Response, c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) { var configData entitys.ConfigDataTool err = json.Unmarshal([]byte(task.Config), &configData) @@ -293,7 +295,7 @@ func (r *AiRouterBiz) handleTask(channel chan entitys.ResponseData, c *websocket } // 知识库 -func (r *AiRouterBiz) handleKnowle(channel chan entitys.ResponseData, c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask, sysInfo model.AiSy) (err error) { +func (r *AiRouterBiz) handleKnowle(channel chan entitys.Response, c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask, sysInfo model.AiSy) (err error) { var ( configData entitys.ConfigDataTool @@ -376,7 +378,7 @@ func (r *AiRouterBiz) handleKnowle(channel chan entitys.ResponseData, c *websock return } -func (r *AiRouterBiz) handleApiTask(channels chan entitys.ResponseData, c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) { +func (r *AiRouterBiz) handleApiTask(channels chan entitys.Response, c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) { var ( request l_request.Request auth = c.Headers("X-Authorization", "") diff --git a/internal/entitys/response.go b/internal/entitys/response.go index e9ecc18..c4e65ee 100644 --- a/internal/entitys/response.go +++ b/internal/entitys/response.go @@ -6,27 +6,33 @@ import ( "github.com/gofiber/websocket/v2" ) -type Response string +type ResponseType string const ( - ResponseJson Response = "json" - ResponseLoading Response = "loading" - ResponseEnd Response = "end" - ResponseStream Response = "stream" - ResponseText Response = "txt" - ResponseImg Response = "img" - ResponseFile Response = "file" - ResponseErr Response = "error" - ResponseLog Response = "log" + ResponseJson ResponseType = "json" + ResponseLoading ResponseType = "loading" + ResponseEnd ResponseType = "end" + ResponseStream ResponseType = "stream" + ResponseText ResponseType = "txt" + ResponseImg ResponseType = "img" + ResponseFile ResponseType = "file" + ResponseErr ResponseType = "error" + ResponseLog ResponseType = "log" ) type ResponseData struct { Done bool Content string - Type Response + Type ResponseType } -func MsgSet(msgType Response, msg string, done bool) []byte { +type Response struct { + Content string + Type ResponseType + Index string +} + +func MsgSet(msgType ResponseType, msg string, done bool) []byte { jsonByte, err := json.Marshal(ResponseData{ Done: done, Content: msg, @@ -39,7 +45,7 @@ func MsgSet(msgType Response, msg string, done bool) []byte { return jsonByte } -func MsgSend(c *websocket.Conn, msg ResponseData) error { +func MsgSend(c *websocket.Conn, msg Response) error { jsonByte, _ := json.Marshal(msg) return c.WriteMessage(websocket.TextMessage, jsonByte) diff --git a/internal/entitys/types.go b/internal/entitys/types.go index 1d10a01..ad12791 100644 --- a/internal/entitys/types.go +++ b/internal/entitys/types.go @@ -67,7 +67,7 @@ type Tool interface { Name() string Description() string Definition() ToolDefinition - Execute(channel chan ResponseData, c *websocket.Conn, args json.RawMessage) error + Execute(channel chan Response, c *websocket.Conn, args json.RawMessage) error } type ConfigDataHttp struct { diff --git a/internal/pkg/utils_ollama/client.go b/internal/pkg/utils_ollama/client.go index 9d9ac82..93b1985 100644 --- a/internal/pkg/utils_ollama/client.go +++ b/internal/pkg/utils_ollama/client.go @@ -60,7 +60,7 @@ func (c *Client) ToolSelect(ctx context.Context, messages []api.Message, tools [ return } -func (c *Client) ChatStream(ctx context.Context, ch chan entitys.ResponseData, messages []api.Message) (err error) { +func (c *Client) ChatStream(ctx context.Context, ch chan entitys.Response, messages []api.Message, index string) (err error) { // 构建聊天请求 req := &api.ChatRequest{ Model: c.config.Model, @@ -74,8 +74,8 @@ func (c *Client) ChatStream(ctx context.Context, ch chan entitys.ResponseData, m defer w.Done() err = c.client.Chat(ctx, req, func(resp api.ChatResponse) error { if resp.Message.Content != "" { - ch <- entitys.ResponseData{ - Done: false, + ch <- entitys.Response{ + Index: index, Content: resp.Message.Content, Type: entitys.ResponseStream, } diff --git a/internal/tools/konwledge_base.go b/internal/tools/konwledge_base.go index b614c11..f7ad5a9 100644 --- a/internal/tools/konwledge_base.go +++ b/internal/tools/konwledge_base.go @@ -7,10 +7,11 @@ import ( "bufio" "encoding/json" "fmt" - "github.com/gofiber/fiber/v2/log" - "github.com/gofiber/websocket/v2" "net/http" "strings" + + "github.com/gofiber/fiber/v2/log" + "github.com/gofiber/websocket/v2" ) // 知识库工具 @@ -59,7 +60,7 @@ func (k *KnowledgeBaseTool) Definition() entitys.ToolDefinition { } // Execute 执行知识库查询 -func (k *KnowledgeBaseTool) Execute(channel chan entitys.ResponseData, c *websocket.Conn, args json.RawMessage) error { +func (k *KnowledgeBaseTool) Execute(channel chan entitys.Response, c *websocket.Conn, args json.RawMessage) error { var params KnowledgeBaseRequest if err := json.Unmarshal(args, ¶ms); err != nil { @@ -93,14 +94,14 @@ type MsgContent struct { } // 解析知识库响应内容,并把通过channel结果返回 -func msgContentParse(input string, channel chan entitys.ResponseData) (msgContent MsgContent, err error) { +func (this *KnowledgeBaseTool) msgContentParse(input string, channel chan entitys.Response) (msgContent MsgContent, err error) { err = json.Unmarshal([]byte(input), &msgContent) if err != nil { err = fmt.Errorf("unmarshal input failed: %w", err) } - channel <- entitys.ResponseData{ - Done: msgContent.Done, + channel <- entitys.Response{ + Index: this.Name(), Content: msgContent.Content, Type: entitys.ResponseStream, } @@ -109,7 +110,7 @@ func msgContentParse(input string, channel chan entitys.ResponseData) (msgConten } // 请求知识库聊天 -func (this *KnowledgeBaseTool) chat(channel chan entitys.ResponseData, c *websocket.Conn, param KnowledgeBaseRequest) (err error) { +func (this *KnowledgeBaseTool) chat(channel chan entitys.Response, c *websocket.Conn, param KnowledgeBaseRequest) (err error) { req := l_request.Request{ Method: "post", @@ -136,7 +137,7 @@ func (this *KnowledgeBaseTool) chat(channel chan entitys.ResponseData, c *websoc } defer rsp.Body.Close() - err = connectAndReadSSE(rsp, channel) + err = this.connectAndReadSSE(rsp, channel) if err != nil { return } @@ -145,7 +146,7 @@ func (this *KnowledgeBaseTool) chat(channel chan entitys.ResponseData, c *websoc } // 连接 SSE 并读取数据 -func connectAndReadSSE(resp *http.Response, channel chan entitys.ResponseData) error { +func (this *KnowledgeBaseTool) connectAndReadSSE(resp *http.Response, channel chan entitys.Response) error { // 验证响应状态和格式 if resp.StatusCode != http.StatusOK { @@ -165,7 +166,7 @@ func connectAndReadSSE(resp *http.Response, channel chan entitys.ResponseData) e if line == "" { // 空行表示一条消息结束,处理当前消息 if currentMsg.Data != "" || currentMsg.Event != "" || currentMsg.ID != "" { - _, err := msgContentParse(currentMsg.Data, channel) + _, err := this.msgContentParse(currentMsg.Data, channel) if err != nil { return fmt.Errorf("msgContentParse failed: %w", err) } @@ -200,7 +201,7 @@ func connectAndReadSSE(resp *http.Response, channel chan entitys.ResponseData) e // 处理最后一条未结束的消息(无结尾空行) if currentMsg.Data != "" || currentMsg.Event != "" || currentMsg.ID != "" { - _, err := msgContentParse(currentMsg.Data, channel) + _, err := this.msgContentParse(currentMsg.Data, channel) if err != nil { return fmt.Errorf("msgContentParse failed: %w", err) } diff --git a/internal/tools/manager.go b/internal/tools/manager.go index 1aa0281..0851f86 100644 --- a/internal/tools/manager.go +++ b/internal/tools/manager.go @@ -100,7 +100,7 @@ func (m *Manager) GetToolDefinitions(caller constants.Caller) []entitys.ToolDefi } // ExecuteTool 执行工具 -func (m *Manager) ExecuteTool(channel chan entitys.ResponseData, c *websocket.Conn, name string, args json.RawMessage) error { +func (m *Manager) ExecuteTool(channel chan entitys.Response, c *websocket.Conn, name string, args json.RawMessage) error { tool, exists := m.GetTool(name) if !exists { return fmt.Errorf("tool not found: %s", name) diff --git a/internal/tools/zltx_order_detail.go b/internal/tools/zltx_order_detail.go index 2b5be0f..79f56f7 100644 --- a/internal/tools/zltx_order_detail.go +++ b/internal/tools/zltx_order_detail.go @@ -81,7 +81,7 @@ type ZltxOrderDetailData struct { } // Execute 执行直连天下订单详情查询 -func (w *ZltxOrderDetailTool) Execute(channel chan entitys.ResponseData, c *websocket.Conn, args json.RawMessage) error { +func (w *ZltxOrderDetailTool) Execute(channel chan entitys.Response, c *websocket.Conn, args json.RawMessage) error { var req ZltxOrderDetailRequest if err := json.Unmarshal(args, &req); err != nil { return fmt.Errorf("invalid zltxOrderDetail request: %w", err) @@ -96,7 +96,7 @@ func (w *ZltxOrderDetailTool) Execute(channel chan entitys.ResponseData, c *webs } // getMockZltxOrderDetail 获取模拟直连天下订单详情数据 -func (w *ZltxOrderDetailTool) getZltxOrderDetail(ch chan entitys.ResponseData, c *websocket.Conn, number string) (err error) { +func (w *ZltxOrderDetailTool) getZltxOrderDetail(ch chan entitys.Response, c *websocket.Conn, number string) (err error) { //查询订单详情 var auth string if c != nil { @@ -129,14 +129,14 @@ func (w *ZltxOrderDetailTool) getZltxOrderDetail(ch chan entitys.ResponseData, c if err = json.Unmarshal(res.Content, &resData); err != nil { return } - ch <- entitys.ResponseData{ - Done: false, + ch <- entitys.Response{ + Index: w.Name(), Content: res.Text, Type: entitys.ResponseJson, } if resData.Data.Direct != nil && resData.Data.Direct["needAi"].(bool) { - ch <- entitys.ResponseData{ - Done: false, + ch <- entitys.Response{ + Index: w.Name(), Content: "正在分析订单日志", Type: entitys.ResponseLoading, } @@ -173,7 +173,7 @@ func (w *ZltxOrderDetailTool) getZltxOrderDetail(ch chan entitys.ResponseData, c Role: "user", Content: fmt.Sprintf("订单日志:%s", string(dataJson)), }, - }) + }, w.Name()) if err != nil { return fmt.Errorf("订单日志解析失败:%s", err) } diff --git a/internal/tools/zltx_order_direct_log.go b/internal/tools/zltx_order_direct_log.go index 51ec941..c6a0a95 100644 --- a/internal/tools/zltx_order_direct_log.go +++ b/internal/tools/zltx_order_direct_log.go @@ -67,7 +67,7 @@ type ZltxOrderDirectLogData struct { Data map[string]interface{} `json:"data"` } -func (t *ZltxOrderLogTool) Execute(channel chan entitys.ResponseData, c *websocket.Conn, args json.RawMessage) error { +func (t *ZltxOrderLogTool) Execute(channel chan entitys.Response, c *websocket.Conn, args json.RawMessage) error { var req ZltxOrderLogRequest if err := json.Unmarshal(args, &req); err != nil { return fmt.Errorf("invalid zltxOrderLog request: %w", err) @@ -78,7 +78,7 @@ func (t *ZltxOrderLogTool) Execute(channel chan entitys.ResponseData, c *websock return t.getZltxOrderLog(channel, c, req.OrderNumber, req.SerialNumber) } -func (t *ZltxOrderLogTool) getZltxOrderLog(channel chan entitys.ResponseData, c *websocket.Conn, orderNumber, serialNumber string) (err error) { +func (t *ZltxOrderLogTool) getZltxOrderLog(channel chan entitys.Response, c *websocket.Conn, orderNumber, serialNumber string) (err error) { //查询订单详情 var auth string if c != nil { @@ -106,15 +106,10 @@ func (t *ZltxOrderLogTool) getZltxOrderLog(channel chan entitys.ResponseData, c if err = json.Unmarshal(res.Content, &resData); err != nil { return } - if c != nil { - _ = c.WriteMessage(websocket.TextMessage, res.Content) - return - } else { - channel <- entitys.ResponseData{ - Done: false, - Content: res.Text, - Type: entitys.ResponseJson, - } + channel <- entitys.Response{ + Index: t.Name(), + Content: res.Text, + Type: entitys.ResponseJson, } return } diff --git a/internal/tools/zltx_product.go b/internal/tools/zltx_product.go index 0bd3d80..71a7daf 100644 --- a/internal/tools/zltx_product.go +++ b/internal/tools/zltx_product.go @@ -52,7 +52,7 @@ type ZltxProductRequest struct { Name string `json:"name"` } -func (z ZltxProductTool) Execute(channel chan entitys.ResponseData, c *websocket.Conn, args json.RawMessage) error { +func (z ZltxProductTool) Execute(channel chan entitys.Response, c *websocket.Conn, args json.RawMessage) error { var req ZltxProductRequest if err := json.Unmarshal(args, &req); err != nil { return fmt.Errorf("invalid zltxProduct request: %w", err) @@ -132,7 +132,7 @@ type ZltxProductData struct { PlatformProductList interface{} `json:"platform_product_list"` } -func (z ZltxProductTool) getZltxProduct(channel chan entitys.ResponseData, c *websocket.Conn, id string, name string) error { +func (z ZltxProductTool) getZltxProduct(channel chan entitys.Response, c *websocket.Conn, id string, name string) error { var auth string if c != nil { auth = c.Headers("X-Authorization", "") @@ -196,8 +196,8 @@ func (z ZltxProductTool) getZltxProduct(channel chan entitys.ResponseData, c *we if err != nil { return err } - channel <- entitys.ResponseData{ - Done: false, + channel <- entitys.Response{ + Index: z.Name(), Content: string(marshal), Type: entitys.ResponseJson, } diff --git a/internal/tools/zltx_statistics.go b/internal/tools/zltx_statistics.go index 9cedc12..64bc80b 100644 --- a/internal/tools/zltx_statistics.go +++ b/internal/tools/zltx_statistics.go @@ -46,7 +46,7 @@ type ZltxOrderStatisticsRequest struct { Number string `json:"number"` } -func (z ZltxOrderStatisticsTool) Execute(channel chan entitys.ResponseData, c *websocket.Conn, args json.RawMessage) error { +func (z ZltxOrderStatisticsTool) Execute(channel chan entitys.Response, c *websocket.Conn, args json.RawMessage) error { var req ZltxOrderStatisticsRequest if err := json.Unmarshal(args, &req); err != nil { return err @@ -74,7 +74,7 @@ type ZltxOrderStatisticsData struct { Total int `json:"total"` } -func (z ZltxOrderStatisticsTool) getZltxOrderStatistics(channel chan entitys.ResponseData, c *websocket.Conn, number string) error { +func (z ZltxOrderStatisticsTool) getZltxOrderStatistics(channel chan entitys.Response, c *websocket.Conn, number string) error { //查询订单详情 var auth string if c != nil { @@ -102,15 +102,10 @@ func (z ZltxOrderStatisticsTool) getZltxOrderStatistics(channel chan entitys.Res if resData.Code != 200 { return fmt.Errorf("zltx order statistics error: %s", resData.Error) } - if c != nil { - _ = c.WriteMessage(websocket.TextMessage, res.Content) - return nil - } else { - channel <- entitys.ResponseData{ - Done: false, - Content: res.Text, - Type: entitys.ResponseJson, - } + channel <- entitys.Response{ + Index: z.Name(), + Content: res.Text, + Type: entitys.ResponseJson, } return nil }