From 0914170e98916b181bc84cb90213f5350abda42c Mon Sep 17 00:00:00 2001 From: Zhikang Niu <73390819+ZhikangNiu@users.noreply.github.com> Date: Wed, 11 Jun 2025 12:14:32 +0800 Subject: [PATCH] Add flash_attn2 support attn_mask, minor fixes (#1066) * add flash attn2 support * update flash attn config in F5TTS * fix minor bug of get the length of ref_mel --------- Co-authored-by: SWivid --- src/f5_tts/configs/F5TTS_Base.yaml | 2 + src/f5_tts/configs/F5TTS_Small.yaml | 2 + src/f5_tts/configs/F5TTS_v1_Base.yaml | 2 + src/f5_tts/eval/eval_infer_batch.py | 9 ++- src/f5_tts/eval/utils_eval.py | 11 ++-- src/f5_tts/model/backbones/dit.py | 4 ++ src/f5_tts/model/cfm.py | 5 +- src/f5_tts/model/dataset.py | 4 +- src/f5_tts/model/modules.py | 95 +++++++++++++++++++++------ src/f5_tts/model/utils.py | 10 +++ 10 files changed, 111 insertions(+), 33 deletions(-) diff --git a/src/f5_tts/configs/F5TTS_Base.yaml b/src/f5_tts/configs/F5TTS_Base.yaml index 9a2eeb9..d177674 100644 --- a/src/f5_tts/configs/F5TTS_Base.yaml +++ b/src/f5_tts/configs/F5TTS_Base.yaml @@ -31,6 +31,8 @@ model: text_mask_padding: False conv_layers: 4 pe_attn_head: 1 + attn_backend: torch # torch | flash_attn + attn_mask_enabled: False checkpoint_activations: False # recompute activations and save memory for extra compute mel_spec: target_sample_rate: 24000 diff --git a/src/f5_tts/configs/F5TTS_Small.yaml b/src/f5_tts/configs/F5TTS_Small.yaml index 1c4a6df..396f389 100644 --- a/src/f5_tts/configs/F5TTS_Small.yaml +++ b/src/f5_tts/configs/F5TTS_Small.yaml @@ -31,6 +31,8 @@ model: text_mask_padding: False conv_layers: 4 pe_attn_head: 1 + attn_backend: torch # torch | flash_attn + attn_mask_enabled: False checkpoint_activations: False # recompute activations and save memory for extra compute mel_spec: target_sample_rate: 24000 diff --git a/src/f5_tts/configs/F5TTS_v1_Base.yaml b/src/f5_tts/configs/F5TTS_v1_Base.yaml index c7717fa..e931a01 100644 --- a/src/f5_tts/configs/F5TTS_v1_Base.yaml +++ b/src/f5_tts/configs/F5TTS_v1_Base.yaml @@ -32,6 +32,8 @@ model: qk_norm: null # null | rms_norm conv_layers: 4 pe_attn_head: null + attn_backend: torch # torch | flash_attn + attn_mask_enabled: False checkpoint_activations: False # recompute activations and save memory for extra compute mel_spec: target_sample_rate: 24000 diff --git a/src/f5_tts/eval/eval_infer_batch.py b/src/f5_tts/eval/eval_infer_batch.py index a28226b..cea5b7a 100644 --- a/src/f5_tts/eval/eval_infer_batch.py +++ b/src/f5_tts/eval/eval_infer_batch.py @@ -148,10 +148,15 @@ def main(): vocab_char_map=vocab_char_map, ).to(device) - ckpt_path = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}.pt" - if not os.path.exists(ckpt_path): + ckpt_prefix = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}" + if os.path.exists(ckpt_prefix + ".pt"): + ckpt_path = ckpt_prefix + ".pt" + elif os.path.exists(ckpt_prefix + ".safetensors"): + ckpt_path = ckpt_prefix + ".safetensors" + else: print("Loading from self-organized training checkpoints rather than released pretrained.") ckpt_path = rel_path + f"/{model_cfg.ckpts.save_dir}/model_{ckpt_step}.pt" + dtype = torch.float32 if mel_spec_type == "bigvgan" else None model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema) diff --git a/src/f5_tts/eval/utils_eval.py b/src/f5_tts/eval/utils_eval.py index f819cdc..c5ac834 100644 --- a/src/f5_tts/eval/utils_eval.py +++ b/src/f5_tts/eval/utils_eval.py @@ -126,8 +126,13 @@ def get_inference_prompt( else: text_list = text + # to mel spectrogram + ref_mel = mel_spectrogram(ref_audio) + ref_mel = ref_mel.squeeze(0) + # Duration, mel frame length - ref_mel_len = ref_audio.shape[-1] // hop_length + ref_mel_len = ref_mel.shape[-1] + if use_truth_duration: gt_audio, gt_sr = torchaudio.load(gt_wav) if gt_sr != target_sample_rate: @@ -142,10 +147,6 @@ def get_inference_prompt( gen_text_len = len(gt_text.encode("utf-8")) total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed) - # to mel spectrogram - ref_mel = mel_spectrogram(ref_audio) - ref_mel = ref_mel.squeeze(0) - # deal with batch assert infer_batch_size > 0, "infer_batch_size should be greater than 0." assert min_tokens <= total_mel_len <= max_tokens, ( diff --git a/src/f5_tts/model/backbones/dit.py b/src/f5_tts/model/backbones/dit.py index 223daf3..c3ae0ce 100644 --- a/src/f5_tts/model/backbones/dit.py +++ b/src/f5_tts/model/backbones/dit.py @@ -116,6 +116,8 @@ class DiT(nn.Module): qk_norm=None, conv_layers=0, pe_attn_head=None, + attn_backend="torch", # "torch" | "flash_attn" + attn_mask_enabled=False, long_skip_connection=False, checkpoint_activations=False, ): @@ -145,6 +147,8 @@ class DiT(nn.Module): dropout=dropout, qk_norm=qk_norm, pe_attn_head=pe_attn_head, + attn_backend=attn_backend, + attn_mask_enabled=attn_mask_enabled, ) for _ in range(depth) ] diff --git a/src/f5_tts/model/cfm.py b/src/f5_tts/model/cfm.py index 15be1bb..d2ec96d 100644 --- a/src/f5_tts/model/cfm.py +++ b/src/f5_tts/model/cfm.py @@ -275,10 +275,9 @@ class CFM(nn.Module): else: drop_text = False - # if want rigorously mask out padding, record in collate_fn in dataset.py, and pass in here - # adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences + # apply mask will use more memory; might adjust batchsize or batchsampler long sequence threshold pred = self.transformer( - x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text + x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text, mask=mask ) # flow matching loss diff --git a/src/f5_tts/model/dataset.py b/src/f5_tts/model/dataset.py index fd6fb11..50448c3 100644 --- a/src/f5_tts/model/dataset.py +++ b/src/f5_tts/model/dataset.py @@ -312,7 +312,7 @@ def collate_fn(batch): max_mel_length = mel_lengths.amax() padded_mel_specs = [] - for spec in mel_specs: # TODO. maybe records mask for attention here + for spec in mel_specs: padding = (0, max_mel_length - spec.size(-1)) padded_spec = F.pad(spec, padding, value=0) padded_mel_specs.append(padded_spec) @@ -324,7 +324,7 @@ def collate_fn(batch): return dict( mel=mel_specs, - mel_lengths=mel_lengths, + mel_lengths=mel_lengths, # records for padding mask text=text, text_lengths=text_lengths, ) diff --git a/src/f5_tts/model/modules.py b/src/f5_tts/model/modules.py index 8e5c3c2..3a96664 100644 --- a/src/f5_tts/model/modules.py +++ b/src/f5_tts/model/modules.py @@ -6,6 +6,7 @@ nt - text sequence nw - raw wave length d - dimension """ +# flake8: noqa from __future__ import annotations @@ -19,6 +20,8 @@ from librosa.filters import mel as librosa_mel_fn from torch import nn from x_transformers.x_transformers import apply_rotary_pos_emb +from f5_tts.model.utils import is_package_available + # raw wav to mel spec @@ -175,7 +178,7 @@ class ConvPositionEmbedding(nn.Module): nn.Mish(), ) - def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722 + def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): if mask is not None: mask = mask[..., None] x = x.masked_fill(~mask, 0.0) @@ -417,9 +420,9 @@ class Attention(nn.Module): def forward( self, - x: float["b n d"], # noised input x # noqa: F722 - c: float["b n d"] = None, # context c # noqa: F722 - mask: bool["b n"] | None = None, # noqa: F722 + x: float["b n d"], # noised input x + c: float["b n d"] = None, # context c + mask: bool["b n"] | None = None, rope=None, # rotary position embedding for x c_rope=None, # rotary position embedding for c ) -> torch.Tensor: @@ -431,19 +434,30 @@ class Attention(nn.Module): # Attention processor +if is_package_available("flash_attn"): + from flash_attn.bert_padding import pad_input, unpad_input + from flash_attn import flash_attn_varlen_func, flash_attn_func + class AttnProcessor: def __init__( self, pe_attn_head: int | None = None, # number of attention head to apply rope, None for all + attn_backend: str = "flash_attn", + attn_mask_enabled: bool = True, ): + if attn_backend == "flash_attn": + assert is_package_available("flash_attn"), "Please install flash-attn first." + self.pe_attn_head = pe_attn_head + self.attn_backend = attn_backend + self.attn_mask_enabled = attn_mask_enabled def __call__( self, attn: Attention, - x: float["b n d"], # noised input x # noqa: F722 - mask: bool["b n"] | None = None, # noqa: F722 + x: float["b n d"], # noised input x + mask: bool["b n"] | None = None, rope=None, # rotary position embedding ) -> torch.FloatTensor: batch_size = x.shape[0] @@ -479,16 +493,40 @@ class AttnProcessor: query = apply_rotary_pos_emb(query, freqs, q_xpos_scale) key = apply_rotary_pos_emb(key, freqs, k_xpos_scale) - # mask. e.g. inference got a batch with different target durations, mask out the padding - if mask is not None: - attn_mask = mask - attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n' - attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2]) - else: - attn_mask = None + if self.attn_backend == "torch": + # mask. e.g. inference got a batch with different target durations, mask out the padding + if self.attn_mask_enabled and mask is not None: + attn_mask = mask + attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n' + attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2]) + else: + attn_mask = None + x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False) + x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + + elif self.attn_backend == "flash_attn": + query = query.transpose(1, 2) # [b, h, n, d] -> [b, n, h, d] + key = key.transpose(1, 2) + value = value.transpose(1, 2) + if self.attn_mask_enabled and mask is not None: + query, indices, q_cu_seqlens, q_max_seqlen_in_batch, _ = unpad_input(query, mask) + key, _, k_cu_seqlens, k_max_seqlen_in_batch, _ = unpad_input(key, mask) + value, _, _, _, _ = unpad_input(value, mask) + x = flash_attn_varlen_func( + query, + key, + value, + q_cu_seqlens, + k_cu_seqlens, + q_max_seqlen_in_batch, + k_max_seqlen_in_batch, + ) + x = pad_input(x, indices, batch_size, q_max_seqlen_in_batch) + x = x.reshape(batch_size, -1, attn.heads * head_dim) + else: + x = flash_attn_func(query, key, value, dropout_p=0.0, causal=False) + x = x.reshape(batch_size, -1, attn.heads * head_dim) - x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False) - x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) x = x.to(query.dtype) # linear proj @@ -514,9 +552,9 @@ class JointAttnProcessor: def __call__( self, attn: Attention, - x: float["b n d"], # noised input x # noqa: F722 - c: float["b nt d"] = None, # context c, here text # noqa: F722 - mask: bool["b n"] | None = None, # noqa: F722 + x: float["b n d"], # noised input x + c: float["b nt d"] = None, # context c, here text + mask: bool["b n"] | None = None, rope=None, # rotary position embedding for x c_rope=None, # rotary position embedding for c ) -> torch.FloatTensor: @@ -608,12 +646,27 @@ class JointAttnProcessor: class DiTBlock(nn.Module): - def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, qk_norm=None, pe_attn_head=None): + def __init__( + self, + dim, + heads, + dim_head, + ff_mult=4, + dropout=0.1, + qk_norm=None, + pe_attn_head=None, + attn_backend="flash_attn", + attn_mask_enabled=True, + ): super().__init__() self.attn_norm = AdaLayerNorm(dim) self.attn = Attention( - processor=AttnProcessor(pe_attn_head=pe_attn_head), + processor=AttnProcessor( + pe_attn_head=pe_attn_head, + attn_backend=attn_backend, + attn_mask_enabled=attn_mask_enabled, + ), dim=dim, heads=heads, dim_head=dim_head, @@ -724,7 +777,7 @@ class TimestepEmbedding(nn.Module): self.time_embed = SinusPositionEmbedding(freq_embed_dim) self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) - def forward(self, timestep: float["b"]): # noqa: F821 + def forward(self, timestep: float["b"]): time_hidden = self.time_embed(timestep) time_hidden = time_hidden.to(timestep.dtype) time = self.time_mlp(time_hidden) # b d diff --git a/src/f5_tts/model/utils.py b/src/f5_tts/model/utils.py index 37d5178..c5c3829 100644 --- a/src/f5_tts/model/utils.py +++ b/src/f5_tts/model/utils.py @@ -35,6 +35,16 @@ def default(v, d): return v if exists(v) else d +def is_package_available(package_name: str) -> bool: + try: + import importlib + + package_exists = importlib.util.find_spec(package_name) is not None + return package_exists + except Exception: + return False + + # tensor helpers