282 lines
11 KiB
Python
282 lines
11 KiB
Python
import logging
|
||
import os
|
||
import threading
|
||
import httpx
|
||
import numpy as np
|
||
import ollama
|
||
from collections import OrderedDict
|
||
from typing import Optional, List, Union
|
||
|
||
from lightrag import LightRAG
|
||
from lightrag.utils import EmbeddingFunc
|
||
from lightrag.llm.ollama import ollama_embed
|
||
from app.config import settings
|
||
|
||
# 全局 RAG 管理器
|
||
rag_manager = None
|
||
|
||
# ==============================================================================
|
||
# LLM Functions
|
||
# ==============================================================================
|
||
|
||
async def ollama_llm_func(prompt, system_prompt=None, history_messages=[], **kwargs) -> str:
|
||
"""Ollama LLM 实现"""
|
||
# 参数清理
|
||
kwargs.pop('model', None)
|
||
kwargs.pop('hashing_kv', None)
|
||
kwargs.pop('enable_cot', None)
|
||
keyword_extraction = kwargs.pop("keyword_extraction", False)
|
||
if keyword_extraction:
|
||
kwargs["format"] = "json"
|
||
|
||
stream = kwargs.pop("stream", False)
|
||
think = kwargs.pop("think", False)
|
||
|
||
messages = []
|
||
if system_prompt:
|
||
messages.append({"role": "system", "content": system_prompt})
|
||
messages.extend(history_messages)
|
||
messages.append({"role": "user", "content": prompt})
|
||
|
||
client = ollama.AsyncClient(host=settings.LLM_BINDING_HOST)
|
||
if stream:
|
||
async def inner():
|
||
response = await client.chat(model=settings.LLM_MODEL, messages=messages, stream=True, think=think, **kwargs)
|
||
async for chunk in response:
|
||
msg = chunk.get("message", {})
|
||
if "thinking" in msg and msg["thinking"]:
|
||
yield {"type": "thinking", "content": msg["thinking"]}
|
||
if "content" in msg and msg["content"]:
|
||
yield {"type": "content", "content": msg["content"]}
|
||
return inner()
|
||
else:
|
||
response = await client.chat(model=settings.LLM_MODEL, messages=messages, stream=False, think=False, **kwargs)
|
||
return response["message"]["content"]
|
||
|
||
async def openai_llm_func(prompt, system_prompt=None, history_messages=[], **kwargs) -> str:
|
||
"""OpenAI 兼容 LLM 实现 (适用于 vLLM)"""
|
||
# 参数清理
|
||
kwargs.pop('model', None)
|
||
kwargs.pop('hashing_kv', None)
|
||
kwargs.pop('enable_cot', None)
|
||
keyword_extraction = kwargs.pop("keyword_extraction", False)
|
||
|
||
# vLLM/OpenAI 不直接支持 format="json",通常需要在 prompt 中指示或使用 response_format
|
||
# 这里简单处理:如果需要 json,在 prompt 中暗示(LightRAG 的 prompt 通常已经包含 json 指令)
|
||
if keyword_extraction:
|
||
kwargs["response_format"] = {"type": "json_object"}
|
||
|
||
stream = kwargs.pop("stream", False)
|
||
# think 参数是 DeepSeek 特有的,OpenAI 标准接口不支持,暂时忽略
|
||
think = kwargs.pop("think", None)
|
||
# 这里使用qwen3指定的chat_template_kwargs,开启/禁用思考模式
|
||
kwargs["chat_template_kwargs"] = {"enable_thinking": think}
|
||
|
||
messages = []
|
||
if system_prompt:
|
||
messages.append({"role": "system", "content": system_prompt})
|
||
messages.extend(history_messages)
|
||
messages.append({"role": "user", "content": prompt})
|
||
|
||
url = f"{settings.LLM_BINDING_HOST}/chat/completions"
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {settings.LLM_KEY}"
|
||
}
|
||
|
||
payload = {
|
||
"model": settings.LLM_MODEL,
|
||
"messages": messages,
|
||
"stream": stream,
|
||
**kwargs
|
||
}
|
||
|
||
if stream:
|
||
async def inner():
|
||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||
async with client.stream("POST", url, headers=headers, json=payload) as response:
|
||
response.raise_for_status()
|
||
async for line in response.aiter_lines():
|
||
if line.startswith("data: "):
|
||
data_str = line[6:]
|
||
if data_str.strip() == "[DONE]":
|
||
break
|
||
try:
|
||
import json
|
||
chunk = json.loads(data_str)
|
||
delta = chunk['choices'][0]['delta']
|
||
if 'content' in delta:
|
||
yield {"type": "content", "content": delta['content']}
|
||
# vLLM 可能不返回 thinking 字段,除非是 DeepSeek 模型且配置了
|
||
except:
|
||
pass
|
||
return inner()
|
||
else:
|
||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||
response = await client.post(url, headers=headers, json=payload)
|
||
response.raise_for_status()
|
||
result = response.json()
|
||
return result['choices'][0]['message']['content']
|
||
|
||
async def llm_func(prompt, system_prompt=None, history_messages=[], **kwargs) -> str:
|
||
"""LLM 调度函数"""
|
||
if settings.LLM_BINDING == "ollama":
|
||
return await ollama_llm_func(prompt, system_prompt, history_messages, **kwargs)
|
||
elif settings.LLM_BINDING in ["vllm", "openai", "custom"]:
|
||
return await openai_llm_func(prompt, system_prompt, history_messages, **kwargs)
|
||
else:
|
||
raise ValueError(f"Unsupported LLM_BINDING: {settings.LLM_BINDING}")
|
||
|
||
# ==============================================================================
|
||
# Embedding Functions
|
||
# ==============================================================================
|
||
|
||
async def ollama_embedding_func(texts: list[str]) -> np.ndarray:
|
||
return await ollama_embed(
|
||
texts,
|
||
embed_model=settings.EMBEDDING_MODEL,
|
||
host=settings.EMBEDDING_BINDING_HOST
|
||
)
|
||
|
||
async def tei_embedding_func(texts: list[str]) -> np.ndarray:
|
||
"""TEI (Text Embeddings Inference) Embedding 实现"""
|
||
url = f"{settings.EMBEDDING_BINDING_HOST}/embed" # TEI 标准接口
|
||
headers = {"Content-Type": "application/json"}
|
||
if settings.EMBEDDING_KEY and settings.EMBEDDING_KEY != "EMPTY":
|
||
headers["Authorization"] = f"Bearer {settings.EMBEDDING_KEY}"
|
||
|
||
payload = {
|
||
"inputs": texts,
|
||
"truncate": True # TEI 参数,防止超长报错
|
||
}
|
||
|
||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||
response = await client.post(url, headers=headers, json=payload)
|
||
response.raise_for_status()
|
||
# TEI 返回: [[0.1, ...], [0.2, ...]]
|
||
embeddings = response.json()
|
||
return np.array(embeddings)
|
||
|
||
async def embedding_func(texts: list[str]) -> np.ndarray:
|
||
"""Embedding 调度函数"""
|
||
if settings.EMBEDDING_BINDING == "ollama":
|
||
return await ollama_embedding_func(texts)
|
||
elif settings.EMBEDDING_BINDING == "tei":
|
||
return await tei_embedding_func(texts)
|
||
else:
|
||
# 默认回退到 ollama 或者报错
|
||
raise ValueError(f"Unsupported EMBEDDING_BINDING: {settings.EMBEDDING_BINDING}")
|
||
|
||
# ==============================================================================
|
||
# Rerank Functions
|
||
# ==============================================================================
|
||
|
||
async def tei_rerank_func(query: str, documents: list[str], top_n: int = 10) -> list[dict]:
|
||
"""TEI Rerank 实现"""
|
||
if not documents:
|
||
return []
|
||
|
||
url = f"{settings.RERANK_BINDING_HOST}/rerank"
|
||
headers = {"Content-Type": "application/json"}
|
||
if settings.RERANK_KEY and settings.RERANK_KEY != "EMPTY":
|
||
headers["Authorization"] = f"Bearer {settings.RERANK_KEY}"
|
||
|
||
# TEI 不支持 top_n 参数,我们手动截断或忽略
|
||
payload = {
|
||
"query": query,
|
||
"texts": documents,
|
||
"return_text": False
|
||
}
|
||
|
||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||
response = await client.post(url, headers=headers, json=payload)
|
||
response.raise_for_status()
|
||
# TEI 返回: [{"index": 0, "score": 0.99}, {"index": 1, "score": 0.5}]
|
||
results = response.json()
|
||
|
||
# LightRAG 期望返回包含 index 和 relevance_score 的字典列表
|
||
# 这样 LightRAG 才能正确映射回原始文档并进行排序
|
||
formatted_results = []
|
||
for res in results:
|
||
formatted_results.append({
|
||
"index": res['index'],
|
||
"relevance_score": res['score']
|
||
})
|
||
|
||
return formatted_results
|
||
|
||
# ==============================================================================
|
||
# RAG Manager
|
||
# ==============================================================================
|
||
|
||
class RAGManager:
|
||
def __init__(self, capacity: int = 3):
|
||
self.capacity = capacity
|
||
self.cache = OrderedDict()
|
||
self.lock = threading.Lock()
|
||
|
||
async def get_rag(self, user_id: str) -> LightRAG:
|
||
"""获取指定用户的 LightRAG 实例 (LRU 缓存)"""
|
||
with self.lock:
|
||
if user_id in self.cache:
|
||
self.cache.move_to_end(user_id)
|
||
logging.debug(f"Cache hit for user: {user_id}")
|
||
return self.cache[user_id]
|
||
|
||
logging.info(f"Initializing RAG instance for user: {user_id}")
|
||
user_data_dir = os.path.join(settings.DATA_DIR, user_id)
|
||
|
||
if not os.path.exists(user_data_dir):
|
||
os.makedirs(user_data_dir)
|
||
|
||
# 准备参数
|
||
rag_params = {
|
||
"working_dir": user_data_dir,
|
||
"llm_model_func": llm_func,
|
||
"llm_model_name": settings.LLM_MODEL,
|
||
"llm_model_max_async": settings.LLM_MODEL_MAX_ASYNC, # vLLM 并发能力强,可以调高
|
||
"max_parallel_insert": 1,
|
||
"embedding_func": EmbeddingFunc(
|
||
embedding_dim=settings.EMBEDDING_DIM,
|
||
max_token_size=settings.MAX_TOKEN_SIZE,
|
||
func=embedding_func
|
||
),
|
||
"embedding_func_max_async": 8, # TEI 并发强
|
||
"enable_llm_cache": True,
|
||
"cosine_threshold": settings.COSINE_THRESHOLD
|
||
}
|
||
|
||
# 如果启用了 Rerank,注入 rerank_model_func
|
||
if settings.RERANK_ENABLED:
|
||
logging.info("Rerank enabled for RAG instance")
|
||
rag_params["rerank_model_func"] = tei_rerank_func
|
||
|
||
rag = LightRAG(**rag_params)
|
||
|
||
await rag.initialize_storages()
|
||
|
||
with self.lock:
|
||
if user_id in self.cache:
|
||
self.cache.move_to_end(user_id)
|
||
return self.cache[user_id]
|
||
|
||
self.cache[user_id] = rag
|
||
self.cache.move_to_end(user_id)
|
||
|
||
while len(self.cache) > self.capacity:
|
||
oldest_user, _ = self.cache.popitem(last=False)
|
||
logging.info(f"Evicting RAG instance for user: {oldest_user}")
|
||
|
||
return rag
|
||
|
||
def initialize_rag_manager():
|
||
global rag_manager
|
||
rag_manager = RAGManager(capacity=settings.MAX_RAG_INSTANCES)
|
||
logging.info(f"RAG Manager initialized with capacity: {settings.MAX_RAG_INSTANCES}")
|
||
return rag_manager
|
||
|
||
def get_rag_manager() -> RAGManager:
|
||
if rag_manager is None:
|
||
return initialize_rag_manager()
|
||
return rag_manager
|