From 381ea0c82c1877a11facd5d9149b3de29cb98c7f Mon Sep 17 00:00:00 2001 From: SWivid Date: Wed, 30 Oct 2024 03:16:09 +0800 Subject: [PATCH] fix vocoder loading --- src/f5_tts/api.py | 3 ++- src/f5_tts/infer/infer_cli.py | 6 ++++-- src/f5_tts/infer/infer_gradio.py | 3 ++- src/f5_tts/infer/utils_infer.py | 8 ++++---- 4 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/f5_tts/api.py b/src/f5_tts/api.py index 3eccc0d..41fc667 100644 --- a/src/f5_tts/api.py +++ b/src/f5_tts/api.py @@ -47,7 +47,7 @@ class F5TTS: self.load_ema_model(model_type, ckpt_file, vocab_file, ode_method, use_ema) def load_vocoder_model(self, local_path): - self.vocos = load_vocoder(local_path is not None, local_path, self.device) + self.vocoder = load_vocoder(local_path is not None, local_path, self.device) def load_ema_model(self, model_type, ckpt_file, vocab_file, ode_method, use_ema): if model_type == "F5-TTS": @@ -102,6 +102,7 @@ class F5TTS: ref_text, gen_text, self.ema_model, + self.vocoder, show_info=show_info, progress=progress, target_rms=target_rms, diff --git a/src/f5_tts/infer/infer_cli.py b/src/f5_tts/infer/infer_cli.py index f33cedf..1d9b319 100644 --- a/src/f5_tts/infer/infer_cli.py +++ b/src/f5_tts/infer/infer_cli.py @@ -113,7 +113,7 @@ wave_path = Path(output_dir) / "infer_cli_out.wav" # spectrogram_path = Path(output_dir) / "infer_cli_out.png" vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz" -vocos = load_vocoder(is_local=args.load_vocoder_from_local, local_path=vocos_local_path) +vocoder = load_vocoder(is_local=args.load_vocoder_from_local, local_path=vocos_local_path) # load models @@ -175,7 +175,9 @@ def main_process(ref_audio, ref_text, text_gen, model_obj, remove_silence, speed ref_audio = voices[voice]["ref_audio"] ref_text = voices[voice]["ref_text"] print(f"Voice: {voice}") - audio, final_sample_rate, spectragram = infer_process(ref_audio, ref_text, gen_text, model_obj, speed=speed) + audio, final_sample_rate, spectragram = infer_process( + ref_audio, ref_text, gen_text, model_obj, vocoder, speed=speed + ) generated_audio_segments.append(audio) if generated_audio_segments: diff --git a/src/f5_tts/infer/infer_gradio.py b/src/f5_tts/infer/infer_gradio.py index 85d2550..4c37989 100644 --- a/src/f5_tts/infer/infer_gradio.py +++ b/src/f5_tts/infer/infer_gradio.py @@ -37,7 +37,7 @@ from f5_tts.infer.utils_infer import ( save_spectrogram, ) -vocos = load_vocoder() +vocoder = load_vocoder() # load models @@ -94,6 +94,7 @@ def infer( ref_text, gen_text, ema_model, + vocoder, cross_fade_duration=cross_fade_duration, speed=speed, show_info=show_info, diff --git a/src/f5_tts/infer/utils_infer.py b/src/f5_tts/infer/utils_infer.py index 34f56bc..be48b5a 100644 --- a/src/f5_tts/infer/utils_infer.py +++ b/src/f5_tts/infer/utils_infer.py @@ -29,9 +29,6 @@ _ref_audio_cache = {} device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" -vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz") - - # ----------------------------------------- target_sample_rate = 24000 @@ -263,6 +260,7 @@ def infer_process( ref_text, gen_text, model_obj, + vocoder, show_info=print, progress=tqdm, target_rms=target_rms, @@ -287,6 +285,7 @@ def infer_process( ref_text, gen_text_batches, model_obj, + vocoder, progress=progress, target_rms=target_rms, cross_fade_duration=cross_fade_duration, @@ -307,6 +306,7 @@ def infer_batch_process( ref_text, gen_text_batches, model_obj, + vocoder, progress=tqdm, target_rms=0.1, cross_fade_duration=0.15, @@ -362,7 +362,7 @@ def infer_batch_process( generated = generated.to(torch.float32) generated = generated[:, ref_audio_len:, :] generated_mel_spec = generated.permute(0, 2, 1) - generated_wave = vocos.decode(generated_mel_spec.cpu()) + generated_wave = vocoder.decode(generated_mel_spec.cpu()) if rms < target_rms: generated_wave = generated_wave * rms / target_rms