Skip to content

Commit

Permalink
Relative import xformer triton depend fix (#37)
Browse files Browse the repository at this point in the history
* 1. xformer import logi update;
2. changing all import * to specific package import

Signed-off-by: lawrence-cj <[email protected]>

* 1. change all import * to specific package;
2. fix bugs
3. add `triton` and `xformers` checking into import_utils.py

* support `F.scaled_dot_product_attention` without `xformers`

* pre-commit

Signed-off-by: lawrence-cj <[email protected]>

---------

Signed-off-by: lawrence-cj <[email protected]>
Co-authored-by: Ligeng Zhu <[email protected]>
  • Loading branch information
lawrence-cj and Lyken17 authored Nov 24, 2024
1 parent fcc872a commit fa267d5
Show file tree
Hide file tree
Showing 15 changed files with 216 additions and 45 deletions.
1 change: 0 additions & 1 deletion diffusion/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@
from .nets import *
48 changes: 43 additions & 5 deletions diffusion/model/nets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,43 @@
from .sana import *
from .sana_multi_scale import *
from .sana_multi_scale_adaln import *
from .sana_U_shape import *
from .sana_U_shape_multi_scale import *
from .sana import (
Sana,
SanaBlock,
get_1d_sincos_pos_embed_from_grid,
get_2d_sincos_pos_embed,
get_2d_sincos_pos_embed_from_grid,
)
from .sana_multi_scale import (
SanaMS,
SanaMS_600M_P1_D28,
SanaMS_600M_P2_D28,
SanaMS_600M_P4_D28,
SanaMS_1600M_P1_D20,
SanaMS_1600M_P2_D20,
SanaMSBlock,
)
from .sana_multi_scale_adaln import (
SanaMSAdaLN,
SanaMSAdaLN_600M_P1_D28,
SanaMSAdaLN_600M_P2_D28,
SanaMSAdaLN_600M_P4_D28,
SanaMSAdaLN_1600M_P1_D20,
SanaMSAdaLN_1600M_P2_D20,
SanaMSAdaLNBlock,
)
from .sana_U_shape import (
SanaU,
SanaU_600M_P1_D28,
SanaU_600M_P2_D28,
SanaU_600M_P4_D28,
SanaU_1600M_P1_D20,
SanaU_1600M_P2_D20,
SanaUBlock,
)
from .sana_U_shape_multi_scale import (
SanaUMS,
SanaUMS_600M_P1_D28,
SanaUMS_600M_P2_D28,
SanaUMS_600M_P4_D28,
SanaUMS_1600M_P1_D20,
SanaUMS_1600M_P2_D20,
SanaUMSBlock,
)
16 changes: 15 additions & 1 deletion diffusion/model/nets/sana.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

from diffusion.model.builder import MODELS
from diffusion.model.nets.basic_modules import DWMlp, GLUMBConv, MBConvPreGLU, Mlp
from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonMBConvPreGLU
from diffusion.model.nets.sana_blocks import (
Attention,
CaptionEmbedder,
Expand All @@ -39,8 +38,15 @@
from diffusion.model.norms import RMSNorm
from diffusion.model.utils import auto_grad_checkpoint, to_2tuple
from diffusion.utils.dist_utils import get_rank
from diffusion.utils.import_utils import is_triton_module_available
from diffusion.utils.logger import get_root_logger

_triton_modules_available = False
if is_triton_module_available():
from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonMBConvPreGLU

_triton_modules_available = True


class SanaBlock(nn.Module):
"""
Expand Down Expand Up @@ -78,6 +84,10 @@ def __init__(
self_num_heads = hidden_size // linear_head_dim
self.attn = LiteLA(hidden_size, hidden_size, heads=self_num_heads, eps=1e-8, qk_norm=qk_norm)
elif attn_type == "triton_linear":
if not _triton_modules_available:
raise ValueError(
f"{attn_type} type is not available due to _triton_modules_available={_triton_modules_available}."
)
# linear self attention with triton kernel fusion
# TODO: Here the num_heads set to 36 for tmp used
self_num_heads = hidden_size // linear_head_dim
Expand Down Expand Up @@ -123,6 +133,10 @@ def __init__(
act=("silu", "silu", None),
)
elif ffn_type == "triton_mbconvpreglu":
if not _triton_modules_available:
raise ValueError(
f"{ffn_type} type is not available due to _triton_modules_available={_triton_modules_available}."
)
self.mlp = TritonMBConvPreGLU(
in_dim=hidden_size,
out_dim=hidden_size,
Expand Down
24 changes: 17 additions & 7 deletions diffusion/model/nets/sana_U_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

from diffusion.model.builder import MODELS
from diffusion.model.nets.basic_modules import DWMlp, GLUMBConv, MBConvPreGLU, Mlp
from diffusion.model.nets.fastlinear.modules import TritonLiteMLA
from diffusion.model.nets.sana import Sana, get_2d_sincos_pos_embed
from diffusion.model.nets.sana_blocks import (
Attention,
Expand All @@ -37,9 +36,16 @@
t2i_modulate,
)
from diffusion.model.norms import RMSNorm
from diffusion.model.utils import auto_grad_checkpoint, to_2tuple
from diffusion.model.utils import auto_grad_checkpoint
from diffusion.utils.import_utils import is_triton_module_available
from diffusion.utils.logger import get_root_logger

_triton_modules_available = False
if is_triton_module_available():
from diffusion.model.nets.fastlinear.modules import TritonLiteMLA

_triton_modules_available = True


class SanaUBlock(nn.Module):
"""
Expand Down Expand Up @@ -77,6 +83,10 @@ def __init__(
self_num_heads = hidden_size // 32
self.attn = LiteLA(hidden_size, hidden_size, heads=self_num_heads, eps=1e-8, qk_norm=qk_norm)
elif attn_type == "triton_linear":
if not _triton_modules_available:
raise ValueError(
f"{attn_type} type is not available due to _triton_modules_available={_triton_modules_available}."
)
# linear self attention with triton kernel fusion
# TODO: Here the num_heads set to 36 for tmp used
self_num_heads = hidden_size // 32
Expand Down Expand Up @@ -343,27 +353,27 @@ def dtype(self):
# SanaU Configs #
#################################################################################
@MODELS.register_module()
def SanaMSU_600M_P1_D28(**kwargs):
def SanaU_600M_P1_D28(**kwargs):
return SanaU(depth=28, hidden_size=1152, patch_size=1, num_heads=16, **kwargs)


@MODELS.register_module()
def SanaMSU_600M_P2_D28(**kwargs):
def SanaU_600M_P2_D28(**kwargs):
return SanaU(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)


@MODELS.register_module()
def SanaMSU_600M_P4_D28(**kwargs):
def SanaU_600M_P4_D28(**kwargs):
return SanaU(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs)


@MODELS.register_module()
def SanaMSU_P1_D20(**kwargs):
def SanaU_1600M_P1_D20(**kwargs):
# 20 layers, 1648.48M
return SanaU(depth=20, hidden_size=2240, patch_size=1, num_heads=20, **kwargs)


@MODELS.register_module()
def SanaMSU_P2_D20(**kwargs):
def SanaU_1600M_P2_D20(**kwargs):
# 28 layers, 1648.48M
return SanaU(depth=20, hidden_size=2240, patch_size=2, num_heads=20, **kwargs)
19 changes: 14 additions & 5 deletions diffusion/model/nets/sana_U_shape_multi_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

from diffusion.model.builder import MODELS
from diffusion.model.nets.basic_modules import DWMlp, GLUMBConv, MBConvPreGLU, Mlp
from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonLiteMLAFwd
from diffusion.model.nets.sana import Sana, get_2d_sincos_pos_embed
from diffusion.model.nets.sana_blocks import (
Attention,
Expand All @@ -30,11 +29,17 @@
LiteLA,
MultiHeadCrossAttention,
PatchEmbedMS,
SizeEmbedder,
T2IFinalLayer,
t2i_modulate,
)
from diffusion.model.utils import auto_grad_checkpoint, to_2tuple
from diffusion.model.utils import auto_grad_checkpoint
from diffusion.utils.import_utils import is_triton_module_available

_triton_modules_available = False
if is_triton_module_available():
from diffusion.model.nets.fastlinear.modules import TritonLiteMLA

_triton_modules_available = True


class SanaUMSBlock(nn.Module):
Expand Down Expand Up @@ -74,6 +79,10 @@ def __init__(
self_num_heads = hidden_size // 32
self.attn = LiteLA(hidden_size, hidden_size, heads=self_num_heads, eps=1e-8, qk_norm=qk_norm)
elif attn_type == "triton_linear":
if not _triton_modules_available:
raise ValueError(
f"{attn_type} type is not available due to _triton_modules_available={_triton_modules_available}."
)
# linear self attention with triton kernel fusion
self_num_heads = hidden_size // 32
self.attn = TritonLiteMLA(hidden_size, num_heads=self_num_heads, eps=1e-8)
Expand Down Expand Up @@ -365,12 +374,12 @@ def SanaUMS_600M_P4_D28(**kwargs):


@MODELS.register_module()
def SanaUMS_P1_D20(**kwargs):
def SanaUMS_1600M_P1_D20(**kwargs):
# 20 layers, 1648.48M
return SanaUMS(depth=20, hidden_size=2240, patch_size=1, num_heads=20, **kwargs)


@MODELS.register_module()
def SanaUMS_P2_D20(**kwargs):
def SanaUMS_1600M_P2_D20(**kwargs):
# 28 layers, 1648.48M
return SanaUMS(depth=20, hidden_size=2240, patch_size=2, num_heads=20, **kwargs)
44 changes: 34 additions & 10 deletions diffusion/model/nets/sana_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,20 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import xformers.ops
from einops import rearrange
from timm.models.vision_transformer import Attention as Attention_
from timm.models.vision_transformer import Mlp
from transformers import AutoModelForCausalLM

from diffusion.model.norms import RMSNorm
from diffusion.model.utils import get_same_padding, to_2tuple
from diffusion.utils.import_utils import is_xformers_available

_xformers_available = False
if is_xformers_available():
import xformers.ops

_xformers_available = True


def modulate(x, shift, scale):
Expand Down Expand Up @@ -65,17 +71,27 @@ def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0, qk_norm=Fal
def forward(self, x, cond, mask=None):
# query: img tokens; key/value: condition; mask: if padding tokens
B, N, C = x.shape
first_dim = 1 if _xformers_available else B

q = self.q_linear(x)
kv = self.kv_linear(cond).view(1, -1, 2, C)
kv = self.kv_linear(cond).view(first_dim, -1, 2, C)
k, v = kv.unbind(2)
q = self.q_norm(q).view(1, -1, self.num_heads, self.head_dim)
k = self.k_norm(k).view(1, -1, self.num_heads, self.head_dim)
v = v.view(1, -1, self.num_heads, self.head_dim)
attn_bias = None
if mask is not None:
attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask)
x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
q = self.q_norm(q).view(first_dim, -1, self.num_heads, self.head_dim)
k = self.k_norm(k).view(first_dim, -1, self.num_heads, self.head_dim)
v = v.view(first_dim, -1, self.num_heads, self.head_dim)

if _xformers_available:
attn_bias = None
if mask is not None:
attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask)
x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
else:
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
if mask is not None and mask.ndim == 2:
mask = (1 - mask.to(x.dtype)) * -10000.0
mask = mask[:, None, None].repeat(1, self.num_heads, 1, 1)
x = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
x = x.transpose(1, 2)

x = x.view(B, -1, C)
x = self.proj(x)
Expand Down Expand Up @@ -347,7 +363,15 @@ def forward(self, x, mask=None, HW=None, block_id=None):
attn_bias = torch.zeros([B * self.num_heads, q.shape[1], k.shape[1]], dtype=q.dtype, device=q.device)
attn_bias.masked_fill_(mask.squeeze(1).repeat(self.num_heads, 1, 1) == 0, float("-inf"))

x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
if _xformers_available:
x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
else:
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
if mask is not None and mask.ndim == 2:
mask = (1 - mask.to(x.dtype)) * -10000.0
mask = mask[:, None, None].repeat(1, self.num_heads, 1, 1)
x = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
x = x.transpose(1, 2)

x = x.view(B, N, C)
x = self.proj(x)
Expand Down
35 changes: 29 additions & 6 deletions diffusion/model/nets/sana_multi_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

from diffusion.model.builder import MODELS
from diffusion.model.nets.basic_modules import DWMlp, GLUMBConv, MBConvPreGLU, Mlp
from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonMBConvPreGLU
from diffusion.model.nets.sana import Sana, get_2d_sincos_pos_embed
from diffusion.model.nets.sana_blocks import (
Attention,
Expand All @@ -34,6 +33,17 @@
t2i_modulate,
)
from diffusion.model.utils import auto_grad_checkpoint
from diffusion.utils.import_utils import is_triton_module_available, is_xformers_available

_triton_modules_available = False
if is_triton_module_available():
from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonMBConvPreGLU

_triton_modules_available = True

_xformers_available = False
if is_xformers_available():
_xformers_available = True


class SanaMSBlock(nn.Module):
Expand Down Expand Up @@ -74,6 +84,10 @@ def __init__(
self_num_heads = hidden_size // linear_head_dim
self.attn = LiteLA(hidden_size, hidden_size, heads=self_num_heads, eps=1e-8, qk_norm=qk_norm)
elif attn_type == "triton_linear":
if not _triton_modules_available:
raise ValueError(
f"{attn_type} type is not available due to _triton_modules_available={_triton_modules_available}."
)
# linear self attention with triton kernel fusion
self_num_heads = hidden_size // linear_head_dim
self.attn = TritonLiteMLA(hidden_size, num_heads=self_num_heads, eps=1e-8)
Expand Down Expand Up @@ -108,6 +122,10 @@ def __init__(
dilation=2,
)
elif ffn_type == "triton_mbconvpreglu":
if not _triton_modules_available:
raise ValueError(
f"{ffn_type} type is not available due to _triton_modules_available={_triton_modules_available}."
)
self.mlp = TritonMBConvPreGLU(
in_dim=hidden_size,
out_dim=hidden_size,
Expand Down Expand Up @@ -292,14 +310,19 @@ def forward(self, x, timestep, y, mask=None, data_info=None, **kwargs):
y = self.attention_y_norm(y)

if mask is not None:
if mask.shape[0] != y.shape[0]:
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
mask = mask.repeat(y.shape[0] // mask.shape[0], 1) if mask.shape[0] != y.shape[0] else mask
mask = mask.squeeze(1).squeeze(1)
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()
else:
if _xformers_available:
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()
else:
y_lens = mask
elif _xformers_available:
y_lens = [y.shape[2]] * y.shape[0]
y = y.squeeze(1).view(1, -1, x.shape[-1])
else:
raise ValueError(f"Attention type is not available due to _xformers_available={_xformers_available}.")

for block in self.blocks:
x = auto_grad_checkpoint(
block, x, y, t0, y_lens, (self.h, self.w), **kwargs
Expand Down
Loading

0 comments on commit fa267d5

Please sign in to comment.