ai-courseware/eino-project/internal/domain/workflow/chat_workflow.go

109 lines
2.8 KiB
Go

package workflow
import (
"context"
"fmt"
"strings"
"time"
contextpkg "eino-project/internal/domain/context"
"eino-project/internal/domain/llm"
"eino-project/internal/domain/vector"
"github.com/cloudwego/eino/schema"
)
type ChatWorkflow interface {
Chat(ctx context.Context, message string, sessionID string) (string, error)
Stream(ctx context.Context, message string, sessionID string) (<-chan string, error)
}
type chatWorkflow struct {
models llm.LLM
searcher vector.KnowledgeSearcher
ctxMgr contextpkg.ContextManager
}
func NewChatWorkflow(models llm.LLM, searcher vector.KnowledgeSearcher, ctxMgr contextpkg.ContextManager) ChatWorkflow {
return &chatWorkflow{
models: models,
searcher: searcher,
ctxMgr: ctxMgr,
}
}
func (w *chatWorkflow) Chat(ctx context.Context, message string, sessionID string) (string, error) {
if w.ctxMgr != nil {
w.ctxMgr.AddMessage(ctx, sessionID, contextpkg.Message{Role: "user", Content: message, Timestamp: time.Now()})
}
var knowledgeContext string
if w.searcher != nil {
results, err := w.searcher.SearchKnowledge(ctx, message, 3)
if err == nil && len(results) > 0 {
var parts []string
for _, r := range results {
parts = append(parts, fmt.Sprintf("相关知识: %s", r.Document.Content))
}
knowledgeContext = strings.Join(parts, "\n")
}
}
enhanced := message
if knowledgeContext != "" {
enhanced = fmt.Sprintf("基于以下知识库内容回答用户问题:\n%s\n\n用户问题: %s", knowledgeContext, message)
}
msgs := []*schema.Message{{Role: schema.User, Content: enhanced}}
chatModel, err := w.models.Chat()
if err != nil {
return "", err
}
resp, err := chatModel.Generate(ctx, msgs)
if err != nil || resp == nil {
return "", err
}
return resp.Content, nil
}
func (w *chatWorkflow) Stream(ctx context.Context, message string, sessionID string) (<-chan string, error) {
var knowledgeContext string
if w.searcher != nil {
results, err := w.searcher.SearchKnowledge(ctx, message, 3)
if err == nil && len(results) > 0 {
var parts []string
for _, r := range results {
parts = append(parts, fmt.Sprintf("相关知识: %s", r.Document.Content))
}
knowledgeContext = strings.Join(parts, "\n")
}
}
enhanced := message
if knowledgeContext != "" {
enhanced = fmt.Sprintf("基于以下知识库内容回答用户问题:\n%s\n\n用户问题: %s", knowledgeContext, message)
}
msgs := []*schema.Message{{Role: schema.User, Content: enhanced}}
chatModel, err := w.models.Chat()
if err != nil {
return nil, err
}
reader, err := chatModel.Stream(ctx, msgs)
if err != nil {
return nil, err
}
ch := make(chan string, 8)
go func() {
defer close(ch)
for {
chunk, err := reader.Recv()
if err != nil {
return
}
if chunk != nil && chunk.Content != "" {
ch <- chunk.Content
}
}
}()
return ch, nil
}