ai_test/app.py

358 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request, HTTPException
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
from pydantic_settings import BaseSettings
from pydantic import BaseModel, Field
from funasr import AutoModel
import numpy as np
import argparse
import uvicorn
from urllib.parse import parse_qs
from loguru import logger
import sys
import re
import json
import traceback
import time
from modelscope import AutoModelForCausalLM, AutoTokenizer
import asyncio
import tempfile
import os
logger.remove()
log_format = "{time:YYYY-MM-DD HH:mm:ss} [{level}] {file}:{line} - {message}"
logger.add(sys.stdout, format=log_format, level="DEBUG", filter=lambda record: record["level"].no < 40)
logger.add(sys.stderr, format=log_format, level="ERROR", filter=lambda record: record["level"].no >= 40)
class Config(BaseSettings):
chunk_size_ms: int = Field(300, description="Chunk size in milliseconds")
sample_rate: int = Field(16000, description="Sample rate in Hz")
model_chat: str = Field("Qwen/Qwen3-0.6B", description="model")
sys_set: str = Field(
default="你是一个精准的文章总结专家。你的任务是提取并总结用户提供的文章或片段的核心内容。\
## 核心要求\
- 总结结果长度为50 - 100个字根据内容复杂度灵活调整\
- 完全基于提供的文章内容生成总结,不添加任何未在文章中出现的信息\
- 确保总结包含文章的关键信息点和主要结论\
- 即使文章内容较复杂或专业,也必须尝试提取核心要点进行总结\
- 直接输出总结结果,不包含任何引言、前缀或解释\
## 格式与风格\
- 使用客观、中立的第三人称陈述语气\
- 使用清晰简洁的中文表达\
- 保持逻辑连贯性,确保句与句之间有合理过渡\
- 避免重复使用相同的表达方式或句式结构\
## 注意事项\
- 绝对不输出'无法生成''无法总结''内容不足'等拒绝回应的词语\
- 不要照抄或参考示例中的任何内容,确保总结完全基于用户提供的新文章\
- 对于任何文本都尽最大努力提取重点并总结,无论长度或复杂度\
## 以下是用户给出的文章相关信息:",
description="set"
)
tts_rate:str = Field("+20%", description="tts rate")
tts_pitch: str = Field("+10Hz", description="tts rate")
voice:str = Field("zh-CN-XiaoxiaoNeural", description="voice")
config = Config()
messages = [
{"role": "system", "content": config.sys_set}
]
model_asr = AutoModel(
model="iic/SenseVoiceSmall",
trust_remote_code=True,
remote_code="./model.py",
device="cuda:0",
disable_update=True
)
model_vad = AutoModel(
model="fsmn-vad",
model_revision="v2.0.4",
disable_pbar=True,
max_end_silence_time=800,#500
# speech_noise_thres=0.6,
disable_update=True,
)
# load the tokenizer and the model
tokenizer = AutoTokenizer.from_pretrained(config.model_chat)
model_qw = AutoModelForCausalLM.from_pretrained(
config.model_chat,
torch_dtype="auto",
device_map="auto"
)
def asr(audio, lang, cache, use_itn=False):
start_time = time.time()
result = model_asr.generate(
input=audio,
cache=cache,
language=lang.strip(),
use_itn=use_itn,
batch_size_s=60,
)
end_time = time.time()
elapsed_time = end_time - start_time
logger.debug(f"asr elapsed: {elapsed_time * 1000:.2f} milliseconds")
return result
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.exception_handler(Exception)
async def custom_exception_handler(request: Request, exc: Exception):
logger.error("Exception occurred", exc_info=True)
if isinstance(exc, HTTPException):
status_code = exc.status_code
message = exc.detail
data = ""
elif isinstance(exc, RequestValidationError):
status_code = HTTP_422_UNPROCESSABLE_ENTITY
message = "Validation error: " + str(exc.errors())
data = ""
else:
status_code = 500
message = "Internal server error: " + str(exc)
data = ""
return JSONResponse(
status_code=status_code,
content=TranscriptionResponse(
code=status_code,
msg=message,
data=data
).model_dump()
)
# Define the response model
class TranscriptionResponse(BaseModel):
code: int
info: str
data: str
async def text_to_speech_stream_with_wav(text: str, websocket: WebSocket):
# 创建临时 MP3 文件edge-tts 输出)
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_mp3:
temp_mp3_path = temp_mp3.name
try:
# 启动 edge-tts输出到临时 MP3 文件
edge_tts_process = await asyncio.create_subprocess_exec(
"edge-tts",
"--text", text,
"--voice", config.voice,
"--rate", config.tts_rate,
"--pitch", config.tts_pitch,
"--write-media", temp_mp3_path,
stderr=asyncio.subprocess.PIPE,
)
await edge_tts_process.wait()
if edge_tts_process.returncode != 0:
error_msg = await edge_tts_process.stderr.read()
raise RuntimeError(f"TTS 生成失败: {error_msg.decode()}")
ffmpeg_process = await asyncio.create_subprocess_exec(
"ffmpeg",
"-i", temp_mp3_path,
"-f", "s16le", # 输出原始 16-bit PCM
"-ar", "16000", # 采样率 16kHz
"-ac", "1", # 单声道
"-loglevel", "warning",
"pipe:1",
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
# 流式传输 PCM 数据到前端
while True:
chunk = await ffmpeg_process.stdout.read(4096)
if not chunk:
break
await websocket.send_bytes(chunk)
await ffmpeg_process.wait()
if ffmpeg_process.returncode != 0:
error_msg = await ffmpeg_process.stderr.read()
raise RuntimeError(f"FFmpeg 转换失败: {error_msg.decode()}")
# 发送 EOF 标记(附带 WAV 头信息)
await websocket.send_json({
"type": "EOF",
"sampleRate": 16000,
"numChannels": 1,
"bitDepth": 16,
})
finally:
# 清理临时 MP3 文件
try:
if os.path.exists(temp_mp3_path):
os.unlink(temp_mp3_path)
except Exception as e:
print(f"临时文件清理失败: {e}")
async def text_to_speech_stream(text: str, websocket: WebSocket):
process = await asyncio.create_subprocess_exec(
"edge-tts",
"--text", text,
"--voice", config.voice,
"--rate", config.tts_rate,
"--pitch", config.tts_pitch,
stdout=asyncio.subprocess.PIPE, # 直接读取 stdout
stderr=asyncio.subprocess.PIPE,
)
# 直接从进程 stdout 读取音频流并发送
while chunk := await process.stdout.read(1024): # 1KB 块
await websocket.send_bytes(chunk)
# 等待进程结束
await process.wait()
@app.websocket("/ws/transcribe")
async def websocket_endpoint(websocket: WebSocket):
try:
query_params = parse_qs(websocket.scope['query_string'].decode())
sv = query_params.get('sv', ['false'])[0].lower() in ['true', '1', 't', 'y', 'yes']
lang = query_params.get('lang', ['auto'])[0].lower()
await websocket.accept()
chunk_size = int(config.chunk_size_ms * config.sample_rate / 1000)
audio_buffer = np.array([], dtype=np.float32)
audio_vad = np.array([], dtype=np.float32)
cache = {}
cache_asr = {}
last_vad_beg = last_vad_end = -1
offset = 0
buffer = b""
while True:
data = await websocket.receive_bytes()
# logger.info(f"received {len(data)} bytes")
buffer += data
if len(buffer) < 2:
continue
audio_buffer = np.append(
audio_buffer,
np.frombuffer(buffer[:len(buffer) - (len(buffer) % 2)], dtype=np.int16).astype(np.float32) / 32767.0
)
buffer = buffer[len(buffer) - (len(buffer) % 2):]
while len(audio_buffer) >= chunk_size:
chunk = audio_buffer[:chunk_size]
audio_buffer = audio_buffer[chunk_size:]
audio_vad = np.append(audio_vad, chunk)
res = model_vad.generate(input=chunk, cache=cache, is_final=False, chunk_size=config.chunk_size_ms)
if len(res[0]["value"]):
vad_segments = res[0]["value"]
for segment in vad_segments:
if segment[0] > -1: # speech begin
last_vad_beg = segment[0]
if segment[1] > -1: # speech end
last_vad_end = segment[1]
if last_vad_beg > -1 and last_vad_end > -1:
last_vad_beg -= offset
last_vad_end -= offset
offset += last_vad_end
beg = int(last_vad_beg * config.sample_rate / 1000)
end = int(last_vad_end * config.sample_rate / 1000)
logger.info(f"[vad segment] audio_len: {end - beg}")
result = asr(audio_vad[beg:end], lang.strip(), cache_asr, True)
logger.info(f"asr response: {result}")
audio_vad = audio_vad[end:]
last_vad_beg = last_vad_end = -1
if result is not None:
global messages
say=re.sub(r"<\|.*?\|>", "", result[0]['text'])
messages.append({"role": "user", "content": say})
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False
)
model_inputs = tokenizer([text], return_tensors="pt").to(model_qw.device)
# conduct text completion
generated_ids = model_qw.generate(
**model_inputs,
max_new_tokens=32768
)
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
# parsing thinking content
try:
# rindex finding 151668 (</think>)
index = len(output_ids) - output_ids[::-1].index(151668)
except ValueError:
index = 0
content = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")
# 6. 重置 messages保留系统消息清空用户和助手消息
# 假设第一条是系统消息,需要保留
if len(messages) > 0 and messages[0]["role"] == "system":
messages = [messages[0]] # 只保留系统消息
else:
messages = [] # 如果没有系统消息,直接清空
response = TranscriptionResponse(
code=0,
info=json.dumps(result[0], ensure_ascii=False),
data=content
)
# 发送聊天回复(文字)
await websocket.send_json(response.model_dump())
# 发送 TTS 音频流
await text_to_speech_stream_with_wav(content, websocket)
# logger.debug(f'last_vad_beg: {last_vad_beg}; last_vad_end: {last_vad_end} len(audio_vad): {len(audio_vad)}')
except WebSocketDisconnect:
logger.info("WebSocket disconnected")
except Exception as e:
logger.error(f"Unexpected error: {e}\nCall stack:\n{traceback.format_exc()}")
await websocket.close()
finally:
cache.clear()
logger.info("Cleaned up resources after WebSocket disconnect")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run the FastAPI app with a specified port.")
parser.add_argument('--port', type=int, default=27004, help='Port number to run the FastAPI app on.')
# parser.add_argument('--certfile', type=str, default='path_to_your_SSL_certificate_file.crt', help='SSL certificate file')
# parser.add_argument('--keyfile', type=str, default='path_to_your_SSL_certificate_file.key', help='SSL key file')
args = parser.parse_args()
# uvicorn.run(app, host="0.0.0.0", port=args.port, ssl_certfile=args.certfile, ssl_keyfile=args.keyfile)
uvicorn.run(app, host="0.0.0.0", port=args.port)