ai-lightrag/app/core/rag.py

139 lines
5.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
import os
import threading
from collections import OrderedDict
from typing import Optional
import numpy as np
import ollama
from lightrag import LightRAG
from lightrag.utils import EmbeddingFunc
from lightrag.llm.ollama import ollama_embed
from app.config import settings
# 全局 RAG 管理器
rag_manager = None
async def llm_func(prompt, system_prompt=None, history_messages=[], **kwargs) -> str:
"""定义 LLM 函数:使用 Ollama 生成回复"""
# 移除可能存在的 model 参数,避免冲突
kwargs.pop('model', None)
kwargs.pop('hashing_kv', None)
keyword_extraction = kwargs.pop("keyword_extraction", False)
if keyword_extraction:
kwargs["format"] = "json"
stream = kwargs.pop("stream", False)
# Debug: 检查流式参数
if stream:
logging.info("LLM called with stream=True")
else:
logging.info("LLM called with stream=False")
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
client = ollama.AsyncClient(host=settings.LLM_BINDING_HOST)
if stream:
async def inner():
# 使用 **kwargs 透传参数,确保 format 等顶级参数生效
response = await client.chat(model=settings.LLM_MODEL, messages=messages, stream=True, **kwargs)
async for chunk in response:
msg = chunk.get("message", {})
if "thinking" in msg and msg["thinking"]:
yield {"type": "thinking", "content": msg["thinking"]}
if "content" in msg and msg["content"]:
yield {"type": "content", "content": msg["content"]}
return inner()
else:
# 同步请求不要think降低延迟
response = await client.chat(model=settings.LLM_MODEL, messages=messages, stream=False, think=False, **kwargs)
return response["message"]["content"]
async def embedding_func(texts: list[str]) -> np.ndarray:
"""定义 Embedding 函数:使用 Ollama 计算向量"""
return await ollama_embed(
texts,
embed_model=settings.EMBEDDING_MODEL,
host=settings.EMBEDDING_BINDING_HOST
)
class RAGManager:
def __init__(self, capacity: int = 3):
self.capacity = capacity
self.cache = OrderedDict()
self.lock = threading.Lock()
async def get_rag(self, user_id: str) -> LightRAG:
"""获取指定用户的 LightRAG 实例 (LRU 缓存)"""
# 1. 尝试从缓存获取
with self.lock:
if user_id in self.cache:
self.cache.move_to_end(user_id)
logging.debug(f"Cache hit for user: {user_id}")
return self.cache[user_id]
# 2. 缓存未命中,需要初始化
logging.info(f"Initializing RAG instance for user: {user_id}")
user_data_dir = os.path.join(settings.DATA_DIR, user_id)
# 严格隔离:确保工作目录绝对干净,不应包含上级目录或其他用户数据
if not os.path.exists(user_data_dir):
os.makedirs(user_data_dir)
# 实例化 LightRAG
# 关键修复:确保 working_dir 严格指向用户子目录
rag = LightRAG(
working_dir=user_data_dir,
llm_model_func=llm_func,
llm_model_name=settings.LLM_MODEL,
llm_model_max_async=1,
max_parallel_insert=1,
embedding_func=EmbeddingFunc(
embedding_dim=settings.EMBEDDING_DIM,
max_token_size=settings.MAX_TOKEN_SIZE,
func=embedding_func
),
embedding_func_max_async=4,
enable_llm_cache=True
)
# 异步初始化存储
await rag.initialize_storages()
# 3. 放入缓存并处理驱逐
with self.lock:
# 双重检查,防止在初始化期间其他线程已经创建了
if user_id in self.cache:
self.cache.move_to_end(user_id)
return self.cache[user_id]
self.cache[user_id] = rag
self.cache.move_to_end(user_id)
# 检查容量
while len(self.cache) > self.capacity:
oldest_user, _ = self.cache.popitem(last=False)
logging.info(f"Evicting RAG instance for user: {oldest_user} due to capacity limit")
# 这里依赖 Python GC 回收资源
return rag
def initialize_rag_manager():
"""初始化全局 RAG 管理器"""
global rag_manager
rag_manager = RAGManager(capacity=settings.MAX_RAG_INSTANCES)
logging.info(f"RAG Manager initialized with capacity: {settings.MAX_RAG_INSTANCES}")
return rag_manager
def get_rag_manager() -> RAGManager:
"""获取全局 RAG 管理器"""
if rag_manager is None:
return initialize_rag_manager()
return rag_manager