ai-lightrag/app/core/rag.py

282 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
import os
import threading
import httpx
import numpy as np
import ollama
from collections import OrderedDict
from typing import Optional, List, Union
from lightrag import LightRAG
from lightrag.utils import EmbeddingFunc
from lightrag.llm.ollama import ollama_embed
from app.config import settings
# 全局 RAG 管理器
rag_manager = None
# ==============================================================================
# LLM Functions
# ==============================================================================
async def ollama_llm_func(prompt, system_prompt=None, history_messages=[], **kwargs) -> str:
"""Ollama LLM 实现"""
# 参数清理
kwargs.pop('model', None)
kwargs.pop('hashing_kv', None)
kwargs.pop('enable_cot', None)
keyword_extraction = kwargs.pop("keyword_extraction", False)
if keyword_extraction:
kwargs["format"] = "json"
stream = kwargs.pop("stream", False)
think = kwargs.pop("think", False)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
client = ollama.AsyncClient(host=settings.LLM_BINDING_HOST)
if stream:
async def inner():
response = await client.chat(model=settings.LLM_MODEL, messages=messages, stream=True, think=think, **kwargs)
async for chunk in response:
msg = chunk.get("message", {})
if "thinking" in msg and msg["thinking"]:
yield {"type": "thinking", "content": msg["thinking"]}
if "content" in msg and msg["content"]:
yield {"type": "content", "content": msg["content"]}
return inner()
else:
response = await client.chat(model=settings.LLM_MODEL, messages=messages, stream=False, think=False, **kwargs)
return response["message"]["content"]
async def openai_llm_func(prompt, system_prompt=None, history_messages=[], **kwargs) -> str:
"""OpenAI 兼容 LLM 实现 (适用于 vLLM)"""
# 参数清理
kwargs.pop('model', None)
kwargs.pop('hashing_kv', None)
kwargs.pop('enable_cot', None)
keyword_extraction = kwargs.pop("keyword_extraction", False)
# vLLM/OpenAI 不直接支持 format="json",通常需要在 prompt 中指示或使用 response_format
# 这里简单处理:如果需要 json在 prompt 中暗示LightRAG 的 prompt 通常已经包含 json 指令)
if keyword_extraction:
kwargs["response_format"] = {"type": "json_object"}
stream = kwargs.pop("stream", False)
# think 参数是 DeepSeek 特有的OpenAI 标准接口不支持,暂时忽略
think = kwargs.pop("think", None)
# 这里使用qwen3指定的chat_template_kwargs开启/禁用思考模式
kwargs["chat_template_kwargs"] = {"enable_thinking": think}
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
url = f"{settings.LLM_BINDING_HOST}/chat/completions"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {settings.LLM_KEY}"
}
payload = {
"model": settings.LLM_MODEL,
"messages": messages,
"stream": stream,
**kwargs
}
if stream:
async def inner():
async with httpx.AsyncClient(timeout=120.0) as client:
async with client.stream("POST", url, headers=headers, json=payload) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if line.startswith("data: "):
data_str = line[6:]
if data_str.strip() == "[DONE]":
break
try:
import json
chunk = json.loads(data_str)
delta = chunk['choices'][0]['delta']
if 'content' in delta:
yield {"type": "content", "content": delta['content']}
# vLLM 可能不返回 thinking 字段,除非是 DeepSeek 模型且配置了
except:
pass
return inner()
else:
async with httpx.AsyncClient(timeout=120.0) as client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
result = response.json()
return result['choices'][0]['message']['content']
async def llm_func(prompt, system_prompt=None, history_messages=[], **kwargs) -> str:
"""LLM 调度函数"""
if settings.LLM_BINDING == "ollama":
return await ollama_llm_func(prompt, system_prompt, history_messages, **kwargs)
elif settings.LLM_BINDING in ["vllm", "openai", "custom"]:
return await openai_llm_func(prompt, system_prompt, history_messages, **kwargs)
else:
raise ValueError(f"Unsupported LLM_BINDING: {settings.LLM_BINDING}")
# ==============================================================================
# Embedding Functions
# ==============================================================================
async def ollama_embedding_func(texts: list[str]) -> np.ndarray:
return await ollama_embed(
texts,
embed_model=settings.EMBEDDING_MODEL,
host=settings.EMBEDDING_BINDING_HOST
)
async def tei_embedding_func(texts: list[str]) -> np.ndarray:
"""TEI (Text Embeddings Inference) Embedding 实现"""
url = f"{settings.EMBEDDING_BINDING_HOST}/embed" # TEI 标准接口
headers = {"Content-Type": "application/json"}
if settings.EMBEDDING_KEY and settings.EMBEDDING_KEY != "EMPTY":
headers["Authorization"] = f"Bearer {settings.EMBEDDING_KEY}"
payload = {
"inputs": texts,
"truncate": True # TEI 参数,防止超长报错
}
async with httpx.AsyncClient(timeout=60.0) as client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
# TEI 返回: [[0.1, ...], [0.2, ...]]
embeddings = response.json()
return np.array(embeddings)
async def embedding_func(texts: list[str]) -> np.ndarray:
"""Embedding 调度函数"""
if settings.EMBEDDING_BINDING == "ollama":
return await ollama_embedding_func(texts)
elif settings.EMBEDDING_BINDING == "tei":
return await tei_embedding_func(texts)
else:
# 默认回退到 ollama 或者报错
raise ValueError(f"Unsupported EMBEDDING_BINDING: {settings.EMBEDDING_BINDING}")
# ==============================================================================
# Rerank Functions
# ==============================================================================
async def tei_rerank_func(query: str, documents: list[str], top_n: int = 10) -> list[dict]:
"""TEI Rerank 实现"""
if not documents:
return []
url = f"{settings.RERANK_BINDING_HOST}/rerank"
headers = {"Content-Type": "application/json"}
if settings.RERANK_KEY and settings.RERANK_KEY != "EMPTY":
headers["Authorization"] = f"Bearer {settings.RERANK_KEY}"
# TEI 不支持 top_n 参数,我们手动截断或忽略
payload = {
"query": query,
"texts": documents,
"return_text": False
}
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
# TEI 返回: [{"index": 0, "score": 0.99}, {"index": 1, "score": 0.5}]
results = response.json()
# LightRAG 期望返回包含 index 和 relevance_score 的字典列表
# 这样 LightRAG 才能正确映射回原始文档并进行排序
formatted_results = []
for res in results:
formatted_results.append({
"index": res['index'],
"relevance_score": res['score']
})
return formatted_results
# ==============================================================================
# RAG Manager
# ==============================================================================
class RAGManager:
def __init__(self, capacity: int = 3):
self.capacity = capacity
self.cache = OrderedDict()
self.lock = threading.Lock()
async def get_rag(self, user_id: str) -> LightRAG:
"""获取指定用户的 LightRAG 实例 (LRU 缓存)"""
with self.lock:
if user_id in self.cache:
self.cache.move_to_end(user_id)
logging.debug(f"Cache hit for user: {user_id}")
return self.cache[user_id]
logging.info(f"Initializing RAG instance for user: {user_id}")
user_data_dir = os.path.join(settings.DATA_DIR, user_id)
if not os.path.exists(user_data_dir):
os.makedirs(user_data_dir)
# 准备参数
rag_params = {
"working_dir": user_data_dir,
"llm_model_func": llm_func,
"llm_model_name": settings.LLM_MODEL,
"llm_model_max_async": 4, # vLLM 并发能力强,可以调高
"max_parallel_insert": 1,
"embedding_func": EmbeddingFunc(
embedding_dim=settings.EMBEDDING_DIM,
max_token_size=settings.MAX_TOKEN_SIZE,
func=embedding_func
),
"embedding_func_max_async": 8, # TEI 并发强
"enable_llm_cache": True,
"cosine_threshold": settings.COSINE_THRESHOLD
}
# 如果启用了 Rerank注入 rerank_model_func
if settings.RERANK_ENABLED:
logging.info("Rerank enabled for RAG instance")
rag_params["rerank_model_func"] = tei_rerank_func
rag = LightRAG(**rag_params)
await rag.initialize_storages()
with self.lock:
if user_id in self.cache:
self.cache.move_to_end(user_id)
return self.cache[user_id]
self.cache[user_id] = rag
self.cache.move_to_end(user_id)
while len(self.cache) > self.capacity:
oldest_user, _ = self.cache.popitem(last=False)
logging.info(f"Evicting RAG instance for user: {oldest_user}")
return rag
def initialize_rag_manager():
global rag_manager
rag_manager = RAGManager(capacity=settings.MAX_RAG_INSTANCES)
logging.info(f"RAG Manager initialized with capacity: {settings.MAX_RAG_INSTANCES}")
return rag_manager
def get_rag_manager() -> RAGManager:
if rag_manager is None:
return initialize_rag_manager()
return rag_manager