l_ai_knowledge/internal/application/service/retriever/composite.go

290 lines
9.2 KiB
Go

package retriever
import (
"context"
"fmt"
"maps"
"slices"
"sync"
"sync/atomic"
"go.opentelemetry.io/otel/attribute"
"knowlege-lsxd/internal/common"
"knowlege-lsxd/internal/logger"
"knowlege-lsxd/internal/models/embedding"
"knowlege-lsxd/internal/runtime"
"knowlege-lsxd/internal/tracing"
"knowlege-lsxd/internal/types"
"knowlege-lsxd/internal/types/interfaces"
)
// engineInfo holds information about a retrieve engine and its supported retriever types
type engineInfo struct {
retrieveEngine interfaces.RetrieveEngineService
retrieverType []types.RetrieverType
}
// CompositeRetrieveEngine implements a composite pattern for retrieval engines,
// delegating operations to all registered engines
type CompositeRetrieveEngine struct {
engineInfos []*engineInfo
}
// Retrieve performs retrieval operations by delegating to the appropriate engine
// based on the retriever type specified in the parameters
func (c *CompositeRetrieveEngine) Retrieve(ctx context.Context,
retrieveParams []types.RetrieveParams,
) ([]*types.RetrieveResult, error) {
return concurrentRetrieve(ctx, retrieveParams,
func(ctx context.Context, param types.RetrieveParams, results *[]*types.RetrieveResult, mu *sync.Mutex) error {
found := false
for _, engineInfo := range c.engineInfos {
if slices.Contains(engineInfo.retrieverType, param.RetrieverType) {
result, err := engineInfo.retrieveEngine.Retrieve(ctx, param)
if err != nil {
return err
}
mu.Lock()
*results = append(*results, result...)
mu.Unlock()
found = true
break
}
}
if !found {
return fmt.Errorf("retriever type %s not found", param.RetrieverType)
}
return nil
},
)
}
// NewCompositeRetrieveEngine creates a new composite retrieve engine with the given parameters
func NewCompositeRetrieveEngine(engineParams []types.RetrieverEngineParams) (*CompositeRetrieveEngine, error) {
var registry interfaces.RetrieveEngineRegistry
runtime.GetContainer().Invoke(func(r interfaces.RetrieveEngineRegistry) {
registry = r
})
engineInfos := make(map[types.RetrieverEngineType]*engineInfo)
for _, engineParam := range engineParams {
repo, err := registry.GetRetrieveEngineService(engineParam.RetrieverEngineType)
if err != nil {
return nil, err
}
if !slices.Contains(repo.Support(), engineParam.RetrieverType) {
return nil, fmt.Errorf("retrieval engine %s does not support retriever type: %s",
repo.EngineType(), engineParam.RetrieverType)
}
if _, exists := engineInfos[repo.EngineType()]; exists {
engineInfos[repo.EngineType()].retrieverType = append(engineInfos[repo.EngineType()].retrieverType,
engineParam.RetrieverType)
continue
}
engineInfos[repo.EngineType()] = &engineInfo{
retrieveEngine: repo,
retrieverType: []types.RetrieverType{engineParam.RetrieverType},
}
}
return &CompositeRetrieveEngine{engineInfos: slices.Collect(maps.Values(engineInfos))}, nil
}
// SupportRetriever checks if a retriever type is supported by any of the registered engines
func (c *CompositeRetrieveEngine) SupportRetriever(r types.RetrieverType) bool {
for _, engineInfo := range c.engineInfos {
if slices.Contains(engineInfo.retrieverType, r) {
return true
}
}
return false
}
// concurrentRetrieve is a helper function for concurrent processing of retrieval parameters
// and collecting results
func concurrentRetrieve(
ctx context.Context,
retrieveParams []types.RetrieveParams,
fn func(ctx context.Context, param types.RetrieveParams, results *[]*types.RetrieveResult, mu *sync.Mutex) error,
) ([]*types.RetrieveResult, error) {
var results []*types.RetrieveResult
var mu sync.Mutex
var wg sync.WaitGroup
errCh := make(chan error, len(retrieveParams))
for _, param := range retrieveParams {
wg.Add(1)
p := param // Create local copy for safe use in closure
go func() {
defer wg.Done()
if err := fn(ctx, p, &results, &mu); err != nil {
errCh <- err
}
}()
}
wg.Wait()
close(errCh)
// Check for errors
for err := range errCh {
if err != nil {
return nil, err
}
}
return results, nil
}
// concurrentExecWithError is a generic function for concurrent execution of operations
// and handling errors
func (c *CompositeRetrieveEngine) concurrentExecWithError(
ctx context.Context,
fn func(ctx context.Context, engineInfo *engineInfo) error,
) error {
var wg sync.WaitGroup
errCh := make(chan error, len(c.engineInfos))
for _, engineInfo := range c.engineInfos {
wg.Add(1)
eng := engineInfo // Create local copy for safe use in closure
go func() {
defer wg.Done()
if err := fn(ctx, eng); err != nil {
errCh <- err
}
}()
}
wg.Wait()
close(errCh)
// Return the first error (if any)
for err := range errCh {
if err != nil {
return err
}
}
return nil
}
// Index saves vector embeddings to all registered repositories
func (c *CompositeRetrieveEngine) Index(ctx context.Context,
embedder embedding.Embedder, indexInfo *types.IndexInfo,
) error {
ctx, span := tracing.ContextWithSpan(ctx, "CompositeRetrieveEngine.Index")
defer span.End()
err := c.concurrentExecWithError(ctx, func(ctx context.Context, engineInfo *engineInfo) error {
if err := engineInfo.retrieveEngine.Index(ctx, embedder, indexInfo, engineInfo.retrieverType); err != nil {
logger.Errorf(ctx, "Repository %s failed to save: %v", engineInfo.retrieveEngine.EngineType(), err)
return err
}
return nil
})
span.RecordError(err)
span.SetAttributes(
attribute.String("embedder", embedder.GetModelName()),
attribute.String("source_id", indexInfo.SourceID),
)
return err
}
// BatchIndex batch saves vector embeddings to all registered repositories
func (c *CompositeRetrieveEngine) BatchIndex(ctx context.Context,
embedder embedding.Embedder, indexInfoList []*types.IndexInfo,
) error {
ctx, span := tracing.ContextWithSpan(ctx, "CompositeRetrieveEngine.BatchIndex")
defer span.End()
// Deduplicate sourceIDs
indexInfoList = common.Deduplicate(func(info *types.IndexInfo) string { return info.SourceID }, indexInfoList...)
err := c.concurrentExecWithError(ctx, func(ctx context.Context, engineInfo *engineInfo) error {
if err := engineInfo.retrieveEngine.BatchIndex(
ctx,
embedder,
indexInfoList,
engineInfo.retrieverType,
); err != nil {
logger.Errorf(ctx, "Repository %s failed to batch save: %v", engineInfo.retrieveEngine.EngineType(), err)
return err
}
return nil
})
span.RecordError(err)
span.SetAttributes(
attribute.String("embedder", embedder.GetModelName()),
attribute.Int("index_info_count", len(indexInfoList)),
)
return err
}
// DeleteByChunkIDList deletes vector embeddings by chunk ID list from all registered repositories
func (c *CompositeRetrieveEngine) DeleteByChunkIDList(ctx context.Context,
chunkIDList []string, dimension int,
) error {
return c.concurrentExecWithError(ctx, func(ctx context.Context, engineInfo *engineInfo) error {
if err := engineInfo.retrieveEngine.DeleteByChunkIDList(ctx, chunkIDList, dimension); err != nil {
logger.GetLogger(ctx).Errorf("Repository %s failed to delete chunk ID list: %v",
engineInfo.retrieveEngine.EngineType(), err)
return err
}
return nil
})
}
// CopyIndices copies indices from a source knowledge base to a target knowledge base
func (c *CompositeRetrieveEngine) CopyIndices(
ctx context.Context,
sourceKnowledgeBaseID string,
targetKnowledgeBaseID string,
sourceToTargetKBIDMap map[string]string,
sourceToTargetChunkIDMap map[string]string,
dimension int,
) error {
return c.concurrentExecWithError(ctx, func(ctx context.Context, engineInfo *engineInfo) error {
if err := engineInfo.retrieveEngine.CopyIndices(
ctx,
sourceKnowledgeBaseID,
sourceToTargetKBIDMap,
sourceToTargetChunkIDMap,
targetKnowledgeBaseID,
dimension,
); err != nil {
logger.Errorf(ctx, "Repository %s failed to copy indices: %v", engineInfo.retrieveEngine.EngineType(), err)
return err
}
return nil
})
}
// DeleteByKnowledgeIDList deletes vector embeddings by knowledge ID list from all registered repositories
func (c *CompositeRetrieveEngine) DeleteByKnowledgeIDList(ctx context.Context,
knowledgeIDList []string, dimension int,
) error {
return c.concurrentExecWithError(ctx, func(ctx context.Context, engineInfo *engineInfo) error {
if err := engineInfo.retrieveEngine.DeleteByKnowledgeIDList(ctx, knowledgeIDList, dimension); err != nil {
logger.GetLogger(ctx).Errorf("Repository %s failed to delete knowledge ID list: %v",
engineInfo.retrieveEngine.EngineType(), err)
return err
}
return nil
})
}
// EstimateStorageSize estimates the storage size required for the provided index information
func (c *CompositeRetrieveEngine) EstimateStorageSize(ctx context.Context,
embedder embedding.Embedder, indexInfoList []*types.IndexInfo,
) int64 {
ctx, span := tracing.ContextWithSpan(ctx, "CompositeRetrieveEngine.EstimateStorageSize")
defer span.End()
sum := atomic.Int64{}
err := c.concurrentExecWithError(ctx, func(ctx context.Context, engineInfo *engineInfo) error {
sum.Add(engineInfo.retrieveEngine.EstimateStorageSize(ctx, embedder, indexInfoList, engineInfo.retrieverType))
return nil
})
span.RecordError(err)
span.SetAttributes(
attribute.String("embedder", embedder.GetModelName()),
attribute.Int("index_info_count", len(indexInfoList)),
attribute.Int64("storage_size", sum.Load()),
)
return sum.Load()
}