结构修改

This commit is contained in:
renzhiyuan 2025-09-22 20:36:17 +08:00
parent b996f62098
commit c07b57a705
11 changed files with 125 additions and 24 deletions

View File

@ -10,18 +10,21 @@ import (
) )
func main() { func main() {
configPath := flag.String("config", "./config/config.yaml", "Path to configuration file") configPath := flag.String("config", "./config/config.yaml", "Path to configuration file")
flag.Parse() flag.Parse()
bc, err := config.LoadConfig(*configPath) bc, err := config.LoadConfig(*configPath)
if err != nil { if err != nil {
log.Fatalf("加载配置失败: %v", err) log.Fatalf("加载配置失败: %v", err)
} }
app, cleanup, err := InitializeApp(bc, log.DefaultLogger()) app, cleanup, err := InitializeApp(bc, log.DefaultLogger())
if err != nil { if err != nil {
log.Fatalf("项目初始化失败: %v", err) log.Fatalf("项目初始化失败: %v", err)
} }
defer cleanup() defer func() {
cleanup()
}()
log.Fatal(app.HttpServer.Listen(fmt.Sprintf(":%d", bc.Server.Port))) log.Fatal(app.HttpServer.Listen(fmt.Sprintf(":%d", bc.Server.Port)))
} }

View File

@ -12,6 +12,8 @@ ollama:
sys: sys:
session_len: 3 session_len: 3
channel_pool_len: 100
channel_pool_size: 32
redis: redis:
host: 47.97.27.195:6379 host: 47.97.27.195:6379

View File

