132 lines
		
	
	
		
			3.7 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			132 lines
		
	
	
		
			3.7 KiB
		
	
	
	
		
			Go
		
	
	
	
package router
 | 
						||
 | 
						||
import (
 | 
						||
	errors "ai_scheduler/internal/data/error"
 | 
						||
	"ai_scheduler/internal/gateway"
 | 
						||
	"ai_scheduler/internal/services"
 | 
						||
	"encoding/json"
 | 
						||
	"fmt"
 | 
						||
	"strings"
 | 
						||
	"sync"
 | 
						||
	"time"
 | 
						||
 | 
						||
	"github.com/gofiber/fiber/v2"
 | 
						||
	"github.com/gofiber/websocket/v2"
 | 
						||
)
 | 
						||
 | 
						||
// SetupRoutes 设置路由
 | 
						||
func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionService *services.SessionService, gateway *gateway.Gateway) {
 | 
						||
	app.Use(func(c *fiber.Ctx) error {
 | 
						||
		// 设置 CORS 头
 | 
						||
		c.Set("Access-Control-Allow-Origin", "*")
 | 
						||
		c.Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
 | 
						||
		c.Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
 | 
						||
 | 
						||
		// 如果是预检请求(OPTIONS),直接返回 204
 | 
						||
		if c.Method() == "OPTIONS" {
 | 
						||
			return c.SendStatus(fiber.StatusNoContent) // 204
 | 
						||
		}
 | 
						||
 | 
						||
		// 继续处理后续中间件或路由
 | 
						||
		return c.Next()
 | 
						||
	})
 | 
						||
	routerHttp(app, sessionService, gateway)
 | 
						||
	routerSocket(app, ChatService)
 | 
						||
 | 
						||
}
 | 
						||
 | 
						||
func routerSocket(app *fiber.App, chatService *services.ChatService) {
 | 
						||
	ws := app.Group("ws/v1/")
 | 
						||
	// WebSocket 路由配置
 | 
						||
	ws.Get("/chat", websocket.New(func(c *websocket.Conn) {
 | 
						||
		// 可以在这里添加握手前的中间件逻辑(如头校验)
 | 
						||
		chatService.Chat(c) // 调用实际的 Chat 处理函数
 | 
						||
	}, websocket.Config{
 | 
						||
		// 可选配置:跨域检查、最大负载大小等
 | 
						||
		HandshakeTimeout: 10 * time.Second,
 | 
						||
		//Subprotocols: []string{"json", "msgpack"},
 | 
						||
		//Origins: []string{"http://localhost:3000", "https://yourdomain.com"} //验证 Origin 头,防止跨站请求伪造(CSRF)。如果为空,则允许所有来源。
 | 
						||
		ReadBufferSize:  8192,
 | 
						||
		WriteBufferSize: 8192,
 | 
						||
		WriteBufferPool: bufferPool, // 使用缓冲池
 | 
						||
	}))
 | 
						||
}
 | 
						||
 | 
						||
var bufferPool = &sync.Pool{
 | 
						||
	New: func() interface{} {
 | 
						||
		return make([]byte, 4096) // 4KB 缓冲区
 | 
						||
	},
 | 
						||
}
 | 
						||
 | 
						||
func routerHttp(app *fiber.App, sessionService *services.SessionService, gateway *gateway.Gateway) {
 | 
						||
	r := app.Group("api/v1/")
 | 
						||
	registerResponse(r)
 | 
						||
	// 注册 CORS 中间件
 | 
						||
	r.Get("/health", func(c *fiber.Ctx) error {
 | 
						||
		c.Response().SetBody([]byte("1"))
 | 
						||
		return nil
 | 
						||
	})
 | 
						||
 | 
						||
	r.Post("/session/init", sessionService.SessionInit) // 会话初始化,不存在则创建,存在则返回会话ID和默认条数会话历史
 | 
						||
	r.Post("/session/list", sessionService.SessionList)
 | 
						||
	//广播
 | 
						||
	r.Get("/broadcast", func(ctx *fiber.Ctx) error {
 | 
						||
		action := ctx.Query("action")
 | 
						||
		uid := ctx.Query("uid")
 | 
						||
		msg := ctx.Query("msg")
 | 
						||
 | 
						||
		switch action {
 | 
						||
		case "sendToAll":
 | 
						||
			gateway.SendToAll([]byte(msg))
 | 
						||
			return ctx.SendString("sent to all")
 | 
						||
		case "sendToUid":
 | 
						||
			if uid == "" {
 | 
						||
				return ctx.Status(400).SendString("missing uid")
 | 
						||
			}
 | 
						||
			gateway.SendToUid(uid, []byte(msg))
 | 
						||
			return ctx.SendString(fmt.Sprintf("sent to uid %s", uid))
 | 
						||
		default:
 | 
						||
			return ctx.Status(400).SendString("unknown action")
 | 
						||
		}
 | 
						||
	})
 | 
						||
}
 | 
						||
 | 
						||
func registerResponse(router fiber.Router) {
 | 
						||
	// 自定义返回
 | 
						||
	router.Use(func(c *fiber.Ctx) error {
 | 
						||
		err := c.Next()
 | 
						||
		return registerCommon(c, err)
 | 
						||
	})
 | 
						||
 | 
						||
}
 | 
						||
 | 
						||
func registerCommon(c *fiber.Ctx, err error) error {
 | 
						||
	// 调用下一个中间件或路由处理函数
 | 
						||
 | 
						||
	bsErr, ok := err.(*errors.BusinessErr)
 | 
						||
	if !ok {
 | 
						||
		bsErr = errors.SystemError
 | 
						||
	}
 | 
						||
	// 如果有错误发生
 | 
						||
	if err != nil {
 | 
						||
		// 返回自定义错误响应
 | 
						||
		return c.JSON(fiber.Map{
 | 
						||
			"message": bsErr.Error(),
 | 
						||
			"code":    bsErr.Code(),
 | 
						||
			"data":    nil,
 | 
						||
		})
 | 
						||
	}
 | 
						||
	contentType := strings.ToLower(string(c.Response().Header.Peek("Content-Type")))
 | 
						||
	if strings.Contains(strings.ToLower(contentType), "text/event-stream") {
 | 
						||
		// 是 SSE 请求
 | 
						||
		return c.SendString("这是 SSE 请求")
 | 
						||
	}
 | 
						||
	var data interface{}
 | 
						||
	json.Unmarshal(c.Response().Body(), &data)
 | 
						||
	return c.JSON(fiber.Map{
 | 
						||
		"data":    data,
 | 
						||
		"message": errors.Success.Error(),
 | 
						||
		"code":    errors.Success.Code(),
 | 
						||
	})
 | 
						||
}
 |