package router import ( errors "ai_scheduler/internal/data/error" "ai_scheduler/internal/gateway" "ai_scheduler/internal/services" "encoding/json" "fmt" "strings" "sync" "time" "github.com/gofiber/fiber/v2" "github.com/gofiber/websocket/v2" ) // SetupRoutes 设置路由 func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionService *services.SessionService, gateway *gateway.Gateway) { app.Use(func(c *fiber.Ctx) error { // 设置 CORS 头 c.Set("Access-Control-Allow-Origin", "*") c.Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") c.Set("Access-Control-Allow-Headers", "Content-Type, Authorization") // 如果是预检请求(OPTIONS),直接返回 204 if c.Method() == "OPTIONS" { return c.SendStatus(fiber.StatusNoContent) // 204 } // 继续处理后续中间件或路由 return c.Next() }) routerHttp(app, sessionService, gateway) routerSocket(app, ChatService) } func routerSocket(app *fiber.App, chatService *services.ChatService) { ws := app.Group("ws/v1/") // WebSocket 路由配置 ws.Get("/chat", websocket.New(func(c *websocket.Conn) { // 可以在这里添加握手前的中间件逻辑(如头校验) chatService.Chat(c) // 调用实际的 Chat 处理函数 }, websocket.Config{ // 可选配置:跨域检查、最大负载大小等 HandshakeTimeout: 10 * time.Second, //Subprotocols: []string{"json", "msgpack"}, //Origins: []string{"http://localhost:3000", "https://yourdomain.com"} //验证 Origin 头,防止跨站请求伪造(CSRF)。如果为空,则允许所有来源。 ReadBufferSize: 8192, WriteBufferSize: 8192, WriteBufferPool: bufferPool, // 使用缓冲池 })) } var bufferPool = &sync.Pool{ New: func() interface{} { return make([]byte, 4096) // 4KB 缓冲区 }, } func routerHttp(app *fiber.App, sessionService *services.SessionService, gateway *gateway.Gateway) { r := app.Group("api/v1/") registerResponse(r) // 注册 CORS 中间件 r.Get("/health", func(c *fiber.Ctx) error { c.Response().SetBody([]byte("1")) return nil }) r.Post("/session/init", sessionService.SessionInit) // 会话初始化,不存在则创建,存在则返回会话ID和默认条数会话历史 r.Post("/session/list", sessionService.SessionList) //广播 r.Get("/broadcast", func(ctx *fiber.Ctx) error { action := ctx.Query("action") uid := ctx.Query("uid") msg := ctx.Query("msg") switch action { case "sendToAll": gateway.SendToAll([]byte(msg)) return ctx.SendString("sent to all") case "sendToUid": if uid == "" { return ctx.Status(400).SendString("missing uid") } gateway.SendToUid(uid, []byte(msg)) return ctx.SendString(fmt.Sprintf("sent to uid %s", uid)) default: return ctx.Status(400).SendString("unknown action") } }) } func registerResponse(router fiber.Router) { // 自定义返回 router.Use(func(c *fiber.Ctx) error { err := c.Next() return registerCommon(c, err) }) } func registerCommon(c *fiber.Ctx, err error) error { // 调用下一个中间件或路由处理函数 bsErr, ok := err.(*errors.BusinessErr) if !ok { bsErr = errors.SystemError } // 如果有错误发生 if err != nil { // 返回自定义错误响应 return c.JSON(fiber.Map{ "message": bsErr.Error(), "code": bsErr.Code(), "data": nil, }) } contentType := strings.ToLower(string(c.Response().Header.Peek("Content-Type"))) if strings.Contains(strings.ToLower(contentType), "text/event-stream") { // 是 SSE 请求 return c.SendString("这是 SSE 请求") } var data interface{} json.Unmarshal(c.Response().Body(), &data) return c.JSON(fiber.Map{ "data": data, "message": errors.Success.Error(), "code": errors.Success.Code(), }) }