328 lines
10 KiB
Go
328 lines
10 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
|
|
"knowlege-lsxd/internal/logger"
|
|
"knowlege-lsxd/internal/models/chat"
|
|
"knowlege-lsxd/internal/models/embedding"
|
|
"knowlege-lsxd/internal/models/rerank"
|
|
"knowlege-lsxd/internal/models/utils/ollama"
|
|
"knowlege-lsxd/internal/types"
|
|
"knowlege-lsxd/internal/types/interfaces"
|
|
)
|
|
|
|
// ErrModelNotFound is returned when a model cannot be found in the repository
|
|
var ErrModelNotFound = errors.New("model not found")
|
|
|
|
// modelService implements the model service interface
|
|
type modelService struct {
|
|
repo interfaces.ModelRepository
|
|
ollamaService *ollama.OllamaService
|
|
}
|
|
|
|
// NewModelService creates a new model service instance
|
|
func NewModelService(repo interfaces.ModelRepository, ollamaService *ollama.OllamaService) interfaces.ModelService {
|
|
return &modelService{
|
|
repo: repo,
|
|
ollamaService: ollamaService,
|
|
}
|
|
}
|
|
|
|
// CreateModel creates a new model in the repository
|
|
// For local models, it initiates an asynchronous download process
|
|
// Remote models are immediately set to active status
|
|
func (s *modelService) CreateModel(ctx context.Context, model *types.Model) error {
|
|
logger.Info(ctx, "Start creating model")
|
|
logger.Infof(ctx, "Creating model: %s, type: %s, source: %s", model.Name, model.Type, model.Source)
|
|
|
|
// Handle remote models (e.g., OpenAI, Azure)
|
|
if model.Source == types.ModelSourceRemote {
|
|
logger.Info(ctx, "Remote model detected, setting status to active")
|
|
model.Status = types.ModelStatusActive
|
|
|
|
logger.Info(ctx, "Saving remote model to repository")
|
|
err := s.repo.Create(ctx, model)
|
|
if err != nil {
|
|
logger.ErrorWithFields(ctx, err, map[string]interface{}{
|
|
"model_name": model.Name,
|
|
"model_type": model.Type,
|
|
})
|
|
return err
|
|
}
|
|
|
|
logger.Infof(ctx, "Remote model created successfully: %s", model.ID)
|
|
return nil
|
|
}
|
|
|
|
// Handle local models (e.g., Ollama)
|
|
logger.Info(ctx, "Local model detected, setting status to downloading")
|
|
model.Status = types.ModelStatusDownloading
|
|
|
|
logger.Info(ctx, "Saving local model to repository")
|
|
err := s.repo.Create(ctx, model)
|
|
if err != nil {
|
|
logger.ErrorWithFields(ctx, err, map[string]interface{}{
|
|
"model_name": model.Name,
|
|
"model_type": model.Type,
|
|
})
|
|
return err
|
|
}
|
|
|
|
// Start asynchronous model download
|
|
logger.Infof(ctx, "Starting background download for model: %s", model.Name)
|
|
newCtx := logger.CloneContext(ctx)
|
|
go func() {
|
|
logger.Info(newCtx, "Background download started")
|
|
err := s.ollamaService.PullModel(newCtx, model.Name)
|
|
if err != nil {
|
|
logger.ErrorWithFields(newCtx, err, map[string]interface{}{
|
|
"model_name": model.Name,
|
|
})
|
|
model.Status = types.ModelStatusDownloadFailed
|
|
} else {
|
|
logger.Infof(newCtx, "Model download completed successfully: %s", model.Name)
|
|
model.Status = types.ModelStatusActive
|
|
}
|
|
logger.Infof(newCtx, "Updating model status to: %s", model.Status)
|
|
s.repo.Update(newCtx, model)
|
|
}()
|
|
|
|
logger.Infof(ctx, "Model creation initiated successfully: %s", model.ID)
|
|
return nil
|
|
}
|
|
|
|
// GetModelByID retrieves a model by its ID
|
|
// Returns an error if the model is not found or is in a non-active state
|
|
func (s *modelService) GetModelByID(ctx context.Context, id string) (*types.Model, error) {
|
|
logger.Info(ctx, "Start getting model by ID")
|
|
logger.Infof(ctx, "Getting model with ID: %s", id)
|
|
|
|
tenantID := ctx.Value(types.TenantIDContextKey).(uint)
|
|
logger.Infof(ctx, "Tenant ID: %d", tenantID)
|
|
|
|
// Fetch model from repository
|
|
model, err := s.repo.GetByID(ctx, tenantID, id)
|
|
if err != nil {
|
|
logger.ErrorWithFields(ctx, err, map[string]interface{}{
|
|
"model_id": id,
|
|
"tenant_id": tenantID,
|
|
})
|
|
return nil, err
|
|
}
|
|
|
|
// Check if model exists
|
|
if model == nil {
|
|
logger.Error(ctx, "Model not found")
|
|
return nil, ErrModelNotFound
|
|
}
|
|
|
|
logger.Infof(ctx, "Model found, name: %s, status: %s", model.Name, model.Status)
|
|
|
|
// Check model status
|
|
if model.Status == types.ModelStatusActive {
|
|
logger.Info(ctx, "Model is active and ready to use")
|
|
return model, nil
|
|
}
|
|
|
|
if model.Status == types.ModelStatusDownloading {
|
|
logger.Warn(ctx, "Model is currently downloading")
|
|
return nil, errors.New("model is currently downloading")
|
|
}
|
|
|
|
if model.Status == types.ModelStatusDownloadFailed {
|
|
logger.Error(ctx, "Model download failed")
|
|
return nil, errors.New("model download failed")
|
|
}
|
|
|
|
logger.Error(ctx, "Model status is abnormal")
|
|
return nil, errors.New("abnormal model status")
|
|
}
|
|
|
|
// ListModels returns all models belonging to the tenant
|
|
func (s *modelService) ListModels(ctx context.Context) ([]*types.Model, error) {
|
|
logger.Info(ctx, "Start listing models")
|
|
|
|
tenantID := ctx.Value(types.TenantIDContextKey).(uint)
|
|
logger.Infof(ctx, "Listing models for tenant ID: %d", tenantID)
|
|
|
|
// List models from repository with no additional filters
|
|
models, err := s.repo.List(ctx, tenantID, "", "")
|
|
if err != nil {
|
|
logger.ErrorWithFields(ctx, err, map[string]interface{}{
|
|
"tenant_id": tenantID,
|
|
})
|
|
return nil, err
|
|
}
|
|
|
|
logger.Infof(ctx, "Retrieved %d models successfully", len(models))
|
|
return models, nil
|
|
}
|
|
|
|
// UpdateModel updates an existing model in the repository
|
|
func (s *modelService) UpdateModel(ctx context.Context, model *types.Model) error {
|
|
logger.Info(ctx, "Start updating model")
|
|
logger.Infof(ctx, "Updating model ID: %s, name: %s", model.ID, model.Name)
|
|
|
|
// Update model in repository
|
|
err := s.repo.Update(ctx, model)
|
|
if err != nil {
|
|
logger.ErrorWithFields(ctx, err, map[string]interface{}{
|
|
"model_id": model.ID,
|
|
"model_name": model.Name,
|
|
})
|
|
return err
|
|
}
|
|
|
|
logger.Infof(ctx, "Model updated successfully: %s", model.ID)
|
|
return nil
|
|
}
|
|
|
|
// DeleteModel removes a model from the repository
|
|
func (s *modelService) DeleteModel(ctx context.Context, id string) error {
|
|
logger.Info(ctx, "Start deleting model")
|
|
logger.Infof(ctx, "Deleting model ID: %s", id)
|
|
|
|
tenantID := ctx.Value(types.TenantIDContextKey).(uint)
|
|
logger.Infof(ctx, "Tenant ID: %d", tenantID)
|
|
|
|
// Delete model from repository
|
|
err := s.repo.Delete(ctx, tenantID, id)
|
|
if err != nil {
|
|
logger.ErrorWithFields(ctx, err, map[string]interface{}{
|
|
"model_id": id,
|
|
"tenant_id": tenantID,
|
|
})
|
|
return err
|
|
}
|
|
|
|
logger.Infof(ctx, "Model deleted successfully: %s", id)
|
|
return nil
|
|
}
|
|
|
|
// GetEmbeddingModel retrieves and initializes an embedding model instance
|
|
// Takes a model ID and returns an Embedder interface implementation
|
|
func (s *modelService) GetEmbeddingModel(ctx context.Context, modelId string) (embedding.Embedder, error) {
|
|
logger.Info(ctx, "Start getting embedding model")
|
|
logger.Infof(ctx, "Getting embedding model with ID: %s", modelId)
|
|
|
|
// Get the model details
|
|
model, err := s.GetModelByID(ctx, modelId)
|
|
if err != nil {
|
|
logger.ErrorWithFields(ctx, err, map[string]interface{}{
|
|
"model_id": modelId,
|
|
})
|
|
return nil, err
|
|
}
|
|
|
|
logger.Info(ctx, "Creating embedder instance")
|
|
logger.Infof(ctx, "Model name: %s, source: %s", model.Name, model.Source)
|
|
|
|
// Initialize the embedder with model configuration
|
|
embedder, err := embedding.NewEmbedder(embedding.Config{
|
|
Source: model.Source,
|
|
BaseURL: model.Parameters.BaseURL,
|
|
APIKey: model.Parameters.APIKey,
|
|
ModelID: model.ID,
|
|
ModelName: model.Name,
|
|
Dimensions: model.Parameters.EmbeddingParameters.Dimension,
|
|
TruncatePromptTokens: model.Parameters.EmbeddingParameters.TruncatePromptTokens,
|
|
})
|
|
if err != nil {
|
|
logger.ErrorWithFields(ctx, err, map[string]interface{}{
|
|
"model_id": model.ID,
|
|
"model_name": model.Name,
|
|
})
|
|
return nil, err
|
|
}
|
|
|
|
logger.Info(ctx, "Embedding model initialized successfully")
|
|
return embedder, nil
|
|
}
|
|
|
|
// GetRerankModel retrieves and initializes a reranking model instance
|
|
// Takes a model ID and returns a Reranker interface implementation
|
|
func (s *modelService) GetRerankModel(ctx context.Context, modelId string) (rerank.Reranker, error) {
|
|
logger.Info(ctx, "Start getting rerank model")
|
|
logger.Infof(ctx, "Getting rerank model with ID: %s", modelId)
|
|
|
|
// Get the model details
|
|
model, err := s.GetModelByID(ctx, modelId)
|
|
if err != nil {
|
|
logger.ErrorWithFields(ctx, err, map[string]interface{}{
|
|
"model_id": modelId,
|
|
})
|
|
return nil, err
|
|
}
|
|
|
|
logger.Info(ctx, "Creating reranker instance")
|
|
logger.Infof(ctx, "Model name: %s, source: %s", model.Name, model.Source)
|
|
|
|
// Initialize the reranker with model configuration
|
|
reranker, err := rerank.NewReranker(&rerank.RerankerConfig{
|
|
ModelID: model.ID,
|
|
APIKey: model.Parameters.APIKey,
|
|
BaseURL: model.Parameters.BaseURL,
|
|
ModelName: model.Name,
|
|
Source: model.Source,
|
|
})
|
|
if err != nil {
|
|
logger.ErrorWithFields(ctx, err, map[string]interface{}{
|
|
"model_id": model.ID,
|
|
"model_name": model.Name,
|
|
})
|
|
return nil, err
|
|
}
|
|
|
|
logger.Info(ctx, "Rerank model initialized successfully")
|
|
return reranker, nil
|
|
}
|
|
|
|
// GetChatModel retrieves and initializes a chat model instance
|
|
// Takes a model ID and returns a Chat interface implementation
|
|
func (s *modelService) GetChatModel(ctx context.Context, modelId string) (chat.Chat, error) {
|
|
logger.Info(ctx, "Start getting chat model")
|
|
logger.Infof(ctx, "Getting chat model with ID: %s", modelId)
|
|
|
|
tenantID := ctx.Value(types.TenantIDContextKey).(uint)
|
|
logger.Infof(ctx, "Tenant ID: %d", tenantID)
|
|
|
|
// Get the model directly from repository to avoid status checks
|
|
model, err := s.repo.GetByID(ctx, tenantID, modelId)
|
|
if err != nil {
|
|
logger.ErrorWithFields(ctx, err, map[string]interface{}{
|
|
"model_id": modelId,
|
|
"tenant_id": tenantID,
|
|
})
|
|
return nil, err
|
|
}
|
|
|
|
if model == nil {
|
|
logger.Error(ctx, "Chat model not found")
|
|
return nil, ErrModelNotFound
|
|
}
|
|
|
|
logger.Info(ctx, "Creating chat model instance")
|
|
logger.Infof(ctx, "Model name: %s, source: %s", model.Name, model.Source)
|
|
|
|
// Initialize the chat model with model configuration
|
|
chatModel, err := chat.NewChat(&chat.ChatConfig{
|
|
ModelID: model.ID,
|
|
APIKey: model.Parameters.APIKey,
|
|
BaseURL: model.Parameters.BaseURL,
|
|
ModelName: model.Name,
|
|
Source: model.Source,
|
|
})
|
|
if err != nil {
|
|
logger.ErrorWithFields(ctx, err, map[string]interface{}{
|
|
"model_id": model.ID,
|
|
"model_name": model.Name,
|
|
})
|
|
return nil, err
|
|
}
|
|
|
|
logger.Info(ctx, "Chat model initialized successfully")
|
|
return chatModel, nil
|
|
}
|