package services import ( "ai_scheduler/internal/data/constant" "ai_scheduler/internal/entitys" "ai_scheduler/internal/gateway" "encoding/hex" "encoding/json" "fmt" "log" "math/rand" "sync" "time" "github.com/gofiber/websocket/v2" ) // ChatHandler 聊天处理器 type ChatService struct { routerService entitys.RouterService Gw *gateway.Gateway mu sync.Mutex } // NewChatHandler 创建聊天处理器 func NewChatService(routerService entitys.RouterService, gw *gateway.Gateway) *ChatService { return &ChatService{ routerService: routerService, Gw: gw, } } // ToolCallResponse 工具调用响应 type ToolCallResponse struct { ID string `json:"id" example:"call_1"` Type string `json:"type" example:"function"` Function FunctionCallResponse `json:"function"` Result interface{} `json:"result,omitempty"` } // FunctionCallResponse 函数调用响应 type FunctionCallResponse struct { Name string `json:"name" example:"get_weather"` Arguments interface{} `json:"arguments"` } func (h *ChatService) ChatFail(c *websocket.Conn, content string) { err := c.WriteMessage(websocket.TextMessage, []byte(content)) if err != nil { log.Println("发送错误:", err) } _ = c.Close() } func generateClientID() string { // 使用时间戳+随机数确保唯一性 timestamp := time.Now().UnixNano() randomBytes := make([]byte, 4) rand.Read(randomBytes) randomStr := hex.EncodeToString(randomBytes) return fmt.Sprintf("%d%s", timestamp, randomStr) } func (h *ChatService) Chat(c *websocket.Conn) { h.mu.Lock() clientID := generateClientID() h.mu.Unlock() client := &gateway.Client{ ID: clientID, SendFunc: func(data []byte) error { return c.WriteMessage(websocket.TextMessage, data) }, } h.Gw.AddClient(client) log.Println("client connected:", clientID) log.Println("客户端已连接") // 循环读取客户端消息 for { messageType, message, err := c.ReadMessage() if err != nil { log.Println("读取错误:", err) break } //简单协议:bind: if c.Headers("Sec-Websocket-Protocol") == "bind" && c.Headers("X-Session") != "" { uid := c.Headers("X-Session") _ = h.Gw.BindUid(clientID, uid) log.Printf("bind %s -> uid:%s\n", clientID, uid) } msg, chatType := h.handleMessageToString(c, messageType, message) if chatType == constant.ConnStatusClosed { break } if chatType == constant.ConnStatusIgnore { continue } log.Printf("收到消息: %s", string(msg)) var req entitys.ChatSockRequest if err := json.Unmarshal(msg, &req); err != nil { log.Println("JSON parse error:", err) continue } err = h.routerService.RouteWithSocket(c, &req) if err != nil { log.Println("处理失败:", err) continue } } h.Gw.RemoveClient(clientID) _ = c.Close() log.Println("client disconnected:", clientID) } func (h *ChatService) handleMessageToString(c *websocket.Conn, msgType int, msg any) (text []byte, chatType constant.ConnStatus) { switch msgType { case websocket.TextMessage: return msg.([]byte), constant.ConnStatusNormal case websocket.BinaryMessage: return msg.([]byte), constant.ConnStatusNormal case websocket.CloseMessage: return nil, constant.ConnStatusClosed case websocket.PingMessage: // 可选:回复 Pong c.WriteMessage(websocket.PongMessage, nil) return nil, constant.ConnStatusIgnore case websocket.PongMessage: return nil, constant.ConnStatusIgnore default: return nil, constant.ConnStatusIgnore } return msg.([]byte), constant.ConnStatusIgnore }