ai_scheduler/internal/services/chat.go

134 lines
3.5 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package services
import (
"ai_scheduler/internal/data/constant"
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/gateway"
"encoding/hex"
"encoding/json"
"fmt"
"log"
"math/rand"
"sync"
"time"
"github.com/gofiber/websocket/v2"
)
// ChatHandler 聊天处理器
type ChatService struct {
routerService entitys.RouterService
Gw *gateway.Gateway
mu sync.Mutex
}
// NewChatHandler 创建聊天处理器
func NewChatService(routerService entitys.RouterService, gw *gateway.Gateway) *ChatService {
return &ChatService{
routerService: routerService,
Gw: gw,
}
}
// ToolCallResponse 工具调用响应
type ToolCallResponse struct {
ID string `json:"id" example:"call_1"`
Type string `json:"type" example:"function"`
Function FunctionCallResponse `json:"function"`
Result interface{} `json:"result,omitempty"`
}
// FunctionCallResponse 函数调用响应
type FunctionCallResponse struct {
Name string `json:"name" example:"get_weather"`
Arguments interface{} `json:"arguments"`
}
func (h *ChatService) ChatFail(c *websocket.Conn, content string) {
err := c.WriteMessage(websocket.TextMessage, []byte(content))
if err != nil {
log.Println("发送错误:", err)
}
_ = c.Close()
}
func generateClientID() string {
// 使用时间戳+随机数确保唯一性
timestamp := time.Now().UnixNano()
randomBytes := make([]byte, 4)
rand.Read(randomBytes)
randomStr := hex.EncodeToString(randomBytes)
return fmt.Sprintf("%d%s", timestamp, randomStr)
}
func (h *ChatService) Chat(c *websocket.Conn) {
h.mu.Lock()
clientID := generateClientID()
h.mu.Unlock()
client := &gateway.Client{
ID: clientID,
SendFunc: func(data []byte) error {
return c.WriteMessage(websocket.TextMessage, data)
},
}
h.Gw.AddClient(client)
log.Println("client connected:", clientID)
log.Println("客户端已连接")
// 循环读取客户端消息
for {
messageType, message, err := c.ReadMessage()
if err != nil {
log.Println("读取错误:", err)
break
}
//简单协议bind:<uid>
if c.Headers("Sec-Websocket-Protocol") == "bind" && c.Headers("X-Session") != "" {
uid := c.Headers("X-Session")
_ = h.Gw.BindUid(clientID, uid)
log.Printf("bind %s -> uid:%s\n", clientID, uid)
}
msg, chatType := h.handleMessageToString(c, messageType, message)
if chatType == constant.ConnStatusClosed {
break
}
if chatType == constant.ConnStatusIgnore {
continue
}
log.Printf("收到消息: %s", string(msg))
var req entitys.ChatSockRequest
if err := json.Unmarshal(msg, &req); err != nil {
log.Println("JSON parse error:", err)
continue
}
err = h.routerService.RouteWithSocket(c, &req)
if err != nil {
log.Println("处理失败:", err)
continue
}
}
h.Gw.RemoveClient(clientID)
_ = c.Close()
log.Println("client disconnected:", clientID)
}
func (h *ChatService) handleMessageToString(c *websocket.Conn, msgType int, msg any) (text []byte, chatType constant.ConnStatus) {
switch msgType {
case websocket.TextMessage:
return msg.([]byte), constant.ConnStatusNormal
case websocket.BinaryMessage:
return msg.([]byte), constant.ConnStatusNormal
case websocket.CloseMessage:
return nil, constant.ConnStatusClosed
case websocket.PingMessage:
// 可选:回复 Pong
c.WriteMessage(websocket.PongMessage, nil)
return nil, constant.ConnStatusIgnore
case websocket.PongMessage:
return nil, constant.ConnStatusIgnore
default:
return nil, constant.ConnStatusIgnore
}
return msg.([]byte), constant.ConnStatusIgnore
}