From 555ed624b5a6fff3cf5d9c4b6a61b67bc758de84 Mon Sep 17 00:00:00 2001 From: neurowelt Date: Wed, 27 Sep 2023 22:55:00 +0800 Subject: [PATCH 1/9] Update Fourier & remove unused imports --- free_lunch_utils.py | 32 ++++++++++++++++++++------------ 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/free_lunch_utils.py b/free_lunch_utils.py index 79d79cd..c884bd1 100644 --- a/free_lunch_utils.py +++ b/free_lunch_utils.py @@ -1,8 +1,7 @@ import torch import torch.fft as fft -from diffusers.models.unet_2d_condition import logger from diffusers.utils import is_torch_version -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple def isinstance_str(x: object, cls_name: str): @@ -20,26 +19,35 @@ def isinstance_str(x: object, cls_name: str): return False -def Fourier_filter(x, threshold, scale): - dtype = x.dtype - x = x.type(torch.float32) +def Fourier_filter(x_in, threshold, scale): + """ + Updated Fourier filter based on: + https://github.com/huggingface/diffusers/pull/5164#issuecomment-1732638706 + """ + + x = x_in + B, C, H, W = x.shape + + # Non-power of 2 images must be float32 + if (W & (W - 1)) != 0 or (H & (H - 1)) != 0: + x = x.to(dtype=torch.float32) + # FFT x_freq = fft.fftn(x, dim=(-2, -1)) x_freq = fft.fftshift(x_freq, dim=(-2, -1)) - + B, C, H, W = x_freq.shape - mask = torch.ones((B, C, H, W)).cuda() + mask = torch.ones((B, C, H, W), device=x.device) - crow, ccol = H // 2, W //2 - mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale + crow, ccol = H // 2, W // 2 + mask[..., crow - threshold : crow + threshold, ccol - threshold : ccol + threshold] = scale x_freq = x_freq * mask # IFFT x_freq = fft.ifftshift(x_freq, dim=(-2, -1)) x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real - - x_filtered = x_filtered.type(dtype) - return x_filtered + + return x_filtered.to(dtype=x_in.dtype) def register_upblock2d(model): From 58a8edbde1cac8561d53a694a9522cfc3fb2fe25 Mon Sep 17 00:00:00 2001 From: neurowelt Date: Wed, 27 Sep 2023 23:35:21 +0800 Subject: [PATCH 2/9] Add UNet3D register methods --- free_lunch_utils.py | 258 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 251 insertions(+), 7 deletions(-) diff --git a/free_lunch_utils.py b/free_lunch_utils.py index c884bd1..7a2168a 100644 --- a/free_lunch_utils.py +++ b/free_lunch_utils.py @@ -1,7 +1,9 @@ +from typing import Any, Dict, Optional, Tuple + import torch import torch.fft as fft from diffusers.utils import is_torch_version -from typing import Any, Dict, Optional, Tuple +from diffusers.models.unet_2d_condition import logger def isinstance_str(x: object, cls_name: str): @@ -51,13 +53,24 @@ def Fourier_filter(x_in, threshold, scale): def register_upblock2d(model): + """ + Register UpBlock2D for UNet2DCondition. + """ + def up_forward(self): - def forward(hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + def forward( + hidden_states, + res_hidden_states_tuple, + temb=None, + upsample_size=None + ): + logger.debug(f"in upblock2d, hidden states shape: {hidden_states.shape}") + for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] - #print(f"in upblock2d, hidden states shape: {hidden_states.shape}") + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: @@ -93,13 +106,24 @@ def custom_forward(*inputs): def register_free_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2): + """ + Register UpBlock2D with FreeU for UNet2DCondition. + """ + def up_forward(self): - def forward(hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + def forward( + hidden_states, + res_hidden_states_tuple, + temb=None, + upsample_size=None + ): + logger.debug(f"in free upblock2d, hidden states shape: {hidden_states.shape}") + for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] - #print(f"in free upblock2d, hidden states shape: {hidden_states.shape}") + # --------------- FreeU code ----------------------- # Only operate on the first two stages if hidden_states.shape[1] == 1280: @@ -149,6 +173,10 @@ def custom_forward(*inputs): def register_crossattn_upblock2d(model): + """ + Register CrossAttn UpBlock2D for UNet2DCondition. + """ + def up_forward(self): def forward( hidden_states: torch.FloatTensor, @@ -160,9 +188,10 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ): + logger.debug(f"in crossatten upblock2d, hidden states shape: {hidden_states.shape}") + for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states - #print(f"in crossatten upblock2d, hidden states shape: {hidden_states.shape}") res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) @@ -221,6 +250,10 @@ def custom_forward(*inputs): def register_free_crossattn_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2): + """ + Register CrossAttn UpBlock2D with FreeU for UNet2DCondition. + """ + def up_forward(self): def forward( hidden_states: torch.FloatTensor, @@ -232,9 +265,10 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ): + logger.debug(f"in free crossatten upblock2d, hidden states shape: {hidden_states.shape}") + for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states - #print(f"in free crossatten upblock2d, hidden states shape: {hidden_states.shape}") res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] @@ -300,6 +334,216 @@ def custom_forward(*inputs): for i, upsample_block in enumerate(model.unet.up_blocks): if isinstance_str(upsample_block, "CrossAttnUpBlock2D"): + upsample_block.forward = up_forward(upsample_block) + setattr(upsample_block, 'b1', b1) + setattr(upsample_block, 'b2', b2) + setattr(upsample_block, 's1', s1) + setattr(upsample_block, 's2', s2) + + +def register_upblock3d(model): + """ + Register UpBlock3D for UNet3DCondition. + """ + + def up_forward(self): + def forward( + hidden_states, + res_hidden_states_tuple, + temb=None, + upsample_size=None, + num_frames=1 + ): + + logger.debug(f"in upblock3d, hidden states shape: {hidden_states.shape}") + + for resnet, temp_conv in zip(self.resnets, self.temp_convs): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + return forward + + for i, upsample_block in enumerate(model.unet.up_blocks): + if isinstance_str(upsample_block, "UpBlock3D"): + upsample_block.forward = up_forward(upsample_block) + + +def register_free_upblock3d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2): + """ + Register UpBlock3D with FreeU for UNet3DCondition. + """ + + def up_forward(self): + def forward( + hidden_states, + res_hidden_states_tuple, + temb=None, + upsample_size=None, + num_frames=1 + ): + + logger.debug(f"in free upblock3d, hidden states shape: {hidden_states.shape}") + + for resnet, temp_conv in zip(self.resnets, self.temp_convs): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # --------------- FreeU code ----------------------- + # Only operate on the first two stages + if hidden_states.shape[1] == 1280: + hidden_states[:,:640] = hidden_states[:,:640] * self.b1 + res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1) + if hidden_states.shape[1] == 640: + hidden_states[:,:320] = hidden_states[:,:320] * self.b2 + res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2) + # --------------------------------------------------------- + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + return forward + + for i, upsample_block in enumerate(model.unet.up_blocks): + if isinstance_str(upsample_block, "UpBlock3D"): + upsample_block.forward = up_forward(upsample_block) + setattr(upsample_block, 'b1', b1) + setattr(upsample_block, 'b2', b2) + setattr(upsample_block, 's1', s1) + setattr(upsample_block, 's2', s2) + + +def register_crossattn_upblock3d(model): + """ + Register CrossAttn UpBlock3D for UNet3DCondition. + """ + + def up_forward(self): + def forward( + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + num_frames: int = 1 + ): + logger.debug(f"in crossatten upblock3d, hidden states shape: {hidden_states.shape}") + + for resnet, temp_conv, attn, temp_attn in zip( + self.resnets, self.temp_convs, self.attentions, self.temp_attentions + ): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + hidden_states = temp_attn( + hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False + )[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + return forward + + for i, upsample_block in enumerate(model.unet.up_blocks): + if isinstance_str(upsample_block, "CrossAttnUpBlock3D"): + upsample_block.forward = up_forward(upsample_block) + + +def register_free_crossattn_upblock3d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2): + """ + Register CrossAttn UpBlock3D with FreeU for UNet3DCondition. + """ + + def up_forward(self): + def forward( + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + num_frames: int = 1 + ): + logger.debug(f"in free crossatten upblock3d, hidden states shape: {hidden_states.shape}") + + for resnet, temp_conv, attn, temp_attn in zip( + self.resnets, self.temp_convs, self.attentions, self.temp_attentions + ): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # --------------- FreeU code ----------------------- + # Only operate on the first two stages + if hidden_states.shape[1] == 1280: + hidden_states[:,:640] = hidden_states[:,:640] * self.b1 + res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1) + if hidden_states.shape[1] == 640: + hidden_states[:,:320] = hidden_states[:,:320] * self.b2 + res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2) + # --------------------------------------------------------- + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + hidden_states = temp_attn( + hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False + )[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + return forward + + for i, upsample_block in enumerate(model.unet.up_blocks): + if isinstance_str(upsample_block, "CrossAttnUpBlock3D"): upsample_block.forward = up_forward(upsample_block) setattr(upsample_block, 'b1', b1) setattr(upsample_block, 'b2', b2) From d63f4837aa5851cec46f80d22136625ac3186d0b Mon Sep 17 00:00:00 2001 From: neurowelt Date: Wed, 27 Sep 2023 23:41:56 +0800 Subject: [PATCH 3/9] Add scale to UpBlock2D --- free_lunch_utils.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/free_lunch_utils.py b/free_lunch_utils.py index 7a2168a..eac0de8 100644 --- a/free_lunch_utils.py +++ b/free_lunch_utils.py @@ -62,7 +62,8 @@ def forward( hidden_states, res_hidden_states_tuple, temb=None, - upsample_size=None + upsample_size=None, + scale: float = 1.0 ): logger.debug(f"in upblock2d, hidden states shape: {hidden_states.shape}") @@ -90,11 +91,11 @@ def custom_forward(*inputs): create_custom_forward(resnet), hidden_states, temb ) else: - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(hidden_states, temb, scale=scale) if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size) + hidden_states = upsampler(hidden_states, upsample_size, scale=scale) return hidden_states @@ -115,7 +116,8 @@ def forward( hidden_states, res_hidden_states_tuple, temb=None, - upsample_size=None + upsample_size=None, + scale: float = 1.0 ): logger.debug(f"in free upblock2d, hidden states shape: {hidden_states.shape}") @@ -153,11 +155,11 @@ def custom_forward(*inputs): create_custom_forward(resnet), hidden_states, temb ) else: - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(hidden_states, temb, scale=scale) if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size) + hidden_states = upsampler(hidden_states, upsample_size, scale=scale) return hidden_states From e74e6f52df69a23900bc96fadf5ad6bfc8d12abd Mon Sep 17 00:00:00 2001 From: neurowelt Date: Wed, 27 Sep 2023 23:45:04 +0800 Subject: [PATCH 4/9] Add lora_scale to CrossAttn UpBlock2D --- free_lunch_utils.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/free_lunch_utils.py b/free_lunch_utils.py index eac0de8..39f996b 100644 --- a/free_lunch_utils.py +++ b/free_lunch_utils.py @@ -192,6 +192,8 @@ def forward( ): logger.debug(f"in crossatten upblock2d, hidden states shape: {hidden_states.shape}") + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] @@ -228,7 +230,7 @@ def custom_forward(*inputs): **ckpt_kwargs, )[0] else: - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(hidden_states, temb, scale=lora_scale) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -240,7 +242,7 @@ def custom_forward(*inputs): if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size) + hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale) return hidden_states @@ -269,6 +271,8 @@ def forward( ): logger.debug(f"in free crossatten upblock2d, hidden states shape: {hidden_states.shape}") + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] @@ -316,7 +320,7 @@ def custom_forward(*inputs): **ckpt_kwargs, )[0] else: - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(hidden_states, temb, scale=lora_scale) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -328,7 +332,7 @@ def custom_forward(*inputs): if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size) + hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale) return hidden_states From 5b87da6de953cf1d4c4cbbb247e3528f2f96263d Mon Sep 17 00:00:00 2001 From: neurowelt Date: Wed, 27 Sep 2023 23:45:24 +0800 Subject: [PATCH 5/9] Update README --- README.md | 30 ++++++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 7874947..7db4414 100644 --- a/README.md +++ b/README.md @@ -3,9 +3,12 @@ The offical code https://github.com/ChenyangSi/FreeU ## Usage + +### Image Pipelines + ```python -from diffusers import StableDiffusionPipeline import torch +from diffusers import StableDiffusionPipeline from .free_lunch_utils import register_free_upblock2d, register_free_crossattn_upblock2d model_id = "runwayml/stable-diffusion-v1-5" @@ -23,5 +26,28 @@ image = pipe(prompt).images[0] image.save("astronaut_rides_horse.png") ``` +### Video Pipelines + +```python +import torch +from diffusers import TextToVideoSDPipeline +from diffusers.utils import export_to_video +from .free_lunch_utils import register_free_upblock3d, register_free_crossattn_upblock3d + +model_id = "cerspense/zeroscope_v2_576w" +pipe = TextToVideoSDPipeline.from_pretrained(model_id, torch_dtype=torch.float16) +pipe = pipe.to("cuda") + +# -------- freeu block registration +register_free_upblock3d(pipe, b1=1.2, b2=1.4, s1=0.9, s2=0.2) +register_free_crossattn_upblock3d(pipe, b1=1.2, b2=1.4, s1=0.9, s2=0.2) +# -------- freeu block registration + +prompt = "an astronaut riding a horse on mars" +video_frames = pipe(prompt, height=320, width=576, num_frames=30).frames + +export_to_video(video_frames, "astronaut_rides_horse.mp4") +``` + Note that it is supported and tested on diffusers v0.19.3. -If you are using the latest diffusers, it is recommended to use the corresponding branch, but it has not been tested. +If you are using the latest diffusers, it is recommended to use the corresponding branch, but it has not been tested. \ No newline at end of file From b3b3f47fd9dc2d13f7db022413d2237672ea6eb2 Mon Sep 17 00:00:00 2001 From: neurowelt Date: Wed, 27 Sep 2023 23:50:33 +0800 Subject: [PATCH 6/9] Update imports and clean up --- .gitignore | 3 +++ __init__.py | 6 +++++- 2 files changed, 8 insertions(+), 1 deletion(-) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..9c6d5e8 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +# MacOS +# ===== +.DS_Store \ No newline at end of file diff --git a/__init__.py b/__init__.py index 3608dfa..f2bd7dd 100644 --- a/__init__.py +++ b/__init__.py @@ -1 +1,5 @@ -from .free_lunch_utils import register_upblock2d, register_free_upblock2d, register_crossattn_upblock2d, register_free_crossattn_upblock2d \ No newline at end of file +from .free_lunch_utils import ( + register_upblock2d, register_free_upblock2d, + register_crossattn_upblock2d, register_free_crossattn_upblock2d, + register_upblock3d, register_free_upblock3d +) \ No newline at end of file From 5e18775dc8f0a60ca568e81b1a9fe4f148e672fa Mon Sep 17 00:00:00 2001 From: neurowelt Date: Thu, 28 Sep 2023 00:18:55 +0800 Subject: [PATCH 7/9] Update README --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index 7db4414..1f0671c 100644 --- a/README.md +++ b/README.md @@ -49,5 +49,9 @@ video_frames = pipe(prompt, height=320, width=576, num_frames=30).frames export_to_video(video_frames, "astronaut_rides_horse.mp4") ``` +#### 28/09/223 +Current version was successfully ran on diffusers v0.21.2. + +#### 26/09/23 Note that it is supported and tested on diffusers v0.19.3. If you are using the latest diffusers, it is recommended to use the corresponding branch, but it has not been tested. \ No newline at end of file From f2d8024c598635edf6419dafe6ef0fff86b34ae2 Mon Sep 17 00:00:00 2001 From: neurowelt Date: Thu, 28 Sep 2023 00:20:08 +0800 Subject: [PATCH 8/9] Update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 1f0671c..a492227 100644 --- a/README.md +++ b/README.md @@ -49,7 +49,7 @@ video_frames = pipe(prompt, height=320, width=576, num_frames=30).frames export_to_video(video_frames, "astronaut_rides_horse.mp4") ``` -#### 28/09/223 +#### 28/09/23 Current version was successfully ran on diffusers v0.21.2. #### 26/09/23 From 951cf56f57b2e54aca62ba29569589d92def5056 Mon Sep 17 00:00:00 2001 From: neurowelt Date: Thu, 28 Sep 2023 10:51:51 +0800 Subject: [PATCH 9/9] Change logger to unet specific --- free_lunch_utils.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/free_lunch_utils.py b/free_lunch_utils.py index 39f996b..8763505 100644 --- a/free_lunch_utils.py +++ b/free_lunch_utils.py @@ -3,7 +3,8 @@ import torch import torch.fft as fft from diffusers.utils import is_torch_version -from diffusers.models.unet_2d_condition import logger +from diffusers.models.unet_2d_condition import logger as logger2d +from diffusers.models.unet_3d_condition import logger as logger3d def isinstance_str(x: object, cls_name: str): @@ -65,7 +66,7 @@ def forward( upsample_size=None, scale: float = 1.0 ): - logger.debug(f"in upblock2d, hidden states shape: {hidden_states.shape}") + logger2d.debug(f"in upblock2d, hidden states shape: {hidden_states.shape}") for resnet in self.resnets: # pop res hidden states @@ -119,7 +120,7 @@ def forward( upsample_size=None, scale: float = 1.0 ): - logger.debug(f"in free upblock2d, hidden states shape: {hidden_states.shape}") + logger2d.debug(f"in free upblock2d, hidden states shape: {hidden_states.shape}") for resnet in self.resnets: # pop res hidden states @@ -190,7 +191,7 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ): - logger.debug(f"in crossatten upblock2d, hidden states shape: {hidden_states.shape}") + logger2d.debug(f"in crossatten upblock2d, hidden states shape: {hidden_states.shape}") lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 @@ -269,7 +270,7 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ): - logger.debug(f"in free crossatten upblock2d, hidden states shape: {hidden_states.shape}") + logger2d.debug(f"in free crossatten upblock2d, hidden states shape: {hidden_states.shape}") lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 @@ -361,7 +362,7 @@ def forward( num_frames=1 ): - logger.debug(f"in upblock3d, hidden states shape: {hidden_states.shape}") + logger3d.debug(f"in upblock3d, hidden states shape: {hidden_states.shape}") for resnet, temp_conv in zip(self.resnets, self.temp_convs): # pop res hidden states @@ -400,7 +401,7 @@ def forward( num_frames=1 ): - logger.debug(f"in free upblock3d, hidden states shape: {hidden_states.shape}") + logger3d.debug(f"in free upblock3d, hidden states shape: {hidden_states.shape}") for resnet, temp_conv in zip(self.resnets, self.temp_convs): # pop res hidden states @@ -455,7 +456,7 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, num_frames: int = 1 ): - logger.debug(f"in crossatten upblock3d, hidden states shape: {hidden_states.shape}") + logger3d.debug(f"in crossatten upblock3d, hidden states shape: {hidden_states.shape}") for resnet, temp_conv, attn, temp_attn in zip( self.resnets, self.temp_convs, self.attentions, self.temp_attentions @@ -507,7 +508,7 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, num_frames: int = 1 ): - logger.debug(f"in free crossatten upblock3d, hidden states shape: {hidden_states.shape}") + logger3d.debug(f"in free crossatten upblock3d, hidden states shape: {hidden_states.shape}") for resnet, temp_conv, attn, temp_attn in zip( self.resnets, self.temp_convs, self.attentions, self.temp_attentions