MarketingSystemDataExportTool/server/internal/api/middleware.go

164 lines
4.1 KiB
Go

package api
import (
"context"
"crypto/rand"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
)
type ctxKey string
var traceKey ctxKey = "trace_id"
var sqlKey ctxKey = "sql"
var metaKey ctxKey = "req_meta"
var payloadKey ctxKey = "payload"
var creatorIDsKey ctxKey = "creator_ids" // 存储用户的创建者ID列表
func withTrace(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tid := r.Header.Get("X-Request-ID")
if tid == "" {
buf := make([]byte, 16)
_, _ = rand.Read(buf)
tid = hex.EncodeToString(buf)
}
w.Header().Set("X-Request-ID", tid)
m := ReqMeta{Method: r.Method, Path: r.URL.Path, Query: r.URL.RawQuery, Remote: r.RemoteAddr}
ctx := context.WithValue(r.Context(), traceKey, tid)
ctx = context.WithValue(ctx, metaKey, m)
h.ServeHTTP(w, r.WithContext(ctx))
})
}
func TraceIDFrom(r *http.Request) string {
v := r.Context().Value(traceKey)
if v == nil {
return ""
}
s, _ := v.(string)
return s
}
func WithSQL(r *http.Request, sql string) *http.Request {
return r.WithContext(context.WithValue(r.Context(), sqlKey, sql))
}
func SQLFrom(r *http.Request) string {
v := r.Context().Value(sqlKey)
if v == nil {
return ""
}
s, _ := v.(string)
return s
}
type ReqMeta struct {
Method string
Path string
Query string
Remote string
}
func MetaFrom(r *http.Request) ReqMeta {
v := r.Context().Value(metaKey)
if v == nil {
return ReqMeta{}
}
m, _ := v.(ReqMeta)
return m
}
func WithPayload(r *http.Request, v interface{}) *http.Request {
b, _ := json.Marshal(v)
return r.WithContext(context.WithValue(r.Context(), payloadKey, string(b)))
}
func PayloadFrom(r *http.Request) string {
v := r.Context().Value(payloadKey)
if v == nil {
return ""
}
s, _ := v.(string)
return s
}
// AuthResponse 营销系统鉴权接口返回结构
type AuthResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Data []int `json:"data"` // 创建者ID列表
}
// withAuth 认证中间件:验证 token 并获取用户数据权限
func withAuth(apiDomain string) func(http.Handler) http.Handler {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 获取 token
token := r.Header.Get("token")
if token == "" {
http.Error(w, "{\"code\":401,\"message\":\"未提供认证token\",\"data\":null}", http.StatusUnauthorized)
return
}
// 请求营销系统鉴权接口
authURL := fmt.Sprintf("%s/auth/admin/dataPermission", apiDomain)
req, err := http.NewRequest("GET", authURL, nil)
if err != nil {
http.Error(w, "{\"code\":500,\"message\":\"创建认证请求失败\",\"data\":null}", http.StatusInternalServerError)
return
}
// 设置请求头(根据图片示例)
req.Header.Set("Authorization", "Bearer "+token)
// 发起请求
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
http.Error(w, "{\"code\":500,\"message\":\"认证请求失败\",\"data\":null}", http.StatusInternalServerError)
return
}
defer resp.Body.Close()
// 读取响应
body, err := io.ReadAll(resp.Body)
if err != nil {
http.Error(w, "{\"code\":500,\"message\":\"读取认证响应失败\",\"data\":null}", http.StatusInternalServerError)
return
}
// 解析响应
var authResp AuthResponse
if err := json.Unmarshal(body, &authResp); err != nil {
http.Error(w, "{\"code\":500,\"message\":\"解析认证响应失败\",\"data\":null}", http.StatusInternalServerError)
return
}
// 检查认证是否成功
if authResp.Code != 200 {
errorMsg := fmt.Sprintf("{\"code\":%d,\"message\":\"%s\",\"data\":null}", authResp.Code, authResp.Message)
http.Error(w, errorMsg, http.StatusUnauthorized)
return
}
// 将创建者ID列表存储到 context 中
ctx := context.WithValue(r.Context(), creatorIDsKey, authResp.Data)
h.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// CreatorIDsFrom 从 context 中获取创建者ID列表
func CreatorIDsFrom(r *http.Request) []int {
v := r.Context().Value(creatorIDsKey)
if v == nil {
return nil
}
ids, _ := v.([]int)
return ids
}