mirror of
https://github.com/SWivid/F5-TTS.git
synced 2025-12-12 07:40:43 -08:00
backward compatibility
This commit is contained in:
@@ -443,7 +443,7 @@ class AttnProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
pe_attn_head: int | None = None, # number of attention head to apply rope, None for all
|
||||
attn_backend: str = "flash_attn",
|
||||
attn_backend: str = "torch", # "torch" or "flash_attn"
|
||||
attn_mask_enabled: bool = True,
|
||||
):
|
||||
if attn_backend == "flash_attn":
|
||||
@@ -655,7 +655,7 @@ class DiTBlock(nn.Module):
|
||||
dropout=0.1,
|
||||
qk_norm=None,
|
||||
pe_attn_head=None,
|
||||
attn_backend="flash_attn",
|
||||
attn_backend="torch", # "torch" or "flash_attn"
|
||||
attn_mask_enabled=True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
Reference in New Issue
Block a user