ai-lightrag/app/core/rag.py

99 lines
3.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
import os
import ollama
import numpy as np
from lightrag import LightRAG
from lightrag.utils import EmbeddingFunc
from lightrag.llm.ollama import ollama_embed
from app.config import settings
# 全局 RAG 实例
rag = None
async def llm_func(prompt, system_prompt=None, history_messages=[], **kwargs) -> str:
"""定义 LLM 函数:使用 Ollama 生成回复"""
# 移除可能存在的 model 参数,避免冲突
kwargs.pop('model', None)
kwargs.pop('hashing_kv', None)
keyword_extraction = kwargs.pop("keyword_extraction", False)
if keyword_extraction:
kwargs["format"] = "json"
stream = kwargs.pop("stream", False)
# Debug: 检查流式参数
if stream:
logging.info("LLM called with stream=True")
else:
logging.info("LLM called with stream=False")
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
client = ollama.AsyncClient(host=settings.LLM_BINDING_HOST)
if stream:
async def inner():
# 使用 **kwargs 透传参数,确保 format 等顶级参数生效
response = await client.chat(model=settings.LLM_MODEL, messages=messages, stream=True, **kwargs)
async for chunk in response:
msg = chunk.get("message", {})
if "thinking" in msg and msg["thinking"]:
yield {"type": "thinking", "content": msg["thinking"]}
if "content" in msg and msg["content"]:
yield {"type": "content", "content": msg["content"]}
return inner()
else:
# 同步请求不要think降低延迟
response = await client.chat(model=settings.LLM_MODEL, messages=messages, stream=False, think=False, **kwargs)
return response["message"]["content"]
async def embedding_func(texts: list[str]) -> np.ndarray:
"""定义 Embedding 函数:使用 Ollama 计算向量"""
return await ollama_embed(
texts,
embed_model=settings.EMBEDDING_MODEL,
host=settings.EMBEDDING_BINDING_HOST
)
async def initialize_rag():
"""初始化 LightRAG 实例"""
global rag
# 确保工作目录存在
if not os.path.exists(settings.DATA_DIR):
os.makedirs(settings.DATA_DIR)
print(f"正在初始化 LightRAG...")
print(f"LLM: {settings.LLM_MODEL} @ {settings.LLM_BINDING_HOST}")
print(f"Embedding: {settings.EMBEDDING_MODEL} @ {settings.EMBEDDING_BINDING_HOST}")
rag = LightRAG(
working_dir=settings.DATA_DIR,
llm_model_func=llm_func,
llm_model_name=settings.LLM_MODEL,
llm_model_max_async=1,
max_parallel_insert=1,
embedding_func=EmbeddingFunc(
embedding_dim=settings.EMBEDDING_DIM,
max_token_size=settings.MAX_TOKEN_SIZE,
func=embedding_func
),
embedding_func_max_async=4,
enable_llm_cache=True
)
print("正在初始化 LightRAG 存储...")
await rag.initialize_storages()
print("LightRAG 存储初始化完成")
return rag
def get_rag():
"""获取全局 RAG 实例"""
if rag is None:
raise RuntimeError("RAG instance not initialized. Call initialize_rag() first.")
return rag