From ffb010755eadab9907d7392da470847b724c6e26 Mon Sep 17 00:00:00 2001 From: Aryan V S Date: Wed, 25 Oct 2023 20:49:15 +0530 Subject: [PATCH] Improve typehints and docs in `diffusers/models` (#5391) * improvement: add typehints and docs to src/diffusers/models/attention_processor.py * improvement: add typehints and docs to src/diffusers/models/vae.py * improvement: add missing docs in src/diffusers/models/vq_model.py * improvement: add typehints and docs to src/diffusers/models/transformer_temporal.py * improvement: add typehints and docs to src/diffusers/models/t5_film_transformer.py * improvement: add type hints to src/diffusers/models/unet_1d_blocks.py * improvement: add missing type hints to src/diffusers/models/unet_2d_blocks.py * fix: CI error (make fix-copies required) * fix: CI error (make fix-copies required again) --------- Co-authored-by: Dhruv Nair --- models/attention_processor.py | 386 ++++++++++++--- models/t5_film_transformer.py | 171 +++++-- models/transformer_temporal.py | 28 +- models/unet_1d_blocks.py | 166 ++++--- models/unet_2d_blocks.py | 457 ++++++++++-------- models/vae.py | 227 +++++++-- models/vq_model.py | 10 +- .../alt_diffusion/pipeline_alt_diffusion.py | 1 - .../pipeline_alt_diffusion_img2img.py | 1 - .../versatile_diffusion/modeling_text_unet.py | 105 ++-- 10 files changed, 1090 insertions(+), 462 deletions(-) diff --git a/models/attention_processor.py b/models/attention_processor.py index 89ecf143be22..efed305a0e96 100644 --- a/models/attention_processor.py +++ b/models/attention_processor.py @@ -40,14 +40,50 @@ class Attention(nn.Module): A cross attention layer. Parameters: - query_dim (`int`): The number of channels in the query. + query_dim (`int`): + The number of channels in the query. cross_attention_dim (`int`, *optional*): The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. - heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. - dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + heads (`int`, *optional*, defaults to 8): + The number of heads to use for multi-head attention. + dim_head (`int`, *optional*, defaults to 64): + The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. bias (`bool`, *optional*, defaults to False): Set to `True` for the query, key, and value linear layers to contain a bias parameter. + upcast_attention (`bool`, *optional*, defaults to False): + Set to `True` to upcast the attention computation to `float32`. + upcast_softmax (`bool`, *optional*, defaults to False): + Set to `True` to upcast the softmax computation to `float32`. + cross_attention_norm (`str`, *optional*, defaults to `None`): + The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`. + cross_attention_norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups to use for the group norm in the cross attention. + added_kv_proj_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the added key and value projections. If `None`, no projection is used. + norm_num_groups (`int`, *optional*, defaults to `None`): + The number of groups to use for the group norm in the attention. + spatial_norm_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the spatial normalization. + out_bias (`bool`, *optional*, defaults to `True`): + Set to `True` to use a bias in the output linear layer. + scale_qk (`bool`, *optional*, defaults to `True`): + Set to `True` to scale the query and key by `1 / sqrt(dim_head)`. + only_cross_attention (`bool`, *optional*, defaults to `False`): + Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if + `added_kv_proj_dim` is not `None`. + eps (`float`, *optional*, defaults to 1e-5): + An additional value added to the denominator in group normalization that is used for numerical stability. + rescale_output_factor (`float`, *optional*, defaults to 1.0): + A factor to rescale the output by dividing it with this value. + residual_connection (`bool`, *optional*, defaults to `False`): + Set to `True` to add the residual connection to the output. + _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`): + Set to `True` if the attention block is loaded from a deprecated state dict. + processor (`AttnProcessor`, *optional*, defaults to `None`): + The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and + `AttnProcessor` otherwise. """ def __init__( @@ -57,7 +93,7 @@ def __init__( heads: int = 8, dim_head: int = 64, dropout: float = 0.0, - bias=False, + bias: bool = False, upcast_attention: bool = False, upcast_softmax: bool = False, cross_attention_norm: Optional[str] = None, @@ -71,7 +107,7 @@ def __init__( eps: float = 1e-5, rescale_output_factor: float = 1.0, residual_connection: bool = False, - _from_deprecated_attn_block=False, + _from_deprecated_attn_block: bool = False, processor: Optional["AttnProcessor"] = None, ): super().__init__() @@ -172,7 +208,17 @@ def __init__( def set_use_memory_efficient_attention_xformers( self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None - ): + ) -> None: + r""" + Set whether to use memory efficient attention from `xformers` or not. + + Args: + use_memory_efficient_attention_xformers (`bool`): + Whether to use memory efficient attention from `xformers` or not. + attention_op (`Callable`, *optional*): + The attention operation to use. Defaults to `None` which uses the default attention operation from + `xformers`. + """ is_lora = hasattr(self, "processor") and isinstance( self.processor, LORA_ATTENTION_PROCESSORS, @@ -294,7 +340,14 @@ def set_use_memory_efficient_attention_xformers( self.set_processor(processor) - def set_attention_slice(self, slice_size): + def set_attention_slice(self, slice_size: int) -> None: + r""" + Set the slice size for attention computation. + + Args: + slice_size (`int`): + The slice size for attention computation. + """ if slice_size is not None and slice_size > self.sliceable_head_dim: raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") @@ -315,7 +368,16 @@ def set_attention_slice(self, slice_size): self.set_processor(processor) - def set_processor(self, processor: "AttnProcessor", _remove_lora=False): + def set_processor(self, processor: "AttnProcessor", _remove_lora: bool = False) -> None: + r""" + Set the attention processor to use. + + Args: + processor (`AttnProcessor`): + The attention processor to use. + _remove_lora (`bool`, *optional*, defaults to `False`): + Set to `True` to remove LoRA layers from the model. + """ if hasattr(self, "processor") and _remove_lora and self.to_q.lora_layer is not None: deprecate( "set_processor to offload LoRA", @@ -342,6 +404,16 @@ def set_processor(self, processor: "AttnProcessor", _remove_lora=False): self.processor = processor def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor": + r""" + Get the attention processor in use. + + Args: + return_deprecated_lora (`bool`, *optional*, defaults to `False`): + Set to `True` to return the deprecated LoRA attention processor. + + Returns: + "AttentionProcessor": The attention processor in use. + """ if not return_deprecated_lora: return self.processor @@ -421,7 +493,29 @@ def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProce return lora_processor - def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs): + def forward( + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + **cross_attention_kwargs, + ) -> torch.Tensor: + r""" + The forward method of the `Attention` class. + + Args: + hidden_states (`torch.Tensor`): + The hidden states of the query. + encoder_hidden_states (`torch.Tensor`, *optional*): + The hidden states of the encoder. + attention_mask (`torch.Tensor`, *optional*): + The attention mask to use. If `None`, no mask is applied. + **cross_attention_kwargs: + Additional keyword arguments to pass along to the cross attention. + + Returns: + `torch.Tensor`: The output of the attention layer. + """ # The `Attention` class can call different attention processors / attention functions # here we simply pass along all tensors to the selected processor class # For standard processors that are defined here, `**cross_attention_kwargs` is empty @@ -433,14 +527,36 @@ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None **cross_attention_kwargs, ) - def batch_to_head_dim(self, tensor): + def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor: + r""" + Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads` + is the number of heads initialized while constructing the `Attention` class. + + Args: + tensor (`torch.Tensor`): The tensor to reshape. + + Returns: + `torch.Tensor`: The reshaped tensor. + """ head_size = self.heads batch_size, seq_len, dim = tensor.shape tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) return tensor - def head_to_batch_dim(self, tensor, out_dim=3): + def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor: + r""" + Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is + the number of heads initialized while constructing the `Attention` class. + + Args: + tensor (`torch.Tensor`): The tensor to reshape. + out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is + reshaped to `[batch_size * heads, seq_len, dim // heads]`. + + Returns: + `torch.Tensor`: The reshaped tensor. + """ head_size = self.heads batch_size, seq_len, dim = tensor.shape tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) @@ -451,7 +567,20 @@ def head_to_batch_dim(self, tensor, out_dim=3): return tensor - def get_attention_scores(self, query, key, attention_mask=None): + def get_attention_scores( + self, query: torch.Tensor, key: torch.Tensor, attention_mask: torch.Tensor = None + ) -> torch.Tensor: + r""" + Compute the attention scores. + + Args: + query (`torch.Tensor`): The query tensor. + key (`torch.Tensor`): The key tensor. + attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied. + + Returns: + `torch.Tensor`: The attention probabilities/scores. + """ dtype = query.dtype if self.upcast_attention: query = query.float() @@ -485,7 +614,25 @@ def get_attention_scores(self, query, key, attention_mask=None): return attention_probs - def prepare_attention_mask(self, attention_mask, target_length, batch_size, out_dim=3): + def prepare_attention_mask( + self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3 + ) -> torch.Tensor: + r""" + Prepare the attention mask for the attention computation. + + Args: + attention_mask (`torch.Tensor`): + The attention mask to prepare. + target_length (`int`): + The target length of the attention mask. This is the length of the attention mask after padding. + batch_size (`int`): + The batch size, which is used to repeat the attention mask. + out_dim (`int`, *optional*, defaults to `3`): + The output dimension of the attention mask. Can be either `3` or `4`. + + Returns: + `torch.Tensor`: The prepared attention mask. + """ head_size = self.heads if attention_mask is None: return attention_mask @@ -514,7 +661,17 @@ def prepare_attention_mask(self, attention_mask, target_length, batch_size, out_ return attention_mask - def norm_encoder_hidden_states(self, encoder_hidden_states): + def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor: + r""" + Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the + `Attention` class. + + Args: + encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder. + + Returns: + `torch.Tensor`: The normalized encoder hidden states. + """ assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states" if isinstance(self.norm_cross, nn.LayerNorm): @@ -542,12 +699,12 @@ class AttnProcessor: def __call__( self, attn: Attention, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - temb=None, - scale=1.0, - ): + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + scale: float = 1.0, + ) -> torch.Tensor: residual = hidden_states args = () if USE_PEFT_BACKEND else (scale,) @@ -624,12 +781,12 @@ class CustomDiffusionAttnProcessor(nn.Module): def __init__( self, - train_kv=True, - train_q_out=True, - hidden_size=None, - cross_attention_dim=None, - out_bias=True, - dropout=0.0, + train_kv: bool = True, + train_q_out: bool = True, + hidden_size: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + out_bias: bool = True, + dropout: float = 0.0, ): super().__init__() self.train_kv = train_kv @@ -648,7 +805,13 @@ def __init__( self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias)) self.to_out_custom_diffusion.append(nn.Dropout(dropout)) - def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if self.train_q_out: @@ -707,7 +870,14 @@ class AttnAddedKVProcessor: encoder. """ - def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + scale: float = 1.0, + ) -> torch.Tensor: residual = hidden_states hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) batch_size, sequence_length, _ = hidden_states.shape @@ -767,7 +937,14 @@ def __init__(self): "AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." ) - def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + scale: float = 1.0, + ) -> torch.Tensor: residual = hidden_states hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) batch_size, sequence_length, _ = hidden_states.shape @@ -833,7 +1010,13 @@ class XFormersAttnAddedKVProcessor: def __init__(self, attention_op: Optional[Callable] = None): self.attention_op = attention_op - def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: residual = hidden_states hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) batch_size, sequence_length, _ = hidden_states.shape @@ -906,7 +1089,7 @@ def __call__( attention_mask: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0, - ): + ) -> torch.FloatTensor: residual = hidden_states args = () if USE_PEFT_BACKEND else (scale,) @@ -986,12 +1169,12 @@ def __init__(self): def __call__( self, attn: Attention, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - temb=None, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, scale: float = 1.0, - ): + ) -> torch.FloatTensor: residual = hidden_states if attn.spatial_norm is not None: @@ -1091,12 +1274,12 @@ class CustomDiffusionXFormersAttnProcessor(nn.Module): def __init__( self, - train_kv=True, - train_q_out=False, - hidden_size=None, - cross_attention_dim=None, - out_bias=True, - dropout=0.0, + train_kv: bool = True, + train_q_out: bool = False, + hidden_size: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + out_bias: bool = True, + dropout: float = 0.0, attention_op: Optional[Callable] = None, ): super().__init__() @@ -1117,7 +1300,13 @@ def __init__( self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias)) self.to_out_custom_diffusion.append(nn.Dropout(dropout)) - def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) @@ -1197,12 +1386,12 @@ class CustomDiffusionAttnProcessor2_0(nn.Module): def __init__( self, - train_kv=True, - train_q_out=True, - hidden_size=None, - cross_attention_dim=None, - out_bias=True, - dropout=0.0, + train_kv: bool = True, + train_q_out: bool = True, + hidden_size: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + out_bias: bool = True, + dropout: float = 0.0, ): super().__init__() self.train_kv = train_kv @@ -1221,7 +1410,13 @@ def __init__( self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias)) self.to_out_custom_diffusion.append(nn.Dropout(dropout)) - def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if self.train_q_out: @@ -1290,10 +1485,16 @@ class SlicedAttnProcessor: `attention_head_dim` must be a multiple of the `slice_size`. """ - def __init__(self, slice_size): + def __init__(self, slice_size: int): self.slice_size = slice_size - def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: residual = hidden_states input_ndim = hidden_states.ndim @@ -1374,7 +1575,14 @@ class SlicedAttnAddedKVProcessor: def __init__(self, slice_size): self.slice_size = slice_size - def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None): + def __call__( + self, + attn: "Attention", + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: residual = hidden_states if attn.spatial_norm is not None: @@ -1448,20 +1656,26 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, class SpatialNorm(nn.Module): """ - Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002 + Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002. + + Args: + f_channels (`int`): + The number of channels for input to group normalization layer, and output of the spatial norm layer. + zq_channels (`int`): + The number of channels for the quantized vector as described in the paper. """ def __init__( self, - f_channels, - zq_channels, + f_channels: int, + zq_channels: int, ): super().__init__() self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True) self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) - def forward(self, f, zq): + def forward(self, f: torch.FloatTensor, zq: torch.FloatTensor) -> torch.FloatTensor: f_size = f.shape[-2:] zq = F.interpolate(zq, size=f_size, mode="nearest") norm_f = self.norm_layer(f) @@ -1483,9 +1697,18 @@ class LoRAAttnProcessor(nn.Module): The dimension of the LoRA update matrices. network_alpha (`int`, *optional*): Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. + kwargs (`dict`): + Additional keyword arguments to pass to the `LoRALinearLayer` layers. """ - def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, **kwargs): + def __init__( + self, + hidden_size: int, + cross_attention_dim: Optional[int] = None, + rank: int = 4, + network_alpha: Optional[int] = None, + **kwargs, + ): super().__init__() self.hidden_size = hidden_size @@ -1512,7 +1735,7 @@ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha= self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha) self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha) - def __call__(self, attn: Attention, hidden_states, *args, **kwargs): + def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: self_cls_name = self.__class__.__name__ deprecate( self_cls_name, @@ -1547,9 +1770,18 @@ class LoRAAttnProcessor2_0(nn.Module): The dimension of the LoRA update matrices. network_alpha (`int`, *optional*): Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. + kwargs (`dict`): + Additional keyword arguments to pass to the `LoRALinearLayer` layers. """ - def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, **kwargs): + def __init__( + self, + hidden_size: int, + cross_attention_dim: Optional[int] = None, + rank: int = 4, + network_alpha: Optional[int] = None, + **kwargs, + ): super().__init__() if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") @@ -1578,7 +1810,7 @@ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha= self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha) self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha) - def __call__(self, attn: Attention, hidden_states, *args, **kwargs): + def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: self_cls_name = self.__class__.__name__ deprecate( self_cls_name, @@ -1617,16 +1849,17 @@ class LoRAXFormersAttnProcessor(nn.Module): operator. network_alpha (`int`, *optional*): Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. - + kwargs (`dict`): + Additional keyword arguments to pass to the `LoRALinearLayer` layers. """ def __init__( self, - hidden_size, - cross_attention_dim, - rank=4, + hidden_size: int, + cross_attention_dim: int, + rank: int = 4, attention_op: Optional[Callable] = None, - network_alpha=None, + network_alpha: Optional[int] = None, **kwargs, ): super().__init__() @@ -1656,7 +1889,7 @@ def __init__( self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha) self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha) - def __call__(self, attn: Attention, hidden_states, *args, **kwargs): + def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: self_cls_name = self.__class__.__name__ deprecate( self_cls_name, @@ -1689,10 +1922,19 @@ class LoRAAttnAddedKVProcessor(nn.Module): The number of channels in the `encoder_hidden_states`. rank (`int`, defaults to 4): The dimension of the LoRA update matrices. - + network_alpha (`int`, *optional*): + Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. + kwargs (`dict`): + Additional keyword arguments to pass to the `LoRALinearLayer` layers. """ - def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None): + def __init__( + self, + hidden_size: int, + cross_attention_dim: Optional[int] = None, + rank: int = 4, + network_alpha: Optional[int] = None, + ): super().__init__() self.hidden_size = hidden_size @@ -1706,7 +1948,7 @@ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha= self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) - def __call__(self, attn: Attention, hidden_states, *args, **kwargs): + def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: self_cls_name = self.__class__.__name__ deprecate( self_cls_name, @@ -1764,7 +2006,7 @@ def __call__(self, attn: Attention, hidden_states, *args, **kwargs): CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0, - # depraceted + # deprecated LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, diff --git a/models/t5_film_transformer.py b/models/t5_film_transformer.py index 1c41e656a9db..26ff3f6b8127 100644 --- a/models/t5_film_transformer.py +++ b/models/t5_film_transformer.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math +from typing import Optional, Tuple import torch from torch import nn @@ -23,6 +24,28 @@ class T5FilmDecoder(ModelMixin, ConfigMixin): + r""" + T5 style decoder with FiLM conditioning. + + Args: + input_dims (`int`, *optional*, defaults to `128`): + The number of input dimensions. + targets_length (`int`, *optional*, defaults to `256`): + The length of the targets. + d_model (`int`, *optional*, defaults to `768`): + Size of the input hidden states. + num_layers (`int`, *optional*, defaults to `12`): + The number of `DecoderLayer`'s to use. + num_heads (`int`, *optional*, defaults to `12`): + The number of attention heads to use. + d_kv (`int`, *optional*, defaults to `64`): + Size of the key-value projection vectors. + d_ff (`int`, *optional*, defaults to `2048`): + The number of dimensions in the intermediate feed-forward layer of `DecoderLayer`'s. + dropout_rate (`float`, *optional*, defaults to `0.1`): + Dropout probability. + """ + @register_to_config def __init__( self, @@ -63,7 +86,7 @@ def __init__( self.post_dropout = nn.Dropout(p=dropout_rate) self.spec_out = nn.Linear(d_model, input_dims, bias=False) - def encoder_decoder_mask(self, query_input, key_input): + def encoder_decoder_mask(self, query_input: torch.FloatTensor, key_input: torch.FloatTensor) -> torch.FloatTensor: mask = torch.mul(query_input.unsqueeze(-1), key_input.unsqueeze(-2)) return mask.unsqueeze(-3) @@ -125,7 +148,27 @@ def forward(self, encodings_and_masks, decoder_input_tokens, decoder_noise_time) class DecoderLayer(nn.Module): - def __init__(self, d_model, d_kv, num_heads, d_ff, dropout_rate, layer_norm_epsilon=1e-6): + r""" + T5 decoder layer. + + Args: + d_model (`int`): + Size of the input hidden states. + d_kv (`int`): + Size of the key-value projection vectors. + num_heads (`int`): + Number of attention heads. + d_ff (`int`): + Size of the intermediate feed-forward layer. + dropout_rate (`float`): + Dropout probability. + layer_norm_epsilon (`float`, *optional*, defaults to `1e-6`): + A small value used for numerical stability to avoid dividing by zero. + """ + + def __init__( + self, d_model: int, d_kv: int, num_heads: int, d_ff: int, dropout_rate: float, layer_norm_epsilon: float = 1e-6 + ): super().__init__() self.layer = nn.ModuleList() @@ -152,13 +195,13 @@ def __init__(self, d_model, d_kv, num_heads, d_ff, dropout_rate, layer_norm_epsi def forward( self, - hidden_states, - conditioning_emb=None, - attention_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, + hidden_states: torch.FloatTensor, + conditioning_emb: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, encoder_decoder_position_bias=None, - ): + ) -> Tuple[torch.FloatTensor]: hidden_states = self.layer[0]( hidden_states, conditioning_emb=conditioning_emb, @@ -183,7 +226,21 @@ def forward( class T5LayerSelfAttentionCond(nn.Module): - def __init__(self, d_model, d_kv, num_heads, dropout_rate): + r""" + T5 style self-attention layer with conditioning. + + Args: + d_model (`int`): + Size of the input hidden states. + d_kv (`int`): + Size of the key-value projection vectors. + num_heads (`int`): + Number of attention heads. + dropout_rate (`float`): + Dropout probability. + """ + + def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float): super().__init__() self.layer_norm = T5LayerNorm(d_model) self.FiLMLayer = T5FiLMLayer(in_features=d_model * 4, out_features=d_model) @@ -192,10 +249,10 @@ def __init__(self, d_model, d_kv, num_heads, dropout_rate): def forward( self, - hidden_states, - conditioning_emb=None, - attention_mask=None, - ): + hidden_states: torch.FloatTensor, + conditioning_emb: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: # pre_self_attention_layer_norm normed_hidden_states = self.layer_norm(hidden_states) @@ -211,7 +268,23 @@ def forward( class T5LayerCrossAttention(nn.Module): - def __init__(self, d_model, d_kv, num_heads, dropout_rate, layer_norm_epsilon): + r""" + T5 style cross-attention layer. + + Args: + d_model (`int`): + Size of the input hidden states. + d_kv (`int`): + Size of the key-value projection vectors. + num_heads (`int`): + Number of attention heads. + dropout_rate (`float`): + Dropout probability. + layer_norm_epsilon (`float`): + A small value used for numerical stability to avoid dividing by zero. + """ + + def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float, layer_norm_epsilon: float): super().__init__() self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False) self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon) @@ -219,10 +292,10 @@ def __init__(self, d_model, d_kv, num_heads, dropout_rate, layer_norm_epsilon): def forward( self, - hidden_states, - key_value_states=None, - attention_mask=None, - ): + hidden_states: torch.FloatTensor, + key_value_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.attention( normed_hidden_states, @@ -234,14 +307,30 @@ def forward( class T5LayerFFCond(nn.Module): - def __init__(self, d_model, d_ff, dropout_rate, layer_norm_epsilon): + r""" + T5 style feed-forward conditional layer. + + Args: + d_model (`int`): + Size of the input hidden states. + d_ff (`int`): + Size of the intermediate feed-forward layer. + dropout_rate (`float`): + Dropout probability. + layer_norm_epsilon (`float`): + A small value used for numerical stability to avoid dividing by zero. + """ + + def __init__(self, d_model: int, d_ff: int, dropout_rate: float, layer_norm_epsilon: float): super().__init__() self.DenseReluDense = T5DenseGatedActDense(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate) self.film = T5FiLMLayer(in_features=d_model * 4, out_features=d_model) self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon) self.dropout = nn.Dropout(dropout_rate) - def forward(self, hidden_states, conditioning_emb=None): + def forward( + self, hidden_states: torch.FloatTensor, conditioning_emb: Optional[torch.FloatTensor] = None + ) -> torch.FloatTensor: forwarded_states = self.layer_norm(hidden_states) if conditioning_emb is not None: forwarded_states = self.film(forwarded_states, conditioning_emb) @@ -252,7 +341,19 @@ def forward(self, hidden_states, conditioning_emb=None): class T5DenseGatedActDense(nn.Module): - def __init__(self, d_model, d_ff, dropout_rate): + r""" + T5 style feed-forward layer with gated activations and dropout. + + Args: + d_model (`int`): + Size of the input hidden states. + d_ff (`int`): + Size of the intermediate feed-forward layer. + dropout_rate (`float`): + Dropout probability. + """ + + def __init__(self, d_model: int, d_ff: int, dropout_rate: float): super().__init__() self.wi_0 = nn.Linear(d_model, d_ff, bias=False) self.wi_1 = nn.Linear(d_model, d_ff, bias=False) @@ -260,7 +361,7 @@ def __init__(self, d_model, d_ff, dropout_rate): self.dropout = nn.Dropout(dropout_rate) self.act = NewGELUActivation() - def forward(self, hidden_states): + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: hidden_gelu = self.act(self.wi_0(hidden_states)) hidden_linear = self.wi_1(hidden_states) hidden_states = hidden_gelu * hidden_linear @@ -271,7 +372,17 @@ def forward(self, hidden_states): class T5LayerNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): + r""" + T5 style layer normalization module. + + Args: + hidden_size (`int`): + Size of the input hidden states. + eps (`float`, `optional`, defaults to `1e-6`): + A small value used for numerical stability to avoid dividing by zero. + """ + + def __init__(self, hidden_size: int, eps: float = 1e-6): """ Construct a layernorm module in the T5 style. No bias and no subtraction of mean. """ @@ -279,7 +390,7 @@ def __init__(self, hidden_size, eps=1e-6): self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps - def forward(self, hidden_states): + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for @@ -307,14 +418,20 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: class T5FiLMLayer(nn.Module): """ - FiLM Layer + T5 style FiLM Layer. + + Args: + in_features (`int`): + Number of input features. + out_features (`int`): + Number of output features. """ - def __init__(self, in_features, out_features): + def __init__(self, in_features: int, out_features: int): super().__init__() self.scale_bias = nn.Linear(in_features, out_features * 2, bias=False) - def forward(self, x, conditioning_emb): + def forward(self, x: torch.FloatTensor, conditioning_emb: torch.FloatTensor) -> torch.FloatTensor: emb = self.scale_bias(conditioning_emb) scale, shift = torch.chunk(emb, 2, -1) x = x * (1 + scale) + shift diff --git a/models/transformer_temporal.py b/models/transformer_temporal.py index d59284875736..55c9e6968a32 100644 --- a/models/transformer_temporal.py +++ b/models/transformer_temporal.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass -from typing import Optional +from typing import Any, Dict, Optional import torch from torch import nn @@ -48,11 +48,15 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin): num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. - sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). - This is fixed during training since it is used to learn a number of position embeddings. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. attention_bias (`bool`, *optional*): Configure if the `TransformerBlock` attention should contain a bias parameter. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + activation_fn (`str`, *optional*, defaults to `"geglu"`): + Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported + activation functions. + norm_elementwise_affine (`bool`, *optional*): + Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization. double_self_attention (`bool`, *optional*): Configure if each `TransformerBlock` should contain two self-attention layers. """ @@ -106,14 +110,14 @@ def __init__( def forward( self, - hidden_states, - encoder_hidden_states=None, - timestep=None, - class_labels=None, - num_frames=1, - cross_attention_kwargs=None, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.LongTensor] = None, + timestep: Optional[torch.LongTensor] = None, + class_labels: torch.LongTensor = None, + num_frames: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, - ): + ) -> TransformerTemporalModelOutput: """ The [`TransformerTemporal`] forward method. @@ -123,7 +127,7 @@ def forward( encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): Conditional embeddings for cross attention layer. If not given, cross-attention defaults to self-attention. - timestep ( `torch.long`, *optional*): + timestep ( `torch.LongTensor`, *optional*): Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in diff --git a/models/unet_1d_blocks.py b/models/unet_1d_blocks.py index 84ae48e0f8c4..74a2f1681ead 100644 --- a/models/unet_1d_blocks.py +++ b/models/unet_1d_blocks.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math +from typing import Optional, Tuple, Union import torch import torch.nn.functional as F @@ -24,17 +25,17 @@ class DownResnetBlock1D(nn.Module): def __init__( self, - in_channels, - out_channels=None, - num_layers=1, - conv_shortcut=False, - temb_channels=32, - groups=32, - groups_out=None, - non_linearity=None, - time_embedding_norm="default", - output_scale_factor=1.0, - add_downsample=True, + in_channels: int, + out_channels: Optional[int] = None, + num_layers: int = 1, + conv_shortcut: bool = False, + temb_channels: int = 32, + groups: int = 32, + groups_out: Optional[int] = None, + non_linearity: Optional[str] = None, + time_embedding_norm: str = "default", + output_scale_factor: float = 1.0, + add_downsample: bool = True, ): super().__init__() self.in_channels = in_channels @@ -65,7 +66,7 @@ def __init__( if add_downsample: self.downsample = Downsample1D(out_channels, use_conv=True, padding=1) - def forward(self, hidden_states, temb=None): + def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: output_states = () hidden_states = self.resnets[0](hidden_states, temb) @@ -86,16 +87,16 @@ def forward(self, hidden_states, temb=None): class UpResnetBlock1D(nn.Module): def __init__( self, - in_channels, - out_channels=None, - num_layers=1, - temb_channels=32, - groups=32, - groups_out=None, - non_linearity=None, - time_embedding_norm="default", - output_scale_factor=1.0, - add_upsample=True, + in_channels: int, + out_channels: Optional[int] = None, + num_layers: int = 1, + temb_channels: int = 32, + groups: int = 32, + groups_out: Optional[int] = None, + non_linearity: Optional[str] = None, + time_embedding_norm: str = "default", + output_scale_factor: float = 1.0, + add_upsample: bool = True, ): super().__init__() self.in_channels = in_channels @@ -125,7 +126,12 @@ def __init__( if add_upsample: self.upsample = Upsample1D(out_channels, use_conv_transpose=True) - def forward(self, hidden_states, res_hidden_states_tuple=None, temb=None): + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Optional[Tuple[torch.FloatTensor, ...]] = None, + temb: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: if res_hidden_states_tuple is not None: res_hidden_states = res_hidden_states_tuple[-1] hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1) @@ -144,7 +150,7 @@ def forward(self, hidden_states, res_hidden_states_tuple=None, temb=None): class ValueFunctionMidBlock1D(nn.Module): - def __init__(self, in_channels, out_channels, embed_dim): + def __init__(self, in_channels: int, out_channels: int, embed_dim: int): super().__init__() self.in_channels = in_channels self.out_channels = out_channels @@ -155,7 +161,7 @@ def __init__(self, in_channels, out_channels, embed_dim): self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim) self.down2 = Downsample1D(out_channels // 4, use_conv=True) - def forward(self, x, temb=None): + def forward(self, x: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: x = self.res1(x, temb) x = self.down1(x) x = self.res2(x, temb) @@ -166,13 +172,13 @@ def forward(self, x, temb=None): class MidResTemporalBlock1D(nn.Module): def __init__( self, - in_channels, - out_channels, - embed_dim, + in_channels: int, + out_channels: int, + embed_dim: int, num_layers: int = 1, add_downsample: bool = False, add_upsample: bool = False, - non_linearity=None, + non_linearity: Optional[str] = None, ): super().__init__() self.in_channels = in_channels @@ -203,7 +209,7 @@ def __init__( if self.upsample and self.downsample: raise ValueError("Block cannot downsample and upsample") - def forward(self, hidden_states, temb): + def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor: hidden_states = self.resnets[0](hidden_states, temb) for resnet in self.resnets[1:]: hidden_states = resnet(hidden_states, temb) @@ -217,14 +223,14 @@ def forward(self, hidden_states, temb): class OutConv1DBlock(nn.Module): - def __init__(self, num_groups_out, out_channels, embed_dim, act_fn): + def __init__(self, num_groups_out: int, out_channels: int, embed_dim: int, act_fn: str): super().__init__() self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2) self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim) self.final_conv1d_act = get_activation(act_fn) self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1) - def forward(self, hidden_states, temb=None): + def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: hidden_states = self.final_conv1d_1(hidden_states) hidden_states = rearrange_dims(hidden_states) hidden_states = self.final_conv1d_gn(hidden_states) @@ -235,7 +241,7 @@ def forward(self, hidden_states, temb=None): class OutValueFunctionBlock(nn.Module): - def __init__(self, fc_dim, embed_dim, act_fn="mish"): + def __init__(self, fc_dim: int, embed_dim: int, act_fn: str = "mish"): super().__init__() self.final_block = nn.ModuleList( [ @@ -245,7 +251,7 @@ def __init__(self, fc_dim, embed_dim, act_fn="mish"): ] ) - def forward(self, hidden_states, temb): + def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor: hidden_states = hidden_states.view(hidden_states.shape[0], -1) hidden_states = torch.cat((hidden_states, temb), dim=-1) for layer in self.final_block: @@ -275,14 +281,14 @@ def forward(self, hidden_states, temb): class Downsample1d(nn.Module): - def __init__(self, kernel="linear", pad_mode="reflect"): + def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"): super().__init__() self.pad_mode = pad_mode kernel_1d = torch.tensor(_kernels[kernel]) self.pad = kernel_1d.shape[0] // 2 - 1 self.register_buffer("kernel", kernel_1d) - def forward(self, hidden_states): + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode) weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]]) indices = torch.arange(hidden_states.shape[1], device=hidden_states.device) @@ -292,14 +298,14 @@ def forward(self, hidden_states): class Upsample1d(nn.Module): - def __init__(self, kernel="linear", pad_mode="reflect"): + def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"): super().__init__() self.pad_mode = pad_mode kernel_1d = torch.tensor(_kernels[kernel]) * 2 self.pad = kernel_1d.shape[0] // 2 - 1 self.register_buffer("kernel", kernel_1d) - def forward(self, hidden_states, temb=None): + def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode) weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]]) indices = torch.arange(hidden_states.shape[1], device=hidden_states.device) @@ -309,7 +315,7 @@ def forward(self, hidden_states, temb=None): class SelfAttention1d(nn.Module): - def __init__(self, in_channels, n_head=1, dropout_rate=0.0): + def __init__(self, in_channels: int, n_head: int = 1, dropout_rate: float = 0.0): super().__init__() self.channels = in_channels self.group_norm = nn.GroupNorm(1, num_channels=in_channels) @@ -329,7 +335,7 @@ def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor: new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3) return new_projection - def forward(self, hidden_states): + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: residual = hidden_states batch, channel_dim, seq = hidden_states.shape @@ -367,7 +373,7 @@ def forward(self, hidden_states): class ResConvBlock(nn.Module): - def __init__(self, in_channels, mid_channels, out_channels, is_last=False): + def __init__(self, in_channels: int, mid_channels: int, out_channels: int, is_last: bool = False): super().__init__() self.is_last = is_last self.has_conv_skip = in_channels != out_channels @@ -384,7 +390,7 @@ def __init__(self, in_channels, mid_channels, out_channels, is_last=False): self.group_norm_2 = nn.GroupNorm(1, out_channels) self.gelu_2 = nn.GELU() - def forward(self, hidden_states): + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: residual = self.conv_skip(hidden_states) if self.has_conv_skip else hidden_states hidden_states = self.conv_1(hidden_states) @@ -401,7 +407,7 @@ def forward(self, hidden_states): class UNetMidBlock1D(nn.Module): - def __init__(self, mid_channels, in_channels, out_channels=None): + def __init__(self, mid_channels: int, in_channels: int, out_channels: Optional[int] = None): super().__init__() out_channels = in_channels if out_channels is None else out_channels @@ -429,7 +435,7 @@ def __init__(self, mid_channels, in_channels, out_channels=None): self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) - def forward(self, hidden_states, temb=None): + def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: hidden_states = self.down(hidden_states) for attn, resnet in zip(self.attentions, self.resnets): hidden_states = resnet(hidden_states) @@ -441,7 +447,7 @@ def forward(self, hidden_states, temb=None): class AttnDownBlock1D(nn.Module): - def __init__(self, out_channels, in_channels, mid_channels=None): + def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None): super().__init__() mid_channels = out_channels if mid_channels is None else mid_channels @@ -460,7 +466,7 @@ def __init__(self, out_channels, in_channels, mid_channels=None): self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) - def forward(self, hidden_states, temb=None): + def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: hidden_states = self.down(hidden_states) for resnet, attn in zip(self.resnets, self.attentions): @@ -471,7 +477,7 @@ def forward(self, hidden_states, temb=None): class DownBlock1D(nn.Module): - def __init__(self, out_channels, in_channels, mid_channels=None): + def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None): super().__init__() mid_channels = out_channels if mid_channels is None else mid_channels @@ -484,7 +490,7 @@ def __init__(self, out_channels, in_channels, mid_channels=None): self.resnets = nn.ModuleList(resnets) - def forward(self, hidden_states, temb=None): + def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: hidden_states = self.down(hidden_states) for resnet in self.resnets: @@ -494,7 +500,7 @@ def forward(self, hidden_states, temb=None): class DownBlock1DNoSkip(nn.Module): - def __init__(self, out_channels, in_channels, mid_channels=None): + def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None): super().__init__() mid_channels = out_channels if mid_channels is None else mid_channels @@ -506,7 +512,7 @@ def __init__(self, out_channels, in_channels, mid_channels=None): self.resnets = nn.ModuleList(resnets) - def forward(self, hidden_states, temb=None): + def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: hidden_states = torch.cat([hidden_states, temb], dim=1) for resnet in self.resnets: hidden_states = resnet(hidden_states) @@ -515,7 +521,7 @@ def forward(self, hidden_states, temb=None): class AttnUpBlock1D(nn.Module): - def __init__(self, in_channels, out_channels, mid_channels=None): + def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None): super().__init__() mid_channels = out_channels if mid_channels is None else mid_channels @@ -534,7 +540,12 @@ def __init__(self, in_channels, out_channels, mid_channels=None): self.resnets = nn.ModuleList(resnets) self.up = Upsample1d(kernel="cubic") - def forward(self, hidden_states, res_hidden_states_tuple, temb=None): + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: res_hidden_states = res_hidden_states_tuple[-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) @@ -548,7 +559,7 @@ def forward(self, hidden_states, res_hidden_states_tuple, temb=None): class UpBlock1D(nn.Module): - def __init__(self, in_channels, out_channels, mid_channels=None): + def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None): super().__init__() mid_channels = in_channels if mid_channels is None else mid_channels @@ -561,7 +572,12 @@ def __init__(self, in_channels, out_channels, mid_channels=None): self.resnets = nn.ModuleList(resnets) self.up = Upsample1d(kernel="cubic") - def forward(self, hidden_states, res_hidden_states_tuple, temb=None): + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: res_hidden_states = res_hidden_states_tuple[-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) @@ -574,7 +590,7 @@ def forward(self, hidden_states, res_hidden_states_tuple, temb=None): class UpBlock1DNoSkip(nn.Module): - def __init__(self, in_channels, out_channels, mid_channels=None): + def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None): super().__init__() mid_channels = in_channels if mid_channels is None else mid_channels @@ -586,7 +602,12 @@ def __init__(self, in_channels, out_channels, mid_channels=None): self.resnets = nn.ModuleList(resnets) - def forward(self, hidden_states, res_hidden_states_tuple, temb=None): + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: res_hidden_states = res_hidden_states_tuple[-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) @@ -596,7 +617,20 @@ def forward(self, hidden_states, res_hidden_states_tuple, temb=None): return hidden_states -def get_down_block(down_block_type, num_layers, in_channels, out_channels, temb_channels, add_downsample): +DownBlockType = Union[DownResnetBlock1D, DownBlock1D, AttnDownBlock1D, DownBlock1DNoSkip] +MidBlockType = Union[MidResTemporalBlock1D, ValueFunctionMidBlock1D, UNetMidBlock1D] +OutBlockType = Union[OutConv1DBlock, OutValueFunctionBlock] +UpBlockType = Union[UpResnetBlock1D, UpBlock1D, AttnUpBlock1D, UpBlock1DNoSkip] + + +def get_down_block( + down_block_type: str, + num_layers: int, + in_channels: int, + out_channels: int, + temb_channels: int, + add_downsample: bool, +) -> DownBlockType: if down_block_type == "DownResnetBlock1D": return DownResnetBlock1D( in_channels=in_channels, @@ -614,7 +648,9 @@ def get_down_block(down_block_type, num_layers, in_channels, out_channels, temb_ raise ValueError(f"{down_block_type} does not exist.") -def get_up_block(up_block_type, num_layers, in_channels, out_channels, temb_channels, add_upsample): +def get_up_block( + up_block_type: str, num_layers: int, in_channels: int, out_channels: int, temb_channels: int, add_upsample: bool +) -> UpBlockType: if up_block_type == "UpResnetBlock1D": return UpResnetBlock1D( in_channels=in_channels, @@ -632,7 +668,15 @@ def get_up_block(up_block_type, num_layers, in_channels, out_channels, temb_chan raise ValueError(f"{up_block_type} does not exist.") -def get_mid_block(mid_block_type, num_layers, in_channels, mid_channels, out_channels, embed_dim, add_downsample): +def get_mid_block( + mid_block_type: str, + num_layers: int, + in_channels: int, + mid_channels: int, + out_channels: int, + embed_dim: int, + add_downsample: bool, +) -> MidBlockType: if mid_block_type == "MidResTemporalBlock1D": return MidResTemporalBlock1D( num_layers=num_layers, @@ -648,7 +692,9 @@ def get_mid_block(mid_block_type, num_layers, in_channels, mid_channels, out_cha raise ValueError(f"{mid_block_type} does not exist.") -def get_out_block(*, out_block_type, num_groups_out, embed_dim, out_channels, act_fn, fc_dim): +def get_out_block( + *, out_block_type: str, num_groups_out: int, embed_dim: int, out_channels: int, act_fn: str, fc_dim: int +) -> Optional[OutBlockType]: if out_block_type == "OutConv1DBlock": return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn) elif out_block_type == "ValueFunction": diff --git a/models/unet_2d_blocks.py b/models/unet_2d_blocks.py index cfaedd717bef..e404cef224ff 100644 --- a/models/unet_2d_blocks.py +++ b/models/unet_2d_blocks.py @@ -32,31 +32,31 @@ def get_down_block( - down_block_type, - num_layers, - in_channels, - out_channels, - temb_channels, - add_downsample, - resnet_eps, - resnet_act_fn, - transformer_layers_per_block=1, - num_attention_heads=None, - resnet_groups=None, - cross_attention_dim=None, - downsample_padding=None, - dual_cross_attention=False, - use_linear_projection=False, - only_cross_attention=False, - upcast_attention=False, - resnet_time_scale_shift="default", - attention_type="default", - resnet_skip_time_act=False, - resnet_out_scale_factor=1.0, - cross_attention_norm=None, - attention_head_dim=None, - downsample_type=None, - dropout=0.0, + down_block_type: str, + num_layers: int, + in_channels: int, + out_channels: int, + temb_channels: int, + add_downsample: bool, + resnet_eps: float, + resnet_act_fn: str, + transformer_layers_per_block: int = 1, + num_attention_heads: Optional[int] = None, + resnet_groups: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + downsample_padding: Optional[int] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + attention_type: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: float = 1.0, + cross_attention_norm: Optional[str] = None, + attention_head_dim: Optional[int] = None, + downsample_type: Optional[str] = None, + dropout: float = 0.0, ): # If attn head dim is not defined, we default it to the number of heads if attention_head_dim is None: @@ -241,33 +241,33 @@ def get_down_block( def get_up_block( - up_block_type, - num_layers, - in_channels, - out_channels, - prev_output_channel, - temb_channels, - add_upsample, - resnet_eps, - resnet_act_fn, - resolution_idx=None, - transformer_layers_per_block=1, - num_attention_heads=None, - resnet_groups=None, - cross_attention_dim=None, - dual_cross_attention=False, - use_linear_projection=False, - only_cross_attention=False, - upcast_attention=False, - resnet_time_scale_shift="default", - attention_type="default", - resnet_skip_time_act=False, - resnet_out_scale_factor=1.0, - cross_attention_norm=None, - attention_head_dim=None, - upsample_type=None, - dropout=0.0, -): + up_block_type: str, + num_layers: int, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + add_upsample: bool, + resnet_eps: float, + resnet_act_fn: str, + resolution_idx: Optional[int] = None, + transformer_layers_per_block: int = 1, + num_attention_heads: Optional[int] = None, + resnet_groups: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + attention_type: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: float = 1.0, + cross_attention_norm: Optional[str] = None, + attention_head_dim: Optional[int] = None, + upsample_type: Optional[str] = None, + dropout: float = 0.0, +) -> nn.Module: # If attn head dim is not defined, we default it to the number of heads if attention_head_dim is None: logger.warn( @@ -498,7 +498,7 @@ def __init__(self, in_channels: int, out_channels: int, act_fn: str): ) self.fuse = nn.ReLU() - def forward(self, x): + def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: return self.fuse(self.conv(x) + self.skip(x)) @@ -546,8 +546,8 @@ def __init__( attn_groups: Optional[int] = None, resnet_pre_norm: bool = True, add_attention: bool = True, - attention_head_dim=1, - output_scale_factor=1.0, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, ): super().__init__() resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) @@ -617,7 +617,7 @@ def __init__( self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) - def forward(self, hidden_states, temb=None): + def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): if attn is not None: @@ -640,13 +640,13 @@ def __init__( resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, - num_attention_heads=1, - output_scale_factor=1.0, - cross_attention_dim=1280, - dual_cross_attention=False, - use_linear_projection=False, - upcast_attention=False, - attention_type="default", + num_attention_heads: int = 1, + output_scale_factor: float = 1.0, + cross_attention_dim: int = 1280, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", ): super().__init__() @@ -785,12 +785,12 @@ def __init__( resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, - attention_head_dim=1, - output_scale_factor=1.0, - cross_attention_dim=1280, - skip_time_act=False, - only_cross_attention=False, - cross_attention_norm=None, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + cross_attention_dim: int = 1280, + skip_time_act: bool = False, + only_cross_attention: bool = False, + cross_attention_norm: Optional[str] = None, ): super().__init__() @@ -866,7 +866,7 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - ): + ) -> torch.FloatTensor: cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} lora_scale = cross_attention_kwargs.get("scale", 1.0) @@ -910,10 +910,10 @@ def __init__( resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, - attention_head_dim=1, - output_scale_factor=1.0, - downsample_padding=1, - downsample_type="conv", + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + downsample_padding: int = 1, + downsample_type: str = "conv", ): super().__init__() resnets = [] @@ -989,7 +989,13 @@ def __init__( else: self.downsamplers = None - def forward(self, hidden_states, temb=None, upsample_size=None, cross_attention_kwargs=None): + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + upsample_size: Optional[int] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} lora_scale = cross_attention_kwargs.get("scale", 1.0) @@ -1028,16 +1034,16 @@ def __init__( resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, - num_attention_heads=1, - cross_attention_dim=1280, - output_scale_factor=1.0, - downsample_padding=1, - add_downsample=True, - dual_cross_attention=False, - use_linear_projection=False, - only_cross_attention=False, - upcast_attention=False, - attention_type="default", + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + downsample_padding: int = 1, + add_downsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", ): super().__init__() resnets = [] @@ -1114,8 +1120,8 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - additional_residuals=None, - ): + additional_residuals: Optional[torch.FloatTensor] = None, + ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: output_states = () lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 @@ -1188,9 +1194,9 @@ def __init__( resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, - output_scale_factor=1.0, - add_downsample=True, - downsample_padding=1, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + downsample_padding: int = 1, ): super().__init__() resnets = [] @@ -1227,7 +1233,9 @@ def __init__( self.gradient_checkpointing = False - def forward(self, hidden_states, temb=None, scale: float = 1.0): + def forward( + self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0 + ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: output_states = () for resnet in self.resnets: @@ -1273,9 +1281,9 @@ def __init__( resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, - output_scale_factor=1.0, - add_downsample=True, - downsample_padding=1, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + downsample_padding: int = 1, ): super().__init__() resnets = [] @@ -1310,7 +1318,7 @@ def __init__( else: self.downsamplers = None - def forward(self, hidden_states, scale: float = 1.0): + def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor: for resnet in self.resnets: hidden_states = resnet(hidden_states, temb=None, scale=scale) @@ -1333,10 +1341,10 @@ def __init__( resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, - attention_head_dim=1, - output_scale_factor=1.0, - add_downsample=True, - downsample_padding=1, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + downsample_padding: int = 1, ): super().__init__() resnets = [] @@ -1393,7 +1401,7 @@ def __init__( else: self.downsamplers = None - def forward(self, hidden_states, scale: float = 1.0): + def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor: for resnet, attn in zip(self.resnets, self.attentions): hidden_states = resnet(hidden_states, temb=None, scale=scale) cross_attention_kwargs = {"scale": scale} @@ -1418,9 +1426,9 @@ def __init__( resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_pre_norm: bool = True, - attention_head_dim=1, - output_scale_factor=np.sqrt(2.0), - add_downsample=True, + attention_head_dim: int = 1, + output_scale_factor: float = np.sqrt(2.0), + add_downsample: bool = True, ): super().__init__() self.attentions = nn.ModuleList([]) @@ -1487,7 +1495,13 @@ def __init__( self.downsamplers = None self.skip_conv = None - def forward(self, hidden_states, temb=None, skip_sample=None, scale: float = 1.0): + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + skip_sample: Optional[torch.FloatTensor] = None, + scale: float = 1.0, + ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...], torch.FloatTensor]: output_states = () for resnet, attn in zip(self.resnets, self.attentions): @@ -1520,9 +1534,9 @@ def __init__( resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_pre_norm: bool = True, - output_scale_factor=np.sqrt(2.0), - add_downsample=True, - downsample_padding=1, + output_scale_factor: float = np.sqrt(2.0), + add_downsample: bool = True, + downsample_padding: int = 1, ): super().__init__() self.resnets = nn.ModuleList([]) @@ -1568,7 +1582,13 @@ def __init__( self.downsamplers = None self.skip_conv = None - def forward(self, hidden_states, temb=None, skip_sample=None, scale: float = 1.0): + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + skip_sample: Optional[torch.FloatTensor] = None, + scale: float = 1.0, + ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...], torch.FloatTensor]: output_states = () for resnet in self.resnets: @@ -1600,9 +1620,9 @@ def __init__( resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, - output_scale_factor=1.0, - add_downsample=True, - skip_time_act=False, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + skip_time_act: bool = False, ): super().__init__() resnets = [] @@ -1651,7 +1671,9 @@ def __init__( self.gradient_checkpointing = False - def forward(self, hidden_states, temb=None, scale: float = 1.0): + def forward( + self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0 + ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: output_states = () for resnet in self.resnets: @@ -1698,13 +1720,13 @@ def __init__( resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, - attention_head_dim=1, - cross_attention_dim=1280, - output_scale_factor=1.0, - add_downsample=True, - skip_time_act=False, - only_cross_attention=False, - cross_attention_norm=None, + attention_head_dim: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + skip_time_act: bool = False, + only_cross_attention: bool = False, + cross_attention_norm: Optional[str] = None, ): super().__init__() @@ -1788,7 +1810,7 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - ): + ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: output_states = () cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} @@ -1856,7 +1878,7 @@ def __init__( resnet_eps: float = 1e-5, resnet_act_fn: str = "gelu", resnet_group_size: int = 32, - add_downsample=False, + add_downsample: bool = False, ): super().__init__() resnets = [] @@ -1891,7 +1913,9 @@ def __init__( self.gradient_checkpointing = False - def forward(self, hidden_states, temb=None, scale: float = 1.0): + def forward( + self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0 + ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: output_states = () for resnet in self.resnets: @@ -1933,7 +1957,7 @@ def __init__( dropout: float = 0.0, num_layers: int = 4, resnet_group_size: int = 32, - add_downsample=True, + add_downsample: bool = True, attention_head_dim: int = 64, add_self_attention: bool = False, resnet_eps: float = 1e-5, @@ -1996,7 +2020,7 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - ): + ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: output_states = () lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 @@ -2065,9 +2089,9 @@ def __init__( resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, - attention_head_dim=1, - output_scale_factor=1.0, - upsample_type="conv", + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + upsample_type: str = "conv", ): super().__init__() resnets = [] @@ -2142,7 +2166,14 @@ def __init__( self.resolution_idx = resolution_idx - def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0): + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + upsample_size: Optional[int] = None, + scale: float = 1.0, + ) -> torch.FloatTensor: for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] @@ -2170,7 +2201,7 @@ def __init__( out_channels: int, prev_output_channel: int, temb_channels: int, - resolution_idx: int = None, + resolution_idx: Optional[int] = None, dropout: float = 0.0, num_layers: int = 1, transformer_layers_per_block: Union[int, Tuple[int]] = 1, @@ -2179,15 +2210,15 @@ def __init__( resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, - num_attention_heads=1, - cross_attention_dim=1280, - output_scale_factor=1.0, - add_upsample=True, - dual_cross_attention=False, - use_linear_projection=False, - only_cross_attention=False, - upcast_attention=False, - attention_type="default", + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", ): super().__init__() resnets = [] @@ -2264,7 +2295,7 @@ def forward( upsample_size: Optional[int] = None, attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - ): + ) -> torch.FloatTensor: lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 is_freeu_enabled = ( getattr(self, "s1", None) @@ -2343,7 +2374,7 @@ def __init__( prev_output_channel: int, out_channels: int, temb_channels: int, - resolution_idx: int = None, + resolution_idx: Optional[int] = None, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, @@ -2351,8 +2382,8 @@ def __init__( resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, - output_scale_factor=1.0, - add_upsample=True, + output_scale_factor: float = 1.0, + add_upsample: bool = True, ): super().__init__() resnets = [] @@ -2386,7 +2417,14 @@ def __init__( self.gradient_checkpointing = False self.resolution_idx = resolution_idx - def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0): + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + upsample_size: Optional[int] = None, + scale: float = 1.0, + ) -> torch.FloatTensor: is_freeu_enabled = ( getattr(self, "s1", None) and getattr(self, "s2", None) @@ -2444,7 +2482,7 @@ def __init__( self, in_channels: int, out_channels: int, - resolution_idx: int = None, + resolution_idx: Optional[int] = None, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, @@ -2452,9 +2490,9 @@ def __init__( resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, - output_scale_factor=1.0, - add_upsample=True, - temb_channels=None, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + temb_channels: Optional[int] = None, ): super().__init__() resnets = [] @@ -2486,7 +2524,9 @@ def __init__( self.resolution_idx = resolution_idx - def forward(self, hidden_states, temb=None, scale: float = 1.0): + def forward( + self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0 + ) -> torch.FloatTensor: for resnet in self.resnets: hidden_states = resnet(hidden_states, temb=temb, scale=scale) @@ -2502,7 +2542,7 @@ def __init__( self, in_channels: int, out_channels: int, - resolution_idx: int = None, + resolution_idx: Optional[int] = None, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, @@ -2510,10 +2550,10 @@ def __init__( resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, - attention_head_dim=1, - output_scale_factor=1.0, - add_upsample=True, - temb_channels=None, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + temb_channels: Optional[int] = None, ): super().__init__() resnets = [] @@ -2568,7 +2608,9 @@ def __init__( self.resolution_idx = resolution_idx - def forward(self, hidden_states, temb=None, scale: float = 1.0): + def forward( + self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0 + ) -> torch.FloatTensor: for resnet, attn in zip(self.resnets, self.attentions): hidden_states = resnet(hidden_states, temb=temb, scale=scale) cross_attention_kwargs = {"scale": scale} @@ -2588,16 +2630,16 @@ def __init__( prev_output_channel: int, out_channels: int, temb_channels: int, - resolution_idx: int = None, + resolution_idx: Optional[int] = None, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_pre_norm: bool = True, - attention_head_dim=1, - output_scale_factor=np.sqrt(2.0), - add_upsample=True, + attention_head_dim: int = 1, + output_scale_factor: float = np.sqrt(2.0), + add_upsample: bool = True, ): super().__init__() self.attentions = nn.ModuleList([]) @@ -2675,7 +2717,14 @@ def __init__( self.resolution_idx = resolution_idx - def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None, scale: float = 1.0): + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + skip_sample=None, + scale: float = 1.0, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] @@ -2711,16 +2760,16 @@ def __init__( prev_output_channel: int, out_channels: int, temb_channels: int, - resolution_idx: int = None, + resolution_idx: Optional[int] = None, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_pre_norm: bool = True, - output_scale_factor=np.sqrt(2.0), - add_upsample=True, - upsample_padding=1, + output_scale_factor: float = np.sqrt(2.0), + add_upsample: bool = True, + upsample_padding: int = 1, ): super().__init__() self.resnets = nn.ModuleList([]) @@ -2776,7 +2825,14 @@ def __init__( self.resolution_idx = resolution_idx - def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None, scale: float = 1.0): + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + skip_sample=None, + scale: float = 1.0, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] @@ -2809,7 +2865,7 @@ def __init__( prev_output_channel: int, out_channels: int, temb_channels: int, - resolution_idx: int = None, + resolution_idx: Optional[int] = None, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, @@ -2817,9 +2873,9 @@ def __init__( resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, - output_scale_factor=1.0, - add_upsample=True, - skip_time_act=False, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + skip_time_act: bool = False, ): super().__init__() resnets = [] @@ -2871,7 +2927,14 @@ def __init__( self.gradient_checkpointing = False self.resolution_idx = resolution_idx - def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0): + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + upsample_size: Optional[int] = None, + scale: float = 1.0, + ) -> torch.FloatTensor: for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] @@ -2911,7 +2974,7 @@ def __init__( out_channels: int, prev_output_channel: int, temb_channels: int, - resolution_idx: int = None, + resolution_idx: Optional[int] = None, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, @@ -2919,13 +2982,13 @@ def __init__( resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, - attention_head_dim=1, - cross_attention_dim=1280, - output_scale_factor=1.0, - add_upsample=True, - skip_time_act=False, - only_cross_attention=False, - cross_attention_norm=None, + attention_head_dim: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + skip_time_act: bool = False, + only_cross_attention: bool = False, + cross_attention_norm: Optional[str] = None, ): super().__init__() resnets = [] @@ -3013,7 +3076,7 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - ): + ) -> torch.FloatTensor: cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} lora_scale = cross_attention_kwargs.get("scale", 1.0) @@ -3082,7 +3145,7 @@ def __init__( resnet_eps: float = 1e-5, resnet_act_fn: str = "gelu", resnet_group_size: Optional[int] = 32, - add_upsample=True, + add_upsample: bool = True, ): super().__init__() resnets = [] @@ -3120,7 +3183,14 @@ def __init__( self.gradient_checkpointing = False self.resolution_idx = resolution_idx - def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0): + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + upsample_size: Optional[int] = None, + scale: float = 1.0, + ) -> torch.FloatTensor: res_hidden_states_tuple = res_hidden_states_tuple[-1] if res_hidden_states_tuple is not None: hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) @@ -3164,7 +3234,7 @@ def __init__( resnet_eps: float = 1e-5, resnet_act_fn: str = "gelu", resnet_group_size: int = 32, - attention_head_dim=1, # attention dim_head + attention_head_dim: int = 1, # attention dim_head cross_attention_dim: int = 768, add_upsample: bool = True, upcast_attention: bool = False, @@ -3248,7 +3318,7 @@ def forward( upsample_size: Optional[int] = None, attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - ): + ) -> torch.FloatTensor: res_hidden_states_tuple = res_hidden_states_tuple[-1] if res_hidden_states_tuple is not None: hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) @@ -3310,11 +3380,18 @@ class KAttentionBlock(nn.Module): attention_head_dim (`int`): The number of channels in each head. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - num_embeds_ada_norm (: - obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. - attention_bias (: - obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + attention_bias (`bool`, *optional*, defaults to `False`): + Configure if the attention layers should contain a bias parameter. + upcast_attention (`bool`, *optional*, defaults to `False`): + Set to `True` to upcast the attention computation to `float32`. + temb_channels (`int`, *optional*, defaults to 768): + The number of channels in the token embedding. + add_self_attention (`bool`, *optional*, defaults to `False`): + Set to `True` to add self-attention to the block. + cross_attention_norm (`str`, *optional*, defaults to `None`): + The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`. + group_size (`int`, *optional*, defaults to 32): + The number of groups to separate the channels into for group normalization. """ def __init__( @@ -3360,10 +3437,10 @@ def __init__( cross_attention_norm=cross_attention_norm, ) - def _to_3d(self, hidden_states, height, weight): + def _to_3d(self, hidden_states: torch.FloatTensor, height: int, weight: int) -> torch.FloatTensor: return hidden_states.permute(0, 2, 3, 1).reshape(hidden_states.shape[0], height * weight, -1) - def _to_4d(self, hidden_states, height, weight): + def _to_4d(self, hidden_states: torch.FloatTensor, height: int, weight: int) -> torch.FloatTensor: return hidden_states.permute(0, 2, 1).reshape(hidden_states.shape[0], -1, height, weight) def forward( @@ -3376,7 +3453,7 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - ): + ) -> torch.FloatTensor: cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} # 1. Self-Attention diff --git a/models/vae.py b/models/vae.py index 36983eefc01f..da08bc360942 100644 --- a/models/vae.py +++ b/models/vae.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass -from typing import Optional +from typing import Optional, Tuple import numpy as np import torch @@ -27,7 +27,7 @@ @dataclass class DecoderOutput(BaseOutput): - """ + r""" Output of decoding method. Args: @@ -39,16 +39,39 @@ class DecoderOutput(BaseOutput): class Encoder(nn.Module): + r""" + The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available + options. + block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + act_fn (`str`, *optional*, defaults to `"silu"`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + double_z (`bool`, *optional*, defaults to `True`): + Whether to double the number of output channels for the last block. + """ + def __init__( self, - in_channels=3, - out_channels=3, - down_block_types=("DownEncoderBlock2D",), - block_out_channels=(64,), - layers_per_block=2, - norm_num_groups=32, - act_fn="silu", - double_z=True, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",), + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + double_z: bool = True, ): super().__init__() self.layers_per_block = layers_per_block @@ -107,7 +130,8 @@ def __init__( self.gradient_checkpointing = False - def forward(self, x): + def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: + r"""The forward method of the `Encoder` class.""" sample = x sample = self.conv_in(sample) @@ -152,16 +176,38 @@ def custom_forward(*inputs): class Decoder(nn.Module): + r""" + The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options. + block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + act_fn (`str`, *optional*, defaults to `"silu"`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + norm_type (`str`, *optional*, defaults to `"group"`): + The normalization type to use. Can be either `"group"` or `"spatial"`. + """ + def __init__( self, - in_channels=3, - out_channels=3, - up_block_types=("UpDecoderBlock2D",), - block_out_channels=(64,), - layers_per_block=2, - norm_num_groups=32, - act_fn="silu", - norm_type="group", # group, spatial + in_channels: int = 3, + out_channels: int = 3, + up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",), + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + norm_type: str = "group", # group, spatial ): super().__init__() self.layers_per_block = layers_per_block @@ -227,7 +273,8 @@ def __init__( self.gradient_checkpointing = False - def forward(self, z, latent_embeds=None): + def forward(self, z: torch.FloatTensor, latent_embeds: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: + r"""The forward method of the `Decoder` class.""" sample = z sample = self.conv_in(sample) @@ -283,6 +330,16 @@ def custom_forward(*inputs): class UpSample(nn.Module): + r""" + The `UpSample` layer of a variational autoencoder that upsamples its input. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + """ + def __init__( self, in_channels: int, @@ -294,6 +351,7 @@ def __init__( self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1) def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: + r"""The forward method of the `UpSample` class.""" x = torch.relu(x) x = self.deconv(x) return x @@ -342,6 +400,7 @@ def __init__( self.layers = nn.Sequential(*layers) def forward(self, x: torch.FloatTensor, mask=None) -> torch.FloatTensor: + r"""The forward method of the `MaskConditionEncoder` class.""" out = {} for l in range(len(self.layers)): layer = self.layers[l] @@ -352,19 +411,38 @@ def forward(self, x: torch.FloatTensor, mask=None) -> torch.FloatTensor: class MaskConditionDecoder(nn.Module): - """The `MaskConditionDecoder` should be used in combination with [`AsymmetricAutoencoderKL`] to enhance the model's - decoder with a conditioner on the mask and masked image.""" + r"""The `MaskConditionDecoder` should be used in combination with [`AsymmetricAutoencoderKL`] to enhance the model's + decoder with a conditioner on the mask and masked image. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options. + block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + act_fn (`str`, *optional*, defaults to `"silu"`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + norm_type (`str`, *optional*, defaults to `"group"`): + The normalization type to use. Can be either `"group"` or `"spatial"`. + """ def __init__( self, - in_channels=3, - out_channels=3, - up_block_types=("UpDecoderBlock2D",), - block_out_channels=(64,), - layers_per_block=2, - norm_num_groups=32, - act_fn="silu", - norm_type="group", # group, spatial + in_channels: int = 3, + out_channels: int = 3, + up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",), + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + norm_type: str = "group", # group, spatial ): super().__init__() self.layers_per_block = layers_per_block @@ -437,7 +515,14 @@ def __init__( self.gradient_checkpointing = False - def forward(self, z, image=None, mask=None, latent_embeds=None): + def forward( + self, + z: torch.FloatTensor, + image: Optional[torch.FloatTensor] = None, + mask: Optional[torch.FloatTensor] = None, + latent_embeds: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + r"""The forward method of the `MaskConditionDecoder` class.""" sample = z sample = self.conv_in(sample) @@ -539,7 +624,14 @@ class VectorQuantizer(nn.Module): # backwards compatibility we use the buggy version by default, but you can # specify legacy=False to fix it. def __init__( - self, n_e, vq_embed_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True + self, + n_e: int, + vq_embed_dim: int, + beta: float, + remap=None, + unknown_index: str = "random", + sane_index_shape: bool = False, + legacy: bool = True, ): super().__init__() self.n_e = n_e @@ -553,6 +645,7 @@ def __init__( self.remap = remap if self.remap is not None: self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.used: torch.Tensor self.re_embed = self.used.shape[0] self.unknown_index = unknown_index # "random" or "extra" or integer if self.unknown_index == "extra": @@ -567,7 +660,7 @@ def __init__( self.sane_index_shape = sane_index_shape - def remap_to_used(self, inds): + def remap_to_used(self, inds: torch.LongTensor) -> torch.LongTensor: ishape = inds.shape assert len(ishape) > 1 inds = inds.reshape(ishape[0], -1) @@ -581,7 +674,7 @@ def remap_to_used(self, inds): new[unknown] = self.unknown_index return new.reshape(ishape) - def unmap_to_all(self, inds): + def unmap_to_all(self, inds: torch.LongTensor) -> torch.LongTensor: ishape = inds.shape assert len(ishape) > 1 inds = inds.reshape(ishape[0], -1) @@ -591,7 +684,7 @@ def unmap_to_all(self, inds): back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) return back.reshape(ishape) - def forward(self, z): + def forward(self, z: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, Tuple]: # reshape z -> (batch, height, width, channel) and flatten z = z.permute(0, 2, 3, 1).contiguous() z_flattened = z.view(-1, self.vq_embed_dim) @@ -610,7 +703,7 @@ def forward(self, z): loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2) # preserve gradients - z_q = z + (z_q - z).detach() + z_q: torch.FloatTensor = z + (z_q - z).detach() # reshape back to match original input shape z_q = z_q.permute(0, 3, 1, 2).contiguous() @@ -625,7 +718,7 @@ def forward(self, z): return z_q, loss, (perplexity, min_encodings, min_encoding_indices) - def get_codebook_entry(self, indices, shape): + def get_codebook_entry(self, indices: torch.LongTensor, shape: Tuple[int, ...]) -> torch.FloatTensor: # shape specifying (batch, height, width, channel) if self.remap is not None: indices = indices.reshape(shape[0], -1) # add batch axis @@ -633,7 +726,7 @@ def get_codebook_entry(self, indices, shape): indices = indices.reshape(-1) # flatten again # get quantized latent vectors - z_q = self.embedding(indices) + z_q: torch.FloatTensor = self.embedding(indices) if shape is not None: z_q = z_q.view(shape) @@ -644,7 +737,7 @@ def get_codebook_entry(self, indices, shape): class DiagonalGaussianDistribution(object): - def __init__(self, parameters, deterministic=False): + def __init__(self, parameters: torch.Tensor, deterministic: bool = False): self.parameters = parameters self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) self.logvar = torch.clamp(self.logvar, -30.0, 20.0) @@ -664,7 +757,7 @@ def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTens x = self.mean + self.std * sample return x - def kl(self, other=None): + def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor: if self.deterministic: return torch.Tensor([0.0]) else: @@ -680,23 +773,40 @@ def kl(self, other=None): dim=[1, 2, 3], ) - def nll(self, sample, dims=[1, 2, 3]): + def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor: if self.deterministic: return torch.Tensor([0.0]) logtwopi = np.log(2.0 * np.pi) return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims) - def mode(self): + def mode(self) -> torch.Tensor: return self.mean class EncoderTiny(nn.Module): + r""" + The `EncoderTiny` layer is a simpler version of the `Encoder` layer. + + Args: + in_channels (`int`): + The number of input channels. + out_channels (`int`): + The number of output channels. + num_blocks (`Tuple[int, ...]`): + Each value of the tuple represents a Conv2d layer followed by `value` number of `AutoencoderTinyBlock`'s to + use. + block_out_channels (`Tuple[int, ...]`): + The number of output channels for each block. + act_fn (`str`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + """ + def __init__( self, in_channels: int, out_channels: int, - num_blocks: int, - block_out_channels: int, + num_blocks: Tuple[int, ...], + block_out_channels: Tuple[int, ...], act_fn: str, ): super().__init__() @@ -718,7 +828,8 @@ def __init__( self.layers = nn.Sequential(*layers) self.gradient_checkpointing = False - def forward(self, x): + def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: + r"""The forward method of the `EncoderTiny` class.""" if self.training and self.gradient_checkpointing: def create_custom_forward(module): @@ -740,12 +851,31 @@ def custom_forward(*inputs): class DecoderTiny(nn.Module): + r""" + The `DecoderTiny` layer is a simpler version of the `Decoder` layer. + + Args: + in_channels (`int`): + The number of input channels. + out_channels (`int`): + The number of output channels. + num_blocks (`Tuple[int, ...]`): + Each value of the tuple represents a Conv2d layer followed by `value` number of `AutoencoderTinyBlock`'s to + use. + block_out_channels (`Tuple[int, ...]`): + The number of output channels for each block. + upsampling_scaling_factor (`int`): + The scaling factor to use for upsampling. + act_fn (`str`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + """ + def __init__( self, in_channels: int, out_channels: int, - num_blocks: int, - block_out_channels: int, + num_blocks: Tuple[int, ...], + block_out_channels: Tuple[int, ...], upsampling_scaling_factor: int, act_fn: str, ): @@ -772,7 +902,8 @@ def __init__( self.layers = nn.Sequential(*layers) self.gradient_checkpointing = False - def forward(self, x): + def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: + r"""The forward method of the `DecoderTiny` class.""" # Clamp. x = torch.tanh(x / 3) * 3 diff --git a/models/vq_model.py b/models/vq_model.py index 0c15300af213..08ad122c3891 100644 --- a/models/vq_model.py +++ b/models/vq_model.py @@ -53,10 +53,12 @@ class VQModel(ModelMixin, ConfigMixin): Tuple of upsample block types. block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): Tuple of block output channels. + layers_per_block (`int`, *optional*, defaults to `1`): Number of layers per block. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space. sample_size (`int`, *optional*, defaults to `32`): Sample input size. num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE. + norm_num_groups (`int`, *optional*, defaults to `32`): Number of groups for normalization layers. vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE. scaling_factor (`float`, *optional*, defaults to `0.18215`): The component-wise standard deviation of the trained latent space computed using the first batch of the @@ -65,6 +67,8 @@ class VQModel(ModelMixin, ConfigMixin): diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. + norm_type (`str`, *optional*, defaults to `"group"`): + Type of normalization layer to use. Can be one of `"group"` or `"spatial"`. """ @register_to_config @@ -72,9 +76,9 @@ def __init__( self, in_channels: int = 3, out_channels: int = 3, - down_block_types: Tuple[str] = ("DownEncoderBlock2D",), - up_block_types: Tuple[str] = ("UpDecoderBlock2D",), - block_out_channels: Tuple[int] = (64,), + down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",), + up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",), + block_out_channels: Tuple[int, ...] = (64,), layers_per_block: int = 1, act_fn: str = "silu", latent_channels: int = 3, diff --git a/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/pipelines/alt_diffusion/pipeline_alt_diffusion.py index 3c24db1fdc94..bf267f0ff1af 100644 --- a/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -106,7 +106,6 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL feature_extractor ([`~transformers.CLIPImageProcessor`]): A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ - model_cpu_offload_seq = "text_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor"] _exclude_from_cpu_offload = ["safety_checker"] diff --git a/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index d9acf9daf2a6..a28c3da2694b 100644 --- a/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -134,7 +134,6 @@ class AltDiffusionImg2ImgPipeline( feature_extractor ([`~transformers.CLIPImageProcessor`]): A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ - model_cpu_offload_seq = "text_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor"] _exclude_from_cpu_offload = ["safety_checker"] diff --git a/pipelines/versatile_diffusion/modeling_text_unet.py b/pipelines/versatile_diffusion/modeling_text_unet.py index d936666d6139..9066c47c56c6 100644 --- a/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/pipelines/versatile_diffusion/modeling_text_unet.py @@ -1508,9 +1508,9 @@ def __init__( resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, - output_scale_factor=1.0, - add_downsample=True, - downsample_padding=1, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + downsample_padding: int = 1, ): super().__init__() resnets = [] @@ -1547,7 +1547,9 @@ def __init__( self.gradient_checkpointing = False - def forward(self, hidden_states, temb=None, scale: float = 1.0): + def forward( + self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0 + ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: output_states = () for resnet in self.resnets: @@ -1596,16 +1598,16 @@ def __init__( resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, - num_attention_heads=1, - cross_attention_dim=1280, - output_scale_factor=1.0, - downsample_padding=1, - add_downsample=True, - dual_cross_attention=False, - use_linear_projection=False, - only_cross_attention=False, - upcast_attention=False, - attention_type="default", + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + downsample_padding: int = 1, + add_downsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", ): super().__init__() resnets = [] @@ -1682,8 +1684,8 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - additional_residuals=None, - ): + additional_residuals: Optional[torch.FloatTensor] = None, + ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: output_states = () lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 @@ -1751,7 +1753,7 @@ def __init__( prev_output_channel: int, out_channels: int, temb_channels: int, - resolution_idx: int = None, + resolution_idx: Optional[int] = None, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, @@ -1759,8 +1761,8 @@ def __init__( resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, - output_scale_factor=1.0, - add_upsample=True, + output_scale_factor: float = 1.0, + add_upsample: bool = True, ): super().__init__() resnets = [] @@ -1794,7 +1796,14 @@ def __init__( self.gradient_checkpointing = False self.resolution_idx = resolution_idx - def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0): + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + upsample_size: Optional[int] = None, + scale: float = 1.0, + ) -> torch.FloatTensor: is_freeu_enabled = ( getattr(self, "s1", None) and getattr(self, "s2", None) @@ -1855,7 +1864,7 @@ def __init__( out_channels: int, prev_output_channel: int, temb_channels: int, - resolution_idx: int = None, + resolution_idx: Optional[int] = None, dropout: float = 0.0, num_layers: int = 1, transformer_layers_per_block: Union[int, Tuple[int]] = 1, @@ -1864,15 +1873,15 @@ def __init__( resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, - num_attention_heads=1, - cross_attention_dim=1280, - output_scale_factor=1.0, - add_upsample=True, - dual_cross_attention=False, - use_linear_projection=False, - only_cross_attention=False, - upcast_attention=False, - attention_type="default", + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", ): super().__init__() resnets = [] @@ -1949,7 +1958,7 @@ def forward( upsample_size: Optional[int] = None, attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - ): + ) -> torch.FloatTensor: lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 is_freeu_enabled = ( getattr(self, "s1", None) @@ -2066,8 +2075,8 @@ def __init__( attn_groups: Optional[int] = None, resnet_pre_norm: bool = True, add_attention: bool = True, - attention_head_dim=1, - output_scale_factor=1.0, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, ): super().__init__() resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) @@ -2138,7 +2147,7 @@ def __init__( self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) - def forward(self, hidden_states, temb=None): + def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): if attn is not None: @@ -2162,13 +2171,13 @@ def __init__( resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, - num_attention_heads=1, - output_scale_factor=1.0, - cross_attention_dim=1280, - dual_cross_attention=False, - use_linear_projection=False, - upcast_attention=False, - attention_type="default", + num_attention_heads: int = 1, + output_scale_factor: float = 1.0, + cross_attention_dim: int = 1280, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", ): super().__init__() @@ -2308,12 +2317,12 @@ def __init__( resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, - attention_head_dim=1, - output_scale_factor=1.0, - cross_attention_dim=1280, - skip_time_act=False, - only_cross_attention=False, - cross_attention_norm=None, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + cross_attention_dim: int = 1280, + skip_time_act: bool = False, + only_cross_attention: bool = False, + cross_attention_norm: Optional[str] = None, ): super().__init__() @@ -2389,7 +2398,7 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - ): + ) -> torch.FloatTensor: cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} lora_scale = cross_attention_kwargs.get("scale", 1.0)