refactor: del global params and set vocos as default vocoder, add dtype check

This commit is contained in:
ZhikangNiu
2024-11-01 14:17:22 +08:00
parent b180961782
commit 18e1ab508f
5 changed files with 28 additions and 17 deletions

View File

@@ -45,7 +45,7 @@ git clone https://github.com/SWivid/F5-TTS.git
cd F5-TTS
pip install -e .
# Init submodule(optional, if you want to change the vocoder from vocos to bigvgan)
# Init submodule (optional, if you want to change the vocoder from vocos to bigvgan)
# git submodule update --init --recursive
# pip install -e .
```

View File

@@ -32,7 +32,6 @@ n_mel_channels = 100
hop_length = 256
win_length = 1024
n_fft = 1024
mel_spec_type = "bigvgan" # 'vocos' or 'bigvgan'
target_rms = 0.1
@@ -49,6 +48,7 @@ def main():
parser.add_argument("-d", "--dataset", default="Emilia_ZH_EN")
parser.add_argument("-n", "--expname", required=True)
parser.add_argument("-c", "--ckptstep", default=1200000, type=int)
parser.add_argument("-m", "--mel_spec_type", default="vocos", type=str, choices=["bigvgan", "vocos"])
parser.add_argument("-nfe", "--nfestep", default=32, type=int)
parser.add_argument("-o", "--odemethod", default="euler")
@@ -63,6 +63,7 @@ def main():
exp_name = args.expname
ckpt_step = args.ckptstep
ckpt_path = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}.pt"
mel_spec_type = args.mel_spec_type
nfe_step = args.nfestep
ode_method = args.odemethod
@@ -101,7 +102,7 @@ def main():
output_dir = (
f"{rel_path}/"
f"results/{exp_name}_{ckpt_step}/{testset}/"
f"seed{seed}_{ode_method}_nfe{nfe_step}"
f"seed{seed}_{ode_method}_nfe{nfe_step}_{mel_spec_type}"
f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}"
f"_cfg{cfg_strength}_speed{speed}"
f"{'_gt-dur' if use_truth_duration else ''}"
@@ -155,10 +156,10 @@ def main():
supports_fp16 = device == "cuda" and torch.cuda.get_device_properties(device).major >= 6
if supports_fp16 and mel_spec_type == "vocos":
dtype = torch.float16
else:
elif mel_spec_type == "bigvgan":
dtype = torch.float32
model = load_checkpoint(model, ckpt_path, device, dtype, use_ema=use_ema)
model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
if not os.path.exists(output_dir) and accelerator.is_main_process:
os.makedirs(output_dir)

View File

@@ -154,7 +154,7 @@ elif model == "E2-TTS":
print(f"Using {model}...")
ema_model = load_model(model_cls, model_cfg, ckpt_file, args.vocoder_name, vocab_file)
ema_model = load_model(model_cls, model_cfg, ckpt_file, mel_spec_type=args.vocoder_name, vocab_file=vocab_file)
def main_process(ref_audio, ref_text, text_gen, model_obj, mel_spec_type, remove_silence, speed):
@@ -192,7 +192,7 @@ def main_process(ref_audio, ref_text, text_gen, model_obj, mel_spec_type, remove
ref_text = voices[voice]["ref_text"]
print(f"Voice: {voice}")
audio, final_sample_rate, spectragram = infer_process(
ref_audio, ref_text, gen_text, model_obj, vocoder, mel_spec_type, speed=speed
ref_audio, ref_text, gen_text, model_obj, vocoder, mel_spec_type=mel_spec_type, speed=speed
)
generated_audio_segments.append(audio)

View File

@@ -18,7 +18,7 @@ n_mel_channels = 100
hop_length = 256
win_length = 1024
n_fft = 1024
mel_spec_type = "bigvgan" # 'vocos' or 'bigvgan'
mel_spec_type = "vocos" # 'vocos' or 'bigvgan'
target_rms = 0.1
tokenizer = "pinyin"
@@ -114,10 +114,10 @@ model = CFM(
supports_fp16 = device == "cuda" and torch.cuda.get_device_properties(device).major >= 6
if supports_fp16 and mel_spec_type == "vocos":
dtype = torch.float16
else:
elif mel_spec_type == "bigvgan":
dtype = torch.float32
model = load_checkpoint(model, ckpt_path, device, dtype, use_ema=use_ema)
model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
# Audio
audio, sr = torchaudio.load(audio_to_edit)

View File

@@ -40,7 +40,6 @@ n_mel_channels = 100
hop_length = 256
win_length = 1024
n_fft = 1024
mel_spec_type = "bigvgan" # 'vocos' or 'bigvgan'
target_rms = 0.1
cross_fade_duration = 0.15
ode_method = "euler"
@@ -133,6 +132,10 @@ def initialize_asr_pipeline(device=device):
def load_checkpoint(model, ckpt_path, device, dtype, use_ema=True):
if dtype is None:
dtype = (
torch.float16 if device == "cuda" and torch.cuda.get_device_properties(device).major >= 6 else torch.float32
)
model = model.to(dtype)
ckpt_type = ckpt_path.split(".")[-1]
@@ -169,7 +172,14 @@ def load_checkpoint(model, ckpt_path, device, dtype, use_ema=True):
def load_model(
model_cls, model_cfg, ckpt_path, mel_spec_type, vocab_file="", ode_method=ode_method, use_ema=True, device=device
model_cls,
model_cfg,
ckpt_path,
mel_spec_type="vocos",
vocab_file="",
ode_method=ode_method,
use_ema=True,
device=device,
):
if vocab_file == "":
vocab_file = str(files("f5_tts").joinpath("infer/examples/vocab.txt"))
@@ -199,10 +209,10 @@ def load_model(
supports_fp16 = device == "cuda" and torch.cuda.get_device_properties(device).major >= 6
if supports_fp16 and mel_spec_type == "vocos":
dtype = torch.float16
else:
elif mel_spec_type == "bigvgan":
dtype = torch.float32
model = load_checkpoint(model, ckpt_path, device, dtype, use_ema=use_ema)
model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
return model
@@ -297,7 +307,7 @@ def infer_process(
gen_text,
model_obj,
vocoder,
mel_spec_type,
mel_spec_type="vocos",
show_info=print,
progress=tqdm,
target_rms=target_rms,
@@ -323,7 +333,7 @@ def infer_process(
gen_text_batches,
model_obj,
vocoder,
mel_spec_type,
mel_spec_type=mel_spec_type,
progress=progress,
target_rms=target_rms,
cross_fade_duration=cross_fade_duration,
@@ -345,7 +355,7 @@ def infer_batch_process(
gen_text_batches,
model_obj,
vocoder,
mel_spec_type,
mel_spec_type="vocos",
progress=tqdm,
target_rms=0.1,
cross_fade_duration=0.15,