diff --git a/.env.example b/.env.example index b99f704..8525e65 100644 --- a/.env.example +++ b/.env.example @@ -4,15 +4,28 @@ APP_VERSION="1.0" HOST="0.0.0.0" PORT=9600 -# LLM Configuration -LLM_BINDING=ollama -LLM_BINDING_HOST=http://localhost:11434 -LLM_MODEL=deepseek-v3.2:cloud +# LLM(Text) Configuration +LLM_BINDING=vllm # ollama, vllm, openai +LLM_BINDING_HOST=http://192.168.6.115:8002/v1 # vLLM OpenAI API base +LLM_MODEL=qwen2.5-7b-awq +LLM_KEY=EMPTY # vLLM default key + +# LLM(Vision) Configuration +VL_BINDING_HOST=http://192.168.6.115:8001/v1 +VL_MODEL=qwen2.5-vl-3b-awq +VL_KEY=EMPTY # Embedding Configuration -EMBEDDING_BINDING=ollama -EMBEDDING_BINDING_HOST=http://localhost:11434 -EMBEDDING_MODEL=bge-m3 +EMBEDDING_BINDING=tei # ollama, tei, openai +EMBEDDING_BINDING_HOST=http://192.168.6.115:8003 # TEI usually exposes /embed +EMBEDDING_MODEL=BAAI/bge-m3 # model id in TEI +EMBEDDING_KEY=EMPTY + +# Rerank - TEI +RERANK_ENABLED=True +RERANK_BINDING_HOST=http://192.168.6.115:8004 +RERANK_MODEL=BAAI/bge-reranker-v2-m3 +RERANK_KEY=EMPTY # Storage DATA_DIR=./index_data @@ -20,3 +33,4 @@ DATA_DIR=./index_data # RAG Configuration EMBEDDING_DIM=1024 MAX_TOKEN_SIZE=8192 +MAX_RAG_INSTANCES=5 # 最大活跃 RAG 实例数 \ No newline at end of file diff --git a/app/api/routes.py b/app/api/routes.py index cba5952..137c4a6 100644 --- a/app/api/routes.py +++ b/app/api/routes.py @@ -1,15 +1,14 @@ import json import logging import os -import io from fastapi import APIRouter, UploadFile, File, HTTPException, Body, Header, Depends from fastapi.responses import StreamingResponse from pydantic import BaseModel -from pypdf import PdfReader from lightrag import LightRAG, QueryParam from app.core.rag import get_rag_manager, llm_func from app.core.prompts import CUSTOM_RAG_RESPONSE_PROMPT from app.config import settings +from app.core.ingest import process_pdf_with_images router = APIRouter() @@ -41,6 +40,10 @@ class IngestResponse(BaseModel): status: str message: str +class QAPair(BaseModel): + question: str + answer: str + # ========================================== # 接口实现 # ========================================== @@ -67,7 +70,7 @@ async def query_knowledge_base( mode=request.mode, top_k=request.top_k, stream=request.stream, - enable_rerank=False # 显式关闭 Rerank 以消除 Warning + enable_rerank=settings.RERANK_ENABLED ) # 处理流式输出 (SSE 协议) @@ -85,7 +88,7 @@ async def query_knowledge_base( mode=request.mode, top_k=request.top_k, only_need_context=True, - enable_rerank=False # 显式关闭 Rerank 以消除 Warning + enable_rerank=settings.RERANK_ENABLED ) # 获取上下文 (这步耗时较长,包含图遍历) @@ -158,6 +161,55 @@ async def ingest_text( except Exception as e: raise HTTPException(status_code=500, detail=str(e)) +@router.post("/ingest/batch_qa") +async def ingest_batch_qa( + qa_list: list[QAPair], + rag: LightRAG = Depends(get_current_rag) +): + """ + 批量摄入 QA 对 + - 自动将 QA 格式化为语义连贯的文本块 + - 自动合并短 QA 以优化索引效率 + """ + if not qa_list: + return {"status": "skipped", "message": "Empty QA list"} + + try: + # 1. 格式化并合并文本 + batch_text = "" + current_batch_size = 0 + MAX_BATCH_CHARS = 2000 # 约 1000 tokens,保守估计 + + inserted_count = 0 + + for qa in qa_list: + # 格式化单条 QA + entry = f"--- Q&A Entry ---\nQuestion: {qa.question}\nAnswer: {qa.answer}\n\n" + entry_len = len(entry) + + # 如果当前批次过大,先提交一次 + if current_batch_size + entry_len > MAX_BATCH_CHARS: + await rag.ainsert(batch_text) + batch_text = "" + current_batch_size = 0 + inserted_count += 1 + + batch_text += entry + current_batch_size += entry_len + + # 提交剩余的文本 + if batch_text: + await rag.ainsert(batch_text) + inserted_count += 1 + + return { + "status": "success", + "message": f"Successfully processed {len(qa_list)} QA pairs into {inserted_count} text chunks." + } + except Exception as e: + logging.error(f"QA批量导入失败: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) + @router.post("/ingest/file", response_model=IngestResponse) async def upload_file( file: UploadFile = File(...), @@ -176,10 +228,8 @@ async def upload_file( # 根据文件类型解析 if filename.endswith(".pdf"): - # PDF 解析逻辑 - pdf_reader = PdfReader(io.BytesIO(file_bytes)) - for page in pdf_reader.pages: - content += page.extract_text() + "\n" + # PDF 解析逻辑 (支持图文) + content = await process_pdf_with_images(file_bytes) elif filename.endswith(".txt") or filename.endswith(".md"): # 文本文件直接解码 content = file_bytes.decode("utf-8") @@ -247,6 +297,40 @@ async def list_documents( logging.error(f"获取文档列表失败: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) +@router.get("/documents/{doc_id}") +async def get_document_detail( + doc_id: str, + rag: LightRAG = Depends(get_current_rag) +): + """ + 获取文档详情 + 返回指定文档的完整内容 + """ + try: + # 读取 kv_store_full_docs.json + full_docs_path = os.path.join(rag.working_dir, "kv_store_full_docs.json") + if not os.path.exists(full_docs_path): + raise HTTPException(status_code=404, detail="Document store not found") + + with open(full_docs_path, "r", encoding="utf-8") as f: + data = json.load(f) + + if doc_id not in data: + raise HTTPException(status_code=404, detail="Document not found") + + doc_info = data[doc_id] + return { + "id": doc_id, + "content": doc_info.get("content", ""), + "create_time": doc_info.get("create_time"), + "file_path": doc_info.get("file_path", "unknown") + } + except HTTPException: + raise + except Exception as e: + logging.error(f"获取文档详情失败: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) + @router.delete("/docs/{doc_id}") async def delete_document( doc_id: str, diff --git a/app/config.py b/app/config.py index f2fa93b..4fee328 100644 --- a/app/config.py +++ b/app/config.py @@ -11,15 +11,28 @@ class Settings(BaseSettings): # Data DATA_DIR: str = "./index_data" - # LLM - LLM_BINDING: str = "ollama" - LLM_BINDING_HOST: str = "http://localhost:11434" - LLM_MODEL: str = "deepseek-v3.2:cloud" + # LLM (Text) - vLLM + LLM_BINDING: str = "vllm" # ollama, vllm, openai + LLM_BINDING_HOST: str = "http://192.168.6.115:8002/v1" # vLLM OpenAI API base + LLM_MODEL: str = "qwen2.5-7b-awq" + LLM_KEY: str = "EMPTY" # vLLM default key + + # LLM (Vision) - vLLM + VL_BINDING_HOST: str = "http://192.168.6.115:8001/v1" + VL_MODEL: str = "qwen2.5-vl-3b-awq" + VL_KEY: str = "EMPTY" - # Embedding - EMBEDDING_BINDING: str = "ollama" - EMBEDDING_BINDING_HOST: str = "http://localhost:11434" - EMBEDDING_MODEL: str = "bge-m3" + # Embedding - TEI + EMBEDDING_BINDING: str = "tei" # ollama, tei, openai + EMBEDDING_BINDING_HOST: str = "http://192.168.6.115:8003" # TEI usually exposes /embed + EMBEDDING_MODEL: str = "BAAI/bge-m3" # model id in TEI + EMBEDDING_KEY: str = "EMPTY" + + # Rerank - TEI + RERANK_ENABLED: bool = True + RERANK_BINDING_HOST: str = "http://192.168.6.115:8004" + RERANK_MODEL: str = "BAAI/bge-reranker-v2-m3" + RERANK_KEY: str = "EMPTY" # RAG Config EMBEDDING_DIM: int = 1024 diff --git a/app/core/ingest.py b/app/core/ingest.py new file mode 100644 index 0000000..a1e1fcc --- /dev/null +++ b/app/core/ingest.py @@ -0,0 +1,88 @@ +import base64 +import logging +import httpx +from io import BytesIO +from app.config import settings + +async def vl_image_caption_func(image_data: bytes, prompt: str = "请详细描述这张图片") -> str: + """ + 使用 VL 模型 (vLLM OpenAI API) 生成图片描述 + """ + if not settings.VL_BINDING_HOST: + return "[Image Processing Skipped: No VL Model Configured]" + + try: + # 1. 编码图片为 Base64 + base64_image = base64.b64encode(image_data).decode('utf-8') + + # 2. 构造 OpenAI 格式请求 + # vLLM 支持 OpenAI Vision API + url = f"{settings.VL_BINDING_HOST}/chat/completions" + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {settings.VL_KEY}" + } + + payload = { + "model": settings.VL_MODEL, + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{base64_image}" + } + } + ] + } + ], + "max_tokens": 300 + } + + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.post(url, headers=headers, json=payload) + response.raise_for_status() + result = response.json() + description = result['choices'][0]['message']['content'] + return f"[Image Description: {description}]" + + except Exception as e: + logging.error(f"VL Caption failed: {str(e)}") + return f"[Image Processing Failed: {str(e)}]" + +async def process_pdf_with_images(file_bytes: bytes) -> str: + """ + 解析 PDF,提取文本并对图片进行 Caption + """ + import pypdf + from PIL import Image + + text_content = "" + pdf_file = BytesIO(file_bytes) + reader = pypdf.PdfReader(pdf_file) + + for page_num, page in enumerate(reader.pages): + # 1. 提取文本 + page_text = page.extract_text() + text_content += f"--- Page {page_num + 1} Text ---\n{page_text}\n\n" + + # 2. 提取图片 + if settings.VL_BINDING_HOST: + for count, image_file_object in enumerate(page.images): + try: + # 获取图片数据 + image_data = image_file_object.data + + # 简单验证图片有效性 + # Image.open(BytesIO(image_data)).verify() + + # 调用 VL 模型 + caption = await vl_image_caption_func(image_data) + text_content += f"--- Page {page_num + 1} Image {count + 1} ---\n{caption}\n\n" + except Exception as e: + logging.warning(f"Failed to process image {count} on page {page_num}: {e}") + + return text_content diff --git a/app/core/rag.py b/app/core/rag.py index 9581cd6..3734472 100644 --- a/app/core/rag.py +++ b/app/core/rag.py @@ -1,11 +1,12 @@ import logging import os import threading -from collections import OrderedDict -from typing import Optional - +import httpx import numpy as np import ollama +from collections import OrderedDict +from typing import Optional, List, Union + from lightrag import LightRAG from lightrag.utils import EmbeddingFunc from lightrag.llm.ollama import ollama_embed @@ -14,23 +15,22 @@ from app.config import settings # 全局 RAG 管理器 rag_manager = None -async def llm_func(prompt, system_prompt=None, history_messages=[], **kwargs) -> str: - """定义 LLM 函数:使用 Ollama 生成回复""" - # 移除可能存在的 model 参数,避免冲突 +# ============================================================================== +# LLM Functions +# ============================================================================== + +async def ollama_llm_func(prompt, system_prompt=None, history_messages=[], **kwargs) -> str: + """Ollama LLM 实现""" + # 参数清理 kwargs.pop('model', None) kwargs.pop('hashing_kv', None) - kwargs.pop('enable_cot', None) # 移除不支持的参数 + kwargs.pop('enable_cot', None) keyword_extraction = kwargs.pop("keyword_extraction", False) if keyword_extraction: kwargs["format"] = "json" stream = kwargs.pop("stream", False) think = kwargs.pop("think", False) - # Debug: 检查流式参数 - if stream: - logging.info("LLM called with stream=True") - else: - logging.info("LLM called with stream=False") messages = [] if system_prompt: @@ -41,30 +41,169 @@ async def llm_func(prompt, system_prompt=None, history_messages=[], **kwargs) -> client = ollama.AsyncClient(host=settings.LLM_BINDING_HOST) if stream: async def inner(): - # 使用 **kwargs 透传参数,确保 format 等顶级参数生效 response = await client.chat(model=settings.LLM_MODEL, messages=messages, stream=True, think=think, **kwargs) - async for chunk in response: msg = chunk.get("message", {}) - if "thinking" in msg and msg["thinking"]: yield {"type": "thinking", "content": msg["thinking"]} if "content" in msg and msg["content"]: yield {"type": "content", "content": msg["content"]} return inner() else: - # 同步请求不要think,降低延迟 response = await client.chat(model=settings.LLM_MODEL, messages=messages, stream=False, think=False, **kwargs) return response["message"]["content"] -async def embedding_func(texts: list[str]) -> np.ndarray: - """定义 Embedding 函数:使用 Ollama 计算向量""" +async def openai_llm_func(prompt, system_prompt=None, history_messages=[], **kwargs) -> str: + """OpenAI 兼容 LLM 实现 (适用于 vLLM)""" + # 参数清理 + kwargs.pop('model', None) + kwargs.pop('hashing_kv', None) + kwargs.pop('enable_cot', None) + keyword_extraction = kwargs.pop("keyword_extraction", False) + + # vLLM/OpenAI 不直接支持 format="json",通常需要在 prompt 中指示或使用 response_format + # 这里简单处理:如果需要 json,在 prompt 中暗示(LightRAG 的 prompt 通常已经包含 json 指令) + if keyword_extraction: + kwargs["response_format"] = {"type": "json_object"} + + stream = kwargs.pop("stream", False) + # think 参数是 DeepSeek 特有的,OpenAI 标准接口不支持,暂时忽略 + kwargs.pop("think", None) + + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.extend(history_messages) + messages.append({"role": "user", "content": prompt}) + + url = f"{settings.LLM_BINDING_HOST}/chat/completions" + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {settings.LLM_KEY}" + } + + payload = { + "model": settings.LLM_MODEL, + "messages": messages, + "stream": stream, + **kwargs + } + + if stream: + async def inner(): + async with httpx.AsyncClient(timeout=120.0) as client: + async with client.stream("POST", url, headers=headers, json=payload) as response: + response.raise_for_status() + async for line in response.aiter_lines(): + if line.startswith("data: "): + data_str = line[6:] + if data_str.strip() == "[DONE]": + break + try: + import json + chunk = json.loads(data_str) + delta = chunk['choices'][0]['delta'] + if 'content' in delta: + yield {"type": "content", "content": delta['content']} + # vLLM 可能不返回 thinking 字段,除非是 DeepSeek 模型且配置了 + except: + pass + return inner() + else: + async with httpx.AsyncClient(timeout=120.0) as client: + response = await client.post(url, headers=headers, json=payload) + response.raise_for_status() + result = response.json() + return result['choices'][0]['message']['content'] + +async def llm_func(prompt, system_prompt=None, history_messages=[], **kwargs) -> str: + """LLM 调度函数""" + if settings.LLM_BINDING == "ollama": + return await ollama_llm_func(prompt, system_prompt, history_messages, **kwargs) + elif settings.LLM_BINDING in ["vllm", "openai", "custom"]: + return await openai_llm_func(prompt, system_prompt, history_messages, **kwargs) + else: + raise ValueError(f"Unsupported LLM_BINDING: {settings.LLM_BINDING}") + +# ============================================================================== +# Embedding Functions +# ============================================================================== + +async def ollama_embedding_func(texts: list[str]) -> np.ndarray: return await ollama_embed( texts, embed_model=settings.EMBEDDING_MODEL, host=settings.EMBEDDING_BINDING_HOST ) +async def tei_embedding_func(texts: list[str]) -> np.ndarray: + """TEI (Text Embeddings Inference) Embedding 实现""" + url = f"{settings.EMBEDDING_BINDING_HOST}/embed" # TEI 标准接口 + headers = {"Content-Type": "application/json"} + if settings.EMBEDDING_KEY and settings.EMBEDDING_KEY != "EMPTY": + headers["Authorization"] = f"Bearer {settings.EMBEDDING_KEY}" + + payload = { + "inputs": texts, + "truncate": True # TEI 参数,防止超长报错 + } + + async with httpx.AsyncClient(timeout=60.0) as client: + response = await client.post(url, headers=headers, json=payload) + response.raise_for_status() + # TEI 返回: [[0.1, ...], [0.2, ...]] + embeddings = response.json() + return np.array(embeddings) + +async def embedding_func(texts: list[str]) -> np.ndarray: + """Embedding 调度函数""" + if settings.EMBEDDING_BINDING == "ollama": + return await ollama_embedding_func(texts) + elif settings.EMBEDDING_BINDING == "tei": + return await tei_embedding_func(texts) + else: + # 默认回退到 ollama 或者报错 + raise ValueError(f"Unsupported EMBEDDING_BINDING: {settings.EMBEDDING_BINDING}") + +# ============================================================================== +# Rerank Functions +# ============================================================================== + +async def tei_rerank_func(query: str, documents: list[str]) -> np.ndarray: + """TEI Rerank 实现""" + if not documents: + return np.array([]) + + url = f"{settings.RERANK_BINDING_HOST}/rerank" + headers = {"Content-Type": "application/json"} + if settings.RERANK_KEY and settings.RERANK_KEY != "EMPTY": + headers["Authorization"] = f"Bearer {settings.RERANK_KEY}" + + payload = { + "query": query, + "texts": documents, + "return_text": False + } + + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.post(url, headers=headers, json=payload) + response.raise_for_status() + # TEI 返回: [{"index": 0, "score": 0.99}, {"index": 1, "score": 0.5}] + results = response.json() + + # LightRAG 期望返回一个分数数组,对应输入的 documents 顺序 + # TEI 返回的结果是排序过的,我们需要根据 index 还原顺序 + scores = np.zeros(len(documents)) + for res in results: + idx = res['index'] + scores[idx] = res['score'] + + return scores + +# ============================================================================== +# RAG Manager +# ============================================================================== + class RAGManager: def __init__(self, capacity: int = 3): self.capacity = capacity @@ -73,44 +212,44 @@ class RAGManager: async def get_rag(self, user_id: str) -> LightRAG: """获取指定用户的 LightRAG 实例 (LRU 缓存)""" - # 1. 尝试从缓存获取 with self.lock: if user_id in self.cache: self.cache.move_to_end(user_id) logging.debug(f"Cache hit for user: {user_id}") return self.cache[user_id] - # 2. 缓存未命中,需要初始化 logging.info(f"Initializing RAG instance for user: {user_id}") user_data_dir = os.path.join(settings.DATA_DIR, user_id) - # 严格隔离:确保工作目录绝对干净,不应包含上级目录或其他用户数据 if not os.path.exists(user_data_dir): os.makedirs(user_data_dir) - # 实例化 LightRAG - # 关键修复:确保 working_dir 严格指向用户子目录 - rag = LightRAG( - working_dir=user_data_dir, - llm_model_func=llm_func, - llm_model_name=settings.LLM_MODEL, - llm_model_max_async=1, - max_parallel_insert=1, - embedding_func=EmbeddingFunc( + # 准备参数 + rag_params = { + "working_dir": user_data_dir, + "llm_model_func": llm_func, + "llm_model_name": settings.LLM_MODEL, + "llm_model_max_async": 4, # vLLM 并发能力强,可以调高 + "max_parallel_insert": 1, + "embedding_func": EmbeddingFunc( embedding_dim=settings.EMBEDDING_DIM, max_token_size=settings.MAX_TOKEN_SIZE, func=embedding_func ), - embedding_func_max_async=4, - enable_llm_cache=True - ) + "embedding_func_max_async": 8, # TEI 并发强 + "enable_llm_cache": True + } + + # 如果启用了 Rerank,注入 rerank_model_func + if settings.RERANK_ENABLED: + logging.info("Rerank enabled for RAG instance") + rag_params["rerank_model_func"] = tei_rerank_func + + rag = LightRAG(**rag_params) - # 异步初始化存储 await rag.initialize_storages() - # 3. 放入缓存并处理驱逐 with self.lock: - # 双重检查,防止在初始化期间其他线程已经创建了 if user_id in self.cache: self.cache.move_to_end(user_id) return self.cache[user_id] @@ -118,23 +257,19 @@ class RAGManager: self.cache[user_id] = rag self.cache.move_to_end(user_id) - # 检查容量 while len(self.cache) > self.capacity: oldest_user, _ = self.cache.popitem(last=False) - logging.info(f"Evicting RAG instance for user: {oldest_user} due to capacity limit") - # 这里依赖 Python GC 回收资源 + logging.info(f"Evicting RAG instance for user: {oldest_user}") return rag def initialize_rag_manager(): - """初始化全局 RAG 管理器""" global rag_manager rag_manager = RAGManager(capacity=settings.MAX_RAG_INSTANCES) logging.info(f"RAG Manager initialized with capacity: {settings.MAX_RAG_INSTANCES}") return rag_manager def get_rag_manager() -> RAGManager: - """获取全局 RAG 管理器""" if rag_manager is None: return initialize_rag_manager() return rag_manager