mirror of
https://github.com/SWivid/F5-TTS.git
synced 2025-12-12 07:40:43 -08:00
fix vocoder loading
This commit is contained in:
@@ -47,7 +47,7 @@ class F5TTS:
|
||||
self.load_ema_model(model_type, ckpt_file, vocab_file, ode_method, use_ema)
|
||||
|
||||
def load_vocoder_model(self, local_path):
|
||||
self.vocos = load_vocoder(local_path is not None, local_path, self.device)
|
||||
self.vocoder = load_vocoder(local_path is not None, local_path, self.device)
|
||||
|
||||
def load_ema_model(self, model_type, ckpt_file, vocab_file, ode_method, use_ema):
|
||||
if model_type == "F5-TTS":
|
||||
@@ -102,6 +102,7 @@ class F5TTS:
|
||||
ref_text,
|
||||
gen_text,
|
||||
self.ema_model,
|
||||
self.vocoder,
|
||||
show_info=show_info,
|
||||
progress=progress,
|
||||
target_rms=target_rms,
|
||||
|
||||
@@ -113,7 +113,7 @@ wave_path = Path(output_dir) / "infer_cli_out.wav"
|
||||
# spectrogram_path = Path(output_dir) / "infer_cli_out.png"
|
||||
vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
|
||||
|
||||
vocos = load_vocoder(is_local=args.load_vocoder_from_local, local_path=vocos_local_path)
|
||||
vocoder = load_vocoder(is_local=args.load_vocoder_from_local, local_path=vocos_local_path)
|
||||
|
||||
|
||||
# load models
|
||||
@@ -175,7 +175,9 @@ def main_process(ref_audio, ref_text, text_gen, model_obj, remove_silence, speed
|
||||
ref_audio = voices[voice]["ref_audio"]
|
||||
ref_text = voices[voice]["ref_text"]
|
||||
print(f"Voice: {voice}")
|
||||
audio, final_sample_rate, spectragram = infer_process(ref_audio, ref_text, gen_text, model_obj, speed=speed)
|
||||
audio, final_sample_rate, spectragram = infer_process(
|
||||
ref_audio, ref_text, gen_text, model_obj, vocoder, speed=speed
|
||||
)
|
||||
generated_audio_segments.append(audio)
|
||||
|
||||
if generated_audio_segments:
|
||||
|
||||
@@ -37,7 +37,7 @@ from f5_tts.infer.utils_infer import (
|
||||
save_spectrogram,
|
||||
)
|
||||
|
||||
vocos = load_vocoder()
|
||||
vocoder = load_vocoder()
|
||||
|
||||
|
||||
# load models
|
||||
@@ -94,6 +94,7 @@ def infer(
|
||||
ref_text,
|
||||
gen_text,
|
||||
ema_model,
|
||||
vocoder,
|
||||
cross_fade_duration=cross_fade_duration,
|
||||
speed=speed,
|
||||
show_info=show_info,
|
||||
|
||||
@@ -29,9 +29,6 @@ _ref_audio_cache = {}
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
|
||||
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
||||
|
||||
|
||||
# -----------------------------------------
|
||||
|
||||
target_sample_rate = 24000
|
||||
@@ -263,6 +260,7 @@ def infer_process(
|
||||
ref_text,
|
||||
gen_text,
|
||||
model_obj,
|
||||
vocoder,
|
||||
show_info=print,
|
||||
progress=tqdm,
|
||||
target_rms=target_rms,
|
||||
@@ -287,6 +285,7 @@ def infer_process(
|
||||
ref_text,
|
||||
gen_text_batches,
|
||||
model_obj,
|
||||
vocoder,
|
||||
progress=progress,
|
||||
target_rms=target_rms,
|
||||
cross_fade_duration=cross_fade_duration,
|
||||
@@ -307,6 +306,7 @@ def infer_batch_process(
|
||||
ref_text,
|
||||
gen_text_batches,
|
||||
model_obj,
|
||||
vocoder,
|
||||
progress=tqdm,
|
||||
target_rms=0.1,
|
||||
cross_fade_duration=0.15,
|
||||
@@ -362,7 +362,7 @@ def infer_batch_process(
|
||||
generated = generated.to(torch.float32)
|
||||
generated = generated[:, ref_audio_len:, :]
|
||||
generated_mel_spec = generated.permute(0, 2, 1)
|
||||
generated_wave = vocos.decode(generated_mel_spec.cpu())
|
||||
generated_wave = vocoder.decode(generated_mel_spec.cpu())
|
||||
if rms < target_rms:
|
||||
generated_wave = generated_wave * rms / target_rms
|
||||
|
||||
|
||||
Reference in New Issue
Block a user