feat(api): 引入营销系统鉴权中间件,增强接口安全控制
- 新增配置项 MarketingAPIDomain,用于设置营销系统API域名 - 服务器启动时打印营销API域名警告或信息 - api/router.go中添加认证中间件,所有API路由均需认证访问 - api/middleware.go新增withAuth认证中间件,实现Token验证和数据权限接口调用 - 认证中间件通过HTTP请求调用营销系统鉴权接口获取创建者ID列表 - 请求头Access-Control-Allow-Headers新增token,支持跨域传递认证token - 提供从请求上下文获取创建者ID列表的辅助函数CreatorIDsFrom - 适配NewRouter函数,新增marketingAPIDomain参数用于中间件配置
This commit is contained in:
parent
4742b3db05
commit
ade149c67c
|
|
@ -84,7 +84,14 @@ func main() {
|
|||
} else {
|
||||
log.Println("warning: gRPC server address not configured, /api/ymt/users will not work")
|
||||
}
|
||||
r := api.NewRouter(meta, marketing, marketingAuth, resellerDB, ymt, grpcAddr)
|
||||
// 获取营销系统API域名
|
||||
marketingAPIDomain := cfg.MarketingAPIDomain
|
||||
if marketingAPIDomain == "" {
|
||||
log.Println("warning: marketing API domain not configured, authentication will not work")
|
||||
} else {
|
||||
log.Println("marketing API domain:", marketingAPIDomain)
|
||||
}
|
||||
r := api.NewRouter(meta, marketing, marketingAuth, resellerDB, ymt, grpcAddr, marketingAPIDomain)
|
||||
addr := ":" + func() string {
|
||||
s := cfg.Port
|
||||
if s == "" {
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ func withAccess(h http.Handler) http.Handler {
|
|||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET,POST,PATCH,DELETE,OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, token")
|
||||
if r.Method == http.MethodOptions {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
|
|
|
|||
|
|
@ -1,11 +1,13 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type ctxKey string
|
||||
|
|
@ -14,21 +16,22 @@ var traceKey ctxKey = "trace_id"
|
|||
var sqlKey ctxKey = "sql"
|
||||
var metaKey ctxKey = "req_meta"
|
||||
var payloadKey ctxKey = "payload"
|
||||
var creatorIDsKey ctxKey = "creator_ids" // 存储用户的创建者ID列表
|
||||
|
||||
func withTrace(h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
tid := r.Header.Get("X-Request-ID")
|
||||
if tid == "" {
|
||||
buf := make([]byte, 16)
|
||||
_, _ = rand.Read(buf)
|
||||
tid = hex.EncodeToString(buf)
|
||||
}
|
||||
w.Header().Set("X-Request-ID", tid)
|
||||
m := ReqMeta{Method: r.Method, Path: r.URL.Path, Query: r.URL.RawQuery, Remote: r.RemoteAddr}
|
||||
ctx := context.WithValue(r.Context(), traceKey, tid)
|
||||
ctx = context.WithValue(ctx, metaKey, m)
|
||||
h.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
tid := r.Header.Get("X-Request-ID")
|
||||
if tid == "" {
|
||||
buf := make([]byte, 16)
|
||||
_, _ = rand.Read(buf)
|
||||
tid = hex.EncodeToString(buf)
|
||||
}
|
||||
w.Header().Set("X-Request-ID", tid)
|
||||
m := ReqMeta{Method: r.Method, Path: r.URL.Path, Query: r.URL.RawQuery, Remote: r.RemoteAddr}
|
||||
ctx := context.WithValue(r.Context(), traceKey, tid)
|
||||
ctx = context.WithValue(ctx, metaKey, m)
|
||||
h.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
func TraceIDFrom(r *http.Request) string {
|
||||
|
|
@ -41,44 +44,120 @@ func TraceIDFrom(r *http.Request) string {
|
|||
}
|
||||
|
||||
func WithSQL(r *http.Request, sql string) *http.Request {
|
||||
return r.WithContext(context.WithValue(r.Context(), sqlKey, sql))
|
||||
return r.WithContext(context.WithValue(r.Context(), sqlKey, sql))
|
||||
}
|
||||
|
||||
func SQLFrom(r *http.Request) string {
|
||||
v := r.Context().Value(sqlKey)
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
s, _ := v.(string)
|
||||
return s
|
||||
v := r.Context().Value(sqlKey)
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
s, _ := v.(string)
|
||||
return s
|
||||
}
|
||||
|
||||
type ReqMeta struct {
|
||||
Method string
|
||||
Path string
|
||||
Query string
|
||||
Remote string
|
||||
Method string
|
||||
Path string
|
||||
Query string
|
||||
Remote string
|
||||
}
|
||||
|
||||
func MetaFrom(r *http.Request) ReqMeta {
|
||||
v := r.Context().Value(metaKey)
|
||||
if v == nil {
|
||||
return ReqMeta{}
|
||||
}
|
||||
m, _ := v.(ReqMeta)
|
||||
return m
|
||||
v := r.Context().Value(metaKey)
|
||||
if v == nil {
|
||||
return ReqMeta{}
|
||||
}
|
||||
m, _ := v.(ReqMeta)
|
||||
return m
|
||||
}
|
||||
|
||||
func WithPayload(r *http.Request, v interface{}) *http.Request {
|
||||
b, _ := json.Marshal(v)
|
||||
return r.WithContext(context.WithValue(r.Context(), payloadKey, string(b)))
|
||||
b, _ := json.Marshal(v)
|
||||
return r.WithContext(context.WithValue(r.Context(), payloadKey, string(b)))
|
||||
}
|
||||
|
||||
func PayloadFrom(r *http.Request) string {
|
||||
v := r.Context().Value(payloadKey)
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
s, _ := v.(string)
|
||||
return s
|
||||
v := r.Context().Value(payloadKey)
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
s, _ := v.(string)
|
||||
return s
|
||||
}
|
||||
|
||||
// AuthResponse 营销系统鉴权接口返回结构
|
||||
type AuthResponse struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data []int `json:"data"` // 创建者ID列表
|
||||
}
|
||||
|
||||
// withAuth 认证中间件:验证 token 并获取用户数据权限
|
||||
func withAuth(apiDomain string) func(http.Handler) http.Handler {
|
||||
return func(h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// 获取 token
|
||||
token := r.Header.Get("token")
|
||||
if token == "" {
|
||||
http.Error(w, "{\"code\":401,\"message\":\"未提供认证token\",\"data\":null}", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// 请求营销系统鉴权接口
|
||||
authURL := fmt.Sprintf("%s/auth/admin/dataPermission", apiDomain)
|
||||
req, err := http.NewRequest("GET", authURL, nil)
|
||||
if err != nil {
|
||||
http.Error(w, "{\"code\":500,\"message\":\"创建认证请求失败\",\"data\":null}", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// 设置请求头(根据图片示例)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
|
||||
// 发起请求
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
http.Error(w, "{\"code\":500,\"message\":\"认证请求失败\",\"data\":null}", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 读取响应
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
http.Error(w, "{\"code\":500,\"message\":\"读取认证响应失败\",\"data\":null}", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// 解析响应
|
||||
var authResp AuthResponse
|
||||
if err := json.Unmarshal(body, &authResp); err != nil {
|
||||
http.Error(w, "{\"code\":500,\"message\":\"解析认证响应失败\",\"data\":null}", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// 检查认证是否成功
|
||||
if authResp.Code != 200 {
|
||||
errorMsg := fmt.Sprintf("{\"code\":%d,\"message\":\"%s\",\"data\":null}", authResp.Code, authResp.Message)
|
||||
http.Error(w, errorMsg, http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// 将创建者ID列表存储到 context 中
|
||||
ctx := context.WithValue(r.Context(), creatorIDsKey, authResp.Data)
|
||||
h.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// CreatorIDsFrom 从 context 中获取创建者ID列表
|
||||
func CreatorIDsFrom(r *http.Request) []int {
|
||||
v := r.Context().Value(creatorIDsKey)
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
ids, _ := v.([]int)
|
||||
return ids
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,27 +6,34 @@ import (
|
|||
"os"
|
||||
)
|
||||
|
||||
func NewRouter(metaDB *sql.DB, marketingDB *sql.DB, marketingAuthDB *sql.DB, resellerDB *sql.DB, ymtDB *sql.DB, grpcAddr string) http.Handler {
|
||||
func NewRouter(metaDB *sql.DB, marketingDB *sql.DB, marketingAuthDB *sql.DB, resellerDB *sql.DB, ymtDB *sql.DB, grpcAddr string, marketingAPIDomain string) http.Handler {
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/api/templates", withAccess(withTrace(TemplatesHandler(metaDB, marketingDB))))
|
||||
mux.Handle("/api/templates/", withAccess(withTrace(TemplatesHandler(metaDB, marketingDB))))
|
||||
mux.Handle("/api/exports", withAccess(withTrace(ExportsHandler(metaDB, marketingDB, ymtDB))))
|
||||
mux.Handle("/api/exports/", withAccess(withTrace(ExportsHandler(metaDB, marketingDB, ymtDB))))
|
||||
mux.Handle("/api/metadata/fields", withAccess(withTrace(MetadataHandler(metaDB, marketingDB, ymtDB))))
|
||||
mux.Handle("/api/fields", withAccess(withTrace(FieldsHandler(marketingDB, ymtDB))))
|
||||
mux.Handle("/api/fields/", withAccess(withTrace(FieldsHandler(marketingDB, ymtDB))))
|
||||
mux.Handle("/api/creators", withAccess(withTrace(CreatorsHandler(marketingAuthDB))))
|
||||
mux.Handle("/api/creators/", withAccess(withTrace(CreatorsHandler(marketingAuthDB))))
|
||||
mux.Handle("/api/resellers", withAccess(withTrace(ResellersHandler(resellerDB))))
|
||||
mux.Handle("/api/resellers/", withAccess(withTrace(ResellersHandler(resellerDB))))
|
||||
mux.Handle("/api/plans", withAccess(withTrace(PlansHandler(marketingDB))))
|
||||
mux.Handle("/api/plans/", withAccess(withTrace(PlansHandler(marketingDB))))
|
||||
mux.Handle("/api/ymt/users", withAccess(withTrace(YMTUsersHandler(grpcAddr))))
|
||||
mux.Handle("/api/ymt/users/", withAccess(withTrace(YMTUsersHandler(grpcAddr))))
|
||||
mux.Handle("/api/ymt/merchants", withAccess(withTrace(YMTMerchantsHandler(ymtDB))))
|
||||
mux.Handle("/api/ymt/merchants/", withAccess(withTrace(YMTMerchantsHandler(ymtDB))))
|
||||
mux.Handle("/api/ymt/activities", withAccess(withTrace(YMTActivitiesHandler(ymtDB))))
|
||||
mux.Handle("/api/ymt/activities/", withAccess(withTrace(YMTActivitiesHandler(ymtDB))))
|
||||
|
||||
// 创建认证中间件
|
||||
authMiddleware := withAuth(marketingAPIDomain)
|
||||
|
||||
// 需要认证的路由(所有 API 路由)
|
||||
mux.Handle("/api/templates", withAccess(withTrace(authMiddleware(TemplatesHandler(metaDB, marketingDB)))))
|
||||
mux.Handle("/api/templates/", withAccess(withTrace(authMiddleware(TemplatesHandler(metaDB, marketingDB)))))
|
||||
mux.Handle("/api/exports", withAccess(withTrace(authMiddleware(ExportsHandler(metaDB, marketingDB, ymtDB)))))
|
||||
mux.Handle("/api/exports/", withAccess(withTrace(authMiddleware(ExportsHandler(metaDB, marketingDB, ymtDB)))))
|
||||
mux.Handle("/api/metadata/fields", withAccess(withTrace(authMiddleware(MetadataHandler(metaDB, marketingDB, ymtDB)))))
|
||||
mux.Handle("/api/fields", withAccess(withTrace(authMiddleware(FieldsHandler(marketingDB, ymtDB)))))
|
||||
mux.Handle("/api/fields/", withAccess(withTrace(authMiddleware(FieldsHandler(marketingDB, ymtDB)))))
|
||||
mux.Handle("/api/creators", withAccess(withTrace(authMiddleware(CreatorsHandler(marketingAuthDB)))))
|
||||
mux.Handle("/api/creators/", withAccess(withTrace(authMiddleware(CreatorsHandler(marketingAuthDB)))))
|
||||
mux.Handle("/api/resellers", withAccess(withTrace(authMiddleware(ResellersHandler(resellerDB)))))
|
||||
mux.Handle("/api/resellers/", withAccess(withTrace(authMiddleware(ResellersHandler(resellerDB)))))
|
||||
mux.Handle("/api/plans", withAccess(withTrace(authMiddleware(PlansHandler(marketingDB)))))
|
||||
mux.Handle("/api/plans/", withAccess(withTrace(authMiddleware(PlansHandler(marketingDB)))))
|
||||
mux.Handle("/api/ymt/users", withAccess(withTrace(authMiddleware(YMTUsersHandler(grpcAddr)))))
|
||||
mux.Handle("/api/ymt/users/", withAccess(withTrace(authMiddleware(YMTUsersHandler(grpcAddr)))))
|
||||
mux.Handle("/api/ymt/merchants", withAccess(withTrace(authMiddleware(YMTMerchantsHandler(ymtDB)))))
|
||||
mux.Handle("/api/ymt/merchants/", withAccess(withTrace(authMiddleware(YMTMerchantsHandler(ymtDB)))))
|
||||
mux.Handle("/api/ymt/activities", withAccess(withTrace(authMiddleware(YMTActivitiesHandler(ymtDB)))))
|
||||
mux.Handle("/api/ymt/activities/", withAccess(withTrace(authMiddleware(YMTActivitiesHandler(ymtDB)))))
|
||||
|
||||
// 工具类接口(不需要认证)
|
||||
mux.HandleFunc("/api/utils/decode_key", func(w http.ResponseWriter, r *http.Request) {
|
||||
v := r.URL.Query().Get("v")
|
||||
if v == "" {
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ type App struct {
|
|||
YMTTestDB DB `yaml:"ymt_test_db"`
|
||||
YMTKeyDecryptKeyB64 string `yaml:"ymt_key_decrypt_key_b64"`
|
||||
GRPCServer GRPCServer `yaml:"grpc_server"`
|
||||
MarketingAPIDomain string `yaml:"marketing_api_domain"`
|
||||
}
|
||||
|
||||
type root struct {
|
||||
|
|
|
|||
Loading…
Reference in New Issue