25 Commits
1.0.8 ... 1.1.0

Author SHA1 Message Date
SWivid
25b3291715 Update README.md 2025-04-03 14:41:52 +08:00
SWivid
16c480a61d v1.1.0 Support GPU Depolyment with Triton and TensorRT-LLM #944 2025-04-03 14:37:58 +08:00
SWivid
d9dfbe47cc Update README.md 2025-04-03 14:36:22 +08:00
Yushen CHEN
d1f6c95fe8 Merge pull request #944 from yuekaizhang/triton
Support GPU Depolyment Solution with Triton and TensorRT-LLM
2025-04-03 13:42:37 +08:00
root
2428d01a56 remove empty lines 2025-04-03 05:25:29 +00:00
root
9401842930 add http client 2025-04-03 05:14:03 +00:00
root
eca56943ec fix docker compose issue 2025-04-03 04:31:33 +00:00
root
ae51cc3d34 fix bug 2025-04-03 04:25:43 +00:00
root
4681a1c177 remove annotation 2025-04-03 02:35:26 +00:00
root
5b178397e0 remove unused codes 2025-04-03 02:34:28 +00:00
Yuekai Zhang
2724f9f101 add Nvidia Triton TensorRT-LLM solution 2025-04-02 19:04:45 -07:00
SWivid
7258b09529 v1.0.10 support custom chat model 2025-03-31 21:15:26 +08:00
SWivid
784e3862b4 add microsoft/Phi-4-mini-instruct to chat model list #937 2025-03-31 21:14:39 +08:00
SWivid
6f6968b034 formatting 2025-03-31 19:45:38 +08:00
maximechen
9bd2d13be1 Merge branch 'huanglizhuo-feat/support-custom-chat-model' 2025-03-31 19:22:08 +08:00
maximechen
b7c41af9cd reorganize and distinguish behavior from local and space 2025-03-31 19:11:52 +08:00
huanglizhuo
eaa7fd8a01 Reapply pre-commit hooks 2025-03-29 20:58:42 +09:00
Yushen CHEN
f34465d118 v1.0.9 several fixes 2025-03-28 23:12:13 +08:00
lizhuo
393993321d fix: use pydantic<=2.10.6 to address dependency conflict with gradio-app #930 2025-03-28 23:10:41 +08:00
lizhuo
29d3326bed update: JA latest HF path in SHARED.md #928
* fix: update japanese latest hf path
* update the huggingface url
2025-03-28 22:36:17 +08:00
Zhikang Niu
67e43dc0fb Merge pull request #926 from huanglizhuo/fix/shared-file-path
fix the SHARED.md file path
2025-03-28 17:14:54 +08:00
huanglizhuo
8469025b1c fix the shared.md file path 2025-03-28 17:52:08 +09:00
Zhikang Niu
5bd8cd7aed update: better save last & per ckpt logic #924
Co-authored-by: Yushen CHEN <45333109+SWivid@users.noreply.github.com>
2025-03-28 13:53:12 +08:00
SWivid
7236536f9a update utils_infer.py 2025-03-25 17:24:20 +08:00
SWivid
6b7f6eefdc fix typo in trainer.py with 4ae5347282 formatting #909 2025-03-25 16:17:03 +08:00
27 changed files with 3301 additions and 44 deletions

View File

