84 lines
1.8 KiB
Go
84 lines
1.8 KiB
Go
package embedding
|
|
|
|
import (
|
|
"context"
|
|
"sync"
|
|
|
|
"github.com/panjf2000/ants/v2"
|
|
"knowlege-lsxd/internal/models/utils"
|
|
)
|
|
|
|
type batchEmbedder struct {
|
|
pool *ants.Pool
|
|
}
|
|
|
|
func NewBatchEmbedder(pool *ants.Pool) EmbedderPooler {
|
|
return &batchEmbedder{pool: pool}
|
|
}
|
|
|
|
type textEmbedding struct {
|
|
text string
|
|
results []float32
|
|
}
|
|
|
|
func (e *batchEmbedder) BatchEmbedWithPool(ctx context.Context, model Embedder, texts []string) ([][]float32, error) {
|
|
// Create goroutine pool for concurrent processing of document chunks
|
|
var wg sync.WaitGroup
|
|
var mu sync.Mutex // For synchronizing access to error
|
|
var firstErr error // Record the first error that occurs
|
|
batchSize := 5
|
|
textEmbeddings := utils.MapSlice(texts, func(text string) *textEmbedding {
|
|
return &textEmbedding{text: text}
|
|
})
|
|
|
|
// Function to process each document chunk
|
|
processChunk := func(texts []*textEmbedding) func() {
|
|
return func() {
|
|
defer wg.Done()
|
|
// If an error has already occurred, don't continue processing
|
|
if firstErr != nil {
|
|
return
|
|
}
|
|
// Embed text
|
|
embedding, err := model.BatchEmbed(ctx, utils.MapSlice(texts, func(text *textEmbedding) string {
|
|
return text.text
|
|
}))
|
|
if err != nil {
|
|
mu.Lock()
|
|
if firstErr == nil {
|
|
firstErr = err
|
|
}
|
|
mu.Unlock()
|
|
return
|
|
}
|
|
mu.Lock()
|
|
for i, text := range texts {
|
|
text.results = embedding[i]
|
|
}
|
|
mu.Unlock()
|
|
}
|
|
}
|
|
|
|
// Submit all tasks to the goroutine pool
|
|
for _, texts := range utils.ChunkSlice(textEmbeddings, batchSize) {
|
|
wg.Add(1)
|
|
err := e.pool.Submit(processChunk(texts))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
// Wait for all tasks to complete
|
|
wg.Wait()
|
|
|
|
// Check if any errors occurred
|
|
if firstErr != nil {
|
|
return nil, firstErr
|
|
}
|
|
|
|
results := utils.MapSlice(textEmbeddings, func(text *textEmbedding) []float32 {
|
|
return text.results
|
|
})
|
|
return results, nil
|
|
}
|