From 2ba9224f6a841e157cdc5069c1e0a6fa830557dc Mon Sep 17 00:00:00 2001 From: Shashank Rajput <144760128+ShashankMosaicML@users.noreply.github.com> Date: Tue, 19 Dec 2023 16:57:14 -0800 Subject: [PATCH] Updating the Flash Attention version to fix cross entropy loss (#812) * .. * .. * .. * .. * .. --- llmfoundry/models/hf/hf_causal_lm.py | 2 +- llmfoundry/models/layers/attention.py | 4 +- .../models/layers/cross_entropy_loss.py | 173 ------------------ llmfoundry/models/mpt/configuration_mpt.py | 2 +- llmfoundry/models/mpt/modeling_mpt.py | 6 +- setup.py | 4 +- tests/models/test_model.py | 16 +- 7 files changed, 11 insertions(+), 196 deletions(-) delete mode 100644 llmfoundry/models/layers/cross_entropy_loss.py diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py index fcac57d817..87b6080de7 100644 --- a/llmfoundry/models/hf/hf_causal_lm.py +++ b/llmfoundry/models/hf/hf_causal_lm.py @@ -100,7 +100,7 @@ def __init__(self, om_model_config: Union[DictConfig, if use_flash_attention_2 and not is_flash_v2_installed(): raise ValueError( 'use_flash_attention_2 is set to True, but flash-attention 2 is not installed. ' - + 'Please install flash_attn==2.3.2`.') + + 'Please install flash_attn==2.3.6`.') requested_attention_implementation = 'flash_attention_2' if use_flash_attention_2 else 'eager' config = AutoConfig.from_pretrained( diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 86e49c315d..ca0515d9cc 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -227,7 +227,7 @@ def flash_attn_fn( from flash_attn import bert_padding, flash_attn_interface # type: ignore # yapf: disable # isort: skip except: raise RuntimeError( - 'Please install flash-attn==1.0.9 or flash-attn==2.3.2') + 'Please install flash-attn==1.0.9 or flash-attn==2.3.6') check_valid_inputs(query, key, value) @@ -344,7 +344,7 @@ def flash_attn_fn( window_size=(sliding_window_size, sliding_window_size)) else: raise RuntimeError( - 'flash-attn==1.0.9 or flash-attn==2.3.2 is required.') + 'flash-attn==1.0.9 or flash-attn==2.3.6 is required.') output = bert_padding.pad_input( rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, diff --git a/llmfoundry/models/layers/cross_entropy_loss.py b/llmfoundry/models/layers/cross_entropy_loss.py deleted file mode 100644 index e3b0931701..0000000000 --- a/llmfoundry/models/layers/cross_entropy_loss.py +++ /dev/null @@ -1,173 +0,0 @@ -# Copyright 2022 MosaicML LLM Foundry authors -# SPDX-License-Identifier: Apache-2.0 - -# Copied from https://github.com/Dao-AILab/flash-attention/blob/f1a73d074002226c42ce65a1df170ecff9f022c0/flash_attn/losses/cross_entropy.py -# type: ignore - -# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py -# But we make it much faster: we compute the local loss and the LSE, and by exchanging the LSE and -# the losses we can get the global loss. There's no need to do it step by step -# (compute local max, exchange, compute exp, compute local sum, exchange, etc.) -# The original xentropy interface is here: https://github.com/NVIDIA/apex/blob/master/apex/contrib/xentropy/softmax_xentropy.py -import torch -import torch.nn as nn - -try: # This try...except is needed because hf transformers library requires it - import xentropy_cuda_lib -except Exception as e: - xentropy_cuda_lib = None - -# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for -# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent -# version of PyTorch. The following 2 lines are for backward compatibility with -# older PyTorch. -if 'all_gather_into_tensor' not in dir(torch.distributed): - torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base - - -class SoftmaxCrossEntropyLossFn(torch.autograd.Function): - - @staticmethod - def forward( - ctx, - logits, - labels, - smoothing=0.0, - ignored_index=-100, - inplace_backward=False, - process_group=None, - ): - """The forward function for softmax cross entropy loss. - - logits: (batch, vocab_size) - labels: (batch,) - If process_group is not None, we're doing Tensor Parallel: each process is responsible for - one part of the vocab. The loss needs to be aggregated across processes. - """ - batch, vocab_size = logits.shape - assert labels.shape == (batch,) - world_size = 1 if process_group is None else torch.distributed.get_world_size( - process_group) - ctx.total_classes = world_size * vocab_size - if world_size == 1: - losses, lse = xentropy_cuda_lib.forward(logits, labels, smoothing) - losses.masked_fill_(labels == ignored_index, 0) - labels_local = labels - else: - rank = torch.distributed.get_rank(process_group) - vocab_start_index, vocab_end_index = rank * vocab_size, ( - rank + 1) * vocab_size - - # Create a mask of valid vocab ids (1 means it needs to be masked). - labels_mask = (labels < vocab_start_index) | (labels >= - vocab_end_index) - ignored_mask = labels == ignored_index - labels_local = torch.where(ignored_mask, labels, - labels - vocab_start_index) - - # For tensor parallel cross entropy with smoothing, we want to pass in the total number - # of classes so that smoothing can be applied correctly. If total_classes=-1, use the - # last dimension of the input tensor. - losses, lse_local = xentropy_cuda_lib.forward( - logits, labels_local, smoothing, world_size * vocab_size) - assert lse_local.shape == (batch,) - assert losses.shape == (batch,) - losses.masked_fill_(ignored_mask, 0) - # For labels == ignored_index, the loss is always 0. - # If there's no smoothing, if labels are in the vocab of this partition, losses contains - # lse_local - predicted logit, and 0 otherwise. - # If there's smoothing=0.1, for labels in the vocab of this partition, losses contains - # 0.9 * (lse_local - predicted logit) + 0.1 * (lse_local - sum logit / total_classes) - # For labels not in the vocab of this partition, losses contains - # 0.1 * (lse_local - sum logit / total_classes). - - lse_allgather = torch.empty(world_size, - batch, - dtype=lse_local.dtype, - device=lse_local.device) - torch.distributed.all_gather_into_tensor(lse_allgather, - lse_local.contiguous(), - group=process_group) - handle_losses = torch.distributed.all_reduce( - losses, - op=torch.distributed.ReduceOp.SUM, - group=process_group, - async_op=True) - lse = torch.logsumexp(lse_allgather, dim=0) - # If there's no smoothing, the total losses are lse_local - predicted_logit, - # we just have to subtract the lse_local and add the lse (global). - # If there's smoothing=0.1, the total losses are - # 0.9 * (lse_local - predicted_logit) + 0.1 * (sum of all lse_local - sum logit / total_classes) - # We want 0.9 * (lse - predicted_logit) + 0.1 * (lse - sum logit / total_classes). - rank_per_sample = torch.div(labels, - vocab_size, - rounding_mode='floor') - lse_local = lse_allgather[ - rank_per_sample, - torch.arange(batch, device=lse_allgather.device)] - - handle_losses.wait() - if smoothing == 0.0: - losses += lse - lse_local - else: - losses += (1 - smoothing) * (lse - lse_local) + smoothing * ( - lse - lse_allgather.sum(dim=0)) - losses.masked_fill_(ignored_mask, 0) - - ctx.save_for_backward(logits, lse, labels_local) - ctx.smoothing = smoothing - ctx.ignored_index = ignored_index - ctx.inplace_backward = inplace_backward - return losses - - @staticmethod - def backward(ctx, grad_loss): - logits, lse, labels = ctx.saved_tensors - grad_loss = grad_loss.contiguous() - grad_loss.masked_fill_(labels == ctx.ignored_index, 0) - grad_logits = xentropy_cuda_lib.backward(grad_loss, logits, lse, labels, - ctx.smoothing, - ctx.inplace_backward, - ctx.total_classes) - return grad_logits, None, None, None, None, None, None - - -class CrossEntropyLoss(nn.Module): - - def __init__( - self, - ignore_index=-100, - reduction='mean', - label_smoothing=0.0, - inplace_backward=False, - process_group=None, - ): - super().__init__() - if xentropy_cuda_lib is None: - raise ValueError( - 'xentropy_cuda_lib is None, probably because importing xentropy_cuda_lib failed.' - ) - if reduction not in ['mean', 'none']: - raise NotImplementedError( - "Only support reduction = 'mean' or 'none'") - self.ignore_index = ignore_index - self.reduction = reduction - self.label_smoothing = label_smoothing - self.inplace_backward = inplace_backward - self.process_group = process_group - - def forward(self, input, target): - assert input.is_cuda and target.is_cuda - # SoftmaxCrossEntropyLoss implicitly casts to float - loss = SoftmaxCrossEntropyLossFn.apply( - input, - target, - self.label_smoothing, - self.ignore_index, - self.inplace_backward, - self.process_group, - ) - if self.reduction == 'mean': - return loss.sum() / (target != self.ignore_index).sum() - else: - return loss diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 913c39d44f..6c4c286712 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -304,5 +304,5 @@ def _validate_config(self) -> None: from flash_attn.bert_padding import unpad_input, pad_input # type: ignore # yapf: disable # isort: skip except: raise ImportError( - 'In order to set `use_pad_tok_in_ffn=False`, please install flash-attn==1.0.9 or flash-attn==2.3.2' + 'In order to set `use_pad_tok_in_ffn=False`, please install flash-attn==1.0.9 or flash-attn==2.3.6' ) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index c587edb723..8c134e2b9f 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -972,11 +972,7 @@ def __init__( loss_fn_config = om_model_config.get('loss_fn', 'fused_crossentropy') if loss_fn_config == 'fused_crossentropy': try: - # NOTE: The following is the original import statement from flash_attn library, which we have currently replaced with a copy pasted code from the same library's version 1.0.9. The reason is that using the CE loss from FA v2.3.2 results in an illegal memory access error at long sequence lengths (github issue: https://github.com/Dao-AILab/flash-attention/issues/714). - # from flash_attn.losses.cross_entropy import \ - # CrossEntropyLoss as FusedCrossEntropyLoss - # TODO: Once the problem with using FA v2's CE loss at longer sequence lengths is resolved (github issue: https://github.com/Dao-AILab/flash-attention/issues/714), revert back to directly importing the CE loss from FA library. - from llmfoundry.models.layers.cross_entropy_loss import \ + from flash_attn.losses.cross_entropy import \ CrossEntropyLoss as FusedCrossEntropyLoss self.loss_fn = FusedCrossEntropyLoss(ignore_index=-100) diff --git a/setup.py b/setup.py index 923705699c..c030fe3268 100644 --- a/setup.py +++ b/setup.py @@ -98,10 +98,8 @@ 'xentropy-cuda-lib@git+https://github.com/HazyResearch/flash-attention.git@v1.0.9#subdirectory=csrc/xentropy', ] extra_deps['gpu-flash2'] = [ - 'flash-attn==2.3.2', + 'flash-attn==2.3.6', 'mosaicml-turbo==0.0.4', - # PyPI does not support direct dependencies, so we remove this line before uploading from PyPI - 'xentropy-cuda-lib@git+https://github.com/HazyResearch/flash-attention.git@v2.3.2#subdirectory=csrc/xentropy', ] extra_deps['peft'] = [ diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 3b2fc22ee3..2419dbfa41 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -422,22 +422,16 @@ def test_determinism(attn_impl: str, precision: torch.dtype, ffn_type: str, @pytest.mark.gpu -@pytest.mark.parametrize('ce_loss_implementation', - ['FA_v1_copied', 'FA_imported']) -def test_loss_fn(ce_loss_implementation: str): +def test_loss_fn(): """Tests the Fused CrossEntropy vs torch.nn.CrossEntropy loss function. We provide non-zero tolerances to account for small numerics differences between the two loss implementations. """ - if ce_loss_implementation == 'FA_imported': - try: - from flash_attn.losses.cross_entropy import CrossEntropyLoss as FusedCrossEntropyLoss # type: ignore # isort: skip - except: - pytest.skip('Fused cross entropy was not installed') - else: - from llmfoundry.models.layers.cross_entropy_loss import \ - CrossEntropyLoss as FusedCrossEntropyLoss + try: + from flash_attn.losses.cross_entropy import CrossEntropyLoss as FusedCrossEntropyLoss # type: ignore # isort: skip + except: + pytest.skip('Fused cross entropy was not installed') # run numerical test in pure fp32 from torch.backends import cuda, cudnn