377 lines
13 KiB
Go
377 lines
13 KiB
Go
// Package container implements dependency injection container setup
|
|
// Provides centralized configuration for services, repositories, and handlers
|
|
// This package is responsible for wiring up all dependencies and ensuring proper lifecycle management
|
|
package container
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"os"
|
|
"slices"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
esv7 "github.com/elastic/go-elasticsearch/v7"
|
|
"github.com/elastic/go-elasticsearch/v8"
|
|
"github.com/panjf2000/ants/v2"
|
|
"go.uber.org/dig"
|
|
"gorm.io/driver/postgres"
|
|
"gorm.io/gorm"
|
|
|
|
"knowlege-lsxd/internal/application/repository"
|
|
elasticsearchRepoV7 "knowlege-lsxd/internal/application/repository/retriever/elasticsearch/v7"
|
|
elasticsearchRepoV8 "knowlege-lsxd/internal/application/repository/retriever/elasticsearch/v8"
|
|
postgresRepo "knowlege-lsxd/internal/application/repository/retriever/postgres"
|
|
"knowlege-lsxd/internal/application/service"
|
|
chatpipline "knowlege-lsxd/internal/application/service/chat_pipline"
|
|
"knowlege-lsxd/internal/application/service/file"
|
|
"knowlege-lsxd/internal/application/service/retriever"
|
|
"knowlege-lsxd/internal/config"
|
|
"knowlege-lsxd/internal/handler"
|
|
"knowlege-lsxd/internal/logger"
|
|
"knowlege-lsxd/internal/models/embedding"
|
|
"knowlege-lsxd/internal/models/utils/ollama"
|
|
"knowlege-lsxd/internal/router"
|
|
"knowlege-lsxd/internal/stream"
|
|
"knowlege-lsxd/internal/tracing"
|
|
"knowlege-lsxd/internal/types"
|
|
"knowlege-lsxd/internal/types/interfaces"
|
|
"knowlege-lsxd/services/docreader/src/client"
|
|
)
|
|
|
|
// BuildContainer constructs the dependency injection container
|
|
// Registers all components, services, repositories and handlers needed by the application
|
|
// Creates a fully configured application container with proper dependency resolution
|
|
// Parameters:
|
|
// - container: Base dig container to add dependencies to
|
|
//
|
|
// Returns:
|
|
// - Configured container with all application dependencies registered
|
|
func BuildContainer(container *dig.Container) *dig.Container {
|
|
// Register resource cleaner for proper cleanup of resources
|
|
must(container.Provide(NewResourceCleaner, dig.As(new(interfaces.ResourceCleaner))))
|
|
|
|
// Core infrastructure configuration
|
|
must(container.Provide(config.LoadConfig))
|
|
must(container.Provide(initTracer))
|
|
must(container.Provide(initDatabase))
|
|
must(container.Provide(initFileService))
|
|
must(container.Provide(initAntsPool))
|
|
|
|
// Register goroutine pool cleanup handler
|
|
must(container.Invoke(registerPoolCleanup))
|
|
|
|
// Initialize retrieval engine registry for search capabilities
|
|
must(container.Provide(initRetrieveEngineRegistry))
|
|
|
|
// External service clients
|
|
must(container.Provide(initDocReaderClient))
|
|
must(container.Provide(initOllamaService))
|
|
must(container.Provide(stream.NewStreamManager))
|
|
|
|
// Data repositories layer
|
|
must(container.Provide(repository.NewTenantRepository))
|
|
must(container.Provide(repository.NewKnowledgeBaseRepository))
|
|
must(container.Provide(repository.NewKnowledgeRepository))
|
|
must(container.Provide(repository.NewChunkRepository))
|
|
must(container.Provide(repository.NewSessionRepository))
|
|
must(container.Provide(repository.NewMessageRepository))
|
|
must(container.Provide(repository.NewModelRepository))
|
|
|
|
// Business service layer
|
|
must(container.Provide(service.NewTenantService))
|
|
must(container.Provide(service.NewKnowledgeBaseService))
|
|
must(container.Provide(service.NewKnowledgeService))
|
|
must(container.Provide(service.NewSessionService))
|
|
must(container.Provide(service.NewMessageService))
|
|
must(container.Provide(service.NewChunkService))
|
|
must(container.Provide(embedding.NewBatchEmbedder))
|
|
must(container.Provide(service.NewTestDataService))
|
|
must(container.Provide(service.NewModelService))
|
|
must(container.Provide(service.NewDatasetService))
|
|
must(container.Provide(service.NewEvaluationService))
|
|
|
|
// Chat pipeline components for processing chat requests
|
|
must(container.Provide(chatpipline.NewEventManager))
|
|
must(container.Invoke(chatpipline.NewPluginTracing))
|
|
must(container.Invoke(chatpipline.NewPluginSearch))
|
|
must(container.Invoke(chatpipline.NewPluginRerank))
|
|
must(container.Invoke(chatpipline.NewPluginMerge))
|
|
must(container.Invoke(chatpipline.NewPluginIntoChatMessage))
|
|
must(container.Invoke(chatpipline.NewPluginChatCompletion))
|
|
must(container.Invoke(chatpipline.NewPluginChatCompletionStream))
|
|
must(container.Invoke(chatpipline.NewPluginStreamFilter))
|
|
must(container.Invoke(chatpipline.NewPluginFilterTopK))
|
|
must(container.Invoke(chatpipline.NewPluginPreprocess))
|
|
must(container.Invoke(chatpipline.NewPluginRewrite))
|
|
|
|
// HTTP handlers layer
|
|
must(container.Provide(handler.NewTenantHandler))
|
|
must(container.Provide(handler.NewKnowledgeBaseHandler))
|
|
must(container.Provide(handler.NewKnowledgeHandler))
|
|
must(container.Provide(handler.NewChunkHandler))
|
|
must(container.Provide(handler.NewSessionHandler))
|
|
must(container.Provide(handler.NewMessageHandler))
|
|
must(container.Provide(handler.NewTestDataHandler))
|
|
must(container.Provide(handler.NewModelHandler))
|
|
must(container.Provide(handler.NewEvaluationHandler))
|
|
must(container.Provide(handler.NewInitializationHandler))
|
|
|
|
// Router configuration
|
|
must(container.Provide(router.NewRouter))
|
|
|
|
return container
|
|
}
|
|
|
|
// must is a helper function for error handling
|
|
// Panics if the error is not nil, useful for configuration steps that must succeed
|
|
// Parameters:
|
|
// - err: Error to check
|
|
func must(err error) {
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
|
|
// initTracer initializes OpenTelemetry tracer
|
|
// Sets up distributed tracing for observability across the application
|
|
// Parameters:
|
|
// - None
|
|
//
|
|
// Returns:
|
|
// - Configured tracer instance
|
|
// - Error if initialization fails
|
|
func initTracer() (*tracing.Tracer, error) {
|
|
return tracing.InitTracer()
|
|
}
|
|
|
|
// initDatabase initializes database connection
|
|
// Creates and configures database connection based on environment configuration
|
|
// Supports multiple database backends (PostgreSQL)
|
|
// Parameters:
|
|
// - cfg: Application configuration
|
|
//
|
|
// Returns:
|
|
// - Configured database connection
|
|
// - Error if connection fails
|
|
func initDatabase(cfg *config.Config) (*gorm.DB, error) {
|
|
var dialector gorm.Dialector
|
|
driver := os.Getenv("DB_DRIVER")
|
|
switch driver {
|
|
case "postgres":
|
|
dsn := fmt.Sprintf(
|
|
"host=%s port=%s user=%s password=%s dbname=%s sslmode=%s",
|
|
os.Getenv("DB_HOST"),
|
|
os.Getenv("DB_PORT"),
|
|
os.Getenv("DB_USER"),
|
|
os.Getenv("DB_PASSWORD"),
|
|
os.Getenv("DB_NAME"),
|
|
"disable",
|
|
)
|
|
dialector = postgres.Open(dsn)
|
|
default:
|
|
return nil, fmt.Errorf("unsupported database driver: %s", os.Getenv("DB_DRIVER"))
|
|
}
|
|
db, err := gorm.Open(dialector, &gorm.Config{})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Get underlying SQL DB object
|
|
sqlDB, err := db.DB()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Configure connection pool parameters
|
|
sqlDB.SetMaxIdleConns(10)
|
|
sqlDB.SetConnMaxLifetime(time.Duration(10) * time.Minute)
|
|
|
|
return db, nil
|
|
}
|
|
|
|
// initFileService initializes file storage service
|
|
// Creates the appropriate file storage service based on configuration
|
|
// Supports multiple storage backends (MinIO, COS, local filesystem)
|
|
// Parameters:
|
|
// - cfg: Application configuration
|
|
//
|
|
// Returns:
|
|
// - Configured file service implementation
|
|
// - Error if initialization fails
|
|
func initFileService(cfg *config.Config) (interfaces.FileService, error) {
|
|
switch os.Getenv("STORAGE_TYPE") {
|
|
case "minio":
|
|
if os.Getenv("MINIO_ENDPOINT") == "" ||
|
|
os.Getenv("MINIO_ACCESS_KEY_ID") == "" ||
|
|
os.Getenv("MINIO_SECRET_ACCESS_KEY") == "" ||
|
|
os.Getenv("MINIO_BUCKET_NAME") == "" {
|
|
return nil, fmt.Errorf("missing MinIO configuration")
|
|
}
|
|
return file.NewMinioFileService(
|
|
os.Getenv("MINIO_ENDPOINT"),
|
|
os.Getenv("MINIO_ACCESS_KEY_ID"),
|
|
os.Getenv("MINIO_SECRET_ACCESS_KEY"),
|
|
os.Getenv("MINIO_BUCKET_NAME"),
|
|
false,
|
|
)
|
|
case "cos":
|
|
if os.Getenv("COS_APP_ID") == "" ||
|
|
os.Getenv("COS_REGION") == "" ||
|
|
os.Getenv("COS_SECRET_ID") == "" ||
|
|
os.Getenv("COS_SECRET_KEY") == "" ||
|
|
os.Getenv("COS_PATH_PREFIX") == "" {
|
|
return nil, fmt.Errorf("missing COS configuration")
|
|
}
|
|
return file.NewCosFileService(
|
|
os.Getenv("COS_APP_ID"),
|
|
os.Getenv("COS_REGION"),
|
|
os.Getenv("COS_SECRET_ID"),
|
|
os.Getenv("COS_SECRET_KEY"),
|
|
os.Getenv("COS_PATH_PREFIX"),
|
|
)
|
|
case "local":
|
|
return file.NewLocalFileService(os.Getenv("LOCAL_STORAGE_BASE_DIR")), nil
|
|
case "dummy":
|
|
return file.NewDummyFileService(), nil
|
|
default:
|
|
return nil, fmt.Errorf("unsupported storage type: %s", os.Getenv("STORAGE_TYPE"))
|
|
}
|
|
}
|
|
|
|
// initRetrieveEngineRegistry initializes the retrieval engine registry
|
|
// Sets up and configures various search engine backends based on configuration
|
|
// Supports multiple retrieval engines (PostgreSQL, ElasticsearchV7, ElasticsearchV8)
|
|
// Parameters:
|
|
// - db: Database connection
|
|
// - cfg: Application configuration
|
|
//
|
|
// Returns:
|
|
// - Configured retrieval engine registry
|
|
// - Error if initialization fails
|
|
func initRetrieveEngineRegistry(db *gorm.DB, cfg *config.Config) (interfaces.RetrieveEngineRegistry, error) {
|
|
registry := retriever.NewRetrieveEngineRegistry()
|
|
retrieveDriver := strings.Split(os.Getenv("RETRIEVE_DRIVER"), ",")
|
|
log := logger.GetLogger(context.Background())
|
|
|
|
if slices.Contains(retrieveDriver, "postgres") {
|
|
postgresRepo := postgresRepo.NewPostgresRetrieveEngineRepository(db)
|
|
if err := registry.Register(
|
|
retriever.NewKVHybridRetrieveEngine(postgresRepo, types.PostgresRetrieverEngineType),
|
|
); err != nil {
|
|
log.Errorf("Register postgres retrieve engine failed: %v", err)
|
|
} else {
|
|
log.Infof("Register postgres retrieve engine success")
|
|
}
|
|
}
|
|
if slices.Contains(retrieveDriver, "elasticsearch_v8") {
|
|
client, err := elasticsearch.NewTypedClient(elasticsearch.Config{
|
|
Addresses: []string{os.Getenv("ELASTICSEARCH_ADDR")},
|
|
Username: os.Getenv("ELASTICSEARCH_USERNAME"),
|
|
Password: os.Getenv("ELASTICSEARCH_PASSWORD"),
|
|
})
|
|
if err != nil {
|
|
log.Errorf("Create elasticsearch_v8 client failed: %v", err)
|
|
} else {
|
|
elasticsearchRepo := elasticsearchRepoV8.NewElasticsearchEngineRepository(client, cfg)
|
|
if err := registry.Register(
|
|
retriever.NewKVHybridRetrieveEngine(
|
|
elasticsearchRepo, types.ElasticsearchRetrieverEngineType,
|
|
),
|
|
); err != nil {
|
|
log.Errorf("Register elasticsearch_v8 retrieve engine failed: %v", err)
|
|
} else {
|
|
log.Infof("Register elasticsearch_v8 retrieve engine success")
|
|
}
|
|
}
|
|
}
|
|
|
|
if slices.Contains(retrieveDriver, "elasticsearch_v7") {
|
|
client, err := esv7.NewClient(esv7.Config{
|
|
Addresses: []string{os.Getenv("ELASTICSEARCH_ADDR")},
|
|
Username: os.Getenv("ELASTICSEARCH_USERNAME"),
|
|
Password: os.Getenv("ELASTICSEARCH_PASSWORD"),
|
|
})
|
|
if err != nil {
|
|
log.Errorf("Create elasticsearch_v7 client failed: %v", err)
|
|
} else {
|
|
elasticsearchRepo := elasticsearchRepoV7.NewElasticsearchEngineRepository(client, cfg)
|
|
if err := registry.Register(
|
|
retriever.NewKVHybridRetrieveEngine(
|
|
elasticsearchRepo, types.ElasticsearchRetrieverEngineType,
|
|
),
|
|
); err != nil {
|
|
log.Errorf("Register elasticsearch_v7 retrieve engine failed: %v", err)
|
|
} else {
|
|
log.Infof("Register elasticsearch_v7 retrieve engine success")
|
|
}
|
|
}
|
|
}
|
|
return registry, nil
|
|
}
|
|
|
|
// initAntsPool initializes the goroutine pool
|
|
// Creates a managed goroutine pool for concurrent task execution
|
|
// Parameters:
|
|
// - cfg: Application configuration
|
|
//
|
|
// Returns:
|
|
// - Configured goroutine pool
|
|
// - Error if initialization fails
|
|
func initAntsPool(cfg *config.Config) (*ants.Pool, error) {
|
|
// Default to 5 if not specified in config
|
|
poolSize := os.Getenv("CONCURRENCY_POOL_SIZE")
|
|
if poolSize == "" {
|
|
poolSize = "5"
|
|
}
|
|
poolSizeInt, err := strconv.Atoi(poolSize)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
// Set up the pool with pre-allocation for better performance
|
|
return ants.NewPool(poolSizeInt, ants.WithPreAlloc(true))
|
|
}
|
|
|
|
// registerPoolCleanup registers the goroutine pool for cleanup
|
|
// Ensures proper cleanup of the goroutine pool when application shuts down
|
|
// Parameters:
|
|
// - pool: Goroutine pool
|
|
// - cleaner: Resource cleaner
|
|
func registerPoolCleanup(pool *ants.Pool, cleaner interfaces.ResourceCleaner) {
|
|
cleaner.RegisterWithName("AntsPool", func() error {
|
|
pool.Release()
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// initDocReaderClient initializes the document reader client
|
|
// Creates a client for interacting with the document reader service
|
|
// Parameters:
|
|
// - cfg: Application configuration
|
|
//
|
|
// Returns:
|
|
// - Configured document reader client
|
|
// - Error if initialization fails
|
|
func initDocReaderClient(cfg *config.Config) (*client.Client, error) {
|
|
// Use the DocReader URL from environment or config
|
|
docReaderURL := os.Getenv("DOCREADER_ADDR")
|
|
if docReaderURL == "" && cfg.DocReader != nil {
|
|
docReaderURL = cfg.DocReader.Addr
|
|
}
|
|
return client.NewClient(docReaderURL)
|
|
}
|
|
|
|
// initOllamaService initializes the Ollama service client
|
|
// Creates a client for interacting with Ollama API for model inference
|
|
// Parameters:
|
|
// - None
|
|
//
|
|
// Returns:
|
|
// - Configured Ollama service client
|
|
// - Error if initialization fails
|
|
func initOllamaService() (*ollama.OllamaService, error) {
|
|
// Get Ollama service from existing factory function
|
|
return ollama.GetOllamaService()
|
|
}
|