fix: 调整 HTTP 方式的知识库query

This commit is contained in:
fuzhongyun 2026-01-26 14:40:36 +08:00
parent 41ddbcde2e
commit be9801898e
1 changed files with 72 additions and 35 deletions

View File

@ -1,6 +1,7 @@
import json import json
import logging import logging
import os import os
import time
from fastapi import APIRouter, UploadFile, File, HTTPException, Body, Header, Depends from fastapi import APIRouter, UploadFile, File, HTTPException, Body, Header, Depends
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from pydantic import BaseModel from pydantic import BaseModel
@ -115,16 +116,25 @@ async def query_knowledge_base(
""" """
try: try:
# 构造查询参数 # 构造查询参数
param = QueryParam( context_param = QueryParam(
mode=request.mode, mode=request.mode,
top_k=request.top_k, top_k=request.top_k,
stream=request.stream, only_need_context=True,
enable_rerank=settings.RERANK_ENABLED enable_rerank=settings.RERANK_ENABLED
) )
# 执行上下文检索
context_resp = await rag.aquery(request.query, param=context_param)
# 判断检索命中状态
rag_status = "miss"
has_context = False
if context_resp and "[no-context]" not in context_resp and "None" not in context_resp:
rag_status = "hit"
has_context = True
# 处理流式输出 (SSE 协议 - OpenAI 兼容格式) # 处理流式输出 (SSE 协议 - OpenAI 兼容格式)
if request.stream: if request.stream:
import time
async def stream_generator(): async def stream_generator():
chat_id = f"chatcmpl-{secrets.token_hex(12)}" chat_id = f"chatcmpl-{secrets.token_hex(12)}"
created_time = int(time.time()) created_time = int(time.time())
@ -155,37 +165,15 @@ async def query_knowledge_base(
} }
return f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n" return f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
# 1. 发送检索状态 (作为思考过程的一部分)
yield openai_chunk(reasoning_content="1. 上下文检索中...\n")
context_param = QueryParam(
mode=request.mode,
top_k=request.top_k,
only_need_context=True,
enable_rerank=settings.RERANK_ENABLED
)
# 获取上下文
context_resp = await rag.aquery(request.query, param=context_param)
logging.info(f"Context Response: {context_resp}")
# 判断检索状态
has_context = False
if context_resp and "[no-context]" not in context_resp and "None" not in context_resp:
has_context = True
# 判断是否开启think
think = request.think
if has_context: if has_context:
yield openai_chunk( yield openai_chunk(
reasoning_content=f"2. 上下文已检索 (长度: {len(context_resp)} 字符).\n", reasoning_content=f"1. 上下文已检索 (长度: {len(context_resp)} 字符).\n",
extra_delta={"x_rag_status": "hit"} extra_delta={"x_rag_status": rag_status}
) )
else: else:
yield openai_chunk( yield openai_chunk(
reasoning_content="2. 未找到相关上下文\n", reasoning_content="未找到相关上下文\n",
extra_delta={"x_rag_status": "miss"} extra_delta={"x_rag_status": rag_status}
) )
# 如果开启了仅RAG模式且未找到上下文则直接结束 # 如果开启了仅RAG模式且未找到上下文则直接结束
@ -195,9 +183,11 @@ async def query_knowledge_base(
return return
yield openai_chunk(reasoning_content=" (将依赖 LLM 自身知识)\n") yield openai_chunk(reasoning_content=" (将依赖 LLM 自身知识)\n")
think = False
yield openai_chunk(reasoning_content="3. 答案生成中...\n") # 未找到上下文,关闭思考模式
request.think = False
yield openai_chunk(reasoning_content="2. 答案生成中...\n")
# 2. 生成答案 # 2. 生成答案
sys_prompt = CUSTOM_RAG_RESPONSE_PROMPT.format( sys_prompt = CUSTOM_RAG_RESPONSE_PROMPT.format(
@ -211,7 +201,7 @@ async def query_knowledge_base(
request.query, request.query,
system_prompt=sys_prompt, system_prompt=sys_prompt,
stream=True, stream=True,
think=think, think=request.think,
hashing_kv=rag.llm_response_cache hashing_kv=rag.llm_response_cache
) )
@ -231,9 +221,56 @@ async def query_knowledge_base(
# 使用 text/event-stream Content-Type # 使用 text/event-stream Content-Type
return StreamingResponse(stream_generator(), media_type="text/event-stream") return StreamingResponse(stream_generator(), media_type="text/event-stream")
# 处理普通输出 # 处理普通输出 (OpenAI 兼容格式)
response = await rag.aquery(request.query, param=param) # 根据策略生成回答
return {"response": response} final_answer = ""
if not has_context and request.only_rag:
# 严格模式且未检索到,直接结束
final_answer = "未找到相关知识库内容。"
else:
# 正常生成
sys_prompt = CUSTOM_RAG_RESPONSE_PROMPT.format(
context_data=context_resp,
response_type="Multiple Paragraphs",
user_prompt=""
)
# 调用 LLM 生成
stream_resp = await llm_func(
request.query,
system_prompt=sys_prompt,
stream=False,
think=request.think,
hashing_kv=rag.llm_response_cache
)
# 非流式调用 LLM 直接返回结果
final_answer = stream_resp
return {
"id": f"chatcmpl-{secrets.token_hex(12)}",
"object": "chat.completion",
"created": int(time.time()),
"model": settings.LLM_MODEL,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": final_answer,
# 扩展字段:非标准,但在某些客户端可能有用,或者放入 usage/metadata
"x_rag_status": rag_status
},
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": -1,
"completion_tokens": len(final_answer),
"total_tokens": -1
}
}
except Exception as e: except Exception as e:
logging.error(f"查询失败: {str(e)}") logging.error(f"查询失败: {str(e)}")