Skip to content

Commit

Permalink
Allow image resolutions multiple of 8 instead of 64 in SVD pipeline (#…
Browse files Browse the repository at this point in the history
…6646)

allow resolutions not multiple of 64 in SVD

Co-authored-by: Miguel Farinha <[email protected]>
Co-authored-by: hlky <[email protected]>
  • Loading branch information
3 people authored Dec 13, 2024
1 parent cef0e36 commit 6bd30ba
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/diffusers/models/unets/unet_3d_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1375,6 +1375,7 @@ def forward(
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
temb: Optional[torch.Tensor] = None,
image_only_indicator: Optional[torch.Tensor] = None,
upsample_size: Optional[int] = None,
) -> torch.Tensor:
for resnet in self.resnets:
# pop res hidden states
Expand Down Expand Up @@ -1415,7 +1416,7 @@ def custom_forward(*inputs):

if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
hidden_states = upsampler(hidden_states, upsample_size)

return hidden_states

Expand Down Expand Up @@ -1485,6 +1486,7 @@ def forward(
temb: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
image_only_indicator: Optional[torch.Tensor] = None,
upsample_size: Optional[int] = None,
) -> torch.Tensor:
for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
Expand Down Expand Up @@ -1533,6 +1535,6 @@ def custom_forward(*inputs):

if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
hidden_states = upsampler(hidden_states, upsample_size)

return hidden_states
23 changes: 23 additions & 0 deletions src/diffusers/models/unets/unet_spatio_temporal_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,20 @@ def forward(
If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is
returned, otherwise a `tuple` is returned where the first element is the sample tensor.
"""
# By default samples have to be AT least a multiple of the overall upsampling factor.
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
# However, the upsampling interpolation output size can be forced to fit any upsampling size
# on the fly if necessary.
default_overall_up_factor = 2**self.num_upsamplers

# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
forward_upsample_size = False
upsample_size = None

if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
logger.info("Forward upsample size to force interpolation output size.")
forward_upsample_size = True

# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
Expand Down Expand Up @@ -457,22 +471,31 @@ def forward(

# 5. up
for i, upsample_block in enumerate(self.up_blocks):
is_final_block = i == len(self.up_blocks) - 1

res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]

# if we have not reached the final block and need to forward the
# upsample size, we do it here
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]

if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states,
upsample_size=upsample_size,
image_only_indicator=image_only_indicator,
)
else:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
upsample_size=upsample_size,
image_only_indicator=image_only_indicator,
)

Expand Down

0 comments on commit 6bd30ba

Please sign in to comment.