l_ai_knowledge/internal/stream/memory_manager.go

124 lines
3.1 KiB
Go

package stream
import (
"context"
"sync"
"time"
"knowlege-lsxd/internal/types"
"knowlege-lsxd/internal/types/interfaces"
)
// 内存流信息
type memoryStreamInfo struct {
sessionID string
requestID string
query string
content string
knowledgeReferences types.References
lastUpdated time.Time
isCompleted bool
}
// MemoryStreamManager 基于内存的流管理器实现
type MemoryStreamManager struct {
// 会话ID -> 请求ID -> 流数据
activeStreams map[string]map[string]*memoryStreamInfo
mu sync.RWMutex
}
// NewMemoryStreamManager 创建一个新的内存流管理器
func NewMemoryStreamManager() *MemoryStreamManager {
return &MemoryStreamManager{
activeStreams: make(map[string]map[string]*memoryStreamInfo),
}
}
// RegisterStream 注册一个新的流
func (m *MemoryStreamManager) RegisterStream(ctx context.Context, sessionID, requestID, query string) error {
m.mu.Lock()
defer m.mu.Unlock()
info := &memoryStreamInfo{
sessionID: sessionID,
requestID: requestID,
query: query,
lastUpdated: time.Now(),
}
if _, exists := m.activeStreams[sessionID]; !exists {
m.activeStreams[sessionID] = make(map[string]*memoryStreamInfo)
}
m.activeStreams[sessionID][requestID] = info
return nil
}
// UpdateStream 更新流内容
func (m *MemoryStreamManager) UpdateStream(ctx context.Context,
sessionID, requestID string, content string, references types.References,
) error {
m.mu.Lock()
defer m.mu.Unlock()
if sessionMap, exists := m.activeStreams[sessionID]; exists {
if stream, found := sessionMap[requestID]; found {
stream.content += content
if len(references) > 0 {
stream.knowledgeReferences = references
}
stream.lastUpdated = time.Now()
}
}
return nil
}
// CompleteStream 完成流
func (m *MemoryStreamManager) CompleteStream(ctx context.Context, sessionID, requestID string) error {
m.mu.Lock()
defer m.mu.Unlock()
if sessionMap, exists := m.activeStreams[sessionID]; exists {
if stream, found := sessionMap[requestID]; found {
stream.isCompleted = true
// 30s 后删除流
go func() {
time.Sleep(30 * time.Second)
m.mu.Lock()
defer m.mu.Unlock()
delete(sessionMap, requestID)
if len(sessionMap) == 0 {
delete(m.activeStreams, sessionID)
}
}()
}
}
return nil
}
// GetStream 获取特定流
func (m *MemoryStreamManager) GetStream(ctx context.Context,
sessionID, requestID string,
) (*interfaces.StreamInfo, error) {
m.mu.RLock()
defer m.mu.RUnlock()
if sessionMap, exists := m.activeStreams[sessionID]; exists {
if stream, found := sessionMap[requestID]; found {
return &interfaces.StreamInfo{
SessionID: stream.sessionID,
RequestID: stream.requestID,
Query: stream.query,
Content: stream.content,
KnowledgeReferences: stream.knowledgeReferences,
LastUpdated: stream.lastUpdated,
IsCompleted: stream.isCompleted,
}, nil
}
}
return nil, nil
}
// 确保实现了接口
var _ interfaces.StreamManager = (*MemoryStreamManager)(nil)