From 4339b6eee80e5bacaa33061c7b4e5b0a7e70a3bd Mon Sep 17 00:00:00 2001 From: renzhiyuan <465386466@qq.com> Date: Thu, 4 Dec 2025 10:14:57 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E4=BC=98=E5=8C=96=E4=BB=BB=E5=8A=A1?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E4=B8=8E=E6=9D=83=E9=99=90=E6=8E=A7=E5=88=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/config_test.yaml | 4 +- internal/biz/llm_service/ollama.go | 1 + internal/config/config.go | 9 ++-- internal/gateway/client.go | 74 ++++++++++++++++++++++-------- internal/gateway/gateway.go | 25 ++++++++-- internal/pkg/func.go | 10 ++++ internal/server/router/router.go | 5 +- internal/services/chat.go | 48 ++++++++----------- 8 files changed, 114 insertions(+), 62 deletions(-) diff --git a/config/config_test.yaml b/config/config_test.yaml index 8275102..2929dd7 100644 --- a/config/config_test.yaml +++ b/config/config_test.yaml @@ -3,11 +3,12 @@ server: port: 8090 host: "0.0.0.0" + ollama: base_url: "http://127.0.0.1:11434" model: "qwen3-coder:480b-cloud" generate_model: "qwen3-coder:480b-cloud" - vl_model: "qwen2.5vl:7b" + vl_model: "gemini-3-pro-preview" timeout: "120s" level: "info" format: "json" @@ -19,6 +20,7 @@ sys: channel_pool_len: 100 channel_pool_size: 32 llm_pool_len: 5 + heartbeat_interval: 30 redis: host: 47.97.27.195:6379 type: node diff --git a/internal/biz/llm_service/ollama.go b/internal/biz/llm_service/ollama.go index 60d9c78..6527a3b 100644 --- a/internal/biz/llm_service/ollama.go +++ b/internal/biz/llm_service/ollama.go @@ -139,6 +139,7 @@ func (r *OllamaService) RecognizeWithImg(ctx context.Context, requireData *entit Prompt: r.config.DefaultPrompt.ImgRecognize.UserPrompt, Images: requireData.ImgByte, KeepAlive: &api.Duration{Duration: 3600 * time.Second}, + //Think: &api.ThinkValue{Value: false}, }) if err != nil { return diff --git a/internal/config/config.go b/internal/config/config.go index 5c9fa21..da25c39 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -36,10 +36,11 @@ type LLM struct { // SysConfig 系统配置 type SysConfig struct { - SessionLen int `mapstructure:"session_len"` - ChannelPoolLen int `mapstructure:"channel_pool_len"` - ChannelPoolSize int `mapstructure:"channel_pool_size"` - LlmPoolLen int `mapstructure:"llm_pool_len"` + SessionLen int `mapstructure:"session_len"` + ChannelPoolLen int `mapstructure:"channel_pool_len"` + ChannelPoolSize int `mapstructure:"channel_pool_size"` + LlmPoolLen int `mapstructure:"llm_pool_len"` + HeartbeatInterval int `mapstructure:"heartbeat_interval"` } // ServerConfig 服务器配置 diff --git a/internal/gateway/client.go b/internal/gateway/client.go index b49bf45..f90678c 100644 --- a/internal/gateway/client.go +++ b/internal/gateway/client.go @@ -3,33 +3,44 @@ package gateway import ( errors "ai_scheduler/internal/data/error" "ai_scheduler/internal/data/model" - "encoding/hex" - "fmt" - "github.com/gofiber/websocket/v2" + "ai_scheduler/internal/pkg" + "context" + "encoding/binary" + "log" "math/rand" "time" + + "github.com/gofiber/websocket/v2" ) var ( ErrConnClosed = errors.SysErr("连接不存在或已关闭") + rng = rand.New(rand.NewSource(time.Now().UnixNano())) + idBuf = make([]byte, 20) ) type Client struct { - id string // 客户端唯一ID - conn *websocket.Conn // WebSocket 连接 - session string // 会话ID - key string // 应用密钥 - auth string // 用户凭证token - codes []string // 用户权限code - sysInfo *model.AiSy // 系统信息 - tasks []model.AiTask // 任务列表 - sysCode string // 系统编码 + id string // 客户端唯一ID + conn *websocket.Conn // WebSocket 连接 + session string // 会话ID + key string // 应用密钥 + auth string // 用户凭证token + codes []string // 用户权限code + sysInfo *model.AiSy // 系统信息 + tasks []model.AiTask // 任务列表 + sysCode string // 系统编码 + Ctx context.Context + Cancel context.CancelFunc + LastActive time.Time } -func NewClient(conn *websocket.Conn) *Client { +func NewClient(conn *websocket.Conn, ctx context.Context, cancel context.CancelFunc) *Client { + return &Client{ - id: generateClientID(), - conn: conn, + id: generateClientID(), + conn: conn, + Ctx: ctx, + Cancel: cancel, } } @@ -103,12 +114,16 @@ func (c *Client) SendFunc(msg []byte) error { // 生成唯一的客户端ID func generateClientID() string { - // 使用时间戳+随机数确保唯一性 + // 1. 时间戳 timestamp := time.Now().UnixNano() - randomBytes := make([]byte, 4) - rand.Read(randomBytes) - randomStr := hex.EncodeToString(randomBytes) - return fmt.Sprintf("%d%s", timestamp, randomStr) + binary.BigEndian.PutUint64(idBuf[:8], uint64(timestamp)) + + // 2. 随机数(4字节) + binary.BigEndian.PutUint32(idBuf[8:12], rng.Uint32()) + + // 3. 十六进制编码 + n := pkg.HexEncode(idBuf[:12], idBuf[12:]) + return string(idBuf[12 : 12+n]) } // 连接数据验证和收集 @@ -136,3 +151,22 @@ func (c *Client) DataAuth() (err error) { } return } + +func (c *Client) InitHeartbeat(timeoutSecond time.Duration) { + ticker := time.NewTicker(timeoutSecond * time.Second) + defer ticker.Stop() + for { + select { + case <-ticker.C: + //*2是防止丢包,连续丢包两次,再加5s网络延迟容错 + if time.Since(c.LastActive) > (timeoutSecond*2*time.Second + 5) { // 5秒容错 + log.Println("Heartbeat timeout", "id", c.id) + c.conn.WriteMessage(websocket.CloseMessage, []byte("Heartbeat timeout")) + c.conn.Close() + return + } + case <-c.Ctx.Done(): + return + } + } +} diff --git a/internal/gateway/gateway.go b/internal/gateway/gateway.go index 0f0e6f3..03ce961 100644 --- a/internal/gateway/gateway.go +++ b/internal/gateway/gateway.go @@ -2,7 +2,9 @@ package gateway import ( "errors" + "log" "sync" + "time" ) type Gateway struct { @@ -20,14 +22,27 @@ func NewGateway() *Gateway { func (g *Gateway) AddClient(c *Client) { g.mu.Lock() - defer g.mu.Unlock() + defer func() { + g.mu.Unlock() + //心跳开始计时 + c.LastActive = time.Now() + log.Println("client connected:", c.GetID()) + log.Println("客户端已连接") + }() g.clients[c.GetID()] = c } -func (g *Gateway) RemoveClient(clientID string) { +func (g *Gateway) Cleanup(clientID string) { g.mu.Lock() - defer g.mu.Unlock() - delete(g.clients, clientID) + defer func() { + if c, ex := g.clients[clientID]; ex { + delete(g.clients, clientID) + _ = c.conn.Close() + c.Cancel() + } + g.mu.Unlock() + log.Println("client disconnected:", clientID) + }() for uid, list := range g.uidMap { newList := []string{} for _, cid := range list { @@ -37,6 +52,7 @@ func (g *Gateway) RemoveClient(clientID string) { } g.uidMap[uid] = newList } + } func (g *Gateway) SendToAll(msg []byte) { @@ -63,6 +79,7 @@ func (g *Gateway) BindUid(clientID, uid string) error { return errors.New("client not found") } g.uidMap[uid] = append(g.uidMap[uid], clientID) + log.Printf("bind %s -> uid:%s\n", clientID, uid) return nil } diff --git a/internal/pkg/func.go b/internal/pkg/func.go index 4e6481a..f6006ac 100644 --- a/internal/pkg/func.go +++ b/internal/pkg/func.go @@ -55,3 +55,13 @@ func ValidateImageURL(rawURL string) error { return nil } + +// hexEncode 将 src 的二进制数据编码为十六进制字符串,写入 dst,返回写入长度 +func HexEncode(src, dst []byte) int { + const hextable = "0123456789abcdef" + for i := 0; i < len(src); i++ { + dst[i*2] = hextable[src[i]>>4] + dst[i*2+1] = hextable[src[i]&0xf] + } + return len(src) * 2 +} diff --git a/internal/server/router/router.go b/internal/server/router/router.go index 5c935d7..1a6fd8b 100644 --- a/internal/server/router/router.go +++ b/internal/server/router/router.go @@ -82,10 +82,7 @@ func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionServi func routerSocket(app *fiber.App, chatService *services.ChatService) { ws := app.Group("ws/v1/") // WebSocket 路由配置 - ws.Get("/chat", websocket.New(func(c *websocket.Conn) { - // 可以在这里添加握手前的中间件逻辑(如头校验) - chatService.Chat(c) // 调用实际的 Chat 处理函数 - }, websocket.Config{ + ws.Get("/chat", websocket.New(chatService.Chat, websocket.Config{ // 可选配置:跨域检查、最大负载大小等 HandshakeTimeout: 10 * time.Second, //Subprotocols: []string{"json", "msgpack"}, diff --git a/internal/services/chat.go b/internal/services/chat.go index ff75bee..a01ba73 100644 --- a/internal/services/chat.go +++ b/internal/services/chat.go @@ -6,9 +6,11 @@ import ( "ai_scheduler/internal/data/constants" "ai_scheduler/internal/entitys" "ai_scheduler/internal/gateway" + "context" "encoding/json" "log" "sync" + "time" "github.com/gofiber/fiber/v2" "github.com/gofiber/websocket/v2" @@ -38,20 +40,6 @@ func NewChatService( } } -// ToolCallResponse 工具调用响应 -type ToolCallResponse struct { - ID string `json:"id" example:"call_1"` - Type string `json:"type" example:"function"` - Function FunctionCallResponse `json:"function"` - Result interface{} `json:"result,omitempty"` -} - -// FunctionCallResponse 函数调用响应 -type FunctionCallResponse struct { - Name string `json:"name" example:"get_weather"` - Arguments interface{} `json:"arguments"` -} - func (h *ChatService) ChatFail(c *websocket.Conn, content string) { err := c.WriteMessage(websocket.TextMessage, []byte(content)) if err != nil { @@ -63,15 +51,18 @@ func (h *ChatService) ChatFail(c *websocket.Conn, content string) { // Chat 处理WebSocket聊天连接 // 这是WebSocket处理的主入口函数 func (h *ChatService) Chat(c *websocket.Conn) { - // 创建新的客户端实例 - h.mu.Lock() - client := gateway.NewClient(c) - h.mu.Unlock() + ctx, cancel := context.WithCancel(context.Background()) + // 创建新的客户端实例 + client := gateway.NewClient(c, ctx, cancel) + // 心跳检测 + go client.InitHeartbeat(time.Duration(h.cfg.Sys.HeartbeatInterval)) // 将客户端添加到网关管理 h.Gw.AddClient(client) - log.Println("client connected:", client.GetID()) - log.Println("客户端已连接") + // 确保在函数返回时移除客户端并关闭连接 + defer func() { + h.Gw.Cleanup(client.GetID()) + }() // 绑定会话ID uid := c.Query("x-session") @@ -79,7 +70,6 @@ func (h *ChatService) Chat(c *websocket.Conn) { if err := h.Gw.BindUid(client.GetID(), uid); err != nil { log.Println("绑定UID错误:", err) } - log.Printf("bind %s -> uid:%s\n", client.GetID(), uid) } // 验证并收集连接数据,后续对话中会使用 @@ -89,13 +79,6 @@ func (h *ChatService) Chat(c *websocket.Conn) { return } - // 确保在函数返回时移除客户端并关闭连接 - defer func() { - h.Gw.RemoveClient(client.GetID()) - _ = c.Close() - log.Println("client disconnected:", client.GetID()) - }() - // 循环读取客户端消息 for { // 读取消息 @@ -104,7 +87,14 @@ func (h *ChatService) Chat(c *websocket.Conn) { log.Println("读取错误:", err) break } - + //if string(message) == `{"type":"ping"}` { + // client.LastActive = time.Now() + // if err := c.WriteMessage(websocket.TextMessage, []byte(`{"type":"pong"}`)); err != nil { + // log.Println("Heartbeat response failed", "id", client.GetID(), "err", err) + // return + // } + // continue + //} // 处理消息 msg, chatType := h.handleMessageToString(c, messageType, message) if chatType == constants.ConnStatusClosed {