l_ai_knowledge/internal/application/service/chat_pipline/rewrite.go

175 lines
5.9 KiB
Go

// Package chatpipline provides chat pipeline processing capabilities
// Including query rewriting, history processing, model invocation and other features
package chatpipline
import (
"bytes"
"context"
"html/template"
"regexp"
"slices"
"sort"
"time"
"knowlege-lsxd/internal/config"
"knowlege-lsxd/internal/logger"
"knowlege-lsxd/internal/models/chat"
"knowlege-lsxd/internal/types"
"knowlege-lsxd/internal/types/interfaces"
)
// PluginRewrite is a plugin for rewriting user queries
// It uses historical dialog context and large language models to optimize the user's original query
type PluginRewrite struct {
modelService interfaces.ModelService // Model service for calling large language models
messageService interfaces.MessageService // Message service for retrieving historical messages
config *config.Config // System configuration
}
// reg is a regular expression used to match and remove content between <think></think> tags
var reg = regexp.MustCompile(`(?s)<think>.*?</think>`)
// NewPluginRewrite creates a new query rewriting plugin instance
// Also registers the plugin with the event manager
func NewPluginRewrite(eventManager *EventManager,
modelService interfaces.ModelService, messageService interfaces.MessageService,
config *config.Config,
) *PluginRewrite {
res := &PluginRewrite{
modelService: modelService,
messageService: messageService,
config: config,
}
eventManager.Register(res)
return res
}
// ActivationEvents returns the list of event types this plugin responds to
// This plugin only responds to REWRITE_QUERY events
func (p *PluginRewrite) ActivationEvents() []types.EventType {
return []types.EventType{types.REWRITE_QUERY}
}
// OnEvent processes triggered events
// When receiving a REWRITE_QUERY event, it rewrites the user query using conversation history and the language model
func (p *PluginRewrite) OnEvent(ctx context.Context,
eventType types.EventType, chatManage *types.ChatManage, next func() *PluginError,
) *PluginError {
// Initialize rewritten query as original query
chatManage.RewriteQuery = chatManage.Query
// Get conversation history
history, err := p.messageService.GetRecentMessagesBySession(ctx, chatManage.SessionID, 20)
if err != nil {
logger.Errorf(ctx, "Failed to get conversation history, session_id: %s, error: %v", chatManage.SessionID, err)
}
// Convert historical messages to conversation history structure
historyMap := make(map[string]*types.History)
// Process historical messages, grouped by requestID
for _, message := range history {
history, ok := historyMap[message.RequestID]
if !ok {
history = &types.History{}
}
if message.Role == "user" {
// User message as query
history.Query = message.Content
history.CreateAt = message.CreatedAt
} else {
// System message as answer, while removing thinking process
history.Answer = reg.ReplaceAllString(message.Content, "")
history.KnowledgeReferences = message.KnowledgeReferences
}
historyMap[message.RequestID] = history
}
// Convert to list and filter incomplete conversations
historyList := make([]*types.History, 0)
for _, history := range historyMap {
if history.Answer != "" && history.Query != "" {
historyList = append(historyList, history)
}
}
// Sort by time, keep the most recent conversations
sort.Slice(historyList, func(i, j int) bool {
return historyList[i].CreateAt.After(historyList[j].CreateAt)
})
// Limit the number of historical records
if len(historyList) > p.config.Conversation.MaxRounds {
historyList = historyList[:p.config.Conversation.MaxRounds]
}
// Reverse to chronological order
slices.Reverse(historyList)
chatManage.History = historyList
userTmpl, err := template.New("rewriteContent").Parse(p.config.Conversation.RewritePromptUser)
if err != nil {
logger.Errorf(ctx, "Failed to execute template, session_id: %s, error: %v", chatManage.SessionID, err)
return next()
}
systemTmpl, err := template.New("rewriteContent").Parse(p.config.Conversation.RewritePromptSystem)
if err != nil {
logger.GetLogger(ctx).Errorf("Failed to execute template, session_id: %s, error: %v", chatManage.SessionID, err)
return next()
}
currentTime := time.Now().Format("2006-01-02 15:04:05")
var userContent, systemContent bytes.Buffer
err = userTmpl.Execute(&userContent, map[string]interface{}{
"Query": chatManage.Query,
"CurrentTime": currentTime,
"Yesterday": time.Now().AddDate(0, 0, -1).Format("2006-01-02"),
"Conversation": historyList,
})
if err != nil {
logger.GetLogger(ctx).Errorf("Failed to execute template, session_id: %s, error: %v", chatManage.SessionID, err)
return next()
}
err = systemTmpl.Execute(&systemContent, map[string]interface{}{
"Query": chatManage.Query,
"CurrentTime": currentTime,
"Yesterday": time.Now().AddDate(0, 0, -1).Format("2006-01-02"),
"Conversation": historyList,
})
if err != nil {
logger.Errorf(ctx, "Failed to execute template, session_id: %s, error: %v", chatManage.SessionID, err)
return next()
}
rewriteModel, err := p.modelService.GetChatModel(ctx, chatManage.ChatModelID)
if err != nil {
logger.Errorf(ctx, "Failed to get model, session_id: %s, error: %v", chatManage.SessionID, err)
return next()
}
// Call model to rewrite query
thinking := false
response, err := rewriteModel.Chat(ctx, []chat.Message{
{
Role: "system",
Content: systemContent.String(),
},
{
Role: "user",
Content: userContent.String(),
},
}, &chat.ChatOptions{
Temperature: 0.3,
MaxCompletionTokens: 50,
Thinking: &thinking,
})
if err != nil {
logger.Errorf(ctx, "Failed to execute model, session_id: %s, error: %v", chatManage.SessionID, err)
return next()
}
// Update rewritten query
chatManage.RewriteQuery = response.Content
logger.GetLogger(ctx).Infof("Rewritten query, session_id: %s, rewrite_query: %s",
chatManage.SessionID, chatManage.RewriteQuery)
return next()
}