diff --git a/app/api/routes.py b/app/api/routes.py index d221903..cba5952 100644 --- a/app/api/routes.py +++ b/app/api/routes.py @@ -34,6 +34,7 @@ class QueryRequest(BaseModel): mode: str = "hybrid" # 可选: naive, local, global, hybrid top_k: int = 5 stream: bool = False + think: bool = False class IngestResponse(BaseModel): filename: str @@ -111,6 +112,7 @@ async def query_knowledge_base( request.query, system_prompt=sys_prompt, stream=True, + think=request.think, hashing_kv=rag.llm_response_cache ) diff --git a/app/core/rag.py b/app/core/rag.py index 8968f46..9581cd6 100644 --- a/app/core/rag.py +++ b/app/core/rag.py @@ -19,11 +19,13 @@ async def llm_func(prompt, system_prompt=None, history_messages=[], **kwargs) -> # 移除可能存在的 model 参数,避免冲突 kwargs.pop('model', None) kwargs.pop('hashing_kv', None) + kwargs.pop('enable_cot', None) # 移除不支持的参数 keyword_extraction = kwargs.pop("keyword_extraction", False) if keyword_extraction: kwargs["format"] = "json" stream = kwargs.pop("stream", False) + think = kwargs.pop("think", False) # Debug: 检查流式参数 if stream: logging.info("LLM called with stream=True") @@ -40,7 +42,7 @@ async def llm_func(prompt, system_prompt=None, history_messages=[], **kwargs) -> if stream: async def inner(): # 使用 **kwargs 透传参数,确保 format 等顶级参数生效 - response = await client.chat(model=settings.LLM_MODEL, messages=messages, stream=True, **kwargs) + response = await client.chat(model=settings.LLM_MODEL, messages=messages, stream=True, think=think, **kwargs) async for chunk in response: msg = chunk.get("message", {})