This commit is contained in:
root
2025-04-03 04:25:43 +00:00
parent 4681a1c177
commit ae51cc3d34
7 changed files with 26 additions and 22 deletions

View File

@@ -35,9 +35,9 @@ python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_t
### Benchmark Results
Decoding on a single L20 GPU, using 26 different prompt_audio/target_text pairs.
| Model | Note | Concurrency | Avg Latency | RTF |
|-------|-----------|-----------------------|---------|--|
| F5-TTS Base (Vocos) | [Code Commit](https://github.com/yuekaizhang/sherpa/tree/329ab3c573252e835844bea38505c6b43e994cf4/triton/f5_tts) | 1 | 253 ms | 0.0394|
| Model | Concurrency | Avg Latency | RTF |
|-------|-------------|-----------------|--|
| F5-TTS Base (Vocos) | 1 | 253 ms | 0.0394|
### Credits
1. [F5-TTS-TRTLLM](https://github.com/Bigfishering/f5-tts-trtllm)

View File

@@ -245,6 +245,7 @@ async def send(
model_name: str,
padding_duration: int = None,
audio_save_dir: str = "./",
save_sample_rate: int = 16000,
):
total_duration = 0.0
latency_data = []
@@ -267,7 +268,9 @@ async def send(
samples = np.zeros(
(
1,
padding_duration * sample_rate * ((int(duration) // padding_duration) + 1),
padding_duration
* sample_rate
* ((int(estimated_target_duration + duration) // padding_duration) + 1),
),
dtype=np.float32,
)
@@ -306,7 +309,7 @@ async def send(
end = time.time() - start
audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
sf.write(audio_save_path, audio, 16000, "PCM_16")
sf.write(audio_save_path, audio, save_sample_rate, "PCM_16")
latency_data.append((end, estimated_target_duration))
total_duration += estimated_target_duration
@@ -413,7 +416,8 @@ async def main():
log_interval=args.log_interval,
model_name=args.model_name,
audio_save_dir=args.log_dir,
padding_duration=1.0,
padding_duration=1,
save_sample_rate=24000 if args.model_name == "f5_tts" else 16000,
)
)
tasks.append(task)

View File

@@ -158,7 +158,7 @@ class TritonPythonModel:
return mel.transpose(1, 2)
def forward_vocoder(self, mel):
mel = mel.to(torch.float32).contiguous()
mel = mel.to(torch.float32).contiguous().cpu()
input_tensor_0 = pb_utils.Tensor.from_dlpack("mel", to_dlpack(mel))
inference_request = pb_utils.InferenceRequest(

View File

@@ -14,9 +14,9 @@
name: "f5_tts"
backend: "python"
max_batch_size: 1
max_batch_size: 4
dynamic_batching {
max_queue_delay_microseconds: 1
max_queue_delay_microseconds: 1000
}
parameters [
{

View File

@@ -30,8 +30,7 @@ class InputEmbedding(Module):
self.proj = Linear(mel_dim * 2 + text_dim, out_dim)
self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
def forward(self, x, cond, drop_audio_cond=False):
# if drop_audio_cond: # cfg for cond audio
def forward(self, x, cond):
x = self.proj(concat([x, cond], dim=-1))
return self.conv_pos_embed(x) + x
@@ -41,9 +40,8 @@ class F5TTS(PretrainedModel):
super().__init__(config)
self.dtype = str_dtype_to_trt(config.dtype)
self.time_embed = TimestepEmbedding(config.hidden_size) # √
text_dim = config.mel_dim if config.text_dim is None else config.text_dim
self.input_embed = InputEmbedding(config.mel_dim, text_dim, config.hidden_size)
self.time_embed = TimestepEmbedding(config.hidden_size)
self.input_embed = InputEmbedding(config.mel_dim, config.text_dim, config.hidden_size)
self.dim = config.hidden_size
self.depth = config.num_hidden_layers

View File

@@ -93,7 +93,7 @@ class ConvPositionEmbedding(Module):
self.conv1d2 = Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2)
self.mish = Mish()
def forward(self, x, mask): # noqa: F722
def forward(self, x, mask=None): # noqa: F722
if default_net().plugin_config.remove_input_padding:
x = unsqueeze(x, 0)
x = permute(x, [0, 2, 1])

View File

@@ -14,23 +14,22 @@ F5_TTS_HF_DOWNLOAD_PATH=./F5-TTS
F5_TTS_TRT_LLM_CHECKPOINT_PATH=./trtllm_ckpt
F5_TTS_TRT_LLM_ENGINE_PATH=./f5_trt_llm_engine
num_task=2
log_dir=./log_concurrent_tasks_${num_task}
vocoder_trt_engine_path=vocos_vocoder.plan
model_repo=./model_repo
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
echo "Copying f5 tts trtllm files"
python_package_path=/usr/local/lib/python3.12/dist-packages
cp -r patch/* $python_package_path/tensorrt_llm/models
echo "Downloading f5 tts from huggingface"
huggingface-cli download SWivid/F5-TTS --local-dir $F5_TTS_HF_DOWNLOAD_PATH
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
echo "Downloading f5 tts from huggingface"
huggingface-cli download SWivid/F5-TTS --local-dir $F5_TTS_HF_DOWNLOAD_PATH
echo "Converting checkpoint"
python3 ./scripts/convert_checkpoint.py \
--timm_ckpt "$F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt" \
--output_dir "$F5_TTS_TRT_LLM_CHECKPOINT_PATH" --model_name $model
python_package_path=/usr/local/lib/python3.12/dist-packages
cp -r patch/* $python_package_path/tensorrt_llm/models
trtllm-build --checkpoint_dir $F5_TTS_TRT_LLM_CHECKPOINT_PATH \
--max_batch_size 8 \
--output_dir $F5_TTS_TRT_LLM_ENGINE_PATH --remove_input_padding disable
@@ -58,5 +57,8 @@ fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
echo "Testing triton server"
num_task=1
log_dir=./log_concurrent_tasks_${num_task}
rm -r $log_dir
python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts --log-dir $log_dir
fi