l_ai_knowledge/internal/application/service/knowledgebase.go

610 lines
19 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"context"
"errors"
"knowlege-lsxd/internal/config"
"slices"
"time"
"encoding/json"
"knowlege-lsxd/internal/application/service/retriever"
"knowlege-lsxd/internal/common"
"knowlege-lsxd/internal/logger"
"knowlege-lsxd/internal/models/embedding"
"knowlege-lsxd/internal/types"
"knowlege-lsxd/internal/types/interfaces"
"github.com/google/uuid"
)
// ErrInvalidTenantID represents an error for invalid tenant ID
var ErrInvalidTenantID = errors.New("invalid tenant ID")
// knowledgeBaseService implements the knowledge base service interface
type knowledgeBaseService struct {
repo interfaces.KnowledgeBaseRepository
kgRepo interfaces.KnowledgeRepository
chunkRepo interfaces.ChunkRepository
modelService interfaces.ModelService
}
// NewKnowledgeBaseService creates a new knowledge base service
func NewKnowledgeBaseService(
cfg *config.Config,
repo interfaces.KnowledgeBaseRepository,
kgRepo interfaces.KnowledgeRepository,
chunkRepo interfaces.ChunkRepository,
modelService interfaces.ModelService,
) interfaces.KnowledgeBaseService {
return &knowledgeBaseService{
repo: repo,
kgRepo: kgRepo,
chunkRepo: chunkRepo,
modelService: modelService,
}
}
// CreateKnowledgeBase creates a new knowledge base
func (s *knowledgeBaseService) CreateKnowledgeBase(ctx context.Context,
kb *types.KnowledgeBase,
) (*types.KnowledgeBase, error) {
// Generate UUID and set creation timestamps
if kb.ID == "" {
kb.ID = uuid.New().String()
}
kb.CreatedAt = time.Now()
kb.TenantID = ctx.Value(types.TenantIDContextKey).(uint)
kb.UpdatedAt = time.Now()
logger.Infof(ctx, "Creating knowledge base, ID: %s, tenant ID: %d, name: %s", kb.ID, kb.TenantID, kb.Name)
if err := s.repo.CreateKnowledgeBase(ctx, kb); err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{
"knowledge_base_id": kb.ID,
"tenant_id": kb.TenantID,
})
return nil, err
}
logger.Infof(ctx, "Knowledge base created successfully, ID: %s, name: %s", kb.ID, kb.Name)
return kb, nil
}
// GetKnowledgeBaseByID retrieves a knowledge base by its ID
func (s *knowledgeBaseService) GetKnowledgeBaseByID(ctx context.Context, id string) (*types.KnowledgeBase, error) {
if id == "" {
logger.Error(ctx, "Knowledge base ID is empty")
return nil, errors.New("knowledge base ID cannot be empty")
}
logger.Infof(ctx, "Retrieving knowledge base, ID: %s", id)
kb, err := s.repo.GetKnowledgeBaseByID(ctx, id)
if err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{
"knowledge_base_id": id,
})
return nil, err
}
logger.Infof(ctx, "Knowledge base retrieved successfully, ID: %s, name: %s", kb.ID, kb.Name)
return kb, nil
}
// ListKnowledgeBases returns all knowledge bases for a tenant
func (s *knowledgeBaseService) ListKnowledgeBases(ctx context.Context) ([]*types.KnowledgeBase, error) {
tenantID := ctx.Value(types.TenantIDContextKey).(uint)
logger.Infof(ctx, "Retrieving knowledge base list for tenant, tenant ID: %d", tenantID)
kbs, err := s.repo.ListKnowledgeBasesByTenantID(ctx, tenantID)
if err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{
"tenant_id": tenantID,
})
return nil, err
}
logger.Infof(
ctx,
"Knowledge base list retrieved successfully, tenant ID: %d, knowledge base count: %d",
tenantID,
len(kbs),
)
return kbs, nil
}
// UpdateKnowledgeBase updates a knowledge base's properties
func (s *knowledgeBaseService) UpdateKnowledgeBase(ctx context.Context,
id string,
name string,
description string,
config *types.KnowledgeBaseConfig,
) (*types.KnowledgeBase, error) {
if id == "" {
logger.Error(ctx, "Knowledge base ID is empty")
return nil, errors.New("knowledge base ID cannot be empty")
}
logger.Infof(ctx, "Updating knowledge base, ID: %s, name: %s", id, name)
// Get existing knowledge base
kb, err := s.repo.GetKnowledgeBaseByID(ctx, id)
if err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{
"knowledge_base_id": id,
})
return nil, err
}
// Update the knowledge base properties
kb.Name = name
kb.Description = description
kb.ChunkingConfig = config.ChunkingConfig
kb.ImageProcessingConfig = config.ImageProcessingConfig
kb.UpdatedAt = time.Now()
logger.Info(ctx, "Saving knowledge base update")
if err := s.repo.UpdateKnowledgeBase(ctx, kb); err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{
"knowledge_base_id": id,
})
return nil, err
}
logger.Infof(ctx, "Knowledge base updated successfully, ID: %s, name: %s", kb.ID, kb.Name)
return kb, nil
}
// DeleteKnowledgeBase deletes a knowledge base by its ID
func (s *knowledgeBaseService) DeleteKnowledgeBase(ctx context.Context, id string) error {
if id == "" {
logger.Error(ctx, "Knowledge base ID is empty")
return errors.New("knowledge base ID cannot be empty")
}
logger.Infof(ctx, "Deleting knowledge base, ID: %s", id)
err := s.repo.DeleteKnowledgeBase(ctx, id)
if err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{
"knowledge_base_id": id,
})
return err
}
logger.Infof(ctx, "Knowledge base deleted successfully, ID: %s", id)
return nil
}
// SetEmbeddingModel sets the embedding model for a knowledge base
func (s *knowledgeBaseService) SetEmbeddingModel(ctx context.Context, id string, modelID string) error {
if id == "" {
logger.Error(ctx, "Knowledge base ID is empty")
return errors.New("knowledge base ID cannot be empty")
}
if modelID == "" {
logger.Error(ctx, "Model ID is empty")
return errors.New("model ID cannot be empty")
}
logger.Infof(ctx, "Setting embedding model for knowledge base, knowledge base ID: %s, model ID: %s", id, modelID)
// Get the knowledge base
kb, err := s.repo.GetKnowledgeBaseByID(ctx, id)
if err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{
"knowledge_base_id": id,
})
return err
}
// Update the knowledge base's embedding model
kb.EmbeddingModelID = modelID
kb.UpdatedAt = time.Now()
logger.Info(ctx, "Saving knowledge base embedding model update")
err = s.repo.UpdateKnowledgeBase(ctx, kb)
if err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{
"knowledge_base_id": id,
"embedding_model_id": modelID,
})
return err
}
logger.Infof(
ctx,
"Knowledge base embedding model set successfully, knowledge base ID: %s, model ID: %s",
id,
modelID,
)
return nil
}
// CopyKnowledgeBase copies a knowledge base to a new knowledge base
// 浅拷贝
func (s *knowledgeBaseService) CopyKnowledgeBase(ctx context.Context,
srcKB string, dstKB string,
) (*types.KnowledgeBase, *types.KnowledgeBase, error) {
sourceKB, err := s.repo.GetKnowledgeBaseByID(ctx, srcKB)
if err != nil {
logger.Errorf(ctx, "Get source knowledge base failed: %v", err)
return nil, nil, err
}
tenantID := ctx.Value(types.TenantIDContextKey).(uint)
var targetKB *types.KnowledgeBase
if dstKB != "" {
targetKB, err = s.repo.GetKnowledgeBaseByID(ctx, dstKB)
if err != nil {
return nil, nil, err
}
} else {
targetKB = &types.KnowledgeBase{
ID: uuid.New().String(),
Name: sourceKB.Name,
Description: sourceKB.Description,
TenantID: tenantID,
ChunkingConfig: sourceKB.ChunkingConfig,
ImageProcessingConfig: sourceKB.ImageProcessingConfig,
EmbeddingModelID: sourceKB.EmbeddingModelID,
SummaryModelID: sourceKB.SummaryModelID,
RerankModelID: sourceKB.RerankModelID,
VLMModelID: sourceKB.VLMModelID,
StorageConfig: sourceKB.StorageConfig,
}
if err := s.repo.CreateKnowledgeBase(ctx, targetKB); err != nil {
return nil, nil, err
}
}
return sourceKB, targetKB, nil
}
// HybridSearch 执行混合搜索,结合向量和关键词检索方法
// 参数:
// - ctx: 上下文,用于传递请求信息和控制执行
// - id: 知识库ID
// - params: 搜索参数,包含查询文本、匹配数量等
//
// 返回值:
// - []*types.SearchResult: 搜索结果列表
// - error: 错误信息
func (s *knowledgeBaseService) HybridSearch(ctx context.Context,
id string,
params types.SearchParams,
) ([]*types.SearchResult, error) {
// 记录混合搜索的参数信息包括知识库ID和查询文本
logger.Infof(ctx, "Hybrid search parameters, knowledge base ID: %s, query text: %s", id, params.QueryText)
// 从上下文中获取租户信息
tenantInfo := ctx.Value(types.TenantInfoContextKey).(*types.Tenant)
logger.Infof(ctx, "Creating composite retrieval engine, tenant ID: %d", tenantInfo.ID)
// Create a composite retrieval engine with tenant's configured retrievers
retrieveEngine, err := retriever.NewCompositeRetrieveEngine(tenantInfo.RetrieverEngines.Engines)
if err != nil {
logger.Errorf(ctx, "Failed to create retrieval engine: %v", err)
return nil, err
}
var retrieveParams []types.RetrieveParams
var embeddingModel embedding.Embedder
var kb *types.KnowledgeBase
// Add vector retrieval params if supported
if retrieveEngine.SupportRetriever(types.VectorRetrieverType) {
logger.Info(ctx, "Vector retrieval supported, preparing vector retrieval parameters")
kb, err = s.repo.GetKnowledgeBaseByID(ctx, id)
if err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{
"knowledge_base_id": id,
})
return nil, err
}
logger.Infof(ctx, "Getting embedding model, model ID: %s", kb.EmbeddingModelID)
embeddingModel, err = s.modelService.GetEmbeddingModel(ctx, kb.EmbeddingModelID)
if err != nil {
logger.Errorf(ctx, "Failed to get embedding model, model ID: %s, error: %v", kb.EmbeddingModelID, err)
return nil, err
}
logger.Infof(ctx, "Embedding model retrieved: %v", embeddingModel)
// Generate embedding vector for the query text
logger.Info(ctx, "Starting to generate query embedding")
queryEmbedding, err := embeddingModel.Embed(ctx, params.QueryText)
if err != nil {
logger.Errorf(ctx, "Failed to embed query text, query text: %s, error: %v", params.QueryText, err)
return nil, err
}
logger.Infof(ctx, "Query embedding generated successfully, embedding vector length: %d", len(queryEmbedding))
retrieveParams = append(retrieveParams, types.RetrieveParams{
Query: params.QueryText,
Embedding: queryEmbedding,
KnowledgeBaseIDs: []string{id},
TopK: params.MatchCount,
Threshold: params.VectorThreshold,
RetrieverType: types.VectorRetrieverType,
})
logger.Info(ctx, "Vector retrieval parameters setup completed")
}
// Add keyword retrieval params if supported
if retrieveEngine.SupportRetriever(types.KeywordsRetrieverType) {
logger.Info(ctx, "Keyword retrieval supported, preparing keyword retrieval parameters")
retrieveParams = append(retrieveParams, types.RetrieveParams{
Query: params.QueryText,
KnowledgeBaseIDs: []string{id},
TopK: params.MatchCount,
Threshold: params.KeywordThreshold,
RetrieverType: types.KeywordsRetrieverType,
})
logger.Info(ctx, "Keyword retrieval parameters setup completed")
}
if len(retrieveParams) == 0 {
logger.Error(ctx, "No retrieval parameters available")
return nil, errors.New("no retrieve params")
}
// Execute retrieval using the configured engines
logger.Infof(ctx, "Starting retrieval, parameter count: %d", len(retrieveParams))
retrieveResults, err := retrieveEngine.Retrieve(ctx, retrieveParams)
if err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{
"knowledge_base_id": id,
"query_text": params.QueryText,
})
return nil, err
}
// Collect all results from different retrievers and deduplicate by chunk ID
logger.Infof(ctx, "Processing retrieval results")
matchResults := []*types.IndexWithScore{}
for _, retrieveResult := range retrieveResults {
logger.Infof(ctx, "Retrieval results, engine: %v, retriever: %v, count: %v",
retrieveResult.RetrieverEngineType,
retrieveResult.RetrieverType,
len(retrieveResult.Results),
)
matchResults = append(matchResults, retrieveResult.Results...)
}
// Early return if no results
if len(matchResults) == 0 {
logger.Info(ctx, "No search results found")
return nil, nil
}
// Deduplicate results by chunk ID
logger.Infof(ctx, "Result count before deduplication: %d", len(matchResults))
deduplicatedChunks := common.Deduplicate(func(r *types.IndexWithScore) string { return r.ChunkID }, matchResults...)
logger.Infof(ctx, "Result count after deduplication: %d", len(deduplicatedChunks))
return s.processSearchResults(ctx, deduplicatedChunks)
}
// processSearchResults 处理搜索结果的方法
// 参数:
//
// ctx - 上下文,用于传递请求相关的信息和控制请求的生命周期
// chunks - 带有分数的索引块列表
//
// 返回值:
//
// []*types.SearchResult - 处理后的搜索结果列表
// error - 处理过程中可能出现的错误
func (s *knowledgeBaseService) processSearchResults(ctx context.Context,
chunks []*types.IndexWithScore) ([]*types.SearchResult, error) {
// 如果输入的块列表为空直接返回nil
if len(chunks) == 0 {
return nil, nil
}
// 从上下文中获取租户ID
tenantID := ctx.Value(types.TenantIDContextKey).(uint)
// 初始化必要的变量
var knowledgeIDs []string // 知识ID列表
var chunkIDs []string // 块ID列表
chunkScores := make(map[string]float64) // 块ID到分数的映射
chunkMatchTypes := make(map[string]types.MatchType) // 块ID到匹配类型的映射
processedKnowledgeIDs := make(map[string]bool) // 已处理的知识ID集合
// 遍历处理输入的块列表
for _, chunk := range chunks {
// 如果知识ID未处理过则添加到知识ID列表
if !processedKnowledgeIDs[chunk.KnowledgeID] {
knowledgeIDs = append(knowledgeIDs, chunk.KnowledgeID)
processedKnowledgeIDs[chunk.KnowledgeID] = true
}
// 添加块ID到块ID列表并记录分数和匹配类型
chunkIDs = append(chunkIDs, chunk.ChunkID)
chunkScores[chunk.ChunkID] = chunk.Score
chunkMatchTypes[chunk.ChunkID] = chunk.MatchType
}
// 获取知识数据
logger.Infof(ctx, "Fetching knowledge data for %d IDs", len(knowledgeIDs))
knowledgeMap, err := s.fetchKnowledgeData(ctx, tenantID, knowledgeIDs)
if err != nil {
return nil, err
}
// 获取块数据
logger.Infof(ctx, "Fetching chunk data for %d IDs", len(chunkIDs))
allChunks, err := s.chunkRepo.ListChunksByID(ctx, tenantID, chunkIDs)
if err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{
"tenant_id": tenantID,
"chunk_ids": chunkIDs,
})
return nil, err
}
// 创建块映射并处理相关块
chunkMap := make(map[string]*types.Chunk, len(allChunks))
var additionalChunkIDs []string
processedChunkIDs := make(map[string]bool)
for _, chunk := range allChunks {
// 将块添加到块映射中
chunkMap[chunk.ID] = chunk
processedChunkIDs[chunk.ID] = true
// 处理父块
if chunk.ParentChunkID != "" && !processedChunkIDs[chunk.ParentChunkID] {
additionalChunkIDs = append(additionalChunkIDs, chunk.ParentChunkID)
processedChunkIDs[chunk.ParentChunkID] = true
chunkScores[chunk.ParentChunkID] = chunkScores[chunk.ID]
chunkMatchTypes[chunk.ParentChunkID] = types.MatchTypeParentChunk
}
// 处理相关块
relationChunkIDs := s.collectRelatedChunkIDs(chunk, processedChunkIDs)
for _, chunkID := range relationChunkIDs {
additionalChunkIDs = append(additionalChunkIDs, chunkID)
chunkMatchTypes[chunkID] = types.MatchTypeRelationChunk
}
// 处理文本类型块的相邻块
if chunk.ChunkType == types.ChunkTypeText {
if chunk.NextChunkID != "" && !processedChunkIDs[chunk.NextChunkID] {
additionalChunkIDs = append(additionalChunkIDs, chunk.NextChunkID)
processedChunkIDs[chunk.NextChunkID] = true
chunkMatchTypes[chunk.NextChunkID] = types.MatchTypeNearByChunk
}
if chunk.PreChunkID != "" && !processedChunkIDs[chunk.PreChunkID] {
additionalChunkIDs = append(additionalChunkIDs, chunk.PreChunkID)
processedChunkIDs[chunk.PreChunkID] = true
chunkMatchTypes[chunk.PreChunkID] = types.MatchTypeNearByChunk
}
}
}
// 获取额外的块数据
if len(additionalChunkIDs) > 0 {
logger.Infof(ctx, "Fetching %d additional chunks", len(additionalChunkIDs))
additionalChunks, err := s.chunkRepo.ListChunksByID(ctx, tenantID, additionalChunkIDs)
if err != nil {
logger.Warnf(ctx, "Failed to fetch some additional chunks: %v", err)
} else {
// 将获取到的额外块添加到块映射中
for _, chunk := range additionalChunks {
chunkMap[chunk.ID] = chunk
}
}
}
// 构建搜索结果
var searchResults []*types.SearchResult
for chunkID, chunk := range chunkMap {
// 检查是否为有效的文本块
if !s.isValidTextChunk(chunk) {
continue
}
// 获取块的分数
score, hasScore := chunkScores[chunkID]
if !hasScore || score <= 0 {
score = 0.0
}
// 构建搜索结果
if knowledge, ok := knowledgeMap[chunk.KnowledgeID]; ok {
matchType := types.MatchTypeParentChunk
if specificType, exists := chunkMatchTypes[chunkID]; exists {
matchType = specificType
} else {
logger.Warnf(ctx, "Unkonwn match type for chunk: %s", chunkID)
continue
}
searchResults = append(searchResults, s.buildSearchResult(chunk, knowledge, score, matchType))
}
}
logger.Infof(ctx, "Search results processed, total: %d", len(searchResults))
return searchResults, nil
}
// collectRelatedChunkIDs extracts related chunk IDs from a chunk
func (s *knowledgeBaseService) collectRelatedChunkIDs(chunk *types.Chunk, processedIDs map[string]bool) []string {
var relatedIDs []string
// Process direct relations
if len(chunk.RelationChunks) > 0 {
var relations []string
if err := json.Unmarshal(chunk.RelationChunks, &relations); err == nil {
for _, id := range relations {
if !processedIDs[id] {
relatedIDs = append(relatedIDs, id)
processedIDs[id] = true
}
}
}
}
return relatedIDs
}
// buildSearchResult creates a search result from chunk and knowledge
func (s *knowledgeBaseService) buildSearchResult(chunk *types.Chunk,
knowledge *types.Knowledge,
score float64,
matchType types.MatchType) *types.SearchResult {
return &types.SearchResult{
ID: chunk.ID,
Content: chunk.Content,
KnowledgeID: chunk.KnowledgeID,
ChunkIndex: chunk.ChunkIndex,
KnowledgeTitle: knowledge.Title,
StartAt: chunk.StartAt,
EndAt: chunk.EndAt,
Seq: chunk.ChunkIndex,
Score: score,
MatchType: matchType,
Metadata: knowledge.GetMetadata(),
ChunkType: string(chunk.ChunkType),
ParentChunkID: chunk.ParentChunkID,
ImageInfo: chunk.ImageInfo,
KnowledgeFilename: knowledge.FileName,
KnowledgeSource: knowledge.Source,
}
}
// isValidTextChunk checks if a chunk is a valid text chunk
func (s *knowledgeBaseService) isValidTextChunk(chunk *types.Chunk) bool {
return slices.Contains([]types.ChunkType{
types.ChunkTypeText, types.ChunkTypeSummary,
}, chunk.ChunkType)
}
// fetchKnowledgeData gets knowledge data in batch
func (s *knowledgeBaseService) fetchKnowledgeData(ctx context.Context,
tenantID uint,
knowledgeIDs []string,
) (map[string]*types.Knowledge, error) {
knowledges, err := s.kgRepo.GetKnowledgeBatch(ctx, tenantID, knowledgeIDs)
if err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{
"tenant_id": tenantID,
"knowledge_ids": knowledgeIDs,
})
return nil, err
}
knowledgeMap := make(map[string]*types.Knowledge, len(knowledges))
for _, knowledge := range knowledges {
knowledgeMap[knowledge.ID] = knowledge
}
return knowledgeMap, nil
}