109 lines
2.8 KiB
Go
109 lines
2.8 KiB
Go
package workflow
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
contextpkg "eino-project/internal/domain/context"
|
|
"eino-project/internal/domain/llm"
|
|
"eino-project/internal/domain/vector"
|
|
|
|
"github.com/cloudwego/eino/schema"
|
|
)
|
|
|
|
type ChatWorkflow interface {
|
|
Chat(ctx context.Context, message string, sessionID string) (string, error)
|
|
Stream(ctx context.Context, message string, sessionID string) (<-chan string, error)
|
|
}
|
|
|
|
type chatWorkflow struct {
|
|
models llm.LLM
|
|
searcher vector.KnowledgeSearcher
|
|
ctxMgr contextpkg.ContextManager
|
|
}
|
|
|
|
func NewChatWorkflow(models llm.LLM, searcher vector.KnowledgeSearcher, ctxMgr contextpkg.ContextManager) ChatWorkflow {
|
|
return &chatWorkflow{
|
|
models: models,
|
|
searcher: searcher,
|
|
ctxMgr: ctxMgr,
|
|
}
|
|
}
|
|
|
|
func (w *chatWorkflow) Chat(ctx context.Context, message string, sessionID string) (string, error) {
|
|
if w.ctxMgr != nil {
|
|
w.ctxMgr.AddMessage(ctx, sessionID, contextpkg.Message{Role: "user", Content: message, Timestamp: time.Now()})
|
|
}
|
|
|
|
var knowledgeContext string
|
|
if w.searcher != nil {
|
|
results, err := w.searcher.SearchKnowledge(ctx, message, 3)
|
|
if err == nil && len(results) > 0 {
|
|
var parts []string
|
|
for _, r := range results {
|
|
parts = append(parts, fmt.Sprintf("相关知识: %s", r.Document.Content))
|
|
}
|
|
knowledgeContext = strings.Join(parts, "\n")
|
|
}
|
|
}
|
|
|
|
enhanced := message
|
|
if knowledgeContext != "" {
|
|
enhanced = fmt.Sprintf("基于以下知识库内容回答用户问题:\n%s\n\n用户问题: %s", knowledgeContext, message)
|
|
}
|
|
|
|
msgs := []*schema.Message{{Role: schema.User, Content: enhanced}}
|
|
chatModel, err := w.models.Chat()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
resp, err := chatModel.Generate(ctx, msgs)
|
|
if err != nil || resp == nil {
|
|
return "", err
|
|
}
|
|
return resp.Content, nil
|
|
}
|
|
|
|
func (w *chatWorkflow) Stream(ctx context.Context, message string, sessionID string) (<-chan string, error) {
|
|
var knowledgeContext string
|
|
if w.searcher != nil {
|
|
results, err := w.searcher.SearchKnowledge(ctx, message, 3)
|
|
if err == nil && len(results) > 0 {
|
|
var parts []string
|
|
for _, r := range results {
|
|
parts = append(parts, fmt.Sprintf("相关知识: %s", r.Document.Content))
|
|
}
|
|
knowledgeContext = strings.Join(parts, "\n")
|
|
}
|
|
}
|
|
enhanced := message
|
|
if knowledgeContext != "" {
|
|
enhanced = fmt.Sprintf("基于以下知识库内容回答用户问题:\n%s\n\n用户问题: %s", knowledgeContext, message)
|
|
}
|
|
msgs := []*schema.Message{{Role: schema.User, Content: enhanced}}
|
|
chatModel, err := w.models.Chat()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
reader, err := chatModel.Stream(ctx, msgs)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
ch := make(chan string, 8)
|
|
go func() {
|
|
defer close(ch)
|
|
for {
|
|
chunk, err := reader.Recv()
|
|
if err != nil {
|
|
return
|
|
}
|
|
if chunk != nil && chunk.Content != "" {
|
|
ch <- chunk.Content
|
|
}
|
|
}
|
|
}()
|
|
return ch, nil
|
|
}
|