l_ai_knowledge/internal/application/service/common_test.go

89 lines
2.1 KiB
Go

package service
import (
"fmt"
"knowlege-lsxd/internal/application/repository"
"knowlege-lsxd/internal/config"
"knowlege-lsxd/internal/models/utils/ollama"
"knowlege-lsxd/internal/types/interfaces"
"os"
"time"
"gorm.io/driver/postgres"
"gorm.io/gorm"
)
var Repo = setRepo()
type repos struct {
Chunk interfaces.ChunkRepository
KnowlegeBase interfaces.KnowledgeBaseRepository
Knowlege interfaces.KnowledgeRepository
Model interfaces.ModelRepository
Tenants interfaces.TenantRepository
}
func setRepo() *repos {
config.SetEnv()
database, err := InitDatabase(nil)
if err != nil {
panic(err)
}
return &repos{
Chunk: repository.NewChunkRepository(database),
KnowlegeBase: repository.NewKnowledgeBaseRepository(database),
Knowlege: repository.NewKnowledgeRepository(database),
Model: repository.NewModelRepository(database),
Tenants: repository.NewTenantRepository(database),
}
}
// Ollama 函数用于获取 OllamaService 实例
// 如果获取过程中出现错误,会通过 panic 抛出异常
// 返回值: *ollama.OllamaService - Ollama 服务的实例指针
func Ollama() *ollama.OllamaService {
// 调用 ollama 包的 GetOllamaService 函数获取服务实例
ol, e := ollama.GetOllamaService()
if e != nil {
panic(e)
}
return ol
}
func InitDatabase(cfg *config.Config) (*gorm.DB, error) {
var dialector gorm.Dialector
driver := os.Getenv("DB_DRIVER")
switch driver {
case "postgres":
dsn := fmt.Sprintf(
"host=%s port=%s user=%s password=%s dbname=%s sslmode=%s",
os.Getenv("DB_HOST"),
os.Getenv("DB_PORT"),
os.Getenv("DB_USER"),
os.Getenv("DB_PASSWORD"),
os.Getenv("DB_NAME"),
"disable",
)
dialector = postgres.Open(dsn)
default:
return nil, fmt.Errorf("unsupported database driver: %s", os.Getenv("DB_DRIVER"))
}
db, err := gorm.Open(dialector, &gorm.Config{})
if err != nil {
return nil, err
}
// Get underlying SQL DB object
sqlDB, err := db.DB()
if err != nil {
return nil, err
}
// Configure connection pool parameters
sqlDB.SetMaxIdleConns(10)
sqlDB.SetConnMaxLifetime(time.Duration(10) * time.Minute)
return db, nil
}