结构修改
This commit is contained in:
parent
b996f62098
commit
c07b57a705
|
@ -10,18 +10,21 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
|
||||||
configPath := flag.String("config", "./config/config.yaml", "Path to configuration file")
|
configPath := flag.String("config", "./config/config.yaml", "Path to configuration file")
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
bc, err := config.LoadConfig(*configPath)
|
bc, err := config.LoadConfig(*configPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("加载配置失败: %v", err)
|
log.Fatalf("加载配置失败: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
app, cleanup, err := InitializeApp(bc, log.DefaultLogger())
|
app, cleanup, err := InitializeApp(bc, log.DefaultLogger())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("项目初始化失败: %v", err)
|
log.Fatalf("项目初始化失败: %v", err)
|
||||||
}
|
}
|
||||||
defer cleanup()
|
defer func() {
|
||||||
|
cleanup()
|
||||||
|
|
||||||
|
}()
|
||||||
|
|
||||||
log.Fatal(app.HttpServer.Listen(fmt.Sprintf(":%d", bc.Server.Port)))
|
log.Fatal(app.HttpServer.Listen(fmt.Sprintf(":%d", bc.Server.Port)))
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,6 +12,8 @@ ollama:
|
||||||
|
|
||||||
sys:
|
sys:
|
||||||
session_len: 3
|
session_len: 3
|
||||||
|
channel_pool_len: 100
|
||||||
|
channel_pool_size: 32
|
||||||
|
|
||||||
redis:
|
redis:
|
||||||
host: 47.97.27.195:6379
|
host: 47.97.27.195:6379
|
||||||
|
|
|
@ -2,7 +2,7 @@ package biz
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"ai_scheduler/internal/config"
|
"ai_scheduler/internal/config"
|
||||||
"ai_scheduler/internal/data/constant"
|
"ai_scheduler/internal/data/constants"
|
||||||
errors "ai_scheduler/internal/data/error"
|
errors "ai_scheduler/internal/data/error"
|
||||||
"ai_scheduler/internal/data/impl"
|
"ai_scheduler/internal/data/impl"
|
||||||
"ai_scheduler/internal/data/model"
|
"ai_scheduler/internal/data/model"
|
||||||
|
@ -34,6 +34,7 @@ type AiRouterBiz struct {
|
||||||
hisImpl *impl.ChatImpl
|
hisImpl *impl.ChatImpl
|
||||||
conf *config.Config
|
conf *config.Config
|
||||||
utilAgent *utils_ollama.UtilOllama
|
utilAgent *utils_ollama.UtilOllama
|
||||||
|
channelPool *pkg.SafeChannelPool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRouterService 创建路由服务
|
// NewRouterService 创建路由服务
|
||||||
|
@ -46,7 +47,7 @@ func NewAiRouterBiz(
|
||||||
hisImpl *impl.ChatImpl,
|
hisImpl *impl.ChatImpl,
|
||||||
conf *config.Config,
|
conf *config.Config,
|
||||||
utilAgent *utils_ollama.UtilOllama,
|
utilAgent *utils_ollama.UtilOllama,
|
||||||
|
channelPool *pkg.SafeChannelPool,
|
||||||
) *AiRouterBiz {
|
) *AiRouterBiz {
|
||||||
return &AiRouterBiz{
|
return &AiRouterBiz{
|
||||||
//aiClient: aiClient,
|
//aiClient: aiClient,
|
||||||
|
@ -57,6 +58,7 @@ func NewAiRouterBiz(
|
||||||
hisImpl: hisImpl,
|
hisImpl: hisImpl,
|
||||||
taskImpl: taskImpl,
|
taskImpl: taskImpl,
|
||||||
utilAgent: utilAgent,
|
utilAgent: utilAgent,
|
||||||
|
channelPool: channelPool,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -68,13 +70,16 @@ func (r *AiRouterBiz) Route(ctx context.Context, req *entitys.ChatRequest) (*ent
|
||||||
|
|
||||||
// Route 执行智能路由
|
// Route 执行智能路由
|
||||||
func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) (err error) {
|
func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) (err error) {
|
||||||
var ch = make(chan []byte)
|
ch, err := r.channelPool.Get()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = c.WriteMessage(websocket.TextMessage, []byte(err.Error()))
|
_ = c.WriteMessage(websocket.TextMessage, []byte(err.Error()))
|
||||||
}
|
}
|
||||||
_ = c.WriteMessage(websocket.TextMessage, []byte("EOF"))
|
_ = c.WriteMessage(websocket.TextMessage, []byte("EOF"))
|
||||||
|
r.channelPool.Put(ch)
|
||||||
}()
|
}()
|
||||||
session := c.Headers("X-Session", "")
|
session := c.Headers("X-Session", "")
|
||||||
if len(session) == 0 {
|
if len(session) == 0 {
|
||||||
|
@ -164,9 +169,9 @@ func (r *AiRouterBiz) handleMatch(c *websocket.Conn, ch chan []byte, matchJson *
|
||||||
return r.handleOtherTask(c, ch, matchJson)
|
return r.handleOtherTask(c, ch, matchJson)
|
||||||
}
|
}
|
||||||
switch pointTask.Type {
|
switch pointTask.Type {
|
||||||
case constant.TaskTypeApi:
|
case constants.TaskTypeApi:
|
||||||
return r.handleApiTask(ch, c, matchJson, pointTask)
|
return r.handleApiTask(ch, c, matchJson, pointTask)
|
||||||
case constant.TaskTypeFunc:
|
case constants.TaskTypeFunc:
|
||||||
return r.handleTask(ch, c, matchJson, pointTask)
|
return r.handleTask(ch, c, matchJson, pointTask)
|
||||||
default:
|
default:
|
||||||
return r.handleOtherTask(c, ch, matchJson)
|
return r.handleOtherTask(c, ch, matchJson)
|
||||||
|
|
|
@ -26,6 +26,8 @@ type LLM struct {
|
||||||
// SysConfig 系统配置
|
// SysConfig 系统配置
|
||||||
type SysConfig struct {
|
type SysConfig struct {
|
||||||
SessionLen int `mapstructure:"session_len"`
|
SessionLen int `mapstructure:"session_len"`
|
||||||
|
ChannelPoolLen int `mapstructure:"channel_pool_len"`
|
||||||
|
ChannelPoolSize int `mapstructure:"channel_pool_size"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServerConfig 服务器配置
|
// ServerConfig 服务器配置
|
||||||
|
|
|
@ -1,3 +0,0 @@
|
||||||
package constant
|
|
||||||
|
|
||||||
const ()
|
|
|
@ -1,4 +1,4 @@
|
||||||
package constant
|
package constants
|
||||||
|
|
||||||
type ConnStatus int8
|
type ConnStatus int8
|
||||||
|
|
|
@ -0,0 +1,3 @@
|
||||||
|
package constants
|
||||||
|
|
||||||
|
const ()
|
|
@ -0,0 +1,13 @@
|
||||||
|
package entitys
|
||||||
|
|
||||||
|
type Response string
|
||||||
|
|
||||||
|
const (
|
||||||
|
ResponseJson Response = "json"
|
||||||
|
ResponseLoading Response = "loading"
|
||||||
|
ResponseEnd Response = "end"
|
||||||
|
ResponseStream Response = "stream"
|
||||||
|
ResponseText Response = "txt"
|
||||||
|
ResponseImg Response = "img"
|
||||||
|
ResponseFile Response = "file"
|
||||||
|
)
|
|
@ -0,0 +1,75 @@
|
||||||
|
package pkg
|
||||||
|
|
||||||
|
import (
|
||||||
|
"ai_scheduler/internal/config"
|
||||||
|
"errors"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SafeChannelPool struct {
|
||||||
|
pool chan chan []byte // 存储空闲 channel 的队列
|
||||||
|
bufSize int // channel 缓冲大小
|
||||||
|
mu sync.Mutex
|
||||||
|
closed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSafeChannelPool(c *config.Config) (*SafeChannelPool, func()) {
|
||||||
|
pool := &SafeChannelPool{
|
||||||
|
pool: make(chan chan []byte, c.Sys.ChannelPoolLen),
|
||||||
|
bufSize: c.Sys.ChannelPoolSize,
|
||||||
|
}
|
||||||
|
|
||||||
|
cleanup := pool.Close
|
||||||
|
return pool, cleanup
|
||||||
|
}
|
||||||
|
|
||||||
|
// 从池中获取 channel(若无空闲则创建新 channel)
|
||||||
|
func (p *SafeChannelPool) Get() (chan []byte, error) {
|
||||||
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
|
|
||||||
|
if p.closed {
|
||||||
|
return nil, errors.New("pool is closed")
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case ch := <-p.pool: // 从池中取
|
||||||
|
return ch, nil
|
||||||
|
default: // 池为空,创建新 channel
|
||||||
|
return make(chan []byte, p.bufSize), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 将 channel 放回池中(必须确保 channel 已清空!)
|
||||||
|
func (p *SafeChannelPool) Put(ch chan []byte) error {
|
||||||
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
|
|
||||||
|
if p.closed {
|
||||||
|
return errors.New("pool is closed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 清空 channel(防止复用时读取旧数据)
|
||||||
|
go func() {
|
||||||
|
for range ch {
|
||||||
|
// 丢弃所有数据(或根据业务需求处理)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case p.pool <- ch: // 尝试放回池中
|
||||||
|
default: // 池已满,直接关闭 channel(避免泄漏)
|
||||||
|
close(ch)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 关闭池(释放所有资源)
|
||||||
|
func (p *SafeChannelPool) Close() {
|
||||||
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
|
|
||||||
|
p.closed = true
|
||||||
|
close(p.pool) // 关闭池队列
|
||||||
|
// 需额外逻辑关闭所有内部 channel(此处简化)
|
||||||
|
}
|
|
@ -11,4 +11,5 @@ var ProviderSetClient = wire.NewSet(
|
||||||
NewGormDb,
|
NewGormDb,
|
||||||
utils_ollama.NewUtilOllama,
|
utils_ollama.NewUtilOllama,
|
||||||
utils_ollama.NewClient,
|
utils_ollama.NewClient,
|
||||||
|
NewSafeChannelPool,
|
||||||
)
|
)
|
||||||
|
|
|
@ -2,7 +2,7 @@ package services
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"ai_scheduler/internal/biz"
|
"ai_scheduler/internal/biz"
|
||||||
"ai_scheduler/internal/data/constant"
|
"ai_scheduler/internal/data/constants"
|
||||||
"ai_scheduler/internal/entitys"
|
"ai_scheduler/internal/entitys"
|
||||||
"ai_scheduler/internal/gateway"
|
"ai_scheduler/internal/gateway"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
|
@ -88,10 +88,10 @@ func (h *ChatService) Chat(c *websocket.Conn) {
|
||||||
log.Printf("bind %s -> uid:%s\n", clientID, uid)
|
log.Printf("bind %s -> uid:%s\n", clientID, uid)
|
||||||
}
|
}
|
||||||
msg, chatType := h.handleMessageToString(c, messageType, message)
|
msg, chatType := h.handleMessageToString(c, messageType, message)
|
||||||
if chatType == constant.ConnStatusClosed {
|
if chatType == constants.ConnStatusClosed {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if chatType == constant.ConnStatusIgnore {
|
if chatType == constants.ConnStatusIgnore {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -112,23 +112,23 @@ func (h *ChatService) Chat(c *websocket.Conn) {
|
||||||
log.Println("client disconnected:", clientID)
|
log.Println("client disconnected:", clientID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *ChatService) handleMessageToString(c *websocket.Conn, msgType int, msg any) (text []byte, chatType constant.ConnStatus) {
|
func (h *ChatService) handleMessageToString(c *websocket.Conn, msgType int, msg any) (text []byte, chatType constants.ConnStatus) {
|
||||||
switch msgType {
|
switch msgType {
|
||||||
case websocket.TextMessage:
|
case websocket.TextMessage:
|
||||||
return msg.([]byte), constant.ConnStatusNormal
|
return msg.([]byte), constants.ConnStatusNormal
|
||||||
case websocket.BinaryMessage:
|
case websocket.BinaryMessage:
|
||||||
return msg.([]byte), constant.ConnStatusNormal
|
return msg.([]byte), constants.ConnStatusNormal
|
||||||
case websocket.CloseMessage:
|
case websocket.CloseMessage:
|
||||||
|
|
||||||
return nil, constant.ConnStatusClosed
|
return nil, constants.ConnStatusClosed
|
||||||
case websocket.PingMessage:
|
case websocket.PingMessage:
|
||||||
// 可选:回复 Pong
|
// 可选:回复 Pong
|
||||||
c.WriteMessage(websocket.PongMessage, nil)
|
c.WriteMessage(websocket.PongMessage, nil)
|
||||||
return nil, constant.ConnStatusIgnore
|
return nil, constants.ConnStatusIgnore
|
||||||
case websocket.PongMessage:
|
case websocket.PongMessage:
|
||||||
return nil, constant.ConnStatusIgnore
|
return nil, constants.ConnStatusIgnore
|
||||||
default:
|
default:
|
||||||
return nil, constant.ConnStatusIgnore
|
return nil, constants.ConnStatusIgnore
|
||||||
}
|
}
|
||||||
return msg.([]byte), constant.ConnStatusIgnore
|
return msg.([]byte), constants.ConnStatusIgnore
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue