fix: 调整 HTTP 方式的知识库query
This commit is contained in:
parent
41ddbcde2e
commit
be9801898e
|
|
@ -1,6 +1,7 @@
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
from fastapi import APIRouter, UploadFile, File, HTTPException, Body, Header, Depends
|
from fastapi import APIRouter, UploadFile, File, HTTPException, Body, Header, Depends
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
@ -115,16 +116,25 @@ async def query_knowledge_base(
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 构造查询参数
|
# 构造查询参数
|
||||||
param = QueryParam(
|
context_param = QueryParam(
|
||||||
mode=request.mode,
|
mode=request.mode,
|
||||||
top_k=request.top_k,
|
top_k=request.top_k,
|
||||||
stream=request.stream,
|
only_need_context=True,
|
||||||
enable_rerank=settings.RERANK_ENABLED
|
enable_rerank=settings.RERANK_ENABLED
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 执行上下文检索
|
||||||
|
context_resp = await rag.aquery(request.query, param=context_param)
|
||||||
|
|
||||||
|
# 判断检索命中状态
|
||||||
|
rag_status = "miss"
|
||||||
|
has_context = False
|
||||||
|
if context_resp and "[no-context]" not in context_resp and "None" not in context_resp:
|
||||||
|
rag_status = "hit"
|
||||||
|
has_context = True
|
||||||
|
|
||||||
# 处理流式输出 (SSE 协议 - OpenAI 兼容格式)
|
# 处理流式输出 (SSE 协议 - OpenAI 兼容格式)
|
||||||
if request.stream:
|
if request.stream:
|
||||||
import time
|
|
||||||
async def stream_generator():
|
async def stream_generator():
|
||||||
chat_id = f"chatcmpl-{secrets.token_hex(12)}"
|
chat_id = f"chatcmpl-{secrets.token_hex(12)}"
|
||||||
created_time = int(time.time())
|
created_time = int(time.time())
|
||||||
|
|
@ -155,37 +165,15 @@ async def query_knowledge_base(
|
||||||
}
|
}
|
||||||
return f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
|
return f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
# 1. 发送检索状态 (作为思考过程的一部分)
|
|
||||||
yield openai_chunk(reasoning_content="1. 上下文检索中...\n")
|
|
||||||
|
|
||||||
context_param = QueryParam(
|
|
||||||
mode=request.mode,
|
|
||||||
top_k=request.top_k,
|
|
||||||
only_need_context=True,
|
|
||||||
enable_rerank=settings.RERANK_ENABLED
|
|
||||||
)
|
|
||||||
|
|
||||||
# 获取上下文
|
|
||||||
context_resp = await rag.aquery(request.query, param=context_param)
|
|
||||||
logging.info(f"Context Response: {context_resp}")
|
|
||||||
|
|
||||||
# 判断检索状态
|
|
||||||
has_context = False
|
|
||||||
if context_resp and "[no-context]" not in context_resp and "None" not in context_resp:
|
|
||||||
has_context = True
|
|
||||||
|
|
||||||
# 判断是否开启think
|
|
||||||
think = request.think
|
|
||||||
|
|
||||||
if has_context:
|
if has_context:
|
||||||
yield openai_chunk(
|
yield openai_chunk(
|
||||||
reasoning_content=f"2. 上下文已检索 (长度: {len(context_resp)} 字符).\n",
|
reasoning_content=f"1. 上下文已检索 (长度: {len(context_resp)} 字符).\n",
|
||||||
extra_delta={"x_rag_status": "hit"}
|
extra_delta={"x_rag_status": rag_status}
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
yield openai_chunk(
|
yield openai_chunk(
|
||||||
reasoning_content="2. 未找到相关上下文\n",
|
reasoning_content="未找到相关上下文\n",
|
||||||
extra_delta={"x_rag_status": "miss"}
|
extra_delta={"x_rag_status": rag_status}
|
||||||
)
|
)
|
||||||
|
|
||||||
# 如果开启了仅RAG模式且未找到上下文,则直接结束
|
# 如果开启了仅RAG模式且未找到上下文,则直接结束
|
||||||
|
|
@ -195,9 +183,11 @@ async def query_knowledge_base(
|
||||||
return
|
return
|
||||||
|
|
||||||
yield openai_chunk(reasoning_content=" (将依赖 LLM 自身知识)\n")
|
yield openai_chunk(reasoning_content=" (将依赖 LLM 自身知识)\n")
|
||||||
think = False
|
|
||||||
|
|
||||||
yield openai_chunk(reasoning_content="3. 答案生成中...\n")
|
# 未找到上下文,关闭思考模式
|
||||||
|
request.think = False
|
||||||
|
|
||||||
|
yield openai_chunk(reasoning_content="2. 答案生成中...\n")
|
||||||
|
|
||||||
# 2. 生成答案
|
# 2. 生成答案
|
||||||
sys_prompt = CUSTOM_RAG_RESPONSE_PROMPT.format(
|
sys_prompt = CUSTOM_RAG_RESPONSE_PROMPT.format(
|
||||||
|
|
@ -211,7 +201,7 @@ async def query_knowledge_base(
|
||||||
request.query,
|
request.query,
|
||||||
system_prompt=sys_prompt,
|
system_prompt=sys_prompt,
|
||||||
stream=True,
|
stream=True,
|
||||||
think=think,
|
think=request.think,
|
||||||
hashing_kv=rag.llm_response_cache
|
hashing_kv=rag.llm_response_cache
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -231,9 +221,56 @@ async def query_knowledge_base(
|
||||||
# 使用 text/event-stream Content-Type
|
# 使用 text/event-stream Content-Type
|
||||||
return StreamingResponse(stream_generator(), media_type="text/event-stream")
|
return StreamingResponse(stream_generator(), media_type="text/event-stream")
|
||||||
|
|
||||||
# 处理普通输出
|
# 处理普通输出 (OpenAI 兼容格式)
|
||||||
response = await rag.aquery(request.query, param=param)
|
# 根据策略生成回答
|
||||||
return {"response": response}
|
final_answer = ""
|
||||||
|
|
||||||
|
if not has_context and request.only_rag:
|
||||||
|
# 严格模式且未检索到,直接结束
|
||||||
|
final_answer = "未找到相关知识库内容。"
|
||||||
|
else:
|
||||||
|
# 正常生成
|
||||||
|
sys_prompt = CUSTOM_RAG_RESPONSE_PROMPT.format(
|
||||||
|
context_data=context_resp,
|
||||||
|
response_type="Multiple Paragraphs",
|
||||||
|
user_prompt=""
|
||||||
|
)
|
||||||
|
|
||||||
|
# 调用 LLM 生成
|
||||||
|
stream_resp = await llm_func(
|
||||||
|
request.query,
|
||||||
|
system_prompt=sys_prompt,
|
||||||
|
stream=False,
|
||||||
|
think=request.think,
|
||||||
|
hashing_kv=rag.llm_response_cache
|
||||||
|
)
|
||||||
|
|
||||||
|
# 非流式调用 LLM 直接返回结果
|
||||||
|
final_answer = stream_resp
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": f"chatcmpl-{secrets.token_hex(12)}",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": int(time.time()),
|
||||||
|
"model": settings.LLM_MODEL,
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": final_answer,
|
||||||
|
# 扩展字段:非标准,但在某些客户端可能有用,或者放入 usage/metadata
|
||||||
|
"x_rag_status": rag_status
|
||||||
|
},
|
||||||
|
"finish_reason": "stop"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": -1,
|
||||||
|
"completion_tokens": len(final_answer),
|
||||||
|
"total_tokens": -1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"查询失败: {str(e)}")
|
logging.error(f"查询失败: {str(e)}")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue