feat: 心跳检测

This commit is contained in:
wolter 2025-12-05 09:12:06 +08:00
parent 4339b6eee8
commit f2638b32b5
6 changed files with 144 additions and 81 deletions

View File

@ -19,8 +19,6 @@ import (
"gitea.cdlsxd.cn/self-tools/l_request"
"github.com/gofiber/fiber/v2/log"
"github.com/gofiber/websocket/v2"
"xorm.io/builder"
)
@ -141,10 +139,10 @@ func (d *Do) loadChatHistory(ctx context.Context, requireData *entitys.RequireDa
return nil
}
func (d *Do) MakeCh(c *websocket.Conn, requireData *entitys.RequireData) (ctx context.Context, deferFunc func()) {
func (d *Do) MakeCh(client *gateway.Client, requireData *entitys.RequireData) (ctx context.Context, deferFunc func()) {
requireData.Ch = make(chan entitys.Response)
ctx, cancel := context.WithCancel(context.Background())
done := d.startMessageHandler(ctx, c, requireData)
done := d.startMessageHandler(ctx, client, requireData)
return ctx, func() {
close(requireData.Ch) //关闭主通道
<-done // 等待消息处理完成
@ -235,7 +233,7 @@ func (d *Do) getTasks(sysId int32) (tasks []model.AiTask, err error) {
// startMessageHandler 启动独立的消息处理协程
func (d *Do) startMessageHandler(
ctx context.Context,
c *websocket.Conn,
client *gateway.Client,
requireData *entitys.RequireData,
) <-chan struct{} {
done := make(chan struct{})
@ -259,7 +257,7 @@ func (d *Do) startMessageHandler(
hisLog.HisId = AiRes.HisID
}
_ = entitys.MsgSend(c, entitys.Response{
_ = entitys.MsgSend(client, entitys.Response{
Content: pkg.JsonStringIgonErr(hisLog),
Type: entitys.ResponseEnd,
})
@ -267,7 +265,7 @@ func (d *Do) startMessageHandler(
}()
for v := range requireData.Ch { // 自动检测通道关闭
if err := sendWithTimeout(c, v, 2*time.Second); err != nil {
if err := sendWithTimeout(client, v, 10*time.Second); err != nil {
log.Errorf("Send error: %v", err)
return
}
@ -281,7 +279,7 @@ func (d *Do) startMessageHandler(
}
// 辅助函数:带超时的 WebSocket 发送
func sendWithTimeout(c *websocket.Conn, data entitys.Response, timeout time.Duration) error {
func sendWithTimeout(client *gateway.Client, data entitys.Response, timeout time.Duration) error {
sendCtx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
@ -294,7 +292,7 @@ func sendWithTimeout(c *websocket.Conn, data entitys.Response, timeout time.Dura
close(done)
}()
// 如果 MsgSend 阻塞,这里会卡住
err := entitys.MsgSend(c, data)
err := entitys.MsgSend(client, data)
done <- err
}()

View File

@ -39,11 +39,9 @@ func (r *AiRouterBiz) RouteWithSocket(client *gateway.Client, req *entitys.ChatS
requireData := &entitys.RequireData{
Req: req,
}
// 获取WebSocket连接
conn := client.GetConn()
//初始化通道/上下文
ctx, clearFunc := r.do.MakeCh(conn, requireData)
ctx, clearFunc := r.do.MakeCh(client, requireData)
defer func() {
if err != nil {
entitys.ResError(requireData.Ch, "", err.Error())

View File

@ -1,6 +1,7 @@
package entitys
import (
"ai_scheduler/internal/gateway"
"encoding/json"
"github.com/gofiber/websocket/v2"
)
@ -100,13 +101,13 @@ func MsgSet(msgType ResponseType, msg string, done bool) []byte {
return jsonByte
}
func MsgSend(c *websocket.Conn, msg Response) error {
func MsgSend(client *gateway.Client, msg Response) error {
// 检查上下文是否已取消
if msg.Type == ResponseText {
}
jsonByte, _ := json.Marshal(msg)
return c.WriteMessage(websocket.TextMessage, jsonByte)
return client.SendFunc(jsonByte)
}
func MsgSendByte(c *websocket.Conn, msg []byte) {

View File

@ -3,11 +3,11 @@ package gateway
import (
errors "ai_scheduler/internal/data/error"
"ai_scheduler/internal/data/model"
"ai_scheduler/internal/pkg"
"context"
"encoding/binary"
"github.com/google/uuid"
"log"
"math/rand"
"sync"
"time"
"github.com/gofiber/websocket/v2"
@ -32,6 +32,7 @@ type Client struct {
Ctx context.Context
Cancel context.CancelFunc
LastActive time.Time
mu sync.Mutex
}
func NewClient(conn *websocket.Conn, ctx context.Context, cancel context.CancelFunc) *Client {
@ -41,6 +42,7 @@ func NewClient(conn *websocket.Conn, ctx context.Context, cancel context.CancelF
conn: conn,
Ctx: ctx,
Cancel: cancel,
mu: sync.Mutex{},
}
}
@ -74,7 +76,7 @@ func (c *Client) GetCodes() []string {
return c.codes
}
// GetSysCode 获取系统编码
// 获取系统编码
func (c *Client) GetSysCode() string {
return c.sysCode
}
@ -104,26 +106,51 @@ func (c *Client) SetCodes(codes []string) {
c.codes = codes
}
// Close 关闭客户端连接
func (c *Client) Close() {
//c.mu.Lock()
//defer c.mu.Unlock()
if c.conn != nil {
_ = c.conn.Close()
c.conn = nil
}
}
// SendFunc 发送消息到客户端
func (c *Client) SendFunc(msg []byte) error {
if c.conn != nil {
return c.conn.WriteMessage(websocket.TextMessage, msg)
return c.SendMessage(websocket.TextMessage, msg)
}
// 在Client结构体中添加更详细的日志
func (c *Client) SendMessage(msgType int, msg []byte) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.conn == nil {
return ErrConnClosed
}
return ErrConnClosed
err := c.conn.WriteMessage(msgType, msg)
if err != nil {
log.Printf("发送消息失败: %v, 客户端ID: %s, 消息类型: %d",
err, c.id, msgType)
}
return err
}
// 生成唯一的客户端ID
func generateClientID() string {
// 1. 时间戳
timestamp := time.Now().UnixNano()
binary.BigEndian.PutUint64(idBuf[:8], uint64(timestamp))
// 2. 随机数4字节
binary.BigEndian.PutUint32(idBuf[8:12], rng.Uint32())
// 3. 十六进制编码
n := pkg.HexEncode(idBuf[:12], idBuf[12:])
return string(idBuf[12 : 12+n])
return uuid.New().String()
//// 1. 时间戳
//timestamp := time.Now().UnixNano()
//binary.BigEndian.PutUint64(idBuf[:8], uint64(timestamp))
//
//// 2. 随机数4字节
//binary.BigEndian.PutUint32(idBuf[8:12], rng.Uint32())
//
//// 3. 十六进制编码
//n := pkg.HexEncode(idBuf[:12], idBuf[12:])
//return string(idBuf[12 : 12+n])
}
// 连接数据验证和收集
@ -152,6 +179,7 @@ func (c *Client) DataAuth() (err error) {
return
}
// 总结目前绝大多数浏览器不支持直接发送WebSocket Ping帧因此在实际开发中应该实现应用层ping机制作为主要心跳检测方案todo 同时保留对未来可能的原生支持的兼容检测。
func (c *Client) InitHeartbeat(timeoutSecond time.Duration) {
ticker := time.NewTicker(timeoutSecond * time.Second)
defer ticker.Stop()
@ -160,9 +188,12 @@ func (c *Client) InitHeartbeat(timeoutSecond time.Duration) {
case <-ticker.C:
//*2是防止丢包连续丢包两次再加5s网络延迟容错
if time.Since(c.LastActive) > (timeoutSecond*2*time.Second + 5) { // 5秒容错
log.Println("Heartbeat timeout", "id", c.id)
c.conn.WriteMessage(websocket.CloseMessage, []byte("Heartbeat timeout"))
c.conn.Close()
log.Println("心跳超时", "clientId", c.id)
err := c.SendMessage(websocket.CloseMessage, []byte("Heartbeat timeout"))
if err != nil {
log.Println("发送心跳超时消息失败", err)
}
c.Close()
return
}
case <-c.Ctx.Done():
@ -170,3 +201,11 @@ func (c *Client) InitHeartbeat(timeoutSecond time.Duration) {
}
}
}
// 在Client结构体中添加ReadMessage方法
func (c *Client) ReadMessage() (messageType int, message []byte, err error) {
if c.conn == nil {
return 0, nil, ErrConnClosed
}
return c.conn.ReadMessage()
}

View File

@ -34,15 +34,17 @@ func (g *Gateway) AddClient(c *Client) {
func (g *Gateway) Cleanup(clientID string) {
g.mu.Lock()
// 从网关管理中移除客户端
defer func() {
if c, ex := g.clients[clientID]; ex {
delete(g.clients, clientID)
_ = c.conn.Close()
c.Close()
c.Cancel()
}
g.mu.Unlock()
log.Println("client disconnected:", clientID)
}()
// 从所有绑定的UID列表中移除该客户端
for uid, list := range g.uidMap {
newList := []string{}
for _, cid := range list {
@ -79,7 +81,7 @@ func (g *Gateway) BindUid(clientID, uid string) error {
return errors.New("client not found")
}
g.uidMap[uid] = append(g.uidMap[uid], clientID)
log.Printf("bind %s -> uid:%s\n", clientID, uid)
log.Printf("绑定 clientId %s -> uid:%s\n", clientID, uid)
return nil
}

View File

@ -55,82 +55,109 @@ func (h *ChatService) Chat(c *websocket.Conn) {
// 创建新的客户端实例
client := gateway.NewClient(c, ctx, cancel)
// 心跳检测
go client.InitHeartbeat(time.Duration(h.cfg.Sys.HeartbeatInterval))
// 将客户端添加到网关管理
// 验证并收集连接数据,后续对话中会使用
if err := client.DataAuth(); err != nil {
log.Println("数据验证错误:", err)
_ = client.SendFunc([]byte(err.Error()))
client.Close()
return
}
// 验证通过后,将客户端添加到网关管理
h.Gw.AddClient(client)
// 使用信号量限制并发处理的消息数量
semaphore := make(chan struct{}, 1) // 最多1个并发消息处理
// 用于等待所有goroutine完成的wait group
var wg sync.WaitGroup
// 确保在函数返回时移除客户端并关闭连接
defer func() {
h.Gw.Cleanup(client.GetID())
close(semaphore) // 关闭信号量通道
wg.Wait() // 等待所有消息处理goroutine完成
}()
// 绑定会话ID
uid := c.Query("x-session")
if uid != "" {
// 绑定会话ID, sessionId 为空时, 则不绑定
if uid := client.GetSession(); uid != "" {
if err := h.Gw.BindUid(client.GetID(), uid); err != nil {
log.Println("绑定UID错误:", err)
}
}
// 验证并收集连接数据,后续对话中会使用
if err := client.DataAuth(); err != nil {
log.Println("数据验证错误:", err)
h.ChatFail(c, err.Error())
return
}
// 开启心跳检测
go client.InitHeartbeat(time.Duration(h.cfg.Sys.HeartbeatInterval))
// 循环读取客户端消息
for {
// 读取消息
messageType, message, err := c.ReadMessage()
messageType, message, err := client.ReadMessage()
if err != nil {
log.Println("读取错误:", err)
log.Printf("读取错误: %v, 客户端ID: %s", err, client.GetID())
break
}
//if string(message) == `{"type":"ping"}` {
// client.LastActive = time.Now()
// if err := c.WriteMessage(websocket.TextMessage, []byte(`{"type":"pong"}`)); err != nil {
// log.Println("Heartbeat response failed", "id", client.GetID(), "err", err)
// return
// }
// continue
//}
// 处理消息
msg, chatType := h.handleMessageToString(c, messageType, message)
if chatType == constants.ConnStatusClosed {
break
}
if chatType == constants.ConnStatusIgnore {
// 处理心跳消息
if messageType == websocket.PingMessage || string(message) == "PING" {
client.LastActive = time.Now()
msgType := websocket.TextMessage
if messageType == websocket.PingMessage {
msgType = websocket.PongMessage
}
if err = client.SendMessage(msgType, []byte(`PONG`)); err != nil {
log.Printf("发送pong消息失败: %v", err)
}
continue
}
log.Printf("收到消息: %s", string(msg))
// 使用信号量限制并发
semaphore <- struct{}{}
wg.Add(1)
go func(msgType int, msg []byte) {
defer func() {
<-semaphore
wg.Done()
// 恢复panic
if r := recover(); r != nil {
log.Printf("消息处理goroutine发生panic: %v", r)
}
}()
// 解析请求
var req entitys.ChatSockRequest
if err = json.Unmarshal(msg, &req); err != nil {
log.Println("JSON parse error:", err)
continue
}
// 消息处理逻辑
h.processMessage(client, msgType, msg)
}(messageType, message)
}
}
// 路由处理请求
err = h.routerBiz.RouteWithSocket(client, &req)
if err != nil {
log.Println("处理失败:", err)
}
// 将消息处理逻辑提取到单独的方法
func (h *ChatService) processMessage(client *gateway.Client, msgType int, msg []byte) {
// 处理消息
processedMsg, _ := h.handleMessageToString(client, msgType, msg)
log.Printf("收到消息:消息类型 %d, 内容 %s, 客户端ID: %s",
msgType, string(processedMsg), client.GetID())
// 解析请求
var req entitys.ChatSockRequest
if err := json.Unmarshal(processedMsg, &req); err != nil {
log.Printf("JSON解析错误: %v, 客户端ID: %s", err, client.GetID())
return
}
// 路由处理请求
if err := h.routerBiz.RouteWithSocket(client, &req); err != nil {
log.Printf("处理失败: %v, 客户端ID: %s", err, client.GetID())
}
}
// handleMessageToString 处理不同类型的WebSocket消息
// 参数:
// - c: WebSocket连接
// - client: 客户端对象
// - msgType: 消息类型
// - msg: 消息内容
//
// 返回:
// - text: 处理后的文本内容
// - chatType: 连接状态
func (h *ChatService) handleMessageToString(c *websocket.Conn, msgType int, msg any) (text []byte, chatType constants.ConnStatus) {
func (h *ChatService) handleMessageToString(client *gateway.Client, msgType int, msg any) (text []byte, chatType constants.ConnStatus) {
switch msgType {
case websocket.TextMessage:
return msg.([]byte), constants.ConnStatusNormal
@ -140,15 +167,13 @@ func (h *ChatService) handleMessageToString(c *websocket.Conn, msgType int, msg
return nil, constants.ConnStatusClosed
case websocket.PingMessage:
// 可选:回复 Pong
c.WriteMessage(websocket.PongMessage, nil)
return nil, constants.ConnStatusIgnore
case websocket.PongMessage:
return nil, constants.ConnStatusIgnore
default:
log.Printf("未知的消息类型: %d", msgType)
return nil, constants.ConnStatusIgnore
}
return msg.([]byte), constants.ConnStatusIgnore
}
func (s *ChatService) Useful(c *fiber.Ctx) error {