129 lines
2.8 KiB
Go
129 lines
2.8 KiB
Go
package router
|
||
|
||
import (
|
||
"encoding/json"
|
||
"errors"
|
||
"geo/pkg"
|
||
"geo/tmpl/errcode"
|
||
"reflect"
|
||
"strings"
|
||
|
||
"github.com/go-playground/validator/v10"
|
||
"github.com/gofiber/fiber/v2"
|
||
)
|
||
|
||
type Module interface {
|
||
Register(r fiber.Router)
|
||
}
|
||
type RouterServer struct {
|
||
modules []Module
|
||
}
|
||
|
||
func NewRouterServer(
|
||
app *AppModule,
|
||
) *RouterServer {
|
||
return &RouterServer{
|
||
modules: []Module{app},
|
||
}
|
||
}
|
||
|
||
// SetupRoutes 设置路由
|
||
func SetupRoutes(app *fiber.App, routerServer *RouterServer) {
|
||
app.Use(func(c *fiber.Ctx) error {
|
||
// 设置 CORS 头
|
||
c.Set("Access-Control-Allow-Origin", "*")
|
||
c.Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
||
c.Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
||
|
||
// 如果是预检请求(OPTIONS),直接返回 204
|
||
if c.Method() == "OPTIONS" {
|
||
return c.SendStatus(fiber.StatusNoContent) // 204
|
||
}
|
||
|
||
// 继续处理后续中间件或路由
|
||
return c.Next()
|
||
})
|
||
|
||
registerResponse(app)
|
||
|
||
// 注册所有模块
|
||
for _, module := range routerServer.modules {
|
||
module.Register(app)
|
||
}
|
||
}
|
||
|
||
func registerResponse(router fiber.Router) {
|
||
// 自定义返回
|
||
router.Use(func(c *fiber.Ctx) error {
|
||
err := c.Next()
|
||
return registerCommon(c, err)
|
||
})
|
||
|
||
}
|
||
|
||
func registerCommon(c *fiber.Ctx, err error) error {
|
||
// 调用下一个中间件或路由处理函数
|
||
|
||
// 如果有错误发生
|
||
if err != nil {
|
||
var businessErr *errcode.BusinessErr
|
||
var fiberError *fiber.Error
|
||
switch {
|
||
case errors.As(err, &businessErr):
|
||
case errors.As(err, &fiberError):
|
||
errors.As(err, &fiberError)
|
||
businessErr = errcode.NewBusinessErr(fiberError.Code, fiberError.Message)
|
||
default:
|
||
businessErr = errcode.SystemError
|
||
}
|
||
// 返回自定义错误响应
|
||
return c.JSON(fiber.Map{
|
||
"message": businessErr.Error(),
|
||
"code": businessErr.Code(),
|
||
"data": nil,
|
||
})
|
||
}
|
||
// 如果没有错误发生,继续处理请求
|
||
var data interface{}
|
||
json.Unmarshal(c.Response().Body(), &data)
|
||
return c.JSON(fiber.Map{
|
||
"data": data,
|
||
"message": errcode.Success.Error(),
|
||
"code": errcode.Success.Code(),
|
||
})
|
||
}
|
||
|
||
func vali[T any](handler func(*fiber.Ctx, *T) error, _ *T) fiber.Handler {
|
||
return func(c *fiber.Ctx) error {
|
||
var data T
|
||
|
||
// 解析请求
|
||
if err := c.BodyParser(&data); err != nil {
|
||
return errcode.ParamErr(err.Error())
|
||
}
|
||
|
||
// 创建验证器
|
||
validate := validator.New()
|
||
|
||
// 注册中文标签
|
||
validate.RegisterTagNameFunc(func(fld reflect.StructField) string {
|
||
name := fld.Tag.Get("zh")
|
||
if name == "" {
|
||
name = fld.Tag.Get("json")
|
||
}
|
||
return name
|
||
})
|
||
|
||
// 验证
|
||
if err := validate.Struct(&data); err != nil {
|
||
er := make([]string, len(err.(validator.ValidationErrors)))
|
||
for k, e := range err.(validator.ValidationErrors) {
|
||
er[k] = pkg.GetErr(e.Tag(), e.Field(), e.Param())
|
||
}
|
||
return errcode.ParamErr(strings.Join(er, ","))
|
||
}
|
||
|
||
return handler(c, &data)
|
||
}
|
||
}
|