ai_scheduler/internal/server/router/router.go

132 lines
3.7 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package router
import (
errors "ai_scheduler/internal/data/error"
"ai_scheduler/internal/gateway"
"ai_scheduler/internal/services"
"encoding/json"
"fmt"
"strings"
"sync"
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/websocket/v2"
)
// SetupRoutes 设置路由
func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionService *services.SessionService, gateway *gateway.Gateway) {
app.Use(func(c *fiber.Ctx) error {
// 设置 CORS 头
c.Set("Access-Control-Allow-Origin", "*")
c.Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
c.Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
// 如果是预检请求OPTIONS直接返回 204
if c.Method() == "OPTIONS" {
return c.SendStatus(fiber.StatusNoContent) // 204
}
// 继续处理后续中间件或路由
return c.Next()
})
routerHttp(app, sessionService, gateway)
routerSocket(app, ChatService)
}
func routerSocket(app *fiber.App, chatService *services.ChatService) {
ws := app.Group("ws/v1/")
// WebSocket 路由配置
ws.Get("/chat", websocket.New(func(c *websocket.Conn) {
// 可以在这里添加握手前的中间件逻辑(如头校验)
chatService.Chat(c) // 调用实际的 Chat 处理函数
}, websocket.Config{
// 可选配置:跨域检查、最大负载大小等
HandshakeTimeout: 10 * time.Second,
//Subprotocols: []string{"json", "msgpack"},
//Origins: []string{"http://localhost:3000", "https://yourdomain.com"} //验证 Origin 头防止跨站请求伪造CSRF。如果为空则允许所有来源。
ReadBufferSize: 8192,
WriteBufferSize: 8192,
WriteBufferPool: bufferPool, // 使用缓冲池
}))
}
var bufferPool = &sync.Pool{
New: func() interface{} {
return make([]byte, 4096) // 4KB 缓冲区
},
}
func routerHttp(app *fiber.App, sessionService *services.SessionService, gateway *gateway.Gateway) {
r := app.Group("api/v1/")
registerResponse(r)
// 注册 CORS 中间件
r.Get("/health", func(c *fiber.Ctx) error {
c.Response().SetBody([]byte("1"))
return nil
})
r.Post("/session/init", sessionService.SessionInit) // 会话初始化,不存在则创建,存在则返回会话ID和默认条数会话历史
r.Post("/session/list", sessionService.SessionList)
//广播
r.Get("/broadcast", func(ctx *fiber.Ctx) error {
action := ctx.Query("action")
uid := ctx.Query("uid")
msg := ctx.Query("msg")
switch action {
case "sendToAll":
gateway.SendToAll([]byte(msg))
return ctx.SendString("sent to all")
case "sendToUid":
if uid == "" {
return ctx.Status(400).SendString("missing uid")
}
gateway.SendToUid(uid, []byte(msg))
return ctx.SendString(fmt.Sprintf("sent to uid %s", uid))
default:
return ctx.Status(400).SendString("unknown action")
}
})
}
func registerResponse(router fiber.Router) {
// 自定义返回
router.Use(func(c *fiber.Ctx) error {
err := c.Next()
return registerCommon(c, err)
})
}
func registerCommon(c *fiber.Ctx, err error) error {
// 调用下一个中间件或路由处理函数
bsErr, ok := err.(*errors.BusinessErr)
if !ok {
bsErr = errors.SystemError
}
// 如果有错误发生
if err != nil {
// 返回自定义错误响应
return c.JSON(fiber.Map{
"message": bsErr.Error(),
"code": bsErr.Code(),
"data": nil,
})
}
contentType := strings.ToLower(string(c.Response().Header.Peek("Content-Type")))
if strings.Contains(strings.ToLower(contentType), "text/event-stream") {
// 是 SSE 请求
return c.SendString("这是 SSE 请求")
}
var data interface{}
json.Unmarshal(c.Response().Body(), &data)
return c.JSON(fiber.Map{
"data": data,
"message": errors.Success.Error(),
"code": errors.Success.Code(),
})
}