feat: 心跳检测

This commit is contained in:
wolter 2025-12-05 09:12:06 +08:00
parent 4339b6eee8
commit f2638b32b5
6 changed files with 144 additions and 81 deletions

View File

@ -19,8 +19,6 @@ import (
"gitea.cdlsxd.cn/self-tools/l_request" "gitea.cdlsxd.cn/self-tools/l_request"
"github.com/gofiber/fiber/v2/log" "github.com/gofiber/fiber/v2/log"
"github.com/gofiber/websocket/v2"
"xorm.io/builder" "xorm.io/builder"
) )
@ -141,10 +139,10 @@ func (d *Do) loadChatHistory(ctx context.Context, requireData *entitys.RequireDa
return nil return nil
} }
func (d *Do) MakeCh(c *websocket.Conn, requireData *entitys.RequireData) (ctx context.Context, deferFunc func()) { func (d *Do) MakeCh(client *gateway.Client, requireData *entitys.RequireData) (ctx context.Context, deferFunc func()) {
requireData.Ch = make(chan entitys.Response) requireData.Ch = make(chan entitys.Response)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
done := d.startMessageHandler(ctx, c, requireData) done := d.startMessageHandler(ctx, client, requireData)
return ctx, func() { return ctx, func() {
close(requireData.Ch) //关闭主通道 close(requireData.Ch) //关闭主通道
<-done // 等待消息处理完成 <-done // 等待消息处理完成
@ -235,7 +233,7 @@ func (d *Do) getTasks(sysId int32) (tasks []model.AiTask, err error) {
// startMessageHandler 启动独立的消息处理协程 // startMessageHandler 启动独立的消息处理协程
func (d *Do) startMessageHandler( func (d *Do) startMessageHandler(
ctx context.Context, ctx context.Context,
c *websocket.Conn, client *gateway.Client,
requireData *entitys.RequireData, requireData *entitys.RequireData,
) <-chan struct{} { ) <-chan struct{} {
done := make(chan struct{}) done := make(chan struct{})
@ -259,7 +257,7 @@ func (d *Do) startMessageHandler(
hisLog.HisId = AiRes.HisID hisLog.HisId = AiRes.HisID
} }
_ = entitys.MsgSend(c, entitys.Response{ _ = entitys.MsgSend(client, entitys.Response{
Content: pkg.JsonStringIgonErr(hisLog), Content: pkg.JsonStringIgonErr(hisLog),
Type: entitys.ResponseEnd, Type: entitys.ResponseEnd,
}) })
@ -267,7 +265,7 @@ func (d *Do) startMessageHandler(
}() }()
for v := range requireData.Ch { // 自动检测通道关闭 for v := range requireData.Ch { // 自动检测通道关闭
if err := sendWithTimeout(c, v, 2*time.Second); err != nil { if err := sendWithTimeout(client, v, 10*time.Second); err != nil {
log.Errorf("Send error: %v", err) log.Errorf("Send error: %v", err)
return return
} }
@ -281,7 +279,7 @@ func (d *Do) startMessageHandler(
} }
// 辅助函数:带超时的 WebSocket 发送 // 辅助函数:带超时的 WebSocket 发送
func sendWithTimeout(c *websocket.Conn, data entitys.Response, timeout time.Duration) error { func sendWithTimeout(client *gateway.Client, data entitys.Response, timeout time.Duration) error {
sendCtx, cancel := context.WithTimeout(context.Background(), timeout) sendCtx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel() defer cancel()
@ -294,7 +292,7 @@ func sendWithTimeout(c *websocket.Conn, data entitys.Response, timeout time.Dura
close(done) close(done)
}() }()
// 如果 MsgSend 阻塞,这里会卡住 // 如果 MsgSend 阻塞,这里会卡住
err := entitys.MsgSend(c, data) err := entitys.MsgSend(client, data)
done <- err done <- err
}() }()

View File

@ -39,11 +39,9 @@ func (r *AiRouterBiz) RouteWithSocket(client *gateway.Client, req *entitys.ChatS
requireData := &entitys.RequireData{ requireData := &entitys.RequireData{
Req: req, Req: req,
} }
// 获取WebSocket连接
conn := client.GetConn()
//初始化通道/上下文 //初始化通道/上下文
ctx, clearFunc := r.do.MakeCh(conn, requireData) ctx, clearFunc := r.do.MakeCh(client, requireData)
defer func() { defer func() {
if err != nil { if err != nil {
entitys.ResError(requireData.Ch, "", err.Error()) entitys.ResError(requireData.Ch, "", err.Error())

View File

@ -1,6 +1,7 @@
package entitys package entitys
import ( import (
"ai_scheduler/internal/gateway"
"encoding/json" "encoding/json"
"github.com/gofiber/websocket/v2" "github.com/gofiber/websocket/v2"
) )
@ -100,13 +101,13 @@ func MsgSet(msgType ResponseType, msg string, done bool) []byte {
return jsonByte return jsonByte
} }
func MsgSend(c *websocket.Conn, msg Response) error { func MsgSend(client *gateway.Client, msg Response) error {
// 检查上下文是否已取消 // 检查上下文是否已取消
if msg.Type == ResponseText { if msg.Type == ResponseText {
} }
jsonByte, _ := json.Marshal(msg) jsonByte, _ := json.Marshal(msg)
return c.WriteMessage(websocket.TextMessage, jsonByte) return client.SendFunc(jsonByte)
} }
func MsgSendByte(c *websocket.Conn, msg []byte) { func MsgSendByte(c *websocket.Conn, msg []byte) {

View File

@ -3,11 +3,11 @@ package gateway
import ( import (
errors "ai_scheduler/internal/data/error" errors "ai_scheduler/internal/data/error"
"ai_scheduler/internal/data/model" "ai_scheduler/internal/data/model"
"ai_scheduler/internal/pkg"
"context" "context"
"encoding/binary" "github.com/google/uuid"
"log" "log"
"math/rand" "math/rand"
"sync"
"time" "time"
"github.com/gofiber/websocket/v2" "github.com/gofiber/websocket/v2"
@ -32,6 +32,7 @@ type Client struct {
Ctx context.Context Ctx context.Context
Cancel context.CancelFunc Cancel context.CancelFunc
LastActive time.Time LastActive time.Time
mu sync.Mutex
} }
func NewClient(conn *websocket.Conn, ctx context.Context, cancel context.CancelFunc) *Client { func NewClient(conn *websocket.Conn, ctx context.Context, cancel context.CancelFunc) *Client {
@ -41,6 +42,7 @@ func NewClient(conn *websocket.Conn, ctx context.Context, cancel context.CancelF
conn: conn, conn: conn,
Ctx: ctx, Ctx: ctx,
Cancel: cancel, Cancel: cancel,
mu: sync.Mutex{},
} }
} }
@ -74,7 +76,7 @@ func (c *Client) GetCodes() []string {
return c.codes return c.codes
} }
// GetSysCode 获取系统编码 // 获取系统编码
func (c *Client) GetSysCode() string { func (c *Client) GetSysCode() string {
return c.sysCode return c.sysCode
} }
@ -104,26 +106,51 @@ func (c *Client) SetCodes(codes []string) {
c.codes = codes c.codes = codes
} }
// Close 关闭客户端连接
func (c *Client) Close() {
//c.mu.Lock()
//defer c.mu.Unlock()
if c.conn != nil {
_ = c.conn.Close()
c.conn = nil
}
}
// SendFunc 发送消息到客户端 // SendFunc 发送消息到客户端
func (c *Client) SendFunc(msg []byte) error { func (c *Client) SendFunc(msg []byte) error {
if c.conn != nil { return c.SendMessage(websocket.TextMessage, msg)
return c.conn.WriteMessage(websocket.TextMessage, msg) }
// 在Client结构体中添加更详细的日志
func (c *Client) SendMessage(msgType int, msg []byte) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.conn == nil {
return ErrConnClosed
} }
return ErrConnClosed
err := c.conn.WriteMessage(msgType, msg)
if err != nil {
log.Printf("发送消息失败: %v, 客户端ID: %s, 消息类型: %d",
err, c.id, msgType)
}
return err
} }
// 生成唯一的客户端ID // 生成唯一的客户端ID
func generateClientID() string { func generateClientID() string {
// 1. 时间戳 return uuid.New().String()
timestamp := time.Now().UnixNano() //// 1. 时间戳
binary.BigEndian.PutUint64(idBuf[:8], uint64(timestamp)) //timestamp := time.Now().UnixNano()
//binary.BigEndian.PutUint64(idBuf[:8], uint64(timestamp))
// 2. 随机数4字节 //
binary.BigEndian.PutUint32(idBuf[8:12], rng.Uint32()) //// 2. 随机数4字节
//binary.BigEndian.PutUint32(idBuf[8:12], rng.Uint32())
// 3. 十六进制编码 //
n := pkg.HexEncode(idBuf[:12], idBuf[12:]) //// 3. 十六进制编码
return string(idBuf[12 : 12+n]) //n := pkg.HexEncode(idBuf[:12], idBuf[12:])
//return string(idBuf[12 : 12+n])
} }
// 连接数据验证和收集 // 连接数据验证和收集
@ -152,6 +179,7 @@ func (c *Client) DataAuth() (err error) {
return return
} }
// 总结目前绝大多数浏览器不支持直接发送WebSocket Ping帧因此在实际开发中应该实现应用层ping机制作为主要心跳检测方案todo 同时保留对未来可能的原生支持的兼容检测。
func (c *Client) InitHeartbeat(timeoutSecond time.Duration) { func (c *Client) InitHeartbeat(timeoutSecond time.Duration) {
ticker := time.NewTicker(timeoutSecond * time.Second) ticker := time.NewTicker(timeoutSecond * time.Second)
defer ticker.Stop() defer ticker.Stop()
@ -160,9 +188,12 @@ func (c *Client) InitHeartbeat(timeoutSecond time.Duration) {
case <-ticker.C: case <-ticker.C:
//*2是防止丢包连续丢包两次再加5s网络延迟容错 //*2是防止丢包连续丢包两次再加5s网络延迟容错
if time.Since(c.LastActive) > (timeoutSecond*2*time.Second + 5) { // 5秒容错 if time.Since(c.LastActive) > (timeoutSecond*2*time.Second + 5) { // 5秒容错
log.Println("Heartbeat timeout", "id", c.id) log.Println("心跳超时", "clientId", c.id)
c.conn.WriteMessage(websocket.CloseMessage, []byte("Heartbeat timeout")) err := c.SendMessage(websocket.CloseMessage, []byte("Heartbeat timeout"))
c.conn.Close() if err != nil {
log.Println("发送心跳超时消息失败", err)
}
c.Close()
return return
} }
case <-c.Ctx.Done(): case <-c.Ctx.Done():
@ -170,3 +201,11 @@ func (c *Client) InitHeartbeat(timeoutSecond time.Duration) {
} }
} }
} }
// 在Client结构体中添加ReadMessage方法
func (c *Client) ReadMessage() (messageType int, message []byte, err error) {
if c.conn == nil {
return 0, nil, ErrConnClosed
}
return c.conn.ReadMessage()
}

View File

@ -34,15 +34,17 @@ func (g *Gateway) AddClient(c *Client) {
func (g *Gateway) Cleanup(clientID string) { func (g *Gateway) Cleanup(clientID string) {
g.mu.Lock() g.mu.Lock()
// 从网关管理中移除客户端
defer func() { defer func() {
if c, ex := g.clients[clientID]; ex { if c, ex := g.clients[clientID]; ex {
delete(g.clients, clientID) delete(g.clients, clientID)
_ = c.conn.Close() c.Close()
c.Cancel() c.Cancel()
} }
g.mu.Unlock() g.mu.Unlock()
log.Println("client disconnected:", clientID) log.Println("client disconnected:", clientID)
}() }()
// 从所有绑定的UID列表中移除该客户端
for uid, list := range g.uidMap { for uid, list := range g.uidMap {
newList := []string{} newList := []string{}
for _, cid := range list { for _, cid := range list {
@ -79,7 +81,7 @@ func (g *Gateway) BindUid(clientID, uid string) error {
return errors.New("client not found") return errors.New("client not found")
} }
g.uidMap[uid] = append(g.uidMap[uid], clientID) g.uidMap[uid] = append(g.uidMap[uid], clientID)
log.Printf("bind %s -> uid:%s\n", clientID, uid) log.Printf("绑定 clientId %s -> uid:%s\n", clientID, uid)
return nil return nil
} }

View File

@ -55,82 +55,109 @@ func (h *ChatService) Chat(c *websocket.Conn) {
// 创建新的客户端实例 // 创建新的客户端实例
client := gateway.NewClient(c, ctx, cancel) client := gateway.NewClient(c, ctx, cancel)
// 心跳检测
go client.InitHeartbeat(time.Duration(h.cfg.Sys.HeartbeatInterval)) // 验证并收集连接数据,后续对话中会使用
// 将客户端添加到网关管理 if err := client.DataAuth(); err != nil {
log.Println("数据验证错误:", err)
_ = client.SendFunc([]byte(err.Error()))
client.Close()
return
}
// 验证通过后,将客户端添加到网关管理
h.Gw.AddClient(client) h.Gw.AddClient(client)
// 使用信号量限制并发处理的消息数量
semaphore := make(chan struct{}, 1) // 最多1个并发消息处理
// 用于等待所有goroutine完成的wait group
var wg sync.WaitGroup
// 确保在函数返回时移除客户端并关闭连接 // 确保在函数返回时移除客户端并关闭连接
defer func() { defer func() {
h.Gw.Cleanup(client.GetID()) h.Gw.Cleanup(client.GetID())
close(semaphore) // 关闭信号量通道
wg.Wait() // 等待所有消息处理goroutine完成
}() }()
// 绑定会话ID // 绑定会话ID, sessionId 为空时, 则不绑定
uid := c.Query("x-session") if uid := client.GetSession(); uid != "" {
if uid != "" {
if err := h.Gw.BindUid(client.GetID(), uid); err != nil { if err := h.Gw.BindUid(client.GetID(), uid); err != nil {
log.Println("绑定UID错误:", err) log.Println("绑定UID错误:", err)
} }
} }
// 验证并收集连接数据,后续对话中会使用 // 开启心跳检测
if err := client.DataAuth(); err != nil { go client.InitHeartbeat(time.Duration(h.cfg.Sys.HeartbeatInterval))
log.Println("数据验证错误:", err)
h.ChatFail(c, err.Error())
return
}
// 循环读取客户端消息 // 循环读取客户端消息
for { for {
// 读取消息 messageType, message, err := client.ReadMessage()
messageType, message, err := c.ReadMessage()
if err != nil { if err != nil {
log.Println("读取错误:", err) log.Printf("读取错误: %v, 客户端ID: %s", err, client.GetID())
break break
} }
//if string(message) == `{"type":"ping"}` {
// client.LastActive = time.Now() // 处理心跳消息
// if err := c.WriteMessage(websocket.TextMessage, []byte(`{"type":"pong"}`)); err != nil { if messageType == websocket.PingMessage || string(message) == "PING" {
// log.Println("Heartbeat response failed", "id", client.GetID(), "err", err) client.LastActive = time.Now()
// return msgType := websocket.TextMessage
// } if messageType == websocket.PingMessage {
// continue msgType = websocket.PongMessage
//} }
// 处理消息 if err = client.SendMessage(msgType, []byte(`PONG`)); err != nil {
msg, chatType := h.handleMessageToString(c, messageType, message) log.Printf("发送pong消息失败: %v", err)
if chatType == constants.ConnStatusClosed { }
break
}
if chatType == constants.ConnStatusIgnore {
continue continue
} }
log.Printf("收到消息: %s", string(msg)) // 使用信号量限制并发
semaphore <- struct{}{}
wg.Add(1)
go func(msgType int, msg []byte) {
defer func() {
<-semaphore
wg.Done()
// 恢复panic
if r := recover(); r != nil {
log.Printf("消息处理goroutine发生panic: %v", r)
}
}()
// 解析请求 // 消息处理逻辑
var req entitys.ChatSockRequest h.processMessage(client, msgType, msg)
if err = json.Unmarshal(msg, &req); err != nil { }(messageType, message)
log.Println("JSON parse error:", err) }
continue }
}
// 路由处理请求 // 将消息处理逻辑提取到单独的方法
err = h.routerBiz.RouteWithSocket(client, &req) func (h *ChatService) processMessage(client *gateway.Client, msgType int, msg []byte) {
if err != nil { // 处理消息
log.Println("处理失败:", err) processedMsg, _ := h.handleMessageToString(client, msgType, msg)
} log.Printf("收到消息:消息类型 %d, 内容 %s, 客户端ID: %s",
msgType, string(processedMsg), client.GetID())
// 解析请求
var req entitys.ChatSockRequest
if err := json.Unmarshal(processedMsg, &req); err != nil {
log.Printf("JSON解析错误: %v, 客户端ID: %s", err, client.GetID())
return
}
// 路由处理请求
if err := h.routerBiz.RouteWithSocket(client, &req); err != nil {
log.Printf("处理失败: %v, 客户端ID: %s", err, client.GetID())
} }
} }
// handleMessageToString 处理不同类型的WebSocket消息 // handleMessageToString 处理不同类型的WebSocket消息
// 参数: // 参数:
// - c: WebSocket连接 // - client: 客户端对象
// - msgType: 消息类型 // - msgType: 消息类型
// - msg: 消息内容 // - msg: 消息内容
// //
// 返回: // 返回:
// - text: 处理后的文本内容 // - text: 处理后的文本内容
// - chatType: 连接状态 // - chatType: 连接状态
func (h *ChatService) handleMessageToString(c *websocket.Conn, msgType int, msg any) (text []byte, chatType constants.ConnStatus) { func (h *ChatService) handleMessageToString(client *gateway.Client, msgType int, msg any) (text []byte, chatType constants.ConnStatus) {
switch msgType { switch msgType {
case websocket.TextMessage: case websocket.TextMessage:
return msg.([]byte), constants.ConnStatusNormal return msg.([]byte), constants.ConnStatusNormal
@ -140,15 +167,13 @@ func (h *ChatService) handleMessageToString(c *websocket.Conn, msgType int, msg
return nil, constants.ConnStatusClosed return nil, constants.ConnStatusClosed
case websocket.PingMessage: case websocket.PingMessage:
// 可选:回复 Pong
c.WriteMessage(websocket.PongMessage, nil)
return nil, constants.ConnStatusIgnore return nil, constants.ConnStatusIgnore
case websocket.PongMessage: case websocket.PongMessage:
return nil, constants.ConnStatusIgnore return nil, constants.ConnStatusIgnore
default: default:
log.Printf("未知的消息类型: %d", msgType)
return nil, constants.ConnStatusIgnore return nil, constants.ConnStatusIgnore
} }
return msg.([]byte), constants.ConnStatusIgnore
} }
func (s *ChatService) Useful(c *fiber.Ctx) error { func (s *ChatService) Useful(c *fiber.Ctx) error {