import json import logging import os import time from fastapi import APIRouter, UploadFile, File, HTTPException, Body, Header, Depends from fastapi.responses import StreamingResponse from pydantic import BaseModel from lightrag import LightRAG, QueryParam from app.core.rag import get_rag_manager, llm_func from app.core.prompts import CUSTOM_RAG_RESPONSE_PROMPT from app.config import settings from app.core.ingest import process_pdf_with_images router = APIRouter() # ========================================== # 依赖注入 # ========================================== async def get_current_rag( x_tenant_id: str = Header("default", alias="X-Tenant-ID") ) -> LightRAG: """ 依赖项:获取当前租户的 LightRAG 实例 从 Header 中读取 X-Tenant-ID,默认为 'default' """ manager = get_rag_manager() return await manager.get_rag(x_tenant_id) # ========================================== # 数据模型 # ========================================== class QueryRequest(BaseModel): query: str mode: str = "hybrid" # 可选: naive, local, global, hybrid top_k: int = 5 stream: bool = False think: bool = False only_rag: bool = False # 是否仅使用RAG检索结果,不进行LLM兜底 class IngestResponse(BaseModel): filename: str status: str message: str class QAPair(BaseModel): question: str answer: str # ========================================== # 接口实现 # ========================================== import secrets import string @router.get("/admin/tenants") async def list_tenants(token: str): """ 管理员接口:获取租户列表 """ if token != settings.ADMIN_TOKEN: raise HTTPException(status_code=403, detail="Invalid admin token") try: if not os.path.exists(settings.DATA_DIR): return {"tenants": []} tenants = [] for entry in os.scandir(settings.DATA_DIR): if entry.is_dir() and not entry.name.startswith("."): tenant_id = entry.name secret_file = os.path.join(entry.path, ".secret") # 读取或生成租户专属 Secret if os.path.exists(secret_file): with open(secret_file, "r") as f: secret = f.read().strip() else: # 生成16位随机字符串 alphabet = string.ascii_letters + string.digits secret = ''.join(secrets.choice(alphabet) for i in range(16)) try: with open(secret_file, "w") as f: f.write(secret) except Exception as e: logging.error(f"Failed to write secret for tenant {tenant_id}: {e}") continue # 生成租户访问 Token (租户ID_随机串) tenant_token = f"{tenant_id}_{secret}" tenants.append({ "id": tenant_id, "token": tenant_token }) return {"tenants": tenants} except Exception as e: logging.error(f"Failed to list tenants: {e}") raise HTTPException(status_code=500, detail=str(e)) @router.get("/health") async def health_check(): """健康检查接口""" return {"status": "ok", "llm": settings.LLM_MODEL} @router.post("/query") async def query_knowledge_base( request: QueryRequest, rag: LightRAG = Depends(get_current_rag) ): """ 查询接口 - query: 用户问题 - mode: 检索模式 (推荐 hybrid 用于事实类查询) - stream: 是否流式输出 (默认 False) - think: 是否启用思考模式 (默认 False) """ try: # 构造查询参数 context_param = QueryParam( mode=request.mode, top_k=request.top_k, only_need_context=True, enable_rerank=settings.RERANK_ENABLED ) # 执行上下文检索 context_resp = await rag.aquery(request.query, param=context_param) # 判断检索命中状态 rag_status = "miss" has_context = False if context_resp and "[no-context]" not in context_resp and "None" not in context_resp: rag_status = "hit" has_context = True # 处理流式输出 (SSE 协议 - OpenAI 兼容格式) if request.stream: async def stream_generator(): chat_id = f"chatcmpl-{secrets.token_hex(12)}" created_time = int(time.time()) model_name = settings.LLM_MODEL # 辅助函数:构造 OpenAI 兼容的 Chunk def openai_chunk(content=None, reasoning_content=None, finish_reason=None, extra_delta=None): delta = {} if content: delta["content"] = content if reasoning_content: delta["reasoning_content"] = reasoning_content if extra_delta: delta.update(extra_delta) chunk = { "id": chat_id, "object": "chat.completion.chunk", "created": created_time, "model": model_name, "choices": [ { "index": 0, "delta": delta, "finish_reason": finish_reason } ] } return f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n" if has_context: yield openai_chunk( reasoning_content=f"1. 上下文已检索 (长度: {len(context_resp)} 字符).\n", extra_delta={"x_rag_status": rag_status} ) else: yield openai_chunk( reasoning_content="未找到相关上下文\n", extra_delta={"x_rag_status": rag_status} ) # 如果开启了仅RAG模式且未找到上下文,则直接结束 if request.only_rag: yield openai_chunk(content="未找到相关知识库内容。", finish_reason="stop") yield "data: [DONE]\n\n" return yield openai_chunk(reasoning_content=" (将依赖 LLM 自身知识)\n") # 未找到上下文,关闭思考模式 request.think = False yield openai_chunk(reasoning_content="2. 答案生成中...\n") # 2. 生成答案 sys_prompt = CUSTOM_RAG_RESPONSE_PROMPT.format( context_data=context_resp, response_type="Multiple Paragraphs", user_prompt="" ) # 调用 LLM 生成 (流式) stream_resp = await llm_func( request.query, system_prompt=sys_prompt, stream=True, think=request.think, hashing_kv=rag.llm_response_cache ) async for chunk in stream_resp: if isinstance(chunk, dict): if chunk.get("type") == "thinking": yield openai_chunk(reasoning_content=chunk["content"]) elif chunk.get("type") == "content": yield openai_chunk(content=chunk["content"]) elif chunk: yield openai_chunk(content=chunk) # 发送结束标记 yield openai_chunk(finish_reason="stop") yield "data: [DONE]\n\n" # 使用 text/event-stream Content-Type return StreamingResponse(stream_generator(), media_type="text/event-stream") # 处理普通输出 (OpenAI 兼容格式) # 根据策略生成回答 final_answer = "" if not has_context and request.only_rag: # 严格模式且未检索到,直接结束 final_answer = "未找到相关知识库内容。" else: # 正常生成 sys_prompt = CUSTOM_RAG_RESPONSE_PROMPT.format( context_data=context_resp, response_type="Multiple Paragraphs", user_prompt="" ) # 调用 LLM 生成 stream_resp = await llm_func( request.query, system_prompt=sys_prompt, stream=False, think=request.think, hashing_kv=rag.llm_response_cache ) # 非流式调用 LLM 直接返回结果 final_answer = stream_resp return { "id": f"chatcmpl-{secrets.token_hex(12)}", "object": "chat.completion", "created": int(time.time()), "model": settings.LLM_MODEL, "choices": [ { "index": 0, "message": { "role": "assistant", "content": final_answer, # 扩展字段:非标准,但在某些客户端可能有用,或者放入 usage/metadata "x_rag_status": rag_status }, "finish_reason": "stop" } ], "usage": { "prompt_tokens": -1, "completion_tokens": len(final_answer), "total_tokens": -1 } } except Exception as e: logging.error(f"查询失败: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/ingest/text") async def ingest_text( text: str = Body(..., embed=True), rag: LightRAG = Depends(get_current_rag) ): """直接摄入文本内容""" try: # 使用异步方法 ainsert await rag.ainsert(text) return {"status": "success", "message": "Text ingested successfully"} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.post("/ingest/batch_qa") async def ingest_batch_qa( qa_list: list[QAPair], rag: LightRAG = Depends(get_current_rag) ): """ 批量摄入 QA 对 - 自动将 QA 格式化为语义连贯的文本块 - 自动合并短 QA 以优化索引效率 """ if not qa_list: return {"status": "skipped", "message": "Empty QA list"} try: # 1. 格式化并合并文本 batch_text = "" current_batch_size = 0 MAX_BATCH_CHARS = 2000 # 约 1000 tokens,保守估计 inserted_count = 0 for qa in qa_list: # 格式化单条 QA entry = f"--- Q&A Entry ---\nQuestion: {qa.question}\nAnswer: {qa.answer}\n\n" entry_len = len(entry) # 如果当前批次过大,先提交一次 if current_batch_size + entry_len > MAX_BATCH_CHARS: await rag.ainsert(batch_text) batch_text = "" current_batch_size = 0 inserted_count += 1 batch_text += entry current_batch_size += entry_len # 提交剩余的文本 if batch_text: await rag.ainsert(batch_text) inserted_count += 1 return { "status": "success", "message": f"Successfully processed {len(qa_list)} QA pairs into {inserted_count} text chunks." } except Exception as e: logging.error(f"QA批量导入失败: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/ingest/file", response_model=IngestResponse) async def upload_file( file: UploadFile = File(...), rag: LightRAG = Depends(get_current_rag) ): """ 文件上传与索引接口 支持 .txt, .md, .pdf """ try: content = "" filename = file.filename # 读取文件内容 file_bytes = await file.read() # 根据文件类型解析 if filename.endswith(".pdf"): # PDF 解析逻辑 (支持图文) content = await process_pdf_with_images(file_bytes) elif filename.endswith(".txt") or filename.endswith(".md"): # 文本文件直接解码 content = file_bytes.decode("utf-8") else: return IngestResponse( filename=filename, status="skipped", message="Unsupported file format. Only PDF, TXT, MD supported." ) if not content.strip(): return IngestResponse( filename=filename, status="failed", message="Empty content extracted." ) # 插入 LightRAG 索引 (这是一个耗时操作) # 使用异步方法 ainsert await rag.ainsert(content) return IngestResponse( filename=filename, status="success", message=f"Successfully indexed {len(content)} characters." ) except Exception as e: logging.error(f"处理文件 {file.filename} 失败: {str(e)}") raise HTTPException(status_code=500, detail=f"Processing failed: {str(e)}") @router.get("/documents") async def list_documents( rag: LightRAG = Depends(get_current_rag) ): """ 获取文档列表 返回当前知识库中已索引的所有文档及其状态 """ try: # 从 rag 实例获取工作目录 doc_status_path = os.path.join(rag.working_dir, "kv_store_doc_status.json") if not os.path.exists(doc_status_path): return {"count": 0, "docs": []} with open(doc_status_path, "r", encoding="utf-8") as f: data = json.load(f) # 格式化返回结果 docs = [] for doc_id, info in data.items(): docs.append({ "id": doc_id, "summary": info.get("content_summary", "")[:100] + "...", # 摘要截断 "length": info.get("content_length", 0), "created_at": info.get("created_at"), "status": info.get("status") }) return { "count": len(docs), "docs": docs } except Exception as e: logging.error(f"获取文档列表失败: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @router.get("/documents/{doc_id}") async def get_document_detail( doc_id: str, rag: LightRAG = Depends(get_current_rag) ): """ 获取文档详情 返回指定文档的完整内容 """ try: # 读取 kv_store_full_docs.json full_docs_path = os.path.join(rag.working_dir, "kv_store_full_docs.json") if not os.path.exists(full_docs_path): raise HTTPException(status_code=404, detail="Document store not found") with open(full_docs_path, "r", encoding="utf-8") as f: data = json.load(f) if doc_id not in data: raise HTTPException(status_code=404, detail="Document not found") doc_info = data[doc_id] return { "id": doc_id, "content": doc_info.get("content", ""), "create_time": doc_info.get("create_time"), "file_path": doc_info.get("file_path", "unknown") } except HTTPException: raise except Exception as e: logging.error(f"获取文档详情失败: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @router.delete("/docs/{doc_id}") async def delete_document( doc_id: str, rag: LightRAG = Depends(get_current_rag) ): """ 删除指定文档 - doc_id: 文档ID (例如 doc-xxxxx) """ try: logging.info(f"正在删除文档: {doc_id}") # 调用 LightRAG 的删除方法 await rag.adelete_by_doc_id(doc_id) return {"status": "success", "message": f"Document {doc_id} deleted successfully"} except Exception as e: logging.error(f"删除文档 {doc_id} 失败: {str(e)}") raise HTTPException(status_code=500, detail=f"Delete failed: {str(e)}")