233 lines
6.3 KiB
Go
233 lines
6.3 KiB
Go
package session
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"sync"
|
||
"time"
|
||
|
||
"github.com/go-kratos/kratos/v2/log"
|
||
)
|
||
|
||
// Message 消息结构
|
||
type Message struct {
|
||
ID string `json:"id"`
|
||
SessionID string `json:"session_id"`
|
||
Role string `json:"role"` // user, assistant, system
|
||
Content string `json:"content"`
|
||
Timestamp time.Time `json:"timestamp"`
|
||
}
|
||
|
||
// Session 会话结构
|
||
type Session struct {
|
||
ID string `json:"id"`
|
||
UserID string `json:"user_id"`
|
||
Title string `json:"title"`
|
||
Messages []Message `json:"messages"`
|
||
CreatedAt time.Time `json:"created_at"`
|
||
UpdatedAt time.Time `json:"updated_at"`
|
||
IsActive bool `json:"is_active"`
|
||
}
|
||
|
||
// SessionManager 会话管理器接口
|
||
type SessionManager interface {
|
||
// CreateSession 创建新会话
|
||
CreateSession(ctx context.Context, userID string, title string) (*Session, error)
|
||
// GetSession 获取会话
|
||
GetSession(ctx context.Context, sessionID string) (*Session, error)
|
||
// AddMessage 添加消息到会话
|
||
AddMessage(ctx context.Context, sessionID string, role string, content string) (*Message, error)
|
||
// GetMessages 获取会话消息历史
|
||
GetMessages(ctx context.Context, sessionID string, limit int) ([]Message, error)
|
||
// UpdateSession 更新会话
|
||
UpdateSession(ctx context.Context, session *Session) error
|
||
// ListSessions 获取用户会话列表
|
||
ListSessions(ctx context.Context, userID string) ([]*Session, error)
|
||
// DeleteSession 删除会话
|
||
DeleteSession(ctx context.Context, sessionID string) error
|
||
}
|
||
|
||
// memorySessionManager 内存会话管理器实现
|
||
type memorySessionManager struct {
|
||
sessions map[string]*Session
|
||
messages map[string][]Message
|
||
mutex sync.RWMutex
|
||
log *log.Helper
|
||
}
|
||
|
||
// NewMemorySessionManager 创建内存会话管理器
|
||
func NewMemorySessionManager(logger log.Logger) SessionManager {
|
||
return &memorySessionManager{
|
||
sessions: make(map[string]*Session),
|
||
messages: make(map[string][]Message),
|
||
log: log.NewHelper(logger),
|
||
}
|
||
}
|
||
|
||
// CreateSession 创建新会话
|
||
func (m *memorySessionManager) CreateSession(ctx context.Context, userID string, title string) (*Session, error) {
|
||
m.mutex.Lock()
|
||
defer m.mutex.Unlock()
|
||
|
||
sessionID := fmt.Sprintf("session_%d", time.Now().UnixNano())
|
||
|
||
session := &Session{
|
||
ID: sessionID,
|
||
UserID: userID,
|
||
Title: title,
|
||
Messages: []Message{},
|
||
CreatedAt: time.Now(),
|
||
UpdatedAt: time.Now(),
|
||
IsActive: true,
|
||
}
|
||
|
||
m.sessions[sessionID] = session
|
||
m.messages[sessionID] = []Message{}
|
||
|
||
m.log.WithContext(ctx).Infof("Created session: %s for user: %s", sessionID, userID)
|
||
return session, nil
|
||
}
|
||
|
||
// GetSession 获取会话
|
||
func (m *memorySessionManager) GetSession(ctx context.Context, sessionID string) (*Session, error) {
|
||
m.mutex.RLock()
|
||
defer m.mutex.RUnlock()
|
||
|
||
session, exists := m.sessions[sessionID]
|
||
if !exists {
|
||
return nil, fmt.Errorf("session not found: %s", sessionID)
|
||
}
|
||
|
||
// 获取最新的消息列表
|
||
if messages, ok := m.messages[sessionID]; ok {
|
||
session.Messages = messages
|
||
}
|
||
|
||
return session, nil
|
||
}
|
||
|
||
// AddMessage 添加消息到会话
|
||
func (m *memorySessionManager) AddMessage(ctx context.Context, sessionID string, role string, content string) (*Message, error) {
|
||
m.mutex.Lock()
|
||
defer m.mutex.Unlock()
|
||
|
||
// 检查会话是否存在
|
||
session, exists := m.sessions[sessionID]
|
||
if !exists {
|
||
return nil, fmt.Errorf("session not found: %s", sessionID)
|
||
}
|
||
|
||
messageID := fmt.Sprintf("msg_%d", time.Now().UnixNano())
|
||
message := Message{
|
||
ID: messageID,
|
||
SessionID: sessionID,
|
||
Role: role,
|
||
Content: content,
|
||
Timestamp: time.Now(),
|
||
}
|
||
|
||
// 添加消息到会话
|
||
m.messages[sessionID] = append(m.messages[sessionID], message)
|
||
|
||
// 更新会话时间
|
||
session.UpdatedAt = time.Now()
|
||
|
||
m.log.WithContext(ctx).Infof("Added message to session %s: %s", sessionID, messageID)
|
||
return &message, nil
|
||
}
|
||
|
||
// GetMessages 获取会话消息历史
|
||
func (m *memorySessionManager) GetMessages(ctx context.Context, sessionID string, limit int) ([]Message, error) {
|
||
m.mutex.RLock()
|
||
defer m.mutex.RUnlock()
|
||
|
||
messages, exists := m.messages[sessionID]
|
||
if !exists {
|
||
return []Message{}, nil
|
||
}
|
||
|
||
// 如果limit <= 0,返回所有消息
|
||
if limit <= 0 {
|
||
return messages, nil
|
||
}
|
||
|
||
// 返回最后limit条消息
|
||
start := len(messages) - limit
|
||
if start < 0 {
|
||
start = 0
|
||
}
|
||
|
||
return messages[start:], nil
|
||
}
|
||
|
||
// UpdateSession 更新会话
|
||
func (m *memorySessionManager) UpdateSession(ctx context.Context, session *Session) error {
|
||
m.mutex.Lock()
|
||
defer m.mutex.Unlock()
|
||
|
||
if _, exists := m.sessions[session.ID]; !exists {
|
||
return fmt.Errorf("session not found: %s", session.ID)
|
||
}
|
||
|
||
session.UpdatedAt = time.Now()
|
||
m.sessions[session.ID] = session
|
||
|
||
m.log.WithContext(ctx).Infof("Updated session: %s", session.ID)
|
||
return nil
|
||
}
|
||
|
||
// ListSessions 获取用户会话列表
|
||
func (m *memorySessionManager) ListSessions(ctx context.Context, userID string) ([]*Session, error) {
|
||
m.mutex.RLock()
|
||
defer m.mutex.RUnlock()
|
||
|
||
var sessions []*Session
|
||
for _, session := range m.sessions {
|
||
if session.UserID == userID && session.IsActive {
|
||
// 创建副本避免并发问题
|
||
sessionCopy := *session
|
||
if messages, ok := m.messages[session.ID]; ok {
|
||
sessionCopy.Messages = make([]Message, len(messages))
|
||
copy(sessionCopy.Messages, messages)
|
||
}
|
||
sessions = append(sessions, &sessionCopy)
|
||
}
|
||
}
|
||
|
||
return sessions, nil
|
||
}
|
||
|
||
// DeleteSession 删除会话
|
||
func (m *memorySessionManager) DeleteSession(ctx context.Context, sessionID string) error {
|
||
m.mutex.Lock()
|
||
defer m.mutex.Unlock()
|
||
|
||
if _, exists := m.sessions[sessionID]; !exists {
|
||
return fmt.Errorf("session not found: %s", sessionID)
|
||
}
|
||
|
||
// 标记为非活跃而不是直接删除
|
||
m.sessions[sessionID].IsActive = false
|
||
m.sessions[sessionID].UpdatedAt = time.Now()
|
||
|
||
m.log.WithContext(ctx).Infof("Deleted session: %s", sessionID)
|
||
return nil
|
||
}
|
||
|
||
// GetContextMessages 获取用于AI对话的上下文消息
|
||
func (m *memorySessionManager) GetContextMessages(ctx context.Context, sessionID string, maxMessages int) ([]Message, error) {
|
||
messages, err := m.GetMessages(ctx, sessionID, maxMessages)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 过滤掉系统消息,只保留用户和助手的对话
|
||
var contextMessages []Message
|
||
for _, msg := range messages {
|
||
if msg.Role == "user" || msg.Role == "assistant" {
|
||
contextMessages = append(contextMessages, msg)
|
||
}
|
||
}
|
||
|
||
return contextMessages, nil
|
||
} |