diff --git a/app/api/routes.py b/app/api/routes.py index 0d2c07a..d221903 100644 --- a/app/api/routes.py +++ b/app/api/routes.py @@ -2,17 +2,30 @@ import json import logging import os import io -from fastapi import APIRouter, UploadFile, File, HTTPException, Body +from fastapi import APIRouter, UploadFile, File, HTTPException, Body, Header, Depends from fastapi.responses import StreamingResponse from pydantic import BaseModel from pypdf import PdfReader -from lightrag import QueryParam -from app.core.rag import get_rag, llm_func +from lightrag import LightRAG, QueryParam +from app.core.rag import get_rag_manager, llm_func from app.core.prompts import CUSTOM_RAG_RESPONSE_PROMPT from app.config import settings router = APIRouter() +# ========================================== +# 依赖注入 +# ========================================== +async def get_current_rag( + x_tenant_id: str = Header("default", alias="X-Tenant-ID") +) -> LightRAG: + """ + 依赖项:获取当前租户的 LightRAG 实例 + 从 Header 中读取 X-Tenant-ID,默认为 'default' + """ + manager = get_rag_manager() + return await manager.get_rag(x_tenant_id) + # ========================================== # 数据模型 # ========================================== @@ -37,7 +50,10 @@ async def health_check(): return {"status": "ok", "llm": settings.LLM_MODEL} @router.post("/query") -async def query_knowledge_base(request: QueryRequest): +async def query_knowledge_base( + request: QueryRequest, + rag: LightRAG = Depends(get_current_rag) +): """ 查询接口 - query: 用户问题 @@ -45,7 +61,6 @@ async def query_knowledge_base(request: QueryRequest): - stream: 是否流式输出 (默认 False) """ try: - rag = get_rag() # 构造查询参数 param = QueryParam( mode=request.mode, @@ -129,10 +144,12 @@ async def query_knowledge_base(request: QueryRequest): raise HTTPException(status_code=500, detail=str(e)) @router.post("/ingest/text") -async def ingest_text(text: str = Body(..., embed=True)): +async def ingest_text( + text: str = Body(..., embed=True), + rag: LightRAG = Depends(get_current_rag) +): """直接摄入文本内容""" try: - rag = get_rag() # 使用异步方法 ainsert await rag.ainsert(text) return {"status": "success", "message": "Text ingested successfully"} @@ -140,13 +157,15 @@ async def ingest_text(text: str = Body(..., embed=True)): raise HTTPException(status_code=500, detail=str(e)) @router.post("/ingest/file", response_model=IngestResponse) -async def upload_file(file: UploadFile = File(...)): +async def upload_file( + file: UploadFile = File(...), + rag: LightRAG = Depends(get_current_rag) +): """ 文件上传与索引接口 支持 .txt, .md, .pdf """ try: - rag = get_rag() content = "" filename = file.filename @@ -191,14 +210,16 @@ async def upload_file(file: UploadFile = File(...)): raise HTTPException(status_code=500, detail=f"Processing failed: {str(e)}") @router.get("/documents") -async def list_documents(): +async def list_documents( + rag: LightRAG = Depends(get_current_rag) +): """ 获取文档列表 返回当前知识库中已索引的所有文档及其状态 """ try: - # 直接读取 doc_status_storage 的底层文件 - doc_status_path = os.path.join(settings.DATA_DIR, "kv_store_doc_status.json") + # 从 rag 实例获取工作目录 + doc_status_path = os.path.join(rag.working_dir, "kv_store_doc_status.json") if not os.path.exists(doc_status_path): return {"count": 0, "docs": []} @@ -225,13 +246,15 @@ async def list_documents(): raise HTTPException(status_code=500, detail=str(e)) @router.delete("/docs/{doc_id}") -async def delete_document(doc_id: str): +async def delete_document( + doc_id: str, + rag: LightRAG = Depends(get_current_rag) +): """ 删除指定文档 - doc_id: 文档ID (例如 doc-xxxxx) """ try: - rag = get_rag() logging.info(f"正在删除文档: {doc_id}") # 调用 LightRAG 的删除方法 await rag.adelete_by_doc_id(doc_id) diff --git a/app/config.py b/app/config.py index fb34df8..f2fa93b 100644 --- a/app/config.py +++ b/app/config.py @@ -24,6 +24,7 @@ class Settings(BaseSettings): # RAG Config EMBEDDING_DIM: int = 1024 MAX_TOKEN_SIZE: int = 8192 + MAX_RAG_INSTANCES: int = 3 # 最大活跃 RAG 实例数 class Config: env_file = ".env" diff --git a/app/core/rag.py b/app/core/rag.py index e7890e5..8968f46 100644 --- a/app/core/rag.py +++ b/app/core/rag.py @@ -1,14 +1,18 @@ import logging import os -import ollama +import threading +from collections import OrderedDict +from typing import Optional + import numpy as np +import ollama from lightrag import LightRAG from lightrag.utils import EmbeddingFunc from lightrag.llm.ollama import ollama_embed from app.config import settings -# 全局 RAG 实例 -rag = None +# 全局 RAG 管理器 +rag_manager = None async def llm_func(prompt, system_prompt=None, history_messages=[], **kwargs) -> str: """定义 LLM 函数:使用 Ollama 生成回复""" @@ -59,40 +63,76 @@ async def embedding_func(texts: list[str]) -> np.ndarray: host=settings.EMBEDDING_BINDING_HOST ) -async def initialize_rag(): - """初始化 LightRAG 实例""" - global rag - - # 确保工作目录存在 - if not os.path.exists(settings.DATA_DIR): - os.makedirs(settings.DATA_DIR) +class RAGManager: + def __init__(self, capacity: int = 3): + self.capacity = capacity + self.cache = OrderedDict() + self.lock = threading.Lock() + + async def get_rag(self, user_id: str) -> LightRAG: + """获取指定用户的 LightRAG 实例 (LRU 缓存)""" + # 1. 尝试从缓存获取 + with self.lock: + if user_id in self.cache: + self.cache.move_to_end(user_id) + logging.debug(f"Cache hit for user: {user_id}") + return self.cache[user_id] - print(f"正在初始化 LightRAG...") - print(f"LLM: {settings.LLM_MODEL} @ {settings.LLM_BINDING_HOST}") - print(f"Embedding: {settings.EMBEDDING_MODEL} @ {settings.EMBEDDING_BINDING_HOST}") + # 2. 缓存未命中,需要初始化 + logging.info(f"Initializing RAG instance for user: {user_id}") + user_data_dir = os.path.join(settings.DATA_DIR, user_id) + + # 严格隔离:确保工作目录绝对干净,不应包含上级目录或其他用户数据 + if not os.path.exists(user_data_dir): + os.makedirs(user_data_dir) + + # 实例化 LightRAG + # 关键修复:确保 working_dir 严格指向用户子目录 + rag = LightRAG( + working_dir=user_data_dir, + llm_model_func=llm_func, + llm_model_name=settings.LLM_MODEL, + llm_model_max_async=1, + max_parallel_insert=1, + embedding_func=EmbeddingFunc( + embedding_dim=settings.EMBEDDING_DIM, + max_token_size=settings.MAX_TOKEN_SIZE, + func=embedding_func + ), + embedding_func_max_async=4, + enable_llm_cache=True + ) + + # 异步初始化存储 + await rag.initialize_storages() + + # 3. 放入缓存并处理驱逐 + with self.lock: + # 双重检查,防止在初始化期间其他线程已经创建了 + if user_id in self.cache: + self.cache.move_to_end(user_id) + return self.cache[user_id] + + self.cache[user_id] = rag + self.cache.move_to_end(user_id) + + # 检查容量 + while len(self.cache) > self.capacity: + oldest_user, _ = self.cache.popitem(last=False) + logging.info(f"Evicting RAG instance for user: {oldest_user} due to capacity limit") + # 这里依赖 Python GC 回收资源 + + return rag - rag = LightRAG( - working_dir=settings.DATA_DIR, - llm_model_func=llm_func, - llm_model_name=settings.LLM_MODEL, - llm_model_max_async=1, - max_parallel_insert=1, - embedding_func=EmbeddingFunc( - embedding_dim=settings.EMBEDDING_DIM, - max_token_size=settings.MAX_TOKEN_SIZE, - func=embedding_func - ), - embedding_func_max_async=4, - enable_llm_cache=True - ) - - print("正在初始化 LightRAG 存储...") - await rag.initialize_storages() - print("LightRAG 存储初始化完成") - return rag +def initialize_rag_manager(): + """初始化全局 RAG 管理器""" + global rag_manager + rag_manager = RAGManager(capacity=settings.MAX_RAG_INSTANCES) + logging.info(f"RAG Manager initialized with capacity: {settings.MAX_RAG_INSTANCES}") + return rag_manager -def get_rag(): - """获取全局 RAG 实例""" - if rag is None: - raise RuntimeError("RAG instance not initialized. Call initialize_rag() first.") - return rag +def get_rag_manager() -> RAGManager: + """获取全局 RAG 管理器""" + if rag_manager is None: + return initialize_rag_manager() + return rag_manager diff --git a/app/main.py b/app/main.py index 756d269..4a48978 100644 --- a/app/main.py +++ b/app/main.py @@ -2,7 +2,7 @@ import logging from fastapi import FastAPI from contextlib import asynccontextmanager from app.config import settings -from app.core.rag import initialize_rag +from app.core.rag import initialize_rag_manager from app.core.prompts import patch_prompts from app.api.routes import router @@ -15,8 +15,8 @@ async def lifespan(app: FastAPI): # 1. Patch Prompts patch_prompts() - # 2. Init RAG - await initialize_rag() + # 2. Init RAG Manager + initialize_rag_manager() yield