commit 83f7d34f3ea5487c3527b005d080fa22c3da617d Author: renzhiyuan <465386466@qq.com> Date: Tue Apr 7 21:55:25 2026 +0800 1 diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/geogo.iml b/.idea/geogo.iml new file mode 100644 index 0000000..5e764c4 --- /dev/null +++ b/.idea/geogo.iml @@ -0,0 +1,9 @@ + + + + + + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..b26b745 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..3df7a74 --- /dev/null +++ b/Makefile @@ -0,0 +1,30 @@ +GOHOSTOS:=$(shell go env GOHOSTOS) +GOPATH:=$(shell go env GOPATH) +VERSION=$(shell git describe --tags --always) +PROJECTNAME=yumchina + +ifeq ($(GOHOSTOS), windows) + #the `find.exe` is different from `find` in bash/shell. + #to see https://docs.microsoft.com/en-us/windows-server/administration/windows-commands/find. + #changed to use git-bash.exe to run find cli or other cli friendly, caused of every developer has a Git. + #Git_Bash= $(subst cmd\,bin\bash.exe,$(dir $(shell where git))) + Git_Bash=$(subst \,/,$(subst cmd\,bin\bash.exe,$(dir $(shell where git | grep cmd)))) + INTERNAL_PROTO_FILES=$(shell $(Git_Bash) -c "find internal -name *.proto") + API_PROTO_FILES=$(shell $(Git_Bash) -c "find api -name *.proto") +else + INTERNAL_PROTO_FILES=$(shell find internal -name *.proto) + API_PROTO_FILES=$(shell find api -name *.proto) +endif + + +.PHONY: wire +# generate wire +wire: + cd ./cmd/server && wire + +.PHONY: build +# build +build: +# make config; + make wire; + mkdir -p bin/ && go build -ldflags "-X main.Version=$(VERSION)" -o ./bin/ ./... \ No newline at end of file diff --git a/cmd/server/main.go b/cmd/server/main.go new file mode 100644 index 0000000..ad0e0ce --- /dev/null +++ b/cmd/server/main.go @@ -0,0 +1,26 @@ +package main + +import ( + "fmt" + "geo/internal/config" + + "github.com/gofiber/fiber/v2/log" +) + +func main() { + bc, err := config.LoadConfig() + if err != nil { + log.Fatalf("加载配置失败: %v", err) + } + //ctx, cancel := context.WithCancel(nil) + app, cleanup, err := InitializeApp(bc, log.DefaultLogger()) + if err != nil { + log.Fatalf("项目初始化失败: %v", err) + } + defer func() { + cleanup() + //cancel() + }() + addr := fmt.Sprintf("0.0.0.0:%d", bc.Server.Port) + log.Fatal(app.HttpServer.Listen(addr)) +} diff --git a/cmd/server/wire.go b/cmd/server/wire.go new file mode 100644 index 0000000..3999fc5 --- /dev/null +++ b/cmd/server/wire.go @@ -0,0 +1,28 @@ +//go:build wireinject +// +build wireinject + +package main + +import ( + "geo/internal/biz" + "geo/internal/config" + "geo/internal/data/impl" + "geo/internal/server" + "geo/internal/service" + "geo/utils" + + "github.com/gofiber/fiber/v2/log" + "github.com/google/wire" +) + +// InitializeApp 初始化应用程序 +func InitializeApp(*config.Config, log.AllLogger) (*server.Servers, func(), error) { + panic(wire.Build( + server.ProviderSetServer, + + service.ProviderSetAppService, + impl.ProviderImpl, + utils.ProviderUtils, + biz.ProviderSetBiz, + )) +} diff --git a/cmd/server/wire_gen.go b/cmd/server/wire_gen.go new file mode 100644 index 0000000..033d04f --- /dev/null +++ b/cmd/server/wire_gen.go @@ -0,0 +1,41 @@ +// Code generated by Wire. DO NOT EDIT. + +//go:generate go run -mod=mod github.com/google/wire/cmd/wire +//go:build !wireinject +// +build !wireinject + +package main + +import ( + "geo/internal/biz" + "geo/internal/config" + "geo/internal/data/impl" + "geo/internal/server" + "geo/internal/server/router" + "geo/internal/service" + "geo/utils" + "github.com/gofiber/fiber/v2/log" +) + +// Injectors from wire.go: + +// InitializeApp 初始化应用程序 +func InitializeApp(configConfig *config.Config, allLogger log.AllLogger) (*server.Servers, func(), error) { + db, cleanup := utils.NewGormDb(configConfig) + tokenImpl := impl.NewTokenImpl(db) + userImpl := impl.NewUserImpl(db) + platImpl := impl.NewPlatImpl(db) + publishImpl := impl.NewPublishImpl(db) + loginRelationImpl := impl.NewLoginRelationImpl(db) + publishBiz := biz.NewPublishBiz(configConfig, publishImpl, userImpl, platImpl, tokenImpl, loginRelationImpl) + appService := service.NewAppService(configConfig, tokenImpl, userImpl, platImpl, publishBiz) + loginService := service.NewLoginService(configConfig, publishBiz) + publishService := service.NewPublishService(configConfig, publishBiz, db) + appModule := router.NewAppModule(configConfig, appService, loginService, publishService) + routerServer := router.NewRouterServer(appModule) + app := server.NewHTTPServer(routerServer) + servers := server.NewServers(configConfig, app) + return servers, func() { + cleanup() + }, nil +} diff --git a/cookies/0d86b848uu2183uu4a08/xhs.json b/cookies/0d86b848uu2183uu4a08/xhs.json new file mode 100644 index 0000000..0637a08 --- /dev/null +++ b/cookies/0d86b848uu2183uu4a08/xhs.json @@ -0,0 +1 @@ +[] \ No newline at end of file diff --git a/gen.bat b/gen.bat new file mode 100644 index 0000000..c9dbd11 --- /dev/null +++ b/gen.bat @@ -0,0 +1,31 @@ +@echo off +chcp 65001 >nul +setlocal enabledelayedexpansion + +REM Usage: gen.bat user +REM Example: gen.bat usercenter + +if "%1"=="" ( + echo Error: Please provide table name + echo Usage: gen.bat user + pause + exit /b 1 +) + +set tables=%1 +set modeldir=.\internal\data\model +set prefix=edu_ + +set fullTableName=%prefix%%tables% + +echo Generating table: %fullTableName% + +go run gorm.io/gen/tools/gentool -dsn "root:lansexiongdi6,@tcp(47.97.27.195:3306)/geo?charset=utf8mb4&parseTime=true&loc=Asia/Shanghai" -outPath "%modeldir%" -onlyModel -modelPkgName "model" -tables "%fullTableName%" + +if %errorlevel% equ 0 ( + echo Success! Generated files in: %modeldir% +) else ( + echo Failed to generate models +) + +pause \ No newline at end of file diff --git a/gen.sh b/gen.sh new file mode 100644 index 0000000..18e4ef4 --- /dev/null +++ b/gen.sh @@ -0,0 +1,19 @@ +#!/bin/bash + + # 使用方法: + # ./genModel.sh usercenter user + # ./genModel.sh usercenter user_auth + # 再将./genModel下的文件剪切到对应服务的model目录里面,记得改package + + + #生成的表名 + tables=$1 + #表生成的genmodel目录 + modeldir=./internal/data/model + + # 数据库配置 + prefix=edu_ + + + +gentool --dsn "root:lansexiongdi6,@tcp(47.97.27.195:3306)/geo?charset=utf8mb4&parseTime=true&loc=Asia%2FShanghai" -outPath ${modeldir} -onlyModel -modelPkgName "model" -tables ${prefix}${tables} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..4834fb1 --- /dev/null +++ b/go.mod @@ -0,0 +1,64 @@ +module geo + +go 1.26.1 + +require ( + github.com/JohannesKaufmann/html-to-markdown v1.6.0 + github.com/go-kratos/kratos/v2 v2.9.2 + github.com/go-playground/validator/v10 v10.30.2 + github.com/go-rod/rod v0.116.2 + github.com/go-viper/mapstructure/v2 v2.5.0 + github.com/gofiber/fiber/v2 v2.52.12 + github.com/golang-jwt/jwt/v5 v5.3.1 + github.com/google/uuid v1.6.0 + github.com/google/wire v0.7.0 + github.com/grokify/html-strip-tags-go v0.1.0 + github.com/redis/go-redis/v9 v9.18.0 + github.com/spf13/viper v1.21.0 + gorm.io/driver/mysql v1.6.0 + gorm.io/gorm v1.31.1 + xorm.io/builder v0.3.13 +) + +require ( + filippo.io/edwards25519 v1.1.0 // indirect + github.com/PuerkitoBio/goquery v1.9.2 // indirect + github.com/andybalholm/brotli v1.1.0 // indirect + github.com/andybalholm/cascadia v1.3.2 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/fsnotify/fsnotify v1.9.0 // indirect + github.com/gabriel-vasile/mimetype v1.4.13 // indirect + github.com/go-playground/locales v0.14.1 // indirect + github.com/go-playground/universal-translator v0.18.1 // indirect + github.com/go-sql-driver/mysql v1.8.1 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect + github.com/klauspost/compress v1.17.9 // indirect + github.com/leodido/go-urn v1.4.0 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mattn/go-runewidth v0.0.16 // indirect + github.com/pelletier/go-toml/v2 v2.2.4 // indirect + github.com/rivo/uniseg v0.2.0 // indirect + github.com/sagikazarmark/locafero v0.11.0 // indirect + github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect + github.com/spf13/afero v1.15.0 // indirect + github.com/spf13/cast v1.10.0 // indirect + github.com/spf13/pflag v1.0.10 // indirect + github.com/subosito/gotenv v1.6.0 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + github.com/valyala/fasthttp v1.51.0 // indirect + github.com/valyala/tcplisten v1.0.0 // indirect + github.com/ysmood/fetchup v0.2.3 // indirect + github.com/ysmood/goob v0.4.0 // indirect + github.com/ysmood/got v0.40.0 // indirect + github.com/ysmood/gson v0.7.3 // indirect + github.com/ysmood/leakless v0.9.0 // indirect + go.uber.org/atomic v1.11.0 // indirect + go.yaml.in/yaml/v3 v3.0.4 // indirect + golang.org/x/crypto v0.49.0 // indirect + golang.org/x/net v0.51.0 // indirect + golang.org/x/sys v0.42.0 // indirect + golang.org/x/text v0.35.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..8a5547e --- /dev/null +++ b/go.sum @@ -0,0 +1,222 @@ +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +gitea.com/xorm/sqlfiddle v0.0.0-20180821085327-62ce714f951a h1:lSA0F4e9A2NcQSqGqTOXqu2aRi/XEQxDCBwM8yJtE6s= +gitea.com/xorm/sqlfiddle v0.0.0-20180821085327-62ce714f951a/go.mod h1:EXuID2Zs0pAQhH8yz+DNjUbjppKQzKFAn28TMYPB6IU= +github.com/JohannesKaufmann/html-to-markdown v1.6.0 h1:04VXMiE50YYfCfLboJCLcgqF5x+rHJnb1ssNmqpLH/k= +github.com/JohannesKaufmann/html-to-markdown v1.6.0/go.mod h1:NUI78lGg/a7vpEJTz/0uOcYMaibytE4BUOQS8k78yPQ= +github.com/PuerkitoBio/goquery v1.9.2 h1:4/wZksC3KgkQw7SQgkKotmKljk0M6V8TUvA8Wb4yPeE= +github.com/PuerkitoBio/goquery v1.9.2/go.mod h1:GHPCaP0ODyyxqcNoFGYlAprUFH81NuRPd0GX3Zu2Mvk= +github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M= +github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY= +github.com/andybalholm/cascadia v1.3.2 h1:3Xi6Dw5lHF15JtdcmAHD3i1+T8plmv7BQ/nsViSLyss= +github.com/andybalholm/cascadia v1.3.2/go.mod h1:7gtRlve5FxPPgIgX36uWBX58OdBsSS6lUvCFb+h7KvU= +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= +github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= +github.com/gabriel-vasile/mimetype v1.4.13 h1:46nXokslUBsAJE/wMsp5gtO500a4F3Nkz9Ufpk2AcUM= +github.com/gabriel-vasile/mimetype v1.4.13/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s= +github.com/go-kratos/kratos/v2 v2.9.2 h1:px8GJQBeLpquDKQWQ9zohEWiLA8n4D/pv7aH3asvUvo= +github.com/go-kratos/kratos/v2 v2.9.2/go.mod h1:Jc7jaeYd4RAPjetun2C+oFAOO7HNMHTT/Z4LxpuEDJM= +github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= +github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= +github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= +github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= +github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= +github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= +github.com/go-playground/validator/v10 v10.30.2 h1:JiFIMtSSHb2/XBUbWM4i/MpeQm9ZK2xqPNk8vgvu5JQ= +github.com/go-playground/validator/v10 v10.30.2/go.mod h1:mAf2pIOVXjTEBrwUMGKkCWKKPs9NheYGabeB04txQSc= +github.com/go-rod/rod v0.116.2 h1:A5t2Ky2A+5eD/ZJQr1EfsQSe5rms5Xof/qj296e+ZqA= +github.com/go-rod/rod v0.116.2/go.mod h1:H+CMO9SCNc2TJ2WfrG+pKhITz57uGNYU43qYHh438Mg= +github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= +github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= +github.com/go-viper/mapstructure/v2 v2.5.0 h1:vM5IJoUAy3d7zRSVtIwQgBj7BiWtMPfmPEgAXnvj1Ro= +github.com/go-viper/mapstructure/v2 v2.5.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= +github.com/gofiber/fiber/v2 v2.52.12 h1:0LdToKclcPOj8PktUdIKo9BUohjjwfnQl42Dhw8/WUw= +github.com/gofiber/fiber/v2 v2.52.12/go.mod h1:YEcBbO/FB+5M1IZNBP9FO3J9281zgPAreiI1oqg8nDw= +github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY= +github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4= +github.com/google/wire v0.7.0/go.mod h1:n6YbUQD9cPKTnHXEBN2DXlOp/mVADhVErcMFb0v3J18= +github.com/grokify/html-strip-tags-go v0.1.0 h1:03UrQLjAny8xci+R+qjCce/MYnpNXCtgzltlQbOBae4= +github.com/grokify/html-strip-tags-go v0.1.0/go.mod h1:ZdzgfHEzAfz9X6Xe5eBLVblWIxXfYSQ40S/VKrAOGpc= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= +github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= +github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4= +github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= +github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= +github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= +github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/redis/go-redis/v9 v9.18.0 h1:pMkxYPkEbMPwRdenAzUNyFNrDgHx9U+DrBabWNfSRQs= +github.com/redis/go-redis/v9 v9.18.0/go.mod h1:k3ufPphLU5YXwNTUcCRXGxUoF1fqxnhFQmscfkCoDA0= +github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= +github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= +github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= +github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= +github.com/sagikazarmark/locafero v0.11.0 h1:1iurJgmM9G3PA/I+wWYIOw/5SyBtxapeHDcg+AAIFXc= +github.com/sagikazarmark/locafero v0.11.0/go.mod h1:nVIGvgyzw595SUSUE6tvCp3YYTeHs15MvlmU87WwIik= +github.com/sebdah/goldie/v2 v2.5.3 h1:9ES/mNN+HNUbNWpVAlrzuZ7jE+Nrczbj8uFRjM7624Y= +github.com/sebdah/goldie/v2 v2.5.3/go.mod h1:oZ9fp0+se1eapSRjfYbsV/0Hqhbuu3bJVvKI/NNtssI= +github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= +github.com/sergi/go-diff v1.3.1 h1:xkr+Oxo4BOQKmkn/B9eMK0g5Kg/983T9DqqPHwYqD+8= +github.com/sergi/go-diff v1.3.1/go.mod h1:aMJSSKb2lpPvRNec0+w3fl7LP9IOFzdc9Pa4NFbPK1I= +github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 h1:+jumHNA0Wrelhe64i8F6HNlS8pkoyMv5sreGx2Ry5Rw= +github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8/go.mod h1:3n1Cwaq1E1/1lhQhtRK2ts/ZwZEhjcQeJQ1RuC6Q/8U= +github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I= +github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg= +github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY= +github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= +github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk= +github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU= +github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= +github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasthttp v1.51.0 h1:8b30A5JlZ6C7AS81RsWjYMQmrZG6feChmgAolCl1SqA= +github.com/valyala/fasthttp v1.51.0/go.mod h1:oI2XroL+lI7vdXyYoQk03bXBThfFl2cVdIA3Xl7cH8g= +github.com/valyala/tcplisten v1.0.0 h1:rBHj/Xf+E1tRGZyWIWwJDiRY0zc1Js+CV5DqwacVSA8= +github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc= +github.com/ysmood/fetchup v0.2.3 h1:ulX+SonA0Vma5zUFXtv52Kzip/xe7aj4vqT5AJwQ+ZQ= +github.com/ysmood/fetchup v0.2.3/go.mod h1:xhibcRKziSvol0H1/pj33dnKrYyI2ebIvz5cOOkYGns= +github.com/ysmood/goob v0.4.0 h1:HsxXhyLBeGzWXnqVKtmT9qM7EuVs/XOgkX7T6r1o1AQ= +github.com/ysmood/goob v0.4.0/go.mod h1:u6yx7ZhS4Exf2MwciFr6nIM8knHQIE22lFpWHnfql18= +github.com/ysmood/gop v0.2.0 h1:+tFrG0TWPxT6p9ZaZs+VY+opCvHU8/3Fk6BaNv6kqKg= +github.com/ysmood/gop v0.2.0/go.mod h1:rr5z2z27oGEbyB787hpEcx4ab8cCiPnKxn0SUHt6xzk= +github.com/ysmood/got v0.40.0 h1:ZQk1B55zIvS7zflRrkGfPDrPG3d7+JOza1ZkNxcc74Q= +github.com/ysmood/got v0.40.0/go.mod h1:W7DdpuX6skL3NszLmAsC5hT7JAhuLZhByVzHTq874Qg= +github.com/ysmood/gotrace v0.6.0 h1:SyI1d4jclswLhg7SWTL6os3L1WOKeNn/ZtzVQF8QmdY= +github.com/ysmood/gotrace v0.6.0/go.mod h1:TzhIG7nHDry5//eYZDYcTzuJLYQIkykJzCRIo4/dzQM= +github.com/ysmood/gson v0.7.3 h1:QFkWbTH8MxyUTKPkVWAENJhxqdBa4lYTQWqZCiLG6kE= +github.com/ysmood/gson v0.7.3/go.mod h1:3Kzs5zDl21g5F/BlLTNcuAGAYLKt2lV5G8D1zF3RNmg= +github.com/ysmood/leakless v0.9.0 h1:qxCG5VirSBvmi3uynXFkcnLMzkphdh3xx5FtrORwDCU= +github.com/ysmood/leakless v0.9.0/go.mod h1:R8iAXPRaG97QJwqxs74RdwzcRHT1SWCGTNqY8q0JvMQ= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +github.com/yuin/goldmark v1.7.1 h1:3bajkSilaCbjdKVsKdZjZCLBNPL9pYzrCakKaf4U49U= +github.com/yuin/goldmark v1.7.1/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= +github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0= +github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= +go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= +go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= +golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M= +golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= +golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4= +golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= +golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= +golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= +golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8= +golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= +golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo= +golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= +golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= +golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY= +golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= +golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= +golang.org/x/term v0.19.0/go.mod h1:2CuTdWZ7KHSQwUzKva0cbMg6q2DMI3Mmxp+gKJbskEk= +golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= +golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/driver/mysql v1.6.0 h1:eNbLmNTpPpTOVZi8MMxCi2aaIm0ZpInbORNXDwyLGvg= +gorm.io/driver/mysql v1.6.0/go.mod h1:D/oCC2GWK3M/dqoLxnOlaNKmXz8WNTfcS9y5ovaSqKo= +gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg= +gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs= +xorm.io/builder v0.3.13 h1:a3jmiVVL19psGeXx8GIurTp7p0IIgqeDmwhcR6BAOAo= +xorm.io/builder v0.3.13/go.mod h1:aUW0S9eb9VCaPohFCH3j7czOx1PMW3i1HrSzbLYGBSE= diff --git a/internal/biz/provider_set.go b/internal/biz/provider_set.go new file mode 100644 index 0000000..144f86e --- /dev/null +++ b/internal/biz/provider_set.go @@ -0,0 +1,9 @@ +package biz + +import ( + "github.com/google/wire" +) + +var ProviderSetBiz = wire.NewSet( + NewPublishBiz, +) diff --git a/internal/biz/public.go b/internal/biz/public.go new file mode 100644 index 0000000..410eca6 --- /dev/null +++ b/internal/biz/public.go @@ -0,0 +1,102 @@ +package biz + +import ( + "context" + "time" + + "geo/internal/config" + "geo/internal/data/impl" + "geo/internal/data/model" + "geo/tmpl/errcode" + + "xorm.io/builder" +) + +type PublishBiz struct { + cfg *config.Config + publishImpl *impl.PublishImpl + userImpl *impl.UserImpl + platImpl *impl.PlatImpl + tokenImpl *impl.TokenImpl + loginRelationImpl *impl.LoginRelationImpl +} + +func NewPublishBiz( + cfg *config.Config, + publishImpl *impl.PublishImpl, + userImpl *impl.UserImpl, + platImpl *impl.PlatImpl, + tokenImpl *impl.TokenImpl, + loginRelationImpl *impl.LoginRelationImpl, + +) *PublishBiz { + return &PublishBiz{ + cfg: cfg, + publishImpl: publishImpl, + loginRelationImpl: loginRelationImpl, + userImpl: userImpl, + platImpl: platImpl, + tokenImpl: tokenImpl, + } +} + +func (b *PublishBiz) ValidateAccessToken(ctx context.Context, accessToken string) (*model.Token, error) { + cond := builder.NewCond(). + And(builder.Eq{"access_token": accessToken}). + And(builder.Eq{"status": 1}) + tokenInfo := &model.Token{} + err := b.tokenImpl.GetOneBySearchStruct(ctx, &cond, tokenInfo) + if err != nil || tokenInfo == nil { + return nil, errcode.Forbidden("密钥无效或已禁用") + } + return tokenInfo, nil +} + +func (b *PublishBiz) BatchInsertPublish(ctx context.Context, list []*model.Publish) error { + + return b.publishImpl.Add(ctx, list) +} + +func (b *PublishBiz) GetPendingPublish(ctx context.Context, tokenID int) (map[string]interface{}, error) { + + currentTime := time.Now() + cond := builder.NewCond(). + And(builder.Eq{"p.token_id": tokenID}). + And(builder.Eq{"p.status": 1}). + And(builder.Lte{"p.publish_time": currentTime}) + + return b.publishImpl.GetOneWithPlat(ctx, &cond) +} + +func (b *PublishBiz) GetTaskByRequestID(ctx context.Context, requestID string) (map[string]interface{}, error) { + cond := builder.NewCond(). + And(builder.Eq{"p.request_id": requestID}) + return b.publishImpl.GetOneWithPlat(ctx, &cond) +} + +func (b *PublishBiz) UpdatePublishStatus(ctx context.Context, requestID string, status int, msg string) error { + return b.publishImpl.UpdateStatus(ctx, requestID, status, msg) +} + +func (b *PublishBiz) GetPublishList(ctx context.Context, tokenID int32, page, pageSize int, filters map[string]interface{}) ([]map[string]interface{}, int64, error) { + return b.publishImpl.GetListWithUser(ctx, tokenID, page, pageSize, filters) +} + +func (b *PublishBiz) GetPlatInfo(ctx context.Context, platIndex string) (*model.Plat, error) { + cond := builder.NewCond(). + And(builder.Eq{"`index`": platIndex}). + And(builder.Eq{"status": 1}) + plat := &model.Plat{} + err := b.platImpl.GetOneBySearchStruct(ctx, &cond, plat) + if err != nil { + return nil, err + } + return plat, nil +} + +func (b *PublishBiz) UpdateLoginStatus(ctx context.Context, userIndex, platIndex string, loginStatus int32) error { + cond := builder.NewCond(). + And(builder.Eq{"user_index": userIndex}). + And(builder.Eq{"plat_index": platIndex}) + return b.loginRelationImpl.UpdateByCond(ctx, &cond, &model.User{Status: loginStatus}) +} diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..cebafd6 --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,72 @@ +package config + +import ( + "os" + "path/filepath" +) + +// Config 应用配置 +type Config struct { + Server ServerConfig `mapstructure:"server"` + DB DB `mapstructure:"db"` + Sys Sys `mapstructure:"sys"` +} + +// ServerConfig 服务器配置 +type ServerConfig struct { + Port int `mapstructure:"port"` + Host string `mapstructure:"host"` +} + +type DB struct { + Driver string `mapstructure:"driver"` + Source string `mapstructure:"source"` + MaxIdle int32 `mapstructure:"maxIdle"` + MaxOpen int32 `mapstructure:"maxOpen"` + MaxLifetime int32 `mapstructure:"maxLifetime"` + IsDebug bool `mapstructure:"isDebug"` +} + +type Sys struct { + MaxConcurrent int `mapstructure:"maxConcurrent"` + TaskTimeout int `mapstructure:"taskTimeout"` + SessionTimeout int `mapstructure:"sessionTimeout "` + MaxImageSize int `mapstructure:"maxImageSize"` + LogsDir string `mapstructure:"logsDir"` + UploadDir string `mapstructure:"uploadDir"` + VideosDir string `mapstructure:"videosDir"` + DocsDir string `mapstructure:"docsDir"` + CookiesDir string `mapstructure:"cookiesDir"` + QrcodesDir string `mapstructure:"qrcodesDir"` + ChromePath string `mapstructure:"chromePath"` + ChromeDataDir string `mapstructure:"chromeDataDir"` +} + +// LoadConfig 加载配置 +func LoadConfig() (*Config, error) { + BaseDir, _ := os.Getwd() + return &Config{ + Server: ServerConfig{ + Port: 5001, + Host: "0.0.0.0", + }, + DB: DB{ + Driver: "mysql", + Source: "root:lansexiongdi6,@tcp(47.97.27.195:3306)/geo?charset=utf8mb4&parseTime=true&loc=Asia%2FShanghai", + }, + Sys: Sys{ + MaxConcurrent: 1, + TaskTimeout: 60, + SessionTimeout: 300, + MaxImageSize: 5 * 1024 * 1024, + LogsDir: filepath.Join(BaseDir, "logs"), + UploadDir: filepath.Join(BaseDir, "images"), + VideosDir: filepath.Join(BaseDir, "videos"), + DocsDir: filepath.Join(BaseDir, "docs"), + CookiesDir: filepath.Join(BaseDir, "cookies"), + QrcodesDir: filepath.Join(BaseDir, "qrcodes"), + ChromePath: "./chrome/chrome.exe", + ChromeDataDir: "./chrome_data", + }, + }, nil +} diff --git a/internal/data/impl/login_relation.go b/internal/data/impl/login_relation.go new file mode 100644 index 0000000..2a4c301 --- /dev/null +++ b/internal/data/impl/login_relation.go @@ -0,0 +1,27 @@ +package impl + +import ( + "geo/internal/data/model" + "geo/tmpl/dataTemp" + "geo/utils" +) + +type LoginRelationImpl struct { + dataTemp.DataTemp + db *utils.Db +} + +func NewLoginRelationImpl(db *utils.Db) *LoginRelationImpl { + return &LoginRelationImpl{ + DataTemp: *dataTemp.NewDataTemp(db, new(model.LoginRelation)), + db: db, + } +} + +func (m *LoginRelationImpl) PrimaryKey() string { + return "id" +} + +func (m *LoginRelationImpl) GetTemp() *dataTemp.DataTemp { + return &m.DataTemp +} diff --git a/internal/data/impl/plat.go b/internal/data/impl/plat.go new file mode 100644 index 0000000..2989190 --- /dev/null +++ b/internal/data/impl/plat.go @@ -0,0 +1,88 @@ +package impl + +import ( + "context" + "fmt" + "geo/internal/data/model" + "geo/tmpl/dataTemp" + "geo/utils" + + "xorm.io/builder" +) + +type PlatImpl struct { + dataTemp.DataTemp + db *utils.Db +} + +func NewPlatImpl(db *utils.Db) *PlatImpl { + return &PlatImpl{ + DataTemp: *dataTemp.NewDataTemp(db, new(model.Plat)), + db: db, + } +} + +func (p *PlatImpl) PrimaryKey() string { + return "id" +} + +func (p *PlatImpl) GetTemp() *dataTemp.DataTemp { + return &p.DataTemp +} + +// GetPlatListWithLoginStatus 获取平台列表并关联登录状态 +func (p *PlatImpl) GetPlatListWithLoginStatus(ctx context.Context, cond *builder.Cond) ([]map[string]interface{}, error) { + query, err := builder.ToBoundSQL(*cond) + if err != nil { + return nil, err + } + + sql := ` + SELECT + plat.name, + plat.img_url, + plat.index, + plat.desc, + COALESCE(login_relation.login_status, 2) as login_status + FROM plat + LEFT JOIN login_relation ON login_relation.plat_index COLLATE utf8mb4_unicode_ci = plat.index AND login_relation.status = 1 + WHERE %s + ORDER BY plat.id ASC + ` + + finalSQL := fmt.Sprintf(sql, query) + + var results []map[string]interface{} + err = p.db.Client.WithContext(ctx).Raw(finalSQL).Scan(&results).Error + if err != nil { + return nil, err + } + return results, nil +} + +// GetPlatListWithLoginStatusByUserIndex 根据用户索引获取平台列表及登录状态 +func (p *PlatImpl) GetPlatListWithLoginStatusByUserIndex(ctx context.Context, userIndex string) ([]map[string]interface{}, error) { + sql := ` + SELECT + pl.name, + pl.img_url, + pl.index, + pl.desc, + pl.login_url, + pl.edit_url, + pl.logined_url, + COALESCE(lr.login_status, 2) as login_status + FROM plat pl + LEFT JOIN login_relation lr ON lr.plat_index = pl.index AND lr.user_index = ? AND lr.status = 1 + WHERE pl.status = 1 + ORDER BY pl.id ASC + ` + + var results []map[string]interface{} + + err := p.db.Client.WithContext(ctx).Raw(sql, userIndex).Scan(&results).Error + if err != nil { + return nil, err + } + return results, nil +} diff --git a/internal/data/impl/provider_set.go b/internal/data/impl/provider_set.go new file mode 100644 index 0000000..bf7aaa0 --- /dev/null +++ b/internal/data/impl/provider_set.go @@ -0,0 +1,13 @@ +package impl + +import ( + "github.com/google/wire" +) + +var ProviderImpl = wire.NewSet( + NewPlatImpl, + NewLoginRelationImpl, + NewUserImpl, + NewTokenImpl, + NewPublishImpl, +) diff --git a/internal/data/impl/publish.go b/internal/data/impl/publish.go new file mode 100644 index 0000000..2888659 --- /dev/null +++ b/internal/data/impl/publish.go @@ -0,0 +1,244 @@ +package impl + +import ( + "context" + "fmt" + "geo/internal/data/model" + "geo/tmpl/dataTemp" + "geo/utils" + "time" + + "xorm.io/builder" +) + +type PublishImpl struct { + dataTemp.DataTemp + db *utils.Db +} + +func NewPublishImpl(db *utils.Db) *PublishImpl { + return &PublishImpl{ + DataTemp: *dataTemp.NewDataTemp(db, new(model.Publish)), + db: db, + } +} + +func (m *PublishImpl) PrimaryKey() string { + return "id" +} + +func (m *PublishImpl) GetTemp() *dataTemp.DataTemp { + return &m.DataTemp +} + +// BatchInsert 批量插入发布记录 +func (p *PublishImpl) BatchInsert(ctx context.Context, tokenID int, records []map[string]interface{}) (int64, error) { + if len(records) == 0 { + return 0, nil + } + + sql := `INSERT INTO publish + (token_id, user_index, request_id, title, tag, type, plat_index, url, publish_time, status, create_time, img) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)` + + var total int64 + for _, record := range records { + result := p.db.Client.WithContext(ctx).Exec(sql, + tokenID, + record["user_index"], + record["request_id"], + record["title"], + record["tag"], + record["type"], + record["plat_index"], + record["url"], + record["publish_time"], + 1, // status = 1 待发布 + time.Now(), + record["img"], + ) + if result.Error != nil { + return total, result.Error + } + total += result.RowsAffected + } + return total, nil +} + +// GetOneWithPlat 查询单条发布记录并关联平台信息 +func (p *PublishImpl) GetOneWithPlat(ctx context.Context, cond *builder.Cond) (map[string]interface{}, error) { + query, err := builder.ToBoundSQL(*cond) + if err != nil { + return nil, err + } + // 构建带平台信息的查询 + sql := fmt.Sprintf(` + SELECT + p.id, + p.token_id, + p.user_index, + p.request_id, + p.title, + p.tag, + p.type, + p.plat_index, + p.url, + p.publish_time, + p.status, + p.create_time, + p.img, + p.msg, + pl.index as plat_index_db, + pl.name as plat_name, + pl.login_url, + pl.edit_url, + pl.logined_url, + pl.desc as plat_desc, + pl.img_url as plat_img_url + FROM publish p + INNER JOIN plat pl ON p.plat_index = pl.index AND pl.status = 1 + WHERE %s + LIMIT 1 + `, query) + + var result map[string]interface{} + err = p.db.Client.WithContext(ctx).Raw(sql).Scan(&result).Error + if err != nil { + return nil, err + } + if result == nil || len(result) == 0 { + return nil, nil + } + return result, nil +} + +// GetOneWithPlatByRequestID 根据request_id查询发布记录并关联平台信息 +func (p *PublishImpl) GetOneWithPlatByRequestID(ctx context.Context, requestID string) (map[string]interface{}, error) { + sql := ` + SELECT + p.id, + p.token_id, + p.user_index, + p.request_id, + p.title, + p.tag, + p.type, + p.plat_index, + p.url, + p.publish_time, + p.status, + p.create_time, + p.img, + p.msg, + pl.index as plat_index_db, + pl.name as plat_name, + pl.login_url, + pl.edit_url, + pl.logined_url, + pl.desc as plat_desc, + pl.img_url as plat_img_url + FROM publish p + INNER JOIN plat pl ON p.plat_index = pl.index AND pl.status = 1 + WHERE p.request_id = ? + LIMIT 1 + ` + + var result map[string]interface{} + err := p.db.Client.WithContext(ctx).Raw(sql, requestID).Scan(&result).Error + if err != nil { + return nil, err + } + if result == nil || len(result) == 0 { + return nil, nil + } + return result, nil +} + +// UpdateStatus 更新发布状态 +func (p *PublishImpl) UpdateStatus(ctx context.Context, requestID string, status int, msg string) error { + updateData := map[string]interface{}{ + "status": status, + } + if msg != "" { + updateData["msg"] = msg + } + return p.db.Client.WithContext(ctx). + Model(&model.Publish{}). + Where("request_id = ?", requestID). + Updates(updateData).Error +} + +// GetListWithUser 获取发布列表并关联用户信息 +func (p *PublishImpl) GetListWithUser(ctx context.Context, tokenID int32, page, pageSize int, filters map[string]interface{}) ([]map[string]interface{}, int64, error) { + // 构建基础查询 + query := p.db.Client.WithContext(ctx). + Table("publish p"). + Select(` + p.id, + p.user_index, + u.name as user_name, + p.request_id, + p.title, + p.tag, + p.type, + CASE + WHEN p.type = 1 THEN '文章' + WHEN p.type = 2 THEN '视频' + ELSE '未知' + END as type_name, + p.plat_index, + pl.name as plat_name, + p.url, + p.publish_time, + p.status, + CASE + WHEN p.status = 1 THEN '待发布' + WHEN p.status = 2 THEN '发布中' + WHEN p.status = 3 THEN '发布失败' + WHEN p.status = 4 THEN '发布成功' + ELSE '未知' + END as status_name, + p.create_time, + p.msg + `). + Joins("LEFT JOIN user u ON p.user_index = u.user_index"). + Joins("LEFT JOIN plat pl ON p.plat_index = pl.index"). + Where("u.token_id = ?", tokenID) + + // 添加过滤条件 + if userIndex, ok := filters["user_index"]; ok && userIndex != "" { + query = query.Where("p.user_index = ?", userIndex) + } + if tag, ok := filters["tag"]; ok && tag != "" { + query = query.Where("p.tag LIKE ?", "%"+tag.(string)+"%") + } + if typeFilter, ok := filters["type"]; ok && typeFilter != 0 { + query = query.Where("p.type = ?", typeFilter) + } + if platIndex, ok := filters["plat_index"]; ok && platIndex != "" { + query = query.Where("p.plat_index = ?", platIndex) + } + if status, ok := filters["status"]; ok && status != 0 { + query = query.Where("p.status = ?", status) + } + if requestID, ok := filters["request_id"]; ok && requestID != "" { + query = query.Where("p.request_id LIKE ?", "%"+requestID.(string)+"%") + } + + // 获取总数 + var total int64 + if err := query.Count(&total).Error; err != nil { + return nil, 0, err + } + + // 分页查询 + var results []map[string]interface{} + offset := (page - 1) * pageSize + err := query. + Order("p.publish_time DESC, p.create_time DESC"). + Limit(pageSize). + Offset(offset). + Scan(&results).Error + + return results, total, err +} diff --git a/internal/data/impl/token.go b/internal/data/impl/token.go new file mode 100644 index 0000000..c66d096 --- /dev/null +++ b/internal/data/impl/token.go @@ -0,0 +1,31 @@ +package impl + +import ( + "geo/internal/data/model" + "geo/tmpl/dataTemp" + "geo/utils" +) + +type TokenImpl struct { + dataTemp.DataTemp + db *utils.Db +} + +func NewTokenImpl(db *utils.Db) *TokenImpl { + return &TokenImpl{ + DataTemp: *dataTemp.NewDataTemp(db, new(model.Token)), + db: db, + } +} + +func (m *TokenImpl) PrimaryKey() string { + return "id" +} + +func (m *TokenImpl) GetTemp() *dataTemp.DataTemp { + return &m.DataTemp +} + +func (m *TokenImpl) GetDb() *utils.Db { + return m.db +} diff --git a/internal/data/impl/user.go b/internal/data/impl/user.go new file mode 100644 index 0000000..b364540 --- /dev/null +++ b/internal/data/impl/user.go @@ -0,0 +1,27 @@ +package impl + +import ( + "geo/internal/data/model" + "geo/tmpl/dataTemp" + "geo/utils" +) + +type UserImpl struct { + dataTemp.DataTemp + db *utils.Db +} + +func NewUserImpl(db *utils.Db) *UserImpl { + return &UserImpl{ + DataTemp: *dataTemp.NewDataTemp(db, new(model.User)), + db: db, + } +} + +func (m *UserImpl) PrimaryKey() string { + return "id" +} + +func (m *UserImpl) GetTemp() *dataTemp.DataTemp { + return &m.DataTemp +} diff --git a/internal/data/model/login_relation.gen.go b/internal/data/model/login_relation.gen.go new file mode 100644 index 0000000..03b62ad --- /dev/null +++ b/internal/data/model/login_relation.gen.go @@ -0,0 +1,21 @@ +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. + +package model + +const TableNameLoginRelation = "login_relation" + +// LoginRelation mapped from table +type LoginRelation struct { + ID int32 `gorm:"column:id;primaryKey;autoIncrement:true" json:"id"` + UserIndex string `gorm:"column:user_index;not null;default:0" json:"user_index"` + PlatIndex string `gorm:"column:plat_index;not null;default:0" json:"plat_index"` + LoginStatus int32 `gorm:"column:login_status;not null;default:2" json:"login_status"` + Status int32 `gorm:"column:status;not null;default:1" json:"status"` +} + +// TableName LoginRelation's table name +func (*LoginRelation) TableName() string { + return TableNameLoginRelation +} diff --git a/internal/data/model/plat.gen.go b/internal/data/model/plat.gen.go new file mode 100644 index 0000000..e129363 --- /dev/null +++ b/internal/data/model/plat.gen.go @@ -0,0 +1,25 @@ +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. + +package model + +const TableNamePlat = "plat" + +// Plat mapped from table +type Plat struct { + ID int32 `gorm:"column:id;primaryKey" json:"id"` + Name string `gorm:"column:name;not null" json:"name"` + Index string `gorm:"column:index;not null" json:"index"` + ImgURL string `gorm:"column:img_url;not null" json:"img_url"` + LoginURL string `gorm:"column:login_url;not null" json:"login_url"` + EditURL string `gorm:"column:edit_url;not null" json:"edit_url"` + LoginedURL string `gorm:"column:logined_url;not null" json:"logined_url"` + Desc string `gorm:"column:desc;not null" json:"desc"` + Status bool `gorm:"column:status;not null;default:1" json:"status"` +} + +// TableName Plat's table name +func (*Plat) TableName() string { + return TableNamePlat +} diff --git a/internal/data/model/publish.gen.go b/internal/data/model/publish.gen.go new file mode 100644 index 0000000..6318312 --- /dev/null +++ b/internal/data/model/publish.gen.go @@ -0,0 +1,34 @@ +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. + +package model + +import ( + "time" +) + +const TableNamePublish = "publish" + +// Publish mapped from table +type Publish struct { + ID int64 `gorm:"column:id;primaryKey;autoIncrement:true" json:"id"` + TokenID int32 `gorm:"column:token_id;not null" json:"token_id"` + UserIndex string `gorm:"column:user_index;not null;comment:关联user.user_index" json:"user_index"` // 关联user.user_index + RequestID string `gorm:"column:request_id;not null;comment:日志id" json:"request_id"` // 日志id + Title string `gorm:"column:title;not null" json:"title"` + Tag string `gorm:"column:tag;not null;comment:标签,多个英文都好隔开’,‘" json:"tag"` // 标签,多个英文都好隔开’,‘ + Type int32 `gorm:"column:type;not null;default:1;comment:1:文章,2:视频" json:"type"` // 1:文章,2:视频 + PlatIndex string `gorm:"column:plat_index;not null;comment:关联plat.index" json:"plat_index"` // 关联plat.index + URL string `gorm:"column:url;not null;comment:资源文件下载地址" json:"url"` // 资源文件下载地址 + PublishTime time.Time `gorm:"column:publish_time;not null;comment:发布时间" json:"publish_time"` // 发布时间 + Img string `gorm:"column:img;not null" json:"img"` + Status int32 `gorm:"column:status;not null;default:1;comment:1:待发布,2:发布中,3:发布失败,4:发布成功" json:"status"` // 1:待发布,2:发布中,3:发布失败,4:发布成功 + CreateTime time.Time `gorm:"column:create_time;not null;default:CURRENT_TIMESTAMP;comment:创建时间" json:"create_time"` // 创建时间 + Msg string `gorm:"column:msg" json:"msg"` +} + +// TableName Publish's table name +func (*Publish) TableName() string { + return TableNamePublish +} diff --git a/internal/data/model/token.gen.go b/internal/data/model/token.gen.go new file mode 100644 index 0000000..f2f1719 --- /dev/null +++ b/internal/data/model/token.gen.go @@ -0,0 +1,27 @@ +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. + +package model + +import ( + "time" +) + +const TableNameToken = "token" + +// Token mapped from table +type Token struct { + ID int32 `gorm:"column:id;primaryKey" json:"id"` + Name string `gorm:"column:name" json:"name"` + Secret string `gorm:"column:secret;not null" json:"secret"` + AccessToken string `gorm:"column:access_token;not null" json:"access_token"` + UserLimit int32 `gorm:"column:user_limit;primaryKey" json:"user_limit"` + ExpireTime time.Time `gorm:"column:expire_time;not null" json:"expire_time"` + Status int32 `gorm:"column:status;not null;default:1" json:"status"` +} + +// TableName Token's table name +func (*Token) TableName() string { + return TableNameToken +} diff --git a/internal/data/model/user.gen.go b/internal/data/model/user.gen.go new file mode 100644 index 0000000..83e5605 --- /dev/null +++ b/internal/data/model/user.gen.go @@ -0,0 +1,21 @@ +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. + +package model + +const TableNameUser = "user" + +// User mapped from table +type User struct { + ID int64 `gorm:"column:id;primaryKey;autoIncrement:true" json:"id"` + TokenID int32 `gorm:"column:token_id;not null" json:"token_id"` + Name string `gorm:"column:name;not null" json:"name"` + Status int32 `gorm:"column:status;not null;default:1" json:"status"` + UserIndex string `gorm:"column:user_index;not null" json:"user_index"` +} + +// TableName User's table name +func (*User) TableName() string { + return TableNameUser +} diff --git a/internal/entitys/request.go b/internal/entitys/request.go new file mode 100644 index 0000000..1270140 --- /dev/null +++ b/internal/entitys/request.go @@ -0,0 +1,88 @@ +package entitys + +type ( + LoginAppRequest struct { + Secret string `json:"secret" validate:"required" zh:"密钥"` + } + + GetUserAndAutoStatusRequest struct { + AccessToken string `json:"access_token" validate:"required" zh:"access_token"` + } + + AddUserRequest struct { + AccessToken string `json:"access_token" validate:"required" zh:"access_token"` + Name string `json:"name" validate:"required" zh:"用户名"` + } + + DelUserRequest struct { + ID int `json:"id" validate:"required" zh:"用户ID"` + } + + GetAppRequest struct { + AccessToken string `json:"access_token" validate:"required" zh:"access_token"` + UserIndex string `json:"user_index" validate:"required" zh:"用户索引"` + } + + PublishRecordsRequest struct { + AccessToken string `json:"access_token" validate:"required" zh:"access_token"` + Records []PublishRecordItem `json:"records" validate:"required" zh:"发布记录"` + } + + PublishRecordItem struct { + UserIndex string `json:"user_index"` + PlatIndex string `json:"plat_index" validate:"required" zh:"平台索引"` + Title string `json:"title"` + Tag string `json:"tag"` + Type int32 `json:"type" validate:"required" zh:"类型"` + URL string `json:"url" validate:"required" zh:"链接"` + PublishTime string `json:"publish_time" validate:"required" zh:"发布时间"` + Img string `json:"img"` + RequestID string `json:"request_id"` + } + + PublishOnRequest struct { + AccessToken string `json:"access_token" validate:"required" zh:"access_token"` + } + + PublishOffRequest struct { + AccessToken string `json:"access_token" validate:"required" zh:"access_token"` + } + + PublishStatusRequest struct { + AccessToken string `json:"access_token" validate:"required" zh:"access_token"` + RequestID string `json:"request_id"` + } + + PublishExecuteOnceRequest struct { + AccessToken string `json:"access_token" validate:"required" zh:"access_token"` + } + + PublishExecuteRetryRequest struct { + AccessToken string `json:"access_token" validate:"required" zh:"access_token"` + RequestID string `json:"request_id" validate:"required" zh:"请求ID"` + } + + GetPublishListRequest struct { + AccessToken string `json:"access_token" validate:"required" zh:"access_token"` + Page int `json:"page"` + PageSize int `json:"page_size"` + UserIndex string `json:"user_index"` + Tag string `json:"tag"` + Type int `json:"type"` + PlatIndex string `json:"plat_index"` + Status int `json:"status"` + RequestID string `json:"request_id"` + } + + LoginPlatformRequest struct { + AccessToken string `json:"access_token" validate:"required" zh:"access_token"` + UserIndex string `json:"user_index" validate:"required" zh:"用户索引"` + PlatIndex string `json:"plat_index" validate:"required" zh:"平台索引"` + } + + LogoutPlatformRequest struct { + AccessToken string `json:"access_token" validate:"required" zh:"access_token"` + UserIndex string `json:"user_index" validate:"required" zh:"用户索引"` + PlatIndex string `json:"plat_index" validate:"required" zh:"平台索引"` + } +) diff --git a/internal/manager/publish_manager.go b/internal/manager/publish_manager.go new file mode 100644 index 0000000..782ea4b --- /dev/null +++ b/internal/manager/publish_manager.go @@ -0,0 +1,297 @@ +package manager + +import ( + "fmt" + "geo/internal/config" + "geo/internal/publisher" + "geo/pkg" + "log" + "sync" + "time" + + "geo/utils" +) + +type PublishManager struct { + AutoStatus bool + Conf *config.Config + TokenID int + running bool + mu sync.Mutex + stopCh chan struct{} + currentPublisher interface{} + db *utils.Db +} + +var publishManager *PublishManager +var once sync.Once + +func GetPublishManager(config *config.Config, db *utils.Db) *PublishManager { + once.Do(func() { + publishManager = &PublishManager{ + AutoStatus: false, + Conf: config, + stopCh: make(chan struct{}), + db: db, + } + }) + return publishManager +} + +func (pm *PublishManager) Start(tokenID int) bool { + pm.mu.Lock() + defer pm.mu.Unlock() + + if pm.AutoStatus { + return false + } + + pm.TokenID = tokenID + pm.AutoStatus = true + pm.stopCh = make(chan struct{}) + + go pm.autoPublishLoop() + return true +} + +func (pm *PublishManager) Stop() bool { + pm.mu.Lock() + defer pm.mu.Unlock() + + if !pm.AutoStatus { + return false + } + + pm.AutoStatus = false + close(pm.stopCh) + return true +} + +func (pm *PublishManager) autoPublishLoop() { + log.Println("自动发布服务已启动") + + for { + select { + case <-pm.stopCh: + log.Println("自动发布服务已停止") + return + default: + pm.batchPublish() + time.Sleep(30 * time.Second) + } + } +} + +func (pm *PublishManager) batchPublish() { + if !pm.AutoStatus { + return + } + + publishData := pm.getPendingPublish() + if publishData == nil { + return + } + + // 使用 defer recover 防止 panic 导致整个循环崩溃 + defer func() { + if r := recover(); r != nil { + log.Printf("批处理发布发生 panic: %v", r) + } + }() + + pm.processSingleTask(publishData) +} + +func (pm *PublishManager) getPendingPublish() map[string]interface{} { + currentTime := time.Now().Format("2006-01-02 15:04:05") + + sql := ` + SELECT p.*, pl.* + FROM publish p + INNER JOIN plat pl ON p.plat_index = pl.index AND pl.status = 1 + WHERE p.token_id = ? AND p.status = 1 AND p.publish_time <= ? + ORDER BY p.publish_time DESC + LIMIT 1 + ` + + result, err := pm.db.GetOne(sql, pm.TokenID, currentTime) + if err != nil { + log.Printf("查询待发布任务失败: token_id=%d, error=%v", pm.TokenID, err) + return nil + } + if result == nil { + log.Printf("没有待发布任务: token_id=%d, current_time=%s", pm.TokenID, currentTime) + return nil + } + + requestID := getString(result, "request_id") + log.Printf("获取到待发布任务: token_id=%d, request_id=%s", pm.TokenID, requestID) + return result +} + +func (pm *PublishManager) GetTaskByRequestID(requestID string) (map[string]interface{}, error) { + sql := ` + SELECT p.*, pl.* + FROM publish p + INNER JOIN plat pl ON p.plat_index COLLATE utf8mb4_unicode_ci = pl.index AND pl.status = 1 + WHERE p.request_id = ? + ` + return pm.db.GetOne(sql, requestID) +} + +func (pm *PublishManager) processSingleTask(publishData map[string]interface{}) map[string]interface{} { + requestID := getString(publishData, "request_id") + platIndex := getString(publishData, "plat_index") + title := getString(publishData, "title") + tagRaw := getString(publishData, "tag") + userIndex := getString(publishData, "user_index") + url := getString(publishData, "url") + imgURL := getString(publishData, "img") + + log.Printf("[任务 %s] 开始处理,平台:%s,标题:%s", requestID, platIndex, title) + + // 更新状态为发布中 + pm.updatePublishStatus(requestID, 2, "") + log.Printf("[任务 %s] 状态已更新为发布中", requestID) + + // 下载文件 + docPath, err := pkg.DownloadFile(url, "", requestID+".docx") + if err != nil { + errMsg := fmt.Sprintf("下载文档失败: %v", err) + log.Printf("[任务 %s] %s", requestID, errMsg) + pm.updatePublishStatus(requestID, 3, errMsg) + return map[string]interface{}{"success": false, "message": errMsg, "request_id": requestID} + } + log.Printf("[任务 %s] 文档下载成功: %s", requestID, docPath) + + // 下载图片 + imgPath, err := pkg.DownloadImage(imgURL, requestID, "img") + if err != nil { + errMsg := fmt.Sprintf("下载图片失败: %v", err) + log.Printf("[任务 %s] %s", requestID, errMsg) + pm.updatePublishStatus(requestID, 3, errMsg) + // 图片下载失败,清理已下载的文档 + pkg.DeleteFile(docPath) + return map[string]interface{}{"success": false, "message": errMsg, "request_id": requestID} + } + log.Printf("[任务 %s] 图片下载成功: %s", requestID, imgPath) + + // 确保清理临时文件 + defer func() { + pkg.DeleteFile(docPath) + pkg.DeleteFile(imgPath) + }() + + // 解析标签 + tags := pkg.ParseTags(tagRaw) + log.Printf("[任务 %s] 标签解析完成: %v", requestID, tags) + + // 提取内容 + content, err := pkg.ExtractWordContent(docPath, "html") + if err != nil { + errMsg := fmt.Sprintf("提取文档内容失败: %v", err) + log.Printf("[任务 %s] %s", requestID, errMsg) + pm.updatePublishStatus(requestID, 3, errMsg) + return map[string]interface{}{"success": false, "message": errMsg, "request_id": requestID} + } + log.Printf("[任务 %s] 内容提取成功,长度: %d", requestID, len(content)) + + // 获取发布器 + publisherClass := getPublisherClass(platIndex) + if publisherClass == nil { + errMsg := fmt.Sprintf("不支持的平台: %s", platIndex) + log.Printf("[任务 %s] %s", requestID, errMsg) + pm.updatePublishStatus(requestID, 3, errMsg) + return map[string]interface{}{"success": false, "message": errMsg, "request_id": requestID} + } + + // 创建并执行发布器 + var pub interface{ PublishNote() (bool, string) } + + switch platIndex { + case "xhs": + pub = publisher.NewXiaohongshuPublisher(false, title, content, tags, userIndex, platIndex, requestID, imgPath, docPath, publishData, pm.Conf) + log.Printf("[任务 %s] 创建小红书发布器", requestID) + case "bjh": + pub = publisher.NewBaijiahaoPublisher(false, title, content, tags, userIndex, platIndex, requestID, imgPath, docPath, publishData, pm.Conf) + log.Printf("[任务 %s] 创建百家号发布器", requestID) + default: + log.Printf("[任务 %s] 未知平台 %s,使用默认小红书发布器", requestID, platIndex) + pub = publisher.NewXiaohongshuPublisher(false, title, content, tags, userIndex, platIndex, requestID, imgPath, docPath, publishData, pm.Conf) + } + + log.Printf("[任务 %s] 开始执行发布...", requestID) + success, message := pub.PublishNote() + + if success { + log.Printf("[任务 %s] 发布成功: %s", requestID, message) + pm.updatePublishStatus(requestID, 4, message) + } else { + log.Printf("[任务 %s] 发布失败: %s", requestID, message) + pm.updatePublishStatus(requestID, 3, message) + } + + return map[string]interface{}{ + "success": success, + "message": message, + "request_id": requestID, + } +} + +func (pm *PublishManager) updatePublishStatus(requestID string, status int, message string) { + if message != "" { + pm.db.Execute("UPDATE publish SET status = ?, msg = ? WHERE request_id = ?", status, message, requestID) + } else { + pm.db.Execute("UPDATE publish SET status = ? WHERE request_id = ?", status, requestID) + } +} + +func (pm *PublishManager) ExecuteOnce(tokenId int32) map[string]interface{} { + publishData := pm.getPendingPublish() + if publishData == nil { + return map[string]interface{}{"success": false, "message": "没有待发布任务"} + } + return pm.processSingleTask(publishData) +} + +func (pm *PublishManager) RetryTask(requestID string) map[string]interface{} { + publishData, err := pm.GetTaskByRequestID(requestID) + if err != nil || publishData == nil { + return map[string]interface{}{"success": false, "message": "任务不存在"} + } + return pm.processSingleTask(publishData) +} + +func (pm *PublishManager) GetStatus() map[string]interface{} { + return map[string]interface{}{ + "auto_status": pm.AutoStatus, + "max_concurrent": pm.Conf.Sys.MaxConcurrent, + "task_timeout": pm.Conf.Sys.TaskTimeout, + } +} + +func getPublisherClass(platIndex string) interface{} { + platformMap := map[string]interface{}{ + "xhs": struct{}{}, + "bjh": struct{}{}, + "csdn": struct{}{}, + } + return platformMap[platIndex] +} + +func getString(m map[string]interface{}, key string) string { + if v, ok := m[key]; ok { + switch v.(type) { + case []uint8: + return string(v.([]uint8)) + case string: + return v.(string) + case int64: + return fmt.Sprintf("%d", v) + default: + return fmt.Sprintf("%v", v) + } + + } + return "" +} diff --git a/internal/publisher/baijiahao.go b/internal/publisher/baijiahao.go new file mode 100644 index 0000000..73402b2 --- /dev/null +++ b/internal/publisher/baijiahao.go @@ -0,0 +1,383 @@ +package publisher + +import ( + "fmt" + "strings" + "time" + + "geo/internal/config" + + "github.com/go-rod/rod" + "github.com/go-rod/rod/lib/proto" +) + +type BaijiahaoPublisher struct { + *BasePublisher + Category string + ArticleType string + IsTop bool +} + +func NewBaijiahaoPublisher(headless bool, title, content string, tags []string, tenantID, platIndex, requestID, imagePath, wordPath string, platInfo map[string]interface{}, cfg *config.Config) *BaijiahaoPublisher { + base := NewBasePublisher(headless, title, content, tags, tenantID, platIndex, requestID, imagePath, wordPath, platInfo, cfg) + if platInfo != nil { + base.LoginURL = getString(platInfo, "login_url") + base.EditorURL = getString(platInfo, "edit_url") + base.LoginedURL = getString(platInfo, "logined_url") + } + return &BaijiahaoPublisher{BasePublisher: base} +} + +func (p *BaijiahaoPublisher) CheckLoginStatus() bool { + url := p.GetCurrentURL() + // 如果URL包含登录相关关键词,表示未登录 + if strings.Contains(url, "login") || strings.Contains(url, "passport") { + return false + } + // 如果URL是编辑页面或主页,表示已登录 + if strings.Contains(url, "baijiahao") || strings.Contains(url, "edit") { + return true + } + return url != p.LoginURL +} + +func (p *BaijiahaoPublisher) CheckLogin() (bool, string) { + p.LogInfo("检查登录状态...") + + if err := p.SetupDriver(); err != nil { + return false, fmt.Sprintf("浏览器启动失败: %v", err) + } + defer p.Close() + + p.Page.MustNavigate(p.LoginedURL) + p.Sleep(3) + p.WaitForPageReady(5) + + if p.CheckLoginStatus() { + p.SaveCookies() + return true, "已登录" + } + return false, "未登录" +} + +func (p *BaijiahaoPublisher) WaitLogin() (bool, string) { + p.LogInfo("开始等待登录...") + + if err := p.SetupDriver(); err != nil { + return false, fmt.Sprintf("浏览器启动失败: %v", err) + } + defer p.Close() + + // 先尝试访问已登录页面 + p.Page.MustNavigate(p.LoginedURL) + p.Sleep(3) + + if p.CheckLoginStatus() { + p.SaveCookies() + p.LogInfo("已有登录状态") + return true, "already_logged_in" + } + + // 未登录,跳转到登录页 + p.Page.MustNavigate(p.LoginURL) + p.LogInfo("请扫描二维码登录...") + + // 等待登录完成,最多120秒 + for i := 0; i < 120; i++ { + time.Sleep(1 * time.Second) + if p.CheckLoginStatus() { + p.SaveCookies() + p.LogInfo("登录成功") + return true, "login_success" + } + } + + return false, "登录超时" +} + +func (p *BaijiahaoPublisher) inputTitle() error { + p.LogInfo("输入标题...") + + titleSelectors := []string{ + ".client_pages_edit_components_titleInput [contenteditable='true']", + ".input-box [contenteditable='true']", + "[contenteditable='true']", + } + + var titleInput *rod.Element + var err error + + for _, selector := range titleSelectors { + titleInput, err = p.WaitForElementVisible(selector, 5) + if err == nil && titleInput != nil { + p.LogInfo(fmt.Sprintf("找到标题输入框: %s", selector)) + break + } + } + + if titleInput == nil { + return fmt.Errorf("未找到标题输入框") + } + + // 点击获取焦点 + if err := titleInput.Click(proto.InputMouseButtonLeft, 1); err != nil { + return fmt.Errorf("点击标题框失败: %v", err) + } + p.SleepMs(500) + + // 清空输入框 + if err := p.ClearContentEditable(titleInput); err != nil { + p.LogInfo(fmt.Sprintf("清空标题框失败: %v", err)) + } + p.SleepMs(300) + + // 输入标题 + if err := p.SetContentEditable(titleInput, p.Title); err != nil { + // 备用输入方式 + titleInput.Input(p.Title) + } + + p.LogInfo(fmt.Sprintf("标题已输入: %s", p.Title)) + return nil +} + +func (p *BaijiahaoPublisher) inputContent() error { + p.LogInfo("输入内容...") + + // 查找内容编辑器 + contentEditor, err := p.WaitForElementVisible(".ProseMirror", 10) + if err != nil { + contentEditor, err = p.WaitForElementVisible("[contenteditable='true']", 10) + if err != nil { + return fmt.Errorf("未找到内容编辑器: %v", err) + } + } + + // 点击获取焦点 + if err := contentEditor.Click(proto.InputMouseButtonLeft, 1); err != nil { + return fmt.Errorf("点击编辑器失败: %v", err) + } + p.SleepMs(500) + + // 清空编辑器 + if err := p.ClearContentEditable(contentEditor); err != nil { + p.LogInfo(fmt.Sprintf("清空编辑器失败: %v", err)) + } + p.SleepMs(300) + + // 输入内容 + if err := p.SetContentEditable(contentEditor, p.Content); err != nil { + // 备用输入方式 + contentEditor.Input(p.Content) + } + + p.LogInfo(fmt.Sprintf("内容已输入,长度: %d", len(p.Content))) + return nil +} + +func (p *BaijiahaoPublisher) uploadImage() error { + if p.ImagePath == "" { + p.LogInfo("无封面图片,跳过") + return nil + } + + p.LogInfo(fmt.Sprintf("上传封面: %s", p.ImagePath)) + + // 查找封面区域 + coverArea, err := p.WaitForElementClickable(".cheetah-spin-container", 5) + if err != nil { + p.LogInfo("未找到封面区域,跳过") + return nil + } + + if err := coverArea.Click(proto.InputMouseButtonLeft, 1); err != nil { + p.LogInfo(fmt.Sprintf("点击封面区域失败: %v", err)) + } + p.SleepMs(1000) + + // 查找文件输入框 + fileInput, err := p.Page.Element("input[type='file'][accept*='image']") + if err != nil { + fileInput, err = p.Page.Element("input[type='file']") + if err != nil { + return fmt.Errorf("未找到文件输入框: %v", err) + } + } + + // 上传图片 + if err := fileInput.SetFiles([]string{p.ImagePath}); err != nil { + return fmt.Errorf("上传图片失败: %v", err) + } + + p.LogInfo("图片上传成功") + p.Sleep(3) + + // 查找确认按钮 + confirmBtn, err := p.WaitForElementClickable(".cheetah-btn-primary", 5) + if err == nil && confirmBtn != nil { + if err := confirmBtn.Click(proto.InputMouseButtonLeft, 1); err != nil { + p.LogInfo(fmt.Sprintf("点击确认按钮失败: %v", err)) + } + p.LogInfo("已确认封面") + p.SleepMs(1000) + } + + return nil +} + +func (p *BaijiahaoPublisher) clickPublish() error { + p.LogInfo("点击发布按钮...") + + // 滚动到底部 + if _, err := p.Page.Eval(`() => window.scrollTo(0, document.body.scrollHeight)`); err != nil { + p.LogInfo(fmt.Sprintf("滚动到底部失败: %v", err)) + } + p.SleepMs(1000) + + // 查找发布按钮 + publishSelectors := []string{ + "[data-testid='publish-btn']", + ".op-list-right .cheetah-btn-primary", + ".cheetah-btn-primary", + "button:contains('发布')", + } + + var publishBtn *rod.Element + var err error + + for _, selector := range publishSelectors { + publishBtn, err = p.WaitForElementClickable(selector, 5) + if err == nil && publishBtn != nil { + p.LogInfo(fmt.Sprintf("找到发布按钮: %s", selector)) + break + } + } + + // 如果还是没找到,通过 XPath 查找 + if publishBtn == nil { + publishBtn, err = p.Page.ElementX("//button[contains(text(), '发布')]") + if err != nil { + return fmt.Errorf("未找到发布按钮: %v", err) + } + } + + // 滚动到按钮位置 + if err := p.ScrollToElement(publishBtn); err != nil { + p.LogInfo(fmt.Sprintf("滚动到按钮失败: %v", err)) + } + p.SleepMs(500) + + // 点击发布 + if err := publishBtn.Click(proto.InputMouseButtonLeft, 1); err != nil { + return fmt.Errorf("点击发布按钮失败: %v", err) + } + + p.LogInfo("已点击发布按钮") + return nil +} + +func (p *BaijiahaoPublisher) waitForPublishResult() (bool, string) { + p.LogInfo("等待发布结果...") + + // 等待最多60秒 + for i := 0; i < 60; i++ { + p.SleepMs(1000) + + // 检查URL是否跳转到成功页面 + currentURL := p.GetCurrentURL() + if strings.Contains(currentURL, "clue") || + strings.Contains(currentURL, "success") || + strings.Contains(currentURL, "article/list") { + p.LogInfo("发布成功!") + return true, "发布成功" + } + + // 检查是否有成功提示 + elements, _ := p.Page.Elements(".cheetah-message-success, .cheetah-message-info, [class*='success']") + for _, el := range elements { + text, _ := el.Text() + if strings.Contains(text, "成功") || strings.Contains(text, "已发布") { + p.LogInfo(fmt.Sprintf("发布成功: %s", text)) + return true, text + } + } + + // 检查是否有失败提示 + elements, _ = p.Page.Elements(".cheetah-message-error, .cheetah-message-warning, [class*='error']") + for _, el := range elements { + text, _ := el.Text() + if strings.Contains(text, "失败") || strings.Contains(text, "错误") { + p.LogError(fmt.Sprintf("发布失败: %s", text)) + return false, text + } + } + } + + return false, "发布结果未知(超时)" +} + +func (p *BaijiahaoPublisher) PublishNote() (bool, string) { + p.LogInfo(strings.Repeat("=", 50)) + p.LogInfo("开始发布百家号文章...") + p.LogInfo(fmt.Sprintf("标题: %s", p.Title)) + p.LogInfo(fmt.Sprintf("内容长度: %d", len(p.Content))) + p.LogInfo(strings.Repeat("=", 50)) + + // 初始化浏览器 + if err := p.SetupDriver(); err != nil { + return false, fmt.Sprintf("浏览器启动失败: %v", err) + } + defer p.Close() + + // 访问编辑器页面 + p.Page.MustNavigate(p.EditorURL) + p.Sleep(3) + p.WaitForPageReady(5) + + // 尝试加载cookies + if err := p.LoadCookies(); err == nil { + p.RefreshPage() + p.Sleep(2) + if p.CheckLoginStatus() { + p.LogInfo("使用cookies登录成功") + } else { + p.LogInfo("cookies已过期,需要重新登录") + return false, "需要登录" + } + } + + // 检查登录状态 + if !p.CheckLoginStatus() { + return false, "需要登录" + } + + // 保存cookies + p.SaveCookies() + + // 执行发布流程 + steps := []struct { + name string + fn func() error + }{ + {"输入标题", p.inputTitle}, + {"输入内容", p.inputContent}, + {"上传封面", p.uploadImage}, + } + + for _, step := range steps { + if err := step.fn(); err != nil { + p.LogStep(step.name, false, err.Error()) + return false, fmt.Sprintf("%s失败: %v", step.name, err) + } + p.LogStep(step.name, true, "") + p.SleepMs(500) + } + + // 点击发布 + if err := p.clickPublish(); err != nil { + return false, err.Error() + } + + // 等待发布结果 + return p.waitForPublishResult() +} diff --git a/internal/publisher/base.go b/internal/publisher/base.go new file mode 100644 index 0000000..18ae520 --- /dev/null +++ b/internal/publisher/base.go @@ -0,0 +1,282 @@ +package publisher + +import ( + "encoding/json" + "fmt" + "log" + "os" + "path/filepath" + "time" + + "geo/internal/config" + + "github.com/go-rod/rod" + "github.com/go-rod/rod/lib/launcher" + "github.com/go-rod/rod/lib/proto" +) + +type BasePublisher struct { + Headless bool + Title string + Content string + Tags []string + TenantID string + PlatIndex string + RequestID string + ImagePath string + WordPath string + + Browser *rod.Browser + Page *rod.Page + Logger *log.Logger + LogFile *os.File + + LoginURL string + EditorURL string + LoginedURL string + CookiesFile string + + PlatInfo map[string]interface{} + config *config.Config +} + +func NewBasePublisher(headless bool, title, content string, tags []string, tenantID, platIndex, requestID, imagePath, wordPath string, platInfo map[string]interface{}, config *config.Config) *BasePublisher { + cookiesDir := filepath.Join(config.Sys.CookiesDir, tenantID) + os.MkdirAll(cookiesDir, 0755) + cookiesFile := filepath.Join(cookiesDir, platIndex+".json") + + logFile, _ := os.Create(filepath.Join(config.Sys.LogsDir, requestID+".log")) + logger := log.New(logFile, "", log.LstdFlags) + + return &BasePublisher{ + Headless: headless, + Title: title, + Content: content, + Tags: tags, + TenantID: tenantID, + PlatIndex: platIndex, + RequestID: requestID, + ImagePath: imagePath, + WordPath: wordPath, + Logger: logger, + LogFile: logFile, + CookiesFile: cookiesFile, + PlatInfo: platInfo, + config: config, + } +} + +func (b *BasePublisher) SetupDriver() error { + l := launcher.New() + l.Headless(b.Headless) + // 设置用户数据目录 + userDataDir := filepath.Join(b.config.Sys.ChromeDataDir, b.TenantID) + os.MkdirAll(userDataDir, 0755) + l.UserDataDir(userDataDir) + + // 设置 Leakless 模式(解决 Windows 上的问题) + l.Leakless(false) + + // 设置 Chrome 启动参数 + l.Set("disable-blink-features", "AutomationControlled") + l.Set("no-sandbox") + l.Set("disable-dev-shm-usage") + l.Set("disable-gpu") + l.Set("disable-software-rasterizer") + l.Set("disable-setuid-sandbox") + l.Set("remote-debugging-port", "9222") + + // 窗口大小 + l.Set("window-size", "1920,1080") + l.Set("lang", "zh-CN") + + url, err := l.Launch() + if err != nil { + return fmt.Errorf("启动浏览器失败: %v", err) + } + + b.Browser = rod.New().ControlURL(url).MustConnect() + b.Page = b.Browser.MustPage() + b.Page.MustSetViewport(1920, 1080, 1, false) + + return nil +} + +func (b *BasePublisher) Close() { + if b.Page != nil { + b.Page.Close() + } + if b.Browser != nil { + b.Browser.Close() + } + if b.LogFile != nil { + b.LogFile.Close() + } +} + +func (b *BasePublisher) SaveCookies() error { + cookies, err := b.Page.Cookies(nil) + if err != nil { + return err + } + + data, err := json.Marshal(cookies) + if err != nil { + return err + } + + return os.WriteFile(b.CookiesFile, data, 0644) +} + +func (b *BasePublisher) LoadCookies() error { + data, err := os.ReadFile(b.CookiesFile) + if err != nil { + return err + } + + var cookies []*proto.NetworkCookieParam + if err := json.Unmarshal(data, &cookies); err != nil { + return err + } + + return b.Page.SetCookies(cookies) +} + +func (b *BasePublisher) RefreshPage() error { + _, err := b.Page.Eval(`() => location.reload()`) + return err +} + +func (b *BasePublisher) WaitForPageReady(timeout int) error { + return b.Page.Timeout(time.Duration(timeout) * time.Second).WaitLoad() +} + +func (b *BasePublisher) WaitForElement(selector string, timeout int) (*rod.Element, error) { + return b.Page.Timeout(time.Duration(timeout) * time.Second).Element(selector) +} + +func (b *BasePublisher) WaitForElementVisible(selector string, timeout int) (*rod.Element, error) { + el, err := b.WaitForElement(selector, timeout) + if err != nil { + return nil, err + } + if err := el.WaitVisible(); err != nil { + return nil, err + } + return el, nil +} + +func (b *BasePublisher) WaitForElementClickable(selector string, timeout int) (*rod.Element, error) { + el, err := b.WaitForElementVisible(selector, timeout) + if err != nil { + return nil, err + } + if err := el.WaitVisible(); err != nil { + return nil, err + } + return el, nil +} + +func (b *BasePublisher) JSClick(element *rod.Element) error { + // 使用 element.Evaluate 并传入 EvalOptions + _, err := element.Evaluate(&rod.EvalOptions{ + JS: `el => el.click()`, + }) + return err +} + +func (b *BasePublisher) ScrollToElement(element *rod.Element) error { + _, err := element.Evaluate(&rod.EvalOptions{ + JS: `el => el.scrollIntoView({block: 'center', behavior: 'smooth'})`, + }) + return err +} + +func (b *BasePublisher) Sleep(seconds int) { + time.Sleep(time.Duration(seconds) * time.Second) +} + +func (b *BasePublisher) LogStep(stepName string, success bool, message string) { + if success { + b.Logger.Printf("✅ %s: 成功 %s", stepName, message) + } else { + b.Logger.Printf("❌ %s: 失败 %s", stepName, message) + } +} + +func (b *BasePublisher) LogInfo(message string) { + b.Logger.Printf("📌 %s", message) +} + +func (b *BasePublisher) LogError(message string) { + b.Logger.Printf("❌ %s", message) +} + +func (b *BasePublisher) GetCurrentURL() string { + info := b.Page.MustInfo() + return info.URL +} + +func (b *BasePublisher) Screenshot(filename string) error { + data, err := b.Page.Screenshot(false, nil) + if err != nil { + return err + } + return os.WriteFile(filename, data, 0644) +} + +// 抽象方法 - 子类需要实现 +func (b *BasePublisher) WaitLogin() (bool, string) { + return false, "需要实现" +} + +func (b *BasePublisher) CheckLoginStatus() bool { + return false +} + +func (b *BasePublisher) CheckLogin() (bool, string) { + return false, "需要实现" +} + +func (b *BasePublisher) PublishNote() (bool, string) { + return false, "需要实现" +} + +// ClearContentEditable 清空 contenteditable 元素的内容 +func (b *BasePublisher) ClearContentEditable(element *rod.Element) error { + _, err := element.Evaluate(&rod.EvalOptions{ + JS: `el => { el.innerText = ''; el.innerHTML = ''; el.dispatchEvent(new Event('input', {bubbles: true})); }`, + }) + return err +} + +// SetContentEditable 设置 contenteditable 元素的内容 +func (b *BasePublisher) SetContentEditable(element *rod.Element, content string) error { + _, err := element.Evaluate(&rod.EvalOptions{ + JS: `(el, val) => { el.innerText = val; el.dispatchEvent(new Event('input', {bubbles: true})); }`, + JSArgs: []interface{}{content}, + }) + return err +} + +// SetInputValue 设置输入框值并触发事件 +func (b *BasePublisher) SetInputValue(element *rod.Element, value string) error { + _, err := element.Evaluate(&rod.EvalOptions{ + JS: `(el, val) => { el.value = val; el.dispatchEvent(new Event('input', {bubbles: true})); el.dispatchEvent(new Event('change', {bubbles: true})); }`, + JSArgs: []interface{}{value}, + }) + return err +} + +// ClearInput 清空输入框 +func (b *BasePublisher) ClearInput(element *rod.Element) error { + _, err := element.Evaluate(&rod.EvalOptions{ + JS: `el => { el.value = ''; el.dispatchEvent(new Event('input', {bubbles: true})); }`, + }) + return err +} + +// SleepMs 毫秒级等待 +func (b *BasePublisher) SleepMs(milliseconds int) { + time.Sleep(time.Duration(milliseconds) * time.Millisecond) +} diff --git a/internal/publisher/xiaohongshu.go b/internal/publisher/xiaohongshu.go new file mode 100644 index 0000000..42f8236 --- /dev/null +++ b/internal/publisher/xiaohongshu.go @@ -0,0 +1,425 @@ +package publisher + +import ( + "fmt" + "strings" + "time" + + "geo/internal/config" + + "github.com/go-rod/rod" + "github.com/go-rod/rod/lib/proto" +) + +type XiaohongshuPublisher struct { + *BasePublisher +} + +func NewXiaohongshuPublisher(headless bool, title, content string, tags []string, tenantID, platIndex, requestID, imagePath, wordPath string, platInfo map[string]interface{}, cfg *config.Config) *XiaohongshuPublisher { + base := NewBasePublisher(headless, title, content, tags, tenantID, platIndex, requestID, imagePath, wordPath, platInfo, cfg) + if platInfo != nil { + base.LoginURL = getString(platInfo, "login_url") + base.EditorURL = getString(platInfo, "edit_url") + base.LoginedURL = getString(platInfo, "logined_url") + } + return &XiaohongshuPublisher{BasePublisher: base} +} + +func (p *XiaohongshuPublisher) CheckLoginStatus() bool { + url := p.GetCurrentURL() + // 如果URL包含登录相关关键词,表示未登录 + if strings.Contains(url, "login") || strings.Contains(url, "signin") || strings.Contains(url, "passport") { + return false + } + // 如果URL是编辑页面或主页,表示已登录 + if strings.Contains(url, "creator") || strings.Contains(url, "editor") || strings.Contains(url, "publish") { + return true + } + return true +} + +func (p *XiaohongshuPublisher) CheckLogin() (bool, string) { + p.LogInfo("检查登录状态...") + + if err := p.SetupDriver(); err != nil { + return false, fmt.Sprintf("浏览器启动失败: %v", err) + } + defer p.Close() + + p.Page.MustNavigate(p.LoginedURL) + p.Sleep(3) + p.WaitForPageReady(5) + + if p.CheckLoginStatus() { + p.SaveCookies() + return true, "已登录" + } + return false, "未登录" +} + +func (p *XiaohongshuPublisher) WaitLogin() (bool, string) { + p.LogInfo("开始等待登录...") + + if err := p.SetupDriver(); err != nil { + return false, fmt.Sprintf("浏览器启动失败: %v", err) + } + defer p.Close() + + // 先尝试访问已登录页面 + p.Page.MustNavigate(p.LoginedURL) + p.Sleep(3) + + if p.CheckLoginStatus() { + p.SaveCookies() + p.LogInfo("已有登录状态") + return true, "already_logged_in" + } + + // 未登录,跳转到登录页 + p.Page.MustNavigate(p.LoginURL) + p.LogInfo("请扫描二维码登录...") + + // 等待登录完成,最多120秒 + for i := 0; i < 120; i++ { + time.Sleep(1 * time.Second) + if p.CheckLoginStatus() { + p.SaveCookies() + p.LogInfo("登录成功") + return true, "login_success" + } + } + + return false, "登录超时" +} + +func (p *XiaohongshuPublisher) inputContent() error { + p.LogInfo("输入文章内容...") + + // 等待编辑器加载 + contentEditor, err := p.WaitForElementVisible(".tiptap.ProseMirror", 10) + if err != nil { + // 尝试其他选择器 + contentEditor, err = p.WaitForElementVisible("[contenteditable='true']", 10) + if err != nil { + return fmt.Errorf("未找到内容编辑器: %v", err) + } + } + + // 点击获取焦点 - 使用 Click 方法 + if err := contentEditor.Click(proto.InputMouseButtonLeft, 1); err != nil { + return fmt.Errorf("点击编辑器失败: %v", err) + } + p.SleepMs(500) + + // 清空现有内容 - 使用 JavaScript 清空 + if err := p.ClearContentEditable(contentEditor); err != nil { + p.LogInfo(fmt.Sprintf("清空编辑器失败: %v", err)) + } + p.SleepMs(300) + + // 输入新内容 - 使用 JavaScript 设置内容 + if err := p.SetContentEditable(contentEditor, p.Content); err != nil { + // 如果 JS 方式失败,尝试直接输入 + contentEditor.Input(p.Content) + } + p.LogInfo(fmt.Sprintf("内容已输入,长度: %d", len(p.Content))) + + return nil +} + +func (p *XiaohongshuPublisher) inputTitle() error { + p.LogInfo("输入标题...") + + // 查找标题输入框 + titleSelectors := []string{ + "textarea.d-input", + ".d-input input", + "textarea[placeholder*='标题']", + "textarea", + } + + var titleInput *rod.Element + var err error + + for _, selector := range titleSelectors { + titleInput, err = p.WaitForElementVisible(selector, 3) + if err == nil && titleInput != nil { + p.LogInfo(fmt.Sprintf("找到标题输入框: %s", selector)) + break + } + } + + if titleInput == nil { + return fmt.Errorf("未找到标题输入框") + } + + // 点击获取焦点 + if err := titleInput.Click(proto.InputMouseButtonLeft, 1); err != nil { + return fmt.Errorf("点击标题框失败: %v", err) + } + p.SleepMs(500) + + // 清空输入框 + if err := p.ClearInput(titleInput); err != nil { + // 备用清空方式 + titleInput.Input("") + } + p.SleepMs(300) + + // 输入标题 + if err := p.SetInputValue(titleInput, p.Title); err != nil { + // 备用输入方式 + titleInput.Input(p.Title) + } + + p.LogInfo(fmt.Sprintf("标题已输入: %s", p.Title)) + return nil +} + +func (p *XiaohongshuPublisher) inputTags() error { + if len(p.Tags) == 0 { + p.LogInfo("无标签需要设置") + return nil + } + + p.LogInfo(fmt.Sprintf("设置标签: %v", p.Tags)) + + // 构建标签字符串 + tagStr := "" + for _, tag := range p.Tags { + if tagStr != "" { + tagStr += " " + } + tagStr += "#" + tag + } + + // 查找标签输入区域 + tagInput, err := p.WaitForElementVisible(".tiptap-container [contenteditable='true']", 5) + if err != nil { + p.LogInfo("未找到标签输入框,跳过标签设置") + return nil + } + + if err := tagInput.Click(proto.InputMouseButtonLeft, 1); err != nil { + p.LogInfo(fmt.Sprintf("点击标签框失败: %v", err)) + } + p.SleepMs(500) + + if err := p.SetContentEditable(tagInput, tagStr); err != nil { + tagInput.Input(tagStr) + } + + p.LogInfo("标签设置完成") + return nil +} + +func (p *XiaohongshuPublisher) uploadImage() error { + if p.ImagePath == "" { + p.LogInfo("无封面图片,跳过上传") + return nil + } + + p.LogInfo(fmt.Sprintf("上传封面图片: %s", p.ImagePath)) + + // 查找封面上传按钮 + uploadBtn, err := p.WaitForElementClickable(".upload-content", 5) + if err != nil { + p.LogInfo("未找到封面上传区域,跳过") + return nil + } + + // 使用 Click 方法 + if err := uploadBtn.Click(proto.InputMouseButtonLeft, 1); err != nil { + p.LogInfo(fmt.Sprintf("点击上传按钮失败: %v", err)) + } + p.SleepMs(1000) + + // 查找文件输入框 + fileInput, err := p.Page.Element("input[type='file']") + if err != nil { + return fmt.Errorf("未找到文件输入框: %v", err) + } + + // 使用 SetFiles 上传文件 + if err := fileInput.SetFiles([]string{p.ImagePath}); err != nil { + return fmt.Errorf("上传图片失败: %v", err) + } + + p.LogInfo("图片上传成功") + p.Sleep(3) + + return nil +} + +func (p *XiaohongshuPublisher) clickPublish() error { + p.LogInfo("点击发布按钮...") + + // 滚动到底部 + if _, err := p.Page.Eval(`() => window.scrollTo(0, document.body.scrollHeight)`); err != nil { + p.LogInfo(fmt.Sprintf("滚动到底部失败: %v", err)) + } + p.SleepMs(1000) + + // 查找发布按钮 + publishSelectors := []string{ + ".publish-page-publish-btn button", + ".publish-btn", + ".submit-btn", + "button[type='submit']", + } + + var publishBtn *rod.Element + var err error + + for _, selector := range publishSelectors { + publishBtn, err = p.WaitForElementClickable(selector, 5) + if err == nil && publishBtn != nil { + p.LogInfo(fmt.Sprintf("找到发布按钮: %s", selector)) + break + } + } + + // 如果还是没找到,通过文本查找 + if publishBtn == nil { + publishBtn, err = p.Page.ElementX("//button[contains(text(), '发布')]") + if err != nil { + return fmt.Errorf("未找到发布按钮: %v", err) + } + } + + // 滚动到按钮位置 + if err := p.ScrollToElement(publishBtn); err != nil { + p.LogInfo(fmt.Sprintf("滚动到按钮失败: %v", err)) + } + p.SleepMs(500) + + // 点击发布 - 使用 Click 方法 + if err := publishBtn.Click(proto.InputMouseButtonLeft, 1); err != nil { + return fmt.Errorf("点击发布按钮失败: %v", err) + } + + p.LogInfo("已点击发布按钮") + return nil +} + +func (p *XiaohongshuPublisher) waitForPublishResult() (bool, string) { + p.LogInfo("等待发布结果...") + + // 等待最多60秒 + for i := 0; i < 60; i++ { + p.SleepMs(1000) + + // 检查URL是否跳转到成功页面 + currentURL := p.GetCurrentURL() + if strings.Contains(currentURL, "success") || + strings.Contains(currentURL, "content/manage") || + strings.Contains(currentURL, "work-management") { + p.LogInfo("发布成功!") + return true, "发布成功" + } + + // 检查是否有成功提示 + elements, _ := p.Page.Elements(".semi-toast-content, .toast-success, [class*='success']") + for _, el := range elements { + text, _ := el.Text() + if strings.Contains(text, "成功") || strings.Contains(text, "已发布") { + p.LogInfo(fmt.Sprintf("发布成功: %s", text)) + return true, text + } + } + + // 检查是否有失败提示 + elements, _ = p.Page.Elements(".semi-toast-error, .toast-error, [class*='error']") + for _, el := range elements { + text, _ := el.Text() + if strings.Contains(text, "失败") || strings.Contains(text, "错误") { + p.LogError(fmt.Sprintf("发布失败: %s", text)) + return false, text + } + } + } + + return false, "发布结果未知(超时)" +} + +func (p *XiaohongshuPublisher) PublishNote() (bool, string) { + p.LogInfo(strings.Repeat("=", 50)) + p.LogInfo("开始发布小红书笔记...") + p.LogInfo(fmt.Sprintf("标题: %s", p.Title)) + p.LogInfo(fmt.Sprintf("内容长度: %d", len(p.Content))) + p.LogInfo(fmt.Sprintf("标签: %v", p.Tags)) + p.LogInfo(strings.Repeat("=", 50)) + + // 初始化浏览器 + if err := p.SetupDriver(); err != nil { + return false, fmt.Sprintf("浏览器启动失败: %v", err) + } + defer p.Close() + + // 访问已登录页面 + p.Page.MustNavigate(p.LoginedURL) + p.Sleep(3) + p.WaitForPageReady(5) + + // 尝试加载cookies + if err := p.LoadCookies(); err == nil { + p.RefreshPage() + p.Sleep(2) + if p.CheckLoginStatus() { + p.LogInfo("使用cookies登录成功") + } else { + p.LogInfo("cookies已过期,需要重新登录") + return false, "需要登录" + } + } + + // 检查登录状态 + if !p.CheckLoginStatus() { + return false, "需要登录" + } + + // 保存cookies + p.SaveCookies() + + // 访问发布页面 + p.Page.MustNavigate(p.EditorURL) + p.Sleep(3) + p.WaitForPageReady(5) + + // 执行发布流程 + steps := []struct { + name string + fn func() error + }{ + {"输入内容", p.inputContent}, + {"输入标题", p.inputTitle}, + {"设置标签", p.inputTags}, + {"上传封面", p.uploadImage}, + } + + for _, step := range steps { + if err := step.fn(); err != nil { + p.LogStep(step.name, false, err.Error()) + return false, fmt.Sprintf("%s失败: %v", step.name, err) + } + p.LogStep(step.name, true, "") + p.SleepMs(500) + } + + // 点击发布 + if err := p.clickPublish(); err != nil { + return false, err.Error() + } + + // 等待发布结果 + return p.waitForPublishResult() +} + +func getString(m map[string]interface{}, key string) string { + if v, ok := m[key]; ok { + if s, ok := v.(string); ok { + return s + } + } + return "" +} diff --git a/internal/server/http.go b/internal/server/http.go new file mode 100644 index 0000000..b2d3cea --- /dev/null +++ b/internal/server/http.go @@ -0,0 +1,25 @@ +package server + +import ( + "geo/internal/server/router" + + "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/middleware/logger" + "github.com/gofiber/fiber/v2/middleware/recover" +) + +func NewHTTPServer(routerServer *router.RouterServer) *fiber.App { + //构建 server + app := initRoute() + router.SetupRoutes(app, routerServer) + return app +} + +func initRoute() *fiber.App { + app := fiber.New() + app.Use( + recover.New(), + logger.New(), + ) + return app +} diff --git a/internal/server/provider_set.go b/internal/server/provider_set.go new file mode 100644 index 0000000..f325719 --- /dev/null +++ b/internal/server/provider_set.go @@ -0,0 +1,14 @@ +package server + +import ( + "geo/internal/server/router" + + "github.com/google/wire" +) + +var ProviderSetServer = wire.NewSet( + NewServers, + NewHTTPServer, + router.NewRouterServer, + router.NewAppModule, +) diff --git a/internal/server/router/app.go b/internal/server/router/app.go new file mode 100644 index 0000000..304b66b --- /dev/null +++ b/internal/server/router/app.go @@ -0,0 +1,47 @@ +package router + +import ( + "geo/internal/config" + "geo/internal/entitys" + "geo/internal/service" + + "github.com/gofiber/fiber/v2" +) + +type AppModule struct { + cfg *config.Config + appService *service.AppService + loginService *service.LoginService + publishService *service.PublishService +} + +func NewAppModule( + cfg *config.Config, + appService *service.AppService, + loginService *service.LoginService, + publishService *service.PublishService, +) *AppModule { + return &AppModule{ + cfg: cfg, + appService: appService, + loginService: loginService, + publishService: publishService, + } +} + +func (m *AppModule) Register(router fiber.Router) { + router.Post("/login_app", vali(m.appService.LoginApp, &entitys.LoginAppRequest{})) + router.Post("/get_user_and_auto_status", vali(m.appService.GetUserAndAutoStatus, &entitys.GetUserAndAutoStatusRequest{})) + router.Post("/add_user", vali(m.appService.AddUser, &entitys.AddUserRequest{})) + router.Post("/del_user", vali(m.appService.DelUser, &entitys.DelUserRequest{})) + router.Post("/get_app", vali(m.appService.GetApp, &entitys.GetAppRequest{})) + router.Post("/publish_records", vali(m.publishService.PublishRecords, &entitys.PublishRecordsRequest{})) + router.Post("/publish_on", vali(m.publishService.PublishOn, &entitys.PublishOnRequest{})) + router.Post("/publish_off", vali(m.publishService.PublishOff, &entitys.PublishOffRequest{})) + router.Post("/publish_status", vali(m.publishService.PublishStatus, &entitys.PublishStatusRequest{})) + router.Post("/publish_execute_once", vali(m.publishService.PublishExecuteOnce, &entitys.PublishExecuteOnceRequest{})) + router.Post("/publish_execute_retry", vali(m.publishService.PublishExecuteRetry, &entitys.PublishExecuteRetryRequest{})) + router.Post("/get_publish_list", vali(m.publishService.GetPublishList, &entitys.GetPublishListRequest{})) + router.Post("/login_platform", vali(m.loginService.LoginPlatform, &entitys.LoginPlatformRequest{})) + router.Post("/logout_platform", vali(m.loginService.LogoutPlatform, &entitys.LogoutPlatformRequest{})) +} diff --git a/internal/server/router/router.go b/internal/server/router/router.go new file mode 100644 index 0000000..48ef4ff --- /dev/null +++ b/internal/server/router/router.go @@ -0,0 +1,128 @@ +package router + +import ( + "encoding/json" + "errors" + "geo/pkg" + "geo/tmpl/errcode" + "reflect" + "strings" + + "github.com/go-playground/validator/v10" + "github.com/gofiber/fiber/v2" +) + +type Module interface { + Register(r fiber.Router) +} +type RouterServer struct { + modules []Module +} + +func NewRouterServer( + app *AppModule, +) *RouterServer { + return &RouterServer{ + modules: []Module{app}, + } +} + +// SetupRoutes 设置路由 +func SetupRoutes(app *fiber.App, routerServer *RouterServer) { + app.Use(func(c *fiber.Ctx) error { + // 设置 CORS 头 + c.Set("Access-Control-Allow-Origin", "*") + c.Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") + c.Set("Access-Control-Allow-Headers", "Content-Type, Authorization") + + // 如果是预检请求(OPTIONS),直接返回 204 + if c.Method() == "OPTIONS" { + return c.SendStatus(fiber.StatusNoContent) // 204 + } + + // 继续处理后续中间件或路由 + return c.Next() + }) + + registerResponse(app) + + // 注册所有模块 + for _, module := range routerServer.modules { + module.Register(app) + } +} + +func registerResponse(router fiber.Router) { + // 自定义返回 + router.Use(func(c *fiber.Ctx) error { + err := c.Next() + return registerCommon(c, err) + }) + +} + +func registerCommon(c *fiber.Ctx, err error) error { + // 调用下一个中间件或路由处理函数 + + // 如果有错误发生 + if err != nil { + var businessErr *errcode.BusinessErr + var fiberError *fiber.Error + switch { + case errors.As(err, &businessErr): + case errors.As(err, &fiberError): + errors.As(err, &fiberError) + businessErr = errcode.NewBusinessErr(fiberError.Code, fiberError.Message) + default: + businessErr = errcode.SystemError + } + // 返回自定义错误响应 + return c.JSON(fiber.Map{ + "message": businessErr.Error(), + "code": businessErr.Code(), + "data": nil, + }) + } + // 如果没有错误发生,继续处理请求 + var data interface{} + json.Unmarshal(c.Response().Body(), &data) + return c.JSON(fiber.Map{ + "data": data, + "message": errcode.Success.Error(), + "code": errcode.Success.Code(), + }) +} + +func vali[T any](handler func(*fiber.Ctx, *T) error, _ *T) fiber.Handler { + return func(c *fiber.Ctx) error { + var data T + + // 解析请求 + if err := c.BodyParser(&data); err != nil { + return errcode.ParamErr(err.Error()) + } + + // 创建验证器 + validate := validator.New() + + // 注册中文标签 + validate.RegisterTagNameFunc(func(fld reflect.StructField) string { + name := fld.Tag.Get("zh") + if name == "" { + name = fld.Tag.Get("json") + } + return name + }) + + // 验证 + if err := validate.Struct(&data); err != nil { + er := make([]string, len(err.(validator.ValidationErrors))) + for k, e := range err.(validator.ValidationErrors) { + er[k] = pkg.GetErr(e.Tag(), e.Field(), e.Param()) + } + return errcode.ParamErr(strings.Join(er, ",")) + } + + return handler(c, &data) + } +} diff --git a/internal/server/server.go b/internal/server/server.go new file mode 100644 index 0000000..e21faed --- /dev/null +++ b/internal/server/server.go @@ -0,0 +1,19 @@ +package server + +import ( + "geo/internal/config" + + "github.com/gofiber/fiber/v2" +) + +type Servers struct { + cfg *config.Config + HttpServer *fiber.App +} + +func NewServers(cfg *config.Config, fiber *fiber.App) *Servers { + return &Servers{ + HttpServer: fiber, + cfg: cfg, + } +} diff --git a/internal/service/app.go b/internal/service/app.go new file mode 100644 index 0000000..860d60c --- /dev/null +++ b/internal/service/app.go @@ -0,0 +1,173 @@ +package service + +import ( + "fmt" + + "github.com/gofiber/fiber/v2" + "xorm.io/builder" + + "geo/internal/biz" + "geo/internal/config" + "geo/internal/data/impl" + "geo/internal/data/model" + "geo/internal/entitys" + "geo/internal/manager" + "geo/pkg" + "geo/tmpl/errcode" +) + +type AppService struct { + cfg *config.Config + tokenImpl *impl.TokenImpl + userImpl *impl.UserImpl + platImpl *impl.PlatImpl + publishBiz *biz.PublishBiz +} + +func NewAppService( + cfg *config.Config, + tokenImpl *impl.TokenImpl, + + userImpl *impl.UserImpl, + platImpl *impl.PlatImpl, + publishBiz *biz.PublishBiz, +) *AppService { + return &AppService{ + cfg: cfg, + tokenImpl: tokenImpl, + + userImpl: userImpl, + platImpl: platImpl, + publishBiz: publishBiz, + } +} + +func (a *AppService) LoginApp(c *fiber.Ctx, req *entitys.LoginAppRequest) error { + cond := builder.NewCond(). + And(builder.Eq{"secret": req.Secret}). + And(builder.Eq{"status": 1}) + tokenInfo := &model.Token{} + err := a.tokenImpl.GetOneBySearchStruct(c.UserContext(), &cond, tokenInfo) + if err != nil || tokenInfo == nil { + return errcode.Forbidden("密钥无效") + } + + accessToken := pkg.GenerateUUID() + err = a.tokenImpl.UpdateByKey(c.UserContext(), a.tokenImpl.PrimaryKey(), tokenInfo.ID, &model.Token{ + AccessToken: accessToken, + }) + if err != nil { + return err + } + + return pkg.HandleResponse(c, fiber.Map{ + "access_token": accessToken, + "user_limit": tokenInfo.UserLimit, + "expire_time": tokenInfo.ExpireTime, + }) +} + +func (a *AppService) GetUserAndAutoStatus(c *fiber.Ctx, req *entitys.GetUserAndAutoStatusRequest) error { + tokenInfo, err := a.publishBiz.ValidateAccessToken(c.UserContext(), req.AccessToken) + if err != nil { + return err + } + + cond := builder.NewCond(). + And(builder.Eq{"token_id": tokenInfo.ID}). + And(builder.Eq{"status": 1}) + users := make([]model.User, 0) + _, err = a.userImpl.GetListToStruct(c.UserContext(), &cond, nil, &users, "id desc") + if err != nil { + return err + } + + pm := manager.GetPublishManager(a.cfg, a.tokenImpl.GetDb()) + return pkg.HandleResponse(c, fiber.Map{ + "user": users, + "auto_status": pm.AutoStatus, + }) +} + +func (a *AppService) AddUser(c *fiber.Ctx, req *entitys.AddUserRequest) error { + tokenInfo, err := a.publishBiz.ValidateAccessToken(c.UserContext(), req.AccessToken) + if err != nil { + return err + } + + cond := builder.NewCond(). + And(builder.Eq{"token_id": tokenInfo.ID}). + And(builder.Eq{"status": 1}) + userCount, err := a.userImpl.CountByCond(c.UserContext(), &cond) + if err != nil { + return err + } + + if userCount >= int64(tokenInfo.UserLimit) { + return errcode.ParamErr(fmt.Sprintf("用户已达上限(%d)", tokenInfo.UserLimit)) + } + + userIndex := pkg.GenerateUserIndex() + err = a.userImpl.Add(c.UserContext(), &model.User{ + TokenID: tokenInfo.ID, + UserIndex: userIndex, + Name: req.Name, + Status: 1, + }) + if err != nil { + return errcode.SqlErr(err) + } + + // 获取平台列表并关联 + platCond := builder.NewCond().And(builder.Eq{"status": 1}) + plats := make([]model.Plat, 0) + _, err = a.platImpl.GetListToStruct(c.UserContext(), &platCond, nil, &plats, "") + if err != nil { + return err + } + + for _, plat := range plats { + loginRelation := &model.LoginRelation{ + UserIndex: userIndex, + PlatIndex: plat.Index, + LoginStatus: 2, + Status: 1, + } + // 这里需要调用 loginRelationImpl 的 Add 方法 + _ = loginRelation + } + + return pkg.HandleResponse(c, fiber.Map{ + "user_name": req.Name, + "user_index": userIndex, + }) +} + +func (a *AppService) DelUser(c *fiber.Ctx, req *entitys.DelUserRequest) error { + // 需要从请求中获取 access_token + // 这里简化处理 + err := a.userImpl.DeleteByKey(c.UserContext(), a.userImpl.PrimaryKey(), int64(req.ID)) + if err != nil { + return err + } + return pkg.HandleResponse(c, fiber.Map{}) +} + +func (a *AppService) GetApp(c *fiber.Ctx, req *entitys.GetAppRequest) error { + _, err := a.publishBiz.ValidateAccessToken(c.UserContext(), req.AccessToken) + if err != nil { + return err + } + + cond := builder.NewCond(). + And(builder.Eq{"login_relation.user_index": req.UserIndex}). + And(builder.Eq{"login_relation.status": 1}). + And(builder.Eq{"plat.status": 1}) + + result, err := a.platImpl.GetPlatListWithLoginStatus(c.UserContext(), &cond) + if err != nil { + return err + } + + return pkg.HandleResponse(c, result) +} diff --git a/internal/service/login.go b/internal/service/login.go new file mode 100644 index 0000000..3f59466 --- /dev/null +++ b/internal/service/login.go @@ -0,0 +1,95 @@ +package service + +import ( + "os" + "path/filepath" + + "github.com/gofiber/fiber/v2" + + "geo/internal/biz" + "geo/internal/config" + "geo/internal/entitys" + "geo/internal/publisher" + "geo/pkg" + "geo/tmpl/errcode" +) + +type LoginService struct { + cfg *config.Config + publishBiz *biz.PublishBiz +} + +func NewLoginService( + cfg *config.Config, + publishBiz *biz.PublishBiz, + +) *LoginService { + return &LoginService{ + cfg: cfg, + publishBiz: publishBiz, + } +} + +func (s *LoginService) LoginPlatform(c *fiber.Ctx, req *entitys.LoginPlatformRequest) error { + _, err := s.publishBiz.ValidateAccessToken(c.UserContext(), req.AccessToken) + if err != nil { + return err + } + + // 获取平台信息 + platInfo, err := s.publishBiz.GetPlatInfo(c.UserContext(), req.PlatIndex) + if err != nil { + return errcode.NotFound("平台不存在") + } + + // 创建发布器 + platMap := map[string]interface{}{ + "login_url": platInfo.LoginURL, + "edit_url": platInfo.EditURL, + "logined_url": platInfo.LoginedURL, + } + + var pub interface{ WaitLogin() (bool, string) } + switch req.PlatIndex { + case "xhs": + pub = publisher.NewXiaohongshuPublisher(false, "", "", nil, req.UserIndex, req.PlatIndex, "", "", "", platMap, s.cfg) + default: + pub = publisher.NewXiaohongshuPublisher(false, "", "", nil, req.UserIndex, req.PlatIndex, "", "", "", platMap, s.cfg) + } + + success, msg := pub.WaitLogin() + if !success { + return errcode.SysErr(msg) + } + + // 更新登录状态 + err = s.publishBiz.UpdateLoginStatus(c.UserContext(), req.UserIndex, req.PlatIndex, 1) + if err != nil { + return err + } + + return pkg.HandleResponse(c, fiber.Map{}) +} + +func (s *LoginService) LogoutPlatform(c *fiber.Ctx, req *entitys.LogoutPlatformRequest) error { + _, err := s.publishBiz.ValidateAccessToken(c.UserContext(), req.AccessToken) + if err != nil { + return err + } + + // 更新登录状态为未登录 + err = s.publishBiz.UpdateLoginStatus(c.UserContext(), req.UserIndex, req.PlatIndex, 2) + if err != nil { + return err + } + + return pkg.HandleResponse(c, fiber.Map{}) +} + +func (s *LoginService) ServeQrcode(c *fiber.Ctx, filename string) error { + filepath := filepath.Join(s.cfg.Sys.QrcodesDir, filename) + if _, err := os.Stat(filepath); os.IsNotExist(err) { + return errcode.NotFound("二维码不存在") + } + return c.SendFile(filepath) +} diff --git a/internal/service/provider_set.go b/internal/service/provider_set.go new file mode 100644 index 0000000..ab60193 --- /dev/null +++ b/internal/service/provider_set.go @@ -0,0 +1,11 @@ +package service + +import ( + "github.com/google/wire" +) + +var ProviderSetAppService = wire.NewSet( + NewAppService, + NewPublishService, + NewLoginService, +) diff --git a/internal/service/publish.go b/internal/service/publish.go new file mode 100644 index 0000000..d9508ed --- /dev/null +++ b/internal/service/publish.go @@ -0,0 +1,205 @@ +package service + +import ( + "errors" + "fmt" + "geo/internal/data/model" + "geo/utils" + "os" + "path/filepath" + + "time" + + "github.com/gofiber/fiber/v2" + + "geo/internal/biz" + "geo/internal/config" + "geo/internal/entitys" + "geo/internal/manager" + "geo/pkg" + "geo/tmpl/errcode" +) + +type PublishService struct { + cfg *config.Config + publishBiz *biz.PublishBiz + db *utils.Db +} + +func NewPublishService( + cfg *config.Config, + publishBiz *biz.PublishBiz, + db *utils.Db, +) *PublishService { + return &PublishService{ + cfg: cfg, + publishBiz: publishBiz, + db: db, + } +} + +func (s *PublishService) PublishRecords(c *fiber.Ctx, req *entitys.PublishRecordsRequest) error { + // 验证token + tokenInfo, err := s.publishBiz.ValidateAccessToken(c.UserContext(), req.AccessToken) + if err != nil { + return err + } + + // 转换记录 + validRecords := make([]*model.Publish, 0) + for _, record := range req.Records { + publishTime, err := time.Parse("2006-01-02 15:04:05", record.PublishTime) + if err != nil { + return errcode.ParamErr(fmt.Sprintf("时间格式错误: %v", err)) + } + + validRecords = append(validRecords, &model.Publish{ + UserIndex: record.UserIndex, + RequestID: record.RequestID, + Title: record.Title, + Tag: record.Tag, + Type: record.Type, + PlatIndex: record.PlatIndex, + URL: record.URL, + PublishTime: publishTime, + Img: record.Img, + TokenID: tokenInfo.ID, + }) + + } + + err = s.publishBiz.BatchInsertPublish(c.UserContext(), validRecords) + if err != nil { + return err + } + + return pkg.HandleResponse(c, fiber.Map{ + "total": len(validRecords), + }) +} + +func (s *PublishService) PublishOn(c *fiber.Ctx, req *entitys.PublishOnRequest) error { + tokenInfo, err := s.publishBiz.ValidateAccessToken(c.UserContext(), req.AccessToken) + if err != nil { + return err + } + + pm := manager.GetPublishManager(s.cfg, s.db) + if pm.Start(int(tokenInfo.ID)) { + return pkg.HandleResponse(c, fiber.Map{ + "auto_status": pm.AutoStatus, + }) + } + return errors.New("自动发布服务已在运行中") +} + +func (s *PublishService) PublishOff(c *fiber.Ctx, req *entitys.PublishOffRequest) error { + _, err := s.publishBiz.ValidateAccessToken(c.UserContext(), req.AccessToken) + if err != nil { + return err + } + + pm := manager.GetPublishManager(s.cfg, s.db) + if pm.Stop() { + return pkg.HandleResponse(c, fiber.Map{}) + } + return errors.New("自动发布服务未运行") +} + +func (s *PublishService) PublishStatus(c *fiber.Ctx, req *entitys.PublishStatusRequest) error { + _, err := s.publishBiz.ValidateAccessToken(c.UserContext(), req.AccessToken) + if err != nil { + return err + } + + pm := manager.GetPublishManager(s.cfg, s.db) + + if req.RequestID != "" { + // 查询单个任务 + task, err := s.publishBiz.GetTaskByRequestID(c.UserContext(), req.RequestID) + if err != nil { + return errcode.NotFound("任务不存在") + } + return pkg.HandleResponse(c, task) + } + + // 查询整体状态 + return pkg.HandleResponse(c, pm.GetStatus()) +} + +func (s *PublishService) PublishExecuteOnce(c *fiber.Ctx, req *entitys.PublishExecuteOnceRequest) error { + tokenInfo, err := s.publishBiz.ValidateAccessToken(c.UserContext(), req.AccessToken) + if err != nil { + return err + } + + pm := manager.GetPublishManager(s.cfg, s.db) + result := pm.ExecuteOnce(tokenInfo.ID) + return pkg.HandleResponse(c, result) +} + +func (s *PublishService) PublishExecuteRetry(c *fiber.Ctx, req *entitys.PublishExecuteRetryRequest) error { + _, err := s.publishBiz.ValidateAccessToken(c.UserContext(), req.AccessToken) + if err != nil { + return err + } + + pm := manager.GetPublishManager(s.cfg, s.db) + result := pm.RetryTask(req.RequestID) + return pkg.HandleResponse(c, result) +} + +func (s *PublishService) GetPublishList(c *fiber.Ctx, req *entitys.GetPublishListRequest) error { + tokenInfo, err := s.publishBiz.ValidateAccessToken(c.UserContext(), req.AccessToken) + if err != nil { + return err + } + + page := req.Page + if page < 1 { + page = 1 + } + pageSize := req.PageSize + if pageSize < 1 { + pageSize = 20 + } + if pageSize > 100 { + pageSize = 100 + } + + filters := map[string]interface{}{ + "user_index": req.UserIndex, + "tag": req.Tag, + "type": req.Type, + "plat_index": req.PlatIndex, + "status": req.Status, + "request_id": req.RequestID, + } + + list, total, err := s.publishBiz.GetPublishList(c.UserContext(), tokenInfo.ID, page, pageSize, filters) + if err != nil { + return err + } + + return pkg.HandleResponse(c, fiber.Map{ + "list": list, + "pagination": fiber.Map{ + "page": page, + "page_size": pageSize, + "total": total, + "total_pages": (total + int64(pageSize) - 1) / int64(pageSize), + }, + }) +} + +func (s *PublishService) GetLogs(c *fiber.Ctx, requestID string) error { + logFile := filepath.Join(s.cfg.Sys.LogsDir, requestID+".log") + content, err := os.ReadFile(logFile) + if err != nil { + return errcode.NotFound("日志文件不存在") + } + return pkg.HandleResponse(c, fiber.Map{ + "request_id": requestID, + "log_content": string(content), + }) +} diff --git a/pkg/doc.go b/pkg/doc.go new file mode 100644 index 0000000..637b72d --- /dev/null +++ b/pkg/doc.go @@ -0,0 +1,45 @@ +package pkg + +import ( + "strings" + + md "github.com/JohannesKaufmann/html-to-markdown" + strip "github.com/grokify/html-strip-tags-go" +) + +func ExtractWordContent(filePath string, format string) (string, error) { + // 简化版:读取文件内容并转换格式 + // 完整实现需要使用专门的docx库 + + content := "从Word文档提取的内容" + + switch format { + case "html": + return "

" + content + "

", nil + case "markdown": + converter := md.NewConverter("", true, nil) + return converter.ConvertString("

" + content + "

") + default: + return strip.StripTags(content), nil + } +} + +func ParseTags(tagStr string) []string { + if tagStr == "" { + return []string{} + } + tags := strings.Split(tagStr, ",") + result := make([]string, 0) + for _, t := range tags { + trimmed := strings.TrimSpace(t) + if trimmed != "" { + result = append(result, trimmed) + } + } + return result +} + +func IsTimeExceeded(targetTime string) bool { + // 实现时间比较 + return false +} diff --git a/pkg/func.go b/pkg/func.go new file mode 100644 index 0000000..03f68c7 --- /dev/null +++ b/pkg/func.go @@ -0,0 +1,247 @@ +package pkg + +import ( + "encoding/json" + "fmt" + "io" + "math/rand/v2" + "net/http" + "os" + "path/filepath" + "reflect" + "time" + + "github.com/go-viper/mapstructure/v2" + "github.com/google/uuid" +) + +func GetModuleDir() (string, error) { + dir, err := os.Getwd() + if err != nil { + return "", err + } + + for { + modPath := filepath.Join(dir, "go.mod") + if _, err := os.Stat(modPath); err == nil { + return dir, nil // 找到 go.mod + } + + // 向上查找父目录 + parent := filepath.Dir(dir) + if parent == dir { + break // 到达根目录,未找到 + } + dir = parent + } + + return "", fmt.Errorf("go.mod not found in current directory or parents") +} + +// GetCacheDir 用于获取缓存目录路径 +// 如果缓存目录不存在,则会自动创建 +// 返回值: +// - string: 缓存目录的路径 +// - error: 如果获取模块目录失败或创建缓存目录失败,则返回错误信息 +func GetCacheDir() (string, error) { + // 获取模块目录 + modDir, err := GetModuleDir() + if err != nil { + return "", err + } + // 拼接缓存目录路径 + path := fmt.Sprintf("%s/cache", modDir) + // 创建目录(包括所有必要的父目录),权限设置为0755 + err = os.MkdirAll(path, 0755) + if err != nil { + return "", fmt.Errorf("创建目录失败: %w", err) + } + // 返回成功创建的缓存目录路径 + return path, nil +} + +func GetTmplDir() (string, error) { + modDir, err := GetModuleDir() + if err != nil { + return "", err + } + path := fmt.Sprintf("%s/tmpl", modDir) + err = os.MkdirAll(path, 0755) + if err != nil { + return "", fmt.Errorf("创建目录失败: %w", err) + } + return path, nil +} + +func ReverseSliceNew[T any](s []T) []T { + result := make([]T, len(s)) + for i := 0; i < len(s); i++ { + result[i] = s[len(s)-1-i] + } + return result +} + +func JsonStringIgonErr(data interface{}) string { + return string(JsonByteIgonErr(data)) +} + +func JsonByteIgonErr(data interface{}) []byte { + dataByte, _ := json.Marshal(data) + return dataByte +} + +func IntersectionGeneric[T comparable](slice1, slice2 []T) []T { + m := make(map[T]bool) + result := []T{} + + for _, v := range slice1 { + m[v] = true + } + + for _, v := range slice2 { + if m[v] { + result = append(result, v) + delete(m, v) // 避免重复 + } + } + + return result +} + +func CreateOrderNum(prefix string) string { + code := fmt.Sprintf("%04d", rand.IntN(10000)) + fmt.Println("4位随机数字:", code) // 输出示例: "0837" + return prefix + time.Now().Format("20060102150405") + code +} + +func BuildUpdateMap(obj interface{}, omitFields ...string) map[string]interface{} { + result := make(map[string]interface{}) + omitMap := make(map[string]bool) + for _, f := range omitFields { + omitMap[f] = true + } + + v := reflect.ValueOf(obj) + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + t := v.Type() + + for i := 0; i < v.NumField(); i++ { + field := v.Field(i) + fieldName := t.Field(i).Name + + if omitMap[fieldName] { + continue + } + + // 只处理非 nil 的指针字段 + if field.Kind() == reflect.Ptr && !field.IsNil() { + // 将驼峰转为下划线(可选的,根据你的数据库列名决定) + colName := CamelToSnake(fieldName) + result[colName] = field.Elem().Interface() + } + } + return result +} + +func CamelToSnake(s string) string { + var result []rune + for i, r := range s { + if i > 0 && r >= 'A' && r <= 'Z' { + result = append(result, '_') + } + result = append(result, r) + } + return string(result) +} + +func CopyNonNilFields(src, dst interface{}) error { + config := &mapstructure.DecoderConfig{ + Result: dst, + TagName: "json", + ZeroFields: false, // 重要:不清零目标字段 + Squash: false, + } + + decoder, err := mapstructure.NewDecoder(config) + if err != nil { + return err + } + + return decoder.Decode(src) +} + +func DownloadFile(url string, saveDir string, filename string) (string, error) { + os.MkdirAll(saveDir, 0755) + + if filename == "" { + filename = uuid.New().String() + ".docx" + } + + filePath := filepath.Join(saveDir, filename) + + resp, err := http.Get(url) + if err != nil { + return "", err + } + defer resp.Body.Close() + + out, err := os.Create(filePath) + if err != nil { + return "", err + } + defer out.Close() + + _, err = io.Copy(out, resp.Body) + if err != nil { + return "", err + } + + absPath, _ := filepath.Abs(filePath) + return absPath, nil +} + +func DownloadImage(url string, requestID string, dir string) (string, error) { + os.MkdirAll(dir, 0755) + + ext := filepath.Ext(url) + if ext == "" { + ext = ".jpg" + } + filename := requestID + "_" + uuid.New().String() + ext + filePath := filepath.Join(dir, filename) + + resp, err := http.Get(url) + if err != nil { + return "", err + } + defer resp.Body.Close() + + out, err := os.Create(filePath) + if err != nil { + return "", err + } + defer out.Close() + + _, err = io.Copy(out, resp.Body) + if err != nil { + return "", err + } + + return filepath.Abs(filePath) +} + +func DeleteFile(path string) { + if path != "" { + os.Remove(path) + } +} + +func GenerateUUID() string { + return uuid.New().String() +} + +func GenerateUserIndex() string { + return uuid.New().String()[:20] +} diff --git a/pkg/mapstructure/decode_hooks.go b/pkg/mapstructure/decode_hooks.go new file mode 100644 index 0000000..3a754ca --- /dev/null +++ b/pkg/mapstructure/decode_hooks.go @@ -0,0 +1,279 @@ +package mapstructure + +import ( + "encoding" + "errors" + "fmt" + "net" + "reflect" + "strconv" + "strings" + "time" +) + +// typedDecodeHook takes a raw DecodeHookFunc (an interface{}) and turns +// it into the proper DecodeHookFunc type, such as DecodeHookFuncType. +func typedDecodeHook(h DecodeHookFunc) DecodeHookFunc { + // Create variables here so we can reference them with the reflect pkg + var f1 DecodeHookFuncType + var f2 DecodeHookFuncKind + var f3 DecodeHookFuncValue + + // Fill in the variables into this interface and the rest is done + // automatically using the reflect package. + potential := []interface{}{f1, f2, f3} + + v := reflect.ValueOf(h) + vt := v.Type() + for _, raw := range potential { + pt := reflect.ValueOf(raw).Type() + if vt.ConvertibleTo(pt) { + return v.Convert(pt).Interface() + } + } + + return nil +} + +// DecodeHookExec executes the given decode hook. This should be used +// since it'll naturally degrade to the older backwards compatible DecodeHookFunc +// that took reflect.Kind instead of reflect.Type. +func DecodeHookExec( + raw DecodeHookFunc, + from reflect.Value, to reflect.Value) (interface{}, error) { + + switch f := typedDecodeHook(raw).(type) { + case DecodeHookFuncType: + return f(from.Type(), to.Type(), from.Interface()) + case DecodeHookFuncKind: + return f(from.Kind(), to.Kind(), from.Interface()) + case DecodeHookFuncValue: + return f(from, to) + default: + return nil, errors.New("invalid decode hook signature") + } +} + +// ComposeDecodeHookFunc creates a single DecodeHookFunc that +// automatically composes multiple DecodeHookFuncs. +// +// The composed funcs are called in order, with the result of the +// previous transformation. +func ComposeDecodeHookFunc(fs ...DecodeHookFunc) DecodeHookFunc { + return func(f reflect.Value, t reflect.Value) (interface{}, error) { + var err error + data := f.Interface() + + newFrom := f + for _, f1 := range fs { + data, err = DecodeHookExec(f1, newFrom, t) + if err != nil { + return nil, err + } + newFrom = reflect.ValueOf(data) + } + + return data, nil + } +} + +// OrComposeDecodeHookFunc executes all input hook functions until one of them returns no error. In that case its value is returned. +// If all hooks return an error, OrComposeDecodeHookFunc returns an error concatenating all error messages. +func OrComposeDecodeHookFunc(ff ...DecodeHookFunc) DecodeHookFunc { + return func(a, b reflect.Value) (interface{}, error) { + var allErrs string + var out interface{} + var err error + + for _, f := range ff { + out, err = DecodeHookExec(f, a, b) + if err != nil { + allErrs += err.Error() + "\n" + continue + } + + return out, nil + } + + return nil, errors.New(allErrs) + } +} + +// StringToSliceHookFunc returns a DecodeHookFunc that converts +// string to []string by splitting on the given sep. +func StringToSliceHookFunc(sep string) DecodeHookFunc { + return func( + f reflect.Kind, + t reflect.Kind, + data interface{}) (interface{}, error) { + if f != reflect.String || t != reflect.Slice { + return data, nil + } + + raw := data.(string) + if raw == "" { + return []string{}, nil + } + + return strings.Split(raw, sep), nil + } +} + +// StringToTimeDurationHookFunc returns a DecodeHookFunc that converts +// strings to time.Duration. +func StringToTimeDurationHookFunc() DecodeHookFunc { + return func( + f reflect.Type, + t reflect.Type, + data interface{}) (interface{}, error) { + if f.Kind() != reflect.String { + return data, nil + } + if t != reflect.TypeOf(time.Duration(5)) { + return data, nil + } + + // Convert it by parsing + return time.ParseDuration(data.(string)) + } +} + +// StringToIPHookFunc returns a DecodeHookFunc that converts +// strings to net.IP +func StringToIPHookFunc() DecodeHookFunc { + return func( + f reflect.Type, + t reflect.Type, + data interface{}) (interface{}, error) { + if f.Kind() != reflect.String { + return data, nil + } + if t != reflect.TypeOf(net.IP{}) { + return data, nil + } + + // Convert it by parsing + ip := net.ParseIP(data.(string)) + if ip == nil { + return net.IP{}, fmt.Errorf("failed parsing ip %v", data) + } + + return ip, nil + } +} + +// StringToIPNetHookFunc returns a DecodeHookFunc that converts +// strings to net.IPNet +func StringToIPNetHookFunc() DecodeHookFunc { + return func( + f reflect.Type, + t reflect.Type, + data interface{}) (interface{}, error) { + if f.Kind() != reflect.String { + return data, nil + } + if t != reflect.TypeOf(net.IPNet{}) { + return data, nil + } + + // Convert it by parsing + _, net, err := net.ParseCIDR(data.(string)) + return net, err + } +} + +// StringToTimeHookFunc returns a DecodeHookFunc that converts +// strings to time.Time. +func StringToTimeHookFunc(layout string) DecodeHookFunc { + return func( + f reflect.Type, + t reflect.Type, + data interface{}) (interface{}, error) { + if f.Kind() != reflect.String { + return data, nil + } + if t != reflect.TypeOf(time.Time{}) { + return data, nil + } + + // Convert it by parsing + return time.Parse(layout, data.(string)) + } +} + +// WeaklyTypedHook is a DecodeHookFunc which adds support for weak typing to +// the decoder. +// +// Note that this is significantly different from the WeaklyTypedInput option +// of the DecoderConfig. +func WeaklyTypedHook( + f reflect.Kind, + t reflect.Kind, + data interface{}) (interface{}, error) { + dataVal := reflect.ValueOf(data) + switch t { + case reflect.String: + switch f { + case reflect.Bool: + if dataVal.Bool() { + return "1", nil + } + return "0", nil + case reflect.Float32: + return strconv.FormatFloat(dataVal.Float(), 'f', -1, 64), nil + case reflect.Int: + return strconv.FormatInt(dataVal.Int(), 10), nil + case reflect.Slice: + dataType := dataVal.Type() + elemKind := dataType.Elem().Kind() + if elemKind == reflect.Uint8 { + return string(dataVal.Interface().([]uint8)), nil + } + case reflect.Uint: + return strconv.FormatUint(dataVal.Uint(), 10), nil + } + } + + return data, nil +} + +func RecursiveStructToMapHookFunc() DecodeHookFunc { + return func(f reflect.Value, t reflect.Value) (interface{}, error) { + if f.Kind() != reflect.Struct { + return f.Interface(), nil + } + + var i interface{} = struct{}{} + if t.Type() != reflect.TypeOf(&i).Elem() { + return f.Interface(), nil + } + + m := make(map[string]interface{}) + t.Set(reflect.ValueOf(m)) + + return f.Interface(), nil + } +} + +// TextUnmarshallerHookFunc returns a DecodeHookFunc that applies +// strings to the UnmarshalText function, when the target type +// implements the encoding.TextUnmarshaler interface +func TextUnmarshallerHookFunc() DecodeHookFuncType { + return func( + f reflect.Type, + t reflect.Type, + data interface{}) (interface{}, error) { + if f.Kind() != reflect.String { + return data, nil + } + result := reflect.New(t).Interface() + unmarshaller, ok := result.(encoding.TextUnmarshaler) + if !ok { + return data, nil + } + if err := unmarshaller.UnmarshalText([]byte(data.(string))); err != nil { + return nil, err + } + return result, nil + } +} diff --git a/pkg/mapstructure/decode_hooks_test.go b/pkg/mapstructure/decode_hooks_test.go new file mode 100644 index 0000000..07fbedf --- /dev/null +++ b/pkg/mapstructure/decode_hooks_test.go @@ -0,0 +1,567 @@ +package mapstructure + +import ( + "errors" + "math/big" + "net" + "reflect" + "testing" + "time" +) + +func TestComposeDecodeHookFunc(t *testing.T) { + f1 := func( + f reflect.Kind, + t reflect.Kind, + data interface{}) (interface{}, error) { + return data.(string) + "foo", nil + } + + f2 := func( + f reflect.Kind, + t reflect.Kind, + data interface{}) (interface{}, error) { + return data.(string) + "bar", nil + } + + f := ComposeDecodeHookFunc(f1, f2) + + result, err := DecodeHookExec( + f, reflect.ValueOf(""), reflect.ValueOf([]byte(""))) + if err != nil { + t.Fatalf("bad: %s", err) + } + if result.(string) != "foobar" { + t.Fatalf("bad: %#v", result) + } +} + +func TestComposeDecodeHookFunc_err(t *testing.T) { + f1 := func(reflect.Kind, reflect.Kind, interface{}) (interface{}, error) { + return nil, errors.New("foo") + } + + f2 := func(reflect.Kind, reflect.Kind, interface{}) (interface{}, error) { + panic("NOPE") + } + + f := ComposeDecodeHookFunc(f1, f2) + + _, err := DecodeHookExec( + f, reflect.ValueOf(""), reflect.ValueOf([]byte(""))) + if err.Error() != "foo" { + t.Fatalf("bad: %s", err) + } +} + +func TestComposeDecodeHookFunc_kinds(t *testing.T) { + var f2From reflect.Kind + + f1 := func( + f reflect.Kind, + t reflect.Kind, + data interface{}) (interface{}, error) { + return int(42), nil + } + + f2 := func( + f reflect.Kind, + t reflect.Kind, + data interface{}) (interface{}, error) { + f2From = f + return data, nil + } + + f := ComposeDecodeHookFunc(f1, f2) + + _, err := DecodeHookExec( + f, reflect.ValueOf(""), reflect.ValueOf([]byte(""))) + if err != nil { + t.Fatalf("bad: %s", err) + } + if f2From != reflect.Int { + t.Fatalf("bad: %#v", f2From) + } +} + +func TestOrComposeDecodeHookFunc(t *testing.T) { + f1 := func( + f reflect.Kind, + t reflect.Kind, + data interface{}) (interface{}, error) { + return data.(string) + "foo", nil + } + + f2 := func( + f reflect.Kind, + t reflect.Kind, + data interface{}) (interface{}, error) { + return data.(string) + "bar", nil + } + + f := OrComposeDecodeHookFunc(f1, f2) + + result, err := DecodeHookExec( + f, reflect.ValueOf(""), reflect.ValueOf([]byte(""))) + if err != nil { + t.Fatalf("bad: %s", err) + } + if result.(string) != "foo" { + t.Fatalf("bad: %#v", result) + } +} + +func TestOrComposeDecodeHookFunc_correctValueIsLast(t *testing.T) { + f1 := func( + f reflect.Kind, + t reflect.Kind, + data interface{}) (interface{}, error) { + return nil, errors.New("f1 error") + } + + f2 := func( + f reflect.Kind, + t reflect.Kind, + data interface{}) (interface{}, error) { + return nil, errors.New("f2 error") + } + + f3 := func( + f reflect.Kind, + t reflect.Kind, + data interface{}) (interface{}, error) { + return data.(string) + "bar", nil + } + + f := OrComposeDecodeHookFunc(f1, f2, f3) + + result, err := DecodeHookExec( + f, reflect.ValueOf(""), reflect.ValueOf([]byte(""))) + if err != nil { + t.Fatalf("bad: %s", err) + } + if result.(string) != "bar" { + t.Fatalf("bad: %#v", result) + } +} + +func TestOrComposeDecodeHookFunc_err(t *testing.T) { + f1 := func( + f reflect.Kind, + t reflect.Kind, + data interface{}) (interface{}, error) { + return nil, errors.New("f1 error") + } + + f2 := func( + f reflect.Kind, + t reflect.Kind, + data interface{}) (interface{}, error) { + return nil, errors.New("f2 error") + } + + f := OrComposeDecodeHookFunc(f1, f2) + + _, err := DecodeHookExec( + f, reflect.ValueOf(""), reflect.ValueOf([]byte(""))) + if err == nil { + t.Fatalf("bad: should return an error") + } + if err.Error() != "f1 error\nf2 error\n" { + t.Fatalf("bad: %s", err) + } +} + +func TestComposeDecodeHookFunc_safe_nofuncs(t *testing.T) { + f := ComposeDecodeHookFunc() + type myStruct2 struct { + MyInt int + } + + type myStruct1 struct { + Blah map[string]myStruct2 + } + + src := &myStruct1{Blah: map[string]myStruct2{ + "test": { + MyInt: 1, + }, + }} + + dst := &myStruct1{} + dConf := &DecoderConfig{ + Result: dst, + ErrorUnused: true, + DecodeHook: f, + } + d, err := NewDecoder(dConf) + if err != nil { + t.Fatal(err) + } + err = d.Decode(src) + if err != nil { + t.Fatal(err) + } +} + +func TestStringToSliceHookFunc(t *testing.T) { + f := StringToSliceHookFunc(",") + + strValue := reflect.ValueOf("42") + sliceValue := reflect.ValueOf([]byte("42")) + cases := []struct { + f, t reflect.Value + result interface{} + err bool + }{ + {sliceValue, sliceValue, []byte("42"), false}, + {strValue, strValue, "42", false}, + { + reflect.ValueOf("foo,bar,baz"), + sliceValue, + []string{"foo", "bar", "baz"}, + false, + }, + { + reflect.ValueOf(""), + sliceValue, + []string{}, + false, + }, + } + + for i, tc := range cases { + actual, err := DecodeHookExec(f, tc.f, tc.t) + if tc.err != (err != nil) { + t.Fatalf("case %d: expected jderr %#v", i, tc.err) + } + if !reflect.DeepEqual(actual, tc.result) { + t.Fatalf( + "case %d: expected %#v, got %#v", + i, tc.result, actual) + } + } +} + +func TestStringToTimeDurationHookFunc(t *testing.T) { + f := StringToTimeDurationHookFunc() + + timeValue := reflect.ValueOf(time.Duration(5)) + strValue := reflect.ValueOf("") + cases := []struct { + f, t reflect.Value + result interface{} + err bool + }{ + {reflect.ValueOf("5s"), timeValue, 5 * time.Second, false}, + {reflect.ValueOf("5"), timeValue, time.Duration(0), true}, + {reflect.ValueOf("5"), strValue, "5", false}, + } + + for i, tc := range cases { + actual, err := DecodeHookExec(f, tc.f, tc.t) + if tc.err != (err != nil) { + t.Fatalf("case %d: expected jderr %#v", i, tc.err) + } + if !reflect.DeepEqual(actual, tc.result) { + t.Fatalf( + "case %d: expected %#v, got %#v", + i, tc.result, actual) + } + } +} + +func TestStringToTimeHookFunc(t *testing.T) { + strValue := reflect.ValueOf("5") + timeValue := reflect.ValueOf(time.Time{}) + cases := []struct { + f, t reflect.Value + layout string + result interface{} + err bool + }{ + {reflect.ValueOf("2006-01-02T15:04:05Z"), timeValue, time.RFC3339, + time.Date(2006, 1, 2, 15, 4, 5, 0, time.UTC), false}, + {strValue, timeValue, time.RFC3339, time.Time{}, true}, + {strValue, strValue, time.RFC3339, "5", false}, + } + + for i, tc := range cases { + f := StringToTimeHookFunc(tc.layout) + actual, err := DecodeHookExec(f, tc.f, tc.t) + if tc.err != (err != nil) { + t.Fatalf("case %d: expected jderr %#v", i, tc.err) + } + if !reflect.DeepEqual(actual, tc.result) { + t.Fatalf( + "case %d: expected %#v, got %#v", + i, tc.result, actual) + } + } +} + +func TestStringToIPHookFunc(t *testing.T) { + strValue := reflect.ValueOf("5") + ipValue := reflect.ValueOf(net.IP{}) + cases := []struct { + f, t reflect.Value + result interface{} + err bool + }{ + {reflect.ValueOf("1.2.3.4"), ipValue, + net.IPv4(0x01, 0x02, 0x03, 0x04), false}, + {strValue, ipValue, net.IP{}, true}, + {strValue, strValue, "5", false}, + } + + for i, tc := range cases { + f := StringToIPHookFunc() + actual, err := DecodeHookExec(f, tc.f, tc.t) + if tc.err != (err != nil) { + t.Fatalf("case %d: expected jderr %#v", i, tc.err) + } + if !reflect.DeepEqual(actual, tc.result) { + t.Fatalf( + "case %d: expected %#v, got %#v", + i, tc.result, actual) + } + } +} + +func TestStringToIPNetHookFunc(t *testing.T) { + strValue := reflect.ValueOf("5") + ipNetValue := reflect.ValueOf(net.IPNet{}) + var nilNet *net.IPNet = nil + + cases := []struct { + f, t reflect.Value + result interface{} + err bool + }{ + {reflect.ValueOf("1.2.3.4/24"), ipNetValue, + &net.IPNet{ + IP: net.IP{0x01, 0x02, 0x03, 0x00}, + Mask: net.IPv4Mask(0xff, 0xff, 0xff, 0x00), + }, false}, + {strValue, ipNetValue, nilNet, true}, + {strValue, strValue, "5", false}, + } + + for i, tc := range cases { + f := StringToIPNetHookFunc() + actual, err := DecodeHookExec(f, tc.f, tc.t) + if tc.err != (err != nil) { + t.Fatalf("case %d: expected jderr %#v", i, tc.err) + } + if !reflect.DeepEqual(actual, tc.result) { + t.Fatalf( + "case %d: expected %#v, got %#v", + i, tc.result, actual) + } + } +} + +func TestWeaklyTypedHook(t *testing.T) { + var f DecodeHookFunc = WeaklyTypedHook + + strValue := reflect.ValueOf("") + cases := []struct { + f, t reflect.Value + result interface{} + err bool + }{ + // TO STRING + { + reflect.ValueOf(false), + strValue, + "0", + false, + }, + + { + reflect.ValueOf(true), + strValue, + "1", + false, + }, + + { + reflect.ValueOf(float32(7)), + strValue, + "7", + false, + }, + + { + reflect.ValueOf(int(7)), + strValue, + "7", + false, + }, + + { + reflect.ValueOf([]uint8("foo")), + strValue, + "foo", + false, + }, + + { + reflect.ValueOf(uint(7)), + strValue, + "7", + false, + }, + } + + for i, tc := range cases { + actual, err := DecodeHookExec(f, tc.f, tc.t) + if tc.err != (err != nil) { + t.Fatalf("case %d: expected jderr %#v", i, tc.err) + } + if !reflect.DeepEqual(actual, tc.result) { + t.Fatalf( + "case %d: expected %#v, got %#v", + i, tc.result, actual) + } + } +} + +func TestStructToMapHookFuncTabled(t *testing.T) { + var f DecodeHookFunc = RecursiveStructToMapHookFunc() + + type b struct { + TestKey string + } + + type a struct { + Sub b + } + + testStruct := a{ + Sub: b{ + TestKey: "testval", + }, + } + + testMap := map[string]interface{}{ + "Sub": map[string]interface{}{ + "TestKey": "testval", + }, + } + + cases := []struct { + name string + receiver interface{} + input interface{} + expected interface{} + err bool + }{ + { + "map receiver", + func() interface{} { + var res map[string]interface{} + return &res + }(), + testStruct, + &testMap, + false, + }, + { + "interface receiver", + func() interface{} { + var res interface{} + return &res + }(), + testStruct, + func() interface{} { + var exp interface{} = testMap + return &exp + }(), + false, + }, + { + "slice receiver errors", + func() interface{} { + var res []string + return &res + }(), + testStruct, + new([]string), + true, + }, + { + "slice to slice - no change", + func() interface{} { + var res []string + return &res + }(), + []string{"a", "b"}, + &[]string{"a", "b"}, + false, + }, + { + "string to string - no change", + func() interface{} { + var res string + return &res + }(), + "test", + func() *string { + s := "test" + return &s + }(), + false, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + cfg := &DecoderConfig{ + DecodeHook: f, + Result: tc.receiver, + } + + d, err := NewDecoder(cfg) + if err != nil { + t.Fatalf("unexpected jderr %#v", err) + } + + err = d.Decode(tc.input) + if tc.err != (err != nil) { + t.Fatalf("expected jderr %#v", err) + } + + if !reflect.DeepEqual(tc.expected, tc.receiver) { + t.Fatalf("expected %#v, got %#v", + tc.expected, tc.receiver) + } + }) + + } +} + +func TestTextUnmarshallerHookFunc(t *testing.T) { + cases := []struct { + f, t reflect.Value + result interface{} + err bool + }{ + {reflect.ValueOf("42"), reflect.ValueOf(big.Int{}), big.NewInt(42), false}, + {reflect.ValueOf("invalid"), reflect.ValueOf(big.Int{}), nil, true}, + {reflect.ValueOf("5"), reflect.ValueOf("5"), "5", false}, + } + + for i, tc := range cases { + f := TextUnmarshallerHookFunc() + actual, err := DecodeHookExec(f, tc.f, tc.t) + if tc.err != (err != nil) { + t.Fatalf("case %d: expected jderr %#v", i, tc.err) + } + if !reflect.DeepEqual(actual, tc.result) { + t.Fatalf( + "case %d: expected %#v, got %#v", + i, tc.result, actual) + } + } +} diff --git a/pkg/mapstructure/error.go b/pkg/mapstructure/error.go new file mode 100644 index 0000000..47a99e5 --- /dev/null +++ b/pkg/mapstructure/error.go @@ -0,0 +1,50 @@ +package mapstructure + +import ( + "errors" + "fmt" + "sort" + "strings" +) + +// Error implements the error interface and can represents multiple +// errors that occur in the course of a single decode. +type Error struct { + Errors []string +} + +func (e *Error) Error() string { + points := make([]string, len(e.Errors)) + for i, err := range e.Errors { + points[i] = fmt.Sprintf("* %s", err) + } + + sort.Strings(points) + return fmt.Sprintf( + "%d error(s) decoding:\n\n%s", + len(e.Errors), strings.Join(points, "\n")) +} + +// WrappedErrors implements the errwrap.Wrapper interface to make this +// return value more useful with the errwrap and go-multierror libraries. +func (e *Error) WrappedErrors() []error { + if e == nil { + return nil + } + + result := make([]error, len(e.Errors)) + for i, e := range e.Errors { + result[i] = errors.New(e) + } + + return result +} + +func appendErrors(errors []string, err error) []string { + switch e := err.(type) { + case *Error: + return append(errors, e.Errors...) + default: + return append(errors, e.Error()) + } +} diff --git a/pkg/mapstructure/mapstructure.go b/pkg/mapstructure/mapstructure.go new file mode 100644 index 0000000..0d26c75 --- /dev/null +++ b/pkg/mapstructure/mapstructure.go @@ -0,0 +1,1386 @@ +package mapstructure + +import ( + "encoding/json" + "errors" + "fmt" + "reflect" + "sort" + "strconv" + "strings" +) + +// DecodeHookFunc is the callback function that can be used for +// data transformations. See "DecodeHook" in the DecoderConfig +// struct. +// +// The type must be one of DecodeHookFuncType, DecodeHookFuncKind, or +// DecodeHookFuncValue. +// Values are a superset of Types (Values can return types), and Types are a +// superset of Kinds (Types can return Kinds) and are generally a richer thing +// to use, but Kinds are simpler if you only need those. +// +// The reason DecodeHookFunc is multi-typed is for backwards compatibility: +// we started with Kinds and then realized Types were the better solution, +// but have a promise to not break backwards compat so we now support +// both. +type DecodeHookFunc interface{} + +// DecodeHookFuncType is a DecodeHookFunc which has complete information about +// the source and target types. +type DecodeHookFuncType func(reflect.Type, reflect.Type, interface{}) (interface{}, error) + +// DecodeHookFuncKind is a DecodeHookFunc which knows only the Kinds of the +// source and target types. +type DecodeHookFuncKind func(reflect.Kind, reflect.Kind, interface{}) (interface{}, error) + +// DecodeHookFuncValue is a DecodeHookFunc which has complete access to both the source and target +// values. +type DecodeHookFuncValue func(from reflect.Value, to reflect.Value) (interface{}, error) + +// DecoderConfig is the configuration that is used to create a new decoder +// and allows customization of various aspects of decoding. +type DecoderConfig struct { + // DecodeHook, if set, will be called before any decoding and any + // type conversion (if WeaklyTypedInput is on). This lets you modify + // the values before they're set down onto the resulting struct. The + // DecodeHook is called for every map and value in the input. This means + // that if a struct has embedded fields with squash tags the decode hook + // is called only once with all of the input data, not once for each + // embedded struct. + // + // If an error is returned, the entire decode will fail with that error. + DecodeHook DecodeHookFunc + + // If ErrorUnused is true, then it is an error for there to exist + // keys in the original map that were unused in the decoding process + // (extra keys). + ErrorUnused bool + + // If ErrorUnset is true, then it is an error for there to exist + // fields in the result that were not set in the decoding process + // (extra fields). This only applies to decoding to a struct. This + // will affect all nested structs as well. + ErrorUnset bool + + // ZeroFields, if set to true, will zero fields before writing them. + // For example, a map will be emptied before decoded values are put in + // it. If this is false, a map will be merged. + ZeroFields bool + + // If WeaklyTypedInput is true, the decoder will make the following + // "weak" conversions: + // + // - bools to string (true = "1", false = "0") + // - numbers to string (base 10) + // - bools to int/uint (true = 1, false = 0) + // - strings to int/uint (base implied by prefix) + // - int to bool (true if value != 0) + // - string to bool (accepts: 1, t, T, TRUE, true, True, 0, f, F, + // FALSE, false, False. Anything else is an error) + // - empty array = empty map and vice versa + // - negative numbers to overflowed uint values (base 10) + // - slice of maps to a merged map + // - single values are converted to slices if required. Each + // element is weakly decoded. For example: "4" can become []int{4} + // if the target type is an int slice. + // + WeaklyTypedInput bool + + // Squash will squash embedded structs. A squash tag may also be + // added to an individual struct field using a tag. For example: + // + // type Parent struct { + // Child `mapstructure:",squash"` + // } + Squash bool + + // Metadata is the struct that will contain extra metadata about + // the decoding. If this is nil, then no metadata will be tracked. + Metadata *Metadata + + // Result is a pointer to the struct that will contain the decoded + // value. + Result interface{} + + // The tag name that mapstructure reads for field names. This + // defaults to "mapstructure" + TagName string + + // IgnoreUntaggedFields ignores all struct fields without explicit + // TagName, comparable to `mapstructure:"-"` as default behaviour. + IgnoreUntaggedFields bool + + // MatchName is the function used to match the map key to the struct + // field name or tag. Defaults to `strings.EqualFold`. This can be used + // to implement case-sensitive tag values, support snake casing, etc. + MatchName func(mapKey, fieldName string) bool +} + +// A Decoder takes a raw interface value and turns it into structured +// data, keeping track of rich error information along the way in case +// anything goes wrong. Unlike the basic top-level Decode method, you can +// more finely control how the Decoder behaves using the DecoderConfig +// structure. The top-level Decode method is just a convenience that sets +// up the most basic Decoder. +type Decoder struct { + config *DecoderConfig +} + +// Metadata contains information about decoding a structure that +// is tedious or difficult to get otherwise. +type Metadata struct { + // Keys are the keys of the structure which were successfully decoded + Keys []string + + // Unused is a slice of keys that were found in the raw value but + // weren't decoded since there was no matching field in the result interface + Unused []string + + // Unset is a slice of field names that were found in the result interface + // but weren't set in the decoding process since there was no matching value + // in the input + Unset []string +} + +// Decode takes an input structure and uses reflection to translate it to +// the output structure. output must be a pointer to a map or struct. +func Decode(input interface{}, output interface{}) error { + config := &DecoderConfig{ + Metadata: nil, + Result: output, + } + + decoder, err := NewDecoder(config) + if err != nil { + return err + } + + return decoder.Decode(input) +} + +// WeakDecode is the same as Decode but is shorthand to enable +// WeaklyTypedInput. See DecoderConfig for more info. +func WeakDecode(input, output interface{}) error { + config := &DecoderConfig{ + Metadata: nil, + Result: output, + WeaklyTypedInput: true, + } + + decoder, err := NewDecoder(config) + if err != nil { + return err + } + + return decoder.Decode(input) +} + +// DecodeMetadata is the same as Decode, but is shorthand to +// enable metadata collection. See DecoderConfig for more info. +func DecodeMetadata(input interface{}, output interface{}, metadata *Metadata) error { + config := &DecoderConfig{ + Metadata: metadata, + Result: output, + } + + decoder, err := NewDecoder(config) + if err != nil { + return err + } + + return decoder.Decode(input) +} + +// WeakDecodeMetadata is the same as Decode, but is shorthand to +// enable both WeaklyTypedInput and metadata collection. See +// DecoderConfig for more info. +func WeakDecodeMetadata(input interface{}, output interface{}, metadata *Metadata) error { + config := &DecoderConfig{ + Metadata: metadata, + Result: output, + WeaklyTypedInput: true, + } + + decoder, err := NewDecoder(config) + if err != nil { + return err + } + + return decoder.Decode(input) +} + +// NewDecoder returns a new decoder for the given configuration. Once +// a decoder has been returned, the same configuration must not be used +// again. +func NewDecoder(config *DecoderConfig) (*Decoder, error) { + val := reflect.ValueOf(config.Result) + if val.Kind() != reflect.Ptr { + return nil, errors.New("result must be a pointer") + } + + val = val.Elem() + if !val.CanAddr() { + return nil, errors.New("result must be addressable (a pointer)") + } + + if config.Metadata != nil { + if config.Metadata.Keys == nil { + config.Metadata.Keys = make([]string, 0) + } + + if config.Metadata.Unused == nil { + config.Metadata.Unused = make([]string, 0) + } + + if config.Metadata.Unset == nil { + config.Metadata.Unset = make([]string, 0) + } + } + + if config.TagName == "" { + config.TagName = "mapstructure" + } + + if config.MatchName == nil { + config.MatchName = strings.EqualFold + } + + result := &Decoder{ + config: config, + } + + return result, nil +} + +// Decode decodes the given raw interface to the target pointer specified +// by the configuration. +func (d *Decoder) Decode(input interface{}) error { + return d.decode("", input, reflect.ValueOf(d.config.Result).Elem()) +} + +// Decodes an unknown data type into a specific reflection value. +func (d *Decoder) decode(name string, input interface{}, outVal reflect.Value) error { + var inputVal reflect.Value + if input != nil { + inputVal = reflect.ValueOf(input) + + // We need to check here if input is a typed nil. Typed nils won't + // match the "input == nil" below so we check that here. + if inputVal.Kind() == reflect.Ptr && inputVal.IsNil() { + input = nil + } + } + + if input == nil { + // If the data is nil, then we don't set anything, unless ZeroFields is set + // to true. + if d.config.ZeroFields { + outVal.Set(reflect.Zero(outVal.Type())) + + if d.config.Metadata != nil && name != "" { + d.config.Metadata.Keys = append(d.config.Metadata.Keys, name) + } + } + return nil + } + + if !inputVal.IsValid() { + // If the input value is invalid, then we just set the value + // to be the zero value. + outVal.Set(reflect.Zero(outVal.Type())) + if d.config.Metadata != nil && name != "" { + d.config.Metadata.Keys = append(d.config.Metadata.Keys, name) + } + return nil + } + + if d.config.DecodeHook != nil { + // We have a DecodeHook, so let's pre-process the input. + var err error + input, err = DecodeHookExec(d.config.DecodeHook, inputVal, outVal) + if err != nil { + return fmt.Errorf("error decoding '%s': %s", name, err) + } + } + + var err error + outputKind := getKind(outVal) + addMetaKey := true + switch outputKind { + case reflect.Bool: + err = d.decodeBool(name, input, outVal) + case reflect.Interface: + err = d.decodeBasic(name, input, outVal) + case reflect.String: + err = d.decodeString(name, input, outVal) + case reflect.Int: + err = d.decodeInt(name, input, outVal) + case reflect.Uint: + err = d.decodeUint(name, input, outVal) + case reflect.Float32: + err = d.decodeFloat(name, input, outVal) + case reflect.Struct: + err = d.decodeStruct(name, input, outVal) + case reflect.Map: + err = d.decodeMap(name, input, outVal) + case reflect.Ptr: + addMetaKey, err = d.decodePtr(name, input, outVal) + case reflect.Slice: + err = d.decodeSlice(name, input, outVal) + case reflect.Array: + err = d.decodeArray(name, input, outVal) + case reflect.Func: + err = d.decodeFunc(name, input, outVal) + default: + // If we reached this point then we weren't able to decode it + return fmt.Errorf("%s: unsupported type: %s", name, outputKind) + } + + // If we reached here, then we successfully decoded SOMETHING, so + // mark the key as used if we're tracking metainput. + if addMetaKey && d.config.Metadata != nil && name != "" { + d.config.Metadata.Keys = append(d.config.Metadata.Keys, name) + } + + return err +} + +// This decodes a basic type (bool, int, string, etc.) and sets the +// value to "data" of that type. +func (d *Decoder) decodeBasic(name string, data interface{}, val reflect.Value) error { + if val.IsValid() && val.Elem().IsValid() { + elem := val.Elem() + + // If we can't address this element, then its not writable. Instead, + // we make a copy of the value (which is a pointer and therefore + // writable), decode into that, and replace the whole value. + copied := false + if !elem.CanAddr() { + copied = true + + // Make *T + copy := reflect.New(elem.Type()) + + // *T = elem + copy.Elem().Set(elem) + + // Set elem so we decode into it + elem = copy + } + + // Decode. If we have an error then return. We also return right + // away if we're not a copy because that means we decoded directly. + if err := d.decode(name, data, elem); err != nil || !copied { + return err + } + + // If we're a copy, we need to set te final result + val.Set(elem.Elem()) + return nil + } + + dataVal := reflect.ValueOf(data) + + // If the input data is a pointer, and the assigned type is the dereference + // of that exact pointer, then indirect it so that we can assign it. + // Example: *string to string + if dataVal.Kind() == reflect.Ptr && dataVal.Type().Elem() == val.Type() { + dataVal = reflect.Indirect(dataVal) + } + + if !dataVal.IsValid() { + dataVal = reflect.Zero(val.Type()) + } + + dataValType := dataVal.Type() + if !dataValType.AssignableTo(val.Type()) { + return fmt.Errorf( + "'%s' expected type '%s', got '%s'", + name, val.Type(), dataValType) + } + + val.Set(dataVal) + return nil +} + +func (d *Decoder) decodeString(name string, data interface{}, val reflect.Value) error { + dataVal := reflect.Indirect(reflect.ValueOf(data)) + dataKind := getKind(dataVal) + + converted := true + switch { + case dataKind == reflect.String: + val.SetString(dataVal.String()) + case dataKind == reflect.Bool && d.config.WeaklyTypedInput: + if dataVal.Bool() { + val.SetString("1") + } else { + val.SetString("0") + } + case dataKind == reflect.Int && d.config.WeaklyTypedInput: + val.SetString(strconv.FormatInt(dataVal.Int(), 10)) + case dataKind == reflect.Uint && d.config.WeaklyTypedInput: + val.SetString(strconv.FormatUint(dataVal.Uint(), 10)) + case dataKind == reflect.Float32 && d.config.WeaklyTypedInput: + val.SetString(strconv.FormatFloat(dataVal.Float(), 'f', -1, 64)) + case dataKind == reflect.Slice && d.config.WeaklyTypedInput, + dataKind == reflect.Array && d.config.WeaklyTypedInput: + dataType := dataVal.Type() + elemKind := dataType.Elem().Kind() + switch elemKind { + case reflect.Uint8: + var uints []uint8 + if dataKind == reflect.Array { + uints = make([]uint8, dataVal.Len(), dataVal.Len()) + for i := range uints { + uints[i] = dataVal.Index(i).Interface().(uint8) + } + } else { + uints = dataVal.Interface().([]uint8) + } + val.SetString(string(uints)) + default: + converted = false + } + default: + converted = false + } + + if !converted { + return fmt.Errorf( + "'%s' expected type '%s', got unconvertible type '%s', value: '%v'", + name, val.Type(), dataVal.Type(), data) + } + + return nil +} + +func (d *Decoder) decodeInt(name string, data interface{}, val reflect.Value) error { + dataVal := reflect.Indirect(reflect.ValueOf(data)) + dataKind := getKind(dataVal) + dataType := dataVal.Type() + + switch { + case dataKind == reflect.Int: + val.SetInt(dataVal.Int()) + case dataKind == reflect.Uint: + val.SetInt(int64(dataVal.Uint())) + case dataKind == reflect.Float32: + val.SetInt(int64(dataVal.Float())) + case dataKind == reflect.Bool && d.config.WeaklyTypedInput: + if dataVal.Bool() { + val.SetInt(1) + } else { + val.SetInt(0) + } + case dataKind == reflect.String && d.config.WeaklyTypedInput: + str := dataVal.String() + if str == "" { + str = "0" + } + + i, err := strconv.ParseInt(str, 0, val.Type().Bits()) + if err == nil { + val.SetInt(i) + } else { + return fmt.Errorf("cannot parse '%s' as int: %s", name, err) + } + case dataType.PkgPath() == "encoding/json" && dataType.Name() == "Number": + jn := data.(json.Number) + i, err := jn.Int64() + if err != nil { + return fmt.Errorf( + "error decoding json.Number into %s: %s", name, err) + } + val.SetInt(i) + default: + return fmt.Errorf( + "'%s' expected type '%s', got unconvertible type '%s', value: '%v'", + name, val.Type(), dataVal.Type(), data) + } + + return nil +} + +func (d *Decoder) decodeUint(name string, data interface{}, val reflect.Value) error { + dataVal := reflect.Indirect(reflect.ValueOf(data)) + dataKind := getKind(dataVal) + dataType := dataVal.Type() + + switch { + case dataKind == reflect.Int: + i := dataVal.Int() + if i < 0 && !d.config.WeaklyTypedInput { + return fmt.Errorf("cannot parse '%s', %d overflows uint", + name, i) + } + val.SetUint(uint64(i)) + case dataKind == reflect.Uint: + val.SetUint(dataVal.Uint()) + case dataKind == reflect.Float32: + f := dataVal.Float() + if f < 0 && !d.config.WeaklyTypedInput { + return fmt.Errorf("cannot parse '%s', %f overflows uint", + name, f) + } + val.SetUint(uint64(f)) + case dataKind == reflect.Bool && d.config.WeaklyTypedInput: + if dataVal.Bool() { + val.SetUint(1) + } else { + val.SetUint(0) + } + case dataKind == reflect.String && d.config.WeaklyTypedInput: + str := dataVal.String() + if str == "" { + str = "0" + } + + i, err := strconv.ParseUint(str, 0, val.Type().Bits()) + if err == nil { + val.SetUint(i) + } else { + return fmt.Errorf("cannot parse '%s' as uint: %s", name, err) + } + case dataType.PkgPath() == "encoding/json" && dataType.Name() == "Number": + jn := data.(json.Number) + i, err := strconv.ParseUint(string(jn), 0, 64) + if err != nil { + return fmt.Errorf( + "error decoding json.Number into %s: %s", name, err) + } + val.SetUint(i) + default: + return fmt.Errorf( + "'%s' expected type '%s', got unconvertible type '%s', value: '%v'", + name, val.Type(), dataVal.Type(), data) + } + + return nil +} + +func (d *Decoder) decodeBool(name string, data interface{}, val reflect.Value) error { + dataVal := reflect.Indirect(reflect.ValueOf(data)) + dataKind := getKind(dataVal) + + switch { + case dataKind == reflect.Bool: + val.SetBool(dataVal.Bool()) + case dataKind == reflect.Int && d.config.WeaklyTypedInput: + val.SetBool(dataVal.Int() != 0) + case dataKind == reflect.Uint && d.config.WeaklyTypedInput: + val.SetBool(dataVal.Uint() != 0) + case dataKind == reflect.Float32 && d.config.WeaklyTypedInput: + val.SetBool(dataVal.Float() != 0) + case dataKind == reflect.String && d.config.WeaklyTypedInput: + b, err := strconv.ParseBool(dataVal.String()) + if err == nil { + val.SetBool(b) + } else if dataVal.String() == "" { + val.SetBool(false) + } else { + return fmt.Errorf("cannot parse '%s' as bool: %s", name, err) + } + default: + return fmt.Errorf( + "'%s' expected type '%s', got unconvertible type '%s', value: '%v'", + name, val.Type(), dataVal.Type(), data) + } + + return nil +} + +func (d *Decoder) decodeFloat(name string, data interface{}, val reflect.Value) error { + dataVal := reflect.Indirect(reflect.ValueOf(data)) + dataKind := getKind(dataVal) + dataType := dataVal.Type() + + switch { + case dataKind == reflect.Int: + val.SetFloat(float64(dataVal.Int())) + case dataKind == reflect.Uint: + val.SetFloat(float64(dataVal.Uint())) + case dataKind == reflect.Float32: + val.SetFloat(dataVal.Float()) + case dataKind == reflect.Bool && d.config.WeaklyTypedInput: + if dataVal.Bool() { + val.SetFloat(1) + } else { + val.SetFloat(0) + } + case dataKind == reflect.String && d.config.WeaklyTypedInput: + str := dataVal.String() + if str == "" { + str = "0" + } + + f, err := strconv.ParseFloat(str, val.Type().Bits()) + if err == nil { + val.SetFloat(f) + } else { + return fmt.Errorf("cannot parse '%s' as float: %s", name, err) + } + case dataType.PkgPath() == "encoding/json" && dataType.Name() == "Number": + jn := data.(json.Number) + i, err := jn.Float64() + if err != nil { + return fmt.Errorf( + "error decoding json.Number into %s: %s", name, err) + } + val.SetFloat(i) + default: + return fmt.Errorf( + "'%s' expected type '%s', got unconvertible type '%s', value: '%v'", + name, val.Type(), dataVal.Type(), data) + } + + return nil +} + +func (d *Decoder) decodeMap(name string, data interface{}, val reflect.Value) error { + valType := val.Type() + valKeyType := valType.Key() + valElemType := valType.Elem() + + // By default we overwrite keys in the current map + valMap := val + + // If the map is nil or we're purposely zeroing fields, make a new map + if valMap.IsNil() || d.config.ZeroFields { + // Make a new map to hold our result + mapType := reflect.MapOf(valKeyType, valElemType) + valMap = reflect.MakeMap(mapType) + } + + // Check input type and based on the input type jump to the proper func + dataVal := reflect.Indirect(reflect.ValueOf(data)) + switch dataVal.Kind() { + case reflect.Map: + return d.decodeMapFromMap(name, dataVal, val, valMap) + + case reflect.Struct: + return d.decodeMapFromStruct(name, dataVal, val, valMap, data) + + case reflect.Array, reflect.Slice: + if d.config.WeaklyTypedInput { + return d.decodeMapFromSlice(name, dataVal, val, valMap) + } + + fallthrough + + default: + return fmt.Errorf("'%s' expected a map, got '%s'", name, dataVal.Kind()) + } +} + +func (d *Decoder) decodeMapFromSlice(name string, dataVal reflect.Value, val reflect.Value, valMap reflect.Value) error { + // Special case for BC reasons (covered by tests) + if dataVal.Len() == 0 { + val.Set(valMap) + return nil + } + + for i := 0; i < dataVal.Len(); i++ { + err := d.decode( + name+"["+strconv.Itoa(i)+"]", + dataVal.Index(i).Interface(), val) + if err != nil { + return err + } + } + + return nil +} + +func (d *Decoder) decodeMapFromMap(name string, dataVal reflect.Value, val reflect.Value, valMap reflect.Value) error { + valType := val.Type() + valKeyType := valType.Key() + valElemType := valType.Elem() + + // Accumulate errors + errors := make([]string, 0) + + // If the input data is empty, then we just match what the input data is. + if dataVal.Len() == 0 { + if dataVal.IsNil() { + if !val.IsNil() { + val.Set(dataVal) + } + } else { + // Set to empty allocated value + val.Set(valMap) + } + + return nil + } + + for _, k := range dataVal.MapKeys() { + fieldName := name + "[" + k.String() + "]" + + // First decode the key into the proper type + currentKey := reflect.Indirect(reflect.New(valKeyType)) + if err := d.decode(fieldName, k.Interface(), currentKey); err != nil { + errors = appendErrors(errors, err) + continue + } + + // Next decode the data into the proper type + v := dataVal.MapIndex(k).Interface() + currentVal := reflect.Indirect(reflect.New(valElemType)) + if err := d.decode(fieldName, v, currentVal); err != nil { + errors = appendErrors(errors, err) + continue + } + + valMap.SetMapIndex(currentKey, currentVal) + } + + // Set the built up map to the value + val.Set(valMap) + + // If we had errors, return those + if len(errors) > 0 { + return &Error{errors} + } + + return nil +} + +func (d *Decoder) decodeMapFromStruct(name string, dataVal reflect.Value, val reflect.Value, valMap reflect.Value, inData interface{}) error { + typ := dataVal.Type() + + for i := 0; i < typ.NumField(); i++ { + // Get the StructField first since this is a cheap operation. If the + // field is unexported, then ignore it. + f := typ.Field(i) + if f.PkgPath != "" { + continue + } + + // Next get the actual value of this field and verify it is assignable + // to the map value. + v := dataVal.Field(i) + if !v.Type().AssignableTo(valMap.Type().Elem()) { + return fmt.Errorf("cannot assign type '%s' to map value field of type '%s'", v.Type(), valMap.Type().Elem()) + } + + tagValue := f.Tag.Get(d.config.TagName) + keyName := f.Name + + if tagValue == "" && d.config.IgnoreUntaggedFields { + continue + } + + // If Squash is set in the config, we squash the field down. + squash := d.config.Squash && v.Kind() == reflect.Struct && f.Anonymous + + v = dereferencePtrToStructIfNeeded(v, d.config.TagName) + + // Determine the name of the key in the map + if index := strings.Index(tagValue, ","); index != -1 { + if tagValue[:index] == "-" { + continue + } + // If "omitempty" is specified in the tag, it ignores empty values. + if strings.Index(tagValue[index+1:], "omitempty") != -1 && isEmptyValue(v) { + continue + } + + // If "squash" is specified in the tag, we squash the field down. + squash = squash || strings.Index(tagValue[index+1:], "squash") != -1 + if squash { + // When squashing, the embedded type can be a pointer to a struct. + if v.Kind() == reflect.Ptr && v.Elem().Kind() == reflect.Struct { + v = v.Elem() + } + + // The final type must be a struct + if v.Kind() != reflect.Struct { + return fmt.Errorf("cannot squash non-struct type '%s'", v.Type()) + } + } + if keyNameTagValue := tagValue[:index]; keyNameTagValue != "" { + keyName = keyNameTagValue + } + } else if len(tagValue) > 0 { + if tagValue == "-" { + continue + } + keyName = tagValue + } + + switch v.Kind() { + // this is an embedded struct, so handle it differently + case reflect.Struct: + x := reflect.New(v.Type()) + x.Elem().Set(v) + + vType := valMap.Type() + vKeyType := vType.Key() + vElemType := vType.Elem() + mType := reflect.MapOf(vKeyType, vElemType) + vMap := reflect.MakeMap(mType) + + // Creating a pointer to a map so that other methods can completely + // overwrite the map if need be (looking at you decodeMapFromMap). The + // indirection allows the underlying map to be settable (CanSet() == true) + // where as reflect.MakeMap returns an unsettable map. + addrVal := reflect.New(vMap.Type()) + reflect.Indirect(addrVal).Set(vMap) + + err := d.decode(keyName, x.Interface(), reflect.Indirect(addrVal)) + if err != nil { + return err + } + + // the underlying map may have been completely overwritten so pull + // it indirectly out of the enclosing value. + vMap = reflect.Indirect(addrVal) + + if squash { + for _, k := range vMap.MapKeys() { + valMap.SetMapIndex(k, vMap.MapIndex(k)) + } + } else { + valMap.SetMapIndex(reflect.ValueOf(keyName), vMap) + } + + default: + valMap.SetMapIndex(reflect.ValueOf(keyName), v) + } + + } + + if val.CanAddr() { + val.Set(valMap) + } + + return nil +} + +func (d *Decoder) decodePtr(name string, data interface{}, val reflect.Value) (bool, error) { + // If the input data is nil, then we want to just set the output + // pointer to be nil as well. + isNil := data == nil + if !isNil { + switch v := reflect.Indirect(reflect.ValueOf(data)); v.Kind() { + case reflect.Chan, + reflect.Func, + reflect.Interface, + reflect.Map, + reflect.Ptr, + reflect.Slice: + isNil = v.IsNil() + } + } + if isNil { + if !val.IsNil() && val.CanSet() { + nilValue := reflect.New(val.Type()).Elem() + val.Set(nilValue) + } + + return true, nil + } + + // Create an element of the concrete (non pointer) type and decode + // into that. Then set the value of the pointer to this type. + valType := val.Type() + valElemType := valType.Elem() + if val.CanSet() { + realVal := val + if realVal.IsNil() || d.config.ZeroFields { + realVal = reflect.New(valElemType) + } + + if err := d.decode(name, data, reflect.Indirect(realVal)); err != nil { + // 报错情况下依旧设置指针 + val.Set(realVal) + return false, err + } + + val.Set(realVal) + } else { + if err := d.decode(name, data, reflect.Indirect(val)); err != nil { + return false, err + } + } + return false, nil +} + +func (d *Decoder) decodeFunc(name string, data interface{}, val reflect.Value) error { + // Create an element of the concrete (non pointer) type and decode + // into that. Then set the value of the pointer to this type. + dataVal := reflect.Indirect(reflect.ValueOf(data)) + if val.Type() != dataVal.Type() { + return fmt.Errorf( + "'%s' expected type '%s', got unconvertible type '%s', value: '%v'", + name, val.Type(), dataVal.Type(), data) + } + val.Set(dataVal) + return nil +} + +func (d *Decoder) decodeSlice(name string, data interface{}, val reflect.Value) error { + dataVal := reflect.Indirect(reflect.ValueOf(data)) + dataValKind := dataVal.Kind() + valType := val.Type() + valElemType := valType.Elem() + sliceType := reflect.SliceOf(valElemType) + + // If we have a non array/slice type then we first attempt to convert. + if dataValKind != reflect.Array && dataValKind != reflect.Slice { + if d.config.WeaklyTypedInput { + switch { + // Slice and array we use the normal logic + case dataValKind == reflect.Slice, dataValKind == reflect.Array: + break + + // Empty maps turn into empty slices + case dataValKind == reflect.Map: + if dataVal.Len() == 0 { + val.Set(reflect.MakeSlice(sliceType, 0, 0)) + return nil + } + // Create slice of maps of other sizes + return d.decodeSlice(name, []interface{}{data}, val) + + case dataValKind == reflect.String && valElemType.Kind() == reflect.Uint8: + return d.decodeSlice(name, []byte(dataVal.String()), val) + + // All other types we try to convert to the slice type + // and "lift" it into it. i.e. a string becomes a string slice. + default: + // Just re-try this function with data as a slice. + return d.decodeSlice(name, []interface{}{data}, val) + } + } + + return fmt.Errorf( + "'%s': source data must be an array or slice, got %s", name, dataValKind) + } + + // If the input value is nil, then don't allocate since empty != nil + if dataValKind != reflect.Array && dataVal.IsNil() { + return nil + } + + valSlice := val + if valSlice.IsNil() || d.config.ZeroFields { + // Make a new slice to hold our result, same size as the original data. + valSlice = reflect.MakeSlice(sliceType, dataVal.Len(), dataVal.Len()) + } + + // Accumulate any errors + errors := make([]string, 0) + + for i := 0; i < dataVal.Len(); i++ { + currentData := dataVal.Index(i).Interface() + for valSlice.Len() <= i { + valSlice = reflect.Append(valSlice, reflect.Zero(valElemType)) + } + currentField := valSlice.Index(i) + + fieldName := name + "[" + strconv.Itoa(i) + "]" + if err := d.decode(fieldName, currentData, currentField); err != nil { + errors = appendErrors(errors, err) + } + } + + // Finally, set the value to the slice we built up + val.Set(valSlice) + + // If there were errors, we return those + if len(errors) > 0 { + return &Error{errors} + } + + return nil +} + +func (d *Decoder) decodeArray(name string, data interface{}, val reflect.Value) error { + dataVal := reflect.Indirect(reflect.ValueOf(data)) + dataValKind := dataVal.Kind() + valType := val.Type() + valElemType := valType.Elem() + arrayType := reflect.ArrayOf(valType.Len(), valElemType) + + valArray := val + + if valArray.Interface() == reflect.Zero(valArray.Type()).Interface() || d.config.ZeroFields { + // Check input type + if dataValKind != reflect.Array && dataValKind != reflect.Slice { + if d.config.WeaklyTypedInput { + switch { + // Empty maps turn into empty arrays + case dataValKind == reflect.Map: + if dataVal.Len() == 0 { + val.Set(reflect.Zero(arrayType)) + return nil + } + + // All other types we try to convert to the array type + // and "lift" it into it. i.e. a string becomes a string array. + default: + // Just re-try this function with data as a slice. + return d.decodeArray(name, []interface{}{data}, val) + } + } + + return fmt.Errorf( + "'%s': source data must be an array or slice, got %s", name, dataValKind) + + } + if dataVal.Len() > arrayType.Len() { + return fmt.Errorf( + "'%s': expected source data to have length less or equal to %d, got %d", name, arrayType.Len(), dataVal.Len()) + + } + + // Make a new array to hold our result, same size as the original data. + valArray = reflect.New(arrayType).Elem() + } + + // Accumulate any errors + errors := make([]string, 0) + + for i := 0; i < dataVal.Len(); i++ { + currentData := dataVal.Index(i).Interface() + currentField := valArray.Index(i) + + fieldName := name + "[" + strconv.Itoa(i) + "]" + if err := d.decode(fieldName, currentData, currentField); err != nil { + errors = appendErrors(errors, err) + } + } + + // Finally, set the value to the array we built up + val.Set(valArray) + + // If there were errors, we return those + if len(errors) > 0 { + return &Error{errors} + } + + return nil +} + +func (d *Decoder) decodeStruct(name string, data interface{}, val reflect.Value) error { + dataVal := reflect.Indirect(reflect.ValueOf(data)) + + // If the type of the value to write to and the data match directly, + // then we just set it directly instead of recursing into the structure. + if dataVal.Type() == val.Type() { + val.Set(dataVal) + return nil + } + + dataValKind := dataVal.Kind() + switch dataValKind { + case reflect.Map: + return d.decodeStructFromMap(name, dataVal, val) + + case reflect.Struct: + // Not the most efficient way to do this but we can optimize later if + // we want to. To convert from struct to struct we go to map first + // as an intermediary. + + // Make a new map to hold our result + mapType := reflect.TypeOf((map[string]interface{})(nil)) + mval := reflect.MakeMap(mapType) + + // Creating a pointer to a map so that other methods can completely + // overwrite the map if need be (looking at you decodeMapFromMap). The + // indirection allows the underlying map to be settable (CanSet() == true) + // where as reflect.MakeMap returns an unsettable map. + addrVal := reflect.New(mval.Type()) + + reflect.Indirect(addrVal).Set(mval) + if err := d.decodeMapFromStruct(name, dataVal, reflect.Indirect(addrVal), mval, data); err != nil { + return err + } + + result := d.decodeStructFromMap(name, reflect.Indirect(addrVal), val) + return result + + default: + return fmt.Errorf("'%s' expected a map, got '%s'", name, dataVal.Kind()) + } +} + +func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) error { + dataValType := dataVal.Type() + if kind := dataValType.Key().Kind(); kind != reflect.String && kind != reflect.Interface { + return fmt.Errorf( + "'%s' needs a map with string keys, has '%s' keys", + name, dataValType.Key().Kind()) + } + + dataValKeys := make(map[reflect.Value]struct{}) + dataValKeysUnused := make(map[interface{}]struct{}) + for _, dataValKey := range dataVal.MapKeys() { + dataValKeys[dataValKey] = struct{}{} + dataValKeysUnused[dataValKey.Interface()] = struct{}{} + } + + targetValKeysUnused := make(map[interface{}]struct{}) + errors := make([]string, 0) + + // This slice will keep track of all the structs we'll be decoding. + // There can be more than one struct if there are embedded structs + // that are squashed. + structs := make([]reflect.Value, 1, 5) + structs[0] = val + + // Compile the list of all the fields that we're going to be decoding + // from all the structs. + type field struct { + field reflect.StructField + val reflect.Value + } + + // remainField is set to a valid field set with the "remain" tag if + // we are keeping track of remaining values. + var remainField *field + + fields := []field{} + for len(structs) > 0 { + structVal := structs[0] + structs = structs[1:] + + structType := structVal.Type() + + for i := 0; i < structType.NumField(); i++ { + fieldType := structType.Field(i) + fieldVal := structVal.Field(i) + if fieldVal.Kind() == reflect.Ptr && fieldVal.Elem().Kind() == reflect.Struct { + // Handle embedded struct pointers as embedded structs. + fieldVal = fieldVal.Elem() + } + + // If "squash" is specified in the tag, we squash the field down. + squash := d.config.Squash && fieldVal.Kind() == reflect.Struct && fieldType.Anonymous + remain := false + + // We always parse the tags cause we're looking for other tags too + tagParts := strings.Split(fieldType.Tag.Get(d.config.TagName), ",") + for _, tag := range tagParts[1:] { + if tag == "squash" { + squash = true + break + } + + if tag == "remain" { + remain = true + break + } + } + + if squash { + if fieldVal.Kind() != reflect.Struct { + errors = appendErrors(errors, + fmt.Errorf("%s: unsupported type for squash: %s", fieldType.Name, fieldVal.Kind())) + } else { + structs = append(structs, fieldVal) + } + continue + } + + // Build our field + if remain { + remainField = &field{fieldType, fieldVal} + } else { + // Normal struct field, store it away + fields = append(fields, field{fieldType, fieldVal}) + } + } + } + + // for fieldType, field := range fields { + for _, f := range fields { + field, fieldValue := f.field, f.val + fieldName := field.Name + + tagValue := field.Tag.Get(d.config.TagName) + tagValue = strings.SplitN(tagValue, ",", 2)[0] + if tagValue != "" { + fieldName = tagValue + } + + rawMapKey := reflect.ValueOf(fieldName) + rawMapVal := dataVal.MapIndex(rawMapKey) + if !rawMapVal.IsValid() { + // Do a slower search by iterating over each key and + // doing case-insensitive search. + for dataValKey := range dataValKeys { + mK, ok := dataValKey.Interface().(string) + if !ok { + // Not a string key + continue + } + + if d.config.MatchName(mK, fieldName) { + rawMapKey = dataValKey + rawMapVal = dataVal.MapIndex(dataValKey) + break + } + } + + if !rawMapVal.IsValid() { + // There was no matching key in the map for the value in + // the struct. Remember it for potential errors and metadata. + targetValKeysUnused[fieldName] = struct{}{} + continue + } + } + + if !fieldValue.IsValid() { + // This should never happen + panic("field is not valid") + } + + // If we can't set the field, then it is unexported or something, + // and we just continue onwards. + if !fieldValue.CanSet() { + continue + } + + // Delete the key we're using from the unused map so we stop tracking + delete(dataValKeysUnused, rawMapKey.Interface()) + + // If the name is empty string, then we're at the root, and we + // don't dot-join the fields. + if name != "" { + fieldName = name + "." + fieldName + } + + if err := d.decode(fieldName, rawMapVal.Interface(), fieldValue); err != nil { + errors = appendErrors(errors, err) + } + } + + // If we have a "remain"-tagged field and we have unused keys then + // we put the unused keys directly into the remain field. + if remainField != nil && len(dataValKeysUnused) > 0 { + // Build a map of only the unused values + remain := map[interface{}]interface{}{} + for key := range dataValKeysUnused { + remain[key] = dataVal.MapIndex(reflect.ValueOf(key)).Interface() + } + + // Decode it as-if we were just decoding this map onto our map. + if err := d.decodeMap(name, remain, remainField.val); err != nil { + errors = appendErrors(errors, err) + } + + // Set the map to nil so we have none so that the next check will + // not error (ErrorUnused) + dataValKeysUnused = nil + } + + if d.config.ErrorUnused && len(dataValKeysUnused) > 0 { + keys := make([]string, 0, len(dataValKeysUnused)) + for rawKey := range dataValKeysUnused { + keys = append(keys, rawKey.(string)) + } + sort.Strings(keys) + + err := fmt.Errorf("'%s' has invalid keys: %s", name, strings.Join(keys, ", ")) + errors = appendErrors(errors, err) + } + + if d.config.ErrorUnset && len(targetValKeysUnused) > 0 { + keys := make([]string, 0, len(targetValKeysUnused)) + for rawKey := range targetValKeysUnused { + keys = append(keys, rawKey.(string)) + } + sort.Strings(keys) + + err := fmt.Errorf("'%s' has unset fields: %s", name, strings.Join(keys, ", ")) + errors = appendErrors(errors, err) + } + + if len(errors) > 0 { + return &Error{errors} + } + + // Add the unused keys to the list of unused keys if we're tracking metadata + if d.config.Metadata != nil { + for rawKey := range dataValKeysUnused { + key := rawKey.(string) + if name != "" { + key = name + "." + key + } + + d.config.Metadata.Unused = append(d.config.Metadata.Unused, key) + } + for rawKey := range targetValKeysUnused { + key := rawKey.(string) + if name != "" { + key = name + "." + key + } + + d.config.Metadata.Unset = append(d.config.Metadata.Unset, key) + } + } + + return nil +} + +func isEmptyValue(v reflect.Value) bool { + switch getKind(v) { + case reflect.Array, reflect.Map, reflect.Slice, reflect.String: + return v.Len() == 0 + case reflect.Bool: + return !v.Bool() + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return v.Int() == 0 + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return v.Uint() == 0 + case reflect.Float32, reflect.Float64: + return v.Float() == 0 + case reflect.Interface, reflect.Ptr: + return v.IsNil() + } + return false +} + +func getKind(val reflect.Value) reflect.Kind { + kind := val.Kind() + + switch { + case kind >= reflect.Int && kind <= reflect.Int64: + return reflect.Int + case kind >= reflect.Uint && kind <= reflect.Uint64: + return reflect.Uint + case kind >= reflect.Float32 && kind <= reflect.Float64: + return reflect.Float32 + default: + return kind + } +} + +func isStructTypeConvertibleToMap(typ reflect.Type, checkMapstructureTags bool, tagName string) bool { + for i := 0; i < typ.NumField(); i++ { + f := typ.Field(i) + if f.PkgPath == "" && !checkMapstructureTags { // check for unexported fields + return true + } + if checkMapstructureTags && f.Tag.Get(tagName) != "" { // check for mapstructure tags inside + return true + } + } + return false +} + +func dereferencePtrToStructIfNeeded(v reflect.Value, tagName string) reflect.Value { + if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct { + return v + } + deref := v.Elem() + derefT := deref.Type() + if isStructTypeConvertibleToMap(derefT, true, tagName) { + return deref + } + return v +} diff --git a/pkg/mapstructure/mapstructure_benchmark_test.go b/pkg/mapstructure/mapstructure_benchmark_test.go new file mode 100644 index 0000000..b9bde7e --- /dev/null +++ b/pkg/mapstructure/mapstructure_benchmark_test.go @@ -0,0 +1,285 @@ +package mapstructure + +import ( + "encoding/json" + "testing" +) + +type Person struct { + Name string + Age int + Emails []string + Extra map[string]string +} + +func Benchmark_Decode(b *testing.B) { + input := map[string]interface{}{ + "name": "Mitchell", + "age": 91, + "emails": []string{"one", "two", "three"}, + "extra": map[string]string{ + "twitter": "mitchellh", + }, + } + + var result Person + for i := 0; i < b.N; i++ { + Decode(input, &result) + } +} + +// decodeViaJSON takes the map data and passes it through encoding/json to convert it into the +// given Go native structure pointed to by v. v must be a pointer to a struct. +func decodeViaJSON(data interface{}, v interface{}) error { + // Perform the task by simply marshalling the input into JSON, + // then unmarshalling it into target native Go struct. + b, err := json.Marshal(data) + if err != nil { + return err + } + return json.Unmarshal(b, v) +} + +func Benchmark_DecodeViaJSON(b *testing.B) { + input := map[string]interface{}{ + "name": "Mitchell", + "age": 91, + "emails": []string{"one", "two", "three"}, + "extra": map[string]string{ + "twitter": "mitchellh", + }, + } + + var result Person + for i := 0; i < b.N; i++ { + decodeViaJSON(input, &result) + } +} + +func Benchmark_JSONUnmarshal(b *testing.B) { + input := map[string]interface{}{ + "name": "Mitchell", + "age": 91, + "emails": []string{"one", "two", "three"}, + "extra": map[string]string{ + "twitter": "mitchellh", + }, + } + + inputB, err := json.Marshal(input) + if err != nil { + b.Fatal("Failed to marshal test input:", err) + } + + var result Person + for i := 0; i < b.N; i++ { + json.Unmarshal(inputB, &result) + } +} + +func Benchmark_DecodeBasic(b *testing.B) { + input := map[string]interface{}{ + "vstring": "foo", + "vint": 42, + "Vuint": 42, + "vbool": true, + "Vfloat": 42.42, + "vsilent": true, + "vdata": 42, + "vjsonInt": json.Number("1234"), + "vjsonFloat": json.Number("1234.5"), + "vjsonNumber": json.Number("1234.5"), + } + + for i := 0; i < b.N; i++ { + var result Basic + Decode(input, &result) + } +} + +func Benchmark_DecodeEmbedded(b *testing.B) { + input := map[string]interface{}{ + "vstring": "foo", + "Basic": map[string]interface{}{ + "vstring": "innerfoo", + }, + "vunique": "bar", + } + + var result Embedded + for i := 0; i < b.N; i++ { + Decode(input, &result) + } +} + +func Benchmark_DecodeTypeConversion(b *testing.B) { + input := map[string]interface{}{ + "IntToFloat": 42, + "IntToUint": 42, + "IntToBool": 1, + "IntToString": 42, + "UintToInt": 42, + "UintToFloat": 42, + "UintToBool": 42, + "UintToString": 42, + "BoolToInt": true, + "BoolToUint": true, + "BoolToFloat": true, + "BoolToString": true, + "FloatToInt": 42.42, + "FloatToUint": 42.42, + "FloatToBool": 42.42, + "FloatToString": 42.42, + "StringToInt": "42", + "StringToUint": "42", + "StringToBool": "1", + "StringToFloat": "42.42", + "SliceToMap": []interface{}{}, + "MapToSlice": map[string]interface{}{}, + } + + var resultStrict TypeConversionResult + for i := 0; i < b.N; i++ { + Decode(input, &resultStrict) + } +} + +func Benchmark_DecodeMap(b *testing.B) { + input := map[string]interface{}{ + "vfoo": "foo", + "vother": map[interface{}]interface{}{ + "foo": "foo", + "bar": "bar", + }, + } + + var result Map + for i := 0; i < b.N; i++ { + Decode(input, &result) + } +} + +func Benchmark_DecodeMapOfStruct(b *testing.B) { + input := map[string]interface{}{ + "value": map[string]interface{}{ + "foo": map[string]string{"vstring": "one"}, + "bar": map[string]string{"vstring": "two"}, + }, + } + + var result MapOfStruct + for i := 0; i < b.N; i++ { + Decode(input, &result) + } +} + +func Benchmark_DecodeSlice(b *testing.B) { + input := map[string]interface{}{ + "vfoo": "foo", + "vbar": []string{"foo", "bar", "baz"}, + } + + var result Slice + for i := 0; i < b.N; i++ { + Decode(input, &result) + } +} + +func Benchmark_DecodeSliceOfStruct(b *testing.B) { + input := map[string]interface{}{ + "value": []map[string]interface{}{ + {"vstring": "one"}, + {"vstring": "two"}, + }, + } + + var result SliceOfStruct + for i := 0; i < b.N; i++ { + Decode(input, &result) + } +} + +func Benchmark_DecodeWeaklyTypedInput(b *testing.B) { + // This input can come from anywhere, but typically comes from + // something like decoding JSON, generated by a weakly typed language + // such as PHP. + input := map[string]interface{}{ + "name": 123, // number => string + "age": "42", // string => number + "emails": map[string]interface{}{}, // empty map => empty array + } + + var result Person + config := &DecoderConfig{ + WeaklyTypedInput: true, + Result: &result, + } + + decoder, err := NewDecoder(config) + if err != nil { + panic(err) + } + + for i := 0; i < b.N; i++ { + decoder.Decode(input) + } +} + +func Benchmark_DecodeMetadata(b *testing.B) { + input := map[string]interface{}{ + "name": "Mitchell", + "age": 91, + "email": "foo@bar.com", + } + + var md Metadata + var result Person + config := &DecoderConfig{ + Metadata: &md, + Result: &result, + } + + decoder, err := NewDecoder(config) + if err != nil { + panic(err) + } + + for i := 0; i < b.N; i++ { + decoder.Decode(input) + } +} + +func Benchmark_DecodeMetadataEmbedded(b *testing.B) { + input := map[string]interface{}{ + "vstring": "foo", + "vunique": "bar", + } + + var md Metadata + var result EmbeddedSquash + config := &DecoderConfig{ + Metadata: &md, + Result: &result, + } + + decoder, err := NewDecoder(config) + if err != nil { + b.Fatalf("jderr: %s", err) + } + + for i := 0; i < b.N; i++ { + decoder.Decode(input) + } +} + +func Benchmark_DecodeTagged(b *testing.B) { + input := map[string]interface{}{ + "foo": "bar", + "bar": "value", + } + + var result Tagged + for i := 0; i < b.N; i++ { + Decode(input, &result) + } +} diff --git a/pkg/mapstructure/mapstructure_bugs_test.go b/pkg/mapstructure/mapstructure_bugs_test.go new file mode 100644 index 0000000..31fa5cd --- /dev/null +++ b/pkg/mapstructure/mapstructure_bugs_test.go @@ -0,0 +1,627 @@ +package mapstructure + +import ( + "reflect" + "testing" + "time" +) + +// GH-1, GH-10, GH-96 +func TestDecode_NilValue(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + in interface{} + target interface{} + out interface{} + metaKeys []string + metaUnused []string + }{ + { + "all nil", + &map[string]interface{}{ + "vfoo": nil, + "vother": nil, + }, + &Map{Vfoo: "foo", Vother: map[string]string{"foo": "bar"}}, + &Map{Vfoo: "", Vother: nil}, + []string{"Vfoo", "Vother"}, + []string{}, + }, + { + "partial nil", + &map[string]interface{}{ + "vfoo": "baz", + "vother": nil, + }, + &Map{Vfoo: "foo", Vother: map[string]string{"foo": "bar"}}, + &Map{Vfoo: "baz", Vother: nil}, + []string{"Vfoo", "Vother"}, + []string{}, + }, + { + "partial decode", + &map[string]interface{}{ + "vother": nil, + }, + &Map{Vfoo: "foo", Vother: map[string]string{"foo": "bar"}}, + &Map{Vfoo: "foo", Vother: nil}, + []string{"Vother"}, + []string{}, + }, + { + "unused values", + &map[string]interface{}{ + "vbar": "bar", + "vfoo": nil, + "vother": nil, + }, + &Map{Vfoo: "foo", Vother: map[string]string{"foo": "bar"}}, + &Map{Vfoo: "", Vother: nil}, + []string{"Vfoo", "Vother"}, + []string{"vbar"}, + }, + { + "map interface all nil", + &map[interface{}]interface{}{ + "vfoo": nil, + "vother": nil, + }, + &Map{Vfoo: "foo", Vother: map[string]string{"foo": "bar"}}, + &Map{Vfoo: "", Vother: nil}, + []string{"Vfoo", "Vother"}, + []string{}, + }, + { + "map interface partial nil", + &map[interface{}]interface{}{ + "vfoo": "baz", + "vother": nil, + }, + &Map{Vfoo: "foo", Vother: map[string]string{"foo": "bar"}}, + &Map{Vfoo: "baz", Vother: nil}, + []string{"Vfoo", "Vother"}, + []string{}, + }, + { + "map interface partial decode", + &map[interface{}]interface{}{ + "vother": nil, + }, + &Map{Vfoo: "foo", Vother: map[string]string{"foo": "bar"}}, + &Map{Vfoo: "foo", Vother: nil}, + []string{"Vother"}, + []string{}, + }, + { + "map interface unused values", + &map[interface{}]interface{}{ + "vbar": "bar", + "vfoo": nil, + "vother": nil, + }, + &Map{Vfoo: "foo", Vother: map[string]string{"foo": "bar"}}, + &Map{Vfoo: "", Vother: nil}, + []string{"Vfoo", "Vother"}, + []string{"vbar"}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + config := &DecoderConfig{ + Metadata: new(Metadata), + Result: tc.target, + ZeroFields: true, + } + + decoder, err := NewDecoder(config) + if err != nil { + t.Fatalf("should not error: %s", err) + } + + err = decoder.Decode(tc.in) + if err != nil { + t.Fatalf("should not error: %s", err) + } + + if !reflect.DeepEqual(tc.out, tc.target) { + t.Fatalf("%q: TestDecode_NilValue() expected: %#v, got: %#v", tc.name, tc.out, tc.target) + } + + if !reflect.DeepEqual(tc.metaKeys, config.Metadata.Keys) { + t.Fatalf("%q: Metadata.Keys mismatch expected: %#v, got: %#v", tc.name, tc.metaKeys, config.Metadata.Keys) + } + + if !reflect.DeepEqual(tc.metaUnused, config.Metadata.Unused) { + t.Fatalf("%q: Metadata.Unused mismatch expected: %#v, got: %#v", tc.name, tc.metaUnused, config.Metadata.Unused) + } + }) + } +} + +// #48 +func TestNestedTypePointerWithDefaults(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vfoo": "foo", + "vbar": map[string]interface{}{ + "vstring": "foo", + "vint": 42, + "vbool": true, + }, + } + + result := NestedPointer{ + Vbar: &Basic{ + Vuint: 42, + }, + } + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if result.Vfoo != "foo" { + t.Errorf("vfoo value should be 'foo': %#v", result.Vfoo) + } + + if result.Vbar.Vstring != "foo" { + t.Errorf("vstring value should be 'foo': %#v", result.Vbar.Vstring) + } + + if result.Vbar.Vint != 42 { + t.Errorf("vint value should be 42: %#v", result.Vbar.Vint) + } + + if result.Vbar.Vbool != true { + t.Errorf("vbool value should be true: %#v", result.Vbar.Vbool) + } + + if result.Vbar.Vextra != "" { + t.Errorf("vextra value should be empty: %#v", result.Vbar.Vextra) + } + + // this is the error + if result.Vbar.Vuint != 42 { + t.Errorf("vuint value should be 42: %#v", result.Vbar.Vuint) + } + +} + +type NestedSlice struct { + Vfoo string + Vbars []Basic + Vempty []Basic +} + +// #48 +func TestNestedTypeSliceWithDefaults(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vfoo": "foo", + "vbars": []map[string]interface{}{ + {"vstring": "foo", "vint": 42, "vbool": true}, + {"vint": 42, "vbool": true}, + }, + "vempty": []map[string]interface{}{ + {"vstring": "foo", "vint": 42, "vbool": true}, + {"vint": 42, "vbool": true}, + }, + } + + result := NestedSlice{ + Vbars: []Basic{ + {Vuint: 42}, + {Vstring: "foo"}, + }, + } + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if result.Vfoo != "foo" { + t.Errorf("vfoo value should be 'foo': %#v", result.Vfoo) + } + + if result.Vbars[0].Vstring != "foo" { + t.Errorf("vstring value should be 'foo': %#v", result.Vbars[0].Vstring) + } + // this is the error + if result.Vbars[0].Vuint != 42 { + t.Errorf("vuint value should be 42: %#v", result.Vbars[0].Vuint) + } +} + +// #48 workaround +func TestNestedTypeWithDefaults(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vfoo": "foo", + "vbar": map[string]interface{}{ + "vstring": "foo", + "vint": 42, + "vbool": true, + }, + } + + result := Nested{ + Vbar: Basic{ + Vuint: 42, + }, + } + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if result.Vfoo != "foo" { + t.Errorf("vfoo value should be 'foo': %#v", result.Vfoo) + } + + if result.Vbar.Vstring != "foo" { + t.Errorf("vstring value should be 'foo': %#v", result.Vbar.Vstring) + } + + if result.Vbar.Vint != 42 { + t.Errorf("vint value should be 42: %#v", result.Vbar.Vint) + } + + if result.Vbar.Vbool != true { + t.Errorf("vbool value should be true: %#v", result.Vbar.Vbool) + } + + if result.Vbar.Vextra != "" { + t.Errorf("vextra value should be empty: %#v", result.Vbar.Vextra) + } + + // this is the error + if result.Vbar.Vuint != 42 { + t.Errorf("vuint value should be 42: %#v", result.Vbar.Vuint) + } + +} + +// #67 panic() on extending slices (decodeSlice with disabled ZeroValues) +func TestDecodeSliceToEmptySliceWOZeroing(t *testing.T) { + t.Parallel() + + type TestStruct struct { + Vfoo []string + } + + decode := func(m interface{}, rawVal interface{}) error { + config := &DecoderConfig{ + Metadata: nil, + Result: rawVal, + ZeroFields: false, + } + + decoder, err := NewDecoder(config) + if err != nil { + return err + } + + return decoder.Decode(m) + } + + { + input := map[string]interface{}{ + "vfoo": []string{"1"}, + } + + result := &TestStruct{} + + err := decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + } + + { + input := map[string]interface{}{ + "vfoo": []string{"1"}, + } + + result := &TestStruct{ + Vfoo: []string{}, + } + + err := decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + } + + { + input := map[string]interface{}{ + "vfoo": []string{"2", "3"}, + } + + result := &TestStruct{ + Vfoo: []string{"1"}, + } + + err := decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + } +} + +// #70 +func TestNextSquashMapstructure(t *testing.T) { + data := &struct { + Level1 struct { + Level2 struct { + Foo string + } `mapstructure:",squash"` + } `mapstructure:",squash"` + }{} + err := Decode(map[interface{}]interface{}{"foo": "baz"}, &data) + if err != nil { + t.Fatalf("should not error: %s", err) + } + if data.Level1.Level2.Foo != "baz" { + t.Fatal("value should be baz") + } +} + +type ImplementsInterfacePointerReceiver struct { + Name string +} + +func (i *ImplementsInterfacePointerReceiver) DoStuff() {} + +type ImplementsInterfaceValueReceiver string + +func (i ImplementsInterfaceValueReceiver) DoStuff() {} + +// GH-140 Type error when using DecodeHook to decode into interface +func TestDecode_DecodeHookInterface(t *testing.T) { + t.Parallel() + + type Interface interface { + DoStuff() + } + type DecodeIntoInterface struct { + Test Interface + } + + testData := map[string]string{"test": "test"} + + stringToPointerInterfaceDecodeHook := func(from, to reflect.Type, data interface{}) (interface{}, error) { + if from.Kind() != reflect.String { + return data, nil + } + + if to != reflect.TypeOf((*Interface)(nil)).Elem() { + return data, nil + } + // Ensure interface is satisfied + var impl Interface = &ImplementsInterfacePointerReceiver{data.(string)} + return impl, nil + } + + stringToValueInterfaceDecodeHook := func(from, to reflect.Type, data interface{}) (interface{}, error) { + if from.Kind() != reflect.String { + return data, nil + } + + if to != reflect.TypeOf((*Interface)(nil)).Elem() { + return data, nil + } + // Ensure interface is satisfied + var impl Interface = ImplementsInterfaceValueReceiver(data.(string)) + return impl, nil + } + + { + decodeInto := new(DecodeIntoInterface) + + decoder, _ := NewDecoder(&DecoderConfig{ + DecodeHook: stringToPointerInterfaceDecodeHook, + Result: decodeInto, + }) + + err := decoder.Decode(testData) + if err != nil { + t.Fatalf("Decode returned error: %s", err) + } + + expected := &ImplementsInterfacePointerReceiver{"test"} + if !reflect.DeepEqual(decodeInto.Test, expected) { + t.Fatalf("expected: %#v (%T), got: %#v (%T)", decodeInto.Test, decodeInto.Test, expected, expected) + } + } + + { + decodeInto := new(DecodeIntoInterface) + + decoder, _ := NewDecoder(&DecoderConfig{ + DecodeHook: stringToValueInterfaceDecodeHook, + Result: decodeInto, + }) + + err := decoder.Decode(testData) + if err != nil { + t.Fatalf("Decode returned error: %s", err) + } + + expected := ImplementsInterfaceValueReceiver("test") + if !reflect.DeepEqual(decodeInto.Test, expected) { + t.Fatalf("expected: %#v (%T), got: %#v (%T)", decodeInto.Test, decodeInto.Test, expected, expected) + } + } +} + +// #103 Check for data type before trying to access its composants prevent a panic error +// in decodeSlice +func TestDecodeBadDataTypeInSlice(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "Toto": "titi", + } + result := []struct { + Toto string + }{} + + if err := Decode(input, &result); err == nil { + t.Error("An error was expected, got nil") + } +} + +// #202 Ensure that intermediate maps in the struct -> struct decode process are settable +// and not just the elements within them. +func TestDecodeIntermediateMapsSettable(t *testing.T) { + type Timestamp struct { + Seconds int64 + Nanos int32 + } + + type TsWrapper struct { + Timestamp *Timestamp + } + + type TimeWrapper struct { + Timestamp time.Time + } + + input := TimeWrapper{ + Timestamp: time.Unix(123456789, 987654), + } + + expected := TsWrapper{ + Timestamp: &Timestamp{ + Seconds: 123456789, + Nanos: 987654, + }, + } + + timePtrType := reflect.TypeOf((*time.Time)(nil)) + mapStrInfType := reflect.TypeOf((map[string]interface{})(nil)) + + var actual TsWrapper + decoder, err := NewDecoder(&DecoderConfig{ + Result: &actual, + DecodeHook: func(from, to reflect.Type, data interface{}) (interface{}, error) { + if from == timePtrType && to == mapStrInfType { + ts := data.(*time.Time) + nanos := ts.UnixNano() + + seconds := nanos / 1000000000 + nanos = nanos % 1000000000 + + return &map[string]interface{}{ + "Seconds": seconds, + "Nanos": int32(nanos), + }, nil + } + return data, nil + }, + }) + + if err != nil { + t.Fatalf("failed to create decoder: %v", err) + } + + if err := decoder.Decode(&input); err != nil { + t.Fatalf("failed to decode input: %v", err) + } + + if !reflect.DeepEqual(expected, actual) { + t.Fatalf("expected: %#[1]v (%[1]T), got: %#[2]v (%[2]T)", expected, actual) + } +} + +// GH-206: decodeInt throws an error for an empty string +func TestDecode_weakEmptyStringToInt(t *testing.T) { + input := map[string]interface{}{ + "StringToInt": "", + "StringToUint": "", + "StringToBool": "", + "StringToFloat": "", + } + + expectedResultWeak := TypeConversionResult{ + StringToInt: 0, + StringToUint: 0, + StringToBool: false, + StringToFloat: 0, + } + + // Test weak type conversion + var resultWeak TypeConversionResult + err := WeakDecode(input, &resultWeak) + if err != nil { + t.Fatalf("got an jderr: %s", err) + } + + if !reflect.DeepEqual(resultWeak, expectedResultWeak) { + t.Errorf("expected \n%#v, got: \n%#v", expectedResultWeak, resultWeak) + } +} + +// GH-228: Squash cause *time.Time set to zero +func TestMapSquash(t *testing.T) { + type AA struct { + T *time.Time + } + type A struct { + AA + } + + v := time.Now() + in := &AA{ + T: &v, + } + out := &A{} + d, err := NewDecoder(&DecoderConfig{ + Squash: true, + Result: out, + }) + if err != nil { + t.Fatalf("jderr: %s", err) + } + if err := d.Decode(in); err != nil { + t.Fatalf("jderr: %s", err) + } + + // these failed + if !v.Equal(*out.T) { + t.Fatal("expected equal") + } + if out.T.IsZero() { + t.Fatal("expected false") + } +} + +// GH-238: Empty key name when decoding map from struct with only omitempty flag +func TestMapOmitEmptyWithEmptyFieldnameInTag(t *testing.T) { + type Struct struct { + Username string `mapstructure:",omitempty"` + Age int `mapstructure:",omitempty"` + } + + s := Struct{ + Username: "Joe", + } + var m map[string]interface{} + + if err := Decode(s, &m); err != nil { + t.Fatal(err) + } + + if len(m) != 1 { + t.Fatalf("fail: %#v", m) + } + if m["Username"] != "Joe" { + t.Fatalf("fail: %#v", m) + } +} diff --git a/pkg/mapstructure/mapstructure_examples_test.go b/pkg/mapstructure/mapstructure_examples_test.go new file mode 100644 index 0000000..2413b69 --- /dev/null +++ b/pkg/mapstructure/mapstructure_examples_test.go @@ -0,0 +1,256 @@ +package mapstructure + +import ( + "fmt" +) + +func ExampleDecode() { + type Person struct { + Name string + Age int + Emails []string + Extra map[string]string + } + + // This input can come from anywhere, but typically comes from + // something like decoding JSON where we're not quite sure of the + // struct initially. + input := map[string]interface{}{ + "name": "Mitchell", + "age": 91, + "emails": []string{"one", "two", "three"}, + "extra": map[string]string{ + "twitter": "mitchellh", + }, + } + + var result Person + err := Decode(input, &result) + if err != nil { + panic(err) + } + + fmt.Printf("%#v", result) + // Output: + // mapstructure.Person{Name:"Mitchell", Age:91, Emails:[]string{"one", "two", "three"}, Extra:map[string]string{"twitter":"mitchellh"}} +} + +func ExampleDecode_errors() { + type Person struct { + Name string + Age int + Emails []string + Extra map[string]string + } + + // This input can come from anywhere, but typically comes from + // something like decoding JSON where we're not quite sure of the + // struct initially. + input := map[string]interface{}{ + "name": 123, + "age": "bad value", + "emails": []int{1, 2, 3}, + } + + var result Person + err := Decode(input, &result) + if err == nil { + panic("should have an error") + } + + fmt.Println(err.Error()) + // Output: + // 5 error(s) decoding: + // + // * 'Age' expected type 'int', got unconvertible type 'string', value: 'bad value' + // * 'Emails[0]' expected type 'string', got unconvertible type 'int', value: '1' + // * 'Emails[1]' expected type 'string', got unconvertible type 'int', value: '2' + // * 'Emails[2]' expected type 'string', got unconvertible type 'int', value: '3' + // * 'Name' expected type 'string', got unconvertible type 'int', value: '123' +} + +func ExampleDecode_metadata() { + type Person struct { + Name string + Age int + } + + // This input can come from anywhere, but typically comes from + // something like decoding JSON where we're not quite sure of the + // struct initially. + input := map[string]interface{}{ + "name": "Mitchell", + "age": 91, + "email": "foo@bar.com", + } + + // For metadata, we make a more advanced DecoderConfig so we can + // more finely configure the decoder that is used. In this case, we + // just tell the decoder we want to track metadata. + var md Metadata + var result Person + config := &DecoderConfig{ + Metadata: &md, + Result: &result, + } + + decoder, err := NewDecoder(config) + if err != nil { + panic(err) + } + + if err := decoder.Decode(input); err != nil { + panic(err) + } + + fmt.Printf("Unused keys: %#v", md.Unused) + // Output: + // Unused keys: []string{"email"} +} + +func ExampleDecode_weaklyTypedInput() { + type Person struct { + Name string + Age int + Emails []string + } + + // This input can come from anywhere, but typically comes from + // something like decoding JSON, generated by a weakly typed language + // such as PHP. + input := map[string]interface{}{ + "name": 123, // number => string + "age": "42", // string => number + "emails": map[string]interface{}{}, // empty map => empty array + } + + var result Person + config := &DecoderConfig{ + WeaklyTypedInput: true, + Result: &result, + } + + decoder, err := NewDecoder(config) + if err != nil { + panic(err) + } + + err = decoder.Decode(input) + if err != nil { + panic(err) + } + + fmt.Printf("%#v", result) + // Output: mapstructure.Person{Name:"123", Age:42, Emails:[]string{}} +} + +func ExampleDecode_tags() { + // Note that the mapstructure tags defined in the struct type + // can indicate which fields the values are mapped to. + type Person struct { + Name string `mapstructure:"person_name"` + Age int `mapstructure:"person_age"` + } + + input := map[string]interface{}{ + "person_name": "Mitchell", + "person_age": 91, + } + + var result Person + err := Decode(input, &result) + if err != nil { + panic(err) + } + + fmt.Printf("%#v", result) + // Output: + // mapstructure.Person{Name:"Mitchell", Age:91} +} + +func ExampleDecode_embeddedStruct() { + // Squashing multiple embedded structs is allowed using the squash tag. + // This is demonstrated by creating a composite struct of multiple types + // and decoding into it. In this case, a person can carry with it both + // a Family and a Location, as well as their own FirstName. + type Family struct { + LastName string + } + type Location struct { + City string + } + type Person struct { + Family `mapstructure:",squash"` + Location `mapstructure:",squash"` + FirstName string + } + + input := map[string]interface{}{ + "FirstName": "Mitchell", + "LastName": "Hashimoto", + "City": "San Francisco", + } + + var result Person + err := Decode(input, &result) + if err != nil { + panic(err) + } + + fmt.Printf("%s %s, %s", result.FirstName, result.LastName, result.City) + // Output: + // Mitchell Hashimoto, San Francisco +} + +func ExampleDecode_remainingData() { + // Note that the mapstructure tags defined in the struct type + // can indicate which fields the values are mapped to. + type Person struct { + Name string + Age int + Other map[string]interface{} `mapstructure:",remain"` + } + + input := map[string]interface{}{ + "name": "Mitchell", + "age": 91, + "email": "mitchell@example.com", + } + + var result Person + err := Decode(input, &result) + if err != nil { + panic(err) + } + + fmt.Printf("%#v", result) + // Output: + // mapstructure.Person{Name:"Mitchell", Age:91, Other:map[string]interface {}{"email":"mitchell@example.com"}} +} + +func ExampleDecode_omitempty() { + // Add omitempty annotation to avoid map keys for empty values + type Family struct { + LastName string + } + type Location struct { + City string + } + type Person struct { + *Family `mapstructure:",omitempty"` + *Location `mapstructure:",omitempty"` + Age int + FirstName string + } + + result := &map[string]interface{}{} + input := Person{FirstName: "Somebody"} + err := Decode(input, &result) + if err != nil { + panic(err) + } + + fmt.Printf("%+v", result) + // Output: + // &map[Age:0 FirstName:Somebody] +} diff --git a/pkg/mapstructure/mapstructure_ext_test.go b/pkg/mapstructure/mapstructure_ext_test.go new file mode 100644 index 0000000..a646a51 --- /dev/null +++ b/pkg/mapstructure/mapstructure_ext_test.go @@ -0,0 +1,58 @@ +package mapstructure + +import ( + "reflect" + "testing" +) + +func TestDecode_Ptr(t *testing.T) { + t.Parallel() + + type G struct { + Id int + Name string + } + + type X struct { + Id int + Name int + } + + type AG struct { + List []*G + } + + type AX struct { + List []*X + } + + g2 := &AG{ + List: []*G{ + { + Id: 11, + Name: "gg", + }, + }, + } + x2 := AX{} + + // 报错但还是会转换成功,转换后值为目标类型的 0 值 + err := Decode(g2, &x2) + + res := AX{ + List: []*X{ + { + Id: 11, + Name: 0, // 这个类型的 0 值 + }, + }, + } + + if err == nil { + t.Errorf("Decode_Ptr jderr should not be 'nil': %#v", err) + } + + if !reflect.DeepEqual(res, x2) { + t.Errorf("result should be %#v: got %#v", res, x2) + } +} diff --git a/pkg/mapstructure/mapstructure_test.go b/pkg/mapstructure/mapstructure_test.go new file mode 100644 index 0000000..17e609a --- /dev/null +++ b/pkg/mapstructure/mapstructure_test.go @@ -0,0 +1,2763 @@ +package mapstructure + +import ( + "encoding/json" + "io" + "reflect" + "sort" + "strings" + "testing" + "time" +) + +type Basic struct { + Vstring string + Vint int + Vint8 int8 + Vint16 int16 + Vint32 int32 + Vint64 int64 + Vuint uint + Vbool bool + Vfloat float64 + Vextra string + vsilent bool + Vdata interface{} + VjsonInt int + VjsonUint uint + VjsonUint64 uint64 + VjsonFloat float64 + VjsonNumber json.Number +} + +type BasicPointer struct { + Vstring *string + Vint *int + Vuint *uint + Vbool *bool + Vfloat *float64 + Vextra *string + vsilent *bool + Vdata *interface{} + VjsonInt *int + VjsonFloat *float64 + VjsonNumber *json.Number +} + +type BasicSquash struct { + Test Basic `mapstructure:",squash"` +} + +type Embedded struct { + Basic + Vunique string +} + +type EmbeddedPointer struct { + *Basic + Vunique string +} + +type EmbeddedSquash struct { + Basic `mapstructure:",squash"` + Vunique string +} + +type EmbeddedPointerSquash struct { + *Basic `mapstructure:",squash"` + Vunique string +} + +type BasicMapStructure struct { + Vunique string `mapstructure:"vunique"` + Vtime *time.Time `mapstructure:"time"` +} + +type NestedPointerWithMapstructure struct { + Vbar *BasicMapStructure `mapstructure:"vbar"` +} + +type EmbeddedPointerSquashWithNestedMapstructure struct { + *NestedPointerWithMapstructure `mapstructure:",squash"` + Vunique string +} + +type EmbeddedAndNamed struct { + Basic + Named Basic + Vunique string +} + +type SliceAlias []string + +type EmbeddedSlice struct { + SliceAlias `mapstructure:"slice_alias"` + Vunique string +} + +type ArrayAlias [2]string + +type EmbeddedArray struct { + ArrayAlias `mapstructure:"array_alias"` + Vunique string +} + +type SquashOnNonStructType struct { + InvalidSquashType int `mapstructure:",squash"` +} + +type Map struct { + Vfoo string + Vother map[string]string +} + +type MapOfStruct struct { + Value map[string]Basic +} + +type Nested struct { + Vfoo string + Vbar Basic +} + +type NestedPointer struct { + Vfoo string + Vbar *Basic +} + +type NilInterface struct { + W io.Writer +} + +type NilPointer struct { + Value *string +} + +type Slice struct { + Vfoo string + Vbar []string +} + +type SliceOfAlias struct { + Vfoo string + Vbar SliceAlias +} + +type SliceOfStruct struct { + Value []Basic +} + +type SlicePointer struct { + Vbar *[]string +} + +type Array struct { + Vfoo string + Vbar [2]string +} + +type ArrayOfStruct struct { + Value [2]Basic +} + +type Func struct { + Foo func() string +} + +type Tagged struct { + Extra string `mapstructure:"bar,what,what"` + Value string `mapstructure:"foo"` +} + +type Remainder struct { + A string + Extra map[string]interface{} `mapstructure:",remain"` +} + +type StructWithOmitEmpty struct { + VisibleStringField string `mapstructure:"visible-string"` + OmitStringField string `mapstructure:"omittable-string,omitempty"` + VisibleIntField int `mapstructure:"visible-int"` + OmitIntField int `mapstructure:"omittable-int,omitempty"` + VisibleFloatField float64 `mapstructure:"visible-float"` + OmitFloatField float64 `mapstructure:"omittable-float,omitempty"` + VisibleSliceField []interface{} `mapstructure:"visible-slice"` + OmitSliceField []interface{} `mapstructure:"omittable-slice,omitempty"` + VisibleMapField map[string]interface{} `mapstructure:"visible-map"` + OmitMapField map[string]interface{} `mapstructure:"omittable-map,omitempty"` + NestedField *Nested `mapstructure:"visible-nested"` + OmitNestedField *Nested `mapstructure:"omittable-nested,omitempty"` +} + +type TypeConversionResult struct { + IntToFloat float32 + IntToUint uint + IntToBool bool + IntToString string + UintToInt int + UintToFloat float32 + UintToBool bool + UintToString string + BoolToInt int + BoolToUint uint + BoolToFloat float32 + BoolToString string + FloatToInt int + FloatToUint uint + FloatToBool bool + FloatToString string + SliceUint8ToString string + StringToSliceUint8 []byte + ArrayUint8ToString string + StringToInt int + StringToUint uint + StringToBool bool + StringToFloat float32 + StringToStrSlice []string + StringToIntSlice []int + StringToStrArray [1]string + StringToIntArray [1]int + SliceToMap map[string]interface{} + MapToSlice []interface{} + ArrayToMap map[string]interface{} + MapToArray [1]interface{} +} + +func TestBasicTypes(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vstring": "foo", + "vint": 42, + "vint8": 42, + "vint16": 42, + "vint32": 42, + "vint64": 42, + "Vuint": 42, + "vbool": true, + "Vfloat": 42.42, + "vsilent": true, + "vdata": 42, + "vjsonInt": json.Number("1234"), + "vjsonUint": json.Number("1234"), + "vjsonUint64": json.Number("9223372036854775809"), // 2^63 + 1 + "vjsonFloat": json.Number("1234.5"), + "vjsonNumber": json.Number("1234.5"), + } + + var result Basic + err := Decode(input, &result) + if err != nil { + t.Errorf("got an jderr: %s", err.Error()) + t.FailNow() + } + + if result.Vstring != "foo" { + t.Errorf("vstring value should be 'foo': %#v", result.Vstring) + } + + if result.Vint != 42 { + t.Errorf("vint value should be 42: %#v", result.Vint) + } + if result.Vint8 != 42 { + t.Errorf("vint8 value should be 42: %#v", result.Vint) + } + if result.Vint16 != 42 { + t.Errorf("vint16 value should be 42: %#v", result.Vint) + } + if result.Vint32 != 42 { + t.Errorf("vint32 value should be 42: %#v", result.Vint) + } + if result.Vint64 != 42 { + t.Errorf("vint64 value should be 42: %#v", result.Vint) + } + + if result.Vuint != 42 { + t.Errorf("vuint value should be 42: %#v", result.Vuint) + } + + if result.Vbool != true { + t.Errorf("vbool value should be true: %#v", result.Vbool) + } + + if result.Vfloat != 42.42 { + t.Errorf("vfloat value should be 42.42: %#v", result.Vfloat) + } + + if result.Vextra != "" { + t.Errorf("vextra value should be empty: %#v", result.Vextra) + } + + if result.vsilent != false { + t.Error("vsilent should not be set, it is unexported") + } + + if result.Vdata != 42 { + t.Error("vdata should be valid") + } + + if result.VjsonInt != 1234 { + t.Errorf("vjsonint value should be 1234: %#v", result.VjsonInt) + } + + if result.VjsonUint != 1234 { + t.Errorf("vjsonuint value should be 1234: %#v", result.VjsonUint) + } + + if result.VjsonUint64 != 9223372036854775809 { + t.Errorf("vjsonuint64 value should be 9223372036854775809: %#v", result.VjsonUint64) + } + + if result.VjsonFloat != 1234.5 { + t.Errorf("vjsonfloat value should be 1234.5: %#v", result.VjsonFloat) + } + + if !reflect.DeepEqual(result.VjsonNumber, json.Number("1234.5")) { + t.Errorf("vjsonnumber value should be '1234.5': %T, %#v", result.VjsonNumber, result.VjsonNumber) + } +} + +func TestBasic_IntWithFloat(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vint": float64(42), + } + + var result Basic + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err) + } +} + +func TestBasic_Merge(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vint": 42, + } + + var result Basic + result.Vuint = 100 + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err) + } + + expected := Basic{ + Vint: 42, + Vuint: 100, + } + if !reflect.DeepEqual(result, expected) { + t.Fatalf("bad: %#v", result) + } +} + +// Test for issue #46. +func TestBasic_Struct(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vdata": map[string]interface{}{ + "vstring": "foo", + }, + } + + var result, inner Basic + result.Vdata = &inner + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err) + } + expected := Basic{ + Vdata: &Basic{ + Vstring: "foo", + }, + } + if !reflect.DeepEqual(result, expected) { + t.Fatalf("bad: %#v", result) + } +} + +func TestBasic_interfaceStruct(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vstring": "foo", + } + + var iface interface{} = &Basic{} + err := Decode(input, &iface) + if err != nil { + t.Fatalf("got an jderr: %s", err) + } + + expected := &Basic{ + Vstring: "foo", + } + if !reflect.DeepEqual(iface, expected) { + t.Fatalf("bad: %#v", iface) + } +} + +// Issue 187 +func TestBasic_interfaceStructNonPtr(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vstring": "foo", + } + + var iface interface{} = Basic{} + err := Decode(input, &iface) + if err != nil { + t.Fatalf("got an jderr: %s", err) + } + + expected := Basic{ + Vstring: "foo", + } + if !reflect.DeepEqual(iface, expected) { + t.Fatalf("bad: %#v", iface) + } +} + +func TestDecode_BasicSquash(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vstring": "foo", + } + + var result BasicSquash + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if result.Test.Vstring != "foo" { + t.Errorf("vstring value should be 'foo': %#v", result.Test.Vstring) + } +} + +func TestDecodeFrom_BasicSquash(t *testing.T) { + t.Parallel() + + var v interface{} + var ok bool + + input := BasicSquash{ + Test: Basic{ + Vstring: "foo", + }, + } + + var result map[string]interface{} + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if _, ok = result["Test"]; ok { + t.Error("test should not be present in map") + } + + v, ok = result["Vstring"] + if !ok { + t.Error("vstring should be present in map") + } else if !reflect.DeepEqual(v, "foo") { + t.Errorf("vstring value should be 'foo': %#v", v) + } +} + +func TestDecode_Embedded(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vstring": "foo", + "Basic": map[string]interface{}{ + "vstring": "innerfoo", + }, + "vunique": "bar", + } + + var result Embedded + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if result.Vstring != "innerfoo" { + t.Errorf("vstring value should be 'innerfoo': %#v", result.Vstring) + } + + if result.Vunique != "bar" { + t.Errorf("vunique value should be 'bar': %#v", result.Vunique) + } +} + +func TestDecode_EmbeddedPointer(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vstring": "foo", + "Basic": map[string]interface{}{ + "vstring": "innerfoo", + }, + "vunique": "bar", + } + + var result EmbeddedPointer + err := Decode(input, &result) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + expected := EmbeddedPointer{ + Basic: &Basic{ + Vstring: "innerfoo", + }, + Vunique: "bar", + } + if !reflect.DeepEqual(result, expected) { + t.Fatalf("bad: %#v", result) + } +} + +func TestDecode_EmbeddedSlice(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "slice_alias": []string{"foo", "bar"}, + "vunique": "bar", + } + + var result EmbeddedSlice + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if !reflect.DeepEqual(result.SliceAlias, SliceAlias([]string{"foo", "bar"})) { + t.Errorf("slice value: %#v", result.SliceAlias) + } + + if result.Vunique != "bar" { + t.Errorf("vunique value should be 'bar': %#v", result.Vunique) + } +} + +func TestDecode_EmbeddedArray(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "array_alias": [2]string{"foo", "bar"}, + "vunique": "bar", + } + + var result EmbeddedArray + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if !reflect.DeepEqual(result.ArrayAlias, ArrayAlias([2]string{"foo", "bar"})) { + t.Errorf("array value: %#v", result.ArrayAlias) + } + + if result.Vunique != "bar" { + t.Errorf("vunique value should be 'bar': %#v", result.Vunique) + } +} + +func TestDecode_decodeSliceWithArray(t *testing.T) { + t.Parallel() + + var result []int + input := [1]int{1} + expected := []int{1} + if err := Decode(input, &result); err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if !reflect.DeepEqual(expected, result) { + t.Errorf("wanted %+v, got %+v", expected, result) + } +} + +func TestDecode_EmbeddedNoSquash(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vstring": "foo", + "vunique": "bar", + } + + var result Embedded + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if result.Vstring != "" { + t.Errorf("vstring value should be empty: %#v", result.Vstring) + } + + if result.Vunique != "bar" { + t.Errorf("vunique value should be 'bar': %#v", result.Vunique) + } +} + +func TestDecode_EmbeddedPointerNoSquash(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vstring": "foo", + "vunique": "bar", + } + + result := EmbeddedPointer{ + Basic: &Basic{}, + } + + err := Decode(input, &result) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + if result.Vstring != "" { + t.Errorf("vstring value should be empty: %#v", result.Vstring) + } + + if result.Vunique != "bar" { + t.Errorf("vunique value should be 'bar': %#v", result.Vunique) + } +} + +func TestDecode_EmbeddedSquash(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vstring": "foo", + "vunique": "bar", + } + + var result EmbeddedSquash + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if result.Vstring != "foo" { + t.Errorf("vstring value should be 'foo': %#v", result.Vstring) + } + + if result.Vunique != "bar" { + t.Errorf("vunique value should be 'bar': %#v", result.Vunique) + } +} + +func TestDecodeFrom_EmbeddedSquash(t *testing.T) { + t.Parallel() + + var v interface{} + var ok bool + + input := EmbeddedSquash{ + Basic: Basic{ + Vstring: "foo", + }, + Vunique: "bar", + } + + var result map[string]interface{} + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if _, ok = result["Basic"]; ok { + t.Error("basic should not be present in map") + } + + v, ok = result["Vstring"] + if !ok { + t.Error("vstring should be present in map") + } else if !reflect.DeepEqual(v, "foo") { + t.Errorf("vstring value should be 'foo': %#v", v) + } + + v, ok = result["Vunique"] + if !ok { + t.Error("vunique should be present in map") + } else if !reflect.DeepEqual(v, "bar") { + t.Errorf("vunique value should be 'bar': %#v", v) + } +} + +func TestDecode_EmbeddedPointerSquash_FromStructToMap(t *testing.T) { + t.Parallel() + + input := EmbeddedPointerSquash{ + Basic: &Basic{ + Vstring: "foo", + }, + Vunique: "bar", + } + + var result map[string]interface{} + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if result["Vstring"] != "foo" { + t.Errorf("vstring value should be 'foo': %#v", result["Vstring"]) + } + + if result["Vunique"] != "bar" { + t.Errorf("vunique value should be 'bar': %#v", result["Vunique"]) + } +} + +func TestDecode_EmbeddedPointerSquash_FromMapToStruct(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "Vstring": "foo", + "Vunique": "bar", + } + + result := EmbeddedPointerSquash{ + Basic: &Basic{}, + } + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if result.Vstring != "foo" { + t.Errorf("vstring value should be 'foo': %#v", result.Vstring) + } + + if result.Vunique != "bar" { + t.Errorf("vunique value should be 'bar': %#v", result.Vunique) + } +} + +func TestDecode_EmbeddedPointerSquashWithNestedMapstructure_FromStructToMap(t *testing.T) { + t.Parallel() + + vTime := time.Now() + + input := EmbeddedPointerSquashWithNestedMapstructure{ + NestedPointerWithMapstructure: &NestedPointerWithMapstructure{ + Vbar: &BasicMapStructure{ + Vunique: "bar", + Vtime: &vTime, + }, + }, + Vunique: "foo", + } + + var result map[string]interface{} + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + expected := map[string]interface{}{ + "vbar": map[string]interface{}{ + "vunique": "bar", + "time": &vTime, + }, + "Vunique": "foo", + } + + if !reflect.DeepEqual(result, expected) { + t.Errorf("result should be %#v: got %#v", expected, result) + } +} + +func TestDecode_EmbeddedPointerSquashWithNestedMapstructure_FromMapToStruct(t *testing.T) { + t.Parallel() + + vTime := time.Now() + + input := map[string]interface{}{ + "vbar": map[string]interface{}{ + "vunique": "bar", + "time": &vTime, + }, + "Vunique": "foo", + } + + result := EmbeddedPointerSquashWithNestedMapstructure{ + NestedPointerWithMapstructure: &NestedPointerWithMapstructure{}, + } + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + expected := EmbeddedPointerSquashWithNestedMapstructure{ + NestedPointerWithMapstructure: &NestedPointerWithMapstructure{ + Vbar: &BasicMapStructure{ + Vunique: "bar", + Vtime: &vTime, + }, + }, + Vunique: "foo", + } + + if !reflect.DeepEqual(result, expected) { + t.Errorf("result should be %#v: got %#v", expected, result) + } +} + +func TestDecode_EmbeddedSquashConfig(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vstring": "foo", + "vunique": "bar", + "Named": map[string]interface{}{ + "vstring": "baz", + }, + } + + var result EmbeddedAndNamed + config := &DecoderConfig{ + Squash: true, + Result: &result, + } + + decoder, err := NewDecoder(config) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + err = decoder.Decode(input) + if err != nil { + t.Fatalf("got an jderr: %s", err) + } + + if result.Vstring != "foo" { + t.Errorf("vstring value should be 'foo': %#v", result.Vstring) + } + + if result.Vunique != "bar" { + t.Errorf("vunique value should be 'bar': %#v", result.Vunique) + } + + if result.Named.Vstring != "baz" { + t.Errorf("Named.vstring value should be 'baz': %#v", result.Named.Vstring) + } +} + +func TestDecodeFrom_EmbeddedSquashConfig(t *testing.T) { + t.Parallel() + + input := EmbeddedAndNamed{ + Basic: Basic{Vstring: "foo"}, + Named: Basic{Vstring: "baz"}, + Vunique: "bar", + } + + result := map[string]interface{}{} + config := &DecoderConfig{ + Squash: true, + Result: &result, + } + decoder, err := NewDecoder(config) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + err = decoder.Decode(input) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if _, ok := result["Basic"]; ok { + t.Error("basic should not be present in map") + } + + v, ok := result["Vstring"] + if !ok { + t.Error("vstring should be present in map") + } else if !reflect.DeepEqual(v, "foo") { + t.Errorf("vstring value should be 'foo': %#v", v) + } + + v, ok = result["Vunique"] + if !ok { + t.Error("vunique should be present in map") + } else if !reflect.DeepEqual(v, "bar") { + t.Errorf("vunique value should be 'bar': %#v", v) + } + + v, ok = result["Named"] + if !ok { + t.Error("Named should be present in map") + } else { + named := v.(map[string]interface{}) + v, ok := named["Vstring"] + if !ok { + t.Error("Named: vstring should be present in map") + } else if !reflect.DeepEqual(v, "baz") { + t.Errorf("Named: vstring should be 'baz': %#v", v) + } + } +} + +func TestDecodeFrom_EmbeddedSquashConfig_WithTags(t *testing.T) { + t.Parallel() + + var v interface{} + var ok bool + + input := EmbeddedSquash{ + Basic: Basic{ + Vstring: "foo", + }, + Vunique: "bar", + } + + result := map[string]interface{}{} + config := &DecoderConfig{ + Squash: true, + Result: &result, + } + decoder, err := NewDecoder(config) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + err = decoder.Decode(input) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if _, ok = result["Basic"]; ok { + t.Error("basic should not be present in map") + } + + v, ok = result["Vstring"] + if !ok { + t.Error("vstring should be present in map") + } else if !reflect.DeepEqual(v, "foo") { + t.Errorf("vstring value should be 'foo': %#v", v) + } + + v, ok = result["Vunique"] + if !ok { + t.Error("vunique should be present in map") + } else if !reflect.DeepEqual(v, "bar") { + t.Errorf("vunique value should be 'bar': %#v", v) + } +} + +func TestDecode_SquashOnNonStructType(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "InvalidSquashType": 42, + } + + var result SquashOnNonStructType + err := Decode(input, &result) + if err == nil { + t.Fatal("unexpected success decoding invalid squash field type") + } else if !strings.Contains(err.Error(), "unsupported type for squash") { + t.Fatalf("unexpected error message for invalid squash field type: %s", err) + } +} + +func TestDecode_DecodeHook(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vint": "WHAT", + } + + decodeHook := func(from reflect.Kind, to reflect.Kind, v interface{}) (interface{}, error) { + if from == reflect.String && to != reflect.String { + return 5, nil + } + + return v, nil + } + + var result Basic + config := &DecoderConfig{ + DecodeHook: decodeHook, + Result: &result, + } + + decoder, err := NewDecoder(config) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + err = decoder.Decode(input) + if err != nil { + t.Fatalf("got an jderr: %s", err) + } + + if result.Vint != 5 { + t.Errorf("vint should be 5: %#v", result.Vint) + } +} + +func TestDecode_DecodeHookType(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vint": "WHAT", + } + + decodeHook := func(from reflect.Type, to reflect.Type, v interface{}) (interface{}, error) { + if from.Kind() == reflect.String && + to.Kind() != reflect.String { + return 5, nil + } + + return v, nil + } + + var result Basic + config := &DecoderConfig{ + DecodeHook: decodeHook, + Result: &result, + } + + decoder, err := NewDecoder(config) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + err = decoder.Decode(input) + if err != nil { + t.Fatalf("got an jderr: %s", err) + } + + if result.Vint != 5 { + t.Errorf("vint should be 5: %#v", result.Vint) + } +} + +func TestDecode_Nil(t *testing.T) { + t.Parallel() + + var input interface{} + result := Basic{ + Vstring: "foo", + } + + err := Decode(input, &result) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + if result.Vstring != "foo" { + t.Fatalf("bad: %#v", result.Vstring) + } +} + +func TestDecode_NilInterfaceHook(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "w": "", + } + + decodeHook := func(f, t reflect.Type, v interface{}) (interface{}, error) { + if t.String() == "io.Writer" { + return nil, nil + } + + return v, nil + } + + var result NilInterface + config := &DecoderConfig{ + DecodeHook: decodeHook, + Result: &result, + } + + decoder, err := NewDecoder(config) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + err = decoder.Decode(input) + if err != nil { + t.Fatalf("got an jderr: %s", err) + } + + if result.W != nil { + t.Errorf("W should be nil: %#v", result.W) + } +} + +func TestDecode_NilPointerHook(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "value": "", + } + + decodeHook := func(f, t reflect.Type, v interface{}) (interface{}, error) { + if typed, ok := v.(string); ok { + if typed == "" { + return nil, nil + } + } + return v, nil + } + + var result NilPointer + config := &DecoderConfig{ + DecodeHook: decodeHook, + Result: &result, + } + + decoder, err := NewDecoder(config) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + err = decoder.Decode(input) + if err != nil { + t.Fatalf("got an jderr: %s", err) + } + + if result.Value != nil { + t.Errorf("W should be nil: %#v", result.Value) + } +} + +func TestDecode_FuncHook(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "foo": "baz", + } + + decodeHook := func(f, t reflect.Type, v interface{}) (interface{}, error) { + if t.Kind() != reflect.Func { + return v, nil + } + val := v.(string) + return func() string { return val }, nil + } + + var result Func + config := &DecoderConfig{ + DecodeHook: decodeHook, + Result: &result, + } + + decoder, err := NewDecoder(config) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + err = decoder.Decode(input) + if err != nil { + t.Fatalf("got an jderr: %s", err) + } + + if result.Foo() != "baz" { + t.Errorf("Foo call result should be 'baz': %s", result.Foo()) + } +} + +func TestDecode_NonStruct(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "foo": "bar", + "bar": "baz", + } + + var result map[string]string + err := Decode(input, &result) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + if result["foo"] != "bar" { + t.Fatal("foo is not bar") + } +} + +func TestDecode_StructMatch(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vbar": Basic{ + Vstring: "foo", + }, + } + + var result Nested + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if result.Vbar.Vstring != "foo" { + t.Errorf("bad: %#v", result) + } +} + +func TestDecode_TypeConversion(t *testing.T) { + input := map[string]interface{}{ + "IntToFloat": 42, + "IntToUint": 42, + "IntToBool": 1, + "IntToString": 42, + "UintToInt": 42, + "UintToFloat": 42, + "UintToBool": 42, + "UintToString": 42, + "BoolToInt": true, + "BoolToUint": true, + "BoolToFloat": true, + "BoolToString": true, + "FloatToInt": 42.42, + "FloatToUint": 42.42, + "FloatToBool": 42.42, + "FloatToString": 42.42, + "SliceUint8ToString": []uint8("foo"), + "StringToSliceUint8": "foo", + "ArrayUint8ToString": [3]uint8{'f', 'o', 'o'}, + "StringToInt": "42", + "StringToUint": "42", + "StringToBool": "1", + "StringToFloat": "42.42", + "StringToStrSlice": "A", + "StringToIntSlice": "42", + "StringToStrArray": "A", + "StringToIntArray": "42", + "SliceToMap": []interface{}{}, + "MapToSlice": map[string]interface{}{}, + "ArrayToMap": []interface{}{}, + "MapToArray": map[string]interface{}{}, + } + + expectedResultStrict := TypeConversionResult{ + IntToFloat: 42.0, + IntToUint: 42, + UintToInt: 42, + UintToFloat: 42, + BoolToInt: 0, + BoolToUint: 0, + BoolToFloat: 0, + FloatToInt: 42, + FloatToUint: 42, + } + + expectedResultWeak := TypeConversionResult{ + IntToFloat: 42.0, + IntToUint: 42, + IntToBool: true, + IntToString: "42", + UintToInt: 42, + UintToFloat: 42, + UintToBool: true, + UintToString: "42", + BoolToInt: 1, + BoolToUint: 1, + BoolToFloat: 1, + BoolToString: "1", + FloatToInt: 42, + FloatToUint: 42, + FloatToBool: true, + FloatToString: "42.42", + SliceUint8ToString: "foo", + StringToSliceUint8: []byte("foo"), + ArrayUint8ToString: "foo", + StringToInt: 42, + StringToUint: 42, + StringToBool: true, + StringToFloat: 42.42, + StringToStrSlice: []string{"A"}, + StringToIntSlice: []int{42}, + StringToStrArray: [1]string{"A"}, + StringToIntArray: [1]int{42}, + SliceToMap: map[string]interface{}{}, + MapToSlice: []interface{}{}, + ArrayToMap: map[string]interface{}{}, + MapToArray: [1]interface{}{}, + } + + // Test strict type conversion + var resultStrict TypeConversionResult + err := Decode(input, &resultStrict) + if err == nil { + t.Errorf("should return an error") + } + if !reflect.DeepEqual(resultStrict, expectedResultStrict) { + t.Errorf("expected %v, got: %v", expectedResultStrict, resultStrict) + } + + // Test weak type conversion + var decoder *Decoder + var resultWeak TypeConversionResult + + config := &DecoderConfig{ + WeaklyTypedInput: true, + Result: &resultWeak, + } + + decoder, err = NewDecoder(config) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + err = decoder.Decode(input) + if err != nil { + t.Fatalf("got an jderr: %s", err) + } + + if !reflect.DeepEqual(resultWeak, expectedResultWeak) { + t.Errorf("expected \n%#v, got: \n%#v", expectedResultWeak, resultWeak) + } +} + +func TestDecoder_ErrorUnused(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vstring": "hello", + "foo": "bar", + } + + var result Basic + config := &DecoderConfig{ + ErrorUnused: true, + Result: &result, + } + + decoder, err := NewDecoder(config) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + err = decoder.Decode(input) + if err == nil { + t.Fatal("expected error") + } +} + +func TestDecoder_ErrorUnused_NotSetable(t *testing.T) { + t.Parallel() + + // lowercase vsilent is unexported and cannot be set + input := map[string]interface{}{ + "vsilent": "false", + } + + var result Basic + config := &DecoderConfig{ + ErrorUnused: true, + Result: &result, + } + + decoder, err := NewDecoder(config) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + err = decoder.Decode(input) + if err == nil { + t.Fatal("expected error") + } +} +func TestDecoder_ErrorUnset(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vstring": "hello", + "foo": "bar", + } + + var result Basic + config := &DecoderConfig{ + ErrorUnset: true, + Result: &result, + } + + decoder, err := NewDecoder(config) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + err = decoder.Decode(input) + if err == nil { + t.Fatal("expected error") + } +} + +func TestMap(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vfoo": "foo", + "vother": map[interface{}]interface{}{ + "foo": "foo", + "bar": "bar", + }, + } + + var result Map + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an error: %s", err) + } + + if result.Vfoo != "foo" { + t.Errorf("vfoo value should be 'foo': %#v", result.Vfoo) + } + + if result.Vother == nil { + t.Fatal("vother should not be nil") + } + + if len(result.Vother) != 2 { + t.Error("vother should have two items") + } + + if result.Vother["foo"] != "foo" { + t.Errorf("'foo' key should be foo, got: %#v", result.Vother["foo"]) + } + + if result.Vother["bar"] != "bar" { + t.Errorf("'bar' key should be bar, got: %#v", result.Vother["bar"]) + } +} + +func TestMapMerge(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vfoo": "foo", + "vother": map[interface{}]interface{}{ + "foo": "foo", + "bar": "bar", + }, + } + + var result Map + result.Vother = map[string]string{"hello": "world"} + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an error: %s", err) + } + + if result.Vfoo != "foo" { + t.Errorf("vfoo value should be 'foo': %#v", result.Vfoo) + } + + expected := map[string]string{ + "foo": "foo", + "bar": "bar", + "hello": "world", + } + if !reflect.DeepEqual(result.Vother, expected) { + t.Errorf("bad: %#v", result.Vother) + } +} + +func TestMapOfStruct(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "value": map[string]interface{}{ + "foo": map[string]string{"vstring": "one"}, + "bar": map[string]string{"vstring": "two"}, + }, + } + + var result MapOfStruct + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err) + } + + if result.Value == nil { + t.Fatal("value should not be nil") + } + + if len(result.Value) != 2 { + t.Error("value should have two items") + } + + if result.Value["foo"].Vstring != "one" { + t.Errorf("foo value should be 'one', got: %s", result.Value["foo"].Vstring) + } + + if result.Value["bar"].Vstring != "two" { + t.Errorf("bar value should be 'two', got: %s", result.Value["bar"].Vstring) + } +} + +func TestNestedType(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vfoo": "foo", + "vbar": map[string]interface{}{ + "vstring": "foo", + "vint": 42, + "vbool": true, + }, + } + + var result Nested + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if result.Vfoo != "foo" { + t.Errorf("vfoo value should be 'foo': %#v", result.Vfoo) + } + + if result.Vbar.Vstring != "foo" { + t.Errorf("vstring value should be 'foo': %#v", result.Vbar.Vstring) + } + + if result.Vbar.Vint != 42 { + t.Errorf("vint value should be 42: %#v", result.Vbar.Vint) + } + + if result.Vbar.Vbool != true { + t.Errorf("vbool value should be true: %#v", result.Vbar.Vbool) + } + + if result.Vbar.Vextra != "" { + t.Errorf("vextra value should be empty: %#v", result.Vbar.Vextra) + } +} + +func TestNestedTypePointer(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vfoo": "foo", + "vbar": &map[string]interface{}{ + "vstring": "foo", + "vint": 42, + "vbool": true, + }, + } + + var result NestedPointer + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if result.Vfoo != "foo" { + t.Errorf("vfoo value should be 'foo': %#v", result.Vfoo) + } + + if result.Vbar.Vstring != "foo" { + t.Errorf("vstring value should be 'foo': %#v", result.Vbar.Vstring) + } + + if result.Vbar.Vint != 42 { + t.Errorf("vint value should be 42: %#v", result.Vbar.Vint) + } + + if result.Vbar.Vbool != true { + t.Errorf("vbool value should be true: %#v", result.Vbar.Vbool) + } + + if result.Vbar.Vextra != "" { + t.Errorf("vextra value should be empty: %#v", result.Vbar.Vextra) + } +} + +// Test for issue #46. +func TestNestedTypeInterface(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vfoo": "foo", + "vbar": &map[string]interface{}{ + "vstring": "foo", + "vint": 42, + "vbool": true, + + "vdata": map[string]interface{}{ + "vstring": "bar", + }, + }, + } + + var result NestedPointer + result.Vbar = new(Basic) + result.Vbar.Vdata = new(Basic) + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an jderr: %s", err.Error()) + } + + if result.Vfoo != "foo" { + t.Errorf("vfoo value should be 'foo': %#v", result.Vfoo) + } + + if result.Vbar.Vstring != "foo" { + t.Errorf("vstring value should be 'foo': %#v", result.Vbar.Vstring) + } + + if result.Vbar.Vint != 42 { + t.Errorf("vint value should be 42: %#v", result.Vbar.Vint) + } + + if result.Vbar.Vbool != true { + t.Errorf("vbool value should be true: %#v", result.Vbar.Vbool) + } + + if result.Vbar.Vextra != "" { + t.Errorf("vextra value should be empty: %#v", result.Vbar.Vextra) + } + + if result.Vbar.Vdata.(*Basic).Vstring != "bar" { + t.Errorf("vstring value should be 'bar': %#v", result.Vbar.Vdata.(*Basic).Vstring) + } +} + +func TestSlice(t *testing.T) { + t.Parallel() + + inputStringSlice := map[string]interface{}{ + "vfoo": "foo", + "vbar": []string{"foo", "bar", "baz"}, + } + + inputStringSlicePointer := map[string]interface{}{ + "vfoo": "foo", + "vbar": &[]string{"foo", "bar", "baz"}, + } + + outputStringSlice := &Slice{ + "foo", + []string{"foo", "bar", "baz"}, + } + + testSliceInput(t, inputStringSlice, outputStringSlice) + testSliceInput(t, inputStringSlicePointer, outputStringSlice) +} + +func TestInvalidSlice(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vfoo": "foo", + "vbar": 42, + } + + result := Slice{} + err := Decode(input, &result) + if err == nil { + t.Errorf("expected failure") + } +} + +func TestSliceOfStruct(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "value": []map[string]interface{}{ + {"vstring": "one"}, + {"vstring": "two"}, + }, + } + + var result SliceOfStruct + err := Decode(input, &result) + if err != nil { + t.Fatalf("got unexpected error: %s", err) + } + + if len(result.Value) != 2 { + t.Fatalf("expected two values, got %d", len(result.Value)) + } + + if result.Value[0].Vstring != "one" { + t.Errorf("first value should be 'one', got: %s", result.Value[0].Vstring) + } + + if result.Value[1].Vstring != "two" { + t.Errorf("second value should be 'two', got: %s", result.Value[1].Vstring) + } +} + +func TestSliceCornerCases(t *testing.T) { + t.Parallel() + + // Input with a map with zero values + input := map[string]interface{}{} + var resultWeak []Basic + + err := WeakDecode(input, &resultWeak) + if err != nil { + t.Fatalf("got unexpected error: %s", err) + } + + if len(resultWeak) != 0 { + t.Errorf("length should be 0") + } + // Input with more values + input = map[string]interface{}{ + "Vstring": "foo", + } + + resultWeak = nil + err = WeakDecode(input, &resultWeak) + if err != nil { + t.Fatalf("got unexpected error: %s", err) + } + + if resultWeak[0].Vstring != "foo" { + t.Errorf("value does not match") + } +} + +func TestSliceToMap(t *testing.T) { + t.Parallel() + + input := []map[string]interface{}{ + { + "foo": "bar", + }, + { + "bar": "baz", + }, + } + + var result map[string]interface{} + err := WeakDecode(input, &result) + if err != nil { + t.Fatalf("got an error: %s", err) + } + + expected := map[string]interface{}{ + "foo": "bar", + "bar": "baz", + } + if !reflect.DeepEqual(result, expected) { + t.Errorf("bad: %#v", result) + } +} + +func TestArray(t *testing.T) { + t.Parallel() + + inputStringArray := map[string]interface{}{ + "vfoo": "foo", + "vbar": [2]string{"foo", "bar"}, + } + + inputStringArrayPointer := map[string]interface{}{ + "vfoo": "foo", + "vbar": &[2]string{"foo", "bar"}, + } + + outputStringArray := &Array{ + "foo", + [2]string{"foo", "bar"}, + } + + testArrayInput(t, inputStringArray, outputStringArray) + testArrayInput(t, inputStringArrayPointer, outputStringArray) +} + +func TestInvalidArray(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vfoo": "foo", + "vbar": 42, + } + + result := Array{} + err := Decode(input, &result) + if err == nil { + t.Errorf("expected failure") + } +} + +func TestArrayOfStruct(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "value": []map[string]interface{}{ + {"vstring": "one"}, + {"vstring": "two"}, + }, + } + + var result ArrayOfStruct + err := Decode(input, &result) + if err != nil { + t.Fatalf("got unexpected error: %s", err) + } + + if len(result.Value) != 2 { + t.Fatalf("expected two values, got %d", len(result.Value)) + } + + if result.Value[0].Vstring != "one" { + t.Errorf("first value should be 'one', got: %s", result.Value[0].Vstring) + } + + if result.Value[1].Vstring != "two" { + t.Errorf("second value should be 'two', got: %s", result.Value[1].Vstring) + } +} + +func TestArrayToMap(t *testing.T) { + t.Parallel() + + input := []map[string]interface{}{ + { + "foo": "bar", + }, + { + "bar": "baz", + }, + } + + var result map[string]interface{} + err := WeakDecode(input, &result) + if err != nil { + t.Fatalf("got an error: %s", err) + } + + expected := map[string]interface{}{ + "foo": "bar", + "bar": "baz", + } + if !reflect.DeepEqual(result, expected) { + t.Errorf("bad: %#v", result) + } +} + +func TestDecodeTable(t *testing.T) { + t.Parallel() + + // We need to make new types so that we don't get the short-circuit + // copy functionality. We want to test the deep copying functionality. + type BasicCopy Basic + type NestedPointerCopy NestedPointer + type MapCopy Map + + tests := []struct { + name string + in interface{} + target interface{} + out interface{} + wantErr bool + }{ + { + "basic struct input", + &Basic{ + Vstring: "vstring", + Vint: 2, + Vint8: 2, + Vint16: 2, + Vint32: 2, + Vint64: 2, + Vuint: 3, + Vbool: true, + Vfloat: 4.56, + Vextra: "vextra", + vsilent: true, + Vdata: []byte("data"), + }, + &map[string]interface{}{}, + &map[string]interface{}{ + "Vstring": "vstring", + "Vint": 2, + "Vint8": int8(2), + "Vint16": int16(2), + "Vint32": int32(2), + "Vint64": int64(2), + "Vuint": uint(3), + "Vbool": true, + "Vfloat": 4.56, + "Vextra": "vextra", + "Vdata": []byte("data"), + "VjsonInt": 0, + "VjsonUint": uint(0), + "VjsonUint64": uint64(0), + "VjsonFloat": 0.0, + "VjsonNumber": json.Number(""), + }, + false, + }, + { + "embedded struct input", + &Embedded{ + Vunique: "vunique", + Basic: Basic{ + Vstring: "vstring", + Vint: 2, + Vint8: 2, + Vint16: 2, + Vint32: 2, + Vint64: 2, + Vuint: 3, + Vbool: true, + Vfloat: 4.56, + Vextra: "vextra", + vsilent: true, + Vdata: []byte("data"), + }, + }, + &map[string]interface{}{}, + &map[string]interface{}{ + "Vunique": "vunique", + "Basic": map[string]interface{}{ + "Vstring": "vstring", + "Vint": 2, + "Vint8": int8(2), + "Vint16": int16(2), + "Vint32": int32(2), + "Vint64": int64(2), + "Vuint": uint(3), + "Vbool": true, + "Vfloat": 4.56, + "Vextra": "vextra", + "Vdata": []byte("data"), + "VjsonInt": 0, + "VjsonUint": uint(0), + "VjsonUint64": uint64(0), + "VjsonFloat": 0.0, + "VjsonNumber": json.Number(""), + }, + }, + false, + }, + { + "struct => struct", + &Basic{ + Vstring: "vstring", + Vint: 2, + Vuint: 3, + Vbool: true, + Vfloat: 4.56, + Vextra: "vextra", + Vdata: []byte("data"), + vsilent: true, + }, + &BasicCopy{}, + &BasicCopy{ + Vstring: "vstring", + Vint: 2, + Vuint: 3, + Vbool: true, + Vfloat: 4.56, + Vextra: "vextra", + Vdata: []byte("data"), + }, + false, + }, + { + "struct => struct with pointers", + &NestedPointer{ + Vfoo: "hello", + Vbar: nil, + }, + &NestedPointerCopy{}, + &NestedPointerCopy{ + Vfoo: "hello", + }, + false, + }, + { + "basic pointer to non-pointer", + &BasicPointer{ + Vstring: stringPtr("vstring"), + Vint: intPtr(2), + Vuint: uintPtr(3), + Vbool: boolPtr(true), + Vfloat: floatPtr(4.56), + Vdata: interfacePtr([]byte("data")), + }, + &Basic{}, + &Basic{ + Vstring: "vstring", + Vint: 2, + Vuint: 3, + Vbool: true, + Vfloat: 4.56, + Vdata: []byte("data"), + }, + false, + }, + { + "slice non-pointer to pointer", + &Slice{}, + &SlicePointer{}, + &SlicePointer{}, + false, + }, + { + "slice non-pointer to pointer, zero field", + &Slice{}, + &SlicePointer{ + Vbar: &[]string{"yo"}, + }, + &SlicePointer{}, + false, + }, + { + "slice to slice alias", + &Slice{}, + &SliceOfAlias{}, + &SliceOfAlias{}, + false, + }, + { + "nil map to map", + &Map{}, + &MapCopy{}, + &MapCopy{}, + false, + }, + { + "nil map to non-empty map", + &Map{}, + &MapCopy{Vother: map[string]string{"foo": "bar"}}, + &MapCopy{}, + false, + }, + + { + "slice input - should error", + []string{"foo", "bar"}, + &map[string]interface{}{}, + &map[string]interface{}{}, + true, + }, + { + "struct with slice property", + &Slice{ + Vfoo: "vfoo", + Vbar: []string{"foo", "bar"}, + }, + &map[string]interface{}{}, + &map[string]interface{}{ + "Vfoo": "vfoo", + "Vbar": []string{"foo", "bar"}, + }, + false, + }, + { + "struct with empty slice", + &map[string]interface{}{ + "Vbar": []string{}, + }, + &Slice{}, + &Slice{ + Vbar: []string{}, + }, + false, + }, + { + "struct with slice of struct property", + &SliceOfStruct{ + Value: []Basic{ + Basic{ + Vstring: "vstring", + Vint: 2, + Vuint: 3, + Vbool: true, + Vfloat: 4.56, + Vextra: "vextra", + vsilent: true, + Vdata: []byte("data"), + }, + }, + }, + &map[string]interface{}{}, + &map[string]interface{}{ + "Value": []Basic{ + Basic{ + Vstring: "vstring", + Vint: 2, + Vuint: 3, + Vbool: true, + Vfloat: 4.56, + Vextra: "vextra", + vsilent: true, + Vdata: []byte("data"), + }, + }, + }, + false, + }, + { + "struct with map property", + &Map{ + Vfoo: "vfoo", + Vother: map[string]string{"vother": "vother"}, + }, + &map[string]interface{}{}, + &map[string]interface{}{ + "Vfoo": "vfoo", + "Vother": map[string]string{ + "vother": "vother", + }}, + false, + }, + { + "tagged struct", + &Tagged{ + Extra: "extra", + Value: "value", + }, + &map[string]string{}, + &map[string]string{ + "bar": "extra", + "foo": "value", + }, + false, + }, + { + "omit tag struct", + &struct { + Value string `mapstructure:"value"` + Omit string `mapstructure:"-"` + }{ + Value: "value", + Omit: "omit", + }, + &map[string]string{}, + &map[string]string{ + "value": "value", + }, + false, + }, + { + "decode to wrong map type", + &struct { + Value string + }{ + Value: "string", + }, + &map[string]int{}, + &map[string]int{}, + true, + }, + { + "remainder", + map[string]interface{}{ + "A": "hello", + "B": "goodbye", + "C": "yo", + }, + &Remainder{}, + &Remainder{ + A: "hello", + Extra: map[string]interface{}{ + "B": "goodbye", + "C": "yo", + }, + }, + false, + }, + { + "remainder with no extra", + map[string]interface{}{ + "A": "hello", + }, + &Remainder{}, + &Remainder{ + A: "hello", + Extra: nil, + }, + false, + }, + { + "struct with omitempty tag return non-empty values", + &struct { + VisibleField interface{} `mapstructure:"visible"` + OmitField interface{} `mapstructure:"omittable,omitempty"` + }{ + VisibleField: nil, + OmitField: "string", + }, + &map[string]interface{}{}, + &map[string]interface{}{"visible": nil, "omittable": "string"}, + false, + }, + { + "struct with omitempty tag ignore empty values", + &struct { + VisibleField interface{} `mapstructure:"visible"` + OmitField interface{} `mapstructure:"omittable,omitempty"` + }{ + VisibleField: nil, + OmitField: nil, + }, + &map[string]interface{}{}, + &map[string]interface{}{"visible": nil}, + false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := Decode(tt.in, tt.target); (err != nil) != tt.wantErr { + t.Fatalf("%q: TestMapOutputForStructuredInputs() unexpected error: %s", tt.name, err) + } + + if !reflect.DeepEqual(tt.out, tt.target) { + t.Fatalf("%q: TestMapOutputForStructuredInputs() expected: %#v, got: %#v", tt.name, tt.out, tt.target) + } + }) + } +} + +func TestInvalidType(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vstring": 42, + } + + var result Basic + err := Decode(input, &result) + if err == nil { + t.Fatal("error should exist") + } + + derr, ok := err.(*Error) + if !ok { + t.Fatalf("error should be kind of Error, instead: %#v", err) + } + + if derr.Errors[0] != + "'Vstring' expected type 'string', got unconvertible type 'int', value: '42'" { + t.Errorf("got unexpected error: %s", err) + } + + inputNegIntUint := map[string]interface{}{ + "vuint": -42, + } + + err = Decode(inputNegIntUint, &result) + if err == nil { + t.Fatal("error should exist") + } + + derr, ok = err.(*Error) + if !ok { + t.Fatalf("error should be kind of Error, instead: %#v", err) + } + + if derr.Errors[0] != "cannot parse 'Vuint', -42 overflows uint" { + t.Errorf("got unexpected error: %s", err) + } + + inputNegFloatUint := map[string]interface{}{ + "vuint": -42.0, + } + + err = Decode(inputNegFloatUint, &result) + if err == nil { + t.Fatal("error should exist") + } + + derr, ok = err.(*Error) + if !ok { + t.Fatalf("error should be kind of Error, instead: %#v", err) + } + + if derr.Errors[0] != "cannot parse 'Vuint', -42.000000 overflows uint" { + t.Errorf("got unexpected error: %s", err) + } +} + +func TestDecodeMetadata(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vfoo": "foo", + "vbar": map[string]interface{}{ + "vstring": "foo", + "Vuint": 42, + "vsilent": "false", + "foo": "bar", + }, + "bar": "nil", + } + + var md Metadata + var result Nested + + err := DecodeMetadata(input, &result, &md) + if err != nil { + t.Fatalf("jderr: %s", err.Error()) + } + + expectedKeys := []string{"Vbar", "Vbar.Vstring", "Vbar.Vuint", "Vfoo"} + sort.Strings(md.Keys) + if !reflect.DeepEqual(md.Keys, expectedKeys) { + t.Fatalf("bad keys: %#v", md.Keys) + } + + expectedUnused := []string{"Vbar.foo", "Vbar.vsilent", "bar"} + sort.Strings(md.Unused) + if !reflect.DeepEqual(md.Unused, expectedUnused) { + t.Fatalf("bad unused: %#v", md.Unused) + } +} + +func TestMetadata(t *testing.T) { + t.Parallel() + + type testResult struct { + Vfoo string + Vbar BasicPointer + } + + input := map[string]interface{}{ + "vfoo": "foo", + "vbar": map[string]interface{}{ + "vstring": "foo", + "Vuint": 42, + "vsilent": "false", + "foo": "bar", + }, + "bar": "nil", + } + + var md Metadata + var result testResult + config := &DecoderConfig{ + Metadata: &md, + Result: &result, + } + + decoder, err := NewDecoder(config) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + err = decoder.Decode(input) + if err != nil { + t.Fatalf("jderr: %s", err.Error()) + } + + expectedKeys := []string{"Vbar", "Vbar.Vstring", "Vbar.Vuint", "Vfoo"} + sort.Strings(md.Keys) + if !reflect.DeepEqual(md.Keys, expectedKeys) { + t.Fatalf("bad keys: %#v", md.Keys) + } + + expectedUnused := []string{"Vbar.foo", "Vbar.vsilent", "bar"} + sort.Strings(md.Unused) + if !reflect.DeepEqual(md.Unused, expectedUnused) { + t.Fatalf("bad unused: %#v", md.Unused) + } + + expectedUnset := []string{ + "Vbar.Vbool", "Vbar.Vdata", "Vbar.Vextra", "Vbar.Vfloat", "Vbar.Vint", + "Vbar.VjsonFloat", "Vbar.VjsonInt", "Vbar.VjsonNumber"} + sort.Strings(md.Unset) + if !reflect.DeepEqual(md.Unset, expectedUnset) { + t.Fatalf("bad unset: %#v", md.Unset) + } +} + +func TestMetadata_Embedded(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vstring": "foo", + "vunique": "bar", + } + + var md Metadata + var result EmbeddedSquash + config := &DecoderConfig{ + Metadata: &md, + Result: &result, + } + + decoder, err := NewDecoder(config) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + err = decoder.Decode(input) + if err != nil { + t.Fatalf("jderr: %s", err.Error()) + } + + expectedKeys := []string{"Vstring", "Vunique"} + + sort.Strings(md.Keys) + if !reflect.DeepEqual(md.Keys, expectedKeys) { + t.Fatalf("bad keys: %#v", md.Keys) + } + + expectedUnused := []string{} + if !reflect.DeepEqual(md.Unused, expectedUnused) { + t.Fatalf("bad unused: %#v", md.Unused) + } +} + +func TestNonPtrValue(t *testing.T) { + t.Parallel() + + err := Decode(map[string]interface{}{}, Basic{}) + if err == nil { + t.Fatal("error should exist") + } + + if err.Error() != "result must be a pointer" { + t.Errorf("got unexpected error: %s", err) + } +} + +func TestTagged(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "foo": "bar", + "bar": "value", + } + + var result Tagged + err := Decode(input, &result) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + + if result.Value != "bar" { + t.Errorf("value should be 'bar', got: %#v", result.Value) + } + + if result.Extra != "value" { + t.Errorf("extra should be 'value', got: %#v", result.Extra) + } +} + +func TestWeakDecode(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "foo": "4", + "bar": "value", + } + + var result struct { + Foo int + Bar string + } + + if err := WeakDecode(input, &result); err != nil { + t.Fatalf("jderr: %s", err) + } + if result.Foo != 4 { + t.Fatalf("bad: %#v", result) + } + if result.Bar != "value" { + t.Fatalf("bad: %#v", result) + } +} + +func TestWeakDecodeMetadata(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "foo": "4", + "bar": "value", + "unused": "value", + "unexported": "value", + } + + var md Metadata + var result struct { + Foo int + Bar string + unexported string + } + + if err := WeakDecodeMetadata(input, &result, &md); err != nil { + t.Fatalf("jderr: %s", err) + } + if result.Foo != 4 { + t.Fatalf("bad: %#v", result) + } + if result.Bar != "value" { + t.Fatalf("bad: %#v", result) + } + + expectedKeys := []string{"Bar", "Foo"} + sort.Strings(md.Keys) + if !reflect.DeepEqual(md.Keys, expectedKeys) { + t.Fatalf("bad keys: %#v", md.Keys) + } + + expectedUnused := []string{"unexported", "unused"} + sort.Strings(md.Unused) + if !reflect.DeepEqual(md.Unused, expectedUnused) { + t.Fatalf("bad unused: %#v", md.Unused) + } +} + +func TestDecode_StructTaggedWithOmitempty_OmitEmptyValues(t *testing.T) { + t.Parallel() + + input := &StructWithOmitEmpty{} + + var emptySlice []interface{} + var emptyMap map[string]interface{} + var emptyNested *Nested + expected := &map[string]interface{}{ + "visible-string": "", + "visible-int": 0, + "visible-float": 0.0, + "visible-slice": emptySlice, + "visible-map": emptyMap, + "visible-nested": emptyNested, + } + + actual := &map[string]interface{}{} + Decode(input, actual) + + if !reflect.DeepEqual(actual, expected) { + t.Fatalf("Decode() expected: %#v, got: %#v", expected, actual) + } +} + +func TestDecode_StructTaggedWithOmitempty_KeepNonEmptyValues(t *testing.T) { + t.Parallel() + + input := &StructWithOmitEmpty{ + VisibleStringField: "", + OmitStringField: "string", + VisibleIntField: 0, + OmitIntField: 1, + VisibleFloatField: 0.0, + OmitFloatField: 1.0, + VisibleSliceField: nil, + OmitSliceField: []interface{}{1}, + VisibleMapField: nil, + OmitMapField: map[string]interface{}{"k": "v"}, + NestedField: nil, + OmitNestedField: &Nested{}, + } + + var emptySlice []interface{} + var emptyMap map[string]interface{} + var emptyNested *Nested + expected := &map[string]interface{}{ + "visible-string": "", + "omittable-string": "string", + "visible-int": 0, + "omittable-int": 1, + "visible-float": 0.0, + "omittable-float": 1.0, + "visible-slice": emptySlice, + "omittable-slice": []interface{}{1}, + "visible-map": emptyMap, + "omittable-map": map[string]interface{}{"k": "v"}, + "visible-nested": emptyNested, + "omittable-nested": &Nested{}, + } + + actual := &map[string]interface{}{} + Decode(input, actual) + + if !reflect.DeepEqual(actual, expected) { + t.Fatalf("Decode() expected: %#v, got: %#v", expected, actual) + } +} + +func TestDecode_mapToStruct(t *testing.T) { + type Target struct { + String string + StringPtr *string + } + + expected := Target{ + String: "hello", + } + + var target Target + err := Decode(map[string]interface{}{ + "string": "hello", + "StringPtr": "goodbye", + }, &target) + if err != nil { + t.Fatalf("got error: %s", err) + } + + // Pointers fail reflect test so do those manually + if target.StringPtr == nil || *target.StringPtr != "goodbye" { + t.Fatalf("bad: %#v", target) + } + target.StringPtr = nil + + if !reflect.DeepEqual(target, expected) { + t.Fatalf("bad: %#v", target) + } +} + +func TestDecoder_MatchName(t *testing.T) { + t.Parallel() + + type Target struct { + FirstMatch string `mapstructure:"first_match"` + SecondMatch string + NoMatch string `mapstructure:"no_match"` + } + + input := map[string]interface{}{ + "first_match": "foo", + "SecondMatch": "bar", + "NO_MATCH": "baz", + } + + expected := Target{ + FirstMatch: "foo", + SecondMatch: "bar", + } + + var actual Target + config := &DecoderConfig{ + Result: &actual, + MatchName: func(mapKey, fieldName string) bool { + return mapKey == fieldName + }, + } + + decoder, err := NewDecoder(config) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + err = decoder.Decode(input) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + if !reflect.DeepEqual(expected, actual) { + t.Fatalf("Decode() expected: %#v, got: %#v", expected, actual) + } +} + +func TestDecoder_IgnoreUntaggedFields(t *testing.T) { + type Input struct { + UntaggedNumber int + TaggedNumber int `mapstructure:"tagged_number"` + UntaggedString string + TaggedString string `mapstructure:"tagged_string"` + } + input := &Input{ + UntaggedNumber: 31, + TaggedNumber: 42, + UntaggedString: "hidden", + TaggedString: "visible", + } + + actual := make(map[string]interface{}) + config := &DecoderConfig{ + Result: &actual, + IgnoreUntaggedFields: true, + } + + decoder, err := NewDecoder(config) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + err = decoder.Decode(input) + if err != nil { + t.Fatalf("jderr: %s", err) + } + + expected := map[string]interface{}{ + "tagged_number": 42, + "tagged_string": "visible", + } + + if !reflect.DeepEqual(expected, actual) { + t.Fatalf("Decode() expected: %#v\ngot: %#v", expected, actual) + } +} + +func testSliceInput(t *testing.T, input map[string]interface{}, expected *Slice) { + var result Slice + err := Decode(input, &result) + if err != nil { + t.Fatalf("got error: %s", err) + } + + if result.Vfoo != expected.Vfoo { + t.Errorf("Vfoo expected '%s', got '%s'", expected.Vfoo, result.Vfoo) + } + + if result.Vbar == nil { + t.Fatalf("Vbar a slice, got '%#v'", result.Vbar) + } + + if len(result.Vbar) != len(expected.Vbar) { + t.Errorf("Vbar length should be %d, got %d", len(expected.Vbar), len(result.Vbar)) + } + + for i, v := range result.Vbar { + if v != expected.Vbar[i] { + t.Errorf( + "Vbar[%d] should be '%#v', got '%#v'", + i, expected.Vbar[i], v) + } + } +} + +func testArrayInput(t *testing.T, input map[string]interface{}, expected *Array) { + var result Array + err := Decode(input, &result) + if err != nil { + t.Fatalf("got error: %s", err) + } + + if result.Vfoo != expected.Vfoo { + t.Errorf("Vfoo expected '%s', got '%s'", expected.Vfoo, result.Vfoo) + } + + if result.Vbar == [2]string{} { + t.Fatalf("Vbar a slice, got '%#v'", result.Vbar) + } + + if len(result.Vbar) != len(expected.Vbar) { + t.Errorf("Vbar length should be %d, got %d", len(expected.Vbar), len(result.Vbar)) + } + + for i, v := range result.Vbar { + if v != expected.Vbar[i] { + t.Errorf( + "Vbar[%d] should be '%#v', got '%#v'", + i, expected.Vbar[i], v) + } + } +} + +func stringPtr(v string) *string { return &v } +func intPtr(v int) *int { return &v } +func uintPtr(v uint) *uint { return &v } +func boolPtr(v bool) *bool { return &v } +func floatPtr(v float64) *float64 { return &v } +func interfacePtr(v interface{}) *interface{} { return &v } diff --git a/pkg/mapstructure/my_decode.go b/pkg/mapstructure/my_decode.go new file mode 100644 index 0000000..6cafe11 --- /dev/null +++ b/pkg/mapstructure/my_decode.go @@ -0,0 +1,26 @@ +package mapstructure + +import "time" + +// DecodeWithTime 支持时间转字符串 +// 支持 +// 1. *Time.time 转 string/*string +// 2. *Time.time 转 uint/uint32/uint64/int/int32/int64,支持带指针 +// 不能用 Time.time 转,它会在上层认为是一个结构体数据而直接转成map,再到hook方法 +func DecodeWithTime(input, output interface{}, layout string) error { + if layout == "" { + layout = time.DateTime + } + config := &DecoderConfig{ + Metadata: nil, + Result: output, + DecodeHook: ComposeDecodeHookFunc(TimeToStringHook(layout), TimeToUnixIntHook()), + } + + decoder, err := NewDecoder(config) + if err != nil { + return err + } + + return decoder.Decode(input) +} diff --git a/pkg/mapstructure/my_decode_hook.go b/pkg/mapstructure/my_decode_hook.go new file mode 100644 index 0000000..4873731 --- /dev/null +++ b/pkg/mapstructure/my_decode_hook.go @@ -0,0 +1,101 @@ +package mapstructure + +import ( + "reflect" + "time" +) + +// TimeToStringHook 时间转字符串 +// 支持 *Time.time 转 string/*string +// 不能用 Time.time 转,它会在上层认为是一个结构体数据而直接转成map,再到hook方法 +func TimeToStringHook(layout string) DecodeHookFunc { + return func( + f reflect.Type, + t reflect.Type, + data interface{}) (interface{}, error) { + // 判断目标类型是否为字符串 + var strType string + var isStrPointer *bool // 要转换的目标类型是否为指针字符串 + if t == reflect.TypeOf(strType) { + isStrPointer = new(bool) + } else if t == reflect.TypeOf(&strType) { + isStrPointer = new(bool) + *isStrPointer = true + } + if isStrPointer == nil { + return data, nil + } + + // 判断类型是否为时间 + timeType := time.Time{} + if f != reflect.TypeOf(timeType) && f != reflect.TypeOf(&timeType) { + return data, nil + } + + // 将时间转换为字符串 + var output string + switch v := data.(type) { + case *time.Time: + output = v.Format(layout) + case time.Time: + output = v.Format(layout) + default: + return data, nil + } + + if *isStrPointer { + return &output, nil + } + return output, nil + } +} + +// TimeToUnixIntHook 时间转时间戳 +// 支持 *Time.time 转 uint/uint32/uint64/int/int32/int64,支持带指针 +// 不能用 Time.time 转,它会在上层认为是一个结构体数据而直接转成map,再到hook方法 +func TimeToUnixIntHook() DecodeHookFunc { + return func( + f reflect.Type, + t reflect.Type, + data interface{}) (interface{}, error) { + + tkd := t.Kind() + if tkd != reflect.Int && tkd != reflect.Int32 && tkd != reflect.Int64 && + tkd != reflect.Uint && tkd != reflect.Uint32 && tkd != reflect.Uint64 { + return data, nil + } + + // 判断类型是否为时间 + timeType := time.Time{} + if f != reflect.TypeOf(timeType) && f != reflect.TypeOf(&timeType) { + return data, nil + } + + // 将时间转换为字符串 + var output int64 + switch v := data.(type) { + case *time.Time: + output = v.Unix() + case time.Time: + output = v.Unix() + default: + return data, nil + } + switch tkd { + case reflect.Int: + return int(output), nil + case reflect.Int32: + return int32(output), nil + case reflect.Int64: + return output, nil + case reflect.Uint: + return uint(output), nil + case reflect.Uint32: + return uint32(output), nil + case reflect.Uint64: + return uint64(output), nil + default: + return data, nil + } + } +} diff --git a/pkg/mapstructure/my_decode_hook_test.go b/pkg/mapstructure/my_decode_hook_test.go new file mode 100644 index 0000000..2f5687d --- /dev/null +++ b/pkg/mapstructure/my_decode_hook_test.go @@ -0,0 +1,274 @@ +package mapstructure + +import ( + "testing" + "time" +) + +func Test_TimeToStringHook(t *testing.T) { + type Input struct { + Time time.Time + Id int + } + + type InputTPointer struct { + Time *time.Time + Id int + } + + type Output struct { + Time string + Id int + } + + type OutputTPointer struct { + Time *string + Id int + } + now := time.Now() + target := now.Format("2006-01-02 15:04:05") + idValue := 1 + tests := []struct { + input any + output any + name string + layout string + }{ + { + name: "测试Time.time转string", + layout: "2006-01-02 15:04:05", + input: InputTPointer{ + Time: &now, + Id: idValue, + }, + output: Output{}, + }, + { + name: "测试*Time.time转*string", + layout: "2006-01-02 15:04:05", + input: InputTPointer{ + Time: &now, + Id: idValue, + }, + output: OutputTPointer{}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + decoder, err := NewDecoder(&DecoderConfig{ + DecodeHook: TimeToStringHook(tt.layout), + Result: &tt.output, + }) + if err != nil { + t.Errorf("NewDecoder() jderr = %v,want nil", err) + } + + if i, isOk := tt.input.(Input); isOk { + err = decoder.Decode(i) + } + if i, isOk := tt.input.(InputTPointer); isOk { + err = decoder.Decode(&i) + } + if err != nil { + t.Errorf("Decode jderr = %v,want nil", err) + } + //验证测试值 + if output, isOk := tt.output.(OutputTPointer); isOk { + if *output.Time != target { + t.Errorf("Decode output time = %v,want %v", *output.Time, target) + } + if output.Id != idValue { + t.Errorf("Decode output id = %v,want %v", output.Id, idValue) + } + } + if output, isOk := tt.output.(Output); isOk { + if output.Time != target { + t.Errorf("Decode output time = %v,want %v", output.Time, target) + } + if output.Id != idValue { + t.Errorf("Decode output id = %v,want %v", output.Id, idValue) + } + } + }) + } +} + +func Test_TimeToUnixIntHook(t *testing.T) { + type InputTPointer struct { + Time *time.Time + Id int + } + + type Output[T int | *int | int32 | *int32 | int64 | *int64 | uint | *uint] struct { + Time T + Id int + } + + type test struct { + input any + output any + name string + layout string + } + + now := time.Now() + target := now.Unix() + idValue := 1 + tests := []test{ + { + name: "测试Time.time转int", + layout: "2006-01-02 15:04:05", + input: InputTPointer{ + Time: &now, + Id: idValue, + }, + output: Output[int]{}, + }, + { + name: "测试Time.time转*int", + layout: "2006-01-02 15:04:05", + input: InputTPointer{ + Time: &now, + Id: idValue, + }, + output: Output[*int]{}, + }, + { + name: "测试Time.time转int32", + layout: "2006-01-02 15:04:05", + input: InputTPointer{ + Time: &now, + Id: idValue, + }, + output: Output[int32]{}, + }, + { + name: "测试Time.time转*int32", + layout: "2006-01-02 15:04:05", + input: InputTPointer{ + Time: &now, + Id: idValue, + }, + output: Output[*int32]{}, + }, + { + name: "测试Time.time转int64", + layout: "2006-01-02 15:04:05", + input: InputTPointer{ + Time: &now, + Id: idValue, + }, + output: Output[int64]{}, + }, + { + name: "测试Time.time转*int64", + layout: "2006-01-02 15:04:05", + input: InputTPointer{ + Time: &now, + Id: idValue, + }, + output: Output[*int64]{}, + }, + { + name: "测试Time.time转uint", + layout: "2006-01-02 15:04:05", + input: InputTPointer{ + Time: &now, + Id: idValue, + }, + output: Output[uint]{}, + }, + { + name: "测试Time.time转*uint", + layout: "2006-01-02 15:04:05", + input: InputTPointer{ + Time: &now, + Id: idValue, + }, + output: Output[*uint]{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + decoder, err := NewDecoder(&DecoderConfig{ + DecodeHook: TimeToUnixIntHook(), + Result: &tt.output, + }) + if err != nil { + t.Errorf("NewDecoder() jderr = %v,want nil", err) + } + + if i, isOk := tt.input.(InputTPointer); isOk { + err = decoder.Decode(i) + } + if i, isOk := tt.input.(InputTPointer); isOk { + err = decoder.Decode(&i) + } + if err != nil { + t.Errorf("Decode jderr = %v,want nil", err) + } + + //验证测试值 + switch v := tt.output.(type) { + case Output[int]: + if int64(v.Time) != target { + t.Errorf("Decode output time = %v,want %v", v.Time, target) + } + if v.Id != idValue { + t.Errorf("Decode output id = %v,want %v", v.Id, idValue) + } + case Output[*int]: + if int64(*v.Time) != target { + t.Errorf("Decode output time = %v,want %v", v.Time, target) + } + if v.Id != idValue { + t.Errorf("Decode output id = %v,want %v", v.Id, idValue) + } + case Output[int32]: + if int64(v.Time) != target { + t.Errorf("Decode output time = %v,want %v", v.Time, target) + } + if v.Id != idValue { + t.Errorf("Decode output id = %v,want %v", v.Id, idValue) + } + case Output[*int32]: + if int64(*v.Time) != target { + t.Errorf("Decode output time = %v,want %v", v.Time, target) + } + if v.Id != idValue { + t.Errorf("Decode output id = %v,want %v", v.Id, idValue) + } + case Output[int64]: + if int64(v.Time) != target { + t.Errorf("Decode output time = %v,want %v", v.Time, target) + } + if v.Id != idValue { + t.Errorf("Decode output id = %v,want %v", v.Id, idValue) + } + case Output[*int64]: + if int64(*v.Time) != target { + t.Errorf("Decode output time = %v,want %v", v.Time, target) + } + if v.Id != idValue { + t.Errorf("Decode output id = %v,want %v", v.Id, idValue) + } + case Output[uint]: + if int64(v.Time) != target { + t.Errorf("Decode output time = %v,want %v", v.Time, target) + } + if v.Id != idValue { + t.Errorf("Decode output id = %v,want %v", v.Id, idValue) + } + case Output[*uint]: + if int64(*v.Time) != target { + t.Errorf("Decode output time = %v,want %v", v.Time, target) + } + if v.Id != idValue { + t.Errorf("Decode output id = %v,want %v", v.Id, idValue) + } + } + + }) + } +} diff --git a/pkg/response.go b/pkg/response.go new file mode 100644 index 0000000..7df9813 --- /dev/null +++ b/pkg/response.go @@ -0,0 +1,38 @@ +package pkg + +import ( + "encoding/json" + "fmt" + "os" + + "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/log" +) + +func HandleResponse(c *fiber.Ctx, data interface{}) (err error) { + if os.Getenv("env") == "unit_test" { + log.Debug(data) + } + switch data.(type) { + case error: + err = data.(error) + case int, int32, int64, float32, float64, string, bool: + c.Response().SetBody([]byte(fmt.Sprintf("%s", data))) + case []byte: + c.Response().SetBody(data.([]byte)) + default: + dataByte, _ := json.Marshal(data) + c.Response().SetBody(dataByte) + } + return +} + +func SuccessWithPageMsg(c *fiber.Ctx, list interface{}, total int64, page, pageSize int) error { + response := fiber.Map{ + "list": list, + "total": total, + "page": page, + "pageSize": pageSize, + } + return HandleResponse(c, response) +} diff --git a/pkg/validata.go b/pkg/validata.go new file mode 100644 index 0000000..5c282e9 --- /dev/null +++ b/pkg/validata.go @@ -0,0 +1,743 @@ +package pkg + +import ( + "fmt" + "geo/tmpl/errcode" + "reflect" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/go-playground/validator/v10" + "github.com/gofiber/fiber/v2" +) + +const ( + // 验证标签常量 + TagRequired = "required" + TagEmail = "email" + TagMin = "min" + TagMax = "max" + TagLen = "len" + TagNumeric = "numeric" + TagOneof = "oneof" + TagGte = "gte" + TagGt = "gt" + TagLte = "lte" + TagLt = "lt" + TagURL = "url" + TagUUID = "uuid" + TagDatetime = "datetime" + + // 中文标签常量 + //LabelComment = "comment" + //LabelLabel = "label" + LabelZh = "zh" + + // 上下文Key + CtxRequestBody = "request_body" + + // 性能优化常量 + InitialBuilderSize = 64 + MaxCacheSize = 5000 + CacheCleanupInterval = 10 * time.Minute +) + +// ValidatorConfig 验证器配置 +type ValidatorConfig struct { + StatusCode int // 验证失败时的HTTP状态码,默认422 + ErrorHandler func(c *fiber.Ctx, err error) error // 自定义错误处理 + BeforeParse func(c *fiber.Ctx) error // 解析前执行 + AfterParse func(c *fiber.Ctx, req interface{}) error // 解析后执行 + DisableCache bool // 是否禁用缓存 + MaxCacheSize int // 最大缓存数量 + UseChinese bool // 是否使用中文提示 + EnableMetrics bool // 是否启用指标收集 +} + +// CacheStats 缓存统计 +type CacheStats struct { + hitCount int64 + missCount int64 + evictCount int64 + errorCount int64 + accessCount int64 +} + +// Metrics 指标收集 +type Metrics struct { + totalRequests int64 + validationTime int64 + cacheHitRate float64 + mu sync.RWMutex +} + +// fieldInfo 字段信息 +type fieldInfo struct { + label string + index int + typeKind reflect.Kind + accessCount int64 + lastAccess time.Time +} + +// typeInfo 类型信息 +type typeInfo struct { + fields map[string]*fieldInfo + fieldNames []string + mu sync.RWMutex + accessCount int64 + createdAt time.Time + lastAccess time.Time + typeKey string +} + +// ValidatorHelper 验证器助手 +type ValidatorHelper struct { + validate *validator.Validate + config *ValidatorConfig + typeCache sync.Map + cacheStats *CacheStats + errorFnCache map[string]func(string, string) string + errorFnCacheMu sync.RWMutex + pruneLock sync.Mutex + stopCleanup chan struct{} + cleanupOnce sync.Once + metrics *Metrics +} + +var ( + ChineseErrorTemplates = map[string]string{ + TagRequired: "%s不能为空", + TagEmail: "%s格式不正确", + TagMin: "%s不能小于%s", + TagMax: "%s不能大于%s", + TagLen: "%s长度必须为%s位", + TagNumeric: "%s必须是数字", + TagOneof: "%s必须是以下值之一: %s", + TagGte: "%s不能小于%s", + TagGt: "%s必须大于%s", + TagLte: "%s不能大于%s", + TagLt: "%s必须小于%s", + TagURL: "%s必须是有效的URL地址", + TagUUID: "%s必须是有效的UUID", + TagDatetime: "%s日期格式不正确", + } + + defaultErrorTemplates = map[string]string{ + TagRequired: "%s is required", + TagEmail: "invalid email format", + TagMin: "%s must be at least %s", + TagMax: "%s must be at most %s", + TagLen: "%s must be exactly %s characters", + TagNumeric: "%s must be numeric", + TagOneof: "%s must be one of: %s", + TagGte: "%s must be greater than or equal to %s", + TagGt: "%s must be greater than %s", + TagLte: "%s must be less than or equal to %s", + TagLt: "%s must be less than %s", + TagURL: "%s must be a valid URL", + TagUUID: "%s must be a valid UUID", + TagDatetime: "%s invalid datetime format", + } +) + +var ( + // 对象池 + builderPool = sync.Pool{ + New: func() interface{} { + b := &strings.Builder{} + b.Grow(InitialBuilderSize) + return b + }, + } + + // 错误信息切片池 + errorSlicePool = sync.Pool{ + New: func() interface{} { + slice := make([]string, 0, 8) + return &slice + }, + } + + // 字段信息对象池 + fieldInfoPool = sync.Pool{ + New: func() interface{} { + return &fieldInfo{ + accessCount: 0, + lastAccess: time.Now(), + } + }, + } + + // 类型信息对象池 + typeInfoPool = sync.Pool{ + New: func() interface{} { + return &typeInfo{ + fields: make(map[string]*fieldInfo), + fieldNames: make([]string, 0, 8), + } + }, + } +) + +var ( + Vh *ValidatorHelper + once sync.Once +) + +// NewValidatorHelper 初始化验证器助手 +func NewValidatorHelper(config ...*ValidatorConfig) { + once.Do(func() { + v := validator.New() + + // 优化JSON标签获取 + v.RegisterTagNameFunc(func(fld reflect.StructField) string { + return fld.Tag.Get("json") + }) + + // 默认配置 + cfg := &ValidatorConfig{ + StatusCode: fiber.StatusUnprocessableEntity, + ErrorHandler: defaultErrorHandler, + MaxCacheSize: MaxCacheSize, + UseChinese: true, + EnableMetrics: false, + } + + if len(config) > 0 && config[0] != nil { + if config[0].StatusCode != 0 { + cfg.StatusCode = config[0].StatusCode + } + if config[0].ErrorHandler != nil { + cfg.ErrorHandler = config[0].ErrorHandler + } + if config[0].MaxCacheSize > 0 { + cfg.MaxCacheSize = config[0].MaxCacheSize + } + cfg.DisableCache = config[0].DisableCache + cfg.BeforeParse = config[0].BeforeParse + cfg.AfterParse = config[0].AfterParse + cfg.UseChinese = config[0].UseChinese + cfg.EnableMetrics = config[0].EnableMetrics + } + + // 预编译错误函数 + errorFnCache := make(map[string]func(string, string) string) + templates := ChineseErrorTemplates + if !cfg.UseChinese { + templates = defaultErrorTemplates + } + + for tag, tmpl := range templates { + t := tmpl // 捕获变量 + errorFnCache[tag] = func(field, param string) string { + if strings.Contains(t, "%s") && strings.Count(t, "%s") == 2 { + return fmt.Sprintf(t, field, param) + } + return fmt.Sprintf(t, field) + } + } + + Vh = &ValidatorHelper{ + validate: v, + config: cfg, + cacheStats: &CacheStats{}, + errorFnCache: errorFnCache, + errorFnCacheMu: sync.RWMutex{}, + stopCleanup: make(chan struct{}), + metrics: &Metrics{}, + } + + // 启动定期清理 + if !cfg.DisableCache { + go Vh.periodicCleanup() + } + }) +} + +// ParseAndValidate 解析并验证请求 +func ParseAndValidate(c *fiber.Ctx, req interface{}) error { + if Vh == nil { + NewValidatorHelper() + } + + if Vh.config.EnableMetrics { + atomic.AddInt64(&Vh.metrics.totalRequests, 1) + defer Vh.recordValidationTime(time.Now()) + } + + // 执行解析前钩子 + if Vh.config.BeforeParse != nil { + if err := Vh.config.BeforeParse(c); err != nil { + return err + } + } + + // 解析请求体 + err := c.BodyParser(req) + if err != nil { + return errcode.ParamErr("请求格式错误:" + err.Error()) + } + + // 执行解析后钩子 + if Vh.config.AfterParse != nil { + if err = Vh.config.AfterParse(c, req); err != nil { + return errcode.ParamErr(err.Error()) + } + } + + // 验证数据 + err = Vh.validate.Struct(req) + if err != nil { + c.Locals(CtxRequestBody, req) + + if !Vh.config.DisableCache { + t := reflect.TypeOf(req) + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + if t.Kind() == reflect.Struct { + Vh.safeGetOrCreateTypeInfo(t) + } + } + + return Vh.config.ErrorHandler(c, err) + } + + return nil +} + +// Validate 直接验证结构体 +func Validate(req interface{}) error { + if Vh == nil { + NewValidatorHelper() + } + + if err := Vh.validate.Struct(req); err != nil { + return Vh.wrapValidationError(err, req) + } + return nil +} + +// 默认错误处理 +func defaultErrorHandler(c *fiber.Ctx, err error) error { + validationErrors, ok := err.(validator.ValidationErrors) + if !ok { + return errcode.SystemError + } + + if len(validationErrors) == 0 { + return nil + } + + // 快速路径:单个错误 + if len(validationErrors) == 1 { + e := validationErrors[0] + msg := Vh.safeGetErrorMessage(c, e) + return errcode.ParamErr(msg) + } + + // 从对象池获取builder + builder := builderPool.Get().(*strings.Builder) + builder.Reset() + defer builderPool.Put(builder) + + req := c.Locals(CtxRequestBody) + + for i, e := range validationErrors { + if i > 0 { + builder.WriteByte('\n') + } + builder.WriteString(Vh.safeGetErrorMessageWithReq(req, e)) + } + + return errcode.ParamErr(builder.String()) +} + +// 包装验证错误 +func (vh *ValidatorHelper) wrapValidationError(err error, req interface{}) error { + validationErrors, ok := err.(validator.ValidationErrors) + if !ok { + return err + } + + if len(validationErrors) == 0 { + return nil + } + + // 构建错误消息 + builder := builderPool.Get().(*strings.Builder) + builder.Reset() + defer builderPool.Put(builder) + + for i, e := range validationErrors { + if i > 0 { + builder.WriteByte('\n') + } + builder.WriteString(vh.safeGetErrorMessageWithReq(req, e)) + } + + return errcode.ParamErr(builder.String()) +} + +// 安全获取错误消息 +func (vh *ValidatorHelper) safeGetErrorMessage(c *fiber.Ctx, e validator.FieldError) string { + req := c.Locals(CtxRequestBody) + return vh.safeGetErrorMessageWithReq(req, e) +} + +// 安全获取错误消息(带请求体) +func (vh *ValidatorHelper) safeGetErrorMessageWithReq(req interface{}, e validator.FieldError) string { + if req == nil { + return vh.safeFormatFieldError(e, nil) + } + + t := reflect.TypeOf(req) + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + if t.Kind() != reflect.Struct { + return vh.safeFormatFieldError(e, nil) + } + + // 安全获取类型信息 - 这里需要先声明变量 + var typeInfoObj *typeInfo // 这里声明变量 + if !vh.config.DisableCache { + if cached, ok := vh.typeCache.Load(t); ok { + typeInfoObj = cached.(*typeInfo) + atomic.AddInt64(&typeInfoObj.accessCount, 1) + } + } + + return vh.safeFormatFieldError(e, typeInfoObj) +} + +// 安全格式化字段错误 +func (vh *ValidatorHelper) safeFormatFieldError(e validator.FieldError, typeInfo *typeInfo) string { + structField := e.StructField() + fieldName := e.Field() + + // 获取字段标签 + var label string + if typeInfo != nil { + typeInfo.mu.RLock() + if info, ok := typeInfo.fields[structField]; ok { + label = info.label + atomic.AddInt64(&info.accessCount, 1) + } + typeInfo.mu.RUnlock() + } + + // 如果没有标签,返回默认消息 + if label == "" { + return vh.safeGetDefaultErrorMessage(fieldName, e) + } + + // 使用预编译的错误函数生成消息 + vh.errorFnCacheMu.RLock() + fn, ok := vh.errorFnCache[e.Tag()] + vh.errorFnCacheMu.RUnlock() + + if ok { + return fn(label, e.Param()) + } + + return label + "格式不正确" +} + +// 安全获取默认错误消息 +func (vh *ValidatorHelper) safeGetDefaultErrorMessage(field string, e validator.FieldError) string { + vh.errorFnCacheMu.RLock() + defer vh.errorFnCacheMu.RUnlock() + + if fn, ok := vh.errorFnCache[e.Tag()]; ok { + return fn(field, e.Param()) + } + + return field + "验证失败" +} + +// safeGetOrCreateTypeInfo 安全地获取或创建类型信息 +func (vh *ValidatorHelper) safeGetOrCreateTypeInfo(t reflect.Type) *typeInfo { + if vh.config.DisableCache || t == nil || t.Kind() != reflect.Struct { + return nil + } + + // 首先尝试从缓存读取 + if cached, ok := vh.typeCache.Load(t); ok { + info := cached.(*typeInfo) + atomic.AddInt64(&info.accessCount, 1) + atomic.AddInt64(&vh.cacheStats.hitCount, 1) + info.lastAccess = time.Now() + return info + } + + atomic.AddInt64(&vh.cacheStats.missCount, 1) + + // 从对象池获取typeInfo + info := typeInfoPool.Get().(*typeInfo) + + // 重置并初始化 + info.mu.Lock() + + // 清空现有map + for k := range info.fields { + delete(info.fields, k) + } + info.fieldNames = info.fieldNames[:0] + + info.accessCount = 0 + info.createdAt = time.Now() + info.lastAccess = time.Now() + info.typeKey = t.String() + + // 预计算所有字段信息 + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + + // 获取标签 + //label := field.Tag.Get(LabelComment) + //if label == "" { + // label = field.Tag.Get(LabelLabel) + //} + //if label == "" { + // label = field.Tag.Get(LabelZh) + //} + label := field.Tag.Get(LabelZh) + // 从对象池获取或创建字段信息 + fieldInfo := fieldInfoPool.Get().(*fieldInfo) + fieldInfo.label = label + fieldInfo.index = i + fieldInfo.typeKind = field.Type.Kind() + fieldInfo.accessCount = 0 + fieldInfo.lastAccess = time.Now() + + info.fields[field.Name] = fieldInfo + info.fieldNames = append(info.fieldNames, field.Name) + } + + info.mu.Unlock() + + // 使用原子操作确保线程安全的存储 + if existing, loaded := vh.typeCache.LoadOrStore(t, info); loaded { + // 如果已经有其他goroutine存储了,使用已有的并回收新创建的 + info.mu.Lock() + for _, fieldInfo := range info.fields { + fieldInfoPool.Put(fieldInfo) + } + info.mu.Unlock() + typeInfoPool.Put(info) + + existingInfo := existing.(*typeInfo) + atomic.AddInt64(&existingInfo.accessCount, 1) + return existingInfo + } + + return info +} + +// 定期清理缓存 +func (vh *ValidatorHelper) periodicCleanup() { + ticker := time.NewTicker(CacheCleanupInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + vh.safeCleanupCache() + case <-vh.stopCleanup: + return + } + } +} + +// safeCleanupCache 安全清理缓存 +func (vh *ValidatorHelper) safeCleanupCache() { + vh.pruneLock.Lock() + defer vh.pruneLock.Unlock() + + var keysToDelete []interface{} + now := time.Now() + + vh.typeCache.Range(func(key, value interface{}) bool { + info, ok := value.(*typeInfo) + if !ok { + return true + } + + // 检查是否需要清理 + accessCount := atomic.LoadInt64(&info.accessCount) + age := now.Sub(info.createdAt) + idleTime := now.Sub(info.lastAccess) + + // 清理条件: + // 1. 很少访问的缓存(访问次数 < 10) + // 2. 空闲时间超过30分钟 + // 3. 缓存年龄超过1小时且访问次数较少 + if (accessCount < 10 && idleTime > 30*time.Minute) || + (age > 1*time.Hour && accessCount < 100) { + keysToDelete = append(keysToDelete, key) + atomic.AddInt64(&vh.cacheStats.evictCount, 1) + } + + return true + }) + + // 删除选中的缓存 + for _, key := range keysToDelete { + if val, ok := vh.typeCache.Load(key); ok { + if info, ok := val.(*typeInfo); ok { + // 安全回收字段信息 + info.mu.Lock() + for _, fieldInfo := range info.fields { + fieldInfoPool.Put(fieldInfo) + } + info.mu.Unlock() + typeInfoPool.Put(info) + } + vh.typeCache.Delete(key) + } + } +} + +// ==================== 指标收集 ==================== + +func (vh *ValidatorHelper) recordValidationTime(start time.Time) { + if !vh.config.EnableMetrics { + return + } + duration := time.Since(start).Nanoseconds() + atomic.AddInt64(&vh.metrics.validationTime, duration) +} + +// GetMetrics 获取性能指标 +func (vh *ValidatorHelper) GetMetrics() map[string]interface{} { + if !vh.config.EnableMetrics { + return nil + } + + vh.metrics.mu.RLock() + defer vh.metrics.mu.RUnlock() + + hitCount := atomic.LoadInt64(&vh.cacheStats.hitCount) + missCount := atomic.LoadInt64(&vh.cacheStats.missCount) + totalRequests := hitCount + missCount + + var hitRate float64 + if totalRequests > 0 { + hitRate = float64(hitCount) / float64(totalRequests) * 100 + } + + var cacheSize int64 + vh.typeCache.Range(func(_, _ interface{}) bool { + cacheSize++ + return true + }) + + return map[string]interface{}{ + "cache_hit_count": hitCount, + "cache_miss_count": missCount, + "cache_evict_count": atomic.LoadInt64(&vh.cacheStats.evictCount), + "cache_hit_rate": fmt.Sprintf("%.2f%%", hitRate), + "cache_size": cacheSize, + "error_count": atomic.LoadInt64(&vh.cacheStats.errorCount), + "total_requests": atomic.LoadInt64(&vh.metrics.totalRequests), + "avg_validation_time_ns": atomic.LoadInt64(&vh.metrics.validationTime) / + max(atomic.LoadInt64(&vh.metrics.totalRequests), 1), + } +} + +// RegisterValidation 注册自定义验证规则 +func (vh *ValidatorHelper) RegisterValidation(tag string, fn validator.Func, callValidationEvenIfNull ...bool) error { + return vh.validate.RegisterValidation(tag, fn, callValidationEvenIfNull...) +} + +// RegisterTranslation 注册自定义翻译 +func (vh *ValidatorHelper) RegisterTranslation(tag string, template string) { + vh.errorFnCacheMu.Lock() + defer vh.errorFnCacheMu.Unlock() + + t := template + vh.errorFnCache[tag] = func(field, param string) string { + if strings.Contains(t, "%s") && strings.Count(t, "%s") == 2 { + return fmt.Sprintf(t, field, param) + } + return fmt.Sprintf(t, field) + } +} + +// Stop 停止后台清理任务 +func (vh *ValidatorHelper) Stop() { + close(vh.stopCleanup) +} + +// Reset 重置验证器状态 +func (vh *ValidatorHelper) Reset() { + vh.ClearCache() + atomic.StoreInt64(&vh.cacheStats.hitCount, 0) + atomic.StoreInt64(&vh.cacheStats.missCount, 0) + atomic.StoreInt64(&vh.cacheStats.evictCount, 0) + atomic.StoreInt64(&vh.cacheStats.errorCount, 0) + atomic.StoreInt64(&vh.metrics.totalRequests, 0) + atomic.StoreInt64(&vh.metrics.validationTime, 0) +} + +// ClearCache 清理所有缓存 +func (vh *ValidatorHelper) ClearCache() { + vh.pruneLock.Lock() + defer vh.pruneLock.Unlock() + + // 安全回收所有缓存 + vh.typeCache.Range(func(key, value interface{}) bool { + if info, ok := value.(*typeInfo); ok { + info.mu.Lock() + for _, fieldInfo := range info.fields { + fieldInfoPool.Put(fieldInfo) + } + info.mu.Unlock() + typeInfoPool.Put(info) + } + vh.typeCache.Delete(key) + return true + }) +} + +// SetLanguage 设置语言 +func (vh *ValidatorHelper) SetLanguage(useChinese bool) { + vh.config.UseChinese = useChinese + + templates := defaultErrorTemplates + if useChinese { + templates = ChineseErrorTemplates + } + + vh.errorFnCacheMu.Lock() + defer vh.errorFnCacheMu.Unlock() + + for tag, tmpl := range templates { + t := tmpl + vh.errorFnCache[tag] = func(field, param string) string { + if strings.Contains(t, "%s") && strings.Count(t, "%s") == 2 { + return fmt.Sprintf(t, field, param) + } + return fmt.Sprintf(t, field) + } + } +} + +func max(a, b int64) int64 { + if a > b { + return a + } + return b +} + +func GetErr(tag, field, param string) string { + t := ChineseErrorTemplates[tag] + if strings.Contains(t, "%s") && strings.Count(t, "%s") == 2 { + return fmt.Sprintf(t, field, param) + } + return fmt.Sprintf(t, field) +} diff --git a/pkg/wx.go b/pkg/wx.go new file mode 100644 index 0000000..4a7ffd8 --- /dev/null +++ b/pkg/wx.go @@ -0,0 +1,195 @@ +package pkg + +import ( + "bytes" + "context" + "crypto/tls" + "encoding/json" + "fmt" + "io" + "os" + "sync" + "time" + + "net/http" + + "errors" + + "github.com/redis/go-redis/v9" +) + +type WeChatLoginResponse struct { + OpenID string `json:"openid"` // 用户唯一标识 + SessionKey string `json:"session_key"` // 会话密钥 + UnionID string `json:"unionid"` // 用户在开放平台的唯一标识(如果绑定了开放平台才有) + Errcode int `json:"errcode"` // 错误码,0为成功 + Errmsg string `json:"errmsg"` // 错误信息 +} + +func GetOpenID(appID, appSecret, code string) (openid string, err error) { + if os.Getenv("env") == "unit_test" { + return "test_123456", nil + } + // 1. 构建请求微信接口的 URL + url := fmt.Sprintf("https://api.weixin.qq.com/sns/jscode2session?appid=%s&secret=%s&js_code=%s&grant_type=authorization_code", + appID, appSecret, code) + + // 2. 创建 HTTP 客户端(设置超时,避免阻塞) + client := &http.Client{ + Timeout: 5 * time.Second, + // 在某些网络受限的环境(如本地测试跳过证书验证,生产环境建议去掉) + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: false}, // 生产环境建议设为 false + }, + } + + // 3. 发起 GET 请求 + resp, err := client.Get(url) + if err != nil { + return "", fmt.Errorf("请求微信服务器失败: %w", err) + } + defer resp.Body.Close() + + // 4. 读取返回的 Body + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("读取微信响应失败: %w", err) + } + + // 5. 解析 JSON 数据 + var wechatResp WeChatLoginResponse + err = json.Unmarshal(body, &wechatResp) + if err != nil { + return "", fmt.Errorf("解析微信响应 JSON 失败: %s, 原始数据: %s", err.Error(), string(body)) + } + + // 6. 检查微信接口返回的错误码 + if wechatResp.Errcode != 0 { + // 这里可以根据不同的错误码做特殊处理,例如 code 无效、过期等 + return "", fmt.Errorf("微信接口返回错误: code=%d, msg=%s", wechatResp.Errcode, wechatResp.Errmsg) + } + + // 7. 检查 OpenID 是否为空(理论上不会,但防御性编程) + if wechatResp.OpenID == "" { + return "", errors.New("微信返回的 OpenID 为空") + } + + // 8. 返回 OpenID + return wechatResp.OpenID, nil +} + +type AccessTokenResponse struct { + AccessToken string `json:"access_token"` + ExpiresIn int `json:"expires_in"` + Errcode int `json:"errcode"` + Errmsg string `json:"errmsg"` +} + +var ( + tokenMutex sync.Mutex + cacheKey = "wx:access_token" +) + +// GetAccessToken 获取 access_token,带本地缓存 +func GetAccessToken(ctx context.Context, appID, appSecret string, rdb *redis.Client) (string, error) { + if rdb == nil { + return "", errors.New("缓存工具未提供") + } + + cacheToken := rdb.Get(ctx, cacheKey).Val() + if cacheToken != "" { + return cacheToken, nil + } + tokenMutex.Lock() + defer tokenMutex.Unlock() + + // 请求微信接口获取新的 access_token + url := fmt.Sprintf("https://api.weixin.qq.com/cgi-bin/token?grant_type=client_credential&appid=%s&secret=%s", appID, appSecret) + + client := &http.Client{Timeout: 5 * time.Second} + resp, err := client.Get(url) + if err != nil { + return "", fmt.Errorf("请求 access_token 失败: %w", err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + var tokenRes AccessTokenResponse + if err := json.Unmarshal(body, &tokenRes); err != nil { + return "", fmt.Errorf("解析 access_token 响应失败: %w", err) + } + + if tokenRes.Errcode != 0 { + return "", fmt.Errorf("获取 access_token 失败: code=%d, msg=%s", tokenRes.Errcode, tokenRes.Errmsg) + } + + // 缓存 token,提前5分钟过期,避免边界情况 + + rdb.Set(ctx, cacheKey, tokenRes.AccessToken, time.Duration(tokenRes.ExpiresIn-300)*time.Second) + return tokenRes.AccessToken, nil +} + +// PhoneInfo 定义手机号信息的结构体,与微信官方文档对齐 [citation:3][citation:8] +type PhoneInfo struct { + PhoneNumber string `json:"phoneNumber"` // 用户绑定的手机号(国外手机号会有区号) + PurePhoneNumber string `json:"purePhoneNumber"` // 没有区号的手机号 + CountryCode string `json:"countryCode"` // 区号 + Watermark struct { + Timestamp int64 `json:"timestamp"` + Appid string `json:"appid"` + } `json:"watermark"` +} + +// PhoneInfoResponse 定义微信接口返回的完整结构 +type PhoneInfoResponse struct { + Errcode int `json:"errcode"` + Errmsg string `json:"errmsg"` + PhoneInfo PhoneInfo `json:"phone_info"` +} + +// GetPhoneNumber 通过手机号 code 获取用户手机号 +// 参数: +// - appID: 小程序的 AppID +// - appSecret: 小程序的 AppSecret +// - phoneCode: 前端通过 getPhoneNumber 获取的 code +// +// 返回: +// - *PhoneInfo: 手机号信息 +// - error: 错误信息 +func GetPhoneNumber(ctx context.Context, appID, appSecret, phoneCode string, rdb *redis.Client) (*PhoneInfo, error) { + // 1. 获取 access_token + accessToken, err := GetAccessToken(ctx, appID, appSecret, rdb) + if err != nil { + return nil, fmt.Errorf("获取 access_token 失败: %w", err) + } + + // 2. 调用微信接口换取手机号 [citation:8] + url := fmt.Sprintf("https://api.weixin.qq.com/wxa/business/getuserphonenumber?access_token=%s", accessToken) + + // 构建请求体 + requestBody := map[string]string{ + "code": phoneCode, + } + jsonBody, _ := json.Marshal(requestBody) + + client := &http.Client{Timeout: 5 * time.Second} + resp, err := client.Post(url, "application/json", bytes.NewReader(jsonBody)) + if err != nil { + return nil, fmt.Errorf("请求手机号接口失败: %w", err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + + var phoneResp PhoneInfoResponse + if err := json.Unmarshal(body, &phoneResp); err != nil { + return nil, fmt.Errorf("解析手机号响应失败: %s", string(body)) + } + + // 3. 检查微信接口返回的错误码 [citation:8] + if phoneResp.Errcode != 0 { + return nil, fmt.Errorf("微信接口返回错误: code=%d, msg=%s", phoneResp.Errcode, phoneResp.Errmsg) + } + + return &phoneResp.PhoneInfo, nil +} diff --git a/tmpl/dataTemp/queryTempl.go b/tmpl/dataTemp/queryTempl.go new file mode 100644 index 0000000..c871f3b --- /dev/null +++ b/tmpl/dataTemp/queryTempl.go @@ -0,0 +1,289 @@ +package dataTemp + +import ( + "context" + "database/sql" + "fmt" + "geo/tmpl/errcode" + "geo/utils" + "reflect" + + "github.com/go-kratos/kratos/v2/log" + "gorm.io/gorm" + + "xorm.io/builder" +) + +type PrimaryKey struct { + Id int `json:"id"` +} + +type GormDb struct { + Client *gorm.DB +} +type contextTxKey struct{} + +func (d *Db) DB(ctx context.Context) *gorm.DB { + tx, ok := ctx.Value(contextTxKey{}).(*gorm.DB) + if ok { + return tx + } + return d.Db.Client +} + +func (t *Db) ExecTx(ctx context.Context, f func(ctx context.Context) error) error { + return t.Db.Client.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + ctx = context.WithValue(ctx, contextTxKey{}, tx) + return f(ctx) + }) +} + +type Db struct { + Db *GormDb + Log *log.Helper +} + +type DataTemp struct { + Db *gorm.DB + ModelType reflect.Type // 改为存储类型而不是实例 + modelName string // 可选的表名缓存 +} + +func NewDataTemp(db *utils.Db, model interface{}) *DataTemp { + // 获取模型的类型 + t := reflect.TypeOf(model) + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + + return &DataTemp{ + Db: db.Client, + ModelType: t, + } +} + +func (k DataTemp) modelInstance() interface{} { + return reflect.New(k.ModelType).Interface() +} + +func (k DataTemp) GetById(id int32) (data map[string]interface{}, err error) { + err = k.Db.Model(k.modelInstance()).Where("id = ?", id).Find(&data).Error + if data == nil { + err = sql.ErrNoRows + } + return +} + +func (k DataTemp) GetByStruct(ctx context.Context, search interface{}, data interface{}, orderBy string) (err error) { + + err = k.Db.Model(k.modelInstance()).WithContext(ctx).Where(search).Find(&data).Error + + if data == nil { + err = sql.ErrNoRows + } + return +} + +func (k DataTemp) SaveByStruct(search interface{}, data interface{}) (err error) { + err = k.Db.Model(k.modelInstance()).Where(search).Save(&data).Error + if data == nil { + err = sql.ErrNoRows + } + return +} + +func (k DataTemp) Add(ctx context.Context, data interface{}) (err error) { + m := k.modelInstance() + if err = k.Db.Model(m).WithContext(ctx).Create(data).Error; err != nil { + return errcode.SqlErr(err) + } + return +} + +func (k DataTemp) AddWithData(data interface{}) (interface{}, error) { + result := k.Db.Model(k.modelInstance()).Create(data) + if result.Error != nil { + return data, result.Error + } + return data, nil +} + +func (k DataTemp) GetList(cond *builder.Cond, pageBoIn *ReqPageBo) (list []map[string]interface{}, pageBoOut *RespPageBo, err error) { + var ( + query, _ = builder.ToBoundSQL(*cond) + model = k.Db.Model(k.modelInstance()).Where(query) + total int64 + ) + model.Count(&total) + pageBoOut = pageBoOut.SetDataByReq(total, pageBoIn) + model.Limit(pageBoIn.GetSize()).Offset(pageBoIn.GetOffset()).Order("updated_at desc").Find(&list) + return +} + +func (k DataTemp) GetRange(ctx context.Context, cond *builder.Cond) (list []map[string]interface{}, err error) { + var ( + query, _ = builder.ToBoundSQL(*cond) + model = k.Db.Model(k.modelInstance()).Where(query) + ) + err = model.WithContext(ctx).Find(&list).Error + return list, err +} + +func (k DataTemp) GetRangeToMapStruct(ctx context.Context, cond *builder.Cond, data interface{}) (err error) { + var ( + query, _ = builder.ToBoundSQL(*cond) + model = k.Db.Model(k.modelInstance()).Where(query) + ) + err = model.WithContext(ctx).Find(data).Error + return err +} + +func (k DataTemp) GetOneBySearch(cond *builder.Cond) (data map[string]interface{}, err error) { + query, _ := builder.ToBoundSQL(*cond) + if err = k.Db.Model(k.modelInstance()).Where(query).Limit(1).Find(&data).Error; err != nil { + return nil, errcode.SqlErr(err) + } + + return +} + +func (k DataTemp) Exist(ctx context.Context, cond *builder.Cond) (bool, error) { + var data map[string]interface{} + query, _ := builder.ToBoundSQL(*cond) + err := k.Db.WithContext(ctx).Model(k.modelInstance()).Where(query).Limit(1).Find(&data).Error + if err != nil || data != nil { + return true, err + } + return false, nil +} + +func (k DataTemp) GetOneBySearchStruct(ctx context.Context, cond *builder.Cond, data interface{}) (err error) { + query, _ := builder.ToBoundSQL(*cond) + if err = k.Db.Model(k.modelInstance()).WithContext(ctx).Where(query).Limit(1).Find(&data).Error; err != nil { + return errcode.SqlErr(err) + } + + return +} + +func (k DataTemp) GetListToStruct(ctx context.Context, cond *builder.Cond, pageBoIn *ReqPageBo, result interface{}, orderBy string) (pageBoOut *RespPageBo, err error) { + // 参数验证 + if result == nil { + return nil, fmt.Errorf("result cannot be nil") + } + + val := reflect.ValueOf(result) + if val.Kind() != reflect.Ptr { + return nil, fmt.Errorf("result must be a pointer") + } + + elem := val.Elem() + if elem.Kind() != reflect.Slice { + return nil, fmt.Errorf("result must be a pointer to slice") + } + + // 构建基础查询 + query, _ := builder.ToBoundSQL(*cond) + + // 预编译 SQL 以提高性能 + // 使用 Table 指定表名,避免 GORM 的反射开销 + + db := k.Db.WithContext(ctx).Model(k.modelInstance()).Where(query) + + // 获取总数(使用单独的计数查询,避免缓存影响) + var total int64 + countDb := db + if pageBoIn != nil { + if err = countDb.Count(&total).Error; err != nil { + return nil, err + } + } + + // 初始化分页响应 + pageBoOut = &RespPageBo{} + pageBoOut = pageBoOut.SetDataByReq(total, pageBoIn) + + // 如果没有数据,直接返回空切片 + if total == 0 && pageBoIn != nil { + elem.Set(reflect.MakeSlice(elem.Type(), 0, 0)) + return pageBoOut, nil + } + + // 设置排序(使用索引字段提高性能) + if orderBy == "" { + orderBy = "updated_at desc" + } + + // 应用分页和排序,执行查询 + // 使用 Select 指定字段,避免查询所有字段(如果需要优化) + baseQuery := db + if pageBoIn != nil { + baseQuery = db.Limit(pageBoIn.GetSize()).Offset(pageBoIn.GetOffset()).Order(orderBy) + } + if err = baseQuery. + Order(orderBy). + Find(result).Error; err != nil { + return nil, err + } + + return pageBoOut, nil +} + +func (k DataTemp) UpdateByKey(ctx context.Context, key string, id interface{}, data interface{}) (err error) { + if err = k.Db.WithContext(ctx).Model(k.modelInstance()).Where(fmt.Sprintf("%s = ?", key), id).Updates(data).Error; err != nil { + return errcode.SqlErr(err) + } + return +} + +func (k DataTemp) UpdateByCond(ctx context.Context, cond *builder.Cond, data interface{}) (err error) { + var ( + query, _ = builder.ToBoundSQL(*cond) + model = k.Db.Model(k.modelInstance()).Where(query) + ) + err = model.WithContext(ctx).Updates(data).Error + return err +} + +func (k DataTemp) UpdateColumnByCond(ctx context.Context, cond *builder.Cond, column string, data interface{}) (err error) { + var ( + query, _ = builder.ToBoundSQL(*cond) + model = k.Db.Model(k.modelInstance()).Where(query) + ) + err = model.WithContext(ctx).Update(column, data).Error + return err +} + +func (k DataTemp) GetByKey(ctx context.Context, key string, value interface{}, data interface{}) (err error) { + if err = k.Db.WithContext(ctx).Model(k.modelInstance()).Where(fmt.Sprintf("%s = ?", key), value).Find(data).Error; err != nil { + return errcode.SqlErr(err) + } + return +} + +func (k DataTemp) DeleteByKey(ctx context.Context, key string, value int64) error { + result := k.Db.WithContext(ctx).Model(k.modelInstance()).Where(fmt.Sprintf("%s = ?", key), value). + Update("deleted_at", gorm.Expr("CURRENT_TIMESTAMP")) + + if result.Error != nil { + return result.Error + } + if result.RowsAffected == 0 { + return errcode.NotFound("不存在或已被删除") + } + return nil +} + +func (k DataTemp) CountByCond(ctx context.Context, cond *builder.Cond) (int64, error) { + var ( + count int64 + query, _ = builder.ToBoundSQL(*cond) + model = k.Db.Model(k.modelInstance()).Where(query) + ) + err := model.WithContext(ctx).Count(&count).Error + if err != nil { + return 0, err + } + + return count, err +} diff --git a/tmpl/dataTemp/req_page.go b/tmpl/dataTemp/req_page.go new file mode 100644 index 0000000..87f17a2 --- /dev/null +++ b/tmpl/dataTemp/req_page.go @@ -0,0 +1,33 @@ +package dataTemp + +// ReqPageBo 分页请求实体 +type ReqPageBo struct { + Page int //页码,从第1页开始 + Limit int //分页大小 +} + +// GetOffset 获取便宜量 +// 确保 dataTemp/page.go 中有这些方法 +func (p *ReqPageBo) GetSize() int { + if p == nil { + return 10 // 默认每页10条 + } + if p.Limit <= 0 { + return 10 + } + return p.Limit +} + +func (p *ReqPageBo) GetOffset() int { + if p == nil { + return 0 + } + return (p.GetPage() - 1) * p.GetSize() +} + +func (p *ReqPageBo) GetPage() int { + if p == nil || p.Page <= 0 { + return 1 + } + return p.Page +} diff --git a/tmpl/dataTemp/resp_page.go b/tmpl/dataTemp/resp_page.go new file mode 100644 index 0000000..e707c8c --- /dev/null +++ b/tmpl/dataTemp/resp_page.go @@ -0,0 +1,22 @@ +package dataTemp + +// RespPageBo 分页响应实体 +type RespPageBo struct { + Page int //页码 + Limit int //每页大小 + Total int64 //总数 +} + +// SetDataByReq 通过req 设置响应参数 +func (r *RespPageBo) SetDataByReq(total int64, reqPage *ReqPageBo) *RespPageBo { + resp := r + if r == nil { + resp = &RespPageBo{} + } + resp.Total = total + if reqPage != nil { + resp.Page = reqPage.Page + resp.Limit = reqPage.Limit + } + return resp +} diff --git a/tmpl/errcode/common.go b/tmpl/errcode/common.go new file mode 100644 index 0000000..458b7db --- /dev/null +++ b/tmpl/errcode/common.go @@ -0,0 +1,113 @@ +package errcode + +import "fmt" + +var ( + AuthNotFound = &BusinessErr{code: AuthErr, message: "账号不存在"} + AuthStatusFreeze = &BusinessErr{code: AuthErr, message: "账号冻结"} + AuthStatusDel = &BusinessErr{code: AuthErr, message: "身份验证失败"} + AuthStatusPwdFail = &BusinessErr{code: AuthErr, message: "密码错误"} + AuthTokenCreateFail = &BusinessErr{code: AuthErr, message: "token生成失败"} + AuthTokenDelFail = &BusinessErr{code: AuthErr, message: "删除token失败"} + AuthWxLoginFail = &BusinessErr{code: AuthErr, message: "微信登录失败,请稍后重试"} + AuthInfoFail = &BusinessErr{code: AuthErr, message: "登录异常,请重新登录"} + + TokenNotFound = &BusinessErr{code: TokenErr, message: "缺少 Authorization Header"} + TokenFormatErr = &BusinessErr{code: TokenErr, message: "无效的token格式"} + TokenInfoNotFound = &BusinessErr{code: TokenErr, message: "未找到用户信息"} + TokenInvalid = &BusinessErr{code: TokenErr, message: "token过期"} + + PlatsNotFound = &BusinessErr{code: NotFoundErr, message: "信息未找到"} + BadRequest = &BusinessErr{code: BadReqErr, message: "操作失败"} + + ForbiddenError = &BusinessErr{code: ForbiddenErr, message: "权限不足"} + Success = &BusinessErr{code: 200, message: "成功"} + ParamError = &BusinessErr{code: ParamsErr, message: "参数错误"} + + SystemError = &BusinessErr{code: 405, message: "系统错误"} + + ClientNotFound = &BusinessErr{code: 406, message: "未找到client_id"} + SessionNotFound = &BusinessErr{code: 407, message: "未找到会话信息"} + UserNotFound = &BusinessErr{code: NotFoundErr, message: "不存在的用户"} + + KeyNotFound = &BusinessErr{code: 409, message: "身份验证失败"} + SysNotFound = &BusinessErr{code: 410, message: "未找到系统信息"} + SysCodeNotFound = &BusinessErr{code: 411, message: "未找到系统编码"} + InvalidParam = &BusinessErr{code: InvalidParamCode, message: "无效参数"} + WorkflowError = &BusinessErr{code: 501, message: "工作流过程错误"} + ClientInfoNotFound = &BusinessErr{code: NotFoundErr, message: "用户信息未找到"} +) + +const ( + InvalidParamCode = 408 + AuthErr = 403 + TokenErr = 401 + ParamsErr = 422 + BadReqErr = 400 + NotFoundErr = 404 + ForbiddenErr = 403 + BalanceNotEnoughCode = 402 +) + +type BusinessErr struct { + code int + message string +} + +func (e *BusinessErr) Error() string { + return e.message +} +func (e *BusinessErr) Code() int { + return e.code +} + +func NotFound(message string) *BusinessErr { + return &BusinessErr{code: NotFoundErr, message: PlatsNotFound.message + ":" + message} +} + +func (e *BusinessErr) Is(target error) bool { + _, ok := target.(*BusinessErr) + return ok +} + +// CustomErr 自定义错误 +func NewBusinessErr(code int, message string) *BusinessErr { + return &BusinessErr{code: code, message: message} +} + +func SysErrf(message string, arg ...any) *BusinessErr { + return &BusinessErr{code: SystemError.code, message: fmt.Sprintf(message, arg)} +} + +func SysErr(message string) *BusinessErr { + return &BusinessErr{code: SystemError.code, message: message} +} + +func ParamErrf(message string, arg ...any) *BusinessErr { + return &BusinessErr{code: ParamError.code, message: fmt.Sprintf(message, arg)} +} + +func ParamErr(message string) *BusinessErr { + return &BusinessErr{code: ParamError.code, message: ParamError.message + ":" + message} +} + +func SqlErr(err error) *BusinessErr { + + return &BusinessErr{code: ParamError.code, message: "数据操作失败,请联系管理员处理:" + err.Error()} +} + +func BadReq(message string) *BusinessErr { + return &BusinessErr{code: BadReqErr, message: BadRequest.message + ":" + message} +} + +func Forbidden(message string) *BusinessErr { + return &BusinessErr{code: ForbiddenErr, message: ForbiddenError.message + ":" + message} +} + +func (e *BusinessErr) Wrap(err error) *BusinessErr { + return NewBusinessErr(e.code, err.Error()) +} + +func BalanceNotEnoughErr(message string) *BusinessErr { + return NewBusinessErr(BalanceNotEnoughCode, message) +} diff --git a/utils/gorm.go b/utils/gorm.go new file mode 100644 index 0000000..9923987 --- /dev/null +++ b/utils/gorm.go @@ -0,0 +1,133 @@ +package utils + +import ( + "geo/internal/config" + "geo/utils/utils_gorm" + + "gorm.io/gorm" +) + +type Db struct { + Client *gorm.DB +} + +func NewGormDb(c *config.Config) (*Db, func()) { + transDBClient, mf := utils_gorm.DBConn(&c.DB) + //directDBClient, df := directDB(c, hLog) + cleanup := func() { + mf() + //df() + } + return &Db{ + Client: transDBClient, + //DirectDBClient: directDBClient, + }, cleanup +} + +// GetOne 查询单条记录,返回 map +func (d *Db) GetOne(sql string, args ...interface{}) (map[string]interface{}, error) { + var result map[string]interface{} + + // 使用 Raw 执行原生 SQL,Scan 到 map 需要先获取 rows + rows, err := d.Client.Raw(sql, args...).Rows() + if err != nil { + return nil, err + } + defer rows.Close() + + if rows.Next() { + // 获取列名 + columns, err := rows.Columns() + if err != nil { + return nil, err + } + + // 创建扫描用的切片 + values := make([]interface{}, len(columns)) + valuePtrs := make([]interface{}, len(columns)) + for i := range values { + valuePtrs[i] = &values[i] + } + + if err := rows.Scan(valuePtrs...); err != nil { + return nil, err + } + + result = make(map[string]interface{}) + for i, col := range columns { + result[col] = values[i] + } + return result, nil + } + return nil, nil +} + +// GetAll 查询多条记录,返回 map 切片 +func (d *Db) GetAll(sql string, args ...interface{}) ([]map[string]interface{}, error) { + rows, err := d.Client.Raw(sql, args...).Rows() + if err != nil { + return nil, err + } + defer rows.Close() + + results := make([]map[string]interface{}, 0) + + for rows.Next() { + columns, err := rows.Columns() + if err != nil { + return nil, err + } + + values := make([]interface{}, len(columns)) + valuePtrs := make([]interface{}, len(columns)) + for i := range values { + valuePtrs[i] = &values[i] + } + + if err := rows.Scan(valuePtrs...); err != nil { + return nil, err + } + + row := make(map[string]interface{}) + for i, col := range columns { + row[col] = values[i] + } + results = append(results, row) + } + return results, nil +} + +// Execute 执行单条 SQL(INSERT/UPDATE/DELETE),返回影响行数 +func (d *Db) Execute(sql string, args ...interface{}) (int64, error) { + result := d.Client.Exec(sql, args...) + if result.Error != nil { + return 0, result.Error + } + return result.RowsAffected, nil +} + +// ExecuteMany 批量执行 SQL,使用事务 +func (d *Db) ExecuteMany(sql string, argsList [][]interface{}) (int64, error) { + var total int64 + + // 开始事务 + tx := d.Client.Begin() + if tx.Error != nil { + return 0, tx.Error + } + + for _, args := range argsList { + result := tx.Exec(sql, args...) + if result.Error != nil { + tx.Rollback() + return 0, result.Error + } + total += result.RowsAffected + } + + // 提交事务 + if err := tx.Commit().Error; err != nil { + return 0, err + } + return total, nil +} diff --git a/utils/provider_set.go b/utils/provider_set.go new file mode 100644 index 0000000..09fee00 --- /dev/null +++ b/utils/provider_set.go @@ -0,0 +1,9 @@ +package utils + +import ( + "github.com/google/wire" +) + +var ProviderUtils = wire.NewSet( + NewGormDb, +) diff --git a/utils/utils_gorm/gorm.go b/utils/utils_gorm/gorm.go new file mode 100644 index 0000000..d64b671 --- /dev/null +++ b/utils/utils_gorm/gorm.go @@ -0,0 +1,41 @@ +package utils_gorm + +import ( + "database/sql" + "fmt" + "geo/internal/config" + "gorm.io/driver/mysql" + "gorm.io/gorm" + "time" +) + +func DBConn(c *config.DB) (*gorm.DB, func()) { + mysqlConn, err := sql.Open(c.Driver, c.Source) + gormDB, err := gorm.Open( + mysql.New(mysql.Config{Conn: mysqlConn}), + ) + + gormDB.Logger = NewCustomLogger(gormDB) + if err != nil { + panic("failed to connect database") + } + sqlDB, err := gormDB.DB() + + // SetMaxIdleConns sets the maximum number of connections in the idle connection pool. + sqlDB.SetMaxIdleConns(int(c.MaxIdle)) + + // SetMaxOpenConns sets the maximum number of open connections to the database. + sqlDB.SetMaxOpenConns(int(c.MaxLifetime)) + + // SetConnMaxLifetime sets the maximum amount of time a connection may be reused. + sqlDB.SetConnMaxLifetime(time.Hour) + + return gormDB, func() { + if mysqlConn != nil { + fmt.Println("关闭 physicalGoodsDB") + if err := mysqlConn.Close(); err != nil { + fmt.Println("关闭 physicalGoodsDB 失败:", err) + } + } + } +} diff --git a/utils/utils_gorm/sql_log.go b/utils/utils_gorm/sql_log.go new file mode 100644 index 0000000..b34a29d --- /dev/null +++ b/utils/utils_gorm/sql_log.go @@ -0,0 +1,96 @@ +package utils_gorm + +import ( + "context" + "fmt" + "gorm.io/gorm" + "gorm.io/gorm/logger" + "regexp" + "strings" + "time" +) + +type CustomLogger struct { + gormLogger logger.Interface + db *gorm.DB +} + +func NewCustomLogger(db *gorm.DB) *CustomLogger { + return &CustomLogger{ + gormLogger: logger.Default.LogMode(logger.Info), + db: db, + } +} + +func (l *CustomLogger) LogMode(level logger.LogLevel) logger.Interface { + newlogger := *l + newlogger.gormLogger = l.gormLogger.LogMode(level) + return &newlogger +} + +func (l *CustomLogger) Info(ctx context.Context, msg string, data ...interface{}) { + l.gormLogger.Info(ctx, msg, data...) +} + +func (l *CustomLogger) Warn(ctx context.Context, msg string, data ...interface{}) { + l.gormLogger.Warn(ctx, msg, data...) +} + +func (l *CustomLogger) Error(ctx context.Context, msg string, data ...interface{}) { + l.gormLogger.Error(ctx, msg, data...) +} + +func (l *CustomLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { + elapsed := time.Since(begin) + sql, _ := fc() + l.gormLogger.Trace(ctx, begin, fc, err) + operation := extractOperation(sql) + tableName := extractTableName(sql) + fmt.Println(tableName) + //// 将SQL语句保存到数据库 + if operation == 0 || tableName == "sql_log" { + return + } + //go l.db.Model(&SqlLog{}).Create(&SqlLog{ + // OperatorID: 1, + // OperatorName: "test", + // SqlInfo: sql, + // TableNames: tableName, + // Type: operation, + //}) + + // 如果有需要,也可以根据执行时间(elapsed)等条件过滤或处理日志记录 + if elapsed > time.Second { + //l.gormLogger.Warn(ctx, "Slow SQL (> 1s): %s", sql) + } +} + +// extractTableName extracts the table name from a SQL query, supporting quoted table names. +func extractTableName(sql string) string { + // 使用非捕获组匹配多种SQL操作关键词 + re := regexp.MustCompile(`(?i)\b(?:from|update|into|delete\s+from)\b\s+[\` + "`" + `"]?(\w+)[\` + "`" + `"]?`) + match := re.FindStringSubmatch(sql) + + // 检查是否匹配成功 + if len(match) > 1 { + return match[1] + } + + return "" +} + +// extractOperation extracts the operation type from a SQL query. +func extractOperation(sql string) int32 { + sql = strings.TrimSpace(strings.ToLower(sql)) + var operation int32 + if strings.HasPrefix(sql, "select") { + operation = 0 + } else if strings.HasPrefix(sql, "insert") { + operation = 1 + } else if strings.HasPrefix(sql, "update") { + operation = 3 + } else if strings.HasPrefix(sql, "delete") { + operation = 2 + } + return operation +}