Skip to content

Commit

Permalink
Add Ascend NPU support for SDXL fine-tuning and fix the model saving …
Browse files Browse the repository at this point in the history
…bug when using DeepSpeed. (#7816)

* Add Ascend NPU support for SDXL fine-tuning and fix the model saving bug when using DeepSpeed.

* fix check code quality

* Decouple the NPU flash attention and make it an independent module.

* add doc and unit tests for npu flash attention.

---------

Co-authored-by: mhh001 <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
  • Loading branch information
3 people authored May 3, 2024
1 parent 3e35628 commit 5823736
Show file tree
Hide file tree
Showing 7 changed files with 261 additions and 12 deletions.
3 changes: 3 additions & 0 deletions docs/source/en/api/attnprocessor.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,6 @@ An attention processor is a class for applying different types of attention mech

## XFormersAttnProcessor
[[autodoc]] models.attention_processor.XFormersAttnProcessor

## AttnProcessorNPU
[[autodoc]] models.attention_processor.AttnProcessorNPU
19 changes: 16 additions & 3 deletions examples/controlnet/train_controlnet_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from accelerate.utils import DistributedType, ProjectConfiguration, set_seed
from datasets import load_dataset
from huggingface_hub import create_repo, upload_folder
from packaging import version
Expand All @@ -53,7 +53,7 @@
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module


Expand All @@ -64,6 +64,8 @@
check_min_version("0.28.0.dev0")

logger = get_logger(__name__)
if is_torch_npu_available():
torch.npu.config.allow_internal_format = False


def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step, is_final_validation=False):
Expand Down Expand Up @@ -471,6 +473,9 @@ def parse_args(input_args=None):
parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
)
parser.add_argument(
"--enable_npu_flash_attention", action="store_true", help="Whether or not to use npu flash attention."
)
parser.add_argument(
"--set_grads_to_none",
action="store_true",
Expand Down Expand Up @@ -936,6 +941,13 @@ def load_model_hook(models, input_dir):
text_encoder_two.requires_grad_(False)
controlnet.train()

if args.enable_npu_flash_attention:
if is_torch_npu_available():
logger.info("npu flash attention enabled.")
unet.enable_npu_flash_attention()
else:
raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu devices.")

if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
import xformers
Expand Down Expand Up @@ -1235,7 +1247,8 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer
progress_bar.update(1)
global_step += 1

if accelerator.is_main_process:
# DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues.
if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
Expand Down
19 changes: 16 additions & 3 deletions examples/text_to_image/train_text_to_image_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
from accelerate.utils import DistributedDataParallelKwargs, DistributedType, ProjectConfiguration, set_seed
from datasets import load_dataset
from huggingface_hub import create_repo, upload_folder
from packaging import version
Expand Down Expand Up @@ -60,14 +60,16 @@
is_wandb_available,
)
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module


# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.28.0.dev0")

logger = get_logger(__name__)
if is_torch_npu_available():
torch.npu.config.allow_internal_format = False


