Skip to content

Commit

Permalink
[refactor embeddings]pixart-alpha (huggingface#6212)
Browse files Browse the repository at this point in the history
pixart-alpha

Co-authored-by: yiyixuxu <yixu310@gmail,com>
  • Loading branch information
yiyixuxu and yiyixuxu authored Dec 19, 2023
1 parent bf40d7d commit 3e71a20
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 31 deletions.
35 changes: 8 additions & 27 deletions src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,7 @@ def forward(
return objs


class CombinedTimestepSizeEmbeddings(nn.Module):
class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
"""
For PixArt-Alpha.
Expand All @@ -746,45 +746,27 @@ def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool

self.use_additional_conditions = use_additional_conditions
if use_additional_conditions:
self.use_additional_conditions = True
self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)

def apply_condition(self, size: torch.Tensor, batch_size: int, embedder: nn.Module):
if size.ndim == 1:
size = size[:, None]

if size.shape[0] != batch_size:
size = size.repeat(batch_size // size.shape[0], 1)
if size.shape[0] != batch_size:
raise ValueError(f"`batch_size` should be {size.shape[0]} but found {batch_size}.")

current_batch_size, dims = size.shape[0], size.shape[1]
size = size.reshape(-1)
size_freq = self.additional_condition_proj(size).to(size.dtype)

size_emb = embedder(size_freq)
size_emb = size_emb.reshape(current_batch_size, dims * self.outdim)
return size_emb

def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)

if self.use_additional_conditions:
resolution = self.apply_condition(resolution, batch_size=batch_size, embedder=self.resolution_embedder)
aspect_ratio = self.apply_condition(
aspect_ratio, batch_size=batch_size, embedder=self.aspect_ratio_embedder
)
conditioning = timesteps_emb + torch.cat([resolution, aspect_ratio], dim=1)
resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype)
resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1)
aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).to(hidden_dtype)
aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb).reshape(batch_size, -1)
conditioning = timesteps_emb + torch.cat([resolution_emb, aspect_ratio_emb], dim=1)
else:
conditioning = timesteps_emb

return conditioning


class CaptionProjection(nn.Module):
class PixArtAlphaTextProjection(nn.Module):
"""
Projects caption embeddings. Also handles dropout for classifier-free guidance.
Expand All @@ -796,9 +778,8 @@ def __init__(self, in_features, hidden_size, num_tokens=120):
self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
self.act_1 = nn.GELU(approximate="tanh")
self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True)
self.register_buffer("y_embedding", nn.Parameter(torch.randn(num_tokens, in_features) / in_features**0.5))

def forward(self, caption, force_drop_ids=None):
def forward(self, caption):
hidden_states = self.linear_1(caption)
hidden_states = self.act_1(hidden_states)
hidden_states = self.linear_2(hidden_states)
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/models/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch.nn.functional as F

from .activations import get_activation
from .embeddings import CombinedTimestepLabelEmbeddings, CombinedTimestepSizeEmbeddings
from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings


class AdaLayerNorm(nn.Module):
Expand Down Expand Up @@ -91,7 +91,7 @@ class AdaLayerNormSingle(nn.Module):
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False):
super().__init__()

self.emb = CombinedTimestepSizeEmbeddings(
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions
)

Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/models/transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ..models.embeddings import ImagePositionalEmbeddings
from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version
from .attention import BasicTransformerBlock
from .embeddings import CaptionProjection, PatchEmbed
from .embeddings import PatchEmbed, PixArtAlphaTextProjection
from .lora import LoRACompatibleConv, LoRACompatibleLinear
from .modeling_utils import ModelMixin
from .normalization import AdaLayerNormSingle
Expand Down Expand Up @@ -235,7 +235,7 @@ def __init__(

self.caption_projection = None
if caption_channels is not None:
self.caption_projection = CaptionProjection(in_features=caption_channels, hidden_size=inner_dim)
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)

self.gradient_checkpointing = False

Expand Down
5 changes: 5 additions & 0 deletions src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,6 +853,11 @@ def __call__(
aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1)
resolution = resolution.to(dtype=prompt_embeds.dtype, device=device)
aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device)

if do_classifier_free_guidance:
resolution = torch.cat([resolution, resolution], dim=0)
aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0)

added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}

# 7. Denoising loop
Expand Down

0 comments on commit 3e71a20

Please sign in to comment.