139 lines
5.0 KiB
Python
139 lines
5.0 KiB
Python
import logging
|
||
import os
|
||
import threading
|
||
from collections import OrderedDict
|
||
from typing import Optional
|
||
|
||
import numpy as np
|
||
import ollama
|
||
from lightrag import LightRAG
|
||
from lightrag.utils import EmbeddingFunc
|
||
from lightrag.llm.ollama import ollama_embed
|
||
from app.config import settings
|
||
|
||
# 全局 RAG 管理器
|
||
rag_manager = None
|
||
|
||
async def llm_func(prompt, system_prompt=None, history_messages=[], **kwargs) -> str:
|
||
"""定义 LLM 函数:使用 Ollama 生成回复"""
|
||
# 移除可能存在的 model 参数,避免冲突
|
||
kwargs.pop('model', None)
|
||
kwargs.pop('hashing_kv', None)
|
||
keyword_extraction = kwargs.pop("keyword_extraction", False)
|
||
if keyword_extraction:
|
||
kwargs["format"] = "json"
|
||
|
||
stream = kwargs.pop("stream", False)
|
||
# Debug: 检查流式参数
|
||
if stream:
|
||
logging.info("LLM called with stream=True")
|
||
else:
|
||
logging.info("LLM called with stream=False")
|
||
|
||
messages = []
|
||
if system_prompt:
|
||
messages.append({"role": "system", "content": system_prompt})
|
||
messages.extend(history_messages)
|
||
messages.append({"role": "user", "content": prompt})
|
||
|
||
client = ollama.AsyncClient(host=settings.LLM_BINDING_HOST)
|
||
if stream:
|
||
async def inner():
|
||
# 使用 **kwargs 透传参数,确保 format 等顶级参数生效
|
||
response = await client.chat(model=settings.LLM_MODEL, messages=messages, stream=True, **kwargs)
|
||
|
||
async for chunk in response:
|
||
msg = chunk.get("message", {})
|
||
|
||
if "thinking" in msg and msg["thinking"]:
|
||
yield {"type": "thinking", "content": msg["thinking"]}
|
||
if "content" in msg and msg["content"]:
|
||
yield {"type": "content", "content": msg["content"]}
|
||
return inner()
|
||
else:
|
||
# 同步请求不要think,降低延迟
|
||
response = await client.chat(model=settings.LLM_MODEL, messages=messages, stream=False, think=False, **kwargs)
|
||
return response["message"]["content"]
|
||
|
||
async def embedding_func(texts: list[str]) -> np.ndarray:
|
||
"""定义 Embedding 函数:使用 Ollama 计算向量"""
|
||
return await ollama_embed(
|
||
texts,
|
||
embed_model=settings.EMBEDDING_MODEL,
|
||
host=settings.EMBEDDING_BINDING_HOST
|
||
)
|
||
|
||
class RAGManager:
|
||
def __init__(self, capacity: int = 3):
|
||
self.capacity = capacity
|
||
self.cache = OrderedDict()
|
||
self.lock = threading.Lock()
|
||
|
||
async def get_rag(self, user_id: str) -> LightRAG:
|
||
"""获取指定用户的 LightRAG 实例 (LRU 缓存)"""
|
||
# 1. 尝试从缓存获取
|
||
with self.lock:
|
||
if user_id in self.cache:
|
||
self.cache.move_to_end(user_id)
|
||
logging.debug(f"Cache hit for user: {user_id}")
|
||
return self.cache[user_id]
|
||
|
||
# 2. 缓存未命中,需要初始化
|
||
logging.info(f"Initializing RAG instance for user: {user_id}")
|
||
user_data_dir = os.path.join(settings.DATA_DIR, user_id)
|
||
|
||
# 严格隔离:确保工作目录绝对干净,不应包含上级目录或其他用户数据
|
||
if not os.path.exists(user_data_dir):
|
||
os.makedirs(user_data_dir)
|
||
|
||
# 实例化 LightRAG
|
||
# 关键修复:确保 working_dir 严格指向用户子目录
|
||
rag = LightRAG(
|
||
working_dir=user_data_dir,
|
||
llm_model_func=llm_func,
|
||
llm_model_name=settings.LLM_MODEL,
|
||
llm_model_max_async=1,
|
||
max_parallel_insert=1,
|
||
embedding_func=EmbeddingFunc(
|
||
embedding_dim=settings.EMBEDDING_DIM,
|
||
max_token_size=settings.MAX_TOKEN_SIZE,
|
||
func=embedding_func
|
||
),
|
||
embedding_func_max_async=4,
|
||
enable_llm_cache=True
|
||
)
|
||
|
||
# 异步初始化存储
|
||
await rag.initialize_storages()
|
||
|
||
# 3. 放入缓存并处理驱逐
|
||
with self.lock:
|
||
# 双重检查,防止在初始化期间其他线程已经创建了
|
||
if user_id in self.cache:
|
||
self.cache.move_to_end(user_id)
|
||
return self.cache[user_id]
|
||
|
||
self.cache[user_id] = rag
|
||
self.cache.move_to_end(user_id)
|
||
|
||
# 检查容量
|
||
while len(self.cache) > self.capacity:
|
||
oldest_user, _ = self.cache.popitem(last=False)
|
||
logging.info(f"Evicting RAG instance for user: {oldest_user} due to capacity limit")
|
||
# 这里依赖 Python GC 回收资源
|
||
|
||
return rag
|
||
|
||
def initialize_rag_manager():
|
||
"""初始化全局 RAG 管理器"""
|
||
global rag_manager
|
||
rag_manager = RAGManager(capacity=settings.MAX_RAG_INSTANCES)
|
||
logging.info(f"RAG Manager initialized with capacity: {settings.MAX_RAG_INSTANCES}")
|
||
return rag_manager
|
||
|
||
def get_rag_manager() -> RAGManager:
|
||
"""获取全局 RAG 管理器"""
|
||
if rag_manager is None:
|
||
return initialize_rag_manager()
|
||
return rag_manager
|