Skip to content

Commit

Permalink
Release v1.12
Browse files Browse the repository at this point in the history
  • Loading branch information
ptrendx committed Nov 18, 2024
2 parents c27ee60 + 7f2afaa commit 7a7225c
Show file tree
Hide file tree
Showing 105 changed files with 4,410 additions and 1,575 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/trigger-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ jobs:
|| github.actor == 'yaox12'
|| github.actor == 'huanghua1994'
|| github.actor == 'mgoldfarb-nvidia'
|| github.actor == 'pggPL'
|| github.actor == 'vasunvidia'
|| github.actor == 'erhoo82'
)
steps:
- name: Check if comment is issued by authorized person
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,4 @@ develop-eggs/
dist/
downloads/
.pytest_cache/
compile_commands.json
2 changes: 1 addition & 1 deletion 3rdparty/cudnn-frontend
Submodule cudnn-frontend updated 146 files
2 changes: 1 addition & 1 deletion build_tools/VERSION.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.11.0
1.12.0
8 changes: 4 additions & 4 deletions build_tools/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,10 @@ def setup_pytorch_extension(
)
)

if "80" in cuda_architectures:
nvcc_flags.extend(["-gencode", "arch=compute_80,code=sm_80"])
if "90" in cuda_architectures:
nvcc_flags.extend(["-gencode", "arch=compute_90,code=sm_90"])
for arch in cuda_architectures.split(";"):
if arch == "70":
continue # Already handled
nvcc_flags.extend(["-gencode", f"arch=compute_{arch},code=sm_{arch}"])

# Libraries
library_dirs = []
Expand Down
75 changes: 75 additions & 0 deletions docs/faq.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
..
Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.

Frequently Asked Questions (FAQ)
================================

FP8 checkpoint compatibility
----------------------------

Transformer Engine starts to support FP8 attention in 1.6. It stores the FP8 metadata, i.e. scaling factors and amax histories, under a `._extra_state` key in the checkpoint. As the FP8 attention support expands from one backend to multiple backends, the location of the `._extra_state` key has also shifted.

Here, we take the `MultiheadAttention` module as an example. Its FP8 attention metadata in Transformer Engine 1.11 is stored as `core_attention._extra_state` as shown below.

.. code-block:: python
>>> from transformer_engine.pytorch import MultiheadAttention, fp8_model_init
>>> with fp8_model_init(enabled=True):
... mha = MultiheadAttention(
... hidden_size=1024,
... num_attention_heads=16,
... bias=True,
... params_dtype=torch.bfloat16,
... input_layernorm=False,
... fuse_qkv_params=True,
... attention_type="self",
... qkv_weight_interleaved=True,
... ).to(dtype=torch.bfloat16, device="cuda")
...
>>> state_dict = mha.state_dict()
>>> print(state_dict.keys())
odict_keys(['qkv.weight', 'qkv.bias', 'qkv._extra_state', 'core_attention._extra_state', 'proj.weight', 'proj.bias', 'proj._extra_state'])
Here is a full list of the checkpoint save/load behaviors from all Transformer Engine versions.

.. list-table::

* - **Version: <= 1.5**

- Saves no FP8 metadata since FP8 attention is not supported
- Loading behavior for checkpoints created by the following versions:

:<= 1.5: Loads no FP8 metadata
:> 1.5: Error: unexpected key
* - **Version: 1.6, 1.7**

- Saves FP8 metadata to `core_attention.fused_attention._extra_state`
- Loading behavior for checkpoints created by the following versions:

:<= 1.5: Initializes FP8 metadata to the default, i.e. 1s for scaling factors, and 0s for amaxes
:1.6, 1.7: Loads FP8 metadata from checkpoint
:>= 1.8: Error: unexpected key
* - **Version: >=1.8, <= 1.11**

- Saves FP8 metadata to `core_attention._extra_state`
- Loading behavior for checkpoints created by the following versions:

