242 lines
8.5 KiB
Python
242 lines
8.5 KiB
Python
import json
|
|
import logging
|
|
import os
|
|
import io
|
|
from fastapi import APIRouter, UploadFile, File, HTTPException, Body
|
|
from fastapi.responses import StreamingResponse
|
|
from pydantic import BaseModel
|
|
from pypdf import PdfReader
|
|
from lightrag import QueryParam
|
|
from app.core.rag import get_rag, llm_func
|
|
from app.core.prompts import CUSTOM_RAG_RESPONSE_PROMPT
|
|
from app.config import settings
|
|
|
|
router = APIRouter()
|
|
|
|
# ==========================================
|
|
# 数据模型
|
|
# ==========================================
|
|
class QueryRequest(BaseModel):
|
|
query: str
|
|
mode: str = "hybrid" # 可选: naive, local, global, hybrid
|
|
top_k: int = 5
|
|
stream: bool = False
|
|
|
|
class IngestResponse(BaseModel):
|
|
filename: str
|
|
status: str
|
|
message: str
|
|
|
|
# ==========================================
|
|
# 接口实现
|
|
# ==========================================
|
|
|
|
@router.get("/health")
|
|
async def health_check():
|
|
"""健康检查接口"""
|
|
return {"status": "ok", "llm": settings.LLM_MODEL}
|
|
|
|
@router.post("/query")
|
|
async def query_knowledge_base(request: QueryRequest):
|
|
"""
|
|
查询接口
|
|
- query: 用户问题
|
|
- mode: 检索模式 (推荐 hybrid 用于事实类查询)
|
|
- stream: 是否流式输出 (默认 False)
|
|
"""
|
|
try:
|
|
rag = get_rag()
|
|
# 构造查询参数
|
|
param = QueryParam(
|
|
mode=request.mode,
|
|
top_k=request.top_k,
|
|
stream=request.stream,
|
|
enable_rerank=False # 显式关闭 Rerank 以消除 Warning
|
|
)
|
|
|
|
# 处理流式输出 (SSE 协议)
|
|
if request.stream:
|
|
async def stream_generator():
|
|
# SSE 格式化辅助函数
|
|
def sse_pack(event: str, text: str) -> str:
|
|
# 使用 JSON 包装 data 内容,确保换行符和特殊字符被正确转义
|
|
data = json.dumps({"text": text}, ensure_ascii=False)
|
|
return f"event: {event}\ndata: {data}\n\n"
|
|
|
|
yield sse_pack("thinking", "1. 上下文检索中...\n")
|
|
|
|
context_param = QueryParam(
|
|
mode=request.mode,
|
|
top_k=request.top_k,
|
|
only_need_context=True,
|
|
enable_rerank=False # 显式关闭 Rerank 以消除 Warning
|
|
)
|
|
|
|
# 获取上下文 (这步耗时较长,包含图遍历)
|
|
context_resp = await rag.aquery(request.query, param=context_param)
|
|
|
|
# 简单判断是否找到内容
|
|
if not context_resp or "Sorry, I'm not able to answer" in context_resp:
|
|
yield sse_pack("thinking", " (未找到相关上下文,将依赖 LLM 自身知识)")
|
|
else:
|
|
yield sse_pack("thinking", f"2. 上下文已检索 (长度: {len(context_resp)} 字符).\n")
|
|
|
|
yield sse_pack("thinking", "3. 答案生成中...\n")
|
|
|
|
# 2. 生成答案
|
|
# 手动构建 System Prompt
|
|
sys_prompt = CUSTOM_RAG_RESPONSE_PROMPT.format(
|
|
context_data=context_resp,
|
|
response_type="Multiple Paragraphs",
|
|
user_prompt=""
|
|
)
|
|
|
|
# 调用 LLM 生成 (流式)
|
|
stream_resp = await llm_func(
|
|
request.query,
|
|
system_prompt=sys_prompt,
|
|
stream=True,
|
|
hashing_kv=rag.llm_response_cache
|
|
)
|
|
|
|
thinkState = 0 # think 状态 0: 未开始 1: 开始 2: 结束
|
|
async for chunk in stream_resp:
|
|
if isinstance(chunk, dict):
|
|
if chunk.get("type") == "thinking":
|
|
if thinkState == 0:
|
|
yield sse_pack("thinking", "\n思考:\n")
|
|
thinkState = 1
|
|
|
|
yield sse_pack("thinking", chunk["content"])
|
|
elif chunk.get("type") == "content":
|
|
if thinkState == 1:
|
|
yield sse_pack("none", "\n\n\n")
|
|
thinkState = 2
|
|
|
|
yield sse_pack("answer", chunk["content"])
|
|
elif chunk:
|
|
yield sse_pack("answer", chunk)
|
|
|
|
# 使用 text/event-stream Content-Type
|
|
return StreamingResponse(stream_generator(), media_type="text/event-stream")
|
|
|
|
# 处理普通输出
|
|
response = await rag.aquery(request.query, param=param)
|
|
return {"response": response}
|
|
|
|
except Exception as e:
|
|
logging.error(f"查询失败: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@router.post("/ingest/text")
|
|
async def ingest_text(text: str = Body(..., embed=True)):
|
|
"""直接摄入文本内容"""
|
|
try:
|
|
rag = get_rag()
|
|
# 使用异步方法 ainsert
|
|
await rag.ainsert(text)
|
|
return {"status": "success", "message": "Text ingested successfully"}
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@router.post("/ingest/file", response_model=IngestResponse)
|
|
async def upload_file(file: UploadFile = File(...)):
|
|
"""
|
|
文件上传与索引接口
|
|
支持 .txt, .md, .pdf
|
|
"""
|
|
try:
|
|
rag = get_rag()
|
|
content = ""
|
|
filename = file.filename
|
|
|
|
# 读取文件内容
|
|
file_bytes = await file.read()
|
|
|
|
# 根据文件类型解析
|
|
if filename.endswith(".pdf"):
|
|
# PDF 解析逻辑
|
|
pdf_reader = PdfReader(io.BytesIO(file_bytes))
|
|
for page in pdf_reader.pages:
|
|
content += page.extract_text() + "\n"
|
|
elif filename.endswith(".txt") or filename.endswith(".md"):
|
|
# 文本文件直接解码
|
|
content = file_bytes.decode("utf-8")
|
|
else:
|
|
return IngestResponse(
|
|
filename=filename,
|
|
status="skipped",
|
|
message="Unsupported file format. Only PDF, TXT, MD supported."
|
|
)
|
|
|
|
if not content.strip():
|
|
return IngestResponse(
|
|
filename=filename,
|
|
status="failed",
|
|
message="Empty content extracted."
|
|
)
|
|
|
|
# 插入 LightRAG 索引 (这是一个耗时操作)
|
|
# 使用异步方法 ainsert
|
|
await rag.ainsert(content)
|
|
|
|
return IngestResponse(
|
|
filename=filename,
|
|
status="success",
|
|
message=f"Successfully indexed {len(content)} characters."
|
|
)
|
|
|
|
except Exception as e:
|
|
logging.error(f"处理文件 {file.filename} 失败: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=f"Processing failed: {str(e)}")
|
|
|
|
@router.get("/documents")
|
|
async def list_documents():
|
|
"""
|
|
获取文档列表
|
|
返回当前知识库中已索引的所有文档及其状态
|
|
"""
|
|
try:
|
|
# 直接读取 doc_status_storage 的底层文件
|
|
doc_status_path = os.path.join(settings.DATA_DIR, "kv_store_doc_status.json")
|
|
if not os.path.exists(doc_status_path):
|
|
return {"count": 0, "docs": []}
|
|
|
|
with open(doc_status_path, "r", encoding="utf-8") as f:
|
|
data = json.load(f)
|
|
|
|
# 格式化返回结果
|
|
docs = []
|
|
for doc_id, info in data.items():
|
|
docs.append({
|
|
"id": doc_id,
|
|
"summary": info.get("content_summary", "")[:100] + "...", # 摘要截断
|
|
"length": info.get("content_length", 0),
|
|
"created_at": info.get("created_at"),
|
|
"status": info.get("status")
|
|
})
|
|
|
|
return {
|
|
"count": len(docs),
|
|
"docs": docs
|
|
}
|
|
except Exception as e:
|
|
logging.error(f"获取文档列表失败: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@router.delete("/docs/{doc_id}")
|
|
async def delete_document(doc_id: str):
|
|
"""
|
|
删除指定文档
|
|
- doc_id: 文档ID (例如 doc-xxxxx)
|
|
"""
|
|
try:
|
|
rag = get_rag()
|
|
logging.info(f"正在删除文档: {doc_id}")
|
|
# 调用 LightRAG 的删除方法
|
|
await rag.adelete_by_doc_id(doc_id)
|
|
return {"status": "success", "message": f"Document {doc_id} deleted successfully"}
|
|
except Exception as e:
|
|
logging.error(f"删除文档 {doc_id} 失败: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=f"Delete failed: {str(e)}")
|