ai_scheduler/internal/services/chat.go

183 lines
4.3 KiB
Go

package services
import (
"ai_scheduler/internal/biz"
"ai_scheduler/internal/config"
"ai_scheduler/internal/data/constants"
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/gateway"
"encoding/json"
"log"
"sync"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/websocket/v2"
)
// ChatHandler 聊天处理器
type ChatService struct {
routerBiz *biz.AiRouterBiz
Gw *gateway.Gateway
mu sync.Mutex
ChatHis *biz.ChatHistoryBiz
cfg *config.Config
}
// NewChatHandler 创建聊天处理器
func NewChatService(
routerService *biz.AiRouterBiz,
chatHis *biz.ChatHistoryBiz,
gw *gateway.Gateway,
cfg *config.Config,
) *ChatService {
return &ChatService{
routerBiz: routerService,
Gw: gw,
ChatHis: chatHis,
cfg: cfg,
}
}
// ToolCallResponse 工具调用响应
type ToolCallResponse struct {
ID string `json:"id" example:"call_1"`
Type string `json:"type" example:"function"`
Function FunctionCallResponse `json:"function"`
Result interface{} `json:"result,omitempty"`
}
// FunctionCallResponse 函数调用响应
type FunctionCallResponse struct {
Name string `json:"name" example:"get_weather"`
Arguments interface{} `json:"arguments"`
}
func (h *ChatService) ChatFail(c *websocket.Conn, content string) {
err := c.WriteMessage(websocket.TextMessage, []byte(content))
if err != nil {
log.Println("发送错误:", err)
}
_ = c.Close()
}
// Chat 处理WebSocket聊天连接
// 这是WebSocket处理的主入口函数
func (h *ChatService) Chat(c *websocket.Conn) {
// 创建新的客户端实例
h.mu.Lock()
client := gateway.NewClient(c)
h.mu.Unlock()
// 将客户端添加到网关管理
h.Gw.AddClient(client)
log.Println("client connected:", client.GetID())
log.Println("客户端已连接")
// 绑定会话ID
uid := c.Query("x-session")
if uid != "" {
if err := h.Gw.BindUid(client.GetID(), uid); err != nil {
log.Println("绑定UID错误:", err)
}
log.Printf("bind %s -> uid:%s\n", client.GetID(), uid)
}
// 验证并收集连接数据,后续对话中会使用
if err := client.DataAuth(); err != nil {
log.Println("数据验证错误:", err)
h.ChatFail(c, err.Error())
return
}
// 确保在函数返回时移除客户端并关闭连接
defer func() {
h.Gw.RemoveClient(client.GetID())
_ = c.Close()
log.Println("client disconnected:", client.GetID())
}()
// 循环读取客户端消息
for {
// 读取消息
messageType, message, err := c.ReadMessage()
if err != nil {
log.Println("读取错误:", err)
break
}
// 处理消息
msg, chatType := h.handleMessageToString(c, messageType, message)
if chatType == constants.ConnStatusClosed {
break
}
if chatType == constants.ConnStatusIgnore {
continue
}
log.Printf("收到消息: %s", string(msg))
// 解析请求
var req entitys.ChatSockRequest
if err = json.Unmarshal(msg, &req); err != nil {
log.Println("JSON parse error:", err)
continue
}
// 路由处理请求
err = h.routerBiz.RouteWithSocket(client, &req)
if err != nil {
log.Println("处理失败:", err)
}
}
}
// handleMessageToString 处理不同类型的WebSocket消息
// 参数:
// - c: WebSocket连接
// - msgType: 消息类型
// - msg: 消息内容
//
// 返回:
// - text: 处理后的文本内容
// - chatType: 连接状态
func (h *ChatService) handleMessageToString(c *websocket.Conn, msgType int, msg any) (text []byte, chatType constants.ConnStatus) {
switch msgType {
case websocket.TextMessage:
return msg.([]byte), constants.ConnStatusNormal
case websocket.BinaryMessage:
return msg.([]byte), constants.ConnStatusNormal
case websocket.CloseMessage:
return nil, constants.ConnStatusClosed
case websocket.PingMessage:
// 可选:回复 Pong
c.WriteMessage(websocket.PongMessage, nil)
return nil, constants.ConnStatusIgnore
case websocket.PongMessage:
return nil, constants.ConnStatusIgnore
default:
return nil, constants.ConnStatusIgnore
}
return msg.([]byte), constants.ConnStatusIgnore
}
func (s *ChatService) Useful(c *fiber.Ctx) error {
req := &entitys.UseFulRequest{}
if err := c.BodyParser(req); err != nil {
return err
}
err := s.ChatHis.Update(c.Context(), req)
if err != nil {
return err
}
return nil
}
func (s *ChatService) UsefulList(c *fiber.Ctx) error {
return c.JSON(constants.UseFulMap)
}