Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Relative import xformer triton depend fix #37

Merged
merged 4 commits into from
Nov 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion diffusion/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@
from .nets import *
48 changes: 43 additions & 5 deletions diffusion/model/nets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,43 @@
from .sana import *
from .sana_multi_scale import *
from .sana_multi_scale_adaln import *
from .sana_U_shape import *
from .sana_U_shape_multi_scale import *
from .sana import (
Sana,
SanaBlock,
get_1d_sincos_pos_embed_from_grid,
get_2d_sincos_pos_embed,
get_2d_sincos_pos_embed_from_grid,
)
from .sana_multi_scale import (
SanaMS,
SanaMS_600M_P1_D28,
SanaMS_600M_P2_D28,
SanaMS_600M_P4_D28,
SanaMS_1600M_P1_D20,
SanaMS_1600M_P2_D20,
SanaMSBlock,
)
from .sana_multi_scale_adaln import (
SanaMSAdaLN,
SanaMSAdaLN_600M_P1_D28,
SanaMSAdaLN_600M_P2_D28,
SanaMSAdaLN_600M_P4_D28,
SanaMSAdaLN_1600M_P1_D20,
SanaMSAdaLN_1600M_P2_D20,
SanaMSAdaLNBlock,
)
from .sana_U_shape import (
SanaU,
SanaU_600M_P1_D28,
SanaU_600M_P2_D28,
SanaU_600M_P4_D28,
SanaU_1600M_P1_D20,
SanaU_1600M_P2_D20,
SanaUBlock,
)
from .sana_U_shape_multi_scale import (
SanaUMS,
SanaUMS_600M_P1_D28,
SanaUMS_600M_P2_D28,
SanaUMS_600M_P4_D28,
SanaUMS_1600M_P1_D20,
SanaUMS_1600M_P2_D20,
SanaUMSBlock,
)
16 changes: 15 additions & 1 deletion diffusion/model/nets/sana.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

