feat: 心跳检测
This commit is contained in:
parent
4339b6eee8
commit
f2638b32b5
|
|
@ -19,8 +19,6 @@ import (
|
|||
|
||||
"gitea.cdlsxd.cn/self-tools/l_request"
|
||||
"github.com/gofiber/fiber/v2/log"
|
||||
"github.com/gofiber/websocket/v2"
|
||||
|
||||
"xorm.io/builder"
|
||||
)
|
||||
|
||||
|
|
@ -141,10 +139,10 @@ func (d *Do) loadChatHistory(ctx context.Context, requireData *entitys.RequireDa
|
|||
return nil
|
||||
}
|
||||
|
||||
func (d *Do) MakeCh(c *websocket.Conn, requireData *entitys.RequireData) (ctx context.Context, deferFunc func()) {
|
||||
func (d *Do) MakeCh(client *gateway.Client, requireData *entitys.RequireData) (ctx context.Context, deferFunc func()) {
|
||||
requireData.Ch = make(chan entitys.Response)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
done := d.startMessageHandler(ctx, c, requireData)
|
||||
done := d.startMessageHandler(ctx, client, requireData)
|
||||
return ctx, func() {
|
||||
close(requireData.Ch) //关闭主通道
|
||||
<-done // 等待消息处理完成
|
||||
|
|
@ -235,7 +233,7 @@ func (d *Do) getTasks(sysId int32) (tasks []model.AiTask, err error) {
|
|||
// startMessageHandler 启动独立的消息处理协程
|
||||
func (d *Do) startMessageHandler(
|
||||
ctx context.Context,
|
||||
c *websocket.Conn,
|
||||
client *gateway.Client,
|
||||
requireData *entitys.RequireData,
|
||||
) <-chan struct{} {
|
||||
done := make(chan struct{})
|
||||
|
|
@ -259,7 +257,7 @@ func (d *Do) startMessageHandler(
|
|||
hisLog.HisId = AiRes.HisID
|
||||
}
|
||||
|
||||
_ = entitys.MsgSend(c, entitys.Response{
|
||||
_ = entitys.MsgSend(client, entitys.Response{
|
||||
Content: pkg.JsonStringIgonErr(hisLog),
|
||||
Type: entitys.ResponseEnd,
|
||||
})
|
||||
|
|
@ -267,7 +265,7 @@ func (d *Do) startMessageHandler(
|
|||
}()
|
||||
|
||||
for v := range requireData.Ch { // 自动检测通道关闭
|
||||
if err := sendWithTimeout(c, v, 2*time.Second); err != nil {
|
||||
if err := sendWithTimeout(client, v, 10*time.Second); err != nil {
|
||||
log.Errorf("Send error: %v", err)
|
||||
return
|
||||
}
|
||||
|
|
@ -281,7 +279,7 @@ func (d *Do) startMessageHandler(
|
|||
}
|
||||
|
||||
// 辅助函数:带超时的 WebSocket 发送
|
||||
func sendWithTimeout(c *websocket.Conn, data entitys.Response, timeout time.Duration) error {
|
||||
func sendWithTimeout(client *gateway.Client, data entitys.Response, timeout time.Duration) error {
|
||||
sendCtx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
|
|
@ -294,7 +292,7 @@ func sendWithTimeout(c *websocket.Conn, data entitys.Response, timeout time.Dura
|
|||
close(done)
|
||||
}()
|
||||
// 如果 MsgSend 阻塞,这里会卡住
|
||||
err := entitys.MsgSend(c, data)
|
||||
err := entitys.MsgSend(client, data)
|
||||
done <- err
|
||||
}()
|
||||
|
||||
|
|
|
|||
|
|
@ -39,11 +39,9 @@ func (r *AiRouterBiz) RouteWithSocket(client *gateway.Client, req *entitys.ChatS
|
|||
requireData := &entitys.RequireData{
|
||||
Req: req,
|
||||
}
|
||||
// 获取WebSocket连接
|
||||
conn := client.GetConn()
|
||||
|
||||
//初始化通道/上下文
|
||||
ctx, clearFunc := r.do.MakeCh(conn, requireData)
|
||||
ctx, clearFunc := r.do.MakeCh(client, requireData)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
entitys.ResError(requireData.Ch, "", err.Error())
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package entitys
|
||||
|
||||
import (
|
||||
"ai_scheduler/internal/gateway"
|
||||
"encoding/json"
|
||||
"github.com/gofiber/websocket/v2"
|
||||
)
|
||||
|
|
@ -100,13 +101,13 @@ func MsgSet(msgType ResponseType, msg string, done bool) []byte {
|
|||
return jsonByte
|
||||
}
|
||||
|
||||
func MsgSend(c *websocket.Conn, msg Response) error {
|
||||
func MsgSend(client *gateway.Client, msg Response) error {
|
||||
// 检查上下文是否已取消
|
||||
if msg.Type == ResponseText {
|
||||
|
||||
}
|
||||
jsonByte, _ := json.Marshal(msg)
|
||||
return c.WriteMessage(websocket.TextMessage, jsonByte)
|
||||
return client.SendFunc(jsonByte)
|
||||
}
|
||||
|
||||
func MsgSendByte(c *websocket.Conn, msg []byte) {
|
||||
|
|
|
|||
|
|
@ -3,11 +3,11 @@ package gateway
|
|||
import (
|
||||
errors "ai_scheduler/internal/data/error"
|
||||
"ai_scheduler/internal/data/model"
|
||||
"ai_scheduler/internal/pkg"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"github.com/google/uuid"
|
||||
"log"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/websocket/v2"
|
||||
|
|
@ -32,6 +32,7 @@ type Client struct {
|
|||
Ctx context.Context
|
||||
Cancel context.CancelFunc
|
||||
LastActive time.Time
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewClient(conn *websocket.Conn, ctx context.Context, cancel context.CancelFunc) *Client {
|
||||
|
|
@ -41,6 +42,7 @@ func NewClient(conn *websocket.Conn, ctx context.Context, cancel context.CancelF
|
|||
conn: conn,
|
||||
Ctx: ctx,
|
||||
Cancel: cancel,
|
||||
mu: sync.Mutex{},
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -74,7 +76,7 @@ func (c *Client) GetCodes() []string {
|
|||
return c.codes
|
||||
}
|
||||
|
||||
// GetSysCode 获取系统编码
|
||||
// 获取系统编码
|
||||
func (c *Client) GetSysCode() string {
|
||||
return c.sysCode
|
||||
}
|
||||
|
|
@ -104,26 +106,51 @@ func (c *Client) SetCodes(codes []string) {
|
|||
c.codes = codes
|
||||
}
|
||||
|
||||
// Close 关闭客户端连接
|
||||
func (c *Client) Close() {
|
||||
//c.mu.Lock()
|
||||
//defer c.mu.Unlock()
|
||||
if c.conn != nil {
|
||||
_ = c.conn.Close()
|
||||
c.conn = nil
|
||||
}
|
||||
}
|
||||
|
||||
// SendFunc 发送消息到客户端
|
||||
func (c *Client) SendFunc(msg []byte) error {
|
||||
if c.conn != nil {
|
||||
return c.conn.WriteMessage(websocket.TextMessage, msg)
|
||||
return c.SendMessage(websocket.TextMessage, msg)
|
||||
}
|
||||
|
||||
// 在Client结构体中添加更详细的日志
|
||||
func (c *Client) SendMessage(msgType int, msg []byte) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.conn == nil {
|
||||
return ErrConnClosed
|
||||
}
|
||||
return ErrConnClosed
|
||||
|
||||
err := c.conn.WriteMessage(msgType, msg)
|
||||
if err != nil {
|
||||
log.Printf("发送消息失败: %v, 客户端ID: %s, 消息类型: %d",
|
||||
err, c.id, msgType)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// 生成唯一的客户端ID
|
||||
func generateClientID() string {
|
||||
// 1. 时间戳
|
||||
timestamp := time.Now().UnixNano()
|
||||
binary.BigEndian.PutUint64(idBuf[:8], uint64(timestamp))
|
||||
|
||||
// 2. 随机数(4字节)
|
||||
binary.BigEndian.PutUint32(idBuf[8:12], rng.Uint32())
|
||||
|
||||
// 3. 十六进制编码
|
||||
n := pkg.HexEncode(idBuf[:12], idBuf[12:])
|
||||
return string(idBuf[12 : 12+n])
|
||||
return uuid.New().String()
|
||||
//// 1. 时间戳
|
||||
//timestamp := time.Now().UnixNano()
|
||||
//binary.BigEndian.PutUint64(idBuf[:8], uint64(timestamp))
|
||||
//
|
||||
//// 2. 随机数(4字节)
|
||||
//binary.BigEndian.PutUint32(idBuf[8:12], rng.Uint32())
|
||||
//
|
||||
//// 3. 十六进制编码
|
||||
//n := pkg.HexEncode(idBuf[:12], idBuf[12:])
|
||||
//return string(idBuf[12 : 12+n])
|
||||
}
|
||||
|
||||
// 连接数据验证和收集
|
||||
|
|
@ -152,6 +179,7 @@ func (c *Client) DataAuth() (err error) {
|
|||
return
|
||||
}
|
||||
|
||||
// 总结:目前绝大多数浏览器不支持直接发送WebSocket Ping帧,因此在实际开发中,应该实现应用层ping机制作为主要心跳检测方案,todo 同时保留对未来可能的原生支持的兼容检测。
|
||||
func (c *Client) InitHeartbeat(timeoutSecond time.Duration) {
|
||||
ticker := time.NewTicker(timeoutSecond * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
|
@ -160,9 +188,12 @@ func (c *Client) InitHeartbeat(timeoutSecond time.Duration) {
|
|||
case <-ticker.C:
|
||||
//*2是防止丢包,连续丢包两次,再加5s网络延迟容错
|
||||
if time.Since(c.LastActive) > (timeoutSecond*2*time.Second + 5) { // 5秒容错
|
||||
log.Println("Heartbeat timeout", "id", c.id)
|
||||
c.conn.WriteMessage(websocket.CloseMessage, []byte("Heartbeat timeout"))
|
||||
c.conn.Close()
|
||||
log.Println("心跳超时", "clientId", c.id)
|
||||
err := c.SendMessage(websocket.CloseMessage, []byte("Heartbeat timeout"))
|
||||
if err != nil {
|
||||
log.Println("发送心跳超时消息失败", err)
|
||||
}
|
||||
c.Close()
|
||||
return
|
||||
}
|
||||
case <-c.Ctx.Done():
|
||||
|
|
@ -170,3 +201,11 @@ func (c *Client) InitHeartbeat(timeoutSecond time.Duration) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 在Client结构体中添加ReadMessage方法
|
||||
func (c *Client) ReadMessage() (messageType int, message []byte, err error) {
|
||||
if c.conn == nil {
|
||||
return 0, nil, ErrConnClosed
|
||||
}
|
||||
return c.conn.ReadMessage()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -34,15 +34,17 @@ func (g *Gateway) AddClient(c *Client) {
|
|||
|
||||
func (g *Gateway) Cleanup(clientID string) {
|
||||
g.mu.Lock()
|
||||
// 从网关管理中移除客户端
|
||||
defer func() {
|
||||
if c, ex := g.clients[clientID]; ex {
|
||||
delete(g.clients, clientID)
|
||||
_ = c.conn.Close()
|
||||
c.Close()
|
||||
c.Cancel()
|
||||
}
|
||||
g.mu.Unlock()
|
||||
log.Println("client disconnected:", clientID)
|
||||
}()
|
||||
// 从所有绑定的UID列表中移除该客户端
|
||||
for uid, list := range g.uidMap {
|
||||
newList := []string{}
|
||||
for _, cid := range list {
|
||||
|
|
@ -79,7 +81,7 @@ func (g *Gateway) BindUid(clientID, uid string) error {
|
|||
return errors.New("client not found")
|
||||
}
|
||||
g.uidMap[uid] = append(g.uidMap[uid], clientID)
|
||||
log.Printf("bind %s -> uid:%s\n", clientID, uid)
|
||||
log.Printf("绑定 clientId %s -> uid:%s\n", clientID, uid)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -55,82 +55,109 @@ func (h *ChatService) Chat(c *websocket.Conn) {
|
|||
|
||||
// 创建新的客户端实例
|
||||
client := gateway.NewClient(c, ctx, cancel)
|
||||
// 心跳检测
|
||||
go client.InitHeartbeat(time.Duration(h.cfg.Sys.HeartbeatInterval))
|
||||
// 将客户端添加到网关管理
|
||||
|
||||
// 验证并收集连接数据,后续对话中会使用
|
||||
if err := client.DataAuth(); err != nil {
|
||||
log.Println("数据验证错误:", err)
|
||||
_ = client.SendFunc([]byte(err.Error()))
|
||||
client.Close()
|
||||
return
|
||||
}
|
||||
|
||||
// 验证通过后,将客户端添加到网关管理
|
||||
h.Gw.AddClient(client)
|
||||
|
||||
// 使用信号量限制并发处理的消息数量
|
||||
semaphore := make(chan struct{}, 1) // 最多1个并发消息处理
|
||||
// 用于等待所有goroutine完成的wait group
|
||||
var wg sync.WaitGroup
|
||||
// 确保在函数返回时移除客户端并关闭连接
|
||||
defer func() {
|
||||
h.Gw.Cleanup(client.GetID())
|
||||
close(semaphore) // 关闭信号量通道
|
||||
wg.Wait() // 等待所有消息处理goroutine完成
|
||||
}()
|
||||
|
||||
// 绑定会话ID
|
||||
uid := c.Query("x-session")
|
||||
if uid != "" {
|
||||
// 绑定会话ID, sessionId 为空时, 则不绑定
|
||||
if uid := client.GetSession(); uid != "" {
|
||||
if err := h.Gw.BindUid(client.GetID(), uid); err != nil {
|
||||
log.Println("绑定UID错误:", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 验证并收集连接数据,后续对话中会使用
|
||||
if err := client.DataAuth(); err != nil {
|
||||
log.Println("数据验证错误:", err)
|
||||
h.ChatFail(c, err.Error())
|
||||
return
|
||||
}
|
||||
// 开启心跳检测
|
||||
go client.InitHeartbeat(time.Duration(h.cfg.Sys.HeartbeatInterval))
|
||||
|
||||
// 循环读取客户端消息
|
||||
for {
|
||||
// 读取消息
|
||||
messageType, message, err := c.ReadMessage()
|
||||
messageType, message, err := client.ReadMessage()
|
||||
if err != nil {
|
||||
log.Println("读取错误:", err)
|
||||
log.Printf("读取错误: %v, 客户端ID: %s", err, client.GetID())
|
||||
break
|
||||
}
|
||||
//if string(message) == `{"type":"ping"}` {
|
||||
// client.LastActive = time.Now()
|
||||
// if err := c.WriteMessage(websocket.TextMessage, []byte(`{"type":"pong"}`)); err != nil {
|
||||
// log.Println("Heartbeat response failed", "id", client.GetID(), "err", err)
|
||||
// return
|
||||
// }
|
||||
// continue
|
||||
//}
|
||||
// 处理消息
|
||||
msg, chatType := h.handleMessageToString(c, messageType, message)
|
||||
if chatType == constants.ConnStatusClosed {
|
||||
break
|
||||
}
|
||||
if chatType == constants.ConnStatusIgnore {
|
||||
|
||||
// 处理心跳消息
|
||||
if messageType == websocket.PingMessage || string(message) == "PING" {
|
||||
client.LastActive = time.Now()
|
||||
msgType := websocket.TextMessage
|
||||
if messageType == websocket.PingMessage {
|
||||
msgType = websocket.PongMessage
|
||||
}
|
||||
if err = client.SendMessage(msgType, []byte(`PONG`)); err != nil {
|
||||
log.Printf("发送pong消息失败: %v", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
log.Printf("收到消息: %s", string(msg))
|
||||
// 使用信号量限制并发
|
||||
semaphore <- struct{}{}
|
||||
wg.Add(1)
|
||||
go func(msgType int, msg []byte) {
|
||||
defer func() {
|
||||
<-semaphore
|
||||
wg.Done()
|
||||
// 恢复panic
|
||||
if r := recover(); r != nil {
|
||||
log.Printf("消息处理goroutine发生panic: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
// 解析请求
|
||||
var req entitys.ChatSockRequest
|
||||
if err = json.Unmarshal(msg, &req); err != nil {
|
||||
log.Println("JSON parse error:", err)
|
||||
continue
|
||||
}
|
||||
// 消息处理逻辑
|
||||
h.processMessage(client, msgType, msg)
|
||||
}(messageType, message)
|
||||
}
|
||||
}
|
||||
|
||||
// 路由处理请求
|
||||
err = h.routerBiz.RouteWithSocket(client, &req)
|
||||
if err != nil {
|
||||
log.Println("处理失败:", err)
|
||||
}
|
||||
// 将消息处理逻辑提取到单独的方法
|
||||
func (h *ChatService) processMessage(client *gateway.Client, msgType int, msg []byte) {
|
||||
// 处理消息
|
||||
processedMsg, _ := h.handleMessageToString(client, msgType, msg)
|
||||
log.Printf("收到消息:消息类型 %d, 内容 %s, 客户端ID: %s",
|
||||
msgType, string(processedMsg), client.GetID())
|
||||
|
||||
// 解析请求
|
||||
var req entitys.ChatSockRequest
|
||||
if err := json.Unmarshal(processedMsg, &req); err != nil {
|
||||
log.Printf("JSON解析错误: %v, 客户端ID: %s", err, client.GetID())
|
||||
return
|
||||
}
|
||||
|
||||
// 路由处理请求
|
||||
if err := h.routerBiz.RouteWithSocket(client, &req); err != nil {
|
||||
log.Printf("处理失败: %v, 客户端ID: %s", err, client.GetID())
|
||||
}
|
||||
}
|
||||
|
||||
// handleMessageToString 处理不同类型的WebSocket消息
|
||||
// 参数:
|
||||
// - c: WebSocket连接
|
||||
// - client: 客户端对象
|
||||
// - msgType: 消息类型
|
||||
// - msg: 消息内容
|
||||
//
|
||||
// 返回:
|
||||
// - text: 处理后的文本内容
|
||||
// - chatType: 连接状态
|
||||
func (h *ChatService) handleMessageToString(c *websocket.Conn, msgType int, msg any) (text []byte, chatType constants.ConnStatus) {
|
||||
func (h *ChatService) handleMessageToString(client *gateway.Client, msgType int, msg any) (text []byte, chatType constants.ConnStatus) {
|
||||
switch msgType {
|
||||
case websocket.TextMessage:
|
||||
return msg.([]byte), constants.ConnStatusNormal
|
||||
|
|
@ -140,15 +167,13 @@ func (h *ChatService) handleMessageToString(c *websocket.Conn, msgType int, msg
|
|||
|
||||
return nil, constants.ConnStatusClosed
|
||||
case websocket.PingMessage:
|
||||
// 可选:回复 Pong
|
||||
c.WriteMessage(websocket.PongMessage, nil)
|
||||
return nil, constants.ConnStatusIgnore
|
||||
case websocket.PongMessage:
|
||||
return nil, constants.ConnStatusIgnore
|
||||
default:
|
||||
log.Printf("未知的消息类型: %d", msgType)
|
||||
return nil, constants.ConnStatusIgnore
|
||||
}
|
||||
return msg.([]byte), constants.ConnStatusIgnore
|
||||
}
|
||||
|
||||
func (s *ChatService) Useful(c *fiber.Ctx) error {
|
||||
|
|
|
|||
Loading…
Reference in New Issue