diff --git a/app/api/routes.py b/app/api/routes.py index 43735b7..24f0ee1 100644 --- a/app/api/routes.py +++ b/app/api/routes.py @@ -1,6 +1,7 @@ import json import logging import os +import time from fastapi import APIRouter, UploadFile, File, HTTPException, Body, Header, Depends from fastapi.responses import StreamingResponse from pydantic import BaseModel @@ -115,16 +116,25 @@ async def query_knowledge_base( """ try: # 构造查询参数 - param = QueryParam( + context_param = QueryParam( mode=request.mode, top_k=request.top_k, - stream=request.stream, + only_need_context=True, enable_rerank=settings.RERANK_ENABLED ) + + # 执行上下文检索 + context_resp = await rag.aquery(request.query, param=context_param) + + # 判断检索命中状态 + rag_status = "miss" + has_context = False + if context_resp and "[no-context]" not in context_resp and "None" not in context_resp: + rag_status = "hit" + has_context = True # 处理流式输出 (SSE 协议 - OpenAI 兼容格式) if request.stream: - import time async def stream_generator(): chat_id = f"chatcmpl-{secrets.token_hex(12)}" created_time = int(time.time()) @@ -155,37 +165,15 @@ async def query_knowledge_base( } return f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n" - # 1. 发送检索状态 (作为思考过程的一部分) - yield openai_chunk(reasoning_content="1. 上下文检索中...\n") - - context_param = QueryParam( - mode=request.mode, - top_k=request.top_k, - only_need_context=True, - enable_rerank=settings.RERANK_ENABLED - ) - - # 获取上下文 - context_resp = await rag.aquery(request.query, param=context_param) - logging.info(f"Context Response: {context_resp}") - - # 判断检索状态 - has_context = False - if context_resp and "[no-context]" not in context_resp and "None" not in context_resp: - has_context = True - - # 判断是否开启think - think = request.think - if has_context: yield openai_chunk( - reasoning_content=f"2. 上下文已检索 (长度: {len(context_resp)} 字符).\n", - extra_delta={"x_rag_status": "hit"} + reasoning_content=f"1. 上下文已检索 (长度: {len(context_resp)} 字符).\n", + extra_delta={"x_rag_status": rag_status} ) else: yield openai_chunk( - reasoning_content="2. 未找到相关上下文\n", - extra_delta={"x_rag_status": "miss"} + reasoning_content="未找到相关上下文\n", + extra_delta={"x_rag_status": rag_status} ) # 如果开启了仅RAG模式且未找到上下文,则直接结束 @@ -195,9 +183,11 @@ async def query_knowledge_base( return yield openai_chunk(reasoning_content=" (将依赖 LLM 自身知识)\n") - think = False - yield openai_chunk(reasoning_content="3. 答案生成中...\n") + # 未找到上下文,关闭思考模式 + request.think = False + + yield openai_chunk(reasoning_content="2. 答案生成中...\n") # 2. 生成答案 sys_prompt = CUSTOM_RAG_RESPONSE_PROMPT.format( @@ -211,7 +201,7 @@ async def query_knowledge_base( request.query, system_prompt=sys_prompt, stream=True, - think=think, + think=request.think, hashing_kv=rag.llm_response_cache ) @@ -231,9 +221,56 @@ async def query_knowledge_base( # 使用 text/event-stream Content-Type return StreamingResponse(stream_generator(), media_type="text/event-stream") - # 处理普通输出 - response = await rag.aquery(request.query, param=param) - return {"response": response} + # 处理普通输出 (OpenAI 兼容格式) + # 根据策略生成回答 + final_answer = "" + + if not has_context and request.only_rag: + # 严格模式且未检索到,直接结束 + final_answer = "未找到相关知识库内容。" + else: + # 正常生成 + sys_prompt = CUSTOM_RAG_RESPONSE_PROMPT.format( + context_data=context_resp, + response_type="Multiple Paragraphs", + user_prompt="" + ) + + # 调用 LLM 生成 + stream_resp = await llm_func( + request.query, + system_prompt=sys_prompt, + stream=False, + think=request.think, + hashing_kv=rag.llm_response_cache + ) + + # 非流式调用 LLM 直接返回结果 + final_answer = stream_resp + + return { + "id": f"chatcmpl-{secrets.token_hex(12)}", + "object": "chat.completion", + "created": int(time.time()), + "model": settings.LLM_MODEL, + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": final_answer, + # 扩展字段:非标准,但在某些客户端可能有用,或者放入 usage/metadata + "x_rag_status": rag_status + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": -1, + "completion_tokens": len(final_answer), + "total_tokens": -1 + } + } except Exception as e: logging.error(f"查询失败: {str(e)}")