ai-courseware/eino-project/internal/domain/context/context.go

310 lines
8.3 KiB
Go

package context
import (
"context"
"fmt"
"sync"
"time"
"github.com/go-kratos/kratos/v2/log"
)
// ConversationContext 对话上下文
type ConversationContext struct {
SessionID string `json:"session_id"`
UserID string `json:"user_id"`
Intent string `json:"intent"`
Entities map[string]interface{} `json:"entities"`
State string `json:"state"`
History []Message `json:"history"`
Metadata map[string]interface{} `json:"metadata"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
ExpiresAt time.Time `json:"expires_at"`
}
// Message 消息结构
type Message struct {
Role string `json:"role"` // user, assistant, system
Content string `json:"content"`
Timestamp time.Time `json:"timestamp"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
}
// ContextManager 上下文管理器接口
type ContextManager interface {
// 创建新的对话上下文
CreateContext(ctx context.Context, sessionID, userID string) (*ConversationContext, error)
// 获取对话上下文
GetContext(ctx context.Context, sessionID string) (*ConversationContext, error)
// 更新对话上下文
UpdateContext(ctx context.Context, convCtx *ConversationContext) error
// 添加消息到上下文
AddMessage(ctx context.Context, sessionID string, message Message) error
// 获取最近的消息
GetRecentMessages(ctx context.Context, sessionID string, limit int) ([]Message, error)
// 设置意图和实体
SetIntent(ctx context.Context, sessionID, intent string, entities map[string]interface{}) error
// 设置对话状态
SetState(ctx context.Context, sessionID, state string) error
// 获取上下文摘要
GetContextSummary(ctx context.Context, sessionID string) (string, error)
// 清理过期上下文
CleanupExpiredContexts(ctx context.Context) error
// 删除上下文
DeleteContext(ctx context.Context, sessionID string) error
}
// contextManager 上下文管理器实现
type contextManager struct {
contexts map[string]*ConversationContext
mutex sync.RWMutex
logger *log.Helper
// 配置
maxHistorySize int
contextTTL time.Duration
}
// NewContextManager 创建上下文管理器
func NewContextManager(logger log.Logger) ContextManager {
return &contextManager{
contexts: make(map[string]*ConversationContext),
mutex: sync.RWMutex{},
logger: log.NewHelper(logger),
maxHistorySize: 50, // 最大历史消息数
contextTTL: 24 * time.Hour, // 上下文过期时间
}
}
// CreateContext 创建新的对话上下文
func (cm *contextManager) CreateContext(ctx context.Context, sessionID, userID string) (*ConversationContext, error) {
cm.mutex.Lock()
defer cm.mutex.Unlock()
now := time.Now()
convCtx := &ConversationContext{
SessionID: sessionID,
UserID: userID,
Intent: "",
Entities: make(map[string]interface{}),
State: "active",
History: make([]Message, 0),
Metadata: make(map[string]interface{}),
CreatedAt: now,
UpdatedAt: now,
ExpiresAt: now.Add(cm.contextTTL),
}
cm.contexts[sessionID] = convCtx
cm.logger.Infof("Created new conversation context for session: %s", sessionID)
return convCtx, nil
}
// GetContext 获取对话上下文
func (cm *contextManager) GetContext(ctx context.Context, sessionID string) (*ConversationContext, error) {
cm.mutex.RLock()
defer cm.mutex.RUnlock()
convCtx, exists := cm.contexts[sessionID]
if !exists {
return nil, fmt.Errorf("context not found for session: %s", sessionID)
}
// 检查是否过期
if time.Now().After(convCtx.ExpiresAt) {
delete(cm.contexts, sessionID)
return nil, fmt.Errorf("context expired for session: %s", sessionID)
}
return convCtx, nil
}
// UpdateContext 更新对话上下文
func (cm *contextManager) UpdateContext(ctx context.Context, convCtx *ConversationContext) error {
cm.mutex.Lock()
defer cm.mutex.Unlock()
convCtx.UpdatedAt = time.Now()
cm.contexts[convCtx.SessionID] = convCtx
return nil
}
// AddMessage 添加消息到上下文
func (cm *contextManager) AddMessage(ctx context.Context, sessionID string, message Message) error {
cm.mutex.Lock()
defer cm.mutex.Unlock()
convCtx, exists := cm.contexts[sessionID]
if !exists {
return fmt.Errorf("context not found for session: %s", sessionID)
}
// 添加时间戳
message.Timestamp = time.Now()
// 添加消息到历史
convCtx.History = append(convCtx.History, message)
// 限制历史消息数量
if len(convCtx.History) > cm.maxHistorySize {
convCtx.History = convCtx.History[len(convCtx.History)-cm.maxHistorySize:]
}
convCtx.UpdatedAt = time.Now()
cm.logger.Infof("Added message to context for session: %s, role: %s", sessionID, message.Role)
return nil
}
// GetRecentMessages 获取最近的消息
func (cm *contextManager) GetRecentMessages(ctx context.Context, sessionID string, limit int) ([]Message, error) {
cm.mutex.RLock()
defer cm.mutex.RUnlock()
convCtx, exists := cm.contexts[sessionID]
if !exists {
return nil, fmt.Errorf("context not found for session: %s", sessionID)
}
if limit <= 0 {
limit = 10
}
historyLen := len(convCtx.History)
if historyLen == 0 {
return []Message{}, nil
}
start := historyLen - limit
if start < 0 {
start = 0
}
// 返回最近的消息副本
recentMessages := make([]Message, historyLen-start)
copy(recentMessages, convCtx.History[start:])
return recentMessages, nil
}
// SetIntent 设置意图和实体
func (cm *contextManager) SetIntent(ctx context.Context, sessionID, intent string, entities map[string]interface{}) error {
cm.mutex.Lock()
defer cm.mutex.Unlock()
convCtx, exists := cm.contexts[sessionID]
if !exists {
return fmt.Errorf("context not found for session: %s", sessionID)
}
convCtx.Intent = intent
if entities != nil {
for k, v := range entities {
convCtx.Entities[k] = v
}
}
convCtx.UpdatedAt = time.Now()
cm.logger.Infof("Set intent for session: %s, intent: %s", sessionID, intent)
return nil
}
// SetState 设置对话状态
func (cm *contextManager) SetState(ctx context.Context, sessionID, state string) error {
cm.mutex.Lock()
defer cm.mutex.Unlock()
convCtx, exists := cm.contexts[sessionID]
if !exists {
return fmt.Errorf("context not found for session: %s", sessionID)
}
convCtx.State = state
convCtx.UpdatedAt = time.Now()
cm.logger.Infof("Set state for session: %s, state: %s", sessionID, state)
return nil
}
// GetContextSummary 获取上下文摘要
func (cm *contextManager) GetContextSummary(ctx context.Context, sessionID string) (string, error) {
cm.mutex.RLock()
defer cm.mutex.RUnlock()
convCtx, exists := cm.contexts[sessionID]
if !exists {
return "", fmt.Errorf("context not found for session: %s", sessionID)
}
summary := fmt.Sprintf("会话ID: %s, 用户ID: %s, 当前意图: %s, 状态: %s, 消息数: %d",
convCtx.SessionID, convCtx.UserID, convCtx.Intent, convCtx.State, len(convCtx.History))
// 添加最近的几条消息
recentCount := 3
if len(convCtx.History) > 0 {
summary += "\n最近消息:"
start := len(convCtx.History) - recentCount
if start < 0 {
start = 0
}
for i := start; i < len(convCtx.History); i++ {
msg := convCtx.History[i]
summary += fmt.Sprintf("\n- %s: %s", msg.Role, msg.Content)
}
}
return summary, nil
}
// CleanupExpiredContexts 清理过期上下文
func (cm *contextManager) CleanupExpiredContexts(ctx context.Context) error {
cm.mutex.Lock()
defer cm.mutex.Unlock()
now := time.Now()
expiredSessions := make([]string, 0)
for sessionID, convCtx := range cm.contexts {
if now.After(convCtx.ExpiresAt) {
expiredSessions = append(expiredSessions, sessionID)
}
}
for _, sessionID := range expiredSessions {
delete(cm.contexts, sessionID)
cm.logger.Infof("Cleaned up expired context for session: %s", sessionID)
}
if len(expiredSessions) > 0 {
cm.logger.Infof("Cleaned up %d expired contexts", len(expiredSessions))
}
return nil
}
// DeleteContext 删除上下文
func (cm *contextManager) DeleteContext(ctx context.Context, sessionID string) error {
cm.mutex.Lock()
defer cm.mutex.Unlock()
delete(cm.contexts, sessionID)
cm.logger.Infof("Deleted context for session: %s", sessionID)
return nil
}