diff --git a/src/f5_tts/model/backbones/dit.py b/src/f5_tts/model/backbones/dit.py index c3ae0ce..2af99a1 100644 --- a/src/f5_tts/model/backbones/dit.py +++ b/src/f5_tts/model/backbones/dit.py @@ -182,26 +182,16 @@ class DiT(nn.Module): return ckpt_forward - def clear_cache(self): - self.text_cond, self.text_uncond = None, None - - def forward( + def get_input_embed( self, - x: float["b n d"], # nosied input audio # noqa: F722 - cond: float["b n d"], # masked cond audio # noqa: F722 - text: int["b nt"], # text # noqa: F722 - time: float["b"] | float[""], # time step # noqa: F821 F722 - drop_audio_cond, # cfg for cond audio - drop_text, # cfg for text - mask: bool["b n"] | None = None, # noqa: F722 - cache=False, + x, # b n d + cond, # b n d + text, # b nt + drop_audio_cond: bool = False, + drop_text: bool = False, + cache: bool = True, ): - batch, seq_len = x.shape[0], x.shape[1] - if time.ndim == 0: - time = time.repeat(batch) - - # t: conditioning time, text: text, x: noised audio + cond audio + text - t = self.time_embed(time) + seq_len = x.shape[1] if cache: if drop_text: if self.text_uncond is None: @@ -213,8 +203,41 @@ class DiT(nn.Module): text_embed = self.text_cond else: text_embed = self.text_embed(text, seq_len, drop_text=drop_text) + x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond) + return x + + def clear_cache(self): + self.text_cond, self.text_uncond = None, None + + def forward( + self, + x: float["b n d"], # nosied input audio # noqa: F722 + cond: float["b n d"], # masked cond audio # noqa: F722 + text: int["b nt"], # text # noqa: F722 + time: float["b"] | float[""], # time step # noqa: F821 F722 + mask: bool["b n"] | None = None, # noqa: F722 + drop_audio_cond: bool = False, # cfg for cond audio + drop_text: bool = False, # cfg for text + cfg_infer: bool = False, # cfg inference, pack cond & uncond forward + cache: bool = False, + ): + batch, seq_len = x.shape[0], x.shape[1] + if time.ndim == 0: + time = time.repeat(batch) + + # t: conditioning time, text: text, x: noised audio + cond audio + text + t = self.time_embed(time) + if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d + x_cond = self.get_input_embed(x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache) + x_uncond = self.get_input_embed(x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache) + x = torch.cat((x_cond, x_uncond), dim=0) + t = torch.cat((t, t), dim=0) + mask = torch.cat((mask, mask), dim=0) if mask is not None else None + else: + x = self.get_input_embed(x, cond, text, drop_audio_cond=drop_audio_cond, drop_text=drop_text, cache=cache) + rope = self.rotary_embed.forward_from_seq_len(seq_len) if self.long_skip_connection is not None: diff --git a/src/f5_tts/model/backbones/mmdit.py b/src/f5_tts/model/backbones/mmdit.py index e8d18c4..cd56de9 100644 --- a/src/f5_tts/model/backbones/mmdit.py +++ b/src/f5_tts/model/backbones/mmdit.py @@ -141,26 +141,15 @@ class MMDiT(nn.Module): nn.init.constant_(self.proj_out.weight, 0) nn.init.constant_(self.proj_out.bias, 0) - def clear_cache(self): - self.text_cond, self.text_uncond = None, None - - def forward( + def get_input_embed( self, - x: float["b n d"], # nosied input audio # noqa: F722 - cond: float["b n d"], # masked cond audio # noqa: F722 - text: int["b nt"], # text # noqa: F722 - time: float["b"] | float[""], # time step # noqa: F821 F722 - drop_audio_cond, # cfg for cond audio - drop_text, # cfg for text - mask: bool["b n"] | None = None, # noqa: F722 - cache=False, + x, # b n d + cond, # b n d + text, # b nt + drop_audio_cond: bool = False, + drop_text: bool = False, + cache: bool = True, ): - batch = x.shape[0] - if time.ndim == 0: - time = time.repeat(batch) - - # t: conditioning (time), c: context (text + masked cond audio), x: noised input audio - t = self.time_embed(time) if cache: if drop_text: if self.text_uncond is None: @@ -174,6 +163,41 @@ class MMDiT(nn.Module): c = self.text_embed(text, drop_text=drop_text) x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond) + return x, c + + def clear_cache(self): + self.text_cond, self.text_uncond = None, None + + def forward( + self, + x: float["b n d"], # nosied input audio # noqa: F722 + cond: float["b n d"], # masked cond audio # noqa: F722 + text: int["b nt"], # text # noqa: F722 + time: float["b"] | float[""], # time step # noqa: F821 F722 + mask: bool["b n"] | None = None, # noqa: F722 + drop_audio_cond: bool = False, # cfg for cond audio + drop_text: bool = False, # cfg for text + cfg_infer: bool = False, # cfg inference, pack cond & uncond forward + cache: bool = False, + ): + batch = x.shape[0] + if time.ndim == 0: + time = time.repeat(batch) + + # t: conditioning (time), c: context (text + masked cond audio), x: noised input audio + t = self.time_embed(time) + if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d + x_cond, c_cond = self.get_input_embed(x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache) + x_uncond, c_uncond = self.get_input_embed(x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache) + x = torch.cat((x_cond, x_uncond), dim=0) + c = torch.cat((c_cond, c_uncond), dim=0) + t = torch.cat((t, t), dim=0) + mask = torch.cat((mask, mask), dim=0) if mask is not None else None + else: + x, c = self.get_input_embed( + x, cond, text, drop_audio_cond=drop_audio_cond, drop_text=drop_text, cache=cache + ) + seq_len = x.shape[1] text_len = text.shape[1] rope_audio = self.rotary_embed.forward_from_seq_len(seq_len) diff --git a/src/f5_tts/model/backbones/unett.py b/src/f5_tts/model/backbones/unett.py index 0192766..7b63c6e 100644 --- a/src/f5_tts/model/backbones/unett.py +++ b/src/f5_tts/model/backbones/unett.py @@ -178,26 +178,16 @@ class UNetT(nn.Module): self.norm_out = RMSNorm(dim) self.proj_out = nn.Linear(dim, mel_dim) - def clear_cache(self): - self.text_cond, self.text_uncond = None, None - - def forward( + def get_input_embed( self, - x: float["b n d"], # nosied input audio # noqa: F722 - cond: float["b n d"], # masked cond audio # noqa: F722 - text: int["b nt"], # text # noqa: F722 - time: float["b"] | float[""], # time step # noqa: F821 F722 - drop_audio_cond, # cfg for cond audio - drop_text, # cfg for text - mask: bool["b n"] | None = None, # noqa: F722 - cache=False, + x, # b n d + cond, # b n d + text, # b nt + drop_audio_cond: bool = False, + drop_text: bool = False, + cache: bool = True, ): - batch, seq_len = x.shape[0], x.shape[1] - if time.ndim == 0: - time = time.repeat(batch) - - # t: conditioning time, c: context (text + masked cond audio), x: noised input audio - t = self.time_embed(time) + seq_len = x.shape[1] if cache: if drop_text: if self.text_uncond is None: @@ -209,8 +199,41 @@ class UNetT(nn.Module): text_embed = self.text_cond else: text_embed = self.text_embed(text, seq_len, drop_text=drop_text) + x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond) + return x + + def clear_cache(self): + self.text_cond, self.text_uncond = None, None + + def forward( + self, + x: float["b n d"], # nosied input audio # noqa: F722 + cond: float["b n d"], # masked cond audio # noqa: F722 + text: int["b nt"], # text # noqa: F722 + time: float["b"] | float[""], # time step # noqa: F821 F722 + mask: bool["b n"] | None = None, # noqa: F722 + drop_audio_cond: bool = False, # cfg for cond audio + drop_text: bool = False, # cfg for text + cfg_infer: bool = False, # cfg inference, pack cond & uncond forward + cache: bool = False, + ): + batch, seq_len = x.shape[0], x.shape[1] + if time.ndim == 0: + time = time.repeat(batch) + + # t: conditioning time, c: context (text + masked cond audio), x: noised input audio + t = self.time_embed(time) + if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d + x_cond = self.get_input_embed(x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache) + x_uncond = self.get_input_embed(x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache) + x = torch.cat((x_cond, x_uncond), dim=0) + t = torch.cat((t, t), dim=0) + mask = torch.cat((mask, mask), dim=0) if mask is not None else None + else: + x = self.get_input_embed(x, cond, text, drop_audio_cond=drop_audio_cond, drop_text=drop_text, cache=cache) + # postfix time t to input x, [b n d] -> [b n+1 d] x = torch.cat([t.unsqueeze(1), x], dim=1) # pack t to x if mask is not None: diff --git a/src/f5_tts/model/cfm.py b/src/f5_tts/model/cfm.py index d2ec96d..ae06575 100644 --- a/src/f5_tts/model/cfm.py +++ b/src/f5_tts/model/cfm.py @@ -162,16 +162,31 @@ class CFM(nn.Module): # at each step, conditioning is fixed # step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) - # predict flow - pred = self.transformer( - x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=False, drop_text=False, cache=True - ) + # predict flow (cond) if cfg_strength < 1e-5: + pred = self.transformer( + x=x, + cond=step_cond, + text=text, + time=t, + mask=mask, + drop_audio_cond=False, + drop_text=False, + cache=True, + ) return pred - null_pred = self.transformer( - x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=True, drop_text=True, cache=True + # predict flow (cond and uncond), for classifier-free guidance + pred_cfg = self.transformer( + x=x, + cond=step_cond, + text=text, + time=t, + mask=mask, + cfg_infer=True, + cache=True, ) + pred, null_pred = torch.chunk(pred_cfg, 2, dim=0) return pred + (pred - null_pred) * cfg_strength # noise input diff --git a/src/f5_tts/model/modules.py b/src/f5_tts/model/modules.py index 3a96664..1e2ee4a 100644 --- a/src/f5_tts/model/modules.py +++ b/src/f5_tts/model/modules.py @@ -443,7 +443,7 @@ class AttnProcessor: def __init__( self, pe_attn_head: int | None = None, # number of attention head to apply rope, None for all - attn_backend: str = "flash_attn", + attn_backend: str = "torch", # "torch" or "flash_attn" attn_mask_enabled: bool = True, ): if attn_backend == "flash_attn": @@ -655,7 +655,7 @@ class DiTBlock(nn.Module): dropout=0.1, qk_norm=None, pe_attn_head=None, - attn_backend="flash_attn", + attn_backend="torch", # "torch" or "flash_attn" attn_mask_enabled=True, ): super().__init__()