feat: 增加 rag_manager 多租户
This commit is contained in:
parent
08dee51b9f
commit
0ce9215ee3
|
|
@ -2,17 +2,30 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import io
|
import io
|
||||||
from fastapi import APIRouter, UploadFile, File, HTTPException, Body
|
from fastapi import APIRouter, UploadFile, File, HTTPException, Body, Header, Depends
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from pypdf import PdfReader
|
from pypdf import PdfReader
|
||||||
from lightrag import QueryParam
|
from lightrag import LightRAG, QueryParam
|
||||||
from app.core.rag import get_rag, llm_func
|
from app.core.rag import get_rag_manager, llm_func
|
||||||
from app.core.prompts import CUSTOM_RAG_RESPONSE_PROMPT
|
from app.core.prompts import CUSTOM_RAG_RESPONSE_PROMPT
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
# ==========================================
|
||||||
|
# 依赖注入
|
||||||
|
# ==========================================
|
||||||
|
async def get_current_rag(
|
||||||
|
x_tenant_id: str = Header("default", alias="X-Tenant-ID")
|
||||||
|
) -> LightRAG:
|
||||||
|
"""
|
||||||
|
依赖项:获取当前租户的 LightRAG 实例
|
||||||
|
从 Header 中读取 X-Tenant-ID,默认为 'default'
|
||||||
|
"""
|
||||||
|
manager = get_rag_manager()
|
||||||
|
return await manager.get_rag(x_tenant_id)
|
||||||
|
|
||||||
# ==========================================
|
# ==========================================
|
||||||
# 数据模型
|
# 数据模型
|
||||||
# ==========================================
|
# ==========================================
|
||||||
|
|
@ -37,7 +50,10 @@ async def health_check():
|
||||||
return {"status": "ok", "llm": settings.LLM_MODEL}
|
return {"status": "ok", "llm": settings.LLM_MODEL}
|
||||||
|
|
||||||
@router.post("/query")
|
@router.post("/query")
|
||||||
async def query_knowledge_base(request: QueryRequest):
|
async def query_knowledge_base(
|
||||||
|
request: QueryRequest,
|
||||||
|
rag: LightRAG = Depends(get_current_rag)
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
查询接口
|
查询接口
|
||||||
- query: 用户问题
|
- query: 用户问题
|
||||||
|
|
@ -45,7 +61,6 @@ async def query_knowledge_base(request: QueryRequest):
|
||||||
- stream: 是否流式输出 (默认 False)
|
- stream: 是否流式输出 (默认 False)
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
rag = get_rag()
|
|
||||||
# 构造查询参数
|
# 构造查询参数
|
||||||
param = QueryParam(
|
param = QueryParam(
|
||||||
mode=request.mode,
|
mode=request.mode,
|
||||||
|
|
@ -129,10 +144,12 @@ async def query_knowledge_base(request: QueryRequest):
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@router.post("/ingest/text")
|
@router.post("/ingest/text")
|
||||||
async def ingest_text(text: str = Body(..., embed=True)):
|
async def ingest_text(
|
||||||
|
text: str = Body(..., embed=True),
|
||||||
|
rag: LightRAG = Depends(get_current_rag)
|
||||||
|
):
|
||||||
"""直接摄入文本内容"""
|
"""直接摄入文本内容"""
|
||||||
try:
|
try:
|
||||||
rag = get_rag()
|
|
||||||
# 使用异步方法 ainsert
|
# 使用异步方法 ainsert
|
||||||
await rag.ainsert(text)
|
await rag.ainsert(text)
|
||||||
return {"status": "success", "message": "Text ingested successfully"}
|
return {"status": "success", "message": "Text ingested successfully"}
|
||||||
|
|
@ -140,13 +157,15 @@ async def ingest_text(text: str = Body(..., embed=True)):
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@router.post("/ingest/file", response_model=IngestResponse)
|
@router.post("/ingest/file", response_model=IngestResponse)
|
||||||
async def upload_file(file: UploadFile = File(...)):
|
async def upload_file(
|
||||||
|
file: UploadFile = File(...),
|
||||||
|
rag: LightRAG = Depends(get_current_rag)
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
文件上传与索引接口
|
文件上传与索引接口
|
||||||
支持 .txt, .md, .pdf
|
支持 .txt, .md, .pdf
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
rag = get_rag()
|
|
||||||
content = ""
|
content = ""
|
||||||
filename = file.filename
|
filename = file.filename
|
||||||
|
|
||||||
|
|
@ -191,14 +210,16 @@ async def upload_file(file: UploadFile = File(...)):
|
||||||
raise HTTPException(status_code=500, detail=f"Processing failed: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"Processing failed: {str(e)}")
|
||||||
|
|
||||||
@router.get("/documents")
|
@router.get("/documents")
|
||||||
async def list_documents():
|
async def list_documents(
|
||||||
|
rag: LightRAG = Depends(get_current_rag)
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
获取文档列表
|
获取文档列表
|
||||||
返回当前知识库中已索引的所有文档及其状态
|
返回当前知识库中已索引的所有文档及其状态
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 直接读取 doc_status_storage 的底层文件
|
# 从 rag 实例获取工作目录
|
||||||
doc_status_path = os.path.join(settings.DATA_DIR, "kv_store_doc_status.json")
|
doc_status_path = os.path.join(rag.working_dir, "kv_store_doc_status.json")
|
||||||
if not os.path.exists(doc_status_path):
|
if not os.path.exists(doc_status_path):
|
||||||
return {"count": 0, "docs": []}
|
return {"count": 0, "docs": []}
|
||||||
|
|
||||||
|
|
@ -225,13 +246,15 @@ async def list_documents():
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@router.delete("/docs/{doc_id}")
|
@router.delete("/docs/{doc_id}")
|
||||||
async def delete_document(doc_id: str):
|
async def delete_document(
|
||||||
|
doc_id: str,
|
||||||
|
rag: LightRAG = Depends(get_current_rag)
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
删除指定文档
|
删除指定文档
|
||||||
- doc_id: 文档ID (例如 doc-xxxxx)
|
- doc_id: 文档ID (例如 doc-xxxxx)
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
rag = get_rag()
|
|
||||||
logging.info(f"正在删除文档: {doc_id}")
|
logging.info(f"正在删除文档: {doc_id}")
|
||||||
# 调用 LightRAG 的删除方法
|
# 调用 LightRAG 的删除方法
|
||||||
await rag.adelete_by_doc_id(doc_id)
|
await rag.adelete_by_doc_id(doc_id)
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@ class Settings(BaseSettings):
|
||||||
# RAG Config
|
# RAG Config
|
||||||
EMBEDDING_DIM: int = 1024
|
EMBEDDING_DIM: int = 1024
|
||||||
MAX_TOKEN_SIZE: int = 8192
|
MAX_TOKEN_SIZE: int = 8192
|
||||||
|
MAX_RAG_INSTANCES: int = 3 # 最大活跃 RAG 实例数
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
env_file = ".env"
|
env_file = ".env"
|
||||||
|
|
|
||||||
110
app/core/rag.py
110
app/core/rag.py
|
|
@ -1,14 +1,18 @@
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import ollama
|
import threading
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import ollama
|
||||||
from lightrag import LightRAG
|
from lightrag import LightRAG
|
||||||
from lightrag.utils import EmbeddingFunc
|
from lightrag.utils import EmbeddingFunc
|
||||||
from lightrag.llm.ollama import ollama_embed
|
from lightrag.llm.ollama import ollama_embed
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
|
|
||||||
# 全局 RAG 实例
|
# 全局 RAG 管理器
|
||||||
rag = None
|
rag_manager = None
|
||||||
|
|
||||||
async def llm_func(prompt, system_prompt=None, history_messages=[], **kwargs) -> str:
|
async def llm_func(prompt, system_prompt=None, history_messages=[], **kwargs) -> str:
|
||||||
"""定义 LLM 函数:使用 Ollama 生成回复"""
|
"""定义 LLM 函数:使用 Ollama 生成回复"""
|
||||||
|
|
@ -59,40 +63,76 @@ async def embedding_func(texts: list[str]) -> np.ndarray:
|
||||||
host=settings.EMBEDDING_BINDING_HOST
|
host=settings.EMBEDDING_BINDING_HOST
|
||||||
)
|
)
|
||||||
|
|
||||||
async def initialize_rag():
|
class RAGManager:
|
||||||
"""初始化 LightRAG 实例"""
|
def __init__(self, capacity: int = 3):
|
||||||
global rag
|
self.capacity = capacity
|
||||||
|
self.cache = OrderedDict()
|
||||||
|
self.lock = threading.Lock()
|
||||||
|
|
||||||
# 确保工作目录存在
|
async def get_rag(self, user_id: str) -> LightRAG:
|
||||||
if not os.path.exists(settings.DATA_DIR):
|
"""获取指定用户的 LightRAG 实例 (LRU 缓存)"""
|
||||||
os.makedirs(settings.DATA_DIR)
|
# 1. 尝试从缓存获取
|
||||||
|
with self.lock:
|
||||||
|
if user_id in self.cache:
|
||||||
|
self.cache.move_to_end(user_id)
|
||||||
|
logging.debug(f"Cache hit for user: {user_id}")
|
||||||
|
return self.cache[user_id]
|
||||||
|
|
||||||
print(f"正在初始化 LightRAG...")
|
# 2. 缓存未命中,需要初始化
|
||||||
print(f"LLM: {settings.LLM_MODEL} @ {settings.LLM_BINDING_HOST}")
|
logging.info(f"Initializing RAG instance for user: {user_id}")
|
||||||
print(f"Embedding: {settings.EMBEDDING_MODEL} @ {settings.EMBEDDING_BINDING_HOST}")
|
user_data_dir = os.path.join(settings.DATA_DIR, user_id)
|
||||||
|
|
||||||
rag = LightRAG(
|
# 严格隔离:确保工作目录绝对干净,不应包含上级目录或其他用户数据
|
||||||
working_dir=settings.DATA_DIR,
|
if not os.path.exists(user_data_dir):
|
||||||
llm_model_func=llm_func,
|
os.makedirs(user_data_dir)
|
||||||
llm_model_name=settings.LLM_MODEL,
|
|
||||||
llm_model_max_async=1,
|
|
||||||
max_parallel_insert=1,
|
|
||||||
embedding_func=EmbeddingFunc(
|
|
||||||
embedding_dim=settings.EMBEDDING_DIM,
|
|
||||||
max_token_size=settings.MAX_TOKEN_SIZE,
|
|
||||||
func=embedding_func
|
|
||||||
),
|
|
||||||
embedding_func_max_async=4,
|
|
||||||
enable_llm_cache=True
|
|
||||||
)
|
|
||||||
|
|
||||||
print("正在初始化 LightRAG 存储...")
|
# 实例化 LightRAG
|
||||||
await rag.initialize_storages()
|
# 关键修复:确保 working_dir 严格指向用户子目录
|
||||||
print("LightRAG 存储初始化完成")
|
rag = LightRAG(
|
||||||
return rag
|
working_dir=user_data_dir,
|
||||||
|
llm_model_func=llm_func,
|
||||||
|
llm_model_name=settings.LLM_MODEL,
|
||||||
|
llm_model_max_async=1,
|
||||||
|
max_parallel_insert=1,
|
||||||
|
embedding_func=EmbeddingFunc(
|
||||||
|
embedding_dim=settings.EMBEDDING_DIM,
|
||||||
|
max_token_size=settings.MAX_TOKEN_SIZE,
|
||||||
|
func=embedding_func
|
||||||
|
),
|
||||||
|
embedding_func_max_async=4,
|
||||||
|
enable_llm_cache=True
|
||||||
|
)
|
||||||
|
|
||||||
def get_rag():
|
# 异步初始化存储
|
||||||
"""获取全局 RAG 实例"""
|
await rag.initialize_storages()
|
||||||
if rag is None:
|
|
||||||
raise RuntimeError("RAG instance not initialized. Call initialize_rag() first.")
|
# 3. 放入缓存并处理驱逐
|
||||||
return rag
|
with self.lock:
|
||||||
|
# 双重检查,防止在初始化期间其他线程已经创建了
|
||||||
|
if user_id in self.cache:
|
||||||
|
self.cache.move_to_end(user_id)
|
||||||
|
return self.cache[user_id]
|
||||||
|
|
||||||
|
self.cache[user_id] = rag
|
||||||
|
self.cache.move_to_end(user_id)
|
||||||
|
|
||||||
|
# 检查容量
|
||||||
|
while len(self.cache) > self.capacity:
|
||||||
|
oldest_user, _ = self.cache.popitem(last=False)
|
||||||
|
logging.info(f"Evicting RAG instance for user: {oldest_user} due to capacity limit")
|
||||||
|
# 这里依赖 Python GC 回收资源
|
||||||
|
|
||||||
|
return rag
|
||||||
|
|
||||||
|
def initialize_rag_manager():
|
||||||
|
"""初始化全局 RAG 管理器"""
|
||||||
|
global rag_manager
|
||||||
|
rag_manager = RAGManager(capacity=settings.MAX_RAG_INSTANCES)
|
||||||
|
logging.info(f"RAG Manager initialized with capacity: {settings.MAX_RAG_INSTANCES}")
|
||||||
|
return rag_manager
|
||||||
|
|
||||||
|
def get_rag_manager() -> RAGManager:
|
||||||
|
"""获取全局 RAG 管理器"""
|
||||||
|
if rag_manager is None:
|
||||||
|
return initialize_rag_manager()
|
||||||
|
return rag_manager
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ import logging
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
from app.core.rag import initialize_rag
|
from app.core.rag import initialize_rag_manager
|
||||||
from app.core.prompts import patch_prompts
|
from app.core.prompts import patch_prompts
|
||||||
from app.api.routes import router
|
from app.api.routes import router
|
||||||
|
|
||||||
|
|
@ -15,8 +15,8 @@ async def lifespan(app: FastAPI):
|
||||||
# 1. Patch Prompts
|
# 1. Patch Prompts
|
||||||
patch_prompts()
|
patch_prompts()
|
||||||
|
|
||||||
# 2. Init RAG
|
# 2. Init RAG Manager
|
||||||
await initialize_rag()
|
initialize_rag_manager()
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue