From bdddc20d68a7441cccfcf0009528fdd59403b94a Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 30 Oct 2024 12:51:49 +0900 Subject: [PATCH] support SD3.5M --- library/sd3_models.py | 128 +++++++++++++++++++++++-------------- library/sd3_train_utils.py | 7 ++ library/sd3_utils.py | 13 ++-- sd3_train.py | 8 +-- sd3_train_network.py | 1 + 5 files changed, 99 insertions(+), 58 deletions(-) diff --git a/library/sd3_models.py b/library/sd3_models.py index 5d09f74e8..840f91869 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -51,7 +51,7 @@ class SD3Params: pos_embed_max_size: int adm_in_channels: int qk_norm: Optional[str] - x_block_self_attn_layers: List[int] + x_block_self_attn_layers: list[int] context_embedder_in_features: int context_embedder_out_features: int model_type: str @@ -510,6 +510,7 @@ def __init__( scale_mod_only: bool = False, swiglu: bool = False, qk_norm: Optional[str] = None, + x_block_self_attn: bool = False, **block_kwargs, ): super().__init__() @@ -519,13 +520,14 @@ def __init__( self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) else: self.norm1 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.attn = AttentionLinears( - dim=hidden_size, - num_heads=num_heads, - qkv_bias=qkv_bias, - pre_only=pre_only, - qk_norm=qk_norm, - ) + self.attn = AttentionLinears(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, pre_only=pre_only, qk_norm=qk_norm) + + self.x_block_self_attn = x_block_self_attn + if self.x_block_self_attn: + assert not pre_only + assert not scale_mod_only + self.attn2 = AttentionLinears(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, pre_only=False, qk_norm=qk_norm) + if not pre_only: if not rmsnorm: self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) @@ -546,7 +548,9 @@ def __init__( multiple_of=256, ) self.scale_mod_only = scale_mod_only - if not scale_mod_only: + if self.x_block_self_attn: + n_mods = 9 + elif not scale_mod_only: n_mods = 6 if not pre_only else 2 else: n_mods = 4 if not pre_only else 1 @@ -556,63 +560,64 @@ def __init__( def pre_attention(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: if not self.pre_only: if not self.scale_mod_only: - ( - shift_msa, - scale_msa, - gate_msa, - shift_mlp, - scale_mlp, - gate_mlp, - ) = self.adaLN_modulation( - c - ).chunk(6, dim=-1) + (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp) = self.adaLN_modulation(c).chunk(6, dim=-1) else: shift_msa = None shift_mlp = None - ( - scale_msa, - gate_msa, - scale_mlp, - gate_mlp, - ) = self.adaLN_modulation( - c - ).chunk(4, dim=-1) + (scale_msa, gate_msa, scale_mlp, gate_mlp) = self.adaLN_modulation(c).chunk(4, dim=-1) qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa)) - return qkv, ( - x, - gate_msa, - shift_mlp, - scale_mlp, - gate_mlp, - ) + return qkv, (x, gate_msa, shift_mlp, scale_mlp, gate_mlp) else: if not self.scale_mod_only: - ( - shift_msa, - scale_msa, - ) = self.adaLN_modulation( - c - ).chunk(2, dim=-1) + (shift_msa, scale_msa) = self.adaLN_modulation(c).chunk(2, dim=-1) else: shift_msa = None scale_msa = self.adaLN_modulation(c) qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa)) return qkv, None + def pre_attention_x(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + assert self.x_block_self_attn + (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2) = self.adaLN_modulation( + c + ).chunk(9, dim=1) + x_norm = self.norm1(x) + qkv = self.attn.pre_attention(modulate(x_norm, shift_msa, scale_msa)) + qkv2 = self.attn2.pre_attention(modulate(x_norm, shift_msa2, scale_msa2)) + return qkv, qkv2, (x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2) + def post_attention(self, attn, x, gate_msa, shift_mlp, scale_mlp, gate_mlp): assert not self.pre_only x = x + gate_msa.unsqueeze(1) * self.attn.post_attention(attn) x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) return x + def post_attention_x(self, attn, attn2, x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2, attn1_dropout: float = 0.0): + assert not self.pre_only + if attn1_dropout > 0.0: + # Use torch.bernoulli to implement dropout, only dropout the batch dimension + attn1_dropout = torch.bernoulli(torch.full((attn.size(0), 1, 1), 1 - attn1_dropout, device=attn.device)) + attn_ = gate_msa.unsqueeze(1) * self.attn.post_attention(attn) * attn1_dropout + else: + attn_ = gate_msa.unsqueeze(1) * self.attn.post_attention(attn) + x = x + attn_ + attn2_ = gate_msa2.unsqueeze(1) * self.attn2.post_attention(attn2) + x = x + attn2_ + mlp_ = gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + x = x + mlp_ + return x + # JointBlock + block_mixing in mmdit.py class MMDiTBlock(nn.Module): def __init__(self, *args, **kwargs): super().__init__() pre_only = kwargs.pop("pre_only") + x_block_self_attn = kwargs.pop("x_block_self_attn") + self.context_block = SingleDiTBlock(*args, pre_only=pre_only, **kwargs) - self.x_block = SingleDiTBlock(*args, pre_only=False, **kwargs) + self.x_block = SingleDiTBlock(*args, pre_only=False, x_block_self_attn=x_block_self_attn, **kwargs) + self.head_dim = self.x_block.attn.head_dim self.mode = self.x_block.attn_mode self.gradient_checkpointing = False @@ -622,7 +627,11 @@ def enable_gradient_checkpointing(self): def _forward(self, context, x, c): ctx_qkv, ctx_intermediate = self.context_block.pre_attention(context, c) - x_qkv, x_intermediate = self.x_block.pre_attention(x, c) + + if self.x_block.x_block_self_attn: + x_qkv, x_qkv2, x_intermediates = self.x_block.pre_attention_x(x, c) + else: + x_qkv, x_intermediates = self.x_block.pre_attention(x, c) ctx_len = ctx_qkv[0].size(1) @@ -634,11 +643,18 @@ def _forward(self, context, x, c): ctx_attn_out = attn[:, :ctx_len] x_attn_out = attn[:, ctx_len:] - x = self.x_block.post_attention(x_attn_out, *x_intermediate) + if self.x_block.x_block_self_attn: + x_q2, x_k2, x_v2 = x_qkv2 + attn2 = attention(x_q2, x_k2, x_v2, self.x_block.attn2.num_heads) + x = self.x_block.post_attention_x(x_attn_out, attn2, *x_intermediates) + else: + x = self.x_block.post_attention(x_attn_out, *x_intermediates) + if not self.context_block.pre_only: context = self.context_block.post_attention(ctx_attn_out, *ctx_intermediate) else: context = None + return context, x def forward(self, *args, **kwargs): @@ -678,7 +694,9 @@ def __init__( pos_embed_max_size: Optional[int] = None, num_patches=None, qk_norm: Optional[str] = None, + x_block_self_attn_layers: Optional[list[int]] = [], qkv_bias: bool = True, + pos_emb_random_crop_rate: float = 0.0, model_type: str = "sd3m", ): super().__init__() @@ -691,6 +709,8 @@ def __init__( self.pos_embed_scaling_factor = pos_embed_scaling_factor self.pos_embed_offset = pos_embed_offset self.pos_embed_max_size = pos_embed_max_size + self.x_block_self_attn_layers = x_block_self_attn_layers + self.pos_emb_random_crop_rate = pos_emb_random_crop_rate self.gradient_checkpointing = use_checkpoint # hidden_size = default(hidden_size, 64 * depth) @@ -751,6 +771,7 @@ def __init__( scale_mod_only=scale_mod_only, swiglu=swiglu, qk_norm=qk_norm, + x_block_self_attn=(i in self.x_block_self_attn_layers), ) for i in range(depth) ] @@ -832,7 +853,10 @@ def _basic_init(module): nn.init.constant_(self.final_layer.linear.weight, 0) nn.init.constant_(self.final_layer.linear.bias, 0) - def cropped_pos_embed(self, h, w, device=None): + def set_pos_emb_random_crop_rate(self, rate: float): + self.pos_emb_random_crop_rate = rate + + def cropped_pos_embed(self, h, w, device=None, random_crop: bool = False): p = self.x_embedder.patch_size # patched size h = (h + 1) // p @@ -842,8 +866,14 @@ def cropped_pos_embed(self, h, w, device=None): assert self.pos_embed_max_size is not None assert h <= self.pos_embed_max_size, (h, self.pos_embed_max_size) assert w <= self.pos_embed_max_size, (w, self.pos_embed_max_size) - top = (self.pos_embed_max_size - h) // 2 - left = (self.pos_embed_max_size - w) // 2 + + if not random_crop: + top = (self.pos_embed_max_size - h) // 2 + left = (self.pos_embed_max_size - w) // 2 + else: + top = torch.randint(0, self.pos_embed_max_size - h + 1, (1,)).item() + left = torch.randint(0, self.pos_embed_max_size - w + 1, (1,)).item() + spatial_pos_embed = self.pos_embed.reshape( 1, self.pos_embed_max_size, @@ -896,9 +926,12 @@ def forward( t: (N,) tensor of diffusion timesteps y: (N, D) tensor of class labels """ + pos_emb_random_crop = ( + False if self.pos_emb_random_crop_rate == 0.0 else torch.rand(1).item() < self.pos_emb_random_crop_rate + ) B, C, H, W = x.shape - x = self.x_embedder(x) + self.cropped_pos_embed(H, W, device=x.device).to(dtype=x.dtype) + x = self.x_embedder(x) + self.cropped_pos_embed(H, W, device=x.device, random_crop=pos_emb_random_crop).to(dtype=x.dtype) c = self.t_embedder(t, dtype=x.dtype) # (N, D) if y is not None and self.y_embedder is not None: y = self.y_embedder(y) # (N, D) @@ -977,6 +1010,7 @@ def create_sd3_mmdit(params: SD3Params, attn_mode: str = "torch") -> MMDiT: depth=params.depth, mlp_ratio=4, qk_norm=params.qk_norm, + x_block_self_attn_layers=params.x_block_self_attn_layers, num_patches=params.num_patches, attn_mode=attn_mode, model_type=params.model_type, diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index 1702e81c2..86f0c9c04 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -239,6 +239,13 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser): default=0.0, help="Dropout rate for T5 encoder, default is 0.0 / T5エンコーダのドロップアウト率、デフォルトは0.0", ) + parser.add_argument( + "--pos_emb_random_crop_rate", + type=float, + default=0.0, + help="Random crop rate for positional embeddings, default is 0.0. Only for SD3.5M" + " / 位置埋め込みのランダムクロップ率、デフォルトは0.0。SD3.5M以外では予期しない動作になります", + ) # copy from Diffusers parser.add_argument( diff --git a/library/sd3_utils.py b/library/sd3_utils.py index 71e50de36..1861dfbc2 100644 --- a/library/sd3_utils.py +++ b/library/sd3_utils.py @@ -41,20 +41,21 @@ def analyze_state_dict_state(state_dict: Dict, prefix: str = ""): # x_block_self_attn_layers.append(int(key.split(".x_block.attn2.ln_k.weight")[0].split(".")[-1])) x_block_self_attn_layers = [] - re_attn = re.compile(r".(\d+).x_block.attn2.ln_k.weight") + re_attn = re.compile(r"\.(\d+)\.x_block\.attn2\.ln_k\.weight") for key in list(state_dict.keys()): - m = re_attn.match(key) + m = re_attn.search(key) if m: x_block_self_attn_layers.append(int(m.group(1))) - assert len(x_block_self_attn_layers) == 0, "x_block_self_attn_layers is not supported" - context_embedder_in_features = context_shape[1] context_embedder_out_features = context_shape[0] - # only supports 3-5-large and 3-medium + # only supports 3-5-large, medium or 3-medium if qk_norm is not None: - model_type = "3-5-large" + if len(x_block_self_attn_layers) == 0: + model_type = "3-5-large" + else: + model_type = "3-5-medium" else: model_type = "3-medium" diff --git a/sd3_train.py b/sd3_train.py index cdac945e6..df2736901 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -353,17 +353,15 @@ def train(args): accelerator.wait_for_everyone() # load MMDIT - mmdit = sd3_utils.load_mmdit( - sd3_state_dict, - model_dtype, - "cpu", - ) + mmdit = sd3_utils.load_mmdit(sd3_state_dict, model_dtype, "cpu") # attn_mode = "xformers" if args.xformers else "torch" # assert ( # attn_mode == "torch" # ), f"attn_mode {attn_mode} is not supported yet. Please use `--sdpa` instead of `--xformers`. / attn_mode {attn_mode} はサポートされていません。`--xformers`の代わりに`--sdpa`を使ってください。" + mmdit.set_pos_emb_random_crop_rate(args.pos_emb_random_crop_rate) + if args.gradient_checkpointing: mmdit.enable_gradient_checkpointing() diff --git a/sd3_train_network.py b/sd3_train_network.py index 3506404ae..3d2a75710 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -65,6 +65,7 @@ def load_target_model(self, args, weight_dtype, accelerator): ) mmdit = sd3_utils.load_mmdit(state_dict, loading_dtype, "cpu") self.model_type = mmdit.model_type + mmdit.set_pos_emb_random_crop_rate(args.pos_emb_random_crop_rate) if args.fp8_base: # check dtype of model