feat(api): 引入营销系统鉴权中间件,增强接口安全控制

- 新增配置项 MarketingAPIDomain,用于设置营销系统API域名
- 服务器启动时打印营销API域名警告或信息
- api/router.go中添加认证中间件,所有API路由均需认证访问
- api/middleware.go新增withAuth认证中间件,实现Token验证和数据权限接口调用
- 认证中间件通过HTTP请求调用营销系统鉴权接口获取创建者ID列表
- 请求头Access-Control-Allow-Headers新增token,支持跨域传递认证token
- 提供从请求上下文获取创建者ID列表的辅助函数CreatorIDsFrom
- 适配NewRouter函数,新增marketingAPIDomain参数用于中间件配置
This commit is contained in:
zhouyonggao 2025-12-23 15:38:48 +08:00
parent 4742b3db05
commit ade149c67c
5 changed files with 159 additions and 65 deletions

View File

@ -84,7 +84,14 @@ func main() {
} else {
log.Println("warning: gRPC server address not configured, /api/ymt/users will not work")
}
r := api.NewRouter(meta, marketing, marketingAuth, resellerDB, ymt, grpcAddr)
// 获取营销系统API域名
marketingAPIDomain := cfg.MarketingAPIDomain
if marketingAPIDomain == "" {
log.Println("warning: marketing API domain not configured, authentication will not work")
} else {
log.Println("marketing API domain:", marketingAPIDomain)
}
r := api.NewRouter(meta, marketing, marketingAuth, resellerDB, ymt, grpcAddr, marketingAPIDomain)
addr := ":" + func() string {
s := cfg.Port
if s == "" {

View File

@ -26,7 +26,7 @@ func withAccess(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET,POST,PATCH,DELETE,OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, token")
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusOK)
return

View File

@ -1,11 +1,13 @@
package api
import (
"context"
"crypto/rand"
"encoding/hex"
"encoding/json"
"net/http"
"context"
"crypto/rand"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
)
type ctxKey string
@ -14,21 +16,22 @@ var traceKey ctxKey = "trace_id"
var sqlKey ctxKey = "sql"
var metaKey ctxKey = "req_meta"
var payloadKey ctxKey = "payload"
var creatorIDsKey ctxKey = "creator_ids" // 存储用户的创建者ID列表
func withTrace(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tid := r.Header.Get("X-Request-ID")
if tid == "" {
buf := make([]byte, 16)
_, _ = rand.Read(buf)
tid = hex.EncodeToString(buf)
}
w.Header().Set("X-Request-ID", tid)
m := ReqMeta{Method: r.Method, Path: r.URL.Path, Query: r.URL.RawQuery, Remote: r.RemoteAddr}
ctx := context.WithValue(r.Context(), traceKey, tid)
ctx = context.WithValue(ctx, metaKey, m)
h.ServeHTTP(w, r.WithContext(ctx))
})
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tid := r.Header.Get("X-Request-ID")
if tid == "" {
buf := make([]byte, 16)
_, _ = rand.Read(buf)
tid = hex.EncodeToString(buf)
}
w.Header().Set("X-Request-ID", tid)
m := ReqMeta{Method: r.Method, Path: r.URL.Path, Query: r.URL.RawQuery, Remote: r.RemoteAddr}
ctx := context.WithValue(r.Context(), traceKey, tid)
ctx = context.WithValue(ctx, metaKey, m)
h.ServeHTTP(w, r.WithContext(ctx))
})
}
func TraceIDFrom(r *http.Request) string {
@ -41,44 +44,120 @@ func TraceIDFrom(r *http.Request) string {
}
func WithSQL(r *http.Request, sql string) *http.Request {
return r.WithContext(context.WithValue(r.Context(), sqlKey, sql))
return r.WithContext(context.WithValue(r.Context(), sqlKey, sql))
}
func SQLFrom(r *http.Request) string {
v := r.Context().Value(sqlKey)
if v == nil {
return ""
}
s, _ := v.(string)
return s
v := r.Context().Value(sqlKey)
if v == nil {
return ""
}
s, _ := v.(string)
return s
}
type ReqMeta struct {
Method string
Path string
Query string
Remote string
Method string
Path string
Query string
Remote string
}
func MetaFrom(r *http.Request) ReqMeta {
v := r.Context().Value(metaKey)
if v == nil {
return ReqMeta{}
}
m, _ := v.(ReqMeta)
return m
v := r.Context().Value(metaKey)
if v == nil {
return ReqMeta{}
}
m, _ := v.(ReqMeta)
return m
}
func WithPayload(r *http.Request, v interface{}) *http.Request {
b, _ := json.Marshal(v)
return r.WithContext(context.WithValue(r.Context(), payloadKey, string(b)))
b, _ := json.Marshal(v)
return r.WithContext(context.WithValue(r.Context(), payloadKey, string(b)))
}
func PayloadFrom(r *http.Request) string {
v := r.Context().Value(payloadKey)
if v == nil {
return ""
}
s, _ := v.(string)
return s
v := r.Context().Value(payloadKey)
if v == nil {
return ""
}
s, _ := v.(string)
return s
}
// AuthResponse 营销系统鉴权接口返回结构
type AuthResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Data []int `json:"data"` // 创建者ID列表
}
// withAuth 认证中间件:验证 token 并获取用户数据权限
func withAuth(apiDomain string) func(http.Handler) http.Handler {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 获取 token
token := r.Header.Get("token")
if token == "" {
http.Error(w, "{\"code\":401,\"message\":\"未提供认证token\",\"data\":null}", http.StatusUnauthorized)
return
}
// 请求营销系统鉴权接口
authURL := fmt.Sprintf("%s/auth/admin/dataPermission", apiDomain)
req, err := http.NewRequest("GET", authURL, nil)
if err != nil {
http.Error(w, "{\"code\":500,\"message\":\"创建认证请求失败\",\"data\":null}", http.StatusInternalServerError)
return
}
// 设置请求头(根据图片示例)
req.Header.Set("Authorization", "Bearer "+token)
// 发起请求
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
http.Error(w, "{\"code\":500,\"message\":\"认证请求失败\",\"data\":null}", http.StatusInternalServerError)
return
}
defer resp.Body.Close()
// 读取响应
body, err := io.ReadAll(resp.Body)
if err != nil {
http.Error(w, "{\"code\":500,\"message\":\"读取认证响应失败\",\"data\":null}", http.StatusInternalServerError)
return
}
// 解析响应
var authResp AuthResponse
if err := json.Unmarshal(body, &authResp); err != nil {
http.Error(w, "{\"code\":500,\"message\":\"解析认证响应失败\",\"data\":null}", http.StatusInternalServerError)
return
}
// 检查认证是否成功
if authResp.Code != 200 {
errorMsg := fmt.Sprintf("{\"code\":%d,\"message\":\"%s\",\"data\":null}", authResp.Code, authResp.Message)
http.Error(w, errorMsg, http.StatusUnauthorized)
return
}
// 将创建者ID列表存储到 context 中
ctx := context.WithValue(r.Context(), creatorIDsKey, authResp.Data)
h.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// CreatorIDsFrom 从 context 中获取创建者ID列表
func CreatorIDsFrom(r *http.Request) []int {
v := r.Context().Value(creatorIDsKey)
if v == nil {
return nil
}
ids, _ := v.([]int)
return ids
}

