This commit is contained in:
renzhiyuan 2026-04-07 21:55:25 +08:00
commit 83f7d34f3e
66 changed files with 12102 additions and 0 deletions

8
.idea/.gitignore vendored Normal file
View File

@ -0,0 +1,8 @@
# Default ignored files
/shelf/
/workspace.xml
# Editor-based HTTP Client requests
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml

9
.idea/geogo.iml Normal file
View File

@ -0,0 +1,9 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="WEB_MODULE" version="4">
<component name="Go" enabled="true" />
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>

8
.idea/modules.xml Normal file
View File

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/geogo.iml" filepath="$PROJECT_DIR$/.idea/geogo.iml" />
</modules>
</component>
</project>

6
.idea/vcs.xml Normal file
View File

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$" vcs="Git" />
</component>
</project>

30
Makefile Normal file
View File

@ -0,0 +1,30 @@
GOHOSTOS:=$(shell go env GOHOSTOS)
GOPATH:=$(shell go env GOPATH)
VERSION=$(shell git describe --tags --always)
PROJECTNAME=yumchina
ifeq ($(GOHOSTOS), windows)
#the `find.exe` is different from `find` in bash/shell.
#to see https://docs.microsoft.com/en-us/windows-server/administration/windows-commands/find.
#changed to use git-bash.exe to run find cli or other cli friendly, caused of every developer has a Git.
#Git_Bash= $(subst cmd\,bin\bash.exe,$(dir $(shell where git)))
Git_Bash=$(subst \,/,$(subst cmd\,bin\bash.exe,$(dir $(shell where git | grep cmd))))
INTERNAL_PROTO_FILES=$(shell $(Git_Bash) -c "find internal -name *.proto")
API_PROTO_FILES=$(shell $(Git_Bash) -c "find api -name *.proto")
else
INTERNAL_PROTO_FILES=$(shell find internal -name *.proto)
API_PROTO_FILES=$(shell find api -name *.proto)
endif
.PHONY: wire
# generate wire
wire:
cd ./cmd/server && wire
.PHONY: build
# build
build:
# make config;
make wire;
mkdir -p bin/ && go build -ldflags "-X main.Version=$(VERSION)" -o ./bin/ ./...

26
cmd/server/main.go Normal file
View File

@ -0,0 +1,26 @@
package main
import (
"fmt"
"geo/internal/config"
"github.com/gofiber/fiber/v2/log"
)
func main() {
bc, err := config.LoadConfig()
if err != nil {
log.Fatalf("加载配置失败: %v", err)
}
//ctx, cancel := context.WithCancel(nil)
app, cleanup, err := InitializeApp(bc, log.DefaultLogger())
if err != nil {
log.Fatalf("项目初始化失败: %v", err)
}
defer func() {
cleanup()
//cancel()
}()
addr := fmt.Sprintf("0.0.0.0:%d", bc.Server.Port)
log.Fatal(app.HttpServer.Listen(addr))
}

28
cmd/server/wire.go Normal file
View File

@ -0,0 +1,28 @@
//go:build wireinject
// +build wireinject
package main
import (
"geo/internal/biz"
"geo/internal/config"
"geo/internal/data/impl"
"geo/internal/server"
"geo/internal/service"
"geo/utils"
"github.com/gofiber/fiber/v2/log"
"github.com/google/wire"
)
// InitializeApp 初始化应用程序
func InitializeApp(*config.Config, log.AllLogger) (*server.Servers, func(), error) {
panic(wire.Build(
server.ProviderSetServer,
service.ProviderSetAppService,
impl.ProviderImpl,
utils.ProviderUtils,
biz.ProviderSetBiz,
))
}

41
cmd/server/wire_gen.go Normal file
View File

@ -0,0 +1,41 @@
// Code generated by Wire. DO NOT EDIT.
//go:generate go run -mod=mod github.com/google/wire/cmd/wire
//go:build !wireinject
// +build !wireinject
package main
import (
"geo/internal/biz"
"geo/internal/config"
"geo/internal/data/impl"
"geo/internal/server"
"geo/internal/server/router"
"geo/internal/service"
"geo/utils"
"github.com/gofiber/fiber/v2/log"
)
// Injectors from wire.go:
// InitializeApp 初始化应用程序
func InitializeApp(configConfig *config.Config, allLogger log.AllLogger) (*server.Servers, func(), error) {
db, cleanup := utils.NewGormDb(configConfig)
tokenImpl := impl.NewTokenImpl(db)
userImpl := impl.NewUserImpl(db)
platImpl := impl.NewPlatImpl(db)
publishImpl := impl.NewPublishImpl(db)
loginRelationImpl := impl.NewLoginRelationImpl(db)
publishBiz := biz.NewPublishBiz(configConfig, publishImpl, userImpl, platImpl, tokenImpl, loginRelationImpl)
appService := service.NewAppService(configConfig, tokenImpl, userImpl, platImpl, publishBiz)
loginService := service.NewLoginService(configConfig, publishBiz)
publishService := service.NewPublishService(configConfig, publishBiz, db)
appModule := router.NewAppModule(configConfig, appService, loginService, publishService)
routerServer := router.NewRouterServer(appModule)
app := server.NewHTTPServer(routerServer)
servers := server.NewServers(configConfig, app)
return servers, func() {
cleanup()
}, nil
}

View File

@ -0,0 +1 @@
[]

31
gen.bat Normal file
View File

@ -0,0 +1,31 @@
@echo off
chcp 65001 >nul
setlocal enabledelayedexpansion
REM Usage: gen.bat user
REM Example: gen.bat usercenter
if "%1"=="" (
echo Error: Please provide table name
echo Usage: gen.bat user
pause
exit /b 1
)
set tables=%1
set modeldir=.\internal\data\model
set prefix=edu_
set fullTableName=%prefix%%tables%
echo Generating table: %fullTableName%
go run gorm.io/gen/tools/gentool -dsn "root:lansexiongdi6,@tcp(47.97.27.195:3306)/geo?charset=utf8mb4&parseTime=true&loc=Asia/Shanghai" -outPath "%modeldir%" -onlyModel -modelPkgName "model" -tables "%fullTableName%"
if %errorlevel% equ 0 (
echo Success! Generated files in: %modeldir%
) else (
echo Failed to generate models
)
pause

19
gen.sh Normal file
View File

@ -0,0 +1,19 @@
#!/bin/bash
# 使用方法:
# ./genModel.sh usercenter user
# ./genModel.sh usercenter user_auth
# 再将./genModel下的文件剪切到对应服务的model目录里面记得改package
#生成的表名
tables=$1
#表生成的genmodel目录
modeldir=./internal/data/model
# 数据库配置
prefix=edu_
gentool --dsn "root:lansexiongdi6,@tcp(47.97.27.195:3306)/geo?charset=utf8mb4&parseTime=true&loc=Asia%2FShanghai" -outPath ${modeldir} -onlyModel -modelPkgName "model" -tables ${prefix}${tables}

64
go.mod Normal file
View File

@ -0,0 +1,64 @@
module geo
go 1.26.1
require (
github.com/JohannesKaufmann/html-to-markdown v1.6.0
github.com/go-kratos/kratos/v2 v2.9.2
github.com/go-playground/validator/v10 v10.30.2
github.com/go-rod/rod v0.116.2
github.com/go-viper/mapstructure/v2 v2.5.0
github.com/gofiber/fiber/v2 v2.52.12
github.com/golang-jwt/jwt/v5 v5.3.1
github.com/google/uuid v1.6.0
github.com/google/wire v0.7.0
github.com/grokify/html-strip-tags-go v0.1.0
github.com/redis/go-redis/v9 v9.18.0
github.com/spf13/viper v1.21.0
gorm.io/driver/mysql v1.6.0
gorm.io/gorm v1.31.1
xorm.io/builder v0.3.13
)
require (
filippo.io/edwards25519 v1.1.0 // indirect
github.com/PuerkitoBio/goquery v1.9.2 // indirect
github.com/andybalholm/brotli v1.1.0 // indirect
github.com/andybalholm/cascadia v1.3.2 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/fsnotify/fsnotify v1.9.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.13 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-sql-driver/mysql v1.8.1 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/klauspost/compress v1.17.9 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-runewidth v0.0.16 // indirect
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
github.com/rivo/uniseg v0.2.0 // indirect
github.com/sagikazarmark/locafero v0.11.0 // indirect
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect
github.com/spf13/afero v1.15.0 // indirect
github.com/spf13/cast v1.10.0 // indirect
github.com/spf13/pflag v1.0.10 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
github.com/valyala/bytebufferpool v1.0.0 // indirect
github.com/valyala/fasthttp v1.51.0 // indirect
github.com/valyala/tcplisten v1.0.0 // indirect
github.com/ysmood/fetchup v0.2.3 // indirect
github.com/ysmood/goob v0.4.0 // indirect
github.com/ysmood/got v0.40.0 // indirect
github.com/ysmood/gson v0.7.3 // indirect
github.com/ysmood/leakless v0.9.0 // indirect
go.uber.org/atomic v1.11.0 // indirect
go.yaml.in/yaml/v3 v3.0.4 // indirect
golang.org/x/crypto v0.49.0 // indirect
golang.org/x/net v0.51.0 // indirect
golang.org/x/sys v0.42.0 // indirect
golang.org/x/text v0.35.0 // indirect
)

222
go.sum Normal file
View File

@ -0,0 +1,222 @@
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
gitea.com/xorm/sqlfiddle v0.0.0-20180821085327-62ce714f951a h1:lSA0F4e9A2NcQSqGqTOXqu2aRi/XEQxDCBwM8yJtE6s=
gitea.com/xorm/sqlfiddle v0.0.0-20180821085327-62ce714f951a/go.mod h1:EXuID2Zs0pAQhH8yz+DNjUbjppKQzKFAn28TMYPB6IU=
github.com/JohannesKaufmann/html-to-markdown v1.6.0 h1:04VXMiE50YYfCfLboJCLcgqF5x+rHJnb1ssNmqpLH/k=
github.com/JohannesKaufmann/html-to-markdown v1.6.0/go.mod h1:NUI78lGg/a7vpEJTz/0uOcYMaibytE4BUOQS8k78yPQ=
github.com/PuerkitoBio/goquery v1.9.2 h1:4/wZksC3KgkQw7SQgkKotmKljk0M6V8TUvA8Wb4yPeE=
github.com/PuerkitoBio/goquery v1.9.2/go.mod h1:GHPCaP0ODyyxqcNoFGYlAprUFH81NuRPd0GX3Zu2Mvk=
github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M=
github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY=
github.com/andybalholm/cascadia v1.3.2 h1:3Xi6Dw5lHF15JtdcmAHD3i1+T8plmv7BQ/nsViSLyss=
github.com/andybalholm/cascadia v1.3.2/go.mod h1:7gtRlve5FxPPgIgX36uWBX58OdBsSS6lUvCFb+h7KvU=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/gabriel-vasile/mimetype v1.4.13 h1:46nXokslUBsAJE/wMsp5gtO500a4F3Nkz9Ufpk2AcUM=
github.com/gabriel-vasile/mimetype v1.4.13/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s=
github.com/go-kratos/kratos/v2 v2.9.2 h1:px8GJQBeLpquDKQWQ9zohEWiLA8n4D/pv7aH3asvUvo=
github.com/go-kratos/kratos/v2 v2.9.2/go.mod h1:Jc7jaeYd4RAPjetun2C+oFAOO7HNMHTT/Z4LxpuEDJM=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.30.2 h1:JiFIMtSSHb2/XBUbWM4i/MpeQm9ZK2xqPNk8vgvu5JQ=
github.com/go-playground/validator/v10 v10.30.2/go.mod h1:mAf2pIOVXjTEBrwUMGKkCWKKPs9NheYGabeB04txQSc=
github.com/go-rod/rod v0.116.2 h1:A5t2Ky2A+5eD/ZJQr1EfsQSe5rms5Xof/qj296e+ZqA=
github.com/go-rod/rod v0.116.2/go.mod h1:H+CMO9SCNc2TJ2WfrG+pKhITz57uGNYU43qYHh438Mg=
github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y=
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
github.com/go-viper/mapstructure/v2 v2.5.0 h1:vM5IJoUAy3d7zRSVtIwQgBj7BiWtMPfmPEgAXnvj1Ro=
github.com/go-viper/mapstructure/v2 v2.5.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
github.com/gofiber/fiber/v2 v2.52.12 h1:0LdToKclcPOj8PktUdIKo9BUohjjwfnQl42Dhw8/WUw=
github.com/gofiber/fiber/v2 v2.52.12/go.mod h1:YEcBbO/FB+5M1IZNBP9FO3J9281zgPAreiI1oqg8nDw=
github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY=
github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4=
github.com/google/wire v0.7.0/go.mod h1:n6YbUQD9cPKTnHXEBN2DXlOp/mVADhVErcMFb0v3J18=
github.com/grokify/html-strip-tags-go v0.1.0 h1:03UrQLjAny8xci+R+qjCce/MYnpNXCtgzltlQbOBae4=
github.com/grokify/html-strip-tags-go v0.1.0/go.mod h1:ZdzgfHEzAfz9X6Xe5eBLVblWIxXfYSQ40S/VKrAOGpc=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA=
github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/redis/go-redis/v9 v9.18.0 h1:pMkxYPkEbMPwRdenAzUNyFNrDgHx9U+DrBabWNfSRQs=
github.com/redis/go-redis/v9 v9.18.0/go.mod h1:k3ufPphLU5YXwNTUcCRXGxUoF1fqxnhFQmscfkCoDA0=
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA=
github.com/sagikazarmark/locafero v0.11.0 h1:1iurJgmM9G3PA/I+wWYIOw/5SyBtxapeHDcg+AAIFXc=
github.com/sagikazarmark/locafero v0.11.0/go.mod h1:nVIGvgyzw595SUSUE6tvCp3YYTeHs15MvlmU87WwIik=
github.com/sebdah/goldie/v2 v2.5.3 h1:9ES/mNN+HNUbNWpVAlrzuZ7jE+Nrczbj8uFRjM7624Y=
github.com/sebdah/goldie/v2 v2.5.3/go.mod h1:oZ9fp0+se1eapSRjfYbsV/0Hqhbuu3bJVvKI/NNtssI=
github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo=
github.com/sergi/go-diff v1.3.1 h1:xkr+Oxo4BOQKmkn/B9eMK0g5Kg/983T9DqqPHwYqD+8=
github.com/sergi/go-diff v1.3.1/go.mod h1:aMJSSKb2lpPvRNec0+w3fl7LP9IOFzdc9Pa4NFbPK1I=
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 h1:+jumHNA0Wrelhe64i8F6HNlS8pkoyMv5sreGx2Ry5Rw=
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8/go.mod h1:3n1Cwaq1E1/1lhQhtRK2ts/ZwZEhjcQeJQ1RuC6Q/8U=
github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I=
github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg=
github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY=
github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo=
github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk=
github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU=
github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
github.com/valyala/fasthttp v1.51.0 h1:8b30A5JlZ6C7AS81RsWjYMQmrZG6feChmgAolCl1SqA=
github.com/valyala/fasthttp v1.51.0/go.mod h1:oI2XroL+lI7vdXyYoQk03bXBThfFl2cVdIA3Xl7cH8g=
github.com/valyala/tcplisten v1.0.0 h1:rBHj/Xf+E1tRGZyWIWwJDiRY0zc1Js+CV5DqwacVSA8=
github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc=
github.com/ysmood/fetchup v0.2.3 h1:ulX+SonA0Vma5zUFXtv52Kzip/xe7aj4vqT5AJwQ+ZQ=
github.com/ysmood/fetchup v0.2.3/go.mod h1:xhibcRKziSvol0H1/pj33dnKrYyI2ebIvz5cOOkYGns=
github.com/ysmood/goob v0.4.0 h1:HsxXhyLBeGzWXnqVKtmT9qM7EuVs/XOgkX7T6r1o1AQ=
github.com/ysmood/goob v0.4.0/go.mod h1:u6yx7ZhS4Exf2MwciFr6nIM8knHQIE22lFpWHnfql18=
github.com/ysmood/gop v0.2.0 h1:+tFrG0TWPxT6p9ZaZs+VY+opCvHU8/3Fk6BaNv6kqKg=
github.com/ysmood/gop v0.2.0/go.mod h1:rr5z2z27oGEbyB787hpEcx4ab8cCiPnKxn0SUHt6xzk=
github.com/ysmood/got v0.40.0 h1:ZQk1B55zIvS7zflRrkGfPDrPG3d7+JOza1ZkNxcc74Q=
github.com/ysmood/got v0.40.0/go.mod h1:W7DdpuX6skL3NszLmAsC5hT7JAhuLZhByVzHTq874Qg=
github.com/ysmood/gotrace v0.6.0 h1:SyI1d4jclswLhg7SWTL6os3L1WOKeNn/ZtzVQF8QmdY=
github.com/ysmood/gotrace v0.6.0/go.mod h1:TzhIG7nHDry5//eYZDYcTzuJLYQIkykJzCRIo4/dzQM=
github.com/ysmood/gson v0.7.3 h1:QFkWbTH8MxyUTKPkVWAENJhxqdBa4lYTQWqZCiLG6kE=
github.com/ysmood/gson v0.7.3/go.mod h1:3Kzs5zDl21g5F/BlLTNcuAGAYLKt2lV5G8D1zF3RNmg=
github.com/ysmood/leakless v0.9.0 h1:qxCG5VirSBvmi3uynXFkcnLMzkphdh3xx5FtrORwDCU=
github.com/ysmood/leakless v0.9.0/go.mod h1:R8iAXPRaG97QJwqxs74RdwzcRHT1SWCGTNqY8q0JvMQ=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
github.com/yuin/goldmark v1.7.1 h1:3bajkSilaCbjdKVsKdZjZCLBNPL9pYzrCakKaf4U49U=
github.com/yuin/goldmark v1.7.1/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E=
github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0=
github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA=
go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE=
go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4=
golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo=
golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY=
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
golang.org/x/term v0.19.0/go.mod h1:2CuTdWZ7KHSQwUzKva0cbMg6q2DMI3Mmxp+gKJbskEk=
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/driver/mysql v1.6.0 h1:eNbLmNTpPpTOVZi8MMxCi2aaIm0ZpInbORNXDwyLGvg=
gorm.io/driver/mysql v1.6.0/go.mod h1:D/oCC2GWK3M/dqoLxnOlaNKmXz8WNTfcS9y5ovaSqKo=
gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg=
gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs=
xorm.io/builder v0.3.13 h1:a3jmiVVL19psGeXx8GIurTp7p0IIgqeDmwhcR6BAOAo=
xorm.io/builder v0.3.13/go.mod h1:aUW0S9eb9VCaPohFCH3j7czOx1PMW3i1HrSzbLYGBSE=

View File

@ -0,0 +1,9 @@
package biz
import (
"github.com/google/wire"
)
var ProviderSetBiz = wire.NewSet(
NewPublishBiz,
)

102
internal/biz/public.go Normal file
View File

@ -0,0 +1,102 @@
package biz
import (
"context"
"time"
"geo/internal/config"
"geo/internal/data/impl"
"geo/internal/data/model"
"geo/tmpl/errcode"
"xorm.io/builder"
)
type PublishBiz struct {
cfg *config.Config
publishImpl *impl.PublishImpl
userImpl *impl.UserImpl
platImpl *impl.PlatImpl
tokenImpl *impl.TokenImpl
loginRelationImpl *impl.LoginRelationImpl
}
func NewPublishBiz(
cfg *config.Config,
publishImpl *impl.PublishImpl,
userImpl *impl.UserImpl,
platImpl *impl.PlatImpl,
tokenImpl *impl.TokenImpl,
loginRelationImpl *impl.LoginRelationImpl,
) *PublishBiz {
return &PublishBiz{
cfg: cfg,
publishImpl: publishImpl,
loginRelationImpl: loginRelationImpl,
userImpl: userImpl,
platImpl: platImpl,
tokenImpl: tokenImpl,
}
}
func (b *PublishBiz) ValidateAccessToken(ctx context.Context, accessToken string) (*model.Token, error) {
cond := builder.NewCond().
And(builder.Eq{"access_token": accessToken}).
And(builder.Eq{"status": 1})
tokenInfo := &model.Token{}
err := b.tokenImpl.GetOneBySearchStruct(ctx, &cond, tokenInfo)
if err != nil || tokenInfo == nil {
return nil, errcode.Forbidden("密钥无效或已禁用")
}
return tokenInfo, nil
}
func (b *PublishBiz) BatchInsertPublish(ctx context.Context, list []*model.Publish) error {
return b.publishImpl.Add(ctx, list)
}
func (b *PublishBiz) GetPendingPublish(ctx context.Context, tokenID int) (map[string]interface{}, error) {
currentTime := time.Now()
cond := builder.NewCond().
And(builder.Eq{"p.token_id": tokenID}).
And(builder.Eq{"p.status": 1}).
And(builder.Lte{"p.publish_time": currentTime})
return b.publishImpl.GetOneWithPlat(ctx, &cond)
}
func (b *PublishBiz) GetTaskByRequestID(ctx context.Context, requestID string) (map[string]interface{}, error) {
cond := builder.NewCond().
And(builder.Eq{"p.request_id": requestID})
return b.publishImpl.GetOneWithPlat(ctx, &cond)
}
func (b *PublishBiz) UpdatePublishStatus(ctx context.Context, requestID string, status int, msg string) error {
return b.publishImpl.UpdateStatus(ctx, requestID, status, msg)
}
func (b *PublishBiz) GetPublishList(ctx context.Context, tokenID int32, page, pageSize int, filters map[string]interface{}) ([]map[string]interface{}, int64, error) {
return b.publishImpl.GetListWithUser(ctx, tokenID, page, pageSize, filters)
}
func (b *PublishBiz) GetPlatInfo(ctx context.Context, platIndex string) (*model.Plat, error) {
cond := builder.NewCond().
And(builder.Eq{"`index`": platIndex}).
And(builder.Eq{"status": 1})
plat := &model.Plat{}
err := b.platImpl.GetOneBySearchStruct(ctx, &cond, plat)
if err != nil {
return nil, err
}
return plat, nil
}
func (b *PublishBiz) UpdateLoginStatus(ctx context.Context, userIndex, platIndex string, loginStatus int32) error {
cond := builder.NewCond().
And(builder.Eq{"user_index": userIndex}).
And(builder.Eq{"plat_index": platIndex})
return b.loginRelationImpl.UpdateByCond(ctx, &cond, &model.User{Status: loginStatus})
}

72
internal/config/config.go Normal file
View File

@ -0,0 +1,72 @@
package config
import (
"os"
"path/filepath"
)
// Config 应用配置
type Config struct {
Server ServerConfig `mapstructure:"server"`
DB DB `mapstructure:"db"`
Sys Sys `mapstructure:"sys"`
}
// ServerConfig 服务器配置
type ServerConfig struct {
Port int `mapstructure:"port"`
Host string `mapstructure:"host"`
}
type DB struct {
Driver string `mapstructure:"driver"`
Source string `mapstructure:"source"`
MaxIdle int32 `mapstructure:"maxIdle"`
MaxOpen int32 `mapstructure:"maxOpen"`
MaxLifetime int32 `mapstructure:"maxLifetime"`
IsDebug bool `mapstructure:"isDebug"`
}
type Sys struct {
MaxConcurrent int `mapstructure:"maxConcurrent"`
TaskTimeout int `mapstructure:"taskTimeout"`
SessionTimeout int `mapstructure:"sessionTimeout "`
MaxImageSize int `mapstructure:"maxImageSize"`
LogsDir string `mapstructure:"logsDir"`
UploadDir string `mapstructure:"uploadDir"`
VideosDir string `mapstructure:"videosDir"`
DocsDir string `mapstructure:"docsDir"`
CookiesDir string `mapstructure:"cookiesDir"`
QrcodesDir string `mapstructure:"qrcodesDir"`
ChromePath string `mapstructure:"chromePath"`
ChromeDataDir string `mapstructure:"chromeDataDir"`
}
// LoadConfig 加载配置
func LoadConfig() (*Config, error) {
BaseDir, _ := os.Getwd()
return &Config{
Server: ServerConfig{
Port: 5001,
Host: "0.0.0.0",
},
DB: DB{
Driver: "mysql",
Source: "root:lansexiongdi6,@tcp(47.97.27.195:3306)/geo?charset=utf8mb4&parseTime=true&loc=Asia%2FShanghai",
},
Sys: Sys{
MaxConcurrent: 1,
TaskTimeout: 60,
SessionTimeout: 300,
MaxImageSize: 5 * 1024 * 1024,
LogsDir: filepath.Join(BaseDir, "logs"),
UploadDir: filepath.Join(BaseDir, "images"),
VideosDir: filepath.Join(BaseDir, "videos"),
DocsDir: filepath.Join(BaseDir, "docs"),
CookiesDir: filepath.Join(BaseDir, "cookies"),
QrcodesDir: filepath.Join(BaseDir, "qrcodes"),
ChromePath: "./chrome/chrome.exe",
ChromeDataDir: "./chrome_data",
},
}, nil
}

View File

@ -0,0 +1,27 @@
package impl
import (
"geo/internal/data/model"
"geo/tmpl/dataTemp"
"geo/utils"
)
type LoginRelationImpl struct {
dataTemp.DataTemp
db *utils.Db
}
func NewLoginRelationImpl(db *utils.Db) *LoginRelationImpl {
return &LoginRelationImpl{
DataTemp: *dataTemp.NewDataTemp(db, new(model.LoginRelation)),
db: db,
}
}
func (m *LoginRelationImpl) PrimaryKey() string {
return "id"
}
func (m *LoginRelationImpl) GetTemp() *dataTemp.DataTemp {
return &m.DataTemp
}

View File

@ -0,0 +1,88 @@
package impl
import (
"context"
"fmt"
"geo/internal/data/model"
"geo/tmpl/dataTemp"
"geo/utils"
"xorm.io/builder"
)
type PlatImpl struct {
dataTemp.DataTemp
db *utils.Db
}
func NewPlatImpl(db *utils.Db) *PlatImpl {
return &PlatImpl{
DataTemp: *dataTemp.NewDataTemp(db, new(model.Plat)),
db: db,
}
}
func (p *PlatImpl) PrimaryKey() string {
return "id"
}
func (p *PlatImpl) GetTemp() *dataTemp.DataTemp {
return &p.DataTemp
}
// GetPlatListWithLoginStatus 获取平台列表并关联登录状态
func (p *PlatImpl) GetPlatListWithLoginStatus(ctx context.Context, cond *builder.Cond) ([]map[string]interface{}, error) {
query, err := builder.ToBoundSQL(*cond)
if err != nil {
return nil, err
}
sql := `
SELECT
plat.name,
plat.img_url,
plat.index,
plat.desc,
COALESCE(login_relation.login_status, 2) as login_status
FROM plat
LEFT JOIN login_relation ON login_relation.plat_index COLLATE utf8mb4_unicode_ci = plat.index AND login_relation.status = 1
WHERE %s
ORDER BY plat.id ASC
`
finalSQL := fmt.Sprintf(sql, query)
var results []map[string]interface{}
err = p.db.Client.WithContext(ctx).Raw(finalSQL).Scan(&results).Error
if err != nil {
return nil, err
}
return results, nil
}
// GetPlatListWithLoginStatusByUserIndex 根据用户索引获取平台列表及登录状态
func (p *PlatImpl) GetPlatListWithLoginStatusByUserIndex(ctx context.Context, userIndex string) ([]map[string]interface{}, error) {
sql := `
SELECT
pl.name,
pl.img_url,
pl.index,
pl.desc,
pl.login_url,
pl.edit_url,
pl.logined_url,
COALESCE(lr.login_status, 2) as login_status
FROM plat pl
LEFT JOIN login_relation lr ON lr.plat_index = pl.index AND lr.user_index = ? AND lr.status = 1
WHERE pl.status = 1
ORDER BY pl.id ASC
`
var results []map[string]interface{}
err := p.db.Client.WithContext(ctx).Raw(sql, userIndex).Scan(&results).Error
if err != nil {
return nil, err
}
return results, nil
}

