ai-courseware/eino-project/internal/domain/session/session.go

233 lines
6.3 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package session
import (
"context"
"fmt"
"sync"
"time"
"github.com/go-kratos/kratos/v2/log"
)
// Message 消息结构
type Message struct {
ID string `json:"id"`
SessionID string `json:"session_id"`
Role string `json:"role"` // user, assistant, system
Content string `json:"content"`
Timestamp time.Time `json:"timestamp"`
}
// Session 会话结构
type Session struct {
ID string `json:"id"`
UserID string `json:"user_id"`
Title string `json:"title"`
Messages []Message `json:"messages"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
IsActive bool `json:"is_active"`
}
// SessionManager 会话管理器接口
type SessionManager interface {
// CreateSession 创建新会话
CreateSession(ctx context.Context, userID string, title string) (*Session, error)
// GetSession 获取会话
GetSession(ctx context.Context, sessionID string) (*Session, error)
// AddMessage 添加消息到会话
AddMessage(ctx context.Context, sessionID string, role string, content string) (*Message, error)
// GetMessages 获取会话消息历史
GetMessages(ctx context.Context, sessionID string, limit int) ([]Message, error)
// UpdateSession 更新会话
UpdateSession(ctx context.Context, session *Session) error
// ListSessions 获取用户会话列表
ListSessions(ctx context.Context, userID string) ([]*Session, error)
// DeleteSession 删除会话
DeleteSession(ctx context.Context, sessionID string) error
}
// memorySessionManager 内存会话管理器实现
type memorySessionManager struct {
sessions map[string]*Session
messages map[string][]Message
mutex sync.RWMutex
log *log.Helper
}
// NewMemorySessionManager 创建内存会话管理器
func NewMemorySessionManager(logger log.Logger) SessionManager {
return &memorySessionManager{
sessions: make(map[string]*Session),
messages: make(map[string][]Message),
log: log.NewHelper(logger),
}
}
// CreateSession 创建新会话
func (m *memorySessionManager) CreateSession(ctx context.Context, userID string, title string) (*Session, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
sessionID := fmt.Sprintf("session_%d", time.Now().UnixNano())
session := &Session{
ID: sessionID,
UserID: userID,
Title: title,
Messages: []Message{},
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
IsActive: true,
}
m.sessions[sessionID] = session
m.messages[sessionID] = []Message{}
m.log.WithContext(ctx).Infof("Created session: %s for user: %s", sessionID, userID)
return session, nil
}
// GetSession 获取会话
func (m *memorySessionManager) GetSession(ctx context.Context, sessionID string) (*Session, error) {
m.mutex.RLock()
defer m.mutex.RUnlock()
session, exists := m.sessions[sessionID]
if !exists {
return nil, fmt.Errorf("session not found: %s", sessionID)
}
// 获取最新的消息列表
if messages, ok := m.messages[sessionID]; ok {
session.Messages = messages
}
return session, nil
}
// AddMessage 添加消息到会话
func (m *memorySessionManager) AddMessage(ctx context.Context, sessionID string, role string, content string) (*Message, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
// 检查会话是否存在
session, exists := m.sessions[sessionID]
if !exists {
return nil, fmt.Errorf("session not found: %s", sessionID)
}
messageID := fmt.Sprintf("msg_%d", time.Now().UnixNano())
message := Message{
ID: messageID,
SessionID: sessionID,
Role: role,
Content: content,
Timestamp: time.Now(),
}
// 添加消息到会话
m.messages[sessionID] = append(m.messages[sessionID], message)
// 更新会话时间
session.UpdatedAt = time.Now()
m.log.WithContext(ctx).Infof("Added message to session %s: %s", sessionID, messageID)
return &message, nil
}
// GetMessages 获取会话消息历史
func (m *memorySessionManager) GetMessages(ctx context.Context, sessionID string, limit int) ([]Message, error) {
m.mutex.RLock()
defer m.mutex.RUnlock()
messages, exists := m.messages[sessionID]
if !exists {
return []Message{}, nil
}
// 如果limit <= 0返回所有消息
if limit <= 0 {
return messages, nil
}
// 返回最后limit条消息
start := len(messages) - limit
if start < 0 {
start = 0
}
return messages[start:], nil
}
// UpdateSession 更新会话
func (m *memorySessionManager) UpdateSession(ctx context.Context, session *Session) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if _, exists := m.sessions[session.ID]; !exists {
return fmt.Errorf("session not found: %s", session.ID)
}
session.UpdatedAt = time.Now()
m.sessions[session.ID] = session
m.log.WithContext(ctx).Infof("Updated session: %s", session.ID)
return nil
}
// ListSessions 获取用户会话列表
func (m *memorySessionManager) ListSessions(ctx context.Context, userID string) ([]*Session, error) {
m.mutex.RLock()
defer m.mutex.RUnlock()
var sessions []*Session
for _, session := range m.sessions {
if session.UserID == userID && session.IsActive {
// 创建副本避免并发问题
sessionCopy := *session
if messages, ok := m.messages[session.ID]; ok {
sessionCopy.Messages = make([]Message, len(messages))
copy(sessionCopy.Messages, messages)
}
sessions = append(sessions, &sessionCopy)
}
}
return sessions, nil
}
// DeleteSession 删除会话
func (m *memorySessionManager) DeleteSession(ctx context.Context, sessionID string) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if _, exists := m.sessions[sessionID]; !exists {
return fmt.Errorf("session not found: %s", sessionID)
}
// 标记为非活跃而不是直接删除
m.sessions[sessionID].IsActive = false
m.sessions[sessionID].UpdatedAt = time.Now()
m.log.WithContext(ctx).Infof("Deleted session: %s", sessionID)
return nil
}
// GetContextMessages 获取用于AI对话的上下文消息
func (m *memorySessionManager) GetContextMessages(ctx context.Context, sessionID string, maxMessages int) ([]Message, error) {
messages, err := m.GetMessages(ctx, sessionID, maxMessages)
if err != nil {
return nil, err
}
// 过滤掉系统消息,只保留用户和助手的对话
var contextMessages []Message
for _, msg := range messages {
if msg.Role == "user" || msg.Role == "assistant" {
contextMessages = append(contextMessages, msg)
}
}
return contextMessages, nil
}