diff --git a/src/f5_tts/model/backbones/dit.py b/src/f5_tts/model/backbones/dit.py index 2af99a1..1608e04 100644 --- a/src/f5_tts/model/backbones/dit.py +++ b/src/f5_tts/model/backbones/dit.py @@ -29,11 +29,16 @@ from f5_tts.model.modules import ( class TextEmbedding(nn.Module): - def __init__(self, text_num_embeds, text_dim, mask_padding=True, conv_layers=0, conv_mult=2): + def __init__( + self, text_num_embeds, text_dim, mask_padding=True, average_upsampling=False, conv_layers=0, conv_mult=2 + ): super().__init__() self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token self.mask_padding = mask_padding # mask filler and batch padding tokens or not + self.average_upsampling = average_upsampling # zipvoice-style text late average upsampling (after text encoder) + if average_upsampling: + assert mask_padding, "text_embedding_average_upsampling requires text_mask_padding to be True" if conv_layers > 0: self.extra_modeling = True @@ -45,11 +50,47 @@ class TextEmbedding(nn.Module): else: self.extra_modeling = False - def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722 + def average_upsample_text_by_mask(self, text, text_mask, audio_mask): + batch, text_len, text_dim = text.shape + + if audio_mask is None: + audio_mask = torch.ones_like(text_mask, dtype=torch.bool) + valid_mask = audio_mask & text_mask + audio_lens = audio_mask.sum(dim=1) # [batch] + valid_lens = valid_mask.sum(dim=1) # [batch] + + upsampled_text = torch.zeros_like(text) + + for i in range(batch): + audio_len = audio_lens[i].item() + valid_len = valid_lens[i].item() + + if valid_len == 0: + continue + + valid_ind = torch.where(valid_mask[i])[0] + valid_data = text[i, valid_ind, :] # [valid_len, text_dim] + + base_repeat = audio_len // valid_len + remainder = audio_len % valid_len + + indices = [] + for j in range(valid_len): + repeat_count = base_repeat + (1 if j >= valid_len - remainder else 0) + indices.extend([j] * repeat_count) + + indices = torch.tensor(indices[:audio_len], device=text.device, dtype=torch.long) + upsampled = valid_data[indices] # [audio_len, text_dim] + + upsampled_text[i, :audio_len, :] = upsampled + + return upsampled_text + + def forward(self, text: int["b nt"], seq_len, drop_text=False, audio_mask: bool["b n"] | None = None): # noqa: F722 text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx() text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens batch, text_len = text.shape[0], text.shape[1] - text = F.pad(text, (0, seq_len - text_len), value=0) + text = F.pad(text, (0, seq_len - text_len), value=0) # (opt.) if not self.average_upsampling: if self.mask_padding: text_mask = text == 0 @@ -61,7 +102,7 @@ class TextEmbedding(nn.Module): # possible extra modeling if self.extra_modeling: # sinus pos emb - batch_start = torch.zeros((batch,), dtype=torch.long) + batch_start = torch.zeros((batch,), device=text.device, dtype=torch.long) pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos) text_pos_embed = self.freqs_cis[pos_idx] text = text + text_pos_embed @@ -75,6 +116,9 @@ class TextEmbedding(nn.Module): else: text = self.text_blocks(text) + if self.average_upsampling: + text = self.average_upsample_text_by_mask(text, ~text_mask, audio_mask) + return text @@ -113,6 +157,7 @@ class DiT(nn.Module): text_num_embeds=256, text_dim=None, text_mask_padding=True, + text_embedding_average_upsampling=False, qk_norm=None, conv_layers=0, pe_attn_head=None, @@ -127,7 +172,11 @@ class DiT(nn.Module): if text_dim is None: text_dim = mel_dim self.text_embed = TextEmbedding( - text_num_embeds, text_dim, mask_padding=text_mask_padding, conv_layers=conv_layers + text_num_embeds, + text_dim, + mask_padding=text_mask_padding, + average_upsampling=text_embedding_average_upsampling, + conv_layers=conv_layers, ) self.text_cond, self.text_uncond = None, None # text cache self.input_embed = InputEmbedding(mel_dim, text_dim, dim) @@ -190,19 +239,20 @@ class DiT(nn.Module): drop_audio_cond: bool = False, drop_text: bool = False, cache: bool = True, + audio_mask: bool["b n"] | None = None, # noqa: F722 ): seq_len = x.shape[1] if cache: if drop_text: if self.text_uncond is None: - self.text_uncond = self.text_embed(text, seq_len, drop_text=True) + self.text_uncond = self.text_embed(text, seq_len, drop_text=True, audio_mask=audio_mask) text_embed = self.text_uncond else: if self.text_cond is None: - self.text_cond = self.text_embed(text, seq_len, drop_text=False) + self.text_cond = self.text_embed(text, seq_len, drop_text=False, audio_mask=audio_mask) text_embed = self.text_cond else: - text_embed = self.text_embed(text, seq_len, drop_text=drop_text) + text_embed = self.text_embed(text, seq_len, drop_text=drop_text, audio_mask=audio_mask) x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond) @@ -230,13 +280,19 @@ class DiT(nn.Module): # t: conditioning time, text: text, x: noised audio + cond audio + text t = self.time_embed(time) if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d - x_cond = self.get_input_embed(x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache) - x_uncond = self.get_input_embed(x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache) + x_cond = self.get_input_embed( + x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache, audio_mask=mask + ) + x_uncond = self.get_input_embed( + x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache, audio_mask=mask + ) x = torch.cat((x_cond, x_uncond), dim=0) t = torch.cat((t, t), dim=0) mask = torch.cat((mask, mask), dim=0) if mask is not None else None else: - x = self.get_input_embed(x, cond, text, drop_audio_cond=drop_audio_cond, drop_text=drop_text, cache=cache) + x = self.get_input_embed( + x, cond, text, drop_audio_cond=drop_audio_cond, drop_text=drop_text, cache=cache, audio_mask=mask + ) rope = self.rotary_embed.forward_from_seq_len(seq_len)