View File

@ -0,0 +1,13 @@
package impl
import (
"github.com/google/wire"
)
var ProviderImpl = wire.NewSet(
NewPlatImpl,
NewLoginRelationImpl,
NewUserImpl,
NewTokenImpl,
NewPublishImpl,
)

View File

@ -0,0 +1,244 @@
package impl
import (
"context"
"fmt"
"geo/internal/data/model"
"geo/tmpl/dataTemp"
"geo/utils"
"time"
"xorm.io/builder"
)
type PublishImpl struct {
dataTemp.DataTemp
db *utils.Db
}
func NewPublishImpl(db *utils.Db) *PublishImpl {
return &PublishImpl{
DataTemp: *dataTemp.NewDataTemp(db, new(model.Publish)),
db: db,
}
}
func (m *PublishImpl) PrimaryKey() string {
return "id"
}
func (m *PublishImpl) GetTemp() *dataTemp.DataTemp {
return &m.DataTemp
}
// BatchInsert 批量插入发布记录
func (p *PublishImpl) BatchInsert(ctx context.Context, tokenID int, records []map[string]interface{}) (int64, error) {
if len(records) == 0 {
return 0, nil
}
sql := `INSERT INTO publish
(token_id, user_index, request_id, title, tag, type, plat_index, url, publish_time, status, create_time, img)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`
var total int64
for _, record := range records {
result := p.db.Client.WithContext(ctx).Exec(sql,
tokenID,
record["user_index"],
record["request_id"],
record["title"],
record["tag"],
record["type"],
record["plat_index"],
record["url"],
record["publish_time"],
1, // status = 1 待发布
time.Now(),
record["img"],
)
if result.Error != nil {
return total, result.Error
}
total += result.RowsAffected
}
return total, nil
}
// GetOneWithPlat 查询单条发布记录并关联平台信息
func (p *PublishImpl) GetOneWithPlat(ctx context.Context, cond *builder.Cond) (map[string]interface{}, error) {
query, err := builder.ToBoundSQL(*cond)
if err != nil {
return nil, err
}
// 构建带平台信息的查询
sql := fmt.Sprintf(`
SELECT
p.id,
p.token_id,
p.user_index,
p.request_id,
p.title,
p.tag,
p.type,
p.plat_index,
p.url,
p.publish_time,
p.status,
p.create_time,
p.img,
p.msg,
pl.index as plat_index_db,
pl.name as plat_name,
pl.login_url,
pl.edit_url,
pl.logined_url,
pl.desc as plat_desc,
pl.img_url as plat_img_url
FROM publish p
INNER JOIN plat pl ON p.plat_index = pl.index AND pl.status = 1
WHERE %s
LIMIT 1
`, query)
var result map[string]interface{}
err = p.db.Client.WithContext(ctx).Raw(sql).Scan(&result).Error
if err != nil {
return nil, err
}
if result == nil || len(result) == 0 {
return nil, nil
}
return result, nil
}
// GetOneWithPlatByRequestID 根据request_id查询发布记录并关联平台信息
func (p *PublishImpl) GetOneWithPlatByRequestID(ctx context.Context, requestID string) (map[string]interface{}, error) {
sql := `
SELECT
p.id,
p.token_id,
p.user_index,
p.request_id,
p.title,
p.tag,
p.type,
p.plat_index,
p.url,
p.publish_time,
p.status,
p.create_time,
p.img,
p.msg,
pl.index as plat_index_db,
pl.name as plat_name,
pl.login_url,
pl.edit_url,
pl.logined_url,
pl.desc as plat_desc,
pl.img_url as plat_img_url
FROM publish p
INNER JOIN plat pl ON p.plat_index = pl.index AND pl.status = 1
WHERE p.request_id = ?
LIMIT 1
`
var result map[string]interface{}
err := p.db.Client.WithContext(ctx).Raw(sql, requestID).Scan(&result).Error
if err != nil {
return nil, err
}
if result == nil || len(result) == 0 {
return nil, nil
}
return result, nil
}
// UpdateStatus 更新发布状态
func (p *PublishImpl) UpdateStatus(ctx context.Context, requestID string, status int, msg string) error {
updateData := map[string]interface{}{
"status": status,
}
if msg != "" {
updateData["msg"] = msg
}
return p.db.Client.WithContext(ctx).
Model(&model.Publish{}).
Where("request_id = ?", requestID).
Updates(updateData).Error
}
// GetListWithUser 获取发布列表并关联用户信息
func (p *PublishImpl) GetListWithUser(ctx context.Context, tokenID int32, page, pageSize int, filters map[string]interface{}) ([]map[string]interface{}, int64, error) {
// 构建基础查询
query := p.db.Client.WithContext(ctx).
Table("publish p").
Select(`
p.id,
p.user_index,
u.name as user_name,
p.request_id,
p.title,
p.tag,
p.type,
CASE
WHEN p.type = 1 THEN '文章'
WHEN p.type = 2 THEN '视频'
ELSE '未知'
END as type_name,
p.plat_index,
pl.name as plat_name,
p.url,
p.publish_time,
p.status,
CASE
WHEN p.status = 1 THEN '待发布'
WHEN p.status = 2 THEN '发布中'
WHEN p.status = 3 THEN '发布失败'
WHEN p.status = 4 THEN '发布成功'
ELSE '未知'
END as status_name,
p.create_time,
p.msg
`).
Joins("LEFT JOIN user u ON p.user_index = u.user_index").
Joins("LEFT JOIN plat pl ON p.plat_index = pl.index").
Where("u.token_id = ?", tokenID)
// 添加过滤条件
if userIndex, ok := filters["user_index"]; ok && userIndex != "" {
query = query.Where("p.user_index = ?", userIndex)
}
if tag, ok := filters["tag"]; ok && tag != "" {
query = query.Where("p.tag LIKE ?", "%"+tag.(string)+"%")
}
if typeFilter, ok := filters["type"]; ok && typeFilter != 0 {
query = query.Where("p.type = ?", typeFilter)
}
if platIndex, ok := filters["plat_index"]; ok && platIndex != "" {
query = query.Where("p.plat_index = ?", platIndex)
}
if status, ok := filters["status"]; ok && status != 0 {
query = query.Where("p.status = ?", status)
}
if requestID, ok := filters["request_id"]; ok && requestID != "" {
query = query.Where("p.request_id LIKE ?", "%"+requestID.(string)+"%")
}
// 获取总数
var total int64
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 分页查询
var results []map[string]interface{}
offset := (page - 1) * pageSize
err := query.
Order("p.publish_time DESC, p.create_time DESC").
Limit(pageSize).
Offset(offset).
Scan(&results).Error
return results, total, err
}

View File

@ -0,0 +1,31 @@
package impl
import (
"geo/internal/data/model"
"geo/tmpl/dataTemp"
"geo/utils"
)
type TokenImpl struct {
dataTemp.DataTemp
db *utils.Db
}
func NewTokenImpl(db *utils.Db) *TokenImpl {
return &TokenImpl{
DataTemp: *dataTemp.NewDataTemp(db, new(model.Token)),
db: db,
}
}
func (m *TokenImpl) PrimaryKey() string {
return "id"
}
func (m *TokenImpl) GetTemp() *dataTemp.DataTemp {
return &m.DataTemp
}
func (m *TokenImpl) GetDb() *utils.Db {
return m.db
}

View File

@ -0,0 +1,27 @@
package impl
import (
"geo/internal/data/model"
"geo/tmpl/dataTemp"
"geo/utils"
)
type UserImpl struct {
dataTemp.DataTemp
db *utils.Db
}
func NewUserImpl(db *utils.Db) *UserImpl {
return &UserImpl{
DataTemp: *dataTemp.NewDataTemp(db, new(model.User)),
db: db,
}
}
func (m *UserImpl) PrimaryKey() string {
return "id"
}
func (m *UserImpl) GetTemp() *dataTemp.DataTemp {
return &m.DataTemp
}

View File

@ -0,0 +1,21 @@
// Code generated by gorm.io/gen. DO NOT EDIT.
// Code generated by gorm.io/gen. DO NOT EDIT.
// Code generated by gorm.io/gen. DO NOT EDIT.
package model
const TableNameLoginRelation = "login_relation"
// LoginRelation mapped from table <login_relation>
type LoginRelation struct {
ID int32 `gorm:"column:id;primaryKey;autoIncrement:true" json:"id"`
UserIndex string `gorm:"column:user_index;not null;default:0" json:"user_index"`
PlatIndex string `gorm:"column:plat_index;not null;default:0" json:"plat_index"`
LoginStatus int32 `gorm:"column:login_status;not null;default:2" json:"login_status"`
Status int32 `gorm:"column:status;not null;default:1" json:"status"`
}
// TableName LoginRelation's table name
func (*LoginRelation) TableName() string {
return TableNameLoginRelation
}

View File

@ -0,0 +1,25 @@
// Code generated by gorm.io/gen. DO NOT EDIT.
// Code generated by gorm.io/gen. DO NOT EDIT.
// Code generated by gorm.io/gen. DO NOT EDIT.
package model
const TableNamePlat = "plat"
// Plat mapped from table <plat>
type Plat struct {
ID int32 `gorm:"column:id;primaryKey" json:"id"`
Name string `gorm:"column:name;not null" json:"name"`
Index string `gorm:"column:index;not null" json:"index"`
ImgURL string `gorm:"column:img_url;not null" json:"img_url"`
LoginURL string `gorm:"column:login_url;not null" json:"login_url"`
EditURL string `gorm:"column:edit_url;not null" json:"edit_url"`
LoginedURL string `gorm:"column:logined_url;not null" json:"logined_url"`
Desc string `gorm:"column:desc;not null" json:"desc"`
Status bool `gorm:"column:status;not null;default:1" json:"status"`
}
// TableName Plat's table name
func (*Plat) TableName() string {
return TableNamePlat
}

View File

@ -0,0 +1,34 @@
// Code generated by gorm.io/gen. DO NOT EDIT.
// Code generated by gorm.io/gen. DO NOT EDIT.
// Code generated by gorm.io/gen. DO NOT EDIT.
package model
import (
"time"
)
const TableNamePublish = "publish"
// Publish mapped from table <publish>
type Publish struct {
ID int64 `gorm:"column:id;primaryKey;autoIncrement:true" json:"id"`
TokenID int32 `gorm:"column:token_id;not null" json:"token_id"`
UserIndex string `gorm:"column:user_index;not null;comment:关联user.user_index" json:"user_index"` // 关联user.user_index
RequestID string `gorm:"column:request_id;not null;comment:日志id" json:"request_id"` // 日志id
Title string `gorm:"column:title;not null" json:"title"`
Tag string `gorm:"column:tag;not null;comment:标签,多个英文都好隔开’," json:"tag"` // 标签,多个英文都好隔开’,
Type int32 `gorm:"column:type;not null;default:1;comment:1:文章2视频" json:"type"` // 1:文章2视频
PlatIndex string `gorm:"column:plat_index;not null;comment:关联plat.index" json:"plat_index"` // 关联plat.index
URL string `gorm:"column:url;not null;comment:资源文件下载地址" json:"url"` // 资源文件下载地址
PublishTime time.Time `gorm:"column:publish_time;not null;comment:发布时间" json:"publish_time"` // 发布时间
Img string `gorm:"column:img;not null" json:"img"`
Status int32 `gorm:"column:status;not null;default:1;comment:1待发布2发布中3发布失败4发布成功" json:"status"` // 1待发布2发布中3发布失败4发布成功
CreateTime time.Time `gorm:"column:create_time;not null;default:CURRENT_TIMESTAMP;comment:创建时间" json:"create_time"` // 创建时间
Msg string `gorm:"column:msg" json:"msg"`
}
// TableName Publish's table name
func (*Publish) TableName() string {
return TableNamePublish
}

View File

@ -0,0 +1,27 @@
// Code generated by gorm.io/gen. DO NOT EDIT.
// Code generated by gorm.io/gen. DO NOT EDIT.
// Code generated by gorm.io/gen. DO NOT EDIT.
package model
import (
"time"
)
const TableNameToken = "token"
// Token mapped from table <token>
type Token struct {
ID int32 `gorm:"column:id;primaryKey" json:"id"`
Name string `gorm:"column:name" json:"name"`
Secret string `gorm:"column:secret;not null" json:"secret"`
AccessToken string `gorm:"column:access_token;not null" json:"access_token"`
UserLimit int32 `gorm:"column:user_limit;primaryKey" json:"user_limit"`
ExpireTime time.Time `gorm:"column:expire_time;not null" json:"expire_time"`
Status int32 `gorm:"column:status;not null;default:1" json:"status"`
}
// TableName Token's table name
func (*Token) TableName() string {
return TableNameToken
}

View File

@ -0,0 +1,21 @@
// Code generated by gorm.io/gen. DO NOT EDIT.
// Code generated by gorm.io/gen. DO NOT EDIT.
// Code generated by gorm.io/gen. DO NOT EDIT.
package model
const TableNameUser = "user"
// User mapped from table <user>
type User struct {
ID int64 `gorm:"column:id;primaryKey;autoIncrement:true" json:"id"`
TokenID int32 `gorm:"column:token_id;not null" json:"token_id"`
Name string `gorm:"column:name;not null" json:"name"`
Status int32 `gorm:"column:status;not null;default:1" json:"status"`
UserIndex string `gorm:"column:user_index;not null" json:"user_index"`
}
// TableName User's table name
func (*User) TableName() string {
return TableNameUser
}

View File

@ -0,0 +1,88 @@
package entitys
type (
LoginAppRequest struct {
Secret string `json:"secret" validate:"required" zh:"密钥"`
}
GetUserAndAutoStatusRequest struct {
AccessToken string `json:"access_token" validate:"required" zh:"access_token"`
}
AddUserRequest struct {
AccessToken string `json:"access_token" validate:"required" zh:"access_token"`
Name string `json:"name" validate:"required" zh:"用户名"`
}
DelUserRequest struct {
ID int `json:"id" validate:"required" zh:"用户ID"`
}
GetAppRequest struct {
AccessToken string `json:"access_token" validate:"required" zh:"access_token"`
UserIndex string `json:"user_index" validate:"required" zh:"用户索引"`
}
PublishRecordsRequest struct {
AccessToken string `json:"access_token" validate:"required" zh:"access_token"`
Records []PublishRecordItem `json:"records" validate:"required" zh:"发布记录"`
}
PublishRecordItem struct {
UserIndex string `json:"user_index"`
PlatIndex string `json:"plat_index" validate:"required" zh:"平台索引"`
Title string `json:"title"`
Tag string `json:"tag"`
Type int32 `json:"type" validate:"required" zh:"类型"`
URL string `json:"url" validate:"required" zh:"链接"`
PublishTime string `json:"publish_time" validate:"required" zh:"发布时间"`
Img string `json:"img"`
RequestID string `json:"request_id"`
}
PublishOnRequest struct {
AccessToken string `json:"access_token" validate:"required" zh:"access_token"`
}
PublishOffRequest struct {
AccessToken string `json:"access_token" validate:"required" zh:"access_token"`
}
PublishStatusRequest struct {
AccessToken string `json:"access_token" validate:"required" zh:"access_token"`
RequestID string `json:"request_id"`
}
PublishExecuteOnceRequest struct {
AccessToken string `json:"access_token" validate:"required" zh:"access_token"`
}
PublishExecuteRetryRequest struct {
AccessToken string `json:"access_token" validate:"required" zh:"access_token"`
RequestID string `json:"request_id" validate:"required" zh:"请求ID"`
}
GetPublishListRequest struct {
AccessToken string `json:"access_token" validate:"required" zh:"access_token"`
Page int `json:"page"`
PageSize int `json:"page_size"`
UserIndex string `json:"user_index"`
Tag string `json:"tag"`
Type int `json:"type"`
PlatIndex string `json:"plat_index"`
Status int `json:"status"`
RequestID string `json:"request_id"`
}
LoginPlatformRequest struct {
AccessToken string `json:"access_token" validate:"required" zh:"access_token"`
UserIndex string `json:"user_index" validate:"required" zh:"用户索引"`
PlatIndex string `json:"plat_index" validate:"required" zh:"平台索引"`
}
LogoutPlatformRequest struct {
AccessToken string `json:"access_token" validate:"required" zh:"access_token"`
UserIndex string `json:"user_index" validate:"required" zh:"用户索引"`
PlatIndex string `json:"plat_index" validate:"required" zh:"平台索引"`
}
)

View File

@ -0,0 +1,297 @@
package manager
import (
"fmt"
"geo/internal/config"
"geo/internal/publisher"
"geo/pkg"
"log"
"sync"
"time"
"geo/utils"
)
type PublishManager struct {
AutoStatus bool
Conf *config.Config
TokenID int
running bool
mu sync.Mutex
stopCh chan struct{}
currentPublisher interface{}
db *utils.Db
}
var publishManager *PublishManager
var once sync.Once
func GetPublishManager(config *config.Config, db *utils.Db) *PublishManager {
once.Do(func() {
publishManager = &PublishManager{
AutoStatus: false,
Conf: config,
stopCh: make(chan struct{}),
db: db,
}
})
return publishManager
}
func (pm *PublishManager) Start(tokenID int) bool {
pm.mu.Lock()
defer pm.mu.Unlock()
if pm.AutoStatus {
return false
}
pm.TokenID = tokenID
pm.AutoStatus = true
pm.stopCh = make(chan struct{})
go pm.autoPublishLoop()
return true
}
func (pm *PublishManager) Stop() bool {
pm.mu.Lock()
defer pm.mu.Unlock()
if !pm.AutoStatus {
return false
}
pm.AutoStatus = false
close(pm.stopCh)
return true
}
func (pm *PublishManager) autoPublishLoop() {
log.Println("自动发布服务已启动")
for {
select {
case <-pm.stopCh:
log.Println("自动发布服务已停止")
return
default:
pm.batchPublish()
time.Sleep(30 * time.Second)
}
}
}
func (pm *PublishManager) batchPublish() {
if !pm.AutoStatus {
return
}
publishData := pm.getPendingPublish()
if publishData == nil {
return
}
// 使用 defer recover 防止 panic 导致整个循环崩溃
defer func() {
if r := recover(); r != nil {
log.Printf("批处理发布发生 panic: %v", r)
}
}()
pm.processSingleTask(publishData)
}
func (pm *PublishManager) getPendingPublish() map[string]interface{} {
currentTime := time.Now().Format("2006-01-02 15:04:05")
sql := `
SELECT p.*, pl.*
FROM publish p
INNER JOIN plat pl ON p.plat_index = pl.index AND pl.status = 1
WHERE p.token_id = ? AND p.status = 1 AND p.publish_time <= ?
ORDER BY p.publish_time DESC
LIMIT 1
`
result, err := pm.db.GetOne(sql, pm.TokenID, currentTime)
if err != nil {
log.Printf("查询待发布任务失败: token_id=%d, error=%v", pm.TokenID, err)
return nil
}
if result == nil {
log.Printf("没有待发布任务: token_id=%d, current_time=%s", pm.TokenID, currentTime)
return nil
}
requestID := getString(result, "request_id")
log.Printf("获取到待发布任务: token_id=%d, request_id=%s", pm.TokenID, requestID)
return result
}
func (pm *PublishManager) GetTaskByRequestID(requestID string) (map[string]interface{}, error) {
sql := `
SELECT p.*, pl.*
FROM publish p
INNER JOIN plat pl ON p.plat_index COLLATE utf8mb4_unicode_ci = pl.index AND pl.status = 1
WHERE p.request_id = ?
`
return pm.db.GetOne(sql, requestID)
}
func (pm *PublishManager) processSingleTask(publishData map[string]interface{}) map[string]interface{} {
requestID := getString(publishData, "request_id")
platIndex := getString(publishData, "plat_index")
title := getString(publishData, "title")
tagRaw := getString(publishData, "tag")
userIndex := getString(publishData, "user_index")
url := getString(publishData, "url")
imgURL := getString(publishData, "img")
log.Printf("[任务 %s] 开始处理,平台:%s标题%s", requestID, platIndex, title)
// 更新状态为发布中
pm.updatePublishStatus(requestID, 2, "")
log.Printf("[任务 %s] 状态已更新为发布中", requestID)
// 下载文件
docPath, err := pkg.DownloadFile(url, "", requestID+".docx")
if err != nil {
errMsg := fmt.Sprintf("下载文档失败: %v", err)
log.Printf("[任务 %s] %s", requestID, errMsg)
pm.updatePublishStatus(requestID, 3, errMsg)
return map[string]interface{}{"success": false, "message": errMsg, "request_id": requestID}
}
log.Printf("[任务 %s] 文档下载成功: %s", requestID, docPath)
// 下载图片
imgPath, err := pkg.DownloadImage(imgURL, requestID, "img")
if err != nil {
errMsg := fmt.Sprintf("下载图片失败: %v", err)
log.Printf("[任务 %s] %s", requestID, errMsg)
pm.updatePublishStatus(requestID, 3, errMsg)
// 图片下载失败,清理已下载的文档
pkg.DeleteFile(docPath)
return map[string]interface{}{"success": false, "message": errMsg, "request_id": requestID}
}
log.Printf("[任务 %s] 图片下载成功: %s", requestID, imgPath)
// 确保清理临时文件
defer func() {
pkg.DeleteFile(docPath)
pkg.DeleteFile(imgPath)
}()
// 解析标签
tags := pkg.ParseTags(tagRaw)
log.Printf("[任务 %s] 标签解析完成: %v", requestID, tags)
// 提取内容
content, err := pkg.ExtractWordContent(docPath, "html")
if err != nil {
errMsg := fmt.Sprintf("提取文档内容失败: %v", err)
log.Printf("[任务 %s] %s", requestID, errMsg)
pm.updatePublishStatus(requestID, 3, errMsg)
return map[string]interface{}{"success": false, "message": errMsg, "request_id": requestID}
}
log.Printf("[任务 %s] 内容提取成功,长度: %d", requestID, len(content))
// 获取发布器
publisherClass := getPublisherClass(platIndex)
if publisherClass == nil {
errMsg := fmt.Sprintf("不支持的平台: %s", platIndex)
log.Printf("[任务 %s] %s", requestID, errMsg)
pm.updatePublishStatus(requestID, 3, errMsg)
return map[string]interface{}{"success": false, "message": errMsg, "request_id": requestID}
}
// 创建并执行发布器
var pub interface{ PublishNote() (bool, string) }
switch platIndex {
case "xhs":
pub = publisher.NewXiaohongshuPublisher(false, title, content, tags, userIndex, platIndex, requestID, imgPath, docPath, publishData, pm.Conf)
log.Printf("[任务 %s] 创建小红书发布器", requestID)
case "bjh":
pub = publisher.NewBaijiahaoPublisher(false, title, content, tags, userIndex, platIndex, requestID, imgPath, docPath, publishData, pm.Conf)
log.Printf("[任务 %s] 创建百家号发布器", requestID)
default:
log.Printf("[任务 %s] 未知平台 %s使用默认小红书发布器", requestID, platIndex)
pub = publisher.NewXiaohongshuPublisher(false, title, content, tags, userIndex, platIndex, requestID, imgPath, docPath, publishData, pm.Conf)
}
log.Printf("[任务 %s] 开始执行发布...", requestID)
success, message := pub.PublishNote()
if success {
log.Printf("[任务 %s] 发布成功: %s", requestID, message)
pm.updatePublishStatus(requestID, 4, message)
} else {
log.Printf("[任务 %s] 发布失败: %s", requestID, message)
pm.updatePublishStatus(requestID, 3, message)
}
return map[string]interface{}{
"success": success,
"message": message,
"request_id": requestID,
}
}
func (pm *PublishManager) updatePublishStatus(requestID string, status int, message string) {
if message != "" {
pm.db.Execute("UPDATE publish SET status = ?, msg = ? WHERE request_id = ?", status, message, requestID)
} else {
pm.db.Execute("UPDATE publish SET status = ? WHERE request_id = ?", status, requestID)
}
}
func (pm *PublishManager) ExecuteOnce(tokenId int32) map[string]interface{} {
publishData := pm.getPendingPublish()
if publishData == nil {
return map[string]interface{}{"success": false, "message": "没有待发布任务"}
}
return pm.processSingleTask(publishData)
}
func (pm *PublishManager) RetryTask(requestID string) map[string]interface{} {
publishData, err := pm.GetTaskByRequestID(requestID)
if err != nil || publishData == nil {
return map[string]interface{}{"success": false, "message": "任务不存在"}
}
return pm.processSingleTask(publishData)
}
func (pm *PublishManager) GetStatus() map[string]interface{} {
return map[string]interface{}{
"auto_status": pm.AutoStatus,
"max_concurrent": pm.Conf.Sys.MaxConcurrent,
"task_timeout": pm.Conf.Sys.TaskTimeout,
}
}
func getPublisherClass(platIndex string) interface{} {
platformMap := map[string]interface{}{
"xhs": struct{}{},
"bjh": struct{}{},
"csdn": struct{}{},
}
return platformMap[platIndex]
}
func getString(m map[string]interface{}, key string) string {
if v, ok := m[key]; ok {
switch v.(type) {
case []uint8:
return string(v.([]uint8))
case string:
return v.(string)
case int64:
return fmt.Sprintf("%d", v)
default:
return fmt.Sprintf("%v", v)
}
}
return ""
}

View File

