16 Commits

Author SHA1 Message Date
SWivid
9ae46c8360 Replace jieba pkg with rjieba - a jieba-rs Python binding 2025-11-28 13:08:07 +00:00
SWivid
3eecd94baa support back avg upsampling for batch, cover up non-mask case 2025-11-09 11:56:03 +00:00
SWivid
d9a69452ce formatting 2025-11-09 18:25:30 +08:00
Yushen CHEN
bc15df2b57 Merge pull request #1212 from QingyuLiu0521/fix/AverageUpsampling
Fix Average Upsampling conflict logic, introduced from the previous batch inference fix.
2025-11-09 18:23:38 +08:00
QingyuLiu0521
9b2357a1b9 Fix Average Upsampling 2025-11-08 18:39:06 -05:00
Yushen CHEN
1dcb4e10f7 Add torchcodec dependency to pyproject.toml 2025-11-03 16:44:11 +08:00
SWivid
529d856133 clean-up eval scripts 2025-10-27 14:38:57 +00:00
SWivid
7abadc4c72 fix typo in eval scripts 2025-10-26 14:28:17 +00:00
SWivid
e67d50841e runtime trtllm: fix batch inference skipping last words in shorter sentences #1039 #1179 2025-10-24 09:12:08 +00:00
SWivid
6b07fb03b2 clean-up ruff lint 2025-10-24 08:30:55 +00:00
SWivid
a051a68552 pytorch imple. fix batch inference skipping last words in shorter sentences issue #1039 #1179 2025-10-24 05:50:25 +00:00
Yushen CHEN
f2a4f8581f Update runtime README 2025-10-22 08:37:32 +08:00
SWivid
a17c5ae435 pytorch imple.: fix batch 1 inference from last commit 2025-10-22 00:31:56 +00:00
SWivid
a0b8fb5df2 runtime trtllm: minor fixes. pytorch: update text_embedding logic to correct v0 batching. 2025-10-22 00:19:45 +00:00
SWivid
c8bfc3aa3d runtime trtllm: support v1 and custom 2025-10-21 22:02:25 +00:00
SWivid
8d3ec72159 runtime trtllm: clean-up v0 code, several fixes. 2025-10-20 10:30:58 +00:00
32 changed files with 859 additions and 721 deletions

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "f5-tts"
version = "1.1.9"
version = "1.1.10"
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
readme = "README.md"
license = {text = "MIT License"}
@@ -22,18 +22,19 @@ dependencies = [
"ema_pytorch>=0.5.2",
"gradio>=5.0.0",
"hydra-core>=1.3.0",
"jieba",
"librosa",
"matplotlib",
"numpy<=1.26.4; python_version<='3.10'",
"pydantic<=2.10.6",
"pydub",
"pypinyin",
"rjieba",
"safetensors",
"soundfile",
"tomli",
"torch>=2.0.0",
"torchaudio>=2.0.0",
"torchcodec",
"torchdiffeq",
"tqdm>=4.65.0",
"transformers",

View File

@@ -154,8 +154,8 @@ if __name__ == "__main__":
wav, sr, spec = f5tts.infer(
ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
ref_text="some call me nature, others call me mother nature.",
gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
ref_text="Some call me nature, others call me mother nature.",
gen_text="I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring.",
file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")),
file_spec=str(files("f5_tts").joinpath("../../tests/api_out.png")),
seed=None,

View File

@@ -14,16 +14,20 @@ pip install -e .[eval]
1. *Seed-TTS testset*: Download from [seed-tts-eval](https://github.com/BytedanceSpeech/seed-tts-eval).
2. *LibriSpeech test-clean*: Download from [OpenSLR](http://www.openslr.org/12/).
3. Unzip the downloaded datasets and place them in the `data/` directory.
4. Update the path for *LibriSpeech test-clean* data in `src/f5_tts/eval/eval_infer_batch.py`
5. Our filtered LibriSpeech-PC 4-10s subset: `data/librispeech_pc_test_clean_cross_sentence.lst`
4. Our filtered LibriSpeech-PC 4-10s subset: `data/librispeech_pc_test_clean_cross_sentence.lst`
### Batch Inference for Test Set
To run batch inference for evaluations, execute the following commands:
```bash
# batch inference for evaluations
accelerate config # if not set before
# if not setup accelerate config yet
accelerate config
# if only perform inference
bash src/f5_tts/eval/eval_infer_batch.sh --infer-only
# if inference and with corresponding evaluation, setup the following tools first
bash src/f5_tts/eval/eval_infer_batch.sh
```
@@ -35,9 +39,13 @@ bash src/f5_tts/eval/eval_infer_batch.sh
2. English ASR Model: [Faster-Whisper](https://huggingface.co/Systran/faster-whisper-large-v3)
3. WavLM Model: Download from [Google Drive](https://drive.google.com/file/d/1-aE1NfzpRCLxA4GUxX9ITI3F9LlbtEGP/view).
Then update in the following scripts with the paths you put evaluation model ckpts to.
> [!NOTE]
> ASR model will be automatically downloaded if `--local` not set for evaluation scripts.
> Otherwise, you should update the `asr_ckpt_dir` path values in `eval_librispeech_test_clean.py` or `eval_seedtts_testset.py`.
>
> WavLM model must be downloaded and your `wavlm_ckpt_dir` path updated in `eval_librispeech_test_clean.py` and `eval_seedtts_testset.py`.
### Objective Evaluation
### Objective Evaluation Examples
Update the path with your batch-inferenced results, and carry out WER / SIM / UTMOS evaluations:
```bash
@@ -50,3 +58,6 @@ python src/f5_tts/eval/eval_librispeech_test_clean.py --eval_task sim --gen_wav_
# Evaluation [UTMOS]. --ext: Audio extension
python src/f5_tts/eval/eval_utmos.py --audio_dir <WAV_DIR> --ext wav
```
> [!NOTE]
> Evaluation results can also be found in `_*_results.jsonl` files saved in `<GEN_WAV_DIR>`/`<WAV_DIR>`.

View File

@@ -48,6 +48,11 @@ def main():
parser.add_argument("-ss", "--swaysampling", default=-1, type=float)
parser.add_argument("-t", "--testset", required=True)
parser.add_argument(
"-p", "--librispeech_test_clean_path", default=f"{rel_path}/data/LibriSpeech/test-clean", type=str
)
parser.add_argument("--local", action="store_true", help="Use local vocoder checkpoint directory")
args = parser.parse_args()
@@ -83,7 +88,7 @@ def main():
if testset == "ls_pc_test_clean":
metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"
librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
librispeech_test_clean_path = args.librispeech_test_clean_path
metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path)
elif testset == "seedtts_test_zh":
@@ -121,7 +126,7 @@ def main():
)
# Vocoder model
local = False
local = args.local
if mel_spec_type == "vocos":
vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz"
elif mel_spec_type == "bigvgan":
@@ -155,7 +160,13 @@ def main():
ckpt_path = ckpt_prefix + ".safetensors"
else:
print("Loading from self-organized training checkpoints rather than released pretrained.")
ckpt_path = rel_path + f"/{model_cfg.ckpts.save_dir}/model_{ckpt_step}.pt"
ckpt_prefix = rel_path + f"/{model_cfg.ckpts.save_dir}/model_{ckpt_step}"
if os.path.exists(ckpt_prefix + ".pt"):
ckpt_path = ckpt_prefix + ".pt"
elif os.path.exists(ckpt_prefix + ".safetensors"):
ckpt_path = ckpt_prefix + ".safetensors"
else:
raise ValueError("The checkpoint does not exist or cannot be found in given location.")
dtype = torch.float32 if mel_spec_type == "bigvgan" else None
model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)

View File

@@ -1,18 +1,116 @@
#!/bin/bash
set -e
export PYTHONWARNINGS="ignore::UserWarning,ignore::FutureWarning"
# e.g. F5-TTS, 16 NFE
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "seedtts_test_zh" -nfe 16
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "seedtts_test_en" -nfe 16
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "ls_pc_test_clean" -nfe 16
# Configuration parameters
MODEL_NAME="F5TTS_v1_Base"
SEEDS=(0 1 2)
CKPTSTEPS=(1250000)
TASKS=("seedtts_test_zh" "seedtts_test_en" "ls_pc_test_clean")
LS_TEST_CLEAN_PATH="data/LibriSpeech/test-clean"
GPUS="[0,1,2,3,4,5,6,7]"
OFFLINE_MODE=false
# e.g. Vanilla E2 TTS, 32 NFE
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "seedtts_test_zh" -o "midpoint" -ss 0
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "seedtts_test_en" -o "midpoint" -ss 0
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "ls_pc_test_clean" -o "midpoint" -ss 0
# Parse arguments
if [ $OFFLINE_MODE = true ]; then
LOCAL="--local"
else
LOCAL=""
fi
INFER_ONLY=false
while [[ $# -gt 0 ]]; do
case $1 in
--infer-only)
INFER_ONLY=true
shift
;;
*)
echo "======== Unknown parameter: $1"
exit 1
;;
esac
done
# e.g. evaluate F5-TTS 16 NFE result on Seed-TTS test-zh
python src/f5_tts/eval/eval_seedtts_testset.py -e wer -l zh --gen_wav_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe32_vocos_ss-1_cfg2.0_speed1.0 --gpu_nums 8
python src/f5_tts/eval/eval_seedtts_testset.py -e sim -l zh --gen_wav_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe32_vocos_ss-1_cfg2.0_speed1.0 --gpu_nums 8
python src/f5_tts/eval/eval_utmos.py --audio_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe32_vocos_ss-1_cfg2.0_speed1.0
echo "======== Starting F5-TTS batch evaluation task..."
if [ "$INFER_ONLY" = true ]; then
echo "======== Mode: Execute infer tasks only"
else
echo "======== Mode: Execute full pipeline (infer + eval)"
fi
# etc.
# Function: Execute eval tasks
execute_eval_tasks() {
local ckptstep=$1
local seed=$2
local task_name=$3
local gen_wav_dir="results/${MODEL_NAME}_${ckptstep}/${task_name}/seed${seed}_euler_nfe32_vocos_ss-1_cfg2.0_speed1.0"
echo ">>>>>>>> Starting eval task: ckptstep=${ckptstep}, seed=${seed}, task=${task_name}"
case $task_name in
"seedtts_test_zh")
python src/f5_tts/eval/eval_seedtts_testset.py -e wer -l zh -g "$gen_wav_dir" -n "$GPUS" $LOCAL
python src/f5_tts/eval/eval_seedtts_testset.py -e sim -l zh -g "$gen_wav_dir" -n "$GPUS" $LOCAL
python src/f5_tts/eval/eval_utmos.py --audio_dir "$gen_wav_dir"
;;
"seedtts_test_en")
python src/f5_tts/eval/eval_seedtts_testset.py -e wer -l en -g "$gen_wav_dir" -n "$GPUS" $LOCAL
python src/f5_tts/eval/eval_seedtts_testset.py -e sim -l en -g "$gen_wav_dir" -n "$GPUS" $LOCAL
python src/f5_tts/eval/eval_utmos.py --audio_dir "$gen_wav_dir"
;;
"ls_pc_test_clean")
python src/f5_tts/eval/eval_librispeech_test_clean.py -e wer -g "$gen_wav_dir" -n "$GPUS" -p "$LS_TEST_CLEAN_PATH" $LOCAL
python src/f5_tts/eval/eval_librispeech_test_clean.py -e sim -g "$gen_wav_dir" -n "$GPUS" -p "$LS_TEST_CLEAN_PATH" $LOCAL
python src/f5_tts/eval/eval_utmos.py --audio_dir "$gen_wav_dir"
;;
esac
echo ">>>>>>>> Completed eval task: ckptstep=${ckptstep}, seed=${seed}, task=${task_name}"
}
# Main execution loop
for ckptstep in "${CKPTSTEPS[@]}"; do
echo "======== Processing ckptstep: ${ckptstep}"
for seed in "${SEEDS[@]}"; do
echo "-------- Processing seed: ${seed}"
# Store eval task PIDs for current seed (if not infer-only mode)
if [ "$INFER_ONLY" = false ]; then
declare -a eval_pids
fi
# Execute each infer task sequentially
for task in "${TASKS[@]}"; do
echo ">>>>>>>> Executing infer task: accelerate launch src/f5_tts/eval/eval_infer_batch.py -s ${seed} -n \"${MODEL_NAME}\" -t \"${task}\" -c ${ckptstep} $LOCAL"
# Execute infer task (foreground execution, wait for completion)
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s ${seed} -n "${MODEL_NAME}" -t "${task}" -c ${ckptstep} -p "${LS_TEST_CLEAN_PATH}" $LOCAL
# If not infer-only mode, launch corresponding eval task
if [ "$INFER_ONLY" = false ]; then
# Launch corresponding eval task (background execution, non-blocking for next infer)
execute_eval_tasks $ckptstep $seed $task &
eval_pids+=($!)
fi
done
# If not infer-only mode, wait for all eval tasks of current seed to complete
if [ "$INFER_ONLY" = false ]; then
echo ">>>>>>>> All infer tasks for seed ${seed} completed, waiting for corresponding eval tasks to finish..."
for pid in "${eval_pids[@]}"; do
wait $pid
done
unset eval_pids # Clean up array
fi
echo "-------- All eval tasks for seed ${seed} completed"
done
echo "======== Completed ckptstep: ${ckptstep}"
echo
done
echo "======== All tasks completed!"

View File

@@ -0,0 +1,18 @@
#!/bin/bash
# e.g. F5-TTS, 16 NFE
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "seedtts_test_zh" -nfe 16
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "seedtts_test_en" -nfe 16
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "ls_pc_test_clean" -nfe 16 -p data/LibriSpeech/test-clean
# e.g. Vanilla E2 TTS, 32 NFE
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "seedtts_test_zh" -o "midpoint" -ss 0
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "seedtts_test_en" -o "midpoint" -ss 0
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "ls_pc_test_clean" -o "midpoint" -ss 0 -p data/LibriSpeech/test-clean
# e.g. evaluate F5-TTS 16 NFE result on Seed-TTS test-zh
python src/f5_tts/eval/eval_seedtts_testset.py -e wer -l zh --gen_wav_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe16_vocos_ss-1_cfg2.0_speed1.0 --gpu_nums 8
python src/f5_tts/eval/eval_seedtts_testset.py -e sim -l zh --gen_wav_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe16_vocos_ss-1_cfg2.0_speed1.0 --gpu_nums 8
python src/f5_tts/eval/eval_utmos.py --audio_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe16_vocos_ss-1_cfg2.0_speed1.0
# etc.

View File

