l_ai_knowledge/internal/models/embedding/ollama.go

134 lines
3.5 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package embedding
import (
"context"
"fmt"
"time"
"knowlege-lsxd/internal/logger"
"knowlege-lsxd/internal/models/utils/ollama"
ollamaapi "github.com/ollama/ollama/api"
)
// OllamaEmbedder implements text vectorization functionality using Ollama
type OllamaEmbedder struct {
modelName string
truncatePromptTokens int
ollamaService *ollama.OllamaService
dimensions int
modelID string
EmbedderPooler
}
// OllamaEmbedRequest represents an Ollama embedding request
type OllamaEmbedRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
TruncatePromptTokens int `json:"truncate_prompt_tokens"`
}
// OllamaEmbedResponse represents an Ollama embedding response
type OllamaEmbedResponse struct {
Embedding []float32 `json:"embedding"`
}
// NewOllamaEmbedder creates a new Ollama embedder
func NewOllamaEmbedder(baseURL,
modelName string,
truncatePromptTokens int,
dimensions int,
modelID string,
pooler EmbedderPooler,
ollamaService *ollama.OllamaService,
) (*OllamaEmbedder, error) {
if modelName == "" {
modelName = "nomic-embed-text"
}
if truncatePromptTokens == 0 {
truncatePromptTokens = 511
}
return &OllamaEmbedder{
modelName: modelName,
truncatePromptTokens: truncatePromptTokens,
ollamaService: ollamaService,
EmbedderPooler: pooler,
dimensions: dimensions,
modelID: modelID,
}, nil
}
// ensureModelAvailable ensures that the model is available
func (e *OllamaEmbedder) ensureModelAvailable(ctx context.Context) error {
logger.GetLogger(ctx).Infof("Ensuring model %s is available", e.modelName)
return e.ollamaService.EnsureModelAvailable(ctx, e.modelName)
}
// Embed converts text to vector
func (e *OllamaEmbedder) Embed(ctx context.Context, text string) ([]float32, error) {
embedding, err := e.BatchEmbed(ctx, []string{text})
if err != nil {
return nil, fmt.Errorf("failed to embed text: %w", err)
}
if len(embedding) == 0 {
return nil, fmt.Errorf("failed to embed text: %w", err)
}
return embedding[0], nil
}
// BatchEmbed 为文本列表批量生成嵌入向量
// 参数:
// - ctx: 上下文信息,用于控制请求的生命周期
// - texts: 需要生成嵌入向量的文本列表
//
// 返回值:
// - [][]float32: 生成的嵌入向量列表,每个文本对应一个向量
// - error: 错误信息如果成功则为nil
func (e *OllamaEmbedder) BatchEmbed(ctx context.Context, texts []string) ([][]float32, error) {
// Ensure model is available
if err := e.ensureModelAvailable(ctx); err != nil {
return nil, err
}
// Create request
req := &ollamaapi.EmbedRequest{
Model: e.modelName,
Input: texts,
Options: make(map[string]interface{}),
}
// Set truncation parameters
if e.truncatePromptTokens > 0 {
req.Options["truncate"] = e.truncatePromptTokens
}
// Send request
startTime := time.Now()
resp, err := e.ollamaService.Embeddings(ctx, req)
if err != nil {
return nil, fmt.Errorf("failed to get embedding vectors: %w", err)
}
logger.GetLogger(ctx).Debugf("Embedding vector retrieval took: %v", time.Since(startTime))
return resp.Embeddings, nil
}
// GetModelName returns the model name
func (e *OllamaEmbedder) GetModelName() string {
return e.modelName
}
// GetDimensions returns the vector dimensions
func (e *OllamaEmbedder) GetDimensions() int {
return e.dimensions
}
// GetModelID returns the model ID
func (e *OllamaEmbedder) GetModelID() string {
return e.modelID
}