ai_scheduler/internal/gateway/client.go

173 lines
3.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package gateway
import (
errors "ai_scheduler/internal/data/error"
"ai_scheduler/internal/data/model"
"ai_scheduler/internal/pkg"
"context"
"encoding/binary"
"log"
"math/rand"
"time"
"github.com/gofiber/websocket/v2"
)
var (
ErrConnClosed = errors.SysErr("连接不存在或已关闭")
rng = rand.New(rand.NewSource(time.Now().UnixNano()))
idBuf = make([]byte, 20)
)
type Client struct {
id string // 客户端唯一ID
conn *websocket.Conn // WebSocket 连接
session string // 会话ID
key string // 应用密钥
auth string // 用户凭证token
codes []string // 用户权限code
sysInfo *model.AiSy // 系统信息
tasks []model.AiTask // 任务列表
sysCode string // 系统编码
Ctx context.Context
Cancel context.CancelFunc
LastActive time.Time
}
func NewClient(conn *websocket.Conn, ctx context.Context, cancel context.CancelFunc) *Client {
return &Client{
id: generateClientID(),
conn: conn,
Ctx: ctx,
Cancel: cancel,
}
}
// GetID 获取客户端的唯一ID
func (c *Client) GetID() string {
return c.id
}
// GetConn 获取客户端的 WebSocket 连接
func (c *Client) GetConn() *websocket.Conn {
return c.conn
}
// GetSession 获取会话ID
func (c *Client) GetSession() string {
return c.session
}
// GetKey 获取应用密钥
func (c *Client) GetKey() string {
return c.key
}
// GetAuth 获取用户凭证token
func (c *Client) GetAuth() string {
return c.auth
}
// GetCodes 获取用户权限code
func (c *Client) GetCodes() []string {
return c.codes
}
// GetSysCode 获取系统编码
func (c *Client) GetSysCode() string {
return c.sysCode
}
// GetSysInfo 获取系统信息
func (c *Client) GetSysInfo() *model.AiSy {
return c.sysInfo
}
// SetSysInfo 设置系统信息
func (c *Client) SetSysInfo(sysInfo *model.AiSy) {
c.sysInfo = sysInfo
}
// GetTasks 获取任务列表
func (c *Client) GetTasks() []model.AiTask {
return c.tasks
}
// SetTasks 设置任务列表
func (c *Client) SetTasks(tasks []model.AiTask) {
c.tasks = tasks
}
// 设置用户权限code
func (c *Client) SetCodes(codes []string) {
c.codes = codes
}
// SendFunc 发送消息到客户端
func (c *Client) SendFunc(msg []byte) error {
if c.conn != nil {
return c.conn.WriteMessage(websocket.TextMessage, msg)
}
return ErrConnClosed
}
// 生成唯一的客户端ID
func generateClientID() string {
// 1. 时间戳
timestamp := time.Now().UnixNano()
binary.BigEndian.PutUint64(idBuf[:8], uint64(timestamp))
// 2. 随机数4字节
binary.BigEndian.PutUint32(idBuf[8:12], rng.Uint32())
// 3. 十六进制编码
n := pkg.HexEncode(idBuf[:12], idBuf[12:])
return string(idBuf[12 : 12+n])
}
// 连接数据验证和收集
func (c *Client) DataAuth() (err error) {
c.session = c.conn.Query("x-session", "")
if len(c.session) == 0 {
err = errors.SessionNotFound
return
}
c.auth = c.conn.Query("x-authorization", "")
if len(c.auth) == 0 {
err = errors.AuthNotFound
return
}
c.key = c.conn.Query("x-app-key", "")
if len(c.key) == 0 {
err = errors.KeyNotFound
return
}
// 系统编码
c.sysCode = c.conn.Query("x-sys-code", "")
if len(c.sysCode) == 0 {
err = errors.SysCodeNotFound
return
}
return
}
func (c *Client) InitHeartbeat(timeoutSecond time.Duration) {
ticker := time.NewTicker(timeoutSecond * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
//*2是防止丢包连续丢包两次再加5s网络延迟容错
if time.Since(c.LastActive) > (timeoutSecond*2*time.Second + 5) { // 5秒容错
log.Println("Heartbeat timeout", "id", c.id)
c.conn.WriteMessage(websocket.CloseMessage, []byte("Heartbeat timeout"))
c.conn.Close()
return
}
case <-c.Ctx.Done():
return
}
}
}