from diffusion.model.builder import MODELS
from diffusion.model.nets.basic_modules import DWMlp, GLUMBConv, MBConvPreGLU, Mlp
from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonMBConvPreGLU
from diffusion.model.nets.sana_blocks import (
Attention,
CaptionEmbedder,
Expand All @@ -39,8 +38,15 @@
from diffusion.model.norms import RMSNorm
from diffusion.model.utils import auto_grad_checkpoint, to_2tuple
from diffusion.utils.dist_utils import get_rank
from diffusion.utils.import_utils import is_triton_module_available
from diffusion.utils.logger import get_root_logger

_triton_modules_available = False
if is_triton_module_available():
from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonMBConvPreGLU

_triton_modules_available = True


class SanaBlock(nn.Module):
"""
Expand Down Expand Up @@ -78,6 +84,10 @@ def __init__(
self_num_heads = hidden_size // linear_head_dim
self.attn = LiteLA(hidden_size, hidden_size, heads=self_num_heads, eps=1e-8, qk_norm=qk_norm)
elif attn_type == "triton_linear":
if not _triton_modules_available:
raise ValueError(
f"{attn_type} type is not available due to _triton_modules_available={_triton_modules_available}."
)
# linear self attention with triton kernel fusion
# TODO: Here the num_heads set to 36 for tmp used
self_num_heads = hidden_size // linear_head_dim
Expand Down Expand Up @@ -123,6 +133,10 @@ def __init__(
act=("silu", "silu", None),
)
elif ffn_type == "triton_mbconvpreglu":
if not _triton_modules_available:
raise ValueError(
f"{ffn_type} type is not available due to _triton_modules_available={_triton_modules_available}."
)
self.mlp = TritonMBConvPreGLU(
in_dim=hidden_size,
out_dim=hidden_size,
Expand Down
24 changes: 17 additions & 7 deletions diffusion/model/nets/sana_U_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

from diffusion.model.builder import MODELS
from diffusion.model.nets.basic_modules import DWMlp, GLUMBConv, MBConvPreGLU, Mlp
from diffusion.model.nets.fastlinear.modules import TritonLiteMLA
from diffusion.model.nets.sana import Sana, get_2d_sincos_pos_embed
from diffusion.model.nets.sana_blocks import (
Attention,
Expand All @@ -37,9 +36,16 @@
t2i_modulate,
)
from diffusion.model.norms import RMSNorm
from diffusion.model.utils import auto_grad_checkpoint, to_2tuple
from diffusion.model.utils import auto_grad_checkpoint
from diffusion.utils.import_utils import is_triton_module_available
from diffusion.utils.logger import get_root_logger

_triton_modules_available = False
if is_triton_module_available():
from diffusion.model.nets.fastlinear.modules import TritonLiteMLA

_triton_modules_available = True


class SanaUBlock(nn.Module):
"""
Expand Down Expand Up @@ -77,6 +83,10 @@ def __init__(
self_num_heads = hidden_size // 32
self.attn = LiteLA(hidden_size, hidden_size, heads=self_num_heads, eps=1e-8, qk_norm=qk_norm)
elif attn_type == "triton_linear":
if not _triton_modules_available:
raise ValueError(
f"{attn_type} type is not available due to _triton_modules_available={_triton_modules_available}."
)
# linear self attention with triton kernel fusion
# TODO: Here the num_heads set to 36 for tmp used
self_num_heads = hidden_size // 32
Expand Down Expand Up @@ -343,27 +353,27 @@ def dtype(self):
# SanaU Configs #
#################################################################################
@MODELS.register_module()
def SanaMSU_600M_P1_D28(**kwargs):
def SanaU_600M_P1_D28(**kwargs):
return SanaU(depth=28, hidden_size=1152, patch_size=1, num_heads=16, **kwargs)


@MODELS.register_module()
def SanaMSU_600M_P2_D28(**kwargs):
def SanaU_600M_P2_D28(**kwargs):
return SanaU(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)


@MODELS.register_module()
def SanaMSU_600M_P4_D28(**kwargs):
def SanaU_600M_P4_D28(**kwargs):
return SanaU(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs)


@MODELS.register_module()
def SanaMSU_P1_D20(**kwargs):
def SanaU_1600M_P1_D20(**kwargs):
# 20 layers, 1648.48M
return SanaU(depth=20, hidden_size=2240, patch_size=1, num_heads=20, **kwargs)


@MODELS.register_module()
def SanaMSU_P2_D20(**kwargs):
def SanaU_1600M_P2_D20(**kwargs):
# 28 layers, 1648.48M
return SanaU(depth=20, hidden_size=2240, patch_size=2, num_heads=20, **kwargs)
19 changes: 14 additions & 5 deletions diffusion/model/nets/sana_U_shape_multi_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

from diffusion.model.builder import MODELS
from diffusion.model.nets.basic_modules import DWMlp, GLUMBConv, MBConvPreGLU, Mlp
from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonLiteMLAFwd
from diffusion.model.nets.sana import Sana, get_2d_sincos_pos_embed
from diffusion.model.nets.sana_blocks import (
Attention,
Expand All @@ -30,11 +29,17 @@
LiteLA,
MultiHeadCrossAttention,
PatchEmbedMS,
SizeEmbedder,
T2IFinalLayer,
t2i_modulate,
)
from diffusion.model.utils import auto_grad_checkpoint, to_2tuple
from diffusion.model.utils import auto_grad_checkpoint
from diffusion.utils.import_utils import is_triton_module_available

_triton_modules_available = False
if is_triton_module_available():
from diffusion.model.nets.fastlinear.modules import TritonLiteMLA

_triton_modules_available = True


class SanaUMSBlock(nn.Module):
Expand Down Expand Up @@ -74,6 +79,10 @@ def __init__(
self_num_heads = hidden_size // 32
self.attn = LiteLA(hidden_size, hidden_size, heads=self_num_heads, eps=1e-8, qk_norm=qk_norm)
elif attn_type == "triton_linear":
if not _triton_modules_available:
raise ValueError(
f"{attn_type} type is not available due to _triton_modules_available={_triton_modules_available}."
)
# linear self attention with triton kernel fusion
self_num_heads = hidden_size // 32
self.attn = TritonLiteMLA(hidden_size, num_heads=self_num_heads, eps=1e-8)
Expand Down Expand Up @@ -365,12 +374,12 @@ def SanaUMS_600M_P4_D28(**kwargs):


@MODELS.register_module()
def SanaUMS_P1_D20(**kwargs):
def SanaUMS_1600M_P1_D20(**kwargs):
# 20 layers, 1648.48M
return SanaUMS(depth=20, hidden_size=2240, patch_size=1, num_heads=20, **kwargs)


@MODELS.register_module()
def SanaUMS_P2_D20(**kwargs):
def SanaUMS_1600M_P2_D20(**kwargs):
# 28 layers, 1648.48M
return SanaUMS(depth=20, hidden_size=2240, patch_size=2, num_heads=20, **kwargs)
44 changes: 34 additions & 10 deletions diffusion/model/nets/sana_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,20 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import xformers.ops
from einops import rearrange
from timm.models.vision_transformer import Attention as Attention_
from timm.models.vision_transformer import Mlp
from transformers import AutoModelForCausalLM

from diffusion.model.norms import RMSNorm
from diffusion.model.utils import get_same_padding, to_2tuple
from diffusion.utils.import_utils import is_xformers_available

_xformers_available = False
if is_xformers_available():
import xformers.ops

_xformers_available = True


def modulate(x, shift, scale):
Expand Down Expand Up @@ -65,17 +71,27 @@ def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0, qk_norm=Fal
def forward(self, x, cond, mask=None):
# query: img tokens; key/value: condition; mask: if padding tokens
B, N, C = x.shape
first_dim = 1 if _xformers_available else B

q = self.q_linear(x)
kv = self.kv_linear(cond).view(1, -1, 2, C)
kv = self.kv_linear(cond).view(first_dim, -1, 2, C)
k, v = kv.unbind(2)
q = self.q_norm(q).view(1, -1, self.num_heads, self.head_dim)
k = self.k_norm(k).view(1, -1, self.num_heads, self.head_dim)
v = v.view(1, -1, self.num_heads, self.head_dim)
attn_bias = None
if mask is not None:
attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask)
x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
q = self.q_norm(q).view(first_dim, -1, self.num_heads, self.head_dim)
k = self.k_norm(k).view(first_dim, -1, self.num_heads, self.head_dim)
v = v.view(first_dim, -1, self.num_heads, self.head_dim)

if _xformers_available:
attn_bias = None
if mask is not None:
attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask)
x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
else:
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
if mask is not None and mask.ndim == 2:
mask = (1 - mask.to(x.dtype)) * -10000.0
mask = mask[:, None, None].repeat(1, self.num_heads, 1, 1)
x = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
x = x.transpose(1, 2)

x = x.view(B, -1, C)
x = self.proj(x)
Expand Down Expand Up @@ -347,7 +363,15 @@ def forward(self, x, mask=None, HW=None, block_id=None):
attn_bias = torch.zeros([B * self.num_heads, q.shape[1], k.shape[1]], dtype=q.dtype, device=q.device)
attn_bias.masked_fill_(mask.squeeze(1).repeat(self.num_heads, 1, 1) == 0, float("-inf"))

x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
if _xformers_available:
x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
else:
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
if mask is not None and mask.ndim == 2:
mask = (1 - mask.to(x.dtype)) * -10000.0
mask = mask[:, None, None].repeat(1, self.num_heads, 1, 1)
x = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
x = x.transpose(1, 2)

x = x.view(B, N, C)
x = self.proj(x)
Expand Down
35 changes: 29 additions & 6 deletions diffusion/model/nets/sana_multi_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

from diffusion.model.builder import MODELS
from diffusion.model.nets.basic_modules import DWMlp, GLUMBConv, MBConvPreGLU, Mlp
from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonMBConvPreGLU
from diffusion.model.nets.sana import Sana, get_2d_sincos_pos_embed
from diffusion.model.nets.sana_blocks import (
Attention,
Expand All @@ -34,6 +33,17 @@
t2i_modulate,
)
from diffusion.model.utils import auto_grad_checkpoint
from diffusion.utils.import_utils import is_triton_module_available, is_xformers_available

_triton_modules_available = False
if is_triton_module_available():
from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonMBConvPreGLU

_triton_modules_available = True

_xformers_available = False
if is_xformers_available():
_xformers_available = True


class SanaMSBlock(nn.Module):
Expand Down Expand Up @@ -74,6 +84,10 @@ def __init__(
self_num_heads = hidden_size // linear_head_dim
self.attn = LiteLA(hidden_size, hidden_size, heads=self_num_heads, eps=1e-8, qk_norm=qk_norm)
elif attn_type == "triton_linear":
if not _triton_modules_available:
raise ValueError(
f"{attn_type} type is not available due to _triton_modules_available={_triton_modules_available}."
)
# linear self attention with triton kernel fusion
self_num_heads = hidden_size // linear_head_dim
self.attn = TritonLiteMLA(hidden_size, num_heads=self_num_heads, eps=1e-8)
Expand Down Expand Up @@ -108,6 +122,10 @@ def __init__(
dilation=2,
)
elif ffn_type == "triton_mbconvpreglu":
if not _triton_modules_available:
raise ValueError(
f"{ffn_type} type is not available due to _triton_modules_available={_triton_modules_available}."
)
self.mlp = TritonMBConvPreGLU(
in_dim=hidden_size,
out_dim=hidden_size,
Expand Down Expand Up @@ -292,14 +310,19 @@ def forward(self, x, timestep, y, mask=None, data_info=None, **kwargs):
y = self.attention_y_norm(y)

if mask is not None:
if mask.shape[0] != y.shape[0]:
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
mask = mask.repeat(y.shape[0] // mask.shape[0], 1) if mask.shape[0] != y.shape[0] else mask
mask = mask.squeeze(1).squeeze(1)
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()
else:
if _xformers_available:
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()
else:
y_lens = mask
elif _xformers_available:
y_lens = [y.shape[2]] * y.shape[0]
y = y.squeeze(1).view(1, -1, x.shape[-1])
else:
raise ValueError(f"Attention type is not available due to _xformers_available={_xformers_available}.")

for block in self.blocks:
x = auto_grad_checkpoint(
block, x, y, t0, y_lens, (self.h, self.w), **kwargs
Expand Down
Loading