From 2bb054940a90e1c89e0981815ead8abf16ac8e66 Mon Sep 17 00:00:00 2001 From: renzhiyuan <465386466@qq.com> Date: Tue, 16 Sep 2025 21:31:47 +0800 Subject: [PATCH] =?UTF-8?q?=E7=BB=93=E6=9E=84=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cmd/server/wire.go | 5 + config/config.yaml | 17 + gen.sh | 19 + go.mod | 44 +- go.sum | 94 +- internal/biz/router.go | 116 +- internal/config/config.go | 29 + internal/data/constant/const.go | 8 +- internal/data/error/error_code.go | 12 +- internal/data/impl/notify_data_impl.go | 15 - internal/data/impl/provider_set.go | 2 +- internal/data/impl/session_impl.go | 15 + internal/data/impl/sys_impl.go | 15 + internal/data/impl/task_impl.go | 15 + internal/data/model/ai_session.gen.go | 28 + internal/data/model/ai_sys.gen.go | 30 + internal/data/model/ai_task.gen.go | 29 + internal/entitys/types.go | 3 + internal/middleware/auth.go | 112 - internal/middleware/error_handler.go | 45 - internal/middleware/logger.go | 83 - internal/middleware/recovery.go | 34 - internal/middleware/trace.go | 124 - internal/pkg/func.go | 1 + internal/pkg/gorm.go | 25 + internal/pkg/mapstructure/decode_hooks.go | 279 ++ .../pkg/mapstructure/decode_hooks_test.go | 567 ++++ internal/pkg/mapstructure/error.go | 50 + internal/pkg/mapstructure/mapstructure.go | 1386 +++++++++ .../mapstructure_benchmark_test.go | 285 ++ .../mapstructure/mapstructure_bugs_test.go | 627 ++++ .../mapstructure_examples_test.go | 256 ++ .../pkg/mapstructure/mapstructure_ext_test.go | 58 + .../pkg/mapstructure/mapstructure_test.go | 2763 +++++++++++++++++ internal/pkg/mapstructure/my_decode.go | 26 + internal/pkg/mapstructure/my_decode_hook.go | 101 + .../pkg/mapstructure/my_decode_hook_test.go | 274 ++ internal/pkg/provider_set.go | 7 +- internal/pkg/rds.go | 48 + internal/pkg/utils_gorm/gorm.go | 42 + internal/pkg/utils_gorm/sql_log.go | 96 + internal/services/chat.go | 79 +- tmpl/dataTemp/queryTempl.go | 113 + tmpl/dataTemp/req_page.go | 35 + tmpl/dataTemp/resp_page.go | 22 + tmpl/errcode/common.go | 86 + utils/gorm.go | 25 + utils/provider_set.go | 10 + utils/rds.go | 48 + utils/utils_gorm/gorm.go | 41 + utils/utils_gorm/sql_log.go | 96 + utils/utils_sys_cache/gods.go | 14 + utils/utils_test.go | 72 + 53 files changed, 7851 insertions(+), 575 deletions(-) create mode 100755 gen.sh delete mode 100644 internal/data/impl/notify_data_impl.go create mode 100644 internal/data/impl/session_impl.go create mode 100644 internal/data/impl/sys_impl.go create mode 100644 internal/data/impl/task_impl.go create mode 100644 internal/data/model/ai_session.gen.go create mode 100644 internal/data/model/ai_sys.gen.go create mode 100644 internal/data/model/ai_task.gen.go delete mode 100644 internal/middleware/auth.go delete mode 100644 internal/middleware/error_handler.go delete mode 100644 internal/middleware/logger.go delete mode 100644 internal/middleware/recovery.go delete mode 100644 internal/middleware/trace.go create mode 100644 internal/pkg/func.go create mode 100644 internal/pkg/gorm.go create mode 100644 internal/pkg/mapstructure/decode_hooks.go create mode 100644 internal/pkg/mapstructure/decode_hooks_test.go create mode 100644 internal/pkg/mapstructure/error.go create mode 100644 internal/pkg/mapstructure/mapstructure.go create mode 100644 internal/pkg/mapstructure/mapstructure_benchmark_test.go create mode 100644 internal/pkg/mapstructure/mapstructure_bugs_test.go create mode 100644 internal/pkg/mapstructure/mapstructure_examples_test.go create mode 100644 internal/pkg/mapstructure/mapstructure_ext_test.go create mode 100644 internal/pkg/mapstructure/mapstructure_test.go create mode 100644 internal/pkg/mapstructure/my_decode.go create mode 100644 internal/pkg/mapstructure/my_decode_hook.go create mode 100644 internal/pkg/mapstructure/my_decode_hook_test.go create mode 100644 internal/pkg/rds.go create mode 100644 internal/pkg/utils_gorm/gorm.go create mode 100644 internal/pkg/utils_gorm/sql_log.go create mode 100644 tmpl/dataTemp/queryTempl.go create mode 100644 tmpl/dataTemp/req_page.go create mode 100644 tmpl/dataTemp/resp_page.go create mode 100644 tmpl/errcode/common.go create mode 100644 utils/gorm.go create mode 100644 utils/provider_set.go create mode 100644 utils/rds.go create mode 100644 utils/utils_gorm/gorm.go create mode 100644 utils/utils_gorm/sql_log.go create mode 100644 utils/utils_sys_cache/gods.go create mode 100644 utils/utils_test.go diff --git a/cmd/server/wire.go b/cmd/server/wire.go index 826becf..ba76674 100644 --- a/cmd/server/wire.go +++ b/cmd/server/wire.go @@ -6,10 +6,13 @@ package main import ( "ai_scheduler/internal/biz" "ai_scheduler/internal/config" + "ai_scheduler/internal/data/impl" "ai_scheduler/internal/pkg" "ai_scheduler/internal/server" "ai_scheduler/internal/services" "ai_scheduler/internal/tools" + "ai_scheduler/utils" + "github.com/gofiber/fiber/v2/log" "github.com/google/wire" ) @@ -22,6 +25,8 @@ func InitializeApp(*config.Config, log.AllLogger) (*server.Servers, func(), erro pkg.ProviderSetClient, services.ProviderSetServices, biz.ProviderSetBiz, + impl.ProviderImpl, + utils.ProviderUtils, )) } diff --git a/config/config.yaml b/config/config.yaml index 71b1cd6..26ed0a1 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -9,3 +9,20 @@ ollama: timeout: "120s" level: "info" format: "json" + +sys: + session_len: 3 + +Redis: + host: 47.97.27.195:6379 + type: node + pass: lansexiongdi@666 + key: report-api-test + pollSize: 5 #连接池大小,不配置,或配置为0表示不启用连接池 + minIdleConns: 2 #最小空闲连接数 + maxIdleTime: 30 #每个连接最大空闲时间,如果超过了这个时间会被关闭 + tls: 30 + db: +DB: + driver: mysql + source: transfer:Lsxd@34234QW@tcp(lsxdpolar.rwlb.rds.aliyuncs.com:3306)/transfer?charset=utf8mb4&parseTime=true& \ No newline at end of file diff --git a/gen.sh b/gen.sh new file mode 100755 index 0000000..2004ddd --- /dev/null +++ b/gen.sh @@ -0,0 +1,19 @@ +#!/bin/bash + + # 使用方法: + # ./genModel.sh usercenter user + # ./genModel.sh usercenter user_auth + # 再将./genModel下的文件剪切到对应服务的model目录里面,记得改package + + + #生成的表名 + tables=$1 + #表生成的genmodel目录 + modeldir=./internal/data/model + + # 数据库配置 + prefix=ai_ + + + +gentool --dsn "root:SD###sdf323r343@tcp(121.199.38.107:3306)/sys_ai?charset=utf8mb4&parseTime=true&loc=Asia%2FShanghai" -outPath ${modeldir} -onlyModel -modelPkgName "model" -tables ${prefix}${tables} diff --git a/go.mod b/go.mod index dc6332f..fe95157 100644 --- a/go.mod +++ b/go.mod @@ -5,42 +5,41 @@ go 1.24.0 toolchain go1.24.7 require ( - github.com/gin-gonic/gin v1.10.0 + github.com/emirpasic/gods v1.18.1 + github.com/go-kratos/kratos/v2 v2.9.1 github.com/gofiber/fiber/v2 v2.52.9 github.com/gofiber/websocket/v2 v2.2.1 - github.com/google/uuid v1.6.0 github.com/google/wire v0.7.0 github.com/ollama/ollama v0.11.10 + github.com/redis/go-redis/v9 v9.14.0 github.com/spf13/viper v1.17.0 - go.opentelemetry.io/otel v1.38.0 + google.golang.org/grpc v1.61.1 + google.golang.org/protobuf v1.34.1 + gopkg.in/yaml.v3 v3.0.1 + gorm.io/driver/mysql v1.6.0 + gorm.io/gorm v1.31.0 + xorm.io/builder v0.3.13 ) require ( + filippo.io/edwards25519 v1.1.0 // indirect github.com/andybalholm/brotli v1.1.0 // indirect - github.com/bytedance/sonic v1.11.6 // indirect - github.com/bytedance/sonic/loader v0.1.1 // indirect - github.com/cloudwego/base64x v0.1.4 // indirect - github.com/cloudwego/iasm v0.2.0 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/fasthttp/websocket v1.5.3 // indirect github.com/fsnotify/fsnotify v1.6.0 // indirect - github.com/gabriel-vasile/mimetype v1.4.3 // indirect - github.com/gin-contrib/sse v0.1.0 // indirect - github.com/go-playground/locales v0.14.1 // indirect - github.com/go-playground/universal-translator v0.18.1 // indirect - github.com/go-playground/validator/v10 v10.20.0 // indirect - github.com/goccy/go-json v0.10.2 // indirect + github.com/go-sql-driver/mysql v1.8.1 // indirect + github.com/golang/protobuf v1.5.4 // indirect + github.com/google/uuid v1.6.0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect - github.com/json-iterator/go v1.1.12 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect github.com/klauspost/compress v1.17.9 // indirect - github.com/klauspost/cpuid/v2 v2.2.7 // indirect - github.com/leodido/go-urn v1.4.0 // indirect github.com/magiconair/properties v1.8.7 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-runewidth v0.0.16 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect - github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect - github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect github.com/rivo/uniseg v0.2.0 // indirect github.com/sagikazarmark/locafero v0.3.0 // indirect @@ -50,22 +49,17 @@ require ( github.com/spf13/afero v1.10.0 // indirect github.com/spf13/cast v1.5.1 // indirect github.com/spf13/pflag v1.0.5 // indirect + github.com/stretchr/testify v1.11.1 // indirect github.com/subosito/gotenv v1.6.0 // indirect - github.com/twitchyliquid64/golang-asm v0.15.1 // indirect - github.com/ugorji/go/codec v1.2.12 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasthttp v1.51.0 // indirect github.com/valyala/tcplisten v1.0.0 // indirect go.uber.org/atomic v1.9.0 // indirect go.uber.org/multierr v1.9.0 // indirect - golang.org/x/arch v0.8.0 // indirect golang.org/x/crypto v0.36.0 // indirect golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect - golang.org/x/net v0.38.0 // indirect golang.org/x/sys v0.31.0 // indirect golang.org/x/text v0.23.0 // indirect - google.golang.org/protobuf v1.34.1 // indirect - gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20240102182953-50ed04b92917 // indirect gopkg.in/ini.v1 v1.67.0 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index ba6b9dc..37747f2 100644 --- a/go.sum +++ b/go.sum @@ -36,23 +36,25 @@ cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RX cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0= cloud.google.com/go/storage v1.14.0/go.mod h1:GrKmX003DSIwi9o29oFT7YDnHYwZoctc3fOKtUw0Xmo= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +gitea.com/xorm/sqlfiddle v0.0.0-20180821085327-62ce714f951a h1:lSA0F4e9A2NcQSqGqTOXqu2aRi/XEQxDCBwM8yJtE6s= +gitea.com/xorm/sqlfiddle v0.0.0-20180821085327-62ce714f951a/go.mod h1:EXuID2Zs0pAQhH8yz+DNjUbjppKQzKFAn28TMYPB6IU= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M= github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY= -github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0= -github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4= -github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM= -github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= -github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y= -github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= -github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= -github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= @@ -60,6 +62,10 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc= +github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= @@ -72,25 +78,13 @@ github.com/frankban/quicktest v1.14.4 h1:g2rn0vABPOOXmZUj+vbmUp0lPoXEMuhTpIluN0X github.com/frankban/quicktest v1.14.4/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY= github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw= -github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= -github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= -github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= -github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= -github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU= -github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= -github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= -github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= -github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= -github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= -github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= -github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= -github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBExVwjEviJTixqxL8= -github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= -github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= -github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= +github.com/go-kratos/kratos/v2 v2.9.1 h1:EGif6/S/aK/RCR5clIbyhioTNyoSrii3FC118jG40Z0= +github.com/go-kratos/kratos/v2 v2.9.1/go.mod h1:a1MQLjMhIh7R0kcJS9SzJYR43BRI7EPzzN0J1Ksu2bA= +github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= +github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= github.com/gofiber/fiber/v2 v2.52.9 h1:YjKl5DOiyP3j0mO61u3NTmK7or8GzzWzCFzkboyP5cw= github.com/gofiber/fiber/v2 v2.52.9/go.mod h1:YEcBbO/FB+5M1IZNBP9FO3J9281zgPAreiI1oqg8nDw= github.com/gofiber/websocket/v2 v2.2.1 h1:C9cjxvloojayOp9AovmpQrk8VqvVnT8Oao3+IUygH7w= @@ -120,6 +114,8 @@ github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvq github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= @@ -133,7 +129,6 @@ github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= -github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= github.com/google/martian/v3 v3.1.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= @@ -162,28 +157,23 @@ github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= -github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= -github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= -github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= -github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM= -github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= -github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M= github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= -github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= -github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= @@ -195,11 +185,6 @@ github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6T github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= -github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= -github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= -github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= -github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= -github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/ollama/ollama v0.11.10 h1:J9zaoTPwIXOrYXCRAqI7rV4cJ+FOMuQc/vBqQ5GIdWg= github.com/ollama/ollama v0.11.10/go.mod h1:9+1//yWPsDE2u+l1a5mpaKrYw4VdnSsRU3ioq5BvMms= github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= @@ -210,11 +195,13 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/redis/go-redis/v9 v9.14.0 h1:u4tNCjXOyzfgeLN+vAZaW1xUooqWDqVEsZN0U01jfAE= +github.com/redis/go-redis/v9 v9.14.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw= github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= -github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= -github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= +github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= github.com/sagikazarmark/locafero v0.3.0 h1:zT7VEGWC2DTflmccN/5T1etyKvxSxpHsjb9cJvm4SvQ= github.com/sagikazarmark/locafero v0.3.0/go.mod h1:w+v7UsPNFwzF1cHuOajOOzoq4U7v/ig1mpRjqV+Bu1U= github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE= @@ -241,17 +228,12 @@ github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5 github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= -github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= -github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= -github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= -github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasthttp v1.51.0 h1:8b30A5JlZ6C7AS81RsWjYMQmrZG6feChmgAolCl1SqA= @@ -268,15 +250,10 @@ go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk= -go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8= -go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM= go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ= -golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= -golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc= -golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= @@ -374,6 +351,8 @@ golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw= +golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -411,7 +390,6 @@ golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= @@ -544,6 +522,8 @@ google.golang.org/genproto v0.0.0-20201210142538-e3217bee35cc/go.mod h1:FWY/as6D google.golang.org/genproto v0.0.0-20201214200347-8c77b98c765d/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20210108203827-ffc7fda8c3d7/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20210226172003-ab064af71705/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240102182953-50ed04b92917 h1:6G8oQ016D88m1xAKljMlBOOGWDZkes4kMhgGFlf8WcQ= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240102182953-50ed04b92917/go.mod h1:xtjpI3tXFPP051KaWnhvxkiubL/6dJ18vLVf7q2pTOU= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= @@ -560,6 +540,8 @@ google.golang.org/grpc v1.31.1/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc= google.golang.org/grpc v1.34.0/go.mod h1:WotjhfgOW/POjDeRt8vscBtXq+2VjORFy659qA51WJ8= google.golang.org/grpc v1.35.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= +google.golang.org/grpc v1.61.1 h1:kLAiWrZs7YeDM6MumDe7m3y4aM6wacLzM1Y/wiLP9XY= +google.golang.org/grpc v1.61.1/go.mod h1:VUbo7IFqmF1QtCAstipjG0GIoq49KvMe9+h1jFLBNJs= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= @@ -583,6 +565,10 @@ gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/driver/mysql v1.6.0 h1:eNbLmNTpPpTOVZi8MMxCi2aaIm0ZpInbORNXDwyLGvg= +gorm.io/driver/mysql v1.6.0/go.mod h1:D/oCC2GWK3M/dqoLxnOlaNKmXz8WNTfcS9y5ovaSqKo= +gorm.io/gorm v1.31.0 h1:0VlycGreVhK7RF/Bwt51Fk8v0xLiiiFdbGDPIZQ7mJY= +gorm.io/gorm v1.31.0/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= @@ -590,8 +576,8 @@ honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWh honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= honnef.co/go/tools v0.0.1-2020.1.3/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= -nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50= rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8= -rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= +xorm.io/builder v0.3.13 h1:a3jmiVVL19psGeXx8GIurTp7p0IIgqeDmwhcR6BAOAo= +xorm.io/builder v0.3.13/go.mod h1:aUW0S9eb9VCaPohFCH3j7czOx1PMW3i1HrSzbLYGBSE= diff --git a/internal/biz/router.go b/internal/biz/router.go index 3987001..ddc24aa 100644 --- a/internal/biz/router.go +++ b/internal/biz/router.go @@ -1,41 +1,98 @@ package biz import ( + "ai_scheduler/internal/config" + errors "ai_scheduler/internal/data/error" + "ai_scheduler/internal/data/impl" + "ai_scheduler/internal/data/model" "ai_scheduler/internal/entitys" "ai_scheduler/internal/tools" + "ai_scheduler/tmpl/dataTemp" + "fmt" + "sync" "context" "encoding/json" "log" "strings" + + "github.com/gofiber/websocket/v2" + "xorm.io/builder" ) // AiRouterService 智能路由服务 type AiRouterService struct { aiClient entitys.AIClient toolManager *tools.Manager + sessionImpl *impl.SessionImpl + conf *config.Config } // NewRouterService 创建路由服务 -func NewAiRouterBiz(aiClient entitys.AIClient, toolManager *tools.Manager) entitys.RouterService { +func NewAiRouterBiz(aiClient entitys.AIClient, toolManager *tools.Manager, sessionImpl *impl.SessionImpl, conf *config.Config) entitys.RouterService { return &AiRouterService{ aiClient: aiClient, toolManager: toolManager, + sessionImpl: sessionImpl, + conf: conf, } } // Route 执行智能路由 func (r *AiRouterService) Route(ctx context.Context, req *entitys.ChatRequest) (*entitys.ChatResponse, error) { + + return nil, nil +} + +// Route 执行智能路由 +func (r *AiRouterService) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) error { + session := c.Headers("x-session", "") + if len(session) == 0 { + return errors.SessionNotFound + } + auth := c.Headers("x-authorization", "") + + if len(auth) == 0 { + return errors.AuthNotFound + } + key := c.Headers("x-app-key", "") + if len(key) == 0 { + return errors.KeyNotFound + } + var sysInfo model.AiSy + cond := builder.NewCond() + cond = cond.And(builder.Eq{"app_key": key}) + err := r.sessionImpl.GetOneBySearchToStrut(&cond, &sysInfo) + if err != nil { + return errors.SysNotFound + } + cond = builder.NewCond() + cond = cond.And(builder.Eq{"session_id": session}) + history, _, err := r.sessionImpl.GetList(&cond, &dataTemp.ReqPageBo{Limit: r.conf.Sys.SessionLen}) + if err != nil { + return errors.SystemError + } + fmt.Printf("history:%v\n", history) + var ( + messages = make([]entitys.Message, 0) + onece sync.Once + ) + onece.Do(func() { + + messages = append(messages, entitys.Message{ + Role: "system", + Content: r.buildSystemPrompt(sysInfo.SysPrompt), + }) + }) + messages = append(messages, entitys.Message{}, entitys.Message{ + Role: "assistant", + Content: r.buildIntentPrompt(req.Text), + }, entitys.Message{ + Role: "user", + Content: req.Text, + }) // 构建消息 //messages := []entitys.Message{ - // // { - // // Role: "system", - // // Content: r.buildSystemPrompt(), - // // }, - // { - // Role: "assistant", - // Content: r.buildIntentPrompt(req.UserInput), - // }, // { // Role: "user", // Content: req.UserInput, @@ -113,35 +170,36 @@ func (r *AiRouterService) Route(ctx context.Context, req *entitys.ChatRequest) ( //log.Printf("Router processed request: %s, used %d tools", req.UserInput, len(toolResults)) //return finalResponse, nil - return nil, nil + return nil } // buildSystemPrompt 构建系统提示词 -func (r *AiRouterService) buildSystemPrompt() string { - prompt := `你是一个智能路由系统,你的任务是根据用户输入判断用户的意图,并且执行对应的任务。` +func (r *AiRouterService) buildSystemPrompt(prompt string) string { + if len(prompt) == 0 { + prompt = "[system] 你是一个智能路由系统,核心职责是 **精准解析用户意图并路由至对应任务模块**\n[rule]\n1.返回以下格式的JSON:{ \"index\": \"工具索引index\", \"confidence\": 0.0-1.0,\"reasoning\": \"判断理由\"}\n2.严格返回字符串格式,禁用markdown格式返回\n3.只返回json字符串,不包含任何其他解释性文字\n4.当用户意图非常不清晰时使用,尝试进行追问具体希望查询内容" + } return prompt } // buildIntentPrompt 构建意图识别提示词 func (r *AiRouterService) buildIntentPrompt(userInput string) string { - prompt := `请分析以下用户输入,判断用户的意图类型。 - -用户输入:{user_input} - -意图类型说明: -1. order_diagnosis - 订单诊断:用户想要查询、诊断或了解订单相关信息 -2. knowledge_qa - 知识问答:用户想要进行一般性问答或获取知识信息 - -- 当用户意图不够清晰且不匹配 knowledge_qa 以外意图时,使用knowledge_qa -- 当用户意图非常不清晰时使用 unknown - -请只返回以下格式的JSON: -{ - "intent": "order_diagnosis" | "knowledge_qa" | "unknown", - "confidence": 0.0-1.0, - "reasoning": "判断理由" -} + prompt := `##任务 +分析用户输入,判断用户的意图类型,没有使用Markdown格式的json格式回复 +##意图类型 +1. product_diagnosis - 商品诊断:用户想要查询、诊断或了解商品相关信息 +2. order_diagnosis - 订单诊断:用户想要查询、诊断或了解订单相关信息 +3. knowledge_qa - 知识问答:用户想要进行一般性问答或获取知识信息 +##判断规则 +1.当用户意图不够清晰且不匹配 knowledge_qa 以外意图时,使用knowledge_qa +2.当用户意图非常不清晰时使用 unknown +##格式要求 +1.返回以下格式的JSON: +{ "intent": "product_diagnosis" | "order_diagnosis" | "knowledge_qa" | "unknown", "confidence": 0.0-1.0,"reasoning": "判断理由"} +2.严格返回字符串格式,禁用markdown格式返回 +3.只返回json字符串,不包含任何其他解释性文字 +## 用户当前的问题是: +{user_input} ` prompt = strings.ReplaceAll(prompt, "{user_input}", userInput) diff --git a/internal/config/config.go b/internal/config/config.go index 2ca92c2..8e6906e 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -11,8 +11,16 @@ import ( type Config struct { Server ServerConfig `mapstructure:"server"` Ollama OllamaConfig `mapstructure:"ollama"` + Sys SysConfig `mapstructure:"sys"` Tools ToolsConfig `mapstructure:"tools"` Logging LoggingConfig `mapstructure:"logging"` + Redis *Redis `protobuf:"bytes,1,opt,name=Redis,proto3" json:"Redis,omitempty"` + DB *DB `protobuf:"bytes,3,opt,name=TransDB,proto3" json:"TransDB,omitempty"` +} + +// SysConfig 系统配置 +type SysConfig struct { + SessionLen int `mapstructure:"session_len"` } // ServerConfig 服务器配置 @@ -28,6 +36,27 @@ type OllamaConfig struct { Timeout time.Duration `mapstructure:"timeout"` } +type Redis struct { + Host string `protobuf:"bytes,1,opt,name=host,proto3" json:"host,omitempty"` + Type string `protobuf:"bytes,2,opt,name=type,proto3" json:"type,omitempty"` + Pass string `protobuf:"bytes,3,opt,name=pass,proto3" json:"pass,omitempty"` + Key string `protobuf:"bytes,4,opt,name=key,proto3" json:"key,omitempty"` + Tls int32 `protobuf:"varint,5,opt,name=tls,proto3" json:"tls,omitempty"` + Db int32 `protobuf:"varint,6,opt,name=db,proto3" json:"db,omitempty"` + MaxIdle int32 `protobuf:"varint,7,opt,name=maxIdle,proto3" json:"maxIdle,omitempty"` + PoolSize int32 `protobuf:"varint,8,opt,name=poolSize,proto3" json:"poolSize,omitempty"` + MaxIdleTime int32 `protobuf:"varint,9,opt,name=maxIdleTime,proto3" json:"maxIdleTime,omitempty"` +} + +type DB struct { + Driver string `protobuf:"bytes,1,opt,name=driver,proto3" json:"driver,omitempty"` + Source string `protobuf:"bytes,2,opt,name=source,proto3" json:"source,omitempty"` + MaxIdle int32 `protobuf:"varint,3,opt,name=maxIdle,proto3" json:"maxIdle,omitempty"` + MaxOpen int32 `protobuf:"varint,4,opt,name=maxOpen,proto3" json:"maxOpen,omitempty"` + MaxLifetime int32 `protobuf:"varint,5,opt,name=maxLifetime,proto3" json:"maxLifetime,omitempty"` + IsDebug bool `protobuf:"varint,6,opt,name=isDebug,proto3" json:"isDebug,omitempty"` +} + // ToolsConfig 工具配置 type ToolsConfig struct { Weather ToolConfig `mapstructure:"weather"` diff --git a/internal/data/constant/const.go b/internal/data/constant/const.go index 19f4670..a93ee03 100644 --- a/internal/data/constant/const.go +++ b/internal/data/constant/const.go @@ -1,3 +1,9 @@ package constant -const () +type ConnStatus int8 + +const ( + ConnStatusClosed ConnStatus = iota + ConnStatusNormal + ConnStatusIgnore +) diff --git a/internal/data/error/error_code.go b/internal/data/error/error_code.go index ef8c36a..2813bc8 100644 --- a/internal/data/error/error_code.go +++ b/internal/data/error/error_code.go @@ -7,9 +7,11 @@ var ( SystemError = &BusinessErr{code: "0005", message: "系统错误"} SupplierNotFound = &BusinessErr{code: "0006", message: "供应商不存在"} - SupplierApiError = &BusinessErr{code: "0007", message: "第三方供应商接口报错"} - - InvalidParam = &BusinessErr{code: InvalidParamCode, message: "无效参数"} + SessionNotFound = &BusinessErr{code: "0007", message: "未找到会话信息"} + AuthNotFound = &BusinessErr{code: "0008", message: "身份验证失败"} + KeyNotFound = &BusinessErr{code: "0009", message: "身份验证失败"} + SysNotFound = &BusinessErr{code: "0010", message: "未找到系统信息"} + InvalidParam = &BusinessErr{code: InvalidParamCode, message: "无效参数"} ) const ( @@ -41,7 +43,3 @@ func NewBusinessErr(code string, message string) *BusinessErr { func (e *BusinessErr) Wrap(err error) *BusinessErr { return NewBusinessErr(e.code, err.Error()) } - -func SupplierApiErrorDiy(message string) *BusinessErr { - return &BusinessErr{code: SupplierApiError.code, message: message} -} diff --git a/internal/data/impl/notify_data_impl.go b/internal/data/impl/notify_data_impl.go deleted file mode 100644 index ce5d4b3..0000000 --- a/internal/data/impl/notify_data_impl.go +++ /dev/null @@ -1,15 +0,0 @@ -package impl - -import ( - "trans_hub/app/physical_goods/supplier/goods/service/internal/data/model" - "trans_hub/tmpl/dataTemp" - "trans_hub/utils" -) - -type NotifyDataImpl struct { - dataTemp.DataTemp -} - -func NewOrderImpl(db *utils.Db) *NotifyDataImpl { - return &NotifyDataImpl{*dataTemp.NewDataTemp(db, new(model.SupplierNotifyDatum))} -} diff --git a/internal/data/impl/provider_set.go b/internal/data/impl/provider_set.go index 92fbe40..3d8d477 100644 --- a/internal/data/impl/provider_set.go +++ b/internal/data/impl/provider_set.go @@ -4,4 +4,4 @@ import ( "github.com/google/wire" ) -var ProviderImpl = wire.NewSet(NewOrderImpl) +var ProviderImpl = wire.NewSet(NewSessionImpl) diff --git a/internal/data/impl/session_impl.go b/internal/data/impl/session_impl.go new file mode 100644 index 0000000..f0ff5e2 --- /dev/null +++ b/internal/data/impl/session_impl.go @@ -0,0 +1,15 @@ +package impl + +import ( + "ai_scheduler/internal/data/model" + "ai_scheduler/tmpl/dataTemp" + "ai_scheduler/utils" +) + +type SessionImpl struct { + dataTemp.DataTemp +} + +func NewSessionImpl(db *utils.Db) *SessionImpl { + return &SessionImpl{*dataTemp.NewDataTemp(db, new(model.AiSession))} +} diff --git a/internal/data/impl/sys_impl.go b/internal/data/impl/sys_impl.go new file mode 100644 index 0000000..4370f33 --- /dev/null +++ b/internal/data/impl/sys_impl.go @@ -0,0 +1,15 @@ +package impl + +import ( + "ai_scheduler/internal/data/model" + "ai_scheduler/tmpl/dataTemp" + "ai_scheduler/utils" +) + +type SysImpl struct { + dataTemp.DataTemp +} + +func NewSysImpl(db *utils.Db) *SysImpl { + return &SysImpl{*dataTemp.NewDataTemp(db, new(model.AiSy))} +} diff --git a/internal/data/impl/task_impl.go b/internal/data/impl/task_impl.go new file mode 100644 index 0000000..8b3f246 --- /dev/null +++ b/internal/data/impl/task_impl.go @@ -0,0 +1,15 @@ +package impl + +import ( + "ai_scheduler/internal/data/model" + "ai_scheduler/tmpl/dataTemp" + "ai_scheduler/utils" +) + +type TaskImpl struct { + dataTemp.DataTemp +} + +func NewTaskImpl(db *utils.Db) *TaskImpl { + return &TaskImpl{*dataTemp.NewDataTemp(db, new(model.AiTask))} +} diff --git a/internal/data/model/ai_session.gen.go b/internal/data/model/ai_session.gen.go new file mode 100644 index 0000000..0206ee0 --- /dev/null +++ b/internal/data/model/ai_session.gen.go @@ -0,0 +1,28 @@ +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. + +package model + +import ( + "time" +) + +const TableNameAiSession = "ai_session" + +// AiSession mapped from table +type AiSession struct { + SysID int32 `gorm:"column:sys_id;not null" json:"sys_id"` + SessionID string `gorm:"column:session_id;primaryKey" json:"session_id"` + KnowlegeSessionID string `gorm:"column:knowlege_session_id;not null" json:"knowlege_session_id"` + Title string `gorm:"column:title;not null" json:"title"` + CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"` + UpdateAt time.Time `gorm:"column:update_at;default:CURRENT_TIMESTAMP" json:"update_at"` + Status int32 `gorm:"column:status;not null" json:"status"` + DeleteAt time.Time `gorm:"column:delete_at" json:"delete_at"` +} + +// TableName AiSession's table name +func (*AiSession) TableName() string { + return TableNameAiSession +} diff --git a/internal/data/model/ai_sys.gen.go b/internal/data/model/ai_sys.gen.go new file mode 100644 index 0000000..0f2fdff --- /dev/null +++ b/internal/data/model/ai_sys.gen.go @@ -0,0 +1,30 @@ +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. + +package model + +import ( + "time" +) + +const TableNameAiSy = "ai_sys" + +// AiSy mapped from table +type AiSy struct { + SysID int32 `gorm:"column:sys_id;primaryKey;autoIncrement:true" json:"sys_id"` + SysName string `gorm:"column:sys_name;not null" json:"sys_name"` + AppKey string `gorm:"column:app_key;not null" json:"app_key"` + KnowlegeTenantKey string `gorm:"column:knowlege_tenant_key;not null" json:"knowlege_tenant_key"` + KnowlegeBaseID string `gorm:"column:knowlege_base_id;not null" json:"knowlege_base_id"` + SysPrompt string `gorm:"column:sys_prompt" json:"sys_prompt"` + CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"` + UpdateAt time.Time `gorm:"column:update_at;default:CURRENT_TIMESTAMP" json:"update_at"` + Status int32 `gorm:"column:status;not null" json:"status"` + DeleteAt time.Time `gorm:"column:delete_at" json:"delete_at"` +} + +// TableName AiSy's table name +func (*AiSy) TableName() string { + return TableNameAiSy +} diff --git a/internal/data/model/ai_task.gen.go b/internal/data/model/ai_task.gen.go new file mode 100644 index 0000000..a962e8b --- /dev/null +++ b/internal/data/model/ai_task.gen.go @@ -0,0 +1,29 @@ +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. + +package model + +import ( + "time" +) + +const TableNameAiTask = "ai_task" + +// AiTask mapped from table +type AiTask struct { + TaskID int32 `gorm:"column:task_id;primaryKey" json:"task_id"` + SysID int32 `gorm:"column:sys_id;not null" json:"sys_id"` + Name string `gorm:"column:name;not null" json:"name"` + Index string `gorm:"column:index;not null" json:"index"` + Desc string `gorm:"column:desc;not null" json:"desc"` + CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"` + UpdateAt time.Time `gorm:"column:update_at;default:CURRENT_TIMESTAMP" json:"update_at"` + Status int32 `gorm:"column:status;not null;default:1" json:"status"` + DeleteAt time.Time `gorm:"column:delete_at" json:"delete_at"` +} + +// TableName AiTask's table name +func (*AiTask) TableName() string { + return TableNameAiTask +} diff --git a/internal/entitys/types.go b/internal/entitys/types.go index b9d7d72..9bcb886 100644 --- a/internal/entitys/types.go +++ b/internal/entitys/types.go @@ -3,6 +3,8 @@ package entitys import ( "context" "encoding/json" + + "github.com/gofiber/websocket/v2" ) // ChatRequest 聊天请求 @@ -80,4 +82,5 @@ type Message struct { // RouterService 路由服务接口 type RouterService interface { Route(ctx context.Context, req *ChatRequest) (*ChatResponse, error) + RouteWithSocket(c *websocket.Conn, req *ChatSockRequest) error } diff --git a/internal/middleware/auth.go b/internal/middleware/auth.go deleted file mode 100644 index c7cc178..0000000 --- a/internal/middleware/auth.go +++ /dev/null @@ -1,112 +0,0 @@ -package middleware - -import ( - "context" - "errors" - "log" - "net/http" - "slices" - "strings" - - "knowlege-lsxd/internal/config" - "knowlege-lsxd/internal/types" - "knowlege-lsxd/internal/types/interfaces" - - "github.com/gin-gonic/gin" -) - -// 无需认证的API列表 -var noAuthAPI = map[string][]string{ - "/api/v1/test-data": {"GET"}, - "/api/v1/tenants": {"POST"}, - "/api/v1/initialization/*": {"GET", "POST"}, -} - -// 检查请求是否在无需认证的API列表中 -func isNoAuthAPI(path string, method string) bool { - for api, methods := range noAuthAPI { - // 如果以*结尾,按照前缀匹配,否则按照全路径匹配 - if strings.HasSuffix(api, "*") { - if strings.HasPrefix(path, strings.TrimSuffix(api, "*")) && slices.Contains(methods, method) { - return true - } - } else if path == api && slices.Contains(methods, method) { - return true - } - } - return false -} - -// Auth 认证中间件 -func Auth(tenantService interfaces.TenantService, cfg *config.Config) gin.HandlerFunc { - return func(c *gin.Context) { - // ignore OPTIONS request - if c.Request.Method == "OPTIONS" { - c.Next() - return - } - - // 检查请求是否在无需认证的API列表中 - if isNoAuthAPI(c.Request.URL.Path, c.Request.Method) { - c.Next() - return - } - - // Get API Key from request header - apiKey := c.GetHeader("X-API-Key") - if apiKey == "" { - c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) - c.Abort() - return - } - - // Get tenant information - //tenantID, err := tenantService.ExtractTenantIDFromAPIKey(apiKey) - //if err != nil { - // c.JSON(http.StatusUnauthorized, gin.H{ - // "error": "Unauthorized: invalid API key format", - // }) - // c.Abort() - // return - //} - - // Verify API key validity (matches the one in database) - t, err := tenantService.GetTenantByApiKey(c.Request.Context(), apiKey) - if err != nil { - log.Printf("Error getting tenant by ID: %v, tenantID: %d, apiKey: %s", err, t.ID, apiKey) - c.JSON(http.StatusUnauthorized, gin.H{ - "error": "Unauthorized: invalid API key", - }) - c.Abort() - return - } - - if t == nil || t.APIKey != apiKey { - c.JSON(http.StatusUnauthorized, gin.H{ - "error": "Unauthorized: invalid API key", - }) - c.Abort() - return - } - - // Store tenant ID in context - c.Set(types.TenantIDContextKey.String(), t.ID) - c.Set(types.TenantInfoContextKey.String(), t) - c.Request = c.Request.WithContext( - context.WithValue( - context.WithValue(c.Request.Context(), types.TenantIDContextKey, t.ID), - types.TenantInfoContextKey, t, - ), - ) - c.Next() - } -} - -// GetTenantIDFromContext helper function to get tenant ID from context -func GetTenantIDFromContext(ctx context.Context) (uint, error) { - tenantID, ok := ctx.Value("tenantID").(uint) - if !ok { - return 0, errors.New("tenant ID not found in context") - } - return tenantID, nil -} diff --git a/internal/middleware/error_handler.go b/internal/middleware/error_handler.go deleted file mode 100644 index 12648ef..0000000 --- a/internal/middleware/error_handler.go +++ /dev/null @@ -1,45 +0,0 @@ -package middleware - -import ( - "net/http" - - "github.com/gin-gonic/gin" - - "ai_scheduler/internal/data/error" -) - -// ErrorHandler 是一个处理应用错误的中间件 -func ErrorHandler() gin.HandlerFunc { - return func(c *gin.Context) { - // 处理请求 - c.Next() - - // 检查是否有错误 - if len(c.Errors) > 0 { - // 获取最后一个错误 - err := c.Errors.Last().Err - - // 检查是否为应用错误 - if appErr, ok := errors.IsAppError(err); ok { - // 返回应用错误 - c.JSON(appErr.Code(), gin.H{ - "status": "error", - "error": gin.H{ - "code": appErr.Code(), - "message": appErr.Error(), - }, - }) - return - } - - // 处理其他类型的错误 - c.JSON(http.StatusInternalServerError, gin.H{ - "status": "error", - "error": gin.H{ - "code": http.StatusInternalServerError, - "message": "Internal server error", - }, - }) - } - } -} diff --git a/internal/middleware/logger.go b/internal/middleware/logger.go deleted file mode 100644 index 54370b7..0000000 --- a/internal/middleware/logger.go +++ /dev/null @@ -1,83 +0,0 @@ -package middleware - -import ( - "context" - "time" - - "github.com/gin-gonic/gin" - "github.com/google/uuid" - "knowlege-lsxd/internal/logger" - "knowlege-lsxd/internal/types" -) - -// RequestID middleware adds a unique request ID to the context -func RequestID() gin.HandlerFunc { - return func(c *gin.Context) { - // Get request ID from header or generate a new one - requestID := c.GetHeader("X-Request-ID") - if requestID == "" { - requestID = uuid.New().String() - } - // Set request ID in header - c.Header("X-Request-ID", requestID) - - // Set request ID in context - c.Set(types.RequestIDContextKey.String(), requestID) - - // Set logger in context - requestLogger := logger.GetLogger(c) - requestLogger = requestLogger.WithField("request_id", requestID) - c.Set(types.LoggerContextKey.String(), requestLogger) - - // Set request ID in the global context for logging - c.Request = c.Request.WithContext( - context.WithValue( - context.WithValue(c.Request.Context(), types.RequestIDContextKey, requestID), - types.LoggerContextKey, requestLogger, - ), - ) - - c.Next() - } -} - -// Logger middleware logs request details with request ID -func Logger() gin.HandlerFunc { - return func(c *gin.Context) { - start := time.Now() - path := c.Request.URL.Path - raw := c.Request.URL.RawQuery - - // Process request - c.Next() - - // Get request ID from context - requestID, exists := c.Get(types.RequestIDContextKey.String()) - if !exists { - requestID = "unknown" - } - - // Calculate latency - latency := time.Since(start) - - // Get client IP and status code - clientIP := c.ClientIP() - statusCode := c.Writer.Status() - method := c.Request.Method - - if raw != "" { - path = path + "?" + raw - } - - // Log with request ID - logger.GetLogger(c).Infof("[%s] %d | %3d | %13v | %15s | %s %s", - requestID, - statusCode, - c.Writer.Size(), - latency, - clientIP, - method, - path, - ) - } -} diff --git a/internal/middleware/recovery.go b/internal/middleware/recovery.go deleted file mode 100644 index 9b60f13..0000000 --- a/internal/middleware/recovery.go +++ /dev/null @@ -1,34 +0,0 @@ -package middleware - -import ( - "fmt" - "log" - "runtime/debug" - - "github.com/gin-gonic/gin" -) - -// Recovery is a middleware that recovers from panics -func Recovery() gin.HandlerFunc { - return func(c *gin.Context) { - defer func() { - if err := recover(); err != nil { - // Get request ID - requestID, _ := c.Get("RequestID") - - // Print stacktrace - stacktrace := debug.Stack() - // Log error - log.Printf("[PANIC] %s | %v | %s", requestID, err, stacktrace) - - // 返回500错误 - c.AbortWithStatusJSON(500, gin.H{ - "error": "Internal Server Error", - "message": fmt.Sprintf("%v", err), - }) - } - }() - - c.Next() - } -} diff --git a/internal/middleware/trace.go b/internal/middleware/trace.go deleted file mode 100644 index d74513f..0000000 --- a/internal/middleware/trace.go +++ /dev/null @@ -1,124 +0,0 @@ -package middleware - -import ( - "bytes" - "fmt" - "io" - "strings" - - "github.com/gin-gonic/gin" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/codes" - - "knowlege-lsxd/internal/tracing" - "knowlege-lsxd/internal/types" -) - -// Custom ResponseWriter to capture response content -type responseBodyWriter struct { - gin.ResponseWriter - body *bytes.Buffer -} - -// Override Write method to write response content to buffer and original writer -func (r responseBodyWriter) Write(b []byte) (int, error) { - r.body.Write(b) - return r.ResponseWriter.Write(b) -} - -// TracingMiddleware provides a Gin middleware that creates a trace span for each request -func TracingMiddleware() gin.HandlerFunc { - return func(c *gin.Context) { - // Extract trace context from request headers - propagator := tracing.GetTracer() - if propagator == nil { - c.Next() - return - } - - // Get request ID as Span ID - requestID := c.GetString(string(types.RequestIDContextKey)) - if requestID == "" { - requestID = c.GetHeader("X-Request-ID") - } - - // Create new span - spanName := fmt.Sprintf("%s %s", c.Request.Method, c.FullPath()) - ctx, span := tracing.ContextWithSpan(c.Request.Context(), spanName) - defer span.End() - - // Set basic span attributes - span.SetAttributes( - attribute.String("http.method", c.Request.Method), - attribute.String("http.url", c.Request.URL.String()), - attribute.String("http.path", c.FullPath()), - ) - - // Record request headers (optional, or selectively record important headers) - for key, values := range c.Request.Header { - // Skip sensitive or unnecessary headers - if strings.ToLower(key) == "authorization" || strings.ToLower(key) == "cookie" { - continue - } - span.SetAttributes(attribute.String("http.request.header."+key, strings.Join(values, ";"))) - } - - // Record request body (for POST/PUT/PATCH requests) - if c.Request.Method == "POST" || c.Request.Method == "PUT" || c.Request.Method == "PATCH" { - if c.Request.Body != nil { - bodyBytes, _ := io.ReadAll(c.Request.Body) - span.SetAttributes(attribute.String("http.request.body", string(bodyBytes))) - // Reset request body because ReadAll consumes the Reader content - c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) - } - } - - // Record query parameters - if len(c.Request.URL.RawQuery) > 0 { - span.SetAttributes(attribute.String("http.request.query", c.Request.URL.RawQuery)) - } - - // Set request context with span context - c.Request = c.Request.WithContext(ctx) - - // Store tracing context in Gin context - c.Set("trace.span", span) - c.Set("trace.ctx", ctx) - - // Create response body capturer - responseBody := &bytes.Buffer{} - responseWriter := &responseBodyWriter{ - ResponseWriter: c.Writer, - body: responseBody, - } - c.Writer = responseWriter - - // Process request - c.Next() - - // Set response status code - statusCode := c.Writer.Status() - span.SetAttributes(attribute.Int("http.status_code", statusCode)) - - // Record response body - responseContent := responseBody.String() - if len(responseContent) > 0 { - span.SetAttributes(attribute.String("http.response.body", responseContent)) - } - - // Record response headers (optional, or selectively record important headers) - for key, values := range c.Writer.Header() { - span.SetAttributes(attribute.String("http.response.header."+key, strings.Join(values, ";"))) - } - - // Mark as error if status code >= 400 - if statusCode >= 400 { - span.SetStatus(codes.Error, fmt.Sprintf("HTTP %d", statusCode)) - if err := c.Errors.Last(); err != nil { - span.RecordError(err.Err) - } - } else { - span.SetStatus(codes.Ok, "") - } - } -} diff --git a/internal/pkg/func.go b/internal/pkg/func.go new file mode 100644 index 0000000..c1caffe --- /dev/null +++ b/internal/pkg/func.go @@ -0,0 +1 @@ +package pkg diff --git a/internal/pkg/gorm.go b/internal/pkg/gorm.go new file mode 100644 index 0000000..5daecb6 --- /dev/null +++ b/internal/pkg/gorm.go @@ -0,0 +1,25 @@ +package pkg + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/pkg/utils_gorm" + + "gorm.io/gorm" +) + +type Db struct { + Client *gorm.DB +} + +func NewGormDb(c *config.Config) (*Db, func()) { + transDBClient, mf := utils_gorm.DBConn(c.DB) + //directDBClient, df := directDB(c, hLog) + cleanup := func() { + mf() + //df() + } + return &Db{ + Client: transDBClient, + //DirectDBClient: directDBClient, + }, cleanup +} diff --git a/internal/pkg/mapstructure/decode_hooks.go b/internal/pkg/mapstructure/decode_hooks.go new file mode 100644 index 0000000..3a754ca --- /dev/null +++ b/internal/pkg/mapstructure/decode_hooks.go @@ -0,0 +1,279 @@ +package mapstructure + +import ( + "encoding" + "errors" + "fmt" + "net" + "reflect" + "strconv" + "strings" + "time" +) + +// typedDecodeHook takes a raw DecodeHookFunc (an interface{}) and turns +// it into the proper DecodeHookFunc type, such as DecodeHookFuncType. +func typedDecodeHook(h DecodeHookFunc) DecodeHookFunc { + // Create variables here so we can reference them with the reflect pkg + var f1 DecodeHookFuncType + var f2 DecodeHookFuncKind + var f3 DecodeHookFuncValue + + // Fill in the variables into this interface and the rest is done + // automatically using the reflect package. + potential := []interface{}{f1, f2, f3} + + v := reflect.ValueOf(h) + vt := v.Type() + for _, raw := range potential { + pt := reflect.ValueOf(raw).Type() + if vt.ConvertibleTo(pt) { + return v.Convert(pt).Interface() + } + } + + return nil +} + +// DecodeHookExec executes the given decode hook. This should be used +// since it'll naturally degrade to the older backwards compatible DecodeHookFunc +// that took reflect.Kind instead of reflect.Type. +func DecodeHookExec( + raw DecodeHookFunc, + from reflect.Value, to reflect.Value) (interface{}, error) { + + switch f := typedDecodeHook(raw).(type) { + case DecodeHookFuncType: + return f(from.Type(), to.Type(), from.Interface()) + case DecodeHookFuncKind: + return f(from.Kind(), to.Kind(), from.Interface()) + case DecodeHookFuncValue: + return f(from, to) + default: + return nil, errors.New("invalid decode hook signature") + } +} + +// ComposeDecodeHookFunc creates a single DecodeHookFunc that +// automatically composes multiple DecodeHookFuncs. +// +// The composed funcs are called in order, with the result of the +// previous transformation. +func ComposeDecodeHookFunc(fs ...DecodeHookFunc) DecodeHookFunc { + return func(f reflect.Value, t reflect.Value) (interface{}, error) { + var err error + data := f.Interface() + + newFrom := f + for _, f1 := range fs { + data, err = DecodeHookExec(f1, newFrom, t) + if err != nil { + return nil, err + } + newFrom = reflect.ValueOf(data) + } + + return data, nil + } +} + +// OrComposeDecodeHookFunc executes all input hook functions until one of them returns no error. In that case its value is returned. +// If all hooks return an error, OrComposeDecodeHookFunc returns an error concatenating all error messages. +func OrComposeDecodeHookFunc(ff ...DecodeHookFunc) DecodeHookFunc { + return func(a, b reflect.Value) (interface{}, error) { + var allErrs string + var out interface{} + var err error + + for _, f := range ff { + out, err = DecodeHookExec(f, a, b) + if err != nil { + allErrs += err.Error() + "\n" + continue + } + + return out, nil + } + + return nil, errors.New(allErrs) + } +} + +// StringToSliceHookFunc returns a DecodeHookFunc that converts +// string to []string by splitting on the given sep. +func StringToSliceHookFunc(sep string) DecodeHookFunc { + return func( + f reflect.Kind, + t reflect.Kind, + data interface{}) (interface{}, error) { + if f != reflect.String || t != reflect.Slice { + return data, nil + } + + raw := data.(string) + if raw == "" { + return []string{}, nil + } + + return strings.Split(raw, sep), nil + } +} + +// StringToTimeDurationHookFunc returns a DecodeHookFunc that converts +// strings to time.Duration. +func StringToTimeDurationHookFunc() DecodeHookFunc { + return func( + f reflect.Type, + t reflect.Type, + data interface{}) (interface{}, error) { + if f.Kind() != reflect.String { + return data, nil + } + if t != reflect.TypeOf(time.Duration(5)) { + return data, nil + } + + // Convert it by parsing + return time.ParseDuration(data.(string)) + } +} + +// StringToIPHookFunc returns a DecodeHookFunc that converts +// strings to net.IP +func StringToIPHookFunc() DecodeHookFunc { + return func( + f reflect.Type, + t reflect.Type, + data interface{}) (interface{}, error) { + if f.Kind() != reflect.String { + return data, nil + } + if t != reflect.TypeOf(net.IP{}) { + return data, nil + } + + // Convert it by parsing + ip := net.ParseIP(data.(string)) + if ip == nil { + return net.IP{}, fmt.Errorf("failed parsing ip %v", data) + } + + return ip, nil + } +} + +// StringToIPNetHookFunc returns a DecodeHookFunc that converts +// strings to net.IPNet +func StringToIPNetHookFunc() DecodeHookFunc { + return func( + f reflect.Type, + t reflect.Type, + data interface{}) (interface{}, error) { + if f.Kind() != reflect.String { + return data, nil + } + if t != reflect.TypeOf(net.IPNet{}) { + return data, nil + } + + // Convert it by parsing + _, net, err := net.ParseCIDR(data.(string)) + return net, err + } +} + +// StringToTimeHookFunc returns a DecodeHookFunc that converts +// strings to time.Time. +func StringToTimeHookFunc(layout string) DecodeHookFunc { + return func( + f reflect.Type, + t reflect.Type, + data interface{}) (interface{}, error) { + if f.Kind() != reflect.String { + return data, nil + } + if t != reflect.TypeOf(time.Time{}) { + return data, nil + } + + // Convert it by parsing + return time.Parse(layout, data.(string)) + } +} + +// WeaklyTypedHook is a DecodeHookFunc which adds support for weak typing to +// the decoder. +// +// Note that this is significantly different from the WeaklyTypedInput option +// of the DecoderConfig. +func WeaklyTypedHook( + f reflect.Kind, + t reflect.Kind, + data interface{}) (interface{}, error) { + dataVal := reflect.ValueOf(data) + switch t { + case reflect.String: + switch f { + case reflect.Bool: + if dataVal.Bool() { + return "1", nil + } + return "0", nil + case reflect.Float32: + return strconv.FormatFloat(dataVal.Float(), 'f', -1, 64), nil + case reflect.Int: + return strconv.FormatInt(dataVal.Int(), 10), nil + case reflect.Slice: + dataType := dataVal.Type() + elemKind := dataType.Elem().Kind() + if elemKind == reflect.Uint8 { + return string(dataVal.Interface().([]uint8)), nil + } + case reflect.Uint: + return strconv.FormatUint(dataVal.Uint(), 10), nil + } + } + + return data, nil +} + +func RecursiveStructToMapHookFunc() DecodeHookFunc { + return func(f reflect.Value, t reflect.Value) (interface{}, error) { + if f.Kind() != reflect.Struct { + return f.Interface(), nil + } + + var i interface{} = struct{}{} + if t.Type() != reflect.TypeOf(&i).Elem() { + return f.Interface(), nil + } + + m := make(map[string]interface{}) + t.Set(reflect.ValueOf(m)) + + return f.Interface(), nil + } +} + +// TextUnmarshallerHookFunc returns a DecodeHookFunc that applies +// strings to the UnmarshalText function, when the target type +// implements the encoding.TextUnmarshaler interface +func TextUnmarshallerHookFunc() DecodeHookFuncType { + return func( + f reflect.Type, + t reflect.Type, + data interface{}) (interface{}, error) { + if f.Kind() != reflect.String { + return data, nil + } + result := reflect.New(t).Interface() + unmarshaller, ok := result.(encoding.TextUnmarshaler) + if !ok { + return data, nil + } + if err := unmarshaller.UnmarshalText([]byte(data.(string))); err != nil { + return nil, err + } + return result, nil + } +} diff --git a/internal/pkg/mapstructure/decode_hooks_test.go b/internal/pkg/mapstructure/decode_hooks_test.go new file mode 100644 index 0000000..07fbedf --- /dev/null +++ b/internal/pkg/mapstructure/decode_hooks_test.go @@ -0,0 +1,567 @@ +package mapstructure + +import ( + "errors" + "math/big" + "net" + "reflect" + "testing" + "time" +) + +func TestComposeDecodeHookFunc(t *testing.T) { + f1 := func( + f reflect.Kind, + t reflect.Kind, + data interface{}) (interface{}, error) { + return data.(string) + "foo", nil + } + + f2 := func( + f reflect.Kind, + t reflect.Kind, + data interface{}) (interface{}, error) { + return data.(string) + "bar", nil + } + + f := ComposeDecodeHookFunc(f1, f2) + + result, err := DecodeHookExec( + f, reflect.ValueOf(""), reflect.ValueOf([]byte(""))) + if err != nil { + t.Fatalf("bad: %s", err) + } + if result.(string) != "foobar" { + t.Fatalf("bad: %#v", result) + } +} + +func TestComposeDecodeHookFunc_err(t *testing.T) { + f1 := func(reflect.Kind, reflect.Kind, interface{}) (interface{}, error) { + return nil, errors.New("foo") + } + + f2 := func(reflect.Kind, reflect.Kind, interface{}) (interface{}, error) { + panic("NOPE") + } + + f := ComposeDecodeHookFunc(f1, f2) + + _, err := DecodeHookExec( + f, reflect.ValueOf(""), reflect.ValueOf([]byte(""))) + if err.Error() != "foo" { + t.Fatalf("bad: %s", err) + } +} + +func TestComposeDecodeHookFunc_kinds(t *testing.T) { + var f2From reflect.Kind + + f1 := func( + f reflect.Kind, + t reflect.Kind, + data interface{}) (interface{}, error) { + return int(42), nil + } + + f2 := func( + f reflect.Kind, + t reflect.Kind, + data interface{}) (interface{}, error) { + f2From = f + return data, nil + } + + f := ComposeDecodeHookFunc(f1, f2) + + _, err := DecodeHookExec( + f, reflect.ValueOf(""), reflect.ValueOf([]byte(""))) + if err != nil { + t.Fatalf("bad: %s", err) + } + if f2From != reflect.Int { + t.Fatalf("bad: %#v", f2From) + } +} + +func TestOrComposeDecodeHookFunc(t *testing.T) { + f1 := func( + f reflect.Kind, + t reflect.Kind, + data interface{}) (interface{}, error) { + return data.(string) + "foo", nil + } + + f2 := func( + f reflect.Kind, + t reflect.Kind, + data interface{}) (interface{}, error) { + return data.(string) + "bar", nil + } + + f := OrComposeDecodeHookFunc(f1, f2) + + result, err := DecodeHookExec( + f, reflect.ValueOf(""), reflect.ValueOf([]byte(""))) + if err != nil { + t.Fatalf("bad: %s", err) + } + if result.(string) != "foo" { + t.Fatalf("bad: %#v", result) + } +} + +func TestOrComposeDecodeHookFunc_correctValueIsLast(t *testing.T) { + f1 := func( + f reflect.Kind, + t reflect.Kind, + data interface{}) (interface{}, error) { + return nil, errors.New("f1 error") + } + + f2 := func( + f reflect.Kind, + t reflect.Kind, + data interface{}) (interface{}, error) { + return nil, errors.New("f2 error") + } + + f3 := func( + f reflect.Kind, + t reflect.Kind, + data interface{}) (interface{}, error) { + return data.(string) + "bar", nil + } + + f := OrComposeDecodeHookFunc(f1, f2, f3) + + result, err := DecodeHookExec( + f, reflect.ValueOf(""), reflect.ValueOf([]byte(""))) + if err != nil { + t.Fatalf("bad: %s", err) + } + if result.(string) != "bar" { + t.Fatalf("bad: %#v", result) + } +} + +func TestOrComposeDecodeHookFunc_err(t *testing.T) { + f1 := func( + f reflect.Kind, + t reflect.Kind, + data interface{}) (interface{}, error) { + return nil, errors.New("f1 error") + } + + f2 := func( + f reflect.Kind, + t reflect.Kind, + data interface{}) (interface{}, error) { + return nil, errors.New("f2 error") + } + + f := OrComposeDecodeHookFunc(f1, f2) + + _, err := DecodeHookExec( + f, reflect.ValueOf(""), reflect.ValueOf([]byte(""))) + if err == nil { + t.Fatalf("bad: should return an error") + } + if err.Error() != "f1 error\nf2 error\n" { + t.Fatalf("bad: %s", err) + } +} + +func TestComposeDecodeHookFunc_safe_nofuncs(t *testing.T) { + f := ComposeDecodeHookFunc() + type myStruct2 struct { + MyInt int + } + + type myStruct1 struct { + Blah map[string]myStruct2 + } + + src := &myStruct1{Blah: map[string]myStruct2{ + "test": { + MyInt: 1, + }, + }} + + dst := &myStruct1{} + dConf := &DecoderConfig{ + Result: dst, + ErrorUnused: true, + DecodeHook: f, + } + d, err := NewDecoder(dConf) + if err != nil { + t.Fatal(err) + } + err = d.Decode(src) + if err != nil { + t.Fatal(err) + } +} + +func TestStringToSliceHookFunc(t *testing.T) { + f := StringToSliceHookFunc(",") + + strValue := reflect.ValueOf("42") + sliceValue := reflect.ValueOf([]byte("42")) + cases := []struct { + f, t reflect.Value + result interface{} + err bool + }{ + {sliceValue, sliceValue, []byte("42"), false}, + {strValue, strValue, "42", false}, + { + reflect.ValueOf("foo,bar,baz"), + sliceValue, + []string{"foo", "bar", "baz"}, + false, + }, + { + reflect.ValueOf(""), + sliceValue, + []string{}, + false, + }, + } + + for i, tc := range cases { + actual, err := DecodeHookExec(f, tc.f, tc.t) + if tc.err != (err != nil) { + t.Fatalf("case %d: expected jderr %#v", i, tc.err) + } + if !reflect.DeepEqual(actual, tc.result) { + t.Fatalf( + "case %d: expected %#v, got %#v", + i, tc.result, actual) + } + } +} + +func TestStringToTimeDurationHookFunc(t *testing.T) { + f := StringToTimeDurationHookFunc() + + timeValue := reflect.ValueOf(time.Duration(5)) + strValue := reflect.ValueOf("") + cases := []struct { + f, t reflect.Value + result interface{} + err bool + }{ + {reflect.ValueOf("5s"), timeValue, 5 * time.Second, false}, + {reflect.ValueOf("5"), timeValue, time.Duration(0), true}, + {reflect.ValueOf("5"), strValue, "5", false}, + } + + for i, tc := range cases { + actual, err := DecodeHookExec(f, tc.f, tc.t) + if tc.err != (err != nil) { + t.Fatalf("case %d: expected jderr %#v", i, tc.err) + } + if !reflect.DeepEqual(actual, tc.result) { + t.Fatalf( + "case %d: expected %#v, got %#v", + i, tc.result, actual) + } + } +} + +func TestStringToTimeHookFunc(t *testing.T) { + strValue := reflect.ValueOf("5") + timeValue := reflect.ValueOf(time.Time{}) + cases := []struct { + f, t reflect.Value + layout string + result interface{} + err bool + }{ + {reflect.ValueOf("2006-01-02T15:04:05Z"), timeValue, time.RFC3339, + time.Date(2006, 1, 2, 15, 4, 5, 0, time.UTC), false}, + {strValue, timeValue, time.RFC3339, time.Time{}, true}, + {strValue, strValue, time.RFC3339, "5", false}, + } + + for i, tc := range cases { + f := StringToTimeHookFunc(tc.layout) + actual, err := DecodeHookExec(f, tc.f, tc.t) + if tc.err != (err != nil) { + t.Fatalf("case %d: expected jderr %#v", i, tc.err) + } + if !reflect.DeepEqual(actual, tc.result) { + t.Fatalf( + "case %d: expected %#v, got %#v", + i, tc.result, actual) + } + } +} + +func TestStringToIPHookFunc(t *testing.T) { + strValue := reflect.ValueOf("5") + ipValue := reflect.ValueOf(net.IP{}) + cases := []struct { + f, t reflect.Value + result interface{} + err bool + }{ + {reflect.ValueOf("1.2.3.4"), ipValue, + net.IPv4(0x01, 0x02, 0x03, 0x04), false}, + {strValue, ipValue, net.IP{}, true}, + {strValue, strValue, "5", false}, + } + + for i, tc := range cases { + f := StringToIPHookFunc() + actual, err := DecodeHookExec(f, tc.f, tc.t) + if tc.err != (err != nil) { + t.Fatalf("case %d: expected jderr %#v", i, tc.err) + } + if !reflect.DeepEqual(actual, tc.result) { + t.Fatalf( + "case %d: expected %#v, got %#v", + i, tc.result, actual) + } + } +} + +func TestStringToIPNetHookFunc(t *testing.T) { + strValue := reflect.ValueOf("5") + ipNetValue := reflect.ValueOf(net.IPNet{}) + var nilNet *net.IPNet = nil + + cases := []struct { + f, t reflect.Value + result interface{} + err bool + }{ + {reflect.ValueOf("1.2.3.4/24"), ipNetValue, + &net.IPNet{ + IP: net.IP{0x01, 0x02, 0x03, 0x00}, + Mask: net.IPv4Mask(0xff, 0xff, 0xff, 0x00), + }, false}, + {strValue, ipNetValue, nilNet, true}, + {strValue, strValue, "5", false}, + } + + for i, tc := range cases { + f := StringToIPNetHookFunc() + actual, err := DecodeHookExec(f, tc.f, tc.t) + if tc.err != (err != nil) { + t.Fatalf("case %d: expected jderr %#v", i, tc.err) + } + if !reflect.DeepEqual(actual, tc.result) { + t.Fatalf( + "case %d: expected %#v, got %#v", + i, tc.result, actual) + } + } +} + +func TestWeaklyTypedHook(t *testing.T) { + var f DecodeHookFunc = WeaklyTypedHook + + strValue := reflect.ValueOf("") + cases := []struct { + f, t reflect.Value + result interface{} + err bool + }{ + // TO STRING + { + reflect.ValueOf(false), + strValue, + "0", + false, + }, + + { + reflect.ValueOf(true), + strValue, + "1", + false, + }, + + { + reflect.ValueOf(float32(7)), + strValue, + "7", + false, + }, + + { + reflect.ValueOf(int(7)), + strValue, + "7", + false, + }, + + { + reflect.ValueOf([]uint8("foo")), + strValue, + "foo", + false, + }, + + { + reflect.ValueOf(uint(7)), + strValue, + "7", + false, + }, + } + + for i, tc := range cases { + actual, err := DecodeHookExec(f, tc.f, tc.t) + if tc.err != (err != nil) { + t.Fatalf("case %d: expected jderr %#v", i, tc.err) + } + if !reflect.DeepEqual(actual, tc.result) { + t.Fatalf( + "case %d: expected %#v, got %#v", + i, tc.result, actual) + } + } +} + +func TestStructToMapHookFuncTabled(t *testing.T) { + var f DecodeHookFunc = RecursiveStructToMapHookFunc() + + type b struct { + TestKey string + } + + type a struct { + Sub b + } + + testStruct := a{ + Sub: b{ + TestKey: "testval", + }, + } + + testMap := map[string]interface{}{ + "Sub": map[string]interface{}{ + "TestKey": "testval", + }, + } + + cases := []struct { + name string + receiver interface{} + input interface{} + expected interface{} + err bool + }{ + { + "map receiver", + func() interface{} { + var res map[string]interface{} + return &res + }(), + testStruct, + &testMap, + false, + }, + { + "interface receiver", + func() interface{} { + var res interface{} + return &res + }(), + testStruct, + func() interface{} { + var exp interface{} = testMap + return &exp + }(), + false, + }, + { + "slice receiver errors", + func() interface{} { + var res []string + return &res + }(), + testStruct, + new([]string), + true, + }, + { + "slice to slice - no change", + func() interface{} { + var res []string + return &res + }(), + []string{"a", "b"}, + &[]string{"a", "b"}, + false, + }, + { + "string to string - no change", + func() interface{} { + var res string + return &res + }(), + "test", + func() *string { + s := "test" + return &s + }(), + false, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + cfg := &DecoderConfig{ + DecodeHook: f, + Result: tc.receiver, + } + + d, err := NewDecoder(cfg) + if err != nil { + t.Fatalf("unexpected jderr %#v", err) + } + + err = d.Decode(tc.input) + if tc.err != (err != nil) { + t.Fatalf("expected jderr %#v", err) + } + + if !reflect.DeepEqual(tc.expected, tc.receiver) { + t.Fatalf("expected %#v, got %#v", + tc.expected, tc.receiver) + } + }) + + } +} + +func TestTextUnmarshallerHookFunc(t *testing.T) { + cases := []struct { + f, t reflect.Value + result interface{} + err bool + }{ + {reflect.ValueOf("42"), reflect.ValueOf(big.Int{}), big.NewInt(42), false}, + {reflect.ValueOf("invalid"), reflect.ValueOf(big.Int{}), nil, true}, + {reflect.ValueOf("5"), reflect.ValueOf("5"), "5", false}, + } + + for i, tc := range cases { + f := TextUnmarshallerHookFunc() + actual, err := DecodeHookExec(f, tc.f, tc.t) + if tc.err != (err != nil) { + t.Fatalf("case %d: expected jderr %#v", i, tc.err) + } + if !reflect.DeepEqual(actual, tc.result) { + t.Fatalf( + "case %d: expected %#v, got %#v", + i, tc.result, actual) + } + } +} diff --git a/internal/pkg/mapstructure/error.go b/internal/pkg/mapstructure/error.go new file mode 100644 index 0000000..47a99e5 --- /dev/null +++ b/internal/pkg/mapstructure/error.go @@ -0,0 +1,50 @@ +package mapstructure + +import ( + "errors" + "fmt" + "sort" + "strings" +) + +// Error implements the error interface and can represents multiple +// errors that occur in the course of a single decode. +type Error struct { + Errors []string +} + +func (e *Error) Error() string { + points := make([]string, len(e.Errors)) + for i, err := range e.Errors { + points[i] = fmt.Sprintf("* %s", err) + } + + sort.Strings(points) + return fmt.Sprintf( + "%d error(s) decoding:\n\n%s", + len(e.Errors), strings.Join(points, "\n")) +} + +// WrappedErrors implements the errwrap.Wrapper interface to make this +// return value more useful with the errwrap and go-multierror libraries. +func (e *Error) WrappedErrors() []error { + if e == nil { + return nil + } + + result := make([]error, len(e.Errors)) + for i, e := range e.Errors { + result[i] = errors.New(e) + } + + return result +} + +func appendErrors(errors []string, err error) []string { + switch e := err.(type) { + case *Error: + return append(errors, e.Errors...) + default: + return append(errors, e.Error()) + } +} diff --git a/internal/pkg/mapstructure/mapstructure.go b/internal/pkg/mapstructure/mapstructure.go new file mode 100644 index 0000000..0d26c75 --- /dev/null +++ b/internal/pkg/mapstructure/mapstructure.go @@ -0,0 +1,1386 @@ +package mapstructure + +import ( + "encoding/json" + "errors" + "fmt" + "reflect" + "sort" + "strconv" + "strings" +) + +// DecodeHookFunc is the callback function that can be used for +// data transformations. See "DecodeHook" in the DecoderConfig +// struct. +// +// The type must be one of DecodeHookFuncType, DecodeHookFuncKind, or +// DecodeHookFuncValue. +// Values are a superset of Types (Values can return types), and Types are a +// superset of Kinds (Types can return Kinds) and are generally a richer thing +// to use, but Kinds are simpler if you only need those. +// +// The reason DecodeHookFunc is multi-typed is for backwards compatibility: +// we started with Kinds and then realized Types were the better solution, +// but have a promise to not break backwards compat so we now support +// both. +type DecodeHookFunc interface{} + +// DecodeHookFuncType is a DecodeHookFunc which has complete information about +// the source and target types. +type DecodeHookFuncType func(reflect.Type, reflect.Type, interface{}) (interface{}, error) + +// DecodeHookFuncKind is a DecodeHookFunc which knows only the Kinds of the +// source and target types. +type DecodeHookFuncKind func(reflect.Kind, reflect.Kind, interface{}) (interface{}, error) + +// DecodeHookFuncValue is a DecodeHookFunc which has complete access to both the source and target +// values. +type DecodeHookFuncValue func(from reflect.Value, to reflect.Value) (interface{}, error) + +// DecoderConfig is the configuration that is used to create a new decoder +// and allows customization of various aspects of decoding. +type DecoderConfig struct { + // DecodeHook, if set, will be called before any decoding and any + // type conversion (if WeaklyTypedInput is on). This lets you modify + // the values before they're set down onto the resulting struct. The + // DecodeHook is called for every map and value in the input. This means + // that if a struct has embedded fields with squash tags the decode hook + // is called only once with all of the input data, not once for each + // embedded struct. + // + // If an error is returned, the entire decode will fail with that error. + DecodeHook DecodeHookFunc + + // If ErrorUnused is true, then it is an error for there to exist + // keys in the original map that were unused in the decoding process + // (extra keys). + ErrorUnused bool + + // If ErrorUnset is true, then it is an error for there to exist + // fields in the result that were not set in the decoding process + // (extra fields). This only applies to decoding to a struct. This + // will affect all nested structs as well. + ErrorUnset bool + + // ZeroFields, if set to true, will zero fields before writing them. + // For example, a map will be emptied before decoded values are put in + // it. If this is false, a map will be merged. + ZeroFields bool + + // If WeaklyTypedInput is true, the decoder will make the following + // "weak" conversions: + // + // - bools to string (true = "1", false = "0") + // - numbers to string (base 10) + // - bools to int/uint (true = 1, false = 0) + // - strings to int/uint (base implied by prefix) + // - int to bool (true if value != 0) + // - string to bool (accepts: 1, t, T, TRUE, true, True, 0, f, F, + // FALSE, false, False. Anything else is an error) + // - empty array = empty map and vice versa + // - negative numbers to overflowed uint values (base 10) + // - slice of maps to a merged map + // - single values are converted to slices if required. Each + // element is weakly decoded. For example: "4" can become []int{4} + // if the target type is an int slice. + // + WeaklyTypedInput bool + + // Squash will squash embedded structs. A squash tag may also be + // added to an individual struct field using a tag. For example: + // + // type Parent struct { + // Child `mapstructure:",squash"` + // } + Squash bool + + // Metadata is the struct that will contain extra metadata about + // the decoding. If this is nil, then no metadata will be tracked. + Metadata *Metadata + + // Result is a pointer to the struct that will contain the decoded + // value. + Result interface{} + + // The tag name that mapstructure reads for field names. This + // defaults to "mapstructure" + TagName string + + // IgnoreUntaggedFields ignores all struct fields without explicit + // TagName, comparable to `mapstructure:"-"` as default behaviour. + IgnoreUntaggedFields bool + + // MatchName is the function used to match the map key to the struct + // field name or tag. Defaults to `strings.EqualFold`. This can be used + // to implement case-sensitive tag values, support snake casing, etc. + MatchName func(mapKey, fieldName string) bool +} + +// A Decoder takes a raw interface value and turns it into structured +// data, keeping track of rich error information along the way in case +// anything goes wrong. Unlike the basic top-level Decode method, you can +// more finely control how the Decoder behaves using the DecoderConfig +// structure. The top-level Decode method is just a convenience that sets +// up the most basic Decoder. +type Decoder struct { + config *DecoderConfig +} + +// Metadata contains information about decoding a structure that +// is tedious or difficult to get otherwise. +type Metadata struct { + // Keys are the keys of the structure which were successfully decoded + Keys []string + + // Unused is a slice of keys that were found in the raw value but + // weren't decoded since there was no matching field in the result interface + Unused []string + + // Unset is a slice of field names that were found in the result interface + // but weren't set in the decoding process since there was no matching value + // in the input + Unset []string +} + +// Decode takes an input structure and uses reflection to translate it to +// the output structure. output must be a pointer to a map or struct. +func Decode(input interface{}, output interface{}) error { + config := &DecoderConfig{ + Metadata: nil, + Result: output, + } + + decoder, err := NewDecoder(config) + if err != nil { + return err + } + + return decoder.Decode(input) +} + +// WeakDecode is the same as Decode but is shorthand to enable +// WeaklyTypedInput. See DecoderConfig for more info. +func WeakDecode(input, output interface{}) error { + config := &DecoderConfig{ + Metadata: nil, + Result: output, + WeaklyTypedInput: true, + } + + decoder, err := NewDecoder(config) + if err != nil { + return err + } + + return decoder.Decode(input) +} + +// DecodeMetadata is the same as Decode, but is shorthand to +// enable metadata collection. See DecoderConfig for more info. +func DecodeMetadata(input interface{}, output interface{}, metadata *Metadata) error { + config := &DecoderConfig{ + Metadata: metadata, + Result: output, + } + + decoder, err := NewDecoder(config) + if err != nil { + return err + } + + return decoder.Decode(input) +} + +// WeakDecodeMetadata is the same as Decode, but is shorthand to +// enable both WeaklyTypedInput and metadata collection. See +// DecoderConfig for more info. +func WeakDecodeMetadata(input interface{}, output interface{}, metadata *Metadata) error { + config := &DecoderConfig{ + Metadata: metadata, + Result: output, + WeaklyTypedInput: true, + } + + decoder, err := NewDecoder(config) + if err != nil { + return err + } + + return decoder.Decode(input) +} + +// NewDecoder returns a new decoder for the given configuration. Once +// a decoder has been returned, the same configuration must not be used +// again. +func NewDecoder(config *DecoderConfig) (*Decoder, error) { + val := reflect.ValueOf(config.Result) + if val.Kind() != reflect.Ptr { + return nil, errors.New("result must be a pointer") + } + + val = val.Elem() + if !val.CanAddr() { + return nil, errors.New("result must be addressable (a pointer)") + } + + if config.Metadata != nil { + if config.Metadata.Keys == nil { + config.Metadata.Keys = make([]string, 0) + } + + if config.Metadata.Unused == nil { + config.Metadata.Unused = make([]string, 0) + } + + if config.Metadata.Unset == nil { + config.Metadata.Unset = make([]string, 0) + } + } + + if config.TagName == "" { + config.TagName = "mapstructure" + } + + if config.MatchName == nil { + config.MatchName = strings.EqualFold + } + + result := &Decoder{ + config: config, + } + + return result, nil +} + +// Decode decodes the given raw interface to the target pointer specified +// by the configuration. +func (d *Decoder) Decode(input interface{}) error { + return d.decode("", input, reflect.ValueOf(d.config.Result).Elem()) +} + +// Decodes an unknown data type into a specific reflection value. +func (d *Decoder) decode(name string, input interface{}, outVal reflect.Value) error { + var inputVal reflect.Value + if input != nil { + inputVal = reflect.ValueOf(input) + + // We need to check here if input is a typed nil. Typed nils won't + // match the "input == nil" below so we check that here. + if inputVal.Kind() == reflect.Ptr && inputVal.IsNil() { + input = nil + } + } + + if input == nil { + // If the data is nil, then we don't set anything, unless ZeroFields is set + // to true. + if d.config.ZeroFields { + outVal.Set(reflect.Zero(outVal.Type())) + + if d.config.Metadata != nil && name != "" { + d.config.Metadata.Keys = append(d.config.Metadata.Keys, name) + } + } + return nil + } + + if !inputVal.IsValid() { + // If the input value is invalid, then we just set the value + // to be the zero value. + outVal.Set(reflect.Zero(outVal.Type())) + if d.config.Metadata != nil && name != "" { + d.config.Metadata.Keys = append(d.config.Metadata.Keys, name) + } + return nil + } + + if d.config.DecodeHook != nil { + // We have a DecodeHook, so let's pre-process the input. + var err error + input, err = DecodeHookExec(d.config.DecodeHook, inputVal, outVal) + if err != nil { + return fmt.Errorf("error decoding '%s': %s", name, err) + } + } + + var err error + outputKind := getKind(outVal) + addMetaKey := true + switch outputKind { + case reflect.Bool: + err = d.decodeBool(name, input, outVal) + case reflect.Interface: + err = d.decodeBasic(name, input, outVal) + case reflect.String: + err = d.decodeString(name, input, outVal) + case reflect.Int: + err = d.decodeInt(name, input, outVal) + case reflect.Uint: + err = d.decodeUint(name, input, outVal) + case reflect.Float32: + err = d.decodeFloat(name, input, outVal) + case reflect.Struct: + err = d.decodeStruct(name, input, outVal) + case reflect.Map: + err = d.decodeMap(name, input, outVal) + case reflect.Ptr: + addMetaKey, err = d.decodePtr(name, input, outVal) + case reflect.Slice: + err = d.decodeSlice(name, input, outVal) + case reflect.Array: + err = d.decodeArray(name, input, outVal) + case reflect.Func: + err = d.decodeFunc(name, input, outVal) + default: + // If we reached this point then we weren't able to decode it + return fmt.Errorf("%s: unsupported type: %s", name, outputKind) + } + + // If we reached here, then we successfully decoded SOMETHING, so + // mark the key as used if we're tracking metainput. + if addMetaKey && d.config.Metadata != nil && name != "" { + d.config.Metadata.Keys = append(d.config.Metadata.Keys, name) + } + + return err +} + +// This decodes a basic type (bool, int, string, etc.) and sets the +// value to "data" of that type. +func (d *Decoder) decodeBasic(name string, data interface{}, val reflect.Value) error { + if val.IsValid() && val.Elem().IsValid() { + elem := val.Elem() + + // If we can't address this element, then its not writable. Instead, + // we make a copy of the value (which is a pointer and therefore + // writable), decode into that, and replace the whole value. + copied := false + if !elem.CanAddr() { + copied = true + + // Make *T + copy := reflect.New(elem.Type()) + + // *T = elem + copy.Elem().Set(elem) + + // Set elem so we decode into it + elem = copy + } + + // Decode. If we have an error then return. We also return right + // away if we're not a copy because that means we decoded directly. + if err := d.decode(name, data, elem); err != nil || !copied { + return err + } + + // If we're a copy, we need to set te final result + val.Set(elem.Elem()) + return nil + } + + dataVal := reflect.ValueOf(data) + + // If the input data is a pointer, and the assigned type is the dereference + // of that exact pointer, then indirect it so that we can assign it. + // Example: *string to string + if dataVal.Kind() == reflect.Ptr && dataVal.Type().Elem() == val.Type() { + dataVal = reflect.Indirect(dataVal) + } + + if !dataVal.IsValid() { + dataVal = reflect.Zero(val.Type()) + } + + dataValType := dataVal.Type() + if !dataValType.AssignableTo(val.Type()) { + return fmt.Errorf( + "'%s' expected type '%s', got '%s'", + name, val.Type(), dataValType) + } + + val.Set(dataVal) + return nil +} + +func (d *Decoder) decodeString(name string, data interface{}, val reflect.Value) error { + dataVal := reflect.Indirect(reflect.ValueOf(data)) + dataKind := getKind(dataVal) + + converted := true + switch { + case dataKind == reflect.String: + val.SetString(dataVal.String()) + case dataKind == reflect.Bool && d.config.WeaklyTypedInput: + if dataVal.Bool() { + val.SetString("1") + } else { + val.SetString("0") + } + case dataKind == reflect.Int && d.config.WeaklyTypedInput: + val.SetString(strconv.FormatInt(dataVal.Int(), 10)) + case dataKind == reflect.Uint && d.config.WeaklyTypedInput: + val.SetString(strconv.FormatUint(dataVal.Uint(), 10)) + case dataKind == reflect.Float32 && d.config.WeaklyTypedInput: + val.SetString(strconv.FormatFloat(dataVal.Float(), 'f', -1, 64)) + case dataKind == reflect.Slice && d.config.WeaklyTypedInput, + dataKind == reflect.Array && d.config.WeaklyTypedInput: + dataType := dataVal.Type() + elemKind := dataType.Elem().Kind() + switch elemKind { + case reflect.Uint8: + var uints []uint8 + if dataKind == reflect.Array { + uints = make([]uint8, dataVal.Len(), dataVal.Len()) + for i := range uints { + uints[i] = dataVal.Index(i).Interface().(uint8) + } + } else { + uints = dataVal.Interface().([]uint8) + } + val.SetString(string(uints)) + default: + converted = false + } + default: + converted = false + } + + if !converted { + return fmt.Errorf( + "'%s' expected type '%s', got unconvertible type '%s', value: '%v'", + name, val.Type(), dataVal.Type(), data) + } + + return nil +} + +func (d *Decoder) decodeInt(name string, data interface{}, val reflect.Value) error { + dataVal := reflect.Indirect(reflect.ValueOf(data)) + dataKind := getKind(dataVal) + dataType := dataVal.Type() + + switch { + case dataKind == reflect.Int: + val.SetInt(dataVal.Int()) + case dataKind == reflect.Uint: + val.SetInt(int64(dataVal.Uint())) + case dataKind == reflect.Float32: + val.SetInt(int64(dataVal.Float())) + case dataKind == reflect.Bool && d.config.WeaklyTypedInput: + if dataVal.Bool() { + val.SetInt(1) + } else { + val.SetInt(0) + } + case dataKind == reflect.String && d.config.WeaklyTypedInput: + str := dataVal.String() + if str == "" { + str = "0" + } + + i, err := strconv.ParseInt(str, 0, val.Type().Bits()) + if err == nil { + val.SetInt(i) + } else { + return fmt.Errorf("cannot parse '%s' as int: %s", name, err) + } + case dataType.PkgPath() == "encoding/json" && dataType.Name() == "Number": + jn := data.(json.Number) + i, err := jn.Int64() + if err != nil { + return fmt.Errorf( + "error decoding json.Number into %s: %s", name, err) + } + val.SetInt(i) + default: + return fmt.Errorf( + "'%s' expected type '%s', got unconvertible type '%s', value: '%v'", + name, val.Type(), dataVal.Type(), data) + } + + return nil +} + +func (d *Decoder) decodeUint(name string, data interface{}, val reflect.Value) error { + dataVal := reflect.Indirect(reflect.ValueOf(data)) + dataKind := getKind(dataVal) + dataType := dataVal.Type() + + switch { + case dataKind == reflect.Int: + i := dataVal.Int() + if i < 0 && !d.config.WeaklyTypedInput { + return fmt.Errorf("cannot parse '%s', %d overflows uint", + name, i) + } + val.SetUint(uint64(i)) + case dataKind == reflect.Uint: + val.SetUint(dataVal.Uint()) + case dataKind == reflect.Float32: + f := dataVal.Float() + if f < 0 && !d.config.WeaklyTypedInput { + return fmt.Errorf("cannot parse '%s', %f overflows uint", + name, f) + } + val.SetUint(uint64(f)) + case dataKind == reflect.Bool && d.config.WeaklyTypedInput: + if dataVal.Bool() { + val.SetUint(1) + } else { + val.SetUint(0) + } + case dataKind == reflect.String && d.config.WeaklyTypedInput: + str := dataVal.String() + if str == "" { + str = "0" + } + + i, err := strconv.ParseUint(str, 0, val.Type().Bits()) + if err == nil { + val.SetUint(i) + } else { + return fmt.Errorf("cannot parse '%s' as uint: %s", name, err) + } + case dataType.PkgPath() == "encoding/json" && dataType.Name() == "Number": + jn := data.(json.Number) + i, err := strconv.ParseUint(string(jn), 0, 64) + if err != nil { + return fmt.Errorf( + "error decoding json.Number into %s: %s", name, err) + } + val.SetUint(i) + default: + return fmt.Errorf( + "'%s' expected type '%s', got unconvertible type '%s', value: '%v'", + name, val.Type(), dataVal.Type(), data) + } + + return nil +} + +func (d *Decoder) decodeBool(name string, data interface{}, val reflect.Value) error { + dataVal := reflect.Indirect(reflect.ValueOf(data)) + dataKind := getKind(dataVal) + + switch { + case dataKind == reflect.Bool: + val.SetBool(dataVal.Bool()) + case dataKind == reflect.Int && d.config.WeaklyTypedInput: + val.SetBool(dataVal.Int() != 0) + case dataKind == reflect.Uint && d.config.WeaklyTypedInput: + val.SetBool(dataVal.Uint() != 0) + case dataKind == reflect.Float32 && d.config.WeaklyTypedInput: + val.SetBool(dataVal.Float() != 0) + case dataKind == reflect.String && d.config.WeaklyTypedInput: + b, err := strconv.ParseBool(dataVal.String()) + if err == nil { + val.SetBool(b) + } else if dataVal.String() == "" { + val.SetBool(false) + } else { + return fmt.Errorf("cannot parse '%s' as bool: %s", name, err) + } + default: + return fmt.Errorf( + "'%s' expected type '%s', got unconvertible type '%s', value: '%v'", + name, val.Type(), dataVal.Type(), data) + } + + return nil +} + +func (d *Decoder) decodeFloat(name string, data interface{}, val reflect.Value) error { + dataVal := reflect.Indirect(reflect.ValueOf(data)) + dataKind := getKind(dataVal) + dataType := dataVal.Type() + + switch { + case dataKind == reflect.Int: + val.SetFloat(float64(dataVal.Int())) + case dataKind == reflect.Uint: + val.SetFloat(float64(dataVal.Uint())) + case dataKind == reflect.Float32: + val.SetFloat(dataVal.Float()) + case dataKind == reflect.Bool && d.config.WeaklyTypedInput: + if dataVal.Bool() { + val.SetFloat(1) + } else { + val.SetFloat(0) + } + case dataKind == reflect.String && d.config.WeaklyTypedInput: + str := dataVal.String() + if str == "" { + str = "0" + } + + f, err := strconv.ParseFloat(str, val.Type().Bits()) + if err == nil { + val.SetFloat(f) + } else { + return fmt.Errorf("cannot parse '%s' as float: %s", name, err) + } + case dataType.PkgPath() == "encoding/json" && dataType.Name() == "Number": + jn := data.(json.Number) + i, err := jn.Float64() + if err != nil { + return fmt.Errorf( + "error decoding json.Number into %s: %s", name, err) + } + val.SetFloat(i) + default: + return fmt.Errorf( + "'%s' expected type '%s', got unconvertible type '%s', value: '%v'", + name, val.Type(), dataVal.Type(), data) + } + + return nil +} + +func (d *Decoder) decodeMap(name string, data interface{}, val reflect.Value) error { + valType := val.Type() + valKeyType := valType.Key() + valElemType := valType.Elem() + + // By default we overwrite keys in the current map + valMap := val + + // If the map is nil or we're purposely zeroing fields, make a new map + if valMap.IsNil() || d.config.ZeroFields { + // Make a new map to hold our result + mapType := reflect.MapOf(valKeyType, valElemType) + valMap = reflect.MakeMap(mapType) + } + + // Check input type and based on the input type jump to the proper func + dataVal := reflect.Indirect(reflect.ValueOf(data)) + switch dataVal.Kind() { + case reflect.Map: + return d.decodeMapFromMap(name, dataVal, val, valMap) + + case reflect.Struct: + return d.decodeMapFromStruct(name, dataVal, val, valMap, data) + + case reflect.Array, reflect.Slice: + if d.config.WeaklyTypedInput { + return d.decodeMapFromSlice(name, dataVal, val, valMap) + } + + fallthrough + + default: + return fmt.Errorf("'%s' expected a map, got '%s'", name, dataVal.Kind()) + } +} + +func (d *Decoder) decodeMapFromSlice(name string, dataVal reflect.Value, val reflect.Value, valMap reflect.Value) error { + // Special case for BC reasons (covered by tests) + if dataVal.Len() == 0 { + val.Set(valMap) + return nil + } + + for i := 0; i < dataVal.Len(); i++ { + err := d.decode( + name+"["+strconv.Itoa(i)+"]", + dataVal.Index(i).Interface(), val) + if err != nil { + return err + } + } + + return nil +} + +func (d *Decoder) decodeMapFromMap(name string, dataVal reflect.Value, val reflect.Value, valMap reflect.Value) error { + valType := val.Type() + valKeyType := valType.Key() + valElemType := valType.Elem() + + // Accumulate errors + errors := make([]string, 0) + + // If the input data is empty, then we just match what the input data is. + if dataVal.Len() == 0 { + if dataVal.IsNil() { + if !val.IsNil() { + val.Set(dataVal) + } + } else { + // Set to empty allocated value + val.Set(valMap) + } + + return nil + } + + for _, k := range dataVal.MapKeys() { + fieldName := name + "[" + k.String() + "]" + + // First decode the key into the proper type + currentKey := reflect.Indirect(reflect.New(valKeyType)) + if err := d.decode(fieldName, k.Interface(), currentKey); err != nil { + errors = appendErrors(errors, err) + continue + } + + // Next decode the data into the proper type + v := dataVal.MapIndex(k).Interface() + currentVal := reflect.Indirect(reflect.New(valElemType)) + if err := d.decode(fieldName, v, currentVal); err != nil { + errors = appendErrors(errors, err) + continue + } + + valMap.SetMapIndex(currentKey, currentVal) + } + + // Set the built up map to the value + val.Set(valMap) + + // If we had errors, return those + if len(errors) > 0 { + return &Error{errors} + } + + return nil +} + +func (d *Decoder) decodeMapFromStruct(name string, dataVal reflect.Value, val reflect.Value, valMap reflect.Value, inData interface{}) error { + typ := dataVal.Type() + + for i := 0; i < typ.NumField(); i++ { + // Get the StructField first since this is a cheap operation. If the + // field is unexported, then ignore it. + f := typ.Field(i) + if f.PkgPath != "" { + continue + } + + // Next get the actual value of this field and verify it is assignable + // to the map value. + v := dataVal.Field(i) + if !v.Type().AssignableTo(valMap.Type().Elem()) { + return fmt.Errorf("cannot assign type '%s' to map value field of type '%s'", v.Type(), valMap.Type().Elem()) + } + + tagValue := f.Tag.Get(d.config.TagName) + keyName := f.Name + + if tagValue == "" && d.config.IgnoreUntaggedFields { + continue + } + + // If Squash is set in the config, we squash the field down. + squash := d.config.Squash && v.Kind() == reflect.Struct && f.Anonymous + + v = dereferencePtrToStructIfNeeded(v, d.config.TagName) + + // Determine the name of the key in the map + if index := strings.Index(tagValue, ","); index != -1 { + if tagValue[:index] == "-" { + continue + } + // If "omitempty" is specified in the tag, it ignores empty values. + if strings.Index(tagValue[index+1:], "omitempty") != -1 && isEmptyValue(v) { + continue + } + + // If "squash" is specified in the tag, we squash the field down. + squash = squash || strings.Index(tagValue[index+1:], "squash") != -1 + if squash { + // When squashing, the embedded type can be a pointer to a struct. + if v.Kind() == reflect.Ptr && v.Elem().Kind() == reflect.Struct { + v = v.Elem() + } + + // The final type must be a struct + if v.Kind() != reflect.Struct { + return fmt.Errorf("cannot squash non-struct type '%s'", v.Type()) + } + } + if keyNameTagValue := tagValue[:index]; keyNameTagValue != "" { + keyName = keyNameTagValue + } + } else if len(tagValue) > 0 { + if tagValue == "-" { + continue + } + keyName = tagValue + } + + switch v.Kind() { + // this is an embedded struct, so handle it differently + case reflect.Struct: + x := reflect.New(v.Type()) + x.Elem().Set(v) + + vType := valMap.Type() + vKeyType := vType.Key() + vElemType := vType.Elem() + mType := reflect.MapOf(vKeyType, vElemType) + vMap := reflect.MakeMap(mType) + + // Creating a pointer to a map so that other methods can completely + // overwrite the map if need be (looking at you decodeMapFromMap). The + // indirection allows the underlying map to be settable (CanSet() == true) + // where as reflect.MakeMap returns an unsettable map. + addrVal := reflect.New(vMap.Type()) + reflect.Indirect(addrVal).Set(vMap) + + err := d.decode(keyName, x.Interface(), reflect.Indirect(addrVal)) + if err != nil { + return err + } + + // the underlying map may have been completely overwritten so pull + // it indirectly out of the enclosing value. + vMap = reflect.Indirect(addrVal) + + if squash { + for _, k := range vMap.MapKeys() { + valMap.SetMapIndex(k, vMap.MapIndex(k)) + } + } else { + valMap.SetMapIndex(reflect.ValueOf(keyName), vMap) + } + + default: + valMap.SetMapIndex(reflect.ValueOf(keyName), v) + } + + } + + if val.CanAddr() { + val.Set(valMap) + } + + return nil +} + +func (d *Decoder) decodePtr(name string, data interface{}, val reflect.Value) (bool, error) { + // If the input data is nil, then we want to just set the output + // pointer to be nil as well. + isNil := data == nil + if !isNil { + switch v := reflect.Indirect(reflect.ValueOf(data)); v.Kind() { + case reflect.Chan, + reflect.Func, + reflect.Interface, + reflect.Map, + reflect.Ptr, + reflect.Slice: + isNil = v.IsNil() + } + } + if isNil { + if !val.IsNil() && val.CanSet() { + nilValue := reflect.New(val.Type()).Elem() + val.Set(nilValue) + } + + return true, nil + } + + // Create an element of the concrete (non pointer) type and decode + // into that. Then set the value of the pointer to this type. + valType := val.Type() + valElemType := valType.Elem() + if val.CanSet() { + realVal := val + if realVal.IsNil() || d.config.ZeroFields { + realVal = reflect.New(valElemType) + } + + if err := d.decode(name, data, reflect.Indirect(realVal)); err != nil { + // 报错情况下依旧设置指针 + val.Set(realVal) + return false, err + } + + val.Set(realVal) + } else { + if err := d.decode(name, data, reflect.Indirect(val)); err != nil { + return false, err + } + } + return false, nil +} + +func (d *Decoder) decodeFunc(name string, data interface{}, val reflect.Value) error { + // Create an element of the concrete (non pointer) type and decode + // into that. Then set the value of the pointer to this type. + dataVal := reflect.Indirect(reflect.ValueOf(data)) + if val.Type() != dataVal.Type() { + return fmt.Errorf( + "'%s' expected type '%s', got unconvertible type '%s', value: '%v'", + name, val.Type(), dataVal.Type(), data) + } + val.Set(dataVal) + return nil +} + +func (d *Decoder) decodeSlice(name string, data interface{}, val reflect.Value) error { + dataVal := reflect.Indirect(reflect.ValueOf(data)) + dataValKind := dataVal.Kind() + valType := val.Type() + valElemType := valType.Elem() + sliceType := reflect.SliceOf(valElemType) + + // If we have a non array/slice type then we first attempt to convert. + if dataValKind != reflect.Array && dataValKind != reflect.Slice { + if d.config.WeaklyTypedInput { + switch { + // Slice and array we use the normal logic + case dataValKind == reflect.Slice, dataValKind == reflect.Array: + break + + // Empty maps turn into empty slices + case dataValKind == reflect.Map: + if dataVal.Len() == 0 { + val.Set(reflect.MakeSlice(sliceType, 0, 0)) + return nil + } + // Create slice of maps of other sizes + return d.decodeSlice(name, []interface{}{data}, val) + + case dataValKind == reflect.String && valElemType.Kind() == reflect.Uint8: + return d.decodeSlice(name, []byte(dataVal.String()), val) + + // All other types we try to convert to the slice type + // and "lift" it into it. i.e. a string becomes a string slice. + default: + // Just re-try this function with data as a slice. + return d.decodeSlice(name, []interface{}{data}, val) + } + } + + return fmt.Errorf( + "'%s': source data must be an array or slice, got %s", name, dataValKind) + } + + // If the input value is nil, then don't allocate since empty != nil + if dataValKind != reflect.Array && dataVal.IsNil() { + return nil + } + + valSlice := val + if valSlice.IsNil() || d.config.ZeroFields { + // Make a new slice to hold our result, same size as the original data. + valSlice = reflect.MakeSlice(sliceType, dataVal.Len(), dataVal.Len()) + } + + // Accumulate any errors + errors := make([]string, 0) + + for i := 0; i < dataVal.Len(); i++ { + currentData := dataVal.Index(i).Interface() + for valSlice.Len() <= i { + valSlice = reflect.Append(valSlice, reflect.Zero(valElemType)) + } + currentField := valSlice.Index(i) + + fieldName := name + "[" + strconv.Itoa(i) + "]" + if err := d.decode(fieldName, currentData, currentField); err != nil { + errors = appendErrors(errors, err) + } + } + + // Finally, set the value to the slice we built up + val.Set(valSlice) + + // If there were errors, we return those + if len(errors) > 0 { + return &Error{errors} + } + + return nil +} + +func (d *Decoder) decodeArray(name string, data interface{}, val reflect.Value) error { + dataVal := reflect.Indirect(reflect.ValueOf(data)) + dataValKind := dataVal.Kind() + valType := val.Type() + valElemType := valType.Elem() + arrayType := reflect.ArrayOf(valType.Len(), valElemType) + + valArray := val + + if valArray.Interface() == reflect.Zero(valArray.Type()).Interface() || d.config.ZeroFields { + // Check input type + if dataValKind != reflect.Array && dataValKind != reflect.Slice { + if d.config.WeaklyTypedInput { + switch { + // Empty maps turn into empty arrays + case dataValKind == reflect.Map: + if dataVal.Len() == 0 { + val.Set(reflect.Zero(arrayType)) + return nil + } + + // All other types we try to convert to the array type + // and "lift" it into it. i.e. a string becomes a string array. + default: + // Just re-try this function with data as a slice. + return d.decodeArray(name, []interface{}{data}, val) + } + } + + return fmt.Errorf( + "'%s': source data must be an array or slice, got %s", name, dataValKind) + + } + if dataVal.Len() > arrayType.Len() { + return fmt.Errorf( + "'%s': expected source data to have length less or equal to %d, got %d", name, arrayType.Len(), dataVal.Len()) + + } + + // Make a new array to hold our result, same size as the original data. + valArray = reflect.New(arrayType).Elem() + } + + // Accumulate any errors + errors := make([]string, 0) + + for i := 0; i < dataVal.Len(); i++ { + currentData := dataVal.Index(i).Interface() + currentField := valArray.Index(i) + + fieldName := name + "[" + strconv.Itoa(i) + "]" + if err := d.decode(fieldName, currentData, currentField); err != nil { + errors = appendErrors(errors, err) + } + } + + // Finally, set the value to the array we built up + val.Set(valArray) + + // If there were errors, we return those + if len(errors) > 0 { + return &Error{errors} + } + + return nil +} + +func (d *Decoder) decodeStruct(name string, data interface{}, val reflect.Value) error { + dataVal := reflect.Indirect(reflect.ValueOf(data)) + + // If the type of the value to write to and the data match directly, + // then we just set it directly instead of recursing into the structure. + if dataVal.Type() == val.Type() { + val.Set(dataVal) + return nil + } + + dataValKind := dataVal.Kind() + switch dataValKind { + case reflect.Map: + return d.decodeStructFromMap(name, dataVal, val) + + case reflect.Struct: + // Not the most efficient way to do this but we can optimize later if + // we want to. To convert from struct to struct we go to map first + // as an intermediary. + + // Make a new map to hold our result + mapType := reflect.TypeOf((map[string]interface{})(nil)) + mval := reflect.MakeMap(mapType) + + // Creating a pointer to a map so that other methods can completely + // overwrite the map if need be (looking at you decodeMapFromMap). The + // indirection allows the underlying map to be settable (CanSet() == true) + // where as reflect.MakeMap returns an unsettable map. + addrVal := reflect.New(mval.Type()) + + reflect.Indirect(addrVal).Set(mval) + if err := d.decodeMapFromStruct(name, dataVal, reflect.Indirect(addrVal), mval, data); err != nil { + return err + } + + result := d.decodeStructFromMap(name, reflect.Indirect(addrVal), val) + return result + + default: + return fmt.Errorf("'%s' expected a map, got '%s'", name, dataVal.Kind()) + } +} + +func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) error { + dataValType := dataVal.Type() + if kind := dataValType.Key().Kind(); kind != reflect.String && kind != reflect.Interface { + return fmt.Errorf( + "'%s' needs a map with string keys, has '%s' keys", + name, dataValType.Key().Kind()) + } + + dataValKeys := make(map[reflect.Value]struct{}) + dataValKeysUnused := make(map[interface{}]struct{}) + for _, dataValKey := range dataVal.MapKeys() { + dataValKeys[dataValKey] = struct{}{} + dataValKeysUnused[dataValKey.Interface()] = struct{}{} + } + + targetValKeysUnused := make(map[interface{}]struct{}) + errors := make([]string, 0) + + // This slice will keep track of all the structs we'll be decoding. + // There can be more than one struct if there are embedded structs + // that are squashed. + structs := make([]reflect.Value, 1, 5) + structs[0] = val + + // Compile the list of all the fields that we're going to be decoding + // from all the structs. + type field struct { + field reflect.StructField + val reflect.Value + } + + // remainField is set to a valid field set with the "remain" tag if + // we are keeping track of remaining values. + var remainField *field + + fields := []field{} + for len(structs) > 0 { + structVal := structs[0] + structs = structs[1:] + + structType := structVal.Type() + + for i := 0; i < structType.NumField(); i++ { + fieldType := structType.Field(i) + fieldVal := structVal.Field(i) + if fieldVal.Kind() == reflect.Ptr && fieldVal.Elem().Kind() == reflect.Struct { + // Handle embedded struct pointers as embedded structs. + fieldVal = fieldVal.Elem() + } + + // If "squash" is specified in the tag, we squash the field down. + squash := d.config.Squash && fieldVal.Kind() == reflect.Struct && fieldType.Anonymous + remain := false + + // We always parse the tags cause we're looking for other tags too + tagParts := strings.Split(fieldType.Tag.Get(d.config.TagName), ",") + for _, tag := range tagParts[1:] { + if tag == "squash" { + squash = true + break + } + + if tag == "remain" { + remain = true + break + } + } + + if squash { + if fieldVal.Kind() != reflect.Struct { + errors = appendErrors(errors, + fmt.Errorf("%s: unsupported type for squash: %s", fieldType.Name, fieldVal.Kind())) + } else { + structs = append(structs, fieldVal) + } + continue + } + + // Build our field + if remain { + remainField = &field{fieldType, fieldVal} + } else { + // Normal struct field, store it away + fields = append(fields, field{fieldType, fieldVal}) + } + } + } + + // for fieldType, field := range fields { + for _, f := range fields { + field, fieldValue := f.field, f.val + fieldName := field.Name + + tagValue := field.Tag.Get(d.config.TagName) + tagValue = strings.SplitN(tagValue, ",", 2)[0] + if tagValue != "" { + fieldName = tagValue + } + + rawMapKey := reflect.ValueOf(fieldName) + rawMapVal := dataVal.MapIndex(rawMapKey) + if !rawMapVal.IsValid() { + // Do a slower search by iterating over each key and + // doing case-insensitive search. + for dataValKey := range dataValKeys { + mK, ok := dataValKey.Interface().(string) + if !ok { + // Not a string key + continue + } + + if d.config.MatchName(mK, fieldName) { + rawMapKey = dataValKey + rawMapVal = dataVal.MapIndex(dataValKey) + break + } + } + + if !rawMapVal.IsValid() { + // There was no matching key in the map for the value in + // the struct. Remember it for potential errors and metadata. + targetValKeysUnused[fieldName] = struct{}{} + continue + } + } + + if !fieldValue.IsValid() { + // This should never happen + panic("field is not valid") + } + + // If we can't set the field, then it is unexported or something, + // and we just continue onwards. + if !fieldValue.CanSet() { + continue + } + + // Delete the key we're using from the unused map so we stop tracking + delete(dataValKeysUnused, rawMapKey.Interface()) + + // If the name is empty string, then we're at the root, and we + // don't dot-join the fields. + if name != "" { + fieldName = name + "." + fieldName + } + + if err := d.decode(fieldName, rawMapVal.Interface(), fieldValue); err != nil { + errors = appendErrors(errors, err) + } + } + + // If we have a "remain"-tagged field and we have unused keys then + // we put the unused keys directly into the remain field. + if remainField != nil && len(dataValKeysUnused) > 0 { + // Build a map of only the unused values + remain := map[interface{}]interface{}{} + for key := range dataValKeysUnused { + remain[key] = dataVal.MapIndex(reflect.ValueOf(key)).Interface() + } + + // Decode it as-if we were just decoding this map onto our map. + if err := d.decodeMap(name, remain, remainField.val); err != nil { + errors = appendErrors(errors, err) + } + + // Set the map to nil so we have none so that the next check will + // not error (ErrorUnused) + dataValKeysUnused = nil + } + + if d.config.ErrorUnused && len(dataValKeysUnused) > 0 { + keys := make([]string, 0, len(dataValKeysUnused)) + for rawKey := range dataValKeysUnused { + keys = append(keys, rawKey.(string)) + } + sort.Strings(keys) + + err := fmt.Errorf("'%s' has invalid keys: %s", name, strings.Join(keys, ", ")) + errors = appendErrors(errors, err) + } + + if d.config.ErrorUnset && len(targetValKeysUnused) > 0 { + keys := make([]string, 0, len(targetValKeysUnused)) + for rawKey := range targetValKeysUnused { + keys = append(keys, rawKey.(string)) + } + sort.Strings(keys) + + err := fmt.Errorf("'%s' has unset fields: %s", name, strings.Join(keys, ", ")) + errors = appendErrors(errors, err) + } + + if len(errors) > 0 { + return &Error{errors} + } + + // Add the unused keys to the list of unused keys if we're tracking metadata + if d.config.Metadata != nil { + for rawKey := range dataValKeysUnused { + key := rawKey.(string) + if name != "" { + key = name + "." + key + } + + d.config.Metadata.Unused = append(d.config.Metadata.Unused, key) + } + for rawKey := range targetValKeysUnused { + key := rawKey.(string) + if name != "" { + key = name + "." + key + } + + d.config.Metadata.Unset = append(d.config.Metadata.Unset, key) + } + } + + return nil +} + +func isEmptyValue(v reflect.Value) bool { + switch getKind(v) { + case reflect.Array, reflect.Map, reflect.Slice, reflect.String: + return v.Len() == 0 + case reflect.Bool: + return !v.Bool() + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return v.Int() == 0 + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return v.Uint() == 0 + case reflect.Float32, reflect.Float64: + return v.Float() == 0 + case reflect.Interface, reflect.Ptr: + return v.IsNil() + } + return false +} + +func getKind(val reflect.Value) reflect.Kind { + kind := val.Kind() + + switch { + case kind >= reflect.Int && kind <= reflect.Int64: + return reflect.Int + case kind >= reflect.Uint && kind <= reflect.Uint64: + return reflect.Uint + case kind >= reflect.Float32 && kind <= reflect.Float64: + return reflect.Float32 + default: + return kind + } +} + +func isStructTypeConvertibleToMap(typ reflect.Type, checkMapstructureTags bool, tagName string) bool { + for i := 0; i < typ.NumField(); i++ { + f := typ.Field(i) + if f.PkgPath == "" && !checkMapstructureTags { // check for unexported fields + return true + } + if checkMapstructureTags && f.Tag.Get(tagName) != "" { // check for mapstructure tags inside + return true + } + } + return false +} + +func dereferencePtrToStructIfNeeded(v reflect.Value, tagName string) reflect.Value { + if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct { + return v + } + deref := v.Elem() + derefT := deref.Type() + if isStructTypeConvertibleToMap(derefT, true, tagName) { + return deref + } + return v +} diff --git a/internal/pkg/mapstructure/mapstructure_benchmark_test.go b/internal/pkg/mapstructure/mapstructure_benchmark_test.go new file mode 100644 index 0000000..b9bde7e --- /dev/null +++ b/internal/pkg/mapstructure/mapstructure_benchmark_test.go @@ -0,0 +1,285 @@ +package mapstructure + +import ( + "encoding/json" + "testing" +) + +type Person struct { + Name string + Age int + Emails []string + Extra map[string]string +} + +func Benchmark_Decode(b *testing.B) { + input := map[string]interface{}{ + "name": "Mitchell", + "age": 91, + "emails": []string{"one", "two", "three"}, + "extra": map[string]string{ + "twitter": "mitchellh", + }, + } + + var result Person + for i := 0; i < b.N; i++ { + Decode(input, &result) + } +} + +// decodeViaJSON takes the map data and passes it through encoding/json to convert it into the +// given Go native structure pointed to by v. v must be a pointer to a struct. +func decodeViaJSON(data interface{}, v interface{}) error { + // Perform the task by simply marshalling the input into JSON, + // then unmarshalling it into target native Go struct. + b, err := json.Marshal(data) + if err != nil { + return err + } + return json.Unmarshal(b, v) +} + +func Benchmark_DecodeViaJSON(b *testing.B) { + input := map[string]interface{}{ + "name": "Mitchell", + "age": 91, + "emails": []string{"one", "two", "three"}, + "extra": map[string]string{ + "twitter": "mitchellh", + }, + } + + var result Person + for i := 0; i < b.N; i++ { + decodeViaJSON(input, &result) + } +} + +func Benchmark_JSONUnmarshal(b *testing.B) { + input := map[string]interface{}{ + "name": "Mitchell", + "age": 91, + "emails": []string{"one", "two", "three"}, + "extra": map[string]string{ + "twitter": "mitchellh", + }, + } + + inputB, err := json.Marshal(input) + if err != nil { + b.Fatal("Failed to marshal test input:", err) + } + + var result Person + for i := 0; i < b.N; i++ { + json.Unmarshal(inputB, &result) + } +} + +func Benchmark_DecodeBasic(b *testing.B) { + input := map[string]interface{}{ + "vstring": "foo", + "vint": 42, + "Vuint": 42, + "vbool": true, + "Vfloat": 42.42, + "vsilent": true, + "vdata": 42, + "vjsonInt": json.Number("1234"), + "vjsonFloat": json.Number("1234.5"), + "vjsonNumber": json.Number("1234.5"), + } + + for i := 0; i < b.N; i++ { + var result Basic + Decode(input, &result) + } +} + +func Benchmark_DecodeEmbedded(b *testing.B) { + input := map[string]interface{}{ + "vstring": "foo", + "Basic": map[string]interface{}{ + "vstring": "innerfoo", + }, + "vunique": "bar", + } + + var result Embedded + for i := 0; i < b.N; i++ { + Decode(input, &result) + } +} + +func Benchmark_DecodeTypeConversion(b *testing.B) { + input := map[string]interface{}{ + "IntToFloat": 42, + "IntToUint": 42, + "IntToBool": 1, + "IntToString": 42, + "UintToInt": 42, + "UintToFloat": 42, + "UintToBool": 42, + "UintToString": 42, + "BoolToInt": true, + "BoolToUint": true, + "BoolToFloat": true, + "BoolToString": true, + "FloatToInt": 42.42, + "FloatToUint": 42.42, + "FloatToBool": 42.42, + "FloatToString": 42.42, + "StringToInt": "42", + "StringToUint": "42", + "StringToBool": "1", + "StringToFloat": "42.42", + "SliceToMap": []interface{}{}, + "MapToSlice": map[string]interface{}{}, + } + + var resultStrict TypeConversionResult + for i := 0; i < b.N; i++ { + Decode(input, &resultStrict) + } +} + +func Benchmark_DecodeMap(b *testing.B) { + input := map[string]interface{}{ + "vfoo": "foo", + "vother": map[interface{}]interface{}{ + "foo": "foo", + "bar": "bar", + }, + } + + var result Map + for i := 0; i < b.N; i++ { + Decode(input, &result) + } +} + +func Benchmark_DecodeMapOfStruct(b *testing.B) { + input := map[string]interface{}{ + "value": map[string]interface{}{ + "foo": map[string]string{"vstring": "one"}, + "bar": map[string]string{"vstring": "two"}, + }, + } + + var result MapOfStruct + for i := 0; i < b.N; i++ { + Decode(input, &result) + } +} + +func Benchmark_DecodeSlice(b *testing.B) { + input := map[string]interface{}{ + "vfoo": "foo", + "vbar": []string{"foo", "bar", "baz"}, + } + + var result Slice + for i := 0; i < b.N; i++ { + Decode(input, &result) + } +} + +func Benchmark_DecodeSliceOfStruct(b *testing.B) { + input := map[string]interface{}{ + "value": []map[string]interface{}{ + {"vstring": "one"}, + {"vstring": "two"}, + }, + } + + var result SliceOfStruct + for i := 0; i < b.N; i++ { + Decode(input, &result) + } +} + +func Benchmark_DecodeWeaklyTypedInput(b *testing.B) { + // This input can come from anywhere, but typically comes from + // something like decoding JSON, generated by a weakly typed language + // such as PHP. + input := map[string]interface{}{ + "name": 123, // number => string + "age": "42", // string => number + "emails": map[string]interface{}{}, // empty map => empty array + } + + var result Person + config := &DecoderConfig{ + WeaklyTypedInput: true, + Result: &result, + } + + decoder, err := NewDecoder(config) + if err != nil { + panic(err) + } + + for i := 0; i < b.N; i++ { + decoder.Decode(input) + } +} + +func Benchmark_DecodeMetadata(b *testing.B) { + input := map[string]interface{}{ + "name": "Mitchell", + "age": 91, + "email": "foo@bar.com", + } + + var md Metadata + var result Person + config := &DecoderConfig{ + Metadata: &md, + Result: &result, + } + + decoder, err := NewDecoder(config) + if err != nil { + panic(err) + } + + for i := 0; i < b.N; i++ { + decoder.Decode(input) + } +} + +func Benchmark_DecodeMetadataEmbedded(b *testing.B) { + input := map[string]interface{}{ + "vstring": "foo", + "vunique": "bar", + } + + var md Metadata + var result EmbeddedSquash + config := &DecoderConfig{ + Metadata: &md, + Result: &result, + } + + decoder, err := NewDecoder(config) + if err != nil { + b.Fatalf("jderr: %s", err) + } + + for i := 0; i < b.N; i++ { + decoder.Decode(input) + } +} + +func Benchmark_DecodeTagged(b *testing.B) { + input := map[string]interface{}{ + "foo": "bar", + "bar": "value", + } + + var result Tagged + for i := 0; i < b.N; i++ { + Decode(input, &result) + } +} diff --git a/internal/pkg/mapstructure/mapstructure_bugs_test.go b/internal/pkg/mapstructure/mapstructure_bugs_test.go new file mode 100644 index 0000000..31fa5cd --- /dev/null +++ b/internal/pkg/mapstructure/mapstructure_bugs_test.go @@ -0,0 +1,627 @@ +package mapstructure + +import ( + "reflect" + "testing" + "time" +) + +// GH-1, GH-10, GH-96 +func TestDecode_NilValue(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + in interface{} + target interface{} + out interface{} + metaKeys []string + metaUnused []string + }{ + { + "all nil", + &map[string]interface{}{ + "vfoo": nil, + "vother": nil, + }, + &Map{Vfoo: "foo", Vother: map[string]string{"foo": "bar"}}, + &Map{Vfoo: "", Vother: nil}, + []string{"Vfoo", "Vother"}, + []string{}, + }, + { + "partial nil", + &map[string]interface{}{ + "vfoo": "baz", + "vother": nil, + }, + &Map{Vfoo: "foo", Vother: map[string]string{"foo": "bar"}}, + &Map{Vfoo: "baz", Vother: nil}, + []string{"Vfoo", "Vother"}, + []string{}, + }, + { + "partial decode", + &map[string]interface{}{ + "vother": nil, + }, + &Map{Vfoo: "foo", Vother: map[string]string{"foo": "bar"}}, + &Map{Vfoo: "foo", Vother: nil}, + []string{"Vother"}, + []string{}, + }, + { + "unused values", + &map[string]interface{}{ + "vbar": "bar", + "vfoo": nil, + "vother": nil, + }, + &Map{Vfoo: "foo", Vother: map[string]string{"foo": "bar"}}, + &Map{Vfoo: "", Vother: nil}, + []string{"Vfoo", "Vother"}, + []string{"vbar"}, + }, + { + "map interface all nil", + &map[interface{}]interface{}{ + "vfoo": nil, + "vother": nil, + }, + &Map{Vfoo: "foo", Vother: map[string]string{"foo": "bar"}}, + &Map{Vfoo: "", Vother: nil}, + []string{"Vfoo", "Vother"}, + []string{}, + }, + { + "map interface partial nil", + &map[interface{}]interface{}{ + "vfoo": "baz", + "vother": nil, + }, + &Map{Vfoo: "foo", Vother: map[string]string{"foo": "bar"}}, + &Map{Vfoo: "baz", Vother: nil}, + []string{"Vfoo", "Vother"}, + []string{}, + }, + { + "map interface partial decode", + &map[interface{}]interface{}{ + "vother": nil, + }, + &Map{Vfoo: "foo", Vother: map[string]string{"foo": "bar"}}, + &Map{Vfoo: "foo", Vother: nil}, + []string{"Vother"}, + []string{}, + }, + { + "map interface unused values", + &map[interface{}]interface{}{ + "vbar": "bar", + "vfoo": nil, + "vother": nil, + }, + &Map{Vfoo: "foo", Vother: map[string]string{"foo": "bar"}}, + &Map{Vfoo: "", Vother: nil}, + []string{"Vfoo", "Vother"}, + []string{"vbar"}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + config := &DecoderConfig{ + Metadata: new(Metadata), + Result: tc.target, + ZeroFields: true, + } + + decoder, err := NewDecoder(config) + if err != nil { + t.Fatalf("should not error: %s", err) + } + + err = decoder.Decode(tc.in) + if err != nil { + t.Fatalf("should not error: %s", err) + } + + if !reflect.DeepEqual(tc.out, tc.target) { + t.Fatalf("%q: TestDecode_NilValue() expected: %#v, got: %#v", tc.name, tc.out, tc.target) + } + + if !reflect.DeepEqual(tc.metaKeys, config.Metadata.Keys) { + t.Fatalf("%q: Metadata.Keys mismatch expected: %#v, got: %#v", tc.name, tc.metaKeys, config.Metadata.Keys) + } + + if !reflect.DeepEqual(tc.metaUnused, config.Metadata.Unused) { + t.Fatalf("%q: Metadata.Unused mismatch expected: %#v, got: %#v", tc.name, tc.metaUnused, config.Metadata.Unused) + } + }) + } +} + +// #48 +func TestNestedTypePointerWithDefaults(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vfoo": "foo", + "vbar": map[string]interface{}{ + "vstring": "foo", + "vint": 42, + "vbool": true, + }, + } + + result := NestedPointer{ + Vbar: &Basic{ + Vuint: 42, + }, + } + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if result.Vfoo != "foo" { + t.Errorf("vfoo value should be 'foo': %#v", result.Vfoo) + } + + if result.Vbar.Vstring != "foo" { + t.Errorf("vstring value should be 'foo': %#v", result.Vbar.Vstring) + } + + if result.Vbar.Vint != 42 { + t.Errorf("vint value should be 42: %#v", result.Vbar.Vint) + } + + if result.Vbar.Vbool != true { + t.Errorf("vbool value should be true: %#v", result.Vbar.Vbool) + } + + if result.Vbar.Vextra != "" { + t.Errorf("vextra value should be empty: %#v", result.Vbar.Vextra) + } + + // this is the error + if result.Vbar.Vuint != 42 { + t.Errorf("vuint value should be 42: %#v", result.Vbar.Vuint) + } + +} + +type NestedSlice struct { + Vfoo string + Vbars []Basic + Vempty []Basic +} + +// #48 +func TestNestedTypeSliceWithDefaults(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vfoo": "foo", + "vbars": []map[string]interface{}{ + {"vstring": "foo", "vint": 42, "vbool": true}, + {"vint": 42, "vbool": true}, + }, + "vempty": []map[string]interface{}{ + {"vstring": "foo", "vint": 42, "vbool": true}, + {"vint": 42, "vbool": true}, + }, + } + + result := NestedSlice{ + Vbars: []Basic{ + {Vuint: 42}, + {Vstring: "foo"}, + }, + } + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if result.Vfoo != "foo" { + t.Errorf("vfoo value should be 'foo': %#v", result.Vfoo) + } + + if result.Vbars[0].Vstring != "foo" { + t.Errorf("vstring value should be 'foo': %#v", result.Vbars[0].Vstring) + } + // this is the error + if result.Vbars[0].Vuint != 42 { + t.Errorf("vuint value should be 42: %#v", result.Vbars[0].Vuint) + } +} + +// #48 workaround +func TestNestedTypeWithDefaults(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vfoo": "foo", + "vbar": map[string]interface{}{ + "vstring": "foo", + "vint": 42, + "vbool": true, + }, + } + + result := Nested{ + Vbar: Basic{ + Vuint: 42, + }, + } + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if result.Vfoo != "foo" { + t.Errorf("vfoo value should be 'foo': %#v", result.Vfoo) + } + + if result.Vbar.Vstring != "foo" { + t.Errorf("vstring value should be 'foo': %#v", result.Vbar.Vstring) + } + + if result.Vbar.Vint != 42 { + t.Errorf("vint value should be 42: %#v", result.Vbar.Vint) + } + + if result.Vbar.Vbool != true { + t.Errorf("vbool value should be true: %#v", result.Vbar.Vbool) + } + + if result.Vbar.Vextra != "" { + t.Errorf("vextra value should be empty: %#v", result.Vbar.Vextra) + } + + // this is the error + if result.Vbar.Vuint != 42 { + t.Errorf("vuint value should be 42: %#v", result.Vbar.Vuint) + } + +} + +// #67 panic() on extending slices (decodeSlice with disabled ZeroValues) +func TestDecodeSliceToEmptySliceWOZeroing(t *testing.T) { + t.Parallel() + + type TestStruct struct { + Vfoo []string + } + + decode := func(m interface{}, rawVal interface{}) error { + config := &DecoderConfig{ + Metadata: nil, + Result: rawVal, + ZeroFields: false, + } + + decoder, err := NewDecoder(config) + if err != nil { + return err + } + + return decoder.Decode(m) + } + + { + input := map[string]interface{}{ + "vfoo": []string{"1"}, + } + + result := &TestStruct{} + + err := decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + } + + { + input := map[string]interface{}{ + "vfoo": []string{"1"}, + } + + result := &TestStruct{ + Vfoo: []string{}, + } + + err := decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + } + + { + input := map[string]interface{}{ + "vfoo": []string{"2", "3"}, + } + + result := &TestStruct{ + Vfoo: []string{"1"}, + } + + err := decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + } +} + +// #70 +func TestNextSquashMapstructure(t *testing.T) { + data := &struct { + Level1 struct { + Level2 struct { + Foo string + } `mapstructure:",squash"` + } `mapstructure:",squash"` + }{} + err := Decode(map[interface{}]interface{}{"foo": "baz"}, &data) + if err != nil { + t.Fatalf("should not error: %s", err) + } + if data.Level1.Level2.Foo != "baz" { + t.Fatal("value should be baz") + } +} + +type ImplementsInterfacePointerReceiver struct { + Name string +} + +func (i *ImplementsInterfacePointerReceiver) DoStuff() {} + +type ImplementsInterfaceValueReceiver string + +func (i ImplementsInterfaceValueReceiver) DoStuff() {} + +// GH-140 Type error when using DecodeHook to decode into interface +func TestDecode_DecodeHookInterface(t *testing.T) { + t.Parallel() + + type Interface interface { + DoStuff() + } + type DecodeIntoInterface struct { + Test Interface + } + + testData := map[string]string{"test": "test"} + + stringToPointerInterfaceDecodeHook := func(from, to reflect.Type, data interface{}) (interface{}, error) { + if from.Kind() != reflect.String { + return data, nil + } + + if to != reflect.TypeOf((*Interface)(nil)).Elem() { + return data, nil + } + // Ensure interface is satisfied + var impl Interface = &ImplementsInterfacePointerReceiver{data.(string)} + return impl, nil + } + + stringToValueInterfaceDecodeHook := func(from, to reflect.Type, data interface{}) (interface{}, error) { + if from.Kind() != reflect.String { + return data, nil + } + + if to != reflect.TypeOf((*Interface)(nil)).Elem() { + return data, nil + } + // Ensure interface is satisfied + var impl Interface = ImplementsInterfaceValueReceiver(data.(string)) + return impl, nil + } + + { + decodeInto := new(DecodeIntoInterface) + + decoder, _ := NewDecoder(&DecoderConfig{ + DecodeHook: stringToPointerInterfaceDecodeHook, + Result: decodeInto, + }) + + err := decoder.Decode(testData) + if err != nil { + t.Fatalf("Decode returned error: %s", err) + } + + expected := &ImplementsInterfacePointerReceiver{"test"} + if !reflect.DeepEqual(decodeInto.Test, expected) { + t.Fatalf("expected: %#v (%T), got: %#v (%T)", decodeInto.Test, decodeInto.Test, expected, expected) + } + } + + { + decodeInto := new(DecodeIntoInterface) + + decoder, _ := NewDecoder(&DecoderConfig{ + DecodeHook: stringToValueInterfaceDecodeHook, + Result: decodeInto, + }) + + err := decoder.Decode(testData) + if err != nil { + t.Fatalf("Decode returned error: %s", err) + } + + expected := ImplementsInterfaceValueReceiver("test") + if !reflect.DeepEqual(decodeInto.Test, expected) { + t.Fatalf("expected: %#v (%T), got: %#v (%T)", decodeInto.Test, decodeInto.Test, expected, expected) + } + } +} + +// #103 Check for data type before trying to access its composants prevent a panic error +// in decodeSlice +func TestDecodeBadDataTypeInSlice(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "Toto": "titi", + } + result := []struct { + Toto string + }{} + + if err := Decode(input, &result); err == nil { + t.Error("An error was expected, got nil") + } +} + +// #202 Ensure that intermediate maps in the struct -> struct decode process are settable +// and not just the elements within them. +func TestDecodeIntermediateMapsSettable(t *testing.T) { + type Timestamp struct { + Seconds int64 + Nanos int32 + } + + type TsWrapper struct { + Timestamp *Timestamp + } + + type TimeWrapper struct { + Timestamp time.Time + } + + input := TimeWrapper{ + Timestamp: time.Unix(123456789, 987654), + } + + expected := TsWrapper{ + Timestamp: &Timestamp{ + Seconds: 123456789, + Nanos: 987654, + }, + } + + timePtrType := reflect.TypeOf((*time.Time)(nil)) + mapStrInfType := reflect.TypeOf((map[string]interface{})(nil)) + + var actual TsWrapper + decoder, err := NewDecoder(&DecoderConfig{ + Result: &actual, + DecodeHook: func(from, to reflect.Type, data interface{}) (interface{}, error) { + if from == timePtrType && to == mapStrInfType { + ts := data.(*time.Time) + nanos := ts.UnixNano() + + seconds := nanos / 1000000000 + nanos = nanos % 1000000000 + + return &map[string]interface{}{ + "Seconds": seconds, + "Nanos": int32(nanos), + }, nil + } + return data, nil + }, + }) + + if err != nil { + t.Fatalf("failed to create decoder: %v", err) + } + + if err := decoder.Decode(&input); err != nil { + t.Fatalf("failed to decode input: %v", err) + } + + if !reflect.DeepEqual(expected, actual) { + t.Fatalf("expected: %#[1]v (%[1]T), got: %#[2]v (%[2]T)", expected, actual) + } +} + +// GH-206: decodeInt throws an error for an empty string +func TestDecode_weakEmptyStringToInt(t *testing.T) { + input := map[string]interface{}{ + "StringToInt": "", + "StringToUint": "", + "StringToBool": "", + "StringToFloat": "", + } + + expectedResultWeak := TypeConversionResult{ + StringToInt: 0, + StringToUint: 0, + StringToBool: false, + StringToFloat: 0, + } + + // Test weak type conversion + var resultWeak TypeConversionResult + err := WeakDecode(input, &resultWeak) + if err != nil { + t.Fatalf("got an jderr: %s", err) + } + + if !reflect.DeepEqual(resultWeak, expectedResultWeak) { + t.Errorf("expected \n%#v, got: \n%#v", expectedResultWeak, resultWeak) + } +} + +// GH-228: Squash cause *time.Time set to zero +func TestMapSquash(t *testing.T) { + type AA struct { + T *time.Time + } + type A struct { + AA + } + + v := time.Now() + in := &AA{ + T: &v, + } + out := &A{} + d, err := NewDecoder(&DecoderConfig{ + Squash: true, + Result: out, + }) + if err != nil { + t.Fatalf("jderr: %s", err) + } + if err := d.Decode(in); err != nil { + t.Fatalf("jderr: %s", err) + } + + // these failed + if !v.Equal(*out.T) { + t.Fatal("expected equal") + } + if out.T.IsZero() { + t.Fatal("expected false") + } +} + +// GH-238: Empty key name when decoding map from struct with only omitempty flag +func TestMapOmitEmptyWithEmptyFieldnameInTag(t *testing.T) { + type Struct struct { + Username string `mapstructure:",omitempty"` + Age int `mapstructure:",omitempty"` + } + + s := Struct{ + Username: "Joe", + } + var m map[string]interface{} + + if err := Decode(s, &m); err != nil { + t.Fatal(err) + } + + if len(m) != 1 { + t.Fatalf("fail: %#v", m) + } + if m["Username"] != "Joe" { + t.Fatalf("fail: %#v", m) + } +} diff --git a/internal/pkg/mapstructure/mapstructure_examples_test.go b/internal/pkg/mapstructure/mapstructure_examples_test.go new file mode 100644 index 0000000..2413b69 --- /dev/null +++ b/internal/pkg/mapstructure/mapstructure_examples_test.go @@ -0,0 +1,256 @@ +package mapstructure + +import ( + "fmt" +) + +func ExampleDecode() { + type Person struct { + Name string + Age int + Emails []string + Extra map[string]string + } + + // This input can come from anywhere, but typically comes from + // something like decoding JSON where we're not quite sure of the + // struct initially. + input := map[string]interface{}{ + "name": "Mitchell", + "age": 91, + "emails": []string{"one", "two", "three"}, + "extra": map[string]string{ + "twitter": "mitchellh", + }, + } + + var result Person + err := Decode(input, &result) + if err != nil { + panic(err) + } + + fmt.Printf("%#v", result) + // Output: + // mapstructure.Person{Name:"Mitchell", Age:91, Emails:[]string{"one", "two", "three"}, Extra:map[string]string{"twitter":"mitchellh"}} +} + +func ExampleDecode_errors() { + type Person struct { + Name string + Age int + Emails []string + Extra map[string]string + } + + // This input can come from anywhere, but typically comes from + // something like decoding JSON where we're not quite sure of the + // struct initially. + input := map[string]interface{}{ + "name": 123, + "age": "bad value", + "emails": []int{1, 2, 3}, + } + + var result Person + err := Decode(input, &result) + if err == nil { + panic("should have an error") + } + + fmt.Println(err.Error()) + // Output: + // 5 error(s) decoding: + // + // * 'Age' expected type 'int', got unconvertible type 'string', value: 'bad value' + // * 'Emails[0]' expected type 'string', got unconvertible type 'int', value: '1' + // * 'Emails[1]' expected type 'string', got unconvertible type 'int', value: '2' + // * 'Emails[2]' expected type 'string', got unconvertible type 'int', value: '3' + // * 'Name' expected type 'string', got unconvertible type 'int', value: '123' +} + +func ExampleDecode_metadata() { + type Person struct { + Name string + Age int + } + + // This input can come from anywhere, but typically comes from + // something like decoding JSON where we're not quite sure of the + // struct initially. + input := map[string]interface{}{ + "name": "Mitchell", + "age": 91, + "email": "foo@bar.com", + } + + // For metadata, we make a more advanced DecoderConfig so we can + // more finely configure the decoder that is used. In this case, we + // just tell the decoder we want to track metadata. + var md Metadata + var result Person + config := &DecoderConfig{ + Metadata: &md, + Result: &result, + } + + decoder, err := NewDecoder(config) + if err != nil { + panic(err) + } + + if err := decoder.Decode(input); err != nil { + panic(err) + } + + fmt.Printf("Unused keys: %#v", md.Unused) + // Output: + // Unused keys: []string{"email"} +} + +func ExampleDecode_weaklyTypedInput() { + type Person struct { + Name string + Age int + Emails []string + } + + // This input can come from anywhere, but typically comes from + // something like decoding JSON, generated by a weakly typed language + // such as PHP. + input := map[string]interface{}{ + "name": 123, // number => string + "age": "42", // string => number + "emails": map[string]interface{}{}, // empty map => empty array + } + + var result Person + config := &DecoderConfig{ + WeaklyTypedInput: true, + Result: &result, + } + + decoder, err := NewDecoder(config) + if err != nil { + panic(err) + } + + err = decoder.Decode(input) + if err != nil { + panic(err) + } + + fmt.Printf("%#v", result) + // Output: mapstructure.Person{Name:"123", Age:42, Emails:[]string{}} +} + +func ExampleDecode_tags() { + // Note that the mapstructure tags defined in the struct type + // can indicate which fields the values are mapped to. + type Person struct { + Name string `mapstructure:"person_name"` + Age int `mapstructure:"person_age"` + } + + input := map[string]interface{}{ + "person_name": "Mitchell", + "person_age": 91, + } + + var result Person + err := Decode(input, &result) + if err != nil { + panic(err) + } + + fmt.Printf("%#v", result) + // Output: + // mapstructure.Person{Name:"Mitchell", Age:91} +} + +func ExampleDecode_embeddedStruct() { + // Squashing multiple embedded structs is allowed using the squash tag. + // This is demonstrated by creating a composite struct of multiple types + // and decoding into it. In this case, a person can carry with it both + // a Family and a Location, as well as their own FirstName. + type Family struct { + LastName string + } + type Location struct { + City string + } + type Person struct { + Family `mapstructure:",squash"` + Location `mapstructure:",squash"` + FirstName string + } + + input := map[string]interface{}{ + "FirstName": "Mitchell", + "LastName": "Hashimoto", + "City": "San Francisco", + } + + var result Person + err := Decode(input, &result) + if err != nil { + panic(err) + } + + fmt.Printf("%s %s, %s", result.FirstName, result.LastName, result.City) + // Output: + // Mitchell Hashimoto, San Francisco +} + +func ExampleDecode_remainingData() { + // Note that the mapstructure tags defined in the struct type + // can indicate which fields the values are mapped to. + type Person struct { + Name string + Age int + Other map[string]interface{} `mapstructure:",remain"` + } + + input := map[string]interface{}{ + "name": "Mitchell", + "age": 91, + "email": "mitchell@example.com", + } + + var result Person + err := Decode(input, &result) + if err != nil { + panic(err) + } + + fmt.Printf("%#v", result) + // Output: + // mapstructure.Person{Name:"Mitchell", Age:91, Other:map[string]interface {}{"email":"mitchell@example.com"}} +} + +func ExampleDecode_omitempty() { + // Add omitempty annotation to avoid map keys for empty values + type Family struct { + LastName string + } + type Location struct { + City string + } + type Person struct { + *Family `mapstructure:",omitempty"` + *Location `mapstructure:",omitempty"` + Age int + FirstName string + } + + result := &map[string]interface{}{} + input := Person{FirstName: "Somebody"} + err := Decode(input, &result) + if err != nil { + panic(err) + } + + fmt.Printf("%+v", result) + // Output: + // &map[Age:0 FirstName:Somebody] +} diff --git a/internal/pkg/mapstructure/mapstructure_ext_test.go b/internal/pkg/mapstructure/mapstructure_ext_test.go new file mode 100644 index 0000000..a646a51 --- /dev/null +++ b/internal/pkg/mapstructure/mapstructure_ext_test.go @@ -0,0 +1,58 @@ +package mapstructure + +import ( + "reflect" + "testing" +) + +func TestDecode_Ptr(t *testing.T) { + t.Parallel() + + type G struct { + Id int + Name string + } + + type X struct { + Id int + Name int + } + + type AG struct { + List []*G + } + + type AX struct { + List []*X + } + + g2 := &AG{ + List: []*G{ + { + Id: 11, + Name: "gg", + }, + }, + } + x2 := AX{} + + // 报错但还是会转换成功,转换后值为目标类型的 0 值 + err := Decode(g2, &x2) + + res := AX{ + List: []*X{ + { + Id: 11, + Name: 0, // 这个类型的 0 值 + }, + }, + } + + if err == nil { + t.Errorf("Decode_Ptr jderr should not be 'nil': %#v", err) + } + + if !reflect.DeepEqual(res, x2) { + t.Errorf("result should be %#v: got %#v", res, x2) + } +} diff --git a/internal/pkg/mapstructure/mapstructure_test.go b/internal/pkg/mapstructure/mapstructure_test.go new file mode 100644 index 0000000..17e609a --- /dev/null +++ b/internal/pkg/mapstructure/mapstructure_test.go @@ -0,0 +1,2763 @@ +package mapstructure + +import ( + "encoding/json" + "io" + "reflect" + "sort" + "strings" + "testing" + "time" +) + +type Basic struct { + Vstring string + Vint int + Vint8 int8 + Vint16 int16 + Vint32 int32 + Vint64 int64 + Vuint uint + Vbool bool + Vfloat float64 + Vextra string + vsilent bool + Vdata interface{} + VjsonInt int + VjsonUint uint + VjsonUint64 uint64 + VjsonFloat float64 + VjsonNumber json.Number +} + +type BasicPointer struct { + Vstring *string + Vint *int + Vuint *uint + Vbool *bool + Vfloat *float64 + Vextra *string + vsilent *bool + Vdata *interface{} + VjsonInt *int + VjsonFloat *float64 + VjsonNumber *json.Number +} + +type BasicSquash struct { + Test Basic `mapstructure:",squash"` +} + +type Embedded struct { + Basic + Vunique string +} + +type EmbeddedPointer struct { + *Basic + Vunique string +} + +type EmbeddedSquash struct { + Basic `mapstructure:",squash"` + Vunique string +} + +type EmbeddedPointerSquash struct { + *Basic `mapstructure:",squash"` + Vunique string +} + +type BasicMapStructure struct { + Vunique string `mapstructure:"vunique"` + Vtime *time.Time `mapstructure:"time"` +} + +type NestedPointerWithMapstructure struct { + Vbar *BasicMapStructure `mapstructure:"vbar"` +} + +type EmbeddedPointerSquashWithNestedMapstructure struct { + *NestedPointerWithMapstructure `mapstructure:",squash"` + Vunique string +} + +type EmbeddedAndNamed struct { + Basic + Named Basic + Vunique string +} + +type SliceAlias []string + +type EmbeddedSlice struct { + SliceAlias `mapstructure:"slice_alias"` + Vunique string +} + +type ArrayAlias [2]string + +type EmbeddedArray struct { + ArrayAlias `mapstructure:"array_alias"` + Vunique string +} + +type SquashOnNonStructType struct { + InvalidSquashType int `mapstructure:",squash"` +} + +type Map struct { + Vfoo string + Vother map[string]string +} + +type MapOfStruct struct { + Value map[string]Basic +} + +type Nested struct { + Vfoo string + Vbar Basic +} + +type NestedPointer struct { + Vfoo string + Vbar *Basic +} + +type NilInterface struct { + W io.Writer +} + +type NilPointer struct { + Value *string +} + +type Slice struct { + Vfoo string + Vbar []string +} + +type SliceOfAlias struct { + Vfoo string + Vbar SliceAlias +} + +type SliceOfStruct struct { + Value []Basic +} + +type SlicePointer struct { + Vbar *[]string +} + +type Array struct { + Vfoo string + Vbar [2]string +} + +type ArrayOfStruct struct { + Value [2]Basic +} + +type Func struct { + Foo func() string +} + +type Tagged struct { + Extra string `mapstructure:"bar,what,what"` + Value string `mapstructure:"foo"` +} + +type Remainder struct { + A string + Extra map[string]interface{} `mapstructure:",remain"` +} + +type StructWithOmitEmpty struct { + VisibleStringField string `mapstructure:"visible-string"` + OmitStringField string `mapstructure:"omittable-string,omitempty"` + VisibleIntField int `mapstructure:"visible-int"` + OmitIntField int `mapstructure:"omittable-int,omitempty"` + VisibleFloatField float64 `mapstructure:"visible-float"` + OmitFloatField float64 `mapstructure:"omittable-float,omitempty"` + VisibleSliceField []interface{} `mapstructure:"visible-slice"` + OmitSliceField []interface{} `mapstructure:"omittable-slice,omitempty"` + VisibleMapField map[string]interface{} `mapstructure:"visible-map"` + OmitMapField map[string]interface{} `mapstructure:"omittable-map,omitempty"` + NestedField *Nested `mapstructure:"visible-nested"` + OmitNestedField *Nested `mapstructure:"omittable-nested,omitempty"` +} + +type TypeConversionResult struct { + IntToFloat float32 + IntToUint uint + IntToBool bool + IntToString string + UintToInt int + UintToFloat float32 + UintToBool bool + UintToString string + BoolToInt int + BoolToUint uint + BoolToFloat float32 + BoolToString string + FloatToInt int + FloatToUint uint + FloatToBool bool + FloatToString string + SliceUint8ToString string + StringToSliceUint8 []byte + ArrayUint8ToString string + StringToInt int + StringToUint uint + StringToBool bool + StringToFloat float32 + StringToStrSlice []string + StringToIntSlice []int + StringToStrArray [1]string + StringToIntArray [1]int + SliceToMap map[string]interface{} + MapToSlice []interface{} + ArrayToMap map[string]interface{} + MapToArray [1]interface{} +} + +func TestBasicTypes(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vstring": "foo", + "vint": 42, + "vint8": 42, + "vint16": 42, + "vint32": 42, + "vint64": 42, + "Vuint": 42, + "vbool": true, + "Vfloat": 42.42, + "vsilent": true, + "vdata": 42, + "vjsonInt": json.Number("1234"), + "vjsonUint": json.Number("1234"), + "vjsonUint64": json.Number("9223372036854775809"), // 2^63 + 1 + "vjsonFloat": json.Number("1234.5"), + "vjsonNumber": json.Number("1234.5"), + } + + var result Basic + err := Decode(input, &result) + if err != nil { + t.Errorf("got an jderr: %s", err.Error()) + t.FailNow() + } + + if result.Vstring != "foo" { + t.Errorf("vstring value should be 'foo': %#v", result.Vstring) + } + + if result.Vint != 42 { + t.Errorf("vint value should be 42: %#v", result.Vint) + } + if result.Vint8 != 42 { + t.Errorf("vint8 value should be 42: %#v", result.Vint) + } + if result.Vint16 != 42 { + t.Errorf("vint16 value should be 42: %#v", result.Vint) + } + if result.Vint32 != 42 { + t.Errorf("vint32 value should be 42: %#v", result.Vint) + } + if result.Vint64 != 42 { + t.Errorf("vint64 value should be 42: %#v", result.Vint) + } + + if result.Vuint != 42 { + t.Errorf("vuint value should be 42: %#v", result.Vuint) + } + + if result.Vbool != true { + t.Errorf("vbool value should be true: %#v", result.Vbool) + } + + if result.Vfloat != 42.42 { + t.Errorf("vfloat value should be 42.42: %#v", result.Vfloat) + } + + if result.Vextra != "" { + t.Errorf("vextra value should be empty: %#v", result.Vextra) + } + + if result.vsilent != false { + t.Error("vsilent should not be set, it is unexported") + } + + if result.Vdata != 42 { + t.Error("vdata should be valid") + } + + if result.VjsonInt != 1234 { + t.Errorf("vjsonint value should be 1234: %#v", result.VjsonInt) + } + + if result.VjsonUint != 1234 { + t.Errorf("vjsonuint value should be 1234: %#v", result.VjsonUint) + } + + if result.VjsonUint64 != 9223372036854775809 { + t.Errorf("vjsonuint64 value should be 9223372036854775809: %#v", result.VjsonUint64) + } + + if result.VjsonFloat != 1234.5 { + t.Errorf("vjsonfloat value should be 1234.5: %#v", result.VjsonFloat) + } + + if !reflect.DeepEqual(result.VjsonNumber, json.Number("1234.5")) { + t.Errorf("vjsonnumber value should be '1234.5': %T, %#v", result.VjsonNumber, result.VjsonNumber) + } +} + +func TestBasic_IntWithFloat(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vint": float64(42), + } + + var result Basic + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err) + } +} + +func TestBasic_Merge(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vint": 42, + } + + var result Basic + result.Vuint = 100 + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err) + } + + expected := Basic{ + Vint: 42, + Vuint: 100, + } + if !reflect.DeepEqual(result, expected) { + t.Fatalf("bad: %#v", result) + } +} + +// Test for issue #46. +func TestBasic_Struct(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vdata": map[string]interface{}{ + "vstring": "foo", + }, + } + + var result, inner Basic + result.Vdata = &inner + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err) + } + expected := Basic{ + Vdata: &Basic{ + Vstring: "foo", + }, + } + if !reflect.DeepEqual(result, expected) { + t.Fatalf("bad: %#v", result) + } +} + +func TestBasic_interfaceStruct(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vstring": "foo", + } + + var iface interface{} = &Basic{} + err := Decode(input, &iface) + if err != nil { + t.Fatalf("got an jderr: %s", err) + } + + expected := &Basic{ + Vstring: "foo", + } + if !reflect.DeepEqual(iface, expected) { + t.Fatalf("bad: %#v", iface) + } +} + +// Issue 187 +func TestBasic_interfaceStructNonPtr(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vstring": "foo", + } + + var iface interface{} = Basic{} + err := Decode(input, &iface) + if err != nil { + t.Fatalf("got an jderr: %s", err) + } + + expected := Basic{ + Vstring: "foo", + } + if !reflect.DeepEqual(iface, expected) { + t.Fatalf("bad: %#v", iface) + } +} + +func TestDecode_BasicSquash(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vstring": "foo", + } + + var result BasicSquash + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if result.Test.Vstring != "foo" { + t.Errorf("vstring value should be 'foo': %#v", result.Test.Vstring) + } +} + +func TestDecodeFrom_BasicSquash(t *testing.T) { + t.Parallel() + + var v interface{} + var ok bool + + input := BasicSquash{ + Test: Basic{ + Vstring: "foo", + }, + } + + var result map[string]interface{} + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if _, ok = result["Test"]; ok { + t.Error("test should not be present in map") + } + + v, ok = result["Vstring"] + if !ok { + t.Error("vstring should be present in map") + } else if !reflect.DeepEqual(v, "foo") { + t.Errorf("vstring value should be 'foo': %#v", v) + } +} + +func TestDecode_Embedded(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vstring": "foo", + "Basic": map[string]interface{}{ + "vstring": "innerfoo", + }, + "vunique": "bar", + } + + var result Embedded + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if result.Vstring != "innerfoo" { + t.Errorf("vstring value should be 'innerfoo': %#v", result.Vstring) + } + + if result.Vunique != "bar" { + t.Errorf("vunique value should be 'bar': %#v", result.Vunique) + } +} + +func TestDecode_EmbeddedPointer(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vstring": "foo", + "Basic": map[string]interface{}{ + "vstring": "innerfoo", + }, + "vunique": "bar", + } + + var result EmbeddedPointer + err := Decode(input, &result) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + expected := EmbeddedPointer{ + Basic: &Basic{ + Vstring: "innerfoo", + }, + Vunique: "bar", + } + if !reflect.DeepEqual(result, expected) { + t.Fatalf("bad: %#v", result) + } +} + +func TestDecode_EmbeddedSlice(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "slice_alias": []string{"foo", "bar"}, + "vunique": "bar", + } + + var result EmbeddedSlice + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if !reflect.DeepEqual(result.SliceAlias, SliceAlias([]string{"foo", "bar"})) { + t.Errorf("slice value: %#v", result.SliceAlias) + } + + if result.Vunique != "bar" { + t.Errorf("vunique value should be 'bar': %#v", result.Vunique) + } +} + +func TestDecode_EmbeddedArray(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "array_alias": [2]string{"foo", "bar"}, + "vunique": "bar", + } + + var result EmbeddedArray + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if !reflect.DeepEqual(result.ArrayAlias, ArrayAlias([2]string{"foo", "bar"})) { + t.Errorf("array value: %#v", result.ArrayAlias) + } + + if result.Vunique != "bar" { + t.Errorf("vunique value should be 'bar': %#v", result.Vunique) + } +} + +func TestDecode_decodeSliceWithArray(t *testing.T) { + t.Parallel() + + var result []int + input := [1]int{1} + expected := []int{1} + if err := Decode(input, &result); err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if !reflect.DeepEqual(expected, result) { + t.Errorf("wanted %+v, got %+v", expected, result) + } +} + +func TestDecode_EmbeddedNoSquash(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vstring": "foo", + "vunique": "bar", + } + + var result Embedded + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if result.Vstring != "" { + t.Errorf("vstring value should be empty: %#v", result.Vstring) + } + + if result.Vunique != "bar" { + t.Errorf("vunique value should be 'bar': %#v", result.Vunique) + } +} + +func TestDecode_EmbeddedPointerNoSquash(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vstring": "foo", + "vunique": "bar", + } + + result := EmbeddedPointer{ + Basic: &Basic{}, + } + + err := Decode(input, &result) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + if result.Vstring != "" { + t.Errorf("vstring value should be empty: %#v", result.Vstring) + } + + if result.Vunique != "bar" { + t.Errorf("vunique value should be 'bar': %#v", result.Vunique) + } +} + +func TestDecode_EmbeddedSquash(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vstring": "foo", + "vunique": "bar", + } + + var result EmbeddedSquash + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if result.Vstring != "foo" { + t.Errorf("vstring value should be 'foo': %#v", result.Vstring) + } + + if result.Vunique != "bar" { + t.Errorf("vunique value should be 'bar': %#v", result.Vunique) + } +} + +func TestDecodeFrom_EmbeddedSquash(t *testing.T) { + t.Parallel() + + var v interface{} + var ok bool + + input := EmbeddedSquash{ + Basic: Basic{ + Vstring: "foo", + }, + Vunique: "bar", + } + + var result map[string]interface{} + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if _, ok = result["Basic"]; ok { + t.Error("basic should not be present in map") + } + + v, ok = result["Vstring"] + if !ok { + t.Error("vstring should be present in map") + } else if !reflect.DeepEqual(v, "foo") { + t.Errorf("vstring value should be 'foo': %#v", v) + } + + v, ok = result["Vunique"] + if !ok { + t.Error("vunique should be present in map") + } else if !reflect.DeepEqual(v, "bar") { + t.Errorf("vunique value should be 'bar': %#v", v) + } +} + +func TestDecode_EmbeddedPointerSquash_FromStructToMap(t *testing.T) { + t.Parallel() + + input := EmbeddedPointerSquash{ + Basic: &Basic{ + Vstring: "foo", + }, + Vunique: "bar", + } + + var result map[string]interface{} + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if result["Vstring"] != "foo" { + t.Errorf("vstring value should be 'foo': %#v", result["Vstring"]) + } + + if result["Vunique"] != "bar" { + t.Errorf("vunique value should be 'bar': %#v", result["Vunique"]) + } +} + +func TestDecode_EmbeddedPointerSquash_FromMapToStruct(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "Vstring": "foo", + "Vunique": "bar", + } + + result := EmbeddedPointerSquash{ + Basic: &Basic{}, + } + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if result.Vstring != "foo" { + t.Errorf("vstring value should be 'foo': %#v", result.Vstring) + } + + if result.Vunique != "bar" { + t.Errorf("vunique value should be 'bar': %#v", result.Vunique) + } +} + +func TestDecode_EmbeddedPointerSquashWithNestedMapstructure_FromStructToMap(t *testing.T) { + t.Parallel() + + vTime := time.Now() + + input := EmbeddedPointerSquashWithNestedMapstructure{ + NestedPointerWithMapstructure: &NestedPointerWithMapstructure{ + Vbar: &BasicMapStructure{ + Vunique: "bar", + Vtime: &vTime, + }, + }, + Vunique: "foo", + } + + var result map[string]interface{} + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + expected := map[string]interface{}{ + "vbar": map[string]interface{}{ + "vunique": "bar", + "time": &vTime, + }, + "Vunique": "foo", + } + + if !reflect.DeepEqual(result, expected) { + t.Errorf("result should be %#v: got %#v", expected, result) + } +} + +func TestDecode_EmbeddedPointerSquashWithNestedMapstructure_FromMapToStruct(t *testing.T) { + t.Parallel() + + vTime := time.Now() + + input := map[string]interface{}{ + "vbar": map[string]interface{}{ + "vunique": "bar", + "time": &vTime, + }, + "Vunique": "foo", + } + + result := EmbeddedPointerSquashWithNestedMapstructure{ + NestedPointerWithMapstructure: &NestedPointerWithMapstructure{}, + } + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + expected := EmbeddedPointerSquashWithNestedMapstructure{ + NestedPointerWithMapstructure: &NestedPointerWithMapstructure{ + Vbar: &BasicMapStructure{ + Vunique: "bar", + Vtime: &vTime, + }, + }, + Vunique: "foo", + } + + if !reflect.DeepEqual(result, expected) { + t.Errorf("result should be %#v: got %#v", expected, result) + } +} + +func TestDecode_EmbeddedSquashConfig(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vstring": "foo", + "vunique": "bar", + "Named": map[string]interface{}{ + "vstring": "baz", + }, + } + + var result EmbeddedAndNamed + config := &DecoderConfig{ + Squash: true, + Result: &result, + } + + decoder, err := NewDecoder(config) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + err = decoder.Decode(input) + if err != nil { + t.Fatalf("got an jderr: %s", err) + } + + if result.Vstring != "foo" { + t.Errorf("vstring value should be 'foo': %#v", result.Vstring) + } + + if result.Vunique != "bar" { + t.Errorf("vunique value should be 'bar': %#v", result.Vunique) + } + + if result.Named.Vstring != "baz" { + t.Errorf("Named.vstring value should be 'baz': %#v", result.Named.Vstring) + } +} + +func TestDecodeFrom_EmbeddedSquashConfig(t *testing.T) { + t.Parallel() + + input := EmbeddedAndNamed{ + Basic: Basic{Vstring: "foo"}, + Named: Basic{Vstring: "baz"}, + Vunique: "bar", + } + + result := map[string]interface{}{} + config := &DecoderConfig{ + Squash: true, + Result: &result, + } + decoder, err := NewDecoder(config) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + err = decoder.Decode(input) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if _, ok := result["Basic"]; ok { + t.Error("basic should not be present in map") + } + + v, ok := result["Vstring"] + if !ok { + t.Error("vstring should be present in map") + } else if !reflect.DeepEqual(v, "foo") { + t.Errorf("vstring value should be 'foo': %#v", v) + } + + v, ok = result["Vunique"] + if !ok { + t.Error("vunique should be present in map") + } else if !reflect.DeepEqual(v, "bar") { + t.Errorf("vunique value should be 'bar': %#v", v) + } + + v, ok = result["Named"] + if !ok { + t.Error("Named should be present in map") + } else { + named := v.(map[string]interface{}) + v, ok := named["Vstring"] + if !ok { + t.Error("Named: vstring should be present in map") + } else if !reflect.DeepEqual(v, "baz") { + t.Errorf("Named: vstring should be 'baz': %#v", v) + } + } +} + +func TestDecodeFrom_EmbeddedSquashConfig_WithTags(t *testing.T) { + t.Parallel() + + var v interface{} + var ok bool + + input := EmbeddedSquash{ + Basic: Basic{ + Vstring: "foo", + }, + Vunique: "bar", + } + + result := map[string]interface{}{} + config := &DecoderConfig{ + Squash: true, + Result: &result, + } + decoder, err := NewDecoder(config) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + err = decoder.Decode(input) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if _, ok = result["Basic"]; ok { + t.Error("basic should not be present in map") + } + + v, ok = result["Vstring"] + if !ok { + t.Error("vstring should be present in map") + } else if !reflect.DeepEqual(v, "foo") { + t.Errorf("vstring value should be 'foo': %#v", v) + } + + v, ok = result["Vunique"] + if !ok { + t.Error("vunique should be present in map") + } else if !reflect.DeepEqual(v, "bar") { + t.Errorf("vunique value should be 'bar': %#v", v) + } +} + +func TestDecode_SquashOnNonStructType(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "InvalidSquashType": 42, + } + + var result SquashOnNonStructType + err := Decode(input, &result) + if err == nil { + t.Fatal("unexpected success decoding invalid squash field type") + } else if !strings.Contains(err.Error(), "unsupported type for squash") { + t.Fatalf("unexpected error message for invalid squash field type: %s", err) + } +} + +func TestDecode_DecodeHook(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vint": "WHAT", + } + + decodeHook := func(from reflect.Kind, to reflect.Kind, v interface{}) (interface{}, error) { + if from == reflect.String && to != reflect.String { + return 5, nil + } + + return v, nil + } + + var result Basic + config := &DecoderConfig{ + DecodeHook: decodeHook, + Result: &result, + } + + decoder, err := NewDecoder(config) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + err = decoder.Decode(input) + if err != nil { + t.Fatalf("got an jderr: %s", err) + } + + if result.Vint != 5 { + t.Errorf("vint should be 5: %#v", result.Vint) + } +} + +func TestDecode_DecodeHookType(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vint": "WHAT", + } + + decodeHook := func(from reflect.Type, to reflect.Type, v interface{}) (interface{}, error) { + if from.Kind() == reflect.String && + to.Kind() != reflect.String { + return 5, nil + } + + return v, nil + } + + var result Basic + config := &DecoderConfig{ + DecodeHook: decodeHook, + Result: &result, + } + + decoder, err := NewDecoder(config) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + err = decoder.Decode(input) + if err != nil { + t.Fatalf("got an jderr: %s", err) + } + + if result.Vint != 5 { + t.Errorf("vint should be 5: %#v", result.Vint) + } +} + +func TestDecode_Nil(t *testing.T) { + t.Parallel() + + var input interface{} + result := Basic{ + Vstring: "foo", + } + + err := Decode(input, &result) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + if result.Vstring != "foo" { + t.Fatalf("bad: %#v", result.Vstring) + } +} + +func TestDecode_NilInterfaceHook(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "w": "", + } + + decodeHook := func(f, t reflect.Type, v interface{}) (interface{}, error) { + if t.String() == "io.Writer" { + return nil, nil + } + + return v, nil + } + + var result NilInterface + config := &DecoderConfig{ + DecodeHook: decodeHook, + Result: &result, + } + + decoder, err := NewDecoder(config) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + err = decoder.Decode(input) + if err != nil { + t.Fatalf("got an jderr: %s", err) + } + + if result.W != nil { + t.Errorf("W should be nil: %#v", result.W) + } +} + +func TestDecode_NilPointerHook(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "value": "", + } + + decodeHook := func(f, t reflect.Type, v interface{}) (interface{}, error) { + if typed, ok := v.(string); ok { + if typed == "" { + return nil, nil + } + } + return v, nil + } + + var result NilPointer + config := &DecoderConfig{ + DecodeHook: decodeHook, + Result: &result, + } + + decoder, err := NewDecoder(config) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + err = decoder.Decode(input) + if err != nil { + t.Fatalf("got an jderr: %s", err) + } + + if result.Value != nil { + t.Errorf("W should be nil: %#v", result.Value) + } +} + +func TestDecode_FuncHook(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "foo": "baz", + } + + decodeHook := func(f, t reflect.Type, v interface{}) (interface{}, error) { + if t.Kind() != reflect.Func { + return v, nil + } + val := v.(string) + return func() string { return val }, nil + } + + var result Func + config := &DecoderConfig{ + DecodeHook: decodeHook, + Result: &result, + } + + decoder, err := NewDecoder(config) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + err = decoder.Decode(input) + if err != nil { + t.Fatalf("got an jderr: %s", err) + } + + if result.Foo() != "baz" { + t.Errorf("Foo call result should be 'baz': %s", result.Foo()) + } +} + +func TestDecode_NonStruct(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "foo": "bar", + "bar": "baz", + } + + var result map[string]string + err := Decode(input, &result) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + if result["foo"] != "bar" { + t.Fatal("foo is not bar") + } +} + +func TestDecode_StructMatch(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vbar": Basic{ + Vstring: "foo", + }, + } + + var result Nested + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if result.Vbar.Vstring != "foo" { + t.Errorf("bad: %#v", result) + } +} + +func TestDecode_TypeConversion(t *testing.T) { + input := map[string]interface{}{ + "IntToFloat": 42, + "IntToUint": 42, + "IntToBool": 1, + "IntToString": 42, + "UintToInt": 42, + "UintToFloat": 42, + "UintToBool": 42, + "UintToString": 42, + "BoolToInt": true, + "BoolToUint": true, + "BoolToFloat": true, + "BoolToString": true, + "FloatToInt": 42.42, + "FloatToUint": 42.42, + "FloatToBool": 42.42, + "FloatToString": 42.42, + "SliceUint8ToString": []uint8("foo"), + "StringToSliceUint8": "foo", + "ArrayUint8ToString": [3]uint8{'f', 'o', 'o'}, + "StringToInt": "42", + "StringToUint": "42", + "StringToBool": "1", + "StringToFloat": "42.42", + "StringToStrSlice": "A", + "StringToIntSlice": "42", + "StringToStrArray": "A", + "StringToIntArray": "42", + "SliceToMap": []interface{}{}, + "MapToSlice": map[string]interface{}{}, + "ArrayToMap": []interface{}{}, + "MapToArray": map[string]interface{}{}, + } + + expectedResultStrict := TypeConversionResult{ + IntToFloat: 42.0, + IntToUint: 42, + UintToInt: 42, + UintToFloat: 42, + BoolToInt: 0, + BoolToUint: 0, + BoolToFloat: 0, + FloatToInt: 42, + FloatToUint: 42, + } + + expectedResultWeak := TypeConversionResult{ + IntToFloat: 42.0, + IntToUint: 42, + IntToBool: true, + IntToString: "42", + UintToInt: 42, + UintToFloat: 42, + UintToBool: true, + UintToString: "42", + BoolToInt: 1, + BoolToUint: 1, + BoolToFloat: 1, + BoolToString: "1", + FloatToInt: 42, + FloatToUint: 42, + FloatToBool: true, + FloatToString: "42.42", + SliceUint8ToString: "foo", + StringToSliceUint8: []byte("foo"), + ArrayUint8ToString: "foo", + StringToInt: 42, + StringToUint: 42, + StringToBool: true, + StringToFloat: 42.42, + StringToStrSlice: []string{"A"}, + StringToIntSlice: []int{42}, + StringToStrArray: [1]string{"A"}, + StringToIntArray: [1]int{42}, + SliceToMap: map[string]interface{}{}, + MapToSlice: []interface{}{}, + ArrayToMap: map[string]interface{}{}, + MapToArray: [1]interface{}{}, + } + + // Test strict type conversion + var resultStrict TypeConversionResult + err := Decode(input, &resultStrict) + if err == nil { + t.Errorf("should return an error") + } + if !reflect.DeepEqual(resultStrict, expectedResultStrict) { + t.Errorf("expected %v, got: %v", expectedResultStrict, resultStrict) + } + + // Test weak type conversion + var decoder *Decoder + var resultWeak TypeConversionResult + + config := &DecoderConfig{ + WeaklyTypedInput: true, + Result: &resultWeak, + } + + decoder, err = NewDecoder(config) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + err = decoder.Decode(input) + if err != nil { + t.Fatalf("got an jderr: %s", err) + } + + if !reflect.DeepEqual(resultWeak, expectedResultWeak) { + t.Errorf("expected \n%#v, got: \n%#v", expectedResultWeak, resultWeak) + } +} + +func TestDecoder_ErrorUnused(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vstring": "hello", + "foo": "bar", + } + + var result Basic + config := &DecoderConfig{ + ErrorUnused: true, + Result: &result, + } + + decoder, err := NewDecoder(config) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + err = decoder.Decode(input) + if err == nil { + t.Fatal("expected error") + } +} + +func TestDecoder_ErrorUnused_NotSetable(t *testing.T) { + t.Parallel() + + // lowercase vsilent is unexported and cannot be set + input := map[string]interface{}{ + "vsilent": "false", + } + + var result Basic + config := &DecoderConfig{ + ErrorUnused: true, + Result: &result, + } + + decoder, err := NewDecoder(config) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + err = decoder.Decode(input) + if err == nil { + t.Fatal("expected error") + } +} +func TestDecoder_ErrorUnset(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vstring": "hello", + "foo": "bar", + } + + var result Basic + config := &DecoderConfig{ + ErrorUnset: true, + Result: &result, + } + + decoder, err := NewDecoder(config) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + err = decoder.Decode(input) + if err == nil { + t.Fatal("expected error") + } +} + +func TestMap(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vfoo": "foo", + "vother": map[interface{}]interface{}{ + "foo": "foo", + "bar": "bar", + }, + } + + var result Map + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an error: %s", err) + } + + if result.Vfoo != "foo" { + t.Errorf("vfoo value should be 'foo': %#v", result.Vfoo) + } + + if result.Vother == nil { + t.Fatal("vother should not be nil") + } + + if len(result.Vother) != 2 { + t.Error("vother should have two items") + } + + if result.Vother["foo"] != "foo" { + t.Errorf("'foo' key should be foo, got: %#v", result.Vother["foo"]) + } + + if result.Vother["bar"] != "bar" { + t.Errorf("'bar' key should be bar, got: %#v", result.Vother["bar"]) + } +} + +func TestMapMerge(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vfoo": "foo", + "vother": map[interface{}]interface{}{ + "foo": "foo", + "bar": "bar", + }, + } + + var result Map + result.Vother = map[string]string{"hello": "world"} + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an error: %s", err) + } + + if result.Vfoo != "foo" { + t.Errorf("vfoo value should be 'foo': %#v", result.Vfoo) + } + + expected := map[string]string{ + "foo": "foo", + "bar": "bar", + "hello": "world", + } + if !reflect.DeepEqual(result.Vother, expected) { + t.Errorf("bad: %#v", result.Vother) + } +} + +func TestMapOfStruct(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "value": map[string]interface{}{ + "foo": map[string]string{"vstring": "one"}, + "bar": map[string]string{"vstring": "two"}, + }, + } + + var result MapOfStruct + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err) + } + + if result.Value == nil { + t.Fatal("value should not be nil") + } + + if len(result.Value) != 2 { + t.Error("value should have two items") + } + + if result.Value["foo"].Vstring != "one" { + t.Errorf("foo value should be 'one', got: %s", result.Value["foo"].Vstring) + } + + if result.Value["bar"].Vstring != "two" { + t.Errorf("bar value should be 'two', got: %s", result.Value["bar"].Vstring) + } +} + +func TestNestedType(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vfoo": "foo", + "vbar": map[string]interface{}{ + "vstring": "foo", + "vint": 42, + "vbool": true, + }, + } + + var result Nested + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if result.Vfoo != "foo" { + t.Errorf("vfoo value should be 'foo': %#v", result.Vfoo) + } + + if result.Vbar.Vstring != "foo" { + t.Errorf("vstring value should be 'foo': %#v", result.Vbar.Vstring) + } + + if result.Vbar.Vint != 42 { + t.Errorf("vint value should be 42: %#v", result.Vbar.Vint) + } + + if result.Vbar.Vbool != true { + t.Errorf("vbool value should be true: %#v", result.Vbar.Vbool) + } + + if result.Vbar.Vextra != "" { + t.Errorf("vextra value should be empty: %#v", result.Vbar.Vextra) + } +} + +func TestNestedTypePointer(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vfoo": "foo", + "vbar": &map[string]interface{}{ + "vstring": "foo", + "vint": 42, + "vbool": true, + }, + } + + var result NestedPointer + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if result.Vfoo != "foo" { + t.Errorf("vfoo value should be 'foo': %#v", result.Vfoo) + } + + if result.Vbar.Vstring != "foo" { + t.Errorf("vstring value should be 'foo': %#v", result.Vbar.Vstring) + } + + if result.Vbar.Vint != 42 { + t.Errorf("vint value should be 42: %#v", result.Vbar.Vint) + } + + if result.Vbar.Vbool != true { + t.Errorf("vbool value should be true: %#v", result.Vbar.Vbool) + } + + if result.Vbar.Vextra != "" { + t.Errorf("vextra value should be empty: %#v", result.Vbar.Vextra) + } +} + +// Test for issue #46. +func TestNestedTypeInterface(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vfoo": "foo", + "vbar": &map[string]interface{}{ + "vstring": "foo", + "vint": 42, + "vbool": true, + + "vdata": map[string]interface{}{ + "vstring": "bar", + }, + }, + } + + var result NestedPointer + result.Vbar = new(Basic) + result.Vbar.Vdata = new(Basic) + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if result.Vfoo != "foo" { + t.Errorf("vfoo value should be 'foo': %#v", result.Vfoo) + } + + if result.Vbar.Vstring != "foo" { + t.Errorf("vstring value should be 'foo': %#v", result.Vbar.Vstring) + } + + if result.Vbar.Vint != 42 { + t.Errorf("vint value should be 42: %#v", result.Vbar.Vint) + } + + if result.Vbar.Vbool != true { + t.Errorf("vbool value should be true: %#v", result.Vbar.Vbool) + } + + if result.Vbar.Vextra != "" { + t.Errorf("vextra value should be empty: %#v", result.Vbar.Vextra) + } + + if result.Vbar.Vdata.(*Basic).Vstring != "bar" { + t.Errorf("vstring value should be 'bar': %#v", result.Vbar.Vdata.(*Basic).Vstring) + } +} + +func TestSlice(t *testing.T) { + t.Parallel() + + inputStringSlice := map[string]interface{}{ + "vfoo": "foo", + "vbar": []string{"foo", "bar", "baz"}, + } + + inputStringSlicePointer := map[string]interface{}{ + "vfoo": "foo", + "vbar": &[]string{"foo", "bar", "baz"}, + } + + outputStringSlice := &Slice{ + "foo", + []string{"foo", "bar", "baz"}, + } + + testSliceInput(t, inputStringSlice, outputStringSlice) + testSliceInput(t, inputStringSlicePointer, outputStringSlice) +} + +func TestInvalidSlice(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vfoo": "foo", + "vbar": 42, + } + + result := Slice{} + err := Decode(input, &result) + if err == nil { + t.Errorf("expected failure") + } +} + +func TestSliceOfStruct(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "value": []map[string]interface{}{ + {"vstring": "one"}, + {"vstring": "two"}, + }, + } + + var result SliceOfStruct + err := Decode(input, &result) + if err != nil { + t.Fatalf("got unexpected error: %s", err) + } + + if len(result.Value) != 2 { + t.Fatalf("expected two values, got %d", len(result.Value)) + } + + if result.Value[0].Vstring != "one" { + t.Errorf("first value should be 'one', got: %s", result.Value[0].Vstring) + } + + if result.Value[1].Vstring != "two" { + t.Errorf("second value should be 'two', got: %s", result.Value[1].Vstring) + } +} + +func TestSliceCornerCases(t *testing.T) { + t.Parallel() + + // Input with a map with zero values + input := map[string]interface{}{} + var resultWeak []Basic + + err := WeakDecode(input, &resultWeak) + if err != nil { + t.Fatalf("got unexpected error: %s", err) + } + + if len(resultWeak) != 0 { + t.Errorf("length should be 0") + } + // Input with more values + input = map[string]interface{}{ + "Vstring": "foo", + } + + resultWeak = nil + err = WeakDecode(input, &resultWeak) + if err != nil { + t.Fatalf("got unexpected error: %s", err) + } + + if resultWeak[0].Vstring != "foo" { + t.Errorf("value does not match") + } +} + +func TestSliceToMap(t *testing.T) { + t.Parallel() + + input := []map[string]interface{}{ + { + "foo": "bar", + }, + { + "bar": "baz", + }, + } + + var result map[string]interface{} + err := WeakDecode(input, &result) + if err != nil { + t.Fatalf("got an error: %s", err) + } + + expected := map[string]interface{}{ + "foo": "bar", + "bar": "baz", + } + if !reflect.DeepEqual(result, expected) { + t.Errorf("bad: %#v", result) + } +} + +func TestArray(t *testing.T) { + t.Parallel() + + inputStringArray := map[string]interface{}{ + "vfoo": "foo", + "vbar": [2]string{"foo", "bar"}, + } + + inputStringArrayPointer := map[string]interface{}{ + "vfoo": "foo", + "vbar": &[2]string{"foo", "bar"}, + } + + outputStringArray := &Array{ + "foo", + [2]string{"foo", "bar"}, + } + + testArrayInput(t, inputStringArray, outputStringArray) + testArrayInput(t, inputStringArrayPointer, outputStringArray) +} + +func TestInvalidArray(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vfoo": "foo", + "vbar": 42, + } + + result := Array{} + err := Decode(input, &result) + if err == nil { + t.Errorf("expected failure") + } +} + +func TestArrayOfStruct(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "value": []map[string]interface{}{ + {"vstring": "one"}, + {"vstring": "two"}, + }, + } + + var result ArrayOfStruct + err := Decode(input, &result) + if err != nil { + t.Fatalf("got unexpected error: %s", err) + } + + if len(result.Value) != 2 { + t.Fatalf("expected two values, got %d", len(result.Value)) + } + + if result.Value[0].Vstring != "one" { + t.Errorf("first value should be 'one', got: %s", result.Value[0].Vstring) + } + + if result.Value[1].Vstring != "two" { + t.Errorf("second value should be 'two', got: %s", result.Value[1].Vstring) + } +} + +func TestArrayToMap(t *testing.T) { + t.Parallel() + + input := []map[string]interface{}{ + { + "foo": "bar", + }, + { + "bar": "baz", + }, + } + + var result map[string]interface{} + err := WeakDecode(input, &result) + if err != nil { + t.Fatalf("got an error: %s", err) + } + + expected := map[string]interface{}{ + "foo": "bar", + "bar": "baz", + } + if !reflect.DeepEqual(result, expected) { + t.Errorf("bad: %#v", result) + } +} + +func TestDecodeTable(t *testing.T) { + t.Parallel() + + // We need to make new types so that we don't get the short-circuit + // copy functionality. We want to test the deep copying functionality. + type BasicCopy Basic + type NestedPointerCopy NestedPointer + type MapCopy Map + + tests := []struct { + name string + in interface{} + target interface{} + out interface{} + wantErr bool + }{ + { + "basic struct input", + &Basic{ + Vstring: "vstring", + Vint: 2, + Vint8: 2, + Vint16: 2, + Vint32: 2, + Vint64: 2, + Vuint: 3, + Vbool: true, + Vfloat: 4.56, + Vextra: "vextra", + vsilent: true, + Vdata: []byte("data"), + }, + &map[string]interface{}{}, + &map[string]interface{}{ + "Vstring": "vstring", + "Vint": 2, + "Vint8": int8(2), + "Vint16": int16(2), + "Vint32": int32(2), + "Vint64": int64(2), + "Vuint": uint(3), + "Vbool": true, + "Vfloat": 4.56, + "Vextra": "vextra", + "Vdata": []byte("data"), + "VjsonInt": 0, + "VjsonUint": uint(0), + "VjsonUint64": uint64(0), + "VjsonFloat": 0.0, + "VjsonNumber": json.Number(""), + }, + false, + }, + { + "embedded struct input", + &Embedded{ + Vunique: "vunique", + Basic: Basic{ + Vstring: "vstring", + Vint: 2, + Vint8: 2, + Vint16: 2, + Vint32: 2, + Vint64: 2, + Vuint: 3, + Vbool: true, + Vfloat: 4.56, + Vextra: "vextra", + vsilent: true, + Vdata: []byte("data"), + }, + }, + &map[string]interface{}{}, + &map[string]interface{}{ + "Vunique": "vunique", + "Basic": map[string]interface{}{ + "Vstring": "vstring", + "Vint": 2, + "Vint8": int8(2), + "Vint16": int16(2), + "Vint32": int32(2), + "Vint64": int64(2), + "Vuint": uint(3), + "Vbool": true, + "Vfloat": 4.56, + "Vextra": "vextra", + "Vdata": []byte("data"), + "VjsonInt": 0, + "VjsonUint": uint(0), + "VjsonUint64": uint64(0), + "VjsonFloat": 0.0, + "VjsonNumber": json.Number(""), + }, + }, + false, + }, + { + "struct => struct", + &Basic{ + Vstring: "vstring", + Vint: 2, + Vuint: 3, + Vbool: true, + Vfloat: 4.56, + Vextra: "vextra", + Vdata: []byte("data"), + vsilent: true, + }, + &BasicCopy{}, + &BasicCopy{ + Vstring: "vstring", + Vint: 2, + Vuint: 3, + Vbool: true, + Vfloat: 4.56, + Vextra: "vextra", + Vdata: []byte("data"), + }, + false, + }, + { + "struct => struct with pointers", + &NestedPointer{ + Vfoo: "hello", + Vbar: nil, + }, + &NestedPointerCopy{}, + &NestedPointerCopy{ + Vfoo: "hello", + }, + false, + }, + { + "basic pointer to non-pointer", + &BasicPointer{ + Vstring: stringPtr("vstring"), + Vint: intPtr(2), + Vuint: uintPtr(3), + Vbool: boolPtr(true), + Vfloat: floatPtr(4.56), + Vdata: interfacePtr([]byte("data")), + }, + &Basic{}, + &Basic{ + Vstring: "vstring", + Vint: 2, + Vuint: 3, + Vbool: true, + Vfloat: 4.56, + Vdata: []byte("data"), + }, + false, + }, + { + "slice non-pointer to pointer", + &Slice{}, + &SlicePointer{}, + &SlicePointer{}, + false, + }, + { + "slice non-pointer to pointer, zero field", + &Slice{}, + &SlicePointer{ + Vbar: &[]string{"yo"}, + }, + &SlicePointer{}, + false, + }, + { + "slice to slice alias", + &Slice{}, + &SliceOfAlias{}, + &SliceOfAlias{}, + false, + }, + { + "nil map to map", + &Map{}, + &MapCopy{}, + &MapCopy{}, + false, + }, + { + "nil map to non-empty map", + &Map{}, + &MapCopy{Vother: map[string]string{"foo": "bar"}}, + &MapCopy{}, + false, + }, + + { + "slice input - should error", + []string{"foo", "bar"}, + &map[string]interface{}{}, + &map[string]interface{}{}, + true, + }, + { + "struct with slice property", + &Slice{ + Vfoo: "vfoo", + Vbar: []string{"foo", "bar"}, + }, + &map[string]interface{}{}, + &map[string]interface{}{ + "Vfoo": "vfoo", + "Vbar": []string{"foo", "bar"}, + }, + false, + }, + { + "struct with empty slice", + &map[string]interface{}{ + "Vbar": []string{}, + }, + &Slice{}, + &Slice{ + Vbar: []string{}, + }, + false, + }, + { + "struct with slice of struct property", + &SliceOfStruct{ + Value: []Basic{ + Basic{ + Vstring: "vstring", + Vint: 2, + Vuint: 3, + Vbool: true, + Vfloat: 4.56, + Vextra: "vextra", + vsilent: true, + Vdata: []byte("data"), + }, + }, + }, + &map[string]interface{}{}, + &map[string]interface{}{ + "Value": []Basic{ + Basic{ + Vstring: "vstring", + Vint: 2, + Vuint: 3, + Vbool: true, + Vfloat: 4.56, + Vextra: "vextra", + vsilent: true, + Vdata: []byte("data"), + }, + }, + }, + false, + }, + { + "struct with map property", + &Map{ + Vfoo: "vfoo", + Vother: map[string]string{"vother": "vother"}, + }, + &map[string]interface{}{}, + &map[string]interface{}{ + "Vfoo": "vfoo", + "Vother": map[string]string{ + "vother": "vother", + }}, + false, + }, + { + "tagged struct", + &Tagged{ + Extra: "extra", + Value: "value", + }, + &map[string]string{}, + &map[string]string{ + "bar": "extra", + "foo": "value", + }, + false, + }, + { + "omit tag struct", + &struct { + Value string `mapstructure:"value"` + Omit string `mapstructure:"-"` + }{ + Value: "value", + Omit: "omit", + }, + &map[string]string{}, + &map[string]string{ + "value": "value", + }, + false, + }, + { + "decode to wrong map type", + &struct { + Value string + }{ + Value: "string", + }, + &map[string]int{}, + &map[string]int{}, + true, + }, + { + "remainder", + map[string]interface{}{ + "A": "hello", + "B": "goodbye", + "C": "yo", + }, + &Remainder{}, + &Remainder{ + A: "hello", + Extra: map[string]interface{}{ + "B": "goodbye", + "C": "yo", + }, + }, + false, + }, + { + "remainder with no extra", + map[string]interface{}{ + "A": "hello", + }, + &Remainder{}, + &Remainder{ + A: "hello", + Extra: nil, + }, + false, + }, + { + "struct with omitempty tag return non-empty values", + &struct { + VisibleField interface{} `mapstructure:"visible"` + OmitField interface{} `mapstructure:"omittable,omitempty"` + }{ + VisibleField: nil, + OmitField: "string", + }, + &map[string]interface{}{}, + &map[string]interface{}{"visible": nil, "omittable": "string"}, + false, + }, + { + "struct with omitempty tag ignore empty values", + &struct { + VisibleField interface{} `mapstructure:"visible"` + OmitField interface{} `mapstructure:"omittable,omitempty"` + }{ + VisibleField: nil, + OmitField: nil, + }, + &map[string]interface{}{}, + &map[string]interface{}{"visible": nil}, + false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := Decode(tt.in, tt.target); (err != nil) != tt.wantErr { + t.Fatalf("%q: TestMapOutputForStructuredInputs() unexpected error: %s", tt.name, err) + } + + if !reflect.DeepEqual(tt.out, tt.target) { + t.Fatalf("%q: TestMapOutputForStructuredInputs() expected: %#v, got: %#v", tt.name, tt.out, tt.target) + } + }) + } +} + +func TestInvalidType(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vstring": 42, + } + + var result Basic + err := Decode(input, &result) + if err == nil { + t.Fatal("error should exist") + } + + derr, ok := err.(*Error) + if !ok { + t.Fatalf("error should be kind of Error, instead: %#v", err) + } + + if derr.Errors[0] != + "'Vstring' expected type 'string', got unconvertible type 'int', value: '42'" { + t.Errorf("got unexpected error: %s", err) + } + + inputNegIntUint := map[string]interface{}{ + "vuint": -42, + } + + err = Decode(inputNegIntUint, &result) + if err == nil { + t.Fatal("error should exist") + } + + derr, ok = err.(*Error) + if !ok { + t.Fatalf("error should be kind of Error, instead: %#v", err) + } + + if derr.Errors[0] != "cannot parse 'Vuint', -42 overflows uint" { + t.Errorf("got unexpected error: %s", err) + } + + inputNegFloatUint := map[string]interface{}{ + "vuint": -42.0, + } + + err = Decode(inputNegFloatUint, &result) + if err == nil { + t.Fatal("error should exist") + } + + derr, ok = err.(*Error) + if !ok { + t.Fatalf("error should be kind of Error, instead: %#v", err) + } + + if derr.Errors[0] != "cannot parse 'Vuint', -42.000000 overflows uint" { + t.Errorf("got unexpected error: %s", err) + } +} + +func TestDecodeMetadata(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vfoo": "foo", + "vbar": map[string]interface{}{ + "vstring": "foo", + "Vuint": 42, + "vsilent": "false", + "foo": "bar", + }, + "bar": "nil", + } + + var md Metadata + var result Nested + + err := DecodeMetadata(input, &result, &md) + if err != nil { + t.Fatalf("jderr: %s", err.Error()) + } + + expectedKeys := []string{"Vbar", "Vbar.Vstring", "Vbar.Vuint", "Vfoo"} + sort.Strings(md.Keys) + if !reflect.DeepEqual(md.Keys, expectedKeys) { + t.Fatalf("bad keys: %#v", md.Keys) + } + + expectedUnused := []string{"Vbar.foo", "Vbar.vsilent", "bar"} + sort.Strings(md.Unused) + if !reflect.DeepEqual(md.Unused, expectedUnused) { + t.Fatalf("bad unused: %#v", md.Unused) + } +} + +func TestMetadata(t *testing.T) { + t.Parallel() + + type testResult struct { + Vfoo string + Vbar BasicPointer + } + + input := map[string]interface{}{ + "vfoo": "foo", + "vbar": map[string]interface{}{ + "vstring": "foo", + "Vuint": 42, + "vsilent": "false", + "foo": "bar", + }, + "bar": "nil", + } + + var md Metadata + var result testResult + config := &DecoderConfig{ + Metadata: &md, + Result: &result, + } + + decoder, err := NewDecoder(config) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + err = decoder.Decode(input) + if err != nil { + t.Fatalf("jderr: %s", err.Error()) + } + + expectedKeys := []string{"Vbar", "Vbar.Vstring", "Vbar.Vuint", "Vfoo"} + sort.Strings(md.Keys) + if !reflect.DeepEqual(md.Keys, expectedKeys) { + t.Fatalf("bad keys: %#v", md.Keys) + } + + expectedUnused := []string{"Vbar.foo", "Vbar.vsilent", "bar"} + sort.Strings(md.Unused) + if !reflect.DeepEqual(md.Unused, expectedUnused) { + t.Fatalf("bad unused: %#v", md.Unused) + } + + expectedUnset := []string{ + "Vbar.Vbool", "Vbar.Vdata", "Vbar.Vextra", "Vbar.Vfloat", "Vbar.Vint", + "Vbar.VjsonFloat", "Vbar.VjsonInt", "Vbar.VjsonNumber"} + sort.Strings(md.Unset) + if !reflect.DeepEqual(md.Unset, expectedUnset) { + t.Fatalf("bad unset: %#v", md.Unset) + } +} + +func TestMetadata_Embedded(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vstring": "foo", + "vunique": "bar", + } + + var md Metadata + var result EmbeddedSquash + config := &DecoderConfig{ + Metadata: &md, + Result: &result, + } + + decoder, err := NewDecoder(config) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + err = decoder.Decode(input) + if err != nil { + t.Fatalf("jderr: %s", err.Error()) + } + + expectedKeys := []string{"Vstring", "Vunique"} + + sort.Strings(md.Keys) + if !reflect.DeepEqual(md.Keys, expectedKeys) { + t.Fatalf("bad keys: %#v", md.Keys) + } + + expectedUnused := []string{} + if !reflect.DeepEqual(md.Unused, expectedUnused) { + t.Fatalf("bad unused: %#v", md.Unused) + } +} + +func TestNonPtrValue(t *testing.T) { + t.Parallel() + + err := Decode(map[string]interface{}{}, Basic{}) + if err == nil { + t.Fatal("error should exist") + } + + if err.Error() != "result must be a pointer" { + t.Errorf("got unexpected error: %s", err) + } +} + +func TestTagged(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "foo": "bar", + "bar": "value", + } + + var result Tagged + err := Decode(input, &result) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + + if result.Value != "bar" { + t.Errorf("value should be 'bar', got: %#v", result.Value) + } + + if result.Extra != "value" { + t.Errorf("extra should be 'value', got: %#v", result.Extra) + } +} + +func TestWeakDecode(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "foo": "4", + "bar": "value", + } + + var result struct { + Foo int + Bar string + } + + if err := WeakDecode(input, &result); err != nil { + t.Fatalf("jderr: %s", err) + } + if result.Foo != 4 { + t.Fatalf("bad: %#v", result) + } + if result.Bar != "value" { + t.Fatalf("bad: %#v", result) + } +} + +func TestWeakDecodeMetadata(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "foo": "4", + "bar": "value", + "unused": "value", + "unexported": "value", + } + + var md Metadata + var result struct { + Foo int + Bar string + unexported string + } + + if err := WeakDecodeMetadata(input, &result, &md); err != nil { + t.Fatalf("jderr: %s", err) + } + if result.Foo != 4 { + t.Fatalf("bad: %#v", result) + } + if result.Bar != "value" { + t.Fatalf("bad: %#v", result) + } + + expectedKeys := []string{"Bar", "Foo"} + sort.Strings(md.Keys) + if !reflect.DeepEqual(md.Keys, expectedKeys) { + t.Fatalf("bad keys: %#v", md.Keys) + } + + expectedUnused := []string{"unexported", "unused"} + sort.Strings(md.Unused) + if !reflect.DeepEqual(md.Unused, expectedUnused) { + t.Fatalf("bad unused: %#v", md.Unused) + } +} + +func TestDecode_StructTaggedWithOmitempty_OmitEmptyValues(t *testing.T) { + t.Parallel() + + input := &StructWithOmitEmpty{} + + var emptySlice []interface{} + var emptyMap map[string]interface{} + var emptyNested *Nested + expected := &map[string]interface{}{ + "visible-string": "", + "visible-int": 0, + "visible-float": 0.0, + "visible-slice": emptySlice, + "visible-map": emptyMap, + "visible-nested": emptyNested, + } + + actual := &map[string]interface{}{} + Decode(input, actual) + + if !reflect.DeepEqual(actual, expected) { + t.Fatalf("Decode() expected: %#v, got: %#v", expected, actual) + } +} + +func TestDecode_StructTaggedWithOmitempty_KeepNonEmptyValues(t *testing.T) { + t.Parallel() + + input := &StructWithOmitEmpty{ + VisibleStringField: "", + OmitStringField: "string", + VisibleIntField: 0, + OmitIntField: 1, + VisibleFloatField: 0.0, + OmitFloatField: 1.0, + VisibleSliceField: nil, + OmitSliceField: []interface{}{1}, + VisibleMapField: nil, + OmitMapField: map[string]interface{}{"k": "v"}, + NestedField: nil, + OmitNestedField: &Nested{}, + } + + var emptySlice []interface{} + var emptyMap map[string]interface{} + var emptyNested *Nested + expected := &map[string]interface{}{ + "visible-string": "", + "omittable-string": "string", + "visible-int": 0, + "omittable-int": 1, + "visible-float": 0.0, + "omittable-float": 1.0, + "visible-slice": emptySlice, + "omittable-slice": []interface{}{1}, + "visible-map": emptyMap, + "omittable-map": map[string]interface{}{"k": "v"}, + "visible-nested": emptyNested, + "omittable-nested": &Nested{}, + } + + actual := &map[string]interface{}{} + Decode(input, actual) + + if !reflect.DeepEqual(actual, expected) { + t.Fatalf("Decode() expected: %#v, got: %#v", expected, actual) + } +} + +func TestDecode_mapToStruct(t *testing.T) { + type Target struct { + String string + StringPtr *string + } + + expected := Target{ + String: "hello", + } + + var target Target + err := Decode(map[string]interface{}{ + "string": "hello", + "StringPtr": "goodbye", + }, &target) + if err != nil { + t.Fatalf("got error: %s", err) + } + + // Pointers fail reflect test so do those manually + if target.StringPtr == nil || *target.StringPtr != "goodbye" { + t.Fatalf("bad: %#v", target) + } + target.StringPtr = nil + + if !reflect.DeepEqual(target, expected) { + t.Fatalf("bad: %#v", target) + } +} + +func TestDecoder_MatchName(t *testing.T) { + t.Parallel() + + type Target struct { + FirstMatch string `mapstructure:"first_match"` + SecondMatch string + NoMatch string `mapstructure:"no_match"` + } + + input := map[string]interface{}{ + "first_match": "foo", + "SecondMatch": "bar", + "NO_MATCH": "baz", + } + + expected := Target{ + FirstMatch: "foo", + SecondMatch: "bar", + } + + var actual Target + config := &DecoderConfig{ + Result: &actual, + MatchName: func(mapKey, fieldName string) bool { + return mapKey == fieldName + }, + } + + decoder, err := NewDecoder(config) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + err = decoder.Decode(input) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + if !reflect.DeepEqual(expected, actual) { + t.Fatalf("Decode() expected: %#v, got: %#v", expected, actual) + } +} + +func TestDecoder_IgnoreUntaggedFields(t *testing.T) { + type Input struct { + UntaggedNumber int + TaggedNumber int `mapstructure:"tagged_number"` + UntaggedString string + TaggedString string `mapstructure:"tagged_string"` + } + input := &Input{ + UntaggedNumber: 31, + TaggedNumber: 42, + UntaggedString: "hidden", + TaggedString: "visible", + } + + actual := make(map[string]interface{}) + config := &DecoderConfig{ + Result: &actual, + IgnoreUntaggedFields: true, + } + + decoder, err := NewDecoder(config) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + err = decoder.Decode(input) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + expected := map[string]interface{}{ + "tagged_number": 42, + "tagged_string": "visible", + } + + if !reflect.DeepEqual(expected, actual) { + t.Fatalf("Decode() expected: %#v\ngot: %#v", expected, actual) + } +} + +func testSliceInput(t *testing.T, input map[string]interface{}, expected *Slice) { + var result Slice + err := Decode(input, &result) + if err != nil { + t.Fatalf("got error: %s", err) + } + + if result.Vfoo != expected.Vfoo { + t.Errorf("Vfoo expected '%s', got '%s'", expected.Vfoo, result.Vfoo) + } + + if result.Vbar == nil { + t.Fatalf("Vbar a slice, got '%#v'", result.Vbar) + } + + if len(result.Vbar) != len(expected.Vbar) { + t.Errorf("Vbar length should be %d, got %d", len(expected.Vbar), len(result.Vbar)) + } + + for i, v := range result.Vbar { + if v != expected.Vbar[i] { + t.Errorf( + "Vbar[%d] should be '%#v', got '%#v'", + i, expected.Vbar[i], v) + } + } +} + +func testArrayInput(t *testing.T, input map[string]interface{}, expected *Array) { + var result Array + err := Decode(input, &result) + if err != nil { + t.Fatalf("got error: %s", err) + } + + if result.Vfoo != expected.Vfoo { + t.Errorf("Vfoo expected '%s', got '%s'", expected.Vfoo, result.Vfoo) + } + + if result.Vbar == [2]string{} { + t.Fatalf("Vbar a slice, got '%#v'", result.Vbar) + } + + if len(result.Vbar) != len(expected.Vbar) { + t.Errorf("Vbar length should be %d, got %d", len(expected.Vbar), len(result.Vbar)) + } + + for i, v := range result.Vbar { + if v != expected.Vbar[i] { + t.Errorf( + "Vbar[%d] should be '%#v', got '%#v'", + i, expected.Vbar[i], v) + } + } +} + +func stringPtr(v string) *string { return &v } +func intPtr(v int) *int { return &v } +func uintPtr(v uint) *uint { return &v } +func boolPtr(v bool) *bool { return &v } +func floatPtr(v float64) *float64 { return &v } +func interfacePtr(v interface{}) *interface{} { return &v } diff --git a/internal/pkg/mapstructure/my_decode.go b/internal/pkg/mapstructure/my_decode.go new file mode 100644 index 0000000..6cafe11 --- /dev/null +++ b/internal/pkg/mapstructure/my_decode.go @@ -0,0 +1,26 @@ +package mapstructure + +import "time" + +// DecodeWithTime 支持时间转字符串 +// 支持 +// 1. *Time.time 转 string/*string +// 2. *Time.time 转 uint/uint32/uint64/int/int32/int64,支持带指针 +// 不能用 Time.time 转,它会在上层认为是一个结构体数据而直接转成map,再到hook方法 +func DecodeWithTime(input, output interface{}, layout string) error { + if layout == "" { + layout = time.DateTime + } + config := &DecoderConfig{ + Metadata: nil, + Result: output, + DecodeHook: ComposeDecodeHookFunc(TimeToStringHook(layout), TimeToUnixIntHook()), + } + + decoder, err := NewDecoder(config) + if err != nil { + return err + } + + return decoder.Decode(input) +} diff --git a/internal/pkg/mapstructure/my_decode_hook.go b/internal/pkg/mapstructure/my_decode_hook.go new file mode 100644 index 0000000..4873731 --- /dev/null +++ b/internal/pkg/mapstructure/my_decode_hook.go @@ -0,0 +1,101 @@ +package mapstructure + +import ( + "reflect" + "time" +) + +// TimeToStringHook 时间转字符串 +// 支持 *Time.time 转 string/*string +// 不能用 Time.time 转,它会在上层认为是一个结构体数据而直接转成map,再到hook方法 +func TimeToStringHook(layout string) DecodeHookFunc { + return func( + f reflect.Type, + t reflect.Type, + data interface{}) (interface{}, error) { + // 判断目标类型是否为字符串 + var strType string + var isStrPointer *bool // 要转换的目标类型是否为指针字符串 + if t == reflect.TypeOf(strType) { + isStrPointer = new(bool) + } else if t == reflect.TypeOf(&strType) { + isStrPointer = new(bool) + *isStrPointer = true + } + if isStrPointer == nil { + return data, nil + } + + // 判断类型是否为时间 + timeType := time.Time{} + if f != reflect.TypeOf(timeType) && f != reflect.TypeOf(&timeType) { + return data, nil + } + + // 将时间转换为字符串 + var output string + switch v := data.(type) { + case *time.Time: + output = v.Format(layout) + case time.Time: + output = v.Format(layout) + default: + return data, nil + } + + if *isStrPointer { + return &output, nil + } + return output, nil + } +} + +// TimeToUnixIntHook 时间转时间戳 +// 支持 *Time.time 转 uint/uint32/uint64/int/int32/int64,支持带指针 +// 不能用 Time.time 转,它会在上层认为是一个结构体数据而直接转成map,再到hook方法 +func TimeToUnixIntHook() DecodeHookFunc { + return func( + f reflect.Type, + t reflect.Type, + data interface{}) (interface{}, error) { + + tkd := t.Kind() + if tkd != reflect.Int && tkd != reflect.Int32 && tkd != reflect.Int64 && + tkd != reflect.Uint && tkd != reflect.Uint32 && tkd != reflect.Uint64 { + return data, nil + } + + // 判断类型是否为时间 + timeType := time.Time{} + if f != reflect.TypeOf(timeType) && f != reflect.TypeOf(&timeType) { + return data, nil + } + + // 将时间转换为字符串 + var output int64 + switch v := data.(type) { + case *time.Time: + output = v.Unix() + case time.Time: + output = v.Unix() + default: + return data, nil + } + switch tkd { + case reflect.Int: + return int(output), nil + case reflect.Int32: + return int32(output), nil + case reflect.Int64: + return output, nil + case reflect.Uint: + return uint(output), nil + case reflect.Uint32: + return uint32(output), nil + case reflect.Uint64: + return uint64(output), nil + default: + return data, nil + } + } +} diff --git a/internal/pkg/mapstructure/my_decode_hook_test.go b/internal/pkg/mapstructure/my_decode_hook_test.go new file mode 100644 index 0000000..2f5687d --- /dev/null +++ b/internal/pkg/mapstructure/my_decode_hook_test.go @@ -0,0 +1,274 @@ +package mapstructure + +import ( + "testing" + "time" +) + +func Test_TimeToStringHook(t *testing.T) { + type Input struct { + Time time.Time + Id int + } + + type InputTPointer struct { + Time *time.Time + Id int + } + + type Output struct { + Time string + Id int + } + + type OutputTPointer struct { + Time *string + Id int + } + now := time.Now() + target := now.Format("2006-01-02 15:04:05") + idValue := 1 + tests := []struct { + input any + output any + name string + layout string + }{ + { + name: "测试Time.time转string", + layout: "2006-01-02 15:04:05", + input: InputTPointer{ + Time: &now, + Id: idValue, + }, + output: Output{}, + }, + { + name: "测试*Time.time转*string", + layout: "2006-01-02 15:04:05", + input: InputTPointer{ + Time: &now, + Id: idValue, + }, + output: OutputTPointer{}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + decoder, err := NewDecoder(&DecoderConfig{ + DecodeHook: TimeToStringHook(tt.layout), + Result: &tt.output, + }) + if err != nil { + t.Errorf("NewDecoder() jderr = %v,want nil", err) + } + + if i, isOk := tt.input.(Input); isOk { + err = decoder.Decode(i) + } + if i, isOk := tt.input.(InputTPointer); isOk { + err = decoder.Decode(&i) + } + if err != nil { + t.Errorf("Decode jderr = %v,want nil", err) + } + //验证测试值 + if output, isOk := tt.output.(OutputTPointer); isOk { + if *output.Time != target { + t.Errorf("Decode output time = %v,want %v", *output.Time, target) + } + if output.Id != idValue { + t.Errorf("Decode output id = %v,want %v", output.Id, idValue) + } + } + if output, isOk := tt.output.(Output); isOk { + if output.Time != target { + t.Errorf("Decode output time = %v,want %v", output.Time, target) + } + if output.Id != idValue { + t.Errorf("Decode output id = %v,want %v", output.Id, idValue) + } + } + }) + } +} + +func Test_TimeToUnixIntHook(t *testing.T) { + type InputTPointer struct { + Time *time.Time + Id int + } + + type Output[T int | *int | int32 | *int32 | int64 | *int64 | uint | *uint] struct { + Time T + Id int + } + + type test struct { + input any + output any + name string + layout string + } + + now := time.Now() + target := now.Unix() + idValue := 1 + tests := []test{ + { + name: "测试Time.time转int", + layout: "2006-01-02 15:04:05", + input: InputTPointer{ + Time: &now, + Id: idValue, + }, + output: Output[int]{}, + }, + { + name: "测试Time.time转*int", + layout: "2006-01-02 15:04:05", + input: InputTPointer{ + Time: &now, + Id: idValue, + }, + output: Output[*int]{}, + }, + { + name: "测试Time.time转int32", + layout: "2006-01-02 15:04:05", + input: InputTPointer{ + Time: &now, + Id: idValue, + }, + output: Output[int32]{}, + }, + { + name: "测试Time.time转*int32", + layout: "2006-01-02 15:04:05", + input: InputTPointer{ + Time: &now, + Id: idValue, + }, + output: Output[*int32]{}, + }, + { + name: "测试Time.time转int64", + layout: "2006-01-02 15:04:05", + input: InputTPointer{ + Time: &now, + Id: idValue, + }, + output: Output[int64]{}, + }, + { + name: "测试Time.time转*int64", + layout: "2006-01-02 15:04:05", + input: InputTPointer{ + Time: &now, + Id: idValue, + }, + output: Output[*int64]{}, + }, + { + name: "测试Time.time转uint", + layout: "2006-01-02 15:04:05", + input: InputTPointer{ + Time: &now, + Id: idValue, + }, + output: Output[uint]{}, + }, + { + name: "测试Time.time转*uint", + layout: "2006-01-02 15:04:05", + input: InputTPointer{ + Time: &now, + Id: idValue, + }, + output: Output[*uint]{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + decoder, err := NewDecoder(&DecoderConfig{ + DecodeHook: TimeToUnixIntHook(), + Result: &tt.output, + }) + if err != nil { + t.Errorf("NewDecoder() jderr = %v,want nil", err) + } + + if i, isOk := tt.input.(InputTPointer); isOk { + err = decoder.Decode(i) + } + if i, isOk := tt.input.(InputTPointer); isOk { + err = decoder.Decode(&i) + } + if err != nil { + t.Errorf("Decode jderr = %v,want nil", err) + } + + //验证测试值 + switch v := tt.output.(type) { + case Output[int]: + if int64(v.Time) != target { + t.Errorf("Decode output time = %v,want %v", v.Time, target) + } + if v.Id != idValue { + t.Errorf("Decode output id = %v,want %v", v.Id, idValue) + } + case Output[*int]: + if int64(*v.Time) != target { + t.Errorf("Decode output time = %v,want %v", v.Time, target) + } + if v.Id != idValue { + t.Errorf("Decode output id = %v,want %v", v.Id, idValue) + } + case Output[int32]: + if int64(v.Time) != target { + t.Errorf("Decode output time = %v,want %v", v.Time, target) + } + if v.Id != idValue { + t.Errorf("Decode output id = %v,want %v", v.Id, idValue) + } + case Output[*int32]: + if int64(*v.Time) != target { + t.Errorf("Decode output time = %v,want %v", v.Time, target) + } + if v.Id != idValue { + t.Errorf("Decode output id = %v,want %v", v.Id, idValue) + } + case Output[int64]: + if int64(v.Time) != target { + t.Errorf("Decode output time = %v,want %v", v.Time, target) + } + if v.Id != idValue { + t.Errorf("Decode output id = %v,want %v", v.Id, idValue) + } + case Output[*int64]: + if int64(*v.Time) != target { + t.Errorf("Decode output time = %v,want %v", v.Time, target) + } + if v.Id != idValue { + t.Errorf("Decode output id = %v,want %v", v.Id, idValue) + } + case Output[uint]: + if int64(v.Time) != target { + t.Errorf("Decode output time = %v,want %v", v.Time, target) + } + if v.Id != idValue { + t.Errorf("Decode output id = %v,want %v", v.Id, idValue) + } + case Output[*uint]: + if int64(*v.Time) != target { + t.Errorf("Decode output time = %v,want %v", v.Time, target) + } + if v.Id != idValue { + t.Errorf("Decode output id = %v,want %v", v.Id, idValue) + } + } + + }) + } +} diff --git a/internal/pkg/provider_set.go b/internal/pkg/provider_set.go index 60b276d..a761c17 100644 --- a/internal/pkg/provider_set.go +++ b/internal/pkg/provider_set.go @@ -2,8 +2,11 @@ package pkg import ( "ai_scheduler/internal/pkg/ollama" - "github.com/google/wire" ) -var ProviderSetClient = wire.NewSet(ollama.NewClient) +var ProviderSetClient = wire.NewSet( + NewRdb, + NewGormDb, + ollama.NewClient, +) diff --git a/internal/pkg/rds.go b/internal/pkg/rds.go new file mode 100644 index 0000000..a54e9fc --- /dev/null +++ b/internal/pkg/rds.go @@ -0,0 +1,48 @@ +package pkg + +import ( + "ai_scheduler/internal/config" + "time" + + "github.com/redis/go-redis/v9" +) + +type Rdb struct { + Rdb *redis.Client +} + +var rdb *Rdb + +func NewRdb(c *config.Config) *Rdb { + if rdb == nil { + //构建 redis + rdbBuild := buildRdb(c.Redis) + //退出时清理资源 + rdb = &Rdb{Rdb: rdbBuild} + } + //cleanup := func() { + // if rdb != nil { + // if err := rdb.Rdb.Close(); err != nil { + // fmt.Println("关闭 redis 失败:%v", err) + // } + // } + // fmt.Println("关闭 data 中的连接资源已完成") + //} + return rdb +} + +// buildRdb 构建redis client +func buildRdb(c *config.Redis) *redis.Client { + + rdb := redis.NewClient(&redis.Options{ + Addr: c.Host, + Password: c.Pass, + ReadTimeout: time.Duration(c.Tls) * time.Second, + WriteTimeout: time.Duration(c.Tls) * time.Second, + PoolSize: int(c.PoolSize), + MinIdleConns: int(c.MaxIdle), + ConnMaxIdleTime: time.Duration(c.MaxIdleTime) * time.Second, + DB: int(c.Db), + }) + return rdb +} diff --git a/internal/pkg/utils_gorm/gorm.go b/internal/pkg/utils_gorm/gorm.go new file mode 100644 index 0000000..0a579d9 --- /dev/null +++ b/internal/pkg/utils_gorm/gorm.go @@ -0,0 +1,42 @@ +package utils_gorm + +import ( + "ai_scheduler/internal/config" + "database/sql" + "fmt" + "time" + + "gorm.io/driver/mysql" + "gorm.io/gorm" +) + +func DBConn(c *config.DB) (*gorm.DB, func()) { + mysqlConn, err := sql.Open(c.Driver, c.Source) + gormDB, err := gorm.Open( + mysql.New(mysql.Config{Conn: mysqlConn}), + ) + + gormDB.Logger = NewCustomLogger(gormDB) + if err != nil { + panic("failed to connect database") + } + sqlDB, err := gormDB.DB() + + // SetMaxIdleConns sets the maximum number of connections in the idle connection pool. + sqlDB.SetMaxIdleConns(int(c.MaxIdle)) + + // SetMaxOpenConns sets the maximum number of open connections to the database. + sqlDB.SetMaxOpenConns(int(c.MaxLifetime)) + + // SetConnMaxLifetime sets the maximum amount of time a connection may be reused. + sqlDB.SetConnMaxLifetime(time.Hour) + + return gormDB, func() { + if mysqlConn != nil { + fmt.Println("关闭 physicalGoodsDB") + if err := mysqlConn.Close(); err != nil { + fmt.Println("关闭 physicalGoodsDB 失败:", err) + } + } + } +} diff --git a/internal/pkg/utils_gorm/sql_log.go b/internal/pkg/utils_gorm/sql_log.go new file mode 100644 index 0000000..b34a29d --- /dev/null +++ b/internal/pkg/utils_gorm/sql_log.go @@ -0,0 +1,96 @@ +package utils_gorm + +import ( + "context" + "fmt" + "gorm.io/gorm" + "gorm.io/gorm/logger" + "regexp" + "strings" + "time" +) + +type CustomLogger struct { + gormLogger logger.Interface + db *gorm.DB +} + +func NewCustomLogger(db *gorm.DB) *CustomLogger { + return &CustomLogger{ + gormLogger: logger.Default.LogMode(logger.Info), + db: db, + } +} + +func (l *CustomLogger) LogMode(level logger.LogLevel) logger.Interface { + newlogger := *l + newlogger.gormLogger = l.gormLogger.LogMode(level) + return &newlogger +} + +func (l *CustomLogger) Info(ctx context.Context, msg string, data ...interface{}) { + l.gormLogger.Info(ctx, msg, data...) +} + +func (l *CustomLogger) Warn(ctx context.Context, msg string, data ...interface{}) { + l.gormLogger.Warn(ctx, msg, data...) +} + +func (l *CustomLogger) Error(ctx context.Context, msg string, data ...interface{}) { + l.gormLogger.Error(ctx, msg, data...) +} + +func (l *CustomLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { + elapsed := time.Since(begin) + sql, _ := fc() + l.gormLogger.Trace(ctx, begin, fc, err) + operation := extractOperation(sql) + tableName := extractTableName(sql) + fmt.Println(tableName) + //// 将SQL语句保存到数据库 + if operation == 0 || tableName == "sql_log" { + return + } + //go l.db.Model(&SqlLog{}).Create(&SqlLog{ + // OperatorID: 1, + // OperatorName: "test", + // SqlInfo: sql, + // TableNames: tableName, + // Type: operation, + //}) + + // 如果有需要,也可以根据执行时间(elapsed)等条件过滤或处理日志记录 + if elapsed > time.Second { + //l.gormLogger.Warn(ctx, "Slow SQL (> 1s): %s", sql) + } +} + +// extractTableName extracts the table name from a SQL query, supporting quoted table names. +func extractTableName(sql string) string { + // 使用非捕获组匹配多种SQL操作关键词 + re := regexp.MustCompile(`(?i)\b(?:from|update|into|delete\s+from)\b\s+[\` + "`" + `"]?(\w+)[\` + "`" + `"]?`) + match := re.FindStringSubmatch(sql) + + // 检查是否匹配成功 + if len(match) > 1 { + return match[1] + } + + return "" +} + +// extractOperation extracts the operation type from a SQL query. +func extractOperation(sql string) int32 { + sql = strings.TrimSpace(strings.ToLower(sql)) + var operation int32 + if strings.HasPrefix(sql, "select") { + operation = 0 + } else if strings.HasPrefix(sql, "insert") { + operation = 1 + } else if strings.HasPrefix(sql, "update") { + operation = 3 + } else if strings.HasPrefix(sql, "delete") { + operation = 2 + } + return operation +} diff --git a/internal/services/chat.go b/internal/services/chat.go index b10429f..0dd09fb 100644 --- a/internal/services/chat.go +++ b/internal/services/chat.go @@ -1,6 +1,7 @@ package services import ( + "ai_scheduler/internal/data/constant" "ai_scheduler/internal/entitys" "encoding/json" "log" @@ -40,62 +41,54 @@ func (h *ChatService) ChatFail(c *websocket.Conn, content string) { func (h *ChatService) Chat(c *websocket.Conn) { log.Println("客户端已连接") - + defer c.Close() // 循环读取客户端消息 for { - messageType, msg, err := c.ReadMessage() + messageType, message, err := c.ReadMessage() if err != nil { log.Println("读取错误:", err) break } + msg, chatType := h.handleMessageToString(c, messageType, message) + if chatType == constant.ConnStatusClosed { + break + } + if chatType == constant.ConnStatusIgnore { + continue + } - log.Printf("收到消息: %s", msg) + log.Printf("收到消息: %s", string(msg)) var req entitys.ChatSockRequest if err := json.Unmarshal(msg, &req); err != nil { log.Println("JSON parse error:", err) continue } - - // 回显消息给客户端 - if err := c.WriteMessage(messageType, msg); err != nil { - log.Println("写入错误:", err) - break + err = h.routerService.RouteWithSocket(c, &req) + if err != nil { + log.Println("处理失败:", err) + continue } } - log.Println("客户端已断开") - //var req entitys.ChatRequest - //if err := c.BodyParser(&req); err != nil { - // return errors.ParamError - //} - // - //// 转换为服务层请求 - //serviceReq := &entitys.ChatRequest{ - // UserInput: req.UserInput, - // Caller: req.Caller, - // SessionID: req.SessionID, - // ChatRequestMeta: entitys.ChatRequestMeta{ - // Authorization: c.Request().Header(), - // }, - //} - // - //// 调用路由服务 - //response, err := h.routerService.Route(c.Request.Context(), serviceReq) - //if err != nil { - // c.JSON(http.StatusInternalServerError, ChatResponse{ - // Status: "error", - // Message: err.Error(), - // }) - // return - //} - // - //// 转换响应格式 - //httpResponse := &ChatResponse{ - // Message: response.Message, - // Status: response.Status, - // Data: response.Data, - // TaskCode: response.TaskCode, - //} - // - //c.JSON(http.StatusOK, httpResponse) +} + +func (h *ChatService) handleMessageToString(c *websocket.Conn, msgType int, msg any) (text []byte, chatType constant.ConnStatus) { + switch msgType { + case websocket.TextMessage: + return msg.([]byte), constant.ConnStatusNormal + case websocket.BinaryMessage: + return msg.([]byte), constant.ConnStatusNormal + case websocket.CloseMessage: + + return nil, constant.ConnStatusClosed + case websocket.PingMessage: + // 可选:回复 Pong + c.WriteMessage(websocket.PongMessage, nil) + return nil, constant.ConnStatusIgnore + case websocket.PongMessage: + return nil, constant.ConnStatusIgnore + default: + return nil, constant.ConnStatusIgnore + } + return msg.([]byte), constant.ConnStatusIgnore } diff --git a/tmpl/dataTemp/queryTempl.go b/tmpl/dataTemp/queryTempl.go new file mode 100644 index 0000000..f3f44f2 --- /dev/null +++ b/tmpl/dataTemp/queryTempl.go @@ -0,0 +1,113 @@ +package dataTemp + +import ( + "ai_scheduler/internal/pkg/mapstructure" + "ai_scheduler/utils" + "context" + "database/sql" + + "github.com/go-kratos/kratos/v2/log" + "gorm.io/gorm" + + "xorm.io/builder" +) + +type PrimaryKey struct { + Id int `json:"id"` +} + +type GormDb struct { + Client *gorm.DB +} +type contextTxKey struct{} + +func (d *Db) DB(ctx context.Context) *gorm.DB { + tx, ok := ctx.Value(contextTxKey{}).(*gorm.DB) + if ok { + return tx + } + return d.Db.Client +} + +func (t *Db) ExecTx(ctx context.Context, f func(ctx context.Context) error) error { + return t.Db.Client.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + ctx = context.WithValue(ctx, contextTxKey{}, tx) + return f(ctx) + }) +} + +type Db struct { + Db *GormDb + Log *log.Helper +} + +type DataTemp struct { + Db *gorm.DB + Model interface{} + Do interface{} +} + +func NewDataTemp(db *utils.Db, model interface{}) *DataTemp { + return &DataTemp{Db: db.Client, Model: model} +} + +func (k DataTemp) GetById(id int) (data map[string]interface{}, err error) { + err = k.Db.Model(k.Model).Where("id = ?", id).Find(&data).Error + if data == nil { + err = sql.ErrNoRows + } + return +} + +func (k DataTemp) Add(data interface{}) (id int, err error) { + var primary *PrimaryKey + add := k.Db.Model(k.Model).Create(data) + _ = mapstructure.Decode(data, &primary) + return primary.Id, add.Error +} + +func (k DataTemp) GetList(cond *builder.Cond, pageBoIn *ReqPageBo) (list []map[string]interface{}, pageBoOut *RespPageBo, err error) { + var ( + query, _ = builder.ToBoundSQL(*cond) + model = k.Db.Model(k.Model).Where(query) + total int64 + ) + model.Count(&total) + pageBoOut = pageBoOut.SetDataByReq(total, pageBoIn) + model.Limit(pageBoIn.GetSize()).Offset(pageBoIn.GetOffset()).Order("updated_at desc").Find(&list) + return +} + +func (k DataTemp) GetRange(cond *builder.Cond) (list []map[string]interface{}, err error) { + var ( + query, _ = builder.ToBoundSQL(*cond) + model = k.Db.Model(k.Model).Where(query) + ) + err = model.Find(&list).Error + return list, err +} + +func (k DataTemp) GetOneBySearch(cond *builder.Cond) (data map[string]interface{}, err error) { + query, _ := builder.ToBoundSQL(*cond) + err = k.Db.Model(k.Model).Where(query).Limit(1).Find(&data).Error + if data == nil { + err = sql.ErrNoRows + } + return +} + +func (k DataTemp) GetOneBySearchToStrut(cond *builder.Cond, result interface{}) error { + query, _ := builder.ToBoundSQL(*cond) + err := k.Db.Model(k.Model).Where(query).Limit(1).Find(&result).Error + + return err +} + +func (k DataTemp) UpdateByCond(cond *builder.Cond, data interface{}) (err error) { + var ( + query, _ = builder.ToBoundSQL(*cond) + model = k.Db.Model(k.Model) + ) + err = model.Where(query).Updates(data).Error + return +} diff --git a/tmpl/dataTemp/req_page.go b/tmpl/dataTemp/req_page.go new file mode 100644 index 0000000..985cc5a --- /dev/null +++ b/tmpl/dataTemp/req_page.go @@ -0,0 +1,35 @@ +package dataTemp + +// ReqPageBo 分页请求实体 +type ReqPageBo struct { + Page int //页码,从第1页开始 + Limit int //分页大小 +} + +// GetOffset 获取便宜量 +func (r *ReqPageBo) GetOffset() int { + if r == nil { + return 0 + } + offset := (r.Page - 1) * r.Limit + if offset < 0 { + return 0 + } + return offset +} + +// GetSize 获取分页 +func (r *ReqPageBo) GetSize() int { + if r == nil { + return 20 + } + return r.Limit +} + +// GetNum 获取页码 +func (r *ReqPageBo) GetNum() int { + if r == nil { + return 1 + } + return r.Page +} diff --git a/tmpl/dataTemp/resp_page.go b/tmpl/dataTemp/resp_page.go new file mode 100644 index 0000000..e707c8c --- /dev/null +++ b/tmpl/dataTemp/resp_page.go @@ -0,0 +1,22 @@ +package dataTemp + +// RespPageBo 分页响应实体 +type RespPageBo struct { + Page int //页码 + Limit int //每页大小 + Total int64 //总数 +} + +// SetDataByReq 通过req 设置响应参数 +func (r *RespPageBo) SetDataByReq(total int64, reqPage *ReqPageBo) *RespPageBo { + resp := r + if r == nil { + resp = &RespPageBo{} + } + resp.Total = total + if reqPage != nil { + resp.Page = reqPage.Page + resp.Limit = reqPage.Limit + } + return resp +} diff --git a/tmpl/errcode/common.go b/tmpl/errcode/common.go new file mode 100644 index 0000000..bb082b1 --- /dev/null +++ b/tmpl/errcode/common.go @@ -0,0 +1,86 @@ +package errcode + +import ( + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// SuccessMsg 自定义成功消息 +var SuccessMsg = "成功" +var SuccessCode = 200 + +func SetSuccessMsg(msg string) { + SuccessMsg = msg +} + +type BusinessErr struct { + Code int32 + Message string +} + +func (e *BusinessErr) Error() string { + return e.Message +} +func (e *BusinessErr) GRPCStatus() *status.Status { + var code = codes.Code(e.Code) + return status.New(code, e.Message) +} + +// CustomErr 自定义错误 +func CustomErr(code int32, message string) *BusinessErr { + return &BusinessErr{Code: code, Message: message} +} + +var ( + NotFoundErr = &BusinessErr{Code: 404, Message: "资源未找到"} + ParamError = &BusinessErr{Code: 400, Message: "参数错误"} + //未经授权 + NotAuth = &BusinessErr{Code: 401, Message: "未授权"} + EnCrypt = &BusinessErr{Code: 401, Message: "加密失败"} + DeCrypt = &BusinessErr{Code: 401, Message: "解密失败"} + + //请求被禁止 + Forbidden = &BusinessErr{Code: 403, Message: "禁止访问"} + WhiteIp = &BusinessErr{Code: 403, Message: "访问IP,不在白名单内"} + + //系统错误 + SystemError = &BusinessErr{Code: 500, Message: "系统错误"} + + ThirtyDayQueryLimit = &BusinessErr{Code: 420, Message: "只能查询最近31天的数据"} + + GoodsSameErr = &BusinessErr{Code: 404, Message: "存在相同货品编码|货品名称的商品,请检查后重试"} + + HadDefaultWareHouseErr = &BusinessErr{Code: 400, Message: "该商品已存在默认仓,请检查后重试"} + + HadSameSupplierRelation = &BusinessErr{Code: 400, Message: "已存在相同的供应商商品关系,请检查后重试"} + + AppRsaEncryptKeyNotFound = &BusinessErr{Code: 400, Message: "密钥缺失"} + + AppRsaEncryptFail = &BusinessErr{Code: 400, Message: "Rsa加密失败"} + + AppRsaDecryptKeyNotFound = &BusinessErr{Code: 400, Message: "密钥缺失"} + + AppRsaDecryptFail = &BusinessErr{Code: 400, Message: "Rsa解密失败"} + + AppSM2EncryptKeyNotFound = &BusinessErr{Code: 400, Message: "密钥缺失"} + + AppSM2EncryptFail = &BusinessErr{Code: 400, Message: "Sm2加密失败"} + + AppSM2DecryptKeyNotFound = &BusinessErr{Code: 400, Message: "密钥缺失"} + + AppSM2DecryptFail = &BusinessErr{Code: 400, Message: "Sm2解密失败"} + + AppSM4EncryptKeyNotFound = &BusinessErr{Code: 400, Message: "密钥缺失"} + + AppSM4EncryptFail = &BusinessErr{Code: 400, Message: "Sm4加密失败"} + + AppSM4DecryptKeyNotFound = &BusinessErr{Code: 400, Message: "密钥缺失"} + + AppSM4DecryptFail = &BusinessErr{Code: 400, Message: "Sm4解密失败"} + + BalanceNotEnough = &BusinessErr{Code: 400, Message: "余额不足"} + + CusStatusException = &BusinessErr{Code: 400, Message: "客户状态异常"} + + HadSameCusGoods = &BusinessErr{Code: 400, Message: "已存在相同的客户商品授权,请检查后重试"} +) diff --git a/utils/gorm.go b/utils/gorm.go new file mode 100644 index 0000000..976cc51 --- /dev/null +++ b/utils/gorm.go @@ -0,0 +1,25 @@ +package utils + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/pkg/utils_gorm" + + "gorm.io/gorm" +) + +type Db struct { + Client *gorm.DB +} + +func NewGormDb(c *config.Config) (*Db, func()) { + transDBClient, mf := utils_gorm.DBConn(c.DB) + //directDBClient, df := directDB(c, hLog) + cleanup := func() { + mf() + //df() + } + return &Db{ + Client: transDBClient, + //DirectDBClient: directDBClient, + }, cleanup +} diff --git a/utils/provider_set.go b/utils/provider_set.go new file mode 100644 index 0000000..478d456 --- /dev/null +++ b/utils/provider_set.go @@ -0,0 +1,10 @@ +package utils + +import ( + "github.com/google/wire" +) + +var ProviderUtils = wire.NewSet( + NewRdb, + NewGormDb, +) diff --git a/utils/rds.go b/utils/rds.go new file mode 100644 index 0000000..57b89f1 --- /dev/null +++ b/utils/rds.go @@ -0,0 +1,48 @@ +package utils + +import ( + "ai_scheduler/internal/config" + "time" + + "github.com/redis/go-redis/v9" +) + +type Rdb struct { + Rdb *redis.Client +} + +var rdb *Rdb + +func NewRdb(c *config.Redis) *Rdb { + if rdb == nil { + //构建 redis + rdbBuild := buildRdb(c) + //退出时清理资源 + rdb = &Rdb{Rdb: rdbBuild} + } + //cleanup := func() { + // if rdb != nil { + // if err := rdb.Rdb.Close(); err != nil { + // fmt.Println("关闭 redis 失败:%v", err) + // } + // } + // fmt.Println("关闭 data 中的连接资源已完成") + //} + return rdb +} + +// buildRdb 构建redis client +func buildRdb(c *config.Redis) *redis.Client { + + rdb := redis.NewClient(&redis.Options{ + Addr: c.Host, + Password: c.Pass, + ReadTimeout: time.Duration(c.Tls) * time.Second, + WriteTimeout: time.Duration(c.Tls) * time.Second, + PoolSize: int(c.PoolSize), + MinIdleConns: int(c.MaxIdle), + ConnMaxIdleTime: time.Duration(c.MaxIdleTime) * time.Second, + DB: int(c.Db), + }) + return rdb +} diff --git a/utils/utils_gorm/gorm.go b/utils/utils_gorm/gorm.go new file mode 100644 index 0000000..ec9c9f9 --- /dev/null +++ b/utils/utils_gorm/gorm.go @@ -0,0 +1,41 @@ +package utils_gorm + +import ( + "ai_scheduler/internal/config" + "database/sql" + "fmt" + "gorm.io/driver/mysql" + "gorm.io/gorm" + "time" +) + +func TransDB(c *config.DB) (*gorm.DB, func()) { + mysqlConn, err := sql.Open(c.Driver, c.Source) + gormDB, err := gorm.Open( + mysql.New(mysql.Config{Conn: mysqlConn}), + ) + + gormDB.Logger = NewCustomLogger(gormDB) + if err != nil { + panic("failed to connect database") + } + sqlDB, err := gormDB.DB() + + // SetMaxIdleConns sets the maximum number of connections in the idle connection pool. + sqlDB.SetMaxIdleConns(int(c.MaxIdle)) + + // SetMaxOpenConns sets the maximum number of open connections to the database. + sqlDB.SetMaxOpenConns(int(c.MaxLifetime)) + + // SetConnMaxLifetime sets the maximum amount of time a connection may be reused. + sqlDB.SetConnMaxLifetime(time.Hour) + + return gormDB, func() { + if mysqlConn != nil { + fmt.Println("关闭 physicalGoodsDB") + if err := mysqlConn.Close(); err != nil { + fmt.Println("关闭 physicalGoodsDB 失败:", err) + } + } + } +} diff --git a/utils/utils_gorm/sql_log.go b/utils/utils_gorm/sql_log.go new file mode 100644 index 0000000..b34a29d --- /dev/null +++ b/utils/utils_gorm/sql_log.go @@ -0,0 +1,96 @@ +package utils_gorm + +import ( + "context" + "fmt" + "gorm.io/gorm" + "gorm.io/gorm/logger" + "regexp" + "strings" + "time" +) + +type CustomLogger struct { + gormLogger logger.Interface + db *gorm.DB +} + +func NewCustomLogger(db *gorm.DB) *CustomLogger { + return &CustomLogger{ + gormLogger: logger.Default.LogMode(logger.Info), + db: db, + } +} + +func (l *CustomLogger) LogMode(level logger.LogLevel) logger.Interface { + newlogger := *l + newlogger.gormLogger = l.gormLogger.LogMode(level) + return &newlogger +} + +func (l *CustomLogger) Info(ctx context.Context, msg string, data ...interface{}) { + l.gormLogger.Info(ctx, msg, data...) +} + +func (l *CustomLogger) Warn(ctx context.Context, msg string, data ...interface{}) { + l.gormLogger.Warn(ctx, msg, data...) +} + +func (l *CustomLogger) Error(ctx context.Context, msg string, data ...interface{}) { + l.gormLogger.Error(ctx, msg, data...) +} + +func (l *CustomLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { + elapsed := time.Since(begin) + sql, _ := fc() + l.gormLogger.Trace(ctx, begin, fc, err) + operation := extractOperation(sql) + tableName := extractTableName(sql) + fmt.Println(tableName) + //// 将SQL语句保存到数据库 + if operation == 0 || tableName == "sql_log" { + return + } + //go l.db.Model(&SqlLog{}).Create(&SqlLog{ + // OperatorID: 1, + // OperatorName: "test", + // SqlInfo: sql, + // TableNames: tableName, + // Type: operation, + //}) + + // 如果有需要,也可以根据执行时间(elapsed)等条件过滤或处理日志记录 + if elapsed > time.Second { + //l.gormLogger.Warn(ctx, "Slow SQL (> 1s): %s", sql) + } +} + +// extractTableName extracts the table name from a SQL query, supporting quoted table names. +func extractTableName(sql string) string { + // 使用非捕获组匹配多种SQL操作关键词 + re := regexp.MustCompile(`(?i)\b(?:from|update|into|delete\s+from)\b\s+[\` + "`" + `"]?(\w+)[\` + "`" + `"]?`) + match := re.FindStringSubmatch(sql) + + // 检查是否匹配成功 + if len(match) > 1 { + return match[1] + } + + return "" +} + +// extractOperation extracts the operation type from a SQL query. +func extractOperation(sql string) int32 { + sql = strings.TrimSpace(strings.ToLower(sql)) + var operation int32 + if strings.HasPrefix(sql, "select") { + operation = 0 + } else if strings.HasPrefix(sql, "insert") { + operation = 1 + } else if strings.HasPrefix(sql, "update") { + operation = 3 + } else if strings.HasPrefix(sql, "delete") { + operation = 2 + } + return operation +} diff --git a/utils/utils_sys_cache/gods.go b/utils/utils_sys_cache/gods.go new file mode 100644 index 0000000..cf1bbc2 --- /dev/null +++ b/utils/utils_sys_cache/gods.go @@ -0,0 +1,14 @@ +package utils_sys_cache + +import ( + "github.com/emirpasic/gods/maps/hashmap" + "github.com/emirpasic/gods/maps/treemap" +) + +func NewCacheHashMap() *hashmap.Map { + return hashmap.New() +} + +func NewCacheTreeMap() *treemap.Map { + return treemap.NewWithIntComparator() +} diff --git a/utils/utils_test.go b/utils/utils_test.go new file mode 100644 index 0000000..63363f0 --- /dev/null +++ b/utils/utils_test.go @@ -0,0 +1,72 @@ +package utils + +import ( + "fmt" + "github.com/go-kratos/kratos/v2/log" + "google.golang.org/protobuf/runtime/protoimpl" + "gopkg.in/yaml.v3" + "io/fs" + "os" + "path/filepath" + "testing" + baseconf "trans_hub/base_conf" + "trans_hub/pkg" + "trans_hub/pkg/mapstructure" +) + +const SPACE = "public" +const PORT = 8848 +const User = "" +const Pass = "" +const IP = "192.168.110.93" +const Group = "DEFAULT_GROUP" +const DataId = "PG_BASE_CONFIG" + +func TestConfig(t *testing.T) { + type Nacos struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + Ip string `protobuf:"bytes,1,opt,name=ip,proto3" json:"ip,omitempty"` + Port uint64 `protobuf:"varint,2,opt,name=port,proto3" json:"port,omitempty"` + } + type Conf struct { + Nacos *Nacos `protobuf:"bytes,8,opt,name=nacos,proto3" json:"nacos,omitempty"` + } + var c Conf + nc := &baseconf.Nacos{Ip: IP, Port: PORT, Space: SPACE, User: User, Password: Pass} + + var s = ServerConfig(nc, Group, DataId) + err := mapstructure.Decode(s, &c) + t.Log(s, err) +} + +func TestMod(t *testing.T) { + dir := pkg.GetRootPath() + // 读取目录内容 + err := filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + if !d.IsDir() && filepath.Ext(path) == ".yaml" { + data, err := os.ReadFile(path) + if err != nil { + return err + } + var result map[string]interface{} + err = yaml.Unmarshal(data, &result) // 解析YAML到map中,使用gopkg.v3的yaml包或其他你选择的版本(例如encoding/yaml) + if err != nil { + return err + } + fmt.Printf("File: %s\nContent: %+v\n", path, result) + } + return nil + }) + if err != nil { + log.Fatal(err) + } +} + +func TestYaml(t *testing.T) { + t.Log(GetBaseYaml()) +}