Compare commits

..

2 Commits

Author SHA1 Message Date
wuchao addcf2d24d feat(internal): 实现网关服务并优化聊天功能
- 新增 Gateway 结构体和相关方法,用于管理客户端连接和消息分发
- 重构 ChatService,集成 Gateway 功能
- 添加客户端绑定 UID 功能,支持消息发送到指定用户
- 实现全局消息广播和单用户消息发送的 HTTP 接口
- 优化聊天消息处理逻辑,增加消息回显功能
2025-09-19 11:56:39 +08:00
wuchao 8a2411016c feat(internal): 实现网关服务并优化聊天功能
- 新增 Gateway 结构体和相关方法,用于管理客户端连接和消息分发
- 重构 ChatService,集成 Gateway 功能
- 添加客户端绑定 UID 功能,支持消息发送到指定用户
- 实现全局消息广播和单用户消息发送的 HTTP 接口
- 优化聊天消息处理逻辑,增加消息回显功能
2025-09-19 11:23:56 +08:00
5 changed files with 179 additions and 10 deletions

104
internal/gateway/gateway.go Normal file
View File

@ -0,0 +1,104 @@
package gateway
import (
"errors"
"sync"
)
type Client struct {
ID string
SendFunc func(data []byte) error
}
type Gateway struct {
mu sync.RWMutex
clients map[string]*Client // clientID -> Client
uidMap map[string][]string // uid -> []clientID
}
func NewGateway() *Gateway {
return &Gateway{
clients: make(map[string]*Client),
uidMap: make(map[string][]string),
}
}
func (g *Gateway) AddClient(c *Client) {
g.mu.Lock()
defer g.mu.Unlock()
g.clients[c.ID] = c
}
func (g *Gateway) RemoveClient(clientID string) {
g.mu.Lock()
defer g.mu.Unlock()
delete(g.clients, clientID)
for uid, list := range g.uidMap {
newList := []string{}
for _, cid := range list {
if cid != clientID {
newList = append(newList, cid)
}
}
g.uidMap[uid] = newList
}
}
func (g *Gateway) SendToAll(msg []byte) {
g.mu.RLock()
defer g.mu.RUnlock()
for _, c := range g.clients {
_ = c.SendFunc(msg)
}
}
func (g *Gateway) SendToClient(clientID string, msg []byte) error {
g.mu.RLock()
defer g.mu.RUnlock()
if c, ok := g.clients[clientID]; ok {
return c.SendFunc(msg)
}
return errors.New("client not found")
}
func (g *Gateway) BindUid(clientID, uid string) error {
g.mu.Lock()
defer g.mu.Unlock()
if _, ok := g.clients[clientID]; !ok {
return errors.New("client not found")
}
g.uidMap[uid] = append(g.uidMap[uid], clientID)
return nil
}
func (g *Gateway) SendToUid(uid string, msg []byte) {
g.mu.RLock()
defer g.mu.RUnlock()
if list, ok := g.uidMap[uid]; ok {
for _, cid := range list {
if c, ok := g.clients[cid]; ok {
_ = c.SendFunc(msg)
}
}
}
}
func (g *Gateway) ListClients() []string {
g.mu.RLock()
defer g.mu.RUnlock()
ids := make([]string, 0, len(g.clients))
for id := range g.clients {
ids = append(ids, id)
}
return ids
}
func (g *Gateway) ListUids() map[string][]string {
g.mu.RLock()
defer g.mu.RUnlock()
result := make(map[string][]string, len(g.uidMap))
for uid, list := range g.uidMap {
result[uid] = append([]string(nil), list...)
}
return result
}

View File

