fix vocoder loading

This commit is contained in:
SWivid
2024-10-30 03:16:09 +08:00
parent da1b40968a
commit 381ea0c82c
4 changed files with 12 additions and 8 deletions

View File

@@ -47,7 +47,7 @@ class F5TTS:
self.load_ema_model(model_type, ckpt_file, vocab_file, ode_method, use_ema)
def load_vocoder_model(self, local_path):
self.vocos = load_vocoder(local_path is not None, local_path, self.device)
self.vocoder = load_vocoder(local_path is not None, local_path, self.device)
def load_ema_model(self, model_type, ckpt_file, vocab_file, ode_method, use_ema):
if model_type == "F5-TTS":
@@ -102,6 +102,7 @@ class F5TTS:
ref_text,
gen_text,
self.ema_model,
self.vocoder,
show_info=show_info,
progress=progress,
target_rms=target_rms,

View File

@@ -113,7 +113,7 @@ wave_path = Path(output_dir) / "infer_cli_out.wav"
# spectrogram_path = Path(output_dir) / "infer_cli_out.png"
vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
vocos = load_vocoder(is_local=args.load_vocoder_from_local, local_path=vocos_local_path)
vocoder = load_vocoder(is_local=args.load_vocoder_from_local, local_path=vocos_local_path)
# load models
@@ -175,7 +175,9 @@ def main_process(ref_audio, ref_text, text_gen, model_obj, remove_silence, speed
ref_audio = voices[voice]["ref_audio"]
ref_text = voices[voice]["ref_text"]
print(f"Voice: {voice}")
audio, final_sample_rate, spectragram = infer_process(ref_audio, ref_text, gen_text, model_obj, speed=speed)
audio, final_sample_rate, spectragram = infer_process(
ref_audio, ref_text, gen_text, model_obj, vocoder, speed=speed
)
generated_audio_segments.append(audio)
if generated_audio_segments:

View File

@@ -37,7 +37,7 @@ from f5_tts.infer.utils_infer import (
save_spectrogram,
)
vocos = load_vocoder()
vocoder = load_vocoder()
# load models
@@ -94,6 +94,7 @@ def infer(
ref_text,
gen_text,
ema_model,
vocoder,
cross_fade_duration=cross_fade_duration,
speed=speed,
show_info=show_info,

View File

@@ -29,9 +29,6 @@ _ref_audio_cache = {}
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
# -----------------------------------------
target_sample_rate = 24000
@@ -263,6 +260,7 @@ def infer_process(
ref_text,
gen_text,
model_obj,
vocoder,
show_info=print,
progress=tqdm,
target_rms=target_rms,
@@ -287,6 +285,7 @@ def infer_process(
ref_text,
gen_text_batches,
model_obj,
vocoder,
progress=progress,
target_rms=target_rms,
cross_fade_duration=cross_fade_duration,
@@ -307,6 +306,7 @@ def infer_batch_process(
ref_text,
gen_text_batches,
model_obj,
vocoder,
progress=tqdm,
target_rms=0.1,
cross_fade_duration=0.15,
@@ -362,7 +362,7 @@ def infer_batch_process(
generated = generated.to(torch.float32)
generated = generated[:, ref_audio_len:, :]
generated_mel_spec = generated.permute(0, 2, 1)
generated_wave = vocos.decode(generated_mel_spec.cpu())
generated_wave = vocoder.decode(generated_mel_spec.cpu())
if rms < target_rms:
generated_wave = generated_wave * rms / target_rms