@@ -110,6 +110,9 @@ docker container run --rm -it --gpus=all --mount 'type=volume,source=f5-tts,targ
## Inference
- In order to achieve desired performance, take a moment to read [detailed guidance](src/f5_tts/infer).
- By properly searching the keywords of problem encountered, [issues](https://github.com/SWivid/F5-TTS/issues?q=is%3Aissue) are very helpful.
### 1. Gradio App
Currently supported features:
@@ -176,10 +179,18 @@ f5-tts_infer-cli -c custom.toml
f5-tts_infer-cli -c src/f5_tts/infer/examples/multi/story.toml
```
### 3. More instructions
### 3. Runtime
- In order to have better generation results, take a moment to read [detailed guidance](src/f5_tts/infer).
- The [Issues](https://github.com/SWivid/F5-TTS/issues?q=is%3Aissue) are very useful, please try to find the solution by properly searching the keywords of problem encountered. If no answer found, then feel free to open an issue.
Deployment solution with Triton and TensorRT-LLM.
#### Benchmark Results
Decoding on a single L20 GPU, using 26 different prompt_audio & target_text pairs.
| Model | Concurrency | Avg Latency | RTF |
|-------|-------------|----------------|-------|
| F5-TTS Base (Vocos) | 1 | 253 ms | 0.0394|
See [detailed instructions](src/f5_tts/runtime/triton_trtllm/README.md) for more information.
## Training
@@ -231,6 +242,7 @@ Note: Some model components have linting exceptions for E722 to accommodate tens
- [mrfakename](https://x.com/realmrfakename) huggingface space demo ~
- [f5-tts-mlx](https://github.com/lucasnewman/f5-tts-mlx/tree/main) Implementation with MLX framework by [Lucas Newman](https://github.com/lucasnewman)
- [F5-TTS-ONNX](https://github.com/DakeQQ/F5-TTS-ONNX) ONNX Runtime version by [DakeQQ](https://github.com/DakeQQ)
- [Yuekai Zhang](https://github.com/yuekaizhang) Triton and TensorRT-LLM support ~
## Citation
If our work and codebase is useful for you, please cite as:

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "f5-tts"
version = "1.0.8"
version = "1.1.0"
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
readme = "README.md"
license = {text = "MIT License"}
@@ -26,6 +26,7 @@ dependencies = [
"librosa",
"matplotlib",
"numpy<=1.26.4",
"pydantic<=2.10.6",
"pydub",
"pypinyin",
"safetensors",

View File

@@ -24,7 +24,7 @@ Currently supported features:
- Basic TTS with Chunk Inference
- Multi-Style / Multi-Speaker Generation
- Voice Chat powered by Qwen2.5-3B-Instruct
- [Custom inference with more language support](src/f5_tts/infer/SHARED.md)
- [Custom inference with more language support](SHARED.md)
The cli command `f5-tts_infer-gradio` equals to `python src/f5_tts/infer/infer_gradio.py`, which launches a Gradio APP (web interface) for inference.

View File

@@ -137,11 +137,11 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "
#### F5-TTS Base @ ja @ Jmica
|Model|🤗Hugging Face|Data (Hours)|Model License|
|:---:|:------------:|:-----------:|:-------------:|
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/Jmica/F5TTS/tree/main/JA_25498980)|[Emilia 1.7k JA](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07) & [Galgame Dataset 5.4k](https://huggingface.co/datasets/OOPPEENN/Galgame_Dataset)|cc-by-nc-4.0|
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/Jmica/F5TTS/tree/main/JA_21999120)|[Emilia 1.7k JA](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07) & [Galgame Dataset 5.4k](https://huggingface.co/datasets/OOPPEENN/Galgame_Dataset)|cc-by-nc-4.0|
```bash
Model: hf://Jmica/F5TTS/JA_25498980/model_25498980.pt
Vocab: hf://Jmica/F5TTS/JA_25498980/vocab_updated.txt
Model: hf://Jmica/F5TTS/JA_21999120/model_21999120.pt
Vocab: hf://Jmica/F5TTS/JA_21999120/vocab_japanese.txt
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
```

View File

@@ -1,6 +1,7 @@
# ruff: noqa: E402
# Above allows ruff to ignore E402: module level import not at top of file
import gc
import json
import re
import tempfile
@@ -11,6 +12,7 @@ import click
import gradio as gr
import numpy as np
import soundfile as sf
import torch
import torchaudio
from cached_path import cached_path
from transformers import AutoModelForCausalLM, AutoTokenizer
@@ -550,35 +552,50 @@ Have a conversation with an AI using your reference voice!
"""
)
if not USING_SPACES:
load_chat_model_btn = gr.Button("Load Chat Model", variant="primary")
chat_model_name_list = [
"Qwen/Qwen2.5-3B-Instruct",
"microsoft/Phi-4-mini-instruct",
]
chat_interface_container = gr.Column(visible=False)
@gpu_decorator
def load_chat_model(chat_model_name):
show_info = gr.Info
global chat_model_state, chat_tokenizer_state
if chat_model_state is not None:
chat_model_state = None
chat_tokenizer_state = None
gc.collect()
torch.cuda.empty_cache()
@gpu_decorator
def load_chat_model():
global chat_model_state, chat_tokenizer_state
if chat_model_state is None:
show_info = gr.Info
show_info("Loading chat model...")
model_name = "Qwen/Qwen2.5-3B-Instruct"
chat_model_state = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype="auto", device_map="auto"
)
chat_tokenizer_state = AutoTokenizer.from_pretrained(model_name)
show_info("Chat model loaded.")
show_info(f"Loading chat model: {chat_model_name}")
chat_model_state = AutoModelForCausalLM.from_pretrained(chat_model_name, torch_dtype="auto", device_map="auto")
chat_tokenizer_state = AutoTokenizer.from_pretrained(chat_model_name)
show_info(f"Chat model {chat_model_name} loaded successfully!")
return gr.update(visible=False), gr.update(visible=True)
return gr.update(visible=False), gr.update(visible=True)
load_chat_model_btn.click(load_chat_model, outputs=[load_chat_model_btn, chat_interface_container])
if USING_SPACES:
load_chat_model(chat_model_name_list[0])
else:
chat_interface_container = gr.Column()
chat_model_name_input = gr.Dropdown(
choices=chat_model_name_list,
value=chat_model_name_list[0],
label="Chat Model Name",
info="Enter the name of a HuggingFace chat model",
allow_custom_value=not USING_SPACES,
)
load_chat_model_btn = gr.Button("Load Chat Model", variant="primary", visible=not USING_SPACES)
chat_interface_container = gr.Column(visible=USING_SPACES)
if chat_model_state is None:
model_name = "Qwen/Qwen2.5-3B-Instruct"
chat_model_state = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto")
chat_tokenizer_state = AutoTokenizer.from_pretrained(model_name)
chat_model_name_input.change(
lambda: gr.update(visible=True),
None,
load_chat_model_btn,
show_progress="hidden",
)
load_chat_model_btn.click(
load_chat_model, inputs=[chat_model_name_input], outputs=[load_chat_model_btn, chat_interface_container]
)
with chat_interface_container:
with gr.Row():

View File

@@ -21,7 +21,7 @@ import numpy as np
import torch
import torchaudio
import tqdm
from huggingface_hub import snapshot_download, hf_hub_download
from huggingface_hub import hf_hub_download
from pydub import AudioSegment, silence
from transformers import pipeline
from vocos import Vocos
@@ -128,11 +128,12 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=dev
except ImportError:
print("You need to follow the README to init submodule and change the BigVGAN source code.")
if is_local:
"""download from https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x/tree/main"""
# download generator from https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x/tree/main
vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
else:
local_path = snapshot_download(repo_id="nvidia/bigvgan_v2_24khz_100band_256x", cache_dir=hf_cache_dir)
vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
vocoder = bigvgan.BigVGAN.from_pretrained(
"nvidia/bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False, cache_dir=hf_cache_dir
)
vocoder.remove_weight_norm()
vocoder = vocoder.eval().to(device)

View File

@@ -270,7 +270,7 @@ class CFM(nn.Module):
else:
drop_text = False
# if want rigourously mask out padding, record in collate_fn in dataset.py, and pass in here
# if want rigorously mask out padding, record in collate_fn in dataset.py, and pass in here
# adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
pred = self.transformer(
x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text

View File

@@ -51,7 +51,7 @@ class Trainer:
mel_spec_type: str = "vocos", # "vocos" | "bigvgan"
is_local_vocoder: bool = False, # use local path vocoder
local_vocoder_path: str = "", # local vocoder path
cfg_dict: dict = dict(), # training config
model_cfg_dict: dict = dict(), # training config
):
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
@@ -73,8 +73,8 @@ class Trainer:
else:
init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
if not cfg_dict:
cfg_dict = {
if not model_cfg_dict:
model_cfg_dict = {
"epochs": epochs,
"learning_rate": learning_rate,
"num_warmup_updates": num_warmup_updates,
@@ -85,11 +85,11 @@ class Trainer:
"max_grad_norm": max_grad_norm,
"noise_scheduler": noise_scheduler,
}
cfg_dict["gpus"] = self.accelerator.num_processes
model_cfg_dict["gpus"] = self.accelerator.num_processes
self.accelerator.init_trackers(
project_name=wandb_project,
init_kwargs=init_kwargs,
config=cfg_dict,
config=model_cfg_dict,
)
elif self.logger == "tensorboard":
@@ -395,6 +395,9 @@ class Trainer:
self.writer.add_scalar("loss", loss.item(), global_update)
self.writer.add_scalar("lr", self.scheduler.get_last_lr()[0], global_update)
if global_update % self.last_per_updates == 0 and self.accelerator.sync_gradients:
self.save_checkpoint(global_update, last=True)
if global_update % self.save_per_updates == 0 and self.accelerator.sync_gradients:
self.save_checkpoint(global_update)
@@ -430,9 +433,6 @@ class Trainer:
)
self.model.train()
if global_update % self.last_per_updates == 0 and self.accelerator.sync_gradients:
self.save_checkpoint(global_update, last=True)
self.save_checkpoint(global_update, last=True)
self.accelerator.end_training()

View File

@@ -0,0 +1,3 @@
FROM nvcr.io/nvidia/tritonserver:24.12-py3
RUN pip install tritonclient[grpc] tensorrt-llm==0.16.0 torchaudio==2.5.1 jieba pypinyin librosa vocos
WORKDIR /workspace

View File

@@ -0,0 +1,47 @@
## Triton Inference Serving Best Practice for F5-TTS
### Quick Start
Directly launch the service using docker compose.
```sh
# TODO: support F5TTS_v1_Base
MODEL=F5TTS_Base docker compose up
```
### Build Image
Build the docker image from scratch.
```sh
docker build . -f Dockerfile.server -t soar97/triton-f5-tts:24.12
```
### Create Docker Container
```sh
your_mount_dir=/mnt:/mnt
docker run -it --name "f5-server" --gpus all --net host -v $your_mount_dir --shm-size=2g soar97/triton-f5-tts:24.12
```
### Export Models to TensorRT-LLM and Launch Server
Inside docker container, we would follow the official guide of TensorRT-LLM to build qwen and whisper TensorRT-LLM engines. See [here](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/whisper).
```sh
bash run.sh 0 4 F5TTS_Base
```
### HTTP Client
```sh
python3 client_http.py
```
### Benchmark using Dataset
```sh
num_task=2
python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts
```
### Benchmark Results
Decoding on a single L20 GPU, using 26 different prompt_audio/target_text pairs.
| Model | Concurrency | Avg Latency | RTF |
|-------|-------------|----------------|-------|
| F5-TTS Base (Vocos) | 1 | 253 ms | 0.0394|
### Credits
1. [F5-TTS-TRTLLM](https://github.com/Bigfishering/f5-tts-trtllm)

View File

@@ -0,0 +1,470 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
# 2023 Nvidia (authors: Yuekai Zhang)
# 2023 Recurrent.ai (authors: Songtao Shi)
# See LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script supports to load dataset from huggingface and sends it to the server
for decoding, in parallel.
Usage:
num_task=2
# For offline F5-TTS
python3 client_grpc.py \
--server-addr localhost \
--model-name f5_tts \
--num-tasks $num_task \
--huggingface-dataset yuekai/seed_tts \
--split-name test_zh \
--log-dir ./log_concurrent_tasks_${num_task}
# For offline Spark-TTS-0.5B
python3 client_grpc.py \
--server-addr localhost \
--model-name spark_tts \
--num-tasks $num_task \
--huggingface-dataset yuekai/seed_tts \
--split-name wenetspeech4tts \
--log-dir ./log_concurrent_tasks_${num_task}
"""
import argparse
import asyncio
import json
import os
import time
import types
from pathlib import Path
import numpy as np
import soundfile as sf
import tritonclient
import tritonclient.grpc.aio as grpcclient
from tritonclient.utils import np_to_triton_dtype
def write_triton_stats(stats, summary_file):
with open(summary_file, "w") as summary_f:
model_stats = stats["model_stats"]
# write a note, the log is from triton_client.get_inference_statistics(), to better human readability
summary_f.write(
"The log is parsing from triton_client.get_inference_statistics(), to better human readability. \n"
)
summary_f.write("To learn more about the log, please refer to: \n")
summary_f.write("1. https://github.com/triton-inference-server/server/blob/main/docs/user_guide/metrics.md \n")
summary_f.write("2. https://github.com/triton-inference-server/server/issues/5374 \n\n")
summary_f.write(
"To better improve throughput, we always would like let requests wait in the queue for a while, and then execute them with a larger batch size. \n"
)
summary_f.write(
"However, there is a trade-off between the increased queue time and the increased batch size. \n"
)
summary_f.write(
"You may change 'max_queue_delay_microseconds' and 'preferred_batch_size' in the model configuration file to achieve this. \n"
)
summary_f.write(
"See https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#delayed-batching for more details. \n\n"
)
for model_state in model_stats:
if "last_inference" not in model_state:
continue
summary_f.write(f"model name is {model_state['name']} \n")
model_inference_stats = model_state["inference_stats"]
total_queue_time_s = int(model_inference_stats["queue"]["ns"]) / 1e9
total_infer_time_s = int(model_inference_stats["compute_infer"]["ns"]) / 1e9
total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9
total_output_time_s = int(model_inference_stats["compute_output"]["ns"]) / 1e9
summary_f.write(
f"queue time {total_queue_time_s:<5.2f} s, compute infer time {total_infer_time_s:<5.2f} s, compute input time {total_input_time_s:<5.2f} s, compute output time {total_output_time_s:<5.2f} s \n" # noqa
)
model_batch_stats = model_state["batch_stats"]
for batch in model_batch_stats:
batch_size = int(batch["batch_size"])
compute_input = batch["compute_input"]
compute_output = batch["compute_output"]
compute_infer = batch["compute_infer"]
batch_count = int(compute_infer["count"])
assert compute_infer["count"] == compute_output["count"] == compute_input["count"]
compute_infer_time_ms = int(compute_infer["ns"]) / 1e6
compute_input_time_ms = int(compute_input["ns"]) / 1e6
compute_output_time_ms = int(compute_output["ns"]) / 1e6
summary_f.write(
f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, total_infer_time {compute_infer_time_ms:<9.2f} ms, avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}={compute_infer_time_ms / batch_count:.2f} ms, avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}={compute_infer_time_ms / batch_count / batch_size:.2f} ms \n" # noqa
)
summary_f.write(
f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms / batch_count:.2f} ms, " # noqa
)
summary_f.write(
f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms / batch_count:.2f} ms \n" # noqa
)
def get_args():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
"--server-addr",
type=str,
default="localhost",
help="Address of the server",
)
parser.add_argument(
"--server-port",
type=int,
default=8001,
help="Grpc port of the triton server, default is 8001",
)
parser.add_argument(
"--reference-audio",
type=str,
default=None,
help="Path to a single audio file. It can't be specified at the same time with --manifest-dir",
)
parser.add_argument(
"--reference-text",
type=str,
default="",
help="",
)
parser.add_argument(
"--target-text",
type=str,
default="",
help="",
)
parser.add_argument(
"--huggingface-dataset",
type=str,
default="yuekai/seed_tts",
help="dataset name in huggingface dataset hub",
)
parser.add_argument(
"--split-name",
type=str,
default="wenetspeech4tts",
choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
help="dataset split name, default is 'test'",
)
parser.add_argument(
"--manifest-path",
type=str,
default=None,
help="Path to the manifest dir which includes wav.scp trans.txt files.",
)
parser.add_argument(
"--model-name",
type=str,
default="f5_tts",
choices=["f5_tts", "spark_tts"],
help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline",
)
parser.add_argument(
"--num-tasks",
type=int,
default=1,
help="Number of concurrent tasks for sending",
)
parser.add_argument(
"--log-interval",
type=int,
default=5,
help="Controls how frequently we print the log.",
)
parser.add_argument(
"--compute-wer",
action="store_true",
default=False,
help="""True to compute WER.
""",
)
parser.add_argument(
"--log-dir",
type=str,
required=False,
default="./tmp",
help="log directory",
)
parser.add_argument(
"--batch-size",
type=int,
default=1,
help="Inference batch_size per request for offline mode.",
)
return parser.parse_args()
def load_audio(wav_path, target_sample_rate=16000):
assert target_sample_rate == 16000, "hard coding in server"
if isinstance(wav_path, dict):
waveform = wav_path["array"]
sample_rate = wav_path["sampling_rate"]
else:
waveform, sample_rate = sf.read(wav_path)
if sample_rate != target_sample_rate:
from scipy.signal import resample
num_samples = int(len(waveform) * (target_sample_rate / sample_rate))
waveform = resample(waveform, num_samples)
return waveform, target_sample_rate
async def send(
manifest_item_list: list,
name: str,
triton_client: tritonclient.grpc.aio.InferenceServerClient,
protocol_client: types.ModuleType,
log_interval: int,
model_name: str,
padding_duration: int = None,
audio_save_dir: str = "./",
save_sample_rate: int = 16000,
):
total_duration = 0.0
latency_data = []
task_id = int(name[5:])
print(f"manifest_item_list: {manifest_item_list}")
for i, item in enumerate(manifest_item_list):
if i % log_interval == 0:
print(f"{name}: {i}/{len(manifest_item_list)}")
waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000)
duration = len(waveform) / sample_rate
lengths = np.array([[len(waveform)]], dtype=np.int32)
reference_text, target_text = item["reference_text"], item["target_text"]
estimated_target_duration = duration / len(reference_text) * len(target_text)
if padding_duration:
# padding to nearset 10 seconds
samples = np.zeros(
(
1,
padding_duration
* sample_rate
* ((int(estimated_target_duration + duration) // padding_duration) + 1),
),
dtype=np.float32,
)
samples[0, : len(waveform)] = waveform
else:
samples = waveform
samples = samples.reshape(1, -1).astype(np.float32)
inputs = [
protocol_client.InferInput("reference_wav", samples.shape, np_to_triton_dtype(samples.dtype)),
protocol_client.InferInput("reference_wav_len", lengths.shape, np_to_triton_dtype(lengths.dtype)),
protocol_client.InferInput("reference_text", [1, 1], "BYTES"),
protocol_client.InferInput("target_text", [1, 1], "BYTES"),
]
inputs[0].set_data_from_numpy(samples)
inputs[1].set_data_from_numpy(lengths)
input_data_numpy = np.array([reference_text], dtype=object)
input_data_numpy = input_data_numpy.reshape((1, 1))
inputs[2].set_data_from_numpy(input_data_numpy)
input_data_numpy = np.array([target_text], dtype=object)
input_data_numpy = input_data_numpy.reshape((1, 1))
inputs[3].set_data_from_numpy(input_data_numpy)
outputs = [protocol_client.InferRequestedOutput("waveform")]
sequence_id = 100000000 + i + task_id * 10
start = time.time()
response = await triton_client.infer(model_name, inputs, request_id=str(sequence_id), outputs=outputs)
audio = response.as_numpy("waveform").reshape(-1)
end = time.time() - start
audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
sf.write(audio_save_path, audio, save_sample_rate, "PCM_16")
latency_data.append((end, estimated_target_duration))
total_duration += estimated_target_duration
return total_duration, latency_data
def load_manifests(manifest_path):
with open(manifest_path, "r") as f:
manifest_list = []
for line in f:
assert len(line.strip().split("|")) == 4
utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
utt = Path(utt).stem
# gt_wav = os.path.join(os.path.dirname(manifest_path), "wavs", utt + ".wav")
if not os.path.isabs(prompt_wav):
prompt_wav = os.path.join(os.path.dirname(manifest_path), prompt_wav)
manifest_list.append(
{
"audio_filepath": prompt_wav,
"reference_text": prompt_text,
"target_text": gt_text,
"target_audio_path": utt,
}
)
return manifest_list
def split_data(data, k):
n = len(data)
if n < k:
print(f"Warning: the length of the input list ({n}) is less than k ({k}). Setting k to {n}.")
k = n
quotient = n // k
remainder = n % k
result = []
start = 0
for i in range(k):
if i < remainder:
end = start + quotient + 1
else:
end = start + quotient
result.append(data[start:end])
start = end
return result
async def main():
args = get_args()
url = f"{args.server_addr}:{args.server_port}"
triton_client = grpcclient.InferenceServerClient(url=url, verbose=False)
protocol_client = grpcclient
if args.reference_audio:
args.num_tasks = 1
args.log_interval = 1
manifest_item_list = [
{
"reference_text": args.reference_text,
"target_text": args.target_text,
"audio_filepath": args.reference_audio,
"target_audio_path": "test",
}
]
elif args.huggingface_dataset:
import datasets
dataset = datasets.load_dataset(
args.huggingface_dataset,
split=args.split_name,
trust_remote_code=True,
)
manifest_item_list = []
for i in range(len(dataset)):
manifest_item_list.append(
{
"audio_filepath": dataset[i]["prompt_audio"],
"reference_text": dataset[i]["prompt_text"],
"target_audio_path": dataset[i]["id"],
"target_text": dataset[i]["target_text"],
}
)
else:
manifest_item_list = load_manifests(args.manifest_path)
args.num_tasks = min(args.num_tasks, len(manifest_item_list))
manifest_item_list = split_data(manifest_item_list, args.num_tasks)
os.makedirs(args.log_dir, exist_ok=True)
tasks = []
start_time = time.time()
for i in range(args.num_tasks):
task = asyncio.create_task(
send(
manifest_item_list[i],
name=f"task-{i}",
triton_client=triton_client,
protocol_client=protocol_client,
log_interval=args.log_interval,
model_name=args.model_name,
audio_save_dir=args.log_dir,
padding_duration=1,
save_sample_rate=24000 if args.model_name == "f5_tts" else 16000,
)
)
tasks.append(task)
ans_list = await asyncio.gather(*tasks)
end_time = time.time()
elapsed = end_time - start_time
total_duration = 0.0
latency_data = []
for ans in ans_list:
total_duration += ans[0]
latency_data += ans[1]
rtf = elapsed / total_duration
s = f"RTF: {rtf:.4f}\n"
s += f"total_duration: {total_duration:.3f} seconds\n"
s += f"({total_duration / 3600:.2f} hours)\n"
s += f"processing time: {elapsed:.3f} seconds ({elapsed / 3600:.2f} hours)\n"
latency_list = [chunk_end for (chunk_end, chunk_duration) in latency_data]
latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0
latency_variance = np.var(latency_list, dtype=np.float64) * 1000.0
s += f"latency_variance: {latency_variance:.2f}\n"
s += f"latency_50_percentile_ms: {np.percentile(latency_list, 50) * 1000.0:.2f}\n"
s += f"latency_90_percentile_ms: {np.percentile(latency_list, 90) * 1000.0:.2f}\n"
s += f"latency_95_percentile_ms: {np.percentile(latency_list, 95) * 1000.0:.2f}\n"
s += f"latency_99_percentile_ms: {np.percentile(latency_list, 99) * 1000.0:.2f}\n"
s += f"average_latency_ms: {latency_ms:.2f}\n"
print(s)
if args.manifest_path:
name = Path(args.manifest_path).stem
elif args.split_name:
name = args.split_name
with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f:
f.write(s)
stats = await triton_client.get_inference_statistics(model_name="", as_json=True)
write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt")
metadata = await triton_client.get_model_config(model_name=args.model_name, as_json=True)
with open(f"{args.log_dir}/model_config-{name}.json", "w") as f:
json.dump(metadata, f, indent=4)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,142 @@
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import requests
import soundfile as sf
import numpy as np
import argparse
def get_args():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
"--server-url",
type=str,
default="localhost:8000",
help="Address of the server",
)
parser.add_argument(
"--reference-audio",
type=str,
default="../../infer/examples/basic/basic_ref_en.wav",
help="Path to a single audio file. It can't be specified at the same time with --manifest-dir",
)
parser.add_argument(
"--reference-text",
type=str,
default="Some call me nature, others call me mother nature.",
help="",
)
parser.add_argument(
"--target-text",
type=str,
default="I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring.",
help="",
)
parser.add_argument(
"--model-name",
type=str,
default="f5_tts",
choices=["f5_tts", "spark_tts"],
help="triton model_repo module name to request",
)
parser.add_argument(
"--output-audio",
type=str,
default="output.wav",
help="Path to save the output audio",
)
return parser.parse_args()
def prepare_request(
samples,
reference_text,
target_text,
sample_rate=16000,
audio_save_dir: str = "./",
):
assert len(samples.shape) == 1, "samples should be 1D"
lengths = np.array([[len(samples)]], dtype=np.int32)
samples = samples.reshape(1, -1).astype(np.float32)
data = {
"inputs": [
{"name": "reference_wav", "shape": samples.shape, "datatype": "FP32", "data": samples.tolist()},
{
"name": "reference_wav_len",
"shape": lengths.shape,
"datatype": "INT32",
"data": lengths.tolist(),
},
{"name": "reference_text", "shape": [1, 1], "datatype": "BYTES", "data": [reference_text]},
{"name": "target_text", "shape": [1, 1], "datatype": "BYTES", "data": [target_text]},
]
}
return data
def load_audio(wav_path, target_sample_rate=16000):
assert target_sample_rate == 16000, "hard coding in server"
if isinstance(wav_path, dict):
samples = wav_path["array"]
sample_rate = wav_path["sampling_rate"]
else:
samples, sample_rate = sf.read(wav_path)
if sample_rate != target_sample_rate:
from scipy.signal import resample
num_samples = int(len(samples) * (target_sample_rate / sample_rate))
samples = resample(samples, num_samples)
return samples, target_sample_rate
if __name__ == "__main__":
args = get_args()
server_url = args.server_url
if not server_url.startswith(("http://", "https://")):
server_url = f"http://{server_url}"
url = f"{server_url}/v2/models/{args.model_name}/infer"
samples, sr = load_audio(args.reference_audio)
assert sr == 16000, "sample rate hardcoded in server"
samples = np.array(samples, dtype=np.float32)
data = prepare_request(samples, args.reference_text, args.target_text)
rsp = requests.post(
url, headers={"Content-Type": "application/json"}, json=data, verify=False, params={"request_id": "0"}
)
result = rsp.json()
audio = result["outputs"][0]["data"]
audio = np.array(audio, dtype=np.float32)
sf.write(args.output_audio, audio, 24000, "PCM_16")

View File

@@ -0,0 +1,20 @@
services:
tts:
image: soar97/triton-f5-tts:24.12
shm_size: '1gb'
ports:
- "8000:8000"
- "8001:8001"
- "8002:8002"
environment:
- PYTHONIOENCODING=utf-8
- MODEL_ID=${MODEL_ID}
deploy:
resources:
reservations:
devices:
- driver: nvidia
device_ids: ['0']
capabilities: [gpu]
command: >
/bin/bash -c "pip install vocos && rm -rf F5-TTS && git clone https://github.com/SWivid/F5-TTS.git && cd F5-TTS/src/f5_tts/runtime/triton_trtllm/ && bash run.sh 0 4 $MODEL"

View File

@@ -0,0 +1,431 @@
import tensorrt as trt
import os
import math
import time
from typing import List, Optional
from functools import wraps
import tensorrt_llm
from tensorrt_llm._utils import str_dtype_to_torch, trt_dtype_to_torch
from tensorrt_llm.logger import logger
from tensorrt_llm.runtime.session import Session
import torch
import torch.nn as nn
import torch.nn.functional as F
def remove_tensor_padding(input_tensor, input_tensor_lengths=None):
# Audio tensor case: batch, seq_len, feature_len
# position_ids case: batch, seq_len
assert input_tensor_lengths is not None, "input_tensor_lengths must be provided for 3D input_tensor"
# Initialize a list to collect valid sequences
valid_sequences = []
for i in range(input_tensor.shape[0]):
valid_length = input_tensor_lengths[i]
valid_sequences.append(input_tensor[i, :valid_length])
# Concatenate all valid sequences along the batch dimension
output_tensor = torch.cat(valid_sequences, dim=0).contiguous()
return output_tensor
class TextEmbedding(nn.Module):
def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2, precompute_max_pos=4096):
super().__init__()
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, precompute_max_pos), persistent=False)
self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
def forward(self, text):
# only keep tensors with value not -1
text_mask = text != -1
text_pad_cut_off_index = text_mask.sum(dim=1).max()
text = text[:, :text_pad_cut_off_index]
text = self.text_embed(text)
text = text + self.freqs_cis[: text.shape[1], :]
for block in self.text_blocks:
text = block(text)
# padding text to the original length
# text shape: B,seq_len,C
# pad at the second dimension
text = F.pad(text, (0, 0, 0, text_mask.shape[1] - text.shape[1], 0, 0), value=0)
return text
class GRN(nn.Module):
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
self.beta = nn.Parameter(torch.zeros(1, 1, dim))
def forward(self, x):
Gx = torch.norm(x, p=2, dim=1, keepdim=True)
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
return self.gamma * (x * Nx) + self.beta + x
class ConvNeXtV2Block(nn.Module):
def __init__(
self,
dim: int,
intermediate_dim: int,
dilation: int = 1,
):
super().__init__()
padding = (dilation * (7 - 1)) // 2
self.dwconv = nn.Conv1d(
dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
) # depthwise conv
self.norm = nn.LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.grn = GRN(intermediate_dim)
self.pwconv2 = nn.Linear(intermediate_dim, dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
x = x.transpose(1, 2) # b n d -> b d n
x = self.dwconv(x)
x = x.transpose(1, 2) # b d n -> b n d
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.grn(x)
x = self.pwconv2(x)
return residual + x
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0):
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
# has some connection to NTK literature
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
# https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
theta *= theta_rescale_factor ** (dim / (dim - 2))
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cos = torch.cos(freqs) # real part
freqs_sin = torch.sin(freqs) # imaginary part
return torch.cat([freqs_cos, freqs_sin], dim=-1)
def load_checkpoint(ckpt_path, use_ema=True):
checkpoint = torch.load(ckpt_path, weights_only=True)
if use_ema:
checkpoint["model_state_dict"] = {
k.replace("ema_model.", ""): v
for k, v in checkpoint["ema_model_state_dict"].items()
if k not in ["initted", "step"]
}
dict_state = checkpoint["model_state_dict"]
text_embed_dict = {}
for key in dict_state.keys():
# transformer.text_embed.text_embed.weight -> text_embed.weight
if "text_embed" in key:
text_embed_dict[key.replace("transformer.text_embed.", "")] = dict_state[key]
return text_embed_dict
class F5TTS(object):
def __init__(
self,
config,
debug_mode=True,
stream: Optional[torch.cuda.Stream] = None,
tllm_model_dir: Optional[str] = None,
model_path: Optional[str] = None,
vocab_size: Optional[int] = None,
):
self.dtype = config["pretrained_config"]["dtype"]
rank = tensorrt_llm.mpi_rank()
world_size = config["pretrained_config"]["mapping"]["world_size"]
cp_size = config["pretrained_config"]["mapping"]["cp_size"]
tp_size = config["pretrained_config"]["mapping"]["tp_size"]
pp_size = config["pretrained_config"]["mapping"]["pp_size"]
assert pp_size == 1
self.mapping = tensorrt_llm.Mapping(
world_size=world_size, rank=rank, cp_size=cp_size, tp_size=tp_size, pp_size=1, gpus_per_node=1
)
local_rank = rank % self.mapping.gpus_per_node
self.device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(self.device)
self.stream = stream
if self.stream is None:
self.stream = torch.cuda.Stream(self.device)
torch.cuda.set_stream(self.stream)
engine_file = os.path.join(tllm_model_dir, f"rank{rank}.engine")
logger.info(f"Loading engine from {engine_file}")
with open(engine_file, "rb") as f:
engine_buffer = f.read()
assert engine_buffer is not None
self.session = Session.from_serialized_engine(engine_buffer)
self.debug_mode = debug_mode
self.inputs = {}
self.outputs = {}
self.buffer_allocated = False
expected_tensor_names = ["noise", "cond", "time", "rope_cos", "rope_sin", "input_lengths", "denoised"]
found_tensor_names = [self.session.engine.get_tensor_name(i) for i in range(self.session.engine.num_io_tensors)]
if not self.debug_mode and set(expected_tensor_names) != set(found_tensor_names):
logger.error(
f"The following expected tensors are not found: {set(expected_tensor_names).difference(set(found_tensor_names))}"
)
logger.error(
f"Those tensors in engine are not expected: {set(found_tensor_names).difference(set(expected_tensor_names))}"
)
logger.error(f"Expected tensor names: {expected_tensor_names}")
logger.error(f"Found tensor names: {found_tensor_names}")
raise RuntimeError("Tensor names in engine are not the same as expected.")
if self.debug_mode:
self.debug_tensors = list(set(found_tensor_names) - set(expected_tensor_names))
self.max_mel_len = 4096
self.text_embedding = TextEmbedding(
text_num_embeds=vocab_size, text_dim=512, conv_layers=4, precompute_max_pos=self.max_mel_len
).to(self.device)
self.text_embedding.load_state_dict(load_checkpoint(model_path), strict=True)
self.target_audio_sample_rate = 24000
self.target_rms = 0.15 # target rms for audio
self.n_fft = 1024
self.win_length = 1024
self.hop_length = 256
self.n_mel_channels = 100
# self.max_mel_len = 3000
self.head_dim = 64
self.base_rescale_factor = 1.0
self.interpolation_factor = 1.0
base = 10000.0 * self.base_rescale_factor ** (self.head_dim / (self.head_dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim))
freqs = torch.outer(torch.arange(self.max_mel_len, dtype=torch.float32), inv_freq) / self.interpolation_factor
self.freqs = freqs.repeat_interleave(2, dim=-1).unsqueeze(0)
self.rope_cos = self.freqs.cos().half()
self.rope_sin = self.freqs.sin().half()
self.nfe_steps = 16
t = torch.linspace(0, 1, self.nfe_steps + 1, dtype=torch.float32)
time_step = t + (-1.0) * (torch.cos(torch.pi * 0.5 * t) - 1 + t)
delta_t = torch.diff(time_step)
# WAR: hard coding 256 here
tmp_dim = 256
time_expand = torch.zeros((1, self.nfe_steps, tmp_dim), dtype=torch.float32)
half_dim = tmp_dim // 2
emb_factor = math.log(10000) / (half_dim - 1)
emb_factor = 1000.0 * torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb_factor)
for i in range(self.nfe_steps):
emb = time_step[i] * emb_factor
time_expand[:, i, :] = torch.cat((emb.sin(), emb.cos()), dim=-1)
self.time_expand = time_expand.to(self.device)
self.delta_t = torch.cat((delta_t, delta_t), dim=0).contiguous().to(self.device)
def _tensor_dtype(self, name):
# return torch dtype given tensor name for convenience
dtype = trt_dtype_to_torch(self.session.engine.get_tensor_dtype(name))
return dtype
def _setup(self, batch_size, seq_len):
for i in range(self.session.engine.num_io_tensors):
name = self.session.engine.get_tensor_name(i)
if self.session.engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT:
shape = list(self.session.engine.get_tensor_shape(name))
shape[0] = batch_size
shape[1] = seq_len
self.outputs[name] = torch.empty(shape, dtype=self._tensor_dtype(name), device=self.device)
self.buffer_allocated = True
def cuda_stream_guard(func):
"""Sync external stream and set current stream to the one bound to the session. Reset on exit."""
@wraps(func)
def wrapper(self, *args, **kwargs):
external_stream = torch.cuda.current_stream()
if external_stream != self.stream:
external_stream.synchronize()
torch.cuda.set_stream(self.stream)
ret = func(self, *args, **kwargs)
if external_stream != self.stream:
self.stream.synchronize()
torch.cuda.set_stream(external_stream)
return ret
return wrapper
@cuda_stream_guard
def forward(
self,
noise: torch.Tensor,
cond: torch.Tensor,
time_expand: torch.Tensor,
rope_cos: torch.Tensor,
rope_sin: torch.Tensor,
input_lengths: torch.Tensor,
delta_t: torch.Tensor,
use_perf: bool = False,
):
if use_perf:
torch.cuda.nvtx.range_push("flow matching")
cfg_strength = 2.0
batch_size = noise.shape[0]
half_batch = batch_size // 2
noise_half = noise[:half_batch] # Store the initial half of noise
input_type = str_dtype_to_torch(self.dtype)
# Keep a copy of the initial tensors
cond = cond.to(input_type)
rope_cos = rope_cos.to(input_type)
rope_sin = rope_sin.to(input_type)
input_lengths = input_lengths.to(str_dtype_to_torch("int32"))
# Instead of iteratively updating noise within a single model context,
# we'll do a single forward pass for each iteration with fresh context setup
for i in range(self.nfe_steps):
# Re-setup the buffers for clean execution
self._setup(batch_size, noise.shape[1])
if not self.buffer_allocated:
raise RuntimeError("Buffer not allocated, please call setup first!")
# Re-create combined noises for this iteration
current_noise = torch.cat([noise_half, noise_half], dim=0).to(input_type)
# Get time step for this iteration
current_time = time_expand[:, i].to(input_type)
# Create fresh input dictionary for this iteration
current_inputs = {
"noise": current_noise,
"cond": cond,
"time": current_time,
"rope_cos": rope_cos,
"rope_sin": rope_sin,
"input_lengths": input_lengths,
}
# Update inputs and set shapes
self.inputs.clear() # Clear previous inputs
self.inputs.update(**current_inputs)
self.session.set_shapes(self.inputs)
if use_perf:
torch.cuda.nvtx.range_push(f"execute {i}")
ok = self.session.run(self.inputs, self.outputs, self.stream.cuda_stream)
assert ok, "Failed to execute model"
# self.session.context.execute_async_v3(self.stream.cuda_stream)
if use_perf:
torch.cuda.nvtx.range_pop()
# Process results
t_scale = delta_t[i].unsqueeze(0).to(input_type)
# Extract predictions
pred_cond = self.outputs["denoised"][:half_batch]
pred_uncond = self.outputs["denoised"][half_batch:]
# Apply classifier-free guidance with safeguards
guidance = pred_cond + (pred_cond - pred_uncond) * cfg_strength
# Calculate update for noise
noise_half = noise_half + guidance * t_scale
if use_perf:
torch.cuda.nvtx.range_pop()
return noise_half
def sample(
self,
text_pad_sequence: torch.Tensor,
ref_mel_batch: torch.Tensor,
ref_mel_len_batch: torch.Tensor,
estimated_reference_target_mel_len: List[int],
remove_input_padding: bool = False,
use_perf: bool = False,
):
if use_perf:
torch.cuda.nvtx.range_push("text embedding")
batch = text_pad_sequence.shape[0]
max_seq_len = ref_mel_batch.shape[1]
text_pad_sequence_drop = torch.cat(
(text_pad_sequence, torch.zeros((1, text_pad_sequence.shape[1]), dtype=torch.int32).to(self.device)), dim=0
)
text_embedding_drop_list = []
for i in range(batch + 1):
text_embedding_drop_list.append(self.text_embedding(text_pad_sequence_drop[i].unsqueeze(0).to(self.device)))
text_embedding_drop_condition = torch.cat(text_embedding_drop_list, dim=0)
text_embedding = text_embedding_drop_condition[:-1]
# text_embedding_drop B,T,C batch should be the same
text_embedding_drop = text_embedding_drop_condition[-1].unsqueeze(0).repeat(batch, 1, 1)
noise = torch.randn_like(ref_mel_batch).to(self.device)
rope_cos = self.rope_cos[:, :max_seq_len, :].float().repeat(batch, 1, 1)
rope_sin = self.rope_sin[:, :max_seq_len, :].float().repeat(batch, 1, 1)
cat_mel_text = torch.cat((ref_mel_batch, text_embedding), dim=-1)
cat_mel_text_drop = torch.cat(
(
torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float32).to(self.device),
text_embedding_drop,
),
dim=-1,
)
time_expand = self.time_expand.repeat(2 * batch, 1, 1).contiguous()
# Convert estimated_reference_target_mel_len to tensor
input_lengths = torch.tensor(estimated_reference_target_mel_len, dtype=torch.int32)
# combine above along the batch dimension
inputs = {
"noise": torch.cat((noise, noise), dim=0).contiguous(),
"cond": torch.cat((cat_mel_text, cat_mel_text_drop), dim=0).contiguous(),
"time_expand": time_expand,
"rope_cos": torch.cat((rope_cos, rope_cos), dim=0).contiguous(),
"rope_sin": torch.cat((rope_sin, rope_sin), dim=0).contiguous(),
"input_lengths": torch.cat((input_lengths, input_lengths), dim=0).contiguous(),
"delta_t": self.delta_t,
}
if use_perf and remove_input_padding:
torch.cuda.nvtx.range_push("remove input padding")
if remove_input_padding:
max_seq_len = inputs["cond"].shape[1]
inputs["noise"] = remove_tensor_padding(inputs["noise"], inputs["input_lengths"])
inputs["cond"] = remove_tensor_padding(inputs["cond"], inputs["input_lengths"])
# for time_expand, convert from B,D to B,T,D by repeat
inputs["time_expand"] = inputs["time_expand"].unsqueeze(1).repeat(1, max_seq_len, 1, 1)
inputs["time_expand"] = remove_tensor_padding(inputs["time_expand"], inputs["input_lengths"])
inputs["rope_cos"] = remove_tensor_padding(inputs["rope_cos"], inputs["input_lengths"])
inputs["rope_sin"] = remove_tensor_padding(inputs["rope_sin"], inputs["input_lengths"])
if use_perf and remove_input_padding:
torch.cuda.nvtx.range_pop()
for key in inputs:
inputs[key] = inputs[key].to(self.device)
if use_perf:
torch.cuda.nvtx.range_pop()
start_time = time.time()
denoised = self.forward(**inputs, use_perf=use_perf)
cost_time = time.time() - start_time
if use_perf and remove_input_padding:
torch.cuda.nvtx.range_push("remove input padding output")
if remove_input_padding:
denoised_list = []
start_idx = 0
for i in range(batch):
denoised_list.append(denoised[start_idx : start_idx + inputs["input_lengths"][i]])
start_idx += inputs["input_lengths"][i]
if use_perf and remove_input_padding:
torch.cuda.nvtx.range_pop()
return denoised_list, cost_time
return denoised, cost_time

View File

@@ -0,0 +1,275 @@
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import json
import torch
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
from torch.utils.dlpack import from_dlpack, to_dlpack
import torchaudio
import jieba
import triton_python_backend_utils as pb_utils
from pypinyin import Style, lazy_pinyin
import os
from f5_tts_trtllm import F5TTS
def get_tokenizer(vocab_file_path: str):
"""
tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
- "char" for char-wise tokenizer, need .txt vocab_file
- "byte" for utf-8 tokenizer
- "custom" if you're directly passing in a path to the vocab.txt you want to use
vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
- if use "char", derived from unfiltered character & symbol counts of custom dataset
- if use "byte", set to 256 (unicode byte range)
"""
with open(vocab_file_path, "r", encoding="utf-8") as f:
vocab_char_map = {}
for i, char in enumerate(f):
vocab_char_map[char[:-1]] = i
vocab_size = len(vocab_char_map)
return vocab_char_map, vocab_size
def convert_char_to_pinyin(reference_target_texts_list, polyphone=True):
final_reference_target_texts_list = []
custom_trans = str.maketrans(
{";": ",", "": '"', "": '"', "": "'", "": "'"}
) # add custom trans here, to address oov
def is_chinese(c):
return "\u3100" <= c <= "\u9fff" # common chinese characters
for text in reference_target_texts_list:
char_list = []
text = text.translate(custom_trans)
for seg in jieba.cut(text):
seg_byte_len = len(bytes(seg, "UTF-8"))
if seg_byte_len == len(seg): # if pure alphabets and symbols
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
char_list.append(" ")
char_list.extend(seg)
elif polyphone and seg_byte_len == 3 * len(seg): # if pure east asian characters
seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
for i, c in enumerate(seg):
if is_chinese(c):
char_list.append(" ")
char_list.append(seg_[i])
else: # if mixed characters, alphabets and symbols
for c in seg:
if ord(c) < 256:
char_list.extend(c)
elif is_chinese(c):
char_list.append(" ")
char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
else:
char_list.append(c)
final_reference_target_texts_list.append(char_list)
return final_reference_target_texts_list
def list_str_to_idx(
text: list[str] | list[list[str]],
vocab_char_map: dict[str, int], # {char: idx}
padding_value=-1,
): # noqa: F722
list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
return list_idx_tensors
class TritonPythonModel:
def initialize(self, args):
self.use_perf = True
self.device = torch.device("cuda")
self.target_audio_sample_rate = 24000
self.target_rms = 0.15 # target rms for audio
self.n_fft = 1024
self.win_length = 1024
self.hop_length = 256
self.n_mel_channels = 100
self.max_mel_len = 3000
self.head_dim = 64
parameters = json.loads(args["model_config"])["parameters"]
for key, value in parameters.items():
parameters[key] = value["string_value"]
self.vocab_char_map, self.vocab_size = get_tokenizer(parameters["vocab_file"])
self.reference_sample_rate = int(parameters["reference_audio_sample_rate"])
self.resampler = torchaudio.transforms.Resample(self.reference_sample_rate, self.target_audio_sample_rate)
self.tllm_model_dir = parameters["tllm_model_dir"]
config_file = os.path.join(self.tllm_model_dir, "config.json")
with open(config_file) as f:
config = json.load(f)
self.model = F5TTS(
config,
debug_mode=False,
tllm_model_dir=self.tllm_model_dir,
model_path=parameters["model_path"],
vocab_size=self.vocab_size,
)
self.vocoder = parameters["vocoder"]
assert self.vocoder in ["vocos", "bigvgan"]
if self.vocoder == "vocos":
self.mel_stft = torchaudio.transforms.MelSpectrogram(
sample_rate=self.target_audio_sample_rate,
n_fft=self.n_fft,
win_length=self.win_length,
hop_length=self.hop_length,
n_mels=self.n_mel_channels,
power=1,
center=True,
normalized=False,
norm=None,
).to(self.device)
self.compute_mel_fn = self.get_vocos_mel_spectrogram
elif self.vocoder == "bigvgan":
self.compute_mel_fn = self.get_bigvgan_mel_spectrogram
def get_vocos_mel_spectrogram(self, waveform):
mel = self.mel_stft(waveform)
mel = mel.clamp(min=1e-5).log()
return mel.transpose(1, 2)
def forward_vocoder(self, mel):
mel = mel.to(torch.float32).contiguous().cpu()
input_tensor_0 = pb_utils.Tensor.from_dlpack("mel", to_dlpack(mel))
inference_request = pb_utils.InferenceRequest(
model_name="vocoder", requested_output_names=["waveform"], inputs=[input_tensor_0]
)
inference_response = inference_request.exec()
if inference_response.has_error():
raise pb_utils.TritonModelException(inference_response.error().message())
else:
waveform = pb_utils.get_output_tensor_by_name(inference_response, "waveform")
waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
return waveform
def execute(self, requests):
(
reference_text_list,
target_text_list,
reference_target_texts_list,
estimated_reference_target_mel_len,
reference_mel_len,
) = [], [], [], [], []
mel_features_list = []
if self.use_perf:
torch.cuda.nvtx.range_push("preprocess")
for request in requests:
wav_tensor = pb_utils.get_input_tensor_by_name(request, "reference_wav")
wav_lens = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
reference_text = reference_text[0][0].decode("utf-8")
reference_text_list.append(reference_text)
target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
target_text = target_text[0][0].decode("utf-8")
target_text_list.append(target_text)
text = reference_text + target_text
reference_target_texts_list.append(text)
wav = from_dlpack(wav_tensor.to_dlpack())
wav_len = from_dlpack(wav_lens.to_dlpack())
wav_len = wav_len.squeeze()
assert wav.shape[0] == 1, "Only support batch size 1 for now."
wav = wav[:, :wav_len]
ref_rms = torch.sqrt(torch.mean(torch.square(wav)))
if ref_rms < self.target_rms:
wav = wav * self.target_rms / ref_rms
if self.reference_sample_rate != self.target_audio_sample_rate:
wav = self.resampler(wav)
wav = wav.to(self.device)
if self.use_perf:
torch.cuda.nvtx.range_push("compute_mel")
mel_features = self.compute_mel_fn(wav)
if self.use_perf:
torch.cuda.nvtx.range_pop()
mel_features_list.append(mel_features)
reference_mel_len.append(mel_features.shape[1])
estimated_reference_target_mel_len.append(
int(mel_features.shape[1] * (1 + len(target_text) / len(reference_text)))
)
max_seq_len = min(max(estimated_reference_target_mel_len), self.max_mel_len)
batch = len(requests)
mel_features = torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float16).to(self.device)
for i, mel in enumerate(mel_features_list):
mel_features[i, : mel.shape[1], :] = mel
reference_mel_len_tensor = torch.LongTensor(reference_mel_len).to(self.device)
pinyin_list = convert_char_to_pinyin(reference_target_texts_list, polyphone=True)
text_pad_sequence = list_str_to_idx(pinyin_list, self.vocab_char_map)
for i, item in enumerate(text_pad_sequence):
text_pad_sequence[i] = F.pad(
item, (0, estimated_reference_target_mel_len[i] - len(item)), mode="constant", value=-1
)
text_pad_sequence[i] += 1 # WAR: 0 is reserved for padding token, hard coding in F5-TTS
text_pad_sequence = pad_sequence(text_pad_sequence, padding_value=-1, batch_first=True).to(self.device)
text_pad_sequence = F.pad(
text_pad_sequence, (0, max_seq_len - text_pad_sequence.shape[1]), mode="constant", value=-1
)
if self.use_perf:
torch.cuda.nvtx.range_pop()
denoised, cost_time = self.model.sample(
text_pad_sequence,
mel_features,
reference_mel_len_tensor,
estimated_reference_target_mel_len,
remove_input_padding=False,
use_perf=self.use_perf,
)
if self.use_perf:
torch.cuda.nvtx.range_push("vocoder")
responses = []
for i in range(batch):
ref_me_len = reference_mel_len[i]
estimated_mel_len = estimated_reference_target_mel_len[i]
denoised_one_item = denoised[i, ref_me_len:estimated_mel_len, :].unsqueeze(0).transpose(1, 2)
audio = self.forward_vocoder(denoised_one_item)
rms = torch.sqrt(torch.mean(torch.square(audio)))
if rms < self.target_rms:
audio = audio * self.target_rms / rms
audio = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
inference_response = pb_utils.InferenceResponse(output_tensors=[audio])
responses.append(inference_response)
if self.use_perf:
torch.cuda.nvtx.range_pop()
return responses

View File

@@ -0,0 +1,81 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: "f5_tts"
backend: "python"
max_batch_size: 4
dynamic_batching {
max_queue_delay_microseconds: 1000
}
parameters [
{
key: "vocab_file"
value: { string_value: "${vocab}"}
},
{
key: "model_path",
value: {string_value:"${model}"}
},
{
key: "tllm_model_dir",
value: {string_value:"${trtllm}"}
},
{
key: "reference_audio_sample_rate",
value: {string_value:"16000"}
},
{
key: "vocoder",
value: {string_value:"${vocoder}"}
}
]
input [
{
name: "reference_wav"
data_type: TYPE_FP32
dims: [-1]
optional: True
},
{
name: "reference_wav_len"
data_type: TYPE_INT32
dims: [1]
optional: True
},
{
name: "reference_text"
data_type: TYPE_STRING
dims: [1]
},
{
name: "target_text"
data_type: TYPE_STRING
dims: [1]
}
]
output [
{
name: "waveform"
data_type: TYPE_FP32
dims: [ -1 ]
}
]
instance_group [
{
count: 1
kind: KIND_GPU
}
]

View File

@@ -0,0 +1,32 @@
name: "vocoder"
backend: "tensorrt"
default_model_filename: "vocoder.plan"
max_batch_size: 4
input [
{
name: "mel"
data_type: TYPE_FP32
dims: [ 100, -1 ]
}
]
output [
{
name: "waveform"
data_type: TYPE_FP32
dims: [ -1 ]
}
]
dynamic_batching {
preferred_batch_size: [1, 2, 4]
max_queue_delay_microseconds: 1
}
instance_group [
{
count: 1
kind: KIND_GPU
}
]

View File

@@ -0,0 +1,198 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .baichuan.model import BaichuanForCausalLM
from .bert.model import (
BertForQuestionAnswering,
BertForSequenceClassification,
BertModel,
RobertaForQuestionAnswering,
RobertaForSequenceClassification,
RobertaModel,
)
from .bloom.model import BloomForCausalLM, BloomModel
from .chatglm.config import ChatGLMConfig
from .chatglm.model import ChatGLMForCausalLM, ChatGLMModel
from .cogvlm.config import CogVLMConfig
from .cogvlm.model import CogVLMForCausalLM
from .commandr.model import CohereForCausalLM
from .dbrx.config import DbrxConfig
from .dbrx.model import DbrxForCausalLM
from .deepseek_v1.model import DeepseekForCausalLM
from .deepseek_v2.model import DeepseekV2ForCausalLM
from .dit.model import DiT
from .eagle.model import EagleForCausalLM
from .enc_dec.model import DecoderModel, EncoderModel, WhisperEncoder
from .falcon.config import FalconConfig
from .falcon.model import FalconForCausalLM, FalconModel
from .gemma.config import GEMMA2_ARCHITECTURE, GEMMA_ARCHITECTURE, GemmaConfig
from .gemma.model import GemmaForCausalLM
from .gpt.config import GPTConfig
from .gpt.model import GPTForCausalLM, GPTModel
from .gptj.config import GPTJConfig
from .gptj.model import GPTJForCausalLM, GPTJModel
from .gptneox.model import GPTNeoXForCausalLM, GPTNeoXModel
from .grok.model import GrokForCausalLM
from .llama.config import LLaMAConfig
from .llama.model import LLaMAForCausalLM, LLaMAModel
from .mamba.model import MambaForCausalLM
from .medusa.config import MedusaConfig
from .medusa.model import MedusaForCausalLm
from .mllama.model import MLLaMAModel
from .modeling_utils import PretrainedConfig, PretrainedModel, SpeculativeDecodingMode
from .mpt.model import MPTForCausalLM, MPTModel
from .nemotron_nas.model import DeciLMForCausalLM
from .opt.model import OPTForCausalLM, OPTModel
from .phi3.model import Phi3ForCausalLM, Phi3Model
from .phi.model import PhiForCausalLM, PhiModel
from .qwen.model import QWenForCausalLM
from .recurrentgemma.model import RecurrentGemmaForCausalLM
from .redrafter.model import ReDrafterForCausalLM
from .f5tts.model import F5TTS
__all__ = [
"BertModel",
"BertForQuestionAnswering",
"BertForSequenceClassification",
"RobertaModel",
"RobertaForQuestionAnswering",
"RobertaForSequenceClassification",
"BloomModel",
"BloomForCausalLM",
"DiT",
"DeepseekForCausalLM",
"FalconConfig",
"DeepseekV2ForCausalLM",
"FalconForCausalLM",
"FalconModel",
"GPTConfig",
"GPTModel",
"GPTForCausalLM",
"OPTForCausalLM",
"OPTModel",
"LLaMAConfig",
"LLaMAForCausalLM",
"LLaMAModel",
"MedusaConfig",
"MedusaForCausalLm",
"ReDrafterForCausalLM",
"GPTJConfig",
"GPTJModel",
"GPTJForCausalLM",
"GPTNeoXModel",
"GPTNeoXForCausalLM",
"PhiModel",
"PhiConfig",
"Phi3Model",
"Phi3Config",
"PhiForCausalLM",
"Phi3ForCausalLM",
"ChatGLMConfig",
"ChatGLMForCausalLM",
"ChatGLMModel",
"BaichuanForCausalLM",
"QWenConfigQWenForCausalLM",
"QWenModel",
"EncoderModel",
"DecoderModel",
"PretrainedConfig",
"PretrainedModel",
"WhisperEncoder",
"MambaForCausalLM",
"MambaConfig",
"MPTForCausalLM",
"MPTModel",
"SkyworkForCausalLM",
"GemmaConfig",
"GemmaForCausalLM",
"DbrxConfig",
"DbrxForCausalLM",
"RecurrentGemmaForCausalLM",
"CogVLMConfig",
"CogVLMForCausalLM",
"EagleForCausalLM",
"SpeculativeDecodingMode",
"CohereForCausalLM",
"MLLaMAModel",
"F5TTS",
]
MODEL_MAP = {
"GPT2LMHeadModel": GPTForCausalLM,
"GPT2LMHeadCustomModel": GPTForCausalLM,
"GPTBigCodeForCausalLM": GPTForCausalLM,
"Starcoder2ForCausalLM": GPTForCausalLM,
"FuyuForCausalLM": GPTForCausalLM,
"Kosmos2ForConditionalGeneration": GPTForCausalLM,
"JAISLMHeadModel": GPTForCausalLM,
"GPTForCausalLM": GPTForCausalLM,
"NemotronForCausalLM": GPTForCausalLM,
"OPTForCausalLM": OPTForCausalLM,
"BloomForCausalLM": BloomForCausalLM,
"RWForCausalLM": FalconForCausalLM,
"FalconForCausalLM": FalconForCausalLM,
"PhiForCausalLM": PhiForCausalLM,
"Phi3ForCausalLM": Phi3ForCausalLM,
"Phi3VForCausalLM": Phi3ForCausalLM,
"Phi3SmallForCausalLM": Phi3ForCausalLM,
"PhiMoEForCausalLM": Phi3ForCausalLM,
"MambaForCausalLM": MambaForCausalLM,
"GPTNeoXForCausalLM": GPTNeoXForCausalLM,
"GPTJForCausalLM": GPTJForCausalLM,
"MPTForCausalLM": MPTForCausalLM,
"GLMModel": ChatGLMForCausalLM,
"ChatGLMModel": ChatGLMForCausalLM,
"ChatGLMForCausalLM": ChatGLMForCausalLM,
"LlamaForCausalLM": LLaMAForCausalLM,
"ExaoneForCausalLM": LLaMAForCausalLM,
"MistralForCausalLM": LLaMAForCausalLM,
"MixtralForCausalLM": LLaMAForCausalLM,
"ArcticForCausalLM": LLaMAForCausalLM,
"Grok1ModelForCausalLM": GrokForCausalLM,
"InternLMForCausalLM": LLaMAForCausalLM,
"InternLM2ForCausalLM": LLaMAForCausalLM,
"MedusaForCausalLM": MedusaForCausalLm,
"ReDrafterForCausalLM": ReDrafterForCausalLM,
"BaichuanForCausalLM": BaichuanForCausalLM,
"BaiChuanForCausalLM": BaichuanForCausalLM,
"SkyworkForCausalLM": LLaMAForCausalLM,
GEMMA_ARCHITECTURE: GemmaForCausalLM,
GEMMA2_ARCHITECTURE: GemmaForCausalLM,
"QWenLMHeadModel": QWenForCausalLM,
"QWenForCausalLM": QWenForCausalLM,
"Qwen2ForCausalLM": QWenForCausalLM,
"Qwen2MoeForCausalLM": QWenForCausalLM,
"Qwen2ForSequenceClassification": QWenForCausalLM,
"Qwen2VLForConditionalGeneration": QWenForCausalLM,
"WhisperEncoder": WhisperEncoder,
"EncoderModel": EncoderModel,
"DecoderModel": DecoderModel,
"DbrxForCausalLM": DbrxForCausalLM,
"RecurrentGemmaForCausalLM": RecurrentGemmaForCausalLM,
"CogVLMForCausalLM": CogVLMForCausalLM,
"DiT": DiT,
"DeepseekForCausalLM": DeepseekForCausalLM,
"DeciLMForCausalLM": DeciLMForCausalLM,
"DeepseekV2ForCausalLM": DeepseekV2ForCausalLM,
"EagleForCausalLM": EagleForCausalLM,
"CohereForCausalLM": CohereForCausalLM,
"MllamaForConditionalGeneration": MLLaMAModel,
"BertForQuestionAnswering": BertForQuestionAnswering,
"BertForSequenceClassification": BertForSequenceClassification,
"BertModel": BertModel,
"RobertaModel": RobertaModel,
"RobertaForQuestionAnswering": RobertaForQuestionAnswering,
"RobertaForSequenceClassification": RobertaForSequenceClassification,
"F5TTS": F5TTS,
}

View File

@@ -0,0 +1,225 @@
from __future__ import annotations
import sys
import os
import tensorrt as trt
from collections import OrderedDict
from ..._utils import str_dtype_to_trt
from ...plugin import current_all_reduce_helper
from ..modeling_utils import PretrainedConfig, PretrainedModel
from ...functional import Tensor, concat
from ...module import Module, ModuleList
from tensorrt_llm._common import default_net
from ...layers import Linear
from .modules import (
TimestepEmbedding,
ConvPositionEmbedding,
DiTBlock,
AdaLayerNormZero_Final,
)
current_file_path = os.path.abspath(__file__)
parent_dir = os.path.dirname(current_file_path)
sys.path.append(parent_dir)
class InputEmbedding(Module):
def __init__(self, mel_dim, text_dim, out_dim):
super().__init__()
self.proj = Linear(mel_dim * 2 + text_dim, out_dim)
self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
def forward(self, x, cond):
x = self.proj(concat([x, cond], dim=-1))
return self.conv_pos_embed(x) + x
class F5TTS(PretrainedModel):
def __init__(self, config: PretrainedConfig):
super().__init__(config)
self.dtype = str_dtype_to_trt(config.dtype)
self.time_embed = TimestepEmbedding(config.hidden_size)
self.input_embed = InputEmbedding(config.mel_dim, config.text_dim, config.hidden_size)
self.dim = config.hidden_size
self.depth = config.num_hidden_layers
self.transformer_blocks = ModuleList(
[
DiTBlock(
dim=self.dim,
heads=config.num_attention_heads,
dim_head=config.dim_head,
ff_mult=config.ff_mult,
dropout=config.dropout,
)
for _ in range(self.depth)
]
)
self.norm_out = AdaLayerNormZero_Final(config.hidden_size) # final modulation
self.proj_out = Linear(config.hidden_size, config.mel_dim)
def forward(
self,
noise, # nosied input audio
cond, # masked cond audio
time, # time step
rope_cos,
rope_sin,
input_lengths,
scale=1.0,
):
t = self.time_embed(time)
x = self.input_embed(noise, cond)
for block in self.transformer_blocks:
x = block(x, t, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale)
denoise = self.proj_out(self.norm_out(x, t))
denoise.mark_output("denoised", self.dtype)
return denoise
def prepare_inputs(self, **kwargs):
max_batch_size = kwargs["max_batch_size"]
batch_size_range = [2, 2, max_batch_size]
mel_size = 100
max_seq_len = 3000
num_frames_range = [200, 2 * max_seq_len, max_seq_len * max_batch_size]
hidden_size = 512
concat_feature_dim = mel_size + hidden_size
freq_embed_dim = 256
head_dim = 64
mapping = self.config.mapping
if mapping.tp_size > 1:
current_all_reduce_helper().set_workspace_tensor(mapping, 1)
if default_net().plugin_config.remove_input_padding:
noise = Tensor(
name="noise",
dtype=self.dtype,
shape=[-1, mel_size],
dim_range=OrderedDict(
[
("num_frames", [num_frames_range]),
("n_mels", [mel_size]),
]
),
)
cond = Tensor(
name="cond",
dtype=self.dtype,
shape=[-1, concat_feature_dim],
dim_range=OrderedDict(
[
("num_frames", [num_frames_range]),
("embeded_length", [concat_feature_dim]),
]
),
)
time = Tensor(
name="time",
dtype=self.dtype,
shape=[-1, freq_embed_dim],
dim_range=OrderedDict(
[
("num_frames", [num_frames_range]),
("freq_dim", [freq_embed_dim]),
]
),
)
rope_cos = Tensor(
name="rope_cos",
dtype=self.dtype,
shape=[-1, head_dim],
dim_range=OrderedDict(
[
("num_frames", [num_frames_range]),
("head_dim", [head_dim]),
]
),
)
rope_sin = Tensor(
name="rope_sin",
dtype=self.dtype,
shape=[-1, head_dim],
dim_range=OrderedDict(
[
("num_frames", [num_frames_range]),
("head_dim", [head_dim]),
]
),
)
else:
noise = Tensor(
name="noise",
dtype=self.dtype,
shape=[-1, -1, mel_size],
dim_range=OrderedDict(
[
("batch_size", [batch_size_range]),
("max_duratuion", [[100, max_seq_len // 2, max_seq_len]]),
("n_mels", [mel_size]),
]
),
)
cond = Tensor(
name="cond",
dtype=self.dtype,
shape=[-1, -1, concat_feature_dim],
dim_range=OrderedDict(
[
("batch_size", [batch_size_range]),
("max_duratuion", [[100, max_seq_len // 2, max_seq_len]]),
("embeded_length", [concat_feature_dim]),
]
),
)
time = Tensor(
name="time",
dtype=self.dtype,
shape=[-1, freq_embed_dim],
dim_range=OrderedDict(
[
("batch_size", [batch_size_range]),
("freq_dim", [freq_embed_dim]),
]
),
)
rope_cos = Tensor(
name="rope_cos",
dtype=self.dtype,
shape=[-1, -1, head_dim],
dim_range=OrderedDict(
[
("batch_size", [batch_size_range]),
("max_duratuion", [[100, max_seq_len // 2, max_seq_len]]),
("head_dim", [head_dim]),
]
),
)
rope_sin = Tensor(
name="rope_sin",
dtype=self.dtype,
shape=[-1, -1, head_dim],
dim_range=OrderedDict(
[
("batch_size", [batch_size_range]),
("max_duratuion", [[100, max_seq_len // 2, max_seq_len]]),
("head_dim", [head_dim]),
]
),
)
input_lengths = Tensor(
name="input_lengths",
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([("batch_size", [batch_size_range])]),
)
return {
"noise": noise,
"cond": cond,
"time": time,
"rope_cos": rope_cos,
"rope_sin": rope_sin,
"input_lengths": input_lengths,
}

View File

@@ -0,0 +1,410 @@
from __future__ import annotations
import math
from typing import Optional
import torch
import torch.nn.functional as F
import numpy as np
from tensorrt_llm._common import default_net
from ..._utils import trt_dtype_to_np, str_dtype_to_trt
from ...functional import (
Tensor,
chunk,
concat,
constant,
expand,
shape,
silu,
slice,
permute,
expand_mask,
expand_dims_like,
unsqueeze,
matmul,
softmax,
squeeze,
cast,
gelu,
)
from ...functional import expand_dims, view, bert_attention
from ...layers import LayerNorm, Linear, Conv1d, Mish, RowLinear, ColumnLinear
from ...module import Module
class FeedForward(Module):
def __init__(self, dim, dim_out=None, mult=4, dropout=0.0):
super().__init__()
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
self.project_in = Linear(dim, inner_dim)
self.ff = Linear(inner_dim, dim_out)
def forward(self, x):
return self.ff(gelu(self.project_in(x)))
class AdaLayerNormZero(Module):
def __init__(self, dim):
super().__init__()
self.linear = Linear(dim, dim * 6)
self.norm = LayerNorm(dim, elementwise_affine=False, eps=1e-6)
def forward(self, x, emb=None):
emb = self.linear(silu(emb))
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = chunk(emb, 6, dim=1)
x = self.norm(x)
ones = constant(np.ones(1, dtype=np.float32)).cast(x.dtype)
if default_net().plugin_config.remove_input_padding:
x = x * (ones + scale_msa) + shift_msa
else:
x = x * (ones + unsqueeze(scale_msa, 1)) + unsqueeze(shift_msa, 1)
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
class AdaLayerNormZero_Final(Module):
def __init__(self, dim):
super().__init__()
self.linear = Linear(dim, dim * 2)
self.norm = LayerNorm(dim, elementwise_affine=False, eps=1e-6)
def forward(self, x, emb):
emb = self.linear(silu(emb))
scale, shift = chunk(emb, 2, dim=1)
ones = constant(np.ones(1, dtype=np.float32)).cast(x.dtype)
if default_net().plugin_config.remove_input_padding:
x = self.norm(x) * (ones + scale) + shift
else:
x = self.norm(x) * unsqueeze((ones + scale), 1)
x = x + unsqueeze(shift, 1)
return x
class ConvPositionEmbedding(Module):
def __init__(self, dim, kernel_size=31, groups=16):
super().__init__()
assert kernel_size % 2 != 0
self.conv1d1 = Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2)
self.conv1d2 = Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2)
self.mish = Mish()
def forward(self, x, mask=None): # noqa: F722
if default_net().plugin_config.remove_input_padding:
x = unsqueeze(x, 0)
x = permute(x, [0, 2, 1])
x = self.mish(self.conv1d2(self.mish(self.conv1d1(x))))
out = permute(x, [0, 2, 1])
if default_net().plugin_config.remove_input_padding:
out = squeeze(out, 0)
return out
class Attention(Module):
def __init__(
self,
processor: AttnProcessor,
dim: int,
heads: int = 16,
dim_head: int = 64,
dropout: float = 0.0,
context_dim: Optional[int] = None, # if not None -> joint attention
context_pre_only=None,
):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.processor = processor
self.dim = dim # hidden_size
self.heads = heads
self.inner_dim = dim_head * heads
self.dropout = dropout
self.attention_head_size = dim_head
self.context_dim = context_dim
self.context_pre_only = context_pre_only
self.tp_size = 1
self.num_attention_heads = heads // self.tp_size
self.num_attention_kv_heads = heads // self.tp_size # 8
self.dtype = str_dtype_to_trt("float32")
self.attention_hidden_size = self.attention_head_size * self.num_attention_heads
self.to_q = ColumnLinear(
dim,
self.tp_size * self.num_attention_heads * self.attention_head_size,
bias=True,
dtype=self.dtype,
tp_group=None,
tp_size=self.tp_size,
)
self.to_k = ColumnLinear(
dim,
self.tp_size * self.num_attention_heads * self.attention_head_size,
bias=True,
dtype=self.dtype,
tp_group=None,
tp_size=self.tp_size,
)
self.to_v = ColumnLinear(
dim,
self.tp_size * self.num_attention_heads * self.attention_head_size,
bias=True,
dtype=self.dtype,
tp_group=None,
tp_size=self.tp_size,
)
if self.context_dim is not None:
self.to_k_c = Linear(context_dim, self.inner_dim)
self.to_v_c = Linear(context_dim, self.inner_dim)
if self.context_pre_only is not None:
self.to_q_c = Linear(context_dim, self.inner_dim)
self.to_out = RowLinear(
self.tp_size * self.num_attention_heads * self.attention_head_size,
dim,
bias=True,
dtype=self.dtype,
tp_group=None,
tp_size=self.tp_size,
)
if self.context_pre_only is not None and not self.context_pre_only:
self.to_out_c = Linear(self.inner_dim, dim)
def forward(
self,
x, # noised input x
rope_cos,
rope_sin,
input_lengths,
c=None, # context c
scale=1.0,
rope=None,
c_rope=None, # rotary position embedding for c
) -> torch.Tensor:
if c is not None:
return self.processor(self, x, c=c, input_lengths=input_lengths, scale=scale, rope=rope, c_rope=c_rope)
else:
return self.processor(
self, x, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale
)
def rotate_every_two_3dim(tensor: Tensor) -> Tensor:
shape_tensor = concat(
[shape(tensor, i) / 2 if i == (tensor.ndim() - 1) else shape(tensor, i) for i in range(tensor.ndim())]
)
if default_net().plugin_config.remove_input_padding:
assert tensor.ndim() == 2
x1 = slice(tensor, [0, 0], shape_tensor, [1, 2])
x2 = slice(tensor, [0, 1], shape_tensor, [1, 2])
x1 = expand_dims(x1, 2)
x2 = expand_dims(x2, 2)
zero = constant(np.ascontiguousarray(np.zeros([1], dtype=trt_dtype_to_np(tensor.dtype))))
x2 = zero - x2
x = concat([x2, x1], 2)
out = view(x, concat([shape(x, 0), shape(x, 1) * 2]))
else:
assert tensor.ndim() == 3
x1 = slice(tensor, [0, 0, 0], shape_tensor, [1, 1, 2])
x2 = slice(tensor, [0, 0, 1], shape_tensor, [1, 1, 2])
x1 = expand_dims(x1, 3)
x2 = expand_dims(x2, 3)
zero = constant(np.ascontiguousarray(np.zeros([1], dtype=trt_dtype_to_np(tensor.dtype))))
x2 = zero - x2
x = concat([x2, x1], 3)
out = view(x, concat([shape(x, 0), shape(x, 1), shape(x, 2) * 2]))
return out
def apply_rotary_pos_emb_3dim(x, rope_cos, rope_sin):
if default_net().plugin_config.remove_input_padding:
rot_dim = shape(rope_cos, -1) # 64
new_t_shape = concat([shape(x, 0), rot_dim]) # (-1, 64)
x_ = slice(x, [0, 0], new_t_shape, [1, 1])
end_dim = shape(x, -1) - shape(rope_cos, -1)
new_t_unrotated_shape = concat([shape(x, 0), end_dim]) # (2, -1, 960)
x_unrotated = slice(x, concat([0, rot_dim]), new_t_unrotated_shape, [1, 1])
out = concat([x_ * rope_cos + rotate_every_two_3dim(x_) * rope_sin, x_unrotated], dim=-1)
else:
rot_dim = shape(rope_cos, 2) # 64
new_t_shape = concat([shape(x, 0), shape(x, 1), rot_dim]) # (2, -1, 64)
x_ = slice(x, [0, 0, 0], new_t_shape, [1, 1, 1])
end_dim = shape(x, 2) - shape(rope_cos, 2)
new_t_unrotated_shape = concat([shape(x, 0), shape(x, 1), end_dim]) # (2, -1, 960)
x_unrotated = slice(x, concat([0, 0, rot_dim]), new_t_unrotated_shape, [1, 1, 1])
out = concat([x_ * rope_cos + rotate_every_two_3dim(x_) * rope_sin, x_unrotated], dim=-1)
return out
class AttnProcessor:
def __init__(self):
pass
def __call__(
self,
attn,
x, # noised input x
rope_cos,
rope_sin,
input_lengths,
scale=1.0,
rope=None,
) -> torch.FloatTensor:
query = attn.to_q(x)
key = attn.to_k(x)
value = attn.to_v(x)
# k,v,q all (2,1226,1024)
query = apply_rotary_pos_emb_3dim(query, rope_cos, rope_sin)
key = apply_rotary_pos_emb_3dim(key, rope_cos, rope_sin)
# attention
inner_dim = key.shape[-1]
norm_factor = math.sqrt(attn.attention_head_size)
q_scaling = 1.0 / norm_factor
mask = None
if not default_net().plugin_config.remove_input_padding:
N = shape(x, 1)
B = shape(x, 0)
seq_len_2d = concat([1, N])
max_position_embeddings = 4096
# create position ids
position_ids_buffer = constant(np.expand_dims(np.arange(max_position_embeddings).astype(np.int32), 0))
tmp_position_ids = slice(position_ids_buffer, starts=[0, 0], sizes=seq_len_2d)
tmp_position_ids = expand(tmp_position_ids, concat([B, N])) # BxL
tmp_input_lengths = unsqueeze(input_lengths, 1) # Bx1
tmp_input_lengths = expand(tmp_input_lengths, concat([B, N])) # BxL
mask = tmp_position_ids < tmp_input_lengths # BxL
mask = mask.cast("int32")
if default_net().plugin_config.bert_attention_plugin:
qkv = concat([query, key, value], dim=-1)
# TRT plugin mode
assert input_lengths is not None
if default_net().plugin_config.remove_input_padding:
qkv = qkv.view(concat([-1, 3 * inner_dim]))
max_input_length = constant(
np.zeros(
[
2048,
],
dtype=np.int32,
)
)
else:
max_input_length = None
context = bert_attention(
qkv,
input_lengths,
attn.num_attention_heads,
attn.attention_head_size,
q_scaling=q_scaling,
max_input_length=max_input_length,
)
else:
assert not default_net().plugin_config.remove_input_padding
def transpose_for_scores(x):
new_x_shape = concat([shape(x, 0), shape(x, 1), attn.num_attention_heads, attn.attention_head_size])
y = x.view(new_x_shape)
y = y.transpose(1, 2)
return y
def transpose_for_scores_k(x):
new_x_shape = concat([shape(x, 0), shape(x, 1), attn.num_attention_heads, attn.attention_head_size])
y = x.view(new_x_shape)
y = y.permute([0, 2, 3, 1])
return y
query = transpose_for_scores(query)
key = transpose_for_scores_k(key)
value = transpose_for_scores(value)
attention_scores = matmul(query, key, use_fp32_acc=False)
if mask is not None:
attention_mask = expand_mask(mask, shape(query, 2))
attention_mask = cast(attention_mask, attention_scores.dtype)
attention_scores = attention_scores + attention_mask
attention_probs = softmax(attention_scores, dim=-1)
context = matmul(attention_probs, value, use_fp32_acc=False).transpose(1, 2)
context = context.view(concat([shape(context, 0), shape(context, 1), attn.attention_hidden_size]))
context = attn.to_out(context)
if mask is not None:
mask = mask.view(concat([shape(mask, 0), shape(mask, 1), 1]))
mask = expand_dims_like(mask, context)
mask = cast(mask, context.dtype)
context = context * mask
return context
# DiT Block
class DiTBlock(Module):
def __init__(self, dim, heads, dim_head, ff_mult=2, dropout=0.1):
super().__init__()
self.attn_norm = AdaLayerNormZero(dim)
self.attn = Attention(
processor=AttnProcessor(),
dim=dim,
heads=heads,
dim_head=dim_head,
dropout=dropout,
)
self.ff_norm = LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout)
def forward(
self, x, t, rope_cos, rope_sin, input_lengths, scale=1.0, rope=ModuleNotFoundError
): # x: noised input, t: time embedding
# pre-norm & modulation for attention input
norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
# attention
# norm ----> (2,1226,1024)
attn_output = self.attn(x=norm, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale)
# process attention output for input x
if default_net().plugin_config.remove_input_padding:
x = x + gate_msa * attn_output
else:
x = x + unsqueeze(gate_msa, 1) * attn_output
ones = constant(np.ones(1, dtype=np.float32)).cast(x.dtype)
if default_net().plugin_config.remove_input_padding:
norm = self.ff_norm(x) * (ones + scale_mlp) + shift_mlp
else:
norm = self.ff_norm(x) * (ones + unsqueeze(scale_mlp, 1)) + unsqueeze(shift_mlp, 1)
# norm = self.ff_norm(x) * (ones + scale_mlp) + shift_mlp
ff_output = self.ff(norm)
if default_net().plugin_config.remove_input_padding:
x = x + gate_mlp * ff_output
else:
x = x + unsqueeze(gate_mlp, 1) * ff_output
return x
class TimestepEmbedding(Module):
def __init__(self, dim, freq_embed_dim=256, dtype=None):
super().__init__()
# self.time_embed = SinusPositionEmbedding(freq_embed_dim)
self.mlp1 = Linear(freq_embed_dim, dim, bias=True, dtype=dtype)
self.mlp2 = Linear(dim, dim, bias=True, dtype=dtype)
def forward(self, timestep):
t_freq = self.mlp1(timestep)
t_freq = silu(t_freq)
t_emb = self.mlp2(t_freq)
return t_emb

View File

@@ -0,0 +1,70 @@
stage=$1
stop_stage=$2
model=$3 # F5TTS_Base
if [ -z "$model" ]; then
echo "Model is none"
exit 1
fi
echo "Start stage: $stage, Stop stage: $stop_stage, Model: $model"
export CUDA_VISIBLE_DEVICES=0
F5_TTS_HF_DOWNLOAD_PATH=./F5-TTS
F5_TTS_TRT_LLM_CHECKPOINT_PATH=./trtllm_ckpt
F5_TTS_TRT_LLM_ENGINE_PATH=./f5_trt_llm_engine
vocoder_trt_engine_path=vocos_vocoder.plan
model_repo=./model_repo
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
echo "Downloading f5 tts from huggingface"
huggingface-cli download SWivid/F5-TTS --local-dir $F5_TTS_HF_DOWNLOAD_PATH
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
echo "Converting checkpoint"
python3 ./scripts/convert_checkpoint.py \
--timm_ckpt "$F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt" \
--output_dir "$F5_TTS_TRT_LLM_CHECKPOINT_PATH" --model_name $model
python_package_path=/usr/local/lib/python3.12/dist-packages
cp -r patch/* $python_package_path/tensorrt_llm/models
trtllm-build --checkpoint_dir $F5_TTS_TRT_LLM_CHECKPOINT_PATH \
--max_batch_size 8 \
--output_dir $F5_TTS_TRT_LLM_ENGINE_PATH --remove_input_padding disable
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
echo "Exporting vocos vocoder"
onnx_vocoder_path=vocos_vocoder.onnx
python3 scripts/export_vocoder_to_onnx.py --vocoder vocos --output-path $onnx_vocoder_path
bash scripts/export_vocos_trt.sh $onnx_vocoder_path $vocoder_trt_engine_path
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
echo "Building triton server"
rm -r $model_repo
cp -r ./model_repo_f5_tts $model_repo
python3 scripts/fill_template.py -i $model_repo/f5_tts/config.pbtxt vocab:$F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt,model:$F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt,trtllm:$F5_TTS_TRT_LLM_ENGINE_PATH,vocoder:vocos
cp $vocoder_trt_engine_path $model_repo/vocoder/1/vocoder.plan
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
echo "Starting triton server"
tritonserver --model-repository=$model_repo
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
echo "Testing triton server"
num_task=1
log_dir=./log_concurrent_tasks_${num_task}
rm -r $log_dir
python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts --log-dir $log_dir
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
echo "Testing http client"
audio=../../infer/examples/basic/basic_ref_en.wav
reference_text="Some call me nature, others call me mother nature."
target_text="I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring."
python3 client_http.py --reference-audio $audio --reference-text "$reference_text" --target-text "$target_text"
fi

View File

@@ -0,0 +1,247 @@
# Modified from https://github.com/echocatzh/conv-stft/blob/master/conv_stft/conv_stft.py
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# MIT License
# Copyright (c) 2020 Shimin Zhang
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import torch as th
import torch.nn.functional as F
from scipy.signal import check_COLA, get_window
support_clp_op = None
if th.__version__ >= "1.7.0":
from torch.fft import rfft as fft
support_clp_op = True
else:
from torch import rfft as fft
class STFT(th.nn.Module):
def __init__(
self,
win_len=1024,
win_hop=512,
fft_len=1024,
enframe_mode="continue",
win_type="hann",
win_sqrt=False,
pad_center=True,
):
"""
Implement of STFT using 1D convolution and 1D transpose convolutions.
Implement of framing the signal in 2 ways, `break` and `continue`.
`break` method is a kaldi-like framing.
`continue` method is a librosa-like framing.
More information about `perfect reconstruction`:
1. https://ww2.mathworks.cn/help/signal/ref/stft.html
2. https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.get_window.html
Args:
win_len (int): Number of points in one frame. Defaults to 1024.
win_hop (int): Number of framing stride. Defaults to 512.
fft_len (int): Number of DFT points. Defaults to 1024.
enframe_mode (str, optional): `break` and `continue`. Defaults to 'continue'.
win_type (str, optional): The type of window to create. Defaults to 'hann'.
win_sqrt (bool, optional): using square root window. Defaults to True.
pad_center (bool, optional): `perfect reconstruction` opts. Defaults to True.
"""
super(STFT, self).__init__()
assert enframe_mode in ["break", "continue"]
assert fft_len >= win_len
self.win_len = win_len
self.win_hop = win_hop
self.fft_len = fft_len
self.mode = enframe_mode
self.win_type = win_type
self.win_sqrt = win_sqrt
self.pad_center = pad_center
self.pad_amount = self.fft_len // 2
en_k, fft_k, ifft_k, ola_k = self.__init_kernel__()
self.register_buffer("en_k", en_k)
self.register_buffer("fft_k", fft_k)
self.register_buffer("ifft_k", ifft_k)
self.register_buffer("ola_k", ola_k)
def __init_kernel__(self):
"""
Generate enframe_kernel, fft_kernel, ifft_kernel and overlap-add kernel.
** enframe_kernel: Using conv1d layer and identity matrix.
** fft_kernel: Using linear layer for matrix multiplication. In fact,
enframe_kernel and fft_kernel can be combined, But for the sake of
readability, I took the two apart.
** ifft_kernel, pinv of fft_kernel.
** overlap-add kernel, just like enframe_kernel, but transposed.
Returns:
tuple: four kernels.
"""
enframed_kernel = th.eye(self.fft_len)[:, None, :]
if support_clp_op:
tmp = fft(th.eye(self.fft_len))
fft_kernel = th.stack([tmp.real, tmp.imag], dim=2)
else:
fft_kernel = fft(th.eye(self.fft_len), 1)
if self.mode == "break":
enframed_kernel = th.eye(self.win_len)[:, None, :]
fft_kernel = fft_kernel[: self.win_len]
fft_kernel = th.cat((fft_kernel[:, :, 0], fft_kernel[:, :, 1]), dim=1)
ifft_kernel = th.pinverse(fft_kernel)[:, None, :]
window = get_window(self.win_type, self.win_len)
self.perfect_reconstruct = check_COLA(window, self.win_len, self.win_len - self.win_hop)
window = th.FloatTensor(window)
if self.mode == "continue":
left_pad = (self.fft_len - self.win_len) // 2
right_pad = left_pad + (self.fft_len - self.win_len) % 2
window = F.pad(window, (left_pad, right_pad))
if self.win_sqrt:
self.padded_window = window
window = th.sqrt(window)
else:
self.padded_window = window**2
fft_kernel = fft_kernel.T * window
ifft_kernel = ifft_kernel * window
ola_kernel = th.eye(self.fft_len)[: self.win_len, None, :]
if self.mode == "continue":
ola_kernel = th.eye(self.fft_len)[:, None, : self.fft_len]
return enframed_kernel, fft_kernel, ifft_kernel, ola_kernel
def is_perfect(self):
"""
Whether the parameters win_len, win_hop and win_sqrt
obey constants overlap-add(COLA)
Returns:
bool: Return true if parameters obey COLA.
"""
return self.perfect_reconstruct and self.pad_center
def transform(self, inputs, return_type="complex"):
"""Take input data (audio) to STFT domain.
Args:
inputs (tensor): Tensor of floats, with shape (num_batch, num_samples)
return_type (str, optional): return (mag, phase) when `magphase`,
return (real, imag) when `realimag` and complex(real, imag) when `complex`.
Defaults to 'complex'.
Returns:
tuple: (mag, phase) when `magphase`, return (real, imag) when
`realimag`. Defaults to 'complex', each elements with shape
[num_batch, num_frequencies, num_frames]
"""
assert return_type in ["magphase", "realimag", "complex"]
if inputs.dim() == 2:
inputs = th.unsqueeze(inputs, 1)
self.num_samples = inputs.size(-1)
if self.pad_center:
inputs = F.pad(inputs, (self.pad_amount, self.pad_amount), mode="reflect")
enframe_inputs = F.conv1d(inputs, self.en_k, stride=self.win_hop)
outputs = th.transpose(enframe_inputs, 1, 2)
outputs = F.linear(outputs, self.fft_k)
outputs = th.transpose(outputs, 1, 2)
dim = self.fft_len // 2 + 1
real = outputs[:, :dim, :]
imag = outputs[:, dim:, :]
if return_type == "realimag":
return real, imag
elif return_type == "complex":
assert support_clp_op
return th.complex(real, imag)
else:
mags = th.sqrt(real**2 + imag**2)
phase = th.atan2(imag, real)
return mags, phase
def inverse(self, input1, input2=None, input_type="magphase"):
"""Call the inverse STFT (iSTFT), given tensors produced
by the `transform` function.
Args:
input1 (tensors): Magnitude/Real-part of STFT with shape
[num_batch, num_frequencies, num_frames]
input2 (tensors): Phase/Imag-part of STFT with shape
[num_batch, num_frequencies, num_frames]
input_type (str, optional): Mathematical meaning of input tensor's.
Defaults to 'magphase'.
Returns:
tensors: Reconstructed audio given magnitude and phase. Of
shape [num_batch, num_samples]
"""
assert input_type in ["magphase", "realimag"]
if input_type == "realimag":
real, imag = None, None
if support_clp_op and th.is_complex(input1):
real, imag = input1.real, input1.imag
else:
real, imag = input1, input2
else:
real = input1 * th.cos(input2)
imag = input1 * th.sin(input2)
inputs = th.cat([real, imag], dim=1)
outputs = F.conv_transpose1d(inputs, self.ifft_k, stride=self.win_hop)
t = (self.padded_window[None, :, None]).repeat(1, 1, inputs.size(-1))
t = t.to(inputs.device)
coff = F.conv_transpose1d(t, self.ola_k, stride=self.win_hop)
num_frames = input1.size(-1)
num_samples = num_frames * self.win_hop
rm_start, rm_end = self.pad_amount, self.pad_amount + num_samples
outputs = outputs[..., rm_start:rm_end]
coff = coff[..., rm_start:rm_end]
coffidx = th.where(coff > 1e-8)
outputs[coffidx] = outputs[coffidx] / (coff[coffidx])
return outputs.squeeze(dim=1)
def forward(self, inputs):
"""Take input data (audio) to STFT domain and then back to audio.
Args:
inputs (tensor): Tensor of floats, with shape [num_batch, num_samples]
Returns:
tensor: Reconstructed audio given magnitude and phase.
Of shape [num_batch, num_samples]
"""
mag, phase = self.transform(inputs)
rec_wav = self.inverse(mag, phase)
return rec_wav

View File

@@ -0,0 +1,359 @@
import argparse
import json
import os
import re
import time
import traceback
from concurrent.futures import ThreadPoolExecutor, as_completed
import safetensors.torch
import torch
from tensorrt_llm import str_dtype_to_torch
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.convert_utils import split, split_matrix_tp
def split_q_tp(v, n_head, n_hidden, tensor_parallel, rank):
split_v = split(v, tensor_parallel, rank, dim=1)
return split_v.contiguous()
def split_q_bias_tp(v, n_head, n_hidden, tensor_parallel, rank):
split_v = split(v, tensor_parallel, rank, dim=0)
return split_v.contiguous()
FACEBOOK_DIT_NAME_MAPPING = {
"^time_embed.time_mlp.0.weight$": "time_embed.mlp1.weight",
"^time_embed.time_mlp.0.bias$": "time_embed.mlp1.bias",
"^time_embed.time_mlp.2.weight$": "time_embed.mlp2.weight",
"^time_embed.time_mlp.2.bias$": "time_embed.mlp2.bias",
"^input_embed.conv_pos_embed.conv1d.0.weight$": "input_embed.conv_pos_embed.conv1d1.weight",
"^input_embed.conv_pos_embed.conv1d.0.bias$": "input_embed.conv_pos_embed.conv1d1.bias",
"^input_embed.conv_pos_embed.conv1d.2.weight$": "input_embed.conv_pos_embed.conv1d2.weight",
"^input_embed.conv_pos_embed.conv1d.2.bias$": "input_embed.conv_pos_embed.conv1d2.bias",
"^transformer_blocks.0.attn.to_out.0.weight$": "transformer_blocks.0.attn.to_out.weight",
"^transformer_blocks.0.attn.to_out.0.bias$": "transformer_blocks.0.attn.to_out.bias",
"^transformer_blocks.1.attn.to_out.0.weight$": "transformer_blocks.1.attn.to_out.weight",
"^transformer_blocks.1.attn.to_out.0.bias$": "transformer_blocks.1.attn.to_out.bias",
"^transformer_blocks.2.attn.to_out.0.weight$": "transformer_blocks.2.attn.to_out.weight",
"^transformer_blocks.2.attn.to_out.0.bias$": "transformer_blocks.2.attn.to_out.bias",
"^transformer_blocks.3.attn.to_out.0.weight$": "transformer_blocks.3.attn.to_out.weight",
"^transformer_blocks.3.attn.to_out.0.bias$": "transformer_blocks.3.attn.to_out.bias",
"^transformer_blocks.4.attn.to_out.0.weight$": "transformer_blocks.4.attn.to_out.weight",
"^transformer_blocks.4.attn.to_out.0.bias$": "transformer_blocks.4.attn.to_out.bias",
"^transformer_blocks.5.attn.to_out.0.weight$": "transformer_blocks.5.attn.to_out.weight",
"^transformer_blocks.5.attn.to_out.0.bias$": "transformer_blocks.5.attn.to_out.bias",
"^transformer_blocks.6.attn.to_out.0.weight$": "transformer_blocks.6.attn.to_out.weight",
"^transformer_blocks.6.attn.to_out.0.bias$": "transformer_blocks.6.attn.to_out.bias",
"^transformer_blocks.7.attn.to_out.0.weight$": "transformer_blocks.7.attn.to_out.weight",
"^transformer_blocks.7.attn.to_out.0.bias$": "transformer_blocks.7.attn.to_out.bias",
"^transformer_blocks.8.attn.to_out.0.weight$": "transformer_blocks.8.attn.to_out.weight",
"^transformer_blocks.8.attn.to_out.0.bias$": "transformer_blocks.8.attn.to_out.bias",
"^transformer_blocks.9.attn.to_out.0.weight$": "transformer_blocks.9.attn.to_out.weight",
"^transformer_blocks.9.attn.to_out.0.bias$": "transformer_blocks.9.attn.to_out.bias",
"^transformer_blocks.10.attn.to_out.0.weight$": "transformer_blocks.10.attn.to_out.weight",
"^transformer_blocks.10.attn.to_out.0.bias$": "transformer_blocks.10.attn.to_out.bias",
"^transformer_blocks.11.attn.to_out.0.weight$": "transformer_blocks.11.attn.to_out.weight",
"^transformer_blocks.11.attn.to_out.0.bias$": "transformer_blocks.11.attn.to_out.bias",
"^transformer_blocks.12.attn.to_out.0.weight$": "transformer_blocks.12.attn.to_out.weight",
"^transformer_blocks.12.attn.to_out.0.bias$": "transformer_blocks.12.attn.to_out.bias",
"^transformer_blocks.13.attn.to_out.0.weight$": "transformer_blocks.13.attn.to_out.weight",
"^transformer_blocks.13.attn.to_out.0.bias$": "transformer_blocks.13.attn.to_out.bias",
"^transformer_blocks.14.attn.to_out.0.weight$": "transformer_blocks.14.attn.to_out.weight",
"^transformer_blocks.14.attn.to_out.0.bias$": "transformer_blocks.14.attn.to_out.bias",
"^transformer_blocks.15.attn.to_out.0.weight$": "transformer_blocks.15.attn.to_out.weight",
"^transformer_blocks.15.attn.to_out.0.bias$": "transformer_blocks.15.attn.to_out.bias",
"^transformer_blocks.16.attn.to_out.0.weight$": "transformer_blocks.16.attn.to_out.weight",
"^transformer_blocks.16.attn.to_out.0.bias$": "transformer_blocks.16.attn.to_out.bias",
"^transformer_blocks.17.attn.to_out.0.weight$": "transformer_blocks.17.attn.to_out.weight",
"^transformer_blocks.17.attn.to_out.0.bias$": "transformer_blocks.17.attn.to_out.bias",
"^transformer_blocks.18.attn.to_out.0.weight$": "transformer_blocks.18.attn.to_out.weight",
"^transformer_blocks.18.attn.to_out.0.bias$": "transformer_blocks.18.attn.to_out.bias",
"^transformer_blocks.19.attn.to_out.0.weight$": "transformer_blocks.19.attn.to_out.weight",
"^transformer_blocks.19.attn.to_out.0.bias$": "transformer_blocks.19.attn.to_out.bias",
"^transformer_blocks.20.attn.to_out.0.weight$": "transformer_blocks.20.attn.to_out.weight",
"^transformer_blocks.20.attn.to_out.0.bias$": "transformer_blocks.20.attn.to_out.bias",
"^transformer_blocks.21.attn.to_out.0.weight$": "transformer_blocks.21.attn.to_out.weight",
"^transformer_blocks.21.attn.to_out.0.bias$": "transformer_blocks.21.attn.to_out.bias",
"^transformer_blocks.0.ff.ff.0.0.weight$": "transformer_blocks.0.ff.project_in.weight",
"^transformer_blocks.0.ff.ff.0.0.bias$": "transformer_blocks.0.ff.project_in.bias",
"^transformer_blocks.0.ff.ff.2.weight$": "transformer_blocks.0.ff.ff.weight",
"^transformer_blocks.0.ff.ff.2.bias$": "transformer_blocks.0.ff.ff.bias",
"^transformer_blocks.1.ff.ff.0.0.weight$": "transformer_blocks.1.ff.project_in.weight",
"^transformer_blocks.1.ff.ff.0.0.bias$": "transformer_blocks.1.ff.project_in.bias",
"^transformer_blocks.1.ff.ff.2.weight$": "transformer_blocks.1.ff.ff.weight",
"^transformer_blocks.1.ff.ff.2.bias$": "transformer_blocks.1.ff.ff.bias",
"^transformer_blocks.2.ff.ff.0.0.weight$": "transformer_blocks.2.ff.project_in.weight",
"^transformer_blocks.2.ff.ff.0.0.bias$": "transformer_blocks.2.ff.project_in.bias",
"^transformer_blocks.2.ff.ff.2.weight$": "transformer_blocks.2.ff.ff.weight",
"^transformer_blocks.2.ff.ff.2.bias$": "transformer_blocks.2.ff.ff.bias",
"^transformer_blocks.3.ff.ff.0.0.weight$": "transformer_blocks.3.ff.project_in.weight",
"^transformer_blocks.3.ff.ff.0.0.bias$": "transformer_blocks.3.ff.project_in.bias",
"^transformer_blocks.3.ff.ff.2.weight$": "transformer_blocks.3.ff.ff.weight",
"^transformer_blocks.3.ff.ff.2.bias$": "transformer_blocks.3.ff.ff.bias",
"^transformer_blocks.4.ff.ff.0.0.weight$": "transformer_blocks.4.ff.project_in.weight",
"^transformer_blocks.4.ff.ff.0.0.bias$": "transformer_blocks.4.ff.project_in.bias",
"^transformer_blocks.4.ff.ff.2.weight$": "transformer_blocks.4.ff.ff.weight",
"^transformer_blocks.4.ff.ff.2.bias$": "transformer_blocks.4.ff.ff.bias",
"^transformer_blocks.5.ff.ff.0.0.weight$": "transformer_blocks.5.ff.project_in.weight",
"^transformer_blocks.5.ff.ff.0.0.bias$": "transformer_blocks.5.ff.project_in.bias",
"^transformer_blocks.5.ff.ff.2.weight$": "transformer_blocks.5.ff.ff.weight",
"^transformer_blocks.5.ff.ff.2.bias$": "transformer_blocks.5.ff.ff.bias",
"^transformer_blocks.6.ff.ff.0.0.weight$": "transformer_blocks.6.ff.project_in.weight",
"^transformer_blocks.6.ff.ff.0.0.bias$": "transformer_blocks.6.ff.project_in.bias",
"^transformer_blocks.6.ff.ff.2.weight$": "transformer_blocks.6.ff.ff.weight",
"^transformer_blocks.6.ff.ff.2.bias$": "transformer_blocks.6.ff.ff.bias",
"^transformer_blocks.7.ff.ff.0.0.weight$": "transformer_blocks.7.ff.project_in.weight",
"^transformer_blocks.7.ff.ff.0.0.bias$": "transformer_blocks.7.ff.project_in.bias",
"^transformer_blocks.7.ff.ff.2.weight$": "transformer_blocks.7.ff.ff.weight",
"^transformer_blocks.7.ff.ff.2.bias$": "transformer_blocks.7.ff.ff.bias",
"^transformer_blocks.8.ff.ff.0.0.weight$": "transformer_blocks.8.ff.project_in.weight",
"^transformer_blocks.8.ff.ff.0.0.bias$": "transformer_blocks.8.ff.project_in.bias",
"^transformer_blocks.8.ff.ff.2.weight$": "transformer_blocks.8.ff.ff.weight",
"^transformer_blocks.8.ff.ff.2.bias$": "transformer_blocks.8.ff.ff.bias",
"^transformer_blocks.9.ff.ff.0.0.weight$": "transformer_blocks.9.ff.project_in.weight",
"^transformer_blocks.9.ff.ff.0.0.bias$": "transformer_blocks.9.ff.project_in.bias",
"^transformer_blocks.9.ff.ff.2.weight$": "transformer_blocks.9.ff.ff.weight",
"^transformer_blocks.9.ff.ff.2.bias$": "transformer_blocks.9.ff.ff.bias",
"^transformer_blocks.10.ff.ff.0.0.weight$": "transformer_blocks.10.ff.project_in.weight",
"^transformer_blocks.10.ff.ff.0.0.bias$": "transformer_blocks.10.ff.project_in.bias",
"^transformer_blocks.10.ff.ff.2.weight$": "transformer_blocks.10.ff.ff.weight",
"^transformer_blocks.10.ff.ff.2.bias$": "transformer_blocks.10.ff.ff.bias",
"^transformer_blocks.11.ff.ff.0.0.weight$": "transformer_blocks.11.ff.project_in.weight",
"^transformer_blocks.11.ff.ff.0.0.bias$": "transformer_blocks.11.ff.project_in.bias",
"^transformer_blocks.11.ff.ff.2.weight$": "transformer_blocks.11.ff.ff.weight",
"^transformer_blocks.11.ff.ff.2.bias$": "transformer_blocks.11.ff.ff.bias",
"^transformer_blocks.12.ff.ff.0.0.weight$": "transformer_blocks.12.ff.project_in.weight",
"^transformer_blocks.12.ff.ff.0.0.bias$": "transformer_blocks.12.ff.project_in.bias",
"^transformer_blocks.12.ff.ff.2.weight$": "transformer_blocks.12.ff.ff.weight",
"^transformer_blocks.12.ff.ff.2.bias$": "transformer_blocks.12.ff.ff.bias",
"^transformer_blocks.13.ff.ff.0.0.weight$": "transformer_blocks.13.ff.project_in.weight",
"^transformer_blocks.13.ff.ff.0.0.bias$": "transformer_blocks.13.ff.project_in.bias",
"^transformer_blocks.13.ff.ff.2.weight$": "transformer_blocks.13.ff.ff.weight",
"^transformer_blocks.13.ff.ff.2.bias$": "transformer_blocks.13.ff.ff.bias",
"^transformer_blocks.14.ff.ff.0.0.weight$": "transformer_blocks.14.ff.project_in.weight",
"^transformer_blocks.14.ff.ff.0.0.bias$": "transformer_blocks.14.ff.project_in.bias",
"^transformer_blocks.14.ff.ff.2.weight$": "transformer_blocks.14.ff.ff.weight",
"^transformer_blocks.14.ff.ff.2.bias$": "transformer_blocks.14.ff.ff.bias",
"^transformer_blocks.15.ff.ff.0.0.weight$": "transformer_blocks.15.ff.project_in.weight",
"^transformer_blocks.15.ff.ff.0.0.bias$": "transformer_blocks.15.ff.project_in.bias",
"^transformer_blocks.15.ff.ff.2.weight$": "transformer_blocks.15.ff.ff.weight",
"^transformer_blocks.15.ff.ff.2.bias$": "transformer_blocks.15.ff.ff.bias",
"^transformer_blocks.16.ff.ff.0.0.weight$": "transformer_blocks.16.ff.project_in.weight",
"^transformer_blocks.16.ff.ff.0.0.bias$": "transformer_blocks.16.ff.project_in.bias",
"^transformer_blocks.16.ff.ff.2.weight$": "transformer_blocks.16.ff.ff.weight",
"^transformer_blocks.16.ff.ff.2.bias$": "transformer_blocks.16.ff.ff.bias",
"^transformer_blocks.17.ff.ff.0.0.weight$": "transformer_blocks.17.ff.project_in.weight",
"^transformer_blocks.17.ff.ff.0.0.bias$": "transformer_blocks.17.ff.project_in.bias",
"^transformer_blocks.17.ff.ff.2.weight$": "transformer_blocks.17.ff.ff.weight",
"^transformer_blocks.17.ff.ff.2.bias$": "transformer_blocks.17.ff.ff.bias",
"^transformer_blocks.18.ff.ff.0.0.weight$": "transformer_blocks.18.ff.project_in.weight",
"^transformer_blocks.18.ff.ff.0.0.bias$": "transformer_blocks.18.ff.project_in.bias",
"^transformer_blocks.18.ff.ff.2.weight$": "transformer_blocks.18.ff.ff.weight",
"^transformer_blocks.18.ff.ff.2.bias$": "transformer_blocks.18.ff.ff.bias",
"^transformer_blocks.19.ff.ff.0.0.weight$": "transformer_blocks.19.ff.project_in.weight",
"^transformer_blocks.19.ff.ff.0.0.bias$": "transformer_blocks.19.ff.project_in.bias",
"^transformer_blocks.19.ff.ff.2.weight$": "transformer_blocks.19.ff.ff.weight",
"^transformer_blocks.19.ff.ff.2.bias$": "transformer_blocks.19.ff.ff.bias",
"^transformer_blocks.20.ff.ff.0.0.weight$": "transformer_blocks.20.ff.project_in.weight",
"^transformer_blocks.20.ff.ff.0.0.bias$": "transformer_blocks.20.ff.project_in.bias",
"^transformer_blocks.20.ff.ff.2.weight$": "transformer_blocks.20.ff.ff.weight",
"^transformer_blocks.20.ff.ff.2.bias$": "transformer_blocks.20.ff.ff.bias",
"^transformer_blocks.21.ff.ff.0.0.weight$": "transformer_blocks.21.ff.project_in.weight",
"^transformer_blocks.21.ff.ff.0.0.bias$": "transformer_blocks.21.ff.project_in.bias",
"^transformer_blocks.21.ff.ff.2.weight$": "transformer_blocks.21.ff.ff.weight",
"^transformer_blocks.21.ff.ff.2.bias$": "transformer_blocks.21.ff.ff.bias",
}
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name",
type=str,
default="F5TTS_Base",
choices=[
"F5TTS_Base",
],
) # TODO: support F5TTS_v1_Base
parser.add_argument("--timm_ckpt", type=str, default="./ckpts/model_1200000.pt")
parser.add_argument(
"--output_dir", type=str, default="./tllm_checkpoint", help="The path to save the TensorRT-LLM checkpoint"
)
parser.add_argument("--hidden_size", type=int, default=1024, help="The hidden size of DiT")
parser.add_argument("--depth", type=int, default=22, help="The number of DiTBlock layers")
parser.add_argument("--num_heads", type=int, default=16, help="The number of heads of attention module")
parser.add_argument("--cfg_scale", type=float, default=4.0)
parser.add_argument("--tp_size", type=int, default=1, help="N-way tensor parallelism size")
parser.add_argument("--cp_size", type=int, default=1, help="Context parallelism size")
parser.add_argument("--pp_size", type=int, default=1, help="N-way pipeline parallelism size")
parser.add_argument("--dtype", type=str, default="float16", choices=["float32", "bfloat16", "float16"])
parser.add_argument("--fp8_linear", action="store_true", help="Whether use FP8 for linear layers")
parser.add_argument(
"--workers", type=int, default=1, help="The number of workers for converting checkpoint in parallel"
)
args = parser.parse_args()
return args
def convert_timm_dit(args, mapping, dtype="float32"):
weights = {}
tik = time.time()
torch_dtype = str_dtype_to_torch(dtype)
tensor_parallel = mapping.tp_size
model_params = dict(torch.load(args.timm_ckpt))
model_params = {
k: v for k, v in model_params["ema_model_state_dict"].items() if k.startswith("ema_model.transformer")
}
prefix = "ema_model.transformer."
model_params = {key[len(prefix) :] if key.startswith(prefix) else key: value for key, value in model_params.items()}
timm_to_trtllm_name = FACEBOOK_DIT_NAME_MAPPING
def get_trtllm_name(timm_name):
for k, v in timm_to_trtllm_name.items():
m = re.match(k, timm_name)
if m is not None:
if "*" in v:
v = v.replace("*", m.groups()[0])
return v
return timm_name
weights = dict()
for name, param in model_params.items():
if name == "input_embed.conv_pos_embed.conv1d.0.weight" or name == "input_embed.conv_pos_embed.conv1d.2.weight":
weights[get_trtllm_name(name)] = param.contiguous().to(torch_dtype).unsqueeze(-1)
else:
weights[get_trtllm_name(name)] = param.contiguous().to(torch_dtype)
assert len(weights) == len(model_params)
# new_prefix = 'f5_transformer.'
new_prefix = ""
weights = {new_prefix + key: value for key, value in weights.items()}
import math
scale_factor = math.pow(64, -0.25)
for k, v in weights.items():
if re.match("^transformer_blocks.*.attn.to_k.weight$", k):
weights[k] *= scale_factor
weights[k] = split_q_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank)
elif re.match("^transformer_blocks.*.attn.to_k.bias$", k):
weights[k] *= scale_factor
weights[k] = split_q_bias_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank)
elif re.match("^transformer_blocks.*.attn.to_q.weight$", k):
weights[k] = split_q_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank)
weights[k] *= scale_factor
elif re.match("^transformer_blocks.*.attn.to_q.bias$", k):
weights[k] = split_q_bias_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank)
weights[k] *= scale_factor
elif re.match("^transformer_blocks.*.attn.to_v.weight$", k):
weights[k] = split_q_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank)
elif re.match("^transformer_blocks.*.attn.to_v.bias$", k):
weights[k] = split_q_bias_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank)
elif re.match("^transformer_blocks.*.attn.to_out.weight$", k):
weights[k] = split_matrix_tp(v, tensor_parallel, mapping.tp_rank, dim=1)
tok = time.time()
t = time.strftime("%H:%M:%S", time.gmtime(tok - tik))
print(f"Weights loaded. Total time: {t}")
return weights
def save_config(args):
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
config = {
"architecture": "F5TTS",
"dtype": args.dtype,
"hidden_size": 1024,
"num_hidden_layers": 22,
"num_attention_heads": 16,
"dim_head": 64,
"dropout": 0.1,
"ff_mult": 2,
"mel_dim": 100,
"text_num_embeds": 256,
"text_dim": 512,
"conv_layers": 4,
"long_skip_connection": False,
"mapping": {
"world_size": args.cp_size * args.tp_size * args.pp_size,
"cp_size": args.cp_size,
"tp_size": args.tp_size,
"pp_size": args.pp_size,
},
}
if args.fp8_linear:
config["quantization"] = {
"quant_algo": "FP8",
# TODO: add support for exclude modules.
# 'exclude_modules': "*final_layer*",
}
with open(os.path.join(args.output_dir, "config.json"), "w") as f:
json.dump(config, f, indent=4)
def covert_and_save(args, rank):
if rank == 0:
save_config(args)
mapping = Mapping(
world_size=args.cp_size * args.tp_size * args.pp_size,
rank=rank,
cp_size=args.cp_size,
tp_size=args.tp_size,
pp_size=args.pp_size,
)
weights = convert_timm_dit(args, mapping, dtype=args.dtype)
safetensors.torch.save_file(weights, os.path.join(args.output_dir, f"rank{rank}.safetensors"))
def execute(workers, func, args):
if workers == 1:
for rank, f in enumerate(func):
f(args, rank)
else:
with ThreadPoolExecutor(max_workers=workers) as p:
futures = [p.submit(f, args, rank) for rank, f in enumerate(func)]
exceptions = []
for future in as_completed(futures):
try:
future.result()
except Exception as e:
traceback.print_exc()
exceptions.append(e)
assert len(exceptions) == 0, "Checkpoint conversion failed, please check error log."
def main():
args = parse_arguments()
world_size = args.cp_size * args.tp_size * args.pp_size
assert args.pp_size == 1, "PP is not supported yet."
tik = time.time()
if args.timm_ckpt is None:
return
print("start execute")
execute(args.workers, [covert_and_save] * world_size, args)
tok = time.time()
t = time.strftime("%H:%M:%S", time.gmtime(tok - tik))
print(f"Total time of converting checkpoints: {t}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,137 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
from conv_stft import STFT
from vocos import Vocos
import argparse
opset_version = 17
def get_args():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
"--vocoder",
type=str,
default="vocos",
choices=["vocos", "bigvgan"],
help="Vocoder to export",
)
parser.add_argument(
"--output-path",
type=str,
default="./vocos_vocoder.onnx",
help="Output path",
)
return parser.parse_args()
class ISTFTHead(nn.Module):
def __init__(self, n_fft: int, hop_length: int):
super().__init__()
self.out = None
self.stft = STFT(fft_len=n_fft, win_hop=hop_length, win_len=n_fft)
def forward(self, x: torch.Tensor):
x = self.out(x).transpose(1, 2)
mag, p = x.chunk(2, dim=1)
mag = torch.exp(mag)
mag = torch.clip(mag, max=1e2)
real = mag * torch.cos(p)
imag = mag * torch.sin(p)
audio = self.stft.inverse(input1=real, input2=imag, input_type="realimag")
return audio
class VocosVocoder(nn.Module):
def __init__(self, vocos_vocoder):
super(VocosVocoder, self).__init__()
self.vocos_vocoder = vocos_vocoder
istft_head_out = self.vocos_vocoder.head.out
n_fft = self.vocos_vocoder.head.istft.n_fft
hop_length = self.vocos_vocoder.head.istft.hop_length
istft_head_for_export = ISTFTHead(n_fft, hop_length)
istft_head_for_export.out = istft_head_out
self.vocos_vocoder.head = istft_head_for_export
def forward(self, mel):
waveform = self.vocos_vocoder.decode(mel)
return waveform
def export_VocosVocoder(vocos_vocoder, output_path, verbose):
vocos_vocoder = VocosVocoder(vocos_vocoder).cuda()
vocos_vocoder.eval()
dummy_batch_size = 8
dummy_input_length = 500
dummy_mel = torch.randn(dummy_batch_size, 100, dummy_input_length).cuda()
with torch.no_grad():
dummy_waveform = vocos_vocoder(mel=dummy_mel)
print(dummy_waveform.shape)
dummy_input = dummy_mel
torch.onnx.export(
vocos_vocoder,
dummy_input,
output_path,
opset_version=opset_version,
do_constant_folding=True,
input_names=["mel"],
output_names=["waveform"],
dynamic_axes={
"mel": {0: "batch_size", 2: "input_length"},
"waveform": {0: "batch_size", 1: "output_length"},
},
verbose=verbose,
)
print("Exported to {}".format(output_path))
def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device="cpu", hf_cache_dir=None):
if vocoder_name == "vocos":
# vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device)
if is_local:
print(f"Load vocos from local path {local_path}")
config_path = f"{local_path}/config.yaml"
model_path = f"{local_path}/pytorch_model.bin"
else:
print("Download Vocos from huggingface charactr/vocos-mel-24khz")
repo_id = "charactr/vocos-mel-24khz"
config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml")
model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin")
vocoder = Vocos.from_hparams(config_path)
state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
vocoder.load_state_dict(state_dict)
vocoder = vocoder.eval().to(device)
elif vocoder_name == "bigvgan":
raise NotImplementedError("BigVGAN is not supported yet")
vocoder.remove_weight_norm()
vocoder = vocoder.eval().to(device)
return vocoder
if __name__ == "__main__":
args = get_args()
vocoder = load_vocoder(vocoder_name=args.vocoder, device="cpu", hf_cache_dir=None)
if args.vocoder == "vocos":
export_VocosVocoder(vocoder, args.output_path, verbose=False)

View File

@@ -0,0 +1,43 @@
#!/bin/bash
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
TRTEXEC="/usr/src/tensorrt/bin/trtexec"
ONNX_PATH=$1
ENGINE_PATH=$2
echo "ONNX_PATH: $ONNX_PATH"
echo "ENGINE_PATH: $ENGINE_PATH"
PRECISION="fp32"
MIN_BATCH_SIZE=1
OPT_BATCH_SIZE=1
MAX_BATCH_SIZE=8
MIN_INPUT_LENGTH=1
OPT_INPUT_LENGTH=1000
MAX_INPUT_LENGTH=3000
MEL_MIN_SHAPE="${MIN_BATCH_SIZE}x100x${MIN_INPUT_LENGTH}"
MEL_OPT_SHAPE="${OPT_BATCH_SIZE}x100x${OPT_INPUT_LENGTH}"
MEL_MAX_SHAPE="${MAX_BATCH_SIZE}x100x${MAX_INPUT_LENGTH}"
${TRTEXEC} \
--minShapes="mel:${MEL_MIN_SHAPE}" \
--optShapes="mel:${MEL_OPT_SHAPE}" \
--maxShapes="mel:${MEL_MAX_SHAPE}" \
--onnx=${ONNX_PATH} \
--saveEngine=${ENGINE_PATH}

View File

@@ -0,0 +1,36 @@
#! /usr/bin/env python3
from argparse import ArgumentParser
from string import Template
def main(file_path, substitutions, in_place, participant_ids):
with open(file_path) as f:
pbtxt = Template(f.read())
sub_dict = {"max_queue_size": 0}
sub_dict["participant_ids"] = participant_ids
for sub in substitutions.split(","):
key, value = sub.split(":")
sub_dict[key] = value
pbtxt = pbtxt.safe_substitute(sub_dict)
if in_place:
with open(file_path, "w") as f:
f.write(pbtxt)
else:
print(pbtxt)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("file_path", help="path of the .pbtxt to modify")
parser.add_argument(
"substitutions",
help="substitutions to perform, in the format variable_name_1:value_1,variable_name_2:value_2...",
)
parser.add_argument("--in_place", "-i", action="store_true", help="do the operation in-place")
parser.add_argument("--participant_ids", help="Participant IDs for the model", default="")
args = parser.parse_args()
main(**vars(args))