diff --git a/cmd/server/main.go b/cmd/server/main.go index 4f1ea4d..b46fe3a 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -30,10 +30,10 @@ func main() { if os.Getenv("GIN_MODE") == "release" { gin.SetMode(gin.ReleaseMode) } else { + config.SetEnv() gin.SetMode(gin.DebugMode) - } - //config.SetEnv() + // Build dependency injection container c := container.BuildContainer(runtime.GetContainer()) diff --git a/docker-compose.yml b/docker-compose.yml index 2b8ea7d..bf5fc5f 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,6 +1,6 @@ services: app: - image: wechatopenai/weknora-app:latest + image: ai-app:latest container_name: WeKnora-app ports: - "8080:8080" diff --git a/internal/application/repository/tenant.go b/internal/application/repository/tenant.go index c46618d..037656d 100644 --- a/internal/application/repository/tenant.go +++ b/internal/application/repository/tenant.go @@ -4,11 +4,12 @@ import ( "context" "errors" - "gorm.io/gorm" - "gorm.io/gorm/clause" "knowlege-lsxd/internal/logger" "knowlege-lsxd/internal/types" "knowlege-lsxd/internal/types/interfaces" + + "gorm.io/gorm" + "gorm.io/gorm/clause" ) var ( @@ -43,6 +44,18 @@ func (r *tenantRepository) GetTenantByID(ctx context.Context, id uint) (*types.T return &tenant, nil } +// GetTenantByApiKey gets tenant by apikey +func (r *tenantRepository) GetTenantByApiKey(ctx context.Context, apikey string) (*types.Tenant, error) { + var tenant types.Tenant + if err := r.db.WithContext(ctx).Where("api_key = ?", apikey).First(&tenant).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrTenantNotFound + } + return nil, err + } + return &tenant, nil +} + // ListTenants lists all tenants func (r *tenantRepository) ListTenants(ctx context.Context) ([]*types.Tenant, error) { var tenants []*types.Tenant diff --git a/internal/application/service/session.go b/internal/application/service/session.go index 0ed14d9..10193ee 100644 --- a/internal/application/service/session.go +++ b/internal/application/service/session.go @@ -355,6 +355,7 @@ func (s *sessionService) KnowledgeQA(ctx context.Context, sessionID, query strin Seed: session.SummaryParameters.Seed, NoMatchPrefix: session.SummaryParameters.NoMatchPrefix, MaxCompletionTokens: session.SummaryParameters.MaxCompletionTokens, + Thinking: session.SummaryParameters.Thinking, }, FallbackResponse: session.FallbackResponse, } diff --git a/internal/application/service/tenant.go b/internal/application/service/tenant.go index a726910..55b2137 100644 --- a/internal/application/service/tenant.go +++ b/internal/application/service/tenant.go @@ -102,6 +102,21 @@ func (s *tenantService) GetTenantByID(ctx context.Context, id uint) (*types.Tena return tenant, nil } +// GetTenantByID retrieves a tenant by their ID +func (s *tenantService) GetTenantByApiKey(ctx context.Context, apikey string) (*types.Tenant, error) { + + tenant, err := s.repo.GetTenantByApiKey(ctx, apikey) + if err != nil { + logger.ErrorWithFields(ctx, err, map[string]interface{}{ + "apikey": apikey, + }) + return nil, err + } + + logger.Infof(ctx, "Tenant retrieved successfully, ID: %d, name: %s", tenant.ID, tenant.Name) + return tenant, nil +} + // ListTenants retrieves a list of all tenants func (s *tenantService) ListTenants(ctx context.Context) ([]*types.Tenant, error) { logger.Info(ctx, "Start retrieving tenant list") diff --git a/internal/config/config.go b/internal/config/config.go index 2e03cb3..53cb3a4 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -210,6 +210,8 @@ func SetEnv() { envSet("DOCREADER_ADDR", "127.0.0.1:50051") envSet("MINIO_ENDPOINT", "127.0.0.1:9000") envSet("REDIS_ADDR", "127.0.0.1:6379") + envSet("OLLAMA_BASE_URL", "http://localhost:11434") + envSet("TENANT_AES_KEY", "knowlege-lsxd-2025-ase-secret") } diff --git a/internal/middleware/auth.go b/internal/middleware/auth.go index caa4820..c7cc178 100644 --- a/internal/middleware/auth.go +++ b/internal/middleware/auth.go @@ -8,10 +8,11 @@ import ( "slices" "strings" - "github.com/gin-gonic/gin" "knowlege-lsxd/internal/config" "knowlege-lsxd/internal/types" "knowlege-lsxd/internal/types/interfaces" + + "github.com/gin-gonic/gin" ) // 无需认证的API列表 @@ -60,19 +61,19 @@ func Auth(tenantService interfaces.TenantService, cfg *config.Config) gin.Handle } // Get tenant information - tenantID, err := tenantService.ExtractTenantIDFromAPIKey(apiKey) - if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{ - "error": "Unauthorized: invalid API key format", - }) - c.Abort() - return - } + //tenantID, err := tenantService.ExtractTenantIDFromAPIKey(apiKey) + //if err != nil { + // c.JSON(http.StatusUnauthorized, gin.H{ + // "error": "Unauthorized: invalid API key format", + // }) + // c.Abort() + // return + //} // Verify API key validity (matches the one in database) - t, err := tenantService.GetTenantByID(c.Request.Context(), tenantID) + t, err := tenantService.GetTenantByApiKey(c.Request.Context(), apiKey) if err != nil { - log.Printf("Error getting tenant by ID: %v, tenantID: %d, apiKey: %s", err, tenantID, apiKey) + log.Printf("Error getting tenant by ID: %v, tenantID: %d, apiKey: %s", err, t.ID, apiKey) c.JSON(http.StatusUnauthorized, gin.H{ "error": "Unauthorized: invalid API key", }) @@ -89,11 +90,11 @@ func Auth(tenantService interfaces.TenantService, cfg *config.Config) gin.Handle } // Store tenant ID in context - c.Set(types.TenantIDContextKey.String(), tenantID) + c.Set(types.TenantIDContextKey.String(), t.ID) c.Set(types.TenantInfoContextKey.String(), t) c.Request = c.Request.WithContext( context.WithValue( - context.WithValue(c.Request.Context(), types.TenantIDContextKey, tenantID), + context.WithValue(c.Request.Context(), types.TenantIDContextKey, t.ID), types.TenantInfoContextKey, t, ), ) diff --git a/internal/models/utils/ollama/ollama.go b/internal/models/utils/ollama/ollama.go index 8fe5e70..e797399 100644 --- a/internal/models/utils/ollama/ollama.go +++ b/internal/models/utils/ollama/ollama.go @@ -8,6 +8,7 @@ import ( "os" "strings" "sync" + "time" "knowlege-lsxd/internal/logger" @@ -28,7 +29,7 @@ func GetOllamaService() (*OllamaService, error) { // Get Ollama base URL from environment variable, if not set use provided baseURL or default value logger.GetLogger(context.Background()).Infof("Ollama base URL: %s", os.Getenv("OLLAMA_BASE_URL")) baseURL := "http://localhost:11434" - envURL := "" //os.Getenv("OLLAMA_BASE_URL") + envURL := os.Getenv("OLLAMA_BASE_URL") if envURL != "" { baseURL = envURL } @@ -285,6 +286,21 @@ func (s *OllamaService) Chat(ctx context.Context, req *api.ChatRequest, fn api.C if err := s.StartService(ctx); err != nil { return err } + if req.KeepAlive == nil { + req.KeepAlive = &api.Duration{ + time.Hour * 24, + } + } + if req.Think != nil { + switch req.Think.Value { + case true: + req.Think = nil + case false: + req.Think = &api.ThinkValue{ + Value: false, + } + } + } // Use official client Chat method return s.client.Chat(ctx, req, fn) diff --git a/internal/router/router.go b/internal/router/router.go index 36b9236..8508495 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -66,6 +66,7 @@ func NewRouter(params RouterParams) *gin.Engine { // 初始化接口(不需要认证) r.GET("/api/v1/initialization/status", params.InitializationHandler.CheckStatus) r.GET("/api/v1/initialization/config", params.InitializationHandler.GetCurrentConfig) + r.POST("/api/v1/initialization/initialize", params.InitializationHandler.Initialize) // Ollama相关接口(不需要认证) r.GET("/api/v1/initialization/ollama/status", params.InitializationHandler.CheckOllamaStatus) diff --git a/internal/types/interfaces/tenant.go b/internal/types/interfaces/tenant.go index 20811ee..19c3286 100644 --- a/internal/types/interfaces/tenant.go +++ b/internal/types/interfaces/tenant.go @@ -12,6 +12,8 @@ type TenantService interface { CreateTenant(ctx context.Context, tenant *types.Tenant) (*types.Tenant, error) // GetTenantByID gets a tenant by ID GetTenantByID(ctx context.Context, id uint) (*types.Tenant, error) + // GetTenantByID gets a tenant by ID + GetTenantByApiKey(ctx context.Context, apikey string) (*types.Tenant, error) // ListTenants lists all tenants ListTenants(ctx context.Context) ([]*types.Tenant, error) // UpdateTenant updates a tenant @@ -30,6 +32,8 @@ type TenantRepository interface { CreateTenant(ctx context.Context, tenant *types.Tenant) error // GetTenantByID gets a tenant by ID GetTenantByID(ctx context.Context, id uint) (*types.Tenant, error) + // GetTenantByID gets a tenant by ID + GetTenantByApiKey(ctx context.Context, apikey string) (*types.Tenant, error) // ListTenants lists all tenants ListTenants(ctx context.Context) ([]*types.Tenant, error) // UpdateTenant updates a tenant