146 lines
		
	
	
		
			3.7 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			146 lines
		
	
	
		
			3.7 KiB
		
	
	
	
		
			Go
		
	
	
	
package services
 | 
						||
 | 
						||
import (
 | 
						||
	"ai_scheduler/internal/biz"
 | 
						||
	"ai_scheduler/internal/data/constants"
 | 
						||
	"ai_scheduler/internal/entitys"
 | 
						||
	"ai_scheduler/internal/gateway"
 | 
						||
	"encoding/hex"
 | 
						||
	"encoding/json"
 | 
						||
	"fmt"
 | 
						||
	"log"
 | 
						||
	"math/rand"
 | 
						||
	"sync"
 | 
						||
	"time"
 | 
						||
 | 
						||
	"github.com/gofiber/websocket/v2"
 | 
						||
)
 | 
						||
 | 
						||
// ChatHandler 聊天处理器
 | 
						||
type ChatService struct {
 | 
						||
	routerBiz *biz.AiRouterBiz
 | 
						||
	Gw        *gateway.Gateway
 | 
						||
	mu        sync.Mutex
 | 
						||
}
 | 
						||
 | 
						||
// NewChatHandler 创建聊天处理器
 | 
						||
func NewChatService(routerService *biz.AiRouterBiz, gw *gateway.Gateway) *ChatService {
 | 
						||
	return &ChatService{
 | 
						||
		routerBiz: routerService,
 | 
						||
		Gw:        gw,
 | 
						||
	}
 | 
						||
}
 | 
						||
 | 
						||
// ToolCallResponse 工具调用响应
 | 
						||
type ToolCallResponse struct {
 | 
						||
	ID       string               `json:"id" example:"call_1"`
 | 
						||
	Type     string               `json:"type" example:"function"`
 | 
						||
	Function FunctionCallResponse `json:"function"`
 | 
						||
	Result   interface{}          `json:"result,omitempty"`
 | 
						||
}
 | 
						||
 | 
						||
// FunctionCallResponse 函数调用响应
 | 
						||
type FunctionCallResponse struct {
 | 
						||
	Name      string      `json:"name" example:"get_weather"`
 | 
						||
	Arguments interface{} `json:"arguments"`
 | 
						||
}
 | 
						||
 | 
						||
func (h *ChatService) ChatFail(c *websocket.Conn, content string) {
 | 
						||
	err := c.WriteMessage(websocket.TextMessage, []byte(content))
 | 
						||
	if err != nil {
 | 
						||
		log.Println("发送错误:", err)
 | 
						||
	}
 | 
						||
	_ = c.Close()
 | 
						||
}
 | 
						||
 | 
						||
func generateClientID() string {
 | 
						||
	// 使用时间戳+随机数确保唯一性
 | 
						||
	timestamp := time.Now().UnixNano()
 | 
						||
	randomBytes := make([]byte, 4)
 | 
						||
	rand.Read(randomBytes)
 | 
						||
	randomStr := hex.EncodeToString(randomBytes)
 | 
						||
	return fmt.Sprintf("%d%s", timestamp, randomStr)
 | 
						||
}
 | 
						||
func (h *ChatService) Chat(c *websocket.Conn) {
 | 
						||
	h.mu.Lock()
 | 
						||
	clientID := generateClientID()
 | 
						||
	h.mu.Unlock()
 | 
						||
	client := &gateway.Client{
 | 
						||
		ID: clientID,
 | 
						||
		SendFunc: func(data []byte) error {
 | 
						||
			return c.WriteMessage(websocket.TextMessage, data)
 | 
						||
		},
 | 
						||
	}
 | 
						||
	h.Gw.AddClient(client)
 | 
						||
	log.Println("client connected:", clientID)
 | 
						||
	log.Println("客户端已连接")
 | 
						||
 | 
						||
	// 循环读取客户端消息
 | 
						||
	for {
 | 
						||
		messageType, message, err := c.ReadMessage()
 | 
						||
		if err != nil {
 | 
						||
			log.Println("读取错误:", err)
 | 
						||
			break
 | 
						||
		}
 | 
						||
 | 
						||
		msg, chatType := h.handleMessageToString(c, messageType, message)
 | 
						||
		if chatType == constants.ConnStatusClosed {
 | 
						||
			break
 | 
						||
		}
 | 
						||
		if chatType == constants.ConnStatusIgnore {
 | 
						||
			continue
 | 
						||
		}
 | 
						||
 | 
						||
		log.Printf("收到消息: %s", string(msg))
 | 
						||
		var req entitys.ChatSockRequest
 | 
						||
		if err := json.Unmarshal(msg, &req); err != nil {
 | 
						||
			log.Println("JSON parse error:", err)
 | 
						||
			continue
 | 
						||
		}
 | 
						||
 | 
						||
		//简单协议:bind:<uid>
 | 
						||
		if c.Headers("Sec-Websocket-Protocol") == "bind" && req.SessionID != "" {
 | 
						||
			uid := c.Query("x-session")
 | 
						||
			_ = h.Gw.BindUid(clientID, req.SessionID)
 | 
						||
			log.Printf("bind %s -> uid:%s\n", clientID, uid)
 | 
						||
		}
 | 
						||
 | 
						||
		err = h.routerBiz.RouteWithSocket(c, &req)
 | 
						||
		if err != nil {
 | 
						||
			log.Println("处理失败:", err)
 | 
						||
			entitys.MsgSend(c, entitys.Response{
 | 
						||
				Content: err.Error(),
 | 
						||
				Type:    entitys.ResponseText,
 | 
						||
			})
 | 
						||
		}
 | 
						||
		_ = entitys.MsgSend(c, entitys.Response{
 | 
						||
			Content: "",
 | 
						||
			Type:    entitys.ResponseEnd,
 | 
						||
		})
 | 
						||
	}
 | 
						||
	h.Gw.RemoveClient(clientID)
 | 
						||
	_ = c.Close()
 | 
						||
	log.Println("client disconnected:", clientID)
 | 
						||
}
 | 
						||
 | 
						||
func (h *ChatService) handleMessageToString(c *websocket.Conn, msgType int, msg any) (text []byte, chatType constants.ConnStatus) {
 | 
						||
	switch msgType {
 | 
						||
	case websocket.TextMessage:
 | 
						||
		return msg.([]byte), constants.ConnStatusNormal
 | 
						||
	case websocket.BinaryMessage:
 | 
						||
		return msg.([]byte), constants.ConnStatusNormal
 | 
						||
	case websocket.CloseMessage:
 | 
						||
 | 
						||
		return nil, constants.ConnStatusClosed
 | 
						||
	case websocket.PingMessage:
 | 
						||
		// 可选:回复 Pong
 | 
						||
		c.WriteMessage(websocket.PongMessage, nil)
 | 
						||
		return nil, constants.ConnStatusIgnore
 | 
						||
	case websocket.PongMessage:
 | 
						||
		return nil, constants.ConnStatusIgnore
 | 
						||
	default:
 | 
						||
		return nil, constants.ConnStatusIgnore
 | 
						||
	}
 | 
						||
	return msg.([]byte), constants.ConnStatusIgnore
 | 
						||
}
 |