mirror of
https://github.com/SWivid/F5-TTS.git
synced 2025-12-12 07:40:43 -08:00
fix vocoder loading
This commit is contained in:
@@ -47,7 +47,7 @@ class F5TTS:
|
|||||||
self.load_ema_model(model_type, ckpt_file, vocab_file, ode_method, use_ema)
|
self.load_ema_model(model_type, ckpt_file, vocab_file, ode_method, use_ema)
|
||||||
|
|
||||||
def load_vocoder_model(self, local_path):
|
def load_vocoder_model(self, local_path):
|
||||||
self.vocos = load_vocoder(local_path is not None, local_path, self.device)
|
self.vocoder = load_vocoder(local_path is not None, local_path, self.device)
|
||||||
|
|
||||||
def load_ema_model(self, model_type, ckpt_file, vocab_file, ode_method, use_ema):
|
def load_ema_model(self, model_type, ckpt_file, vocab_file, ode_method, use_ema):
|
||||||
if model_type == "F5-TTS":
|
if model_type == "F5-TTS":
|
||||||
@@ -102,6 +102,7 @@ class F5TTS:
|
|||||||
ref_text,
|
ref_text,
|
||||||
gen_text,
|
gen_text,
|
||||||
self.ema_model,
|
self.ema_model,
|
||||||
|
self.vocoder,
|
||||||
show_info=show_info,
|
show_info=show_info,
|
||||||
progress=progress,
|
progress=progress,
|
||||||
target_rms=target_rms,
|
target_rms=target_rms,
|
||||||
|
|||||||
@@ -113,7 +113,7 @@ wave_path = Path(output_dir) / "infer_cli_out.wav"
|
|||||||
# spectrogram_path = Path(output_dir) / "infer_cli_out.png"
|
# spectrogram_path = Path(output_dir) / "infer_cli_out.png"
|
||||||
vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
|
vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
|
||||||
|
|
||||||
vocos = load_vocoder(is_local=args.load_vocoder_from_local, local_path=vocos_local_path)
|
vocoder = load_vocoder(is_local=args.load_vocoder_from_local, local_path=vocos_local_path)
|
||||||
|
|
||||||
|
|
||||||
# load models
|
# load models
|
||||||
@@ -175,7 +175,9 @@ def main_process(ref_audio, ref_text, text_gen, model_obj, remove_silence, speed
|
|||||||
ref_audio = voices[voice]["ref_audio"]
|
ref_audio = voices[voice]["ref_audio"]
|
||||||
ref_text = voices[voice]["ref_text"]
|
ref_text = voices[voice]["ref_text"]
|
||||||
print(f"Voice: {voice}")
|
print(f"Voice: {voice}")
|
||||||
audio, final_sample_rate, spectragram = infer_process(ref_audio, ref_text, gen_text, model_obj, speed=speed)
|
audio, final_sample_rate, spectragram = infer_process(
|
||||||
|
ref_audio, ref_text, gen_text, model_obj, vocoder, speed=speed
|
||||||
|
)
|
||||||
generated_audio_segments.append(audio)
|
generated_audio_segments.append(audio)
|
||||||
|
|
||||||
if generated_audio_segments:
|
if generated_audio_segments:
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ from f5_tts.infer.utils_infer import (
|
|||||||
save_spectrogram,
|
save_spectrogram,
|
||||||
)
|
)
|
||||||
|
|
||||||
vocos = load_vocoder()
|
vocoder = load_vocoder()
|
||||||
|
|
||||||
|
|
||||||
# load models
|
# load models
|
||||||
@@ -94,6 +94,7 @@ def infer(
|
|||||||
ref_text,
|
ref_text,
|
||||||
gen_text,
|
gen_text,
|
||||||
ema_model,
|
ema_model,
|
||||||
|
vocoder,
|
||||||
cross_fade_duration=cross_fade_duration,
|
cross_fade_duration=cross_fade_duration,
|
||||||
speed=speed,
|
speed=speed,
|
||||||
show_info=show_info,
|
show_info=show_info,
|
||||||
|
|||||||
@@ -29,9 +29,6 @@ _ref_audio_cache = {}
|
|||||||
|
|
||||||
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||||
|
|
||||||
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------
|
# -----------------------------------------
|
||||||
|
|
||||||
target_sample_rate = 24000
|
target_sample_rate = 24000
|
||||||
@@ -263,6 +260,7 @@ def infer_process(
|
|||||||
ref_text,
|
ref_text,
|
||||||
gen_text,
|
gen_text,
|
||||||
model_obj,
|
model_obj,
|
||||||
|
vocoder,
|
||||||
show_info=print,
|
show_info=print,
|
||||||
progress=tqdm,
|
progress=tqdm,
|
||||||
target_rms=target_rms,
|
target_rms=target_rms,
|
||||||
@@ -287,6 +285,7 @@ def infer_process(
|
|||||||
ref_text,
|
ref_text,
|
||||||
gen_text_batches,
|
gen_text_batches,
|
||||||
model_obj,
|
model_obj,
|
||||||
|
vocoder,
|
||||||
progress=progress,
|
progress=progress,
|
||||||
target_rms=target_rms,
|
target_rms=target_rms,
|
||||||
cross_fade_duration=cross_fade_duration,
|
cross_fade_duration=cross_fade_duration,
|
||||||
@@ -307,6 +306,7 @@ def infer_batch_process(
|
|||||||
ref_text,
|
ref_text,
|
||||||
gen_text_batches,
|
gen_text_batches,
|
||||||
model_obj,
|
model_obj,
|
||||||
|
vocoder,
|
||||||
progress=tqdm,
|
progress=tqdm,
|
||||||
target_rms=0.1,
|
target_rms=0.1,
|
||||||
cross_fade_duration=0.15,
|
cross_fade_duration=0.15,
|
||||||
@@ -362,7 +362,7 @@ def infer_batch_process(
|
|||||||
generated = generated.to(torch.float32)
|
generated = generated.to(torch.float32)
|
||||||
generated = generated[:, ref_audio_len:, :]
|
generated = generated[:, ref_audio_len:, :]
|
||||||
generated_mel_spec = generated.permute(0, 2, 1)
|
generated_mel_spec = generated.permute(0, 2, 1)
|
||||||
generated_wave = vocos.decode(generated_mel_spec.cpu())
|
generated_wave = vocoder.decode(generated_mel_spec.cpu())
|
||||||
if rms < target_rms:
|
if rms < target_rms:
|
||||||
generated_wave = generated_wave * rms / target_rms
|
generated_wave = generated_wave * rms / target_rms
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user