diff --git a/server/cmd/server/main.go b/server/cmd/server/main.go index d767e43..2e81b2d 100644 --- a/server/cmd/server/main.go +++ b/server/cmd/server/main.go @@ -84,7 +84,14 @@ func main() { } else { log.Println("warning: gRPC server address not configured, /api/ymt/users will not work") } - r := api.NewRouter(meta, marketing, marketingAuth, resellerDB, ymt, grpcAddr) + // 获取营销系统API域名 + marketingAPIDomain := cfg.MarketingAPIDomain + if marketingAPIDomain == "" { + log.Println("warning: marketing API domain not configured, authentication will not work") + } else { + log.Println("marketing API domain:", marketingAPIDomain) + } + r := api.NewRouter(meta, marketing, marketingAuth, resellerDB, ymt, grpcAddr, marketingAPIDomain) addr := ":" + func() string { s := cfg.Port if s == "" { diff --git a/server/internal/api/access.go b/server/internal/api/access.go index 5c838f7..eb3b419 100644 --- a/server/internal/api/access.go +++ b/server/internal/api/access.go @@ -26,7 +26,7 @@ func withAccess(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Methods", "GET,POST,PATCH,DELETE,OPTIONS") - w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, token") if r.Method == http.MethodOptions { w.WriteHeader(http.StatusOK) return diff --git a/server/internal/api/middleware.go b/server/internal/api/middleware.go index d403f9f..bf45b2c 100644 --- a/server/internal/api/middleware.go +++ b/server/internal/api/middleware.go @@ -1,11 +1,13 @@ package api import ( - "context" - "crypto/rand" - "encoding/hex" - "encoding/json" - "net/http" + "context" + "crypto/rand" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "net/http" ) type ctxKey string @@ -14,21 +16,22 @@ var traceKey ctxKey = "trace_id" var sqlKey ctxKey = "sql" var metaKey ctxKey = "req_meta" var payloadKey ctxKey = "payload" +var creatorIDsKey ctxKey = "creator_ids" // 存储用户的创建者ID列表 func withTrace(h http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - tid := r.Header.Get("X-Request-ID") - if tid == "" { - buf := make([]byte, 16) - _, _ = rand.Read(buf) - tid = hex.EncodeToString(buf) - } - w.Header().Set("X-Request-ID", tid) - m := ReqMeta{Method: r.Method, Path: r.URL.Path, Query: r.URL.RawQuery, Remote: r.RemoteAddr} - ctx := context.WithValue(r.Context(), traceKey, tid) - ctx = context.WithValue(ctx, metaKey, m) - h.ServeHTTP(w, r.WithContext(ctx)) - }) + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + tid := r.Header.Get("X-Request-ID") + if tid == "" { + buf := make([]byte, 16) + _, _ = rand.Read(buf) + tid = hex.EncodeToString(buf) + } + w.Header().Set("X-Request-ID", tid) + m := ReqMeta{Method: r.Method, Path: r.URL.Path, Query: r.URL.RawQuery, Remote: r.RemoteAddr} + ctx := context.WithValue(r.Context(), traceKey, tid) + ctx = context.WithValue(ctx, metaKey, m) + h.ServeHTTP(w, r.WithContext(ctx)) + }) } func TraceIDFrom(r *http.Request) string { @@ -41,44 +44,120 @@ func TraceIDFrom(r *http.Request) string { } func WithSQL(r *http.Request, sql string) *http.Request { - return r.WithContext(context.WithValue(r.Context(), sqlKey, sql)) + return r.WithContext(context.WithValue(r.Context(), sqlKey, sql)) } func SQLFrom(r *http.Request) string { - v := r.Context().Value(sqlKey) - if v == nil { - return "" - } - s, _ := v.(string) - return s + v := r.Context().Value(sqlKey) + if v == nil { + return "" + } + s, _ := v.(string) + return s } type ReqMeta struct { - Method string - Path string - Query string - Remote string + Method string + Path string + Query string + Remote string } func MetaFrom(r *http.Request) ReqMeta { - v := r.Context().Value(metaKey) - if v == nil { - return ReqMeta{} - } - m, _ := v.(ReqMeta) - return m + v := r.Context().Value(metaKey) + if v == nil { + return ReqMeta{} + } + m, _ := v.(ReqMeta) + return m } func WithPayload(r *http.Request, v interface{}) *http.Request { - b, _ := json.Marshal(v) - return r.WithContext(context.WithValue(r.Context(), payloadKey, string(b))) + b, _ := json.Marshal(v) + return r.WithContext(context.WithValue(r.Context(), payloadKey, string(b))) } func PayloadFrom(r *http.Request) string { - v := r.Context().Value(payloadKey) - if v == nil { - return "" - } - s, _ := v.(string) - return s + v := r.Context().Value(payloadKey) + if v == nil { + return "" + } + s, _ := v.(string) + return s +} + +// AuthResponse 营销系统鉴权接口返回结构 +type AuthResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Data []int `json:"data"` // 创建者ID列表 +} + +// withAuth 认证中间件:验证 token 并获取用户数据权限 +func withAuth(apiDomain string) func(http.Handler) http.Handler { + return func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // 获取 token + token := r.Header.Get("token") + if token == "" { + http.Error(w, "{\"code\":401,\"message\":\"未提供认证token\",\"data\":null}", http.StatusUnauthorized) + return + } + + // 请求营销系统鉴权接口 + authURL := fmt.Sprintf("%s/auth/admin/dataPermission", apiDomain) + req, err := http.NewRequest("GET", authURL, nil) + if err != nil { + http.Error(w, "{\"code\":500,\"message\":\"创建认证请求失败\",\"data\":null}", http.StatusInternalServerError) + return + } + + // 设置请求头(根据图片示例) + req.Header.Set("Authorization", "Bearer "+token) + + // 发起请求 + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + http.Error(w, "{\"code\":500,\"message\":\"认证请求失败\",\"data\":null}", http.StatusInternalServerError) + return + } + defer resp.Body.Close() + + // 读取响应 + body, err := io.ReadAll(resp.Body) + if err != nil { + http.Error(w, "{\"code\":500,\"message\":\"读取认证响应失败\",\"data\":null}", http.StatusInternalServerError) + return + } + + // 解析响应 + var authResp AuthResponse + if err := json.Unmarshal(body, &authResp); err != nil { + http.Error(w, "{\"code\":500,\"message\":\"解析认证响应失败\",\"data\":null}", http.StatusInternalServerError) + return + } + + // 检查认证是否成功 + if authResp.Code != 200 { + errorMsg := fmt.Sprintf("{\"code\":%d,\"message\":\"%s\",\"data\":null}", authResp.Code, authResp.Message) + http.Error(w, errorMsg, http.StatusUnauthorized) + return + } + + // 将创建者ID列表存储到 context 中 + ctx := context.WithValue(r.Context(), creatorIDsKey, authResp.Data) + h.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} + +// CreatorIDsFrom 从 context 中获取创建者ID列表 +func CreatorIDsFrom(r *http.Request) []int { + v := r.Context().Value(creatorIDsKey) + if v == nil { + return nil + } + ids, _ := v.([]int) + return ids } diff --git a/server/internal/api/router.go b/server/internal/api/router.go index 175bc8c..fbc10cf 100644 --- a/server/internal/api/router.go +++ b/server/internal/api/router.go @@ -6,27 +6,34 @@ import ( "os" ) -func NewRouter(metaDB *sql.DB, marketingDB *sql.DB, marketingAuthDB *sql.DB, resellerDB *sql.DB, ymtDB *sql.DB, grpcAddr string) http.Handler { +func NewRouter(metaDB *sql.DB, marketingDB *sql.DB, marketingAuthDB *sql.DB, resellerDB *sql.DB, ymtDB *sql.DB, grpcAddr string, marketingAPIDomain string) http.Handler { mux := http.NewServeMux() - mux.Handle("/api/templates", withAccess(withTrace(TemplatesHandler(metaDB, marketingDB)))) - mux.Handle("/api/templates/", withAccess(withTrace(TemplatesHandler(metaDB, marketingDB)))) - mux.Handle("/api/exports", withAccess(withTrace(ExportsHandler(metaDB, marketingDB, ymtDB)))) - mux.Handle("/api/exports/", withAccess(withTrace(ExportsHandler(metaDB, marketingDB, ymtDB)))) - mux.Handle("/api/metadata/fields", withAccess(withTrace(MetadataHandler(metaDB, marketingDB, ymtDB)))) - mux.Handle("/api/fields", withAccess(withTrace(FieldsHandler(marketingDB, ymtDB)))) - mux.Handle("/api/fields/", withAccess(withTrace(FieldsHandler(marketingDB, ymtDB)))) - mux.Handle("/api/creators", withAccess(withTrace(CreatorsHandler(marketingAuthDB)))) - mux.Handle("/api/creators/", withAccess(withTrace(CreatorsHandler(marketingAuthDB)))) - mux.Handle("/api/resellers", withAccess(withTrace(ResellersHandler(resellerDB)))) - mux.Handle("/api/resellers/", withAccess(withTrace(ResellersHandler(resellerDB)))) - mux.Handle("/api/plans", withAccess(withTrace(PlansHandler(marketingDB)))) - mux.Handle("/api/plans/", withAccess(withTrace(PlansHandler(marketingDB)))) - mux.Handle("/api/ymt/users", withAccess(withTrace(YMTUsersHandler(grpcAddr)))) - mux.Handle("/api/ymt/users/", withAccess(withTrace(YMTUsersHandler(grpcAddr)))) - mux.Handle("/api/ymt/merchants", withAccess(withTrace(YMTMerchantsHandler(ymtDB)))) - mux.Handle("/api/ymt/merchants/", withAccess(withTrace(YMTMerchantsHandler(ymtDB)))) - mux.Handle("/api/ymt/activities", withAccess(withTrace(YMTActivitiesHandler(ymtDB)))) - mux.Handle("/api/ymt/activities/", withAccess(withTrace(YMTActivitiesHandler(ymtDB)))) + + // 创建认证中间件 + authMiddleware := withAuth(marketingAPIDomain) + + // 需要认证的路由(所有 API 路由) + mux.Handle("/api/templates", withAccess(withTrace(authMiddleware(TemplatesHandler(metaDB, marketingDB))))) + mux.Handle("/api/templates/", withAccess(withTrace(authMiddleware(TemplatesHandler(metaDB, marketingDB))))) + mux.Handle("/api/exports", withAccess(withTrace(authMiddleware(ExportsHandler(metaDB, marketingDB, ymtDB))))) + mux.Handle("/api/exports/", withAccess(withTrace(authMiddleware(ExportsHandler(metaDB, marketingDB, ymtDB))))) + mux.Handle("/api/metadata/fields", withAccess(withTrace(authMiddleware(MetadataHandler(metaDB, marketingDB, ymtDB))))) + mux.Handle("/api/fields", withAccess(withTrace(authMiddleware(FieldsHandler(marketingDB, ymtDB))))) + mux.Handle("/api/fields/", withAccess(withTrace(authMiddleware(FieldsHandler(marketingDB, ymtDB))))) + mux.Handle("/api/creators", withAccess(withTrace(authMiddleware(CreatorsHandler(marketingAuthDB))))) + mux.Handle("/api/creators/", withAccess(withTrace(authMiddleware(CreatorsHandler(marketingAuthDB))))) + mux.Handle("/api/resellers", withAccess(withTrace(authMiddleware(ResellersHandler(resellerDB))))) + mux.Handle("/api/resellers/", withAccess(withTrace(authMiddleware(ResellersHandler(resellerDB))))) + mux.Handle("/api/plans", withAccess(withTrace(authMiddleware(PlansHandler(marketingDB))))) + mux.Handle("/api/plans/", withAccess(withTrace(authMiddleware(PlansHandler(marketingDB))))) + mux.Handle("/api/ymt/users", withAccess(withTrace(authMiddleware(YMTUsersHandler(grpcAddr))))) + mux.Handle("/api/ymt/users/", withAccess(withTrace(authMiddleware(YMTUsersHandler(grpcAddr))))) + mux.Handle("/api/ymt/merchants", withAccess(withTrace(authMiddleware(YMTMerchantsHandler(ymtDB))))) + mux.Handle("/api/ymt/merchants/", withAccess(withTrace(authMiddleware(YMTMerchantsHandler(ymtDB))))) + mux.Handle("/api/ymt/activities", withAccess(withTrace(authMiddleware(YMTActivitiesHandler(ymtDB))))) + mux.Handle("/api/ymt/activities/", withAccess(withTrace(authMiddleware(YMTActivitiesHandler(ymtDB))))) + + // 工具类接口(不需要认证) mux.HandleFunc("/api/utils/decode_key", func(w http.ResponseWriter, r *http.Request) { v := r.URL.Query().Get("v") if v == "" { diff --git a/server/internal/config/config.go b/server/internal/config/config.go index e8844b4..9ef83ab 100644 --- a/server/internal/config/config.go +++ b/server/internal/config/config.go @@ -38,6 +38,7 @@ type App struct { YMTTestDB DB `yaml:"ymt_test_db"` YMTKeyDecryptKeyB64 string `yaml:"ymt_key_decrypt_key_b64"` GRPCServer GRPCServer `yaml:"grpc_server"` + MarketingAPIDomain string `yaml:"marketing_api_domain"` } type root struct {