@ -0,0 +1,383 @@
package publisher
import (
"fmt"
"strings"
"time"
"geo/internal/config"
"github.com/go-rod/rod"
"github.com/go-rod/rod/lib/proto"
)
type BaijiahaoPublisher struct {
*BasePublisher
Category string
ArticleType string
IsTop bool
}
func NewBaijiahaoPublisher(headless bool, title, content string, tags []string, tenantID, platIndex, requestID, imagePath, wordPath string, platInfo map[string]interface{}, cfg *config.Config) *BaijiahaoPublisher {
base := NewBasePublisher(headless, title, content, tags, tenantID, platIndex, requestID, imagePath, wordPath, platInfo, cfg)
if platInfo != nil {
base.LoginURL = getString(platInfo, "login_url")
base.EditorURL = getString(platInfo, "edit_url")
base.LoginedURL = getString(platInfo, "logined_url")
}
return &BaijiahaoPublisher{BasePublisher: base}
}
func (p *BaijiahaoPublisher) CheckLoginStatus() bool {
url := p.GetCurrentURL()
// 如果URL包含登录相关关键词表示未登录
if strings.Contains(url, "login") || strings.Contains(url, "passport") {
return false
}
// 如果URL是编辑页面或主页表示已登录
if strings.Contains(url, "baijiahao") || strings.Contains(url, "edit") {
return true
}
return url != p.LoginURL
}
func (p *BaijiahaoPublisher) CheckLogin() (bool, string) {
p.LogInfo("检查登录状态...")
if err := p.SetupDriver(); err != nil {
return false, fmt.Sprintf("浏览器启动失败: %v", err)
}
defer p.Close()
p.Page.MustNavigate(p.LoginedURL)
p.Sleep(3)
p.WaitForPageReady(5)
if p.CheckLoginStatus() {
p.SaveCookies()
return true, "已登录"
}
return false, "未登录"
}
func (p *BaijiahaoPublisher) WaitLogin() (bool, string) {
p.LogInfo("开始等待登录...")
if err := p.SetupDriver(); err != nil {
return false, fmt.Sprintf("浏览器启动失败: %v", err)
}
defer p.Close()
// 先尝试访问已登录页面
p.Page.MustNavigate(p.LoginedURL)
p.Sleep(3)
if p.CheckLoginStatus() {
p.SaveCookies()
p.LogInfo("已有登录状态")
return true, "already_logged_in"
}
// 未登录,跳转到登录页
p.Page.MustNavigate(p.LoginURL)
p.LogInfo("请扫描二维码登录...")
// 等待登录完成最多120秒
for i := 0; i < 120; i++ {
time.Sleep(1 * time.Second)
if p.CheckLoginStatus() {
p.SaveCookies()
p.LogInfo("登录成功")
return true, "login_success"
}
}
return false, "登录超时"
}
func (p *BaijiahaoPublisher) inputTitle() error {
p.LogInfo("输入标题...")
titleSelectors := []string{
".client_pages_edit_components_titleInput [contenteditable='true']",
".input-box [contenteditable='true']",
"[contenteditable='true']",
}
var titleInput *rod.Element
var err error
for _, selector := range titleSelectors {
titleInput, err = p.WaitForElementVisible(selector, 5)
if err == nil && titleInput != nil {
p.LogInfo(fmt.Sprintf("找到标题输入框: %s", selector))
break
}
}
if titleInput == nil {
return fmt.Errorf("未找到标题输入框")
}
// 点击获取焦点
if err := titleInput.Click(proto.InputMouseButtonLeft, 1); err != nil {
return fmt.Errorf("点击标题框失败: %v", err)
}
p.SleepMs(500)
// 清空输入框
if err := p.ClearContentEditable(titleInput); err != nil {
p.LogInfo(fmt.Sprintf("清空标题框失败: %v", err))
}
p.SleepMs(300)
// 输入标题
if err := p.SetContentEditable(titleInput, p.Title); err != nil {
// 备用输入方式
titleInput.Input(p.Title)
}
p.LogInfo(fmt.Sprintf("标题已输入: %s", p.Title))
return nil
}
func (p *BaijiahaoPublisher) inputContent() error {
p.LogInfo("输入内容...")
// 查找内容编辑器
contentEditor, err := p.WaitForElementVisible(".ProseMirror", 10)
if err != nil {
contentEditor, err = p.WaitForElementVisible("[contenteditable='true']", 10)
if err != nil {
return fmt.Errorf("未找到内容编辑器: %v", err)
}
}
// 点击获取焦点
if err := contentEditor.Click(proto.InputMouseButtonLeft, 1); err != nil {
return fmt.Errorf("点击编辑器失败: %v", err)
}
p.SleepMs(500)
// 清空编辑器
if err := p.ClearContentEditable(contentEditor); err != nil {
p.LogInfo(fmt.Sprintf("清空编辑器失败: %v", err))
}
p.SleepMs(300)
// 输入内容
if err := p.SetContentEditable(contentEditor, p.Content); err != nil {
// 备用输入方式
contentEditor.Input(p.Content)
}
p.LogInfo(fmt.Sprintf("内容已输入,长度: %d", len(p.Content)))
return nil
}
func (p *BaijiahaoPublisher) uploadImage() error {
if p.ImagePath == "" {
p.LogInfo("无封面图片,跳过")
return nil
}
p.LogInfo(fmt.Sprintf("上传封面: %s", p.ImagePath))
// 查找封面区域
coverArea, err := p.WaitForElementClickable(".cheetah-spin-container", 5)
if err != nil {
p.LogInfo("未找到封面区域,跳过")
return nil
}
if err := coverArea.Click(proto.InputMouseButtonLeft, 1); err != nil {
p.LogInfo(fmt.Sprintf("点击封面区域失败: %v", err))
}
p.SleepMs(1000)
// 查找文件输入框
fileInput, err := p.Page.Element("input[type='file'][accept*='image']")
if err != nil {
fileInput, err = p.Page.Element("input[type='file']")
if err != nil {
return fmt.Errorf("未找到文件输入框: %v", err)
}
}
// 上传图片
if err := fileInput.SetFiles([]string{p.ImagePath}); err != nil {
return fmt.Errorf("上传图片失败: %v", err)
}
p.LogInfo("图片上传成功")
p.Sleep(3)
// 查找确认按钮
confirmBtn, err := p.WaitForElementClickable(".cheetah-btn-primary", 5)
if err == nil && confirmBtn != nil {
if err := confirmBtn.Click(proto.InputMouseButtonLeft, 1); err != nil {
p.LogInfo(fmt.Sprintf("点击确认按钮失败: %v", err))
}
p.LogInfo("已确认封面")
p.SleepMs(1000)
}
return nil
}
func (p *BaijiahaoPublisher) clickPublish() error {
p.LogInfo("点击发布按钮...")
// 滚动到底部
if _, err := p.Page.Eval(`() => window.scrollTo(0, document.body.scrollHeight)`); err != nil {
p.LogInfo(fmt.Sprintf("滚动到底部失败: %v", err))
}
p.SleepMs(1000)
// 查找发布按钮
publishSelectors := []string{
"[data-testid='publish-btn']",
".op-list-right .cheetah-btn-primary",
".cheetah-btn-primary",
"button:contains('发布')",
}
var publishBtn *rod.Element
var err error
for _, selector := range publishSelectors {
publishBtn, err = p.WaitForElementClickable(selector, 5)
if err == nil && publishBtn != nil {
p.LogInfo(fmt.Sprintf("找到发布按钮: %s", selector))
break
}
}
// 如果还是没找到,通过 XPath 查找
if publishBtn == nil {
publishBtn, err = p.Page.ElementX("//button[contains(text(), '发布')]")
if err != nil {
return fmt.Errorf("未找到发布按钮: %v", err)
}
}
// 滚动到按钮位置
if err := p.ScrollToElement(publishBtn); err != nil {
p.LogInfo(fmt.Sprintf("滚动到按钮失败: %v", err))
}
p.SleepMs(500)
// 点击发布
if err := publishBtn.Click(proto.InputMouseButtonLeft, 1); err != nil {
return fmt.Errorf("点击发布按钮失败: %v", err)
}
p.LogInfo("已点击发布按钮")
return nil
}
func (p *BaijiahaoPublisher) waitForPublishResult() (bool, string) {
p.LogInfo("等待发布结果...")
// 等待最多60秒
for i := 0; i < 60; i++ {
p.SleepMs(1000)
// 检查URL是否跳转到成功页面
currentURL := p.GetCurrentURL()
if strings.Contains(currentURL, "clue") ||
strings.Contains(currentURL, "success") ||
strings.Contains(currentURL, "article/list") {
p.LogInfo("发布成功!")
return true, "发布成功"
}
// 检查是否有成功提示
elements, _ := p.Page.Elements(".cheetah-message-success, .cheetah-message-info, [class*='success']")
for _, el := range elements {
text, _ := el.Text()
if strings.Contains(text, "成功") || strings.Contains(text, "已发布") {
p.LogInfo(fmt.Sprintf("发布成功: %s", text))
return true, text
}
}
// 检查是否有失败提示
elements, _ = p.Page.Elements(".cheetah-message-error, .cheetah-message-warning, [class*='error']")
for _, el := range elements {
text, _ := el.Text()
if strings.Contains(text, "失败") || strings.Contains(text, "错误") {
p.LogError(fmt.Sprintf("发布失败: %s", text))
return false, text
}
}
}
return false, "发布结果未知(超时)"
}
func (p *BaijiahaoPublisher) PublishNote() (bool, string) {
p.LogInfo(strings.Repeat("=", 50))
p.LogInfo("开始发布百家号文章...")
p.LogInfo(fmt.Sprintf("标题: %s", p.Title))
p.LogInfo(fmt.Sprintf("内容长度: %d", len(p.Content)))
p.LogInfo(strings.Repeat("=", 50))
// 初始化浏览器
if err := p.SetupDriver(); err != nil {
return false, fmt.Sprintf("浏览器启动失败: %v", err)
}
defer p.Close()
// 访问编辑器页面
p.Page.MustNavigate(p.EditorURL)
p.Sleep(3)
p.WaitForPageReady(5)
// 尝试加载cookies
if err := p.LoadCookies(); err == nil {
p.RefreshPage()
p.Sleep(2)
if p.CheckLoginStatus() {
p.LogInfo("使用cookies登录成功")
} else {
p.LogInfo("cookies已过期需要重新登录")
return false, "需要登录"
}
}
// 检查登录状态
if !p.CheckLoginStatus() {
return false, "需要登录"
}
// 保存cookies
p.SaveCookies()
// 执行发布流程
steps := []struct {
name string
fn func() error
}{
{"输入标题", p.inputTitle},
{"输入内容", p.inputContent},
{"上传封面", p.uploadImage},
}
for _, step := range steps {
if err := step.fn(); err != nil {
p.LogStep(step.name, false, err.Error())
return false, fmt.Sprintf("%s失败: %v", step.name, err)
}
p.LogStep(step.name, true, "")
p.SleepMs(500)
}
// 点击发布
if err := p.clickPublish(); err != nil {
return false, err.Error()
}
// 等待发布结果
return p.waitForPublishResult()
}

282
internal/publisher/base.go Normal file
View File

@ -0,0 +1,282 @@
package publisher
import (
"encoding/json"
"fmt"
"log"
"os"
"path/filepath"
"time"
"geo/internal/config"
"github.com/go-rod/rod"
"github.com/go-rod/rod/lib/launcher"
"github.com/go-rod/rod/lib/proto"
)
type BasePublisher struct {
Headless bool
Title string
Content string
Tags []string
TenantID string
PlatIndex string
RequestID string
ImagePath string
WordPath string
Browser *rod.Browser
Page *rod.Page
Logger *log.Logger
LogFile *os.File
LoginURL string
EditorURL string
LoginedURL string
CookiesFile string
PlatInfo map[string]interface{}
config *config.Config
}
func NewBasePublisher(headless bool, title, content string, tags []string, tenantID, platIndex, requestID, imagePath, wordPath string, platInfo map[string]interface{}, config *config.Config) *BasePublisher {
cookiesDir := filepath.Join(config.Sys.CookiesDir, tenantID)
os.MkdirAll(cookiesDir, 0755)
cookiesFile := filepath.Join(cookiesDir, platIndex+".json")
logFile, _ := os.Create(filepath.Join(config.Sys.LogsDir, requestID+".log"))
logger := log.New(logFile, "", log.LstdFlags)
return &BasePublisher{
Headless: headless,
Title: title,
Content: content,
Tags: tags,
TenantID: tenantID,
PlatIndex: platIndex,
RequestID: requestID,
ImagePath: imagePath,
WordPath: wordPath,
Logger: logger,
LogFile: logFile,
CookiesFile: cookiesFile,
PlatInfo: platInfo,
config: config,
}
}
func (b *BasePublisher) SetupDriver() error {
l := launcher.New()
l.Headless(b.Headless)
// 设置用户数据目录
userDataDir := filepath.Join(b.config.Sys.ChromeDataDir, b.TenantID)
os.MkdirAll(userDataDir, 0755)
l.UserDataDir(userDataDir)
// 设置 Leakless 模式(解决 Windows 上的问题)
l.Leakless(false)
// 设置 Chrome 启动参数
l.Set("disable-blink-features", "AutomationControlled")
l.Set("no-sandbox")
l.Set("disable-dev-shm-usage")
l.Set("disable-gpu")
l.Set("disable-software-rasterizer")
l.Set("disable-setuid-sandbox")
l.Set("remote-debugging-port", "9222")
// 窗口大小
l.Set("window-size", "1920,1080")
l.Set("lang", "zh-CN")
url, err := l.Launch()
if err != nil {
return fmt.Errorf("启动浏览器失败: %v", err)
}
b.Browser = rod.New().ControlURL(url).MustConnect()
b.Page = b.Browser.MustPage()
b.Page.MustSetViewport(1920, 1080, 1, false)
return nil
}
func (b *BasePublisher) Close() {
if b.Page != nil {
b.Page.Close()
}
if b.Browser != nil {
b.Browser.Close()
}
if b.LogFile != nil {
b.LogFile.Close()
}
}
func (b *BasePublisher) SaveCookies() error {
cookies, err := b.Page.Cookies(nil)
if err != nil {
return err
}
data, err := json.Marshal(cookies)
if err != nil {
return err
}
return os.WriteFile(b.CookiesFile, data, 0644)
}
func (b *BasePublisher) LoadCookies() error {
data, err := os.ReadFile(b.CookiesFile)
if err != nil {
return err
}
var cookies []*proto.NetworkCookieParam
if err := json.Unmarshal(data, &cookies); err != nil {
return err
}
return b.Page.SetCookies(cookies)
}
func (b *BasePublisher) RefreshPage() error {
_, err := b.Page.Eval(`() => location.reload()`)
return err
}
func (b *BasePublisher) WaitForPageReady(timeout int) error {
return b.Page.Timeout(time.Duration(timeout) * time.Second).WaitLoad()
}
func (b *BasePublisher) WaitForElement(selector string, timeout int) (*rod.Element, error) {
return b.Page.Timeout(time.Duration(timeout) * time.Second).Element(selector)
}
func (b *BasePublisher) WaitForElementVisible(selector string, timeout int) (*rod.Element, error) {
el, err := b.WaitForElement(selector, timeout)
if err != nil {
return nil, err
}
if err := el.WaitVisible(); err != nil {
return nil, err
}
return el, nil
}
func (b *BasePublisher) WaitForElementClickable(selector string, timeout int) (*rod.Element, error) {
el, err := b.WaitForElementVisible(selector, timeout)
if err != nil {
return nil, err
}
if err := el.WaitVisible(); err != nil {
return nil, err
}
return el, nil
}
func (b *BasePublisher) JSClick(element *rod.Element) error {
// 使用 element.Evaluate 并传入 EvalOptions
_, err := element.Evaluate(&rod.EvalOptions{
JS: `el => el.click()`,
})
return err
}
func (b *BasePublisher) ScrollToElement(element *rod.Element) error {
_, err := element.Evaluate(&rod.EvalOptions{
JS: `el => el.scrollIntoView({block: 'center', behavior: 'smooth'})`,
})
return err
}
func (b *BasePublisher) Sleep(seconds int) {
time.Sleep(time.Duration(seconds) * time.Second)
}
func (b *BasePublisher) LogStep(stepName string, success bool, message string) {
if success {
b.Logger.Printf("✅ %s: 成功 %s", stepName, message)
} else {
b.Logger.Printf("❌ %s: 失败 %s", stepName, message)
}
}
func (b *BasePublisher) LogInfo(message string) {
b.Logger.Printf("📌 %s", message)
}
func (b *BasePublisher) LogError(message string) {
b.Logger.Printf("❌ %s", message)
}
func (b *BasePublisher) GetCurrentURL() string {
info := b.Page.MustInfo()
return info.URL
}
func (b *BasePublisher) Screenshot(filename string) error {
data, err := b.Page.Screenshot(false, nil)
if err != nil {
return err
}
return os.WriteFile(filename, data, 0644)
}
// 抽象方法 - 子类需要实现
func (b *BasePublisher) WaitLogin() (bool, string) {
return false, "需要实现"
}
func (b *BasePublisher) CheckLoginStatus() bool {
return false
}
func (b *BasePublisher) CheckLogin() (bool, string) {
return false, "需要实现"
}
func (b *BasePublisher) PublishNote() (bool, string) {
return false, "需要实现"
}
// ClearContentEditable 清空 contenteditable 元素的内容
func (b *BasePublisher) ClearContentEditable(element *rod.Element) error {
_, err := element.Evaluate(&rod.EvalOptions{
JS: `el => { el.innerText = ''; el.innerHTML = ''; el.dispatchEvent(new Event('input', {bubbles: true})); }`,
})
return err
}
// SetContentEditable 设置 contenteditable 元素的内容
func (b *BasePublisher) SetContentEditable(element *rod.Element, content string) error {
_, err := element.Evaluate(&rod.EvalOptions{
JS: `(el, val) => { el.innerText = val; el.dispatchEvent(new Event('input', {bubbles: true})); }`,
JSArgs: []interface{}{content},
})
return err
}
// SetInputValue 设置输入框值并触发事件
func (b *BasePublisher) SetInputValue(element *rod.Element, value string) error {
_, err := element.Evaluate(&rod.EvalOptions{
JS: `(el, val) => { el.value = val; el.dispatchEvent(new Event('input', {bubbles: true})); el.dispatchEvent(new Event('change', {bubbles: true})); }`,
JSArgs: []interface{}{value},
})
return err
}
// ClearInput 清空输入框
func (b *BasePublisher) ClearInput(element *rod.Element) error {
_, err := element.Evaluate(&rod.EvalOptions{
JS: `el => { el.value = ''; el.dispatchEvent(new Event('input', {bubbles: true})); }`,
})
return err
}
// SleepMs 毫秒级等待
func (b *BasePublisher) SleepMs(milliseconds int) {
time.Sleep(time.Duration(milliseconds) * time.Millisecond)
}

View File

@ -0,0 +1,425 @@
package publisher
import (
"fmt"
"strings"
"time"
"geo/internal/config"
"github.com/go-rod/rod"
"github.com/go-rod/rod/lib/proto"
)
type XiaohongshuPublisher struct {
*BasePublisher
}
func NewXiaohongshuPublisher(headless bool, title, content string, tags []string, tenantID, platIndex, requestID, imagePath, wordPath string, platInfo map[string]interface{}, cfg *config.Config) *XiaohongshuPublisher {
base := NewBasePublisher(headless, title, content, tags, tenantID, platIndex, requestID, imagePath, wordPath, platInfo, cfg)
if platInfo != nil {
base.LoginURL = getString(platInfo, "login_url")
base.EditorURL = getString(platInfo, "edit_url")
base.LoginedURL = getString(platInfo, "logined_url")
}
return &XiaohongshuPublisher{BasePublisher: base}
}
func (p *XiaohongshuPublisher) CheckLoginStatus() bool {
url := p.GetCurrentURL()
// 如果URL包含登录相关关键词表示未登录
if strings.Contains(url, "login") || strings.Contains(url, "signin") || strings.Contains(url, "passport") {
return false
}
// 如果URL是编辑页面或主页表示已登录
if strings.Contains(url, "creator") || strings.Contains(url, "editor") || strings.Contains(url, "publish") {
return true
}
return true
}
func (p *XiaohongshuPublisher) CheckLogin() (bool, string) {
p.LogInfo("检查登录状态...")
if err := p.SetupDriver(); err != nil {
return false, fmt.Sprintf("浏览器启动失败: %v", err)
}
defer p.Close()
p.Page.MustNavigate(p.LoginedURL)
p.Sleep(3)
p.WaitForPageReady(5)
if p.CheckLoginStatus() {
p.SaveCookies()
return true, "已登录"
}
return false, "未登录"
}
func (p *XiaohongshuPublisher) WaitLogin() (bool, string) {
p.LogInfo("开始等待登录...")
if err := p.SetupDriver(); err != nil {
return false, fmt.Sprintf("浏览器启动失败: %v", err)
}
defer p.Close()
// 先尝试访问已登录页面
p.Page.MustNavigate(p.LoginedURL)
p.Sleep(3)
if p.CheckLoginStatus() {
p.SaveCookies()
p.LogInfo("已有登录状态")
return true, "already_logged_in"
}
// 未登录,跳转到登录页
p.Page.MustNavigate(p.LoginURL)
p.LogInfo("请扫描二维码登录...")
// 等待登录完成最多120秒
for i := 0; i < 120; i++ {
time.Sleep(1 * time.Second)
if p.CheckLoginStatus() {
p.SaveCookies()
p.LogInfo("登录成功")
return true, "login_success"
}
}
return false, "登录超时"
}
func (p *XiaohongshuPublisher) inputContent() error {
p.LogInfo("输入文章内容...")
// 等待编辑器加载
contentEditor, err := p.WaitForElementVisible(".tiptap.ProseMirror", 10)
if err != nil {
// 尝试其他选择器
contentEditor, err = p.WaitForElementVisible("[contenteditable='true']", 10)
if err != nil {
return fmt.Errorf("未找到内容编辑器: %v", err)
}
}
// 点击获取焦点 - 使用 Click 方法
if err := contentEditor.Click(proto.InputMouseButtonLeft, 1); err != nil {
return fmt.Errorf("点击编辑器失败: %v", err)
}
p.SleepMs(500)
// 清空现有内容 - 使用 JavaScript 清空
if err := p.ClearContentEditable(contentEditor); err != nil {
p.LogInfo(fmt.Sprintf("清空编辑器失败: %v", err))
}
p.SleepMs(300)
// 输入新内容 - 使用 JavaScript 设置内容
if err := p.SetContentEditable(contentEditor, p.Content); err != nil {
// 如果 JS 方式失败,尝试直接输入
contentEditor.Input(p.Content)
}
p.LogInfo(fmt.Sprintf("内容已输入,长度: %d", len(p.Content)))
return nil
}
func (p *XiaohongshuPublisher) inputTitle() error {
p.LogInfo("输入标题...")
// 查找标题输入框
titleSelectors := []string{
"textarea.d-input",
".d-input input",
"textarea[placeholder*='标题']",
"textarea",
}
var titleInput *rod.Element
var err error
for _, selector := range titleSelectors {
titleInput, err = p.WaitForElementVisible(selector, 3)
if err == nil && titleInput != nil {
p.LogInfo(fmt.Sprintf("找到标题输入框: %s", selector))
break
}
}
if titleInput == nil {
return fmt.Errorf("未找到标题输入框")
}
// 点击获取焦点
if err := titleInput.Click(proto.InputMouseButtonLeft, 1); err != nil {
return fmt.Errorf("点击标题框失败: %v", err)
}
p.SleepMs(500)
// 清空输入框
if err := p.ClearInput(titleInput); err != nil {
// 备用清空方式
titleInput.Input("")
}
p.SleepMs(300)
// 输入标题
if err := p.SetInputValue(titleInput, p.Title); err != nil {
// 备用输入方式
titleInput.Input(p.Title)
}
p.LogInfo(fmt.Sprintf("标题已输入: %s", p.Title))
return nil
}
func (p *XiaohongshuPublisher) inputTags() error {
if len(p.Tags) == 0 {
p.LogInfo("无标签需要设置")
return nil
}
p.LogInfo(fmt.Sprintf("设置标签: %v", p.Tags))
// 构建标签字符串
tagStr := ""
for _, tag := range p.Tags {
if tagStr != "" {
tagStr += " "
}
tagStr += "#" + tag
}
// 查找标签输入区域
tagInput, err := p.WaitForElementVisible(".tiptap-container [contenteditable='true']", 5)
if err != nil {
p.LogInfo("未找到标签输入框,跳过标签设置")
return nil
}
if err := tagInput.Click(proto.InputMouseButtonLeft, 1); err != nil {
p.LogInfo(fmt.Sprintf("点击标签框失败: %v", err))
}
p.SleepMs(500)
if err := p.SetContentEditable(tagInput, tagStr); err != nil {
tagInput.Input(tagStr)
}
p.LogInfo("标签设置完成")
return nil
}
func (p *XiaohongshuPublisher) uploadImage() error {
if p.ImagePath == "" {
p.LogInfo("无封面图片,跳过上传")
return nil
}
p.LogInfo(fmt.Sprintf("上传封面图片: %s", p.ImagePath))
// 查找封面上传按钮
uploadBtn, err := p.WaitForElementClickable(".upload-content", 5)
if err != nil {
p.LogInfo("未找到封面上传区域,跳过")
return nil
}
// 使用 Click 方法
if err := uploadBtn.Click(proto.InputMouseButtonLeft, 1); err != nil {
p.LogInfo(fmt.Sprintf("点击上传按钮失败: %v", err))
}
p.SleepMs(1000)
// 查找文件输入框
fileInput, err := p.Page.Element("input[type='file']")
if err != nil {
return fmt.Errorf("未找到文件输入框: %v", err)
}
// 使用 SetFiles 上传文件
if err := fileInput.SetFiles([]string{p.ImagePath}); err != nil {
return fmt.Errorf("上传图片失败: %v", err)
}
p.LogInfo("图片上传成功")
p.Sleep(3)
return nil
}
func (p *XiaohongshuPublisher) clickPublish() error {
p.LogInfo("点击发布按钮...")
// 滚动到底部
if _, err := p.Page.Eval(`() => window.scrollTo(0, document.body.scrollHeight)`); err != nil {
p.LogInfo(fmt.Sprintf("滚动到底部失败: %v", err))
}
p.SleepMs(1000)
// 查找发布按钮
publishSelectors := []string{
".publish-page-publish-btn button",
".publish-btn",
".submit-btn",
"button[type='submit']",
}
var publishBtn *rod.Element
var err error
for _, selector := range publishSelectors {
publishBtn, err = p.WaitForElementClickable(selector, 5)
if err == nil && publishBtn != nil {
p.LogInfo(fmt.Sprintf("找到发布按钮: %s", selector))
break
}
}
// 如果还是没找到,通过文本查找
if publishBtn == nil {
publishBtn, err = p.Page.ElementX("//button[contains(text(), '发布')]")
if err != nil {
return fmt.Errorf("未找到发布按钮: %v", err)
}
}
// 滚动到按钮位置
if err := p.ScrollToElement(publishBtn); err != nil {
p.LogInfo(fmt.Sprintf("滚动到按钮失败: %v", err))
}
p.SleepMs(500)
// 点击发布 - 使用 Click 方法
if err := publishBtn.Click(proto.InputMouseButtonLeft, 1); err != nil {
return fmt.Errorf("点击发布按钮失败: %v", err)
}
p.LogInfo("已点击发布按钮")
return nil
}
func (p *XiaohongshuPublisher) waitForPublishResult() (bool, string) {
p.LogInfo("等待发布结果...")
// 等待最多60秒
for i := 0; i < 60; i++ {
p.SleepMs(1000)
// 检查URL是否跳转到成功页面
currentURL := p.GetCurrentURL()
if strings.Contains(currentURL, "success") ||
strings.Contains(currentURL, "content/manage") ||
strings.Contains(currentURL, "work-management") {
p.LogInfo("发布成功!")
return true, "发布成功"
}
// 检查是否有成功提示
elements, _ := p.Page.Elements(".semi-toast-content, .toast-success, [class*='success']")
for _, el := range elements {
text, _ := el.Text()
if strings.Contains(text, "成功") || strings.Contains(text, "已发布") {
p.LogInfo(fmt.Sprintf("发布成功: %s", text))
return true, text
}
}
// 检查是否有失败提示
elements, _ = p.Page.Elements(".semi-toast-error, .toast-error, [class*='error']")
for _, el := range elements {
text, _ := el.Text()
if strings.Contains(text, "失败") || strings.Contains(text, "错误") {
p.LogError(fmt.Sprintf("发布失败: %s", text))
return false, text
}
}
}
return false, "发布结果未知(超时)"
}
func (p *XiaohongshuPublisher) PublishNote() (bool, string) {
p.LogInfo(strings.Repeat("=", 50))
p.LogInfo("开始发布小红书笔记...")
p.LogInfo(fmt.Sprintf("标题: %s", p.Title))
p.LogInfo(fmt.Sprintf("内容长度: %d", len(p.Content)))
p.LogInfo(fmt.Sprintf("标签: %v", p.Tags))
p.LogInfo(strings.Repeat("=", 50))
// 初始化浏览器
if err := p.SetupDriver(); err != nil {
return false, fmt.Sprintf("浏览器启动失败: %v", err)
}
defer p.Close()
// 访问已登录页面
p.Page.MustNavigate(p.LoginedURL)
p.Sleep(3)
p.WaitForPageReady(5)
// 尝试加载cookies
if err := p.LoadCookies(); err == nil {
p.RefreshPage()
p.Sleep(2)
if p.CheckLoginStatus() {
p.LogInfo("使用cookies登录成功")
} else {
p.LogInfo("cookies已过期需要重新登录")
return false, "需要登录"
}
}
// 检查登录状态
if !p.CheckLoginStatus() {
return false, "需要登录"
}
// 保存cookies
p.SaveCookies()
// 访问发布页面
p.Page.MustNavigate(p.EditorURL)
p.Sleep(3)
p.WaitForPageReady(5)
// 执行发布流程
steps := []struct {
name string
fn func() error
}{
{"输入内容", p.inputContent},
{"输入标题", p.inputTitle},
{"设置标签", p.inputTags},
{"上传封面", p.uploadImage},
}
for _, step := range steps {
if err := step.fn(); err != nil {
p.LogStep(step.name, false, err.Error())
return false, fmt.Sprintf("%s失败: %v", step.name, err)
}
p.LogStep(step.name, true, "")
p.SleepMs(500)
}
// 点击发布
if err := p.clickPublish(); err != nil {
return false, err.Error()
}
// 等待发布结果
return p.waitForPublishResult()
}
func getString(m map[string]interface{}, key string) string {
if v, ok := m[key]; ok {
if s, ok := v.(string); ok {
return s
}
}
return ""
}

