mirror of
https://github.com/SWivid/F5-TTS.git
synced 2025-12-12 07:40:43 -08:00
set attn related default value for unet-t backbone: #1192
This commit is contained in:
@@ -120,6 +120,8 @@ class UNetT(nn.Module):
|
||||
qk_norm=None,
|
||||
conv_layers=0,
|
||||
pe_attn_head=None,
|
||||
attn_backend="torch", # "torch" | "flash_attn"
|
||||
attn_mask_enabled=False,
|
||||
skip_connect_type: Literal["add", "concat", "none"] = "concat",
|
||||
):
|
||||
super().__init__()
|
||||
@@ -150,7 +152,11 @@ class UNetT(nn.Module):
|
||||
|
||||
attn_norm = RMSNorm(dim)
|
||||
attn = Attention(
|
||||
processor=AttnProcessor(pe_attn_head=pe_attn_head),
|
||||
processor=AttnProcessor(
|
||||
pe_attn_head=pe_attn_head,
|
||||
attn_backend=attn_backend,
|
||||
attn_mask_enabled=attn_mask_enabled,
|
||||
),
|
||||
dim=dim,
|
||||
heads=heads,
|
||||
dim_head=dim_head,
|
||||
|
||||
Reference in New Issue
Block a user