结构修改
This commit is contained in:
parent
aa86882f7d
commit
5523a8f78a
|
@ -9,12 +9,13 @@ ollama:
|
||||||
timeout: "120s"
|
timeout: "120s"
|
||||||
level: "info"
|
level: "info"
|
||||||
format: "json"
|
format: "json"
|
||||||
|
|
||||||
|
|
||||||
sys:
|
sys:
|
||||||
session_len: 3
|
session_len: 3
|
||||||
channel_pool_len: 100
|
channel_pool_len: 100
|
||||||
channel_pool_size: 32
|
channel_pool_size: 32
|
||||||
|
llm_pool_len: 5
|
||||||
redis:
|
redis:
|
||||||
host: 47.97.27.195:6379
|
host: 47.97.27.195:6379
|
||||||
type: node
|
type: node
|
||||||
|
|
|
@ -16,10 +16,12 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"gitea.cdlsxd.cn/self-tools/l_request"
|
"gitea.cdlsxd.cn/self-tools/l_request"
|
||||||
"github.com/gofiber/fiber/v2/log"
|
"github.com/gofiber/fiber/v2/log"
|
||||||
"github.com/gofiber/websocket/v2"
|
"github.com/gofiber/websocket/v2"
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/tmc/langchaingo/llms"
|
"github.com/tmc/langchaingo/llms"
|
||||||
"xorm.io/builder"
|
"xorm.io/builder"
|
||||||
)
|
)
|
||||||
|
@ -34,6 +36,7 @@ type AiRouterBiz struct {
|
||||||
hisImpl *impl.ChatImpl
|
hisImpl *impl.ChatImpl
|
||||||
conf *config.Config
|
conf *config.Config
|
||||||
utilAgent *utils_ollama.UtilOllama
|
utilAgent *utils_ollama.UtilOllama
|
||||||
|
ollama *utils_ollama.Client
|
||||||
channelPool *pkg.SafeChannelPool
|
channelPool *pkg.SafeChannelPool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -48,6 +51,7 @@ func NewAiRouterBiz(
|
||||||
conf *config.Config,
|
conf *config.Config,
|
||||||
utilAgent *utils_ollama.UtilOllama,
|
utilAgent *utils_ollama.UtilOllama,
|
||||||
channelPool *pkg.SafeChannelPool,
|
channelPool *pkg.SafeChannelPool,
|
||||||
|
ollama *utils_ollama.Client,
|
||||||
) *AiRouterBiz {
|
) *AiRouterBiz {
|
||||||
return &AiRouterBiz{
|
return &AiRouterBiz{
|
||||||
//aiClient: aiClient,
|
//aiClient: aiClient,
|
||||||
|
@ -59,6 +63,7 @@ func NewAiRouterBiz(
|
||||||
taskImpl: taskImpl,
|
taskImpl: taskImpl,
|
||||||
utilAgent: utilAgent,
|
utilAgent: utilAgent,
|
||||||
channelPool: channelPool,
|
channelPool: channelPool,
|
||||||
|
ollama: ollama,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -68,30 +73,52 @@ func (r *AiRouterBiz) Route(ctx context.Context, req *entitys.ChatRequest) (*ent
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Route 执行智能路由
|
|
||||||
func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) (err error) {
|
func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) (err error) {
|
||||||
ch, err := r.channelPool.Get()
|
|
||||||
if err != nil {
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
return err
|
defer cancel()
|
||||||
}
|
//ch := r.channelPool.Get()
|
||||||
|
ch := make(chan entitys.ResponseData)
|
||||||
|
done := make(chan struct{})
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer close(done)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case v, ok := <-ch:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 带超时的发送,避免阻塞
|
||||||
|
if err := sendWithTimeout(c, v, 2*time.Second); err != nil {
|
||||||
|
log.Errorf("Send error: %v", err)
|
||||||
|
cancel() // 通知主流程退出
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
entitys.MsgSend(c, entitys.ResponseData{
|
_ = entitys.MsgSend(c, entitys.ResponseData{
|
||||||
Done: false,
|
Done: false,
|
||||||
Content: err.Error(),
|
Content: err.Error(),
|
||||||
Type: entitys.ResponseErr,
|
Type: entitys.ResponseErr,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
entitys.MsgSend(c, entitys.ResponseData{
|
_ = entitys.MsgSend(c, entitys.ResponseData{
|
||||||
Done: true,
|
Done: true,
|
||||||
Content: "",
|
Content: "",
|
||||||
Type: entitys.ResponseEnd,
|
Type: entitys.ResponseEnd,
|
||||||
})
|
})
|
||||||
err = r.channelPool.Put(ch)
|
//r.channelPool.Put(ch)
|
||||||
if err != nil {
|
close(ch)
|
||||||
close(ch)
|
|
||||||
}
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
session := c.Headers("X-Session", "")
|
session := c.Headers("X-Session", "")
|
||||||
if len(session) == 0 {
|
if len(session) == 0 {
|
||||||
return errors.SessionNotFound
|
return errors.SessionNotFound
|
||||||
|
@ -119,46 +146,94 @@ func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRe
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.SystemError
|
return errors.SystemError
|
||||||
}
|
}
|
||||||
//意图预测
|
|
||||||
prompt := r.getPromptLLM(sysInfo, history, req.Text, task)
|
prompt := r.getPromptLLM(sysInfo, history, req.Text, task)
|
||||||
match, err := r.utilAgent.Llm.GenerateContent(context.TODO(), prompt,
|
|
||||||
llms.WithJSONMode(),
|
AgentClient := r.utilAgent.Get()
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return errors.SystemError
|
|
||||||
}
|
|
||||||
log.Info(match.Choices[0].Content)
|
|
||||||
ch <- entitys.ResponseData{
|
ch <- entitys.ResponseData{
|
||||||
Done: false,
|
Done: false,
|
||||||
Content: match.Choices[0].Content,
|
Content: "准备意图识别",
|
||||||
Type: entitys.ResponseLog,
|
Type: entitys.ResponseLog,
|
||||||
}
|
}
|
||||||
var matchJson entitys.Match
|
|
||||||
err = json.Unmarshal([]byte(match.Choices[0].Content), &matchJson)
|
match, err := AgentClient.Llm.GenerateContent(
|
||||||
|
ctx, // 使用可取消的上下文
|
||||||
|
prompt,
|
||||||
|
llms.WithJSONMode(),
|
||||||
|
)
|
||||||
|
resMsg := match.Choices[0].Content
|
||||||
|
|
||||||
|
r.utilAgent.Put(AgentClient)
|
||||||
|
ch <- entitys.ResponseData{
|
||||||
|
Done: false,
|
||||||
|
Content: resMsg,
|
||||||
|
Type: entitys.ResponseLog,
|
||||||
|
}
|
||||||
|
ch <- entitys.ResponseData{
|
||||||
|
Done: false,
|
||||||
|
Content: "意图识别结束",
|
||||||
|
Type: entitys.ResponseLog,
|
||||||
|
}
|
||||||
|
//for i := 1; i < 10; i++ {
|
||||||
|
// ch <- entitys.ResponseData{
|
||||||
|
// Done: false,
|
||||||
|
// Content: fmt.Sprintf("%d", i),
|
||||||
|
// Type: entitys.ResponseLog,
|
||||||
|
// }
|
||||||
|
// time.Sleep(1 * time.Second)
|
||||||
|
//}
|
||||||
|
//return
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Errorf("LLM error: %v", err)
|
||||||
return errors.SystemError
|
return errors.SystemError
|
||||||
}
|
}
|
||||||
go func() {
|
|
||||||
defer func() {
|
|
||||||
if r := recover(); r != nil {
|
|
||||||
err = fmt.Errorf("recovered from panic: %v", r)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
defer close(ch)
|
|
||||||
err = r.handleMatch(c, ch, &matchJson, task, sysInfo)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
for v := range ch {
|
//msg, err := r.ollama.ToolSelect(ctx, r.getPromptOllama(sysInfo, history, req.Text), []api.Tool{})
|
||||||
if err := entitys.MsgSend(c, v); err != nil {
|
//if err != nil {
|
||||||
return err
|
// return
|
||||||
}
|
//}
|
||||||
|
//resMsg := msg.Message.Content
|
||||||
|
select {
|
||||||
|
case ch <- entitys.ResponseData{
|
||||||
|
Done: false,
|
||||||
|
Content: resMsg,
|
||||||
|
Type: entitys.ResponseLog,
|
||||||
|
}:
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
}
|
}
|
||||||
return
|
|
||||||
|
var matchJson entitys.Match
|
||||||
|
if err := json.Unmarshal([]byte(resMsg), &matchJson); err != nil {
|
||||||
|
return errors.SystemError
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.handleMatch(ctx, c, ch, &matchJson, task, sysInfo); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 辅助函数:带超时的 WebSocket 发送
|
||||||
|
func sendWithTimeout(c *websocket.Conn, data entitys.ResponseData, timeout time.Duration) error {
|
||||||
|
sendCtx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
done := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
done <- entitys.MsgSend(c, data)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-done:
|
||||||
|
return err
|
||||||
|
case <-sendCtx.Done():
|
||||||
|
return sendCtx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
func (r *AiRouterBiz) handleOtherTask(c *websocket.Conn, ch chan entitys.ResponseData, matchJson *entitys.Match) (err error) {
|
func (r *AiRouterBiz) handleOtherTask(c *websocket.Conn, ch chan entitys.ResponseData, matchJson *entitys.Match) (err error) {
|
||||||
ch <- entitys.ResponseData{
|
ch <- entitys.ResponseData{
|
||||||
Done: false,
|
Done: false,
|
||||||
|
@ -169,7 +244,7 @@ func (r *AiRouterBiz) handleOtherTask(c *websocket.Conn, ch chan entitys.Respons
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *AiRouterBiz) handleMatch(c *websocket.Conn, ch chan entitys.ResponseData, matchJson *entitys.Match, tasks []model.AiTask, sysInfo model.AiSy) (err error) {
|
func (r *AiRouterBiz) handleMatch(ctx context.Context, c *websocket.Conn, ch chan entitys.ResponseData, matchJson *entitys.Match, tasks []model.AiTask, sysInfo model.AiSy) (err error) {
|
||||||
|
|
||||||
if !matchJson.IsMatch {
|
if !matchJson.IsMatch {
|
||||||
ch <- entitys.ResponseData{
|
ch <- entitys.ResponseData{
|
||||||
|
@ -404,6 +479,23 @@ func (r *AiRouterBiz) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, re
|
||||||
return prompt
|
return prompt
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *AiRouterBiz) getPromptOllama(sysInfo model.AiSy, history []model.AiChatHi, reqInput string) []api.Message {
|
||||||
|
var (
|
||||||
|
prompt = make([]api.Message, 0)
|
||||||
|
)
|
||||||
|
prompt = append(prompt, api.Message{
|
||||||
|
Role: "system",
|
||||||
|
Content: r.buildSystemPrompt(sysInfo.SysPrompt),
|
||||||
|
}, api.Message{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: pkg.JsonStringIgonErr(r.buildAssistant(history)),
|
||||||
|
}, api.Message{
|
||||||
|
Role: "user",
|
||||||
|
Content: reqInput,
|
||||||
|
})
|
||||||
|
return prompt
|
||||||
|
}
|
||||||
|
|
||||||
func (r *AiRouterBiz) getPromptLLM(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []llms.MessageContent {
|
func (r *AiRouterBiz) getPromptLLM(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []llms.MessageContent {
|
||||||
var (
|
var (
|
||||||
prompt = make([]llms.MessageContent, 0)
|
prompt = make([]llms.MessageContent, 0)
|
||||||
|
|
|
@ -28,6 +28,7 @@ type SysConfig struct {
|
||||||
SessionLen int `mapstructure:"session_len"`
|
SessionLen int `mapstructure:"session_len"`
|
||||||
ChannelPoolLen int `mapstructure:"channel_pool_len"`
|
ChannelPoolLen int `mapstructure:"channel_pool_len"`
|
||||||
ChannelPoolSize int `mapstructure:"channel_pool_size"`
|
ChannelPoolSize int `mapstructure:"channel_pool_size"`
|
||||||
|
LlmPoolLen int `mapstructure:"llm_pool_len"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServerConfig 服务器配置
|
// ServerConfig 服务器配置
|
||||||
|
|
|
@ -3,7 +3,6 @@ package pkg
|
||||||
import (
|
import (
|
||||||
"ai_scheduler/internal/config"
|
"ai_scheduler/internal/config"
|
||||||
"ai_scheduler/internal/entitys"
|
"ai_scheduler/internal/entitys"
|
||||||
"errors"
|
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -25,29 +24,29 @@ func NewSafeChannelPool(c *config.Config) (*SafeChannelPool, func()) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// 从池中获取 channel(若无空闲则创建新 channel)
|
// 从池中获取 channel(若无空闲则创建新 channel)
|
||||||
func (p *SafeChannelPool) Get() (chan entitys.ResponseData, error) {
|
func (p *SafeChannelPool) Get() chan entitys.ResponseData {
|
||||||
p.mu.Lock()
|
p.mu.Lock()
|
||||||
defer p.mu.Unlock()
|
defer p.mu.Unlock()
|
||||||
|
|
||||||
if p.closed {
|
if p.closed {
|
||||||
return nil, errors.New("pool is closed")
|
return make(chan entitys.ResponseData, p.bufSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case ch := <-p.pool: // 从池中取
|
case ch := <-p.pool: // 从池中取
|
||||||
return ch, nil
|
return ch
|
||||||
default: // 池为空,创建新 channel
|
default: // 池为空,创建新 channel
|
||||||
return make(chan entitys.ResponseData, p.bufSize), nil
|
return make(chan entitys.ResponseData, p.bufSize)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 将 channel 放回池中(必须确保 channel 已清空!)
|
// 将 channel 放回池中(必须确保 channel 已清空!)
|
||||||
func (p *SafeChannelPool) Put(ch chan entitys.ResponseData) error {
|
func (p *SafeChannelPool) Put(ch chan entitys.ResponseData) {
|
||||||
p.mu.Lock()
|
p.mu.Lock()
|
||||||
defer p.mu.Unlock()
|
defer p.mu.Unlock()
|
||||||
|
|
||||||
if p.closed {
|
if p.closed {
|
||||||
return errors.New("pool is closed")
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 清空 channel(防止复用时读取旧数据)
|
// 清空 channel(防止复用时读取旧数据)
|
||||||
|
@ -62,7 +61,7 @@ func (p *SafeChannelPool) Put(ch chan entitys.ResponseData) error {
|
||||||
default: // 池已满,直接关闭 channel(避免泄漏)
|
default: // 池已满,直接关闭 channel(避免泄漏)
|
||||||
close(ch)
|
close(ch)
|
||||||
}
|
}
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 关闭池(释放所有资源)
|
// 关闭池(释放所有资源)
|
||||||
|
|
|
@ -1,8 +1,23 @@
|
||||||
package pkg
|
package pkg
|
||||||
|
|
||||||
import "encoding/json"
|
import (
|
||||||
|
"ai_scheduler/internal/entitys"
|
||||||
|
"encoding/json"
|
||||||
|
)
|
||||||
|
|
||||||
func JsonStringIgonErr(data interface{}) string {
|
func JsonStringIgonErr(data interface{}) string {
|
||||||
dataByte, _ := json.Marshal(data)
|
dataByte, _ := json.Marshal(data)
|
||||||
return string(dataByte)
|
return string(dataByte)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsChannelClosed 检查给定的 channel 是否已经关闭
|
||||||
|
// 参数 ch: 要检查的 channel,类型为 chan entitys.ResponseData
|
||||||
|
// 返回值: bool 类型,true 表示 channel 已关闭,false 表示未关闭
|
||||||
|
func IsChannelClosed(ch chan entitys.ResponseData) bool {
|
||||||
|
select {
|
||||||
|
case _, ok := <-ch: // 尝试从 channel 中读取数据
|
||||||
|
return !ok // 如果 ok=false,说明 channel 已关闭
|
||||||
|
default: // 如果 channel 暂时无数据可读(但不一定关闭)
|
||||||
|
return false // channel 未关闭(但可能有数据未读取)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -46,7 +46,7 @@ func (c *Client) ToolSelect(ctx context.Context, messages []api.Message, tools [
|
||||||
Messages: messages,
|
Messages: messages,
|
||||||
Stream: new(bool), // 设置为false,不使用流式响应
|
Stream: new(bool), // 设置为false,不使用流式响应
|
||||||
Think: &api.ThinkValue{Value: true},
|
Think: &api.ThinkValue{Value: true},
|
||||||
Tools: tools,
|
//Tools: tools,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = c.client.Chat(ctx, req, func(resp api.ChatResponse) error {
|
err = c.client.Chat(ctx, req, func(resp api.ChatResponse) error {
|
||||||
|
|
|
@ -2,40 +2,112 @@ package utils_ollama
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"ai_scheduler/internal/config"
|
"ai_scheduler/internal/config"
|
||||||
|
"math/rand"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2/log"
|
"github.com/gofiber/fiber/v2/log"
|
||||||
"github.com/tmc/langchaingo/llms/ollama"
|
"github.com/tmc/langchaingo/llms/ollama"
|
||||||
)
|
)
|
||||||
|
|
||||||
type UtilOllama struct {
|
type UtilOllama struct {
|
||||||
Llm *ollama.LLM
|
LlmClientPool *sync.Pool
|
||||||
|
poolSize int // 记录池大小,用于调试
|
||||||
|
model string
|
||||||
|
serverURL string
|
||||||
|
c *config.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
type LlmObj struct {
|
||||||
|
Number string
|
||||||
|
Llm *ollama.LLM
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUtilOllama(c *config.Config, logger log.AllLogger) *UtilOllama {
|
func NewUtilOllama(c *config.Config, logger log.AllLogger) *UtilOllama {
|
||||||
llm, err := ollama.New(
|
poolSize := c.Sys.LlmPoolLen
|
||||||
ollama.WithModel(c.Ollama.Model),
|
if poolSize <= 0 {
|
||||||
ollama.WithHTTPClient(http.DefaultClient),
|
poolSize = 10 // 默认值
|
||||||
ollama.WithServerURL(getUrl(c)),
|
logger.Warnf("LlmPoolLen not set, using default: %d", poolSize)
|
||||||
ollama.WithKeepAlive("-1s"),
|
}
|
||||||
)
|
|
||||||
if err != nil {
|
// 初始化 Pool
|
||||||
logger.Fatal(err)
|
pool := &sync.Pool{
|
||||||
panic(err)
|
New: func() interface{} {
|
||||||
|
llm, err := ollama.New(
|
||||||
|
ollama.WithModel(c.Ollama.Model),
|
||||||
|
ollama.WithHTTPClient(http.DefaultClient),
|
||||||
|
ollama.WithServerURL(getUrl(c)),
|
||||||
|
ollama.WithKeepAlive("-1s"),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
logger.Fatalf("Failed to create Ollama client: %v", err)
|
||||||
|
panic(err) // 或者返回 nil + 错误处理
|
||||||
|
}
|
||||||
|
number := randStr(5)
|
||||||
|
log.Info(number)
|
||||||
|
return &LlmObj{
|
||||||
|
Number: number,
|
||||||
|
Llm: llm,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// 预填充 Pool
|
||||||
|
for i := 0; i < poolSize; i++ {
|
||||||
|
pool.Put(pool.New())
|
||||||
}
|
}
|
||||||
|
|
||||||
return &UtilOllama{
|
return &UtilOllama{
|
||||||
Llm: llm,
|
LlmClientPool: pool,
|
||||||
|
poolSize: poolSize,
|
||||||
|
model: c.Ollama.Model,
|
||||||
|
serverURL: getUrl(c),
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//func (o *UtilOllama) a() {
|
func (o *UtilOllama) NewClient() *ollama.LLM {
|
||||||
// var agent agents.Agent
|
llm, _ := ollama.New(
|
||||||
// agent = agents.NewOneShotAgent(llm, tools, opts...)
|
ollama.WithModel(o.c.Ollama.Model),
|
||||||
//
|
ollama.WithHTTPClient(&http.Client{
|
||||||
// agents.NewExecutor()
|
Transport: &http.Transport{
|
||||||
//}
|
MaxIdleConns: 100, // 最大空闲连接数(默认 2,太小)
|
||||||
|
MaxIdleConnsPerHost: 100, // 每个 Host 的最大空闲连接数(默认 2)
|
||||||
|
IdleConnTimeout: 90 * time.Second, // 空闲连接超时时间
|
||||||
|
DialContext: (&net.Dialer{
|
||||||
|
Timeout: 30 * time.Second, // 连接超时
|
||||||
|
KeepAlive: 30 * time.Second, // TCP Keep-Alive
|
||||||
|
}).DialContext,
|
||||||
|
},
|
||||||
|
Timeout: 60 * time.Second, // 整体请求超时(避免无限等待)
|
||||||
|
}),
|
||||||
|
ollama.WithServerURL(getUrl(o.c)),
|
||||||
|
ollama.WithKeepAlive("-1s"),
|
||||||
|
)
|
||||||
|
return llm
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get 返回一个可用的 LLM 客户端
|
||||||
|
func (o *UtilOllama) Get() *LlmObj {
|
||||||
|
client := o.LlmClientPool.Get().(*LlmObj)
|
||||||
|
return client
|
||||||
|
}
|
||||||
|
|
||||||
|
// Put 归还客户端(可选:检查是否仍可用)
|
||||||
|
func (o *UtilOllama) Put(llm *LlmObj) {
|
||||||
|
if llm == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
o.LlmClientPool.Put(llm)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stats 返回池的统计信息(用于监控)
|
||||||
|
func (o *UtilOllama) Stats() (current, max int) {
|
||||||
|
return o.poolSize, o.poolSize
|
||||||
|
}
|
||||||
|
|
||||||
func getUrl(c *config.Config) string {
|
func getUrl(c *config.Config) string {
|
||||||
baseURL := c.Ollama.BaseURL
|
baseURL := c.Ollama.BaseURL
|
||||||
|
@ -45,3 +117,13 @@ func getUrl(c *config.Config) string {
|
||||||
}
|
}
|
||||||
return baseURL
|
return baseURL
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
|
||||||
|
|
||||||
|
func randStr(n int) string {
|
||||||
|
b := make([]rune, n)
|
||||||
|
for i := range b {
|
||||||
|
b[i] = letters[rand.Intn(len(letters))]
|
||||||
|
}
|
||||||
|
return string(b)
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue