package service import ( "context" "encoding/json" "fmt" "time" "eino-project/internal/domain/agent" "eino-project/internal/domain/llm" "eino-project/internal/domain/monitor" "eino-project/internal/domain/workflow" "eino-project/internal/pkg/adkutil" "eino-project/internal/pkg/sseutil" "github.com/cloudwego/eino/schema" "github.com/go-kratos/kratos/v2/log" "github.com/gorilla/websocket" ) type ChatService struct { models llm.LLM log *log.Helper monitor monitor.Monitor } func NewChatService(logger log.Logger, models llm.LLM, m monitor.Monitor) *ChatService { return &ChatService{models: models, log: log.NewHelper(logger), monitor: m} } func (s *ChatService) HandleWebSocketChat(conn *websocket.Conn) { defer conn.Close() ctx := context.Background() // 初始化两轮上下文(不输出,仅用于后续对话) var dialogCtx []schema.Message = []schema.Message{ {Role: schema.User, Content: "你好"}, {Role: schema.Assistant, Content: "您好,请问需要什么帮助?"}, {Role: schema.User, Content: "查询订单 O271 的进度"}, {Role: schema.Assistant, Content: "订单 O271 已创建,正在处理"}, } for { mt, msg, err := conn.ReadMessage() if err != nil { return } if mt != websocket.TextMessage { continue } var req struct { Message string `json:"message"` SessionID string `json:"session_id"` } if json.Unmarshal(msg, &req) != nil || req.Message == "" { _ = sseutil.WSWriteJSON(conn, sseutil.BuildError(req.SessionID, "invalid input")) _ = sseutil.WSWriteJSON(conn, sseutil.BuildDone(req.SessionID)) continue } // 追加到上下文并限制近2轮(最多保留4条) dialogCtx = append(dialogCtx, schema.Message{Role: schema.User, Content: req.Message}) if len(dialogCtx) > 4 { dialogCtx = dialogCtx[len(dialogCtx)-4:] } // 成本消耗记录(示意:写入DB) + Coze Loop 简化监控 _ = func() error { // TODO: 记录成本消耗(token、模型、耗时等) s.log.Infof("cost-record session=%s ts=%s", req.SessionID, time.Now().Format("2006-01-02 15:04:05")) if s.monitor != nil { _ = s.monitor.RecordRequest(ctx, "ws_chat", 0, true) } return nil }() _ = sseutil.WSWriteJSON(conn, sseutil.BuildLog(req.SessionID, "意图识别中")) intentAgent := agent.NewIntentAgent(ctx, s.models) var intent string if intentAgent != nil { r, err := adkutil.Query(ctx, intentAgent, req.Message) if err == nil && r.Message != nil && r.Message.Content != "" { var intentJson map[string]interface{} if json.Unmarshal([]byte(r.Message.Content), &intentJson) == nil { intent = intentJson["intent"].(string) } } } _ = sseutil.WSWriteJSON(conn, sseutil.BuildLog(req.SessionID, "意图识别结果:"+intent)) switch intent { case "product": _ = sseutil.WSWriteJSON(conn, sseutil.BuildProcess(req.SessionID, "进入产品工作流", false)) wf := workflow.NewZltxProductWorkflow(s.models) startWF := time.Now() out, err := wf.Run(ctx, req.Message) if err != nil { _ = sseutil.WSWriteJSON(conn, sseutil.BuildError(req.SessionID, err.Error())) _ = sseutil.WSWriteJSON(conn, sseutil.BuildDone(req.SessionID)) continue } if s.monitor != nil { _ = s.monitor.RecordAIRequest(ctx, time.Since(startWF), true) _ = s.monitor.RecordLLMUsage(ctx, &monitor.LLMUsage{Model: "workflow:product", SessionID: req.SessionID, PromptPreview: req.Message, LatencyMS: time.Since(startWF).Milliseconds(), Timestamp: time.Now()}) } var comp map[string]interface{} if json.Unmarshal([]byte(out), &comp) != nil { comp = map[string]interface{}{"result": out} } _ = sseutil.WSWriteJSON(conn, sseutil.BuildJson(req.SessionID, comp, true)) _ = sseutil.WSWriteJSON(conn, sseutil.BuildDone(req.SessionID)) case "order": _ = sseutil.WSWriteJSON(conn, sseutil.BuildProcess(req.SessionID, "获取订单信息", false)) ordAgent := agent.NewOrderChatAgent(ctx, s.models) ordr, err := adkutil.QueryWithLogger(ctx, ordAgent, req.Message, s.log) if err != nil { _ = sseutil.WSWriteJSON(conn, sseutil.BuildError(req.SessionID, err.Error())) _ = sseutil.WSWriteJSON(conn, sseutil.BuildDone(req.SessionID)) continue } var orderInfo map[string]interface{} if ordr.Customized != nil { b, _ := json.Marshal(ordr.Customized) _ = json.Unmarshal(b, &orderInfo) } else if ordr.Message != nil { _ = json.Unmarshal([]byte(ordr.Message.Content), &orderInfo) } if orderInfo["id"] == nil { _ = sseutil.WSWriteJSON(conn, sseutil.BuildError(req.SessionID, "订单ID不存在")) _ = sseutil.WSWriteJSON(conn, sseutil.BuildDone(req.SessionID)) continue } _ = sseutil.WSWriteJSON(conn, sseutil.BuildJson(req.SessionID, map[string]interface{}{"order": orderInfo}, false)) status := fmt.Sprintf("%v", orderInfo["status"]) if status != "" && status != "completed" && status != "delivered" { _ = sseutil.WSWriteJSON(conn, sseutil.BuildProcess(req.SessionID, "订单未完成,生成诊断(Markdown流式)", false)) logAgent := agent.NewOrderLogAgent(ctx, s.models) id := fmt.Sprintf("%v", orderInfo["id"]) res, err := adkutil.Query(ctx, logAgent, fmt.Sprintf("order_id=%s", id)) if err != nil { s.log.Infof("order log agent error: %v", err) _ = sseutil.WSWriteJSON(conn, sseutil.BuildError(req.SessionID, "日志查询失败")) } else { var logs []map[string]interface{} if res.Customized != nil { b, _ := json.Marshal(res.Customized) _ = json.Unmarshal(b, &logs) } else if res.Message != nil && res.Message.Content != "" { _ = json.Unmarshal([]byte(res.Message.Content), &logs) } // 仅使用最近两轮上下文 ctx2 := dialogCtx if len(ctx2) > 4 { ctx2 = ctx2[len(ctx2)-4:] } ctxTextDialog, _ := json.Marshal(ctx2) ctxTextOrder, _ := json.Marshal(orderInfo) ctxTextLogs, _ := json.Marshal(logs) prompt := fmt.Sprintf("基于最近两轮对话上下文、订单信息与处理日志生成Markdown诊断。\n\n上下文:%s\n\n订单:%s\n\n日志:%s\n\n请给出状态判定、关键日志条目与建议。", string(ctxTextDialog), string(ctxTextOrder), string(ctxTextLogs)) chatModel, err := s.models.Chat() if err != nil { _ = sseutil.WSWriteJSON(conn, sseutil.BuildError(req.SessionID, err.Error())) } else { reader, err := chatModel.Stream(ctx, []*schema.Message{{Role: schema.User, Content: prompt}}) if err != nil { _ = sseutil.WSWriteJSON(conn, sseutil.BuildError(req.SessionID, "诊断流式生成失败")) } else { start := time.Now() var full string for { chunk, rerr := reader.Recv() if rerr != nil { break } if chunk != nil && chunk.Content != "" { full += chunk.Content _ = sseutil.WSWriteJSON(conn, sseutil.BuildStreamChunk(req.SessionID, chunk.Content)) } } if s.monitor != nil { _ = s.monitor.RecordAIRequest(ctx, time.Since(start), true) _ = s.monitor.RecordLLMUsage(ctx, &monitor.LLMUsage{Model: "agent:order_diagnosis", SessionID: req.SessionID, PromptPreview: prompt, LatencyMS: time.Since(start).Milliseconds(), Timestamp: time.Now()}) } _ = sseutil.WSWriteJSON(conn, sseutil.BuildStreamFinal(req.SessionID, full)) // 追加助手回复到上下文并限制两轮 dialogCtx = append(dialogCtx, schema.Message{Role: schema.Assistant, Content: full}) if len(dialogCtx) > 4 { dialogCtx = dialogCtx[len(dialogCtx)-4:] } } } } } _ = sseutil.WSWriteJSON(conn, sseutil.BuildDone(req.SessionID)) default: _ = sseutil.WSWriteJSON(conn, sseutil.BuildProcess(req.SessionID, "进入自然对话", false)) chatModel, err := s.models.Chat() if err != nil { _ = sseutil.WSWriteJSON(conn, sseutil.BuildError(req.SessionID, err.Error())) _ = sseutil.WSWriteJSON(conn, sseutil.BuildDone(req.SessionID)) continue } start := time.Now() resp, err := chatModel.Generate(ctx, []*schema.Message{{Role: schema.User, Content: req.Message}}) if err != nil || resp == nil { _ = sseutil.WSWriteJSON(conn, sseutil.BuildError(req.SessionID, "对话失败")) _ = sseutil.WSWriteJSON(conn, sseutil.BuildDone(req.SessionID)) continue } if s.monitor != nil { _ = s.monitor.RecordAIRequest(ctx, time.Since(start), true) _ = s.monitor.RecordLLMUsage(ctx, &monitor.LLMUsage{Model: "chat", SessionID: req.SessionID, PromptPreview: req.Message, LatencyMS: time.Since(start).Milliseconds(), Timestamp: time.Now()}) } _ = sseutil.WSWriteJSON(conn, sseutil.BuildStreamFinal(req.SessionID, resp.Content)) _ = sseutil.WSWriteJSON(conn, sseutil.BuildDone(req.SessionID)) } } }