l_ai_knowledge/internal/models/embedding/ollama.go

126 lines
3.2 KiB
Go

package embedding
import (
"context"
"fmt"
"time"
ollamaapi "github.com/ollama/ollama/api"
"knowlege-lsxd/internal/logger"
"knowlege-lsxd/internal/models/utils/ollama"
)
// OllamaEmbedder implements text vectorization functionality using Ollama
type OllamaEmbedder struct {
modelName string
truncatePromptTokens int
ollamaService *ollama.OllamaService
dimensions int
modelID string
EmbedderPooler
}
// OllamaEmbedRequest represents an Ollama embedding request
type OllamaEmbedRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
TruncatePromptTokens int `json:"truncate_prompt_tokens"`
}
// OllamaEmbedResponse represents an Ollama embedding response
type OllamaEmbedResponse struct {
Embedding []float32 `json:"embedding"`
}
// NewOllamaEmbedder creates a new Ollama embedder
func NewOllamaEmbedder(baseURL,
modelName string,
truncatePromptTokens int,
dimensions int,
modelID string,
pooler EmbedderPooler,
ollamaService *ollama.OllamaService,
) (*OllamaEmbedder, error) {
if modelName == "" {
modelName = "nomic-embed-text"
}
if truncatePromptTokens == 0 {
truncatePromptTokens = 511
}
return &OllamaEmbedder{
modelName: modelName,
truncatePromptTokens: truncatePromptTokens,
ollamaService: ollamaService,
EmbedderPooler: pooler,
dimensions: dimensions,
modelID: modelID,
}, nil
}
// ensureModelAvailable ensures that the model is available
func (e *OllamaEmbedder) ensureModelAvailable(ctx context.Context) error {
logger.GetLogger(ctx).Infof("Ensuring model %s is available", e.modelName)
return e.ollamaService.EnsureModelAvailable(ctx, e.modelName)
}
// Embed converts text to vector
func (e *OllamaEmbedder) Embed(ctx context.Context, text string) ([]float32, error) {
embedding, err := e.BatchEmbed(ctx, []string{text})
if err != nil {
return nil, fmt.Errorf("failed to embed text: %w", err)
}
if len(embedding) == 0 {
return nil, fmt.Errorf("failed to embed text: %w", err)
}
return embedding[0], nil
}
// BatchEmbed converts multiple texts to vectors in batch
func (e *OllamaEmbedder) BatchEmbed(ctx context.Context, texts []string) ([][]float32, error) {
// Ensure model is available
if err := e.ensureModelAvailable(ctx); err != nil {
return nil, err
}
// Create request
req := &ollamaapi.EmbedRequest{
Model: e.modelName,
Input: texts,
Options: make(map[string]interface{}),
}
// Set truncation parameters
if e.truncatePromptTokens > 0 {
req.Options["truncate"] = e.truncatePromptTokens
}
// Send request
startTime := time.Now()
resp, err := e.ollamaService.Embeddings(ctx, req)
if err != nil {
return nil, fmt.Errorf("failed to get embedding vectors: %w", err)
}
logger.GetLogger(ctx).Debugf("Embedding vector retrieval took: %v", time.Since(startTime))
return resp.Embeddings, nil
}
// GetModelName returns the model name
func (e *OllamaEmbedder) GetModelName() string {
return e.modelName
}
// GetDimensions returns the vector dimensions
func (e *OllamaEmbedder) GetDimensions() int {
return e.dimensions
}
// GetModelID returns the model ID
func (e *OllamaEmbedder) GetModelID() string {
return e.modelID
}