ai_scheduler/internal/pkg/utils_langchain/client.go

130 lines
2.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package utils_langchain
import (
"ai_scheduler/internal/config"
"math/rand"
"net"
"net/http"
"os"
"sync"
"time"
"github.com/gofiber/fiber/v2/log"
"github.com/tmc/langchaingo/llms/ollama"
)
type UtilLangChain struct {
LlmClientPool *sync.Pool
poolSize int // 记录池大小,用于调试
model string
serverURL string
c *config.Config
}
type LlmObj struct {
Number string
Llm *ollama.LLM
}
func NewUtilLangChain(c *config.Config, logger log.AllLogger) *UtilLangChain {
poolSize := c.Sys.LlmPoolLen
if poolSize <= 0 {
poolSize = 10 // 默认值
logger.Warnf("LlmPoolLen not set, using default: %d", poolSize)
}
// 初始化 Pool
pool := &sync.Pool{
New: func() interface{} {
llm, err := ollama.New(
ollama.WithModel(c.Ollama.Model),
ollama.WithHTTPClient(http.DefaultClient),
ollama.WithServerURL(getUrl(c)),
ollama.WithKeepAlive("-1s"),
)
if err != nil {
logger.Fatalf("Failed to create Ollama client: %v", err)
panic(err) // 或者返回 nil + 错误处理
}
number := randStr(5)
log.Info(number)
return &LlmObj{
Number: number,
Llm: llm,
}
},
}
// 预填充 Pool
for i := 0; i < poolSize; i++ {
pool.Put(pool.New())
}
return &UtilLangChain{
LlmClientPool: pool,
poolSize: poolSize,
model: c.Ollama.Model,
serverURL: getUrl(c),
}
}
func (o *UtilLangChain) NewClient() *ollama.LLM {
llm, _ := ollama.New(
ollama.WithModel(o.c.Ollama.Model),
ollama.WithHTTPClient(&http.Client{
Transport: &http.Transport{
MaxIdleConns: 100, // 最大空闲连接数(默认 2太小
MaxIdleConnsPerHost: 100, // 每个 Host 的最大空闲连接数(默认 2
IdleConnTimeout: 90 * time.Second, // 空闲连接超时时间
DialContext: (&net.Dialer{
Timeout: 30 * time.Second, // 连接超时
KeepAlive: 30 * time.Second, // TCP Keep-Alive
}).DialContext,
},
Timeout: 60 * time.Second, // 整体请求超时(避免无限等待)
}),
ollama.WithServerURL(getUrl(o.c)),
ollama.WithKeepAlive("-1s"),
)
return llm
}
// Get 返回一个可用的 LLM 客户端
func (o *UtilLangChain) Get() *LlmObj {
client := o.LlmClientPool.Get().(*LlmObj)
return client
}
// Put 归还客户端(可选:检查是否仍可用)
func (o *UtilLangChain) Put(llm *LlmObj) {
if llm == nil {
return
}
o.LlmClientPool.Put(llm)
}
// Stats 返回池的统计信息(用于监控)
func (o *UtilLangChain) Stats() (current, max int) {
return o.poolSize, o.poolSize
}
func getUrl(c *config.Config) string {
baseURL := c.Ollama.BaseURL
envURL := os.Getenv("OLLAMA_BASE_URL")
if envURL != "" {
baseURL = envURL
}
return baseURL
}
var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
func randStr(n int) string {
b := make([]rune, n)
for i := range b {
b[i] = letters[rand.Intn(len(letters))]
}
return string(b)
}