@ -1,6 +1,7 @@
package server package server
import ( import (
"ai_scheduler/internal/gateway"
"ai_scheduler/internal/server/router" "ai_scheduler/internal/server/router"
"ai_scheduler/internal/services" "ai_scheduler/internal/services"
@ -12,10 +13,11 @@ import (
func NewHTTPServer( func NewHTTPServer(
service *services.ChatService, service *services.ChatService,
session *services.SessionService, session *services.SessionService,
gateway *gateway.Gateway,
) *fiber.App { ) *fiber.App {
//构建 server //构建 server
app := initRoute() app := initRoute()
router.SetupRoutes(app, service, session) router.SetupRoutes(app, service, session, gateway)
return app return app
} }

View File

@ -2,8 +2,10 @@ package router
import ( import (
errors "ai_scheduler/internal/data/error" errors "ai_scheduler/internal/data/error"
"ai_scheduler/internal/gateway"
"ai_scheduler/internal/services" "ai_scheduler/internal/services"
"encoding/json" "encoding/json"
"fmt"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -13,7 +15,7 @@ import (
) )
// SetupRoutes 设置路由 // SetupRoutes 设置路由
func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionService *services.SessionService) { func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionService *services.SessionService, gateway *gateway.Gateway) {
app.Use(func(c *fiber.Ctx) error { app.Use(func(c *fiber.Ctx) error {
// 设置 CORS 头 // 设置 CORS 头
c.Set("Access-Control-Allow-Origin", "*") c.Set("Access-Control-Allow-Origin", "*")
@ -28,7 +30,7 @@ func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionServi
// 继续处理后续中间件或路由 // 继续处理后续中间件或路由
return c.Next() return c.Next()
}) })
routerHttp(app, sessionService) routerHttp(app, sessionService, gateway)
routerSocket(app, ChatService) routerSocket(app, ChatService)
} }
@ -56,7 +58,7 @@ var bufferPool = &sync.Pool{
}, },
} }
func routerHttp(app *fiber.App, sessionService *services.SessionService) { func routerHttp(app *fiber.App, sessionService *services.SessionService, gateway *gateway.Gateway) {
r := app.Group("api/v1/") r := app.Group("api/v1/")
registerResponse(r) registerResponse(r)
// 注册 CORS 中间件 // 注册 CORS 中间件
@ -67,7 +69,26 @@ func routerHttp(app *fiber.App, sessionService *services.SessionService) {
r.Post("/session/init", sessionService.SessionInit) // 会话初始化,不存在则创建,存在则返回会话ID和默认条数会话历史 r.Post("/session/init", sessionService.SessionInit) // 会话初始化,不存在则创建,存在则返回会话ID和默认条数会话历史
r.Post("/session/list", sessionService.SessionList) r.Post("/session/list", sessionService.SessionList)
//广播
r.Get("/broadcast", func(ctx *fiber.Ctx) error {
action := ctx.Query("action")
uid := ctx.Query("uid")
msg := ctx.Query("msg")
switch action {
case "sendToAll":
gateway.SendToAll([]byte(msg))
return ctx.SendString("sent to all")
case "sendToUid":
if uid == "" {
return ctx.Status(400).SendString("missing uid")
}
gateway.SendToUid(uid, []byte(msg))
return ctx.SendString(fmt.Sprintf("sent to uid %s", uid))
default:
return ctx.Status(400).SendString("unknown action")
}
})
} }
func registerResponse(router fiber.Router) { func registerResponse(router fiber.Router) {

View File

@ -3,8 +3,14 @@ package services
import ( import (
"ai_scheduler/internal/data/constant" "ai_scheduler/internal/data/constant"
"ai_scheduler/internal/entitys" "ai_scheduler/internal/entitys"
"ai_scheduler/internal/gateway"
"encoding/hex"
"encoding/json" "encoding/json"
"fmt"
"log" "log"
"math/rand"
"sync"
"time"
"github.com/gofiber/websocket/v2" "github.com/gofiber/websocket/v2"
) )
@ -12,12 +18,15 @@ import (
// ChatHandler 聊天处理器 // ChatHandler 聊天处理器
type ChatService struct { type ChatService struct {
routerService entitys.RouterService routerService entitys.RouterService
Gw *gateway.Gateway
mu sync.Mutex
} }
// NewChatHandler 创建聊天处理器 // NewChatHandler 创建聊天处理器
func NewChatService(routerService entitys.RouterService) *ChatService { func NewChatService(routerService entitys.RouterService, gw *gateway.Gateway) *ChatService {
return &ChatService{ return &ChatService{
routerService: routerService, routerService: routerService,
Gw: gw,
} }
} }
@ -36,12 +45,34 @@ type FunctionCallResponse struct {
} }
func (h *ChatService) ChatFail(c *websocket.Conn, content string) { func (h *ChatService) ChatFail(c *websocket.Conn, content string) {
//c.WriteMessage(messageType, message) err := c.WriteMessage(websocket.TextMessage, []byte(content))
if err != nil {
log.Println("发送错误:", err)
}
_ = c.Close()
} }
func generateClientID() string {
// 使用时间戳+随机数确保唯一性
timestamp := time.Now().UnixNano()
randomBytes := make([]byte, 4)
rand.Read(randomBytes)
randomStr := hex.EncodeToString(randomBytes)
return fmt.Sprintf("%d%s", timestamp, randomStr)
}
func (h *ChatService) Chat(c *websocket.Conn) { func (h *ChatService) Chat(c *websocket.Conn) {
h.mu.Lock()
clientID := generateClientID()
h.mu.Unlock()
client := &gateway.Client{
ID: clientID,
SendFunc: func(data []byte) error {
return c.WriteMessage(websocket.TextMessage, data)
},
}
h.Gw.AddClient(client)
log.Println("client connected:", clientID)
log.Println("客户端已连接") log.Println("客户端已连接")
defer c.Close()
// 循环读取客户端消息 // 循环读取客户端消息
for { for {
messageType, message, err := c.ReadMessage() messageType, message, err := c.ReadMessage()
@ -49,6 +80,12 @@ func (h *ChatService) Chat(c *websocket.Conn) {
log.Println("读取错误:", err) log.Println("读取错误:", err)
break break
} }
//简单协议bind:<uid>
if c.Headers("Sec-Websocket-Protocol") == "bind" && c.Headers("X-Session") != "" {
uid := c.Headers("X-Session")
_ = h.Gw.BindUid(clientID, uid)
log.Printf("bind %s -> uid:%s\n", clientID, uid)
}
msg, chatType := h.handleMessageToString(c, messageType, message) msg, chatType := h.handleMessageToString(c, messageType, message)
if chatType == constant.ConnStatusClosed { if chatType == constant.ConnStatusClosed {
break break
@ -69,7 +106,9 @@ func (h *ChatService) Chat(c *websocket.Conn) {
continue continue
} }
} }
log.Println("客户端已断开") h.Gw.RemoveClient(clientID)
_ = c.Close()
log.Println("client disconnected:", clientID)
} }
func (h *ChatService) handleMessageToString(c *websocket.Conn, msgType int, msg any) (text []byte, chatType constant.ConnStatus) { func (h *ChatService) handleMessageToString(c *websocket.Conn, msgType int, msg any) (text []byte, chatType constant.ConnStatus) {

View File

@ -1,5 +1,8 @@
package services package services
import "github.com/google/wire" import (
"ai_scheduler/internal/gateway"
"github.com/google/wire"
)
var ProviderSetServices = wire.NewSet(NewChatService, NewSessionService) var ProviderSetServices = wire.NewSet(NewChatService, NewSessionService, gateway.NewGateway)