package session import ( "context" "fmt" "sync" "time" "github.com/go-kratos/kratos/v2/log" ) // Message 消息结构 type Message struct { ID string `json:"id"` SessionID string `json:"session_id"` Role string `json:"role"` // user, assistant, system Content string `json:"content"` Timestamp time.Time `json:"timestamp"` } // Session 会话结构 type Session struct { ID string `json:"id"` UserID string `json:"user_id"` Title string `json:"title"` Messages []Message `json:"messages"` CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` IsActive bool `json:"is_active"` } // SessionManager 会话管理器接口 type SessionManager interface { // CreateSession 创建新会话 CreateSession(ctx context.Context, userID string, title string) (*Session, error) // GetSession 获取会话 GetSession(ctx context.Context, sessionID string) (*Session, error) // AddMessage 添加消息到会话 AddMessage(ctx context.Context, sessionID string, role string, content string) (*Message, error) // GetMessages 获取会话消息历史 GetMessages(ctx context.Context, sessionID string, limit int) ([]Message, error) // UpdateSession 更新会话 UpdateSession(ctx context.Context, session *Session) error // ListSessions 获取用户会话列表 ListSessions(ctx context.Context, userID string) ([]*Session, error) // DeleteSession 删除会话 DeleteSession(ctx context.Context, sessionID string) error } // memorySessionManager 内存会话管理器实现 type memorySessionManager struct { sessions map[string]*Session messages map[string][]Message mutex sync.RWMutex log *log.Helper } // NewMemorySessionManager 创建内存会话管理器 func NewMemorySessionManager(logger log.Logger) SessionManager { return &memorySessionManager{ sessions: make(map[string]*Session), messages: make(map[string][]Message), log: log.NewHelper(logger), } } // CreateSession 创建新会话 func (m *memorySessionManager) CreateSession(ctx context.Context, userID string, title string) (*Session, error) { m.mutex.Lock() defer m.mutex.Unlock() sessionID := fmt.Sprintf("session_%d", time.Now().UnixNano()) session := &Session{ ID: sessionID, UserID: userID, Title: title, Messages: []Message{}, CreatedAt: time.Now(), UpdatedAt: time.Now(), IsActive: true, } m.sessions[sessionID] = session m.messages[sessionID] = []Message{} m.log.WithContext(ctx).Infof("Created session: %s for user: %s", sessionID, userID) return session, nil } // GetSession 获取会话 func (m *memorySessionManager) GetSession(ctx context.Context, sessionID string) (*Session, error) { m.mutex.RLock() defer m.mutex.RUnlock() session, exists := m.sessions[sessionID] if !exists { return nil, fmt.Errorf("session not found: %s", sessionID) } // 获取最新的消息列表 if messages, ok := m.messages[sessionID]; ok { session.Messages = messages } return session, nil } // AddMessage 添加消息到会话 func (m *memorySessionManager) AddMessage(ctx context.Context, sessionID string, role string, content string) (*Message, error) { m.mutex.Lock() defer m.mutex.Unlock() // 检查会话是否存在 session, exists := m.sessions[sessionID] if !exists { return nil, fmt.Errorf("session not found: %s", sessionID) } messageID := fmt.Sprintf("msg_%d", time.Now().UnixNano()) message := Message{ ID: messageID, SessionID: sessionID, Role: role, Content: content, Timestamp: time.Now(), } // 添加消息到会话 m.messages[sessionID] = append(m.messages[sessionID], message) // 更新会话时间 session.UpdatedAt = time.Now() m.log.WithContext(ctx).Infof("Added message to session %s: %s", sessionID, messageID) return &message, nil } // GetMessages 获取会话消息历史 func (m *memorySessionManager) GetMessages(ctx context.Context, sessionID string, limit int) ([]Message, error) { m.mutex.RLock() defer m.mutex.RUnlock() messages, exists := m.messages[sessionID] if !exists { return []Message{}, nil } // 如果limit <= 0,返回所有消息 if limit <= 0 { return messages, nil } // 返回最后limit条消息 start := len(messages) - limit if start < 0 { start = 0 } return messages[start:], nil } // UpdateSession 更新会话 func (m *memorySessionManager) UpdateSession(ctx context.Context, session *Session) error { m.mutex.Lock() defer m.mutex.Unlock() if _, exists := m.sessions[session.ID]; !exists { return fmt.Errorf("session not found: %s", session.ID) } session.UpdatedAt = time.Now() m.sessions[session.ID] = session m.log.WithContext(ctx).Infof("Updated session: %s", session.ID) return nil } // ListSessions 获取用户会话列表 func (m *memorySessionManager) ListSessions(ctx context.Context, userID string) ([]*Session, error) { m.mutex.RLock() defer m.mutex.RUnlock() var sessions []*Session for _, session := range m.sessions { if session.UserID == userID && session.IsActive { // 创建副本避免并发问题 sessionCopy := *session if messages, ok := m.messages[session.ID]; ok { sessionCopy.Messages = make([]Message, len(messages)) copy(sessionCopy.Messages, messages) } sessions = append(sessions, &sessionCopy) } } return sessions, nil } // DeleteSession 删除会话 func (m *memorySessionManager) DeleteSession(ctx context.Context, sessionID string) error { m.mutex.Lock() defer m.mutex.Unlock() if _, exists := m.sessions[sessionID]; !exists { return fmt.Errorf("session not found: %s", sessionID) } // 标记为非活跃而不是直接删除 m.sessions[sessionID].IsActive = false m.sessions[sessionID].UpdatedAt = time.Now() m.log.WithContext(ctx).Infof("Deleted session: %s", sessionID) return nil } // GetContextMessages 获取用于AI对话的上下文消息 func (m *memorySessionManager) GetContextMessages(ctx context.Context, sessionID string, maxMessages int) ([]Message, error) { messages, err := m.GetMessages(ctx, sessionID, maxMessages) if err != nil { return nil, err } // 过滤掉系统消息,只保留用户和助手的对话 var contextMessages []Message for _, msg := range messages { if msg.Role == "user" || msg.Role == "assistant" { contextMessages = append(contextMessages, msg) } } return contextMessages, nil }