set attn related default value for unet-t backbone: #1192

This commit is contained in:
SWivid
2025-10-09 06:51:25 +00:00
parent 77d3ec623b
commit 65ada48a62

View File

@@ -120,6 +120,8 @@ class UNetT(nn.Module):
qk_norm=None,
conv_layers=0,
pe_attn_head=None,
attn_backend="torch", # "torch" | "flash_attn"
attn_mask_enabled=False,
skip_connect_type: Literal["add", "concat", "none"] = "concat",
):
super().__init__()
@@ -150,7 +152,11 @@ class UNetT(nn.Module):
attn_norm = RMSNorm(dim)
attn = Attention(
processor=AttnProcessor(pe_attn_head=pe_attn_head),
processor=AttnProcessor(
pe_attn_head=pe_attn_head,
attn_backend=attn_backend,
attn_mask_enabled=attn_mask_enabled,
),
dim=dim,
heads=heads,
dim_head=dim_head,