124 lines
3.1 KiB
Go
124 lines
3.1 KiB
Go
package stream
|
|
|
|
import (
|
|
"context"
|
|
"sync"
|
|
"time"
|
|
|
|
"knowlege-lsxd/internal/types"
|
|
"knowlege-lsxd/internal/types/interfaces"
|
|
)
|
|
|
|
// 内存流信息
|
|
type memoryStreamInfo struct {
|
|
sessionID string
|
|
requestID string
|
|
query string
|
|
content string
|
|
knowledgeReferences types.References
|
|
lastUpdated time.Time
|
|
isCompleted bool
|
|
}
|
|
|
|
// MemoryStreamManager 基于内存的流管理器实现
|
|
type MemoryStreamManager struct {
|
|
// 会话ID -> 请求ID -> 流数据
|
|
activeStreams map[string]map[string]*memoryStreamInfo
|
|
mu sync.RWMutex
|
|
}
|
|
|
|
// NewMemoryStreamManager 创建一个新的内存流管理器
|
|
func NewMemoryStreamManager() *MemoryStreamManager {
|
|
return &MemoryStreamManager{
|
|
activeStreams: make(map[string]map[string]*memoryStreamInfo),
|
|
}
|
|
}
|
|
|
|
// RegisterStream 注册一个新的流
|
|
func (m *MemoryStreamManager) RegisterStream(ctx context.Context, sessionID, requestID, query string) error {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
info := &memoryStreamInfo{
|
|
sessionID: sessionID,
|
|
requestID: requestID,
|
|
query: query,
|
|
lastUpdated: time.Now(),
|
|
}
|
|
|
|
if _, exists := m.activeStreams[sessionID]; !exists {
|
|
m.activeStreams[sessionID] = make(map[string]*memoryStreamInfo)
|
|
}
|
|
|
|
m.activeStreams[sessionID][requestID] = info
|
|
return nil
|
|
}
|
|
|
|
// UpdateStream 更新流内容
|
|
func (m *MemoryStreamManager) UpdateStream(ctx context.Context,
|
|
sessionID, requestID string, content string, references types.References,
|
|
) error {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
if sessionMap, exists := m.activeStreams[sessionID]; exists {
|
|
if stream, found := sessionMap[requestID]; found {
|
|
stream.content += content
|
|
if len(references) > 0 {
|
|
stream.knowledgeReferences = references
|
|
}
|
|
stream.lastUpdated = time.Now()
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// CompleteStream 完成流
|
|
func (m *MemoryStreamManager) CompleteStream(ctx context.Context, sessionID, requestID string) error {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
if sessionMap, exists := m.activeStreams[sessionID]; exists {
|
|
if stream, found := sessionMap[requestID]; found {
|
|
stream.isCompleted = true
|
|
// 30s 后删除流
|
|
go func() {
|
|
time.Sleep(30 * time.Second)
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
delete(sessionMap, requestID)
|
|
if len(sessionMap) == 0 {
|
|
delete(m.activeStreams, sessionID)
|
|
}
|
|
}()
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetStream 获取特定流
|
|
func (m *MemoryStreamManager) GetStream(ctx context.Context,
|
|
sessionID, requestID string,
|
|
) (*interfaces.StreamInfo, error) {
|
|
m.mu.RLock()
|
|
defer m.mu.RUnlock()
|
|
|
|
if sessionMap, exists := m.activeStreams[sessionID]; exists {
|
|
if stream, found := sessionMap[requestID]; found {
|
|
return &interfaces.StreamInfo{
|
|
SessionID: stream.sessionID,
|
|
RequestID: stream.requestID,
|
|
Query: stream.query,
|
|
Content: stream.content,
|
|
KnowledgeReferences: stream.knowledgeReferences,
|
|
LastUpdated: stream.lastUpdated,
|
|
IsCompleted: stream.isCompleted,
|
|
}, nil
|
|
}
|
|
}
|
|
return nil, nil
|
|
}
|
|
|
|
// 确保实现了接口
|
|
var _ interfaces.StreamManager = (*MemoryStreamManager)(nil)
|