mirror of
https://github.com/SWivid/F5-TTS.git
synced 2025-12-12 07:40:43 -08:00
formatting
This commit is contained in:
@@ -66,18 +66,18 @@ class TextEmbedding(nn.Module):
|
|||||||
|
|
||||||
valid_ind = torch.where(valid_mask)[0]
|
valid_ind = torch.where(valid_mask)[0]
|
||||||
valid_data = text[0, valid_ind, :] # [valid_len, text_dim]
|
valid_data = text[0, valid_ind, :] # [valid_len, text_dim]
|
||||||
|
|
||||||
base_repeat = audio_len // valid_len
|
base_repeat = audio_len // valid_len
|
||||||
remainder = audio_len % valid_len
|
remainder = audio_len % valid_len
|
||||||
|
|
||||||
indices = []
|
indices = []
|
||||||
for j in range(valid_len):
|
for j in range(valid_len):
|
||||||
repeat_count = base_repeat + (1 if j >= valid_len - remainder else 0)
|
repeat_count = base_repeat + (1 if j >= valid_len - remainder else 0)
|
||||||
indices.extend([j] * repeat_count)
|
indices.extend([j] * repeat_count)
|
||||||
|
|
||||||
indices = torch.tensor(indices[:audio_len], device=text.device, dtype=torch.long)
|
indices = torch.tensor(indices[:audio_len], device=text.device, dtype=torch.long)
|
||||||
upsampled = valid_data[indices] # [audio_len, text_dim]
|
upsampled = valid_data[indices] # [audio_len, text_dim]
|
||||||
|
|
||||||
upsampled_text[0, :audio_len, :] = upsampled
|
upsampled_text[0, :audio_len, :] = upsampled
|
||||||
|
|
||||||
return upsampled_text
|
return upsampled_text
|
||||||
@@ -245,7 +245,7 @@ class DiT(nn.Module):
|
|||||||
text_embed = self.text_embed(text, x.shape[1], drop_text=drop_text)
|
text_embed = self.text_embed(text, x.shape[1], drop_text=drop_text)
|
||||||
else:
|
else:
|
||||||
batch = x.shape[0]
|
batch = x.shape[0]
|
||||||
seq_lens = audio_mask.sum(dim=1) # Calculate the actual sequence length for each sample
|
seq_lens = audio_mask.sum(dim=1) # Calculate the actual sequence length for each sample
|
||||||
text_embed_list = []
|
text_embed_list = []
|
||||||
for i in range(batch):
|
for i in range(batch):
|
||||||
text_embed_i = self.text_embed(
|
text_embed_i = self.text_embed(
|
||||||
@@ -325,4 +325,4 @@ class DiT(nn.Module):
|
|||||||
x = self.norm_out(x, t)
|
x = self.norm_out(x, t)
|
||||||
output = self.proj_out(x)
|
output = self.proj_out(x)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|||||||
Reference in New Issue
Block a user