265 lines
9.1 KiB
Python
265 lines
9.1 KiB
Python
import json
|
||
import logging
|
||
import os
|
||
import io
|
||
from fastapi import APIRouter, UploadFile, File, HTTPException, Body, Header, Depends
|
||
from fastapi.responses import StreamingResponse
|
||
from pydantic import BaseModel
|
||
from pypdf import PdfReader
|
||
from lightrag import LightRAG, QueryParam
|
||
from app.core.rag import get_rag_manager, llm_func
|
||
from app.core.prompts import CUSTOM_RAG_RESPONSE_PROMPT
|
||
from app.config import settings
|
||
|
||
router = APIRouter()
|
||
|
||
# ==========================================
|
||
# 依赖注入
|
||
# ==========================================
|
||
async def get_current_rag(
|
||
x_tenant_id: str = Header("default", alias="X-Tenant-ID")
|
||
) -> LightRAG:
|
||
"""
|
||
依赖项:获取当前租户的 LightRAG 实例
|
||
从 Header 中读取 X-Tenant-ID,默认为 'default'
|
||
"""
|
||
manager = get_rag_manager()
|
||
return await manager.get_rag(x_tenant_id)
|
||
|
||
# ==========================================
|
||
# 数据模型
|
||
# ==========================================
|
||
class QueryRequest(BaseModel):
|
||
query: str
|
||
mode: str = "hybrid" # 可选: naive, local, global, hybrid
|
||
top_k: int = 5
|
||
stream: bool = False
|
||
|
||
class IngestResponse(BaseModel):
|
||
filename: str
|
||
status: str
|
||
message: str
|
||
|
||
# ==========================================
|
||
# 接口实现
|
||
# ==========================================
|
||
|
||
@router.get("/health")
|
||
async def health_check():
|
||
"""健康检查接口"""
|
||
return {"status": "ok", "llm": settings.LLM_MODEL}
|
||
|
||
@router.post("/query")
|
||
async def query_knowledge_base(
|
||
request: QueryRequest,
|
||
rag: LightRAG = Depends(get_current_rag)
|
||
):
|
||
"""
|
||
查询接口
|
||
- query: 用户问题
|
||
- mode: 检索模式 (推荐 hybrid 用于事实类查询)
|
||
- stream: 是否流式输出 (默认 False)
|
||
"""
|
||
try:
|
||
# 构造查询参数
|
||
param = QueryParam(
|
||
mode=request.mode,
|
||
top_k=request.top_k,
|
||
stream=request.stream,
|
||
enable_rerank=False # 显式关闭 Rerank 以消除 Warning
|
||
)
|
||
|
||
# 处理流式输出 (SSE 协议)
|
||
if request.stream:
|
||
async def stream_generator():
|
||
# SSE 格式化辅助函数
|
||
def sse_pack(event: str, text: str) -> str:
|
||
# 使用 JSON 包装 data 内容,确保换行符和特殊字符被正确转义
|
||
data = json.dumps({"text": text}, ensure_ascii=False)
|
||
return f"event: {event}\ndata: {data}\n\n"
|
||
|
||
yield sse_pack("thinking", "1. 上下文检索中...\n")
|
||
|
||
context_param = QueryParam(
|
||
mode=request.mode,
|
||
top_k=request.top_k,
|
||
only_need_context=True,
|
||
enable_rerank=False # 显式关闭 Rerank 以消除 Warning
|
||
)
|
||
|
||
# 获取上下文 (这步耗时较长,包含图遍历)
|
||
context_resp = await rag.aquery(request.query, param=context_param)
|
||
|
||
# 简单判断是否找到内容
|
||
if not context_resp or "Sorry, I'm not able to answer" in context_resp:
|
||
yield sse_pack("thinking", " (未找到相关上下文,将依赖 LLM 自身知识)")
|
||
else:
|
||
yield sse_pack("thinking", f"2. 上下文已检索 (长度: {len(context_resp)} 字符).\n")
|
||
|
||
yield sse_pack("thinking", "3. 答案生成中...\n")
|
||
|
||
# 2. 生成答案
|
||
# 手动构建 System Prompt
|
||
sys_prompt = CUSTOM_RAG_RESPONSE_PROMPT.format(
|
||
context_data=context_resp,
|
||
response_type="Multiple Paragraphs",
|
||
user_prompt=""
|
||
)
|
||
|
||
# 调用 LLM 生成 (流式)
|
||
stream_resp = await llm_func(
|
||
request.query,
|
||
system_prompt=sys_prompt,
|
||
stream=True,
|
||
hashing_kv=rag.llm_response_cache
|
||
)
|
||
|
||
thinkState = 0 # think 状态 0: 未开始 1: 开始 2: 结束
|
||
async for chunk in stream_resp:
|
||
if isinstance(chunk, dict):
|
||
if chunk.get("type") == "thinking":
|
||
if thinkState == 0:
|
||
yield sse_pack("thinking", "\n思考:\n")
|
||
thinkState = 1
|
||
|
||
yield sse_pack("thinking", chunk["content"])
|
||
elif chunk.get("type") == "content":
|
||
if thinkState == 1:
|
||
yield sse_pack("none", "\n\n\n")
|
||
thinkState = 2
|
||
|
||
yield sse_pack("answer", chunk["content"])
|
||
elif chunk:
|
||
yield sse_pack("answer", chunk)
|
||
|
||
# 使用 text/event-stream Content-Type
|
||
return StreamingResponse(stream_generator(), media_type="text/event-stream")
|
||
|
||
# 处理普通输出
|
||
response = await rag.aquery(request.query, param=param)
|
||
return {"response": response}
|
||
|
||
except Exception as e:
|
||
logging.error(f"查询失败: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
@router.post("/ingest/text")
|
||
async def ingest_text(
|
||
text: str = Body(..., embed=True),
|
||
rag: LightRAG = Depends(get_current_rag)
|
||
):
|
||
"""直接摄入文本内容"""
|
||
try:
|
||
# 使用异步方法 ainsert
|
||
await rag.ainsert(text)
|
||
return {"status": "success", "message": "Text ingested successfully"}
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
@router.post("/ingest/file", response_model=IngestResponse)
|
||
async def upload_file(
|
||
file: UploadFile = File(...),
|
||
rag: LightRAG = Depends(get_current_rag)
|
||
):
|
||
"""
|
||
文件上传与索引接口
|
||
支持 .txt, .md, .pdf
|
||
"""
|
||
try:
|
||
content = ""
|
||
filename = file.filename
|
||
|
||
# 读取文件内容
|
||
file_bytes = await file.read()
|
||
|
||
# 根据文件类型解析
|
||
if filename.endswith(".pdf"):
|
||
# PDF 解析逻辑
|
||
pdf_reader = PdfReader(io.BytesIO(file_bytes))
|
||
for page in pdf_reader.pages:
|
||
content += page.extract_text() + "\n"
|
||
elif filename.endswith(".txt") or filename.endswith(".md"):
|
||
# 文本文件直接解码
|
||
content = file_bytes.decode("utf-8")
|
||
else:
|
||
return IngestResponse(
|
||
filename=filename,
|
||
status="skipped",
|
||
message="Unsupported file format. Only PDF, TXT, MD supported."
|
||
)
|
||
|
||
if not content.strip():
|
||
return IngestResponse(
|
||
filename=filename,
|
||
status="failed",
|
||
message="Empty content extracted."
|
||
)
|
||
|
||
# 插入 LightRAG 索引 (这是一个耗时操作)
|
||
# 使用异步方法 ainsert
|
||
await rag.ainsert(content)
|
||
|
||
return IngestResponse(
|
||
filename=filename,
|
||
status="success",
|
||
message=f"Successfully indexed {len(content)} characters."
|
||
)
|
||
|
||
except Exception as e:
|
||
logging.error(f"处理文件 {file.filename} 失败: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"Processing failed: {str(e)}")
|
||
|
||
@router.get("/documents")
|
||
async def list_documents(
|
||
rag: LightRAG = Depends(get_current_rag)
|
||
):
|
||
"""
|
||
获取文档列表
|
||
返回当前知识库中已索引的所有文档及其状态
|
||
"""
|
||
try:
|
||
# 从 rag 实例获取工作目录
|
||
doc_status_path = os.path.join(rag.working_dir, "kv_store_doc_status.json")
|
||
if not os.path.exists(doc_status_path):
|
||
return {"count": 0, "docs": []}
|
||
|
||
with open(doc_status_path, "r", encoding="utf-8") as f:
|
||
data = json.load(f)
|
||
|
||
# 格式化返回结果
|
||
docs = []
|
||
for doc_id, info in data.items():
|
||
docs.append({
|
||
"id": doc_id,
|
||
"summary": info.get("content_summary", "")[:100] + "...", # 摘要截断
|
||
"length": info.get("content_length", 0),
|
||
"created_at": info.get("created_at"),
|
||
"status": info.get("status")
|
||
})
|
||
|
||
return {
|
||
"count": len(docs),
|
||
"docs": docs
|
||
}
|
||
except Exception as e:
|
||
logging.error(f"获取文档列表失败: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
@router.delete("/docs/{doc_id}")
|
||
async def delete_document(
|
||
doc_id: str,
|
||
rag: LightRAG = Depends(get_current_rag)
|
||
):
|
||
"""
|
||
删除指定文档
|
||
- doc_id: 文档ID (例如 doc-xxxxx)
|
||
"""
|
||
try:
|
||
logging.info(f"正在删除文档: {doc_id}")
|
||
# 调用 LightRAG 的删除方法
|
||
await rag.adelete_by_doc_id(doc_id)
|
||
return {"status": "success", "message": f"Document {doc_id} deleted successfully"}
|
||
except Exception as e:
|
||
logging.error(f"删除文档 {doc_id} 失败: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"Delete failed: {str(e)}")
|