This commit is contained in:
commit
83f7d34f3e
|
|
@ -0,0 +1,8 @@
|
|||
# Default ignored files
|
||||
/shelf/
|
||||
/workspace.xml
|
||||
# Editor-based HTTP Client requests
|
||||
/httpRequests/
|
||||
# Datasource local storage ignored files
|
||||
/dataSources/
|
||||
/dataSources.local.xml
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="WEB_MODULE" version="4">
|
||||
<component name="Go" enabled="true" />
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$" />
|
||||
<orderEntry type="inheritedJdk" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
</module>
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectModuleManager">
|
||||
<modules>
|
||||
<module fileurl="file://$PROJECT_DIR$/.idea/geogo.iml" filepath="$PROJECT_DIR$/.idea/geogo.iml" />
|
||||
</modules>
|
||||
</component>
|
||||
</project>
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
||||
</component>
|
||||
</project>
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
GOHOSTOS:=$(shell go env GOHOSTOS)
|
||||
GOPATH:=$(shell go env GOPATH)
|
||||
VERSION=$(shell git describe --tags --always)
|
||||
PROJECTNAME=yumchina
|
||||
|
||||
ifeq ($(GOHOSTOS), windows)
|
||||
#the `find.exe` is different from `find` in bash/shell.
|
||||
#to see https://docs.microsoft.com/en-us/windows-server/administration/windows-commands/find.
|
||||
#changed to use git-bash.exe to run find cli or other cli friendly, caused of every developer has a Git.
|
||||
#Git_Bash= $(subst cmd\,bin\bash.exe,$(dir $(shell where git)))
|
||||
Git_Bash=$(subst \,/,$(subst cmd\,bin\bash.exe,$(dir $(shell where git | grep cmd))))
|
||||
INTERNAL_PROTO_FILES=$(shell $(Git_Bash) -c "find internal -name *.proto")
|
||||
API_PROTO_FILES=$(shell $(Git_Bash) -c "find api -name *.proto")
|
||||
else
|
||||
INTERNAL_PROTO_FILES=$(shell find internal -name *.proto)
|
||||
API_PROTO_FILES=$(shell find api -name *.proto)
|
||||
endif
|
||||
|
||||
|
||||
.PHONY: wire
|
||||
# generate wire
|
||||
wire:
|
||||
cd ./cmd/server && wire
|
||||
|
||||
.PHONY: build
|
||||
# build
|
||||
build:
|
||||
# make config;
|
||||
make wire;
|
||||
mkdir -p bin/ && go build -ldflags "-X main.Version=$(VERSION)" -o ./bin/ ./...
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"geo/internal/config"
|
||||
|
||||
"github.com/gofiber/fiber/v2/log"
|
||||
)
|
||||
|
||||
func main() {
|
||||
bc, err := config.LoadConfig()
|
||||
if err != nil {
|
||||
log.Fatalf("加载配置失败: %v", err)
|
||||
}
|
||||
//ctx, cancel := context.WithCancel(nil)
|
||||
app, cleanup, err := InitializeApp(bc, log.DefaultLogger())
|
||||
if err != nil {
|
||||
log.Fatalf("项目初始化失败: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
cleanup()
|
||||
//cancel()
|
||||
}()
|
||||
addr := fmt.Sprintf("0.0.0.0:%d", bc.Server.Port)
|
||||
log.Fatal(app.HttpServer.Listen(addr))
|
||||
}
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
//go:build wireinject
|
||||
// +build wireinject
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"geo/internal/biz"
|
||||
"geo/internal/config"
|
||||
"geo/internal/data/impl"
|
||||
"geo/internal/server"
|
||||
"geo/internal/service"
|
||||
"geo/utils"
|
||||
|
||||
"github.com/gofiber/fiber/v2/log"
|
||||
"github.com/google/wire"
|
||||
)
|
||||
|
||||
// InitializeApp 初始化应用程序
|
||||
func InitializeApp(*config.Config, log.AllLogger) (*server.Servers, func(), error) {
|
||||
panic(wire.Build(
|
||||
server.ProviderSetServer,
|
||||
|
||||
service.ProviderSetAppService,
|
||||
impl.ProviderImpl,
|
||||
utils.ProviderUtils,
|
||||
biz.ProviderSetBiz,
|
||||
))
|
||||
}
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
// Code generated by Wire. DO NOT EDIT.
|
||||
|
||||
//go:generate go run -mod=mod github.com/google/wire/cmd/wire
|
||||
//go:build !wireinject
|
||||
// +build !wireinject
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"geo/internal/biz"
|
||||
"geo/internal/config"
|
||||
"geo/internal/data/impl"
|
||||
"geo/internal/server"
|
||||
"geo/internal/server/router"
|
||||
"geo/internal/service"
|
||||
"geo/utils"
|
||||
"github.com/gofiber/fiber/v2/log"
|
||||
)
|
||||
|
||||
// Injectors from wire.go:
|
||||
|
||||
// InitializeApp 初始化应用程序
|
||||
func InitializeApp(configConfig *config.Config, allLogger log.AllLogger) (*server.Servers, func(), error) {
|
||||
db, cleanup := utils.NewGormDb(configConfig)
|
||||
tokenImpl := impl.NewTokenImpl(db)
|
||||
userImpl := impl.NewUserImpl(db)
|
||||
platImpl := impl.NewPlatImpl(db)
|
||||
publishImpl := impl.NewPublishImpl(db)
|
||||
loginRelationImpl := impl.NewLoginRelationImpl(db)
|
||||
publishBiz := biz.NewPublishBiz(configConfig, publishImpl, userImpl, platImpl, tokenImpl, loginRelationImpl)
|
||||
appService := service.NewAppService(configConfig, tokenImpl, userImpl, platImpl, publishBiz)
|
||||
loginService := service.NewLoginService(configConfig, publishBiz)
|
||||
publishService := service.NewPublishService(configConfig, publishBiz, db)
|
||||
appModule := router.NewAppModule(configConfig, appService, loginService, publishService)
|
||||
routerServer := router.NewRouterServer(appModule)
|
||||
app := server.NewHTTPServer(routerServer)
|
||||
servers := server.NewServers(configConfig, app)
|
||||
return servers, func() {
|
||||
cleanup()
|
||||
}, nil
|
||||
}
|
||||
|
|
@ -0,0 +1 @@
|
|||
[]
|
||||
|
|
@ -0,0 +1,31 @@
|
|||
@echo off
|
||||
chcp 65001 >nul
|
||||
setlocal enabledelayedexpansion
|
||||
|
||||
REM Usage: gen.bat user
|
||||
REM Example: gen.bat usercenter
|
||||
|
||||
if "%1"=="" (
|
||||
echo Error: Please provide table name
|
||||
echo Usage: gen.bat user
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
set tables=%1
|
||||
set modeldir=.\internal\data\model
|
||||
set prefix=edu_
|
||||
|
||||
set fullTableName=%prefix%%tables%
|
||||
|
||||
echo Generating table: %fullTableName%
|
||||
|
||||
go run gorm.io/gen/tools/gentool -dsn "root:lansexiongdi6,@tcp(47.97.27.195:3306)/geo?charset=utf8mb4&parseTime=true&loc=Asia/Shanghai" -outPath "%modeldir%" -onlyModel -modelPkgName "model" -tables "%fullTableName%"
|
||||
|
||||
if %errorlevel% equ 0 (
|
||||
echo Success! Generated files in: %modeldir%
|
||||
) else (
|
||||
echo Failed to generate models
|
||||
)
|
||||
|
||||
pause
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
#!/bin/bash
|
||||
|
||||
# 使用方法:
|
||||
# ./genModel.sh usercenter user
|
||||
# ./genModel.sh usercenter user_auth
|
||||
# 再将./genModel下的文件剪切到对应服务的model目录里面,记得改package
|
||||
|
||||
|
||||
#生成的表名
|
||||
tables=$1
|
||||
#表生成的genmodel目录
|
||||
modeldir=./internal/data/model
|
||||
|
||||
# 数据库配置
|
||||
prefix=edu_
|
||||
|
||||
|
||||
|
||||
gentool --dsn "root:lansexiongdi6,@tcp(47.97.27.195:3306)/geo?charset=utf8mb4&parseTime=true&loc=Asia%2FShanghai" -outPath ${modeldir} -onlyModel -modelPkgName "model" -tables ${prefix}${tables}
|
||||
|
|
@ -0,0 +1,64 @@
|
|||
module geo
|
||||
|
||||
go 1.26.1
|
||||
|
||||
require (
|
||||
github.com/JohannesKaufmann/html-to-markdown v1.6.0
|
||||
github.com/go-kratos/kratos/v2 v2.9.2
|
||||
github.com/go-playground/validator/v10 v10.30.2
|
||||
github.com/go-rod/rod v0.116.2
|
||||
github.com/go-viper/mapstructure/v2 v2.5.0
|
||||
github.com/gofiber/fiber/v2 v2.52.12
|
||||
github.com/golang-jwt/jwt/v5 v5.3.1
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/google/wire v0.7.0
|
||||
github.com/grokify/html-strip-tags-go v0.1.0
|
||||
github.com/redis/go-redis/v9 v9.18.0
|
||||
github.com/spf13/viper v1.21.0
|
||||
gorm.io/driver/mysql v1.6.0
|
||||
gorm.io/gorm v1.31.1
|
||||
xorm.io/builder v0.3.13
|
||||
)
|
||||
|
||||
require (
|
||||
filippo.io/edwards25519 v1.1.0 // indirect
|
||||
github.com/PuerkitoBio/goquery v1.9.2 // indirect
|
||||
github.com/andybalholm/brotli v1.1.0 // indirect
|
||||
github.com/andybalholm/cascadia v1.3.2 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/fsnotify/fsnotify v1.9.0 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.13 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-sql-driver/mysql v1.8.1 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/klauspost/compress v1.17.9 // indirect
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
github.com/mattn/go-colorable v0.1.13 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.16 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
|
||||
github.com/rivo/uniseg v0.2.0 // indirect
|
||||
github.com/sagikazarmark/locafero v0.11.0 // indirect
|
||||
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect
|
||||
github.com/spf13/afero v1.15.0 // indirect
|
||||
github.com/spf13/cast v1.10.0 // indirect
|
||||
github.com/spf13/pflag v1.0.10 // indirect
|
||||
github.com/subosito/gotenv v1.6.0 // indirect
|
||||
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||
github.com/valyala/fasthttp v1.51.0 // indirect
|
||||
github.com/valyala/tcplisten v1.0.0 // indirect
|
||||
github.com/ysmood/fetchup v0.2.3 // indirect
|
||||
github.com/ysmood/goob v0.4.0 // indirect
|
||||
github.com/ysmood/got v0.40.0 // indirect
|
||||
github.com/ysmood/gson v0.7.3 // indirect
|
||||
github.com/ysmood/leakless v0.9.0 // indirect
|
||||
go.uber.org/atomic v1.11.0 // indirect
|
||||
go.yaml.in/yaml/v3 v3.0.4 // indirect
|
||||
golang.org/x/crypto v0.49.0 // indirect
|
||||
golang.org/x/net v0.51.0 // indirect
|
||||
golang.org/x/sys v0.42.0 // indirect
|
||||
golang.org/x/text v0.35.0 // indirect
|
||||
)
|
||||
|
|
@ -0,0 +1,222 @@
|
|||
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
|
||||
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
|
||||
gitea.com/xorm/sqlfiddle v0.0.0-20180821085327-62ce714f951a h1:lSA0F4e9A2NcQSqGqTOXqu2aRi/XEQxDCBwM8yJtE6s=
|
||||
gitea.com/xorm/sqlfiddle v0.0.0-20180821085327-62ce714f951a/go.mod h1:EXuID2Zs0pAQhH8yz+DNjUbjppKQzKFAn28TMYPB6IU=
|
||||
github.com/JohannesKaufmann/html-to-markdown v1.6.0 h1:04VXMiE50YYfCfLboJCLcgqF5x+rHJnb1ssNmqpLH/k=
|
||||
github.com/JohannesKaufmann/html-to-markdown v1.6.0/go.mod h1:NUI78lGg/a7vpEJTz/0uOcYMaibytE4BUOQS8k78yPQ=
|
||||
github.com/PuerkitoBio/goquery v1.9.2 h1:4/wZksC3KgkQw7SQgkKotmKljk0M6V8TUvA8Wb4yPeE=
|
||||
github.com/PuerkitoBio/goquery v1.9.2/go.mod h1:GHPCaP0ODyyxqcNoFGYlAprUFH81NuRPd0GX3Zu2Mvk=
|
||||
github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M=
|
||||
github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY=
|
||||
github.com/andybalholm/cascadia v1.3.2 h1:3Xi6Dw5lHF15JtdcmAHD3i1+T8plmv7BQ/nsViSLyss=
|
||||
github.com/andybalholm/cascadia v1.3.2/go.mod h1:7gtRlve5FxPPgIgX36uWBX58OdBsSS6lUvCFb+h7KvU=
|
||||
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
|
||||
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
|
||||
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
|
||||
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
|
||||
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
|
||||
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
|
||||
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||
github.com/gabriel-vasile/mimetype v1.4.13 h1:46nXokslUBsAJE/wMsp5gtO500a4F3Nkz9Ufpk2AcUM=
|
||||
github.com/gabriel-vasile/mimetype v1.4.13/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s=
|
||||
github.com/go-kratos/kratos/v2 v2.9.2 h1:px8GJQBeLpquDKQWQ9zohEWiLA8n4D/pv7aH3asvUvo=
|
||||
github.com/go-kratos/kratos/v2 v2.9.2/go.mod h1:Jc7jaeYd4RAPjetun2C+oFAOO7HNMHTT/Z4LxpuEDJM=
|
||||
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
||||
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
|
||||
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
|
||||
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
|
||||
github.com/go-playground/validator/v10 v10.30.2 h1:JiFIMtSSHb2/XBUbWM4i/MpeQm9ZK2xqPNk8vgvu5JQ=
|
||||
github.com/go-playground/validator/v10 v10.30.2/go.mod h1:mAf2pIOVXjTEBrwUMGKkCWKKPs9NheYGabeB04txQSc=
|
||||
github.com/go-rod/rod v0.116.2 h1:A5t2Ky2A+5eD/ZJQr1EfsQSe5rms5Xof/qj296e+ZqA=
|
||||
github.com/go-rod/rod v0.116.2/go.mod h1:H+CMO9SCNc2TJ2WfrG+pKhITz57uGNYU43qYHh438Mg=
|
||||
github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y=
|
||||
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
|
||||
github.com/go-viper/mapstructure/v2 v2.5.0 h1:vM5IJoUAy3d7zRSVtIwQgBj7BiWtMPfmPEgAXnvj1Ro=
|
||||
github.com/go-viper/mapstructure/v2 v2.5.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
|
||||
github.com/gofiber/fiber/v2 v2.52.12 h1:0LdToKclcPOj8PktUdIKo9BUohjjwfnQl42Dhw8/WUw=
|
||||
github.com/gofiber/fiber/v2 v2.52.12/go.mod h1:YEcBbO/FB+5M1IZNBP9FO3J9281zgPAreiI1oqg8nDw=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4=
|
||||
github.com/google/wire v0.7.0/go.mod h1:n6YbUQD9cPKTnHXEBN2DXlOp/mVADhVErcMFb0v3J18=
|
||||
github.com/grokify/html-strip-tags-go v0.1.0 h1:03UrQLjAny8xci+R+qjCce/MYnpNXCtgzltlQbOBae4=
|
||||
github.com/grokify/html-strip-tags-go v0.1.0/go.mod h1:ZdzgfHEzAfz9X6Xe5eBLVblWIxXfYSQ40S/VKrAOGpc=
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA=
|
||||
github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
|
||||
github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4=
|
||||
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
|
||||
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=
|
||||
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
|
||||
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/redis/go-redis/v9 v9.18.0 h1:pMkxYPkEbMPwRdenAzUNyFNrDgHx9U+DrBabWNfSRQs=
|
||||
github.com/redis/go-redis/v9 v9.18.0/go.mod h1:k3ufPphLU5YXwNTUcCRXGxUoF1fqxnhFQmscfkCoDA0=
|
||||
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
|
||||
github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA=
|
||||
github.com/sagikazarmark/locafero v0.11.0 h1:1iurJgmM9G3PA/I+wWYIOw/5SyBtxapeHDcg+AAIFXc=
|
||||
github.com/sagikazarmark/locafero v0.11.0/go.mod h1:nVIGvgyzw595SUSUE6tvCp3YYTeHs15MvlmU87WwIik=
|
||||
github.com/sebdah/goldie/v2 v2.5.3 h1:9ES/mNN+HNUbNWpVAlrzuZ7jE+Nrczbj8uFRjM7624Y=
|
||||
github.com/sebdah/goldie/v2 v2.5.3/go.mod h1:oZ9fp0+se1eapSRjfYbsV/0Hqhbuu3bJVvKI/NNtssI=
|
||||
github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo=
|
||||
github.com/sergi/go-diff v1.3.1 h1:xkr+Oxo4BOQKmkn/B9eMK0g5Kg/983T9DqqPHwYqD+8=
|
||||
github.com/sergi/go-diff v1.3.1/go.mod h1:aMJSSKb2lpPvRNec0+w3fl7LP9IOFzdc9Pa4NFbPK1I=
|
||||
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 h1:+jumHNA0Wrelhe64i8F6HNlS8pkoyMv5sreGx2Ry5Rw=
|
||||
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8/go.mod h1:3n1Cwaq1E1/1lhQhtRK2ts/ZwZEhjcQeJQ1RuC6Q/8U=
|
||||
github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I=
|
||||
github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg=
|
||||
github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY=
|
||||
github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo=
|
||||
github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk=
|
||||
github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU=
|
||||
github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
|
||||
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
|
||||
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
|
||||
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
||||
github.com/valyala/fasthttp v1.51.0 h1:8b30A5JlZ6C7AS81RsWjYMQmrZG6feChmgAolCl1SqA=
|
||||
github.com/valyala/fasthttp v1.51.0/go.mod h1:oI2XroL+lI7vdXyYoQk03bXBThfFl2cVdIA3Xl7cH8g=
|
||||
github.com/valyala/tcplisten v1.0.0 h1:rBHj/Xf+E1tRGZyWIWwJDiRY0zc1Js+CV5DqwacVSA8=
|
||||
github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc=
|
||||
github.com/ysmood/fetchup v0.2.3 h1:ulX+SonA0Vma5zUFXtv52Kzip/xe7aj4vqT5AJwQ+ZQ=
|
||||
github.com/ysmood/fetchup v0.2.3/go.mod h1:xhibcRKziSvol0H1/pj33dnKrYyI2ebIvz5cOOkYGns=
|
||||
github.com/ysmood/goob v0.4.0 h1:HsxXhyLBeGzWXnqVKtmT9qM7EuVs/XOgkX7T6r1o1AQ=
|
||||
github.com/ysmood/goob v0.4.0/go.mod h1:u6yx7ZhS4Exf2MwciFr6nIM8knHQIE22lFpWHnfql18=
|
||||
github.com/ysmood/gop v0.2.0 h1:+tFrG0TWPxT6p9ZaZs+VY+opCvHU8/3Fk6BaNv6kqKg=
|
||||
github.com/ysmood/gop v0.2.0/go.mod h1:rr5z2z27oGEbyB787hpEcx4ab8cCiPnKxn0SUHt6xzk=
|
||||
github.com/ysmood/got v0.40.0 h1:ZQk1B55zIvS7zflRrkGfPDrPG3d7+JOza1ZkNxcc74Q=
|
||||
github.com/ysmood/got v0.40.0/go.mod h1:W7DdpuX6skL3NszLmAsC5hT7JAhuLZhByVzHTq874Qg=
|
||||
github.com/ysmood/gotrace v0.6.0 h1:SyI1d4jclswLhg7SWTL6os3L1WOKeNn/ZtzVQF8QmdY=
|
||||
github.com/ysmood/gotrace v0.6.0/go.mod h1:TzhIG7nHDry5//eYZDYcTzuJLYQIkykJzCRIo4/dzQM=
|
||||
github.com/ysmood/gson v0.7.3 h1:QFkWbTH8MxyUTKPkVWAENJhxqdBa4lYTQWqZCiLG6kE=
|
||||
github.com/ysmood/gson v0.7.3/go.mod h1:3Kzs5zDl21g5F/BlLTNcuAGAYLKt2lV5G8D1zF3RNmg=
|
||||
github.com/ysmood/leakless v0.9.0 h1:qxCG5VirSBvmi3uynXFkcnLMzkphdh3xx5FtrORwDCU=
|
||||
github.com/ysmood/leakless v0.9.0/go.mod h1:R8iAXPRaG97QJwqxs74RdwzcRHT1SWCGTNqY8q0JvMQ=
|
||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
github.com/yuin/goldmark v1.7.1 h1:3bajkSilaCbjdKVsKdZjZCLBNPL9pYzrCakKaf4U49U=
|
||||
github.com/yuin/goldmark v1.7.1/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E=
|
||||
github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0=
|
||||
github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA=
|
||||
go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE=
|
||||
go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
|
||||
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
|
||||
golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M=
|
||||
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
||||
golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4=
|
||||
golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
||||
golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns=
|
||||
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
|
||||
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
|
||||
golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8=
|
||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
||||
golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo=
|
||||
golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
||||
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
|
||||
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
||||
golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY=
|
||||
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
|
||||
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
|
||||
golang.org/x/term v0.19.0/go.mod h1:2CuTdWZ7KHSQwUzKva0cbMg6q2DMI3Mmxp+gKJbskEk=
|
||||
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
|
||||
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gorm.io/driver/mysql v1.6.0 h1:eNbLmNTpPpTOVZi8MMxCi2aaIm0ZpInbORNXDwyLGvg=
|
||||
gorm.io/driver/mysql v1.6.0/go.mod h1:D/oCC2GWK3M/dqoLxnOlaNKmXz8WNTfcS9y5ovaSqKo=
|
||||
gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg=
|
||||
gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs=
|
||||
xorm.io/builder v0.3.13 h1:a3jmiVVL19psGeXx8GIurTp7p0IIgqeDmwhcR6BAOAo=
|
||||
xorm.io/builder v0.3.13/go.mod h1:aUW0S9eb9VCaPohFCH3j7czOx1PMW3i1HrSzbLYGBSE=
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
package biz
|
||||
|
||||
import (
|
||||
"github.com/google/wire"
|
||||
)
|
||||
|
||||
var ProviderSetBiz = wire.NewSet(
|
||||
NewPublishBiz,
|
||||
)
|
||||
|
|
@ -0,0 +1,102 @@
|
|||
package biz
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"geo/internal/config"
|
||||
"geo/internal/data/impl"
|
||||
"geo/internal/data/model"
|
||||
"geo/tmpl/errcode"
|
||||
|
||||
"xorm.io/builder"
|
||||
)
|
||||
|
||||
type PublishBiz struct {
|
||||
cfg *config.Config
|
||||
publishImpl *impl.PublishImpl
|
||||
userImpl *impl.UserImpl
|
||||
platImpl *impl.PlatImpl
|
||||
tokenImpl *impl.TokenImpl
|
||||
loginRelationImpl *impl.LoginRelationImpl
|
||||
}
|
||||
|
||||
func NewPublishBiz(
|
||||
cfg *config.Config,
|
||||
publishImpl *impl.PublishImpl,
|
||||
userImpl *impl.UserImpl,
|
||||
platImpl *impl.PlatImpl,
|
||||
tokenImpl *impl.TokenImpl,
|
||||
loginRelationImpl *impl.LoginRelationImpl,
|
||||
|
||||
) *PublishBiz {
|
||||
return &PublishBiz{
|
||||
cfg: cfg,
|
||||
publishImpl: publishImpl,
|
||||
loginRelationImpl: loginRelationImpl,
|
||||
userImpl: userImpl,
|
||||
platImpl: platImpl,
|
||||
tokenImpl: tokenImpl,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *PublishBiz) ValidateAccessToken(ctx context.Context, accessToken string) (*model.Token, error) {
|
||||
cond := builder.NewCond().
|
||||
And(builder.Eq{"access_token": accessToken}).
|
||||
And(builder.Eq{"status": 1})
|
||||
tokenInfo := &model.Token{}
|
||||
err := b.tokenImpl.GetOneBySearchStruct(ctx, &cond, tokenInfo)
|
||||
if err != nil || tokenInfo == nil {
|
||||
return nil, errcode.Forbidden("密钥无效或已禁用")
|
||||
}
|
||||
return tokenInfo, nil
|
||||
}
|
||||
|
||||
func (b *PublishBiz) BatchInsertPublish(ctx context.Context, list []*model.Publish) error {
|
||||
|
||||
return b.publishImpl.Add(ctx, list)
|
||||
}
|
||||
|
||||
func (b *PublishBiz) GetPendingPublish(ctx context.Context, tokenID int) (map[string]interface{}, error) {
|
||||
|
||||
currentTime := time.Now()
|
||||
cond := builder.NewCond().
|
||||
And(builder.Eq{"p.token_id": tokenID}).
|
||||
And(builder.Eq{"p.status": 1}).
|
||||
And(builder.Lte{"p.publish_time": currentTime})
|
||||
|
||||
return b.publishImpl.GetOneWithPlat(ctx, &cond)
|
||||
}
|
||||
|
||||
func (b *PublishBiz) GetTaskByRequestID(ctx context.Context, requestID string) (map[string]interface{}, error) {
|
||||
cond := builder.NewCond().
|
||||
And(builder.Eq{"p.request_id": requestID})
|
||||
return b.publishImpl.GetOneWithPlat(ctx, &cond)
|
||||
}
|
||||
|
||||
func (b *PublishBiz) UpdatePublishStatus(ctx context.Context, requestID string, status int, msg string) error {
|
||||
return b.publishImpl.UpdateStatus(ctx, requestID, status, msg)
|
||||
}
|
||||
|
||||
func (b *PublishBiz) GetPublishList(ctx context.Context, tokenID int32, page, pageSize int, filters map[string]interface{}) ([]map[string]interface{}, int64, error) {
|
||||
return b.publishImpl.GetListWithUser(ctx, tokenID, page, pageSize, filters)
|
||||
}
|
||||
|
||||
func (b *PublishBiz) GetPlatInfo(ctx context.Context, platIndex string) (*model.Plat, error) {
|
||||
cond := builder.NewCond().
|
||||
And(builder.Eq{"`index`": platIndex}).
|
||||
And(builder.Eq{"status": 1})
|
||||
plat := &model.Plat{}
|
||||
err := b.platImpl.GetOneBySearchStruct(ctx, &cond, plat)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return plat, nil
|
||||
}
|
||||
|
||||
func (b *PublishBiz) UpdateLoginStatus(ctx context.Context, userIndex, platIndex string, loginStatus int32) error {
|
||||
cond := builder.NewCond().
|
||||
And(builder.Eq{"user_index": userIndex}).
|
||||
And(builder.Eq{"plat_index": platIndex})
|
||||
return b.loginRelationImpl.UpdateByCond(ctx, &cond, &model.User{Status: loginStatus})
|
||||
}
|
||||
|
|
@ -0,0 +1,72 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
// Config 应用配置
|
||||
type Config struct {
|
||||
Server ServerConfig `mapstructure:"server"`
|
||||
DB DB `mapstructure:"db"`
|
||||
Sys Sys `mapstructure:"sys"`
|
||||
}
|
||||
|
||||
// ServerConfig 服务器配置
|
||||
type ServerConfig struct {
|
||||
Port int `mapstructure:"port"`
|
||||
Host string `mapstructure:"host"`
|
||||
}
|
||||
|
||||
type DB struct {
|
||||
Driver string `mapstructure:"driver"`
|
||||
Source string `mapstructure:"source"`
|
||||
MaxIdle int32 `mapstructure:"maxIdle"`
|
||||
MaxOpen int32 `mapstructure:"maxOpen"`
|
||||
MaxLifetime int32 `mapstructure:"maxLifetime"`
|
||||
IsDebug bool `mapstructure:"isDebug"`
|
||||
}
|
||||
|
||||
type Sys struct {
|
||||
MaxConcurrent int `mapstructure:"maxConcurrent"`
|
||||
TaskTimeout int `mapstructure:"taskTimeout"`
|
||||
SessionTimeout int `mapstructure:"sessionTimeout "`
|
||||
MaxImageSize int `mapstructure:"maxImageSize"`
|
||||
LogsDir string `mapstructure:"logsDir"`
|
||||
UploadDir string `mapstructure:"uploadDir"`
|
||||
VideosDir string `mapstructure:"videosDir"`
|
||||
DocsDir string `mapstructure:"docsDir"`
|
||||
CookiesDir string `mapstructure:"cookiesDir"`
|
||||
QrcodesDir string `mapstructure:"qrcodesDir"`
|
||||
ChromePath string `mapstructure:"chromePath"`
|
||||
ChromeDataDir string `mapstructure:"chromeDataDir"`
|
||||
}
|
||||
|
||||
// LoadConfig 加载配置
|
||||
func LoadConfig() (*Config, error) {
|
||||
BaseDir, _ := os.Getwd()
|
||||
return &Config{
|
||||
Server: ServerConfig{
|
||||
Port: 5001,
|
||||
Host: "0.0.0.0",
|
||||
},
|
||||
DB: DB{
|
||||
Driver: "mysql",
|
||||
Source: "root:lansexiongdi6,@tcp(47.97.27.195:3306)/geo?charset=utf8mb4&parseTime=true&loc=Asia%2FShanghai",
|
||||
},
|
||||
Sys: Sys{
|
||||
MaxConcurrent: 1,
|
||||
TaskTimeout: 60,
|
||||
SessionTimeout: 300,
|
||||
MaxImageSize: 5 * 1024 * 1024,
|
||||
LogsDir: filepath.Join(BaseDir, "logs"),
|
||||
UploadDir: filepath.Join(BaseDir, "images"),
|
||||
VideosDir: filepath.Join(BaseDir, "videos"),
|
||||
DocsDir: filepath.Join(BaseDir, "docs"),
|
||||
CookiesDir: filepath.Join(BaseDir, "cookies"),
|
||||
QrcodesDir: filepath.Join(BaseDir, "qrcodes"),
|
||||
ChromePath: "./chrome/chrome.exe",
|
||||
ChromeDataDir: "./chrome_data",
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
package impl
|
||||
|
||||
import (
|
||||
"geo/internal/data/model"
|
||||
"geo/tmpl/dataTemp"
|
||||
"geo/utils"
|
||||
)
|
||||
|
||||
type LoginRelationImpl struct {
|
||||
dataTemp.DataTemp
|
||||
db *utils.Db
|
||||
}
|
||||
|
||||
func NewLoginRelationImpl(db *utils.Db) *LoginRelationImpl {
|
||||
return &LoginRelationImpl{
|
||||
DataTemp: *dataTemp.NewDataTemp(db, new(model.LoginRelation)),
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *LoginRelationImpl) PrimaryKey() string {
|
||||
return "id"
|
||||
}
|
||||
|
||||
func (m *LoginRelationImpl) GetTemp() *dataTemp.DataTemp {
|
||||
return &m.DataTemp
|
||||
}
|
||||
|
|
@ -0,0 +1,88 @@
|
|||
package impl
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"geo/internal/data/model"
|
||||
"geo/tmpl/dataTemp"
|
||||
"geo/utils"
|
||||
|
||||
"xorm.io/builder"
|
||||
)
|
||||
|
||||
type PlatImpl struct {
|
||||
dataTemp.DataTemp
|
||||
db *utils.Db
|
||||
}
|
||||
|
||||
func NewPlatImpl(db *utils.Db) *PlatImpl {
|
||||
return &PlatImpl{
|
||||
DataTemp: *dataTemp.NewDataTemp(db, new(model.Plat)),
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PlatImpl) PrimaryKey() string {
|
||||
return "id"
|
||||
}
|
||||
|
||||
func (p *PlatImpl) GetTemp() *dataTemp.DataTemp {
|
||||
return &p.DataTemp
|
||||
}
|
||||
|
||||
// GetPlatListWithLoginStatus 获取平台列表并关联登录状态
|
||||
func (p *PlatImpl) GetPlatListWithLoginStatus(ctx context.Context, cond *builder.Cond) ([]map[string]interface{}, error) {
|
||||
query, err := builder.ToBoundSQL(*cond)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sql := `
|
||||
SELECT
|
||||
plat.name,
|
||||
plat.img_url,
|
||||
plat.index,
|
||||
plat.desc,
|
||||
COALESCE(login_relation.login_status, 2) as login_status
|
||||
FROM plat
|
||||
LEFT JOIN login_relation ON login_relation.plat_index COLLATE utf8mb4_unicode_ci = plat.index AND login_relation.status = 1
|
||||
WHERE %s
|
||||
ORDER BY plat.id ASC
|
||||
`
|
||||
|
||||
finalSQL := fmt.Sprintf(sql, query)
|
||||
|
||||
var results []map[string]interface{}
|
||||
err = p.db.Client.WithContext(ctx).Raw(finalSQL).Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetPlatListWithLoginStatusByUserIndex 根据用户索引获取平台列表及登录状态
|
||||
func (p *PlatImpl) GetPlatListWithLoginStatusByUserIndex(ctx context.Context, userIndex string) ([]map[string]interface{}, error) {
|
||||
sql := `
|
||||
SELECT
|
||||
pl.name,
|
||||
pl.img_url,
|
||||
pl.index,
|
||||
pl.desc,
|
||||
pl.login_url,
|
||||
pl.edit_url,
|
||||
pl.logined_url,
|
||||
COALESCE(lr.login_status, 2) as login_status
|
||||
FROM plat pl
|
||||
LEFT JOIN login_relation lr ON lr.plat_index = pl.index AND lr.user_index = ? AND lr.status = 1
|
||||
WHERE pl.status = 1
|
||||
ORDER BY pl.id ASC
|
||||
`
|
||||
|
||||
var results []map[string]interface{}
|
||||
|
||||
err := p.db.Client.WithContext(ctx).Raw(sql, userIndex).Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
package impl
|
||||
|
||||
import (
|
||||
"github.com/google/wire"
|
||||
)
|
||||
|
||||
var ProviderImpl = wire.NewSet(
|
||||
NewPlatImpl,
|
||||
NewLoginRelationImpl,
|
||||
NewUserImpl,
|
||||
NewTokenImpl,
|
||||
NewPublishImpl,
|
||||
)
|
||||
|
|
@ -0,0 +1,244 @@
|
|||
package impl
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"geo/internal/data/model"
|
||||
"geo/tmpl/dataTemp"
|
||||
"geo/utils"
|
||||
"time"
|
||||
|
||||
"xorm.io/builder"
|
||||
)
|
||||
|
||||
type PublishImpl struct {
|
||||
dataTemp.DataTemp
|
||||
db *utils.Db
|
||||
}
|
||||
|
||||
func NewPublishImpl(db *utils.Db) *PublishImpl {
|
||||
return &PublishImpl{
|
||||
DataTemp: *dataTemp.NewDataTemp(db, new(model.Publish)),
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *PublishImpl) PrimaryKey() string {
|
||||
return "id"
|
||||
}
|
||||
|
||||
func (m *PublishImpl) GetTemp() *dataTemp.DataTemp {
|
||||
return &m.DataTemp
|
||||
}
|
||||
|
||||
// BatchInsert 批量插入发布记录
|
||||
func (p *PublishImpl) BatchInsert(ctx context.Context, tokenID int, records []map[string]interface{}) (int64, error) {
|
||||
if len(records) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
sql := `INSERT INTO publish
|
||||
(token_id, user_index, request_id, title, tag, type, plat_index, url, publish_time, status, create_time, img)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`
|
||||
|
||||
var total int64
|
||||
for _, record := range records {
|
||||
result := p.db.Client.WithContext(ctx).Exec(sql,
|
||||
tokenID,
|
||||
record["user_index"],
|
||||
record["request_id"],
|
||||
record["title"],
|
||||
record["tag"],
|
||||
record["type"],
|
||||
record["plat_index"],
|
||||
record["url"],
|
||||
record["publish_time"],
|
||||
1, // status = 1 待发布
|
||||
time.Now(),
|
||||
record["img"],
|
||||
)
|
||||
if result.Error != nil {
|
||||
return total, result.Error
|
||||
}
|
||||
total += result.RowsAffected
|
||||
}
|
||||
return total, nil
|
||||
}
|
||||
|
||||
// GetOneWithPlat 查询单条发布记录并关联平台信息
|
||||
func (p *PublishImpl) GetOneWithPlat(ctx context.Context, cond *builder.Cond) (map[string]interface{}, error) {
|
||||
query, err := builder.ToBoundSQL(*cond)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 构建带平台信息的查询
|
||||
sql := fmt.Sprintf(`
|
||||
SELECT
|
||||
p.id,
|
||||
p.token_id,
|
||||
p.user_index,
|
||||
p.request_id,
|
||||
p.title,
|
||||
p.tag,
|
||||
p.type,
|
||||
p.plat_index,
|
||||
p.url,
|
||||
p.publish_time,
|
||||
p.status,
|
||||
p.create_time,
|
||||
p.img,
|
||||
p.msg,
|
||||
pl.index as plat_index_db,
|
||||
pl.name as plat_name,
|
||||
pl.login_url,
|
||||
pl.edit_url,
|
||||
pl.logined_url,
|
||||
pl.desc as plat_desc,
|
||||
pl.img_url as plat_img_url
|
||||
FROM publish p
|
||||
INNER JOIN plat pl ON p.plat_index = pl.index AND pl.status = 1
|
||||
WHERE %s
|
||||
LIMIT 1
|
||||
`, query)
|
||||
|
||||
var result map[string]interface{}
|
||||
err = p.db.Client.WithContext(ctx).Raw(sql).Scan(&result).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if result == nil || len(result) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetOneWithPlatByRequestID 根据request_id查询发布记录并关联平台信息
|
||||
func (p *PublishImpl) GetOneWithPlatByRequestID(ctx context.Context, requestID string) (map[string]interface{}, error) {
|
||||
sql := `
|
||||
SELECT
|
||||
p.id,
|
||||
p.token_id,
|
||||
p.user_index,
|
||||
p.request_id,
|
||||
p.title,
|
||||
p.tag,
|
||||
p.type,
|
||||
p.plat_index,
|
||||
p.url,
|
||||
p.publish_time,
|
||||
p.status,
|
||||
p.create_time,
|
||||
p.img,
|
||||
p.msg,
|
||||
pl.index as plat_index_db,
|
||||
pl.name as plat_name,
|
||||
pl.login_url,
|
||||
pl.edit_url,
|
||||
pl.logined_url,
|
||||
pl.desc as plat_desc,
|
||||
pl.img_url as plat_img_url
|
||||
FROM publish p
|
||||
INNER JOIN plat pl ON p.plat_index = pl.index AND pl.status = 1
|
||||
WHERE p.request_id = ?
|
||||
LIMIT 1
|
||||
`
|
||||
|
||||
var result map[string]interface{}
|
||||
err := p.db.Client.WithContext(ctx).Raw(sql, requestID).Scan(&result).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if result == nil || len(result) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// UpdateStatus 更新发布状态
|
||||
func (p *PublishImpl) UpdateStatus(ctx context.Context, requestID string, status int, msg string) error {
|
||||
updateData := map[string]interface{}{
|
||||
"status": status,
|
||||
}
|
||||
if msg != "" {
|
||||
updateData["msg"] = msg
|
||||
}
|
||||
return p.db.Client.WithContext(ctx).
|
||||
Model(&model.Publish{}).
|
||||
Where("request_id = ?", requestID).
|
||||
Updates(updateData).Error
|
||||
}
|
||||
|
||||
// GetListWithUser 获取发布列表并关联用户信息
|
||||
func (p *PublishImpl) GetListWithUser(ctx context.Context, tokenID int32, page, pageSize int, filters map[string]interface{}) ([]map[string]interface{}, int64, error) {
|
||||
// 构建基础查询
|
||||
query := p.db.Client.WithContext(ctx).
|
||||
Table("publish p").
|
||||
Select(`
|
||||
p.id,
|
||||
p.user_index,
|
||||
u.name as user_name,
|
||||
p.request_id,
|
||||
p.title,
|
||||
p.tag,
|
||||
p.type,
|
||||
CASE
|
||||
WHEN p.type = 1 THEN '文章'
|
||||
WHEN p.type = 2 THEN '视频'
|
||||
ELSE '未知'
|
||||
END as type_name,
|
||||
p.plat_index,
|
||||
pl.name as plat_name,
|
||||
p.url,
|
||||
p.publish_time,
|
||||
p.status,
|
||||
CASE
|
||||
WHEN p.status = 1 THEN '待发布'
|
||||
WHEN p.status = 2 THEN '发布中'
|
||||
WHEN p.status = 3 THEN '发布失败'
|
||||
WHEN p.status = 4 THEN '发布成功'
|
||||
ELSE '未知'
|
||||
END as status_name,
|
||||
p.create_time,
|
||||
p.msg
|
||||
`).
|
||||
Joins("LEFT JOIN user u ON p.user_index = u.user_index").
|
||||
Joins("LEFT JOIN plat pl ON p.plat_index = pl.index").
|
||||
Where("u.token_id = ?", tokenID)
|
||||
|
||||
// 添加过滤条件
|
||||
if userIndex, ok := filters["user_index"]; ok && userIndex != "" {
|
||||
query = query.Where("p.user_index = ?", userIndex)
|
||||
}
|
||||
if tag, ok := filters["tag"]; ok && tag != "" {
|
||||
query = query.Where("p.tag LIKE ?", "%"+tag.(string)+"%")
|
||||
}
|
||||
if typeFilter, ok := filters["type"]; ok && typeFilter != 0 {
|
||||
query = query.Where("p.type = ?", typeFilter)
|
||||
}
|
||||
if platIndex, ok := filters["plat_index"]; ok && platIndex != "" {
|
||||
query = query.Where("p.plat_index = ?", platIndex)
|
||||
}
|
||||
if status, ok := filters["status"]; ok && status != 0 {
|
||||
query = query.Where("p.status = ?", status)
|
||||
}
|
||||
if requestID, ok := filters["request_id"]; ok && requestID != "" {
|
||||
query = query.Where("p.request_id LIKE ?", "%"+requestID.(string)+"%")
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
var total int64
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 分页查询
|
||||
var results []map[string]interface{}
|
||||
offset := (page - 1) * pageSize
|
||||
err := query.
|
||||
Order("p.publish_time DESC, p.create_time DESC").
|
||||
Limit(pageSize).
|
||||
Offset(offset).
|
||||
Scan(&results).Error
|
||||
|
||||
return results, total, err
|
||||
}
|
||||
|
|
@ -0,0 +1,31 @@
|
|||
package impl
|
||||
|
||||
import (
|
||||
"geo/internal/data/model"
|
||||
"geo/tmpl/dataTemp"
|
||||
"geo/utils"
|
||||
)
|
||||
|
||||
type TokenImpl struct {
|
||||
dataTemp.DataTemp
|
||||
db *utils.Db
|
||||
}
|
||||
|
||||
func NewTokenImpl(db *utils.Db) *TokenImpl {
|
||||
return &TokenImpl{
|
||||
DataTemp: *dataTemp.NewDataTemp(db, new(model.Token)),
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *TokenImpl) PrimaryKey() string {
|
||||
return "id"
|
||||
}
|
||||
|
||||
func (m *TokenImpl) GetTemp() *dataTemp.DataTemp {
|
||||
return &m.DataTemp
|
||||
}
|
||||
|
||||
func (m *TokenImpl) GetDb() *utils.Db {
|
||||
return m.db
|
||||
}
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
package impl
|
||||
|
||||
import (
|
||||
"geo/internal/data/model"
|
||||
"geo/tmpl/dataTemp"
|
||||
"geo/utils"
|
||||
)
|
||||
|
||||
type UserImpl struct {
|
||||
dataTemp.DataTemp
|
||||
db *utils.Db
|
||||
}
|
||||
|
||||
func NewUserImpl(db *utils.Db) *UserImpl {
|
||||
return &UserImpl{
|
||||
DataTemp: *dataTemp.NewDataTemp(db, new(model.User)),
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *UserImpl) PrimaryKey() string {
|
||||
return "id"
|
||||
}
|
||||
|
||||
func (m *UserImpl) GetTemp() *dataTemp.DataTemp {
|
||||
return &m.DataTemp
|
||||
}
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
|
||||
package model
|
||||
|
||||
const TableNameLoginRelation = "login_relation"
|
||||
|
||||
// LoginRelation mapped from table <login_relation>
|
||||
type LoginRelation struct {
|
||||
ID int32 `gorm:"column:id;primaryKey;autoIncrement:true" json:"id"`
|
||||
UserIndex string `gorm:"column:user_index;not null;default:0" json:"user_index"`
|
||||
PlatIndex string `gorm:"column:plat_index;not null;default:0" json:"plat_index"`
|
||||
LoginStatus int32 `gorm:"column:login_status;not null;default:2" json:"login_status"`
|
||||
Status int32 `gorm:"column:status;not null;default:1" json:"status"`
|
||||
}
|
||||
|
||||
// TableName LoginRelation's table name
|
||||
func (*LoginRelation) TableName() string {
|
||||
return TableNameLoginRelation
|
||||
}
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
|
||||
package model
|
||||
|
||||
const TableNamePlat = "plat"
|
||||
|
||||
// Plat mapped from table <plat>
|
||||
type Plat struct {
|
||||
ID int32 `gorm:"column:id;primaryKey" json:"id"`
|
||||
Name string `gorm:"column:name;not null" json:"name"`
|
||||
Index string `gorm:"column:index;not null" json:"index"`
|
||||
ImgURL string `gorm:"column:img_url;not null" json:"img_url"`
|
||||
LoginURL string `gorm:"column:login_url;not null" json:"login_url"`
|
||||
EditURL string `gorm:"column:edit_url;not null" json:"edit_url"`
|
||||
LoginedURL string `gorm:"column:logined_url;not null" json:"logined_url"`
|
||||
Desc string `gorm:"column:desc;not null" json:"desc"`
|
||||
Status bool `gorm:"column:status;not null;default:1" json:"status"`
|
||||
}
|
||||
|
||||
// TableName Plat's table name
|
||||
func (*Plat) TableName() string {
|
||||
return TableNamePlat
|
||||
}
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
const TableNamePublish = "publish"
|
||||
|
||||
// Publish mapped from table <publish>
|
||||
type Publish struct {
|
||||
ID int64 `gorm:"column:id;primaryKey;autoIncrement:true" json:"id"`
|
||||
TokenID int32 `gorm:"column:token_id;not null" json:"token_id"`
|
||||
UserIndex string `gorm:"column:user_index;not null;comment:关联user.user_index" json:"user_index"` // 关联user.user_index
|
||||
RequestID string `gorm:"column:request_id;not null;comment:日志id" json:"request_id"` // 日志id
|
||||
Title string `gorm:"column:title;not null" json:"title"`
|
||||
Tag string `gorm:"column:tag;not null;comment:标签,多个英文都好隔开’,‘" json:"tag"` // 标签,多个英文都好隔开’,‘
|
||||
Type int32 `gorm:"column:type;not null;default:1;comment:1:文章,2:视频" json:"type"` // 1:文章,2:视频
|
||||
PlatIndex string `gorm:"column:plat_index;not null;comment:关联plat.index" json:"plat_index"` // 关联plat.index
|
||||
URL string `gorm:"column:url;not null;comment:资源文件下载地址" json:"url"` // 资源文件下载地址
|
||||
PublishTime time.Time `gorm:"column:publish_time;not null;comment:发布时间" json:"publish_time"` // 发布时间
|
||||
Img string `gorm:"column:img;not null" json:"img"`
|
||||
Status int32 `gorm:"column:status;not null;default:1;comment:1:待发布,2:发布中,3:发布失败,4:发布成功" json:"status"` // 1:待发布,2:发布中,3:发布失败,4:发布成功
|
||||
CreateTime time.Time `gorm:"column:create_time;not null;default:CURRENT_TIMESTAMP;comment:创建时间" json:"create_time"` // 创建时间
|
||||
Msg string `gorm:"column:msg" json:"msg"`
|
||||
}
|
||||
|
||||
// TableName Publish's table name
|
||||
func (*Publish) TableName() string {
|
||||
return TableNamePublish
|
||||
}
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
const TableNameToken = "token"
|
||||
|
||||
// Token mapped from table <token>
|
||||
type Token struct {
|
||||
ID int32 `gorm:"column:id;primaryKey" json:"id"`
|
||||
Name string `gorm:"column:name" json:"name"`
|
||||
Secret string `gorm:"column:secret;not null" json:"secret"`
|
||||
AccessToken string `gorm:"column:access_token;not null" json:"access_token"`
|
||||
UserLimit int32 `gorm:"column:user_limit;primaryKey" json:"user_limit"`
|
||||
ExpireTime time.Time `gorm:"column:expire_time;not null" json:"expire_time"`
|
||||
Status int32 `gorm:"column:status;not null;default:1" json:"status"`
|
||||
}
|
||||
|
||||
// TableName Token's table name
|
||||
func (*Token) TableName() string {
|
||||
return TableNameToken
|
||||
}
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
|
||||
package model
|
||||
|
||||
const TableNameUser = "user"
|
||||
|
||||
// User mapped from table <user>
|
||||
type User struct {
|
||||
ID int64 `gorm:"column:id;primaryKey;autoIncrement:true" json:"id"`
|
||||
TokenID int32 `gorm:"column:token_id;not null" json:"token_id"`
|
||||
Name string `gorm:"column:name;not null" json:"name"`
|
||||
Status int32 `gorm:"column:status;not null;default:1" json:"status"`
|
||||
UserIndex string `gorm:"column:user_index;not null" json:"user_index"`
|
||||
}
|
||||
|
||||
// TableName User's table name
|
||||
func (*User) TableName() string {
|
||||
return TableNameUser
|
||||
}
|
||||
|
|
@ -0,0 +1,88 @@
|
|||
package entitys
|
||||
|
||||
type (
|
||||
LoginAppRequest struct {
|
||||
Secret string `json:"secret" validate:"required" zh:"密钥"`
|
||||
}
|
||||
|
||||
GetUserAndAutoStatusRequest struct {
|
||||
AccessToken string `json:"access_token" validate:"required" zh:"access_token"`
|
||||
}
|
||||
|
||||
AddUserRequest struct {
|
||||
AccessToken string `json:"access_token" validate:"required" zh:"access_token"`
|
||||
Name string `json:"name" validate:"required" zh:"用户名"`
|
||||
}
|
||||
|
||||
DelUserRequest struct {
|
||||
ID int `json:"id" validate:"required" zh:"用户ID"`
|
||||
}
|
||||
|
||||
GetAppRequest struct {
|
||||
AccessToken string `json:"access_token" validate:"required" zh:"access_token"`
|
||||
UserIndex string `json:"user_index" validate:"required" zh:"用户索引"`
|
||||
}
|
||||
|
||||
PublishRecordsRequest struct {
|
||||
AccessToken string `json:"access_token" validate:"required" zh:"access_token"`
|
||||
Records []PublishRecordItem `json:"records" validate:"required" zh:"发布记录"`
|
||||
}
|
||||
|
||||
PublishRecordItem struct {
|
||||
UserIndex string `json:"user_index"`
|
||||
PlatIndex string `json:"plat_index" validate:"required" zh:"平台索引"`
|
||||
Title string `json:"title"`
|
||||
Tag string `json:"tag"`
|
||||
Type int32 `json:"type" validate:"required" zh:"类型"`
|
||||
URL string `json:"url" validate:"required" zh:"链接"`
|
||||
PublishTime string `json:"publish_time" validate:"required" zh:"发布时间"`
|
||||
Img string `json:"img"`
|
||||
RequestID string `json:"request_id"`
|
||||
}
|
||||
|
||||
PublishOnRequest struct {
|
||||
AccessToken string `json:"access_token" validate:"required" zh:"access_token"`
|
||||
}
|
||||
|
||||
PublishOffRequest struct {
|
||||
AccessToken string `json:"access_token" validate:"required" zh:"access_token"`
|
||||
}
|
||||
|
||||
PublishStatusRequest struct {
|
||||
AccessToken string `json:"access_token" validate:"required" zh:"access_token"`
|
||||
RequestID string `json:"request_id"`
|
||||
}
|
||||
|
||||
PublishExecuteOnceRequest struct {
|
||||
AccessToken string `json:"access_token" validate:"required" zh:"access_token"`
|
||||
}
|
||||
|
||||
PublishExecuteRetryRequest struct {
|
||||
AccessToken string `json:"access_token" validate:"required" zh:"access_token"`
|
||||
RequestID string `json:"request_id" validate:"required" zh:"请求ID"`
|
||||
}
|
||||
|
||||
GetPublishListRequest struct {
|
||||
AccessToken string `json:"access_token" validate:"required" zh:"access_token"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
UserIndex string `json:"user_index"`
|
||||
Tag string `json:"tag"`
|
||||
Type int `json:"type"`
|
||||
PlatIndex string `json:"plat_index"`
|
||||
Status int `json:"status"`
|
||||
RequestID string `json:"request_id"`
|
||||
}
|
||||
|
||||
LoginPlatformRequest struct {
|
||||
AccessToken string `json:"access_token" validate:"required" zh:"access_token"`
|
||||
UserIndex string `json:"user_index" validate:"required" zh:"用户索引"`
|
||||
PlatIndex string `json:"plat_index" validate:"required" zh:"平台索引"`
|
||||
}
|
||||
|
||||
LogoutPlatformRequest struct {
|
||||
AccessToken string `json:"access_token" validate:"required" zh:"access_token"`
|
||||
UserIndex string `json:"user_index" validate:"required" zh:"用户索引"`
|
||||
PlatIndex string `json:"plat_index" validate:"required" zh:"平台索引"`
|
||||
}
|
||||
)
|
||||
|
|
@ -0,0 +1,297 @@
|
|||
package manager
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"geo/internal/config"
|
||||
"geo/internal/publisher"
|
||||
"geo/pkg"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"geo/utils"
|
||||
)
|
||||
|
||||
type PublishManager struct {
|
||||
AutoStatus bool
|
||||
Conf *config.Config
|
||||
TokenID int
|
||||
running bool
|
||||
mu sync.Mutex
|
||||
stopCh chan struct{}
|
||||
currentPublisher interface{}
|
||||
db *utils.Db
|
||||
}
|
||||
|
||||
var publishManager *PublishManager
|
||||
var once sync.Once
|
||||
|
||||
func GetPublishManager(config *config.Config, db *utils.Db) *PublishManager {
|
||||
once.Do(func() {
|
||||
publishManager = &PublishManager{
|
||||
AutoStatus: false,
|
||||
Conf: config,
|
||||
stopCh: make(chan struct{}),
|
||||
db: db,
|
||||
}
|
||||
})
|
||||
return publishManager
|
||||
}
|
||||
|
||||
func (pm *PublishManager) Start(tokenID int) bool {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
if pm.AutoStatus {
|
||||
return false
|
||||
}
|
||||
|
||||
pm.TokenID = tokenID
|
||||
pm.AutoStatus = true
|
||||
pm.stopCh = make(chan struct{})
|
||||
|
||||
go pm.autoPublishLoop()
|
||||
return true
|
||||
}
|
||||
|
||||
func (pm *PublishManager) Stop() bool {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
if !pm.AutoStatus {
|
||||
return false
|
||||
}
|
||||
|
||||
pm.AutoStatus = false
|
||||
close(pm.stopCh)
|
||||
return true
|
||||
}
|
||||
|
||||
func (pm *PublishManager) autoPublishLoop() {
|
||||
log.Println("自动发布服务已启动")
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-pm.stopCh:
|
||||
log.Println("自动发布服务已停止")
|
||||
return
|
||||
default:
|
||||
pm.batchPublish()
|
||||
time.Sleep(30 * time.Second)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *PublishManager) batchPublish() {
|
||||
if !pm.AutoStatus {
|
||||
return
|
||||
}
|
||||
|
||||
publishData := pm.getPendingPublish()
|
||||
if publishData == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 使用 defer recover 防止 panic 导致整个循环崩溃
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Printf("批处理发布发生 panic: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
pm.processSingleTask(publishData)
|
||||
}
|
||||
|
||||
func (pm *PublishManager) getPendingPublish() map[string]interface{} {
|
||||
currentTime := time.Now().Format("2006-01-02 15:04:05")
|
||||
|
||||
sql := `
|
||||
SELECT p.*, pl.*
|
||||
FROM publish p
|
||||
INNER JOIN plat pl ON p.plat_index = pl.index AND pl.status = 1
|
||||
WHERE p.token_id = ? AND p.status = 1 AND p.publish_time <= ?
|
||||
ORDER BY p.publish_time DESC
|
||||
LIMIT 1
|
||||
`
|
||||
|
||||
result, err := pm.db.GetOne(sql, pm.TokenID, currentTime)
|
||||
if err != nil {
|
||||
log.Printf("查询待发布任务失败: token_id=%d, error=%v", pm.TokenID, err)
|
||||
return nil
|
||||
}
|
||||
if result == nil {
|
||||
log.Printf("没有待发布任务: token_id=%d, current_time=%s", pm.TokenID, currentTime)
|
||||
return nil
|
||||
}
|
||||
|
||||
requestID := getString(result, "request_id")
|
||||
log.Printf("获取到待发布任务: token_id=%d, request_id=%s", pm.TokenID, requestID)
|
||||
return result
|
||||
}
|
||||
|
||||
func (pm *PublishManager) GetTaskByRequestID(requestID string) (map[string]interface{}, error) {
|
||||
sql := `
|
||||
SELECT p.*, pl.*
|
||||
FROM publish p
|
||||
INNER JOIN plat pl ON p.plat_index COLLATE utf8mb4_unicode_ci = pl.index AND pl.status = 1
|
||||
WHERE p.request_id = ?
|
||||
`
|
||||
return pm.db.GetOne(sql, requestID)
|
||||
}
|
||||
|
||||
func (pm *PublishManager) processSingleTask(publishData map[string]interface{}) map[string]interface{} {
|
||||
requestID := getString(publishData, "request_id")
|
||||
platIndex := getString(publishData, "plat_index")
|
||||
title := getString(publishData, "title")
|
||||
tagRaw := getString(publishData, "tag")
|
||||
userIndex := getString(publishData, "user_index")
|
||||
url := getString(publishData, "url")
|
||||
imgURL := getString(publishData, "img")
|
||||
|
||||
log.Printf("[任务 %s] 开始处理,平台:%s,标题:%s", requestID, platIndex, title)
|
||||
|
||||
// 更新状态为发布中
|
||||
pm.updatePublishStatus(requestID, 2, "")
|
||||
log.Printf("[任务 %s] 状态已更新为发布中", requestID)
|
||||
|
||||
// 下载文件
|
||||
docPath, err := pkg.DownloadFile(url, "", requestID+".docx")
|
||||
if err != nil {
|
||||
errMsg := fmt.Sprintf("下载文档失败: %v", err)
|
||||
log.Printf("[任务 %s] %s", requestID, errMsg)
|
||||
pm.updatePublishStatus(requestID, 3, errMsg)
|
||||
return map[string]interface{}{"success": false, "message": errMsg, "request_id": requestID}
|
||||
}
|
||||
log.Printf("[任务 %s] 文档下载成功: %s", requestID, docPath)
|
||||
|
||||
// 下载图片
|
||||
imgPath, err := pkg.DownloadImage(imgURL, requestID, "img")
|
||||
if err != nil {
|
||||
errMsg := fmt.Sprintf("下载图片失败: %v", err)
|
||||
log.Printf("[任务 %s] %s", requestID, errMsg)
|
||||
pm.updatePublishStatus(requestID, 3, errMsg)
|
||||
// 图片下载失败,清理已下载的文档
|
||||
pkg.DeleteFile(docPath)
|
||||
return map[string]interface{}{"success": false, "message": errMsg, "request_id": requestID}
|
||||
}
|
||||
log.Printf("[任务 %s] 图片下载成功: %s", requestID, imgPath)
|
||||
|
||||
// 确保清理临时文件
|
||||
defer func() {
|
||||
pkg.DeleteFile(docPath)
|
||||
pkg.DeleteFile(imgPath)
|
||||
}()
|
||||
|
||||
// 解析标签
|
||||
tags := pkg.ParseTags(tagRaw)
|
||||
log.Printf("[任务 %s] 标签解析完成: %v", requestID, tags)
|
||||
|
||||
// 提取内容
|
||||
content, err := pkg.ExtractWordContent(docPath, "html")
|
||||
if err != nil {
|
||||
errMsg := fmt.Sprintf("提取文档内容失败: %v", err)
|
||||
log.Printf("[任务 %s] %s", requestID, errMsg)
|
||||
pm.updatePublishStatus(requestID, 3, errMsg)
|
||||
return map[string]interface{}{"success": false, "message": errMsg, "request_id": requestID}
|
||||
}
|
||||
log.Printf("[任务 %s] 内容提取成功,长度: %d", requestID, len(content))
|
||||
|
||||
// 获取发布器
|
||||
publisherClass := getPublisherClass(platIndex)
|
||||
if publisherClass == nil {
|
||||
errMsg := fmt.Sprintf("不支持的平台: %s", platIndex)
|
||||
log.Printf("[任务 %s] %s", requestID, errMsg)
|
||||
pm.updatePublishStatus(requestID, 3, errMsg)
|
||||
return map[string]interface{}{"success": false, "message": errMsg, "request_id": requestID}
|
||||
}
|
||||
|
||||
// 创建并执行发布器
|
||||
var pub interface{ PublishNote() (bool, string) }
|
||||
|
||||
switch platIndex {
|
||||
case "xhs":
|
||||
pub = publisher.NewXiaohongshuPublisher(false, title, content, tags, userIndex, platIndex, requestID, imgPath, docPath, publishData, pm.Conf)
|
||||
log.Printf("[任务 %s] 创建小红书发布器", requestID)
|
||||
case "bjh":
|
||||
pub = publisher.NewBaijiahaoPublisher(false, title, content, tags, userIndex, platIndex, requestID, imgPath, docPath, publishData, pm.Conf)
|
||||
log.Printf("[任务 %s] 创建百家号发布器", requestID)
|
||||
default:
|
||||
log.Printf("[任务 %s] 未知平台 %s,使用默认小红书发布器", requestID, platIndex)
|
||||
pub = publisher.NewXiaohongshuPublisher(false, title, content, tags, userIndex, platIndex, requestID, imgPath, docPath, publishData, pm.Conf)
|
||||
}
|
||||
|
||||
log.Printf("[任务 %s] 开始执行发布...", requestID)
|
||||
success, message := pub.PublishNote()
|
||||
|
||||
if success {
|
||||
log.Printf("[任务 %s] 发布成功: %s", requestID, message)
|
||||
pm.updatePublishStatus(requestID, 4, message)
|
||||
} else {
|
||||
log.Printf("[任务 %s] 发布失败: %s", requestID, message)
|
||||
pm.updatePublishStatus(requestID, 3, message)
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"success": success,
|
||||
"message": message,
|
||||
"request_id": requestID,
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *PublishManager) updatePublishStatus(requestID string, status int, message string) {
|
||||
if message != "" {
|
||||
pm.db.Execute("UPDATE publish SET status = ?, msg = ? WHERE request_id = ?", status, message, requestID)
|
||||
} else {
|
||||
pm.db.Execute("UPDATE publish SET status = ? WHERE request_id = ?", status, requestID)
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *PublishManager) ExecuteOnce(tokenId int32) map[string]interface{} {
|
||||
publishData := pm.getPendingPublish()
|
||||
if publishData == nil {
|
||||
return map[string]interface{}{"success": false, "message": "没有待发布任务"}
|
||||
}
|
||||
return pm.processSingleTask(publishData)
|
||||
}
|
||||
|
||||
func (pm *PublishManager) RetryTask(requestID string) map[string]interface{} {
|
||||
publishData, err := pm.GetTaskByRequestID(requestID)
|
||||
if err != nil || publishData == nil {
|
||||
return map[string]interface{}{"success": false, "message": "任务不存在"}
|
||||
}
|
||||
return pm.processSingleTask(publishData)
|
||||
}
|
||||
|
||||
func (pm *PublishManager) GetStatus() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"auto_status": pm.AutoStatus,
|
||||
"max_concurrent": pm.Conf.Sys.MaxConcurrent,
|
||||
"task_timeout": pm.Conf.Sys.TaskTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
func getPublisherClass(platIndex string) interface{} {
|
||||
platformMap := map[string]interface{}{
|
||||
"xhs": struct{}{},
|
||||
"bjh": struct{}{},
|
||||
"csdn": struct{}{},
|
||||
}
|
||||
return platformMap[platIndex]
|
||||
}
|
||||
|
||||
func getString(m map[string]interface{}, key string) string {
|
||||
if v, ok := m[key]; ok {
|
||||
switch v.(type) {
|
||||
case []uint8:
|
||||
return string(v.([]uint8))
|
||||
case string:
|
||||
return v.(string)
|
||||
case int64:
|
||||
return fmt.Sprintf("%d", v)
|
||||
default:
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
|
@ -0,0 +1,383 @@
|
|||
package publisher
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"geo/internal/config"
|
||||
|
||||
"github.com/go-rod/rod"
|
||||
"github.com/go-rod/rod/lib/proto"
|
||||
)
|
||||
|
||||
type BaijiahaoPublisher struct {
|
||||
*BasePublisher
|
||||
Category string
|
||||
ArticleType string
|
||||
IsTop bool
|
||||
}
|
||||
|
||||
func NewBaijiahaoPublisher(headless bool, title, content string, tags []string, tenantID, platIndex, requestID, imagePath, wordPath string, platInfo map[string]interface{}, cfg *config.Config) *BaijiahaoPublisher {
|
||||
base := NewBasePublisher(headless, title, content, tags, tenantID, platIndex, requestID, imagePath, wordPath, platInfo, cfg)
|
||||
if platInfo != nil {
|
||||
base.LoginURL = getString(platInfo, "login_url")
|
||||
base.EditorURL = getString(platInfo, "edit_url")
|
||||
base.LoginedURL = getString(platInfo, "logined_url")
|
||||
}
|
||||
return &BaijiahaoPublisher{BasePublisher: base}
|
||||
}
|
||||
|
||||
func (p *BaijiahaoPublisher) CheckLoginStatus() bool {
|
||||
url := p.GetCurrentURL()
|
||||
// 如果URL包含登录相关关键词,表示未登录
|
||||
if strings.Contains(url, "login") || strings.Contains(url, "passport") {
|
||||
return false
|
||||
}
|
||||
// 如果URL是编辑页面或主页,表示已登录
|
||||
if strings.Contains(url, "baijiahao") || strings.Contains(url, "edit") {
|
||||
return true
|
||||
}
|
||||
return url != p.LoginURL
|
||||
}
|
||||
|
||||
func (p *BaijiahaoPublisher) CheckLogin() (bool, string) {
|
||||
p.LogInfo("检查登录状态...")
|
||||
|
||||
if err := p.SetupDriver(); err != nil {
|
||||
return false, fmt.Sprintf("浏览器启动失败: %v", err)
|
||||
}
|
||||
defer p.Close()
|
||||
|
||||
p.Page.MustNavigate(p.LoginedURL)
|
||||
p.Sleep(3)
|
||||
p.WaitForPageReady(5)
|
||||
|
||||
if p.CheckLoginStatus() {
|
||||
p.SaveCookies()
|
||||
return true, "已登录"
|
||||
}
|
||||
return false, "未登录"
|
||||
}
|
||||
|
||||
func (p *BaijiahaoPublisher) WaitLogin() (bool, string) {
|
||||
p.LogInfo("开始等待登录...")
|
||||
|
||||
if err := p.SetupDriver(); err != nil {
|
||||
return false, fmt.Sprintf("浏览器启动失败: %v", err)
|
||||
}
|
||||
defer p.Close()
|
||||
|
||||
// 先尝试访问已登录页面
|
||||
p.Page.MustNavigate(p.LoginedURL)
|
||||
p.Sleep(3)
|
||||
|
||||
if p.CheckLoginStatus() {
|
||||
p.SaveCookies()
|
||||
p.LogInfo("已有登录状态")
|
||||
return true, "already_logged_in"
|
||||
}
|
||||
|
||||
// 未登录,跳转到登录页
|
||||
p.Page.MustNavigate(p.LoginURL)
|
||||
p.LogInfo("请扫描二维码登录...")
|
||||
|
||||
// 等待登录完成,最多120秒
|
||||
for i := 0; i < 120; i++ {
|
||||
time.Sleep(1 * time.Second)
|
||||
if p.CheckLoginStatus() {
|
||||
p.SaveCookies()
|
||||
p.LogInfo("登录成功")
|
||||
return true, "login_success"
|
||||
}
|
||||
}
|
||||
|
||||
return false, "登录超时"
|
||||
}
|
||||
|
||||
func (p *BaijiahaoPublisher) inputTitle() error {
|
||||
p.LogInfo("输入标题...")
|
||||
|
||||
titleSelectors := []string{
|
||||
".client_pages_edit_components_titleInput [contenteditable='true']",
|
||||
".input-box [contenteditable='true']",
|
||||
"[contenteditable='true']",
|
||||
}
|
||||
|
||||
var titleInput *rod.Element
|
||||
var err error
|
||||
|
||||
for _, selector := range titleSelectors {
|
||||
titleInput, err = p.WaitForElementVisible(selector, 5)
|
||||
if err == nil && titleInput != nil {
|
||||
p.LogInfo(fmt.Sprintf("找到标题输入框: %s", selector))
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if titleInput == nil {
|
||||
return fmt.Errorf("未找到标题输入框")
|
||||
}
|
||||
|
||||
// 点击获取焦点
|
||||
if err := titleInput.Click(proto.InputMouseButtonLeft, 1); err != nil {
|
||||
return fmt.Errorf("点击标题框失败: %v", err)
|
||||
}
|
||||
p.SleepMs(500)
|
||||
|
||||
// 清空输入框
|
||||
if err := p.ClearContentEditable(titleInput); err != nil {
|
||||
p.LogInfo(fmt.Sprintf("清空标题框失败: %v", err))
|
||||
}
|
||||
p.SleepMs(300)
|
||||
|
||||
// 输入标题
|
||||
if err := p.SetContentEditable(titleInput, p.Title); err != nil {
|
||||
// 备用输入方式
|
||||
titleInput.Input(p.Title)
|
||||
}
|
||||
|
||||
p.LogInfo(fmt.Sprintf("标题已输入: %s", p.Title))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *BaijiahaoPublisher) inputContent() error {
|
||||
p.LogInfo("输入内容...")
|
||||
|
||||
// 查找内容编辑器
|
||||
contentEditor, err := p.WaitForElementVisible(".ProseMirror", 10)
|
||||
if err != nil {
|
||||
contentEditor, err = p.WaitForElementVisible("[contenteditable='true']", 10)
|
||||
if err != nil {
|
||||
return fmt.Errorf("未找到内容编辑器: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 点击获取焦点
|
||||
if err := contentEditor.Click(proto.InputMouseButtonLeft, 1); err != nil {
|
||||
return fmt.Errorf("点击编辑器失败: %v", err)
|
||||
}
|
||||
p.SleepMs(500)
|
||||
|
||||
// 清空编辑器
|
||||
if err := p.ClearContentEditable(contentEditor); err != nil {
|
||||
p.LogInfo(fmt.Sprintf("清空编辑器失败: %v", err))
|
||||
}
|
||||
p.SleepMs(300)
|
||||
|
||||
// 输入内容
|
||||
if err := p.SetContentEditable(contentEditor, p.Content); err != nil {
|
||||
// 备用输入方式
|
||||
contentEditor.Input(p.Content)
|
||||
}
|
||||
|
||||
p.LogInfo(fmt.Sprintf("内容已输入,长度: %d", len(p.Content)))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *BaijiahaoPublisher) uploadImage() error {
|
||||
if p.ImagePath == "" {
|
||||
p.LogInfo("无封面图片,跳过")
|
||||
return nil
|
||||
}
|
||||
|
||||
p.LogInfo(fmt.Sprintf("上传封面: %s", p.ImagePath))
|
||||
|
||||
// 查找封面区域
|
||||
coverArea, err := p.WaitForElementClickable(".cheetah-spin-container", 5)
|
||||
if err != nil {
|
||||
p.LogInfo("未找到封面区域,跳过")
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := coverArea.Click(proto.InputMouseButtonLeft, 1); err != nil {
|
||||
p.LogInfo(fmt.Sprintf("点击封面区域失败: %v", err))
|
||||
}
|
||||
p.SleepMs(1000)
|
||||
|
||||
// 查找文件输入框
|
||||
fileInput, err := p.Page.Element("input[type='file'][accept*='image']")
|
||||
if err != nil {
|
||||
fileInput, err = p.Page.Element("input[type='file']")
|
||||
if err != nil {
|
||||
return fmt.Errorf("未找到文件输入框: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 上传图片
|
||||
if err := fileInput.SetFiles([]string{p.ImagePath}); err != nil {
|
||||
return fmt.Errorf("上传图片失败: %v", err)
|
||||
}
|
||||
|
||||
p.LogInfo("图片上传成功")
|
||||
p.Sleep(3)
|
||||
|
||||
// 查找确认按钮
|
||||
confirmBtn, err := p.WaitForElementClickable(".cheetah-btn-primary", 5)
|
||||
if err == nil && confirmBtn != nil {
|
||||
if err := confirmBtn.Click(proto.InputMouseButtonLeft, 1); err != nil {
|
||||
p.LogInfo(fmt.Sprintf("点击确认按钮失败: %v", err))
|
||||
}
|
||||
p.LogInfo("已确认封面")
|
||||
p.SleepMs(1000)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *BaijiahaoPublisher) clickPublish() error {
|
||||
p.LogInfo("点击发布按钮...")
|
||||
|
||||
// 滚动到底部
|
||||
if _, err := p.Page.Eval(`() => window.scrollTo(0, document.body.scrollHeight)`); err != nil {
|
||||
p.LogInfo(fmt.Sprintf("滚动到底部失败: %v", err))
|
||||
}
|
||||
p.SleepMs(1000)
|
||||
|
||||
// 查找发布按钮
|
||||
publishSelectors := []string{
|
||||
"[data-testid='publish-btn']",
|
||||
".op-list-right .cheetah-btn-primary",
|
||||
".cheetah-btn-primary",
|
||||
"button:contains('发布')",
|
||||
}
|
||||
|
||||
var publishBtn *rod.Element
|
||||
var err error
|
||||
|
||||
for _, selector := range publishSelectors {
|
||||
publishBtn, err = p.WaitForElementClickable(selector, 5)
|
||||
if err == nil && publishBtn != nil {
|
||||
p.LogInfo(fmt.Sprintf("找到发布按钮: %s", selector))
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 如果还是没找到,通过 XPath 查找
|
||||
if publishBtn == nil {
|
||||
publishBtn, err = p.Page.ElementX("//button[contains(text(), '发布')]")
|
||||
if err != nil {
|
||||
return fmt.Errorf("未找到发布按钮: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 滚动到按钮位置
|
||||
if err := p.ScrollToElement(publishBtn); err != nil {
|
||||
p.LogInfo(fmt.Sprintf("滚动到按钮失败: %v", err))
|
||||
}
|
||||
p.SleepMs(500)
|
||||
|
||||
// 点击发布
|
||||
if err := publishBtn.Click(proto.InputMouseButtonLeft, 1); err != nil {
|
||||
return fmt.Errorf("点击发布按钮失败: %v", err)
|
||||
}
|
||||
|
||||
p.LogInfo("已点击发布按钮")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *BaijiahaoPublisher) waitForPublishResult() (bool, string) {
|
||||
p.LogInfo("等待发布结果...")
|
||||
|
||||
// 等待最多60秒
|
||||
for i := 0; i < 60; i++ {
|
||||
p.SleepMs(1000)
|
||||
|
||||
// 检查URL是否跳转到成功页面
|
||||
currentURL := p.GetCurrentURL()
|
||||
if strings.Contains(currentURL, "clue") ||
|
||||
strings.Contains(currentURL, "success") ||
|
||||
strings.Contains(currentURL, "article/list") {
|
||||
p.LogInfo("发布成功!")
|
||||
return true, "发布成功"
|
||||
}
|
||||
|
||||
// 检查是否有成功提示
|
||||
elements, _ := p.Page.Elements(".cheetah-message-success, .cheetah-message-info, [class*='success']")
|
||||
for _, el := range elements {
|
||||
text, _ := el.Text()
|
||||
if strings.Contains(text, "成功") || strings.Contains(text, "已发布") {
|
||||
p.LogInfo(fmt.Sprintf("发布成功: %s", text))
|
||||
return true, text
|
||||
}
|
||||
}
|
||||
|
||||
// 检查是否有失败提示
|
||||
elements, _ = p.Page.Elements(".cheetah-message-error, .cheetah-message-warning, [class*='error']")
|
||||
for _, el := range elements {
|
||||
text, _ := el.Text()
|
||||
if strings.Contains(text, "失败") || strings.Contains(text, "错误") {
|
||||
p.LogError(fmt.Sprintf("发布失败: %s", text))
|
||||
return false, text
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false, "发布结果未知(超时)"
|
||||
}
|
||||
|
||||
func (p *BaijiahaoPublisher) PublishNote() (bool, string) {
|
||||
p.LogInfo(strings.Repeat("=", 50))
|
||||
p.LogInfo("开始发布百家号文章...")
|
||||
p.LogInfo(fmt.Sprintf("标题: %s", p.Title))
|
||||
p.LogInfo(fmt.Sprintf("内容长度: %d", len(p.Content)))
|
||||
p.LogInfo(strings.Repeat("=", 50))
|
||||
|
||||
// 初始化浏览器
|
||||
if err := p.SetupDriver(); err != nil {
|
||||
return false, fmt.Sprintf("浏览器启动失败: %v", err)
|
||||
}
|
||||
defer p.Close()
|
||||
|
||||
// 访问编辑器页面
|
||||
p.Page.MustNavigate(p.EditorURL)
|
||||
p.Sleep(3)
|
||||
p.WaitForPageReady(5)
|
||||
|
||||
// 尝试加载cookies
|
||||
if err := p.LoadCookies(); err == nil {
|
||||
p.RefreshPage()
|
||||
p.Sleep(2)
|
||||
if p.CheckLoginStatus() {
|
||||
p.LogInfo("使用cookies登录成功")
|
||||
} else {
|
||||
p.LogInfo("cookies已过期,需要重新登录")
|
||||
return false, "需要登录"
|
||||
}
|
||||
}
|
||||
|
||||
// 检查登录状态
|
||||
if !p.CheckLoginStatus() {
|
||||
return false, "需要登录"
|
||||
}
|
||||
|
||||
// 保存cookies
|
||||
p.SaveCookies()
|
||||
|
||||
// 执行发布流程
|
||||
steps := []struct {
|
||||
name string
|
||||
fn func() error
|
||||
}{
|
||||
{"输入标题", p.inputTitle},
|
||||
{"输入内容", p.inputContent},
|
||||
{"上传封面", p.uploadImage},
|
||||
}
|
||||
|
||||
for _, step := range steps {
|
||||
if err := step.fn(); err != nil {
|
||||
p.LogStep(step.name, false, err.Error())
|
||||
return false, fmt.Sprintf("%s失败: %v", step.name, err)
|
||||
}
|
||||
p.LogStep(step.name, true, "")
|
||||
p.SleepMs(500)
|
||||
}
|
||||
|
||||
// 点击发布
|
||||
if err := p.clickPublish(); err != nil {
|
||||
return false, err.Error()
|
||||
}
|
||||
|
||||
// 等待发布结果
|
||||
return p.waitForPublishResult()
|
||||
}
|
||||
|
|
@ -0,0 +1,282 @@
|
|||
package publisher
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"geo/internal/config"
|
||||
|
||||
"github.com/go-rod/rod"
|
||||
"github.com/go-rod/rod/lib/launcher"
|
||||
"github.com/go-rod/rod/lib/proto"
|
||||
)
|
||||
|
||||
type BasePublisher struct {
|
||||
Headless bool
|
||||
Title string
|
||||
Content string
|
||||
Tags []string
|
||||
TenantID string
|
||||
PlatIndex string
|
||||
RequestID string
|
||||
ImagePath string
|
||||
WordPath string
|
||||
|
||||
Browser *rod.Browser
|
||||
Page *rod.Page
|
||||
Logger *log.Logger
|
||||
LogFile *os.File
|
||||
|
||||
LoginURL string
|
||||
EditorURL string
|
||||
LoginedURL string
|
||||
CookiesFile string
|
||||
|
||||
PlatInfo map[string]interface{}
|
||||
config *config.Config
|
||||
}
|
||||
|
||||
func NewBasePublisher(headless bool, title, content string, tags []string, tenantID, platIndex, requestID, imagePath, wordPath string, platInfo map[string]interface{}, config *config.Config) *BasePublisher {
|
||||
cookiesDir := filepath.Join(config.Sys.CookiesDir, tenantID)
|
||||
os.MkdirAll(cookiesDir, 0755)
|
||||
cookiesFile := filepath.Join(cookiesDir, platIndex+".json")
|
||||
|
||||
logFile, _ := os.Create(filepath.Join(config.Sys.LogsDir, requestID+".log"))
|
||||
logger := log.New(logFile, "", log.LstdFlags)
|
||||
|
||||
return &BasePublisher{
|
||||
Headless: headless,
|
||||
Title: title,
|
||||
Content: content,
|
||||
Tags: tags,
|
||||
TenantID: tenantID,
|
||||
PlatIndex: platIndex,
|
||||
RequestID: requestID,
|
||||
ImagePath: imagePath,
|
||||
WordPath: wordPath,
|
||||
Logger: logger,
|
||||
LogFile: logFile,
|
||||
CookiesFile: cookiesFile,
|
||||
PlatInfo: platInfo,
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *BasePublisher) SetupDriver() error {
|
||||
l := launcher.New()
|
||||
l.Headless(b.Headless)
|
||||
// 设置用户数据目录
|
||||
userDataDir := filepath.Join(b.config.Sys.ChromeDataDir, b.TenantID)
|
||||
os.MkdirAll(userDataDir, 0755)
|
||||
l.UserDataDir(userDataDir)
|
||||
|
||||
// 设置 Leakless 模式(解决 Windows 上的问题)
|
||||
l.Leakless(false)
|
||||
|
||||
// 设置 Chrome 启动参数
|
||||
l.Set("disable-blink-features", "AutomationControlled")
|
||||
l.Set("no-sandbox")
|
||||
l.Set("disable-dev-shm-usage")
|
||||
l.Set("disable-gpu")
|
||||
l.Set("disable-software-rasterizer")
|
||||
l.Set("disable-setuid-sandbox")
|
||||
l.Set("remote-debugging-port", "9222")
|
||||
|
||||
// 窗口大小
|
||||
l.Set("window-size", "1920,1080")
|
||||
l.Set("lang", "zh-CN")
|
||||
|
||||
url, err := l.Launch()
|
||||
if err != nil {
|
||||
return fmt.Errorf("启动浏览器失败: %v", err)
|
||||
}
|
||||
|
||||
b.Browser = rod.New().ControlURL(url).MustConnect()
|
||||
b.Page = b.Browser.MustPage()
|
||||
b.Page.MustSetViewport(1920, 1080, 1, false)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *BasePublisher) Close() {
|
||||
if b.Page != nil {
|
||||
b.Page.Close()
|
||||
}
|
||||
if b.Browser != nil {
|
||||
b.Browser.Close()
|
||||
}
|
||||
if b.LogFile != nil {
|
||||
b.LogFile.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (b *BasePublisher) SaveCookies() error {
|
||||
cookies, err := b.Page.Cookies(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
data, err := json.Marshal(cookies)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(b.CookiesFile, data, 0644)
|
||||
}
|
||||
|
||||
func (b *BasePublisher) LoadCookies() error {
|
||||
data, err := os.ReadFile(b.CookiesFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var cookies []*proto.NetworkCookieParam
|
||||
if err := json.Unmarshal(data, &cookies); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return b.Page.SetCookies(cookies)
|
||||
}
|
||||
|
||||
func (b *BasePublisher) RefreshPage() error {
|
||||
_, err := b.Page.Eval(`() => location.reload()`)
|
||||
return err
|
||||
}
|
||||
|
||||
func (b *BasePublisher) WaitForPageReady(timeout int) error {
|
||||
return b.Page.Timeout(time.Duration(timeout) * time.Second).WaitLoad()
|
||||
}
|
||||
|
||||
func (b *BasePublisher) WaitForElement(selector string, timeout int) (*rod.Element, error) {
|
||||
return b.Page.Timeout(time.Duration(timeout) * time.Second).Element(selector)
|
||||
}
|
||||
|
||||
func (b *BasePublisher) WaitForElementVisible(selector string, timeout int) (*rod.Element, error) {
|
||||
el, err := b.WaitForElement(selector, timeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := el.WaitVisible(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return el, nil
|
||||
}
|
||||
|
||||
func (b *BasePublisher) WaitForElementClickable(selector string, timeout int) (*rod.Element, error) {
|
||||
el, err := b.WaitForElementVisible(selector, timeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := el.WaitVisible(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return el, nil
|
||||
}
|
||||
|
||||
func (b *BasePublisher) JSClick(element *rod.Element) error {
|
||||
// 使用 element.Evaluate 并传入 EvalOptions
|
||||
_, err := element.Evaluate(&rod.EvalOptions{
|
||||
JS: `el => el.click()`,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (b *BasePublisher) ScrollToElement(element *rod.Element) error {
|
||||
_, err := element.Evaluate(&rod.EvalOptions{
|
||||
JS: `el => el.scrollIntoView({block: 'center', behavior: 'smooth'})`,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (b *BasePublisher) Sleep(seconds int) {
|
||||
time.Sleep(time.Duration(seconds) * time.Second)
|
||||
}
|
||||
|
||||
func (b *BasePublisher) LogStep(stepName string, success bool, message string) {
|
||||
if success {
|
||||
b.Logger.Printf("✅ %s: 成功 %s", stepName, message)
|
||||
} else {
|
||||
b.Logger.Printf("❌ %s: 失败 %s", stepName, message)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *BasePublisher) LogInfo(message string) {
|
||||
b.Logger.Printf("📌 %s", message)
|
||||
}
|
||||
|
||||
func (b *BasePublisher) LogError(message string) {
|
||||
b.Logger.Printf("❌ %s", message)
|
||||
}
|
||||
|
||||
func (b *BasePublisher) GetCurrentURL() string {
|
||||
info := b.Page.MustInfo()
|
||||
return info.URL
|
||||
}
|
||||
|
||||
func (b *BasePublisher) Screenshot(filename string) error {
|
||||
data, err := b.Page.Screenshot(false, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(filename, data, 0644)
|
||||
}
|
||||
|
||||
// 抽象方法 - 子类需要实现
|
||||
func (b *BasePublisher) WaitLogin() (bool, string) {
|
||||
return false, "需要实现"
|
||||
}
|
||||
|
||||
func (b *BasePublisher) CheckLoginStatus() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (b *BasePublisher) CheckLogin() (bool, string) {
|
||||
return false, "需要实现"
|
||||
}
|
||||
|
||||
func (b *BasePublisher) PublishNote() (bool, string) {
|
||||
return false, "需要实现"
|
||||
}
|
||||
|
||||
// ClearContentEditable 清空 contenteditable 元素的内容
|
||||
func (b *BasePublisher) ClearContentEditable(element *rod.Element) error {
|
||||
_, err := element.Evaluate(&rod.EvalOptions{
|
||||
JS: `el => { el.innerText = ''; el.innerHTML = ''; el.dispatchEvent(new Event('input', {bubbles: true})); }`,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// SetContentEditable 设置 contenteditable 元素的内容
|
||||
func (b *BasePublisher) SetContentEditable(element *rod.Element, content string) error {
|
||||
_, err := element.Evaluate(&rod.EvalOptions{
|
||||
JS: `(el, val) => { el.innerText = val; el.dispatchEvent(new Event('input', {bubbles: true})); }`,
|
||||
JSArgs: []interface{}{content},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// SetInputValue 设置输入框值并触发事件
|
||||
func (b *BasePublisher) SetInputValue(element *rod.Element, value string) error {
|
||||
_, err := element.Evaluate(&rod.EvalOptions{
|
||||
JS: `(el, val) => { el.value = val; el.dispatchEvent(new Event('input', {bubbles: true})); el.dispatchEvent(new Event('change', {bubbles: true})); }`,
|
||||
JSArgs: []interface{}{value},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// ClearInput 清空输入框
|
||||
func (b *BasePublisher) ClearInput(element *rod.Element) error {
|
||||
_, err := element.Evaluate(&rod.EvalOptions{
|
||||
JS: `el => { el.value = ''; el.dispatchEvent(new Event('input', {bubbles: true})); }`,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// SleepMs 毫秒级等待
|
||||
func (b *BasePublisher) SleepMs(milliseconds int) {
|
||||
time.Sleep(time.Duration(milliseconds) * time.Millisecond)
|
||||
}
|
||||
|
|
@ -0,0 +1,425 @@
|
|||
package publisher
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"geo/internal/config"
|
||||
|
||||
"github.com/go-rod/rod"
|
||||
"github.com/go-rod/rod/lib/proto"
|
||||
)
|
||||
|
||||
type XiaohongshuPublisher struct {
|
||||
*BasePublisher
|
||||
}
|
||||
|
||||
func NewXiaohongshuPublisher(headless bool, title, content string, tags []string, tenantID, platIndex, requestID, imagePath, wordPath string, platInfo map[string]interface{}, cfg *config.Config) *XiaohongshuPublisher {
|
||||
base := NewBasePublisher(headless, title, content, tags, tenantID, platIndex, requestID, imagePath, wordPath, platInfo, cfg)
|
||||
if platInfo != nil {
|
||||
base.LoginURL = getString(platInfo, "login_url")
|
||||
base.EditorURL = getString(platInfo, "edit_url")
|
||||
base.LoginedURL = getString(platInfo, "logined_url")
|
||||
}
|
||||
return &XiaohongshuPublisher{BasePublisher: base}
|
||||
}
|
||||
|
||||
func (p *XiaohongshuPublisher) CheckLoginStatus() bool {
|
||||
url := p.GetCurrentURL()
|
||||
// 如果URL包含登录相关关键词,表示未登录
|
||||
if strings.Contains(url, "login") || strings.Contains(url, "signin") || strings.Contains(url, "passport") {
|
||||
return false
|
||||
}
|
||||
// 如果URL是编辑页面或主页,表示已登录
|
||||
if strings.Contains(url, "creator") || strings.Contains(url, "editor") || strings.Contains(url, "publish") {
|
||||
return true
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *XiaohongshuPublisher) CheckLogin() (bool, string) {
|
||||
p.LogInfo("检查登录状态...")
|
||||
|
||||
if err := p.SetupDriver(); err != nil {
|
||||
return false, fmt.Sprintf("浏览器启动失败: %v", err)
|
||||
}
|
||||
defer p.Close()
|
||||
|
||||
p.Page.MustNavigate(p.LoginedURL)
|
||||
p.Sleep(3)
|
||||
p.WaitForPageReady(5)
|
||||
|
||||
if p.CheckLoginStatus() {
|
||||
p.SaveCookies()
|
||||
return true, "已登录"
|
||||
}
|
||||
return false, "未登录"
|
||||
}
|
||||
|
||||
func (p *XiaohongshuPublisher) WaitLogin() (bool, string) {
|
||||
p.LogInfo("开始等待登录...")
|
||||
|
||||
if err := p.SetupDriver(); err != nil {
|
||||
return false, fmt.Sprintf("浏览器启动失败: %v", err)
|
||||
}
|
||||
defer p.Close()
|
||||
|
||||
// 先尝试访问已登录页面
|
||||
p.Page.MustNavigate(p.LoginedURL)
|
||||
p.Sleep(3)
|
||||
|
||||
if p.CheckLoginStatus() {
|
||||
p.SaveCookies()
|
||||
p.LogInfo("已有登录状态")
|
||||
return true, "already_logged_in"
|
||||
}
|
||||
|
||||
// 未登录,跳转到登录页
|
||||
p.Page.MustNavigate(p.LoginURL)
|
||||
p.LogInfo("请扫描二维码登录...")
|
||||
|
||||
// 等待登录完成,最多120秒
|
||||
for i := 0; i < 120; i++ {
|
||||
time.Sleep(1 * time.Second)
|
||||
if p.CheckLoginStatus() {
|
||||
p.SaveCookies()
|
||||
p.LogInfo("登录成功")
|
||||
return true, "login_success"
|
||||
}
|
||||
}
|
||||
|
||||
return false, "登录超时"
|
||||
}
|
||||
|
||||
func (p *XiaohongshuPublisher) inputContent() error {
|
||||
p.LogInfo("输入文章内容...")
|
||||
|
||||
// 等待编辑器加载
|
||||
contentEditor, err := p.WaitForElementVisible(".tiptap.ProseMirror", 10)
|
||||
if err != nil {
|
||||
// 尝试其他选择器
|
||||
contentEditor, err = p.WaitForElementVisible("[contenteditable='true']", 10)
|
||||
if err != nil {
|
||||
return fmt.Errorf("未找到内容编辑器: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 点击获取焦点 - 使用 Click 方法
|
||||
if err := contentEditor.Click(proto.InputMouseButtonLeft, 1); err != nil {
|
||||
return fmt.Errorf("点击编辑器失败: %v", err)
|
||||
}
|
||||
p.SleepMs(500)
|
||||
|
||||
// 清空现有内容 - 使用 JavaScript 清空
|
||||
if err := p.ClearContentEditable(contentEditor); err != nil {
|
||||
p.LogInfo(fmt.Sprintf("清空编辑器失败: %v", err))
|
||||
}
|
||||
p.SleepMs(300)
|
||||
|
||||
// 输入新内容 - 使用 JavaScript 设置内容
|
||||
if err := p.SetContentEditable(contentEditor, p.Content); err != nil {
|
||||
// 如果 JS 方式失败,尝试直接输入
|
||||
contentEditor.Input(p.Content)
|
||||
}
|
||||
p.LogInfo(fmt.Sprintf("内容已输入,长度: %d", len(p.Content)))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *XiaohongshuPublisher) inputTitle() error {
|
||||
p.LogInfo("输入标题...")
|
||||
|
||||
// 查找标题输入框
|
||||
titleSelectors := []string{
|
||||
"textarea.d-input",
|
||||
".d-input input",
|
||||
"textarea[placeholder*='标题']",
|
||||
"textarea",
|
||||
}
|
||||
|
||||
var titleInput *rod.Element
|
||||
var err error
|
||||
|
||||
for _, selector := range titleSelectors {
|
||||
titleInput, err = p.WaitForElementVisible(selector, 3)
|
||||
if err == nil && titleInput != nil {
|
||||
p.LogInfo(fmt.Sprintf("找到标题输入框: %s", selector))
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if titleInput == nil {
|
||||
return fmt.Errorf("未找到标题输入框")
|
||||
}
|
||||
|
||||
// 点击获取焦点
|
||||
if err := titleInput.Click(proto.InputMouseButtonLeft, 1); err != nil {
|
||||
return fmt.Errorf("点击标题框失败: %v", err)
|
||||
}
|
||||
p.SleepMs(500)
|
||||
|
||||
// 清空输入框
|
||||
if err := p.ClearInput(titleInput); err != nil {
|
||||
// 备用清空方式
|
||||
titleInput.Input("")
|
||||
}
|
||||
p.SleepMs(300)
|
||||
|
||||
// 输入标题
|
||||
if err := p.SetInputValue(titleInput, p.Title); err != nil {
|
||||
// 备用输入方式
|
||||
titleInput.Input(p.Title)
|
||||
}
|
||||
|
||||
p.LogInfo(fmt.Sprintf("标题已输入: %s", p.Title))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *XiaohongshuPublisher) inputTags() error {
|
||||
if len(p.Tags) == 0 {
|
||||
p.LogInfo("无标签需要设置")
|
||||
return nil
|
||||
}
|
||||
|
||||
p.LogInfo(fmt.Sprintf("设置标签: %v", p.Tags))
|
||||
|
||||
// 构建标签字符串
|
||||
tagStr := ""
|
||||
for _, tag := range p.Tags {
|
||||
if tagStr != "" {
|
||||
tagStr += " "
|
||||
}
|
||||
tagStr += "#" + tag
|
||||
}
|
||||
|
||||
// 查找标签输入区域
|
||||
tagInput, err := p.WaitForElementVisible(".tiptap-container [contenteditable='true']", 5)
|
||||
if err != nil {
|
||||
p.LogInfo("未找到标签输入框,跳过标签设置")
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := tagInput.Click(proto.InputMouseButtonLeft, 1); err != nil {
|
||||
p.LogInfo(fmt.Sprintf("点击标签框失败: %v", err))
|
||||
}
|
||||
p.SleepMs(500)
|
||||
|
||||
if err := p.SetContentEditable(tagInput, tagStr); err != nil {
|
||||
tagInput.Input(tagStr)
|
||||
}
|
||||
|
||||
p.LogInfo("标签设置完成")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *XiaohongshuPublisher) uploadImage() error {
|
||||
if p.ImagePath == "" {
|
||||
p.LogInfo("无封面图片,跳过上传")
|
||||
return nil
|
||||
}
|
||||
|
||||
p.LogInfo(fmt.Sprintf("上传封面图片: %s", p.ImagePath))
|
||||
|
||||
// 查找封面上传按钮
|
||||
uploadBtn, err := p.WaitForElementClickable(".upload-content", 5)
|
||||
if err != nil {
|
||||
p.LogInfo("未找到封面上传区域,跳过")
|
||||
return nil
|
||||
}
|
||||
|
||||
// 使用 Click 方法
|
||||
if err := uploadBtn.Click(proto.InputMouseButtonLeft, 1); err != nil {
|
||||
p.LogInfo(fmt.Sprintf("点击上传按钮失败: %v", err))
|
||||
}
|
||||
p.SleepMs(1000)
|
||||
|
||||
// 查找文件输入框
|
||||
fileInput, err := p.Page.Element("input[type='file']")
|
||||
if err != nil {
|
||||
return fmt.Errorf("未找到文件输入框: %v", err)
|
||||
}
|
||||
|
||||
// 使用 SetFiles 上传文件
|
||||
if err := fileInput.SetFiles([]string{p.ImagePath}); err != nil {
|
||||
return fmt.Errorf("上传图片失败: %v", err)
|
||||
}
|
||||
|
||||
p.LogInfo("图片上传成功")
|
||||
p.Sleep(3)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *XiaohongshuPublisher) clickPublish() error {
|
||||
p.LogInfo("点击发布按钮...")
|
||||
|
||||
// 滚动到底部
|
||||
if _, err := p.Page.Eval(`() => window.scrollTo(0, document.body.scrollHeight)`); err != nil {
|
||||
p.LogInfo(fmt.Sprintf("滚动到底部失败: %v", err))
|
||||
}
|
||||
p.SleepMs(1000)
|
||||
|
||||
// 查找发布按钮
|
||||
publishSelectors := []string{
|
||||
".publish-page-publish-btn button",
|
||||
".publish-btn",
|
||||
".submit-btn",
|
||||
"button[type='submit']",
|
||||
}
|
||||
|
||||
var publishBtn *rod.Element
|
||||
var err error
|
||||
|
||||
for _, selector := range publishSelectors {
|
||||
publishBtn, err = p.WaitForElementClickable(selector, 5)
|
||||
if err == nil && publishBtn != nil {
|
||||
p.LogInfo(fmt.Sprintf("找到发布按钮: %s", selector))
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 如果还是没找到,通过文本查找
|
||||
if publishBtn == nil {
|
||||
publishBtn, err = p.Page.ElementX("//button[contains(text(), '发布')]")
|
||||
if err != nil {
|
||||
return fmt.Errorf("未找到发布按钮: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 滚动到按钮位置
|
||||
if err := p.ScrollToElement(publishBtn); err != nil {
|
||||
p.LogInfo(fmt.Sprintf("滚动到按钮失败: %v", err))
|
||||
}
|
||||
p.SleepMs(500)
|
||||
|
||||
// 点击发布 - 使用 Click 方法
|
||||
if err := publishBtn.Click(proto.InputMouseButtonLeft, 1); err != nil {
|
||||
return fmt.Errorf("点击发布按钮失败: %v", err)
|
||||
}
|
||||
|
||||
p.LogInfo("已点击发布按钮")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *XiaohongshuPublisher) waitForPublishResult() (bool, string) {
|
||||
p.LogInfo("等待发布结果...")
|
||||
|
||||
// 等待最多60秒
|
||||
for i := 0; i < 60; i++ {
|
||||
p.SleepMs(1000)
|
||||
|
||||
// 检查URL是否跳转到成功页面
|
||||
currentURL := p.GetCurrentURL()
|
||||
if strings.Contains(currentURL, "success") ||
|
||||
strings.Contains(currentURL, "content/manage") ||
|
||||
strings.Contains(currentURL, "work-management") {
|
||||
p.LogInfo("发布成功!")
|
||||
return true, "发布成功"
|
||||
}
|
||||
|
||||
// 检查是否有成功提示
|
||||
elements, _ := p.Page.Elements(".semi-toast-content, .toast-success, [class*='success']")
|
||||
for _, el := range elements {
|
||||
text, _ := el.Text()
|
||||
if strings.Contains(text, "成功") || strings.Contains(text, "已发布") {
|
||||
p.LogInfo(fmt.Sprintf("发布成功: %s", text))
|
||||
return true, text
|
||||
}
|
||||
}
|
||||
|
||||
// 检查是否有失败提示
|
||||
elements, _ = p.Page.Elements(".semi-toast-error, .toast-error, [class*='error']")
|
||||
for _, el := range elements {
|
||||
text, _ := el.Text()
|
||||
if strings.Contains(text, "失败") || strings.Contains(text, "错误") {
|
||||
p.LogError(fmt.Sprintf("发布失败: %s", text))
|
||||
return false, text
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false, "发布结果未知(超时)"
|
||||
}
|
||||
|
||||
func (p *XiaohongshuPublisher) PublishNote() (bool, string) {
|
||||
p.LogInfo(strings.Repeat("=", 50))
|
||||
p.LogInfo("开始发布小红书笔记...")
|
||||
p.LogInfo(fmt.Sprintf("标题: %s", p.Title))
|
||||
p.LogInfo(fmt.Sprintf("内容长度: %d", len(p.Content)))
|
||||
p.LogInfo(fmt.Sprintf("标签: %v", p.Tags))
|
||||
p.LogInfo(strings.Repeat("=", 50))
|
||||
|
||||
// 初始化浏览器
|
||||
if err := p.SetupDriver(); err != nil {
|
||||
return false, fmt.Sprintf("浏览器启动失败: %v", err)
|
||||
}
|
||||
defer p.Close()
|
||||
|
||||
// 访问已登录页面
|
||||
p.Page.MustNavigate(p.LoginedURL)
|
||||
p.Sleep(3)
|
||||
p.WaitForPageReady(5)
|
||||
|
||||
// 尝试加载cookies
|
||||
if err := p.LoadCookies(); err == nil {
|
||||
p.RefreshPage()
|
||||
p.Sleep(2)
|
||||
if p.CheckLoginStatus() {
|
||||
p.LogInfo("使用cookies登录成功")
|
||||
} else {
|
||||
p.LogInfo("cookies已过期,需要重新登录")
|
||||
return false, "需要登录"
|
||||
}
|
||||
}
|
||||
|
||||
// 检查登录状态
|
||||
if !p.CheckLoginStatus() {
|
||||
return false, "需要登录"
|
||||
}
|
||||
|
||||
// 保存cookies
|
||||
p.SaveCookies()
|
||||
|
||||
// 访问发布页面
|
||||
p.Page.MustNavigate(p.EditorURL)
|
||||
p.Sleep(3)
|
||||
p.WaitForPageReady(5)
|
||||
|
||||
// 执行发布流程
|
||||
steps := []struct {
|
||||
name string
|
||||
fn func() error
|
||||
}{
|
||||
{"输入内容", p.inputContent},
|
||||
{"输入标题", p.inputTitle},
|
||||
{"设置标签", p.inputTags},
|
||||
{"上传封面", p.uploadImage},
|
||||
}
|
||||
|
||||
for _, step := range steps {
|
||||
if err := step.fn(); err != nil {
|
||||
p.LogStep(step.name, false, err.Error())
|
||||
return false, fmt.Sprintf("%s失败: %v", step.name, err)
|
||||
}
|
||||
p.LogStep(step.name, true, "")
|
||||
p.SleepMs(500)
|
||||
}
|
||||
|
||||
// 点击发布
|
||||
if err := p.clickPublish(); err != nil {
|
||||
return false, err.Error()
|
||||
}
|
||||
|
||||
// 等待发布结果
|
||||
return p.waitForPublishResult()
|
||||
}
|
||||
|
||||
func getString(m map[string]interface{}, key string) string {
|
||||
if v, ok := m[key]; ok {
|
||||
if s, ok := v.(string); ok {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"geo/internal/server/router"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/middleware/logger"
|
||||
"github.com/gofiber/fiber/v2/middleware/recover"
|
||||
)
|
||||
|
||||
func NewHTTPServer(routerServer *router.RouterServer) *fiber.App {
|
||||
//构建 server
|
||||
app := initRoute()
|
||||
router.SetupRoutes(app, routerServer)
|
||||
return app
|
||||
}
|
||||
|
||||
func initRoute() *fiber.App {
|
||||
app := fiber.New()
|
||||
app.Use(
|
||||
recover.New(),
|
||||
logger.New(),
|
||||
)
|
||||
return app
|
||||
}
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"geo/internal/server/router"
|
||||
|
||||
"github.com/google/wire"
|
||||
)
|
||||
|
||||
var ProviderSetServer = wire.NewSet(
|
||||
NewServers,
|
||||
NewHTTPServer,
|
||||
router.NewRouterServer,
|
||||
router.NewAppModule,
|
||||
)
|
||||
|
|
@ -0,0 +1,47 @@
|
|||
package router
|
||||
|
||||
import (
|
||||
"geo/internal/config"
|
||||
"geo/internal/entitys"
|
||||
"geo/internal/service"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
type AppModule struct {
|
||||
cfg *config.Config
|
||||
appService *service.AppService
|
||||
loginService *service.LoginService
|
||||
publishService *service.PublishService
|
||||
}
|
||||
|
||||
func NewAppModule(
|
||||
cfg *config.Config,
|
||||
appService *service.AppService,
|
||||
loginService *service.LoginService,
|
||||
publishService *service.PublishService,
|
||||
) *AppModule {
|
||||
return &AppModule{
|
||||
cfg: cfg,
|
||||
appService: appService,
|
||||
loginService: loginService,
|
||||
publishService: publishService,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *AppModule) Register(router fiber.Router) {
|
||||
router.Post("/login_app", vali(m.appService.LoginApp, &entitys.LoginAppRequest{}))
|
||||
router.Post("/get_user_and_auto_status", vali(m.appService.GetUserAndAutoStatus, &entitys.GetUserAndAutoStatusRequest{}))
|
||||
router.Post("/add_user", vali(m.appService.AddUser, &entitys.AddUserRequest{}))
|
||||
router.Post("/del_user", vali(m.appService.DelUser, &entitys.DelUserRequest{}))
|
||||
router.Post("/get_app", vali(m.appService.GetApp, &entitys.GetAppRequest{}))
|
||||
router.Post("/publish_records", vali(m.publishService.PublishRecords, &entitys.PublishRecordsRequest{}))
|
||||
router.Post("/publish_on", vali(m.publishService.PublishOn, &entitys.PublishOnRequest{}))
|
||||
router.Post("/publish_off", vali(m.publishService.PublishOff, &entitys.PublishOffRequest{}))
|
||||
router.Post("/publish_status", vali(m.publishService.PublishStatus, &entitys.PublishStatusRequest{}))
|
||||
router.Post("/publish_execute_once", vali(m.publishService.PublishExecuteOnce, &entitys.PublishExecuteOnceRequest{}))
|
||||
router.Post("/publish_execute_retry", vali(m.publishService.PublishExecuteRetry, &entitys.PublishExecuteRetryRequest{}))
|
||||
router.Post("/get_publish_list", vali(m.publishService.GetPublishList, &entitys.GetPublishListRequest{}))
|
||||
router.Post("/login_platform", vali(m.loginService.LoginPlatform, &entitys.LoginPlatformRequest{}))
|
||||
router.Post("/logout_platform", vali(m.loginService.LogoutPlatform, &entitys.LogoutPlatformRequest{}))
|
||||
}
|
||||
|
|
@ -0,0 +1,128 @@
|
|||
package router
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"geo/pkg"
|
||||
"geo/tmpl/errcode"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
type Module interface {
|
||||
Register(r fiber.Router)
|
||||
}
|
||||
type RouterServer struct {
|
||||
modules []Module
|
||||
}
|
||||
|
||||
func NewRouterServer(
|
||||
app *AppModule,
|
||||
) *RouterServer {
|
||||
return &RouterServer{
|
||||
modules: []Module{app},
|
||||
}
|
||||
}
|
||||
|
||||
// SetupRoutes 设置路由
|
||||
func SetupRoutes(app *fiber.App, routerServer *RouterServer) {
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
// 设置 CORS 头
|
||||
c.Set("Access-Control-Allow-Origin", "*")
|
||||
c.Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
||||
c.Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
||||
|
||||
// 如果是预检请求(OPTIONS),直接返回 204
|
||||
if c.Method() == "OPTIONS" {
|
||||
return c.SendStatus(fiber.StatusNoContent) // 204
|
||||
}
|
||||
|
||||
// 继续处理后续中间件或路由
|
||||
return c.Next()
|
||||
})
|
||||
|
||||
registerResponse(app)
|
||||
|
||||
// 注册所有模块
|
||||
for _, module := range routerServer.modules {
|
||||
module.Register(app)
|
||||
}
|
||||
}
|
||||
|
||||
func registerResponse(router fiber.Router) {
|
||||
// 自定义返回
|
||||
router.Use(func(c *fiber.Ctx) error {
|
||||
err := c.Next()
|
||||
return registerCommon(c, err)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func registerCommon(c *fiber.Ctx, err error) error {
|
||||
// 调用下一个中间件或路由处理函数
|
||||
|
||||
// 如果有错误发生
|
||||
if err != nil {
|
||||
var businessErr *errcode.BusinessErr
|
||||
var fiberError *fiber.Error
|
||||
switch {
|
||||
case errors.As(err, &businessErr):
|
||||
case errors.As(err, &fiberError):
|
||||
errors.As(err, &fiberError)
|
||||
businessErr = errcode.NewBusinessErr(fiberError.Code, fiberError.Message)
|
||||
default:
|
||||
businessErr = errcode.SystemError
|
||||
}
|
||||
// 返回自定义错误响应
|
||||
return c.JSON(fiber.Map{
|
||||
"message": businessErr.Error(),
|
||||
"code": businessErr.Code(),
|
||||
"data": nil,
|
||||
})
|
||||
}
|
||||
// 如果没有错误发生,继续处理请求
|
||||
var data interface{}
|
||||
json.Unmarshal(c.Response().Body(), &data)
|
||||
return c.JSON(fiber.Map{
|
||||
"data": data,
|
||||
"message": errcode.Success.Error(),
|
||||
"code": errcode.Success.Code(),
|
||||
})
|
||||
}
|
||||
|
||||
func vali[T any](handler func(*fiber.Ctx, *T) error, _ *T) fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
var data T
|
||||
|
||||
// 解析请求
|
||||
if err := c.BodyParser(&data); err != nil {
|
||||
return errcode.ParamErr(err.Error())
|
||||
}
|
||||
|
||||
// 创建验证器
|
||||
validate := validator.New()
|
||||
|
||||
// 注册中文标签
|
||||
validate.RegisterTagNameFunc(func(fld reflect.StructField) string {
|
||||
name := fld.Tag.Get("zh")
|
||||
if name == "" {
|
||||
name = fld.Tag.Get("json")
|
||||
}
|
||||
return name
|
||||
})
|
||||
|
||||
// 验证
|
||||
if err := validate.Struct(&data); err != nil {
|
||||
er := make([]string, len(err.(validator.ValidationErrors)))
|
||||
for k, e := range err.(validator.ValidationErrors) {
|
||||
er[k] = pkg.GetErr(e.Tag(), e.Field(), e.Param())
|
||||
}
|
||||
return errcode.ParamErr(strings.Join(er, ","))
|
||||
}
|
||||
|
||||
return handler(c, &data)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"geo/internal/config"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
type Servers struct {
|
||||
cfg *config.Config
|
||||
HttpServer *fiber.App
|
||||
}
|
||||
|
||||
func NewServers(cfg *config.Config, fiber *fiber.App) *Servers {
|
||||
return &Servers{
|
||||
HttpServer: fiber,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,173 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"xorm.io/builder"
|
||||
|
||||
"geo/internal/biz"
|
||||
"geo/internal/config"
|
||||
"geo/internal/data/impl"
|
||||
"geo/internal/data/model"
|
||||
"geo/internal/entitys"
|
||||
"geo/internal/manager"
|
||||
"geo/pkg"
|
||||
"geo/tmpl/errcode"
|
||||
)
|
||||
|
||||
type AppService struct {
|
||||
cfg *config.Config
|
||||
tokenImpl *impl.TokenImpl
|
||||
userImpl *impl.UserImpl
|
||||
platImpl *impl.PlatImpl
|
||||
publishBiz *biz.PublishBiz
|
||||
}
|
||||
|
||||
func NewAppService(
|
||||
cfg *config.Config,
|
||||
tokenImpl *impl.TokenImpl,
|
||||
|
||||
userImpl *impl.UserImpl,
|
||||
platImpl *impl.PlatImpl,
|
||||
publishBiz *biz.PublishBiz,
|
||||
) *AppService {
|
||||
return &AppService{
|
||||
cfg: cfg,
|
||||
tokenImpl: tokenImpl,
|
||||
|
||||
userImpl: userImpl,
|
||||
platImpl: platImpl,
|
||||
publishBiz: publishBiz,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *AppService) LoginApp(c *fiber.Ctx, req *entitys.LoginAppRequest) error {
|
||||
cond := builder.NewCond().
|
||||
And(builder.Eq{"secret": req.Secret}).
|
||||
And(builder.Eq{"status": 1})
|
||||
tokenInfo := &model.Token{}
|
||||
err := a.tokenImpl.GetOneBySearchStruct(c.UserContext(), &cond, tokenInfo)
|
||||
if err != nil || tokenInfo == nil {
|
||||
return errcode.Forbidden("密钥无效")
|
||||
}
|
||||
|
||||
accessToken := pkg.GenerateUUID()
|
||||
err = a.tokenImpl.UpdateByKey(c.UserContext(), a.tokenImpl.PrimaryKey(), tokenInfo.ID, &model.Token{
|
||||
AccessToken: accessToken,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return pkg.HandleResponse(c, fiber.Map{
|
||||
"access_token": accessToken,
|
||||
"user_limit": tokenInfo.UserLimit,
|
||||
"expire_time": tokenInfo.ExpireTime,
|
||||
})
|
||||
}
|
||||
|
||||
func (a *AppService) GetUserAndAutoStatus(c *fiber.Ctx, req *entitys.GetUserAndAutoStatusRequest) error {
|
||||
tokenInfo, err := a.publishBiz.ValidateAccessToken(c.UserContext(), req.AccessToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cond := builder.NewCond().
|
||||
And(builder.Eq{"token_id": tokenInfo.ID}).
|
||||
And(builder.Eq{"status": 1})
|
||||
users := make([]model.User, 0)
|
||||
_, err = a.userImpl.GetListToStruct(c.UserContext(), &cond, nil, &users, "id desc")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pm := manager.GetPublishManager(a.cfg, a.tokenImpl.GetDb())
|
||||
return pkg.HandleResponse(c, fiber.Map{
|
||||
"user": users,
|
||||
"auto_status": pm.AutoStatus,
|
||||
})
|
||||
}
|
||||
|
||||
func (a *AppService) AddUser(c *fiber.Ctx, req *entitys.AddUserRequest) error {
|
||||
tokenInfo, err := a.publishBiz.ValidateAccessToken(c.UserContext(), req.AccessToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cond := builder.NewCond().
|
||||
And(builder.Eq{"token_id": tokenInfo.ID}).
|
||||
And(builder.Eq{"status": 1})
|
||||
userCount, err := a.userImpl.CountByCond(c.UserContext(), &cond)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if userCount >= int64(tokenInfo.UserLimit) {
|
||||
return errcode.ParamErr(fmt.Sprintf("用户已达上限(%d)", tokenInfo.UserLimit))
|
||||
}
|
||||
|
||||
userIndex := pkg.GenerateUserIndex()
|
||||
err = a.userImpl.Add(c.UserContext(), &model.User{
|
||||
TokenID: tokenInfo.ID,
|
||||
UserIndex: userIndex,
|
||||
Name: req.Name,
|
||||
Status: 1,
|
||||
})
|
||||
if err != nil {
|
||||
return errcode.SqlErr(err)
|
||||
}
|
||||
|
||||
// 获取平台列表并关联
|
||||
platCond := builder.NewCond().And(builder.Eq{"status": 1})
|
||||
plats := make([]model.Plat, 0)
|
||||
_, err = a.platImpl.GetListToStruct(c.UserContext(), &platCond, nil, &plats, "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, plat := range plats {
|
||||
loginRelation := &model.LoginRelation{
|
||||
UserIndex: userIndex,
|
||||
PlatIndex: plat.Index,
|
||||
LoginStatus: 2,
|
||||
Status: 1,
|
||||
}
|
||||
// 这里需要调用 loginRelationImpl 的 Add 方法
|
||||
_ = loginRelation
|
||||
}
|
||||
|
||||
return pkg.HandleResponse(c, fiber.Map{
|
||||
"user_name": req.Name,
|
||||
"user_index": userIndex,
|
||||
})
|
||||
}
|
||||
|
||||
func (a *AppService) DelUser(c *fiber.Ctx, req *entitys.DelUserRequest) error {
|
||||
// 需要从请求中获取 access_token
|
||||
// 这里简化处理
|
||||
err := a.userImpl.DeleteByKey(c.UserContext(), a.userImpl.PrimaryKey(), int64(req.ID))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return pkg.HandleResponse(c, fiber.Map{})
|
||||
}
|
||||
|
||||
func (a *AppService) GetApp(c *fiber.Ctx, req *entitys.GetAppRequest) error {
|
||||
_, err := a.publishBiz.ValidateAccessToken(c.UserContext(), req.AccessToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cond := builder.NewCond().
|
||||
And(builder.Eq{"login_relation.user_index": req.UserIndex}).
|
||||
And(builder.Eq{"login_relation.status": 1}).
|
||||
And(builder.Eq{"plat.status": 1})
|
||||
|
||||
result, err := a.platImpl.GetPlatListWithLoginStatus(c.UserContext(), &cond)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return pkg.HandleResponse(c, result)
|
||||
}
|
||||
|
|
@ -0,0 +1,95 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
|
||||
"geo/internal/biz"
|
||||
"geo/internal/config"
|
||||
"geo/internal/entitys"
|
||||
"geo/internal/publisher"
|
||||
"geo/pkg"
|
||||
"geo/tmpl/errcode"
|
||||
)
|
||||
|
||||
type LoginService struct {
|
||||
cfg *config.Config
|
||||
publishBiz *biz.PublishBiz
|
||||
}
|
||||
|
||||
func NewLoginService(
|
||||
cfg *config.Config,
|
||||
publishBiz *biz.PublishBiz,
|
||||
|
||||
) *LoginService {
|
||||
return &LoginService{
|
||||
cfg: cfg,
|
||||
publishBiz: publishBiz,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *LoginService) LoginPlatform(c *fiber.Ctx, req *entitys.LoginPlatformRequest) error {
|
||||
_, err := s.publishBiz.ValidateAccessToken(c.UserContext(), req.AccessToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 获取平台信息
|
||||
platInfo, err := s.publishBiz.GetPlatInfo(c.UserContext(), req.PlatIndex)
|
||||
if err != nil {
|
||||
return errcode.NotFound("平台不存在")
|
||||
}
|
||||
|
||||
// 创建发布器
|
||||
platMap := map[string]interface{}{
|
||||
"login_url": platInfo.LoginURL,
|
||||
"edit_url": platInfo.EditURL,
|
||||
"logined_url": platInfo.LoginedURL,
|
||||
}
|
||||
|
||||
var pub interface{ WaitLogin() (bool, string) }
|
||||
switch req.PlatIndex {
|
||||
case "xhs":
|
||||
pub = publisher.NewXiaohongshuPublisher(false, "", "", nil, req.UserIndex, req.PlatIndex, "", "", "", platMap, s.cfg)
|
||||
default:
|
||||
pub = publisher.NewXiaohongshuPublisher(false, "", "", nil, req.UserIndex, req.PlatIndex, "", "", "", platMap, s.cfg)
|
||||
}
|
||||
|
||||
success, msg := pub.WaitLogin()
|
||||
if !success {
|
||||
return errcode.SysErr(msg)
|
||||
}
|
||||
|
||||
// 更新登录状态
|
||||
err = s.publishBiz.UpdateLoginStatus(c.UserContext(), req.UserIndex, req.PlatIndex, 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return pkg.HandleResponse(c, fiber.Map{})
|
||||
}
|
||||
|
||||
func (s *LoginService) LogoutPlatform(c *fiber.Ctx, req *entitys.LogoutPlatformRequest) error {
|
||||
_, err := s.publishBiz.ValidateAccessToken(c.UserContext(), req.AccessToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 更新登录状态为未登录
|
||||
err = s.publishBiz.UpdateLoginStatus(c.UserContext(), req.UserIndex, req.PlatIndex, 2)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return pkg.HandleResponse(c, fiber.Map{})
|
||||
}
|
||||
|
||||
func (s *LoginService) ServeQrcode(c *fiber.Ctx, filename string) error {
|
||||
filepath := filepath.Join(s.cfg.Sys.QrcodesDir, filename)
|
||||
if _, err := os.Stat(filepath); os.IsNotExist(err) {
|
||||
return errcode.NotFound("二维码不存在")
|
||||
}
|
||||
return c.SendFile(filepath)
|
||||
}
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"github.com/google/wire"
|
||||
)
|
||||
|
||||
var ProviderSetAppService = wire.NewSet(
|
||||
NewAppService,
|
||||
NewPublishService,
|
||||
NewLoginService,
|
||||
)
|
||||
|
|
@ -0,0 +1,205 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"geo/internal/data/model"
|
||||
"geo/utils"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
|
||||
"geo/internal/biz"
|
||||
"geo/internal/config"
|
||||
"geo/internal/entitys"
|
||||
"geo/internal/manager"
|
||||
"geo/pkg"
|
||||
"geo/tmpl/errcode"
|
||||
)
|
||||
|
||||
type PublishService struct {
|
||||
cfg *config.Config
|
||||
publishBiz *biz.PublishBiz
|
||||
db *utils.Db
|
||||
}
|
||||
|
||||
func NewPublishService(
|
||||
cfg *config.Config,
|
||||
publishBiz *biz.PublishBiz,
|
||||
db *utils.Db,
|
||||
) *PublishService {
|
||||
return &PublishService{
|
||||
cfg: cfg,
|
||||
publishBiz: publishBiz,
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *PublishService) PublishRecords(c *fiber.Ctx, req *entitys.PublishRecordsRequest) error {
|
||||
// 验证token
|
||||
tokenInfo, err := s.publishBiz.ValidateAccessToken(c.UserContext(), req.AccessToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 转换记录
|
||||
validRecords := make([]*model.Publish, 0)
|
||||
for _, record := range req.Records {
|
||||
publishTime, err := time.Parse("2006-01-02 15:04:05", record.PublishTime)
|
||||
if err != nil {
|
||||
return errcode.ParamErr(fmt.Sprintf("时间格式错误: %v", err))
|
||||
}
|
||||
|
||||
validRecords = append(validRecords, &model.Publish{
|
||||
UserIndex: record.UserIndex,
|
||||
RequestID: record.RequestID,
|
||||
Title: record.Title,
|
||||
Tag: record.Tag,
|
||||
Type: record.Type,
|
||||
PlatIndex: record.PlatIndex,
|
||||
URL: record.URL,
|
||||
PublishTime: publishTime,
|
||||
Img: record.Img,
|
||||
TokenID: tokenInfo.ID,
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
err = s.publishBiz.BatchInsertPublish(c.UserContext(), validRecords)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return pkg.HandleResponse(c, fiber.Map{
|
||||
"total": len(validRecords),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *PublishService) PublishOn(c *fiber.Ctx, req *entitys.PublishOnRequest) error {
|
||||
tokenInfo, err := s.publishBiz.ValidateAccessToken(c.UserContext(), req.AccessToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pm := manager.GetPublishManager(s.cfg, s.db)
|
||||
if pm.Start(int(tokenInfo.ID)) {
|
||||
return pkg.HandleResponse(c, fiber.Map{
|
||||
"auto_status": pm.AutoStatus,
|
||||
})
|
||||
}
|
||||
return errors.New("自动发布服务已在运行中")
|
||||
}
|
||||
|
||||
func (s *PublishService) PublishOff(c *fiber.Ctx, req *entitys.PublishOffRequest) error {
|
||||
_, err := s.publishBiz.ValidateAccessToken(c.UserContext(), req.AccessToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pm := manager.GetPublishManager(s.cfg, s.db)
|
||||
if pm.Stop() {
|
||||
return pkg.HandleResponse(c, fiber.Map{})
|
||||
}
|
||||
return errors.New("自动发布服务未运行")
|
||||
}
|
||||
|
||||
func (s *PublishService) PublishStatus(c *fiber.Ctx, req *entitys.PublishStatusRequest) error {
|
||||
_, err := s.publishBiz.ValidateAccessToken(c.UserContext(), req.AccessToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pm := manager.GetPublishManager(s.cfg, s.db)
|
||||
|
||||
if req.RequestID != "" {
|
||||
// 查询单个任务
|
||||
task, err := s.publishBiz.GetTaskByRequestID(c.UserContext(), req.RequestID)
|
||||
if err != nil {
|
||||
return errcode.NotFound("任务不存在")
|
||||
}
|
||||
return pkg.HandleResponse(c, task)
|
||||
}
|
||||
|
||||
// 查询整体状态
|
||||
return pkg.HandleResponse(c, pm.GetStatus())
|
||||
}
|
||||
|
||||
func (s *PublishService) PublishExecuteOnce(c *fiber.Ctx, req *entitys.PublishExecuteOnceRequest) error {
|
||||
tokenInfo, err := s.publishBiz.ValidateAccessToken(c.UserContext(), req.AccessToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pm := manager.GetPublishManager(s.cfg, s.db)
|
||||
result := pm.ExecuteOnce(tokenInfo.ID)
|
||||
return pkg.HandleResponse(c, result)
|
||||
}
|
||||
|
||||
func (s *PublishService) PublishExecuteRetry(c *fiber.Ctx, req *entitys.PublishExecuteRetryRequest) error {
|
||||
_, err := s.publishBiz.ValidateAccessToken(c.UserContext(), req.AccessToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pm := manager.GetPublishManager(s.cfg, s.db)
|
||||
result := pm.RetryTask(req.RequestID)
|
||||
return pkg.HandleResponse(c, result)
|
||||
}
|
||||
|
||||
func (s *PublishService) GetPublishList(c *fiber.Ctx, req *entitys.GetPublishListRequest) error {
|
||||
tokenInfo, err := s.publishBiz.ValidateAccessToken(c.UserContext(), req.AccessToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
page := req.Page
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
pageSize := req.PageSize
|
||||
if pageSize < 1 {
|
||||
pageSize = 20
|
||||
}
|
||||
if pageSize > 100 {
|
||||
pageSize = 100
|
||||
}
|
||||
|
||||
filters := map[string]interface{}{
|
||||
"user_index": req.UserIndex,
|
||||
"tag": req.Tag,
|
||||
"type": req.Type,
|
||||
"plat_index": req.PlatIndex,
|
||||
"status": req.Status,
|
||||
"request_id": req.RequestID,
|
||||
}
|
||||
|
||||
list, total, err := s.publishBiz.GetPublishList(c.UserContext(), tokenInfo.ID, page, pageSize, filters)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return pkg.HandleResponse(c, fiber.Map{
|
||||
"list": list,
|
||||
"pagination": fiber.Map{
|
||||
"page": page,
|
||||
"page_size": pageSize,
|
||||
"total": total,
|
||||
"total_pages": (total + int64(pageSize) - 1) / int64(pageSize),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (s *PublishService) GetLogs(c *fiber.Ctx, requestID string) error {
|
||||
logFile := filepath.Join(s.cfg.Sys.LogsDir, requestID+".log")
|
||||
content, err := os.ReadFile(logFile)
|
||||
if err != nil {
|
||||
return errcode.NotFound("日志文件不存在")
|
||||
}
|
||||
return pkg.HandleResponse(c, fiber.Map{
|
||||
"request_id": requestID,
|
||||
"log_content": string(content),
|
||||
})
|
||||
}
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
package pkg
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
md "github.com/JohannesKaufmann/html-to-markdown"
|
||||
strip "github.com/grokify/html-strip-tags-go"
|
||||
)
|
||||
|
||||
func ExtractWordContent(filePath string, format string) (string, error) {
|
||||
// 简化版:读取文件内容并转换格式
|
||||
// 完整实现需要使用专门的docx库
|
||||
|
||||
content := "从Word文档提取的内容"
|
||||
|
||||
switch format {
|
||||
case "html":
|
||||
return "<p>" + content + "</p>", nil
|
||||
case "markdown":
|
||||
converter := md.NewConverter("", true, nil)
|
||||
return converter.ConvertString("<p>" + content + "</p>")
|
||||
default:
|
||||
return strip.StripTags(content), nil
|
||||
}
|
||||
}
|
||||
|
||||
func ParseTags(tagStr string) []string {
|
||||
if tagStr == "" {
|
||||
return []string{}
|
||||
}
|
||||
tags := strings.Split(tagStr, ",")
|
||||
result := make([]string, 0)
|
||||
for _, t := range tags {
|
||||
trimmed := strings.TrimSpace(t)
|
||||
if trimmed != "" {
|
||||
result = append(result, trimmed)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func IsTimeExceeded(targetTime string) bool {
|
||||
// 实现时间比较
|
||||
return false
|
||||
}
|
||||
|
|
@ -0,0 +1,247 @@
|
|||
package pkg
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand/v2"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/go-viper/mapstructure/v2"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func GetModuleDir() (string, error) {
|
||||
dir, err := os.Getwd()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
for {
|
||||
modPath := filepath.Join(dir, "go.mod")
|
||||
if _, err := os.Stat(modPath); err == nil {
|
||||
return dir, nil // 找到 go.mod
|
||||
}
|
||||
|
||||
// 向上查找父目录
|
||||
parent := filepath.Dir(dir)
|
||||
if parent == dir {
|
||||
break // 到达根目录,未找到
|
||||
}
|
||||
dir = parent
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("go.mod not found in current directory or parents")
|
||||
}
|
||||
|
||||
// GetCacheDir 用于获取缓存目录路径
|
||||
// 如果缓存目录不存在,则会自动创建
|
||||
// 返回值:
|
||||
// - string: 缓存目录的路径
|
||||
// - error: 如果获取模块目录失败或创建缓存目录失败,则返回错误信息
|
||||
func GetCacheDir() (string, error) {
|
||||
// 获取模块目录
|
||||
modDir, err := GetModuleDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
// 拼接缓存目录路径
|
||||
path := fmt.Sprintf("%s/cache", modDir)
|
||||
// 创建目录(包括所有必要的父目录),权限设置为0755
|
||||
err = os.MkdirAll(path, 0755)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("创建目录失败: %w", err)
|
||||
}
|
||||
// 返回成功创建的缓存目录路径
|
||||
return path, nil
|
||||
}
|
||||
|
||||
func GetTmplDir() (string, error) {
|
||||
modDir, err := GetModuleDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
path := fmt.Sprintf("%s/tmpl", modDir)
|
||||
err = os.MkdirAll(path, 0755)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("创建目录失败: %w", err)
|
||||
}
|
||||
return path, nil
|
||||
}
|
||||
|
||||
func ReverseSliceNew[T any](s []T) []T {
|
||||
result := make([]T, len(s))
|
||||
for i := 0; i < len(s); i++ {
|
||||
result[i] = s[len(s)-1-i]
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func JsonStringIgonErr(data interface{}) string {
|
||||
return string(JsonByteIgonErr(data))
|
||||
}
|
||||
|
||||
func JsonByteIgonErr(data interface{}) []byte {
|
||||
dataByte, _ := json.Marshal(data)
|
||||
return dataByte
|
||||
}
|
||||
|
||||
func IntersectionGeneric[T comparable](slice1, slice2 []T) []T {
|
||||
m := make(map[T]bool)
|
||||
result := []T{}
|
||||
|
||||
for _, v := range slice1 {
|
||||
m[v] = true
|
||||
}
|
||||
|
||||
for _, v := range slice2 {
|
||||
if m[v] {
|
||||
result = append(result, v)
|
||||
delete(m, v) // 避免重复
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func CreateOrderNum(prefix string) string {
|
||||
code := fmt.Sprintf("%04d", rand.IntN(10000))
|
||||
fmt.Println("4位随机数字:", code) // 输出示例: "0837"
|
||||
return prefix + time.Now().Format("20060102150405") + code
|
||||
}
|
||||
|
||||
func BuildUpdateMap(obj interface{}, omitFields ...string) map[string]interface{} {
|
||||
result := make(map[string]interface{})
|
||||
omitMap := make(map[string]bool)
|
||||
for _, f := range omitFields {
|
||||
omitMap[f] = true
|
||||
}
|
||||
|
||||
v := reflect.ValueOf(obj)
|
||||
if v.Kind() == reflect.Ptr {
|
||||
v = v.Elem()
|
||||
}
|
||||
t := v.Type()
|
||||
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
field := v.Field(i)
|
||||
fieldName := t.Field(i).Name
|
||||
|
||||
if omitMap[fieldName] {
|
||||
continue
|
||||
}
|
||||
|
||||
// 只处理非 nil 的指针字段
|
||||
if field.Kind() == reflect.Ptr && !field.IsNil() {
|
||||
// 将驼峰转为下划线(可选的,根据你的数据库列名决定)
|
||||
colName := CamelToSnake(fieldName)
|
||||
result[colName] = field.Elem().Interface()
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func CamelToSnake(s string) string {
|
||||
var result []rune
|
||||
for i, r := range s {
|
||||
if i > 0 && r >= 'A' && r <= 'Z' {
|
||||
result = append(result, '_')
|
||||
}
|
||||
result = append(result, r)
|
||||
}
|
||||
return string(result)
|
||||
}
|
||||
|
||||
func CopyNonNilFields(src, dst interface{}) error {
|
||||
config := &mapstructure.DecoderConfig{
|
||||
Result: dst,
|
||||
TagName: "json",
|
||||
ZeroFields: false, // 重要:不清零目标字段
|
||||
Squash: false,
|
||||
}
|
||||
|
||||
decoder, err := mapstructure.NewDecoder(config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return decoder.Decode(src)
|
||||
}
|
||||
|
||||
func DownloadFile(url string, saveDir string, filename string) (string, error) {
|
||||
os.MkdirAll(saveDir, 0755)
|
||||
|
||||
if filename == "" {
|
||||
filename = uuid.New().String() + ".docx"
|
||||
}
|
||||
|
||||
filePath := filepath.Join(saveDir, filename)
|
||||
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
out, err := os.Create(filePath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
_, err = io.Copy(out, resp.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
absPath, _ := filepath.Abs(filePath)
|
||||
return absPath, nil
|
||||
}
|
||||
|
||||
func DownloadImage(url string, requestID string, dir string) (string, error) {
|
||||
os.MkdirAll(dir, 0755)
|
||||
|
||||
ext := filepath.Ext(url)
|
||||
if ext == "" {
|
||||
ext = ".jpg"
|
||||
}
|
||||
filename := requestID + "_" + uuid.New().String() + ext
|
||||
filePath := filepath.Join(dir, filename)
|
||||
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
out, err := os.Create(filePath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
_, err = io.Copy(out, resp.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return filepath.Abs(filePath)
|
||||
}
|
||||
|
||||
func DeleteFile(path string) {
|
||||
if path != "" {
|
||||
os.Remove(path)
|
||||
}
|
||||
}
|
||||
|
||||
func GenerateUUID() string {
|
||||
return uuid.New().String()
|
||||
}
|
||||
|
||||
func GenerateUserIndex() string {
|
||||
return uuid.New().String()[:20]
|
||||
}
|
||||
|
|
@ -0,0 +1,279 @@
|
|||
package mapstructure
|
||||
|
||||
import (
|
||||
"encoding"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// typedDecodeHook takes a raw DecodeHookFunc (an interface{}) and turns
|
||||
// it into the proper DecodeHookFunc type, such as DecodeHookFuncType.
|
||||
func typedDecodeHook(h DecodeHookFunc) DecodeHookFunc {
|
||||
// Create variables here so we can reference them with the reflect pkg
|
||||
var f1 DecodeHookFuncType
|
||||
var f2 DecodeHookFuncKind
|
||||
var f3 DecodeHookFuncValue
|
||||
|
||||
// Fill in the variables into this interface and the rest is done
|
||||
// automatically using the reflect package.
|
||||
potential := []interface{}{f1, f2, f3}
|
||||
|
||||
v := reflect.ValueOf(h)
|
||||
vt := v.Type()
|
||||
for _, raw := range potential {
|
||||
pt := reflect.ValueOf(raw).Type()
|
||||
if vt.ConvertibleTo(pt) {
|
||||
return v.Convert(pt).Interface()
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DecodeHookExec executes the given decode hook. This should be used
|
||||
// since it'll naturally degrade to the older backwards compatible DecodeHookFunc
|
||||
// that took reflect.Kind instead of reflect.Type.
|
||||
func DecodeHookExec(
|
||||
raw DecodeHookFunc,
|
||||
from reflect.Value, to reflect.Value) (interface{}, error) {
|
||||
|
||||
switch f := typedDecodeHook(raw).(type) {
|
||||
case DecodeHookFuncType:
|
||||
return f(from.Type(), to.Type(), from.Interface())
|
||||
case DecodeHookFuncKind:
|
||||
return f(from.Kind(), to.Kind(), from.Interface())
|
||||
case DecodeHookFuncValue:
|
||||
return f(from, to)
|
||||
default:
|
||||
return nil, errors.New("invalid decode hook signature")
|
||||
}
|
||||
}
|
||||
|
||||
// ComposeDecodeHookFunc creates a single DecodeHookFunc that
|
||||
// automatically composes multiple DecodeHookFuncs.
|
||||
//
|
||||
// The composed funcs are called in order, with the result of the
|
||||
// previous transformation.
|
||||
func ComposeDecodeHookFunc(fs ...DecodeHookFunc) DecodeHookFunc {
|
||||
return func(f reflect.Value, t reflect.Value) (interface{}, error) {
|
||||
var err error
|
||||
data := f.Interface()
|
||||
|
||||
newFrom := f
|
||||
for _, f1 := range fs {
|
||||
data, err = DecodeHookExec(f1, newFrom, t)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newFrom = reflect.ValueOf(data)
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
}
|
||||
|
||||
// OrComposeDecodeHookFunc executes all input hook functions until one of them returns no error. In that case its value is returned.
|
||||
// If all hooks return an error, OrComposeDecodeHookFunc returns an error concatenating all error messages.
|
||||
func OrComposeDecodeHookFunc(ff ...DecodeHookFunc) DecodeHookFunc {
|
||||
return func(a, b reflect.Value) (interface{}, error) {
|
||||
var allErrs string
|
||||
var out interface{}
|
||||
var err error
|
||||
|
||||
for _, f := range ff {
|
||||
out, err = DecodeHookExec(f, a, b)
|
||||
if err != nil {
|
||||
allErrs += err.Error() + "\n"
|
||||
continue
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
return nil, errors.New(allErrs)
|
||||
}
|
||||
}
|
||||
|
||||
// StringToSliceHookFunc returns a DecodeHookFunc that converts
|
||||
// string to []string by splitting on the given sep.
|
||||
func StringToSliceHookFunc(sep string) DecodeHookFunc {
|
||||
return func(
|
||||
f reflect.Kind,
|
||||
t reflect.Kind,
|
||||
data interface{}) (interface{}, error) {
|
||||
if f != reflect.String || t != reflect.Slice {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
raw := data.(string)
|
||||
if raw == "" {
|
||||
return []string{}, nil
|
||||
}
|
||||
|
||||
return strings.Split(raw, sep), nil
|
||||
}
|
||||
}
|
||||
|
||||
// StringToTimeDurationHookFunc returns a DecodeHookFunc that converts
|
||||
// strings to time.Duration.
|
||||
func StringToTimeDurationHookFunc() DecodeHookFunc {
|
||||
return func(
|
||||
f reflect.Type,
|
||||
t reflect.Type,
|
||||
data interface{}) (interface{}, error) {
|
||||
if f.Kind() != reflect.String {
|
||||
return data, nil
|
||||
}
|
||||
if t != reflect.TypeOf(time.Duration(5)) {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// Convert it by parsing
|
||||
return time.ParseDuration(data.(string))
|
||||
}
|
||||
}
|
||||
|
||||
// StringToIPHookFunc returns a DecodeHookFunc that converts
|
||||
// strings to net.IP
|
||||
func StringToIPHookFunc() DecodeHookFunc {
|
||||
return func(
|
||||
f reflect.Type,
|
||||
t reflect.Type,
|
||||
data interface{}) (interface{}, error) {
|
||||
if f.Kind() != reflect.String {
|
||||
return data, nil
|
||||
}
|
||||
if t != reflect.TypeOf(net.IP{}) {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// Convert it by parsing
|
||||
ip := net.ParseIP(data.(string))
|
||||
if ip == nil {
|
||||
return net.IP{}, fmt.Errorf("failed parsing ip %v", data)
|
||||
}
|
||||
|
||||
return ip, nil
|
||||
}
|
||||
}
|
||||
|
||||
// StringToIPNetHookFunc returns a DecodeHookFunc that converts
|
||||
// strings to net.IPNet
|
||||
func StringToIPNetHookFunc() DecodeHookFunc {
|
||||
return func(
|
||||
f reflect.Type,
|
||||
t reflect.Type,
|
||||
data interface{}) (interface{}, error) {
|
||||
if f.Kind() != reflect.String {
|
||||
return data, nil
|
||||
}
|
||||
if t != reflect.TypeOf(net.IPNet{}) {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// Convert it by parsing
|
||||
_, net, err := net.ParseCIDR(data.(string))
|
||||
return net, err
|
||||
}
|
||||
}
|
||||
|
||||
// StringToTimeHookFunc returns a DecodeHookFunc that converts
|
||||
// strings to time.Time.
|
||||
func StringToTimeHookFunc(layout string) DecodeHookFunc {
|
||||
return func(
|
||||
f reflect.Type,
|
||||
t reflect.Type,
|
||||
data interface{}) (interface{}, error) {
|
||||
if f.Kind() != reflect.String {
|
||||
return data, nil
|
||||
}
|
||||
if t != reflect.TypeOf(time.Time{}) {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// Convert it by parsing
|
||||
return time.Parse(layout, data.(string))
|
||||
}
|
||||
}
|
||||
|
||||
// WeaklyTypedHook is a DecodeHookFunc which adds support for weak typing to
|
||||
// the decoder.
|
||||
//
|
||||
// Note that this is significantly different from the WeaklyTypedInput option
|
||||
// of the DecoderConfig.
|
||||
func WeaklyTypedHook(
|
||||
f reflect.Kind,
|
||||
t reflect.Kind,
|
||||
data interface{}) (interface{}, error) {
|
||||
dataVal := reflect.ValueOf(data)
|
||||
switch t {
|
||||
case reflect.String:
|
||||
switch f {
|
||||
case reflect.Bool:
|
||||
if dataVal.Bool() {
|
||||
return "1", nil
|
||||
}
|
||||
return "0", nil
|
||||
case reflect.Float32:
|
||||
return strconv.FormatFloat(dataVal.Float(), 'f', -1, 64), nil
|
||||
case reflect.Int:
|
||||
return strconv.FormatInt(dataVal.Int(), 10), nil
|
||||
case reflect.Slice:
|
||||
dataType := dataVal.Type()
|
||||
elemKind := dataType.Elem().Kind()
|
||||
if elemKind == reflect.Uint8 {
|
||||
return string(dataVal.Interface().([]uint8)), nil
|
||||
}
|
||||
case reflect.Uint:
|
||||
return strconv.FormatUint(dataVal.Uint(), 10), nil
|
||||
}
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func RecursiveStructToMapHookFunc() DecodeHookFunc {
|
||||
return func(f reflect.Value, t reflect.Value) (interface{}, error) {
|
||||
if f.Kind() != reflect.Struct {
|
||||
return f.Interface(), nil
|
||||
}
|
||||
|
||||
var i interface{} = struct{}{}
|
||||
if t.Type() != reflect.TypeOf(&i).Elem() {
|
||||
return f.Interface(), nil
|
||||
}
|
||||
|
||||
m := make(map[string]interface{})
|
||||
t.Set(reflect.ValueOf(m))
|
||||
|
||||
return f.Interface(), nil
|
||||
}
|
||||
}
|
||||
|
||||
// TextUnmarshallerHookFunc returns a DecodeHookFunc that applies
|
||||
// strings to the UnmarshalText function, when the target type
|
||||
// implements the encoding.TextUnmarshaler interface
|
||||
func TextUnmarshallerHookFunc() DecodeHookFuncType {
|
||||
return func(
|
||||
f reflect.Type,
|
||||
t reflect.Type,
|
||||
data interface{}) (interface{}, error) {
|
||||
if f.Kind() != reflect.String {
|
||||
return data, nil
|
||||
}
|
||||
result := reflect.New(t).Interface()
|
||||
unmarshaller, ok := result.(encoding.TextUnmarshaler)
|
||||
if !ok {
|
||||
return data, nil
|
||||
}
|
||||
if err := unmarshaller.UnmarshalText([]byte(data.(string))); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,567 @@
|
|||
package mapstructure
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math/big"
|
||||
"net"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestComposeDecodeHookFunc(t *testing.T) {
|
||||
f1 := func(
|
||||
f reflect.Kind,
|
||||
t reflect.Kind,
|
||||
data interface{}) (interface{}, error) {
|
||||
return data.(string) + "foo", nil
|
||||
}
|
||||
|
||||
f2 := func(
|
||||
f reflect.Kind,
|
||||
t reflect.Kind,
|
||||
data interface{}) (interface{}, error) {
|
||||
return data.(string) + "bar", nil
|
||||
}
|
||||
|
||||
f := ComposeDecodeHookFunc(f1, f2)
|
||||
|
||||
result, err := DecodeHookExec(
|
||||
f, reflect.ValueOf(""), reflect.ValueOf([]byte("")))
|
||||
if err != nil {
|
||||
t.Fatalf("bad: %s", err)
|
||||
}
|
||||
if result.(string) != "foobar" {
|
||||
t.Fatalf("bad: %#v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComposeDecodeHookFunc_err(t *testing.T) {
|
||||
f1 := func(reflect.Kind, reflect.Kind, interface{}) (interface{}, error) {
|
||||
return nil, errors.New("foo")
|
||||
}
|
||||
|
||||
f2 := func(reflect.Kind, reflect.Kind, interface{}) (interface{}, error) {
|
||||
panic("NOPE")
|
||||
}
|
||||
|
||||
f := ComposeDecodeHookFunc(f1, f2)
|
||||
|
||||
_, err := DecodeHookExec(
|
||||
f, reflect.ValueOf(""), reflect.ValueOf([]byte("")))
|
||||
if err.Error() != "foo" {
|
||||
t.Fatalf("bad: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComposeDecodeHookFunc_kinds(t *testing.T) {
|
||||
var f2From reflect.Kind
|
||||
|
||||
f1 := func(
|
||||
f reflect.Kind,
|
||||
t reflect.Kind,
|
||||
data interface{}) (interface{}, error) {
|
||||
return int(42), nil
|
||||
}
|
||||
|
||||
f2 := func(
|
||||
f reflect.Kind,
|
||||
t reflect.Kind,
|
||||
data interface{}) (interface{}, error) {
|
||||
f2From = f
|
||||
return data, nil
|
||||
}
|
||||
|
||||
f := ComposeDecodeHookFunc(f1, f2)
|
||||
|
||||
_, err := DecodeHookExec(
|
||||
f, reflect.ValueOf(""), reflect.ValueOf([]byte("")))
|
||||
if err != nil {
|
||||
t.Fatalf("bad: %s", err)
|
||||
}
|
||||
if f2From != reflect.Int {
|
||||
t.Fatalf("bad: %#v", f2From)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOrComposeDecodeHookFunc(t *testing.T) {
|
||||
f1 := func(
|
||||
f reflect.Kind,
|
||||
t reflect.Kind,
|
||||
data interface{}) (interface{}, error) {
|
||||
return data.(string) + "foo", nil
|
||||
}
|
||||
|
||||
f2 := func(
|
||||
f reflect.Kind,
|
||||
t reflect.Kind,
|
||||
data interface{}) (interface{}, error) {
|
||||
return data.(string) + "bar", nil
|
||||
}
|
||||
|
||||
f := OrComposeDecodeHookFunc(f1, f2)
|
||||
|
||||
result, err := DecodeHookExec(
|
||||
f, reflect.ValueOf(""), reflect.ValueOf([]byte("")))
|
||||
if err != nil {
|
||||
t.Fatalf("bad: %s", err)
|
||||
}
|
||||
if result.(string) != "foo" {
|
||||
t.Fatalf("bad: %#v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOrComposeDecodeHookFunc_correctValueIsLast(t *testing.T) {
|
||||
f1 := func(
|
||||
f reflect.Kind,
|
||||
t reflect.Kind,
|
||||
data interface{}) (interface{}, error) {
|
||||
return nil, errors.New("f1 error")
|
||||
}
|
||||
|
||||
f2 := func(
|
||||
f reflect.Kind,
|
||||
t reflect.Kind,
|
||||
data interface{}) (interface{}, error) {
|
||||
return nil, errors.New("f2 error")
|
||||
}
|
||||
|
||||
f3 := func(
|
||||
f reflect.Kind,
|
||||
t reflect.Kind,
|
||||
data interface{}) (interface{}, error) {
|
||||
return data.(string) + "bar", nil
|
||||
}
|
||||
|
||||
f := OrComposeDecodeHookFunc(f1, f2, f3)
|
||||
|
||||
result, err := DecodeHookExec(
|
||||
f, reflect.ValueOf(""), reflect.ValueOf([]byte("")))
|
||||
if err != nil {
|
||||
t.Fatalf("bad: %s", err)
|
||||
}
|
||||
if result.(string) != "bar" {
|
||||
t.Fatalf("bad: %#v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOrComposeDecodeHookFunc_err(t *testing.T) {
|
||||
f1 := func(
|
||||
f reflect.Kind,
|
||||
t reflect.Kind,
|
||||
data interface{}) (interface{}, error) {
|
||||
return nil, errors.New("f1 error")
|
||||
}
|
||||
|
||||
f2 := func(
|
||||
f reflect.Kind,
|
||||
t reflect.Kind,
|
||||
data interface{}) (interface{}, error) {
|
||||
return nil, errors.New("f2 error")
|
||||
}
|
||||
|
||||
f := OrComposeDecodeHookFunc(f1, f2)
|
||||
|
||||
_, err := DecodeHookExec(
|
||||
f, reflect.ValueOf(""), reflect.ValueOf([]byte("")))
|
||||
if err == nil {
|
||||
t.Fatalf("bad: should return an error")
|
||||
}
|
||||
if err.Error() != "f1 error\nf2 error\n" {
|
||||
t.Fatalf("bad: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComposeDecodeHookFunc_safe_nofuncs(t *testing.T) {
|
||||
f := ComposeDecodeHookFunc()
|
||||
type myStruct2 struct {
|
||||
MyInt int
|
||||
}
|
||||
|
||||
type myStruct1 struct {
|
||||
Blah map[string]myStruct2
|
||||
}
|
||||
|
||||
src := &myStruct1{Blah: map[string]myStruct2{
|
||||
"test": {
|
||||
MyInt: 1,
|
||||
},
|
||||
}}
|
||||
|
||||
dst := &myStruct1{}
|
||||
dConf := &DecoderConfig{
|
||||
Result: dst,
|
||||
ErrorUnused: true,
|
||||
DecodeHook: f,
|
||||
}
|
||||
d, err := NewDecoder(dConf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = d.Decode(src)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStringToSliceHookFunc(t *testing.T) {
|
||||
f := StringToSliceHookFunc(",")
|
||||
|
||||
strValue := reflect.ValueOf("42")
|
||||
sliceValue := reflect.ValueOf([]byte("42"))
|
||||
cases := []struct {
|
||||
f, t reflect.Value
|
||||
result interface{}
|
||||
err bool
|
||||
}{
|
||||
{sliceValue, sliceValue, []byte("42"), false},
|
||||
{strValue, strValue, "42", false},
|
||||
{
|
||||
reflect.ValueOf("foo,bar,baz"),
|
||||
sliceValue,
|
||||
[]string{"foo", "bar", "baz"},
|
||||
false,
|
||||
},
|
||||
{
|
||||
reflect.ValueOf(""),
|
||||
sliceValue,
|
||||
[]string{},
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tc := range cases {
|
||||
actual, err := DecodeHookExec(f, tc.f, tc.t)
|
||||
if tc.err != (err != nil) {
|
||||
t.Fatalf("case %d: expected jderr %#v", i, tc.err)
|
||||
}
|
||||
if !reflect.DeepEqual(actual, tc.result) {
|
||||
t.Fatalf(
|
||||
"case %d: expected %#v, got %#v",
|
||||
i, tc.result, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStringToTimeDurationHookFunc(t *testing.T) {
|
||||
f := StringToTimeDurationHookFunc()
|
||||
|
||||
timeValue := reflect.ValueOf(time.Duration(5))
|
||||
strValue := reflect.ValueOf("")
|
||||
cases := []struct {
|
||||
f, t reflect.Value
|
||||
result interface{}
|
||||
err bool
|
||||
}{
|
||||
{reflect.ValueOf("5s"), timeValue, 5 * time.Second, false},
|
||||
{reflect.ValueOf("5"), timeValue, time.Duration(0), true},
|
||||
{reflect.ValueOf("5"), strValue, "5", false},
|
||||
}
|
||||
|
||||
for i, tc := range cases {
|
||||
actual, err := DecodeHookExec(f, tc.f, tc.t)
|
||||
if tc.err != (err != nil) {
|
||||
t.Fatalf("case %d: expected jderr %#v", i, tc.err)
|
||||
}
|
||||
if !reflect.DeepEqual(actual, tc.result) {
|
||||
t.Fatalf(
|
||||
"case %d: expected %#v, got %#v",
|
||||
i, tc.result, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStringToTimeHookFunc(t *testing.T) {
|
||||
strValue := reflect.ValueOf("5")
|
||||
timeValue := reflect.ValueOf(time.Time{})
|
||||
cases := []struct {
|
||||
f, t reflect.Value
|
||||
layout string
|
||||
result interface{}
|
||||
err bool
|
||||
}{
|
||||
{reflect.ValueOf("2006-01-02T15:04:05Z"), timeValue, time.RFC3339,
|
||||
time.Date(2006, 1, 2, 15, 4, 5, 0, time.UTC), false},
|
||||
{strValue, timeValue, time.RFC3339, time.Time{}, true},
|
||||
{strValue, strValue, time.RFC3339, "5", false},
|
||||
}
|
||||
|
||||
for i, tc := range cases {
|
||||
f := StringToTimeHookFunc(tc.layout)
|
||||
actual, err := DecodeHookExec(f, tc.f, tc.t)
|
||||
if tc.err != (err != nil) {
|
||||
t.Fatalf("case %d: expected jderr %#v", i, tc.err)
|
||||
}
|
||||
if !reflect.DeepEqual(actual, tc.result) {
|
||||
t.Fatalf(
|
||||
"case %d: expected %#v, got %#v",
|
||||
i, tc.result, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStringToIPHookFunc(t *testing.T) {
|
||||
strValue := reflect.ValueOf("5")
|
||||
ipValue := reflect.ValueOf(net.IP{})
|
||||
cases := []struct {
|
||||
f, t reflect.Value
|
||||
result interface{}
|
||||
err bool
|
||||
}{
|
||||
{reflect.ValueOf("1.2.3.4"), ipValue,
|
||||
net.IPv4(0x01, 0x02, 0x03, 0x04), false},
|
||||
{strValue, ipValue, net.IP{}, true},
|
||||
{strValue, strValue, "5", false},
|
||||
}
|
||||
|
||||
for i, tc := range cases {
|
||||
f := StringToIPHookFunc()
|
||||
actual, err := DecodeHookExec(f, tc.f, tc.t)
|
||||
if tc.err != (err != nil) {
|
||||
t.Fatalf("case %d: expected jderr %#v", i, tc.err)
|
||||
}
|
||||
if !reflect.DeepEqual(actual, tc.result) {
|
||||
t.Fatalf(
|
||||
"case %d: expected %#v, got %#v",
|
||||
i, tc.result, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStringToIPNetHookFunc(t *testing.T) {
|
||||
strValue := reflect.ValueOf("5")
|
||||
ipNetValue := reflect.ValueOf(net.IPNet{})
|
||||
var nilNet *net.IPNet = nil
|
||||
|
||||
cases := []struct {
|
||||
f, t reflect.Value
|
||||
result interface{}
|
||||
err bool
|
||||
}{
|
||||
{reflect.ValueOf("1.2.3.4/24"), ipNetValue,
|
||||
&net.IPNet{
|
||||
IP: net.IP{0x01, 0x02, 0x03, 0x00},
|
||||
Mask: net.IPv4Mask(0xff, 0xff, 0xff, 0x00),
|
||||
}, false},
|
||||
{strValue, ipNetValue, nilNet, true},
|
||||
{strValue, strValue, "5", false},
|
||||
}
|
||||
|
||||
for i, tc := range cases {
|
||||
f := StringToIPNetHookFunc()
|
||||
actual, err := DecodeHookExec(f, tc.f, tc.t)
|
||||
if tc.err != (err != nil) {
|
||||
t.Fatalf("case %d: expected jderr %#v", i, tc.err)
|
||||
}
|
||||
if !reflect.DeepEqual(actual, tc.result) {
|
||||
t.Fatalf(
|
||||
"case %d: expected %#v, got %#v",
|
||||
i, tc.result, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWeaklyTypedHook(t *testing.T) {
|
||||
var f DecodeHookFunc = WeaklyTypedHook
|
||||
|
||||
strValue := reflect.ValueOf("")
|
||||
cases := []struct {
|
||||
f, t reflect.Value
|
||||
result interface{}
|
||||
err bool
|
||||
}{
|
||||
// TO STRING
|
||||
{
|
||||
reflect.ValueOf(false),
|
||||
strValue,
|
||||
"0",
|
||||
false,
|
||||
},
|
||||
|
||||
{
|
||||
reflect.ValueOf(true),
|
||||
strValue,
|
||||
"1",
|
||||
false,
|
||||
},
|
||||
|
||||
{
|
||||
reflect.ValueOf(float32(7)),
|
||||
strValue,
|
||||
"7",
|
||||
false,
|
||||
},
|
||||
|
||||
{
|
||||
reflect.ValueOf(int(7)),
|
||||
strValue,
|
||||
"7",
|
||||
false,
|
||||
},
|
||||
|
||||
{
|
||||
reflect.ValueOf([]uint8("foo")),
|
||||
strValue,
|
||||
"foo",
|
||||
false,
|
||||
},
|
||||
|
||||
{
|
||||
reflect.ValueOf(uint(7)),
|
||||
strValue,
|
||||
"7",
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tc := range cases {
|
||||
actual, err := DecodeHookExec(f, tc.f, tc.t)
|
||||
if tc.err != (err != nil) {
|
||||
t.Fatalf("case %d: expected jderr %#v", i, tc.err)
|
||||
}
|
||||
if !reflect.DeepEqual(actual, tc.result) {
|
||||
t.Fatalf(
|
||||
"case %d: expected %#v, got %#v",
|
||||
i, tc.result, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStructToMapHookFuncTabled(t *testing.T) {
|
||||
var f DecodeHookFunc = RecursiveStructToMapHookFunc()
|
||||
|
||||
type b struct {
|
||||
TestKey string
|
||||
}
|
||||
|
||||
type a struct {
|
||||
Sub b
|
||||
}
|
||||
|
||||
testStruct := a{
|
||||
Sub: b{
|
||||
TestKey: "testval",
|
||||
},
|
||||
}
|
||||
|
||||
testMap := map[string]interface{}{
|
||||
"Sub": map[string]interface{}{
|
||||
"TestKey": "testval",
|
||||
},
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
receiver interface{}
|
||||
input interface{}
|
||||
expected interface{}
|
||||
err bool
|
||||
}{
|
||||
{
|
||||
"map receiver",
|
||||
func() interface{} {
|
||||
var res map[string]interface{}
|
||||
return &res
|
||||
}(),
|
||||
testStruct,
|
||||
&testMap,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"interface receiver",
|
||||
func() interface{} {
|
||||
var res interface{}
|
||||
return &res
|
||||
}(),
|
||||
testStruct,
|
||||
func() interface{} {
|
||||
var exp interface{} = testMap
|
||||
return &exp
|
||||
}(),
|
||||
false,
|
||||
},
|
||||
{
|
||||
"slice receiver errors",
|
||||
func() interface{} {
|
||||
var res []string
|
||||
return &res
|
||||
}(),
|
||||
testStruct,
|
||||
new([]string),
|
||||
true,
|
||||
},
|
||||
{
|
||||
"slice to slice - no change",
|
||||
func() interface{} {
|
||||
var res []string
|
||||
return &res
|
||||
}(),
|
||||
[]string{"a", "b"},
|
||||
&[]string{"a", "b"},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"string to string - no change",
|
||||
func() interface{} {
|
||||
var res string
|
||||
return &res
|
||||
}(),
|
||||
"test",
|
||||
func() *string {
|
||||
s := "test"
|
||||
return &s
|
||||
}(),
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
cfg := &DecoderConfig{
|
||||
DecodeHook: f,
|
||||
Result: tc.receiver,
|
||||
}
|
||||
|
||||
d, err := NewDecoder(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected jderr %#v", err)
|
||||
}
|
||||
|
||||
err = d.Decode(tc.input)
|
||||
if tc.err != (err != nil) {
|
||||
t.Fatalf("expected jderr %#v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(tc.expected, tc.receiver) {
|
||||
t.Fatalf("expected %#v, got %#v",
|
||||
tc.expected, tc.receiver)
|
||||
}
|
||||
})
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func TestTextUnmarshallerHookFunc(t *testing.T) {
|
||||
cases := []struct {
|
||||
f, t reflect.Value
|
||||
result interface{}
|
||||
err bool
|
||||
}{
|
||||
{reflect.ValueOf("42"), reflect.ValueOf(big.Int{}), big.NewInt(42), false},
|
||||
{reflect.ValueOf("invalid"), reflect.ValueOf(big.Int{}), nil, true},
|
||||
{reflect.ValueOf("5"), reflect.ValueOf("5"), "5", false},
|
||||
}
|
||||
|
||||
for i, tc := range cases {
|
||||
f := TextUnmarshallerHookFunc()
|
||||
actual, err := DecodeHookExec(f, tc.f, tc.t)
|
||||
if tc.err != (err != nil) {
|
||||
t.Fatalf("case %d: expected jderr %#v", i, tc.err)
|
||||
}
|
||||
if !reflect.DeepEqual(actual, tc.result) {
|
||||
t.Fatalf(
|
||||
"case %d: expected %#v, got %#v",
|
||||
i, tc.result, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,50 @@
|
|||
package mapstructure
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Error implements the error interface and can represents multiple
|
||||
// errors that occur in the course of a single decode.
|
||||
type Error struct {
|
||||
Errors []string
|
||||
}
|
||||
|
||||
func (e *Error) Error() string {
|
||||
points := make([]string, len(e.Errors))
|
||||
for i, err := range e.Errors {
|
||||
points[i] = fmt.Sprintf("* %s", err)
|
||||
}
|
||||
|
||||
sort.Strings(points)
|
||||
return fmt.Sprintf(
|
||||
"%d error(s) decoding:\n\n%s",
|
||||
len(e.Errors), strings.Join(points, "\n"))
|
||||
}
|
||||
|
||||
// WrappedErrors implements the errwrap.Wrapper interface to make this
|
||||
// return value more useful with the errwrap and go-multierror libraries.
|
||||
func (e *Error) WrappedErrors() []error {
|
||||
if e == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := make([]error, len(e.Errors))
|
||||
for i, e := range e.Errors {
|
||||
result[i] = errors.New(e)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func appendErrors(errors []string, err error) []string {
|
||||
switch e := err.(type) {
|
||||
case *Error:
|
||||
return append(errors, e.Errors...)
|
||||
default:
|
||||
return append(errors, e.Error())
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,285 @@
|
|||
package mapstructure
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type Person struct {
|
||||
Name string
|
||||
Age int
|
||||
Emails []string
|
||||
Extra map[string]string
|
||||
}
|
||||
|
||||
func Benchmark_Decode(b *testing.B) {
|
||||
input := map[string]interface{}{
|
||||
"name": "Mitchell",
|
||||
"age": 91,
|
||||
"emails": []string{"one", "two", "three"},
|
||||
"extra": map[string]string{
|
||||
"twitter": "mitchellh",
|
||||
},
|
||||
}
|
||||
|
||||
var result Person
|
||||
for i := 0; i < b.N; i++ {
|
||||
Decode(input, &result)
|
||||
}
|
||||
}
|
||||
|
||||
// decodeViaJSON takes the map data and passes it through encoding/json to convert it into the
|
||||
// given Go native structure pointed to by v. v must be a pointer to a struct.
|
||||
func decodeViaJSON(data interface{}, v interface{}) error {
|
||||
// Perform the task by simply marshalling the input into JSON,
|
||||
// then unmarshalling it into target native Go struct.
|
||||
b, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return json.Unmarshal(b, v)
|
||||
}
|
||||
|
||||
func Benchmark_DecodeViaJSON(b *testing.B) {
|
||||
input := map[string]interface{}{
|
||||
"name": "Mitchell",
|
||||
"age": 91,
|
||||
"emails": []string{"one", "two", "three"},
|
||||
"extra": map[string]string{
|
||||
"twitter": "mitchellh",
|
||||
},
|
||||
}
|
||||
|
||||
var result Person
|
||||
for i := 0; i < b.N; i++ {
|
||||
decodeViaJSON(input, &result)
|
||||
}
|
||||
}
|
||||
|
||||
func Benchmark_JSONUnmarshal(b *testing.B) {
|
||||
input := map[string]interface{}{
|
||||
"name": "Mitchell",
|
||||
"age": 91,
|
||||
"emails": []string{"one", "two", "three"},
|
||||
"extra": map[string]string{
|
||||
"twitter": "mitchellh",
|
||||
},
|
||||
}
|
||||
|
||||
inputB, err := json.Marshal(input)
|
||||
if err != nil {
|
||||
b.Fatal("Failed to marshal test input:", err)
|
||||
}
|
||||
|
||||
var result Person
|
||||
for i := 0; i < b.N; i++ {
|
||||
json.Unmarshal(inputB, &result)
|
||||
}
|
||||
}
|
||||
|
||||
func Benchmark_DecodeBasic(b *testing.B) {
|
||||
input := map[string]interface{}{
|
||||
"vstring": "foo",
|
||||
"vint": 42,
|
||||
"Vuint": 42,
|
||||
"vbool": true,
|
||||
"Vfloat": 42.42,
|
||||
"vsilent": true,
|
||||
"vdata": 42,
|
||||
"vjsonInt": json.Number("1234"),
|
||||
"vjsonFloat": json.Number("1234.5"),
|
||||
"vjsonNumber": json.Number("1234.5"),
|
||||
}
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
var result Basic
|
||||
Decode(input, &result)
|
||||
}
|
||||
}
|
||||
|
||||
func Benchmark_DecodeEmbedded(b *testing.B) {
|
||||
input := map[string]interface{}{
|
||||
"vstring": "foo",
|
||||
"Basic": map[string]interface{}{
|
||||
"vstring": "innerfoo",
|
||||
},
|
||||
"vunique": "bar",
|
||||
}
|
||||
|
||||
var result Embedded
|
||||
for i := 0; i < b.N; i++ {
|
||||
Decode(input, &result)
|
||||
}
|
||||
}
|
||||
|
||||
func Benchmark_DecodeTypeConversion(b *testing.B) {
|
||||
input := map[string]interface{}{
|
||||
"IntToFloat": 42,
|
||||
"IntToUint": 42,
|
||||
"IntToBool": 1,
|
||||
"IntToString": 42,
|
||||
"UintToInt": 42,
|
||||
"UintToFloat": 42,
|
||||
"UintToBool": 42,
|
||||
"UintToString": 42,
|
||||
"BoolToInt": true,
|
||||
"BoolToUint": true,
|
||||
"BoolToFloat": true,
|
||||
"BoolToString": true,
|
||||
"FloatToInt": 42.42,
|
||||
"FloatToUint": 42.42,
|
||||
"FloatToBool": 42.42,
|
||||
"FloatToString": 42.42,
|
||||
"StringToInt": "42",
|
||||
"StringToUint": "42",
|
||||
"StringToBool": "1",
|
||||
"StringToFloat": "42.42",
|
||||
"SliceToMap": []interface{}{},
|
||||
"MapToSlice": map[string]interface{}{},
|
||||
}
|
||||
|
||||
var resultStrict TypeConversionResult
|
||||
for i := 0; i < b.N; i++ {
|
||||
Decode(input, &resultStrict)
|
||||
}
|
||||
}
|
||||
|
||||
func Benchmark_DecodeMap(b *testing.B) {
|
||||
input := map[string]interface{}{
|
||||
"vfoo": "foo",
|
||||
"vother": map[interface{}]interface{}{
|
||||
"foo": "foo",
|
||||
"bar": "bar",
|
||||
},
|
||||
}
|
||||
|
||||
var result Map
|
||||
for i := 0; i < b.N; i++ {
|
||||
Decode(input, &result)
|
||||
}
|
||||
}
|
||||
|
||||
func Benchmark_DecodeMapOfStruct(b *testing.B) {
|
||||
input := map[string]interface{}{
|
||||
"value": map[string]interface{}{
|
||||
"foo": map[string]string{"vstring": "one"},
|
||||
"bar": map[string]string{"vstring": "two"},
|
||||
},
|
||||
}
|
||||
|
||||
var result MapOfStruct
|
||||
for i := 0; i < b.N; i++ {
|
||||
Decode(input, &result)
|
||||
}
|
||||
}
|
||||
|
||||
func Benchmark_DecodeSlice(b *testing.B) {
|
||||
input := map[string]interface{}{
|
||||
"vfoo": "foo",
|
||||
"vbar": []string{"foo", "bar", "baz"},
|
||||
}
|
||||
|
||||
var result Slice
|
||||
for i := 0; i < b.N; i++ {
|
||||
Decode(input, &result)
|
||||
}
|
||||
}
|
||||
|
||||
func Benchmark_DecodeSliceOfStruct(b *testing.B) {
|
||||
input := map[string]interface{}{
|
||||
"value": []map[string]interface{}{
|
||||
{"vstring": "one"},
|
||||
{"vstring": "two"},
|
||||
},
|
||||
}
|
||||
|
||||
var result SliceOfStruct
|
||||
for i := 0; i < b.N; i++ {
|
||||
Decode(input, &result)
|
||||
}
|
||||
}
|
||||
|
||||
func Benchmark_DecodeWeaklyTypedInput(b *testing.B) {
|
||||
// This input can come from anywhere, but typically comes from
|
||||
// something like decoding JSON, generated by a weakly typed language
|
||||
// such as PHP.
|
||||
input := map[string]interface{}{
|
||||
"name": 123, // number => string
|
||||
"age": "42", // string => number
|
||||
"emails": map[string]interface{}{}, // empty map => empty array
|
||||
}
|
||||
|
||||
var result Person
|
||||
config := &DecoderConfig{
|
||||
WeaklyTypedInput: true,
|
||||
Result: &result,
|
||||
}
|
||||
|
||||
decoder, err := NewDecoder(config)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
decoder.Decode(input)
|
||||
}
|
||||
}
|
||||
|
||||
func Benchmark_DecodeMetadata(b *testing.B) {
|
||||
input := map[string]interface{}{
|
||||
"name": "Mitchell",
|
||||
"age": 91,
|
||||
"email": "foo@bar.com",
|
||||
}
|
||||
|
||||
var md Metadata
|
||||
var result Person
|
||||
config := &DecoderConfig{
|
||||
Metadata: &md,
|
||||
Result: &result,
|
||||
}
|
||||
|
||||
decoder, err := NewDecoder(config)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
decoder.Decode(input)
|
||||
}
|
||||
}
|
||||
|
||||
func Benchmark_DecodeMetadataEmbedded(b *testing.B) {
|
||||
input := map[string]interface{}{
|
||||
"vstring": "foo",
|
||||
"vunique": "bar",
|
||||
}
|
||||
|
||||
var md Metadata
|
||||
var result EmbeddedSquash
|
||||
config := &DecoderConfig{
|
||||
Metadata: &md,
|
||||
Result: &result,
|
||||
}
|
||||
|
||||
decoder, err := NewDecoder(config)
|
||||
if err != nil {
|
||||
b.Fatalf("jderr: %s", err)
|
||||
}
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
decoder.Decode(input)
|
||||
}
|
||||
}
|
||||
|
||||
func Benchmark_DecodeTagged(b *testing.B) {
|
||||
input := map[string]interface{}{
|
||||
"foo": "bar",
|
||||
"bar": "value",
|
||||
}
|
||||
|
||||
var result Tagged
|
||||
for i := 0; i < b.N; i++ {
|
||||
Decode(input, &result)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,627 @@
|
|||
package mapstructure
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// GH-1, GH-10, GH-96
|
||||
func TestDecode_NilValue(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
in interface{}
|
||||
target interface{}
|
||||
out interface{}
|
||||
metaKeys []string
|
||||
metaUnused []string
|
||||
}{
|
||||
{
|
||||
"all nil",
|
||||
&map[string]interface{}{
|
||||
"vfoo": nil,
|
||||
"vother": nil,
|
||||
},
|
||||
&Map{Vfoo: "foo", Vother: map[string]string{"foo": "bar"}},
|
||||
&Map{Vfoo: "", Vother: nil},
|
||||
[]string{"Vfoo", "Vother"},
|
||||
[]string{},
|
||||
},
|
||||
{
|
||||
"partial nil",
|
||||
&map[string]interface{}{
|
||||
"vfoo": "baz",
|
||||
"vother": nil,
|
||||
},
|
||||
&Map{Vfoo: "foo", Vother: map[string]string{"foo": "bar"}},
|
||||
&Map{Vfoo: "baz", Vother: nil},
|
||||
[]string{"Vfoo", "Vother"},
|
||||
[]string{},
|
||||
},
|
||||
{
|
||||
"partial decode",
|
||||
&map[string]interface{}{
|
||||
"vother": nil,
|
||||
},
|
||||
&Map{Vfoo: "foo", Vother: map[string]string{"foo": "bar"}},
|
||||
&Map{Vfoo: "foo", Vother: nil},
|
||||
[]string{"Vother"},
|
||||
[]string{},
|
||||
},
|
||||
{
|
||||
"unused values",
|
||||
&map[string]interface{}{
|
||||
"vbar": "bar",
|
||||
"vfoo": nil,
|
||||
"vother": nil,
|
||||
},
|
||||
&Map{Vfoo: "foo", Vother: map[string]string{"foo": "bar"}},
|
||||
&Map{Vfoo: "", Vother: nil},
|
||||
[]string{"Vfoo", "Vother"},
|
||||
[]string{"vbar"},
|
||||
},
|
||||
{
|
||||
"map interface all nil",
|
||||
&map[interface{}]interface{}{
|
||||
"vfoo": nil,
|
||||
"vother": nil,
|
||||
},
|
||||
&Map{Vfoo: "foo", Vother: map[string]string{"foo": "bar"}},
|
||||
&Map{Vfoo: "", Vother: nil},
|
||||
[]string{"Vfoo", "Vother"},
|
||||
[]string{},
|
||||
},
|
||||
{
|
||||
"map interface partial nil",
|
||||
&map[interface{}]interface{}{
|
||||
"vfoo": "baz",
|
||||
"vother": nil,
|
||||
},
|
||||
&Map{Vfoo: "foo", Vother: map[string]string{"foo": "bar"}},
|
||||
&Map{Vfoo: "baz", Vother: nil},
|
||||
[]string{"Vfoo", "Vother"},
|
||||
[]string{},
|
||||
},
|
||||
{
|
||||
"map interface partial decode",
|
||||
&map[interface{}]interface{}{
|
||||
"vother": nil,
|
||||
},
|
||||
&Map{Vfoo: "foo", Vother: map[string]string{"foo": "bar"}},
|
||||
&Map{Vfoo: "foo", Vother: nil},
|
||||
[]string{"Vother"},
|
||||
[]string{},
|
||||
},
|
||||
{
|
||||
"map interface unused values",
|
||||
&map[interface{}]interface{}{
|
||||
"vbar": "bar",
|
||||
"vfoo": nil,
|
||||
"vother": nil,
|
||||
},
|
||||
&Map{Vfoo: "foo", Vother: map[string]string{"foo": "bar"}},
|
||||
&Map{Vfoo: "", Vother: nil},
|
||||
[]string{"Vfoo", "Vother"},
|
||||
[]string{"vbar"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
config := &DecoderConfig{
|
||||
Metadata: new(Metadata),
|
||||
Result: tc.target,
|
||||
ZeroFields: true,
|
||||
}
|
||||
|
||||
decoder, err := NewDecoder(config)
|
||||
if err != nil {
|
||||
t.Fatalf("should not error: %s", err)
|
||||
}
|
||||
|
||||
err = decoder.Decode(tc.in)
|
||||
if err != nil {
|
||||
t.Fatalf("should not error: %s", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(tc.out, tc.target) {
|
||||
t.Fatalf("%q: TestDecode_NilValue() expected: %#v, got: %#v", tc.name, tc.out, tc.target)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(tc.metaKeys, config.Metadata.Keys) {
|
||||
t.Fatalf("%q: Metadata.Keys mismatch expected: %#v, got: %#v", tc.name, tc.metaKeys, config.Metadata.Keys)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(tc.metaUnused, config.Metadata.Unused) {
|
||||
t.Fatalf("%q: Metadata.Unused mismatch expected: %#v, got: %#v", tc.name, tc.metaUnused, config.Metadata.Unused)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// #48
|
||||
func TestNestedTypePointerWithDefaults(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
input := map[string]interface{}{
|
||||
"vfoo": "foo",
|
||||
"vbar": map[string]interface{}{
|
||||
"vstring": "foo",
|
||||
"vint": 42,
|
||||
"vbool": true,
|
||||
},
|
||||
}
|
||||
|
||||
result := NestedPointer{
|
||||
Vbar: &Basic{
|
||||
Vuint: 42,
|
||||
},
|
||||
}
|
||||
err := Decode(input, &result)
|
||||
if err != nil {
|
||||
t.Fatalf("got an jderr: %s", err.Error())
|
||||
}
|
||||
|
||||
if result.Vfoo != "foo" {
|
||||
t.Errorf("vfoo value should be 'foo': %#v", result.Vfoo)
|
||||
}
|
||||
|
||||
if result.Vbar.Vstring != "foo" {
|
||||
t.Errorf("vstring value should be 'foo': %#v", result.Vbar.Vstring)
|
||||
}
|
||||
|
||||
if result.Vbar.Vint != 42 {
|
||||
t.Errorf("vint value should be 42: %#v", result.Vbar.Vint)
|
||||
}
|
||||
|
||||
if result.Vbar.Vbool != true {
|
||||
t.Errorf("vbool value should be true: %#v", result.Vbar.Vbool)
|
||||
}
|
||||
|
||||
if result.Vbar.Vextra != "" {
|
||||
t.Errorf("vextra value should be empty: %#v", result.Vbar.Vextra)
|
||||
}
|
||||
|
||||
// this is the error
|
||||
if result.Vbar.Vuint != 42 {
|
||||
t.Errorf("vuint value should be 42: %#v", result.Vbar.Vuint)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
type NestedSlice struct {
|
||||
Vfoo string
|
||||
Vbars []Basic
|
||||
Vempty []Basic
|
||||
}
|
||||
|
||||
// #48
|
||||
func TestNestedTypeSliceWithDefaults(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
input := map[string]interface{}{
|
||||
"vfoo": "foo",
|
||||
"vbars": []map[string]interface{}{
|
||||
{"vstring": "foo", "vint": 42, "vbool": true},
|
||||
{"vint": 42, "vbool": true},
|
||||
},
|
||||
"vempty": []map[string]interface{}{
|
||||
{"vstring": "foo", "vint": 42, "vbool": true},
|
||||
{"vint": 42, "vbool": true},
|
||||
},
|
||||
}
|
||||
|
||||
result := NestedSlice{
|
||||
Vbars: []Basic{
|
||||
{Vuint: 42},
|
||||
{Vstring: "foo"},
|
||||
},
|
||||
}
|
||||
err := Decode(input, &result)
|
||||
if err != nil {
|
||||
t.Fatalf("got an jderr: %s", err.Error())
|
||||
}
|
||||
|
||||
if result.Vfoo != "foo" {
|
||||
t.Errorf("vfoo value should be 'foo': %#v", result.Vfoo)
|
||||
}
|
||||
|
||||
if result.Vbars[0].Vstring != "foo" {
|
||||
t.Errorf("vstring value should be 'foo': %#v", result.Vbars[0].Vstring)
|
||||
}
|
||||
// this is the error
|
||||
if result.Vbars[0].Vuint != 42 {
|
||||
t.Errorf("vuint value should be 42: %#v", result.Vbars[0].Vuint)
|
||||
}
|
||||
}
|
||||
|
||||
// #48 workaround
|
||||
func TestNestedTypeWithDefaults(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
input := map[string]interface{}{
|
||||
"vfoo": "foo",
|
||||
"vbar": map[string]interface{}{
|
||||
"vstring": "foo",
|
||||
"vint": 42,
|
||||
"vbool": true,
|
||||
},
|
||||
}
|
||||
|
||||
result := Nested{
|
||||
Vbar: Basic{
|
||||
Vuint: 42,
|
||||
},
|
||||
}
|
||||
err := Decode(input, &result)
|
||||
if err != nil {
|
||||
t.Fatalf("got an jderr: %s", err.Error())
|
||||
}
|
||||
|
||||
if result.Vfoo != "foo" {
|
||||
t.Errorf("vfoo value should be 'foo': %#v", result.Vfoo)
|
||||
}
|
||||
|
||||
if result.Vbar.Vstring != "foo" {
|
||||
t.Errorf("vstring value should be 'foo': %#v", result.Vbar.Vstring)
|
||||
}
|
||||
|
||||
if result.Vbar.Vint != 42 {
|
||||
t.Errorf("vint value should be 42: %#v", result.Vbar.Vint)
|
||||
}
|
||||
|
||||
if result.Vbar.Vbool != true {
|
||||
t.Errorf("vbool value should be true: %#v", result.Vbar.Vbool)
|
||||
}
|
||||
|
||||
if result.Vbar.Vextra != "" {
|
||||
t.Errorf("vextra value should be empty: %#v", result.Vbar.Vextra)
|
||||
}
|
||||
|
||||
// this is the error
|
||||
if result.Vbar.Vuint != 42 {
|
||||
t.Errorf("vuint value should be 42: %#v", result.Vbar.Vuint)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// #67 panic() on extending slices (decodeSlice with disabled ZeroValues)
|
||||
func TestDecodeSliceToEmptySliceWOZeroing(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type TestStruct struct {
|
||||
Vfoo []string
|
||||
}
|
||||
|
||||
decode := func(m interface{}, rawVal interface{}) error {
|
||||
config := &DecoderConfig{
|
||||
Metadata: nil,
|
||||
Result: rawVal,
|
||||
ZeroFields: false,
|
||||
}
|
||||
|
||||
decoder, err := NewDecoder(config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return decoder.Decode(m)
|
||||
}
|
||||
|
||||
{
|
||||
input := map[string]interface{}{
|
||||
"vfoo": []string{"1"},
|
||||
}
|
||||
|
||||
result := &TestStruct{}
|
||||
|
||||
err := decode(input, &result)
|
||||
if err != nil {
|
||||
t.Fatalf("got an jderr: %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
input := map[string]interface{}{
|
||||
"vfoo": []string{"1"},
|
||||
}
|
||||
|
||||
result := &TestStruct{
|
||||
Vfoo: []string{},
|
||||
}
|
||||
|
||||
err := decode(input, &result)
|
||||
if err != nil {
|
||||
t.Fatalf("got an jderr: %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
input := map[string]interface{}{
|
||||
"vfoo": []string{"2", "3"},
|
||||
}
|
||||
|
||||
result := &TestStruct{
|
||||
Vfoo: []string{"1"},
|
||||
}
|
||||
|
||||
err := decode(input, &result)
|
||||
if err != nil {
|
||||
t.Fatalf("got an jderr: %s", err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// #70
|
||||
func TestNextSquashMapstructure(t *testing.T) {
|
||||
data := &struct {
|
||||
Level1 struct {
|
||||
Level2 struct {
|
||||
Foo string
|
||||
} `mapstructure:",squash"`
|
||||
} `mapstructure:",squash"`
|
||||
}{}
|
||||
err := Decode(map[interface{}]interface{}{"foo": "baz"}, &data)
|
||||
if err != nil {
|
||||
t.Fatalf("should not error: %s", err)
|
||||
}
|
||||
if data.Level1.Level2.Foo != "baz" {
|
||||
t.Fatal("value should be baz")
|
||||
}
|
||||
}
|
||||
|
||||
type ImplementsInterfacePointerReceiver struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
func (i *ImplementsInterfacePointerReceiver) DoStuff() {}
|
||||
|
||||
type ImplementsInterfaceValueReceiver string
|
||||
|
||||
func (i ImplementsInterfaceValueReceiver) DoStuff() {}
|
||||
|
||||
// GH-140 Type error when using DecodeHook to decode into interface
|
||||
func TestDecode_DecodeHookInterface(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type Interface interface {
|
||||
DoStuff()
|
||||
}
|
||||
type DecodeIntoInterface struct {
|
||||
Test Interface
|
||||
}
|
||||
|
||||
testData := map[string]string{"test": "test"}
|
||||
|
||||
stringToPointerInterfaceDecodeHook := func(from, to reflect.Type, data interface{}) (interface{}, error) {
|
||||
if from.Kind() != reflect.String {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
if to != reflect.TypeOf((*Interface)(nil)).Elem() {
|
||||
return data, nil
|
||||
}
|
||||
// Ensure interface is satisfied
|
||||
var impl Interface = &ImplementsInterfacePointerReceiver{data.(string)}
|
||||
return impl, nil
|
||||
}
|
||||
|
||||
stringToValueInterfaceDecodeHook := func(from, to reflect.Type, data interface{}) (interface{}, error) {
|
||||
if from.Kind() != reflect.String {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
if to != reflect.TypeOf((*Interface)(nil)).Elem() {
|
||||
return data, nil
|
||||
}
|
||||
// Ensure interface is satisfied
|
||||
var impl Interface = ImplementsInterfaceValueReceiver(data.(string))
|
||||
return impl, nil
|
||||
}
|
||||
|
||||
{
|
||||
decodeInto := new(DecodeIntoInterface)
|
||||
|
||||
decoder, _ := NewDecoder(&DecoderConfig{
|
||||
DecodeHook: stringToPointerInterfaceDecodeHook,
|
||||
Result: decodeInto,
|
||||
})
|
||||
|
||||
err := decoder.Decode(testData)
|
||||
if err != nil {
|
||||
t.Fatalf("Decode returned error: %s", err)
|
||||
}
|
||||
|
||||
expected := &ImplementsInterfacePointerReceiver{"test"}
|
||||
if !reflect.DeepEqual(decodeInto.Test, expected) {
|
||||
t.Fatalf("expected: %#v (%T), got: %#v (%T)", decodeInto.Test, decodeInto.Test, expected, expected)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
decodeInto := new(DecodeIntoInterface)
|
||||
|
||||
decoder, _ := NewDecoder(&DecoderConfig{
|
||||
DecodeHook: stringToValueInterfaceDecodeHook,
|
||||
Result: decodeInto,
|
||||
})
|
||||
|
||||
err := decoder.Decode(testData)
|
||||
if err != nil {
|
||||
t.Fatalf("Decode returned error: %s", err)
|
||||
}
|
||||
|
||||
expected := ImplementsInterfaceValueReceiver("test")
|
||||
if !reflect.DeepEqual(decodeInto.Test, expected) {
|
||||
t.Fatalf("expected: %#v (%T), got: %#v (%T)", decodeInto.Test, decodeInto.Test, expected, expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// #103 Check for data type before trying to access its composants prevent a panic error
|
||||
// in decodeSlice
|
||||
func TestDecodeBadDataTypeInSlice(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
input := map[string]interface{}{
|
||||
"Toto": "titi",
|
||||
}
|
||||
result := []struct {
|
||||
Toto string
|
||||
}{}
|
||||
|
||||
if err := Decode(input, &result); err == nil {
|
||||
t.Error("An error was expected, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// #202 Ensure that intermediate maps in the struct -> struct decode process are settable
|
||||
// and not just the elements within them.
|
||||
func TestDecodeIntermediateMapsSettable(t *testing.T) {
|
||||
type Timestamp struct {
|
||||
Seconds int64
|
||||
Nanos int32
|
||||
}
|
||||
|
||||
type TsWrapper struct {
|
||||
Timestamp *Timestamp
|
||||
}
|
||||
|
||||
type TimeWrapper struct {
|
||||
Timestamp time.Time
|
||||
}
|
||||
|
||||
input := TimeWrapper{
|
||||
Timestamp: time.Unix(123456789, 987654),
|
||||
}
|
||||
|
||||
expected := TsWrapper{
|
||||
Timestamp: &Timestamp{
|
||||
Seconds: 123456789,
|
||||
Nanos: 987654,
|
||||
},
|
||||
}
|
||||
|
||||
timePtrType := reflect.TypeOf((*time.Time)(nil))
|
||||
mapStrInfType := reflect.TypeOf((map[string]interface{})(nil))
|
||||
|
||||
var actual TsWrapper
|
||||
decoder, err := NewDecoder(&DecoderConfig{
|
||||
Result: &actual,
|
||||
DecodeHook: func(from, to reflect.Type, data interface{}) (interface{}, error) {
|
||||
if from == timePtrType && to == mapStrInfType {
|
||||
ts := data.(*time.Time)
|
||||
nanos := ts.UnixNano()
|
||||
|
||||
seconds := nanos / 1000000000
|
||||
nanos = nanos % 1000000000
|
||||
|
||||
return &map[string]interface{}{
|
||||
"Seconds": seconds,
|
||||
"Nanos": int32(nanos),
|
||||
}, nil
|
||||
}
|
||||
return data, nil
|
||||
},
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create decoder: %v", err)
|
||||
}
|
||||
|
||||
if err := decoder.Decode(&input); err != nil {
|
||||
t.Fatalf("failed to decode input: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(expected, actual) {
|
||||
t.Fatalf("expected: %#[1]v (%[1]T), got: %#[2]v (%[2]T)", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
// GH-206: decodeInt throws an error for an empty string
|
||||
func TestDecode_weakEmptyStringToInt(t *testing.T) {
|
||||
input := map[string]interface{}{
|
||||
"StringToInt": "",
|
||||
"StringToUint": "",
|
||||
"StringToBool": "",
|
||||
"StringToFloat": "",
|
||||
}
|
||||
|
||||
expectedResultWeak := TypeConversionResult{
|
||||
StringToInt: 0,
|
||||
StringToUint: 0,
|
||||
StringToBool: false,
|
||||
StringToFloat: 0,
|
||||
}
|
||||
|
||||
// Test weak type conversion
|
||||
var resultWeak TypeConversionResult
|
||||
err := WeakDecode(input, &resultWeak)
|
||||
if err != nil {
|
||||
t.Fatalf("got an jderr: %s", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(resultWeak, expectedResultWeak) {
|
||||
t.Errorf("expected \n%#v, got: \n%#v", expectedResultWeak, resultWeak)
|
||||
}
|
||||
}
|
||||
|
||||
// GH-228: Squash cause *time.Time set to zero
|
||||
func TestMapSquash(t *testing.T) {
|
||||
type AA struct {
|
||||
T *time.Time
|
||||
}
|
||||
type A struct {
|
||||
AA
|
||||
}
|
||||
|
||||
v := time.Now()
|
||||
in := &AA{
|
||||
T: &v,
|
||||
}
|
||||
out := &A{}
|
||||
d, err := NewDecoder(&DecoderConfig{
|
||||
Squash: true,
|
||||
Result: out,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("jderr: %s", err)
|
||||
}
|
||||
if err := d.Decode(in); err != nil {
|
||||
t.Fatalf("jderr: %s", err)
|
||||
}
|
||||
|
||||
// these failed
|
||||
if !v.Equal(*out.T) {
|
||||
t.Fatal("expected equal")
|
||||
}
|
||||
if out.T.IsZero() {
|
||||
t.Fatal("expected false")
|
||||
}
|
||||
}
|
||||
|
||||
// GH-238: Empty key name when decoding map from struct with only omitempty flag
|
||||
func TestMapOmitEmptyWithEmptyFieldnameInTag(t *testing.T) {
|
||||
type Struct struct {
|
||||
Username string `mapstructure:",omitempty"`
|
||||
Age int `mapstructure:",omitempty"`
|
||||
}
|
||||
|
||||
s := Struct{
|
||||
Username: "Joe",
|
||||
}
|
||||
var m map[string]interface{}
|
||||
|
||||
if err := Decode(s, &m); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(m) != 1 {
|
||||
t.Fatalf("fail: %#v", m)
|
||||
}
|
||||
if m["Username"] != "Joe" {
|
||||
t.Fatalf("fail: %#v", m)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,256 @@
|
|||
package mapstructure
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func ExampleDecode() {
|
||||
type Person struct {
|
||||
Name string
|
||||
Age int
|
||||
Emails []string
|
||||
Extra map[string]string
|
||||
}
|
||||
|
||||
// This input can come from anywhere, but typically comes from
|
||||
// something like decoding JSON where we're not quite sure of the
|
||||
// struct initially.
|
||||
input := map[string]interface{}{
|
||||
"name": "Mitchell",
|
||||
"age": 91,
|
||||
"emails": []string{"one", "two", "three"},
|
||||
"extra": map[string]string{
|
||||
"twitter": "mitchellh",
|
||||
},
|
||||
}
|
||||
|
||||
var result Person
|
||||
err := Decode(input, &result)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
fmt.Printf("%#v", result)
|
||||
// Output:
|
||||
// mapstructure.Person{Name:"Mitchell", Age:91, Emails:[]string{"one", "two", "three"}, Extra:map[string]string{"twitter":"mitchellh"}}
|
||||
}
|
||||
|
||||
func ExampleDecode_errors() {
|
||||
type Person struct {
|
||||
Name string
|
||||
Age int
|
||||
Emails []string
|
||||
Extra map[string]string
|
||||
}
|
||||
|
||||
// This input can come from anywhere, but typically comes from
|
||||
// something like decoding JSON where we're not quite sure of the
|
||||
// struct initially.
|
||||
input := map[string]interface{}{
|
||||
"name": 123,
|
||||
"age": "bad value",
|
||||
"emails": []int{1, 2, 3},
|
||||
}
|
||||
|
||||
var result Person
|
||||
err := Decode(input, &result)
|
||||
if err == nil {
|
||||
panic("should have an error")
|
||||
}
|
||||
|
||||
fmt.Println(err.Error())
|
||||
// Output:
|
||||
// 5 error(s) decoding:
|
||||
//
|
||||
// * 'Age' expected type 'int', got unconvertible type 'string', value: 'bad value'
|
||||
// * 'Emails[0]' expected type 'string', got unconvertible type 'int', value: '1'
|
||||
// * 'Emails[1]' expected type 'string', got unconvertible type 'int', value: '2'
|
||||
// * 'Emails[2]' expected type 'string', got unconvertible type 'int', value: '3'
|
||||
// * 'Name' expected type 'string', got unconvertible type 'int', value: '123'
|
||||
}
|
||||
|
||||
func ExampleDecode_metadata() {
|
||||
type Person struct {
|
||||
Name string
|
||||
Age int
|
||||
}
|
||||
|
||||
// This input can come from anywhere, but typically comes from
|
||||
// something like decoding JSON where we're not quite sure of the
|
||||
// struct initially.
|
||||
input := map[string]interface{}{
|
||||
"name": "Mitchell",
|
||||
"age": 91,
|
||||
"email": "foo@bar.com",
|
||||
}
|
||||
|
||||
// For metadata, we make a more advanced DecoderConfig so we can
|
||||
// more finely configure the decoder that is used. In this case, we
|
||||
// just tell the decoder we want to track metadata.
|
||||
var md Metadata
|
||||
var result Person
|
||||
config := &DecoderConfig{
|
||||
Metadata: &md,
|
||||
Result: &result,
|
||||
}
|
||||
|
||||
decoder, err := NewDecoder(config)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if err := decoder.Decode(input); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
fmt.Printf("Unused keys: %#v", md.Unused)
|
||||
// Output:
|
||||
// Unused keys: []string{"email"}
|
||||
}
|
||||
|
||||
func ExampleDecode_weaklyTypedInput() {
|
||||
type Person struct {
|
||||
Name string
|
||||
Age int
|
||||
Emails []string
|
||||
}
|
||||
|
||||
// This input can come from anywhere, but typically comes from
|
||||
// something like decoding JSON, generated by a weakly typed language
|
||||
// such as PHP.
|
||||
input := map[string]interface{}{
|
||||
"name": 123, // number => string
|
||||
"age": "42", // string => number
|
||||
"emails": map[string]interface{}{}, // empty map => empty array
|
||||
}
|
||||
|
||||
var result Person
|
||||
config := &DecoderConfig{
|
||||
WeaklyTypedInput: true,
|
||||
Result: &result,
|
||||
}
|
||||
|
||||
decoder, err := NewDecoder(config)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err = decoder.Decode(input)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
fmt.Printf("%#v", result)
|
||||
// Output: mapstructure.Person{Name:"123", Age:42, Emails:[]string{}}
|
||||
}
|
||||
|
||||
func ExampleDecode_tags() {
|
||||
// Note that the mapstructure tags defined in the struct type
|
||||
// can indicate which fields the values are mapped to.
|
||||
type Person struct {
|
||||
Name string `mapstructure:"person_name"`
|
||||
Age int `mapstructure:"person_age"`
|
||||
}
|
||||
|
||||
input := map[string]interface{}{
|
||||
"person_name": "Mitchell",
|
||||
"person_age": 91,
|
||||
}
|
||||
|
||||
var result Person
|
||||
err := Decode(input, &result)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
fmt.Printf("%#v", result)
|
||||
// Output:
|
||||
// mapstructure.Person{Name:"Mitchell", Age:91}
|
||||
}
|
||||
|
||||
func ExampleDecode_embeddedStruct() {
|
||||
// Squashing multiple embedded structs is allowed using the squash tag.
|
||||
// This is demonstrated by creating a composite struct of multiple types
|
||||
// and decoding into it. In this case, a person can carry with it both
|
||||
// a Family and a Location, as well as their own FirstName.
|
||||
type Family struct {
|
||||
LastName string
|
||||
}
|
||||
type Location struct {
|
||||
City string
|
||||
}
|
||||
type Person struct {
|
||||
Family `mapstructure:",squash"`
|
||||
Location `mapstructure:",squash"`
|
||||
FirstName string
|
||||
}
|
||||
|
||||
input := map[string]interface{}{
|
||||
"FirstName": "Mitchell",
|
||||
"LastName": "Hashimoto",
|
||||
"City": "San Francisco",
|
||||
}
|
||||
|
||||
var result Person
|
||||
err := Decode(input, &result)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
fmt.Printf("%s %s, %s", result.FirstName, result.LastName, result.City)
|
||||
// Output:
|
||||
// Mitchell Hashimoto, San Francisco
|
||||
}
|
||||
|
||||
func ExampleDecode_remainingData() {
|
||||
// Note that the mapstructure tags defined in the struct type
|
||||
// can indicate which fields the values are mapped to.
|
||||
type Person struct {
|
||||
Name string
|
||||
Age int
|
||||
Other map[string]interface{} `mapstructure:",remain"`
|
||||
}
|
||||
|
||||
input := map[string]interface{}{
|
||||
"name": "Mitchell",
|
||||
"age": 91,
|
||||
"email": "mitchell@example.com",
|
||||
}
|
||||
|
||||
var result Person
|
||||
err := Decode(input, &result)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
fmt.Printf("%#v", result)
|
||||
// Output:
|
||||
// mapstructure.Person{Name:"Mitchell", Age:91, Other:map[string]interface {}{"email":"mitchell@example.com"}}
|
||||
}
|
||||
|
||||
func ExampleDecode_omitempty() {
|
||||
// Add omitempty annotation to avoid map keys for empty values
|
||||
type Family struct {
|
||||
LastName string
|
||||
}
|
||||
type Location struct {
|
||||
City string
|
||||
}
|
||||
type Person struct {
|
||||
*Family `mapstructure:",omitempty"`
|
||||
*Location `mapstructure:",omitempty"`
|
||||
Age int
|
||||
FirstName string
|
||||
}
|
||||
|
||||
result := &map[string]interface{}{}
|
||||
input := Person{FirstName: "Somebody"}
|
||||
err := Decode(input, &result)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
fmt.Printf("%+v", result)
|
||||
// Output:
|
||||
// &map[Age:0 FirstName:Somebody]
|
||||
}
|
||||
|
|
@ -0,0 +1,58 @@
|
|||
package mapstructure
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDecode_Ptr(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type G struct {
|
||||
Id int
|
||||
Name string
|
||||
}
|
||||
|
||||
type X struct {
|
||||
Id int
|
||||
Name int
|
||||
}
|
||||
|
||||
type AG struct {
|
||||
List []*G
|
||||
}
|
||||
|
||||
type AX struct {
|
||||
List []*X
|
||||
}
|
||||
|
||||
g2 := &AG{
|
||||
List: []*G{
|
||||
{
|
||||
Id: 11,
|
||||
Name: "gg",
|
||||
},
|
||||
},
|
||||
}
|
||||
x2 := AX{}
|
||||
|
||||
// 报错但还是会转换成功,转换后值为目标类型的 0 值
|
||||
err := Decode(g2, &x2)
|
||||
|
||||
res := AX{
|
||||
List: []*X{
|
||||
{
|
||||
Id: 11,
|
||||
Name: 0, // 这个类型的 0 值
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
t.Errorf("Decode_Ptr jderr should not be 'nil': %#v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(res, x2) {
|
||||
t.Errorf("result should be %#v: got %#v", res, x2)
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,26 @@
|
|||
package mapstructure
|
||||
|
||||
import "time"
|
||||
|
||||
// DecodeWithTime 支持时间转字符串
|
||||
// 支持
|
||||
// 1. *Time.time 转 string/*string
|
||||
// 2. *Time.time 转 uint/uint32/uint64/int/int32/int64,支持带指针
|
||||
// 不能用 Time.time 转,它会在上层认为是一个结构体数据而直接转成map,再到hook方法
|
||||
func DecodeWithTime(input, output interface{}, layout string) error {
|
||||
if layout == "" {
|
||||
layout = time.DateTime
|
||||
}
|
||||
config := &DecoderConfig{
|
||||
Metadata: nil,
|
||||
Result: output,
|
||||
DecodeHook: ComposeDecodeHookFunc(TimeToStringHook(layout), TimeToUnixIntHook()),
|
||||
}
|
||||
|
||||
decoder, err := NewDecoder(config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return decoder.Decode(input)
|
||||
}
|
||||
|
|
@ -0,0 +1,101 @@
|
|||
package mapstructure
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TimeToStringHook 时间转字符串
|
||||
// 支持 *Time.time 转 string/*string
|
||||
// 不能用 Time.time 转,它会在上层认为是一个结构体数据而直接转成map,再到hook方法
|
||||
func TimeToStringHook(layout string) DecodeHookFunc {
|
||||
return func(
|
||||
f reflect.Type,
|
||||
t reflect.Type,
|
||||
data interface{}) (interface{}, error) {
|
||||
// 判断目标类型是否为字符串
|
||||
var strType string
|
||||
var isStrPointer *bool // 要转换的目标类型是否为指针字符串
|
||||
if t == reflect.TypeOf(strType) {
|
||||
isStrPointer = new(bool)
|
||||
} else if t == reflect.TypeOf(&strType) {
|
||||
isStrPointer = new(bool)
|
||||
*isStrPointer = true
|
||||
}
|
||||
if isStrPointer == nil {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// 判断类型是否为时间
|
||||
timeType := time.Time{}
|
||||
if f != reflect.TypeOf(timeType) && f != reflect.TypeOf(&timeType) {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// 将时间转换为字符串
|
||||
var output string
|
||||
switch v := data.(type) {
|
||||
case *time.Time:
|
||||
output = v.Format(layout)
|
||||
case time.Time:
|
||||
output = v.Format(layout)
|
||||
default:
|
||||
return data, nil
|
||||
}
|
||||
|
||||
if *isStrPointer {
|
||||
return &output, nil
|
||||
}
|
||||
return output, nil
|
||||
}
|
||||
}
|
||||
|
||||
// TimeToUnixIntHook 时间转时间戳
|
||||
// 支持 *Time.time 转 uint/uint32/uint64/int/int32/int64,支持带指针
|
||||
// 不能用 Time.time 转,它会在上层认为是一个结构体数据而直接转成map,再到hook方法
|
||||
func TimeToUnixIntHook() DecodeHookFunc {
|
||||
return func(
|
||||
f reflect.Type,
|
||||
t reflect.Type,
|
||||
data interface{}) (interface{}, error) {
|
||||
|
||||
tkd := t.Kind()
|
||||
if tkd != reflect.Int && tkd != reflect.Int32 && tkd != reflect.Int64 &&
|
||||
tkd != reflect.Uint && tkd != reflect.Uint32 && tkd != reflect.Uint64 {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// 判断类型是否为时间
|
||||
timeType := time.Time{}
|
||||
if f != reflect.TypeOf(timeType) && f != reflect.TypeOf(&timeType) {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// 将时间转换为字符串
|
||||
var output int64
|
||||
switch v := data.(type) {
|
||||
case *time.Time:
|
||||
output = v.Unix()
|
||||
case time.Time:
|
||||
output = v.Unix()
|
||||
default:
|
||||
return data, nil
|
||||
}
|
||||
switch tkd {
|
||||
case reflect.Int:
|
||||
return int(output), nil
|
||||
case reflect.Int32:
|
||||
return int32(output), nil
|
||||
case reflect.Int64:
|
||||
return output, nil
|
||||
case reflect.Uint:
|
||||
return uint(output), nil
|
||||
case reflect.Uint32:
|
||||
return uint32(output), nil
|
||||
case reflect.Uint64:
|
||||
return uint64(output), nil
|
||||
default:
|
||||
return data, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,274 @@
|
|||
package mapstructure
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func Test_TimeToStringHook(t *testing.T) {
|
||||
type Input struct {
|
||||
Time time.Time
|
||||
Id int
|
||||
}
|
||||
|
||||
type InputTPointer struct {
|
||||
Time *time.Time
|
||||
Id int
|
||||
}
|
||||
|
||||
type Output struct {
|
||||
Time string
|
||||
Id int
|
||||
}
|
||||
|
||||
type OutputTPointer struct {
|
||||
Time *string
|
||||
Id int
|
||||
}
|
||||
now := time.Now()
|
||||
target := now.Format("2006-01-02 15:04:05")
|
||||
idValue := 1
|
||||
tests := []struct {
|
||||
input any
|
||||
output any
|
||||
name string
|
||||
layout string
|
||||
}{
|
||||
{
|
||||
name: "测试Time.time转string",
|
||||
layout: "2006-01-02 15:04:05",
|
||||
input: InputTPointer{
|
||||
Time: &now,
|
||||
Id: idValue,
|
||||
},
|
||||
output: Output{},
|
||||
},
|
||||
{
|
||||
name: "测试*Time.time转*string",
|
||||
layout: "2006-01-02 15:04:05",
|
||||
input: InputTPointer{
|
||||
Time: &now,
|
||||
Id: idValue,
|
||||
},
|
||||
output: OutputTPointer{},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
decoder, err := NewDecoder(&DecoderConfig{
|
||||
DecodeHook: TimeToStringHook(tt.layout),
|
||||
Result: &tt.output,
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("NewDecoder() jderr = %v,want nil", err)
|
||||
}
|
||||
|
||||
if i, isOk := tt.input.(Input); isOk {
|
||||
err = decoder.Decode(i)
|
||||
}
|
||||
if i, isOk := tt.input.(InputTPointer); isOk {
|
||||
err = decoder.Decode(&i)
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("Decode jderr = %v,want nil", err)
|
||||
}
|
||||
//验证测试值
|
||||
if output, isOk := tt.output.(OutputTPointer); isOk {
|
||||
if *output.Time != target {
|
||||
t.Errorf("Decode output time = %v,want %v", *output.Time, target)
|
||||
}
|
||||
if output.Id != idValue {
|
||||
t.Errorf("Decode output id = %v,want %v", output.Id, idValue)
|
||||
}
|
||||
}
|
||||
if output, isOk := tt.output.(Output); isOk {
|
||||
if output.Time != target {
|
||||
t.Errorf("Decode output time = %v,want %v", output.Time, target)
|
||||
}
|
||||
if output.Id != idValue {
|
||||
t.Errorf("Decode output id = %v,want %v", output.Id, idValue)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_TimeToUnixIntHook(t *testing.T) {
|
||||
type InputTPointer struct {
|
||||
Time *time.Time
|
||||
Id int
|
||||
}
|
||||
|
||||
type Output[T int | *int | int32 | *int32 | int64 | *int64 | uint | *uint] struct {
|
||||
Time T
|
||||
Id int
|
||||
}
|
||||
|
||||
type test struct {
|
||||
input any
|
||||
output any
|
||||
name string
|
||||
layout string
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
target := now.Unix()
|
||||
idValue := 1
|
||||
tests := []test{
|
||||
{
|
||||
name: "测试Time.time转int",
|
||||
layout: "2006-01-02 15:04:05",
|
||||
input: InputTPointer{
|
||||
Time: &now,
|
||||
Id: idValue,
|
||||
},
|
||||
output: Output[int]{},
|
||||
},
|
||||
{
|
||||
name: "测试Time.time转*int",
|
||||
layout: "2006-01-02 15:04:05",
|
||||
input: InputTPointer{
|
||||
Time: &now,
|
||||
Id: idValue,
|
||||
},
|
||||
output: Output[*int]{},
|
||||
},
|
||||
{
|
||||
name: "测试Time.time转int32",
|
||||
layout: "2006-01-02 15:04:05",
|
||||
input: InputTPointer{
|
||||
Time: &now,
|
||||
Id: idValue,
|
||||
},
|
||||
output: Output[int32]{},
|
||||
},
|
||||
{
|
||||
name: "测试Time.time转*int32",
|
||||
layout: "2006-01-02 15:04:05",
|
||||
input: InputTPointer{
|
||||
Time: &now,
|
||||
Id: idValue,
|
||||
},
|
||||
output: Output[*int32]{},
|
||||
},
|
||||
{
|
||||
name: "测试Time.time转int64",
|
||||
layout: "2006-01-02 15:04:05",
|
||||
input: InputTPointer{
|
||||
Time: &now,
|
||||
Id: idValue,
|
||||
},
|
||||
output: Output[int64]{},
|
||||
},
|
||||
{
|
||||
name: "测试Time.time转*int64",
|
||||
layout: "2006-01-02 15:04:05",
|
||||
input: InputTPointer{
|
||||
Time: &now,
|
||||
Id: idValue,
|
||||
},
|
||||
output: Output[*int64]{},
|
||||
},
|
||||
{
|
||||
name: "测试Time.time转uint",
|
||||
layout: "2006-01-02 15:04:05",
|
||||
input: InputTPointer{
|
||||
Time: &now,
|
||||
Id: idValue,
|
||||
},
|
||||
output: Output[uint]{},
|
||||
},
|
||||
{
|
||||
name: "测试Time.time转*uint",
|
||||
layout: "2006-01-02 15:04:05",
|
||||
input: InputTPointer{
|
||||
Time: &now,
|
||||
Id: idValue,
|
||||
},
|
||||
output: Output[*uint]{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
decoder, err := NewDecoder(&DecoderConfig{
|
||||
DecodeHook: TimeToUnixIntHook(),
|
||||
Result: &tt.output,
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("NewDecoder() jderr = %v,want nil", err)
|
||||
}
|
||||
|
||||
if i, isOk := tt.input.(InputTPointer); isOk {
|
||||
err = decoder.Decode(i)
|
||||
}
|
||||
if i, isOk := tt.input.(InputTPointer); isOk {
|
||||
err = decoder.Decode(&i)
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("Decode jderr = %v,want nil", err)
|
||||
}
|
||||
|
||||
//验证测试值
|
||||
switch v := tt.output.(type) {
|
||||
case Output[int]:
|
||||
if int64(v.Time) != target {
|
||||
t.Errorf("Decode output time = %v,want %v", v.Time, target)
|
||||
}
|
||||
if v.Id != idValue {
|
||||
t.Errorf("Decode output id = %v,want %v", v.Id, idValue)
|
||||
}
|
||||
case Output[*int]:
|
||||
if int64(*v.Time) != target {
|
||||
t.Errorf("Decode output time = %v,want %v", v.Time, target)
|
||||
}
|
||||
if v.Id != idValue {
|
||||
t.Errorf("Decode output id = %v,want %v", v.Id, idValue)
|
||||
}
|
||||
case Output[int32]:
|
||||
if int64(v.Time) != target {
|
||||
t.Errorf("Decode output time = %v,want %v", v.Time, target)
|
||||
}
|
||||
if v.Id != idValue {
|
||||
t.Errorf("Decode output id = %v,want %v", v.Id, idValue)
|
||||
}
|
||||
case Output[*int32]:
|
||||
if int64(*v.Time) != target {
|
||||
t.Errorf("Decode output time = %v,want %v", v.Time, target)
|
||||
}
|
||||
if v.Id != idValue {
|
||||
t.Errorf("Decode output id = %v,want %v", v.Id, idValue)
|
||||
}
|
||||
case Output[int64]:
|
||||
if int64(v.Time) != target {
|
||||
t.Errorf("Decode output time = %v,want %v", v.Time, target)
|
||||
}
|
||||
if v.Id != idValue {
|
||||
t.Errorf("Decode output id = %v,want %v", v.Id, idValue)
|
||||
}
|
||||
case Output[*int64]:
|
||||
if int64(*v.Time) != target {
|
||||
t.Errorf("Decode output time = %v,want %v", v.Time, target)
|
||||
}
|
||||
if v.Id != idValue {
|
||||
t.Errorf("Decode output id = %v,want %v", v.Id, idValue)
|
||||
}
|
||||
case Output[uint]:
|
||||
if int64(v.Time) != target {
|
||||
t.Errorf("Decode output time = %v,want %v", v.Time, target)
|
||||
}
|
||||
if v.Id != idValue {
|
||||
t.Errorf("Decode output id = %v,want %v", v.Id, idValue)
|
||||
}
|
||||
case Output[*uint]:
|
||||
if int64(*v.Time) != target {
|
||||
t.Errorf("Decode output time = %v,want %v", v.Time, target)
|
||||
}
|
||||
if v.Id != idValue {
|
||||
t.Errorf("Decode output id = %v,want %v", v.Id, idValue)
|
||||
}
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
package pkg
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/log"
|
||||
)
|
||||
|
||||
func HandleResponse(c *fiber.Ctx, data interface{}) (err error) {
|
||||
if os.Getenv("env") == "unit_test" {
|
||||
log.Debug(data)
|
||||
}
|
||||
switch data.(type) {
|
||||
case error:
|
||||
err = data.(error)
|
||||
case int, int32, int64, float32, float64, string, bool:
|
||||
c.Response().SetBody([]byte(fmt.Sprintf("%s", data)))
|
||||
case []byte:
|
||||
c.Response().SetBody(data.([]byte))
|
||||
default:
|
||||
dataByte, _ := json.Marshal(data)
|
||||
c.Response().SetBody(dataByte)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func SuccessWithPageMsg(c *fiber.Ctx, list interface{}, total int64, page, pageSize int) error {
|
||||
response := fiber.Map{
|
||||
"list": list,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"pageSize": pageSize,
|
||||
}
|
||||
return HandleResponse(c, response)
|
||||
}
|
||||
|
|
@ -0,0 +1,743 @@
|
|||
package pkg
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"geo/tmpl/errcode"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
const (
|
||||
// 验证标签常量
|
||||
TagRequired = "required"
|
||||
TagEmail = "email"
|
||||
TagMin = "min"
|
||||
TagMax = "max"
|
||||
TagLen = "len"
|
||||
TagNumeric = "numeric"
|
||||
TagOneof = "oneof"
|
||||
TagGte = "gte"
|
||||
TagGt = "gt"
|
||||
TagLte = "lte"
|
||||
TagLt = "lt"
|
||||
TagURL = "url"
|
||||
TagUUID = "uuid"
|
||||
TagDatetime = "datetime"
|
||||
|
||||
// 中文标签常量
|
||||
//LabelComment = "comment"
|
||||
//LabelLabel = "label"
|
||||
LabelZh = "zh"
|
||||
|
||||
// 上下文Key
|
||||
CtxRequestBody = "request_body"
|
||||
|
||||
// 性能优化常量
|
||||
InitialBuilderSize = 64
|
||||
MaxCacheSize = 5000
|
||||
CacheCleanupInterval = 10 * time.Minute
|
||||
)
|
||||
|
||||
// ValidatorConfig 验证器配置
|
||||
type ValidatorConfig struct {
|
||||
StatusCode int // 验证失败时的HTTP状态码,默认422
|
||||
ErrorHandler func(c *fiber.Ctx, err error) error // 自定义错误处理
|
||||
BeforeParse func(c *fiber.Ctx) error // 解析前执行
|
||||
AfterParse func(c *fiber.Ctx, req interface{}) error // 解析后执行
|
||||
DisableCache bool // 是否禁用缓存
|
||||
MaxCacheSize int // 最大缓存数量
|
||||
UseChinese bool // 是否使用中文提示
|
||||
EnableMetrics bool // 是否启用指标收集
|
||||
}
|
||||
|
||||
// CacheStats 缓存统计
|
||||
type CacheStats struct {
|
||||
hitCount int64
|
||||
missCount int64
|
||||
evictCount int64
|
||||
errorCount int64
|
||||
accessCount int64
|
||||
}
|
||||
|
||||
// Metrics 指标收集
|
||||
type Metrics struct {
|
||||
totalRequests int64
|
||||
validationTime int64
|
||||
cacheHitRate float64
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// fieldInfo 字段信息
|
||||
type fieldInfo struct {
|
||||
label string
|
||||
index int
|
||||
typeKind reflect.Kind
|
||||
accessCount int64
|
||||
lastAccess time.Time
|
||||
}
|
||||
|
||||
// typeInfo 类型信息
|
||||
type typeInfo struct {
|
||||
fields map[string]*fieldInfo
|
||||
fieldNames []string
|
||||
mu sync.RWMutex
|
||||
accessCount int64
|
||||
createdAt time.Time
|
||||
lastAccess time.Time
|
||||
typeKey string
|
||||
}
|
||||
|
||||
// ValidatorHelper 验证器助手
|
||||
type ValidatorHelper struct {
|
||||
validate *validator.Validate
|
||||
config *ValidatorConfig
|
||||
typeCache sync.Map
|
||||
cacheStats *CacheStats
|
||||
errorFnCache map[string]func(string, string) string
|
||||
errorFnCacheMu sync.RWMutex
|
||||
pruneLock sync.Mutex
|
||||
stopCleanup chan struct{}
|
||||
cleanupOnce sync.Once
|
||||
metrics *Metrics
|
||||
}
|
||||
|
||||
var (
|
||||
ChineseErrorTemplates = map[string]string{
|
||||
TagRequired: "%s不能为空",
|
||||
TagEmail: "%s格式不正确",
|
||||
TagMin: "%s不能小于%s",
|
||||
TagMax: "%s不能大于%s",
|
||||
TagLen: "%s长度必须为%s位",
|
||||
TagNumeric: "%s必须是数字",
|
||||
TagOneof: "%s必须是以下值之一: %s",
|
||||
TagGte: "%s不能小于%s",
|
||||
TagGt: "%s必须大于%s",
|
||||
TagLte: "%s不能大于%s",
|
||||
TagLt: "%s必须小于%s",
|
||||
TagURL: "%s必须是有效的URL地址",
|
||||
TagUUID: "%s必须是有效的UUID",
|
||||
TagDatetime: "%s日期格式不正确",
|
||||
}
|
||||
|
||||
defaultErrorTemplates = map[string]string{
|
||||
TagRequired: "%s is required",
|
||||
TagEmail: "invalid email format",
|
||||
TagMin: "%s must be at least %s",
|
||||
TagMax: "%s must be at most %s",
|
||||
TagLen: "%s must be exactly %s characters",
|
||||
TagNumeric: "%s must be numeric",
|
||||
TagOneof: "%s must be one of: %s",
|
||||
TagGte: "%s must be greater than or equal to %s",
|
||||
TagGt: "%s must be greater than %s",
|
||||
TagLte: "%s must be less than or equal to %s",
|
||||
TagLt: "%s must be less than %s",
|
||||
TagURL: "%s must be a valid URL",
|
||||
TagUUID: "%s must be a valid UUID",
|
||||
TagDatetime: "%s invalid datetime format",
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
// 对象池
|
||||
builderPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
b := &strings.Builder{}
|
||||
b.Grow(InitialBuilderSize)
|
||||
return b
|
||||
},
|
||||
}
|
||||
|
||||
// 错误信息切片池
|
||||
errorSlicePool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
slice := make([]string, 0, 8)
|
||||
return &slice
|
||||
},
|
||||
}
|
||||
|
||||
// 字段信息对象池
|
||||
fieldInfoPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return &fieldInfo{
|
||||
accessCount: 0,
|
||||
lastAccess: time.Now(),
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
// 类型信息对象池
|
||||
typeInfoPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return &typeInfo{
|
||||
fields: make(map[string]*fieldInfo),
|
||||
fieldNames: make([]string, 0, 8),
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
Vh *ValidatorHelper
|
||||
once sync.Once
|
||||
)
|
||||
|
||||
// NewValidatorHelper 初始化验证器助手
|
||||
func NewValidatorHelper(config ...*ValidatorConfig) {
|
||||
once.Do(func() {
|
||||
v := validator.New()
|
||||
|
||||
// 优化JSON标签获取
|
||||
v.RegisterTagNameFunc(func(fld reflect.StructField) string {
|
||||
return fld.Tag.Get("json")
|
||||
})
|
||||
|
||||
// 默认配置
|
||||
cfg := &ValidatorConfig{
|
||||
StatusCode: fiber.StatusUnprocessableEntity,
|
||||
ErrorHandler: defaultErrorHandler,
|
||||
MaxCacheSize: MaxCacheSize,
|
||||
UseChinese: true,
|
||||
EnableMetrics: false,
|
||||
}
|
||||
|
||||
if len(config) > 0 && config[0] != nil {
|
||||
if config[0].StatusCode != 0 {
|
||||
cfg.StatusCode = config[0].StatusCode
|
||||
}
|
||||
if config[0].ErrorHandler != nil {
|
||||
cfg.ErrorHandler = config[0].ErrorHandler
|
||||
}
|
||||
if config[0].MaxCacheSize > 0 {
|
||||
cfg.MaxCacheSize = config[0].MaxCacheSize
|
||||
}
|
||||
cfg.DisableCache = config[0].DisableCache
|
||||
cfg.BeforeParse = config[0].BeforeParse
|
||||
cfg.AfterParse = config[0].AfterParse
|
||||
cfg.UseChinese = config[0].UseChinese
|
||||
cfg.EnableMetrics = config[0].EnableMetrics
|
||||
}
|
||||
|
||||
// 预编译错误函数
|
||||
errorFnCache := make(map[string]func(string, string) string)
|
||||
templates := ChineseErrorTemplates
|
||||
if !cfg.UseChinese {
|
||||
templates = defaultErrorTemplates
|
||||
}
|
||||
|
||||
for tag, tmpl := range templates {
|
||||
t := tmpl // 捕获变量
|
||||
errorFnCache[tag] = func(field, param string) string {
|
||||
if strings.Contains(t, "%s") && strings.Count(t, "%s") == 2 {
|
||||
return fmt.Sprintf(t, field, param)
|
||||
}
|
||||
return fmt.Sprintf(t, field)
|
||||
}
|
||||
}
|
||||
|
||||
Vh = &ValidatorHelper{
|
||||
validate: v,
|
||||
config: cfg,
|
||||
cacheStats: &CacheStats{},
|
||||
errorFnCache: errorFnCache,
|
||||
errorFnCacheMu: sync.RWMutex{},
|
||||
stopCleanup: make(chan struct{}),
|
||||
metrics: &Metrics{},
|
||||
}
|
||||
|
||||
// 启动定期清理
|
||||
if !cfg.DisableCache {
|
||||
go Vh.periodicCleanup()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// ParseAndValidate 解析并验证请求
|
||||
func ParseAndValidate(c *fiber.Ctx, req interface{}) error {
|
||||
if Vh == nil {
|
||||
NewValidatorHelper()
|
||||
}
|
||||
|
||||
if Vh.config.EnableMetrics {
|
||||
atomic.AddInt64(&Vh.metrics.totalRequests, 1)
|
||||
defer Vh.recordValidationTime(time.Now())
|
||||
}
|
||||
|
||||
// 执行解析前钩子
|
||||
if Vh.config.BeforeParse != nil {
|
||||
if err := Vh.config.BeforeParse(c); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// 解析请求体
|
||||
err := c.BodyParser(req)
|
||||
if err != nil {
|
||||
return errcode.ParamErr("请求格式错误:" + err.Error())
|
||||
}
|
||||
|
||||
// 执行解析后钩子
|
||||
if Vh.config.AfterParse != nil {
|
||||
if err = Vh.config.AfterParse(c, req); err != nil {
|
||||
return errcode.ParamErr(err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// 验证数据
|
||||
err = Vh.validate.Struct(req)
|
||||
if err != nil {
|
||||
c.Locals(CtxRequestBody, req)
|
||||
|
||||
if !Vh.config.DisableCache {
|
||||
t := reflect.TypeOf(req)
|
||||
if t.Kind() == reflect.Ptr {
|
||||
t = t.Elem()
|
||||
}
|
||||
if t.Kind() == reflect.Struct {
|
||||
Vh.safeGetOrCreateTypeInfo(t)
|
||||
}
|
||||
}
|
||||
|
||||
return Vh.config.ErrorHandler(c, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate 直接验证结构体
|
||||
func Validate(req interface{}) error {
|
||||
if Vh == nil {
|
||||
NewValidatorHelper()
|
||||
}
|
||||
|
||||
if err := Vh.validate.Struct(req); err != nil {
|
||||
return Vh.wrapValidationError(err, req)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 默认错误处理
|
||||
func defaultErrorHandler(c *fiber.Ctx, err error) error {
|
||||
validationErrors, ok := err.(validator.ValidationErrors)
|
||||
if !ok {
|
||||
return errcode.SystemError
|
||||
}
|
||||
|
||||
if len(validationErrors) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 快速路径:单个错误
|
||||
if len(validationErrors) == 1 {
|
||||
e := validationErrors[0]
|
||||
msg := Vh.safeGetErrorMessage(c, e)
|
||||
return errcode.ParamErr(msg)
|
||||
}
|
||||
|
||||
// 从对象池获取builder
|
||||
builder := builderPool.Get().(*strings.Builder)
|
||||
builder.Reset()
|
||||
defer builderPool.Put(builder)
|
||||
|
||||
req := c.Locals(CtxRequestBody)
|
||||
|
||||
for i, e := range validationErrors {
|
||||
if i > 0 {
|
||||
builder.WriteByte('\n')
|
||||
}
|
||||
builder.WriteString(Vh.safeGetErrorMessageWithReq(req, e))
|
||||
}
|
||||
|
||||
return errcode.ParamErr(builder.String())
|
||||
}
|
||||
|
||||
// 包装验证错误
|
||||
func (vh *ValidatorHelper) wrapValidationError(err error, req interface{}) error {
|
||||
validationErrors, ok := err.(validator.ValidationErrors)
|
||||
if !ok {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(validationErrors) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 构建错误消息
|
||||
builder := builderPool.Get().(*strings.Builder)
|
||||
builder.Reset()
|
||||
defer builderPool.Put(builder)
|
||||
|
||||
for i, e := range validationErrors {
|
||||
if i > 0 {
|
||||
builder.WriteByte('\n')
|
||||
}
|
||||
builder.WriteString(vh.safeGetErrorMessageWithReq(req, e))
|
||||
}
|
||||
|
||||
return errcode.ParamErr(builder.String())
|
||||
}
|
||||
|
||||
// 安全获取错误消息
|
||||
func (vh *ValidatorHelper) safeGetErrorMessage(c *fiber.Ctx, e validator.FieldError) string {
|
||||
req := c.Locals(CtxRequestBody)
|
||||
return vh.safeGetErrorMessageWithReq(req, e)
|
||||
}
|
||||
|
||||
// 安全获取错误消息(带请求体)
|
||||
func (vh *ValidatorHelper) safeGetErrorMessageWithReq(req interface{}, e validator.FieldError) string {
|
||||
if req == nil {
|
||||
return vh.safeFormatFieldError(e, nil)
|
||||
}
|
||||
|
||||
t := reflect.TypeOf(req)
|
||||
if t.Kind() == reflect.Ptr {
|
||||
t = t.Elem()
|
||||
}
|
||||
if t.Kind() != reflect.Struct {
|
||||
return vh.safeFormatFieldError(e, nil)
|
||||
}
|
||||
|
||||
// 安全获取类型信息 - 这里需要先声明变量
|
||||
var typeInfoObj *typeInfo // 这里声明变量
|
||||
if !vh.config.DisableCache {
|
||||
if cached, ok := vh.typeCache.Load(t); ok {
|
||||
typeInfoObj = cached.(*typeInfo)
|
||||
atomic.AddInt64(&typeInfoObj.accessCount, 1)
|
||||
}
|
||||
}
|
||||
|
||||
return vh.safeFormatFieldError(e, typeInfoObj)
|
||||
}
|
||||
|
||||
// 安全格式化字段错误
|
||||
func (vh *ValidatorHelper) safeFormatFieldError(e validator.FieldError, typeInfo *typeInfo) string {
|
||||
structField := e.StructField()
|
||||
fieldName := e.Field()
|
||||
|
||||
// 获取字段标签
|
||||
var label string
|
||||
if typeInfo != nil {
|
||||
typeInfo.mu.RLock()
|
||||
if info, ok := typeInfo.fields[structField]; ok {
|
||||
label = info.label
|
||||
atomic.AddInt64(&info.accessCount, 1)
|
||||
}
|
||||
typeInfo.mu.RUnlock()
|
||||
}
|
||||
|
||||
// 如果没有标签,返回默认消息
|
||||
if label == "" {
|
||||
return vh.safeGetDefaultErrorMessage(fieldName, e)
|
||||
}
|
||||
|
||||
// 使用预编译的错误函数生成消息
|
||||
vh.errorFnCacheMu.RLock()
|
||||
fn, ok := vh.errorFnCache[e.Tag()]
|
||||
vh.errorFnCacheMu.RUnlock()
|
||||
|
||||
if ok {
|
||||
return fn(label, e.Param())
|
||||
}
|
||||
|
||||
return label + "格式不正确"
|
||||
}
|
||||
|
||||
// 安全获取默认错误消息
|
||||
func (vh *ValidatorHelper) safeGetDefaultErrorMessage(field string, e validator.FieldError) string {
|
||||
vh.errorFnCacheMu.RLock()
|
||||
defer vh.errorFnCacheMu.RUnlock()
|
||||
|
||||
if fn, ok := vh.errorFnCache[e.Tag()]; ok {
|
||||
return fn(field, e.Param())
|
||||
}
|
||||
|
||||
return field + "验证失败"
|
||||
}
|
||||
|
||||
// safeGetOrCreateTypeInfo 安全地获取或创建类型信息
|
||||
func (vh *ValidatorHelper) safeGetOrCreateTypeInfo(t reflect.Type) *typeInfo {
|
||||
if vh.config.DisableCache || t == nil || t.Kind() != reflect.Struct {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 首先尝试从缓存读取
|
||||
if cached, ok := vh.typeCache.Load(t); ok {
|
||||
info := cached.(*typeInfo)
|
||||
atomic.AddInt64(&info.accessCount, 1)
|
||||
atomic.AddInt64(&vh.cacheStats.hitCount, 1)
|
||||
info.lastAccess = time.Now()
|
||||
return info
|
||||
}
|
||||
|
||||
atomic.AddInt64(&vh.cacheStats.missCount, 1)
|
||||
|
||||
// 从对象池获取typeInfo
|
||||
info := typeInfoPool.Get().(*typeInfo)
|
||||
|
||||
// 重置并初始化
|
||||
info.mu.Lock()
|
||||
|
||||
// 清空现有map
|
||||
for k := range info.fields {
|
||||
delete(info.fields, k)
|
||||
}
|
||||
info.fieldNames = info.fieldNames[:0]
|
||||
|
||||
info.accessCount = 0
|
||||
info.createdAt = time.Now()
|
||||
info.lastAccess = time.Now()
|
||||
info.typeKey = t.String()
|
||||
|
||||
// 预计算所有字段信息
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
field := t.Field(i)
|
||||
|
||||
// 获取标签
|
||||
//label := field.Tag.Get(LabelComment)
|
||||
//if label == "" {
|
||||
// label = field.Tag.Get(LabelLabel)
|
||||
//}
|
||||
//if label == "" {
|
||||
// label = field.Tag.Get(LabelZh)
|
||||
//}
|
||||
label := field.Tag.Get(LabelZh)
|
||||
// 从对象池获取或创建字段信息
|
||||
fieldInfo := fieldInfoPool.Get().(*fieldInfo)
|
||||
fieldInfo.label = label
|
||||
fieldInfo.index = i
|
||||
fieldInfo.typeKind = field.Type.Kind()
|
||||
fieldInfo.accessCount = 0
|
||||
fieldInfo.lastAccess = time.Now()
|
||||
|
||||
info.fields[field.Name] = fieldInfo
|
||||
info.fieldNames = append(info.fieldNames, field.Name)
|
||||
}
|
||||
|
||||
info.mu.Unlock()
|
||||
|
||||
// 使用原子操作确保线程安全的存储
|
||||
if existing, loaded := vh.typeCache.LoadOrStore(t, info); loaded {
|
||||
// 如果已经有其他goroutine存储了,使用已有的并回收新创建的
|
||||
info.mu.Lock()
|
||||
for _, fieldInfo := range info.fields {
|
||||
fieldInfoPool.Put(fieldInfo)
|
||||
}
|
||||
info.mu.Unlock()
|
||||
typeInfoPool.Put(info)
|
||||
|
||||
existingInfo := existing.(*typeInfo)
|
||||
atomic.AddInt64(&existingInfo.accessCount, 1)
|
||||
return existingInfo
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
// 定期清理缓存
|
||||
func (vh *ValidatorHelper) periodicCleanup() {
|
||||
ticker := time.NewTicker(CacheCleanupInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
vh.safeCleanupCache()
|
||||
case <-vh.stopCleanup:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// safeCleanupCache 安全清理缓存
|
||||
func (vh *ValidatorHelper) safeCleanupCache() {
|
||||
vh.pruneLock.Lock()
|
||||
defer vh.pruneLock.Unlock()
|
||||
|
||||
var keysToDelete []interface{}
|
||||
now := time.Now()
|
||||
|
||||
vh.typeCache.Range(func(key, value interface{}) bool {
|
||||
info, ok := value.(*typeInfo)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
|
||||
// 检查是否需要清理
|
||||
accessCount := atomic.LoadInt64(&info.accessCount)
|
||||
age := now.Sub(info.createdAt)
|
||||
idleTime := now.Sub(info.lastAccess)
|
||||
|
||||
// 清理条件:
|
||||
// 1. 很少访问的缓存(访问次数 < 10)
|
||||
// 2. 空闲时间超过30分钟
|
||||
// 3. 缓存年龄超过1小时且访问次数较少
|
||||
if (accessCount < 10 && idleTime > 30*time.Minute) ||
|
||||
(age > 1*time.Hour && accessCount < 100) {
|
||||
keysToDelete = append(keysToDelete, key)
|
||||
atomic.AddInt64(&vh.cacheStats.evictCount, 1)
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
// 删除选中的缓存
|
||||
for _, key := range keysToDelete {
|
||||
if val, ok := vh.typeCache.Load(key); ok {
|
||||
if info, ok := val.(*typeInfo); ok {
|
||||
// 安全回收字段信息
|
||||
info.mu.Lock()
|
||||
for _, fieldInfo := range info.fields {
|
||||
fieldInfoPool.Put(fieldInfo)
|
||||
}
|
||||
info.mu.Unlock()
|
||||
typeInfoPool.Put(info)
|
||||
}
|
||||
vh.typeCache.Delete(key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 指标收集 ====================
|
||||
|
||||
func (vh *ValidatorHelper) recordValidationTime(start time.Time) {
|
||||
if !vh.config.EnableMetrics {
|
||||
return
|
||||
}
|
||||
duration := time.Since(start).Nanoseconds()
|
||||
atomic.AddInt64(&vh.metrics.validationTime, duration)
|
||||
}
|
||||
|
||||
// GetMetrics 获取性能指标
|
||||
func (vh *ValidatorHelper) GetMetrics() map[string]interface{} {
|
||||
if !vh.config.EnableMetrics {
|
||||
return nil
|
||||
}
|
||||
|
||||
vh.metrics.mu.RLock()
|
||||
defer vh.metrics.mu.RUnlock()
|
||||
|
||||
hitCount := atomic.LoadInt64(&vh.cacheStats.hitCount)
|
||||
missCount := atomic.LoadInt64(&vh.cacheStats.missCount)
|
||||
totalRequests := hitCount + missCount
|
||||
|
||||
var hitRate float64
|
||||
if totalRequests > 0 {
|
||||
hitRate = float64(hitCount) / float64(totalRequests) * 100
|
||||
}
|
||||
|
||||
var cacheSize int64
|
||||
vh.typeCache.Range(func(_, _ interface{}) bool {
|
||||
cacheSize++
|
||||
return true
|
||||
})
|
||||
|
||||
return map[string]interface{}{
|
||||
"cache_hit_count": hitCount,
|
||||
"cache_miss_count": missCount,
|
||||
"cache_evict_count": atomic.LoadInt64(&vh.cacheStats.evictCount),
|
||||
"cache_hit_rate": fmt.Sprintf("%.2f%%", hitRate),
|
||||
"cache_size": cacheSize,
|
||||
"error_count": atomic.LoadInt64(&vh.cacheStats.errorCount),
|
||||
"total_requests": atomic.LoadInt64(&vh.metrics.totalRequests),
|
||||
"avg_validation_time_ns": atomic.LoadInt64(&vh.metrics.validationTime) /
|
||||
max(atomic.LoadInt64(&vh.metrics.totalRequests), 1),
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterValidation 注册自定义验证规则
|
||||
func (vh *ValidatorHelper) RegisterValidation(tag string, fn validator.Func, callValidationEvenIfNull ...bool) error {
|
||||
return vh.validate.RegisterValidation(tag, fn, callValidationEvenIfNull...)
|
||||
}
|
||||
|
||||
// RegisterTranslation 注册自定义翻译
|
||||
func (vh *ValidatorHelper) RegisterTranslation(tag string, template string) {
|
||||
vh.errorFnCacheMu.Lock()
|
||||
defer vh.errorFnCacheMu.Unlock()
|
||||
|
||||
t := template
|
||||
vh.errorFnCache[tag] = func(field, param string) string {
|
||||
if strings.Contains(t, "%s") && strings.Count(t, "%s") == 2 {
|
||||
return fmt.Sprintf(t, field, param)
|
||||
}
|
||||
return fmt.Sprintf(t, field)
|
||||
}
|
||||
}
|
||||
|
||||
// Stop 停止后台清理任务
|
||||
func (vh *ValidatorHelper) Stop() {
|
||||
close(vh.stopCleanup)
|
||||
}
|
||||
|
||||
// Reset 重置验证器状态
|
||||
func (vh *ValidatorHelper) Reset() {
|
||||
vh.ClearCache()
|
||||
atomic.StoreInt64(&vh.cacheStats.hitCount, 0)
|
||||
atomic.StoreInt64(&vh.cacheStats.missCount, 0)
|
||||
atomic.StoreInt64(&vh.cacheStats.evictCount, 0)
|
||||
atomic.StoreInt64(&vh.cacheStats.errorCount, 0)
|
||||
atomic.StoreInt64(&vh.metrics.totalRequests, 0)
|
||||
atomic.StoreInt64(&vh.metrics.validationTime, 0)
|
||||
}
|
||||
|
||||
// ClearCache 清理所有缓存
|
||||
func (vh *ValidatorHelper) ClearCache() {
|
||||
vh.pruneLock.Lock()
|
||||
defer vh.pruneLock.Unlock()
|
||||
|
||||
// 安全回收所有缓存
|
||||
vh.typeCache.Range(func(key, value interface{}) bool {
|
||||
if info, ok := value.(*typeInfo); ok {
|
||||
info.mu.Lock()
|
||||
for _, fieldInfo := range info.fields {
|
||||
fieldInfoPool.Put(fieldInfo)
|
||||
}
|
||||
info.mu.Unlock()
|
||||
typeInfoPool.Put(info)
|
||||
}
|
||||
vh.typeCache.Delete(key)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// SetLanguage 设置语言
|
||||
func (vh *ValidatorHelper) SetLanguage(useChinese bool) {
|
||||
vh.config.UseChinese = useChinese
|
||||
|
||||
templates := defaultErrorTemplates
|
||||
if useChinese {
|
||||
templates = ChineseErrorTemplates
|
||||
}
|
||||
|
||||
vh.errorFnCacheMu.Lock()
|
||||
defer vh.errorFnCacheMu.Unlock()
|
||||
|
||||
for tag, tmpl := range templates {
|
||||
t := tmpl
|
||||
vh.errorFnCache[tag] = func(field, param string) string {
|
||||
if strings.Contains(t, "%s") && strings.Count(t, "%s") == 2 {
|
||||
return fmt.Sprintf(t, field, param)
|
||||
}
|
||||
return fmt.Sprintf(t, field)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func max(a, b int64) int64 {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func GetErr(tag, field, param string) string {
|
||||
t := ChineseErrorTemplates[tag]
|
||||
if strings.Contains(t, "%s") && strings.Count(t, "%s") == 2 {
|
||||
return fmt.Sprintf(t, field, param)
|
||||
}
|
||||
return fmt.Sprintf(t, field)
|
||||
}
|
||||
|
|
@ -0,0 +1,195 @@
|
|||
package pkg
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"net/http"
|
||||
|
||||
"errors"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
type WeChatLoginResponse struct {
|
||||
OpenID string `json:"openid"` // 用户唯一标识
|
||||
SessionKey string `json:"session_key"` // 会话密钥
|
||||
UnionID string `json:"unionid"` // 用户在开放平台的唯一标识(如果绑定了开放平台才有)
|
||||
Errcode int `json:"errcode"` // 错误码,0为成功
|
||||
Errmsg string `json:"errmsg"` // 错误信息
|
||||
}
|
||||
|
||||
func GetOpenID(appID, appSecret, code string) (openid string, err error) {
|
||||
if os.Getenv("env") == "unit_test" {
|
||||
return "test_123456", nil
|
||||
}
|
||||
// 1. 构建请求微信接口的 URL
|
||||
url := fmt.Sprintf("https://api.weixin.qq.com/sns/jscode2session?appid=%s&secret=%s&js_code=%s&grant_type=authorization_code",
|
||||
appID, appSecret, code)
|
||||
|
||||
// 2. 创建 HTTP 客户端(设置超时,避免阻塞)
|
||||
client := &http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
// 在某些网络受限的环境(如本地测试跳过证书验证,生产环境建议去掉)
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: false}, // 生产环境建议设为 false
|
||||
},
|
||||
}
|
||||
|
||||
// 3. 发起 GET 请求
|
||||
resp, err := client.Get(url)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("请求微信服务器失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 4. 读取返回的 Body
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("读取微信响应失败: %w", err)
|
||||
}
|
||||
|
||||
// 5. 解析 JSON 数据
|
||||
var wechatResp WeChatLoginResponse
|
||||
err = json.Unmarshal(body, &wechatResp)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("解析微信响应 JSON 失败: %s, 原始数据: %s", err.Error(), string(body))
|
||||
}
|
||||
|
||||
// 6. 检查微信接口返回的错误码
|
||||
if wechatResp.Errcode != 0 {
|
||||
// 这里可以根据不同的错误码做特殊处理,例如 code 无效、过期等
|
||||
return "", fmt.Errorf("微信接口返回错误: code=%d, msg=%s", wechatResp.Errcode, wechatResp.Errmsg)
|
||||
}
|
||||
|
||||
// 7. 检查 OpenID 是否为空(理论上不会,但防御性编程)
|
||||
if wechatResp.OpenID == "" {
|
||||
return "", errors.New("微信返回的 OpenID 为空")
|
||||
}
|
||||
|
||||
// 8. 返回 OpenID
|
||||
return wechatResp.OpenID, nil
|
||||
}
|
||||
|
||||
type AccessTokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
Errcode int `json:"errcode"`
|
||||
Errmsg string `json:"errmsg"`
|
||||
}
|
||||
|
||||
var (
|
||||
tokenMutex sync.Mutex
|
||||
cacheKey = "wx:access_token"
|
||||
)
|
||||
|
||||
// GetAccessToken 获取 access_token,带本地缓存
|
||||
func GetAccessToken(ctx context.Context, appID, appSecret string, rdb *redis.Client) (string, error) {
|
||||
if rdb == nil {
|
||||
return "", errors.New("缓存工具未提供")
|
||||
}
|
||||
|
||||
cacheToken := rdb.Get(ctx, cacheKey).Val()
|
||||
if cacheToken != "" {
|
||||
return cacheToken, nil
|
||||
}
|
||||
tokenMutex.Lock()
|
||||
defer tokenMutex.Unlock()
|
||||
|
||||
// 请求微信接口获取新的 access_token
|
||||
url := fmt.Sprintf("https://api.weixin.qq.com/cgi-bin/token?grant_type=client_credential&appid=%s&secret=%s", appID, appSecret)
|
||||
|
||||
client := &http.Client{Timeout: 5 * time.Second}
|
||||
resp, err := client.Get(url)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("请求 access_token 失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
var tokenRes AccessTokenResponse
|
||||
if err := json.Unmarshal(body, &tokenRes); err != nil {
|
||||
return "", fmt.Errorf("解析 access_token 响应失败: %w", err)
|
||||
}
|
||||
|
||||
if tokenRes.Errcode != 0 {
|
||||
return "", fmt.Errorf("获取 access_token 失败: code=%d, msg=%s", tokenRes.Errcode, tokenRes.Errmsg)
|
||||
}
|
||||
|
||||
// 缓存 token,提前5分钟过期,避免边界情况
|
||||
|
||||
rdb.Set(ctx, cacheKey, tokenRes.AccessToken, time.Duration(tokenRes.ExpiresIn-300)*time.Second)
|
||||
return tokenRes.AccessToken, nil
|
||||
}
|
||||
|
||||
// PhoneInfo 定义手机号信息的结构体,与微信官方文档对齐 [citation:3][citation:8]
|
||||
type PhoneInfo struct {
|
||||
PhoneNumber string `json:"phoneNumber"` // 用户绑定的手机号(国外手机号会有区号)
|
||||
PurePhoneNumber string `json:"purePhoneNumber"` // 没有区号的手机号
|
||||
CountryCode string `json:"countryCode"` // 区号
|
||||
Watermark struct {
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
Appid string `json:"appid"`
|
||||
} `json:"watermark"`
|
||||
}
|
||||
|
||||
// PhoneInfoResponse 定义微信接口返回的完整结构
|
||||
type PhoneInfoResponse struct {
|
||||
Errcode int `json:"errcode"`
|
||||
Errmsg string `json:"errmsg"`
|
||||
PhoneInfo PhoneInfo `json:"phone_info"`
|
||||
}
|
||||
|
||||
// GetPhoneNumber 通过手机号 code 获取用户手机号
|
||||
// 参数:
|
||||
// - appID: 小程序的 AppID
|
||||
// - appSecret: 小程序的 AppSecret
|
||||
// - phoneCode: 前端通过 getPhoneNumber 获取的 code
|
||||
//
|
||||
// 返回:
|
||||
// - *PhoneInfo: 手机号信息
|
||||
// - error: 错误信息
|
||||
func GetPhoneNumber(ctx context.Context, appID, appSecret, phoneCode string, rdb *redis.Client) (*PhoneInfo, error) {
|
||||
// 1. 获取 access_token
|
||||
accessToken, err := GetAccessToken(ctx, appID, appSecret, rdb)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取 access_token 失败: %w", err)
|
||||
}
|
||||
|
||||
// 2. 调用微信接口换取手机号 [citation:8]
|
||||
url := fmt.Sprintf("https://api.weixin.qq.com/wxa/business/getuserphonenumber?access_token=%s", accessToken)
|
||||
|
||||
// 构建请求体
|
||||
requestBody := map[string]string{
|
||||
"code": phoneCode,
|
||||
}
|
||||
jsonBody, _ := json.Marshal(requestBody)
|
||||
|
||||
client := &http.Client{Timeout: 5 * time.Second}
|
||||
resp, err := client.Post(url, "application/json", bytes.NewReader(jsonBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("请求手机号接口失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
var phoneResp PhoneInfoResponse
|
||||
if err := json.Unmarshal(body, &phoneResp); err != nil {
|
||||
return nil, fmt.Errorf("解析手机号响应失败: %s", string(body))
|
||||
}
|
||||
|
||||
// 3. 检查微信接口返回的错误码 [citation:8]
|
||||
if phoneResp.Errcode != 0 {
|
||||
return nil, fmt.Errorf("微信接口返回错误: code=%d, msg=%s", phoneResp.Errcode, phoneResp.Errmsg)
|
||||
}
|
||||
|
||||
return &phoneResp.PhoneInfo, nil
|
||||
}
|
||||
|
|
@ -0,0 +1,289 @@
|
|||
package dataTemp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"geo/tmpl/errcode"
|
||||
"geo/utils"
|
||||
"reflect"
|
||||
|
||||
"github.com/go-kratos/kratos/v2/log"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"xorm.io/builder"
|
||||
)
|
||||
|
||||
type PrimaryKey struct {
|
||||
Id int `json:"id"`
|
||||
}
|
||||
|
||||
type GormDb struct {
|
||||
Client *gorm.DB
|
||||
}
|
||||
type contextTxKey struct{}
|
||||
|
||||
func (d *Db) DB(ctx context.Context) *gorm.DB {
|
||||
tx, ok := ctx.Value(contextTxKey{}).(*gorm.DB)
|
||||
if ok {
|
||||
return tx
|
||||
}
|
||||
return d.Db.Client
|
||||
}
|
||||
|
||||
func (t *Db) ExecTx(ctx context.Context, f func(ctx context.Context) error) error {
|
||||
return t.Db.Client.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
ctx = context.WithValue(ctx, contextTxKey{}, tx)
|
||||
return f(ctx)
|
||||
})
|
||||
}
|
||||
|
||||
type Db struct {
|
||||
Db *GormDb
|
||||
Log *log.Helper
|
||||
}
|
||||
|
||||
type DataTemp struct {
|
||||
Db *gorm.DB
|
||||
ModelType reflect.Type // 改为存储类型而不是实例
|
||||
modelName string // 可选的表名缓存
|
||||
}
|
||||
|
||||
func NewDataTemp(db *utils.Db, model interface{}) *DataTemp {
|
||||
// 获取模型的类型
|
||||
t := reflect.TypeOf(model)
|
||||
if t.Kind() == reflect.Ptr {
|
||||
t = t.Elem()
|
||||
}
|
||||
|
||||
return &DataTemp{
|
||||
Db: db.Client,
|
||||
ModelType: t,
|
||||
}
|
||||
}
|
||||
|
||||
func (k DataTemp) modelInstance() interface{} {
|
||||
return reflect.New(k.ModelType).Interface()
|
||||
}
|
||||
|
||||
func (k DataTemp) GetById(id int32) (data map[string]interface{}, err error) {
|
||||
err = k.Db.Model(k.modelInstance()).Where("id = ?", id).Find(&data).Error
|
||||
if data == nil {
|
||||
err = sql.ErrNoRows
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (k DataTemp) GetByStruct(ctx context.Context, search interface{}, data interface{}, orderBy string) (err error) {
|
||||
|
||||
err = k.Db.Model(k.modelInstance()).WithContext(ctx).Where(search).Find(&data).Error
|
||||
|
||||
if data == nil {
|
||||
err = sql.ErrNoRows
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (k DataTemp) SaveByStruct(search interface{}, data interface{}) (err error) {
|
||||
err = k.Db.Model(k.modelInstance()).Where(search).Save(&data).Error
|
||||
if data == nil {
|
||||
err = sql.ErrNoRows
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (k DataTemp) Add(ctx context.Context, data interface{}) (err error) {
|
||||
m := k.modelInstance()
|
||||
if err = k.Db.Model(m).WithContext(ctx).Create(data).Error; err != nil {
|
||||
return errcode.SqlErr(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (k DataTemp) AddWithData(data interface{}) (interface{}, error) {
|
||||
result := k.Db.Model(k.modelInstance()).Create(data)
|
||||
if result.Error != nil {
|
||||
return data, result.Error
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (k DataTemp) GetList(cond *builder.Cond, pageBoIn *ReqPageBo) (list []map[string]interface{}, pageBoOut *RespPageBo, err error) {
|
||||
var (
|
||||
query, _ = builder.ToBoundSQL(*cond)
|
||||
model = k.Db.Model(k.modelInstance()).Where(query)
|
||||
total int64
|
||||
)
|
||||
model.Count(&total)
|
||||
pageBoOut = pageBoOut.SetDataByReq(total, pageBoIn)
|
||||
model.Limit(pageBoIn.GetSize()).Offset(pageBoIn.GetOffset()).Order("updated_at desc").Find(&list)
|
||||
return
|
||||
}
|
||||
|
||||
func (k DataTemp) GetRange(ctx context.Context, cond *builder.Cond) (list []map[string]interface{}, err error) {
|
||||
var (
|
||||
query, _ = builder.ToBoundSQL(*cond)
|
||||
model = k.Db.Model(k.modelInstance()).Where(query)
|
||||
)
|
||||
err = model.WithContext(ctx).Find(&list).Error
|
||||
return list, err
|
||||
}
|
||||
|
||||
func (k DataTemp) GetRangeToMapStruct(ctx context.Context, cond *builder.Cond, data interface{}) (err error) {
|
||||
var (
|
||||
query, _ = builder.ToBoundSQL(*cond)
|
||||
model = k.Db.Model(k.modelInstance()).Where(query)
|
||||
)
|
||||
err = model.WithContext(ctx).Find(data).Error
|
||||
return err
|
||||
}
|
||||
|
||||
func (k DataTemp) GetOneBySearch(cond *builder.Cond) (data map[string]interface{}, err error) {
|
||||
query, _ := builder.ToBoundSQL(*cond)
|
||||
if err = k.Db.Model(k.modelInstance()).Where(query).Limit(1).Find(&data).Error; err != nil {
|
||||
return nil, errcode.SqlErr(err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (k DataTemp) Exist(ctx context.Context, cond *builder.Cond) (bool, error) {
|
||||
var data map[string]interface{}
|
||||
query, _ := builder.ToBoundSQL(*cond)
|
||||
err := k.Db.WithContext(ctx).Model(k.modelInstance()).Where(query).Limit(1).Find(&data).Error
|
||||
if err != nil || data != nil {
|
||||
return true, err
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (k DataTemp) GetOneBySearchStruct(ctx context.Context, cond *builder.Cond, data interface{}) (err error) {
|
||||
query, _ := builder.ToBoundSQL(*cond)
|
||||
if err = k.Db.Model(k.modelInstance()).WithContext(ctx).Where(query).Limit(1).Find(&data).Error; err != nil {
|
||||
return errcode.SqlErr(err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (k DataTemp) GetListToStruct(ctx context.Context, cond *builder.Cond, pageBoIn *ReqPageBo, result interface{}, orderBy string) (pageBoOut *RespPageBo, err error) {
|
||||
// 参数验证
|
||||
if result == nil {
|
||||
return nil, fmt.Errorf("result cannot be nil")
|
||||
}
|
||||
|
||||
val := reflect.ValueOf(result)
|
||||
if val.Kind() != reflect.Ptr {
|
||||
return nil, fmt.Errorf("result must be a pointer")
|
||||
}
|
||||
|
||||
elem := val.Elem()
|
||||
if elem.Kind() != reflect.Slice {
|
||||
return nil, fmt.Errorf("result must be a pointer to slice")
|
||||
}
|
||||
|
||||
// 构建基础查询
|
||||
query, _ := builder.ToBoundSQL(*cond)
|
||||
|
||||
// 预编译 SQL 以提高性能
|
||||
// 使用 Table 指定表名,避免 GORM 的反射开销
|
||||
|
||||
db := k.Db.WithContext(ctx).Model(k.modelInstance()).Where(query)
|
||||
|
||||
// 获取总数(使用单独的计数查询,避免缓存影响)
|
||||
var total int64
|
||||
countDb := db
|
||||
if pageBoIn != nil {
|
||||
if err = countDb.Count(&total).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// 初始化分页响应
|
||||
pageBoOut = &RespPageBo{}
|
||||
pageBoOut = pageBoOut.SetDataByReq(total, pageBoIn)
|
||||
|
||||
// 如果没有数据,直接返回空切片
|
||||
if total == 0 && pageBoIn != nil {
|
||||
elem.Set(reflect.MakeSlice(elem.Type(), 0, 0))
|
||||
return pageBoOut, nil
|
||||
}
|
||||
|
||||
// 设置排序(使用索引字段提高性能)
|
||||
if orderBy == "" {
|
||||
orderBy = "updated_at desc"
|
||||
}
|
||||
|
||||
// 应用分页和排序,执行查询
|
||||
// 使用 Select 指定字段,避免查询所有字段(如果需要优化)
|
||||
baseQuery := db
|
||||
if pageBoIn != nil {
|
||||
baseQuery = db.Limit(pageBoIn.GetSize()).Offset(pageBoIn.GetOffset()).Order(orderBy)
|
||||
}
|
||||
if err = baseQuery.
|
||||
Order(orderBy).
|
||||
Find(result).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return pageBoOut, nil
|
||||
}
|
||||
|
||||
func (k DataTemp) UpdateByKey(ctx context.Context, key string, id interface{}, data interface{}) (err error) {
|
||||
if err = k.Db.WithContext(ctx).Model(k.modelInstance()).Where(fmt.Sprintf("%s = ?", key), id).Updates(data).Error; err != nil {
|
||||
return errcode.SqlErr(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (k DataTemp) UpdateByCond(ctx context.Context, cond *builder.Cond, data interface{}) (err error) {
|
||||
var (
|
||||
query, _ = builder.ToBoundSQL(*cond)
|
||||
model = k.Db.Model(k.modelInstance()).Where(query)
|
||||
)
|
||||
err = model.WithContext(ctx).Updates(data).Error
|
||||
return err
|
||||
}
|
||||
|
||||
func (k DataTemp) UpdateColumnByCond(ctx context.Context, cond *builder.Cond, column string, data interface{}) (err error) {
|
||||
var (
|
||||
query, _ = builder.ToBoundSQL(*cond)
|
||||
model = k.Db.Model(k.modelInstance()).Where(query)
|
||||
)
|
||||
err = model.WithContext(ctx).Update(column, data).Error
|
||||
return err
|
||||
}
|
||||
|
||||
func (k DataTemp) GetByKey(ctx context.Context, key string, value interface{}, data interface{}) (err error) {
|
||||
if err = k.Db.WithContext(ctx).Model(k.modelInstance()).Where(fmt.Sprintf("%s = ?", key), value).Find(data).Error; err != nil {
|
||||
return errcode.SqlErr(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (k DataTemp) DeleteByKey(ctx context.Context, key string, value int64) error {
|
||||
result := k.Db.WithContext(ctx).Model(k.modelInstance()).Where(fmt.Sprintf("%s = ?", key), value).
|
||||
Update("deleted_at", gorm.Expr("CURRENT_TIMESTAMP"))
|
||||
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return errcode.NotFound("不存在或已被删除")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (k DataTemp) CountByCond(ctx context.Context, cond *builder.Cond) (int64, error) {
|
||||
var (
|
||||
count int64
|
||||
query, _ = builder.ToBoundSQL(*cond)
|
||||
model = k.Db.Model(k.modelInstance()).Where(query)
|
||||
)
|
||||
err := model.WithContext(ctx).Count(&count).Error
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return count, err
|
||||
}
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
package dataTemp
|
||||
|
||||
// ReqPageBo 分页请求实体
|
||||
type ReqPageBo struct {
|
||||
Page int //页码,从第1页开始
|
||||
Limit int //分页大小
|
||||
}
|
||||
|
||||
// GetOffset 获取便宜量
|
||||
// 确保 dataTemp/page.go 中有这些方法
|
||||
func (p *ReqPageBo) GetSize() int {
|
||||
if p == nil {
|
||||
return 10 // 默认每页10条
|
||||
}
|
||||
if p.Limit <= 0 {
|
||||
return 10
|
||||
}
|
||||
return p.Limit
|
||||
}
|
||||
|
||||
func (p *ReqPageBo) GetOffset() int {
|
||||
if p == nil {
|
||||
return 0
|
||||
}
|
||||
return (p.GetPage() - 1) * p.GetSize()
|
||||
}
|
||||
|
||||
func (p *ReqPageBo) GetPage() int {
|
||||
if p == nil || p.Page <= 0 {
|
||||
return 1
|
||||
}
|
||||
return p.Page
|
||||
}
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
package dataTemp
|
||||
|
||||
// RespPageBo 分页响应实体
|
||||
type RespPageBo struct {
|
||||
Page int //页码
|
||||
Limit int //每页大小
|
||||
Total int64 //总数
|
||||
}
|
||||
|
||||
// SetDataByReq 通过req 设置响应参数
|
||||
func (r *RespPageBo) SetDataByReq(total int64, reqPage *ReqPageBo) *RespPageBo {
|
||||
resp := r
|
||||
if r == nil {
|
||||
resp = &RespPageBo{}
|
||||
}
|
||||
resp.Total = total
|
||||
if reqPage != nil {
|
||||
resp.Page = reqPage.Page
|
||||
resp.Limit = reqPage.Limit
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
|
@ -0,0 +1,113 @@
|
|||
package errcode
|
||||
|
||||
import "fmt"
|
||||
|
||||
var (
|
||||
AuthNotFound = &BusinessErr{code: AuthErr, message: "账号不存在"}
|
||||
AuthStatusFreeze = &BusinessErr{code: AuthErr, message: "账号冻结"}
|
||||
AuthStatusDel = &BusinessErr{code: AuthErr, message: "身份验证失败"}
|
||||
AuthStatusPwdFail = &BusinessErr{code: AuthErr, message: "密码错误"}
|
||||
AuthTokenCreateFail = &BusinessErr{code: AuthErr, message: "token生成失败"}
|
||||
AuthTokenDelFail = &BusinessErr{code: AuthErr, message: "删除token失败"}
|
||||
AuthWxLoginFail = &BusinessErr{code: AuthErr, message: "微信登录失败,请稍后重试"}
|
||||
AuthInfoFail = &BusinessErr{code: AuthErr, message: "登录异常,请重新登录"}
|
||||
|
||||
TokenNotFound = &BusinessErr{code: TokenErr, message: "缺少 Authorization Header"}
|
||||
TokenFormatErr = &BusinessErr{code: TokenErr, message: "无效的token格式"}
|
||||
TokenInfoNotFound = &BusinessErr{code: TokenErr, message: "未找到用户信息"}
|
||||
TokenInvalid = &BusinessErr{code: TokenErr, message: "token过期"}
|
||||
|
||||
PlatsNotFound = &BusinessErr{code: NotFoundErr, message: "信息未找到"}
|
||||
BadRequest = &BusinessErr{code: BadReqErr, message: "操作失败"}
|
||||
|
||||
ForbiddenError = &BusinessErr{code: ForbiddenErr, message: "权限不足"}
|
||||
Success = &BusinessErr{code: 200, message: "成功"}
|
||||
ParamError = &BusinessErr{code: ParamsErr, message: "参数错误"}
|
||||
|
||||
SystemError = &BusinessErr{code: 405, message: "系统错误"}
|
||||
|
||||
ClientNotFound = &BusinessErr{code: 406, message: "未找到client_id"}
|
||||
SessionNotFound = &BusinessErr{code: 407, message: "未找到会话信息"}
|
||||
UserNotFound = &BusinessErr{code: NotFoundErr, message: "不存在的用户"}
|
||||
|
||||
KeyNotFound = &BusinessErr{code: 409, message: "身份验证失败"}
|
||||
SysNotFound = &BusinessErr{code: 410, message: "未找到系统信息"}
|
||||
SysCodeNotFound = &BusinessErr{code: 411, message: "未找到系统编码"}
|
||||
InvalidParam = &BusinessErr{code: InvalidParamCode, message: "无效参数"}
|
||||
WorkflowError = &BusinessErr{code: 501, message: "工作流过程错误"}
|
||||
ClientInfoNotFound = &BusinessErr{code: NotFoundErr, message: "用户信息未找到"}
|
||||
)
|
||||
|
||||
const (
|
||||
InvalidParamCode = 408
|
||||
AuthErr = 403
|
||||
TokenErr = 401
|
||||
ParamsErr = 422
|
||||
BadReqErr = 400
|
||||
NotFoundErr = 404
|
||||
ForbiddenErr = 403
|
||||
BalanceNotEnoughCode = 402
|
||||
)
|
||||
|
||||
type BusinessErr struct {
|
||||
code int
|
||||
message string
|
||||
}
|
||||
|
||||
func (e *BusinessErr) Error() string {
|
||||
return e.message
|
||||
}
|
||||
func (e *BusinessErr) Code() int {
|
||||
return e.code
|
||||
}
|
||||
|
||||
func NotFound(message string) *BusinessErr {
|
||||
return &BusinessErr{code: NotFoundErr, message: PlatsNotFound.message + ":" + message}
|
||||
}
|
||||
|
||||
func (e *BusinessErr) Is(target error) bool {
|
||||
_, ok := target.(*BusinessErr)
|
||||
return ok
|
||||
}
|
||||
|
||||
// CustomErr 自定义错误
|
||||
func NewBusinessErr(code int, message string) *BusinessErr {
|
||||
return &BusinessErr{code: code, message: message}
|
||||
}
|
||||
|
||||
func SysErrf(message string, arg ...any) *BusinessErr {
|
||||
return &BusinessErr{code: SystemError.code, message: fmt.Sprintf(message, arg)}
|
||||
}
|
||||
|
||||
func SysErr(message string) *BusinessErr {
|
||||
return &BusinessErr{code: SystemError.code, message: message}
|
||||
}
|
||||
|
||||
func ParamErrf(message string, arg ...any) *BusinessErr {
|
||||
return &BusinessErr{code: ParamError.code, message: fmt.Sprintf(message, arg)}
|
||||
}
|
||||
|
||||
func ParamErr(message string) *BusinessErr {
|
||||
return &BusinessErr{code: ParamError.code, message: ParamError.message + ":" + message}
|
||||
}
|
||||
|
||||
func SqlErr(err error) *BusinessErr {
|
||||
|
||||
return &BusinessErr{code: ParamError.code, message: "数据操作失败,请联系管理员处理:" + err.Error()}
|
||||
}
|
||||
|
||||
func BadReq(message string) *BusinessErr {
|
||||
return &BusinessErr{code: BadReqErr, message: BadRequest.message + ":" + message}
|
||||
}
|
||||
|
||||
func Forbidden(message string) *BusinessErr {
|
||||
return &BusinessErr{code: ForbiddenErr, message: ForbiddenError.message + ":" + message}
|
||||
}
|
||||
|
||||
func (e *BusinessErr) Wrap(err error) *BusinessErr {
|
||||
return NewBusinessErr(e.code, err.Error())
|
||||
}
|
||||
|
||||
func BalanceNotEnoughErr(message string) *BusinessErr {
|
||||
return NewBusinessErr(BalanceNotEnoughCode, message)
|
||||
}
|
||||
|
|
@ -0,0 +1,133 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"geo/internal/config"
|
||||
"geo/utils/utils_gorm"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Db struct {
|
||||
Client *gorm.DB
|
||||
}
|
||||
|
||||
func NewGormDb(c *config.Config) (*Db, func()) {
|
||||
transDBClient, mf := utils_gorm.DBConn(&c.DB)
|
||||
//directDBClient, df := directDB(c, hLog)
|
||||
cleanup := func() {
|
||||
mf()
|
||||
//df()
|
||||
}
|
||||
return &Db{
|
||||
Client: transDBClient,
|
||||
//DirectDBClient: directDBClient,
|
||||
}, cleanup
|
||||
}
|
||||
|
||||
// GetOne 查询单条记录,返回 map
|
||||
func (d *Db) GetOne(sql string, args ...interface{}) (map[string]interface{}, error) {
|
||||
var result map[string]interface{}
|
||||
|
||||
// 使用 Raw 执行原生 SQL,Scan 到 map 需要先获取 rows
|
||||
rows, err := d.Client.Raw(sql, args...).Rows()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
if rows.Next() {
|
||||
// 获取列名
|
||||
columns, err := rows.Columns()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 创建扫描用的切片
|
||||
values := make([]interface{}, len(columns))
|
||||
valuePtrs := make([]interface{}, len(columns))
|
||||
for i := range values {
|
||||
valuePtrs[i] = &values[i]
|
||||
}
|
||||
|
||||
if err := rows.Scan(valuePtrs...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result = make(map[string]interface{})
|
||||
for i, col := range columns {
|
||||
result[col] = values[i]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// GetAll 查询多条记录,返回 map 切片
|
||||
func (d *Db) GetAll(sql string, args ...interface{}) ([]map[string]interface{}, error) {
|
||||
rows, err := d.Client.Raw(sql, args...).Rows()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
results := make([]map[string]interface{}, 0)
|
||||
|
||||
for rows.Next() {
|
||||
columns, err := rows.Columns()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
values := make([]interface{}, len(columns))
|
||||
valuePtrs := make([]interface{}, len(columns))
|
||||
for i := range values {
|
||||
valuePtrs[i] = &values[i]
|
||||
}
|
||||
|
||||
if err := rows.Scan(valuePtrs...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
row := make(map[string]interface{})
|
||||
for i, col := range columns {
|
||||
row[col] = values[i]
|
||||
}
|
||||
results = append(results, row)
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// Execute 执行单条 SQL(INSERT/UPDATE/DELETE),返回影响行数
|
||||
func (d *Db) Execute(sql string, args ...interface{}) (int64, error) {
|
||||
result := d.Client.Exec(sql, args...)
|
||||
if result.Error != nil {
|
||||
return 0, result.Error
|
||||
}
|
||||
return result.RowsAffected, nil
|
||||
}
|
||||
|
||||
// ExecuteMany 批量执行 SQL,使用事务
|
||||
func (d *Db) ExecuteMany(sql string, argsList [][]interface{}) (int64, error) {
|
||||
var total int64
|
||||
|
||||
// 开始事务
|
||||
tx := d.Client.Begin()
|
||||
if tx.Error != nil {
|
||||
return 0, tx.Error
|
||||
}
|
||||
|
||||
for _, args := range argsList {
|
||||
result := tx.Exec(sql, args...)
|
||||
if result.Error != nil {
|
||||
tx.Rollback()
|
||||
return 0, result.Error
|
||||
}
|
||||
total += result.RowsAffected
|
||||
}
|
||||
|
||||
// 提交事务
|
||||
if err := tx.Commit().Error; err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return total, nil
|
||||
}
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"github.com/google/wire"
|
||||
)
|
||||
|
||||
var ProviderUtils = wire.NewSet(
|
||||
NewGormDb,
|
||||
)
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
package utils_gorm
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"geo/internal/config"
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/gorm"
|
||||
"time"
|
||||
)
|
||||
|
||||
func DBConn(c *config.DB) (*gorm.DB, func()) {
|
||||
mysqlConn, err := sql.Open(c.Driver, c.Source)
|
||||
gormDB, err := gorm.Open(
|
||||
mysql.New(mysql.Config{Conn: mysqlConn}),
|
||||
)
|
||||
|
||||
gormDB.Logger = NewCustomLogger(gormDB)
|
||||
if err != nil {
|
||||
panic("failed to connect database")
|
||||
}
|
||||
sqlDB, err := gormDB.DB()
|
||||
|
||||
// SetMaxIdleConns sets the maximum number of connections in the idle connection pool.
|
||||
sqlDB.SetMaxIdleConns(int(c.MaxIdle))
|
||||
|
||||
// SetMaxOpenConns sets the maximum number of open connections to the database.
|
||||
sqlDB.SetMaxOpenConns(int(c.MaxLifetime))
|
||||
|
||||
// SetConnMaxLifetime sets the maximum amount of time a connection may be reused.
|
||||
sqlDB.SetConnMaxLifetime(time.Hour)
|
||||
|
||||
return gormDB, func() {
|
||||
if mysqlConn != nil {
|
||||
fmt.Println("关闭 physicalGoodsDB")
|
||||
if err := mysqlConn.Close(); err != nil {
|
||||
fmt.Println("关闭 physicalGoodsDB 失败:", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,96 @@
|
|||
package utils_gorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type CustomLogger struct {
|
||||
gormLogger logger.Interface
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewCustomLogger(db *gorm.DB) *CustomLogger {
|
||||
return &CustomLogger{
|
||||
gormLogger: logger.Default.LogMode(logger.Info),
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *CustomLogger) LogMode(level logger.LogLevel) logger.Interface {
|
||||
newlogger := *l
|
||||
newlogger.gormLogger = l.gormLogger.LogMode(level)
|
||||
return &newlogger
|
||||
}
|
||||
|
||||
func (l *CustomLogger) Info(ctx context.Context, msg string, data ...interface{}) {
|
||||
l.gormLogger.Info(ctx, msg, data...)
|
||||
}
|
||||
|
||||
func (l *CustomLogger) Warn(ctx context.Context, msg string, data ...interface{}) {
|
||||
l.gormLogger.Warn(ctx, msg, data...)
|
||||
}
|
||||
|
||||
func (l *CustomLogger) Error(ctx context.Context, msg string, data ...interface{}) {
|
||||
l.gormLogger.Error(ctx, msg, data...)
|
||||
}
|
||||
|
||||
func (l *CustomLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
|
||||
elapsed := time.Since(begin)
|
||||
sql, _ := fc()
|
||||
l.gormLogger.Trace(ctx, begin, fc, err)
|
||||
operation := extractOperation(sql)
|
||||
tableName := extractTableName(sql)
|
||||
fmt.Println(tableName)
|
||||
//// 将SQL语句保存到数据库
|
||||
if operation == 0 || tableName == "sql_log" {
|
||||
return
|
||||
}
|
||||
//go l.db.Model(&SqlLog{}).Create(&SqlLog{
|
||||
// OperatorID: 1,
|
||||
// OperatorName: "test",
|
||||
// SqlInfo: sql,
|
||||
// TableNames: tableName,
|
||||
// Type: operation,
|
||||
//})
|
||||
|
||||
// 如果有需要,也可以根据执行时间(elapsed)等条件过滤或处理日志记录
|
||||
if elapsed > time.Second {
|
||||
//l.gormLogger.Warn(ctx, "Slow SQL (> 1s): %s", sql)
|
||||
}
|
||||
}
|
||||
|
||||
// extractTableName extracts the table name from a SQL query, supporting quoted table names.
|
||||
func extractTableName(sql string) string {
|
||||
// 使用非捕获组匹配多种SQL操作关键词
|
||||
re := regexp.MustCompile(`(?i)\b(?:from|update|into|delete\s+from)\b\s+[\` + "`" + `"]?(\w+)[\` + "`" + `"]?`)
|
||||
match := re.FindStringSubmatch(sql)
|
||||
|
||||
// 检查是否匹配成功
|
||||
if len(match) > 1 {
|
||||
return match[1]
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// extractOperation extracts the operation type from a SQL query.
|
||||
func extractOperation(sql string) int32 {
|
||||
sql = strings.TrimSpace(strings.ToLower(sql))
|
||||
var operation int32
|
||||
if strings.HasPrefix(sql, "select") {
|
||||
operation = 0
|
||||
} else if strings.HasPrefix(sql, "insert") {
|
||||
operation = 1
|
||||
} else if strings.HasPrefix(sql, "update") {
|
||||
operation = 3
|
||||
} else if strings.HasPrefix(sql, "delete") {
|
||||
operation = 2
|
||||
}
|
||||
return operation
|
||||
}
|
||||
Loading…
Reference in New Issue