Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve typehints and docs in diffusers/models #5391

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
386 changes: 314 additions & 72 deletions src/diffusers/models/attention_processor.py

Large diffs are not rendered by default.

171 changes: 144 additions & 27 deletions src/diffusers/models/t5_film_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional, Tuple

import torch
from torch import nn
Expand All @@ -23,6 +24,28 @@


class T5FilmDecoder(ModelMixin, ConfigMixin):
r"""
T5 style decoder with FiLM conditioning.

Args:
input_dims (`int`, *optional*, defaults to `128`):
The number of input dimensions.
targets_length (`int`, *optional*, defaults to `256`):
The length of the targets.
d_model (`int`, *optional*, defaults to `768`):
Size of the input hidden states.
num_layers (`int`, *optional*, defaults to `12`):
The number of `DecoderLayer`'s to use.
num_heads (`int`, *optional*, defaults to `12`):
The number of attention heads to use.
d_kv (`int`, *optional*, defaults to `64`):
Size of the key-value projection vectors.
d_ff (`int`, *optional*, defaults to `2048`):
The number of dimensions in the intermediate feed-forward layer of `DecoderLayer`'s.
dropout_rate (`float`, *optional*, defaults to `0.1`):
Dropout probability.
"""

@register_to_config
def __init__(
self,
Expand Down Expand Up @@ -63,7 +86,7 @@ def __init__(
self.post_dropout = nn.Dropout(p=dropout_rate)
self.spec_out = nn.Linear(d_model, input_dims, bias=False)

def encoder_decoder_mask(self, query_input, key_input):
def encoder_decoder_mask(self, query_input: torch.FloatTensor, key_input: torch.FloatTensor) -> torch.FloatTensor:
mask = torch.mul(query_input.unsqueeze(-1), key_input.unsqueeze(-2))
return mask.unsqueeze(-3)

Expand Down Expand Up @@ -125,7 +148,27 @@ def forward(self, encodings_and_masks, decoder_input_tokens, decoder_noise_time)


class DecoderLayer(nn.Module):
def __init__(self, d_model, d_kv, num_heads, d_ff, dropout_rate, layer_norm_epsilon=1e-6):
r"""
T5 decoder layer.

Args:
d_model (`int`):
Size of the input hidden states.
d_kv (`int`):
Size of the key-value projection vectors.
num_heads (`int`):
Number of attention heads.
d_ff (`int`):
Size of the intermediate feed-forward layer.
dropout_rate (`float`):
Dropout probability.
layer_norm_epsilon (`float`, *optional*, defaults to `1e-6`):
A small value used for numerical stability to avoid dividing by zero.
"""

def __init__(
self, d_model: int, d_kv: int, num_heads: int, d_ff: int, dropout_rate: float, layer_norm_epsilon: float = 1e-6
):
super().__init__()
self.layer = nn.ModuleList()

Expand All @@ -152,13 +195,13 @@ def __init__(self, d_model, d_kv, num_heads, d_ff, dropout_rate, layer_norm_epsi

def forward(
self,
hidden_states,
conditioning_emb=None,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
hidden_states: torch.FloatTensor,
conditioning_emb: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
encoder_decoder_position_bias=None,
):
) -> Tuple[torch.FloatTensor]:
hidden_states = self.layer[0](
hidden_states,
conditioning_emb=conditioning_emb,
Expand All @@ -183,7 +226,21 @@ def forward(


class T5LayerSelfAttentionCond(nn.Module):
def __init__(self, d_model, d_kv, num_heads, dropout_rate):
r"""
T5 style self-attention layer with conditioning.

Args:
d_model (`int`):
Size of the input hidden states.
d_kv (`int`):
Size of the key-value projection vectors.
num_heads (`int`):
Number of attention heads.
dropout_rate (`float`):
Dropout probability.
"""

def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float):
super().__init__()
self.layer_norm = T5LayerNorm(d_model)
self.FiLMLayer = T5FiLMLayer(in_features=d_model * 4, out_features=d_model)
Expand All @@ -192,10 +249,10 @@ def __init__(self, d_model, d_kv, num_heads, dropout_rate):

def forward(
self,
hidden_states,
conditioning_emb=None,
attention_mask=None,
):
hidden_states: torch.FloatTensor,
conditioning_emb: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
# pre_self_attention_layer_norm
normed_hidden_states = self.layer_norm(hidden_states)

Expand All @@ -211,18 +268,34 @@ def forward(


class T5LayerCrossAttention(nn.Module):
def __init__(self, d_model, d_kv, num_heads, dropout_rate, layer_norm_epsilon):
r"""
T5 style cross-attention layer.

Args:
d_model (`int`):
Size of the input hidden states.
d_kv (`int`):
Size of the key-value projection vectors.
num_heads (`int`):
Number of attention heads.
dropout_rate (`float`):
Dropout probability.
layer_norm_epsilon (`float`):
A small value used for numerical stability to avoid dividing by zero.
"""

def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float, layer_norm_epsilon: float):
super().__init__()
self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False)
self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
self.dropout = nn.Dropout(dropout_rate)

def forward(
self,
hidden_states,
key_value_states=None,
attention_mask=None,
):
hidden_states: torch.FloatTensor,
key_value_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.attention(
normed_hidden_states,
Expand All @@ -234,14 +307,30 @@ def forward(


class T5LayerFFCond(nn.Module):
def __init__(self, d_model, d_ff, dropout_rate, layer_norm_epsilon):
r"""
T5 style feed-forward conditional layer.

Args:
d_model (`int`):
Size of the input hidden states.
d_ff (`int`):
Size of the intermediate feed-forward layer.
dropout_rate (`float`):
Dropout probability.
layer_norm_epsilon (`float`):
A small value used for numerical stability to avoid dividing by zero.
"""

def __init__(self, d_model: int, d_ff: int, dropout_rate: float, layer_norm_epsilon: float):
super().__init__()
self.DenseReluDense = T5DenseGatedActDense(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate)
self.film = T5FiLMLayer(in_features=d_model * 4, out_features=d_model)
self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
self.dropout = nn.Dropout(dropout_rate)

def forward(self, hidden_states, conditioning_emb=None):
def forward(
self, hidden_states: torch.FloatTensor, conditioning_emb: Optional[torch.FloatTensor] = None
) -> torch.FloatTensor:
forwarded_states = self.layer_norm(hidden_states)
if conditioning_emb is not None:
forwarded_states = self.film(forwarded_states, conditioning_emb)
Expand All @@ -252,15 +341,27 @@ def forward(self, hidden_states, conditioning_emb=None):


class T5DenseGatedActDense(nn.Module):
def __init__(self, d_model, d_ff, dropout_rate):
r"""
T5 style feed-forward layer with gated activations and dropout.

Args:
d_model (`int`):
Size of the input hidden states.
d_ff (`int`):
Size of the intermediate feed-forward layer.
dropout_rate (`float`):
Dropout probability.
"""

def __init__(self, d_model: int, d_ff: int, dropout_rate: float):
super().__init__()
self.wi_0 = nn.Linear(d_model, d_ff, bias=False)
self.wi_1 = nn.Linear(d_model, d_ff, bias=False)
self.wo = nn.Linear(d_ff, d_model, bias=False)
self.dropout = nn.Dropout(dropout_rate)
self.act = NewGELUActivation()

def forward(self, hidden_states):
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
hidden_gelu = self.act(self.wi_0(hidden_states))
hidden_linear = self.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
Expand All @@ -271,15 +372,25 @@ def forward(self, hidden_states):


class T5LayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
r"""
T5 style layer normalization module.

Args:
hidden_size (`int`):
Size of the input hidden states.
eps (`float`, `optional`, defaults to `1e-6`):
A small value used for numerical stability to avoid dividing by zero.
"""

def __init__(self, hidden_size: int, eps: float = 1e-6):
"""
Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps

def forward(self, hidden_states):
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
Expand Down Expand Up @@ -307,14 +418,20 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:

class T5FiLMLayer(nn.Module):
"""
FiLM Layer
T5 style FiLM Layer.

Args:
in_features (`int`):
Number of input features.
out_features (`int`):
Number of output features.
"""

def __init__(self, in_features, out_features):
def __init__(self, in_features: int, out_features: int):
super().__init__()
self.scale_bias = nn.Linear(in_features, out_features * 2, bias=False)

def forward(self, x, conditioning_emb):
def forward(self, x: torch.FloatTensor, conditioning_emb: torch.FloatTensor) -> torch.FloatTensor:
emb = self.scale_bias(conditioning_emb)
scale, shift = torch.chunk(emb, 2, -1)
x = x * (1 + scale) + shift
Expand Down
28 changes: 16 additions & 12 deletions src/diffusers/models/transformer_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional
from typing import Any, Dict, Optional

import torch
from torch import nn
Expand Down Expand Up @@ -48,11 +48,15 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin):
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
This is fixed during training since it is used to learn a number of position embeddings.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
attention_bias (`bool`, *optional*):
Configure if the `TransformerBlock` attention should contain a bias parameter.
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
This is fixed during training since it is used to learn a number of position embeddings.
activation_fn (`str`, *optional*, defaults to `"geglu"`):
Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported
activation functions.
norm_elementwise_affine (`bool`, *optional*):
Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization.
double_self_attention (`bool`, *optional*):
Configure if each `TransformerBlock` should contain two self-attention layers.
"""
Expand Down Expand Up @@ -106,14 +110,14 @@ def __init__(

def forward(
self,
hidden_states,
encoder_hidden_states=None,
timestep=None,
class_labels=None,
num_frames=1,
cross_attention_kwargs=None,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.LongTensor] = None,
timestep: Optional[torch.LongTensor] = None,
class_labels: torch.LongTensor = None,
num_frames: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
):
) -> TransformerTemporalModelOutput:
"""
The [`TransformerTemporal`] forward method.

Expand All @@ -123,7 +127,7 @@ def forward(
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
timestep ( `torch.long`, *optional*):
timestep ( `torch.LongTensor`, *optional*):
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
Expand Down
Loading