mirror of
https://github.com/SWivid/F5-TTS.git
synced 2025-12-12 15:50:07 -08:00
Compare commits
16 Commits
65ada48a62
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9ae46c8360 | ||
|
|
3eecd94baa | ||
|
|
d9a69452ce | ||
|
|
bc15df2b57 | ||
|
|
9b2357a1b9 | ||
|
|
1dcb4e10f7 | ||
|
|
529d856133 | ||
|
|
7abadc4c72 | ||
|
|
e67d50841e | ||
|
|
6b07fb03b2 | ||
|
|
a051a68552 | ||
|
|
f2a4f8581f | ||
|
|
a17c5ae435 | ||
|
|
a0b8fb5df2 | ||
|
|
c8bfc3aa3d | ||
|
|
8d3ec72159 |
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "f5-tts"
|
||||
version = "1.1.9"
|
||||
version = "1.1.10"
|
||||
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
|
||||
readme = "README.md"
|
||||
license = {text = "MIT License"}
|
||||
@@ -22,18 +22,19 @@ dependencies = [
|
||||
"ema_pytorch>=0.5.2",
|
||||
"gradio>=5.0.0",
|
||||
"hydra-core>=1.3.0",
|
||||
"jieba",
|
||||
"librosa",
|
||||
"matplotlib",
|
||||
"numpy<=1.26.4; python_version<='3.10'",
|
||||
"pydantic<=2.10.6",
|
||||
"pydub",
|
||||
"pypinyin",
|
||||
"rjieba",
|
||||
"safetensors",
|
||||
"soundfile",
|
||||
"tomli",
|
||||
"torch>=2.0.0",
|
||||
"torchaudio>=2.0.0",
|
||||
"torchcodec",
|
||||
"torchdiffeq",
|
||||
"tqdm>=4.65.0",
|
||||
"transformers",
|
||||
|
||||
@@ -154,8 +154,8 @@ if __name__ == "__main__":
|
||||
|
||||
wav, sr, spec = f5tts.infer(
|
||||
ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
|
||||
ref_text="some call me nature, others call me mother nature.",
|
||||
gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
|
||||
ref_text="Some call me nature, others call me mother nature.",
|
||||
gen_text="I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring.",
|
||||
file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")),
|
||||
file_spec=str(files("f5_tts").joinpath("../../tests/api_out.png")),
|
||||
seed=None,
|
||||
|
||||
@@ -14,16 +14,20 @@ pip install -e .[eval]
|
||||
1. *Seed-TTS testset*: Download from [seed-tts-eval](https://github.com/BytedanceSpeech/seed-tts-eval).
|
||||
2. *LibriSpeech test-clean*: Download from [OpenSLR](http://www.openslr.org/12/).
|
||||
3. Unzip the downloaded datasets and place them in the `data/` directory.
|
||||
4. Update the path for *LibriSpeech test-clean* data in `src/f5_tts/eval/eval_infer_batch.py`
|
||||
5. Our filtered LibriSpeech-PC 4-10s subset: `data/librispeech_pc_test_clean_cross_sentence.lst`
|
||||
4. Our filtered LibriSpeech-PC 4-10s subset: `data/librispeech_pc_test_clean_cross_sentence.lst`
|
||||
|
||||
### Batch Inference for Test Set
|
||||
|
||||
To run batch inference for evaluations, execute the following commands:
|
||||
|
||||
```bash
|
||||
# batch inference for evaluations
|
||||
accelerate config # if not set before
|
||||
# if not setup accelerate config yet
|
||||
accelerate config
|
||||
|
||||
# if only perform inference
|
||||
bash src/f5_tts/eval/eval_infer_batch.sh --infer-only
|
||||
|
||||
# if inference and with corresponding evaluation, setup the following tools first
|
||||
bash src/f5_tts/eval/eval_infer_batch.sh
|
||||
```
|
||||
|
||||
@@ -35,9 +39,13 @@ bash src/f5_tts/eval/eval_infer_batch.sh
|
||||
2. English ASR Model: [Faster-Whisper](https://huggingface.co/Systran/faster-whisper-large-v3)
|
||||
3. WavLM Model: Download from [Google Drive](https://drive.google.com/file/d/1-aE1NfzpRCLxA4GUxX9ITI3F9LlbtEGP/view).
|
||||
|
||||
Then update in the following scripts with the paths you put evaluation model ckpts to.
|
||||
> [!NOTE]
|
||||
> ASR model will be automatically downloaded if `--local` not set for evaluation scripts.
|
||||
> Otherwise, you should update the `asr_ckpt_dir` path values in `eval_librispeech_test_clean.py` or `eval_seedtts_testset.py`.
|
||||
>
|
||||
> WavLM model must be downloaded and your `wavlm_ckpt_dir` path updated in `eval_librispeech_test_clean.py` and `eval_seedtts_testset.py`.
|
||||
|
||||
### Objective Evaluation
|
||||
### Objective Evaluation Examples
|
||||
|
||||
Update the path with your batch-inferenced results, and carry out WER / SIM / UTMOS evaluations:
|
||||
```bash
|
||||
@@ -50,3 +58,6 @@ python src/f5_tts/eval/eval_librispeech_test_clean.py --eval_task sim --gen_wav_
|
||||
# Evaluation [UTMOS]. --ext: Audio extension
|
||||
python src/f5_tts/eval/eval_utmos.py --audio_dir <WAV_DIR> --ext wav
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> Evaluation results can also be found in `_*_results.jsonl` files saved in `<GEN_WAV_DIR>`/`<WAV_DIR>`.
|
||||
|
||||
@@ -48,6 +48,11 @@ def main():
|
||||
parser.add_argument("-ss", "--swaysampling", default=-1, type=float)
|
||||
|
||||
parser.add_argument("-t", "--testset", required=True)
|
||||
parser.add_argument(
|
||||
"-p", "--librispeech_test_clean_path", default=f"{rel_path}/data/LibriSpeech/test-clean", type=str
|
||||
)
|
||||
|
||||
parser.add_argument("--local", action="store_true", help="Use local vocoder checkpoint directory")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -83,7 +88,7 @@ def main():
|
||||
|
||||
if testset == "ls_pc_test_clean":
|
||||
metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"
|
||||
librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
|
||||
librispeech_test_clean_path = args.librispeech_test_clean_path
|
||||
metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path)
|
||||
|
||||
elif testset == "seedtts_test_zh":
|
||||
@@ -121,7 +126,7 @@ def main():
|
||||
)
|
||||
|
||||
# Vocoder model
|
||||
local = False
|
||||
local = args.local
|
||||
if mel_spec_type == "vocos":
|
||||
vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz"
|
||||
elif mel_spec_type == "bigvgan":
|
||||
@@ -155,7 +160,13 @@ def main():
|
||||
ckpt_path = ckpt_prefix + ".safetensors"
|
||||
else:
|
||||
print("Loading from self-organized training checkpoints rather than released pretrained.")
|
||||
ckpt_path = rel_path + f"/{model_cfg.ckpts.save_dir}/model_{ckpt_step}.pt"
|
||||
ckpt_prefix = rel_path + f"/{model_cfg.ckpts.save_dir}/model_{ckpt_step}"
|
||||
if os.path.exists(ckpt_prefix + ".pt"):
|
||||
ckpt_path = ckpt_prefix + ".pt"
|
||||
elif os.path.exists(ckpt_prefix + ".safetensors"):
|
||||
ckpt_path = ckpt_prefix + ".safetensors"
|
||||
else:
|
||||
raise ValueError("The checkpoint does not exist or cannot be found in given location.")
|
||||
|
||||
dtype = torch.float32 if mel_spec_type == "bigvgan" else None
|
||||
model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
|
||||
|
||||
@@ -1,18 +1,116 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
export PYTHONWARNINGS="ignore::UserWarning,ignore::FutureWarning"
|
||||
|
||||
# e.g. F5-TTS, 16 NFE
|
||||
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "seedtts_test_zh" -nfe 16
|
||||
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "seedtts_test_en" -nfe 16
|
||||
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "ls_pc_test_clean" -nfe 16
|
||||
# Configuration parameters
|
||||
MODEL_NAME="F5TTS_v1_Base"
|
||||
SEEDS=(0 1 2)
|
||||
CKPTSTEPS=(1250000)
|
||||
TASKS=("seedtts_test_zh" "seedtts_test_en" "ls_pc_test_clean")
|
||||
LS_TEST_CLEAN_PATH="data/LibriSpeech/test-clean"
|
||||
GPUS="[0,1,2,3,4,5,6,7]"
|
||||
OFFLINE_MODE=false
|
||||
|
||||
# e.g. Vanilla E2 TTS, 32 NFE
|
||||
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "seedtts_test_zh" -o "midpoint" -ss 0
|
||||
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "seedtts_test_en" -o "midpoint" -ss 0
|
||||
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "ls_pc_test_clean" -o "midpoint" -ss 0
|
||||
# Parse arguments
|
||||
if [ $OFFLINE_MODE = true ]; then
|
||||
LOCAL="--local"
|
||||
else
|
||||
LOCAL=""
|
||||
fi
|
||||
INFER_ONLY=false
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
--infer-only)
|
||||
INFER_ONLY=true
|
||||
shift
|
||||
;;
|
||||
*)
|
||||
echo "======== Unknown parameter: $1"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# e.g. evaluate F5-TTS 16 NFE result on Seed-TTS test-zh
|
||||
python src/f5_tts/eval/eval_seedtts_testset.py -e wer -l zh --gen_wav_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe32_vocos_ss-1_cfg2.0_speed1.0 --gpu_nums 8
|
||||
python src/f5_tts/eval/eval_seedtts_testset.py -e sim -l zh --gen_wav_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe32_vocos_ss-1_cfg2.0_speed1.0 --gpu_nums 8
|
||||
python src/f5_tts/eval/eval_utmos.py --audio_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe32_vocos_ss-1_cfg2.0_speed1.0
|
||||
echo "======== Starting F5-TTS batch evaluation task..."
|
||||
if [ "$INFER_ONLY" = true ]; then
|
||||
echo "======== Mode: Execute infer tasks only"
|
||||
else
|
||||
echo "======== Mode: Execute full pipeline (infer + eval)"
|
||||
fi
|
||||
|
||||
# etc.
|
||||
# Function: Execute eval tasks
|
||||
execute_eval_tasks() {
|
||||
local ckptstep=$1
|
||||
local seed=$2
|
||||
local task_name=$3
|
||||
|
||||
local gen_wav_dir="results/${MODEL_NAME}_${ckptstep}/${task_name}/seed${seed}_euler_nfe32_vocos_ss-1_cfg2.0_speed1.0"
|
||||
|
||||
echo ">>>>>>>> Starting eval task: ckptstep=${ckptstep}, seed=${seed}, task=${task_name}"
|
||||
|
||||
case $task_name in
|
||||
"seedtts_test_zh")
|
||||
python src/f5_tts/eval/eval_seedtts_testset.py -e wer -l zh -g "$gen_wav_dir" -n "$GPUS" $LOCAL
|
||||
python src/f5_tts/eval/eval_seedtts_testset.py -e sim -l zh -g "$gen_wav_dir" -n "$GPUS" $LOCAL
|
||||
python src/f5_tts/eval/eval_utmos.py --audio_dir "$gen_wav_dir"
|
||||
;;
|
||||
"seedtts_test_en")
|
||||
python src/f5_tts/eval/eval_seedtts_testset.py -e wer -l en -g "$gen_wav_dir" -n "$GPUS" $LOCAL
|
||||
python src/f5_tts/eval/eval_seedtts_testset.py -e sim -l en -g "$gen_wav_dir" -n "$GPUS" $LOCAL
|
||||
python src/f5_tts/eval/eval_utmos.py --audio_dir "$gen_wav_dir"
|
||||
;;
|
||||
"ls_pc_test_clean")
|
||||
python src/f5_tts/eval/eval_librispeech_test_clean.py -e wer -g "$gen_wav_dir" -n "$GPUS" -p "$LS_TEST_CLEAN_PATH" $LOCAL
|
||||
python src/f5_tts/eval/eval_librispeech_test_clean.py -e sim -g "$gen_wav_dir" -n "$GPUS" -p "$LS_TEST_CLEAN_PATH" $LOCAL
|
||||
python src/f5_tts/eval/eval_utmos.py --audio_dir "$gen_wav_dir"
|
||||
;;
|
||||
esac
|
||||
|
||||
echo ">>>>>>>> Completed eval task: ckptstep=${ckptstep}, seed=${seed}, task=${task_name}"
|
||||
}
|
||||
|
||||
# Main execution loop
|
||||
for ckptstep in "${CKPTSTEPS[@]}"; do
|
||||
echo "======== Processing ckptstep: ${ckptstep}"
|
||||
|
||||
for seed in "${SEEDS[@]}"; do
|
||||
echo "-------- Processing seed: ${seed}"
|
||||
|
||||
# Store eval task PIDs for current seed (if not infer-only mode)
|
||||
if [ "$INFER_ONLY" = false ]; then
|
||||
declare -a eval_pids
|
||||
fi
|
||||
|
||||
# Execute each infer task sequentially
|
||||
for task in "${TASKS[@]}"; do
|
||||
echo ">>>>>>>> Executing infer task: accelerate launch src/f5_tts/eval/eval_infer_batch.py -s ${seed} -n \"${MODEL_NAME}\" -t \"${task}\" -c ${ckptstep} $LOCAL"
|
||||
|
||||
# Execute infer task (foreground execution, wait for completion)
|
||||
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s ${seed} -n "${MODEL_NAME}" -t "${task}" -c ${ckptstep} -p "${LS_TEST_CLEAN_PATH}" $LOCAL
|
||||
|
||||
# If not infer-only mode, launch corresponding eval task
|
||||
if [ "$INFER_ONLY" = false ]; then
|
||||
# Launch corresponding eval task (background execution, non-blocking for next infer)
|
||||
execute_eval_tasks $ckptstep $seed $task &
|
||||
eval_pids+=($!)
|
||||
fi
|
||||
done
|
||||
|
||||
# If not infer-only mode, wait for all eval tasks of current seed to complete
|
||||
if [ "$INFER_ONLY" = false ]; then
|
||||
echo ">>>>>>>> All infer tasks for seed ${seed} completed, waiting for corresponding eval tasks to finish..."
|
||||
|
||||
for pid in "${eval_pids[@]}"; do
|
||||
wait $pid
|
||||
done
|
||||
|
||||
unset eval_pids # Clean up array
|
||||
fi
|
||||
echo "-------- All eval tasks for seed ${seed} completed"
|
||||
done
|
||||
|
||||
echo "======== Completed ckptstep: ${ckptstep}"
|
||||
echo
|
||||
done
|
||||
|
||||
echo "======== All tasks completed!"
|
||||
18
src/f5_tts/eval/eval_infer_batch_example.sh
Normal file
18
src/f5_tts/eval/eval_infer_batch_example.sh
Normal file
@@ -0,0 +1,18 @@
|
||||
#!/bin/bash
|
||||
|
||||
# e.g. F5-TTS, 16 NFE
|
||||
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "seedtts_test_zh" -nfe 16
|
||||
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "seedtts_test_en" -nfe 16
|
||||
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "ls_pc_test_clean" -nfe 16 -p data/LibriSpeech/test-clean
|
||||
|
||||
# e.g. Vanilla E2 TTS, 32 NFE
|
||||
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "seedtts_test_zh" -o "midpoint" -ss 0
|
||||
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "seedtts_test_en" -o "midpoint" -ss 0
|
||||
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "ls_pc_test_clean" -o "midpoint" -ss 0 -p data/LibriSpeech/test-clean
|
||||
|
||||
# e.g. evaluate F5-TTS 16 NFE result on Seed-TTS test-zh
|
||||
python src/f5_tts/eval/eval_seedtts_testset.py -e wer -l zh --gen_wav_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe16_vocos_ss-1_cfg2.0_speed1.0 --gpu_nums 8
|
||||
python src/f5_tts/eval/eval_seedtts_testset.py -e sim -l zh --gen_wav_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe16_vocos_ss-1_cfg2.0_speed1.0 --gpu_nums 8
|
||||
python src/f5_tts/eval/eval_utmos.py --audio_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe16_vocos_ss-1_cfg2.0_speed1.0
|
||||
|
||||
# etc.
|
||||
@@ -1,6 +1,7 @@
|
||||
# Evaluate with Librispeech test-clean, ~3s prompt to generate 4-10s audio (the way of valle/voicebox evaluation)
|
||||
|
||||
import argparse
|
||||
import ast
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
@@ -25,11 +26,26 @@ def get_args():
|
||||
parser.add_argument("-l", "--lang", type=str, default="en")
|
||||
parser.add_argument("-g", "--gen_wav_dir", type=str, required=True)
|
||||
parser.add_argument("-p", "--librispeech_test_clean_path", type=str, required=True)
|
||||
parser.add_argument("-n", "--gpu_nums", type=int, default=8, help="Number of GPUs to use")
|
||||
parser.add_argument(
|
||||
"-n", "--gpu_nums", type=str, default="8", help="Number of GPUs to use (e.g., 8) or GPU list (e.g., [0,1,2,3])"
|
||||
)
|
||||
parser.add_argument("--local", action="store_true", help="Use local custom checkpoint directory")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def parse_gpu_nums(gpu_nums_str):
|
||||
try:
|
||||
if gpu_nums_str.startswith("[") and gpu_nums_str.endswith("]"):
|
||||
gpu_list = ast.literal_eval(gpu_nums_str)
|
||||
if isinstance(gpu_list, list):
|
||||
return gpu_list
|
||||
return list(range(int(gpu_nums_str)))
|
||||
except (ValueError, SyntaxError):
|
||||
raise argparse.ArgumentTypeError(
|
||||
f"Invalid GPU specification: {gpu_nums_str}. Use a number (e.g., 8) or a list (e.g., [0,1,2,3])"
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
eval_task = args.eval_task
|
||||
@@ -38,7 +54,7 @@ def main():
|
||||
gen_wav_dir = args.gen_wav_dir
|
||||
metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"
|
||||
|
||||
gpus = list(range(args.gpu_nums))
|
||||
gpus = parse_gpu_nums(args.gpu_nums)
|
||||
test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path)
|
||||
|
||||
## In LibriSpeech, some speakers utilized varying voice characteristics for different characters in the book,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Evaluate with Seed-TTS testset
|
||||
|
||||
import argparse
|
||||
import ast
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
@@ -24,11 +25,26 @@ def get_args():
|
||||
parser.add_argument("-e", "--eval_task", type=str, default="wer", choices=["sim", "wer"])
|
||||
parser.add_argument("-l", "--lang", type=str, default="en", choices=["zh", "en"])
|
||||
parser.add_argument("-g", "--gen_wav_dir", type=str, required=True)
|
||||
parser.add_argument("-n", "--gpu_nums", type=int, default=8, help="Number of GPUs to use")
|
||||
parser.add_argument(
|
||||
"-n", "--gpu_nums", type=str, default="8", help="Number of GPUs to use (e.g., 8) or GPU list (e.g., [0,1,2,3])"
|
||||
)
|
||||
parser.add_argument("--local", action="store_true", help="Use local custom checkpoint directory")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def parse_gpu_nums(gpu_nums_str):
|
||||
try:
|
||||
if gpu_nums_str.startswith("[") and gpu_nums_str.endswith("]"):
|
||||
gpu_list = ast.literal_eval(gpu_nums_str)
|
||||
if isinstance(gpu_list, list):
|
||||
return gpu_list
|
||||
return list(range(int(gpu_nums_str)))
|
||||
except (ValueError, SyntaxError):
|
||||
raise argparse.ArgumentTypeError(
|
||||
f"Invalid GPU specification: {gpu_nums_str}. Use a number (e.g., 8) or a list (e.g., [0,1,2,3])"
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
eval_task = args.eval_task
|
||||
@@ -38,7 +54,7 @@ def main():
|
||||
|
||||
# NOTE. paraformer-zh result will be slightly different according to the number of gpus, cuz batchsize is different
|
||||
# zh 1.254 seems a result of 4 workers wer_seed_tts
|
||||
gpus = list(range(args.gpu_nums))
|
||||
gpus = parse_gpu_nums(args.gpu_nums)
|
||||
test_set = get_seed_tts_test(metalst, gen_wav_dir, gpus)
|
||||
|
||||
local = args.local
|
||||
|
||||
@@ -395,14 +395,21 @@ def run_sim(args):
|
||||
wav1, sr1 = torchaudio.load(gen_wav)
|
||||
wav2, sr2 = torchaudio.load(prompt_wav)
|
||||
|
||||
resample1 = torchaudio.transforms.Resample(orig_freq=sr1, new_freq=16000)
|
||||
resample2 = torchaudio.transforms.Resample(orig_freq=sr2, new_freq=16000)
|
||||
wav1 = resample1(wav1)
|
||||
wav2 = resample2(wav2)
|
||||
|
||||
if use_gpu:
|
||||
wav1 = wav1.cuda(device)
|
||||
wav2 = wav2.cuda(device)
|
||||
|
||||
if sr1 != 16000:
|
||||
resample1 = torchaudio.transforms.Resample(orig_freq=sr1, new_freq=16000)
|
||||
if use_gpu:
|
||||
resample1 = resample1.cuda(device)
|
||||
wav1 = resample1(wav1)
|
||||
if sr2 != 16000:
|
||||
resample2 = torchaudio.transforms.Resample(orig_freq=sr2, new_freq=16000)
|
||||
if use_gpu:
|
||||
resample2 = resample2.cuda(device)
|
||||
wav2 = resample2(wav2)
|
||||
|
||||
with torch.no_grad():
|
||||
emb1 = model(wav1)
|
||||
emb2 = model(wav2)
|
||||
|
||||
@@ -6,12 +6,14 @@ nt - text sequence
|
||||
nw - raw wave length
|
||||
d - dimension
|
||||
"""
|
||||
# ruff: noqa: F722 F821
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from x_transformers.x_transformers import RotaryEmbedding
|
||||
|
||||
from f5_tts.model.modules import (
|
||||
@@ -20,7 +22,6 @@ from f5_tts.model.modules import (
|
||||
ConvPositionEmbedding,
|
||||
DiTBlock,
|
||||
TimestepEmbedding,
|
||||
get_pos_embed_indices,
|
||||
precompute_freqs_cis,
|
||||
)
|
||||
|
||||
@@ -42,7 +43,7 @@ class TextEmbedding(nn.Module):
|
||||
|
||||
if conv_layers > 0:
|
||||
self.extra_modeling = True
|
||||
self.precompute_max_pos = 4096 # ~44s of 24khz audio
|
||||
self.precompute_max_pos = 8192 # 8192 is ~87.38s of 24khz audio; 4096 is ~43.69s of 24khz audio
|
||||
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
|
||||
self.text_blocks = nn.Sequential(
|
||||
*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
|
||||
@@ -50,33 +51,29 @@ class TextEmbedding(nn.Module):
|
||||
else:
|
||||
self.extra_modeling = False
|
||||
|
||||
def average_upsample_text_by_mask(self, text, text_mask, audio_mask):
|
||||
def average_upsample_text_by_mask(self, text, text_mask):
|
||||
batch, text_len, text_dim = text.shape
|
||||
|
||||
if audio_mask is None:
|
||||
audio_mask = torch.ones_like(text_mask, dtype=torch.bool)
|
||||
valid_mask = audio_mask & text_mask
|
||||
audio_lens = audio_mask.sum(dim=1) # [batch]
|
||||
valid_lens = valid_mask.sum(dim=1) # [batch]
|
||||
audio_len = text_len # cuz text already padded to same length as audio sequence
|
||||
text_lens = text_mask.sum(dim=1) # [batch]
|
||||
|
||||
upsampled_text = torch.zeros_like(text)
|
||||
|
||||
for i in range(batch):
|
||||
audio_len = audio_lens[i].item()
|
||||
valid_len = valid_lens[i].item()
|
||||
text_len = text_lens[i].item()
|
||||
|
||||
if valid_len == 0:
|
||||
if text_len == 0:
|
||||
continue
|
||||
|
||||
valid_ind = torch.where(valid_mask[i])[0]
|
||||
valid_data = text[i, valid_ind, :] # [valid_len, text_dim]
|
||||
valid_ind = torch.where(text_mask[i])[0]
|
||||
valid_data = text[i, valid_ind, :] # [text_len, text_dim]
|
||||
|
||||
base_repeat = audio_len // valid_len
|
||||
remainder = audio_len % valid_len
|
||||
base_repeat = audio_len // text_len
|
||||
remainder = audio_len % text_len
|
||||
|
||||
indices = []
|
||||
for j in range(valid_len):
|
||||
repeat_count = base_repeat + (1 if j >= valid_len - remainder else 0)
|
||||
for j in range(text_len):
|
||||
repeat_count = base_repeat + (1 if j >= text_len - remainder else 0)
|
||||
indices.extend([j] * repeat_count)
|
||||
|
||||
indices = torch.tensor(indices[:audio_len], device=text.device, dtype=torch.long)
|
||||
@@ -86,11 +83,10 @@ class TextEmbedding(nn.Module):
|
||||
|
||||
return upsampled_text
|
||||
|
||||
def forward(self, text: int["b nt"], seq_len, drop_text=False, audio_mask: bool["b n"] | None = None): # noqa: F722
|
||||
def forward(self, text: int["b nt"], seq_len, drop_text=False):
|
||||
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
|
||||
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
|
||||
batch, text_len = text.shape[0], text.shape[1]
|
||||
text = F.pad(text, (0, seq_len - text_len), value=0) # (opt.) if not self.average_upsampling:
|
||||
text = F.pad(text, (0, seq_len - text.shape[1]), value=0) # (opt.) if not self.average_upsampling:
|
||||
if self.mask_padding:
|
||||
text_mask = text == 0
|
||||
|
||||
@@ -102,10 +98,7 @@ class TextEmbedding(nn.Module):
|
||||
# possible extra modeling
|
||||
if self.extra_modeling:
|
||||
# sinus pos emb
|
||||
batch_start = torch.zeros((batch,), device=text.device, dtype=torch.long)
|
||||
pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
|
||||
text_pos_embed = self.freqs_cis[pos_idx]
|
||||
text = text + text_pos_embed
|
||||
text = text + self.freqs_cis[:seq_len, :]
|
||||
|
||||
# convnextv2 blocks
|
||||
if self.mask_padding:
|
||||
@@ -117,7 +110,7 @@ class TextEmbedding(nn.Module):
|
||||
text = self.text_blocks(text)
|
||||
|
||||
if self.average_upsampling:
|
||||
text = self.average_upsample_text_by_mask(text, ~text_mask, audio_mask)
|
||||
text = self.average_upsample_text_by_mask(text, ~text_mask)
|
||||
|
||||
return text
|
||||
|
||||
@@ -131,12 +124,19 @@ class InputEmbedding(nn.Module):
|
||||
self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
|
||||
self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
|
||||
|
||||
def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722
|
||||
def forward(
|
||||
self,
|
||||
x: float["b n d"],
|
||||
cond: float["b n d"],
|
||||
text_embed: float["b n d"],
|
||||
drop_audio_cond=False,
|
||||
audio_mask: bool["b n"] | None = None,
|
||||
):
|
||||
if drop_audio_cond: # cfg for cond audio
|
||||
cond = torch.zeros_like(cond)
|
||||
|
||||
x = self.proj(torch.cat((x, cond, text_embed), dim=-1))
|
||||
x = self.conv_pos_embed(x) + x
|
||||
x = self.conv_pos_embed(x, mask=audio_mask) + x
|
||||
return x
|
||||
|
||||
|
||||
@@ -239,22 +239,36 @@ class DiT(nn.Module):
|
||||
drop_audio_cond: bool = False,
|
||||
drop_text: bool = False,
|
||||
cache: bool = True,
|
||||
audio_mask: bool["b n"] | None = None, # noqa: F722
|
||||
audio_mask: bool["b n"] | None = None,
|
||||
):
|
||||
seq_len = x.shape[1]
|
||||
if self.text_uncond is None or self.text_cond is None or not cache:
|
||||
if audio_mask is None:
|
||||
text_embed = self.text_embed(text, x.shape[1], drop_text=drop_text)
|
||||
else:
|
||||
batch = x.shape[0]
|
||||
seq_lens = audio_mask.sum(dim=1) # Calculate the actual sequence length for each sample
|
||||
text_embed_list = []
|
||||
for i in range(batch):
|
||||
text_embed_i = self.text_embed(
|
||||
text[i].unsqueeze(0),
|
||||
seq_len=seq_lens[i].item(),
|
||||
drop_text=drop_text,
|
||||
)
|
||||
text_embed_list.append(text_embed_i[0])
|
||||
text_embed = pad_sequence(text_embed_list, batch_first=True, padding_value=0)
|
||||
if cache:
|
||||
if drop_text:
|
||||
self.text_uncond = text_embed
|
||||
else:
|
||||
self.text_cond = text_embed
|
||||
|
||||
if cache:
|
||||
if drop_text:
|
||||
if self.text_uncond is None:
|
||||
self.text_uncond = self.text_embed(text, seq_len, drop_text=True, audio_mask=audio_mask)
|
||||
text_embed = self.text_uncond
|
||||
else:
|
||||
if self.text_cond is None:
|
||||
self.text_cond = self.text_embed(text, seq_len, drop_text=False, audio_mask=audio_mask)
|
||||
text_embed = self.text_cond
|
||||
else:
|
||||
text_embed = self.text_embed(text, seq_len, drop_text=drop_text, audio_mask=audio_mask)
|
||||
|
||||
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
|
||||
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond, audio_mask=audio_mask)
|
||||
|
||||
return x
|
||||
|
||||
@@ -263,11 +277,11 @@ class DiT(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: float["b n d"], # nosied input audio # noqa: F722
|
||||
cond: float["b n d"], # masked cond audio # noqa: F722
|
||||
text: int["b nt"], # text # noqa: F722
|
||||
time: float["b"] | float[""], # time step # noqa: F821 F722
|
||||
mask: bool["b n"] | None = None, # noqa: F722
|
||||
x: float["b n d"], # nosied input audio
|
||||
cond: float["b n d"], # masked cond audio
|
||||
text: int["b nt"], # text
|
||||
time: float["b"] | float[""], # time step
|
||||
mask: bool["b n"] | None = None,
|
||||
drop_audio_cond: bool = False, # cfg for cond audio
|
||||
drop_text: bool = False, # cfg for text
|
||||
cfg_infer: bool = False, # cfg inference, pack cond & uncond forward
|
||||
|
||||
@@ -6,6 +6,7 @@ nt - text sequence
|
||||
nw - raw wave length
|
||||
d - dimension
|
||||
"""
|
||||
# ruff: noqa: F722 F821
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -36,7 +37,7 @@ class TextEmbedding(nn.Module):
|
||||
self.precompute_max_pos = 1024
|
||||
self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False)
|
||||
|
||||
def forward(self, text: int["b nt"], drop_text=False) -> int["b nt d"]: # noqa: F722
|
||||
def forward(self, text: int["b nt"], drop_text=False) -> int["b nt d"]:
|
||||
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
|
||||
if self.mask_padding:
|
||||
text_mask = text == 0
|
||||
@@ -69,7 +70,7 @@ class AudioEmbedding(nn.Module):
|
||||
self.linear = nn.Linear(2 * in_dim, out_dim)
|
||||
self.conv_pos_embed = ConvPositionEmbedding(out_dim)
|
||||
|
||||
def forward(self, x: float["b n d"], cond: float["b n d"], drop_audio_cond=False): # noqa: F722
|
||||
def forward(self, x: float["b n d"], cond: float["b n d"], drop_audio_cond=False):
|
||||
if drop_audio_cond:
|
||||
cond = torch.zeros_like(cond)
|
||||
x = torch.cat((x, cond), dim=-1)
|
||||
@@ -170,11 +171,11 @@ class MMDiT(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: float["b n d"], # nosied input audio # noqa: F722
|
||||
cond: float["b n d"], # masked cond audio # noqa: F722
|
||||
text: int["b nt"], # text # noqa: F722
|
||||
time: float["b"] | float[""], # time step # noqa: F821 F722
|
||||
mask: bool["b n"] | None = None, # noqa: F722
|
||||
x: float["b n d"], # nosied input audio
|
||||
cond: float["b n d"], # masked cond audio
|
||||
text: int["b nt"], # text
|
||||
time: float["b"] | float[""], # time step
|
||||
mask: bool["b n"] | None = None,
|
||||
drop_audio_cond: bool = False, # cfg for cond audio
|
||||
drop_text: bool = False, # cfg for text
|
||||
cfg_infer: bool = False, # cfg inference, pack cond & uncond forward
|
||||
|
||||
@@ -6,6 +6,7 @@ nt - text sequence
|
||||
nw - raw wave length
|
||||
d - dimension
|
||||
"""
|
||||
# ruff: noqa: F722 F821
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -49,7 +50,7 @@ class TextEmbedding(nn.Module):
|
||||
else:
|
||||
self.extra_modeling = False
|
||||
|
||||
def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
|
||||
def forward(self, text: int["b nt"], seq_len, drop_text=False):
|
||||
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
|
||||
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
|
||||
batch, text_len = text.shape[0], text.shape[1]
|
||||
@@ -91,7 +92,7 @@ class InputEmbedding(nn.Module):
|
||||
self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
|
||||
self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
|
||||
|
||||
def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722
|
||||
def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False):
|
||||
if drop_audio_cond: # cfg for cond audio
|
||||
cond = torch.zeros_like(cond)
|
||||
|
||||
@@ -215,11 +216,11 @@ class UNetT(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: float["b n d"], # nosied input audio # noqa: F722
|
||||
cond: float["b n d"], # masked cond audio # noqa: F722
|
||||
text: int["b nt"], # text # noqa: F722
|
||||
time: float["b"] | float[""], # time step # noqa: F821 F722
|
||||
mask: bool["b n"] | None = None, # noqa: F722
|
||||
x: float["b n d"], # nosied input audio
|
||||
cond: float["b n d"], # masked cond audio
|
||||
text: int["b nt"], # text
|
||||
time: float["b"] | float[""], # time step
|
||||
mask: bool["b n"] | None = None,
|
||||
drop_audio_cond: bool = False, # cfg for cond audio
|
||||
drop_text: bool = False, # cfg for text
|
||||
cfg_infer: bool = False, # cfg inference, pack cond & uncond forward
|
||||
|
||||
@@ -6,6 +6,7 @@ nt - text sequence
|
||||
nw - raw wave length
|
||||
d - dimension
|
||||
"""
|
||||
# ruff: noqa: F722 F821
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -82,17 +83,17 @@ class CFM(nn.Module):
|
||||
@torch.no_grad()
|
||||
def sample(
|
||||
self,
|
||||
cond: float["b n d"] | float["b nw"], # noqa: F722
|
||||
text: int["b nt"] | list[str], # noqa: F722
|
||||
duration: int | int["b"], # noqa: F821
|
||||
cond: float["b n d"] | float["b nw"],
|
||||
text: int["b nt"] | list[str],
|
||||
duration: int | int["b"],
|
||||
*,
|
||||
lens: int["b"] | None = None, # noqa: F821
|
||||
lens: int["b"] | None = None,
|
||||
steps=32,
|
||||
cfg_strength=1.0,
|
||||
sway_sampling_coef=None,
|
||||
seed: int | None = None,
|
||||
max_duration=4096,
|
||||
vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722
|
||||
vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None,
|
||||
use_epss=True,
|
||||
no_ref_audio=False,
|
||||
duplicate_test=False,
|
||||
@@ -229,10 +230,10 @@ class CFM(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inp: float["b n d"] | float["b nw"], # mel or raw wave # noqa: F722
|
||||
text: int["b nt"] | list[str], # noqa: F722
|
||||
inp: float["b n d"] | float["b nw"], # mel or raw wave
|
||||
text: int["b nt"] | list[str],
|
||||
*,
|
||||
lens: int["b"] | None = None, # noqa: F821
|
||||
lens: int["b"] | None = None,
|
||||
noise_scheduler: str | None = None,
|
||||
):
|
||||
# handle raw wave
|
||||
@@ -252,10 +253,9 @@ class CFM(nn.Module):
|
||||
assert text.shape[0] == batch
|
||||
|
||||
# lens and mask
|
||||
if not exists(lens):
|
||||
if not exists(lens): # if lens not acquired by trainer from collate_fn
|
||||
lens = torch.full((batch,), seq_len, device=device)
|
||||
|
||||
mask = lens_to_mask(lens, length=seq_len) # useless here, as collate_fn will pad to max length in batch
|
||||
mask = lens_to_mask(lens, length=seq_len)
|
||||
|
||||
# get a random span to mask out for training conditionally
|
||||
frac_lengths = torch.zeros((batch,), device=self.device).float().uniform_(*self.frac_lengths_mask)
|
||||
|
||||
@@ -6,7 +6,7 @@ nt - text sequence
|
||||
nw - raw wave length
|
||||
d - dimension
|
||||
"""
|
||||
# flake8: noqa
|
||||
# ruff: noqa: F722 F821
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -177,20 +177,23 @@ class ConvPositionEmbedding(nn.Module):
|
||||
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
|
||||
nn.Mish(),
|
||||
)
|
||||
self.layer_need_mask_idx = [i for i, layer in enumerate(self.conv1d) if isinstance(layer, nn.Conv1d)]
|
||||
|
||||
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None):
|
||||
if mask is not None:
|
||||
mask = mask[..., None]
|
||||
x = x.masked_fill(~mask, 0.0)
|
||||
|
||||
x = x.permute(0, 2, 1)
|
||||
x = self.conv1d(x)
|
||||
out = x.permute(0, 2, 1)
|
||||
mask = mask.unsqueeze(1) # [B 1 N]
|
||||
x = x.permute(0, 2, 1) # [B D N]
|
||||
|
||||
if mask is not None:
|
||||
out = out.masked_fill(~mask, 0.0)
|
||||
x = x.masked_fill(~mask, 0.0)
|
||||
for i, block in enumerate(self.conv1d):
|
||||
x = block(x)
|
||||
if mask is not None and i in self.layer_need_mask_idx:
|
||||
x = x.masked_fill(~mask, 0.0)
|
||||
|
||||
return out
|
||||
x = x.permute(0, 2, 1) # [B N D]
|
||||
|
||||
return x
|
||||
|
||||
|
||||
# rotary positional embedding related
|
||||
@@ -435,8 +438,8 @@ class Attention(nn.Module):
|
||||
# Attention processor
|
||||
|
||||
if is_package_available("flash_attn"):
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
from flash_attn.bert_padding import pad_input, unpad_input
|
||||
from flash_attn import flash_attn_varlen_func, flash_attn_func
|
||||
|
||||
|
||||
class AttnProcessor:
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
# ruff: noqa: F722 F821
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
@@ -5,7 +7,7 @@ import random
|
||||
from collections import defaultdict
|
||||
from importlib.resources import files
|
||||
|
||||
import jieba
|
||||
import rjieba
|
||||
import torch
|
||||
from pypinyin import Style, lazy_pinyin
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
@@ -48,7 +50,7 @@ def is_package_available(package_name: str) -> bool:
|
||||
# tensor helpers
|
||||
|
||||
|
||||
def lens_to_mask(t: int["b"], length: int | None = None) -> bool["b n"]: # noqa: F722 F821
|
||||
def lens_to_mask(t: int["b"], length: int | None = None) -> bool["b n"]:
|
||||
if not exists(length):
|
||||
length = t.amax()
|
||||
|
||||
@@ -56,7 +58,7 @@ def lens_to_mask(t: int["b"], length: int | None = None) -> bool["b n"]: # noqa
|
||||
return seq[None, :] < t[:, None]
|
||||
|
||||
|
||||
def mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b"]): # noqa: F722 F821
|
||||
def mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b"]):
|
||||
max_seq_len = seq_len.max().item()
|
||||
seq = torch.arange(max_seq_len, device=start.device).long()
|
||||
start_mask = seq[None, :] >= start[:, None]
|
||||
@@ -64,7 +66,7 @@ def mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b"
|
||||
return start_mask & end_mask
|
||||
|
||||
|
||||
def mask_from_frac_lengths(seq_len: int["b"], frac_lengths: float["b"]): # noqa: F722 F821
|
||||
def mask_from_frac_lengths(seq_len: int["b"], frac_lengths: float["b"]):
|
||||
lengths = (frac_lengths * seq_len).long()
|
||||
max_start = seq_len - lengths
|
||||
|
||||
@@ -75,7 +77,7 @@ def mask_from_frac_lengths(seq_len: int["b"], frac_lengths: float["b"]): # noqa
|
||||
return mask_from_start_end_indices(seq_len, start, end)
|
||||
|
||||
|
||||
def maybe_masked_mean(t: float["b n d"], mask: bool["b n"] = None) -> float["b d"]: # noqa: F722
|
||||
def maybe_masked_mean(t: float["b n d"], mask: bool["b n"] = None) -> float["b d"]:
|
||||
if not exists(mask):
|
||||
return t.mean(dim=1)
|
||||
|
||||
@@ -87,7 +89,7 @@ def maybe_masked_mean(t: float["b n d"], mask: bool["b n"] = None) -> float["b d
|
||||
|
||||
|
||||
# simple utf-8 tokenizer, since paper went character based
|
||||
def list_str_to_tensor(text: list[str], padding_value=-1) -> int["b nt"]: # noqa: F722
|
||||
def list_str_to_tensor(text: list[str], padding_value=-1) -> int["b nt"]:
|
||||
list_tensors = [torch.tensor([*bytes(t, "UTF-8")]) for t in text] # ByT5 style
|
||||
text = pad_sequence(list_tensors, padding_value=padding_value, batch_first=True)
|
||||
return text
|
||||
@@ -98,7 +100,7 @@ def list_str_to_idx(
|
||||
text: list[str] | list[list[str]],
|
||||
vocab_char_map: dict[str, int], # {char: idx}
|
||||
padding_value=-1,
|
||||
) -> int["b nt"]: # noqa: F722
|
||||
) -> int["b nt"]:
|
||||
list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
|
||||
text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
|
||||
return text
|
||||
@@ -144,10 +146,6 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
|
||||
|
||||
|
||||
def convert_char_to_pinyin(text_list, polyphone=True):
|
||||
if jieba.dt.initialized is False:
|
||||
jieba.default_logger.setLevel(50) # CRITICAL
|
||||
jieba.initialize()
|
||||
|
||||
final_text_list = []
|
||||
custom_trans = str.maketrans(
|
||||
{";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"}
|
||||
@@ -161,7 +159,7 @@ def convert_char_to_pinyin(text_list, polyphone=True):
|
||||
for text in text_list:
|
||||
char_list = []
|
||||
text = text.translate(custom_trans)
|
||||
for seg in jieba.cut(text):
|
||||
for seg in rjieba.cut(text):
|
||||
seg_byte_len = len(bytes(seg, "UTF-8"))
|
||||
if seg_byte_len == len(seg): # if pure alphabets and symbols
|
||||
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
|
||||
|
||||
3
src/f5_tts/runtime/triton_trtllm/.gitignore
vendored
Normal file
3
src/f5_tts/runtime/triton_trtllm/.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
# runtime/triton_trtllm related
|
||||
model.cache
|
||||
model_repo/
|
||||
@@ -1,3 +1,3 @@
|
||||
FROM nvcr.io/nvidia/tritonserver:24.12-py3
|
||||
RUN pip install tritonclient[grpc] tensorrt-llm==0.16.0 torchaudio==2.5.1 jieba pypinyin librosa vocos
|
||||
RUN pip install tritonclient[grpc] tensorrt-llm==0.16.0 torchaudio==2.5.1 rjieba pypinyin librosa vocos
|
||||
WORKDIR /workspace
|
||||
@@ -1,59 +1,68 @@
|
||||
## Triton Inference Serving Best Practice for F5-TTS
|
||||
|
||||
### Quick Start
|
||||
Directly launch the service using docker compose.
|
||||
### Setup
|
||||
#### Option 1: Quick Start
|
||||
```sh
|
||||
# TODO: support F5TTS_v1_Base
|
||||
MODEL=F5TTS_Base docker compose up
|
||||
# Directly launch the service using docker compose
|
||||
MODEL=F5TTS_v1_Base docker compose up
|
||||
```
|
||||
|
||||
### Build Image
|
||||
Build the docker image from scratch.
|
||||
#### Option 2: Build from scratch
|
||||
```sh
|
||||
# Build the docker image
|
||||
docker build . -f Dockerfile.server -t soar97/triton-f5-tts:24.12
|
||||
```
|
||||
|
||||
### Create Docker Container
|
||||
```sh
|
||||
# Create Docker Container
|
||||
your_mount_dir=/mnt:/mnt
|
||||
docker run -it --name "f5-server" --gpus all --net host -v $your_mount_dir --shm-size=2g soar97/triton-f5-tts:24.12
|
||||
```
|
||||
|
||||
### Export Models to TensorRT-LLM and Launch Server
|
||||
Inside docker container, we would follow the official guide of TensorRT-LLM to build qwen and whisper TensorRT-LLM engines. See [here](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/whisper).
|
||||
### Build TensorRT-LLM Engines and Launch Server
|
||||
Inside docker container, we would follow the official guide of TensorRT-LLM to build qwen and whisper TensorRT-LLM engines. See [here](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/models/core/whisper).
|
||||
```sh
|
||||
bash run.sh 0 4 F5TTS_Base
|
||||
# F5TTS_v1_Base | F5TTS_Base | F5TTS_v1_Small | F5TTS_Small
|
||||
bash run.sh 0 4 F5TTS_v1_Base
|
||||
```
|
||||
> [!NOTE]
|
||||
> If use custom checkpoint, set `ckpt_file` and `vocab_file` in `run.sh`.
|
||||
> Remember to used matched model version (`F5TTS_v1_*` for v1, `F5TTS_*` for v0).
|
||||
>
|
||||
> If use checkpoint of different structure, see `scripts/convert_checkpoint.py`, and perform modification if necessary.
|
||||
|
||||
> [!IMPORTANT]
|
||||
> If train or finetune with fp32, add `--dtype float32` flag when converting checkpoint in `run.sh` phase 1.
|
||||
|
||||
### HTTP Client
|
||||
```sh
|
||||
python3 client_http.py
|
||||
```
|
||||
|
||||
### Benchmark using Client-Server Mode
|
||||
### Benchmarking
|
||||
#### Using Client-Server Mode
|
||||
```sh
|
||||
# bash run.sh 5 5 F5TTS_v1_Base
|
||||
num_task=2
|
||||
python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts
|
||||
```
|
||||
|
||||
### Benchmark using Offline TRT-LLM Mode
|
||||
#### Using Offline TRT-LLM Mode
|
||||
```sh
|
||||
# bash run.sh 7 7 F5TTS_v1_Base
|
||||
batch_size=1
|
||||
split_name=wenetspeech4tts
|
||||
backend_type=trt
|
||||
log_dir=./log_benchmark_batch_size_${batch_size}_${split_name}_${backend_type}
|
||||
log_dir=./tests/benchmark_batch_size_${batch_size}_${split_name}_${backend_type}
|
||||
rm -r $log_dir
|
||||
ln -s model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py ./
|
||||
torchrun --nproc_per_node=1 \
|
||||
benchmark.py --output-dir $log_dir \
|
||||
--batch-size $batch_size \
|
||||
--enable-warmup \
|
||||
--split-name $split_name \
|
||||
--model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \
|
||||
--vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \
|
||||
--vocoder-trt-engine-path $vocoder_trt_engine_path \
|
||||
--model-path $ckpt_file \
|
||||
--vocab-file $vocab_file \
|
||||
--vocoder-trt-engine-path $VOCODER_TRT_ENGINE_PATH \
|
||||
--backend-type $backend_type \
|
||||
--tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1
|
||||
--tllm-model-dir $TRTLLM_ENGINE_DIR || exit 1
|
||||
```
|
||||
|
||||
### Benchmark Results
|
||||
@@ -66,4 +75,5 @@ Decoding on a single L20 GPU, using 26 different prompt_audio & target_text pair
|
||||
| F5-TTS Base (Vocos) | 1 (Batch_size) | - | 0.1467 | Offline Pytorch |
|
||||
|
||||
### Credits
|
||||
1. [F5-TTS-TRTLLM](https://github.com/Bigfishering/f5-tts-trtllm)
|
||||
1. [Yuekai Zhang](https://github.com/yuekaizhang)
|
||||
2. [F5-TTS-TRTLLM](https://github.com/Bigfishering/f5-tts-trtllm)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song)
|
||||
# 2025 (authors: Yuekai Zhang)
|
||||
# 2025 (authors: Yuekai Zhang)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -19,39 +19,45 @@ benchmark.py --output-dir $log_dir \
|
||||
--batch-size $batch_size \
|
||||
--enable-warmup \
|
||||
--split-name $split_name \
|
||||
--model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \
|
||||
--vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \
|
||||
--model-path $CKPT_DIR/$model/model_1200000.pt \
|
||||
--vocab-file $CKPT_DIR/$model/vocab.txt \
|
||||
--vocoder-trt-engine-path $vocoder_trt_engine_path \
|
||||
--backend-type $backend_type \
|
||||
--tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1
|
||||
--tllm-model-dir $TRTLLM_ENGINE_DIR || exit 1
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import datasets
|
||||
import jieba
|
||||
import tensorrt as trt
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
from datasets import load_dataset
|
||||
from f5_tts_trtllm import F5TTS
|
||||
from huggingface_hub import hf_hub_download
|
||||
from pypinyin import Style, lazy_pinyin
|
||||
from tensorrt_llm._utils import trt_dtype_to_torch
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.runtime.session import Session, TensorInfo
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from tqdm import tqdm
|
||||
from vocos import Vocos
|
||||
|
||||
|
||||
sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../../../src/")
|
||||
|
||||
from f5_tts.eval.utils_eval import padded_mel_batch
|
||||
from f5_tts.model.modules import get_vocos_mel_spectrogram
|
||||
from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer, list_str_to_idx
|
||||
|
||||
|
||||
F5TTS = importlib.import_module("model_repo_f5_tts.f5_tts.1.f5_tts_trtllm").F5TTS
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
@@ -111,22 +117,20 @@ def get_args():
|
||||
return args
|
||||
|
||||
|
||||
def padded_mel_batch(ref_mels, max_seq_len):
|
||||
padded_ref_mels = []
|
||||
for mel in ref_mels:
|
||||
# pad along the last dimension
|
||||
padded_ref_mel = F.pad(mel, (0, 0, 0, max_seq_len - mel.shape[0]), value=0)
|
||||
padded_ref_mels.append(padded_ref_mel)
|
||||
padded_ref_mels = torch.stack(padded_ref_mels)
|
||||
return padded_ref_mels
|
||||
|
||||
|
||||
def data_collator(batch, vocab_char_map, device="cuda", use_perf=False):
|
||||
if use_perf:
|
||||
torch.cuda.nvtx.range_push("data_collator")
|
||||
target_sample_rate = 24000
|
||||
target_rms = 0.1
|
||||
ids, ref_mel_list, ref_mel_len_list, estimated_reference_target_mel_len, reference_target_texts_list = (
|
||||
(
|
||||
ids,
|
||||
ref_rms_list,
|
||||
ref_mel_list,
|
||||
ref_mel_len_list,
|
||||
estimated_reference_target_mel_len,
|
||||
reference_target_texts_list,
|
||||
) = (
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
@@ -148,6 +152,7 @@ def data_collator(batch, vocab_char_map, device="cuda", use_perf=False):
|
||||
)
|
||||
ref_audio_org = torch.from_numpy(ref_audio_org).unsqueeze(0).float()
|
||||
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org)))
|
||||
ref_rms_list.append(ref_rms)
|
||||
if ref_rms < target_rms:
|
||||
ref_audio_org = ref_audio_org * target_rms / ref_rms
|
||||
|
||||
@@ -159,40 +164,31 @@ def data_collator(batch, vocab_char_map, device="cuda", use_perf=False):
|
||||
|
||||
if use_perf:
|
||||
torch.cuda.nvtx.range_push(f"mel_spectrogram {i}")
|
||||
ref_mel = mel_spectrogram(ref_audio, vocoder="vocos", device="cuda")
|
||||
ref_audio = ref_audio.to("cuda")
|
||||
ref_mel = get_vocos_mel_spectrogram(ref_audio).squeeze(0)
|
||||
if use_perf:
|
||||
torch.cuda.nvtx.range_pop()
|
||||
ref_mel = ref_mel.squeeze()
|
||||
ref_mel_len = ref_mel.shape[0]
|
||||
assert ref_mel.shape[1] == 100
|
||||
ref_mel_len = ref_mel.shape[-1]
|
||||
assert ref_mel.shape[0] == 100
|
||||
|
||||
ref_mel_list.append(ref_mel)
|
||||
ref_mel_len_list.append(ref_mel_len)
|
||||
|
||||
estimated_reference_target_mel_len.append(
|
||||
int(ref_mel.shape[0] * (1 + len(target_text.encode("utf-8")) / len(prompt_text.encode("utf-8"))))
|
||||
int(ref_mel_len * (1 + len(target_text.encode("utf-8")) / len(prompt_text.encode("utf-8"))))
|
||||
)
|
||||
|
||||
max_seq_len = max(estimated_reference_target_mel_len)
|
||||
ref_mel_batch = padded_mel_batch(ref_mel_list, max_seq_len)
|
||||
ref_mel_batch = padded_mel_batch(ref_mel_list)
|
||||
ref_mel_len_batch = torch.LongTensor(ref_mel_len_list)
|
||||
|
||||
pinyin_list = convert_char_to_pinyin(reference_target_texts_list, polyphone=True)
|
||||
text_pad_sequence = list_str_to_idx(pinyin_list, vocab_char_map)
|
||||
|
||||
for i, item in enumerate(text_pad_sequence):
|
||||
text_pad_sequence[i] = F.pad(
|
||||
item, (0, estimated_reference_target_mel_len[i] - len(item)), mode="constant", value=-1
|
||||
)
|
||||
text_pad_sequence[i] += 1 # WAR: 0 is reserved for padding token, hard coding in F5-TTS
|
||||
text_pad_sequence = pad_sequence(text_pad_sequence, padding_value=-1, batch_first=True).to(device)
|
||||
text_pad_sequence = F.pad(
|
||||
text_pad_sequence, (0, max_seq_len - text_pad_sequence.shape[1]), mode="constant", value=-1
|
||||
)
|
||||
if use_perf:
|
||||
torch.cuda.nvtx.range_pop()
|
||||
return {
|
||||
"ids": ids,
|
||||
"ref_rms_list": ref_rms_list,
|
||||
"ref_mel_batch": ref_mel_batch,
|
||||
"ref_mel_len_batch": ref_mel_len_batch,
|
||||
"text_pad_sequence": text_pad_sequence,
|
||||
@@ -216,72 +212,6 @@ def init_distributed():
|
||||
return world_size, local_rank, rank
|
||||
|
||||
|
||||
def get_tokenizer(vocab_file_path: str):
|
||||
"""
|
||||
tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
|
||||
- "char" for char-wise tokenizer, need .txt vocab_file
|
||||
- "byte" for utf-8 tokenizer
|
||||
- "custom" if you're directly passing in a path to the vocab.txt you want to use
|
||||
vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
|
||||
- if use "char", derived from unfiltered character & symbol counts of custom dataset
|
||||
- if use "byte", set to 256 (unicode byte range)
|
||||
"""
|
||||
with open(vocab_file_path, "r", encoding="utf-8") as f:
|
||||
vocab_char_map = {}
|
||||
for i, char in enumerate(f):
|
||||
vocab_char_map[char[:-1]] = i
|
||||
vocab_size = len(vocab_char_map)
|
||||
return vocab_char_map, vocab_size
|
||||
|
||||
|
||||
def convert_char_to_pinyin(reference_target_texts_list, polyphone=True):
|
||||
final_reference_target_texts_list = []
|
||||
custom_trans = str.maketrans(
|
||||
{";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"}
|
||||
) # add custom trans here, to address oov
|
||||
|
||||
def is_chinese(c):
|
||||
return "\u3100" <= c <= "\u9fff" # common chinese characters
|
||||
|
||||
for text in reference_target_texts_list:
|
||||
char_list = []
|
||||
text = text.translate(custom_trans)
|
||||
for seg in jieba.cut(text):
|
||||
seg_byte_len = len(bytes(seg, "UTF-8"))
|
||||
if seg_byte_len == len(seg): # if pure alphabets and symbols
|
||||
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
|
||||
char_list.append(" ")
|
||||
char_list.extend(seg)
|
||||
elif polyphone and seg_byte_len == 3 * len(seg): # if pure east asian characters
|
||||
seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
|
||||
for i, c in enumerate(seg):
|
||||
if is_chinese(c):
|
||||
char_list.append(" ")
|
||||
char_list.append(seg_[i])
|
||||
else: # if mixed characters, alphabets and symbols
|
||||
for c in seg:
|
||||
if ord(c) < 256:
|
||||
char_list.extend(c)
|
||||
elif is_chinese(c):
|
||||
char_list.append(" ")
|
||||
char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
|
||||
else:
|
||||
char_list.append(c)
|
||||
final_reference_target_texts_list.append(char_list)
|
||||
|
||||
return final_reference_target_texts_list
|
||||
|
||||
|
||||
def list_str_to_idx(
|
||||
text: Union[List[str], List[List[str]]],
|
||||
vocab_char_map: Dict[str, int], # {char: idx}
|
||||
padding_value=-1,
|
||||
):
|
||||
list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
|
||||
# text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
|
||||
return list_idx_tensors
|
||||
|
||||
|
||||
def load_vocoder(
|
||||
vocoder_name="vocos", is_local=False, local_path="", device="cuda", hf_cache_dir=None, vocoder_trt_engine_path=None
|
||||
):
|
||||
@@ -316,29 +246,11 @@ def load_vocoder(
|
||||
return vocoder
|
||||
|
||||
|
||||
def mel_spectrogram(waveform, vocoder="vocos", device="cuda"):
|
||||
if vocoder == "vocos":
|
||||
mel_stft = torchaudio.transforms.MelSpectrogram(
|
||||
sample_rate=24000,
|
||||
n_fft=1024,
|
||||
win_length=1024,
|
||||
hop_length=256,
|
||||
n_mels=100,
|
||||
power=1,
|
||||
center=True,
|
||||
normalized=False,
|
||||
norm=None,
|
||||
).to(device)
|
||||
mel = mel_stft(waveform.to(device))
|
||||
mel = mel.clamp(min=1e-5).log()
|
||||
return mel.transpose(1, 2)
|
||||
|
||||
|
||||
class VocosTensorRT:
|
||||
def __init__(self, engine_path="./vocos_vocoder.plan", stream=None):
|
||||
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
|
||||
trt.init_libnvinfer_plugins(TRT_LOGGER, namespace="")
|
||||
logger.info(f"Loading vae engine from {engine_path}")
|
||||
logger.info(f"Loading vocoder engine from {engine_path}")
|
||||
self.engine_path = engine_path
|
||||
with open(engine_path, "rb") as f:
|
||||
engine_buffer = f.read()
|
||||
@@ -368,34 +280,37 @@ def main():
|
||||
world_size, local_rank, rank = init_distributed()
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
|
||||
vocab_char_map, vocab_size = get_tokenizer(args.vocab_file)
|
||||
vocab_char_map, vocab_size = get_tokenizer(args.vocab_file, "custom")
|
||||
|
||||
tllm_model_dir = args.tllm_model_dir
|
||||
config_file = os.path.join(tllm_model_dir, "config.json")
|
||||
with open(config_file) as f:
|
||||
config = json.load(f)
|
||||
with open(os.path.join(tllm_model_dir, "config.json")) as f:
|
||||
tllm_model_config = json.load(f)
|
||||
if args.backend_type == "trt":
|
||||
model = F5TTS(
|
||||
config, debug_mode=False, tllm_model_dir=tllm_model_dir, model_path=args.model_path, vocab_size=vocab_size
|
||||
tllm_model_config,
|
||||
debug_mode=False,
|
||||
tllm_model_dir=tllm_model_dir,
|
||||
model_path=args.model_path,
|
||||
vocab_size=vocab_size,
|
||||
)
|
||||
elif args.backend_type == "pytorch":
|
||||
import sys
|
||||
|
||||
sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../../../src/")
|
||||
from f5_tts.infer.utils_infer import load_model
|
||||
from f5_tts.model import DiT
|
||||
|
||||
F5TTS_model_cfg = dict(
|
||||
dim=1024,
|
||||
depth=22,
|
||||
heads=16,
|
||||
ff_mult=2,
|
||||
text_dim=512,
|
||||
conv_layers=4,
|
||||
pe_attn_head=1,
|
||||
text_mask_padding=False,
|
||||
pretrained_config = tllm_model_config["pretrained_config"]
|
||||
pt_model_config = dict(
|
||||
dim=pretrained_config["hidden_size"],
|
||||
depth=pretrained_config["num_hidden_layers"],
|
||||
heads=pretrained_config["num_attention_heads"],
|
||||
ff_mult=pretrained_config["ff_mult"],
|
||||
text_dim=pretrained_config["text_dim"],
|
||||
text_mask_padding=pretrained_config["text_mask_padding"],
|
||||
conv_layers=pretrained_config["conv_layers"],
|
||||
pe_attn_head=pretrained_config["pe_attn_head"],
|
||||
# attn_backend="flash_attn",
|
||||
# attn_mask_enabled=True,
|
||||
)
|
||||
model = load_model(DiT, F5TTS_model_cfg, args.model_path)
|
||||
model = load_model(DiT, pt_model_config, args.model_path)
|
||||
|
||||
vocoder = load_vocoder(
|
||||
vocoder_name=args.vocoder, device=device, vocoder_trt_engine_path=args.vocoder_trt_engine_path
|
||||
@@ -445,20 +360,23 @@ def main():
|
||||
ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch["ref_mel_len_batch"].to(device)
|
||||
text_pad_seq = batch["text_pad_sequence"].to(device)
|
||||
total_mel_lens = batch["estimated_reference_target_mel_len"]
|
||||
cond_pad_seq = F.pad(ref_mels, (0, 0, 0, max(total_mel_lens) - ref_mels.shape[1], 0, 0))
|
||||
if args.backend_type == "trt":
|
||||
_ = model.sample(
|
||||
text_pad_seq, ref_mels, ref_mel_lens, total_mel_lens, remove_input_padding=args.remove_input_padding
|
||||
text_pad_seq,
|
||||
cond_pad_seq,
|
||||
ref_mel_lens,
|
||||
total_mel_lens,
|
||||
remove_input_padding=args.remove_input_padding,
|
||||
)
|
||||
elif args.backend_type == "pytorch":
|
||||
total_mel_lens = torch.tensor(total_mel_lens, device=device)
|
||||
with torch.inference_mode():
|
||||
text_pad_seq -= 1
|
||||
text_pad_seq[text_pad_seq == -2] = -1
|
||||
total_mel_lens = torch.tensor(total_mel_lens, device=device)
|
||||
generated, _ = model.sample(
|
||||
cond=ref_mels,
|
||||
text=text_pad_seq,
|
||||
duration=total_mel_lens,
|
||||
steps=16,
|
||||
steps=32,
|
||||
cfg_strength=2.0,
|
||||
sway_sampling_coef=-1,
|
||||
)
|
||||
@@ -478,13 +396,13 @@ def main():
|
||||
ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch["ref_mel_len_batch"].to(device)
|
||||
text_pad_seq = batch["text_pad_sequence"].to(device)
|
||||
total_mel_lens = batch["estimated_reference_target_mel_len"]
|
||||
|
||||
cond_pad_seq = F.pad(ref_mels, (0, 0, 0, max(total_mel_lens) - ref_mels.shape[1], 0, 0))
|
||||
if args.use_perf:
|
||||
torch.cuda.nvtx.range_pop()
|
||||
if args.backend_type == "trt":
|
||||
generated, cost_time = model.sample(
|
||||
text_pad_seq,
|
||||
ref_mels,
|
||||
cond_pad_seq,
|
||||
ref_mel_lens,
|
||||
total_mel_lens,
|
||||
remove_input_padding=args.remove_input_padding,
|
||||
@@ -494,20 +412,20 @@ def main():
|
||||
total_mel_lens = torch.tensor(total_mel_lens, device=device)
|
||||
with torch.inference_mode():
|
||||
start_time = time.time()
|
||||
text_pad_seq -= 1
|
||||
text_pad_seq[text_pad_seq == -2] = -1
|
||||
generated, _ = model.sample(
|
||||
cond=ref_mels,
|
||||
text=text_pad_seq,
|
||||
duration=total_mel_lens,
|
||||
lens=ref_mel_lens,
|
||||
steps=16,
|
||||
steps=32,
|
||||
cfg_strength=2.0,
|
||||
sway_sampling_coef=-1,
|
||||
)
|
||||
cost_time = time.time() - start_time
|
||||
decoding_time += cost_time
|
||||
vocoder_start_time = time.time()
|
||||
target_rms = 0.1
|
||||
target_sample_rate = 24000
|
||||
for i, gen in enumerate(generated):
|
||||
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
|
||||
gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32)
|
||||
@@ -519,13 +437,10 @@ def main():
|
||||
torch.cuda.nvtx.range_pop()
|
||||
else:
|
||||
generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()
|
||||
target_rms = 0.1
|
||||
target_sample_rate = 24_000
|
||||
# if ref_rms_list[i] < target_rms:
|
||||
# generated_wave = generated_wave * ref_rms_list[i] / target_rms
|
||||
rms = torch.sqrt(torch.mean(torch.square(generated_wave)))
|
||||
if rms < target_rms:
|
||||
generated_wave = generated_wave * target_rms / rms
|
||||
|
||||
if batch["ref_rms_list"][i] < target_rms:
|
||||
generated_wave = generated_wave * batch["ref_rms_list"][i] / target_rms
|
||||
|
||||
utt = batch["ids"][i]
|
||||
torchaudio.save(
|
||||
f"{args.output_dir}/{utt}.wav",
|
||||
|
||||
@@ -30,15 +30,6 @@ python3 client_grpc.py \
|
||||
--huggingface-dataset yuekai/seed_tts \
|
||||
--split-name test_zh \
|
||||
--log-dir ./log_concurrent_tasks_${num_task}
|
||||
|
||||
# For offline Spark-TTS-0.5B
|
||||
python3 client_grpc.py \
|
||||
--server-addr localhost \
|
||||
--model-name spark_tts \
|
||||
--num-tasks $num_task \
|
||||
--huggingface-dataset yuekai/seed_tts \
|
||||
--split-name wenetspeech4tts \
|
||||
--log-dir ./log_concurrent_tasks_${num_task}
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@@ -176,8 +167,7 @@ def get_args():
|
||||
"--model-name",
|
||||
type=str,
|
||||
default="f5_tts",
|
||||
choices=["f5_tts", "spark_tts"],
|
||||
help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline",
|
||||
help="triton model_repo module name to request",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@@ -206,7 +196,7 @@ def get_args():
|
||||
"--log-dir",
|
||||
type=str,
|
||||
required=False,
|
||||
default="./tmp",
|
||||
default="./tests/client_grpc",
|
||||
help="log directory",
|
||||
)
|
||||
|
||||
@@ -230,8 +220,7 @@ def load_audio(wav_path, target_sample_rate=24000):
|
||||
if sample_rate != target_sample_rate:
|
||||
from scipy.signal import resample
|
||||
|
||||
num_samples = int(len(waveform) * (target_sample_rate / sample_rate))
|
||||
waveform = resample(waveform, num_samples)
|
||||
waveform = resample(waveform, int(len(waveform) * (target_sample_rate / sample_rate)))
|
||||
return waveform, target_sample_rate
|
||||
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@
|
||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
@@ -65,33 +66,32 @@ def get_args():
|
||||
"--model-name",
|
||||
type=str,
|
||||
default="f5_tts",
|
||||
choices=["f5_tts", "spark_tts"],
|
||||
help="triton model_repo module name to request",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output-audio",
|
||||
type=str,
|
||||
default="output.wav",
|
||||
default="tests/client_http.wav",
|
||||
help="Path to save the output audio",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def prepare_request(
|
||||
samples,
|
||||
waveform,
|
||||
reference_text,
|
||||
target_text,
|
||||
sample_rate=24000,
|
||||
audio_save_dir: str = "./",
|
||||
):
|
||||
assert len(samples.shape) == 1, "samples should be 1D"
|
||||
lengths = np.array([[len(samples)]], dtype=np.int32)
|
||||
samples = samples.reshape(1, -1).astype(np.float32)
|
||||
assert len(waveform.shape) == 1, "waveform should be 1D"
|
||||
lengths = np.array([[len(waveform)]], dtype=np.int32)
|
||||
waveform = waveform.reshape(1, -1).astype(np.float32)
|
||||
|
||||
data = {
|
||||
"inputs": [
|
||||
{"name": "reference_wav", "shape": samples.shape, "datatype": "FP32", "data": samples.tolist()},
|
||||
{"name": "reference_wav", "shape": waveform.shape, "datatype": "FP32", "data": waveform.tolist()},
|
||||
{
|
||||
"name": "reference_wav_len",
|
||||
"shape": lengths.shape,
|
||||
@@ -109,16 +109,15 @@ def prepare_request(
|
||||
def load_audio(wav_path, target_sample_rate=24000):
|
||||
assert target_sample_rate == 24000, "hard coding in server"
|
||||
if isinstance(wav_path, dict):
|
||||
samples = wav_path["array"]
|
||||
waveform = wav_path["array"]
|
||||
sample_rate = wav_path["sampling_rate"]
|
||||
else:
|
||||
samples, sample_rate = sf.read(wav_path)
|
||||
waveform, sample_rate = sf.read(wav_path)
|
||||
if sample_rate != target_sample_rate:
|
||||
from scipy.signal import resample
|
||||
|
||||
num_samples = int(len(samples) * (target_sample_rate / sample_rate))
|
||||
samples = resample(samples, num_samples)
|
||||
return samples, target_sample_rate
|
||||
waveform = resample(waveform, int(len(waveform) * (target_sample_rate / sample_rate)))
|
||||
return waveform, target_sample_rate
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -128,11 +127,11 @@ if __name__ == "__main__":
|
||||
server_url = f"http://{server_url}"
|
||||
|
||||
url = f"{server_url}/v2/models/{args.model_name}/infer"
|
||||
samples, sr = load_audio(args.reference_audio)
|
||||
waveform, sr = load_audio(args.reference_audio)
|
||||
assert sr == 24000, "sample rate hardcoded in server"
|
||||
|
||||
samples = np.array(samples, dtype=np.float32)
|
||||
data = prepare_request(samples, args.reference_text, args.target_text)
|
||||
waveform = np.array(waveform, dtype=np.float32)
|
||||
data = prepare_request(waveform, args.reference_text, args.target_text)
|
||||
|
||||
rsp = requests.post(
|
||||
url, headers={"Content-Type": "application/json"}, json=data, verify=False, params={"request_id": "0"}
|
||||
@@ -140,4 +139,5 @@ if __name__ == "__main__":
|
||||
result = rsp.json()
|
||||
audio = result["outputs"][0]["data"]
|
||||
audio = np.array(audio, dtype=np.float32)
|
||||
os.makedirs(os.path.dirname(args.output_audio), exist_ok=True)
|
||||
sf.write(args.output_audio, audio, 24000, "PCM_16")
|
||||
|
||||
@@ -12,6 +12,7 @@ import torch.nn.functional as F
|
||||
from tensorrt_llm._utils import str_dtype_to_torch, trt_dtype_to_torch
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.runtime.session import Session
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
|
||||
def remove_tensor_padding(input_tensor, input_tensor_lengths=None):
|
||||
@@ -32,26 +33,35 @@ def remove_tensor_padding(input_tensor, input_tensor_lengths=None):
|
||||
|
||||
|
||||
class TextEmbedding(nn.Module):
|
||||
def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2, precompute_max_pos=4096):
|
||||
def __init__(
|
||||
self, text_num_embeds, text_dim, mask_padding=True, conv_layers=0, conv_mult=2, precompute_max_pos=4096
|
||||
):
|
||||
super().__init__()
|
||||
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
|
||||
self.mask_padding = mask_padding
|
||||
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, precompute_max_pos), persistent=False)
|
||||
self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
|
||||
|
||||
def forward(self, text):
|
||||
# only keep tensors with value not -1
|
||||
text_mask = text != -1
|
||||
text_pad_cut_off_index = text_mask.sum(dim=1).max()
|
||||
def forward(self, text, seq_len, drop_text=False):
|
||||
text = text + 1
|
||||
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
|
||||
text = F.pad(text, (0, seq_len - text.shape[1]), value=0)
|
||||
if self.mask_padding:
|
||||
text_mask = text == 0
|
||||
|
||||
if drop_text: # cfg for text
|
||||
text = torch.zeros_like(text)
|
||||
|
||||
text = self.text_embed(text) # b n -> b n d
|
||||
text = text + self.freqs_cis[:seq_len, :]
|
||||
if self.mask_padding:
|
||||
text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
|
||||
for block in self.text_blocks:
|
||||
text = block(text)
|
||||
text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
|
||||
else:
|
||||
text = self.text_blocks(text)
|
||||
|
||||
text = text[:, :text_pad_cut_off_index]
|
||||
text = self.text_embed(text)
|
||||
text = text + self.freqs_cis[: text.shape[1], :]
|
||||
for block in self.text_blocks:
|
||||
text = block(text)
|
||||
# padding text to the original length
|
||||
# text shape: B,seq_len,C
|
||||
# pad at the second dimension
|
||||
text = F.pad(text, (0, 0, 0, text_mask.shape[1] - text.shape[1], 0, 0), value=0)
|
||||
return text
|
||||
|
||||
|
||||
@@ -112,20 +122,33 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_resca
|
||||
return torch.cat([freqs_cos, freqs_sin], dim=-1)
|
||||
|
||||
|
||||
def load_checkpoint(ckpt_path, use_ema=True):
|
||||
checkpoint = torch.load(ckpt_path, weights_only=True)
|
||||
def get_text_embed_dict(ckpt_path, use_ema=True):
|
||||
ckpt_type = ckpt_path.split(".")[-1]
|
||||
if ckpt_type == "safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
checkpoint = load_file(ckpt_path)
|
||||
else:
|
||||
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
||||
|
||||
if use_ema:
|
||||
if ckpt_type == "safetensors":
|
||||
checkpoint = {"ema_model_state_dict": checkpoint}
|
||||
checkpoint["model_state_dict"] = {
|
||||
k.replace("ema_model.", ""): v
|
||||
for k, v in checkpoint["ema_model_state_dict"].items()
|
||||
if k not in ["initted", "step"]
|
||||
}
|
||||
dict_state = checkpoint["model_state_dict"]
|
||||
else:
|
||||
if ckpt_type == "safetensors":
|
||||
checkpoint = {"model_state_dict": checkpoint}
|
||||
model_params = checkpoint["model_state_dict"]
|
||||
|
||||
text_embed_dict = {}
|
||||
for key in dict_state.keys():
|
||||
for key in model_params.keys():
|
||||
# transformer.text_embed.text_embed.weight -> text_embed.weight
|
||||
if "text_embed" in key:
|
||||
text_embed_dict[key.replace("transformer.text_embed.", "")] = dict_state[key]
|
||||
text_embed_dict[key.replace("transformer.text_embed.", "")] = model_params[key]
|
||||
return text_embed_dict
|
||||
|
||||
|
||||
@@ -194,18 +217,16 @@ class F5TTS(object):
|
||||
|
||||
self.max_mel_len = 4096
|
||||
self.text_embedding = TextEmbedding(
|
||||
text_num_embeds=vocab_size, text_dim=512, conv_layers=4, precompute_max_pos=self.max_mel_len
|
||||
text_num_embeds=vocab_size,
|
||||
text_dim=config["pretrained_config"]["text_dim"],
|
||||
mask_padding=config["pretrained_config"]["text_mask_padding"],
|
||||
conv_layers=config["pretrained_config"]["conv_layers"],
|
||||
precompute_max_pos=self.max_mel_len,
|
||||
).to(self.device)
|
||||
self.text_embedding.load_state_dict(load_checkpoint(model_path), strict=True)
|
||||
self.text_embedding.load_state_dict(get_text_embed_dict(model_path), strict=True)
|
||||
|
||||
self.target_audio_sample_rate = 24000
|
||||
self.target_rms = 0.15 # target rms for audio
|
||||
self.n_fft = 1024
|
||||
self.win_length = 1024
|
||||
self.hop_length = 256
|
||||
self.n_mel_channels = 100
|
||||
# self.max_mel_len = 3000
|
||||
self.head_dim = 64
|
||||
self.n_mel_channels = config["pretrained_config"]["mel_dim"]
|
||||
self.head_dim = config["pretrained_config"]["dim_head"]
|
||||
self.base_rescale_factor = 1.0
|
||||
self.interpolation_factor = 1.0
|
||||
base = 10000.0 * self.base_rescale_factor ** (self.head_dim / (self.head_dim - 2))
|
||||
@@ -214,14 +235,23 @@ class F5TTS(object):
|
||||
self.freqs = freqs.repeat_interleave(2, dim=-1).unsqueeze(0)
|
||||
self.rope_cos = self.freqs.cos().half()
|
||||
self.rope_sin = self.freqs.sin().half()
|
||||
self.nfe_steps = 16
|
||||
t = torch.linspace(0, 1, self.nfe_steps + 1, dtype=torch.float32)
|
||||
time_step = t + (-1.0) * (torch.cos(torch.pi * 0.5 * t) - 1 + t)
|
||||
|
||||
self.nfe_steps = 32
|
||||
epss = {
|
||||
5: [0, 2, 4, 8, 16, 32],
|
||||
6: [0, 2, 4, 6, 8, 16, 32],
|
||||
7: [0, 2, 4, 6, 8, 16, 24, 32],
|
||||
10: [0, 2, 4, 6, 8, 12, 16, 20, 24, 28, 32],
|
||||
12: [0, 2, 4, 6, 8, 10, 12, 14, 16, 20, 24, 28, 32],
|
||||
16: [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 32],
|
||||
}
|
||||
t = 1 / 32 * torch.tensor(epss.get(self.nfe_steps, list(range(self.nfe_steps + 1))), dtype=torch.float32)
|
||||
time_step = 1 - torch.cos(torch.pi * t / 2)
|
||||
delta_t = torch.diff(time_step)
|
||||
# WAR: hard coding 256 here
|
||||
tmp_dim = 256
|
||||
time_expand = torch.zeros((1, self.nfe_steps, tmp_dim), dtype=torch.float32)
|
||||
half_dim = tmp_dim // 2
|
||||
|
||||
freq_embed_dim = 256 # Warning: hard coding 256 here
|
||||
time_expand = torch.zeros((1, self.nfe_steps, freq_embed_dim), dtype=torch.float32)
|
||||
half_dim = freq_embed_dim // 2
|
||||
emb_factor = math.log(10000) / (half_dim - 1)
|
||||
emb_factor = 1000.0 * torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb_factor)
|
||||
for i in range(self.nfe_steps):
|
||||
@@ -344,7 +374,7 @@ class F5TTS(object):
|
||||
def sample(
|
||||
self,
|
||||
text_pad_sequence: torch.Tensor,
|
||||
ref_mel_batch: torch.Tensor,
|
||||
cond_pad_sequence: torch.Tensor,
|
||||
ref_mel_len_batch: torch.Tensor,
|
||||
estimated_reference_target_mel_len: List[int],
|
||||
remove_input_padding: bool = False,
|
||||
@@ -353,26 +383,43 @@ class F5TTS(object):
|
||||
if use_perf:
|
||||
torch.cuda.nvtx.range_push("text embedding")
|
||||
batch = text_pad_sequence.shape[0]
|
||||
max_seq_len = ref_mel_batch.shape[1]
|
||||
max_seq_len = cond_pad_sequence.shape[1]
|
||||
|
||||
text_pad_sequence_drop = torch.cat(
|
||||
(text_pad_sequence, torch.zeros((1, text_pad_sequence.shape[1]), dtype=torch.int32).to(self.device)), dim=0
|
||||
# get text_embed one by one to avoid misalignment
|
||||
text_and_drop_embedding_list = []
|
||||
for i in range(batch):
|
||||
text_embedding_i = self.text_embedding(
|
||||
text_pad_sequence[i].unsqueeze(0).to(self.device),
|
||||
estimated_reference_target_mel_len[i],
|
||||
drop_text=False,
|
||||
)
|
||||
text_embedding_drop_i = self.text_embedding(
|
||||
text_pad_sequence[i].unsqueeze(0).to(self.device),
|
||||
estimated_reference_target_mel_len[i],
|
||||
drop_text=True,
|
||||
)
|
||||
text_and_drop_embedding_list.extend([text_embedding_i[0], text_embedding_drop_i[0]])
|
||||
|
||||
# pad separately computed text_embed to form batch with max_seq_len
|
||||
text_and_drop_embedding = pad_sequence(
|
||||
text_and_drop_embedding_list,
|
||||
batch_first=True,
|
||||
padding_value=0,
|
||||
)
|
||||
text_embedding = text_and_drop_embedding[0::2]
|
||||
text_embedding_drop = text_and_drop_embedding[1::2]
|
||||
|
||||
text_embedding_drop_list = []
|
||||
for i in range(batch + 1):
|
||||
text_embedding_drop_list.append(self.text_embedding(text_pad_sequence_drop[i].unsqueeze(0).to(self.device)))
|
||||
text_embedding_drop_condition = torch.cat(text_embedding_drop_list, dim=0)
|
||||
|
||||
text_embedding = text_embedding_drop_condition[:-1]
|
||||
# text_embedding_drop B,T,C batch should be the same
|
||||
text_embedding_drop = text_embedding_drop_condition[-1].unsqueeze(0).repeat(batch, 1, 1)
|
||||
|
||||
noise = torch.randn_like(ref_mel_batch).to(self.device)
|
||||
noise = torch.randn_like(cond_pad_sequence).to(self.device)
|
||||
rope_cos = self.rope_cos[:, :max_seq_len, :].float().repeat(batch, 1, 1)
|
||||
rope_sin = self.rope_sin[:, :max_seq_len, :].float().repeat(batch, 1, 1)
|
||||
|
||||
cat_mel_text = torch.cat((ref_mel_batch, text_embedding), dim=-1)
|
||||
cat_mel_text = torch.cat(
|
||||
(
|
||||
cond_pad_sequence,
|
||||
text_embedding,
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
cat_mel_text_drop = torch.cat(
|
||||
(
|
||||
torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float32).to(self.device),
|
||||
|
||||
@@ -26,9 +26,8 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
import jieba
|
||||
import rjieba
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
import triton_python_backend_utils as pb_utils
|
||||
from f5_tts_trtllm import F5TTS
|
||||
@@ -67,7 +66,7 @@ def convert_char_to_pinyin(reference_target_texts_list, polyphone=True):
|
||||
for text in reference_target_texts_list:
|
||||
char_list = []
|
||||
text = text.translate(custom_trans)
|
||||
for seg in jieba.cut(text):
|
||||
for seg in rjieba.cut(text):
|
||||
seg_byte_len = len(bytes(seg, "UTF-8"))
|
||||
if seg_byte_len == len(seg): # if pure alphabets and symbols
|
||||
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
|
||||
@@ -99,7 +98,8 @@ def list_str_to_idx(
|
||||
padding_value=-1,
|
||||
): # noqa: F722
|
||||
list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
|
||||
return list_idx_tensors
|
||||
text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
|
||||
return text
|
||||
|
||||
|
||||
class TritonPythonModel:
|
||||
@@ -107,13 +107,12 @@ class TritonPythonModel:
|
||||
self.use_perf = True
|
||||
self.device = torch.device("cuda")
|
||||
self.target_audio_sample_rate = 24000
|
||||
self.target_rms = 0.15 # target rms for audio
|
||||
self.target_rms = 0.1 # least rms when inference, normalize to if lower
|
||||
self.n_fft = 1024
|
||||
self.win_length = 1024
|
||||
self.hop_length = 256
|
||||
self.n_mel_channels = 100
|
||||
self.max_mel_len = 3000
|
||||
self.head_dim = 64
|
||||
self.max_mel_len = 4096
|
||||
|
||||
parameters = json.loads(args["model_config"])["parameters"]
|
||||
for key, value in parameters.items():
|
||||
@@ -181,7 +180,8 @@ class TritonPythonModel:
|
||||
reference_target_texts_list,
|
||||
estimated_reference_target_mel_len,
|
||||
reference_mel_len,
|
||||
) = [], [], [], [], []
|
||||
reference_rms_list,
|
||||
) = [], [], [], [], [], []
|
||||
mel_features_list = []
|
||||
if self.use_perf:
|
||||
torch.cuda.nvtx.range_push("preprocess")
|
||||
@@ -208,6 +208,7 @@ class TritonPythonModel:
|
||||
ref_rms = torch.sqrt(torch.mean(torch.square(wav)))
|
||||
if ref_rms < self.target_rms:
|
||||
wav = wav * self.target_rms / ref_rms
|
||||
reference_rms_list.append(ref_rms)
|
||||
if self.reference_sample_rate != self.target_audio_sample_rate:
|
||||
wav = self.resampler(wav)
|
||||
wav = wav.to(self.device)
|
||||
@@ -228,7 +229,7 @@ class TritonPythonModel:
|
||||
max_seq_len = min(max(estimated_reference_target_mel_len), self.max_mel_len)
|
||||
|
||||
batch = len(requests)
|
||||
mel_features = torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float16).to(self.device)
|
||||
mel_features = torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float32).to(self.device)
|
||||
for i, mel in enumerate(mel_features_list):
|
||||
mel_features[i, : mel.shape[1], :] = mel
|
||||
|
||||
@@ -237,15 +238,6 @@ class TritonPythonModel:
|
||||
pinyin_list = convert_char_to_pinyin(reference_target_texts_list, polyphone=True)
|
||||
text_pad_sequence = list_str_to_idx(pinyin_list, self.vocab_char_map)
|
||||
|
||||
for i, item in enumerate(text_pad_sequence):
|
||||
text_pad_sequence[i] = F.pad(
|
||||
item, (0, estimated_reference_target_mel_len[i] - len(item)), mode="constant", value=-1
|
||||
)
|
||||
text_pad_sequence[i] += 1 # WAR: 0 is reserved for padding token, hard coding in F5-TTS
|
||||
text_pad_sequence = pad_sequence(text_pad_sequence, padding_value=-1, batch_first=True).to(self.device)
|
||||
text_pad_sequence = F.pad(
|
||||
text_pad_sequence, (0, max_seq_len - text_pad_sequence.shape[1]), mode="constant", value=-1
|
||||
)
|
||||
if self.use_perf:
|
||||
torch.cuda.nvtx.range_pop()
|
||||
|
||||
@@ -262,13 +254,12 @@ class TritonPythonModel:
|
||||
|
||||
responses = []
|
||||
for i in range(batch):
|
||||
ref_me_len = reference_mel_len[i]
|
||||
ref_mel_len = reference_mel_len[i]
|
||||
estimated_mel_len = estimated_reference_target_mel_len[i]
|
||||
denoised_one_item = denoised[i, ref_me_len:estimated_mel_len, :].unsqueeze(0).transpose(1, 2)
|
||||
denoised_one_item = denoised[i, ref_mel_len:estimated_mel_len, :].unsqueeze(0).transpose(1, 2)
|
||||
audio = self.forward_vocoder(denoised_one_item)
|
||||
rms = torch.sqrt(torch.mean(torch.square(audio)))
|
||||
if rms < self.target_rms:
|
||||
audio = audio * self.target_rms / rms
|
||||
if reference_rms_list[i] < self.target_rms:
|
||||
audio = audio * reference_rms_list[i] / self.target_rms
|
||||
|
||||
audio = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
|
||||
inference_response = pb_utils.InferenceResponse(output_tensors=[audio])
|
||||
|
||||
@@ -4,11 +4,20 @@ import os
|
||||
import sys
|
||||
from collections import OrderedDict
|
||||
|
||||
import numpy as np
|
||||
import tensorrt as trt
|
||||
from tensorrt_llm._common import default_net
|
||||
|
||||
from ..._utils import str_dtype_to_trt
|
||||
from ...functional import Tensor, concat
|
||||
from ...functional import (
|
||||
Tensor,
|
||||
concat,
|
||||
constant,
|
||||
expand,
|
||||
shape,
|
||||
slice,
|
||||
unsqueeze,
|
||||
)
|
||||
from ...layers import Linear
|
||||
from ...module import Module, ModuleList
|
||||
from ...plugin import current_all_reduce_helper
|
||||
@@ -27,9 +36,9 @@ class InputEmbedding(Module):
|
||||
self.proj = Linear(mel_dim * 2 + text_dim, out_dim)
|
||||
self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
|
||||
|
||||
def forward(self, x, cond):
|
||||
def forward(self, x, cond, mask=None):
|
||||
x = self.proj(concat([x, cond], dim=-1))
|
||||
return self.conv_pos_embed(x) + x
|
||||
return self.conv_pos_embed(x, mask=mask) + x
|
||||
|
||||
|
||||
class F5TTS(PretrainedModel):
|
||||
@@ -50,6 +59,7 @@ class F5TTS(PretrainedModel):
|
||||
dim_head=config.dim_head,
|
||||
ff_mult=config.ff_mult,
|
||||
dropout=config.dropout,
|
||||
pe_attn_head=config.pe_attn_head,
|
||||
)
|
||||
for _ in range(self.depth)
|
||||
]
|
||||
@@ -68,10 +78,26 @@ class F5TTS(PretrainedModel):
|
||||
input_lengths,
|
||||
scale=1.0,
|
||||
):
|
||||
if default_net().plugin_config.remove_input_padding:
|
||||
mask = None
|
||||
else:
|
||||
N = shape(noise, 1)
|
||||
B = shape(noise, 0)
|
||||
seq_len_2d = concat([1, N])
|
||||
max_position_embeddings = 4096
|
||||
# create position ids
|
||||
position_ids_buffer = constant(np.expand_dims(np.arange(max_position_embeddings).astype(np.int32), 0))
|
||||
tmp_position_ids = slice(position_ids_buffer, starts=[0, 0], sizes=seq_len_2d)
|
||||
tmp_position_ids = expand(tmp_position_ids, concat([B, N])) # [B, N]
|
||||
tmp_input_lengths = unsqueeze(input_lengths, 1) # [B, 1]
|
||||
tmp_input_lengths = expand(tmp_input_lengths, concat([B, N])) # [B, N]
|
||||
mask = tmp_position_ids < tmp_input_lengths # [B, N]
|
||||
mask = mask.cast("int32")
|
||||
|
||||
t = self.time_embed(time)
|
||||
x = self.input_embed(noise, cond)
|
||||
x = self.input_embed(noise, cond, mask=mask)
|
||||
for block in self.transformer_blocks:
|
||||
x = block(x, t, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale)
|
||||
x = block(x, t, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale, mask=mask)
|
||||
denoise = self.proj_out(self.norm_out(x, t))
|
||||
denoise.mark_output("denoised", self.dtype)
|
||||
return denoise
|
||||
@@ -79,13 +105,12 @@ class F5TTS(PretrainedModel):
|
||||
def prepare_inputs(self, **kwargs):
|
||||
max_batch_size = kwargs["max_batch_size"]
|
||||
batch_size_range = [2, 2, max_batch_size]
|
||||
mel_size = 100
|
||||
max_seq_len = 3000
|
||||
num_frames_range = [200, 2 * max_seq_len, max_seq_len * max_batch_size]
|
||||
hidden_size = 512
|
||||
concat_feature_dim = mel_size + hidden_size
|
||||
freq_embed_dim = 256
|
||||
head_dim = 64
|
||||
mel_size = self.config.mel_dim
|
||||
max_seq_len = 3000 # 4096
|
||||
num_frames_range = [mel_size * 2, max_seq_len * 2, max_seq_len * max_batch_size]
|
||||
concat_feature_dim = mel_size + self.config.text_dim
|
||||
freq_embed_dim = 256 # Warning: hard coding 256 here
|
||||
head_dim = self.config.dim_head
|
||||
mapping = self.config.mapping
|
||||
if mapping.tp_size > 1:
|
||||
current_all_reduce_helper().set_workspace_tensor(mapping, 1)
|
||||
|
||||
@@ -16,7 +16,6 @@ from ...functional import (
|
||||
chunk,
|
||||
concat,
|
||||
constant,
|
||||
expand,
|
||||
expand_dims,
|
||||
expand_dims_like,
|
||||
expand_mask,
|
||||
@@ -95,15 +94,24 @@ class ConvPositionEmbedding(Module):
|
||||
self.conv1d2 = Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2)
|
||||
self.mish = Mish()
|
||||
|
||||
def forward(self, x, mask=None): # noqa: F722
|
||||
def forward(self, x, mask=None):
|
||||
if default_net().plugin_config.remove_input_padding:
|
||||
x = unsqueeze(x, 0)
|
||||
x = permute(x, [0, 2, 1])
|
||||
x = self.mish(self.conv1d2(self.mish(self.conv1d1(x))))
|
||||
out = permute(x, [0, 2, 1])
|
||||
if mask is not None:
|
||||
mask = mask.view(concat([shape(mask, 0), 1, shape(mask, 1)])) # [B 1 N]
|
||||
mask = expand_dims_like(mask, x) # [B D N]
|
||||
mask = cast(mask, x.dtype)
|
||||
x = permute(x, [0, 2, 1]) # [B D N]
|
||||
|
||||
if mask is not None:
|
||||
x = self.mish(self.conv1d2(self.mish(self.conv1d1(x * mask) * mask)) * mask)
|
||||
else:
|
||||
x = self.mish(self.conv1d2(self.mish(self.conv1d1(x))))
|
||||
|
||||
x = permute(x, [0, 2, 1]) # [B N D]
|
||||
if default_net().plugin_config.remove_input_padding:
|
||||
out = squeeze(out, 0)
|
||||
return out
|
||||
x = squeeze(x, 0)
|
||||
return x
|
||||
|
||||
|
||||
class Attention(Module):
|
||||
@@ -185,6 +193,7 @@ class Attention(Module):
|
||||
rope_cos,
|
||||
rope_sin,
|
||||
input_lengths,
|
||||
mask=None,
|
||||
c=None, # context c
|
||||
scale=1.0,
|
||||
rope=None,
|
||||
@@ -227,29 +236,52 @@ def rotate_every_two_3dim(tensor: Tensor) -> Tensor:
|
||||
return out
|
||||
|
||||
|
||||
def apply_rotary_pos_emb_3dim(x, rope_cos, rope_sin):
|
||||
if default_net().plugin_config.remove_input_padding:
|
||||
rot_dim = shape(rope_cos, -1) # 64
|
||||
new_t_shape = concat([shape(x, 0), rot_dim]) # (-1, 64)
|
||||
x_ = slice(x, [0, 0], new_t_shape, [1, 1])
|
||||
end_dim = shape(x, -1) - shape(rope_cos, -1)
|
||||
new_t_unrotated_shape = concat([shape(x, 0), end_dim]) # (2, -1, 960)
|
||||
x_unrotated = slice(x, concat([0, rot_dim]), new_t_unrotated_shape, [1, 1])
|
||||
out = concat([x_ * rope_cos + rotate_every_two_3dim(x_) * rope_sin, x_unrotated], dim=-1)
|
||||
else:
|
||||
rot_dim = shape(rope_cos, 2) # 64
|
||||
new_t_shape = concat([shape(x, 0), shape(x, 1), rot_dim]) # (2, -1, 64)
|
||||
x_ = slice(x, [0, 0, 0], new_t_shape, [1, 1, 1])
|
||||
end_dim = shape(x, 2) - shape(rope_cos, 2)
|
||||
new_t_unrotated_shape = concat([shape(x, 0), shape(x, 1), end_dim]) # (2, -1, 960)
|
||||
x_unrotated = slice(x, concat([0, 0, rot_dim]), new_t_unrotated_shape, [1, 1, 1])
|
||||
out = concat([x_ * rope_cos + rotate_every_two_3dim(x_) * rope_sin, x_unrotated], dim=-1)
|
||||
def apply_rotary_pos_emb_3dim(x, rope_cos, rope_sin, pe_attn_head):
|
||||
full_dim = x.size(-1)
|
||||
head_dim = rope_cos.size(-1) # attn head dim, e.g. 64
|
||||
if pe_attn_head is None:
|
||||
pe_attn_head = full_dim // head_dim
|
||||
rotated_dim = head_dim * pe_attn_head
|
||||
|
||||
rotated_and_unrotated_list = []
|
||||
|
||||
if default_net().plugin_config.remove_input_padding: # for [N, D] input
|
||||
new_t_shape = concat([shape(x, 0), head_dim]) # (2, -1, 64)
|
||||
|
||||
for i in range(pe_attn_head):
|
||||
x_slice_i = slice(x, [0, i * 64], new_t_shape, [1, 1])
|
||||
x_rotated_i = x_slice_i * rope_cos + rotate_every_two_3dim(x_slice_i) * rope_sin
|
||||
rotated_and_unrotated_list.append(x_rotated_i)
|
||||
|
||||
new_t_unrotated_shape = concat([shape(x, 0), full_dim - rotated_dim]) # (2, -1, 1024 - 64 * pe_attn_head)
|
||||
x_unrotated = slice(x, concat([0, rotated_dim]), new_t_unrotated_shape, [1, 1])
|
||||
rotated_and_unrotated_list.append(x_unrotated)
|
||||
|
||||
else: # for [B, N, D] input
|
||||
new_t_shape = concat([shape(x, 0), shape(x, 1), head_dim]) # (2, -1, 64)
|
||||
|
||||
for i in range(pe_attn_head):
|
||||
x_slice_i = slice(x, [0, 0, i * 64], new_t_shape, [1, 1, 1])
|
||||
x_rotated_i = x_slice_i * rope_cos + rotate_every_two_3dim(x_slice_i) * rope_sin
|
||||
rotated_and_unrotated_list.append(x_rotated_i)
|
||||
|
||||
new_t_unrotated_shape = concat(
|
||||
[shape(x, 0), shape(x, 1), full_dim - rotated_dim]
|
||||
) # (2, -1, 1024 - 64 * pe_attn_head)
|
||||
x_unrotated = slice(x, concat([0, 0, rotated_dim]), new_t_unrotated_shape, [1, 1, 1])
|
||||
rotated_and_unrotated_list.append(x_unrotated)
|
||||
|
||||
out = concat(rotated_and_unrotated_list, dim=-1)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class AttnProcessor:
|
||||
def __init__(self):
|
||||
pass
|
||||
def __init__(
|
||||
self,
|
||||
pe_attn_head: Optional[int] = None, # number of attention head to apply rope, None for all
|
||||
):
|
||||
self.pe_attn_head = pe_attn_head
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -260,32 +292,21 @@ class AttnProcessor:
|
||||
input_lengths,
|
||||
scale=1.0,
|
||||
rope=None,
|
||||
mask=None,
|
||||
) -> torch.FloatTensor:
|
||||
query = attn.to_q(x)
|
||||
key = attn.to_k(x)
|
||||
value = attn.to_v(x)
|
||||
# k,v,q all (2,1226,1024)
|
||||
query = apply_rotary_pos_emb_3dim(query, rope_cos, rope_sin)
|
||||
key = apply_rotary_pos_emb_3dim(key, rope_cos, rope_sin)
|
||||
query = apply_rotary_pos_emb_3dim(query, rope_cos, rope_sin, self.pe_attn_head)
|
||||
key = apply_rotary_pos_emb_3dim(key, rope_cos, rope_sin, self.pe_attn_head)
|
||||
|
||||
# attention
|
||||
inner_dim = key.shape[-1]
|
||||
norm_factor = math.sqrt(attn.attention_head_size)
|
||||
q_scaling = 1.0 / norm_factor
|
||||
mask = None
|
||||
if not default_net().plugin_config.remove_input_padding:
|
||||
N = shape(x, 1)
|
||||
B = shape(x, 0)
|
||||
seq_len_2d = concat([1, N])
|
||||
max_position_embeddings = 4096
|
||||
# create position ids
|
||||
position_ids_buffer = constant(np.expand_dims(np.arange(max_position_embeddings).astype(np.int32), 0))
|
||||
tmp_position_ids = slice(position_ids_buffer, starts=[0, 0], sizes=seq_len_2d)
|
||||
tmp_position_ids = expand(tmp_position_ids, concat([B, N])) # BxL
|
||||
tmp_input_lengths = unsqueeze(input_lengths, 1) # Bx1
|
||||
tmp_input_lengths = expand(tmp_input_lengths, concat([B, N])) # BxL
|
||||
mask = tmp_position_ids < tmp_input_lengths # BxL
|
||||
mask = mask.cast("int32")
|
||||
if default_net().plugin_config.remove_input_padding:
|
||||
mask = None
|
||||
|
||||
if default_net().plugin_config.bert_attention_plugin:
|
||||
qkv = concat([query, key, value], dim=-1)
|
||||
@@ -354,12 +375,12 @@ class AttnProcessor:
|
||||
|
||||
# DiT Block
|
||||
class DiTBlock(Module):
|
||||
def __init__(self, dim, heads, dim_head, ff_mult=2, dropout=0.1):
|
||||
def __init__(self, dim, heads, dim_head, ff_mult=2, dropout=0.1, pe_attn_head=None):
|
||||
super().__init__()
|
||||
|
||||
self.attn_norm = AdaLayerNormZero(dim)
|
||||
self.attn = Attention(
|
||||
processor=AttnProcessor(),
|
||||
processor=AttnProcessor(pe_attn_head=pe_attn_head),
|
||||
dim=dim,
|
||||
heads=heads,
|
||||
dim_head=dim_head,
|
||||
@@ -370,14 +391,15 @@ class DiTBlock(Module):
|
||||
self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout)
|
||||
|
||||
def forward(
|
||||
self, x, t, rope_cos, rope_sin, input_lengths, scale=1.0, rope=ModuleNotFoundError
|
||||
self, x, t, rope_cos, rope_sin, input_lengths, scale=1.0, rope=ModuleNotFoundError, mask=None
|
||||
): # x: noised input, t: time embedding
|
||||
# pre-norm & modulation for attention input
|
||||
norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
|
||||
# attention
|
||||
# norm ----> (2,1226,1024)
|
||||
attn_output = self.attn(x=norm, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale)
|
||||
|
||||
attn_output = self.attn(
|
||||
x=norm, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale, mask=mask
|
||||
)
|
||||
# process attention output for input x
|
||||
if default_net().plugin_config.remove_input_padding:
|
||||
x = x + gate_msa * attn_output
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
accelerate>=0.33.0
|
||||
bitsandbytes>0.37.0
|
||||
cached_path
|
||||
click
|
||||
datasets
|
||||
ema_pytorch>=0.5.2
|
||||
gradio>=3.45.2
|
||||
hydra-core>=1.3.0
|
||||
jieba
|
||||
librosa
|
||||
matplotlib
|
||||
numpy<=1.26.4
|
||||
pydub
|
||||
pypinyin
|
||||
safetensors
|
||||
soundfile
|
||||
tomli
|
||||
torch>=2.0.0
|
||||
# torchaudio>=2.0.0
|
||||
torchdiffeq
|
||||
tqdm>=4.65.0
|
||||
transformers
|
||||
x_transformers>=1.31.14
|
||||
packaging>=24.2
|
||||
@@ -1,64 +1,66 @@
|
||||
stage=$1
|
||||
stop_stage=$2
|
||||
model=$3 # F5TTS_Base
|
||||
model=$3 # F5TTS_v1_Base | F5TTS_Base | F5TTS_v1_Small | F5TTS_Small
|
||||
if [ -z "$model" ]; then
|
||||
echo "Model is none, using default model F5TTS_Base"
|
||||
model=F5TTS_Base
|
||||
model=F5TTS_v1_Base
|
||||
fi
|
||||
echo "Start stage: $stage, Stop stage: $stop_stage, Model: $model"
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
F5_TTS_HF_DOWNLOAD_PATH=./F5-TTS
|
||||
F5_TTS_TRT_LLM_CHECKPOINT_PATH=./trtllm_ckpt
|
||||
F5_TTS_TRT_LLM_ENGINE_PATH=./f5_trt_llm_engine
|
||||
CKPT_DIR=../../../../ckpts
|
||||
TRTLLM_CKPT_DIR=$CKPT_DIR/$model/trtllm_ckpt
|
||||
TRTLLM_ENGINE_DIR=$CKPT_DIR/$model/trtllm_engine
|
||||
|
||||
vocoder_trt_engine_path=vocos_vocoder.plan
|
||||
model_repo=./model_repo
|
||||
VOCODER_ONNX_PATH=$CKPT_DIR/vocos_vocoder.onnx
|
||||
VOCODER_TRT_ENGINE_PATH=$CKPT_DIR/vocos_vocoder.plan
|
||||
MODEL_REPO=./model_repo
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
echo "Downloading f5 tts from huggingface"
|
||||
huggingface-cli download SWivid/F5-TTS --local-dir $F5_TTS_HF_DOWNLOAD_PATH
|
||||
|
||||
echo "Downloading F5-TTS from huggingface"
|
||||
huggingface-cli download SWivid/F5-TTS $model/model_*.* $model/vocab.txt --local-dir $CKPT_DIR
|
||||
fi
|
||||
|
||||
ckpt_file=$(ls $CKPT_DIR/$model/model_*.* 2>/dev/null | sort -V | tail -1) # default select latest update
|
||||
vocab_file=$CKPT_DIR/$model/vocab.txt
|
||||
|
||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
echo "Converting checkpoint"
|
||||
python3 ./scripts/convert_checkpoint.py \
|
||||
--timm_ckpt "$F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt" \
|
||||
--output_dir "$F5_TTS_TRT_LLM_CHECKPOINT_PATH" --model_name $model
|
||||
python3 scripts/convert_checkpoint.py \
|
||||
--pytorch_ckpt $ckpt_file \
|
||||
--output_dir $TRTLLM_CKPT_DIR --model_name $model
|
||||
python_package_path=/usr/local/lib/python3.12/dist-packages
|
||||
cp -r patch/* $python_package_path/tensorrt_llm/models
|
||||
trtllm-build --checkpoint_dir $F5_TTS_TRT_LLM_CHECKPOINT_PATH \
|
||||
trtllm-build --checkpoint_dir $TRTLLM_CKPT_DIR \
|
||||
--max_batch_size 8 \
|
||||
--output_dir $F5_TTS_TRT_LLM_ENGINE_PATH --remove_input_padding disable
|
||||
--output_dir $TRTLLM_ENGINE_DIR --remove_input_padding disable
|
||||
fi
|
||||
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
echo "Exporting vocos vocoder"
|
||||
onnx_vocoder_path=vocos_vocoder.onnx
|
||||
python3 scripts/export_vocoder_to_onnx.py --vocoder vocos --output-path $onnx_vocoder_path
|
||||
bash scripts/export_vocos_trt.sh $onnx_vocoder_path $vocoder_trt_engine_path
|
||||
python3 scripts/export_vocoder_to_onnx.py --vocoder vocos --output-path $VOCODER_ONNX_PATH
|
||||
bash scripts/export_vocos_trt.sh $VOCODER_ONNX_PATH $VOCODER_TRT_ENGINE_PATH
|
||||
fi
|
||||
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
echo "Building triton server"
|
||||
rm -r $model_repo
|
||||
cp -r ./model_repo_f5_tts $model_repo
|
||||
python3 scripts/fill_template.py -i $model_repo/f5_tts/config.pbtxt vocab:$F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt,model:$F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt,trtllm:$F5_TTS_TRT_LLM_ENGINE_PATH,vocoder:vocos
|
||||
cp $vocoder_trt_engine_path $model_repo/vocoder/1/vocoder.plan
|
||||
rm -r $MODEL_REPO
|
||||
cp -r ./model_repo_f5_tts $MODEL_REPO
|
||||
python3 scripts/fill_template.py -i $MODEL_REPO/f5_tts/config.pbtxt vocab:$vocab_file,model:$ckpt_file,trtllm:$TRTLLM_ENGINE_DIR,vocoder:vocos
|
||||
cp $VOCODER_TRT_ENGINE_PATH $MODEL_REPO/vocoder/1/vocoder.plan
|
||||
fi
|
||||
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
echo "Starting triton server"
|
||||
tritonserver --model-repository=$model_repo
|
||||
tritonserver --model-repository=$MODEL_REPO
|
||||
fi
|
||||
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
echo "Testing triton server"
|
||||
num_task=1
|
||||
log_dir=./log_concurrent_tasks_${num_task}
|
||||
split_name=wenetspeech4tts
|
||||
log_dir=./tests/client_grpc_${model}_concurrent_${num_task}_${split_name}
|
||||
rm -r $log_dir
|
||||
python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts --log-dir $log_dir
|
||||
python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name $split_name --log-dir $log_dir
|
||||
fi
|
||||
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
@@ -66,45 +68,45 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
audio=../../infer/examples/basic/basic_ref_en.wav
|
||||
reference_text="Some call me nature, others call me mother nature."
|
||||
target_text="I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring."
|
||||
python3 client_http.py --reference-audio $audio --reference-text "$reference_text" --target-text "$target_text"
|
||||
python3 client_http.py --reference-audio $audio --reference-text "$reference_text" --target-text "$target_text" --output-audio "./tests/client_http_$model.wav"
|
||||
fi
|
||||
|
||||
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
||||
echo "TRT-LLM: offline decoding benchmark test"
|
||||
batch_size=1
|
||||
batch_size=2
|
||||
split_name=wenetspeech4tts
|
||||
backend_type=trt
|
||||
log_dir=./log_benchmark_batch_size_${batch_size}_${split_name}_${backend_type}
|
||||
log_dir=./tests/benchmark_${model}_batch_size_${batch_size}_${split_name}_${backend_type}
|
||||
rm -r $log_dir
|
||||
ln -s model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py ./
|
||||
torchrun --nproc_per_node=1 \
|
||||
benchmark.py --output-dir $log_dir \
|
||||
--batch-size $batch_size \
|
||||
--enable-warmup \
|
||||
--split-name $split_name \
|
||||
--model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \
|
||||
--vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \
|
||||
--vocoder-trt-engine-path $vocoder_trt_engine_path \
|
||||
--model-path $ckpt_file \
|
||||
--vocab-file $vocab_file \
|
||||
--vocoder-trt-engine-path $VOCODER_TRT_ENGINE_PATH \
|
||||
--backend-type $backend_type \
|
||||
--tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1
|
||||
--tllm-model-dir $TRTLLM_ENGINE_DIR || exit 1
|
||||
fi
|
||||
|
||||
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
||||
echo "Native Pytorch: offline decoding benchmark test"
|
||||
pip install -r requirements-pytorch.txt
|
||||
batch_size=1
|
||||
if ! python3 -c "import f5_tts" &> /dev/null; then
|
||||
pip install -e ../../../../
|
||||
fi
|
||||
batch_size=1 # set attn_mask_enabled=True if batching in actual use case
|
||||
split_name=wenetspeech4tts
|
||||
backend_type=pytorch
|
||||
log_dir=./log_benchmark_batch_size_${batch_size}_${split_name}_${backend_type}
|
||||
log_dir=./tests/benchmark_${model}_batch_size_${batch_size}_${split_name}_${backend_type}
|
||||
rm -r $log_dir
|
||||
ln -s model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py ./
|
||||
torchrun --nproc_per_node=1 \
|
||||
benchmark.py --output-dir $log_dir \
|
||||
--batch-size $batch_size \
|
||||
--split-name $split_name \
|
||||
--enable-warmup \
|
||||
--model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \
|
||||
--vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \
|
||||
--model-path $ckpt_file \
|
||||
--vocab-file $vocab_file \
|
||||
--backend-type $backend_type \
|
||||
--tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1
|
||||
--tllm-model-dir $TRTLLM_ENGINE_DIR || exit 1
|
||||
fi
|
||||
@@ -23,168 +23,12 @@ def split_q_bias_tp(v, n_head, n_hidden, tensor_parallel, rank):
|
||||
return split_v.contiguous()
|
||||
|
||||
|
||||
FACEBOOK_DIT_NAME_MAPPING = {
|
||||
"^time_embed.time_mlp.0.weight$": "time_embed.mlp1.weight",
|
||||
"^time_embed.time_mlp.0.bias$": "time_embed.mlp1.bias",
|
||||
"^time_embed.time_mlp.2.weight$": "time_embed.mlp2.weight",
|
||||
"^time_embed.time_mlp.2.bias$": "time_embed.mlp2.bias",
|
||||
"^input_embed.conv_pos_embed.conv1d.0.weight$": "input_embed.conv_pos_embed.conv1d1.weight",
|
||||
"^input_embed.conv_pos_embed.conv1d.0.bias$": "input_embed.conv_pos_embed.conv1d1.bias",
|
||||
"^input_embed.conv_pos_embed.conv1d.2.weight$": "input_embed.conv_pos_embed.conv1d2.weight",
|
||||
"^input_embed.conv_pos_embed.conv1d.2.bias$": "input_embed.conv_pos_embed.conv1d2.bias",
|
||||
"^transformer_blocks.0.attn.to_out.0.weight$": "transformer_blocks.0.attn.to_out.weight",
|
||||
"^transformer_blocks.0.attn.to_out.0.bias$": "transformer_blocks.0.attn.to_out.bias",
|
||||
"^transformer_blocks.1.attn.to_out.0.weight$": "transformer_blocks.1.attn.to_out.weight",
|
||||
"^transformer_blocks.1.attn.to_out.0.bias$": "transformer_blocks.1.attn.to_out.bias",
|
||||
"^transformer_blocks.2.attn.to_out.0.weight$": "transformer_blocks.2.attn.to_out.weight",
|
||||
"^transformer_blocks.2.attn.to_out.0.bias$": "transformer_blocks.2.attn.to_out.bias",
|
||||
"^transformer_blocks.3.attn.to_out.0.weight$": "transformer_blocks.3.attn.to_out.weight",
|
||||
"^transformer_blocks.3.attn.to_out.0.bias$": "transformer_blocks.3.attn.to_out.bias",
|
||||
"^transformer_blocks.4.attn.to_out.0.weight$": "transformer_blocks.4.attn.to_out.weight",
|
||||
"^transformer_blocks.4.attn.to_out.0.bias$": "transformer_blocks.4.attn.to_out.bias",
|
||||
"^transformer_blocks.5.attn.to_out.0.weight$": "transformer_blocks.5.attn.to_out.weight",
|
||||
"^transformer_blocks.5.attn.to_out.0.bias$": "transformer_blocks.5.attn.to_out.bias",
|
||||
"^transformer_blocks.6.attn.to_out.0.weight$": "transformer_blocks.6.attn.to_out.weight",
|
||||
"^transformer_blocks.6.attn.to_out.0.bias$": "transformer_blocks.6.attn.to_out.bias",
|
||||
"^transformer_blocks.7.attn.to_out.0.weight$": "transformer_blocks.7.attn.to_out.weight",
|
||||
"^transformer_blocks.7.attn.to_out.0.bias$": "transformer_blocks.7.attn.to_out.bias",
|
||||
"^transformer_blocks.8.attn.to_out.0.weight$": "transformer_blocks.8.attn.to_out.weight",
|
||||
"^transformer_blocks.8.attn.to_out.0.bias$": "transformer_blocks.8.attn.to_out.bias",
|
||||
"^transformer_blocks.9.attn.to_out.0.weight$": "transformer_blocks.9.attn.to_out.weight",
|
||||
"^transformer_blocks.9.attn.to_out.0.bias$": "transformer_blocks.9.attn.to_out.bias",
|
||||
"^transformer_blocks.10.attn.to_out.0.weight$": "transformer_blocks.10.attn.to_out.weight",
|
||||
"^transformer_blocks.10.attn.to_out.0.bias$": "transformer_blocks.10.attn.to_out.bias",
|
||||
"^transformer_blocks.11.attn.to_out.0.weight$": "transformer_blocks.11.attn.to_out.weight",
|
||||
"^transformer_blocks.11.attn.to_out.0.bias$": "transformer_blocks.11.attn.to_out.bias",
|
||||
"^transformer_blocks.12.attn.to_out.0.weight$": "transformer_blocks.12.attn.to_out.weight",
|
||||
"^transformer_blocks.12.attn.to_out.0.bias$": "transformer_blocks.12.attn.to_out.bias",
|
||||
"^transformer_blocks.13.attn.to_out.0.weight$": "transformer_blocks.13.attn.to_out.weight",
|
||||
"^transformer_blocks.13.attn.to_out.0.bias$": "transformer_blocks.13.attn.to_out.bias",
|
||||
"^transformer_blocks.14.attn.to_out.0.weight$": "transformer_blocks.14.attn.to_out.weight",
|
||||
"^transformer_blocks.14.attn.to_out.0.bias$": "transformer_blocks.14.attn.to_out.bias",
|
||||
"^transformer_blocks.15.attn.to_out.0.weight$": "transformer_blocks.15.attn.to_out.weight",
|
||||
"^transformer_blocks.15.attn.to_out.0.bias$": "transformer_blocks.15.attn.to_out.bias",
|
||||
"^transformer_blocks.16.attn.to_out.0.weight$": "transformer_blocks.16.attn.to_out.weight",
|
||||
"^transformer_blocks.16.attn.to_out.0.bias$": "transformer_blocks.16.attn.to_out.bias",
|
||||
"^transformer_blocks.17.attn.to_out.0.weight$": "transformer_blocks.17.attn.to_out.weight",
|
||||
"^transformer_blocks.17.attn.to_out.0.bias$": "transformer_blocks.17.attn.to_out.bias",
|
||||
"^transformer_blocks.18.attn.to_out.0.weight$": "transformer_blocks.18.attn.to_out.weight",
|
||||
"^transformer_blocks.18.attn.to_out.0.bias$": "transformer_blocks.18.attn.to_out.bias",
|
||||
"^transformer_blocks.19.attn.to_out.0.weight$": "transformer_blocks.19.attn.to_out.weight",
|
||||
"^transformer_blocks.19.attn.to_out.0.bias$": "transformer_blocks.19.attn.to_out.bias",
|
||||
"^transformer_blocks.20.attn.to_out.0.weight$": "transformer_blocks.20.attn.to_out.weight",
|
||||
"^transformer_blocks.20.attn.to_out.0.bias$": "transformer_blocks.20.attn.to_out.bias",
|
||||
"^transformer_blocks.21.attn.to_out.0.weight$": "transformer_blocks.21.attn.to_out.weight",
|
||||
"^transformer_blocks.21.attn.to_out.0.bias$": "transformer_blocks.21.attn.to_out.bias",
|
||||
"^transformer_blocks.0.ff.ff.0.0.weight$": "transformer_blocks.0.ff.project_in.weight",
|
||||
"^transformer_blocks.0.ff.ff.0.0.bias$": "transformer_blocks.0.ff.project_in.bias",
|
||||
"^transformer_blocks.0.ff.ff.2.weight$": "transformer_blocks.0.ff.ff.weight",
|
||||
"^transformer_blocks.0.ff.ff.2.bias$": "transformer_blocks.0.ff.ff.bias",
|
||||
"^transformer_blocks.1.ff.ff.0.0.weight$": "transformer_blocks.1.ff.project_in.weight",
|
||||
"^transformer_blocks.1.ff.ff.0.0.bias$": "transformer_blocks.1.ff.project_in.bias",
|
||||
"^transformer_blocks.1.ff.ff.2.weight$": "transformer_blocks.1.ff.ff.weight",
|
||||
"^transformer_blocks.1.ff.ff.2.bias$": "transformer_blocks.1.ff.ff.bias",
|
||||
"^transformer_blocks.2.ff.ff.0.0.weight$": "transformer_blocks.2.ff.project_in.weight",
|
||||
"^transformer_blocks.2.ff.ff.0.0.bias$": "transformer_blocks.2.ff.project_in.bias",
|
||||
"^transformer_blocks.2.ff.ff.2.weight$": "transformer_blocks.2.ff.ff.weight",
|
||||
"^transformer_blocks.2.ff.ff.2.bias$": "transformer_blocks.2.ff.ff.bias",
|
||||
"^transformer_blocks.3.ff.ff.0.0.weight$": "transformer_blocks.3.ff.project_in.weight",
|
||||
"^transformer_blocks.3.ff.ff.0.0.bias$": "transformer_blocks.3.ff.project_in.bias",
|
||||
"^transformer_blocks.3.ff.ff.2.weight$": "transformer_blocks.3.ff.ff.weight",
|
||||
"^transformer_blocks.3.ff.ff.2.bias$": "transformer_blocks.3.ff.ff.bias",
|
||||
"^transformer_blocks.4.ff.ff.0.0.weight$": "transformer_blocks.4.ff.project_in.weight",
|
||||
"^transformer_blocks.4.ff.ff.0.0.bias$": "transformer_blocks.4.ff.project_in.bias",
|
||||
"^transformer_blocks.4.ff.ff.2.weight$": "transformer_blocks.4.ff.ff.weight",
|
||||
"^transformer_blocks.4.ff.ff.2.bias$": "transformer_blocks.4.ff.ff.bias",
|
||||
"^transformer_blocks.5.ff.ff.0.0.weight$": "transformer_blocks.5.ff.project_in.weight",
|
||||
"^transformer_blocks.5.ff.ff.0.0.bias$": "transformer_blocks.5.ff.project_in.bias",
|
||||
"^transformer_blocks.5.ff.ff.2.weight$": "transformer_blocks.5.ff.ff.weight",
|
||||
"^transformer_blocks.5.ff.ff.2.bias$": "transformer_blocks.5.ff.ff.bias",
|
||||
"^transformer_blocks.6.ff.ff.0.0.weight$": "transformer_blocks.6.ff.project_in.weight",
|
||||
"^transformer_blocks.6.ff.ff.0.0.bias$": "transformer_blocks.6.ff.project_in.bias",
|
||||
"^transformer_blocks.6.ff.ff.2.weight$": "transformer_blocks.6.ff.ff.weight",
|
||||
"^transformer_blocks.6.ff.ff.2.bias$": "transformer_blocks.6.ff.ff.bias",
|
||||
"^transformer_blocks.7.ff.ff.0.0.weight$": "transformer_blocks.7.ff.project_in.weight",
|
||||
"^transformer_blocks.7.ff.ff.0.0.bias$": "transformer_blocks.7.ff.project_in.bias",
|
||||
"^transformer_blocks.7.ff.ff.2.weight$": "transformer_blocks.7.ff.ff.weight",
|
||||
"^transformer_blocks.7.ff.ff.2.bias$": "transformer_blocks.7.ff.ff.bias",
|
||||
"^transformer_blocks.8.ff.ff.0.0.weight$": "transformer_blocks.8.ff.project_in.weight",
|
||||
"^transformer_blocks.8.ff.ff.0.0.bias$": "transformer_blocks.8.ff.project_in.bias",
|
||||
"^transformer_blocks.8.ff.ff.2.weight$": "transformer_blocks.8.ff.ff.weight",
|
||||
"^transformer_blocks.8.ff.ff.2.bias$": "transformer_blocks.8.ff.ff.bias",
|
||||
"^transformer_blocks.9.ff.ff.0.0.weight$": "transformer_blocks.9.ff.project_in.weight",
|
||||
"^transformer_blocks.9.ff.ff.0.0.bias$": "transformer_blocks.9.ff.project_in.bias",
|
||||
"^transformer_blocks.9.ff.ff.2.weight$": "transformer_blocks.9.ff.ff.weight",
|
||||
"^transformer_blocks.9.ff.ff.2.bias$": "transformer_blocks.9.ff.ff.bias",
|
||||
"^transformer_blocks.10.ff.ff.0.0.weight$": "transformer_blocks.10.ff.project_in.weight",
|
||||
"^transformer_blocks.10.ff.ff.0.0.bias$": "transformer_blocks.10.ff.project_in.bias",
|
||||
"^transformer_blocks.10.ff.ff.2.weight$": "transformer_blocks.10.ff.ff.weight",
|
||||
"^transformer_blocks.10.ff.ff.2.bias$": "transformer_blocks.10.ff.ff.bias",
|
||||
"^transformer_blocks.11.ff.ff.0.0.weight$": "transformer_blocks.11.ff.project_in.weight",
|
||||
"^transformer_blocks.11.ff.ff.0.0.bias$": "transformer_blocks.11.ff.project_in.bias",
|
||||
"^transformer_blocks.11.ff.ff.2.weight$": "transformer_blocks.11.ff.ff.weight",
|
||||
"^transformer_blocks.11.ff.ff.2.bias$": "transformer_blocks.11.ff.ff.bias",
|
||||
"^transformer_blocks.12.ff.ff.0.0.weight$": "transformer_blocks.12.ff.project_in.weight",
|
||||
"^transformer_blocks.12.ff.ff.0.0.bias$": "transformer_blocks.12.ff.project_in.bias",
|
||||
"^transformer_blocks.12.ff.ff.2.weight$": "transformer_blocks.12.ff.ff.weight",
|
||||
"^transformer_blocks.12.ff.ff.2.bias$": "transformer_blocks.12.ff.ff.bias",
|
||||
"^transformer_blocks.13.ff.ff.0.0.weight$": "transformer_blocks.13.ff.project_in.weight",
|
||||
"^transformer_blocks.13.ff.ff.0.0.bias$": "transformer_blocks.13.ff.project_in.bias",
|
||||
"^transformer_blocks.13.ff.ff.2.weight$": "transformer_blocks.13.ff.ff.weight",
|
||||
"^transformer_blocks.13.ff.ff.2.bias$": "transformer_blocks.13.ff.ff.bias",
|
||||
"^transformer_blocks.14.ff.ff.0.0.weight$": "transformer_blocks.14.ff.project_in.weight",
|
||||
"^transformer_blocks.14.ff.ff.0.0.bias$": "transformer_blocks.14.ff.project_in.bias",
|
||||
"^transformer_blocks.14.ff.ff.2.weight$": "transformer_blocks.14.ff.ff.weight",
|
||||
"^transformer_blocks.14.ff.ff.2.bias$": "transformer_blocks.14.ff.ff.bias",
|
||||
"^transformer_blocks.15.ff.ff.0.0.weight$": "transformer_blocks.15.ff.project_in.weight",
|
||||
"^transformer_blocks.15.ff.ff.0.0.bias$": "transformer_blocks.15.ff.project_in.bias",
|
||||
"^transformer_blocks.15.ff.ff.2.weight$": "transformer_blocks.15.ff.ff.weight",
|
||||
"^transformer_blocks.15.ff.ff.2.bias$": "transformer_blocks.15.ff.ff.bias",
|
||||
"^transformer_blocks.16.ff.ff.0.0.weight$": "transformer_blocks.16.ff.project_in.weight",
|
||||
"^transformer_blocks.16.ff.ff.0.0.bias$": "transformer_blocks.16.ff.project_in.bias",
|
||||
"^transformer_blocks.16.ff.ff.2.weight$": "transformer_blocks.16.ff.ff.weight",
|
||||
"^transformer_blocks.16.ff.ff.2.bias$": "transformer_blocks.16.ff.ff.bias",
|
||||
"^transformer_blocks.17.ff.ff.0.0.weight$": "transformer_blocks.17.ff.project_in.weight",
|
||||
"^transformer_blocks.17.ff.ff.0.0.bias$": "transformer_blocks.17.ff.project_in.bias",
|
||||
"^transformer_blocks.17.ff.ff.2.weight$": "transformer_blocks.17.ff.ff.weight",
|
||||
"^transformer_blocks.17.ff.ff.2.bias$": "transformer_blocks.17.ff.ff.bias",
|
||||
"^transformer_blocks.18.ff.ff.0.0.weight$": "transformer_blocks.18.ff.project_in.weight",
|
||||
"^transformer_blocks.18.ff.ff.0.0.bias$": "transformer_blocks.18.ff.project_in.bias",
|
||||
"^transformer_blocks.18.ff.ff.2.weight$": "transformer_blocks.18.ff.ff.weight",
|
||||
"^transformer_blocks.18.ff.ff.2.bias$": "transformer_blocks.18.ff.ff.bias",
|
||||
"^transformer_blocks.19.ff.ff.0.0.weight$": "transformer_blocks.19.ff.project_in.weight",
|
||||
"^transformer_blocks.19.ff.ff.0.0.bias$": "transformer_blocks.19.ff.project_in.bias",
|
||||
"^transformer_blocks.19.ff.ff.2.weight$": "transformer_blocks.19.ff.ff.weight",
|
||||
"^transformer_blocks.19.ff.ff.2.bias$": "transformer_blocks.19.ff.ff.bias",
|
||||
"^transformer_blocks.20.ff.ff.0.0.weight$": "transformer_blocks.20.ff.project_in.weight",
|
||||
"^transformer_blocks.20.ff.ff.0.0.bias$": "transformer_blocks.20.ff.project_in.bias",
|
||||
"^transformer_blocks.20.ff.ff.2.weight$": "transformer_blocks.20.ff.ff.weight",
|
||||
"^transformer_blocks.20.ff.ff.2.bias$": "transformer_blocks.20.ff.ff.bias",
|
||||
"^transformer_blocks.21.ff.ff.0.0.weight$": "transformer_blocks.21.ff.project_in.weight",
|
||||
"^transformer_blocks.21.ff.ff.0.0.bias$": "transformer_blocks.21.ff.project_in.bias",
|
||||
"^transformer_blocks.21.ff.ff.2.weight$": "transformer_blocks.21.ff.ff.weight",
|
||||
"^transformer_blocks.21.ff.ff.2.bias$": "transformer_blocks.21.ff.ff.bias",
|
||||
}
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
type=str,
|
||||
default="F5TTS_Base",
|
||||
choices=[
|
||||
"F5TTS_Base",
|
||||
],
|
||||
) # TODO: support F5TTS_v1_Base
|
||||
parser.add_argument("--timm_ckpt", type=str, default="./ckpts/model_1200000.pt")
|
||||
parser.add_argument("--pytorch_ckpt", type=str, default="./ckpts/model_last.pt")
|
||||
parser.add_argument(
|
||||
"--output_dir", type=str, default="./tllm_checkpoint", help="The path to save the TensorRT-LLM checkpoint"
|
||||
)
|
||||
parser.add_argument("--hidden_size", type=int, default=1024, help="The hidden size of DiT")
|
||||
parser.add_argument("--depth", type=int, default=22, help="The number of DiTBlock layers")
|
||||
parser.add_argument("--num_heads", type=int, default=16, help="The number of heads of attention module")
|
||||
parser.add_argument("--cfg_scale", type=float, default=4.0)
|
||||
parser.add_argument("--tp_size", type=int, default=1, help="N-way tensor parallelism size")
|
||||
parser.add_argument("--cp_size", type=int, default=1, help="Context parallelism size")
|
||||
parser.add_argument("--pp_size", type=int, default=1, help="N-way pipeline parallelism size")
|
||||
@@ -193,33 +37,119 @@ def parse_arguments():
|
||||
parser.add_argument(
|
||||
"--workers", type=int, default=1, help="The number of workers for converting checkpoint in parallel"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
type=str,
|
||||
default="F5TTS_Custom",
|
||||
choices=[
|
||||
"F5TTS_v1_Base",
|
||||
"F5TTS_Base",
|
||||
"F5TTS_v1_Small",
|
||||
"F5TTS_Small",
|
||||
], # if set, overwrite the below hyperparams
|
||||
)
|
||||
parser.add_argument("--hidden_size", type=int, default=1024, help="The hidden size of DiT")
|
||||
parser.add_argument("--depth", type=int, default=22, help="The number of DiTBlock layers")
|
||||
parser.add_argument("--num_heads", type=int, default=16, help="The number of heads of attention module")
|
||||
parser.add_argument("--dim_head", type=int, default=64, help="The dimension of attention head")
|
||||
parser.add_argument("--ff_mult", type=int, default=2, help="The FFN intermediate dimension multiplier")
|
||||
parser.add_argument("--text_dim", type=int, default=512, help="The output dimension of text encoder")
|
||||
parser.add_argument(
|
||||
"--text_mask_padding",
|
||||
type=lambda x: x.lower() == "true",
|
||||
choices=[True, False],
|
||||
default=True,
|
||||
help="Whether apply padding mask for conv layers in text encoder",
|
||||
)
|
||||
parser.add_argument("--conv_layers", type=int, default=4, help="The number of conv layers of text encoder")
|
||||
parser.add_argument("--pe_attn_head", type=int, default=None, help="The number of attn head that apply pos emb")
|
||||
args = parser.parse_args()
|
||||
|
||||
# overwrite if --model_name ordered
|
||||
if args.model_name == "F5TTS_v1_Base":
|
||||
args.hidden_size = 1024
|
||||
args.depth = 22
|
||||
args.num_heads = 16
|
||||
args.dim_head = 64
|
||||
args.ff_mult = 2
|
||||
args.text_dim = 512
|
||||
args.text_mask_padding = True
|
||||
args.conv_layers = 4
|
||||
args.pe_attn_head = None
|
||||
elif args.model_name == "F5TTS_Base":
|
||||
args.hidden_size = 1024
|
||||
args.depth = 22
|
||||
args.num_heads = 16
|
||||
args.dim_head = 64
|
||||
args.ff_mult = 2
|
||||
args.text_dim = 512
|
||||
args.text_mask_padding = False
|
||||
args.conv_layers = 4
|
||||
args.pe_attn_head = 1
|
||||
elif args.model_name == "F5TTS_v1_Small":
|
||||
args.hidden_size = 768
|
||||
args.depth = 18
|
||||
args.num_heads = 12
|
||||
args.dim_head = 64
|
||||
args.ff_mult = 2
|
||||
args.text_dim = 512
|
||||
args.text_mask_padding = True
|
||||
args.conv_layers = 4
|
||||
args.pe_attn_head = None
|
||||
elif args.model_name == "F5TTS_Small":
|
||||
args.hidden_size = 768
|
||||
args.depth = 18
|
||||
args.num_heads = 12
|
||||
args.dim_head = 64
|
||||
args.ff_mult = 2
|
||||
args.text_dim = 512
|
||||
args.text_mask_padding = False
|
||||
args.conv_layers = 4
|
||||
args.pe_attn_head = 1
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def convert_timm_dit(args, mapping, dtype="float32"):
|
||||
def convert_pytorch_dit_to_trtllm_weight(args, mapping, dtype="float32", use_ema=True):
|
||||
weights = {}
|
||||
tik = time.time()
|
||||
torch_dtype = str_dtype_to_torch(dtype)
|
||||
tensor_parallel = mapping.tp_size
|
||||
|
||||
model_params = dict(torch.load(args.timm_ckpt))
|
||||
model_params = {
|
||||
k: v for k, v in model_params["ema_model_state_dict"].items() if k.startswith("ema_model.transformer")
|
||||
ckpt_path = args.pytorch_ckpt
|
||||
ckpt_type = ckpt_path.split(".")[-1]
|
||||
if ckpt_type == "safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
model_params = load_file(ckpt_path)
|
||||
else:
|
||||
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
||||
model_params = ckpt["ema_model_state_dict"] if use_ema else ckpt["model_state_dict"]
|
||||
|
||||
prefix = "ema_model.transformer." if use_ema else "transformer."
|
||||
if any(k.startswith(prefix) for k in model_params.keys()):
|
||||
model_params = {
|
||||
key[len(prefix) :] if key.startswith(prefix) else key: value
|
||||
for key, value in model_params.items()
|
||||
if key.startswith(prefix)
|
||||
}
|
||||
|
||||
pytorch_to_trtllm_name = {
|
||||
r"^time_embed\.time_mlp\.0\.(weight|bias)$": r"time_embed.mlp1.\1",
|
||||
r"^time_embed\.time_mlp\.2\.(weight|bias)$": r"time_embed.mlp2.\1",
|
||||
r"^input_embed\.conv_pos_embed\.conv1d\.0\.(weight|bias)$": r"input_embed.conv_pos_embed.conv1d1.\1",
|
||||
r"^input_embed\.conv_pos_embed\.conv1d\.2\.(weight|bias)$": r"input_embed.conv_pos_embed.conv1d2.\1",
|
||||
r"^transformer_blocks\.(\d+)\.attn\.to_out\.0\.(weight|bias)$": r"transformer_blocks.\1.attn.to_out.\2",
|
||||
r"^transformer_blocks\.(\d+)\.ff\.ff\.0\.0\.(weight|bias)$": r"transformer_blocks.\1.ff.project_in.\2",
|
||||
r"^transformer_blocks\.(\d+)\.ff\.ff\.2\.(weight|bias)$": r"transformer_blocks.\1.ff.ff.\2",
|
||||
}
|
||||
prefix = "ema_model.transformer."
|
||||
model_params = {key[len(prefix) :] if key.startswith(prefix) else key: value for key, value in model_params.items()}
|
||||
|
||||
timm_to_trtllm_name = FACEBOOK_DIT_NAME_MAPPING
|
||||
|
||||
def get_trtllm_name(timm_name):
|
||||
for k, v in timm_to_trtllm_name.items():
|
||||
m = re.match(k, timm_name)
|
||||
if m is not None:
|
||||
if "*" in v:
|
||||
v = v.replace("*", m.groups()[0])
|
||||
return v
|
||||
return timm_name
|
||||
def get_trtllm_name(pytorch_name):
|
||||
for pytorch_name_pattern, trtllm_name_replacement in pytorch_to_trtllm_name.items():
|
||||
trtllm_name_if_matched = re.sub(pytorch_name_pattern, trtllm_name_replacement, pytorch_name)
|
||||
if trtllm_name_if_matched != pytorch_name:
|
||||
return trtllm_name_if_matched
|
||||
return pytorch_name
|
||||
|
||||
weights = dict()
|
||||
for name, param in model_params.items():
|
||||
@@ -230,7 +160,7 @@ def convert_timm_dit(args, mapping, dtype="float32"):
|
||||
|
||||
assert len(weights) == len(model_params)
|
||||
|
||||
# new_prefix = 'f5_transformer.'
|
||||
# new_prefix = "f5_transformer."
|
||||
new_prefix = ""
|
||||
weights = {new_prefix + key: value for key, value in weights.items()}
|
||||
import math
|
||||
@@ -272,19 +202,19 @@ def save_config(args):
|
||||
if not os.path.exists(args.output_dir):
|
||||
os.makedirs(args.output_dir)
|
||||
config = {
|
||||
"architecture": "F5TTS",
|
||||
"architecture": "F5TTS", # set the same as in ../patch/__init__.py
|
||||
"dtype": args.dtype,
|
||||
"hidden_size": 1024,
|
||||
"num_hidden_layers": 22,
|
||||
"num_attention_heads": 16,
|
||||
"dim_head": 64,
|
||||
"dropout": 0.1,
|
||||
"ff_mult": 2,
|
||||
"hidden_size": args.hidden_size,
|
||||
"num_hidden_layers": args.depth,
|
||||
"num_attention_heads": args.num_heads,
|
||||
"dim_head": args.dim_head,
|
||||
"dropout": 0.0, # inference-only
|
||||
"ff_mult": args.ff_mult,
|
||||
"mel_dim": 100,
|
||||
"text_num_embeds": 256,
|
||||
"text_dim": 512,
|
||||
"conv_layers": 4,
|
||||
"long_skip_connection": False,
|
||||
"text_dim": args.text_dim,
|
||||
"text_mask_padding": args.text_mask_padding,
|
||||
"conv_layers": args.conv_layers,
|
||||
"pe_attn_head": args.pe_attn_head,
|
||||
"mapping": {
|
||||
"world_size": args.cp_size * args.tp_size * args.pp_size,
|
||||
"cp_size": args.cp_size,
|
||||
@@ -296,7 +226,7 @@ def save_config(args):
|
||||
config["quantization"] = {
|
||||
"quant_algo": "FP8",
|
||||
# TODO: add support for exclude modules.
|
||||
# 'exclude_modules': "*final_layer*",
|
||||
# "exclude_modules": "*final_layer*",
|
||||
}
|
||||
|
||||
with open(os.path.join(args.output_dir, "config.json"), "w") as f:
|
||||
@@ -315,7 +245,7 @@ def covert_and_save(args, rank):
|
||||
pp_size=args.pp_size,
|
||||
)
|
||||
|
||||
weights = convert_timm_dit(args, mapping, dtype=args.dtype)
|
||||
weights = convert_pytorch_dit_to_trtllm_weight(args, mapping, dtype=args.dtype)
|
||||
|
||||
safetensors.torch.save_file(weights, os.path.join(args.output_dir, f"rank{rank}.safetensors"))
|
||||
|
||||
@@ -344,9 +274,9 @@ def main():
|
||||
assert args.pp_size == 1, "PP is not supported yet."
|
||||
|
||||
tik = time.time()
|
||||
if args.timm_ckpt is None:
|
||||
if args.pytorch_ckpt is None:
|
||||
return
|
||||
print("start execute")
|
||||
print("Start execute")
|
||||
execute(args.workers, [covert_and_save] * world_size, args)
|
||||
|
||||
tok = time.time()
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Manual installation of TensorRT, in case not using NVIDIA NGC:
|
||||
# https://docs.nvidia.com/deeplearning/tensorrt/latest/installing-tensorrt/installing.html#downloading-tensorrt
|
||||
TRTEXEC="/usr/src/tensorrt/bin/trtexec"
|
||||
|
||||
ONNX_PATH=$1
|
||||
@@ -28,7 +30,7 @@ MAX_BATCH_SIZE=8
|
||||
|
||||
MIN_INPUT_LENGTH=1
|
||||
OPT_INPUT_LENGTH=1000
|
||||
MAX_INPUT_LENGTH=3000
|
||||
MAX_INPUT_LENGTH=3000 # 4096
|
||||
|
||||
MEL_MIN_SHAPE="${MIN_BATCH_SIZE}x100x${MIN_INPUT_LENGTH}"
|
||||
MEL_OPT_SHAPE="${OPT_BATCH_SIZE}x100x${OPT_INPUT_LENGTH}"
|
||||
@@ -40,4 +42,3 @@ ${TRTEXEC} \
|
||||
--maxShapes="mel:${MEL_MAX_SHAPE}" \
|
||||
--onnx=${ONNX_PATH} \
|
||||
--saveEngine=${ENGINE_PATH}
|
||||
|
||||
|
||||
32
src/f5_tts/scripts/count_max_epoch_precise.py
Normal file
32
src/f5_tts/scripts/count_max_epoch_precise.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import math
|
||||
|
||||
from torch.utils.data import SequentialSampler
|
||||
|
||||
from f5_tts.model.dataset import DynamicBatchSampler, load_dataset
|
||||
|
||||
|
||||
train_dataset = load_dataset("Emilia_ZH_EN", "pinyin")
|
||||
sampler = SequentialSampler(train_dataset)
|
||||
|
||||
gpus = 8
|
||||
batch_size_per_gpu = 38400
|
||||
max_samples_per_gpu = 64
|
||||
max_updates = 1250000
|
||||
|
||||
batch_sampler = DynamicBatchSampler(
|
||||
sampler,
|
||||
batch_size_per_gpu,
|
||||
max_samples=max_samples_per_gpu,
|
||||
random_seed=666,
|
||||
drop_residual=False,
|
||||
)
|
||||
|
||||
print(
|
||||
f"One epoch has {len(batch_sampler) / gpus} updates if gpus={gpus}, with "
|
||||
f"batch_size_per_gpu={batch_size_per_gpu} (frames) & "
|
||||
f"max_samples_per_gpu={max_samples_per_gpu}."
|
||||
)
|
||||
print(
|
||||
f"If gpus={gpus}, for max_updates={max_updates} "
|
||||
f"should set epoch={math.ceil(max_updates / len(batch_sampler) * gpus)}."
|
||||
)
|
||||
@@ -225,5 +225,5 @@ if __name__ == "__main__":
|
||||
# bad zh asr cnt 230435 (samples)
|
||||
# bad eh asr cnt 37217 (samples)
|
||||
|
||||
# vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme)
|
||||
# vocab size may be slightly different due to rjieba tokenizer and pypinyin (e.g. way of polyphoneme)
|
||||
# please be careful if using pretrained model, make sure the vocab.txt is same
|
||||
|
||||
@@ -122,5 +122,5 @@ if __name__ == "__main__":
|
||||
# - - 1459 (polyphone)
|
||||
# char vocab size 5264 5219 5042
|
||||
|
||||
# vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme)
|
||||
# vocab size may be slightly different due to rjieba tokenizer and pypinyin (e.g. way of polyphoneme)
|
||||
# please be careful if using pretrained model, make sure the vocab.txt is same
|
||||
|
||||
Reference in New Issue
Block a user