254 lines
9.7 KiB
Go
254 lines
9.7 KiB
Go
package config
|
|
|
|
import (
|
|
"fmt"
|
|
"log"
|
|
"os"
|
|
"path/filepath"
|
|
"regexp"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/go-viper/mapstructure/v2"
|
|
"github.com/joho/godotenv"
|
|
"github.com/spf13/viper"
|
|
)
|
|
|
|
// Config 应用程序总配置
|
|
type Config struct {
|
|
Conversation *ConversationConfig `yaml:"conversation" json:"conversation"`
|
|
Server *ServerConfig `yaml:"server" json:"server"`
|
|
KnowledgeBase *KnowledgeBaseConfig `yaml:"knowledge_base" json:"knowledge_base"`
|
|
Tenant *TenantConfig `yaml:"tenant" json:"tenant"`
|
|
Models []ModelConfig `yaml:"models" json:"models"`
|
|
Asynq *AsynqConfig `yaml:"asynq" json:"asynq"`
|
|
VectorDatabase *VectorDatabaseConfig `yaml:"vector_database" json:"vector_database"`
|
|
DocReader *DocReaderConfig `yaml:"docreader" json:"docreader"`
|
|
StreamManager *StreamManagerConfig `yaml:"stream_manager" json:"stream_manager"`
|
|
}
|
|
|
|
type DocReaderConfig struct {
|
|
Addr string `yaml:"addr" json:"addr"`
|
|
}
|
|
|
|
type VectorDatabaseConfig struct {
|
|
Driver string `yaml:"driver" json:"driver"`
|
|
}
|
|
|
|
// ConversationConfig 对话服务配置
|
|
type ConversationConfig struct {
|
|
MaxRounds int `yaml:"max_rounds" json:"max_rounds"`
|
|
KeywordThreshold float64 `yaml:"keyword_threshold" json:"keyword_threshold"`
|
|
EmbeddingTopK int `yaml:"embedding_top_k" json:"embedding_top_k"`
|
|
VectorThreshold float64 `yaml:"vector_threshold" json:"vector_threshold"`
|
|
RerankTopK int `yaml:"rerank_top_k" json:"rerank_top_k"`
|
|
RerankThreshold float64 `yaml:"rerank_threshold" json:"rerank_threshold"`
|
|
FallbackStrategy string `yaml:"fallback_strategy" json:"fallback_strategy"`
|
|
FallbackResponse string `yaml:"fallback_response" json:"fallback_response"`
|
|
FallbackPrompt string `yaml:"fallback_prompt" json:"fallback_prompt"`
|
|
EnableRewrite bool `yaml:"enable_rewrite" json:"enable_rewrite"`
|
|
EnableRerank bool `yaml:"enable_rerank" json:"enable_rerank"`
|
|
Summary *SummaryConfig `yaml:"summary" json:"summary"`
|
|
GenerateSessionTitlePrompt string `yaml:"generate_session_title_prompt" json:"generate_session_title_prompt"`
|
|
GenerateSummaryPrompt string `yaml:"generate_summary_prompt" json:"generate_summary_prompt"`
|
|
RewritePromptSystem string `yaml:"rewrite_prompt_system" json:"rewrite_prompt_system"`
|
|
RewritePromptUser string `yaml:"rewrite_prompt_user" json:"rewrite_prompt_user"`
|
|
SimplifyQueryPrompt string `yaml:"simplify_query_prompt" json:"simplify_query_prompt"`
|
|
SimplifyQueryPromptUser string `yaml:"simplify_query_prompt_user" json:"simplify_query_prompt_user"`
|
|
ExtractEntitiesPrompt string `yaml:"extract_entities_prompt" json:"extract_entities_prompt"`
|
|
ExtractRelationshipsPrompt string `yaml:"extract_relationships_prompt" json:"extract_relationships_prompt"`
|
|
}
|
|
|
|
// SummaryConfig 摘要配置
|
|
type SummaryConfig struct {
|
|
MaxTokens int `yaml:"max_tokens" json:"max_tokens"`
|
|
RepeatPenalty float64 `yaml:"repeat_penalty" json:"repeat_penalty"`
|
|
TopK int `yaml:"top_k" json:"top_k"`
|
|
TopP float64 `yaml:"top_p" json:"top_p"`
|
|
FrequencyPenalty float64 `yaml:"frequency_penalty" json:"frequency_penalty"`
|
|
PresencePenalty float64 `yaml:"presence_penalty" json:"presence_penalty"`
|
|
Prompt string `yaml:"prompt" json:"prompt"`
|
|
ContextTemplate string `yaml:"context_template" json:"context_template"`
|
|
Temperature float64 `yaml:"temperature" json:"temperature"`
|
|
Seed int `yaml:"seed" json:"seed"`
|
|
MaxCompletionTokens int `yaml:"max_completion_tokens" json:"max_completion_tokens"`
|
|
NoMatchPrefix string `yaml:"no_match_prefix" json:"no_match_prefix"`
|
|
}
|
|
|
|
// ServerConfig 服务器配置
|
|
type ServerConfig struct {
|
|
Port int `yaml:"port" json:"port"`
|
|
Host string `yaml:"host" json:"host"`
|
|
LogPath string `yaml:"log_path" json:"log_path"`
|
|
ShutdownTimeout time.Duration `yaml:"shutdown_timeout" json:"shutdown_timeout" default:"30s"`
|
|
}
|
|
|
|
// KnowledgeBaseConfig 知识库配置
|
|
type KnowledgeBaseConfig struct {
|
|
ChunkSize int `yaml:"chunk_size" json:"chunk_size"`
|
|
ChunkOverlap int `yaml:"chunk_overlap" json:"chunk_overlap"`
|
|
SplitMarkers []string `yaml:"split_markers" json:"split_markers"`
|
|
KeepSeparator bool `yaml:"keep_separator" json:"keep_separator"`
|
|
ImageProcessing *ImageProcessingConfig `yaml:"image_processing" json:"image_processing"`
|
|
}
|
|
|
|
// ImageProcessingConfig 图像处理配置
|
|
type ImageProcessingConfig struct {
|
|
EnableMultimodal bool `yaml:"enable_multimodal" json:"enable_multimodal"`
|
|
}
|
|
|
|
// TenantConfig 租户配置
|
|
type TenantConfig struct {
|
|
DefaultSessionName string `yaml:"default_session_name" json:"default_session_name"`
|
|
DefaultSessionTitle string `yaml:"default_session_title" json:"default_session_title"`
|
|
DefaultSessionDescription string `yaml:"default_session_description" json:"default_session_description"`
|
|
}
|
|
|
|
// ModelConfig 模型配置
|
|
type ModelConfig struct {
|
|
Type string `yaml:"type" json:"type"`
|
|
Source string `yaml:"source" json:"source"`
|
|
ModelName string `yaml:"model_name" json:"model_name"`
|
|
Parameters map[string]interface{} `yaml:"parameters" json:"parameters"`
|
|
}
|
|
|
|
type AsynqConfig struct {
|
|
Addr string `yaml:"addr" json:"addr"`
|
|
Username string `yaml:"username" json:"username"`
|
|
Password string `yaml:"password" json:"password"`
|
|
ReadTimeout time.Duration `yaml:"read_timeout" json:"read_timeout"`
|
|
WriteTimeout time.Duration `yaml:"write_timeout" json:"write_timeout"`
|
|
Concurrency int `yaml:"concurrency" json:"concurrency"`
|
|
}
|
|
|
|
// StreamManagerConfig 流管理器配置
|
|
type StreamManagerConfig struct {
|
|
Type string `yaml:"type" json:"type"` // 类型: "memory" 或 "redis"
|
|
Redis RedisConfig `yaml:"redis" json:"redis"` // Redis配置
|
|
CleanupTimeout time.Duration `yaml:"cleanup_timeout" json:"cleanup_timeout"` // 清理超时,单位秒
|
|
}
|
|
|
|
// RedisConfig Redis配置
|
|
type RedisConfig struct {
|
|
Address string `yaml:"address" json:"address"` // Redis地址
|
|
Password string `yaml:"password" json:"password"` // Redis密码
|
|
DB int `yaml:"db" json:"db"` // Redis数据库
|
|
Prefix string `yaml:"prefix" json:"prefix"` // 键前缀
|
|
TTL time.Duration `yaml:"ttl" json:"ttl"` // 过期时间(小时)
|
|
}
|
|
|
|
// LoadConfig 从配置文件加载配置
|
|
func LoadConfig() (*Config, error) {
|
|
// 设置配置文件名和路径
|
|
viper.SetConfigName("config") // 配置文件名称(不带扩展名)
|
|
viper.SetConfigType("yaml") // 配置文件类型
|
|
viper.AddConfigPath(".") // 当前目录
|
|
viper.AddConfigPath("./config") // config子目录
|
|
viper.AddConfigPath("$HOME/.appname") // 用户目录
|
|
viper.AddConfigPath("/etc/appname/") // etc目录
|
|
|
|
// 启用环境变量替换
|
|
viper.AutomaticEnv()
|
|
viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
|
|
|
|
// 读取配置文件
|
|
if err := viper.ReadInConfig(); err != nil {
|
|
return nil, fmt.Errorf("error reading config file: %w", err)
|
|
}
|
|
|
|
// 替换配置中的环境变量引用
|
|
configFileContent, err := os.ReadFile(viper.ConfigFileUsed())
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error reading config file content: %w", err)
|
|
}
|
|
|
|
// 替换${ENV_VAR}格式的环境变量引用
|
|
re := regexp.MustCompile(`\${([^}]+)}`)
|
|
result := re.ReplaceAllStringFunc(string(configFileContent), func(match string) string {
|
|
// 提取环境变量名称(去掉${}部分)
|
|
envVar := match[2 : len(match)-1]
|
|
// 获取环境变量值,如果不存在则保持原样
|
|
if value := os.Getenv(envVar); value != "" {
|
|
return value
|
|
}
|
|
return match
|
|
})
|
|
|
|
// 使用处理后的配置内容
|
|
viper.ReadConfig(strings.NewReader(result))
|
|
|
|
// 解析配置到结构体
|
|
var cfg Config
|
|
if err := viper.Unmarshal(&cfg, func(dc *mapstructure.DecoderConfig) {
|
|
dc.TagName = "yaml"
|
|
}); err != nil {
|
|
return nil, fmt.Errorf("unable to decode config into struct: %w", err)
|
|
}
|
|
fmt.Printf("Using configuration file: %s\n", viper.ConfigFileUsed())
|
|
return &cfg, nil
|
|
}
|
|
|
|
func SetEnv() {
|
|
setEnvFileConfigToEnv()
|
|
envSet("DB_HOST", "127.0.0.1")
|
|
envSet("DB_PORT", "5432")
|
|
envSet("TZ", "Asia/Shanghai")
|
|
envSet("OTEL_SERVICE_NAME", "WeKnora")
|
|
envSet("OTEL_TRACES_EXPORTER", "otlp")
|
|
envSet("OTEL_METRICS_EXPORTER", "none")
|
|
envSet("OTEL_LOGS_EXPORTER", "none")
|
|
envSet("OTEL_PROPAGATORS", "tracecontext,baggage")
|
|
envSet("DOCREADER_ADDR", "127.0.0.1:50051")
|
|
envSet("MINIO_ENDPOINT", "127.0.0.1:9000")
|
|
envSet("REDIS_ADDR", "127.0.0.1:6379")
|
|
|
|
}
|
|
|
|
var (
|
|
loadEnvOnce sync.Once
|
|
)
|
|
|
|
func setEnvFileConfigToEnv() {
|
|
loadEnvOnce.Do(func() {
|
|
modDir, err := getModuleDir()
|
|
if err != nil {
|
|
return
|
|
}
|
|
envDir := fmt.Sprintf("%s/.env", modDir)
|
|
err = godotenv.Load(envDir)
|
|
if err != nil {
|
|
log.Fatal("Error loading .env file:", err)
|
|
}
|
|
})
|
|
}
|
|
|
|
func envSet(key, value string) {
|
|
if err := os.Setenv(key, value); err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
|
|
func getModuleDir() (string, error) {
|
|
dir, err := os.Getwd()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
for {
|
|
modPath := filepath.Join(dir, "go.mod")
|
|
if _, err := os.Stat(modPath); err == nil {
|
|
return dir, nil // 找到 go.mod
|
|
}
|
|
|
|
// 向上查找父目录
|
|
parent := filepath.Dir(dir)
|
|
if parent == dir {
|
|
break // 到达根目录,未找到
|
|
}
|
|
dir = parent
|
|
}
|
|
|
|
return "", fmt.Errorf("go.mod not found in current directory or parents")
|
|
}
|