pytorch imple.: fix batch 1 inference from last commit

This commit is contained in:
SWivid
2025-10-22 00:31:56 +00:00
parent a0b8fb5df2
commit a17c5ae435

View File

@@ -238,6 +238,9 @@ class DiT(nn.Module):
audio_mask: bool["b n"] | None = None, # noqa: F722
):
if self.text_uncond is None or self.text_cond is None or not cache:
if audio_mask is None:
text_embed = self.text_embed(text, x.shape[1], drop_text=drop_text, audio_mask=audio_mask)
else:
batch = x.shape[0]
seq_lens = audio_mask.sum(dim=1)
text_embed_list = []