mirror of
https://github.com/SWivid/F5-TTS.git
synced 2025-12-12 07:40:43 -08:00
Batch cfg DiT forward
This commit is contained in:
@@ -185,23 +185,7 @@ class DiT(nn.Module):
|
||||
def clear_cache(self):
|
||||
self.text_cond, self.text_uncond = None, None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: float["b n d"], # nosied input audio # noqa: F722
|
||||
cond: float["b n d"], # masked cond audio # noqa: F722
|
||||
text: int["b nt"], # text # noqa: F722
|
||||
time: float["b"] | float[""], # time step # noqa: F821 F722
|
||||
drop_audio_cond, # cfg for cond audio
|
||||
drop_text, # cfg for text
|
||||
mask: bool["b n"] | None = None, # noqa: F722
|
||||
cache=False,
|
||||
):
|
||||
batch, seq_len = x.shape[0], x.shape[1]
|
||||
if time.ndim == 0:
|
||||
time = time.repeat(batch)
|
||||
|
||||
# t: conditioning time, text: text, x: noised audio + cond audio + text
|
||||
t = self.time_embed(time)
|
||||
def get_text_embed(self, text, seq_len, drop_text, cache):
|
||||
if cache:
|
||||
if drop_text:
|
||||
if self.text_uncond is None:
|
||||
@@ -213,7 +197,41 @@ class DiT(nn.Module):
|
||||
text_embed = self.text_cond
|
||||
else:
|
||||
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
|
||||
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
|
||||
return text_embed
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: float["b n d"], # nosied input audio # noqa: F722
|
||||
cond: float["b n d"], # masked cond audio # noqa: F722
|
||||
text: int["b nt"], # text # noqa: F722
|
||||
time: float["b"] | float[""], # time step # noqa: F821 F722
|
||||
drop_audio_cond, # cfg for cond audio
|
||||
drop_text, # cfg for text
|
||||
batch_cfg=False, # batch cfg compute
|
||||
mask: bool["b n"] | None = None, # noqa: F722
|
||||
cache=False,
|
||||
):
|
||||
batch, seq_len = x.shape[0], x.shape[1]
|
||||
if time.ndim == 0:
|
||||
time = time.repeat(batch)
|
||||
|
||||
# t: conditioning time, text: text, x: noised audio + cond audio + text
|
||||
t = self.time_embed(time)
|
||||
if batch_cfg:
|
||||
text_embed_cond = self.get_text_embed(
|
||||
text, seq_len, drop_text=False, cache=cache
|
||||
)
|
||||
text_embed_uncond = self.get_text_embed(
|
||||
text, seq_len, drop_text=True, cache=cache
|
||||
)
|
||||
x_cond = self.input_embed(x, cond, text_embed_cond, drop_audio_cond=False)
|
||||
x_uncond = self.input_embed(
|
||||
x, cond, text_embed_uncond, drop_audio_cond=True
|
||||
)
|
||||
x = torch.cat((x_cond, x_uncond), dim=0)
|
||||
else:
|
||||
text_embed = self.get_text_embed(text, seq_len, drop_text, cache)
|
||||
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
|
||||
|
||||
rope = self.rotary_embed.forward_from_seq_len(seq_len)
|
||||
|
||||
|
||||
@@ -163,15 +163,31 @@ class CFM(nn.Module):
|
||||
# step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
|
||||
|
||||
# predict flow
|
||||
pred = self.transformer(
|
||||
x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=False, drop_text=False, cache=True
|
||||
)
|
||||
if cfg_strength < 1e-5:
|
||||
pred = self.transformer(
|
||||
x=x,
|
||||
cond=step_cond,
|
||||
text=text,
|
||||
time=t,
|
||||
mask=mask,
|
||||
drop_audio_cond=False,
|
||||
drop_text=False,
|
||||
cache=True,
|
||||
)
|
||||
return pred
|
||||
|
||||
null_pred = self.transformer(
|
||||
x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=True, drop_text=True, cache=True
|
||||
pred_and_null = self.transformer(
|
||||
x=x,
|
||||
cond=step_cond,
|
||||
text=text,
|
||||
time=t,
|
||||
mask=mask,
|
||||
drop_audio_cond=False,
|
||||
drop_text=False,
|
||||
batch_cfg=True,
|
||||
cache=True,
|
||||
)
|
||||
pred, null_pred = torch.chunk(pred_and_null, 2, dim=0)
|
||||
return pred + (pred - null_pred) * cfg_strength
|
||||
|
||||
# noise input
|
||||
|
||||
Reference in New Issue
Block a user