:<= 1.5: Initializes FP8 metadata to the default, i.e. 1s for scaling factors, and 0s for amaxes
:1.6, 1.7: This save/load combination relies on users to map the 1.6/1.7 key to the 1.8-1.11 key. Otherwise, it initializes FP8 metadata to the default, i.e. 1s for scaling factors, and 0s for amaxes. The mapping can be done, in this `MultiheadAttention` example, by

.. code-block:: python
>>> state_dict["core_attention._extra_state"] = \
state_dict["core_attention.fused_attention._extra_state"]
>>> del state_dict["core_attention.fused_attention._extra_state"]
:>= 1.8: Loads FP8 metadata from checkpoint
* - **Version: >=1.12**

- Saves FP8 metadata to `core_attention._extra_state`
- Loading behavior for checkpoints created by the following versions:

:<= 1.5: Initializes FP8 metadata to the default, i.e. 1s for scaling factors, and 0s for amaxes
:>= 1.6: Loads FP8 metadata from checkpoint
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ Transformer Engine documentation

installation
examples/quickstart.ipynb
faq

.. toctree::
:hidden:
Expand Down
43 changes: 43 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Examples

We provide a variety of examples for deep learning frameworks including [PyTorch](https://github.com/pytorch/pytorch), [JAX](https://github.com/jax-ml/jax), and [PaddlePaddle](https://github.com/PaddlePaddle/Paddle).
Additionally, we offer [Jupyter notebook tutorials](https://github.com/NVIDIA/TransformerEngine/tree/main/docs/examples) and a selection of [third-party examples](#third-party). Please be aware that these third-party examples might need specific, older versions of dependencies to function properly.

# PyTorch

- [Accelerate Hugging Face Llama models with TE](https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb)
- Provides code examples and explanations for integrating TE with the LLaMA2 and LLaMA2 models.
- [PyTorch FSDP with FP8](https://github.com/NVIDIA/TransformerEngine/tree/main/examples/pytorch/fsdp)
- **Distributed Training**: How to set up and run distributed training using PyTorch’s FullyShardedDataParallel (FSDP) strategy.
- **TE Integration**: Instructions on integrating TE/FP8 with PyTorch for optimized performance.
- **Checkpointing**: Methods for applying activation checkpointing to manage memory usage during training.
- [Attention backends in TE](https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/attention/attention.ipynb)
- **Attention Backends**: Describes various attention backends supported by Transformer Engine, including framework-native, fused, and flash-attention backends, and their performance benefits.
- **Flash vs. Non-Flash**: Compares the flash algorithm with the standard non-flash algorithm, highlighting memory and computational efficiency improvements.
- **Backend Selection**: Details the logic for selecting the most appropriate backend based on availability and performance, and provides user control options for backend selection.
- [Overlapping Communication with GEMM](https://github.com/NVIDIA/TransformerEngine/tree/main/examples/pytorch/comm_gemm_overlap)
- Training a TE module with GEMM and communication overlap, including various configurations and command-line arguments for customization.
- [Performance Optimizations](https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/advanced_optimizations.ipynb)
- **Multi-GPU Training**: How to use TE with data, tensor, and sequence parallelism.
- **Gradient Accumulation Fusion**: Utilizing Tensor Cores to accumulate outputs directly into FP32 for better numerical accuracy.
- **FP8 Weight Caching**: Avoiding redundant FP8 casting during multiple gradient accumulation steps to improve efficiency.
- [Introduction to FP8](https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/fp8_primer.ipynb)
- Overview of FP8 datatypes (E4M3, E5M2), mixed precision training, delayed scaling strategies, and code examples for FP8 configuration and usage.
- [TE Quickstart](https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/quickstart.ipynb)
- Introduction to TE, building a Transformer Layer using PyTorch, and instructions on integrating TE modules like Linear and LayerNorm.
- [Basic MNIST Example](https://github.com/NVIDIA/TransformerEngine/tree/main/examples/pytorch/mnist)

# JAX
- [Basic Transformer Encoder Example](https://github.com/NVIDIA/TransformerEngine/tree/main/examples/jax/encoder)
- Single GPU Training: Demonstrates setting up and training a Transformer model using a single GPU.
- Data Parallelism: Scale training across multiple GPUs using data parallelism.
- Model Parallelism: Divide a model across multiple GPUs for parallel training.
- Multiprocessing with Model Parallelism: Multiprocessing for model parallelism, including multi-node support and hardware affinity setup.
- [Basic MNIST Example](https://github.com/NVIDIA/TransformerEngine/tree/main/examples/jax/mnist)

# PaddlePaddle
- [Basic MNIST Example](https://github.com/NVIDIA/TransformerEngine/tree/main/examples/paddle/mnist)

# Third party
- [Hugging Face Accelerate + TE](https://github.com/huggingface/accelerate/tree/main/benchmarks/fp8/transformer_engine)
- Scripts for training with Accelerate and TE. Supports single GPU, and multi-GPU via DDP, FSDP, and DeepSpeed ZeRO 1-3.
14 changes: 14 additions & 0 deletions examples/jax/encoder/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Shared functions for the encoder tests"""
from functools import lru_cache

from transformer_engine.transformer_engine_jax import get_device_compute_capability


@lru_cache
def is_bf16_supported():
"""Return if BF16 has hardware supported"""
gpu_arch = get_device_compute_capability(0)
return gpu_arch >= 80
4 changes: 4 additions & 0 deletions examples/jax/encoder/test_model_parallel_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax

from common import is_bf16_supported

DEVICE_DP_AXIS = "data"
DEVICE_TP_AXIS = "model"
NAMED_BROADCAST_AXIS = "my_broadcast_axis"
Expand Down Expand Up @@ -434,6 +436,7 @@ def setUpClass(cls):
"""Run 3 epochs for testing"""
cls.args = encoder_parser(["--epochs", "3"])

@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args)
Expand All @@ -446,6 +449,7 @@ def test_te_fp8(self):
actual = train_and_evaluate(self.args)
assert actual[0] < 0.45 and actual[1] > 0.79

@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_sp(self):
"""Test Transformer Engine with BF16 + SP"""
self.args.enable_sp = True
Expand Down
3 changes: 3 additions & 0 deletions examples/jax/encoder/test_multigpu_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax

from common import is_bf16_supported

DEVICE_DP_AXIS = "data"
PARAMS_KEY = "params"
PARAMS_AXES_KEY = PARAMS_KEY + "_axes"
Expand Down Expand Up @@ -402,6 +404,7 @@ def setUpClass(cls):
"""Run 3 epochs for testing"""
cls.args = encoder_parser(["--epochs", "3"])

@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args)
Expand Down
12 changes: 8 additions & 4 deletions examples/jax/encoder/test_multiprocessing_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax

from common import is_bf16_supported

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
DEVICE_DP_AXIS = "data"
DEVICE_TP_AXIS = "model"
Expand Down Expand Up @@ -552,8 +554,9 @@ def encoder_parser(args):
def query_gpu(q):
"""Query GPU info on the system"""
gpu_has_fp8, reason = te.fp8.is_fp8_available()
gpu_has_bf16 = is_bf16_supported()
num_gpu = len(jax.devices())
q.put([num_gpu, gpu_has_fp8, reason])
q.put([num_gpu, gpu_has_fp8, gpu_has_bf16, reason])


def unittest_query_gpu():
Expand All @@ -566,15 +569,15 @@ def unittest_query_gpu():
q = mp.Queue()
p = mp.Process(target=query_gpu, args=(q,))
p.start()
num_gpu, gpu_has_fp8, reason = q.get()
num_gpu, gpu_has_fp8, gpu_has_bf16, reason = q.get()
p.join()
return num_gpu, gpu_has_fp8, reason
return num_gpu, gpu_has_fp8, gpu_has_bf16, reason


class TestEncoder(unittest.TestCase):
"""Encoder unittests"""

num_gpu, gpu_has_fp8, reason = unittest_query_gpu()
num_gpu, gpu_has_fp8, gpu_has_bf16, reason = unittest_query_gpu()

def exec(self, use_fp8):
"""Run 3 epochs for testing"""
Expand All @@ -598,6 +601,7 @@ def exec(self, use_fp8):

return results

@unittest.skipIf(not gpu_has_bf16, "Device compute capability 8.0+ is required for BF16")
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
results = self.exec(False)
Expand Down
3 changes: 3 additions & 0 deletions examples/jax/encoder/test_single_gpu_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax

from common import is_bf16_supported

PARAMS_KEY = "params"
DROPOUT_KEY = "dropout"
INPUT_KEY = "input_rng"
Expand Down Expand Up @@ -321,6 +323,7 @@ def setUpClass(cls):
"""Run 4 epochs for testing"""
cls.args = encoder_parser(["--epochs", "3"])

@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args)
Expand Down
2 changes: 2 additions & 0 deletions pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ extension-pkg-whitelist=flash_attn_2_cuda,
extension-pkg-allow-list=transformer_engine.transformer_engine_jax

disable=too-many-locals,
too-few-public-methods,
too-many-public-methods,
too-many-positional-arguments,
invalid-name,
too-many-arguments,
abstract-method,
Expand Down
2 changes: 1 addition & 1 deletion qa/L0_jax_lint/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ set -e

: "${TE_PATH:=/opt/transformerengine}"

pip install cpplint==1.6.0 pylint==2.13.5
pip install cpplint==1.6.0 pylint==3.3.1
if [ -z "${PYTHON_ONLY}" ]
then
cd $TE_PATH
Expand Down
2 changes: 0 additions & 2 deletions qa/L0_jax_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,5 @@ pip install -r $TE_PATH/examples/jax/encoder/requirements.txt

pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/mnist

# Make encoder tests to have run-to-run deterministic to have the stable CI results
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder --ignore=$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py
2 changes: 1 addition & 1 deletion qa/L0_paddle_lint/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ set -e

: "${TE_PATH:=/opt/transformerengine}"

pip install cpplint==1.6.0 pylint==2.13.5
pip install cpplint==1.6.0 pylint==3.3.1
if [ -z "${PYTHON_ONLY}" ]
then
cd $TE_PATH
Expand Down
10 changes: 7 additions & 3 deletions qa/L0_paddle_wheel/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@ set -e

: "${TE_PATH:=/opt/transformerengine}"

pip install wheel==0.44.0 pydantic
# Install dependencies
# Note: Need to install wheel locally since PaddlePaddle container
# already contains APT install.
pip install pydantic
pip install --user wheel==0.44.0

cd $TE_PATH
pip uninstall -y transformer-engine transformer-engine-cu12 transformer-engine-paddle
Expand All @@ -16,11 +20,11 @@ WHL_BASE="transformer_engine-${VERSION}"

# Core wheel.
NVTE_RELEASE_BUILD=1 python setup.py bdist_wheel
wheel unpack dist/*
python -m wheel unpack dist/*
sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info"
wheel pack ${WHL_BASE}
python -m wheel pack ${WHL_BASE}
rm dist/*.whl
mv *.whl dist/
NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python setup.py bdist_wheel
Expand Down
2 changes: 1 addition & 1 deletion qa/L0_pytorch_lint/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ set -e

: "${TE_PATH:=/opt/transformerengine}"

pip install cpplint==1.6.0 pylint==2.13.5
pip install cpplint==1.6.0 pylint==3.3.1
if [ -z "${PYTHON_ONLY}" ]
then
cd $TE_PATH
Expand Down
6 changes: 2 additions & 4 deletions qa/L0_pytorch_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,19 @@ set -e

: ${TE_PATH:=/opt/transformerengine}

pip install pytest==8.2.1 onnxruntime==1.13.1
pip install pytest==8.2.1
pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py
pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py
pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py
pytest -v -s $TE_PATH/tests/pytorch/test_jit.py
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py
NVTE_TORCH_COMPILE=0 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py
pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py
pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py
pytest -v -s $TE_PATH/tests/pytorch/test_torch_save_load.py
pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py
pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py
pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py
pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py
pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops_distributed.py
pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py
10 changes: 0 additions & 10 deletions qa/L1_pytorch_context_parallel_test/test.sh

This file was deleted.

Loading

0 comments on commit 7a7225c

Please sign in to comment.