166 lines
4.2 KiB
Go
166 lines
4.2 KiB
Go
package api
|
||
|
||
import (
|
||
"context"
|
||
"crypto/rand"
|
||
"encoding/hex"
|
||
"encoding/json"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
)
|
||
|
||
type ctxKey string
|
||
|
||
var traceKey ctxKey = "trace_id"
|
||
var sqlKey ctxKey = "sql"
|
||
var metaKey ctxKey = "req_meta"
|
||
var payloadKey ctxKey = "payload"
|
||
var creatorIDsKey ctxKey = "creator_ids" // 存储用户的创建者ID列表
|
||
|
||
func withTrace(h http.Handler) http.Handler {
|
||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
tid := r.Header.Get("X-Request-ID")
|
||
if tid == "" {
|
||
buf := make([]byte, 16)
|
||
_, _ = rand.Read(buf)
|
||
tid = hex.EncodeToString(buf)
|
||
}
|
||
w.Header().Set("X-Request-ID", tid)
|
||
m := ReqMeta{Method: r.Method, Path: r.URL.Path, Query: r.URL.RawQuery, Remote: r.RemoteAddr}
|
||
ctx := context.WithValue(r.Context(), traceKey, tid)
|
||
ctx = context.WithValue(ctx, metaKey, m)
|
||
h.ServeHTTP(w, r.WithContext(ctx))
|
||
})
|
||
}
|
||
|
||
func TraceIDFrom(r *http.Request) string {
|
||
v := r.Context().Value(traceKey)
|
||
if v == nil {
|
||
return ""
|
||
}
|
||
s, _ := v.(string)
|
||
return s
|
||
}
|
||
|
||
func WithSQL(r *http.Request, sql string) *http.Request {
|
||
return r.WithContext(context.WithValue(r.Context(), sqlKey, sql))
|
||
}
|
||
|
||
func SQLFrom(r *http.Request) string {
|
||
v := r.Context().Value(sqlKey)
|
||
if v == nil {
|
||
return ""
|
||
}
|
||
s, _ := v.(string)
|
||
return s
|
||
}
|
||
|
||
type ReqMeta struct {
|
||
Method string
|
||
Path string
|
||
Query string
|
||
Remote string
|
||
}
|
||
|
||
func MetaFrom(r *http.Request) ReqMeta {
|
||
v := r.Context().Value(metaKey)
|
||
if v == nil {
|
||
return ReqMeta{}
|
||
}
|
||
m, _ := v.(ReqMeta)
|
||
return m
|
||
}
|
||
|
||
func WithPayload(r *http.Request, v interface{}) *http.Request {
|
||
b, _ := json.Marshal(v)
|
||
return r.WithContext(context.WithValue(r.Context(), payloadKey, string(b)))
|
||
}
|
||
|
||
func PayloadFrom(r *http.Request) string {
|
||
v := r.Context().Value(payloadKey)
|
||
if v == nil {
|
||
return ""
|
||
}
|
||
s, _ := v.(string)
|
||
return s
|
||
}
|
||
|
||
// AuthResponse 营销系统鉴权接口返回结构
|
||
type AuthResponse struct {
|
||
Code int `json:"code"`
|
||
Message string `json:"message"`
|
||
Data []int `json:"data"` // 创建者ID列表
|
||
}
|
||
|
||
// withAuth 认证中间件:验证 token 并获取用户数据权限
|
||
func withAuth(apiDomain string) func(http.Handler) http.Handler {
|
||
return func(h http.Handler) http.Handler {
|
||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
// 获取 token
|
||
token := r.Header.Get("token")
|
||
if token == "" {
|
||
http.Error(w, "{\"code\":401,\"message\":\"请先登录\",\"data\":null}", http.StatusUnauthorized)
|
||
return
|
||
}
|
||
|
||
// 请求营销系统鉴权接口
|
||
authURL := fmt.Sprintf("%s/auth/admin/dataPermission", apiDomain)
|
||
req, err := http.NewRequest("GET", authURL, nil)
|
||
if err != nil {
|
||
http.Error(w, "{\"code\":401,\"message\":\"请先登录\",\"data\":null}", http.StatusUnauthorized)
|
||
return
|
||
}
|
||
|
||
// 设置请求头(根据图片示例)
|
||
req.Header.Set("Authorization", "Bearer "+token)
|
||
|
||
// 发起请求
|
||
client := &http.Client{}
|
||
resp, err := client.Do(req)
|
||
if err != nil {
|
||
http.Error(w, "{\"code\":500,\"message\":\"认证请求失败\",\"data\":null}", http.StatusInternalServerError)
|
||
return
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
// 读取响应
|
||
body, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
http.Error(w, "{\"code\":500,\"message\":\"读取认证响应失败\",\"data\":null}", http.StatusInternalServerError)
|
||
return
|
||
}
|
||
|
||
// 解析响应
|
||
var authResp AuthResponse
|
||
if err := json.Unmarshal(body, &authResp); err != nil {
|
||
http.Error(w, "{\"code\":500,\"message\":\"解析认证响应失败\",\"data\":null}", http.StatusInternalServerError)
|
||
return
|
||
}
|
||
|
||
// 检查认证是否成功(支持 HTTP 状态码和业务 code)
|
||
if resp.StatusCode != http.StatusOK || authResp.Code != 200 {
|
||
// 所有认证错误都显示"请先登录"
|
||
w.Header().Set("Content-Type", "application/json")
|
||
w.WriteHeader(http.StatusUnauthorized)
|
||
w.Write([]byte("{\"code\":401,\"message\":\"请先登录\",\"data\":null}"))
|
||
return
|
||
}
|
||
|
||
// 将创建者ID列表存储到 context 中
|
||
ctx := context.WithValue(r.Context(), creatorIDsKey, authResp.Data)
|
||
h.ServeHTTP(w, r.WithContext(ctx))
|
||
})
|
||
}
|
||
}
|
||
|
||
// CreatorIDsFrom 从 context 中获取创建者ID列表
|
||
func CreatorIDsFrom(r *http.Request) []int {
|
||
v := r.Context().Value(creatorIDsKey)
|
||
if v == nil {
|
||
return nil
|
||
}
|
||
ids, _ := v.([]int)
|
||
return ids
|
||
}
|