geoGo/pkg/validata.go

744 lines
18 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package pkg
import (
"fmt"
"geo/tmpl/errcode"
"reflect"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/go-playground/validator/v10"
"github.com/gofiber/fiber/v2"
)
const (
// 验证标签常量
TagRequired = "required"
TagEmail = "email"
TagMin = "min"
TagMax = "max"
TagLen = "len"
TagNumeric = "numeric"
TagOneof = "oneof"
TagGte = "gte"
TagGt = "gt"
TagLte = "lte"
TagLt = "lt"
TagURL = "url"
TagUUID = "uuid"
TagDatetime = "datetime"
// 中文标签常量
//LabelComment = "comment"
//LabelLabel = "label"
LabelZh = "zh"
// 上下文Key
CtxRequestBody = "request_body"
// 性能优化常量
InitialBuilderSize = 64
MaxCacheSize = 5000
CacheCleanupInterval = 10 * time.Minute
)
// ValidatorConfig 验证器配置
type ValidatorConfig struct {
StatusCode int // 验证失败时的HTTP状态码默认422
ErrorHandler func(c *fiber.Ctx, err error) error // 自定义错误处理
BeforeParse func(c *fiber.Ctx) error // 解析前执行
AfterParse func(c *fiber.Ctx, req interface{}) error // 解析后执行
DisableCache bool // 是否禁用缓存
MaxCacheSize int // 最大缓存数量
UseChinese bool // 是否使用中文提示
EnableMetrics bool // 是否启用指标收集
}
// CacheStats 缓存统计
type CacheStats struct {
hitCount int64
missCount int64
evictCount int64
errorCount int64
accessCount int64
}
// Metrics 指标收集
type Metrics struct {
totalRequests int64
validationTime int64
cacheHitRate float64
mu sync.RWMutex
}
// fieldInfo 字段信息
type fieldInfo struct {
label string
index int
typeKind reflect.Kind
accessCount int64
lastAccess time.Time
}
// typeInfo 类型信息
type typeInfo struct {
fields map[string]*fieldInfo
fieldNames []string
mu sync.RWMutex
accessCount int64
createdAt time.Time
lastAccess time.Time
typeKey string
}
// ValidatorHelper 验证器助手
type ValidatorHelper struct {
validate *validator.Validate
config *ValidatorConfig
typeCache sync.Map
cacheStats *CacheStats
errorFnCache map[string]func(string, string) string
errorFnCacheMu sync.RWMutex
pruneLock sync.Mutex
stopCleanup chan struct{}
cleanupOnce sync.Once
metrics *Metrics
}
var (
ChineseErrorTemplates = map[string]string{
TagRequired: "%s不能为空",
TagEmail: "%s格式不正确",
TagMin: "%s不能小于%s",
TagMax: "%s不能大于%s",
TagLen: "%s长度必须为%s位",
TagNumeric: "%s必须是数字",
TagOneof: "%s必须是以下值之一: %s",
TagGte: "%s不能小于%s",
TagGt: "%s必须大于%s",
TagLte: "%s不能大于%s",
TagLt: "%s必须小于%s",
TagURL: "%s必须是有效的URL地址",
TagUUID: "%s必须是有效的UUID",
TagDatetime: "%s日期格式不正确",
}
defaultErrorTemplates = map[string]string{
TagRequired: "%s is required",
TagEmail: "invalid email format",
TagMin: "%s must be at least %s",
TagMax: "%s must be at most %s",
TagLen: "%s must be exactly %s characters",
TagNumeric: "%s must be numeric",
TagOneof: "%s must be one of: %s",
TagGte: "%s must be greater than or equal to %s",
TagGt: "%s must be greater than %s",
TagLte: "%s must be less than or equal to %s",
TagLt: "%s must be less than %s",
TagURL: "%s must be a valid URL",
TagUUID: "%s must be a valid UUID",
TagDatetime: "%s invalid datetime format",
}
)
var (
// 对象池
builderPool = sync.Pool{
New: func() interface{} {
b := &strings.Builder{}
b.Grow(InitialBuilderSize)
return b
},
}
// 错误信息切片池
errorSlicePool = sync.Pool{
New: func() interface{} {
slice := make([]string, 0, 8)
return &slice
},
}
// 字段信息对象池
fieldInfoPool = sync.Pool{
New: func() interface{} {
return &fieldInfo{
accessCount: 0,
lastAccess: time.Now(),
}
},
}
// 类型信息对象池
typeInfoPool = sync.Pool{
New: func() interface{} {
return &typeInfo{
fields: make(map[string]*fieldInfo),
fieldNames: make([]string, 0, 8),
}
},
}
)
var (
Vh *ValidatorHelper
once sync.Once
)
// NewValidatorHelper 初始化验证器助手
func NewValidatorHelper(config ...*ValidatorConfig) {
once.Do(func() {
v := validator.New()
// 优化JSON标签获取
v.RegisterTagNameFunc(func(fld reflect.StructField) string {
return fld.Tag.Get("json")
})
// 默认配置
cfg := &ValidatorConfig{
StatusCode: fiber.StatusUnprocessableEntity,
ErrorHandler: defaultErrorHandler,
MaxCacheSize: MaxCacheSize,
UseChinese: true,
EnableMetrics: false,
}
if len(config) > 0 && config[0] != nil {
if config[0].StatusCode != 0 {
cfg.StatusCode = config[0].StatusCode
}
if config[0].ErrorHandler != nil {
cfg.ErrorHandler = config[0].ErrorHandler
}
if config[0].MaxCacheSize > 0 {
cfg.MaxCacheSize = config[0].MaxCacheSize
}
cfg.DisableCache = config[0].DisableCache
cfg.BeforeParse = config[0].BeforeParse
cfg.AfterParse = config[0].AfterParse
cfg.UseChinese = config[0].UseChinese
cfg.EnableMetrics = config[0].EnableMetrics
}
// 预编译错误函数
errorFnCache := make(map[string]func(string, string) string)
templates := ChineseErrorTemplates
if !cfg.UseChinese {
templates = defaultErrorTemplates
}
for tag, tmpl := range templates {
t := tmpl // 捕获变量
errorFnCache[tag] = func(field, param string) string {
if strings.Contains(t, "%s") && strings.Count(t, "%s") == 2 {
return fmt.Sprintf(t, field, param)
}
return fmt.Sprintf(t, field)
}
}
Vh = &ValidatorHelper{
validate: v,
config: cfg,
cacheStats: &CacheStats{},
errorFnCache: errorFnCache,
errorFnCacheMu: sync.RWMutex{},
stopCleanup: make(chan struct{}),
metrics: &Metrics{},
}
// 启动定期清理
if !cfg.DisableCache {
go Vh.periodicCleanup()
}
})
}
// ParseAndValidate 解析并验证请求
func ParseAndValidate(c *fiber.Ctx, req interface{}) error {
if Vh == nil {
NewValidatorHelper()
}
if Vh.config.EnableMetrics {
atomic.AddInt64(&Vh.metrics.totalRequests, 1)
defer Vh.recordValidationTime(time.Now())
}
// 执行解析前钩子
if Vh.config.BeforeParse != nil {
if err := Vh.config.BeforeParse(c); err != nil {
return err
}
}
// 解析请求体
err := c.BodyParser(req)
if err != nil {
return errcode.ParamErr("请求格式错误:" + err.Error())
}
// 执行解析后钩子
if Vh.config.AfterParse != nil {
if err = Vh.config.AfterParse(c, req); err != nil {
return errcode.ParamErr(err.Error())
}
}
// 验证数据
err = Vh.validate.Struct(req)
if err != nil {
c.Locals(CtxRequestBody, req)
if !Vh.config.DisableCache {
t := reflect.TypeOf(req)
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
if t.Kind() == reflect.Struct {
Vh.safeGetOrCreateTypeInfo(t)
}
}
return Vh.config.ErrorHandler(c, err)
}
return nil
}
// Validate 直接验证结构体
func Validate(req interface{}) error {
if Vh == nil {
NewValidatorHelper()
}
if err := Vh.validate.Struct(req); err != nil {
return Vh.wrapValidationError(err, req)
}
return nil
}
// 默认错误处理
func defaultErrorHandler(c *fiber.Ctx, err error) error {
validationErrors, ok := err.(validator.ValidationErrors)
if !ok {
return errcode.SystemError
}
if len(validationErrors) == 0 {
return nil
}
// 快速路径:单个错误
if len(validationErrors) == 1 {
e := validationErrors[0]
msg := Vh.safeGetErrorMessage(c, e)
return errcode.ParamErr(msg)
}
// 从对象池获取builder
builder := builderPool.Get().(*strings.Builder)
builder.Reset()
defer builderPool.Put(builder)
req := c.Locals(CtxRequestBody)
for i, e := range validationErrors {
if i > 0 {
builder.WriteByte('\n')
}
builder.WriteString(Vh.safeGetErrorMessageWithReq(req, e))
}
return errcode.ParamErr(builder.String())
}
// 包装验证错误
func (vh *ValidatorHelper) wrapValidationError(err error, req interface{}) error {
validationErrors, ok := err.(validator.ValidationErrors)
if !ok {
return err
}
if len(validationErrors) == 0 {
return nil
}
// 构建错误消息
builder := builderPool.Get().(*strings.Builder)
builder.Reset()
defer builderPool.Put(builder)
for i, e := range validationErrors {
if i > 0 {
builder.WriteByte('\n')
}
builder.WriteString(vh.safeGetErrorMessageWithReq(req, e))
}
return errcode.ParamErr(builder.String())
}
// 安全获取错误消息
func (vh *ValidatorHelper) safeGetErrorMessage(c *fiber.Ctx, e validator.FieldError) string {
req := c.Locals(CtxRequestBody)
return vh.safeGetErrorMessageWithReq(req, e)
}
// 安全获取错误消息(带请求体)
func (vh *ValidatorHelper) safeGetErrorMessageWithReq(req interface{}, e validator.FieldError) string {
if req == nil {
return vh.safeFormatFieldError(e, nil)
}
t := reflect.TypeOf(req)
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
if t.Kind() != reflect.Struct {
return vh.safeFormatFieldError(e, nil)
}
// 安全获取类型信息 - 这里需要先声明变量
var typeInfoObj *typeInfo // 这里声明变量
if !vh.config.DisableCache {
if cached, ok := vh.typeCache.Load(t); ok {
typeInfoObj = cached.(*typeInfo)
atomic.AddInt64(&typeInfoObj.accessCount, 1)
}
}
return vh.safeFormatFieldError(e, typeInfoObj)
}
// 安全格式化字段错误
func (vh *ValidatorHelper) safeFormatFieldError(e validator.FieldError, typeInfo *typeInfo) string {
structField := e.StructField()
fieldName := e.Field()
// 获取字段标签
var label string
if typeInfo != nil {
typeInfo.mu.RLock()
if info, ok := typeInfo.fields[structField]; ok {
label = info.label
atomic.AddInt64(&info.accessCount, 1)
}
typeInfo.mu.RUnlock()
}
// 如果没有标签,返回默认消息
if label == "" {
return vh.safeGetDefaultErrorMessage(fieldName, e)
}
// 使用预编译的错误函数生成消息
vh.errorFnCacheMu.RLock()
fn, ok := vh.errorFnCache[e.Tag()]
vh.errorFnCacheMu.RUnlock()
if ok {
return fn(label, e.Param())
}
return label + "格式不正确"
}
// 安全获取默认错误消息
func (vh *ValidatorHelper) safeGetDefaultErrorMessage(field string, e validator.FieldError) string {
vh.errorFnCacheMu.RLock()
defer vh.errorFnCacheMu.RUnlock()
if fn, ok := vh.errorFnCache[e.Tag()]; ok {
return fn(field, e.Param())
}
return field + "验证失败"
}
// safeGetOrCreateTypeInfo 安全地获取或创建类型信息
func (vh *ValidatorHelper) safeGetOrCreateTypeInfo(t reflect.Type) *typeInfo {
if vh.config.DisableCache || t == nil || t.Kind() != reflect.Struct {
return nil
}
// 首先尝试从缓存读取
if cached, ok := vh.typeCache.Load(t); ok {
info := cached.(*typeInfo)
atomic.AddInt64(&info.accessCount, 1)
atomic.AddInt64(&vh.cacheStats.hitCount, 1)
info.lastAccess = time.Now()
return info
}
atomic.AddInt64(&vh.cacheStats.missCount, 1)
// 从对象池获取typeInfo
info := typeInfoPool.Get().(*typeInfo)
// 重置并初始化
info.mu.Lock()
// 清空现有map
for k := range info.fields {
delete(info.fields, k)
}
info.fieldNames = info.fieldNames[:0]
info.accessCount = 0
info.createdAt = time.Now()
info.lastAccess = time.Now()
info.typeKey = t.String()
// 预计算所有字段信息
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
// 获取标签
//label := field.Tag.Get(LabelComment)
//if label == "" {
// label = field.Tag.Get(LabelLabel)
//}
//if label == "" {
// label = field.Tag.Get(LabelZh)
//}
label := field.Tag.Get(LabelZh)
// 从对象池获取或创建字段信息
fieldInfo := fieldInfoPool.Get().(*fieldInfo)
fieldInfo.label = label
fieldInfo.index = i
fieldInfo.typeKind = field.Type.Kind()
fieldInfo.accessCount = 0
fieldInfo.lastAccess = time.Now()
info.fields[field.Name] = fieldInfo
info.fieldNames = append(info.fieldNames, field.Name)
}
info.mu.Unlock()
// 使用原子操作确保线程安全的存储
if existing, loaded := vh.typeCache.LoadOrStore(t, info); loaded {
// 如果已经有其他goroutine存储了使用已有的并回收新创建的
info.mu.Lock()
for _, fieldInfo := range info.fields {
fieldInfoPool.Put(fieldInfo)
}
info.mu.Unlock()
typeInfoPool.Put(info)
existingInfo := existing.(*typeInfo)
atomic.AddInt64(&existingInfo.accessCount, 1)
return existingInfo
}
return info
}
// 定期清理缓存
func (vh *ValidatorHelper) periodicCleanup() {
ticker := time.NewTicker(CacheCleanupInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
vh.safeCleanupCache()
case <-vh.stopCleanup:
return
}
}
}
// safeCleanupCache 安全清理缓存
func (vh *ValidatorHelper) safeCleanupCache() {
vh.pruneLock.Lock()
defer vh.pruneLock.Unlock()
var keysToDelete []interface{}
now := time.Now()
vh.typeCache.Range(func(key, value interface{}) bool {
info, ok := value.(*typeInfo)
if !ok {
return true
}
// 检查是否需要清理
accessCount := atomic.LoadInt64(&info.accessCount)
age := now.Sub(info.createdAt)
idleTime := now.Sub(info.lastAccess)
// 清理条件:
// 1. 很少访问的缓存(访问次数 < 10
// 2. 空闲时间超过30分钟
// 3. 缓存年龄超过1小时且访问次数较少
if (accessCount < 10 && idleTime > 30*time.Minute) ||
(age > 1*time.Hour && accessCount < 100) {
keysToDelete = append(keysToDelete, key)
atomic.AddInt64(&vh.cacheStats.evictCount, 1)
}
return true
})
// 删除选中的缓存
for _, key := range keysToDelete {
if val, ok := vh.typeCache.Load(key); ok {
if info, ok := val.(*typeInfo); ok {
// 安全回收字段信息
info.mu.Lock()
for _, fieldInfo := range info.fields {
fieldInfoPool.Put(fieldInfo)
}
info.mu.Unlock()
typeInfoPool.Put(info)
}
vh.typeCache.Delete(key)
}
}
}
// ==================== 指标收集 ====================
func (vh *ValidatorHelper) recordValidationTime(start time.Time) {
if !vh.config.EnableMetrics {
return
}
duration := time.Since(start).Nanoseconds()
atomic.AddInt64(&vh.metrics.validationTime, duration)
}
// GetMetrics 获取性能指标
func (vh *ValidatorHelper) GetMetrics() map[string]interface{} {
if !vh.config.EnableMetrics {
return nil
}
vh.metrics.mu.RLock()
defer vh.metrics.mu.RUnlock()
hitCount := atomic.LoadInt64(&vh.cacheStats.hitCount)
missCount := atomic.LoadInt64(&vh.cacheStats.missCount)
totalRequests := hitCount + missCount
var hitRate float64
if totalRequests > 0 {
hitRate = float64(hitCount) / float64(totalRequests) * 100
}
var cacheSize int64
vh.typeCache.Range(func(_, _ interface{}) bool {
cacheSize++
return true
})
return map[string]interface{}{
"cache_hit_count": hitCount,
"cache_miss_count": missCount,
"cache_evict_count": atomic.LoadInt64(&vh.cacheStats.evictCount),
"cache_hit_rate": fmt.Sprintf("%.2f%%", hitRate),
"cache_size": cacheSize,
"error_count": atomic.LoadInt64(&vh.cacheStats.errorCount),
"total_requests": atomic.LoadInt64(&vh.metrics.totalRequests),
"avg_validation_time_ns": atomic.LoadInt64(&vh.metrics.validationTime) /
max(atomic.LoadInt64(&vh.metrics.totalRequests), 1),
}
}
// RegisterValidation 注册自定义验证规则
func (vh *ValidatorHelper) RegisterValidation(tag string, fn validator.Func, callValidationEvenIfNull ...bool) error {
return vh.validate.RegisterValidation(tag, fn, callValidationEvenIfNull...)
}
// RegisterTranslation 注册自定义翻译
func (vh *ValidatorHelper) RegisterTranslation(tag string, template string) {
vh.errorFnCacheMu.Lock()
defer vh.errorFnCacheMu.Unlock()
t := template
vh.errorFnCache[tag] = func(field, param string) string {
if strings.Contains(t, "%s") && strings.Count(t, "%s") == 2 {
return fmt.Sprintf(t, field, param)
}
return fmt.Sprintf(t, field)
}
}
// Stop 停止后台清理任务
func (vh *ValidatorHelper) Stop() {
close(vh.stopCleanup)
}
// Reset 重置验证器状态
func (vh *ValidatorHelper) Reset() {
vh.ClearCache()
atomic.StoreInt64(&vh.cacheStats.hitCount, 0)
atomic.StoreInt64(&vh.cacheStats.missCount, 0)
atomic.StoreInt64(&vh.cacheStats.evictCount, 0)
atomic.StoreInt64(&vh.cacheStats.errorCount, 0)
atomic.StoreInt64(&vh.metrics.totalRequests, 0)
atomic.StoreInt64(&vh.metrics.validationTime, 0)
}
// ClearCache 清理所有缓存
func (vh *ValidatorHelper) ClearCache() {
vh.pruneLock.Lock()
defer vh.pruneLock.Unlock()
// 安全回收所有缓存
vh.typeCache.Range(func(key, value interface{}) bool {
if info, ok := value.(*typeInfo); ok {
info.mu.Lock()
for _, fieldInfo := range info.fields {
fieldInfoPool.Put(fieldInfo)
}
info.mu.Unlock()
typeInfoPool.Put(info)
}
vh.typeCache.Delete(key)
return true
})
}
// SetLanguage 设置语言
func (vh *ValidatorHelper) SetLanguage(useChinese bool) {
vh.config.UseChinese = useChinese
templates := defaultErrorTemplates
if useChinese {
templates = ChineseErrorTemplates
}
vh.errorFnCacheMu.Lock()
defer vh.errorFnCacheMu.Unlock()
for tag, tmpl := range templates {
t := tmpl
vh.errorFnCache[tag] = func(field, param string) string {
if strings.Contains(t, "%s") && strings.Count(t, "%s") == 2 {
return fmt.Sprintf(t, field, param)
}
return fmt.Sprintf(t, field)
}
}
}
func max(a, b int64) int64 {
if a > b {
return a
}
return b
}
func GetErr(tag, field, param string) string {
t := ChineseErrorTemplates[tag]
if strings.Contains(t, "%s") && strings.Count(t, "%s") == 2 {
return fmt.Sprintf(t, field, param)
}
return fmt.Sprintf(t, field)
}