580 lines
18 KiB
Go
580 lines
18 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"slices"
|
|
"time"
|
|
|
|
"encoding/json"
|
|
|
|
"github.com/google/uuid"
|
|
"knowlege-lsxd/internal/application/service/retriever"
|
|
"knowlege-lsxd/internal/common"
|
|
"knowlege-lsxd/internal/logger"
|
|
"knowlege-lsxd/internal/models/embedding"
|
|
"knowlege-lsxd/internal/types"
|
|
"knowlege-lsxd/internal/types/interfaces"
|
|
)
|
|
|
|
// ErrInvalidTenantID represents an error for invalid tenant ID
|
|
var ErrInvalidTenantID = errors.New("invalid tenant ID")
|
|
|
|
// knowledgeBaseService implements the knowledge base service interface
|
|
type knowledgeBaseService struct {
|
|
repo interfaces.KnowledgeBaseRepository
|
|
kgRepo interfaces.KnowledgeRepository
|
|
chunkRepo interfaces.ChunkRepository
|
|
modelService interfaces.ModelService
|
|
}
|
|
|
|
// NewKnowledgeBaseService creates a new knowledge base service
|
|
func NewKnowledgeBaseService(repo interfaces.KnowledgeBaseRepository,
|
|
kgRepo interfaces.KnowledgeRepository,
|
|
chunkRepo interfaces.ChunkRepository,
|
|
modelService interfaces.ModelService,
|
|
) interfaces.KnowledgeBaseService {
|
|
return &knowledgeBaseService{
|
|
repo: repo,
|
|
kgRepo: kgRepo,
|
|
chunkRepo: chunkRepo,
|
|
modelService: modelService,
|
|
}
|
|
}
|
|
|
|
// CreateKnowledgeBase creates a new knowledge base
|
|
func (s *knowledgeBaseService) CreateKnowledgeBase(ctx context.Context,
|
|
kb *types.KnowledgeBase,
|
|
) (*types.KnowledgeBase, error) {
|
|
// Generate UUID and set creation timestamps
|
|
if kb.ID == "" {
|
|
kb.ID = uuid.New().String()
|
|
}
|
|
kb.CreatedAt = time.Now()
|
|
kb.TenantID = ctx.Value(types.TenantIDContextKey).(uint)
|
|
kb.UpdatedAt = time.Now()
|
|
|
|
logger.Infof(ctx, "Creating knowledge base, ID: %s, tenant ID: %d, name: %s", kb.ID, kb.TenantID, kb.Name)
|
|
|
|
if err := s.repo.CreateKnowledgeBase(ctx, kb); err != nil {
|
|
logger.ErrorWithFields(ctx, err, map[string]interface{}{
|
|
"knowledge_base_id": kb.ID,
|
|
"tenant_id": kb.TenantID,
|
|
})
|
|
return nil, err
|
|
}
|
|
|
|
logger.Infof(ctx, "Knowledge base created successfully, ID: %s, name: %s", kb.ID, kb.Name)
|
|
return kb, nil
|
|
}
|
|
|
|
// GetKnowledgeBaseByID retrieves a knowledge base by its ID
|
|
func (s *knowledgeBaseService) GetKnowledgeBaseByID(ctx context.Context, id string) (*types.KnowledgeBase, error) {
|
|
if id == "" {
|
|
logger.Error(ctx, "Knowledge base ID is empty")
|
|
return nil, errors.New("knowledge base ID cannot be empty")
|
|
}
|
|
|
|
logger.Infof(ctx, "Retrieving knowledge base, ID: %s", id)
|
|
|
|
kb, err := s.repo.GetKnowledgeBaseByID(ctx, id)
|
|
if err != nil {
|
|
logger.ErrorWithFields(ctx, err, map[string]interface{}{
|
|
"knowledge_base_id": id,
|
|
})
|
|
return nil, err
|
|
}
|
|
|
|
logger.Infof(ctx, "Knowledge base retrieved successfully, ID: %s, name: %s", kb.ID, kb.Name)
|
|
return kb, nil
|
|
}
|
|
|
|
// ListKnowledgeBases returns all knowledge bases for a tenant
|
|
func (s *knowledgeBaseService) ListKnowledgeBases(ctx context.Context) ([]*types.KnowledgeBase, error) {
|
|
tenantID := ctx.Value(types.TenantIDContextKey).(uint)
|
|
logger.Infof(ctx, "Retrieving knowledge base list for tenant, tenant ID: %d", tenantID)
|
|
|
|
kbs, err := s.repo.ListKnowledgeBasesByTenantID(ctx, tenantID)
|
|
if err != nil {
|
|
logger.ErrorWithFields(ctx, err, map[string]interface{}{
|
|
"tenant_id": tenantID,
|
|
})
|
|
return nil, err
|
|
}
|
|
|
|
logger.Infof(
|
|
ctx,
|
|
"Knowledge base list retrieved successfully, tenant ID: %d, knowledge base count: %d",
|
|
tenantID,
|
|
len(kbs),
|
|
)
|
|
return kbs, nil
|
|
}
|
|
|
|
// UpdateKnowledgeBase updates a knowledge base's properties
|
|
func (s *knowledgeBaseService) UpdateKnowledgeBase(ctx context.Context,
|
|
id string,
|
|
name string,
|
|
description string,
|
|
config *types.KnowledgeBaseConfig,
|
|
) (*types.KnowledgeBase, error) {
|
|
if id == "" {
|
|
logger.Error(ctx, "Knowledge base ID is empty")
|
|
return nil, errors.New("knowledge base ID cannot be empty")
|
|
}
|
|
|
|
logger.Infof(ctx, "Updating knowledge base, ID: %s, name: %s", id, name)
|
|
|
|
// Get existing knowledge base
|
|
kb, err := s.repo.GetKnowledgeBaseByID(ctx, id)
|
|
if err != nil {
|
|
logger.ErrorWithFields(ctx, err, map[string]interface{}{
|
|
"knowledge_base_id": id,
|
|
})
|
|
return nil, err
|
|
}
|
|
|
|
// Update the knowledge base properties
|
|
kb.Name = name
|
|
kb.Description = description
|
|
kb.ChunkingConfig = config.ChunkingConfig
|
|
kb.ImageProcessingConfig = config.ImageProcessingConfig
|
|
kb.UpdatedAt = time.Now()
|
|
|
|
logger.Info(ctx, "Saving knowledge base update")
|
|
if err := s.repo.UpdateKnowledgeBase(ctx, kb); err != nil {
|
|
logger.ErrorWithFields(ctx, err, map[string]interface{}{
|
|
"knowledge_base_id": id,
|
|
})
|
|
return nil, err
|
|
}
|
|
|
|
logger.Infof(ctx, "Knowledge base updated successfully, ID: %s, name: %s", kb.ID, kb.Name)
|
|
return kb, nil
|
|
}
|
|
|
|
// DeleteKnowledgeBase deletes a knowledge base by its ID
|
|
func (s *knowledgeBaseService) DeleteKnowledgeBase(ctx context.Context, id string) error {
|
|
if id == "" {
|
|
logger.Error(ctx, "Knowledge base ID is empty")
|
|
return errors.New("knowledge base ID cannot be empty")
|
|
}
|
|
|
|
logger.Infof(ctx, "Deleting knowledge base, ID: %s", id)
|
|
|
|
err := s.repo.DeleteKnowledgeBase(ctx, id)
|
|
if err != nil {
|
|
logger.ErrorWithFields(ctx, err, map[string]interface{}{
|
|
"knowledge_base_id": id,
|
|
})
|
|
return err
|
|
}
|
|
|
|
logger.Infof(ctx, "Knowledge base deleted successfully, ID: %s", id)
|
|
return nil
|
|
}
|
|
|
|
// SetEmbeddingModel sets the embedding model for a knowledge base
|
|
func (s *knowledgeBaseService) SetEmbeddingModel(ctx context.Context, id string, modelID string) error {
|
|
if id == "" {
|
|
logger.Error(ctx, "Knowledge base ID is empty")
|
|
return errors.New("knowledge base ID cannot be empty")
|
|
}
|
|
|
|
if modelID == "" {
|
|
logger.Error(ctx, "Model ID is empty")
|
|
return errors.New("model ID cannot be empty")
|
|
}
|
|
|
|
logger.Infof(ctx, "Setting embedding model for knowledge base, knowledge base ID: %s, model ID: %s", id, modelID)
|
|
|
|
// Get the knowledge base
|
|
kb, err := s.repo.GetKnowledgeBaseByID(ctx, id)
|
|
if err != nil {
|
|
logger.ErrorWithFields(ctx, err, map[string]interface{}{
|
|
"knowledge_base_id": id,
|
|
})
|
|
return err
|
|
}
|
|
|
|
// Update the knowledge base's embedding model
|
|
kb.EmbeddingModelID = modelID
|
|
kb.UpdatedAt = time.Now()
|
|
|
|
logger.Info(ctx, "Saving knowledge base embedding model update")
|
|
err = s.repo.UpdateKnowledgeBase(ctx, kb)
|
|
if err != nil {
|
|
logger.ErrorWithFields(ctx, err, map[string]interface{}{
|
|
"knowledge_base_id": id,
|
|
"embedding_model_id": modelID,
|
|
})
|
|
return err
|
|
}
|
|
|
|
logger.Infof(
|
|
ctx,
|
|
"Knowledge base embedding model set successfully, knowledge base ID: %s, model ID: %s",
|
|
id,
|
|
modelID,
|
|
)
|
|
return nil
|
|
}
|
|
|
|
// CopyKnowledgeBase copies a knowledge base to a new knowledge base
|
|
// 浅拷贝
|
|
func (s *knowledgeBaseService) CopyKnowledgeBase(ctx context.Context,
|
|
srcKB string, dstKB string,
|
|
) (*types.KnowledgeBase, *types.KnowledgeBase, error) {
|
|
sourceKB, err := s.repo.GetKnowledgeBaseByID(ctx, srcKB)
|
|
if err != nil {
|
|
logger.Errorf(ctx, "Get source knowledge base failed: %v", err)
|
|
return nil, nil, err
|
|
}
|
|
tenantID := ctx.Value(types.TenantIDContextKey).(uint)
|
|
var targetKB *types.KnowledgeBase
|
|
if dstKB != "" {
|
|
targetKB, err = s.repo.GetKnowledgeBaseByID(ctx, dstKB)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
} else {
|
|
targetKB = &types.KnowledgeBase{
|
|
ID: uuid.New().String(),
|
|
Name: sourceKB.Name,
|
|
Description: sourceKB.Description,
|
|
TenantID: tenantID,
|
|
ChunkingConfig: sourceKB.ChunkingConfig,
|
|
ImageProcessingConfig: sourceKB.ImageProcessingConfig,
|
|
EmbeddingModelID: sourceKB.EmbeddingModelID,
|
|
SummaryModelID: sourceKB.SummaryModelID,
|
|
RerankModelID: sourceKB.RerankModelID,
|
|
VLMModelID: sourceKB.VLMModelID,
|
|
StorageConfig: sourceKB.StorageConfig,
|
|
}
|
|
if err := s.repo.CreateKnowledgeBase(ctx, targetKB); err != nil {
|
|
return nil, nil, err
|
|
}
|
|
}
|
|
return sourceKB, targetKB, nil
|
|
}
|
|
|
|
// HybridSearch performs hybrid search, including vector retrieval and keyword retrieval
|
|
func (s *knowledgeBaseService) HybridSearch(ctx context.Context,
|
|
id string,
|
|
params types.SearchParams,
|
|
) ([]*types.SearchResult, error) {
|
|
logger.Infof(ctx, "Hybrid search parameters, knowledge base ID: %s, query text: %s", id, params.QueryText)
|
|
|
|
tenantInfo := ctx.Value(types.TenantInfoContextKey).(*types.Tenant)
|
|
logger.Infof(ctx, "Creating composite retrieval engine, tenant ID: %d", tenantInfo.ID)
|
|
|
|
// Create a composite retrieval engine with tenant's configured retrievers
|
|
retrieveEngine, err := retriever.NewCompositeRetrieveEngine(tenantInfo.RetrieverEngines.Engines)
|
|
if err != nil {
|
|
logger.Errorf(ctx, "Failed to create retrieval engine: %v", err)
|
|
return nil, err
|
|
}
|
|
|
|
var retrieveParams []types.RetrieveParams
|
|
var embeddingModel embedding.Embedder
|
|
var kb *types.KnowledgeBase
|
|
|
|
// Add vector retrieval params if supported
|
|
if retrieveEngine.SupportRetriever(types.VectorRetrieverType) {
|
|
logger.Info(ctx, "Vector retrieval supported, preparing vector retrieval parameters")
|
|
|
|
kb, err = s.repo.GetKnowledgeBaseByID(ctx, id)
|
|
if err != nil {
|
|
logger.ErrorWithFields(ctx, err, map[string]interface{}{
|
|
"knowledge_base_id": id,
|
|
})
|
|
return nil, err
|
|
}
|
|
|
|
logger.Infof(ctx, "Getting embedding model, model ID: %s", kb.EmbeddingModelID)
|
|
embeddingModel, err = s.modelService.GetEmbeddingModel(ctx, kb.EmbeddingModelID)
|
|
if err != nil {
|
|
logger.Errorf(ctx, "Failed to get embedding model, model ID: %s, error: %v", kb.EmbeddingModelID, err)
|
|
return nil, err
|
|
}
|
|
logger.Infof(ctx, "Embedding model retrieved: %v", embeddingModel)
|
|
|
|
// Generate embedding vector for the query text
|
|
logger.Info(ctx, "Starting to generate query embedding")
|
|
queryEmbedding, err := embeddingModel.Embed(ctx, params.QueryText)
|
|
if err != nil {
|
|
logger.Errorf(ctx, "Failed to embed query text, query text: %s, error: %v", params.QueryText, err)
|
|
return nil, err
|
|
}
|
|
logger.Infof(ctx, "Query embedding generated successfully, embedding vector length: %d", len(queryEmbedding))
|
|
|
|
retrieveParams = append(retrieveParams, types.RetrieveParams{
|
|
Query: params.QueryText,
|
|
Embedding: queryEmbedding,
|
|
KnowledgeBaseIDs: []string{id},
|
|
TopK: params.MatchCount,
|
|
Threshold: params.VectorThreshold,
|
|
RetrieverType: types.VectorRetrieverType,
|
|
})
|
|
logger.Info(ctx, "Vector retrieval parameters setup completed")
|
|
}
|
|
|
|
// Add keyword retrieval params if supported
|
|
if retrieveEngine.SupportRetriever(types.KeywordsRetrieverType) {
|
|
logger.Info(ctx, "Keyword retrieval supported, preparing keyword retrieval parameters")
|
|
retrieveParams = append(retrieveParams, types.RetrieveParams{
|
|
Query: params.QueryText,
|
|
KnowledgeBaseIDs: []string{id},
|
|
TopK: params.MatchCount,
|
|
Threshold: params.KeywordThreshold,
|
|
RetrieverType: types.KeywordsRetrieverType,
|
|
})
|
|
logger.Info(ctx, "Keyword retrieval parameters setup completed")
|
|
}
|
|
|
|
if len(retrieveParams) == 0 {
|
|
logger.Error(ctx, "No retrieval parameters available")
|
|
return nil, errors.New("no retrieve params")
|
|
}
|
|
|
|
// Execute retrieval using the configured engines
|
|
logger.Infof(ctx, "Starting retrieval, parameter count: %d", len(retrieveParams))
|
|
retrieveResults, err := retrieveEngine.Retrieve(ctx, retrieveParams)
|
|
if err != nil {
|
|
logger.ErrorWithFields(ctx, err, map[string]interface{}{
|
|
"knowledge_base_id": id,
|
|
"query_text": params.QueryText,
|
|
})
|
|
return nil, err
|
|
}
|
|
|
|
// Collect all results from different retrievers and deduplicate by chunk ID
|
|
logger.Infof(ctx, "Processing retrieval results")
|
|
matchResults := []*types.IndexWithScore{}
|
|
for _, retrieveResult := range retrieveResults {
|
|
logger.Infof(ctx, "Retrieval results, engine: %v, retriever: %v, count: %v",
|
|
retrieveResult.RetrieverEngineType,
|
|
retrieveResult.RetrieverType,
|
|
len(retrieveResult.Results),
|
|
)
|
|
matchResults = append(matchResults, retrieveResult.Results...)
|
|
}
|
|
|
|
// Early return if no results
|
|
if len(matchResults) == 0 {
|
|
logger.Info(ctx, "No search results found")
|
|
return nil, nil
|
|
}
|
|
|
|
// Deduplicate results by chunk ID
|
|
logger.Infof(ctx, "Result count before deduplication: %d", len(matchResults))
|
|
deduplicatedChunks := common.Deduplicate(func(r *types.IndexWithScore) string { return r.ChunkID }, matchResults...)
|
|
logger.Infof(ctx, "Result count after deduplication: %d", len(deduplicatedChunks))
|
|
|
|
return s.processSearchResults(ctx, deduplicatedChunks)
|
|
}
|
|
|
|
// processSearchResults handles the processing of search results, optimizing database queries
|
|
func (s *knowledgeBaseService) processSearchResults(ctx context.Context,
|
|
chunks []*types.IndexWithScore) ([]*types.SearchResult, error) {
|
|
|
|
if len(chunks) == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
tenantID := ctx.Value(types.TenantIDContextKey).(uint)
|
|
|
|
// Prepare data structures for efficient processing
|
|
var knowledgeIDs []string
|
|
var chunkIDs []string
|
|
chunkScores := make(map[string]float64)
|
|
chunkMatchTypes := make(map[string]types.MatchType)
|
|
processedKnowledgeIDs := make(map[string]bool)
|
|
|
|
// Collect all knowledge and chunk IDs
|
|
for _, chunk := range chunks {
|
|
if !processedKnowledgeIDs[chunk.KnowledgeID] {
|
|
knowledgeIDs = append(knowledgeIDs, chunk.KnowledgeID)
|
|
processedKnowledgeIDs[chunk.KnowledgeID] = true
|
|
}
|
|
|
|
chunkIDs = append(chunkIDs, chunk.ChunkID)
|
|
chunkScores[chunk.ChunkID] = chunk.Score
|
|
chunkMatchTypes[chunk.ChunkID] = chunk.MatchType
|
|
}
|
|
|
|
// Batch fetch knowledge data
|
|
logger.Infof(ctx, "Fetching knowledge data for %d IDs", len(knowledgeIDs))
|
|
knowledgeMap, err := s.fetchKnowledgeData(ctx, tenantID, knowledgeIDs)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Batch fetch all chunks in one go
|
|
logger.Infof(ctx, "Fetching chunk data for %d IDs", len(chunkIDs))
|
|
allChunks, err := s.chunkRepo.ListChunksByID(ctx, tenantID, chunkIDs)
|
|
if err != nil {
|
|
logger.ErrorWithFields(ctx, err, map[string]interface{}{
|
|
"tenant_id": tenantID,
|
|
"chunk_ids": chunkIDs,
|
|
})
|
|
return nil, err
|
|
}
|
|
|
|
// Build chunk map and collect additional IDs to fetch
|
|
chunkMap := make(map[string]*types.Chunk, len(allChunks))
|
|
var additionalChunkIDs []string
|
|
processedChunkIDs := make(map[string]bool)
|
|
|
|
// First pass: Build chunk map and collect parent IDs
|
|
for _, chunk := range allChunks {
|
|
chunkMap[chunk.ID] = chunk
|
|
processedChunkIDs[chunk.ID] = true
|
|
|
|
// Collect parent chunks
|
|
if chunk.ParentChunkID != "" && !processedChunkIDs[chunk.ParentChunkID] {
|
|
additionalChunkIDs = append(additionalChunkIDs, chunk.ParentChunkID)
|
|
processedChunkIDs[chunk.ParentChunkID] = true
|
|
|
|
// Pass score to parent
|
|
chunkScores[chunk.ParentChunkID] = chunkScores[chunk.ID]
|
|
chunkMatchTypes[chunk.ParentChunkID] = types.MatchTypeParentChunk
|
|
}
|
|
|
|
// Collect related chunks
|
|
relationChunkIDs := s.collectRelatedChunkIDs(chunk, processedChunkIDs)
|
|
for _, chunkID := range relationChunkIDs {
|
|
additionalChunkIDs = append(additionalChunkIDs, chunkID)
|
|
chunkMatchTypes[chunkID] = types.MatchTypeRelationChunk
|
|
}
|
|
|
|
// Add nearby chunks (prev and next)
|
|
if chunk.ChunkType == types.ChunkTypeText {
|
|
if chunk.NextChunkID != "" && !processedChunkIDs[chunk.NextChunkID] {
|
|
additionalChunkIDs = append(additionalChunkIDs, chunk.NextChunkID)
|
|
processedChunkIDs[chunk.NextChunkID] = true
|
|
chunkMatchTypes[chunk.NextChunkID] = types.MatchTypeNearByChunk
|
|
}
|
|
if chunk.PreChunkID != "" && !processedChunkIDs[chunk.PreChunkID] {
|
|
additionalChunkIDs = append(additionalChunkIDs, chunk.PreChunkID)
|
|
processedChunkIDs[chunk.PreChunkID] = true
|
|
chunkMatchTypes[chunk.PreChunkID] = types.MatchTypeNearByChunk
|
|
}
|
|
}
|
|
}
|
|
|
|
// Fetch all additional chunks in one go if needed
|
|
if len(additionalChunkIDs) > 0 {
|
|
logger.Infof(ctx, "Fetching %d additional chunks", len(additionalChunkIDs))
|
|
additionalChunks, err := s.chunkRepo.ListChunksByID(ctx, tenantID, additionalChunkIDs)
|
|
if err != nil {
|
|
logger.Warnf(ctx, "Failed to fetch some additional chunks: %v", err)
|
|
// Continue with what we have
|
|
} else {
|
|
// Add to chunk map
|
|
for _, chunk := range additionalChunks {
|
|
chunkMap[chunk.ID] = chunk
|
|
}
|
|
}
|
|
}
|
|
|
|
// Build final search results
|
|
var searchResults []*types.SearchResult
|
|
for chunkID, chunk := range chunkMap {
|
|
if !s.isValidTextChunk(chunk) {
|
|
continue
|
|
}
|
|
|
|
score, hasScore := chunkScores[chunkID]
|
|
if !hasScore || score <= 0 {
|
|
score = 0.0
|
|
}
|
|
|
|
if knowledge, ok := knowledgeMap[chunk.KnowledgeID]; ok {
|
|
matchType := types.MatchTypeParentChunk
|
|
if specificType, exists := chunkMatchTypes[chunkID]; exists {
|
|
matchType = specificType
|
|
} else {
|
|
logger.Warnf(ctx, "Unkonwn match type for chunk: %s", chunkID)
|
|
continue
|
|
}
|
|
searchResults = append(searchResults, s.buildSearchResult(chunk, knowledge, score, matchType))
|
|
}
|
|
}
|
|
logger.Infof(ctx, "Search results processed, total: %d", len(searchResults))
|
|
return searchResults, nil
|
|
}
|
|
|
|
// collectRelatedChunkIDs extracts related chunk IDs from a chunk
|
|
func (s *knowledgeBaseService) collectRelatedChunkIDs(chunk *types.Chunk, processedIDs map[string]bool) []string {
|
|
var relatedIDs []string
|
|
// Process direct relations
|
|
if len(chunk.RelationChunks) > 0 {
|
|
var relations []string
|
|
if err := json.Unmarshal(chunk.RelationChunks, &relations); err == nil {
|
|
for _, id := range relations {
|
|
if !processedIDs[id] {
|
|
relatedIDs = append(relatedIDs, id)
|
|
processedIDs[id] = true
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return relatedIDs
|
|
}
|
|
|
|
// buildSearchResult creates a search result from chunk and knowledge
|
|
func (s *knowledgeBaseService) buildSearchResult(chunk *types.Chunk,
|
|
knowledge *types.Knowledge,
|
|
score float64,
|
|
matchType types.MatchType) *types.SearchResult {
|
|
|
|
return &types.SearchResult{
|
|
ID: chunk.ID,
|
|
Content: chunk.Content,
|
|
KnowledgeID: chunk.KnowledgeID,
|
|
ChunkIndex: chunk.ChunkIndex,
|
|
KnowledgeTitle: knowledge.Title,
|
|
StartAt: chunk.StartAt,
|
|
EndAt: chunk.EndAt,
|
|
Seq: chunk.ChunkIndex,
|
|
Score: score,
|
|
MatchType: matchType,
|
|
Metadata: knowledge.GetMetadata(),
|
|
ChunkType: string(chunk.ChunkType),
|
|
ParentChunkID: chunk.ParentChunkID,
|
|
ImageInfo: chunk.ImageInfo,
|
|
KnowledgeFilename: knowledge.FileName,
|
|
KnowledgeSource: knowledge.Source,
|
|
}
|
|
}
|
|
|
|
// isValidTextChunk checks if a chunk is a valid text chunk
|
|
func (s *knowledgeBaseService) isValidTextChunk(chunk *types.Chunk) bool {
|
|
return slices.Contains([]types.ChunkType{
|
|
types.ChunkTypeText, types.ChunkTypeSummary,
|
|
}, chunk.ChunkType)
|
|
}
|
|
|
|
// fetchKnowledgeData gets knowledge data in batch
|
|
func (s *knowledgeBaseService) fetchKnowledgeData(ctx context.Context,
|
|
tenantID uint,
|
|
knowledgeIDs []string,
|
|
) (map[string]*types.Knowledge, error) {
|
|
knowledges, err := s.kgRepo.GetKnowledgeBatch(ctx, tenantID, knowledgeIDs)
|
|
if err != nil {
|
|
logger.ErrorWithFields(ctx, err, map[string]interface{}{
|
|
"tenant_id": tenantID,
|
|
"knowledge_ids": knowledgeIDs,
|
|
})
|
|
return nil, err
|
|
}
|
|
|
|
knowledgeMap := make(map[string]*types.Knowledge, len(knowledges))
|
|
for _, knowledge := range knowledges {
|
|
knowledgeMap[knowledge.ID] = knowledge
|
|
}
|
|
|
|
return knowledgeMap, nil
|
|
}
|