package context import ( "context" "fmt" "sync" "time" "github.com/go-kratos/kratos/v2/log" ) // ConversationContext 对话上下文 type ConversationContext struct { SessionID string `json:"session_id"` UserID string `json:"user_id"` Intent string `json:"intent"` Entities map[string]interface{} `json:"entities"` State string `json:"state"` History []Message `json:"history"` Metadata map[string]interface{} `json:"metadata"` CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` ExpiresAt time.Time `json:"expires_at"` } // Message 消息结构 type Message struct { Role string `json:"role"` // user, assistant, system Content string `json:"content"` Timestamp time.Time `json:"timestamp"` Metadata map[string]interface{} `json:"metadata,omitempty"` } // ContextManager 上下文管理器接口 type ContextManager interface { // 创建新的对话上下文 CreateContext(ctx context.Context, sessionID, userID string) (*ConversationContext, error) // 获取对话上下文 GetContext(ctx context.Context, sessionID string) (*ConversationContext, error) // 更新对话上下文 UpdateContext(ctx context.Context, convCtx *ConversationContext) error // 添加消息到上下文 AddMessage(ctx context.Context, sessionID string, message Message) error // 获取最近的消息 GetRecentMessages(ctx context.Context, sessionID string, limit int) ([]Message, error) // 设置意图和实体 SetIntent(ctx context.Context, sessionID, intent string, entities map[string]interface{}) error // 设置对话状态 SetState(ctx context.Context, sessionID, state string) error // 获取上下文摘要 GetContextSummary(ctx context.Context, sessionID string) (string, error) // 清理过期上下文 CleanupExpiredContexts(ctx context.Context) error // 删除上下文 DeleteContext(ctx context.Context, sessionID string) error } // contextManager 上下文管理器实现 type contextManager struct { contexts map[string]*ConversationContext mutex sync.RWMutex logger *log.Helper // 配置 maxHistorySize int contextTTL time.Duration } // NewContextManager 创建上下文管理器 func NewContextManager(logger log.Logger) ContextManager { return &contextManager{ contexts: make(map[string]*ConversationContext), mutex: sync.RWMutex{}, logger: log.NewHelper(logger), maxHistorySize: 50, // 最大历史消息数 contextTTL: 24 * time.Hour, // 上下文过期时间 } } // CreateContext 创建新的对话上下文 func (cm *contextManager) CreateContext(ctx context.Context, sessionID, userID string) (*ConversationContext, error) { cm.mutex.Lock() defer cm.mutex.Unlock() now := time.Now() convCtx := &ConversationContext{ SessionID: sessionID, UserID: userID, Intent: "", Entities: make(map[string]interface{}), State: "active", History: make([]Message, 0), Metadata: make(map[string]interface{}), CreatedAt: now, UpdatedAt: now, ExpiresAt: now.Add(cm.contextTTL), } cm.contexts[sessionID] = convCtx cm.logger.Infof("Created new conversation context for session: %s", sessionID) return convCtx, nil } // GetContext 获取对话上下文 func (cm *contextManager) GetContext(ctx context.Context, sessionID string) (*ConversationContext, error) { cm.mutex.RLock() defer cm.mutex.RUnlock() convCtx, exists := cm.contexts[sessionID] if !exists { return nil, fmt.Errorf("context not found for session: %s", sessionID) } // 检查是否过期 if time.Now().After(convCtx.ExpiresAt) { delete(cm.contexts, sessionID) return nil, fmt.Errorf("context expired for session: %s", sessionID) } return convCtx, nil } // UpdateContext 更新对话上下文 func (cm *contextManager) UpdateContext(ctx context.Context, convCtx *ConversationContext) error { cm.mutex.Lock() defer cm.mutex.Unlock() convCtx.UpdatedAt = time.Now() cm.contexts[convCtx.SessionID] = convCtx return nil } // AddMessage 添加消息到上下文 func (cm *contextManager) AddMessage(ctx context.Context, sessionID string, message Message) error { cm.mutex.Lock() defer cm.mutex.Unlock() convCtx, exists := cm.contexts[sessionID] if !exists { return fmt.Errorf("context not found for session: %s", sessionID) } // 添加时间戳 message.Timestamp = time.Now() // 添加消息到历史 convCtx.History = append(convCtx.History, message) // 限制历史消息数量 if len(convCtx.History) > cm.maxHistorySize { convCtx.History = convCtx.History[len(convCtx.History)-cm.maxHistorySize:] } convCtx.UpdatedAt = time.Now() cm.logger.Infof("Added message to context for session: %s, role: %s", sessionID, message.Role) return nil } // GetRecentMessages 获取最近的消息 func (cm *contextManager) GetRecentMessages(ctx context.Context, sessionID string, limit int) ([]Message, error) { cm.mutex.RLock() defer cm.mutex.RUnlock() convCtx, exists := cm.contexts[sessionID] if !exists { return nil, fmt.Errorf("context not found for session: %s", sessionID) } if limit <= 0 { limit = 10 } historyLen := len(convCtx.History) if historyLen == 0 { return []Message{}, nil } start := historyLen - limit if start < 0 { start = 0 } // 返回最近的消息副本 recentMessages := make([]Message, historyLen-start) copy(recentMessages, convCtx.History[start:]) return recentMessages, nil } // SetIntent 设置意图和实体 func (cm *contextManager) SetIntent(ctx context.Context, sessionID, intent string, entities map[string]interface{}) error { cm.mutex.Lock() defer cm.mutex.Unlock() convCtx, exists := cm.contexts[sessionID] if !exists { return fmt.Errorf("context not found for session: %s", sessionID) } convCtx.Intent = intent if entities != nil { for k, v := range entities { convCtx.Entities[k] = v } } convCtx.UpdatedAt = time.Now() cm.logger.Infof("Set intent for session: %s, intent: %s", sessionID, intent) return nil } // SetState 设置对话状态 func (cm *contextManager) SetState(ctx context.Context, sessionID, state string) error { cm.mutex.Lock() defer cm.mutex.Unlock() convCtx, exists := cm.contexts[sessionID] if !exists { return fmt.Errorf("context not found for session: %s", sessionID) } convCtx.State = state convCtx.UpdatedAt = time.Now() cm.logger.Infof("Set state for session: %s, state: %s", sessionID, state) return nil } // GetContextSummary 获取上下文摘要 func (cm *contextManager) GetContextSummary(ctx context.Context, sessionID string) (string, error) { cm.mutex.RLock() defer cm.mutex.RUnlock() convCtx, exists := cm.contexts[sessionID] if !exists { return "", fmt.Errorf("context not found for session: %s", sessionID) } summary := fmt.Sprintf("会话ID: %s, 用户ID: %s, 当前意图: %s, 状态: %s, 消息数: %d", convCtx.SessionID, convCtx.UserID, convCtx.Intent, convCtx.State, len(convCtx.History)) // 添加最近的几条消息 recentCount := 3 if len(convCtx.History) > 0 { summary += "\n最近消息:" start := len(convCtx.History) - recentCount if start < 0 { start = 0 } for i := start; i < len(convCtx.History); i++ { msg := convCtx.History[i] summary += fmt.Sprintf("\n- %s: %s", msg.Role, msg.Content) } } return summary, nil } // CleanupExpiredContexts 清理过期上下文 func (cm *contextManager) CleanupExpiredContexts(ctx context.Context) error { cm.mutex.Lock() defer cm.mutex.Unlock() now := time.Now() expiredSessions := make([]string, 0) for sessionID, convCtx := range cm.contexts { if now.After(convCtx.ExpiresAt) { expiredSessions = append(expiredSessions, sessionID) } } for _, sessionID := range expiredSessions { delete(cm.contexts, sessionID) cm.logger.Infof("Cleaned up expired context for session: %s", sessionID) } if len(expiredSessions) > 0 { cm.logger.Infof("Cleaned up %d expired contexts", len(expiredSessions)) } return nil } // DeleteContext 删除上下文 func (cm *contextManager) DeleteContext(ctx context.Context, sessionID string) error { cm.mutex.Lock() defer cm.mutex.Unlock() delete(cm.contexts, sessionID) cm.logger.Infof("Deleted context for session: %s", sessionID) return nil }