From 293c480d4cf0bed442e02395dfd012c686d2ca59 Mon Sep 17 00:00:00 2001 From: yiyixu Date: Mon, 27 Nov 2023 09:18:47 +0000 Subject: [PATCH 01/20] remove kandinsky specific attention and attention processor --- src/diffusers/models/attention_processor.py | 49 ++------------- src/diffusers/models/unet_kandi3.py | 68 ++++----------------- 2 files changed, 18 insertions(+), 99 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 21eb3a32dc09..40a335527ace 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -16,7 +16,7 @@ import torch import torch.nn.functional as F -from torch import einsum, nn +from torch import nn from ..utils import USE_PEFT_BACKEND, deprecate, logging from ..utils.import_utils import is_xformers_available @@ -109,15 +109,17 @@ def __init__( residual_connection: bool = False, _from_deprecated_attn_block: bool = False, processor: Optional["AttnProcessor"] = None, + out_dim: int = None, ): super().__init__() - self.inner_dim = dim_head * heads + self.inner_dim = out_dim if out_dim is not None else dim_head * heads self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim self.upcast_attention = upcast_attention self.upcast_softmax = upcast_softmax self.rescale_output_factor = rescale_output_factor self.residual_connection = residual_connection self.dropout = dropout + self.out_dim = out_dim if out_dim is not None else query_dim # we make use of this private variable to know whether this class is loaded # with an deprecated state dict so that we can convert it on the fly @@ -126,7 +128,7 @@ def __init__( self.scale_qk = scale_qk self.scale = dim_head**-0.5 if self.scale_qk else 1.0 - self.heads = heads + self.heads = out_dim // dim_head if out_dim is not None else heads # for slice_size > 0 the attention score computation # is split across the batch axis to save memory # You can set slice_size with `set_attention_slice` @@ -193,7 +195,7 @@ def __init__( self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim) self.to_out = nn.ModuleList([]) - self.to_out.append(linear_cls(self.inner_dim, query_dim, bias=out_bias)) + self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias)) self.to_out.append(nn.Dropout(dropout)) # set attention processor @@ -2219,44 +2221,6 @@ def __call__( return hidden_states -# TODO(Yiyi): This class should not exist, we can replace it with a normal attention processor I believe -# this way torch.compile and co. will work as well -class Kandi3AttnProcessor: - r""" - Default kandinsky3 proccesor for performing attention-related computations. - """ - - @staticmethod - def _reshape(hid_states, h): - b, n, f = hid_states.shape - d = f // h - return hid_states.unsqueeze(-1).reshape(b, n, h, d).permute(0, 2, 1, 3) - - def __call__( - self, - attn, - x, - context, - context_mask=None, - ): - query = self._reshape(attn.to_q(x), h=attn.num_heads) - key = self._reshape(attn.to_k(context), h=attn.num_heads) - value = self._reshape(attn.to_v(context), h=attn.num_heads) - - attention_matrix = einsum("b h i d, b h j d -> b h i j", query, key) - - if context_mask is not None: - max_neg_value = -torch.finfo(attention_matrix.dtype).max - context_mask = context_mask.unsqueeze(1).unsqueeze(1) - attention_matrix = attention_matrix.masked_fill(~(context_mask != 0), max_neg_value) - attention_matrix = (attention_matrix * attn.scale).softmax(dim=-1) - - out = einsum("b h i j, b h j d -> b h i d", attention_matrix, value) - out = out.permute(0, 2, 1, 3).reshape(out.shape[0], out.shape[2], -1) - out = attn.to_out[0](out) - return out - - LORA_ATTENTION_PROCESSORS = ( LoRAAttnProcessor, LoRAAttnProcessor2_0, @@ -2282,7 +2246,6 @@ def __call__( LoRAXFormersAttnProcessor, IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, - Kandi3AttnProcessor, ) AttentionProcessor = Union[ diff --git a/src/diffusers/models/unet_kandi3.py b/src/diffusers/models/unet_kandi3.py index 42e25a942f7d..036d2b2195e1 100644 --- a/src/diffusers/models/unet_kandi3.py +++ b/src/diffusers/models/unet_kandi3.py @@ -3,13 +3,12 @@ from typing import Dict, Tuple, Union import torch -import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, logging -from .attention_processor import AttentionProcessor, Kandi3AttnProcessor +from .attention_processor import Attention, AttentionProcessor, AttnProcessor from .embeddings import TimestepEmbedding from .modeling_utils import ModelMixin @@ -30,7 +29,7 @@ def set_default_item(condition, item_1, item_2=None): return item_2 -# TODO(Yiyi): This class needs to be removed +# TODO(Yiyi): This class needs to be removed: either layer_1 or nn.identity def set_default_layer(condition, layer_1, args_1=[], kwargs_1={}, layer_2=torch.nn.Identity, args_2=[], kwargs_2={}): if condition: return layer_1(*args_1, **kwargs_1) @@ -223,7 +222,7 @@ def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. """ - self.set_attn_processor(Kandi3AttnProcessor()) + self.set_attn_processor(AttnProcessor()) def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): @@ -412,6 +411,7 @@ def forward(self, x, time_embed, context=None, context_mask=None, image_mask=Non return x +# yiyi notes: should not have a seperate class here either class Kandinsky3ConditionalGroupNorm(nn.Module): def __init__(self, groups, normalized_shape, context_dim): super().__init__() @@ -431,49 +431,6 @@ def forward(self, x, context): return x -# TODO(Yiyi): This class should ideally not even exist, it slows everything needlessly down. I'm pretty -# sure we can delete it and instead just pass an attention_mask -class Attention(nn.Module): - def __init__(self, in_channels, out_channels, context_dim, head_dim=64): - super().__init__() - assert out_channels % head_dim == 0 - self.num_heads = out_channels // head_dim - self.scale = head_dim**-0.5 - - # to_q - self.to_q = nn.Linear(in_channels, out_channels, bias=False) - # to_k - self.to_k = nn.Linear(context_dim, out_channels, bias=False) - # to_v - self.to_v = nn.Linear(context_dim, out_channels, bias=False) - processor = Kandi3AttnProcessor() - self.set_processor(processor) - # to_out - self.to_out = nn.ModuleList([]) - self.to_out.append(nn.Linear(out_channels, out_channels, bias=False)) - - def set_processor(self, processor: "AttnProcessor"): # noqa: F821 - # if current processor is in `self._modules` and if passed `processor` is not, we need to - # pop `processor` from `self._modules` - if ( - hasattr(self, "processor") - and isinstance(self.processor, torch.nn.Module) - and not isinstance(processor, torch.nn.Module) - ): - logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") - self._modules.pop("processor") - - self.processor = processor - - def forward(self, x, context, context_mask=None, image_mask=None): - return self.processor( - self, - x, - context=context, - context_mask=context_mask, - ) - - class Kandinsky3Block(nn.Module): def __init__(self, in_channels, out_channels, time_embed_dim, kernel_size=3, norm_groups=32, up_resolution=None): super().__init__() @@ -546,9 +503,10 @@ def forward(self, x, time_embed): class Kandinsky3AttentionPooling(nn.Module): def __init__(self, num_channels, context_dim, head_dim=64): super().__init__() - self.attention = Attention(context_dim, num_channels, context_dim, head_dim) + self.attention = Attention(context_dim, context_dim, dim_head=head_dim, out_dim=num_channels, out_bias=False) def forward(self, x, context, context_mask=None): + context_mask = context_mask.unsqueeze(1).to(dtype=context.dtype) context = self.attention(context.mean(dim=1, keepdim=True), context, context_mask) return x + context.squeeze(1) @@ -557,7 +515,9 @@ class Kandinsky3AttentionBlock(nn.Module): def __init__(self, num_channels, time_embed_dim, context_dim=None, norm_groups=32, head_dim=64, expansion_ratio=4): super().__init__() self.in_norm = Kandinsky3ConditionalGroupNorm(norm_groups, num_channels, time_embed_dim) - self.attention = Attention(num_channels, num_channels, context_dim or num_channels, head_dim) + self.attention = Attention( + num_channels, context_dim or num_channels, dim_head=head_dim, out_dim=num_channels, out_bias=False + ) hidden_channels = expansion_ratio * num_channels self.out_norm = Kandinsky3ConditionalGroupNorm(norm_groups, num_channels, time_embed_dim) @@ -572,14 +532,10 @@ def forward(self, x, time_embed, context=None, context_mask=None, image_mask=Non out = self.in_norm(x, time_embed) out = out.reshape(x.shape[0], -1, height * width).permute(0, 2, 1) context = context if context is not None else out + if context_mask is not None: + context_mask = context_mask.unsqueeze(1).to(dtype=context.dtype) - if image_mask is not None: - mask_height, mask_width = image_mask.shape[-2:] - kernel_size = (mask_height // height, mask_width // width) - image_mask = F.max_pool2d(image_mask, kernel_size, kernel_size) - image_mask = image_mask.reshape(image_mask.shape[0], -1) - - out = self.attention(out, context, context_mask, image_mask) + out = self.attention(out, context, context_mask) out = out.permute(0, 2, 1).unsqueeze(-1).reshape(out.shape[0], -1, height, width) x = x + out From eca5c18542d721619e404f6461b553feeb379adc Mon Sep 17 00:00:00 2001 From: yiyixu Date: Tue, 28 Nov 2023 01:10:26 +0000 Subject: [PATCH 02/20] remove set_default_layer and set_default_item --- src/diffusers/models/unet_kandi3.py | 118 ++++++++++++---------------- 1 file changed, 52 insertions(+), 66 deletions(-) diff --git a/src/diffusers/models/unet_kandi3.py b/src/diffusers/models/unet_kandi3.py index 036d2b2195e1..355ce47d624c 100644 --- a/src/diffusers/models/unet_kandi3.py +++ b/src/diffusers/models/unet_kandi3.py @@ -21,22 +21,6 @@ class Kandinsky3UNetOutput(BaseOutput): sample: torch.FloatTensor = None -# TODO(Yiyi): This class needs to be removed -def set_default_item(condition, item_1, item_2=None): - if condition: - return item_1 - else: - return item_2 - - -# TODO(Yiyi): This class needs to be removed: either layer_1 or nn.identity -def set_default_layer(condition, layer_1, args_1=[], kwargs_1={}, layer_2=torch.nn.Identity, args_2=[], kwargs_2={}): - if condition: - return layer_1(*args_1, **kwargs_1) - else: - return layer_2(*args_2, **kwargs_2) - - # TODO(Yiyi): This class should be removed and be replaced by Timesteps class SinusoidalPosEmb(nn.Module): def __init__(self, dim): @@ -105,7 +89,7 @@ def __init__( hidden_dims = [init_channels] + list(block_out_channels) in_out_dims = list(zip(hidden_dims[:-1], hidden_dims[1:])) - text_dims = [set_default_item(is_exist, cross_attention_dim) for is_exist in add_cross_attention] + text_dims = [cross_attention_dim if is_exist else None for is_exist in add_cross_attention] num_blocks = len(block_out_channels) * [layers_per_block] layer_params = [num_blocks, text_dims, add_self_attention] rev_layer_params = map(reversed, layer_params) @@ -117,7 +101,7 @@ def __init__( zip(in_out_dims, *layer_params) ): down_sample = level != (self.num_levels - 1) - cat_dims.append(set_default_item(level != (self.num_levels - 1), out_dim, 0)) + cat_dims.append(out_dim if level != (self.num_levels - 1) else 0) self.down_blocks.append( Kandinsky3DownSampleBlock( in_dim, @@ -289,7 +273,7 @@ def __init__( self_attention=True, ): super().__init__() - up_resolutions = [[None, set_default_item(up_sample, True), None, None]] + [[None] * 4] * (num_blocks - 1) + up_resolutions = [[None, True if up_sample else None, None, None]] + [[None] * 4] * (num_blocks - 1) hidden_channels = ( [(in_channels + cat_dim, in_channels)] + [(in_channels, in_channels)] * (num_blocks - 2) @@ -302,27 +286,27 @@ def __init__( self.self_attention = self_attention self.context_dim = context_dim - attentions.append( - set_default_layer( - self_attention, - Kandinsky3AttentionBlock, - (out_channels, time_embed_dim, None, groups, head_dim, expansion_ratio), - layer_2=nn.Identity, + if self_attention: + attentions.append( + Kandinsky3AttentionBlock(out_channels, time_embed_dim, None, groups, head_dim, expansion_ratio) ) - ) + else: + attentions.append(nn.Identity()) for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions): resnets_in.append( Kandinsky3ResNetBlock(in_channel, in_channel, time_embed_dim, groups, compression_ratio, up_resolution) ) - attentions.append( - set_default_layer( - context_dim is not None, - Kandinsky3AttentionBlock, - (in_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio), - layer_2=nn.Identity, + + if context_dim is not None: + attentions.append( + Kandinsky3AttentionBlock( + in_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio + ) ) - ) + else: + attentions.append(nn.Identity()) + resnets_out.append( Kandinsky3ResNetBlock(in_channel, out_channel, time_embed_dim, groups, compression_ratio) ) @@ -366,29 +350,29 @@ def __init__( self.self_attention = self_attention self.context_dim = context_dim - attentions.append( - set_default_layer( - self_attention, - Kandinsky3AttentionBlock, - (in_channels, time_embed_dim, None, groups, head_dim, expansion_ratio), - layer_2=nn.Identity, + if self_attention: + attentions.append( + Kandinsky3AttentionBlock(in_channels, time_embed_dim, None, groups, head_dim, expansion_ratio) ) - ) + else: + attentions.append(nn.Identity()) - up_resolutions = [[None] * 4] * (num_blocks - 1) + [[None, None, set_default_item(down_sample, False), None]] + up_resolutions = [[None] * 4] * (num_blocks - 1) + [[None, None, False if down_sample else None, None]] hidden_channels = [(in_channels, out_channels)] + [(out_channels, out_channels)] * (num_blocks - 1) for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions): resnets_in.append( Kandinsky3ResNetBlock(in_channel, out_channel, time_embed_dim, groups, compression_ratio) ) - attentions.append( - set_default_layer( - context_dim is not None, - Kandinsky3AttentionBlock, - (out_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio), - layer_2=nn.Identity, + + if context_dim is not None: + attentions.append( + Kandinsky3AttentionBlock( + out_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio + ) ) - ) + else: + attentions.append(nn.Identity()) + resnets_out.append( Kandinsky3ResNetBlock( out_channel, out_channel, time_embed_dim, groups, compression_ratio, up_resolution @@ -436,20 +420,18 @@ def __init__(self, in_channels, out_channels, time_embed_dim, kernel_size=3, nor super().__init__() self.group_norm = Kandinsky3ConditionalGroupNorm(norm_groups, in_channels, time_embed_dim) self.activation = nn.SiLU() - self.up_sample = set_default_layer( - up_resolution is not None and up_resolution, - nn.ConvTranspose2d, - (in_channels, in_channels), - {"kernel_size": 2, "stride": 2}, - ) + if up_resolution is not None and up_resolution: + self.up_sample = nn.ConvTranspose2d(in_channels, in_channels, kernel_size=2, stride=2) + else: + self.up_sample = nn.Identity() + padding = int(kernel_size > 1) self.projection = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding) - self.down_sample = set_default_layer( - up_resolution is not None and not up_resolution, - nn.Conv2d, - (out_channels, out_channels), - {"kernel_size": 2, "stride": 2}, - ) + + if up_resolution is not None and not up_resolution: + self.down_sample = nn.Conv2d(out_channels, out_channels, kernel_size=2, stride=2) + else: + self.down_sample = nn.Identity() def forward(self, x, time_embed): x = self.group_norm(x, time_embed) @@ -478,14 +460,18 @@ def __init__( ) ] ) - self.shortcut_up_sample = set_default_layer( - True in up_resolutions, nn.ConvTranspose2d, (in_channels, in_channels), {"kernel_size": 2, "stride": 2} + self.shortcut_up_sample = ( + nn.ConvTranspose2d(in_channels, in_channels, kernel_size=2, stride=2) + if True in up_resolutions + else nn.Identity() ) - self.shortcut_projection = set_default_layer( - in_channels != out_channels, nn.Conv2d, (in_channels, out_channels), {"kernel_size": 1} + self.shortcut_projection = ( + nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity() ) - self.shortcut_down_sample = set_default_layer( - False in up_resolutions, nn.Conv2d, (out_channels, out_channels), {"kernel_size": 2, "stride": 2} + self.shortcut_down_sample = ( + nn.Conv2d(out_channels, out_channels, kernel_size=2, stride=2) + if False in up_resolutions + else nn.Identity() ) def forward(self, x, time_embed): From b8bb288001b14c08f7d90b4dccf614ef8c8e0245 Mon Sep 17 00:00:00 2001 From: yiyixu Date: Tue, 28 Nov 2023 01:55:23 +0000 Subject: [PATCH 03/20] remove SinusoidalPosEmb --- src/diffusers/models/unet_kandi3.py | 21 ++------------------- 1 file changed, 2 insertions(+), 19 deletions(-) diff --git a/src/diffusers/models/unet_kandi3.py b/src/diffusers/models/unet_kandi3.py index 355ce47d624c..257e9b2dd7a3 100644 --- a/src/diffusers/models/unet_kandi3.py +++ b/src/diffusers/models/unet_kandi3.py @@ -1,4 +1,3 @@ -import math from dataclasses import dataclass from typing import Dict, Tuple, Union @@ -9,7 +8,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, logging from .attention_processor import Attention, AttentionProcessor, AttnProcessor -from .embeddings import TimestepEmbedding +from .embeddings import TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin @@ -21,20 +20,6 @@ class Kandinsky3UNetOutput(BaseOutput): sample: torch.FloatTensor = None -# TODO(Yiyi): This class should be removed and be replaced by Timesteps -class SinusoidalPosEmb(nn.Module): - def __init__(self, dim): - super().__init__() - self.dim = dim - - def forward(self, x, type_tensor=None): - half_dim = self.dim // 2 - emb = math.log(10000) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, device=x.device) * -emb) - emb = x[:, None] * emb[None, :] - return torch.cat((emb.sin(), emb.cos()), dim=-1) - - class Kandinsky3EncoderProj(nn.Module): def __init__(self, encoder_hid_dim, cross_attention_dim): super().__init__() @@ -70,9 +55,7 @@ def __init__( out_channels = in_channels init_channels = block_out_channels[0] // 2 - # TODO(Yiyi): Should be replaced with Timesteps class -> make sure that results are the same - # self.time_proj = Timesteps(init_channels, flip_sin_to_cos=False, downscale_freq_shift=1) - self.time_proj = SinusoidalPosEmb(init_channels) + self.time_proj = Timesteps(init_channels, flip_sin_to_cos=False, downscale_freq_shift=1) self.time_embedding = TimestepEmbedding( init_channels, From 84ce3d67d154db2a7454ae3265d83520f8e53f90 Mon Sep 17 00:00:00 2001 From: yiyixu Date: Tue, 28 Nov 2023 04:32:56 +0000 Subject: [PATCH 04/20] more --- src/diffusers/models/unet_kandi3.py | 35 ++++++++++++----------------- 1 file changed, 14 insertions(+), 21 deletions(-) diff --git a/src/diffusers/models/unet_kandi3.py b/src/diffusers/models/unet_kandi3.py index 257e9b2dd7a3..3eea77ccb841 100644 --- a/src/diffusers/models/unet_kandi3.py +++ b/src/diffusers/models/unet_kandi3.py @@ -196,12 +196,6 @@ def _set_gradient_checkpointing(self, module, value=False): module.gradient_checkpointing = value def forward(self, sample, timestep, encoder_hidden_states=None, encoder_attention_mask=None, return_dict=True): - # TODO(Yiyi): Clean up the following variables - these names should not be used - # but instead only the ones that we pass to forward - x = sample - context_mask = encoder_attention_mask - context = encoder_hidden_states - if not torch.is_tensor(timestep): dtype = torch.float32 if isinstance(timestep, float) else torch.int32 timestep = torch.tensor([timestep], dtype=dtype, device=sample.device) @@ -210,33 +204,33 @@ def forward(self, sample, timestep, encoder_hidden_states=None, encoder_attentio # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = timestep.expand(sample.shape[0]) - time_embed_input = self.time_proj(timestep).to(x.dtype) + time_embed_input = self.time_proj(timestep).to(sample.dtype) time_embed = self.time_embedding(time_embed_input) - context = self.encoder_hid_proj(context) + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) - if context is not None: - time_embed = self.add_time_condition(time_embed, context, context_mask) + if encoder_hidden_states is not None: + time_embed = self.add_time_condition(time_embed, encoder_hidden_states, encoder_attention_mask) hidden_states = [] - x = self.conv_in(x) + sample = self.conv_in(sample) for level, down_sample in enumerate(self.down_blocks): - x = down_sample(x, time_embed, context, context_mask) + sample = down_sample(sample, time_embed, encoder_hidden_states, encoder_attention_mask) if level != self.num_levels - 1: - hidden_states.append(x) + hidden_states.append(sample) for level, up_sample in enumerate(self.up_blocks): if level != 0: - x = torch.cat([x, hidden_states.pop()], dim=1) - x = up_sample(x, time_embed, context, context_mask) + sample = torch.cat([sample, hidden_states.pop()], dim=1) + sample = up_sample(sample, time_embed, encoder_hidden_states, encoder_attention_mask) - x = self.conv_norm_out(x) - x = self.conv_act_out(x) - x = self.conv_out(x) + sample = self.conv_norm_out(sample) + sample = self.conv_act_out(sample) + sample = self.conv_out(sample) if not return_dict: - return (x,) - return Kandinsky3UNetOutput(sample=x) + return (sample,) + return Kandinsky3UNetOutput(sample=sample) class Kandinsky3UpSampleBlock(nn.Module): @@ -378,7 +372,6 @@ def forward(self, x, time_embed, context=None, context_mask=None, image_mask=Non return x -# yiyi notes: should not have a seperate class here either class Kandinsky3ConditionalGroupNorm(nn.Module): def __init__(self, groups, normalized_shape, context_dim): super().__init__() From 123dafcdea49c154a2dda7f7086d18bd981c6fb8 Mon Sep 17 00:00:00 2001 From: yiyixu Date: Wed, 29 Nov 2023 03:28:04 +0000 Subject: [PATCH 05/20] disable batch test --- tests/pipelines/kandinsky3/test_kandinsky3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipelines/kandinsky3/test_kandinsky3.py b/tests/pipelines/kandinsky3/test_kandinsky3.py index 65297a36b157..82eed67f39d1 100644 --- a/tests/pipelines/kandinsky3/test_kandinsky3.py +++ b/tests/pipelines/kandinsky3/test_kandinsky3.py @@ -163,7 +163,7 @@ def test_float16_inference(self): super().test_float16_inference(expected_max_diff=1e-1) def test_inference_batch_single_identical(self): - super().test_inference_batch_single_identical(expected_max_diff=1e-2) + pass def test_model_cpu_offload_forward_pass(self): # TODO(Yiyi) - this test should work, skipped for time reasons for now From 145ddad6889f835e68368df88f01296811104e46 Mon Sep 17 00:00:00 2001 From: yiyixu Date: Wed, 29 Nov 2023 04:36:56 +0000 Subject: [PATCH 06/20] fix cpu model offload --- src/diffusers/pipelines/kandinsky3/kandinsky3_pipeline.py | 4 +++- tests/pipelines/kandinsky3/test_kandinsky3.py | 4 ---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/kandinsky3/kandinsky3_pipeline.py b/src/diffusers/pipelines/kandinsky3/kandinsky3_pipeline.py index f116fb7894f0..a6a0e01084de 100644 --- a/src/diffusers/pipelines/kandinsky3/kandinsky3_pipeline.py +++ b/src/diffusers/pipelines/kandinsky3/kandinsky3_pipeline.py @@ -50,7 +50,7 @@ def remove_all_hooks(self): else: raise ImportError("Please install accelerate via `pip install accelerate`") - for model in [self.text_encoder, self.unet]: + for model in [self.text_encoder, self.unet, self.movq]: if model is not None: remove_hook_from_module(model, recurse=True) @@ -446,6 +446,8 @@ def __call__( if output_type == "pil": image = self.numpy_to_pil(image) + self.maybe_free_model_hooks() + if not return_dict: return (image,) diff --git a/tests/pipelines/kandinsky3/test_kandinsky3.py b/tests/pipelines/kandinsky3/test_kandinsky3.py index 82eed67f39d1..7b5691409383 100644 --- a/tests/pipelines/kandinsky3/test_kandinsky3.py +++ b/tests/pipelines/kandinsky3/test_kandinsky3.py @@ -165,10 +165,6 @@ def test_float16_inference(self): def test_inference_batch_single_identical(self): pass - def test_model_cpu_offload_forward_pass(self): - # TODO(Yiyi) - this test should work, skipped for time reasons for now - pass - @slow @require_torch_gpu From 5a2dd24f589357516bd5bb38e09f02749c0b93e3 Mon Sep 17 00:00:00 2001 From: yiyixu Date: Wed, 29 Nov 2023 04:40:43 +0000 Subject: [PATCH 07/20] take off last to-do --- src/diffusers/pipelines/kandinsky3/kandinsky3_pipeline.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/pipelines/kandinsky3/kandinsky3_pipeline.py b/src/diffusers/pipelines/kandinsky3/kandinsky3_pipeline.py index a6a0e01084de..3494d3873ab4 100644 --- a/src/diffusers/pipelines/kandinsky3/kandinsky3_pipeline.py +++ b/src/diffusers/pipelines/kandinsky3/kandinsky3_pipeline.py @@ -397,8 +397,6 @@ def __call__( self.text_encoder_offload_hook.offload() # 7. Denoising loop - # TODO(Yiyi): Correct the following line and use correctly - # num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents From 89fdee4910e3d29e19bdaadf431b6255ff482815 Mon Sep 17 00:00:00 2001 From: yiyixu Date: Wed, 29 Nov 2023 04:44:02 +0000 Subject: [PATCH 08/20] another to-do --- .../pipelines/kandinsky3/kandinsky3img2img_pipeline.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/pipelines/kandinsky3/kandinsky3img2img_pipeline.py b/src/diffusers/pipelines/kandinsky3/kandinsky3img2img_pipeline.py index b043110cf1d7..a4ea60291606 100644 --- a/src/diffusers/pipelines/kandinsky3/kandinsky3img2img_pipeline.py +++ b/src/diffusers/pipelines/kandinsky3/kandinsky3img2img_pipeline.py @@ -409,8 +409,6 @@ def __call__( self.text_encoder_offload_hook.offload() # 7. Denoising loop - # TODO(Yiyi): Correct the following line and use correctly - # num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents From e408cdf98816f503439306d61fcc2a98b5df41f9 Mon Sep 17 00:00:00 2001 From: yiyixu Date: Wed, 29 Nov 2023 09:25:00 +0000 Subject: [PATCH 09/20] refactor --- src/diffusers/models/attention_processor.py | 4 +- src/diffusers/models/unet_kandi3.py | 16 +- .../kandinsky3/kandinsky3img2img_pipeline.py | 2 + tests/pipelines/kandinsky3/test_kandinsky3.py | 2 +- .../kandinsky3/test_kandinsky3_img2img.py | 212 ++++++++++++++++++ tests/pipelines/test_pipelines_common.py | 4 + 6 files changed, 236 insertions(+), 4 deletions(-) create mode 100644 tests/pipelines/kandinsky3/test_kandinsky3_img2img.py diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 40a335527ace..12939f6eabf2 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -109,6 +109,7 @@ def __init__( residual_connection: bool = False, _from_deprecated_attn_block: bool = False, processor: Optional["AttnProcessor"] = None, + scale_mask_factor: float = 1.0, out_dim: int = None, ): super().__init__() @@ -120,6 +121,7 @@ def __init__( self.residual_connection = residual_connection self.dropout = dropout self.out_dim = out_dim if out_dim is not None else query_dim + self.scale_mask_factor = scale_mask_factor # we make use of this private variable to know whether this class is loaded # with an deprecated state dict so that we can convert it on the fly @@ -595,7 +597,7 @@ def get_attention_scores( beta = 0 else: baddbmm_input = attention_mask - beta = 1 + beta = self.scale_mask_factor attention_scores = torch.baddbmm( baddbmm_input, diff --git a/src/diffusers/models/unet_kandi3.py b/src/diffusers/models/unet_kandi3.py index 3eea77ccb841..5dbbd095d069 100644 --- a/src/diffusers/models/unet_kandi3.py +++ b/src/diffusers/models/unet_kandi3.py @@ -465,7 +465,14 @@ def forward(self, x, time_embed): class Kandinsky3AttentionPooling(nn.Module): def __init__(self, num_channels, context_dim, head_dim=64): super().__init__() - self.attention = Attention(context_dim, context_dim, dim_head=head_dim, out_dim=num_channels, out_bias=False) + self.attention = Attention( + context_dim, + context_dim, + dim_head=head_dim, + out_dim=num_channels, + out_bias=False, + scale_mask_factor=-60000.0, + ) def forward(self, x, context, context_mask=None): context_mask = context_mask.unsqueeze(1).to(dtype=context.dtype) @@ -478,7 +485,12 @@ def __init__(self, num_channels, time_embed_dim, context_dim=None, norm_groups=3 super().__init__() self.in_norm = Kandinsky3ConditionalGroupNorm(norm_groups, num_channels, time_embed_dim) self.attention = Attention( - num_channels, context_dim or num_channels, dim_head=head_dim, out_dim=num_channels, out_bias=False + num_channels, + context_dim or num_channels, + dim_head=head_dim, + out_dim=num_channels, + out_bias=False, + scale_mask_factor=-60000.0, ) hidden_channels = expansion_ratio * num_channels diff --git a/src/diffusers/pipelines/kandinsky3/kandinsky3img2img_pipeline.py b/src/diffusers/pipelines/kandinsky3/kandinsky3img2img_pipeline.py index a4ea60291606..d698c15d68fa 100644 --- a/src/diffusers/pipelines/kandinsky3/kandinsky3img2img_pipeline.py +++ b/src/diffusers/pipelines/kandinsky3/kandinsky3img2img_pipeline.py @@ -452,6 +452,8 @@ def __call__( if output_type == "pil": image = self.numpy_to_pil(image) + self.maybe_free_model_hooks() + if not return_dict: return (image,) diff --git a/tests/pipelines/kandinsky3/test_kandinsky3.py b/tests/pipelines/kandinsky3/test_kandinsky3.py index 7b5691409383..c163fe3102c4 100644 --- a/tests/pipelines/kandinsky3/test_kandinsky3.py +++ b/tests/pipelines/kandinsky3/test_kandinsky3.py @@ -163,7 +163,7 @@ def test_float16_inference(self): super().test_float16_inference(expected_max_diff=1e-1) def test_inference_batch_single_identical(self): - pass + super().test_inference_batch_single_identical(expected_max_diff=1e-2) @slow diff --git a/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py b/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py new file mode 100644 index 000000000000..a3c2697d0fbc --- /dev/null +++ b/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py @@ -0,0 +1,212 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import random +import unittest + +import numpy as np +import torch +from PIL import Image +from transformers import AutoTokenizer, T5EncoderModel + +from diffusers import ( + AutoPipelineForImage2Image, + Kandinsky3Img2ImgPipeline, + Kandinsky3UNet, + VQModel, +) +from diffusers.image_processor import VaeImageProcessor +from diffusers.schedulers.scheduling_ddpm import DDPMScheduler +from diffusers.utils.testing_utils import ( + enable_full_determinism, + floats_tensor, + load_image, + require_torch_gpu, + slow, +) + +from ..pipeline_params import ( + IMAGE_TO_IMAGE_IMAGE_PARAMS, + TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, + TEXT_GUIDED_IMAGE_VARIATION_PARAMS, +) +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class Kandinsky3Img2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = Kandinsky3Img2ImgPipeline + params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"} + batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS + image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS + test_xformers_attention = False + + @property + def dummy_movq_kwargs(self): + return { + "block_out_channels": [32, 64], + "down_block_types": ["DownEncoderBlock2D", "AttnDownEncoderBlock2D"], + "in_channels": 3, + "latent_channels": 4, + "layers_per_block": 1, + "norm_num_groups": 8, + "norm_type": "spatial", + "num_vq_embeddings": 12, + "out_channels": 3, + "up_block_types": [ + "AttnUpDecoderBlock2D", + "UpDecoderBlock2D", + ], + "vq_embed_dim": 4, + } + + @property + def dummy_movq(self): + torch.manual_seed(0) + model = VQModel(**self.dummy_movq_kwargs) + return model + + def get_dummy_components(self, time_cond_proj_dim=None): + torch.manual_seed(0) + unet = Kandinsky3UNet( + in_channels=4, + time_embedding_dim=4, + groups=2, + attention_head_dim=4, + layers_per_block=3, + block_out_channels=(32, 64), + cross_attention_dim=4, + encoder_hid_dim=32, + ) + scheduler = DDPMScheduler( + beta_start=0.00085, + beta_end=0.012, + steps_offset=1, + beta_schedule="squaredcos_cap_v2", + clip_sample=True, + thresholding=False, + ) + torch.manual_seed(0) + movq = self.dummy_movq + torch.manual_seed(0) + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + components = { + "unet": unet, + "scheduler": scheduler, + "movq": movq, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + # create init_image + image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device) + image = image.cpu().permute(0, 2, 3, 1)[0] + init_image = Image.fromarray(np.uint8(image)).convert("RGB") + + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "image": init_image, + "generator": generator, + "strength": 0.75, + "num_inference_steps": 10, + "guidance_scale": 6.0, + "output_type": "np", + } + return inputs + + def test_kandinsky3_img2img(self): + device = "cpu" + + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + + pipe.set_progress_bar_config(disable=None) + + output = pipe(**self.get_dummy_inputs(device)) + image = output.images + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + + expected_slice = np.array( + [0.576259, 0.6132097, 0.41703486, 0.603196, 0.62062526, 0.4655338, 0.5434324, 0.5660727, 0.65433365] + ) + + assert ( + np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}" + + def test_float16_inference(self): + super().test_float16_inference(expected_max_diff=1e-1) + + def test_inference_batch_single_identical(self): + super().test_inference_batch_single_identical(expected_max_diff=1e-2) + + +@slow +@require_torch_gpu +class Kandinsky3Img2ImgPipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_kandinskyV3_img2img(self): + pipe = AutoPipelineForImage2Image.from_pretrained( + "kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16 + ) + pipe.enable_model_cpu_offload() + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device="cpu").manual_seed(0) + + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky3/t2i.png" + ) + w, h = 512, 512 + image = image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1) + prompt = "A painting of the inside of a subway train with tiny raccoons." + + image = pipe(prompt, image=image, strength=0.75, num_inference_steps=25, generator=generator).images[0] + + assert image.size == (512, 512) + + expected_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky3/i2i.png" + ) + + image_processor = VaeImageProcessor() + + image_np = image_processor.pil_to_numpy(image) + expected_image_np = image_processor.pil_to_numpy(expected_image) + + self.assertTrue(np.allclose(image_np, expected_image_np, atol=5e-2)) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index e11175921184..cac5ee442ae6 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -377,6 +377,10 @@ def test_save_load_local(self, expected_max_difference=5e-4): with CaptureLogger(logger) as cap_logger: pipe_loaded = self.pipeline_class.from_pretrained(tmpdir) + for component in pipe_loaded.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + for name in pipe_loaded.components.keys(): if name not in pipe_loaded._optional_components: assert name in str(cap_logger) From 7dced94116d007bf49ace508811275783feed3c9 Mon Sep 17 00:00:00 2001 From: yiyixu Date: Wed, 29 Nov 2023 10:08:06 +0000 Subject: [PATCH 10/20] add callback and latent output for text2img --- .../kandinsky3/kandinsky3_pipeline.py | 133 ++++++++++++++---- 1 file changed, 103 insertions(+), 30 deletions(-) diff --git a/src/diffusers/pipelines/kandinsky3/kandinsky3_pipeline.py b/src/diffusers/pipelines/kandinsky3/kandinsky3_pipeline.py index 3494d3873ab4..8af4716da1c8 100644 --- a/src/diffusers/pipelines/kandinsky3/kandinsky3_pipeline.py +++ b/src/diffusers/pipelines/kandinsky3/kandinsky3_pipeline.py @@ -1,4 +1,4 @@ -from typing import Callable, List, Optional, Union +from typing import Callable, Dict, List, Optional, Union import torch from transformers import T5EncoderModel, T5Tokenizer @@ -7,6 +7,7 @@ from ...models import Kandinsky3UNet, VQModel from ...schedulers import DDPMScheduler from ...utils import ( + deprecate, is_accelerate_available, logging, ) @@ -29,6 +30,13 @@ def downscale_height_and_width(height, width, scale_factor=8): class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin): model_cpu_offload_seq = "text_encoder->unet->movq" + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "negative_attention_mask", + "attention_mask", + ] def __init__( self, @@ -228,14 +236,19 @@ def check_inputs( negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, ): - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) if prompt is not None and prompt_embeds is not None: raise ValueError( @@ -263,6 +276,18 @@ def check_inputs( f" {negative_prompt_embeds.shape}." ) + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + @torch.no_grad() def __call__( self, @@ -278,9 +303,10 @@ def __call__( negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, latents=None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, ): """ Function invoked when calling the pipeline for generation. @@ -344,11 +370,44 @@ def __call__( `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + cut_context = True device = self._execution_device # 1. Check inputs. Raise error if not correct - self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) + self.check_inputs( + prompt, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -357,15 +416,10 @@ def __call__( else: batch_size = prompt_embeds.shape[0] - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - # 3. Encode input prompt prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask = self.encode_prompt( prompt, - do_classifier_free_guidance, + self.do_classifier_free_guidance, num_images_per_prompt=num_images_per_prompt, device=device, negative_prompt=negative_prompt, @@ -374,7 +428,7 @@ def __call__( _cut_context=cut_context, ) - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) attention_mask = torch.cat([negative_attention_mask, attention_mask]).bool() # 4. Prepare timesteps @@ -397,9 +451,11 @@ def __call__( self.text_encoder_offload_hook.offload() # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents # predict the noise residual noise_pred = self.unet( @@ -410,7 +466,7 @@ def __call__( return_dict=False, )[0] - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = (guidance_scale + 1.0) * noise_pred_text - guidance_scale * noise_pred_uncond @@ -423,26 +479,43 @@ def __call__( latents, generator=generator, ).prev_sample - progress_bar.update() - if callback is not None and i % callback_steps == 0: - step_idx = i // getattr(self.scheduler, "order", 1) - callback(step_idx, t, latents) - # post-processing - image = self.movq.decode(latents, force_not_quantize=True)["sample"] + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + attention_mask = callback_outputs.pop("attention_mask", attention_mask) + negative_attention_mask = callback_outputs.pop("negative_attention_mask", negative_attention_mask) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) - if output_type not in ["pt", "np", "pil"]: + # post-processing + if output_type not in ["pt", "np", "pil", "latent"]: raise ValueError( - f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}" + f"Only the output types `pt`, `pil`, `np` and `latent` are supported not output_type={output_type}" ) - if output_type in ["np", "pil"]: - image = image * 0.5 + 0.5 - image = image.clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).float().numpy() + if not output_type == "latent": + image = self.movq.decode(latents, force_not_quantize=True)["sample"] + + if output_type in ["np", "pil"]: + image = image * 0.5 + 0.5 + image = image.clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() - if output_type == "pil": - image = self.numpy_to_pil(image) + if output_type == "pil": + image = self.numpy_to_pil(image) + else: + image = latents self.maybe_free_model_hooks() From 51fe17bcd8d76b2b16403f4bf15efcfa848eb494 Mon Sep 17 00:00:00 2001 From: yiyixu Date: Wed, 29 Nov 2023 10:29:14 +0000 Subject: [PATCH 11/20] refactor img2img --- .../kandinsky3/kandinsky3img2img_pipeline.py | 133 ++++++++++++++---- .../kandinsky3/test_kandinsky3_img2img.py | 4 + 2 files changed, 107 insertions(+), 30 deletions(-) diff --git a/src/diffusers/pipelines/kandinsky3/kandinsky3img2img_pipeline.py b/src/diffusers/pipelines/kandinsky3/kandinsky3img2img_pipeline.py index d698c15d68fa..5a27baeefb19 100644 --- a/src/diffusers/pipelines/kandinsky3/kandinsky3img2img_pipeline.py +++ b/src/diffusers/pipelines/kandinsky3/kandinsky3img2img_pipeline.py @@ -1,5 +1,5 @@ import inspect -from typing import Callable, List, Optional, Union +from typing import Callable, Dict, List, Optional, Union import numpy as np import PIL @@ -11,6 +11,7 @@ from ...models import Kandinsky3UNet, VQModel from ...schedulers import DDPMScheduler from ...utils import ( + deprecate, is_accelerate_available, logging, ) @@ -41,6 +42,13 @@ def prepare_image(pil_image): class Kandinsky3Img2ImgPipeline(DiffusionPipeline, LoraLoaderMixin): model_cpu_offload_seq = "text_encoder->unet->movq" + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "negative_attention_mask", + "attention_mask", + ] def __init__( self, @@ -299,15 +307,21 @@ def check_inputs( negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, ): - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" @@ -334,6 +348,18 @@ def check_inputs( f" {negative_prompt_embeds.shape}." ) + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + @torch.no_grad() def __call__( self, @@ -349,13 +375,46 @@ def __call__( negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, latents=None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, ): + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + cut_context = True # 1. Check inputs. Raise error if not correct - self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) + self.check_inputs( + prompt, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -366,15 +425,10 @@ def __call__( device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - # 3. Encode input prompt prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask = self.encode_prompt( prompt, - do_classifier_free_guidance, + self.do_classifier_free_guidance, num_images_per_prompt=num_images_per_prompt, device=device, negative_prompt=negative_prompt, @@ -383,7 +437,7 @@ def __call__( _cut_context=cut_context, ) - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) attention_mask = torch.cat([negative_attention_mask, attention_mask]).bool() if not isinstance(image, list): @@ -409,9 +463,11 @@ def __call__( self.text_encoder_offload_hook.offload() # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents # predict the noise residual noise_pred = self.unet( @@ -420,7 +476,7 @@ def __call__( encoder_hidden_states=prompt_embeds, encoder_attention_mask=attention_mask, )[0] - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = (guidance_scale + 1.0) * noise_pred_text - guidance_scale * noise_pred_uncond @@ -432,25 +488,42 @@ def __call__( latents, generator=generator, ).prev_sample - progress_bar.update() - if callback is not None and i % callback_steps == 0: - step_idx = i // getattr(self.scheduler, "order", 1) - callback(step_idx, t, latents) - # post-processing - image = self.movq.decode(latents, force_not_quantize=True)["sample"] - if output_type not in ["pt", "np", "pil"]: + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + attention_mask = callback_outputs.pop("attention_mask", attention_mask) + negative_attention_mask = callback_outputs.pop("negative_attention_mask", negative_attention_mask) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # post-processing + if output_type not in ["pt", "np", "pil", "latent"]: raise ValueError( - f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}" + f"Only the output types `pt`, `pil`, `np` and `latent` are supported not output_type={output_type}" ) + if not output_type == "latent": + image = self.movq.decode(latents, force_not_quantize=True)["sample"] - if output_type in ["np", "pil"]: - image = image * 0.5 + 0.5 - image = image.clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).float().numpy() + if output_type in ["np", "pil"]: + image = image * 0.5 + 0.5 + image = image.clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() - if output_type == "pil": - image = self.numpy_to_pil(image) + if output_type == "pil": + image = self.numpy_to_pil(image) + else: + image = latents self.maybe_free_model_hooks() diff --git a/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py b/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py index a3c2697d0fbc..fcd1ed0dc94d 100644 --- a/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py +++ b/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py @@ -42,6 +42,8 @@ IMAGE_TO_IMAGE_IMAGE_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS, + TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS, + TEXT_TO_IMAGE_IMAGE_PARAMS, ) from ..test_pipelines_common import PipelineTesterMixin @@ -54,6 +56,8 @@ class Kandinsky3Img2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase) params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"} batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS test_xformers_attention = False @property From f5cfa5ad9223e2b7e633f951289a26c5893f3fc8 Mon Sep 17 00:00:00 2001 From: yiyixu Date: Wed, 29 Nov 2023 10:42:35 +0000 Subject: [PATCH 12/20] change unet file name --- src/diffusers/models/__init__.py | 4 ++-- ...{unet_kandi3.py => unet_2d_condition_kandi3.py} | 14 ++++++++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) rename src/diffusers/models/{unet_kandi3.py => unet_2d_condition_kandi3.py} (97%) diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index de2e2848b848..13f05f78b769 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -36,7 +36,7 @@ _import_structure["unet_2d"] = ["UNet2DModel"] _import_structure["unet_2d_condition"] = ["UNet2DConditionModel"] _import_structure["unet_3d_condition"] = ["UNet3DConditionModel"] - _import_structure["unet_kandi3"] = ["Kandinsky3UNet"] + _import_structure["unet_2d_condition_kandi3"] = ["Kandinsky3UNet"] _import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"] _import_structure["vq_model"] = ["VQModel"] @@ -64,7 +64,7 @@ from .unet_2d import UNet2DModel from .unet_2d_condition import UNet2DConditionModel from .unet_3d_condition import UNet3DConditionModel - from .unet_kandi3 import Kandinsky3UNet + from .unet_2d_condition_kandi3 import Kandinsky3UNet from .unet_motion_model import MotionAdapter, UNetMotionModel from .vq_model import VQModel diff --git a/src/diffusers/models/unet_kandi3.py b/src/diffusers/models/unet_2d_condition_kandi3.py similarity index 97% rename from src/diffusers/models/unet_kandi3.py rename to src/diffusers/models/unet_2d_condition_kandi3.py index 5dbbd095d069..1ce831393c37 100644 --- a/src/diffusers/models/unet_kandi3.py +++ b/src/diffusers/models/unet_2d_condition_kandi3.py @@ -1,3 +1,17 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from dataclasses import dataclass from typing import Dict, Tuple, Union From 334cd2eb64c97d57f02b9df8de0147c9fc22a7cd Mon Sep 17 00:00:00 2001 From: yiyixu Date: Wed, 29 Nov 2023 11:09:21 +0000 Subject: [PATCH 13/20] add doc string --- src/diffusers/models/__init__.py | 4 +- .../kandinsky3/kandinsky3_pipeline.py | 25 ++++++ .../kandinsky3/kandinsky3img2img_pipeline.py | 83 ++++++++++++++++++- 3 files changed, 109 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 13f05f78b769..6346523a7004 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -35,8 +35,8 @@ _import_structure["unet_1d"] = ["UNet1DModel"] _import_structure["unet_2d"] = ["UNet2DModel"] _import_structure["unet_2d_condition"] = ["UNet2DConditionModel"] - _import_structure["unet_3d_condition"] = ["UNet3DConditionModel"] _import_structure["unet_2d_condition_kandi3"] = ["Kandinsky3UNet"] + _import_structure["unet_3d_condition"] = ["UNet3DConditionModel"] _import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"] _import_structure["vq_model"] = ["VQModel"] @@ -63,8 +63,8 @@ from .unet_1d import UNet1DModel from .unet_2d import UNet2DModel from .unet_2d_condition import UNet2DConditionModel - from .unet_3d_condition import UNet3DConditionModel from .unet_2d_condition_kandi3 import Kandinsky3UNet + from .unet_3d_condition import UNet3DConditionModel from .unet_motion_model import MotionAdapter, UNetMotionModel from .vq_model import VQModel diff --git a/src/diffusers/pipelines/kandinsky3/kandinsky3_pipeline.py b/src/diffusers/pipelines/kandinsky3/kandinsky3_pipeline.py index 8af4716da1c8..dab0d15918b7 100644 --- a/src/diffusers/pipelines/kandinsky3/kandinsky3_pipeline.py +++ b/src/diffusers/pipelines/kandinsky3/kandinsky3_pipeline.py @@ -10,6 +10,7 @@ deprecate, is_accelerate_available, logging, + replace_example_docstring, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput @@ -17,6 +18,23 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import AutoPipelineForText2Image + >>> import torch + + >>> pipe = AutoPipelineForText2Image.from_pretrained("kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16) + >>> pipe.enable_model_cpu_offload() + + >>> prompt = "A photograph of the inside of a subway train. There are raccoons sitting on the seats. One of them is reading a newspaper. The window shows the city in the background." + + >>> generator = torch.Generator(device="cpu").manual_seed(0) + >>> image = pipe(prompt, num_inference_steps=25, generator=generator).images[0] + ``` + +""" + def downscale_height_and_width(height, width, scale_factor=8): new_height = height // scale_factor**2 @@ -289,6 +307,7 @@ def num_timesteps(self): return self._num_timesteps @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, @@ -369,6 +388,12 @@ def __call__( A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + """ callback = kwargs.pop("callback", None) diff --git a/src/diffusers/pipelines/kandinsky3/kandinsky3img2img_pipeline.py b/src/diffusers/pipelines/kandinsky3/kandinsky3img2img_pipeline.py index 5a27baeefb19..b830179196a5 100644 --- a/src/diffusers/pipelines/kandinsky3/kandinsky3img2img_pipeline.py +++ b/src/diffusers/pipelines/kandinsky3/kandinsky3img2img_pipeline.py @@ -14,6 +14,7 @@ deprecate, is_accelerate_available, logging, + replace_example_docstring, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput @@ -21,6 +22,24 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import AutoPipelineForImage2Image + >>> from diffusers.utils import load_image + >>> import torch + + >>> pipe = AutoPipelineForImage2Image.from_pretrained("kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16) + >>> pipe.enable_model_cpu_offload() + + >>> prompt = "A painting of the inside of a subway train with tiny raccoons." + >>> image = load_image("https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky3/t2i.png") + + >>> generator = torch.Generator(device="cpu").manual_seed(0) + >>> image = pipe(prompt, image=image, strength=0.75, num_inference_steps=25, generator=generator).images[0] + ``` +""" + def downscale_height_and_width(height, width, scale_factor=8): new_height = height // scale_factor**2 @@ -361,6 +380,7 @@ def num_timesteps(self): return self._num_timesteps @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, @@ -375,11 +395,72 @@ def __call__( negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - latents=None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + strength (`float`, *optional*, defaults to 0.8): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 3.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + + """ callback = kwargs.pop("callback", None) callback_steps = kwargs.pop("callback_steps", None) From dbf5135de4c7244bbd4baace1e33ba2d64f456df Mon Sep 17 00:00:00 2001 From: yiyixu Date: Wed, 29 Nov 2023 11:15:16 +0000 Subject: [PATCH 14/20] change pipeline file name --- src/diffusers/pipelines/kandinsky3/__init__.py | 8 ++++---- .../{kandinsky3_pipeline.py => pipeline_kandinsky3.py} | 0 ...img2img_pipeline.py => pipeline_kandinsky3_img2img.py} | 0 3 files changed, 4 insertions(+), 4 deletions(-) rename src/diffusers/pipelines/kandinsky3/{kandinsky3_pipeline.py => pipeline_kandinsky3.py} (100%) rename src/diffusers/pipelines/kandinsky3/{kandinsky3img2img_pipeline.py => pipeline_kandinsky3_img2img.py} (100%) diff --git a/src/diffusers/pipelines/kandinsky3/__init__.py b/src/diffusers/pipelines/kandinsky3/__init__.py index 4da3a83c0448..e8a3063141b5 100644 --- a/src/diffusers/pipelines/kandinsky3/__init__.py +++ b/src/diffusers/pipelines/kandinsky3/__init__.py @@ -21,8 +21,8 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: - _import_structure["kandinsky3_pipeline"] = ["Kandinsky3Pipeline"] - _import_structure["kandinsky3img2img_pipeline"] = ["Kandinsky3Img2ImgPipeline"] + _import_structure["pipeline_kandinsky3"] = ["Kandinsky3Pipeline"] + _import_structure["pipeline_kandinsky3_img2img"] = ["Kandinsky3Img2ImgPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: @@ -33,8 +33,8 @@ except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: - from .kandinsky3_pipeline import Kandinsky3Pipeline - from .kandinsky3img2img_pipeline import Kandinsky3Img2ImgPipeline + from .pipeline_kandinsky3 import Kandinsky3Pipeline + from .pipeline_kandinsky3_img2img import Kandinsky3Img2ImgPipeline else: import sys diff --git a/src/diffusers/pipelines/kandinsky3/kandinsky3_pipeline.py b/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py similarity index 100% rename from src/diffusers/pipelines/kandinsky3/kandinsky3_pipeline.py rename to src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py diff --git a/src/diffusers/pipelines/kandinsky3/kandinsky3img2img_pipeline.py b/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py similarity index 100% rename from src/diffusers/pipelines/kandinsky3/kandinsky3img2img_pipeline.py rename to src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py From 25c4e07b42ad77b4ca0edfcad15745438bf29e1e Mon Sep 17 00:00:00 2001 From: yiyixu Date: Wed, 29 Nov 2023 12:04:14 +0000 Subject: [PATCH 15/20] fix failing test --- tests/pipelines/kandinsky3/test_kandinsky3_img2img.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py b/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py index fcd1ed0dc94d..581251a81639 100644 --- a/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py +++ b/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py @@ -59,6 +59,15 @@ class Kandinsky3Img2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase) image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS test_xformers_attention = False + required_optional_params = frozenset( + [ + "num_inference_steps", + "num_images_per_prompt", + "generator", + "output_type", + "return_dict", + ] + ) @property def dummy_movq_kwargs(self): From dd198cb729d450890a5e36eab636c2188010f2e6 Mon Sep 17 00:00:00 2001 From: yiyixu Date: Wed, 29 Nov 2023 17:34:58 +0000 Subject: [PATCH 16/20] rename unet file --- src/diffusers/models/__init__.py | 4 ++-- .../{unet_2d_condition_kandi3.py => unet_kandinsky3.py} | 0 2 files changed, 2 insertions(+), 2 deletions(-) rename src/diffusers/models/{unet_2d_condition_kandi3.py => unet_kandinsky3.py} (100%) diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 6346523a7004..103f014a2b89 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -35,8 +35,8 @@ _import_structure["unet_1d"] = ["UNet1DModel"] _import_structure["unet_2d"] = ["UNet2DModel"] _import_structure["unet_2d_condition"] = ["UNet2DConditionModel"] - _import_structure["unet_2d_condition_kandi3"] = ["Kandinsky3UNet"] _import_structure["unet_3d_condition"] = ["UNet3DConditionModel"] + _import_structure["unet_kandinsky3"] = ["Kandinsky3UNet"] _import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"] _import_structure["vq_model"] = ["VQModel"] @@ -63,8 +63,8 @@ from .unet_1d import UNet1DModel from .unet_2d import UNet2DModel from .unet_2d_condition import UNet2DConditionModel - from .unet_2d_condition_kandi3 import Kandinsky3UNet from .unet_3d_condition import UNet3DConditionModel + from .unet_kandinsky3 import Kandinsky3UNet from .unet_motion_model import MotionAdapter, UNetMotionModel from .vq_model import VQModel diff --git a/src/diffusers/models/unet_2d_condition_kandi3.py b/src/diffusers/models/unet_kandinsky3.py similarity index 100% rename from src/diffusers/models/unet_2d_condition_kandi3.py rename to src/diffusers/models/unet_kandinsky3.py From d60bc4e4c7d6f21cbe8d561319ccf788de553367 Mon Sep 17 00:00:00 2001 From: yiyixu Date: Wed, 29 Nov 2023 18:41:53 +0000 Subject: [PATCH 17/20] testing prints --- src/diffusers/models/attention_processor.py | 4 +--- src/diffusers/models/unet_kandinsky3.py | 11 +++++++---- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 12939f6eabf2..40a335527ace 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -109,7 +109,6 @@ def __init__( residual_connection: bool = False, _from_deprecated_attn_block: bool = False, processor: Optional["AttnProcessor"] = None, - scale_mask_factor: float = 1.0, out_dim: int = None, ): super().__init__() @@ -121,7 +120,6 @@ def __init__( self.residual_connection = residual_connection self.dropout = dropout self.out_dim = out_dim if out_dim is not None else query_dim - self.scale_mask_factor = scale_mask_factor # we make use of this private variable to know whether this class is loaded # with an deprecated state dict so that we can convert it on the fly @@ -597,7 +595,7 @@ def get_attention_scores( beta = 0 else: baddbmm_input = attention_mask - beta = self.scale_mask_factor + beta = 1 attention_scores = torch.baddbmm( baddbmm_input, diff --git a/src/diffusers/models/unet_kandinsky3.py b/src/diffusers/models/unet_kandinsky3.py index 1ce831393c37..620220b9c9ed 100644 --- a/src/diffusers/models/unet_kandinsky3.py +++ b/src/diffusers/models/unet_kandinsky3.py @@ -210,6 +210,11 @@ def _set_gradient_checkpointing(self, module, value=False): module.gradient_checkpointing = value def forward(self, sample, timestep, encoder_hidden_states=None, encoder_attention_mask=None, return_dict=True): + + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + if not torch.is_tensor(timestep): dtype = torch.float32 if isinstance(timestep, float) else torch.int32 timestep = torch.tensor([timestep], dtype=dtype, device=sample.device) @@ -485,11 +490,10 @@ def __init__(self, num_channels, context_dim, head_dim=64): dim_head=head_dim, out_dim=num_channels, out_bias=False, - scale_mask_factor=-60000.0, ) def forward(self, x, context, context_mask=None): - context_mask = context_mask.unsqueeze(1).to(dtype=context.dtype) + context_mask = context_mask.to(dtype=context.dtype) context = self.attention(context.mean(dim=1, keepdim=True), context, context_mask) return x + context.squeeze(1) @@ -504,7 +508,6 @@ def __init__(self, num_channels, time_embed_dim, context_dim=None, norm_groups=3 dim_head=head_dim, out_dim=num_channels, out_bias=False, - scale_mask_factor=-60000.0, ) hidden_channels = expansion_ratio * num_channels @@ -521,7 +524,7 @@ def forward(self, x, time_embed, context=None, context_mask=None, image_mask=Non out = out.reshape(x.shape[0], -1, height * width).permute(0, 2, 1) context = context if context is not None else out if context_mask is not None: - context_mask = context_mask.unsqueeze(1).to(dtype=context.dtype) + context_mask = context_mask.to(dtype=context.dtype) out = self.attention(out, context, context_mask) out = out.permute(0, 2, 1).unsqueeze(-1).reshape(out.shape[0], -1, height, width) From 1d170cc08add986c5a61cacc67be73e8f37a52fd Mon Sep 17 00:00:00 2001 From: yiyixu Date: Wed, 29 Nov 2023 23:31:02 +0000 Subject: [PATCH 18/20] style --- src/diffusers/models/unet_kandinsky3.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/models/unet_kandinsky3.py b/src/diffusers/models/unet_kandinsky3.py index 620220b9c9ed..eef3287e5d99 100644 --- a/src/diffusers/models/unet_kandinsky3.py +++ b/src/diffusers/models/unet_kandinsky3.py @@ -210,11 +210,10 @@ def _set_gradient_checkpointing(self, module, value=False): module.gradient_checkpointing = value def forward(self, sample, timestep, encoder_hidden_states=None, encoder_attention_mask=None, return_dict=True): - if encoder_attention_mask is not None: encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) - + if not torch.is_tensor(timestep): dtype = torch.float32 if isinstance(timestep, float) else torch.int32 timestep = torch.tensor([timestep], dtype=dtype, device=sample.device) From c4eae7e6405bab60d66971de9c485216e3f92f89 Mon Sep 17 00:00:00 2001 From: yiyixu Date: Thu, 30 Nov 2023 04:17:42 +0000 Subject: [PATCH 19/20] allow pass prompt_embeds --- .../kandinsky3/pipeline_kandinsky3.py | 39 ++++++++++++++++++ .../kandinsky3/pipeline_kandinsky3_img2img.py | 40 +++++++++++++++++++ 2 files changed, 79 insertions(+) diff --git a/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py b/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py index dab0d15918b7..4d14fc637b05 100644 --- a/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +++ b/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py @@ -103,6 +103,8 @@ def encode_prompt( prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, _cut_context=False, + attention_mask: Optional[torch.FloatTensor] = None, + negative_attention_mask: Optional[torch.FloatTensor] = None, ): r""" Encodes the prompt into text encoder hidden states. @@ -127,6 +129,10 @@ def encode_prompt( Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + attention_mask (`torch.FloatTensor`, *optional*): + Pre-generated attention mask. Must provide if passing `prompt_embeds` directly. + negative_attention_mask (`torch.FloatTensor`, *optional*): + Pre-generated negative attention mask. Must provide if passing `negative_prompt_embeds` directly. """ if prompt is not None and negative_prompt is not None: if type(prompt) is not type(negative_prompt): @@ -255,6 +261,8 @@ def check_inputs( prompt_embeds=None, negative_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, + attention_mask=None, + negative_attention_mask=None, ): if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( @@ -293,6 +301,27 @@ def check_inputs( f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) + if negative_prompt_embeds is not None and negative_attention_mask is None: + raise ValueError("Please provide `negative_attention_mask` along with `negative_prompt_embeds`") + + if negative_prompt_embeds is not None and negative_attention_mask is not None: + if negative_prompt_embeds.shape[:2] != negative_attention_mask.shape: + raise ValueError( + "`negative_prompt_embeds` and `negative_attention_mask` must have the same batch_size and token length when passed directly, but" + f" got: `negative_prompt_embeds` {negative_prompt_embeds.shape[:2]} != `negative_attention_mask`" + f" {negative_attention_mask.shape}." + ) + + if prompt_embeds is not None and attention_mask is None: + raise ValueError("Please provide `attention_mask` along with `prompt_embeds`") + + if prompt_embeds is not None and attention_mask is not None: + if prompt_embeds.shape[:2] != attention_mask.shape: + raise ValueError( + "`prompt_embeds` and `attention_mask` must have the same batch_size and token length when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape[:2]} != `attention_mask`" + f" {attention_mask.shape}." + ) @property def guidance_scale(self): @@ -320,6 +349,8 @@ def __call__( generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + negative_attention_mask: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, latents=None, @@ -369,6 +400,10 @@ def __call__( Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + attention_mask (`torch.FloatTensor`, *optional*): + Pre-generated attention mask. Must provide if passing `prompt_embeds` directly. + negative_attention_mask (`torch.FloatTensor`, *optional*): + Pre-generated negative attention mask. Must provide if passing `negative_prompt_embeds` directly. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -430,6 +465,8 @@ def __call__( prompt_embeds, negative_prompt_embeds, callback_on_step_end_tensor_inputs, + attention_mask, + negative_attention_mask, ) self._guidance_scale = guidance_scale @@ -451,6 +488,8 @@ def __call__( prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, _cut_context=cut_context, + attention_mask=attention_mask, + negative_attention_mask=negative_attention_mask, ) if self.do_classifier_free_guidance: diff --git a/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py b/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py index b830179196a5..edeb3955ec76 100644 --- a/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +++ b/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py @@ -126,6 +126,8 @@ def encode_prompt( prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, _cut_context=False, + attention_mask: Optional[torch.FloatTensor] = None, + negative_attention_mask: Optional[torch.FloatTensor] = None, ): r""" Encodes the prompt into text encoder hidden states. @@ -150,6 +152,10 @@ def encode_prompt( Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + attention_mask (`torch.FloatTensor`, *optional*): + Pre-generated attention mask. Must provide if passing `prompt_embeds` directly. + negative_attention_mask (`torch.FloatTensor`, *optional*): + Pre-generated negative attention mask. Must provide if passing `negative_prompt_embeds` directly. """ if prompt is not None and negative_prompt is not None: if type(prompt) is not type(negative_prompt): @@ -327,6 +333,8 @@ def check_inputs( prompt_embeds=None, negative_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, + attention_mask=None, + negative_attention_mask=None, ): if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( @@ -367,6 +375,28 @@ def check_inputs( f" {negative_prompt_embeds.shape}." ) + if negative_prompt_embeds is not None and negative_attention_mask is None: + raise ValueError("Please provide `negative_attention_mask` along with `negative_prompt_embeds`") + + if negative_prompt_embeds is not None and negative_attention_mask is not None: + if negative_prompt_embeds.shape[:2] != negative_attention_mask.shape: + raise ValueError( + "`negative_prompt_embeds` and `negative_attention_mask` must have the same batch_size and token length when passed directly, but" + f" got: `negative_prompt_embeds` {negative_prompt_embeds.shape[:2]} != `negative_attention_mask`" + f" {negative_attention_mask.shape}." + ) + + if prompt_embeds is not None and attention_mask is None: + raise ValueError("Please provide `attention_mask` along with `prompt_embeds`") + + if prompt_embeds is not None and attention_mask is not None: + if prompt_embeds.shape[:2] != attention_mask.shape: + raise ValueError( + "`prompt_embeds` and `attention_mask` must have the same batch_size and token length when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape[:2]} != `attention_mask`" + f" {attention_mask.shape}." + ) + @property def guidance_scale(self): return self._guidance_scale @@ -393,6 +423,8 @@ def __call__( generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + negative_attention_mask: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, @@ -440,6 +472,10 @@ def __call__( Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + attention_mask (`torch.FloatTensor`, *optional*): + Pre-generated attention mask. Must provide if passing `prompt_embeds` directly. + negative_attention_mask (`torch.FloatTensor`, *optional*): + Pre-generated negative attention mask. Must provide if passing `negative_prompt_embeds` directly. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -493,6 +529,8 @@ def __call__( prompt_embeds, negative_prompt_embeds, callback_on_step_end_tensor_inputs, + attention_mask, + negative_attention_mask, ) self._guidance_scale = guidance_scale @@ -516,6 +554,8 @@ def __call__( prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, _cut_context=cut_context, + attention_mask=attention_mask, + negative_attention_mask=negative_attention_mask, ) if self.do_classifier_free_guidance: From 226f755091b70ba2ac78bad075aaa46a4cdba26a Mon Sep 17 00:00:00 2001 From: yiyixu Date: Thu, 30 Nov 2023 06:14:50 +0000 Subject: [PATCH 20/20] offload --- .../pipelines/kandinsky3/pipeline_kandinsky3_img2img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py b/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py index edeb3955ec76..7f4164a04d1e 100644 --- a/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +++ b/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py @@ -60,7 +60,7 @@ def prepare_image(pil_image): class Kandinsky3Img2ImgPipeline(DiffusionPipeline, LoraLoaderMixin): - model_cpu_offload_seq = "text_encoder->unet->movq" + model_cpu_offload_seq = "text_encoder->movq->unet->movq" _callback_tensor_inputs = [ "latents", "prompt_embeds",