Skip to content

Commit

Permalink
support SD3.5M
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Oct 30, 2024
1 parent 7555486 commit bdddc20
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 58 deletions.
128 changes: 81 additions & 47 deletions library/sd3_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class SD3Params:
pos_embed_max_size: int
adm_in_channels: int
qk_norm: Optional[str]
x_block_self_attn_layers: List[int]
x_block_self_attn_layers: list[int]
context_embedder_in_features: int
context_embedder_out_features: int
model_type: str
Expand Down Expand Up @@ -510,6 +510,7 @@ def __init__(
scale_mod_only: bool = False,
swiglu: bool = False,
qk_norm: Optional[str] = None,
x_block_self_attn: bool = False,
**block_kwargs,
):
super().__init__()
Expand All @@ -519,13 +520,14 @@ def __init__(
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
else:
self.norm1 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = AttentionLinears(
dim=hidden_size,
num_heads=num_heads,
qkv_bias=qkv_bias,
pre_only=pre_only,
qk_norm=qk_norm,
)
self.attn = AttentionLinears(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, pre_only=pre_only, qk_norm=qk_norm)

self.x_block_self_attn = x_block_self_attn
if self.x_block_self_attn:
assert not pre_only
assert not scale_mod_only
self.attn2 = AttentionLinears(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, pre_only=False, qk_norm=qk_norm)

if not pre_only:
if not rmsnorm:
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
Expand All @@ -546,7 +548,9 @@ def __init__(
multiple_of=256,
)
self.scale_mod_only = scale_mod_only
if not scale_mod_only:
if self.x_block_self_attn:
n_mods = 9
elif not scale_mod_only:
n_mods = 6 if not pre_only else 2
else:
n_mods = 4 if not pre_only else 1
Expand All @@ -556,63 +560,64 @@ def __init__(
def pre_attention(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
if not self.pre_only:
if not self.scale_mod_only:
(
shift_msa,
scale_msa,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
) = self.adaLN_modulation(
c
).chunk(6, dim=-1)
(shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp) = self.adaLN_modulation(c).chunk(6, dim=-1)
else:
shift_msa = None
shift_mlp = None
(
scale_msa,
gate_msa,
scale_mlp,
gate_mlp,
) = self.adaLN_modulation(
c
).chunk(4, dim=-1)
(scale_msa, gate_msa, scale_mlp, gate_mlp) = self.adaLN_modulation(c).chunk(4, dim=-1)
qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa))
return qkv, (
x,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
)
return qkv, (x, gate_msa, shift_mlp, scale_mlp, gate_mlp)
else:
if not self.scale_mod_only:
(
shift_msa,
scale_msa,
) = self.adaLN_modulation(
c
).chunk(2, dim=-1)
(shift_msa, scale_msa) = self.adaLN_modulation(c).chunk(2, dim=-1)
else:
shift_msa = None
scale_msa = self.adaLN_modulation(c)
qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa))
return qkv, None

