diff --git a/config/config.yaml b/config/config.yaml index c85ac02..c6945d6 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -9,12 +9,13 @@ ollama: timeout: "120s" level: "info" format: "json" + sys: session_len: 3 channel_pool_len: 100 channel_pool_size: 32 - + llm_pool_len: 5 redis: host: 47.97.27.195:6379 type: node diff --git a/internal/biz/router.go b/internal/biz/router.go index 6e5fccb..1ed0ed8 100644 --- a/internal/biz/router.go +++ b/internal/biz/router.go @@ -16,10 +16,12 @@ import ( "encoding/json" "fmt" "strings" + "time" "gitea.cdlsxd.cn/self-tools/l_request" "github.com/gofiber/fiber/v2/log" "github.com/gofiber/websocket/v2" + "github.com/ollama/ollama/api" "github.com/tmc/langchaingo/llms" "xorm.io/builder" ) @@ -34,6 +36,7 @@ type AiRouterBiz struct { hisImpl *impl.ChatImpl conf *config.Config utilAgent *utils_ollama.UtilOllama + ollama *utils_ollama.Client channelPool *pkg.SafeChannelPool } @@ -48,6 +51,7 @@ func NewAiRouterBiz( conf *config.Config, utilAgent *utils_ollama.UtilOllama, channelPool *pkg.SafeChannelPool, + ollama *utils_ollama.Client, ) *AiRouterBiz { return &AiRouterBiz{ //aiClient: aiClient, @@ -59,6 +63,7 @@ func NewAiRouterBiz( taskImpl: taskImpl, utilAgent: utilAgent, channelPool: channelPool, + ollama: ollama, } } @@ -68,30 +73,52 @@ func (r *AiRouterBiz) Route(ctx context.Context, req *entitys.ChatRequest) (*ent return nil, nil } -// Route 执行智能路由 func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) (err error) { - ch, err := r.channelPool.Get() - if err != nil { - return err - } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + //ch := r.channelPool.Get() + ch := make(chan entitys.ResponseData) + done := make(chan struct{}) + + go func() { + defer close(done) + for { + select { + case v, ok := <-ch: + if !ok { + return + } + // 带超时的发送,避免阻塞 + if err := sendWithTimeout(c, v, 2*time.Second); err != nil { + log.Errorf("Send error: %v", err) + cancel() // 通知主流程退出 + return + } + case <-ctx.Done(): + return + } + } + }() + defer func() { + if err != nil { - entitys.MsgSend(c, entitys.ResponseData{ + _ = entitys.MsgSend(c, entitys.ResponseData{ Done: false, Content: err.Error(), Type: entitys.ResponseErr, }) } - entitys.MsgSend(c, entitys.ResponseData{ + _ = entitys.MsgSend(c, entitys.ResponseData{ Done: true, Content: "", Type: entitys.ResponseEnd, }) - err = r.channelPool.Put(ch) - if err != nil { - close(ch) - } + //r.channelPool.Put(ch) + close(ch) }() + session := c.Headers("X-Session", "") if len(session) == 0 { return errors.SessionNotFound @@ -119,46 +146,94 @@ func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRe if err != nil { return errors.SystemError } - //意图预测 + prompt := r.getPromptLLM(sysInfo, history, req.Text, task) - match, err := r.utilAgent.Llm.GenerateContent(context.TODO(), prompt, - llms.WithJSONMode(), - ) - if err != nil { - return errors.SystemError - } - log.Info(match.Choices[0].Content) + + AgentClient := r.utilAgent.Get() + ch <- entitys.ResponseData{ Done: false, - Content: match.Choices[0].Content, + Content: "准备意图识别", Type: entitys.ResponseLog, } - var matchJson entitys.Match - err = json.Unmarshal([]byte(match.Choices[0].Content), &matchJson) + + match, err := AgentClient.Llm.GenerateContent( + ctx, // 使用可取消的上下文 + prompt, + llms.WithJSONMode(), + ) + resMsg := match.Choices[0].Content + + r.utilAgent.Put(AgentClient) + ch <- entitys.ResponseData{ + Done: false, + Content: resMsg, + Type: entitys.ResponseLog, + } + ch <- entitys.ResponseData{ + Done: false, + Content: "意图识别结束", + Type: entitys.ResponseLog, + } + //for i := 1; i < 10; i++ { + // ch <- entitys.ResponseData{ + // Done: false, + // Content: fmt.Sprintf("%d", i), + // Type: entitys.ResponseLog, + // } + // time.Sleep(1 * time.Second) + //} + //return + if err != nil { + log.Errorf("LLM error: %v", err) return errors.SystemError } - go func() { - defer func() { - if r := recover(); r != nil { - err = fmt.Errorf("recovered from panic: %v", r) - } - }() - defer close(ch) - err = r.handleMatch(c, ch, &matchJson, task, sysInfo) - if err != nil { - return - } - }() - for v := range ch { - if err := entitys.MsgSend(c, v); err != nil { - return err - } + //msg, err := r.ollama.ToolSelect(ctx, r.getPromptOllama(sysInfo, history, req.Text), []api.Tool{}) + //if err != nil { + // return + //} + //resMsg := msg.Message.Content + select { + case ch <- entitys.ResponseData{ + Done: false, + Content: resMsg, + Type: entitys.ResponseLog, + }: + case <-ctx.Done(): + return ctx.Err() } - return + + var matchJson entitys.Match + if err := json.Unmarshal([]byte(resMsg), &matchJson); err != nil { + return errors.SystemError + } + + if err := r.handleMatch(ctx, c, ch, &matchJson, task, sysInfo); err != nil { + return err + } + + return nil } +// 辅助函数:带超时的 WebSocket 发送 +func sendWithTimeout(c *websocket.Conn, data entitys.ResponseData, timeout time.Duration) error { + sendCtx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + done := make(chan error, 1) + go func() { + done <- entitys.MsgSend(c, data) + }() + + select { + case err := <-done: + return err + case <-sendCtx.Done(): + return sendCtx.Err() + } +} func (r *AiRouterBiz) handleOtherTask(c *websocket.Conn, ch chan entitys.ResponseData, matchJson *entitys.Match) (err error) { ch <- entitys.ResponseData{ Done: false, @@ -169,7 +244,7 @@ func (r *AiRouterBiz) handleOtherTask(c *websocket.Conn, ch chan entitys.Respons return } -func (r *AiRouterBiz) handleMatch(c *websocket.Conn, ch chan entitys.ResponseData, matchJson *entitys.Match, tasks []model.AiTask, sysInfo model.AiSy) (err error) { +func (r *AiRouterBiz) handleMatch(ctx context.Context, c *websocket.Conn, ch chan entitys.ResponseData, matchJson *entitys.Match, tasks []model.AiTask, sysInfo model.AiSy) (err error) { if !matchJson.IsMatch { ch <- entitys.ResponseData{ @@ -404,6 +479,23 @@ func (r *AiRouterBiz) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, re return prompt } +func (r *AiRouterBiz) getPromptOllama(sysInfo model.AiSy, history []model.AiChatHi, reqInput string) []api.Message { + var ( + prompt = make([]api.Message, 0) + ) + prompt = append(prompt, api.Message{ + Role: "system", + Content: r.buildSystemPrompt(sysInfo.SysPrompt), + }, api.Message{ + Role: "assistant", + Content: pkg.JsonStringIgonErr(r.buildAssistant(history)), + }, api.Message{ + Role: "user", + Content: reqInput, + }) + return prompt +} + func (r *AiRouterBiz) getPromptLLM(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []llms.MessageContent { var ( prompt = make([]llms.MessageContent, 0) diff --git a/internal/config/config.go b/internal/config/config.go index 6a0bcc4..2e80b18 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -28,6 +28,7 @@ type SysConfig struct { SessionLen int `mapstructure:"session_len"` ChannelPoolLen int `mapstructure:"channel_pool_len"` ChannelPoolSize int `mapstructure:"channel_pool_size"` + LlmPoolLen int `mapstructure:"llm_pool_len"` } // ServerConfig 服务器配置 diff --git a/internal/pkg/channel_pool.go b/internal/pkg/channel_pool.go index 5fa2137..eda85fa 100644 --- a/internal/pkg/channel_pool.go +++ b/internal/pkg/channel_pool.go @@ -3,7 +3,6 @@ package pkg import ( "ai_scheduler/internal/config" "ai_scheduler/internal/entitys" - "errors" "sync" ) @@ -25,29 +24,29 @@ func NewSafeChannelPool(c *config.Config) (*SafeChannelPool, func()) { } // 从池中获取 channel(若无空闲则创建新 channel) -func (p *SafeChannelPool) Get() (chan entitys.ResponseData, error) { +func (p *SafeChannelPool) Get() chan entitys.ResponseData { p.mu.Lock() defer p.mu.Unlock() if p.closed { - return nil, errors.New("pool is closed") + return make(chan entitys.ResponseData, p.bufSize) } select { case ch := <-p.pool: // 从池中取 - return ch, nil + return ch default: // 池为空,创建新 channel - return make(chan entitys.ResponseData, p.bufSize), nil + return make(chan entitys.ResponseData, p.bufSize) } } // 将 channel 放回池中(必须确保 channel 已清空!) -func (p *SafeChannelPool) Put(ch chan entitys.ResponseData) error { +func (p *SafeChannelPool) Put(ch chan entitys.ResponseData) { p.mu.Lock() defer p.mu.Unlock() if p.closed { - return errors.New("pool is closed") + return } // 清空 channel(防止复用时读取旧数据) @@ -62,7 +61,7 @@ func (p *SafeChannelPool) Put(ch chan entitys.ResponseData) error { default: // 池已满,直接关闭 channel(避免泄漏) close(ch) } - return nil + return } // 关闭池(释放所有资源) diff --git a/internal/pkg/func.go b/internal/pkg/func.go index a31a818..20d9d7c 100644 --- a/internal/pkg/func.go +++ b/internal/pkg/func.go @@ -1,8 +1,23 @@ package pkg -import "encoding/json" +import ( + "ai_scheduler/internal/entitys" + "encoding/json" +) func JsonStringIgonErr(data interface{}) string { dataByte, _ := json.Marshal(data) return string(dataByte) } + +// IsChannelClosed 检查给定的 channel 是否已经关闭 +// 参数 ch: 要检查的 channel,类型为 chan entitys.ResponseData +// 返回值: bool 类型,true 表示 channel 已关闭,false 表示未关闭 +func IsChannelClosed(ch chan entitys.ResponseData) bool { + select { + case _, ok := <-ch: // 尝试从 channel 中读取数据 + return !ok // 如果 ok=false,说明 channel 已关闭 + default: // 如果 channel 暂时无数据可读(但不一定关闭) + return false // channel 未关闭(但可能有数据未读取) + } +} diff --git a/internal/pkg/utils_ollama/client.go b/internal/pkg/utils_ollama/client.go index 390caa8..9d9ac82 100644 --- a/internal/pkg/utils_ollama/client.go +++ b/internal/pkg/utils_ollama/client.go @@ -46,7 +46,7 @@ func (c *Client) ToolSelect(ctx context.Context, messages []api.Message, tools [ Messages: messages, Stream: new(bool), // 设置为false,不使用流式响应 Think: &api.ThinkValue{Value: true}, - Tools: tools, + //Tools: tools, } err = c.client.Chat(ctx, req, func(resp api.ChatResponse) error { diff --git a/internal/pkg/utils_ollama/ollama.go b/internal/pkg/utils_ollama/ollama.go index 65388a3..6121838 100644 --- a/internal/pkg/utils_ollama/ollama.go +++ b/internal/pkg/utils_ollama/ollama.go @@ -2,40 +2,112 @@ package utils_ollama import ( "ai_scheduler/internal/config" + "math/rand" + "net" "net/http" "os" + "sync" + "time" "github.com/gofiber/fiber/v2/log" "github.com/tmc/langchaingo/llms/ollama" ) type UtilOllama struct { - Llm *ollama.LLM + LlmClientPool *sync.Pool + poolSize int // 记录池大小,用于调试 + model string + serverURL string + c *config.Config +} + +type LlmObj struct { + Number string + Llm *ollama.LLM } func NewUtilOllama(c *config.Config, logger log.AllLogger) *UtilOllama { - llm, err := ollama.New( - ollama.WithModel(c.Ollama.Model), - ollama.WithHTTPClient(http.DefaultClient), - ollama.WithServerURL(getUrl(c)), - ollama.WithKeepAlive("-1s"), - ) - if err != nil { - logger.Fatal(err) - panic(err) + poolSize := c.Sys.LlmPoolLen + if poolSize <= 0 { + poolSize = 10 // 默认值 + logger.Warnf("LlmPoolLen not set, using default: %d", poolSize) + } + + // 初始化 Pool + pool := &sync.Pool{ + New: func() interface{} { + llm, err := ollama.New( + ollama.WithModel(c.Ollama.Model), + ollama.WithHTTPClient(http.DefaultClient), + ollama.WithServerURL(getUrl(c)), + ollama.WithKeepAlive("-1s"), + ) + if err != nil { + logger.Fatalf("Failed to create Ollama client: %v", err) + panic(err) // 或者返回 nil + 错误处理 + } + number := randStr(5) + log.Info(number) + return &LlmObj{ + Number: number, + Llm: llm, + } + }, + } + + // 预填充 Pool + for i := 0; i < poolSize; i++ { + pool.Put(pool.New()) } return &UtilOllama{ - Llm: llm, + LlmClientPool: pool, + poolSize: poolSize, + model: c.Ollama.Model, + serverURL: getUrl(c), } + } -//func (o *UtilOllama) a() { -// var agent agents.Agent -// agent = agents.NewOneShotAgent(llm, tools, opts...) -// -// agents.NewExecutor() -//} +func (o *UtilOllama) NewClient() *ollama.LLM { + llm, _ := ollama.New( + ollama.WithModel(o.c.Ollama.Model), + ollama.WithHTTPClient(&http.Client{ + Transport: &http.Transport{ + MaxIdleConns: 100, // 最大空闲连接数(默认 2,太小) + MaxIdleConnsPerHost: 100, // 每个 Host 的最大空闲连接数(默认 2) + IdleConnTimeout: 90 * time.Second, // 空闲连接超时时间 + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, // 连接超时 + KeepAlive: 30 * time.Second, // TCP Keep-Alive + }).DialContext, + }, + Timeout: 60 * time.Second, // 整体请求超时(避免无限等待) + }), + ollama.WithServerURL(getUrl(o.c)), + ollama.WithKeepAlive("-1s"), + ) + return llm +} + +// Get 返回一个可用的 LLM 客户端 +func (o *UtilOllama) Get() *LlmObj { + client := o.LlmClientPool.Get().(*LlmObj) + return client +} + +// Put 归还客户端(可选:检查是否仍可用) +func (o *UtilOllama) Put(llm *LlmObj) { + if llm == nil { + return + } + o.LlmClientPool.Put(llm) +} + +// Stats 返回池的统计信息(用于监控) +func (o *UtilOllama) Stats() (current, max int) { + return o.poolSize, o.poolSize +} func getUrl(c *config.Config) string { baseURL := c.Ollama.BaseURL @@ -45,3 +117,13 @@ func getUrl(c *config.Config) string { } return baseURL } + +var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") + +func randStr(n int) string { + b := make([]rune, n) + for i := range b { + b[i] = letters[rand.Intn(len(letters))] + } + return string(b) +}