ai-lightrag/app/api/routes.py

492 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import logging
import os
import time
from fastapi import APIRouter, UploadFile, File, HTTPException, Body, Header, Depends
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from lightrag import LightRAG, QueryParam
from app.core.rag import get_rag_manager, llm_func
from app.core.prompts import CUSTOM_RAG_RESPONSE_PROMPT
from app.config import settings
from app.core.ingest import process_pdf_with_images
router = APIRouter()
# ==========================================
# 依赖注入
# ==========================================
async def get_current_rag(
x_tenant_id: str = Header("default", alias="X-Tenant-ID")
) -> LightRAG:
"""
依赖项:获取当前租户的 LightRAG 实例
从 Header 中读取 X-Tenant-ID默认为 'default'
"""
manager = get_rag_manager()
return await manager.get_rag(x_tenant_id)
# ==========================================
# 数据模型
# ==========================================
class QueryRequest(BaseModel):
query: str
mode: str = "hybrid" # 可选: naive, local, global, hybrid
top_k: int = 5
stream: bool = False
think: bool = False
only_rag: bool = False # 是否仅使用RAG检索结果不进行LLM兜底
class IngestResponse(BaseModel):
filename: str
status: str
message: str
class QAPair(BaseModel):
question: str
answer: str
# ==========================================
# 接口实现
# ==========================================
import secrets
import string
@router.get("/admin/tenants")
async def list_tenants(token: str):
"""
管理员接口:获取租户列表
"""
if token != settings.ADMIN_TOKEN:
raise HTTPException(status_code=403, detail="Invalid admin token")
try:
if not os.path.exists(settings.DATA_DIR):
return {"tenants": []}
tenants = []
for entry in os.scandir(settings.DATA_DIR):
if entry.is_dir() and not entry.name.startswith("."):
tenant_id = entry.name
secret_file = os.path.join(entry.path, ".secret")
# 读取或生成租户专属 Secret
if os.path.exists(secret_file):
with open(secret_file, "r") as f:
secret = f.read().strip()
else:
# 生成16位随机字符串
alphabet = string.ascii_letters + string.digits
secret = ''.join(secrets.choice(alphabet) for i in range(16))
try:
with open(secret_file, "w") as f:
f.write(secret)
except Exception as e:
logging.error(f"Failed to write secret for tenant {tenant_id}: {e}")
continue
# 生成租户访问 Token (租户ID_随机串)
tenant_token = f"{tenant_id}_{secret}"
tenants.append({
"id": tenant_id,
"token": tenant_token
})
return {"tenants": tenants}
except Exception as e:
logging.error(f"Failed to list tenants: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/health")
async def health_check():
"""健康检查接口"""
return {"status": "ok", "llm": settings.LLM_MODEL}
@router.post("/query")
async def query_knowledge_base(
request: QueryRequest,
rag: LightRAG = Depends(get_current_rag)
):
"""
查询接口
- query: 用户问题
- mode: 检索模式 (推荐 hybrid 用于事实类查询)
- stream: 是否流式输出 (默认 False)
- think: 是否启用思考模式 (默认 False)
"""
try:
# 构造查询参数
context_param = QueryParam(
mode=request.mode,
top_k=request.top_k,
only_need_context=True,
enable_rerank=settings.RERANK_ENABLED
)
# 执行上下文检索
context_resp = await rag.aquery(request.query, param=context_param)
logging.info(f"Context response: {context_resp}")
# 判断检索状态
has_context = False
# 1. 基础检查:排除空字符串和明确的无上下文标记
if context_resp and "[no-context]" not in context_resp and "None" not in context_resp:
# 2. 严谨检查:只有包含具体的 Document Chunks (原文片段) 才视为有效命中
# 实体(Entities)容易因通用词产生脏匹配,不宜单独作为命中依据
if "Document Chunks" in context_resp:
chunks_part = context_resp.split("Document Chunks")[1]
# 检查 Chunks 部分是否包含 JSON 格式的内容字段
if '"content":' in chunks_part or '"text":' in chunks_part:
has_context = True
if has_context:
rag_status = "hit"
else:
rag_status = "miss"
# 处理流式输出 (SSE 协议 - OpenAI 兼容格式)
if request.stream:
async def stream_generator():
chat_id = f"chatcmpl-{secrets.token_hex(12)}"
created_time = int(time.time())
model_name = settings.LLM_MODEL
# 辅助函数:构造 OpenAI 兼容的 Chunk
def openai_chunk(content=None, reasoning_content=None, finish_reason=None, extra_delta=None):
delta = {}
if content:
delta["content"] = content
if reasoning_content:
delta["reasoning_content"] = reasoning_content
if extra_delta:
delta.update(extra_delta)
chunk = {
"id": chat_id,
"object": "chat.completion.chunk",
"created": created_time,
"model": model_name,
"choices": [
{
"index": 0,
"delta": delta,
"finish_reason": finish_reason
}
]
}
return f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
if has_context:
yield openai_chunk(
reasoning_content=f"1. 上下文已检索 (长度: {len(context_resp)} 字符).\n",
extra_delta={"x_rag_status": rag_status}
)
else:
yield openai_chunk(
reasoning_content="未找到相关上下文\n",
extra_delta={"x_rag_status": rag_status}
)
# 如果开启了仅RAG模式且未找到上下文则直接结束
if request.only_rag:
yield openai_chunk(content="未找到相关知识库内容。", finish_reason="stop")
yield "data: [DONE]\n\n"
return
yield openai_chunk(reasoning_content=" (将依赖 LLM 自身知识)\n")
# 未找到上下文,关闭思考模式
request.think = False
yield openai_chunk(reasoning_content="2. 答案生成中...\n")
# 2. 生成答案
sys_prompt = CUSTOM_RAG_RESPONSE_PROMPT.format(
context_data=context_resp,
response_type="Multiple Paragraphs",
user_prompt=""
)
# 调用 LLM 生成 (流式)
stream_resp = await llm_func(
request.query,
system_prompt=sys_prompt,
stream=True,
think=request.think,
hashing_kv=rag.llm_response_cache
)
async for chunk in stream_resp:
if isinstance(chunk, dict):
if chunk.get("type") == "thinking":
yield openai_chunk(reasoning_content=chunk["content"])
elif chunk.get("type") == "content":
yield openai_chunk(content=chunk["content"])
elif chunk:
yield openai_chunk(content=chunk)
# 发送结束标记
yield openai_chunk(finish_reason="stop")
yield "data: [DONE]\n\n"
# 使用 text/event-stream Content-Type
return StreamingResponse(stream_generator(), media_type="text/event-stream")
# 处理普通输出 (OpenAI 兼容格式)
# 根据策略生成回答
final_answer = ""
if not has_context and request.only_rag:
# 严格模式且未检索到,直接结束
final_answer = "未找到相关知识库内容。"
else:
# 正常生成
sys_prompt = CUSTOM_RAG_RESPONSE_PROMPT.format(
context_data=context_resp,
response_type="Multiple Paragraphs",
user_prompt=""
)
# 调用 LLM 生成
stream_resp = await llm_func(
request.query,
system_prompt=sys_prompt,
stream=False,
think=request.think,
hashing_kv=rag.llm_response_cache
)
# 非流式调用 LLM 直接返回结果
final_answer = stream_resp
return {
"id": f"chatcmpl-{secrets.token_hex(12)}",
"object": "chat.completion",
"created": int(time.time()),
"model": settings.LLM_MODEL,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": final_answer,
# 扩展字段:非标准,但在某些客户端可能有用,或者放入 usage/metadata
"x_rag_status": rag_status
},
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": -1,
"completion_tokens": len(final_answer),
"total_tokens": -1
}
}
except Exception as e:
logging.error(f"查询失败: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/ingest/text")
async def ingest_text(
text: str = Body(..., embed=True),
rag: LightRAG = Depends(get_current_rag)
):
"""直接摄入文本内容"""
try:
# 使用异步方法 ainsert
await rag.ainsert(text)
return {"status": "success", "message": "Text ingested successfully"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/ingest/batch_qa")
async def ingest_batch_qa(
qa_list: list[QAPair],
rag: LightRAG = Depends(get_current_rag)
):
"""
批量摄入 QA 对
- 自动将 QA 格式化为语义连贯的文本块
- 自动合并短 QA 以优化索引效率
"""
if not qa_list:
return {"status": "skipped", "message": "Empty QA list"}
try:
# 1. 格式化并合并文本
batch_text = ""
current_batch_size = 0
MAX_BATCH_CHARS = 2000 # 约 1000 tokens保守估计
inserted_count = 0
for qa in qa_list:
# 格式化单条 QA
entry = f"--- Q&A Entry ---\nQuestion: {qa.question}\nAnswer: {qa.answer}\n\n"
entry_len = len(entry)
# 如果当前批次过大,先提交一次
if current_batch_size + entry_len > MAX_BATCH_CHARS:
await rag.ainsert(batch_text)
batch_text = ""
current_batch_size = 0
inserted_count += 1
batch_text += entry
current_batch_size += entry_len
# 提交剩余的文本
if batch_text:
await rag.ainsert(batch_text)
inserted_count += 1
return {
"status": "success",
"message": f"Successfully processed {len(qa_list)} QA pairs into {inserted_count} text chunks."
}
except Exception as e:
logging.error(f"QA批量导入失败: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/ingest/file", response_model=IngestResponse)
async def upload_file(
file: UploadFile = File(...),
rag: LightRAG = Depends(get_current_rag)
):
"""
文件上传与索引接口
支持 .txt, .md, .pdf
"""
try:
content = ""
filename = file.filename
# 读取文件内容
file_bytes = await file.read()
# 根据文件类型解析
if filename.endswith(".pdf"):
# PDF 解析逻辑 (支持图文)
content = await process_pdf_with_images(file_bytes)
elif filename.endswith(".txt") or filename.endswith(".md"):
# 文本文件直接解码
content = file_bytes.decode("utf-8")
else:
return IngestResponse(
filename=filename,
status="skipped",
message="Unsupported file format. Only PDF, TXT, MD supported."
)
if not content.strip():
return IngestResponse(
filename=filename,
status="failed",
message="Empty content extracted."
)
# 插入 LightRAG 索引 (这是一个耗时操作)
# 使用异步方法 ainsert
await rag.ainsert(content)
return IngestResponse(
filename=filename,
status="success",
message=f"Successfully indexed {len(content)} characters."
)
except Exception as e:
logging.error(f"处理文件 {file.filename} 失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"Processing failed: {str(e)}")
@router.get("/documents")
async def list_documents(
rag: LightRAG = Depends(get_current_rag)
):
"""
获取文档列表
返回当前知识库中已索引的所有文档及其状态
"""
try:
# 从 rag 实例获取工作目录
doc_status_path = os.path.join(rag.working_dir, "kv_store_doc_status.json")
if not os.path.exists(doc_status_path):
return {"count": 0, "docs": []}
with open(doc_status_path, "r", encoding="utf-8") as f:
data = json.load(f)
# 格式化返回结果
docs = []
for doc_id, info in data.items():
docs.append({
"id": doc_id,
"summary": info.get("content_summary", "")[:100] + "...", # 摘要截断
"length": info.get("content_length", 0),
"created_at": info.get("created_at"),
"status": info.get("status")
})
return {
"count": len(docs),
"docs": docs
}
except Exception as e:
logging.error(f"获取文档列表失败: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/documents/{doc_id}")
async def get_document_detail(
doc_id: str,
rag: LightRAG = Depends(get_current_rag)
):
"""
获取文档详情
返回指定文档的完整内容
"""
try:
# 读取 kv_store_full_docs.json
full_docs_path = os.path.join(rag.working_dir, "kv_store_full_docs.json")
if not os.path.exists(full_docs_path):
raise HTTPException(status_code=404, detail="Document store not found")
with open(full_docs_path, "r", encoding="utf-8") as f:
data = json.load(f)
if doc_id not in data:
raise HTTPException(status_code=404, detail="Document not found")
doc_info = data[doc_id]
return {
"id": doc_id,
"content": doc_info.get("content", ""),
"create_time": doc_info.get("create_time"),
"file_path": doc_info.get("file_path", "unknown")
}
except HTTPException:
raise
except Exception as e:
logging.error(f"获取文档详情失败: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.delete("/docs/{doc_id}")
async def delete_document(
doc_id: str,
rag: LightRAG = Depends(get_current_rag)
):
"""
删除指定文档
- doc_id: 文档ID (例如 doc-xxxxx)
"""
try:
logging.info(f"正在删除文档: {doc_id}")
# 调用 LightRAG 的删除方法
await rag.adelete_by_doc_id(doc_id)
return {"status": "success", "message": f"Document {doc_id} deleted successfully"}
except Exception as e:
logging.error(f"删除文档 {doc_id} 失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"Delete failed: {str(e)}")