feat: 增加 rag_manager 多租户

This commit is contained in:
fuzhongyun 2026-01-12 14:32:51 +08:00
parent 08dee51b9f
commit 0ce9215ee3
4 changed files with 118 additions and 54 deletions

View File

@ -2,17 +2,30 @@ import json
import logging
import os
import io
from fastapi import APIRouter, UploadFile, File, HTTPException, Body
from fastapi import APIRouter, UploadFile, File, HTTPException, Body, Header, Depends
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from pypdf import PdfReader
from lightrag import QueryParam
from app.core.rag import get_rag, llm_func
from lightrag import LightRAG, QueryParam
from app.core.rag import get_rag_manager, llm_func
from app.core.prompts import CUSTOM_RAG_RESPONSE_PROMPT
from app.config import settings
router = APIRouter()
# ==========================================
# 依赖注入
# ==========================================
async def get_current_rag(
x_tenant_id: str = Header("default", alias="X-Tenant-ID")
) -> LightRAG:
"""
依赖项获取当前租户的 LightRAG 实例
Header 中读取 X-Tenant-ID默认为 'default'
"""
manager = get_rag_manager()
return await manager.get_rag(x_tenant_id)
# ==========================================
# 数据模型
# ==========================================
@ -37,7 +50,10 @@ async def health_check():
return {"status": "ok", "llm": settings.LLM_MODEL}
@router.post("/query")
async def query_knowledge_base(request: QueryRequest):
async def query_knowledge_base(
request: QueryRequest,
rag: LightRAG = Depends(get_current_rag)
):
"""
查询接口
- query: 用户问题
@ -45,7 +61,6 @@ async def query_knowledge_base(request: QueryRequest):
- stream: 是否流式输出 (默认 False)
"""
try:
rag = get_rag()
# 构造查询参数
param = QueryParam(
mode=request.mode,
@ -129,10 +144,12 @@ async def query_knowledge_base(request: QueryRequest):
raise HTTPException(status_code=500, detail=str(e))
@router.post("/ingest/text")
async def ingest_text(text: str = Body(..., embed=True)):
async def ingest_text(
text: str = Body(..., embed=True),
rag: LightRAG = Depends(get_current_rag)
):
"""直接摄入文本内容"""
try:
rag = get_rag()
# 使用异步方法 ainsert
await rag.ainsert(text)
return {"status": "success", "message": "Text ingested successfully"}
@ -140,13 +157,15 @@ async def ingest_text(text: str = Body(..., embed=True)):
raise HTTPException(status_code=500, detail=str(e))
@router.post("/ingest/file", response_model=IngestResponse)
async def upload_file(file: UploadFile = File(...)):
async def upload_file(
file: UploadFile = File(...),
rag: LightRAG = Depends(get_current_rag)
):
"""
文件上传与索引接口
支持 .txt, .md, .pdf
"""
try:
rag = get_rag()
content = ""
filename = file.filename
@ -191,14 +210,16 @@ async def upload_file(file: UploadFile = File(...)):
raise HTTPException(status_code=500, detail=f"Processing failed: {str(e)}")
@router.get("/documents")
async def list_documents():
async def list_documents(
rag: LightRAG = Depends(get_current_rag)
):
"""
获取文档列表
返回当前知识库中已索引的所有文档及其状态
"""
try:
# 直接读取 doc_status_storage 的底层文件
doc_status_path = os.path.join(settings.DATA_DIR, "kv_store_doc_status.json")
# 从 rag 实例获取工作目录
doc_status_path = os.path.join(rag.working_dir, "kv_store_doc_status.json")
if not os.path.exists(doc_status_path):
return {"count": 0, "docs": []}
@ -225,13 +246,15 @@ async def list_documents():
raise HTTPException(status_code=500, detail=str(e))
@router.delete("/docs/{doc_id}")
async def delete_document(doc_id: str):
async def delete_document(
doc_id: str,
rag: LightRAG = Depends(get_current_rag)
):
"""
删除指定文档
- doc_id: 文档ID (例如 doc-xxxxx)
"""
try:
rag = get_rag()
logging.info(f"正在删除文档: {doc_id}")
# 调用 LightRAG 的删除方法
await rag.adelete_by_doc_id(doc_id)

View File

@ -24,6 +24,7 @@ class Settings(BaseSettings):
# RAG Config
EMBEDDING_DIM: int = 1024
MAX_TOKEN_SIZE: int = 8192
MAX_RAG_INSTANCES: int = 3 # 最大活跃 RAG 实例数
class Config:
env_file = ".env"

View File

@ -1,14 +1,18 @@
import logging
import os
import ollama
import threading
from collections import OrderedDict
from typing import Optional
import numpy as np
import ollama
from lightrag import LightRAG
from lightrag.utils import EmbeddingFunc
from lightrag.llm.ollama import ollama_embed
from app.config import settings
# 全局 RAG 实例
rag = None
# 全局 RAG 管理器
rag_manager = None
async def llm_func(prompt, system_prompt=None, history_messages=[], **kwargs) -> str:
"""定义 LLM 函数:使用 Ollama 生成回复"""
@ -59,40 +63,76 @@ async def embedding_func(texts: list[str]) -> np.ndarray:
host=settings.EMBEDDING_BINDING_HOST
)
async def initialize_rag():
"""初始化 LightRAG 实例"""
global rag
class RAGManager:
def __init__(self, capacity: int = 3):
self.capacity = capacity
self.cache = OrderedDict()
self.lock = threading.Lock()
# 确保工作目录存在
if not os.path.exists(settings.DATA_DIR):
os.makedirs(settings.DATA_DIR)
async def get_rag(self, user_id: str) -> LightRAG:
"""获取指定用户的 LightRAG 实例 (LRU 缓存)"""
# 1. 尝试从缓存获取
with self.lock:
if user_id in self.cache:
self.cache.move_to_end(user_id)
logging.debug(f"Cache hit for user: {user_id}")
return self.cache[user_id]
print(f"正在初始化 LightRAG...")
print(f"LLM: {settings.LLM_MODEL} @ {settings.LLM_BINDING_HOST}")
print(f"Embedding: {settings.EMBEDDING_MODEL} @ {settings.EMBEDDING_BINDING_HOST}")
# 2. 缓存未命中,需要初始化
logging.info(f"Initializing RAG instance for user: {user_id}")
user_data_dir = os.path.join(settings.DATA_DIR, user_id)
rag = LightRAG(
working_dir=settings.DATA_DIR,
llm_model_func=llm_func,
llm_model_name=settings.LLM_MODEL,
llm_model_max_async=1,
max_parallel_insert=1,
embedding_func=EmbeddingFunc(
embedding_dim=settings.EMBEDDING_DIM,
max_token_size=settings.MAX_TOKEN_SIZE,
func=embedding_func
),
embedding_func_max_async=4,
enable_llm_cache=True
)
# 严格隔离:确保工作目录绝对干净,不应包含上级目录或其他用户数据
if not os.path.exists(user_data_dir):
os.makedirs(user_data_dir)
print("正在初始化 LightRAG 存储...")
await rag.initialize_storages()
print("LightRAG 存储初始化完成")
return rag
# 实例化 LightRAG
# 关键修复:确保 working_dir 严格指向用户子目录
rag = LightRAG(
working_dir=user_data_dir,
llm_model_func=llm_func,
llm_model_name=settings.LLM_MODEL,
llm_model_max_async=1,
max_parallel_insert=1,
embedding_func=EmbeddingFunc(
embedding_dim=settings.EMBEDDING_DIM,
max_token_size=settings.MAX_TOKEN_SIZE,
func=embedding_func
),
embedding_func_max_async=4,
enable_llm_cache=True
)
def get_rag():
"""获取全局 RAG 实例"""
if rag is None:
raise RuntimeError("RAG instance not initialized. Call initialize_rag() first.")
return rag
# 异步初始化存储
await rag.initialize_storages()
# 3. 放入缓存并处理驱逐
with self.lock:
# 双重检查,防止在初始化期间其他线程已经创建了
if user_id in self.cache:
self.cache.move_to_end(user_id)
return self.cache[user_id]
self.cache[user_id] = rag
self.cache.move_to_end(user_id)
# 检查容量
while len(self.cache) > self.capacity:
oldest_user, _ = self.cache.popitem(last=False)
logging.info(f"Evicting RAG instance for user: {oldest_user} due to capacity limit")
# 这里依赖 Python GC 回收资源
return rag
def initialize_rag_manager():
"""初始化全局 RAG 管理器"""
global rag_manager
rag_manager = RAGManager(capacity=settings.MAX_RAG_INSTANCES)
logging.info(f"RAG Manager initialized with capacity: {settings.MAX_RAG_INSTANCES}")
return rag_manager
def get_rag_manager() -> RAGManager:
"""获取全局 RAG 管理器"""
if rag_manager is None:
return initialize_rag_manager()
return rag_manager

View File

@ -2,7 +2,7 @@ import logging
from fastapi import FastAPI
from contextlib import asynccontextmanager
from app.config import settings
from app.core.rag import initialize_rag
from app.core.rag import initialize_rag_manager
from app.core.prompts import patch_prompts
from app.api.routes import router
@ -15,8 +15,8 @@ async def lifespan(app: FastAPI):
# 1. Patch Prompts
patch_prompts()
# 2. Init RAG
await initialize_rag()
# 2. Init RAG Manager
initialize_rag_manager()
yield