Add flash_attn2 support attn_mask, minor fixes (#1066)

* add flash attn2 support
* update flash attn config in F5TTS
* fix minor bug of get the length of ref_mel

---------

Co-authored-by: SWivid <swivid@qq.com>
This commit is contained in:
Zhikang Niu
2025-06-11 12:14:32 +08:00
committed by GitHub
parent c6ebad0220
commit 0914170e98
10 changed files with 111 additions and 33 deletions

View File

@@ -31,6 +31,8 @@ model:
text_mask_padding: False
conv_layers: 4
pe_attn_head: 1
attn_backend: torch # torch | flash_attn
attn_mask_enabled: False
checkpoint_activations: False # recompute activations and save memory for extra compute
mel_spec:
target_sample_rate: 24000

View File

@@ -31,6 +31,8 @@ model:
text_mask_padding: False
conv_layers: 4
pe_attn_head: 1
attn_backend: torch # torch | flash_attn
attn_mask_enabled: False
checkpoint_activations: False # recompute activations and save memory for extra compute
mel_spec:
target_sample_rate: 24000

View File

@@ -32,6 +32,8 @@ model:
qk_norm: null # null | rms_norm
conv_layers: 4
pe_attn_head: null
attn_backend: torch # torch | flash_attn
attn_mask_enabled: False
checkpoint_activations: False # recompute activations and save memory for extra compute
mel_spec:
target_sample_rate: 24000

View File

@@ -148,10 +148,15 @@ def main():
vocab_char_map=vocab_char_map,
).to(device)
ckpt_path = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}.pt"
if not os.path.exists(ckpt_path):
ckpt_prefix = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}"
if os.path.exists(ckpt_prefix + ".pt"):
ckpt_path = ckpt_prefix + ".pt"
elif os.path.exists(ckpt_prefix + ".safetensors"):
ckpt_path = ckpt_prefix + ".safetensors"
else:
print("Loading from self-organized training checkpoints rather than released pretrained.")
ckpt_path = rel_path + f"/{model_cfg.ckpts.save_dir}/model_{ckpt_step}.pt"
dtype = torch.float32 if mel_spec_type == "bigvgan" else None
model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)

View File

@@ -126,8 +126,13 @@ def get_inference_prompt(
else:
text_list = text
# to mel spectrogram
ref_mel = mel_spectrogram(ref_audio)
ref_mel = ref_mel.squeeze(0)
# Duration, mel frame length
ref_mel_len = ref_audio.shape[-1] // hop_length
ref_mel_len = ref_mel.shape[-1]
if use_truth_duration:
gt_audio, gt_sr = torchaudio.load(gt_wav)
if gt_sr != target_sample_rate:
@@ -142,10 +147,6 @@ def get_inference_prompt(
gen_text_len = len(gt_text.encode("utf-8"))
total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed)
# to mel spectrogram
ref_mel = mel_spectrogram(ref_audio)
ref_mel = ref_mel.squeeze(0)
# deal with batch
assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
assert min_tokens <= total_mel_len <= max_tokens, (

View File

@@ -116,6 +116,8 @@ class DiT(nn.Module):
qk_norm=None,
conv_layers=0,
pe_attn_head=None,
attn_backend="torch", # "torch" | "flash_attn"
attn_mask_enabled=False,
long_skip_connection=False,
checkpoint_activations=False,
):
@@ -145,6 +147,8 @@ class DiT(nn.Module):
dropout=dropout,
qk_norm=qk_norm,
pe_attn_head=pe_attn_head,
attn_backend=attn_backend,
attn_mask_enabled=attn_mask_enabled,
)
for _ in range(depth)
]

View File

@@ -275,10 +275,9 @@ class CFM(nn.Module):
else:
drop_text = False
# if want rigorously mask out padding, record in collate_fn in dataset.py, and pass in here
# adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
# apply mask will use more memory; might adjust batchsize or batchsampler long sequence threshold
pred = self.transformer(
x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text
x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text, mask=mask
)
# flow matching loss

View File

@@ -312,7 +312,7 @@ def collate_fn(batch):
max_mel_length = mel_lengths.amax()
padded_mel_specs = []
for spec in mel_specs: # TODO. maybe records mask for attention here
for spec in mel_specs:
padding = (0, max_mel_length - spec.size(-1))
padded_spec = F.pad(spec, padding, value=0)
padded_mel_specs.append(padded_spec)
@@ -324,7 +324,7 @@ def collate_fn(batch):
return dict(
mel=mel_specs,
mel_lengths=mel_lengths,
mel_lengths=mel_lengths, # records for padding mask
text=text,
text_lengths=text_lengths,
)

View File

