package services import ( "ai_scheduler/internal/biz" "ai_scheduler/internal/config" "ai_scheduler/internal/data/constants" "ai_scheduler/internal/entitys" "ai_scheduler/internal/gateway" "encoding/json" "log" "sync" "github.com/gofiber/fiber/v2" "github.com/gofiber/websocket/v2" ) // ChatHandler 聊天处理器 type ChatService struct { routerBiz *biz.AiRouterBiz Gw *gateway.Gateway mu sync.Mutex ChatHis *biz.ChatHistoryBiz cfg *config.Config } // NewChatHandler 创建聊天处理器 func NewChatService( routerService *biz.AiRouterBiz, chatHis *biz.ChatHistoryBiz, gw *gateway.Gateway, cfg *config.Config, ) *ChatService { return &ChatService{ routerBiz: routerService, Gw: gw, ChatHis: chatHis, cfg: cfg, } } // ToolCallResponse 工具调用响应 type ToolCallResponse struct { ID string `json:"id" example:"call_1"` Type string `json:"type" example:"function"` Function FunctionCallResponse `json:"function"` Result interface{} `json:"result,omitempty"` } // FunctionCallResponse 函数调用响应 type FunctionCallResponse struct { Name string `json:"name" example:"get_weather"` Arguments interface{} `json:"arguments"` } func (h *ChatService) ChatFail(c *websocket.Conn, content string) { err := c.WriteMessage(websocket.TextMessage, []byte(content)) if err != nil { log.Println("发送错误:", err) } _ = c.Close() } // Chat 处理WebSocket聊天连接 // 这是WebSocket处理的主入口函数 func (h *ChatService) Chat(c *websocket.Conn) { // 创建新的客户端实例 h.mu.Lock() client := gateway.NewClient(c) h.mu.Unlock() // 将客户端添加到网关管理 h.Gw.AddClient(client) log.Println("client connected:", client.GetID()) log.Println("客户端已连接") // 绑定会话ID uid := c.Query("x-session") if uid != "" { if err := h.Gw.BindUid(client.GetID(), uid); err != nil { log.Println("绑定UID错误:", err) } log.Printf("bind %s -> uid:%s\n", client.GetID(), uid) } // 验证并收集连接数据,后续对话中会使用 if err := client.DataAuth(); err != nil { log.Println("数据验证错误:", err) h.ChatFail(c, err.Error()) return } // 确保在函数返回时移除客户端并关闭连接 defer func() { h.Gw.RemoveClient(client.GetID()) _ = c.Close() log.Println("client disconnected:", client.GetID()) }() // 循环读取客户端消息 for { // 读取消息 messageType, message, err := c.ReadMessage() if err != nil { log.Println("读取错误:", err) break } // 处理消息 msg, chatType := h.handleMessageToString(c, messageType, message) if chatType == constants.ConnStatusClosed { break } if chatType == constants.ConnStatusIgnore { continue } log.Printf("收到消息: %s", string(msg)) // 解析请求 var req entitys.ChatSockRequest if err = json.Unmarshal(msg, &req); err != nil { log.Println("JSON parse error:", err) continue } // 路由处理请求 err = h.routerBiz.RouteWithSocket(client, &req) if err != nil { log.Println("处理失败:", err) } } } // handleMessageToString 处理不同类型的WebSocket消息 // 参数: // - c: WebSocket连接 // - msgType: 消息类型 // - msg: 消息内容 // // 返回: // - text: 处理后的文本内容 // - chatType: 连接状态 func (h *ChatService) handleMessageToString(c *websocket.Conn, msgType int, msg any) (text []byte, chatType constants.ConnStatus) { switch msgType { case websocket.TextMessage: return msg.([]byte), constants.ConnStatusNormal case websocket.BinaryMessage: return msg.([]byte), constants.ConnStatusNormal case websocket.CloseMessage: return nil, constants.ConnStatusClosed case websocket.PingMessage: // 可选:回复 Pong c.WriteMessage(websocket.PongMessage, nil) return nil, constants.ConnStatusIgnore case websocket.PongMessage: return nil, constants.ConnStatusIgnore default: return nil, constants.ConnStatusIgnore } return msg.([]byte), constants.ConnStatusIgnore } func (s *ChatService) Useful(c *fiber.Ctx) error { req := &entitys.UseFulRequest{} if err := c.BodyParser(req); err != nil { return err } err := s.ChatHis.Update(c.Context(), req) if err != nil { return err } return nil } func (s *ChatService) UsefulList(c *fiber.Ctx) error { return c.JSON(constants.UseFulMap) }