l_ai_knowledge/internal/models/embedding/batch.go

84 lines
1.8 KiB
Go

package embedding
import (
"context"
"sync"
"github.com/panjf2000/ants/v2"
"knowlege-lsxd/internal/models/utils"
)
type batchEmbedder struct {
pool *ants.Pool
}
func NewBatchEmbedder(pool *ants.Pool) EmbedderPooler {
return &batchEmbedder{pool: pool}
}
type textEmbedding struct {
text string
results []float32
}
func (e *batchEmbedder) BatchEmbedWithPool(ctx context.Context, model Embedder, texts []string) ([][]float32, error) {
// Create goroutine pool for concurrent processing of document chunks
var wg sync.WaitGroup
var mu sync.Mutex // For synchronizing access to error
var firstErr error // Record the first error that occurs
batchSize := 5
textEmbeddings := utils.MapSlice(texts, func(text string) *textEmbedding {
return &textEmbedding{text: text}
})
// Function to process each document chunk
processChunk := func(texts []*textEmbedding) func() {
return func() {
defer wg.Done()
// If an error has already occurred, don't continue processing
if firstErr != nil {
return
}
// Embed text
embedding, err := model.BatchEmbed(ctx, utils.MapSlice(texts, func(text *textEmbedding) string {
return text.text
}))
if err != nil {
mu.Lock()
if firstErr == nil {
firstErr = err
}
mu.Unlock()
return
}
mu.Lock()
for i, text := range texts {
text.results = embedding[i]
}
mu.Unlock()
}
}
// Submit all tasks to the goroutine pool
for _, texts := range utils.ChunkSlice(textEmbeddings, batchSize) {
wg.Add(1)
err := e.pool.Submit(processChunk(texts))
if err != nil {
return nil, err
}
}
// Wait for all tasks to complete
wg.Wait()
// Check if any errors occurred
if firstErr != nil {
return nil, firstErr
}
results := utils.MapSlice(textEmbeddings, func(text *textEmbedding) []float32 {
return text.results
})
return results, nil
}