mirror of
https://github.com/SWivid/F5-TTS.git
synced 2025-12-12 07:40:43 -08:00
clean-up eval scripts
This commit is contained in:
@@ -14,16 +14,20 @@ pip install -e .[eval]
|
|||||||
1. *Seed-TTS testset*: Download from [seed-tts-eval](https://github.com/BytedanceSpeech/seed-tts-eval).
|
1. *Seed-TTS testset*: Download from [seed-tts-eval](https://github.com/BytedanceSpeech/seed-tts-eval).
|
||||||
2. *LibriSpeech test-clean*: Download from [OpenSLR](http://www.openslr.org/12/).
|
2. *LibriSpeech test-clean*: Download from [OpenSLR](http://www.openslr.org/12/).
|
||||||
3. Unzip the downloaded datasets and place them in the `data/` directory.
|
3. Unzip the downloaded datasets and place them in the `data/` directory.
|
||||||
4. Update the path for *LibriSpeech test-clean* data in `src/f5_tts/eval/eval_infer_batch.py`
|
4. Our filtered LibriSpeech-PC 4-10s subset: `data/librispeech_pc_test_clean_cross_sentence.lst`
|
||||||
5. Our filtered LibriSpeech-PC 4-10s subset: `data/librispeech_pc_test_clean_cross_sentence.lst`
|
|
||||||
|
|
||||||
### Batch Inference for Test Set
|
### Batch Inference for Test Set
|
||||||
|
|
||||||
To run batch inference for evaluations, execute the following commands:
|
To run batch inference for evaluations, execute the following commands:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# batch inference for evaluations
|
# if not setup accelerate config yet
|
||||||
accelerate config # if not set before
|
accelerate config
|
||||||
|
|
||||||
|
# if only perform inference
|
||||||
|
bash src/f5_tts/eval/eval_infer_batch.sh --infer-only
|
||||||
|
|
||||||
|
# if inference and with corresponding evaluation, setup the following tools first
|
||||||
bash src/f5_tts/eval/eval_infer_batch.sh
|
bash src/f5_tts/eval/eval_infer_batch.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -35,9 +39,13 @@ bash src/f5_tts/eval/eval_infer_batch.sh
|
|||||||
2. English ASR Model: [Faster-Whisper](https://huggingface.co/Systran/faster-whisper-large-v3)
|
2. English ASR Model: [Faster-Whisper](https://huggingface.co/Systran/faster-whisper-large-v3)
|
||||||
3. WavLM Model: Download from [Google Drive](https://drive.google.com/file/d/1-aE1NfzpRCLxA4GUxX9ITI3F9LlbtEGP/view).
|
3. WavLM Model: Download from [Google Drive](https://drive.google.com/file/d/1-aE1NfzpRCLxA4GUxX9ITI3F9LlbtEGP/view).
|
||||||
|
|
||||||
Then update in the following scripts with the paths you put evaluation model ckpts to.
|
> [!NOTE]
|
||||||
|
> ASR model will be automatically downloaded if `--local` not set for evaluation scripts.
|
||||||
|
> Otherwise, you should update the `asr_ckpt_dir` path values in `eval_librispeech_test_clean.py` or `eval_seedtts_testset.py`.
|
||||||
|
>
|
||||||
|
> WavLM model must be downloaded and your `wavlm_ckpt_dir` path updated in `eval_librispeech_test_clean.py` and `eval_seedtts_testset.py`.
|
||||||
|
|
||||||
### Objective Evaluation
|
### Objective Evaluation Examples
|
||||||
|
|
||||||
Update the path with your batch-inferenced results, and carry out WER / SIM / UTMOS evaluations:
|
Update the path with your batch-inferenced results, and carry out WER / SIM / UTMOS evaluations:
|
||||||
```bash
|
```bash
|
||||||
@@ -50,3 +58,6 @@ python src/f5_tts/eval/eval_librispeech_test_clean.py --eval_task sim --gen_wav_
|
|||||||
# Evaluation [UTMOS]. --ext: Audio extension
|
# Evaluation [UTMOS]. --ext: Audio extension
|
||||||
python src/f5_tts/eval/eval_utmos.py --audio_dir <WAV_DIR> --ext wav
|
python src/f5_tts/eval/eval_utmos.py --audio_dir <WAV_DIR> --ext wav
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> Evaluation results can also be found in `_*_results.jsonl` files saved in `<GEN_WAV_DIR>`/`<WAV_DIR>`.
|
||||||
|
|||||||
@@ -48,6 +48,11 @@ def main():
|
|||||||
parser.add_argument("-ss", "--swaysampling", default=-1, type=float)
|
parser.add_argument("-ss", "--swaysampling", default=-1, type=float)
|
||||||
|
|
||||||
parser.add_argument("-t", "--testset", required=True)
|
parser.add_argument("-t", "--testset", required=True)
|
||||||
|
parser.add_argument(
|
||||||
|
"-p", "--librispeech_test_clean_path", default=f"{rel_path}/data/LibriSpeech/test-clean", type=str
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument("--local", action="store_true", help="Use local vocoder checkpoint directory")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@@ -83,7 +88,7 @@ def main():
|
|||||||
|
|
||||||
if testset == "ls_pc_test_clean":
|
if testset == "ls_pc_test_clean":
|
||||||
metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"
|
metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"
|
||||||
librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
|
librispeech_test_clean_path = args.librispeech_test_clean_path
|
||||||
metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path)
|
metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path)
|
||||||
|
|
||||||
elif testset == "seedtts_test_zh":
|
elif testset == "seedtts_test_zh":
|
||||||
@@ -121,7 +126,7 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Vocoder model
|
# Vocoder model
|
||||||
local = False
|
local = args.local
|
||||||
if mel_spec_type == "vocos":
|
if mel_spec_type == "vocos":
|
||||||
vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz"
|
vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz"
|
||||||
elif mel_spec_type == "bigvgan":
|
elif mel_spec_type == "bigvgan":
|
||||||
|
|||||||
@@ -1,18 +1,116 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
set -e
|
||||||
|
export PYTHONWARNINGS="ignore::UserWarning,ignore::FutureWarning"
|
||||||
|
|
||||||
# e.g. F5-TTS, 16 NFE
|
# Configuration parameters
|
||||||
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "seedtts_test_zh" -nfe 16
|
MODEL_NAME="F5TTS_v1_Base"
|
||||||
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "seedtts_test_en" -nfe 16
|
SEEDS=(0 1 2)
|
||||||
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "ls_pc_test_clean" -nfe 16
|
CKPTSTEPS=(1250000)
|
||||||
|
TASKS=("seedtts_test_zh" "seedtts_test_en" "ls_pc_test_clean")
|
||||||
|
LS_TEST_CLEAN_PATH="data/LibriSpeech/test-clean"
|
||||||
|
GPUS="[0,1,2,3,4,5,6,7]"
|
||||||
|
OFFLINE_MODE=false
|
||||||
|
|
||||||
# e.g. Vanilla E2 TTS, 32 NFE
|
# Parse arguments
|
||||||
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "seedtts_test_zh" -o "midpoint" -ss 0
|
if [ $OFFLINE_MODE = true ]; then
|
||||||
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "seedtts_test_en" -o "midpoint" -ss 0
|
LOCAL="--local"
|
||||||
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "ls_pc_test_clean" -o "midpoint" -ss 0
|
else
|
||||||
|
LOCAL=""
|
||||||
|
fi
|
||||||
|
INFER_ONLY=false
|
||||||
|
while [[ $# -gt 0 ]]; do
|
||||||
|
case $1 in
|
||||||
|
--infer-only)
|
||||||
|
INFER_ONLY=true
|
||||||
|
shift
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
echo "======== Unknown parameter: $1"
|
||||||
|
exit 1
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
done
|
||||||
|
|
||||||
# e.g. evaluate F5-TTS 16 NFE result on Seed-TTS test-zh
|
echo "======== Starting F5-TTS batch evaluation task..."
|
||||||
python src/f5_tts/eval/eval_seedtts_testset.py -e wer -l zh --gen_wav_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe16_vocos_ss-1_cfg2.0_speed1.0 --gpu_nums 8
|
if [ "$INFER_ONLY" = true ]; then
|
||||||
python src/f5_tts/eval/eval_seedtts_testset.py -e sim -l zh --gen_wav_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe16_vocos_ss-1_cfg2.0_speed1.0 --gpu_nums 8
|
echo "======== Mode: Execute infer tasks only"
|
||||||
python src/f5_tts/eval/eval_utmos.py --audio_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe16_vocos_ss-1_cfg2.0_speed1.0
|
else
|
||||||
|
echo "======== Mode: Execute full pipeline (infer + eval)"
|
||||||
|
fi
|
||||||
|
|
||||||
# etc.
|
# Function: Execute eval tasks
|
||||||
|
execute_eval_tasks() {
|
||||||
|
local ckptstep=$1
|
||||||
|
local seed=$2
|
||||||
|
local task_name=$3
|
||||||
|
|
||||||
|
local gen_wav_dir="results/${MODEL_NAME}_${ckptstep}/${task_name}/seed${seed}_euler_nfe32_vocos_ss-1_cfg2.0_speed1.0"
|
||||||
|
|
||||||
|
echo ">>>>>>>> Starting eval task: ckptstep=${ckptstep}, seed=${seed}, task=${task_name}"
|
||||||
|
|
||||||
|
case $task_name in
|
||||||
|
"seedtts_test_zh")
|
||||||
|
python src/f5_tts/eval/eval_seedtts_testset.py -e wer -l zh -g "$gen_wav_dir" -n "$GPUS" $LOCAL
|
||||||
|
python src/f5_tts/eval/eval_seedtts_testset.py -e sim -l zh -g "$gen_wav_dir" -n "$GPUS" $LOCAL
|
||||||
|
python src/f5_tts/eval/eval_utmos.py --audio_dir "$gen_wav_dir"
|
||||||
|
;;
|
||||||
|
"seedtts_test_en")
|
||||||
|
python src/f5_tts/eval/eval_seedtts_testset.py -e wer -l en -g "$gen_wav_dir" -n "$GPUS" $LOCAL
|
||||||
|
python src/f5_tts/eval/eval_seedtts_testset.py -e sim -l en -g "$gen_wav_dir" -n "$GPUS" $LOCAL
|
||||||
|
python src/f5_tts/eval/eval_utmos.py --audio_dir "$gen_wav_dir"
|
||||||
|
;;
|
||||||
|
"ls_pc_test_clean")
|
||||||
|
python src/f5_tts/eval/eval_librispeech_test_clean.py -e wer -g "$gen_wav_dir" -n "$GPUS" -p "$LS_TEST_CLEAN_PATH" $LOCAL
|
||||||
|
python src/f5_tts/eval/eval_librispeech_test_clean.py -e sim -g "$gen_wav_dir" -n "$GPUS" -p "$LS_TEST_CLEAN_PATH" $LOCAL
|
||||||
|
python src/f5_tts/eval/eval_utmos.py --audio_dir "$gen_wav_dir"
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
|
||||||
|
echo ">>>>>>>> Completed eval task: ckptstep=${ckptstep}, seed=${seed}, task=${task_name}"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Main execution loop
|
||||||
|
for ckptstep in "${CKPTSTEPS[@]}"; do
|
||||||
|
echo "======== Processing ckptstep: ${ckptstep}"
|
||||||
|
|
||||||
|
for seed in "${SEEDS[@]}"; do
|
||||||
|
echo "-------- Processing seed: ${seed}"
|
||||||
|
|
||||||
|
# Store eval task PIDs for current seed (if not infer-only mode)
|
||||||
|
if [ "$INFER_ONLY" = false ]; then
|
||||||
|
declare -a eval_pids
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Execute each infer task sequentially
|
||||||
|
for task in "${TASKS[@]}"; do
|
||||||
|
echo ">>>>>>>> Executing infer task: accelerate launch src/f5_tts/eval/eval_infer_batch.py -s ${seed} -n \"${MODEL_NAME}\" -t \"${task}\" -c ${ckptstep} $LOCAL"
|
||||||
|
|
||||||
|
# Execute infer task (foreground execution, wait for completion)
|
||||||
|
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s ${seed} -n "${MODEL_NAME}" -t "${task}" -c ${ckptstep} -p "${LS_TEST_CLEAN_PATH}" $LOCAL
|
||||||
|
|
||||||
|
# If not infer-only mode, launch corresponding eval task
|
||||||
|
if [ "$INFER_ONLY" = false ]; then
|
||||||
|
# Launch corresponding eval task (background execution, non-blocking for next infer)
|
||||||
|
execute_eval_tasks $ckptstep $seed $task &
|
||||||
|
eval_pids+=($!)
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
# If not infer-only mode, wait for all eval tasks of current seed to complete
|
||||||
|
if [ "$INFER_ONLY" = false ]; then
|
||||||
|
echo ">>>>>>>> All infer tasks for seed ${seed} completed, waiting for corresponding eval tasks to finish..."
|
||||||
|
|
||||||
|
for pid in "${eval_pids[@]}"; do
|
||||||
|
wait $pid
|
||||||
|
done
|
||||||
|
|
||||||
|
unset eval_pids # Clean up array
|
||||||
|
fi
|
||||||
|
echo "-------- All eval tasks for seed ${seed} completed"
|
||||||
|
done
|
||||||
|
|
||||||
|
echo "======== Completed ckptstep: ${ckptstep}"
|
||||||
|
echo
|
||||||
|
done
|
||||||
|
|
||||||
|
echo "======== All tasks completed!"
|
||||||
18
src/f5_tts/eval/eval_infer_batch_example.sh
Normal file
18
src/f5_tts/eval/eval_infer_batch_example.sh
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# e.g. F5-TTS, 16 NFE
|
||||||
|
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "seedtts_test_zh" -nfe 16
|
||||||
|
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "seedtts_test_en" -nfe 16
|
||||||
|
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "ls_pc_test_clean" -nfe 16 -p data/LibriSpeech/test-clean
|
||||||
|
|
||||||
|
# e.g. Vanilla E2 TTS, 32 NFE
|
||||||
|
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "seedtts_test_zh" -o "midpoint" -ss 0
|
||||||
|
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "seedtts_test_en" -o "midpoint" -ss 0
|
||||||
|
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "ls_pc_test_clean" -o "midpoint" -ss 0 -p data/LibriSpeech/test-clean
|
||||||
|
|
||||||
|
# e.g. evaluate F5-TTS 16 NFE result on Seed-TTS test-zh
|
||||||
|
python src/f5_tts/eval/eval_seedtts_testset.py -e wer -l zh --gen_wav_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe16_vocos_ss-1_cfg2.0_speed1.0 --gpu_nums 8
|
||||||
|
python src/f5_tts/eval/eval_seedtts_testset.py -e sim -l zh --gen_wav_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe16_vocos_ss-1_cfg2.0_speed1.0 --gpu_nums 8
|
||||||
|
python src/f5_tts/eval/eval_utmos.py --audio_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe16_vocos_ss-1_cfg2.0_speed1.0
|
||||||
|
|
||||||
|
# etc.
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
# Evaluate with Librispeech test-clean, ~3s prompt to generate 4-10s audio (the way of valle/voicebox evaluation)
|
# Evaluate with Librispeech test-clean, ~3s prompt to generate 4-10s audio (the way of valle/voicebox evaluation)
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import ast
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
@@ -25,11 +26,26 @@ def get_args():
|
|||||||
parser.add_argument("-l", "--lang", type=str, default="en")
|
parser.add_argument("-l", "--lang", type=str, default="en")
|
||||||
parser.add_argument("-g", "--gen_wav_dir", type=str, required=True)
|
parser.add_argument("-g", "--gen_wav_dir", type=str, required=True)
|
||||||
parser.add_argument("-p", "--librispeech_test_clean_path", type=str, required=True)
|
parser.add_argument("-p", "--librispeech_test_clean_path", type=str, required=True)
|
||||||
parser.add_argument("-n", "--gpu_nums", type=int, default=8, help="Number of GPUs to use")
|
parser.add_argument(
|
||||||
|
"-n", "--gpu_nums", type=str, default="8", help="Number of GPUs to use (e.g., 8) or GPU list (e.g., [0,1,2,3])"
|
||||||
|
)
|
||||||
parser.add_argument("--local", action="store_true", help="Use local custom checkpoint directory")
|
parser.add_argument("--local", action="store_true", help="Use local custom checkpoint directory")
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def parse_gpu_nums(gpu_nums_str):
|
||||||
|
try:
|
||||||
|
if gpu_nums_str.startswith("[") and gpu_nums_str.endswith("]"):
|
||||||
|
gpu_list = ast.literal_eval(gpu_nums_str)
|
||||||
|
if isinstance(gpu_list, list):
|
||||||
|
return gpu_list
|
||||||
|
return list(range(int(gpu_nums_str)))
|
||||||
|
except (ValueError, SyntaxError):
|
||||||
|
raise argparse.ArgumentTypeError(
|
||||||
|
f"Invalid GPU specification: {gpu_nums_str}. Use a number (e.g., 8) or a list (e.g., [0,1,2,3])"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = get_args()
|
args = get_args()
|
||||||
eval_task = args.eval_task
|
eval_task = args.eval_task
|
||||||
@@ -38,7 +54,7 @@ def main():
|
|||||||
gen_wav_dir = args.gen_wav_dir
|
gen_wav_dir = args.gen_wav_dir
|
||||||
metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"
|
metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"
|
||||||
|
|
||||||
gpus = list(range(args.gpu_nums))
|
gpus = parse_gpu_nums(args.gpu_nums)
|
||||||
test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path)
|
test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path)
|
||||||
|
|
||||||
## In LibriSpeech, some speakers utilized varying voice characteristics for different characters in the book,
|
## In LibriSpeech, some speakers utilized varying voice characteristics for different characters in the book,
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
# Evaluate with Seed-TTS testset
|
# Evaluate with Seed-TTS testset
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import ast
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
@@ -24,11 +25,26 @@ def get_args():
|
|||||||
parser.add_argument("-e", "--eval_task", type=str, default="wer", choices=["sim", "wer"])
|
parser.add_argument("-e", "--eval_task", type=str, default="wer", choices=["sim", "wer"])
|
||||||
parser.add_argument("-l", "--lang", type=str, default="en", choices=["zh", "en"])
|
parser.add_argument("-l", "--lang", type=str, default="en", choices=["zh", "en"])
|
||||||
parser.add_argument("-g", "--gen_wav_dir", type=str, required=True)
|
parser.add_argument("-g", "--gen_wav_dir", type=str, required=True)
|
||||||
parser.add_argument("-n", "--gpu_nums", type=int, default=8, help="Number of GPUs to use")
|
parser.add_argument(
|
||||||
|
"-n", "--gpu_nums", type=str, default="8", help="Number of GPUs to use (e.g., 8) or GPU list (e.g., [0,1,2,3])"
|
||||||
|
)
|
||||||
parser.add_argument("--local", action="store_true", help="Use local custom checkpoint directory")
|
parser.add_argument("--local", action="store_true", help="Use local custom checkpoint directory")
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def parse_gpu_nums(gpu_nums_str):
|
||||||
|
try:
|
||||||
|
if gpu_nums_str.startswith("[") and gpu_nums_str.endswith("]"):
|
||||||
|
gpu_list = ast.literal_eval(gpu_nums_str)
|
||||||
|
if isinstance(gpu_list, list):
|
||||||
|
return gpu_list
|
||||||
|
return list(range(int(gpu_nums_str)))
|
||||||
|
except (ValueError, SyntaxError):
|
||||||
|
raise argparse.ArgumentTypeError(
|
||||||
|
f"Invalid GPU specification: {gpu_nums_str}. Use a number (e.g., 8) or a list (e.g., [0,1,2,3])"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = get_args()
|
args = get_args()
|
||||||
eval_task = args.eval_task
|
eval_task = args.eval_task
|
||||||
@@ -38,7 +54,7 @@ def main():
|
|||||||
|
|
||||||
# NOTE. paraformer-zh result will be slightly different according to the number of gpus, cuz batchsize is different
|
# NOTE. paraformer-zh result will be slightly different according to the number of gpus, cuz batchsize is different
|
||||||
# zh 1.254 seems a result of 4 workers wer_seed_tts
|
# zh 1.254 seems a result of 4 workers wer_seed_tts
|
||||||
gpus = list(range(args.gpu_nums))
|
gpus = parse_gpu_nums(args.gpu_nums)
|
||||||
test_set = get_seed_tts_test(metalst, gen_wav_dir, gpus)
|
test_set = get_seed_tts_test(metalst, gen_wav_dir, gpus)
|
||||||
|
|
||||||
local = args.local
|
local = args.local
|
||||||
|
|||||||
@@ -395,14 +395,21 @@ def run_sim(args):
|
|||||||
wav1, sr1 = torchaudio.load(gen_wav)
|
wav1, sr1 = torchaudio.load(gen_wav)
|
||||||
wav2, sr2 = torchaudio.load(prompt_wav)
|
wav2, sr2 = torchaudio.load(prompt_wav)
|
||||||
|
|
||||||
resample1 = torchaudio.transforms.Resample(orig_freq=sr1, new_freq=16000)
|
|
||||||
resample2 = torchaudio.transforms.Resample(orig_freq=sr2, new_freq=16000)
|
|
||||||
wav1 = resample1(wav1)
|
|
||||||
wav2 = resample2(wav2)
|
|
||||||
|
|
||||||
if use_gpu:
|
if use_gpu:
|
||||||
wav1 = wav1.cuda(device)
|
wav1 = wav1.cuda(device)
|
||||||
wav2 = wav2.cuda(device)
|
wav2 = wav2.cuda(device)
|
||||||
|
|
||||||
|
if sr1 != 16000:
|
||||||
|
resample1 = torchaudio.transforms.Resample(orig_freq=sr1, new_freq=16000)
|
||||||
|
if use_gpu:
|
||||||
|
resample1 = resample1.cuda(device)
|
||||||
|
wav1 = resample1(wav1)
|
||||||
|
if sr2 != 16000:
|
||||||
|
resample2 = torchaudio.transforms.Resample(orig_freq=sr2, new_freq=16000)
|
||||||
|
if use_gpu:
|
||||||
|
resample2 = resample2.cuda(device)
|
||||||
|
wav2 = resample2(wav2)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
emb1 = model(wav1)
|
emb1 = model(wav1)
|
||||||
emb2 = model(wav2)
|
emb2 = model(wav2)
|
||||||
|
|||||||
32
src/f5_tts/scripts/count_max_epoch_precise.py
Normal file
32
src/f5_tts/scripts/count_max_epoch_precise.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
|
from torch.utils.data import SequentialSampler
|
||||||
|
|
||||||
|
from f5_tts.model.dataset import DynamicBatchSampler, load_dataset
|
||||||
|
|
||||||
|
|
||||||
|
train_dataset = load_dataset("Emilia_ZH_EN", "pinyin")
|
||||||
|
sampler = SequentialSampler(train_dataset)
|
||||||
|
|
||||||
|
gpus = 8
|
||||||
|
batch_size_per_gpu = 38400
|
||||||
|
max_samples_per_gpu = 64
|
||||||
|
max_updates = 1250000
|
||||||
|
|
||||||
|
batch_sampler = DynamicBatchSampler(
|
||||||
|
sampler,
|
||||||
|
batch_size_per_gpu,
|
||||||
|
max_samples=max_samples_per_gpu,
|
||||||
|
random_seed=666,
|
||||||
|
drop_residual=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"One epoch has {len(batch_sampler) / gpus} updates if gpus={gpus}, with "
|
||||||
|
f"batch_size_per_gpu={batch_size_per_gpu} (frames) & "
|
||||||
|
f"max_samples_per_gpu={max_samples_per_gpu}."
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"If gpus={gpus}, for max_updates={max_updates} "
|
||||||
|
f"should set epoch={math.ceil(max_updates / len(batch_sampler) * gpus)}."
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user