Skip to content

Commit

Permalink
fix conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenglongjiepheonix committed Sep 20, 2024
2 parents 3a1a195 + 2fb5ea5 commit b5b371f
Show file tree
Hide file tree
Showing 37 changed files with 1,672 additions and 882 deletions.
31 changes: 31 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -268,3 +268,34 @@ You can find more examples in the [documentation](https://huggingface.co/docs/op
```

You can find more examples in the [documentation](https://huggingface.co/docs/optimum/onnxruntime/usage_guides/trainer) and in the [examples](https://github.com/huggingface/optimum/tree/main/examples/onnxruntime/training).


### Quanto

[Quanto](https://github.com/huggingface/optimum-quanto) is a pytorch quantization backend.

You can quantize a model either using the python API or the `optimum-cli`.

```python
from transformers import AutoModelForCausalLM
from optimum.quanto import QuantizedModelForCausalLM, qint4

model = AutoModelForCausalLM.from_pretrained('meta-llama/Meta-Llama-3.1-8B')
qmodel = QuantizedModelForCausalLM.quantize(model, weights=qint4, exclude='lm_head')
```

The quantized model can be saved using `save_pretrained`:

```python
qmodel.save_pretrained('./Llama-3.1-8B-quantized')
```

It can later be reloaded using `from_pretrained`:

```python
from optimum.quanto import QuantizedModelForCausalLM

qmodel = QuantizedModelForCausalLM.from_pretrained('Llama-3.1-8B-quantized')
```

You can see more details and [examples](https://github.com/huggingface/optimum-quanto/tree/main/examples) in the [Quanto](https://github.com/huggingface/optimum-quanto) repository.
2 changes: 1 addition & 1 deletion docs/source/bettertransformer/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ In the 2.0 version, PyTorch includes a native scaled dot-product attention opera
We provide an integration with these optimizations out of the box in 🤗 Optimum, so that you can convert any supported 🤗 Transformers model so as to use the optimized paths & `scaled_dot_product_attention` function when relevant.

<Tip warning={true}>
PyTorch-native `scaled_dot_product_attention` is slowly being natively [made default and integrated in 🤗 Transformers](https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-and-memory-efficient-attention-through-pytorchs-scaleddotproductattention). For models that do support SDPA in Transformers, we deprecate BetterTransformer and recommend you to use directly Transformers and PyTorc latest version for the attention optimizations (Flash Attention, memory-efficient attention) through SDPA.
PyTorch-native `scaled_dot_product_attention` is slowly being natively [made default and integrated in 🤗 Transformers](https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-and-memory-efficient-attention-through-pytorchs-scaleddotproductattention). For models that do support SDPA in Transformers, we deprecate BetterTransformer and recommend you to use directly Transformers and PyTorch latest version for the attention optimizations (Flash Attention, memory-efficient attention) through SDPA.
</Tip>

<Tip warning={true}>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ torchrun --nproc_per_node=NUM_GPUS_YOU_HAVE run_image_classification.py \
--per_device_eval_batch_size 32 \
--logging_strategy steps \
--logging_steps 10 \
--evaluation_strategy epoch \
--eval_strategy epoch \
--seed 1337
```

Expand Down
2 changes: 1 addition & 1 deletion optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,9 +308,9 @@ class TasksManager:
"image-feature-extraction": "feature-extraction",
# for backward compatibility and testing (where
# model task and model type are still the same)
"lcm": "text-to-image",
"stable-diffusion": "text-to-image",
"stable-diffusion-xl": "text-to-image",
"latent-consistency": "text-to-image",
}