@@ -6,6 +6,7 @@ nt - text sequence
nw - raw wave length
d - dimension
"""
# flake8: noqa
from __future__ import annotations
@@ -19,6 +20,8 @@ from librosa.filters import mel as librosa_mel_fn
from torch import nn
from x_transformers.x_transformers import apply_rotary_pos_emb
from f5_tts.model.utils import is_package_available
# raw wav to mel spec
@@ -175,7 +178,7 @@ class ConvPositionEmbedding(nn.Module):
nn.Mish(),
)
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None):
if mask is not None:
mask = mask[..., None]
x = x.masked_fill(~mask, 0.0)
@@ -417,9 +420,9 @@ class Attention(nn.Module):
def forward(
self,
x: float["b n d"], # noised input x # noqa: F722
c: float["b n d"] = None, # context c # noqa: F722
mask: bool["b n"] | None = None, # noqa: F722
x: float["b n d"], # noised input x
c: float["b n d"] = None, # context c
mask: bool["b n"] | None = None,
rope=None, # rotary position embedding for x
c_rope=None, # rotary position embedding for c
) -> torch.Tensor:
@@ -431,19 +434,30 @@ class Attention(nn.Module):
# Attention processor
if is_package_available("flash_attn"):
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn import flash_attn_varlen_func, flash_attn_func
class AttnProcessor:
def __init__(
self,
pe_attn_head: int | None = None, # number of attention head to apply rope, None for all
attn_backend: str = "flash_attn",
attn_mask_enabled: bool = True,
):
if attn_backend == "flash_attn":
assert is_package_available("flash_attn"), "Please install flash-attn first."
self.pe_attn_head = pe_attn_head
self.attn_backend = attn_backend
self.attn_mask_enabled = attn_mask_enabled
def __call__(
self,
attn: Attention,
x: float["b n d"], # noised input x # noqa: F722
mask: bool["b n"] | None = None, # noqa: F722
x: float["b n d"], # noised input x
mask: bool["b n"] | None = None,
rope=None, # rotary position embedding
) -> torch.FloatTensor:
batch_size = x.shape[0]
@@ -479,16 +493,40 @@ class AttnProcessor:
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
# mask. e.g. inference got a batch with different target durations, mask out the padding
if mask is not None:
attn_mask = mask
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
else:
attn_mask = None
if self.attn_backend == "torch":
# mask. e.g. inference got a batch with different target durations, mask out the padding
if self.attn_mask_enabled and mask is not None:
attn_mask = mask
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
else:
attn_mask = None
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
elif self.attn_backend == "flash_attn":
query = query.transpose(1, 2) # [b, h, n, d] -> [b, n, h, d]
key = key.transpose(1, 2)
value = value.transpose(1, 2)
if self.attn_mask_enabled and mask is not None:
query, indices, q_cu_seqlens, q_max_seqlen_in_batch, _ = unpad_input(query, mask)
key, _, k_cu_seqlens, k_max_seqlen_in_batch, _ = unpad_input(key, mask)
value, _, _, _, _ = unpad_input(value, mask)
x = flash_attn_varlen_func(
query,
key,
value,
q_cu_seqlens,
k_cu_seqlens,
q_max_seqlen_in_batch,
k_max_seqlen_in_batch,
)
x = pad_input(x, indices, batch_size, q_max_seqlen_in_batch)
x = x.reshape(batch_size, -1, attn.heads * head_dim)
else:
x = flash_attn_func(query, key, value, dropout_p=0.0, causal=False)
x = x.reshape(batch_size, -1, attn.heads * head_dim)
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
x = x.to(query.dtype)
# linear proj
@@ -514,9 +552,9 @@ class JointAttnProcessor:
def __call__(
self,
attn: Attention,
x: float["b n d"], # noised input x # noqa: F722
c: float["b nt d"] = None, # context c, here text # noqa: F722
mask: bool["b n"] | None = None, # noqa: F722
x: float["b n d"], # noised input x
c: float["b nt d"] = None, # context c, here text
mask: bool["b n"] | None = None,
rope=None, # rotary position embedding for x
c_rope=None, # rotary position embedding for c
) -> torch.FloatTensor:
@@ -608,12 +646,27 @@ class JointAttnProcessor:
class DiTBlock(nn.Module):
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, qk_norm=None, pe_attn_head=None):
def __init__(
self,
dim,
heads,
dim_head,
ff_mult=4,
dropout=0.1,
qk_norm=None,
pe_attn_head=None,
attn_backend="flash_attn",
attn_mask_enabled=True,
):
super().__init__()
self.attn_norm = AdaLayerNorm(dim)
self.attn = Attention(
processor=AttnProcessor(pe_attn_head=pe_attn_head),
processor=AttnProcessor(
pe_attn_head=pe_attn_head,
attn_backend=attn_backend,
attn_mask_enabled=attn_mask_enabled,
),
dim=dim,
heads=heads,
dim_head=dim_head,
@@ -724,7 +777,7 @@ class TimestepEmbedding(nn.Module):
self.time_embed = SinusPositionEmbedding(freq_embed_dim)
self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
def forward(self, timestep: float["b"]): # noqa: F821
def forward(self, timestep: float["b"]):
time_hidden = self.time_embed(timestep)
time_hidden = time_hidden.to(timestep.dtype)
time = self.time_mlp(time_hidden) # b d

View File

@@ -35,6 +35,16 @@ def default(v, d):
return v if exists(v) else d
def is_package_available(package_name: str) -> bool:
try:
import importlib
package_exists = importlib.util.find_spec(package_name) is not None
return package_exists
except Exception:
return False
# tensor helpers