Skip to content

Commit

Permalink
add network_merge_n_models option
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Oct 9, 2023
1 parent 8b79e3b commit f611726
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 35 deletions.
51 changes: 33 additions & 18 deletions gen_img_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,13 @@
import diffusers
import numpy as np
import torch

try:
import intel_extension_for_pytorch as ipex

if torch.xpu.is_available():
from library.ipex import ipex_init

ipex_init()
except Exception:
pass
Expand Down Expand Up @@ -954,7 +957,7 @@ def __call__(
text_emb_last = torch.stack(text_emb_last)
else:
text_emb_last = text_embeddings

for i, t in enumerate(tqdm(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
Expand Down Expand Up @@ -2363,12 +2366,19 @@ def __getattr__(self, item):
network_default_muls = []
network_pre_calc = args.network_pre_calc

# merge関連の引数を統合する
if args.network_merge:
network_merge = len(args.network_module) # all networks are merged
elif args.network_merge_n_models:
network_merge = args.network_merge_n_models
else:
network_merge = None

for i, network_module in enumerate(args.network_module):
print("import network module:", network_module)
imported_module = importlib.import_module(network_module)

network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
network_default_muls.append(network_mul)

net_kwargs = {}
if args.network_args and i < len(args.network_args):
Expand All @@ -2379,31 +2389,32 @@ def __getattr__(self, item):
key, value = net_arg.split("=")
net_kwargs[key] = value

if args.network_weights and i < len(args.network_weights):
network_weight = args.network_weights[i]
print("load network weights from:", network_weight)
if args.network_weights is None or len(args.network_weights) <= i:
raise ValueError("No weight. Weight is required.")

network_weight = args.network_weights[i]
print("load network weights from:", network_weight)

if model_util.is_safetensors(network_weight) and args.network_show_meta:
from safetensors.torch import safe_open
if model_util.is_safetensors(network_weight) and args.network_show_meta:
from safetensors.torch import safe_open

with safe_open(network_weight, framework="pt") as f:
metadata = f.metadata()
if metadata is not None:
print(f"metadata for: {network_weight}: {metadata}")
with safe_open(network_weight, framework="pt") as f:
metadata = f.metadata()
if metadata is not None:
print(f"metadata for: {network_weight}: {metadata}")

network, weights_sd = imported_module.create_network_from_weights(
network_mul, network_weight, vae, text_encoder, unet, for_inference=True, **net_kwargs
)
else:
raise ValueError("No weight. Weight is required.")
network, weights_sd = imported_module.create_network_from_weights(
network_mul, network_weight, vae, text_encoder, unet, for_inference=True, **net_kwargs
)
if network is None:
return

mergeable = network.is_mergeable()
if args.network_merge and not mergeable:
if network_merge and not mergeable:
print("network is not mergiable. ignore merge option.")

if not args.network_merge or not mergeable:
if not mergeable or i >= network_merge:
# not merging
network.apply_to(text_encoder, unet)
info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい
print(f"weights are loaded: {info}")
Expand All @@ -2417,6 +2428,7 @@ def __getattr__(self, item):
network.backup_weights()

networks.append(network)
network_default_muls.append(network_mul)
else:
network.merge_to(text_encoder, unet, weights_sd, dtype, device)

Expand Down Expand Up @@ -3367,6 +3379,9 @@ def setup_parser() -> argparse.ArgumentParser:
"--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数"
)
parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する")
parser.add_argument(
"--network_merge_n_models", type=int, default=None, help="merge this number of networks / この数だけネットワークをマージする"
)
parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする")
parser.add_argument(
"--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する"
Expand Down
50 changes: 33 additions & 17 deletions sdxl_gen_img.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@
import diffusers
import numpy as np
import torch

try:
import intel_extension_for_pytorch as ipex

if torch.xpu.is_available():
from library.ipex import ipex_init

ipex_init()
except Exception:
pass
Expand Down Expand Up @@ -1534,12 +1537,20 @@ def __getattr__(self, item):
network_default_muls = []
network_pre_calc = args.network_pre_calc

# merge関連の引数を統合する
if args.network_merge:
network_merge = len(args.network_module) # all networks are merged
elif args.network_merge_n_models:
network_merge = args.network_merge_n_models
else:
network_merge = None
print(f"network_merge: {network_merge}")

for i, network_module in enumerate(args.network_module):
print("import network module:", network_module)
imported_module = importlib.import_module(network_module)

network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
network_default_muls.append(network_mul)

net_kwargs = {}
if args.network_args and i < len(args.network_args):
Expand All @@ -1550,31 +1561,32 @@ def __getattr__(self, item):
key, value = net_arg.split("=")
net_kwargs[key] = value

if args.network_weights and i < len(args.network_weights):
network_weight = args.network_weights[i]
print("load network weights from:", network_weight)
if args.network_weights is None or len(args.network_weights) <= i:
raise ValueError("No weight. Weight is required.")

if model_util.is_safetensors(network_weight) and args.network_show_meta:
from safetensors.torch import safe_open
network_weight = args.network_weights[i]
print("load network weights from:", network_weight)

with safe_open(network_weight, framework="pt") as f:
metadata = f.metadata()
if metadata is not None:
print(f"metadata for: {network_weight}: {metadata}")
if model_util.is_safetensors(network_weight) and args.network_show_meta:
from safetensors.torch import safe_open

network, weights_sd = imported_module.create_network_from_weights(
network_mul, network_weight, vae, [text_encoder1, text_encoder2], unet, for_inference=True, **net_kwargs
)
else:
raise ValueError("No weight. Weight is required.")
with safe_open(network_weight, framework="pt") as f:
metadata = f.metadata()
if metadata is not None:
print(f"metadata for: {network_weight}: {metadata}")

network, weights_sd = imported_module.create_network_from_weights(
network_mul, network_weight, vae, [text_encoder1, text_encoder2], unet, for_inference=True, **net_kwargs
)
if network is None:
return

mergeable = network.is_mergeable()
if args.network_merge and not mergeable:
if network_merge and not mergeable:
print("network is not mergiable. ignore merge option.")

if not args.network_merge or not mergeable:
if not mergeable or i >= network_merge:
# not merging
network.apply_to([text_encoder1, text_encoder2], unet)
info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい
print(f"weights are loaded: {info}")
Expand All @@ -1588,6 +1600,7 @@ def __getattr__(self, item):
network.backup_weights()

networks.append(network)
network_default_muls.append(network_mul)
else:
network.merge_to([text_encoder1, text_encoder2], unet, weights_sd, dtype, device)

Expand Down Expand Up @@ -2615,6 +2628,9 @@ def setup_parser() -> argparse.ArgumentParser:
"--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数"
)
parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する")
parser.add_argument(
"--network_merge_n_models", type=int, default=None, help="merge this number of networks / この数だけネットワークをマージする"
)
parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする")
parser.add_argument(
"--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する"
Expand Down

0 comments on commit f611726

Please sign in to comment.