diff --git a/src/f5_tts/model/backbones/dit.py b/src/f5_tts/model/backbones/dit.py index 1c4b9bb..cf64255 100644 --- a/src/f5_tts/model/backbones/dit.py +++ b/src/f5_tts/model/backbones/dit.py @@ -238,18 +238,21 @@ class DiT(nn.Module): audio_mask: bool["b n"] | None = None, # noqa: F722 ): if self.text_uncond is None or self.text_cond is None or not cache: - batch = x.shape[0] - seq_lens = audio_mask.sum(dim=1) - text_embed_list = [] - for i in range(batch): - text_embed_i = self.text_embed( - text[i].unsqueeze(0), - seq_lens[i].item(), - drop_text=drop_text, - audio_mask=audio_mask, - ) - text_embed_list.append(text_embed_i[0]) - text_embed = pad_sequence(text_embed_list, batch_first=True, padding_value=0) + if audio_mask is None: + text_embed = self.text_embed(text, x.shape[1], drop_text=drop_text, audio_mask=audio_mask) + else: + batch = x.shape[0] + seq_lens = audio_mask.sum(dim=1) + text_embed_list = [] + for i in range(batch): + text_embed_i = self.text_embed( + text[i].unsqueeze(0), + seq_lens[i].item(), + drop_text=drop_text, + audio_mask=audio_mask, + ) + text_embed_list.append(text_embed_i[0]) + text_embed = pad_sequence(text_embed_list, batch_first=True, padding_value=0) if cache: if drop_text: self.text_uncond = text_embed