package embedding import ( "context" "sync" "github.com/panjf2000/ants/v2" "knowlege-lsxd/internal/models/utils" ) type batchEmbedder struct { pool *ants.Pool } func NewBatchEmbedder(pool *ants.Pool) EmbedderPooler { return &batchEmbedder{pool: pool} } type textEmbedding struct { text string results []float32 } func (e *batchEmbedder) BatchEmbedWithPool(ctx context.Context, model Embedder, texts []string) ([][]float32, error) { // Create goroutine pool for concurrent processing of document chunks var wg sync.WaitGroup var mu sync.Mutex // For synchronizing access to error var firstErr error // Record the first error that occurs batchSize := 5 textEmbeddings := utils.MapSlice(texts, func(text string) *textEmbedding { return &textEmbedding{text: text} }) // Function to process each document chunk processChunk := func(texts []*textEmbedding) func() { return func() { defer wg.Done() // If an error has already occurred, don't continue processing if firstErr != nil { return } // Embed text embedding, err := model.BatchEmbed(ctx, utils.MapSlice(texts, func(text *textEmbedding) string { return text.text })) if err != nil { mu.Lock() if firstErr == nil { firstErr = err } mu.Unlock() return } mu.Lock() for i, text := range texts { text.results = embedding[i] } mu.Unlock() } } // Submit all tasks to the goroutine pool for _, texts := range utils.ChunkSlice(textEmbeddings, batchSize) { wg.Add(1) err := e.pool.Submit(processChunk(texts)) if err != nil { return nil, err } } // Wait for all tasks to complete wg.Wait() // Check if any errors occurred if firstErr != nil { return nil, firstErr } results := utils.MapSlice(textEmbeddings, func(text *textEmbedding) []float32 { return text.results }) return results, nil }