From 016a6b8f2c54c27e53369891a6277be538063390 Mon Sep 17 00:00:00 2001 From: Josh Achiam Date: Fri, 30 Sep 2022 00:54:40 -0700 Subject: [PATCH] Allow resolutions that are not multiples of 64 (#505) * Allow resolutions that are not multiples of 64 * ran black * fix bug * add test * more explanation * more comments Co-authored-by: Patrick von Platen --- models/resnet.py | 10 ++++++-- models/unet_2d_condition.py | 48 ++++++++++++++++++++++++++++++++----- models/unet_blocks.py | 7 +++--- 3 files changed, 54 insertions(+), 11 deletions(-) diff --git a/models/resnet.py b/models/resnet.py index 2a6b2971aae5..49ff7d6bfa45 100644 --- a/models/resnet.py +++ b/models/resnet.py @@ -34,12 +34,18 @@ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_chann else: self.Conv2d_0 = conv - def forward(self, hidden_states): + def forward(self, hidden_states, output_size=None): assert hidden_states.shape[1] == self.channels + if self.use_conv_transpose: return self.conv(hidden_states) - hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + # if `output_size` is passed we force the interpolation output + # size and do not make use of `scale_factor=2` + if output_size is None: + hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + else: + hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed if self.use_conv: diff --git a/models/unet_2d_condition.py b/models/unet_2d_condition.py index 3ea8829b48e1..04453e0645af 100644 --- a/models/unet_2d_condition.py +++ b/models/unet_2d_condition.py @@ -7,7 +7,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..modeling_utils import ModelMixin -from ..utils import BaseOutput +from ..utils import BaseOutput, logging from .embeddings import TimestepEmbedding, Timesteps from .unet_blocks import ( CrossAttnDownBlock2D, @@ -20,6 +20,9 @@ ) +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + @dataclass class UNet2DConditionOutput(BaseOutput): """ @@ -145,15 +148,25 @@ def __init__( resnet_groups=norm_num_groups, ) + # count how many layers upsample the images + self.num_upsamplers = 0 + # up reversed_block_out_channels = list(reversed(block_out_channels)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] - is_final_block = i == len(block_out_channels) - 1 + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False up_block = get_up_block( up_block_type, @@ -162,7 +175,7 @@ def __init__( out_channels=output_channel, prev_output_channel=prev_output_channel, temb_channels=time_embed_dim, - add_upsample=not is_final_block, + add_upsample=add_upsample, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, @@ -223,6 +236,20 @@ def forward( [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + # 0. center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 @@ -262,20 +289,29 @@ def forward( sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states) # 5. up - for upsample_block in self.up_blocks: + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, ) else: - sample = upsample_block(hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples) - + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) # 6. post-process # make sure hidden states is in float32 # when running in half-precision diff --git a/models/unet_blocks.py b/models/unet_blocks.py index f42389b98562..a17b1d2a5333 100644 --- a/models/unet_blocks.py +++ b/models/unet_blocks.py @@ -1126,6 +1126,7 @@ def forward( res_hidden_states_tuple, temb=None, encoder_hidden_states=None, + upsample_size=None, ): for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states @@ -1151,7 +1152,7 @@ def custom_forward(*inputs): if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) + hidden_states = upsampler(hidden_states, upsample_size) return hidden_states @@ -1204,7 +1205,7 @@ def __init__( self.gradient_checkpointing = False - def forward(self, hidden_states, res_hidden_states_tuple, temb=None): + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] @@ -1225,7 +1226,7 @@ def custom_forward(*inputs): if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) + hidden_states = upsampler(hidden_states, upsample_size) return hidden_states