@ -2,7 +2,7 @@ package biz
import ( import (
"ai_scheduler/internal/config" "ai_scheduler/internal/config"
"ai_scheduler/internal/data/constant" "ai_scheduler/internal/data/constants"
errors "ai_scheduler/internal/data/error" errors "ai_scheduler/internal/data/error"
"ai_scheduler/internal/data/impl" "ai_scheduler/internal/data/impl"
"ai_scheduler/internal/data/model" "ai_scheduler/internal/data/model"
@ -34,6 +34,7 @@ type AiRouterBiz struct {
hisImpl *impl.ChatImpl hisImpl *impl.ChatImpl
conf *config.Config conf *config.Config
utilAgent *utils_ollama.UtilOllama utilAgent *utils_ollama.UtilOllama
channelPool *pkg.SafeChannelPool
} }
// NewRouterService 创建路由服务 // NewRouterService 创建路由服务
@ -46,7 +47,7 @@ func NewAiRouterBiz(
hisImpl *impl.ChatImpl, hisImpl *impl.ChatImpl,
conf *config.Config, conf *config.Config,
utilAgent *utils_ollama.UtilOllama, utilAgent *utils_ollama.UtilOllama,
channelPool *pkg.SafeChannelPool,
) *AiRouterBiz { ) *AiRouterBiz {
return &AiRouterBiz{ return &AiRouterBiz{
//aiClient: aiClient, //aiClient: aiClient,
@ -57,6 +58,7 @@ func NewAiRouterBiz(
hisImpl: hisImpl, hisImpl: hisImpl,
taskImpl: taskImpl, taskImpl: taskImpl,
utilAgent: utilAgent, utilAgent: utilAgent,
channelPool: channelPool,
} }
} }
@ -68,13 +70,16 @@ func (r *AiRouterBiz) Route(ctx context.Context, req *entitys.ChatRequest) (*ent
// Route 执行智能路由 // Route 执行智能路由
func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) (err error) { func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) (err error) {
var ch = make(chan []byte) ch, err := r.channelPool.Get()
if err != nil {
return err
}
defer func() { defer func() {
if err != nil { if err != nil {
_ = c.WriteMessage(websocket.TextMessage, []byte(err.Error())) _ = c.WriteMessage(websocket.TextMessage, []byte(err.Error()))
} }
_ = c.WriteMessage(websocket.TextMessage, []byte("EOF")) _ = c.WriteMessage(websocket.TextMessage, []byte("EOF"))
r.channelPool.Put(ch)
}() }()
session := c.Headers("X-Session", "") session := c.Headers("X-Session", "")
if len(session) == 0 { if len(session) == 0 {
@ -164,9 +169,9 @@ func (r *AiRouterBiz) handleMatch(c *websocket.Conn, ch chan []byte, matchJson *
return r.handleOtherTask(c, ch, matchJson) return r.handleOtherTask(c, ch, matchJson)
} }
switch pointTask.Type { switch pointTask.Type {
case constant.TaskTypeApi: case constants.TaskTypeApi:
return r.handleApiTask(ch, c, matchJson, pointTask) return r.handleApiTask(ch, c, matchJson, pointTask)
case constant.TaskTypeFunc: case constants.TaskTypeFunc:
return r.handleTask(ch, c, matchJson, pointTask) return r.handleTask(ch, c, matchJson, pointTask)
default: default:
return r.handleOtherTask(c, ch, matchJson) return r.handleOtherTask(c, ch, matchJson)

View File

@ -26,6 +26,8 @@ type LLM struct {
// SysConfig 系统配置 // SysConfig 系统配置
type SysConfig struct { type SysConfig struct {
SessionLen int `mapstructure:"session_len"` SessionLen int `mapstructure:"session_len"`
ChannelPoolLen int `mapstructure:"channel_pool_len"`
ChannelPoolSize int `mapstructure:"channel_pool_size"`
} }
// ServerConfig 服务器配置 // ServerConfig 服务器配置

View File

@ -1,3 +0,0 @@
package constant
const ()

View File

@ -1,4 +1,4 @@
package constant package constants
type ConnStatus int8 type ConnStatus int8

View File

@ -0,0 +1,3 @@
package constants
const ()

View File

@ -0,0 +1,13 @@
package entitys
type Response string
const (
ResponseJson Response = "json"
ResponseLoading Response = "loading"
ResponseEnd Response = "end"
ResponseStream Response = "stream"
ResponseText Response = "txt"
ResponseImg Response = "img"
ResponseFile Response = "file"
)

View File

@ -0,0 +1,75 @@
package pkg
import (
"ai_scheduler/internal/config"
"errors"
"sync"
)
type SafeChannelPool struct {
pool chan chan []byte // 存储空闲 channel 的队列
bufSize int // channel 缓冲大小
mu sync.Mutex
closed bool
}
func NewSafeChannelPool(c *config.Config) (*SafeChannelPool, func()) {
pool := &SafeChannelPool{
pool: make(chan chan []byte, c.Sys.ChannelPoolLen),
bufSize: c.Sys.ChannelPoolSize,
}
cleanup := pool.Close
return pool, cleanup
}
// 从池中获取 channel若无空闲则创建新 channel
func (p *SafeChannelPool) Get() (chan []byte, error) {
p.mu.Lock()
defer p.mu.Unlock()
if p.closed {
return nil, errors.New("pool is closed")
}
select {
case ch := <-p.pool: // 从池中取
return ch, nil
default: // 池为空,创建新 channel
return make(chan []byte, p.bufSize), nil
}
}
// 将 channel 放回池中(必须确保 channel 已清空!)
func (p *SafeChannelPool) Put(ch chan []byte) error {
p.mu.Lock()
defer p.mu.Unlock()
if p.closed {
return errors.New("pool is closed")
}
// 清空 channel防止复用时读取旧数据
go func() {
for range ch {
// 丢弃所有数据(或根据业务需求处理)
}
}()
select {
case p.pool <- ch: // 尝试放回池中
default: // 池已满,直接关闭 channel避免泄漏
close(ch)
}
return nil
}
// 关闭池(释放所有资源)
func (p *SafeChannelPool) Close() {
p.mu.Lock()
defer p.mu.Unlock()
p.closed = true
close(p.pool) // 关闭池队列
// 需额外逻辑关闭所有内部 channel此处简化
}

View File

@ -11,4 +11,5 @@ var ProviderSetClient = wire.NewSet(
NewGormDb, NewGormDb,
utils_ollama.NewUtilOllama, utils_ollama.NewUtilOllama,
utils_ollama.NewClient, utils_ollama.NewClient,
NewSafeChannelPool,
) )

View File

@ -2,7 +2,7 @@ package services
import ( import (
"ai_scheduler/internal/biz" "ai_scheduler/internal/biz"
"ai_scheduler/internal/data/constant" "ai_scheduler/internal/data/constants"
"ai_scheduler/internal/entitys" "ai_scheduler/internal/entitys"
"ai_scheduler/internal/gateway" "ai_scheduler/internal/gateway"
"encoding/hex" "encoding/hex"
@ -88,10 +88,10 @@ func (h *ChatService) Chat(c *websocket.Conn) {
log.Printf("bind %s -> uid:%s\n", clientID, uid) log.Printf("bind %s -> uid:%s\n", clientID, uid)
} }
msg, chatType := h.handleMessageToString(c, messageType, message) msg, chatType := h.handleMessageToString(c, messageType, message)
if chatType == constant.ConnStatusClosed { if chatType == constants.ConnStatusClosed {
break break
} }
if chatType == constant.ConnStatusIgnore { if chatType == constants.ConnStatusIgnore {
continue continue
} }
@ -112,23 +112,23 @@ func (h *ChatService) Chat(c *websocket.Conn) {
log.Println("client disconnected:", clientID) log.Println("client disconnected:", clientID)
} }
func (h *ChatService) handleMessageToString(c *websocket.Conn, msgType int, msg any) (text []byte, chatType constant.ConnStatus) { func (h *ChatService) handleMessageToString(c *websocket.Conn, msgType int, msg any) (text []byte, chatType constants.ConnStatus) {
switch msgType { switch msgType {
case websocket.TextMessage: case websocket.TextMessage:
return msg.([]byte), constant.ConnStatusNormal return msg.([]byte), constants.ConnStatusNormal
case websocket.BinaryMessage: case websocket.BinaryMessage:
return msg.([]byte), constant.ConnStatusNormal return msg.([]byte), constants.ConnStatusNormal
case websocket.CloseMessage: case websocket.CloseMessage:
return nil, constant.ConnStatusClosed return nil, constants.ConnStatusClosed
case websocket.PingMessage: case websocket.PingMessage:
// 可选:回复 Pong // 可选:回复 Pong
c.WriteMessage(websocket.PongMessage, nil) c.WriteMessage(websocket.PongMessage, nil)
return nil, constant.ConnStatusIgnore return nil, constants.ConnStatusIgnore
case websocket.PongMessage: case websocket.PongMessage:
return nil, constant.ConnStatusIgnore return nil, constants.ConnStatusIgnore
default: default:
return nil, constant.ConnStatusIgnore return nil, constants.ConnStatusIgnore
} }
return msg.([]byte), constant.ConnStatusIgnore return msg.([]byte), constants.ConnStatusIgnore
} }