package service import ( "context" "errors" "knowlege-lsxd/internal/config" "slices" "time" "encoding/json" "knowlege-lsxd/internal/application/service/retriever" "knowlege-lsxd/internal/common" "knowlege-lsxd/internal/logger" "knowlege-lsxd/internal/models/embedding" "knowlege-lsxd/internal/types" "knowlege-lsxd/internal/types/interfaces" "github.com/google/uuid" ) // ErrInvalidTenantID represents an error for invalid tenant ID var ErrInvalidTenantID = errors.New("invalid tenant ID") // knowledgeBaseService implements the knowledge base service interface type knowledgeBaseService struct { repo interfaces.KnowledgeBaseRepository kgRepo interfaces.KnowledgeRepository chunkRepo interfaces.ChunkRepository modelService interfaces.ModelService } // NewKnowledgeBaseService creates a new knowledge base service func NewKnowledgeBaseService( cfg *config.Config, repo interfaces.KnowledgeBaseRepository, kgRepo interfaces.KnowledgeRepository, chunkRepo interfaces.ChunkRepository, modelService interfaces.ModelService, ) interfaces.KnowledgeBaseService { return &knowledgeBaseService{ repo: repo, kgRepo: kgRepo, chunkRepo: chunkRepo, modelService: modelService, } } // CreateKnowledgeBase creates a new knowledge base func (s *knowledgeBaseService) CreateKnowledgeBase(ctx context.Context, kb *types.KnowledgeBase, ) (*types.KnowledgeBase, error) { // Generate UUID and set creation timestamps if kb.ID == "" { kb.ID = uuid.New().String() } kb.CreatedAt = time.Now() kb.TenantID = ctx.Value(types.TenantIDContextKey).(uint) kb.UpdatedAt = time.Now() logger.Infof(ctx, "Creating knowledge base, ID: %s, tenant ID: %d, name: %s", kb.ID, kb.TenantID, kb.Name) if err := s.repo.CreateKnowledgeBase(ctx, kb); err != nil { logger.ErrorWithFields(ctx, err, map[string]interface{}{ "knowledge_base_id": kb.ID, "tenant_id": kb.TenantID, }) return nil, err } logger.Infof(ctx, "Knowledge base created successfully, ID: %s, name: %s", kb.ID, kb.Name) return kb, nil } // GetKnowledgeBaseByID retrieves a knowledge base by its ID func (s *knowledgeBaseService) GetKnowledgeBaseByID(ctx context.Context, id string) (*types.KnowledgeBase, error) { if id == "" { logger.Error(ctx, "Knowledge base ID is empty") return nil, errors.New("knowledge base ID cannot be empty") } logger.Infof(ctx, "Retrieving knowledge base, ID: %s", id) kb, err := s.repo.GetKnowledgeBaseByID(ctx, id) if err != nil { logger.ErrorWithFields(ctx, err, map[string]interface{}{ "knowledge_base_id": id, }) return nil, err } logger.Infof(ctx, "Knowledge base retrieved successfully, ID: %s, name: %s", kb.ID, kb.Name) return kb, nil } // ListKnowledgeBases returns all knowledge bases for a tenant func (s *knowledgeBaseService) ListKnowledgeBases(ctx context.Context) ([]*types.KnowledgeBase, error) { tenantID := ctx.Value(types.TenantIDContextKey).(uint) logger.Infof(ctx, "Retrieving knowledge base list for tenant, tenant ID: %d", tenantID) kbs, err := s.repo.ListKnowledgeBasesByTenantID(ctx, tenantID) if err != nil { logger.ErrorWithFields(ctx, err, map[string]interface{}{ "tenant_id": tenantID, }) return nil, err } logger.Infof( ctx, "Knowledge base list retrieved successfully, tenant ID: %d, knowledge base count: %d", tenantID, len(kbs), ) return kbs, nil } // UpdateKnowledgeBase updates a knowledge base's properties func (s *knowledgeBaseService) UpdateKnowledgeBase(ctx context.Context, id string, name string, description string, config *types.KnowledgeBaseConfig, ) (*types.KnowledgeBase, error) { if id == "" { logger.Error(ctx, "Knowledge base ID is empty") return nil, errors.New("knowledge base ID cannot be empty") } logger.Infof(ctx, "Updating knowledge base, ID: %s, name: %s", id, name) // Get existing knowledge base kb, err := s.repo.GetKnowledgeBaseByID(ctx, id) if err != nil { logger.ErrorWithFields(ctx, err, map[string]interface{}{ "knowledge_base_id": id, }) return nil, err } // Update the knowledge base properties kb.Name = name kb.Description = description kb.ChunkingConfig = config.ChunkingConfig kb.ImageProcessingConfig = config.ImageProcessingConfig kb.UpdatedAt = time.Now() logger.Info(ctx, "Saving knowledge base update") if err := s.repo.UpdateKnowledgeBase(ctx, kb); err != nil { logger.ErrorWithFields(ctx, err, map[string]interface{}{ "knowledge_base_id": id, }) return nil, err } logger.Infof(ctx, "Knowledge base updated successfully, ID: %s, name: %s", kb.ID, kb.Name) return kb, nil } // DeleteKnowledgeBase deletes a knowledge base by its ID func (s *knowledgeBaseService) DeleteKnowledgeBase(ctx context.Context, id string) error { if id == "" { logger.Error(ctx, "Knowledge base ID is empty") return errors.New("knowledge base ID cannot be empty") } logger.Infof(ctx, "Deleting knowledge base, ID: %s", id) err := s.repo.DeleteKnowledgeBase(ctx, id) if err != nil { logger.ErrorWithFields(ctx, err, map[string]interface{}{ "knowledge_base_id": id, }) return err } logger.Infof(ctx, "Knowledge base deleted successfully, ID: %s", id) return nil } // SetEmbeddingModel sets the embedding model for a knowledge base func (s *knowledgeBaseService) SetEmbeddingModel(ctx context.Context, id string, modelID string) error { if id == "" { logger.Error(ctx, "Knowledge base ID is empty") return errors.New("knowledge base ID cannot be empty") } if modelID == "" { logger.Error(ctx, "Model ID is empty") return errors.New("model ID cannot be empty") } logger.Infof(ctx, "Setting embedding model for knowledge base, knowledge base ID: %s, model ID: %s", id, modelID) // Get the knowledge base kb, err := s.repo.GetKnowledgeBaseByID(ctx, id) if err != nil { logger.ErrorWithFields(ctx, err, map[string]interface{}{ "knowledge_base_id": id, }) return err } // Update the knowledge base's embedding model kb.EmbeddingModelID = modelID kb.UpdatedAt = time.Now() logger.Info(ctx, "Saving knowledge base embedding model update") err = s.repo.UpdateKnowledgeBase(ctx, kb) if err != nil { logger.ErrorWithFields(ctx, err, map[string]interface{}{ "knowledge_base_id": id, "embedding_model_id": modelID, }) return err } logger.Infof( ctx, "Knowledge base embedding model set successfully, knowledge base ID: %s, model ID: %s", id, modelID, ) return nil } // CopyKnowledgeBase copies a knowledge base to a new knowledge base // 浅拷贝 func (s *knowledgeBaseService) CopyKnowledgeBase(ctx context.Context, srcKB string, dstKB string, ) (*types.KnowledgeBase, *types.KnowledgeBase, error) { sourceKB, err := s.repo.GetKnowledgeBaseByID(ctx, srcKB) if err != nil { logger.Errorf(ctx, "Get source knowledge base failed: %v", err) return nil, nil, err } tenantID := ctx.Value(types.TenantIDContextKey).(uint) var targetKB *types.KnowledgeBase if dstKB != "" { targetKB, err = s.repo.GetKnowledgeBaseByID(ctx, dstKB) if err != nil { return nil, nil, err } } else { targetKB = &types.KnowledgeBase{ ID: uuid.New().String(), Name: sourceKB.Name, Description: sourceKB.Description, TenantID: tenantID, ChunkingConfig: sourceKB.ChunkingConfig, ImageProcessingConfig: sourceKB.ImageProcessingConfig, EmbeddingModelID: sourceKB.EmbeddingModelID, SummaryModelID: sourceKB.SummaryModelID, RerankModelID: sourceKB.RerankModelID, VLMModelID: sourceKB.VLMModelID, StorageConfig: sourceKB.StorageConfig, } if err := s.repo.CreateKnowledgeBase(ctx, targetKB); err != nil { return nil, nil, err } } return sourceKB, targetKB, nil } // HybridSearch 执行混合搜索,结合向量和关键词检索方法 // 参数: // - ctx: 上下文,用于传递请求信息和控制执行 // - id: 知识库ID // - params: 搜索参数,包含查询文本、匹配数量等 // // 返回值: // - []*types.SearchResult: 搜索结果列表 // - error: 错误信息 func (s *knowledgeBaseService) HybridSearch(ctx context.Context, id string, params types.SearchParams, ) ([]*types.SearchResult, error) { // 记录混合搜索的参数信息,包括知识库ID和查询文本 logger.Infof(ctx, "Hybrid search parameters, knowledge base ID: %s, query text: %s", id, params.QueryText) // 从上下文中获取租户信息 tenantInfo := ctx.Value(types.TenantInfoContextKey).(*types.Tenant) logger.Infof(ctx, "Creating composite retrieval engine, tenant ID: %d", tenantInfo.ID) // Create a composite retrieval engine with tenant's configured retrievers retrieveEngine, err := retriever.NewCompositeRetrieveEngine(tenantInfo.RetrieverEngines.Engines) if err != nil { logger.Errorf(ctx, "Failed to create retrieval engine: %v", err) return nil, err } var retrieveParams []types.RetrieveParams var embeddingModel embedding.Embedder var kb *types.KnowledgeBase // Add vector retrieval params if supported if retrieveEngine.SupportRetriever(types.VectorRetrieverType) { logger.Info(ctx, "Vector retrieval supported, preparing vector retrieval parameters") kb, err = s.repo.GetKnowledgeBaseByID(ctx, id) if err != nil { logger.ErrorWithFields(ctx, err, map[string]interface{}{ "knowledge_base_id": id, }) return nil, err } logger.Infof(ctx, "Getting embedding model, model ID: %s", kb.EmbeddingModelID) embeddingModel, err = s.modelService.GetEmbeddingModel(ctx, kb.EmbeddingModelID) if err != nil { logger.Errorf(ctx, "Failed to get embedding model, model ID: %s, error: %v", kb.EmbeddingModelID, err) return nil, err } logger.Infof(ctx, "Embedding model retrieved: %v", embeddingModel) // Generate embedding vector for the query text logger.Info(ctx, "Starting to generate query embedding") queryEmbedding, err := embeddingModel.Embed(ctx, params.QueryText) if err != nil { logger.Errorf(ctx, "Failed to embed query text, query text: %s, error: %v", params.QueryText, err) return nil, err } logger.Infof(ctx, "Query embedding generated successfully, embedding vector length: %d", len(queryEmbedding)) retrieveParams = append(retrieveParams, types.RetrieveParams{ Query: params.QueryText, Embedding: queryEmbedding, KnowledgeBaseIDs: []string{id}, TopK: params.MatchCount, Threshold: params.VectorThreshold, RetrieverType: types.VectorRetrieverType, }) logger.Info(ctx, "Vector retrieval parameters setup completed") } // Add keyword retrieval params if supported if retrieveEngine.SupportRetriever(types.KeywordsRetrieverType) { logger.Info(ctx, "Keyword retrieval supported, preparing keyword retrieval parameters") retrieveParams = append(retrieveParams, types.RetrieveParams{ Query: params.QueryText, KnowledgeBaseIDs: []string{id}, TopK: params.MatchCount, Threshold: params.KeywordThreshold, RetrieverType: types.KeywordsRetrieverType, }) logger.Info(ctx, "Keyword retrieval parameters setup completed") } if len(retrieveParams) == 0 { logger.Error(ctx, "No retrieval parameters available") return nil, errors.New("no retrieve params") } // Execute retrieval using the configured engines logger.Infof(ctx, "Starting retrieval, parameter count: %d", len(retrieveParams)) retrieveResults, err := retrieveEngine.Retrieve(ctx, retrieveParams) if err != nil { logger.ErrorWithFields(ctx, err, map[string]interface{}{ "knowledge_base_id": id, "query_text": params.QueryText, }) return nil, err } // Collect all results from different retrievers and deduplicate by chunk ID logger.Infof(ctx, "Processing retrieval results") matchResults := []*types.IndexWithScore{} for _, retrieveResult := range retrieveResults { logger.Infof(ctx, "Retrieval results, engine: %v, retriever: %v, count: %v", retrieveResult.RetrieverEngineType, retrieveResult.RetrieverType, len(retrieveResult.Results), ) matchResults = append(matchResults, retrieveResult.Results...) } // Early return if no results if len(matchResults) == 0 { logger.Info(ctx, "No search results found") return nil, nil } // Deduplicate results by chunk ID logger.Infof(ctx, "Result count before deduplication: %d", len(matchResults)) deduplicatedChunks := common.Deduplicate(func(r *types.IndexWithScore) string { return r.ChunkID }, matchResults...) logger.Infof(ctx, "Result count after deduplication: %d", len(deduplicatedChunks)) return s.processSearchResults(ctx, deduplicatedChunks) } // processSearchResults 处理搜索结果的方法 // 参数: // // ctx - 上下文,用于传递请求相关的信息和控制请求的生命周期 // chunks - 带有分数的索引块列表 // // 返回值: // // []*types.SearchResult - 处理后的搜索结果列表 // error - 处理过程中可能出现的错误 func (s *knowledgeBaseService) processSearchResults(ctx context.Context, chunks []*types.IndexWithScore) ([]*types.SearchResult, error) { // 如果输入的块列表为空,直接返回nil if len(chunks) == 0 { return nil, nil } // 从上下文中获取租户ID tenantID := ctx.Value(types.TenantIDContextKey).(uint) // 初始化必要的变量 var knowledgeIDs []string // 知识ID列表 var chunkIDs []string // 块ID列表 chunkScores := make(map[string]float64) // 块ID到分数的映射 chunkMatchTypes := make(map[string]types.MatchType) // 块ID到匹配类型的映射 processedKnowledgeIDs := make(map[string]bool) // 已处理的知识ID集合 // 遍历处理输入的块列表 for _, chunk := range chunks { // 如果知识ID未处理过,则添加到知识ID列表 if !processedKnowledgeIDs[chunk.KnowledgeID] { knowledgeIDs = append(knowledgeIDs, chunk.KnowledgeID) processedKnowledgeIDs[chunk.KnowledgeID] = true } // 添加块ID到块ID列表,并记录分数和匹配类型 chunkIDs = append(chunkIDs, chunk.ChunkID) chunkScores[chunk.ChunkID] = chunk.Score chunkMatchTypes[chunk.ChunkID] = chunk.MatchType } // 获取知识数据 logger.Infof(ctx, "Fetching knowledge data for %d IDs", len(knowledgeIDs)) knowledgeMap, err := s.fetchKnowledgeData(ctx, tenantID, knowledgeIDs) if err != nil { return nil, err } // 获取块数据 logger.Infof(ctx, "Fetching chunk data for %d IDs", len(chunkIDs)) allChunks, err := s.chunkRepo.ListChunksByID(ctx, tenantID, chunkIDs) if err != nil { logger.ErrorWithFields(ctx, err, map[string]interface{}{ "tenant_id": tenantID, "chunk_ids": chunkIDs, }) return nil, err } // 创建块映射并处理相关块 chunkMap := make(map[string]*types.Chunk, len(allChunks)) var additionalChunkIDs []string processedChunkIDs := make(map[string]bool) for _, chunk := range allChunks { // 将块添加到块映射中 chunkMap[chunk.ID] = chunk processedChunkIDs[chunk.ID] = true // 处理父块 if chunk.ParentChunkID != "" && !processedChunkIDs[chunk.ParentChunkID] { additionalChunkIDs = append(additionalChunkIDs, chunk.ParentChunkID) processedChunkIDs[chunk.ParentChunkID] = true chunkScores[chunk.ParentChunkID] = chunkScores[chunk.ID] chunkMatchTypes[chunk.ParentChunkID] = types.MatchTypeParentChunk } // 处理相关块 relationChunkIDs := s.collectRelatedChunkIDs(chunk, processedChunkIDs) for _, chunkID := range relationChunkIDs { additionalChunkIDs = append(additionalChunkIDs, chunkID) chunkMatchTypes[chunkID] = types.MatchTypeRelationChunk } // 处理文本类型块的相邻块 if chunk.ChunkType == types.ChunkTypeText { if chunk.NextChunkID != "" && !processedChunkIDs[chunk.NextChunkID] { additionalChunkIDs = append(additionalChunkIDs, chunk.NextChunkID) processedChunkIDs[chunk.NextChunkID] = true chunkMatchTypes[chunk.NextChunkID] = types.MatchTypeNearByChunk } if chunk.PreChunkID != "" && !processedChunkIDs[chunk.PreChunkID] { additionalChunkIDs = append(additionalChunkIDs, chunk.PreChunkID) processedChunkIDs[chunk.PreChunkID] = true chunkMatchTypes[chunk.PreChunkID] = types.MatchTypeNearByChunk } } } // 获取额外的块数据 if len(additionalChunkIDs) > 0 { logger.Infof(ctx, "Fetching %d additional chunks", len(additionalChunkIDs)) additionalChunks, err := s.chunkRepo.ListChunksByID(ctx, tenantID, additionalChunkIDs) if err != nil { logger.Warnf(ctx, "Failed to fetch some additional chunks: %v", err) } else { // 将获取到的额外块添加到块映射中 for _, chunk := range additionalChunks { chunkMap[chunk.ID] = chunk } } } // 构建搜索结果 var searchResults []*types.SearchResult for chunkID, chunk := range chunkMap { // 检查是否为有效的文本块 if !s.isValidTextChunk(chunk) { continue } // 获取块的分数 score, hasScore := chunkScores[chunkID] if !hasScore || score <= 0 { score = 0.0 } // 构建搜索结果 if knowledge, ok := knowledgeMap[chunk.KnowledgeID]; ok { matchType := types.MatchTypeParentChunk if specificType, exists := chunkMatchTypes[chunkID]; exists { matchType = specificType } else { logger.Warnf(ctx, "Unkonwn match type for chunk: %s", chunkID) continue } searchResults = append(searchResults, s.buildSearchResult(chunk, knowledge, score, matchType)) } } logger.Infof(ctx, "Search results processed, total: %d", len(searchResults)) return searchResults, nil } // collectRelatedChunkIDs extracts related chunk IDs from a chunk func (s *knowledgeBaseService) collectRelatedChunkIDs(chunk *types.Chunk, processedIDs map[string]bool) []string { var relatedIDs []string // Process direct relations if len(chunk.RelationChunks) > 0 { var relations []string if err := json.Unmarshal(chunk.RelationChunks, &relations); err == nil { for _, id := range relations { if !processedIDs[id] { relatedIDs = append(relatedIDs, id) processedIDs[id] = true } } } } return relatedIDs } // buildSearchResult creates a search result from chunk and knowledge func (s *knowledgeBaseService) buildSearchResult(chunk *types.Chunk, knowledge *types.Knowledge, score float64, matchType types.MatchType) *types.SearchResult { return &types.SearchResult{ ID: chunk.ID, Content: chunk.Content, KnowledgeID: chunk.KnowledgeID, ChunkIndex: chunk.ChunkIndex, KnowledgeTitle: knowledge.Title, StartAt: chunk.StartAt, EndAt: chunk.EndAt, Seq: chunk.ChunkIndex, Score: score, MatchType: matchType, Metadata: knowledge.GetMetadata(), ChunkType: string(chunk.ChunkType), ParentChunkID: chunk.ParentChunkID, ImageInfo: chunk.ImageInfo, KnowledgeFilename: knowledge.FileName, KnowledgeSource: knowledge.Source, } } // isValidTextChunk checks if a chunk is a valid text chunk func (s *knowledgeBaseService) isValidTextChunk(chunk *types.Chunk) bool { return slices.Contains([]types.ChunkType{ types.ChunkTypeText, types.ChunkTypeSummary, }, chunk.ChunkType) } // fetchKnowledgeData gets knowledge data in batch func (s *knowledgeBaseService) fetchKnowledgeData(ctx context.Context, tenantID uint, knowledgeIDs []string, ) (map[string]*types.Knowledge, error) { knowledges, err := s.kgRepo.GetKnowledgeBatch(ctx, tenantID, knowledgeIDs) if err != nil { logger.ErrorWithFields(ctx, err, map[string]interface{}{ "tenant_id": tenantID, "knowledge_ids": knowledgeIDs, }) return nil, err } knowledgeMap := make(map[string]*types.Knowledge, len(knowledges)) for _, knowledge := range knowledges { knowledgeMap[knowledge.ID] = knowledge } return knowledgeMap, nil }