164 lines
4.1 KiB
Go
164 lines
4.1 KiB
Go
package api
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
)
|
|
|
|
type ctxKey string
|
|
|
|
var traceKey ctxKey = "trace_id"
|
|
var sqlKey ctxKey = "sql"
|
|
var metaKey ctxKey = "req_meta"
|
|
var payloadKey ctxKey = "payload"
|
|
var creatorIDsKey ctxKey = "creator_ids" // 存储用户的创建者ID列表
|
|
|
|
func withTrace(h http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
tid := r.Header.Get("X-Request-ID")
|
|
if tid == "" {
|
|
buf := make([]byte, 16)
|
|
_, _ = rand.Read(buf)
|
|
tid = hex.EncodeToString(buf)
|
|
}
|
|
w.Header().Set("X-Request-ID", tid)
|
|
m := ReqMeta{Method: r.Method, Path: r.URL.Path, Query: r.URL.RawQuery, Remote: r.RemoteAddr}
|
|
ctx := context.WithValue(r.Context(), traceKey, tid)
|
|
ctx = context.WithValue(ctx, metaKey, m)
|
|
h.ServeHTTP(w, r.WithContext(ctx))
|
|
})
|
|
}
|
|
|
|
func TraceIDFrom(r *http.Request) string {
|
|
v := r.Context().Value(traceKey)
|
|
if v == nil {
|
|
return ""
|
|
}
|
|
s, _ := v.(string)
|
|
return s
|
|
}
|
|
|
|
func WithSQL(r *http.Request, sql string) *http.Request {
|
|
return r.WithContext(context.WithValue(r.Context(), sqlKey, sql))
|
|
}
|
|
|
|
func SQLFrom(r *http.Request) string {
|
|
v := r.Context().Value(sqlKey)
|
|
if v == nil {
|
|
return ""
|
|
}
|
|
s, _ := v.(string)
|
|
return s
|
|
}
|
|
|
|
type ReqMeta struct {
|
|
Method string
|
|
Path string
|
|
Query string
|
|
Remote string
|
|
}
|
|
|
|
func MetaFrom(r *http.Request) ReqMeta {
|
|
v := r.Context().Value(metaKey)
|
|
if v == nil {
|
|
return ReqMeta{}
|
|
}
|
|
m, _ := v.(ReqMeta)
|
|
return m
|
|
}
|
|
|
|
func WithPayload(r *http.Request, v interface{}) *http.Request {
|
|
b, _ := json.Marshal(v)
|
|
return r.WithContext(context.WithValue(r.Context(), payloadKey, string(b)))
|
|
}
|
|
|
|
func PayloadFrom(r *http.Request) string {
|
|
v := r.Context().Value(payloadKey)
|
|
if v == nil {
|
|
return ""
|
|
}
|
|
s, _ := v.(string)
|
|
return s
|
|
}
|
|
|
|
// AuthResponse 营销系统鉴权接口返回结构
|
|
type AuthResponse struct {
|
|
Code int `json:"code"`
|
|
Message string `json:"message"`
|
|
Data []int `json:"data"` // 创建者ID列表
|
|
}
|
|
|
|
// withAuth 认证中间件:验证 token 并获取用户数据权限
|
|
func withAuth(apiDomain string) func(http.Handler) http.Handler {
|
|
return func(h http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// 获取 token
|
|
token := r.Header.Get("token")
|
|
if token == "" {
|
|
http.Error(w, "{\"code\":401,\"message\":\"未提供认证token\",\"data\":null}", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
// 请求营销系统鉴权接口
|
|
authURL := fmt.Sprintf("%s/auth/admin/dataPermission", apiDomain)
|
|
req, err := http.NewRequest("GET", authURL, nil)
|
|
if err != nil {
|
|
http.Error(w, "{\"code\":500,\"message\":\"创建认证请求失败\",\"data\":null}", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// 设置请求头(根据图片示例)
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
|
|
// 发起请求
|
|
client := &http.Client{}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
http.Error(w, "{\"code\":500,\"message\":\"认证请求失败\",\"data\":null}", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
// 读取响应
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
http.Error(w, "{\"code\":500,\"message\":\"读取认证响应失败\",\"data\":null}", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// 解析响应
|
|
var authResp AuthResponse
|
|
if err := json.Unmarshal(body, &authResp); err != nil {
|
|
http.Error(w, "{\"code\":500,\"message\":\"解析认证响应失败\",\"data\":null}", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// 检查认证是否成功
|
|
if authResp.Code != 200 {
|
|
errorMsg := fmt.Sprintf("{\"code\":%d,\"message\":\"%s\",\"data\":null}", authResp.Code, authResp.Message)
|
|
http.Error(w, errorMsg, http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
// 将创建者ID列表存储到 context 中
|
|
ctx := context.WithValue(r.Context(), creatorIDsKey, authResp.Data)
|
|
h.ServeHTTP(w, r.WithContext(ctx))
|
|
})
|
|
}
|
|
}
|
|
|
|
// CreatorIDsFrom 从 context 中获取创建者ID列表
|
|
func CreatorIDsFrom(r *http.Request) []int {
|
|
v := r.Context().Value(creatorIDsKey)
|
|
if v == nil {
|
|
return nil
|
|
}
|
|
ids, _ := v.([]int)
|
|
return ids
|
|
}
|