Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Parallel Cross Entropy #2017

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion optimum/fx/parallelization/decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def run(self, *args, **kwargs):
def decompose_and_functionalize(
graph_module: GraphModule,
decomposition_table: Dict[torch._ops.OperatorBase, Callable] = core_aten_decompositions(),
leaf_function_targets: List[Callable] = [F.scaled_dot_product_attention],
leaf_function_targets: List[Callable] = [F.scaled_dot_product_attention, F.cross_entropy],
) -> Callable:
"""
API to decompose and functionalize a high-level graph module.
Expand Down
35 changes: 27 additions & 8 deletions optimum/fx/parallelization/op_registry/op_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch.fx import Node

from ..core import Config
from ..utils import is_activation, is_embedding, is_linear
from ..utils import is_activation, is_cross_entropy, is_cross_entropy_parallel_compatible, is_embedding, is_linear


class Registry:
Expand Down Expand Up @@ -334,7 +334,16 @@ def propagate(self) -> List[int]:
ndim = arg.meta["val"].ndim
slice_dim = (slice_dim + ndim) % ndim
if slice_dim == axis:
# slice on the parallel axis is not allowed
# slice on the parallel axis is not allowed, except it's a nop
start, stop, step = 0, arg.meta["val"].shape[axis], 1
if len(self.node.args) > 2:
start = self.node.args[2]
elif len(self.node.args) > 3:
stop = self.node.args[3]
elif len(self.node.args) > 4:
step = self.node.args[4]
if start == 0 and stop >= arg.meta["val"].shape[axis] and step == 1:
return [axis]
return []
return [axis]

Expand Down Expand Up @@ -404,12 +413,12 @@ def propagate(self) -> List[int]:
if self.node.op in ["placeholder", "get_attr"]:
return [None]
elif self.node.op == "output":
for node in self.node.all_input_nodes:
# TODO: allow parallelized nodes in output, and append comm ops in graph tp all-gather
# parallelized output if intructed
if self.extract_axis(node) is not None:
return []
return [None]
# does not care about if output is being parallelized right now, because if the output is loss,
# then it must be not parallelized as long as it comes from sharded cross entropy.
# TODO: append all-gather comm ops before all parallelized output nodes if instructed.
input_arg = self.node.all_input_nodes[0]
axis = self.extract_axis(input_arg)
return [axis]
elif is_linear(self.node):
input_arg = self.node.all_input_nodes[0]
axis = self.extract_axis(input_arg)
Expand Down Expand Up @@ -438,6 +447,16 @@ def propagate(self) -> List[int]:
return [1, None] if self.config.enable_sequence_parallel else [None]
else:
return []
elif is_cross_entropy(self.node):
logits = self.node.all_input_nodes[0]
axis = self.extract_axis(logits)
if axis is None or (
is_cross_entropy_parallel_compatible(self.node) and axis == logits.meta["val"].ndim - 1
):
# for cross entropy, the input logits parallel axis can only be the last axis or None
return [None]
else:
return []
elif is_activation(self.node):
return UnaryOpParallelAxisPropagateHandler(self.node, self.meta_key, self.config).propagate()

Expand Down
1 change: 1 addition & 0 deletions optimum/fx/parallelization/parallel_layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@
# limitations under the License.
from .embedding import VocabParallelEmbedding
from .linear import ColumnParallelLinear, RowParallelLinear
from .loss import VocabParallelCrossEntropyLoss, sharded_cross_entropy_wrapper_fn
163 changes: 163 additions & 0 deletions optimum/fx/parallelization/parallel_layers/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import wraps
from typing import Optional

import torch
import torch.distributed as dist
import torch.nn as nn

from ..core import ParallelExecutionCtx


# Adapted from https://github.com/huggingface/nanotron/blob/main/src/nanotron/parallel/tensor_parallel/functional.py
class _ShardedCrossEntropy(torch.autograd.Function):
@staticmethod
def forward(
ctx,
sharded_logits: torch.Tensor, # (batch_size, length, sharded_hidden_size)
target: torch.Tensor, # (batch_size, length)
group: dist.ProcessGroup,
):
# Maximum value along last dimension across all GPUs.
logits_max = torch.max(sharded_logits, dim=-1)[0]
dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=group)
# Subtract the maximum value.
sharded_logits = sharded_logits - logits_max.unsqueeze(dim=-1)

# Get the shard's indices
sharded_hidden_size = sharded_logits.shape[-1]
rank = dist.get_rank(group)
start_index = rank * sharded_hidden_size
end_index = start_index + sharded_hidden_size

# Create a mask of valid ids (1 means it needs to be masked).
target_mask = (target < start_index) | (target >= end_index)
masked_target = target.clone() - start_index
masked_target[target_mask] = 0

# Get predicted-logits = logits[target].
# For Simplicity, we convert logits to a 2-D tensor with size
# [*, shard-size] and target to a 1-D tensor of size [*].
logits_2d = sharded_logits.view(-1, sharded_hidden_size)
masked_target_1d = masked_target.view(-1)
arange_1d = torch.arange(start=0, end=logits_2d.shape[0], device=logits_2d.device)
predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
if predicted_logits_1d.is_contiguous():
predicted_logits_1d = predicted_logits_1d.clone()
else:
predicted_logits_1d = predicted_logits_1d.contiguous()
predicted_logits = predicted_logits_1d.view_as(target)
predicted_logits[target_mask] = 0.0
# All reduce is needed to get the chunks from other GPUs.
dist.all_reduce(predicted_logits, op=dist.ReduceOp.SUM, group=group)

# Sum of exponential of logits along vocab dimension across all GPUs.
exp_logits = sharded_logits
torch.exp(sharded_logits, out=exp_logits)
sum_exp_logits = exp_logits.sum(dim=-1)
dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=group)

# Loss = log(sum(exp(logits))) - predicted-logit.
loss = torch.log(sum_exp_logits) - predicted_logits

# Normalize and optionally smooth logits
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))

# Store softmax, target-mask and masked-target for backward pass.
ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)

return loss

@staticmethod
def backward(ctx, grad_output: torch.Tensor):
# Retrieve tensors from the forward path.
softmax, target_mask, masked_target_1d = ctx.saved_tensors

# All the inputs have softmax as their gradient.
grad_input = softmax
# For simplicity, work with the 2D gradient.
sharded_hidden_size = softmax.size()[-1]
grad_2d = grad_input.view(-1, sharded_hidden_size)

# Add the gradient from matching classes.
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device)
grad_2d[arange_1d, masked_target_1d] -= 1.0 - target_mask.view(-1).float()

# Finally elementwise multiplication with the output gradients.
grad_input.mul_(grad_output.unsqueeze(dim=-1))

return grad_input, None, None


def sharded_cross_entropy(sharded_logits: torch.Tensor, target: torch.Tensor, process_group: dist.ProcessGroup):
return _ShardedCrossEntropy.apply(sharded_logits, target, process_group)


def sharded_cross_entropy_wrapper_fn(process_group: dist.ProcessGroup):
@wraps(sharded_cross_entropy)
def wrapper(
sharded_logits: torch.Tensor,
target: torch.Tensor,
weight: Optional[torch.Tensor] = None,
size_average: Optional[bool] = None,
ignore_index: int = -100,
reduce: Optional[bool] = None,
reduction: str = "mean",
label_smoothing: float = 0.0,
):
if weight is not None or ignore_index != -100 or label_smoothing != 0.0:
raise ValueError(
"Does not support weighted mode, index ignoring and label smoothing in current parallel cross entropy implementation."
)
loss: torch.Tensor = sharded_cross_entropy(sharded_logits, target, process_group)

if size_average is not None or reduce is not None:
size_average = True if size_average is None else size_average
reduce = True if reduce is None else reduce

if size_average and reduce:
reduction = "mean"
elif reduce:
reduction = "sum"
else:
reduction = "none"

if reduction == "mean":
return loss.mean()
elif reduction == "sum":
return loss.sum()
return loss

return wrapper


class VocabParallelCrossEntropyLoss(nn.Module):
"""
Simple parallel cross entropy implementation which does not support weighted mode and label smoothing yet.
"""

def __init__(self, ctx: ParallelExecutionCtx, reduction: str = "mean") -> None:
super(VocabParallelCrossEntropyLoss, self).__init__()
self.process_group = ctx.tp_group
self.reduction = reduction

def forward(self, sharded_logits: torch.Tensor, target: torch.Tensor):
loss: torch.Tensor = _ShardedCrossEntropy.apply(sharded_logits, target, self.process_group)
if self.reduction == "mean":
return loss.mean()
elif self.reduction == "sum":
return loss.sum()
return loss
45 changes: 44 additions & 1 deletion optimum/fx/parallelization/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,15 @@
from .decomp import decompose_and_functionalize
from .distributed import scatter
from .op_registry import REGISTRY, FallbackParallelAxisPropagateHandler
from .parallel_layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding
from .parallel_layers import (
ColumnParallelLinear,
RowParallelLinear,
VocabParallelCrossEntropyLoss,
VocabParallelEmbedding,
sharded_cross_entropy_wrapper_fn,
)
from .utils import (
is_cross_entropy,
is_embedding,
is_linear,
is_shape_consumer,
Expand Down Expand Up @@ -273,6 +280,11 @@ def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Conf
info["sequence_parallel"] = False
self.place_marker_per_node(node, info)

elif is_cross_entropy(node):
axis_before = ParallelAxisSolverPass.get_stored_field_info(node.args[0], "parallel_axis")
if axis_before is not None:
self.place_marker_per_node(node, {"axis": "vocab"})

return graph_module


Expand Down Expand Up @@ -343,6 +355,35 @@ def handle_embedding(node: Node, ctx: ParallelExecutionCtx) -> None:
layer_cache[key] = new_mod
setattr(parent_mod, field, new_mod)

@staticmethod
def handle_cross_entropy(node: Node, ctx: ParallelExecutionCtx) -> None:
axis = ParallelLayerAnnotatePass.get_stored_field_info(node, field="axis")
if axis is None:
return

assert axis in {"vocab"}, "Only support parallelization on vocab dim for now."
if node.op == "call_module":
graph_module = node.graph.owning_module
prefix_and_field = node.target.rsplit(".", maxsplit=1)
if len(prefix_and_field) == 2:
parent_mod = graph_module.get_submodule(prefix_and_field[0])
field = prefix_and_field[1]
else:
parent_mod = graph_module
field = node.target

mod: nn.CrossEntropyLoss = graph_module.get_submodule(node.target)
key, layer_cache = node.target, ctx.parallel_layer_cache
if key in layer_cache:
new_mod = layer_cache[key]
else:
assert ctx.compile_times == 0, "illegal path for recompilation"
new_mod = VocabParallelCrossEntropyLoss(ctx, reduction=mod.reduction)
layer_cache[key] = new_mod
setattr(parent_mod, field, new_mod)
else:
node.target = sharded_cross_entropy_wrapper_fn(process_group=ctx.tp_group)

@staticmethod
def handle_hard_coded_axis_param(node: Node, ctx: ParallelExecutionCtx) -> None:
def extract_shape_from_node(node: Node) -> List[Any]:
Expand Down Expand Up @@ -384,6 +425,8 @@ def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Conf
self.handle_linear(node, ctx)
elif is_embedding(node):
self.handle_embedding(node, ctx)
elif is_cross_entropy(node):
self.handle_cross_entropy(node, ctx)
# correct the attention head num in parallel setting
elif is_shape_consumer(node):
self.handle_hard_coded_axis_param(node, ctx)
Expand Down
34 changes: 34 additions & 0 deletions optimum/fx/parallelization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,40 @@ def is_shape_generator(node: Node) -> bool:
return node.op == "call_method" and node.target == "size"


def is_cross_entropy(node: Node) -> bool:
if node.op == "call_function":
return node.target is F.cross_entropy
elif node.op == "call_module":
mod = node.graph.owning_module
return isinstance(mod.get_submodule(node.target), nn.CrossEntropyLoss)
return False


def is_cross_entropy_parallel_compatible(node: Node) -> bool:
"""
For now `VocabParallelCrossEntropyLoss` does not support weighted mode, index ignoring and label smoothing.
"""
if node.op == "call_function":
weight = node.kwargs.get("weight", None)
ignore_index = node.kwargs.get("ignore_index", -100)
label_smoothing = node.kwargs.get("label_smoothing", 0.0)
if len(node.args) > 2 and weight is None:
weight = node.args[2]
if len(node.args) > 4 and ignore_index == -100:
ignore_index = node.args[4]
if len(node.args) > 7 and label_smoothing == 0.0:
label_smoothing = node.args[7]

return weight is None and ignore_index == -100 and label_smoothing == 0.0

elif node.op == "call_module":
mod: nn.CrossEntropyLoss = node.graph.owning_module.get_submodule(node.target)
weight, label_smoothing, ignore_index = mod.weight, mod.label_smoothing, mod.ignore_index
return weight is None and ignore_index == -100 and label_smoothing == 0.0

return False


def stable_topological_sort(graph: Graph):
def _args(n: torch.fx.Node) -> List[torch.fx.node.Argument]:
args: List[torch.fx.node.Argument] = []
Expand Down
Loading
Loading