74 lines
1.7 KiB
Go
74 lines
1.7 KiB
Go
package router
|
||
|
||
import (
|
||
errors "ai_scheduler/internal/data/error"
|
||
"ai_scheduler/internal/services"
|
||
"encoding/json"
|
||
"strings"
|
||
|
||
"github.com/gofiber/fiber/v2"
|
||
)
|
||
|
||
// SetupRoutes 设置路由
|
||
func SetupRoutes(app *fiber.App, ChatService *services.ChatService) {
|
||
app.Use(func(c *fiber.Ctx) error {
|
||
// 设置 CORS 头
|
||
c.Set("Access-Control-Allow-Origin", "*")
|
||
c.Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
||
c.Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
||
|
||
// 如果是预检请求(OPTIONS),直接返回 204
|
||
if c.Method() == "OPTIONS" {
|
||
return c.SendStatus(fiber.StatusNoContent) // 204
|
||
}
|
||
|
||
// 继续处理后续中间件或路由
|
||
return c.Next()
|
||
})
|
||
|
||
r := app.Group("api/v1/")
|
||
registerResponse(r)
|
||
// 注册 CORS 中间件
|
||
|
||
r.Post("/chat", ChatService.Chat)
|
||
}
|
||
|
||
func registerResponse(router fiber.Router) {
|
||
// 自定义返回
|
||
router.Use(func(c *fiber.Ctx) error {
|
||
err := c.Next()
|
||
return registerCommon(c, err)
|
||
})
|
||
|
||
}
|
||
|
||
func registerCommon(c *fiber.Ctx, err error) error {
|
||
// 调用下一个中间件或路由处理函数
|
||
|
||
bsErr, ok := err.(*errors.BusinessErr)
|
||
if !ok {
|
||
bsErr = errors.SystemError
|
||
}
|
||
// 如果有错误发生
|
||
if err != nil {
|
||
// 返回自定义错误响应
|
||
return c.JSON(fiber.Map{
|
||
"message": bsErr.Error(),
|
||
"code": bsErr.Code(),
|
||
"data": nil,
|
||
})
|
||
}
|
||
contentType := strings.ToLower(string(c.Response().Header.Peek("Content-Type")))
|
||
if strings.Contains(strings.ToLower(contentType), "text/event-stream") {
|
||
// 是 SSE 请求
|
||
return c.SendString("这是 SSE 请求")
|
||
}
|
||
var data interface{}
|
||
json.Unmarshal(c.Response().Body(), &data)
|
||
return c.JSON(fiber.Map{
|
||
"data": data,
|
||
"message": errors.Success.Error(),
|
||
"code": errors.Success.Code(),
|
||
})
|
||
}
|