@@ -1,6 +1,7 @@
# Evaluate with Librispeech test-clean, ~3s prompt to generate 4-10s audio (the way of valle/voicebox evaluation)
import argparse
import ast
import json
import os
import sys
@@ -25,11 +26,26 @@ def get_args():
parser.add_argument("-l", "--lang", type=str, default="en")
parser.add_argument("-g", "--gen_wav_dir", type=str, required=True)
parser.add_argument("-p", "--librispeech_test_clean_path", type=str, required=True)
parser.add_argument("-n", "--gpu_nums", type=int, default=8, help="Number of GPUs to use")
parser.add_argument(
"-n", "--gpu_nums", type=str, default="8", help="Number of GPUs to use (e.g., 8) or GPU list (e.g., [0,1,2,3])"
)
parser.add_argument("--local", action="store_true", help="Use local custom checkpoint directory")
return parser.parse_args()
def parse_gpu_nums(gpu_nums_str):
try:
if gpu_nums_str.startswith("[") and gpu_nums_str.endswith("]"):
gpu_list = ast.literal_eval(gpu_nums_str)
if isinstance(gpu_list, list):
return gpu_list
return list(range(int(gpu_nums_str)))
except (ValueError, SyntaxError):
raise argparse.ArgumentTypeError(
f"Invalid GPU specification: {gpu_nums_str}. Use a number (e.g., 8) or a list (e.g., [0,1,2,3])"
)
def main():
args = get_args()
eval_task = args.eval_task
@@ -38,7 +54,7 @@ def main():
gen_wav_dir = args.gen_wav_dir
metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"
gpus = list(range(args.gpu_nums))
gpus = parse_gpu_nums(args.gpu_nums)
test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path)
## In LibriSpeech, some speakers utilized varying voice characteristics for different characters in the book,

View File

@@ -1,6 +1,7 @@
# Evaluate with Seed-TTS testset
import argparse
import ast
import json
import os
import sys
@@ -24,11 +25,26 @@ def get_args():
parser.add_argument("-e", "--eval_task", type=str, default="wer", choices=["sim", "wer"])
parser.add_argument("-l", "--lang", type=str, default="en", choices=["zh", "en"])
parser.add_argument("-g", "--gen_wav_dir", type=str, required=True)
parser.add_argument("-n", "--gpu_nums", type=int, default=8, help="Number of GPUs to use")
parser.add_argument(
"-n", "--gpu_nums", type=str, default="8", help="Number of GPUs to use (e.g., 8) or GPU list (e.g., [0,1,2,3])"
)
parser.add_argument("--local", action="store_true", help="Use local custom checkpoint directory")
return parser.parse_args()
def parse_gpu_nums(gpu_nums_str):
try:
if gpu_nums_str.startswith("[") and gpu_nums_str.endswith("]"):
gpu_list = ast.literal_eval(gpu_nums_str)
if isinstance(gpu_list, list):
return gpu_list
return list(range(int(gpu_nums_str)))
except (ValueError, SyntaxError):
raise argparse.ArgumentTypeError(
f"Invalid GPU specification: {gpu_nums_str}. Use a number (e.g., 8) or a list (e.g., [0,1,2,3])"
)
def main():
args = get_args()
eval_task = args.eval_task
@@ -38,7 +54,7 @@ def main():
# NOTE. paraformer-zh result will be slightly different according to the number of gpus, cuz batchsize is different
# zh 1.254 seems a result of 4 workers wer_seed_tts
gpus = list(range(args.gpu_nums))
gpus = parse_gpu_nums(args.gpu_nums)
test_set = get_seed_tts_test(metalst, gen_wav_dir, gpus)
local = args.local

View File

@@ -395,14 +395,21 @@ def run_sim(args):
wav1, sr1 = torchaudio.load(gen_wav)
wav2, sr2 = torchaudio.load(prompt_wav)
resample1 = torchaudio.transforms.Resample(orig_freq=sr1, new_freq=16000)
resample2 = torchaudio.transforms.Resample(orig_freq=sr2, new_freq=16000)
wav1 = resample1(wav1)
wav2 = resample2(wav2)
if use_gpu:
wav1 = wav1.cuda(device)
wav2 = wav2.cuda(device)
if sr1 != 16000:
resample1 = torchaudio.transforms.Resample(orig_freq=sr1, new_freq=16000)
if use_gpu:
resample1 = resample1.cuda(device)
wav1 = resample1(wav1)
if sr2 != 16000:
resample2 = torchaudio.transforms.Resample(orig_freq=sr2, new_freq=16000)
if use_gpu:
resample2 = resample2.cuda(device)
wav2 = resample2(wav2)
with torch.no_grad():
emb1 = model(wav1)
emb2 = model(wav2)

View File

