From a17c5ae4350757c75fd9f2be181edba78b865389 Mon Sep 17 00:00:00 2001 From: SWivid Date: Wed, 22 Oct 2025 00:31:56 +0000 Subject: [PATCH] pytorch imple.: fix batch 1 inference from last commit --- src/f5_tts/model/backbones/dit.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/src/f5_tts/model/backbones/dit.py b/src/f5_tts/model/backbones/dit.py index 1c4b9bb..cf64255 100644 --- a/src/f5_tts/model/backbones/dit.py +++ b/src/f5_tts/model/backbones/dit.py @@ -238,18 +238,21 @@ class DiT(nn.Module): audio_mask: bool["b n"] | None = None, # noqa: F722 ): if self.text_uncond is None or self.text_cond is None or not cache: - batch = x.shape[0] - seq_lens = audio_mask.sum(dim=1) - text_embed_list = [] - for i in range(batch): - text_embed_i = self.text_embed( - text[i].unsqueeze(0), - seq_lens[i].item(), - drop_text=drop_text, - audio_mask=audio_mask, - ) - text_embed_list.append(text_embed_i[0]) - text_embed = pad_sequence(text_embed_list, batch_first=True, padding_value=0) + if audio_mask is None: + text_embed = self.text_embed(text, x.shape[1], drop_text=drop_text, audio_mask=audio_mask) + else: + batch = x.shape[0] + seq_lens = audio_mask.sum(dim=1) + text_embed_list = [] + for i in range(batch): + text_embed_i = self.text_embed( + text[i].unsqueeze(0), + seq_lens[i].item(), + drop_text=drop_text, + audio_mask=audio_mask, + ) + text_embed_list.append(text_embed_i[0]) + text_embed = pad_sequence(text_embed_list, batch_first=True, padding_value=0) if cache: if drop_text: self.text_uncond = text_embed