结构修改

This commit is contained in:
renzhiyuan 2025-09-23 21:17:44 +08:00
parent aa86882f7d
commit 5523a8f78a
7 changed files with 258 additions and 68 deletions

View File

@ -9,12 +9,13 @@ ollama:
timeout: "120s" timeout: "120s"
level: "info" level: "info"
format: "json" format: "json"
sys: sys:
session_len: 3 session_len: 3
channel_pool_len: 100 channel_pool_len: 100
channel_pool_size: 32 channel_pool_size: 32
llm_pool_len: 5
redis: redis:
host: 47.97.27.195:6379 host: 47.97.27.195:6379
type: node type: node

View File

@ -16,10 +16,12 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"strings" "strings"
"time"
"gitea.cdlsxd.cn/self-tools/l_request" "gitea.cdlsxd.cn/self-tools/l_request"
"github.com/gofiber/fiber/v2/log" "github.com/gofiber/fiber/v2/log"
"github.com/gofiber/websocket/v2" "github.com/gofiber/websocket/v2"
"github.com/ollama/ollama/api"
"github.com/tmc/langchaingo/llms" "github.com/tmc/langchaingo/llms"
"xorm.io/builder" "xorm.io/builder"
) )
@ -34,6 +36,7 @@ type AiRouterBiz struct {
hisImpl *impl.ChatImpl hisImpl *impl.ChatImpl
conf *config.Config conf *config.Config
utilAgent *utils_ollama.UtilOllama utilAgent *utils_ollama.UtilOllama
ollama *utils_ollama.Client
channelPool *pkg.SafeChannelPool channelPool *pkg.SafeChannelPool
} }
@ -48,6 +51,7 @@ func NewAiRouterBiz(
conf *config.Config, conf *config.Config,
utilAgent *utils_ollama.UtilOllama, utilAgent *utils_ollama.UtilOllama,
channelPool *pkg.SafeChannelPool, channelPool *pkg.SafeChannelPool,
ollama *utils_ollama.Client,
) *AiRouterBiz { ) *AiRouterBiz {
return &AiRouterBiz{ return &AiRouterBiz{
//aiClient: aiClient, //aiClient: aiClient,
@ -59,6 +63,7 @@ func NewAiRouterBiz(
taskImpl: taskImpl, taskImpl: taskImpl,
utilAgent: utilAgent, utilAgent: utilAgent,
channelPool: channelPool, channelPool: channelPool,
ollama: ollama,
} }
} }
@ -68,30 +73,52 @@ func (r *AiRouterBiz) Route(ctx context.Context, req *entitys.ChatRequest) (*ent
return nil, nil return nil, nil
} }
// Route 执行智能路由
func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) (err error) { func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) (err error) {
ch, err := r.channelPool.Get()
if err != nil { ctx, cancel := context.WithCancel(context.Background())
return err defer cancel()
} //ch := r.channelPool.Get()
ch := make(chan entitys.ResponseData)
done := make(chan struct{})
go func() {
defer close(done)
for {
select {
case v, ok := <-ch:
if !ok {
return
}
// 带超时的发送,避免阻塞
if err := sendWithTimeout(c, v, 2*time.Second); err != nil {
log.Errorf("Send error: %v", err)
cancel() // 通知主流程退出
return
}
case <-ctx.Done():
return
}
}
}()
defer func() { defer func() {
if err != nil { if err != nil {
entitys.MsgSend(c, entitys.ResponseData{ _ = entitys.MsgSend(c, entitys.ResponseData{
Done: false, Done: false,
Content: err.Error(), Content: err.Error(),
Type: entitys.ResponseErr, Type: entitys.ResponseErr,
}) })
} }
entitys.MsgSend(c, entitys.ResponseData{ _ = entitys.MsgSend(c, entitys.ResponseData{
Done: true, Done: true,
Content: "", Content: "",
Type: entitys.ResponseEnd, Type: entitys.ResponseEnd,
}) })
err = r.channelPool.Put(ch) //r.channelPool.Put(ch)
if err != nil { close(ch)
close(ch)
}
}() }()
session := c.Headers("X-Session", "") session := c.Headers("X-Session", "")
if len(session) == 0 { if len(session) == 0 {
return errors.SessionNotFound return errors.SessionNotFound
@ -119,46 +146,94 @@ func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRe
if err != nil { if err != nil {
return errors.SystemError return errors.SystemError
} }
//意图预测
prompt := r.getPromptLLM(sysInfo, history, req.Text, task) prompt := r.getPromptLLM(sysInfo, history, req.Text, task)
match, err := r.utilAgent.Llm.GenerateContent(context.TODO(), prompt,
llms.WithJSONMode(), AgentClient := r.utilAgent.Get()
)
if err != nil {
return errors.SystemError
}
log.Info(match.Choices[0].Content)
ch <- entitys.ResponseData{ ch <- entitys.ResponseData{
Done: false, Done: false,
Content: match.Choices[0].Content, Content: "准备意图识别",
Type: entitys.ResponseLog, Type: entitys.ResponseLog,
} }
var matchJson entitys.Match
err = json.Unmarshal([]byte(match.Choices[0].Content), &matchJson) match, err := AgentClient.Llm.GenerateContent(
ctx, // 使用可取消的上下文
prompt,
llms.WithJSONMode(),
)
resMsg := match.Choices[0].Content
r.utilAgent.Put(AgentClient)
ch <- entitys.ResponseData{
Done: false,
Content: resMsg,
Type: entitys.ResponseLog,
}
ch <- entitys.ResponseData{
Done: false,
Content: "意图识别结束",
Type: entitys.ResponseLog,
}
//for i := 1; i < 10; i++ {
// ch <- entitys.ResponseData{
// Done: false,
// Content: fmt.Sprintf("%d", i),
// Type: entitys.ResponseLog,
// }
// time.Sleep(1 * time.Second)
//}
//return
if err != nil { if err != nil {
log.Errorf("LLM error: %v", err)
return errors.SystemError return errors.SystemError
} }
go func() {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("recovered from panic: %v", r)
}
}()
defer close(ch)
err = r.handleMatch(c, ch, &matchJson, task, sysInfo)
if err != nil {
return
}
}()
for v := range ch { //msg, err := r.ollama.ToolSelect(ctx, r.getPromptOllama(sysInfo, history, req.Text), []api.Tool{})
if err := entitys.MsgSend(c, v); err != nil { //if err != nil {
return err // return
} //}
//resMsg := msg.Message.Content
select {
case ch <- entitys.ResponseData{
Done: false,
Content: resMsg,
Type: entitys.ResponseLog,
}:
case <-ctx.Done():
return ctx.Err()
} }
return
var matchJson entitys.Match
if err := json.Unmarshal([]byte(resMsg), &matchJson); err != nil {
return errors.SystemError
}
if err := r.handleMatch(ctx, c, ch, &matchJson, task, sysInfo); err != nil {
return err
}
return nil
} }
// 辅助函数:带超时的 WebSocket 发送
func sendWithTimeout(c *websocket.Conn, data entitys.ResponseData, timeout time.Duration) error {
sendCtx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
done := make(chan error, 1)
go func() {
done <- entitys.MsgSend(c, data)
}()
select {
case err := <-done:
return err
case <-sendCtx.Done():
return sendCtx.Err()
}
}
func (r *AiRouterBiz) handleOtherTask(c *websocket.Conn, ch chan entitys.ResponseData, matchJson *entitys.Match) (err error) { func (r *AiRouterBiz) handleOtherTask(c *websocket.Conn, ch chan entitys.ResponseData, matchJson *entitys.Match) (err error) {
ch <- entitys.ResponseData{ ch <- entitys.ResponseData{
Done: false, Done: false,
@ -169,7 +244,7 @@ func (r *AiRouterBiz) handleOtherTask(c *websocket.Conn, ch chan entitys.Respons
return return
} }
func (r *AiRouterBiz) handleMatch(c *websocket.Conn, ch chan entitys.ResponseData, matchJson *entitys.Match, tasks []model.AiTask, sysInfo model.AiSy) (err error) { func (r *AiRouterBiz) handleMatch(ctx context.Context, c *websocket.Conn, ch chan entitys.ResponseData, matchJson *entitys.Match, tasks []model.AiTask, sysInfo model.AiSy) (err error) {
if !matchJson.IsMatch { if !matchJson.IsMatch {
ch <- entitys.ResponseData{ ch <- entitys.ResponseData{
@ -404,6 +479,23 @@ func (r *AiRouterBiz) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, re
return prompt return prompt
} }
func (r *AiRouterBiz) getPromptOllama(sysInfo model.AiSy, history []model.AiChatHi, reqInput string) []api.Message {
var (
prompt = make([]api.Message, 0)
)
prompt = append(prompt, api.Message{
Role: "system",
Content: r.buildSystemPrompt(sysInfo.SysPrompt),
}, api.Message{
Role: "assistant",
Content: pkg.JsonStringIgonErr(r.buildAssistant(history)),
}, api.Message{
Role: "user",
Content: reqInput,
})
return prompt
}
func (r *AiRouterBiz) getPromptLLM(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []llms.MessageContent { func (r *AiRouterBiz) getPromptLLM(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []llms.MessageContent {
var ( var (
prompt = make([]llms.MessageContent, 0) prompt = make([]llms.MessageContent, 0)

View File

@ -28,6 +28,7 @@ type SysConfig struct {
SessionLen int `mapstructure:"session_len"` SessionLen int `mapstructure:"session_len"`
ChannelPoolLen int `mapstructure:"channel_pool_len"` ChannelPoolLen int `mapstructure:"channel_pool_len"`
ChannelPoolSize int `mapstructure:"channel_pool_size"` ChannelPoolSize int `mapstructure:"channel_pool_size"`
LlmPoolLen int `mapstructure:"llm_pool_len"`
} }
// ServerConfig 服务器配置 // ServerConfig 服务器配置

View File

@ -3,7 +3,6 @@ package pkg
import ( import (
"ai_scheduler/internal/config" "ai_scheduler/internal/config"
"ai_scheduler/internal/entitys" "ai_scheduler/internal/entitys"
"errors"
"sync" "sync"
) )
@ -25,29 +24,29 @@ func NewSafeChannelPool(c *config.Config) (*SafeChannelPool, func()) {
} }
// 从池中获取 channel若无空闲则创建新 channel // 从池中获取 channel若无空闲则创建新 channel
func (p *SafeChannelPool) Get() (chan entitys.ResponseData, error) { func (p *SafeChannelPool) Get() chan entitys.ResponseData {
p.mu.Lock() p.mu.Lock()
defer p.mu.Unlock() defer p.mu.Unlock()
if p.closed { if p.closed {
return nil, errors.New("pool is closed") return make(chan entitys.ResponseData, p.bufSize)
} }
select { select {
case ch := <-p.pool: // 从池中取 case ch := <-p.pool: // 从池中取
return ch, nil return ch
default: // 池为空,创建新 channel default: // 池为空,创建新 channel
return make(chan entitys.ResponseData, p.bufSize), nil return make(chan entitys.ResponseData, p.bufSize)
} }
} }
// 将 channel 放回池中(必须确保 channel 已清空!) // 将 channel 放回池中(必须确保 channel 已清空!)
func (p *SafeChannelPool) Put(ch chan entitys.ResponseData) error { func (p *SafeChannelPool) Put(ch chan entitys.ResponseData) {
p.mu.Lock() p.mu.Lock()
defer p.mu.Unlock() defer p.mu.Unlock()
if p.closed { if p.closed {
return errors.New("pool is closed") return
} }
// 清空 channel防止复用时读取旧数据 // 清空 channel防止复用时读取旧数据
@ -62,7 +61,7 @@ func (p *SafeChannelPool) Put(ch chan entitys.ResponseData) error {
default: // 池已满,直接关闭 channel避免泄漏 default: // 池已满,直接关闭 channel避免泄漏
close(ch) close(ch)
} }
return nil return
} }
// 关闭池(释放所有资源) // 关闭池(释放所有资源)

View File

@ -1,8 +1,23 @@
package pkg package pkg
import "encoding/json" import (
"ai_scheduler/internal/entitys"
"encoding/json"
)
func JsonStringIgonErr(data interface{}) string { func JsonStringIgonErr(data interface{}) string {
dataByte, _ := json.Marshal(data) dataByte, _ := json.Marshal(data)
return string(dataByte) return string(dataByte)
} }
// IsChannelClosed 检查给定的 channel 是否已经关闭
// 参数 ch: 要检查的 channel类型为 chan entitys.ResponseData
// 返回值: bool 类型true 表示 channel 已关闭false 表示未关闭
func IsChannelClosed(ch chan entitys.ResponseData) bool {
select {
case _, ok := <-ch: // 尝试从 channel 中读取数据
return !ok // 如果 ok=false说明 channel 已关闭
default: // 如果 channel 暂时无数据可读(但不一定关闭)
return false // channel 未关闭(但可能有数据未读取)
}
}

View File

@ -46,7 +46,7 @@ func (c *Client) ToolSelect(ctx context.Context, messages []api.Message, tools [
Messages: messages, Messages: messages,
Stream: new(bool), // 设置为false不使用流式响应 Stream: new(bool), // 设置为false不使用流式响应
Think: &api.ThinkValue{Value: true}, Think: &api.ThinkValue{Value: true},
Tools: tools, //Tools: tools,
} }
err = c.client.Chat(ctx, req, func(resp api.ChatResponse) error { err = c.client.Chat(ctx, req, func(resp api.ChatResponse) error {

View File

@ -2,40 +2,112 @@ package utils_ollama
import ( import (
"ai_scheduler/internal/config" "ai_scheduler/internal/config"
"math/rand"
"net"
"net/http" "net/http"
"os" "os"
"sync"
"time"
"github.com/gofiber/fiber/v2/log" "github.com/gofiber/fiber/v2/log"
"github.com/tmc/langchaingo/llms/ollama" "github.com/tmc/langchaingo/llms/ollama"
) )
type UtilOllama struct { type UtilOllama struct {
Llm *ollama.LLM LlmClientPool *sync.Pool
poolSize int // 记录池大小,用于调试
model string
serverURL string
c *config.Config
}
type LlmObj struct {
Number string
Llm *ollama.LLM
} }
func NewUtilOllama(c *config.Config, logger log.AllLogger) *UtilOllama { func NewUtilOllama(c *config.Config, logger log.AllLogger) *UtilOllama {
llm, err := ollama.New( poolSize := c.Sys.LlmPoolLen
ollama.WithModel(c.Ollama.Model), if poolSize <= 0 {
ollama.WithHTTPClient(http.DefaultClient), poolSize = 10 // 默认值
ollama.WithServerURL(getUrl(c)), logger.Warnf("LlmPoolLen not set, using default: %d", poolSize)
ollama.WithKeepAlive("-1s"), }
)
if err != nil { // 初始化 Pool
logger.Fatal(err) pool := &sync.Pool{
panic(err) New: func() interface{} {
llm, err := ollama.New(
ollama.WithModel(c.Ollama.Model),
ollama.WithHTTPClient(http.DefaultClient),
ollama.WithServerURL(getUrl(c)),
ollama.WithKeepAlive("-1s"),
)
if err != nil {
logger.Fatalf("Failed to create Ollama client: %v", err)
panic(err) // 或者返回 nil + 错误处理
}
number := randStr(5)
log.Info(number)
return &LlmObj{
Number: number,
Llm: llm,
}
},
}
// 预填充 Pool
for i := 0; i < poolSize; i++ {
pool.Put(pool.New())
} }
return &UtilOllama{ return &UtilOllama{
Llm: llm, LlmClientPool: pool,
poolSize: poolSize,
model: c.Ollama.Model,
serverURL: getUrl(c),
} }
} }
//func (o *UtilOllama) a() { func (o *UtilOllama) NewClient() *ollama.LLM {
// var agent agents.Agent llm, _ := ollama.New(
// agent = agents.NewOneShotAgent(llm, tools, opts...) ollama.WithModel(o.c.Ollama.Model),
// ollama.WithHTTPClient(&http.Client{
// agents.NewExecutor() Transport: &http.Transport{
//} MaxIdleConns: 100, // 最大空闲连接数(默认 2太小
MaxIdleConnsPerHost: 100, // 每个 Host 的最大空闲连接数(默认 2
IdleConnTimeout: 90 * time.Second, // 空闲连接超时时间
DialContext: (&net.Dialer{
Timeout: 30 * time.Second, // 连接超时
KeepAlive: 30 * time.Second, // TCP Keep-Alive
}).DialContext,
},
Timeout: 60 * time.Second, // 整体请求超时(避免无限等待)
}),
ollama.WithServerURL(getUrl(o.c)),
ollama.WithKeepAlive("-1s"),
)
return llm
}
// Get 返回一个可用的 LLM 客户端
func (o *UtilOllama) Get() *LlmObj {
client := o.LlmClientPool.Get().(*LlmObj)
return client
}
// Put 归还客户端(可选:检查是否仍可用)
func (o *UtilOllama) Put(llm *LlmObj) {
if llm == nil {
return
}
o.LlmClientPool.Put(llm)
}
// Stats 返回池的统计信息(用于监控)
func (o *UtilOllama) Stats() (current, max int) {
return o.poolSize, o.poolSize
}
func getUrl(c *config.Config) string { func getUrl(c *config.Config) string {
baseURL := c.Ollama.BaseURL baseURL := c.Ollama.BaseURL
@ -45,3 +117,13 @@ func getUrl(c *config.Config) string {
} }
return baseURL return baseURL
} }
var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
func randStr(n int) string {
b := make([]rune, n)
for i := range b {
b[i] = letters[rand.Intn(len(letters))]
}
return string(b)
}