diff --git a/internal/biz/do/ctx.go b/internal/biz/do/ctx.go index ed8749d..eb2c5a5 100644 --- a/internal/biz/do/ctx.go +++ b/internal/biz/do/ctx.go @@ -19,8 +19,6 @@ import ( "gitea.cdlsxd.cn/self-tools/l_request" "github.com/gofiber/fiber/v2/log" - "github.com/gofiber/websocket/v2" - "xorm.io/builder" ) @@ -141,10 +139,10 @@ func (d *Do) loadChatHistory(ctx context.Context, requireData *entitys.RequireDa return nil } -func (d *Do) MakeCh(c *websocket.Conn, requireData *entitys.RequireData) (ctx context.Context, deferFunc func()) { +func (d *Do) MakeCh(client *gateway.Client, requireData *entitys.RequireData) (ctx context.Context, deferFunc func()) { requireData.Ch = make(chan entitys.Response) ctx, cancel := context.WithCancel(context.Background()) - done := d.startMessageHandler(ctx, c, requireData) + done := d.startMessageHandler(ctx, client, requireData) return ctx, func() { close(requireData.Ch) //关闭主通道 <-done // 等待消息处理完成 @@ -235,7 +233,7 @@ func (d *Do) getTasks(sysId int32) (tasks []model.AiTask, err error) { // startMessageHandler 启动独立的消息处理协程 func (d *Do) startMessageHandler( ctx context.Context, - c *websocket.Conn, + client *gateway.Client, requireData *entitys.RequireData, ) <-chan struct{} { done := make(chan struct{}) @@ -259,7 +257,7 @@ func (d *Do) startMessageHandler( hisLog.HisId = AiRes.HisID } - _ = entitys.MsgSend(c, entitys.Response{ + _ = entitys.MsgSend(client, entitys.Response{ Content: pkg.JsonStringIgonErr(hisLog), Type: entitys.ResponseEnd, }) @@ -267,7 +265,7 @@ func (d *Do) startMessageHandler( }() for v := range requireData.Ch { // 自动检测通道关闭 - if err := sendWithTimeout(c, v, 2*time.Second); err != nil { + if err := sendWithTimeout(client, v, 10*time.Second); err != nil { log.Errorf("Send error: %v", err) return } @@ -281,7 +279,7 @@ func (d *Do) startMessageHandler( } // 辅助函数:带超时的 WebSocket 发送 -func sendWithTimeout(c *websocket.Conn, data entitys.Response, timeout time.Duration) error { +func sendWithTimeout(client *gateway.Client, data entitys.Response, timeout time.Duration) error { sendCtx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() @@ -294,7 +292,7 @@ func sendWithTimeout(c *websocket.Conn, data entitys.Response, timeout time.Dura close(done) }() // 如果 MsgSend 阻塞,这里会卡住 - err := entitys.MsgSend(c, data) + err := entitys.MsgSend(client, data) done <- err }() diff --git a/internal/biz/router.go b/internal/biz/router.go index 6dcc233..975a487 100644 --- a/internal/biz/router.go +++ b/internal/biz/router.go @@ -39,11 +39,9 @@ func (r *AiRouterBiz) RouteWithSocket(client *gateway.Client, req *entitys.ChatS requireData := &entitys.RequireData{ Req: req, } - // 获取WebSocket连接 - conn := client.GetConn() //初始化通道/上下文 - ctx, clearFunc := r.do.MakeCh(conn, requireData) + ctx, clearFunc := r.do.MakeCh(client, requireData) defer func() { if err != nil { entitys.ResError(requireData.Ch, "", err.Error()) diff --git a/internal/entitys/response.go b/internal/entitys/response.go index cdadc98..44e053b 100644 --- a/internal/entitys/response.go +++ b/internal/entitys/response.go @@ -1,6 +1,7 @@ package entitys import ( + "ai_scheduler/internal/gateway" "encoding/json" "github.com/gofiber/websocket/v2" ) @@ -100,13 +101,13 @@ func MsgSet(msgType ResponseType, msg string, done bool) []byte { return jsonByte } -func MsgSend(c *websocket.Conn, msg Response) error { +func MsgSend(client *gateway.Client, msg Response) error { // 检查上下文是否已取消 if msg.Type == ResponseText { } jsonByte, _ := json.Marshal(msg) - return c.WriteMessage(websocket.TextMessage, jsonByte) + return client.SendFunc(jsonByte) } func MsgSendByte(c *websocket.Conn, msg []byte) { diff --git a/internal/gateway/client.go b/internal/gateway/client.go index f90678c..a293daa 100644 --- a/internal/gateway/client.go +++ b/internal/gateway/client.go @@ -3,11 +3,11 @@ package gateway import ( errors "ai_scheduler/internal/data/error" "ai_scheduler/internal/data/model" - "ai_scheduler/internal/pkg" "context" - "encoding/binary" + "github.com/google/uuid" "log" "math/rand" + "sync" "time" "github.com/gofiber/websocket/v2" @@ -32,6 +32,7 @@ type Client struct { Ctx context.Context Cancel context.CancelFunc LastActive time.Time + mu sync.Mutex } func NewClient(conn *websocket.Conn, ctx context.Context, cancel context.CancelFunc) *Client { @@ -41,6 +42,7 @@ func NewClient(conn *websocket.Conn, ctx context.Context, cancel context.CancelF conn: conn, Ctx: ctx, Cancel: cancel, + mu: sync.Mutex{}, } } @@ -74,7 +76,7 @@ func (c *Client) GetCodes() []string { return c.codes } -// GetSysCode 获取系统编码 +// 获取系统编码 func (c *Client) GetSysCode() string { return c.sysCode } @@ -104,26 +106,51 @@ func (c *Client) SetCodes(codes []string) { c.codes = codes } +// Close 关闭客户端连接 +func (c *Client) Close() { + //c.mu.Lock() + //defer c.mu.Unlock() + if c.conn != nil { + _ = c.conn.Close() + c.conn = nil + } +} + // SendFunc 发送消息到客户端 func (c *Client) SendFunc(msg []byte) error { - if c.conn != nil { - return c.conn.WriteMessage(websocket.TextMessage, msg) + return c.SendMessage(websocket.TextMessage, msg) +} + +// 在Client结构体中添加更详细的日志 +func (c *Client) SendMessage(msgType int, msg []byte) error { + c.mu.Lock() + defer c.mu.Unlock() + + if c.conn == nil { + return ErrConnClosed } - return ErrConnClosed + + err := c.conn.WriteMessage(msgType, msg) + if err != nil { + log.Printf("发送消息失败: %v, 客户端ID: %s, 消息类型: %d", + err, c.id, msgType) + } + return err } // 生成唯一的客户端ID func generateClientID() string { - // 1. 时间戳 - timestamp := time.Now().UnixNano() - binary.BigEndian.PutUint64(idBuf[:8], uint64(timestamp)) - - // 2. 随机数(4字节) - binary.BigEndian.PutUint32(idBuf[8:12], rng.Uint32()) - - // 3. 十六进制编码 - n := pkg.HexEncode(idBuf[:12], idBuf[12:]) - return string(idBuf[12 : 12+n]) + return uuid.New().String() + //// 1. 时间戳 + //timestamp := time.Now().UnixNano() + //binary.BigEndian.PutUint64(idBuf[:8], uint64(timestamp)) + // + //// 2. 随机数(4字节) + //binary.BigEndian.PutUint32(idBuf[8:12], rng.Uint32()) + // + //// 3. 十六进制编码 + //n := pkg.HexEncode(idBuf[:12], idBuf[12:]) + //return string(idBuf[12 : 12+n]) } // 连接数据验证和收集 @@ -152,6 +179,7 @@ func (c *Client) DataAuth() (err error) { return } +// 总结:目前绝大多数浏览器不支持直接发送WebSocket Ping帧,因此在实际开发中,应该实现应用层ping机制作为主要心跳检测方案,todo 同时保留对未来可能的原生支持的兼容检测。 func (c *Client) InitHeartbeat(timeoutSecond time.Duration) { ticker := time.NewTicker(timeoutSecond * time.Second) defer ticker.Stop() @@ -160,9 +188,12 @@ func (c *Client) InitHeartbeat(timeoutSecond time.Duration) { case <-ticker.C: //*2是防止丢包,连续丢包两次,再加5s网络延迟容错 if time.Since(c.LastActive) > (timeoutSecond*2*time.Second + 5) { // 5秒容错 - log.Println("Heartbeat timeout", "id", c.id) - c.conn.WriteMessage(websocket.CloseMessage, []byte("Heartbeat timeout")) - c.conn.Close() + log.Println("心跳超时", "clientId", c.id) + err := c.SendMessage(websocket.CloseMessage, []byte("Heartbeat timeout")) + if err != nil { + log.Println("发送心跳超时消息失败", err) + } + c.Close() return } case <-c.Ctx.Done(): @@ -170,3 +201,11 @@ func (c *Client) InitHeartbeat(timeoutSecond time.Duration) { } } } + +// 在Client结构体中添加ReadMessage方法 +func (c *Client) ReadMessage() (messageType int, message []byte, err error) { + if c.conn == nil { + return 0, nil, ErrConnClosed + } + return c.conn.ReadMessage() +} diff --git a/internal/gateway/gateway.go b/internal/gateway/gateway.go index 03ce961..6e09bdc 100644 --- a/internal/gateway/gateway.go +++ b/internal/gateway/gateway.go @@ -34,15 +34,17 @@ func (g *Gateway) AddClient(c *Client) { func (g *Gateway) Cleanup(clientID string) { g.mu.Lock() + // 从网关管理中移除客户端 defer func() { if c, ex := g.clients[clientID]; ex { delete(g.clients, clientID) - _ = c.conn.Close() + c.Close() c.Cancel() } g.mu.Unlock() log.Println("client disconnected:", clientID) }() + // 从所有绑定的UID列表中移除该客户端 for uid, list := range g.uidMap { newList := []string{} for _, cid := range list { @@ -79,7 +81,7 @@ func (g *Gateway) BindUid(clientID, uid string) error { return errors.New("client not found") } g.uidMap[uid] = append(g.uidMap[uid], clientID) - log.Printf("bind %s -> uid:%s\n", clientID, uid) + log.Printf("绑定 clientId %s -> uid:%s\n", clientID, uid) return nil } diff --git a/internal/services/chat.go b/internal/services/chat.go index a01ba73..11dda38 100644 --- a/internal/services/chat.go +++ b/internal/services/chat.go @@ -55,82 +55,109 @@ func (h *ChatService) Chat(c *websocket.Conn) { // 创建新的客户端实例 client := gateway.NewClient(c, ctx, cancel) - // 心跳检测 - go client.InitHeartbeat(time.Duration(h.cfg.Sys.HeartbeatInterval)) - // 将客户端添加到网关管理 + + // 验证并收集连接数据,后续对话中会使用 + if err := client.DataAuth(); err != nil { + log.Println("数据验证错误:", err) + _ = client.SendFunc([]byte(err.Error())) + client.Close() + return + } + + // 验证通过后,将客户端添加到网关管理 h.Gw.AddClient(client) + + // 使用信号量限制并发处理的消息数量 + semaphore := make(chan struct{}, 1) // 最多1个并发消息处理 + // 用于等待所有goroutine完成的wait group + var wg sync.WaitGroup // 确保在函数返回时移除客户端并关闭连接 defer func() { h.Gw.Cleanup(client.GetID()) + close(semaphore) // 关闭信号量通道 + wg.Wait() // 等待所有消息处理goroutine完成 }() - // 绑定会话ID - uid := c.Query("x-session") - if uid != "" { + // 绑定会话ID, sessionId 为空时, 则不绑定 + if uid := client.GetSession(); uid != "" { if err := h.Gw.BindUid(client.GetID(), uid); err != nil { log.Println("绑定UID错误:", err) } } - // 验证并收集连接数据,后续对话中会使用 - if err := client.DataAuth(); err != nil { - log.Println("数据验证错误:", err) - h.ChatFail(c, err.Error()) - return - } + // 开启心跳检测 + go client.InitHeartbeat(time.Duration(h.cfg.Sys.HeartbeatInterval)) // 循环读取客户端消息 for { - // 读取消息 - messageType, message, err := c.ReadMessage() + messageType, message, err := client.ReadMessage() if err != nil { - log.Println("读取错误:", err) + log.Printf("读取错误: %v, 客户端ID: %s", err, client.GetID()) break } - //if string(message) == `{"type":"ping"}` { - // client.LastActive = time.Now() - // if err := c.WriteMessage(websocket.TextMessage, []byte(`{"type":"pong"}`)); err != nil { - // log.Println("Heartbeat response failed", "id", client.GetID(), "err", err) - // return - // } - // continue - //} - // 处理消息 - msg, chatType := h.handleMessageToString(c, messageType, message) - if chatType == constants.ConnStatusClosed { - break - } - if chatType == constants.ConnStatusIgnore { + + // 处理心跳消息 + if messageType == websocket.PingMessage || string(message) == "PING" { + client.LastActive = time.Now() + msgType := websocket.TextMessage + if messageType == websocket.PingMessage { + msgType = websocket.PongMessage + } + if err = client.SendMessage(msgType, []byte(`PONG`)); err != nil { + log.Printf("发送pong消息失败: %v", err) + } continue } - log.Printf("收到消息: %s", string(msg)) + // 使用信号量限制并发 + semaphore <- struct{}{} + wg.Add(1) + go func(msgType int, msg []byte) { + defer func() { + <-semaphore + wg.Done() + // 恢复panic + if r := recover(); r != nil { + log.Printf("消息处理goroutine发生panic: %v", r) + } + }() - // 解析请求 - var req entitys.ChatSockRequest - if err = json.Unmarshal(msg, &req); err != nil { - log.Println("JSON parse error:", err) - continue - } + // 消息处理逻辑 + h.processMessage(client, msgType, msg) + }(messageType, message) + } +} - // 路由处理请求 - err = h.routerBiz.RouteWithSocket(client, &req) - if err != nil { - log.Println("处理失败:", err) - } +// 将消息处理逻辑提取到单独的方法 +func (h *ChatService) processMessage(client *gateway.Client, msgType int, msg []byte) { + // 处理消息 + processedMsg, _ := h.handleMessageToString(client, msgType, msg) + log.Printf("收到消息:消息类型 %d, 内容 %s, 客户端ID: %s", + msgType, string(processedMsg), client.GetID()) + + // 解析请求 + var req entitys.ChatSockRequest + if err := json.Unmarshal(processedMsg, &req); err != nil { + log.Printf("JSON解析错误: %v, 客户端ID: %s", err, client.GetID()) + return + } + + // 路由处理请求 + if err := h.routerBiz.RouteWithSocket(client, &req); err != nil { + log.Printf("处理失败: %v, 客户端ID: %s", err, client.GetID()) } } // handleMessageToString 处理不同类型的WebSocket消息 // 参数: -// - c: WebSocket连接 +// - client: 客户端对象 // - msgType: 消息类型 // - msg: 消息内容 // // 返回: // - text: 处理后的文本内容 // - chatType: 连接状态 -func (h *ChatService) handleMessageToString(c *websocket.Conn, msgType int, msg any) (text []byte, chatType constants.ConnStatus) { +func (h *ChatService) handleMessageToString(client *gateway.Client, msgType int, msg any) (text []byte, chatType constants.ConnStatus) { switch msgType { case websocket.TextMessage: return msg.([]byte), constants.ConnStatusNormal @@ -140,15 +167,13 @@ func (h *ChatService) handleMessageToString(c *websocket.Conn, msgType int, msg return nil, constants.ConnStatusClosed case websocket.PingMessage: - // 可选:回复 Pong - c.WriteMessage(websocket.PongMessage, nil) return nil, constants.ConnStatusIgnore case websocket.PongMessage: return nil, constants.ConnStatusIgnore default: + log.Printf("未知的消息类型: %d", msgType) return nil, constants.ConnStatusIgnore } - return msg.([]byte), constants.ConnStatusIgnore } func (s *ChatService) Useful(c *fiber.Ctx) error {