fix: 1. 配置切换 ollama -> vllm+TEI 2. 文本模型变更 ollama-cloud -> vllm 3. 增加 vl模型、rerank模型 4.增加qa知识上传方法 4.增加知识详情方法

This commit is contained in:
fuzhongyun 2026-01-16 15:34:31 +08:00
parent 13bc171e9d
commit 75044940c7
5 changed files with 398 additions and 64 deletions

View File

@ -4,15 +4,28 @@ APP_VERSION="1.0"
HOST="0.0.0.0"
PORT=9600
# LLM Configuration
LLM_BINDING=ollama
LLM_BINDING_HOST=http://localhost:11434
LLM_MODEL=deepseek-v3.2:cloud
# LLMText Configuration
LLM_BINDING=vllm # ollama, vllm, openai
LLM_BINDING_HOST=http://192.168.6.115:8002/v1 # vLLM OpenAI API base
LLM_MODEL=qwen2.5-7b-awq
LLM_KEY=EMPTY # vLLM default key
# LLMVision Configuration
VL_BINDING_HOST=http://192.168.6.115:8001/v1
VL_MODEL=qwen2.5-vl-3b-awq
VL_KEY=EMPTY
# Embedding Configuration
EMBEDDING_BINDING=ollama
EMBEDDING_BINDING_HOST=http://localhost:11434
EMBEDDING_MODEL=bge-m3
EMBEDDING_BINDING=tei # ollama, tei, openai
EMBEDDING_BINDING_HOST=http://192.168.6.115:8003 # TEI usually exposes /embed
EMBEDDING_MODEL=BAAI/bge-m3 # model id in TEI
EMBEDDING_KEY=EMPTY
# Rerank - TEI
RERANK_ENABLED=True
RERANK_BINDING_HOST=http://192.168.6.115:8004
RERANK_MODEL=BAAI/bge-reranker-v2-m3
RERANK_KEY=EMPTY
# Storage
DATA_DIR=./index_data
@ -20,3 +33,4 @@ DATA_DIR=./index_data
# RAG Configuration
EMBEDDING_DIM=1024
MAX_TOKEN_SIZE=8192
MAX_RAG_INSTANCES=5 # 最大活跃 RAG 实例数

View File

