MarketingSystemDataExportTool/server/internal/api/middleware.go

210 lines
5.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package api
import (
"context"
"crypto/rand"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
)
type ctxKey string
var traceKey ctxKey = "trace_id"
var sqlKey ctxKey = "sql"
var metaKey ctxKey = "req_meta"
var payloadKey ctxKey = "payload"
var creatorIDsKey ctxKey = "creator_ids" // 存储用户的创建者ID列表
var isAdminKey ctxKey = "is_admin" // 存储是否超级管理员
func withTrace(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tid := r.Header.Get("X-Request-ID")
if tid == "" {
buf := make([]byte, 16)
_, _ = rand.Read(buf)
tid = hex.EncodeToString(buf)
}
w.Header().Set("X-Request-ID", tid)
m := ReqMeta{Method: r.Method, Path: r.URL.Path, Query: r.URL.RawQuery, Remote: r.RemoteAddr}
ctx := context.WithValue(r.Context(), traceKey, tid)
ctx = context.WithValue(ctx, metaKey, m)
h.ServeHTTP(w, r.WithContext(ctx))
})
}
func TraceIDFrom(r *http.Request) string {
v := r.Context().Value(traceKey)
if v == nil {
return ""
}
s, _ := v.(string)
return s
}
func WithSQL(r *http.Request, sql string) *http.Request {
return r.WithContext(context.WithValue(r.Context(), sqlKey, sql))
}
func SQLFrom(r *http.Request) string {
v := r.Context().Value(sqlKey)
if v == nil {
return ""
}
s, _ := v.(string)
return s
}
type ReqMeta struct {
Method string
Path string
Query string
Remote string
}
func MetaFrom(r *http.Request) ReqMeta {
v := r.Context().Value(metaKey)
if v == nil {
return ReqMeta{}
}
m, _ := v.(ReqMeta)
return m
}
func WithPayload(r *http.Request, v interface{}) *http.Request {
b, _ := json.Marshal(v)
return r.WithContext(context.WithValue(r.Context(), payloadKey, string(b)))
}
func PayloadFrom(r *http.Request) string {
v := r.Context().Value(payloadKey)
if v == nil {
return ""
}
s, _ := v.(string)
return s
}
// AuthResponse 营销系统鉴权接口返回结构
type AuthResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Data interface{} `json:"data"` // 创建者ID列表或权限数据
IsAdmin int `json:"isAdmin"` // 1:超级管理员 0:普通用户
}
// withAuth 认证中间件:验证 token 并获取用户数据权限
func withAuth(apiDomain string) func(http.Handler) http.Handler {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 获取 token
token := r.Header.Get("token")
if token == "" {
http.Error(w, "{\"code\":401,\"message\":\"请先登录\",\"data\":null}", http.StatusUnauthorized)
return
}
// 请求营销系统鉴权接口
authURL := fmt.Sprintf("%s/auth/admin/dataPermission", apiDomain)
req, err := http.NewRequest("GET", authURL, nil)
if err != nil {
http.Error(w, "{\"code\":401,\"message\":\"请先登录\",\"data\":null}", http.StatusUnauthorized)
return
}
// 设置请求头(根据图片示例)
req.Header.Set("Authorization", "Bearer "+token)
// 发起请求
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
http.Error(w, "{\"code\":500,\"message\":\"认证请求失败\",\"data\":null}", http.StatusInternalServerError)
return
}
defer resp.Body.Close()
// 读取响应
body, err := io.ReadAll(resp.Body)
if err != nil {
http.Error(w, "{\"code\":500,\"message\":\"读取认证响应失败\",\"data\":null}", http.StatusInternalServerError)
return
}
// 解析响应
var authResp AuthResponse
if err := json.Unmarshal(body, &authResp); err != nil {
http.Error(w, "{\"code\":500,\"message\":\"解析认证响应失败\",\"data\":null}", http.StatusInternalServerError)
return
}
// 检查认证是否成功(支持 HTTP 状态码和业务 code
if resp.StatusCode != http.StatusOK || authResp.Code != 200 {
// 所有认证错误都显示"请先登录"
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(fmt.Sprintf("{\"code\":%d,\"message\":\"%s\",\"data\":%v}", authResp.Code, authResp.Message, authResp.Data)))
return
}
// 将创建者ID列表和isAdmin信息存储到 context 中
var creatorIDs []int
var isAdmin int = authResp.IsAdmin // 从响应中获取 isAdmin 信息
// 处理 Data 字段(可能是数组或对象)
if data, ok := authResp.Data.([]interface{}); ok {
// Data 是数组,转换为 []int
creatorIDs = make([]int, 0, len(data))
for _, v := range data {
if num, ok := v.(float64); ok {
creatorIDs = append(creatorIDs, int(num))
}
}
} else if dataMap, ok := authResp.Data.(map[string]interface{}); ok {
// Data 是对象从中提取创建者ID列表
if ids, ok := dataMap["creatorIDs"].([]interface{}); ok {
creatorIDs = make([]int, 0, len(ids))
for _, v := range ids {
if num, ok := v.(float64); ok {
creatorIDs = append(creatorIDs, int(num))
}
}
}
// 如果 isAdmin 也存在于 data 对象中,优先使用 data 中的值
if adminVal, ok := dataMap["isAdmin"].(float64); ok {
isAdmin = int(adminVal)
}
} else if dataSlice, ok := authResp.Data.([]int); ok {
// Data 直接是 []int
creatorIDs = dataSlice
}
ctx := context.WithValue(r.Context(), creatorIDsKey, creatorIDs)
ctx = context.WithValue(ctx, isAdminKey, isAdmin)
h.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// CreatorIDsFrom 从 context 中获取创建者ID列表
func CreatorIDsFrom(r *http.Request) []int {
v := r.Context().Value(creatorIDsKey)
if v == nil {
return nil
}
ids, _ := v.([]int)
return ids
}
// IsAdminFrom 从 context 中获取是否超级管理员
func IsAdminFrom(r *http.Request) int {
v := r.Context().Value(isAdminKey)
if v == nil {
return 0
}
admin, _ := v.(int)
return admin
}