diff --git a/cmd/server/main.go b/cmd/server/main.go index 2f94736..e5fbbdd 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -10,18 +10,21 @@ import ( ) func main() { - configPath := flag.String("config", "./config/config.yaml", "Path to configuration file") flag.Parse() bc, err := config.LoadConfig(*configPath) if err != nil { log.Fatalf("加载配置失败: %v", err) } + app, cleanup, err := InitializeApp(bc, log.DefaultLogger()) if err != nil { log.Fatalf("项目初始化失败: %v", err) } - defer cleanup() + defer func() { + cleanup() + + }() log.Fatal(app.HttpServer.Listen(fmt.Sprintf(":%d", bc.Server.Port))) } diff --git a/config/config.yaml b/config/config.yaml index c0f2035..8009dcd 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -12,6 +12,8 @@ ollama: sys: session_len: 3 + channel_pool_len: 100 + channel_pool_size: 32 redis: host: 47.97.27.195:6379 diff --git a/internal/biz/router.go b/internal/biz/router.go index f721791..d754745 100644 --- a/internal/biz/router.go +++ b/internal/biz/router.go @@ -2,7 +2,7 @@ package biz import ( "ai_scheduler/internal/config" - "ai_scheduler/internal/data/constant" + "ai_scheduler/internal/data/constants" errors "ai_scheduler/internal/data/error" "ai_scheduler/internal/data/impl" "ai_scheduler/internal/data/model" @@ -34,6 +34,7 @@ type AiRouterBiz struct { hisImpl *impl.ChatImpl conf *config.Config utilAgent *utils_ollama.UtilOllama + channelPool *pkg.SafeChannelPool } // NewRouterService 创建路由服务 @@ -46,7 +47,7 @@ func NewAiRouterBiz( hisImpl *impl.ChatImpl, conf *config.Config, utilAgent *utils_ollama.UtilOllama, - + channelPool *pkg.SafeChannelPool, ) *AiRouterBiz { return &AiRouterBiz{ //aiClient: aiClient, @@ -57,6 +58,7 @@ func NewAiRouterBiz( hisImpl: hisImpl, taskImpl: taskImpl, utilAgent: utilAgent, + channelPool: channelPool, } } @@ -68,13 +70,16 @@ func (r *AiRouterBiz) Route(ctx context.Context, req *entitys.ChatRequest) (*ent // Route 执行智能路由 func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) (err error) { - var ch = make(chan []byte) + ch, err := r.channelPool.Get() + if err != nil { + return err + } defer func() { - if err != nil { _ = c.WriteMessage(websocket.TextMessage, []byte(err.Error())) } _ = c.WriteMessage(websocket.TextMessage, []byte("EOF")) + r.channelPool.Put(ch) }() session := c.Headers("X-Session", "") if len(session) == 0 { @@ -164,9 +169,9 @@ func (r *AiRouterBiz) handleMatch(c *websocket.Conn, ch chan []byte, matchJson * return r.handleOtherTask(c, ch, matchJson) } switch pointTask.Type { - case constant.TaskTypeApi: + case constants.TaskTypeApi: return r.handleApiTask(ch, c, matchJson, pointTask) - case constant.TaskTypeFunc: + case constants.TaskTypeFunc: return r.handleTask(ch, c, matchJson, pointTask) default: return r.handleOtherTask(c, ch, matchJson) diff --git a/internal/config/config.go b/internal/config/config.go index 6a3ad43..6a0bcc4 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -25,7 +25,9 @@ type LLM struct { // SysConfig 系统配置 type SysConfig struct { - SessionLen int `mapstructure:"session_len"` + SessionLen int `mapstructure:"session_len"` + ChannelPoolLen int `mapstructure:"channel_pool_len"` + ChannelPoolSize int `mapstructure:"channel_pool_size"` } // ServerConfig 服务器配置 diff --git a/internal/data/constant/supplier.go b/internal/data/constant/supplier.go deleted file mode 100644 index 19f4670..0000000 --- a/internal/data/constant/supplier.go +++ /dev/null @@ -1,3 +0,0 @@ -package constant - -const () diff --git a/internal/data/constant/const.go b/internal/data/constants/const.go similarity index 91% rename from internal/data/constant/const.go rename to internal/data/constants/const.go index 13ef745..240188a 100644 --- a/internal/data/constant/const.go +++ b/internal/data/constants/const.go @@ -1,4 +1,4 @@ -package constant +package constants type ConnStatus int8 diff --git a/internal/data/constants/supplier.go b/internal/data/constants/supplier.go new file mode 100644 index 0000000..9c9295e --- /dev/null +++ b/internal/data/constants/supplier.go @@ -0,0 +1,3 @@ +package constants + +const () diff --git a/internal/entitys/response.go b/internal/entitys/response.go new file mode 100644 index 0000000..f8bcb00 --- /dev/null +++ b/internal/entitys/response.go @@ -0,0 +1,13 @@ +package entitys + +type Response string + +const ( + ResponseJson Response = "json" + ResponseLoading Response = "loading" + ResponseEnd Response = "end" + ResponseStream Response = "stream" + ResponseText Response = "txt" + ResponseImg Response = "img" + ResponseFile Response = "file" +) diff --git a/internal/pkg/channel_pool.go b/internal/pkg/channel_pool.go new file mode 100644 index 0000000..70502b2 --- /dev/null +++ b/internal/pkg/channel_pool.go @@ -0,0 +1,75 @@ +package pkg + +import ( + "ai_scheduler/internal/config" + "errors" + "sync" +) + +type SafeChannelPool struct { + pool chan chan []byte // 存储空闲 channel 的队列 + bufSize int // channel 缓冲大小 + mu sync.Mutex + closed bool +} + +func NewSafeChannelPool(c *config.Config) (*SafeChannelPool, func()) { + pool := &SafeChannelPool{ + pool: make(chan chan []byte, c.Sys.ChannelPoolLen), + bufSize: c.Sys.ChannelPoolSize, + } + + cleanup := pool.Close + return pool, cleanup +} + +// 从池中获取 channel(若无空闲则创建新 channel) +func (p *SafeChannelPool) Get() (chan []byte, error) { + p.mu.Lock() + defer p.mu.Unlock() + + if p.closed { + return nil, errors.New("pool is closed") + } + + select { + case ch := <-p.pool: // 从池中取 + return ch, nil + default: // 池为空,创建新 channel + return make(chan []byte, p.bufSize), nil + } +} + +// 将 channel 放回池中(必须确保 channel 已清空!) +func (p *SafeChannelPool) Put(ch chan []byte) error { + p.mu.Lock() + defer p.mu.Unlock() + + if p.closed { + return errors.New("pool is closed") + } + + // 清空 channel(防止复用时读取旧数据) + go func() { + for range ch { + // 丢弃所有数据(或根据业务需求处理) + } + }() + + select { + case p.pool <- ch: // 尝试放回池中 + default: // 池已满,直接关闭 channel(避免泄漏) + close(ch) + } + return nil +} + +// 关闭池(释放所有资源) +func (p *SafeChannelPool) Close() { + p.mu.Lock() + defer p.mu.Unlock() + + p.closed = true + close(p.pool) // 关闭池队列 + // 需额外逻辑关闭所有内部 channel(此处简化) +} diff --git a/internal/pkg/provider_set.go b/internal/pkg/provider_set.go index d123eed..f8b1a25 100644 --- a/internal/pkg/provider_set.go +++ b/internal/pkg/provider_set.go @@ -11,4 +11,5 @@ var ProviderSetClient = wire.NewSet( NewGormDb, utils_ollama.NewUtilOllama, utils_ollama.NewClient, + NewSafeChannelPool, ) diff --git a/internal/services/chat.go b/internal/services/chat.go index 82995e1..bb24ce1 100644 --- a/internal/services/chat.go +++ b/internal/services/chat.go @@ -2,7 +2,7 @@ package services import ( "ai_scheduler/internal/biz" - "ai_scheduler/internal/data/constant" + "ai_scheduler/internal/data/constants" "ai_scheduler/internal/entitys" "ai_scheduler/internal/gateway" "encoding/hex" @@ -88,10 +88,10 @@ func (h *ChatService) Chat(c *websocket.Conn) { log.Printf("bind %s -> uid:%s\n", clientID, uid) } msg, chatType := h.handleMessageToString(c, messageType, message) - if chatType == constant.ConnStatusClosed { + if chatType == constants.ConnStatusClosed { break } - if chatType == constant.ConnStatusIgnore { + if chatType == constants.ConnStatusIgnore { continue } @@ -112,23 +112,23 @@ func (h *ChatService) Chat(c *websocket.Conn) { log.Println("client disconnected:", clientID) } -func (h *ChatService) handleMessageToString(c *websocket.Conn, msgType int, msg any) (text []byte, chatType constant.ConnStatus) { +func (h *ChatService) handleMessageToString(c *websocket.Conn, msgType int, msg any) (text []byte, chatType constants.ConnStatus) { switch msgType { case websocket.TextMessage: - return msg.([]byte), constant.ConnStatusNormal + return msg.([]byte), constants.ConnStatusNormal case websocket.BinaryMessage: - return msg.([]byte), constant.ConnStatusNormal + return msg.([]byte), constants.ConnStatusNormal case websocket.CloseMessage: - return nil, constant.ConnStatusClosed + return nil, constants.ConnStatusClosed case websocket.PingMessage: // 可选:回复 Pong c.WriteMessage(websocket.PongMessage, nil) - return nil, constant.ConnStatusIgnore + return nil, constants.ConnStatusIgnore case websocket.PongMessage: - return nil, constant.ConnStatusIgnore + return nil, constants.ConnStatusIgnore default: - return nil, constant.ConnStatusIgnore + return nil, constants.ConnStatusIgnore } - return msg.([]byte), constant.ConnStatusIgnore + return msg.([]byte), constants.ConnStatusIgnore }