package middleware import ( "context" "errors" "log" "net/http" "slices" "strings" "github.com/gin-gonic/gin" "knowlege-lsxd/internal/config" "knowlege-lsxd/internal/types" "knowlege-lsxd/internal/types/interfaces" ) // 无需认证的API列表 var noAuthAPI = map[string][]string{ "/api/v1/test-data": {"GET"}, "/api/v1/tenants": {"POST"}, "/api/v1/initialization/*": {"GET", "POST"}, } // 检查请求是否在无需认证的API列表中 func isNoAuthAPI(path string, method string) bool { for api, methods := range noAuthAPI { // 如果以*结尾,按照前缀匹配,否则按照全路径匹配 if strings.HasSuffix(api, "*") { if strings.HasPrefix(path, strings.TrimSuffix(api, "*")) && slices.Contains(methods, method) { return true } } else if path == api && slices.Contains(methods, method) { return true } } return false } // Auth 认证中间件 func Auth(tenantService interfaces.TenantService, cfg *config.Config) gin.HandlerFunc { return func(c *gin.Context) { // ignore OPTIONS request if c.Request.Method == "OPTIONS" { c.Next() return } // 检查请求是否在无需认证的API列表中 if isNoAuthAPI(c.Request.URL.Path, c.Request.Method) { c.Next() return } // Get API Key from request header apiKey := c.GetHeader("X-API-Key") if apiKey == "" { c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) c.Abort() return } // Get tenant information tenantID, err := tenantService.ExtractTenantIDFromAPIKey(apiKey) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{ "error": "Unauthorized: invalid API key format", }) c.Abort() return } // Verify API key validity (matches the one in database) t, err := tenantService.GetTenantByID(c.Request.Context(), tenantID) if err != nil { log.Printf("Error getting tenant by ID: %v, tenantID: %d, apiKey: %s", err, tenantID, apiKey) c.JSON(http.StatusUnauthorized, gin.H{ "error": "Unauthorized: invalid API key", }) c.Abort() return } if t == nil || t.APIKey != apiKey { c.JSON(http.StatusUnauthorized, gin.H{ "error": "Unauthorized: invalid API key", }) c.Abort() return } // Store tenant ID in context c.Set(types.TenantIDContextKey.String(), tenantID) c.Set(types.TenantInfoContextKey.String(), t) c.Request = c.Request.WithContext( context.WithValue( context.WithValue(c.Request.Context(), types.TenantIDContextKey, tenantID), types.TenantInfoContextKey, t, ), ) c.Next() } } // GetTenantIDFromContext helper function to get tenant ID from context func GetTenantIDFromContext(ctx context.Context) (uint, error) { tenantID, ok := ctx.Value("tenantID").(uint) if !ok { return 0, errors.New("tenant ID not found in context") } return tenantID, nil }