25
internal/server/http.go Normal file
View File

@ -0,0 +1,25 @@
package server
import (
"geo/internal/server/router"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/logger"
"github.com/gofiber/fiber/v2/middleware/recover"
)
func NewHTTPServer(routerServer *router.RouterServer) *fiber.App {
//构建 server
app := initRoute()
router.SetupRoutes(app, routerServer)
return app
}
func initRoute() *fiber.App {
app := fiber.New()
app.Use(
recover.New(),
logger.New(),
)
return app
}

View File

@ -0,0 +1,14 @@
package server
import (
"geo/internal/server/router"
"github.com/google/wire"
)
var ProviderSetServer = wire.NewSet(
NewServers,
NewHTTPServer,
router.NewRouterServer,
router.NewAppModule,
)

View File

@ -0,0 +1,47 @@
package router
import (
"geo/internal/config"
"geo/internal/entitys"
"geo/internal/service"
"github.com/gofiber/fiber/v2"
)
type AppModule struct {
cfg *config.Config
appService *service.AppService
loginService *service.LoginService
publishService *service.PublishService
}
func NewAppModule(
cfg *config.Config,
appService *service.AppService,
loginService *service.LoginService,
publishService *service.PublishService,
) *AppModule {
return &AppModule{
cfg: cfg,
appService: appService,
loginService: loginService,
publishService: publishService,
}
}
func (m *AppModule) Register(router fiber.Router) {
router.Post("/login_app", vali(m.appService.LoginApp, &entitys.LoginAppRequest{}))
router.Post("/get_user_and_auto_status", vali(m.appService.GetUserAndAutoStatus, &entitys.GetUserAndAutoStatusRequest{}))
router.Post("/add_user", vali(m.appService.AddUser, &entitys.AddUserRequest{}))
router.Post("/del_user", vali(m.appService.DelUser, &entitys.DelUserRequest{}))
router.Post("/get_app", vali(m.appService.GetApp, &entitys.GetAppRequest{}))
router.Post("/publish_records", vali(m.publishService.PublishRecords, &entitys.PublishRecordsRequest{}))
router.Post("/publish_on", vali(m.publishService.PublishOn, &entitys.PublishOnRequest{}))
router.Post("/publish_off", vali(m.publishService.PublishOff, &entitys.PublishOffRequest{}))
router.Post("/publish_status", vali(m.publishService.PublishStatus, &entitys.PublishStatusRequest{}))
router.Post("/publish_execute_once", vali(m.publishService.PublishExecuteOnce, &entitys.PublishExecuteOnceRequest{}))
router.Post("/publish_execute_retry", vali(m.publishService.PublishExecuteRetry, &entitys.PublishExecuteRetryRequest{}))
router.Post("/get_publish_list", vali(m.publishService.GetPublishList, &entitys.GetPublishListRequest{}))
router.Post("/login_platform", vali(m.loginService.LoginPlatform, &entitys.LoginPlatformRequest{}))
router.Post("/logout_platform", vali(m.loginService.LogoutPlatform, &entitys.LogoutPlatformRequest{}))
}

View File

@ -0,0 +1,128 @@
package router
import (
"encoding/json"
"errors"
"geo/pkg"
"geo/tmpl/errcode"
"reflect"
"strings"
"github.com/go-playground/validator/v10"
"github.com/gofiber/fiber/v2"
)
type Module interface {
Register(r fiber.Router)
}
type RouterServer struct {
modules []Module
}
func NewRouterServer(
app *AppModule,
) *RouterServer {
return &RouterServer{
modules: []Module{app},
}
}
// SetupRoutes 设置路由
func SetupRoutes(app *fiber.App, routerServer *RouterServer) {
app.Use(func(c *fiber.Ctx) error {
// 设置 CORS 头
c.Set("Access-Control-Allow-Origin", "*")
c.Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
c.Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
// 如果是预检请求OPTIONS直接返回 204
if c.Method() == "OPTIONS" {
return c.SendStatus(fiber.StatusNoContent) // 204
}
// 继续处理后续中间件或路由
return c.Next()
})
registerResponse(app)
// 注册所有模块
for _, module := range routerServer.modules {
module.Register(app)
}
}
func registerResponse(router fiber.Router) {
// 自定义返回
router.Use(func(c *fiber.Ctx) error {
err := c.Next()
return registerCommon(c, err)
})
}
func registerCommon(c *fiber.Ctx, err error) error {
// 调用下一个中间件或路由处理函数
// 如果有错误发生
if err != nil {
var businessErr *errcode.BusinessErr
var fiberError *fiber.Error
switch {
case errors.As(err, &businessErr):
case errors.As(err, &fiberError):
errors.As(err, &fiberError)
businessErr = errcode.NewBusinessErr(fiberError.Code, fiberError.Message)
default:
businessErr = errcode.SystemError
}
// 返回自定义错误响应
return c.JSON(fiber.Map{
"message": businessErr.Error(),
"code": businessErr.Code(),
"data": nil,
})
}
// 如果没有错误发生,继续处理请求
var data interface{}
json.Unmarshal(c.Response().Body(), &data)
return c.JSON(fiber.Map{
"data": data,
"message": errcode.Success.Error(),
"code": errcode.Success.Code(),
})
}
func vali[T any](handler func(*fiber.Ctx, *T) error, _ *T) fiber.Handler {
return func(c *fiber.Ctx) error {
var data T
// 解析请求
if err := c.BodyParser(&data); err != nil {
return errcode.ParamErr(err.Error())
}
// 创建验证器
validate := validator.New()
// 注册中文标签
validate.RegisterTagNameFunc(func(fld reflect.StructField) string {
name := fld.Tag.Get("zh")
if name == "" {
name = fld.Tag.Get("json")
}
return name
})
// 验证
if err := validate.Struct(&data); err != nil {
er := make([]string, len(err.(validator.ValidationErrors)))
for k, e := range err.(validator.ValidationErrors) {
er[k] = pkg.GetErr(e.Tag(), e.Field(), e.Param())
}
return errcode.ParamErr(strings.Join(er, ","))
}
return handler(c, &data)
}
}

19
internal/server/server.go Normal file
View File

@ -0,0 +1,19 @@
package server
import (
"geo/internal/config"
"github.com/gofiber/fiber/v2"
)
type Servers struct {
cfg *config.Config
HttpServer *fiber.App
}
func NewServers(cfg *config.Config, fiber *fiber.App) *Servers {
return &Servers{
HttpServer: fiber,
cfg: cfg,
}
}

173
internal/service/app.go Normal file
View File

@ -0,0 +1,173 @@
package service
import (
"fmt"
"github.com/gofiber/fiber/v2"
"xorm.io/builder"
"geo/internal/biz"
"geo/internal/config"
"geo/internal/data/impl"
"geo/internal/data/model"
"geo/internal/entitys"
"geo/internal/manager"
"geo/pkg"
"geo/tmpl/errcode"
)
type AppService struct {
cfg *config.Config
tokenImpl *impl.TokenImpl
userImpl *impl.UserImpl
platImpl *impl.PlatImpl
publishBiz *biz.PublishBiz
}
func NewAppService(
cfg *config.Config,
tokenImpl *impl.TokenImpl,
userImpl *impl.UserImpl,
platImpl *impl.PlatImpl,
publishBiz *biz.PublishBiz,
) *AppService {
return &AppService{
cfg: cfg,
tokenImpl: tokenImpl,
userImpl: userImpl,
platImpl: platImpl,
publishBiz: publishBiz,
}
}
func (a *AppService) LoginApp(c *fiber.Ctx, req *entitys.LoginAppRequest) error {
cond := builder.NewCond().
And(builder.Eq{"secret": req.Secret}).
And(builder.Eq{"status": 1})
tokenInfo := &model.Token{}
err := a.tokenImpl.GetOneBySearchStruct(c.UserContext(), &cond, tokenInfo)
if err != nil || tokenInfo == nil {
return errcode.Forbidden("密钥无效")
}
accessToken := pkg.GenerateUUID()
err = a.tokenImpl.UpdateByKey(c.UserContext(), a.tokenImpl.PrimaryKey(), tokenInfo.ID, &model.Token{
AccessToken: accessToken,
})
if err != nil {
return err
}
return pkg.HandleResponse(c, fiber.Map{
"access_token": accessToken,
"user_limit": tokenInfo.UserLimit,
"expire_time": tokenInfo.ExpireTime,
})
}
func (a *AppService) GetUserAndAutoStatus(c *fiber.Ctx, req *entitys.GetUserAndAutoStatusRequest) error {
tokenInfo, err := a.publishBiz.ValidateAccessToken(c.UserContext(), req.AccessToken)
if err != nil {
return err
}
cond := builder.NewCond().
And(builder.Eq{"token_id": tokenInfo.ID}).
And(builder.Eq{"status": 1})
users := make([]model.User, 0)
_, err = a.userImpl.GetListToStruct(c.UserContext(), &cond, nil, &users, "id desc")
if err != nil {
return err
}
pm := manager.GetPublishManager(a.cfg, a.tokenImpl.GetDb())
return pkg.HandleResponse(c, fiber.Map{
"user": users,
"auto_status": pm.AutoStatus,
})
}
func (a *AppService) AddUser(c *fiber.Ctx, req *entitys.AddUserRequest) error {
tokenInfo, err := a.publishBiz.ValidateAccessToken(c.UserContext(), req.AccessToken)
if err != nil {
return err
}
cond := builder.NewCond().
And(builder.Eq{"token_id": tokenInfo.ID}).
And(builder.Eq{"status": 1})
userCount, err := a.userImpl.CountByCond(c.UserContext(), &cond)
if err != nil {
return err
}
if userCount >= int64(tokenInfo.UserLimit) {
return errcode.ParamErr(fmt.Sprintf("用户已达上限(%d)", tokenInfo.UserLimit))
}
userIndex := pkg.GenerateUserIndex()
err = a.userImpl.Add(c.UserContext(), &model.User{
TokenID: tokenInfo.ID,
UserIndex: userIndex,
Name: req.Name,
Status: 1,
})
if err != nil {
return errcode.SqlErr(err)
}
// 获取平台列表并关联
platCond := builder.NewCond().And(builder.Eq{"status": 1})
plats := make([]model.Plat, 0)
_, err = a.platImpl.GetListToStruct(c.UserContext(), &platCond, nil, &plats, "")
if err != nil {
return err
}
for _, plat := range plats {
loginRelation := &model.LoginRelation{
UserIndex: userIndex,
PlatIndex: plat.Index,
LoginStatus: 2,
Status: 1,
}
// 这里需要调用 loginRelationImpl 的 Add 方法
_ = loginRelation
}
return pkg.HandleResponse(c, fiber.Map{
"user_name": req.Name,
"user_index": userIndex,
})
}
func (a *AppService) DelUser(c *fiber.Ctx, req *entitys.DelUserRequest) error {
// 需要从请求中获取 access_token
// 这里简化处理
err := a.userImpl.DeleteByKey(c.UserContext(), a.userImpl.PrimaryKey(), int64(req.ID))
if err != nil {
return err
}
return pkg.HandleResponse(c, fiber.Map{})
}
func (a *AppService) GetApp(c *fiber.Ctx, req *entitys.GetAppRequest) error {
_, err := a.publishBiz.ValidateAccessToken(c.UserContext(), req.AccessToken)
if err != nil {
return err
}
cond := builder.NewCond().
And(builder.Eq{"login_relation.user_index": req.UserIndex}).
And(builder.Eq{"login_relation.status": 1}).
And(builder.Eq{"plat.status": 1})
result, err := a.platImpl.GetPlatListWithLoginStatus(c.UserContext(), &cond)
if err != nil {
return err
}
return pkg.HandleResponse(c, result)
}

95
internal/service/login.go Normal file
View File

@ -0,0 +1,95 @@
package service
import (
"os"
"path/filepath"
"github.com/gofiber/fiber/v2"
"geo/internal/biz"
"geo/internal/config"
"geo/internal/entitys"
"geo/internal/publisher"
"geo/pkg"
"geo/tmpl/errcode"
)
type LoginService struct {
cfg *config.Config
publishBiz *biz.PublishBiz
}
func NewLoginService(
cfg *config.Config,
publishBiz *biz.PublishBiz,
) *LoginService {
return &LoginService{
cfg: cfg,
publishBiz: publishBiz,
}
}
func (s *LoginService) LoginPlatform(c *fiber.Ctx, req *entitys.LoginPlatformRequest) error {
_, err := s.publishBiz.ValidateAccessToken(c.UserContext(), req.AccessToken)
if err != nil {
return err
}
// 获取平台信息
platInfo, err := s.publishBiz.GetPlatInfo(c.UserContext(), req.PlatIndex)
if err != nil {
return errcode.NotFound("平台不存在")
}
// 创建发布器
platMap := map[string]interface{}{
"login_url": platInfo.LoginURL,
"edit_url": platInfo.EditURL,
"logined_url": platInfo.LoginedURL,
}
var pub interface{ WaitLogin() (bool, string) }
switch req.PlatIndex {
case "xhs":
pub = publisher.NewXiaohongshuPublisher(false, "", "", nil, req.UserIndex, req.PlatIndex, "", "", "", platMap, s.cfg)
default:
pub = publisher.NewXiaohongshuPublisher(false, "", "", nil, req.UserIndex, req.PlatIndex, "", "", "", platMap, s.cfg)
}
success, msg := pub.WaitLogin()
if !success {
return errcode.SysErr(msg)
}
// 更新登录状态
err = s.publishBiz.UpdateLoginStatus(c.UserContext(), req.UserIndex, req.PlatIndex, 1)
if err != nil {
return err
}
return pkg.HandleResponse(c, fiber.Map{})
}
func (s *LoginService) LogoutPlatform(c *fiber.Ctx, req *entitys.LogoutPlatformRequest) error {
_, err := s.publishBiz.ValidateAccessToken(c.UserContext(), req.AccessToken)
if err != nil {
return err
}
// 更新登录状态为未登录
err = s.publishBiz.UpdateLoginStatus(c.UserContext(), req.UserIndex, req.PlatIndex, 2)
if err != nil {
return err
}
return pkg.HandleResponse(c, fiber.Map{})
}
func (s *LoginService) ServeQrcode(c *fiber.Ctx, filename string) error {
filepath := filepath.Join(s.cfg.Sys.QrcodesDir, filename)
if _, err := os.Stat(filepath); os.IsNotExist(err) {
return errcode.NotFound("二维码不存在")
}
return c.SendFile(filepath)
}

View File

@ -0,0 +1,11 @@
package service
import (
"github.com/google/wire"
)
var ProviderSetAppService = wire.NewSet(
NewAppService,
NewPublishService,
NewLoginService,
)

205
internal/service/publish.go Normal file
View File

@ -0,0 +1,205 @@
package service
import (
"errors"
"fmt"
"geo/internal/data/model"
"geo/utils"
"os"
"path/filepath"
"time"
"github.com/gofiber/fiber/v2"
"geo/internal/biz"
"geo/internal/config"
"geo/internal/entitys"
"geo/internal/manager"
"geo/pkg"
"geo/tmpl/errcode"
)
type PublishService struct {
cfg *config.Config
publishBiz *biz.PublishBiz
db *utils.Db
}
func NewPublishService(
cfg *config.Config,
publishBiz *biz.PublishBiz,
db *utils.Db,
) *PublishService {
return &PublishService{
cfg: cfg,
publishBiz: publishBiz,
db: db,
}
}
func (s *PublishService) PublishRecords(c *fiber.Ctx, req *entitys.PublishRecordsRequest) error {
// 验证token
tokenInfo, err := s.publishBiz.ValidateAccessToken(c.UserContext(), req.AccessToken)
if err != nil {
return err
}
// 转换记录
validRecords := make([]*model.Publish, 0)
for _, record := range req.Records {
publishTime, err := time.Parse("2006-01-02 15:04:05", record.PublishTime)
if err != nil {
return errcode.ParamErr(fmt.Sprintf("时间格式错误: %v", err))
}
validRecords = append(validRecords, &model.Publish{
UserIndex: record.UserIndex,
RequestID: record.RequestID,
Title: record.Title,
Tag: record.Tag,
Type: record.Type,
PlatIndex: record.PlatIndex,
URL: record.URL,
PublishTime: publishTime,
Img: record.Img,
TokenID: tokenInfo.ID,
})
}
err = s.publishBiz.BatchInsertPublish(c.UserContext(), validRecords)
if err != nil {
return err
}
return pkg.HandleResponse(c, fiber.Map{
"total": len(validRecords),
})
}
func (s *PublishService) PublishOn(c *fiber.Ctx, req *entitys.PublishOnRequest) error {
tokenInfo, err := s.publishBiz.ValidateAccessToken(c.UserContext(), req.AccessToken)
if err != nil {
return err
}
pm := manager.GetPublishManager(s.cfg, s.db)
if pm.Start(int(tokenInfo.ID)) {
return pkg.HandleResponse(c, fiber.Map{
"auto_status": pm.AutoStatus,
})
}
return errors.New("自动发布服务已在运行中")
}
func (s *PublishService) PublishOff(c *fiber.Ctx, req *entitys.PublishOffRequest) error {
_, err := s.publishBiz.ValidateAccessToken(c.UserContext(), req.AccessToken)
if err != nil {
return err
}
pm := manager.GetPublishManager(s.cfg, s.db)
if pm.Stop() {
return pkg.HandleResponse(c, fiber.Map{})
}
return errors.New("自动发布服务未运行")
}
func (s *PublishService) PublishStatus(c *fiber.Ctx, req *entitys.PublishStatusRequest) error {
_, err := s.publishBiz.ValidateAccessToken(c.UserContext(), req.AccessToken)
if err != nil {
return err
}
pm := manager.GetPublishManager(s.cfg, s.db)
if req.RequestID != "" {
// 查询单个任务
task, err := s.publishBiz.GetTaskByRequestID(c.UserContext(), req.RequestID)
if err != nil {
return errcode.NotFound("任务不存在")
}
return pkg.HandleResponse(c, task)
}
// 查询整体状态
return pkg.HandleResponse(c, pm.GetStatus())
}
func (s *PublishService) PublishExecuteOnce(c *fiber.Ctx, req *entitys.PublishExecuteOnceRequest) error {
tokenInfo, err := s.publishBiz.ValidateAccessToken(c.UserContext(), req.AccessToken)
if err != nil {
return err
}
pm := manager.GetPublishManager(s.cfg, s.db)
result := pm.ExecuteOnce(tokenInfo.ID)
return pkg.HandleResponse(c, result)
}
func (s *PublishService) PublishExecuteRetry(c *fiber.Ctx, req *entitys.PublishExecuteRetryRequest) error {
_, err := s.publishBiz.ValidateAccessToken(c.UserContext(), req.AccessToken)
if err != nil {
return err
}
pm := manager.GetPublishManager(s.cfg, s.db)
result := pm.RetryTask(req.RequestID)
return pkg.HandleResponse(c, result)
}
func (s *PublishService) GetPublishList(c *fiber.Ctx, req *entitys.GetPublishListRequest) error {
tokenInfo, err := s.publishBiz.ValidateAccessToken(c.UserContext(), req.AccessToken)
if err != nil {
return err
}
page := req.Page
if page < 1 {
page = 1
}
pageSize := req.PageSize
if pageSize < 1 {
pageSize = 20
}
if pageSize > 100 {
pageSize = 100
}
filters := map[string]interface{}{
"user_index": req.UserIndex,
"tag": req.Tag,
"type": req.Type,
"plat_index": req.PlatIndex,
"status": req.Status,
"request_id": req.RequestID,
}
list, total, err := s.publishBiz.GetPublishList(c.UserContext(), tokenInfo.ID, page, pageSize, filters)
if err != nil {
return err
}
return pkg.HandleResponse(c, fiber.Map{
"list": list,
"pagination": fiber.Map{
"page": page,
"page_size": pageSize,
"total": total,
"total_pages": (total + int64(pageSize) - 1) / int64(pageSize),
},
})
}
func (s *PublishService) GetLogs(c *fiber.Ctx, requestID string) error {
logFile := filepath.Join(s.cfg.Sys.LogsDir, requestID+".log")
content, err := os.ReadFile(logFile)
if err != nil {
return errcode.NotFound("日志文件不存在")
}
return pkg.HandleResponse(c, fiber.Map{
"request_id": requestID,
"log_content": string(content),
})
}

45
pkg/doc.go Normal file
View File

@ -0,0 +1,45 @@
package pkg
import (
"strings"
md "github.com/JohannesKaufmann/html-to-markdown"
strip "github.com/grokify/html-strip-tags-go"
)
func ExtractWordContent(filePath string, format string) (string, error) {
// 简化版:读取文件内容并转换格式
// 完整实现需要使用专门的docx库
content := "从Word文档提取的内容"
switch format {
case "html":
return "<p>" + content + "</p>", nil
case "markdown":
converter := md.NewConverter("", true, nil)
return converter.ConvertString("<p>" + content + "</p>")
default:
return strip.StripTags(content), nil
}
}
func ParseTags(tagStr string) []string {
if tagStr == "" {
return []string{}
}
tags := strings.Split(tagStr, ",")
result := make([]string, 0)
for _, t := range tags {
trimmed := strings.TrimSpace(t)
if trimmed != "" {
result = append(result, trimmed)
}
}
return result
}
func IsTimeExceeded(targetTime string) bool {
// 实现时间比较
return false
}

247
pkg/func.go Normal file
View File

@ -0,0 +1,247 @@
package pkg
import (
"encoding/json"
"fmt"
"io"
"math/rand/v2"
"net/http"
"os"
"path/filepath"
"reflect"
"time"
"github.com/go-viper/mapstructure/v2"
"github.com/google/uuid"
)
func GetModuleDir() (string, error) {
dir, err := os.Getwd()
if err != nil {
return "", err
}
for {
modPath := filepath.Join(dir, "go.mod")
if _, err := os.Stat(modPath); err == nil {
return dir, nil // 找到 go.mod
}
// 向上查找父目录
parent := filepath.Dir(dir)
if parent == dir {
break // 到达根目录,未找到
}
dir = parent
}
return "", fmt.Errorf("go.mod not found in current directory or parents")
}
// GetCacheDir 用于获取缓存目录路径
// 如果缓存目录不存在,则会自动创建
// 返回值:
// - string: 缓存目录的路径
// - error: 如果获取模块目录失败或创建缓存目录失败,则返回错误信息
func GetCacheDir() (string, error) {
// 获取模块目录
modDir, err := GetModuleDir()
if err != nil {
return "", err
}
// 拼接缓存目录路径
path := fmt.Sprintf("%s/cache", modDir)
// 创建目录包括所有必要的父目录权限设置为0755
err = os.MkdirAll(path, 0755)
if err != nil {
return "", fmt.Errorf("创建目录失败: %w", err)
}
// 返回成功创建的缓存目录路径
return path, nil
}
func GetTmplDir() (string, error) {
modDir, err := GetModuleDir()
if err != nil {
return "", err
}
path := fmt.Sprintf("%s/tmpl", modDir)
err = os.MkdirAll(path, 0755)
if err != nil {
return "", fmt.Errorf("创建目录失败: %w", err)
}
return path, nil
}
func ReverseSliceNew[T any](s []T) []T {
result := make([]T, len(s))
for i := 0; i < len(s); i++ {
result[i] = s[len(s)-1-i]
}
return result
}
func JsonStringIgonErr(data interface{}) string {
return string(JsonByteIgonErr(data))
}
func JsonByteIgonErr(data interface{}) []byte {
dataByte, _ := json.Marshal(data)
return dataByte
}
func IntersectionGeneric[T comparable](slice1, slice2 []T) []T {
m := make(map[T]bool)
result := []T{}
for _, v := range slice1 {
m[v] = true
}
for _, v := range slice2 {
if m[v] {
result = append(result, v)
delete(m, v) // 避免重复
}
}
return result
}
func CreateOrderNum(prefix string) string {
code := fmt.Sprintf("%04d", rand.IntN(10000))
fmt.Println("4位随机数字:", code) // 输出示例: "0837"
return prefix + time.Now().Format("20060102150405") + code
}
func BuildUpdateMap(obj interface{}, omitFields ...string) map[string]interface{} {
result := make(map[string]interface{})
omitMap := make(map[string]bool)
for _, f := range omitFields {
omitMap[f] = true
}
v := reflect.ValueOf(obj)
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
t := v.Type()
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
fieldName := t.Field(i).Name
if omitMap[fieldName] {
continue
}
// 只处理非 nil 的指针字段
if field.Kind() == reflect.Ptr && !field.IsNil() {
// 将驼峰转为下划线(可选的,根据你的数据库列名决定)
colName := CamelToSnake(fieldName)
result[colName] = field.Elem().Interface()
}
}
return result
}
func CamelToSnake(s string) string {
var result []rune
for i, r := range s {
if i > 0 && r >= 'A' && r <= 'Z' {
result = append(result, '_')
}
result = append(result, r)
}
return string(result)
}
func CopyNonNilFields(src, dst interface{}) error {
config := &mapstructure.DecoderConfig{
Result: dst,
TagName: "json",
ZeroFields: false, // 重要:不清零目标字段
Squash: false,
}
decoder, err := mapstructure.NewDecoder(config)
if err != nil {
return err
}
return decoder.Decode(src)
}
func DownloadFile(url string, saveDir string, filename string) (string, error) {
os.MkdirAll(saveDir, 0755)
if filename == "" {
filename = uuid.New().String() + ".docx"
}
filePath := filepath.Join(saveDir, filename)
resp, err := http.Get(url)
if err != nil {
return "", err
}
defer resp.Body.Close()
out, err := os.Create(filePath)
if err != nil {
return "", err
}
defer out.Close()
_, err = io.Copy(out, resp.Body)
if err != nil {
return "", err
}
absPath, _ := filepath.Abs(filePath)
return absPath, nil
}
func DownloadImage(url string, requestID string, dir string) (string, error) {
os.MkdirAll(dir, 0755)
ext := filepath.Ext(url)
if ext == "" {
ext = ".jpg"
}
filename := requestID + "_" + uuid.New().String() + ext
filePath := filepath.Join(dir, filename)
resp, err := http.Get(url)
if err != nil {
return "", err
}
defer resp.Body.Close()
out, err := os.Create(filePath)
if err != nil {
return "", err
}
defer out.Close()
_, err = io.Copy(out, resp.Body)
if err != nil {
return "", err
}
return filepath.Abs(filePath)
}
func DeleteFile(path string) {
if path != "" {
os.Remove(path)
}
}
func GenerateUUID() string {
return uuid.New().String()
}
func GenerateUserIndex() string {
return uuid.New().String()[:20]
}

