ai-courseware/eino-project/internal/service/chat.go

216 lines
8.6 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"context"
"encoding/json"
"fmt"
"time"
"eino-project/internal/domain/agent"
"eino-project/internal/domain/llm"
"eino-project/internal/domain/monitor"
"eino-project/internal/domain/workflow"
"eino-project/internal/pkg/adkutil"
"eino-project/internal/pkg/sseutil"
"github.com/cloudwego/eino/schema"
"github.com/go-kratos/kratos/v2/log"
"github.com/gorilla/websocket"
)
type ChatService struct {
models llm.LLM
log *log.Helper
monitor monitor.Monitor
}
func NewChatService(logger log.Logger, models llm.LLM, m monitor.Monitor) *ChatService {
return &ChatService{models: models, log: log.NewHelper(logger), monitor: m}
}
func (s *ChatService) HandleWebSocketChat(conn *websocket.Conn) {
defer conn.Close()
ctx := context.Background()
// 初始化两轮上下文(不输出,仅用于后续对话)
var dialogCtx []schema.Message = []schema.Message{
{Role: schema.User, Content: "你好"},
{Role: schema.Assistant, Content: "您好,请问需要什么帮助?"},
{Role: schema.User, Content: "查询订单 O271 的进度"},
{Role: schema.Assistant, Content: "订单 O271 已创建,正在处理"},
}
for {
mt, msg, err := conn.ReadMessage()
if err != nil {
return
}
if mt != websocket.TextMessage {
continue
}
var req struct {
Message string `json:"message"`
SessionID string `json:"session_id"`
}
if json.Unmarshal(msg, &req) != nil || req.Message == "" {
_ = sseutil.WSWriteJSON(conn, sseutil.BuildError(req.SessionID, "invalid input"))
_ = sseutil.WSWriteJSON(conn, sseutil.BuildDone(req.SessionID))
continue
}
// 追加到上下文并限制近2轮最多保留4条
dialogCtx = append(dialogCtx, schema.Message{Role: schema.User, Content: req.Message})
if len(dialogCtx) > 4 {
dialogCtx = dialogCtx[len(dialogCtx)-4:]
}
// 成本消耗记录示意写入DB + Coze Loop 简化监控
_ = func() error {
// TODO: 记录成本消耗token、模型、耗时等
s.log.Infof("cost-record session=%s ts=%s", req.SessionID, time.Now().Format("2006-01-02 15:04:05"))
if s.monitor != nil {
_ = s.monitor.RecordRequest(ctx, "ws_chat", 0, true)
}
return nil
}()
_ = sseutil.WSWriteJSON(conn, sseutil.BuildLog(req.SessionID, "意图识别中"))
intentAgent := agent.NewIntentAgent(ctx, s.models)
var intent string
if intentAgent != nil {
r, err := adkutil.Query(ctx, intentAgent, req.Message)
if err == nil && r.Message != nil && r.Message.Content != "" {
var intentJson map[string]interface{}
if json.Unmarshal([]byte(r.Message.Content), &intentJson) == nil {
intent = intentJson["intent"].(string)
}
}
}
_ = sseutil.WSWriteJSON(conn, sseutil.BuildLog(req.SessionID, "意图识别结果:"+intent))
switch intent {
case "product":
_ = sseutil.WSWriteJSON(conn, sseutil.BuildProcess(req.SessionID, "进入产品工作流", false))
wf := workflow.NewZltxProductWorkflow(s.models)
startWF := time.Now()
out, err := wf.Run(ctx, req.Message)
if err != nil {
_ = sseutil.WSWriteJSON(conn, sseutil.BuildError(req.SessionID, err.Error()))
_ = sseutil.WSWriteJSON(conn, sseutil.BuildDone(req.SessionID))
continue
}
if s.monitor != nil {
_ = s.monitor.RecordAIRequest(ctx, time.Since(startWF), true)
_ = s.monitor.RecordLLMUsage(ctx, &monitor.LLMUsage{Model: "workflow:product", SessionID: req.SessionID, PromptPreview: req.Message, LatencyMS: time.Since(startWF).Milliseconds(), Timestamp: time.Now()})
}
var comp map[string]interface{}
if json.Unmarshal([]byte(out), &comp) != nil {
comp = map[string]interface{}{"result": out}
}
_ = sseutil.WSWriteJSON(conn, sseutil.BuildJson(req.SessionID, comp, true))
_ = sseutil.WSWriteJSON(conn, sseutil.BuildDone(req.SessionID))
case "order":
_ = sseutil.WSWriteJSON(conn, sseutil.BuildProcess(req.SessionID, "获取订单信息", false))
ordAgent := agent.NewOrderChatAgent(ctx, s.models)
ordr, err := adkutil.QueryWithLogger(ctx, ordAgent, req.Message, s.log)
if err != nil {
_ = sseutil.WSWriteJSON(conn, sseutil.BuildError(req.SessionID, err.Error()))
_ = sseutil.WSWriteJSON(conn, sseutil.BuildDone(req.SessionID))
continue
}
var orderInfo map[string]interface{}
if ordr.Customized != nil {
b, _ := json.Marshal(ordr.Customized)
_ = json.Unmarshal(b, &orderInfo)
} else if ordr.Message != nil {
_ = json.Unmarshal([]byte(ordr.Message.Content), &orderInfo)
}
if orderInfo["id"] == nil {
_ = sseutil.WSWriteJSON(conn, sseutil.BuildError(req.SessionID, "订单ID不存在"))
_ = sseutil.WSWriteJSON(conn, sseutil.BuildDone(req.SessionID))
continue
}
_ = sseutil.WSWriteJSON(conn, sseutil.BuildJson(req.SessionID, map[string]interface{}{"order": orderInfo}, false))
status := fmt.Sprintf("%v", orderInfo["status"])
if status != "" && status != "completed" && status != "delivered" {
_ = sseutil.WSWriteJSON(conn, sseutil.BuildProcess(req.SessionID, "订单未完成生成诊断Markdown流式", false))
logAgent := agent.NewOrderLogAgent(ctx, s.models)
id := fmt.Sprintf("%v", orderInfo["id"])
res, err := adkutil.Query(ctx, logAgent, fmt.Sprintf("order_id=%s", id))
if err != nil {
s.log.Infof("order log agent error: %v", err)
_ = sseutil.WSWriteJSON(conn, sseutil.BuildError(req.SessionID, "日志查询失败"))
} else {
var logs []map[string]interface{}
if res.Customized != nil {
b, _ := json.Marshal(res.Customized)
_ = json.Unmarshal(b, &logs)
} else if res.Message != nil && res.Message.Content != "" {
_ = json.Unmarshal([]byte(res.Message.Content), &logs)
}
// 仅使用最近两轮上下文
ctx2 := dialogCtx
if len(ctx2) > 4 {
ctx2 = ctx2[len(ctx2)-4:]
}
ctxTextDialog, _ := json.Marshal(ctx2)
ctxTextOrder, _ := json.Marshal(orderInfo)
ctxTextLogs, _ := json.Marshal(logs)
prompt := fmt.Sprintf("基于最近两轮对话上下文、订单信息与处理日志生成Markdown诊断。\n\n上下文:%s\n\n订单:%s\n\n日志:%s\n\n请给出状态判定、关键日志条目与建议。", string(ctxTextDialog), string(ctxTextOrder), string(ctxTextLogs))
chatModel, err := s.models.Chat()
if err != nil {
_ = sseutil.WSWriteJSON(conn, sseutil.BuildError(req.SessionID, err.Error()))
} else {
reader, err := chatModel.Stream(ctx, []*schema.Message{{Role: schema.User, Content: prompt}})
if err != nil {
_ = sseutil.WSWriteJSON(conn, sseutil.BuildError(req.SessionID, "诊断流式生成失败"))
} else {
start := time.Now()
var full string
for {
chunk, rerr := reader.Recv()
if rerr != nil {
break
}
if chunk != nil && chunk.Content != "" {
full += chunk.Content
_ = sseutil.WSWriteJSON(conn, sseutil.BuildStreamChunk(req.SessionID, chunk.Content))
}
}
if s.monitor != nil {
_ = s.monitor.RecordAIRequest(ctx, time.Since(start), true)
_ = s.monitor.RecordLLMUsage(ctx, &monitor.LLMUsage{Model: "agent:order_diagnosis", SessionID: req.SessionID, PromptPreview: prompt, LatencyMS: time.Since(start).Milliseconds(), Timestamp: time.Now()})
}
_ = sseutil.WSWriteJSON(conn, sseutil.BuildStreamFinal(req.SessionID, full))
// 追加助手回复到上下文并限制两轮
dialogCtx = append(dialogCtx, schema.Message{Role: schema.Assistant, Content: full})
if len(dialogCtx) > 4 {
dialogCtx = dialogCtx[len(dialogCtx)-4:]
}
}
}
}
}
_ = sseutil.WSWriteJSON(conn, sseutil.BuildDone(req.SessionID))
default:
_ = sseutil.WSWriteJSON(conn, sseutil.BuildProcess(req.SessionID, "进入自然对话", false))
chatModel, err := s.models.Chat()
if err != nil {
_ = sseutil.WSWriteJSON(conn, sseutil.BuildError(req.SessionID, err.Error()))
_ = sseutil.WSWriteJSON(conn, sseutil.BuildDone(req.SessionID))
continue
}
start := time.Now()
resp, err := chatModel.Generate(ctx, []*schema.Message{{Role: schema.User, Content: req.Message}})
if err != nil || resp == nil {
_ = sseutil.WSWriteJSON(conn, sseutil.BuildError(req.SessionID, "对话失败"))
_ = sseutil.WSWriteJSON(conn, sseutil.BuildDone(req.SessionID))
continue
}
if s.monitor != nil {
_ = s.monitor.RecordAIRequest(ctx, time.Since(start), true)
_ = s.monitor.RecordLLMUsage(ctx, &monitor.LLMUsage{Model: "chat", SessionID: req.SessionID, PromptPreview: req.Message, LatencyMS: time.Since(start).Milliseconds(), Timestamp: time.Now()})
}
_ = sseutil.WSWriteJSON(conn, sseutil.BuildStreamFinal(req.SessionID, resp.Content))
_ = sseutil.WSWriteJSON(conn, sseutil.BuildDone(req.SessionID))
}
}
}