l_ai_knowledge/internal/application/service/model.go

328 lines
10 KiB
Go

package service
import (
"context"
"errors"
"knowlege-lsxd/internal/logger"
"knowlege-lsxd/internal/models/chat"
"knowlege-lsxd/internal/models/embedding"
"knowlege-lsxd/internal/models/rerank"
"knowlege-lsxd/internal/models/utils/ollama"
"knowlege-lsxd/internal/types"
"knowlege-lsxd/internal/types/interfaces"
)
// ErrModelNotFound is returned when a model cannot be found in the repository
var ErrModelNotFound = errors.New("model not found")
// modelService implements the model service interface
type modelService struct {
repo interfaces.ModelRepository
ollamaService *ollama.OllamaService
}
// NewModelService creates a new model service instance
func NewModelService(repo interfaces.ModelRepository, ollamaService *ollama.OllamaService) interfaces.ModelService {
return &modelService{
repo: repo,
ollamaService: ollamaService,
}
}
// CreateModel creates a new model in the repository
// For local models, it initiates an asynchronous download process
// Remote models are immediately set to active status
func (s *modelService) CreateModel(ctx context.Context, model *types.Model) error {
logger.Info(ctx, "Start creating model")
logger.Infof(ctx, "Creating model: %s, type: %s, source: %s", model.Name, model.Type, model.Source)
// Handle remote models (e.g., OpenAI, Azure)
if model.Source == types.ModelSourceRemote {
logger.Info(ctx, "Remote model detected, setting status to active")
model.Status = types.ModelStatusActive
logger.Info(ctx, "Saving remote model to repository")
err := s.repo.Create(ctx, model)
if err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{
"model_name": model.Name,
"model_type": model.Type,
})
return err
}
logger.Infof(ctx, "Remote model created successfully: %s", model.ID)
return nil
}
// Handle local models (e.g., Ollama)
logger.Info(ctx, "Local model detected, setting status to downloading")
model.Status = types.ModelStatusDownloading
logger.Info(ctx, "Saving local model to repository")
err := s.repo.Create(ctx, model)
if err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{
"model_name": model.Name,
"model_type": model.Type,
})
return err
}
// Start asynchronous model download
logger.Infof(ctx, "Starting background download for model: %s", model.Name)
newCtx := logger.CloneContext(ctx)
go func() {
logger.Info(newCtx, "Background download started")
err := s.ollamaService.PullModel(newCtx, model.Name)
if err != nil {
logger.ErrorWithFields(newCtx, err, map[string]interface{}{
"model_name": model.Name,
})
model.Status = types.ModelStatusDownloadFailed
} else {
logger.Infof(newCtx, "Model download completed successfully: %s", model.Name)
model.Status = types.ModelStatusActive
}
logger.Infof(newCtx, "Updating model status to: %s", model.Status)
s.repo.Update(newCtx, model)
}()
logger.Infof(ctx, "Model creation initiated successfully: %s", model.ID)
return nil
}
// GetModelByID retrieves a model by its ID
// Returns an error if the model is not found or is in a non-active state
func (s *modelService) GetModelByID(ctx context.Context, id string) (*types.Model, error) {
logger.Info(ctx, "Start getting model by ID")
logger.Infof(ctx, "Getting model with ID: %s", id)
tenantID := ctx.Value(types.TenantIDContextKey).(uint)
logger.Infof(ctx, "Tenant ID: %d", tenantID)
// Fetch model from repository
model, err := s.repo.GetByID(ctx, tenantID, id)
if err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{
"model_id": id,
"tenant_id": tenantID,
})
return nil, err
}
// Check if model exists
if model == nil {
logger.Error(ctx, "Model not found")
return nil, ErrModelNotFound
}
logger.Infof(ctx, "Model found, name: %s, status: %s", model.Name, model.Status)
// Check model status
if model.Status == types.ModelStatusActive {
logger.Info(ctx, "Model is active and ready to use")
return model, nil
}
if model.Status == types.ModelStatusDownloading {
logger.Warn(ctx, "Model is currently downloading")
return nil, errors.New("model is currently downloading")
}
if model.Status == types.ModelStatusDownloadFailed {
logger.Error(ctx, "Model download failed")
return nil, errors.New("model download failed")
}
logger.Error(ctx, "Model status is abnormal")
return nil, errors.New("abnormal model status")
}
// ListModels returns all models belonging to the tenant
func (s *modelService) ListModels(ctx context.Context) ([]*types.Model, error) {
logger.Info(ctx, "Start listing models")
tenantID := ctx.Value(types.TenantIDContextKey).(uint)
logger.Infof(ctx, "Listing models for tenant ID: %d", tenantID)
// List models from repository with no additional filters
models, err := s.repo.List(ctx, tenantID, "", "")
if err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{
"tenant_id": tenantID,
})
return nil, err
}
logger.Infof(ctx, "Retrieved %d models successfully", len(models))
return models, nil
}
// UpdateModel updates an existing model in the repository
func (s *modelService) UpdateModel(ctx context.Context, model *types.Model) error {
logger.Info(ctx, "Start updating model")
logger.Infof(ctx, "Updating model ID: %s, name: %s", model.ID, model.Name)
// Update model in repository
err := s.repo.Update(ctx, model)
if err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{
"model_id": model.ID,
"model_name": model.Name,
})
return err
}
logger.Infof(ctx, "Model updated successfully: %s", model.ID)
return nil
}
// DeleteModel removes a model from the repository
func (s *modelService) DeleteModel(ctx context.Context, id string) error {
logger.Info(ctx, "Start deleting model")
logger.Infof(ctx, "Deleting model ID: %s", id)
tenantID := ctx.Value(types.TenantIDContextKey).(uint)
logger.Infof(ctx, "Tenant ID: %d", tenantID)
// Delete model from repository
err := s.repo.Delete(ctx, tenantID, id)
if err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{
"model_id": id,
"tenant_id": tenantID,
})
return err
}
logger.Infof(ctx, "Model deleted successfully: %s", id)
return nil
}
// GetEmbeddingModel retrieves and initializes an embedding model instance
// Takes a model ID and returns an Embedder interface implementation
func (s *modelService) GetEmbeddingModel(ctx context.Context, modelId string) (embedding.Embedder, error) {
logger.Info(ctx, "Start getting embedding model")
logger.Infof(ctx, "Getting embedding model with ID: %s", modelId)
// Get the model details
model, err := s.GetModelByID(ctx, modelId)
if err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{
"model_id": modelId,
})
return nil, err
}
logger.Info(ctx, "Creating embedder instance")
logger.Infof(ctx, "Model name: %s, source: %s", model.Name, model.Source)
// Initialize the embedder with model configuration
embedder, err := embedding.NewEmbedder(embedding.Config{
Source: model.Source,
BaseURL: model.Parameters.BaseURL,
APIKey: model.Parameters.APIKey,
ModelID: model.ID,
ModelName: model.Name,
Dimensions: model.Parameters.EmbeddingParameters.Dimension,
TruncatePromptTokens: model.Parameters.EmbeddingParameters.TruncatePromptTokens,
})
if err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{
"model_id": model.ID,
"model_name": model.Name,
})
return nil, err
}
logger.Info(ctx, "Embedding model initialized successfully")
return embedder, nil
}
// GetRerankModel retrieves and initializes a reranking model instance
// Takes a model ID and returns a Reranker interface implementation
func (s *modelService) GetRerankModel(ctx context.Context, modelId string) (rerank.Reranker, error) {
logger.Info(ctx, "Start getting rerank model")
logger.Infof(ctx, "Getting rerank model with ID: %s", modelId)
// Get the model details
model, err := s.GetModelByID(ctx, modelId)
if err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{
"model_id": modelId,
})
return nil, err
}
logger.Info(ctx, "Creating reranker instance")
logger.Infof(ctx, "Model name: %s, source: %s", model.Name, model.Source)
// Initialize the reranker with model configuration
reranker, err := rerank.NewReranker(&rerank.RerankerConfig{
ModelID: model.ID,
APIKey: model.Parameters.APIKey,
BaseURL: model.Parameters.BaseURL,
ModelName: model.Name,
Source: model.Source,
})
if err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{
"model_id": model.ID,
"model_name": model.Name,
})
return nil, err
}
logger.Info(ctx, "Rerank model initialized successfully")
return reranker, nil
}
// GetChatModel retrieves and initializes a chat model instance
// Takes a model ID and returns a Chat interface implementation
func (s *modelService) GetChatModel(ctx context.Context, modelId string) (chat.Chat, error) {
logger.Info(ctx, "Start getting chat model")
logger.Infof(ctx, "Getting chat model with ID: %s", modelId)
tenantID := ctx.Value(types.TenantIDContextKey).(uint)
logger.Infof(ctx, "Tenant ID: %d", tenantID)
// Get the model directly from repository to avoid status checks
model, err := s.repo.GetByID(ctx, tenantID, modelId)
if err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{
"model_id": modelId,
"tenant_id": tenantID,
})
return nil, err
}
if model == nil {
logger.Error(ctx, "Chat model not found")
return nil, ErrModelNotFound
}
logger.Info(ctx, "Creating chat model instance")
logger.Infof(ctx, "Model name: %s, source: %s", model.Name, model.Source)
// Initialize the chat model with model configuration
chatModel, err := chat.NewChat(&chat.ChatConfig{
ModelID: model.ID,
APIKey: model.Parameters.APIKey,
BaseURL: model.Parameters.BaseURL,
ModelName: model.Name,
Source: model.Source,
})
if err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{
"model_id": model.ID,
"model_name": model.Name,
})
return nil, err
}
logger.Info(ctx, "Chat model initialized successfully")
return chatModel, nil
}