198 lines
5.0 KiB
Go
198 lines
5.0 KiB
Go
package services
|
|
|
|
import (
|
|
"ai_scheduler/internal/biz"
|
|
"ai_scheduler/internal/config"
|
|
"ai_scheduler/internal/data/constants"
|
|
"ai_scheduler/internal/entitys"
|
|
"ai_scheduler/internal/gateway"
|
|
"context"
|
|
"encoding/json"
|
|
"log"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gofiber/fiber/v2"
|
|
"github.com/gofiber/websocket/v2"
|
|
)
|
|
|
|
// ChatHandler 聊天处理器
|
|
type ChatService struct {
|
|
routerBiz *biz.AiRouterBiz
|
|
Gw *gateway.Gateway
|
|
mu sync.Mutex
|
|
ChatHis *biz.ChatHistoryBiz
|
|
cfg *config.Config
|
|
}
|
|
|
|
// NewChatHandler 创建聊天处理器
|
|
func NewChatService(
|
|
routerService *biz.AiRouterBiz,
|
|
chatHis *biz.ChatHistoryBiz,
|
|
gw *gateway.Gateway,
|
|
cfg *config.Config,
|
|
) *ChatService {
|
|
return &ChatService{
|
|
routerBiz: routerService,
|
|
Gw: gw,
|
|
ChatHis: chatHis,
|
|
cfg: cfg,
|
|
}
|
|
}
|
|
|
|
func (h *ChatService) ChatFail(c *websocket.Conn, content string) {
|
|
err := c.WriteMessage(websocket.TextMessage, []byte(content))
|
|
if err != nil {
|
|
log.Println("发送错误:", err)
|
|
}
|
|
_ = c.Close()
|
|
}
|
|
|
|
// Chat 处理WebSocket聊天连接
|
|
// 这是WebSocket处理的主入口函数
|
|
func (h *ChatService) Chat(c *websocket.Conn) {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
// 创建新的客户端实例
|
|
client := gateway.NewClient(c, ctx, cancel)
|
|
|
|
// 验证并收集连接数据,后续对话中会使用
|
|
if err := client.DataAuth(); err != nil {
|
|
log.Println("数据验证错误:", err)
|
|
_ = client.SendFunc([]byte(err.Error()))
|
|
client.Close()
|
|
return
|
|
}
|
|
|
|
// 验证通过后,将客户端添加到网关管理
|
|
h.Gw.AddClient(client)
|
|
|
|
// 使用信号量限制并发处理的消息数量
|
|
semaphore := make(chan struct{}, 1) // 最多1个并发消息处理
|
|
// 用于等待所有goroutine完成的wait group
|
|
var wg sync.WaitGroup
|
|
// 确保在函数返回时移除客户端并关闭连接
|
|
defer func() {
|
|
wg.Wait() // 等待所有消息处理goroutine完成
|
|
close(semaphore) // 关闭信号量通道
|
|
h.Gw.Cleanup(client.GetID())
|
|
}()
|
|
|
|
// 绑定会话ID, sessionId 为空时, 则不绑定
|
|
if uid := client.GetSession(); uid != "" {
|
|
if err := h.Gw.BindUid(client.GetID(), uid); err != nil {
|
|
log.Println("绑定UID错误:", err)
|
|
}
|
|
}
|
|
|
|
// 开启心跳检测
|
|
go client.InitHeartbeat(time.Duration(h.cfg.Sys.HeartbeatInterval))
|
|
|
|
// 循环读取客户端消息
|
|
for {
|
|
messageType, message, err := client.ReadMessage()
|
|
if err != nil {
|
|
log.Printf("读取错误: %v, 客户端ID: %s", err, client.GetID())
|
|
break
|
|
}
|
|
|
|
// 处理心跳消息
|
|
if messageType == websocket.PingMessage || string(message) == "PING" {
|
|
client.LastActive = time.Now()
|
|
msgType := websocket.TextMessage
|
|
if messageType == websocket.PingMessage {
|
|
msgType = websocket.PongMessage
|
|
}
|
|
if err = client.SendMessage(msgType, []byte(`PONG`)); err != nil {
|
|
log.Printf("发送pong消息失败: %v", err)
|
|
}
|
|
continue
|
|
}
|
|
|
|
// 使用信号量限制并发
|
|
semaphore <- struct{}{}
|
|
wg.Add(1)
|
|
go func(msgType int, msg []byte) {
|
|
defer func() {
|
|
<-semaphore
|
|
wg.Done()
|
|
// 恢复panic
|
|
if r := recover(); r != nil {
|
|
log.Printf("消息处理goroutine发生panic: %v", r)
|
|
}
|
|
}()
|
|
|
|
// 消息处理逻辑
|
|
h.processMessage(client, msgType, msg)
|
|
}(messageType, message)
|
|
}
|
|
}
|
|
|
|
// 将消息处理逻辑提取到单独的方法
|
|
func (h *ChatService) processMessage(client *gateway.Client, msgType int, msg []byte) {
|
|
// 处理消息
|
|
processedMsg, _ := h.handleMessageToString(client, msgType, msg)
|
|
log.Printf("收到消息:消息类型 %d, 内容 %s, 客户端ID: %s",
|
|
msgType, string(processedMsg), client.GetID())
|
|
|
|
// 解析请求
|
|
var req entitys.ChatSockRequest
|
|
if err := json.Unmarshal(processedMsg, &req); err != nil {
|
|
log.Printf("JSON解析错误: %v, 客户端ID: %s", err, client.GetID())
|
|
return
|
|
}
|
|
|
|
// 路由处理请求
|
|
if err := h.routerBiz.RouteWithSocket(client, &req); err != nil {
|
|
log.Printf("处理失败: %v, 客户端ID: %s", err, client.GetID())
|
|
}
|
|
}
|
|
|
|
// handleMessageToString 处理不同类型的WebSocket消息
|
|
// 参数:
|
|
// - client: 客户端对象
|
|
// - msgType: 消息类型
|
|
// - msg: 消息内容
|
|
//
|
|
// 返回:
|
|
// - text: 处理后的文本内容
|
|
// - chatType: 连接状态
|
|
func (h *ChatService) handleMessageToString(client *gateway.Client, msgType int, msg any) (text []byte, chatType constants.ConnStatus) {
|
|
switch msgType {
|
|
case websocket.TextMessage:
|
|
return msg.([]byte), constants.ConnStatusNormal
|
|
case websocket.BinaryMessage:
|
|
return msg.([]byte), constants.ConnStatusNormal
|
|
case websocket.CloseMessage:
|
|
|
|
return nil, constants.ConnStatusClosed
|
|
case websocket.PingMessage:
|
|
return nil, constants.ConnStatusIgnore
|
|
case websocket.PongMessage:
|
|
return nil, constants.ConnStatusIgnore
|
|
default:
|
|
log.Printf("未知的消息类型: %d", msgType)
|
|
return nil, constants.ConnStatusIgnore
|
|
}
|
|
}
|
|
|
|
func (s *ChatService) Useful(c *fiber.Ctx) error {
|
|
req := &entitys.UseFulRequest{}
|
|
if err := c.BodyParser(req); err != nil {
|
|
return err
|
|
}
|
|
|
|
err := s.ChatHis.Update(c.Context(), req)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *ChatService) UsefulList(c *fiber.Ctx) error {
|
|
|
|
return c.JSON(constants.UseFulMap)
|
|
}
|