View File

@ -6,27 +6,34 @@ import (
"os"
)
func NewRouter(metaDB *sql.DB, marketingDB *sql.DB, marketingAuthDB *sql.DB, resellerDB *sql.DB, ymtDB *sql.DB, grpcAddr string) http.Handler {
func NewRouter(metaDB *sql.DB, marketingDB *sql.DB, marketingAuthDB *sql.DB, resellerDB *sql.DB, ymtDB *sql.DB, grpcAddr string, marketingAPIDomain string) http.Handler {
mux := http.NewServeMux()
mux.Handle("/api/templates", withAccess(withTrace(TemplatesHandler(metaDB, marketingDB))))
mux.Handle("/api/templates/", withAccess(withTrace(TemplatesHandler(metaDB, marketingDB))))
mux.Handle("/api/exports", withAccess(withTrace(ExportsHandler(metaDB, marketingDB, ymtDB))))
mux.Handle("/api/exports/", withAccess(withTrace(ExportsHandler(metaDB, marketingDB, ymtDB))))
mux.Handle("/api/metadata/fields", withAccess(withTrace(MetadataHandler(metaDB, marketingDB, ymtDB))))
mux.Handle("/api/fields", withAccess(withTrace(FieldsHandler(marketingDB, ymtDB))))
mux.Handle("/api/fields/", withAccess(withTrace(FieldsHandler(marketingDB, ymtDB))))
mux.Handle("/api/creators", withAccess(withTrace(CreatorsHandler(marketingAuthDB))))
mux.Handle("/api/creators/", withAccess(withTrace(CreatorsHandler(marketingAuthDB))))
mux.Handle("/api/resellers", withAccess(withTrace(ResellersHandler(resellerDB))))
mux.Handle("/api/resellers/", withAccess(withTrace(ResellersHandler(resellerDB))))
mux.Handle("/api/plans", withAccess(withTrace(PlansHandler(marketingDB))))
mux.Handle("/api/plans/", withAccess(withTrace(PlansHandler(marketingDB))))
mux.Handle("/api/ymt/users", withAccess(withTrace(YMTUsersHandler(grpcAddr))))
mux.Handle("/api/ymt/users/", withAccess(withTrace(YMTUsersHandler(grpcAddr))))
mux.Handle("/api/ymt/merchants", withAccess(withTrace(YMTMerchantsHandler(ymtDB))))
mux.Handle("/api/ymt/merchants/", withAccess(withTrace(YMTMerchantsHandler(ymtDB))))
mux.Handle("/api/ymt/activities", withAccess(withTrace(YMTActivitiesHandler(ymtDB))))
mux.Handle("/api/ymt/activities/", withAccess(withTrace(YMTActivitiesHandler(ymtDB))))
// 创建认证中间件
authMiddleware := withAuth(marketingAPIDomain)
// 需要认证的路由(所有 API 路由)
mux.Handle("/api/templates", withAccess(withTrace(authMiddleware(TemplatesHandler(metaDB, marketingDB)))))
mux.Handle("/api/templates/", withAccess(withTrace(authMiddleware(TemplatesHandler(metaDB, marketingDB)))))
mux.Handle("/api/exports", withAccess(withTrace(authMiddleware(ExportsHandler(metaDB, marketingDB, ymtDB)))))
mux.Handle("/api/exports/", withAccess(withTrace(authMiddleware(ExportsHandler(metaDB, marketingDB, ymtDB)))))
mux.Handle("/api/metadata/fields", withAccess(withTrace(authMiddleware(MetadataHandler(metaDB, marketingDB, ymtDB)))))
mux.Handle("/api/fields", withAccess(withTrace(authMiddleware(FieldsHandler(marketingDB, ymtDB)))))
mux.Handle("/api/fields/", withAccess(withTrace(authMiddleware(FieldsHandler(marketingDB, ymtDB)))))
mux.Handle("/api/creators", withAccess(withTrace(authMiddleware(CreatorsHandler(marketingAuthDB)))))
mux.Handle("/api/creators/", withAccess(withTrace(authMiddleware(CreatorsHandler(marketingAuthDB)))))
mux.Handle("/api/resellers", withAccess(withTrace(authMiddleware(ResellersHandler(resellerDB)))))
mux.Handle("/api/resellers/", withAccess(withTrace(authMiddleware(ResellersHandler(resellerDB)))))
mux.Handle("/api/plans", withAccess(withTrace(authMiddleware(PlansHandler(marketingDB)))))
mux.Handle("/api/plans/", withAccess(withTrace(authMiddleware(PlansHandler(marketingDB)))))
mux.Handle("/api/ymt/users", withAccess(withTrace(authMiddleware(YMTUsersHandler(grpcAddr)))))
mux.Handle("/api/ymt/users/", withAccess(withTrace(authMiddleware(YMTUsersHandler(grpcAddr)))))
mux.Handle("/api/ymt/merchants", withAccess(withTrace(authMiddleware(YMTMerchantsHandler(ymtDB)))))
mux.Handle("/api/ymt/merchants/", withAccess(withTrace(authMiddleware(YMTMerchantsHandler(ymtDB)))))
mux.Handle("/api/ymt/activities", withAccess(withTrace(authMiddleware(YMTActivitiesHandler(ymtDB)))))
mux.Handle("/api/ymt/activities/", withAccess(withTrace(authMiddleware(YMTActivitiesHandler(ymtDB)))))
// 工具类接口(不需要认证)
mux.HandleFunc("/api/utils/decode_key", func(w http.ResponseWriter, r *http.Request) {
v := r.URL.Query().Get("v")
if v == "" {

View File

@ -38,6 +38,7 @@ type App struct {
YMTTestDB DB `yaml:"ymt_test_db"`
YMTKeyDecryptKeyB64 string `yaml:"ymt_key_decrypt_key_b64"`
GRPCServer GRPCServer `yaml:"grpc_server"`
MarketingAPIDomain string `yaml:"marketing_api_domain"`
}
type root struct {