-
Notifications
You must be signed in to change notification settings - Fork 331
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
105 changed files
with
4,410 additions
and
1,575 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,3 +39,4 @@ develop-eggs/ | |
dist/ | ||
downloads/ | ||
.pytest_cache/ | ||
compile_commands.json |
Submodule cudnn-frontend
updated
146 files
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
1.11.0 | ||
1.12.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
.. | ||
Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
See LICENSE for license information. | ||
|
||
Frequently Asked Questions (FAQ) | ||
================================ | ||
|
||
FP8 checkpoint compatibility | ||
---------------------------- | ||
|
||
Transformer Engine starts to support FP8 attention in 1.6. It stores the FP8 metadata, i.e. scaling factors and amax histories, under a `._extra_state` key in the checkpoint. As the FP8 attention support expands from one backend to multiple backends, the location of the `._extra_state` key has also shifted. | ||
|
||
Here, we take the `MultiheadAttention` module as an example. Its FP8 attention metadata in Transformer Engine 1.11 is stored as `core_attention._extra_state` as shown below. | ||
|
||
.. code-block:: python | ||
>>> from transformer_engine.pytorch import MultiheadAttention, fp8_model_init | ||
>>> with fp8_model_init(enabled=True): | ||
... mha = MultiheadAttention( | ||
... hidden_size=1024, | ||
... num_attention_heads=16, | ||
... bias=True, | ||
... params_dtype=torch.bfloat16, | ||
... input_layernorm=False, | ||
... fuse_qkv_params=True, | ||
... attention_type="self", | ||
... qkv_weight_interleaved=True, | ||
... ).to(dtype=torch.bfloat16, device="cuda") | ||
... | ||
>>> state_dict = mha.state_dict() | ||
>>> print(state_dict.keys()) | ||
odict_keys(['qkv.weight', 'qkv.bias', 'qkv._extra_state', 'core_attention._extra_state', 'proj.weight', 'proj.bias', 'proj._extra_state']) | ||
Here is a full list of the checkpoint save/load behaviors from all Transformer Engine versions. | ||
|
||
.. list-table:: | ||
|
||
* - **Version: <= 1.5** | ||
|
||
- Saves no FP8 metadata since FP8 attention is not supported | ||
- Loading behavior for checkpoints created by the following versions: | ||
|
||
:<= 1.5: Loads no FP8 metadata | ||
:> 1.5: Error: unexpected key | ||
* - **Version: 1.6, 1.7** | ||
|
||
- Saves FP8 metadata to `core_attention.fused_attention._extra_state` | ||
- Loading behavior for checkpoints created by the following versions: | ||
|
||
:<= 1.5: Initializes FP8 metadata to the default, i.e. 1s for scaling factors, and 0s for amaxes | ||
:1.6, 1.7: Loads FP8 metadata from checkpoint | ||
:>= 1.8: Error: unexpected key | ||
* - **Version: >=1.8, <= 1.11** | ||
|
||
- Saves FP8 metadata to `core_attention._extra_state` | ||
- Loading behavior for checkpoints created by the following versions: | ||
|
||
:<= 1.5: Initializes FP8 metadata to the default, i.e. 1s for scaling factors, and 0s for amaxes | ||
:1.6, 1.7: This save/load combination relies on users to map the 1.6/1.7 key to the 1.8-1.11 key. Otherwise, it initializes FP8 metadata to the default, i.e. 1s for scaling factors, and 0s for amaxes. The mapping can be done, in this `MultiheadAttention` example, by | ||
|
||
.. code-block:: python | ||
>>> state_dict["core_attention._extra_state"] = \ | ||
state_dict["core_attention.fused_attention._extra_state"] | ||
>>> del state_dict["core_attention.fused_attention._extra_state"] | ||
:>= 1.8: Loads FP8 metadata from checkpoint | ||
* - **Version: >=1.12** | ||
|
||
- Saves FP8 metadata to `core_attention._extra_state` | ||
- Loading behavior for checkpoints created by the following versions: | ||
|
||
:<= 1.5: Initializes FP8 metadata to the default, i.e. 1s for scaling factors, and 0s for amaxes | ||
:>= 1.6: Loads FP8 metadata from checkpoint |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,6 +30,7 @@ Transformer Engine documentation | |
|
||
installation | ||
examples/quickstart.ipynb | ||
faq | ||
|
||
.. toctree:: | ||
:hidden: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# Examples | ||
|
||
We provide a variety of examples for deep learning frameworks including [PyTorch](https://github.com/pytorch/pytorch), [JAX](https://github.com/jax-ml/jax), and [PaddlePaddle](https://github.com/PaddlePaddle/Paddle). | ||
Additionally, we offer [Jupyter notebook tutorials](https://github.com/NVIDIA/TransformerEngine/tree/main/docs/examples) and a selection of [third-party examples](#third-party). Please be aware that these third-party examples might need specific, older versions of dependencies to function properly. | ||
|
||
# PyTorch | ||
|
||
- [Accelerate Hugging Face Llama models with TE](https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb) | ||
- Provides code examples and explanations for integrating TE with the LLaMA2 and LLaMA2 models. | ||
- [PyTorch FSDP with FP8](https://github.com/NVIDIA/TransformerEngine/tree/main/examples/pytorch/fsdp) | ||
- **Distributed Training**: How to set up and run distributed training using PyTorch’s FullyShardedDataParallel (FSDP) strategy. | ||
- **TE Integration**: Instructions on integrating TE/FP8 with PyTorch for optimized performance. | ||
- **Checkpointing**: Methods for applying activation checkpointing to manage memory usage during training. | ||
- [Attention backends in TE](https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/attention/attention.ipynb) | ||
- **Attention Backends**: Describes various attention backends supported by Transformer Engine, including framework-native, fused, and flash-attention backends, and their performance benefits. | ||
- **Flash vs. Non-Flash**: Compares the flash algorithm with the standard non-flash algorithm, highlighting memory and computational efficiency improvements. | ||
- **Backend Selection**: Details the logic for selecting the most appropriate backend based on availability and performance, and provides user control options for backend selection. | ||
- [Overlapping Communication with GEMM](https://github.com/NVIDIA/TransformerEngine/tree/main/examples/pytorch/comm_gemm_overlap) | ||
- Training a TE module with GEMM and communication overlap, including various configurations and command-line arguments for customization. | ||
- [Performance Optimizations](https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/advanced_optimizations.ipynb) | ||
- **Multi-GPU Training**: How to use TE with data, tensor, and sequence parallelism. | ||
- **Gradient Accumulation Fusion**: Utilizing Tensor Cores to accumulate outputs directly into FP32 for better numerical accuracy. | ||
- **FP8 Weight Caching**: Avoiding redundant FP8 casting during multiple gradient accumulation steps to improve efficiency. | ||
- [Introduction to FP8](https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/fp8_primer.ipynb) | ||
- Overview of FP8 datatypes (E4M3, E5M2), mixed precision training, delayed scaling strategies, and code examples for FP8 configuration and usage. | ||
- [TE Quickstart](https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/quickstart.ipynb) | ||
- Introduction to TE, building a Transformer Layer using PyTorch, and instructions on integrating TE modules like Linear and LayerNorm. | ||
- [Basic MNIST Example](https://github.com/NVIDIA/TransformerEngine/tree/main/examples/pytorch/mnist) | ||
|
||
# JAX | ||
- [Basic Transformer Encoder Example](https://github.com/NVIDIA/TransformerEngine/tree/main/examples/jax/encoder) | ||
- Single GPU Training: Demonstrates setting up and training a Transformer model using a single GPU. | ||
- Data Parallelism: Scale training across multiple GPUs using data parallelism. | ||
- Model Parallelism: Divide a model across multiple GPUs for parallel training. | ||
- Multiprocessing with Model Parallelism: Multiprocessing for model parallelism, including multi-node support and hardware affinity setup. | ||
- [Basic MNIST Example](https://github.com/NVIDIA/TransformerEngine/tree/main/examples/jax/mnist) | ||
|
||
# PaddlePaddle | ||
- [Basic MNIST Example](https://github.com/NVIDIA/TransformerEngine/tree/main/examples/paddle/mnist) | ||
|
||
# Third party | ||
- [Hugging Face Accelerate + TE](https://github.com/huggingface/accelerate/tree/main/benchmarks/fp8/transformer_engine) | ||
- Scripts for training with Accelerate and TE. Supports single GPU, and multi-GPU via DDP, FSDP, and DeepSpeed ZeRO 1-3. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# | ||
# See LICENSE for license information. | ||
"""Shared functions for the encoder tests""" | ||
from functools import lru_cache | ||
|
||
from transformer_engine.transformer_engine_jax import get_device_compute_capability | ||
|
||
|
||
@lru_cache | ||
def is_bf16_supported(): | ||
"""Return if BF16 has hardware supported""" | ||
gpu_arch = get_device_compute_capability(0) | ||
return gpu_arch >= 80 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.