def save_model_card(
Expand Down Expand Up @@ -419,6 +421,9 @@ def parse_args(input_args=None):
parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
)
parser.add_argument(
"--enable_npu_flash_attention", action="store_true", help="Whether or not to use npu flash attention."
)
parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
parser.add_argument(
"--rank",
Expand Down Expand Up @@ -623,6 +628,13 @@ def main(args):
text_encoder_one.to(accelerator.device, dtype=weight_dtype)
text_encoder_two.to(accelerator.device, dtype=weight_dtype)

if args.enable_npu_flash_attention:
if is_torch_npu_available():
logger.info("npu flash attention enabled.")
unet.enable_npu_flash_attention()
else:
raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu devices.")

if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
import xformers
Expand Down Expand Up @@ -1149,7 +1161,8 @@ def compute_time_ids(original_size, crops_coords_top_left):
accelerator.log({"train_loss": train_loss}, step=global_step)
train_loss = 0.0

if accelerator.is_main_process:
# DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues.
if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
Expand Down
14 changes: 11 additions & 3 deletions src/diffusers/models/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,12 @@
from torch import nn

from ..utils import deprecate
from ..utils.import_utils import is_torch_npu_available


if is_torch_npu_available():
import torch_npu

ACTIVATION_FUNCTIONS = {
"swish": nn.SiLU(),
"silu": nn.SiLU(),
Expand Down Expand Up @@ -98,9 +102,13 @@ def forward(self, hidden_states, *args, **kwargs):
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)

hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
return hidden_states * self.gelu(gate)
hidden_states = self.proj(hidden_states)
if is_torch_npu_available():
# using torch_npu.npu_geglu can run faster and save memory on NPU.
return torch_npu.npu_geglu(hidden_states, dim=-1, approximate=1)[0]
else:
hidden_states, gate = hidden_states.chunk(2, dim=-1)
return hidden_states * self.gelu(gate)


class ApproximateGELU(nn.Module):
Expand Down
132 changes: 131 additions & 1 deletion src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import math
from importlib import import_module
from typing import Callable, List, Optional, Union

Expand All @@ -21,13 +22,15 @@

from ..image_processor import IPAdapterMaskProcessor
from ..utils import deprecate, logging
from ..utils.import_utils import is_xformers_available
from ..utils.import_utils import is_torch_npu_available, is_xformers_available
from ..utils.torch_utils import maybe_allow_in_graph
from .lora import LoRALinearLayer


logger = logging.get_logger(__name__) # pylint: disable=invalid-name

if is_torch_npu_available():
import torch_npu

if is_xformers_available():
import xformers
Expand Down Expand Up @@ -209,6 +212,23 @@ def __init__(
)
self.set_processor(processor)

def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
r"""
Set whether to use npu flash attention from `torch_npu` or not.
"""
if use_npu_flash_attention:
processor = AttnProcessorNPU()
else:
# set attention processor
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
processor = (
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
)
self.set_processor(processor)

def set_use_memory_efficient_attention_xformers(
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
) -> None:
Expand Down Expand Up @@ -1207,6 +1227,116 @@ def __call__(
return hidden_states


class AttnProcessorNPU:

r"""
Processor for implementing flash attention using torch_npu. Torch_npu supports only fp16 and bf16 data types. If
fp32 is used, F.scaled_dot_product_attention will be used for computation, but the acceleration effect on NPU is
not significant.
"""

def __init__(self):
if not is_torch_npu_available():
raise ImportError("AttnProcessorNPU requires torch_npu extensions and is supported only on npu devices.")

def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
*args,
**kwargs,
) -> torch.FloatTensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)

residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)

input_ndim = hidden_states.ndim

if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)

if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

query = attn.to_q(hidden_states)

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)

inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads

query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

# the output of sdp = (batch, num_heads, seq_len, head_dim)
if query.dtype in (torch.float16, torch.bfloat16):
hidden_states = torch_npu.npu_fusion_attention(
query,
key,
value,
attn.heads,
input_layout="BNSD",
pse=None,
atten_mask=attention_mask,
scale=1.0 / math.sqrt(query.shape[-1]),
pre_tockens=65536,
next_tockens=65536,
keep_prob=1.0,
sync=False,
inner_precise=0,
)[0]
else:
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)

hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)

# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)

if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

if attn.residual_connection:
hidden_states = hidden_states + residual

hidden_states = hidden_states / attn.rescale_output_factor

return hidden_states


class AttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
Expand Down
30 changes: 30 additions & 0 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,36 @@ def disable_gradient_checkpointing(self) -> None:
if self._supports_gradient_checkpointing:
self.apply(partial(self._set_gradient_checkpointing, value=False))

def set_use_npu_flash_attention(self, valid: bool) -> None:
r"""
Set the switch for the npu flash attention.
"""

def fn_recursive_set_npu_flash_attention(module: torch.nn.Module):
if hasattr(module, "set_use_npu_flash_attention"):
module.set_use_npu_flash_attention(valid)

for child in module.children():
fn_recursive_set_npu_flash_attention(child)

for module in self.children():
if isinstance(module, torch.nn.Module):
fn_recursive_set_npu_flash_attention(module)

def enable_npu_flash_attention(self) -> None:
r"""
Enable npu flash attention from torch_npu
"""
self.set_use_npu_flash_attention(True)

def disable_npu_flash_attention(self) -> None:
r"""
disable npu flash attention from torch_npu
"""
self.set_use_npu_flash_attention(False)

def set_use_memory_efficient_attention_xformers(
self, valid: bool, attention_op: Optional[Callable] = None
) -> None:
Expand Down
Loading

0 comments on commit 5823736

Please sign in to comment.