pytorch imple.: fix batch 1 inference from last commit

This commit is contained in:
SWivid
2025-10-22 00:31:56 +00:00
parent a0b8fb5df2
commit a17c5ae435

View File

@@ -238,18 +238,21 @@ class DiT(nn.Module):
audio_mask: bool["b n"] | None = None, # noqa: F722
):
if self.text_uncond is None or self.text_cond is None or not cache:
batch = x.shape[0]
seq_lens = audio_mask.sum(dim=1)
text_embed_list = []
for i in range(batch):
text_embed_i = self.text_embed(
text[i].unsqueeze(0),
seq_lens[i].item(),
drop_text=drop_text,
audio_mask=audio_mask,
)
text_embed_list.append(text_embed_i[0])
text_embed = pad_sequence(text_embed_list, batch_first=True, padding_value=0)
if audio_mask is None:
text_embed = self.text_embed(text, x.shape[1], drop_text=drop_text, audio_mask=audio_mask)
else:
batch = x.shape[0]
seq_lens = audio_mask.sum(dim=1)
text_embed_list = []
for i in range(batch):
text_embed_i = self.text_embed(
text[i].unsqueeze(0),
seq_lens[i].item(),
drop_text=drop_text,
audio_mask=audio_mask,
)
text_embed_list.append(text_embed_i[0])
text_embed = pad_sequence(text_embed_list, batch_first=True, padding_value=0)
if cache:
if drop_text:
self.text_uncond = text_embed