MarketingSystemDataExportTool/server/internal/api/middleware.go

172 lines
4.5 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package api
import (
"context"
"crypto/rand"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
)
type ctxKey string
var traceKey ctxKey = "trace_id"
var sqlKey ctxKey = "sql"
var metaKey ctxKey = "req_meta"
var payloadKey ctxKey = "payload"
var creatorIDsKey ctxKey = "creator_ids" // 存储用户的创建者ID列表
func withTrace(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tid := r.Header.Get("X-Request-ID")
if tid == "" {
buf := make([]byte, 16)
_, _ = rand.Read(buf)
tid = hex.EncodeToString(buf)
}
w.Header().Set("X-Request-ID", tid)
m := ReqMeta{Method: r.Method, Path: r.URL.Path, Query: r.URL.RawQuery, Remote: r.RemoteAddr}
ctx := context.WithValue(r.Context(), traceKey, tid)
ctx = context.WithValue(ctx, metaKey, m)
h.ServeHTTP(w, r.WithContext(ctx))
})
}
func TraceIDFrom(r *http.Request) string {
v := r.Context().Value(traceKey)
if v == nil {
return ""
}
s, _ := v.(string)
return s
}
func WithSQL(r *http.Request, sql string) *http.Request {
return r.WithContext(context.WithValue(r.Context(), sqlKey, sql))
}
func SQLFrom(r *http.Request) string {
v := r.Context().Value(sqlKey)
if v == nil {
return ""
}
s, _ := v.(string)
return s
}
type ReqMeta struct {
Method string
Path string
Query string
Remote string
}
func MetaFrom(r *http.Request) ReqMeta {
v := r.Context().Value(metaKey)
if v == nil {
return ReqMeta{}
}
m, _ := v.(ReqMeta)
return m
}
func WithPayload(r *http.Request, v interface{}) *http.Request {
b, _ := json.Marshal(v)
return r.WithContext(context.WithValue(r.Context(), payloadKey, string(b)))
}
func PayloadFrom(r *http.Request) string {
v := r.Context().Value(payloadKey)
if v == nil {
return ""
}
s, _ := v.(string)
return s
}
// AuthResponse 营销系统鉴权接口返回结构
type AuthResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Data []int `json:"data"` // 创建者ID列表
}
// withAuth 认证中间件:验证 token 并获取用户数据权限
func withAuth(apiDomain string) func(http.Handler) http.Handler {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 获取 token
token := r.Header.Get("token")
if token == "" {
http.Error(w, "{\"code\":401,\"message\":\"未提供认证token\",\"data\":null}", http.StatusUnauthorized)
return
}
// 请求营销系统鉴权接口
authURL := fmt.Sprintf("%s/auth/admin/dataPermission", apiDomain)
req, err := http.NewRequest("GET", authURL, nil)
if err != nil {
http.Error(w, "{\"code\":500,\"message\":\"创建认证请求失败\",\"data\":null}", http.StatusInternalServerError)
return
}
// 设置请求头(根据图片示例)
req.Header.Set("Authorization", "Bearer "+token)
// 发起请求
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
http.Error(w, "{\"code\":500,\"message\":\"认证请求失败\",\"data\":null}", http.StatusInternalServerError)
return
}
defer resp.Body.Close()
// 读取响应
body, err := io.ReadAll(resp.Body)
if err != nil {
http.Error(w, "{\"code\":500,\"message\":\"读取认证响应失败\",\"data\":null}", http.StatusInternalServerError)
return
}
// 解析响应
var authResp AuthResponse
if err := json.Unmarshal(body, &authResp); err != nil {
http.Error(w, "{\"code\":500,\"message\":\"解析认证响应失败\",\"data\":null}", http.StatusInternalServerError)
return
}
// 检查认证是否成功(支持 HTTP 状态码和业务 code
if resp.StatusCode != http.StatusOK || authResp.Code != 200 {
// 优先使用业务返回的错误信息
errorMsg := authResp.Message
if errorMsg == "" {
errorMsg = "认证失败"
}
// 返回原始的业务错误码和消息
responseBody := fmt.Sprintf("{\"code\":%d,\"message\":\"%s\",\"data\":null}", authResp.Code, errorMsg)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(responseBody))
return
}
// 将创建者ID列表存储到 context 中
ctx := context.WithValue(r.Context(), creatorIDsKey, authResp.Data)
h.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// CreatorIDsFrom 从 context 中获取创建者ID列表
func CreatorIDsFrom(r *http.Request) []int {
v := r.Context().Value(creatorIDsKey)
if v == nil {
return nil
}
ids, _ := v.([]int)
return ids
}