diff --git a/src/f5_tts/api.py b/src/f5_tts/api.py index 41fc667..823067d 100644 --- a/src/f5_tts/api.py +++ b/src/f5_tts/api.py @@ -15,6 +15,9 @@ from f5_tts.infer.utils_infer import ( infer_process, remove_silence_for_generated_wav, save_spectrogram, + preprocess_ref_audio_text, + target_sample_rate, + hop_length, ) @@ -31,10 +34,8 @@ class F5TTS: ): # Initialize parameters self.final_wave = None - self.target_sample_rate = 24000 - self.n_mel_channels = 100 - self.hop_length = 256 - self.target_rms = 0.1 + self.target_sample_rate = target_sample_rate + self.hop_length = hop_length self.seed = -1 # Set device @@ -97,6 +98,9 @@ class F5TTS: seed = random.randint(0, sys.maxsize) seed_everything(seed) self.seed = seed + + ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text, device=self.device) + wav, sr, spect = infer_process( ref_file, ref_text, diff --git a/src/f5_tts/train/finetune_gradio.py b/src/f5_tts/train/finetune_gradio.py index fe835b2..007dad8 100644 --- a/src/f5_tts/train/finetune_gradio.py +++ b/src/f5_tts/train/finetune_gradio.py @@ -1216,7 +1216,7 @@ def infer(project, file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe else: device_test = None - if last_checkpoint != file_checkpoint or last_device != device_test or last_ema != use_ema: + if last_checkpoint != file_checkpoint or last_device != device_test or last_ema != use_ema or tts_api is None: if last_checkpoint != file_checkpoint: last_checkpoint = file_checkpoint