feat: 增加 rag_manager 多租户

This commit is contained in:
fuzhongyun 2026-01-12 14:32:51 +08:00
parent 08dee51b9f
commit 0ce9215ee3
4 changed files with 118 additions and 54 deletions

View File

@ -2,17 +2,30 @@ import json
import logging import logging
import os import os
import io import io
from fastapi import APIRouter, UploadFile, File, HTTPException, Body from fastapi import APIRouter, UploadFile, File, HTTPException, Body, Header, Depends
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from pydantic import BaseModel from pydantic import BaseModel
from pypdf import PdfReader from pypdf import PdfReader
from lightrag import QueryParam from lightrag import LightRAG, QueryParam
from app.core.rag import get_rag, llm_func from app.core.rag import get_rag_manager, llm_func
from app.core.prompts import CUSTOM_RAG_RESPONSE_PROMPT from app.core.prompts import CUSTOM_RAG_RESPONSE_PROMPT
from app.config import settings from app.config import settings
router = APIRouter() router = APIRouter()
# ==========================================
# 依赖注入
# ==========================================
async def get_current_rag(
x_tenant_id: str = Header("default", alias="X-Tenant-ID")
) -> LightRAG:
"""
依赖项获取当前租户的 LightRAG 实例
Header 中读取 X-Tenant-ID默认为 'default'
"""
manager = get_rag_manager()
return await manager.get_rag(x_tenant_id)
# ========================================== # ==========================================
# 数据模型 # 数据模型
# ========================================== # ==========================================
@ -37,7 +50,10 @@ async def health_check():
return {"status": "ok", "llm": settings.LLM_MODEL} return {"status": "ok", "llm": settings.LLM_MODEL}
@router.post("/query") @router.post("/query")
async def query_knowledge_base(request: QueryRequest): async def query_knowledge_base(
request: QueryRequest,
rag: LightRAG = Depends(get_current_rag)
):
""" """
查询接口 查询接口
- query: 用户问题 - query: 用户问题
@ -45,7 +61,6 @@ async def query_knowledge_base(request: QueryRequest):
- stream: 是否流式输出 (默认 False) - stream: 是否流式输出 (默认 False)
""" """
try: try:
rag = get_rag()
# 构造查询参数 # 构造查询参数
param = QueryParam( param = QueryParam(
mode=request.mode, mode=request.mode,
@ -129,10 +144,12 @@ async def query_knowledge_base(request: QueryRequest):
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@router.post("/ingest/text") @router.post("/ingest/text")
async def ingest_text(text: str = Body(..., embed=True)): async def ingest_text(
text: str = Body(..., embed=True),
rag: LightRAG = Depends(get_current_rag)
):
"""直接摄入文本内容""" """直接摄入文本内容"""
try: try:
rag = get_rag()
# 使用异步方法 ainsert # 使用异步方法 ainsert
await rag.ainsert(text) await rag.ainsert(text)
return {"status": "success", "message": "Text ingested successfully"} return {"status": "success", "message": "Text ingested successfully"}
@ -140,13 +157,15 @@ async def ingest_text(text: str = Body(..., embed=True)):
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@router.post("/ingest/file", response_model=IngestResponse) @router.post("/ingest/file", response_model=IngestResponse)
async def upload_file(file: UploadFile = File(...)): async def upload_file(
file: UploadFile = File(...),
rag: LightRAG = Depends(get_current_rag)
):
""" """
文件上传与索引接口 文件上传与索引接口
支持 .txt, .md, .pdf 支持 .txt, .md, .pdf
""" """
try: try:
rag = get_rag()
content = "" content = ""
filename = file.filename filename = file.filename
@ -191,14 +210,16 @@ async def upload_file(file: UploadFile = File(...)):
raise HTTPException(status_code=500, detail=f"Processing failed: {str(e)}") raise HTTPException(status_code=500, detail=f"Processing failed: {str(e)}")
@router.get("/documents") @router.get("/documents")
async def list_documents(): async def list_documents(
rag: LightRAG = Depends(get_current_rag)
):
""" """
获取文档列表 获取文档列表
返回当前知识库中已索引的所有文档及其状态 返回当前知识库中已索引的所有文档及其状态
""" """
try: try:
# 直接读取 doc_status_storage 的底层文件 # 从 rag 实例获取工作目录
doc_status_path = os.path.join(settings.DATA_DIR, "kv_store_doc_status.json") doc_status_path = os.path.join(rag.working_dir, "kv_store_doc_status.json")
if not os.path.exists(doc_status_path): if not os.path.exists(doc_status_path):
return {"count": 0, "docs": []} return {"count": 0, "docs": []}
@ -225,13 +246,15 @@ async def list_documents():
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@router.delete("/docs/{doc_id}") @router.delete("/docs/{doc_id}")
async def delete_document(doc_id: str): async def delete_document(
doc_id: str,
rag: LightRAG = Depends(get_current_rag)
):
""" """
删除指定文档 删除指定文档
- doc_id: 文档ID (例如 doc-xxxxx) - doc_id: 文档ID (例如 doc-xxxxx)
""" """
try: try:
rag = get_rag()
logging.info(f"正在删除文档: {doc_id}") logging.info(f"正在删除文档: {doc_id}")
# 调用 LightRAG 的删除方法 # 调用 LightRAG 的删除方法
await rag.adelete_by_doc_id(doc_id) await rag.adelete_by_doc_id(doc_id)

View File

@ -24,6 +24,7 @@ class Settings(BaseSettings):
# RAG Config # RAG Config
EMBEDDING_DIM: int = 1024 EMBEDDING_DIM: int = 1024
MAX_TOKEN_SIZE: int = 8192 MAX_TOKEN_SIZE: int = 8192
MAX_RAG_INSTANCES: int = 3 # 最大活跃 RAG 实例数
class Config: class Config:
env_file = ".env" env_file = ".env"

View File

@ -1,14 +1,18 @@
import logging import logging
import os import os
import ollama import threading
from collections import OrderedDict
from typing import Optional
import numpy as np import numpy as np
import ollama
from lightrag import LightRAG from lightrag import LightRAG
from lightrag.utils import EmbeddingFunc from lightrag.utils import EmbeddingFunc
from lightrag.llm.ollama import ollama_embed from lightrag.llm.ollama import ollama_embed
from app.config import settings from app.config import settings
# 全局 RAG 实例 # 全局 RAG 管理器
rag = None rag_manager = None
async def llm_func(prompt, system_prompt=None, history_messages=[], **kwargs) -> str: async def llm_func(prompt, system_prompt=None, history_messages=[], **kwargs) -> str:
"""定义 LLM 函数:使用 Ollama 生成回复""" """定义 LLM 函数:使用 Ollama 生成回复"""
@ -59,40 +63,76 @@ async def embedding_func(texts: list[str]) -> np.ndarray:
host=settings.EMBEDDING_BINDING_HOST host=settings.EMBEDDING_BINDING_HOST
) )
async def initialize_rag(): class RAGManager:
"""初始化 LightRAG 实例""" def __init__(self, capacity: int = 3):
global rag self.capacity = capacity
self.cache = OrderedDict()
self.lock = threading.Lock()
# 确保工作目录存在 async def get_rag(self, user_id: str) -> LightRAG:
if not os.path.exists(settings.DATA_DIR): """获取指定用户的 LightRAG 实例 (LRU 缓存)"""
os.makedirs(settings.DATA_DIR) # 1. 尝试从缓存获取
with self.lock:
if user_id in self.cache:
self.cache.move_to_end(user_id)
logging.debug(f"Cache hit for user: {user_id}")
return self.cache[user_id]
print(f"正在初始化 LightRAG...") # 2. 缓存未命中,需要初始化
print(f"LLM: {settings.LLM_MODEL} @ {settings.LLM_BINDING_HOST}") logging.info(f"Initializing RAG instance for user: {user_id}")
print(f"Embedding: {settings.EMBEDDING_MODEL} @ {settings.EMBEDDING_BINDING_HOST}") user_data_dir = os.path.join(settings.DATA_DIR, user_id)
rag = LightRAG( # 严格隔离:确保工作目录绝对干净,不应包含上级目录或其他用户数据
working_dir=settings.DATA_DIR, if not os.path.exists(user_data_dir):
llm_model_func=llm_func, os.makedirs(user_data_dir)
llm_model_name=settings.LLM_MODEL,
llm_model_max_async=1,
max_parallel_insert=1,
embedding_func=EmbeddingFunc(
embedding_dim=settings.EMBEDDING_DIM,
max_token_size=settings.MAX_TOKEN_SIZE,
func=embedding_func
),
embedding_func_max_async=4,
enable_llm_cache=True
)
print("正在初始化 LightRAG 存储...") # 实例化 LightRAG
await rag.initialize_storages() # 关键修复:确保 working_dir 严格指向用户子目录
print("LightRAG 存储初始化完成") rag = LightRAG(
return rag working_dir=user_data_dir,
llm_model_func=llm_func,
llm_model_name=settings.LLM_MODEL,
llm_model_max_async=1,
max_parallel_insert=1,
embedding_func=EmbeddingFunc(
embedding_dim=settings.EMBEDDING_DIM,
max_token_size=settings.MAX_TOKEN_SIZE,
func=embedding_func
),
embedding_func_max_async=4,
enable_llm_cache=True
)
def get_rag(): # 异步初始化存储
"""获取全局 RAG 实例""" await rag.initialize_storages()
if rag is None:
raise RuntimeError("RAG instance not initialized. Call initialize_rag() first.") # 3. 放入缓存并处理驱逐
return rag with self.lock:
# 双重检查,防止在初始化期间其他线程已经创建了
if user_id in self.cache:
self.cache.move_to_end(user_id)
return self.cache[user_id]
self.cache[user_id] = rag
self.cache.move_to_end(user_id)
# 检查容量
while len(self.cache) > self.capacity:
oldest_user, _ = self.cache.popitem(last=False)
logging.info(f"Evicting RAG instance for user: {oldest_user} due to capacity limit")
# 这里依赖 Python GC 回收资源
return rag
def initialize_rag_manager():
"""初始化全局 RAG 管理器"""
global rag_manager
rag_manager = RAGManager(capacity=settings.MAX_RAG_INSTANCES)
logging.info(f"RAG Manager initialized with capacity: {settings.MAX_RAG_INSTANCES}")
return rag_manager
def get_rag_manager() -> RAGManager:
"""获取全局 RAG 管理器"""
if rag_manager is None:
return initialize_rag_manager()
return rag_manager

View File

@ -2,7 +2,7 @@ import logging
from fastapi import FastAPI from fastapi import FastAPI
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from app.config import settings from app.config import settings
from app.core.rag import initialize_rag from app.core.rag import initialize_rag_manager
from app.core.prompts import patch_prompts from app.core.prompts import patch_prompts
from app.api.routes import router from app.api.routes import router
@ -15,8 +15,8 @@ async def lifespan(app: FastAPI):
# 1. Patch Prompts # 1. Patch Prompts
patch_prompts() patch_prompts()
# 2. Init RAG # 2. Init RAG Manager
await initialize_rag() initialize_rag_manager()
yield yield