fix: 调整 HTTP 方式的知识库query
This commit is contained in:
parent
41ddbcde2e
commit
be9801898e
|
|
@ -1,6 +1,7 @@
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from fastapi import APIRouter, UploadFile, File, HTTPException, Body, Header, Depends
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
|
@ -115,16 +116,25 @@ async def query_knowledge_base(
|
|||
"""
|
||||
try:
|
||||
# 构造查询参数
|
||||
param = QueryParam(
|
||||
context_param = QueryParam(
|
||||
mode=request.mode,
|
||||
top_k=request.top_k,
|
||||
stream=request.stream,
|
||||
only_need_context=True,
|
||||
enable_rerank=settings.RERANK_ENABLED
|
||||
)
|
||||
|
||||
# 执行上下文检索
|
||||
context_resp = await rag.aquery(request.query, param=context_param)
|
||||
|
||||
# 判断检索命中状态
|
||||
rag_status = "miss"
|
||||
has_context = False
|
||||
if context_resp and "[no-context]" not in context_resp and "None" not in context_resp:
|
||||
rag_status = "hit"
|
||||
has_context = True
|
||||
|
||||
# 处理流式输出 (SSE 协议 - OpenAI 兼容格式)
|
||||
if request.stream:
|
||||
import time
|
||||
async def stream_generator():
|
||||
chat_id = f"chatcmpl-{secrets.token_hex(12)}"
|
||||
created_time = int(time.time())
|
||||
|
|
@ -155,37 +165,15 @@ async def query_knowledge_base(
|
|||
}
|
||||
return f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 1. 发送检索状态 (作为思考过程的一部分)
|
||||
yield openai_chunk(reasoning_content="1. 上下文检索中...\n")
|
||||
|
||||
context_param = QueryParam(
|
||||
mode=request.mode,
|
||||
top_k=request.top_k,
|
||||
only_need_context=True,
|
||||
enable_rerank=settings.RERANK_ENABLED
|
||||
)
|
||||
|
||||
# 获取上下文
|
||||
context_resp = await rag.aquery(request.query, param=context_param)
|
||||
logging.info(f"Context Response: {context_resp}")
|
||||
|
||||
# 判断检索状态
|
||||
has_context = False
|
||||
if context_resp and "[no-context]" not in context_resp and "None" not in context_resp:
|
||||
has_context = True
|
||||
|
||||
# 判断是否开启think
|
||||
think = request.think
|
||||
|
||||
if has_context:
|
||||
yield openai_chunk(
|
||||
reasoning_content=f"2. 上下文已检索 (长度: {len(context_resp)} 字符).\n",
|
||||
extra_delta={"x_rag_status": "hit"}
|
||||
reasoning_content=f"1. 上下文已检索 (长度: {len(context_resp)} 字符).\n",
|
||||
extra_delta={"x_rag_status": rag_status}
|
||||
)
|
||||
else:
|
||||
yield openai_chunk(
|
||||
reasoning_content="2. 未找到相关上下文\n",
|
||||
extra_delta={"x_rag_status": "miss"}
|
||||
reasoning_content="未找到相关上下文\n",
|
||||
extra_delta={"x_rag_status": rag_status}
|
||||
)
|
||||
|
||||
# 如果开启了仅RAG模式且未找到上下文,则直接结束
|
||||
|
|
@ -195,9 +183,11 @@ async def query_knowledge_base(
|
|||
return
|
||||
|
||||
yield openai_chunk(reasoning_content=" (将依赖 LLM 自身知识)\n")
|
||||
think = False
|
||||
|
||||
yield openai_chunk(reasoning_content="3. 答案生成中...\n")
|
||||
# 未找到上下文,关闭思考模式
|
||||
request.think = False
|
||||
|
||||
yield openai_chunk(reasoning_content="2. 答案生成中...\n")
|
||||
|
||||
# 2. 生成答案
|
||||
sys_prompt = CUSTOM_RAG_RESPONSE_PROMPT.format(
|
||||
|
|
@ -211,7 +201,7 @@ async def query_knowledge_base(
|
|||
request.query,
|
||||
system_prompt=sys_prompt,
|
||||
stream=True,
|
||||
think=think,
|
||||
think=request.think,
|
||||
hashing_kv=rag.llm_response_cache
|
||||
)
|
||||
|
||||
|
|
@ -231,9 +221,56 @@ async def query_knowledge_base(
|
|||
# 使用 text/event-stream Content-Type
|
||||
return StreamingResponse(stream_generator(), media_type="text/event-stream")
|
||||
|
||||
# 处理普通输出
|
||||
response = await rag.aquery(request.query, param=param)
|
||||
return {"response": response}
|
||||
# 处理普通输出 (OpenAI 兼容格式)
|
||||
# 根据策略生成回答
|
||||
final_answer = ""
|
||||
|
||||
if not has_context and request.only_rag:
|
||||
# 严格模式且未检索到,直接结束
|
||||
final_answer = "未找到相关知识库内容。"
|
||||
else:
|
||||
# 正常生成
|
||||
sys_prompt = CUSTOM_RAG_RESPONSE_PROMPT.format(
|
||||
context_data=context_resp,
|
||||
response_type="Multiple Paragraphs",
|
||||
user_prompt=""
|
||||
)
|
||||
|
||||
# 调用 LLM 生成
|
||||
stream_resp = await llm_func(
|
||||
request.query,
|
||||
system_prompt=sys_prompt,
|
||||
stream=False,
|
||||
think=request.think,
|
||||
hashing_kv=rag.llm_response_cache
|
||||
)
|
||||
|
||||
# 非流式调用 LLM 直接返回结果
|
||||
final_answer = stream_resp
|
||||
|
||||
return {
|
||||
"id": f"chatcmpl-{secrets.token_hex(12)}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": settings.LLM_MODEL,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": final_answer,
|
||||
# 扩展字段:非标准,但在某些客户端可能有用,或者放入 usage/metadata
|
||||
"x_rag_status": rag_status
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": -1,
|
||||
"completion_tokens": len(final_answer),
|
||||
"total_tokens": -1
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"查询失败: {str(e)}")
|
||||
|
|
|
|||
Loading…
Reference in New Issue