134 lines
3.5 KiB
Go
134 lines
3.5 KiB
Go
package embedding
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"time"
|
||
|
||
"knowlege-lsxd/internal/logger"
|
||
"knowlege-lsxd/internal/models/utils/ollama"
|
||
|
||
ollamaapi "github.com/ollama/ollama/api"
|
||
)
|
||
|
||
// OllamaEmbedder implements text vectorization functionality using Ollama
|
||
type OllamaEmbedder struct {
|
||
modelName string
|
||
truncatePromptTokens int
|
||
ollamaService *ollama.OllamaService
|
||
dimensions int
|
||
modelID string
|
||
EmbedderPooler
|
||
}
|
||
|
||
// OllamaEmbedRequest represents an Ollama embedding request
|
||
type OllamaEmbedRequest struct {
|
||
Model string `json:"model"`
|
||
Prompt string `json:"prompt"`
|
||
TruncatePromptTokens int `json:"truncate_prompt_tokens"`
|
||
}
|
||
|
||
// OllamaEmbedResponse represents an Ollama embedding response
|
||
type OllamaEmbedResponse struct {
|
||
Embedding []float32 `json:"embedding"`
|
||
}
|
||
|
||
// NewOllamaEmbedder creates a new Ollama embedder
|
||
func NewOllamaEmbedder(baseURL,
|
||
modelName string,
|
||
truncatePromptTokens int,
|
||
dimensions int,
|
||
modelID string,
|
||
pooler EmbedderPooler,
|
||
ollamaService *ollama.OllamaService,
|
||
) (*OllamaEmbedder, error) {
|
||
if modelName == "" {
|
||
modelName = "nomic-embed-text"
|
||
}
|
||
|
||
if truncatePromptTokens == 0 {
|
||
truncatePromptTokens = 511
|
||
}
|
||
|
||
return &OllamaEmbedder{
|
||
modelName: modelName,
|
||
truncatePromptTokens: truncatePromptTokens,
|
||
ollamaService: ollamaService,
|
||
EmbedderPooler: pooler,
|
||
dimensions: dimensions,
|
||
modelID: modelID,
|
||
}, nil
|
||
}
|
||
|
||
// ensureModelAvailable ensures that the model is available
|
||
func (e *OllamaEmbedder) ensureModelAvailable(ctx context.Context) error {
|
||
logger.GetLogger(ctx).Infof("Ensuring model %s is available", e.modelName)
|
||
return e.ollamaService.EnsureModelAvailable(ctx, e.modelName)
|
||
}
|
||
|
||
// Embed converts text to vector
|
||
func (e *OllamaEmbedder) Embed(ctx context.Context, text string) ([]float32, error) {
|
||
embedding, err := e.BatchEmbed(ctx, []string{text})
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to embed text: %w", err)
|
||
}
|
||
|
||
if len(embedding) == 0 {
|
||
return nil, fmt.Errorf("failed to embed text: %w", err)
|
||
}
|
||
|
||
return embedding[0], nil
|
||
}
|
||
|
||
// BatchEmbed 为文本列表批量生成嵌入向量
|
||
// 参数:
|
||
// - ctx: 上下文信息,用于控制请求的生命周期
|
||
// - texts: 需要生成嵌入向量的文本列表
|
||
//
|
||
// 返回值:
|
||
// - [][]float32: 生成的嵌入向量列表,每个文本对应一个向量
|
||
// - error: 错误信息,如果成功则为nil
|
||
func (e *OllamaEmbedder) BatchEmbed(ctx context.Context, texts []string) ([][]float32, error) {
|
||
// Ensure model is available
|
||
if err := e.ensureModelAvailable(ctx); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// Create request
|
||
req := &ollamaapi.EmbedRequest{
|
||
Model: e.modelName,
|
||
Input: texts,
|
||
Options: make(map[string]interface{}),
|
||
}
|
||
|
||
// Set truncation parameters
|
||
if e.truncatePromptTokens > 0 {
|
||
req.Options["truncate"] = e.truncatePromptTokens
|
||
}
|
||
|
||
// Send request
|
||
startTime := time.Now()
|
||
resp, err := e.ollamaService.Embeddings(ctx, req)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to get embedding vectors: %w", err)
|
||
}
|
||
|
||
logger.GetLogger(ctx).Debugf("Embedding vector retrieval took: %v", time.Since(startTime))
|
||
return resp.Embeddings, nil
|
||
}
|
||
|
||
// GetModelName returns the model name
|
||
func (e *OllamaEmbedder) GetModelName() string {
|
||
return e.modelName
|
||
}
|
||
|
||
// GetDimensions returns the vector dimensions
|
||
func (e *OllamaEmbedder) GetDimensions() int {
|
||
return e.dimensions
|
||
}
|
||
|
||
// GetModelID returns the model ID
|
||
func (e *OllamaEmbedder) GetModelID() string {
|
||
return e.modelID
|
||
}
|