l_ai_knowledge/internal/handler/session.go

810 lines
26 KiB
Go

package handler
import (
"context"
"fmt"
"io"
"net/http"
"time"
"github.com/gin-gonic/gin"
"knowlege-lsxd/internal/application/service"
"knowlege-lsxd/internal/config"
"knowlege-lsxd/internal/errors"
"knowlege-lsxd/internal/logger"
"knowlege-lsxd/internal/types"
"knowlege-lsxd/internal/types/interfaces"
)
// SessionHandler handles all HTTP requests related to conversation sessions
type SessionHandler struct {
messageService interfaces.MessageService // Service for managing messages
sessionService interfaces.SessionService // Service for managing sessions
streamManager interfaces.StreamManager // Manager for handling streaming responses
config *config.Config // Application configuration
testDataService *service.TestDataService // Service for test data (models, etc.)
knowledgebaseService interfaces.KnowledgeBaseService
}
// NewSessionHandler creates a new instance of SessionHandler with all necessary dependencies
func NewSessionHandler(
sessionService interfaces.SessionService,
messageService interfaces.MessageService,
streamManager interfaces.StreamManager,
config *config.Config,
testDataService *service.TestDataService,
knowledgebaseService interfaces.KnowledgeBaseService,
) *SessionHandler {
return &SessionHandler{
sessionService: sessionService,
messageService: messageService,
streamManager: streamManager,
config: config,
testDataService: testDataService,
knowledgebaseService: knowledgebaseService,
}
}
// SessionStrategy defines the configuration for a conversation session strategy
type SessionStrategy struct {
// Maximum number of conversation rounds to maintain
MaxRounds int `json:"max_rounds"`
// Whether to enable query rewrite for multi-round conversations
EnableRewrite bool `json:"enable_rewrite"`
// Strategy to use when no relevant knowledge is found
FallbackStrategy types.FallbackStrategy `json:"fallback_strategy"`
// Fixed response content for fallback
FallbackResponse string `json:"fallback_response"`
// Number of top results to retrieve from vector search
EmbeddingTopK int `json:"embedding_top_k"`
// Threshold for keyword-based retrieval
KeywordThreshold float64 `json:"keyword_threshold"`
// Threshold for vector-based retrieval
VectorThreshold float64 `json:"vector_threshold"`
// ID of the model used for reranking results
RerankModelID string `json:"rerank_model_id"`
// Number of top results after reranking
RerankTopK int `json:"rerank_top_k"`
// Threshold for reranking results
RerankThreshold float64 `json:"rerank_threshold"`
// ID of the model used for summarization
SummaryModelID string `json:"summary_model_id"`
// Parameters for the summary model
SummaryParameters *types.SummaryConfig `json:"summary_parameters" gorm:"type:json"`
// Prefix for responses when no match is found
NoMatchPrefix string `json:"no_match_prefix"`
}
// CreateSessionRequest represents a request to create a new session
type CreateSessionRequest struct {
// ID of the associated knowledge base
KnowledgeBaseID string `json:"knowledge_base_id" binding:"required"`
// Session strategy configuration
SessionStrategy *SessionStrategy `json:"session_strategy"`
}
// CreateSession handles the creation of a new conversation session
func (h *SessionHandler) CreateSession(c *gin.Context) {
ctx := c.Request.Context()
logger.Info(ctx, "Start creating session")
// Parse and validate the request body
var request CreateSessionRequest
if err := c.ShouldBindJSON(&request); err != nil {
logger.Error(ctx, "Failed to validate session creation parameters", err)
c.Error(errors.NewBadRequestError(err.Error()))
return
}
// Get tenant ID from context
tenantID, exists := c.Get(types.TenantIDContextKey.String())
if !exists {
logger.Error(ctx, "Failed to get tenant ID")
c.Error(errors.NewUnauthorizedError("Unauthorized"))
return
}
// Validate knowledge base ID
if request.KnowledgeBaseID == "" {
logger.Error(ctx, "Knowledge base ID is empty")
c.Error(errors.NewBadRequestError("Knowledge base cannot be empty"))
return
}
logger.Infof(
ctx,
"Processing session creation request, tenant ID: %d, knowledge base ID: %s",
tenantID.(uint),
request.KnowledgeBaseID,
)
// Create session object with base properties
createdSession := &types.Session{
TenantID: tenantID.(uint),
KnowledgeBaseID: request.KnowledgeBaseID,
}
// If summary model parameters are empty, set defaults
if request.SessionStrategy != nil {
createdSession.RerankModelID = request.SessionStrategy.RerankModelID
createdSession.SummaryModelID = request.SessionStrategy.SummaryModelID
createdSession.MaxRounds = request.SessionStrategy.MaxRounds
createdSession.EnableRewrite = request.SessionStrategy.EnableRewrite
createdSession.FallbackStrategy = request.SessionStrategy.FallbackStrategy
createdSession.FallbackResponse = request.SessionStrategy.FallbackResponse
createdSession.EmbeddingTopK = request.SessionStrategy.EmbeddingTopK
createdSession.KeywordThreshold = request.SessionStrategy.KeywordThreshold
createdSession.VectorThreshold = request.SessionStrategy.VectorThreshold
createdSession.RerankTopK = request.SessionStrategy.RerankTopK
createdSession.RerankThreshold = request.SessionStrategy.RerankThreshold
if request.SessionStrategy.SummaryParameters != nil {
createdSession.SummaryParameters = request.SessionStrategy.SummaryParameters
} else {
createdSession.SummaryParameters = &types.SummaryConfig{
MaxTokens: h.config.Conversation.Summary.MaxTokens,
TopP: h.config.Conversation.Summary.TopP,
TopK: h.config.Conversation.Summary.TopK,
FrequencyPenalty: h.config.Conversation.Summary.FrequencyPenalty,
PresencePenalty: h.config.Conversation.Summary.PresencePenalty,
RepeatPenalty: h.config.Conversation.Summary.RepeatPenalty,
NoMatchPrefix: h.config.Conversation.Summary.NoMatchPrefix,
Temperature: h.config.Conversation.Summary.Temperature,
Seed: h.config.Conversation.Summary.Seed,
MaxCompletionTokens: h.config.Conversation.Summary.MaxCompletionTokens,
}
}
if createdSession.SummaryParameters.Prompt == "" {
createdSession.SummaryParameters.Prompt = h.config.Conversation.Summary.Prompt
}
if createdSession.SummaryParameters.ContextTemplate == "" {
createdSession.SummaryParameters.ContextTemplate = h.config.Conversation.Summary.ContextTemplate
}
if createdSession.SummaryParameters.NoMatchPrefix == "" {
createdSession.SummaryParameters.NoMatchPrefix = h.config.Conversation.Summary.NoMatchPrefix
}
logger.Debug(ctx, "Custom session strategy set")
} else {
// Use default configuration from global config
createdSession.MaxRounds = h.config.Conversation.MaxRounds
createdSession.EnableRewrite = h.config.Conversation.EnableRewrite
createdSession.FallbackStrategy = types.FallbackStrategy(h.config.Conversation.FallbackStrategy)
createdSession.FallbackResponse = h.config.Conversation.FallbackResponse
createdSession.EmbeddingTopK = h.config.Conversation.EmbeddingTopK
createdSession.KeywordThreshold = h.config.Conversation.KeywordThreshold
createdSession.VectorThreshold = h.config.Conversation.VectorThreshold
createdSession.RerankThreshold = h.config.Conversation.RerankThreshold
createdSession.RerankTopK = h.config.Conversation.RerankTopK
createdSession.SummaryParameters = &types.SummaryConfig{
MaxTokens: h.config.Conversation.Summary.MaxTokens,
TopP: h.config.Conversation.Summary.TopP,
TopK: h.config.Conversation.Summary.TopK,
FrequencyPenalty: h.config.Conversation.Summary.FrequencyPenalty,
PresencePenalty: h.config.Conversation.Summary.PresencePenalty,
RepeatPenalty: h.config.Conversation.Summary.RepeatPenalty,
Prompt: h.config.Conversation.Summary.Prompt,
ContextTemplate: h.config.Conversation.Summary.ContextTemplate,
NoMatchPrefix: h.config.Conversation.Summary.NoMatchPrefix,
Temperature: h.config.Conversation.Summary.Temperature,
Seed: h.config.Conversation.Summary.Seed,
MaxCompletionTokens: h.config.Conversation.Summary.MaxCompletionTokens,
}
logger.Debug(ctx, "Using default session strategy")
}
kb, err := h.knowledgebaseService.GetKnowledgeBaseByID(ctx, request.KnowledgeBaseID)
if err != nil {
logger.Error(ctx, "Failed to get knowledge base", err)
c.Error(errors.NewInternalServerError(err.Error()))
return
}
// Get model IDs from knowledge base if not provided
if createdSession.SummaryModelID == "" {
createdSession.SummaryModelID = kb.SummaryModelID
}
if createdSession.RerankModelID == "" {
createdSession.RerankModelID = kb.RerankModelID
}
// Call service to create session
logger.Infof(ctx, "Calling session service to create session")
createdSession, err = h.sessionService.CreateSession(ctx, createdSession)
if err != nil {
logger.ErrorWithFields(ctx, err, nil)
c.Error(errors.NewInternalServerError(err.Error()))
return
}
// Return created session
logger.Infof(ctx, "Session created successfully, ID: %s", createdSession.ID)
c.JSON(http.StatusCreated, gin.H{
"success": true,
"data": createdSession,
})
}
// GetSession retrieves a session by its ID
func (h *SessionHandler) GetSession(c *gin.Context) {
ctx := c.Request.Context()
logger.Info(ctx, "Start retrieving session")
// Get session ID from URL parameter
id := c.Param("id")
if id == "" {
logger.Error(ctx, "Session ID is empty")
c.Error(errors.NewBadRequestError(errors.ErrInvalidSessionID.Error()))
return
}
// Call service to get session details
logger.Infof(ctx, "Retrieving session, ID: %s", id)
session, err := h.sessionService.GetSession(ctx, id)
if err != nil {
if err == errors.ErrSessionNotFound {
logger.Warnf(ctx, "Session not found, ID: %s", id)
c.Error(errors.NewNotFoundError(err.Error()))
return
}
logger.ErrorWithFields(ctx, err, nil)
c.Error(errors.NewInternalServerError(err.Error()))
return
}
// Return session data
logger.Infof(ctx, "Session retrieved successfully, ID: %s", id)
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": session,
})
}
// GetSessionsByTenant retrieves all sessions for the current tenant with pagination
func (h *SessionHandler) GetSessionsByTenant(c *gin.Context) {
ctx := c.Request.Context()
logger.Info(ctx, "Start retrieving all sessions for tenant")
// Parse pagination parameters from query
var pagination types.Pagination
if err := c.ShouldBindQuery(&pagination); err != nil {
logger.Error(ctx, "Failed to parse pagination parameters", err)
c.Error(errors.NewBadRequestError(err.Error()))
return
}
logger.Debugf(ctx, "Using pagination parameters: page=%d, page_size=%d", pagination.Page, pagination.PageSize)
// Use paginated query to get sessions
result, err := h.sessionService.GetPagedSessionsByTenant(ctx, &pagination)
if err != nil {
logger.ErrorWithFields(ctx, err, nil)
c.Error(errors.NewInternalServerError(err.Error()))
return
}
// Return sessions with pagination data
logger.Infof(ctx, "Successfully retrieved tenant sessions, total: %d", result.Total)
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": result.Data,
"total": result.Total,
"page": result.Page,
"page_size": result.PageSize,
})
}
// UpdateSession updates an existing session's properties
func (h *SessionHandler) UpdateSession(c *gin.Context) {
ctx := c.Request.Context()
logger.Info(ctx, "Start updating session")
// Get session ID from URL parameter
id := c.Param("id")
if id == "" {
logger.Error(ctx, "Session ID is empty")
c.Error(errors.NewBadRequestError(errors.ErrInvalidSessionID.Error()))
return
}
// Verify tenant ID from context for authorization
tenantID, exists := c.Get(types.TenantIDContextKey.String())
if !exists {
logger.Error(ctx, "Failed to get tenant ID")
c.Error(errors.NewUnauthorizedError("Unauthorized"))
return
}
// Parse request body to session object
var session types.Session
if err := c.ShouldBindJSON(&session); err != nil {
logger.Error(ctx, "Failed to parse session data", err)
c.Error(errors.NewBadRequestError(err.Error()))
return
}
// Set session ID and tenant ID
logger.Infof(ctx, "Updating session, ID: %s, tenant ID: %d", id, tenantID.(uint))
session.ID = id
session.TenantID = tenantID.(uint)
// Call service to update session
if err := h.sessionService.UpdateSession(ctx, &session); err != nil {
if err == errors.ErrSessionNotFound {
logger.Warnf(ctx, "Session not found, ID: %s", id)
c.Error(errors.NewNotFoundError(err.Error()))
return
}
logger.ErrorWithFields(ctx, err, nil)
c.Error(errors.NewInternalServerError(err.Error()))
return
}
// Return updated session
logger.Infof(ctx, "Session updated successfully, ID: %s", id)
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": session,
})
}
// DeleteSession deletes a session by its ID
func (h *SessionHandler) DeleteSession(c *gin.Context) {
ctx := c.Request.Context()
logger.Info(ctx, "Start deleting session")
// Get session ID from URL parameter
id := c.Param("id")
if id == "" {
logger.Error(ctx, "Session ID is empty")
c.Error(errors.NewBadRequestError(errors.ErrInvalidSessionID.Error()))
return
}
// Call service to delete session
logger.Infof(ctx, "Deleting session, ID: %s", id)
if err := h.sessionService.DeleteSession(ctx, id); err != nil {
if err == errors.ErrSessionNotFound {
logger.Warnf(ctx, "Session not found, ID: %s", id)
c.Error(errors.NewNotFoundError(err.Error()))
return
}
logger.ErrorWithFields(ctx, err, nil)
c.Error(errors.NewInternalServerError(err.Error()))
return
}
// Return success message
logger.Infof(ctx, "Session deleted successfully, ID: %s", id)
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "Session deleted successfully",
})
}
// GenerateTitleRequest defines the request structure for generating a session title
type GenerateTitleRequest struct {
Messages []types.Message `json:"messages" binding:"required"` // Messages to use as context for title generation
}
// GenerateTitle generates a title for a session based on message content
func (h *SessionHandler) GenerateTitle(c *gin.Context) {
ctx := c.Request.Context()
logger.Info(ctx, "Start generating session title")
// Get session ID from URL parameter
sessionID := c.Param("session_id")
if sessionID == "" {
logger.Error(ctx, "Session ID is empty")
c.Error(errors.NewBadRequestError(errors.ErrInvalidSessionID.Error()))
return
}
// Parse request body
var request GenerateTitleRequest
if err := c.ShouldBindJSON(&request); err != nil {
logger.Error(ctx, "Failed to parse request data", err)
c.Error(errors.NewBadRequestError(err.Error()))
return
}
// Call service to generate title
logger.Infof(ctx, "Generating session title, session ID: %s, message count: %d", sessionID, len(request.Messages))
title, err := h.sessionService.GenerateTitle(ctx, sessionID, request.Messages)
if err != nil {
logger.ErrorWithFields(ctx, err, nil)
c.Error(errors.NewInternalServerError(err.Error()))
return
}
// Return generated title
logger.Infof(ctx, "Session title generated successfully, session ID: %s, title: %s", sessionID, title)
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": title,
})
}
// CreateKnowledgeQARequest defines the request structure for knowledge QA
type CreateKnowledgeQARequest struct {
Query string `json:"query" binding:"required"` // Query text for knowledge base search
}
// SearchKnowledgeRequest defines the request structure for searching knowledge without LLM summarization
type SearchKnowledgeRequest struct {
Query string `json:"query" binding:"required"` // Query text to search for
KnowledgeBaseID string `json:"knowledge_base_id" binding:"required"` // ID of the knowledge base to search
}
// SearchKnowledge performs knowledge base search without LLM summarization
func (h *SessionHandler) SearchKnowledge(c *gin.Context) {
ctx := logger.CloneContext(c.Request.Context())
logger.Info(ctx, "Start processing knowledge search request")
// Parse request body
var request SearchKnowledgeRequest
if err := c.ShouldBindJSON(&request); err != nil {
logger.Error(ctx, "Failed to parse request data", err)
c.Error(errors.NewBadRequestError(err.Error()))
return
}
// Validate request parameters
if request.Query == "" {
logger.Error(ctx, "Query content is empty")
c.Error(errors.NewBadRequestError("Query content cannot be empty"))
return
}
if request.KnowledgeBaseID == "" {
logger.Error(ctx, "Knowledge base ID is empty")
c.Error(errors.NewBadRequestError("Knowledge base ID cannot be empty"))
return
}
logger.Infof(
ctx,
"Knowledge search request, knowledge base ID: %s, query: %s",
request.KnowledgeBaseID,
request.Query,
)
// Directly call knowledge retrieval service without LLM summarization
searchResults, err := h.sessionService.SearchKnowledge(ctx, request.KnowledgeBaseID, request.Query)
if err != nil {
logger.ErrorWithFields(ctx, err, nil)
c.Error(errors.NewInternalServerError(err.Error()))
return
}
logger.Infof(ctx, "Knowledge search completed, found %d results", len(searchResults))
// Return search results
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": searchResults,
})
}
// ContinueStream handles continued streaming of an active response stream
func (h *SessionHandler) ContinueStream(c *gin.Context) {
ctx := c.Request.Context()
logger.Info(ctx, "Start continuing stream response processing")
// Get session ID from URL parameter
sessionID := c.Param("session_id")
if sessionID == "" {
logger.Error(ctx, "Session ID is empty")
c.Error(errors.NewBadRequestError(errors.ErrInvalidSessionID.Error()))
return
}
// Get message ID from query parameter
messageID := c.Query("message_id")
if messageID == "" {
logger.Error(ctx, "Message ID is empty")
c.Error(errors.NewBadRequestError("Missing message ID"))
return
}
logger.Infof(ctx, "Continuing stream, session ID: %s, message ID: %s", sessionID, messageID)
// Verify that the session exists and belongs to this tenant
_, err := h.sessionService.GetSession(ctx, sessionID)
if err != nil {
if err == errors.ErrSessionNotFound {
logger.Warnf(ctx, "Session not found, ID: %s", sessionID)
c.Error(errors.NewNotFoundError(err.Error()))
} else {
logger.ErrorWithFields(ctx, err, nil)
c.Error(errors.NewInternalServerError(err.Error()))
}
return
}
// Get the incomplete message
message, err := h.messageService.GetMessage(ctx, sessionID, messageID)
if err != nil {
logger.ErrorWithFields(ctx, err, nil)
c.Error(errors.NewInternalServerError(err.Error()))
return
}
if message == nil {
logger.Warnf(ctx, "Incomplete message not found, session ID: %s, message ID: %s", sessionID, messageID)
c.JSON(http.StatusNotFound, gin.H{
"success": false,
"error": "Incomplete message not found",
})
return
}
// Get stream information
streamInfo, err := h.streamManager.GetStream(ctx, sessionID, messageID)
if err != nil {
logger.ErrorWithFields(ctx, err, nil)
c.Error(errors.NewInternalServerError(fmt.Sprintf("Failed to get stream data: %s", err.Error())))
return
}
if streamInfo == nil {
logger.Warnf(ctx, "Active stream not found, session ID: %s, message ID: %s", sessionID, messageID)
c.JSON(http.StatusNotFound, gin.H{
"success": false,
"error": "Active stream not found",
})
return
}
// If stream is already completed, return the full message
if streamInfo.IsCompleted {
logger.Infof(
ctx, "Stream already completed, returning directly, session ID: %s, message ID: %s", sessionID, messageID,
)
c.JSON(http.StatusOK, gin.H{
"id": message.ID,
"role": message.Role,
"content": message.Content,
"created_at": message.CreatedAt,
"done": true,
})
return
}
logger.Infof(
ctx, "Preparing to set SSE headers and send stream data, session ID: %s, message ID: %s", sessionID, messageID,
)
// Send knowledge references first if available
if len(streamInfo.KnowledgeReferences) > 0 {
logger.Debug(ctx, "Sending knowledge references")
c.SSEvent("message", &types.StreamResponse{
ID: message.RequestID,
ResponseType: types.ResponseTypeReferences,
Done: false,
KnowledgeReferences: streamInfo.KnowledgeReferences,
})
}
// Send existing content
if streamInfo.Content != "" {
logger.Debug(ctx, "Sending current existing content")
c.SSEvent("message", &types.StreamResponse{
ID: message.RequestID,
ResponseType: types.ResponseTypeAnswer,
Content: streamInfo.Content,
Done: streamInfo.IsCompleted,
})
}
// Create channels to monitor content updates
contentCh := make(chan string, 10)
doneCh := make(chan bool, 1)
logger.Debug(ctx, "Starting content update monitoring")
// Start a goroutine to monitor for content updates
go func() {
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
currentContent := streamInfo.Content
for {
select {
case <-ticker.C:
latestStreamInfo, err := h.streamManager.GetStream(ctx, sessionID, messageID)
if err != nil {
logger.Errorf(ctx, "Failed to get stream data: %v", err)
doneCh <- true
return
}
if latestStreamInfo == nil {
logger.Debug(ctx, "Stream no longer exists")
doneCh <- true
return
}
if latestStreamInfo.IsCompleted {
logger.Debug(ctx, "Stream completed")
doneCh <- true
return
}
// Calculate new content delta
if len(latestStreamInfo.Content) > len(currentContent) {
newContent := latestStreamInfo.Content[len(currentContent):]
contentCh <- newContent
currentContent = latestStreamInfo.Content
logger.Debugf(ctx, "Sending new content: %d bytes", len(newContent))
}
case <-c.Request.Context().Done():
logger.Debug(ctx, "Client connection closed")
return
}
}
}()
logger.Info(ctx, "Starting stream response")
// Stream updated content to client
c.Stream(func(w io.Writer) bool {
select {
case <-c.Request.Context().Done():
logger.Debug(ctx, "Client connection closed")
return false
case <-doneCh:
logger.Debug(ctx, "Stream completed, sending completion notification")
c.SSEvent("message", &types.StreamResponse{
ID: message.RequestID,
ResponseType: types.ResponseTypeAnswer,
Content: "",
Done: true,
})
return false
case content := <-contentCh:
logger.Debugf(ctx, "Sending content fragment: %d bytes", len(content))
c.SSEvent("message", &types.StreamResponse{
ID: message.RequestID,
ResponseType: types.ResponseTypeAnswer,
Content: content,
Done: false,
})
return true
}
})
}
// KnowledgeQA handles knowledge base question answering requests with LLM summarization
func (h *SessionHandler) KnowledgeQA(c *gin.Context) {
ctx := logger.CloneContext(c.Request.Context())
logger.Info(ctx, "Start processing knowledge QA request")
// Get session ID from URL parameter
sessionID := c.Param("session_id")
if sessionID == "" {
logger.Error(ctx, "Session ID is empty")
c.Error(errors.NewBadRequestError(errors.ErrInvalidSessionID.Error()))
return
}
// Parse request body
var request CreateKnowledgeQARequest
if err := c.ShouldBindJSON(&request); err != nil {
logger.Error(ctx, "Failed to parse request data", err)
c.Error(errors.NewBadRequestError(err.Error()))
return
}
// Create assistant message
assistantMessage := &types.Message{
SessionID: sessionID,
Role: "assistant",
RequestID: c.GetString(types.RequestIDContextKey.String()),
IsCompleted: false,
}
defer h.completeAssistantMessage(ctx, assistantMessage)
// Validate query content
if request.Query == "" {
logger.Error(ctx, "Query content is empty")
c.Error(errors.NewBadRequestError("Query content cannot be empty"))
return
}
logger.Infof(ctx, "Knowledge QA request, session ID: %s, query: %s", sessionID, request.Query)
// Create user message
if _, err := h.messageService.CreateMessage(ctx, &types.Message{
SessionID: sessionID,
Role: "user",
Content: request.Query,
RequestID: c.GetString(types.RequestIDContextKey.String()),
CreatedAt: time.Now(),
IsCompleted: true,
}); err != nil {
c.Error(errors.NewInternalServerError(err.Error()))
return
}
// Create assistant message (response)
assistantMessage.CreatedAt = time.Now()
if _, err := h.messageService.CreateMessage(ctx, assistantMessage); err != nil {
c.Error(errors.NewInternalServerError(err.Error()))
return
}
logger.Infof(ctx, "Calling knowledge QA service, session ID: %s", sessionID)
// Call service to perform knowledge QA
searchResults, respCh, err := h.sessionService.KnowledgeQA(ctx, sessionID, request.Query)
if err != nil {
logger.ErrorWithFields(ctx, err, nil)
c.Error(errors.NewInternalServerError(err.Error()))
return
}
assistantMessage.KnowledgeReferences = searchResults
// Register new stream with stream manager
requestID := c.GetString(types.RequestIDContextKey.String())
if err := h.streamManager.RegisterStream(ctx, sessionID, assistantMessage.ID, request.Query); err != nil {
logger.GetLogger(ctx).Error("Register stream failed", "error", err)
}
// Send knowledge references if available
if len(searchResults) > 0 {
logger.Debugf(ctx, "Sending reference content, total %d", len(searchResults))
c.SSEvent("message", &types.StreamResponse{
ID: requestID,
ResponseType: types.ResponseTypeReferences,
KnowledgeReferences: searchResults,
})
c.Writer.Flush()
} else {
logger.Debug(ctx, "No reference content to send")
}
// Process streamed response
func() {
defer func() {
// Mark stream as completed when done
if err := h.streamManager.CompleteStream(ctx, sessionID, assistantMessage.ID); err != nil {
logger.GetLogger(ctx).Error("Complete stream failed", "error", err)
}
}()
for response := range respCh {
response.ID = requestID
c.SSEvent("message", response)
c.Writer.Flush()
if response.ResponseType == types.ResponseTypeAnswer {
assistantMessage.Content += response.Content
// Update stream manager with new content
if err := h.streamManager.UpdateStream(
ctx, sessionID, assistantMessage.ID, response.Content, searchResults,
); err != nil {
logger.GetLogger(ctx).Error("Update stream content failed", "error", err)
}
}
}
}()
}
// completeAssistantMessage marks an assistant message as complete and updates it
func (h *SessionHandler) completeAssistantMessage(ctx context.Context, assistantMessage *types.Message) {
assistantMessage.UpdatedAt = time.Now()
assistantMessage.IsCompleted = true
_ = h.messageService.UpdateMessage(ctx, assistantMessage)
}