pytorch imple.: fix batch 1 inference from last commit

This commit is contained in:
SWivid
2025-10-22 00:31:56 +00:00
parent a0b8fb5df2
commit a17c5ae435

View File

@@ -238,18 +238,21 @@ class DiT(nn.Module):
audio_mask: bool["b n"] | None = None, # noqa: F722 audio_mask: bool["b n"] | None = None, # noqa: F722
): ):
if self.text_uncond is None or self.text_cond is None or not cache: if self.text_uncond is None or self.text_cond is None or not cache:
batch = x.shape[0] if audio_mask is None:
seq_lens = audio_mask.sum(dim=1) text_embed = self.text_embed(text, x.shape[1], drop_text=drop_text, audio_mask=audio_mask)
text_embed_list = [] else:
for i in range(batch): batch = x.shape[0]
text_embed_i = self.text_embed( seq_lens = audio_mask.sum(dim=1)
text[i].unsqueeze(0), text_embed_list = []
seq_lens[i].item(), for i in range(batch):
drop_text=drop_text, text_embed_i = self.text_embed(
audio_mask=audio_mask, text[i].unsqueeze(0),
) seq_lens[i].item(),
text_embed_list.append(text_embed_i[0]) drop_text=drop_text,
text_embed = pad_sequence(text_embed_list, batch_first=True, padding_value=0) audio_mask=audio_mask,
)
text_embed_list.append(text_embed_i[0])
text_embed = pad_sequence(text_embed_list, batch_first=True, padding_value=0)
if cache: if cache:
if drop_text: if drop_text:
self.text_uncond = text_embed self.text_uncond = text_embed