diff --git a/docs/source/en/api/attnprocessor.md b/docs/source/en/api/attnprocessor.md index ab89d4d260f0..f586e9b08f2c 100644 --- a/docs/source/en/api/attnprocessor.md +++ b/docs/source/en/api/attnprocessor.md @@ -55,3 +55,6 @@ An attention processor is a class for applying different types of attention mech ## XFormersAttnProcessor [[autodoc]] models.attention_processor.XFormersAttnProcessor + +## AttnProcessorNPU +[[autodoc]] models.attention_processor.AttnProcessorNPU diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index 62192521a323..288a1e3fb612 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -32,7 +32,7 @@ import transformers from accelerate import Accelerator from accelerate.logging import get_logger -from accelerate.utils import ProjectConfiguration, set_seed +from accelerate.utils import DistributedType, ProjectConfiguration, set_seed from datasets import load_dataset from huggingface_hub import create_repo, upload_folder from packaging import version @@ -53,7 +53,7 @@ from diffusers.optimization import get_scheduler from diffusers.utils import check_min_version, is_wandb_available, make_image_grid from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card -from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available from diffusers.utils.torch_utils import is_compiled_module @@ -64,6 +64,8 @@ check_min_version("0.28.0.dev0") logger = get_logger(__name__) +if is_torch_npu_available(): + torch.npu.config.allow_internal_format = False def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step, is_final_validation=False): @@ -471,6 +473,9 @@ def parse_args(input_args=None): parser.add_argument( "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." ) + parser.add_argument( + "--enable_npu_flash_attention", action="store_true", help="Whether or not to use npu flash attention." + ) parser.add_argument( "--set_grads_to_none", action="store_true", @@ -936,6 +941,13 @@ def load_model_hook(models, input_dir): text_encoder_two.requires_grad_(False) controlnet.train() + if args.enable_npu_flash_attention: + if is_torch_npu_available(): + logger.info("npu flash attention enabled.") + unet.enable_npu_flash_attention() + else: + raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu devices.") + if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): import xformers @@ -1235,7 +1247,8 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer progress_bar.update(1) global_step += 1 - if accelerator.is_main_process: + # DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues. + if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process: if global_step % args.checkpointing_steps == 0: # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if args.checkpoints_total_limit is not None: diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index 0a6a70de2dc7..3604e755c62a 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -32,7 +32,7 @@ import transformers from accelerate import Accelerator from accelerate.logging import get_logger -from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed +from accelerate.utils import DistributedDataParallelKwargs, DistributedType, ProjectConfiguration, set_seed from datasets import load_dataset from huggingface_hub import create_repo, upload_folder from packaging import version @@ -60,7 +60,7 @@ is_wandb_available, ) from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card -from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available from diffusers.utils.torch_utils import is_compiled_module @@ -68,6 +68,8 @@ check_min_version("0.28.0.dev0") logger = get_logger(__name__) +if is_torch_npu_available(): + torch.npu.config.allow_internal_format = False def save_model_card( @@ -419,6 +421,9 @@ def parse_args(input_args=None): parser.add_argument( "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." ) + parser.add_argument( + "--enable_npu_flash_attention", action="store_true", help="Whether or not to use npu flash attention." + ) parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") parser.add_argument( "--rank", @@ -623,6 +628,13 @@ def main(args): text_encoder_one.to(accelerator.device, dtype=weight_dtype) text_encoder_two.to(accelerator.device, dtype=weight_dtype) + if args.enable_npu_flash_attention: + if is_torch_npu_available(): + logger.info("npu flash attention enabled.") + unet.enable_npu_flash_attention() + else: + raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu devices.") + if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): import xformers @@ -1149,7 +1161,8 @@ def compute_time_ids(original_size, crops_coords_top_left): accelerator.log({"train_loss": train_loss}, step=global_step) train_loss = 0.0 - if accelerator.is_main_process: + # DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues. + if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process: if global_step % args.checkpointing_steps == 0: # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if args.checkpoints_total_limit is not None: diff --git a/src/diffusers/models/activations.py b/src/diffusers/models/activations.py index cec83bdded9e..f94b6c8d6d06 100644 --- a/src/diffusers/models/activations.py +++ b/src/diffusers/models/activations.py @@ -18,8 +18,12 @@ from torch import nn from ..utils import deprecate +from ..utils.import_utils import is_torch_npu_available +if is_torch_npu_available(): + import torch_npu + ACTIVATION_FUNCTIONS = { "swish": nn.SiLU(), "silu": nn.SiLU(), @@ -98,9 +102,13 @@ def forward(self, hidden_states, *args, **kwargs): if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) - - hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) - return hidden_states * self.gelu(gate) + hidden_states = self.proj(hidden_states) + if is_torch_npu_available(): + # using torch_npu.npu_geglu can run faster and save memory on NPU. + return torch_npu.npu_geglu(hidden_states, dim=-1, approximate=1)[0] + else: + hidden_states, gate = hidden_states.chunk(2, dim=-1) + return hidden_states * self.gelu(gate) class ApproximateGELU(nn.Module): diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 429807989296..ea1c987e95c6 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect +import math from importlib import import_module from typing import Callable, List, Optional, Union @@ -21,13 +22,15 @@ from ..image_processor import IPAdapterMaskProcessor from ..utils import deprecate, logging -from ..utils.import_utils import is_xformers_available +from ..utils.import_utils import is_torch_npu_available, is_xformers_available from ..utils.torch_utils import maybe_allow_in_graph from .lora import LoRALinearLayer logger = logging.get_logger(__name__) # pylint: disable=invalid-name +if is_torch_npu_available(): + import torch_npu if is_xformers_available(): import xformers @@ -209,6 +212,23 @@ def __init__( ) self.set_processor(processor) + def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None: + r""" + Set whether to use npu flash attention from `torch_npu` or not. + + """ + if use_npu_flash_attention: + processor = AttnProcessorNPU() + else: + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + processor = ( + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() + ) + self.set_processor(processor) + def set_use_memory_efficient_attention_xformers( self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None ) -> None: @@ -1207,6 +1227,116 @@ def __call__( return hidden_states +class AttnProcessorNPU: + + r""" + Processor for implementing flash attention using torch_npu. Torch_npu supports only fp16 and bf16 data types. If + fp32 is used, F.scaled_dot_product_attention will be used for computation, but the acceleration effect on NPU is + not significant. + + """ + + def __init__(self): + if not is_torch_npu_available(): + raise ImportError("AttnProcessorNPU requires torch_npu extensions and is supported only on npu devices.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + if query.dtype in (torch.float16, torch.bfloat16): + hidden_states = torch_npu.npu_fusion_attention( + query, + key, + value, + attn.heads, + input_layout="BNSD", + pse=None, + atten_mask=attention_mask, + scale=1.0 / math.sqrt(query.shape[-1]), + pre_tockens=65536, + next_tockens=65536, + keep_prob=1.0, + sync=False, + inner_precise=0, + )[0] + else: + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + class AttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index a8518ca3ff7f..373f5453aa23 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -272,6 +272,36 @@ def disable_gradient_checkpointing(self) -> None: if self._supports_gradient_checkpointing: self.apply(partial(self._set_gradient_checkpointing, value=False)) + def set_use_npu_flash_attention(self, valid: bool) -> None: + r""" + Set the switch for the npu flash attention. + """ + + def fn_recursive_set_npu_flash_attention(module: torch.nn.Module): + if hasattr(module, "set_use_npu_flash_attention"): + module.set_use_npu_flash_attention(valid) + + for child in module.children(): + fn_recursive_set_npu_flash_attention(child) + + for module in self.children(): + if isinstance(module, torch.nn.Module): + fn_recursive_set_npu_flash_attention(module) + + def enable_npu_flash_attention(self) -> None: + r""" + Enable npu flash attention from torch_npu + + """ + self.set_use_npu_flash_attention(True) + + def disable_npu_flash_attention(self) -> None: + r""" + disable npu flash attention from torch_npu + + """ + self.set_use_npu_flash_attention(False) + def set_use_memory_efficient_attention_xformers( self, valid: bool, attention_op: Optional[Callable] = None ) -> None: diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index d9e70c6dd784..59369b509876 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -30,9 +30,14 @@ from requests.exceptions import HTTPError from diffusers.models import UNet2DConditionModel -from diffusers.models.attention_processor import AttnProcessor, AttnProcessor2_0, XFormersAttnProcessor +from diffusers.models.attention_processor import ( + AttnProcessor, + AttnProcessor2_0, + AttnProcessorNPU, + XFormersAttnProcessor, +) from diffusers.training_utils import EMAModel -from diffusers.utils import is_xformers_available, logging +from diffusers.utils import is_torch_npu_available, is_xformers_available, logging from diffusers.utils.testing_utils import ( CaptureLogger, get_python_version, @@ -300,6 +305,53 @@ def test_getattr_is_correct(self): assert str(error.exception) == f"'{type(model).__name__}' object has no attribute 'does_not_exist'" + @unittest.skipIf( + torch_device != "npu" or not is_torch_npu_available(), + reason="torch npu flash attention is only available with NPU and `torch_npu` installed", + ) + def test_set_torch_npu_flash_attn_processor_determinism(self): + torch.use_deterministic_algorithms(False) + if self.forward_requires_fresh_args: + model = self.model_class(**self.init_dict) + else: + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + + if not hasattr(model, "set_attn_processor"): + # If not has `set_attn_processor`, skip test + return + + model.set_default_attn_processor() + assert all(type(proc) == AttnProcessorNPU for proc in model.attn_processors.values()) + with torch.no_grad(): + if self.forward_requires_fresh_args: + output = model(**self.inputs_dict(0))[0] + else: + output = model(**inputs_dict)[0] + + model.enable_npu_flash_attention() + assert all(type(proc) == AttnProcessorNPU for proc in model.attn_processors.values()) + with torch.no_grad(): + if self.forward_requires_fresh_args: + output_2 = model(**self.inputs_dict(0))[0] + else: + output_2 = model(**inputs_dict)[0] + + model.set_attn_processor(AttnProcessorNPU()) + assert all(type(proc) == AttnProcessorNPU for proc in model.attn_processors.values()) + with torch.no_grad(): + if self.forward_requires_fresh_args: + output_3 = model(**self.inputs_dict(0))[0] + else: + output_3 = model(**inputs_dict)[0] + + torch.use_deterministic_algorithms(True) + + assert torch.allclose(output, output_2, atol=self.base_precision) + assert torch.allclose(output, output_3, atol=self.base_precision) + assert torch.allclose(output_2, output_3, atol=self.base_precision) + @unittest.skipIf( torch_device != "cuda" or not is_xformers_available(), reason="XFormers attention is only available with CUDA and `xformers` installed",