View File

@ -0,0 +1,279 @@
package mapstructure
import (
"encoding"
"errors"
"fmt"
"net"
"reflect"
"strconv"
"strings"
"time"
)
// typedDecodeHook takes a raw DecodeHookFunc (an interface{}) and turns
// it into the proper DecodeHookFunc type, such as DecodeHookFuncType.
func typedDecodeHook(h DecodeHookFunc) DecodeHookFunc {
// Create variables here so we can reference them with the reflect pkg
var f1 DecodeHookFuncType
var f2 DecodeHookFuncKind
var f3 DecodeHookFuncValue
// Fill in the variables into this interface and the rest is done
// automatically using the reflect package.
potential := []interface{}{f1, f2, f3}
v := reflect.ValueOf(h)
vt := v.Type()
for _, raw := range potential {
pt := reflect.ValueOf(raw).Type()
if vt.ConvertibleTo(pt) {
return v.Convert(pt).Interface()
}
}
return nil
}
// DecodeHookExec executes the given decode hook. This should be used
// since it'll naturally degrade to the older backwards compatible DecodeHookFunc
// that took reflect.Kind instead of reflect.Type.
func DecodeHookExec(
raw DecodeHookFunc,
from reflect.Value, to reflect.Value) (interface{}, error) {
switch f := typedDecodeHook(raw).(type) {
case DecodeHookFuncType:
return f(from.Type(), to.Type(), from.Interface())
case DecodeHookFuncKind:
return f(from.Kind(), to.Kind(), from.Interface())
case DecodeHookFuncValue:
return f(from, to)
default:
return nil, errors.New("invalid decode hook signature")
}
}
// ComposeDecodeHookFunc creates a single DecodeHookFunc that
// automatically composes multiple DecodeHookFuncs.
//
// The composed funcs are called in order, with the result of the
// previous transformation.
func ComposeDecodeHookFunc(fs ...DecodeHookFunc) DecodeHookFunc {
return func(f reflect.Value, t reflect.Value) (interface{}, error) {
var err error
data := f.Interface()
newFrom := f
for _, f1 := range fs {
data, err = DecodeHookExec(f1, newFrom, t)
if err != nil {
return nil, err
}
newFrom = reflect.ValueOf(data)
}
return data, nil
}
}
// OrComposeDecodeHookFunc executes all input hook functions until one of them returns no error. In that case its value is returned.
// If all hooks return an error, OrComposeDecodeHookFunc returns an error concatenating all error messages.
func OrComposeDecodeHookFunc(ff ...DecodeHookFunc) DecodeHookFunc {
return func(a, b reflect.Value) (interface{}, error) {
var allErrs string
var out interface{}
var err error
for _, f := range ff {
out, err = DecodeHookExec(f, a, b)
if err != nil {
allErrs += err.Error() + "\n"
continue
}
return out, nil
}
return nil, errors.New(allErrs)
}
}
// StringToSliceHookFunc returns a DecodeHookFunc that converts
// string to []string by splitting on the given sep.
func StringToSliceHookFunc(sep string) DecodeHookFunc {
return func(
f reflect.Kind,
t reflect.Kind,
data interface{}) (interface{}, error) {
if f != reflect.String || t != reflect.Slice {
return data, nil
}
raw := data.(string)
if raw == "" {
return []string{}, nil
}
return strings.Split(raw, sep), nil
}
}
// StringToTimeDurationHookFunc returns a DecodeHookFunc that converts
// strings to time.Duration.
func StringToTimeDurationHookFunc() DecodeHookFunc {
return func(
f reflect.Type,
t reflect.Type,
data interface{}) (interface{}, error) {
if f.Kind() != reflect.String {
return data, nil
}
if t != reflect.TypeOf(time.Duration(5)) {
return data, nil
}
// Convert it by parsing
return time.ParseDuration(data.(string))
}
}
// StringToIPHookFunc returns a DecodeHookFunc that converts
// strings to net.IP
func StringToIPHookFunc() DecodeHookFunc {
return func(
f reflect.Type,
t reflect.Type,
data interface{}) (interface{}, error) {
if f.Kind() != reflect.String {
return data, nil
}
if t != reflect.TypeOf(net.IP{}) {
return data, nil
}
// Convert it by parsing
ip := net.ParseIP(data.(string))
if ip == nil {
return net.IP{}, fmt.Errorf("failed parsing ip %v", data)
}
return ip, nil
}
}
// StringToIPNetHookFunc returns a DecodeHookFunc that converts
// strings to net.IPNet
func StringToIPNetHookFunc() DecodeHookFunc {
return func(
f reflect.Type,
t reflect.Type,
data interface{}) (interface{}, error) {
if f.Kind() != reflect.String {
return data, nil
}
if t != reflect.TypeOf(net.IPNet{}) {
return data, nil
}
// Convert it by parsing
_, net, err := net.ParseCIDR(data.(string))
return net, err
}
}
// StringToTimeHookFunc returns a DecodeHookFunc that converts
// strings to time.Time.
func StringToTimeHookFunc(layout string) DecodeHookFunc {
return func(
f reflect.Type,
t reflect.Type,
data interface{}) (interface{}, error) {
if f.Kind() != reflect.String {
return data, nil
}
if t != reflect.TypeOf(time.Time{}) {
return data, nil
}
// Convert it by parsing
return time.Parse(layout, data.(string))
}
}
// WeaklyTypedHook is a DecodeHookFunc which adds support for weak typing to
// the decoder.
//
// Note that this is significantly different from the WeaklyTypedInput option
// of the DecoderConfig.
func WeaklyTypedHook(
f reflect.Kind,
t reflect.Kind,
data interface{}) (interface{}, error) {
dataVal := reflect.ValueOf(data)
switch t {
case reflect.String:
switch f {
case reflect.Bool:
if dataVal.Bool() {
return "1", nil
}
return "0", nil
case reflect.Float32:
return strconv.FormatFloat(dataVal.Float(), 'f', -1, 64), nil
case reflect.Int:
return strconv.FormatInt(dataVal.Int(), 10), nil
case reflect.Slice:
dataType := dataVal.Type()
elemKind := dataType.Elem().Kind()
if elemKind == reflect.Uint8 {
return string(dataVal.Interface().([]uint8)), nil
}
case reflect.Uint:
return strconv.FormatUint(dataVal.Uint(), 10), nil
}
}
return data, nil
}
func RecursiveStructToMapHookFunc() DecodeHookFunc {
return func(f reflect.Value, t reflect.Value) (interface{}, error) {
if f.Kind() != reflect.Struct {
return f.Interface(), nil
}
var i interface{} = struct{}{}
if t.Type() != reflect.TypeOf(&i).Elem() {
return f.Interface(), nil
}
m := make(map[string]interface{})
t.Set(reflect.ValueOf(m))
return f.Interface(), nil
}
}
// TextUnmarshallerHookFunc returns a DecodeHookFunc that applies
// strings to the UnmarshalText function, when the target type
// implements the encoding.TextUnmarshaler interface
func TextUnmarshallerHookFunc() DecodeHookFuncType {
return func(
f reflect.Type,
t reflect.Type,
data interface{}) (interface{}, error) {
if f.Kind() != reflect.String {
return data, nil
}
result := reflect.New(t).Interface()
unmarshaller, ok := result.(encoding.TextUnmarshaler)
if !ok {
return data, nil
}
if err := unmarshaller.UnmarshalText([]byte(data.(string))); err != nil {
return nil, err
}
return result, nil
}
}

View File

@ -0,0 +1,567 @@
package mapstructure
import (
"errors"
"math/big"
"net"
"reflect"
"testing"
"time"
)
func TestComposeDecodeHookFunc(t *testing.T) {
f1 := func(
f reflect.Kind,
t reflect.Kind,
data interface{}) (interface{}, error) {
return data.(string) + "foo", nil
}
f2 := func(
f reflect.Kind,
t reflect.Kind,
data interface{}) (interface{}, error) {
return data.(string) + "bar", nil
}
f := ComposeDecodeHookFunc(f1, f2)
result, err := DecodeHookExec(
f, reflect.ValueOf(""), reflect.ValueOf([]byte("")))
if err != nil {
t.Fatalf("bad: %s", err)
}
if result.(string) != "foobar" {
t.Fatalf("bad: %#v", result)
}
}
func TestComposeDecodeHookFunc_err(t *testing.T) {
f1 := func(reflect.Kind, reflect.Kind, interface{}) (interface{}, error) {
return nil, errors.New("foo")
}
f2 := func(reflect.Kind, reflect.Kind, interface{}) (interface{}, error) {
panic("NOPE")
}
f := ComposeDecodeHookFunc(f1, f2)
_, err := DecodeHookExec(
f, reflect.ValueOf(""), reflect.ValueOf([]byte("")))
if err.Error() != "foo" {
t.Fatalf("bad: %s", err)
}
}
func TestComposeDecodeHookFunc_kinds(t *testing.T) {
var f2From reflect.Kind
f1 := func(
f reflect.Kind,
t reflect.Kind,
data interface{}) (interface{}, error) {
return int(42), nil
}
f2 := func(
f reflect.Kind,
t reflect.Kind,
data interface{}) (interface{}, error) {
f2From = f
return data, nil
}
f := ComposeDecodeHookFunc(f1, f2)
_, err := DecodeHookExec(
f, reflect.ValueOf(""), reflect.ValueOf([]byte("")))
if err != nil {
t.Fatalf("bad: %s", err)
}
if f2From != reflect.Int {
t.Fatalf("bad: %#v", f2From)
}
}
func TestOrComposeDecodeHookFunc(t *testing.T) {
f1 := func(
f reflect.Kind,
t reflect.Kind,
data interface{}) (interface{}, error) {
return data.(string) + "foo", nil
}
f2 := func(
f reflect.Kind,
t reflect.Kind,
data interface{}) (interface{}, error) {
return data.(string) + "bar", nil
}
f := OrComposeDecodeHookFunc(f1, f2)
result, err := DecodeHookExec(
f, reflect.ValueOf(""), reflect.ValueOf([]byte("")))
if err != nil {
t.Fatalf("bad: %s", err)
}
if result.(string) != "foo" {
t.Fatalf("bad: %#v", result)
}
}
func TestOrComposeDecodeHookFunc_correctValueIsLast(t *testing.T) {
f1 := func(
f reflect.Kind,
t reflect.Kind,
data interface{}) (interface{}, error) {
return nil, errors.New("f1 error")
}
f2 := func(
f reflect.Kind,
t reflect.Kind,
data interface{}) (interface{}, error) {
return nil, errors.New("f2 error")
}
f3 := func(
f reflect.Kind,
t reflect.Kind,
data interface{}) (interface{}, error) {
return data.(string) + "bar", nil
}
f := OrComposeDecodeHookFunc(f1, f2, f3)
result, err := DecodeHookExec(
f, reflect.ValueOf(""), reflect.ValueOf([]byte("")))
if err != nil {
t.Fatalf("bad: %s", err)
}
if result.(string) != "bar" {
t.Fatalf("bad: %#v", result)
}
}
func TestOrComposeDecodeHookFunc_err(t *testing.T) {
f1 := func(
f reflect.Kind,
t reflect.Kind,
data interface{}) (interface{}, error) {
return nil, errors.New("f1 error")
}
f2 := func(
f reflect.Kind,
t reflect.Kind,
data interface{}) (interface{}, error) {
return nil, errors.New("f2 error")
}
f := OrComposeDecodeHookFunc(f1, f2)
_, err := DecodeHookExec(
f, reflect.ValueOf(""), reflect.ValueOf([]byte("")))
if err == nil {
t.Fatalf("bad: should return an error")
}
if err.Error() != "f1 error\nf2 error\n" {
t.Fatalf("bad: %s", err)
}
}
func TestComposeDecodeHookFunc_safe_nofuncs(t *testing.T) {
f := ComposeDecodeHookFunc()
type myStruct2 struct {
MyInt int
}
type myStruct1 struct {
Blah map[string]myStruct2
}
src := &myStruct1{Blah: map[string]myStruct2{
"test": {
MyInt: 1,
},
}}
dst := &myStruct1{}
dConf := &DecoderConfig{
Result: dst,
ErrorUnused: true,
DecodeHook: f,
}
d, err := NewDecoder(dConf)
if err != nil {
t.Fatal(err)
}
err = d.Decode(src)
if err != nil {
t.Fatal(err)
}
}
func TestStringToSliceHookFunc(t *testing.T) {
f := StringToSliceHookFunc(",")
strValue := reflect.ValueOf("42")
sliceValue := reflect.ValueOf([]byte("42"))
cases := []struct {
f, t reflect.Value
result interface{}
err bool
}{
{sliceValue, sliceValue, []byte("42"), false},
{strValue, strValue, "42", false},
{
reflect.ValueOf("foo,bar,baz"),
sliceValue,
[]string{"foo", "bar", "baz"},
false,
},
{
reflect.ValueOf(""),
sliceValue,
[]string{},
false,
},
}
for i, tc := range cases {
actual, err := DecodeHookExec(f, tc.f, tc.t)
if tc.err != (err != nil) {
t.Fatalf("case %d: expected jderr %#v", i, tc.err)
}
if !reflect.DeepEqual(actual, tc.result) {
t.Fatalf(
"case %d: expected %#v, got %#v",
i, tc.result, actual)
}
}
}
func TestStringToTimeDurationHookFunc(t *testing.T) {
f := StringToTimeDurationHookFunc()
timeValue := reflect.ValueOf(time.Duration(5))
strValue := reflect.ValueOf("")
cases := []struct {
f, t reflect.Value
result interface{}
err bool
}{
{reflect.ValueOf("5s"), timeValue, 5 * time.Second, false},
{reflect.ValueOf("5"), timeValue, time.Duration(0), true},
{reflect.ValueOf("5"), strValue, "5", false},
}
for i, tc := range cases {
actual, err := DecodeHookExec(f, tc.f, tc.t)
if tc.err != (err != nil) {
t.Fatalf("case %d: expected jderr %#v", i, tc.err)
}
if !reflect.DeepEqual(actual, tc.result) {
t.Fatalf(
"case %d: expected %#v, got %#v",
i, tc.result, actual)
}
}
}
func TestStringToTimeHookFunc(t *testing.T) {
strValue := reflect.ValueOf("5")
timeValue := reflect.ValueOf(time.Time{})
cases := []struct {
f, t reflect.Value
layout string
result interface{}
err bool
}{
{reflect.ValueOf("2006-01-02T15:04:05Z"), timeValue, time.RFC3339,
time.Date(2006, 1, 2, 15, 4, 5, 0, time.UTC), false},
{strValue, timeValue, time.RFC3339, time.Time{}, true},
{strValue, strValue, time.RFC3339, "5", false},
}
for i, tc := range cases {
f := StringToTimeHookFunc(tc.layout)
actual, err := DecodeHookExec(f, tc.f, tc.t)
if tc.err != (err != nil) {
t.Fatalf("case %d: expected jderr %#v", i, tc.err)
}
if !reflect.DeepEqual(actual, tc.result) {
t.Fatalf(
"case %d: expected %#v, got %#v",
i, tc.result, actual)
}
}
}
func TestStringToIPHookFunc(t *testing.T) {
strValue := reflect.ValueOf("5")
ipValue := reflect.ValueOf(net.IP{})
cases := []struct {
f, t reflect.Value
result interface{}
err bool
}{
{reflect.ValueOf("1.2.3.4"), ipValue,
net.IPv4(0x01, 0x02, 0x03, 0x04), false},
{strValue, ipValue, net.IP{}, true},
{strValue, strValue, "5", false},
}
for i, tc := range cases {
f := StringToIPHookFunc()
actual, err := DecodeHookExec(f, tc.f, tc.t)
if tc.err != (err != nil) {
t.Fatalf("case %d: expected jderr %#v", i, tc.err)
}
if !reflect.DeepEqual(actual, tc.result) {
t.Fatalf(
"case %d: expected %#v, got %#v",
i, tc.result, actual)
}
}
}
func TestStringToIPNetHookFunc(t *testing.T) {
strValue := reflect.ValueOf("5")
ipNetValue := reflect.ValueOf(net.IPNet{})
var nilNet *net.IPNet = nil
cases := []struct {
f, t reflect.Value
result interface{}
err bool
}{
{reflect.ValueOf("1.2.3.4/24"), ipNetValue,
&net.IPNet{
IP: net.IP{0x01, 0x02, 0x03, 0x00},
Mask: net.IPv4Mask(0xff, 0xff, 0xff, 0x00),
}, false},
{strValue, ipNetValue, nilNet, true},
{strValue, strValue, "5", false},
}
for i, tc := range cases {
f := StringToIPNetHookFunc()
actual, err := DecodeHookExec(f, tc.f, tc.t)
if tc.err != (err != nil) {
t.Fatalf("case %d: expected jderr %#v", i, tc.err)
}
if !reflect.DeepEqual(actual, tc.result) {
t.Fatalf(
"case %d: expected %#v, got %#v",
i, tc.result, actual)
}
}
}
func TestWeaklyTypedHook(t *testing.T) {
var f DecodeHookFunc = WeaklyTypedHook
strValue := reflect.ValueOf("")
cases := []struct {
f, t reflect.Value
result interface{}
err bool
}{
// TO STRING
{
reflect.ValueOf(false),
strValue,
"0",
false,
},
{
reflect.ValueOf(true),
strValue,
"1",
false,
},
{
reflect.ValueOf(float32(7)),
strValue,
"7",
false,
},
{
reflect.ValueOf(int(7)),
strValue,
"7",
false,
},
{
reflect.ValueOf([]uint8("foo")),
strValue,
"foo",
false,
},
{
reflect.ValueOf(uint(7)),
strValue,
"7",
false,
},
}
for i, tc := range cases {
actual, err := DecodeHookExec(f, tc.f, tc.t)
if tc.err != (err != nil) {
t.Fatalf("case %d: expected jderr %#v", i, tc.err)
}
if !reflect.DeepEqual(actual, tc.result) {
t.Fatalf(
"case %d: expected %#v, got %#v",
i, tc.result, actual)
}
}
}
func TestStructToMapHookFuncTabled(t *testing.T) {
var f DecodeHookFunc = RecursiveStructToMapHookFunc()
type b struct {
TestKey string
}
type a struct {
Sub b
}
testStruct := a{
Sub: b{
TestKey: "testval",
},
}
testMap := map[string]interface{}{
"Sub": map[string]interface{}{
"TestKey": "testval",
},
}
cases := []struct {
name string
receiver interface{}
input interface{}
expected interface{}
err bool
}{
{
"map receiver",
func() interface{} {
var res map[string]interface{}
return &res
}(),
testStruct,
&testMap,
false,
},
{
"interface receiver",
func() interface{} {
var res interface{}
return &res
}(),
testStruct,
func() interface{} {
var exp interface{} = testMap
return &exp
}(),
false,
},
{
"slice receiver errors",
func() interface{} {
var res []string
return &res
}(),
testStruct,
new([]string),
true,
},
{
"slice to slice - no change",
func() interface{} {
var res []string
return &res
}(),
[]string{"a", "b"},
&[]string{"a", "b"},
false,
},
{
"string to string - no change",
func() interface{} {
var res string
return &res
}(),
"test",
func() *string {
s := "test"
return &s
}(),
false,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
cfg := &DecoderConfig{
DecodeHook: f,
Result: tc.receiver,
}
d, err := NewDecoder(cfg)
if err != nil {
t.Fatalf("unexpected jderr %#v", err)
}
err = d.Decode(tc.input)
if tc.err != (err != nil) {
t.Fatalf("expected jderr %#v", err)
}
if !reflect.DeepEqual(tc.expected, tc.receiver) {
t.Fatalf("expected %#v, got %#v",
tc.expected, tc.receiver)
}
})
}
}
func TestTextUnmarshallerHookFunc(t *testing.T) {
cases := []struct {
f, t reflect.Value
result interface{}
err bool
}{
{reflect.ValueOf("42"), reflect.ValueOf(big.Int{}), big.NewInt(42), false},
{reflect.ValueOf("invalid"), reflect.ValueOf(big.Int{}), nil, true},
{reflect.ValueOf("5"), reflect.ValueOf("5"), "5", false},
}
for i, tc := range cases {
f := TextUnmarshallerHookFunc()
actual, err := DecodeHookExec(f, tc.f, tc.t)
if tc.err != (err != nil) {
t.Fatalf("case %d: expected jderr %#v", i, tc.err)
}
if !reflect.DeepEqual(actual, tc.result) {
t.Fatalf(
"case %d: expected %#v, got %#v",
i, tc.result, actual)
}
}
}

50
pkg/mapstructure/error.go Normal file
View File

