695 lines
28 KiB
Python
695 lines
28 KiB
Python
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Liu Yue)
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
import os
|
||
import sys
|
||
import argparse
|
||
import logging
|
||
import io
|
||
import torch
|
||
import torchaudio
|
||
import librosa
|
||
import numpy as np
|
||
import time
|
||
from urllib.parse import quote
|
||
from typing import Optional, List, Dict, Callable
|
||
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
|
||
from fastapi.responses import StreamingResponse
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
import uvicorn
|
||
|
||
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||
sys.path.append('{}/third_party/Matcha-TTS'.format(ROOT_DIR))
|
||
|
||
# 预先下载 wetext 模型,避免在初始化时下载(离线环境支持)
|
||
def _preload_wetext_model():
|
||
"""预先下载 wetext 模型到本地缓存,避免在初始化时下载
|
||
|
||
此函数会在导入 CosyVoice2 之前检查并下载 wetext 模型。
|
||
如果模型已存在于缓存中,则不会重新下载,支持离线环境运行。
|
||
"""
|
||
try:
|
||
# 检查是否已安装 wetext
|
||
import wetext
|
||
except ImportError:
|
||
# 如果没有安装 wetext,直接返回(会使用 ttsfrd)
|
||
return
|
||
|
||
try:
|
||
# 检查 ModelScope 缓存目录中是否已有 wetext 模型
|
||
cache_root = os.path.expanduser('~/.cache/modelscope/hub')
|
||
wetext_cache_dir = os.path.join(cache_root, 'pengzhendong', 'wetext')
|
||
|
||
# 检查必要的文件是否存在
|
||
required_files = [
|
||
os.path.join(wetext_cache_dir, 'zh', 'tn', 'tagger.fst'),
|
||
os.path.join(wetext_cache_dir, 'zh', 'tn', 'verbalizer.fst'),
|
||
os.path.join(wetext_cache_dir, 'en', 'tn', 'tagger.fst'),
|
||
os.path.join(wetext_cache_dir, 'en', 'tn', 'verbalizer.fst'),
|
||
]
|
||
|
||
# 如果所有文件都存在,说明模型已下载,直接返回
|
||
if all(os.path.exists(f) for f in required_files):
|
||
logging.info(f'wetext 模型已存在于缓存: {wetext_cache_dir}')
|
||
return
|
||
|
||
# 如果模型不存在,尝试下载(仅在联网时)
|
||
# 注意:snapshot_download 在模型已存在时会直接返回路径,不会重新下载
|
||
logging.info('正在下载 wetext 模型(如果已存在则使用缓存)...')
|
||
from modelscope import snapshot_download
|
||
downloaded_dir = snapshot_download("pengzhendong/wetext")
|
||
logging.info(f'wetext 模型已就绪: {downloaded_dir}')
|
||
except Exception as e:
|
||
# 如果下载失败(可能是离线环境),记录警告但继续运行
|
||
# 后续初始化时会再次尝试下载或报错
|
||
logging.warning(f'无法预先下载 wetext 模型: {e},将在初始化时尝试下载')
|
||
|
||
# 在导入 CosyVoice2 之前预先加载 wetext 模型
|
||
_preload_wetext_model()
|
||
|
||
from cosyvoice.cli.cosyvoice import CosyVoice2
|
||
from cosyvoice.utils.file_utils import load_wav
|
||
from cosyvoice.utils.common import set_all_random_seed
|
||
|
||
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
||
|
||
app = FastAPI(title="CosyVoice API Server", version="1.0.0")
|
||
|
||
# 设置跨域支持
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"],
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
# 全局变量
|
||
cosyvoice = None
|
||
prompt_sr = 16000
|
||
max_val = 0.8
|
||
spk2info: Dict[str, Dict[str, torch.Tensor]] = {}
|
||
spk2info_path: Optional[str] = None
|
||
|
||
|
||
def postprocess(speech, top_db=60, hop_length=220, win_length=440):
|
||
"""后处理音频,去除静音并归一化"""
|
||
speech, _ = librosa.effects.trim(
|
||
speech, top_db=top_db,
|
||
frame_length=win_length,
|
||
hop_length=hop_length
|
||
)
|
||
if speech.abs().max() > max_val:
|
||
speech = speech / speech.abs().max() * max_val
|
||
speech = torch.concat([speech, torch.zeros(1, int(cosyvoice.sample_rate * 0.2))], dim=1)
|
||
return speech
|
||
|
||
|
||
def generate_wav_stream(model_output, sample_rate, stream_mode=False):
|
||
"""生成WAV格式的音频流"""
|
||
if stream_mode:
|
||
# 流式模式:逐个chunk返回
|
||
for i in model_output:
|
||
audio_data = i['tts_speech'].numpy().flatten()
|
||
# 转换为int16格式
|
||
audio_int16 = (audio_data * (2 ** 15)).astype(np.int16)
|
||
# 创建WAV字节流
|
||
buffer = io.BytesIO()
|
||
torchaudio.save(buffer, torch.from_numpy(audio_int16).unsqueeze(0), sample_rate, format='wav')
|
||
yield buffer.getvalue()
|
||
else:
|
||
# 非流式模式:收集所有chunk后一次性返回
|
||
audio_chunks = []
|
||
for i in model_output:
|
||
audio_chunks.append(i['tts_speech'].numpy().flatten())
|
||
|
||
if len(audio_chunks) > 0:
|
||
# 拼接所有chunk
|
||
audio_data = np.concatenate(audio_chunks)
|
||
# 转换为int16格式
|
||
audio_int16 = (audio_data * (2 ** 15)).astype(np.int16)
|
||
# 创建WAV字节流
|
||
buffer = io.BytesIO()
|
||
torchaudio.save(buffer, torch.from_numpy(audio_int16).unsqueeze(0), sample_rate, format='wav')
|
||
yield buffer.getvalue()
|
||
|
||
|
||
def build_content_disposition(filename: str) -> str:
|
||
"""构造兼容多语言文件名的 Content-Disposition 头"""
|
||
safe_ascii = filename.encode('ascii', errors='ignore').decode('ascii') or 'download.wav'
|
||
quoted = quote(filename)
|
||
return f'attachment; filename="{safe_ascii}"; filename*=UTF-8\'\'{quoted}'
|
||
|
||
|
||
def load_speaker_info():
|
||
"""加载已保存的音色缓存"""
|
||
global spk2info
|
||
if cosyvoice is None or spk2info_path is None:
|
||
spk2info = {}
|
||
return
|
||
if os.path.exists(spk2info_path):
|
||
logging.info(f"加载音色缓存: {spk2info_path}")
|
||
spk2info = torch.load(spk2info_path, map_location=cosyvoice.frontend.device)
|
||
logging.info(f"✓ 已加载 {len(spk2info)} 个音色特征")
|
||
else:
|
||
spk2info = {}
|
||
logging.info("未找到音色缓存文件,初始化为空")
|
||
|
||
|
||
def save_speaker_info():
|
||
"""将音色缓存保存到磁盘"""
|
||
if spk2info_path is None:
|
||
raise RuntimeError("spk2info_path 未初始化,无法保存音色信息")
|
||
torch.save(spk2info, spk2info_path)
|
||
logging.info(f"音色缓存已保存到 {spk2info_path}")
|
||
|
||
|
||
def save_uploaded_file(upload_file: UploadFile) -> str:
|
||
"""保存上传的音频文件到临时路径"""
|
||
import tempfile
|
||
suffix = os.path.splitext(upload_file.filename)[1]
|
||
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
|
||
content = upload_file.file.read()
|
||
tmp.write(content)
|
||
return tmp.name
|
||
|
||
|
||
def extract_speaker_features(speaker_id: str, prompt_wav_path: str, prompt_text: str, force_regenerate: bool = False) -> bool:
|
||
"""提取并缓存音色的 embedding / speech_feat / speech_token"""
|
||
global spk2info
|
||
if cosyvoice is None:
|
||
raise RuntimeError("模型未加载,无法提取音色特征")
|
||
if speaker_id in spk2info and not force_regenerate:
|
||
logging.info(f"音色 {speaker_id} 已存在,跳过提取")
|
||
return True
|
||
try:
|
||
logging.info(f"开始提取音色 {speaker_id} 特征")
|
||
prompt_speech_16k = load_wav(prompt_wav_path, 16000)
|
||
embedding = cosyvoice.frontend._extract_spk_embedding(prompt_speech_16k)
|
||
resample_op = torchaudio.transforms.Resample(orig_freq=16000, new_freq=cosyvoice.sample_rate)
|
||
prompt_speech_resample = resample_op(prompt_speech_16k)
|
||
speech_feat, _ = cosyvoice.frontend._extract_speech_feat(prompt_speech_resample)
|
||
speech_token, _ = cosyvoice.frontend._extract_speech_token(prompt_speech_16k)
|
||
spk2info[speaker_id] = {
|
||
"embedding": embedding,
|
||
"speech_feat": speech_feat,
|
||
"speech_token": speech_token,
|
||
"prompt_text": prompt_text,
|
||
}
|
||
logging.info(f"音色 {speaker_id} 特征提取完成")
|
||
return True
|
||
except Exception as e:
|
||
logging.error(f"提取音色特征失败: {e}")
|
||
return False
|
||
|
||
|
||
def tts_with_cached_features(
|
||
tts_text: str,
|
||
speaker_id: str,
|
||
prompt_text: Optional[str] = "",
|
||
stream: bool = False,
|
||
speed: float = 1.0,
|
||
text_frontend: bool = True,
|
||
progress_callback: Optional[Callable[[int, int], None]] = None,
|
||
):
|
||
"""使用缓存特征进行快速 SFT 合成"""
|
||
if cosyvoice is None:
|
||
raise RuntimeError("模型未加载")
|
||
if speaker_id not in spk2info:
|
||
raise ValueError(f"音色 {speaker_id} 不存在或未缓存")
|
||
speaker_info = spk2info[speaker_id]
|
||
if not prompt_text and "prompt_text" in speaker_info:
|
||
prompt_text = speaker_info["prompt_text"]
|
||
segments = list(cosyvoice.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend))
|
||
if not segments:
|
||
segments = [tts_text]
|
||
total_segments = len(segments)
|
||
for idx, chunk_text in enumerate(segments, start=1):
|
||
text_token, text_token_len = cosyvoice.frontend._extract_text_token(chunk_text)
|
||
prompt_text_token, prompt_text_token_len = cosyvoice.frontend._extract_text_token(prompt_text)
|
||
speech_token_len = torch.tensor(
|
||
[speaker_info["speech_token"].shape[1]],
|
||
dtype=torch.int32,
|
||
).to(cosyvoice.frontend.device)
|
||
speech_feat_len = torch.tensor(
|
||
[speaker_info["speech_feat"].shape[1]],
|
||
dtype=torch.int32,
|
||
).to(cosyvoice.frontend.device)
|
||
model_input = {
|
||
"text": text_token,
|
||
"text_len": text_token_len,
|
||
"prompt_text": prompt_text_token,
|
||
"prompt_text_len": prompt_text_token_len,
|
||
"llm_prompt_speech_token": speaker_info["speech_token"],
|
||
"llm_prompt_speech_token_len": speech_token_len,
|
||
"flow_prompt_speech_token": speaker_info["speech_token"],
|
||
"flow_prompt_speech_token_len": speech_token_len,
|
||
"prompt_speech_feat": speaker_info["speech_feat"],
|
||
"prompt_speech_feat_len": speech_feat_len,
|
||
"llm_embedding": speaker_info["embedding"],
|
||
"flow_embedding": speaker_info["embedding"],
|
||
}
|
||
if progress_callback:
|
||
progress_callback(idx, total_segments)
|
||
for output in cosyvoice.model.tts(**model_input, stream=stream, speed=speed):
|
||
yield output
|
||
|
||
@app.get("/")
|
||
async def root():
|
||
"""根路径,返回API信息"""
|
||
return {
|
||
"name": "CosyVoice API Server",
|
||
"version": "1.0.0",
|
||
"endpoints": {
|
||
"/inference": "统一推理接口,支持所有模式",
|
||
"/list_speakers": "获取可用的预训练音色列表",
|
||
"/health": "健康检查"
|
||
}
|
||
}
|
||
|
||
|
||
@app.get("/health")
|
||
async def health():
|
||
"""健康检查"""
|
||
return {"status": "healthy", "model_loaded": cosyvoice is not None}
|
||
|
||
|
||
@app.get("/list_speakers")
|
||
async def list_speakers():
|
||
"""获取可用的预训练音色列表"""
|
||
if cosyvoice is None:
|
||
raise HTTPException(status_code=503, detail="模型未加载")
|
||
speakers = cosyvoice.list_available_spks()
|
||
return {"speakers": speakers if speakers else []}
|
||
|
||
|
||
@app.get("/api/speakers", response_model=List[str])
|
||
async def list_cached_speakers():
|
||
"""获取已缓存特征的音色ID"""
|
||
if cosyvoice is None:
|
||
raise HTTPException(status_code=500, detail="模型未加载")
|
||
logging.info(f"当前缓存音色数量: {len(spk2info)}")
|
||
return list(spk2info.keys())
|
||
|
||
|
||
@app.get("/api/speakers/info")
|
||
async def get_cached_speakers_info():
|
||
"""获取缓存音色的详细信息"""
|
||
if cosyvoice is None:
|
||
raise HTTPException(status_code=500, detail="模型未加载")
|
||
speakers_info = []
|
||
for speaker_id, info in spk2info.items():
|
||
speakers_info.append({
|
||
"speaker_id": speaker_id,
|
||
"has_embedding": "embedding" in info,
|
||
"has_speech_feat": "speech_feat" in info,
|
||
"has_speech_token": "speech_token" in info,
|
||
"embedding_shape": str(info["embedding"].shape) if "embedding" in info else None,
|
||
"speech_feat_shape": str(info["speech_feat"].shape) if "speech_feat" in info else None,
|
||
"speech_token_shape": str(info["speech_token"].shape) if "speech_token" in info else None,
|
||
})
|
||
return {"total_count": len(speakers_info), "speakers": speakers_info}
|
||
|
||
|
||
@app.post("/api/speakers/add")
|
||
async def add_speaker(
|
||
speaker_id: str = Form(..., description="音色ID"),
|
||
prompt_text: str = Form(..., description="参考文本"),
|
||
prompt_wav: UploadFile = File(..., description="参考音频文件"),
|
||
force_regenerate: bool = Form(False, description="是否强制重新生成"),
|
||
):
|
||
"""上传音频并提取缓存音色特征"""
|
||
if cosyvoice is None:
|
||
raise HTTPException(status_code=500, detail="模型未加载")
|
||
if speaker_id in spk2info and not force_regenerate:
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail=f"音色 {speaker_id} 已存在,可设置 force_regenerate=true 重新生成",
|
||
)
|
||
temp_path = None
|
||
try:
|
||
temp_path = save_uploaded_file(prompt_wav)
|
||
success = extract_speaker_features(
|
||
speaker_id=speaker_id,
|
||
prompt_wav_path=temp_path,
|
||
prompt_text=prompt_text,
|
||
force_regenerate=force_regenerate,
|
||
)
|
||
if not success:
|
||
raise HTTPException(status_code=500, detail="特征提取失败")
|
||
return {
|
||
"success": True,
|
||
"speaker_id": speaker_id,
|
||
"message": f"音色 {speaker_id} 特征已缓存,如需持久化请调用 /api/speakers/save",
|
||
"cached_features": ["embedding", "speech_feat", "speech_token"],
|
||
}
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logging.error(f"添加音色失败: {e}")
|
||
raise HTTPException(status_code=500, detail=f"添加音色失败: {e}")
|
||
finally:
|
||
if temp_path and os.path.exists(temp_path):
|
||
os.remove(temp_path)
|
||
|
||
|
||
@app.post("/api/speakers/save")
|
||
async def save_cached_speakers():
|
||
"""将缓存音色保存到磁盘"""
|
||
if cosyvoice is None:
|
||
raise HTTPException(status_code=500, detail="模型未加载")
|
||
try:
|
||
save_speaker_info()
|
||
return {
|
||
"success": True,
|
||
"file_path": spk2info_path,
|
||
"total_speakers": len(spk2info),
|
||
"speakers": list(spk2info.keys()),
|
||
}
|
||
except Exception as e:
|
||
logging.error(f"保存音色缓存失败: {e}")
|
||
raise HTTPException(status_code=500, detail=f"保存失败: {e}")
|
||
|
||
|
||
@app.delete("/api/speakers/{speaker_id}")
|
||
async def delete_cached_speaker(speaker_id: str):
|
||
"""删除指定缓存音色"""
|
||
if cosyvoice is None:
|
||
raise HTTPException(status_code=500, detail="模型未加载")
|
||
if speaker_id not in spk2info:
|
||
raise HTTPException(status_code=404, detail=f"音色 {speaker_id} 不存在")
|
||
del spk2info[speaker_id]
|
||
logging.info(f"音色 {speaker_id} 已从缓存移除")
|
||
return {
|
||
"success": True,
|
||
"message": f"音色 {speaker_id} 已删除,如需同步磁盘请调用 /api/speakers/save",
|
||
}
|
||
|
||
|
||
@app.post("/api/speakers/regenerate/{speaker_id}")
|
||
async def regenerate_cached_speaker(
|
||
speaker_id: str,
|
||
prompt_text: str = Form(...),
|
||
prompt_wav: UploadFile = File(...),
|
||
):
|
||
"""重新生成缓存音色"""
|
||
if cosyvoice is None:
|
||
raise HTTPException(status_code=500, detail="模型未加载")
|
||
if speaker_id not in spk2info:
|
||
raise HTTPException(status_code=404, detail=f"音色 {speaker_id} 不存在")
|
||
temp_path = None
|
||
try:
|
||
temp_path = save_uploaded_file(prompt_wav)
|
||
success = extract_speaker_features(
|
||
speaker_id=speaker_id,
|
||
prompt_wav_path=temp_path,
|
||
prompt_text=prompt_text,
|
||
force_regenerate=True,
|
||
)
|
||
if not success:
|
||
raise HTTPException(status_code=500, detail="特征重新生成失败")
|
||
return {"success": True, "message": f"音色 {speaker_id} 特征已更新"}
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logging.error(f"重新生成音色失败: {e}")
|
||
raise HTTPException(status_code=500, detail=f"重新生成失败: {e}")
|
||
finally:
|
||
if temp_path and os.path.exists(temp_path):
|
||
os.remove(temp_path)
|
||
|
||
|
||
@app.post("/api/synthesis/sft")
|
||
async def synthesis_sft_cached(
|
||
text: str = Form(..., description="要合成的文本"),
|
||
speaker_id: str = Form(..., description="缓存音色ID"),
|
||
prompt_text: str = Form("", description="可选提示文本,默认使用缓存中保存的文本"),
|
||
stream: bool = Form(False, description="是否流式返回"),
|
||
speed: float = Form(1.0, description="语速调节"),
|
||
):
|
||
"""使用缓存音色特征进行高速 SFT 合成"""
|
||
if cosyvoice is None:
|
||
raise HTTPException(status_code=500, detail="模型未加载")
|
||
if not text.strip():
|
||
raise HTTPException(status_code=400, detail="文本不能为空")
|
||
if speaker_id not in spk2info:
|
||
raise HTTPException(status_code=404, detail=f"音色 {speaker_id} 不存在")
|
||
start_time = time.perf_counter()
|
||
progress_total = {"total": 0}
|
||
|
||
def log_progress(current: int, total: int):
|
||
progress_total["total"] = total
|
||
logging.info(f"SFT缓存合成进度 - 音色:{speaker_id} {current}/{total}")
|
||
|
||
try:
|
||
logging.info(f"SFT缓存合成开始 - 音色:{speaker_id}, 文本长度:{len(text)}")
|
||
model_output = tts_with_cached_features(
|
||
tts_text=text,
|
||
speaker_id=speaker_id,
|
||
prompt_text=prompt_text,
|
||
stream=stream,
|
||
speed=speed,
|
||
progress_callback=log_progress,
|
||
)
|
||
|
||
headers = {
|
||
"Content-Disposition": build_content_disposition(f"synthesis_{speaker_id}.wav"),
|
||
"X-Sample-Rate": str(cosyvoice.sample_rate),
|
||
}
|
||
|
||
if stream:
|
||
headers["X-Progress-Total"] = str(progress_total["total"])
|
||
|
||
def timed_stream():
|
||
try:
|
||
for chunk in generate_wav_stream(model_output, cosyvoice.sample_rate, stream_mode=True):
|
||
yield chunk
|
||
finally:
|
||
elapsed = time.perf_counter() - start_time
|
||
logging.info(f"SFT缓存合成完成(流式) - 音色:{speaker_id}, 耗时:{elapsed:.3f}s")
|
||
|
||
return StreamingResponse(timed_stream(), media_type="audio/wav", headers=headers)
|
||
|
||
audio_bytes = b"".join(
|
||
generate_wav_stream(model_output, cosyvoice.sample_rate, stream_mode=False)
|
||
)
|
||
elapsed = time.perf_counter() - start_time
|
||
headers["X-Processing-Time"] = f"{elapsed:.3f}"
|
||
logging.info(f"SFT缓存合成完成 - 音色:{speaker_id}, 耗时:{elapsed:.3f}s")
|
||
return StreamingResponse(
|
||
iter([audio_bytes]),
|
||
media_type="audio/wav",
|
||
headers=headers,
|
||
)
|
||
except ValueError as e:
|
||
raise HTTPException(status_code=404, detail=str(e))
|
||
except Exception as e:
|
||
logging.error(f"SFT 缓存合成失败: {e}")
|
||
raise HTTPException(status_code=500, detail=f"合成失败: {e}")
|
||
|
||
|
||
@app.post("/inference")
|
||
async def inference(
|
||
tts_text: str = Form(..., description="需要合成的文本"),
|
||
mode: str = Form(..., description="推理模式: 'sft'(预训练音色), 'zero_shot'(3s极速复刻), 'cross_lingual'(跨语种复刻), 'instruct'(自然语言控制)"),
|
||
spk_id: Optional[str] = Form(None, description="预训练音色ID(sft和instruct模式需要)"),
|
||
prompt_text: Optional[str] = Form(None, description="prompt文本(zero_shot模式需要)"),
|
||
prompt_wav: Optional[UploadFile] = File(None, description="prompt音频文件(zero_shot和cross_lingual模式需要)"),
|
||
instruct_text: Optional[str] = Form(None, description="instruct文本(instruct模式需要)"),
|
||
seed: Optional[int] = Form(0, description="随机种子,0表示随机"),
|
||
stream: bool = Form(False, description="是否流式推理"),
|
||
speed: float = Form(1.0, description="速度调节(0.5-2.0,仅支持非流式推理)", ge=0.5, le=2.0)
|
||
):
|
||
"""
|
||
统一推理接口,支持所有推理模式
|
||
|
||
模式说明:
|
||
- sft: 预训练音色模式,需要spk_id
|
||
- zero_shot: 3s极速复刻模式,需要prompt_text和prompt_wav
|
||
- cross_lingual: 跨语种复刻模式,需要prompt_wav
|
||
- instruct: 自然语言控制模式,需要spk_id和instruct_text
|
||
"""
|
||
if cosyvoice is None:
|
||
raise HTTPException(status_code=503, detail="模型未加载")
|
||
|
||
# 参数验证
|
||
if mode == 'sft':
|
||
if not spk_id:
|
||
raise HTTPException(status_code=400, detail="sft模式需要提供spk_id")
|
||
available_spks = cosyvoice.list_available_spks()
|
||
if not available_spks or spk_id not in available_spks:
|
||
raise HTTPException(status_code=400, detail=f"无效的spk_id: {spk_id},可用音色: {available_spks}")
|
||
|
||
elif mode == 'zero_shot':
|
||
if not prompt_text:
|
||
raise HTTPException(status_code=400, detail="zero_shot模式需要提供prompt_text")
|
||
if not prompt_wav:
|
||
raise HTTPException(status_code=400, detail="zero_shot模式需要提供prompt_wav")
|
||
|
||
elif mode == 'cross_lingual':
|
||
if not prompt_wav:
|
||
raise HTTPException(status_code=400, detail="cross_lingual模式需要提供prompt_wav")
|
||
if cosyvoice.instruct is True:
|
||
raise HTTPException(status_code=400, detail="当前模型不支持cross_lingual模式,请使用非Instruct模型")
|
||
|
||
elif mode == 'instruct':
|
||
if not spk_id:
|
||
raise HTTPException(status_code=400, detail="instruct模式需要提供spk_id")
|
||
if not instruct_text:
|
||
raise HTTPException(status_code=400, detail="instruct模式需要提供instruct_text")
|
||
if cosyvoice.instruct is False:
|
||
raise HTTPException(status_code=400, detail="当前模型不支持instruct模式,请使用CosyVoice-300M-Instruct模型")
|
||
available_spks = cosyvoice.list_available_spks()
|
||
if not available_spks or spk_id not in available_spks:
|
||
raise HTTPException(status_code=400, detail=f"无效的spk_id: {spk_id},可用音色: {available_spks}")
|
||
|
||
else:
|
||
raise HTTPException(status_code=400, detail=f"无效的模式: {mode},支持的模式: sft, zero_shot, cross_lingual, instruct")
|
||
|
||
# 设置随机种子
|
||
if seed > 0:
|
||
set_all_random_seed(seed)
|
||
|
||
# 流式模式下速度必须为1.0
|
||
if stream and speed != 1.0:
|
||
raise HTTPException(status_code=400, detail="流式推理模式下速度必须为1.0")
|
||
|
||
try:
|
||
# 执行推理
|
||
if mode == 'sft':
|
||
logging.info('get sft inference request')
|
||
model_output = cosyvoice.inference_sft(tts_text, spk_id, stream=stream, speed=speed)
|
||
|
||
elif mode == 'zero_shot':
|
||
logging.info('get zero_shot inference request')
|
||
# 保存上传的文件到临时位置
|
||
temp_file = io.BytesIO(await prompt_wav.read())
|
||
try:
|
||
prompt_speech_16k = postprocess(load_wav(temp_file, prompt_sr))
|
||
except AssertionError as e:
|
||
raise HTTPException(status_code=400, detail=str(e))
|
||
model_output = cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k, stream=stream, speed=speed)
|
||
|
||
elif mode == 'cross_lingual':
|
||
logging.info('get cross_lingual inference request')
|
||
# 保存上传的文件到临时位置
|
||
temp_file = io.BytesIO(await prompt_wav.read())
|
||
try:
|
||
prompt_speech_16k = postprocess(load_wav(temp_file, prompt_sr))
|
||
except AssertionError as e:
|
||
raise HTTPException(status_code=400, detail=str(e))
|
||
model_output = cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k, stream=stream, speed=speed)
|
||
|
||
else: # instruct
|
||
logging.info('get instruct inference request')
|
||
model_output = cosyvoice.inference_instruct(tts_text, spk_id, instruct_text, stream=stream, speed=speed)
|
||
|
||
# 返回音频流
|
||
return StreamingResponse(
|
||
generate_wav_stream(model_output, cosyvoice.sample_rate, stream_mode=stream),
|
||
media_type="audio/wav",
|
||
headers={
|
||
"Content-Disposition": "attachment; filename=output.wav",
|
||
"X-Sample-Rate": str(cosyvoice.sample_rate)
|
||
}
|
||
)
|
||
|
||
except Exception as e:
|
||
logging.error(f"推理错误: {str(e)}", exc_info=True)
|
||
raise HTTPException(status_code=500, detail=f"推理失败: {str(e)}")
|
||
|
||
|
||
@app.post("/inference_sft")
|
||
async def inference_sft(
|
||
tts_text: str = Form(...),
|
||
spk_id: str = Form(...),
|
||
seed: Optional[int] = Form(0),
|
||
stream: bool = Form(False),
|
||
speed: float = Form(1.0, ge=0.5, le=2.0)
|
||
):
|
||
"""预训练音色推理接口"""
|
||
return await inference(tts_text=tts_text, mode='sft', spk_id=spk_id, seed=seed, stream=stream, speed=speed)
|
||
|
||
|
||
@app.post("/inference_zero_shot")
|
||
async def inference_zero_shot(
|
||
tts_text: str = Form(...),
|
||
prompt_text: str = Form(...),
|
||
prompt_wav: UploadFile = File(...),
|
||
seed: Optional[int] = Form(0),
|
||
stream: bool = Form(False),
|
||
speed: float = Form(1.0, ge=0.5, le=2.0)
|
||
):
|
||
"""3s极速复刻推理接口"""
|
||
return await inference(tts_text=tts_text, mode='zero_shot', prompt_text=prompt_text,
|
||
prompt_wav=prompt_wav, seed=seed, stream=stream, speed=speed)
|
||
|
||
|
||
@app.post("/inference_cross_lingual")
|
||
async def inference_cross_lingual(
|
||
tts_text: str = Form(...),
|
||
prompt_wav: UploadFile = File(...),
|
||
seed: Optional[int] = Form(0),
|
||
stream: bool = Form(False),
|
||
speed: float = Form(1.0, ge=0.5, le=2.0)
|
||
):
|
||
"""跨语种复刻推理接口"""
|
||
return await inference(tts_text=tts_text, mode='cross_lingual', prompt_wav=prompt_wav,
|
||
seed=seed, stream=stream, speed=speed)
|
||
|
||
|
||
@app.post("/inference_instruct")
|
||
async def inference_instruct(
|
||
tts_text: str = Form(...),
|
||
spk_id: str = Form(...),
|
||
instruct_text: str = Form(...),
|
||
seed: Optional[int] = Form(0),
|
||
stream: bool = Form(False),
|
||
speed: float = Form(1.0, ge=0.5, le=2.0)
|
||
):
|
||
"""自然语言控制推理接口"""
|
||
return await inference(tts_text=tts_text, mode='instruct', spk_id=spk_id,
|
||
instruct_text=instruct_text, seed=seed, stream=stream, speed=speed)
|
||
|
||
|
||
if __name__ == '__main__':
|
||
parser = argparse.ArgumentParser(description='CosyVoice API Server')
|
||
parser.add_argument('--port', type=int, default=8000, help='服务器端口')
|
||
parser.add_argument('--host', type=str, default='0.0.0.0', help='服务器地址')
|
||
parser.add_argument('--model_dir', type=str, default='pretrain/CosyVoice2-0.5B',
|
||
help='模型路径(本地路径或modelscope repo id)')
|
||
args = parser.parse_args()
|
||
|
||
# 加载模型(仅支持 CosyVoice2)
|
||
logging.info(f"正在加载 CosyVoice2 模型: {args.model_dir}")
|
||
try:
|
||
cosyvoice = CosyVoice2(args.model_dir, load_jit=False, load_trt=False, fp16=False)
|
||
logging.info("成功加载 CosyVoice2 模型")
|
||
except Exception as e:
|
||
logging.error(f"模型加载失败: {e}")
|
||
raise RuntimeError(f'无法加载 CosyVoice2 模型: {e}')
|
||
|
||
spk2info_path = os.path.join(args.model_dir, 'spk2info.pt')
|
||
load_speaker_info()
|
||
if spk2info:
|
||
logging.info(f"已加载 {len(spk2info)} 个缓存音色: {list(spk2info.keys())[:5]}{'...' if len(spk2info) > 5 else ''}")
|
||
else:
|
||
logging.info("当前无缓存音色,可通过 /api/speakers/add 添加")
|
||
|
||
prompt_sr = 16000
|
||
logging.info(f"API服务器启动在 http://{args.host}:{args.port}")
|
||
logging.info(f"可用音色: {cosyvoice.list_available_spks()}")
|
||
|
||
uvicorn.run(app, host=args.host, port=args.port)
|
||
|