99 lines
3.3 KiB
Python
99 lines
3.3 KiB
Python
import logging
|
||
import os
|
||
import ollama
|
||
import numpy as np
|
||
from lightrag import LightRAG
|
||
from lightrag.utils import EmbeddingFunc
|
||
from lightrag.llm.ollama import ollama_embed
|
||
from app.config import settings
|
||
|
||
# 全局 RAG 实例
|
||
rag = None
|
||
|
||
async def llm_func(prompt, system_prompt=None, history_messages=[], **kwargs) -> str:
|
||
"""定义 LLM 函数:使用 Ollama 生成回复"""
|
||
# 移除可能存在的 model 参数,避免冲突
|
||
kwargs.pop('model', None)
|
||
kwargs.pop('hashing_kv', None)
|
||
keyword_extraction = kwargs.pop("keyword_extraction", False)
|
||
if keyword_extraction:
|
||
kwargs["format"] = "json"
|
||
|
||
stream = kwargs.pop("stream", False)
|
||
# Debug: 检查流式参数
|
||
if stream:
|
||
logging.info("LLM called with stream=True")
|
||
else:
|
||
logging.info("LLM called with stream=False")
|
||
|
||
messages = []
|
||
if system_prompt:
|
||
messages.append({"role": "system", "content": system_prompt})
|
||
messages.extend(history_messages)
|
||
messages.append({"role": "user", "content": prompt})
|
||
|
||
client = ollama.AsyncClient(host=settings.LLM_BINDING_HOST)
|
||
if stream:
|
||
async def inner():
|
||
# 使用 **kwargs 透传参数,确保 format 等顶级参数生效
|
||
response = await client.chat(model=settings.LLM_MODEL, messages=messages, stream=True, **kwargs)
|
||
|
||
async for chunk in response:
|
||
msg = chunk.get("message", {})
|
||
|
||
if "thinking" in msg and msg["thinking"]:
|
||
yield {"type": "thinking", "content": msg["thinking"]}
|
||
if "content" in msg and msg["content"]:
|
||
yield {"type": "content", "content": msg["content"]}
|
||
return inner()
|
||
else:
|
||
# 同步请求不要think,降低延迟
|
||
response = await client.chat(model=settings.LLM_MODEL, messages=messages, stream=False, think=False, **kwargs)
|
||
return response["message"]["content"]
|
||
|
||
async def embedding_func(texts: list[str]) -> np.ndarray:
|
||
"""定义 Embedding 函数:使用 Ollama 计算向量"""
|
||
return await ollama_embed(
|
||
texts,
|
||
embed_model=settings.EMBEDDING_MODEL,
|
||
host=settings.EMBEDDING_BINDING_HOST
|
||
)
|
||
|
||
async def initialize_rag():
|
||
"""初始化 LightRAG 实例"""
|
||
global rag
|
||
|
||
# 确保工作目录存在
|
||
if not os.path.exists(settings.DATA_DIR):
|
||
os.makedirs(settings.DATA_DIR)
|
||
|
||
print(f"正在初始化 LightRAG...")
|
||
print(f"LLM: {settings.LLM_MODEL} @ {settings.LLM_BINDING_HOST}")
|
||
print(f"Embedding: {settings.EMBEDDING_MODEL} @ {settings.EMBEDDING_BINDING_HOST}")
|
||
|
||
rag = LightRAG(
|
||
working_dir=settings.DATA_DIR,
|
||
llm_model_func=llm_func,
|
||
llm_model_name=settings.LLM_MODEL,
|
||
llm_model_max_async=1,
|
||
max_parallel_insert=1,
|
||
embedding_func=EmbeddingFunc(
|
||
embedding_dim=settings.EMBEDDING_DIM,
|
||
max_token_size=settings.MAX_TOKEN_SIZE,
|
||
func=embedding_func
|
||
),
|
||
embedding_func_max_async=4,
|
||
enable_llm_cache=True
|
||
)
|
||
|
||
print("正在初始化 LightRAG 存储...")
|
||
await rag.initialize_storages()
|
||
print("LightRAG 存储初始化完成")
|
||
return rag
|
||
|
||
def get_rag():
|
||
"""获取全局 RAG 实例"""
|
||
if rag is None:
|
||
raise RuntimeError("RAG instance not initialized. Call initialize_rag() first.")
|
||
return rag
|