@ -1,15 +1,14 @@
import json
import logging
import os
import io
from fastapi import APIRouter, UploadFile, File, HTTPException, Body, Header, Depends
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from pypdf import PdfReader
from lightrag import LightRAG, QueryParam
from app.core.rag import get_rag_manager, llm_func
from app.core.prompts import CUSTOM_RAG_RESPONSE_PROMPT
from app.config import settings
from app.core.ingest import process_pdf_with_images
router = APIRouter()
@ -41,6 +40,10 @@ class IngestResponse(BaseModel):
status: str
message: str
class QAPair(BaseModel):
question: str
answer: str
# ==========================================
# 接口实现
# ==========================================
@ -67,7 +70,7 @@ async def query_knowledge_base(
mode=request.mode,
top_k=request.top_k,
stream=request.stream,
enable_rerank=False # 显式关闭 Rerank 以消除 Warning
enable_rerank=settings.RERANK_ENABLED
)
# 处理流式输出 (SSE 协议)
@ -85,7 +88,7 @@ async def query_knowledge_base(
mode=request.mode,
top_k=request.top_k,
only_need_context=True,
enable_rerank=False # 显式关闭 Rerank 以消除 Warning
enable_rerank=settings.RERANK_ENABLED
)
# 获取上下文 (这步耗时较长,包含图遍历)
@ -158,6 +161,55 @@ async def ingest_text(
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/ingest/batch_qa")
async def ingest_batch_qa(
qa_list: list[QAPair],
rag: LightRAG = Depends(get_current_rag)
):
"""
批量摄入 QA
- 自动将 QA 格式化为语义连贯的文本块
- 自动合并短 QA 以优化索引效率
"""
if not qa_list:
return {"status": "skipped", "message": "Empty QA list"}
try:
# 1. 格式化并合并文本
batch_text = ""
current_batch_size = 0
MAX_BATCH_CHARS = 2000 # 约 1000 tokens保守估计
inserted_count = 0
for qa in qa_list:
# 格式化单条 QA
entry = f"--- Q&A Entry ---\nQuestion: {qa.question}\nAnswer: {qa.answer}\n\n"
entry_len = len(entry)
# 如果当前批次过大,先提交一次
if current_batch_size + entry_len > MAX_BATCH_CHARS:
await rag.ainsert(batch_text)
batch_text = ""
current_batch_size = 0
inserted_count += 1
batch_text += entry
current_batch_size += entry_len
# 提交剩余的文本
if batch_text:
await rag.ainsert(batch_text)
inserted_count += 1
return {
"status": "success",
"message": f"Successfully processed {len(qa_list)} QA pairs into {inserted_count} text chunks."
}
except Exception as e:
logging.error(f"QA批量导入失败: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/ingest/file", response_model=IngestResponse)
async def upload_file(
file: UploadFile = File(...),
@ -176,10 +228,8 @@ async def upload_file(
# 根据文件类型解析
if filename.endswith(".pdf"):
# PDF 解析逻辑
pdf_reader = PdfReader(io.BytesIO(file_bytes))
for page in pdf_reader.pages:
content += page.extract_text() + "\n"
# PDF 解析逻辑 (支持图文)
content = await process_pdf_with_images(file_bytes)
elif filename.endswith(".txt") or filename.endswith(".md"):
# 文本文件直接解码
content = file_bytes.decode("utf-8")
@ -247,6 +297,40 @@ async def list_documents(
logging.error(f"获取文档列表失败: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/documents/{doc_id}")
async def get_document_detail(
doc_id: str,
rag: LightRAG = Depends(get_current_rag)
):
"""
获取文档详情
返回指定文档的完整内容
"""
try:
# 读取 kv_store_full_docs.json
full_docs_path = os.path.join(rag.working_dir, "kv_store_full_docs.json")
if not os.path.exists(full_docs_path):
raise HTTPException(status_code=404, detail="Document store not found")
with open(full_docs_path, "r", encoding="utf-8") as f:
data = json.load(f)
if doc_id not in data:
raise HTTPException(status_code=404, detail="Document not found")
doc_info = data[doc_id]
return {
"id": doc_id,
"content": doc_info.get("content", ""),
"create_time": doc_info.get("create_time"),
"file_path": doc_info.get("file_path", "unknown")
}
except HTTPException:
raise
except Exception as e:
logging.error(f"获取文档详情失败: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.delete("/docs/{doc_id}")
async def delete_document(
doc_id: str,

View File

@ -11,15 +11,28 @@ class Settings(BaseSettings):
# Data
DATA_DIR: str = "./index_data"
# LLM
LLM_BINDING: str = "ollama"
LLM_BINDING_HOST: str = "http://localhost:11434"
LLM_MODEL: str = "deepseek-v3.2:cloud"
# LLM (Text) - vLLM
LLM_BINDING: str = "vllm" # ollama, vllm, openai
LLM_BINDING_HOST: str = "http://192.168.6.115:8002/v1" # vLLM OpenAI API base
LLM_MODEL: str = "qwen2.5-7b-awq"
LLM_KEY: str = "EMPTY" # vLLM default key
# Embedding
EMBEDDING_BINDING: str = "ollama"
EMBEDDING_BINDING_HOST: str = "http://localhost:11434"
EMBEDDING_MODEL: str = "bge-m3"
# LLM (Vision) - vLLM
VL_BINDING_HOST: str = "http://192.168.6.115:8001/v1"
VL_MODEL: str = "qwen2.5-vl-3b-awq"
VL_KEY: str = "EMPTY"
# Embedding - TEI
EMBEDDING_BINDING: str = "tei" # ollama, tei, openai
EMBEDDING_BINDING_HOST: str = "http://192.168.6.115:8003" # TEI usually exposes /embed
EMBEDDING_MODEL: str = "BAAI/bge-m3" # model id in TEI
EMBEDDING_KEY: str = "EMPTY"
# Rerank - TEI
RERANK_ENABLED: bool = True
RERANK_BINDING_HOST: str = "http://192.168.6.115:8004"
RERANK_MODEL: str = "BAAI/bge-reranker-v2-m3"
RERANK_KEY: str = "EMPTY"
# RAG Config
EMBEDDING_DIM: int = 1024

88
app/core/ingest.py Normal file
View File

@ -0,0 +1,88 @@
import base64
import logging
import httpx
from io import BytesIO
from app.config import settings
async def vl_image_caption_func(image_data: bytes, prompt: str = "请详细描述这张图片") -> str:
"""
使用 VL 模型 (vLLM OpenAI API) 生成图片描述
"""
if not settings.VL_BINDING_HOST:
return "[Image Processing Skipped: No VL Model Configured]"
try:
# 1. 编码图片为 Base64
base64_image = base64.b64encode(image_data).decode('utf-8')
# 2. 构造 OpenAI 格式请求
# vLLM 支持 OpenAI Vision API
url = f"{settings.VL_BINDING_HOST}/chat/completions"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {settings.VL_KEY}"
}
payload = {
"model": settings.VL_MODEL,
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
}
}
]
}
],
"max_tokens": 300
}
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
result = response.json()
description = result['choices'][0]['message']['content']
return f"[Image Description: {description}]"
except Exception as e:
logging.error(f"VL Caption failed: {str(e)}")
return f"[Image Processing Failed: {str(e)}]"
async def process_pdf_with_images(file_bytes: bytes) -> str:
"""
解析 PDF提取文本并对图片进行 Caption
"""
import pypdf
from PIL import Image
text_content = ""
pdf_file = BytesIO(file_bytes)
reader = pypdf.PdfReader(pdf_file)
for page_num, page in enumerate(reader.pages):
# 1. 提取文本
page_text = page.extract_text()
text_content += f"--- Page {page_num + 1} Text ---\n{page_text}\n\n"
# 2. 提取图片
if settings.VL_BINDING_HOST:
for count, image_file_object in enumerate(page.images):
try:
# 获取图片数据
image_data = image_file_object.data
# 简单验证图片有效性
# Image.open(BytesIO(image_data)).verify()
# 调用 VL 模型
caption = await vl_image_caption_func(image_data)
text_content += f"--- Page {page_num + 1} Image {count + 1} ---\n{caption}\n\n"
except Exception as e:
logging.warning(f"Failed to process image {count} on page {page_num}: {e}")
return text_content

