// Package service 提供应用程序的核心业务逻辑服务层 // 此包包含了知识库管理、用户租户管理、模型服务等核心功能实现 package service import ( "context" "fmt" "os" "strconv" "knowlege-lsxd/internal/config" "knowlege-lsxd/internal/logger" "knowlege-lsxd/internal/models/chat" "knowlege-lsxd/internal/models/embedding" "knowlege-lsxd/internal/models/rerank" "knowlege-lsxd/internal/models/utils/ollama" "knowlege-lsxd/internal/types" "knowlege-lsxd/internal/types/interfaces" ) // TestDataService 测试数据服务 // 负责初始化测试环境所需的数据,包括创建测试租户、测试知识库 // 以及配置必要的模型服务实例 type TestDataService struct { config *config.Config // 应用程序配置 kbRepo interfaces.KnowledgeBaseRepository // 知识库存储库接口 tenantService interfaces.TenantService // 租户服务接口 ollamaService *ollama.OllamaService // Ollama模型服务 modelService interfaces.ModelService // 模型服务接口 EmbedModel embedding.Embedder // 嵌入模型实例 RerankModel rerank.Reranker // 重排模型实例 LLMModel chat.Chat // 大语言模型实例 } // NewTestDataService 创建测试数据服务 // 注入所需的依赖服务和组件 func NewTestDataService( config *config.Config, kbRepo interfaces.KnowledgeBaseRepository, tenantService interfaces.TenantService, ollamaService *ollama.OllamaService, modelService interfaces.ModelService, ) *TestDataService { return &TestDataService{ config: config, kbRepo: kbRepo, tenantService: tenantService, ollamaService: ollamaService, modelService: modelService, } } // initTenant 初始化测试租户 // 通过环境变量获取租户ID,如果租户不存在则创建新租户,否则更新现有租户 // 同时配置租户的检索引擎参数 func (s *TestDataService) initTenant(ctx context.Context) error { logger.Info(ctx, "Start initializing test tenant") // 从环境变量获取租户ID tenantID := os.Getenv("INIT_TEST_TENANT_ID") logger.Infof(ctx, "Test tenant ID from environment: %s", tenantID) // 将字符串ID转换为uint64 tenantIDUint, err := strconv.ParseUint(tenantID, 10, 64) if err != nil { logger.Errorf(ctx, "Failed to parse tenant ID: %v", err) return err } // 创建租户配置 tenantConfig := &types.Tenant{ Name: "Test Tenant", Description: "Test Tenant for Testing", RetrieverEngines: types.RetrieverEngines{ Engines: []types.RetrieverEngineParams{ { RetrieverType: types.KeywordsRetrieverType, RetrieverEngineType: types.PostgresRetrieverEngineType, }, { RetrieverType: types.VectorRetrieverType, RetrieverEngineType: types.PostgresRetrieverEngineType, }, }, }, } // 获取或创建测试租户 logger.Infof(ctx, "Attempting to get tenant with ID: %d", tenantIDUint) tenant, err := s.tenantService.GetTenantByID(ctx, uint(tenantIDUint)) if err != nil { // 租户不存在,创建新租户 logger.Info(ctx, "Tenant not found, creating a new test tenant") tenant, err = s.tenantService.CreateTenant(ctx, tenantConfig) if err != nil { logger.Errorf(ctx, "Failed to create tenant: %v", err) return err } logger.Infof(ctx, "Created new test tenant with ID: %d", tenant.ID) } else { // 租户存在,更新检索引擎配置 logger.Info(ctx, "Test tenant found, updating retriever engines") tenant.RetrieverEngines = tenantConfig.RetrieverEngines tenant, err = s.tenantService.UpdateTenant(ctx, tenant) if err != nil { logger.Errorf(ctx, "Failed to update tenant: %v", err) return err } logger.Info(ctx, "Test tenant updated successfully") } logger.Infof(ctx, "Test tenant configured - ID: %d, Name: %s, API Key: %s", tenant.ID, tenant.Name, tenant.APIKey) return nil } // initKnowledgeBase 初始化测试知识库 // 从环境变量获取知识库ID,创建或更新知识库 // 配置知识库的分块策略、嵌入模型和摘要模型 func (s *TestDataService) initKnowledgeBase(ctx context.Context) error { logger.Info(ctx, "Start initializing test knowledge base") // 检查上下文中的租户ID if ctx.Value(types.TenantIDContextKey).(uint) == 0 { logger.Warn(ctx, "Tenant ID is 0, skipping knowledge base initialization") return nil } // 从环境变量获取知识库ID knowledgeBaseID := os.Getenv("INIT_TEST_KNOWLEDGE_BASE_ID") logger.Infof(ctx, "Test knowledge base ID from environment: %s", knowledgeBaseID) // 创建知识库配置 kbConfig := &types.KnowledgeBase{ ID: knowledgeBaseID, Name: "Test Knowledge Base", Description: "Knowledge Base for Testing", TenantID: ctx.Value(types.TenantIDContextKey).(uint), ChunkingConfig: types.ChunkingConfig{ ChunkSize: s.config.KnowledgeBase.ChunkSize, ChunkOverlap: s.config.KnowledgeBase.ChunkOverlap, Separators: s.config.KnowledgeBase.SplitMarkers, EnableMultimodal: s.config.KnowledgeBase.ImageProcessing.EnableMultimodal, }, EmbeddingModelID: s.EmbedModel.GetModelID(), SummaryModelID: s.LLMModel.GetModelID(), RerankModelID: s.RerankModel.GetModelID(), } // 初始化测试知识库 logger.Info(ctx, "Attempting to get existing knowledge base") _, err := s.kbRepo.GetKnowledgeBaseByID(ctx, knowledgeBaseID) if err != nil { // 知识库不存在,创建新知识库 logger.Info(ctx, "Knowledge base not found, creating a new one") logger.Infof(ctx, "Creating knowledge base with ID: %s, tenant ID: %d", kbConfig.ID, kbConfig.TenantID) if err := s.kbRepo.CreateKnowledgeBase(ctx, kbConfig); err != nil { logger.Errorf(ctx, "Failed to create knowledge base: %v", err) return err } logger.Info(ctx, "Knowledge base created successfully") } else { // 知识库存在,更新配置 logger.Info(ctx, "Knowledge base found, updating configuration") logger.Infof(ctx, "Updating knowledge base with ID: %s", kbConfig.ID) err = s.kbRepo.UpdateKnowledgeBase(ctx, kbConfig) if err != nil { logger.Errorf(ctx, "Failed to update knowledge base: %v", err) return err } logger.Info(ctx, "Knowledge base updated successfully") } logger.Infof(ctx, "Test knowledge base configured - ID: %s, Name: %s", kbConfig.ID, kbConfig.Name) return nil } // InitializeTestData 初始化测试数据 // 这是对外暴露的主要方法,负责协调所有测试数据的初始化过程 // 包括初始化租户、嵌入模型、重排模型、LLM模型和知识库 func (s *TestDataService) InitializeTestData(ctx context.Context) error { logger.Info(ctx, "Start initializing test data") // 从环境变量获取租户ID tenantID := os.Getenv("INIT_TEST_TENANT_ID") logger.Infof(ctx, "Test tenant ID from environment: %s", tenantID) // 解析租户ID tenantIDUint, err := strconv.ParseUint(tenantID, 10, 64) if err != nil { // 解析失败时使用默认值0 logger.Warn(ctx, "Failed to parse tenant ID, using default value 0") tenantIDUint = 0 } else { // 初始化租户 logger.Info(ctx, "Initializing tenant") err = s.initTenant(ctx) if err != nil { logger.Errorf(ctx, "Failed to initialize tenant: %v", err) return err } logger.Info(ctx, "Tenant initialized successfully") } // 创建带有租户ID的新上下文 newCtx := context.Background() newCtx = context.WithValue(newCtx, types.TenantIDContextKey, uint(tenantIDUint)) logger.Infof(ctx, "Created new context with tenant ID: %d", tenantIDUint) // 初始化模型 modelInitFuncs := []struct { name string fn func(context.Context) error }{ {"embedding model", s.initEmbeddingModel}, {"rerank model", s.initRerankModel}, {"LLM model", s.initLLMModel}, } for _, initFunc := range modelInitFuncs { logger.Infof(ctx, "Initializing %s", initFunc.name) if err := initFunc.fn(newCtx); err != nil { logger.Errorf(ctx, "Failed to initialize %s: %v", initFunc.name, err) return err } logger.Infof(ctx, "%s initialized successfully", initFunc.name) } // 初始化知识库 logger.Info(ctx, "Initializing knowledge base") if err := s.initKnowledgeBase(newCtx); err != nil { logger.Errorf(ctx, "Failed to initialize knowledge base: %v", err) return err } logger.Info(ctx, "Knowledge base initialized successfully") logger.Info(ctx, "Test data initialization completed") return nil } // getEnvOrError 获取环境变量值,如果不存在则返回错误 func (s *TestDataService) getEnvOrError(name string) (string, error) { value := os.Getenv(name) if value == "" { return "", fmt.Errorf("%s environment variable is not set", name) } return value, nil } // updateOrCreateModel 更新或创建模型 func (s *TestDataService) updateOrCreateModel(ctx context.Context, modelConfig *types.Model) error { model, err := s.modelService.GetModelByID(ctx, modelConfig.ID) if err != nil { // 模型不存在,创建新模型 return s.modelService.CreateModel(ctx, modelConfig) } // 模型存在,更新属性 model.TenantID = modelConfig.TenantID model.Name = modelConfig.Name model.Source = modelConfig.Source model.Type = modelConfig.Type model.Parameters = modelConfig.Parameters model.Status = modelConfig.Status return s.modelService.UpdateModel(ctx, model) } // initEmbeddingModel 初始化嵌入模型 func (s *TestDataService) initEmbeddingModel(ctx context.Context) error { // 从环境变量获取模型参数 modelName, err := s.getEnvOrError("INIT_EMBEDDING_MODEL_NAME") if err != nil { return err } dimensionStr := os.Getenv("INIT_EMBEDDING_MODEL_DIMENSION") dimension, err := strconv.Atoi(dimensionStr) if err != nil || dimension == 0 { return fmt.Errorf("invalid embedding model dimension: %s", dimensionStr) } baseURL := os.Getenv("INIT_EMBEDDING_MODEL_BASE_URL") apiKey := os.Getenv("INIT_EMBEDDING_MODEL_API_KEY") // 确定模型来源 source := types.ModelSourceRemote if baseURL == "" { source = types.ModelSourceLocal } // 确定模型ID modelID := os.Getenv("INIT_EMBEDDING_MODEL_ID") if modelID == "" { modelID = fmt.Sprintf("builtin:%s:%d", modelName, dimension) } // 创建嵌入模型实例 s.EmbedModel, err = embedding.NewEmbedder(embedding.Config{ Source: source, BaseURL: baseURL, ModelName: modelName, APIKey: apiKey, Dimensions: dimension, ModelID: modelID, }) if err != nil { return fmt.Errorf("failed to create embedder: %w", err) } // 如果是本地模型,使用Ollama拉取模型 if source == types.ModelSourceLocal && s.ollamaService != nil { if err := s.ollamaService.PullModel(context.Background(), modelName); err != nil { return fmt.Errorf("failed to pull embedding model: %w", err) } } // 创建模型配置 modelConfig := &types.Model{ ID: modelID, TenantID: ctx.Value(types.TenantIDContextKey).(uint), Name: modelName, Source: source, Type: types.ModelTypeEmbedding, Parameters: types.ModelParameters{ BaseURL: baseURL, APIKey: apiKey, EmbeddingParameters: types.EmbeddingParameters{ Dimension: dimension, }, }, Status: "active", } // 更新或创建模型 return s.updateOrCreateModel(ctx, modelConfig) } // initRerankModel 初始化重排模型 func (s *TestDataService) initRerankModel(ctx context.Context) error { // 从环境变量获取模型参数 modelName, err := s.getEnvOrError("INIT_RERANK_MODEL_NAME") if err != nil { logger.Warnf(ctx, "Skip Rerank Model: %v", err) return nil } baseURL, err := s.getEnvOrError("INIT_RERANK_MODEL_BASE_URL") if err != nil { return err } apiKey := os.Getenv("INIT_RERANK_MODEL_API_KEY") modelID := fmt.Sprintf("builtin:%s:rerank:%s", types.ModelSourceRemote, modelName) // 创建重排模型实例 s.RerankModel, err = rerank.NewReranker(&rerank.RerankerConfig{ Source: types.ModelSourceRemote, BaseURL: baseURL, ModelName: modelName, APIKey: apiKey, ModelID: modelID, }) if err != nil { return fmt.Errorf("failed to create reranker: %w", err) } // 创建模型配置 modelConfig := &types.Model{ ID: modelID, TenantID: ctx.Value(types.TenantIDContextKey).(uint), Name: modelName, Source: types.ModelSourceRemote, Type: types.ModelTypeRerank, Parameters: types.ModelParameters{ BaseURL: baseURL, APIKey: apiKey, }, Status: "active", } // 更新或创建模型 return s.updateOrCreateModel(ctx, modelConfig) } // initLLMModel 初始化大语言模型 func (s *TestDataService) initLLMModel(ctx context.Context) error { // 从环境变量获取模型参数 modelName, err := s.getEnvOrError("INIT_LLM_MODEL_NAME") if err != nil { return err } baseURL := os.Getenv("INIT_LLM_MODEL_BASE_URL") apiKey := os.Getenv("INIT_LLM_MODEL_API_KEY") // 确定模型来源 source := types.ModelSourceRemote if baseURL == "" { source = types.ModelSourceLocal } // 确定模型ID modelID := fmt.Sprintf("builtin:%s:llm:%s", source, modelName) // 创建大语言模型实例 s.LLMModel, err = chat.NewChat(&chat.ChatConfig{ Source: source, BaseURL: baseURL, ModelName: modelName, APIKey: apiKey, ModelID: modelID, }) if err != nil { return fmt.Errorf("failed to create llm: %w", err) } // 如果是本地模型,使用Ollama拉取模型 if source == types.ModelSourceLocal && s.ollamaService != nil { if err := s.ollamaService.PullModel(context.Background(), modelName); err != nil { return fmt.Errorf("failed to pull llm model: %w", err) } } // 创建模型配置 modelConfig := &types.Model{ ID: modelID, TenantID: ctx.Value(types.TenantIDContextKey).(uint), Name: modelName, Source: source, Type: types.ModelTypeKnowledgeQA, Parameters: types.ModelParameters{ BaseURL: baseURL, APIKey: apiKey, }, Status: "active", } // 更新或创建模型 return s.updateOrCreateModel(ctx, modelConfig) }