@@ -6,12 +6,14 @@ nt - text sequence
nw - raw wave length
d - dimension
"""
# ruff: noqa: F722 F821
from __future__ import annotations
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from x_transformers.x_transformers import RotaryEmbedding
from f5_tts.model.modules import (
@@ -20,7 +22,6 @@ from f5_tts.model.modules import (
ConvPositionEmbedding,
DiTBlock,
TimestepEmbedding,
get_pos_embed_indices,
precompute_freqs_cis,
)
@@ -42,7 +43,7 @@ class TextEmbedding(nn.Module):
if conv_layers > 0:
self.extra_modeling = True
self.precompute_max_pos = 4096 # ~44s of 24khz audio
self.precompute_max_pos = 8192 # 8192 is ~87.38s of 24khz audio; 4096 is ~43.69s of 24khz audio
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
self.text_blocks = nn.Sequential(
*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
@@ -50,33 +51,29 @@ class TextEmbedding(nn.Module):
else:
self.extra_modeling = False
def average_upsample_text_by_mask(self, text, text_mask, audio_mask):
def average_upsample_text_by_mask(self, text, text_mask):
batch, text_len, text_dim = text.shape
if audio_mask is None:
audio_mask = torch.ones_like(text_mask, dtype=torch.bool)
valid_mask = audio_mask & text_mask
audio_lens = audio_mask.sum(dim=1) # [batch]
valid_lens = valid_mask.sum(dim=1) # [batch]
audio_len = text_len # cuz text already padded to same length as audio sequence
text_lens = text_mask.sum(dim=1) # [batch]
upsampled_text = torch.zeros_like(text)
for i in range(batch):
audio_len = audio_lens[i].item()
valid_len = valid_lens[i].item()
text_len = text_lens[i].item()
if valid_len == 0:
if text_len == 0:
continue
valid_ind = torch.where(valid_mask[i])[0]
valid_data = text[i, valid_ind, :] # [valid_len, text_dim]
valid_ind = torch.where(text_mask[i])[0]
valid_data = text[i, valid_ind, :] # [text_len, text_dim]
base_repeat = audio_len // valid_len
remainder = audio_len % valid_len
base_repeat = audio_len // text_len
remainder = audio_len % text_len
indices = []
for j in range(valid_len):
repeat_count = base_repeat + (1 if j >= valid_len - remainder else 0)
for j in range(text_len):
repeat_count = base_repeat + (1 if j >= text_len - remainder else 0)
indices.extend([j] * repeat_count)
indices = torch.tensor(indices[:audio_len], device=text.device, dtype=torch.long)
@@ -86,11 +83,10 @@ class TextEmbedding(nn.Module):
return upsampled_text
def forward(self, text: int["b nt"], seq_len, drop_text=False, audio_mask: bool["b n"] | None = None): # noqa: F722
def forward(self, text: int["b nt"], seq_len, drop_text=False):
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
batch, text_len = text.shape[0], text.shape[1]
text = F.pad(text, (0, seq_len - text_len), value=0) # (opt.) if not self.average_upsampling:
text = F.pad(text, (0, seq_len - text.shape[1]), value=0) # (opt.) if not self.average_upsampling:
if self.mask_padding:
text_mask = text == 0
@@ -102,10 +98,7 @@ class TextEmbedding(nn.Module):
# possible extra modeling
if self.extra_modeling:
# sinus pos emb
batch_start = torch.zeros((batch,), device=text.device, dtype=torch.long)
pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
text_pos_embed = self.freqs_cis[pos_idx]
text = text + text_pos_embed
text = text + self.freqs_cis[:seq_len, :]
# convnextv2 blocks
if self.mask_padding:
@@ -117,7 +110,7 @@ class TextEmbedding(nn.Module):
text = self.text_blocks(text)
if self.average_upsampling:
text = self.average_upsample_text_by_mask(text, ~text_mask, audio_mask)
text = self.average_upsample_text_by_mask(text, ~text_mask)
return text
@@ -131,12 +124,19 @@ class InputEmbedding(nn.Module):
self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722
def forward(
self,
x: float["b n d"],
cond: float["b n d"],
text_embed: float["b n d"],
drop_audio_cond=False,
audio_mask: bool["b n"] | None = None,
):
if drop_audio_cond: # cfg for cond audio
cond = torch.zeros_like(cond)
x = self.proj(torch.cat((x, cond, text_embed), dim=-1))
x = self.conv_pos_embed(x) + x
x = self.conv_pos_embed(x, mask=audio_mask) + x
return x
@@ -239,22 +239,36 @@ class DiT(nn.Module):
drop_audio_cond: bool = False,
drop_text: bool = False,
cache: bool = True,
audio_mask: bool["b n"] | None = None, # noqa: F722
audio_mask: bool["b n"] | None = None,
):
seq_len = x.shape[1]
if self.text_uncond is None or self.text_cond is None or not cache:
if audio_mask is None:
text_embed = self.text_embed(text, x.shape[1], drop_text=drop_text)
else:
batch = x.shape[0]
seq_lens = audio_mask.sum(dim=1) # Calculate the actual sequence length for each sample
text_embed_list = []
for i in range(batch):
text_embed_i = self.text_embed(
text[i].unsqueeze(0),
seq_len=seq_lens[i].item(),
drop_text=drop_text,
)
text_embed_list.append(text_embed_i[0])
text_embed = pad_sequence(text_embed_list, batch_first=True, padding_value=0)
if cache:
if drop_text:
self.text_uncond = text_embed
else:
self.text_cond = text_embed
if cache:
if drop_text:
if self.text_uncond is None:
self.text_uncond = self.text_embed(text, seq_len, drop_text=True, audio_mask=audio_mask)
text_embed = self.text_uncond
else:
if self.text_cond is None:
self.text_cond = self.text_embed(text, seq_len, drop_text=False, audio_mask=audio_mask)
text_embed = self.text_cond
else:
text_embed = self.text_embed(text, seq_len, drop_text=drop_text, audio_mask=audio_mask)
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond, audio_mask=audio_mask)
return x
@@ -263,11 +277,11 @@ class DiT(nn.Module):
def forward(
self,
x: float["b n d"], # nosied input audio # noqa: F722
cond: float["b n d"], # masked cond audio # noqa: F722
text: int["b nt"], # text # noqa: F722
time: float["b"] | float[""], # time step # noqa: F821 F722
mask: bool["b n"] | None = None, # noqa: F722
x: float["b n d"], # nosied input audio
cond: float["b n d"], # masked cond audio
text: int["b nt"], # text
time: float["b"] | float[""], # time step
mask: bool["b n"] | None = None,
drop_audio_cond: bool = False, # cfg for cond audio
drop_text: bool = False, # cfg for text
cfg_infer: bool = False, # cfg inference, pack cond & uncond forward

View File

@@ -6,6 +6,7 @@ nt - text sequence
nw - raw wave length
d - dimension
"""
# ruff: noqa: F722 F821
from __future__ import annotations
@@ -36,7 +37,7 @@ class TextEmbedding(nn.Module):
self.precompute_max_pos = 1024
self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False)
def forward(self, text: int["b nt"], drop_text=False) -> int["b nt d"]: # noqa: F722
def forward(self, text: int["b nt"], drop_text=False) -> int["b nt d"]:
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
if self.mask_padding:
text_mask = text == 0
@@ -69,7 +70,7 @@ class AudioEmbedding(nn.Module):
self.linear = nn.Linear(2 * in_dim, out_dim)
self.conv_pos_embed = ConvPositionEmbedding(out_dim)
def forward(self, x: float["b n d"], cond: float["b n d"], drop_audio_cond=False): # noqa: F722
def forward(self, x: float["b n d"], cond: float["b n d"], drop_audio_cond=False):
if drop_audio_cond:
cond = torch.zeros_like(cond)
x = torch.cat((x, cond), dim=-1)
@@ -170,11 +171,11 @@ class MMDiT(nn.Module):
def forward(
self,
x: float["b n d"], # nosied input audio # noqa: F722
cond: float["b n d"], # masked cond audio # noqa: F722
text: int["b nt"], # text # noqa: F722
time: float["b"] | float[""], # time step # noqa: F821 F722
mask: bool["b n"] | None = None, # noqa: F722
x: float["b n d"], # nosied input audio
cond: float["b n d"], # masked cond audio
text: int["b nt"], # text
time: float["b"] | float[""], # time step
mask: bool["b n"] | None = None,
drop_audio_cond: bool = False, # cfg for cond audio
drop_text: bool = False, # cfg for text
cfg_infer: bool = False, # cfg inference, pack cond & uncond forward

View File

@@ -6,6 +6,7 @@ nt - text sequence
nw - raw wave length
d - dimension
"""
# ruff: noqa: F722 F821
from __future__ import annotations
@@ -49,7 +50,7 @@ class TextEmbedding(nn.Module):
else:
self.extra_modeling = False
def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
def forward(self, text: int["b nt"], seq_len, drop_text=False):
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
batch, text_len = text.shape[0], text.shape[1]
@@ -91,7 +92,7 @@ class InputEmbedding(nn.Module):
self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722
def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False):
if drop_audio_cond: # cfg for cond audio
cond = torch.zeros_like(cond)
@@ -215,11 +216,11 @@ class UNetT(nn.Module):
def forward(
self,
x: float["b n d"], # nosied input audio # noqa: F722
cond: float["b n d"], # masked cond audio # noqa: F722
text: int["b nt"], # text # noqa: F722
time: float["b"] | float[""], # time step # noqa: F821 F722
mask: bool["b n"] | None = None, # noqa: F722
x: float["b n d"], # nosied input audio
cond: float["b n d"], # masked cond audio
text: int["b nt"], # text
time: float["b"] | float[""], # time step
mask: bool["b n"] | None = None,
drop_audio_cond: bool = False, # cfg for cond audio
drop_text: bool = False, # cfg for text
cfg_infer: bool = False, # cfg inference, pack cond & uncond forward

View File

@@ -6,6 +6,7 @@ nt - text sequence
nw - raw wave length
d - dimension
"""
# ruff: noqa: F722 F821
from __future__ import annotations
@@ -82,17 +83,17 @@ class CFM(nn.Module):
@torch.no_grad()
def sample(
self,
cond: float["b n d"] | float["b nw"], # noqa: F722
text: int["b nt"] | list[str], # noqa: F722
duration: int | int["b"], # noqa: F821
cond: float["b n d"] | float["b nw"],
text: int["b nt"] | list[str],
duration: int | int["b"],
*,
lens: int["b"] | None = None, # noqa: F821
lens: int["b"] | None = None,
steps=32,
cfg_strength=1.0,
sway_sampling_coef=None,
seed: int | None = None,
max_duration=4096,
vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722
vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None,
use_epss=True,
no_ref_audio=False,
duplicate_test=False,
@@ -229,10 +230,10 @@ class CFM(nn.Module):
def forward(
self,
inp: float["b n d"] | float["b nw"], # mel or raw wave # noqa: F722
text: int["b nt"] | list[str], # noqa: F722
inp: float["b n d"] | float["b nw"], # mel or raw wave
text: int["b nt"] | list[str],
*,
lens: int["b"] | None = None, # noqa: F821
lens: int["b"] | None = None,
noise_scheduler: str | None = None,
):
# handle raw wave
@@ -252,10 +253,9 @@ class CFM(nn.Module):
assert text.shape[0] == batch
# lens and mask
if not exists(lens):
if not exists(lens): # if lens not acquired by trainer from collate_fn
lens = torch.full((batch,), seq_len, device=device)
mask = lens_to_mask(lens, length=seq_len) # useless here, as collate_fn will pad to max length in batch
mask = lens_to_mask(lens, length=seq_len)
# get a random span to mask out for training conditionally
frac_lengths = torch.zeros((batch,), device=self.device).float().uniform_(*self.frac_lengths_mask)

View File

@@ -6,7 +6,7 @@ nt - text sequence
nw - raw wave length
d - dimension
"""
# flake8: noqa
# ruff: noqa: F722 F821
from __future__ import annotations
@@ -177,20 +177,23 @@ class ConvPositionEmbedding(nn.Module):
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
nn.Mish(),
)
self.layer_need_mask_idx = [i for i, layer in enumerate(self.conv1d) if isinstance(layer, nn.Conv1d)]
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None):
if mask is not None:
mask = mask[..., None]
x = x.masked_fill(~mask, 0.0)
x = x.permute(0, 2, 1)
x = self.conv1d(x)
out = x.permute(0, 2, 1)
mask = mask.unsqueeze(1) # [B 1 N]
x = x.permute(0, 2, 1) # [B D N]
if mask is not None:
out = out.masked_fill(~mask, 0.0)
x = x.masked_fill(~mask, 0.0)
for i, block in enumerate(self.conv1d):
x = block(x)
if mask is not None and i in self.layer_need_mask_idx:
x = x.masked_fill(~mask, 0.0)
return out
x = x.permute(0, 2, 1) # [B N D]
return x
# rotary positional embedding related
@@ -435,8 +438,8 @@ class Attention(nn.Module):
# Attention processor
if is_package_available("flash_attn"):
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn import flash_attn_varlen_func, flash_attn_func
class AttnProcessor:

View File

@@ -1,3 +1,5 @@
# ruff: noqa: F722 F821
from __future__ import annotations
import os
@@ -5,7 +7,7 @@ import random
from collections import defaultdict
from importlib.resources import files
import jieba
import rjieba
import torch
from pypinyin import Style, lazy_pinyin
from torch.nn.utils.rnn import pad_sequence
@@ -48,7 +50,7 @@ def is_package_available(package_name: str) -> bool:
# tensor helpers
def lens_to_mask(t: int["b"], length: int | None = None) -> bool["b n"]: # noqa: F722 F821
def lens_to_mask(t: int["b"], length: int | None = None) -> bool["b n"]:
if not exists(length):
length = t.amax()
@@ -56,7 +58,7 @@ def lens_to_mask(t: int["b"], length: int | None = None) -> bool["b n"]: # noqa
return seq[None, :] < t[:, None]
def mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b"]): # noqa: F722 F821
def mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b"]):
max_seq_len = seq_len.max().item()
seq = torch.arange(max_seq_len, device=start.device).long()
start_mask = seq[None, :] >= start[:, None]
@@ -64,7 +66,7 @@ def mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b"
return start_mask & end_mask
def mask_from_frac_lengths(seq_len: int["b"], frac_lengths: float["b"]): # noqa: F722 F821
def mask_from_frac_lengths(seq_len: int["b"], frac_lengths: float["b"]):
lengths = (frac_lengths * seq_len).long()
max_start = seq_len - lengths
@@ -75,7 +77,7 @@ def mask_from_frac_lengths(seq_len: int["b"], frac_lengths: float["b"]): # noqa
return mask_from_start_end_indices(seq_len, start, end)
def maybe_masked_mean(t: float["b n d"], mask: bool["b n"] = None) -> float["b d"]: # noqa: F722
def maybe_masked_mean(t: float["b n d"], mask: bool["b n"] = None) -> float["b d"]:
if not exists(mask):
return t.mean(dim=1)
@@ -87,7 +89,7 @@ def maybe_masked_mean(t: float["b n d"], mask: bool["b n"] = None) -> float["b d
# simple utf-8 tokenizer, since paper went character based
def list_str_to_tensor(text: list[str], padding_value=-1) -> int["b nt"]: # noqa: F722
def list_str_to_tensor(text: list[str], padding_value=-1) -> int["b nt"]:
list_tensors = [torch.tensor([*bytes(t, "UTF-8")]) for t in text] # ByT5 style
text = pad_sequence(list_tensors, padding_value=padding_value, batch_first=True)
return text
@@ -98,7 +100,7 @@ def list_str_to_idx(
text: list[str] | list[list[str]],
vocab_char_map: dict[str, int], # {char: idx}
padding_value=-1,
) -> int["b nt"]: # noqa: F722
) -> int["b nt"]:
list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
return text
@@ -144,10 +146,6 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
def convert_char_to_pinyin(text_list, polyphone=True):
if jieba.dt.initialized is False:
jieba.default_logger.setLevel(50) # CRITICAL
jieba.initialize()
final_text_list = []
custom_trans = str.maketrans(
{";": ",", "": '"', "": '"', "": "'", "": "'"}
@@ -161,7 +159,7 @@ def convert_char_to_pinyin(text_list, polyphone=True):
for text in text_list:
char_list = []
text = text.translate(custom_trans)
for seg in jieba.cut(text):
for seg in rjieba.cut(text):
seg_byte_len = len(bytes(seg, "UTF-8"))
if seg_byte_len == len(seg): # if pure alphabets and symbols
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":

View File

@@ -0,0 +1,3 @@
# runtime/triton_trtllm related
model.cache
model_repo/

View File

@@ -1,3 +1,3 @@
FROM nvcr.io/nvidia/tritonserver:24.12-py3
RUN pip install tritonclient[grpc] tensorrt-llm==0.16.0 torchaudio==2.5.1 jieba pypinyin librosa vocos
RUN pip install tritonclient[grpc] tensorrt-llm==0.16.0 torchaudio==2.5.1 rjieba pypinyin librosa vocos
WORKDIR /workspace

View File

@@ -1,59 +1,68 @@
## Triton Inference Serving Best Practice for F5-TTS
### Quick Start
Directly launch the service using docker compose.
### Setup
#### Option 1: Quick Start
```sh
# TODO: support F5TTS_v1_Base
MODEL=F5TTS_Base docker compose up
# Directly launch the service using docker compose
MODEL=F5TTS_v1_Base docker compose up
```
### Build Image
Build the docker image from scratch.
#### Option 2: Build from scratch
```sh
# Build the docker image
docker build . -f Dockerfile.server -t soar97/triton-f5-tts:24.12
```
### Create Docker Container
```sh
# Create Docker Container
your_mount_dir=/mnt:/mnt
docker run -it --name "f5-server" --gpus all --net host -v $your_mount_dir --shm-size=2g soar97/triton-f5-tts:24.12
```
### Export Models to TensorRT-LLM and Launch Server
Inside docker container, we would follow the official guide of TensorRT-LLM to build qwen and whisper TensorRT-LLM engines. See [here](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/whisper).
### Build TensorRT-LLM Engines and Launch Server
Inside docker container, we would follow the official guide of TensorRT-LLM to build qwen and whisper TensorRT-LLM engines. See [here](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/models/core/whisper).
```sh
bash run.sh 0 4 F5TTS_Base
# F5TTS_v1_Base | F5TTS_Base | F5TTS_v1_Small | F5TTS_Small
bash run.sh 0 4 F5TTS_v1_Base
```
> [!NOTE]
> If use custom checkpoint, set `ckpt_file` and `vocab_file` in `run.sh`.
> Remember to used matched model version (`F5TTS_v1_*` for v1, `F5TTS_*` for v0).
>
> If use checkpoint of different structure, see `scripts/convert_checkpoint.py`, and perform modification if necessary.
> [!IMPORTANT]
> If train or finetune with fp32, add `--dtype float32` flag when converting checkpoint in `run.sh` phase 1.
### HTTP Client
```sh
python3 client_http.py
```
### Benchmark using Client-Server Mode
### Benchmarking
#### Using Client-Server Mode
```sh
# bash run.sh 5 5 F5TTS_v1_Base
num_task=2
python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts
```
### Benchmark using Offline TRT-LLM Mode
#### Using Offline TRT-LLM Mode
```sh
# bash run.sh 7 7 F5TTS_v1_Base
batch_size=1
split_name=wenetspeech4tts
backend_type=trt
log_dir=./log_benchmark_batch_size_${batch_size}_${split_name}_${backend_type}
log_dir=./tests/benchmark_batch_size_${batch_size}_${split_name}_${backend_type}
rm -r $log_dir
ln -s model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py ./
torchrun --nproc_per_node=1 \
benchmark.py --output-dir $log_dir \
--batch-size $batch_size \
--enable-warmup \
--split-name $split_name \
--model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \
--vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \
--vocoder-trt-engine-path $vocoder_trt_engine_path \
--model-path $ckpt_file \
--vocab-file $vocab_file \
--vocoder-trt-engine-path $VOCODER_TRT_ENGINE_PATH \
--backend-type $backend_type \
--tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1
--tllm-model-dir $TRTLLM_ENGINE_DIR || exit 1
```
### Benchmark Results
@@ -66,4 +75,5 @@ Decoding on a single L20 GPU, using 26 different prompt_audio & target_text pair
| F5-TTS Base (Vocos) | 1 (Batch_size) | - | 0.1467 | Offline Pytorch |
### Credits
1. [F5-TTS-TRTLLM](https://github.com/Bigfishering/f5-tts-trtllm)
1. [Yuekai Zhang](https://github.com/yuekaizhang)
2. [F5-TTS-TRTLLM](https://github.com/Bigfishering/f5-tts-trtllm)

View File

@@ -1,5 +1,5 @@
# Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song)
# 2025 authors: Yuekai Zhang
# 2025 (authors: Yuekai Zhang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -19,39 +19,45 @@ benchmark.py --output-dir $log_dir \
--batch-size $batch_size \
--enable-warmup \
--split-name $split_name \
--model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \
--vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \
--model-path $CKPT_DIR/$model/model_1200000.pt \
--vocab-file $CKPT_DIR/$model/vocab.txt \
--vocoder-trt-engine-path $vocoder_trt_engine_path \
--backend-type $backend_type \
--tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1
--tllm-model-dir $TRTLLM_ENGINE_DIR || exit 1
"""
import argparse
import importlib
import json
import os
import sys
import time
from typing import Dict, List, Union
import datasets
import jieba
import tensorrt as trt
import torch
import torch.distributed as dist
import torch.nn.functional as F
import torchaudio
from datasets import load_dataset
from f5_tts_trtllm import F5TTS
from huggingface_hub import hf_hub_download
from pypinyin import Style, lazy_pinyin
from tensorrt_llm._utils import trt_dtype_to_torch
from tensorrt_llm.logger import logger
from tensorrt_llm.runtime.session import Session, TensorInfo
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, DistributedSampler
from tqdm import tqdm
from vocos import Vocos
sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../../../src/")
from f5_tts.eval.utils_eval import padded_mel_batch
from f5_tts.model.modules import get_vocos_mel_spectrogram
from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer, list_str_to_idx
F5TTS = importlib.import_module("model_repo_f5_tts.f5_tts.1.f5_tts_trtllm").F5TTS
torch.manual_seed(0)
@@ -111,22 +117,20 @@ def get_args():
return args
def padded_mel_batch(ref_mels, max_seq_len):
padded_ref_mels = []
for mel in ref_mels:
# pad along the last dimension
padded_ref_mel = F.pad(mel, (0, 0, 0, max_seq_len - mel.shape[0]), value=0)
padded_ref_mels.append(padded_ref_mel)
padded_ref_mels = torch.stack(padded_ref_mels)
return padded_ref_mels
def data_collator(batch, vocab_char_map, device="cuda", use_perf=False):
if use_perf:
torch.cuda.nvtx.range_push("data_collator")
target_sample_rate = 24000
target_rms = 0.1
ids, ref_mel_list, ref_mel_len_list, estimated_reference_target_mel_len, reference_target_texts_list = (
(
ids,
ref_rms_list,
ref_mel_list,
ref_mel_len_list,
estimated_reference_target_mel_len,
reference_target_texts_list,
) = (
[],
[],
[],
[],
@@ -148,6 +152,7 @@ def data_collator(batch, vocab_char_map, device="cuda", use_perf=False):
)
ref_audio_org = torch.from_numpy(ref_audio_org).unsqueeze(0).float()
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org)))
ref_rms_list.append(ref_rms)
if ref_rms < target_rms:
ref_audio_org = ref_audio_org * target_rms / ref_rms
@@ -159,40 +164,31 @@ def data_collator(batch, vocab_char_map, device="cuda", use_perf=False):
if use_perf:
torch.cuda.nvtx.range_push(f"mel_spectrogram {i}")
ref_mel = mel_spectrogram(ref_audio, vocoder="vocos", device="cuda")
ref_audio = ref_audio.to("cuda")
ref_mel = get_vocos_mel_spectrogram(ref_audio).squeeze(0)
if use_perf:
torch.cuda.nvtx.range_pop()
ref_mel = ref_mel.squeeze()
ref_mel_len = ref_mel.shape[0]
assert ref_mel.shape[1] == 100
ref_mel_len = ref_mel.shape[-1]
assert ref_mel.shape[0] == 100
ref_mel_list.append(ref_mel)
ref_mel_len_list.append(ref_mel_len)
estimated_reference_target_mel_len.append(
int(ref_mel.shape[0] * (1 + len(target_text.encode("utf-8")) / len(prompt_text.encode("utf-8"))))
int(ref_mel_len * (1 + len(target_text.encode("utf-8")) / len(prompt_text.encode("utf-8"))))
)
max_seq_len = max(estimated_reference_target_mel_len)
ref_mel_batch = padded_mel_batch(ref_mel_list, max_seq_len)
ref_mel_batch = padded_mel_batch(ref_mel_list)
ref_mel_len_batch = torch.LongTensor(ref_mel_len_list)
pinyin_list = convert_char_to_pinyin(reference_target_texts_list, polyphone=True)
text_pad_sequence = list_str_to_idx(pinyin_list, vocab_char_map)
for i, item in enumerate(text_pad_sequence):
text_pad_sequence[i] = F.pad(
item, (0, estimated_reference_target_mel_len[i] - len(item)), mode="constant", value=-1
)
text_pad_sequence[i] += 1 # WAR: 0 is reserved for padding token, hard coding in F5-TTS
text_pad_sequence = pad_sequence(text_pad_sequence, padding_value=-1, batch_first=True).to(device)
text_pad_sequence = F.pad(
text_pad_sequence, (0, max_seq_len - text_pad_sequence.shape[1]), mode="constant", value=-1
)
if use_perf:
torch.cuda.nvtx.range_pop()
return {
"ids": ids,
"ref_rms_list": ref_rms_list,
"ref_mel_batch": ref_mel_batch,
"ref_mel_len_batch": ref_mel_len_batch,
"text_pad_sequence": text_pad_sequence,
@@ -216,72 +212,6 @@ def init_distributed():
return world_size, local_rank, rank
def get_tokenizer(vocab_file_path: str):
"""
tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
- "char" for char-wise tokenizer, need .txt vocab_file
- "byte" for utf-8 tokenizer
- "custom" if you're directly passing in a path to the vocab.txt you want to use
vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
- if use "char", derived from unfiltered character & symbol counts of custom dataset
- if use "byte", set to 256 (unicode byte range)
"""
with open(vocab_file_path, "r", encoding="utf-8") as f:
vocab_char_map = {}
for i, char in enumerate(f):
vocab_char_map[char[:-1]] = i
vocab_size = len(vocab_char_map)
return vocab_char_map, vocab_size
def convert_char_to_pinyin(reference_target_texts_list, polyphone=True):
final_reference_target_texts_list = []
custom_trans = str.maketrans(
{";": ",", "": '"', "": '"', "": "'", "": "'"}
) # add custom trans here, to address oov
def is_chinese(c):
return "\u3100" <= c <= "\u9fff" # common chinese characters
for text in reference_target_texts_list:
char_list = []
text = text.translate(custom_trans)
for seg in jieba.cut(text):
seg_byte_len = len(bytes(seg, "UTF-8"))
if seg_byte_len == len(seg): # if pure alphabets and symbols
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
char_list.append(" ")
char_list.extend(seg)
elif polyphone and seg_byte_len == 3 * len(seg): # if pure east asian characters
seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
for i, c in enumerate(seg):
if is_chinese(c):
char_list.append(" ")
char_list.append(seg_[i])
else: # if mixed characters, alphabets and symbols
for c in seg:
if ord(c) < 256:
char_list.extend(c)
elif is_chinese(c):
char_list.append(" ")
char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
else:
char_list.append(c)
final_reference_target_texts_list.append(char_list)
return final_reference_target_texts_list
def list_str_to_idx(
text: Union[List[str], List[List[str]]],
vocab_char_map: Dict[str, int], # {char: idx}
padding_value=-1,
):
list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
# text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
return list_idx_tensors
def load_vocoder(
vocoder_name="vocos", is_local=False, local_path="", device="cuda", hf_cache_dir=None, vocoder_trt_engine_path=None
):
@@ -316,29 +246,11 @@ def load_vocoder(
return vocoder
def mel_spectrogram(waveform, vocoder="vocos", device="cuda"):
if vocoder == "vocos":
mel_stft = torchaudio.transforms.MelSpectrogram(
sample_rate=24000,
n_fft=1024,
win_length=1024,
hop_length=256,
n_mels=100,
power=1,
center=True,
normalized=False,
norm=None,
).to(device)
mel = mel_stft(waveform.to(device))
mel = mel.clamp(min=1e-5).log()
return mel.transpose(1, 2)
class VocosTensorRT:
def __init__(self, engine_path="./vocos_vocoder.plan", stream=None):
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
trt.init_libnvinfer_plugins(TRT_LOGGER, namespace="")
logger.info(f"Loading vae engine from {engine_path}")
logger.info(f"Loading vocoder engine from {engine_path}")
self.engine_path = engine_path
with open(engine_path, "rb") as f:
engine_buffer = f.read()
@@ -368,34 +280,37 @@ def main():
world_size, local_rank, rank = init_distributed()
device = torch.device(f"cuda:{local_rank}")
vocab_char_map, vocab_size = get_tokenizer(args.vocab_file)
vocab_char_map, vocab_size = get_tokenizer(args.vocab_file, "custom")
tllm_model_dir = args.tllm_model_dir
config_file = os.path.join(tllm_model_dir, "config.json")
with open(config_file) as f:
config = json.load(f)
with open(os.path.join(tllm_model_dir, "config.json")) as f:
tllm_model_config = json.load(f)
if args.backend_type == "trt":
model = F5TTS(
config, debug_mode=False, tllm_model_dir=tllm_model_dir, model_path=args.model_path, vocab_size=vocab_size
tllm_model_config,
debug_mode=False,
tllm_model_dir=tllm_model_dir,
model_path=args.model_path,
vocab_size=vocab_size,
)
elif args.backend_type == "pytorch":
import sys
sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../../../src/")
from f5_tts.infer.utils_infer import load_model
from f5_tts.model import DiT
F5TTS_model_cfg = dict(
dim=1024,
depth=22,
heads=16,
ff_mult=2,
text_dim=512,
conv_layers=4,
pe_attn_head=1,
text_mask_padding=False,
pretrained_config = tllm_model_config["pretrained_config"]
pt_model_config = dict(
dim=pretrained_config["hidden_size"],
depth=pretrained_config["num_hidden_layers"],
heads=pretrained_config["num_attention_heads"],
ff_mult=pretrained_config["ff_mult"],
text_dim=pretrained_config["text_dim"],
text_mask_padding=pretrained_config["text_mask_padding"],
conv_layers=pretrained_config["conv_layers"],
pe_attn_head=pretrained_config["pe_attn_head"],
# attn_backend="flash_attn",
# attn_mask_enabled=True,
)
model = load_model(DiT, F5TTS_model_cfg, args.model_path)
model = load_model(DiT, pt_model_config, args.model_path)
vocoder = load_vocoder(
vocoder_name=args.vocoder, device=device, vocoder_trt_engine_path=args.vocoder_trt_engine_path
@@ -445,20 +360,23 @@ def main():
ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch["ref_mel_len_batch"].to(device)
text_pad_seq = batch["text_pad_sequence"].to(device)
total_mel_lens = batch["estimated_reference_target_mel_len"]
cond_pad_seq = F.pad(ref_mels, (0, 0, 0, max(total_mel_lens) - ref_mels.shape[1], 0, 0))
if args.backend_type == "trt":
_ = model.sample(
text_pad_seq, ref_mels, ref_mel_lens, total_mel_lens, remove_input_padding=args.remove_input_padding
text_pad_seq,
cond_pad_seq,
ref_mel_lens,
total_mel_lens,
remove_input_padding=args.remove_input_padding,
)
elif args.backend_type == "pytorch":
total_mel_lens = torch.tensor(total_mel_lens, device=device)
with torch.inference_mode():
text_pad_seq -= 1
text_pad_seq[text_pad_seq == -2] = -1
total_mel_lens = torch.tensor(total_mel_lens, device=device)
generated, _ = model.sample(
cond=ref_mels,
text=text_pad_seq,
duration=total_mel_lens,
steps=16,
steps=32,
cfg_strength=2.0,
sway_sampling_coef=-1,
)
@@ -478,13 +396,13 @@ def main():
ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch["ref_mel_len_batch"].to(device)
text_pad_seq = batch["text_pad_sequence"].to(device)
total_mel_lens = batch["estimated_reference_target_mel_len"]
cond_pad_seq = F.pad(ref_mels, (0, 0, 0, max(total_mel_lens) - ref_mels.shape[1], 0, 0))
if args.use_perf:
torch.cuda.nvtx.range_pop()
if args.backend_type == "trt":
generated, cost_time = model.sample(
text_pad_seq,
ref_mels,
cond_pad_seq,
ref_mel_lens,
total_mel_lens,
remove_input_padding=args.remove_input_padding,
@@ -494,20 +412,20 @@ def main():
total_mel_lens = torch.tensor(total_mel_lens, device=device)
with torch.inference_mode():
start_time = time.time()
text_pad_seq -= 1
text_pad_seq[text_pad_seq == -2] = -1
generated, _ = model.sample(
cond=ref_mels,
text=text_pad_seq,
duration=total_mel_lens,
lens=ref_mel_lens,
steps=16,
steps=32,
cfg_strength=2.0,
sway_sampling_coef=-1,
)
cost_time = time.time() - start_time
decoding_time += cost_time
vocoder_start_time = time.time()
target_rms = 0.1
target_sample_rate = 24000
for i, gen in enumerate(generated):
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32)
@@ -519,13 +437,10 @@ def main():
torch.cuda.nvtx.range_pop()
else:
generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()
target_rms = 0.1
target_sample_rate = 24_000
# if ref_rms_list[i] < target_rms:
# generated_wave = generated_wave * ref_rms_list[i] / target_rms
rms = torch.sqrt(torch.mean(torch.square(generated_wave)))
if rms < target_rms:
generated_wave = generated_wave * target_rms / rms
if batch["ref_rms_list"][i] < target_rms:
generated_wave = generated_wave * batch["ref_rms_list"][i] / target_rms
utt = batch["ids"][i]
torchaudio.save(
f"{args.output_dir}/{utt}.wav",

View File

@@ -30,15 +30,6 @@ python3 client_grpc.py \
--huggingface-dataset yuekai/seed_tts \
--split-name test_zh \
--log-dir ./log_concurrent_tasks_${num_task}
# For offline Spark-TTS-0.5B
python3 client_grpc.py \
--server-addr localhost \
--model-name spark_tts \
--num-tasks $num_task \
--huggingface-dataset yuekai/seed_tts \
--split-name wenetspeech4tts \
--log-dir ./log_concurrent_tasks_${num_task}
"""
import argparse
@@ -176,8 +167,7 @@ def get_args():
"--model-name",
type=str,
default="f5_tts",
choices=["f5_tts", "spark_tts"],
help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline",
help="triton model_repo module name to request",
)
parser.add_argument(
@@ -206,7 +196,7 @@ def get_args():
"--log-dir",
type=str,
required=False,
default="./tmp",
default="./tests/client_grpc",
help="log directory",
)
@@ -230,8 +220,7 @@ def load_audio(wav_path, target_sample_rate=24000):
if sample_rate != target_sample_rate:
from scipy.signal import resample
num_samples = int(len(waveform) * (target_sample_rate / sample_rate))
waveform = resample(waveform, num_samples)
waveform = resample(waveform, int(len(waveform) * (target_sample_rate / sample_rate)))
return waveform, target_sample_rate

View File

@@ -24,6 +24,7 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import argparse
import os
import numpy as np
import requests
@@ -65,33 +66,32 @@ def get_args():
"--model-name",
type=str,
default="f5_tts",
choices=["f5_tts", "spark_tts"],
help="triton model_repo module name to request",
)
parser.add_argument(
"--output-audio",
type=str,
default="output.wav",
default="tests/client_http.wav",
help="Path to save the output audio",
)
return parser.parse_args()
def prepare_request(
samples,
waveform,
reference_text,
target_text,
sample_rate=24000,
audio_save_dir: str = "./",
):
assert len(samples.shape) == 1, "samples should be 1D"
lengths = np.array([[len(samples)]], dtype=np.int32)
samples = samples.reshape(1, -1).astype(np.float32)
assert len(waveform.shape) == 1, "waveform should be 1D"
lengths = np.array([[len(waveform)]], dtype=np.int32)
waveform = waveform.reshape(1, -1).astype(np.float32)
data = {
"inputs": [
{"name": "reference_wav", "shape": samples.shape, "datatype": "FP32", "data": samples.tolist()},
{"name": "reference_wav", "shape": waveform.shape, "datatype": "FP32", "data": waveform.tolist()},
{
"name": "reference_wav_len",
"shape": lengths.shape,
@@ -109,16 +109,15 @@ def prepare_request(
def load_audio(wav_path, target_sample_rate=24000):
assert target_sample_rate == 24000, "hard coding in server"
if isinstance(wav_path, dict):
samples = wav_path["array"]
waveform = wav_path["array"]
sample_rate = wav_path["sampling_rate"]
else:
samples, sample_rate = sf.read(wav_path)
waveform, sample_rate = sf.read(wav_path)
if sample_rate != target_sample_rate:
from scipy.signal import resample
num_samples = int(len(samples) * (target_sample_rate / sample_rate))
samples = resample(samples, num_samples)
return samples, target_sample_rate
waveform = resample(waveform, int(len(waveform) * (target_sample_rate / sample_rate)))
return waveform, target_sample_rate
if __name__ == "__main__":
@@ -128,11 +127,11 @@ if __name__ == "__main__":
server_url = f"http://{server_url}"
url = f"{server_url}/v2/models/{args.model_name}/infer"
samples, sr = load_audio(args.reference_audio)
waveform, sr = load_audio(args.reference_audio)
assert sr == 24000, "sample rate hardcoded in server"
samples = np.array(samples, dtype=np.float32)
data = prepare_request(samples, args.reference_text, args.target_text)
waveform = np.array(waveform, dtype=np.float32)
data = prepare_request(waveform, args.reference_text, args.target_text)
rsp = requests.post(
url, headers={"Content-Type": "application/json"}, json=data, verify=False, params={"request_id": "0"}
@@ -140,4 +139,5 @@ if __name__ == "__main__":
result = rsp.json()
audio = result["outputs"][0]["data"]
audio = np.array(audio, dtype=np.float32)
os.makedirs(os.path.dirname(args.output_audio), exist_ok=True)
sf.write(args.output_audio, audio, 24000, "PCM_16")

View File

@@ -12,6 +12,7 @@ import torch.nn.functional as F
from tensorrt_llm._utils import str_dtype_to_torch, trt_dtype_to_torch
from tensorrt_llm.logger import logger
from tensorrt_llm.runtime.session import Session
from torch.nn.utils.rnn import pad_sequence
def remove_tensor_padding(input_tensor, input_tensor_lengths=None):
@@ -32,26 +33,35 @@ def remove_tensor_padding(input_tensor, input_tensor_lengths=None):
class TextEmbedding(nn.Module):
def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2, precompute_max_pos=4096):
def __init__(
self, text_num_embeds, text_dim, mask_padding=True, conv_layers=0, conv_mult=2, precompute_max_pos=4096
):
super().__init__()
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
self.mask_padding = mask_padding
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, precompute_max_pos), persistent=False)
self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
def forward(self, text):
# only keep tensors with value not -1
text_mask = text != -1
text_pad_cut_off_index = text_mask.sum(dim=1).max()
def forward(self, text, seq_len, drop_text=False):
text = text + 1
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
text = F.pad(text, (0, seq_len - text.shape[1]), value=0)
if self.mask_padding:
text_mask = text == 0
if drop_text: # cfg for text
text = torch.zeros_like(text)
text = self.text_embed(text) # b n -> b n d
text = text + self.freqs_cis[:seq_len, :]
if self.mask_padding:
text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
for block in self.text_blocks:
text = block(text)
text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
else:
text = self.text_blocks(text)
text = text[:, :text_pad_cut_off_index]
text = self.text_embed(text)
text = text + self.freqs_cis[: text.shape[1], :]
for block in self.text_blocks:
text = block(text)
# padding text to the original length
# text shape: B,seq_len,C
# pad at the second dimension
text = F.pad(text, (0, 0, 0, text_mask.shape[1] - text.shape[1], 0, 0), value=0)
return text
@@ -112,20 +122,33 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_resca
return torch.cat([freqs_cos, freqs_sin], dim=-1)
def load_checkpoint(ckpt_path, use_ema=True):
checkpoint = torch.load(ckpt_path, weights_only=True)
def get_text_embed_dict(ckpt_path, use_ema=True):
ckpt_type = ckpt_path.split(".")[-1]
if ckpt_type == "safetensors":
from safetensors.torch import load_file
checkpoint = load_file(ckpt_path)
else:
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
if use_ema:
if ckpt_type == "safetensors":
checkpoint = {"ema_model_state_dict": checkpoint}
checkpoint["model_state_dict"] = {
k.replace("ema_model.", ""): v
for k, v in checkpoint["ema_model_state_dict"].items()
if k not in ["initted", "step"]
}
dict_state = checkpoint["model_state_dict"]
else:
if ckpt_type == "safetensors":
checkpoint = {"model_state_dict": checkpoint}
model_params = checkpoint["model_state_dict"]
text_embed_dict = {}
for key in dict_state.keys():
for key in model_params.keys():
# transformer.text_embed.text_embed.weight -> text_embed.weight
if "text_embed" in key:
text_embed_dict[key.replace("transformer.text_embed.", "")] = dict_state[key]
text_embed_dict[key.replace("transformer.text_embed.", "")] = model_params[key]
return text_embed_dict
@@ -194,18 +217,16 @@ class F5TTS(object):
self.max_mel_len = 4096
self.text_embedding = TextEmbedding(
text_num_embeds=vocab_size, text_dim=512, conv_layers=4, precompute_max_pos=self.max_mel_len
text_num_embeds=vocab_size,
text_dim=config["pretrained_config"]["text_dim"],
mask_padding=config["pretrained_config"]["text_mask_padding"],
conv_layers=config["pretrained_config"]["conv_layers"],
precompute_max_pos=self.max_mel_len,
).to(self.device)
self.text_embedding.load_state_dict(load_checkpoint(model_path), strict=True)
self.text_embedding.load_state_dict(get_text_embed_dict(model_path), strict=True)
self.target_audio_sample_rate = 24000
self.target_rms = 0.15 # target rms for audio
self.n_fft = 1024
self.win_length = 1024
self.hop_length = 256
self.n_mel_channels = 100
# self.max_mel_len = 3000
self.head_dim = 64
self.n_mel_channels = config["pretrained_config"]["mel_dim"]
self.head_dim = config["pretrained_config"]["dim_head"]
self.base_rescale_factor = 1.0
self.interpolation_factor = 1.0
base = 10000.0 * self.base_rescale_factor ** (self.head_dim / (self.head_dim - 2))
@@ -214,14 +235,23 @@ class F5TTS(object):
self.freqs = freqs.repeat_interleave(2, dim=-1).unsqueeze(0)
self.rope_cos = self.freqs.cos().half()
self.rope_sin = self.freqs.sin().half()
self.nfe_steps = 16
t = torch.linspace(0, 1, self.nfe_steps + 1, dtype=torch.float32)
time_step = t + (-1.0) * (torch.cos(torch.pi * 0.5 * t) - 1 + t)
self.nfe_steps = 32
epss = {
5: [0, 2, 4, 8, 16, 32],
6: [0, 2, 4, 6, 8, 16, 32],
7: [0, 2, 4, 6, 8, 16, 24, 32],
10: [0, 2, 4, 6, 8, 12, 16, 20, 24, 28, 32],
12: [0, 2, 4, 6, 8, 10, 12, 14, 16, 20, 24, 28, 32],
16: [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 32],
}
t = 1 / 32 * torch.tensor(epss.get(self.nfe_steps, list(range(self.nfe_steps + 1))), dtype=torch.float32)
time_step = 1 - torch.cos(torch.pi * t / 2)
delta_t = torch.diff(time_step)
# WAR: hard coding 256 here
tmp_dim = 256
time_expand = torch.zeros((1, self.nfe_steps, tmp_dim), dtype=torch.float32)
half_dim = tmp_dim // 2
freq_embed_dim = 256 # Warning: hard coding 256 here
time_expand = torch.zeros((1, self.nfe_steps, freq_embed_dim), dtype=torch.float32)
half_dim = freq_embed_dim // 2
emb_factor = math.log(10000) / (half_dim - 1)
emb_factor = 1000.0 * torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb_factor)
for i in range(self.nfe_steps):
@@ -344,7 +374,7 @@ class F5TTS(object):
def sample(
self,
text_pad_sequence: torch.Tensor,
ref_mel_batch: torch.Tensor,
cond_pad_sequence: torch.Tensor,
ref_mel_len_batch: torch.Tensor,
estimated_reference_target_mel_len: List[int],
remove_input_padding: bool = False,
@@ -353,26 +383,43 @@ class F5TTS(object):
if use_perf:
torch.cuda.nvtx.range_push("text embedding")
batch = text_pad_sequence.shape[0]
max_seq_len = ref_mel_batch.shape[1]
max_seq_len = cond_pad_sequence.shape[1]
text_pad_sequence_drop = torch.cat(
(text_pad_sequence, torch.zeros((1, text_pad_sequence.shape[1]), dtype=torch.int32).to(self.device)), dim=0
# get text_embed one by one to avoid misalignment
text_and_drop_embedding_list = []
for i in range(batch):
text_embedding_i = self.text_embedding(
text_pad_sequence[i].unsqueeze(0).to(self.device),
estimated_reference_target_mel_len[i],
drop_text=False,
)
text_embedding_drop_i = self.text_embedding(
text_pad_sequence[i].unsqueeze(0).to(self.device),
estimated_reference_target_mel_len[i],
drop_text=True,
)
text_and_drop_embedding_list.extend([text_embedding_i[0], text_embedding_drop_i[0]])
# pad separately computed text_embed to form batch with max_seq_len
text_and_drop_embedding = pad_sequence(
text_and_drop_embedding_list,
batch_first=True,
padding_value=0,
)
text_embedding = text_and_drop_embedding[0::2]
text_embedding_drop = text_and_drop_embedding[1::2]
text_embedding_drop_list = []
for i in range(batch + 1):
text_embedding_drop_list.append(self.text_embedding(text_pad_sequence_drop[i].unsqueeze(0).to(self.device)))
text_embedding_drop_condition = torch.cat(text_embedding_drop_list, dim=0)
text_embedding = text_embedding_drop_condition[:-1]
# text_embedding_drop B,T,C batch should be the same
text_embedding_drop = text_embedding_drop_condition[-1].unsqueeze(0).repeat(batch, 1, 1)
noise = torch.randn_like(ref_mel_batch).to(self.device)
noise = torch.randn_like(cond_pad_sequence).to(self.device)
rope_cos = self.rope_cos[:, :max_seq_len, :].float().repeat(batch, 1, 1)
rope_sin = self.rope_sin[:, :max_seq_len, :].float().repeat(batch, 1, 1)
cat_mel_text = torch.cat((ref_mel_batch, text_embedding), dim=-1)
cat_mel_text = torch.cat(
(
cond_pad_sequence,
text_embedding,
),
dim=-1,
)
cat_mel_text_drop = torch.cat(
(
torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float32).to(self.device),

View File

@@ -26,9 +26,8 @@
import json
import os
import jieba
import rjieba
import torch
import torch.nn.functional as F
import torchaudio
import triton_python_backend_utils as pb_utils
from f5_tts_trtllm import F5TTS
@@ -67,7 +66,7 @@ def convert_char_to_pinyin(reference_target_texts_list, polyphone=True):
for text in reference_target_texts_list:
char_list = []
text = text.translate(custom_trans)
for seg in jieba.cut(text):
for seg in rjieba.cut(text):
seg_byte_len = len(bytes(seg, "UTF-8"))
if seg_byte_len == len(seg): # if pure alphabets and symbols
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
@@ -99,7 +98,8 @@ def list_str_to_idx(
padding_value=-1,
): # noqa: F722
list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
return list_idx_tensors
text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
return text
class TritonPythonModel:
@@ -107,13 +107,12 @@ class TritonPythonModel:
self.use_perf = True
self.device = torch.device("cuda")
self.target_audio_sample_rate = 24000
self.target_rms = 0.15 # target rms for audio
self.target_rms = 0.1 # least rms when inference, normalize to if lower
self.n_fft = 1024
self.win_length = 1024
self.hop_length = 256
self.n_mel_channels = 100
self.max_mel_len = 3000
self.head_dim = 64
self.max_mel_len = 4096
parameters = json.loads(args["model_config"])["parameters"]
for key, value in parameters.items():
@@ -181,7 +180,8 @@ class TritonPythonModel:
reference_target_texts_list,
estimated_reference_target_mel_len,
reference_mel_len,
) = [], [], [], [], []
reference_rms_list,
) = [], [], [], [], [], []
mel_features_list = []
if self.use_perf:
torch.cuda.nvtx.range_push("preprocess")
@@ -208,6 +208,7 @@ class TritonPythonModel:
ref_rms = torch.sqrt(torch.mean(torch.square(wav)))
if ref_rms < self.target_rms:
wav = wav * self.target_rms / ref_rms
reference_rms_list.append(ref_rms)
if self.reference_sample_rate != self.target_audio_sample_rate:
wav = self.resampler(wav)
wav = wav.to(self.device)
@@ -228,7 +229,7 @@ class TritonPythonModel:
max_seq_len = min(max(estimated_reference_target_mel_len), self.max_mel_len)
batch = len(requests)
mel_features = torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float16).to(self.device)
mel_features = torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float32).to(self.device)
for i, mel in enumerate(mel_features_list):
mel_features[i, : mel.shape[1], :] = mel
@@ -237,15 +238,6 @@ class TritonPythonModel:
pinyin_list = convert_char_to_pinyin(reference_target_texts_list, polyphone=True)
text_pad_sequence = list_str_to_idx(pinyin_list, self.vocab_char_map)
for i, item in enumerate(text_pad_sequence):
text_pad_sequence[i] = F.pad(
item, (0, estimated_reference_target_mel_len[i] - len(item)), mode="constant", value=-1
)
text_pad_sequence[i] += 1 # WAR: 0 is reserved for padding token, hard coding in F5-TTS
text_pad_sequence = pad_sequence(text_pad_sequence, padding_value=-1, batch_first=True).to(self.device)
text_pad_sequence = F.pad(
text_pad_sequence, (0, max_seq_len - text_pad_sequence.shape[1]), mode="constant", value=-1
)
if self.use_perf:
torch.cuda.nvtx.range_pop()
@@ -262,13 +254,12 @@ class TritonPythonModel:
responses = []
for i in range(batch):
ref_me_len = reference_mel_len[i]
ref_mel_len = reference_mel_len[i]
estimated_mel_len = estimated_reference_target_mel_len[i]
denoised_one_item = denoised[i, ref_me_len:estimated_mel_len, :].unsqueeze(0).transpose(1, 2)
denoised_one_item = denoised[i, ref_mel_len:estimated_mel_len, :].unsqueeze(0).transpose(1, 2)
audio = self.forward_vocoder(denoised_one_item)
rms = torch.sqrt(torch.mean(torch.square(audio)))
if rms < self.target_rms:
audio = audio * self.target_rms / rms
if reference_rms_list[i] < self.target_rms:
audio = audio * reference_rms_list[i] / self.target_rms
audio = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
inference_response = pb_utils.InferenceResponse(output_tensors=[audio])

View File

@@ -4,11 +4,20 @@ import os
import sys
from collections import OrderedDict
import numpy as np
import tensorrt as trt
from tensorrt_llm._common import default_net
from ..._utils import str_dtype_to_trt
from ...functional import Tensor, concat
from ...functional import (
Tensor,
concat,
constant,
expand,
shape,
slice,
unsqueeze,
)
from ...layers import Linear
from ...module import Module, ModuleList
from ...plugin import current_all_reduce_helper
@@ -27,9 +36,9 @@ class InputEmbedding(Module):
self.proj = Linear(mel_dim * 2 + text_dim, out_dim)
self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
def forward(self, x, cond):
def forward(self, x, cond, mask=None):
x = self.proj(concat([x, cond], dim=-1))
return self.conv_pos_embed(x) + x
return self.conv_pos_embed(x, mask=mask) + x
class F5TTS(PretrainedModel):
@@ -50,6 +59,7 @@ class F5TTS(PretrainedModel):
dim_head=config.dim_head,
ff_mult=config.ff_mult,
dropout=config.dropout,
pe_attn_head=config.pe_attn_head,
)
for _ in range(self.depth)
]
@@ -68,10 +78,26 @@ class F5TTS(PretrainedModel):
input_lengths,
scale=1.0,
):
if default_net().plugin_config.remove_input_padding:
mask = None
else:
N = shape(noise, 1)
B = shape(noise, 0)
seq_len_2d = concat([1, N])
max_position_embeddings = 4096
# create position ids
position_ids_buffer = constant(np.expand_dims(np.arange(max_position_embeddings).astype(np.int32), 0))
tmp_position_ids = slice(position_ids_buffer, starts=[0, 0], sizes=seq_len_2d)
tmp_position_ids = expand(tmp_position_ids, concat([B, N])) # [B, N]
tmp_input_lengths = unsqueeze(input_lengths, 1) # [B, 1]
tmp_input_lengths = expand(tmp_input_lengths, concat([B, N])) # [B, N]
mask = tmp_position_ids < tmp_input_lengths # [B, N]
mask = mask.cast("int32")
t = self.time_embed(time)
x = self.input_embed(noise, cond)
x = self.input_embed(noise, cond, mask=mask)
for block in self.transformer_blocks:
x = block(x, t, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale)
x = block(x, t, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale, mask=mask)
denoise = self.proj_out(self.norm_out(x, t))
denoise.mark_output("denoised", self.dtype)
return denoise
@@ -79,13 +105,12 @@ class F5TTS(PretrainedModel):
def prepare_inputs(self, **kwargs):
max_batch_size = kwargs["max_batch_size"]
batch_size_range = [2, 2, max_batch_size]
mel_size = 100
max_seq_len = 3000
num_frames_range = [200, 2 * max_seq_len, max_seq_len * max_batch_size]
hidden_size = 512
concat_feature_dim = mel_size + hidden_size
freq_embed_dim = 256
head_dim = 64
mel_size = self.config.mel_dim
max_seq_len = 3000 # 4096
num_frames_range = [mel_size * 2, max_seq_len * 2, max_seq_len * max_batch_size]
concat_feature_dim = mel_size + self.config.text_dim
freq_embed_dim = 256 # Warning: hard coding 256 here
head_dim = self.config.dim_head
mapping = self.config.mapping
if mapping.tp_size > 1:
current_all_reduce_helper().set_workspace_tensor(mapping, 1)

View File

@@ -16,7 +16,6 @@ from ...functional import (
chunk,
concat,
constant,
expand,
expand_dims,
expand_dims_like,
expand_mask,
@@ -95,15 +94,24 @@ class ConvPositionEmbedding(Module):
self.conv1d2 = Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2)
self.mish = Mish()
def forward(self, x, mask=None): # noqa: F722
def forward(self, x, mask=None):
if default_net().plugin_config.remove_input_padding:
x = unsqueeze(x, 0)
x = permute(x, [0, 2, 1])
x = self.mish(self.conv1d2(self.mish(self.conv1d1(x))))
out = permute(x, [0, 2, 1])
if mask is not None:
mask = mask.view(concat([shape(mask, 0), 1, shape(mask, 1)])) # [B 1 N]
mask = expand_dims_like(mask, x) # [B D N]
mask = cast(mask, x.dtype)
x = permute(x, [0, 2, 1]) # [B D N]
if mask is not None:
x = self.mish(self.conv1d2(self.mish(self.conv1d1(x * mask) * mask)) * mask)
else:
x = self.mish(self.conv1d2(self.mish(self.conv1d1(x))))
x = permute(x, [0, 2, 1]) # [B N D]
if default_net().plugin_config.remove_input_padding:
out = squeeze(out, 0)
return out
x = squeeze(x, 0)
return x
class Attention(Module):
@@ -185,6 +193,7 @@ class Attention(Module):
rope_cos,
rope_sin,
input_lengths,
mask=None,
c=None, # context c
scale=1.0,
rope=None,
@@ -227,29 +236,52 @@ def rotate_every_two_3dim(tensor: Tensor) -> Tensor:
return out
def apply_rotary_pos_emb_3dim(x, rope_cos, rope_sin):
if default_net().plugin_config.remove_input_padding:
rot_dim = shape(rope_cos, -1) # 64
new_t_shape = concat([shape(x, 0), rot_dim]) # (-1, 64)
x_ = slice(x, [0, 0], new_t_shape, [1, 1])
end_dim = shape(x, -1) - shape(rope_cos, -1)
new_t_unrotated_shape = concat([shape(x, 0), end_dim]) # (2, -1, 960)
x_unrotated = slice(x, concat([0, rot_dim]), new_t_unrotated_shape, [1, 1])
out = concat([x_ * rope_cos + rotate_every_two_3dim(x_) * rope_sin, x_unrotated], dim=-1)
else:
rot_dim = shape(rope_cos, 2) # 64
new_t_shape = concat([shape(x, 0), shape(x, 1), rot_dim]) # (2, -1, 64)
x_ = slice(x, [0, 0, 0], new_t_shape, [1, 1, 1])
end_dim = shape(x, 2) - shape(rope_cos, 2)
new_t_unrotated_shape = concat([shape(x, 0), shape(x, 1), end_dim]) # (2, -1, 960)
x_unrotated = slice(x, concat([0, 0, rot_dim]), new_t_unrotated_shape, [1, 1, 1])
out = concat([x_ * rope_cos + rotate_every_two_3dim(x_) * rope_sin, x_unrotated], dim=-1)
def apply_rotary_pos_emb_3dim(x, rope_cos, rope_sin, pe_attn_head):
full_dim = x.size(-1)
head_dim = rope_cos.size(-1) # attn head dim, e.g. 64
if pe_attn_head is None:
pe_attn_head = full_dim // head_dim
rotated_dim = head_dim * pe_attn_head
rotated_and_unrotated_list = []
if default_net().plugin_config.remove_input_padding: # for [N, D] input
new_t_shape = concat([shape(x, 0), head_dim]) # (2, -1, 64)
for i in range(pe_attn_head):
x_slice_i = slice(x, [0, i * 64], new_t_shape, [1, 1])
x_rotated_i = x_slice_i * rope_cos + rotate_every_two_3dim(x_slice_i) * rope_sin
rotated_and_unrotated_list.append(x_rotated_i)
new_t_unrotated_shape = concat([shape(x, 0), full_dim - rotated_dim]) # (2, -1, 1024 - 64 * pe_attn_head)
x_unrotated = slice(x, concat([0, rotated_dim]), new_t_unrotated_shape, [1, 1])
rotated_and_unrotated_list.append(x_unrotated)
else: # for [B, N, D] input
new_t_shape = concat([shape(x, 0), shape(x, 1), head_dim]) # (2, -1, 64)
for i in range(pe_attn_head):
x_slice_i = slice(x, [0, 0, i * 64], new_t_shape, [1, 1, 1])
x_rotated_i = x_slice_i * rope_cos + rotate_every_two_3dim(x_slice_i) * rope_sin
rotated_and_unrotated_list.append(x_rotated_i)
new_t_unrotated_shape = concat(
[shape(x, 0), shape(x, 1), full_dim - rotated_dim]
) # (2, -1, 1024 - 64 * pe_attn_head)
x_unrotated = slice(x, concat([0, 0, rotated_dim]), new_t_unrotated_shape, [1, 1, 1])
rotated_and_unrotated_list.append(x_unrotated)
out = concat(rotated_and_unrotated_list, dim=-1)
return out
class AttnProcessor:
def __init__(self):
pass
def __init__(
self,
pe_attn_head: Optional[int] = None, # number of attention head to apply rope, None for all
):
self.pe_attn_head = pe_attn_head
def __call__(
self,
@@ -260,32 +292,21 @@ class AttnProcessor:
input_lengths,
scale=1.0,
rope=None,
mask=None,
) -> torch.FloatTensor:
query = attn.to_q(x)
key = attn.to_k(x)
value = attn.to_v(x)
# k,v,q all (2,1226,1024)
query = apply_rotary_pos_emb_3dim(query, rope_cos, rope_sin)
key = apply_rotary_pos_emb_3dim(key, rope_cos, rope_sin)
query = apply_rotary_pos_emb_3dim(query, rope_cos, rope_sin, self.pe_attn_head)
key = apply_rotary_pos_emb_3dim(key, rope_cos, rope_sin, self.pe_attn_head)
# attention
inner_dim = key.shape[-1]
norm_factor = math.sqrt(attn.attention_head_size)
q_scaling = 1.0 / norm_factor
mask = None
if not default_net().plugin_config.remove_input_padding:
N = shape(x, 1)
B = shape(x, 0)
seq_len_2d = concat([1, N])
max_position_embeddings = 4096
# create position ids
position_ids_buffer = constant(np.expand_dims(np.arange(max_position_embeddings).astype(np.int32), 0))
tmp_position_ids = slice(position_ids_buffer, starts=[0, 0], sizes=seq_len_2d)
tmp_position_ids = expand(tmp_position_ids, concat([B, N])) # BxL
tmp_input_lengths = unsqueeze(input_lengths, 1) # Bx1
tmp_input_lengths = expand(tmp_input_lengths, concat([B, N])) # BxL
mask = tmp_position_ids < tmp_input_lengths # BxL
mask = mask.cast("int32")
if default_net().plugin_config.remove_input_padding:
mask = None
if default_net().plugin_config.bert_attention_plugin:
qkv = concat([query, key, value], dim=-1)
@@ -354,12 +375,12 @@ class AttnProcessor:
# DiT Block
class DiTBlock(Module):
def __init__(self, dim, heads, dim_head, ff_mult=2, dropout=0.1):
def __init__(self, dim, heads, dim_head, ff_mult=2, dropout=0.1, pe_attn_head=None):
super().__init__()
self.attn_norm = AdaLayerNormZero(dim)
self.attn = Attention(
processor=AttnProcessor(),
processor=AttnProcessor(pe_attn_head=pe_attn_head),
dim=dim,
heads=heads,
dim_head=dim_head,
@@ -370,14 +391,15 @@ class DiTBlock(Module):
self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout)
def forward(
self, x, t, rope_cos, rope_sin, input_lengths, scale=1.0, rope=ModuleNotFoundError
self, x, t, rope_cos, rope_sin, input_lengths, scale=1.0, rope=ModuleNotFoundError, mask=None
): # x: noised input, t: time embedding
# pre-norm & modulation for attention input
norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
# attention
# norm ----> (2,1226,1024)
attn_output = self.attn(x=norm, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale)
attn_output = self.attn(
x=norm, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale, mask=mask
)
# process attention output for input x
if default_net().plugin_config.remove_input_padding:
x = x + gate_msa * attn_output

View File

@@ -1,24 +0,0 @@
accelerate>=0.33.0
bitsandbytes>0.37.0
cached_path
click
datasets
ema_pytorch>=0.5.2
gradio>=3.45.2
hydra-core>=1.3.0
jieba
librosa
matplotlib
numpy<=1.26.4
pydub
pypinyin
safetensors
soundfile
tomli
torch>=2.0.0
# torchaudio>=2.0.0
torchdiffeq
tqdm>=4.65.0
transformers
x_transformers>=1.31.14
packaging>=24.2

View File

@@ -1,64 +1,66 @@
stage=$1
stop_stage=$2
model=$3 # F5TTS_Base
model=$3 # F5TTS_v1_Base | F5TTS_Base | F5TTS_v1_Small | F5TTS_Small
if [ -z "$model" ]; then
echo "Model is none, using default model F5TTS_Base"
model=F5TTS_Base
model=F5TTS_v1_Base
fi
echo "Start stage: $stage, Stop stage: $stop_stage, Model: $model"
export CUDA_VISIBLE_DEVICES=0
F5_TTS_HF_DOWNLOAD_PATH=./F5-TTS
F5_TTS_TRT_LLM_CHECKPOINT_PATH=./trtllm_ckpt
F5_TTS_TRT_LLM_ENGINE_PATH=./f5_trt_llm_engine
CKPT_DIR=../../../../ckpts
TRTLLM_CKPT_DIR=$CKPT_DIR/$model/trtllm_ckpt
TRTLLM_ENGINE_DIR=$CKPT_DIR/$model/trtllm_engine
vocoder_trt_engine_path=vocos_vocoder.plan
model_repo=./model_repo
VOCODER_ONNX_PATH=$CKPT_DIR/vocos_vocoder.onnx
VOCODER_TRT_ENGINE_PATH=$CKPT_DIR/vocos_vocoder.plan
MODEL_REPO=./model_repo
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
echo "Downloading f5 tts from huggingface"
huggingface-cli download SWivid/F5-TTS --local-dir $F5_TTS_HF_DOWNLOAD_PATH
echo "Downloading F5-TTS from huggingface"
huggingface-cli download SWivid/F5-TTS $model/model_*.* $model/vocab.txt --local-dir $CKPT_DIR
fi
ckpt_file=$(ls $CKPT_DIR/$model/model_*.* 2>/dev/null | sort -V | tail -1) # default select latest update
vocab_file=$CKPT_DIR/$model/vocab.txt
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
echo "Converting checkpoint"
python3 ./scripts/convert_checkpoint.py \
--timm_ckpt "$F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt" \
--output_dir "$F5_TTS_TRT_LLM_CHECKPOINT_PATH" --model_name $model
python3 scripts/convert_checkpoint.py \
--pytorch_ckpt $ckpt_file \
--output_dir $TRTLLM_CKPT_DIR --model_name $model
python_package_path=/usr/local/lib/python3.12/dist-packages
cp -r patch/* $python_package_path/tensorrt_llm/models
trtllm-build --checkpoint_dir $F5_TTS_TRT_LLM_CHECKPOINT_PATH \
trtllm-build --checkpoint_dir $TRTLLM_CKPT_DIR \
--max_batch_size 8 \
--output_dir $F5_TTS_TRT_LLM_ENGINE_PATH --remove_input_padding disable
--output_dir $TRTLLM_ENGINE_DIR --remove_input_padding disable
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
echo "Exporting vocos vocoder"
onnx_vocoder_path=vocos_vocoder.onnx
python3 scripts/export_vocoder_to_onnx.py --vocoder vocos --output-path $onnx_vocoder_path
bash scripts/export_vocos_trt.sh $onnx_vocoder_path $vocoder_trt_engine_path
python3 scripts/export_vocoder_to_onnx.py --vocoder vocos --output-path $VOCODER_ONNX_PATH
bash scripts/export_vocos_trt.sh $VOCODER_ONNX_PATH $VOCODER_TRT_ENGINE_PATH
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
echo "Building triton server"
rm -r $model_repo
cp -r ./model_repo_f5_tts $model_repo
python3 scripts/fill_template.py -i $model_repo/f5_tts/config.pbtxt vocab:$F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt,model:$F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt,trtllm:$F5_TTS_TRT_LLM_ENGINE_PATH,vocoder:vocos
cp $vocoder_trt_engine_path $model_repo/vocoder/1/vocoder.plan
rm -r $MODEL_REPO
cp -r ./model_repo_f5_tts $MODEL_REPO
python3 scripts/fill_template.py -i $MODEL_REPO/f5_tts/config.pbtxt vocab:$vocab_file,model:$ckpt_file,trtllm:$TRTLLM_ENGINE_DIR,vocoder:vocos
cp $VOCODER_TRT_ENGINE_PATH $MODEL_REPO/vocoder/1/vocoder.plan
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
echo "Starting triton server"
tritonserver --model-repository=$model_repo
tritonserver --model-repository=$MODEL_REPO
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
echo "Testing triton server"
num_task=1
log_dir=./log_concurrent_tasks_${num_task}
split_name=wenetspeech4tts
log_dir=./tests/client_grpc_${model}_concurrent_${num_task}_${split_name}
rm -r $log_dir
python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts --log-dir $log_dir
python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name $split_name --log-dir $log_dir
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
@@ -66,45 +68,45 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
audio=../../infer/examples/basic/basic_ref_en.wav
reference_text="Some call me nature, others call me mother nature."
target_text="I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring."
python3 client_http.py --reference-audio $audio --reference-text "$reference_text" --target-text "$target_text"
python3 client_http.py --reference-audio $audio --reference-text "$reference_text" --target-text "$target_text" --output-audio "./tests/client_http_$model.wav"
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
echo "TRT-LLM: offline decoding benchmark test"
batch_size=1
batch_size=2
split_name=wenetspeech4tts
backend_type=trt
log_dir=./log_benchmark_batch_size_${batch_size}_${split_name}_${backend_type}
log_dir=./tests/benchmark_${model}_batch_size_${batch_size}_${split_name}_${backend_type}
rm -r $log_dir
ln -s model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py ./
torchrun --nproc_per_node=1 \
benchmark.py --output-dir $log_dir \
--batch-size $batch_size \
--enable-warmup \
--split-name $split_name \
--model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \
--vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \
--vocoder-trt-engine-path $vocoder_trt_engine_path \
--model-path $ckpt_file \
--vocab-file $vocab_file \
--vocoder-trt-engine-path $VOCODER_TRT_ENGINE_PATH \
--backend-type $backend_type \
--tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1
--tllm-model-dir $TRTLLM_ENGINE_DIR || exit 1
fi
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
echo "Native Pytorch: offline decoding benchmark test"
pip install -r requirements-pytorch.txt
batch_size=1
if ! python3 -c "import f5_tts" &> /dev/null; then
pip install -e ../../../../
fi
batch_size=1 # set attn_mask_enabled=True if batching in actual use case
split_name=wenetspeech4tts
backend_type=pytorch
log_dir=./log_benchmark_batch_size_${batch_size}_${split_name}_${backend_type}
log_dir=./tests/benchmark_${model}_batch_size_${batch_size}_${split_name}_${backend_type}
rm -r $log_dir
ln -s model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py ./
torchrun --nproc_per_node=1 \
benchmark.py --output-dir $log_dir \
--batch-size $batch_size \
--split-name $split_name \
--enable-warmup \
--model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \
--vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \
--model-path $ckpt_file \
--vocab-file $vocab_file \
--backend-type $backend_type \
--tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1
--tllm-model-dir $TRTLLM_ENGINE_DIR || exit 1
fi

View File

@@ -23,168 +23,12 @@ def split_q_bias_tp(v, n_head, n_hidden, tensor_parallel, rank):
return split_v.contiguous()
FACEBOOK_DIT_NAME_MAPPING = {
"^time_embed.time_mlp.0.weight$": "time_embed.mlp1.weight",
"^time_embed.time_mlp.0.bias$": "time_embed.mlp1.bias",
"^time_embed.time_mlp.2.weight$": "time_embed.mlp2.weight",
"^time_embed.time_mlp.2.bias$": "time_embed.mlp2.bias",
"^input_embed.conv_pos_embed.conv1d.0.weight$": "input_embed.conv_pos_embed.conv1d1.weight",
"^input_embed.conv_pos_embed.conv1d.0.bias$": "input_embed.conv_pos_embed.conv1d1.bias",
"^input_embed.conv_pos_embed.conv1d.2.weight$": "input_embed.conv_pos_embed.conv1d2.weight",
"^input_embed.conv_pos_embed.conv1d.2.bias$": "input_embed.conv_pos_embed.conv1d2.bias",
"^transformer_blocks.0.attn.to_out.0.weight$": "transformer_blocks.0.attn.to_out.weight",
"^transformer_blocks.0.attn.to_out.0.bias$": "transformer_blocks.0.attn.to_out.bias",
"^transformer_blocks.1.attn.to_out.0.weight$": "transformer_blocks.1.attn.to_out.weight",
"^transformer_blocks.1.attn.to_out.0.bias$": "transformer_blocks.1.attn.to_out.bias",
"^transformer_blocks.2.attn.to_out.0.weight$": "transformer_blocks.2.attn.to_out.weight",
"^transformer_blocks.2.attn.to_out.0.bias$": "transformer_blocks.2.attn.to_out.bias",
"^transformer_blocks.3.attn.to_out.0.weight$": "transformer_blocks.3.attn.to_out.weight",
"^transformer_blocks.3.attn.to_out.0.bias$": "transformer_blocks.3.attn.to_out.bias",
"^transformer_blocks.4.attn.to_out.0.weight$": "transformer_blocks.4.attn.to_out.weight",
"^transformer_blocks.4.attn.to_out.0.bias$": "transformer_blocks.4.attn.to_out.bias",
"^transformer_blocks.5.attn.to_out.0.weight$": "transformer_blocks.5.attn.to_out.weight",
"^transformer_blocks.5.attn.to_out.0.bias$": "transformer_blocks.5.attn.to_out.bias",
"^transformer_blocks.6.attn.to_out.0.weight$": "transformer_blocks.6.attn.to_out.weight",
"^transformer_blocks.6.attn.to_out.0.bias$": "transformer_blocks.6.attn.to_out.bias",
"^transformer_blocks.7.attn.to_out.0.weight$": "transformer_blocks.7.attn.to_out.weight",
"^transformer_blocks.7.attn.to_out.0.bias$": "transformer_blocks.7.attn.to_out.bias",
"^transformer_blocks.8.attn.to_out.0.weight$": "transformer_blocks.8.attn.to_out.weight",
"^transformer_blocks.8.attn.to_out.0.bias$": "transformer_blocks.8.attn.to_out.bias",
"^transformer_blocks.9.attn.to_out.0.weight$": "transformer_blocks.9.attn.to_out.weight",
"^transformer_blocks.9.attn.to_out.0.bias$": "transformer_blocks.9.attn.to_out.bias",
"^transformer_blocks.10.attn.to_out.0.weight$": "transformer_blocks.10.attn.to_out.weight",
"^transformer_blocks.10.attn.to_out.0.bias$": "transformer_blocks.10.attn.to_out.bias",
"^transformer_blocks.11.attn.to_out.0.weight$": "transformer_blocks.11.attn.to_out.weight",
"^transformer_blocks.11.attn.to_out.0.bias$": "transformer_blocks.11.attn.to_out.bias",
"^transformer_blocks.12.attn.to_out.0.weight$": "transformer_blocks.12.attn.to_out.weight",
"^transformer_blocks.12.attn.to_out.0.bias$": "transformer_blocks.12.attn.to_out.bias",
"^transformer_blocks.13.attn.to_out.0.weight$": "transformer_blocks.13.attn.to_out.weight",
"^transformer_blocks.13.attn.to_out.0.bias$": "transformer_blocks.13.attn.to_out.bias",
"^transformer_blocks.14.attn.to_out.0.weight$": "transformer_blocks.14.attn.to_out.weight",
"^transformer_blocks.14.attn.to_out.0.bias$": "transformer_blocks.14.attn.to_out.bias",
"^transformer_blocks.15.attn.to_out.0.weight$": "transformer_blocks.15.attn.to_out.weight",
"^transformer_blocks.15.attn.to_out.0.bias$": "transformer_blocks.15.attn.to_out.bias",
"^transformer_blocks.16.attn.to_out.0.weight$": "transformer_blocks.16.attn.to_out.weight",
"^transformer_blocks.16.attn.to_out.0.bias$": "transformer_blocks.16.attn.to_out.bias",
"^transformer_blocks.17.attn.to_out.0.weight$": "transformer_blocks.17.attn.to_out.weight",
"^transformer_blocks.17.attn.to_out.0.bias$": "transformer_blocks.17.attn.to_out.bias",
"^transformer_blocks.18.attn.to_out.0.weight$": "transformer_blocks.18.attn.to_out.weight",
"^transformer_blocks.18.attn.to_out.0.bias$": "transformer_blocks.18.attn.to_out.bias",
"^transformer_blocks.19.attn.to_out.0.weight$": "transformer_blocks.19.attn.to_out.weight",
"^transformer_blocks.19.attn.to_out.0.bias$": "transformer_blocks.19.attn.to_out.bias",
"^transformer_blocks.20.attn.to_out.0.weight$": "transformer_blocks.20.attn.to_out.weight",
"^transformer_blocks.20.attn.to_out.0.bias$": "transformer_blocks.20.attn.to_out.bias",
"^transformer_blocks.21.attn.to_out.0.weight$": "transformer_blocks.21.attn.to_out.weight",
"^transformer_blocks.21.attn.to_out.0.bias$": "transformer_blocks.21.attn.to_out.bias",
"^transformer_blocks.0.ff.ff.0.0.weight$": "transformer_blocks.0.ff.project_in.weight",
"^transformer_blocks.0.ff.ff.0.0.bias$": "transformer_blocks.0.ff.project_in.bias",
"^transformer_blocks.0.ff.ff.2.weight$": "transformer_blocks.0.ff.ff.weight",
"^transformer_blocks.0.ff.ff.2.bias$": "transformer_blocks.0.ff.ff.bias",
"^transformer_blocks.1.ff.ff.0.0.weight$": "transformer_blocks.1.ff.project_in.weight",
"^transformer_blocks.1.ff.ff.0.0.bias$": "transformer_blocks.1.ff.project_in.bias",
"^transformer_blocks.1.ff.ff.2.weight$": "transformer_blocks.1.ff.ff.weight",
"^transformer_blocks.1.ff.ff.2.bias$": "transformer_blocks.1.ff.ff.bias",
"^transformer_blocks.2.ff.ff.0.0.weight$": "transformer_blocks.2.ff.project_in.weight",
"^transformer_blocks.2.ff.ff.0.0.bias$": "transformer_blocks.2.ff.project_in.bias",
"^transformer_blocks.2.ff.ff.2.weight$": "transformer_blocks.2.ff.ff.weight",
"^transformer_blocks.2.ff.ff.2.bias$": "transformer_blocks.2.ff.ff.bias",
"^transformer_blocks.3.ff.ff.0.0.weight$": "transformer_blocks.3.ff.project_in.weight",
"^transformer_blocks.3.ff.ff.0.0.bias$": "transformer_blocks.3.ff.project_in.bias",
"^transformer_blocks.3.ff.ff.2.weight$": "transformer_blocks.3.ff.ff.weight",
"^transformer_blocks.3.ff.ff.2.bias$": "transformer_blocks.3.ff.ff.bias",
"^transformer_blocks.4.ff.ff.0.0.weight$": "transformer_blocks.4.ff.project_in.weight",
"^transformer_blocks.4.ff.ff.0.0.bias$": "transformer_blocks.4.ff.project_in.bias",
"^transformer_blocks.4.ff.ff.2.weight$": "transformer_blocks.4.ff.ff.weight",
"^transformer_blocks.4.ff.ff.2.bias$": "transformer_blocks.4.ff.ff.bias",
"^transformer_blocks.5.ff.ff.0.0.weight$": "transformer_blocks.5.ff.project_in.weight",
"^transformer_blocks.5.ff.ff.0.0.bias$": "transformer_blocks.5.ff.project_in.bias",
"^transformer_blocks.5.ff.ff.2.weight$": "transformer_blocks.5.ff.ff.weight",
"^transformer_blocks.5.ff.ff.2.bias$": "transformer_blocks.5.ff.ff.bias",
"^transformer_blocks.6.ff.ff.0.0.weight$": "transformer_blocks.6.ff.project_in.weight",
"^transformer_blocks.6.ff.ff.0.0.bias$": "transformer_blocks.6.ff.project_in.bias",
"^transformer_blocks.6.ff.ff.2.weight$": "transformer_blocks.6.ff.ff.weight",
"^transformer_blocks.6.ff.ff.2.bias$": "transformer_blocks.6.ff.ff.bias",
"^transformer_blocks.7.ff.ff.0.0.weight$": "transformer_blocks.7.ff.project_in.weight",
"^transformer_blocks.7.ff.ff.0.0.bias$": "transformer_blocks.7.ff.project_in.bias",
"^transformer_blocks.7.ff.ff.2.weight$": "transformer_blocks.7.ff.ff.weight",
"^transformer_blocks.7.ff.ff.2.bias$": "transformer_blocks.7.ff.ff.bias",
"^transformer_blocks.8.ff.ff.0.0.weight$": "transformer_blocks.8.ff.project_in.weight",
"^transformer_blocks.8.ff.ff.0.0.bias$": "transformer_blocks.8.ff.project_in.bias",
"^transformer_blocks.8.ff.ff.2.weight$": "transformer_blocks.8.ff.ff.weight",
"^transformer_blocks.8.ff.ff.2.bias$": "transformer_blocks.8.ff.ff.bias",
"^transformer_blocks.9.ff.ff.0.0.weight$": "transformer_blocks.9.ff.project_in.weight",
"^transformer_blocks.9.ff.ff.0.0.bias$": "transformer_blocks.9.ff.project_in.bias",
"^transformer_blocks.9.ff.ff.2.weight$": "transformer_blocks.9.ff.ff.weight",
"^transformer_blocks.9.ff.ff.2.bias$": "transformer_blocks.9.ff.ff.bias",
"^transformer_blocks.10.ff.ff.0.0.weight$": "transformer_blocks.10.ff.project_in.weight",
"^transformer_blocks.10.ff.ff.0.0.bias$": "transformer_blocks.10.ff.project_in.bias",
"^transformer_blocks.10.ff.ff.2.weight$": "transformer_blocks.10.ff.ff.weight",
"^transformer_blocks.10.ff.ff.2.bias$": "transformer_blocks.10.ff.ff.bias",
"^transformer_blocks.11.ff.ff.0.0.weight$": "transformer_blocks.11.ff.project_in.weight",
"^transformer_blocks.11.ff.ff.0.0.bias$": "transformer_blocks.11.ff.project_in.bias",
"^transformer_blocks.11.ff.ff.2.weight$": "transformer_blocks.11.ff.ff.weight",
"^transformer_blocks.11.ff.ff.2.bias$": "transformer_blocks.11.ff.ff.bias",
"^transformer_blocks.12.ff.ff.0.0.weight$": "transformer_blocks.12.ff.project_in.weight",
"^transformer_blocks.12.ff.ff.0.0.bias$": "transformer_blocks.12.ff.project_in.bias",
"^transformer_blocks.12.ff.ff.2.weight$": "transformer_blocks.12.ff.ff.weight",
"^transformer_blocks.12.ff.ff.2.bias$": "transformer_blocks.12.ff.ff.bias",
"^transformer_blocks.13.ff.ff.0.0.weight$": "transformer_blocks.13.ff.project_in.weight",
"^transformer_blocks.13.ff.ff.0.0.bias$": "transformer_blocks.13.ff.project_in.bias",
"^transformer_blocks.13.ff.ff.2.weight$": "transformer_blocks.13.ff.ff.weight",
"^transformer_blocks.13.ff.ff.2.bias$": "transformer_blocks.13.ff.ff.bias",
"^transformer_blocks.14.ff.ff.0.0.weight$": "transformer_blocks.14.ff.project_in.weight",
"^transformer_blocks.14.ff.ff.0.0.bias$": "transformer_blocks.14.ff.project_in.bias",
"^transformer_blocks.14.ff.ff.2.weight$": "transformer_blocks.14.ff.ff.weight",
"^transformer_blocks.14.ff.ff.2.bias$": "transformer_blocks.14.ff.ff.bias",
"^transformer_blocks.15.ff.ff.0.0.weight$": "transformer_blocks.15.ff.project_in.weight",
"^transformer_blocks.15.ff.ff.0.0.bias$": "transformer_blocks.15.ff.project_in.bias",
"^transformer_blocks.15.ff.ff.2.weight$": "transformer_blocks.15.ff.ff.weight",
"^transformer_blocks.15.ff.ff.2.bias$": "transformer_blocks.15.ff.ff.bias",
"^transformer_blocks.16.ff.ff.0.0.weight$": "transformer_blocks.16.ff.project_in.weight",
"^transformer_blocks.16.ff.ff.0.0.bias$": "transformer_blocks.16.ff.project_in.bias",
"^transformer_blocks.16.ff.ff.2.weight$": "transformer_blocks.16.ff.ff.weight",
"^transformer_blocks.16.ff.ff.2.bias$": "transformer_blocks.16.ff.ff.bias",
"^transformer_blocks.17.ff.ff.0.0.weight$": "transformer_blocks.17.ff.project_in.weight",
"^transformer_blocks.17.ff.ff.0.0.bias$": "transformer_blocks.17.ff.project_in.bias",
"^transformer_blocks.17.ff.ff.2.weight$": "transformer_blocks.17.ff.ff.weight",
"^transformer_blocks.17.ff.ff.2.bias$": "transformer_blocks.17.ff.ff.bias",
"^transformer_blocks.18.ff.ff.0.0.weight$": "transformer_blocks.18.ff.project_in.weight",
"^transformer_blocks.18.ff.ff.0.0.bias$": "transformer_blocks.18.ff.project_in.bias",
"^transformer_blocks.18.ff.ff.2.weight$": "transformer_blocks.18.ff.ff.weight",
"^transformer_blocks.18.ff.ff.2.bias$": "transformer_blocks.18.ff.ff.bias",
"^transformer_blocks.19.ff.ff.0.0.weight$": "transformer_blocks.19.ff.project_in.weight",
"^transformer_blocks.19.ff.ff.0.0.bias$": "transformer_blocks.19.ff.project_in.bias",
"^transformer_blocks.19.ff.ff.2.weight$": "transformer_blocks.19.ff.ff.weight",
"^transformer_blocks.19.ff.ff.2.bias$": "transformer_blocks.19.ff.ff.bias",
"^transformer_blocks.20.ff.ff.0.0.weight$": "transformer_blocks.20.ff.project_in.weight",
"^transformer_blocks.20.ff.ff.0.0.bias$": "transformer_blocks.20.ff.project_in.bias",
"^transformer_blocks.20.ff.ff.2.weight$": "transformer_blocks.20.ff.ff.weight",
"^transformer_blocks.20.ff.ff.2.bias$": "transformer_blocks.20.ff.ff.bias",
"^transformer_blocks.21.ff.ff.0.0.weight$": "transformer_blocks.21.ff.project_in.weight",
"^transformer_blocks.21.ff.ff.0.0.bias$": "transformer_blocks.21.ff.project_in.bias",
"^transformer_blocks.21.ff.ff.2.weight$": "transformer_blocks.21.ff.ff.weight",
"^transformer_blocks.21.ff.ff.2.bias$": "transformer_blocks.21.ff.ff.bias",
}
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name",
type=str,
default="F5TTS_Base",
choices=[
"F5TTS_Base",
],
) # TODO: support F5TTS_v1_Base
parser.add_argument("--timm_ckpt", type=str, default="./ckpts/model_1200000.pt")
parser.add_argument("--pytorch_ckpt", type=str, default="./ckpts/model_last.pt")
parser.add_argument(
"--output_dir", type=str, default="./tllm_checkpoint", help="The path to save the TensorRT-LLM checkpoint"
)
parser.add_argument("--hidden_size", type=int, default=1024, help="The hidden size of DiT")
parser.add_argument("--depth", type=int, default=22, help="The number of DiTBlock layers")
parser.add_argument("--num_heads", type=int, default=16, help="The number of heads of attention module")
parser.add_argument("--cfg_scale", type=float, default=4.0)
parser.add_argument("--tp_size", type=int, default=1, help="N-way tensor parallelism size")
parser.add_argument("--cp_size", type=int, default=1, help="Context parallelism size")
parser.add_argument("--pp_size", type=int, default=1, help="N-way pipeline parallelism size")
@@ -193,33 +37,119 @@ def parse_arguments():
parser.add_argument(
"--workers", type=int, default=1, help="The number of workers for converting checkpoint in parallel"
)
parser.add_argument(
"--model_name",
type=str,
default="F5TTS_Custom",
choices=[
"F5TTS_v1_Base",
"F5TTS_Base",
"F5TTS_v1_Small",
"F5TTS_Small",
], # if set, overwrite the below hyperparams
)
parser.add_argument("--hidden_size", type=int, default=1024, help="The hidden size of DiT")
parser.add_argument("--depth", type=int, default=22, help="The number of DiTBlock layers")
parser.add_argument("--num_heads", type=int, default=16, help="The number of heads of attention module")
parser.add_argument("--dim_head", type=int, default=64, help="The dimension of attention head")
parser.add_argument("--ff_mult", type=int, default=2, help="The FFN intermediate dimension multiplier")
parser.add_argument("--text_dim", type=int, default=512, help="The output dimension of text encoder")
parser.add_argument(
"--text_mask_padding",
type=lambda x: x.lower() == "true",
choices=[True, False],
default=True,
help="Whether apply padding mask for conv layers in text encoder",
)
parser.add_argument("--conv_layers", type=int, default=4, help="The number of conv layers of text encoder")
parser.add_argument("--pe_attn_head", type=int, default=None, help="The number of attn head that apply pos emb")
args = parser.parse_args()
# overwrite if --model_name ordered
if args.model_name == "F5TTS_v1_Base":
args.hidden_size = 1024
args.depth = 22
args.num_heads = 16
args.dim_head = 64
args.ff_mult = 2
args.text_dim = 512
args.text_mask_padding = True
args.conv_layers = 4
args.pe_attn_head = None
elif args.model_name == "F5TTS_Base":
args.hidden_size = 1024
args.depth = 22
args.num_heads = 16
args.dim_head = 64
args.ff_mult = 2
args.text_dim = 512
args.text_mask_padding = False
args.conv_layers = 4
args.pe_attn_head = 1
elif args.model_name == "F5TTS_v1_Small":
args.hidden_size = 768
args.depth = 18
args.num_heads = 12
args.dim_head = 64
args.ff_mult = 2
args.text_dim = 512
args.text_mask_padding = True
args.conv_layers = 4
args.pe_attn_head = None
elif args.model_name == "F5TTS_Small":
args.hidden_size = 768
args.depth = 18
args.num_heads = 12
args.dim_head = 64
args.ff_mult = 2
args.text_dim = 512
args.text_mask_padding = False
args.conv_layers = 4
args.pe_attn_head = 1
return args
def convert_timm_dit(args, mapping, dtype="float32"):
def convert_pytorch_dit_to_trtllm_weight(args, mapping, dtype="float32", use_ema=True):
weights = {}
tik = time.time()
torch_dtype = str_dtype_to_torch(dtype)
tensor_parallel = mapping.tp_size
model_params = dict(torch.load(args.timm_ckpt))
model_params = {
k: v for k, v in model_params["ema_model_state_dict"].items() if k.startswith("ema_model.transformer")
ckpt_path = args.pytorch_ckpt
ckpt_type = ckpt_path.split(".")[-1]
if ckpt_type == "safetensors":
from safetensors.torch import load_file
model_params = load_file(ckpt_path)
else:
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
model_params = ckpt["ema_model_state_dict"] if use_ema else ckpt["model_state_dict"]
prefix = "ema_model.transformer." if use_ema else "transformer."
if any(k.startswith(prefix) for k in model_params.keys()):
model_params = {
key[len(prefix) :] if key.startswith(prefix) else key: value
for key, value in model_params.items()
if key.startswith(prefix)
}
pytorch_to_trtllm_name = {
r"^time_embed\.time_mlp\.0\.(weight|bias)$": r"time_embed.mlp1.\1",
r"^time_embed\.time_mlp\.2\.(weight|bias)$": r"time_embed.mlp2.\1",
r"^input_embed\.conv_pos_embed\.conv1d\.0\.(weight|bias)$": r"input_embed.conv_pos_embed.conv1d1.\1",
r"^input_embed\.conv_pos_embed\.conv1d\.2\.(weight|bias)$": r"input_embed.conv_pos_embed.conv1d2.\1",
r"^transformer_blocks\.(\d+)\.attn\.to_out\.0\.(weight|bias)$": r"transformer_blocks.\1.attn.to_out.\2",
r"^transformer_blocks\.(\d+)\.ff\.ff\.0\.0\.(weight|bias)$": r"transformer_blocks.\1.ff.project_in.\2",
r"^transformer_blocks\.(\d+)\.ff\.ff\.2\.(weight|bias)$": r"transformer_blocks.\1.ff.ff.\2",
}
prefix = "ema_model.transformer."
model_params = {key[len(prefix) :] if key.startswith(prefix) else key: value for key, value in model_params.items()}
timm_to_trtllm_name = FACEBOOK_DIT_NAME_MAPPING
def get_trtllm_name(timm_name):
for k, v in timm_to_trtllm_name.items():
m = re.match(k, timm_name)
if m is not None:
if "*" in v:
v = v.replace("*", m.groups()[0])
return v
return timm_name
def get_trtllm_name(pytorch_name):
for pytorch_name_pattern, trtllm_name_replacement in pytorch_to_trtllm_name.items():
trtllm_name_if_matched = re.sub(pytorch_name_pattern, trtllm_name_replacement, pytorch_name)
if trtllm_name_if_matched != pytorch_name:
return trtllm_name_if_matched
return pytorch_name
weights = dict()
for name, param in model_params.items():
@@ -230,7 +160,7 @@ def convert_timm_dit(args, mapping, dtype="float32"):
assert len(weights) == len(model_params)
# new_prefix = 'f5_transformer.'
# new_prefix = "f5_transformer."
new_prefix = ""
weights = {new_prefix + key: value for key, value in weights.items()}
import math
@@ -272,19 +202,19 @@ def save_config(args):
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
config = {
"architecture": "F5TTS",
"architecture": "F5TTS", # set the same as in ../patch/__init__.py
"dtype": args.dtype,
"hidden_size": 1024,
"num_hidden_layers": 22,
"num_attention_heads": 16,
"dim_head": 64,
"dropout": 0.1,
"ff_mult": 2,
"hidden_size": args.hidden_size,
"num_hidden_layers": args.depth,
"num_attention_heads": args.num_heads,
"dim_head": args.dim_head,
"dropout": 0.0, # inference-only
"ff_mult": args.ff_mult,
"mel_dim": 100,
"text_num_embeds": 256,
"text_dim": 512,
"conv_layers": 4,
"long_skip_connection": False,
"text_dim": args.text_dim,
"text_mask_padding": args.text_mask_padding,
"conv_layers": args.conv_layers,
"pe_attn_head": args.pe_attn_head,
"mapping": {
"world_size": args.cp_size * args.tp_size * args.pp_size,
"cp_size": args.cp_size,
@@ -296,7 +226,7 @@ def save_config(args):
config["quantization"] = {
"quant_algo": "FP8",
# TODO: add support for exclude modules.
# 'exclude_modules': "*final_layer*",
# "exclude_modules": "*final_layer*",
}
with open(os.path.join(args.output_dir, "config.json"), "w") as f:
@@ -315,7 +245,7 @@ def covert_and_save(args, rank):
pp_size=args.pp_size,
)
weights = convert_timm_dit(args, mapping, dtype=args.dtype)
weights = convert_pytorch_dit_to_trtllm_weight(args, mapping, dtype=args.dtype)
safetensors.torch.save_file(weights, os.path.join(args.output_dir, f"rank{rank}.safetensors"))
@@ -344,9 +274,9 @@ def main():
assert args.pp_size == 1, "PP is not supported yet."
tik = time.time()
if args.timm_ckpt is None:
if args.pytorch_ckpt is None:
return
print("start execute")
print("Start execute")
execute(args.workers, [covert_and_save] * world_size, args)
tok = time.time()

View File

@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Manual installation of TensorRT, in case not using NVIDIA NGC:
# https://docs.nvidia.com/deeplearning/tensorrt/latest/installing-tensorrt/installing.html#downloading-tensorrt
TRTEXEC="/usr/src/tensorrt/bin/trtexec"
ONNX_PATH=$1
@@ -28,7 +30,7 @@ MAX_BATCH_SIZE=8
MIN_INPUT_LENGTH=1
OPT_INPUT_LENGTH=1000
MAX_INPUT_LENGTH=3000
MAX_INPUT_LENGTH=3000 # 4096
MEL_MIN_SHAPE="${MIN_BATCH_SIZE}x100x${MIN_INPUT_LENGTH}"
MEL_OPT_SHAPE="${OPT_BATCH_SIZE}x100x${OPT_INPUT_LENGTH}"
@@ -40,4 +42,3 @@ ${TRTEXEC} \
--maxShapes="mel:${MEL_MAX_SHAPE}" \
--onnx=${ONNX_PATH} \
--saveEngine=${ENGINE_PATH}

View File

@@ -0,0 +1,32 @@
import math
from torch.utils.data import SequentialSampler
from f5_tts.model.dataset import DynamicBatchSampler, load_dataset
train_dataset = load_dataset("Emilia_ZH_EN", "pinyin")
sampler = SequentialSampler(train_dataset)
gpus = 8
batch_size_per_gpu = 38400
max_samples_per_gpu = 64
max_updates = 1250000
batch_sampler = DynamicBatchSampler(
sampler,
batch_size_per_gpu,
max_samples=max_samples_per_gpu,
random_seed=666,
drop_residual=False,
)
print(
f"One epoch has {len(batch_sampler) / gpus} updates if gpus={gpus}, with "
f"batch_size_per_gpu={batch_size_per_gpu} (frames) & "
f"max_samples_per_gpu={max_samples_per_gpu}."
)
print(
f"If gpus={gpus}, for max_updates={max_updates} "
f"should set epoch={math.ceil(max_updates / len(batch_sampler) * gpus)}."
)

View File

@@ -225,5 +225,5 @@ if __name__ == "__main__":
# bad zh asr cnt 230435 (samples)
# bad eh asr cnt 37217 (samples)
# vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme)
# vocab size may be slightly different due to rjieba tokenizer and pypinyin (e.g. way of polyphoneme)
# please be careful if using pretrained model, make sure the vocab.txt is same

View File

@@ -122,5 +122,5 @@ if __name__ == "__main__":
# - - 1459 (polyphone)
# char vocab size 5264 5219 5042
# vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme)
# vocab size may be slightly different due to rjieba tokenizer and pypinyin (e.g. way of polyphoneme)
# please be careful if using pretrained model, make sure the vocab.txt is same