243 lines
5.7 KiB
Go
243 lines
5.7 KiB
Go
package services
|
|
|
|
import (
|
|
"ai_scheduler/internal/biz"
|
|
"ai_scheduler/internal/config"
|
|
"ai_scheduler/internal/data/constants"
|
|
errors "ai_scheduler/internal/data/error"
|
|
"ai_scheduler/internal/entitys"
|
|
"ai_scheduler/internal/gateway"
|
|
"ai_scheduler/internal/pkg/l_request"
|
|
"encoding/json"
|
|
"github.com/gofiber/fiber/v2"
|
|
"github.com/gofiber/websocket/v2"
|
|
"log"
|
|
"net/http"
|
|
"sync"
|
|
)
|
|
|
|
// ChatHandler 聊天处理器
|
|
type ChatService struct {
|
|
routerBiz *biz.AiRouterBiz
|
|
Gw *gateway.Gateway
|
|
mu sync.Mutex
|
|
ChatHis *biz.ChatHistoryBiz
|
|
cfg *config.Config
|
|
}
|
|
|
|
// NewChatHandler 创建聊天处理器
|
|
func NewChatService(
|
|
routerService *biz.AiRouterBiz,
|
|
chatHis *biz.ChatHistoryBiz,
|
|
gw *gateway.Gateway,
|
|
cfg *config.Config,
|
|
) *ChatService {
|
|
return &ChatService{
|
|
routerBiz: routerService,
|
|
Gw: gw,
|
|
ChatHis: chatHis,
|
|
cfg: cfg,
|
|
}
|
|
}
|
|
|
|
// ToolCallResponse 工具调用响应
|
|
type ToolCallResponse struct {
|
|
ID string `json:"id" example:"call_1"`
|
|
Type string `json:"type" example:"function"`
|
|
Function FunctionCallResponse `json:"function"`
|
|
Result interface{} `json:"result,omitempty"`
|
|
}
|
|
|
|
// FunctionCallResponse 函数调用响应
|
|
type FunctionCallResponse struct {
|
|
Name string `json:"name" example:"get_weather"`
|
|
Arguments interface{} `json:"arguments"`
|
|
}
|
|
|
|
func (h *ChatService) ChatFail(c *websocket.Conn, content string) {
|
|
err := c.WriteMessage(websocket.TextMessage, []byte(content))
|
|
if err != nil {
|
|
log.Println("发送错误:", err)
|
|
}
|
|
_ = c.Close()
|
|
}
|
|
|
|
// Chat 处理WebSocket聊天连接
|
|
// 这是WebSocket处理的主入口函数
|
|
func (h *ChatService) Chat(c *websocket.Conn) {
|
|
// 创建新的客户端实例
|
|
h.mu.Lock()
|
|
client := gateway.NewClient(c)
|
|
h.mu.Unlock()
|
|
|
|
// 将客户端添加到网关管理
|
|
h.Gw.AddClient(client)
|
|
log.Println("client connected:", client.GetID())
|
|
log.Println("客户端已连接")
|
|
|
|
// 绑定会话ID
|
|
uid := c.Query("x-session")
|
|
if uid != "" {
|
|
if err := h.Gw.BindUid(client.GetID(), uid); err != nil {
|
|
log.Println("绑定UID错误:", err)
|
|
}
|
|
log.Printf("bind %s -> uid:%s\n", client.GetID(), uid)
|
|
}
|
|
|
|
// 验证并收集连接数据,后续对话中会使用
|
|
if err := client.DataAuth(); err != nil {
|
|
log.Println("数据验证错误:", err)
|
|
h.ChatFail(c, err.Error())
|
|
return
|
|
}
|
|
|
|
// 获取用户权限
|
|
codes, err := h.GetUserPermission(client)
|
|
if err != nil {
|
|
log.Println("获取用户权限错误:", err)
|
|
h.ChatFail(c, err.Error())
|
|
return
|
|
}
|
|
client.SetCodes(codes)
|
|
|
|
// 确保在函数返回时移除客户端并关闭连接
|
|
defer func() {
|
|
h.Gw.RemoveClient(client.GetID())
|
|
_ = c.Close()
|
|
log.Println("client disconnected:", client.GetID())
|
|
}()
|
|
|
|
// 循环读取客户端消息
|
|
for {
|
|
// 读取消息
|
|
messageType, message, err := c.ReadMessage()
|
|
if err != nil {
|
|
log.Println("读取错误:", err)
|
|
break
|
|
}
|
|
|
|
// 处理消息
|
|
msg, chatType := h.handleMessageToString(c, messageType, message)
|
|
if chatType == constants.ConnStatusClosed {
|
|
break
|
|
}
|
|
if chatType == constants.ConnStatusIgnore {
|
|
continue
|
|
}
|
|
|
|
log.Printf("收到消息: %s", string(msg))
|
|
|
|
// 解析请求
|
|
var req entitys.ChatSockRequest
|
|
if err = json.Unmarshal(msg, &req); err != nil {
|
|
log.Println("JSON parse error:", err)
|
|
continue
|
|
}
|
|
|
|
// 路由处理请求
|
|
err = h.routerBiz.RouteWithSocket(client, &req)
|
|
if err != nil {
|
|
log.Println("处理失败:", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// handleMessageToString 处理不同类型的WebSocket消息
|
|
// 参数:
|
|
// - c: WebSocket连接
|
|
// - msgType: 消息类型
|
|
// - msg: 消息内容
|
|
//
|
|
// 返回:
|
|
// - text: 处理后的文本内容
|
|
// - chatType: 连接状态
|
|
func (h *ChatService) handleMessageToString(c *websocket.Conn, msgType int, msg any) (text []byte, chatType constants.ConnStatus) {
|
|
switch msgType {
|
|
case websocket.TextMessage:
|
|
return msg.([]byte), constants.ConnStatusNormal
|
|
case websocket.BinaryMessage:
|
|
return msg.([]byte), constants.ConnStatusNormal
|
|
case websocket.CloseMessage:
|
|
|
|
return nil, constants.ConnStatusClosed
|
|
case websocket.PingMessage:
|
|
// 可选:回复 Pong
|
|
c.WriteMessage(websocket.PongMessage, nil)
|
|
return nil, constants.ConnStatusIgnore
|
|
case websocket.PongMessage:
|
|
return nil, constants.ConnStatusIgnore
|
|
default:
|
|
return nil, constants.ConnStatusIgnore
|
|
}
|
|
return msg.([]byte), constants.ConnStatusIgnore
|
|
}
|
|
|
|
func (s *ChatService) Useful(c *fiber.Ctx) error {
|
|
req := &entitys.UseFulRequest{}
|
|
if err := c.BodyParser(req); err != nil {
|
|
return err
|
|
}
|
|
|
|
err := s.ChatHis.Update(c.Context(), req)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *ChatService) UsefulList(c *fiber.Ctx) error {
|
|
|
|
return c.JSON(constants.UseFulMap)
|
|
}
|
|
|
|
// 从统一登录平台获取用户权限
|
|
func (s *ChatService) GetUserPermission(client *gateway.Client) (codes []string, err error) {
|
|
var (
|
|
request l_request.Request
|
|
)
|
|
|
|
// 系统编码
|
|
systemCode := client.GetSysCode()
|
|
|
|
// 检查系统编码是否配置
|
|
if v, ok := s.cfg.PermissionConfig.SysPermission[systemCode]; !ok {
|
|
err = errors.SysErr("系统编码 %s 未配置", systemCode)
|
|
return
|
|
} else {
|
|
request.Url = v.PermissionURL
|
|
}
|
|
|
|
request.Method = "GET"
|
|
request.Headers = map[string]string{
|
|
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36",
|
|
"Accept": "application/json, text/plain, */*",
|
|
"Authorization": "Bearer " + client.GetAuth(),
|
|
}
|
|
|
|
// 发送请求
|
|
res, err := request.Send()
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
// 检查响应状态码
|
|
if res.StatusCode != http.StatusOK {
|
|
err = errors.SysErr("获取用户权限失败")
|
|
return
|
|
}
|
|
|
|
type resp struct {
|
|
Codes []string `json:"codes"`
|
|
}
|
|
// 解析响应体
|
|
var respBody resp
|
|
err = json.Unmarshal([]byte(res.Text), &respBody)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
return respBody.Codes, nil
|
|
}
|