479 lines
16 KiB
Python
479 lines
16 KiB
Python
import json
|
||
import logging
|
||
import os
|
||
import time
|
||
from fastapi import APIRouter, UploadFile, File, HTTPException, Body, Header, Depends
|
||
from fastapi.responses import StreamingResponse
|
||
from pydantic import BaseModel
|
||
from lightrag import LightRAG, QueryParam
|
||
from app.core.rag import get_rag_manager, llm_func
|
||
from app.core.prompts import CUSTOM_RAG_RESPONSE_PROMPT
|
||
from app.config import settings
|
||
from app.core.ingest import process_pdf_with_images
|
||
|
||
router = APIRouter()
|
||
|
||
# ==========================================
|
||
# 依赖注入
|
||
# ==========================================
|
||
async def get_current_rag(
|
||
x_tenant_id: str = Header("default", alias="X-Tenant-ID")
|
||
) -> LightRAG:
|
||
"""
|
||
依赖项:获取当前租户的 LightRAG 实例
|
||
从 Header 中读取 X-Tenant-ID,默认为 'default'
|
||
"""
|
||
manager = get_rag_manager()
|
||
return await manager.get_rag(x_tenant_id)
|
||
|
||
# ==========================================
|
||
# 数据模型
|
||
# ==========================================
|
||
class QueryRequest(BaseModel):
|
||
query: str
|
||
mode: str = "hybrid" # 可选: naive, local, global, hybrid
|
||
top_k: int = 5
|
||
stream: bool = False
|
||
think: bool = False
|
||
only_rag: bool = False # 是否仅使用RAG检索结果,不进行LLM兜底
|
||
|
||
class IngestResponse(BaseModel):
|
||
filename: str
|
||
status: str
|
||
message: str
|
||
|
||
class QAPair(BaseModel):
|
||
question: str
|
||
answer: str
|
||
|
||
# ==========================================
|
||
# 接口实现
|
||
# ==========================================
|
||
|
||
import secrets
|
||
import string
|
||
|
||
@router.get("/admin/tenants")
|
||
async def list_tenants(token: str):
|
||
"""
|
||
管理员接口:获取租户列表
|
||
"""
|
||
if token != settings.ADMIN_TOKEN:
|
||
raise HTTPException(status_code=403, detail="Invalid admin token")
|
||
|
||
try:
|
||
if not os.path.exists(settings.DATA_DIR):
|
||
return {"tenants": []}
|
||
|
||
tenants = []
|
||
for entry in os.scandir(settings.DATA_DIR):
|
||
if entry.is_dir() and not entry.name.startswith("."):
|
||
tenant_id = entry.name
|
||
secret_file = os.path.join(entry.path, ".secret")
|
||
|
||
# 读取或生成租户专属 Secret
|
||
if os.path.exists(secret_file):
|
||
with open(secret_file, "r") as f:
|
||
secret = f.read().strip()
|
||
else:
|
||
# 生成16位随机字符串
|
||
alphabet = string.ascii_letters + string.digits
|
||
secret = ''.join(secrets.choice(alphabet) for i in range(16))
|
||
try:
|
||
with open(secret_file, "w") as f:
|
||
f.write(secret)
|
||
except Exception as e:
|
||
logging.error(f"Failed to write secret for tenant {tenant_id}: {e}")
|
||
continue
|
||
|
||
# 生成租户访问 Token (租户ID_随机串)
|
||
tenant_token = f"{tenant_id}_{secret}"
|
||
tenants.append({
|
||
"id": tenant_id,
|
||
"token": tenant_token
|
||
})
|
||
return {"tenants": tenants}
|
||
except Exception as e:
|
||
logging.error(f"Failed to list tenants: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
@router.get("/health")
|
||
async def health_check():
|
||
"""健康检查接口"""
|
||
return {"status": "ok", "llm": settings.LLM_MODEL}
|
||
|
||
@router.post("/query")
|
||
async def query_knowledge_base(
|
||
request: QueryRequest,
|
||
rag: LightRAG = Depends(get_current_rag)
|
||
):
|
||
"""
|
||
查询接口
|
||
- query: 用户问题
|
||
- mode: 检索模式 (推荐 hybrid 用于事实类查询)
|
||
- stream: 是否流式输出 (默认 False)
|
||
- think: 是否启用思考模式 (默认 False)
|
||
"""
|
||
try:
|
||
# 构造查询参数
|
||
context_param = QueryParam(
|
||
mode=request.mode,
|
||
top_k=request.top_k,
|
||
only_need_context=True,
|
||
enable_rerank=settings.RERANK_ENABLED
|
||
)
|
||
|
||
# 执行上下文检索
|
||
context_resp = await rag.aquery(request.query, param=context_param)
|
||
|
||
# 判断检索命中状态
|
||
rag_status = "miss"
|
||
has_context = False
|
||
if context_resp and "[no-context]" not in context_resp and "None" not in context_resp:
|
||
rag_status = "hit"
|
||
has_context = True
|
||
|
||
# 处理流式输出 (SSE 协议 - OpenAI 兼容格式)
|
||
if request.stream:
|
||
async def stream_generator():
|
||
chat_id = f"chatcmpl-{secrets.token_hex(12)}"
|
||
created_time = int(time.time())
|
||
model_name = settings.LLM_MODEL
|
||
|
||
# 辅助函数:构造 OpenAI 兼容的 Chunk
|
||
def openai_chunk(content=None, reasoning_content=None, finish_reason=None, extra_delta=None):
|
||
delta = {}
|
||
if content:
|
||
delta["content"] = content
|
||
if reasoning_content:
|
||
delta["reasoning_content"] = reasoning_content
|
||
if extra_delta:
|
||
delta.update(extra_delta)
|
||
|
||
chunk = {
|
||
"id": chat_id,
|
||
"object": "chat.completion.chunk",
|
||
"created": created_time,
|
||
"model": model_name,
|
||
"choices": [
|
||
{
|
||
"index": 0,
|
||
"delta": delta,
|
||
"finish_reason": finish_reason
|
||
}
|
||
]
|
||
}
|
||
return f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
|
||
|
||
if has_context:
|
||
yield openai_chunk(
|
||
reasoning_content=f"1. 上下文已检索 (长度: {len(context_resp)} 字符).\n",
|
||
extra_delta={"x_rag_status": rag_status}
|
||
)
|
||
else:
|
||
yield openai_chunk(
|
||
reasoning_content="未找到相关上下文\n",
|
||
extra_delta={"x_rag_status": rag_status}
|
||
)
|
||
|
||
# 如果开启了仅RAG模式且未找到上下文,则直接结束
|
||
if request.only_rag:
|
||
yield openai_chunk(content="未找到相关知识库内容。", finish_reason="stop")
|
||
yield "data: [DONE]\n\n"
|
||
return
|
||
|
||
yield openai_chunk(reasoning_content=" (将依赖 LLM 自身知识)\n")
|
||
|
||
# 未找到上下文,关闭思考模式
|
||
request.think = False
|
||
|
||
yield openai_chunk(reasoning_content="2. 答案生成中...\n")
|
||
|
||
# 2. 生成答案
|
||
sys_prompt = CUSTOM_RAG_RESPONSE_PROMPT.format(
|
||
context_data=context_resp,
|
||
response_type="Multiple Paragraphs",
|
||
user_prompt=""
|
||
)
|
||
|
||
# 调用 LLM 生成 (流式)
|
||
stream_resp = await llm_func(
|
||
request.query,
|
||
system_prompt=sys_prompt,
|
||
stream=True,
|
||
think=request.think,
|
||
hashing_kv=rag.llm_response_cache
|
||
)
|
||
|
||
async for chunk in stream_resp:
|
||
if isinstance(chunk, dict):
|
||
if chunk.get("type") == "thinking":
|
||
yield openai_chunk(reasoning_content=chunk["content"])
|
||
elif chunk.get("type") == "content":
|
||
yield openai_chunk(content=chunk["content"])
|
||
elif chunk:
|
||
yield openai_chunk(content=chunk)
|
||
|
||
# 发送结束标记
|
||
yield openai_chunk(finish_reason="stop")
|
||
yield "data: [DONE]\n\n"
|
||
|
||
# 使用 text/event-stream Content-Type
|
||
return StreamingResponse(stream_generator(), media_type="text/event-stream")
|
||
|
||
# 处理普通输出 (OpenAI 兼容格式)
|
||
# 根据策略生成回答
|
||
final_answer = ""
|
||
|
||
if not has_context and request.only_rag:
|
||
# 严格模式且未检索到,直接结束
|
||
final_answer = "未找到相关知识库内容。"
|
||
else:
|
||
# 正常生成
|
||
sys_prompt = CUSTOM_RAG_RESPONSE_PROMPT.format(
|
||
context_data=context_resp,
|
||
response_type="Multiple Paragraphs",
|
||
user_prompt=""
|
||
)
|
||
|
||
# 调用 LLM 生成
|
||
stream_resp = await llm_func(
|
||
request.query,
|
||
system_prompt=sys_prompt,
|
||
stream=False,
|
||
think=request.think,
|
||
hashing_kv=rag.llm_response_cache
|
||
)
|
||
|
||
# 非流式调用 LLM 直接返回结果
|
||
final_answer = stream_resp
|
||
|
||
return {
|
||
"id": f"chatcmpl-{secrets.token_hex(12)}",
|
||
"object": "chat.completion",
|
||
"created": int(time.time()),
|
||
"model": settings.LLM_MODEL,
|
||
"choices": [
|
||
{
|
||
"index": 0,
|
||
"message": {
|
||
"role": "assistant",
|
||
"content": final_answer,
|
||
# 扩展字段:非标准,但在某些客户端可能有用,或者放入 usage/metadata
|
||
"x_rag_status": rag_status
|
||
},
|
||
"finish_reason": "stop"
|
||
}
|
||
],
|
||
"usage": {
|
||
"prompt_tokens": -1,
|
||
"completion_tokens": len(final_answer),
|
||
"total_tokens": -1
|
||
}
|
||
}
|
||
|
||
except Exception as e:
|
||
logging.error(f"查询失败: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
@router.post("/ingest/text")
|
||
async def ingest_text(
|
||
text: str = Body(..., embed=True),
|
||
rag: LightRAG = Depends(get_current_rag)
|
||
):
|
||
"""直接摄入文本内容"""
|
||
try:
|
||
# 使用异步方法 ainsert
|
||
await rag.ainsert(text)
|
||
return {"status": "success", "message": "Text ingested successfully"}
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
@router.post("/ingest/batch_qa")
|
||
async def ingest_batch_qa(
|
||
qa_list: list[QAPair],
|
||
rag: LightRAG = Depends(get_current_rag)
|
||
):
|
||
"""
|
||
批量摄入 QA 对
|
||
- 自动将 QA 格式化为语义连贯的文本块
|
||
- 自动合并短 QA 以优化索引效率
|
||
"""
|
||
if not qa_list:
|
||
return {"status": "skipped", "message": "Empty QA list"}
|
||
|
||
try:
|
||
# 1. 格式化并合并文本
|
||
batch_text = ""
|
||
current_batch_size = 0
|
||
MAX_BATCH_CHARS = 2000 # 约 1000 tokens,保守估计
|
||
|
||
inserted_count = 0
|
||
|
||
for qa in qa_list:
|
||
# 格式化单条 QA
|
||
entry = f"--- Q&A Entry ---\nQuestion: {qa.question}\nAnswer: {qa.answer}\n\n"
|
||
entry_len = len(entry)
|
||
|
||
# 如果当前批次过大,先提交一次
|
||
if current_batch_size + entry_len > MAX_BATCH_CHARS:
|
||
await rag.ainsert(batch_text)
|
||
batch_text = ""
|
||
current_batch_size = 0
|
||
inserted_count += 1
|
||
|
||
batch_text += entry
|
||
current_batch_size += entry_len
|
||
|
||
# 提交剩余的文本
|
||
if batch_text:
|
||
await rag.ainsert(batch_text)
|
||
inserted_count += 1
|
||
|
||
return {
|
||
"status": "success",
|
||
"message": f"Successfully processed {len(qa_list)} QA pairs into {inserted_count} text chunks."
|
||
}
|
||
except Exception as e:
|
||
logging.error(f"QA批量导入失败: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
@router.post("/ingest/file", response_model=IngestResponse)
|
||
async def upload_file(
|
||
file: UploadFile = File(...),
|
||
rag: LightRAG = Depends(get_current_rag)
|
||
):
|
||
"""
|
||
文件上传与索引接口
|
||
支持 .txt, .md, .pdf
|
||
"""
|
||
try:
|
||
content = ""
|
||
filename = file.filename
|
||
|
||
# 读取文件内容
|
||
file_bytes = await file.read()
|
||
|
||
# 根据文件类型解析
|
||
if filename.endswith(".pdf"):
|
||
# PDF 解析逻辑 (支持图文)
|
||
content = await process_pdf_with_images(file_bytes)
|
||
elif filename.endswith(".txt") or filename.endswith(".md"):
|
||
# 文本文件直接解码
|
||
content = file_bytes.decode("utf-8")
|
||
else:
|
||
return IngestResponse(
|
||
filename=filename,
|
||
status="skipped",
|
||
message="Unsupported file format. Only PDF, TXT, MD supported."
|
||
)
|
||
|
||
if not content.strip():
|
||
return IngestResponse(
|
||
filename=filename,
|
||
status="failed",
|
||
message="Empty content extracted."
|
||
)
|
||
|
||
# 插入 LightRAG 索引 (这是一个耗时操作)
|
||
# 使用异步方法 ainsert
|
||
await rag.ainsert(content)
|
||
|
||
return IngestResponse(
|
||
filename=filename,
|
||
status="success",
|
||
message=f"Successfully indexed {len(content)} characters."
|
||
)
|
||
|
||
except Exception as e:
|
||
logging.error(f"处理文件 {file.filename} 失败: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"Processing failed: {str(e)}")
|
||
|
||
@router.get("/documents")
|
||
async def list_documents(
|
||
rag: LightRAG = Depends(get_current_rag)
|
||
):
|
||
"""
|
||
获取文档列表
|
||
返回当前知识库中已索引的所有文档及其状态
|
||
"""
|
||
try:
|
||
# 从 rag 实例获取工作目录
|
||
doc_status_path = os.path.join(rag.working_dir, "kv_store_doc_status.json")
|
||
if not os.path.exists(doc_status_path):
|
||
return {"count": 0, "docs": []}
|
||
|
||
with open(doc_status_path, "r", encoding="utf-8") as f:
|
||
data = json.load(f)
|
||
|
||
# 格式化返回结果
|
||
docs = []
|
||
for doc_id, info in data.items():
|
||
docs.append({
|
||
"id": doc_id,
|
||
"summary": info.get("content_summary", "")[:100] + "...", # 摘要截断
|
||
"length": info.get("content_length", 0),
|
||
"created_at": info.get("created_at"),
|
||
"status": info.get("status")
|
||
})
|
||
|
||
return {
|
||
"count": len(docs),
|
||
"docs": docs
|
||
}
|
||
except Exception as e:
|
||
logging.error(f"获取文档列表失败: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
@router.get("/documents/{doc_id}")
|
||
async def get_document_detail(
|
||
doc_id: str,
|
||
rag: LightRAG = Depends(get_current_rag)
|
||
):
|
||
"""
|
||
获取文档详情
|
||
返回指定文档的完整内容
|
||
"""
|
||
try:
|
||
# 读取 kv_store_full_docs.json
|
||
full_docs_path = os.path.join(rag.working_dir, "kv_store_full_docs.json")
|
||
if not os.path.exists(full_docs_path):
|
||
raise HTTPException(status_code=404, detail="Document store not found")
|
||
|
||
with open(full_docs_path, "r", encoding="utf-8") as f:
|
||
data = json.load(f)
|
||
|
||
if doc_id not in data:
|
||
raise HTTPException(status_code=404, detail="Document not found")
|
||
|
||
doc_info = data[doc_id]
|
||
return {
|
||
"id": doc_id,
|
||
"content": doc_info.get("content", ""),
|
||
"create_time": doc_info.get("create_time"),
|
||
"file_path": doc_info.get("file_path", "unknown")
|
||
}
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logging.error(f"获取文档详情失败: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
@router.delete("/docs/{doc_id}")
|
||
async def delete_document(
|
||
doc_id: str,
|
||
rag: LightRAG = Depends(get_current_rag)
|
||
):
|
||
"""
|
||
删除指定文档
|
||
- doc_id: 文档ID (例如 doc-xxxxx)
|
||
"""
|
||
try:
|
||
logging.info(f"正在删除文档: {doc_id}")
|
||
# 调用 LightRAG 的删除方法
|
||
await rag.adelete_by_doc_id(doc_id)
|
||
return {"status": "success", "message": f"Document {doc_id} deleted successfully"}
|
||
except Exception as e:
|
||
logging.error(f"删除文档 {doc_id} 失败: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"Delete failed: {str(e)}")
|