From 3389762996ad3d21572fc6d587e4c7d53f3f4924 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Fri, 1 Dec 2023 07:14:22 -1000 Subject: [PATCH] [Kandinsky 3.0] Follow-up TODOs (#5944) clean-up kendinsky 3.0 --- src/diffusers/models/__init__.py | 4 +- src/diffusers/models/attention_processor.py | 49 +--- .../{unet_kandi3.py => unet_kandinsky3.py} | 264 +++++++----------- .../pipelines/kandinsky3/__init__.py | 8 +- ...ky3_pipeline.py => pipeline_kandinsky3.py} | 203 +++++++++++--- ...line.py => pipeline_kandinsky3_img2img.py} | 262 ++++++++++++++--- tests/pipelines/kandinsky3/test_kandinsky3.py | 4 - .../kandinsky3/test_kandinsky3_img2img.py | 225 +++++++++++++++ tests/pipelines/test_pipelines_common.py | 4 + 9 files changed, 744 insertions(+), 279 deletions(-) rename src/diffusers/models/{unet_kandi3.py => unet_kandinsky3.py} (69%) rename src/diffusers/pipelines/kandinsky3/{kandinsky3_pipeline.py => pipeline_kandinsky3.py} (70%) rename src/diffusers/pipelines/kandinsky3/{kandinsky3img2img_pipeline.py => pipeline_kandinsky3_img2img.py} (59%) create mode 100644 tests/pipelines/kandinsky3/test_kandinsky3_img2img.py diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 839045001bb06..1b76b4e033413 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -42,7 +42,7 @@ _import_structure["unet_2d"] = ["UNet2DModel"] _import_structure["unet_2d_condition"] = ["UNet2DConditionModel"] _import_structure["unet_3d_condition"] = ["UNet3DConditionModel"] - _import_structure["unet_kandi3"] = ["Kandinsky3UNet"] + _import_structure["unet_kandinsky3"] = ["Kandinsky3UNet"] _import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"] _import_structure["unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"] _import_structure["vq_model"] = ["VQModel"] @@ -72,7 +72,7 @@ from .unet_2d import UNet2DModel from .unet_2d_condition import UNet2DConditionModel from .unet_3d_condition import UNet3DConditionModel - from .unet_kandi3 import Kandinsky3UNet + from .unet_kandinsky3 import Kandinsky3UNet from .unet_motion_model import MotionAdapter, UNetMotionModel from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel from .vq_model import VQModel diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 21eb3a32dc091..40a335527ace0 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -16,7 +16,7 @@ import torch import torch.nn.functional as F -from torch import einsum, nn +from torch import nn from ..utils import USE_PEFT_BACKEND, deprecate, logging from ..utils.import_utils import is_xformers_available @@ -109,15 +109,17 @@ def __init__( residual_connection: bool = False, _from_deprecated_attn_block: bool = False, processor: Optional["AttnProcessor"] = None, + out_dim: int = None, ): super().__init__() - self.inner_dim = dim_head * heads + self.inner_dim = out_dim if out_dim is not None else dim_head * heads self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim self.upcast_attention = upcast_attention self.upcast_softmax = upcast_softmax self.rescale_output_factor = rescale_output_factor self.residual_connection = residual_connection self.dropout = dropout + self.out_dim = out_dim if out_dim is not None else query_dim # we make use of this private variable to know whether this class is loaded # with an deprecated state dict so that we can convert it on the fly @@ -126,7 +128,7 @@ def __init__( self.scale_qk = scale_qk self.scale = dim_head**-0.5 if self.scale_qk else 1.0 - self.heads = heads + self.heads = out_dim // dim_head if out_dim is not None else heads # for slice_size > 0 the attention score computation # is split across the batch axis to save memory # You can set slice_size with `set_attention_slice` @@ -193,7 +195,7 @@ def __init__( self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim) self.to_out = nn.ModuleList([]) - self.to_out.append(linear_cls(self.inner_dim, query_dim, bias=out_bias)) + self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias)) self.to_out.append(nn.Dropout(dropout)) # set attention processor @@ -2219,44 +2221,6 @@ def __call__( return hidden_states -# TODO(Yiyi): This class should not exist, we can replace it with a normal attention processor I believe -# this way torch.compile and co. will work as well -class Kandi3AttnProcessor: - r""" - Default kandinsky3 proccesor for performing attention-related computations. - """ - - @staticmethod - def _reshape(hid_states, h): - b, n, f = hid_states.shape - d = f // h - return hid_states.unsqueeze(-1).reshape(b, n, h, d).permute(0, 2, 1, 3) - - def __call__( - self, - attn, - x, - context, - context_mask=None, - ): - query = self._reshape(attn.to_q(x), h=attn.num_heads) - key = self._reshape(attn.to_k(context), h=attn.num_heads) - value = self._reshape(attn.to_v(context), h=attn.num_heads) - - attention_matrix = einsum("b h i d, b h j d -> b h i j", query, key) - - if context_mask is not None: - max_neg_value = -torch.finfo(attention_matrix.dtype).max - context_mask = context_mask.unsqueeze(1).unsqueeze(1) - attention_matrix = attention_matrix.masked_fill(~(context_mask != 0), max_neg_value) - attention_matrix = (attention_matrix * attn.scale).softmax(dim=-1) - - out = einsum("b h i j, b h j d -> b h i d", attention_matrix, value) - out = out.permute(0, 2, 1, 3).reshape(out.shape[0], out.shape[2], -1) - out = attn.to_out[0](out) - return out - - LORA_ATTENTION_PROCESSORS = ( LoRAAttnProcessor, LoRAAttnProcessor2_0, @@ -2282,7 +2246,6 @@ def __call__( LoRAXFormersAttnProcessor, IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, - Kandi3AttnProcessor, ) AttentionProcessor = Union[ diff --git a/src/diffusers/models/unet_kandi3.py b/src/diffusers/models/unet_kandinsky3.py similarity index 69% rename from src/diffusers/models/unet_kandi3.py rename to src/diffusers/models/unet_kandinsky3.py index 42e25a942f7df..eef3287e5d99a 100644 --- a/src/diffusers/models/unet_kandi3.py +++ b/src/diffusers/models/unet_kandinsky3.py @@ -1,16 +1,28 @@ -import math +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from dataclasses import dataclass from typing import Dict, Tuple, Union import torch -import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, logging -from .attention_processor import AttentionProcessor, Kandi3AttnProcessor -from .embeddings import TimestepEmbedding +from .attention_processor import Attention, AttentionProcessor, AttnProcessor +from .embeddings import TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin @@ -22,36 +34,6 @@ class Kandinsky3UNetOutput(BaseOutput): sample: torch.FloatTensor = None -# TODO(Yiyi): This class needs to be removed -def set_default_item(condition, item_1, item_2=None): - if condition: - return item_1 - else: - return item_2 - - -# TODO(Yiyi): This class needs to be removed -def set_default_layer(condition, layer_1, args_1=[], kwargs_1={}, layer_2=torch.nn.Identity, args_2=[], kwargs_2={}): - if condition: - return layer_1(*args_1, **kwargs_1) - else: - return layer_2(*args_2, **kwargs_2) - - -# TODO(Yiyi): This class should be removed and be replaced by Timesteps -class SinusoidalPosEmb(nn.Module): - def __init__(self, dim): - super().__init__() - self.dim = dim - - def forward(self, x, type_tensor=None): - half_dim = self.dim // 2 - emb = math.log(10000) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, device=x.device) * -emb) - emb = x[:, None] * emb[None, :] - return torch.cat((emb.sin(), emb.cos()), dim=-1) - - class Kandinsky3EncoderProj(nn.Module): def __init__(self, encoder_hid_dim, cross_attention_dim): super().__init__() @@ -87,9 +69,7 @@ def __init__( out_channels = in_channels init_channels = block_out_channels[0] // 2 - # TODO(Yiyi): Should be replaced with Timesteps class -> make sure that results are the same - # self.time_proj = Timesteps(init_channels, flip_sin_to_cos=False, downscale_freq_shift=1) - self.time_proj = SinusoidalPosEmb(init_channels) + self.time_proj = Timesteps(init_channels, flip_sin_to_cos=False, downscale_freq_shift=1) self.time_embedding = TimestepEmbedding( init_channels, @@ -106,7 +86,7 @@ def __init__( hidden_dims = [init_channels] + list(block_out_channels) in_out_dims = list(zip(hidden_dims[:-1], hidden_dims[1:])) - text_dims = [set_default_item(is_exist, cross_attention_dim) for is_exist in add_cross_attention] + text_dims = [cross_attention_dim if is_exist else None for is_exist in add_cross_attention] num_blocks = len(block_out_channels) * [layers_per_block] layer_params = [num_blocks, text_dims, add_self_attention] rev_layer_params = map(reversed, layer_params) @@ -118,7 +98,7 @@ def __init__( zip(in_out_dims, *layer_params) ): down_sample = level != (self.num_levels - 1) - cat_dims.append(set_default_item(level != (self.num_levels - 1), out_dim, 0)) + cat_dims.append(out_dim if level != (self.num_levels - 1) else 0) self.down_blocks.append( Kandinsky3DownSampleBlock( in_dim, @@ -223,18 +203,16 @@ def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. """ - self.set_attn_processor(Kandi3AttnProcessor()) + self.set_attn_processor(AttnProcessor()) def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value def forward(self, sample, timestep, encoder_hidden_states=None, encoder_attention_mask=None, return_dict=True): - # TODO(Yiyi): Clean up the following variables - these names should not be used - # but instead only the ones that we pass to forward - x = sample - context_mask = encoder_attention_mask - context = encoder_hidden_states + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) if not torch.is_tensor(timestep): dtype = torch.float32 if isinstance(timestep, float) else torch.int32 @@ -244,33 +222,33 @@ def forward(self, sample, timestep, encoder_hidden_states=None, encoder_attentio # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = timestep.expand(sample.shape[0]) - time_embed_input = self.time_proj(timestep).to(x.dtype) + time_embed_input = self.time_proj(timestep).to(sample.dtype) time_embed = self.time_embedding(time_embed_input) - context = self.encoder_hid_proj(context) + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) - if context is not None: - time_embed = self.add_time_condition(time_embed, context, context_mask) + if encoder_hidden_states is not None: + time_embed = self.add_time_condition(time_embed, encoder_hidden_states, encoder_attention_mask) hidden_states = [] - x = self.conv_in(x) + sample = self.conv_in(sample) for level, down_sample in enumerate(self.down_blocks): - x = down_sample(x, time_embed, context, context_mask) + sample = down_sample(sample, time_embed, encoder_hidden_states, encoder_attention_mask) if level != self.num_levels - 1: - hidden_states.append(x) + hidden_states.append(sample) for level, up_sample in enumerate(self.up_blocks): if level != 0: - x = torch.cat([x, hidden_states.pop()], dim=1) - x = up_sample(x, time_embed, context, context_mask) + sample = torch.cat([sample, hidden_states.pop()], dim=1) + sample = up_sample(sample, time_embed, encoder_hidden_states, encoder_attention_mask) - x = self.conv_norm_out(x) - x = self.conv_act_out(x) - x = self.conv_out(x) + sample = self.conv_norm_out(sample) + sample = self.conv_act_out(sample) + sample = self.conv_out(sample) if not return_dict: - return (x,) - return Kandinsky3UNetOutput(sample=x) + return (sample,) + return Kandinsky3UNetOutput(sample=sample) class Kandinsky3UpSampleBlock(nn.Module): @@ -290,7 +268,7 @@ def __init__( self_attention=True, ): super().__init__() - up_resolutions = [[None, set_default_item(up_sample, True), None, None]] + [[None] * 4] * (num_blocks - 1) + up_resolutions = [[None, True if up_sample else None, None, None]] + [[None] * 4] * (num_blocks - 1) hidden_channels = ( [(in_channels + cat_dim, in_channels)] + [(in_channels, in_channels)] * (num_blocks - 2) @@ -303,27 +281,27 @@ def __init__( self.self_attention = self_attention self.context_dim = context_dim - attentions.append( - set_default_layer( - self_attention, - Kandinsky3AttentionBlock, - (out_channels, time_embed_dim, None, groups, head_dim, expansion_ratio), - layer_2=nn.Identity, + if self_attention: + attentions.append( + Kandinsky3AttentionBlock(out_channels, time_embed_dim, None, groups, head_dim, expansion_ratio) ) - ) + else: + attentions.append(nn.Identity()) for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions): resnets_in.append( Kandinsky3ResNetBlock(in_channel, in_channel, time_embed_dim, groups, compression_ratio, up_resolution) ) - attentions.append( - set_default_layer( - context_dim is not None, - Kandinsky3AttentionBlock, - (in_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio), - layer_2=nn.Identity, + + if context_dim is not None: + attentions.append( + Kandinsky3AttentionBlock( + in_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio + ) ) - ) + else: + attentions.append(nn.Identity()) + resnets_out.append( Kandinsky3ResNetBlock(in_channel, out_channel, time_embed_dim, groups, compression_ratio) ) @@ -367,29 +345,29 @@ def __init__( self.self_attention = self_attention self.context_dim = context_dim - attentions.append( - set_default_layer( - self_attention, - Kandinsky3AttentionBlock, - (in_channels, time_embed_dim, None, groups, head_dim, expansion_ratio), - layer_2=nn.Identity, + if self_attention: + attentions.append( + Kandinsky3AttentionBlock(in_channels, time_embed_dim, None, groups, head_dim, expansion_ratio) ) - ) + else: + attentions.append(nn.Identity()) - up_resolutions = [[None] * 4] * (num_blocks - 1) + [[None, None, set_default_item(down_sample, False), None]] + up_resolutions = [[None] * 4] * (num_blocks - 1) + [[None, None, False if down_sample else None, None]] hidden_channels = [(in_channels, out_channels)] + [(out_channels, out_channels)] * (num_blocks - 1) for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions): resnets_in.append( Kandinsky3ResNetBlock(in_channel, out_channel, time_embed_dim, groups, compression_ratio) ) - attentions.append( - set_default_layer( - context_dim is not None, - Kandinsky3AttentionBlock, - (out_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio), - layer_2=nn.Identity, + + if context_dim is not None: + attentions.append( + Kandinsky3AttentionBlock( + out_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio + ) ) - ) + else: + attentions.append(nn.Identity()) + resnets_out.append( Kandinsky3ResNetBlock( out_channel, out_channel, time_embed_dim, groups, compression_ratio, up_resolution @@ -431,68 +409,23 @@ def forward(self, x, context): return x -# TODO(Yiyi): This class should ideally not even exist, it slows everything needlessly down. I'm pretty -# sure we can delete it and instead just pass an attention_mask -class Attention(nn.Module): - def __init__(self, in_channels, out_channels, context_dim, head_dim=64): - super().__init__() - assert out_channels % head_dim == 0 - self.num_heads = out_channels // head_dim - self.scale = head_dim**-0.5 - - # to_q - self.to_q = nn.Linear(in_channels, out_channels, bias=False) - # to_k - self.to_k = nn.Linear(context_dim, out_channels, bias=False) - # to_v - self.to_v = nn.Linear(context_dim, out_channels, bias=False) - processor = Kandi3AttnProcessor() - self.set_processor(processor) - # to_out - self.to_out = nn.ModuleList([]) - self.to_out.append(nn.Linear(out_channels, out_channels, bias=False)) - - def set_processor(self, processor: "AttnProcessor"): # noqa: F821 - # if current processor is in `self._modules` and if passed `processor` is not, we need to - # pop `processor` from `self._modules` - if ( - hasattr(self, "processor") - and isinstance(self.processor, torch.nn.Module) - and not isinstance(processor, torch.nn.Module) - ): - logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") - self._modules.pop("processor") - - self.processor = processor - - def forward(self, x, context, context_mask=None, image_mask=None): - return self.processor( - self, - x, - context=context, - context_mask=context_mask, - ) - - class Kandinsky3Block(nn.Module): def __init__(self, in_channels, out_channels, time_embed_dim, kernel_size=3, norm_groups=32, up_resolution=None): super().__init__() self.group_norm = Kandinsky3ConditionalGroupNorm(norm_groups, in_channels, time_embed_dim) self.activation = nn.SiLU() - self.up_sample = set_default_layer( - up_resolution is not None and up_resolution, - nn.ConvTranspose2d, - (in_channels, in_channels), - {"kernel_size": 2, "stride": 2}, - ) + if up_resolution is not None and up_resolution: + self.up_sample = nn.ConvTranspose2d(in_channels, in_channels, kernel_size=2, stride=2) + else: + self.up_sample = nn.Identity() + padding = int(kernel_size > 1) self.projection = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding) - self.down_sample = set_default_layer( - up_resolution is not None and not up_resolution, - nn.Conv2d, - (out_channels, out_channels), - {"kernel_size": 2, "stride": 2}, - ) + + if up_resolution is not None and not up_resolution: + self.down_sample = nn.Conv2d(out_channels, out_channels, kernel_size=2, stride=2) + else: + self.down_sample = nn.Identity() def forward(self, x, time_embed): x = self.group_norm(x, time_embed) @@ -521,14 +454,18 @@ def __init__( ) ] ) - self.shortcut_up_sample = set_default_layer( - True in up_resolutions, nn.ConvTranspose2d, (in_channels, in_channels), {"kernel_size": 2, "stride": 2} + self.shortcut_up_sample = ( + nn.ConvTranspose2d(in_channels, in_channels, kernel_size=2, stride=2) + if True in up_resolutions + else nn.Identity() ) - self.shortcut_projection = set_default_layer( - in_channels != out_channels, nn.Conv2d, (in_channels, out_channels), {"kernel_size": 1} + self.shortcut_projection = ( + nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity() ) - self.shortcut_down_sample = set_default_layer( - False in up_resolutions, nn.Conv2d, (out_channels, out_channels), {"kernel_size": 2, "stride": 2} + self.shortcut_down_sample = ( + nn.Conv2d(out_channels, out_channels, kernel_size=2, stride=2) + if False in up_resolutions + else nn.Identity() ) def forward(self, x, time_embed): @@ -546,9 +483,16 @@ def forward(self, x, time_embed): class Kandinsky3AttentionPooling(nn.Module): def __init__(self, num_channels, context_dim, head_dim=64): super().__init__() - self.attention = Attention(context_dim, num_channels, context_dim, head_dim) + self.attention = Attention( + context_dim, + context_dim, + dim_head=head_dim, + out_dim=num_channels, + out_bias=False, + ) def forward(self, x, context, context_mask=None): + context_mask = context_mask.to(dtype=context.dtype) context = self.attention(context.mean(dim=1, keepdim=True), context, context_mask) return x + context.squeeze(1) @@ -557,7 +501,13 @@ class Kandinsky3AttentionBlock(nn.Module): def __init__(self, num_channels, time_embed_dim, context_dim=None, norm_groups=32, head_dim=64, expansion_ratio=4): super().__init__() self.in_norm = Kandinsky3ConditionalGroupNorm(norm_groups, num_channels, time_embed_dim) - self.attention = Attention(num_channels, num_channels, context_dim or num_channels, head_dim) + self.attention = Attention( + num_channels, + context_dim or num_channels, + dim_head=head_dim, + out_dim=num_channels, + out_bias=False, + ) hidden_channels = expansion_ratio * num_channels self.out_norm = Kandinsky3ConditionalGroupNorm(norm_groups, num_channels, time_embed_dim) @@ -572,14 +522,10 @@ def forward(self, x, time_embed, context=None, context_mask=None, image_mask=Non out = self.in_norm(x, time_embed) out = out.reshape(x.shape[0], -1, height * width).permute(0, 2, 1) context = context if context is not None else out + if context_mask is not None: + context_mask = context_mask.to(dtype=context.dtype) - if image_mask is not None: - mask_height, mask_width = image_mask.shape[-2:] - kernel_size = (mask_height // height, mask_width // width) - image_mask = F.max_pool2d(image_mask, kernel_size, kernel_size) - image_mask = image_mask.reshape(image_mask.shape[0], -1) - - out = self.attention(out, context, context_mask, image_mask) + out = self.attention(out, context, context_mask) out = out.permute(0, 2, 1).unsqueeze(-1).reshape(out.shape[0], -1, height, width) x = x + out diff --git a/src/diffusers/pipelines/kandinsky3/__init__.py b/src/diffusers/pipelines/kandinsky3/__init__.py index 4da3a83c04481..e8a3063141b5e 100644 --- a/src/diffusers/pipelines/kandinsky3/__init__.py +++ b/src/diffusers/pipelines/kandinsky3/__init__.py @@ -21,8 +21,8 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: - _import_structure["kandinsky3_pipeline"] = ["Kandinsky3Pipeline"] - _import_structure["kandinsky3img2img_pipeline"] = ["Kandinsky3Img2ImgPipeline"] + _import_structure["pipeline_kandinsky3"] = ["Kandinsky3Pipeline"] + _import_structure["pipeline_kandinsky3_img2img"] = ["Kandinsky3Img2ImgPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: @@ -33,8 +33,8 @@ except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: - from .kandinsky3_pipeline import Kandinsky3Pipeline - from .kandinsky3img2img_pipeline import Kandinsky3Img2ImgPipeline + from .pipeline_kandinsky3 import Kandinsky3Pipeline + from .pipeline_kandinsky3_img2img import Kandinsky3Img2ImgPipeline else: import sys diff --git a/src/diffusers/pipelines/kandinsky3/kandinsky3_pipeline.py b/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py similarity index 70% rename from src/diffusers/pipelines/kandinsky3/kandinsky3_pipeline.py rename to src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py index f116fb7894f05..4d14fc637b05b 100644 --- a/src/diffusers/pipelines/kandinsky3/kandinsky3_pipeline.py +++ b/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py @@ -1,4 +1,4 @@ -from typing import Callable, List, Optional, Union +from typing import Callable, Dict, List, Optional, Union import torch from transformers import T5EncoderModel, T5Tokenizer @@ -7,8 +7,10 @@ from ...models import Kandinsky3UNet, VQModel from ...schedulers import DDPMScheduler from ...utils import ( + deprecate, is_accelerate_available, logging, + replace_example_docstring, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput @@ -16,6 +18,23 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import AutoPipelineForText2Image + >>> import torch + + >>> pipe = AutoPipelineForText2Image.from_pretrained("kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16) + >>> pipe.enable_model_cpu_offload() + + >>> prompt = "A photograph of the inside of a subway train. There are raccoons sitting on the seats. One of them is reading a newspaper. The window shows the city in the background." + + >>> generator = torch.Generator(device="cpu").manual_seed(0) + >>> image = pipe(prompt, num_inference_steps=25, generator=generator).images[0] + ``` + +""" + def downscale_height_and_width(height, width, scale_factor=8): new_height = height // scale_factor**2 @@ -29,6 +48,13 @@ def downscale_height_and_width(height, width, scale_factor=8): class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin): model_cpu_offload_seq = "text_encoder->unet->movq" + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "negative_attention_mask", + "attention_mask", + ] def __init__( self, @@ -50,7 +76,7 @@ def remove_all_hooks(self): else: raise ImportError("Please install accelerate via `pip install accelerate`") - for model in [self.text_encoder, self.unet]: + for model in [self.text_encoder, self.unet, self.movq]: if model is not None: remove_hook_from_module(model, recurse=True) @@ -77,6 +103,8 @@ def encode_prompt( prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, _cut_context=False, + attention_mask: Optional[torch.FloatTensor] = None, + negative_attention_mask: Optional[torch.FloatTensor] = None, ): r""" Encodes the prompt into text encoder hidden states. @@ -101,6 +129,10 @@ def encode_prompt( Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + attention_mask (`torch.FloatTensor`, *optional*): + Pre-generated attention mask. Must provide if passing `prompt_embeds` directly. + negative_attention_mask (`torch.FloatTensor`, *optional*): + Pre-generated negative attention mask. Must provide if passing `negative_prompt_embeds` directly. """ if prompt is not None and negative_prompt is not None: if type(prompt) is not type(negative_prompt): @@ -228,14 +260,21 @@ def check_inputs( negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + attention_mask=None, + negative_attention_mask=None, ): - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) if prompt is not None and prompt_embeds is not None: raise ValueError( @@ -262,8 +301,42 @@ def check_inputs( f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) + if negative_prompt_embeds is not None and negative_attention_mask is None: + raise ValueError("Please provide `negative_attention_mask` along with `negative_prompt_embeds`") + + if negative_prompt_embeds is not None and negative_attention_mask is not None: + if negative_prompt_embeds.shape[:2] != negative_attention_mask.shape: + raise ValueError( + "`negative_prompt_embeds` and `negative_attention_mask` must have the same batch_size and token length when passed directly, but" + f" got: `negative_prompt_embeds` {negative_prompt_embeds.shape[:2]} != `negative_attention_mask`" + f" {negative_attention_mask.shape}." + ) + + if prompt_embeds is not None and attention_mask is None: + raise ValueError("Please provide `attention_mask` along with `prompt_embeds`") + + if prompt_embeds is not None and attention_mask is not None: + if prompt_embeds.shape[:2] != attention_mask.shape: + raise ValueError( + "`prompt_embeds` and `attention_mask` must have the same batch_size and token length when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape[:2]} != `attention_mask`" + f" {attention_mask.shape}." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, @@ -276,11 +349,14 @@ def __call__( generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + negative_attention_mask: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, latents=None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, ): """ Function invoked when calling the pipeline for generation. @@ -324,6 +400,10 @@ def __call__( Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + attention_mask (`torch.FloatTensor`, *optional*): + Pre-generated attention mask. Must provide if passing `prompt_embeds` directly. + negative_attention_mask (`torch.FloatTensor`, *optional*): + Pre-generated negative attention mask. Must provide if passing `negative_prompt_embeds` directly. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -343,12 +423,53 @@ def __call__( A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + cut_context = True device = self._execution_device # 1. Check inputs. Raise error if not correct - self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) + self.check_inputs( + prompt, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + attention_mask, + negative_attention_mask, + ) + + self._guidance_scale = guidance_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -357,24 +478,21 @@ def __call__( else: batch_size = prompt_embeds.shape[0] - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - # 3. Encode input prompt prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask = self.encode_prompt( prompt, - do_classifier_free_guidance, + self.do_classifier_free_guidance, num_images_per_prompt=num_images_per_prompt, device=device, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, _cut_context=cut_context, + attention_mask=attention_mask, + negative_attention_mask=negative_attention_mask, ) - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) attention_mask = torch.cat([negative_attention_mask, attention_mask]).bool() # 4. Prepare timesteps @@ -397,11 +515,11 @@ def __call__( self.text_encoder_offload_hook.offload() # 7. Denoising loop - # TODO(Yiyi): Correct the following line and use correctly - # num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents # predict the noise residual noise_pred = self.unet( @@ -412,7 +530,7 @@ def __call__( return_dict=False, )[0] - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = (guidance_scale + 1.0) * noise_pred_text - guidance_scale * noise_pred_uncond @@ -425,26 +543,45 @@ def __call__( latents, generator=generator, ).prev_sample - progress_bar.update() - if callback is not None and i % callback_steps == 0: - step_idx = i // getattr(self.scheduler, "order", 1) - callback(step_idx, t, latents) - # post-processing - image = self.movq.decode(latents, force_not_quantize=True)["sample"] + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + attention_mask = callback_outputs.pop("attention_mask", attention_mask) + negative_attention_mask = callback_outputs.pop("negative_attention_mask", negative_attention_mask) - if output_type not in ["pt", "np", "pil"]: + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # post-processing + if output_type not in ["pt", "np", "pil", "latent"]: raise ValueError( - f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}" + f"Only the output types `pt`, `pil`, `np` and `latent` are supported not output_type={output_type}" ) - if output_type in ["np", "pil"]: - image = image * 0.5 + 0.5 - image = image.clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).float().numpy() + if not output_type == "latent": + image = self.movq.decode(latents, force_not_quantize=True)["sample"] + + if output_type in ["np", "pil"]: + image = image * 0.5 + 0.5 + image = image.clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + else: + image = latents - if output_type == "pil": - image = self.numpy_to_pil(image) + self.maybe_free_model_hooks() if not return_dict: return (image,) diff --git a/src/diffusers/pipelines/kandinsky3/kandinsky3img2img_pipeline.py b/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py similarity index 59% rename from src/diffusers/pipelines/kandinsky3/kandinsky3img2img_pipeline.py rename to src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py index b043110cf1d7f..7f4164a04d1ed 100644 --- a/src/diffusers/pipelines/kandinsky3/kandinsky3img2img_pipeline.py +++ b/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py @@ -1,5 +1,5 @@ import inspect -from typing import Callable, List, Optional, Union +from typing import Callable, Dict, List, Optional, Union import numpy as np import PIL @@ -11,8 +11,10 @@ from ...models import Kandinsky3UNet, VQModel from ...schedulers import DDPMScheduler from ...utils import ( + deprecate, is_accelerate_available, logging, + replace_example_docstring, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput @@ -20,6 +22,24 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import AutoPipelineForImage2Image + >>> from diffusers.utils import load_image + >>> import torch + + >>> pipe = AutoPipelineForImage2Image.from_pretrained("kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16) + >>> pipe.enable_model_cpu_offload() + + >>> prompt = "A painting of the inside of a subway train with tiny raccoons." + >>> image = load_image("https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky3/t2i.png") + + >>> generator = torch.Generator(device="cpu").manual_seed(0) + >>> image = pipe(prompt, image=image, strength=0.75, num_inference_steps=25, generator=generator).images[0] + ``` +""" + def downscale_height_and_width(height, width, scale_factor=8): new_height = height // scale_factor**2 @@ -40,7 +60,14 @@ def prepare_image(pil_image): class Kandinsky3Img2ImgPipeline(DiffusionPipeline, LoraLoaderMixin): - model_cpu_offload_seq = "text_encoder->unet->movq" + model_cpu_offload_seq = "text_encoder->movq->unet->movq" + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "negative_attention_mask", + "attention_mask", + ] def __init__( self, @@ -99,6 +126,8 @@ def encode_prompt( prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, _cut_context=False, + attention_mask: Optional[torch.FloatTensor] = None, + negative_attention_mask: Optional[torch.FloatTensor] = None, ): r""" Encodes the prompt into text encoder hidden states. @@ -123,6 +152,10 @@ def encode_prompt( Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + attention_mask (`torch.FloatTensor`, *optional*): + Pre-generated attention mask. Must provide if passing `prompt_embeds` directly. + negative_attention_mask (`torch.FloatTensor`, *optional*): + Pre-generated negative attention mask. Must provide if passing `negative_prompt_embeds` directly. """ if prompt is not None and negative_prompt is not None: if type(prompt) is not type(negative_prompt): @@ -299,15 +332,23 @@ def check_inputs( negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + attention_mask=None, + negative_attention_mask=None, ): - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" @@ -334,7 +375,42 @@ def check_inputs( f" {negative_prompt_embeds.shape}." ) + if negative_prompt_embeds is not None and negative_attention_mask is None: + raise ValueError("Please provide `negative_attention_mask` along with `negative_prompt_embeds`") + + if negative_prompt_embeds is not None and negative_attention_mask is not None: + if negative_prompt_embeds.shape[:2] != negative_attention_mask.shape: + raise ValueError( + "`negative_prompt_embeds` and `negative_attention_mask` must have the same batch_size and token length when passed directly, but" + f" got: `negative_prompt_embeds` {negative_prompt_embeds.shape[:2]} != `negative_attention_mask`" + f" {negative_attention_mask.shape}." + ) + + if prompt_embeds is not None and attention_mask is None: + raise ValueError("Please provide `attention_mask` along with `prompt_embeds`") + + if prompt_embeds is not None and attention_mask is not None: + if prompt_embeds.shape[:2] != attention_mask.shape: + raise ValueError( + "`prompt_embeds` and `attention_mask` must have the same batch_size and token length when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape[:2]} != `attention_mask`" + f" {attention_mask.shape}." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, @@ -347,15 +423,117 @@ def __call__( generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + negative_attention_mask: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, - latents=None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + strength (`float`, *optional*, defaults to 0.8): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 3.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + attention_mask (`torch.FloatTensor`, *optional*): + Pre-generated attention mask. Must provide if passing `prompt_embeds` directly. + negative_attention_mask (`torch.FloatTensor`, *optional*): + Pre-generated negative attention mask. Must provide if passing `negative_prompt_embeds` directly. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + + """ + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + cut_context = True # 1. Check inputs. Raise error if not correct - self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) + self.check_inputs( + prompt, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + attention_mask, + negative_attention_mask, + ) + + self._guidance_scale = guidance_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -366,24 +544,21 @@ def __call__( device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - # 3. Encode input prompt prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask = self.encode_prompt( prompt, - do_classifier_free_guidance, + self.do_classifier_free_guidance, num_images_per_prompt=num_images_per_prompt, device=device, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, _cut_context=cut_context, + attention_mask=attention_mask, + negative_attention_mask=negative_attention_mask, ) - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) attention_mask = torch.cat([negative_attention_mask, attention_mask]).bool() if not isinstance(image, list): @@ -409,11 +584,11 @@ def __call__( self.text_encoder_offload_hook.offload() # 7. Denoising loop - # TODO(Yiyi): Correct the following line and use correctly - # num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents # predict the noise residual noise_pred = self.unet( @@ -422,7 +597,7 @@ def __call__( encoder_hidden_states=prompt_embeds, encoder_attention_mask=attention_mask, )[0] - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = (guidance_scale + 1.0) * noise_pred_text - guidance_scale * noise_pred_uncond @@ -434,25 +609,44 @@ def __call__( latents, generator=generator, ).prev_sample - progress_bar.update() - if callback is not None and i % callback_steps == 0: - step_idx = i // getattr(self.scheduler, "order", 1) - callback(step_idx, t, latents) - # post-processing - image = self.movq.decode(latents, force_not_quantize=True)["sample"] - if output_type not in ["pt", "np", "pil"]: + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + attention_mask = callback_outputs.pop("attention_mask", attention_mask) + negative_attention_mask = callback_outputs.pop("negative_attention_mask", negative_attention_mask) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # post-processing + if output_type not in ["pt", "np", "pil", "latent"]: raise ValueError( - f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}" + f"Only the output types `pt`, `pil`, `np` and `latent` are supported not output_type={output_type}" ) + if not output_type == "latent": + image = self.movq.decode(latents, force_not_quantize=True)["sample"] + + if output_type in ["np", "pil"]: + image = image * 0.5 + 0.5 + image = image.clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() - if output_type in ["np", "pil"]: - image = image * 0.5 + 0.5 - image = image.clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).float().numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) + else: + image = latents - if output_type == "pil": - image = self.numpy_to_pil(image) + self.maybe_free_model_hooks() if not return_dict: return (image,) diff --git a/tests/pipelines/kandinsky3/test_kandinsky3.py b/tests/pipelines/kandinsky3/test_kandinsky3.py index 65297a36b1575..c163fe3102c41 100644 --- a/tests/pipelines/kandinsky3/test_kandinsky3.py +++ b/tests/pipelines/kandinsky3/test_kandinsky3.py @@ -165,10 +165,6 @@ def test_float16_inference(self): def test_inference_batch_single_identical(self): super().test_inference_batch_single_identical(expected_max_diff=1e-2) - def test_model_cpu_offload_forward_pass(self): - # TODO(Yiyi) - this test should work, skipped for time reasons for now - pass - @slow @require_torch_gpu diff --git a/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py b/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py new file mode 100644 index 0000000000000..581251a816394 --- /dev/null +++ b/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py @@ -0,0 +1,225 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import random +import unittest + +import numpy as np +import torch +from PIL import Image +from transformers import AutoTokenizer, T5EncoderModel + +from diffusers import ( + AutoPipelineForImage2Image, + Kandinsky3Img2ImgPipeline, + Kandinsky3UNet, + VQModel, +) +from diffusers.image_processor import VaeImageProcessor +from diffusers.schedulers.scheduling_ddpm import DDPMScheduler +from diffusers.utils.testing_utils import ( + enable_full_determinism, + floats_tensor, + load_image, + require_torch_gpu, + slow, +) + +from ..pipeline_params import ( + IMAGE_TO_IMAGE_IMAGE_PARAMS, + TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, + TEXT_GUIDED_IMAGE_VARIATION_PARAMS, + TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS, + TEXT_TO_IMAGE_IMAGE_PARAMS, +) +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class Kandinsky3Img2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = Kandinsky3Img2ImgPipeline + params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"} + batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS + image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS + test_xformers_attention = False + required_optional_params = frozenset( + [ + "num_inference_steps", + "num_images_per_prompt", + "generator", + "output_type", + "return_dict", + ] + ) + + @property + def dummy_movq_kwargs(self): + return { + "block_out_channels": [32, 64], + "down_block_types": ["DownEncoderBlock2D", "AttnDownEncoderBlock2D"], + "in_channels": 3, + "latent_channels": 4, + "layers_per_block": 1, + "norm_num_groups": 8, + "norm_type": "spatial", + "num_vq_embeddings": 12, + "out_channels": 3, + "up_block_types": [ + "AttnUpDecoderBlock2D", + "UpDecoderBlock2D", + ], + "vq_embed_dim": 4, + } + + @property + def dummy_movq(self): + torch.manual_seed(0) + model = VQModel(**self.dummy_movq_kwargs) + return model + + def get_dummy_components(self, time_cond_proj_dim=None): + torch.manual_seed(0) + unet = Kandinsky3UNet( + in_channels=4, + time_embedding_dim=4, + groups=2, + attention_head_dim=4, + layers_per_block=3, + block_out_channels=(32, 64), + cross_attention_dim=4, + encoder_hid_dim=32, + ) + scheduler = DDPMScheduler( + beta_start=0.00085, + beta_end=0.012, + steps_offset=1, + beta_schedule="squaredcos_cap_v2", + clip_sample=True, + thresholding=False, + ) + torch.manual_seed(0) + movq = self.dummy_movq + torch.manual_seed(0) + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + components = { + "unet": unet, + "scheduler": scheduler, + "movq": movq, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + # create init_image + image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device) + image = image.cpu().permute(0, 2, 3, 1)[0] + init_image = Image.fromarray(np.uint8(image)).convert("RGB") + + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "image": init_image, + "generator": generator, + "strength": 0.75, + "num_inference_steps": 10, + "guidance_scale": 6.0, + "output_type": "np", + } + return inputs + + def test_kandinsky3_img2img(self): + device = "cpu" + + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + + pipe.set_progress_bar_config(disable=None) + + output = pipe(**self.get_dummy_inputs(device)) + image = output.images + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + + expected_slice = np.array( + [0.576259, 0.6132097, 0.41703486, 0.603196, 0.62062526, 0.4655338, 0.5434324, 0.5660727, 0.65433365] + ) + + assert ( + np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}" + + def test_float16_inference(self): + super().test_float16_inference(expected_max_diff=1e-1) + + def test_inference_batch_single_identical(self): + super().test_inference_batch_single_identical(expected_max_diff=1e-2) + + +@slow +@require_torch_gpu +class Kandinsky3Img2ImgPipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_kandinskyV3_img2img(self): + pipe = AutoPipelineForImage2Image.from_pretrained( + "kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16 + ) + pipe.enable_model_cpu_offload() + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device="cpu").manual_seed(0) + + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky3/t2i.png" + ) + w, h = 512, 512 + image = image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1) + prompt = "A painting of the inside of a subway train with tiny raccoons." + + image = pipe(prompt, image=image, strength=0.75, num_inference_steps=25, generator=generator).images[0] + + assert image.size == (512, 512) + + expected_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky3/i2i.png" + ) + + image_processor = VaeImageProcessor() + + image_np = image_processor.pil_to_numpy(image) + expected_image_np = image_processor.pil_to_numpy(expected_image) + + self.assertTrue(np.allclose(image_np, expected_image_np, atol=5e-2)) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index e11175921184d..cac5ee442ae66 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -377,6 +377,10 @@ def test_save_load_local(self, expected_max_difference=5e-4): with CaptureLogger(logger) as cap_logger: pipe_loaded = self.pipeline_class.from_pretrained(tmpdir) + for component in pipe_loaded.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + for name in pipe_loaded.components.keys(): if name not in pipe_loaded._optional_components: assert name in str(cap_logger)