From 3e81bd6b6729bf8b553b30cbce23e21698287039 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 9 Oct 2023 23:07:14 +0900 Subject: [PATCH] fix network_merge, add regional mask as color code --- gen_img_diffusers.py | 21 ++++++++++++++++++--- sdxl_gen_img.py | 21 ++++++++++++++++++--- 2 files changed, 36 insertions(+), 6 deletions(-) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 820028347..a596a0494 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -2372,7 +2372,7 @@ def __getattr__(self, item): elif args.network_merge_n_models: network_merge = args.network_merge_n_models else: - network_merge = None + network_merge = 0 for i, network_module in enumerate(args.network_module): print("import network module:", network_module) @@ -2724,9 +2724,18 @@ def resize_images(imgs, size): size = None for i, network in enumerate(networks): - if i < 3: + if (i < 3 and args.network_regional_mask_max_color_codes is None) or i < args.network_regional_mask_max_color_codes: np_mask = np.array(mask_images[0]) - np_mask = np_mask[:, :, i] + + if args.network_regional_mask_max_color_codes: + # カラーコードでマスクを指定する + ch0 = (i + 1) & 1 + ch1 = ((i + 1) >> 1) & 1 + ch2 = ((i + 1) >> 2) & 1 + np_mask = np.all(np_mask == np.array([ch0, ch1, ch2]) * 255, axis=2) + np_mask = np_mask.astype(np.uint8) * 255 + else: + np_mask = np_mask[:, :, i] size = np_mask.shape else: np_mask = np.full(size, 255, dtype=np.uint8) @@ -3386,6 +3395,12 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する" ) + parser.add_argument( + "--network_regional_mask_max_color_codes", + type=int, + default=None, + help="max color codes for regional mask (default is None, mask by channel) / regional maskの最大色数(デフォルトはNoneでチャンネルごとのマスク)", + ) parser.add_argument( "--textual_inversion_embeddings", type=str, diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index 2d652bc82..c31ae0072 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -1543,7 +1543,7 @@ def __getattr__(self, item): elif args.network_merge_n_models: network_merge = args.network_merge_n_models else: - network_merge = None + network_merge = 0 print(f"network_merge: {network_merge}") for i, network_module in enumerate(args.network_module): @@ -1877,9 +1877,18 @@ def resize_images(imgs, size): size = None for i, network in enumerate(networks): - if i < 3: + if (i < 3 and args.network_regional_mask_max_color_codes is None) or i < args.network_regional_mask_max_color_codes: np_mask = np.array(mask_images[0]) - np_mask = np_mask[:, :, i] + + if args.network_regional_mask_max_color_codes: + # カラーコードでマスクを指定する + ch0 = (i + 1) & 1 + ch1 = ((i + 1) >> 1) & 1 + ch2 = ((i + 1) >> 2) & 1 + np_mask = np.all(np_mask == np.array([ch0, ch1, ch2]) * 255, axis=2) + np_mask = np_mask.astype(np.uint8) * 255 + else: + np_mask = np_mask[:, :, i] size = np_mask.shape else: np_mask = np.full(size, 255, dtype=np.uint8) @@ -2635,6 +2644,12 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する" ) + parser.add_argument( + "--network_regional_mask_max_color_codes", + type=int, + default=None, + help="max color codes for regional mask (default is None, mask by channel) / regional maskの最大色数(デフォルトはNoneでチャンネルごとのマスク)", + ) parser.add_argument( "--textual_inversion_embeddings", type=str,