l_ai_knowledge/internal/models/utils/ollama/ollama.go

317 lines
8.2 KiB
Go

package ollama
import (
"context"
"fmt"
"net/http"
"net/url"
"os"
"strings"
"sync"
"github.com/ollama/ollama/api"
"knowlege-lsxd/internal/logger"
)
// OllamaService manages Ollama service
type OllamaService struct {
client *api.Client
baseURL string
mu sync.Mutex
isAvailable bool
isOptional bool // Added: marks if Ollama service is optional
}
// GetOllamaService gets Ollama service instance (singleton pattern)
func GetOllamaService() (*OllamaService, error) {
// Get Ollama base URL from environment variable, if not set use provided baseURL or default value
logger.GetLogger(context.Background()).Infof("Ollama base URL: %s", os.Getenv("OLLAMA_BASE_URL"))
baseURL := "http://localhost:11434"
envURL := os.Getenv("OLLAMA_BASE_URL")
if envURL != "" {
baseURL = envURL
}
// Create URL object
parsedURL, err := url.Parse(baseURL)
if err != nil {
return nil, fmt.Errorf("invalid Ollama service URL: %w", err)
}
// Create official client
client := api.NewClient(parsedURL, http.DefaultClient)
// Check if Ollama is set as optional
isOptional := false
if os.Getenv("OLLAMA_OPTIONAL") == "true" {
isOptional = true
logger.GetLogger(context.Background()).Info("Ollama service set to optional mode")
}
service := &OllamaService{
client: client,
baseURL: baseURL,
isOptional: isOptional,
}
return service, nil
}
// StartService checks if Ollama service is available
func (s *OllamaService) StartService(ctx context.Context) error {
s.mu.Lock()
defer s.mu.Unlock()
// Check if service is available
err := s.client.Heartbeat(ctx)
if err != nil {
logger.GetLogger(ctx).Warnf("Ollama service unavailable: %v", err)
s.isAvailable = false
// If configured as optional, don't return an error
if s.isOptional {
logger.GetLogger(ctx).Info("Ollama service set as optional, will continue running the application")
return nil
}
return fmt.Errorf("Ollama service unavailable: %w", err)
}
s.isAvailable = true
logger.GetLogger(ctx).Info("Ollama service ready")
return nil
}
// IsAvailable returns whether the service is available
func (s *OllamaService) IsAvailable() bool {
s.mu.Lock()
defer s.mu.Unlock()
return s.isAvailable
}
// IsModelAvailable checks if a model is available
func (s *OllamaService) IsModelAvailable(ctx context.Context, modelName string) (bool, error) {
// First check if the service is available
if err := s.StartService(ctx); err != nil {
return false, err
}
// If service is not available but set as optional, return false but no error
if !s.isAvailable && s.isOptional {
return false, nil
}
// Get model list
listResp, err := s.client.List(ctx)
if err != nil {
return false, fmt.Errorf("failed to get model list: %w", err)
}
// Check if model is in the list
for _, model := range listResp.Models {
if model.Name == modelName {
return true, nil
}
}
return false, nil
}
// PullModel pulls a model
func (s *OllamaService) PullModel(ctx context.Context, modelName string) error {
// First check if the service is available
if err := s.StartService(ctx); err != nil {
return err
}
// If service is not available but set as optional, return nil without further operations
if !s.isAvailable && s.isOptional {
logger.GetLogger(ctx).Warnf("Ollama service unavailable, unable to pull model %s", modelName)
return nil
}
// Check if model already exists
available, err := s.IsModelAvailable(ctx, modelName)
if err != nil {
return err
}
if available {
logger.GetLogger(ctx).Infof("Model %s already exists", modelName)
return nil
}
logger.GetLogger(ctx).Infof("Pulling model %s...", modelName)
// Use official client to pull model
pullReq := &api.PullRequest{
Name: modelName,
}
err = s.client.Pull(ctx, pullReq, func(progress api.ProgressResponse) error {
if progress.Status != "" {
if progress.Total > 0 && progress.Completed > 0 {
percentage := float64(progress.Completed) / float64(progress.Total) * 100
logger.GetLogger(ctx).Infof("Pull progress: %s (%.2f%%)",
progress.Status, percentage)
} else {
logger.GetLogger(ctx).Infof("Pull status: %s", progress.Status)
}
}
if progress.Total > 0 && progress.Completed == progress.Total {
logger.GetLogger(ctx).Infof("Model %s pull completed", modelName)
}
return nil
})
if err != nil {
return fmt.Errorf("failed to pull model: %w", err)
}
return nil
}
// EnsureModelAvailable ensures the model is available, pulls it if not available
func (s *OllamaService) EnsureModelAvailable(ctx context.Context, modelName string) error {
// If service is not available but set as optional, return nil directly
if !s.IsAvailable() && s.isOptional {
logger.GetLogger(ctx).Warnf("Ollama service unavailable, skipping ensuring model %s availability", modelName)
return nil
}
available, err := s.IsModelAvailable(ctx, modelName)
if err != nil {
if s.isOptional {
logger.GetLogger(ctx).Warnf("Failed to check model %s availability, but Ollama is set as optional", modelName)
return nil
}
return err
}
if !available {
return s.PullModel(ctx, modelName)
}
return nil
}
// GetVersion gets Ollama version
func (s *OllamaService) GetVersion(ctx context.Context) (string, error) {
// If service is not available but set as optional, return empty version info
if !s.IsAvailable() && s.isOptional {
return "unavailable", nil
}
version, err := s.client.Version(ctx)
if err != nil {
return "", fmt.Errorf("failed to get Ollama version: %w", err)
}
return version, nil
}
// CreateModel creates a custom model
func (s *OllamaService) CreateModel(ctx context.Context, name, modelfile string) error {
req := &api.CreateRequest{
Model: name,
Template: modelfile, // Use Template field instead of Modelfile
}
err := s.client.Create(ctx, req, func(progress api.ProgressResponse) error {
if progress.Status != "" {
logger.GetLogger(ctx).Infof("Model creation status: %s", progress.Status)
}
return nil
})
if err != nil {
return fmt.Errorf("failed to create model: %w", err)
}
return nil
}
// GetModelInfo gets model information
func (s *OllamaService) GetModelInfo(ctx context.Context, modelName string) (*api.ShowResponse, error) {
req := &api.ShowRequest{
Name: modelName,
}
resp, err := s.client.Show(ctx, req)
if err != nil {
return nil, fmt.Errorf("failed to get model information: %w", err)
}
return resp, nil
}
// ListModels lists all available models
func (s *OllamaService) ListModels(ctx context.Context) ([]string, error) {
listResp, err := s.client.List(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get model list: %w", err)
}
modelNames := make([]string, len(listResp.Models))
for i, model := range listResp.Models {
modelNames[i] = model.Name
}
return modelNames, nil
}
// DeleteModel deletes a model
func (s *OllamaService) DeleteModel(ctx context.Context, modelName string) error {
req := &api.DeleteRequest{
Name: modelName,
}
err := s.client.Delete(ctx, req)
if err != nil {
return fmt.Errorf("failed to delete model: %w", err)
}
return nil
}
// IsValidModelName checks if model name is valid
func IsValidModelName(name string) bool {
// Simple check for model name format
return name != "" && !strings.Contains(name, " ")
}
// Chat uses Ollama chat
func (s *OllamaService) Chat(ctx context.Context, req *api.ChatRequest, fn api.ChatResponseFunc) error {
// First check if service is available
if err := s.StartService(ctx); err != nil {
return err
}
// Use official client Chat method
return s.client.Chat(ctx, req, fn)
}
// Embeddings gets text embedding vectors
func (s *OllamaService) Embeddings(ctx context.Context, req *api.EmbedRequest) (*api.EmbedResponse, error) {
// First check if service is available
if err := s.StartService(ctx); err != nil {
return nil, err
}
// Use official client Embed method
return s.client.Embed(ctx, req)
}
// Generate generates text (used for Rerank)
func (s *OllamaService) Generate(ctx context.Context, req *api.GenerateRequest, fn api.GenerateResponseFunc) error {
// First check if service is available
if err := s.StartService(ctx); err != nil {
return err
}
// Use official client Generate method
return s.client.Generate(ctx, req, fn)
}
// GetClient returns the underlying ollama client for advanced operations
func (s *OllamaService) GetClient() *api.Client {
return s.client
}