l_ai_knowledge/internal/application/service/evaluation.go

408 lines
14 KiB
Go

package service
import (
"context"
"errors"
"runtime"
"sync"
"time"
"github.com/google/uuid"
"golang.org/x/sync/errgroup"
"knowlege-lsxd/internal/config"
"knowlege-lsxd/internal/logger"
"knowlege-lsxd/internal/types"
"knowlege-lsxd/internal/types/interfaces"
)
/*
corpus: pid -> content
queries: qid -> content
answers: aid -> content
qrels: qid -> pid
arels: qid -> aid
*/
// EvaluationService handles evaluation tasks for knowledge base and chat models
type EvaluationService struct {
config *config.Config // Application configuration
dataset interfaces.DatasetService // Service for dataset operations
knowledgeBaseService interfaces.KnowledgeBaseService // Service for knowledge base operations
knowledgeService interfaces.KnowledgeService // Service for knowledge operations
sessionService interfaces.SessionService // Service for chat sessions
testData *TestDataService // Service for test data
evaluationMemoryStorage *evaluationMemoryStorage // In-memory storage for evaluation tasks
}
func NewEvaluationService(
config *config.Config,
dataset interfaces.DatasetService,
knowledgeBaseService interfaces.KnowledgeBaseService,
knowledgeService interfaces.KnowledgeService,
sessionService interfaces.SessionService,
testData *TestDataService,
) interfaces.EvaluationService {
evaluationMemoryStorage := newEvaluationMemoryStorage()
return &EvaluationService{
config: config,
dataset: dataset,
knowledgeBaseService: knowledgeBaseService,
knowledgeService: knowledgeService,
sessionService: sessionService,
testData: testData,
evaluationMemoryStorage: evaluationMemoryStorage,
}
}
// evaluationMemoryStorage stores evaluation tasks in memory with thread-safe access
type evaluationMemoryStorage struct {
store map[string]*types.EvaluationDetail // Map of taskID to evaluation details
mu *sync.RWMutex // Read-write lock for concurrent access
}
func newEvaluationMemoryStorage() *evaluationMemoryStorage {
res := &evaluationMemoryStorage{
store: make(map[string]*types.EvaluationDetail),
mu: &sync.RWMutex{},
}
return res
}
func (e *evaluationMemoryStorage) register(params *types.EvaluationDetail) {
e.mu.Lock()
defer e.mu.Unlock()
logger.Infof(context.Background(), "Registering evaluation task: %s", params.Task.ID)
e.store[params.Task.ID] = params
}
func (e *evaluationMemoryStorage) get(taskID string) (*types.EvaluationDetail, error) {
e.mu.RLock()
defer e.mu.RUnlock()
logger.Infof(context.Background(), "Getting evaluation task: %s", taskID)
res, ok := e.store[taskID]
if !ok {
return nil, errors.New("task not found")
}
return res, nil
}
func (e *evaluationMemoryStorage) update(taskID string, fn func(params *types.EvaluationDetail)) error {
e.mu.Lock()
defer e.mu.Unlock()
params, ok := e.store[taskID]
if !ok {
return errors.New("task not found")
}
fn(params)
return nil
}
func (e *EvaluationService) EvaluationResult(ctx context.Context, taskID string) (*types.EvaluationDetail, error) {
logger.Info(ctx, "Start getting evaluation result")
logger.Infof(ctx, "Task ID: %s", taskID)
detail, err := e.evaluationMemoryStorage.get(taskID)
if err != nil {
logger.Errorf(ctx, "Failed to get evaluation task: %v", err)
return nil, err
}
tenantID := ctx.Value(types.TenantIDContextKey).(uint)
logger.Infof(
ctx,
"Checking tenant ID match, task tenant ID: %d, current tenant ID: %d",
detail.Task.TenantID, tenantID,
)
if tenantID != detail.Task.TenantID {
logger.Error(ctx, "Tenant ID mismatch")
return nil, errors.New("tenant ID does not match")
}
logger.Info(ctx, "Evaluation result retrieved successfully")
return detail, nil
}
// Evaluation starts a new evaluation task with given parameters
// datasetID: ID of the dataset to evaluate against
// knowledgeBaseID: ID of the knowledge base to use (empty to create new)
// chatModelID: ID of the chat model to evaluate
// rerankModelID: ID of the rerank model to evaluate
func (e *EvaluationService) Evaluation(ctx context.Context,
datasetID string, knowledgeBaseID string, chatModelID string, rerankModelID string,
) (*types.EvaluationDetail, error) {
logger.Info(ctx, "Start evaluation")
logger.Infof(ctx, "Dataset ID: %s, Knowledge Base ID: %s, Chat Model ID: %s, Rerank Model ID: %s",
datasetID, knowledgeBaseID, chatModelID, rerankModelID)
// Get tenant ID from context for multi-tenancy support
tenantID := ctx.Value(types.TenantIDContextKey).(uint)
logger.Infof(ctx, "Tenant ID: %d", tenantID)
// Handle knowledge base creation if not provided
if knowledgeBaseID == "" {
logger.Info(ctx, "No knowledge base ID provided, creating new knowledge base")
// Create new knowledge base with default evaluation settings
kb, err := e.knowledgeBaseService.CreateKnowledgeBase(ctx, &types.KnowledgeBase{
Name: "evaluation",
Description: "evaluation",
EmbeddingModelID: e.testData.EmbedModel.GetModelID(),
SummaryModelID: e.testData.LLMModel.GetModelID(),
})
if err != nil {
logger.Errorf(ctx, "Failed to create knowledge base: %v", err)
return nil, err
}
knowledgeBaseID = kb.ID
logger.Infof(ctx, "Created new knowledge base with ID: %s", knowledgeBaseID)
} else {
logger.Infof(ctx, "Using existing knowledge base ID: %s", knowledgeBaseID)
// Create evaluation-specific knowledge base based on existing one
kb, err := e.knowledgeBaseService.GetKnowledgeBaseByID(ctx, knowledgeBaseID)
if err != nil {
logger.Errorf(ctx, "Failed to get knowledge base: %v", err)
return nil, err
}
kb, err = e.knowledgeBaseService.CreateKnowledgeBase(ctx, &types.KnowledgeBase{
Name: "evaluation",
Description: "evaluation",
EmbeddingModelID: kb.EmbeddingModelID,
SummaryModelID: kb.SummaryModelID,
})
if err != nil {
logger.Errorf(ctx, "Failed to create knowledge base: %v", err)
return nil, err
}
knowledgeBaseID = kb.ID
logger.Infof(ctx, "Created new knowledge base with ID: %s based on existing one", knowledgeBaseID)
}
// Set default values for optional parameters
if datasetID == "" {
datasetID = "default"
logger.Info(ctx, "Using default dataset")
}
if rerankModelID == "" {
rerankModelID = e.testData.RerankModel.GetModelID()
logger.Infof(ctx, "Using default rerank model: %s", rerankModelID)
}
if chatModelID == "" {
chatModelID = e.testData.LLMModel.GetModelID()
logger.Infof(ctx, "Using default chat model: %s", chatModelID)
}
// Create evaluation task with unique ID
logger.Info(ctx, "Creating evaluation task")
taskID := uuid.New().String()
logger.Infof(ctx, "Generated task ID: %s", taskID)
// Prepare evaluation detail with all parameters
detail := &types.EvaluationDetail{
Task: &types.EvaluationTask{
ID: taskID,
TenantID: tenantID,
DatasetID: datasetID,
Status: types.EvaluationStatuePending,
StartTime: time.Now(),
},
Params: &types.ChatManage{
KnowledgeBaseID: knowledgeBaseID,
VectorThreshold: e.config.Conversation.VectorThreshold,
KeywordThreshold: e.config.Conversation.KeywordThreshold,
EmbeddingTopK: e.config.Conversation.EmbeddingTopK,
RerankModelID: rerankModelID,
RerankTopK: e.config.Conversation.RerankTopK,
RerankThreshold: e.config.Conversation.RerankThreshold,
ChatModelID: chatModelID,
SummaryConfig: types.SummaryConfig{
MaxTokens: e.config.Conversation.Summary.MaxTokens,
RepeatPenalty: e.config.Conversation.Summary.RepeatPenalty,
TopK: e.config.Conversation.Summary.TopK,
TopP: e.config.Conversation.Summary.TopP,
Prompt: e.config.Conversation.Summary.Prompt,
ContextTemplate: e.config.Conversation.Summary.ContextTemplate,
FrequencyPenalty: e.config.Conversation.Summary.FrequencyPenalty,
PresencePenalty: e.config.Conversation.Summary.PresencePenalty,
NoMatchPrefix: e.config.Conversation.Summary.NoMatchPrefix,
Temperature: e.config.Conversation.Summary.Temperature,
Seed: e.config.Conversation.Summary.Seed,
MaxCompletionTokens: e.config.Conversation.Summary.MaxCompletionTokens,
},
FallbackResponse: e.config.Conversation.FallbackResponse,
},
}
// Store evaluation task in memory storage
logger.Info(ctx, "Registering evaluation task")
e.evaluationMemoryStorage.register(detail)
// Start evaluation in background goroutine
logger.Info(ctx, "Starting evaluation in background")
go func() {
// Create new context with logger for background task
newCtx := logger.CloneContext(ctx)
logger.Infof(newCtx, "Background evaluation started for task ID: %s", taskID)
// Update task status to running
detail.Task.Status = types.EvaluationStatueRunning
logger.Info(newCtx, "Evaluation task status set to running")
// Execute actual evaluation
if err := e.EvalDataset(newCtx, detail); err != nil {
detail.Task.Status = types.EvaluationStatueFailed
detail.Task.ErrMsg = err.Error()
logger.Errorf(newCtx, "Evaluation task failed: %v, task ID: %s", err, taskID)
return
}
// Mark task as completed successfully
logger.Infof(newCtx, "Evaluation task completed successfully, task ID: %s", taskID)
detail.Task.Status = types.EvaluationStatueSuccess
}()
logger.Infof(ctx, "Evaluation task created successfully, task ID: %s", taskID)
return detail, nil
}
// EvalDataset performs the actual evaluation of a dataset
// Processes each QA pair in parallel and records metrics
func (e *EvaluationService) EvalDataset(ctx context.Context, detail *types.EvaluationDetail) error {
logger.Info(ctx, "Start evaluating dataset")
logger.Infof(ctx, "Task ID: %s, Dataset ID: %s", detail.Task.ID, detail.Task.DatasetID)
// Retrieve dataset from storage
dataset, err := e.dataset.GetDatasetByID(ctx, detail.Task.DatasetID)
if err != nil {
logger.Errorf(ctx, "Failed to get dataset: %v", err)
return err
}
logger.Infof(ctx, "Dataset retrieved successfully with %d QA pairs", len(dataset))
// Update total QA pairs count in task details
e.evaluationMemoryStorage.update(detail.Task.ID, func(params *types.EvaluationDetail) {
params.Task.Total = len(dataset)
logger.Infof(ctx, "Updated task total to %d QA pairs", params.Task.Total)
})
// Extract and organize passages from dataset
passages := getPassageList(dataset)
logger.Infof(ctx, "Creating knowledge from %d passages", len(passages))
// Create knowledge base from passages
knowledge, err := e.knowledgeService.CreateKnowledgeFromPassage(ctx, detail.Params.KnowledgeBaseID, passages)
if err != nil {
logger.Errorf(ctx, "Failed to create knowledge from passages: %v", err)
return err
}
logger.Infof(ctx, "Knowledge created successfully, ID: %s", knowledge.ID)
// Setup cleanup of temporary resources
defer func() {
logger.Infof(ctx, "Cleaning up resources - deleting knowledge: %s", knowledge.ID)
if err := e.knowledgeService.DeleteKnowledge(ctx, knowledge.ID); err != nil {
logger.Errorf(ctx, "Failed to delete knowledge: %v, knowledge ID: %s", err, knowledge.ID)
}
logger.Infof(ctx, "Cleaning up resources - deleting knowledge base: %s", detail.Params.KnowledgeBaseID)
if err := e.knowledgeBaseService.DeleteKnowledgeBase(ctx, detail.Params.KnowledgeBaseID); err != nil {
logger.Errorf(
ctx,
"Failed to delete knowledge base: %v, knowledge base ID: %s",
err, detail.Params.KnowledgeBaseID,
)
}
}()
// Initialize parallel evaluation metrics
var finished int
var mu sync.Mutex
var g errgroup.Group
metricHook := NewHookMetric(len(dataset))
// Set worker limit based on available CPUs
g.SetLimit(max(runtime.GOMAXPROCS(0)-1, 1))
logger.Infof(ctx, "Starting evaluation with %d parallel workers", max(runtime.GOMAXPROCS(0)-1, 1))
// Process each QA pair in parallel
for i, qaPair := range dataset {
qaPair := qaPair
i := i
g.Go(func() error {
logger.Infof(ctx, "Processing QA pair %d, question: %s", i, qaPair.Question)
// Prepare chat management parameters for this QA pair
chatManage := detail.Params.Clone()
chatManage.Query = qaPair.Question
// Execute knowledge QA pipeline
logger.Infof(ctx, "Running knowledge QA for question: %s", qaPair.Question)
err = e.sessionService.KnowledgeQAByEvent(ctx, chatManage, types.Pipline["rag"])
if err != nil {
logger.Errorf(ctx, "Failed to process question %d: %v", i, err)
return err
}
// Record evaluation metrics
logger.Infof(ctx, "Recording metrics for QA pair %d", i)
metricHook.recordInit(i)
metricHook.recordQaPair(i, qaPair)
metricHook.recordSearchResult(i, chatManage.SearchResult)
metricHook.recordRerankResult(i, chatManage.RerankResult)
metricHook.recordChatResponse(i, chatManage.ChatResponse)
metricHook.recordFinish(i)
// Update progress metrics
mu.Lock()
finished += 1
metricResult := metricHook.MetricResult()
mu.Unlock()
e.evaluationMemoryStorage.update(detail.Task.ID, func(params *types.EvaluationDetail) {
params.Metric = metricResult
params.Task.Finished = finished
logger.Infof(ctx, "Updated task progress: %d/%d completed", finished, params.Task.Total)
})
return nil
})
}
// Wait for all parallel evaluations to complete
logger.Info(ctx, "Waiting for all evaluation tasks to complete")
if err := g.Wait(); err != nil {
logger.Errorf(ctx, "Evaluation error: %v", err)
return err
}
// Final update of evaluation metrics
e.evaluationMemoryStorage.update(detail.Task.ID, func(params *types.EvaluationDetail) {
params.Metric = metricHook.MetricResult()
params.Task.Finished = finished
})
logger.Infof(ctx, "Dataset evaluation completed successfully, task ID: %s", detail.Task.ID)
return nil
}
// getPassageList extracts and organizes passages from QA pairs
// Returns a slice of passages indexed by their passage IDs
func getPassageList(dataset []*types.QAPair) []string {
pIDMap := make(map[int]string)
maxPID := 0
for _, qaPair := range dataset {
for i := 0; i < len(qaPair.PIDs); i++ {
pIDMap[qaPair.PIDs[i]] = qaPair.Passages[i]
maxPID = max(maxPID, qaPair.PIDs[i])
}
}
passages := make([]string, maxPID)
for i := 0; i < maxPID; i++ {
if _, ok := pIDMap[i]; ok {
passages[i] = pIDMap[i]
}
}
return passages
}