Skip to content

Commit

Permalink
Improve typehints and docs in diffusers/models (huggingface#5391)
Browse files Browse the repository at this point in the history
* improvement: add typehints and docs to src/diffusers/models/attention_processor.py

* improvement: add typehints and docs to src/diffusers/models/vae.py

* improvement: add missing docs in src/diffusers/models/vq_model.py

* improvement: add typehints and docs to src/diffusers/models/transformer_temporal.py

* improvement: add typehints and docs to src/diffusers/models/t5_film_transformer.py

* improvement: add type hints to src/diffusers/models/unet_1d_blocks.py

* improvement: add missing type hints to src/diffusers/models/unet_2d_blocks.py

* fix: CI error (make fix-copies required)

* fix: CI error (make fix-copies required again)

---------

Co-authored-by: Dhruv Nair <[email protected]>
  • Loading branch information
a-r-r-o-w and DN6 authored Oct 25, 2023
1 parent 4578cad commit ffb0107
Show file tree
Hide file tree
Showing 10 changed files with 1,090 additions and 462 deletions.
386 changes: 314 additions & 72 deletions models/attention_processor.py

Large diffs are not rendered by default.

171 changes: 144 additions & 27 deletions models/t5_film_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional, Tuple

import torch
from torch import nn
Expand All @@ -23,6 +24,28 @@


class T5FilmDecoder(ModelMixin, ConfigMixin):
r"""
T5 style decoder with FiLM conditioning.
Args:
input_dims (`int`, *optional*, defaults to `128`):
The number of input dimensions.
targets_length (`int`, *optional*, defaults to `256`):
The length of the targets.
d_model (`int`, *optional*, defaults to `768`):
Size of the input hidden states.
num_layers (`int`, *optional*, defaults to `12`):
The number of `DecoderLayer`'s to use.
num_heads (`int`, *optional*, defaults to `12`):
The number of attention heads to use.
d_kv (`int`, *optional*, defaults to `64`):
Size of the key-value projection vectors.
d_ff (`int`, *optional*, defaults to `2048`):
The number of dimensions in the intermediate feed-forward layer of `DecoderLayer`'s.
dropout_rate (`float`, *optional*, defaults to `0.1`):
Dropout probability.
"""

@register_to_config
def __init__(
self,
Expand Down Expand Up @@ -63,7 +86,7 @@ def __init__(
self.post_dropout = nn.Dropout(p=dropout_rate)
self.spec_out = nn.Linear(d_model, input_dims, bias=False)

def encoder_decoder_mask(self, query_input, key_input):
def encoder_decoder_mask(self, query_input: torch.FloatTensor, key_input: torch.FloatTensor) -> torch.FloatTensor:
mask = torch.mul(query_input.unsqueeze(-1), key_input.unsqueeze(-2))
return mask.unsqueeze(-3)

Expand Down Expand Up @@ -125,7 +148,27 @@ def forward(self, encodings_and_masks, decoder_input_tokens, decoder_noise_time)


class DecoderLayer(nn.Module):
def __init__(self, d_model, d_kv, num_heads, d_ff, dropout_rate, layer_norm_epsilon=1e-6):
r"""
T5 decoder layer.
Args:
d_model (`int`):
Size of the input hidden states.
d_kv (`int`):
Size of the key-value projection vectors.
num_heads (`int`):
Number of attention heads.
d_ff (`int`):
Size of the intermediate feed-forward layer.
dropout_rate (`float`):
Dropout probability.
layer_norm_epsilon (`float`, *optional*, defaults to `1e-6`):
A small value used for numerical stability to avoid dividing by zero.
"""

def __init__(
self, d_model: int, d_kv: int, num_heads: int, d_ff: int, dropout_rate: float, layer_norm_epsilon: float = 1e-6
):
super().__init__()
self.layer = nn.ModuleList()

Expand All @@ -152,13 +195,13 @@ def __init__(self, d_model, d_kv, num_heads, d_ff, dropout_rate, layer_norm_epsi

def forward(
self,
hidden_states,
conditioning_emb=None,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
hidden_states: torch.FloatTensor,
conditioning_emb: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
encoder_decoder_position_bias=None,
):
) -> Tuple[torch.FloatTensor]:
hidden_states = self.layer[0](
hidden_states,
conditioning_emb=conditioning_emb,
Expand All @@ -183,7 +226,21 @@ def forward(


class T5LayerSelfAttentionCond(nn.Module):
def __init__(self, d_model, d_kv, num_heads, dropout_rate):
r"""
T5 style self-attention layer with conditioning.
Args:
d_model (`int`):
Size of the input hidden states.
d_kv (`int`):
Size of the key-value projection vectors.
num_heads (`int`):
Number of attention heads.
dropout_rate (`float`):
Dropout probability.
"""

def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float):
super().__init__()
self.layer_norm = T5LayerNorm(d_model)
self.FiLMLayer = T5FiLMLayer(in_features=d_model * 4, out_features=d_model)
Expand All @@ -192,10 +249,10 @@ def __init__(self, d_model, d_kv, num_heads, dropout_rate):

def forward(
self,
hidden_states,
conditioning_emb=None,
attention_mask=None,
):
hidden_states: torch.FloatTensor,
conditioning_emb: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
# pre_self_attention_layer_norm
normed_hidden_states = self.layer_norm(hidden_states)

Expand All @@ -211,18 +268,34 @@ def forward(


class T5LayerCrossAttention(nn.Module):
def __init__(self, d_model, d_kv, num_heads, dropout_rate, layer_norm_epsilon):
r"""
T5 style cross-attention layer.
Args:
d_model (`int`):
Size of the input hidden states.
d_kv (`int`):
Size of the key-value projection vectors.
num_heads (`int`):
Number of attention heads.
dropout_rate (`float`):
Dropout probability.
layer_norm_epsilon (`float`):
A small value used for numerical stability to avoid dividing by zero.
"""

def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float, layer_norm_epsilon: float):
super().__init__()
self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False)
self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
self.dropout = nn.Dropout(dropout_rate)

def forward(
self,
hidden_states,
key_value_states=None,
attention_mask=None,
):
hidden_states: torch.FloatTensor,
key_value_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.attention(
normed_hidden_states,
Expand All @@ -234,14 +307,30 @@ def forward(


class T5LayerFFCond(nn.Module):
def __init__(self, d_model, d_ff, dropout_rate, layer_norm_epsilon):
r"""
T5 style feed-forward conditional layer.
Args:
d_model (`int`):
Size of the input hidden states.
d_ff (`int`):
Size of the intermediate feed-forward layer.
dropout_rate (`float`):
Dropout probability.
layer_norm_epsilon (`float`):
A small value used for numerical stability to avoid dividing by zero.
"""

def __init__(self, d_model: int, d_ff: int, dropout_rate: float, layer_norm_epsilon: float):
super().__init__()
self.DenseReluDense = T5DenseGatedActDense(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate)
self.film = T5FiLMLayer(in_features=d_model * 4, out_features=d_model)
self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
self.dropout = nn.Dropout(dropout_rate)

def forward(self, hidden_states, conditioning_emb=None):
def forward(
self, hidden_states: torch.FloatTensor, conditioning_emb: Optional[torch.FloatTensor] = None
) -> torch.FloatTensor:
forwarded_states = self.layer_norm(hidden_states)
if conditioning_emb is not None:
forwarded_states = self.film(forwarded_states, conditioning_emb)
Expand All @@ -252,15 +341,27 @@ def forward(self, hidden_states, conditioning_emb=None):


class T5DenseGatedActDense(nn.Module):
def __init__(self, d_model, d_ff, dropout_rate):
r"""
T5 style feed-forward layer with gated activations and dropout.
Args:
d_model (`int`):
Size of the input hidden states.
d_ff (`int`):
Size of the intermediate feed-forward layer.
dropout_rate (`float`):
Dropout probability.
"""

def __init__(self, d_model: int, d_ff: int, dropout_rate: float):
super().__init__()
self.wi_0 = nn.Linear(d_model, d_ff, bias=False)
self.wi_1 = nn.Linear(d_model, d_ff, bias=False)
self.wo = nn.Linear(d_ff, d_model, bias=False)
self.dropout = nn.Dropout(dropout_rate)
self.act = NewGELUActivation()

def forward(self, hidden_states):
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
hidden_gelu = self.act(self.wi_0(hidden_states))
hidden_linear = self.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
Expand All @@ -271,15 +372,25 @@ def forward(self, hidden_states):


class T5LayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
r"""
T5 style layer normalization module.
Args:
hidden_size (`int`):
Size of the input hidden states.
eps (`float`, `optional`, defaults to `1e-6`):
A small value used for numerical stability to avoid dividing by zero.
"""

def __init__(self, hidden_size: int, eps: float = 1e-6):
"""
Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps

def forward(self, hidden_states):
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
Expand Down Expand Up @@ -307,14 +418,20 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:

class T5FiLMLayer(nn.Module):
"""
FiLM Layer
T5 style FiLM Layer.
Args:
in_features (`int`):
Number of input features.
out_features (`int`):
Number of output features.
"""

def __init__(self, in_features, out_features):
def __init__(self, in_features: int, out_features: int):
super().__init__()
self.scale_bias = nn.Linear(in_features, out_features * 2, bias=False)

def forward(self, x, conditioning_emb):
def forward(self, x: torch.FloatTensor, conditioning_emb: torch.FloatTensor) -> torch.FloatTensor:
emb = self.scale_bias(conditioning_emb)
scale, shift = torch.chunk(emb, 2, -1)
x = x * (1 + scale) + shift
Expand Down
28 changes: 16 additions & 12 deletions models/transformer_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional
from typing import Any, Dict, Optional

import torch
from torch import nn
Expand Down Expand Up @@ -48,11 +48,15 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin):
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
This is fixed during training since it is used to learn a number of position embeddings.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
attention_bias (`bool`, *optional*):
Configure if the `TransformerBlock` attention should contain a bias parameter.
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
This is fixed during training since it is used to learn a number of position embeddings.
activation_fn (`str`, *optional*, defaults to `"geglu"`):
Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported
activation functions.
norm_elementwise_affine (`bool`, *optional*):
Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization.
double_self_attention (`bool`, *optional*):
Configure if each `TransformerBlock` should contain two self-attention layers.
"""
Expand Down Expand Up @@ -106,14 +110,14 @@ def __init__(

def forward(
self,
hidden_states,
encoder_hidden_states=None,
timestep=None,
class_labels=None,
num_frames=1,
cross_attention_kwargs=None,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.LongTensor] = None,
timestep: Optional[torch.LongTensor] = None,
class_labels: torch.LongTensor = None,
num_frames: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
):
) -> TransformerTemporalModelOutput:
"""
The [`TransformerTemporal`] forward method.
Expand All @@ -123,7 +127,7 @@ def forward(
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
timestep ( `torch.long`, *optional*):
timestep ( `torch.LongTensor`, *optional*):
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
Expand Down
Loading

0 comments on commit ffb0107

Please sign in to comment.