-
Notifications
You must be signed in to change notification settings - Fork 538
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
302 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
# Copyright 2022 MosaicML LLM Foundry authors | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# Copied from https://github.com/Dao-AILab/flash-attention/blob/f1a73d074002226c42ce65a1df170ecff9f022c0/flash_attn/losses/cross_entropy.py | ||
# type: ignore | ||
|
||
# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py | ||
# But we make it much faster: we compute the local loss and the LSE, and by exchanging the LSE and | ||
# the losses we can get the global loss. There's no need to do it step by step | ||
# (compute local max, exchange, compute exp, compute local sum, exchange, etc.) | ||
# The original xentropy interface is here: https://github.com/NVIDIA/apex/blob/master/apex/contrib/xentropy/softmax_xentropy.py | ||
import torch | ||
import torch.nn as nn | ||
|
||
try: # This try...except is needed because hf transformers library requires it | ||
import xentropy_cuda_lib | ||
except Exception as e: | ||
xentropy_cuda_lib = None | ||
|
||
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for | ||
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent | ||
# version of PyTorch. The following 2 lines are for backward compatibility with | ||
# older PyTorch. | ||
if 'all_gather_into_tensor' not in dir(torch.distributed): | ||
torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base | ||
|
||
|
||
class SoftmaxCrossEntropyLossFn(torch.autograd.Function): | ||
|
||
@staticmethod | ||
def forward( | ||
ctx, | ||
logits, | ||
labels, | ||
smoothing=0.0, | ||
ignored_index=-100, | ||
inplace_backward=False, | ||
process_group=None, | ||
): | ||
"""The forward function for softmax cross entropy loss. | ||
logits: (batch, vocab_size) | ||
labels: (batch,) | ||
If process_group is not None, we're doing Tensor Parallel: each process is responsible for | ||
one part of the vocab. The loss needs to be aggregated across processes. | ||
""" | ||
batch, vocab_size = logits.shape | ||
assert labels.shape == (batch,) | ||
world_size = 1 if process_group is None else torch.distributed.get_world_size( | ||
process_group) | ||
ctx.total_classes = world_size * vocab_size | ||
if world_size == 1: | ||
losses, lse = xentropy_cuda_lib.forward(logits, labels, smoothing) | ||
losses.masked_fill_(labels == ignored_index, 0) | ||
labels_local = labels | ||
else: | ||
rank = torch.distributed.get_rank(process_group) | ||
vocab_start_index, vocab_end_index = rank * vocab_size, ( | ||
rank + 1) * vocab_size | ||
|
||
# Create a mask of valid vocab ids (1 means it needs to be masked). | ||
labels_mask = (labels < vocab_start_index) | (labels >= | ||
vocab_end_index) | ||
ignored_mask = labels == ignored_index | ||
labels_local = torch.where(ignored_mask, labels, | ||
labels - vocab_start_index) | ||
|
||
# For tensor parallel cross entropy with smoothing, we want to pass in the total number | ||
# of classes so that smoothing can be applied correctly. If total_classes=-1, use the | ||
# last dimension of the input tensor. | ||
losses, lse_local = xentropy_cuda_lib.forward( | ||
logits, labels_local, smoothing, world_size * vocab_size) | ||
assert lse_local.shape == (batch,) | ||
assert losses.shape == (batch,) | ||
losses.masked_fill_(ignored_mask, 0) | ||
# For labels == ignored_index, the loss is always 0. | ||
# If there's no smoothing, if labels are in the vocab of this partition, losses contains | ||
# lse_local - predicted logit, and 0 otherwise. | ||
# If there's smoothing=0.1, for labels in the vocab of this partition, losses contains | ||
# 0.9 * (lse_local - predicted logit) + 0.1 * (lse_local - sum logit / total_classes) | ||
# For labels not in the vocab of this partition, losses contains | ||
# 0.1 * (lse_local - sum logit / total_classes). | ||
|
||
lse_allgather = torch.empty(world_size, | ||
batch, | ||
dtype=lse_local.dtype, | ||
device=lse_local.device) | ||
torch.distributed.all_gather_into_tensor(lse_allgather, | ||
lse_local.contiguous(), | ||
group=process_group) | ||
handle_losses = torch.distributed.all_reduce( | ||
losses, | ||
op=torch.distributed.ReduceOp.SUM, | ||
group=process_group, | ||
async_op=True) | ||
lse = torch.logsumexp(lse_allgather, dim=0) | ||
# If there's no smoothing, the total losses are lse_local - predicted_logit, | ||
# we just have to subtract the lse_local and add the lse (global). | ||
# If there's smoothing=0.1, the total losses are | ||
# 0.9 * (lse_local - predicted_logit) + 0.1 * (sum of all lse_local - sum logit / total_classes) | ||
# We want 0.9 * (lse - predicted_logit) + 0.1 * (lse - sum logit / total_classes). | ||
rank_per_sample = torch.div(labels, | ||
vocab_size, | ||
rounding_mode='floor') | ||
lse_local = lse_allgather[ | ||
rank_per_sample, | ||
torch.arange(batch, device=lse_allgather.device)] | ||
|
||
handle_losses.wait() | ||
if smoothing == 0.0: | ||
losses += lse - lse_local | ||
else: | ||
losses += (1 - smoothing) * (lse - lse_local) + smoothing * ( | ||
lse - lse_allgather.sum(dim=0)) | ||
losses.masked_fill_(ignored_mask, 0) | ||
|
||
ctx.save_for_backward(logits, lse, labels_local) | ||
ctx.smoothing = smoothing | ||
ctx.ignored_index = ignored_index | ||
ctx.inplace_backward = inplace_backward | ||
return losses | ||
|
||
@staticmethod | ||
def backward(ctx, grad_loss): | ||
logits, lse, labels = ctx.saved_tensors | ||
grad_loss = grad_loss.contiguous() | ||
grad_loss.masked_fill_(labels == ctx.ignored_index, 0) | ||
grad_logits = xentropy_cuda_lib.backward(grad_loss, logits, lse, labels, | ||
ctx.smoothing, | ||
ctx.inplace_backward, | ||
ctx.total_classes) | ||
return grad_logits, None, None, None, None, None, None | ||
|
||
|
||
class CrossEntropyLoss(nn.Module): | ||
|
||
def __init__( | ||
self, | ||
ignore_index=-100, | ||
reduction='mean', | ||
label_smoothing=0.0, | ||
inplace_backward=False, | ||
process_group=None, | ||
): | ||
super().__init__() | ||
if xentropy_cuda_lib is None: | ||
raise ValueError( | ||
'xentropy_cuda_lib is None, probably because importing xentropy_cuda_lib failed.' | ||
) | ||
if reduction not in ['mean', 'none']: | ||
raise NotImplementedError( | ||
"Only support reduction = 'mean' or 'none'") | ||
self.ignore_index = ignore_index | ||
self.reduction = reduction | ||
self.label_smoothing = label_smoothing | ||
self.inplace_backward = inplace_backward | ||
self.process_group = process_group | ||
|
||
def forward(self, input, target): | ||
assert input.is_cuda and target.is_cuda | ||
# SoftmaxCrossEntropyLoss implicitly casts to float | ||
loss = SoftmaxCrossEntropyLossFn.apply( | ||
input, | ||
target, | ||
self.label_smoothing, | ||
self.ignore_index, | ||
self.inplace_backward, | ||
self.process_group, | ||
) | ||
if self.reduction == 'mean': | ||
return loss.sum() / (target != self.ignore_index).sum() | ||
else: | ||
return loss |
Oops, something went wrong.