feat: 增加 rag_manager 多租户
This commit is contained in:
parent
08dee51b9f
commit
0ce9215ee3
|
|
@ -2,17 +2,30 @@ import json
|
|||
import logging
|
||||
import os
|
||||
import io
|
||||
from fastapi import APIRouter, UploadFile, File, HTTPException, Body
|
||||
from fastapi import APIRouter, UploadFile, File, HTTPException, Body, Header, Depends
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from pypdf import PdfReader
|
||||
from lightrag import QueryParam
|
||||
from app.core.rag import get_rag, llm_func
|
||||
from lightrag import LightRAG, QueryParam
|
||||
from app.core.rag import get_rag_manager, llm_func
|
||||
from app.core.prompts import CUSTOM_RAG_RESPONSE_PROMPT
|
||||
from app.config import settings
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# ==========================================
|
||||
# 依赖注入
|
||||
# ==========================================
|
||||
async def get_current_rag(
|
||||
x_tenant_id: str = Header("default", alias="X-Tenant-ID")
|
||||
) -> LightRAG:
|
||||
"""
|
||||
依赖项:获取当前租户的 LightRAG 实例
|
||||
从 Header 中读取 X-Tenant-ID,默认为 'default'
|
||||
"""
|
||||
manager = get_rag_manager()
|
||||
return await manager.get_rag(x_tenant_id)
|
||||
|
||||
# ==========================================
|
||||
# 数据模型
|
||||
# ==========================================
|
||||
|
|
@ -37,7 +50,10 @@ async def health_check():
|
|||
return {"status": "ok", "llm": settings.LLM_MODEL}
|
||||
|
||||
@router.post("/query")
|
||||
async def query_knowledge_base(request: QueryRequest):
|
||||
async def query_knowledge_base(
|
||||
request: QueryRequest,
|
||||
rag: LightRAG = Depends(get_current_rag)
|
||||
):
|
||||
"""
|
||||
查询接口
|
||||
- query: 用户问题
|
||||
|
|
@ -45,7 +61,6 @@ async def query_knowledge_base(request: QueryRequest):
|
|||
- stream: 是否流式输出 (默认 False)
|
||||
"""
|
||||
try:
|
||||
rag = get_rag()
|
||||
# 构造查询参数
|
||||
param = QueryParam(
|
||||
mode=request.mode,
|
||||
|
|
@ -129,10 +144,12 @@ async def query_knowledge_base(request: QueryRequest):
|
|||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.post("/ingest/text")
|
||||
async def ingest_text(text: str = Body(..., embed=True)):
|
||||
async def ingest_text(
|
||||
text: str = Body(..., embed=True),
|
||||
rag: LightRAG = Depends(get_current_rag)
|
||||
):
|
||||
"""直接摄入文本内容"""
|
||||
try:
|
||||
rag = get_rag()
|
||||
# 使用异步方法 ainsert
|
||||
await rag.ainsert(text)
|
||||
return {"status": "success", "message": "Text ingested successfully"}
|
||||
|
|
@ -140,13 +157,15 @@ async def ingest_text(text: str = Body(..., embed=True)):
|
|||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.post("/ingest/file", response_model=IngestResponse)
|
||||
async def upload_file(file: UploadFile = File(...)):
|
||||
async def upload_file(
|
||||
file: UploadFile = File(...),
|
||||
rag: LightRAG = Depends(get_current_rag)
|
||||
):
|
||||
"""
|
||||
文件上传与索引接口
|
||||
支持 .txt, .md, .pdf
|
||||
"""
|
||||
try:
|
||||
rag = get_rag()
|
||||
content = ""
|
||||
filename = file.filename
|
||||
|
||||
|
|
@ -191,14 +210,16 @@ async def upload_file(file: UploadFile = File(...)):
|
|||
raise HTTPException(status_code=500, detail=f"Processing failed: {str(e)}")
|
||||
|
||||
@router.get("/documents")
|
||||
async def list_documents():
|
||||
async def list_documents(
|
||||
rag: LightRAG = Depends(get_current_rag)
|
||||
):
|
||||
"""
|
||||
获取文档列表
|
||||
返回当前知识库中已索引的所有文档及其状态
|
||||
"""
|
||||
try:
|
||||
# 直接读取 doc_status_storage 的底层文件
|
||||
doc_status_path = os.path.join(settings.DATA_DIR, "kv_store_doc_status.json")
|
||||
# 从 rag 实例获取工作目录
|
||||
doc_status_path = os.path.join(rag.working_dir, "kv_store_doc_status.json")
|
||||
if not os.path.exists(doc_status_path):
|
||||
return {"count": 0, "docs": []}
|
||||
|
||||
|
|
@ -225,13 +246,15 @@ async def list_documents():
|
|||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.delete("/docs/{doc_id}")
|
||||
async def delete_document(doc_id: str):
|
||||
async def delete_document(
|
||||
doc_id: str,
|
||||
rag: LightRAG = Depends(get_current_rag)
|
||||
):
|
||||
"""
|
||||
删除指定文档
|
||||
- doc_id: 文档ID (例如 doc-xxxxx)
|
||||
"""
|
||||
try:
|
||||
rag = get_rag()
|
||||
logging.info(f"正在删除文档: {doc_id}")
|
||||
# 调用 LightRAG 的删除方法
|
||||
await rag.adelete_by_doc_id(doc_id)
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ class Settings(BaseSettings):
|
|||
# RAG Config
|
||||
EMBEDDING_DIM: int = 1024
|
||||
MAX_TOKEN_SIZE: int = 8192
|
||||
MAX_RAG_INSTANCES: int = 3 # 最大活跃 RAG 实例数
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
|
|
|
|||
110
app/core/rag.py
110
app/core/rag.py
|
|
@ -1,14 +1,18 @@
|
|||
import logging
|
||||
import os
|
||||
import ollama
|
||||
import threading
|
||||
from collections import OrderedDict
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import ollama
|
||||
from lightrag import LightRAG
|
||||
from lightrag.utils import EmbeddingFunc
|
||||
from lightrag.llm.ollama import ollama_embed
|
||||
from app.config import settings
|
||||
|
||||
# 全局 RAG 实例
|
||||
rag = None
|
||||
# 全局 RAG 管理器
|
||||
rag_manager = None
|
||||
|
||||
async def llm_func(prompt, system_prompt=None, history_messages=[], **kwargs) -> str:
|
||||
"""定义 LLM 函数:使用 Ollama 生成回复"""
|
||||
|
|
@ -59,40 +63,76 @@ async def embedding_func(texts: list[str]) -> np.ndarray:
|
|||
host=settings.EMBEDDING_BINDING_HOST
|
||||
)
|
||||
|
||||
async def initialize_rag():
|
||||
"""初始化 LightRAG 实例"""
|
||||
global rag
|
||||
class RAGManager:
|
||||
def __init__(self, capacity: int = 3):
|
||||
self.capacity = capacity
|
||||
self.cache = OrderedDict()
|
||||
self.lock = threading.Lock()
|
||||
|
||||
# 确保工作目录存在
|
||||
if not os.path.exists(settings.DATA_DIR):
|
||||
os.makedirs(settings.DATA_DIR)
|
||||
async def get_rag(self, user_id: str) -> LightRAG:
|
||||
"""获取指定用户的 LightRAG 实例 (LRU 缓存)"""
|
||||
# 1. 尝试从缓存获取
|
||||
with self.lock:
|
||||
if user_id in self.cache:
|
||||
self.cache.move_to_end(user_id)
|
||||
logging.debug(f"Cache hit for user: {user_id}")
|
||||
return self.cache[user_id]
|
||||
|
||||
print(f"正在初始化 LightRAG...")
|
||||
print(f"LLM: {settings.LLM_MODEL} @ {settings.LLM_BINDING_HOST}")
|
||||
print(f"Embedding: {settings.EMBEDDING_MODEL} @ {settings.EMBEDDING_BINDING_HOST}")
|
||||
# 2. 缓存未命中,需要初始化
|
||||
logging.info(f"Initializing RAG instance for user: {user_id}")
|
||||
user_data_dir = os.path.join(settings.DATA_DIR, user_id)
|
||||
|
||||
rag = LightRAG(
|
||||
working_dir=settings.DATA_DIR,
|
||||
llm_model_func=llm_func,
|
||||
llm_model_name=settings.LLM_MODEL,
|
||||
llm_model_max_async=1,
|
||||
max_parallel_insert=1,
|
||||
embedding_func=EmbeddingFunc(
|
||||
embedding_dim=settings.EMBEDDING_DIM,
|
||||
max_token_size=settings.MAX_TOKEN_SIZE,
|
||||
func=embedding_func
|
||||
),
|
||||
embedding_func_max_async=4,
|
||||
enable_llm_cache=True
|
||||
)
|
||||
# 严格隔离:确保工作目录绝对干净,不应包含上级目录或其他用户数据
|
||||
if not os.path.exists(user_data_dir):
|
||||
os.makedirs(user_data_dir)
|
||||
|
||||
print("正在初始化 LightRAG 存储...")
|
||||
await rag.initialize_storages()
|
||||
print("LightRAG 存储初始化完成")
|
||||
return rag
|
||||
# 实例化 LightRAG
|
||||
# 关键修复:确保 working_dir 严格指向用户子目录
|
||||
rag = LightRAG(
|
||||
working_dir=user_data_dir,
|
||||
llm_model_func=llm_func,
|
||||
llm_model_name=settings.LLM_MODEL,
|
||||
llm_model_max_async=1,
|
||||
max_parallel_insert=1,
|
||||
embedding_func=EmbeddingFunc(
|
||||
embedding_dim=settings.EMBEDDING_DIM,
|
||||
max_token_size=settings.MAX_TOKEN_SIZE,
|
||||
func=embedding_func
|
||||
),
|
||||
embedding_func_max_async=4,
|
||||
enable_llm_cache=True
|
||||
)
|
||||
|
||||
def get_rag():
|
||||
"""获取全局 RAG 实例"""
|
||||
if rag is None:
|
||||
raise RuntimeError("RAG instance not initialized. Call initialize_rag() first.")
|
||||
return rag
|
||||
# 异步初始化存储
|
||||
await rag.initialize_storages()
|
||||
|
||||
# 3. 放入缓存并处理驱逐
|
||||
with self.lock:
|
||||
# 双重检查,防止在初始化期间其他线程已经创建了
|
||||
if user_id in self.cache:
|
||||
self.cache.move_to_end(user_id)
|
||||
return self.cache[user_id]
|
||||
|
||||
self.cache[user_id] = rag
|
||||
self.cache.move_to_end(user_id)
|
||||
|
||||
# 检查容量
|
||||
while len(self.cache) > self.capacity:
|
||||
oldest_user, _ = self.cache.popitem(last=False)
|
||||
logging.info(f"Evicting RAG instance for user: {oldest_user} due to capacity limit")
|
||||
# 这里依赖 Python GC 回收资源
|
||||
|
||||
return rag
|
||||
|
||||
def initialize_rag_manager():
|
||||
"""初始化全局 RAG 管理器"""
|
||||
global rag_manager
|
||||
rag_manager = RAGManager(capacity=settings.MAX_RAG_INSTANCES)
|
||||
logging.info(f"RAG Manager initialized with capacity: {settings.MAX_RAG_INSTANCES}")
|
||||
return rag_manager
|
||||
|
||||
def get_rag_manager() -> RAGManager:
|
||||
"""获取全局 RAG 管理器"""
|
||||
if rag_manager is None:
|
||||
return initialize_rag_manager()
|
||||
return rag_manager
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import logging
|
|||
from fastapi import FastAPI
|
||||
from contextlib import asynccontextmanager
|
||||
from app.config import settings
|
||||
from app.core.rag import initialize_rag
|
||||
from app.core.rag import initialize_rag_manager
|
||||
from app.core.prompts import patch_prompts
|
||||
from app.api.routes import router
|
||||
|
||||
|
|
@ -15,8 +15,8 @@ async def lifespan(app: FastAPI):
|
|||
# 1. Patch Prompts
|
||||
patch_prompts()
|
||||
|
||||
# 2. Init RAG
|
||||
await initialize_rag()
|
||||
# 2. Init RAG Manager
|
||||
initialize_rag_manager()
|
||||
|
||||
yield
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue