package pkg import ( "fmt" "geo/tmpl/errcode" "reflect" "strings" "sync" "sync/atomic" "time" "github.com/go-playground/validator/v10" "github.com/gofiber/fiber/v2" ) const ( // 验证标签常量 TagRequired = "required" TagEmail = "email" TagMin = "min" TagMax = "max" TagLen = "len" TagNumeric = "numeric" TagOneof = "oneof" TagGte = "gte" TagGt = "gt" TagLte = "lte" TagLt = "lt" TagURL = "url" TagUUID = "uuid" TagDatetime = "datetime" // 中文标签常量 //LabelComment = "comment" //LabelLabel = "label" LabelZh = "zh" // 上下文Key CtxRequestBody = "request_body" // 性能优化常量 InitialBuilderSize = 64 MaxCacheSize = 5000 CacheCleanupInterval = 10 * time.Minute ) // ValidatorConfig 验证器配置 type ValidatorConfig struct { StatusCode int // 验证失败时的HTTP状态码,默认422 ErrorHandler func(c *fiber.Ctx, err error) error // 自定义错误处理 BeforeParse func(c *fiber.Ctx) error // 解析前执行 AfterParse func(c *fiber.Ctx, req interface{}) error // 解析后执行 DisableCache bool // 是否禁用缓存 MaxCacheSize int // 最大缓存数量 UseChinese bool // 是否使用中文提示 EnableMetrics bool // 是否启用指标收集 } // CacheStats 缓存统计 type CacheStats struct { hitCount int64 missCount int64 evictCount int64 errorCount int64 accessCount int64 } // Metrics 指标收集 type Metrics struct { totalRequests int64 validationTime int64 cacheHitRate float64 mu sync.RWMutex } // fieldInfo 字段信息 type fieldInfo struct { label string index int typeKind reflect.Kind accessCount int64 lastAccess time.Time } // typeInfo 类型信息 type typeInfo struct { fields map[string]*fieldInfo fieldNames []string mu sync.RWMutex accessCount int64 createdAt time.Time lastAccess time.Time typeKey string } // ValidatorHelper 验证器助手 type ValidatorHelper struct { validate *validator.Validate config *ValidatorConfig typeCache sync.Map cacheStats *CacheStats errorFnCache map[string]func(string, string) string errorFnCacheMu sync.RWMutex pruneLock sync.Mutex stopCleanup chan struct{} cleanupOnce sync.Once metrics *Metrics } var ( ChineseErrorTemplates = map[string]string{ TagRequired: "%s不能为空", TagEmail: "%s格式不正确", TagMin: "%s不能小于%s", TagMax: "%s不能大于%s", TagLen: "%s长度必须为%s位", TagNumeric: "%s必须是数字", TagOneof: "%s必须是以下值之一: %s", TagGte: "%s不能小于%s", TagGt: "%s必须大于%s", TagLte: "%s不能大于%s", TagLt: "%s必须小于%s", TagURL: "%s必须是有效的URL地址", TagUUID: "%s必须是有效的UUID", TagDatetime: "%s日期格式不正确", } defaultErrorTemplates = map[string]string{ TagRequired: "%s is required", TagEmail: "invalid email format", TagMin: "%s must be at least %s", TagMax: "%s must be at most %s", TagLen: "%s must be exactly %s characters", TagNumeric: "%s must be numeric", TagOneof: "%s must be one of: %s", TagGte: "%s must be greater than or equal to %s", TagGt: "%s must be greater than %s", TagLte: "%s must be less than or equal to %s", TagLt: "%s must be less than %s", TagURL: "%s must be a valid URL", TagUUID: "%s must be a valid UUID", TagDatetime: "%s invalid datetime format", } ) var ( // 对象池 builderPool = sync.Pool{ New: func() interface{} { b := &strings.Builder{} b.Grow(InitialBuilderSize) return b }, } // 错误信息切片池 errorSlicePool = sync.Pool{ New: func() interface{} { slice := make([]string, 0, 8) return &slice }, } // 字段信息对象池 fieldInfoPool = sync.Pool{ New: func() interface{} { return &fieldInfo{ accessCount: 0, lastAccess: time.Now(), } }, } // 类型信息对象池 typeInfoPool = sync.Pool{ New: func() interface{} { return &typeInfo{ fields: make(map[string]*fieldInfo), fieldNames: make([]string, 0, 8), } }, } ) var ( Vh *ValidatorHelper once sync.Once ) // NewValidatorHelper 初始化验证器助手 func NewValidatorHelper(config ...*ValidatorConfig) { once.Do(func() { v := validator.New() // 优化JSON标签获取 v.RegisterTagNameFunc(func(fld reflect.StructField) string { return fld.Tag.Get("json") }) // 默认配置 cfg := &ValidatorConfig{ StatusCode: fiber.StatusUnprocessableEntity, ErrorHandler: defaultErrorHandler, MaxCacheSize: MaxCacheSize, UseChinese: true, EnableMetrics: false, } if len(config) > 0 && config[0] != nil { if config[0].StatusCode != 0 { cfg.StatusCode = config[0].StatusCode } if config[0].ErrorHandler != nil { cfg.ErrorHandler = config[0].ErrorHandler } if config[0].MaxCacheSize > 0 { cfg.MaxCacheSize = config[0].MaxCacheSize } cfg.DisableCache = config[0].DisableCache cfg.BeforeParse = config[0].BeforeParse cfg.AfterParse = config[0].AfterParse cfg.UseChinese = config[0].UseChinese cfg.EnableMetrics = config[0].EnableMetrics } // 预编译错误函数 errorFnCache := make(map[string]func(string, string) string) templates := ChineseErrorTemplates if !cfg.UseChinese { templates = defaultErrorTemplates } for tag, tmpl := range templates { t := tmpl // 捕获变量 errorFnCache[tag] = func(field, param string) string { if strings.Contains(t, "%s") && strings.Count(t, "%s") == 2 { return fmt.Sprintf(t, field, param) } return fmt.Sprintf(t, field) } } Vh = &ValidatorHelper{ validate: v, config: cfg, cacheStats: &CacheStats{}, errorFnCache: errorFnCache, errorFnCacheMu: sync.RWMutex{}, stopCleanup: make(chan struct{}), metrics: &Metrics{}, } // 启动定期清理 if !cfg.DisableCache { go Vh.periodicCleanup() } }) } // ParseAndValidate 解析并验证请求 func ParseAndValidate(c *fiber.Ctx, req interface{}) error { if Vh == nil { NewValidatorHelper() } if Vh.config.EnableMetrics { atomic.AddInt64(&Vh.metrics.totalRequests, 1) defer Vh.recordValidationTime(time.Now()) } // 执行解析前钩子 if Vh.config.BeforeParse != nil { if err := Vh.config.BeforeParse(c); err != nil { return err } } // 解析请求体 err := c.BodyParser(req) if err != nil { return errcode.ParamErr("请求格式错误:" + err.Error()) } // 执行解析后钩子 if Vh.config.AfterParse != nil { if err = Vh.config.AfterParse(c, req); err != nil { return errcode.ParamErr(err.Error()) } } // 验证数据 err = Vh.validate.Struct(req) if err != nil { c.Locals(CtxRequestBody, req) if !Vh.config.DisableCache { t := reflect.TypeOf(req) if t.Kind() == reflect.Ptr { t = t.Elem() } if t.Kind() == reflect.Struct { Vh.safeGetOrCreateTypeInfo(t) } } return Vh.config.ErrorHandler(c, err) } return nil } // Validate 直接验证结构体 func Validate(req interface{}) error { if Vh == nil { NewValidatorHelper() } if err := Vh.validate.Struct(req); err != nil { return Vh.wrapValidationError(err, req) } return nil } // 默认错误处理 func defaultErrorHandler(c *fiber.Ctx, err error) error { validationErrors, ok := err.(validator.ValidationErrors) if !ok { return errcode.SystemError } if len(validationErrors) == 0 { return nil } // 快速路径:单个错误 if len(validationErrors) == 1 { e := validationErrors[0] msg := Vh.safeGetErrorMessage(c, e) return errcode.ParamErr(msg) } // 从对象池获取builder builder := builderPool.Get().(*strings.Builder) builder.Reset() defer builderPool.Put(builder) req := c.Locals(CtxRequestBody) for i, e := range validationErrors { if i > 0 { builder.WriteByte('\n') } builder.WriteString(Vh.safeGetErrorMessageWithReq(req, e)) } return errcode.ParamErr(builder.String()) } // 包装验证错误 func (vh *ValidatorHelper) wrapValidationError(err error, req interface{}) error { validationErrors, ok := err.(validator.ValidationErrors) if !ok { return err } if len(validationErrors) == 0 { return nil } // 构建错误消息 builder := builderPool.Get().(*strings.Builder) builder.Reset() defer builderPool.Put(builder) for i, e := range validationErrors { if i > 0 { builder.WriteByte('\n') } builder.WriteString(vh.safeGetErrorMessageWithReq(req, e)) } return errcode.ParamErr(builder.String()) } // 安全获取错误消息 func (vh *ValidatorHelper) safeGetErrorMessage(c *fiber.Ctx, e validator.FieldError) string { req := c.Locals(CtxRequestBody) return vh.safeGetErrorMessageWithReq(req, e) } // 安全获取错误消息(带请求体) func (vh *ValidatorHelper) safeGetErrorMessageWithReq(req interface{}, e validator.FieldError) string { if req == nil { return vh.safeFormatFieldError(e, nil) } t := reflect.TypeOf(req) if t.Kind() == reflect.Ptr { t = t.Elem() } if t.Kind() != reflect.Struct { return vh.safeFormatFieldError(e, nil) } // 安全获取类型信息 - 这里需要先声明变量 var typeInfoObj *typeInfo // 这里声明变量 if !vh.config.DisableCache { if cached, ok := vh.typeCache.Load(t); ok { typeInfoObj = cached.(*typeInfo) atomic.AddInt64(&typeInfoObj.accessCount, 1) } } return vh.safeFormatFieldError(e, typeInfoObj) } // 安全格式化字段错误 func (vh *ValidatorHelper) safeFormatFieldError(e validator.FieldError, typeInfo *typeInfo) string { structField := e.StructField() fieldName := e.Field() // 获取字段标签 var label string if typeInfo != nil { typeInfo.mu.RLock() if info, ok := typeInfo.fields[structField]; ok { label = info.label atomic.AddInt64(&info.accessCount, 1) } typeInfo.mu.RUnlock() } // 如果没有标签,返回默认消息 if label == "" { return vh.safeGetDefaultErrorMessage(fieldName, e) } // 使用预编译的错误函数生成消息 vh.errorFnCacheMu.RLock() fn, ok := vh.errorFnCache[e.Tag()] vh.errorFnCacheMu.RUnlock() if ok { return fn(label, e.Param()) } return label + "格式不正确" } // 安全获取默认错误消息 func (vh *ValidatorHelper) safeGetDefaultErrorMessage(field string, e validator.FieldError) string { vh.errorFnCacheMu.RLock() defer vh.errorFnCacheMu.RUnlock() if fn, ok := vh.errorFnCache[e.Tag()]; ok { return fn(field, e.Param()) } return field + "验证失败" } // safeGetOrCreateTypeInfo 安全地获取或创建类型信息 func (vh *ValidatorHelper) safeGetOrCreateTypeInfo(t reflect.Type) *typeInfo { if vh.config.DisableCache || t == nil || t.Kind() != reflect.Struct { return nil } // 首先尝试从缓存读取 if cached, ok := vh.typeCache.Load(t); ok { info := cached.(*typeInfo) atomic.AddInt64(&info.accessCount, 1) atomic.AddInt64(&vh.cacheStats.hitCount, 1) info.lastAccess = time.Now() return info } atomic.AddInt64(&vh.cacheStats.missCount, 1) // 从对象池获取typeInfo info := typeInfoPool.Get().(*typeInfo) // 重置并初始化 info.mu.Lock() // 清空现有map for k := range info.fields { delete(info.fields, k) } info.fieldNames = info.fieldNames[:0] info.accessCount = 0 info.createdAt = time.Now() info.lastAccess = time.Now() info.typeKey = t.String() // 预计算所有字段信息 for i := 0; i < t.NumField(); i++ { field := t.Field(i) // 获取标签 //label := field.Tag.Get(LabelComment) //if label == "" { // label = field.Tag.Get(LabelLabel) //} //if label == "" { // label = field.Tag.Get(LabelZh) //} label := field.Tag.Get(LabelZh) // 从对象池获取或创建字段信息 fieldInfo := fieldInfoPool.Get().(*fieldInfo) fieldInfo.label = label fieldInfo.index = i fieldInfo.typeKind = field.Type.Kind() fieldInfo.accessCount = 0 fieldInfo.lastAccess = time.Now() info.fields[field.Name] = fieldInfo info.fieldNames = append(info.fieldNames, field.Name) } info.mu.Unlock() // 使用原子操作确保线程安全的存储 if existing, loaded := vh.typeCache.LoadOrStore(t, info); loaded { // 如果已经有其他goroutine存储了,使用已有的并回收新创建的 info.mu.Lock() for _, fieldInfo := range info.fields { fieldInfoPool.Put(fieldInfo) } info.mu.Unlock() typeInfoPool.Put(info) existingInfo := existing.(*typeInfo) atomic.AddInt64(&existingInfo.accessCount, 1) return existingInfo } return info } // 定期清理缓存 func (vh *ValidatorHelper) periodicCleanup() { ticker := time.NewTicker(CacheCleanupInterval) defer ticker.Stop() for { select { case <-ticker.C: vh.safeCleanupCache() case <-vh.stopCleanup: return } } } // safeCleanupCache 安全清理缓存 func (vh *ValidatorHelper) safeCleanupCache() { vh.pruneLock.Lock() defer vh.pruneLock.Unlock() var keysToDelete []interface{} now := time.Now() vh.typeCache.Range(func(key, value interface{}) bool { info, ok := value.(*typeInfo) if !ok { return true } // 检查是否需要清理 accessCount := atomic.LoadInt64(&info.accessCount) age := now.Sub(info.createdAt) idleTime := now.Sub(info.lastAccess) // 清理条件: // 1. 很少访问的缓存(访问次数 < 10) // 2. 空闲时间超过30分钟 // 3. 缓存年龄超过1小时且访问次数较少 if (accessCount < 10 && idleTime > 30*time.Minute) || (age > 1*time.Hour && accessCount < 100) { keysToDelete = append(keysToDelete, key) atomic.AddInt64(&vh.cacheStats.evictCount, 1) } return true }) // 删除选中的缓存 for _, key := range keysToDelete { if val, ok := vh.typeCache.Load(key); ok { if info, ok := val.(*typeInfo); ok { // 安全回收字段信息 info.mu.Lock() for _, fieldInfo := range info.fields { fieldInfoPool.Put(fieldInfo) } info.mu.Unlock() typeInfoPool.Put(info) } vh.typeCache.Delete(key) } } } // ==================== 指标收集 ==================== func (vh *ValidatorHelper) recordValidationTime(start time.Time) { if !vh.config.EnableMetrics { return } duration := time.Since(start).Nanoseconds() atomic.AddInt64(&vh.metrics.validationTime, duration) } // GetMetrics 获取性能指标 func (vh *ValidatorHelper) GetMetrics() map[string]interface{} { if !vh.config.EnableMetrics { return nil } vh.metrics.mu.RLock() defer vh.metrics.mu.RUnlock() hitCount := atomic.LoadInt64(&vh.cacheStats.hitCount) missCount := atomic.LoadInt64(&vh.cacheStats.missCount) totalRequests := hitCount + missCount var hitRate float64 if totalRequests > 0 { hitRate = float64(hitCount) / float64(totalRequests) * 100 } var cacheSize int64 vh.typeCache.Range(func(_, _ interface{}) bool { cacheSize++ return true }) return map[string]interface{}{ "cache_hit_count": hitCount, "cache_miss_count": missCount, "cache_evict_count": atomic.LoadInt64(&vh.cacheStats.evictCount), "cache_hit_rate": fmt.Sprintf("%.2f%%", hitRate), "cache_size": cacheSize, "error_count": atomic.LoadInt64(&vh.cacheStats.errorCount), "total_requests": atomic.LoadInt64(&vh.metrics.totalRequests), "avg_validation_time_ns": atomic.LoadInt64(&vh.metrics.validationTime) / max(atomic.LoadInt64(&vh.metrics.totalRequests), 1), } } // RegisterValidation 注册自定义验证规则 func (vh *ValidatorHelper) RegisterValidation(tag string, fn validator.Func, callValidationEvenIfNull ...bool) error { return vh.validate.RegisterValidation(tag, fn, callValidationEvenIfNull...) } // RegisterTranslation 注册自定义翻译 func (vh *ValidatorHelper) RegisterTranslation(tag string, template string) { vh.errorFnCacheMu.Lock() defer vh.errorFnCacheMu.Unlock() t := template vh.errorFnCache[tag] = func(field, param string) string { if strings.Contains(t, "%s") && strings.Count(t, "%s") == 2 { return fmt.Sprintf(t, field, param) } return fmt.Sprintf(t, field) } } // Stop 停止后台清理任务 func (vh *ValidatorHelper) Stop() { close(vh.stopCleanup) } // Reset 重置验证器状态 func (vh *ValidatorHelper) Reset() { vh.ClearCache() atomic.StoreInt64(&vh.cacheStats.hitCount, 0) atomic.StoreInt64(&vh.cacheStats.missCount, 0) atomic.StoreInt64(&vh.cacheStats.evictCount, 0) atomic.StoreInt64(&vh.cacheStats.errorCount, 0) atomic.StoreInt64(&vh.metrics.totalRequests, 0) atomic.StoreInt64(&vh.metrics.validationTime, 0) } // ClearCache 清理所有缓存 func (vh *ValidatorHelper) ClearCache() { vh.pruneLock.Lock() defer vh.pruneLock.Unlock() // 安全回收所有缓存 vh.typeCache.Range(func(key, value interface{}) bool { if info, ok := value.(*typeInfo); ok { info.mu.Lock() for _, fieldInfo := range info.fields { fieldInfoPool.Put(fieldInfo) } info.mu.Unlock() typeInfoPool.Put(info) } vh.typeCache.Delete(key) return true }) } // SetLanguage 设置语言 func (vh *ValidatorHelper) SetLanguage(useChinese bool) { vh.config.UseChinese = useChinese templates := defaultErrorTemplates if useChinese { templates = ChineseErrorTemplates } vh.errorFnCacheMu.Lock() defer vh.errorFnCacheMu.Unlock() for tag, tmpl := range templates { t := tmpl vh.errorFnCache[tag] = func(field, param string) string { if strings.Contains(t, "%s") && strings.Count(t, "%s") == 2 { return fmt.Sprintf(t, field, param) } return fmt.Sprintf(t, field) } } } func max(a, b int64) int64 { if a > b { return a } return b } func GetErr(tag, field, param string) string { t := ChineseErrorTemplates[tag] if strings.Contains(t, "%s") && strings.Count(t, "%s") == 2 { return fmt.Sprintf(t, field, param) } return fmt.Sprintf(t, field) }