339 lines
13 KiB
Python
339 lines
13 KiB
Python
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request, HTTPException
|
||
from fastapi.exceptions import RequestValidationError
|
||
from fastapi.responses import JSONResponse
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
|
||
from pydantic_settings import BaseSettings
|
||
from pydantic import BaseModel, Field
|
||
from funasr import AutoModel
|
||
import numpy as np
|
||
import argparse
|
||
import uvicorn
|
||
from urllib.parse import parse_qs
|
||
from loguru import logger
|
||
import sys
|
||
import re
|
||
import json
|
||
import traceback
|
||
import time
|
||
from modelscope import AutoModelForCausalLM, AutoTokenizer
|
||
import asyncio
|
||
import tempfile
|
||
import os
|
||
|
||
|
||
logger.remove()
|
||
log_format = "{time:YYYY-MM-DD HH:mm:ss} [{level}] {file}:{line} - {message}"
|
||
logger.add(sys.stdout, format=log_format, level="DEBUG", filter=lambda record: record["level"].no < 40)
|
||
logger.add(sys.stderr, format=log_format, level="ERROR", filter=lambda record: record["level"].no >= 40)
|
||
|
||
|
||
class Config(BaseSettings):
|
||
chunk_size_ms: int = Field(300, description="Chunk size in milliseconds")
|
||
sample_rate: int = Field(16000, description="Sample rate in Hz")
|
||
model_chat: str = Field("Qwen/Qwen3-0.6B", description="model")
|
||
sys_set: str = Field("你是一个感情丰富的聊天机器人,你名字叫蓝哥,回答内容必须是中文,只回答中文,不要有颜文字,也不要有表情符号,每次的回答内容必须超过20个文字", description="set")#当用户输入的话不完整、有歧义,或者你不太确定如何回答时,可以回应:'没听懂',
|
||
tts_rate:str = Field("+20%", description="tts rate")
|
||
tts_pitch: str = Field("+10Hz", description="tts rate")
|
||
voice:str = Field("zh-CN-XiaoxiaoNeural", description="voice")
|
||
|
||
config = Config()
|
||
messages = [
|
||
{"role": "system", "content": config.sys_set}
|
||
]
|
||
|
||
model_asr = AutoModel(
|
||
model="iic/SenseVoiceSmall",
|
||
trust_remote_code=True,
|
||
remote_code="./model.py",
|
||
device="cuda:0",
|
||
disable_update=True
|
||
)
|
||
|
||
model_vad = AutoModel(
|
||
model="fsmn-vad",
|
||
model_revision="v2.0.4",
|
||
disable_pbar=True,
|
||
max_end_silence_time=800,#500
|
||
# speech_noise_thres=0.6,
|
||
disable_update=True,
|
||
)
|
||
|
||
|
||
|
||
|
||
# load the tokenizer and the model
|
||
tokenizer = AutoTokenizer.from_pretrained(config.model_chat)
|
||
model_qw = AutoModelForCausalLM.from_pretrained(
|
||
config.model_chat,
|
||
torch_dtype="auto",
|
||
device_map="auto"
|
||
)
|
||
|
||
|
||
|
||
def asr(audio, lang, cache, use_itn=False):
|
||
start_time = time.time()
|
||
result = model_asr.generate(
|
||
input=audio,
|
||
cache=cache,
|
||
language=lang.strip(),
|
||
use_itn=use_itn,
|
||
batch_size_s=60,
|
||
)
|
||
end_time = time.time()
|
||
elapsed_time = end_time - start_time
|
||
logger.debug(f"asr elapsed: {elapsed_time * 1000:.2f} milliseconds")
|
||
return result
|
||
|
||
|
||
app = FastAPI()
|
||
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"],
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
|
||
@app.exception_handler(Exception)
|
||
async def custom_exception_handler(request: Request, exc: Exception):
|
||
logger.error("Exception occurred", exc_info=True)
|
||
if isinstance(exc, HTTPException):
|
||
status_code = exc.status_code
|
||
message = exc.detail
|
||
data = ""
|
||
elif isinstance(exc, RequestValidationError):
|
||
status_code = HTTP_422_UNPROCESSABLE_ENTITY
|
||
message = "Validation error: " + str(exc.errors())
|
||
data = ""
|
||
else:
|
||
status_code = 500
|
||
message = "Internal server error: " + str(exc)
|
||
data = ""
|
||
|
||
return JSONResponse(
|
||
status_code=status_code,
|
||
content=TranscriptionResponse(
|
||
code=status_code,
|
||
msg=message,
|
||
data=data
|
||
).model_dump()
|
||
)
|
||
|
||
|
||
# Define the response model
|
||
class TranscriptionResponse(BaseModel):
|
||
code: int
|
||
info: str
|
||
data: str
|
||
|
||
|
||
async def text_to_speech_stream_with_wav(text: str, websocket: WebSocket):
|
||
# 创建临时 MP3 文件(edge-tts 输出)
|
||
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_mp3:
|
||
temp_mp3_path = temp_mp3.name
|
||
|
||
try:
|
||
# 启动 edge-tts,输出到临时 MP3 文件
|
||
edge_tts_process = await asyncio.create_subprocess_exec(
|
||
"edge-tts",
|
||
"--text", text,
|
||
"--voice", config.voice,
|
||
"--rate", config.tts_rate,
|
||
"--pitch", config.tts_pitch,
|
||
"--write-media", temp_mp3_path,
|
||
stderr=asyncio.subprocess.PIPE,
|
||
)
|
||
await edge_tts_process.wait()
|
||
if edge_tts_process.returncode != 0:
|
||
error_msg = await edge_tts_process.stderr.read()
|
||
raise RuntimeError(f"TTS 生成失败: {error_msg.decode()}")
|
||
|
||
ffmpeg_process = await asyncio.create_subprocess_exec(
|
||
"ffmpeg",
|
||
"-i", temp_mp3_path,
|
||
"-f", "s16le", # 输出原始 16-bit PCM
|
||
"-ar", "16000", # 采样率 16kHz
|
||
"-ac", "1", # 单声道
|
||
"-loglevel", "warning",
|
||
"pipe:1",
|
||
stdout=asyncio.subprocess.PIPE,
|
||
stderr=asyncio.subprocess.PIPE,
|
||
)
|
||
|
||
# 流式传输 PCM 数据到前端
|
||
while True:
|
||
chunk = await ffmpeg_process.stdout.read(4096)
|
||
if not chunk:
|
||
break
|
||
await websocket.send_bytes(chunk)
|
||
|
||
await ffmpeg_process.wait()
|
||
if ffmpeg_process.returncode != 0:
|
||
error_msg = await ffmpeg_process.stderr.read()
|
||
raise RuntimeError(f"FFmpeg 转换失败: {error_msg.decode()}")
|
||
|
||
# 发送 EOF 标记(附带 WAV 头信息)
|
||
await websocket.send_json({
|
||
"type": "EOF",
|
||
"sampleRate": 16000,
|
||
"numChannels": 1,
|
||
"bitDepth": 16,
|
||
})
|
||
|
||
finally:
|
||
# 清理临时 MP3 文件
|
||
try:
|
||
if os.path.exists(temp_mp3_path):
|
||
os.unlink(temp_mp3_path)
|
||
except Exception as e:
|
||
print(f"临时文件清理失败: {e}")
|
||
|
||
async def text_to_speech_stream(text: str, websocket: WebSocket):
|
||
process = await asyncio.create_subprocess_exec(
|
||
"edge-tts",
|
||
"--text", text,
|
||
"--voice", config.voice,
|
||
"--rate", config.tts_rate,
|
||
"--pitch", config.tts_pitch,
|
||
stdout=asyncio.subprocess.PIPE, # 直接读取 stdout
|
||
stderr=asyncio.subprocess.PIPE,
|
||
)
|
||
|
||
# 直接从进程 stdout 读取音频流并发送
|
||
while chunk := await process.stdout.read(1024): # 1KB 块
|
||
await websocket.send_bytes(chunk)
|
||
|
||
# 等待进程结束
|
||
await process.wait()
|
||
|
||
|
||
@app.websocket("/ws/transcribe")
|
||
async def websocket_endpoint(websocket: WebSocket):
|
||
try:
|
||
query_params = parse_qs(websocket.scope['query_string'].decode())
|
||
sv = query_params.get('sv', ['false'])[0].lower() in ['true', '1', 't', 'y', 'yes']
|
||
lang = query_params.get('lang', ['auto'])[0].lower()
|
||
|
||
await websocket.accept()
|
||
chunk_size = int(config.chunk_size_ms * config.sample_rate / 1000)
|
||
audio_buffer = np.array([], dtype=np.float32)
|
||
audio_vad = np.array([], dtype=np.float32)
|
||
|
||
cache = {}
|
||
cache_asr = {}
|
||
last_vad_beg = last_vad_end = -1
|
||
offset = 0
|
||
|
||
buffer = b""
|
||
while True:
|
||
data = await websocket.receive_bytes()
|
||
# logger.info(f"received {len(data)} bytes")
|
||
|
||
buffer += data
|
||
if len(buffer) < 2:
|
||
continue
|
||
|
||
audio_buffer = np.append(
|
||
audio_buffer,
|
||
np.frombuffer(buffer[:len(buffer) - (len(buffer) % 2)], dtype=np.int16).astype(np.float32) / 32767.0
|
||
)
|
||
|
||
|
||
|
||
buffer = buffer[len(buffer) - (len(buffer) % 2):]
|
||
|
||
while len(audio_buffer) >= chunk_size:
|
||
chunk = audio_buffer[:chunk_size]
|
||
audio_buffer = audio_buffer[chunk_size:]
|
||
audio_vad = np.append(audio_vad, chunk)
|
||
res = model_vad.generate(input=chunk, cache=cache, is_final=False, chunk_size=config.chunk_size_ms)
|
||
if len(res[0]["value"]):
|
||
vad_segments = res[0]["value"]
|
||
for segment in vad_segments:
|
||
if segment[0] > -1: # speech begin
|
||
last_vad_beg = segment[0]
|
||
if segment[1] > -1: # speech end
|
||
last_vad_end = segment[1]
|
||
if last_vad_beg > -1 and last_vad_end > -1:
|
||
last_vad_beg -= offset
|
||
last_vad_end -= offset
|
||
offset += last_vad_end
|
||
beg = int(last_vad_beg * config.sample_rate / 1000)
|
||
end = int(last_vad_end * config.sample_rate / 1000)
|
||
logger.info(f"[vad segment] audio_len: {end - beg}")
|
||
result = asr(audio_vad[beg:end], lang.strip(), cache_asr, True)
|
||
logger.info(f"asr response: {result}")
|
||
audio_vad = audio_vad[end:]
|
||
last_vad_beg = last_vad_end = -1
|
||
|
||
if result is not None:
|
||
global messages
|
||
say=re.sub(r"<\|.*?\|>", "", result[0]['text'])
|
||
messages.append({"role": "user", "content": say})
|
||
|
||
text = tokenizer.apply_chat_template(
|
||
messages,
|
||
tokenize=False,
|
||
add_generation_prompt=True,
|
||
enable_thinking=False
|
||
|
||
)
|
||
model_inputs = tokenizer([text], return_tensors="pt").to(model_qw.device)
|
||
# conduct text completion
|
||
generated_ids = model_qw.generate(
|
||
**model_inputs,
|
||
max_new_tokens=32768
|
||
)
|
||
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
|
||
# parsing thinking content
|
||
try:
|
||
# rindex finding 151668 (</think>)
|
||
index = len(output_ids) - output_ids[::-1].index(151668)
|
||
except ValueError:
|
||
index = 0
|
||
content = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")
|
||
# 6. 重置 messages(保留系统消息,清空用户和助手消息)
|
||
# 假设第一条是系统消息,需要保留
|
||
if len(messages) > 0 and messages[0]["role"] == "system":
|
||
messages = [messages[0]] # 只保留系统消息
|
||
else:
|
||
messages = [] # 如果没有系统消息,直接清空
|
||
response = TranscriptionResponse(
|
||
code=0,
|
||
info=json.dumps(result[0], ensure_ascii=False),
|
||
data=content
|
||
)
|
||
|
||
# 发送聊天回复(文字)
|
||
await websocket.send_json(response.model_dump())
|
||
|
||
# 发送 TTS 音频流
|
||
await text_to_speech_stream_with_wav(content, websocket)
|
||
|
||
# logger.debug(f'last_vad_beg: {last_vad_beg}; last_vad_end: {last_vad_end} len(audio_vad): {len(audio_vad)}')
|
||
|
||
except WebSocketDisconnect:
|
||
logger.info("WebSocket disconnected")
|
||
except Exception as e:
|
||
logger.error(f"Unexpected error: {e}\nCall stack:\n{traceback.format_exc()}")
|
||
await websocket.close()
|
||
finally:
|
||
cache.clear()
|
||
logger.info("Cleaned up resources after WebSocket disconnect")
|
||
|
||
|
||
|
||
|
||
if __name__ == "__main__":
|
||
parser = argparse.ArgumentParser(description="Run the FastAPI app with a specified port.")
|
||
parser.add_argument('--port', type=int, default=27004, help='Port number to run the FastAPI app on.')
|
||
# parser.add_argument('--certfile', type=str, default='path_to_your_SSL_certificate_file.crt', help='SSL certificate file')
|
||
# parser.add_argument('--keyfile', type=str, default='path_to_your_SSL_certificate_file.key', help='SSL key file')
|
||
args = parser.parse_args()
|
||
# uvicorn.run(app, host="0.0.0.0", port=args.port, ssl_certfile=args.certfile, ssl_keyfile=args.keyfile)
|
||
uvicorn.run(app, host="0.0.0.0", port=args.port)
|