update dit checkpoint_activations and fix#399 #400

This commit is contained in:
ZhikangNiu
2024-12-16 13:24:58 +08:00
parent c6e96d0c83
commit 7c84e91a00
3 changed files with 18 additions and 1 deletions

View File

@@ -28,6 +28,7 @@ model:
ff_mult: 2
text_dim: 512
conv_layers: 4
checkpoint_activations: False # recompute activations and save memory for extra compute
mel_spec:
target_sample_rate: 24000
n_mel_channels: 100

View File

@@ -28,6 +28,7 @@ model:
ff_mult: 2
text_dim: 512
conv_layers: 4
checkpoint_activations: False # recompute activations and save memory for extra compute
mel_spec:
target_sample_rate: 24000
n_mel_channels: 100

View File

@@ -105,6 +105,7 @@ class DiT(nn.Module):
text_dim=None,
conv_layers=0,
long_skip_connection=False,
checkpoint_activations=False,
):
super().__init__()
@@ -127,6 +128,17 @@ class DiT(nn.Module):
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
self.proj_out = nn.Linear(dim, mel_dim)
self.checkpoint_activations = checkpoint_activations
def ckpt_wrapper(self, module):
"""Code from https://github.com/chuanyangjin/fast-DiT/blob/1a8ecce58f346f877749f2dc67cdb190d295e4dc/models.py#L233-L237"""
def ckpt_forward(*inputs):
outputs = module(*inputs)
return outputs
return ckpt_forward
def forward(
self,
x: float["b n d"], # nosied input audio # noqa: F722
@@ -152,7 +164,10 @@ class DiT(nn.Module):
residual = x
for block in self.transformer_blocks:
x = block(x, t, mask=mask, rope=rope)
if self.checkpoint_activations:
x = torch.utils.checkpoint.checkpoint(self.ckpt_wrapper(block), x, t, mask, rope)
else:
x = block(x, t, mask=mask, rope=rope)
if self.long_skip_connection is not None:
x = self.long_skip_connection(torch.cat((x, residual), dim=-1))