Skip to content

Commit

Permalink
Adding a fix for Cross Entropy Loss for long sequence lengths. (#795)
Browse files Browse the repository at this point in the history
* ..

* ..

* ..

* ..

* ..

* ..

* ..

* ..

* ..

* ..
  • Loading branch information
ShashankMosaicML authored Dec 12, 2023
1 parent 410d5c7 commit 96cf646
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 6 deletions.
173 changes: 173 additions & 0 deletions llmfoundry/models/layers/cross_entropy_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

# Copied from https://github.com/Dao-AILab/flash-attention/blob/f1a73d074002226c42ce65a1df170ecff9f022c0/flash_attn/losses/cross_entropy.py
# type: ignore

# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py
# But we make it much faster: we compute the local loss and the LSE, and by exchanging the LSE and
# the losses we can get the global loss. There's no need to do it step by step
# (compute local max, exchange, compute exp, compute local sum, exchange, etc.)
# The original xentropy interface is here: https://github.com/NVIDIA/apex/blob/master/apex/contrib/xentropy/softmax_xentropy.py
import torch
import torch.nn as nn

try: # This try...except is needed because hf transformers library requires it
import xentropy_cuda_lib
except Exception as e:
xentropy_cuda_lib = None

# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
# version of PyTorch. The following 2 lines are for backward compatibility with
# older PyTorch.
if 'all_gather_into_tensor' not in dir(torch.distributed):
torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base


class SoftmaxCrossEntropyLossFn(torch.autograd.Function):

@staticmethod
def forward(
ctx,
logits,
labels,
smoothing=0.0,
ignored_index=-100,
inplace_backward=False,
process_group=None,
):
"""The forward function for softmax cross entropy loss.
logits: (batch, vocab_size)
labels: (batch,)
If process_group is not None, we're doing Tensor Parallel: each process is responsible for
one part of the vocab. The loss needs to be aggregated across processes.
"""
batch, vocab_size = logits.shape
assert labels.shape == (batch,)
world_size = 1 if process_group is None else torch.distributed.get_world_size(
process_group)
ctx.total_classes = world_size * vocab_size
if world_size == 1:
losses, lse = xentropy_cuda_lib.forward(logits, labels, smoothing)
losses.masked_fill_(labels == ignored_index, 0)
labels_local = labels
else:
rank = torch.distributed.get_rank(process_group)
vocab_start_index, vocab_end_index = rank * vocab_size, (
rank + 1) * vocab_size

# Create a mask of valid vocab ids (1 means it needs to be masked).
labels_mask = (labels < vocab_start_index) | (labels >=
vocab_end_index)
ignored_mask = labels == ignored_index
labels_local = torch.where(ignored_mask, labels,
labels - vocab_start_index)

# For tensor parallel cross entropy with smoothing, we want to pass in the total number
# of classes so that smoothing can be applied correctly. If total_classes=-1, use the
# last dimension of the input tensor.
losses, lse_local = xentropy_cuda_lib.forward(
logits, labels_local, smoothing, world_size * vocab_size)
assert lse_local.shape == (batch,)
assert losses.shape == (batch,)
losses.masked_fill_(ignored_mask, 0)
# For labels == ignored_index, the loss is always 0.
# If there's no smoothing, if labels are in the vocab of this partition, losses contains
# lse_local - predicted logit, and 0 otherwise.
# If there's smoothing=0.1, for labels in the vocab of this partition, losses contains
# 0.9 * (lse_local - predicted logit) + 0.1 * (lse_local - sum logit / total_classes)
# For labels not in the vocab of this partition, losses contains
# 0.1 * (lse_local - sum logit / total_classes).

lse_allgather = torch.empty(world_size,
batch,
dtype=lse_local.dtype,
device=lse_local.device)
torch.distributed.all_gather_into_tensor(lse_allgather,
lse_local.contiguous(),
group=process_group)
handle_losses = torch.distributed.all_reduce(
losses,
op=torch.distributed.ReduceOp.SUM,
group=process_group,
async_op=True)
lse = torch.logsumexp(lse_allgather, dim=0)
# If there's no smoothing, the total losses are lse_local - predicted_logit,
# we just have to subtract the lse_local and add the lse (global).
# If there's smoothing=0.1, the total losses are
# 0.9 * (lse_local - predicted_logit) + 0.1 * (sum of all lse_local - sum logit / total_classes)
# We want 0.9 * (lse - predicted_logit) + 0.1 * (lse - sum logit / total_classes).
rank_per_sample = torch.div(labels,
vocab_size,
rounding_mode='floor')
lse_local = lse_allgather[
rank_per_sample,
torch.arange(batch, device=lse_allgather.device)]

handle_losses.wait()
if smoothing == 0.0:
losses += lse - lse_local
else:
losses += (1 - smoothing) * (lse - lse_local) + smoothing * (
lse - lse_allgather.sum(dim=0))
losses.masked_fill_(ignored_mask, 0)

ctx.save_for_backward(logits, lse, labels_local)
ctx.smoothing = smoothing
ctx.ignored_index = ignored_index
ctx.inplace_backward = inplace_backward
return losses

@staticmethod
def backward(ctx, grad_loss):
logits, lse, labels = ctx.saved_tensors
grad_loss = grad_loss.contiguous()
grad_loss.masked_fill_(labels == ctx.ignored_index, 0)
grad_logits = xentropy_cuda_lib.backward(grad_loss, logits, lse, labels,
ctx.smoothing,
ctx.inplace_backward,
ctx.total_classes)
return grad_logits, None, None, None, None, None, None


class CrossEntropyLoss(nn.Module):

def __init__(
self,
ignore_index=-100,
reduction='mean',
label_smoothing=0.0,
inplace_backward=False,
process_group=None,
):
super().__init__()
if xentropy_cuda_lib is None:
raise ValueError(
'xentropy_cuda_lib is None, probably because importing xentropy_cuda_lib failed.'
)
if reduction not in ['mean', 'none']:
raise NotImplementedError(
"Only support reduction = 'mean' or 'none'")
self.ignore_index = ignore_index
self.reduction = reduction
self.label_smoothing = label_smoothing
self.inplace_backward = inplace_backward
self.process_group = process_group

def forward(self, input, target):
assert input.is_cuda and target.is_cuda
# SoftmaxCrossEntropyLoss implicitly casts to float
loss = SoftmaxCrossEntropyLossFn.apply(
input,
target,
self.label_smoothing,
self.ignore_index,
self.inplace_backward,
self.process_group,
)
if self.reduction == 'mean':
return loss.sum() / (target != self.ignore_index).sum()
else:
return loss
6 changes: 5 additions & 1 deletion llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,7 +972,11 @@ def __init__(
loss_fn_config = om_model_config.get('loss_fn', 'fused_crossentropy')
if loss_fn_config == 'fused_crossentropy':
try:
from flash_attn.losses.cross_entropy import \
# NOTE: The following is the original import statement from flash_attn library, which we have currently replaced with a copy pasted code from the same library's version 1.0.9. The reason is that using the CE loss from FA v2.3.2 results in an illegal memory access error at long sequence lengths (github issue: https://github.com/Dao-AILab/flash-attention/issues/714).
# from flash_attn.losses.cross_entropy import \
# CrossEntropyLoss as FusedCrossEntropyLoss
# TODO: Once the problem with using FA v2's CE loss at longer sequence lengths is resolved (github issue: https://github.com/Dao-AILab/flash-attention/issues/714), revert back to directly importing the CE loss from FA library.
from llmfoundry.models.layers.cross_entropy_loss import \
CrossEntropyLoss as FusedCrossEntropyLoss

self.loss_fn = FusedCrossEntropyLoss(ignore_index=-100)
Expand Down
16 changes: 11 additions & 5 deletions tests/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,16 +398,22 @@ def test_determinism(attn_impl: str, precision: torch.dtype):


@pytest.mark.gpu
def test_loss_fn():
@pytest.mark.parametrize('ce_loss_implementation',
['FA_v1_copied', 'FA_imported'])
def test_loss_fn(ce_loss_implementation: str):
"""Tests the Fused CrossEntropy vs torch.nn.CrossEntropy loss function.
We provide non-zero tolerances to account for small numerics differences
between the two loss implementations.
"""
try:
from flash_attn.losses.cross_entropy import CrossEntropyLoss as FusedCrossEntropyLoss # type: ignore # isort: skip
except:
pytest.skip('Fused cross entropy was not installed')
if ce_loss_implementation == 'FA_imported':
try:
from flash_attn.losses.cross_entropy import CrossEntropyLoss as FusedCrossEntropyLoss # type: ignore # isort: skip
except:
pytest.skip('Fused cross entropy was not installed')
else:
from llmfoundry.models.layers.cross_entropy_loss import \
CrossEntropyLoss as FusedCrossEntropyLoss

# run numerical test in pure fp32
from torch.backends import cuda, cudnn
Expand Down

0 comments on commit 96cf646

Please sign in to comment.