408 lines
14 KiB
Go
408 lines
14 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"runtime"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"golang.org/x/sync/errgroup"
|
|
"knowlege-lsxd/internal/config"
|
|
"knowlege-lsxd/internal/logger"
|
|
"knowlege-lsxd/internal/types"
|
|
"knowlege-lsxd/internal/types/interfaces"
|
|
)
|
|
|
|
/*
|
|
corpus: pid -> content
|
|
queries: qid -> content
|
|
answers: aid -> content
|
|
qrels: qid -> pid
|
|
arels: qid -> aid
|
|
*/
|
|
|
|
// EvaluationService handles evaluation tasks for knowledge base and chat models
|
|
type EvaluationService struct {
|
|
config *config.Config // Application configuration
|
|
dataset interfaces.DatasetService // Service for dataset operations
|
|
knowledgeBaseService interfaces.KnowledgeBaseService // Service for knowledge base operations
|
|
knowledgeService interfaces.KnowledgeService // Service for knowledge operations
|
|
sessionService interfaces.SessionService // Service for chat sessions
|
|
testData *TestDataService // Service for test data
|
|
|
|
evaluationMemoryStorage *evaluationMemoryStorage // In-memory storage for evaluation tasks
|
|
}
|
|
|
|
func NewEvaluationService(
|
|
config *config.Config,
|
|
dataset interfaces.DatasetService,
|
|
knowledgeBaseService interfaces.KnowledgeBaseService,
|
|
knowledgeService interfaces.KnowledgeService,
|
|
sessionService interfaces.SessionService,
|
|
testData *TestDataService,
|
|
) interfaces.EvaluationService {
|
|
evaluationMemoryStorage := newEvaluationMemoryStorage()
|
|
return &EvaluationService{
|
|
config: config,
|
|
dataset: dataset,
|
|
knowledgeBaseService: knowledgeBaseService,
|
|
knowledgeService: knowledgeService,
|
|
sessionService: sessionService,
|
|
testData: testData,
|
|
evaluationMemoryStorage: evaluationMemoryStorage,
|
|
}
|
|
}
|
|
|
|
// evaluationMemoryStorage stores evaluation tasks in memory with thread-safe access
|
|
type evaluationMemoryStorage struct {
|
|
store map[string]*types.EvaluationDetail // Map of taskID to evaluation details
|
|
mu *sync.RWMutex // Read-write lock for concurrent access
|
|
}
|
|
|
|
func newEvaluationMemoryStorage() *evaluationMemoryStorage {
|
|
res := &evaluationMemoryStorage{
|
|
store: make(map[string]*types.EvaluationDetail),
|
|
mu: &sync.RWMutex{},
|
|
}
|
|
return res
|
|
}
|
|
|
|
func (e *evaluationMemoryStorage) register(params *types.EvaluationDetail) {
|
|
e.mu.Lock()
|
|
defer e.mu.Unlock()
|
|
logger.Infof(context.Background(), "Registering evaluation task: %s", params.Task.ID)
|
|
e.store[params.Task.ID] = params
|
|
}
|
|
|
|
func (e *evaluationMemoryStorage) get(taskID string) (*types.EvaluationDetail, error) {
|
|
e.mu.RLock()
|
|
defer e.mu.RUnlock()
|
|
logger.Infof(context.Background(), "Getting evaluation task: %s", taskID)
|
|
res, ok := e.store[taskID]
|
|
if !ok {
|
|
return nil, errors.New("task not found")
|
|
}
|
|
return res, nil
|
|
}
|
|
|
|
func (e *evaluationMemoryStorage) update(taskID string, fn func(params *types.EvaluationDetail)) error {
|
|
e.mu.Lock()
|
|
defer e.mu.Unlock()
|
|
params, ok := e.store[taskID]
|
|
if !ok {
|
|
return errors.New("task not found")
|
|
}
|
|
fn(params)
|
|
return nil
|
|
}
|
|
|
|
func (e *EvaluationService) EvaluationResult(ctx context.Context, taskID string) (*types.EvaluationDetail, error) {
|
|
logger.Info(ctx, "Start getting evaluation result")
|
|
logger.Infof(ctx, "Task ID: %s", taskID)
|
|
|
|
detail, err := e.evaluationMemoryStorage.get(taskID)
|
|
if err != nil {
|
|
logger.Errorf(ctx, "Failed to get evaluation task: %v", err)
|
|
return nil, err
|
|
}
|
|
|
|
tenantID := ctx.Value(types.TenantIDContextKey).(uint)
|
|
logger.Infof(
|
|
ctx,
|
|
"Checking tenant ID match, task tenant ID: %d, current tenant ID: %d",
|
|
detail.Task.TenantID, tenantID,
|
|
)
|
|
|
|
if tenantID != detail.Task.TenantID {
|
|
logger.Error(ctx, "Tenant ID mismatch")
|
|
return nil, errors.New("tenant ID does not match")
|
|
}
|
|
|
|
logger.Info(ctx, "Evaluation result retrieved successfully")
|
|
return detail, nil
|
|
}
|
|
|
|
// Evaluation starts a new evaluation task with given parameters
|
|
// datasetID: ID of the dataset to evaluate against
|
|
// knowledgeBaseID: ID of the knowledge base to use (empty to create new)
|
|
// chatModelID: ID of the chat model to evaluate
|
|
// rerankModelID: ID of the rerank model to evaluate
|
|
func (e *EvaluationService) Evaluation(ctx context.Context,
|
|
datasetID string, knowledgeBaseID string, chatModelID string, rerankModelID string,
|
|
) (*types.EvaluationDetail, error) {
|
|
logger.Info(ctx, "Start evaluation")
|
|
logger.Infof(ctx, "Dataset ID: %s, Knowledge Base ID: %s, Chat Model ID: %s, Rerank Model ID: %s",
|
|
datasetID, knowledgeBaseID, chatModelID, rerankModelID)
|
|
|
|
// Get tenant ID from context for multi-tenancy support
|
|
tenantID := ctx.Value(types.TenantIDContextKey).(uint)
|
|
logger.Infof(ctx, "Tenant ID: %d", tenantID)
|
|
|
|
// Handle knowledge base creation if not provided
|
|
if knowledgeBaseID == "" {
|
|
logger.Info(ctx, "No knowledge base ID provided, creating new knowledge base")
|
|
// Create new knowledge base with default evaluation settings
|
|
kb, err := e.knowledgeBaseService.CreateKnowledgeBase(ctx, &types.KnowledgeBase{
|
|
Name: "evaluation",
|
|
Description: "evaluation",
|
|
EmbeddingModelID: e.testData.EmbedModel.GetModelID(),
|
|
SummaryModelID: e.testData.LLMModel.GetModelID(),
|
|
})
|
|
if err != nil {
|
|
logger.Errorf(ctx, "Failed to create knowledge base: %v", err)
|
|
return nil, err
|
|
}
|
|
knowledgeBaseID = kb.ID
|
|
logger.Infof(ctx, "Created new knowledge base with ID: %s", knowledgeBaseID)
|
|
} else {
|
|
logger.Infof(ctx, "Using existing knowledge base ID: %s", knowledgeBaseID)
|
|
// Create evaluation-specific knowledge base based on existing one
|
|
kb, err := e.knowledgeBaseService.GetKnowledgeBaseByID(ctx, knowledgeBaseID)
|
|
if err != nil {
|
|
logger.Errorf(ctx, "Failed to get knowledge base: %v", err)
|
|
return nil, err
|
|
}
|
|
|
|
kb, err = e.knowledgeBaseService.CreateKnowledgeBase(ctx, &types.KnowledgeBase{
|
|
Name: "evaluation",
|
|
Description: "evaluation",
|
|
EmbeddingModelID: kb.EmbeddingModelID,
|
|
SummaryModelID: kb.SummaryModelID,
|
|
})
|
|
if err != nil {
|
|
logger.Errorf(ctx, "Failed to create knowledge base: %v", err)
|
|
return nil, err
|
|
}
|
|
knowledgeBaseID = kb.ID
|
|
logger.Infof(ctx, "Created new knowledge base with ID: %s based on existing one", knowledgeBaseID)
|
|
}
|
|
|
|
// Set default values for optional parameters
|
|
if datasetID == "" {
|
|
datasetID = "default"
|
|
logger.Info(ctx, "Using default dataset")
|
|
}
|
|
|
|
if rerankModelID == "" {
|
|
rerankModelID = e.testData.RerankModel.GetModelID()
|
|
logger.Infof(ctx, "Using default rerank model: %s", rerankModelID)
|
|
}
|
|
|
|
if chatModelID == "" {
|
|
chatModelID = e.testData.LLMModel.GetModelID()
|
|
logger.Infof(ctx, "Using default chat model: %s", chatModelID)
|
|
}
|
|
|
|
// Create evaluation task with unique ID
|
|
logger.Info(ctx, "Creating evaluation task")
|
|
taskID := uuid.New().String()
|
|
logger.Infof(ctx, "Generated task ID: %s", taskID)
|
|
|
|
// Prepare evaluation detail with all parameters
|
|
detail := &types.EvaluationDetail{
|
|
Task: &types.EvaluationTask{
|
|
ID: taskID,
|
|
TenantID: tenantID,
|
|
DatasetID: datasetID,
|
|
Status: types.EvaluationStatuePending,
|
|
StartTime: time.Now(),
|
|
},
|
|
Params: &types.ChatManage{
|
|
KnowledgeBaseID: knowledgeBaseID,
|
|
VectorThreshold: e.config.Conversation.VectorThreshold,
|
|
KeywordThreshold: e.config.Conversation.KeywordThreshold,
|
|
EmbeddingTopK: e.config.Conversation.EmbeddingTopK,
|
|
RerankModelID: rerankModelID,
|
|
RerankTopK: e.config.Conversation.RerankTopK,
|
|
RerankThreshold: e.config.Conversation.RerankThreshold,
|
|
ChatModelID: chatModelID,
|
|
SummaryConfig: types.SummaryConfig{
|
|
MaxTokens: e.config.Conversation.Summary.MaxTokens,
|
|
RepeatPenalty: e.config.Conversation.Summary.RepeatPenalty,
|
|
TopK: e.config.Conversation.Summary.TopK,
|
|
TopP: e.config.Conversation.Summary.TopP,
|
|
Prompt: e.config.Conversation.Summary.Prompt,
|
|
ContextTemplate: e.config.Conversation.Summary.ContextTemplate,
|
|
FrequencyPenalty: e.config.Conversation.Summary.FrequencyPenalty,
|
|
PresencePenalty: e.config.Conversation.Summary.PresencePenalty,
|
|
NoMatchPrefix: e.config.Conversation.Summary.NoMatchPrefix,
|
|
Temperature: e.config.Conversation.Summary.Temperature,
|
|
Seed: e.config.Conversation.Summary.Seed,
|
|
MaxCompletionTokens: e.config.Conversation.Summary.MaxCompletionTokens,
|
|
},
|
|
FallbackResponse: e.config.Conversation.FallbackResponse,
|
|
},
|
|
}
|
|
|
|
// Store evaluation task in memory storage
|
|
logger.Info(ctx, "Registering evaluation task")
|
|
e.evaluationMemoryStorage.register(detail)
|
|
|
|
// Start evaluation in background goroutine
|
|
logger.Info(ctx, "Starting evaluation in background")
|
|
go func() {
|
|
// Create new context with logger for background task
|
|
newCtx := logger.CloneContext(ctx)
|
|
logger.Infof(newCtx, "Background evaluation started for task ID: %s", taskID)
|
|
|
|
// Update task status to running
|
|
detail.Task.Status = types.EvaluationStatueRunning
|
|
logger.Info(newCtx, "Evaluation task status set to running")
|
|
|
|
// Execute actual evaluation
|
|
if err := e.EvalDataset(newCtx, detail); err != nil {
|
|
detail.Task.Status = types.EvaluationStatueFailed
|
|
detail.Task.ErrMsg = err.Error()
|
|
logger.Errorf(newCtx, "Evaluation task failed: %v, task ID: %s", err, taskID)
|
|
return
|
|
}
|
|
|
|
// Mark task as completed successfully
|
|
logger.Infof(newCtx, "Evaluation task completed successfully, task ID: %s", taskID)
|
|
detail.Task.Status = types.EvaluationStatueSuccess
|
|
}()
|
|
|
|
logger.Infof(ctx, "Evaluation task created successfully, task ID: %s", taskID)
|
|
return detail, nil
|
|
}
|
|
|
|
// EvalDataset performs the actual evaluation of a dataset
|
|
// Processes each QA pair in parallel and records metrics
|
|
func (e *EvaluationService) EvalDataset(ctx context.Context, detail *types.EvaluationDetail) error {
|
|
logger.Info(ctx, "Start evaluating dataset")
|
|
logger.Infof(ctx, "Task ID: %s, Dataset ID: %s", detail.Task.ID, detail.Task.DatasetID)
|
|
|
|
// Retrieve dataset from storage
|
|
dataset, err := e.dataset.GetDatasetByID(ctx, detail.Task.DatasetID)
|
|
if err != nil {
|
|
logger.Errorf(ctx, "Failed to get dataset: %v", err)
|
|
return err
|
|
}
|
|
logger.Infof(ctx, "Dataset retrieved successfully with %d QA pairs", len(dataset))
|
|
|
|
// Update total QA pairs count in task details
|
|
e.evaluationMemoryStorage.update(detail.Task.ID, func(params *types.EvaluationDetail) {
|
|
params.Task.Total = len(dataset)
|
|
logger.Infof(ctx, "Updated task total to %d QA pairs", params.Task.Total)
|
|
})
|
|
|
|
// Extract and organize passages from dataset
|
|
passages := getPassageList(dataset)
|
|
logger.Infof(ctx, "Creating knowledge from %d passages", len(passages))
|
|
|
|
// Create knowledge base from passages
|
|
knowledge, err := e.knowledgeService.CreateKnowledgeFromPassage(ctx, detail.Params.KnowledgeBaseID, passages)
|
|
if err != nil {
|
|
logger.Errorf(ctx, "Failed to create knowledge from passages: %v", err)
|
|
return err
|
|
}
|
|
logger.Infof(ctx, "Knowledge created successfully, ID: %s", knowledge.ID)
|
|
|
|
// Setup cleanup of temporary resources
|
|
defer func() {
|
|
logger.Infof(ctx, "Cleaning up resources - deleting knowledge: %s", knowledge.ID)
|
|
if err := e.knowledgeService.DeleteKnowledge(ctx, knowledge.ID); err != nil {
|
|
logger.Errorf(ctx, "Failed to delete knowledge: %v, knowledge ID: %s", err, knowledge.ID)
|
|
}
|
|
|
|
logger.Infof(ctx, "Cleaning up resources - deleting knowledge base: %s", detail.Params.KnowledgeBaseID)
|
|
if err := e.knowledgeBaseService.DeleteKnowledgeBase(ctx, detail.Params.KnowledgeBaseID); err != nil {
|
|
logger.Errorf(
|
|
ctx,
|
|
"Failed to delete knowledge base: %v, knowledge base ID: %s",
|
|
err, detail.Params.KnowledgeBaseID,
|
|
)
|
|
}
|
|
}()
|
|
|
|
// Initialize parallel evaluation metrics
|
|
var finished int
|
|
var mu sync.Mutex
|
|
var g errgroup.Group
|
|
metricHook := NewHookMetric(len(dataset))
|
|
|
|
// Set worker limit based on available CPUs
|
|
g.SetLimit(max(runtime.GOMAXPROCS(0)-1, 1))
|
|
logger.Infof(ctx, "Starting evaluation with %d parallel workers", max(runtime.GOMAXPROCS(0)-1, 1))
|
|
|
|
// Process each QA pair in parallel
|
|
for i, qaPair := range dataset {
|
|
qaPair := qaPair
|
|
i := i
|
|
g.Go(func() error {
|
|
logger.Infof(ctx, "Processing QA pair %d, question: %s", i, qaPair.Question)
|
|
|
|
// Prepare chat management parameters for this QA pair
|
|
chatManage := detail.Params.Clone()
|
|
chatManage.Query = qaPair.Question
|
|
|
|
// Execute knowledge QA pipeline
|
|
logger.Infof(ctx, "Running knowledge QA for question: %s", qaPair.Question)
|
|
err = e.sessionService.KnowledgeQAByEvent(ctx, chatManage, types.Pipline["rag"])
|
|
if err != nil {
|
|
logger.Errorf(ctx, "Failed to process question %d: %v", i, err)
|
|
return err
|
|
}
|
|
|
|
// Record evaluation metrics
|
|
logger.Infof(ctx, "Recording metrics for QA pair %d", i)
|
|
metricHook.recordInit(i)
|
|
metricHook.recordQaPair(i, qaPair)
|
|
metricHook.recordSearchResult(i, chatManage.SearchResult)
|
|
metricHook.recordRerankResult(i, chatManage.RerankResult)
|
|
metricHook.recordChatResponse(i, chatManage.ChatResponse)
|
|
metricHook.recordFinish(i)
|
|
|
|
// Update progress metrics
|
|
mu.Lock()
|
|
finished += 1
|
|
metricResult := metricHook.MetricResult()
|
|
mu.Unlock()
|
|
e.evaluationMemoryStorage.update(detail.Task.ID, func(params *types.EvaluationDetail) {
|
|
params.Metric = metricResult
|
|
params.Task.Finished = finished
|
|
logger.Infof(ctx, "Updated task progress: %d/%d completed", finished, params.Task.Total)
|
|
})
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// Wait for all parallel evaluations to complete
|
|
logger.Info(ctx, "Waiting for all evaluation tasks to complete")
|
|
if err := g.Wait(); err != nil {
|
|
logger.Errorf(ctx, "Evaluation error: %v", err)
|
|
return err
|
|
}
|
|
|
|
// Final update of evaluation metrics
|
|
e.evaluationMemoryStorage.update(detail.Task.ID, func(params *types.EvaluationDetail) {
|
|
params.Metric = metricHook.MetricResult()
|
|
params.Task.Finished = finished
|
|
})
|
|
|
|
logger.Infof(ctx, "Dataset evaluation completed successfully, task ID: %s", detail.Task.ID)
|
|
return nil
|
|
}
|
|
|
|
// getPassageList extracts and organizes passages from QA pairs
|
|
// Returns a slice of passages indexed by their passage IDs
|
|
func getPassageList(dataset []*types.QAPair) []string {
|
|
pIDMap := make(map[int]string)
|
|
maxPID := 0
|
|
for _, qaPair := range dataset {
|
|
for i := 0; i < len(qaPair.PIDs); i++ {
|
|
pIDMap[qaPair.PIDs[i]] = qaPair.Passages[i]
|
|
maxPID = max(maxPID, qaPair.PIDs[i])
|
|
}
|
|
}
|
|
passages := make([]string, maxPID)
|
|
for i := 0; i < maxPID; i++ {
|
|
if _, ok := pIDMap[i]; ok {
|
|
passages[i] = pIDMap[i]
|
|
}
|
|
}
|
|
return passages
|
|
}
|