fix: 1. 配置切换 ollama -> vllm+TEI 2. 文本模型变更 ollama-cloud -> vllm 3. 增加 vl模型、rerank模型 4.增加qa知识上传方法 4.增加知识详情方法
This commit is contained in:
parent
13bc171e9d
commit
75044940c7
28
.env.example
28
.env.example
|
|
@ -4,15 +4,28 @@ APP_VERSION="1.0"
|
||||||
HOST="0.0.0.0"
|
HOST="0.0.0.0"
|
||||||
PORT=9600
|
PORT=9600
|
||||||
|
|
||||||
# LLM Configuration
|
# LLM(Text) Configuration
|
||||||
LLM_BINDING=ollama
|
LLM_BINDING=vllm # ollama, vllm, openai
|
||||||
LLM_BINDING_HOST=http://localhost:11434
|
LLM_BINDING_HOST=http://192.168.6.115:8002/v1 # vLLM OpenAI API base
|
||||||
LLM_MODEL=deepseek-v3.2:cloud
|
LLM_MODEL=qwen2.5-7b-awq
|
||||||
|
LLM_KEY=EMPTY # vLLM default key
|
||||||
|
|
||||||
|
# LLM(Vision) Configuration
|
||||||
|
VL_BINDING_HOST=http://192.168.6.115:8001/v1
|
||||||
|
VL_MODEL=qwen2.5-vl-3b-awq
|
||||||
|
VL_KEY=EMPTY
|
||||||
|
|
||||||
# Embedding Configuration
|
# Embedding Configuration
|
||||||
EMBEDDING_BINDING=ollama
|
EMBEDDING_BINDING=tei # ollama, tei, openai
|
||||||
EMBEDDING_BINDING_HOST=http://localhost:11434
|
EMBEDDING_BINDING_HOST=http://192.168.6.115:8003 # TEI usually exposes /embed
|
||||||
EMBEDDING_MODEL=bge-m3
|
EMBEDDING_MODEL=BAAI/bge-m3 # model id in TEI
|
||||||
|
EMBEDDING_KEY=EMPTY
|
||||||
|
|
||||||
|
# Rerank - TEI
|
||||||
|
RERANK_ENABLED=True
|
||||||
|
RERANK_BINDING_HOST=http://192.168.6.115:8004
|
||||||
|
RERANK_MODEL=BAAI/bge-reranker-v2-m3
|
||||||
|
RERANK_KEY=EMPTY
|
||||||
|
|
||||||
# Storage
|
# Storage
|
||||||
DATA_DIR=./index_data
|
DATA_DIR=./index_data
|
||||||
|
|
@ -20,3 +33,4 @@ DATA_DIR=./index_data
|
||||||
# RAG Configuration
|
# RAG Configuration
|
||||||
EMBEDDING_DIM=1024
|
EMBEDDING_DIM=1024
|
||||||
MAX_TOKEN_SIZE=8192
|
MAX_TOKEN_SIZE=8192
|
||||||
|
MAX_RAG_INSTANCES=5 # 最大活跃 RAG 实例数
|
||||||
|
|
@ -1,15 +1,14 @@
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import io
|
|
||||||
from fastapi import APIRouter, UploadFile, File, HTTPException, Body, Header, Depends
|
from fastapi import APIRouter, UploadFile, File, HTTPException, Body, Header, Depends
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from pypdf import PdfReader
|
|
||||||
from lightrag import LightRAG, QueryParam
|
from lightrag import LightRAG, QueryParam
|
||||||
from app.core.rag import get_rag_manager, llm_func
|
from app.core.rag import get_rag_manager, llm_func
|
||||||
from app.core.prompts import CUSTOM_RAG_RESPONSE_PROMPT
|
from app.core.prompts import CUSTOM_RAG_RESPONSE_PROMPT
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
|
from app.core.ingest import process_pdf_with_images
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
@ -41,6 +40,10 @@ class IngestResponse(BaseModel):
|
||||||
status: str
|
status: str
|
||||||
message: str
|
message: str
|
||||||
|
|
||||||
|
class QAPair(BaseModel):
|
||||||
|
question: str
|
||||||
|
answer: str
|
||||||
|
|
||||||
# ==========================================
|
# ==========================================
|
||||||
# 接口实现
|
# 接口实现
|
||||||
# ==========================================
|
# ==========================================
|
||||||
|
|
@ -67,7 +70,7 @@ async def query_knowledge_base(
|
||||||
mode=request.mode,
|
mode=request.mode,
|
||||||
top_k=request.top_k,
|
top_k=request.top_k,
|
||||||
stream=request.stream,
|
stream=request.stream,
|
||||||
enable_rerank=False # 显式关闭 Rerank 以消除 Warning
|
enable_rerank=settings.RERANK_ENABLED
|
||||||
)
|
)
|
||||||
|
|
||||||
# 处理流式输出 (SSE 协议)
|
# 处理流式输出 (SSE 协议)
|
||||||
|
|
@ -85,7 +88,7 @@ async def query_knowledge_base(
|
||||||
mode=request.mode,
|
mode=request.mode,
|
||||||
top_k=request.top_k,
|
top_k=request.top_k,
|
||||||
only_need_context=True,
|
only_need_context=True,
|
||||||
enable_rerank=False # 显式关闭 Rerank 以消除 Warning
|
enable_rerank=settings.RERANK_ENABLED
|
||||||
)
|
)
|
||||||
|
|
||||||
# 获取上下文 (这步耗时较长,包含图遍历)
|
# 获取上下文 (这步耗时较长,包含图遍历)
|
||||||
|
|
@ -158,6 +161,55 @@ async def ingest_text(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
@router.post("/ingest/batch_qa")
|
||||||
|
async def ingest_batch_qa(
|
||||||
|
qa_list: list[QAPair],
|
||||||
|
rag: LightRAG = Depends(get_current_rag)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
批量摄入 QA 对
|
||||||
|
- 自动将 QA 格式化为语义连贯的文本块
|
||||||
|
- 自动合并短 QA 以优化索引效率
|
||||||
|
"""
|
||||||
|
if not qa_list:
|
||||||
|
return {"status": "skipped", "message": "Empty QA list"}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 1. 格式化并合并文本
|
||||||
|
batch_text = ""
|
||||||
|
current_batch_size = 0
|
||||||
|
MAX_BATCH_CHARS = 2000 # 约 1000 tokens,保守估计
|
||||||
|
|
||||||
|
inserted_count = 0
|
||||||
|
|
||||||
|
for qa in qa_list:
|
||||||
|
# 格式化单条 QA
|
||||||
|
entry = f"--- Q&A Entry ---\nQuestion: {qa.question}\nAnswer: {qa.answer}\n\n"
|
||||||
|
entry_len = len(entry)
|
||||||
|
|
||||||
|
# 如果当前批次过大,先提交一次
|
||||||
|
if current_batch_size + entry_len > MAX_BATCH_CHARS:
|
||||||
|
await rag.ainsert(batch_text)
|
||||||
|
batch_text = ""
|
||||||
|
current_batch_size = 0
|
||||||
|
inserted_count += 1
|
||||||
|
|
||||||
|
batch_text += entry
|
||||||
|
current_batch_size += entry_len
|
||||||
|
|
||||||
|
# 提交剩余的文本
|
||||||
|
if batch_text:
|
||||||
|
await rag.ainsert(batch_text)
|
||||||
|
inserted_count += 1
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"message": f"Successfully processed {len(qa_list)} QA pairs into {inserted_count} text chunks."
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"QA批量导入失败: {str(e)}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@router.post("/ingest/file", response_model=IngestResponse)
|
@router.post("/ingest/file", response_model=IngestResponse)
|
||||||
async def upload_file(
|
async def upload_file(
|
||||||
file: UploadFile = File(...),
|
file: UploadFile = File(...),
|
||||||
|
|
@ -176,10 +228,8 @@ async def upload_file(
|
||||||
|
|
||||||
# 根据文件类型解析
|
# 根据文件类型解析
|
||||||
if filename.endswith(".pdf"):
|
if filename.endswith(".pdf"):
|
||||||
# PDF 解析逻辑
|
# PDF 解析逻辑 (支持图文)
|
||||||
pdf_reader = PdfReader(io.BytesIO(file_bytes))
|
content = await process_pdf_with_images(file_bytes)
|
||||||
for page in pdf_reader.pages:
|
|
||||||
content += page.extract_text() + "\n"
|
|
||||||
elif filename.endswith(".txt") or filename.endswith(".md"):
|
elif filename.endswith(".txt") or filename.endswith(".md"):
|
||||||
# 文本文件直接解码
|
# 文本文件直接解码
|
||||||
content = file_bytes.decode("utf-8")
|
content = file_bytes.decode("utf-8")
|
||||||
|
|
@ -247,6 +297,40 @@ async def list_documents(
|
||||||
logging.error(f"获取文档列表失败: {str(e)}")
|
logging.error(f"获取文档列表失败: {str(e)}")
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
@router.get("/documents/{doc_id}")
|
||||||
|
async def get_document_detail(
|
||||||
|
doc_id: str,
|
||||||
|
rag: LightRAG = Depends(get_current_rag)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取文档详情
|
||||||
|
返回指定文档的完整内容
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 读取 kv_store_full_docs.json
|
||||||
|
full_docs_path = os.path.join(rag.working_dir, "kv_store_full_docs.json")
|
||||||
|
if not os.path.exists(full_docs_path):
|
||||||
|
raise HTTPException(status_code=404, detail="Document store not found")
|
||||||
|
|
||||||
|
with open(full_docs_path, "r", encoding="utf-8") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
if doc_id not in data:
|
||||||
|
raise HTTPException(status_code=404, detail="Document not found")
|
||||||
|
|
||||||
|
doc_info = data[doc_id]
|
||||||
|
return {
|
||||||
|
"id": doc_id,
|
||||||
|
"content": doc_info.get("content", ""),
|
||||||
|
"create_time": doc_info.get("create_time"),
|
||||||
|
"file_path": doc_info.get("file_path", "unknown")
|
||||||
|
}
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"获取文档详情失败: {str(e)}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@router.delete("/docs/{doc_id}")
|
@router.delete("/docs/{doc_id}")
|
||||||
async def delete_document(
|
async def delete_document(
|
||||||
doc_id: str,
|
doc_id: str,
|
||||||
|
|
|
||||||
|
|
@ -11,15 +11,28 @@ class Settings(BaseSettings):
|
||||||
# Data
|
# Data
|
||||||
DATA_DIR: str = "./index_data"
|
DATA_DIR: str = "./index_data"
|
||||||
|
|
||||||
# LLM
|
# LLM (Text) - vLLM
|
||||||
LLM_BINDING: str = "ollama"
|
LLM_BINDING: str = "vllm" # ollama, vllm, openai
|
||||||
LLM_BINDING_HOST: str = "http://localhost:11434"
|
LLM_BINDING_HOST: str = "http://192.168.6.115:8002/v1" # vLLM OpenAI API base
|
||||||
LLM_MODEL: str = "deepseek-v3.2:cloud"
|
LLM_MODEL: str = "qwen2.5-7b-awq"
|
||||||
|
LLM_KEY: str = "EMPTY" # vLLM default key
|
||||||
|
|
||||||
# Embedding
|
# LLM (Vision) - vLLM
|
||||||
EMBEDDING_BINDING: str = "ollama"
|
VL_BINDING_HOST: str = "http://192.168.6.115:8001/v1"
|
||||||
EMBEDDING_BINDING_HOST: str = "http://localhost:11434"
|
VL_MODEL: str = "qwen2.5-vl-3b-awq"
|
||||||
EMBEDDING_MODEL: str = "bge-m3"
|
VL_KEY: str = "EMPTY"
|
||||||
|
|
||||||
|
# Embedding - TEI
|
||||||
|
EMBEDDING_BINDING: str = "tei" # ollama, tei, openai
|
||||||
|
EMBEDDING_BINDING_HOST: str = "http://192.168.6.115:8003" # TEI usually exposes /embed
|
||||||
|
EMBEDDING_MODEL: str = "BAAI/bge-m3" # model id in TEI
|
||||||
|
EMBEDDING_KEY: str = "EMPTY"
|
||||||
|
|
||||||
|
# Rerank - TEI
|
||||||
|
RERANK_ENABLED: bool = True
|
||||||
|
RERANK_BINDING_HOST: str = "http://192.168.6.115:8004"
|
||||||
|
RERANK_MODEL: str = "BAAI/bge-reranker-v2-m3"
|
||||||
|
RERANK_KEY: str = "EMPTY"
|
||||||
|
|
||||||
# RAG Config
|
# RAG Config
|
||||||
EMBEDDING_DIM: int = 1024
|
EMBEDDING_DIM: int = 1024
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,88 @@
|
||||||
|
import base64
|
||||||
|
import logging
|
||||||
|
import httpx
|
||||||
|
from io import BytesIO
|
||||||
|
from app.config import settings
|
||||||
|
|
||||||
|
async def vl_image_caption_func(image_data: bytes, prompt: str = "请详细描述这张图片") -> str:
|
||||||
|
"""
|
||||||
|
使用 VL 模型 (vLLM OpenAI API) 生成图片描述
|
||||||
|
"""
|
||||||
|
if not settings.VL_BINDING_HOST:
|
||||||
|
return "[Image Processing Skipped: No VL Model Configured]"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 1. 编码图片为 Base64
|
||||||
|
base64_image = base64.b64encode(image_data).decode('utf-8')
|
||||||
|
|
||||||
|
# 2. 构造 OpenAI 格式请求
|
||||||
|
# vLLM 支持 OpenAI Vision API
|
||||||
|
url = f"{settings.VL_BINDING_HOST}/chat/completions"
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {settings.VL_KEY}"
|
||||||
|
}
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": settings.VL_MODEL,
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": prompt},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": f"data:image/jpeg;base64,{base64_image}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"max_tokens": 300
|
||||||
|
}
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
|
response = await client.post(url, headers=headers, json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
result = response.json()
|
||||||
|
description = result['choices'][0]['message']['content']
|
||||||
|
return f"[Image Description: {description}]"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"VL Caption failed: {str(e)}")
|
||||||
|
return f"[Image Processing Failed: {str(e)}]"
|
||||||
|
|
||||||
|
async def process_pdf_with_images(file_bytes: bytes) -> str:
|
||||||
|
"""
|
||||||
|
解析 PDF,提取文本并对图片进行 Caption
|
||||||
|
"""
|
||||||
|
import pypdf
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
text_content = ""
|
||||||
|
pdf_file = BytesIO(file_bytes)
|
||||||
|
reader = pypdf.PdfReader(pdf_file)
|
||||||
|
|
||||||
|
for page_num, page in enumerate(reader.pages):
|
||||||
|
# 1. 提取文本
|
||||||
|
page_text = page.extract_text()
|
||||||
|
text_content += f"--- Page {page_num + 1} Text ---\n{page_text}\n\n"
|
||||||
|
|
||||||
|
# 2. 提取图片
|
||||||
|
if settings.VL_BINDING_HOST:
|
||||||
|
for count, image_file_object in enumerate(page.images):
|
||||||
|
try:
|
||||||
|
# 获取图片数据
|
||||||
|
image_data = image_file_object.data
|
||||||
|
|
||||||
|
# 简单验证图片有效性
|
||||||
|
# Image.open(BytesIO(image_data)).verify()
|
||||||
|
|
||||||
|
# 调用 VL 模型
|
||||||
|
caption = await vl_image_caption_func(image_data)
|
||||||
|
text_content += f"--- Page {page_num + 1} Image {count + 1} ---\n{caption}\n\n"
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"Failed to process image {count} on page {page_num}: {e}")
|
||||||
|
|
||||||
|
return text_content
|
||||||
217
app/core/rag.py
217
app/core/rag.py
|
|
@ -1,11 +1,12 @@
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
from collections import OrderedDict
|
import httpx
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import ollama
|
import ollama
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Optional, List, Union
|
||||||
|
|
||||||
from lightrag import LightRAG
|
from lightrag import LightRAG
|
||||||
from lightrag.utils import EmbeddingFunc
|
from lightrag.utils import EmbeddingFunc
|
||||||
from lightrag.llm.ollama import ollama_embed
|
from lightrag.llm.ollama import ollama_embed
|
||||||
|
|
@ -14,23 +15,22 @@ from app.config import settings
|
||||||
# 全局 RAG 管理器
|
# 全局 RAG 管理器
|
||||||
rag_manager = None
|
rag_manager = None
|
||||||
|
|
||||||
async def llm_func(prompt, system_prompt=None, history_messages=[], **kwargs) -> str:
|
# ==============================================================================
|
||||||
"""定义 LLM 函数:使用 Ollama 生成回复"""
|
# LLM Functions
|
||||||
# 移除可能存在的 model 参数,避免冲突
|
# ==============================================================================
|
||||||
|
|
||||||
|
async def ollama_llm_func(prompt, system_prompt=None, history_messages=[], **kwargs) -> str:
|
||||||
|
"""Ollama LLM 实现"""
|
||||||
|
# 参数清理
|
||||||
kwargs.pop('model', None)
|
kwargs.pop('model', None)
|
||||||
kwargs.pop('hashing_kv', None)
|
kwargs.pop('hashing_kv', None)
|
||||||
kwargs.pop('enable_cot', None) # 移除不支持的参数
|
kwargs.pop('enable_cot', None)
|
||||||
keyword_extraction = kwargs.pop("keyword_extraction", False)
|
keyword_extraction = kwargs.pop("keyword_extraction", False)
|
||||||
if keyword_extraction:
|
if keyword_extraction:
|
||||||
kwargs["format"] = "json"
|
kwargs["format"] = "json"
|
||||||
|
|
||||||
stream = kwargs.pop("stream", False)
|
stream = kwargs.pop("stream", False)
|
||||||
think = kwargs.pop("think", False)
|
think = kwargs.pop("think", False)
|
||||||
# Debug: 检查流式参数
|
|
||||||
if stream:
|
|
||||||
logging.info("LLM called with stream=True")
|
|
||||||
else:
|
|
||||||
logging.info("LLM called with stream=False")
|
|
||||||
|
|
||||||
messages = []
|
messages = []
|
||||||
if system_prompt:
|
if system_prompt:
|
||||||
|
|
@ -41,30 +41,169 @@ async def llm_func(prompt, system_prompt=None, history_messages=[], **kwargs) ->
|
||||||
client = ollama.AsyncClient(host=settings.LLM_BINDING_HOST)
|
client = ollama.AsyncClient(host=settings.LLM_BINDING_HOST)
|
||||||
if stream:
|
if stream:
|
||||||
async def inner():
|
async def inner():
|
||||||
# 使用 **kwargs 透传参数,确保 format 等顶级参数生效
|
|
||||||
response = await client.chat(model=settings.LLM_MODEL, messages=messages, stream=True, think=think, **kwargs)
|
response = await client.chat(model=settings.LLM_MODEL, messages=messages, stream=True, think=think, **kwargs)
|
||||||
|
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
msg = chunk.get("message", {})
|
msg = chunk.get("message", {})
|
||||||
|
|
||||||
if "thinking" in msg and msg["thinking"]:
|
if "thinking" in msg and msg["thinking"]:
|
||||||
yield {"type": "thinking", "content": msg["thinking"]}
|
yield {"type": "thinking", "content": msg["thinking"]}
|
||||||
if "content" in msg and msg["content"]:
|
if "content" in msg and msg["content"]:
|
||||||
yield {"type": "content", "content": msg["content"]}
|
yield {"type": "content", "content": msg["content"]}
|
||||||
return inner()
|
return inner()
|
||||||
else:
|
else:
|
||||||
# 同步请求不要think,降低延迟
|
|
||||||
response = await client.chat(model=settings.LLM_MODEL, messages=messages, stream=False, think=False, **kwargs)
|
response = await client.chat(model=settings.LLM_MODEL, messages=messages, stream=False, think=False, **kwargs)
|
||||||
return response["message"]["content"]
|
return response["message"]["content"]
|
||||||
|
|
||||||
async def embedding_func(texts: list[str]) -> np.ndarray:
|
async def openai_llm_func(prompt, system_prompt=None, history_messages=[], **kwargs) -> str:
|
||||||
"""定义 Embedding 函数:使用 Ollama 计算向量"""
|
"""OpenAI 兼容 LLM 实现 (适用于 vLLM)"""
|
||||||
|
# 参数清理
|
||||||
|
kwargs.pop('model', None)
|
||||||
|
kwargs.pop('hashing_kv', None)
|
||||||
|
kwargs.pop('enable_cot', None)
|
||||||
|
keyword_extraction = kwargs.pop("keyword_extraction", False)
|
||||||
|
|
||||||
|
# vLLM/OpenAI 不直接支持 format="json",通常需要在 prompt 中指示或使用 response_format
|
||||||
|
# 这里简单处理:如果需要 json,在 prompt 中暗示(LightRAG 的 prompt 通常已经包含 json 指令)
|
||||||
|
if keyword_extraction:
|
||||||
|
kwargs["response_format"] = {"type": "json_object"}
|
||||||
|
|
||||||
|
stream = kwargs.pop("stream", False)
|
||||||
|
# think 参数是 DeepSeek 特有的,OpenAI 标准接口不支持,暂时忽略
|
||||||
|
kwargs.pop("think", None)
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
if system_prompt:
|
||||||
|
messages.append({"role": "system", "content": system_prompt})
|
||||||
|
messages.extend(history_messages)
|
||||||
|
messages.append({"role": "user", "content": prompt})
|
||||||
|
|
||||||
|
url = f"{settings.LLM_BINDING_HOST}/chat/completions"
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {settings.LLM_KEY}"
|
||||||
|
}
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": settings.LLM_MODEL,
|
||||||
|
"messages": messages,
|
||||||
|
"stream": stream,
|
||||||
|
**kwargs
|
||||||
|
}
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
async def inner():
|
||||||
|
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||||
|
async with client.stream("POST", url, headers=headers, json=payload) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
async for line in response.aiter_lines():
|
||||||
|
if line.startswith("data: "):
|
||||||
|
data_str = line[6:]
|
||||||
|
if data_str.strip() == "[DONE]":
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
import json
|
||||||
|
chunk = json.loads(data_str)
|
||||||
|
delta = chunk['choices'][0]['delta']
|
||||||
|
if 'content' in delta:
|
||||||
|
yield {"type": "content", "content": delta['content']}
|
||||||
|
# vLLM 可能不返回 thinking 字段,除非是 DeepSeek 模型且配置了
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return inner()
|
||||||
|
else:
|
||||||
|
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||||
|
response = await client.post(url, headers=headers, json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
result = response.json()
|
||||||
|
return result['choices'][0]['message']['content']
|
||||||
|
|
||||||
|
async def llm_func(prompt, system_prompt=None, history_messages=[], **kwargs) -> str:
|
||||||
|
"""LLM 调度函数"""
|
||||||
|
if settings.LLM_BINDING == "ollama":
|
||||||
|
return await ollama_llm_func(prompt, system_prompt, history_messages, **kwargs)
|
||||||
|
elif settings.LLM_BINDING in ["vllm", "openai", "custom"]:
|
||||||
|
return await openai_llm_func(prompt, system_prompt, history_messages, **kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported LLM_BINDING: {settings.LLM_BINDING}")
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# Embedding Functions
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
async def ollama_embedding_func(texts: list[str]) -> np.ndarray:
|
||||||
return await ollama_embed(
|
return await ollama_embed(
|
||||||
texts,
|
texts,
|
||||||
embed_model=settings.EMBEDDING_MODEL,
|
embed_model=settings.EMBEDDING_MODEL,
|
||||||
host=settings.EMBEDDING_BINDING_HOST
|
host=settings.EMBEDDING_BINDING_HOST
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def tei_embedding_func(texts: list[str]) -> np.ndarray:
|
||||||
|
"""TEI (Text Embeddings Inference) Embedding 实现"""
|
||||||
|
url = f"{settings.EMBEDDING_BINDING_HOST}/embed" # TEI 标准接口
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
if settings.EMBEDDING_KEY and settings.EMBEDDING_KEY != "EMPTY":
|
||||||
|
headers["Authorization"] = f"Bearer {settings.EMBEDDING_KEY}"
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"inputs": texts,
|
||||||
|
"truncate": True # TEI 参数,防止超长报错
|
||||||
|
}
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||||
|
response = await client.post(url, headers=headers, json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
# TEI 返回: [[0.1, ...], [0.2, ...]]
|
||||||
|
embeddings = response.json()
|
||||||
|
return np.array(embeddings)
|
||||||
|
|
||||||
|
async def embedding_func(texts: list[str]) -> np.ndarray:
|
||||||
|
"""Embedding 调度函数"""
|
||||||
|
if settings.EMBEDDING_BINDING == "ollama":
|
||||||
|
return await ollama_embedding_func(texts)
|
||||||
|
elif settings.EMBEDDING_BINDING == "tei":
|
||||||
|
return await tei_embedding_func(texts)
|
||||||
|
else:
|
||||||
|
# 默认回退到 ollama 或者报错
|
||||||
|
raise ValueError(f"Unsupported EMBEDDING_BINDING: {settings.EMBEDDING_BINDING}")
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# Rerank Functions
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
async def tei_rerank_func(query: str, documents: list[str]) -> np.ndarray:
|
||||||
|
"""TEI Rerank 实现"""
|
||||||
|
if not documents:
|
||||||
|
return np.array([])
|
||||||
|
|
||||||
|
url = f"{settings.RERANK_BINDING_HOST}/rerank"
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
if settings.RERANK_KEY and settings.RERANK_KEY != "EMPTY":
|
||||||
|
headers["Authorization"] = f"Bearer {settings.RERANK_KEY}"
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"query": query,
|
||||||
|
"texts": documents,
|
||||||
|
"return_text": False
|
||||||
|
}
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
|
response = await client.post(url, headers=headers, json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
# TEI 返回: [{"index": 0, "score": 0.99}, {"index": 1, "score": 0.5}]
|
||||||
|
results = response.json()
|
||||||
|
|
||||||
|
# LightRAG 期望返回一个分数数组,对应输入的 documents 顺序
|
||||||
|
# TEI 返回的结果是排序过的,我们需要根据 index 还原顺序
|
||||||
|
scores = np.zeros(len(documents))
|
||||||
|
for res in results:
|
||||||
|
idx = res['index']
|
||||||
|
scores[idx] = res['score']
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# RAG Manager
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
class RAGManager:
|
class RAGManager:
|
||||||
def __init__(self, capacity: int = 3):
|
def __init__(self, capacity: int = 3):
|
||||||
self.capacity = capacity
|
self.capacity = capacity
|
||||||
|
|
@ -73,44 +212,44 @@ class RAGManager:
|
||||||
|
|
||||||
async def get_rag(self, user_id: str) -> LightRAG:
|
async def get_rag(self, user_id: str) -> LightRAG:
|
||||||
"""获取指定用户的 LightRAG 实例 (LRU 缓存)"""
|
"""获取指定用户的 LightRAG 实例 (LRU 缓存)"""
|
||||||
# 1. 尝试从缓存获取
|
|
||||||
with self.lock:
|
with self.lock:
|
||||||
if user_id in self.cache:
|
if user_id in self.cache:
|
||||||
self.cache.move_to_end(user_id)
|
self.cache.move_to_end(user_id)
|
||||||
logging.debug(f"Cache hit for user: {user_id}")
|
logging.debug(f"Cache hit for user: {user_id}")
|
||||||
return self.cache[user_id]
|
return self.cache[user_id]
|
||||||
|
|
||||||
# 2. 缓存未命中,需要初始化
|
|
||||||
logging.info(f"Initializing RAG instance for user: {user_id}")
|
logging.info(f"Initializing RAG instance for user: {user_id}")
|
||||||
user_data_dir = os.path.join(settings.DATA_DIR, user_id)
|
user_data_dir = os.path.join(settings.DATA_DIR, user_id)
|
||||||
|
|
||||||
# 严格隔离:确保工作目录绝对干净,不应包含上级目录或其他用户数据
|
|
||||||
if not os.path.exists(user_data_dir):
|
if not os.path.exists(user_data_dir):
|
||||||
os.makedirs(user_data_dir)
|
os.makedirs(user_data_dir)
|
||||||
|
|
||||||
# 实例化 LightRAG
|
# 准备参数
|
||||||
# 关键修复:确保 working_dir 严格指向用户子目录
|
rag_params = {
|
||||||
rag = LightRAG(
|
"working_dir": user_data_dir,
|
||||||
working_dir=user_data_dir,
|
"llm_model_func": llm_func,
|
||||||
llm_model_func=llm_func,
|
"llm_model_name": settings.LLM_MODEL,
|
||||||
llm_model_name=settings.LLM_MODEL,
|
"llm_model_max_async": 4, # vLLM 并发能力强,可以调高
|
||||||
llm_model_max_async=1,
|
"max_parallel_insert": 1,
|
||||||
max_parallel_insert=1,
|
"embedding_func": EmbeddingFunc(
|
||||||
embedding_func=EmbeddingFunc(
|
|
||||||
embedding_dim=settings.EMBEDDING_DIM,
|
embedding_dim=settings.EMBEDDING_DIM,
|
||||||
max_token_size=settings.MAX_TOKEN_SIZE,
|
max_token_size=settings.MAX_TOKEN_SIZE,
|
||||||
func=embedding_func
|
func=embedding_func
|
||||||
),
|
),
|
||||||
embedding_func_max_async=4,
|
"embedding_func_max_async": 8, # TEI 并发强
|
||||||
enable_llm_cache=True
|
"enable_llm_cache": True
|
||||||
)
|
}
|
||||||
|
|
||||||
|
# 如果启用了 Rerank,注入 rerank_model_func
|
||||||
|
if settings.RERANK_ENABLED:
|
||||||
|
logging.info("Rerank enabled for RAG instance")
|
||||||
|
rag_params["rerank_model_func"] = tei_rerank_func
|
||||||
|
|
||||||
|
rag = LightRAG(**rag_params)
|
||||||
|
|
||||||
# 异步初始化存储
|
|
||||||
await rag.initialize_storages()
|
await rag.initialize_storages()
|
||||||
|
|
||||||
# 3. 放入缓存并处理驱逐
|
|
||||||
with self.lock:
|
with self.lock:
|
||||||
# 双重检查,防止在初始化期间其他线程已经创建了
|
|
||||||
if user_id in self.cache:
|
if user_id in self.cache:
|
||||||
self.cache.move_to_end(user_id)
|
self.cache.move_to_end(user_id)
|
||||||
return self.cache[user_id]
|
return self.cache[user_id]
|
||||||
|
|
@ -118,23 +257,19 @@ class RAGManager:
|
||||||
self.cache[user_id] = rag
|
self.cache[user_id] = rag
|
||||||
self.cache.move_to_end(user_id)
|
self.cache.move_to_end(user_id)
|
||||||
|
|
||||||
# 检查容量
|
|
||||||
while len(self.cache) > self.capacity:
|
while len(self.cache) > self.capacity:
|
||||||
oldest_user, _ = self.cache.popitem(last=False)
|
oldest_user, _ = self.cache.popitem(last=False)
|
||||||
logging.info(f"Evicting RAG instance for user: {oldest_user} due to capacity limit")
|
logging.info(f"Evicting RAG instance for user: {oldest_user}")
|
||||||
# 这里依赖 Python GC 回收资源
|
|
||||||
|
|
||||||
return rag
|
return rag
|
||||||
|
|
||||||
def initialize_rag_manager():
|
def initialize_rag_manager():
|
||||||
"""初始化全局 RAG 管理器"""
|
|
||||||
global rag_manager
|
global rag_manager
|
||||||
rag_manager = RAGManager(capacity=settings.MAX_RAG_INSTANCES)
|
rag_manager = RAGManager(capacity=settings.MAX_RAG_INSTANCES)
|
||||||
logging.info(f"RAG Manager initialized with capacity: {settings.MAX_RAG_INSTANCES}")
|
logging.info(f"RAG Manager initialized with capacity: {settings.MAX_RAG_INSTANCES}")
|
||||||
return rag_manager
|
return rag_manager
|
||||||
|
|
||||||
def get_rag_manager() -> RAGManager:
|
def get_rag_manager() -> RAGManager:
|
||||||
"""获取全局 RAG 管理器"""
|
|
||||||
if rag_manager is None:
|
if rag_manager is None:
|
||||||
return initialize_rag_manager()
|
return initialize_rag_manager()
|
||||||
return rag_manager
|
return rag_manager
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue