Skip to content

Commit

Permalink
Flux latents fix (#9929)
Browse files Browse the repository at this point in the history
* update

* update

* update

* update

* update

* update

---------

Co-authored-by: Sayak Paul <[email protected]>
  • Loading branch information
DN6 and sayakpaul authored Nov 20, 2024
1 parent 637e230 commit f6f7afa
Show file tree
Hide file tree
Showing 12 changed files with 219 additions and 61 deletions.
22 changes: 14 additions & 8 deletions src/diffusers/pipelines/flux/pipeline_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,9 @@ def __init__(
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
)
Expand Down Expand Up @@ -386,9 +388,9 @@ def check_inputs(
callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
):
if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:
raise ValueError(
f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}."
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
logger.warning(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
)

if callback_on_step_end_tensor_inputs is not None and not all(
Expand Down Expand Up @@ -451,8 +453,10 @@ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
def _unpack_latents(latents, height, width, vae_scale_factor):
batch_size, num_patches, channels = latents.shape

height = height // vae_scale_factor
width = width // vae_scale_factor
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // (vae_scale_factor * 2))
width = 2 * (int(width) // (vae_scale_factor * 2))

latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
Expand Down Expand Up @@ -501,8 +505,10 @@ def prepare_latents(
generator,
latents=None,
):
height = int(height) // self.vae_scale_factor
width = int(width) // self.vae_scale_factor
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))

shape = (batch_size, num_channels_latents, height, width)

Expand Down
22 changes: 14 additions & 8 deletions src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,9 @@ def __init__(
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
)
Expand Down Expand Up @@ -410,9 +412,9 @@ def check_inputs(
callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
):
if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:
raise ValueError(
f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}."
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
logger.warning(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
)

if callback_on_step_end_tensor_inputs is not None and not all(
Expand Down Expand Up @@ -478,8 +480,10 @@ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
def _unpack_latents(latents, height, width, vae_scale_factor):
batch_size, num_patches, channels = latents.shape

height = height // vae_scale_factor
width = width // vae_scale_factor
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // (vae_scale_factor * 2))
width = 2 * (int(width) // (vae_scale_factor * 2))

latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
Expand All @@ -500,8 +504,10 @@ def prepare_latents(
generator,
latents=None,
):
height = int(height) // self.vae_scale_factor
width = int(width) // self.vae_scale_factor
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))

shape = (batch_size, num_channels_latents, height, width)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,9 @@ def __init__(
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
)
Expand Down Expand Up @@ -453,9 +455,9 @@ def check_inputs(
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")

if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:
raise ValueError(
f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}."
if height % self.vae_scale_factor * 2 != 0 or width % self.vae_scale_factor * 2 != 0:
logger.warning(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
)

if callback_on_step_end_tensor_inputs is not None and not all(
Expand Down Expand Up @@ -521,8 +523,10 @@ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
def _unpack_latents(latents, height, width, vae_scale_factor):
batch_size, num_patches, channels = latents.shape

height = height // vae_scale_factor
width = width // vae_scale_factor
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // (vae_scale_factor * 2))
width = 2 * (int(width) // (vae_scale_factor * 2))

latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
Expand Down Expand Up @@ -551,9 +555,10 @@ def prepare_latents(
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)

height = int(height) // self.vae_scale_factor
width = int(width) // self.vae_scale_factor

# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))
shape = (batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)

Expand Down Expand Up @@ -873,7 +878,6 @@ def __call__(
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)

latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)

latents, latent_image_ids = self.prepare_latents(
init_image,
latent_timestep,
Expand Down
32 changes: 19 additions & 13 deletions src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,11 @@ def __init__(
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor,
vae_scale_factor=self.vae_scale_factor * 2,
vae_latent_channels=self.vae.config.latent_channels,
do_normalize=False,
do_binarize=True,
Expand Down Expand Up @@ -467,9 +469,9 @@ def check_inputs(
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")

if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:
raise ValueError(
f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}."
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
logger.warning(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
)

if callback_on_step_end_tensor_inputs is not None and not all(
Expand Down Expand Up @@ -548,8 +550,10 @@ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
def _unpack_latents(latents, height, width, vae_scale_factor):
batch_size, num_patches, channels = latents.shape

height = height // vae_scale_factor
width = width // vae_scale_factor
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // (vae_scale_factor * 2))
width = 2 * (int(width) // (vae_scale_factor * 2))

latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
Expand Down Expand Up @@ -578,9 +582,10 @@ def prepare_latents(
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)

height = int(height) // self.vae_scale_factor
width = int(width) // self.vae_scale_factor

# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))
shape = (batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)

Expand Down Expand Up @@ -624,8 +629,10 @@ def prepare_mask_latents(
device,
generator,
):
height = int(height) // self.vae_scale_factor
width = int(width) // self.vae_scale_factor
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))
# resize the mask to latents shape as we concatenate the mask to the latents
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
# and half precision
Expand Down Expand Up @@ -663,7 +670,6 @@ def prepare_mask_latents(

# aligning device to prevent device errors when concating it with the latent model input
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)

masked_image_latents = self._pack_latents(
masked_image_latents,
batch_size,
Expand Down
23 changes: 14 additions & 9 deletions src/diffusers/pipelines/flux/pipeline_flux_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,9 @@ def __init__(
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
)
Expand Down Expand Up @@ -437,9 +439,9 @@ def check_inputs(
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")

if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:
raise ValueError(
f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}."
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
logger.warning(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
)

if callback_on_step_end_tensor_inputs is not None and not all(
Expand Down Expand Up @@ -505,8 +507,10 @@ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
def _unpack_latents(latents, height, width, vae_scale_factor):
batch_size, num_patches, channels = latents.shape

height = height // vae_scale_factor
width = width // vae_scale_factor
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // (vae_scale_factor * 2))
width = 2 * (int(width) // (vae_scale_factor * 2))

latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
Expand Down Expand Up @@ -534,9 +538,10 @@ def prepare_latents(
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)

height = int(height) // self.vae_scale_factor
width = int(width) // self.vae_scale_factor

# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))
shape = (batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)

Expand Down
Loading

0 comments on commit f6f7afa

Please sign in to comment.