from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request, HTTPException from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY from pydantic_settings import BaseSettings from pydantic import BaseModel, Field from funasr import AutoModel import numpy as np import argparse import uvicorn from urllib.parse import parse_qs from loguru import logger import sys import re import json import traceback import time from modelscope import AutoModelForCausalLM, AutoTokenizer import asyncio import tempfile import os logger.remove() log_format = "{time:YYYY-MM-DD HH:mm:ss} [{level}] {file}:{line} - {message}" logger.add(sys.stdout, format=log_format, level="DEBUG", filter=lambda record: record["level"].no < 40) logger.add(sys.stderr, format=log_format, level="ERROR", filter=lambda record: record["level"].no >= 40) class Config(BaseSettings): chunk_size_ms: int = Field(300, description="Chunk size in milliseconds") sample_rate: int = Field(16000, description="Sample rate in Hz") model_chat: str = Field("Qwen/Qwen3-0.6B", description="model") sys_set: str = Field("你是一个感情丰富的聊天机器人,你名字叫蓝哥,回答内容必须是中文,只回答中文,不要有颜文字,也不要有表情符号,每次的回答内容必须超过20个文字", description="set")#当用户输入的话不完整、有歧义,或者你不太确定如何回答时,可以回应:'没听懂', tts_rate:str = Field("+20%", description="tts rate") tts_pitch: str = Field("+10Hz", description="tts rate") voice:str = Field("zh-CN-XiaoxiaoNeural", description="voice") config = Config() messages = [ {"role": "system", "content": config.sys_set} ] model_asr = AutoModel( model="iic/SenseVoiceSmall", trust_remote_code=True, remote_code="./model.py", device="cuda:0", disable_update=True ) model_vad = AutoModel( model="fsmn-vad", model_revision="v2.0.4", disable_pbar=True, max_end_silence_time=800,#500 # speech_noise_thres=0.6, disable_update=True, ) # load the tokenizer and the model tokenizer = AutoTokenizer.from_pretrained(config.model_chat) model_qw = AutoModelForCausalLM.from_pretrained( config.model_chat, torch_dtype="auto", device_map="auto" ) def asr(audio, lang, cache, use_itn=False): start_time = time.time() result = model_asr.generate( input=audio, cache=cache, language=lang.strip(), use_itn=use_itn, batch_size_s=60, ) end_time = time.time() elapsed_time = end_time - start_time logger.debug(f"asr elapsed: {elapsed_time * 1000:.2f} milliseconds") return result app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.exception_handler(Exception) async def custom_exception_handler(request: Request, exc: Exception): logger.error("Exception occurred", exc_info=True) if isinstance(exc, HTTPException): status_code = exc.status_code message = exc.detail data = "" elif isinstance(exc, RequestValidationError): status_code = HTTP_422_UNPROCESSABLE_ENTITY message = "Validation error: " + str(exc.errors()) data = "" else: status_code = 500 message = "Internal server error: " + str(exc) data = "" return JSONResponse( status_code=status_code, content=TranscriptionResponse( code=status_code, msg=message, data=data ).model_dump() ) # Define the response model class TranscriptionResponse(BaseModel): code: int info: str data: str async def text_to_speech_stream_with_wav(text: str, websocket: WebSocket): # 创建临时 MP3 文件(edge-tts 输出) with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_mp3: temp_mp3_path = temp_mp3.name try: # 启动 edge-tts,输出到临时 MP3 文件 edge_tts_process = await asyncio.create_subprocess_exec( "edge-tts", "--text", text, "--voice", config.voice, "--rate", config.tts_rate, "--pitch", config.tts_pitch, "--write-media", temp_mp3_path, stderr=asyncio.subprocess.PIPE, ) await edge_tts_process.wait() if edge_tts_process.returncode != 0: error_msg = await edge_tts_process.stderr.read() raise RuntimeError(f"TTS 生成失败: {error_msg.decode()}") ffmpeg_process = await asyncio.create_subprocess_exec( "ffmpeg", "-i", temp_mp3_path, "-f", "s16le", # 输出原始 16-bit PCM "-ar", "16000", # 采样率 16kHz "-ac", "1", # 单声道 "-loglevel", "warning", "pipe:1", stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) # 流式传输 PCM 数据到前端 while True: chunk = await ffmpeg_process.stdout.read(4096) if not chunk: break await websocket.send_bytes(chunk) await ffmpeg_process.wait() if ffmpeg_process.returncode != 0: error_msg = await ffmpeg_process.stderr.read() raise RuntimeError(f"FFmpeg 转换失败: {error_msg.decode()}") # 发送 EOF 标记(附带 WAV 头信息) await websocket.send_json({ "type": "EOF", "sampleRate": 16000, "numChannels": 1, "bitDepth": 16, }) finally: # 清理临时 MP3 文件 try: if os.path.exists(temp_mp3_path): os.unlink(temp_mp3_path) except Exception as e: print(f"临时文件清理失败: {e}") async def text_to_speech_stream(text: str, websocket: WebSocket): process = await asyncio.create_subprocess_exec( "edge-tts", "--text", text, "--voice", config.voice, "--rate", config.tts_rate, "--pitch", config.tts_pitch, stdout=asyncio.subprocess.PIPE, # 直接读取 stdout stderr=asyncio.subprocess.PIPE, ) # 直接从进程 stdout 读取音频流并发送 while chunk := await process.stdout.read(1024): # 1KB 块 await websocket.send_bytes(chunk) # 等待进程结束 await process.wait() @app.websocket("/ws/transcribe") async def websocket_endpoint(websocket: WebSocket): try: query_params = parse_qs(websocket.scope['query_string'].decode()) sv = query_params.get('sv', ['false'])[0].lower() in ['true', '1', 't', 'y', 'yes'] lang = query_params.get('lang', ['auto'])[0].lower() await websocket.accept() chunk_size = int(config.chunk_size_ms * config.sample_rate / 1000) audio_buffer = np.array([], dtype=np.float32) audio_vad = np.array([], dtype=np.float32) cache = {} cache_asr = {} last_vad_beg = last_vad_end = -1 offset = 0 buffer = b"" while True: data = await websocket.receive_bytes() # logger.info(f"received {len(data)} bytes") buffer += data if len(buffer) < 2: continue audio_buffer = np.append( audio_buffer, np.frombuffer(buffer[:len(buffer) - (len(buffer) % 2)], dtype=np.int16).astype(np.float32) / 32767.0 ) buffer = buffer[len(buffer) - (len(buffer) % 2):] while len(audio_buffer) >= chunk_size: chunk = audio_buffer[:chunk_size] audio_buffer = audio_buffer[chunk_size:] audio_vad = np.append(audio_vad, chunk) res = model_vad.generate(input=chunk, cache=cache, is_final=False, chunk_size=config.chunk_size_ms) if len(res[0]["value"]): vad_segments = res[0]["value"] for segment in vad_segments: if segment[0] > -1: # speech begin last_vad_beg = segment[0] if segment[1] > -1: # speech end last_vad_end = segment[1] if last_vad_beg > -1 and last_vad_end > -1: last_vad_beg -= offset last_vad_end -= offset offset += last_vad_end beg = int(last_vad_beg * config.sample_rate / 1000) end = int(last_vad_end * config.sample_rate / 1000) logger.info(f"[vad segment] audio_len: {end - beg}") result = asr(audio_vad[beg:end], lang.strip(), cache_asr, True) logger.info(f"asr response: {result}") audio_vad = audio_vad[end:] last_vad_beg = last_vad_end = -1 if result is not None: global messages say=re.sub(r"<\|.*?\|>", "", result[0]['text']) messages.append({"role": "user", "content": say}) text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, enable_thinking=False ) model_inputs = tokenizer([text], return_tensors="pt").to(model_qw.device) # conduct text completion generated_ids = model_qw.generate( **model_inputs, max_new_tokens=32768 ) output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist() # parsing thinking content try: # rindex finding 151668 () index = len(output_ids) - output_ids[::-1].index(151668) except ValueError: index = 0 content = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n") # 6. 重置 messages(保留系统消息,清空用户和助手消息) # 假设第一条是系统消息,需要保留 if len(messages) > 0 and messages[0]["role"] == "system": messages = [messages[0]] # 只保留系统消息 else: messages = [] # 如果没有系统消息,直接清空 response = TranscriptionResponse( code=0, info=json.dumps(result[0], ensure_ascii=False), data=content ) # 发送聊天回复(文字) await websocket.send_json(response.model_dump()) # 发送 TTS 音频流 await text_to_speech_stream_with_wav(content, websocket) # logger.debug(f'last_vad_beg: {last_vad_beg}; last_vad_end: {last_vad_end} len(audio_vad): {len(audio_vad)}') except WebSocketDisconnect: logger.info("WebSocket disconnected") except Exception as e: logger.error(f"Unexpected error: {e}\nCall stack:\n{traceback.format_exc()}") await websocket.close() finally: cache.clear() logger.info("Cleaned up resources after WebSocket disconnect") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run the FastAPI app with a specified port.") parser.add_argument('--port', type=int, default=27004, help='Port number to run the FastAPI app on.') # parser.add_argument('--certfile', type=str, default='path_to_your_SSL_certificate_file.crt', help='SSL certificate file') # parser.add_argument('--keyfile', type=str, default='path_to_your_SSL_certificate_file.key', help='SSL key file') args = parser.parse_args() # uvicorn.run(app, host="0.0.0.0", port=args.port, ssl_certfile=args.certfile, ssl_keyfile=args.keyfile) uvicorn.run(app, host="0.0.0.0", port=args.port)