126 lines
3.2 KiB
Go
126 lines
3.2 KiB
Go
package embedding
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"time"
|
|
|
|
ollamaapi "github.com/ollama/ollama/api"
|
|
"knowlege-lsxd/internal/logger"
|
|
"knowlege-lsxd/internal/models/utils/ollama"
|
|
)
|
|
|
|
// OllamaEmbedder implements text vectorization functionality using Ollama
|
|
type OllamaEmbedder struct {
|
|
modelName string
|
|
truncatePromptTokens int
|
|
ollamaService *ollama.OllamaService
|
|
dimensions int
|
|
modelID string
|
|
EmbedderPooler
|
|
}
|
|
|
|
// OllamaEmbedRequest represents an Ollama embedding request
|
|
type OllamaEmbedRequest struct {
|
|
Model string `json:"model"`
|
|
Prompt string `json:"prompt"`
|
|
TruncatePromptTokens int `json:"truncate_prompt_tokens"`
|
|
}
|
|
|
|
// OllamaEmbedResponse represents an Ollama embedding response
|
|
type OllamaEmbedResponse struct {
|
|
Embedding []float32 `json:"embedding"`
|
|
}
|
|
|
|
// NewOllamaEmbedder creates a new Ollama embedder
|
|
func NewOllamaEmbedder(baseURL,
|
|
modelName string,
|
|
truncatePromptTokens int,
|
|
dimensions int,
|
|
modelID string,
|
|
pooler EmbedderPooler,
|
|
ollamaService *ollama.OllamaService,
|
|
) (*OllamaEmbedder, error) {
|
|
if modelName == "" {
|
|
modelName = "nomic-embed-text"
|
|
}
|
|
|
|
if truncatePromptTokens == 0 {
|
|
truncatePromptTokens = 511
|
|
}
|
|
|
|
return &OllamaEmbedder{
|
|
modelName: modelName,
|
|
truncatePromptTokens: truncatePromptTokens,
|
|
ollamaService: ollamaService,
|
|
EmbedderPooler: pooler,
|
|
dimensions: dimensions,
|
|
modelID: modelID,
|
|
}, nil
|
|
}
|
|
|
|
// ensureModelAvailable ensures that the model is available
|
|
func (e *OllamaEmbedder) ensureModelAvailable(ctx context.Context) error {
|
|
logger.GetLogger(ctx).Infof("Ensuring model %s is available", e.modelName)
|
|
return e.ollamaService.EnsureModelAvailable(ctx, e.modelName)
|
|
}
|
|
|
|
// Embed converts text to vector
|
|
func (e *OllamaEmbedder) Embed(ctx context.Context, text string) ([]float32, error) {
|
|
embedding, err := e.BatchEmbed(ctx, []string{text})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to embed text: %w", err)
|
|
}
|
|
|
|
if len(embedding) == 0 {
|
|
return nil, fmt.Errorf("failed to embed text: %w", err)
|
|
}
|
|
|
|
return embedding[0], nil
|
|
}
|
|
|
|
// BatchEmbed converts multiple texts to vectors in batch
|
|
func (e *OllamaEmbedder) BatchEmbed(ctx context.Context, texts []string) ([][]float32, error) {
|
|
// Ensure model is available
|
|
if err := e.ensureModelAvailable(ctx); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Create request
|
|
req := &ollamaapi.EmbedRequest{
|
|
Model: e.modelName,
|
|
Input: texts,
|
|
Options: make(map[string]interface{}),
|
|
}
|
|
|
|
// Set truncation parameters
|
|
if e.truncatePromptTokens > 0 {
|
|
req.Options["truncate"] = e.truncatePromptTokens
|
|
}
|
|
|
|
// Send request
|
|
startTime := time.Now()
|
|
resp, err := e.ollamaService.Embeddings(ctx, req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get embedding vectors: %w", err)
|
|
}
|
|
|
|
logger.GetLogger(ctx).Debugf("Embedding vector retrieval took: %v", time.Since(startTime))
|
|
return resp.Embeddings, nil
|
|
}
|
|
|
|
// GetModelName returns the model name
|
|
func (e *OllamaEmbedder) GetModelName() string {
|
|
return e.modelName
|
|
}
|
|
|
|
// GetDimensions returns the vector dimensions
|
|
func (e *OllamaEmbedder) GetDimensions() int {
|
|
return e.dimensions
|
|
}
|
|
|
|
// GetModelID returns the model ID
|
|
func (e *OllamaEmbedder) GetModelID() string {
|
|
return e.modelID
|
|
}
|