mirror of
https://github.com/SWivid/F5-TTS.git
synced 2025-12-12 07:40:43 -08:00
Add flash_attn2 support attn_mask, minor fixes (#1066)
* add flash attn2 support * update flash attn config in F5TTS * fix minor bug of get the length of ref_mel --------- Co-authored-by: SWivid <swivid@qq.com>
This commit is contained in:
@@ -31,6 +31,8 @@ model:
|
||||
text_mask_padding: False
|
||||
conv_layers: 4
|
||||
pe_attn_head: 1
|
||||
attn_backend: torch # torch | flash_attn
|
||||
attn_mask_enabled: False
|
||||
checkpoint_activations: False # recompute activations and save memory for extra compute
|
||||
mel_spec:
|
||||
target_sample_rate: 24000
|
||||
|
||||
@@ -31,6 +31,8 @@ model:
|
||||
text_mask_padding: False
|
||||
conv_layers: 4
|
||||
pe_attn_head: 1
|
||||
attn_backend: torch # torch | flash_attn
|
||||
attn_mask_enabled: False
|
||||
checkpoint_activations: False # recompute activations and save memory for extra compute
|
||||
mel_spec:
|
||||
target_sample_rate: 24000
|
||||
|
||||
@@ -32,6 +32,8 @@ model:
|
||||
qk_norm: null # null | rms_norm
|
||||
conv_layers: 4
|
||||
pe_attn_head: null
|
||||
attn_backend: torch # torch | flash_attn
|
||||
attn_mask_enabled: False
|
||||
checkpoint_activations: False # recompute activations and save memory for extra compute
|
||||
mel_spec:
|
||||
target_sample_rate: 24000
|
||||
|
||||
@@ -148,10 +148,15 @@ def main():
|
||||
vocab_char_map=vocab_char_map,
|
||||
).to(device)
|
||||
|
||||
ckpt_path = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}.pt"
|
||||
if not os.path.exists(ckpt_path):
|
||||
ckpt_prefix = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}"
|
||||
if os.path.exists(ckpt_prefix + ".pt"):
|
||||
ckpt_path = ckpt_prefix + ".pt"
|
||||
elif os.path.exists(ckpt_prefix + ".safetensors"):
|
||||
ckpt_path = ckpt_prefix + ".safetensors"
|
||||
else:
|
||||
print("Loading from self-organized training checkpoints rather than released pretrained.")
|
||||
ckpt_path = rel_path + f"/{model_cfg.ckpts.save_dir}/model_{ckpt_step}.pt"
|
||||
|
||||
dtype = torch.float32 if mel_spec_type == "bigvgan" else None
|
||||
model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
|
||||
|
||||
|
||||
@@ -126,8 +126,13 @@ def get_inference_prompt(
|
||||
else:
|
||||
text_list = text
|
||||
|
||||
# to mel spectrogram
|
||||
ref_mel = mel_spectrogram(ref_audio)
|
||||
ref_mel = ref_mel.squeeze(0)
|
||||
|
||||
# Duration, mel frame length
|
||||
ref_mel_len = ref_audio.shape[-1] // hop_length
|
||||
ref_mel_len = ref_mel.shape[-1]
|
||||
|
||||
if use_truth_duration:
|
||||
gt_audio, gt_sr = torchaudio.load(gt_wav)
|
||||
if gt_sr != target_sample_rate:
|
||||
@@ -142,10 +147,6 @@ def get_inference_prompt(
|
||||
gen_text_len = len(gt_text.encode("utf-8"))
|
||||
total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed)
|
||||
|
||||
# to mel spectrogram
|
||||
ref_mel = mel_spectrogram(ref_audio)
|
||||
ref_mel = ref_mel.squeeze(0)
|
||||
|
||||
# deal with batch
|
||||
assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
|
||||
assert min_tokens <= total_mel_len <= max_tokens, (
|
||||
|
||||
@@ -116,6 +116,8 @@ class DiT(nn.Module):
|
||||
qk_norm=None,
|
||||
conv_layers=0,
|
||||
pe_attn_head=None,
|
||||
attn_backend="torch", # "torch" | "flash_attn"
|
||||
attn_mask_enabled=False,
|
||||
long_skip_connection=False,
|
||||
checkpoint_activations=False,
|
||||
):
|
||||
@@ -145,6 +147,8 @@ class DiT(nn.Module):
|
||||
dropout=dropout,
|
||||
qk_norm=qk_norm,
|
||||
pe_attn_head=pe_attn_head,
|
||||
attn_backend=attn_backend,
|
||||
attn_mask_enabled=attn_mask_enabled,
|
||||
)
|
||||
for _ in range(depth)
|
||||
]
|
||||
|
||||
@@ -275,10 +275,9 @@ class CFM(nn.Module):
|
||||
else:
|
||||
drop_text = False
|
||||
|
||||
# if want rigorously mask out padding, record in collate_fn in dataset.py, and pass in here
|
||||
# adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
|
||||
# apply mask will use more memory; might adjust batchsize or batchsampler long sequence threshold
|
||||
pred = self.transformer(
|
||||
x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text
|
||||
x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text, mask=mask
|
||||
)
|
||||
|
||||
# flow matching loss
|
||||
|
||||
@@ -312,7 +312,7 @@ def collate_fn(batch):
|
||||
max_mel_length = mel_lengths.amax()
|
||||
|
||||
padded_mel_specs = []
|
||||
for spec in mel_specs: # TODO. maybe records mask for attention here
|
||||
for spec in mel_specs:
|
||||
padding = (0, max_mel_length - spec.size(-1))
|
||||
padded_spec = F.pad(spec, padding, value=0)
|
||||
padded_mel_specs.append(padded_spec)
|
||||
@@ -324,7 +324,7 @@ def collate_fn(batch):
|
||||
|
||||
return dict(
|
||||
mel=mel_specs,
|
||||
mel_lengths=mel_lengths,
|
||||
mel_lengths=mel_lengths, # records for padding mask
|
||||
text=text,
|
||||
text_lengths=text_lengths,
|
||||
)
|
||||
|
||||
@@ -6,6 +6,7 @@ nt - text sequence
|
||||
nw - raw wave length
|
||||
d - dimension
|
||||
"""
|
||||
# flake8: noqa
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -19,6 +20,8 @@ from librosa.filters import mel as librosa_mel_fn
|
||||
from torch import nn
|
||||
from x_transformers.x_transformers import apply_rotary_pos_emb
|
||||
|
||||
from f5_tts.model.utils import is_package_available
|
||||
|
||||
|
||||
# raw wav to mel spec
|
||||
|
||||
@@ -175,7 +178,7 @@ class ConvPositionEmbedding(nn.Module):
|
||||
nn.Mish(),
|
||||
)
|
||||
|
||||
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
|
||||
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None):
|
||||
if mask is not None:
|
||||
mask = mask[..., None]
|
||||
x = x.masked_fill(~mask, 0.0)
|
||||
@@ -417,9 +420,9 @@ class Attention(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: float["b n d"], # noised input x # noqa: F722
|
||||
c: float["b n d"] = None, # context c # noqa: F722
|
||||
mask: bool["b n"] | None = None, # noqa: F722
|
||||
x: float["b n d"], # noised input x
|
||||
c: float["b n d"] = None, # context c
|
||||
mask: bool["b n"] | None = None,
|
||||
rope=None, # rotary position embedding for x
|
||||
c_rope=None, # rotary position embedding for c
|
||||
) -> torch.Tensor:
|
||||
@@ -431,19 +434,30 @@ class Attention(nn.Module):
|
||||
|
||||
# Attention processor
|
||||
|
||||
if is_package_available("flash_attn"):
|
||||
from flash_attn.bert_padding import pad_input, unpad_input
|
||||
from flash_attn import flash_attn_varlen_func, flash_attn_func
|
||||
|
||||
|
||||
class AttnProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
pe_attn_head: int | None = None, # number of attention head to apply rope, None for all
|
||||
attn_backend: str = "flash_attn",
|
||||
attn_mask_enabled: bool = True,
|
||||
):
|
||||
if attn_backend == "flash_attn":
|
||||
assert is_package_available("flash_attn"), "Please install flash-attn first."
|
||||
|
||||
self.pe_attn_head = pe_attn_head
|
||||
self.attn_backend = attn_backend
|
||||
self.attn_mask_enabled = attn_mask_enabled
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
x: float["b n d"], # noised input x # noqa: F722
|
||||
mask: bool["b n"] | None = None, # noqa: F722
|
||||
x: float["b n d"], # noised input x
|
||||
mask: bool["b n"] | None = None,
|
||||
rope=None, # rotary position embedding
|
||||
) -> torch.FloatTensor:
|
||||
batch_size = x.shape[0]
|
||||
@@ -479,16 +493,40 @@ class AttnProcessor:
|
||||
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
|
||||
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
|
||||
|
||||
# mask. e.g. inference got a batch with different target durations, mask out the padding
|
||||
if mask is not None:
|
||||
attn_mask = mask
|
||||
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
|
||||
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
|
||||
else:
|
||||
attn_mask = None
|
||||
if self.attn_backend == "torch":
|
||||
# mask. e.g. inference got a batch with different target durations, mask out the padding
|
||||
if self.attn_mask_enabled and mask is not None:
|
||||
attn_mask = mask
|
||||
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
|
||||
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
|
||||
else:
|
||||
attn_mask = None
|
||||
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
|
||||
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
|
||||
elif self.attn_backend == "flash_attn":
|
||||
query = query.transpose(1, 2) # [b, h, n, d] -> [b, n, h, d]
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
if self.attn_mask_enabled and mask is not None:
|
||||
query, indices, q_cu_seqlens, q_max_seqlen_in_batch, _ = unpad_input(query, mask)
|
||||
key, _, k_cu_seqlens, k_max_seqlen_in_batch, _ = unpad_input(key, mask)
|
||||
value, _, _, _, _ = unpad_input(value, mask)
|
||||
x = flash_attn_varlen_func(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
q_cu_seqlens,
|
||||
k_cu_seqlens,
|
||||
q_max_seqlen_in_batch,
|
||||
k_max_seqlen_in_batch,
|
||||
)
|
||||
x = pad_input(x, indices, batch_size, q_max_seqlen_in_batch)
|
||||
x = x.reshape(batch_size, -1, attn.heads * head_dim)
|
||||
else:
|
||||
x = flash_attn_func(query, key, value, dropout_p=0.0, causal=False)
|
||||
x = x.reshape(batch_size, -1, attn.heads * head_dim)
|
||||
|
||||
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
|
||||
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
x = x.to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
@@ -514,9 +552,9 @@ class JointAttnProcessor:
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
x: float["b n d"], # noised input x # noqa: F722
|
||||
c: float["b nt d"] = None, # context c, here text # noqa: F722
|
||||
mask: bool["b n"] | None = None, # noqa: F722
|
||||
x: float["b n d"], # noised input x
|
||||
c: float["b nt d"] = None, # context c, here text
|
||||
mask: bool["b n"] | None = None,
|
||||
rope=None, # rotary position embedding for x
|
||||
c_rope=None, # rotary position embedding for c
|
||||
) -> torch.FloatTensor:
|
||||
@@ -608,12 +646,27 @@ class JointAttnProcessor:
|
||||
|
||||
|
||||
class DiTBlock(nn.Module):
|
||||
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, qk_norm=None, pe_attn_head=None):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
heads,
|
||||
dim_head,
|
||||
ff_mult=4,
|
||||
dropout=0.1,
|
||||
qk_norm=None,
|
||||
pe_attn_head=None,
|
||||
attn_backend="flash_attn",
|
||||
attn_mask_enabled=True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attn_norm = AdaLayerNorm(dim)
|
||||
self.attn = Attention(
|
||||
processor=AttnProcessor(pe_attn_head=pe_attn_head),
|
||||
processor=AttnProcessor(
|
||||
pe_attn_head=pe_attn_head,
|
||||
attn_backend=attn_backend,
|
||||
attn_mask_enabled=attn_mask_enabled,
|
||||
),
|
||||
dim=dim,
|
||||
heads=heads,
|
||||
dim_head=dim_head,
|
||||
@@ -724,7 +777,7 @@ class TimestepEmbedding(nn.Module):
|
||||
self.time_embed = SinusPositionEmbedding(freq_embed_dim)
|
||||
self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
|
||||
|
||||
def forward(self, timestep: float["b"]): # noqa: F821
|
||||
def forward(self, timestep: float["b"]):
|
||||
time_hidden = self.time_embed(timestep)
|
||||
time_hidden = time_hidden.to(timestep.dtype)
|
||||
time = self.time_mlp(time_hidden) # b d
|
||||
|
||||
@@ -35,6 +35,16 @@ def default(v, d):
|
||||
return v if exists(v) else d
|
||||
|
||||
|
||||
def is_package_available(package_name: str) -> bool:
|
||||
try:
|
||||
import importlib
|
||||
|
||||
package_exists = importlib.util.find_spec(package_name) is not None
|
||||
return package_exists
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
# tensor helpers
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user