diff --git a/src/f5_tts/model/backbones/dit.py b/src/f5_tts/model/backbones/dit.py index 4e49b7b..47b0f1d 100644 --- a/src/f5_tts/model/backbones/dit.py +++ b/src/f5_tts/model/backbones/dit.py @@ -66,18 +66,18 @@ class TextEmbedding(nn.Module): valid_ind = torch.where(valid_mask)[0] valid_data = text[0, valid_ind, :] # [valid_len, text_dim] - + base_repeat = audio_len // valid_len remainder = audio_len % valid_len - + indices = [] for j in range(valid_len): repeat_count = base_repeat + (1 if j >= valid_len - remainder else 0) indices.extend([j] * repeat_count) - + indices = torch.tensor(indices[:audio_len], device=text.device, dtype=torch.long) upsampled = valid_data[indices] # [audio_len, text_dim] - + upsampled_text[0, :audio_len, :] = upsampled return upsampled_text @@ -245,7 +245,7 @@ class DiT(nn.Module): text_embed = self.text_embed(text, x.shape[1], drop_text=drop_text) else: batch = x.shape[0] - seq_lens = audio_mask.sum(dim=1) # Calculate the actual sequence length for each sample + seq_lens = audio_mask.sum(dim=1) # Calculate the actual sequence length for each sample text_embed_list = [] for i in range(batch): text_embed_i = self.text_embed( @@ -325,4 +325,4 @@ class DiT(nn.Module): x = self.norm_out(x, t) output = self.proj_out(x) - return output \ No newline at end of file + return output