From 6d6d86260b9b97fa79d84a6115d696312905d993 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 23 Nov 2023 19:40:48 +0900 Subject: [PATCH] add Deep Shrink --- gen_img_diffusers.py | 76 ++++++++++++++++++++++++++++++++++ library/original_unet.py | 68 ++++++++++++++++++++++++++++++- library/sdxl_original_unet.py | 70 ++++++++++++++++++++++++++++++- sdxl_gen_img.py | 77 +++++++++++++++++++++++++++++++++++ 4 files changed, 288 insertions(+), 3 deletions(-) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index a596a0494..7661538c6 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -2501,6 +2501,10 @@ def __getattr__(self, item): if args.diffusers_xformers: pipe.enable_xformers_memory_efficient_attention() + # Deep Shrink + if args.ds_depth_1 is not None: + unet.set_deep_shrink(args.ds_depth_1, args.ds_timesteps_1, args.ds_depth_2, args.ds_timesteps_2, args.ds_ratio) + # Extended Textual Inversion および Textual Inversionを処理する if args.XTI_embeddings: diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI @@ -3085,6 +3089,13 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): clip_prompt = None network_muls = None + # Deep Shrink + ds_depth_1 = None # means no override + ds_timesteps_1 = args.ds_timesteps_1 + ds_depth_2 = args.ds_depth_2 + ds_timesteps_2 = args.ds_timesteps_2 + ds_ratio = args.ds_ratio + prompt_args = raw_prompt.strip().split(" --") prompt = prompt_args[0] print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") @@ -3156,10 +3167,51 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): print(f"network mul: {network_muls}") continue + # Deep Shrink + m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink depth 1 + ds_depth_1 = int(m.group(1)) + print(f"deep shrink depth 1: {ds_depth_1}") + continue + + m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink timesteps 1 + ds_timesteps_1 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + print(f"deep shrink timesteps 1: {ds_timesteps_1}") + continue + + m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink depth 2 + ds_depth_2 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + print(f"deep shrink depth 2: {ds_depth_2}") + continue + + m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink timesteps 2 + ds_timesteps_2 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + print(f"deep shrink timesteps 2: {ds_timesteps_2}") + continue + + m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink ratio + ds_ratio = float(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + print(f"deep shrink ratio: {ds_ratio}") + continue + except ValueError as ex: print(f"Exception in parsing / 解析エラー: {parg}") print(ex) + # override Deep Shrink + if ds_depth_1 is not None: + if ds_depth_1 < 0: + ds_depth_1 = args.ds_depth_1 or 3 + unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio) + # prepare seed if seeds is not None: # given in prompt # 数が足りないなら前のをそのまま使う @@ -3509,6 +3561,30 @@ def setup_parser() -> argparse.ArgumentParser: # "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像" # ) + # Deep Shrink + parser.add_argument( + "--ds_depth_1", + type=int, + default=None, + help="Enable Deep Shrink with this depth 1, valid values are 0 to 3 / Deep Shrinkをこのdepthで有効にする", + ) + parser.add_argument( + "--ds_timesteps_1", + type=int, + default=650, + help="Apply Deep Shrink depth 1 until this timesteps / Deep Shrink depth 1を適用するtimesteps", + ) + parser.add_argument("--ds_depth_2", type=int, default=None, help="Deep Shrink depth 2 / Deep Shrinkのdepth 2") + parser.add_argument( + "--ds_timesteps_2", + type=int, + default=650, + help="Apply Deep Shrink depth 2 until this timesteps / Deep Shrink depth 2を適用するtimesteps", + ) + parser.add_argument( + "--ds_ratio", type=float, default=0.5, help="Deep Shrink ratio for downsampling / Deep Shrinkのdownsampling比率" + ) + return parser diff --git a/library/original_unet.py b/library/original_unet.py index 240b85951..0454f13f1 100644 --- a/library/original_unet.py +++ b/library/original_unet.py @@ -361,6 +361,23 @@ def get_timestep_embedding( return emb +# Deep Shrink: We do not common this function, because minimize dependencies. +def resize_like(x, target, mode="bicubic", align_corners=False): + org_dtype = x.dtype + if org_dtype == torch.bfloat16: + x = x.to(torch.float32) + + if x.shape[-2:] != target.shape[-2:]: + if mode == "nearest": + x = F.interpolate(x, size=target.shape[-2:], mode=mode) + else: + x = F.interpolate(x, size=target.shape[-2:], mode=mode, align_corners=align_corners) + + if org_dtype == torch.bfloat16: + x = x.to(org_dtype) + return x + + class SampleOutput: def __init__(self, sample): self.sample = sample @@ -1130,6 +1147,11 @@ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_si # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # Deep Shrink + if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]: + hidden_states = resize_like(hidden_states, res_hidden_states) + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: @@ -1221,6 +1243,11 @@ def forward( # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # Deep Shrink + if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]: + hidden_states = resize_like(hidden_states, res_hidden_states) + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: @@ -1417,6 +1444,31 @@ def __init__( self.conv_act = nn.SiLU() self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1) + # Deep Shrink + self.ds_depth_1 = None + self.ds_depth_2 = None + self.ds_timesteps_1 = None + self.ds_timesteps_2 = None + self.ds_ratio = None + + def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5): + if ds_depth_1 is None: + print("Deep Shrink is disabled.") + self.ds_depth_1 = None + self.ds_timesteps_1 = None + self.ds_depth_2 = None + self.ds_timesteps_2 = None + self.ds_ratio = None + else: + print( + f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]" + ) + self.ds_depth_1 = ds_depth_1 + self.ds_timesteps_1 = ds_timesteps_1 + self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1 + self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000 + self.ds_ratio = ds_ratio + # region diffusers compatibility def prepare_config(self): self.config = SimpleNamespace() @@ -1519,9 +1571,21 @@ def forward( # 2. pre-process sample = self.conv_in(sample) - # 3. down down_block_res_samples = (sample,) - for downsample_block in self.down_blocks: + for depth, downsample_block in enumerate(self.down_blocks): + # Deep Shrink + if self.ds_depth_1 is not None: + if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or ( + self.ds_depth_2 is not None + and depth == self.ds_depth_2 + and timesteps[0] < self.ds_timesteps_1 + and timesteps[0] >= self.ds_timesteps_2 + ): + org_dtype = sample.dtype + if org_dtype == torch.bfloat16: + sample = sample.to(torch.float32) + sample = F.interpolate(sample, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype) + # downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、 # まあこちらのほうがわかりやすいかもしれない if downsample_block.has_cross_attention: diff --git a/library/sdxl_original_unet.py b/library/sdxl_original_unet.py index 26a0af319..d51dfdbcc 100644 --- a/library/sdxl_original_unet.py +++ b/library/sdxl_original_unet.py @@ -266,6 +266,23 @@ def get_timestep_embedding( return emb +# Deep Shrink: We do not common this function, because minimize dependencies. +def resize_like(x, target, mode="bicubic", align_corners=False): + org_dtype = x.dtype + if org_dtype == torch.bfloat16: + x = x.to(torch.float32) + + if x.shape[-2:] != target.shape[-2:]: + if mode == "nearest": + x = F.interpolate(x, size=target.shape[-2:], mode=mode) + else: + x = F.interpolate(x, size=target.shape[-2:], mode=mode, align_corners=align_corners) + + if org_dtype == torch.bfloat16: + x = x.to(org_dtype) + return x + + class GroupNorm32(nn.GroupNorm): def forward(self, x): if self.weight.dtype != torch.float32: @@ -996,6 +1013,31 @@ def __init__( [GroupNorm32(32, self.model_channels), nn.SiLU(), nn.Conv2d(self.model_channels, self.out_channels, 3, padding=1)] ) + # Deep Shrink + self.ds_depth_1 = None + self.ds_depth_2 = None + self.ds_timesteps_1 = None + self.ds_timesteps_2 = None + self.ds_ratio = None + + def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5): + if ds_depth_1 is None: + print("Deep Shrink is disabled.") + self.ds_depth_1 = None + self.ds_timesteps_1 = None + self.ds_depth_2 = None + self.ds_timesteps_2 = None + self.ds_ratio = None + else: + print( + f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]" + ) + self.ds_depth_1 = ds_depth_1 + self.ds_timesteps_1 = ds_timesteps_1 + self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1 + self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000 + self.ds_ratio = ds_ratio + # region diffusers compatibility def prepare_config(self): self.config = SimpleNamespace() @@ -1077,16 +1119,42 @@ def call_module(module, h, emb, context): # h = x.type(self.dtype) h = x - for module in self.input_blocks: + + for depth, module in enumerate(self.input_blocks): + # Deep Shrink + if self.ds_depth_1 is not None: + if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or ( + self.ds_depth_2 is not None + and depth == self.ds_depth_2 + and timesteps[0] < self.ds_timesteps_1 + and timesteps[0] >= self.ds_timesteps_2 + ): + # print("downsample", h.shape, self.ds_ratio) + org_dtype = h.dtype + if org_dtype == torch.bfloat16: + h = h.to(torch.float32) + h = F.interpolate(h, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype) + h = call_module(module, h, emb, context) hs.append(h) h = call_module(self.middle_block, h, emb, context) for module in self.output_blocks: + # Deep Shrink + if self.ds_depth_1 is not None: + if hs[-1].shape[-2:] != h.shape[-2:]: + # print("upsample", h.shape, hs[-1].shape) + h = resize_like(h, hs[-1]) + h = torch.cat([h, hs.pop()], dim=1) h = call_module(module, h, emb, context) + # Deep Shrink: in case of depth 0 + if self.ds_depth_1 == 0 and h.shape[-2:] != x.shape[-2:]: + # print("upsample", h.shape, x.shape) + h = resize_like(h, x) + h = h.type(x.dtype) h = call_module(self.out, h, emb, context) diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index c31ae0072..a61fb7a89 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -1696,6 +1696,10 @@ def __getattr__(self, item): if args.diffusers_xformers: pipe.enable_xformers_memory_efficient_attention() + # Deep Shrink + if args.ds_depth_1 is not None: + unet.set_deep_shrink(args.ds_depth_1, args.ds_timesteps_1, args.ds_depth_2, args.ds_timesteps_2, args.ds_ratio) + # Textual Inversionを処理する if args.textual_inversion_embeddings: token_ids_embeds1 = [] @@ -2286,6 +2290,13 @@ def scale_and_round(x): clip_prompt = None network_muls = None + # Deep Shrink + ds_depth_1 = None # means no override + ds_timesteps_1 = args.ds_timesteps_1 + ds_depth_2 = args.ds_depth_2 + ds_timesteps_2 = args.ds_timesteps_2 + ds_ratio = args.ds_ratio + prompt_args = raw_prompt.strip().split(" --") prompt = prompt_args[0] print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") @@ -2393,10 +2404,51 @@ def scale_and_round(x): print(f"network mul: {network_muls}") continue + # Deep Shrink + m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink depth 1 + ds_depth_1 = int(m.group(1)) + print(f"deep shrink depth 1: {ds_depth_1}") + continue + + m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink timesteps 1 + ds_timesteps_1 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + print(f"deep shrink timesteps 1: {ds_timesteps_1}") + continue + + m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink depth 2 + ds_depth_2 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + print(f"deep shrink depth 2: {ds_depth_2}") + continue + + m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink timesteps 2 + ds_timesteps_2 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + print(f"deep shrink timesteps 2: {ds_timesteps_2}") + continue + + m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink ratio + ds_ratio = float(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + print(f"deep shrink ratio: {ds_ratio}") + continue + except ValueError as ex: print(f"Exception in parsing / 解析エラー: {parg}") print(ex) + # override Deep Shrink + if ds_depth_1 is not None: + if ds_depth_1 < 0: + ds_depth_1 = args.ds_depth_1 or 3 + unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio) + # prepare seed if seeds is not None: # given in prompt # 数が足りないなら前のをそのまま使う @@ -2734,6 +2786,31 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="enable CLIP Vision Conditioning for img2img with this strength / img2imgでCLIP Vision Conditioningを有効にしてこのstrengthで処理する", ) + + # Deep Shrink + parser.add_argument( + "--ds_depth_1", + type=int, + default=None, + help="Enable Deep Shrink with this depth 1, valid values are 0 to 8 / Deep Shrinkをこのdepthで有効にする", + ) + parser.add_argument( + "--ds_timesteps_1", + type=int, + default=650, + help="Apply Deep Shrink depth 1 until this timesteps / Deep Shrink depth 1を適用するtimesteps", + ) + parser.add_argument("--ds_depth_2", type=int, default=None, help="Deep Shrink depth 2 / Deep Shrinkのdepth 2") + parser.add_argument( + "--ds_timesteps_2", + type=int, + default=650, + help="Apply Deep Shrink depth 2 until this timesteps / Deep Shrink depth 2を適用するtimesteps", + ) + parser.add_argument( + "--ds_ratio", type=float, default=0.5, help="Deep Shrink ratio for downsampling / Deep Shrinkのdownsampling比率" + ) + # # parser.add_argument( # "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像" # )