mirror of
https://github.com/SWivid/F5-TTS.git
synced 2025-12-12 15:50:07 -08:00
fix bug
This commit is contained in:
@@ -35,9 +35,9 @@ python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_t
|
||||
### Benchmark Results
|
||||
Decoding on a single L20 GPU, using 26 different prompt_audio/target_text pairs.
|
||||
|
||||
| Model | Note | Concurrency | Avg Latency | RTF |
|
||||
|-------|-----------|-----------------------|---------|--|
|
||||
| F5-TTS Base (Vocos) | [Code Commit](https://github.com/yuekaizhang/sherpa/tree/329ab3c573252e835844bea38505c6b43e994cf4/triton/f5_tts) | 1 | 253 ms | 0.0394|
|
||||
| Model | Concurrency | Avg Latency | RTF |
|
||||
|-------|-------------|-----------------|--|
|
||||
| F5-TTS Base (Vocos) | 1 | 253 ms | 0.0394|
|
||||
|
||||
### Credits
|
||||
1. [F5-TTS-TRTLLM](https://github.com/Bigfishering/f5-tts-trtllm)
|
||||
@@ -245,6 +245,7 @@ async def send(
|
||||
model_name: str,
|
||||
padding_duration: int = None,
|
||||
audio_save_dir: str = "./",
|
||||
save_sample_rate: int = 16000,
|
||||
):
|
||||
total_duration = 0.0
|
||||
latency_data = []
|
||||
@@ -267,7 +268,9 @@ async def send(
|
||||
samples = np.zeros(
|
||||
(
|
||||
1,
|
||||
padding_duration * sample_rate * ((int(duration) // padding_duration) + 1),
|
||||
padding_duration
|
||||
* sample_rate
|
||||
* ((int(estimated_target_duration + duration) // padding_duration) + 1),
|
||||
),
|
||||
dtype=np.float32,
|
||||
)
|
||||
@@ -306,7 +309,7 @@ async def send(
|
||||
end = time.time() - start
|
||||
|
||||
audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
|
||||
sf.write(audio_save_path, audio, 16000, "PCM_16")
|
||||
sf.write(audio_save_path, audio, save_sample_rate, "PCM_16")
|
||||
|
||||
latency_data.append((end, estimated_target_duration))
|
||||
total_duration += estimated_target_duration
|
||||
@@ -413,7 +416,8 @@ async def main():
|
||||
log_interval=args.log_interval,
|
||||
model_name=args.model_name,
|
||||
audio_save_dir=args.log_dir,
|
||||
padding_duration=1.0,
|
||||
padding_duration=1,
|
||||
save_sample_rate=24000 if args.model_name == "f5_tts" else 16000,
|
||||
)
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
@@ -158,7 +158,7 @@ class TritonPythonModel:
|
||||
return mel.transpose(1, 2)
|
||||
|
||||
def forward_vocoder(self, mel):
|
||||
mel = mel.to(torch.float32).contiguous()
|
||||
mel = mel.to(torch.float32).contiguous().cpu()
|
||||
input_tensor_0 = pb_utils.Tensor.from_dlpack("mel", to_dlpack(mel))
|
||||
|
||||
inference_request = pb_utils.InferenceRequest(
|
||||
|
||||
@@ -14,9 +14,9 @@
|
||||
|
||||
name: "f5_tts"
|
||||
backend: "python"
|
||||
max_batch_size: 1
|
||||
max_batch_size: 4
|
||||
dynamic_batching {
|
||||
max_queue_delay_microseconds: 1
|
||||
max_queue_delay_microseconds: 1000
|
||||
}
|
||||
parameters [
|
||||
{
|
||||
|
||||
@@ -30,8 +30,7 @@ class InputEmbedding(Module):
|
||||
self.proj = Linear(mel_dim * 2 + text_dim, out_dim)
|
||||
self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
|
||||
|
||||
def forward(self, x, cond, drop_audio_cond=False):
|
||||
# if drop_audio_cond: # cfg for cond audio
|
||||
def forward(self, x, cond):
|
||||
x = self.proj(concat([x, cond], dim=-1))
|
||||
return self.conv_pos_embed(x) + x
|
||||
|
||||
@@ -41,9 +40,8 @@ class F5TTS(PretrainedModel):
|
||||
super().__init__(config)
|
||||
self.dtype = str_dtype_to_trt(config.dtype)
|
||||
|
||||
self.time_embed = TimestepEmbedding(config.hidden_size) # √
|
||||
text_dim = config.mel_dim if config.text_dim is None else config.text_dim
|
||||
self.input_embed = InputEmbedding(config.mel_dim, text_dim, config.hidden_size)
|
||||
self.time_embed = TimestepEmbedding(config.hidden_size)
|
||||
self.input_embed = InputEmbedding(config.mel_dim, config.text_dim, config.hidden_size)
|
||||
|
||||
self.dim = config.hidden_size
|
||||
self.depth = config.num_hidden_layers
|
||||
|
||||
@@ -93,7 +93,7 @@ class ConvPositionEmbedding(Module):
|
||||
self.conv1d2 = Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2)
|
||||
self.mish = Mish()
|
||||
|
||||
def forward(self, x, mask): # noqa: F722
|
||||
def forward(self, x, mask=None): # noqa: F722
|
||||
if default_net().plugin_config.remove_input_padding:
|
||||
x = unsqueeze(x, 0)
|
||||
x = permute(x, [0, 2, 1])
|
||||
|
||||
@@ -14,23 +14,22 @@ F5_TTS_HF_DOWNLOAD_PATH=./F5-TTS
|
||||
F5_TTS_TRT_LLM_CHECKPOINT_PATH=./trtllm_ckpt
|
||||
F5_TTS_TRT_LLM_ENGINE_PATH=./f5_trt_llm_engine
|
||||
|
||||
num_task=2
|
||||
log_dir=./log_concurrent_tasks_${num_task}
|
||||
vocoder_trt_engine_path=vocos_vocoder.plan
|
||||
model_repo=./model_repo
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
echo "Copying f5 tts trtllm files"
|
||||
python_package_path=/usr/local/lib/python3.12/dist-packages
|
||||
cp -r patch/* $python_package_path/tensorrt_llm/models
|
||||
echo "Downloading f5 tts from huggingface"
|
||||
huggingface-cli download SWivid/F5-TTS --local-dir $F5_TTS_HF_DOWNLOAD_PATH
|
||||
|
||||
fi
|
||||
|
||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
echo "Downloading f5 tts from huggingface"
|
||||
huggingface-cli download SWivid/F5-TTS --local-dir $F5_TTS_HF_DOWNLOAD_PATH
|
||||
echo "Converting checkpoint"
|
||||
python3 ./scripts/convert_checkpoint.py \
|
||||
--timm_ckpt "$F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt" \
|
||||
--output_dir "$F5_TTS_TRT_LLM_CHECKPOINT_PATH" --model_name $model
|
||||
python_package_path=/usr/local/lib/python3.12/dist-packages
|
||||
cp -r patch/* $python_package_path/tensorrt_llm/models
|
||||
trtllm-build --checkpoint_dir $F5_TTS_TRT_LLM_CHECKPOINT_PATH \
|
||||
--max_batch_size 8 \
|
||||
--output_dir $F5_TTS_TRT_LLM_ENGINE_PATH --remove_input_padding disable
|
||||
@@ -58,5 +57,8 @@ fi
|
||||
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
echo "Testing triton server"
|
||||
num_task=1
|
||||
log_dir=./log_concurrent_tasks_${num_task}
|
||||
rm -r $log_dir
|
||||
python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts --log-dir $log_dir
|
||||
fi
|
||||
Reference in New Issue
Block a user