diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..d5b50ee --- /dev/null +++ b/Makefile @@ -0,0 +1,23 @@ +GOHOSTOS:=$(shell go env GOHOSTOS) +GOPATH:=$(shell go env GOPATH) +VERSION=$(shell git describe --tags --always) +PROJECTNAME=yumchina + +ifeq ($(GOHOSTOS), windows) + #the `find.exe` is different from `find` in bash/shell. + #to see https://docs.microsoft.com/en-us/windows-server/administration/windows-commands/find. + #changed to use git-bash.exe to run find cli or other cli friendly, caused of every developer has a Git. + #Git_Bash= $(subst cmd\,bin\bash.exe,$(dir $(shell where git))) + Git_Bash=$(subst \,/,$(subst cmd\,bin\bash.exe,$(dir $(shell where git | grep cmd)))) + INTERNAL_PROTO_FILES=$(shell $(Git_Bash) -c "find internal -name *.proto") + API_PROTO_FILES=$(shell $(Git_Bash) -c "find api -name *.proto") +else + INTERNAL_PROTO_FILES=$(shell find internal -name *.proto) + API_PROTO_FILES=$(shell find api -name *.proto) +endif + + +.PHONY: wire +# generate wire +wire: + cd ./cmd/server && wire diff --git a/cmd/server/main.go b/cmd/server/main.go index 71bc59f..2b18181 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -1,87 +1,43 @@ package main import ( - "ai_scheduler/internal/handlers" - "context" - "flag" - "log" - "net/http" - "os" - "os/signal" - "syscall" - "time" + "ai_scheduler/internal/config" - _ "ai_scheduler/docs" + "flag" + "fmt" "github.com/gin-gonic/gin" + "github.com/gofiber/fiber/v2/log" ) +// Swagger 文档注解(保持不变) // @title AI Scheduler API // @version 1.0 // @description 智能路由调度系统API文档 // @termsOfService http://swagger.io/terms/ - // @contact.name API Support // @contact.url http://www.swagger.io/support // @contact.email support@swagger.io - // @license.name Apache 2.0 // @license.url http://www.apache.org/licenses/LICENSE-2.0.html - // @host localhost:8080 // @BasePath / + func main() { - // 解析命令行参数 - configPath := flag.String("config", "config.yaml", "配置文件路径") - port := flag.String("port", "8080", "服务端口") + + configPath := flag.String("config", "config.yaml", "Path to configuration file") flag.Parse() - - // 初始化应用程序 - app, err := InitializeApp(*configPath) + bc, err := config.LoadConfig(*configPath) if err != nil { - log.Fatalf("Failed to initialize app: %v", err) + log.Fatalf("加载配置失败: %v", err) } + app, cleanup, err := InitializeApp(bc, log.DefaultLogger()) + if err != nil { + log.Fatalf("项目初始化失败: %v", err) + } + defer cleanup() - // 设置Gin模式为发布模式 gin.SetMode(gin.ReleaseMode) - // 设置路由 - router := handlers.SetupRoutes(app.RouterService) - - // 创建HTTP服务器 - server := &http.Server{ - Addr: ":" + *port, - Handler: router, - ReadTimeout: 30 * time.Second, - WriteTimeout: 30 * time.Second, - } - - // 启动服务器 - go func() { - log.Printf("Starting server on port %s", *port) - log.Printf("Swagger UI: http://localhost:%s/swagger/index.html", *port) - log.Printf("Health check: http://localhost:%s/health", *port) - log.Printf("Chat API: http://localhost:%s/api/v1/chat", *port) - - if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { - log.Fatalf("Failed to start server: %v", err) - } - }() - - // 等待中断信号 - quit := make(chan os.Signal, 1) - signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) - <-quit - - log.Println("Shutting down server...") - - // 优雅关闭服务器 - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - if err := server.Shutdown(ctx); err != nil { - log.Printf("Server forced to shutdown: %v", err) - } else { - log.Println("Server exited gracefully") - } + log.Fatal(app.HttpServer.Listen(fmt.Sprintf(":%d", bc.Server.Port))) } diff --git a/cmd/server/wire.go b/cmd/server/wire.go index 1e0a950..826becf 100644 --- a/cmd/server/wire.go +++ b/cmd/server/wire.go @@ -4,63 +4,24 @@ package main import ( + "ai_scheduler/internal/biz" "ai_scheduler/internal/config" + "ai_scheduler/internal/pkg" + "ai_scheduler/internal/server" "ai_scheduler/internal/services" "ai_scheduler/internal/tools" - "ai_scheduler/pkg/ollama" - "ai_scheduler/pkg/types" - + "github.com/gofiber/fiber/v2/log" "github.com/google/wire" ) // InitializeApp 初始化应用程序 -func InitializeApp(configPath string) (*App, error) { - wire.Build( - // 配置 - config.LoadConfig, +func InitializeApp(*config.Config, log.AllLogger) (*server.Servers, func(), error) { + panic(wire.Build( + server.ProviderSetServer, + tools.ProviderSetTools, + pkg.ProviderSetClient, + services.ProviderSetServices, + biz.ProviderSetBiz, + )) - // Ollama客户端 - provideOllamaClient, - - // 工具管理器 - provideToolsConfig, - tools.NewManager, - - // 路由服务 - provideRouterService, - - // 应用程序 - NewApp, - ) - return &App{}, nil -} - -// provideOllamaClient 提供Ollama客户端 -func provideOllamaClient(cfg *config.Config) types.AIClient { - client, _ := ollama.NewClient(&cfg.Ollama) - return client -} - -// provideToolsConfig 提供工具配置 -func provideToolsConfig(cfg *config.Config) *config.ToolsConfig { - return &cfg.Tools -} - -// provideRouterService 提供路由服务 -func provideRouterService(aiClient types.AIClient, toolManager *tools.Manager) types.RouterService { - return services.NewRouterService(aiClient, toolManager) -} - -// App 应用程序结构 -type App struct { - Config *config.Config - RouterService types.RouterService -} - -// NewApp 创建应用程序 -func NewApp(cfg *config.Config, routerService types.RouterService) *App { - return &App{ - Config: cfg, - RouterService: routerService, - } } diff --git a/config/config.yaml b/config/config.yaml new file mode 100644 index 0000000..71b1cd6 --- /dev/null +++ b/config/config.yaml @@ -0,0 +1,11 @@ +# 服务器配置 +server: + port: 8090 + host: "0.0.0.0" + +ollama: + base_url: "http://localhost:11434" + model: "deepseek-r1:8b" + timeout: "120s" + level: "info" + format: "json" diff --git a/go.mod b/go.mod index 2539454..7b5b30d 100644 --- a/go.mod +++ b/go.mod @@ -5,19 +5,23 @@ go 1.24.0 toolchain go1.24.7 require ( + github.com/gin-contrib/cors v1.7.2 github.com/gin-gonic/gin v1.10.0 + github.com/gofiber/fiber/v2 v2.52.9 + github.com/google/uuid v1.6.0 github.com/google/wire v0.7.0 github.com/ollama/ollama v0.11.10 github.com/spf13/viper v1.17.0 github.com/swaggo/files v1.0.1 github.com/swaggo/gin-swagger v1.6.1 - github.com/swaggo/swag v1.16.6 + go.opentelemetry.io/otel v1.38.0 ) require ( github.com/KyleBanks/depth v1.2.1 // indirect github.com/PuerkitoBio/purell v1.1.1 // indirect github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578 // indirect + github.com/andybalholm/brotli v1.1.0 // indirect github.com/bytedance/sonic v1.11.6 // indirect github.com/bytedance/sonic/loader v0.1.1 // indirect github.com/cloudwego/base64x v0.1.4 // indirect @@ -36,15 +40,19 @@ require ( github.com/hashicorp/hcl v1.0.0 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect + github.com/klauspost/compress v1.17.9 // indirect github.com/klauspost/cpuid/v2 v2.2.7 // indirect github.com/leodido/go-urn v1.4.0 // indirect github.com/magiconair/properties v1.8.7 // indirect github.com/mailru/easyjson v0.7.6 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mattn/go-runewidth v0.0.16 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect + github.com/rivo/uniseg v0.2.0 // indirect github.com/sagikazarmark/locafero v0.3.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect github.com/sourcegraph/conc v0.3.0 // indirect @@ -52,8 +60,12 @@ require ( github.com/spf13/cast v1.5.1 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/subosito/gotenv v1.6.0 // indirect + github.com/swaggo/swag v1.16.6 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.12 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + github.com/valyala/fasthttp v1.51.0 // indirect + github.com/valyala/tcplisten v1.0.0 // indirect go.uber.org/atomic v1.9.0 // indirect go.uber.org/multierr v1.9.0 // indirect golang.org/x/arch v0.8.0 // indirect diff --git a/go.sum b/go.sum index 2e9d8ac..bfe4df9 100644 --- a/go.sum +++ b/go.sum @@ -44,6 +44,8 @@ github.com/PuerkitoBio/purell v1.1.1 h1:WEQqlqaGbrPkxLJWfBwQmfEAE1Z7ONdDLqrN38tN github.com/PuerkitoBio/purell v1.1.1/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0= github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578 h1:d+Bc7a5rLufV/sSk/8dngufqelfh6jnri85riMAaF/M= github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE= +github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M= +github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY= github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0= github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4= github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM= @@ -77,6 +79,8 @@ github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4 github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw= github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= +github.com/gin-contrib/cors v1.7.2 h1:oLDHxdg8W/XDoN/8zamqk/Drgt4oVZDvaV0YmvVICQw= +github.com/gin-contrib/cors v1.7.2/go.mod h1:SUJVARKgQ40dmrzgXEVxj2m7Ig1v1qIboQkPDTQ9t2E= github.com/gin-contrib/gzip v0.0.6 h1:NjcunTcGAj5CO1gn4N8jHOSIeRFHIbn51z6K+xaN4d4= github.com/gin-contrib/gzip v0.0.6/go.mod h1:QOJlmV2xmayAjkNS2Y8NQsMneuRShOU/kjovCXNuzzk= github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= @@ -106,6 +110,8 @@ github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBEx github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= +github.com/gofiber/fiber/v2 v2.52.9 h1:YjKl5DOiyP3j0mO61u3NTmK7or8GzzWzCFzkboyP5cw= +github.com/gofiber/fiber/v2 v2.52.9/go.mod h1:YEcBbO/FB+5M1IZNBP9FO3J9281zgPAreiI1oqg8nDw= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= @@ -160,6 +166,8 @@ github.com/google/pprof v0.0.0-20201203190320-1bf35d6f28c2/go.mod h1:kpwsk12EmLe github.com/google/pprof v0.0.0-20201218002935-b9804c9f04c2/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4= github.com/google/wire v0.7.0/go.mod h1:n6YbUQD9cPKTnHXEBN2DXlOp/mVADhVErcMFb0v3J18= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= @@ -178,6 +186,8 @@ github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHm github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= +github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM= github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= @@ -198,8 +208,13 @@ github.com/mailru/easyjson v0.0.0-20190614124828-94de47d64c63/go.mod h1:C1wdFJiN github.com/mailru/easyjson v0.0.0-20190626092158-b2ccc519800e/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/mailru/easyjson v0.7.6 h1:8yTIVnZgCoiM1TgqoeTl+LfU5Jg6/xL3QhGQnimLYnA= github.com/mailru/easyjson v0.7.6/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= +github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -207,7 +222,6 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= -github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/ollama/ollama v0.11.10 h1:J9zaoTPwIXOrYXCRAqI7rV4cJ+FOMuQc/vBqQ5GIdWg= github.com/ollama/ollama v0.11.10/go.mod h1:9+1//yWPsDE2u+l1a5mpaKrYw4VdnSsRU3ioq5BvMms= @@ -219,6 +233,8 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= +github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= @@ -249,8 +265,9 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= github.com/swaggo/files v1.0.1 h1:J1bVJ4XHZNq0I46UU90611i9/YzdrF7x92oX1ig5IdE= @@ -263,6 +280,12 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasthttp v1.51.0 h1:8b30A5JlZ6C7AS81RsWjYMQmrZG6feChmgAolCl1SqA= +github.com/valyala/fasthttp v1.51.0/go.mod h1:oI2XroL+lI7vdXyYoQk03bXBThfFl2cVdIA3Xl7cH8g= +github.com/valyala/tcplisten v1.0.0 h1:rBHj/Xf+E1tRGZyWIWwJDiRY0zc1Js+CV5DqwacVSA8= +github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= @@ -274,6 +297,8 @@ go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk= +go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8= +go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM= go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= @@ -426,6 +451,7 @@ golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -596,8 +622,9 @@ google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFW google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU= gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= diff --git a/internal/biz/provider_set.go b/internal/biz/provider_set.go new file mode 100644 index 0000000..e745ee0 --- /dev/null +++ b/internal/biz/provider_set.go @@ -0,0 +1,5 @@ +package biz + +import "github.com/google/wire" + +var ProviderSetBiz = wire.NewSet(NewAiRouterBiz) diff --git a/internal/services/router.go b/internal/biz/router.go similarity index 82% rename from internal/services/router.go rename to internal/biz/router.go index d25bd77..a163bd8 100644 --- a/internal/services/router.go +++ b/internal/biz/router.go @@ -1,9 +1,10 @@ -package services +package biz import ( "ai_scheduler/internal/constants" + "ai_scheduler/internal/entitys" "ai_scheduler/internal/tools" - "ai_scheduler/pkg/types" + "context" "encoding/json" "fmt" @@ -11,24 +12,24 @@ import ( "strings" ) -// RouterService 智能路由服务 -type RouterService struct { - aiClient types.AIClient +// AiRouterService 智能路由服务 +type AiRouterService struct { + aiClient entitys.AIClient toolManager *tools.Manager } // NewRouterService 创建路由服务 -func NewRouterService(aiClient types.AIClient, toolManager *tools.Manager) *RouterService { - return &RouterService{ +func NewAiRouterBiz(aiClient entitys.AIClient, toolManager *tools.Manager) entitys.RouterService { + return &AiRouterService{ aiClient: aiClient, toolManager: toolManager, } } // Route 执行智能路由 -func (r *RouterService) Route(ctx context.Context, req *types.ChatRequest) (*types.ChatResponse, error) { +func (r *AiRouterService) Route(ctx context.Context, req *entitys.ChatRequest) (*entitys.ChatResponse, error) { // 构建消息 - messages := []types.Message{ + messages := []entitys.Message{ // { // Role: "system", // Content: r.buildSystemPrompt(), @@ -88,7 +89,7 @@ func (r *RouterService) Route(ctx context.Context, req *types.ChatRequest) (*typ } // 构建包含工具结果的消息 - messages = append(messages, types.Message{ + messages = append(messages, entitys.Message{ Role: "assistant", Content: response.Message, }) @@ -96,7 +97,7 @@ func (r *RouterService) Route(ctx context.Context, req *types.ChatRequest) (*typ // 添加工具调用结果 for _, toolResult := range toolResults { toolResultStr, _ := json.Marshal(toolResult.Result) - messages = append(messages, types.Message{ + messages = append(messages, entitys.Message{ Role: "tool", Content: fmt.Sprintf("Tool %s result: %s", toolResult.Function.Name, string(toolResultStr)), }) @@ -117,14 +118,14 @@ func (r *RouterService) Route(ctx context.Context, req *types.ChatRequest) (*typ } // buildSystemPrompt 构建系统提示词 -func (r *RouterService) buildSystemPrompt() string { +func (r *AiRouterService) buildSystemPrompt() string { prompt := `你是一个智能路由系统,你的任务是根据用户输入判断用户的意图,并且执行对应的任务。` return prompt } // buildIntentPrompt 构建意图识别提示词 -func (r *RouterService) buildIntentPrompt(userInput string) string { +func (r *AiRouterService) buildIntentPrompt(userInput string) string { prompt := `请分析以下用户输入,判断用户的意图类型。 用户输入:{user_input} @@ -150,7 +151,7 @@ func (r *RouterService) buildIntentPrompt(userInput string) string { } // extractIntent 从AI响应中提取意图 -func (r *RouterService) extractIntent(response *types.ChatResponse) string { +func (r *AiRouterService) extractIntent(response *entitys.ChatResponse) string { if response == nil || response.Message == "" { return "" } @@ -171,7 +172,7 @@ func (r *RouterService) extractIntent(response *types.ChatResponse) string { } // handleOrderDiagnosis 处理订单诊断意图 -func (r *RouterService) handleOrderDiagnosis(ctx context.Context, req *types.ChatRequest, messages []types.Message) (*types.ChatResponse, error) { +func (r *AiRouterService) handleOrderDiagnosis(ctx context.Context, req *entitys.ChatRequest, messages []entitys.Message) (*entitys.ChatResponse, error) { // 调用订单详情工具 orderDetailTool, ok := r.toolManager.GetTool("zltxOrderDetail") if orderDetailTool == nil || !ok { @@ -203,7 +204,7 @@ func (r *RouterService) handleOrderDiagnosis(ctx context.Context, req *types.Cha } // handleKnowledgeQA 处理知识问答意图 -func (r *RouterService) handleKnowledgeQA(ctx context.Context, req *types.ChatRequest, messages []types.Message) (*types.ChatResponse, error) { +func (r *AiRouterService) handleKnowledgeQA(ctx context.Context, req *entitys.ChatRequest, messages []entitys.Message) (*entitys.ChatResponse, error) { return nil, nil } diff --git a/internal/config/config.go b/internal/config/config.go index 178649c..6c966b0 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -57,13 +57,13 @@ func LoadConfig(configPath string) (*Config, error) { viper.SetConfigType("yaml") // 设置默认值 - viper.SetDefault("server.port", "8080") - viper.SetDefault("server.host", "0.0.0.0") - viper.SetDefault("ollama.base_url", "http://localhost:11434") - viper.SetDefault("ollama.model", "llama2") - viper.SetDefault("ollama.timeout", "30s") - viper.SetDefault("logging.level", "info") - viper.SetDefault("logging.format", "json") + //viper.SetDefault("server.port", "8080") + //viper.SetDefault("server.host", "0.0.0.0") + //viper.SetDefault("ollama.base_url", "http://localhost:11434") + //viper.SetDefault("ollama.model", "llama2") + //viper.SetDefault("ollama.timeout", "30s") + //viper.SetDefault("logging.level", "info") + //viper.SetDefault("logging.format", "json") // 读取配置文件 if err := viper.ReadInConfig(); err != nil { diff --git a/internal/data/constant/const.go b/internal/data/constant/const.go new file mode 100644 index 0000000..19f4670 --- /dev/null +++ b/internal/data/constant/const.go @@ -0,0 +1,3 @@ +package constant + +const () diff --git a/internal/data/constant/supplier.go b/internal/data/constant/supplier.go new file mode 100644 index 0000000..19f4670 --- /dev/null +++ b/internal/data/constant/supplier.go @@ -0,0 +1,3 @@ +package constant + +const () diff --git a/internal/data/error/error_code.go b/internal/data/error/error_code.go new file mode 100644 index 0000000..ef8c36a --- /dev/null +++ b/internal/data/error/error_code.go @@ -0,0 +1,47 @@ +package errorcode + +var ( + Success = &BusinessErr{code: "0000", message: "成功"} + ParamError = &BusinessErr{code: "0001", message: "参数错误"} + NotFoundError = &BusinessErr{code: "0004", message: "请求地址未找到"} + SystemError = &BusinessErr{code: "0005", message: "系统错误"} + + SupplierNotFound = &BusinessErr{code: "0006", message: "供应商不存在"} + SupplierApiError = &BusinessErr{code: "0007", message: "第三方供应商接口报错"} + + InvalidParam = &BusinessErr{code: InvalidParamCode, message: "无效参数"} +) + +const ( + InvalidParamCode = "0008" +) + +type BusinessErr struct { + code string + message string +} + +func (e *BusinessErr) Error() string { + return e.message +} +func (e *BusinessErr) Code() string { + return e.code +} + +func (e *BusinessErr) Is(target error) bool { + _, ok := target.(*BusinessErr) + return ok +} + +// CustomErr 自定义错误 +func NewBusinessErr(code string, message string) *BusinessErr { + return &BusinessErr{code: code, message: message} +} + +func (e *BusinessErr) Wrap(err error) *BusinessErr { + return NewBusinessErr(e.code, err.Error()) +} + +func SupplierApiErrorDiy(message string) *BusinessErr { + return &BusinessErr{code: SupplierApiError.code, message: message} +} diff --git a/internal/data/impl/notify_data_impl.go b/internal/data/impl/notify_data_impl.go new file mode 100644 index 0000000..ce5d4b3 --- /dev/null +++ b/internal/data/impl/notify_data_impl.go @@ -0,0 +1,15 @@ +package impl + +import ( + "trans_hub/app/physical_goods/supplier/goods/service/internal/data/model" + "trans_hub/tmpl/dataTemp" + "trans_hub/utils" +) + +type NotifyDataImpl struct { + dataTemp.DataTemp +} + +func NewOrderImpl(db *utils.Db) *NotifyDataImpl { + return &NotifyDataImpl{*dataTemp.NewDataTemp(db, new(model.SupplierNotifyDatum))} +} diff --git a/internal/data/impl/provider_set.go b/internal/data/impl/provider_set.go new file mode 100644 index 0000000..92fbe40 --- /dev/null +++ b/internal/data/impl/provider_set.go @@ -0,0 +1,7 @@ +package impl + +import ( + "github.com/google/wire" +) + +var ProviderImpl = wire.NewSet(NewOrderImpl) diff --git a/pkg/types/types.go b/internal/entitys/types.go similarity index 99% rename from pkg/types/types.go rename to internal/entitys/types.go index fe5321c..5109189 100644 --- a/pkg/types/types.go +++ b/internal/entitys/types.go @@ -1,4 +1,4 @@ -package types +package entitys import ( "context" diff --git a/internal/handlers/router.go b/internal/handlers/router.go deleted file mode 100644 index 6b55820..0000000 --- a/internal/handlers/router.go +++ /dev/null @@ -1,40 +0,0 @@ -package handlers - -import ( - "ai_scheduler/pkg/types" - - "github.com/gin-gonic/gin" - swaggerFiles "github.com/swaggo/files" - ginSwagger "github.com/swaggo/gin-swagger" -) - -// SetupRoutes 设置路由 -func SetupRoutes(routerService types.RouterService) *gin.Engine { - r := gin.Default() - - // 添加CORS中间件 - r.Use(func(c *gin.Context) { - c.Header("Access-Control-Allow-Origin", "*") - c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") - c.Header("Access-Control-Allow-Headers", "Content-Type, Authorization") - if c.Request.Method == "OPTIONS" { - c.AbortWithStatus(204) - return - } - c.Next() - }) - - // 创建处理器 - chatHandler := NewChatHandler(routerService) - - // API路由组 - v1 := r.Group("/api/v1") - { - v1.POST("/chat", chatHandler.Chat) - } - - // Swagger文档 - r.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerFiles.Handler)) - - return r -} diff --git a/internal/middleware/auth.go b/internal/middleware/auth.go new file mode 100644 index 0000000..c7cc178 --- /dev/null +++ b/internal/middleware/auth.go @@ -0,0 +1,112 @@ +package middleware + +import ( + "context" + "errors" + "log" + "net/http" + "slices" + "strings" + + "knowlege-lsxd/internal/config" + "knowlege-lsxd/internal/types" + "knowlege-lsxd/internal/types/interfaces" + + "github.com/gin-gonic/gin" +) + +// 无需认证的API列表 +var noAuthAPI = map[string][]string{ + "/api/v1/test-data": {"GET"}, + "/api/v1/tenants": {"POST"}, + "/api/v1/initialization/*": {"GET", "POST"}, +} + +// 检查请求是否在无需认证的API列表中 +func isNoAuthAPI(path string, method string) bool { + for api, methods := range noAuthAPI { + // 如果以*结尾,按照前缀匹配,否则按照全路径匹配 + if strings.HasSuffix(api, "*") { + if strings.HasPrefix(path, strings.TrimSuffix(api, "*")) && slices.Contains(methods, method) { + return true + } + } else if path == api && slices.Contains(methods, method) { + return true + } + } + return false +} + +// Auth 认证中间件 +func Auth(tenantService interfaces.TenantService, cfg *config.Config) gin.HandlerFunc { + return func(c *gin.Context) { + // ignore OPTIONS request + if c.Request.Method == "OPTIONS" { + c.Next() + return + } + + // 检查请求是否在无需认证的API列表中 + if isNoAuthAPI(c.Request.URL.Path, c.Request.Method) { + c.Next() + return + } + + // Get API Key from request header + apiKey := c.GetHeader("X-API-Key") + if apiKey == "" { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) + c.Abort() + return + } + + // Get tenant information + //tenantID, err := tenantService.ExtractTenantIDFromAPIKey(apiKey) + //if err != nil { + // c.JSON(http.StatusUnauthorized, gin.H{ + // "error": "Unauthorized: invalid API key format", + // }) + // c.Abort() + // return + //} + + // Verify API key validity (matches the one in database) + t, err := tenantService.GetTenantByApiKey(c.Request.Context(), apiKey) + if err != nil { + log.Printf("Error getting tenant by ID: %v, tenantID: %d, apiKey: %s", err, t.ID, apiKey) + c.JSON(http.StatusUnauthorized, gin.H{ + "error": "Unauthorized: invalid API key", + }) + c.Abort() + return + } + + if t == nil || t.APIKey != apiKey { + c.JSON(http.StatusUnauthorized, gin.H{ + "error": "Unauthorized: invalid API key", + }) + c.Abort() + return + } + + // Store tenant ID in context + c.Set(types.TenantIDContextKey.String(), t.ID) + c.Set(types.TenantInfoContextKey.String(), t) + c.Request = c.Request.WithContext( + context.WithValue( + context.WithValue(c.Request.Context(), types.TenantIDContextKey, t.ID), + types.TenantInfoContextKey, t, + ), + ) + c.Next() + } +} + +// GetTenantIDFromContext helper function to get tenant ID from context +func GetTenantIDFromContext(ctx context.Context) (uint, error) { + tenantID, ok := ctx.Value("tenantID").(uint) + if !ok { + return 0, errors.New("tenant ID not found in context") + } + return tenantID, nil +} diff --git a/internal/middleware/error_handler.go b/internal/middleware/error_handler.go new file mode 100644 index 0000000..12648ef --- /dev/null +++ b/internal/middleware/error_handler.go @@ -0,0 +1,45 @@ +package middleware + +import ( + "net/http" + + "github.com/gin-gonic/gin" + + "ai_scheduler/internal/data/error" +) + +// ErrorHandler 是一个处理应用错误的中间件 +func ErrorHandler() gin.HandlerFunc { + return func(c *gin.Context) { + // 处理请求 + c.Next() + + // 检查是否有错误 + if len(c.Errors) > 0 { + // 获取最后一个错误 + err := c.Errors.Last().Err + + // 检查是否为应用错误 + if appErr, ok := errors.IsAppError(err); ok { + // 返回应用错误 + c.JSON(appErr.Code(), gin.H{ + "status": "error", + "error": gin.H{ + "code": appErr.Code(), + "message": appErr.Error(), + }, + }) + return + } + + // 处理其他类型的错误 + c.JSON(http.StatusInternalServerError, gin.H{ + "status": "error", + "error": gin.H{ + "code": http.StatusInternalServerError, + "message": "Internal server error", + }, + }) + } + } +} diff --git a/internal/middleware/logger.go b/internal/middleware/logger.go new file mode 100644 index 0000000..54370b7 --- /dev/null +++ b/internal/middleware/logger.go @@ -0,0 +1,83 @@ +package middleware + +import ( + "context" + "time" + + "github.com/gin-gonic/gin" + "github.com/google/uuid" + "knowlege-lsxd/internal/logger" + "knowlege-lsxd/internal/types" +) + +// RequestID middleware adds a unique request ID to the context +func RequestID() gin.HandlerFunc { + return func(c *gin.Context) { + // Get request ID from header or generate a new one + requestID := c.GetHeader("X-Request-ID") + if requestID == "" { + requestID = uuid.New().String() + } + // Set request ID in header + c.Header("X-Request-ID", requestID) + + // Set request ID in context + c.Set(types.RequestIDContextKey.String(), requestID) + + // Set logger in context + requestLogger := logger.GetLogger(c) + requestLogger = requestLogger.WithField("request_id", requestID) + c.Set(types.LoggerContextKey.String(), requestLogger) + + // Set request ID in the global context for logging + c.Request = c.Request.WithContext( + context.WithValue( + context.WithValue(c.Request.Context(), types.RequestIDContextKey, requestID), + types.LoggerContextKey, requestLogger, + ), + ) + + c.Next() + } +} + +// Logger middleware logs request details with request ID +func Logger() gin.HandlerFunc { + return func(c *gin.Context) { + start := time.Now() + path := c.Request.URL.Path + raw := c.Request.URL.RawQuery + + // Process request + c.Next() + + // Get request ID from context + requestID, exists := c.Get(types.RequestIDContextKey.String()) + if !exists { + requestID = "unknown" + } + + // Calculate latency + latency := time.Since(start) + + // Get client IP and status code + clientIP := c.ClientIP() + statusCode := c.Writer.Status() + method := c.Request.Method + + if raw != "" { + path = path + "?" + raw + } + + // Log with request ID + logger.GetLogger(c).Infof("[%s] %d | %3d | %13v | %15s | %s %s", + requestID, + statusCode, + c.Writer.Size(), + latency, + clientIP, + method, + path, + ) + } +} diff --git a/internal/middleware/recovery.go b/internal/middleware/recovery.go new file mode 100644 index 0000000..9b60f13 --- /dev/null +++ b/internal/middleware/recovery.go @@ -0,0 +1,34 @@ +package middleware + +import ( + "fmt" + "log" + "runtime/debug" + + "github.com/gin-gonic/gin" +) + +// Recovery is a middleware that recovers from panics +func Recovery() gin.HandlerFunc { + return func(c *gin.Context) { + defer func() { + if err := recover(); err != nil { + // Get request ID + requestID, _ := c.Get("RequestID") + + // Print stacktrace + stacktrace := debug.Stack() + // Log error + log.Printf("[PANIC] %s | %v | %s", requestID, err, stacktrace) + + // 返回500错误 + c.AbortWithStatusJSON(500, gin.H{ + "error": "Internal Server Error", + "message": fmt.Sprintf("%v", err), + }) + } + }() + + c.Next() + } +} diff --git a/internal/middleware/trace.go b/internal/middleware/trace.go new file mode 100644 index 0000000..d74513f --- /dev/null +++ b/internal/middleware/trace.go @@ -0,0 +1,124 @@ +package middleware + +import ( + "bytes" + "fmt" + "io" + "strings" + + "github.com/gin-gonic/gin" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + + "knowlege-lsxd/internal/tracing" + "knowlege-lsxd/internal/types" +) + +// Custom ResponseWriter to capture response content +type responseBodyWriter struct { + gin.ResponseWriter + body *bytes.Buffer +} + +// Override Write method to write response content to buffer and original writer +func (r responseBodyWriter) Write(b []byte) (int, error) { + r.body.Write(b) + return r.ResponseWriter.Write(b) +} + +// TracingMiddleware provides a Gin middleware that creates a trace span for each request +func TracingMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + // Extract trace context from request headers + propagator := tracing.GetTracer() + if propagator == nil { + c.Next() + return + } + + // Get request ID as Span ID + requestID := c.GetString(string(types.RequestIDContextKey)) + if requestID == "" { + requestID = c.GetHeader("X-Request-ID") + } + + // Create new span + spanName := fmt.Sprintf("%s %s", c.Request.Method, c.FullPath()) + ctx, span := tracing.ContextWithSpan(c.Request.Context(), spanName) + defer span.End() + + // Set basic span attributes + span.SetAttributes( + attribute.String("http.method", c.Request.Method), + attribute.String("http.url", c.Request.URL.String()), + attribute.String("http.path", c.FullPath()), + ) + + // Record request headers (optional, or selectively record important headers) + for key, values := range c.Request.Header { + // Skip sensitive or unnecessary headers + if strings.ToLower(key) == "authorization" || strings.ToLower(key) == "cookie" { + continue + } + span.SetAttributes(attribute.String("http.request.header."+key, strings.Join(values, ";"))) + } + + // Record request body (for POST/PUT/PATCH requests) + if c.Request.Method == "POST" || c.Request.Method == "PUT" || c.Request.Method == "PATCH" { + if c.Request.Body != nil { + bodyBytes, _ := io.ReadAll(c.Request.Body) + span.SetAttributes(attribute.String("http.request.body", string(bodyBytes))) + // Reset request body because ReadAll consumes the Reader content + c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + } + } + + // Record query parameters + if len(c.Request.URL.RawQuery) > 0 { + span.SetAttributes(attribute.String("http.request.query", c.Request.URL.RawQuery)) + } + + // Set request context with span context + c.Request = c.Request.WithContext(ctx) + + // Store tracing context in Gin context + c.Set("trace.span", span) + c.Set("trace.ctx", ctx) + + // Create response body capturer + responseBody := &bytes.Buffer{} + responseWriter := &responseBodyWriter{ + ResponseWriter: c.Writer, + body: responseBody, + } + c.Writer = responseWriter + + // Process request + c.Next() + + // Set response status code + statusCode := c.Writer.Status() + span.SetAttributes(attribute.Int("http.status_code", statusCode)) + + // Record response body + responseContent := responseBody.String() + if len(responseContent) > 0 { + span.SetAttributes(attribute.String("http.response.body", responseContent)) + } + + // Record response headers (optional, or selectively record important headers) + for key, values := range c.Writer.Header() { + span.SetAttributes(attribute.String("http.response.header."+key, strings.Join(values, ";"))) + } + + // Mark as error if status code >= 400 + if statusCode >= 400 { + span.SetStatus(codes.Error, fmt.Sprintf("HTTP %d", statusCode)) + if err := c.Errors.Last(); err != nil { + span.RecordError(err.Err) + } + } else { + span.SetStatus(codes.Ok, "") + } + } +} diff --git a/pkg/ollama/client.go b/internal/pkg/ollama/client.go similarity index 75% rename from pkg/ollama/client.go rename to internal/pkg/ollama/client.go index c7a9a87..9c6335d 100644 --- a/pkg/ollama/client.go +++ b/internal/pkg/ollama/client.go @@ -2,7 +2,7 @@ package ollama import ( "ai_scheduler/internal/config" - "ai_scheduler/pkg/types" + "ai_scheduler/internal/entitys" "context" "encoding/json" "fmt" @@ -18,20 +18,25 @@ type Client struct { } // NewClient 创建新的Ollama客户端 -func NewClient(config *config.OllamaConfig) (*Client, error) { +func NewClient(config *config.Config) (entitys.AIClient, func(), error) { client, err := api.ClientFromEnvironment() + cleanup := func() { + if client != nil { + client = nil + } + } if err != nil { - return nil, fmt.Errorf("failed to create ollama client: %w", err) + return nil, cleanup, fmt.Errorf("failed to create ollama client: %w", err) } return &Client{ client: client, - config: config, - }, nil + config: &config.Ollama, + }, cleanup, nil } // Chat 实现聊天功能 -func (c *Client) Chat(ctx context.Context, messages []types.Message, tools []types.ToolDefinition) (*types.ChatResponse, error) { +func (c *Client) Chat(ctx context.Context, messages []entitys.Message, tools []entitys.ToolDefinition) (*entitys.ChatResponse, error) { // 构建聊天请求 req := &api.ChatRequest{ Model: c.config.Model, @@ -89,23 +94,23 @@ func (c *Client) Chat(ctx context.Context, messages []types.Message, tools []typ } // convertResponse 转换响应格式 -func (c *Client) convertResponse(resp *api.ChatResponse) *types.ChatResponse { - result := &types.ChatResponse{ +func (c *Client) convertResponse(resp *api.ChatResponse) *entitys.ChatResponse { + result := &entitys.ChatResponse{ Message: resp.Message.Content, Finished: resp.Done, } // 转换工具调用 if len(resp.Message.ToolCalls) > 0 { - result.ToolCalls = make([]types.ToolCall, len(resp.Message.ToolCalls)) + result.ToolCalls = make([]entitys.ToolCall, len(resp.Message.ToolCalls)) for i, toolCall := range resp.Message.ToolCalls { // 转换函数参数 argBytes, _ := json.Marshal(toolCall.Function.Arguments) - result.ToolCalls[i] = types.ToolCall{ + result.ToolCalls[i] = entitys.ToolCall{ ID: fmt.Sprintf("call_%d", i), Type: "function", - Function: types.FunctionCall{ + Function: entitys.FunctionCall{ Name: toolCall.Function.Name, Arguments: json.RawMessage(argBytes), }, diff --git a/internal/pkg/provider_set.go b/internal/pkg/provider_set.go new file mode 100644 index 0000000..60b276d --- /dev/null +++ b/internal/pkg/provider_set.go @@ -0,0 +1,9 @@ +package pkg + +import ( + "ai_scheduler/internal/pkg/ollama" + + "github.com/google/wire" +) + +var ProviderSetClient = wire.NewSet(ollama.NewClient) diff --git a/internal/server/http.go b/internal/server/http.go new file mode 100644 index 0000000..65e73a1 --- /dev/null +++ b/internal/server/http.go @@ -0,0 +1,28 @@ +package server + +import ( + "ai_scheduler/internal/server/router" + "ai_scheduler/internal/services" + + "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/middleware/logger" + "github.com/gofiber/fiber/v2/middleware/recover" +) + +func NewHTTPServer( + service *services.ChatService, +) *fiber.App { + //构建 server + app := initRoute() + router.SetupRoutes(app, service) + return app +} + +func initRoute() *fiber.App { + app := fiber.New() + app.Use( + recover.New(), + logger.New(), + ) + return app +} diff --git a/internal/server/provider_set.go b/internal/server/provider_set.go new file mode 100644 index 0000000..dd6b4b4 --- /dev/null +++ b/internal/server/provider_set.go @@ -0,0 +1,5 @@ +package server + +import "github.com/google/wire" + +var ProviderSetServer = wire.NewSet(NewServers, NewHTTPServer) diff --git a/internal/server/router/router.go b/internal/server/router/router.go new file mode 100644 index 0000000..c909b31 --- /dev/null +++ b/internal/server/router/router.go @@ -0,0 +1,73 @@ +package router + +import ( + errors "ai_scheduler/internal/data/error" + "ai_scheduler/internal/services" + "encoding/json" + "strings" + + "github.com/gofiber/fiber/v2" +) + +// SetupRoutes 设置路由 +func SetupRoutes(app *fiber.App, ChatService *services.ChatService) { + app.Use(func(c *fiber.Ctx) error { + // 设置 CORS 头 + c.Set("Access-Control-Allow-Origin", "*") + c.Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") + c.Set("Access-Control-Allow-Headers", "Content-Type, Authorization") + + // 如果是预检请求(OPTIONS),直接返回 204 + if c.Method() == "OPTIONS" { + return c.SendStatus(fiber.StatusNoContent) // 204 + } + + // 继续处理后续中间件或路由 + return c.Next() + }) + + r := app.Group("api/v1/") + registerResponse(r) + // 注册 CORS 中间件 + + r.Post("/chat", ChatService.Chat) +} + +func registerResponse(router fiber.Router) { + // 自定义返回 + router.Use(func(c *fiber.Ctx) error { + err := c.Next() + return registerCommon(c, err) + }) + +} + +func registerCommon(c *fiber.Ctx, err error) error { + // 调用下一个中间件或路由处理函数 + + bsErr, ok := err.(*errors.BusinessErr) + if !ok { + bsErr = errors.SystemError + } + // 如果有错误发生 + if err != nil { + // 返回自定义错误响应 + return c.JSON(fiber.Map{ + "message": bsErr.Error(), + "code": bsErr.Code(), + "data": nil, + }) + } + contentType := strings.ToLower(string(c.Response().Header.Peek("Content-Type"))) + if strings.Contains(strings.ToLower(contentType), "text/event-stream") { + // 是 SSE 请求 + return c.SendString("这是 SSE 请求") + } + var data interface{} + json.Unmarshal(c.Response().Body(), &data) + return c.JSON(fiber.Map{ + "data": data, + "message": errors.Success.Error(), + "code": errors.Success.Code(), + }) +} diff --git a/internal/server/server.go b/internal/server/server.go new file mode 100644 index 0000000..488ef38 --- /dev/null +++ b/internal/server/server.go @@ -0,0 +1,13 @@ +package server + +import "github.com/gofiber/fiber/v2" + +type Servers struct { + HttpServer *fiber.App +} + +func NewServers(fiber *fiber.App) *Servers { + return &Servers{ + HttpServer: fiber, + } +} diff --git a/internal/handlers/chat.go b/internal/services/chat.go similarity index 51% rename from internal/handlers/chat.go rename to internal/services/chat.go index 6961c9e..d9ec1d7 100644 --- a/internal/handlers/chat.go +++ b/internal/services/chat.go @@ -1,39 +1,26 @@ -package handlers +package services import ( - "ai_scheduler/pkg/types" + errors "ai_scheduler/internal/data/error" + "ai_scheduler/internal/entitys" "net/http" "github.com/gin-gonic/gin" + "github.com/gofiber/fiber/v2" ) // ChatHandler 聊天处理器 -type ChatHandler struct { - routerService types.RouterService +type ChatService struct { + routerService entitys.RouterService } // NewChatHandler 创建聊天处理器 -func NewChatHandler(routerService types.RouterService) *ChatHandler { - return &ChatHandler{ +func NewChatService(routerService entitys.RouterService) *ChatService { + return &ChatService{ routerService: routerService, } } -// ChatRequest HTTP聊天请求 -type ChatRequest struct { - UserInput string `json:"user_input" binding:"required" example:"考勤规则"` - Caller string `json:"caller" binding:"required" example:"zltx"` - SessionID string `json:"session_id" example:"default"` -} - -// ChatResponse HTTP聊天响应 -type ChatResponse struct { - Status string `json:"status" example:"success"` // 处理状态 - Message string `json:"message" example:""` // 响应消息 - Data any `json:"data,omitempty"` // 响应数据 - TaskCode string `json:"task_code,omitempty"` // 任务代码 -} - // ToolCallResponse 工具调用响应 type ToolCallResponse struct { ID string `json:"id" example:"call_1"` @@ -48,23 +35,19 @@ type FunctionCallResponse struct { Arguments interface{} `json:"arguments"` } -func (h *ChatHandler) Chat(c *gin.Context) { - var req ChatRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, ChatResponse{ - Status: "error", - Message: "请求参数错误", - }) - return +func (h *ChatService) Chat(c *fiber.Ctx) error { + var req entitys.ChatRequest + if err := c.BodyParser(&req); err != nil { + return errors.ParamError } // 转换为服务层请求 - serviceReq := &types.ChatRequest{ + serviceReq := &entitys.ChatRequest{ UserInput: req.UserInput, Caller: req.Caller, SessionID: req.SessionID, - ChatRequestMeta: types.ChatRequestMeta{ - Authorization: c.GetHeader("Authorization"), + ChatRequestMeta: entitys.ChatRequestMeta{ + Authorization: c.Request().Header(), }, } diff --git a/internal/services/provider_set.go b/internal/services/provider_set.go new file mode 100644 index 0000000..1fdac15 --- /dev/null +++ b/internal/services/provider_set.go @@ -0,0 +1,5 @@ +package services + +import "github.com/google/wire" + +var ProviderSetServices = wire.NewSet(NewChatService) diff --git a/internal/tools/calculator.go b/internal/tools/calculator.go index 39beccf..fecde64 100644 --- a/internal/tools/calculator.go +++ b/internal/tools/calculator.go @@ -1,7 +1,8 @@ package tools import ( - "ai_scheduler/pkg/types" + "ai_scheduler/internal/entitys" + "context" "encoding/json" "fmt" @@ -27,10 +28,10 @@ func (c *CalculatorTool) Description() string { } // Definition 返回工具定义 -func (c *CalculatorTool) Definition() types.ToolDefinition { - return types.ToolDefinition{ +func (c *CalculatorTool) Definition() entitys.ToolDefinition { + return entitys.ToolDefinition{ Type: "function", - Function: types.FunctionDef{ + Function: entitys.FunctionDef{ Name: c.Name(), Description: c.Description(), Parameters: map[string]interface{}{ @@ -65,11 +66,11 @@ type CalculateRequest struct { // CalculateResponse 计算响应 type CalculateResponse struct { - Operation string `json:"operation"` - A float64 `json:"a"` - B float64 `json:"b"` - Result float64 `json:"result"` - Expression string `json:"expression"` + Operation string `json:"operation"` + A float64 `json:"a"` + B float64 `json:"b"` + Result float64 `json:"result"` + Expression string `json:"expression"` } // Execute 执行计算 @@ -117,4 +118,4 @@ func (c *CalculatorTool) Execute(ctx context.Context, args json.RawMessage) (int Result: result, Expression: expression, }, nil -} \ No newline at end of file +} diff --git a/internal/tools/manager.go b/internal/tools/manager.go index dbd4f46..6b907e2 100644 --- a/internal/tools/manager.go +++ b/internal/tools/manager.go @@ -3,7 +3,8 @@ package tools import ( "ai_scheduler/internal/config" "ai_scheduler/internal/constants" - "ai_scheduler/pkg/types" + "ai_scheduler/internal/entitys" + "context" "encoding/json" "fmt" @@ -11,23 +12,23 @@ import ( // Manager 工具管理器 type Manager struct { - tools map[string]types.Tool + tools map[string]entitys.Tool } // NewManager 创建工具管理器 -func NewManager(config *config.ToolsConfig) *Manager { +func NewManager(config *config.Config) *Manager { m := &Manager{ - tools: make(map[string]types.Tool), + tools: make(map[string]entitys.Tool), } // 注册天气工具 - if config.Weather.Enabled { + if config.Tools.Weather.Enabled { weatherTool := NewWeatherTool() m.tools[weatherTool.Name()] = weatherTool } // 注册计算器工具 - if config.Calculator.Enabled { + if config.Tools.Calculator.Enabled { calcTool := NewCalculatorTool() m.tools[calcTool.Name()] = calcTool } @@ -39,8 +40,8 @@ func NewManager(config *config.ToolsConfig) *Manager { // } // 注册直连天下订单详情工具 - if config.ZltxOrderDetail.Enabled { - zltxOrderDetailTool := NewZltxOrderDetailTool(config.ZltxOrderDetail) + if config.Tools.ZltxOrderDetail.Enabled { + zltxOrderDetailTool := NewZltxOrderDetailTool(config.Tools.ZltxOrderDetail) m.tools[zltxOrderDetailTool.Name()] = zltxOrderDetailTool } @@ -54,14 +55,14 @@ func NewManager(config *config.ToolsConfig) *Manager { } // GetTool 获取工具 -func (m *Manager) GetTool(name string) (types.Tool, bool) { +func (m *Manager) GetTool(name string) (entitys.Tool, bool) { tool, exists := m.tools[name] return tool, exists } // GetAllTools 获取所有工具 -func (m *Manager) GetAllTools() []types.Tool { - tools := make([]types.Tool, 0, len(m.tools)) +func (m *Manager) GetAllTools() []entitys.Tool { + tools := make([]entitys.Tool, 0, len(m.tools)) for _, tool := range m.tools { tools = append(tools, tool) } @@ -69,8 +70,8 @@ func (m *Manager) GetAllTools() []types.Tool { } // GetToolDefinitions 获取所有工具定义 -func (m *Manager) GetToolDefinitions(caller constants.Caller) []types.ToolDefinition { - definitions := make([]types.ToolDefinition, 0, len(m.tools)) +func (m *Manager) GetToolDefinitions(caller constants.Caller) []entitys.ToolDefinition { + definitions := make([]entitys.ToolDefinition, 0, len(m.tools)) for _, tool := range m.tools { definitions = append(definitions, tool.Definition()) } @@ -89,8 +90,8 @@ func (m *Manager) ExecuteTool(ctx context.Context, name string, args json.RawMes } // ExecuteToolCalls 执行多个工具调用 -func (m *Manager) ExecuteToolCalls(ctx context.Context, toolCalls []types.ToolCall) ([]types.ToolCall, error) { - results := make([]types.ToolCall, len(toolCalls)) +func (m *Manager) ExecuteToolCalls(ctx context.Context, toolCalls []entitys.ToolCall) ([]entitys.ToolCall, error) { + results := make([]entitys.ToolCall, len(toolCalls)) for i, toolCall := range toolCalls { results[i] = toolCall diff --git a/internal/tools/provider_set.go b/internal/tools/provider_set.go new file mode 100644 index 0000000..d038f67 --- /dev/null +++ b/internal/tools/provider_set.go @@ -0,0 +1,7 @@ +package tools + +import ( + "github.com/google/wire" +) + +var ProviderSetTools = wire.NewSet(NewManager) diff --git a/internal/tools/weather.go b/internal/tools/weather.go index 9d662b7..744216f 100644 --- a/internal/tools/weather.go +++ b/internal/tools/weather.go @@ -1,7 +1,8 @@ package tools import ( - "ai_scheduler/pkg/types" + "ai_scheduler/internal/entitys" + "context" "encoding/json" "fmt" @@ -30,10 +31,10 @@ func (w *WeatherTool) Description() string { } // Definition 返回工具定义 -func (w *WeatherTool) Definition() types.ToolDefinition { - return types.ToolDefinition{ +func (w *WeatherTool) Definition() entitys.ToolDefinition { + return entitys.ToolDefinition{ Type: "function", - Function: types.FunctionDef{ + Function: entitys.FunctionDef{ Name: w.Name(), Description: w.Description(), Parameters: map[string]interface{}{ diff --git a/internal/tools/zltx_order_detail.go b/internal/tools/zltx_order_detail.go index 934ae38..4d6ddd8 100644 --- a/internal/tools/zltx_order_detail.go +++ b/internal/tools/zltx_order_detail.go @@ -2,7 +2,8 @@ package tools import ( "ai_scheduler/internal/config" - "ai_scheduler/pkg/types" + "ai_scheduler/internal/entitys" + "context" "encoding/json" "fmt" @@ -30,10 +31,10 @@ func (w *ZltxOrderDetailTool) Description() string { } // Definition 返回工具定义 -func (w *ZltxOrderDetailTool) Definition() types.ToolDefinition { - return types.ToolDefinition{ +func (w *ZltxOrderDetailTool) Definition() entitys.ToolDefinition { + return entitys.ToolDefinition{ Type: "function", - Function: types.FunctionDef{ + Function: entitys.FunctionDef{ Name: w.Name(), Description: w.Description(), Parameters: map[string]interface{}{