fix no_ref_audio in cfm.py

This commit is contained in:
SWivid
2025-01-07 22:01:43 +08:00
parent dc2d2d3b2f
commit 4872afef9f

View File

@@ -142,6 +142,9 @@ class CFM(nn.Module):
test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0)
cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0)
if no_ref_audio:
cond = torch.zeros_like(cond)
cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False)
cond_mask = cond_mask.unsqueeze(-1)
step_cond = torch.where(
@@ -153,10 +156,6 @@ class CFM(nn.Module):
else: # save memory and speed up, as single inference need no mask currently
mask = None
# test for no ref audio
if no_ref_audio:
cond = torch.zeros_like(cond)
# neural ode
def fn(t, x):