ai_test/app.py

339 lines
13 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request, HTTPException
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
from pydantic_settings import BaseSettings
from pydantic import BaseModel, Field
from funasr import AutoModel
import numpy as np
import argparse
import uvicorn
from urllib.parse import parse_qs
from loguru import logger
import sys
import re
import json
import traceback
import time
from modelscope import AutoModelForCausalLM, AutoTokenizer
import asyncio
import tempfile
import os
logger.remove()
log_format = "{time:YYYY-MM-DD HH:mm:ss} [{level}] {file}:{line} - {message}"
logger.add(sys.stdout, format=log_format, level="DEBUG", filter=lambda record: record["level"].no < 40)
logger.add(sys.stderr, format=log_format, level="ERROR", filter=lambda record: record["level"].no >= 40)
class Config(BaseSettings):
chunk_size_ms: int = Field(300, description="Chunk size in milliseconds")
sample_rate: int = Field(16000, description="Sample rate in Hz")
model_chat: str = Field("Qwen/Qwen3-0.6B", description="model")
sys_set: str = Field("你是一个感情丰富的聊天机器人,你名字叫蓝哥,回答内容必须是中文,只回答中文,不要有颜文字,也不要有表情符号,每次的回答内容必须超过20个文字", description="set")#当用户输入的话不完整、有歧义,或者你不太确定如何回答时,可以回应:'没听懂',
tts_rate:str = Field("+20%", description="tts rate")
tts_pitch: str = Field("+10Hz", description="tts rate")
voice:str = Field("zh-CN-XiaoxiaoNeural", description="voice")
config = Config()
messages = [
{"role": "system", "content": config.sys_set}
]
model_asr = AutoModel(
model="iic/SenseVoiceSmall",
trust_remote_code=True,
remote_code="./model.py",
device="cuda:0",
disable_update=True
)
model_vad = AutoModel(
model="fsmn-vad",
model_revision="v2.0.4",
disable_pbar=True,
max_end_silence_time=800,#500
# speech_noise_thres=0.6,
disable_update=True,
)
# load the tokenizer and the model
tokenizer = AutoTokenizer.from_pretrained(config.model_chat)
model_qw = AutoModelForCausalLM.from_pretrained(
config.model_chat,
torch_dtype="auto",
device_map="auto"
)
def asr(audio, lang, cache, use_itn=False):
start_time = time.time()
result = model_asr.generate(
input=audio,
cache=cache,
language=lang.strip(),
use_itn=use_itn,
batch_size_s=60,
)
end_time = time.time()
elapsed_time = end_time - start_time
logger.debug(f"asr elapsed: {elapsed_time * 1000:.2f} milliseconds")
return result
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.exception_handler(Exception)
async def custom_exception_handler(request: Request, exc: Exception):
logger.error("Exception occurred", exc_info=True)
if isinstance(exc, HTTPException):
status_code = exc.status_code
message = exc.detail
data = ""
elif isinstance(exc, RequestValidationError):
status_code = HTTP_422_UNPROCESSABLE_ENTITY
message = "Validation error: " + str(exc.errors())
data = ""
else:
status_code = 500
message = "Internal server error: " + str(exc)
data = ""
return JSONResponse(
status_code=status_code,
content=TranscriptionResponse(
code=status_code,
msg=message,
data=data
).model_dump()
)
# Define the response model
class TranscriptionResponse(BaseModel):
code: int
info: str
data: str
async def text_to_speech_stream_with_wav(text: str, websocket: WebSocket):
# 创建临时 MP3 文件edge-tts 输出)
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_mp3:
temp_mp3_path = temp_mp3.name
try:
# 启动 edge-tts输出到临时 MP3 文件
edge_tts_process = await asyncio.create_subprocess_exec(
"edge-tts",
"--text", text,
"--voice", config.voice,
"--rate", config.tts_rate,
"--pitch", config.tts_pitch,
"--write-media", temp_mp3_path,
stderr=asyncio.subprocess.PIPE,
)
await edge_tts_process.wait()
if edge_tts_process.returncode != 0:
error_msg = await edge_tts_process.stderr.read()
raise RuntimeError(f"TTS 生成失败: {error_msg.decode()}")
ffmpeg_process = await asyncio.create_subprocess_exec(
"ffmpeg",
"-i", temp_mp3_path,
"-f", "s16le", # 输出原始 16-bit PCM
"-ar", "16000", # 采样率 16kHz
"-ac", "1", # 单声道
"-loglevel", "warning",
"pipe:1",
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
# 流式传输 PCM 数据到前端
while True:
chunk = await ffmpeg_process.stdout.read(4096)
if not chunk:
break
await websocket.send_bytes(chunk)
await ffmpeg_process.wait()
if ffmpeg_process.returncode != 0:
error_msg = await ffmpeg_process.stderr.read()
raise RuntimeError(f"FFmpeg 转换失败: {error_msg.decode()}")
# 发送 EOF 标记(附带 WAV 头信息)
await websocket.send_json({
"type": "EOF",
"sampleRate": 16000,
"numChannels": 1,
"bitDepth": 16,
})
finally:
# 清理临时 MP3 文件
try:
if os.path.exists(temp_mp3_path):
os.unlink(temp_mp3_path)
except Exception as e:
print(f"临时文件清理失败: {e}")
async def text_to_speech_stream(text: str, websocket: WebSocket):
process = await asyncio.create_subprocess_exec(
"edge-tts",
"--text", text,
"--voice", config.voice,
"--rate", config.tts_rate,
"--pitch", config.tts_pitch,
stdout=asyncio.subprocess.PIPE, # 直接读取 stdout
stderr=asyncio.subprocess.PIPE,
)
# 直接从进程 stdout 读取音频流并发送
while chunk := await process.stdout.read(1024): # 1KB 块
await websocket.send_bytes(chunk)
# 等待进程结束
await process.wait()
@app.websocket("/ws/transcribe")
async def websocket_endpoint(websocket: WebSocket):
try:
query_params = parse_qs(websocket.scope['query_string'].decode())
sv = query_params.get('sv', ['false'])[0].lower() in ['true', '1', 't', 'y', 'yes']
lang = query_params.get('lang', ['auto'])[0].lower()
await websocket.accept()
chunk_size = int(config.chunk_size_ms * config.sample_rate / 1000)
audio_buffer = np.array([], dtype=np.float32)
audio_vad = np.array([], dtype=np.float32)
cache = {}
cache_asr = {}
last_vad_beg = last_vad_end = -1
offset = 0
buffer = b""
while True:
data = await websocket.receive_bytes()
# logger.info(f"received {len(data)} bytes")
buffer += data
if len(buffer) < 2:
continue
audio_buffer = np.append(
audio_buffer,
np.frombuffer(buffer[:len(buffer) - (len(buffer) % 2)], dtype=np.int16).astype(np.float32) / 32767.0
)
buffer = buffer[len(buffer) - (len(buffer) % 2):]
while len(audio_buffer) >= chunk_size:
chunk = audio_buffer[:chunk_size]
audio_buffer = audio_buffer[chunk_size:]
audio_vad = np.append(audio_vad, chunk)
res = model_vad.generate(input=chunk, cache=cache, is_final=False, chunk_size=config.chunk_size_ms)
if len(res[0]["value"]):
vad_segments = res[0]["value"]
for segment in vad_segments:
if segment[0] > -1: # speech begin
last_vad_beg = segment[0]
if segment[1] > -1: # speech end
last_vad_end = segment[1]
if last_vad_beg > -1 and last_vad_end > -1:
last_vad_beg -= offset
last_vad_end -= offset
offset += last_vad_end
beg = int(last_vad_beg * config.sample_rate / 1000)
end = int(last_vad_end * config.sample_rate / 1000)
logger.info(f"[vad segment] audio_len: {end - beg}")
result = asr(audio_vad[beg:end], lang.strip(), cache_asr, True)
logger.info(f"asr response: {result}")
audio_vad = audio_vad[end:]
last_vad_beg = last_vad_end = -1
if result is not None:
global messages
say=re.sub(r"<\|.*?\|>", "", result[0]['text'])
messages.append({"role": "user", "content": say})
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False
)
model_inputs = tokenizer([text], return_tensors="pt").to(model_qw.device)
# conduct text completion
generated_ids = model_qw.generate(
**model_inputs,
max_new_tokens=32768
)
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
# parsing thinking content
try:
# rindex finding 151668 (</think>)
index = len(output_ids) - output_ids[::-1].index(151668)
except ValueError:
index = 0
content = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")
# 6. 重置 messages保留系统消息清空用户和助手消息
# 假设第一条是系统消息,需要保留
if len(messages) > 0 and messages[0]["role"] == "system":
messages = [messages[0]] # 只保留系统消息
else:
messages = [] # 如果没有系统消息,直接清空
response = TranscriptionResponse(
code=0,
info=json.dumps(result[0], ensure_ascii=False),
data=content
)
# 发送聊天回复(文字)
await websocket.send_json(response.model_dump())
# 发送 TTS 音频流
await text_to_speech_stream_with_wav(content, websocket)
# logger.debug(f'last_vad_beg: {last_vad_beg}; last_vad_end: {last_vad_end} len(audio_vad): {len(audio_vad)}')
except WebSocketDisconnect:
logger.info("WebSocket disconnected")
except Exception as e:
logger.error(f"Unexpected error: {e}\nCall stack:\n{traceback.format_exc()}")
await websocket.close()
finally:
cache.clear()
logger.info("Cleaned up resources after WebSocket disconnect")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run the FastAPI app with a specified port.")
parser.add_argument('--port', type=int, default=27004, help='Port number to run the FastAPI app on.')
# parser.add_argument('--certfile', type=str, default='path_to_your_SSL_certificate_file.crt', help='SSL certificate file')
# parser.add_argument('--keyfile', type=str, default='path_to_your_SSL_certificate_file.key', help='SSL key file')
args = parser.parse_args()
# uvicorn.run(app, host="0.0.0.0", port=args.port, ssl_certfile=args.certfile, ssl_keyfile=args.keyfile)
uvicorn.run(app, host="0.0.0.0", port=args.port)