import json import logging import os import io from fastapi import APIRouter, UploadFile, File, HTTPException, Body, Header, Depends from fastapi.responses import StreamingResponse from pydantic import BaseModel from pypdf import PdfReader from lightrag import LightRAG, QueryParam from app.core.rag import get_rag_manager, llm_func from app.core.prompts import CUSTOM_RAG_RESPONSE_PROMPT from app.config import settings router = APIRouter() # ========================================== # 依赖注入 # ========================================== async def get_current_rag( x_tenant_id: str = Header("default", alias="X-Tenant-ID") ) -> LightRAG: """ 依赖项:获取当前租户的 LightRAG 实例 从 Header 中读取 X-Tenant-ID,默认为 'default' """ manager = get_rag_manager() return await manager.get_rag(x_tenant_id) # ========================================== # 数据模型 # ========================================== class QueryRequest(BaseModel): query: str mode: str = "hybrid" # 可选: naive, local, global, hybrid top_k: int = 5 stream: bool = False class IngestResponse(BaseModel): filename: str status: str message: str # ========================================== # 接口实现 # ========================================== @router.get("/health") async def health_check(): """健康检查接口""" return {"status": "ok", "llm": settings.LLM_MODEL} @router.post("/query") async def query_knowledge_base( request: QueryRequest, rag: LightRAG = Depends(get_current_rag) ): """ 查询接口 - query: 用户问题 - mode: 检索模式 (推荐 hybrid 用于事实类查询) - stream: 是否流式输出 (默认 False) """ try: # 构造查询参数 param = QueryParam( mode=request.mode, top_k=request.top_k, stream=request.stream, enable_rerank=False # 显式关闭 Rerank 以消除 Warning ) # 处理流式输出 (SSE 协议) if request.stream: async def stream_generator(): # SSE 格式化辅助函数 def sse_pack(event: str, text: str) -> str: # 使用 JSON 包装 data 内容,确保换行符和特殊字符被正确转义 data = json.dumps({"text": text}, ensure_ascii=False) return f"event: {event}\ndata: {data}\n\n" yield sse_pack("thinking", "1. 上下文检索中...\n") context_param = QueryParam( mode=request.mode, top_k=request.top_k, only_need_context=True, enable_rerank=False # 显式关闭 Rerank 以消除 Warning ) # 获取上下文 (这步耗时较长,包含图遍历) context_resp = await rag.aquery(request.query, param=context_param) # 简单判断是否找到内容 if not context_resp or "Sorry, I'm not able to answer" in context_resp: yield sse_pack("thinking", " (未找到相关上下文,将依赖 LLM 自身知识)") else: yield sse_pack("thinking", f"2. 上下文已检索 (长度: {len(context_resp)} 字符).\n") yield sse_pack("thinking", "3. 答案生成中...\n") # 2. 生成答案 # 手动构建 System Prompt sys_prompt = CUSTOM_RAG_RESPONSE_PROMPT.format( context_data=context_resp, response_type="Multiple Paragraphs", user_prompt="" ) # 调用 LLM 生成 (流式) stream_resp = await llm_func( request.query, system_prompt=sys_prompt, stream=True, hashing_kv=rag.llm_response_cache ) thinkState = 0 # think 状态 0: 未开始 1: 开始 2: 结束 async for chunk in stream_resp: if isinstance(chunk, dict): if chunk.get("type") == "thinking": if thinkState == 0: yield sse_pack("thinking", "\n思考:\n") thinkState = 1 yield sse_pack("thinking", chunk["content"]) elif chunk.get("type") == "content": if thinkState == 1: yield sse_pack("none", "\n\n\n") thinkState = 2 yield sse_pack("answer", chunk["content"]) elif chunk: yield sse_pack("answer", chunk) # 使用 text/event-stream Content-Type return StreamingResponse(stream_generator(), media_type="text/event-stream") # 处理普通输出 response = await rag.aquery(request.query, param=param) return {"response": response} except Exception as e: logging.error(f"查询失败: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/ingest/text") async def ingest_text( text: str = Body(..., embed=True), rag: LightRAG = Depends(get_current_rag) ): """直接摄入文本内容""" try: # 使用异步方法 ainsert await rag.ainsert(text) return {"status": "success", "message": "Text ingested successfully"} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.post("/ingest/file", response_model=IngestResponse) async def upload_file( file: UploadFile = File(...), rag: LightRAG = Depends(get_current_rag) ): """ 文件上传与索引接口 支持 .txt, .md, .pdf """ try: content = "" filename = file.filename # 读取文件内容 file_bytes = await file.read() # 根据文件类型解析 if filename.endswith(".pdf"): # PDF 解析逻辑 pdf_reader = PdfReader(io.BytesIO(file_bytes)) for page in pdf_reader.pages: content += page.extract_text() + "\n" elif filename.endswith(".txt") or filename.endswith(".md"): # 文本文件直接解码 content = file_bytes.decode("utf-8") else: return IngestResponse( filename=filename, status="skipped", message="Unsupported file format. Only PDF, TXT, MD supported." ) if not content.strip(): return IngestResponse( filename=filename, status="failed", message="Empty content extracted." ) # 插入 LightRAG 索引 (这是一个耗时操作) # 使用异步方法 ainsert await rag.ainsert(content) return IngestResponse( filename=filename, status="success", message=f"Successfully indexed {len(content)} characters." ) except Exception as e: logging.error(f"处理文件 {file.filename} 失败: {str(e)}") raise HTTPException(status_code=500, detail=f"Processing failed: {str(e)}") @router.get("/documents") async def list_documents( rag: LightRAG = Depends(get_current_rag) ): """ 获取文档列表 返回当前知识库中已索引的所有文档及其状态 """ try: # 从 rag 实例获取工作目录 doc_status_path = os.path.join(rag.working_dir, "kv_store_doc_status.json") if not os.path.exists(doc_status_path): return {"count": 0, "docs": []} with open(doc_status_path, "r", encoding="utf-8") as f: data = json.load(f) # 格式化返回结果 docs = [] for doc_id, info in data.items(): docs.append({ "id": doc_id, "summary": info.get("content_summary", "")[:100] + "...", # 摘要截断 "length": info.get("content_length", 0), "created_at": info.get("created_at"), "status": info.get("status") }) return { "count": len(docs), "docs": docs } except Exception as e: logging.error(f"获取文档列表失败: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @router.delete("/docs/{doc_id}") async def delete_document( doc_id: str, rag: LightRAG = Depends(get_current_rag) ): """ 删除指定文档 - doc_id: 文档ID (例如 doc-xxxxx) """ try: logging.info(f"正在删除文档: {doc_id}") # 调用 LightRAG 的删除方法 await rag.adelete_by_doc_id(doc_id) return {"status": "success", "message": f"Document {doc_id} deleted successfully"} except Exception as e: logging.error(f"删除文档 {doc_id} 失败: {str(e)}") raise HTTPException(status_code=500, detail=f"Delete failed: {str(e)}")