_CUSTOM_CLASSES = {
Expand Down
22 changes: 20 additions & 2 deletions optimum/fx/parallelization/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import Optional, Tuple
from typing import Optional, Tuple, Union

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch.fx import GraphModule

from ..core import Config, ParallelExecutionCtx, ParameterMeta
from ..distributed import scatter
from ..parallel_layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding
from ..parallel_layers import (
ColumnParallelLinear,
RowParallelLinear,
VocabParallelEmbedding,
VocabParallelCrossEntropyLoss,
sharded_cross_entropy_wrapper_fn,
)
from ..passes import (
ParallelAxisSolverPass,
ParallelLayerAnnotatePass,
Expand Down Expand Up @@ -64,6 +71,17 @@ def create_parallel_embedding(
) -> nn.Module:
raise NotImplementedError

@abstractmethod
def create_parallel_cross_entropy(
self,
mod_or_fn: Union[nn.CrossEntropyLoss, F.cross_entropy],
parallel_ctx: "ParallelExecutionCtx",
):
if isinstance(mod_or_fn, nn.CrossEntropyLoss):
return VocabParallelCrossEntropyLoss(ctx=parallel_ctx, reduction=mod_or_fn.reduction)
else:
return sharded_cross_entropy_wrapper_fn(process_group=parallel_ctx.tp_group)

def pre_process(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", config: "Config") -> GraphModule:
"""
Mark tie information right before we run passes because dynamo tracing will alter the parameter name while our
Expand Down
2 changes: 1 addition & 1 deletion optimum/fx/parallelization/decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def run(self, *args, **kwargs):
def decompose_and_functionalize(
graph_module: GraphModule,
decomposition_table: Dict[torch._ops.OperatorBase, Callable] = core_aten_decompositions(),
leaf_function_targets: List[Callable] = [F.scaled_dot_product_attention],
leaf_function_targets: List[Callable] = [F.scaled_dot_product_attention, F.cross_entropy],
) -> Callable:
"""
API to decompose and functionalize a high-level graph module.
Expand Down
35 changes: 27 additions & 8 deletions optimum/fx/parallelization/op_registry/op_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch.fx import Node

from ..core import Config
from ..utils import is_activation, is_embedding, is_linear
from ..utils import is_activation, is_cross_entropy, is_cross_entropy_parallel_compatible, is_embedding, is_linear


class Registry:
Expand Down Expand Up @@ -334,7 +334,16 @@ def propagate(self) -> List[int]:
ndim = arg.meta["val"].ndim
slice_dim = (slice_dim + ndim) % ndim
if slice_dim == axis:
# slice on the parallel axis is not allowed
# slice on the parallel axis is not allowed, except it's a nop
start, stop, step = 0, arg.meta["val"].shape[axis], 1
if len(self.node.args) > 2:
start = self.node.args[2]
elif len(self.node.args) > 3:
stop = self.node.args[3]
elif len(self.node.args) > 4:
step = self.node.args[4]
if start == 0 and stop >= arg.meta["val"].shape[axis] and step == 1:
return [axis]
return []
return [axis]

Expand Down Expand Up @@ -404,12 +413,12 @@ def propagate(self) -> List[int]:
if self.node.op in ["placeholder", "get_attr"]:
return [None]
elif self.node.op == "output":
for node in self.node.all_input_nodes:
# TODO: allow parallelized nodes in output, and append comm ops in graph tp all-gather
# parallelized output if intructed
if self.extract_axis(node) is not None:
return []
return [None]
# does not care about if output is being parallelized right now, because if the output is loss,
# then it must be not parallelized as long as it comes from sharded cross entropy.
# TODO: append all-gather comm ops before all parallelized output nodes if instructed.
input_arg = self.node.all_input_nodes[0]
axis = self.extract_axis(input_arg)
return [axis]
elif is_linear(self.node):
input_arg = self.node.all_input_nodes[0]
axis = self.extract_axis(input_arg)
Expand Down Expand Up @@ -438,6 +447,16 @@ def propagate(self) -> List[int]:
return [1, None] if self.config.enable_sequence_parallel else [None]
else:
return []
elif is_cross_entropy(self.node):
logits = self.node.all_input_nodes[0]
axis = self.extract_axis(logits)
if axis is None or (
is_cross_entropy_parallel_compatible(self.node) and axis == logits.meta["val"].ndim - 1
):
# for cross entropy, the input logits parallel axis can only be the last axis or None
return [None]
else:
return []
elif is_activation(self.node):
return UnaryOpParallelAxisPropagateHandler(self.node, self.meta_key, self.config).propagate()

Expand Down
1 change: 1 addition & 0 deletions optimum/fx/parallelization/parallel_layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@
# limitations under the License.
from .embedding import VocabParallelEmbedding
from .linear import ColumnParallelLinear, RowParallelLinear
from .loss import VocabParallelCrossEntropyLoss, sharded_cross_entropy_wrapper_fn
163 changes: 163 additions & 0 deletions optimum/fx/parallelization/parallel_layers/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import wraps
from typing import Optional

import torch
import torch.distributed as dist
import torch.nn as nn

from ..core import ParallelExecutionCtx


# Adapted from https://github.com/huggingface/nanotron/blob/main/src/nanotron/parallel/tensor_parallel/functional.py
class _ShardedCrossEntropy(torch.autograd.Function):
@staticmethod
def forward(
ctx,
sharded_logits: torch.Tensor, # (batch_size, length, sharded_hidden_size)
target: torch.Tensor, # (batch_size, length)
group: dist.ProcessGroup,
):
# Maximum value along last dimension across all GPUs.
logits_max = torch.max(sharded_logits, dim=-1)[0]
dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=group)
# Subtract the maximum value.
sharded_logits = sharded_logits - logits_max.unsqueeze(dim=-1)

# Get the shard's indices
sharded_hidden_size = sharded_logits.shape[-1]
rank = dist.get_rank(group)
start_index = rank * sharded_hidden_size
end_index = start_index + sharded_hidden_size

# Create a mask of valid ids (1 means it needs to be masked).
target_mask = (target < start_index) | (target >= end_index)
masked_target = target.clone() - start_index
masked_target[target_mask] = 0

# Get predicted-logits = logits[target].
# For Simplicity, we convert logits to a 2-D tensor with size
# [*, shard-size] and target to a 1-D tensor of size [*].
logits_2d = sharded_logits.view(-1, sharded_hidden_size)
masked_target_1d = masked_target.view(-1)
arange_1d = torch.arange(start=0, end=logits_2d.shape[0], device=logits_2d.device)
predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
if predicted_logits_1d.is_contiguous():
predicted_logits_1d = predicted_logits_1d.clone()
else:
predicted_logits_1d = predicted_logits_1d.contiguous()
predicted_logits = predicted_logits_1d.view_as(target)
predicted_logits[target_mask] = 0.0
# All reduce is needed to get the chunks from other GPUs.
dist.all_reduce(predicted_logits, op=dist.ReduceOp.SUM, group=group)

# Sum of exponential of logits along vocab dimension across all GPUs.
exp_logits = sharded_logits
torch.exp(sharded_logits, out=exp_logits)
sum_exp_logits = exp_logits.sum(dim=-1)
dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=group)

# Loss = log(sum(exp(logits))) - predicted-logit.
loss = torch.log(sum_exp_logits) - predicted_logits

# Normalize and optionally smooth logits
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))

