From 88b93c7cc32a5845187865a26d8d352bd7eef300 Mon Sep 17 00:00:00 2001 From: junsong Date: Sun, 24 Nov 2024 02:03:16 -0800 Subject: [PATCH 1/4] 1. xformer import logi update; 2. changing all import * to specific package import Signed-off-by: lawrence-cj --- diffusion/dpm_solver.py | 2 +- diffusion/iddpm.py | 2 +- diffusion/model/nets/__init__.py | 48 +++++++++++++++++-- diffusion/model/nets/sana.py | 6 ++- diffusion/model/nets/sana_U_shape.py | 16 ++++--- .../model/nets/sana_U_shape_multi_scale.py | 10 ++-- diffusion/model/nets/sana_blocks.py | 35 +++++++++++--- diffusion/model/nets/sana_multi_scale.py | 6 ++- .../model/nets/sana_multi_scale_adaln.py | 6 ++- diffusion/model/nets/sana_others.py | 5 ++ diffusion/sa_sampler.py | 2 +- diffusion/utils/import_utils.py | 21 ++++++++ 12 files changed, 133 insertions(+), 26 deletions(-) create mode 100644 diffusion/utils/import_utils.py diff --git a/diffusion/dpm_solver.py b/diffusion/dpm_solver.py index ff8f1a8..d35ae18 100755 --- a/diffusion/dpm_solver.py +++ b/diffusion/dpm_solver.py @@ -16,7 +16,7 @@ import torch -from .model import gaussian_diffusion as gd +from .model.gaussian_diffusion import gaussian_diffusion as gd from .model.dpm_solver import DPM_Solver, NoiseScheduleFlow, NoiseScheduleVP, model_wrapper diff --git a/diffusion/iddpm.py b/diffusion/iddpm.py index 919a822..0b9bfd2 100755 --- a/diffusion/iddpm.py +++ b/diffusion/iddpm.py @@ -20,7 +20,7 @@ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py from diffusion.model.respace import SpacedDiffusion, space_timesteps -from .model import gaussian_diffusion as gd +from .model.gaussian_diffusion import gaussian_diffusion as gd def Scheduler( diff --git a/diffusion/model/nets/__init__.py b/diffusion/model/nets/__init__.py index 375153d..9a7eaa1 100755 --- a/diffusion/model/nets/__init__.py +++ b/diffusion/model/nets/__init__.py @@ -1,5 +1,43 @@ -from .sana import * -from .sana_multi_scale import * -from .sana_multi_scale_adaln import * -from .sana_U_shape import * -from .sana_U_shape_multi_scale import * +from .sana import ( + Sana, + SanaBlock, + get_2d_sincos_pos_embed, + get_2d_sincos_pos_embed_from_grid, + get_1d_sincos_pos_embed_from_grid, +) +from .sana_multi_scale import ( + SanaMSBlock, + SanaMS, + SanaMS_600M_P1_D28, + SanaMS_600M_P2_D28, + SanaMS_600M_P4_D28, + SanaMS_1600M_P1_D20, + SanaMS_1600M_P2_D20, +) +from .sana_multi_scale_adaln import ( + SanaMSAdaLNBlock, + SanaMSAdaLN, + SanaMSAdaLN_600M_P1_D28, + SanaMSAdaLN_600M_P2_D28, + SanaMSAdaLN_600M_P4_D28, + SanaMSAdaLN_1600M_P1_D20, + SanaMSAdaLN_1600M_P2_D20, +) +from .sana_U_shape import ( + SanaUBlock, + SanaU, + SanaU_600M_P1_D28, + SanaU_600M_P2_D28, + SanaU_600M_P4_D28, + SanaU_1600M_P1_D20, + SanaU_1600M_P2_D20, +) +from .sana_U_shape_multi_scale import ( + SanaUMSBlock, + SanaUMS, + SanaUMS_600M_P1_D28, + SanaUMS_600M_P2_D28, + SanaUMS_600M_P4_D28, + SanaUMS_1600M_P1_D20, + SanaUMS_1600M_P2_D20, +) diff --git a/diffusion/model/nets/sana.py b/diffusion/model/nets/sana.py index 76c318e..63e6b97 100755 --- a/diffusion/model/nets/sana.py +++ b/diffusion/model/nets/sana.py @@ -24,7 +24,11 @@ from diffusion.model.builder import MODELS from diffusion.model.nets.basic_modules import DWMlp, GLUMBConv, MBConvPreGLU, Mlp -from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonMBConvPreGLU +try: + from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonMBConvPreGLU +except ImportError: + import warnings + warnings.warn("TritonLiteMLA and TritonMBConvPreGLU with `triton` is not available on your platform.") from diffusion.model.nets.sana_blocks import ( Attention, CaptionEmbedder, diff --git a/diffusion/model/nets/sana_U_shape.py b/diffusion/model/nets/sana_U_shape.py index d16b895..2480f03 100644 --- a/diffusion/model/nets/sana_U_shape.py +++ b/diffusion/model/nets/sana_U_shape.py @@ -23,7 +23,11 @@ from diffusion.model.builder import MODELS from diffusion.model.nets.basic_modules import DWMlp, GLUMBConv, MBConvPreGLU, Mlp -from diffusion.model.nets.fastlinear.modules import TritonLiteMLA +try: + from diffusion.model.nets.fastlinear.modules import TritonLiteMLA +except ImportError: + import warnings + warnings.warn("TritonLiteMLA with `triton` is not available on your platform.") from diffusion.model.nets.sana import Sana, get_2d_sincos_pos_embed from diffusion.model.nets.sana_blocks import ( Attention, @@ -343,27 +347,27 @@ def dtype(self): # SanaU Configs # ################################################################################# @MODELS.register_module() -def SanaMSU_600M_P1_D28(**kwargs): +def SanaU_600M_P1_D28(**kwargs): return SanaU(depth=28, hidden_size=1152, patch_size=1, num_heads=16, **kwargs) @MODELS.register_module() -def SanaMSU_600M_P2_D28(**kwargs): +def SanaU_600M_P2_D28(**kwargs): return SanaU(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs) @MODELS.register_module() -def SanaMSU_600M_P4_D28(**kwargs): +def SanaU_600M_P4_D28(**kwargs): return SanaU(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs) @MODELS.register_module() -def SanaMSU_P1_D20(**kwargs): +def SanaU_1600M_P1_D20(**kwargs): # 20 layers, 1648.48M return SanaU(depth=20, hidden_size=2240, patch_size=1, num_heads=20, **kwargs) @MODELS.register_module() -def SanaMSU_P2_D20(**kwargs): +def SanaU_1600M_P2_D20(**kwargs): # 28 layers, 1648.48M return SanaU(depth=20, hidden_size=2240, patch_size=2, num_heads=20, **kwargs) diff --git a/diffusion/model/nets/sana_U_shape_multi_scale.py b/diffusion/model/nets/sana_U_shape_multi_scale.py index 927fc1b..df9f940 100644 --- a/diffusion/model/nets/sana_U_shape_multi_scale.py +++ b/diffusion/model/nets/sana_U_shape_multi_scale.py @@ -21,7 +21,11 @@ from diffusion.model.builder import MODELS from diffusion.model.nets.basic_modules import DWMlp, GLUMBConv, MBConvPreGLU, Mlp -from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonLiteMLAFwd +try: + from diffusion.model.nets.fastlinear.modules import TritonLiteMLA +except ImportError: + import warnings + warnings.warn("TritonLiteMLA with `triton` is not available on your platform.") from diffusion.model.nets.sana import Sana, get_2d_sincos_pos_embed from diffusion.model.nets.sana_blocks import ( Attention, @@ -365,12 +369,12 @@ def SanaUMS_600M_P4_D28(**kwargs): @MODELS.register_module() -def SanaUMS_P1_D20(**kwargs): +def SanaUMS_1600M_P1_D20(**kwargs): # 20 layers, 1648.48M return SanaUMS(depth=20, hidden_size=2240, patch_size=1, num_heads=20, **kwargs) @MODELS.register_module() -def SanaUMS_P2_D20(**kwargs): +def SanaUMS_1600M_P2_D20(**kwargs): # 28 layers, 1648.48M return SanaUMS(depth=20, hidden_size=2240, patch_size=2, num_heads=20, **kwargs) diff --git a/diffusion/model/nets/sana_blocks.py b/diffusion/model/nets/sana_blocks.py index b5e5597..f756f0d 100755 --- a/diffusion/model/nets/sana_blocks.py +++ b/diffusion/model/nets/sana_blocks.py @@ -22,7 +22,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -import xformers.ops +from diffusion.utils.import_utils import is_xformers_available from einops import rearrange from timm.models.vision_transformer import Attention as Attention_ from timm.models.vision_transformer import Mlp @@ -32,6 +32,12 @@ from diffusion.model.utils import get_same_padding, to_2tuple +_xformers_available = False +if is_xformers_available(): + import xformers.ops + _xformers_available = True + + def modulate(x, shift, scale): return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) @@ -72,10 +78,19 @@ def forward(self, x, cond, mask=None): q = self.q_norm(q).view(1, -1, self.num_heads, self.head_dim) k = self.k_norm(k).view(1, -1, self.num_heads, self.head_dim) v = v.view(1, -1, self.num_heads, self.head_dim) - attn_bias = None - if mask is not None: - attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask) - x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias) + + if _xformers_available: + attn_bias = None + if mask is not None: + attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask) + x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias) + else: + q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) + if mask is not None and mask.ndim == 2: + mask = (1 - mask.to(x.dtype)) * -10000.0 + mask = mask[:, None, None].repeat(1, self.num_heads, 1, 1) + x = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) + x = x.transpose(1, 2) x = x.view(B, -1, C) x = self.proj(x) @@ -347,7 +362,15 @@ def forward(self, x, mask=None, HW=None, block_id=None): attn_bias = torch.zeros([B * self.num_heads, q.shape[1], k.shape[1]], dtype=q.dtype, device=q.device) attn_bias.masked_fill_(mask.squeeze(1).repeat(self.num_heads, 1, 1) == 0, float("-inf")) - x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias) + if _xformers_available: + x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias) + else: + q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) + if mask is not None and mask.ndim == 2: + mask = (1 - mask.to(x.dtype)) * -10000.0 + mask = mask[:, None, None].repeat(1, self.num_heads, 1, 1) + x = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) + x = x.transpose(1, 2) x = x.view(B, N, C) x = self.proj(x) diff --git a/diffusion/model/nets/sana_multi_scale.py b/diffusion/model/nets/sana_multi_scale.py index 078d82b..7db1328 100755 --- a/diffusion/model/nets/sana_multi_scale.py +++ b/diffusion/model/nets/sana_multi_scale.py @@ -21,7 +21,11 @@ from diffusion.model.builder import MODELS from diffusion.model.nets.basic_modules import DWMlp, GLUMBConv, MBConvPreGLU, Mlp -from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonMBConvPreGLU +try: + from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonMBConvPreGLU +except ImportError: + import warnings + warnings.warn("TritonLiteMLA and TritonMBConvPreGLU with `triton` is not available on your platform.") from diffusion.model.nets.sana import Sana, get_2d_sincos_pos_embed from diffusion.model.nets.sana_blocks import ( Attention, diff --git a/diffusion/model/nets/sana_multi_scale_adaln.py b/diffusion/model/nets/sana_multi_scale_adaln.py index 8c2ce8a..d915c4d 100644 --- a/diffusion/model/nets/sana_multi_scale_adaln.py +++ b/diffusion/model/nets/sana_multi_scale_adaln.py @@ -21,7 +21,11 @@ from diffusion.model.builder import MODELS from diffusion.model.nets.basic_modules import DWMlp, GLUMBConv, MBConvPreGLU, Mlp -from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonLiteMLAFwd +try: + from diffusion.model.nets.fastlinear.modules import TritonLiteMLA +except ImportError: + import warnings + warnings.warn("TritonLiteMLA with `triton` is not available on your platform.") from diffusion.model.nets.sana import Sana, get_2d_sincos_pos_embed from diffusion.model.nets.sana_blocks import ( Attention, diff --git a/diffusion/model/nets/sana_others.py b/diffusion/model/nets/sana_others.py index 519a011..74caf1b 100644 --- a/diffusion/model/nets/sana_others.py +++ b/diffusion/model/nets/sana_others.py @@ -21,6 +21,11 @@ from diffusion.model.nets.basic_modules import DWMlp, MBConvPreGLU, Mlp from diffusion.model.nets.fastlinear.modules import TritonLiteMLA +try: + from diffusion.model.nets.fastlinear.modules import TritonLiteMLA +except ImportError: + import warnings + warnings.warn("TritonLiteMLA with `triton` is not available on your platform.") from diffusion.model.nets.sana_blocks import Attention, FlashAttention, MultiHeadCrossAttention, t2i_modulate diff --git a/diffusion/sa_sampler.py b/diffusion/sa_sampler.py index 79d7f21..dee556f 100755 --- a/diffusion/sa_sampler.py +++ b/diffusion/sa_sampler.py @@ -21,7 +21,7 @@ from diffusion.model.sa_solver import NoiseScheduleVP, SASolver, model_wrapper -from .model import gaussian_diffusion as gd +from .model.gaussian_diffusion import gaussian_diffusion as gd class SASolverSampler: diff --git a/diffusion/utils/import_utils.py b/diffusion/utils/import_utils.py new file mode 100644 index 0000000..562160d --- /dev/null +++ b/diffusion/utils/import_utils.py @@ -0,0 +1,21 @@ +import importlib.util +import importlib_metadata +from packaging import version +import logging + +logger = logging.getLogger(__name__) + +_xformers_available = importlib.util.find_spec("xformers") is not None +try: + if _xformers_available: + _xformers_version = importlib_metadata.version("xformers") + _torch_version = importlib_metadata.version("torch") + if version.Version(_torch_version) < version.Version("1.12"): + raise ValueError("xformers is installed but requires PyTorch >= 1.12") + logger.debug(f"Successfully imported xformers version {_xformers_version}") +except importlib_metadata.PackageNotFoundError: + _xformers_available = False + + +def is_xformers_available(): + return _xformers_available \ No newline at end of file From 6265a23c53aca52f7434b2ab9c1c2a39f26ae602 Mon Sep 17 00:00:00 2001 From: junsong Date: Sun, 24 Nov 2024 02:46:38 -0800 Subject: [PATCH 2/4] 1. change all import * to specific package; 2. fix bugs 3. add `triton` and `xformers` checking into import_utils.py --- diffusion/dpm_solver.py | 2 +- diffusion/iddpm.py | 2 +- diffusion/model/__init__.py | 1 - diffusion/model/nets/sana.py | 15 ++++++++++----- diffusion/model/nets/sana_U_shape.py | 8 ++++++++ diffusion/model/nets/sana_U_shape_multi_scale.py | 12 +++++++----- diffusion/model/nets/sana_multi_scale.py | 15 ++++++++++----- diffusion/model/nets/sana_multi_scale_adaln.py | 12 ++++++++++-- diffusion/sa_sampler.py | 2 +- diffusion/utils/import_utils.py | 16 +++++++++++++++- scripts/inference.py | 2 +- scripts/inference_dpg.py | 2 +- scripts/inference_geneval.py | 2 +- scripts/inference_image_reward.py | 2 +- scripts/interface.py | 2 +- 15 files changed, 68 insertions(+), 27 deletions(-) diff --git a/diffusion/dpm_solver.py b/diffusion/dpm_solver.py index d35ae18..ff8f1a8 100755 --- a/diffusion/dpm_solver.py +++ b/diffusion/dpm_solver.py @@ -16,7 +16,7 @@ import torch -from .model.gaussian_diffusion import gaussian_diffusion as gd +from .model import gaussian_diffusion as gd from .model.dpm_solver import DPM_Solver, NoiseScheduleFlow, NoiseScheduleVP, model_wrapper diff --git a/diffusion/iddpm.py b/diffusion/iddpm.py index 0b9bfd2..919a822 100755 --- a/diffusion/iddpm.py +++ b/diffusion/iddpm.py @@ -20,7 +20,7 @@ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py from diffusion.model.respace import SpacedDiffusion, space_timesteps -from .model.gaussian_diffusion import gaussian_diffusion as gd +from .model import gaussian_diffusion as gd def Scheduler( diff --git a/diffusion/model/__init__.py b/diffusion/model/__init__.py index 5a0d275..e69de29 100755 --- a/diffusion/model/__init__.py +++ b/diffusion/model/__init__.py @@ -1 +0,0 @@ -from .nets import * diff --git a/diffusion/model/nets/sana.py b/diffusion/model/nets/sana.py index 63e6b97..d14bb3d 100755 --- a/diffusion/model/nets/sana.py +++ b/diffusion/model/nets/sana.py @@ -24,11 +24,6 @@ from diffusion.model.builder import MODELS from diffusion.model.nets.basic_modules import DWMlp, GLUMBConv, MBConvPreGLU, Mlp -try: - from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonMBConvPreGLU -except ImportError: - import warnings - warnings.warn("TritonLiteMLA and TritonMBConvPreGLU with `triton` is not available on your platform.") from diffusion.model.nets.sana_blocks import ( Attention, CaptionEmbedder, @@ -43,8 +38,14 @@ from diffusion.model.norms import RMSNorm from diffusion.model.utils import auto_grad_checkpoint, to_2tuple from diffusion.utils.dist_utils import get_rank +from diffusion.utils.import_utils import is_triton_module_available from diffusion.utils.logger import get_root_logger +_triton_modules_available = False +if is_triton_module_available(): + from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonMBConvPreGLU + _triton_modules_available = True + class SanaBlock(nn.Module): """ @@ -82,6 +83,8 @@ def __init__( self_num_heads = hidden_size // linear_head_dim self.attn = LiteLA(hidden_size, hidden_size, heads=self_num_heads, eps=1e-8, qk_norm=qk_norm) elif attn_type == "triton_linear": + if not _triton_modules_available: + raise ValueError(f"{attn_type} type is not available due to _triton_modules_available={_triton_modules_available}.") # linear self attention with triton kernel fusion # TODO: Here the num_heads set to 36 for tmp used self_num_heads = hidden_size // linear_head_dim @@ -127,6 +130,8 @@ def __init__( act=("silu", "silu", None), ) elif ffn_type == "triton_mbconvpreglu": + if not _triton_modules_available: + raise ValueError(f"{ffn_type} type is not available due to _triton_modules_available={_triton_modules_available}.") self.mlp = TritonMBConvPreGLU( in_dim=hidden_size, out_dim=hidden_size, diff --git a/diffusion/model/nets/sana_U_shape.py b/diffusion/model/nets/sana_U_shape.py index 2480f03..21fdbda 100644 --- a/diffusion/model/nets/sana_U_shape.py +++ b/diffusion/model/nets/sana_U_shape.py @@ -43,6 +43,12 @@ from diffusion.model.norms import RMSNorm from diffusion.model.utils import auto_grad_checkpoint, to_2tuple from diffusion.utils.logger import get_root_logger +from diffusion.utils.import_utils import is_triton_module_available + +_triton_modules_available = False +if is_triton_module_available(): + from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonMBConvPreGLU + _triton_modules_available = True class SanaUBlock(nn.Module): @@ -81,6 +87,8 @@ def __init__( self_num_heads = hidden_size // 32 self.attn = LiteLA(hidden_size, hidden_size, heads=self_num_heads, eps=1e-8, qk_norm=qk_norm) elif attn_type == "triton_linear": + if not _triton_modules_available: + raise ValueError(f"{attn_type} type is not available due to _triton_modules_available={_triton_modules_available}.") # linear self attention with triton kernel fusion # TODO: Here the num_heads set to 36 for tmp used self_num_heads = hidden_size // 32 diff --git a/diffusion/model/nets/sana_U_shape_multi_scale.py b/diffusion/model/nets/sana_U_shape_multi_scale.py index df9f940..c4cfb86 100644 --- a/diffusion/model/nets/sana_U_shape_multi_scale.py +++ b/diffusion/model/nets/sana_U_shape_multi_scale.py @@ -21,11 +21,6 @@ from diffusion.model.builder import MODELS from diffusion.model.nets.basic_modules import DWMlp, GLUMBConv, MBConvPreGLU, Mlp -try: - from diffusion.model.nets.fastlinear.modules import TritonLiteMLA -except ImportError: - import warnings - warnings.warn("TritonLiteMLA with `triton` is not available on your platform.") from diffusion.model.nets.sana import Sana, get_2d_sincos_pos_embed from diffusion.model.nets.sana_blocks import ( Attention, @@ -39,7 +34,12 @@ t2i_modulate, ) from diffusion.model.utils import auto_grad_checkpoint, to_2tuple +from diffusion.utils.import_utils import is_triton_module_available +_triton_modules_available = False +if is_triton_module_available(): + from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonMBConvPreGLU + _triton_modules_available = True class SanaUMSBlock(nn.Module): """ @@ -78,6 +78,8 @@ def __init__( self_num_heads = hidden_size // 32 self.attn = LiteLA(hidden_size, hidden_size, heads=self_num_heads, eps=1e-8, qk_norm=qk_norm) elif attn_type == "triton_linear": + if not _triton_modules_available: + raise ValueError(f"{attn_type} type is not available due to _triton_modules_available={_triton_modules_available}.") # linear self attention with triton kernel fusion self_num_heads = hidden_size // 32 self.attn = TritonLiteMLA(hidden_size, num_heads=self_num_heads, eps=1e-8) diff --git a/diffusion/model/nets/sana_multi_scale.py b/diffusion/model/nets/sana_multi_scale.py index 7db1328..5fc2c22 100755 --- a/diffusion/model/nets/sana_multi_scale.py +++ b/diffusion/model/nets/sana_multi_scale.py @@ -21,11 +21,6 @@ from diffusion.model.builder import MODELS from diffusion.model.nets.basic_modules import DWMlp, GLUMBConv, MBConvPreGLU, Mlp -try: - from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonMBConvPreGLU -except ImportError: - import warnings - warnings.warn("TritonLiteMLA and TritonMBConvPreGLU with `triton` is not available on your platform.") from diffusion.model.nets.sana import Sana, get_2d_sincos_pos_embed from diffusion.model.nets.sana_blocks import ( Attention, @@ -38,6 +33,12 @@ t2i_modulate, ) from diffusion.model.utils import auto_grad_checkpoint +from diffusion.utils.import_utils import is_triton_module_available + +_triton_modules_available = False +if is_triton_module_available(): + from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonMBConvPreGLU + _triton_modules_available = True class SanaMSBlock(nn.Module): @@ -78,6 +79,8 @@ def __init__( self_num_heads = hidden_size // linear_head_dim self.attn = LiteLA(hidden_size, hidden_size, heads=self_num_heads, eps=1e-8, qk_norm=qk_norm) elif attn_type == "triton_linear": + if not _triton_modules_available: + raise ValueError(f"{attn_type} type is not available due to _triton_modules_available={_triton_modules_available}.") # linear self attention with triton kernel fusion self_num_heads = hidden_size // linear_head_dim self.attn = TritonLiteMLA(hidden_size, num_heads=self_num_heads, eps=1e-8) @@ -112,6 +115,8 @@ def __init__( dilation=2, ) elif ffn_type == "triton_mbconvpreglu": + if not _triton_modules_available: + raise ValueError(f"{ffn_type} type is not available due to _triton_modules_available={_triton_modules_available}.") self.mlp = TritonMBConvPreGLU( in_dim=hidden_size, out_dim=hidden_size, diff --git a/diffusion/model/nets/sana_multi_scale_adaln.py b/diffusion/model/nets/sana_multi_scale_adaln.py index d915c4d..569350f 100644 --- a/diffusion/model/nets/sana_multi_scale_adaln.py +++ b/diffusion/model/nets/sana_multi_scale_adaln.py @@ -39,6 +39,12 @@ modulate, ) from diffusion.model.utils import auto_grad_checkpoint, to_2tuple +from diffusion.utils.import_utils import is_triton_module_available + +_triton_modules_available = False +if is_triton_module_available(): + from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonMBConvPreGLU + _triton_modules_available = True class SanaMSAdaLNBlock(nn.Module): @@ -77,6 +83,8 @@ def __init__( self_num_heads = hidden_size // 32 self.attn = LiteLA(hidden_size, hidden_size, heads=self_num_heads, eps=1e-8, qk_norm=qk_norm) elif attn_type == "triton_linear": + if not _triton_modules_available: + raise ValueError(f"{attn_type} type is not available due to _triton_modules_available={_triton_modules_available}.") # linear self attention with triton kernel fusion self_num_heads = hidden_size // 32 self.attn = TritonLiteMLA(hidden_size, num_heads=self_num_heads, eps=1e-8) @@ -375,12 +383,12 @@ def SanaMSAdaLN_600M_P4_D28(**kwargs): @MODELS.register_module() -def SanaMSAdaLN_P1_D20(**kwargs): +def SanaMSAdaLN_1600M_P1_D20(**kwargs): # 20 layers, 1648.48M return SanaMSAdaLN(depth=20, hidden_size=2240, patch_size=1, num_heads=20, **kwargs) @MODELS.register_module() -def SanaMSAdaLN_P2_D20(**kwargs): +def SanaMSAdaLN_1600M_P2_D20(**kwargs): # 28 layers, 1648.48M return SanaMSAdaLN(depth=20, hidden_size=2240, patch_size=2, num_heads=20, **kwargs) diff --git a/diffusion/sa_sampler.py b/diffusion/sa_sampler.py index dee556f..79d7f21 100755 --- a/diffusion/sa_sampler.py +++ b/diffusion/sa_sampler.py @@ -21,7 +21,7 @@ from diffusion.model.sa_solver import NoiseScheduleVP, SASolver, model_wrapper -from .model.gaussian_diffusion import gaussian_diffusion as gd +from .model import gaussian_diffusion as gd class SASolverSampler: diff --git a/diffusion/utils/import_utils.py b/diffusion/utils/import_utils.py index 562160d..7c011ee 100644 --- a/diffusion/utils/import_utils.py +++ b/diffusion/utils/import_utils.py @@ -2,6 +2,7 @@ import importlib_metadata from packaging import version import logging +import warnings logger = logging.getLogger(__name__) @@ -16,6 +17,19 @@ except importlib_metadata.PackageNotFoundError: _xformers_available = False +_triton_modules_available = importlib.util.find_spec("triton") is not None +try: + if _triton_modules_available: + _triton_version = importlib_metadata.version("triton") + if version.Version(_triton_version) < version.Version("3.0.0"): + raise ValueError("triton is installed but requires Triton >= 3.0.0") + logger.debug(f"Successfully imported triton version {_triton_version}") +except ImportError: + _triton_modules_available = False + warnings.warn("TritonLiteMLA and TritonMBConvPreGLU with `triton` is not available on your platform.") def is_xformers_available(): - return _xformers_available \ No newline at end of file + return _xformers_available + +def is_triton_module_available(): + return _triton_modules_available diff --git a/scripts/inference.py b/scripts/inference.py index 1c9a24e..52d21fd 100755 --- a/scripts/inference.py +++ b/scripts/inference.py @@ -35,7 +35,7 @@ warnings.filterwarnings("ignore") # ignore warning from diffusion import DPMS, FlowEuler, SASolverSampler -from diffusion.data.datasets.utils import * +from diffusion.data.datasets.utils import ASPECT_RATIO_512_TEST, ASPECT_RATIO_1024_TEST, ASPECT_RATIO_2048_TEST from diffusion.model.builder import build_model, get_tokenizer_and_text_encoder, get_vae, vae_decode from diffusion.model.utils import prepare_prompt_ar from diffusion.utils.config import SanaConfig diff --git a/scripts/inference_dpg.py b/scripts/inference_dpg.py index 5f6805a..bde68c9 100644 --- a/scripts/inference_dpg.py +++ b/scripts/inference_dpg.py @@ -34,7 +34,7 @@ warnings.filterwarnings("ignore") # ignore warning from diffusion import DPMS, FlowEuler, SASolverSampler -from diffusion.data.datasets.utils import * +from diffusion.data.datasets.utils import ASPECT_RATIO_512_TEST, ASPECT_RATIO_1024_TEST, ASPECT_RATIO_2048_TEST from diffusion.model.builder import build_model, get_tokenizer_and_text_encoder, get_vae, vae_decode from diffusion.model.utils import prepare_prompt_ar from diffusion.utils.config import SanaConfig diff --git a/scripts/inference_geneval.py b/scripts/inference_geneval.py index 05cb920..07cef9f 100644 --- a/scripts/inference_geneval.py +++ b/scripts/inference_geneval.py @@ -36,7 +36,7 @@ warnings.filterwarnings("ignore") # ignore warning from diffusion import DPMS, FlowEuler, SASolverSampler -from diffusion.data.datasets.utils import * +from diffusion.data.datasets.utils import ASPECT_RATIO_512_TEST, ASPECT_RATIO_1024_TEST, ASPECT_RATIO_2048_TEST from diffusion.model.builder import build_model, get_tokenizer_and_text_encoder, get_vae, vae_decode from diffusion.model.utils import prepare_prompt_ar from diffusion.utils.config import SanaConfig diff --git a/scripts/inference_image_reward.py b/scripts/inference_image_reward.py index 972541c..fb3e1a5 100644 --- a/scripts/inference_image_reward.py +++ b/scripts/inference_image_reward.py @@ -33,7 +33,7 @@ from tqdm import tqdm from diffusion import DPMS, FlowEuler, SASolverSampler -from diffusion.data.datasets.utils import * +from diffusion.data.datasets.utils import ASPECT_RATIO_512_TEST, ASPECT_RATIO_1024_TEST, ASPECT_RATIO_2048_TEST from diffusion.model.builder import build_model, get_tokenizer_and_text_encoder, get_vae, vae_decode from diffusion.model.utils import prepare_prompt_ar from diffusion.utils.config import SanaConfig diff --git a/scripts/interface.py b/scripts/interface.py index e64c0f9..d81199c 100755 --- a/scripts/interface.py +++ b/scripts/interface.py @@ -34,7 +34,7 @@ from asset.examples import examples from diffusion import DPMS, FlowEuler, SASolverSampler -from diffusion.data.datasets.utils import * +from diffusion.data.datasets.utils import ASPECT_RATIO_512_TEST, ASPECT_RATIO_1024_TEST, ASPECT_RATIO_2048_TEST from diffusion.model.builder import build_model, get_tokenizer_and_text_encoder, get_vae, vae_decode from diffusion.model.utils import prepare_prompt_ar, resize_and_crop_tensor from diffusion.utils.config import SanaConfig From cc5991b48427e72f924829e1bfe8db3ac52d41eb Mon Sep 17 00:00:00 2001 From: junsong Date: Sun, 24 Nov 2024 03:08:38 -0800 Subject: [PATCH 3/4] support `F.scaled_dot_product_attention` without `xformers` --- diffusion/model/nets/sana_blocks.py | 9 +++++---- diffusion/model/nets/sana_multi_scale.py | 21 +++++++++++++++------ 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/diffusion/model/nets/sana_blocks.py b/diffusion/model/nets/sana_blocks.py index f756f0d..2121f3b 100755 --- a/diffusion/model/nets/sana_blocks.py +++ b/diffusion/model/nets/sana_blocks.py @@ -71,13 +71,14 @@ def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0, qk_norm=Fal def forward(self, x, cond, mask=None): # query: img tokens; key/value: condition; mask: if padding tokens B, N, C = x.shape + first_dim = 1 if _xformers_available else B q = self.q_linear(x) - kv = self.kv_linear(cond).view(1, -1, 2, C) + kv = self.kv_linear(cond).view(first_dim, -1, 2, C) k, v = kv.unbind(2) - q = self.q_norm(q).view(1, -1, self.num_heads, self.head_dim) - k = self.k_norm(k).view(1, -1, self.num_heads, self.head_dim) - v = v.view(1, -1, self.num_heads, self.head_dim) + q = self.q_norm(q).view(first_dim, -1, self.num_heads, self.head_dim) + k = self.k_norm(k).view(first_dim, -1, self.num_heads, self.head_dim) + v = v.view(first_dim, -1, self.num_heads, self.head_dim) if _xformers_available: attn_bias = None diff --git a/diffusion/model/nets/sana_multi_scale.py b/diffusion/model/nets/sana_multi_scale.py index 5fc2c22..7546e49 100755 --- a/diffusion/model/nets/sana_multi_scale.py +++ b/diffusion/model/nets/sana_multi_scale.py @@ -33,13 +33,17 @@ t2i_modulate, ) from diffusion.model.utils import auto_grad_checkpoint -from diffusion.utils.import_utils import is_triton_module_available +from diffusion.utils.import_utils import is_triton_module_available, is_xformers_available _triton_modules_available = False if is_triton_module_available(): from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonMBConvPreGLU _triton_modules_available = True +_xformers_available = False +if is_xformers_available(): + import xformers.ops + _xformers_available = True class SanaMSBlock(nn.Module): """ @@ -301,14 +305,19 @@ def forward(self, x, timestep, y, mask=None, data_info=None, **kwargs): y = self.attention_y_norm(y) if mask is not None: - if mask.shape[0] != y.shape[0]: - mask = mask.repeat(y.shape[0] // mask.shape[0], 1) + mask = mask.repeat(y.shape[0] // mask.shape[0], 1) if mask.shape[0] != y.shape[0] else mask mask = mask.squeeze(1).squeeze(1) - y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) - y_lens = mask.sum(dim=1).tolist() - else: + if _xformers_available: + y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) + y_lens = mask.sum(dim=1).tolist() + else: + y_lens = mask + elif _xformers_available: y_lens = [y.shape[2]] * y.shape[0] y = y.squeeze(1).view(1, -1, x.shape[-1]) + else: + raise ValueError(f"{attn_type} type is not available due to _xformers_available={_xformers_available}.") + for block in self.blocks: x = auto_grad_checkpoint( block, x, y, t0, y_lens, (self.h, self.w), **kwargs From e55cbc7d724d8bee60c6204c2548f7be8c5704bd Mon Sep 17 00:00:00 2001 From: Ligeng Zhu Date: Sun, 24 Nov 2024 19:21:25 +0800 Subject: [PATCH 4/4] pre-commit Signed-off-by: lawrence-cj --- diffusion/model/nets/__init__.py | 20 +++++++++---------- diffusion/model/nets/sana.py | 11 +++++++--- diffusion/model/nets/sana_U_shape.py | 16 +++++++-------- .../model/nets/sana_U_shape_multi_scale.py | 11 ++++++---- diffusion/model/nets/sana_blocks.py | 4 ++-- diffusion/model/nets/sana_multi_scale.py | 15 +++++++++----- .../model/nets/sana_multi_scale_adaln.py | 14 ++++++------- diffusion/model/nets/sana_others.py | 13 ++++++------ diffusion/utils/import_utils.py | 7 +++++-- 9 files changed, 62 insertions(+), 49 deletions(-) diff --git a/diffusion/model/nets/__init__.py b/diffusion/model/nets/__init__.py index 9a7eaa1..a7ac826 100755 --- a/diffusion/model/nets/__init__.py +++ b/diffusion/model/nets/__init__.py @@ -1,43 +1,43 @@ from .sana import ( - Sana, - SanaBlock, - get_2d_sincos_pos_embed, - get_2d_sincos_pos_embed_from_grid, + Sana, + SanaBlock, get_1d_sincos_pos_embed_from_grid, + get_2d_sincos_pos_embed, + get_2d_sincos_pos_embed_from_grid, ) from .sana_multi_scale import ( - SanaMSBlock, - SanaMS, - SanaMS_600M_P1_D28, + SanaMS, + SanaMS_600M_P1_D28, SanaMS_600M_P2_D28, SanaMS_600M_P4_D28, SanaMS_1600M_P1_D20, SanaMS_1600M_P2_D20, + SanaMSBlock, ) from .sana_multi_scale_adaln import ( - SanaMSAdaLNBlock, SanaMSAdaLN, SanaMSAdaLN_600M_P1_D28, SanaMSAdaLN_600M_P2_D28, SanaMSAdaLN_600M_P4_D28, SanaMSAdaLN_1600M_P1_D20, SanaMSAdaLN_1600M_P2_D20, + SanaMSAdaLNBlock, ) from .sana_U_shape import ( - SanaUBlock, SanaU, SanaU_600M_P1_D28, SanaU_600M_P2_D28, SanaU_600M_P4_D28, SanaU_1600M_P1_D20, SanaU_1600M_P2_D20, + SanaUBlock, ) from .sana_U_shape_multi_scale import ( - SanaUMSBlock, SanaUMS, SanaUMS_600M_P1_D28, SanaUMS_600M_P2_D28, SanaUMS_600M_P4_D28, SanaUMS_1600M_P1_D20, SanaUMS_1600M_P2_D20, + SanaUMSBlock, ) diff --git a/diffusion/model/nets/sana.py b/diffusion/model/nets/sana.py index d14bb3d..fc236f0 100755 --- a/diffusion/model/nets/sana.py +++ b/diffusion/model/nets/sana.py @@ -43,7 +43,8 @@ _triton_modules_available = False if is_triton_module_available(): - from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonMBConvPreGLU + from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonMBConvPreGLU + _triton_modules_available = True @@ -84,7 +85,9 @@ def __init__( self.attn = LiteLA(hidden_size, hidden_size, heads=self_num_heads, eps=1e-8, qk_norm=qk_norm) elif attn_type == "triton_linear": if not _triton_modules_available: - raise ValueError(f"{attn_type} type is not available due to _triton_modules_available={_triton_modules_available}.") + raise ValueError( + f"{attn_type} type is not available due to _triton_modules_available={_triton_modules_available}." + ) # linear self attention with triton kernel fusion # TODO: Here the num_heads set to 36 for tmp used self_num_heads = hidden_size // linear_head_dim @@ -131,7 +134,9 @@ def __init__( ) elif ffn_type == "triton_mbconvpreglu": if not _triton_modules_available: - raise ValueError(f"{ffn_type} type is not available due to _triton_modules_available={_triton_modules_available}.") + raise ValueError( + f"{ffn_type} type is not available due to _triton_modules_available={_triton_modules_available}." + ) self.mlp = TritonMBConvPreGLU( in_dim=hidden_size, out_dim=hidden_size, diff --git a/diffusion/model/nets/sana_U_shape.py b/diffusion/model/nets/sana_U_shape.py index 21fdbda..35f5cbf 100644 --- a/diffusion/model/nets/sana_U_shape.py +++ b/diffusion/model/nets/sana_U_shape.py @@ -23,11 +23,6 @@ from diffusion.model.builder import MODELS from diffusion.model.nets.basic_modules import DWMlp, GLUMBConv, MBConvPreGLU, Mlp -try: - from diffusion.model.nets.fastlinear.modules import TritonLiteMLA -except ImportError: - import warnings - warnings.warn("TritonLiteMLA with `triton` is not available on your platform.") from diffusion.model.nets.sana import Sana, get_2d_sincos_pos_embed from diffusion.model.nets.sana_blocks import ( Attention, @@ -41,13 +36,14 @@ t2i_modulate, ) from diffusion.model.norms import RMSNorm -from diffusion.model.utils import auto_grad_checkpoint, to_2tuple -from diffusion.utils.logger import get_root_logger +from diffusion.model.utils import auto_grad_checkpoint from diffusion.utils.import_utils import is_triton_module_available +from diffusion.utils.logger import get_root_logger _triton_modules_available = False if is_triton_module_available(): - from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonMBConvPreGLU + from diffusion.model.nets.fastlinear.modules import TritonLiteMLA + _triton_modules_available = True @@ -88,7 +84,9 @@ def __init__( self.attn = LiteLA(hidden_size, hidden_size, heads=self_num_heads, eps=1e-8, qk_norm=qk_norm) elif attn_type == "triton_linear": if not _triton_modules_available: - raise ValueError(f"{attn_type} type is not available due to _triton_modules_available={_triton_modules_available}.") + raise ValueError( + f"{attn_type} type is not available due to _triton_modules_available={_triton_modules_available}." + ) # linear self attention with triton kernel fusion # TODO: Here the num_heads set to 36 for tmp used self_num_heads = hidden_size // 32 diff --git a/diffusion/model/nets/sana_U_shape_multi_scale.py b/diffusion/model/nets/sana_U_shape_multi_scale.py index c4cfb86..13dfd4a 100644 --- a/diffusion/model/nets/sana_U_shape_multi_scale.py +++ b/diffusion/model/nets/sana_U_shape_multi_scale.py @@ -29,18 +29,19 @@ LiteLA, MultiHeadCrossAttention, PatchEmbedMS, - SizeEmbedder, T2IFinalLayer, t2i_modulate, ) -from diffusion.model.utils import auto_grad_checkpoint, to_2tuple +from diffusion.model.utils import auto_grad_checkpoint from diffusion.utils.import_utils import is_triton_module_available _triton_modules_available = False if is_triton_module_available(): - from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonMBConvPreGLU + from diffusion.model.nets.fastlinear.modules import TritonLiteMLA + _triton_modules_available = True + class SanaUMSBlock(nn.Module): """ A SanaU block with global shared adaptive layer norm (adaLN-single) conditioning and U-shaped model. @@ -79,7 +80,9 @@ def __init__( self.attn = LiteLA(hidden_size, hidden_size, heads=self_num_heads, eps=1e-8, qk_norm=qk_norm) elif attn_type == "triton_linear": if not _triton_modules_available: - raise ValueError(f"{attn_type} type is not available due to _triton_modules_available={_triton_modules_available}.") + raise ValueError( + f"{attn_type} type is not available due to _triton_modules_available={_triton_modules_available}." + ) # linear self attention with triton kernel fusion self_num_heads = hidden_size // 32 self.attn = TritonLiteMLA(hidden_size, num_heads=self_num_heads, eps=1e-8) diff --git a/diffusion/model/nets/sana_blocks.py b/diffusion/model/nets/sana_blocks.py index 2121f3b..b50e64b 100755 --- a/diffusion/model/nets/sana_blocks.py +++ b/diffusion/model/nets/sana_blocks.py @@ -22,7 +22,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -from diffusion.utils.import_utils import is_xformers_available from einops import rearrange from timm.models.vision_transformer import Attention as Attention_ from timm.models.vision_transformer import Mlp @@ -30,11 +29,12 @@ from diffusion.model.norms import RMSNorm from diffusion.model.utils import get_same_padding, to_2tuple - +from diffusion.utils.import_utils import is_xformers_available _xformers_available = False if is_xformers_available(): import xformers.ops + _xformers_available = True diff --git a/diffusion/model/nets/sana_multi_scale.py b/diffusion/model/nets/sana_multi_scale.py index 7546e49..b79d230 100755 --- a/diffusion/model/nets/sana_multi_scale.py +++ b/diffusion/model/nets/sana_multi_scale.py @@ -37,14 +37,15 @@ _triton_modules_available = False if is_triton_module_available(): - from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonMBConvPreGLU + from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonMBConvPreGLU + _triton_modules_available = True _xformers_available = False if is_xformers_available(): - import xformers.ops _xformers_available = True + class SanaMSBlock(nn.Module): """ A Sana block with global shared adaptive layer norm zero (adaLN-Zero) conditioning. @@ -84,7 +85,9 @@ def __init__( self.attn = LiteLA(hidden_size, hidden_size, heads=self_num_heads, eps=1e-8, qk_norm=qk_norm) elif attn_type == "triton_linear": if not _triton_modules_available: - raise ValueError(f"{attn_type} type is not available due to _triton_modules_available={_triton_modules_available}.") + raise ValueError( + f"{attn_type} type is not available due to _triton_modules_available={_triton_modules_available}." + ) # linear self attention with triton kernel fusion self_num_heads = hidden_size // linear_head_dim self.attn = TritonLiteMLA(hidden_size, num_heads=self_num_heads, eps=1e-8) @@ -120,7 +123,9 @@ def __init__( ) elif ffn_type == "triton_mbconvpreglu": if not _triton_modules_available: - raise ValueError(f"{ffn_type} type is not available due to _triton_modules_available={_triton_modules_available}.") + raise ValueError( + f"{ffn_type} type is not available due to _triton_modules_available={_triton_modules_available}." + ) self.mlp = TritonMBConvPreGLU( in_dim=hidden_size, out_dim=hidden_size, @@ -316,7 +321,7 @@ def forward(self, x, timestep, y, mask=None, data_info=None, **kwargs): y_lens = [y.shape[2]] * y.shape[0] y = y.squeeze(1).view(1, -1, x.shape[-1]) else: - raise ValueError(f"{attn_type} type is not available due to _xformers_available={_xformers_available}.") + raise ValueError(f"Attention type is not available due to _xformers_available={_xformers_available}.") for block in self.blocks: x = auto_grad_checkpoint( diff --git a/diffusion/model/nets/sana_multi_scale_adaln.py b/diffusion/model/nets/sana_multi_scale_adaln.py index 569350f..7564740 100644 --- a/diffusion/model/nets/sana_multi_scale_adaln.py +++ b/diffusion/model/nets/sana_multi_scale_adaln.py @@ -21,11 +21,6 @@ from diffusion.model.builder import MODELS from diffusion.model.nets.basic_modules import DWMlp, GLUMBConv, MBConvPreGLU, Mlp -try: - from diffusion.model.nets.fastlinear.modules import TritonLiteMLA -except ImportError: - import warnings - warnings.warn("TritonLiteMLA with `triton` is not available on your platform.") from diffusion.model.nets.sana import Sana, get_2d_sincos_pos_embed from diffusion.model.nets.sana_blocks import ( Attention, @@ -38,12 +33,13 @@ T2IFinalLayer, modulate, ) -from diffusion.model.utils import auto_grad_checkpoint, to_2tuple +from diffusion.model.utils import auto_grad_checkpoint from diffusion.utils.import_utils import is_triton_module_available _triton_modules_available = False if is_triton_module_available(): - from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonMBConvPreGLU + from diffusion.model.nets.fastlinear.modules import TritonLiteMLA + _triton_modules_available = True @@ -84,7 +80,9 @@ def __init__( self.attn = LiteLA(hidden_size, hidden_size, heads=self_num_heads, eps=1e-8, qk_norm=qk_norm) elif attn_type == "triton_linear": if not _triton_modules_available: - raise ValueError(f"{attn_type} type is not available due to _triton_modules_available={_triton_modules_available}.") + raise ValueError( + f"{attn_type} type is not available due to _triton_modules_available={_triton_modules_available}." + ) # linear self attention with triton kernel fusion self_num_heads = hidden_size // 32 self.attn = TritonLiteMLA(hidden_size, num_heads=self_num_heads, eps=1e-8) diff --git a/diffusion/model/nets/sana_others.py b/diffusion/model/nets/sana_others.py index 74caf1b..4faacca 100644 --- a/diffusion/model/nets/sana_others.py +++ b/diffusion/model/nets/sana_others.py @@ -20,13 +20,14 @@ from timm.models.layers import DropPath from diffusion.model.nets.basic_modules import DWMlp, MBConvPreGLU, Mlp -from diffusion.model.nets.fastlinear.modules import TritonLiteMLA -try: - from diffusion.model.nets.fastlinear.modules import TritonLiteMLA -except ImportError: - import warnings - warnings.warn("TritonLiteMLA with `triton` is not available on your platform.") from diffusion.model.nets.sana_blocks import Attention, FlashAttention, MultiHeadCrossAttention, t2i_modulate +from diffusion.utils.import_utils import is_triton_module_available + +_triton_modules_available = False +if is_triton_module_available(): + from diffusion.model.nets.fastlinear.modules import TritonLiteMLA + + _triton_modules_available = True class SanaMSPABlock(nn.Module): diff --git a/diffusion/utils/import_utils.py b/diffusion/utils/import_utils.py index 7c011ee..6d0466f 100644 --- a/diffusion/utils/import_utils.py +++ b/diffusion/utils/import_utils.py @@ -1,9 +1,10 @@ import importlib.util -import importlib_metadata -from packaging import version import logging import warnings +import importlib_metadata +from packaging import version + logger = logging.getLogger(__name__) _xformers_available = importlib.util.find_spec("xformers") is not None @@ -28,8 +29,10 @@ _triton_modules_available = False warnings.warn("TritonLiteMLA and TritonMBConvPreGLU with `triton` is not available on your platform.") + def is_xformers_available(): return _xformers_available + def is_triton_module_available(): return _triton_modules_available