@ -0,0 +1,50 @@
package mapstructure
import (
"errors"
"fmt"
"sort"
"strings"
)
// Error implements the error interface and can represents multiple
// errors that occur in the course of a single decode.
type Error struct {
Errors []string
}
func (e *Error) Error() string {
points := make([]string, len(e.Errors))
for i, err := range e.Errors {
points[i] = fmt.Sprintf("* %s", err)
}
sort.Strings(points)
return fmt.Sprintf(
"%d error(s) decoding:\n\n%s",
len(e.Errors), strings.Join(points, "\n"))
}
// WrappedErrors implements the errwrap.Wrapper interface to make this
// return value more useful with the errwrap and go-multierror libraries.
func (e *Error) WrappedErrors() []error {
if e == nil {
return nil
}
result := make([]error, len(e.Errors))
for i, e := range e.Errors {
result[i] = errors.New(e)
}
return result
}
func appendErrors(errors []string, err error) []string {
switch e := err.(type) {
case *Error:
return append(errors, e.Errors...)
default:
return append(errors, e.Error())
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,285 @@
package mapstructure
import (
"encoding/json"
"testing"
)
type Person struct {
Name string
Age int
Emails []string
Extra map[string]string
}
func Benchmark_Decode(b *testing.B) {
input := map[string]interface{}{
"name": "Mitchell",
"age": 91,
"emails": []string{"one", "two", "three"},
"extra": map[string]string{
"twitter": "mitchellh",
},
}
var result Person
for i := 0; i < b.N; i++ {
Decode(input, &result)
}
}
// decodeViaJSON takes the map data and passes it through encoding/json to convert it into the
// given Go native structure pointed to by v. v must be a pointer to a struct.
func decodeViaJSON(data interface{}, v interface{}) error {
// Perform the task by simply marshalling the input into JSON,
// then unmarshalling it into target native Go struct.
b, err := json.Marshal(data)
if err != nil {
return err
}
return json.Unmarshal(b, v)
}
func Benchmark_DecodeViaJSON(b *testing.B) {
input := map[string]interface{}{
"name": "Mitchell",
"age": 91,
"emails": []string{"one", "two", "three"},
"extra": map[string]string{
"twitter": "mitchellh",
},
}
var result Person
for i := 0; i < b.N; i++ {
decodeViaJSON(input, &result)
}
}
func Benchmark_JSONUnmarshal(b *testing.B) {
input := map[string]interface{}{
"name": "Mitchell",
"age": 91,
"emails": []string{"one", "two", "three"},
"extra": map[string]string{
"twitter": "mitchellh",
},
}
inputB, err := json.Marshal(input)
if err != nil {
b.Fatal("Failed to marshal test input:", err)
}
var result Person
for i := 0; i < b.N; i++ {
json.Unmarshal(inputB, &result)
}
}
func Benchmark_DecodeBasic(b *testing.B) {
input := map[string]interface{}{
"vstring": "foo",
"vint": 42,
"Vuint": 42,
"vbool": true,
"Vfloat": 42.42,
"vsilent": true,
"vdata": 42,
"vjsonInt": json.Number("1234"),
"vjsonFloat": json.Number("1234.5"),
"vjsonNumber": json.Number("1234.5"),
}
for i := 0; i < b.N; i++ {
var result Basic
Decode(input, &result)
}
}
func Benchmark_DecodeEmbedded(b *testing.B) {
input := map[string]interface{}{
"vstring": "foo",
"Basic": map[string]interface{}{
"vstring": "innerfoo",
},
"vunique": "bar",
}
var result Embedded
for i := 0; i < b.N; i++ {
Decode(input, &result)
}
}
func Benchmark_DecodeTypeConversion(b *testing.B) {
input := map[string]interface{}{
"IntToFloat": 42,
"IntToUint": 42,
"IntToBool": 1,
"IntToString": 42,
"UintToInt": 42,
"UintToFloat": 42,
"UintToBool": 42,
"UintToString": 42,
"BoolToInt": true,
"BoolToUint": true,
"BoolToFloat": true,
"BoolToString": true,
"FloatToInt": 42.42,
"FloatToUint": 42.42,
"FloatToBool": 42.42,
"FloatToString": 42.42,
"StringToInt": "42",
"StringToUint": "42",
"StringToBool": "1",
"StringToFloat": "42.42",
"SliceToMap": []interface{}{},
"MapToSlice": map[string]interface{}{},
}
var resultStrict TypeConversionResult
for i := 0; i < b.N; i++ {
Decode(input, &resultStrict)
}
}
func Benchmark_DecodeMap(b *testing.B) {
input := map[string]interface{}{
"vfoo": "foo",
"vother": map[interface{}]interface{}{
"foo": "foo",
"bar": "bar",
},
}
var result Map
for i := 0; i < b.N; i++ {
Decode(input, &result)
}
}
func Benchmark_DecodeMapOfStruct(b *testing.B) {
input := map[string]interface{}{
"value": map[string]interface{}{
"foo": map[string]string{"vstring": "one"},
"bar": map[string]string{"vstring": "two"},
},
}
var result MapOfStruct
for i := 0; i < b.N; i++ {
Decode(input, &result)
}
}
func Benchmark_DecodeSlice(b *testing.B) {
input := map[string]interface{}{
"vfoo": "foo",
"vbar": []string{"foo", "bar", "baz"},
}
var result Slice
for i := 0; i < b.N; i++ {
Decode(input, &result)
}
}
func Benchmark_DecodeSliceOfStruct(b *testing.B) {
input := map[string]interface{}{
"value": []map[string]interface{}{
{"vstring": "one"},
{"vstring": "two"},
},
}
var result SliceOfStruct
for i := 0; i < b.N; i++ {
Decode(input, &result)
}
}
func Benchmark_DecodeWeaklyTypedInput(b *testing.B) {
// This input can come from anywhere, but typically comes from
// something like decoding JSON, generated by a weakly typed language
// such as PHP.
input := map[string]interface{}{
"name": 123, // number => string
"age": "42", // string => number
"emails": map[string]interface{}{}, // empty map => empty array
}
var result Person
config := &DecoderConfig{
WeaklyTypedInput: true,
Result: &result,
}
decoder, err := NewDecoder(config)
if err != nil {
panic(err)
}
for i := 0; i < b.N; i++ {
decoder.Decode(input)
}
}
func Benchmark_DecodeMetadata(b *testing.B) {
input := map[string]interface{}{
"name": "Mitchell",
"age": 91,
"email": "foo@bar.com",
}
var md Metadata
var result Person
config := &DecoderConfig{
Metadata: &md,
Result: &result,
}
decoder, err := NewDecoder(config)
if err != nil {
panic(err)
}
for i := 0; i < b.N; i++ {
decoder.Decode(input)
}
}
func Benchmark_DecodeMetadataEmbedded(b *testing.B) {
input := map[string]interface{}{
"vstring": "foo",
"vunique": "bar",
}
var md Metadata
var result EmbeddedSquash
config := &DecoderConfig{
Metadata: &md,
Result: &result,
}
decoder, err := NewDecoder(config)
if err != nil {
b.Fatalf("jderr: %s", err)
}
for i := 0; i < b.N; i++ {
decoder.Decode(input)
}
}
func Benchmark_DecodeTagged(b *testing.B) {
input := map[string]interface{}{
"foo": "bar",
"bar": "value",
}
var result Tagged
for i := 0; i < b.N; i++ {
Decode(input, &result)
}
}

View File

@ -0,0 +1,627 @@
package mapstructure
import (
"reflect"
"testing"
"time"
)
// GH-1, GH-10, GH-96
func TestDecode_NilValue(t *testing.T) {
t.Parallel()
tests := []struct {
name string
in interface{}
target interface{}
out interface{}
metaKeys []string
metaUnused []string
}{
{
"all nil",
&map[string]interface{}{
"vfoo": nil,
"vother": nil,
},
&Map{Vfoo: "foo", Vother: map[string]string{"foo": "bar"}},
&Map{Vfoo: "", Vother: nil},
[]string{"Vfoo", "Vother"},
[]string{},
},
{
"partial nil",
&map[string]interface{}{
"vfoo": "baz",
"vother": nil,
},
&Map{Vfoo: "foo", Vother: map[string]string{"foo": "bar"}},
&Map{Vfoo: "baz", Vother: nil},
[]string{"Vfoo", "Vother"},
[]string{},
},
{
"partial decode",
&map[string]interface{}{
"vother": nil,
},
&Map{Vfoo: "foo", Vother: map[string]string{"foo": "bar"}},
&Map{Vfoo: "foo", Vother: nil},
[]string{"Vother"},
[]string{},
},
{
"unused values",
&map[string]interface{}{
"vbar": "bar",
"vfoo": nil,
"vother": nil,
},
&Map{Vfoo: "foo", Vother: map[string]string{"foo": "bar"}},
&Map{Vfoo: "", Vother: nil},
[]string{"Vfoo", "Vother"},
[]string{"vbar"},
},
{
"map interface all nil",
&map[interface{}]interface{}{
"vfoo": nil,
"vother": nil,
},
&Map{Vfoo: "foo", Vother: map[string]string{"foo": "bar"}},
&Map{Vfoo: "", Vother: nil},
[]string{"Vfoo", "Vother"},
[]string{},
},
{
"map interface partial nil",
&map[interface{}]interface{}{
"vfoo": "baz",
"vother": nil,
},
&Map{Vfoo: "foo", Vother: map[string]string{"foo": "bar"}},
&Map{Vfoo: "baz", Vother: nil},
[]string{"Vfoo", "Vother"},
[]string{},
},
{
"map interface partial decode",
&map[interface{}]interface{}{
"vother": nil,
},
&Map{Vfoo: "foo", Vother: map[string]string{"foo": "bar"}},
&Map{Vfoo: "foo", Vother: nil},
[]string{"Vother"},
[]string{},
},
{
"map interface unused values",
&map[interface{}]interface{}{
"vbar": "bar",
"vfoo": nil,
"vother": nil,
},
&Map{Vfoo: "foo", Vother: map[string]string{"foo": "bar"}},
&Map{Vfoo: "", Vother: nil},
[]string{"Vfoo", "Vother"},
[]string{"vbar"},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
config := &DecoderConfig{
Metadata: new(Metadata),
Result: tc.target,
ZeroFields: true,
}
decoder, err := NewDecoder(config)
if err != nil {
t.Fatalf("should not error: %s", err)
}
err = decoder.Decode(tc.in)
if err != nil {
t.Fatalf("should not error: %s", err)
}
if !reflect.DeepEqual(tc.out, tc.target) {
t.Fatalf("%q: TestDecode_NilValue() expected: %#v, got: %#v", tc.name, tc.out, tc.target)
}
if !reflect.DeepEqual(tc.metaKeys, config.Metadata.Keys) {
t.Fatalf("%q: Metadata.Keys mismatch expected: %#v, got: %#v", tc.name, tc.metaKeys, config.Metadata.Keys)
}
if !reflect.DeepEqual(tc.metaUnused, config.Metadata.Unused) {
t.Fatalf("%q: Metadata.Unused mismatch expected: %#v, got: %#v", tc.name, tc.metaUnused, config.Metadata.Unused)
}
})
}
}
// #48
func TestNestedTypePointerWithDefaults(t *testing.T) {
t.Parallel()
input := map[string]interface{}{
"vfoo": "foo",
"vbar": map[string]interface{}{
"vstring": "foo",
"vint": 42,
"vbool": true,
},
}
result := NestedPointer{
Vbar: &Basic{
Vuint: 42,
},
}
err := Decode(input, &result)
if err != nil {
t.Fatalf("got an jderr: %s", err.Error())
}
if result.Vfoo != "foo" {
t.Errorf("vfoo value should be 'foo': %#v", result.Vfoo)
}
if result.Vbar.Vstring != "foo" {
t.Errorf("vstring value should be 'foo': %#v", result.Vbar.Vstring)
}
if result.Vbar.Vint != 42 {
t.Errorf("vint value should be 42: %#v", result.Vbar.Vint)
}
if result.Vbar.Vbool != true {
t.Errorf("vbool value should be true: %#v", result.Vbar.Vbool)
}
if result.Vbar.Vextra != "" {
t.Errorf("vextra value should be empty: %#v", result.Vbar.Vextra)
}
// this is the error
if result.Vbar.Vuint != 42 {
t.Errorf("vuint value should be 42: %#v", result.Vbar.Vuint)
}
}
type NestedSlice struct {
Vfoo string
Vbars []Basic
Vempty []Basic
}
// #48
func TestNestedTypeSliceWithDefaults(t *testing.T) {
t.Parallel()
input := map[string]interface{}{
"vfoo": "foo",
"vbars": []map[string]interface{}{
{"vstring": "foo", "vint": 42, "vbool": true},
{"vint": 42, "vbool": true},
},
"vempty": []map[string]interface{}{
{"vstring": "foo", "vint": 42, "vbool": true},
{"vint": 42, "vbool": true},
},
}
result := NestedSlice{
Vbars: []Basic{
{Vuint: 42},
{Vstring: "foo"},
},
}
err := Decode(input, &result)
if err != nil {
t.Fatalf("got an jderr: %s", err.Error())
}
if result.Vfoo != "foo" {
t.Errorf("vfoo value should be 'foo': %#v", result.Vfoo)
}
if result.Vbars[0].Vstring != "foo" {
t.Errorf("vstring value should be 'foo': %#v", result.Vbars[0].Vstring)
}
// this is the error
if result.Vbars[0].Vuint != 42 {
t.Errorf("vuint value should be 42: %#v", result.Vbars[0].Vuint)
}
}
// #48 workaround
func TestNestedTypeWithDefaults(t *testing.T) {
t.Parallel()
input := map[string]interface{}{
"vfoo": "foo",
"vbar": map[string]interface{}{
"vstring": "foo",
"vint": 42,
"vbool": true,
},
}
result := Nested{
Vbar: Basic{
Vuint: 42,
},
}
err := Decode(input, &result)
if err != nil {
t.Fatalf("got an jderr: %s", err.Error())
}
if result.Vfoo != "foo" {
t.Errorf("vfoo value should be 'foo': %#v", result.Vfoo)
}
if result.Vbar.Vstring != "foo" {
t.Errorf("vstring value should be 'foo': %#v", result.Vbar.Vstring)
}
if result.Vbar.Vint != 42 {
t.Errorf("vint value should be 42: %#v", result.Vbar.Vint)
}
if result.Vbar.Vbool != true {
t.Errorf("vbool value should be true: %#v", result.Vbar.Vbool)
}
if result.Vbar.Vextra != "" {
t.Errorf("vextra value should be empty: %#v", result.Vbar.Vextra)
}
// this is the error
if result.Vbar.Vuint != 42 {
t.Errorf("vuint value should be 42: %#v", result.Vbar.Vuint)
}
}
// #67 panic() on extending slices (decodeSlice with disabled ZeroValues)
func TestDecodeSliceToEmptySliceWOZeroing(t *testing.T) {
t.Parallel()
type TestStruct struct {
Vfoo []string
}
decode := func(m interface{}, rawVal interface{}) error {
config := &DecoderConfig{
Metadata: nil,
Result: rawVal,
ZeroFields: false,
}
decoder, err := NewDecoder(config)
if err != nil {
return err
}
return decoder.Decode(m)
}
{
input := map[string]interface{}{
"vfoo": []string{"1"},
}
result := &TestStruct{}
err := decode(input, &result)
if err != nil {
t.Fatalf("got an jderr: %s", err.Error())
}
}
{
input := map[string]interface{}{
"vfoo": []string{"1"},
}
result := &TestStruct{
Vfoo: []string{},
}
err := decode(input, &result)
if err != nil {
t.Fatalf("got an jderr: %s", err.Error())
}
}
{
input := map[string]interface{}{
"vfoo": []string{"2", "3"},
}
result := &TestStruct{
Vfoo: []string{"1"},
}
err := decode(input, &result)
if err != nil {
t.Fatalf("got an jderr: %s", err.Error())
}
}
}
// #70
func TestNextSquashMapstructure(t *testing.T) {
data := &struct {
Level1 struct {
Level2 struct {
Foo string
} `mapstructure:",squash"`
} `mapstructure:",squash"`
}{}
err := Decode(map[interface{}]interface{}{"foo": "baz"}, &data)
if err != nil {
t.Fatalf("should not error: %s", err)
}
if data.Level1.Level2.Foo != "baz" {
t.Fatal("value should be baz")
}
}
type ImplementsInterfacePointerReceiver struct {
Name string
}
func (i *ImplementsInterfacePointerReceiver) DoStuff() {}
type ImplementsInterfaceValueReceiver string
func (i ImplementsInterfaceValueReceiver) DoStuff() {}
// GH-140 Type error when using DecodeHook to decode into interface
func TestDecode_DecodeHookInterface(t *testing.T) {
t.Parallel()
type Interface interface {
DoStuff()
}
type DecodeIntoInterface struct {
Test Interface
}
testData := map[string]string{"test": "test"}
stringToPointerInterfaceDecodeHook := func(from, to reflect.Type, data interface{}) (interface{}, error) {
if from.Kind() != reflect.String {
return data, nil
}
if to != reflect.TypeOf((*Interface)(nil)).Elem() {
return data, nil
}
// Ensure interface is satisfied
var impl Interface = &ImplementsInterfacePointerReceiver{data.(string)}
return impl, nil
}
stringToValueInterfaceDecodeHook := func(from, to reflect.Type, data interface{}) (interface{}, error) {
if from.Kind() != reflect.String {
return data, nil
}
if to != reflect.TypeOf((*Interface)(nil)).Elem() {
return data, nil
}
// Ensure interface is satisfied
var impl Interface = ImplementsInterfaceValueReceiver(data.(string))
return impl, nil
}
{
decodeInto := new(DecodeIntoInterface)
decoder, _ := NewDecoder(&DecoderConfig{
DecodeHook: stringToPointerInterfaceDecodeHook,
Result: decodeInto,
})
err := decoder.Decode(testData)
if err != nil {
t.Fatalf("Decode returned error: %s", err)
}
expected := &ImplementsInterfacePointerReceiver{"test"}
if !reflect.DeepEqual(decodeInto.Test, expected) {
t.Fatalf("expected: %#v (%T), got: %#v (%T)", decodeInto.Test, decodeInto.Test, expected, expected)
}
}
{
decodeInto := new(DecodeIntoInterface)
decoder, _ := NewDecoder(&DecoderConfig{
DecodeHook: stringToValueInterfaceDecodeHook,
Result: decodeInto,
})
err := decoder.Decode(testData)
if err != nil {
t.Fatalf("Decode returned error: %s", err)
}
expected := ImplementsInterfaceValueReceiver("test")
if !reflect.DeepEqual(decodeInto.Test, expected) {
t.Fatalf("expected: %#v (%T), got: %#v (%T)", decodeInto.Test, decodeInto.Test, expected, expected)
}
}
}
// #103 Check for data type before trying to access its composants prevent a panic error
// in decodeSlice
func TestDecodeBadDataTypeInSlice(t *testing.T) {
t.Parallel()
input := map[string]interface{}{
"Toto": "titi",
}
result := []struct {
Toto string
}{}
if err := Decode(input, &result); err == nil {
t.Error("An error was expected, got nil")
}
}
// #202 Ensure that intermediate maps in the struct -> struct decode process are settable
// and not just the elements within them.
func TestDecodeIntermediateMapsSettable(t *testing.T) {
type Timestamp struct {
Seconds int64
Nanos int32
}
type TsWrapper struct {
Timestamp *Timestamp
}
type TimeWrapper struct {
Timestamp time.Time
}
input := TimeWrapper{
Timestamp: time.Unix(123456789, 987654),
}
expected := TsWrapper{
Timestamp: &Timestamp{
Seconds: 123456789,
Nanos: 987654,
},
}
timePtrType := reflect.TypeOf((*time.Time)(nil))
mapStrInfType := reflect.TypeOf((map[string]interface{})(nil))
var actual TsWrapper
decoder, err := NewDecoder(&DecoderConfig{
Result: &actual,
DecodeHook: func(from, to reflect.Type, data interface{}) (interface{}, error) {
if from == timePtrType && to == mapStrInfType {
ts := data.(*time.Time)
nanos := ts.UnixNano()
seconds := nanos / 1000000000
nanos = nanos % 1000000000
return &map[string]interface{}{
"Seconds": seconds,
"Nanos": int32(nanos),
}, nil
}
return data, nil
},
})
if err != nil {
t.Fatalf("failed to create decoder: %v", err)
}
if err := decoder.Decode(&input); err != nil {
t.Fatalf("failed to decode input: %v", err)
}
if !reflect.DeepEqual(expected, actual) {
t.Fatalf("expected: %#[1]v (%[1]T), got: %#[2]v (%[2]T)", expected, actual)
}
}
// GH-206: decodeInt throws an error for an empty string
func TestDecode_weakEmptyStringToInt(t *testing.T) {
input := map[string]interface{}{
"StringToInt": "",
"StringToUint": "",
"StringToBool": "",
"StringToFloat": "",
}
expectedResultWeak := TypeConversionResult{
StringToInt: 0,
StringToUint: 0,
StringToBool: false,
StringToFloat: 0,
}
// Test weak type conversion
var resultWeak TypeConversionResult
err := WeakDecode(input, &resultWeak)
if err != nil {
t.Fatalf("got an jderr: %s", err)
}
if !reflect.DeepEqual(resultWeak, expectedResultWeak) {
t.Errorf("expected \n%#v, got: \n%#v", expectedResultWeak, resultWeak)
}
}
// GH-228: Squash cause *time.Time set to zero
func TestMapSquash(t *testing.T) {
type AA struct {
T *time.Time
}
type A struct {
AA
}
v := time.Now()
in := &AA{
T: &v,
}
out := &A{}
d, err := NewDecoder(&DecoderConfig{
Squash: true,
Result: out,
})
if err != nil {
t.Fatalf("jderr: %s", err)
}
if err := d.Decode(in); err != nil {
t.Fatalf("jderr: %s", err)
}
// these failed
if !v.Equal(*out.T) {
t.Fatal("expected equal")
}
if out.T.IsZero() {
t.Fatal("expected false")
}
}
// GH-238: Empty key name when decoding map from struct with only omitempty flag
func TestMapOmitEmptyWithEmptyFieldnameInTag(t *testing.T) {
type Struct struct {
Username string `mapstructure:",omitempty"`
Age int `mapstructure:",omitempty"`
}
s := Struct{
Username: "Joe",
}
var m map[string]interface{}
if err := Decode(s, &m); err != nil {
t.Fatal(err)
}
if len(m) != 1 {
t.Fatalf("fail: %#v", m)
}
if m["Username"] != "Joe" {
t.Fatalf("fail: %#v", m)
}
}

View File

@ -0,0 +1,256 @@
package mapstructure
import (
"fmt"
)
func ExampleDecode() {
type Person struct {
Name string
Age int
Emails []string
Extra map[string]string
}
// This input can come from anywhere, but typically comes from
// something like decoding JSON where we're not quite sure of the
// struct initially.
input := map[string]interface{}{
"name": "Mitchell",
"age": 91,
"emails": []string{"one", "two", "three"},
"extra": map[string]string{
"twitter": "mitchellh",
},
}
var result Person
err := Decode(input, &result)
if err != nil {
panic(err)
}
fmt.Printf("%#v", result)
// Output:
// mapstructure.Person{Name:"Mitchell", Age:91, Emails:[]string{"one", "two", "three"}, Extra:map[string]string{"twitter":"mitchellh"}}
}
func ExampleDecode_errors() {
type Person struct {
Name string
Age int
Emails []string
Extra map[string]string
}
// This input can come from anywhere, but typically comes from
// something like decoding JSON where we're not quite sure of the
// struct initially.
input := map[string]interface{}{
"name": 123,
"age": "bad value",
"emails": []int{1, 2, 3},
}
var result Person
err := Decode(input, &result)
if err == nil {
panic("should have an error")
}
fmt.Println(err.Error())
// Output:
// 5 error(s) decoding:
//
// * 'Age' expected type 'int', got unconvertible type 'string', value: 'bad value'
// * 'Emails[0]' expected type 'string', got unconvertible type 'int', value: '1'
// * 'Emails[1]' expected type 'string', got unconvertible type 'int', value: '2'
// * 'Emails[2]' expected type 'string', got unconvertible type 'int', value: '3'
// * 'Name' expected type 'string', got unconvertible type 'int', value: '123'
}
func ExampleDecode_metadata() {
type Person struct {
Name string
Age int
}
// This input can come from anywhere, but typically comes from
// something like decoding JSON where we're not quite sure of the
// struct initially.
input := map[string]interface{}{
"name": "Mitchell",
"age": 91,
"email": "foo@bar.com",
}
// For metadata, we make a more advanced DecoderConfig so we can
// more finely configure the decoder that is used. In this case, we
// just tell the decoder we want to track metadata.
var md Metadata
var result Person
config := &DecoderConfig{
Metadata: &md,
Result: &result,
}
decoder, err := NewDecoder(config)
if err != nil {
panic(err)
}
if err := decoder.Decode(input); err != nil {
panic(err)
}
fmt.Printf("Unused keys: %#v", md.Unused)
// Output:
// Unused keys: []string{"email"}
}
func ExampleDecode_weaklyTypedInput() {
type Person struct {
Name string
Age int
Emails []string
}
// This input can come from anywhere, but typically comes from
// something like decoding JSON, generated by a weakly typed language
// such as PHP.
input := map[string]interface{}{
"name": 123, // number => string
"age": "42", // string => number
"emails": map[string]interface{}{}, // empty map => empty array
}
var result Person
config := &DecoderConfig{
WeaklyTypedInput: true,
Result: &result,
}
decoder, err := NewDecoder(config)
if err != nil {
panic(err)
}
err = decoder.Decode(input)
if err != nil {
panic(err)
}
fmt.Printf("%#v", result)
// Output: mapstructure.Person{Name:"123", Age:42, Emails:[]string{}}
}
func ExampleDecode_tags() {
// Note that the mapstructure tags defined in the struct type
// can indicate which fields the values are mapped to.
type Person struct {
Name string `mapstructure:"person_name"`
Age int `mapstructure:"person_age"`
}
input := map[string]interface{}{
"person_name": "Mitchell",
"person_age": 91,
}
var result Person
err := Decode(input, &result)
if err != nil {
panic(err)
}
fmt.Printf("%#v", result)
// Output:
// mapstructure.Person{Name:"Mitchell", Age:91}
}
func ExampleDecode_embeddedStruct() {
// Squashing multiple embedded structs is allowed using the squash tag.
// This is demonstrated by creating a composite struct of multiple types
// and decoding into it. In this case, a person can carry with it both
// a Family and a Location, as well as their own FirstName.
type Family struct {
LastName string
}
type Location struct {
City string
}
type Person struct {
Family `mapstructure:",squash"`
Location `mapstructure:",squash"`
FirstName string
}
input := map[string]interface{}{
"FirstName": "Mitchell",
"LastName": "Hashimoto",
"City": "San Francisco",
}
var result Person
err := Decode(input, &result)
if err != nil {
panic(err)
}
fmt.Printf("%s %s, %s", result.FirstName, result.LastName, result.City)
// Output:
// Mitchell Hashimoto, San Francisco
}
func ExampleDecode_remainingData() {
// Note that the mapstructure tags defined in the struct type
// can indicate which fields the values are mapped to.
type Person struct {
Name string
Age int
Other map[string]interface{} `mapstructure:",remain"`
}
input := map[string]interface{}{
"name": "Mitchell",
"age": 91,
"email": "mitchell@example.com",
}
var result Person
err := Decode(input, &result)
if err != nil {
panic(err)
}
fmt.Printf("%#v", result)
// Output:
// mapstructure.Person{Name:"Mitchell", Age:91, Other:map[string]interface {}{"email":"mitchell@example.com"}}
}
func ExampleDecode_omitempty() {
// Add omitempty annotation to avoid map keys for empty values
type Family struct {
LastName string
}
type Location struct {
City string
}
type Person struct {
*Family `mapstructure:",omitempty"`
*Location `mapstructure:",omitempty"`
Age int
FirstName string
}
result := &map[string]interface{}{}
input := Person{FirstName: "Somebody"}
err := Decode(input, &result)
if err != nil {
panic(err)
}
fmt.Printf("%+v", result)
// Output:
// &map[Age:0 FirstName:Somebody]
}

View File

@ -0,0 +1,58 @@
package mapstructure
import (
"reflect"
"testing"
)
func TestDecode_Ptr(t *testing.T) {
t.Parallel()
type G struct {
Id int
Name string
}
type X struct {
Id int
Name int
}
type AG struct {
List []*G
}
type AX struct {
List []*X
}
g2 := &AG{
List: []*G{
{
Id: 11,
Name: "gg",
},
},
}
x2 := AX{}
// 报错但还是会转换成功,转换后值为目标类型的 0 值
err := Decode(g2, &x2)
res := AX{
List: []*X{
{
Id: 11,
Name: 0, // 这个类型的 0 值
},
},
}
if err == nil {
t.Errorf("Decode_Ptr jderr should not be 'nil': %#v", err)
}
if !reflect.DeepEqual(res, x2) {
t.Errorf("result should be %#v: got %#v", res, x2)
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,26 @@
package mapstructure
import "time"
// DecodeWithTime 支持时间转字符串
// 支持
// 1. *Time.time 转 string/*string
// 2. *Time.time 转 uint/uint32/uint64/int/int32/int64支持带指针
// 不能用 Time.time 转它会在上层认为是一个结构体数据而直接转成map再到hook方法
func DecodeWithTime(input, output interface{}, layout string) error {
if layout == "" {
layout = time.DateTime
}
config := &DecoderConfig{
Metadata: nil,
Result: output,
DecodeHook: ComposeDecodeHookFunc(TimeToStringHook(layout), TimeToUnixIntHook()),
}
decoder, err := NewDecoder(config)
if err != nil {
return err
}
return decoder.Decode(input)
}

View File

@ -0,0 +1,101 @@
package mapstructure
import (
"reflect"
"time"
)
// TimeToStringHook 时间转字符串
// 支持 *Time.time 转 string/*string
// 不能用 Time.time 转它会在上层认为是一个结构体数据而直接转成map再到hook方法
func TimeToStringHook(layout string) DecodeHookFunc {
return func(
f reflect.Type,
t reflect.Type,
data interface{}) (interface{}, error) {
// 判断目标类型是否为字符串
var strType string
var isStrPointer *bool // 要转换的目标类型是否为指针字符串
if t == reflect.TypeOf(strType) {
isStrPointer = new(bool)
} else if t == reflect.TypeOf(&strType) {
isStrPointer = new(bool)
*isStrPointer = true
}
if isStrPointer == nil {
return data, nil
}
// 判断类型是否为时间
timeType := time.Time{}
if f != reflect.TypeOf(timeType) && f != reflect.TypeOf(&timeType) {
return data, nil
}
// 将时间转换为字符串
var output string
switch v := data.(type) {
case *time.Time:
output = v.Format(layout)
case time.Time:
output = v.Format(layout)
default:
return data, nil
}
if *isStrPointer {
return &output, nil
}
return output, nil
}
}
// TimeToUnixIntHook 时间转时间戳
// 支持 *Time.time 转 uint/uint32/uint64/int/int32/int64支持带指针
// 不能用 Time.time 转它会在上层认为是一个结构体数据而直接转成map再到hook方法
func TimeToUnixIntHook() DecodeHookFunc {
return func(
f reflect.Type,
t reflect.Type,
data interface{}) (interface{}, error) {
tkd := t.Kind()
if tkd != reflect.Int && tkd != reflect.Int32 && tkd != reflect.Int64 &&
tkd != reflect.Uint && tkd != reflect.Uint32 && tkd != reflect.Uint64 {
return data, nil
}
// 判断类型是否为时间
timeType := time.Time{}
if f != reflect.TypeOf(timeType) && f != reflect.TypeOf(&timeType) {
return data, nil
}
// 将时间转换为字符串
var output int64
switch v := data.(type) {
case *time.Time:
output = v.Unix()
case time.Time:
output = v.Unix()
default:
return data, nil
}
switch tkd {
case reflect.Int:
return int(output), nil
case reflect.Int32:
return int32(output), nil
case reflect.Int64:
return output, nil
case reflect.Uint:
return uint(output), nil
case reflect.Uint32:
return uint32(output), nil
case reflect.Uint64:
return uint64(output), nil
default:
return data, nil
}
}
}

View File

@ -0,0 +1,274 @@
package mapstructure
import (
"testing"
"time"
)
func Test_TimeToStringHook(t *testing.T) {
type Input struct {
Time time.Time
Id int
}
type InputTPointer struct {
Time *time.Time
Id int
}
type Output struct {
Time string
Id int
}
type OutputTPointer struct {
Time *string
Id int
}
now := time.Now()
target := now.Format("2006-01-02 15:04:05")
idValue := 1
tests := []struct {
input any
output any
name string
layout string
}{
{
name: "测试Time.time转string",
layout: "2006-01-02 15:04:05",
input: InputTPointer{
Time: &now,
Id: idValue,
},
output: Output{},
},
{
name: "测试*Time.time转*string",
layout: "2006-01-02 15:04:05",
input: InputTPointer{
Time: &now,
Id: idValue,
},
output: OutputTPointer{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
decoder, err := NewDecoder(&DecoderConfig{
DecodeHook: TimeToStringHook(tt.layout),
Result: &tt.output,
})
if err != nil {
t.Errorf("NewDecoder() jderr = %vwant nil", err)
}
if i, isOk := tt.input.(Input); isOk {
err = decoder.Decode(i)
}
if i, isOk := tt.input.(InputTPointer); isOk {
err = decoder.Decode(&i)
}
if err != nil {
t.Errorf("Decode jderr = %vwant nil", err)
}
//验证测试值
if output, isOk := tt.output.(OutputTPointer); isOk {
if *output.Time != target {
t.Errorf("Decode output time = %vwant %v", *output.Time, target)
}
if output.Id != idValue {
t.Errorf("Decode output id = %vwant %v", output.Id, idValue)
}
}
if output, isOk := tt.output.(Output); isOk {
if output.Time != target {
t.Errorf("Decode output time = %vwant %v", output.Time, target)
}
if output.Id != idValue {
t.Errorf("Decode output id = %vwant %v", output.Id, idValue)
}
}
})
}
}
func Test_TimeToUnixIntHook(t *testing.T) {
type InputTPointer struct {
Time *time.Time
Id int
}
type Output[T int | *int | int32 | *int32 | int64 | *int64 | uint | *uint] struct {
Time T
Id int
}
type test struct {
input any
output any
name string
layout string
}
now := time.Now()
target := now.Unix()
idValue := 1
tests := []test{
{
name: "测试Time.time转int",
layout: "2006-01-02 15:04:05",
input: InputTPointer{
Time: &now,
Id: idValue,
},
output: Output[int]{},
},
{
name: "测试Time.time转*int",
layout: "2006-01-02 15:04:05",
input: InputTPointer{
Time: &now,
Id: idValue,
},
output: Output[*int]{},
},
{
name: "测试Time.time转int32",
layout: "2006-01-02 15:04:05",
input: InputTPointer{
Time: &now,
Id: idValue,
},
output: Output[int32]{},
},
{
name: "测试Time.time转*int32",
layout: "2006-01-02 15:04:05",
input: InputTPointer{
Time: &now,
Id: idValue,
},
output: Output[*int32]{},
},
{
name: "测试Time.time转int64",
layout: "2006-01-02 15:04:05",
input: InputTPointer{
Time: &now,
Id: idValue,
},
output: Output[int64]{},
},
{
name: "测试Time.time转*int64",
layout: "2006-01-02 15:04:05",
input: InputTPointer{
Time: &now,
Id: idValue,
},
output: Output[*int64]{},
},
{
name: "测试Time.time转uint",
layout: "2006-01-02 15:04:05",
input: InputTPointer{
Time: &now,
Id: idValue,
},
output: Output[uint]{},
},
{
name: "测试Time.time转*uint",
layout: "2006-01-02 15:04:05",
input: InputTPointer{
Time: &now,
Id: idValue,
},
output: Output[*uint]{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
decoder, err := NewDecoder(&DecoderConfig{
DecodeHook: TimeToUnixIntHook(),
Result: &tt.output,
})
if err != nil {
t.Errorf("NewDecoder() jderr = %vwant nil", err)
}
if i, isOk := tt.input.(InputTPointer); isOk {
err = decoder.Decode(i)
}
if i, isOk := tt.input.(InputTPointer); isOk {
err = decoder.Decode(&i)
}
if err != nil {
t.Errorf("Decode jderr = %vwant nil", err)
}
//验证测试值
switch v := tt.output.(type) {
case Output[int]:
if int64(v.Time) != target {
t.Errorf("Decode output time = %vwant %v", v.Time, target)
}
if v.Id != idValue {
t.Errorf("Decode output id = %vwant %v", v.Id, idValue)
}
case Output[*int]:
if int64(*v.Time) != target {
t.Errorf("Decode output time = %vwant %v", v.Time, target)
}
if v.Id != idValue {
t.Errorf("Decode output id = %vwant %v", v.Id, idValue)
}
case Output[int32]:
if int64(v.Time) != target {
t.Errorf("Decode output time = %vwant %v", v.Time, target)
}
if v.Id != idValue {
t.Errorf("Decode output id = %vwant %v", v.Id, idValue)
}
case Output[*int32]:
if int64(*v.Time) != target {
t.Errorf("Decode output time = %vwant %v", v.Time, target)
}
if v.Id != idValue {
t.Errorf("Decode output id = %vwant %v", v.Id, idValue)
}
case Output[int64]:
if int64(v.Time) != target {
t.Errorf("Decode output time = %vwant %v", v.Time, target)
}
if v.Id != idValue {
t.Errorf("Decode output id = %vwant %v", v.Id, idValue)
}
case Output[*int64]:
if int64(*v.Time) != target {
t.Errorf("Decode output time = %vwant %v", v.Time, target)
}
if v.Id != idValue {
t.Errorf("Decode output id = %vwant %v", v.Id, idValue)
}
case Output[uint]:
if int64(v.Time) != target {
t.Errorf("Decode output time = %vwant %v", v.Time, target)
}
if v.Id != idValue {
t.Errorf("Decode output id = %vwant %v", v.Id, idValue)
}
case Output[*uint]:
if int64(*v.Time) != target {
t.Errorf("Decode output time = %vwant %v", v.Time, target)
}
if v.Id != idValue {
t.Errorf("Decode output id = %vwant %v", v.Id, idValue)
}
}
})
}
}

38
pkg/response.go Normal file
View File

@ -0,0 +1,38 @@
package pkg
import (
"encoding/json"
"fmt"
"os"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/log"
)
func HandleResponse(c *fiber.Ctx, data interface{}) (err error) {
if os.Getenv("env") == "unit_test" {
log.Debug(data)
}
switch data.(type) {
case error:
err = data.(error)
case int, int32, int64, float32, float64, string, bool:
c.Response().SetBody([]byte(fmt.Sprintf("%s", data)))
case []byte:
c.Response().SetBody(data.([]byte))
default:
dataByte, _ := json.Marshal(data)
c.Response().SetBody(dataByte)
}
return
}
func SuccessWithPageMsg(c *fiber.Ctx, list interface{}, total int64, page, pageSize int) error {
response := fiber.Map{
"list": list,
"total": total,
"page": page,
"pageSize": pageSize,
}
return HandleResponse(c, response)
}

743
pkg/validata.go Normal file
View File

@ -0,0 +1,743 @@
package pkg
import (
"fmt"
"geo/tmpl/errcode"
"reflect"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/go-playground/validator/v10"
"github.com/gofiber/fiber/v2"
)
const (
// 验证标签常量
TagRequired = "required"
TagEmail = "email"
TagMin = "min"
TagMax = "max"
TagLen = "len"
TagNumeric = "numeric"
TagOneof = "oneof"
TagGte = "gte"
TagGt = "gt"
TagLte = "lte"
TagLt = "lt"
TagURL = "url"
TagUUID = "uuid"
TagDatetime = "datetime"
// 中文标签常量
//LabelComment = "comment"
//LabelLabel = "label"
LabelZh = "zh"
// 上下文Key
CtxRequestBody = "request_body"
// 性能优化常量
InitialBuilderSize = 64
MaxCacheSize = 5000
CacheCleanupInterval = 10 * time.Minute
)
// ValidatorConfig 验证器配置
type ValidatorConfig struct {
StatusCode int // 验证失败时的HTTP状态码默认422
ErrorHandler func(c *fiber.Ctx, err error) error // 自定义错误处理
BeforeParse func(c *fiber.Ctx) error // 解析前执行
AfterParse func(c *fiber.Ctx, req interface{}) error // 解析后执行
DisableCache bool // 是否禁用缓存
MaxCacheSize int // 最大缓存数量
UseChinese bool // 是否使用中文提示
EnableMetrics bool // 是否启用指标收集
}
// CacheStats 缓存统计
type CacheStats struct {
hitCount int64
missCount int64
evictCount int64
errorCount int64
accessCount int64
}
// Metrics 指标收集
type Metrics struct {
totalRequests int64
validationTime int64
cacheHitRate float64
mu sync.RWMutex
}
// fieldInfo 字段信息
type fieldInfo struct {
label string
index int
typeKind reflect.Kind
accessCount int64
lastAccess time.Time
}
// typeInfo 类型信息
type typeInfo struct {
fields map[string]*fieldInfo
fieldNames []string
mu sync.RWMutex
accessCount int64
createdAt time.Time
lastAccess time.Time
typeKey string
}
// ValidatorHelper 验证器助手
type ValidatorHelper struct {
validate *validator.Validate
config *ValidatorConfig
typeCache sync.Map
cacheStats *CacheStats
errorFnCache map[string]func(string, string) string
errorFnCacheMu sync.RWMutex
pruneLock sync.Mutex
stopCleanup chan struct{}
cleanupOnce sync.Once
metrics *Metrics
}
var (
ChineseErrorTemplates = map[string]string{
TagRequired: "%s不能为空",
TagEmail: "%s格式不正确",
TagMin: "%s不能小于%s",
TagMax: "%s不能大于%s",
TagLen: "%s长度必须为%s位",
TagNumeric: "%s必须是数字",
TagOneof: "%s必须是以下值之一: %s",
TagGte: "%s不能小于%s",
TagGt: "%s必须大于%s",
TagLte: "%s不能大于%s",
TagLt: "%s必须小于%s",
TagURL: "%s必须是有效的URL地址",
TagUUID: "%s必须是有效的UUID",
TagDatetime: "%s日期格式不正确",
}
defaultErrorTemplates = map[string]string{
TagRequired: "%s is required",
TagEmail: "invalid email format",
TagMin: "%s must be at least %s",
TagMax: "%s must be at most %s",
TagLen: "%s must be exactly %s characters",
TagNumeric: "%s must be numeric",
TagOneof: "%s must be one of: %s",
TagGte: "%s must be greater than or equal to %s",
TagGt: "%s must be greater than %s",
TagLte: "%s must be less than or equal to %s",
TagLt: "%s must be less than %s",
TagURL: "%s must be a valid URL",
TagUUID: "%s must be a valid UUID",
TagDatetime: "%s invalid datetime format",
}
)
var (
// 对象池
builderPool = sync.Pool{
New: func() interface{} {
b := &strings.Builder{}
b.Grow(InitialBuilderSize)
return b
},
}
// 错误信息切片池
errorSlicePool = sync.Pool{
New: func() interface{} {
slice := make([]string, 0, 8)
return &slice
},
}
// 字段信息对象池
fieldInfoPool = sync.Pool{
New: func() interface{} {
return &fieldInfo{
accessCount: 0,
lastAccess: time.Now(),
}
},
}
// 类型信息对象池
typeInfoPool = sync.Pool{
New: func() interface{} {
return &typeInfo{
fields: make(map[string]*fieldInfo),
fieldNames: make([]string, 0, 8),
}
},
}
)
var (
Vh *ValidatorHelper
once sync.Once
)
// NewValidatorHelper 初始化验证器助手
func NewValidatorHelper(config ...*ValidatorConfig) {
once.Do(func() {
v := validator.New()
// 优化JSON标签获取
v.RegisterTagNameFunc(func(fld reflect.StructField) string {
return fld.Tag.Get("json")
})
// 默认配置
cfg := &ValidatorConfig{
StatusCode: fiber.StatusUnprocessableEntity,
ErrorHandler: defaultErrorHandler,
MaxCacheSize: MaxCacheSize,
UseChinese: true,
EnableMetrics: false,
}
if len(config) > 0 && config[0] != nil {
if config[0].StatusCode != 0 {
cfg.StatusCode = config[0].StatusCode
}
if config[0].ErrorHandler != nil {
cfg.ErrorHandler = config[0].ErrorHandler
}
if config[0].MaxCacheSize > 0 {
cfg.MaxCacheSize = config[0].MaxCacheSize
}
cfg.DisableCache = config[0].DisableCache
cfg.BeforeParse = config[0].BeforeParse
cfg.AfterParse = config[0].AfterParse
cfg.UseChinese = config[0].UseChinese
cfg.EnableMetrics = config[0].EnableMetrics
}
// 预编译错误函数
errorFnCache := make(map[string]func(string, string) string)
templates := ChineseErrorTemplates
if !cfg.UseChinese {
templates = defaultErrorTemplates
}
for tag, tmpl := range templates {
t := tmpl // 捕获变量
errorFnCache[tag] = func(field, param string) string {
if strings.Contains(t, "%s") && strings.Count(t, "%s") == 2 {
return fmt.Sprintf(t, field, param)
}
return fmt.Sprintf(t, field)
}
}
Vh = &ValidatorHelper{
validate: v,
config: cfg,
cacheStats: &CacheStats{},
errorFnCache: errorFnCache,
errorFnCacheMu: sync.RWMutex{},
stopCleanup: make(chan struct{}),
metrics: &Metrics{},
}
// 启动定期清理
if !cfg.DisableCache {
go Vh.periodicCleanup()
}
})
}
// ParseAndValidate 解析并验证请求
func ParseAndValidate(c *fiber.Ctx, req interface{}) error {
if Vh == nil {
NewValidatorHelper()
}
if Vh.config.EnableMetrics {
atomic.AddInt64(&Vh.metrics.totalRequests, 1)
defer Vh.recordValidationTime(time.Now())
}
// 执行解析前钩子
if Vh.config.BeforeParse != nil {
if err := Vh.config.BeforeParse(c); err != nil {
return err
}
}
// 解析请求体
err := c.BodyParser(req)
if err != nil {
return errcode.ParamErr("请求格式错误:" + err.Error())
}
// 执行解析后钩子
if Vh.config.AfterParse != nil {
if err = Vh.config.AfterParse(c, req); err != nil {
return errcode.ParamErr(err.Error())
}
}
// 验证数据
err = Vh.validate.Struct(req)
if err != nil {
c.Locals(CtxRequestBody, req)
if !Vh.config.DisableCache {
t := reflect.TypeOf(req)
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
if t.Kind() == reflect.Struct {
Vh.safeGetOrCreateTypeInfo(t)
}
}
return Vh.config.ErrorHandler(c, err)
}
return nil
}
// Validate 直接验证结构体
func Validate(req interface{}) error {
if Vh == nil {
NewValidatorHelper()
}
if err := Vh.validate.Struct(req); err != nil {
return Vh.wrapValidationError(err, req)
}
return nil
}
// 默认错误处理
func defaultErrorHandler(c *fiber.Ctx, err error) error {
validationErrors, ok := err.(validator.ValidationErrors)
if !ok {
return errcode.SystemError
}
if len(validationErrors) == 0 {
return nil
}
// 快速路径:单个错误
if len(validationErrors) == 1 {
e := validationErrors[0]
msg := Vh.safeGetErrorMessage(c, e)
return errcode.ParamErr(msg)
}
// 从对象池获取builder
builder := builderPool.Get().(*strings.Builder)
builder.Reset()
defer builderPool.Put(builder)
req := c.Locals(CtxRequestBody)
for i, e := range validationErrors {
if i > 0 {
builder.WriteByte('\n')
}
builder.WriteString(Vh.safeGetErrorMessageWithReq(req, e))
}
return errcode.ParamErr(builder.String())
}
// 包装验证错误
func (vh *ValidatorHelper) wrapValidationError(err error, req interface{}) error {
validationErrors, ok := err.(validator.ValidationErrors)
if !ok {
return err
}
if len(validationErrors) == 0 {
return nil
}
// 构建错误消息
builder := builderPool.Get().(*strings.Builder)
builder.Reset()
defer builderPool.Put(builder)
for i, e := range validationErrors {
if i > 0 {
builder.WriteByte('\n')
}
builder.WriteString(vh.safeGetErrorMessageWithReq(req, e))
}
return errcode.ParamErr(builder.String())
}
// 安全获取错误消息
func (vh *ValidatorHelper) safeGetErrorMessage(c *fiber.Ctx, e validator.FieldError) string {
req := c.Locals(CtxRequestBody)
return vh.safeGetErrorMessageWithReq(req, e)
}
// 安全获取错误消息(带请求体)
func (vh *ValidatorHelper) safeGetErrorMessageWithReq(req interface{}, e validator.FieldError) string {
if req == nil {
return vh.safeFormatFieldError(e, nil)
}
t := reflect.TypeOf(req)
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
if t.Kind() != reflect.Struct {
return vh.safeFormatFieldError(e, nil)
}
// 安全获取类型信息 - 这里需要先声明变量
var typeInfoObj *typeInfo // 这里声明变量
if !vh.config.DisableCache {
if cached, ok := vh.typeCache.Load(t); ok {
typeInfoObj = cached.(*typeInfo)
atomic.AddInt64(&typeInfoObj.accessCount, 1)
}
}
return vh.safeFormatFieldError(e, typeInfoObj)
}
// 安全格式化字段错误
func (vh *ValidatorHelper) safeFormatFieldError(e validator.FieldError, typeInfo *typeInfo) string {
structField := e.StructField()
fieldName := e.Field()
// 获取字段标签
var label string
if typeInfo != nil {
typeInfo.mu.RLock()
if info, ok := typeInfo.fields[structField]; ok {
label = info.label
atomic.AddInt64(&info.accessCount, 1)
}
typeInfo.mu.RUnlock()
}
// 如果没有标签,返回默认消息
if label == "" {
return vh.safeGetDefaultErrorMessage(fieldName, e)
}
// 使用预编译的错误函数生成消息
vh.errorFnCacheMu.RLock()
fn, ok := vh.errorFnCache[e.Tag()]
vh.errorFnCacheMu.RUnlock()
if ok {
return fn(label, e.Param())
}
return label + "格式不正确"
}
// 安全获取默认错误消息
func (vh *ValidatorHelper) safeGetDefaultErrorMessage(field string, e validator.FieldError) string {
vh.errorFnCacheMu.RLock()
defer vh.errorFnCacheMu.RUnlock()
if fn, ok := vh.errorFnCache[e.Tag()]; ok {
return fn(field, e.Param())
}
return field + "验证失败"
}
// safeGetOrCreateTypeInfo 安全地获取或创建类型信息
func (vh *ValidatorHelper) safeGetOrCreateTypeInfo(t reflect.Type) *typeInfo {
if vh.config.DisableCache || t == nil || t.Kind() != reflect.Struct {
return nil
}
// 首先尝试从缓存读取
if cached, ok := vh.typeCache.Load(t); ok {
info := cached.(*typeInfo)
atomic.AddInt64(&info.accessCount, 1)
atomic.AddInt64(&vh.cacheStats.hitCount, 1)
info.lastAccess = time.Now()
return info
}
atomic.AddInt64(&vh.cacheStats.missCount, 1)
// 从对象池获取typeInfo
info := typeInfoPool.Get().(*typeInfo)
// 重置并初始化
info.mu.Lock()
// 清空现有map
for k := range info.fields {
delete(info.fields, k)
}
info.fieldNames = info.fieldNames[:0]
info.accessCount = 0
info.createdAt = time.Now()
info.lastAccess = time.Now()
info.typeKey = t.String()
// 预计算所有字段信息
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
// 获取标签
//label := field.Tag.Get(LabelComment)
//if label == "" {
// label = field.Tag.Get(LabelLabel)
//}
//if label == "" {
// label = field.Tag.Get(LabelZh)
//}
label := field.Tag.Get(LabelZh)
// 从对象池获取或创建字段信息
fieldInfo := fieldInfoPool.Get().(*fieldInfo)
fieldInfo.label = label
fieldInfo.index = i
fieldInfo.typeKind = field.Type.Kind()
fieldInfo.accessCount = 0
fieldInfo.lastAccess = time.Now()
info.fields[field.Name] = fieldInfo
info.fieldNames = append(info.fieldNames, field.Name)
}
info.mu.Unlock()
// 使用原子操作确保线程安全的存储
if existing, loaded := vh.typeCache.LoadOrStore(t, info); loaded {
// 如果已经有其他goroutine存储了使用已有的并回收新创建的
info.mu.Lock()
for _, fieldInfo := range info.fields {
fieldInfoPool.Put(fieldInfo)
}
info.mu.Unlock()
typeInfoPool.Put(info)
existingInfo := existing.(*typeInfo)
atomic.AddInt64(&existingInfo.accessCount, 1)
return existingInfo
}
return info
}
// 定期清理缓存
func (vh *ValidatorHelper) periodicCleanup() {
ticker := time.NewTicker(CacheCleanupInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
vh.safeCleanupCache()
case <-vh.stopCleanup:
return
}
}
}
// safeCleanupCache 安全清理缓存
func (vh *ValidatorHelper) safeCleanupCache() {
vh.pruneLock.Lock()
defer vh.pruneLock.Unlock()
var keysToDelete []interface{}
now := time.Now()
vh.typeCache.Range(func(key, value interface{}) bool {
info, ok := value.(*typeInfo)
if !ok {
return true
}
// 检查是否需要清理
accessCount := atomic.LoadInt64(&info.accessCount)
age := now.Sub(info.createdAt)
idleTime := now.Sub(info.lastAccess)
// 清理条件:
// 1. 很少访问的缓存(访问次数 < 10
// 2. 空闲时间超过30分钟
// 3. 缓存年龄超过1小时且访问次数较少
if (accessCount < 10 && idleTime > 30*time.Minute) ||
(age > 1*time.Hour && accessCount < 100) {
keysToDelete = append(keysToDelete, key)
atomic.AddInt64(&vh.cacheStats.evictCount, 1)
}
return true
})
// 删除选中的缓存
for _, key := range keysToDelete {
if val, ok := vh.typeCache.Load(key); ok {
if info, ok := val.(*typeInfo); ok {
// 安全回收字段信息
info.mu.Lock()
for _, fieldInfo := range info.fields {
fieldInfoPool.Put(fieldInfo)
}
info.mu.Unlock()
typeInfoPool.Put(info)
}
vh.typeCache.Delete(key)
}
}
}
// ==================== 指标收集 ====================
func (vh *ValidatorHelper) recordValidationTime(start time.Time) {
if !vh.config.EnableMetrics {
return
}
duration := time.Since(start).Nanoseconds()
atomic.AddInt64(&vh.metrics.validationTime, duration)
}
// GetMetrics 获取性能指标
func (vh *ValidatorHelper) GetMetrics() map[string]interface{} {
if !vh.config.EnableMetrics {
return nil
}
vh.metrics.mu.RLock()
defer vh.metrics.mu.RUnlock()
hitCount := atomic.LoadInt64(&vh.cacheStats.hitCount)
missCount := atomic.LoadInt64(&vh.cacheStats.missCount)
totalRequests := hitCount + missCount
var hitRate float64
if totalRequests > 0 {
hitRate = float64(hitCount) / float64(totalRequests) * 100
}
var cacheSize int64
vh.typeCache.Range(func(_, _ interface{}) bool {
cacheSize++
return true
})
return map[string]interface{}{
"cache_hit_count": hitCount,
"cache_miss_count": missCount,
"cache_evict_count": atomic.LoadInt64(&vh.cacheStats.evictCount),
"cache_hit_rate": fmt.Sprintf("%.2f%%", hitRate),
"cache_size": cacheSize,
"error_count": atomic.LoadInt64(&vh.cacheStats.errorCount),
"total_requests": atomic.LoadInt64(&vh.metrics.totalRequests),
"avg_validation_time_ns": atomic.LoadInt64(&vh.metrics.validationTime) /
max(atomic.LoadInt64(&vh.metrics.totalRequests), 1),
}
}
// RegisterValidation 注册自定义验证规则
func (vh *ValidatorHelper) RegisterValidation(tag string, fn validator.Func, callValidationEvenIfNull ...bool) error {
return vh.validate.RegisterValidation(tag, fn, callValidationEvenIfNull...)
}
// RegisterTranslation 注册自定义翻译
func (vh *ValidatorHelper) RegisterTranslation(tag string, template string) {
vh.errorFnCacheMu.Lock()
defer vh.errorFnCacheMu.Unlock()
t := template
vh.errorFnCache[tag] = func(field, param string) string {
if strings.Contains(t, "%s") && strings.Count(t, "%s") == 2 {
return fmt.Sprintf(t, field, param)
}
return fmt.Sprintf(t, field)
}
}
// Stop 停止后台清理任务
func (vh *ValidatorHelper) Stop() {
close(vh.stopCleanup)
}
// Reset 重置验证器状态
func (vh *ValidatorHelper) Reset() {
vh.ClearCache()
atomic.StoreInt64(&vh.cacheStats.hitCount, 0)
atomic.StoreInt64(&vh.cacheStats.missCount, 0)
atomic.StoreInt64(&vh.cacheStats.evictCount, 0)
atomic.StoreInt64(&vh.cacheStats.errorCount, 0)
atomic.StoreInt64(&vh.metrics.totalRequests, 0)
atomic.StoreInt64(&vh.metrics.validationTime, 0)
}
// ClearCache 清理所有缓存
func (vh *ValidatorHelper) ClearCache() {
vh.pruneLock.Lock()
defer vh.pruneLock.Unlock()
// 安全回收所有缓存
vh.typeCache.Range(func(key, value interface{}) bool {
if info, ok := value.(*typeInfo); ok {
info.mu.Lock()
for _, fieldInfo := range info.fields {
fieldInfoPool.Put(fieldInfo)
}
info.mu.Unlock()
typeInfoPool.Put(info)
}
vh.typeCache.Delete(key)
return true
})
}
// SetLanguage 设置语言
func (vh *ValidatorHelper) SetLanguage(useChinese bool) {
vh.config.UseChinese = useChinese
templates := defaultErrorTemplates
if useChinese {
templates = ChineseErrorTemplates
}
vh.errorFnCacheMu.Lock()
defer vh.errorFnCacheMu.Unlock()
for tag, tmpl := range templates {
t := tmpl
vh.errorFnCache[tag] = func(field, param string) string {
if strings.Contains(t, "%s") && strings.Count(t, "%s") == 2 {
return fmt.Sprintf(t, field, param)
}
return fmt.Sprintf(t, field)
}
}
}
func max(a, b int64) int64 {
if a > b {
return a
}
return b
}
func GetErr(tag, field, param string) string {
t := ChineseErrorTemplates[tag]
if strings.Contains(t, "%s") && strings.Count(t, "%s") == 2 {
return fmt.Sprintf(t, field, param)
}
return fmt.Sprintf(t, field)
}

195
pkg/wx.go Normal file
View File

@ -0,0 +1,195 @@
package pkg
import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"fmt"
"io"
"os"
"sync"
"time"
"net/http"
"errors"
"github.com/redis/go-redis/v9"
)
type WeChatLoginResponse struct {
OpenID string `json:"openid"` // 用户唯一标识
SessionKey string `json:"session_key"` // 会话密钥
UnionID string `json:"unionid"` // 用户在开放平台的唯一标识(如果绑定了开放平台才有)
Errcode int `json:"errcode"` // 错误码0为成功
Errmsg string `json:"errmsg"` // 错误信息
}
func GetOpenID(appID, appSecret, code string) (openid string, err error) {
if os.Getenv("env") == "unit_test" {
return "test_123456", nil
}
// 1. 构建请求微信接口的 URL
url := fmt.Sprintf("https://api.weixin.qq.com/sns/jscode2session?appid=%s&secret=%s&js_code=%s&grant_type=authorization_code",
appID, appSecret, code)
// 2. 创建 HTTP 客户端(设置超时,避免阻塞)
client := &http.Client{
Timeout: 5 * time.Second,
// 在某些网络受限的环境(如本地测试跳过证书验证,生产环境建议去掉)
Transport: &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: false}, // 生产环境建议设为 false
},
}
// 3. 发起 GET 请求
resp, err := client.Get(url)
if err != nil {
return "", fmt.Errorf("请求微信服务器失败: %w", err)
}
defer resp.Body.Close()
// 4. 读取返回的 Body
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("读取微信响应失败: %w", err)
}
// 5. 解析 JSON 数据
var wechatResp WeChatLoginResponse
err = json.Unmarshal(body, &wechatResp)
if err != nil {
return "", fmt.Errorf("解析微信响应 JSON 失败: %s, 原始数据: %s", err.Error(), string(body))
}
// 6. 检查微信接口返回的错误码
if wechatResp.Errcode != 0 {
// 这里可以根据不同的错误码做特殊处理,例如 code 无效、过期等
return "", fmt.Errorf("微信接口返回错误: code=%d, msg=%s", wechatResp.Errcode, wechatResp.Errmsg)
}
// 7. 检查 OpenID 是否为空(理论上不会,但防御性编程)
if wechatResp.OpenID == "" {
return "", errors.New("微信返回的 OpenID 为空")
}
// 8. 返回 OpenID
return wechatResp.OpenID, nil
}
type AccessTokenResponse struct {
AccessToken string `json:"access_token"`
ExpiresIn int `json:"expires_in"`
Errcode int `json:"errcode"`
Errmsg string `json:"errmsg"`
}
var (
tokenMutex sync.Mutex
cacheKey = "wx:access_token"
)
// GetAccessToken 获取 access_token带本地缓存
func GetAccessToken(ctx context.Context, appID, appSecret string, rdb *redis.Client) (string, error) {
if rdb == nil {
return "", errors.New("缓存工具未提供")
}
cacheToken := rdb.Get(ctx, cacheKey).Val()
if cacheToken != "" {
return cacheToken, nil
}
tokenMutex.Lock()
defer tokenMutex.Unlock()
// 请求微信接口获取新的 access_token
url := fmt.Sprintf("https://api.weixin.qq.com/cgi-bin/token?grant_type=client_credential&appid=%s&secret=%s", appID, appSecret)
client := &http.Client{Timeout: 5 * time.Second}
resp, err := client.Get(url)
if err != nil {
return "", fmt.Errorf("请求 access_token 失败: %w", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
var tokenRes AccessTokenResponse
if err := json.Unmarshal(body, &tokenRes); err != nil {
return "", fmt.Errorf("解析 access_token 响应失败: %w", err)
}
if tokenRes.Errcode != 0 {
return "", fmt.Errorf("获取 access_token 失败: code=%d, msg=%s", tokenRes.Errcode, tokenRes.Errmsg)
}
// 缓存 token提前5分钟过期避免边界情况
rdb.Set(ctx, cacheKey, tokenRes.AccessToken, time.Duration(tokenRes.ExpiresIn-300)*time.Second)
return tokenRes.AccessToken, nil
}
// PhoneInfo 定义手机号信息的结构体,与微信官方文档对齐 [citation:3][citation:8]
type PhoneInfo struct {
PhoneNumber string `json:"phoneNumber"` // 用户绑定的手机号(国外手机号会有区号)
PurePhoneNumber string `json:"purePhoneNumber"` // 没有区号的手机号
CountryCode string `json:"countryCode"` // 区号
Watermark struct {
Timestamp int64 `json:"timestamp"`
Appid string `json:"appid"`
} `json:"watermark"`
}
// PhoneInfoResponse 定义微信接口返回的完整结构
type PhoneInfoResponse struct {
Errcode int `json:"errcode"`
Errmsg string `json:"errmsg"`
PhoneInfo PhoneInfo `json:"phone_info"`
}
// GetPhoneNumber 通过手机号 code 获取用户手机号
// 参数:
// - appID: 小程序的 AppID
// - appSecret: 小程序的 AppSecret
// - phoneCode: 前端通过 getPhoneNumber 获取的 code
//
// 返回:
// - *PhoneInfo: 手机号信息
// - error: 错误信息
func GetPhoneNumber(ctx context.Context, appID, appSecret, phoneCode string, rdb *redis.Client) (*PhoneInfo, error) {
// 1. 获取 access_token
accessToken, err := GetAccessToken(ctx, appID, appSecret, rdb)
if err != nil {
return nil, fmt.Errorf("获取 access_token 失败: %w", err)
}
// 2. 调用微信接口换取手机号 [citation:8]
url := fmt.Sprintf("https://api.weixin.qq.com/wxa/business/getuserphonenumber?access_token=%s", accessToken)
// 构建请求体
requestBody := map[string]string{
"code": phoneCode,
}
jsonBody, _ := json.Marshal(requestBody)
client := &http.Client{Timeout: 5 * time.Second}
resp, err := client.Post(url, "application/json", bytes.NewReader(jsonBody))
if err != nil {
return nil, fmt.Errorf("请求手机号接口失败: %w", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
var phoneResp PhoneInfoResponse
if err := json.Unmarshal(body, &phoneResp); err != nil {
return nil, fmt.Errorf("解析手机号响应失败: %s", string(body))
}
// 3. 检查微信接口返回的错误码 [citation:8]
if phoneResp.Errcode != 0 {
return nil, fmt.Errorf("微信接口返回错误: code=%d, msg=%s", phoneResp.Errcode, phoneResp.Errmsg)
}
return &phoneResp.PhoneInfo, nil
}

289
tmpl/dataTemp/queryTempl.go Normal file
View File

@ -0,0 +1,289 @@
package dataTemp
import (
"context"
"database/sql"
"fmt"
"geo/tmpl/errcode"
"geo/utils"
"reflect"
"github.com/go-kratos/kratos/v2/log"
"gorm.io/gorm"
"xorm.io/builder"
)
type PrimaryKey struct {
Id int `json:"id"`
}
type GormDb struct {
Client *gorm.DB
}
type contextTxKey struct{}
func (d *Db) DB(ctx context.Context) *gorm.DB {
tx, ok := ctx.Value(contextTxKey{}).(*gorm.DB)
if ok {
return tx
}
return d.Db.Client
}
func (t *Db) ExecTx(ctx context.Context, f func(ctx context.Context) error) error {
return t.Db.Client.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
ctx = context.WithValue(ctx, contextTxKey{}, tx)
return f(ctx)
})
}
type Db struct {
Db *GormDb
Log *log.Helper
}
type DataTemp struct {
Db *gorm.DB
ModelType reflect.Type // 改为存储类型而不是实例
modelName string // 可选的表名缓存
}
func NewDataTemp(db *utils.Db, model interface{}) *DataTemp {
// 获取模型的类型
t := reflect.TypeOf(model)
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
return &DataTemp{
Db: db.Client,
ModelType: t,
}
}
func (k DataTemp) modelInstance() interface{} {
return reflect.New(k.ModelType).Interface()
}
func (k DataTemp) GetById(id int32) (data map[string]interface{}, err error) {
err = k.Db.Model(k.modelInstance()).Where("id = ?", id).Find(&data).Error
if data == nil {
err = sql.ErrNoRows
}
return
}
func (k DataTemp) GetByStruct(ctx context.Context, search interface{}, data interface{}, orderBy string) (err error) {
err = k.Db.Model(k.modelInstance()).WithContext(ctx).Where(search).Find(&data).Error
if data == nil {
err = sql.ErrNoRows
}
return
}
func (k DataTemp) SaveByStruct(search interface{}, data interface{}) (err error) {
err = k.Db.Model(k.modelInstance()).Where(search).Save(&data).Error
if data == nil {
err = sql.ErrNoRows
}
return
}
func (k DataTemp) Add(ctx context.Context, data interface{}) (err error) {
m := k.modelInstance()
if err = k.Db.Model(m).WithContext(ctx).Create(data).Error; err != nil {
return errcode.SqlErr(err)
}
return
}
func (k DataTemp) AddWithData(data interface{}) (interface{}, error) {
result := k.Db.Model(k.modelInstance()).Create(data)
if result.Error != nil {
return data, result.Error
}
return data, nil
}
func (k DataTemp) GetList(cond *builder.Cond, pageBoIn *ReqPageBo) (list []map[string]interface{}, pageBoOut *RespPageBo, err error) {
var (
query, _ = builder.ToBoundSQL(*cond)
model = k.Db.Model(k.modelInstance()).Where(query)
total int64
)
model.Count(&total)
pageBoOut = pageBoOut.SetDataByReq(total, pageBoIn)
model.Limit(pageBoIn.GetSize()).Offset(pageBoIn.GetOffset()).Order("updated_at desc").Find(&list)
return
}
func (k DataTemp) GetRange(ctx context.Context, cond *builder.Cond) (list []map[string]interface{}, err error) {
var (
query, _ = builder.ToBoundSQL(*cond)
model = k.Db.Model(k.modelInstance()).Where(query)
)
err = model.WithContext(ctx).Find(&list).Error
return list, err
}
func (k DataTemp) GetRangeToMapStruct(ctx context.Context, cond *builder.Cond, data interface{}) (err error) {
var (
query, _ = builder.ToBoundSQL(*cond)
model = k.Db.Model(k.modelInstance()).Where(query)
)
err = model.WithContext(ctx).Find(data).Error
return err
}
func (k DataTemp) GetOneBySearch(cond *builder.Cond) (data map[string]interface{}, err error) {
query, _ := builder.ToBoundSQL(*cond)
if err = k.Db.Model(k.modelInstance()).Where(query).Limit(1).Find(&data).Error; err != nil {
return nil, errcode.SqlErr(err)
}
return
}
func (k DataTemp) Exist(ctx context.Context, cond *builder.Cond) (bool, error) {
var data map[string]interface{}
query, _ := builder.ToBoundSQL(*cond)
err := k.Db.WithContext(ctx).Model(k.modelInstance()).Where(query).Limit(1).Find(&data).Error
if err != nil || data != nil {
return true, err
}
return false, nil
}
func (k DataTemp) GetOneBySearchStruct(ctx context.Context, cond *builder.Cond, data interface{}) (err error) {
query, _ := builder.ToBoundSQL(*cond)
if err = k.Db.Model(k.modelInstance()).WithContext(ctx).Where(query).Limit(1).Find(&data).Error; err != nil {
return errcode.SqlErr(err)
}
return
}
func (k DataTemp) GetListToStruct(ctx context.Context, cond *builder.Cond, pageBoIn *ReqPageBo, result interface{}, orderBy string) (pageBoOut *RespPageBo, err error) {
// 参数验证
if result == nil {
return nil, fmt.Errorf("result cannot be nil")
}
val := reflect.ValueOf(result)
if val.Kind() != reflect.Ptr {
return nil, fmt.Errorf("result must be a pointer")
}
elem := val.Elem()
if elem.Kind() != reflect.Slice {
return nil, fmt.Errorf("result must be a pointer to slice")
}
// 构建基础查询
query, _ := builder.ToBoundSQL(*cond)
// 预编译 SQL 以提高性能
// 使用 Table 指定表名,避免 GORM 的反射开销
db := k.Db.WithContext(ctx).Model(k.modelInstance()).Where(query)
// 获取总数(使用单独的计数查询,避免缓存影响)
var total int64
countDb := db
if pageBoIn != nil {
if err = countDb.Count(&total).Error; err != nil {
return nil, err
}
}
// 初始化分页响应
pageBoOut = &RespPageBo{}
pageBoOut = pageBoOut.SetDataByReq(total, pageBoIn)
// 如果没有数据,直接返回空切片
if total == 0 && pageBoIn != nil {
elem.Set(reflect.MakeSlice(elem.Type(), 0, 0))
return pageBoOut, nil
}
// 设置排序(使用索引字段提高性能)
if orderBy == "" {
orderBy = "updated_at desc"
}
// 应用分页和排序,执行查询
// 使用 Select 指定字段,避免查询所有字段(如果需要优化)
baseQuery := db
if pageBoIn != nil {
baseQuery = db.Limit(pageBoIn.GetSize()).Offset(pageBoIn.GetOffset()).Order(orderBy)
}
if err = baseQuery.
Order(orderBy).
Find(result).Error; err != nil {
return nil, err
}
return pageBoOut, nil
}
func (k DataTemp) UpdateByKey(ctx context.Context, key string, id interface{}, data interface{}) (err error) {
if err = k.Db.WithContext(ctx).Model(k.modelInstance()).Where(fmt.Sprintf("%s = ?", key), id).Updates(data).Error; err != nil {
return errcode.SqlErr(err)
}
return
}
func (k DataTemp) UpdateByCond(ctx context.Context, cond *builder.Cond, data interface{}) (err error) {
var (
query, _ = builder.ToBoundSQL(*cond)
model = k.Db.Model(k.modelInstance()).Where(query)
)
err = model.WithContext(ctx).Updates(data).Error
return err
}
func (k DataTemp) UpdateColumnByCond(ctx context.Context, cond *builder.Cond, column string, data interface{}) (err error) {
var (
query, _ = builder.ToBoundSQL(*cond)
model = k.Db.Model(k.modelInstance()).Where(query)
)
err = model.WithContext(ctx).Update(column, data).Error
return err
}
func (k DataTemp) GetByKey(ctx context.Context, key string, value interface{}, data interface{}) (err error) {
if err = k.Db.WithContext(ctx).Model(k.modelInstance()).Where(fmt.Sprintf("%s = ?", key), value).Find(data).Error; err != nil {
return errcode.SqlErr(err)
}
return
}
func (k DataTemp) DeleteByKey(ctx context.Context, key string, value int64) error {
result := k.Db.WithContext(ctx).Model(k.modelInstance()).Where(fmt.Sprintf("%s = ?", key), value).
Update("deleted_at", gorm.Expr("CURRENT_TIMESTAMP"))
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return errcode.NotFound("不存在或已被删除")
}
return nil
}
func (k DataTemp) CountByCond(ctx context.Context, cond *builder.Cond) (int64, error) {
var (
count int64
query, _ = builder.ToBoundSQL(*cond)
model = k.Db.Model(k.modelInstance()).Where(query)
)
err := model.WithContext(ctx).Count(&count).Error
if err != nil {
return 0, err
}
return count, err
}

33
tmpl/dataTemp/req_page.go Normal file
View File

@ -0,0 +1,33 @@
package dataTemp
// ReqPageBo 分页请求实体
type ReqPageBo struct {
Page int //页码从第1页开始
Limit int //分页大小
}
// GetOffset 获取便宜量
// 确保 dataTemp/page.go 中有这些方法
func (p *ReqPageBo) GetSize() int {
if p == nil {
return 10 // 默认每页10条
}
if p.Limit <= 0 {
return 10
}
return p.Limit
}
func (p *ReqPageBo) GetOffset() int {
if p == nil {
return 0
}
return (p.GetPage() - 1) * p.GetSize()
}
func (p *ReqPageBo) GetPage() int {
if p == nil || p.Page <= 0 {
return 1
}
return p.Page
}

View File

@ -0,0 +1,22 @@
package dataTemp
// RespPageBo 分页响应实体
type RespPageBo struct {
Page int //页码
Limit int //每页大小
Total int64 //总数
}
// SetDataByReq 通过req 设置响应参数
func (r *RespPageBo) SetDataByReq(total int64, reqPage *ReqPageBo) *RespPageBo {
resp := r
if r == nil {
resp = &RespPageBo{}
}
resp.Total = total
if reqPage != nil {
resp.Page = reqPage.Page
resp.Limit = reqPage.Limit
}
return resp
}

113
tmpl/errcode/common.go Normal file
View File

@ -0,0 +1,113 @@
package errcode
import "fmt"
var (
AuthNotFound = &BusinessErr{code: AuthErr, message: "账号不存在"}
AuthStatusFreeze = &BusinessErr{code: AuthErr, message: "账号冻结"}
AuthStatusDel = &BusinessErr{code: AuthErr, message: "身份验证失败"}
AuthStatusPwdFail = &BusinessErr{code: AuthErr, message: "密码错误"}
AuthTokenCreateFail = &BusinessErr{code: AuthErr, message: "token生成失败"}
AuthTokenDelFail = &BusinessErr{code: AuthErr, message: "删除token失败"}
AuthWxLoginFail = &BusinessErr{code: AuthErr, message: "微信登录失败,请稍后重试"}
AuthInfoFail = &BusinessErr{code: AuthErr, message: "登录异常,请重新登录"}
TokenNotFound = &BusinessErr{code: TokenErr, message: "缺少 Authorization Header"}
TokenFormatErr = &BusinessErr{code: TokenErr, message: "无效的token格式"}
TokenInfoNotFound = &BusinessErr{code: TokenErr, message: "未找到用户信息"}
TokenInvalid = &BusinessErr{code: TokenErr, message: "token过期"}
PlatsNotFound = &BusinessErr{code: NotFoundErr, message: "信息未找到"}
BadRequest = &BusinessErr{code: BadReqErr, message: "操作失败"}
ForbiddenError = &BusinessErr{code: ForbiddenErr, message: "权限不足"}
Success = &BusinessErr{code: 200, message: "成功"}
ParamError = &BusinessErr{code: ParamsErr, message: "参数错误"}
SystemError = &BusinessErr{code: 405, message: "系统错误"}
ClientNotFound = &BusinessErr{code: 406, message: "未找到client_id"}
SessionNotFound = &BusinessErr{code: 407, message: "未找到会话信息"}
UserNotFound = &BusinessErr{code: NotFoundErr, message: "不存在的用户"}
KeyNotFound = &BusinessErr{code: 409, message: "身份验证失败"}
SysNotFound = &BusinessErr{code: 410, message: "未找到系统信息"}
SysCodeNotFound = &BusinessErr{code: 411, message: "未找到系统编码"}
InvalidParam = &BusinessErr{code: InvalidParamCode, message: "无效参数"}
WorkflowError = &BusinessErr{code: 501, message: "工作流过程错误"}
ClientInfoNotFound = &BusinessErr{code: NotFoundErr, message: "用户信息未找到"}
)
const (
InvalidParamCode = 408
AuthErr = 403
TokenErr = 401
ParamsErr = 422
BadReqErr = 400
NotFoundErr = 404
ForbiddenErr = 403
BalanceNotEnoughCode = 402
)
type BusinessErr struct {
code int
message string
}
func (e *BusinessErr) Error() string {
return e.message
}
func (e *BusinessErr) Code() int {
return e.code
}
func NotFound(message string) *BusinessErr {
return &BusinessErr{code: NotFoundErr, message: PlatsNotFound.message + ":" + message}
}
func (e *BusinessErr) Is(target error) bool {
_, ok := target.(*BusinessErr)
return ok
}
// CustomErr 自定义错误
func NewBusinessErr(code int, message string) *BusinessErr {
return &BusinessErr{code: code, message: message}
}
func SysErrf(message string, arg ...any) *BusinessErr {
return &BusinessErr{code: SystemError.code, message: fmt.Sprintf(message, arg)}
}
func SysErr(message string) *BusinessErr {
return &BusinessErr{code: SystemError.code, message: message}
}
func ParamErrf(message string, arg ...any) *BusinessErr {
return &BusinessErr{code: ParamError.code, message: fmt.Sprintf(message, arg)}
}
func ParamErr(message string) *BusinessErr {
return &BusinessErr{code: ParamError.code, message: ParamError.message + ":" + message}
}
func SqlErr(err error) *BusinessErr {
return &BusinessErr{code: ParamError.code, message: "数据操作失败,请联系管理员处理:" + err.Error()}
}
func BadReq(message string) *BusinessErr {
return &BusinessErr{code: BadReqErr, message: BadRequest.message + ":" + message}
}
func Forbidden(message string) *BusinessErr {
return &BusinessErr{code: ForbiddenErr, message: ForbiddenError.message + ":" + message}
}
func (e *BusinessErr) Wrap(err error) *BusinessErr {
return NewBusinessErr(e.code, err.Error())
}
func BalanceNotEnoughErr(message string) *BusinessErr {
return NewBusinessErr(BalanceNotEnoughCode, message)
}

133
utils/gorm.go Normal file
View File

@ -0,0 +1,133 @@
package utils
import (
"geo/internal/config"
"geo/utils/utils_gorm"
"gorm.io/gorm"
)
type Db struct {
Client *gorm.DB
}
func NewGormDb(c *config.Config) (*Db, func()) {
transDBClient, mf := utils_gorm.DBConn(&c.DB)
//directDBClient, df := directDB(c, hLog)
cleanup := func() {
mf()
//df()
}
return &Db{
Client: transDBClient,
//DirectDBClient: directDBClient,
}, cleanup
}
// GetOne 查询单条记录,返回 map
func (d *Db) GetOne(sql string, args ...interface{}) (map[string]interface{}, error) {
var result map[string]interface{}
// 使用 Raw 执行原生 SQLScan 到 map 需要先获取 rows
rows, err := d.Client.Raw(sql, args...).Rows()
if err != nil {
return nil, err
}
defer rows.Close()
if rows.Next() {
// 获取列名
columns, err := rows.Columns()
if err != nil {
return nil, err
}
// 创建扫描用的切片
values := make([]interface{}, len(columns))
valuePtrs := make([]interface{}, len(columns))
for i := range values {
valuePtrs[i] = &values[i]
}
if err := rows.Scan(valuePtrs...); err != nil {
return nil, err
}
result = make(map[string]interface{})
for i, col := range columns {
result[col] = values[i]
}
return result, nil
}
return nil, nil
}
// GetAll 查询多条记录,返回 map 切片
func (d *Db) GetAll(sql string, args ...interface{}) ([]map[string]interface{}, error) {
rows, err := d.Client.Raw(sql, args...).Rows()
if err != nil {
return nil, err
}
defer rows.Close()
results := make([]map[string]interface{}, 0)
for rows.Next() {
columns, err := rows.Columns()
if err != nil {
return nil, err
}
values := make([]interface{}, len(columns))
valuePtrs := make([]interface{}, len(columns))
for i := range values {
valuePtrs[i] = &values[i]
}
if err := rows.Scan(valuePtrs...); err != nil {
return nil, err
}
row := make(map[string]interface{})
for i, col := range columns {
row[col] = values[i]
}
results = append(results, row)
}
return results, nil
}
// Execute 执行单条 SQLINSERT/UPDATE/DELETE返回影响行数
func (d *Db) Execute(sql string, args ...interface{}) (int64, error) {
result := d.Client.Exec(sql, args...)
if result.Error != nil {
return 0, result.Error
}
return result.RowsAffected, nil
}
// ExecuteMany 批量执行 SQL使用事务
func (d *Db) ExecuteMany(sql string, argsList [][]interface{}) (int64, error) {
var total int64
// 开始事务
tx := d.Client.Begin()
if tx.Error != nil {
return 0, tx.Error
}
for _, args := range argsList {
result := tx.Exec(sql, args...)
if result.Error != nil {
tx.Rollback()
return 0, result.Error
}
total += result.RowsAffected
}
// 提交事务
if err := tx.Commit().Error; err != nil {
return 0, err
}
return total, nil
}

9
utils/provider_set.go Normal file
View File

@ -0,0 +1,9 @@
package utils
import (
"github.com/google/wire"
)
var ProviderUtils = wire.NewSet(
NewGormDb,
)

41
utils/utils_gorm/gorm.go Normal file
View File

@ -0,0 +1,41 @@
package utils_gorm
import (
"database/sql"
"fmt"
"geo/internal/config"
"gorm.io/driver/mysql"
"gorm.io/gorm"
"time"
)
func DBConn(c *config.DB) (*gorm.DB, func()) {
mysqlConn, err := sql.Open(c.Driver, c.Source)
gormDB, err := gorm.Open(
mysql.New(mysql.Config{Conn: mysqlConn}),
)
gormDB.Logger = NewCustomLogger(gormDB)
if err != nil {
panic("failed to connect database")
}
sqlDB, err := gormDB.DB()
// SetMaxIdleConns sets the maximum number of connections in the idle connection pool.
sqlDB.SetMaxIdleConns(int(c.MaxIdle))
// SetMaxOpenConns sets the maximum number of open connections to the database.
sqlDB.SetMaxOpenConns(int(c.MaxLifetime))
// SetConnMaxLifetime sets the maximum amount of time a connection may be reused.
sqlDB.SetConnMaxLifetime(time.Hour)
return gormDB, func() {
if mysqlConn != nil {
fmt.Println("关闭 physicalGoodsDB")
if err := mysqlConn.Close(); err != nil {
fmt.Println("关闭 physicalGoodsDB 失败:", err)
}
}
}
}

View File

@ -0,0 +1,96 @@
package utils_gorm
import (
"context"
"fmt"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"regexp"
"strings"
"time"
)
type CustomLogger struct {
gormLogger logger.Interface
db *gorm.DB
}
func NewCustomLogger(db *gorm.DB) *CustomLogger {
return &CustomLogger{
gormLogger: logger.Default.LogMode(logger.Info),
db: db,
}
}
func (l *CustomLogger) LogMode(level logger.LogLevel) logger.Interface {
newlogger := *l
newlogger.gormLogger = l.gormLogger.LogMode(level)
return &newlogger
}
func (l *CustomLogger) Info(ctx context.Context, msg string, data ...interface{}) {
l.gormLogger.Info(ctx, msg, data...)
}
func (l *CustomLogger) Warn(ctx context.Context, msg string, data ...interface{}) {
l.gormLogger.Warn(ctx, msg, data...)
}
func (l *CustomLogger) Error(ctx context.Context, msg string, data ...interface{}) {
l.gormLogger.Error(ctx, msg, data...)
}
func (l *CustomLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
elapsed := time.Since(begin)
sql, _ := fc()
l.gormLogger.Trace(ctx, begin, fc, err)
operation := extractOperation(sql)
tableName := extractTableName(sql)
fmt.Println(tableName)
//// 将SQL语句保存到数据库
if operation == 0 || tableName == "sql_log" {
return
}
//go l.db.Model(&SqlLog{}).Create(&SqlLog{
// OperatorID: 1,
// OperatorName: "test",
// SqlInfo: sql,
// TableNames: tableName,
// Type: operation,
//})
// 如果有需要也可以根据执行时间elapsed等条件过滤或处理日志记录
if elapsed > time.Second {
//l.gormLogger.Warn(ctx, "Slow SQL (> 1s): %s", sql)
}
}
// extractTableName extracts the table name from a SQL query, supporting quoted table names.
func extractTableName(sql string) string {
// 使用非捕获组匹配多种SQL操作关键词
re := regexp.MustCompile(`(?i)\b(?:from|update|into|delete\s+from)\b\s+[\` + "`" + `"]?(\w+)[\` + "`" + `"]?`)
match := re.FindStringSubmatch(sql)
// 检查是否匹配成功
if len(match) > 1 {
return match[1]
}
return ""
}
// extractOperation extracts the operation type from a SQL query.
func extractOperation(sql string) int32 {
sql = strings.TrimSpace(strings.ToLower(sql))
var operation int32
if strings.HasPrefix(sql, "select") {
operation = 0
} else if strings.HasPrefix(sql, "insert") {
operation = 1
} else if strings.HasPrefix(sql, "update") {
operation = 3
} else if strings.HasPrefix(sql, "delete") {
operation = 2
}
return operation
}