mirror of
https://github.com/SWivid/F5-TTS.git
synced 2025-12-12 15:50:07 -08:00
Fix Average Upsampling
This commit is contained in:
@@ -51,43 +51,38 @@ class TextEmbedding(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.extra_modeling = False
|
self.extra_modeling = False
|
||||||
|
|
||||||
def average_upsample_text_by_mask(self, text, text_mask, audio_mask):
|
def average_upsample_text_by_mask(self, text, text_mask):
|
||||||
batch, text_len, text_dim = text.shape
|
batch, text_len, text_dim = text.shape
|
||||||
|
assert batch == 1
|
||||||
|
|
||||||
if audio_mask is None:
|
valid_mask = text_mask[0]
|
||||||
audio_mask = torch.ones_like(text_mask, dtype=torch.bool)
|
audio_len = text_len
|
||||||
valid_mask = audio_mask & text_mask
|
valid_len = valid_mask.sum().item()
|
||||||
audio_lens = audio_mask.sum(dim=1) # [batch]
|
|
||||||
valid_lens = valid_mask.sum(dim=1) # [batch]
|
if valid_len == 0:
|
||||||
|
return torch.zeros_like(text)
|
||||||
|
|
||||||
upsampled_text = torch.zeros_like(text)
|
upsampled_text = torch.zeros_like(text)
|
||||||
|
|
||||||
for i in range(batch):
|
valid_ind = torch.where(valid_mask)[0]
|
||||||
audio_len = audio_lens[i].item()
|
valid_data = text[0, valid_ind, :] # [valid_len, text_dim]
|
||||||
valid_len = valid_lens[i].item()
|
|
||||||
|
base_repeat = audio_len // valid_len
|
||||||
if valid_len == 0:
|
remainder = audio_len % valid_len
|
||||||
continue
|
|
||||||
|
indices = []
|
||||||
valid_ind = torch.where(valid_mask[i])[0]
|
for j in range(valid_len):
|
||||||
valid_data = text[i, valid_ind, :] # [valid_len, text_dim]
|
repeat_count = base_repeat + (1 if j >= valid_len - remainder else 0)
|
||||||
|
indices.extend([j] * repeat_count)
|
||||||
base_repeat = audio_len // valid_len
|
|
||||||
remainder = audio_len % valid_len
|
indices = torch.tensor(indices[:audio_len], device=text.device, dtype=torch.long)
|
||||||
|
upsampled = valid_data[indices] # [audio_len, text_dim]
|
||||||
indices = []
|
|
||||||
for j in range(valid_len):
|
upsampled_text[0, :audio_len, :] = upsampled
|
||||||
repeat_count = base_repeat + (1 if j >= valid_len - remainder else 0)
|
|
||||||
indices.extend([j] * repeat_count)
|
|
||||||
|
|
||||||
indices = torch.tensor(indices[:audio_len], device=text.device, dtype=torch.long)
|
|
||||||
upsampled = valid_data[indices] # [audio_len, text_dim]
|
|
||||||
|
|
||||||
upsampled_text[i, :audio_len, :] = upsampled
|
|
||||||
|
|
||||||
return upsampled_text
|
return upsampled_text
|
||||||
|
|
||||||
def forward(self, text: int["b nt"], seq_len, drop_text=False, audio_mask: bool["b n"] | None = None):
|
def forward(self, text: int["b nt"], seq_len, drop_text=False):
|
||||||
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
|
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
|
||||||
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
|
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
|
||||||
text = F.pad(text, (0, seq_len - text.shape[1]), value=0) # (opt.) if not self.average_upsampling:
|
text = F.pad(text, (0, seq_len - text.shape[1]), value=0) # (opt.) if not self.average_upsampling:
|
||||||
@@ -114,7 +109,7 @@ class TextEmbedding(nn.Module):
|
|||||||
text = self.text_blocks(text)
|
text = self.text_blocks(text)
|
||||||
|
|
||||||
if self.average_upsampling:
|
if self.average_upsampling:
|
||||||
text = self.average_upsample_text_by_mask(text, ~text_mask, audio_mask)
|
text = self.average_upsample_text_by_mask(text, ~text_mask)
|
||||||
|
|
||||||
return text
|
return text
|
||||||
|
|
||||||
@@ -247,17 +242,16 @@ class DiT(nn.Module):
|
|||||||
):
|
):
|
||||||
if self.text_uncond is None or self.text_cond is None or not cache:
|
if self.text_uncond is None or self.text_cond is None or not cache:
|
||||||
if audio_mask is None:
|
if audio_mask is None:
|
||||||
text_embed = self.text_embed(text, x.shape[1], drop_text=drop_text, audio_mask=audio_mask)
|
text_embed = self.text_embed(text, x.shape[1], drop_text=drop_text)
|
||||||
else:
|
else:
|
||||||
batch = x.shape[0]
|
batch = x.shape[0]
|
||||||
seq_lens = audio_mask.sum(dim=1)
|
seq_lens = audio_mask.sum(dim=1) # Calculate the actual sequence length for each sample
|
||||||
text_embed_list = []
|
text_embed_list = []
|
||||||
for i in range(batch):
|
for i in range(batch):
|
||||||
text_embed_i = self.text_embed(
|
text_embed_i = self.text_embed(
|
||||||
text[i].unsqueeze(0),
|
text[i].unsqueeze(0),
|
||||||
seq_lens[i].item(),
|
seq_len=seq_lens[i].item(),
|
||||||
drop_text=drop_text,
|
drop_text=drop_text,
|
||||||
audio_mask=audio_mask,
|
|
||||||
)
|
)
|
||||||
text_embed_list.append(text_embed_i[0])
|
text_embed_list.append(text_embed_i[0])
|
||||||
text_embed = pad_sequence(text_embed_list, batch_first=True, padding_value=0)
|
text_embed = pad_sequence(text_embed_list, batch_first=True, padding_value=0)
|
||||||
@@ -331,4 +325,4 @@ class DiT(nn.Module):
|
|||||||
x = self.norm_out(x, t)
|
x = self.norm_out(x, t)
|
||||||
output = self.proj_out(x)
|
output = self.proj_out(x)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
Reference in New Issue
Block a user