View File

@ -1,11 +1,12 @@
import logging
import os
import threading
from collections import OrderedDict
from typing import Optional
import httpx
import numpy as np
import ollama
from collections import OrderedDict
from typing import Optional, List, Union
from lightrag import LightRAG
from lightrag.utils import EmbeddingFunc
from lightrag.llm.ollama import ollama_embed
@ -14,23 +15,22 @@ from app.config import settings
# 全局 RAG 管理器
rag_manager = None
async def llm_func(prompt, system_prompt=None, history_messages=[], **kwargs) -> str:
"""定义 LLM 函数:使用 Ollama 生成回复"""
# 移除可能存在的 model 参数,避免冲突
# ==============================================================================
# LLM Functions
# ==============================================================================
async def ollama_llm_func(prompt, system_prompt=None, history_messages=[], **kwargs) -> str:
"""Ollama LLM 实现"""
# 参数清理
kwargs.pop('model', None)
kwargs.pop('hashing_kv', None)
kwargs.pop('enable_cot', None) # 移除不支持的参数
kwargs.pop('enable_cot', None)
keyword_extraction = kwargs.pop("keyword_extraction", False)
if keyword_extraction:
kwargs["format"] = "json"
stream = kwargs.pop("stream", False)
think = kwargs.pop("think", False)
# Debug: 检查流式参数
if stream:
logging.info("LLM called with stream=True")
else:
logging.info("LLM called with stream=False")
messages = []
if system_prompt:
@ -41,30 +41,169 @@ async def llm_func(prompt, system_prompt=None, history_messages=[], **kwargs) ->
client = ollama.AsyncClient(host=settings.LLM_BINDING_HOST)
if stream:
async def inner():
# 使用 **kwargs 透传参数,确保 format 等顶级参数生效
response = await client.chat(model=settings.LLM_MODEL, messages=messages, stream=True, think=think, **kwargs)
async for chunk in response:
msg = chunk.get("message", {})
if "thinking" in msg and msg["thinking"]:
yield {"type": "thinking", "content": msg["thinking"]}
if "content" in msg and msg["content"]:
yield {"type": "content", "content": msg["content"]}
return inner()
else:
# 同步请求不要think降低延迟
response = await client.chat(model=settings.LLM_MODEL, messages=messages, stream=False, think=False, **kwargs)
return response["message"]["content"]
async def embedding_func(texts: list[str]) -> np.ndarray:
"""定义 Embedding 函数:使用 Ollama 计算向量"""
async def openai_llm_func(prompt, system_prompt=None, history_messages=[], **kwargs) -> str:
"""OpenAI 兼容 LLM 实现 (适用于 vLLM)"""
# 参数清理
kwargs.pop('model', None)
kwargs.pop('hashing_kv', None)
kwargs.pop('enable_cot', None)
keyword_extraction = kwargs.pop("keyword_extraction", False)
# vLLM/OpenAI 不直接支持 format="json",通常需要在 prompt 中指示或使用 response_format
# 这里简单处理:如果需要 json在 prompt 中暗示LightRAG 的 prompt 通常已经包含 json 指令)
if keyword_extraction:
kwargs["response_format"] = {"type": "json_object"}
stream = kwargs.pop("stream", False)
# think 参数是 DeepSeek 特有的OpenAI 标准接口不支持,暂时忽略
kwargs.pop("think", None)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
url = f"{settings.LLM_BINDING_HOST}/chat/completions"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {settings.LLM_KEY}"
}
payload = {
"model": settings.LLM_MODEL,
"messages": messages,
"stream": stream,
**kwargs
}
if stream:
async def inner():
async with httpx.AsyncClient(timeout=120.0) as client:
async with client.stream("POST", url, headers=headers, json=payload) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if line.startswith("data: "):
data_str = line[6:]
if data_str.strip() == "[DONE]":
break
try:
import json
chunk = json.loads(data_str)
delta = chunk['choices'][0]['delta']
if 'content' in delta:
yield {"type": "content", "content": delta['content']}
# vLLM 可能不返回 thinking 字段,除非是 DeepSeek 模型且配置了
except:
pass
return inner()
else:
async with httpx.AsyncClient(timeout=120.0) as client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
result = response.json()
return result['choices'][0]['message']['content']
async def llm_func(prompt, system_prompt=None, history_messages=[], **kwargs) -> str:
"""LLM 调度函数"""
if settings.LLM_BINDING == "ollama":
return await ollama_llm_func(prompt, system_prompt, history_messages, **kwargs)
elif settings.LLM_BINDING in ["vllm", "openai", "custom"]:
return await openai_llm_func(prompt, system_prompt, history_messages, **kwargs)
else:
raise ValueError(f"Unsupported LLM_BINDING: {settings.LLM_BINDING}")
# ==============================================================================
# Embedding Functions
# ==============================================================================
async def ollama_embedding_func(texts: list[str]) -> np.ndarray:
return await ollama_embed(
texts,
embed_model=settings.EMBEDDING_MODEL,
host=settings.EMBEDDING_BINDING_HOST
)
async def tei_embedding_func(texts: list[str]) -> np.ndarray:
"""TEI (Text Embeddings Inference) Embedding 实现"""
url = f"{settings.EMBEDDING_BINDING_HOST}/embed" # TEI 标准接口
headers = {"Content-Type": "application/json"}
if settings.EMBEDDING_KEY and settings.EMBEDDING_KEY != "EMPTY":
headers["Authorization"] = f"Bearer {settings.EMBEDDING_KEY}"
payload = {
"inputs": texts,
"truncate": True # TEI 参数,防止超长报错
}
async with httpx.AsyncClient(timeout=60.0) as client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
# TEI 返回: [[0.1, ...], [0.2, ...]]
embeddings = response.json()
return np.array(embeddings)
async def embedding_func(texts: list[str]) -> np.ndarray:
"""Embedding 调度函数"""
if settings.EMBEDDING_BINDING == "ollama":
return await ollama_embedding_func(texts)
elif settings.EMBEDDING_BINDING == "tei":
return await tei_embedding_func(texts)
else:
# 默认回退到 ollama 或者报错
raise ValueError(f"Unsupported EMBEDDING_BINDING: {settings.EMBEDDING_BINDING}")
# ==============================================================================
# Rerank Functions
# ==============================================================================
async def tei_rerank_func(query: str, documents: list[str]) -> np.ndarray:
"""TEI Rerank 实现"""
if not documents:
return np.array([])
url = f"{settings.RERANK_BINDING_HOST}/rerank"
headers = {"Content-Type": "application/json"}
if settings.RERANK_KEY and settings.RERANK_KEY != "EMPTY":
headers["Authorization"] = f"Bearer {settings.RERANK_KEY}"
payload = {
"query": query,
"texts": documents,
"return_text": False
}
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
# TEI 返回: [{"index": 0, "score": 0.99}, {"index": 1, "score": 0.5}]
results = response.json()
# LightRAG 期望返回一个分数数组,对应输入的 documents 顺序
# TEI 返回的结果是排序过的,我们需要根据 index 还原顺序
scores = np.zeros(len(documents))
for res in results:
idx = res['index']
scores[idx] = res['score']
return scores
# ==============================================================================
# RAG Manager
# ==============================================================================
class RAGManager:
def __init__(self, capacity: int = 3):
self.capacity = capacity
@ -73,44 +212,44 @@ class RAGManager:
async def get_rag(self, user_id: str) -> LightRAG:
"""获取指定用户的 LightRAG 实例 (LRU 缓存)"""
# 1. 尝试从缓存获取
with self.lock:
if user_id in self.cache:
self.cache.move_to_end(user_id)
logging.debug(f"Cache hit for user: {user_id}")
return self.cache[user_id]
# 2. 缓存未命中,需要初始化
logging.info(f"Initializing RAG instance for user: {user_id}")
user_data_dir = os.path.join(settings.DATA_DIR, user_id)
# 严格隔离:确保工作目录绝对干净,不应包含上级目录或其他用户数据
if not os.path.exists(user_data_dir):
os.makedirs(user_data_dir)
# 实例化 LightRAG
# 关键修复:确保 working_dir 严格指向用户子目录
rag = LightRAG(
working_dir=user_data_dir,
llm_model_func=llm_func,
llm_model_name=settings.LLM_MODEL,
llm_model_max_async=1,
max_parallel_insert=1,
embedding_func=EmbeddingFunc(
# 准备参数
rag_params = {
"working_dir": user_data_dir,
"llm_model_func": llm_func,
"llm_model_name": settings.LLM_MODEL,
"llm_model_max_async": 4, # vLLM 并发能力强,可以调高
"max_parallel_insert": 1,
"embedding_func": EmbeddingFunc(
embedding_dim=settings.EMBEDDING_DIM,
max_token_size=settings.MAX_TOKEN_SIZE,
func=embedding_func
),
embedding_func_max_async=4,
enable_llm_cache=True
)
"embedding_func_max_async": 8, # TEI 并发强
"enable_llm_cache": True
}
# 如果启用了 Rerank注入 rerank_model_func
if settings.RERANK_ENABLED:
logging.info("Rerank enabled for RAG instance")
rag_params["rerank_model_func"] = tei_rerank_func
rag = LightRAG(**rag_params)
# 异步初始化存储
await rag.initialize_storages()
# 3. 放入缓存并处理驱逐
with self.lock:
# 双重检查,防止在初始化期间其他线程已经创建了
if user_id in self.cache:
self.cache.move_to_end(user_id)
return self.cache[user_id]
@ -118,23 +257,19 @@ class RAGManager:
self.cache[user_id] = rag
self.cache.move_to_end(user_id)
# 检查容量
while len(self.cache) > self.capacity:
oldest_user, _ = self.cache.popitem(last=False)
logging.info(f"Evicting RAG instance for user: {oldest_user} due to capacity limit")
# 这里依赖 Python GC 回收资源
logging.info(f"Evicting RAG instance for user: {oldest_user}")
return rag
def initialize_rag_manager():
"""初始化全局 RAG 管理器"""
global rag_manager
rag_manager = RAGManager(capacity=settings.MAX_RAG_INSTANCES)
logging.info(f"RAG Manager initialized with capacity: {settings.MAX_RAG_INSTANCES}")
return rag_manager
def get_rag_manager() -> RAGManager:
"""获取全局 RAG 管理器"""
if rag_manager is None:
return initialize_rag_manager()
return rag_manager