package gateway import ( errors "ai_scheduler/internal/data/error" "ai_scheduler/internal/data/model" "ai_scheduler/internal/pkg" "context" "encoding/binary" "log" "math/rand" "time" "github.com/gofiber/websocket/v2" ) var ( ErrConnClosed = errors.SysErr("连接不存在或已关闭") rng = rand.New(rand.NewSource(time.Now().UnixNano())) idBuf = make([]byte, 20) ) type Client struct { id string // 客户端唯一ID conn *websocket.Conn // WebSocket 连接 session string // 会话ID key string // 应用密钥 auth string // 用户凭证token codes []string // 用户权限code sysInfo *model.AiSy // 系统信息 tasks []model.AiTask // 任务列表 sysCode string // 系统编码 Ctx context.Context Cancel context.CancelFunc LastActive time.Time } func NewClient(conn *websocket.Conn, ctx context.Context, cancel context.CancelFunc) *Client { return &Client{ id: generateClientID(), conn: conn, Ctx: ctx, Cancel: cancel, } } // GetID 获取客户端的唯一ID func (c *Client) GetID() string { return c.id } // GetConn 获取客户端的 WebSocket 连接 func (c *Client) GetConn() *websocket.Conn { return c.conn } // GetSession 获取会话ID func (c *Client) GetSession() string { return c.session } // GetKey 获取应用密钥 func (c *Client) GetKey() string { return c.key } // GetAuth 获取用户凭证token func (c *Client) GetAuth() string { return c.auth } // GetCodes 获取用户权限code func (c *Client) GetCodes() []string { return c.codes } // GetSysCode 获取系统编码 func (c *Client) GetSysCode() string { return c.sysCode } // GetSysInfo 获取系统信息 func (c *Client) GetSysInfo() *model.AiSy { return c.sysInfo } // SetSysInfo 设置系统信息 func (c *Client) SetSysInfo(sysInfo *model.AiSy) { c.sysInfo = sysInfo } // GetTasks 获取任务列表 func (c *Client) GetTasks() []model.AiTask { return c.tasks } // SetTasks 设置任务列表 func (c *Client) SetTasks(tasks []model.AiTask) { c.tasks = tasks } // 设置用户权限code func (c *Client) SetCodes(codes []string) { c.codes = codes } // SendFunc 发送消息到客户端 func (c *Client) SendFunc(msg []byte) error { if c.conn != nil { return c.conn.WriteMessage(websocket.TextMessage, msg) } return ErrConnClosed } // 生成唯一的客户端ID func generateClientID() string { // 1. 时间戳 timestamp := time.Now().UnixNano() binary.BigEndian.PutUint64(idBuf[:8], uint64(timestamp)) // 2. 随机数(4字节) binary.BigEndian.PutUint32(idBuf[8:12], rng.Uint32()) // 3. 十六进制编码 n := pkg.HexEncode(idBuf[:12], idBuf[12:]) return string(idBuf[12 : 12+n]) } // 连接数据验证和收集 func (c *Client) DataAuth() (err error) { c.session = c.conn.Query("x-session", "") if len(c.session) == 0 { err = errors.SessionNotFound return } c.auth = c.conn.Query("x-authorization", "") if len(c.auth) == 0 { err = errors.AuthNotFound return } c.key = c.conn.Query("x-app-key", "") if len(c.key) == 0 { err = errors.KeyNotFound return } // 系统编码 c.sysCode = c.conn.Query("x-sys-code", "") if len(c.sysCode) == 0 { err = errors.SysCodeNotFound return } return } func (c *Client) InitHeartbeat(timeoutSecond time.Duration) { ticker := time.NewTicker(timeoutSecond * time.Second) defer ticker.Stop() for { select { case <-ticker.C: //*2是防止丢包,连续丢包两次,再加5s网络延迟容错 if time.Since(c.LastActive) > (timeoutSecond*2*time.Second + 5) { // 5秒容错 log.Println("Heartbeat timeout", "id", c.id) c.conn.WriteMessage(websocket.CloseMessage, []byte("Heartbeat timeout")) c.conn.Close() return } case <-c.Ctx.Done(): return } } }