mirror of
https://github.com/SWivid/F5-TTS.git
synced 2025-12-12 07:40:43 -08:00
pytorch imple.: fix batch 1 inference from last commit
This commit is contained in:
@@ -238,18 +238,21 @@ class DiT(nn.Module):
|
|||||||
audio_mask: bool["b n"] | None = None, # noqa: F722
|
audio_mask: bool["b n"] | None = None, # noqa: F722
|
||||||
):
|
):
|
||||||
if self.text_uncond is None or self.text_cond is None or not cache:
|
if self.text_uncond is None or self.text_cond is None or not cache:
|
||||||
batch = x.shape[0]
|
if audio_mask is None:
|
||||||
seq_lens = audio_mask.sum(dim=1)
|
text_embed = self.text_embed(text, x.shape[1], drop_text=drop_text, audio_mask=audio_mask)
|
||||||
text_embed_list = []
|
else:
|
||||||
for i in range(batch):
|
batch = x.shape[0]
|
||||||
text_embed_i = self.text_embed(
|
seq_lens = audio_mask.sum(dim=1)
|
||||||
text[i].unsqueeze(0),
|
text_embed_list = []
|
||||||
seq_lens[i].item(),
|
for i in range(batch):
|
||||||
drop_text=drop_text,
|
text_embed_i = self.text_embed(
|
||||||
audio_mask=audio_mask,
|
text[i].unsqueeze(0),
|
||||||
)
|
seq_lens[i].item(),
|
||||||
text_embed_list.append(text_embed_i[0])
|
drop_text=drop_text,
|
||||||
text_embed = pad_sequence(text_embed_list, batch_first=True, padding_value=0)
|
audio_mask=audio_mask,
|
||||||
|
)
|
||||||
|
text_embed_list.append(text_embed_i[0])
|
||||||
|
text_embed = pad_sequence(text_embed_list, batch_first=True, padding_value=0)
|
||||||
if cache:
|
if cache:
|
||||||
if drop_text:
|
if drop_text:
|
||||||
self.text_uncond = text_embed
|
self.text_uncond = text_embed
|
||||||
|
|||||||
Reference in New Issue
Block a user