From f6117263649078680661e319b9a469268a87b43b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 9 Oct 2023 21:41:50 +0900 Subject: [PATCH] add network_merge_n_models option --- gen_img_diffusers.py | 51 ++++++++++++++++++++++++++++---------------- sdxl_gen_img.py | 50 ++++++++++++++++++++++++++++--------------- 2 files changed, 66 insertions(+), 35 deletions(-) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 0ec683a23..820028347 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -65,10 +65,13 @@ import diffusers import numpy as np import torch + try: import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): from library.ipex import ipex_init + ipex_init() except Exception: pass @@ -954,7 +957,7 @@ def __call__( text_emb_last = torch.stack(text_emb_last) else: text_emb_last = text_embeddings - + for i, t in enumerate(tqdm(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) @@ -2363,12 +2366,19 @@ def __getattr__(self, item): network_default_muls = [] network_pre_calc = args.network_pre_calc + # merge関連の引数を統合する + if args.network_merge: + network_merge = len(args.network_module) # all networks are merged + elif args.network_merge_n_models: + network_merge = args.network_merge_n_models + else: + network_merge = None + for i, network_module in enumerate(args.network_module): print("import network module:", network_module) imported_module = importlib.import_module(network_module) network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] - network_default_muls.append(network_mul) net_kwargs = {} if args.network_args and i < len(args.network_args): @@ -2379,31 +2389,32 @@ def __getattr__(self, item): key, value = net_arg.split("=") net_kwargs[key] = value - if args.network_weights and i < len(args.network_weights): - network_weight = args.network_weights[i] - print("load network weights from:", network_weight) + if args.network_weights is None or len(args.network_weights) <= i: + raise ValueError("No weight. Weight is required.") + + network_weight = args.network_weights[i] + print("load network weights from:", network_weight) - if model_util.is_safetensors(network_weight) and args.network_show_meta: - from safetensors.torch import safe_open + if model_util.is_safetensors(network_weight) and args.network_show_meta: + from safetensors.torch import safe_open - with safe_open(network_weight, framework="pt") as f: - metadata = f.metadata() - if metadata is not None: - print(f"metadata for: {network_weight}: {metadata}") + with safe_open(network_weight, framework="pt") as f: + metadata = f.metadata() + if metadata is not None: + print(f"metadata for: {network_weight}: {metadata}") - network, weights_sd = imported_module.create_network_from_weights( - network_mul, network_weight, vae, text_encoder, unet, for_inference=True, **net_kwargs - ) - else: - raise ValueError("No weight. Weight is required.") + network, weights_sd = imported_module.create_network_from_weights( + network_mul, network_weight, vae, text_encoder, unet, for_inference=True, **net_kwargs + ) if network is None: return mergeable = network.is_mergeable() - if args.network_merge and not mergeable: + if network_merge and not mergeable: print("network is not mergiable. ignore merge option.") - if not args.network_merge or not mergeable: + if not mergeable or i >= network_merge: + # not merging network.apply_to(text_encoder, unet) info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい print(f"weights are loaded: {info}") @@ -2417,6 +2428,7 @@ def __getattr__(self, item): network.backup_weights() networks.append(network) + network_default_muls.append(network_mul) else: network.merge_to(text_encoder, unet, weights_sd, dtype, device) @@ -3367,6 +3379,9 @@ def setup_parser() -> argparse.ArgumentParser: "--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数" ) parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する") + parser.add_argument( + "--network_merge_n_models", type=int, default=None, help="merge this number of networks / この数だけネットワークをマージする" + ) parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする") parser.add_argument( "--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する" diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index ab2b6b3d6..2d652bc82 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -17,10 +17,13 @@ import diffusers import numpy as np import torch + try: import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): from library.ipex import ipex_init + ipex_init() except Exception: pass @@ -1534,12 +1537,20 @@ def __getattr__(self, item): network_default_muls = [] network_pre_calc = args.network_pre_calc + # merge関連の引数を統合する + if args.network_merge: + network_merge = len(args.network_module) # all networks are merged + elif args.network_merge_n_models: + network_merge = args.network_merge_n_models + else: + network_merge = None + print(f"network_merge: {network_merge}") + for i, network_module in enumerate(args.network_module): print("import network module:", network_module) imported_module = importlib.import_module(network_module) network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] - network_default_muls.append(network_mul) net_kwargs = {} if args.network_args and i < len(args.network_args): @@ -1550,31 +1561,32 @@ def __getattr__(self, item): key, value = net_arg.split("=") net_kwargs[key] = value - if args.network_weights and i < len(args.network_weights): - network_weight = args.network_weights[i] - print("load network weights from:", network_weight) + if args.network_weights is None or len(args.network_weights) <= i: + raise ValueError("No weight. Weight is required.") - if model_util.is_safetensors(network_weight) and args.network_show_meta: - from safetensors.torch import safe_open + network_weight = args.network_weights[i] + print("load network weights from:", network_weight) - with safe_open(network_weight, framework="pt") as f: - metadata = f.metadata() - if metadata is not None: - print(f"metadata for: {network_weight}: {metadata}") + if model_util.is_safetensors(network_weight) and args.network_show_meta: + from safetensors.torch import safe_open - network, weights_sd = imported_module.create_network_from_weights( - network_mul, network_weight, vae, [text_encoder1, text_encoder2], unet, for_inference=True, **net_kwargs - ) - else: - raise ValueError("No weight. Weight is required.") + with safe_open(network_weight, framework="pt") as f: + metadata = f.metadata() + if metadata is not None: + print(f"metadata for: {network_weight}: {metadata}") + + network, weights_sd = imported_module.create_network_from_weights( + network_mul, network_weight, vae, [text_encoder1, text_encoder2], unet, for_inference=True, **net_kwargs + ) if network is None: return mergeable = network.is_mergeable() - if args.network_merge and not mergeable: + if network_merge and not mergeable: print("network is not mergiable. ignore merge option.") - if not args.network_merge or not mergeable: + if not mergeable or i >= network_merge: + # not merging network.apply_to([text_encoder1, text_encoder2], unet) info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい print(f"weights are loaded: {info}") @@ -1588,6 +1600,7 @@ def __getattr__(self, item): network.backup_weights() networks.append(network) + network_default_muls.append(network_mul) else: network.merge_to([text_encoder1, text_encoder2], unet, weights_sd, dtype, device) @@ -2615,6 +2628,9 @@ def setup_parser() -> argparse.ArgumentParser: "--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数" ) parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する") + parser.add_argument( + "--network_merge_n_models", type=int, default=None, help="merge this number of networks / この数だけネットワークをマージする" + ) parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする") parser.add_argument( "--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する"