diff --git a/optimum/fx/parallelization/decomp.py b/optimum/fx/parallelization/decomp.py index 26258d451bf..5410818e929 100644 --- a/optimum/fx/parallelization/decomp.py +++ b/optimum/fx/parallelization/decomp.py @@ -197,7 +197,7 @@ def run(self, *args, **kwargs): def decompose_and_functionalize( graph_module: GraphModule, decomposition_table: Dict[torch._ops.OperatorBase, Callable] = core_aten_decompositions(), - leaf_function_targets: List[Callable] = [F.scaled_dot_product_attention], + leaf_function_targets: List[Callable] = [F.scaled_dot_product_attention, F.cross_entropy], ) -> Callable: """ API to decompose and functionalize a high-level graph module. diff --git a/optimum/fx/parallelization/op_registry/op_handlers.py b/optimum/fx/parallelization/op_registry/op_handlers.py index 56b8fc16bc0..4a9c55e3764 100644 --- a/optimum/fx/parallelization/op_registry/op_handlers.py +++ b/optimum/fx/parallelization/op_registry/op_handlers.py @@ -19,7 +19,7 @@ from torch.fx import Node from ..core import Config -from ..utils import is_activation, is_embedding, is_linear +from ..utils import is_activation, is_cross_entropy, is_cross_entropy_parallel_compatible, is_embedding, is_linear class Registry: @@ -334,7 +334,16 @@ def propagate(self) -> List[int]: ndim = arg.meta["val"].ndim slice_dim = (slice_dim + ndim) % ndim if slice_dim == axis: - # slice on the parallel axis is not allowed + # slice on the parallel axis is not allowed, except it's a nop + start, stop, step = 0, arg.meta["val"].shape[axis], 1 + if len(self.node.args) > 2: + start = self.node.args[2] + elif len(self.node.args) > 3: + stop = self.node.args[3] + elif len(self.node.args) > 4: + step = self.node.args[4] + if start == 0 and stop >= arg.meta["val"].shape[axis] and step == 1: + return [axis] return [] return [axis] @@ -404,12 +413,12 @@ def propagate(self) -> List[int]: if self.node.op in ["placeholder", "get_attr"]: return [None] elif self.node.op == "output": - for node in self.node.all_input_nodes: - # TODO: allow parallelized nodes in output, and append comm ops in graph tp all-gather - # parallelized output if intructed - if self.extract_axis(node) is not None: - return [] - return [None] + # does not care about if output is being parallelized right now, because if the output is loss, + # then it must be not parallelized as long as it comes from sharded cross entropy. + # TODO: append all-gather comm ops before all parallelized output nodes if instructed. + input_arg = self.node.all_input_nodes[0] + axis = self.extract_axis(input_arg) + return [axis] elif is_linear(self.node): input_arg = self.node.all_input_nodes[0] axis = self.extract_axis(input_arg) @@ -438,6 +447,16 @@ def propagate(self) -> List[int]: return [1, None] if self.config.enable_sequence_parallel else [None] else: return [] + elif is_cross_entropy(self.node): + logits = self.node.all_input_nodes[0] + axis = self.extract_axis(logits) + if axis is None or ( + is_cross_entropy_parallel_compatible(self.node) and axis == logits.meta["val"].ndim - 1 + ): + # for cross entropy, the input logits parallel axis can only be the last axis or None + return [None] + else: + return [] elif is_activation(self.node): return UnaryOpParallelAxisPropagateHandler(self.node, self.meta_key, self.config).propagate() diff --git a/optimum/fx/parallelization/parallel_layers/__init__.py b/optimum/fx/parallelization/parallel_layers/__init__.py index 9bfb13afdf6..474ae7f7eef 100644 --- a/optimum/fx/parallelization/parallel_layers/__init__.py +++ b/optimum/fx/parallelization/parallel_layers/__init__.py @@ -14,3 +14,4 @@ # limitations under the License. from .embedding import VocabParallelEmbedding from .linear import ColumnParallelLinear, RowParallelLinear +from .loss import VocabParallelCrossEntropyLoss, sharded_cross_entropy_wrapper_fn diff --git a/optimum/fx/parallelization/parallel_layers/loss.py b/optimum/fx/parallelization/parallel_layers/loss.py new file mode 100644 index 00000000000..0a11e33c08e --- /dev/null +++ b/optimum/fx/parallelization/parallel_layers/loss.py @@ -0,0 +1,163 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import wraps +from typing import Optional + +import torch +import torch.distributed as dist +import torch.nn as nn + +from ..core import ParallelExecutionCtx + + +# Adapted from https://github.com/huggingface/nanotron/blob/main/src/nanotron/parallel/tensor_parallel/functional.py +class _ShardedCrossEntropy(torch.autograd.Function): + @staticmethod + def forward( + ctx, + sharded_logits: torch.Tensor, # (batch_size, length, sharded_hidden_size) + target: torch.Tensor, # (batch_size, length) + group: dist.ProcessGroup, + ): + # Maximum value along last dimension across all GPUs. + logits_max = torch.max(sharded_logits, dim=-1)[0] + dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=group) + # Subtract the maximum value. + sharded_logits = sharded_logits - logits_max.unsqueeze(dim=-1) + + # Get the shard's indices + sharded_hidden_size = sharded_logits.shape[-1] + rank = dist.get_rank(group) + start_index = rank * sharded_hidden_size + end_index = start_index + sharded_hidden_size + + # Create a mask of valid ids (1 means it needs to be masked). + target_mask = (target < start_index) | (target >= end_index) + masked_target = target.clone() - start_index + masked_target[target_mask] = 0 + + # Get predicted-logits = logits[target]. + # For Simplicity, we convert logits to a 2-D tensor with size + # [*, shard-size] and target to a 1-D tensor of size [*]. + logits_2d = sharded_logits.view(-1, sharded_hidden_size) + masked_target_1d = masked_target.view(-1) + arange_1d = torch.arange(start=0, end=logits_2d.shape[0], device=logits_2d.device) + predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] + if predicted_logits_1d.is_contiguous(): + predicted_logits_1d = predicted_logits_1d.clone() + else: + predicted_logits_1d = predicted_logits_1d.contiguous() + predicted_logits = predicted_logits_1d.view_as(target) + predicted_logits[target_mask] = 0.0 + # All reduce is needed to get the chunks from other GPUs. + dist.all_reduce(predicted_logits, op=dist.ReduceOp.SUM, group=group) + + # Sum of exponential of logits along vocab dimension across all GPUs. + exp_logits = sharded_logits + torch.exp(sharded_logits, out=exp_logits) + sum_exp_logits = exp_logits.sum(dim=-1) + dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=group) + + # Loss = log(sum(exp(logits))) - predicted-logit. + loss = torch.log(sum_exp_logits) - predicted_logits + + # Normalize and optionally smooth logits + exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) + + # Store softmax, target-mask and masked-target for backward pass. + ctx.save_for_backward(exp_logits, target_mask, masked_target_1d) + + return loss + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + # Retrieve tensors from the forward path. + softmax, target_mask, masked_target_1d = ctx.saved_tensors + + # All the inputs have softmax as their gradient. + grad_input = softmax + # For simplicity, work with the 2D gradient. + sharded_hidden_size = softmax.size()[-1] + grad_2d = grad_input.view(-1, sharded_hidden_size) + + # Add the gradient from matching classes. + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device) + grad_2d[arange_1d, masked_target_1d] -= 1.0 - target_mask.view(-1).float() + + # Finally elementwise multiplication with the output gradients. + grad_input.mul_(grad_output.unsqueeze(dim=-1)) + + return grad_input, None, None + + +def sharded_cross_entropy(sharded_logits: torch.Tensor, target: torch.Tensor, process_group: dist.ProcessGroup): + return _ShardedCrossEntropy.apply(sharded_logits, target, process_group) + + +def sharded_cross_entropy_wrapper_fn(process_group: dist.ProcessGroup): + @wraps(sharded_cross_entropy) + def wrapper( + sharded_logits: torch.Tensor, + target: torch.Tensor, + weight: Optional[torch.Tensor] = None, + size_average: Optional[bool] = None, + ignore_index: int = -100, + reduce: Optional[bool] = None, + reduction: str = "mean", + label_smoothing: float = 0.0, + ): + if weight is not None or ignore_index != -100 or label_smoothing != 0.0: + raise ValueError( + "Does not support weighted mode, index ignoring and label smoothing in current parallel cross entropy implementation." + ) + loss: torch.Tensor = sharded_cross_entropy(sharded_logits, target, process_group) + + if size_average is not None or reduce is not None: + size_average = True if size_average is None else size_average + reduce = True if reduce is None else reduce + + if size_average and reduce: + reduction = "mean" + elif reduce: + reduction = "sum" + else: + reduction = "none" + + if reduction == "mean": + return loss.mean() + elif reduction == "sum": + return loss.sum() + return loss + + return wrapper + + +class VocabParallelCrossEntropyLoss(nn.Module): + """ + Simple parallel cross entropy implementation which does not support weighted mode and label smoothing yet. + """ + + def __init__(self, ctx: ParallelExecutionCtx, reduction: str = "mean") -> None: + super(VocabParallelCrossEntropyLoss, self).__init__() + self.process_group = ctx.tp_group + self.reduction = reduction + + def forward(self, sharded_logits: torch.Tensor, target: torch.Tensor): + loss: torch.Tensor = _ShardedCrossEntropy.apply(sharded_logits, target, self.process_group) + if self.reduction == "mean": + return loss.mean() + elif self.reduction == "sum": + return loss.sum() + return loss diff --git a/optimum/fx/parallelization/passes.py b/optimum/fx/parallelization/passes.py index 14b652fff73..90155263281 100644 --- a/optimum/fx/parallelization/passes.py +++ b/optimum/fx/parallelization/passes.py @@ -26,8 +26,15 @@ from .decomp import decompose_and_functionalize from .distributed import scatter from .op_registry import REGISTRY, FallbackParallelAxisPropagateHandler -from .parallel_layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding +from .parallel_layers import ( + ColumnParallelLinear, + RowParallelLinear, + VocabParallelCrossEntropyLoss, + VocabParallelEmbedding, + sharded_cross_entropy_wrapper_fn, +) from .utils import ( + is_cross_entropy, is_embedding, is_linear, is_shape_consumer, @@ -273,6 +280,11 @@ def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Conf info["sequence_parallel"] = False self.place_marker_per_node(node, info) + elif is_cross_entropy(node): + axis_before = ParallelAxisSolverPass.get_stored_field_info(node.args[0], "parallel_axis") + if axis_before is not None: + self.place_marker_per_node(node, {"axis": "vocab"}) + return graph_module @@ -343,6 +355,35 @@ def handle_embedding(node: Node, ctx: ParallelExecutionCtx) -> None: layer_cache[key] = new_mod setattr(parent_mod, field, new_mod) + @staticmethod + def handle_cross_entropy(node: Node, ctx: ParallelExecutionCtx) -> None: + axis = ParallelLayerAnnotatePass.get_stored_field_info(node, field="axis") + if axis is None: + return + + assert axis in {"vocab"}, "Only support parallelization on vocab dim for now." + if node.op == "call_module": + graph_module = node.graph.owning_module + prefix_and_field = node.target.rsplit(".", maxsplit=1) + if len(prefix_and_field) == 2: + parent_mod = graph_module.get_submodule(prefix_and_field[0]) + field = prefix_and_field[1] + else: + parent_mod = graph_module + field = node.target + + mod: nn.CrossEntropyLoss = graph_module.get_submodule(node.target) + key, layer_cache = node.target, ctx.parallel_layer_cache + if key in layer_cache: + new_mod = layer_cache[key] + else: + assert ctx.compile_times == 0, "illegal path for recompilation" + new_mod = VocabParallelCrossEntropyLoss(ctx, reduction=mod.reduction) + layer_cache[key] = new_mod + setattr(parent_mod, field, new_mod) + else: + node.target = sharded_cross_entropy_wrapper_fn(process_group=ctx.tp_group) + @staticmethod def handle_hard_coded_axis_param(node: Node, ctx: ParallelExecutionCtx) -> None: def extract_shape_from_node(node: Node) -> List[Any]: @@ -384,6 +425,8 @@ def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Conf self.handle_linear(node, ctx) elif is_embedding(node): self.handle_embedding(node, ctx) + elif is_cross_entropy(node): + self.handle_cross_entropy(node, ctx) # correct the attention head num in parallel setting elif is_shape_consumer(node): self.handle_hard_coded_axis_param(node, ctx) diff --git a/optimum/fx/parallelization/utils.py b/optimum/fx/parallelization/utils.py index b7b1ccd41c8..3074638737f 100644 --- a/optimum/fx/parallelization/utils.py +++ b/optimum/fx/parallelization/utils.py @@ -82,6 +82,40 @@ def is_shape_generator(node: Node) -> bool: return node.op == "call_method" and node.target == "size" +def is_cross_entropy(node: Node) -> bool: + if node.op == "call_function": + return node.target is F.cross_entropy + elif node.op == "call_module": + mod = node.graph.owning_module + return isinstance(mod.get_submodule(node.target), nn.CrossEntropyLoss) + return False + + +def is_cross_entropy_parallel_compatible(node: Node) -> bool: + """ + For now `VocabParallelCrossEntropyLoss` does not support weighted mode, index ignoring and label smoothing. + """ + if node.op == "call_function": + weight = node.kwargs.get("weight", None) + ignore_index = node.kwargs.get("ignore_index", -100) + label_smoothing = node.kwargs.get("label_smoothing", 0.0) + if len(node.args) > 2 and weight is None: + weight = node.args[2] + if len(node.args) > 4 and ignore_index == -100: + ignore_index = node.args[4] + if len(node.args) > 7 and label_smoothing == 0.0: + label_smoothing = node.args[7] + + return weight is None and ignore_index == -100 and label_smoothing == 0.0 + + elif node.op == "call_module": + mod: nn.CrossEntropyLoss = node.graph.owning_module.get_submodule(node.target) + weight, label_smoothing, ignore_index = mod.weight, mod.label_smoothing, mod.ignore_index + return weight is None and ignore_index == -100 and label_smoothing == 0.0 + + return False + + def stable_topological_sort(graph: Graph): def _args(n: torch.fx.Node) -> List[torch.fx.node.Argument]: args: List[torch.fx.node.Argument] = [] diff --git a/tests/fx/parallelization/test_tensor_parallel.py b/tests/fx/parallelization/test_tensor_parallel.py index 9626fccec3b..8a00393c4d7 100644 --- a/tests/fx/parallelization/test_tensor_parallel.py +++ b/tests/fx/parallelization/test_tensor_parallel.py @@ -36,6 +36,7 @@ "output_attentions": False, "output_hidden_states": False, "tie_word_embeddings": True, + "return_dict": True, } DUMMY_MODELS_TO_TEST = ( @@ -64,11 +65,10 @@ def prepare_dummy_inputs( seq_len: int = 10, device: Union[str, torch.device] = "cuda", ): - return { - "input_ids": torch.randint(low=1, high=model_config.vocab_size, size=(batch_size, seq_len), device=device), - "attention_mask": torch.ones((batch_size, seq_len), dtype=torch.int64, device=device), - "position_ids": torch.arange(0, seq_len, device=device).unsqueeze(0).expand(batch_size, -1), - } + input_ids = torch.randint(low=1, high=model_config.vocab_size, size=(batch_size, seq_len), device=device) + attention_mask = torch.ones((batch_size, seq_len), dtype=torch.int64, device=device) + labels = input_ids.clone() + return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels} def run_test_all_rank_results_match(rank: int, world_size: int, model_id: str, model_kwargs: Dict[str, Any]): @@ -82,8 +82,8 @@ def run_test_all_rank_results_match(rank: int, world_size: int, model_id: str, m model = parallelize_model(model_id, ctx, skip_load_weights=True, **model_kwargs) inputs = prepare_dummy_inputs(model.config) - logits = model(**inputs)[0] - tensors = gather_at_main_process(tensor=logits, group=tp_group, rank=rank, world_size=world_size) + loss = model(**inputs).loss + tensors = gather_at_main_process(tensor=loss, group=tp_group, rank=rank, world_size=world_size) # check results at main worker process if rank == 0: @@ -145,7 +145,7 @@ def run_test_parallel_results_matches_non_parallel( inputs = prepare_dummy_inputs(model.config) set_seed(SEED) - logits = model(**inputs)[0] + loss = model(**inputs).loss torch._dynamo.reset() del model @@ -154,9 +154,9 @@ def run_test_parallel_results_matches_non_parallel( set_seed(SEED) ctx = ParallelExecutionCtx(tp_group=tp_group, current_device=device) model = parallelize_model(model_id, ctx, skip_load_weights=True, **model_kwargs) - parallel_logits = model(**inputs)[0] + parallel_loss = model(**inputs).loss - torch.testing.assert_close(logits.cpu(), parallel_logits.cpu(), rtol=1e-4, atol=1e-4) + torch.testing.assert_close(loss.cpu(), parallel_loss.cpu(), rtol=1e-4, atol=1e-4) dist.barrier(tp_group) tearDown()