结构修改
This commit is contained in:
parent
b996f62098
commit
c07b57a705
|
@ -10,18 +10,21 @@ import (
|
|||
)
|
||||
|
||||
func main() {
|
||||
|
||||
configPath := flag.String("config", "./config/config.yaml", "Path to configuration file")
|
||||
flag.Parse()
|
||||
bc, err := config.LoadConfig(*configPath)
|
||||
if err != nil {
|
||||
log.Fatalf("加载配置失败: %v", err)
|
||||
}
|
||||
|
||||
app, cleanup, err := InitializeApp(bc, log.DefaultLogger())
|
||||
if err != nil {
|
||||
log.Fatalf("项目初始化失败: %v", err)
|
||||
}
|
||||
defer cleanup()
|
||||
defer func() {
|
||||
cleanup()
|
||||
|
||||
}()
|
||||
|
||||
log.Fatal(app.HttpServer.Listen(fmt.Sprintf(":%d", bc.Server.Port)))
|
||||
}
|
||||
|
|
|
@ -12,6 +12,8 @@ ollama:
|
|||
|
||||
sys:
|
||||
session_len: 3
|
||||
channel_pool_len: 100
|
||||
channel_pool_size: 32
|
||||
|
||||
redis:
|
||||
host: 47.97.27.195:6379
|
||||
|
|
|
@ -2,7 +2,7 @@ package biz
|
|||
|
||||
import (
|
||||
"ai_scheduler/internal/config"
|
||||
"ai_scheduler/internal/data/constant"
|
||||
"ai_scheduler/internal/data/constants"
|
||||
errors "ai_scheduler/internal/data/error"
|
||||
"ai_scheduler/internal/data/impl"
|
||||
"ai_scheduler/internal/data/model"
|
||||
|
@ -34,6 +34,7 @@ type AiRouterBiz struct {
|
|||
hisImpl *impl.ChatImpl
|
||||
conf *config.Config
|
||||
utilAgent *utils_ollama.UtilOllama
|
||||
channelPool *pkg.SafeChannelPool
|
||||
}
|
||||
|
||||
// NewRouterService 创建路由服务
|
||||
|
@ -46,7 +47,7 @@ func NewAiRouterBiz(
|
|||
hisImpl *impl.ChatImpl,
|
||||
conf *config.Config,
|
||||
utilAgent *utils_ollama.UtilOllama,
|
||||
|
||||
channelPool *pkg.SafeChannelPool,
|
||||
) *AiRouterBiz {
|
||||
return &AiRouterBiz{
|
||||
//aiClient: aiClient,
|
||||
|
@ -57,6 +58,7 @@ func NewAiRouterBiz(
|
|||
hisImpl: hisImpl,
|
||||
taskImpl: taskImpl,
|
||||
utilAgent: utilAgent,
|
||||
channelPool: channelPool,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -68,13 +70,16 @@ func (r *AiRouterBiz) Route(ctx context.Context, req *entitys.ChatRequest) (*ent
|
|||
|
||||
// Route 执行智能路由
|
||||
func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) (err error) {
|
||||
var ch = make(chan []byte)
|
||||
ch, err := r.channelPool.Get()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
|
||||
if err != nil {
|
||||
_ = c.WriteMessage(websocket.TextMessage, []byte(err.Error()))
|
||||
}
|
||||
_ = c.WriteMessage(websocket.TextMessage, []byte("EOF"))
|
||||
r.channelPool.Put(ch)
|
||||
}()
|
||||
session := c.Headers("X-Session", "")
|
||||
if len(session) == 0 {
|
||||
|
@ -164,9 +169,9 @@ func (r *AiRouterBiz) handleMatch(c *websocket.Conn, ch chan []byte, matchJson *
|
|||
return r.handleOtherTask(c, ch, matchJson)
|
||||
}
|
||||
switch pointTask.Type {
|
||||
case constant.TaskTypeApi:
|
||||
case constants.TaskTypeApi:
|
||||
return r.handleApiTask(ch, c, matchJson, pointTask)
|
||||
case constant.TaskTypeFunc:
|
||||
case constants.TaskTypeFunc:
|
||||
return r.handleTask(ch, c, matchJson, pointTask)
|
||||
default:
|
||||
return r.handleOtherTask(c, ch, matchJson)
|
||||
|
|
|
@ -25,7 +25,9 @@ type LLM struct {
|
|||
|
||||
// SysConfig 系统配置
|
||||
type SysConfig struct {
|
||||
SessionLen int `mapstructure:"session_len"`
|
||||
SessionLen int `mapstructure:"session_len"`
|
||||
ChannelPoolLen int `mapstructure:"channel_pool_len"`
|
||||
ChannelPoolSize int `mapstructure:"channel_pool_size"`
|
||||
}
|
||||
|
||||
// ServerConfig 服务器配置
|
||||
|
|
|
@ -1,3 +0,0 @@
|
|||
package constant
|
||||
|
||||
const ()
|
|
@ -1,4 +1,4 @@
|
|||
package constant
|
||||
package constants
|
||||
|
||||
type ConnStatus int8
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
package constants
|
||||
|
||||
const ()
|
|
@ -0,0 +1,13 @@
|
|||
package entitys
|
||||
|
||||
type Response string
|
||||
|
||||
const (
|
||||
ResponseJson Response = "json"
|
||||
ResponseLoading Response = "loading"
|
||||
ResponseEnd Response = "end"
|
||||
ResponseStream Response = "stream"
|
||||
ResponseText Response = "txt"
|
||||
ResponseImg Response = "img"
|
||||
ResponseFile Response = "file"
|
||||
)
|
|
@ -0,0 +1,75 @@
|
|||
package pkg
|
||||
|
||||
import (
|
||||
"ai_scheduler/internal/config"
|
||||
"errors"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type SafeChannelPool struct {
|
||||
pool chan chan []byte // 存储空闲 channel 的队列
|
||||
bufSize int // channel 缓冲大小
|
||||
mu sync.Mutex
|
||||
closed bool
|
||||
}
|
||||
|
||||
func NewSafeChannelPool(c *config.Config) (*SafeChannelPool, func()) {
|
||||
pool := &SafeChannelPool{
|
||||
pool: make(chan chan []byte, c.Sys.ChannelPoolLen),
|
||||
bufSize: c.Sys.ChannelPoolSize,
|
||||
}
|
||||
|
||||
cleanup := pool.Close
|
||||
return pool, cleanup
|
||||
}
|
||||
|
||||
// 从池中获取 channel(若无空闲则创建新 channel)
|
||||
func (p *SafeChannelPool) Get() (chan []byte, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if p.closed {
|
||||
return nil, errors.New("pool is closed")
|
||||
}
|
||||
|
||||
select {
|
||||
case ch := <-p.pool: // 从池中取
|
||||
return ch, nil
|
||||
default: // 池为空,创建新 channel
|
||||
return make(chan []byte, p.bufSize), nil
|
||||
}
|
||||
}
|
||||
|
||||
// 将 channel 放回池中(必须确保 channel 已清空!)
|
||||
func (p *SafeChannelPool) Put(ch chan []byte) error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if p.closed {
|
||||
return errors.New("pool is closed")
|
||||
}
|
||||
|
||||
// 清空 channel(防止复用时读取旧数据)
|
||||
go func() {
|
||||
for range ch {
|
||||
// 丢弃所有数据(或根据业务需求处理)
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case p.pool <- ch: // 尝试放回池中
|
||||
default: // 池已满,直接关闭 channel(避免泄漏)
|
||||
close(ch)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 关闭池(释放所有资源)
|
||||
func (p *SafeChannelPool) Close() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
p.closed = true
|
||||
close(p.pool) // 关闭池队列
|
||||
// 需额外逻辑关闭所有内部 channel(此处简化)
|
||||
}
|
|
@ -11,4 +11,5 @@ var ProviderSetClient = wire.NewSet(
|
|||
NewGormDb,
|
||||
utils_ollama.NewUtilOllama,
|
||||
utils_ollama.NewClient,
|
||||
NewSafeChannelPool,
|
||||
)
|
||||
|
|
|
@ -2,7 +2,7 @@ package services
|
|||
|
||||
import (
|
||||
"ai_scheduler/internal/biz"
|
||||
"ai_scheduler/internal/data/constant"
|
||||
"ai_scheduler/internal/data/constants"
|
||||
"ai_scheduler/internal/entitys"
|
||||
"ai_scheduler/internal/gateway"
|
||||
"encoding/hex"
|
||||
|
@ -88,10 +88,10 @@ func (h *ChatService) Chat(c *websocket.Conn) {
|
|||
log.Printf("bind %s -> uid:%s\n", clientID, uid)
|
||||
}
|
||||
msg, chatType := h.handleMessageToString(c, messageType, message)
|
||||
if chatType == constant.ConnStatusClosed {
|
||||
if chatType == constants.ConnStatusClosed {
|
||||
break
|
||||
}
|
||||
if chatType == constant.ConnStatusIgnore {
|
||||
if chatType == constants.ConnStatusIgnore {
|
||||
continue
|
||||
}
|
||||
|
||||
|
@ -112,23 +112,23 @@ func (h *ChatService) Chat(c *websocket.Conn) {
|
|||
log.Println("client disconnected:", clientID)
|
||||
}
|
||||
|
||||
func (h *ChatService) handleMessageToString(c *websocket.Conn, msgType int, msg any) (text []byte, chatType constant.ConnStatus) {
|
||||
func (h *ChatService) handleMessageToString(c *websocket.Conn, msgType int, msg any) (text []byte, chatType constants.ConnStatus) {
|
||||
switch msgType {
|
||||
case websocket.TextMessage:
|
||||
return msg.([]byte), constant.ConnStatusNormal
|
||||
return msg.([]byte), constants.ConnStatusNormal
|
||||
case websocket.BinaryMessage:
|
||||
return msg.([]byte), constant.ConnStatusNormal
|
||||
return msg.([]byte), constants.ConnStatusNormal
|
||||
case websocket.CloseMessage:
|
||||
|
||||
return nil, constant.ConnStatusClosed
|
||||
return nil, constants.ConnStatusClosed
|
||||
case websocket.PingMessage:
|
||||
// 可选:回复 Pong
|
||||
c.WriteMessage(websocket.PongMessage, nil)
|
||||
return nil, constant.ConnStatusIgnore
|
||||
return nil, constants.ConnStatusIgnore
|
||||
case websocket.PongMessage:
|
||||
return nil, constant.ConnStatusIgnore
|
||||
return nil, constants.ConnStatusIgnore
|
||||
default:
|
||||
return nil, constant.ConnStatusIgnore
|
||||
return nil, constants.ConnStatusIgnore
|
||||
}
|
||||
return msg.([]byte), constant.ConnStatusIgnore
|
||||
return msg.([]byte), constants.ConnStatusIgnore
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue