mirror of
https://github.com/SWivid/F5-TTS.git
synced 2025-12-12 15:50:07 -08:00
update dit checkpoint_activations and fix#399 #400
This commit is contained in:
@@ -28,6 +28,7 @@ model:
|
||||
ff_mult: 2
|
||||
text_dim: 512
|
||||
conv_layers: 4
|
||||
checkpoint_activations: False # recompute activations and save memory for extra compute
|
||||
mel_spec:
|
||||
target_sample_rate: 24000
|
||||
n_mel_channels: 100
|
||||
|
||||
@@ -28,6 +28,7 @@ model:
|
||||
ff_mult: 2
|
||||
text_dim: 512
|
||||
conv_layers: 4
|
||||
checkpoint_activations: False # recompute activations and save memory for extra compute
|
||||
mel_spec:
|
||||
target_sample_rate: 24000
|
||||
n_mel_channels: 100
|
||||
|
||||
@@ -105,6 +105,7 @@ class DiT(nn.Module):
|
||||
text_dim=None,
|
||||
conv_layers=0,
|
||||
long_skip_connection=False,
|
||||
checkpoint_activations=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -127,6 +128,17 @@ class DiT(nn.Module):
|
||||
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
|
||||
self.proj_out = nn.Linear(dim, mel_dim)
|
||||
|
||||
self.checkpoint_activations = checkpoint_activations
|
||||
|
||||
def ckpt_wrapper(self, module):
|
||||
"""Code from https://github.com/chuanyangjin/fast-DiT/blob/1a8ecce58f346f877749f2dc67cdb190d295e4dc/models.py#L233-L237"""
|
||||
|
||||
def ckpt_forward(*inputs):
|
||||
outputs = module(*inputs)
|
||||
return outputs
|
||||
|
||||
return ckpt_forward
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: float["b n d"], # nosied input audio # noqa: F722
|
||||
@@ -152,7 +164,10 @@ class DiT(nn.Module):
|
||||
residual = x
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
x = block(x, t, mask=mask, rope=rope)
|
||||
if self.checkpoint_activations:
|
||||
x = torch.utils.checkpoint.checkpoint(self.ckpt_wrapper(block), x, t, mask, rope)
|
||||
else:
|
||||
x = block(x, t, mask=mask, rope=rope)
|
||||
|
||||
if self.long_skip_connection is not None:
|
||||
x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
|
||||
|
||||
Reference in New Issue
Block a user