From 7abadc4c72adcc3e3e4ac5105981fe3fe5285bea Mon Sep 17 00:00:00 2001 From: SWivid Date: Sun, 26 Oct 2025 14:28:17 +0000 Subject: [PATCH] fix typo in eval scripts --- src/f5_tts/eval/eval_infer_batch.py | 8 +++++++- src/f5_tts/eval/eval_infer_batch.sh | 6 +++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/f5_tts/eval/eval_infer_batch.py b/src/f5_tts/eval/eval_infer_batch.py index cea5b7a..01fc15f 100644 --- a/src/f5_tts/eval/eval_infer_batch.py +++ b/src/f5_tts/eval/eval_infer_batch.py @@ -155,7 +155,13 @@ def main(): ckpt_path = ckpt_prefix + ".safetensors" else: print("Loading from self-organized training checkpoints rather than released pretrained.") - ckpt_path = rel_path + f"/{model_cfg.ckpts.save_dir}/model_{ckpt_step}.pt" + ckpt_prefix = rel_path + f"/{model_cfg.ckpts.save_dir}/model_{ckpt_step}" + if os.path.exists(ckpt_prefix + ".pt"): + ckpt_path = ckpt_prefix + ".pt" + elif os.path.exists(ckpt_prefix + ".safetensors"): + ckpt_path = ckpt_prefix + ".safetensors" + else: + raise ValueError("The checkpoint does not exist or cannot be found in given location.") dtype = torch.float32 if mel_spec_type == "bigvgan" else None model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema) diff --git a/src/f5_tts/eval/eval_infer_batch.sh b/src/f5_tts/eval/eval_infer_batch.sh index a5b4f63..e7b81af 100644 --- a/src/f5_tts/eval/eval_infer_batch.sh +++ b/src/f5_tts/eval/eval_infer_batch.sh @@ -11,8 +11,8 @@ accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 12 accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "ls_pc_test_clean" -o "midpoint" -ss 0 # e.g. evaluate F5-TTS 16 NFE result on Seed-TTS test-zh -python src/f5_tts/eval/eval_seedtts_testset.py -e wer -l zh --gen_wav_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe32_vocos_ss-1_cfg2.0_speed1.0 --gpu_nums 8 -python src/f5_tts/eval/eval_seedtts_testset.py -e sim -l zh --gen_wav_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe32_vocos_ss-1_cfg2.0_speed1.0 --gpu_nums 8 -python src/f5_tts/eval/eval_utmos.py --audio_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe32_vocos_ss-1_cfg2.0_speed1.0 +python src/f5_tts/eval/eval_seedtts_testset.py -e wer -l zh --gen_wav_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe16_vocos_ss-1_cfg2.0_speed1.0 --gpu_nums 8 +python src/f5_tts/eval/eval_seedtts_testset.py -e sim -l zh --gen_wav_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe16_vocos_ss-1_cfg2.0_speed1.0 --gpu_nums 8 +python src/f5_tts/eval/eval_utmos.py --audio_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe16_vocos_ss-1_cfg2.0_speed1.0 # etc.