ai-lightrag/app/api/routes.py

411 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import logging
import os
from fastapi import APIRouter, UploadFile, File, HTTPException, Body, Header, Depends
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from lightrag import LightRAG, QueryParam
from app.core.rag import get_rag_manager, llm_func
from app.core.prompts import CUSTOM_RAG_RESPONSE_PROMPT
from app.config import settings
from app.core.ingest import process_pdf_with_images
router = APIRouter()
# ==========================================
# 依赖注入
# ==========================================
async def get_current_rag(
x_tenant_id: str = Header("default", alias="X-Tenant-ID")
) -> LightRAG:
"""
依赖项:获取当前租户的 LightRAG 实例
从 Header 中读取 X-Tenant-ID默认为 'default'
"""
manager = get_rag_manager()
return await manager.get_rag(x_tenant_id)
# ==========================================
# 数据模型
# ==========================================
class QueryRequest(BaseModel):
query: str
mode: str = "hybrid" # 可选: naive, local, global, hybrid
top_k: int = 5
stream: bool = False
think: bool = False
class IngestResponse(BaseModel):
filename: str
status: str
message: str
class QAPair(BaseModel):
question: str
answer: str
# ==========================================
# 接口实现
# ==========================================
import secrets
import string
@router.get("/admin/tenants")
async def list_tenants(token: str):
"""
管理员接口:获取租户列表
"""
if token != settings.ADMIN_TOKEN:
raise HTTPException(status_code=403, detail="Invalid admin token")
try:
if not os.path.exists(settings.DATA_DIR):
return {"tenants": []}
tenants = []
for entry in os.scandir(settings.DATA_DIR):
if entry.is_dir() and not entry.name.startswith("."):
tenant_id = entry.name
secret_file = os.path.join(entry.path, ".secret")
# 读取或生成租户专属 Secret
if os.path.exists(secret_file):
with open(secret_file, "r") as f:
secret = f.read().strip()
else:
# 生成16位随机字符串
alphabet = string.ascii_letters + string.digits
secret = ''.join(secrets.choice(alphabet) for i in range(16))
try:
with open(secret_file, "w") as f:
f.write(secret)
except Exception as e:
logging.error(f"Failed to write secret for tenant {tenant_id}: {e}")
continue
# 生成租户访问 Token (租户ID_随机串)
tenant_token = f"{tenant_id}_{secret}"
tenants.append({
"id": tenant_id,
"token": tenant_token
})
return {"tenants": tenants}
except Exception as e:
logging.error(f"Failed to list tenants: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/health")
async def health_check():
"""健康检查接口"""
return {"status": "ok", "llm": settings.LLM_MODEL}
@router.post("/query")
async def query_knowledge_base(
request: QueryRequest,
rag: LightRAG = Depends(get_current_rag)
):
"""
查询接口
- query: 用户问题
- mode: 检索模式 (推荐 hybrid 用于事实类查询)
- stream: 是否流式输出 (默认 False)
- think: 是否启用思考模式 (默认 False)
"""
try:
# 构造查询参数
param = QueryParam(
mode=request.mode,
top_k=request.top_k,
stream=request.stream,
enable_rerank=settings.RERANK_ENABLED
)
# 处理流式输出 (SSE 协议)
if request.stream:
async def stream_generator():
# SSE 格式化辅助函数
def sse_pack(event: str, text: str) -> str:
# 使用 JSON 包装 data 内容,确保换行符和特殊字符被正确转义
data = json.dumps({"text": text}, ensure_ascii=False)
return f"event: {event}\ndata: {data}\n\n"
yield sse_pack("thinking", "1. 上下文检索中...\n")
context_param = QueryParam(
mode=request.mode,
top_k=request.top_k,
only_need_context=True,
enable_rerank=settings.RERANK_ENABLED
)
# 获取上下文 (这步耗时较长,包含图遍历)
context_resp = await rag.aquery(request.query, param=context_param)
logging.info(f"Context Response: {context_resp}")
# 判断检索状态
has_context = False
if context_resp and "[no-context]" not in context_resp and "None" not in context_resp:
has_context = True
# 判断是否开启think
think = request.think
if has_context:
yield sse_pack("system", "retrieved") # 发送系统事件:已检索到信息
yield sse_pack("thinking", f"2. 上下文已检索 (长度: {len(context_resp)} 字符).\n")
else:
yield sse_pack("system", "missed") # 发送系统事件:未检索到信息
yield sse_pack("thinking", "2. 未找到相关上下文,将依赖 LLM 自身知识\n")
think = False
yield sse_pack("thinking", "3. 答案生成中...\n")
# 2. 生成答案
# 手动构建 System Prompt
sys_prompt = CUSTOM_RAG_RESPONSE_PROMPT.format(
context_data=context_resp,
response_type="Multiple Paragraphs",
user_prompt=""
)
# 调用 LLM 生成 (流式)
stream_resp = await llm_func(
request.query,
system_prompt=sys_prompt,
stream=True,
think=think,
hashing_kv=rag.llm_response_cache
)
thinkState = 0 # think 状态 0: 未开始 1: 开始 2: 结束
async for chunk in stream_resp:
if isinstance(chunk, dict):
if chunk.get("type") == "thinking":
if thinkState == 0:
yield sse_pack("thinking", "\n思考:\n")
thinkState = 1
yield sse_pack("thinking", chunk["content"])
elif chunk.get("type") == "content":
if thinkState == 1:
yield sse_pack("none", "\n\n\n")
thinkState = 2
yield sse_pack("answer", chunk["content"])
elif chunk:
yield sse_pack("answer", chunk)
# 使用 text/event-stream Content-Type
return StreamingResponse(stream_generator(), media_type="text/event-stream")
# 处理普通输出
response = await rag.aquery(request.query, param=param)
return {"response": response}
except Exception as e:
logging.error(f"查询失败: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/ingest/text")
async def ingest_text(
text: str = Body(..., embed=True),
rag: LightRAG = Depends(get_current_rag)
):
"""直接摄入文本内容"""
try:
# 使用异步方法 ainsert
await rag.ainsert(text)
return {"status": "success", "message": "Text ingested successfully"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/ingest/batch_qa")
async def ingest_batch_qa(
qa_list: list[QAPair],
rag: LightRAG = Depends(get_current_rag)
):
"""
批量摄入 QA 对
- 自动将 QA 格式化为语义连贯的文本块
- 自动合并短 QA 以优化索引效率
"""
if not qa_list:
return {"status": "skipped", "message": "Empty QA list"}
try:
# 1. 格式化并合并文本
batch_text = ""
current_batch_size = 0
MAX_BATCH_CHARS = 2000 # 约 1000 tokens保守估计
inserted_count = 0
for qa in qa_list:
# 格式化单条 QA
entry = f"--- Q&A Entry ---\nQuestion: {qa.question}\nAnswer: {qa.answer}\n\n"
entry_len = len(entry)
# 如果当前批次过大,先提交一次
if current_batch_size + entry_len > MAX_BATCH_CHARS:
await rag.ainsert(batch_text)
batch_text = ""
current_batch_size = 0
inserted_count += 1
batch_text += entry
current_batch_size += entry_len
# 提交剩余的文本
if batch_text:
await rag.ainsert(batch_text)
inserted_count += 1
return {
"status": "success",
"message": f"Successfully processed {len(qa_list)} QA pairs into {inserted_count} text chunks."
}
except Exception as e:
logging.error(f"QA批量导入失败: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/ingest/file", response_model=IngestResponse)
async def upload_file(
file: UploadFile = File(...),
rag: LightRAG = Depends(get_current_rag)
):
"""
文件上传与索引接口
支持 .txt, .md, .pdf
"""
try:
content = ""
filename = file.filename
# 读取文件内容
file_bytes = await file.read()
# 根据文件类型解析
if filename.endswith(".pdf"):
# PDF 解析逻辑 (支持图文)
content = await process_pdf_with_images(file_bytes)
elif filename.endswith(".txt") or filename.endswith(".md"):
# 文本文件直接解码
content = file_bytes.decode("utf-8")
else:
return IngestResponse(
filename=filename,
status="skipped",
message="Unsupported file format. Only PDF, TXT, MD supported."
)
if not content.strip():
return IngestResponse(
filename=filename,
status="failed",
message="Empty content extracted."
)
# 插入 LightRAG 索引 (这是一个耗时操作)
# 使用异步方法 ainsert
await rag.ainsert(content)
return IngestResponse(
filename=filename,
status="success",
message=f"Successfully indexed {len(content)} characters."
)
except Exception as e:
logging.error(f"处理文件 {file.filename} 失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"Processing failed: {str(e)}")
@router.get("/documents")
async def list_documents(
rag: LightRAG = Depends(get_current_rag)
):
"""
获取文档列表
返回当前知识库中已索引的所有文档及其状态
"""
try:
# 从 rag 实例获取工作目录
doc_status_path = os.path.join(rag.working_dir, "kv_store_doc_status.json")
if not os.path.exists(doc_status_path):
return {"count": 0, "docs": []}
with open(doc_status_path, "r", encoding="utf-8") as f:
data = json.load(f)
# 格式化返回结果
docs = []
for doc_id, info in data.items():
docs.append({
"id": doc_id,
"summary": info.get("content_summary", "")[:100] + "...", # 摘要截断
"length": info.get("content_length", 0),
"created_at": info.get("created_at"),
"status": info.get("status")
})
return {
"count": len(docs),
"docs": docs
}
except Exception as e:
logging.error(f"获取文档列表失败: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/documents/{doc_id}")
async def get_document_detail(
doc_id: str,
rag: LightRAG = Depends(get_current_rag)
):
"""
获取文档详情
返回指定文档的完整内容
"""
try:
# 读取 kv_store_full_docs.json
full_docs_path = os.path.join(rag.working_dir, "kv_store_full_docs.json")
if not os.path.exists(full_docs_path):
raise HTTPException(status_code=404, detail="Document store not found")
with open(full_docs_path, "r", encoding="utf-8") as f:
data = json.load(f)
if doc_id not in data:
raise HTTPException(status_code=404, detail="Document not found")
doc_info = data[doc_id]
return {
"id": doc_id,
"content": doc_info.get("content", ""),
"create_time": doc_info.get("create_time"),
"file_path": doc_info.get("file_path", "unknown")
}
except HTTPException:
raise
except Exception as e:
logging.error(f"获取文档详情失败: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.delete("/docs/{doc_id}")
async def delete_document(
doc_id: str,
rag: LightRAG = Depends(get_current_rag)
):
"""
删除指定文档
- doc_id: 文档ID (例如 doc-xxxxx)
"""
try:
logging.info(f"正在删除文档: {doc_id}")
# 调用 LightRAG 的删除方法
await rag.adelete_by_doc_id(doc_id)
return {"status": "success", "message": f"Document {doc_id} deleted successfully"}
except Exception as e:
logging.error(f"删除文档 {doc_id} 失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"Delete failed: {str(e)}")