l_ai_knowledge/internal/container/container.go

377 lines
13 KiB
Go

// Package container implements dependency injection container setup
// Provides centralized configuration for services, repositories, and handlers
// This package is responsible for wiring up all dependencies and ensuring proper lifecycle management
package container
import (
"context"
"fmt"
"os"
"slices"
"strconv"
"strings"
"time"
esv7 "github.com/elastic/go-elasticsearch/v7"
"github.com/elastic/go-elasticsearch/v8"
"github.com/panjf2000/ants/v2"
"go.uber.org/dig"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"knowlege-lsxd/internal/application/repository"
elasticsearchRepoV7 "knowlege-lsxd/internal/application/repository/retriever/elasticsearch/v7"
elasticsearchRepoV8 "knowlege-lsxd/internal/application/repository/retriever/elasticsearch/v8"
postgresRepo "knowlege-lsxd/internal/application/repository/retriever/postgres"
"knowlege-lsxd/internal/application/service"
chatpipline "knowlege-lsxd/internal/application/service/chat_pipline"
"knowlege-lsxd/internal/application/service/file"
"knowlege-lsxd/internal/application/service/retriever"
"knowlege-lsxd/internal/config"
"knowlege-lsxd/internal/handler"
"knowlege-lsxd/internal/logger"
"knowlege-lsxd/internal/models/embedding"
"knowlege-lsxd/internal/models/utils/ollama"
"knowlege-lsxd/internal/router"
"knowlege-lsxd/internal/stream"
"knowlege-lsxd/internal/tracing"
"knowlege-lsxd/internal/types"
"knowlege-lsxd/internal/types/interfaces"
"knowlege-lsxd/services/docreader/src/client"
)
// BuildContainer constructs the dependency injection container
// Registers all components, services, repositories and handlers needed by the application
// Creates a fully configured application container with proper dependency resolution
// Parameters:
// - container: Base dig container to add dependencies to
//
// Returns:
// - Configured container with all application dependencies registered
func BuildContainer(container *dig.Container) *dig.Container {
// Register resource cleaner for proper cleanup of resources
must(container.Provide(NewResourceCleaner, dig.As(new(interfaces.ResourceCleaner))))
// Core infrastructure configuration
must(container.Provide(config.LoadConfig))
must(container.Provide(initTracer))
must(container.Provide(initDatabase))
must(container.Provide(initFileService))
must(container.Provide(initAntsPool))
// Register goroutine pool cleanup handler
must(container.Invoke(registerPoolCleanup))
// Initialize retrieval engine registry for search capabilities
must(container.Provide(initRetrieveEngineRegistry))
// External service clients
must(container.Provide(initDocReaderClient))
must(container.Provide(initOllamaService))
must(container.Provide(stream.NewStreamManager))
// Data repositories layer
must(container.Provide(repository.NewTenantRepository))
must(container.Provide(repository.NewKnowledgeBaseRepository))
must(container.Provide(repository.NewKnowledgeRepository))
must(container.Provide(repository.NewChunkRepository))
must(container.Provide(repository.NewSessionRepository))
must(container.Provide(repository.NewMessageRepository))
must(container.Provide(repository.NewModelRepository))
// Business service layer
must(container.Provide(service.NewTenantService))
must(container.Provide(service.NewKnowledgeBaseService))
must(container.Provide(service.NewKnowledgeService))
must(container.Provide(service.NewSessionService))
must(container.Provide(service.NewMessageService))
must(container.Provide(service.NewChunkService))
must(container.Provide(embedding.NewBatchEmbedder))
must(container.Provide(service.NewTestDataService))
must(container.Provide(service.NewModelService))
must(container.Provide(service.NewDatasetService))
must(container.Provide(service.NewEvaluationService))
// Chat pipeline components for processing chat requests
must(container.Provide(chatpipline.NewEventManager))
must(container.Invoke(chatpipline.NewPluginTracing))
must(container.Invoke(chatpipline.NewPluginSearch))
must(container.Invoke(chatpipline.NewPluginRerank))
must(container.Invoke(chatpipline.NewPluginMerge))
must(container.Invoke(chatpipline.NewPluginIntoChatMessage))
must(container.Invoke(chatpipline.NewPluginChatCompletion))
must(container.Invoke(chatpipline.NewPluginChatCompletionStream))
must(container.Invoke(chatpipline.NewPluginStreamFilter))
must(container.Invoke(chatpipline.NewPluginFilterTopK))
must(container.Invoke(chatpipline.NewPluginPreprocess))
must(container.Invoke(chatpipline.NewPluginRewrite))
// HTTP handlers layer
must(container.Provide(handler.NewTenantHandler))
must(container.Provide(handler.NewKnowledgeBaseHandler))
must(container.Provide(handler.NewKnowledgeHandler))
must(container.Provide(handler.NewChunkHandler))
must(container.Provide(handler.NewSessionHandler))
must(container.Provide(handler.NewMessageHandler))
must(container.Provide(handler.NewTestDataHandler))
must(container.Provide(handler.NewModelHandler))
must(container.Provide(handler.NewEvaluationHandler))
must(container.Provide(handler.NewInitializationHandler))
// Router configuration
must(container.Provide(router.NewRouter))
return container
}
// must is a helper function for error handling
// Panics if the error is not nil, useful for configuration steps that must succeed
// Parameters:
// - err: Error to check
func must(err error) {
if err != nil {
panic(err)
}
}
// initTracer initializes OpenTelemetry tracer
// Sets up distributed tracing for observability across the application
// Parameters:
// - None
//
// Returns:
// - Configured tracer instance
// - Error if initialization fails
func initTracer() (*tracing.Tracer, error) {
return tracing.InitTracer()
}
// initDatabase initializes database connection
// Creates and configures database connection based on environment configuration
// Supports multiple database backends (PostgreSQL)
// Parameters:
// - cfg: Application configuration
//
// Returns:
// - Configured database connection
// - Error if connection fails
func initDatabase(cfg *config.Config) (*gorm.DB, error) {
var dialector gorm.Dialector
driver := os.Getenv("DB_DRIVER")
switch driver {
case "postgres":
dsn := fmt.Sprintf(
"host=%s port=%s user=%s password=%s dbname=%s sslmode=%s",
os.Getenv("DB_HOST"),
os.Getenv("DB_PORT"),
os.Getenv("DB_USER"),
os.Getenv("DB_PASSWORD"),
os.Getenv("DB_NAME"),
"disable",
)
dialector = postgres.Open(dsn)
default:
return nil, fmt.Errorf("unsupported database driver: %s", os.Getenv("DB_DRIVER"))
}
db, err := gorm.Open(dialector, &gorm.Config{})
if err != nil {
return nil, err
}
// Get underlying SQL DB object
sqlDB, err := db.DB()
if err != nil {
return nil, err
}
// Configure connection pool parameters
sqlDB.SetMaxIdleConns(10)
sqlDB.SetConnMaxLifetime(time.Duration(10) * time.Minute)
return db, nil
}
// initFileService initializes file storage service
// Creates the appropriate file storage service based on configuration
// Supports multiple storage backends (MinIO, COS, local filesystem)
// Parameters:
// - cfg: Application configuration
//
// Returns:
// - Configured file service implementation
// - Error if initialization fails
func initFileService(cfg *config.Config) (interfaces.FileService, error) {
switch os.Getenv("STORAGE_TYPE") {
case "minio":
if os.Getenv("MINIO_ENDPOINT") == "" ||
os.Getenv("MINIO_ACCESS_KEY_ID") == "" ||
os.Getenv("MINIO_SECRET_ACCESS_KEY") == "" ||
os.Getenv("MINIO_BUCKET_NAME") == "" {
return nil, fmt.Errorf("missing MinIO configuration")
}
return file.NewMinioFileService(
os.Getenv("MINIO_ENDPOINT"),
os.Getenv("MINIO_ACCESS_KEY_ID"),
os.Getenv("MINIO_SECRET_ACCESS_KEY"),
os.Getenv("MINIO_BUCKET_NAME"),
false,
)
case "cos":
if os.Getenv("COS_APP_ID") == "" ||
os.Getenv("COS_REGION") == "" ||
os.Getenv("COS_SECRET_ID") == "" ||
os.Getenv("COS_SECRET_KEY") == "" ||
os.Getenv("COS_PATH_PREFIX") == "" {
return nil, fmt.Errorf("missing COS configuration")
}
return file.NewCosFileService(
os.Getenv("COS_APP_ID"),
os.Getenv("COS_REGION"),
os.Getenv("COS_SECRET_ID"),
os.Getenv("COS_SECRET_KEY"),
os.Getenv("COS_PATH_PREFIX"),
)
case "local":
return file.NewLocalFileService(os.Getenv("LOCAL_STORAGE_BASE_DIR")), nil
case "dummy":
return file.NewDummyFileService(), nil
default:
return nil, fmt.Errorf("unsupported storage type: %s", os.Getenv("STORAGE_TYPE"))
}
}
// initRetrieveEngineRegistry initializes the retrieval engine registry
// Sets up and configures various search engine backends based on configuration
// Supports multiple retrieval engines (PostgreSQL, ElasticsearchV7, ElasticsearchV8)
// Parameters:
// - db: Database connection
// - cfg: Application configuration
//
// Returns:
// - Configured retrieval engine registry
// - Error if initialization fails
func initRetrieveEngineRegistry(db *gorm.DB, cfg *config.Config) (interfaces.RetrieveEngineRegistry, error) {
registry := retriever.NewRetrieveEngineRegistry()
retrieveDriver := strings.Split(os.Getenv("RETRIEVE_DRIVER"), ",")
log := logger.GetLogger(context.Background())
if slices.Contains(retrieveDriver, "postgres") {
postgresRepo := postgresRepo.NewPostgresRetrieveEngineRepository(db)
if err := registry.Register(
retriever.NewKVHybridRetrieveEngine(postgresRepo, types.PostgresRetrieverEngineType),
); err != nil {
log.Errorf("Register postgres retrieve engine failed: %v", err)
} else {
log.Infof("Register postgres retrieve engine success")
}
}
if slices.Contains(retrieveDriver, "elasticsearch_v8") {
client, err := elasticsearch.NewTypedClient(elasticsearch.Config{
Addresses: []string{os.Getenv("ELASTICSEARCH_ADDR")},
Username: os.Getenv("ELASTICSEARCH_USERNAME"),
Password: os.Getenv("ELASTICSEARCH_PASSWORD"),
})
if err != nil {
log.Errorf("Create elasticsearch_v8 client failed: %v", err)
} else {
elasticsearchRepo := elasticsearchRepoV8.NewElasticsearchEngineRepository(client, cfg)
if err := registry.Register(
retriever.NewKVHybridRetrieveEngine(
elasticsearchRepo, types.ElasticsearchRetrieverEngineType,
),
); err != nil {
log.Errorf("Register elasticsearch_v8 retrieve engine failed: %v", err)
} else {
log.Infof("Register elasticsearch_v8 retrieve engine success")
}
}
}
if slices.Contains(retrieveDriver, "elasticsearch_v7") {
client, err := esv7.NewClient(esv7.Config{
Addresses: []string{os.Getenv("ELASTICSEARCH_ADDR")},
Username: os.Getenv("ELASTICSEARCH_USERNAME"),
Password: os.Getenv("ELASTICSEARCH_PASSWORD"),
})
if err != nil {
log.Errorf("Create elasticsearch_v7 client failed: %v", err)
} else {
elasticsearchRepo := elasticsearchRepoV7.NewElasticsearchEngineRepository(client, cfg)
if err := registry.Register(
retriever.NewKVHybridRetrieveEngine(
elasticsearchRepo, types.ElasticsearchRetrieverEngineType,
),
); err != nil {
log.Errorf("Register elasticsearch_v7 retrieve engine failed: %v", err)
} else {
log.Infof("Register elasticsearch_v7 retrieve engine success")
}
}
}
return registry, nil
}
// initAntsPool initializes the goroutine pool
// Creates a managed goroutine pool for concurrent task execution
// Parameters:
// - cfg: Application configuration
//
// Returns:
// - Configured goroutine pool
// - Error if initialization fails
func initAntsPool(cfg *config.Config) (*ants.Pool, error) {
// Default to 5 if not specified in config
poolSize := os.Getenv("CONCURRENCY_POOL_SIZE")
if poolSize == "" {
poolSize = "5"
}
poolSizeInt, err := strconv.Atoi(poolSize)
if err != nil {
return nil, err
}
// Set up the pool with pre-allocation for better performance
return ants.NewPool(poolSizeInt, ants.WithPreAlloc(true))
}
// registerPoolCleanup registers the goroutine pool for cleanup
// Ensures proper cleanup of the goroutine pool when application shuts down
// Parameters:
// - pool: Goroutine pool
// - cleaner: Resource cleaner
func registerPoolCleanup(pool *ants.Pool, cleaner interfaces.ResourceCleaner) {
cleaner.RegisterWithName("AntsPool", func() error {
pool.Release()
return nil
})
}
// initDocReaderClient initializes the document reader client
// Creates a client for interacting with the document reader service
// Parameters:
// - cfg: Application configuration
//
// Returns:
// - Configured document reader client
// - Error if initialization fails
func initDocReaderClient(cfg *config.Config) (*client.Client, error) {
// Use the DocReader URL from environment or config
docReaderURL := os.Getenv("DOCREADER_ADDR")
if docReaderURL == "" && cfg.DocReader != nil {
docReaderURL = cfg.DocReader.Addr
}
return client.NewClient(docReaderURL)
}
// initOllamaService initializes the Ollama service client
// Creates a client for interacting with Ollama API for model inference
// Parameters:
// - None
//
// Returns:
// - Configured Ollama service client
// - Error if initialization fails
func initOllamaService() (*ollama.OllamaService, error) {
// Get Ollama service from existing factory function
return ollama.GetOllamaService()
}