130 lines
		
	
	
		
			2.9 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			130 lines
		
	
	
		
			2.9 KiB
		
	
	
	
		
			Go
		
	
	
	
package utils_langchain
 | 
						||
 | 
						||
import (
 | 
						||
	"ai_scheduler/internal/config"
 | 
						||
	"math/rand"
 | 
						||
	"net"
 | 
						||
	"net/http"
 | 
						||
	"os"
 | 
						||
	"sync"
 | 
						||
	"time"
 | 
						||
 | 
						||
	"github.com/gofiber/fiber/v2/log"
 | 
						||
	"github.com/tmc/langchaingo/llms/ollama"
 | 
						||
)
 | 
						||
 | 
						||
type UtilLangChain struct {
 | 
						||
	LlmClientPool *sync.Pool
 | 
						||
	poolSize      int // 记录池大小,用于调试
 | 
						||
	model         string
 | 
						||
	serverURL     string
 | 
						||
	c             *config.Config
 | 
						||
}
 | 
						||
 | 
						||
type LlmObj struct {
 | 
						||
	Number string
 | 
						||
	Llm    *ollama.LLM
 | 
						||
}
 | 
						||
 | 
						||
func NewUtilLangChain(c *config.Config, logger log.AllLogger) *UtilLangChain {
 | 
						||
	poolSize := c.Sys.LlmPoolLen
 | 
						||
	if poolSize <= 0 {
 | 
						||
		poolSize = 10 // 默认值
 | 
						||
		logger.Warnf("LlmPoolLen not set, using default: %d", poolSize)
 | 
						||
	}
 | 
						||
 | 
						||
	// 初始化 Pool
 | 
						||
	pool := &sync.Pool{
 | 
						||
		New: func() interface{} {
 | 
						||
			llm, err := ollama.New(
 | 
						||
				ollama.WithModel(c.Ollama.Model),
 | 
						||
				ollama.WithHTTPClient(http.DefaultClient),
 | 
						||
				ollama.WithServerURL(getUrl(c)),
 | 
						||
				ollama.WithKeepAlive("-1s"),
 | 
						||
			)
 | 
						||
			if err != nil {
 | 
						||
				logger.Fatalf("Failed to create Ollama client: %v", err)
 | 
						||
				panic(err) // 或者返回 nil + 错误处理
 | 
						||
			}
 | 
						||
			number := randStr(5)
 | 
						||
			log.Info(number)
 | 
						||
			return &LlmObj{
 | 
						||
				Number: number,
 | 
						||
				Llm:    llm,
 | 
						||
			}
 | 
						||
		},
 | 
						||
	}
 | 
						||
 | 
						||
	// 预填充 Pool
 | 
						||
	for i := 0; i < poolSize; i++ {
 | 
						||
		pool.Put(pool.New())
 | 
						||
	}
 | 
						||
 | 
						||
	return &UtilLangChain{
 | 
						||
		LlmClientPool: pool,
 | 
						||
		poolSize:      poolSize,
 | 
						||
		model:         c.Ollama.Model,
 | 
						||
		serverURL:     getUrl(c),
 | 
						||
	}
 | 
						||
 | 
						||
}
 | 
						||
 | 
						||
func (o *UtilLangChain) NewClient() *ollama.LLM {
 | 
						||
	llm, _ := ollama.New(
 | 
						||
		ollama.WithModel(o.c.Ollama.Model),
 | 
						||
		ollama.WithHTTPClient(&http.Client{
 | 
						||
			Transport: &http.Transport{
 | 
						||
				MaxIdleConns:        100,              // 最大空闲连接数(默认 2,太小)
 | 
						||
				MaxIdleConnsPerHost: 100,              // 每个 Host 的最大空闲连接数(默认 2)
 | 
						||
				IdleConnTimeout:     90 * time.Second, // 空闲连接超时时间
 | 
						||
				DialContext: (&net.Dialer{
 | 
						||
					Timeout:   30 * time.Second, // 连接超时
 | 
						||
					KeepAlive: 30 * time.Second, // TCP Keep-Alive
 | 
						||
				}).DialContext,
 | 
						||
			},
 | 
						||
			Timeout: 60 * time.Second, // 整体请求超时(避免无限等待)
 | 
						||
		}),
 | 
						||
		ollama.WithServerURL(getUrl(o.c)),
 | 
						||
		ollama.WithKeepAlive("-1s"),
 | 
						||
	)
 | 
						||
	return llm
 | 
						||
}
 | 
						||
 | 
						||
// Get 返回一个可用的 LLM 客户端
 | 
						||
func (o *UtilLangChain) Get() *LlmObj {
 | 
						||
	client := o.LlmClientPool.Get().(*LlmObj)
 | 
						||
	return client
 | 
						||
}
 | 
						||
 | 
						||
// Put 归还客户端(可选:检查是否仍可用)
 | 
						||
func (o *UtilLangChain) Put(llm *LlmObj) {
 | 
						||
	if llm == nil {
 | 
						||
		return
 | 
						||
	}
 | 
						||
	o.LlmClientPool.Put(llm)
 | 
						||
}
 | 
						||
 | 
						||
// Stats 返回池的统计信息(用于监控)
 | 
						||
func (o *UtilLangChain) Stats() (current, max int) {
 | 
						||
	return o.poolSize, o.poolSize
 | 
						||
}
 | 
						||
 | 
						||
func getUrl(c *config.Config) string {
 | 
						||
	baseURL := c.Ollama.BaseURL
 | 
						||
	envURL := os.Getenv("OLLAMA_BASE_URL")
 | 
						||
	if envURL != "" {
 | 
						||
		baseURL = envURL
 | 
						||
	}
 | 
						||
	return baseURL
 | 
						||
}
 | 
						||
 | 
						||
var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
 | 
						||
 | 
						||
func randStr(n int) string {
 | 
						||
	b := make([]rune, n)
 | 
						||
	for i := range b {
 | 
						||
		b[i] = letters[rand.Intn(len(letters))]
 | 
						||
	}
 | 
						||
	return string(b)
 | 
						||
}
 |