From 6bd30ba74827a4e8f392ec8a1ba90335425c6b9a Mon Sep 17 00:00:00 2001 From: Miguel Farinha <101428614+mlfarinha@users.noreply.github.com> Date: Fri, 13 Dec 2024 16:17:15 +0000 Subject: [PATCH] Allow image resolutions multiple of 8 instead of 64 in SVD pipeline (#6646) allow resolutions not multiple of 64 in SVD Co-authored-by: Miguel Farinha Co-authored-by: hlky --- src/diffusers/models/unets/unet_3d_blocks.py | 6 +++-- .../unets/unet_spatio_temporal_condition.py | 23 +++++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/unets/unet_3d_blocks.py b/src/diffusers/models/unets/unet_3d_blocks.py index 9c9fd7555899..195f7601dd54 100644 --- a/src/diffusers/models/unets/unet_3d_blocks.py +++ b/src/diffusers/models/unets/unet_3d_blocks.py @@ -1375,6 +1375,7 @@ def forward( res_hidden_states_tuple: Tuple[torch.Tensor, ...], temb: Optional[torch.Tensor] = None, image_only_indicator: Optional[torch.Tensor] = None, + upsample_size: Optional[int] = None, ) -> torch.Tensor: for resnet in self.resnets: # pop res hidden states @@ -1415,7 +1416,7 @@ def custom_forward(*inputs): if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) + hidden_states = upsampler(hidden_states, upsample_size) return hidden_states @@ -1485,6 +1486,7 @@ def forward( temb: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, image_only_indicator: Optional[torch.Tensor] = None, + upsample_size: Optional[int] = None, ) -> torch.Tensor: for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states @@ -1533,6 +1535,6 @@ def custom_forward(*inputs): if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) + hidden_states = upsampler(hidden_states, upsample_size) return hidden_states diff --git a/src/diffusers/models/unets/unet_spatio_temporal_condition.py b/src/diffusers/models/unets/unet_spatio_temporal_condition.py index 9fb975bc32d9..308b9e01c587 100644 --- a/src/diffusers/models/unets/unet_spatio_temporal_condition.py +++ b/src/diffusers/models/unets/unet_spatio_temporal_condition.py @@ -382,6 +382,20 @@ def forward( If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned, otherwise a `tuple` is returned where the first element is the sample tensor. """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + # 1. time timesteps = timestep if not torch.is_tensor(timesteps): @@ -457,15 +471,23 @@ def forward( # 5. up for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, image_only_indicator=image_only_indicator, ) else: @@ -473,6 +495,7 @@ def forward( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, image_only_indicator=image_only_indicator, )