107 lines
2.8 KiB
Go
107 lines
2.8 KiB
Go
package router
|
||
|
||
import (
|
||
errors "ai_scheduler/internal/data/error"
|
||
"ai_scheduler/internal/services"
|
||
"encoding/json"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
"github.com/gofiber/fiber/v2"
|
||
"github.com/gofiber/websocket/v2"
|
||
)
|
||
|
||
// SetupRoutes 设置路由
|
||
func SetupRoutes(app *fiber.App, ChatService *services.ChatService) {
|
||
app.Use(func(c *fiber.Ctx) error {
|
||
// 设置 CORS 头
|
||
c.Set("Access-Control-Allow-Origin", "*")
|
||
c.Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
||
c.Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
||
|
||
// 如果是预检请求(OPTIONS),直接返回 204
|
||
if c.Method() == "OPTIONS" {
|
||
return c.SendStatus(fiber.StatusNoContent) // 204
|
||
}
|
||
|
||
// 继续处理后续中间件或路由
|
||
return c.Next()
|
||
})
|
||
routerHttp(app)
|
||
routerSocket(app, ChatService)
|
||
|
||
}
|
||
|
||
func routerSocket(app *fiber.App, chatService *services.ChatService) {
|
||
ws := app.Group("ws/v1/")
|
||
// WebSocket 路由配置
|
||
ws.Get("/chat", websocket.New(func(c *websocket.Conn) {
|
||
// 可以在这里添加握手前的中间件逻辑(如头校验)
|
||
chatService.Chat(c) // 调用实际的 Chat 处理函数
|
||
}, websocket.Config{
|
||
// 可选配置:跨域检查、最大负载大小等
|
||
HandshakeTimeout: 10 * time.Second,
|
||
//Subprotocols: []string{"json", "msgpack"},
|
||
//Origins: []string{"http://localhost:3000", "https://yourdomain.com"} //验证 Origin 头,防止跨站请求伪造(CSRF)。如果为空,则允许所有来源。
|
||
ReadBufferSize: 8192,
|
||
WriteBufferSize: 8192,
|
||
WriteBufferPool: bufferPool, // 使用缓冲池
|
||
}))
|
||
}
|
||
|
||
var bufferPool = &sync.Pool{
|
||
New: func() interface{} {
|
||
return make([]byte, 4096) // 4KB 缓冲区
|
||
},
|
||
}
|
||
|
||
func routerHttp(app *fiber.App) {
|
||
r := app.Group("api/v1/")
|
||
registerResponse(r)
|
||
// 注册 CORS 中间件
|
||
r.Get("/health", func(c *fiber.Ctx) error {
|
||
c.Response().SetBody([]byte("1"))
|
||
return nil
|
||
})
|
||
}
|
||
|
||
func registerResponse(router fiber.Router) {
|
||
// 自定义返回
|
||
router.Use(func(c *fiber.Ctx) error {
|
||
err := c.Next()
|
||
return registerCommon(c, err)
|
||
})
|
||
|
||
}
|
||
|
||
func registerCommon(c *fiber.Ctx, err error) error {
|
||
// 调用下一个中间件或路由处理函数
|
||
|
||
bsErr, ok := err.(*errors.BusinessErr)
|
||
if !ok {
|
||
bsErr = errors.SystemError
|
||
}
|
||
// 如果有错误发生
|
||
if err != nil {
|
||
// 返回自定义错误响应
|
||
return c.JSON(fiber.Map{
|
||
"message": bsErr.Error(),
|
||
"code": bsErr.Code(),
|
||
"data": nil,
|
||
})
|
||
}
|
||
contentType := strings.ToLower(string(c.Response().Header.Peek("Content-Type")))
|
||
if strings.Contains(strings.ToLower(contentType), "text/event-stream") {
|
||
// 是 SSE 请求
|
||
return c.SendString("这是 SSE 请求")
|
||
}
|
||
var data interface{}
|
||
json.Unmarshal(c.Response().Body(), &data)
|
||
return c.JSON(fiber.Map{
|
||
"data": data,
|
||
"message": errors.Success.Error(),
|
||
"code": errors.Success.Code(),
|
||
})
|
||
}
|