diff --git a/diffusion/model/__init__.py b/diffusion/model/__init__.py index 5a0d275..e69de29 100755 --- a/diffusion/model/__init__.py +++ b/diffusion/model/__init__.py @@ -1 +0,0 @@ -from .nets import * diff --git a/diffusion/model/nets/__init__.py b/diffusion/model/nets/__init__.py index 375153d..a7ac826 100755 --- a/diffusion/model/nets/__init__.py +++ b/diffusion/model/nets/__init__.py @@ -1,5 +1,43 @@ -from .sana import * -from .sana_multi_scale import * -from .sana_multi_scale_adaln import * -from .sana_U_shape import * -from .sana_U_shape_multi_scale import * +from .sana import ( + Sana, + SanaBlock, + get_1d_sincos_pos_embed_from_grid, + get_2d_sincos_pos_embed, + get_2d_sincos_pos_embed_from_grid, +) +from .sana_multi_scale import ( + SanaMS, + SanaMS_600M_P1_D28, + SanaMS_600M_P2_D28, + SanaMS_600M_P4_D28, + SanaMS_1600M_P1_D20, + SanaMS_1600M_P2_D20, + SanaMSBlock, +) +from .sana_multi_scale_adaln import ( + SanaMSAdaLN, + SanaMSAdaLN_600M_P1_D28, + SanaMSAdaLN_600M_P2_D28, + SanaMSAdaLN_600M_P4_D28, + SanaMSAdaLN_1600M_P1_D20, + SanaMSAdaLN_1600M_P2_D20, + SanaMSAdaLNBlock, +) +from .sana_U_shape import ( + SanaU, + SanaU_600M_P1_D28, + SanaU_600M_P2_D28, + SanaU_600M_P4_D28, + SanaU_1600M_P1_D20, + SanaU_1600M_P2_D20, + SanaUBlock, +) +from .sana_U_shape_multi_scale import ( + SanaUMS, + SanaUMS_600M_P1_D28, + SanaUMS_600M_P2_D28, + SanaUMS_600M_P4_D28, + SanaUMS_1600M_P1_D20, + SanaUMS_1600M_P2_D20, + SanaUMSBlock, +) diff --git a/diffusion/model/nets/sana.py b/diffusion/model/nets/sana.py index 76c318e..fc236f0 100755 --- a/diffusion/model/nets/sana.py +++ b/diffusion/model/nets/sana.py @@ -24,7 +24,6 @@ from diffusion.model.builder import MODELS from diffusion.model.nets.basic_modules import DWMlp, GLUMBConv, MBConvPreGLU, Mlp -from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonMBConvPreGLU from diffusion.model.nets.sana_blocks import ( Attention, CaptionEmbedder, @@ -39,8 +38,15 @@ from diffusion.model.norms import RMSNorm from diffusion.model.utils import auto_grad_checkpoint, to_2tuple from diffusion.utils.dist_utils import get_rank +from diffusion.utils.import_utils import is_triton_module_available from diffusion.utils.logger import get_root_logger +_triton_modules_available = False +if is_triton_module_available(): + from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonMBConvPreGLU + + _triton_modules_available = True + class SanaBlock(nn.Module): """ @@ -78,6 +84,10 @@ def __init__( self_num_heads = hidden_size // linear_head_dim self.attn = LiteLA(hidden_size, hidden_size, heads=self_num_heads, eps=1e-8, qk_norm=qk_norm) elif attn_type == "triton_linear": + if not _triton_modules_available: + raise ValueError( + f"{attn_type} type is not available due to _triton_modules_available={_triton_modules_available}." + ) # linear self attention with triton kernel fusion # TODO: Here the num_heads set to 36 for tmp used self_num_heads = hidden_size // linear_head_dim @@ -123,6 +133,10 @@ def __init__( act=("silu", "silu", None), ) elif ffn_type == "triton_mbconvpreglu": + if not _triton_modules_available: + raise ValueError( + f"{ffn_type} type is not available due to _triton_modules_available={_triton_modules_available}." + ) self.mlp = TritonMBConvPreGLU( in_dim=hidden_size, out_dim=hidden_size, diff --git a/diffusion/model/nets/sana_U_shape.py b/diffusion/model/nets/sana_U_shape.py index d16b895..35f5cbf 100644 --- a/diffusion/model/nets/sana_U_shape.py +++ b/diffusion/model/nets/sana_U_shape.py @@ -23,7 +23,6 @@ from diffusion.model.builder import MODELS from diffusion.model.nets.basic_modules import DWMlp, GLUMBConv, MBConvPreGLU, Mlp -from diffusion.model.nets.fastlinear.modules import TritonLiteMLA from diffusion.model.nets.sana import Sana, get_2d_sincos_pos_embed from diffusion.model.nets.sana_blocks import ( Attention, @@ -37,9 +36,16 @@ t2i_modulate, ) from diffusion.model.norms import RMSNorm -from diffusion.model.utils import auto_grad_checkpoint, to_2tuple +from diffusion.model.utils import auto_grad_checkpoint +from diffusion.utils.import_utils import is_triton_module_available from diffusion.utils.logger import get_root_logger +_triton_modules_available = False +if is_triton_module_available(): + from diffusion.model.nets.fastlinear.modules import TritonLiteMLA + + _triton_modules_available = True + class SanaUBlock(nn.Module): """ @@ -77,6 +83,10 @@ def __init__( self_num_heads = hidden_size // 32 self.attn = LiteLA(hidden_size, hidden_size, heads=self_num_heads, eps=1e-8, qk_norm=qk_norm) elif attn_type == "triton_linear": + if not _triton_modules_available: + raise ValueError( + f"{attn_type} type is not available due to _triton_modules_available={_triton_modules_available}." + ) # linear self attention with triton kernel fusion # TODO: Here the num_heads set to 36 for tmp used self_num_heads = hidden_size // 32 @@ -343,27 +353,27 @@ def dtype(self): # SanaU Configs # ################################################################################# @MODELS.register_module() -def SanaMSU_600M_P1_D28(**kwargs): +def SanaU_600M_P1_D28(**kwargs): return SanaU(depth=28, hidden_size=1152, patch_size=1, num_heads=16, **kwargs) @MODELS.register_module() -def SanaMSU_600M_P2_D28(**kwargs): +def SanaU_600M_P2_D28(**kwargs): return SanaU(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs) @MODELS.register_module() -def SanaMSU_600M_P4_D28(**kwargs): +def SanaU_600M_P4_D28(**kwargs): return SanaU(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs) @MODELS.register_module() -def SanaMSU_P1_D20(**kwargs): +def SanaU_1600M_P1_D20(**kwargs): # 20 layers, 1648.48M return SanaU(depth=20, hidden_size=2240, patch_size=1, num_heads=20, **kwargs) @MODELS.register_module() -def SanaMSU_P2_D20(**kwargs): +def SanaU_1600M_P2_D20(**kwargs): # 28 layers, 1648.48M return SanaU(depth=20, hidden_size=2240, patch_size=2, num_heads=20, **kwargs) diff --git a/diffusion/model/nets/sana_U_shape_multi_scale.py b/diffusion/model/nets/sana_U_shape_multi_scale.py index 927fc1b..13dfd4a 100644 --- a/diffusion/model/nets/sana_U_shape_multi_scale.py +++ b/diffusion/model/nets/sana_U_shape_multi_scale.py @@ -21,7 +21,6 @@ from diffusion.model.builder import MODELS from diffusion.model.nets.basic_modules import DWMlp, GLUMBConv, MBConvPreGLU, Mlp -from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonLiteMLAFwd from diffusion.model.nets.sana import Sana, get_2d_sincos_pos_embed from diffusion.model.nets.sana_blocks import ( Attention, @@ -30,11 +29,17 @@ LiteLA, MultiHeadCrossAttention, PatchEmbedMS, - SizeEmbedder, T2IFinalLayer, t2i_modulate, ) -from diffusion.model.utils import auto_grad_checkpoint, to_2tuple +from diffusion.model.utils import auto_grad_checkpoint +from diffusion.utils.import_utils import is_triton_module_available + +_triton_modules_available = False +if is_triton_module_available(): + from diffusion.model.nets.fastlinear.modules import TritonLiteMLA + + _triton_modules_available = True class SanaUMSBlock(nn.Module): @@ -74,6 +79,10 @@ def __init__( self_num_heads = hidden_size // 32 self.attn = LiteLA(hidden_size, hidden_size, heads=self_num_heads, eps=1e-8, qk_norm=qk_norm) elif attn_type == "triton_linear": + if not _triton_modules_available: + raise ValueError( + f"{attn_type} type is not available due to _triton_modules_available={_triton_modules_available}." + ) # linear self attention with triton kernel fusion self_num_heads = hidden_size // 32 self.attn = TritonLiteMLA(hidden_size, num_heads=self_num_heads, eps=1e-8) @@ -365,12 +374,12 @@ def SanaUMS_600M_P4_D28(**kwargs): @MODELS.register_module() -def SanaUMS_P1_D20(**kwargs): +def SanaUMS_1600M_P1_D20(**kwargs): # 20 layers, 1648.48M return SanaUMS(depth=20, hidden_size=2240, patch_size=1, num_heads=20, **kwargs) @MODELS.register_module() -def SanaUMS_P2_D20(**kwargs): +def SanaUMS_1600M_P2_D20(**kwargs): # 28 layers, 1648.48M return SanaUMS(depth=20, hidden_size=2240, patch_size=2, num_heads=20, **kwargs) diff --git a/diffusion/model/nets/sana_blocks.py b/diffusion/model/nets/sana_blocks.py index b5e5597..b50e64b 100755 --- a/diffusion/model/nets/sana_blocks.py +++ b/diffusion/model/nets/sana_blocks.py @@ -22,7 +22,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -import xformers.ops from einops import rearrange from timm.models.vision_transformer import Attention as Attention_ from timm.models.vision_transformer import Mlp @@ -30,6 +29,13 @@ from diffusion.model.norms import RMSNorm from diffusion.model.utils import get_same_padding, to_2tuple +from diffusion.utils.import_utils import is_xformers_available + +_xformers_available = False +if is_xformers_available(): + import xformers.ops + + _xformers_available = True def modulate(x, shift, scale): @@ -65,17 +71,27 @@ def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0, qk_norm=Fal def forward(self, x, cond, mask=None): # query: img tokens; key/value: condition; mask: if padding tokens B, N, C = x.shape + first_dim = 1 if _xformers_available else B q = self.q_linear(x) - kv = self.kv_linear(cond).view(1, -1, 2, C) + kv = self.kv_linear(cond).view(first_dim, -1, 2, C) k, v = kv.unbind(2) - q = self.q_norm(q).view(1, -1, self.num_heads, self.head_dim) - k = self.k_norm(k).view(1, -1, self.num_heads, self.head_dim) - v = v.view(1, -1, self.num_heads, self.head_dim) - attn_bias = None - if mask is not None: - attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask) - x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias) + q = self.q_norm(q).view(first_dim, -1, self.num_heads, self.head_dim) + k = self.k_norm(k).view(first_dim, -1, self.num_heads, self.head_dim) + v = v.view(first_dim, -1, self.num_heads, self.head_dim) + + if _xformers_available: + attn_bias = None + if mask is not None: + attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask) + x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias) + else: + q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) + if mask is not None and mask.ndim == 2: + mask = (1 - mask.to(x.dtype)) * -10000.0 + mask = mask[:, None, None].repeat(1, self.num_heads, 1, 1) + x = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) + x = x.transpose(1, 2) x = x.view(B, -1, C) x = self.proj(x) @@ -347,7 +363,15 @@ def forward(self, x, mask=None, HW=None, block_id=None): attn_bias = torch.zeros([B * self.num_heads, q.shape[1], k.shape[1]], dtype=q.dtype, device=q.device) attn_bias.masked_fill_(mask.squeeze(1).repeat(self.num_heads, 1, 1) == 0, float("-inf")) - x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias) + if _xformers_available: + x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias) + else: + q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) + if mask is not None and mask.ndim == 2: + mask = (1 - mask.to(x.dtype)) * -10000.0 + mask = mask[:, None, None].repeat(1, self.num_heads, 1, 1) + x = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) + x = x.transpose(1, 2) x = x.view(B, N, C) x = self.proj(x) diff --git a/diffusion/model/nets/sana_multi_scale.py b/diffusion/model/nets/sana_multi_scale.py index 078d82b..b79d230 100755 --- a/diffusion/model/nets/sana_multi_scale.py +++ b/diffusion/model/nets/sana_multi_scale.py @@ -21,7 +21,6 @@ from diffusion.model.builder import MODELS from diffusion.model.nets.basic_modules import DWMlp, GLUMBConv, MBConvPreGLU, Mlp -from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonMBConvPreGLU from diffusion.model.nets.sana import Sana, get_2d_sincos_pos_embed from diffusion.model.nets.sana_blocks import ( Attention, @@ -34,6 +33,17 @@ t2i_modulate, ) from diffusion.model.utils import auto_grad_checkpoint +from diffusion.utils.import_utils import is_triton_module_available, is_xformers_available + +_triton_modules_available = False +if is_triton_module_available(): + from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonMBConvPreGLU + + _triton_modules_available = True + +_xformers_available = False +if is_xformers_available(): + _xformers_available = True class SanaMSBlock(nn.Module): @@ -74,6 +84,10 @@ def __init__( self_num_heads = hidden_size // linear_head_dim self.attn = LiteLA(hidden_size, hidden_size, heads=self_num_heads, eps=1e-8, qk_norm=qk_norm) elif attn_type == "triton_linear": + if not _triton_modules_available: + raise ValueError( + f"{attn_type} type is not available due to _triton_modules_available={_triton_modules_available}." + ) # linear self attention with triton kernel fusion self_num_heads = hidden_size // linear_head_dim self.attn = TritonLiteMLA(hidden_size, num_heads=self_num_heads, eps=1e-8) @@ -108,6 +122,10 @@ def __init__( dilation=2, ) elif ffn_type == "triton_mbconvpreglu": + if not _triton_modules_available: + raise ValueError( + f"{ffn_type} type is not available due to _triton_modules_available={_triton_modules_available}." + ) self.mlp = TritonMBConvPreGLU( in_dim=hidden_size, out_dim=hidden_size, @@ -292,14 +310,19 @@ def forward(self, x, timestep, y, mask=None, data_info=None, **kwargs): y = self.attention_y_norm(y) if mask is not None: - if mask.shape[0] != y.shape[0]: - mask = mask.repeat(y.shape[0] // mask.shape[0], 1) + mask = mask.repeat(y.shape[0] // mask.shape[0], 1) if mask.shape[0] != y.shape[0] else mask mask = mask.squeeze(1).squeeze(1) - y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) - y_lens = mask.sum(dim=1).tolist() - else: + if _xformers_available: + y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) + y_lens = mask.sum(dim=1).tolist() + else: + y_lens = mask + elif _xformers_available: y_lens = [y.shape[2]] * y.shape[0] y = y.squeeze(1).view(1, -1, x.shape[-1]) + else: + raise ValueError(f"Attention type is not available due to _xformers_available={_xformers_available}.") + for block in self.blocks: x = auto_grad_checkpoint( block, x, y, t0, y_lens, (self.h, self.w), **kwargs diff --git a/diffusion/model/nets/sana_multi_scale_adaln.py b/diffusion/model/nets/sana_multi_scale_adaln.py index 8c2ce8a..7564740 100644 --- a/diffusion/model/nets/sana_multi_scale_adaln.py +++ b/diffusion/model/nets/sana_multi_scale_adaln.py @@ -21,7 +21,6 @@ from diffusion.model.builder import MODELS from diffusion.model.nets.basic_modules import DWMlp, GLUMBConv, MBConvPreGLU, Mlp -from diffusion.model.nets.fastlinear.modules import TritonLiteMLA, TritonLiteMLAFwd from diffusion.model.nets.sana import Sana, get_2d_sincos_pos_embed from diffusion.model.nets.sana_blocks import ( Attention, @@ -34,7 +33,14 @@ T2IFinalLayer, modulate, ) -from diffusion.model.utils import auto_grad_checkpoint, to_2tuple +from diffusion.model.utils import auto_grad_checkpoint +from diffusion.utils.import_utils import is_triton_module_available + +_triton_modules_available = False +if is_triton_module_available(): + from diffusion.model.nets.fastlinear.modules import TritonLiteMLA + + _triton_modules_available = True class SanaMSAdaLNBlock(nn.Module): @@ -73,6 +79,10 @@ def __init__( self_num_heads = hidden_size // 32 self.attn = LiteLA(hidden_size, hidden_size, heads=self_num_heads, eps=1e-8, qk_norm=qk_norm) elif attn_type == "triton_linear": + if not _triton_modules_available: + raise ValueError( + f"{attn_type} type is not available due to _triton_modules_available={_triton_modules_available}." + ) # linear self attention with triton kernel fusion self_num_heads = hidden_size // 32 self.attn = TritonLiteMLA(hidden_size, num_heads=self_num_heads, eps=1e-8) @@ -371,12 +381,12 @@ def SanaMSAdaLN_600M_P4_D28(**kwargs): @MODELS.register_module() -def SanaMSAdaLN_P1_D20(**kwargs): +def SanaMSAdaLN_1600M_P1_D20(**kwargs): # 20 layers, 1648.48M return SanaMSAdaLN(depth=20, hidden_size=2240, patch_size=1, num_heads=20, **kwargs) @MODELS.register_module() -def SanaMSAdaLN_P2_D20(**kwargs): +def SanaMSAdaLN_1600M_P2_D20(**kwargs): # 28 layers, 1648.48M return SanaMSAdaLN(depth=20, hidden_size=2240, patch_size=2, num_heads=20, **kwargs) diff --git a/diffusion/model/nets/sana_others.py b/diffusion/model/nets/sana_others.py index 519a011..4faacca 100644 --- a/diffusion/model/nets/sana_others.py +++ b/diffusion/model/nets/sana_others.py @@ -20,8 +20,14 @@ from timm.models.layers import DropPath from diffusion.model.nets.basic_modules import DWMlp, MBConvPreGLU, Mlp -from diffusion.model.nets.fastlinear.modules import TritonLiteMLA from diffusion.model.nets.sana_blocks import Attention, FlashAttention, MultiHeadCrossAttention, t2i_modulate +from diffusion.utils.import_utils import is_triton_module_available + +_triton_modules_available = False +if is_triton_module_available(): + from diffusion.model.nets.fastlinear.modules import TritonLiteMLA + + _triton_modules_available = True class SanaMSPABlock(nn.Module): diff --git a/diffusion/utils/import_utils.py b/diffusion/utils/import_utils.py new file mode 100644 index 0000000..6d0466f --- /dev/null +++ b/diffusion/utils/import_utils.py @@ -0,0 +1,38 @@ +import importlib.util +import logging +import warnings + +import importlib_metadata +from packaging import version + +logger = logging.getLogger(__name__) + +_xformers_available = importlib.util.find_spec("xformers") is not None +try: + if _xformers_available: + _xformers_version = importlib_metadata.version("xformers") + _torch_version = importlib_metadata.version("torch") + if version.Version(_torch_version) < version.Version("1.12"): + raise ValueError("xformers is installed but requires PyTorch >= 1.12") + logger.debug(f"Successfully imported xformers version {_xformers_version}") +except importlib_metadata.PackageNotFoundError: + _xformers_available = False + +_triton_modules_available = importlib.util.find_spec("triton") is not None +try: + if _triton_modules_available: + _triton_version = importlib_metadata.version("triton") + if version.Version(_triton_version) < version.Version("3.0.0"): + raise ValueError("triton is installed but requires Triton >= 3.0.0") + logger.debug(f"Successfully imported triton version {_triton_version}") +except ImportError: + _triton_modules_available = False + warnings.warn("TritonLiteMLA and TritonMBConvPreGLU with `triton` is not available on your platform.") + + +def is_xformers_available(): + return _xformers_available + + +def is_triton_module_available(): + return _triton_modules_available diff --git a/scripts/inference.py b/scripts/inference.py index 1c9a24e..52d21fd 100755 --- a/scripts/inference.py +++ b/scripts/inference.py @@ -35,7 +35,7 @@ warnings.filterwarnings("ignore") # ignore warning from diffusion import DPMS, FlowEuler, SASolverSampler -from diffusion.data.datasets.utils import * +from diffusion.data.datasets.utils import ASPECT_RATIO_512_TEST, ASPECT_RATIO_1024_TEST, ASPECT_RATIO_2048_TEST from diffusion.model.builder import build_model, get_tokenizer_and_text_encoder, get_vae, vae_decode from diffusion.model.utils import prepare_prompt_ar from diffusion.utils.config import SanaConfig diff --git a/scripts/inference_dpg.py b/scripts/inference_dpg.py index 5f6805a..bde68c9 100644 --- a/scripts/inference_dpg.py +++ b/scripts/inference_dpg.py @@ -34,7 +34,7 @@ warnings.filterwarnings("ignore") # ignore warning from diffusion import DPMS, FlowEuler, SASolverSampler -from diffusion.data.datasets.utils import * +from diffusion.data.datasets.utils import ASPECT_RATIO_512_TEST, ASPECT_RATIO_1024_TEST, ASPECT_RATIO_2048_TEST from diffusion.model.builder import build_model, get_tokenizer_and_text_encoder, get_vae, vae_decode from diffusion.model.utils import prepare_prompt_ar from diffusion.utils.config import SanaConfig diff --git a/scripts/inference_geneval.py b/scripts/inference_geneval.py index 05cb920..07cef9f 100644 --- a/scripts/inference_geneval.py +++ b/scripts/inference_geneval.py @@ -36,7 +36,7 @@ warnings.filterwarnings("ignore") # ignore warning from diffusion import DPMS, FlowEuler, SASolverSampler -from diffusion.data.datasets.utils import * +from diffusion.data.datasets.utils import ASPECT_RATIO_512_TEST, ASPECT_RATIO_1024_TEST, ASPECT_RATIO_2048_TEST from diffusion.model.builder import build_model, get_tokenizer_and_text_encoder, get_vae, vae_decode from diffusion.model.utils import prepare_prompt_ar from diffusion.utils.config import SanaConfig diff --git a/scripts/inference_image_reward.py b/scripts/inference_image_reward.py index 972541c..fb3e1a5 100644 --- a/scripts/inference_image_reward.py +++ b/scripts/inference_image_reward.py @@ -33,7 +33,7 @@ from tqdm import tqdm from diffusion import DPMS, FlowEuler, SASolverSampler -from diffusion.data.datasets.utils import * +from diffusion.data.datasets.utils import ASPECT_RATIO_512_TEST, ASPECT_RATIO_1024_TEST, ASPECT_RATIO_2048_TEST from diffusion.model.builder import build_model, get_tokenizer_and_text_encoder, get_vae, vae_decode from diffusion.model.utils import prepare_prompt_ar from diffusion.utils.config import SanaConfig diff --git a/scripts/interface.py b/scripts/interface.py index e64c0f9..d81199c 100755 --- a/scripts/interface.py +++ b/scripts/interface.py @@ -34,7 +34,7 @@ from asset.examples import examples from diffusion import DPMS, FlowEuler, SASolverSampler -from diffusion.data.datasets.utils import * +from diffusion.data.datasets.utils import ASPECT_RATIO_512_TEST, ASPECT_RATIO_1024_TEST, ASPECT_RATIO_2048_TEST from diffusion.model.builder import build_model, get_tokenizer_and_text_encoder, get_vae, vae_decode from diffusion.model.utils import prepare_prompt_ar, resize_and_crop_tensor from diffusion.utils.config import SanaConfig