Compare commits
No commits in common. "addcf2d24d4ce60fa254c9f458f8cb56bccfe86d" and "068563b91491d60fab01fdd95d76f83f65e109b0" have entirely different histories.
addcf2d24d
...
068563b914
|
@ -1,104 +0,0 @@
|
||||||
package gateway
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Client struct {
|
|
||||||
ID string
|
|
||||||
SendFunc func(data []byte) error
|
|
||||||
}
|
|
||||||
|
|
||||||
type Gateway struct {
|
|
||||||
mu sync.RWMutex
|
|
||||||
clients map[string]*Client // clientID -> Client
|
|
||||||
uidMap map[string][]string // uid -> []clientID
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewGateway() *Gateway {
|
|
||||||
return &Gateway{
|
|
||||||
clients: make(map[string]*Client),
|
|
||||||
uidMap: make(map[string][]string),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *Gateway) AddClient(c *Client) {
|
|
||||||
g.mu.Lock()
|
|
||||||
defer g.mu.Unlock()
|
|
||||||
g.clients[c.ID] = c
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *Gateway) RemoveClient(clientID string) {
|
|
||||||
g.mu.Lock()
|
|
||||||
defer g.mu.Unlock()
|
|
||||||
delete(g.clients, clientID)
|
|
||||||
for uid, list := range g.uidMap {
|
|
||||||
newList := []string{}
|
|
||||||
for _, cid := range list {
|
|
||||||
if cid != clientID {
|
|
||||||
newList = append(newList, cid)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
g.uidMap[uid] = newList
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *Gateway) SendToAll(msg []byte) {
|
|
||||||
g.mu.RLock()
|
|
||||||
defer g.mu.RUnlock()
|
|
||||||
for _, c := range g.clients {
|
|
||||||
_ = c.SendFunc(msg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *Gateway) SendToClient(clientID string, msg []byte) error {
|
|
||||||
g.mu.RLock()
|
|
||||||
defer g.mu.RUnlock()
|
|
||||||
if c, ok := g.clients[clientID]; ok {
|
|
||||||
return c.SendFunc(msg)
|
|
||||||
}
|
|
||||||
return errors.New("client not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *Gateway) BindUid(clientID, uid string) error {
|
|
||||||
g.mu.Lock()
|
|
||||||
defer g.mu.Unlock()
|
|
||||||
if _, ok := g.clients[clientID]; !ok {
|
|
||||||
return errors.New("client not found")
|
|
||||||
}
|
|
||||||
g.uidMap[uid] = append(g.uidMap[uid], clientID)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *Gateway) SendToUid(uid string, msg []byte) {
|
|
||||||
g.mu.RLock()
|
|
||||||
defer g.mu.RUnlock()
|
|
||||||
if list, ok := g.uidMap[uid]; ok {
|
|
||||||
for _, cid := range list {
|
|
||||||
if c, ok := g.clients[cid]; ok {
|
|
||||||
_ = c.SendFunc(msg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *Gateway) ListClients() []string {
|
|
||||||
g.mu.RLock()
|
|
||||||
defer g.mu.RUnlock()
|
|
||||||
ids := make([]string, 0, len(g.clients))
|
|
||||||
for id := range g.clients {
|
|
||||||
ids = append(ids, id)
|
|
||||||
}
|
|
||||||
return ids
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *Gateway) ListUids() map[string][]string {
|
|
||||||
g.mu.RLock()
|
|
||||||
defer g.mu.RUnlock()
|
|
||||||
result := make(map[string][]string, len(g.uidMap))
|
|
||||||
for uid, list := range g.uidMap {
|
|
||||||
result[uid] = append([]string(nil), list...)
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
|
@ -1,7 +1,6 @@
|
||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"ai_scheduler/internal/gateway"
|
|
||||||
"ai_scheduler/internal/server/router"
|
"ai_scheduler/internal/server/router"
|
||||||
"ai_scheduler/internal/services"
|
"ai_scheduler/internal/services"
|
||||||
|
|
||||||
|
@ -13,11 +12,10 @@ import (
|
||||||
func NewHTTPServer(
|
func NewHTTPServer(
|
||||||
service *services.ChatService,
|
service *services.ChatService,
|
||||||
session *services.SessionService,
|
session *services.SessionService,
|
||||||
gateway *gateway.Gateway,
|
|
||||||
) *fiber.App {
|
) *fiber.App {
|
||||||
//构建 server
|
//构建 server
|
||||||
app := initRoute()
|
app := initRoute()
|
||||||
router.SetupRoutes(app, service, session, gateway)
|
router.SetupRoutes(app, service, session)
|
||||||
return app
|
return app
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2,10 +2,8 @@ package router
|
||||||
|
|
||||||
import (
|
import (
|
||||||
errors "ai_scheduler/internal/data/error"
|
errors "ai_scheduler/internal/data/error"
|
||||||
"ai_scheduler/internal/gateway"
|
|
||||||
"ai_scheduler/internal/services"
|
"ai_scheduler/internal/services"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
@ -15,7 +13,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
// SetupRoutes 设置路由
|
// SetupRoutes 设置路由
|
||||||
func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionService *services.SessionService, gateway *gateway.Gateway) {
|
func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionService *services.SessionService) {
|
||||||
app.Use(func(c *fiber.Ctx) error {
|
app.Use(func(c *fiber.Ctx) error {
|
||||||
// 设置 CORS 头
|
// 设置 CORS 头
|
||||||
c.Set("Access-Control-Allow-Origin", "*")
|
c.Set("Access-Control-Allow-Origin", "*")
|
||||||
|
@ -30,7 +28,7 @@ func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionServi
|
||||||
// 继续处理后续中间件或路由
|
// 继续处理后续中间件或路由
|
||||||
return c.Next()
|
return c.Next()
|
||||||
})
|
})
|
||||||
routerHttp(app, sessionService, gateway)
|
routerHttp(app, sessionService)
|
||||||
routerSocket(app, ChatService)
|
routerSocket(app, ChatService)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -58,7 +56,7 @@ var bufferPool = &sync.Pool{
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
func routerHttp(app *fiber.App, sessionService *services.SessionService, gateway *gateway.Gateway) {
|
func routerHttp(app *fiber.App, sessionService *services.SessionService) {
|
||||||
r := app.Group("api/v1/")
|
r := app.Group("api/v1/")
|
||||||
registerResponse(r)
|
registerResponse(r)
|
||||||
// 注册 CORS 中间件
|
// 注册 CORS 中间件
|
||||||
|
@ -69,26 +67,7 @@ func routerHttp(app *fiber.App, sessionService *services.SessionService, gateway
|
||||||
|
|
||||||
r.Post("/session/init", sessionService.SessionInit) // 会话初始化,不存在则创建,存在则返回会话ID和默认条数会话历史
|
r.Post("/session/init", sessionService.SessionInit) // 会话初始化,不存在则创建,存在则返回会话ID和默认条数会话历史
|
||||||
r.Post("/session/list", sessionService.SessionList)
|
r.Post("/session/list", sessionService.SessionList)
|
||||||
//广播
|
|
||||||
r.Get("/broadcast", func(ctx *fiber.Ctx) error {
|
|
||||||
action := ctx.Query("action")
|
|
||||||
uid := ctx.Query("uid")
|
|
||||||
msg := ctx.Query("msg")
|
|
||||||
|
|
||||||
switch action {
|
|
||||||
case "sendToAll":
|
|
||||||
gateway.SendToAll([]byte(msg))
|
|
||||||
return ctx.SendString("sent to all")
|
|
||||||
case "sendToUid":
|
|
||||||
if uid == "" {
|
|
||||||
return ctx.Status(400).SendString("missing uid")
|
|
||||||
}
|
|
||||||
gateway.SendToUid(uid, []byte(msg))
|
|
||||||
return ctx.SendString(fmt.Sprintf("sent to uid %s", uid))
|
|
||||||
default:
|
|
||||||
return ctx.Status(400).SendString("unknown action")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func registerResponse(router fiber.Router) {
|
func registerResponse(router fiber.Router) {
|
||||||
|
|
|
@ -3,14 +3,8 @@ package services
|
||||||
import (
|
import (
|
||||||
"ai_scheduler/internal/data/constant"
|
"ai_scheduler/internal/data/constant"
|
||||||
"ai_scheduler/internal/entitys"
|
"ai_scheduler/internal/entitys"
|
||||||
"ai_scheduler/internal/gateway"
|
|
||||||
"encoding/hex"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
|
||||||
"log"
|
"log"
|
||||||
"math/rand"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gofiber/websocket/v2"
|
"github.com/gofiber/websocket/v2"
|
||||||
)
|
)
|
||||||
|
@ -18,15 +12,12 @@ import (
|
||||||
// ChatHandler 聊天处理器
|
// ChatHandler 聊天处理器
|
||||||
type ChatService struct {
|
type ChatService struct {
|
||||||
routerService entitys.RouterService
|
routerService entitys.RouterService
|
||||||
Gw *gateway.Gateway
|
|
||||||
mu sync.Mutex
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewChatHandler 创建聊天处理器
|
// NewChatHandler 创建聊天处理器
|
||||||
func NewChatService(routerService entitys.RouterService, gw *gateway.Gateway) *ChatService {
|
func NewChatService(routerService entitys.RouterService) *ChatService {
|
||||||
return &ChatService{
|
return &ChatService{
|
||||||
routerService: routerService,
|
routerService: routerService,
|
||||||
Gw: gw,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -45,34 +36,12 @@ type FunctionCallResponse struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *ChatService) ChatFail(c *websocket.Conn, content string) {
|
func (h *ChatService) ChatFail(c *websocket.Conn, content string) {
|
||||||
err := c.WriteMessage(websocket.TextMessage, []byte(content))
|
//c.WriteMessage(messageType, message)
|
||||||
if err != nil {
|
|
||||||
log.Println("发送错误:", err)
|
|
||||||
}
|
|
||||||
_ = c.Close()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateClientID() string {
|
|
||||||
// 使用时间戳+随机数确保唯一性
|
|
||||||
timestamp := time.Now().UnixNano()
|
|
||||||
randomBytes := make([]byte, 4)
|
|
||||||
rand.Read(randomBytes)
|
|
||||||
randomStr := hex.EncodeToString(randomBytes)
|
|
||||||
return fmt.Sprintf("%d%s", timestamp, randomStr)
|
|
||||||
}
|
|
||||||
func (h *ChatService) Chat(c *websocket.Conn) {
|
func (h *ChatService) Chat(c *websocket.Conn) {
|
||||||
h.mu.Lock()
|
|
||||||
clientID := generateClientID()
|
|
||||||
h.mu.Unlock()
|
|
||||||
client := &gateway.Client{
|
|
||||||
ID: clientID,
|
|
||||||
SendFunc: func(data []byte) error {
|
|
||||||
return c.WriteMessage(websocket.TextMessage, data)
|
|
||||||
},
|
|
||||||
}
|
|
||||||
h.Gw.AddClient(client)
|
|
||||||
log.Println("client connected:", clientID)
|
|
||||||
log.Println("客户端已连接")
|
log.Println("客户端已连接")
|
||||||
|
defer c.Close()
|
||||||
// 循环读取客户端消息
|
// 循环读取客户端消息
|
||||||
for {
|
for {
|
||||||
messageType, message, err := c.ReadMessage()
|
messageType, message, err := c.ReadMessage()
|
||||||
|
@ -80,12 +49,6 @@ func (h *ChatService) Chat(c *websocket.Conn) {
|
||||||
log.Println("读取错误:", err)
|
log.Println("读取错误:", err)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
//简单协议:bind:<uid>
|
|
||||||
if c.Headers("Sec-Websocket-Protocol") == "bind" && c.Headers("X-Session") != "" {
|
|
||||||
uid := c.Headers("X-Session")
|
|
||||||
_ = h.Gw.BindUid(clientID, uid)
|
|
||||||
log.Printf("bind %s -> uid:%s\n", clientID, uid)
|
|
||||||
}
|
|
||||||
msg, chatType := h.handleMessageToString(c, messageType, message)
|
msg, chatType := h.handleMessageToString(c, messageType, message)
|
||||||
if chatType == constant.ConnStatusClosed {
|
if chatType == constant.ConnStatusClosed {
|
||||||
break
|
break
|
||||||
|
@ -106,9 +69,7 @@ func (h *ChatService) Chat(c *websocket.Conn) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
h.Gw.RemoveClient(clientID)
|
log.Println("客户端已断开")
|
||||||
_ = c.Close()
|
|
||||||
log.Println("client disconnected:", clientID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *ChatService) handleMessageToString(c *websocket.Conn, msgType int, msg any) (text []byte, chatType constant.ConnStatus) {
|
func (h *ChatService) handleMessageToString(c *websocket.Conn, msgType int, msg any) (text []byte, chatType constant.ConnStatus) {
|
||||||
|
|
|
@ -1,8 +1,5 @@
|
||||||
package services
|
package services
|
||||||
|
|
||||||
import (
|
import "github.com/google/wire"
|
||||||
"ai_scheduler/internal/gateway"
|
|
||||||
"github.com/google/wire"
|
|
||||||
)
|
|
||||||
|
|
||||||
var ProviderSetServices = wire.NewSet(NewChatService, NewSessionService, gateway.NewGateway)
|
var ProviderSetServices = wire.NewSet(NewChatService, NewSessionService)
|
||||||
|
|
Loading…
Reference in New Issue