310 lines
8.3 KiB
Go
310 lines
8.3 KiB
Go
package context
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/go-kratos/kratos/v2/log"
|
|
)
|
|
|
|
// ConversationContext 对话上下文
|
|
type ConversationContext struct {
|
|
SessionID string `json:"session_id"`
|
|
UserID string `json:"user_id"`
|
|
Intent string `json:"intent"`
|
|
Entities map[string]interface{} `json:"entities"`
|
|
State string `json:"state"`
|
|
History []Message `json:"history"`
|
|
Metadata map[string]interface{} `json:"metadata"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
UpdatedAt time.Time `json:"updated_at"`
|
|
ExpiresAt time.Time `json:"expires_at"`
|
|
}
|
|
|
|
// Message 消息结构
|
|
type Message struct {
|
|
Role string `json:"role"` // user, assistant, system
|
|
Content string `json:"content"`
|
|
Timestamp time.Time `json:"timestamp"`
|
|
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
|
}
|
|
|
|
// ContextManager 上下文管理器接口
|
|
type ContextManager interface {
|
|
// 创建新的对话上下文
|
|
CreateContext(ctx context.Context, sessionID, userID string) (*ConversationContext, error)
|
|
|
|
// 获取对话上下文
|
|
GetContext(ctx context.Context, sessionID string) (*ConversationContext, error)
|
|
|
|
// 更新对话上下文
|
|
UpdateContext(ctx context.Context, convCtx *ConversationContext) error
|
|
|
|
// 添加消息到上下文
|
|
AddMessage(ctx context.Context, sessionID string, message Message) error
|
|
|
|
// 获取最近的消息
|
|
GetRecentMessages(ctx context.Context, sessionID string, limit int) ([]Message, error)
|
|
|
|
// 设置意图和实体
|
|
SetIntent(ctx context.Context, sessionID, intent string, entities map[string]interface{}) error
|
|
|
|
// 设置对话状态
|
|
SetState(ctx context.Context, sessionID, state string) error
|
|
|
|
// 获取上下文摘要
|
|
GetContextSummary(ctx context.Context, sessionID string) (string, error)
|
|
|
|
// 清理过期上下文
|
|
CleanupExpiredContexts(ctx context.Context) error
|
|
|
|
// 删除上下文
|
|
DeleteContext(ctx context.Context, sessionID string) error
|
|
}
|
|
|
|
// contextManager 上下文管理器实现
|
|
type contextManager struct {
|
|
contexts map[string]*ConversationContext
|
|
mutex sync.RWMutex
|
|
logger *log.Helper
|
|
|
|
// 配置
|
|
maxHistorySize int
|
|
contextTTL time.Duration
|
|
}
|
|
|
|
// NewContextManager 创建上下文管理器
|
|
func NewContextManager(logger log.Logger) ContextManager {
|
|
return &contextManager{
|
|
contexts: make(map[string]*ConversationContext),
|
|
mutex: sync.RWMutex{},
|
|
logger: log.NewHelper(logger),
|
|
maxHistorySize: 50, // 最大历史消息数
|
|
contextTTL: 24 * time.Hour, // 上下文过期时间
|
|
}
|
|
}
|
|
|
|
// CreateContext 创建新的对话上下文
|
|
func (cm *contextManager) CreateContext(ctx context.Context, sessionID, userID string) (*ConversationContext, error) {
|
|
cm.mutex.Lock()
|
|
defer cm.mutex.Unlock()
|
|
|
|
now := time.Now()
|
|
convCtx := &ConversationContext{
|
|
SessionID: sessionID,
|
|
UserID: userID,
|
|
Intent: "",
|
|
Entities: make(map[string]interface{}),
|
|
State: "active",
|
|
History: make([]Message, 0),
|
|
Metadata: make(map[string]interface{}),
|
|
CreatedAt: now,
|
|
UpdatedAt: now,
|
|
ExpiresAt: now.Add(cm.contextTTL),
|
|
}
|
|
|
|
cm.contexts[sessionID] = convCtx
|
|
cm.logger.Infof("Created new conversation context for session: %s", sessionID)
|
|
|
|
return convCtx, nil
|
|
}
|
|
|
|
// GetContext 获取对话上下文
|
|
func (cm *contextManager) GetContext(ctx context.Context, sessionID string) (*ConversationContext, error) {
|
|
cm.mutex.RLock()
|
|
defer cm.mutex.RUnlock()
|
|
|
|
convCtx, exists := cm.contexts[sessionID]
|
|
if !exists {
|
|
return nil, fmt.Errorf("context not found for session: %s", sessionID)
|
|
}
|
|
|
|
// 检查是否过期
|
|
if time.Now().After(convCtx.ExpiresAt) {
|
|
delete(cm.contexts, sessionID)
|
|
return nil, fmt.Errorf("context expired for session: %s", sessionID)
|
|
}
|
|
|
|
return convCtx, nil
|
|
}
|
|
|
|
// UpdateContext 更新对话上下文
|
|
func (cm *contextManager) UpdateContext(ctx context.Context, convCtx *ConversationContext) error {
|
|
cm.mutex.Lock()
|
|
defer cm.mutex.Unlock()
|
|
|
|
convCtx.UpdatedAt = time.Now()
|
|
cm.contexts[convCtx.SessionID] = convCtx
|
|
|
|
return nil
|
|
}
|
|
|
|
// AddMessage 添加消息到上下文
|
|
func (cm *contextManager) AddMessage(ctx context.Context, sessionID string, message Message) error {
|
|
cm.mutex.Lock()
|
|
defer cm.mutex.Unlock()
|
|
|
|
convCtx, exists := cm.contexts[sessionID]
|
|
if !exists {
|
|
return fmt.Errorf("context not found for session: %s", sessionID)
|
|
}
|
|
|
|
// 添加时间戳
|
|
message.Timestamp = time.Now()
|
|
|
|
// 添加消息到历史
|
|
convCtx.History = append(convCtx.History, message)
|
|
|
|
// 限制历史消息数量
|
|
if len(convCtx.History) > cm.maxHistorySize {
|
|
convCtx.History = convCtx.History[len(convCtx.History)-cm.maxHistorySize:]
|
|
}
|
|
|
|
convCtx.UpdatedAt = time.Now()
|
|
|
|
cm.logger.Infof("Added message to context for session: %s, role: %s", sessionID, message.Role)
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetRecentMessages 获取最近的消息
|
|
func (cm *contextManager) GetRecentMessages(ctx context.Context, sessionID string, limit int) ([]Message, error) {
|
|
cm.mutex.RLock()
|
|
defer cm.mutex.RUnlock()
|
|
|
|
convCtx, exists := cm.contexts[sessionID]
|
|
if !exists {
|
|
return nil, fmt.Errorf("context not found for session: %s", sessionID)
|
|
}
|
|
|
|
if limit <= 0 {
|
|
limit = 10
|
|
}
|
|
|
|
historyLen := len(convCtx.History)
|
|
if historyLen == 0 {
|
|
return []Message{}, nil
|
|
}
|
|
|
|
start := historyLen - limit
|
|
if start < 0 {
|
|
start = 0
|
|
}
|
|
|
|
// 返回最近的消息副本
|
|
recentMessages := make([]Message, historyLen-start)
|
|
copy(recentMessages, convCtx.History[start:])
|
|
|
|
return recentMessages, nil
|
|
}
|
|
|
|
// SetIntent 设置意图和实体
|
|
func (cm *contextManager) SetIntent(ctx context.Context, sessionID, intent string, entities map[string]interface{}) error {
|
|
cm.mutex.Lock()
|
|
defer cm.mutex.Unlock()
|
|
|
|
convCtx, exists := cm.contexts[sessionID]
|
|
if !exists {
|
|
return fmt.Errorf("context not found for session: %s", sessionID)
|
|
}
|
|
|
|
convCtx.Intent = intent
|
|
if entities != nil {
|
|
for k, v := range entities {
|
|
convCtx.Entities[k] = v
|
|
}
|
|
}
|
|
convCtx.UpdatedAt = time.Now()
|
|
|
|
cm.logger.Infof("Set intent for session: %s, intent: %s", sessionID, intent)
|
|
|
|
return nil
|
|
}
|
|
|
|
// SetState 设置对话状态
|
|
func (cm *contextManager) SetState(ctx context.Context, sessionID, state string) error {
|
|
cm.mutex.Lock()
|
|
defer cm.mutex.Unlock()
|
|
|
|
convCtx, exists := cm.contexts[sessionID]
|
|
if !exists {
|
|
return fmt.Errorf("context not found for session: %s", sessionID)
|
|
}
|
|
|
|
convCtx.State = state
|
|
convCtx.UpdatedAt = time.Now()
|
|
|
|
cm.logger.Infof("Set state for session: %s, state: %s", sessionID, state)
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetContextSummary 获取上下文摘要
|
|
func (cm *contextManager) GetContextSummary(ctx context.Context, sessionID string) (string, error) {
|
|
cm.mutex.RLock()
|
|
defer cm.mutex.RUnlock()
|
|
|
|
convCtx, exists := cm.contexts[sessionID]
|
|
if !exists {
|
|
return "", fmt.Errorf("context not found for session: %s", sessionID)
|
|
}
|
|
|
|
summary := fmt.Sprintf("会话ID: %s, 用户ID: %s, 当前意图: %s, 状态: %s, 消息数: %d",
|
|
convCtx.SessionID, convCtx.UserID, convCtx.Intent, convCtx.State, len(convCtx.History))
|
|
|
|
// 添加最近的几条消息
|
|
recentCount := 3
|
|
if len(convCtx.History) > 0 {
|
|
summary += "\n最近消息:"
|
|
start := len(convCtx.History) - recentCount
|
|
if start < 0 {
|
|
start = 0
|
|
}
|
|
|
|
for i := start; i < len(convCtx.History); i++ {
|
|
msg := convCtx.History[i]
|
|
summary += fmt.Sprintf("\n- %s: %s", msg.Role, msg.Content)
|
|
}
|
|
}
|
|
|
|
return summary, nil
|
|
}
|
|
|
|
// CleanupExpiredContexts 清理过期上下文
|
|
func (cm *contextManager) CleanupExpiredContexts(ctx context.Context) error {
|
|
cm.mutex.Lock()
|
|
defer cm.mutex.Unlock()
|
|
|
|
now := time.Now()
|
|
expiredSessions := make([]string, 0)
|
|
|
|
for sessionID, convCtx := range cm.contexts {
|
|
if now.After(convCtx.ExpiresAt) {
|
|
expiredSessions = append(expiredSessions, sessionID)
|
|
}
|
|
}
|
|
|
|
for _, sessionID := range expiredSessions {
|
|
delete(cm.contexts, sessionID)
|
|
cm.logger.Infof("Cleaned up expired context for session: %s", sessionID)
|
|
}
|
|
|
|
if len(expiredSessions) > 0 {
|
|
cm.logger.Infof("Cleaned up %d expired contexts", len(expiredSessions))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// DeleteContext 删除上下文
|
|
func (cm *contextManager) DeleteContext(ctx context.Context, sessionID string) error {
|
|
cm.mutex.Lock()
|
|
defer cm.mutex.Unlock()
|
|
|
|
delete(cm.contexts, sessionID)
|
|
cm.logger.Infof("Deleted context for session: %s", sessionID)
|
|
|
|
return nil
|
|
} |