import logging import os import threading from collections import OrderedDict from typing import Optional import numpy as np import ollama from lightrag import LightRAG from lightrag.utils import EmbeddingFunc from lightrag.llm.ollama import ollama_embed from app.config import settings # 全局 RAG 管理器 rag_manager = None async def llm_func(prompt, system_prompt=None, history_messages=[], **kwargs) -> str: """定义 LLM 函数:使用 Ollama 生成回复""" # 移除可能存在的 model 参数,避免冲突 kwargs.pop('model', None) kwargs.pop('hashing_kv', None) keyword_extraction = kwargs.pop("keyword_extraction", False) if keyword_extraction: kwargs["format"] = "json" stream = kwargs.pop("stream", False) # Debug: 检查流式参数 if stream: logging.info("LLM called with stream=True") else: logging.info("LLM called with stream=False") messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) messages.extend(history_messages) messages.append({"role": "user", "content": prompt}) client = ollama.AsyncClient(host=settings.LLM_BINDING_HOST) if stream: async def inner(): # 使用 **kwargs 透传参数,确保 format 等顶级参数生效 response = await client.chat(model=settings.LLM_MODEL, messages=messages, stream=True, **kwargs) async for chunk in response: msg = chunk.get("message", {}) if "thinking" in msg and msg["thinking"]: yield {"type": "thinking", "content": msg["thinking"]} if "content" in msg and msg["content"]: yield {"type": "content", "content": msg["content"]} return inner() else: # 同步请求不要think,降低延迟 response = await client.chat(model=settings.LLM_MODEL, messages=messages, stream=False, think=False, **kwargs) return response["message"]["content"] async def embedding_func(texts: list[str]) -> np.ndarray: """定义 Embedding 函数:使用 Ollama 计算向量""" return await ollama_embed( texts, embed_model=settings.EMBEDDING_MODEL, host=settings.EMBEDDING_BINDING_HOST ) class RAGManager: def __init__(self, capacity: int = 3): self.capacity = capacity self.cache = OrderedDict() self.lock = threading.Lock() async def get_rag(self, user_id: str) -> LightRAG: """获取指定用户的 LightRAG 实例 (LRU 缓存)""" # 1. 尝试从缓存获取 with self.lock: if user_id in self.cache: self.cache.move_to_end(user_id) logging.debug(f"Cache hit for user: {user_id}") return self.cache[user_id] # 2. 缓存未命中,需要初始化 logging.info(f"Initializing RAG instance for user: {user_id}") user_data_dir = os.path.join(settings.DATA_DIR, user_id) # 严格隔离:确保工作目录绝对干净,不应包含上级目录或其他用户数据 if not os.path.exists(user_data_dir): os.makedirs(user_data_dir) # 实例化 LightRAG # 关键修复:确保 working_dir 严格指向用户子目录 rag = LightRAG( working_dir=user_data_dir, llm_model_func=llm_func, llm_model_name=settings.LLM_MODEL, llm_model_max_async=1, max_parallel_insert=1, embedding_func=EmbeddingFunc( embedding_dim=settings.EMBEDDING_DIM, max_token_size=settings.MAX_TOKEN_SIZE, func=embedding_func ), embedding_func_max_async=4, enable_llm_cache=True ) # 异步初始化存储 await rag.initialize_storages() # 3. 放入缓存并处理驱逐 with self.lock: # 双重检查,防止在初始化期间其他线程已经创建了 if user_id in self.cache: self.cache.move_to_end(user_id) return self.cache[user_id] self.cache[user_id] = rag self.cache.move_to_end(user_id) # 检查容量 while len(self.cache) > self.capacity: oldest_user, _ = self.cache.popitem(last=False) logging.info(f"Evicting RAG instance for user: {oldest_user} due to capacity limit") # 这里依赖 Python GC 回收资源 return rag def initialize_rag_manager(): """初始化全局 RAG 管理器""" global rag_manager rag_manager = RAGManager(capacity=settings.MAX_RAG_INSTANCES) logging.info(f"RAG Manager initialized with capacity: {settings.MAX_RAG_INSTANCES}") return rag_manager def get_rag_manager() -> RAGManager: """获取全局 RAG 管理器""" if rag_manager is None: return initialize_rag_manager() return rag_manager