130 lines
2.9 KiB
Go
130 lines
2.9 KiB
Go
package utils_langchain
|
||
|
||
import (
|
||
"ai_scheduler/internal/config"
|
||
"math/rand"
|
||
"net"
|
||
"net/http"
|
||
"os"
|
||
"sync"
|
||
"time"
|
||
|
||
"github.com/gofiber/fiber/v2/log"
|
||
"github.com/tmc/langchaingo/llms/ollama"
|
||
)
|
||
|
||
type UtilLangChain struct {
|
||
LlmClientPool *sync.Pool
|
||
poolSize int // 记录池大小,用于调试
|
||
model string
|
||
serverURL string
|
||
c *config.Config
|
||
}
|
||
|
||
type LlmObj struct {
|
||
Number string
|
||
Llm *ollama.LLM
|
||
}
|
||
|
||
func NewUtilLangChain(c *config.Config, logger log.AllLogger) *UtilLangChain {
|
||
poolSize := c.Sys.LlmPoolLen
|
||
if poolSize <= 0 {
|
||
poolSize = 10 // 默认值
|
||
logger.Warnf("LlmPoolLen not set, using default: %d", poolSize)
|
||
}
|
||
|
||
// 初始化 Pool
|
||
pool := &sync.Pool{
|
||
New: func() interface{} {
|
||
llm, err := ollama.New(
|
||
ollama.WithModel(c.Ollama.Model),
|
||
ollama.WithHTTPClient(http.DefaultClient),
|
||
ollama.WithServerURL(getUrl(c)),
|
||
ollama.WithKeepAlive("-1s"),
|
||
)
|
||
if err != nil {
|
||
logger.Fatalf("Failed to create Ollama client: %v", err)
|
||
panic(err) // 或者返回 nil + 错误处理
|
||
}
|
||
number := randStr(5)
|
||
log.Info(number)
|
||
return &LlmObj{
|
||
Number: number,
|
||
Llm: llm,
|
||
}
|
||
},
|
||
}
|
||
|
||
// 预填充 Pool
|
||
for i := 0; i < poolSize; i++ {
|
||
pool.Put(pool.New())
|
||
}
|
||
|
||
return &UtilLangChain{
|
||
LlmClientPool: pool,
|
||
poolSize: poolSize,
|
||
model: c.Ollama.Model,
|
||
serverURL: getUrl(c),
|
||
}
|
||
|
||
}
|
||
|
||
func (o *UtilLangChain) NewClient() *ollama.LLM {
|
||
llm, _ := ollama.New(
|
||
ollama.WithModel(o.c.Ollama.Model),
|
||
ollama.WithHTTPClient(&http.Client{
|
||
Transport: &http.Transport{
|
||
MaxIdleConns: 100, // 最大空闲连接数(默认 2,太小)
|
||
MaxIdleConnsPerHost: 100, // 每个 Host 的最大空闲连接数(默认 2)
|
||
IdleConnTimeout: 90 * time.Second, // 空闲连接超时时间
|
||
DialContext: (&net.Dialer{
|
||
Timeout: 30 * time.Second, // 连接超时
|
||
KeepAlive: 30 * time.Second, // TCP Keep-Alive
|
||
}).DialContext,
|
||
},
|
||
Timeout: 60 * time.Second, // 整体请求超时(避免无限等待)
|
||
}),
|
||
ollama.WithServerURL(getUrl(o.c)),
|
||
ollama.WithKeepAlive("-1s"),
|
||
)
|
||
return llm
|
||
}
|
||
|
||
// Get 返回一个可用的 LLM 客户端
|
||
func (o *UtilLangChain) Get() *LlmObj {
|
||
client := o.LlmClientPool.Get().(*LlmObj)
|
||
return client
|
||
}
|
||
|
||
// Put 归还客户端(可选:检查是否仍可用)
|
||
func (o *UtilLangChain) Put(llm *LlmObj) {
|
||
if llm == nil {
|
||
return
|
||
}
|
||
o.LlmClientPool.Put(llm)
|
||
}
|
||
|
||
// Stats 返回池的统计信息(用于监控)
|
||
func (o *UtilLangChain) Stats() (current, max int) {
|
||
return o.poolSize, o.poolSize
|
||
}
|
||
|
||
func getUrl(c *config.Config) string {
|
||
baseURL := c.Ollama.BaseURL
|
||
envURL := os.Getenv("OLLAMA_BASE_URL")
|
||
if envURL != "" {
|
||
baseURL = envURL
|
||
}
|
||
return baseURL
|
||
}
|
||
|
||
var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
|
||
|
||
func randStr(n int) string {
|
||
b := make([]rune, n)
|
||
for i := range b {
|
||
b[i] = letters[rand.Intn(len(letters))]
|
||
}
|
||
return string(b)
|
||
}
|