From 8a2411016cd8db90877075968d7ab3a6b8c59256 Mon Sep 17 00:00:00 2001 From: wuchao <1272174216@qq.com> Date: Fri, 19 Sep 2025 11:23:56 +0800 Subject: [PATCH] =?UTF-8?q?feat(internal):=20=E5=AE=9E=E7=8E=B0=E7=BD=91?= =?UTF-8?q?=E5=85=B3=E6=9C=8D=E5=8A=A1=E5=B9=B6=E4=BC=98=E5=8C=96=E8=81=8A?= =?UTF-8?q?=E5=A4=A9=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 Gateway 结构体和相关方法,用于管理客户端连接和消息分发 - 重构 ChatService,集成 Gateway 功能 - 添加客户端绑定 UID 功能,支持消息发送到指定用户 - 实现全局消息广播和单用户消息发送的 HTTP 接口 - 优化聊天消息处理逻辑,增加消息回显功能 --- internal/gateway/gateway.go | 104 ++++++++++++++++++++++++++++++ internal/server/http.go | 4 +- internal/server/router/router.go | 27 +++++++- internal/services/chat.go | 54 ++++++++++++++-- internal/services/provider_set.go | 7 +- 5 files changed, 186 insertions(+), 10 deletions(-) create mode 100644 internal/gateway/gateway.go diff --git a/internal/gateway/gateway.go b/internal/gateway/gateway.go new file mode 100644 index 0000000..f04ed3a --- /dev/null +++ b/internal/gateway/gateway.go @@ -0,0 +1,104 @@ +package gateway + +import ( + "errors" + "sync" +) + +type Client struct { + ID string + SendFunc func(data []byte) error +} + +type Gateway struct { + mu sync.RWMutex + clients map[string]*Client // clientID -> Client + uidMap map[string][]string // uid -> []clientID +} + +func NewGateway() *Gateway { + return &Gateway{ + clients: make(map[string]*Client), + uidMap: make(map[string][]string), + } +} + +func (g *Gateway) AddClient(c *Client) { + g.mu.Lock() + defer g.mu.Unlock() + g.clients[c.ID] = c +} + +func (g *Gateway) RemoveClient(clientID string) { + g.mu.Lock() + defer g.mu.Unlock() + delete(g.clients, clientID) + for uid, list := range g.uidMap { + newList := []string{} + for _, cid := range list { + if cid != clientID { + newList = append(newList, cid) + } + } + g.uidMap[uid] = newList + } +} + +func (g *Gateway) SendToAll(msg []byte) { + g.mu.RLock() + defer g.mu.RUnlock() + for _, c := range g.clients { + _ = c.SendFunc(msg) + } +} + +func (g *Gateway) SendToClient(clientID string, msg []byte) error { + g.mu.RLock() + defer g.mu.RUnlock() + if c, ok := g.clients[clientID]; ok { + return c.SendFunc(msg) + } + return errors.New("client not found") +} + +func (g *Gateway) BindUid(clientID, uid string) error { + g.mu.Lock() + defer g.mu.Unlock() + if _, ok := g.clients[clientID]; !ok { + return errors.New("client not found") + } + g.uidMap[uid] = append(g.uidMap[uid], clientID) + return nil +} + +func (g *Gateway) SendToUid(uid string, msg []byte) { + g.mu.RLock() + defer g.mu.RUnlock() + if list, ok := g.uidMap[uid]; ok { + for _, cid := range list { + if c, ok := g.clients[cid]; ok { + _ = c.SendFunc(msg) + } + } + } +} + +func (g *Gateway) ListClients() []string { + g.mu.RLock() + defer g.mu.RUnlock() + ids := make([]string, 0, len(g.clients)) + for id := range g.clients { + ids = append(ids, id) + } + return ids +} + +func (g *Gateway) ListUids() map[string][]string { + g.mu.RLock() + defer g.mu.RUnlock() + result := make(map[string][]string, len(g.uidMap)) + for uid, list := range g.uidMap { + result[uid] = append([]string(nil), list...) + } + return result +} diff --git a/internal/server/http.go b/internal/server/http.go index 19eb931..d874dfa 100644 --- a/internal/server/http.go +++ b/internal/server/http.go @@ -1,6 +1,7 @@ package server import ( + "ai_scheduler/internal/gateway" "ai_scheduler/internal/server/router" "ai_scheduler/internal/services" @@ -12,10 +13,11 @@ import ( func NewHTTPServer( service *services.ChatService, session *services.SessionService, + gateway *gateway.Gateway, ) *fiber.App { //构建 server app := initRoute() - router.SetupRoutes(app, service, session) + router.SetupRoutes(app, service, session, gateway) return app } diff --git a/internal/server/router/router.go b/internal/server/router/router.go index c5e3936..30a0a6a 100644 --- a/internal/server/router/router.go +++ b/internal/server/router/router.go @@ -2,8 +2,10 @@ package router import ( errors "ai_scheduler/internal/data/error" + "ai_scheduler/internal/gateway" "ai_scheduler/internal/services" "encoding/json" + "fmt" "strings" "sync" "time" @@ -13,7 +15,7 @@ import ( ) // SetupRoutes 设置路由 -func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionService *services.SessionService) { +func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionService *services.SessionService, gateway *gateway.Gateway) { app.Use(func(c *fiber.Ctx) error { // 设置 CORS 头 c.Set("Access-Control-Allow-Origin", "*") @@ -28,7 +30,7 @@ func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionServi // 继续处理后续中间件或路由 return c.Next() }) - routerHttp(app, sessionService) + routerHttp(app, sessionService, gateway) routerSocket(app, ChatService) } @@ -56,7 +58,7 @@ var bufferPool = &sync.Pool{ }, } -func routerHttp(app *fiber.App, sessionService *services.SessionService) { +func routerHttp(app *fiber.App, sessionService *services.SessionService, gateway *gateway.Gateway) { r := app.Group("api/v1/") registerResponse(r) // 注册 CORS 中间件 @@ -67,7 +69,26 @@ func routerHttp(app *fiber.App, sessionService *services.SessionService) { r.Post("/session/init", sessionService.SessionInit) // 会话初始化,不存在则创建,存在则返回会话ID和默认条数会话历史 r.Post("/session/list", sessionService.SessionList) + //广播 + r.Get("/broadcast", func(ctx *fiber.Ctx) error { + action := ctx.Query("action") + uid := ctx.Query("uid") + msg := ctx.Query("msg") + switch action { + case "sendToAll": + gateway.SendToAll([]byte(msg)) + return ctx.SendString("sent to all") + case "sendToUid": + if uid == "" { + return ctx.Status(400).SendString("missing uid") + } + gateway.SendToUid(uid, []byte(msg)) + return ctx.SendString(fmt.Sprintf("sent to uid %s", uid)) + default: + return ctx.Status(400).SendString("unknown action") + } + }) } func registerResponse(router fiber.Router) { diff --git a/internal/services/chat.go b/internal/services/chat.go index 0dd09fb..d44d032 100644 --- a/internal/services/chat.go +++ b/internal/services/chat.go @@ -3,8 +3,14 @@ package services import ( "ai_scheduler/internal/data/constant" "ai_scheduler/internal/entitys" + "ai_scheduler/internal/gateway" + "encoding/hex" "encoding/json" + "fmt" "log" + "math/rand" + "sync" + "time" "github.com/gofiber/websocket/v2" ) @@ -12,12 +18,15 @@ import ( // ChatHandler 聊天处理器 type ChatService struct { routerService entitys.RouterService + Gw *gateway.Gateway + mu sync.Mutex } // NewChatHandler 创建聊天处理器 -func NewChatService(routerService entitys.RouterService) *ChatService { +func NewChatService(routerService entitys.RouterService, gw *gateway.Gateway) *ChatService { return &ChatService{ routerService: routerService, + Gw: gw, } } @@ -36,12 +45,34 @@ type FunctionCallResponse struct { } func (h *ChatService) ChatFail(c *websocket.Conn, content string) { - //c.WriteMessage(messageType, message) + err := c.WriteMessage(websocket.TextMessage, []byte(content)) + if err != nil { + log.Println("发送错误:", err) + } + _ = c.Close() } +func generateClientID() string { + // 使用时间戳+随机数确保唯一性 + timestamp := time.Now().UnixNano() + randomBytes := make([]byte, 4) + rand.Read(randomBytes) + randomStr := hex.EncodeToString(randomBytes) + return fmt.Sprintf("%d%s", timestamp, randomStr) +} func (h *ChatService) Chat(c *websocket.Conn) { + h.mu.Lock() + clientID := generateClientID() + h.mu.Unlock() + client := &gateway.Client{ + ID: clientID, + SendFunc: func(data []byte) error { + return c.WriteMessage(websocket.TextMessage, data) + }, + } + h.Gw.AddClient(client) + log.Println("client connected:", clientID) log.Println("客户端已连接") - defer c.Close() // 循环读取客户端消息 for { messageType, message, err := c.ReadMessage() @@ -49,6 +80,19 @@ func (h *ChatService) Chat(c *websocket.Conn) { log.Println("读取错误:", err) break } + //简单协议:bind: + if len(message) > 5 && string(message[:5]) == "bind:" { + uid := string(message[5:]) + _ = h.Gw.BindUid(clientID, uid) + log.Printf("bind %s -> uid:%s\n", clientID, uid) + continue + } + // 回显 + err = h.Gw.SendToClient(clientID, []byte("echo: "+string(message))) + if err != nil { + h.ChatFail(c, "send to client failed") + continue + } msg, chatType := h.handleMessageToString(c, messageType, message) if chatType == constant.ConnStatusClosed { break @@ -69,7 +113,9 @@ func (h *ChatService) Chat(c *websocket.Conn) { continue } } - log.Println("客户端已断开") + h.Gw.RemoveClient(clientID) + _ = c.Close() + log.Println("client disconnected:", clientID) } func (h *ChatService) handleMessageToString(c *websocket.Conn, msgType int, msg any) (text []byte, chatType constant.ConnStatus) { diff --git a/internal/services/provider_set.go b/internal/services/provider_set.go index dbe2e9e..f5b03d4 100644 --- a/internal/services/provider_set.go +++ b/internal/services/provider_set.go @@ -1,5 +1,8 @@ package services -import "github.com/google/wire" +import ( + "ai_scheduler/internal/gateway" + "github.com/google/wire" +) -var ProviderSetServices = wire.NewSet(NewChatService, NewSessionService) +var ProviderSetServices = wire.NewSet(NewChatService, NewSessionService, gateway.NewGateway)