import logging import os import ollama import numpy as np from lightrag import LightRAG from lightrag.utils import EmbeddingFunc from lightrag.llm.ollama import ollama_embed from app.config import settings # 全局 RAG 实例 rag = None async def llm_func(prompt, system_prompt=None, history_messages=[], **kwargs) -> str: """定义 LLM 函数:使用 Ollama 生成回复""" # 移除可能存在的 model 参数,避免冲突 kwargs.pop('model', None) kwargs.pop('hashing_kv', None) keyword_extraction = kwargs.pop("keyword_extraction", False) if keyword_extraction: kwargs["format"] = "json" stream = kwargs.pop("stream", False) # Debug: 检查流式参数 if stream: logging.info("LLM called with stream=True") else: logging.info("LLM called with stream=False") messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) messages.extend(history_messages) messages.append({"role": "user", "content": prompt}) client = ollama.AsyncClient(host=settings.LLM_BINDING_HOST) if stream: async def inner(): # 使用 **kwargs 透传参数,确保 format 等顶级参数生效 response = await client.chat(model=settings.LLM_MODEL, messages=messages, stream=True, **kwargs) async for chunk in response: msg = chunk.get("message", {}) if "thinking" in msg and msg["thinking"]: yield {"type": "thinking", "content": msg["thinking"]} if "content" in msg and msg["content"]: yield {"type": "content", "content": msg["content"]} return inner() else: # 同步请求不要think,降低延迟 response = await client.chat(model=settings.LLM_MODEL, messages=messages, stream=False, think=False, **kwargs) return response["message"]["content"] async def embedding_func(texts: list[str]) -> np.ndarray: """定义 Embedding 函数:使用 Ollama 计算向量""" return await ollama_embed( texts, embed_model=settings.EMBEDDING_MODEL, host=settings.EMBEDDING_BINDING_HOST ) async def initialize_rag(): """初始化 LightRAG 实例""" global rag # 确保工作目录存在 if not os.path.exists(settings.DATA_DIR): os.makedirs(settings.DATA_DIR) print(f"正在初始化 LightRAG...") print(f"LLM: {settings.LLM_MODEL} @ {settings.LLM_BINDING_HOST}") print(f"Embedding: {settings.EMBEDDING_MODEL} @ {settings.EMBEDDING_BINDING_HOST}") rag = LightRAG( working_dir=settings.DATA_DIR, llm_model_func=llm_func, llm_model_name=settings.LLM_MODEL, llm_model_max_async=1, max_parallel_insert=1, embedding_func=EmbeddingFunc( embedding_dim=settings.EMBEDDING_DIM, max_token_size=settings.MAX_TOKEN_SIZE, func=embedding_func ), embedding_func_max_async=4, enable_llm_cache=True ) print("正在初始化 LightRAG 存储...") await rag.initialize_storages() print("LightRAG 存储初始化完成") return rag def get_rag(): """获取全局 RAG 实例""" if rag is None: raise RuntimeError("RAG instance not initialized. Call initialize_rag() first.") return rag