# Store softmax, target-mask and masked-target for backward pass.
ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)

return loss

@staticmethod
def backward(ctx, grad_output: torch.Tensor):
# Retrieve tensors from the forward path.
softmax, target_mask, masked_target_1d = ctx.saved_tensors

# All the inputs have softmax as their gradient.
grad_input = softmax
# For simplicity, work with the 2D gradient.
sharded_hidden_size = softmax.size()[-1]
grad_2d = grad_input.view(-1, sharded_hidden_size)

# Add the gradient from matching classes.
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device)
grad_2d[arange_1d, masked_target_1d] -= 1.0 - target_mask.view(-1).float()

# Finally elementwise multiplication with the output gradients.
grad_input.mul_(grad_output.unsqueeze(dim=-1))

return grad_input, None, None


def sharded_cross_entropy(sharded_logits: torch.Tensor, target: torch.Tensor, process_group: dist.ProcessGroup):
return _ShardedCrossEntropy.apply(sharded_logits, target, process_group)


def sharded_cross_entropy_wrapper_fn(process_group: dist.ProcessGroup):
@wraps(sharded_cross_entropy)
def wrapper(
sharded_logits: torch.Tensor,
target: torch.Tensor,
weight: Optional[torch.Tensor] = None,
size_average: Optional[bool] = None,
ignore_index: int = -100,
reduce: Optional[bool] = None,
reduction: str = "mean",
label_smoothing: float = 0.0,
):
if weight is not None or ignore_index != -100 or label_smoothing != 0.0:
raise ValueError(
"Does not support weighted mode, index ignoring and label smoothing in current parallel cross entropy implementation."
)
loss: torch.Tensor = sharded_cross_entropy(sharded_logits, target, process_group)

if size_average is not None or reduce is not None:
size_average = True if size_average is None else size_average
reduce = True if reduce is None else reduce

if size_average and reduce:
reduction = "mean"
elif reduce:
reduction = "sum"
else:
reduction = "none"

if reduction == "mean":
return loss.mean()
elif reduction == "sum":
return loss.sum()
return loss

return wrapper


class VocabParallelCrossEntropyLoss(nn.Module):
"""
Simple parallel cross entropy implementation which does not support weighted mode and label smoothing yet.
"""

def __init__(self, ctx: ParallelExecutionCtx, reduction: str = "mean") -> None:
super(VocabParallelCrossEntropyLoss, self).__init__()
self.process_group = ctx.tp_group
self.reduction = reduction

def forward(self, sharded_logits: torch.Tensor, target: torch.Tensor):
loss: torch.Tensor = _ShardedCrossEntropy.apply(sharded_logits, target, self.process_group)
if self.reduction == "mean":
return loss.mean()
elif self.reduction == "sum":
return loss.sum()
return loss
Loading

0 comments on commit b5b371f

Please sign in to comment.