package service import ( "context" "errors" "knowlege-lsxd/internal/logger" "knowlege-lsxd/internal/models/chat" "knowlege-lsxd/internal/models/embedding" "knowlege-lsxd/internal/models/rerank" "knowlege-lsxd/internal/models/utils/ollama" "knowlege-lsxd/internal/types" "knowlege-lsxd/internal/types/interfaces" ) // ErrModelNotFound is returned when a model cannot be found in the repository var ErrModelNotFound = errors.New("model not found") // modelService implements the model service interface type modelService struct { repo interfaces.ModelRepository ollamaService *ollama.OllamaService } // NewModelService creates a new model service instance func NewModelService(repo interfaces.ModelRepository, ollamaService *ollama.OllamaService) interfaces.ModelService { return &modelService{ repo: repo, ollamaService: ollamaService, } } // CreateModel creates a new model in the repository // For local models, it initiates an asynchronous download process // Remote models are immediately set to active status func (s *modelService) CreateModel(ctx context.Context, model *types.Model) error { logger.Info(ctx, "Start creating model") logger.Infof(ctx, "Creating model: %s, type: %s, source: %s", model.Name, model.Type, model.Source) // Handle remote models (e.g., OpenAI, Azure) if model.Source == types.ModelSourceRemote { logger.Info(ctx, "Remote model detected, setting status to active") model.Status = types.ModelStatusActive logger.Info(ctx, "Saving remote model to repository") err := s.repo.Create(ctx, model) if err != nil { logger.ErrorWithFields(ctx, err, map[string]interface{}{ "model_name": model.Name, "model_type": model.Type, }) return err } logger.Infof(ctx, "Remote model created successfully: %s", model.ID) return nil } // Handle local models (e.g., Ollama) logger.Info(ctx, "Local model detected, setting status to downloading") model.Status = types.ModelStatusDownloading logger.Info(ctx, "Saving local model to repository") err := s.repo.Create(ctx, model) if err != nil { logger.ErrorWithFields(ctx, err, map[string]interface{}{ "model_name": model.Name, "model_type": model.Type, }) return err } // Start asynchronous model download logger.Infof(ctx, "Starting background download for model: %s", model.Name) newCtx := logger.CloneContext(ctx) go func() { logger.Info(newCtx, "Background download started") err := s.ollamaService.PullModel(newCtx, model.Name) if err != nil { logger.ErrorWithFields(newCtx, err, map[string]interface{}{ "model_name": model.Name, }) model.Status = types.ModelStatusDownloadFailed } else { logger.Infof(newCtx, "Model download completed successfully: %s", model.Name) model.Status = types.ModelStatusActive } logger.Infof(newCtx, "Updating model status to: %s", model.Status) s.repo.Update(newCtx, model) }() logger.Infof(ctx, "Model creation initiated successfully: %s", model.ID) return nil } // GetModelByID retrieves a model by its ID // Returns an error if the model is not found or is in a non-active state func (s *modelService) GetModelByID(ctx context.Context, id string) (*types.Model, error) { logger.Info(ctx, "Start getting model by ID") logger.Infof(ctx, "Getting model with ID: %s", id) tenantID := ctx.Value(types.TenantIDContextKey).(uint) logger.Infof(ctx, "Tenant ID: %d", tenantID) // Fetch model from repository model, err := s.repo.GetByID(ctx, tenantID, id) if err != nil { logger.ErrorWithFields(ctx, err, map[string]interface{}{ "model_id": id, "tenant_id": tenantID, }) return nil, err } // Check if model exists if model == nil { logger.Error(ctx, "Model not found") return nil, ErrModelNotFound } logger.Infof(ctx, "Model found, name: %s, status: %s", model.Name, model.Status) // Check model status if model.Status == types.ModelStatusActive { logger.Info(ctx, "Model is active and ready to use") return model, nil } if model.Status == types.ModelStatusDownloading { logger.Warn(ctx, "Model is currently downloading") return nil, errors.New("model is currently downloading") } if model.Status == types.ModelStatusDownloadFailed { logger.Error(ctx, "Model download failed") return nil, errors.New("model download failed") } logger.Error(ctx, "Model status is abnormal") return nil, errors.New("abnormal model status") } // ListModels returns all models belonging to the tenant func (s *modelService) ListModels(ctx context.Context) ([]*types.Model, error) { logger.Info(ctx, "Start listing models") tenantID := ctx.Value(types.TenantIDContextKey).(uint) logger.Infof(ctx, "Listing models for tenant ID: %d", tenantID) // List models from repository with no additional filters models, err := s.repo.List(ctx, tenantID, "", "") if err != nil { logger.ErrorWithFields(ctx, err, map[string]interface{}{ "tenant_id": tenantID, }) return nil, err } logger.Infof(ctx, "Retrieved %d models successfully", len(models)) return models, nil } // UpdateModel updates an existing model in the repository func (s *modelService) UpdateModel(ctx context.Context, model *types.Model) error { logger.Info(ctx, "Start updating model") logger.Infof(ctx, "Updating model ID: %s, name: %s", model.ID, model.Name) // Update model in repository err := s.repo.Update(ctx, model) if err != nil { logger.ErrorWithFields(ctx, err, map[string]interface{}{ "model_id": model.ID, "model_name": model.Name, }) return err } logger.Infof(ctx, "Model updated successfully: %s", model.ID) return nil } // DeleteModel removes a model from the repository func (s *modelService) DeleteModel(ctx context.Context, id string) error { logger.Info(ctx, "Start deleting model") logger.Infof(ctx, "Deleting model ID: %s", id) tenantID := ctx.Value(types.TenantIDContextKey).(uint) logger.Infof(ctx, "Tenant ID: %d", tenantID) // Delete model from repository err := s.repo.Delete(ctx, tenantID, id) if err != nil { logger.ErrorWithFields(ctx, err, map[string]interface{}{ "model_id": id, "tenant_id": tenantID, }) return err } logger.Infof(ctx, "Model deleted successfully: %s", id) return nil } // GetEmbeddingModel retrieves and initializes an embedding model instance // Takes a model ID and returns an Embedder interface implementation func (s *modelService) GetEmbeddingModel(ctx context.Context, modelId string) (embedding.Embedder, error) { logger.Info(ctx, "Start getting embedding model") logger.Infof(ctx, "Getting embedding model with ID: %s", modelId) // Get the model details model, err := s.GetModelByID(ctx, modelId) if err != nil { logger.ErrorWithFields(ctx, err, map[string]interface{}{ "model_id": modelId, }) return nil, err } logger.Info(ctx, "Creating embedder instance") logger.Infof(ctx, "Model name: %s, source: %s", model.Name, model.Source) // Initialize the embedder with model configuration embedder, err := embedding.NewEmbedder(embedding.Config{ Source: model.Source, BaseURL: model.Parameters.BaseURL, APIKey: model.Parameters.APIKey, ModelID: model.ID, ModelName: model.Name, Dimensions: model.Parameters.EmbeddingParameters.Dimension, TruncatePromptTokens: model.Parameters.EmbeddingParameters.TruncatePromptTokens, }) if err != nil { logger.ErrorWithFields(ctx, err, map[string]interface{}{ "model_id": model.ID, "model_name": model.Name, }) return nil, err } logger.Info(ctx, "Embedding model initialized successfully") return embedder, nil } // GetRerankModel retrieves and initializes a reranking model instance // Takes a model ID and returns a Reranker interface implementation func (s *modelService) GetRerankModel(ctx context.Context, modelId string) (rerank.Reranker, error) { logger.Info(ctx, "Start getting rerank model") logger.Infof(ctx, "Getting rerank model with ID: %s", modelId) // Get the model details model, err := s.GetModelByID(ctx, modelId) if err != nil { logger.ErrorWithFields(ctx, err, map[string]interface{}{ "model_id": modelId, }) return nil, err } logger.Info(ctx, "Creating reranker instance") logger.Infof(ctx, "Model name: %s, source: %s", model.Name, model.Source) // Initialize the reranker with model configuration reranker, err := rerank.NewReranker(&rerank.RerankerConfig{ ModelID: model.ID, APIKey: model.Parameters.APIKey, BaseURL: model.Parameters.BaseURL, ModelName: model.Name, Source: model.Source, }) if err != nil { logger.ErrorWithFields(ctx, err, map[string]interface{}{ "model_id": model.ID, "model_name": model.Name, }) return nil, err } logger.Info(ctx, "Rerank model initialized successfully") return reranker, nil } // GetChatModel retrieves and initializes a chat model instance // Takes a model ID and returns a Chat interface implementation func (s *modelService) GetChatModel(ctx context.Context, modelId string) (chat.Chat, error) { logger.Info(ctx, "Start getting chat model") logger.Infof(ctx, "Getting chat model with ID: %s", modelId) tenantID := ctx.Value(types.TenantIDContextKey).(uint) logger.Infof(ctx, "Tenant ID: %d", tenantID) // Get the model directly from repository to avoid status checks model, err := s.repo.GetByID(ctx, tenantID, modelId) if err != nil { logger.ErrorWithFields(ctx, err, map[string]interface{}{ "model_id": modelId, "tenant_id": tenantID, }) return nil, err } if model == nil { logger.Error(ctx, "Chat model not found") return nil, ErrModelNotFound } logger.Info(ctx, "Creating chat model instance") logger.Infof(ctx, "Model name: %s, source: %s", model.Name, model.Source) // Initialize the chat model with model configuration chatModel, err := chat.NewChat(&chat.ChatConfig{ ModelID: model.ID, APIKey: model.Parameters.APIKey, BaseURL: model.Parameters.BaseURL, ModelName: model.Name, Source: model.Source, }) if err != nil { logger.ErrorWithFields(ctx, err, map[string]interface{}{ "model_id": model.ID, "model_name": model.Name, }) return nil, err } logger.Info(ctx, "Chat model initialized successfully") return chatModel, nil }