134 lines
3.5 KiB
Go
134 lines
3.5 KiB
Go
package services
|
||
|
||
import (
|
||
"ai_scheduler/internal/data/constant"
|
||
"ai_scheduler/internal/entitys"
|
||
"ai_scheduler/internal/gateway"
|
||
"encoding/hex"
|
||
"encoding/json"
|
||
"fmt"
|
||
"log"
|
||
"math/rand"
|
||
"sync"
|
||
"time"
|
||
|
||
"github.com/gofiber/websocket/v2"
|
||
)
|
||
|
||
// ChatHandler 聊天处理器
|
||
type ChatService struct {
|
||
routerService entitys.RouterService
|
||
Gw *gateway.Gateway
|
||
mu sync.Mutex
|
||
}
|
||
|
||
// NewChatHandler 创建聊天处理器
|
||
func NewChatService(routerService entitys.RouterService, gw *gateway.Gateway) *ChatService {
|
||
return &ChatService{
|
||
routerService: routerService,
|
||
Gw: gw,
|
||
}
|
||
}
|
||
|
||
// ToolCallResponse 工具调用响应
|
||
type ToolCallResponse struct {
|
||
ID string `json:"id" example:"call_1"`
|
||
Type string `json:"type" example:"function"`
|
||
Function FunctionCallResponse `json:"function"`
|
||
Result interface{} `json:"result,omitempty"`
|
||
}
|
||
|
||
// FunctionCallResponse 函数调用响应
|
||
type FunctionCallResponse struct {
|
||
Name string `json:"name" example:"get_weather"`
|
||
Arguments interface{} `json:"arguments"`
|
||
}
|
||
|
||
func (h *ChatService) ChatFail(c *websocket.Conn, content string) {
|
||
err := c.WriteMessage(websocket.TextMessage, []byte(content))
|
||
if err != nil {
|
||
log.Println("发送错误:", err)
|
||
}
|
||
_ = c.Close()
|
||
}
|
||
|
||
func generateClientID() string {
|
||
// 使用时间戳+随机数确保唯一性
|
||
timestamp := time.Now().UnixNano()
|
||
randomBytes := make([]byte, 4)
|
||
rand.Read(randomBytes)
|
||
randomStr := hex.EncodeToString(randomBytes)
|
||
return fmt.Sprintf("%d%s", timestamp, randomStr)
|
||
}
|
||
func (h *ChatService) Chat(c *websocket.Conn) {
|
||
h.mu.Lock()
|
||
clientID := generateClientID()
|
||
h.mu.Unlock()
|
||
client := &gateway.Client{
|
||
ID: clientID,
|
||
SendFunc: func(data []byte) error {
|
||
return c.WriteMessage(websocket.TextMessage, data)
|
||
},
|
||
}
|
||
h.Gw.AddClient(client)
|
||
log.Println("client connected:", clientID)
|
||
log.Println("客户端已连接")
|
||
// 循环读取客户端消息
|
||
for {
|
||
messageType, message, err := c.ReadMessage()
|
||
if err != nil {
|
||
log.Println("读取错误:", err)
|
||
break
|
||
}
|
||
//简单协议:bind:<uid>
|
||
if c.Headers("Sec-Websocket-Protocol") == "bind" && c.Headers("X-Session") != "" {
|
||
uid := c.Headers("X-Session")
|
||
_ = h.Gw.BindUid(clientID, uid)
|
||
log.Printf("bind %s -> uid:%s\n", clientID, uid)
|
||
}
|
||
msg, chatType := h.handleMessageToString(c, messageType, message)
|
||
if chatType == constant.ConnStatusClosed {
|
||
break
|
||
}
|
||
if chatType == constant.ConnStatusIgnore {
|
||
continue
|
||
}
|
||
|
||
log.Printf("收到消息: %s", string(msg))
|
||
var req entitys.ChatSockRequest
|
||
if err := json.Unmarshal(msg, &req); err != nil {
|
||
log.Println("JSON parse error:", err)
|
||
continue
|
||
}
|
||
err = h.routerService.RouteWithSocket(c, &req)
|
||
if err != nil {
|
||
log.Println("处理失败:", err)
|
||
continue
|
||
}
|
||
}
|
||
h.Gw.RemoveClient(clientID)
|
||
_ = c.Close()
|
||
log.Println("client disconnected:", clientID)
|
||
}
|
||
|
||
func (h *ChatService) handleMessageToString(c *websocket.Conn, msgType int, msg any) (text []byte, chatType constant.ConnStatus) {
|
||
switch msgType {
|
||
case websocket.TextMessage:
|
||
return msg.([]byte), constant.ConnStatusNormal
|
||
case websocket.BinaryMessage:
|
||
return msg.([]byte), constant.ConnStatusNormal
|
||
case websocket.CloseMessage:
|
||
|
||
return nil, constant.ConnStatusClosed
|
||
case websocket.PingMessage:
|
||
// 可选:回复 Pong
|
||
c.WriteMessage(websocket.PongMessage, nil)
|
||
return nil, constant.ConnStatusIgnore
|
||
case websocket.PongMessage:
|
||
return nil, constant.ConnStatusIgnore
|
||
default:
|
||
return nil, constant.ConnStatusIgnore
|
||
}
|
||
return msg.([]byte), constant.ConnStatusIgnore
|
||
}
|