def pre_attention_x(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
assert self.x_block_self_attn
(shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2) = self.adaLN_modulation(
c
).chunk(9, dim=1)
x_norm = self.norm1(x)
qkv = self.attn.pre_attention(modulate(x_norm, shift_msa, scale_msa))
qkv2 = self.attn2.pre_attention(modulate(x_norm, shift_msa2, scale_msa2))
return qkv, qkv2, (x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2)

def post_attention(self, attn, x, gate_msa, shift_mlp, scale_mlp, gate_mlp):
assert not self.pre_only
x = x + gate_msa.unsqueeze(1) * self.attn.post_attention(attn)
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x

def post_attention_x(self, attn, attn2, x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2, attn1_dropout: float = 0.0):
assert not self.pre_only
if attn1_dropout > 0.0:
# Use torch.bernoulli to implement dropout, only dropout the batch dimension
attn1_dropout = torch.bernoulli(torch.full((attn.size(0), 1, 1), 1 - attn1_dropout, device=attn.device))
attn_ = gate_msa.unsqueeze(1) * self.attn.post_attention(attn) * attn1_dropout
else:
attn_ = gate_msa.unsqueeze(1) * self.attn.post_attention(attn)
x = x + attn_
attn2_ = gate_msa2.unsqueeze(1) * self.attn2.post_attention(attn2)
x = x + attn2_
mlp_ = gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
x = x + mlp_
return x


# JointBlock + block_mixing in mmdit.py
class MMDiTBlock(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
pre_only = kwargs.pop("pre_only")
x_block_self_attn = kwargs.pop("x_block_self_attn")

self.context_block = SingleDiTBlock(*args, pre_only=pre_only, **kwargs)
self.x_block = SingleDiTBlock(*args, pre_only=False, **kwargs)
self.x_block = SingleDiTBlock(*args, pre_only=False, x_block_self_attn=x_block_self_attn, **kwargs)

self.head_dim = self.x_block.attn.head_dim
self.mode = self.x_block.attn_mode
self.gradient_checkpointing = False
Expand All @@ -622,7 +627,11 @@ def enable_gradient_checkpointing(self):

def _forward(self, context, x, c):
ctx_qkv, ctx_intermediate = self.context_block.pre_attention(context, c)
x_qkv, x_intermediate = self.x_block.pre_attention(x, c)

if self.x_block.x_block_self_attn:
x_qkv, x_qkv2, x_intermediates = self.x_block.pre_attention_x(x, c)
else:
x_qkv, x_intermediates = self.x_block.pre_attention(x, c)

ctx_len = ctx_qkv[0].size(1)

Expand All @@ -634,11 +643,18 @@ def _forward(self, context, x, c):
ctx_attn_out = attn[:, :ctx_len]
x_attn_out = attn[:, ctx_len:]

x = self.x_block.post_attention(x_attn_out, *x_intermediate)
if self.x_block.x_block_self_attn:
x_q2, x_k2, x_v2 = x_qkv2
attn2 = attention(x_q2, x_k2, x_v2, self.x_block.attn2.num_heads)
x = self.x_block.post_attention_x(x_attn_out, attn2, *x_intermediates)
else:
x = self.x_block.post_attention(x_attn_out, *x_intermediates)

if not self.context_block.pre_only:
context = self.context_block.post_attention(ctx_attn_out, *ctx_intermediate)
else:
context = None

return context, x

def forward(self, *args, **kwargs):
Expand Down Expand Up @@ -678,7 +694,9 @@ def __init__(
pos_embed_max_size: Optional[int] = None,
num_patches=None,
qk_norm: Optional[str] = None,
x_block_self_attn_layers: Optional[list[int]] = [],
qkv_bias: bool = True,
pos_emb_random_crop_rate: float = 0.0,
model_type: str = "sd3m",
):
super().__init__()
Expand All @@ -691,6 +709,8 @@ def __init__(
self.pos_embed_scaling_factor = pos_embed_scaling_factor
self.pos_embed_offset = pos_embed_offset
self.pos_embed_max_size = pos_embed_max_size
self.x_block_self_attn_layers = x_block_self_attn_layers
self.pos_emb_random_crop_rate = pos_emb_random_crop_rate
self.gradient_checkpointing = use_checkpoint

# hidden_size = default(hidden_size, 64 * depth)
Expand Down Expand Up @@ -751,6 +771,7 @@ def __init__(
scale_mod_only=scale_mod_only,
swiglu=swiglu,
qk_norm=qk_norm,
x_block_self_attn=(i in self.x_block_self_attn_layers),
)
for i in range(depth)
]
Expand Down Expand Up @@ -832,7 +853,10 @@ def _basic_init(module):
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)

def cropped_pos_embed(self, h, w, device=None):
def set_pos_emb_random_crop_rate(self, rate: float):
self.pos_emb_random_crop_rate = rate

def cropped_pos_embed(self, h, w, device=None, random_crop: bool = False):
p = self.x_embedder.patch_size
# patched size
h = (h + 1) // p
Expand All @@ -842,8 +866,14 @@ def cropped_pos_embed(self, h, w, device=None):
assert self.pos_embed_max_size is not None
assert h <= self.pos_embed_max_size, (h, self.pos_embed_max_size)
assert w <= self.pos_embed_max_size, (w, self.pos_embed_max_size)
top = (self.pos_embed_max_size - h) // 2
left = (self.pos_embed_max_size - w) // 2

if not random_crop:
top = (self.pos_embed_max_size - h) // 2
left = (self.pos_embed_max_size - w) // 2
else:
top = torch.randint(0, self.pos_embed_max_size - h + 1, (1,)).item()
left = torch.randint(0, self.pos_embed_max_size - w + 1, (1,)).item()

spatial_pos_embed = self.pos_embed.reshape(
1,
self.pos_embed_max_size,
Expand Down Expand Up @@ -896,9 +926,12 @@ def forward(
t: (N,) tensor of diffusion timesteps
y: (N, D) tensor of class labels
"""
pos_emb_random_crop = (
False if self.pos_emb_random_crop_rate == 0.0 else torch.rand(1).item() < self.pos_emb_random_crop_rate
)

B, C, H, W = x.shape
x = self.x_embedder(x) + self.cropped_pos_embed(H, W, device=x.device).to(dtype=x.dtype)
x = self.x_embedder(x) + self.cropped_pos_embed(H, W, device=x.device, random_crop=pos_emb_random_crop).to(dtype=x.dtype)
c = self.t_embedder(t, dtype=x.dtype) # (N, D)
if y is not None and self.y_embedder is not None:
y = self.y_embedder(y) # (N, D)
Expand Down Expand Up @@ -977,6 +1010,7 @@ def create_sd3_mmdit(params: SD3Params, attn_mode: str = "torch") -> MMDiT:
depth=params.depth,
mlp_ratio=4,
qk_norm=params.qk_norm,
x_block_self_attn_layers=params.x_block_self_attn_layers,
num_patches=params.num_patches,
attn_mode=attn_mode,
model_type=params.model_type,
Expand Down
7 changes: 7 additions & 0 deletions library/sd3_train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,13 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser):
default=0.0,
help="Dropout rate for T5 encoder, default is 0.0 / T5エンコーダのドロップアウト率、デフォルトは0.0",
)
parser.add_argument(
"--pos_emb_random_crop_rate",
type=float,
default=0.0,
help="Random crop rate for positional embeddings, default is 0.0. Only for SD3.5M"
" / 位置埋め込みのランダムクロップ率、デフォルトは0.0。SD3.5M以外では予期しない動作になります",
)

# copy from Diffusers
parser.add_argument(
Expand Down
13 changes: 7 additions & 6 deletions library/sd3_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,20 +41,21 @@ def analyze_state_dict_state(state_dict: Dict, prefix: str = ""):

# x_block_self_attn_layers.append(int(key.split(".x_block.attn2.ln_k.weight")[0].split(".")[-1]))
x_block_self_attn_layers = []
re_attn = re.compile(r".(\d+).x_block.attn2.ln_k.weight")
re_attn = re.compile(r"\.(\d+)\.x_block\.attn2\.ln_k\.weight")
for key in list(state_dict.keys()):
m = re_attn.match(key)
m = re_attn.search(key)
if m:
x_block_self_attn_layers.append(int(m.group(1)))

assert len(x_block_self_attn_layers) == 0, "x_block_self_attn_layers is not supported"

context_embedder_in_features = context_shape[1]
context_embedder_out_features = context_shape[0]

# only supports 3-5-large and 3-medium
# only supports 3-5-large, medium or 3-medium
if qk_norm is not None:
model_type = "3-5-large"
if len(x_block_self_attn_layers) == 0:
model_type = "3-5-large"
else:
model_type = "3-5-medium"
else:
model_type = "3-medium"

Expand Down
8 changes: 3 additions & 5 deletions sd3_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,17 +353,15 @@ def train(args):
accelerator.wait_for_everyone()

# load MMDIT
mmdit = sd3_utils.load_mmdit(
sd3_state_dict,
model_dtype,
"cpu",
)
mmdit = sd3_utils.load_mmdit(sd3_state_dict, model_dtype, "cpu")

# attn_mode = "xformers" if args.xformers else "torch"
# assert (
# attn_mode == "torch"
# ), f"attn_mode {attn_mode} is not supported yet. Please use `--sdpa` instead of `--xformers`. / attn_mode {attn_mode} はサポートされていません。`--xformers`の代わりに`--sdpa`を使ってください。"

mmdit.set_pos_emb_random_crop_rate(args.pos_emb_random_crop_rate)

if args.gradient_checkpointing:
mmdit.enable_gradient_checkpointing()

Expand Down
1 change: 1 addition & 0 deletions sd3_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def load_target_model(self, args, weight_dtype, accelerator):
)
mmdit = sd3_utils.load_mmdit(state_dict, loading_dtype, "cpu")
self.model_type = mmdit.model_type
mmdit.set_pos_emb_random_crop_rate(args.pos_emb_random_crop_rate)

if args.fp8_base:
# check dtype of model
Expand Down

0 comments on commit bdddc20

Please sign in to comment.