290 lines
9.2 KiB
Go
290 lines
9.2 KiB
Go
package retriever
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"maps"
|
|
"slices"
|
|
"sync"
|
|
"sync/atomic"
|
|
|
|
"go.opentelemetry.io/otel/attribute"
|
|
"knowlege-lsxd/internal/common"
|
|
"knowlege-lsxd/internal/logger"
|
|
"knowlege-lsxd/internal/models/embedding"
|
|
"knowlege-lsxd/internal/runtime"
|
|
"knowlege-lsxd/internal/tracing"
|
|
"knowlege-lsxd/internal/types"
|
|
"knowlege-lsxd/internal/types/interfaces"
|
|
)
|
|
|
|
// engineInfo holds information about a retrieve engine and its supported retriever types
|
|
type engineInfo struct {
|
|
retrieveEngine interfaces.RetrieveEngineService
|
|
retrieverType []types.RetrieverType
|
|
}
|
|
|
|
// CompositeRetrieveEngine implements a composite pattern for retrieval engines,
|
|
// delegating operations to all registered engines
|
|
type CompositeRetrieveEngine struct {
|
|
engineInfos []*engineInfo
|
|
}
|
|
|
|
// Retrieve performs retrieval operations by delegating to the appropriate engine
|
|
// based on the retriever type specified in the parameters
|
|
func (c *CompositeRetrieveEngine) Retrieve(ctx context.Context,
|
|
retrieveParams []types.RetrieveParams,
|
|
) ([]*types.RetrieveResult, error) {
|
|
return concurrentRetrieve(ctx, retrieveParams,
|
|
func(ctx context.Context, param types.RetrieveParams, results *[]*types.RetrieveResult, mu *sync.Mutex) error {
|
|
found := false
|
|
for _, engineInfo := range c.engineInfos {
|
|
if slices.Contains(engineInfo.retrieverType, param.RetrieverType) {
|
|
result, err := engineInfo.retrieveEngine.Retrieve(ctx, param)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
mu.Lock()
|
|
*results = append(*results, result...)
|
|
mu.Unlock()
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
return fmt.Errorf("retriever type %s not found", param.RetrieverType)
|
|
}
|
|
return nil
|
|
},
|
|
)
|
|
}
|
|
|
|
// NewCompositeRetrieveEngine creates a new composite retrieve engine with the given parameters
|
|
func NewCompositeRetrieveEngine(engineParams []types.RetrieverEngineParams) (*CompositeRetrieveEngine, error) {
|
|
var registry interfaces.RetrieveEngineRegistry
|
|
runtime.GetContainer().Invoke(func(r interfaces.RetrieveEngineRegistry) {
|
|
registry = r
|
|
})
|
|
engineInfos := make(map[types.RetrieverEngineType]*engineInfo)
|
|
for _, engineParam := range engineParams {
|
|
repo, err := registry.GetRetrieveEngineService(engineParam.RetrieverEngineType)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if !slices.Contains(repo.Support(), engineParam.RetrieverType) {
|
|
return nil, fmt.Errorf("retrieval engine %s does not support retriever type: %s",
|
|
repo.EngineType(), engineParam.RetrieverType)
|
|
}
|
|
if _, exists := engineInfos[repo.EngineType()]; exists {
|
|
engineInfos[repo.EngineType()].retrieverType = append(engineInfos[repo.EngineType()].retrieverType,
|
|
engineParam.RetrieverType)
|
|
continue
|
|
}
|
|
engineInfos[repo.EngineType()] = &engineInfo{
|
|
retrieveEngine: repo,
|
|
retrieverType: []types.RetrieverType{engineParam.RetrieverType},
|
|
}
|
|
}
|
|
return &CompositeRetrieveEngine{engineInfos: slices.Collect(maps.Values(engineInfos))}, nil
|
|
}
|
|
|
|
// SupportRetriever checks if a retriever type is supported by any of the registered engines
|
|
func (c *CompositeRetrieveEngine) SupportRetriever(r types.RetrieverType) bool {
|
|
for _, engineInfo := range c.engineInfos {
|
|
if slices.Contains(engineInfo.retrieverType, r) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// concurrentRetrieve is a helper function for concurrent processing of retrieval parameters
|
|
// and collecting results
|
|
func concurrentRetrieve(
|
|
ctx context.Context,
|
|
retrieveParams []types.RetrieveParams,
|
|
fn func(ctx context.Context, param types.RetrieveParams, results *[]*types.RetrieveResult, mu *sync.Mutex) error,
|
|
) ([]*types.RetrieveResult, error) {
|
|
var results []*types.RetrieveResult
|
|
var mu sync.Mutex
|
|
var wg sync.WaitGroup
|
|
errCh := make(chan error, len(retrieveParams))
|
|
|
|
for _, param := range retrieveParams {
|
|
wg.Add(1)
|
|
p := param // Create local copy for safe use in closure
|
|
go func() {
|
|
defer wg.Done()
|
|
if err := fn(ctx, p, &results, &mu); err != nil {
|
|
errCh <- err
|
|
}
|
|
}()
|
|
}
|
|
|
|
wg.Wait()
|
|
close(errCh)
|
|
|
|
// Check for errors
|
|
for err := range errCh {
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return results, nil
|
|
}
|
|
|
|
// concurrentExecWithError is a generic function for concurrent execution of operations
|
|
// and handling errors
|
|
func (c *CompositeRetrieveEngine) concurrentExecWithError(
|
|
ctx context.Context,
|
|
fn func(ctx context.Context, engineInfo *engineInfo) error,
|
|
) error {
|
|
var wg sync.WaitGroup
|
|
errCh := make(chan error, len(c.engineInfos))
|
|
|
|
for _, engineInfo := range c.engineInfos {
|
|
wg.Add(1)
|
|
eng := engineInfo // Create local copy for safe use in closure
|
|
go func() {
|
|
defer wg.Done()
|
|
if err := fn(ctx, eng); err != nil {
|
|
errCh <- err
|
|
}
|
|
}()
|
|
}
|
|
|
|
wg.Wait()
|
|
close(errCh)
|
|
|
|
// Return the first error (if any)
|
|
for err := range errCh {
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Index saves vector embeddings to all registered repositories
|
|
func (c *CompositeRetrieveEngine) Index(ctx context.Context,
|
|
embedder embedding.Embedder, indexInfo *types.IndexInfo,
|
|
) error {
|
|
ctx, span := tracing.ContextWithSpan(ctx, "CompositeRetrieveEngine.Index")
|
|
defer span.End()
|
|
err := c.concurrentExecWithError(ctx, func(ctx context.Context, engineInfo *engineInfo) error {
|
|
if err := engineInfo.retrieveEngine.Index(ctx, embedder, indexInfo, engineInfo.retrieverType); err != nil {
|
|
logger.Errorf(ctx, "Repository %s failed to save: %v", engineInfo.retrieveEngine.EngineType(), err)
|
|
return err
|
|
}
|
|
return nil
|
|
})
|
|
span.RecordError(err)
|
|
span.SetAttributes(
|
|
attribute.String("embedder", embedder.GetModelName()),
|
|
attribute.String("source_id", indexInfo.SourceID),
|
|
)
|
|
return err
|
|
}
|
|
|
|
// BatchIndex batch saves vector embeddings to all registered repositories
|
|
func (c *CompositeRetrieveEngine) BatchIndex(ctx context.Context,
|
|
embedder embedding.Embedder, indexInfoList []*types.IndexInfo,
|
|
) error {
|
|
ctx, span := tracing.ContextWithSpan(ctx, "CompositeRetrieveEngine.BatchIndex")
|
|
defer span.End()
|
|
// Deduplicate sourceIDs
|
|
indexInfoList = common.Deduplicate(func(info *types.IndexInfo) string { return info.SourceID }, indexInfoList...)
|
|
err := c.concurrentExecWithError(ctx, func(ctx context.Context, engineInfo *engineInfo) error {
|
|
if err := engineInfo.retrieveEngine.BatchIndex(
|
|
ctx,
|
|
embedder,
|
|
indexInfoList,
|
|
engineInfo.retrieverType,
|
|
); err != nil {
|
|
logger.Errorf(ctx, "Repository %s failed to batch save: %v", engineInfo.retrieveEngine.EngineType(), err)
|
|
return err
|
|
}
|
|
return nil
|
|
})
|
|
span.RecordError(err)
|
|
span.SetAttributes(
|
|
attribute.String("embedder", embedder.GetModelName()),
|
|
attribute.Int("index_info_count", len(indexInfoList)),
|
|
)
|
|
return err
|
|
}
|
|
|
|
// DeleteByChunkIDList deletes vector embeddings by chunk ID list from all registered repositories
|
|
func (c *CompositeRetrieveEngine) DeleteByChunkIDList(ctx context.Context,
|
|
chunkIDList []string, dimension int,
|
|
) error {
|
|
return c.concurrentExecWithError(ctx, func(ctx context.Context, engineInfo *engineInfo) error {
|
|
if err := engineInfo.retrieveEngine.DeleteByChunkIDList(ctx, chunkIDList, dimension); err != nil {
|
|
logger.GetLogger(ctx).Errorf("Repository %s failed to delete chunk ID list: %v",
|
|
engineInfo.retrieveEngine.EngineType(), err)
|
|
return err
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// CopyIndices copies indices from a source knowledge base to a target knowledge base
|
|
func (c *CompositeRetrieveEngine) CopyIndices(
|
|
ctx context.Context,
|
|
sourceKnowledgeBaseID string,
|
|
targetKnowledgeBaseID string,
|
|
sourceToTargetKBIDMap map[string]string,
|
|
sourceToTargetChunkIDMap map[string]string,
|
|
dimension int,
|
|
) error {
|
|
return c.concurrentExecWithError(ctx, func(ctx context.Context, engineInfo *engineInfo) error {
|
|
if err := engineInfo.retrieveEngine.CopyIndices(
|
|
ctx,
|
|
sourceKnowledgeBaseID,
|
|
sourceToTargetKBIDMap,
|
|
sourceToTargetChunkIDMap,
|
|
targetKnowledgeBaseID,
|
|
dimension,
|
|
); err != nil {
|
|
logger.Errorf(ctx, "Repository %s failed to copy indices: %v", engineInfo.retrieveEngine.EngineType(), err)
|
|
return err
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// DeleteByKnowledgeIDList deletes vector embeddings by knowledge ID list from all registered repositories
|
|
func (c *CompositeRetrieveEngine) DeleteByKnowledgeIDList(ctx context.Context,
|
|
knowledgeIDList []string, dimension int,
|
|
) error {
|
|
return c.concurrentExecWithError(ctx, func(ctx context.Context, engineInfo *engineInfo) error {
|
|
if err := engineInfo.retrieveEngine.DeleteByKnowledgeIDList(ctx, knowledgeIDList, dimension); err != nil {
|
|
logger.GetLogger(ctx).Errorf("Repository %s failed to delete knowledge ID list: %v",
|
|
engineInfo.retrieveEngine.EngineType(), err)
|
|
return err
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// EstimateStorageSize estimates the storage size required for the provided index information
|
|
func (c *CompositeRetrieveEngine) EstimateStorageSize(ctx context.Context,
|
|
embedder embedding.Embedder, indexInfoList []*types.IndexInfo,
|
|
) int64 {
|
|
ctx, span := tracing.ContextWithSpan(ctx, "CompositeRetrieveEngine.EstimateStorageSize")
|
|
defer span.End()
|
|
sum := atomic.Int64{}
|
|
err := c.concurrentExecWithError(ctx, func(ctx context.Context, engineInfo *engineInfo) error {
|
|
sum.Add(engineInfo.retrieveEngine.EstimateStorageSize(ctx, embedder, indexInfoList, engineInfo.retrieverType))
|
|
return nil
|
|
})
|
|
span.RecordError(err)
|
|
span.SetAttributes(
|
|
attribute.String("embedder", embedder.GetModelName()),
|
|
attribute.Int("index_info_count", len(indexInfoList)),
|
|
attribute.Int64("storage_size", sum.Load()),
|
|
)
|
|
return sum.Load()
|
|
}
|