Skip to content

Commit

Permalink
Merge branch 'sd3_5_support' of https://github.com/kohya-ss/sd-scripts
Browse files Browse the repository at this point in the history
…into sd3_5_support
  • Loading branch information
kohya-ss committed Oct 30, 2024
2 parents bdddc20 + b502f58 commit 8c3c825
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 22 deletions.
11 changes: 11 additions & 0 deletions library/flux_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int
with safe_open(ckpt_path, framework="pt") as f:
keys.extend(f.keys())

# if the key has annoying prefix, remove it
if keys[0].startswith("model.diffusion_model."):
keys = [key.replace("model.diffusion_model.", "") for key in keys]

is_diffusers = "transformer_blocks.0.attn.add_k_proj.bias" in keys
is_schnell = not ("guidance_in.in_layer.bias" in keys or "time_text_embed.guidance_embedder.linear_1.bias" in keys)

Expand Down Expand Up @@ -122,6 +126,13 @@ def load_flow_model(
sd = convert_diffusers_sd_to_bfl(sd, num_double_blocks, num_single_blocks)
logger.info("Converted Diffusers to BFL")

# if the key has annoying prefix, remove it
for key in list(sd.keys()):
new_key = key.replace("model.diffusion_model.", "")
if new_key == key:
break # the model doesn't have annoying prefix
sd[new_key] = sd.pop(key)

info = model.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded Flux: {info}")
return is_schnell, model
Expand Down
25 changes: 19 additions & 6 deletions networks/lora_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ def create_modules(
target_replace_modules: List[str],
filter: Optional[str] = None,
default_dim: Optional[int] = None,
include_conv2d_if_filter: bool = False,
) -> List[LoRAModule]:
prefix = (
self.LORA_PREFIX_SD3
Expand All @@ -332,8 +333,11 @@ def create_modules(
lora_name = prefix + "." + (name + "." if name else "") + child_name
lora_name = lora_name.replace(".", "_")

if filter is not None and not filter in lora_name:
continue
force_incl_conv2d = False
if filter is not None:
if not filter in lora_name:
continue
force_incl_conv2d = include_conv2d_if_filter

dim = None
alpha = None
Expand Down Expand Up @@ -373,6 +377,10 @@ def create_modules(
elif self.conv_lora_dim is not None:
dim = self.conv_lora_dim
alpha = self.conv_alpha
elif force_incl_conv2d:
# x_embedder
dim = default_dim if default_dim is not None else self.lora_dim
alpha = self.alpha

if dim is None or dim == 0:
# skipした情報を出力
Expand Down Expand Up @@ -428,15 +436,20 @@ def create_modules(
for filter, in_dim in zip(
[
"context_embedder",
"t_embedder",
"_t_embedder", # don't use "t_embedder" because it's used in "context_embedder"
"x_embedder",
"y_embedder",
"final_layer_adaLN_modulation",
"final_layer_linear",
],
self.emb_dims,
):
loras, _ = create_modules(True, None, unet, None, filter=filter, default_dim=in_dim)
# x_embedder is conv2d, so we need to include it
loras, _ = create_modules(
True, None, unet, None, filter=filter, default_dim=in_dim, include_conv2d_if_filter=filter == "x_embedder"
)
# if len(loras) > 0:
# logger.info(f"create LoRA for {filter}: {len(loras)} modules.")
self.unet_loras.extend(loras)

logger.info(f"create LoRA for SD3 MMDiT: {len(self.unet_loras)} modules.")
Expand Down Expand Up @@ -540,8 +553,8 @@ def state_dict(self, destination=None, prefix="", keep_vars=False):
down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim)

# merge up weight (sum of split_dim, rank*3)
qkv_dim, rank = up_weights[0].size()
split_dim = qkv_dim // 3
split_dim, rank = up_weights[0].size()
qkv_dim = split_dim * 3
up_weight = torch.zeros((qkv_dim, down_weight.size(0)), device=down_weight.device, dtype=down_weight.dtype)
i = 0
for j in range(3):
Expand Down
60 changes: 44 additions & 16 deletions sd3_minimal_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@

import torch
from safetensors.torch import safe_open, load_file
import torch.amp
from tqdm import tqdm
from PIL import Image
from transformers import CLIPTextModelWithProjection, T5EncoderModel

from library.device_utils import init_ipex, get_preferred_device
from networks import lora_sd3

init_ipex()

Expand Down Expand Up @@ -104,7 +106,8 @@ def do_sample(
x_c_nc = torch.cat([x, x], dim=0)
# print(x_c_nc.shape, timestep.shape, c_crossattn.shape, y.shape)

model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y)
with torch.autocast(device_type=device.type, dtype=dtype):
model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y)
model_output = model_output.float()
batched = model_sampling.calculate_denoised(sigma_hat, model_output, x)

Expand Down Expand Up @@ -153,7 +156,7 @@ def generate_image(
clip_g.to(device)
t5xxl.to(device)

with torch.no_grad():
with torch.autocast(device_type=device.type, dtype=mmdit.dtype), torch.no_grad():
tokens_and_masks = tokenize_strategy.tokenize(prompt)
lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = encoding_strategy.encode_tokens(
tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask
Expand Down Expand Up @@ -233,13 +236,14 @@ def generate_image(
parser.add_argument("--bf16", action="store_true")
parser.add_argument("--seed", type=int, default=1)
parser.add_argument("--steps", type=int, default=50)
# parser.add_argument(
# "--lora_weights",
# type=str,
# nargs="*",
# default=[],
# help="LoRA weights, only supports networks.lora, each argument is a `path;multiplier` (semi-colon separated)",
# )
parser.add_argument(
"--lora_weights",
type=str,
nargs="*",
default=[],
help="LoRA weights, only supports networks.lora_sd3, each argument is a `path;multiplier` (semi-colon separated)",
)
parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model")
parser.add_argument("--width", type=int, default=target_width)
parser.add_argument("--height", type=int, default=target_height)
parser.add_argument("--interactive", action="store_true")
Expand Down Expand Up @@ -294,6 +298,30 @@ def generate_image(
tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_token_length)
encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy()

# LoRA
lora_models: list[lora_sd3.LoRANetwork] = []
for weights_file in args.lora_weights:
if ";" in weights_file:
weights_file, multiplier = weights_file.split(";")
multiplier = float(multiplier)
else:
multiplier = 1.0

weights_sd = load_file(weights_file)
module = lora_sd3
lora_model, _ = module.create_network_from_weights(multiplier, None, vae, [clip_l, clip_g, t5xxl], mmdit, weights_sd, True)

if args.merge_lora_weights:
lora_model.merge_to([clip_l, clip_g, t5xxl], mmdit, weights_sd)
else:
lora_model.apply_to([clip_l, clip_g, t5xxl], mmdit)
info = lora_model.load_state_dict(weights_sd, strict=True)
logger.info(f"Loaded LoRA weights from {weights_file}: {info}")
lora_model.eval()
lora_model.to(device)

lora_models.append(lora_model)

if not args.interactive:
generate_image(
mmdit,
Expand Down Expand Up @@ -344,13 +372,13 @@ def generate_image(
steps = int(opt[1:].strip())
elif opt.startswith("d"):
seed = int(opt[1:].strip())
# elif opt.startswith("m"):
# mutipliers = opt[1:].strip().split(",")
# if len(mutipliers) != len(lora_models):
# logger.error(f"Invalid number of multipliers, expected {len(lora_models)}")
# continue
# for i, lora_model in enumerate(lora_models):
# lora_model.set_multiplier(float(mutipliers[i]))
elif opt.startswith("m"):
mutipliers = opt[1:].strip().split(",")
if len(mutipliers) != len(lora_models):
logger.error(f"Invalid number of multipliers, expected {len(lora_models)}")
continue
for i, lora_model in enumerate(lora_models):
lora_model.set_multiplier(float(mutipliers[i]))
elif opt.startswith("n"):
negative_prompt = opt[1:].strip()
if negative_prompt == "-":
Expand Down

0 comments on commit 8c3c825

Please sign in to comment.