backward compatibility

This commit is contained in:
SWivid
2025-06-12 03:52:12 +08:00
parent b3ef4ed1d7
commit 8b0053ad0c

View File

@@ -443,7 +443,7 @@ class AttnProcessor:
def __init__(
self,
pe_attn_head: int | None = None, # number of attention head to apply rope, None for all
attn_backend: str = "flash_attn",
attn_backend: str = "torch", # "torch" or "flash_attn"
attn_mask_enabled: bool = True,
):
if attn_backend == "flash_attn":
@@ -655,7 +655,7 @@ class DiTBlock(nn.Module):
dropout=0.1,
qk_norm=None,
pe_attn_head=None,
attn_backend="flash_attn",
attn_backend="torch", # "torch" or "flash_attn"
attn_mask_enabled=True,
):
super().__init__()