216 lines
8.6 KiB
Go
216 lines
8.6 KiB
Go
package service
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"time"
|
||
|
||
"eino-project/internal/domain/agent"
|
||
"eino-project/internal/domain/llm"
|
||
"eino-project/internal/domain/monitor"
|
||
"eino-project/internal/domain/workflow"
|
||
"eino-project/internal/pkg/adkutil"
|
||
"eino-project/internal/pkg/sseutil"
|
||
|
||
"github.com/cloudwego/eino/schema"
|
||
"github.com/go-kratos/kratos/v2/log"
|
||
"github.com/gorilla/websocket"
|
||
)
|
||
|
||
type ChatService struct {
|
||
models llm.LLM
|
||
log *log.Helper
|
||
monitor monitor.Monitor
|
||
}
|
||
|
||
func NewChatService(logger log.Logger, models llm.LLM, m monitor.Monitor) *ChatService {
|
||
return &ChatService{models: models, log: log.NewHelper(logger), monitor: m}
|
||
}
|
||
|
||
func (s *ChatService) HandleWebSocketChat(conn *websocket.Conn) {
|
||
defer conn.Close()
|
||
ctx := context.Background()
|
||
// 初始化两轮上下文(不输出,仅用于后续对话)
|
||
var dialogCtx []schema.Message = []schema.Message{
|
||
{Role: schema.User, Content: "你好"},
|
||
{Role: schema.Assistant, Content: "您好,请问需要什么帮助?"},
|
||
{Role: schema.User, Content: "查询订单 O271 的进度"},
|
||
{Role: schema.Assistant, Content: "订单 O271 已创建,正在处理"},
|
||
}
|
||
for {
|
||
mt, msg, err := conn.ReadMessage()
|
||
if err != nil {
|
||
return
|
||
}
|
||
if mt != websocket.TextMessage {
|
||
continue
|
||
}
|
||
var req struct {
|
||
Message string `json:"message"`
|
||
SessionID string `json:"session_id"`
|
||
}
|
||
if json.Unmarshal(msg, &req) != nil || req.Message == "" {
|
||
_ = sseutil.WSWriteJSON(conn, sseutil.BuildError(req.SessionID, "invalid input"))
|
||
_ = sseutil.WSWriteJSON(conn, sseutil.BuildDone(req.SessionID))
|
||
continue
|
||
}
|
||
// 追加到上下文并限制近2轮(最多保留4条)
|
||
dialogCtx = append(dialogCtx, schema.Message{Role: schema.User, Content: req.Message})
|
||
if len(dialogCtx) > 4 {
|
||
dialogCtx = dialogCtx[len(dialogCtx)-4:]
|
||
}
|
||
// 成本消耗记录(示意:写入DB) + Coze Loop 简化监控
|
||
_ = func() error {
|
||
// TODO: 记录成本消耗(token、模型、耗时等)
|
||
s.log.Infof("cost-record session=%s ts=%s", req.SessionID, time.Now().Format("2006-01-02 15:04:05"))
|
||
if s.monitor != nil {
|
||
_ = s.monitor.RecordRequest(ctx, "ws_chat", 0, true)
|
||
}
|
||
return nil
|
||
}()
|
||
|
||
_ = sseutil.WSWriteJSON(conn, sseutil.BuildLog(req.SessionID, "意图识别中"))
|
||
intentAgent := agent.NewIntentAgent(ctx, s.models)
|
||
var intent string
|
||
if intentAgent != nil {
|
||
r, err := adkutil.Query(ctx, intentAgent, req.Message)
|
||
if err == nil && r.Message != nil && r.Message.Content != "" {
|
||
var intentJson map[string]interface{}
|
||
if json.Unmarshal([]byte(r.Message.Content), &intentJson) == nil {
|
||
intent = intentJson["intent"].(string)
|
||
}
|
||
}
|
||
}
|
||
_ = sseutil.WSWriteJSON(conn, sseutil.BuildLog(req.SessionID, "意图识别结果:"+intent))
|
||
|
||
switch intent {
|
||
case "product":
|
||
_ = sseutil.WSWriteJSON(conn, sseutil.BuildProcess(req.SessionID, "进入产品工作流", false))
|
||
wf := workflow.NewZltxProductWorkflow(s.models)
|
||
startWF := time.Now()
|
||
out, err := wf.Run(ctx, req.Message)
|
||
if err != nil {
|
||
_ = sseutil.WSWriteJSON(conn, sseutil.BuildError(req.SessionID, err.Error()))
|
||
_ = sseutil.WSWriteJSON(conn, sseutil.BuildDone(req.SessionID))
|
||
continue
|
||
}
|
||
if s.monitor != nil {
|
||
_ = s.monitor.RecordAIRequest(ctx, time.Since(startWF), true)
|
||
_ = s.monitor.RecordLLMUsage(ctx, &monitor.LLMUsage{Model: "workflow:product", SessionID: req.SessionID, PromptPreview: req.Message, LatencyMS: time.Since(startWF).Milliseconds(), Timestamp: time.Now()})
|
||
}
|
||
var comp map[string]interface{}
|
||
if json.Unmarshal([]byte(out), &comp) != nil {
|
||
comp = map[string]interface{}{"result": out}
|
||
}
|
||
_ = sseutil.WSWriteJSON(conn, sseutil.BuildJson(req.SessionID, comp, true))
|
||
_ = sseutil.WSWriteJSON(conn, sseutil.BuildDone(req.SessionID))
|
||
case "order":
|
||
_ = sseutil.WSWriteJSON(conn, sseutil.BuildProcess(req.SessionID, "获取订单信息", false))
|
||
ordAgent := agent.NewOrderChatAgent(ctx, s.models)
|
||
ordr, err := adkutil.QueryWithLogger(ctx, ordAgent, req.Message, s.log)
|
||
if err != nil {
|
||
_ = sseutil.WSWriteJSON(conn, sseutil.BuildError(req.SessionID, err.Error()))
|
||
_ = sseutil.WSWriteJSON(conn, sseutil.BuildDone(req.SessionID))
|
||
continue
|
||
}
|
||
var orderInfo map[string]interface{}
|
||
if ordr.Customized != nil {
|
||
b, _ := json.Marshal(ordr.Customized)
|
||
_ = json.Unmarshal(b, &orderInfo)
|
||
} else if ordr.Message != nil {
|
||
_ = json.Unmarshal([]byte(ordr.Message.Content), &orderInfo)
|
||
}
|
||
if orderInfo["id"] == nil {
|
||
_ = sseutil.WSWriteJSON(conn, sseutil.BuildError(req.SessionID, "订单ID不存在"))
|
||
_ = sseutil.WSWriteJSON(conn, sseutil.BuildDone(req.SessionID))
|
||
continue
|
||
}
|
||
_ = sseutil.WSWriteJSON(conn, sseutil.BuildJson(req.SessionID, map[string]interface{}{"order": orderInfo}, false))
|
||
|
||
status := fmt.Sprintf("%v", orderInfo["status"])
|
||
if status != "" && status != "completed" && status != "delivered" {
|
||
_ = sseutil.WSWriteJSON(conn, sseutil.BuildProcess(req.SessionID, "订单未完成,生成诊断(Markdown流式)", false))
|
||
logAgent := agent.NewOrderLogAgent(ctx, s.models)
|
||
id := fmt.Sprintf("%v", orderInfo["id"])
|
||
res, err := adkutil.Query(ctx, logAgent, fmt.Sprintf("order_id=%s", id))
|
||
if err != nil {
|
||
s.log.Infof("order log agent error: %v", err)
|
||
_ = sseutil.WSWriteJSON(conn, sseutil.BuildError(req.SessionID, "日志查询失败"))
|
||
} else {
|
||
var logs []map[string]interface{}
|
||
if res.Customized != nil {
|
||
b, _ := json.Marshal(res.Customized)
|
||
_ = json.Unmarshal(b, &logs)
|
||
} else if res.Message != nil && res.Message.Content != "" {
|
||
_ = json.Unmarshal([]byte(res.Message.Content), &logs)
|
||
}
|
||
// 仅使用最近两轮上下文
|
||
ctx2 := dialogCtx
|
||
if len(ctx2) > 4 {
|
||
ctx2 = ctx2[len(ctx2)-4:]
|
||
}
|
||
ctxTextDialog, _ := json.Marshal(ctx2)
|
||
ctxTextOrder, _ := json.Marshal(orderInfo)
|
||
ctxTextLogs, _ := json.Marshal(logs)
|
||
prompt := fmt.Sprintf("基于最近两轮对话上下文、订单信息与处理日志生成Markdown诊断。\n\n上下文:%s\n\n订单:%s\n\n日志:%s\n\n请给出状态判定、关键日志条目与建议。", string(ctxTextDialog), string(ctxTextOrder), string(ctxTextLogs))
|
||
chatModel, err := s.models.Chat()
|
||
if err != nil {
|
||
_ = sseutil.WSWriteJSON(conn, sseutil.BuildError(req.SessionID, err.Error()))
|
||
} else {
|
||
reader, err := chatModel.Stream(ctx, []*schema.Message{{Role: schema.User, Content: prompt}})
|
||
if err != nil {
|
||
_ = sseutil.WSWriteJSON(conn, sseutil.BuildError(req.SessionID, "诊断流式生成失败"))
|
||
} else {
|
||
start := time.Now()
|
||
var full string
|
||
for {
|
||
chunk, rerr := reader.Recv()
|
||
if rerr != nil {
|
||
break
|
||
}
|
||
if chunk != nil && chunk.Content != "" {
|
||
full += chunk.Content
|
||
_ = sseutil.WSWriteJSON(conn, sseutil.BuildStreamChunk(req.SessionID, chunk.Content))
|
||
}
|
||
}
|
||
if s.monitor != nil {
|
||
_ = s.monitor.RecordAIRequest(ctx, time.Since(start), true)
|
||
_ = s.monitor.RecordLLMUsage(ctx, &monitor.LLMUsage{Model: "agent:order_diagnosis", SessionID: req.SessionID, PromptPreview: prompt, LatencyMS: time.Since(start).Milliseconds(), Timestamp: time.Now()})
|
||
}
|
||
_ = sseutil.WSWriteJSON(conn, sseutil.BuildStreamFinal(req.SessionID, full))
|
||
// 追加助手回复到上下文并限制两轮
|
||
dialogCtx = append(dialogCtx, schema.Message{Role: schema.Assistant, Content: full})
|
||
if len(dialogCtx) > 4 {
|
||
dialogCtx = dialogCtx[len(dialogCtx)-4:]
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
_ = sseutil.WSWriteJSON(conn, sseutil.BuildDone(req.SessionID))
|
||
default:
|
||
_ = sseutil.WSWriteJSON(conn, sseutil.BuildProcess(req.SessionID, "进入自然对话", false))
|
||
chatModel, err := s.models.Chat()
|
||
if err != nil {
|
||
_ = sseutil.WSWriteJSON(conn, sseutil.BuildError(req.SessionID, err.Error()))
|
||
_ = sseutil.WSWriteJSON(conn, sseutil.BuildDone(req.SessionID))
|
||
continue
|
||
}
|
||
start := time.Now()
|
||
resp, err := chatModel.Generate(ctx, []*schema.Message{{Role: schema.User, Content: req.Message}})
|
||
if err != nil || resp == nil {
|
||
_ = sseutil.WSWriteJSON(conn, sseutil.BuildError(req.SessionID, "对话失败"))
|
||
_ = sseutil.WSWriteJSON(conn, sseutil.BuildDone(req.SessionID))
|
||
continue
|
||
}
|
||
if s.monitor != nil {
|
||
_ = s.monitor.RecordAIRequest(ctx, time.Since(start), true)
|
||
_ = s.monitor.RecordLLMUsage(ctx, &monitor.LLMUsage{Model: "chat", SessionID: req.SessionID, PromptPreview: req.Message, LatencyMS: time.Since(start).Milliseconds(), Timestamp: time.Now()})
|
||
}
|
||
_ = sseutil.WSWriteJSON(conn, sseutil.BuildStreamFinal(req.SessionID, resp.Content))
|
||
_ = sseutil.WSWriteJSON(conn, sseutil.BuildDone(req.SessionID))
|
||
}
|
||
}
|
||
}
|