diff --git a/.github/workflows/trigger-ci.yml b/.github/workflows/trigger-ci.yml index c25aa863ad..c7a85029e6 100644 --- a/.github/workflows/trigger-ci.yml +++ b/.github/workflows/trigger-ci.yml @@ -36,6 +36,9 @@ jobs: || github.actor == 'yaox12' || github.actor == 'huanghua1994' || github.actor == 'mgoldfarb-nvidia' + || github.actor == 'pggPL' + || github.actor == 'vasunvidia' + || github.actor == 'erhoo82' ) steps: - name: Check if comment is issued by authorized person diff --git a/.gitignore b/.gitignore index 898feb8f4f..6890911c14 100644 --- a/.gitignore +++ b/.gitignore @@ -39,3 +39,4 @@ develop-eggs/ dist/ downloads/ .pytest_cache/ +compile_commands.json diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 2533f5e5c1..936021bfed 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 2533f5e5c1877fd76266133c1479ef1643ce3a8b +Subproject commit 936021bfed8c91dc416af1588b2c4eca631a9e45 diff --git a/build_tools/VERSION.txt b/build_tools/VERSION.txt index 1cac385c6c..0eed1a29ef 100644 --- a/build_tools/VERSION.txt +++ b/build_tools/VERSION.txt @@ -1 +1 @@ -1.11.0 +1.12.0 diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index 4563a0272a..9152229d2f 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -80,10 +80,10 @@ def setup_pytorch_extension( ) ) - if "80" in cuda_architectures: - nvcc_flags.extend(["-gencode", "arch=compute_80,code=sm_80"]) - if "90" in cuda_architectures: - nvcc_flags.extend(["-gencode", "arch=compute_90,code=sm_90"]) + for arch in cuda_architectures.split(";"): + if arch == "70": + continue # Already handled + nvcc_flags.extend(["-gencode", f"arch=compute_{arch},code=sm_{arch}"]) # Libraries library_dirs = [] diff --git a/docs/faq.rst b/docs/faq.rst new file mode 100644 index 0000000000..50b3a7481e --- /dev/null +++ b/docs/faq.rst @@ -0,0 +1,75 @@ +.. + Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +Frequently Asked Questions (FAQ) +================================ + +FP8 checkpoint compatibility +---------------------------- + +Transformer Engine starts to support FP8 attention in 1.6. It stores the FP8 metadata, i.e. scaling factors and amax histories, under a `._extra_state` key in the checkpoint. As the FP8 attention support expands from one backend to multiple backends, the location of the `._extra_state` key has also shifted. + +Here, we take the `MultiheadAttention` module as an example. Its FP8 attention metadata in Transformer Engine 1.11 is stored as `core_attention._extra_state` as shown below. + +.. code-block:: python + + >>> from transformer_engine.pytorch import MultiheadAttention, fp8_model_init + >>> with fp8_model_init(enabled=True): + ... mha = MultiheadAttention( + ... hidden_size=1024, + ... num_attention_heads=16, + ... bias=True, + ... params_dtype=torch.bfloat16, + ... input_layernorm=False, + ... fuse_qkv_params=True, + ... attention_type="self", + ... qkv_weight_interleaved=True, + ... ).to(dtype=torch.bfloat16, device="cuda") + ... + >>> state_dict = mha.state_dict() + >>> print(state_dict.keys()) + odict_keys(['qkv.weight', 'qkv.bias', 'qkv._extra_state', 'core_attention._extra_state', 'proj.weight', 'proj.bias', 'proj._extra_state']) + +Here is a full list of the checkpoint save/load behaviors from all Transformer Engine versions. + +.. list-table:: + + * - **Version: <= 1.5** + + - Saves no FP8 metadata since FP8 attention is not supported + - Loading behavior for checkpoints created by the following versions: + + :<= 1.5: Loads no FP8 metadata + :> 1.5: Error: unexpected key + * - **Version: 1.6, 1.7** + + - Saves FP8 metadata to `core_attention.fused_attention._extra_state` + - Loading behavior for checkpoints created by the following versions: + + :<= 1.5: Initializes FP8 metadata to the default, i.e. 1s for scaling factors, and 0s for amaxes + :1.6, 1.7: Loads FP8 metadata from checkpoint + :>= 1.8: Error: unexpected key + * - **Version: >=1.8, <= 1.11** + + - Saves FP8 metadata to `core_attention._extra_state` + - Loading behavior for checkpoints created by the following versions: + + :<= 1.5: Initializes FP8 metadata to the default, i.e. 1s for scaling factors, and 0s for amaxes + :1.6, 1.7: This save/load combination relies on users to map the 1.6/1.7 key to the 1.8-1.11 key. Otherwise, it initializes FP8 metadata to the default, i.e. 1s for scaling factors, and 0s for amaxes. The mapping can be done, in this `MultiheadAttention` example, by + + .. code-block:: python + + >>> state_dict["core_attention._extra_state"] = \ + state_dict["core_attention.fused_attention._extra_state"] + >>> del state_dict["core_attention.fused_attention._extra_state"] + + :>= 1.8: Loads FP8 metadata from checkpoint + * - **Version: >=1.12** + + - Saves FP8 metadata to `core_attention._extra_state` + - Loading behavior for checkpoints created by the following versions: + + :<= 1.5: Initializes FP8 metadata to the default, i.e. 1s for scaling factors, and 0s for amaxes + :>= 1.6: Loads FP8 metadata from checkpoint diff --git a/docs/index.rst b/docs/index.rst index 47b8388dd2..38e095c239 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -30,6 +30,7 @@ Transformer Engine documentation installation examples/quickstart.ipynb + faq .. toctree:: :hidden: diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 0000000000..6001bc2cf6 --- /dev/null +++ b/examples/README.md @@ -0,0 +1,43 @@ +# Examples + +We provide a variety of examples for deep learning frameworks including [PyTorch](https://github.com/pytorch/pytorch), [JAX](https://github.com/jax-ml/jax), and [PaddlePaddle](https://github.com/PaddlePaddle/Paddle). +Additionally, we offer [Jupyter notebook tutorials](https://github.com/NVIDIA/TransformerEngine/tree/main/docs/examples) and a selection of [third-party examples](#third-party). Please be aware that these third-party examples might need specific, older versions of dependencies to function properly. + +# PyTorch + +- [Accelerate Hugging Face Llama models with TE](https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb) + - Provides code examples and explanations for integrating TE with the LLaMA2 and LLaMA2 models. +- [PyTorch FSDP with FP8](https://github.com/NVIDIA/TransformerEngine/tree/main/examples/pytorch/fsdp) + - **Distributed Training**: How to set up and run distributed training using PyTorch’s FullyShardedDataParallel (FSDP) strategy. + - **TE Integration**: Instructions on integrating TE/FP8 with PyTorch for optimized performance. + - **Checkpointing**: Methods for applying activation checkpointing to manage memory usage during training. +- [Attention backends in TE](https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/attention/attention.ipynb) + - **Attention Backends**: Describes various attention backends supported by Transformer Engine, including framework-native, fused, and flash-attention backends, and their performance benefits. + - **Flash vs. Non-Flash**: Compares the flash algorithm with the standard non-flash algorithm, highlighting memory and computational efficiency improvements. + - **Backend Selection**: Details the logic for selecting the most appropriate backend based on availability and performance, and provides user control options for backend selection. +- [Overlapping Communication with GEMM](https://github.com/NVIDIA/TransformerEngine/tree/main/examples/pytorch/comm_gemm_overlap) + - Training a TE module with GEMM and communication overlap, including various configurations and command-line arguments for customization. +- [Performance Optimizations](https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/advanced_optimizations.ipynb) + - **Multi-GPU Training**: How to use TE with data, tensor, and sequence parallelism. + - **Gradient Accumulation Fusion**: Utilizing Tensor Cores to accumulate outputs directly into FP32 for better numerical accuracy. + - **FP8 Weight Caching**: Avoiding redundant FP8 casting during multiple gradient accumulation steps to improve efficiency. +- [Introduction to FP8](https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/fp8_primer.ipynb) + - Overview of FP8 datatypes (E4M3, E5M2), mixed precision training, delayed scaling strategies, and code examples for FP8 configuration and usage. +- [TE Quickstart](https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/quickstart.ipynb) + - Introduction to TE, building a Transformer Layer using PyTorch, and instructions on integrating TE modules like Linear and LayerNorm. +- [Basic MNIST Example](https://github.com/NVIDIA/TransformerEngine/tree/main/examples/pytorch/mnist) + +# JAX +- [Basic Transformer Encoder Example](https://github.com/NVIDIA/TransformerEngine/tree/main/examples/jax/encoder) + - Single GPU Training: Demonstrates setting up and training a Transformer model using a single GPU. + - Data Parallelism: Scale training across multiple GPUs using data parallelism. + - Model Parallelism: Divide a model across multiple GPUs for parallel training. + - Multiprocessing with Model Parallelism: Multiprocessing for model parallelism, including multi-node support and hardware affinity setup. +- [Basic MNIST Example](https://github.com/NVIDIA/TransformerEngine/tree/main/examples/jax/mnist) + +# PaddlePaddle +- [Basic MNIST Example](https://github.com/NVIDIA/TransformerEngine/tree/main/examples/paddle/mnist) + +# Third party +- [Hugging Face Accelerate + TE](https://github.com/huggingface/accelerate/tree/main/benchmarks/fp8/transformer_engine) + - Scripts for training with Accelerate and TE. Supports single GPU, and multi-GPU via DDP, FSDP, and DeepSpeed ZeRO 1-3. diff --git a/examples/jax/encoder/common.py b/examples/jax/encoder/common.py new file mode 100644 index 0000000000..dcbfafc467 --- /dev/null +++ b/examples/jax/encoder/common.py @@ -0,0 +1,14 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Shared functions for the encoder tests""" +from functools import lru_cache + +from transformer_engine.transformer_engine_jax import get_device_compute_capability + + +@lru_cache +def is_bf16_supported(): + """Return if BF16 has hardware supported""" + gpu_arch = get_device_compute_capability(0) + return gpu_arch >= 80 diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index 25d744887e..bafd9bd2fb 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -22,6 +22,8 @@ import transformer_engine.jax as te import transformer_engine.jax.flax as te_flax +from common import is_bf16_supported + DEVICE_DP_AXIS = "data" DEVICE_TP_AXIS = "model" NAMED_BROADCAST_AXIS = "my_broadcast_axis" @@ -434,6 +436,7 @@ def setUpClass(cls): """Run 3 epochs for testing""" cls.args = encoder_parser(["--epochs", "3"]) + @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") def test_te_bf16(self): """Test Transformer Engine with BF16""" actual = train_and_evaluate(self.args) @@ -446,6 +449,7 @@ def test_te_fp8(self): actual = train_and_evaluate(self.args) assert actual[0] < 0.45 and actual[1] > 0.79 + @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") def test_te_bf16_sp(self): """Test Transformer Engine with BF16 + SP""" self.args.enable_sp = True diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index 9d08254f4d..a4a19b43c2 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -22,6 +22,8 @@ import transformer_engine.jax as te import transformer_engine.jax.flax as te_flax +from common import is_bf16_supported + DEVICE_DP_AXIS = "data" PARAMS_KEY = "params" PARAMS_AXES_KEY = PARAMS_KEY + "_axes" @@ -402,6 +404,7 @@ def setUpClass(cls): """Run 3 epochs for testing""" cls.args = encoder_parser(["--epochs", "3"]) + @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") def test_te_bf16(self): """Test Transformer Engine with BF16""" actual = train_and_evaluate(self.args) diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index e581dbc3f9..f54deff69c 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -24,6 +24,8 @@ import transformer_engine.jax as te import transformer_engine.jax.flax as te_flax +from common import is_bf16_supported + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" DEVICE_DP_AXIS = "data" DEVICE_TP_AXIS = "model" @@ -552,8 +554,9 @@ def encoder_parser(args): def query_gpu(q): """Query GPU info on the system""" gpu_has_fp8, reason = te.fp8.is_fp8_available() + gpu_has_bf16 = is_bf16_supported() num_gpu = len(jax.devices()) - q.put([num_gpu, gpu_has_fp8, reason]) + q.put([num_gpu, gpu_has_fp8, gpu_has_bf16, reason]) def unittest_query_gpu(): @@ -566,15 +569,15 @@ def unittest_query_gpu(): q = mp.Queue() p = mp.Process(target=query_gpu, args=(q,)) p.start() - num_gpu, gpu_has_fp8, reason = q.get() + num_gpu, gpu_has_fp8, gpu_has_bf16, reason = q.get() p.join() - return num_gpu, gpu_has_fp8, reason + return num_gpu, gpu_has_fp8, gpu_has_bf16, reason class TestEncoder(unittest.TestCase): """Encoder unittests""" - num_gpu, gpu_has_fp8, reason = unittest_query_gpu() + num_gpu, gpu_has_fp8, gpu_has_bf16, reason = unittest_query_gpu() def exec(self, use_fp8): """Run 3 epochs for testing""" @@ -598,6 +601,7 @@ def exec(self, use_fp8): return results + @unittest.skipIf(not gpu_has_bf16, "Device compute capability 8.0+ is required for BF16") def test_te_bf16(self): """Test Transformer Engine with BF16""" results = self.exec(False) diff --git a/examples/jax/encoder/test_single_gpu_encoder.py b/examples/jax/encoder/test_single_gpu_encoder.py index 363759afea..ac71fe4c0e 100644 --- a/examples/jax/encoder/test_single_gpu_encoder.py +++ b/examples/jax/encoder/test_single_gpu_encoder.py @@ -19,6 +19,8 @@ import transformer_engine.jax as te import transformer_engine.jax.flax as te_flax +from common import is_bf16_supported + PARAMS_KEY = "params" DROPOUT_KEY = "dropout" INPUT_KEY = "input_rng" @@ -321,6 +323,7 @@ def setUpClass(cls): """Run 4 epochs for testing""" cls.args = encoder_parser(["--epochs", "3"]) + @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") def test_te_bf16(self): """Test Transformer Engine with BF16""" actual = train_and_evaluate(self.args) diff --git a/pylintrc b/pylintrc index 9035de4f81..b80679d72c 100644 --- a/pylintrc +++ b/pylintrc @@ -8,7 +8,9 @@ extension-pkg-whitelist=flash_attn_2_cuda, extension-pkg-allow-list=transformer_engine.transformer_engine_jax disable=too-many-locals, + too-few-public-methods, too-many-public-methods, + too-many-positional-arguments, invalid-name, too-many-arguments, abstract-method, diff --git a/qa/L0_jax_lint/test.sh b/qa/L0_jax_lint/test.sh index afa2475f51..7bc84eef51 100644 --- a/qa/L0_jax_lint/test.sh +++ b/qa/L0_jax_lint/test.sh @@ -6,7 +6,7 @@ set -e : "${TE_PATH:=/opt/transformerengine}" -pip install cpplint==1.6.0 pylint==2.13.5 +pip install cpplint==1.6.0 pylint==3.3.1 if [ -z "${PYTHON_ONLY}" ] then cd $TE_PATH diff --git a/qa/L0_jax_unittest/test.sh b/qa/L0_jax_unittest/test.sh index db3aa31951..9efec6f2e5 100644 --- a/qa/L0_jax_unittest/test.sh +++ b/qa/L0_jax_unittest/test.sh @@ -18,7 +18,5 @@ pip install -r $TE_PATH/examples/jax/encoder/requirements.txt pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/mnist -# Make encoder tests to have run-to-run deterministic to have the stable CI results -export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder --ignore=$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py diff --git a/qa/L0_paddle_lint/test.sh b/qa/L0_paddle_lint/test.sh index 44249e9f60..5c5379554f 100644 --- a/qa/L0_paddle_lint/test.sh +++ b/qa/L0_paddle_lint/test.sh @@ -6,7 +6,7 @@ set -e : "${TE_PATH:=/opt/transformerengine}" -pip install cpplint==1.6.0 pylint==2.13.5 +pip install cpplint==1.6.0 pylint==3.3.1 if [ -z "${PYTHON_ONLY}" ] then cd $TE_PATH diff --git a/qa/L0_paddle_wheel/test.sh b/qa/L0_paddle_wheel/test.sh index 30fbb1df1f..00653877b8 100644 --- a/qa/L0_paddle_wheel/test.sh +++ b/qa/L0_paddle_wheel/test.sh @@ -6,7 +6,11 @@ set -e : "${TE_PATH:=/opt/transformerengine}" -pip install wheel==0.44.0 pydantic +# Install dependencies +# Note: Need to install wheel locally since PaddlePaddle container +# already contains APT install. +pip install pydantic +pip install --user wheel==0.44.0 cd $TE_PATH pip uninstall -y transformer-engine transformer-engine-cu12 transformer-engine-paddle @@ -16,11 +20,11 @@ WHL_BASE="transformer_engine-${VERSION}" # Core wheel. NVTE_RELEASE_BUILD=1 python setup.py bdist_wheel -wheel unpack dist/* +python -m wheel unpack dist/* sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info" -wheel pack ${WHL_BASE} +python -m wheel pack ${WHL_BASE} rm dist/*.whl mv *.whl dist/ NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python setup.py bdist_wheel diff --git a/qa/L0_pytorch_lint/test.sh b/qa/L0_pytorch_lint/test.sh index ea74427967..ac517976c7 100644 --- a/qa/L0_pytorch_lint/test.sh +++ b/qa/L0_pytorch_lint/test.sh @@ -6,7 +6,7 @@ set -e : "${TE_PATH:=/opt/transformerengine}" -pip install cpplint==1.6.0 pylint==2.13.5 +pip install cpplint==1.6.0 pylint==3.3.1 if [ -z "${PYTHON_ONLY}" ] then cd $TE_PATH diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index b69aed6648..17307574a9 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -6,21 +6,19 @@ set -e : ${TE_PATH:=/opt/transformerengine} -pip install pytest==8.2.1 onnxruntime==1.13.1 +pip install pytest==8.2.1 pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py pytest -v -s $TE_PATH/tests/pytorch/test_jit.py -NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py +NVTE_TORCH_COMPILE=0 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py -NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py pytest -v -s $TE_PATH/tests/pytorch/test_torch_save_load.py pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py -pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops_distributed.py pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py diff --git a/qa/L1_pytorch_context_parallel_test/test.sh b/qa/L1_pytorch_context_parallel_test/test.sh deleted file mode 100644 index 7f3c289b36..0000000000 --- a/qa/L1_pytorch_context_parallel_test/test.sh +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -set -e - -: ${TE_PATH:=/opt/transformerengine} - -pip install pytest==7.2.0 onnxruntime==1.13.1 -pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index 50394c33a9..c22ba221be 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -4,16 +4,10 @@ set -e -# pkg_resources is deprecated in setuptools 70+ and the packaging submodule -# has been removed from it. This is a temporary fix until upstream MLM fix. -pip install setuptools==69.5.1 - : ${TE_PATH:=/opt/transformerengine} -pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py -git clone https://github.com/NVIDIA/Megatron-LM.git -cd Megatron-LM -git checkout bcce6f54e075e3c3374ea67adefe54f3f2da2b07 -sed -i -e '1504,1505d' megatron/model/transformer.py -pytest -v -s $TE_PATH/tests/pytorch/distributed/test_convergence.py -python $TE_PATH/tests/pytorch/distributed/print_logs.py +pip install pytest==8.2.1 +pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py +pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py +pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py +pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py diff --git a/qa/L1_pytorch_onnx_test/test.sh b/qa/L1_pytorch_onnx_test/test.sh new file mode 100644 index 0000000000..5a01468064 --- /dev/null +++ b/qa/L1_pytorch_onnx_test/test.sh @@ -0,0 +1,16 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +set -e + +: ${TE_PATH:=/opt/transformerengine} + +pip install pytest==8.2.1 onnxruntime==1.19.2 + +# Build custom ONNX Runtime operators +export CUSTOM_ORT_OPS_PATH=$TE_PATH/tests/pytorch/custom_ort_ops +bash $CUSTOM_ORT_OPS_PATH/build.sh + +# Run tests +NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py diff --git a/qa/L3_pytorch_FA_versions_test/test.sh b/qa/L3_pytorch_FA_versions_test/test.sh new file mode 100644 index 0000000000..162ed85823 --- /dev/null +++ b/qa/L3_pytorch_FA_versions_test/test.sh @@ -0,0 +1,33 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +set -e + +: ${TE_PATH:=/opt/transformerengine} + +pip install pytest==8.2.1 + +# Limit parallel build jobs to avoid overwhelming system resources +export MAX_JOBS=4 + +# Iterate over Flash Attention versions +FA_versions=(2.1.1 2.3.0 2.4.0.post1 2.4.1 2.5.7 2.6.3 3.0.0b1) +for fa_version in "${FA_versions[@]}" +do + + # Build Flash Attention + if [ "${fa_version}" \< "3.0.0" ] + then + pip install flash-attn==${fa_version} + else + pip install "git+https://github.com/Dao-AILab/flash-attention.git#egg=flashattn-hopper&subdirectory=hopper" + python_path=`python -c "import site; print(site.getsitepackages()[0])"` + mkdir -p $python_path/flashattn_hopper + wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/main/hopper/flash_attn_interface.py + fi + + # Run tests + NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py + +done diff --git a/qa/L3_pytorch_convergence_test/test.sh b/qa/L3_pytorch_convergence_test/test.sh new file mode 100644 index 0000000000..fca621f279 --- /dev/null +++ b/qa/L3_pytorch_convergence_test/test.sh @@ -0,0 +1,14 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +set -e + +: ${TE_PATH:=/opt/transformerengine} + +pip install prettytable +git clone https://github.com/NVIDIA/Megatron-LM.git +cd Megatron-LM +git checkout b3375a0e38c10e2300ef4be031f7dcabab52b448 +pytest -v -s $TE_PATH/tests/pytorch/distributed/test_convergence.py +python $TE_PATH/tests/pytorch/distributed/print_logs.py diff --git a/setup.py b/setup.py index 0b0639aea6..512defa619 100644 --- a/setup.py +++ b/setup.py @@ -93,7 +93,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: # Framework-specific requirements if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): if "pytorch" in frameworks: - install_reqs.extend(["torch", "flash-attn>=2.0.6,<=2.6.3,!=2.0.9,!=2.1.0"]) + install_reqs.extend(["torch"]) test_reqs.extend(["numpy", "onnxruntime", "torchvision", "prettytable"]) if "jax" in frameworks: install_reqs.extend(["jax", "flax>=0.7.1"]) diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index 61d68aacae..23a26087d4 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -17,7 +17,13 @@ generate_collectives_count, compare_ops, ) -from utils import make_causal_mask, make_self_mask, assert_tree_like_allclose, assert_allclose +from utils import ( + make_causal_mask, + make_self_mask, + assert_tree_like_allclose, + assert_allclose, + print_debug_tensor_stats, +) from transformer_engine.jax import fp8_autocast from transformer_engine.jax.attention import ( is_fused_attn_kernel_available, @@ -31,6 +37,8 @@ inverse_reorder_causal_load_balancing, ) +# We will use the golden reference model from our non distributed attention test fixture. +from test_fused_attn import general_dot_product_attention, make_mask DTYPES = [jnp.float16, jnp.bfloat16] @@ -124,8 +132,10 @@ def test_self_attn( seqlen, seqlen, hidden, + None, # no window + False, # not context parallel ): - pytest.skip(f"No FusedAttn backwend found") + pytest.skip(f"No FusedAttn backend found") def target_func(qkv, bias, mask): return jnp.mean( @@ -257,8 +267,10 @@ def test_cross_attn( seqlen, seqlen, hidden, + None, # no window + False, # not context parallel ): - pytest.skip(f"No FusedAttn backwend found") + pytest.skip(f"No FusedAttn backend found") def target_func(q, kv, mask): return jnp.mean( @@ -323,18 +335,27 @@ def ref_func(query, kv, mask): ) -class TestDistributedContexParallelSelfAttn: +class TestDistributedContextParallelSelfAttn: def generate_inputs(self, shape, kv_groups: int, attn_mask_type: AttnMaskType, dtype): batch, seqlen, heads, hidden = shape + kv_shape = (batch, seqlen, heads // kv_groups, hidden) qkey, kkey, vkey = random.split(random.PRNGKey(1124), 3) q = random.normal(qkey, shape, dtype=dtype) k = random.normal(kkey, (batch, seqlen, heads // kv_groups, hidden), dtype=dtype) v = random.normal(vkey, (batch, seqlen, heads // kv_groups, hidden), dtype=dtype) - mask = None - if attn_mask_type == AttnMaskType.CAUSAL_MASK: - mask = make_causal_mask(batch, seqlen) + def gen_valid(bs, max_seqlen, pad_ratio): + pad_len = int(max_seqlen * pad_ratio) + valid_len = max_seqlen - pad_len + tokens = jnp.concatenate([jnp.ones((bs, valid_len)), jnp.zeros((bs, pad_len))], axis=-1) + return tokens, jnp.logical_not(tokens) + + from test_fused_attn import make_mask + + q_idx, _ = gen_valid(batch, seqlen, 0.0) + kv_idx, _ = gen_valid(batch, seqlen, 0.0) + mask = make_mask(q_idx, kv_idx, None, None, attn_mask_type) return q, k, v, mask @@ -378,7 +399,8 @@ def qkv_to_layout(self, q, k, v, qkv_layout): ], ) @pytest.mark.parametrize( - "load_balanced", [pytest.param(False, id="UNBALANCED"), pytest.param(True, id="BALANCED")] + "load_balanced", + [pytest.param(False, id="UNBALANCED"), pytest.param(True, id="BALANCED")], ) def test_contex_parallel_self_attn( self, @@ -396,61 +418,93 @@ def test_contex_parallel_self_attn( attn_bias_type = AttnBiasType.NO_BIAS dropout_prob = 0.0 is_training = True - scaling_factor = 1.0 dp_size, cp_size, tp_size = mesh_shape qkv_format = get_qkv_format(qkv_layout) - _, seqlen, num_head, hidden = data_shape + batch, seqlen, num_head, hidden = data_shape num_kv_heads = num_head // kv_groups + scaling_factor = 1.0 / np.sqrt(num_head) + + if not is_fused_attn_kernel_available( + dtype, + dtype, + qkv_layout, + attn_bias_type, + attn_mask_type, + dropout_prob, + num_head, + num_kv_heads, + seqlen, + seqlen, + hidden, + None, # no window + cp_size > 1, + ): + pytest.skip(f"No FusedAttn backend found") + + if dp_size > 1 and batch % dp_size != 0: + pytest.skip(f"Skipping {batch=} not a multiple of {dp_size=}") - # make sure the mesh evently divides cp and tp axis + # make sure the mesh even divides cp and tp axis if num_head % kv_groups != 0 or (num_head // kv_groups) % tp_size != 0: pytest.skip(f"Skipping {kv_groups=} not multiple of {data_shape=} or {tp_size=}") def target_func(q, k, v, mask): - return jnp.mean( - fused_attn( - self.qkv_to_layout(q, k, v, qkv_layout), - bias=None, - mask=mask, - seed=None, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - qkv_layout=qkv_layout, - scaling_factor=scaling_factor, - dropout_probability=dropout_prob, - is_training=is_training, - context_parallel_causal_load_balanced=load_balanced, - ), + return fused_attn( + self.qkv_to_layout(q, k, v, qkv_layout), + None, # bias + mask, + None, # seed + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + qkv_layout=qkv_layout, + scaling_factor=scaling_factor, + dropout_probability=dropout_prob, + is_training=is_training, + context_parallel_causal_load_balanced=load_balanced, + context_parallel_axis="cp", ).astype(dtype) - def ref_func(q, k, v, mask, kv_groups): - q = jnp.squeeze(q) - k = jnp.squeeze(jnp.repeat(k, kv_groups, axis=2)) - v = jnp.squeeze(jnp.repeat(v, kv_groups, axis=2)) - output = dot_product_attention( + def ref_func(q, k, v, mask): + output = general_dot_product_attention( q, k, v, bias=None, mask=mask, - deterministic=is_training, + deterministic=not is_training, + scale_factor=scaling_factor, dropout_rate=dropout_prob, dropout_rng=None, dtype=jnp.float32, ) - return jnp.mean(output).astype(dtype) + return output.astype(dtype) + + def grad_func(func, *args, **kwargs): + # Gradient is small, use a gradient multiplier to amplify the gradient + _, max_seq_len, num_heads, _ = data_shape + gradient_multiplier = max_seq_len * num_heads + if attn_mask_type in [AttnMaskType.CAUSAL_MASK, AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK]: + gradient_multiplier /= 10 + ret_valid = func(*args, **kwargs) + return (jnp.mean(ret_valid, dtype=jnp.float32) * gradient_multiplier).astype(dtype) q, k, v, mask = self.generate_inputs(data_shape, kv_groups, attn_mask_type, dtype) + diff_argnums = (0, 1, 2) + # Single GPU (reference) - ref_func_jit = jax.jit(jax.value_and_grad(ref_func, argnums=[0, 1, 2]), static_argnums=[4]) - ref_fwd, ref_grads = ref_func_jit(q, k, v, mask, kv_groups) + ref_func_jit = jax.jit( + jax.value_and_grad( + lambda q, k, v, mask: grad_func(ref_func, q, k, v, mask), argnums=diff_argnums + ) + ) + ref_fwd, ref_grads = ref_func_jit(q, k, v, mask) # Multi GPU (function under test) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = Mesh(devices, mesh_axes) - with mesh, fp8_autocast(mesh_resource=mesh_resource): + with mesh, fp8_autocast(mesh_resource=mesh_resource, enabled=False): qkv_ps = PartitionSpec( mesh_resource.dp_resource, mesh_resource.cp_resource, @@ -478,7 +532,10 @@ def ref_func(q, k, v, mask, kv_groups): mask_ = jax.device_put(mask, device=mask_sharding) target_func_jit = jax.jit( - jax.value_and_grad(target_func, argnums=[0, 1, 2]), + jax.value_and_grad( + lambda q, k, v, mask: grad_func(target_func, q, k, v, mask), + argnums=diff_argnums, + ), in_shardings=[qkv_sharding, qkv_sharding, qkv_sharding, mask_sharding], out_shardings=(None, (qkv_sharding, qkv_sharding, qkv_sharding)), ) @@ -489,37 +546,25 @@ def ref_func(q, k, v, mask, kv_groups): target_dq, target_dk, target_dv = jax.tree.map(inverse_reorder, target_grads[0:3]) target_grads = (target_dq, target_dk, target_dv, *target_grads[3:]) - def _print_diffs(target, ref): - print("min: ", jnp.min(target), jnp.min(ref)) - print("max: ", jnp.max(target), jnp.max(ref)) - print("mean: ", jnp.mean(target), jnp.mean(ref)) - print("median: ", jnp.median(target), jnp.median(ref)) - print("std: ", jnp.std(target), jnp.std(ref)) - print("var: ", jnp.var(target), jnp.var(ref)) - print("max diff: ", jnp.max(jnp.abs(target - ref))) - has_diffs = False - try: - assert_allclose(target_fwd, ref_fwd, dtype=dtype) - except AssertionError as e: - has_diffs = True - print(f"target_fwd v. ref_fwd") - _print_diffs(target_fwd, ref_fwd) + print_debug_tensor_stats("target", target_fwd) + print_debug_tensor_stats("ref", ref_fwd) + print_debug_tensor_stats("diff", jnp.abs(target_fwd - ref_fwd)) + assert_allclose(target_fwd, ref_fwd, dtype=dtype) for i in range(len(target_grads)): if ref_grads[i] is None or target_grads[i] is None: # expect both none if one is assert target_grads[i] is None and ref_grads[i] is None else: - try: - assert_allclose(target_grads[i], ref_grads[i]) - except AssertionError as e: - has_diffs = True - print(f"target_grads[{i}] v. ref_grads[{i}]") - _print_diffs(target_grads[i], ref_grads[i]) - - assert has_diffs == False, "has_diffs != False" + print_debug_tensor_stats(f"target_grad[{i}]", target_grads[i]) + print_debug_tensor_stats(f"ref_grad[{i}]", ref_grads[i]) + print_debug_tensor_stats( + f"diff_grad[{i}]", jnp.abs(target_grads[i] - ref_grads[i]) + ) + + assert_allclose(target_grads[i], ref_grads[i], dtype=dtype) class TestReorderCausalLoadBalancing: diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 2709eeab30..d4f92e940d 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -6,6 +6,7 @@ from dataclasses import dataclass from functools import partial from math import sqrt +from typing import Tuple, Optional import jax import jax.numpy as jnp @@ -27,6 +28,7 @@ fused_attn, fused_attn_thd, get_qkv_format, + make_swa_mask, ) from transformer_engine.jax.cpp_extensions import FusedAttnHelper from transformer_engine.transformer_engine_jax import ( @@ -123,6 +125,7 @@ def make_mask( segment_pad_q: ArrayLike, segment_pad_kv: ArrayLike, attn_mask_type: AttnMaskType, + window_size: Optional[Tuple[int, int]] = None, ) -> Array: """ Create attention mask based on mask type. A `True` value in the mask means @@ -140,6 +143,15 @@ def make_mask( segment_pad_q, segment_pad_kv, lambda x, y: jnp.logical_and(x != 1, y != 1) ) inv_mask = combine_masks(inv_pad_mask, inv_mask) + + if window_size is not None: + max_seqlen_q = inv_mask.shape[-2] + max_seqlen_kv = inv_mask.shape[-1] + inv_swa_mask = make_swa_mask(max_seqlen_q, max_seqlen_kv, window_size, attn_mask_type) + inv_swa_mask = jnp.broadcast_to(inv_swa_mask, inv_mask.shape) + # In inv_swa_mask and inv_mask 0 is masked out + inv_mask = jnp.where(inv_mask != 0, inv_swa_mask, inv_mask) + mask = jnp.logical_not(inv_mask) return mask @@ -274,6 +286,7 @@ class FusedAttnRunner: is_training: bool qkv_layout: QKVLayout bias_shape: BiasShape + window_size: Optional[Tuple[int, int]] = None # See https://docs.nvidia.com/deeplearning/cudnn/latest/release-notes.html#cudnn-9-4-0 for known issue # generating zero-length ragged tensors. This setting adjusts the test to avoid the zero-length cases. @@ -298,6 +311,11 @@ def _check_configs(self): if self.max_seqlen_q != self.max_seqlen_kv: pytest.skip("QKVPACKED layout requires max_seqlen_q and max_seqlen_kv to be equal.") + if self.max_seqlen_q > self.max_seqlen_kv and self.window_size is not None: + pytest.skip( + "seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN" + ) + self.backend = FusedAttnHelper( self.dtype, self.dtype, @@ -310,6 +328,7 @@ def _check_configs(self): self.max_seqlen_q, self.max_seqlen_kv, self.head_dim, + (-1, -1) if self.window_size is None else self.window_size, ).get_fused_attn_backend() if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend: pytest.skip("Unsupported inputs combination or device compute capability.") @@ -456,6 +475,7 @@ def generate_random_segment_ids( self.segment_pad_q, self.segment_pad_kv, self.attn_mask_type, + self.window_size, ) if get_qkv_format(self.qkv_layout) == QKVFormat.THD: @@ -500,6 +520,7 @@ def test_forward(self): "is_training": self.is_training, "qkv_layout": self.qkv_layout, "max_segments_per_seq": self._get_max_segments_per_sequence(), + "window_size": self.window_size, } # Convert the outputs to float32 for the elementwise comparison @@ -557,6 +578,7 @@ def grad_func(func, *args, **kwargs): "is_training": self.is_training, "qkv_layout": self.qkv_layout, "max_segments_per_seq": self._get_max_segments_per_sequence(), + "window_size": self.window_size, } # We can compute dBias only for the [1, h, s, s] layout @@ -668,7 +690,7 @@ def check_dqkv(primitive, reference, pad): pytest.param(4, 128, 128, 16, 16, 64, jnp.bfloat16, id="4-128-128-16-16-64-BF16-SELF"), pytest.param(4, 128, 128, 16, 16, 64, jnp.float16, id="4-128-128-16-16-64-FP16-SELF"), pytest.param(2, 2048, 2048, 12, 12, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-BF16-SELF"), - pytest.param(4, 512, 128, 16, 16, 64, jnp.bfloat16, id="4-512-128-16-16-64-BF16-CROSS"), + pytest.param(4, 128, 256, 16, 16, 64, jnp.bfloat16, id="4-128-256-16-16-64-BF16-CROSS"), pytest.param( 2, 2048, @@ -677,7 +699,7 @@ def check_dqkv(primitive, reference, pad): 12, 64, jnp.bfloat16, - id="2-2048-1048-12-12-64-BF16-CROSS", + id="2-2048-1024-12-12-64-BF16-CROSS", ), pytest.param(4, 128, 128, 16, 8, 64, jnp.bfloat16, id="4-128-128-16-8-64-BF16-GQA"), pytest.param(2, 2048, 2048, 12, 6, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-BF16-GQA"), @@ -690,6 +712,13 @@ def check_dqkv(primitive, reference, pad): pytest.param(0.1, id="DROP_0.1"), ], ) +@pytest.mark.parametrize( + "swa", + [ + pytest.param(False, id="NO_SWA"), + pytest.param(True, id="SWA"), + ], +) class TestFusedAttn: """ Fused attention tester @@ -717,12 +746,16 @@ def _test_forward( is_training, qkv_layout, bias_shape, + swa, ): """ Test forward with parameterized configs This test is not intended to run automatically during CI as it is time-consuming It is kept for development and debugging """ + window_size = None + if swa: + window_size = (s_kv // 10, 0) runner = FusedAttnRunner( b, s_q, @@ -737,6 +770,7 @@ def _test_forward( is_training, qkv_layout, bias_shape, + window_size, ) runner.test_forward() @@ -754,10 +788,14 @@ def test_backward( dtype, qkv_layout, bias_shape, + swa, ): """ Test backward with parameterized configs """ + window_size = None + if swa: + window_size = (s_kv // 10, 0) runner = FusedAttnRunner( b, s_q, @@ -772,5 +810,6 @@ def test_backward( True, qkv_layout, bias_shape, + window_size, ) runner.test_backward() diff --git a/tests/jax/test_layer.py b/tests/jax/test_layer.py index fa04382d59..3245bca676 100644 --- a/tests/jax/test_layer.py +++ b/tests/jax/test_layer.py @@ -4,7 +4,7 @@ """Test transformer_engine.jax.flax.TransformerLayer""" import os from functools import partial -from typing import Dict +from typing import Dict, Tuple import flax import jax @@ -61,6 +61,7 @@ def enable_fused_attn(): _KEY_OF_FLOAT32_ATTENTION_LOGITS = "float32_attention_logits" _KEY_OF_USE_BIAS = "use_bias" _KEY_OF_RELATIVE_EMBEDDING = "enable_relative_embedding" +_KEY_OF_WINDOW_SIZE = "window_size" BASE_ATTRS = { _KEY_OF_TRANSPOSE_BS: True, @@ -70,6 +71,7 @@ def enable_fused_attn(): _KEY_OF_INTERMEDIATE_DROPOUT: 0, _KEY_OF_SELF_ATTN_MASK_TYPE: "padding_causal", _KEY_OF_LAYERNORM_TYPE: "layernorm", + _KEY_OF_WINDOW_SIZE: (-1, -1), } ATTRS = [ @@ -193,6 +195,19 @@ def enable_fused_attn(): { _KEY_OF_MLP_ACTIVATIONS: (("relu", "relu")), }, + { + _KEY_OF_TRANSPOSE_BS: False, + _KEY_OF_RELATIVE_EMBEDDING: False, + _KEY_OF_SELF_ATTN_MASK_TYPE: "causal", + _KEY_OF_WINDOW_SIZE: (64, 0), # Left size must < DATA_SHAPE seqlen + _KEY_OF_FLOAT32_ATTENTION_LOGITS: True, + }, + { + _KEY_OF_TRANSPOSE_BS: False, + _KEY_OF_RELATIVE_EMBEDDING: False, + _KEY_OF_SELF_ATTN_MASK_TYPE: "padding", + _KEY_OF_WINDOW_SIZE: (2, 2), + }, ] ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS] @@ -326,7 +341,7 @@ def generate_inputs(self, data_shape, dtype): padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8) causal_mask = jnp.triu(jnp.ones((batch, 1, seqlen, seqlen), dtype=jnp.uint8), k=1) - if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ["casual", "padding_causal"]: + if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ["causal", "padding_causal"]: mask = causal_mask else: mask = padded_mask @@ -379,7 +394,7 @@ def generate_inputs(self, data_shape, dtype): padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8) causal_mask = jnp.triu(jnp.ones((batch, 1, seqlen, seqlen), dtype=jnp.uint8), k=1) - if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ["casual", "padding_causal"]: + if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ["causal", "padding_causal"]: self_mask = causal_mask else: self_mask = padded_mask diff --git a/tests/jax/test_praxis_layers.py b/tests/jax/test_praxis_layers.py index ccab73088a..8ac8ecbe79 100644 --- a/tests/jax/test_praxis_layers.py +++ b/tests/jax/test_praxis_layers.py @@ -4,7 +4,7 @@ import os from functools import partial -from typing import Dict +from typing import Dict, Tuple import flax import jax @@ -645,6 +645,7 @@ class DotProductAttnAttr: NUM_GQA_GROUPS = "num_gqa_groups" TRANSPOSE_BS = "transpose_batch_sequence" SCALE_FACTOR = "scale_factor" + WINDOW_SIZE = "window_size" ATTRS = [ { ATTN_MASK_TYPE: "padding", @@ -681,6 +682,12 @@ class DotProductAttnAttr: TRANSPOSE_BS: False, SCALE_FACTOR: 1.0, }, + { + ATTN_MASK_TYPE: "causal", + TRANSPOSE_BS: False, + SCALE_FACTOR: 1.0, + WINDOW_SIZE: (64, 0), # Left size must <= S in DATA_SHAPE + }, ] @@ -707,6 +714,7 @@ def generate_praxis_p_and_flax_cls(self, dtype, attrs): num_gqa_groups = num_attention_heads attn_mask_type = attrs[DotProductAttnAttr.ATTN_MASK_TYPE] transpose_batch_sequence = attrs[DotProductAttnAttr.TRANSPOSE_BS] + window_size = attrs.get(DotProductAttnAttr.WINDOW_SIZE, None) praxis_p = pax_fiddle.Config( DotProductAttention, @@ -717,6 +725,7 @@ def generate_praxis_p_and_flax_cls(self, dtype, attrs): num_gqa_groups=num_gqa_groups, attn_mask_type=attn_mask_type, transpose_batch_sequence=transpose_batch_sequence, + window_size=window_size, ) flax_cls = partial( flax_DotProductAttention, @@ -726,6 +735,7 @@ def generate_praxis_p_and_flax_cls(self, dtype, attrs): num_gqa_groups=num_gqa_groups, attn_mask_type=attn_mask_type, transpose_batch_sequence=transpose_batch_sequence, + window_size=window_size, ) return praxis_p, flax_cls @@ -750,6 +760,7 @@ class MultiHeadAttnAttr: ENABLE_ROPE = "enable_rotary_pos_emb" ROPE_GROUP_METHOD = "rotary_pos_emb_group_method" LORA_SCOPE = "low_rank_adaptation_scope" + WINDOW_SIZE = "window_size" ATTRS = [ { USE_BIAS: True, @@ -858,6 +869,17 @@ class MultiHeadAttnAttr: LORA_SCOPE: "all", TRANSPOSE_BS: True, }, + { + USE_BIAS: True, + LN_TYPE: "layernorm", + ZERO_CEN: False, + ENABLE_ROPE: False, + ROPE_GROUP_METHOD: "consecutive", + ATTN_MASK_TYPE: "causal", + LORA_SCOPE: "all", + TRANSPOSE_BS: True, + WINDOW_SIZE: (64, 0), # Left size must <= S in DATA_SHAPE + }, ] @@ -899,6 +921,7 @@ def generate_praxis_p_and_flax_cls(self, dtype, attrs): scale_attn_logits = False scaled_query_init = True float32_logits = False + window_size = attrs.get(MultiHeadAttnAttr.WINDOW_SIZE, None) praxis_p = pax_fiddle.Config( MultiHeadAttention, @@ -923,6 +946,7 @@ def generate_praxis_p_and_flax_cls(self, dtype, attrs): scale_attn_logits=scale_attn_logits, scaled_query_init=scaled_query_init, float32_logits=float32_logits, + window_size=window_size, ) flax_cls = partial( flax_MultiHeadAttention, @@ -946,6 +970,7 @@ def generate_praxis_p_and_flax_cls(self, dtype, attrs): scale_attn_logits=scale_attn_logits, scaled_query_init=scaled_query_init, float32_logits=float32_logits, + window_size=window_size, ) return praxis_p, flax_cls @@ -983,6 +1008,7 @@ class TransformerLayerAttr: ENABLE_ROPE = "enable_rotary_pos_emb" ROPE_GROUP_METHOD = "rotary_pos_emb_group_method" LORA_SCOPE = "low_rank_adaptation_scope" + WINDOW_SIZE = "window_size" ATTRS = [ { USE_BIAS: True, @@ -1246,6 +1272,28 @@ class TransformerLayerAttr: TRANSPOSE_BS: False, LORA_SCOPE: "all", }, + { + USE_BIAS: True, + LN_TYPE: "layernorm", + ZERO_CEN: False, + ACTIVATION: ("relu",), + LYR_TYPE: TransformerLayerType.ENCODER, + ENABLE_ROPE: False, + ROPE_GROUP_METHOD: "consecutive", + TRANSPOSE_BS: False, + WINDOW_SIZE: (64, 0), # Left size must <= S in DATA_SHAPE + }, + { + USE_BIAS: True, + LN_TYPE: "layernorm", + ZERO_CEN: False, + ACTIVATION: ("relu",), + LYR_TYPE: TransformerLayerType.DECODER, + ENABLE_ROPE: False, + ROPE_GROUP_METHOD: "consecutive", + TRANSPOSE_BS: False, + WINDOW_SIZE: (64, 0), # Left size must <= S in DATA_SHAPE + }, ] @@ -1289,6 +1337,7 @@ def generate_praxis_p_and_flax_cls(self, dtype, attrs): ) drop_path = 0.0 transpose_batch_sequence = attrs[TransformerLayerAttr.TRANSPOSE_BS] + window_size = attrs.get(TransformerLayerAttr.WINDOW_SIZE, None) rel_embedding_init = RelativePositionBiases.generate_embedding_init( relative_embedding.embedding_init, @@ -1330,6 +1379,7 @@ def generate_praxis_p_and_flax_cls(self, dtype, attrs): relative_embedding=relative_embedding, drop_path=drop_path, transpose_batch_sequence=transpose_batch_sequence, + window_size=window_size, ) flax_cls = partial( flax_TransformerLayer, @@ -1358,6 +1408,7 @@ def generate_praxis_p_and_flax_cls(self, dtype, attrs): low_rank_adaptation_scope=low_rank_adaptation_scope, drop_path=drop_path, transpose_batch_sequence=transpose_batch_sequence, + window_size=window_size, ) return praxis_p, flax_cls diff --git a/tests/jax/utils.py b/tests/jax/utils.py index 798c2a82ba..78a6225e1f 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -7,6 +7,7 @@ import math import operator from typing import Any, Callable, Dict, Tuple, Sequence, Union, Iterable, Optional +import os import jax import jax.numpy as jnp @@ -18,6 +19,7 @@ from jax import nn as jax_nn from jax import random as jax_random +from transformer_engine.jax.attention import AttnMaskType, make_swa_mask from transformer_engine.jax.fp8 import DType as TEDType PRNGKey = Any @@ -29,6 +31,9 @@ ] Initializer = Callable[[PRNGKey, Shape, DType], Array] +# Enables verbose printing of tensor numerics for debug. +NVTE_DEBUG_NUMERICS = bool(int(os.getenv("NVTE_DEBUG_NUMERICS", 0))) + def is_devices_enough(required): """ @@ -902,6 +907,33 @@ def __call__(self, qlen, klen, bidirectional=True): return values[jnp.newaxis, ...] +def apply_swa_mask( + attn_mask_type: str, + original_mask: Array, + window_size: Tuple[int, int] = (-1, -1), +) -> Array: + """Apply the sliding window mask to a given mask""" + mask_map = { + "no_mask": AttnMaskType.NO_MASK, + "padding": AttnMaskType.PADDING_MASK, + "causal": AttnMaskType.CAUSAL_MASK, + "padding_causal": AttnMaskType.PADDING_CAUSAL_MASK, + "causal_bottom_right": AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK, + "padding_causal_bottom_right": AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, + } + _attn_mask_type = mask_map.get(attn_mask_type, None) + assert _attn_mask_type is not None + max_seqlen_q = original_mask.shape[-2] + max_seqlen_kv = original_mask.shape[-1] + swa_mask = make_swa_mask( + max_seqlen_q, max_seqlen_kv, window_size, _attn_mask_type, dtype=original_mask.dtype + ) + # In swa_mask and original_mask 0 is masked out + swa_mask_bcast = jnp.broadcast_to(swa_mask, original_mask.shape) + new_mask = jnp.where(original_mask == 1, swa_mask_bcast, original_mask) + return new_mask + + class EncoderLayer(nn.Module): """Transformer encoder layer.""" @@ -934,7 +966,8 @@ class EncoderLayer(nn.Module): fuse_qkv_params: bool = True fuse_mlp_wi: bool = True self_attn_bias_type: Any = None - self_attn_mask_type: Any = None + self_attn_mask_type: str = "no_mask" + window_size: Tuple[int, int] = (-1, -1) def __post_init__(self): if self.num_gqa_groups is None: @@ -943,7 +976,13 @@ def __post_init__(self): @nn.compact def __call__(self, inputs, encoder_mask=None, deterministic=False): - del self.self_attn_mask_type # dummy, just align to TE's impl + # Currently cuDNN backend only supports SWA for causal/padding_causal, follow this + encoder_mask = apply_swa_mask( + self.self_attn_mask_type, + encoder_mask, + self.window_size, + ) + # Relative position embedding as attention biases. sequence_dim = 0 if self.transpose_batch_sequence else 1 batch_dim = 1 - sequence_dim @@ -1087,7 +1126,8 @@ class DecoderLayer(nn.Module): fuse_qkv_params: bool = True fuse_mlp_wi: bool = True self_attn_bias_type: Any = None - self_attn_mask_type: Any = None + self_attn_mask_type: str = "no_mask" + window_size: Tuple[int, int] = (-1, -1) def __post_init__(self): if self.num_gqa_groups is None: @@ -1105,7 +1145,18 @@ def __call__( decode=False, max_decode_length=None, ): - del self.self_attn_mask_type # dummy, just align to TE's impl + decoder_mask = apply_swa_mask( + self.self_attn_mask_type, + decoder_mask, + self.window_size, + ) + + encoder_decoder_mask = apply_swa_mask( + "padding", + encoder_decoder_mask, + self.window_size, + ) + # Relative position embedding as attention biases. sequence_dim = 0 if self.transpose_batch_sequence else 1 batch_dim = 1 - sequence_dim @@ -1419,3 +1470,23 @@ def sync_params_values(dst, src, transformations, sep="/"): synced_dst = jax.tree_util.tree_unflatten(dst_tree_def, synced_dst_values) return jax.tree_util.tree_map(lambda x, y: x.reshape(y.shape), synced_dst, dst) + + +@functools.partial(jax.jit, static_argnums=[0, 2]) +def print_debug_tensor_stats(prefix, tensor, hist=False): + if NVTE_DEBUG_NUMERICS: + args = [ + jnp.mean(tensor), + jnp.min(tensor), + jnp.max(tensor), + jnp.cumprod(jnp.array(tensor.shape))[-1] if len(tensor.shape) >= 1 else 1, + jnp.count_nonzero(tensor), + ] + fmt = prefix + " mean={}, min={}, max={}, numel={}, nzcnt={}" + + if hist: + h = jnp.histogram(tensor.astype(jnp.float32), bins=10) + args += [h[0], h[1]] + fmt = fmt + "\n {}\n {}" + + jax.debug.print(fmt, *args) diff --git a/tests/pytorch/custom_ort_ops/.gitignore b/tests/pytorch/custom_ort_ops/.gitignore new file mode 100644 index 0000000000..d491fb774c --- /dev/null +++ b/tests/pytorch/custom_ort_ops/.gitignore @@ -0,0 +1,3 @@ +build +onnxruntime +libcustom_ort_ops.so diff --git a/tests/pytorch/custom_ort_ops/CMakeLists.txt b/tests/pytorch/custom_ort_ops/CMakeLists.txt new file mode 100644 index 0000000000..90fb3624c1 --- /dev/null +++ b/tests/pytorch/custom_ort_ops/CMakeLists.txt @@ -0,0 +1,29 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +cmake_minimum_required(VERSION 3.21) +project(custom_ort_ops LANGUAGES CXX) + +# Dependencies +find_package(CUDAToolkit REQUIRED) +set(ONNX_INCLUDE_DIR ${CMAKE_SOURCE_DIR}/onnxruntime/include) +if(NOT EXISTS "${ONNX_INCLUDE_DIR}") + message(FATAL_ERROR + "Could not find ONNX Runtime headers. " + "Please clone https://github.com/microsoft/onnxruntime " + "into TransformerEngine/tests/pytorch/onnx.") +endif() +include_directories(${ONNX_INCLUDE_DIR}) + +# Configure library +add_library(custom_ort_ops SHARED custom_op_library.cc) +target_link_libraries(custom_ort_ops PUBLIC CUDA::cudart) +target_include_directories(custom_ort_ops PUBLIC + ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) +target_include_directories(custom_ort_ops PRIVATE + ${ONNX_INCLUDE_DIR}/onnxruntime + ${ONNX_INCLUDE_DIR}/onnxruntime/core/session) + +# Install library +install(TARGETS custom_ort_ops DESTINATION .) diff --git a/tests/pytorch/custom_ort_ops/README.md b/tests/pytorch/custom_ort_ops/README.md new file mode 100644 index 0000000000..ca392805be --- /dev/null +++ b/tests/pytorch/custom_ort_ops/README.md @@ -0,0 +1,22 @@ +# Custom ONNX Runtime operators for Transformer Engine tests + +This directory contains code that builds custom ONNX operators for use +in Transformer Engine tests. It includes basic, non-performant +implementations of the FP8 quantization and dequantization operators +that are used when exporting Transformer Engine models to ONNX. + +For more information, see [the ONNX Runtime reference for custom +operators](https://onnxruntime.ai/docs/reference/operators/add-custom-op.html). +Much of the code has been adapted from [an ONNX Runtime +test](https://github.com/microsoft/onnxruntime/blob/de93f40240459953a6e3bbb86b6ad83eaeab681f/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc). + +## Usage + +* Build the custom operators: +```bash +$ bash TransformerEngine/tests/pytorch/custom_ort_ops/build.sh +``` +* Run the ONNX export tests with pytest: +```bash +$ python -m pytest TransformerEngine/tests/pytorch/test_onnx_export.py +``` \ No newline at end of file diff --git a/tests/pytorch/custom_ort_ops/build.sh b/tests/pytorch/custom_ort_ops/build.sh new file mode 100644 index 0000000000..989da2f4ef --- /dev/null +++ b/tests/pytorch/custom_ort_ops/build.sh @@ -0,0 +1,17 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +set -ex + +: ${CUSTOM_ORT_OPS_PATH=$(dirname $(realpath $0))} +cd ${CUSTOM_ORT_OPS_PATH} + +# Download ONNX Runtime source +git clone --depth=1 -b rel-1.19.2 --single-branch https://github.com/microsoft/onnxruntime.git || true + +# Configure and build with CMake +mkdir -p build +cmake -S . -B build -DCMAKE_INSTALL_PREFIX=. +cmake --build build --verbose +cmake --install build --verbose diff --git a/tests/pytorch/custom_ort_ops/custom_op_library.cc b/tests/pytorch/custom_ort_ops/custom_op_library.cc new file mode 100755 index 0000000000..f46e897152 --- /dev/null +++ b/tests/pytorch/custom_ort_ops/custom_op_library.cc @@ -0,0 +1,102 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "custom_op_library.h" + +#define ORT_API_MANUAL_INIT +#include "onnxruntime_c_api.h" +#include "onnxruntime_cxx_api.h" +#undef ORT_API_MANUAL_INIT + +#include +#include +#include +#include +#include + +#include "core/common/common.h" +#include "core/session/onnxruntime_lite_custom_op.h" +#include + +namespace { + +template +void Quantize(OrtKernelContext* context, + const Ort::Custom::Tensor& input, + const Ort::Custom::Tensor& scale_inv, + Ort::Custom::Tensor& output) { + auto raw_input = input.Data(); + auto raw_scale_inv = scale_inv.Data(); + auto raw_output = reinterpret_cast(output.Allocate(input.Shape())); + const auto rs = static_cast(raw_scale_inv[0]); + const size_t N = input.NumberOfElement(); + for (size_t i = 0; i < N; ++i) { + const auto x = static_cast(raw_input[i]); + raw_output[i] = static_cast(x / rs); + } +} + +template +void Dequantize(OrtKernelContext* context, + const Ort::Custom::Tensor& input, + const Ort::Custom::Tensor& scale_inv, + Ort::Custom::Tensor& output) { + auto raw_input = reinterpret_cast(input.Data()); + auto raw_scale_inv = scale_inv.Data(); + auto raw_output = output.Allocate(input.Shape()); + const auto rs = static_cast(raw_scale_inv[0]); + const size_t N = input.NumberOfElement(); + for (size_t i = 0; i < N; ++i) { + const auto x = rs * static_cast(raw_input[i]); + raw_output[i] = static_cast(x); + } +} + +static void AddOrtCustomOpDomainToContainer(Ort::CustomOpDomain&& domain) { + static std::vector ort_custom_op_domain_container; + static std::mutex ort_custom_op_domain_mutex; + std::lock_guard lock(ort_custom_op_domain_mutex); + ort_custom_op_domain_container.push_back(std::move(domain)); +} + +} // namespace + +OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api) { + Ort::Global::api_ = api->GetApi(ORT_API_VERSION); + + // Namespace for custom ops + static const char* c_OpDomain = "trt"; + + // Construct custom ops + static const std::unique_ptr c_Quantize{ + Ort::Custom::CreateLiteCustomOp("TRT_FP8QuantizeLinear", + "CPUExecutionProvider", + Quantize) + }; + static const std::unique_ptr c_Dequantize{ + Ort::Custom::CreateLiteCustomOp("TRT_FP8DequantizeLinear", + "CPUExecutionProvider", + Dequantize<__nv_fp8_e4m3, float, float>) + }; + + // Register custom ops + OrtStatus* result = nullptr; + ORT_TRY { + Ort::CustomOpDomain domain{c_OpDomain}; + domain.Add(c_Quantize.get()); + domain.Add(c_Dequantize.get()); + Ort::UnownedSessionOptions session_options(options); + session_options.Add(domain); + AddOrtCustomOpDomainToContainer(std::move(domain)); + } + ORT_CATCH(const std::exception& e) { + ORT_HANDLE_EXCEPTION([&]() { + Ort::Status status{e}; + result = status.release(); + }); + } + return result; +} diff --git a/tests/pytorch/custom_ort_ops/custom_op_library.h b/tests/pytorch/custom_ort_ops/custom_op_library.h new file mode 100755 index 0000000000..7e4b8256bc --- /dev/null +++ b/tests/pytorch/custom_ort_ops/custom_op_library.h @@ -0,0 +1,18 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#pragma once +#include "onnxruntime/core/session/onnxruntime_c_api.h" + +#ifdef __cplusplus +extern "C" { +#endif + +ORT_EXPORT OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api); + +#ifdef __cplusplus +} +#endif diff --git a/tests/pytorch/distributed/run_numerics.py b/tests/pytorch/distributed/run_numerics.py new file mode 100644 index 0000000000..5d2828454c --- /dev/null +++ b/tests/pytorch/distributed/run_numerics.py @@ -0,0 +1,727 @@ +#!/usr/bin/python3 + +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import sys +import os +import argparse +from functools import wraps + +import transformer_engine.pytorch as te +import torch +from torch import nn +import torch.distributed as dist + +from transformer_engine.common.recipe import Format, DelayedScaling +from run_layer_with_overlap import _compare_tensors + +SEQ_LEN, BATCH_SIZE = 16, 16 +HIDDEN_SIZE = 64 +NR_HEADS = 4 +WORLD_RANK, WORLD_SIZE = None, None +NCCL_WORLD = None +LOSS_FN = nn.MSELoss() +FP8 = False + +# Fp8 recipe setup +fp8_format = Format.HYBRID +fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") + + +def main(argv=None, namespace=None): + global WORLD_RANK, WORLD_SIZE, NCCL_WORLD, FP8 + + WORLD_RANK = int(os.getenv("RANK", "0")) + WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) + LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) + LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1")) + + assert WORLD_SIZE == LOCAL_SIZE # this test supports only 1 node + assert LOCAL_SIZE <= torch.cuda.device_count() + dist_init_kwargs = { + "backend": "nccl", + "rank": WORLD_RANK, + "world_size": WORLD_SIZE, + } + dist_init_kwargs["init_method"] = "env://" + dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}") + assert dist.is_nccl_available() + torch.cuda.set_device(LOCAL_RANK) + dist.init_process_group(**dist_init_kwargs) + + NCCL_WORLD = dist.new_group(backend="nccl") + + WORLD_SIZE = dist.get_world_size() + + parser = argparse.ArgumentParser() + parser.add_argument("-l", "--layer-type", type=str) + parser.add_argument("--fp8", action="store_true", default=False) + args = parser.parse_args(argv, namespace) + + test_dict = [ + test_linear, + test_layernorm, + test_layernorm_linear, + test_layernorm_mlp, + test_transformer_layer, + ] + + FP8 = args.fp8 + + for test in test_dict: + test() + dist.destroy_process_group() + return 0 + + +def run_distributed_test(test_name=None): + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + name = test_name if test_name is not None else func.__name__ + + dist_print(f"Starting test {name} with args {args} and {kwargs}") + torch.cuda.set_device(WORLD_RANK) + torch.manual_seed(12345) + torch.cuda.manual_seed(12345) + func(*args, **kwargs) + + dist.barrier() + dist_print(f"Passed test {name}") + + return wrapper + + return decorator + + +def _gather(tensor, dim=0): + """ + Gathers tensors and concats them. Since torch.distributed.nn.functional.all_gather + multiplies gradients by WORLD_SIZE, those gradiedts are rescaled. + """ + + class HalfGradient(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + return input # forward pass (identity) + + @staticmethod + def backward(ctx, grad_output): + return grad_output / WORLD_SIZE # gradient division by WORLD_SIZE + + tensor = HalfGradient.apply(tensor) + gathered = torch.distributed.nn.functional.all_gather(tensor, group=NCCL_WORLD) + return torch.cat(gathered, dim=dim) + + +def _constant(tensor): + return nn.init.constant_(tensor, 0.5) + + +def dist_print(msg, src=None, end="\n", error=False): + stream = sys.stderr if error else sys.stdout + if WORLD_RANK == (0 if src is None else src): + stream.write(f"[rank{WORLD_RANK}] {msg}{end}\n") + dist.barrier() + + +def _get_tolerances(dtype): + if FP8: + return {"rtol": 0.125, "atol": 0.0625} + + if dtype == torch.float16: + return {"rtol": 1e-3, "atol": 1e-5} + if dtype == torch.bfloat16: + return {"rtol": 1.6e-2, "atol": 1e-5} + if dtype == torch.float32: + return {"rtol": 1.3e-6, "atol": 1e-5} + raise ValueError(f"Unsupported dtype ({dtype})") + + +def _check_outputs(output_single_node, output_distributed): + numerics_failed = torch.tensor([0], dtype=torch.uint8, device="cuda") + + output_failed, output_info = _compare_tensors( + "outputs", + output_distributed, + output_single_node, + **_get_tolerances(output_single_node.dtype), + ) + if output_failed: + dist_print(output_info, src=WORLD_RANK, error=output_failed) + numerics_failed[0] = int(output_failed) + dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, NCCL_WORLD) + if bool(numerics_failed.item()): + sys.exit(1) + + +def _match_param_sizes(dist_param, single_param): + """ + Adjust single_param to match the shape of dist_param + by slicing along dimensions where the shapes differ. + This function is typically used in a distributed setting + where single_param is a larger tensor that needs + to be partitioned among multiple processes. + + Args: + dist_param: Tensor representing the distributed output + with the desired shape for the current process. + single_param: Tensor representing the non-distributed output, + possibly larger than dist_param. + + Returns: + Tensor: Sliced version of single_param matching + the shape of dist_param for the current process. + """ + # Initialize indices for slicing with full slices for each dimension + indices = [slice(None)] * len(single_param.shape) + + # Iterate over each dimension to identify where shapes differ + for i in range(len(dist_param.shape)): + if dist_param.shape[i] != single_param.shape[i]: + # Calculate the start and end indices for slicing based on the world rank + start = WORLD_RANK * dist_param.shape[i] + end = (WORLD_RANK + 1) * dist_param.shape[i] + src_slice = slice(start, end) + + # Update the slicing indices for the current dimension + indices[i] = src_slice + + # Slice single_param to obtain the output matching dist_param's shape + to_output = single_param[tuple(indices)] + + return to_output + + +def _check_gradients(model_distributed, model_single, main_grad_check=False): + for i, ((name, param_d), param_s) in enumerate( + zip(model_distributed.named_parameters(), model_single.parameters()) + ): + numerics_failed = torch.tensor([0], dtype=torch.uint8, device="cuda") + grad_failed, grad_info = None, None + if main_grad_check: + param_s_grad = _match_param_sizes(param_d.main_grad, param_s.main_grad) + grad_failed, grad_info = _compare_tensors( + str(i), param_d.main_grad, param_s_grad, **_get_tolerances(param_s_grad.dtype) + ) + else: + param_s_grad = _match_param_sizes(param_d.grad, param_s.grad) + grad_failed, grad_info = _compare_tensors( + str(i), param_d.grad, param_s_grad, **_get_tolerances(param_s_grad.dtype) + ) + + if grad_failed: + dist_print(i) + dist_print(name) + dist_print(grad_info, src=WORLD_RANK, error=grad_failed) + numerics_failed[0] = int(grad_failed) + dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, NCCL_WORLD) + if bool(numerics_failed.item()): + sys.exit(1) + + +def _copy_params(model_distributed, model_single): + for dist_param, single_param in zip(model_distributed.parameters(), model_single.parameters()): + with torch.no_grad(): + to_copy = single_param + for dim, _ in enumerate(dist_param.shape): + if dist_param.shape[dim] != single_param.shape[dim]: + src_slice = slice( + WORLD_RANK * dist_param.shape[dim], (WORLD_RANK + 1) * dist_param.shape[dim] + ) + indices = [slice(None)] * max(min(dim, len(dist_param.shape) - 1), 0) + indices.append(src_slice) + if dim < len(dist_param.shape) - 1: + indices.append(slice(None)) + to_copy = single_param[tuple(indices)] + dist_param.copy_(to_copy) + + +def _apply_models( + model_single_node, model_distributed, input_single_node, input_distributed, **kwargs +): + _alloc_main_grad(model_single_node, model_distributed) # for fuse_wgrad_accumulation=True + with te.fp8_autocast(enabled=FP8, fp8_recipe=fp8_recipe): + output_single_node = model_single_node(input_single_node, **kwargs) + with te.fp8_autocast(enabled=FP8, fp8_recipe=fp8_recipe, fp8_group=NCCL_WORLD): + output_distributed = model_distributed(input_distributed, **kwargs) + return output_single_node, output_distributed + + +def _loss_backward(output_single_node, output_distributed): + target = torch.randn_like(output_single_node) + LOSS_FN(output_single_node, target).backward() + LOSS_FN(output_distributed, target).backward() + + +def _alloc_main_grad(model_single_node, model_distributed): + for model in [model_single_node, model_distributed]: + for param in model.parameters(): + param.main_grad = torch.zeros_like(param, dtype=torch.float32) + + +############################################ +# Linear # +############################################ +@run_distributed_test() +def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs): + """Test the linear layer with specified parallel mode and sequence parallelization. + + Args: + parallel_mode (str): 'row' or 'column' parallelism. + sequence_parallel (bool): Enable sequence parallelism if True. + kwargs (dict): Additional arguments for the linear layer. + """ + # Set parameter data type + params_dtype = kwargs.get("params_dtype", torch.float32) + + # Create models + model_single_node = te.Linear(HIDDEN_SIZE, HIDDEN_SIZE, **kwargs) + model_distributed = te.Linear( + HIDDEN_SIZE, + HIDDEN_SIZE, + tp_size=WORLD_SIZE, + tp_group=NCCL_WORLD, + parallel_mode=parallel_mode, + sequence_parallel=sequence_parallel, + **kwargs, + ) + + # Synchronize parameters between models + _copy_params(model_distributed, model_single_node) + + # Prepare input tensors + input_single_node = torch.randn((BATCH_SIZE, HIDDEN_SIZE)).cuda().to(params_dtype) + + if parallel_mode == "row": + # Split input across GPUs for row parallelism + split_size = HIDDEN_SIZE // WORLD_SIZE + input_distributed = input_single_node[ + :, WORLD_RANK * split_size : (WORLD_RANK + 1) * split_size + ].clone() + elif parallel_mode == "column": + if sequence_parallel: + # Duplicate input for sequence parallelism + input_single_node = ( + torch.empty((WORLD_SIZE * BATCH_SIZE, HIDDEN_SIZE)).cuda().to(params_dtype) + ) + input_distributed = torch.randn((BATCH_SIZE, HIDDEN_SIZE)).cuda().to(params_dtype) + input_single_node = _gather(input_distributed, dim=0).detach() + else: + input_distributed = input_single_node.clone() + else: + raise ValueError(f"Invalid parallel_mode: {parallel_mode}") + + # Apply models + output_single_node, output_distributed = _apply_models( + model_single_node, model_distributed, input_single_node, input_distributed + ) + + if "return_bias" in kwargs: + output_single_node, bias_s = output_single_node + output_distributed, bias_d = output_distributed + if parallel_mode == "column": + bias_d = _gather(bias_d) + _check_outputs(bias_s, bias_d) + + # Gather outputs if necessary + if parallel_mode == "column" or (sequence_parallel and parallel_mode == "row"): + output_distributed = _gather(output_distributed, dim=1 if parallel_mode == "column" else 0) + + # Compute loss and backpropagate + _loss_backward(output_single_node, output_distributed) + + # Validate outputs and gradients + _check_outputs(output_single_node, output_distributed) + + # gradients in other cases need additional synchronization + if (parallel_mode == "column" or not sequence_parallel) and "return_bias" not in kwargs: + _check_gradients( + model_distributed, + model_single_node, + main_grad_check=("fuse_wgrad_accumulation" in kwargs), + ) + + +def test_linear(): + """Run linear layer tests with various configurations.""" + kwargs_list = [ + {}, + {"bias": False}, + {"init_method": _constant}, + {"fuse_wgrad_accumulation": True}, + {"return_bias": True}, + {"params_dtype": torch.float16}, + ] + for kwargs in kwargs_list: + for parallel_mode in ["column", "row"]: + for sequence_parallel in [False, True]: + _test_linear(parallel_mode, sequence_parallel, **kwargs) + + +############################################ +# LayerNorm # +############################################ + + +@run_distributed_test() +def _test_layernorm(kwargs): + """Test LayerNorm and RMSNorm with given arguments. + + Args: + kwargs (dict): Contains 'norm', 'basic_args', and 'distributed_args'. + """ + # Extract parameters + norm = kwargs["norm"] + basic_args = kwargs["basic_args"] + distributed_args = kwargs["distributed_args"] + params_dtype = basic_args.get("params_dtype", torch.float32) + + # Create models + model_single_node = norm(HIDDEN_SIZE, **basic_args) + model_distributed = norm(HIDDEN_SIZE, **{**basic_args, **distributed_args}) + + # Synchronize parameters between models + _copy_params(model_distributed, model_single_node) + + # Prepare input tensors + input_single_node = torch.randn((BATCH_SIZE, HIDDEN_SIZE), dtype=params_dtype).cuda() + input_distributed = input_single_node.clone() + + # Apply models + output_single_node, output_distributed = _apply_models( + model_single_node, model_distributed, input_single_node, input_distributed + ) + + # Compute loss and backpropagate + _loss_backward(output_single_node, output_distributed) + + # Validate outputs and gradients + _check_outputs(output_single_node, output_distributed) + _check_gradients(model_distributed, model_single_node) + + +def test_layernorm(): + """Run LayerNorm and RMSNorm tests with various configurations.""" + norms = [te.LayerNorm, te.RMSNorm] + + # Define basic arguments for the models + basic_args_list = [ + {"zero_centered_gamma": True}, + {"params_dtype": torch.float16}, + ] + + # Define distributed arguments + distributed_args_list = [ + {}, + {"sequence_parallel": True}, + ] + + # Generate combinations of norms and arguments + for norm in norms: + for basic_args in basic_args_list: + for distributed_args in distributed_args_list: + kwargs = { + "norm": norm, + "basic_args": basic_args, + "distributed_args": distributed_args, + } + _test_layernorm(kwargs) + + +############################################ +# LayerNormLinear # +############################################ + + +@run_distributed_test() +def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs): + """Test the linear layer with specified parallel mode and sequence parallelization. + + Args: + parallel_mode (str): 'row' or 'column' parallelism. + sequence_parallel (bool): Enable sequence parallelism if True. + kwargs (dict): Additional arguments for the linear layer. + """ + # Set parameter data type + params_dtype = kwargs.get("params_dtype", torch.float32) + + # Create models + model_single_node = te.LayerNormLinear(HIDDEN_SIZE, HIDDEN_SIZE, **kwargs) + model_distributed = te.LayerNormLinear( + HIDDEN_SIZE, + HIDDEN_SIZE, + tp_size=WORLD_SIZE, + tp_group=NCCL_WORLD, + parallel_mode=parallel_mode, + sequence_parallel=sequence_parallel, + **kwargs, + ) + + # Synchronize parameters between models + _copy_params(model_distributed, model_single_node) + + # Prepare input tensors + input_single_node = torch.randn((SEQ_LEN, HIDDEN_SIZE)).cuda().to(params_dtype) + + if sequence_parallel: + # Duplicate input for sequence parallelism + input_single_node = torch.empty((WORLD_SIZE * SEQ_LEN, HIDDEN_SIZE)).cuda().to(params_dtype) + input_distributed = torch.randn((SEQ_LEN, HIDDEN_SIZE)).cuda().to(params_dtype) + input_single_node = _gather(input_distributed).detach() + else: + input_distributed = input_single_node.clone() + # Apply models + output_single_node, output_distributed = _apply_models( + model_single_node, model_distributed, input_single_node, input_distributed + ) + + if "return_layernorm_output" in kwargs: + output_single_node, norm_s = output_single_node + output_distributed, norm_d = output_distributed + if sequence_parallel: + norm_d = _gather(norm_d) + _check_outputs(norm_s, norm_d) + + if "return_bias" in kwargs: + output_single_node, bias_s = output_single_node + output_distributed, bias_d = output_distributed + if parallel_mode == "column": + bias_d = _gather(bias_d) + _check_outputs(bias_s, bias_d) + + # Gather outputs if necessary + if parallel_mode == "column" or (sequence_parallel and parallel_mode == "row"): + output_distributed = _gather(output_distributed, dim=1 if parallel_mode == "column" else 0) + + # Compute loss and backpropagate + _loss_backward(output_single_node, output_distributed) + + # Validate outputs and gradients + _check_outputs(output_single_node, output_distributed) + + # gradients in other cases need additional synchronization + if parallel_mode == "column" and not sequence_parallel and "return_bias" not in kwargs: + _check_gradients( + model_distributed, + model_single_node, + main_grad_check=("fuse_wgrad_accumulation" in kwargs), + ) + + +def test_layernorm_linear(): + kwargs_list = [ + {}, + {"bias": False}, + {"init_method": _constant}, + {"fuse_wgrad_accumulation": True}, + {"return_bias": True}, + {"params_dtype": torch.float16}, + {"zero_centered_gamma": False}, + {"return_layernorm_output": True}, + ] + for kwargs in kwargs_list: + for parallel_mode in ["column"]: + for sequence_parallel in [False, True]: + _test_layernorm_linear(parallel_mode, sequence_parallel, **kwargs) + + +############################################ +# LayerNormMLP # +############################################ + + +@run_distributed_test() +def _test_layernorm_mlp(set_parallel_mode=None, sequence_parallel=False, **kwargs): + """Test the LayerNormMLP with specified parallel mode and sequence parallelization. + + Args: + set_parallel_mode (bool): Enable parallel mode. + sequence_parallel (bool): Enable sequence parallelism if True. + kwargs (dict): Additional arguments for the linear layer. + """ + # Set parameter data type + params_dtype = kwargs.get("params_dtype", torch.float32) + FFN_HIDDEN_SIZE = ( + 64 if FP8 else 32 + ) # larger tensors lead to numerical failures with thight atol and rtol + + # Create models + model_single_node = te.LayerNormMLP(HIDDEN_SIZE, FFN_HIDDEN_SIZE, **kwargs) + model_distributed = te.LayerNormMLP( + HIDDEN_SIZE, + FFN_HIDDEN_SIZE, + tp_size=WORLD_SIZE, + tp_group=NCCL_WORLD, + set_parallel_mode=set_parallel_mode, + sequence_parallel=sequence_parallel, + **kwargs, + ) + + # Synchronize parameters between models + _copy_params(model_distributed, model_single_node) + + # Prepare input tensors + input_single_node = torch.randn((SEQ_LEN, HIDDEN_SIZE)).cuda().to(params_dtype) + + if sequence_parallel: + # Duplicate input for sequence parallelism + input_single_node = torch.empty((WORLD_SIZE * SEQ_LEN, HIDDEN_SIZE)).cuda().to(params_dtype) + input_distributed = torch.randn((SEQ_LEN, HIDDEN_SIZE)).cuda().to(params_dtype) + input_single_node = _gather(input_distributed).detach() + else: + input_distributed = input_single_node.clone() + # Apply models + output_single_node, output_distributed = _apply_models( + model_single_node, model_distributed, input_single_node, input_distributed + ) + + if "return_layernorm_output" in kwargs: + output_single_node, norm_s = output_single_node + output_distributed, norm_d = output_distributed + if sequence_parallel: + norm_d = _gather(norm_d) + _check_outputs(norm_s, norm_d) + + if "return_bias" in kwargs: + output_single_node, bias_s = output_single_node + output_distributed, bias_d = output_distributed + _check_outputs(bias_s, bias_d) + + if sequence_parallel: + output_distributed = _gather(output_distributed) + + # Compute loss and backpropagate + _loss_backward(output_single_node, output_distributed) + + # Validate outputs and gradients + _check_outputs(output_single_node, output_distributed) + + # gradients in other cases need additional synchronization + if not sequence_parallel and "return_bias" not in kwargs: + _check_gradients( + model_distributed, + model_single_node, + main_grad_check=("fuse_wgrad_accumulation" in kwargs), + ) + + +def test_layernorm_mlp(): + kwargs_list = [ + {}, + {"init_method": _constant}, + {"output_layer_init_method": _constant}, + {"normalization": "RMSNorm"}, + {"zero_centered_gamma": True}, + {"bias": False}, + {"params_dtype": torch.float16}, + {"activation": "relu"}, + {"fuse_wgrad_accumulation": True}, + {"return_bias": True}, + {"return_layernorm_output": True}, + ] + for kwargs in kwargs_list: + for set_parallel_mode in [True]: + for sequence_parallel in [False, True]: + _test_layernorm_mlp(set_parallel_mode, sequence_parallel, **kwargs) + + +############################################ +# TransformerLayer # +############################################ + + +@run_distributed_test() +def _test_transformer_layer_parallel(sequence_parallel=False, **kwargs): + params_dtype = kwargs.get("params_dtype", torch.float32) + FFN_HIDDEN_SIZE = ( + 64 if FP8 else 32 + ) # larger tensors lead to numerical failures with thight atol and rtol + + model_single_node = te.TransformerLayer( + HIDDEN_SIZE, FFN_HIDDEN_SIZE, NR_HEADS, attention_dropout=0, hidden_dropout=0, **kwargs + ) + model_distributed = te.TransformerLayer( + HIDDEN_SIZE, + FFN_HIDDEN_SIZE, + NR_HEADS, + tp_size=WORLD_SIZE, + tp_group=NCCL_WORLD, + set_parallel_mode=True, + sequence_parallel=sequence_parallel, + seq_length=WORLD_SIZE * SEQ_LEN if sequence_parallel else None, + attention_dropout=0, + hidden_dropout=0, + **kwargs, + ) + + _copy_params(model_distributed, model_single_node) + _alloc_main_grad(model_single_node, model_distributed) # for fuse_wgrad_accumulation=True + + input_single_node = ( + torch.randn((WORLD_SIZE * SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE)).cuda().to(params_dtype) + ) + if sequence_parallel: + input_distributed = input_single_node[ + WORLD_RANK * SEQ_LEN : (WORLD_RANK + 1) * SEQ_LEN, :, : + ] + else: + input_distributed = input_single_node.clone().cuda() + + encoder_output = None + if "layer_type" in kwargs: + encoder_output = torch.randn((SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE)).cuda() + + output_single_node, output_distributed = _apply_models( + model_single_node, + model_distributed, + input_single_node, + input_distributed, + encoder_output=encoder_output, + ) + + if sequence_parallel: + output_distributed = _gather(output_distributed) + + _loss_backward(output_single_node, output_distributed) + _check_outputs(output_single_node, output_distributed) + + # gradients in other cases need additional synchronization + if not sequence_parallel and "return_bias" not in kwargs: + _check_gradients( + model_distributed, + model_single_node, + main_grad_check=("fuse_wgrad_accumulation" in kwargs), + ) + + +def test_transformer_layer(): + kwargs_list = [ + {}, + {"num_gqa_groups": 4}, + {"init_method": _constant}, + {"output_layer_init_method": _constant}, + {"apply_residual_connection_post_layernorm": True}, + {"output_layernorm": True}, + {"parallel_attention_mlp": True}, + # {"layer_type": "decoder"}, + {"window_size": (2, 2)}, + {"normalization": "RMSNorm"}, + {"zero_centered_gamma": True}, + {"fuse_qkv_params": True}, + {"fuse_qkv_params": True, "fuse_wgrad_accumulation": True}, + {"qkv_weight_interleaved": False}, + {"bias": False}, + {"params_dtype": torch.float16}, + {"fuse_qkv_params": True}, + {"activation": "relu"}, + ] + for kwargs in kwargs_list: + for sequence_parallel in [False, True]: + _test_transformer_layer_parallel(sequence_parallel, **kwargs) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/pytorch/test_fusible_ops_distributed.py b/tests/pytorch/distributed/test_fusible_ops.py similarity index 100% rename from tests/pytorch/test_fusible_ops_distributed.py rename to tests/pytorch/distributed/test_fusible_ops.py diff --git a/tests/pytorch/distributed/test_numerics.py b/tests/pytorch/distributed/test_numerics.py new file mode 100644 index 0000000000..d0b445a505 --- /dev/null +++ b/tests/pytorch/distributed/test_numerics.py @@ -0,0 +1,55 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import os +import subprocess +from pathlib import Path + +import pytest +import torch +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager + +""" + Distributed numerics tests + + These tests test the numerical corectness of the TransformerEngine layers. + Tests are parametrized by the layer and fp8 precision. + One test consists of running multiple configurations from file run_numerics.py + Such design is due to the fact the initialization of one test is long + - 2 processes need to start and load torch and TE. Multiple configurations + are run in one test - this reduces the initialization overhead. + +""" + + +if torch.cuda.device_count() < 2: + pytest.skip("Distributed training needs at least 2 GPUs.") + +fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() + +TEST_ROOT = Path(__file__).parent.resolve() +NUM_PROCS: int = min(4, torch.cuda.device_count()) +LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"] + + +def _run_test(fp8): + test_path = TEST_ROOT / "run_numerics.py" + test_cmd = LAUNCH_CMD + [str(test_path)] + + if fp8: + test_cmd += ["--fp8"] + + result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False) + if result.returncode != 0 or "NUMERICAL CHECK FAILED" in result.stderr.decode(): + raise AssertionError(result.stderr.decode()) + + +all_boolean = [True, False] + + +@pytest.mark.parametrize("fp8", all_boolean) +def test_distributed(fp8): + if fp8 and not fp8_available: + pytest.skip(reason_for_no_fp8) + _run_test(fp8) diff --git a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py index 6c775fb127..15fb994050 100644 --- a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py @@ -59,6 +59,17 @@ def run_dpa_with_cp( cp_comm_ranks = range(world_size) assert rank in cp_comm_ranks cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl") + if cp_comm_type == "a2a+p2p": + assert ( + world_size % 2 == 0 + ), "Assuming CP size for A2A is 2, and CP size for P2P is (world_size // 2)!" + cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)] + cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)] + cp_comm_sub_groups = [] + for sub_ranks in cp_comm_sub_ranks: + sub_group = dist.new_group(sub_ranks, backend="nccl") + if rank in sub_ranks: + cp_comm_sub_groups.append(sub_group) if dtype == "fp8": fp8_recipe = DelayedScaling(fp8_dpa=True) @@ -167,13 +178,6 @@ def run_dpa_with_cp( else: bias = None - # make sure all GPU ranks have same inputs - for x in [q, k, v, dout] + ([] if bias is None else [bias]): - dist.broadcast(x, 0, group=cp_comm_group) - if qkv_format == "thd": - for x in [cu_seqlens_q, cu_seqlens_q_padded, cu_seqlens_kv, cu_seqlens_kv_padded]: - dist.broadcast(x, 0, group=cp_comm_group) - # run core_attn without CP for x in [q, k, v]: x.requires_grad = True @@ -239,7 +243,10 @@ def run_dpa_with_cp( bias_ = bias_.index_select(2, seq_idx) bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1]) core_attn.set_context_parallel_group( - cp_comm_group, cp_comm_ranks, torch.cuda.Stream(), cp_comm_type + cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else cp_comm_group, + cp_comm_ranks, + torch.cuda.Stream(), + cp_comm_type, ) if dtype == "fp8": diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index d110dece53..4b4eecbf39 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -20,9 +20,8 @@ MultiheadAttention, RotaryPositionEmbedding, get_attention_backend, - _flash_attn_2_plus, _flash_attn_2_3_plus, - _flash_attn_3_plus, + _flash_attn_3_is_installed, check_set_window_size, AttentionParams, _attention_backends, @@ -249,7 +248,7 @@ def test_dot_product_attention( # Test backend availability window_size = (-1, -1) if swa: - window_size = tuple(torch.randint(0, config.max_seqlen_kv, [2], dtype=torch.int32).tolist()) + window_size = [2, 2] config.window_size = check_set_window_size(config.attn_mask_type, window_size) available_backends, fused_attn_backends = _get_attention_backends( config, @@ -1319,6 +1318,8 @@ def _error(a, b, name_a, name_b, atol, rtol, rmse_tol): logging.debug(name_a + " min {:.6f} max {:.6f}".format(a.min().item(), a.max().item())) logging.debug(name_b + " min {:.6f} max {:.6f}".format(b.min().item(), b.max().item())) try: + if a.dtype != b.dtype: + a = a.to(b.dtype) torch.testing.assert_close(a, b, atol=atol, rtol=rtol) except Exception as e: logging.debug(e) @@ -1351,7 +1352,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" config = model_configs_fp8_vs_f16[model] - if _flash_attn_3_plus and not is_training: + if _flash_attn_3_is_installed and not is_training: if RoPE: pytest.skip("Flash Attention doesn't support FP8 MHA with RoPE.") os.environ["NVTE_FLASH_ATTN"] = "1" @@ -1379,7 +1380,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, rtol = 5e-1 rmse_tol = 0.15 logging.debug("========== {:^25s} ==========".format("forward output")) - if _flash_attn_3_plus and not is_training: + if _flash_attn_3_is_installed and not is_training: _error( flash_attn_fwd_fp8, fused_attn_fwd_f16, @@ -1532,7 +1533,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" - if _flash_attn_3_plus and not is_training: + if _flash_attn_3_is_installed and not is_training: os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True @@ -1559,7 +1560,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): rmse_tol = 0.1 bwd_names = ["dq", "dk", "dv"] logging.debug("========== {:^25s} ==========".format("forward output")) - if _flash_attn_3_plus and not is_training: + if _flash_attn_3_is_installed and not is_training: _error( flash_attn_fwd_fp8, fused_attn_fwd_f16, @@ -1854,13 +1855,6 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: """Get cuda rng tracker.""" return _DUMMY_CUDA_RNG_STATE_TRACKER - _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() - _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) - - def get_dummy_cuda_rng_tracker(): - """Get cuda rng tracker.""" - return _DUMMY_CUDA_RNG_STATE_TRACKER - block = DotProductAttention( config.num_heads, config.head_dim_qk, diff --git a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py index c1c18ffe47..ea30a4831f 100644 --- a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py @@ -36,8 +36,13 @@ } -def get_bash_arguments(**kwargs): - args = ["python", "-m", "torch.distributed.launch", "--nproc-per-node=2"] +def get_bash_arguments(num_gpus_per_node, **kwargs): + args = [ + "python", + "-m", + "torch.distributed.launch", + "--nproc-per-node=" + str(num_gpus_per_node), + ] te_path = os.getenv("TE_PATH", "/opt/transformerengine") script_path = os.path.join(te_path, "tests/pytorch/fused_attn/run_fused_attn_with_cp.py") args.append(script_path) @@ -51,20 +56,20 @@ def get_bash_arguments(**kwargs): @pytest.mark.parametrize("dtype", ["bf16", "fp16"]) @pytest.mark.parametrize("model", model_configs_flash_attn.keys()) @pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"]) -@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a"]) +@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a", "a2a+p2p"]) def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): config = model_configs_flash_attn[model] - if cp_comm_type == "p2p" and config.window_size != (-1, 0) and config.window_size != (-1, -1): + if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1): pytest.skip("CP implementation with KV P2P does not support sliding window yet!") if cp_comm_type == "all_gather" and qkv_format == "thd": pytest.skip("CP implementation with KV all-gather does not support THD format yet!") if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias": pytest.skip("CP implementation with KV all-gather does not support bias yet!") - if cp_comm_type == "a2a" and qkv_format == "thd": + if "a2a" in cp_comm_type and qkv_format == "thd": pytest.skip("CP implementation with QKVO A2A does not support THD format yet!") - if cp_comm_type == "a2a" and config.attn_bias_type != "no_bias": + if "a2a" in cp_comm_type and config.attn_bias_type != "no_bias": pytest.skip("CP implementation with QKVO A2A does not support bias yet!") - if cp_comm_type == "a2a" and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0): + if "a2a" in cp_comm_type and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0): pytest.skip( f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and" f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!" @@ -72,6 +77,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): subprocess.run( get_bash_arguments( + num_gpus_per_node=4 if cp_comm_type == "a2a+p2p" else 2, dtype=dtype, model=model, qkv_format=qkv_format, @@ -106,7 +112,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): @pytest.mark.parametrize("dtype", ["bf16", "fp16", "fp8"]) @pytest.mark.parametrize("model", model_configs_fused_attn.keys()) @pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"]) -@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a"]) +@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a", "a2a+p2p"]) def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type): if qkv_format == "thd" and get_device_compute_capability() < (9, 0): pytest.skip("THD format is only supported on sm90+!") @@ -122,7 +128,7 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type): pytest.skip("THD format does not support post_scale_bias yet!") if qkv_format == "thd" and cp_comm_type == "all_gather": pytest.skip("CP implementation with KV all-gather does not support THD format yet!") - if qkv_format == "thd" and cp_comm_type == "a2a": + if qkv_format == "thd" and "a2a" in cp_comm_type: pytest.skip("CP implementation with QKVO A2A does not support THD format yet!") if config.window_size != (-1, 0) and config.window_size != (-1, -1) and cp_comm_type != "a2a": pytest.skip( @@ -140,9 +146,9 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type): pytest.skip("FP8 attention cannot work with sliding window yet!") if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias": pytest.skip("CP implementation with KV all-gather does not support bias yet!") - if cp_comm_type == "a2a" and config.attn_bias_type != "no_bias": + if "a2a" in cp_comm_type and config.attn_bias_type != "no_bias": pytest.skip("CP implementation with QKVO A2A does not support bias yet!") - if cp_comm_type == "a2a" and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0): + if "a2a" in cp_comm_type and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0): pytest.skip( f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and" f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!" @@ -150,6 +156,7 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type): subprocess.run( get_bash_arguments( + num_gpus_per_node=4 if cp_comm_type == "a2a+p2p" else 2, dtype=dtype, model=model, qkv_format=qkv_format, diff --git a/tests/pytorch/libcustom_ort_fp8_qdq_ops.so b/tests/pytorch/libcustom_ort_fp8_qdq_ops.so deleted file mode 100755 index 61d9232e3a..0000000000 Binary files a/tests/pytorch/libcustom_ort_fp8_qdq_ops.so and /dev/null differ diff --git a/tests/pytorch/test_float8tensor.py b/tests/pytorch/test_float8tensor.py index fd204f58c4..51f4c695dc 100644 --- a/tests/pytorch/test_float8tensor.py +++ b/tests/pytorch/test_float8tensor.py @@ -89,7 +89,7 @@ def _test_quantize_dequantize( fp8_dtype=fp8_dtype, scale=torch.full([1], scale), ) - x_fp8 = x_fp8.from_float8().cpu() + x_fp8 = x_fp8.dequantize().cpu() # Check results torch.testing.assert_close(x_fp8, x_ref, **_tols[fp8_dtype]) @@ -144,7 +144,7 @@ def test_fp8_meta( fp8_meta=fp8_meta, fp8_meta_index=fp8_meta_index, ) - x_ref = x_fp8.from_float8() + x_ref = x_fp8.dequantize() assert list(x_fp8.size()) == dims, "Incorrect dims" assert x_fp8.dtype == dtype, "Incorrect nominal dtype" assert x_fp8.is_cuda, "Incorrect device" @@ -194,8 +194,8 @@ def test_basic_ops( fp8_dtype=fp8_dtype, scale=torch.full([1], scale), ) - x_ref = x_fp8.from_float8() - y_ref = y_fp8.from_float8() + x_ref = x_fp8.dequantize() + y_ref = y_fp8.dequantize() # Exact operations torch.testing.assert_close(-x_fp8, -x_ref, rtol=0, atol=0) @@ -237,23 +237,23 @@ def test_inplace_ops( fp8_dtype=fp8_dtype, scale=torch.full([1], scale), ) - x_ref = x_fp8.from_float8() - y_ref = y_fp8.from_float8() + x_ref = x_fp8.dequantize() + y_ref = y_fp8.dequantize() # In-place operations tols = _tols[fp8_dtype] x_fp8 += y_ref x_ref += y_ref torch.testing.assert_close(x_fp8, x_ref, **tols) - x_ref = x_fp8.from_float8() + x_ref = x_fp8.dequantize() x_fp8 -= y_fp8 x_ref -= y_fp8 torch.testing.assert_close(x_fp8, x_ref, **tols) - x_ref = x_fp8.from_float8() + x_ref = x_fp8.dequantize() x_fp8 *= 2 x_ref *= 2 torch.testing.assert_close(x_fp8, x_ref, **tols) - x_ref = x_fp8.from_float8() + x_ref = x_fp8.dequantize() # Make sure we are not trivially passing tests x_ref += 123 @@ -278,7 +278,7 @@ def test_transpose( fp8_dtype=fp8_dtype, scale=torch.full([1], scale), ) - x = x_fp8.from_float8() + x = x_fp8.dequantize() # Perform transpose x_fp8_t = x_fp8.transpose_2d() @@ -296,7 +296,7 @@ def test_transpose( # Caching test assert x_fp8._transpose_invalid, "Transpose cache must be invalid when not caching." x_fp8 += 0.5 - x = x_fp8.from_float8() + x = x_fp8.dequantize() x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8.transpose_2d(fill_cache=True)) x_t = x.transpose(0, 1) torch.testing.assert_close(x_fp8_t, x_t, **tols) @@ -305,7 +305,7 @@ def test_transpose( # Inplace update test x_fp8 += 0.5 assert not x_fp8._transpose_invalid, "Transpose cache reset incorrectly." - x = x_fp8.from_float8() + x = x_fp8.dequantize() x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8._transpose) x_t = x.transpose(0, 1) torch.testing.assert_close(x_fp8_t, x_t, **tols) @@ -326,7 +326,7 @@ def test_serialization( fp8_dtype=fp8_dtype, scale=torch.full([1], scale), ) - x_ref = x_fp8.from_float8() + x_ref = x_fp8.dequantize() # Serialize tensor byte_stream = io.BytesIO() @@ -351,3 +351,47 @@ def test_serialization( x_fp8._scale_inv.zero_() with pytest.raises(AssertionError): torch.testing.assert_close(x_fp8, x_ref, **tols) + + def test_set_data(self): + """Test directly setting .data attr""" + + # Initialize Float8Tensor + x0 = torch.zeros(4, dtype=torch.float32) + x = Float8Tensor.to_float8(x0) + assert isinstance(x, Float8Tensor) + assert x0.size() == x.size() == x._data.size() + assert x.dtype == torch.float32 + assert x.is_cuda and x._data.is_cuda + y = x.dequantize() + assert not isinstance(y, Float8Tensor) + assert x.size() == y.size() + assert x.dtype == y.dtype + assert x.device == y.device + + # Set data to plain tensor + x0 = torch.zeros((3, 2), dtype=torch.float16, device=x.device) + x.data = x0 + assert isinstance(x, Float8Tensor) + assert x0.size() == x.size() == x._data.size() + assert x0.dtype == x.dtype + assert x0.device == x.device == x._data.device + y = x.dequantize() + assert not isinstance(y, Float8Tensor) + assert x.size() == y.size() + assert x.dtype == y.dtype + assert x.device == y.device + + # Set data to Float8Tensor + x0 = Float8Tensor.to_float8(torch.zeros((4, 3, 1), dtype=torch.float32)) + x.data = x0 + assert isinstance(x, Float8Tensor) + assert x0.size() == x.size() == x._data.size() + assert x0.dtype == x.dtype + assert x0.device == x.device == x._data.device + assert x0._data is x._data + assert x0._scale_inv is x._scale_inv + y = x.dequantize() + assert not isinstance(y, Float8Tensor) + assert x.size() == y.size() + assert x.dtype == y.dtype + assert x.device == y.device diff --git a/tests/pytorch/test_fused_optimizer.py b/tests/pytorch/test_fused_optimizer.py index ee6739fbf6..d19fc5a521 100644 --- a/tests/pytorch/test_fused_optimizer.py +++ b/tests/pytorch/test_fused_optimizer.py @@ -3,9 +3,9 @@ # See LICENSE for license information. from itertools import product -import unittest import copy +import pytest import torch from torch import nn from torch.testing._internal.common_device_type import largeTensorTest @@ -19,14 +19,12 @@ fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() -class TestFusedOptimizer(unittest.TestCase): - def setUp(self, iters=7): +class TestFusedOptimizer: + + def setup_method(self, *, iters: int = 7) -> None: self.iters = iters torch.manual_seed(9876) - def tearDown(self): - pass - def gen_param_optim(self, tensors, options, tst_options=None): # Adding this to make backward compatible with existing tests. Just in @@ -88,8 +86,8 @@ def gen_single_type_test( class TestFusedAdam(TestFusedOptimizer): - def setUp(self): - super().setUp() + def setup_method(self) -> None: + super().setup_method() self.options = { "lr": 5e-4, "betas": (0.9, 0.999), @@ -111,7 +109,7 @@ def test_half(self): def test_bfloat16(self): self.gen_single_type_test(param_type=torch.bfloat16, skip_assert=True) - @unittest.skipIf(torch.cuda.device_count() < 2, "more than 1 GPU required") + @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="more than 1 GPU required") def test_multi_device(self): devices = ("cuda:0", "cuda:1") for current_dev, tensor_dev in product(devices, devices): @@ -176,7 +174,7 @@ def test_frozen_model(self): torch.testing.assert_close(ref_param, tst_param) - @unittest.skipIf(not is_bf16_compatible(), "bf16 if not supported") + @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") def test_bf16_model_weight_cast(self): dtype = torch.bfloat16 model = MultiheadAttention( @@ -214,7 +212,7 @@ def test_bf16_model_weight_cast(self): ref_params, model_params_to_fp32, rtol=1e-3, atol=1e-3, equal_nan=True ) - @unittest.skipIf(not fp8_available, reason=reason_for_no_fp8) + @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) def test_fp8_model_weight_cast(self): dtype = torch.bfloat16 with fp8_model_init(enabled=True): @@ -255,8 +253,9 @@ def test_fp8_model_weight_cast(self): class TestFusedSGD(TestFusedOptimizer): - def __init__(self, *args, **kwargs): - super(TestFusedSGD, self).__init__(*args, **kwargs) + + def setup_method(self) -> None: + super().setup_method() self.options = {"lr": 0.25, "momentum": 0.125} self.ref_optim = torch.optim.SGD self.fused_optim = te.optimizers.FusedSGD @@ -267,7 +266,7 @@ def test_float(self): def test_half(self): self.gen_single_type_test(param_type=torch.float16) - @unittest.skipIf(torch.cuda.device_count() < 2, "more than 1 GPU required") + @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="more than 1 GPU required") def test_multi_device(self): devices = ("cuda:0", "cuda:1") for current_dev, tensor_dev in product(devices, devices): @@ -308,9 +307,9 @@ def forward(self, x): return y -class AdamTest(unittest.TestCase): - def setUp(self, seed=0): - super().setUp() +class AdamTest: + + def setup_method(self, *, seed: int = 0) -> None: torch.manual_seed(seed) self.model = Model().cuda() @@ -321,7 +320,7 @@ def setUp(self, seed=0): params = [p for p in self.model.parameters() if p.requires_grad] self.optimizer = torch.optim.Adam(params, lr=self.lr) - def testGradScaler(self): + def test_grad_scaler(self): params_ = [p for p in self.model_.parameters() if p.requires_grad] optimizer_ = te.optimizers.FusedAdam(params_, lr=self.lr, capturable=False) scaler = torch.cuda.amp.GradScaler(enabled=True) @@ -372,7 +371,7 @@ def testGradScaler(self): self.model_.load_state_dict(copy.deepcopy(self.model.state_dict())) - def testGradScalerCapturable(self): + def test_grad_scaler_capturable(self): params_ = [p for p in self.model_.parameters() if p.requires_grad] optimizer_ = te.optimizers.FusedAdam(params_, lr=self.lr, capturable=True) scaler = torch.cuda.amp.GradScaler(enabled=True) @@ -423,7 +422,7 @@ def testGradScalerCapturable(self): self.model_.load_state_dict(copy.deepcopy(self.model.state_dict())) - def testGradScalerCapturableMaster(self): + def test_grad_scaler_capturable_master(self): # Cast conv layers to FP16 for m in self.model_.modules(): if m.__class__ in [torch.nn.Conv2d]: @@ -485,7 +484,7 @@ def testGradScalerCapturableMaster(self): self.model_.load_state_dict(copy.deepcopy(self.model.state_dict())) - def testNative(self): + def test_native(self): params_ = [p for p in self.model_.parameters() if p.requires_grad] optimizer_ = te.optimizers.FusedAdam(params_, lr=self.lr, capturable=False) @@ -531,7 +530,7 @@ def testNative(self): self.model_.load_state_dict(copy.deepcopy(self.model.state_dict())) @largeTensorTest("60GB", "cuda") - def testLargeTensor(self): + def test_large_tensor(self): t = torch.zeros(2359332864, dtype=torch.half, device="cuda") t2 = torch.zeros(2359332864, dtype=torch.half, device="cuda") grad = torch.randn_like(t) diff --git a/tests/pytorch/test_fused_rope.py b/tests/pytorch/test_fused_rope.py index d6ba66cbbc..81c4973756 100644 --- a/tests/pytorch/test_fused_rope.py +++ b/tests/pytorch/test_fused_rope.py @@ -1,17 +1,38 @@ # Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. +import math import pytest import torch -from typing import Callable, Dict, Tuple, Union +from typing import Callable, Tuple, Union from transformer_engine.pytorch.attention import ( RotaryPositionEmbedding, apply_rotary_pos_emb, ) +def _get_thd_freqs_on_this_cp_rank( + cp_rank: int, cp_size: int, x: torch.Tensor, freqs: torch.Tensor +) -> torch.Tensor: + if cp_size > 1: + cp_seg = x.size(0) // 2 + full_seqlen = cp_size * x.size(0) + return torch.cat( + [ + freqs[cp_rank * cp_seg : (cp_rank + 1) * cp_seg], + freqs[full_seqlen - (cp_rank + 1) * cp_seg : full_seqlen - cp_rank * cp_seg], + ] + ) + else: + return freqs[: x.size(0)] + + def apply_rotary_pos_emb_thd( - t: torch.Tensor, cu_seqlens: torch.Tensor, freqs: torch.Tensor + t: torch.Tensor, + cu_seqlens: torch.Tensor, + freqs: torch.Tensor, + cp_size: int = 1, + cp_rank: int = 0, ) -> torch.Tensor: """A baseline implementation of applying RoPE for `thd` format. @@ -24,20 +45,18 @@ def apply_rotary_pos_emb_thd( Returns: Tensor: Shape [t, h, d]. The input tensor after applying RoPE. """ + cu_seqlens = cu_seqlens // cp_size seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() return torch.cat( - [apply_rotary_pos_emb(x.unsqueeze(1), freqs[: x.size(0)]) for x in torch.split(t, seqlens)] + [ + apply_rotary_pos_emb( + x.unsqueeze(1), _get_thd_freqs_on_this_cp_rank(cp_rank, cp_size, x, freqs) + ) + for x in torch.split(t, seqlens) + ] ).squeeze(1) -def get_tol(dtype: torch.dtype) -> Dict: - if dtype == torch.bfloat16: - return dict(atol=1e-2, rtol=1e-2) - elif dtype == torch.float16: - return dict(atol=1e-3, rtol=1e-3) - return dict(atol=1e-5, rtol=1.3e-6) - - # Gradient is a broadcasted scalar def _overlapping_grad(output: torch.Tensor) -> torch.Tensor: return output.sum() * 2 @@ -84,7 +103,11 @@ def test_fused_rope( emb = rotary_pos_emb(seq_length) # unfused - output_unfused = apply_rotary_pos_emb(t, emb, tensor_format=tensor_format, fused=False) + # The fused kernel computes in float32 internally, so we force the unfused func to use float32 + # for more accurate comparison + output_unfused = apply_rotary_pos_emb( + t.float(), emb, tensor_format=tensor_format, fused=False + ).to(dtype) loss_unfused = loss_func(output_unfused) loss_unfused.backward() grad_unfused = t.grad.detach().clone() @@ -102,8 +125,8 @@ def test_fused_rope( grad_fused = t.grad.detach().clone() t.grad = None - torch.testing.assert_close(output_fused, output_unfused, **get_tol(dtype)) - torch.testing.assert_close(grad_fused, grad_unfused, **get_tol(dtype)) + torch.testing.assert_close(output_fused, output_unfused) + torch.testing.assert_close(grad_fused, grad_unfused) assert output_fused.is_contiguous() @@ -112,22 +135,34 @@ def test_fused_rope( @pytest.mark.parametrize("rotary_percent", [0.5, 1.0]) @pytest.mark.parametrize("transpose", [None, (1, 2)]) @pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad]) +@pytest.mark.parametrize("cp_size", [1, 2, 3]) def test_fused_rope_thd( dtype: torch.dtype, hidden_size: int, rotary_percent: float, transpose: Union[Tuple, None], loss_func: Callable, + cp_size: int, ) -> None: device = torch.device("cuda:0") batch_size, head_num = 2, 64 - cu_seqlens = torch.tensor( - [0, 400, 542, 711, 727, 752, 1270, 1426, 1450, 1954, 2044, 2048], + cu_seqlens = [0, 400, 542, 711, 727, 752, 1270, 1426, 1450, 1954, 2044, 2048] + if cp_size > 1: + cu_seqlens_padded = [0] + for i in range(1, len(cu_seqlens)): + cu_seqlens_padded.append( + cu_seqlens_padded[i - 1] + + math.ceil((cu_seqlens[i] - cu_seqlens[i - 1]) / (cp_size * 2)) * (cp_size * 2) + ) + else: + cu_seqlens_padded = cu_seqlens + cu_seqlens_padded = torch.tensor( + cu_seqlens_padded, dtype=torch.int32, device=device, ) t = torch.rand( - (cu_seqlens[-1], head_num, hidden_size), + (cu_seqlens_padded[-1] // cp_size, head_num, hidden_size), dtype=dtype, device=device, ) @@ -136,23 +171,34 @@ def test_fused_rope_thd( t.requires_grad = True rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent) - emb = rotary_pos_emb(cu_seqlens[-1]) - - # unfused - output_unfused = apply_rotary_pos_emb_thd(t, cu_seqlens, emb) - loss_unfused = loss_func(output_unfused) - loss_unfused.backward() - grad_unfused = t.grad.detach().clone() - t.grad = None - - # fused - output_fused = apply_rotary_pos_emb( - t, emb, fused=True, tensor_format="thd", cu_seqlens=cu_seqlens - ) - loss_fused = loss_func(output_fused) - loss_fused.backward() - grad_fused = t.grad.detach().clone() - t.grad = None - - torch.testing.assert_close(output_fused, output_unfused, **get_tol(dtype)) - torch.testing.assert_close(grad_fused, grad_unfused, **get_tol(dtype)) + emb = rotary_pos_emb(cu_seqlens_padded[-1]) + + for cp_rank in range(cp_size): + # unfused + # The fused kernel computes in float32 internally, so we force the unfused func to use float32 + # for more accurate comparison + output_unfused = apply_rotary_pos_emb_thd( + t.float(), cu_seqlens_padded, emb, cp_size, cp_rank + ).to(dtype) + loss_unfused = loss_func(output_unfused) + loss_unfused.backward() + grad_unfused = t.grad.detach().clone() + t.grad = None + + # fused + output_fused = apply_rotary_pos_emb( + t, + emb, + fused=True, + tensor_format="thd", + cu_seqlens=cu_seqlens_padded, + cp_size=cp_size, + cp_rank=cp_rank, + ) + loss_fused = loss_func(output_fused) + loss_fused.backward() + grad_fused = t.grad.detach().clone() + t.grad = None + + torch.testing.assert_close(output_fused, output_unfused) + torch.testing.assert_close(grad_fused, grad_unfused) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index e97dfe1efd..1d91683ae4 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -307,6 +307,128 @@ def test_fp8_scale_update( torch.testing.assert_close(x_scale, torch.full_like(x_scale, x_scale_ref)) torch.testing.assert_close(dy_scale, torch.full_like(dy_scale, dy_scale_ref)) + @pytest.mark.parametrize("init_dtype", _dtypes) + @pytest.mark.parametrize("final_dtype", _dtypes) + @pytest.mark.parametrize("fp8_weight", (False, True)) + def test_dtype_cast( + self, + *, + size: int = 16, + init_dtype: torch.dtype, + final_dtype: torch.dtype, + device: torch.device = "cuda", + fp8_weight: bool, + ) -> None: + """Check dtype cast functions""" + + # Skip invalid configurations + if fp8_weight: + if not fp8_available: + pytest.skip(reason_for_no_fp8) + if torch.device(device).type != "cuda": + pytest.skip("FP8 is only supported on CUDA devices") + + # Random data + dtype = torch.float32 + if torch.float16 in (init_dtype, final_dtype): + dtype = torch.float16 + if torch.bfloat16 in (init_dtype, final_dtype): + dtype = torch.bfloat16 + w_ref, w_test = make_reference_and_test_tensors( + (size, size), + test_dtype=dtype, + test_device=device, + test_is_fp8=fp8_weight, + ) + + # Construct operation + with te.fp8_model_init(enabled=fp8_weight): + op = te_ops.Linear(size, size, bias=False, device=device, dtype=init_dtype) + with torch.no_grad(): + op.weight.copy_(w_test) + del w_test + + # Cast operation dtype + if final_dtype == torch.float32: + op.float() + elif final_dtype == torch.float16: + op.half() + elif final_dtype == torch.bfloat16: + op.bfloat16() + + # Check weights + assert isinstance(op.weight, Float8Tensor) == fp8_weight + assert op.weight.dtype == final_dtype + w_test = op.weight.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(w_test, w_ref, rtol=0, atol=0) + + # Check forward and backward pass + x = torch.zeros( + (size, size), + dtype=init_dtype, + device=device, + requires_grad=True, + ) + y = op(x) + y.backward(torch.zeros_like(y)) + assert y.dtype == final_dtype + assert x.grad.dtype == init_dtype + assert op.weight.grad.dtype == final_dtype + + @pytest.mark.parametrize("model_dtype", _dtypes) + @pytest.mark.parametrize("autocast_dtype", _dtypes) + @pytest.mark.parametrize("fp8_compute", (False, True)) + def test_pyt_autocast( + self, + *, + size: int = 16, + model_dtype: torch.dtype, + autocast_dtype: torch.dtype, + device: torch.device = "cuda", + fp8_weight: bool = False, + fp8_compute: bool, + ) -> None: + """Test with PyTorch autocast""" + device = torch.device(device) + + # Skip invalid configurations + if fp8_weight or fp8_compute: + if not fp8_available: + pytest.skip(reason_for_no_fp8) + if torch.device(device).type != "cuda": + pytest.skip("FP8 is only supported on CUDA devices") + + # Construct operation + with te.fp8_model_init(enabled=fp8_weight): + op = te_ops.Linear(size, size, bias=False, device=device, dtype=model_dtype) + + # Check forward and backward pass + x = torch.zeros( + (size, size), + dtype=model_dtype, + device=device, + requires_grad=True, + ) + with te.fp8_autocast(enabled=fp8_compute): + with torch.autocast(device_type=device.type, dtype=autocast_dtype): + y = op(x) + y.backward(torch.zeros_like(y)) + assert y.dtype == autocast_dtype + assert x.grad.dtype == model_dtype + assert op.weight.grad.dtype == model_dtype + + # Check forward and backward pass (swapped context order) + if fp8_compute: + x.grad = None + op.weight.grad = None + with torch.autocast(device_type=device.type, dtype=autocast_dtype): + with te.fp8_autocast(enabled=fp8_compute): + y = op(x) + y.backward(torch.zeros_like(y)) + assert y.dtype == autocast_dtype + assert x.grad.dtype == model_dtype + assert op.weight.grad.dtype == model_dtype + class TestBasicOps: """Tests for individual operations""" diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index ad34b4996f..c0f45ada4e 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -1602,10 +1602,12 @@ def test_gpt_cuda_graph(dtype, bs, model): init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - block = TransformerLayer( + block_args = ( config.hidden_size, 4 * config.hidden_size, config.num_attention_heads, + ) + block_kwargs = dict( layernorm_epsilon=config.eps, init_method=init_method, output_layer_init_method=output_layer_init_method, @@ -1617,7 +1619,11 @@ def test_gpt_cuda_graph(dtype, bs, model): output_layernorm=False, device="cuda", ) - graphed_block = copy.deepcopy(block) + block = TransformerLayer(*block_args, **block_kwargs) + graphed_block = TransformerLayer(*block_args, **block_kwargs) + with torch.no_grad(): + for param1, param2 in zip(block.parameters(), graphed_block.parameters()): + param2.copy_(param1) out, grads = _test_gpt_e2e_cuda_graph(block, bs, dtype, config, False) graphed_out, graphed_grads = _test_gpt_e2e_cuda_graph(graphed_block, bs, dtype, config, True) diff --git a/tests/pytorch/test_onnx_export.py b/tests/pytorch/test_onnx_export.py index bdc459cdcc..6a463b556a 100644 --- a/tests/pytorch/test_onnx_export.py +++ b/tests/pytorch/test_onnx_export.py @@ -72,7 +72,7 @@ assert OPSET >= TRILU_OPSET # Shared library implementing custom FP8 Q/DQ operators for ONNX Runtime (ORT). -ORT_CUSTOM_OPS_LIB = os.path.join(TESTS_DIR, "./libcustom_ort_fp8_qdq_ops.so") +ORT_CUSTOM_OPS_LIB = os.path.join(TESTS_DIR, "custom_ort_ops", "libcustom_ort_ops.so") fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() skip_FP8 = pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @@ -85,7 +85,7 @@ @pytest.fixture() def seed_default_rng(): """Reseed the PRNG for test reproducibility""" - torch.random.seed() + torch.manual_seed(1234) @pytest.fixture() diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 2bd512d56b..4f057c12fe 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -9,6 +9,7 @@ import torch import pytest import io +import os from transformer_engine.pytorch.fp8 import ( fp8_autocast, @@ -42,6 +43,7 @@ ) from transformer_engine.pytorch.module.base import get_workspace from test_onnx_export import create_meta +from test_numerics import reset_rng_states, dtype_tols # Only run FP8 tests on H100. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() @@ -1004,20 +1006,50 @@ def test_sanity_fp8_gemm_with_unalignment(N, datatype): @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) -@pytest.mark.skipif(get_device_compute_capability() != (9, 0), reason="FP8 tests require Hopper.") +@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper.") @pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.") @pytest.mark.parametrize("model", ["large"]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) def test_sanity_attention_extra_state(model, dtype): config = model_configs[model] + outputs = _run_attention_extra_state(dtype, config, checkpoint=False) + outputs_checkpoint = _run_attention_extra_state(dtype, config, checkpoint=True) + outputs_checkpoint_v1_6 = _run_attention_extra_state( + dtype, config, mimic_v1_6=True, checkpoint=True + ) + + # Check that results match + tols = dtype_tols(dtype) + if dtype in (torch.float16, torch.bfloat16): + tols.update(dict(rtol=2e-2, atol=2e-3)) + for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint)): + torch.testing.assert_close( + test, + ref, + **tols, + ) + for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint_v1_6)): + torch.testing.assert_close( + test, + ref, + **tols, + ) + + +def _run_attention_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False): + steps = 10 + path = "checkpoint.pt" + fp8_enabled = True fp8_recipe = recipe.DelayedScaling( margin=0, fp8_format=recipe.Format.HYBRID, amax_history_len=1, amax_compute_algo="most_recent", - fp8_dpa=True, + fp8_dpa=fp8_enabled, fp8_mha=False, ) + + reset_rng_states() hidden_states = torch.randn( (config.seq_len, config.batch_size, config.hidden_size), dtype=dtype, @@ -1025,63 +1057,74 @@ def test_sanity_attention_extra_state(model, dtype): requires_grad=True, ) - with fp8_model_init(enabled=True): - block = TransformerLayer( - config.hidden_size, - 4 * config.hidden_size, - config.num_attention_heads, - fuse_qkv_params=True, - params_dtype=dtype, - device="cuda", - ) - with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): - output = block(hidden_states, is_first_microbatch=True) - loss = output.sum() - loss.backward() - - # call state_dict() - sd = block.state_dict() - - # check core_attention._extra_state - attn_extra_state = sd["self_attention.core_attention._extra_state"] - attn_extra_state.seek(0) - attn_extra_state = torch.load(attn_extra_state, map_location="cuda") - - # add random core_attention.fused_attention._extra_state - # it should not be loaded or cause any 'unexpected key' errors - random_state = {"a": 1, "b": 2} - fused_attn_extra_state = io.BytesIO() - torch.save(random_state, fused_attn_extra_state) - sd["self_attention.core_attention.fused_attention._extra_state"] = fused_attn_extra_state - - # save checkpoint - path = "./checkpoint.pt" - torch.save(sd, path) - - # reinit the model - del block - with fp8_model_init(enabled=True): - block_new = TransformerLayer( - config.hidden_size, - 4 * config.hidden_size, - config.num_attention_heads, - fuse_qkv_params=True, - params_dtype=dtype, - device="cuda", - ) - FP8GlobalStateManager.reset() + def get_model(dtype, config): + sigma = 0.023 + init_method = init_method_normal(sigma) + output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) + + with fp8_model_init(enabled=fp8_enabled): + block = TransformerLayer( + config.hidden_size, + 4 * config.hidden_size, + config.num_attention_heads, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + fuse_qkv_params=True, + params_dtype=dtype, + device="cuda", + ) + return block + + block = get_model(dtype, config) + for i in range(steps // 2): + with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe): + output = block(hidden_states, None) + loss = output.sum() + loss.backward() + + if checkpoint: + sd = block.state_dict() + if mimic_v1_6: + sd["self_attention.core_attention.fused_attention._extra_state"] = sd[ + "self_attention.core_attention._extra_state" + ] + del sd["self_attention.core_attention._extra_state"] + torch.save(sd, path) + + param_grads = [] + for p in block.parameters(): + if p.requires_grad: + param_grads.append(p.grad.clone()) + + _cpu_rng_state_new = torch.get_rng_state() + _cuda_rng_state_new = torch.cuda.get_rng_state() + + del block + block = get_model(dtype, config) + block.load_state_dict(torch.load(path)) + torch.set_rng_state(_cpu_rng_state_new) + torch.cuda.set_rng_state(_cuda_rng_state_new) + + for p in block.parameters(): + if p.requires_grad: + p.grad = param_grads.pop(0) + + assert not param_grads, "Oops!" + + for i in range((steps + 1) // 2): + with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe): + output = block(hidden_states, None) + loss = output.sum() + loss.backward() + + torch.cuda.synchronize() + + if os.path.exists(path): + os.remove(path) + + outputs = [output, hidden_states.grad] + for p in block.parameters(): + if p.requires_grad: + outputs.append(p.grad) - # load from checkpoint - block_new.load_state_dict(torch.load(path)) - - # check state_dict - sd_new = block_new.state_dict() - attn_extra_state_new = sd_new["self_attention.core_attention._extra_state"] - attn_extra_state_new.seek(0) - attn_extra_state_new = torch.load(attn_extra_state_new, map_location="cuda") - for k, v in attn_extra_state_new.items(): - if k != "extra_fp8_variables": - assert torch.equal(v, attn_extra_state[k]), f"{k} is not equal" - else: - for ek, ev in attn_extra_state_new["extra_fp8_variables"].items(): - assert ev == attn_extra_state["extra_fp8_variables"][ek], f"{ek} is not equal" + return outputs diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 647d2c474d..cabb2e2aea 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -88,7 +88,6 @@ target_include_directories(transformer_engine PUBLIC # Configure dependencies target_link_libraries(transformer_engine PUBLIC CUDA::cublas - CUDA::cuda_driver CUDA::cudart) target_include_directories(transformer_engine PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 42fb779717..9eff62debf 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -75,9 +75,6 @@ void fused_attn_arbitrary_seqlen_fwd_impl( if (is_ragged) { NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!"); } - if (window_size_left == -1) { - window_size_left = s_q; - } auto cudnn_runtime_version = cudnnGetVersion(); try { @@ -221,8 +218,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( .set_causal_mask_bottom_right(is_bottom_right) .set_attn_scale(attn_scale); - if (cudnn_runtime_version >= 90200 && window_size_left != s_q) { - sdpa_options.set_sliding_window_length(window_size_left); + if (cudnn_runtime_version >= 90200 && window_size_left != -1) { + sdpa_options.set_sliding_window_length(window_size_left + 1); } sdpa_options.set_alibi_mask(is_alibi); @@ -407,9 +404,6 @@ void fused_attn_arbitrary_seqlen_bwd_impl( (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); bool is_dropout = (dropout_probability != 0.0f); bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD); - if (window_size_left == -1) { - window_size_left = s_q; - } auto cudnn_runtime_version = cudnnGetVersion(); const int device_id = cuda::current_device(); const int sm_arch_ = cuda::sm_arch(device_id); @@ -584,8 +578,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl( .set_causal_mask_bottom_right(is_bottom_right) .set_attn_scale(attn_scale); - if (cudnn_runtime_version >= 90200 && window_size_left != s_q) { - sdpa_backward_options.set_sliding_window_length(window_size_left); + if (cudnn_runtime_version >= 90200 && window_size_left != -1) { + sdpa_backward_options.set_sliding_window_length(window_size_left + 1); } if (cudnn_runtime_version >= 90000) { diff --git a/transformer_engine/common/fused_rope/fused_rope.cu b/transformer_engine/common/fused_rope/fused_rope.cu index e7cf940a57..26f104d3ed 100644 --- a/transformer_engine/common/fused_rope/fused_rope.cu +++ b/transformer_engine/common/fused_rope/fused_rope.cu @@ -4,6 +4,7 @@ * See LICENSE for license information. ************************************************************************/ +#include #include #include @@ -15,11 +16,10 @@ namespace transformer_engine { template __device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs, scalar_t *dst, - const int offset_block, const int offset_block_dst, - const int h, const int d, const int d2, const int stride_h, - const int stride_d, const int o_stride_h, - const int o_stride_d) { - int s_id = blockIdx.x; + const int s_id, const int offset_block, + const int offset_block_dst, const int h, const int d, + const int d2, const int stride_h, const int stride_d, + const int o_stride_h, const int o_stride_d) { #pragma unroll for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { float v_cos, v_sin; @@ -52,11 +52,10 @@ __device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs template __device__ void fused_rope_block_backward(const scalar_t *src, const float *freqs, scalar_t *dst, - const int offset_block, const int offset_block_dst, - const int h, const int d, const int d2, - const int stride_h, const int stride_d, + const int s_id, const int offset_block, + const int offset_block_dst, const int h, const int d, + const int d2, const int stride_h, const int stride_d, const int o_stride_h, const int o_stride_d) { - int s_id = blockIdx.x; #pragma unroll for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { float v_cos = cosf(freqs[s_id * d2 + d_id]); @@ -97,8 +96,8 @@ __global__ void fused_rope_forward_kernel(const scalar_t *src, const float *freq int s_id = blockIdx.x, b_id = blockIdx.y; int offset_block = s_id * stride_s + b_id * stride_b; int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b; - fused_rope_block_forward(src, freqs, dst, offset_block, offset_block_dst, h, d, d2, stride_h, - stride_d, o_stride_h, o_stride_d); + fused_rope_block_forward(src, freqs, dst, s_id, offset_block, offset_block_dst, h, d, d2, + stride_h, stride_d, o_stride_h, o_stride_d); } template @@ -111,40 +110,72 @@ __global__ void fused_rope_backward_kernel(const scalar_t *src, const float *fre int s_id = blockIdx.x, b_id = blockIdx.y; int offset_block = s_id * stride_s + b_id * stride_b; int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b; - fused_rope_block_backward(src, freqs, dst, offset_block, offset_block_dst, h, d, d2, stride_h, - stride_d, o_stride_h, o_stride_d); + fused_rope_block_backward(src, freqs, dst, s_id, offset_block, offset_block_dst, h, d, d2, + stride_h, stride_d, o_stride_h, o_stride_d); } template __global__ void fused_rope_thd_forward_kernel(const scalar_t *src, const int *cu_seqlens, - const float *freqs, scalar_t *dst, const int h, - const int d, const int d2, const int stride_t, - const int stride_h, const int stride_d, - const int o_stride_t, const int o_stride_h, - const int o_stride_d) { + const float *freqs, scalar_t *dst, const int cp_size, + const int cp_rank, const int h, const int d, + const int d2, const int stride_t, const int stride_h, + const int stride_d, const int o_stride_t, + const int o_stride_h, const int o_stride_d) { int s_id = blockIdx.x, b_id = blockIdx.y; - int t_id = s_id + cu_seqlens[b_id]; - if (t_id >= cu_seqlens[b_id + 1]) return; + int start = cu_seqlens[b_id] / cp_size; + int end = cu_seqlens[b_id + 1] / cp_size; + int t_id = s_id + start; + if (t_id >= end) return; int offset_block = t_id * stride_t; int offset_block_dst = t_id * o_stride_t; - fused_rope_block_forward(src, freqs, dst, offset_block, offset_block_dst, h, d, d2, stride_h, - stride_d, o_stride_h, o_stride_d); + + int s_id_for_freqs; + if (cp_size > 1) { + int cur_seqlens = end - start; + assert(cur_seqlens % 2 == 0); + if (s_id < cur_seqlens / 2) { + s_id_for_freqs = s_id + cp_rank * cur_seqlens / 2; + } else { + s_id_for_freqs = + cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 + s_id - cur_seqlens / 2; + } + } else { + s_id_for_freqs = s_id; + } + fused_rope_block_forward(src, freqs, dst, s_id_for_freqs, offset_block, offset_block_dst, h, d, + d2, stride_h, stride_d, o_stride_h, o_stride_d); } template __global__ void fused_rope_thd_backward_kernel(const scalar_t *src, const int *cu_seqlens, - const float *freqs, scalar_t *dst, const int h, - const int d, const int d2, const int stride_t, - const int stride_h, const int stride_d, - const int o_stride_t, const int o_stride_h, - const int o_stride_d) { + const float *freqs, scalar_t *dst, const int cp_size, + const int cp_rank, const int h, const int d, + const int d2, const int stride_t, const int stride_h, + const int stride_d, const int o_stride_t, + const int o_stride_h, const int o_stride_d) { int s_id = blockIdx.x, b_id = blockIdx.y; - int t_id = s_id + cu_seqlens[b_id]; - if (t_id >= cu_seqlens[b_id + 1]) return; + int start = cu_seqlens[b_id] / cp_size; + int end = cu_seqlens[b_id + 1] / cp_size; + int t_id = s_id + start; + if (t_id >= end) return; int offset_block = t_id * stride_t; int offset_block_dst = t_id * o_stride_t; - fused_rope_block_backward(src, freqs, dst, offset_block, offset_block_dst, h, d, d2, stride_h, - stride_d, o_stride_h, o_stride_d); + + int s_id_for_freqs; + if (cp_size > 1) { + int cur_seqlens = end - start; + assert(cur_seqlens % 2 == 0); + if (s_id < cur_seqlens / 2) { + s_id_for_freqs = s_id + cp_rank * cur_seqlens / 2; + } else { + s_id_for_freqs = + cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 + s_id - cur_seqlens / 2; + } + } else { + s_id_for_freqs = s_id; + } + fused_rope_block_backward(src, freqs, dst, s_id_for_freqs, offset_block, offset_block_dst, h, d, + d2, stride_h, stride_d, o_stride_h, o_stride_d); } template @@ -182,35 +213,37 @@ void fused_rope_backward_launcher(const scalar_t *output_grads, const float *fre template void fused_rope_thd_forward_launcher(const scalar_t *input, const int *cu_seqlens, - const float *freqs, scalar_t *output, const int max_s, - const int b, const int h, const int d, const int d2, - const int stride_t, const int stride_h, const int stride_d, - const int o_stride_t, const int o_stride_h, - const int o_stride_d, cudaStream_t stream) { + const float *freqs, scalar_t *output, const int cp_size, + const int cp_rank, const int max_s, const int b, const int h, + const int d, const int d2, const int stride_t, + const int stride_h, const int stride_d, const int o_stride_t, + const int o_stride_h, const int o_stride_d, + cudaStream_t stream) { int warps_per_block = h < 16 ? 4 : 8; dim3 blocks(max_s, b); dim3 threads(THREADS_PER_WARP, warps_per_block); - fused_rope_thd_forward_kernel<<>>(input, cu_seqlens, freqs, output, h, - d, d2, stride_t, stride_h, stride_d, - o_stride_t, o_stride_h, o_stride_d); + fused_rope_thd_forward_kernel<<>>( + input, cu_seqlens, freqs, output, cp_size, cp_rank, h, d, d2, stride_t, stride_h, stride_d, + o_stride_t, o_stride_h, o_stride_d); NVTE_CHECK_CUDA(cudaGetLastError()); } template void fused_rope_thd_backward_launcher(const scalar_t *output_grads, const int *cu_seqlens, - const float *freqs, scalar_t *input_grads, const int max_s, - const int b, const int h, const int d, const int d2, - const int stride_t, const int stride_h, const int stride_d, - const int o_stride_t, const int o_stride_h, - const int o_stride_d, cudaStream_t stream) { + const float *freqs, scalar_t *input_grads, const int cp_size, + const int cp_rank, const int max_s, const int b, const int h, + const int d, const int d2, const int stride_t, + const int stride_h, const int stride_d, const int o_stride_t, + const int o_stride_h, const int o_stride_d, + cudaStream_t stream) { int warps_per_block = h < 16 ? 4 : 8; dim3 blocks(max_s, b); dim3 threads(THREADS_PER_WARP, warps_per_block); fused_rope_thd_backward_kernel<<>>( - output_grads, cu_seqlens, freqs, input_grads, h, d, d2, stride_t, stride_h, stride_d, - o_stride_t, o_stride_h, o_stride_d); + output_grads, cu_seqlens, freqs, input_grads, cp_size, cp_rank, h, d, d2, stride_t, stride_h, + stride_d, o_stride_t, o_stride_h, o_stride_d); NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -243,33 +276,34 @@ void fused_rope_backward(const Tensor &output_grads, const Tensor &freqs, Tensor } void fused_rope_thd_forward(const Tensor &input, const Tensor &cu_seqlens, const Tensor &freqs, - Tensor *output, const int max_s, const int b, const int h, const int d, - const int d2, const int stride_t, const int stride_h, - const int stride_d, const int o_stride_t, const int o_stride_h, - const int o_stride_d, cudaStream_t stream) { + Tensor *output, const int cp_size, const int cp_rank, const int max_s, + const int b, const int h, const int d, const int d2, const int stride_t, + const int stride_h, const int stride_d, const int o_stride_t, + const int o_stride_h, const int o_stride_d, cudaStream_t stream) { TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( input.data.dtype, scalar_t, fused_rope_thd_forward_launcher(reinterpret_cast(input.data.dptr), reinterpret_cast(cu_seqlens.data.dptr), reinterpret_cast(freqs.data.dptr), - reinterpret_cast(output->data.dptr), max_s, b, h, - d, d2, stride_t, stride_h, stride_d, o_stride_t, o_stride_h, - o_stride_d, stream);); + reinterpret_cast(output->data.dptr), cp_size, + cp_rank, max_s, b, h, d, d2, stride_t, stride_h, stride_d, + o_stride_t, o_stride_h, o_stride_d, stream);); } void fused_rope_thd_backward(const Tensor &output_grads, const Tensor &cu_seqlens, - const Tensor &freqs, Tensor *input_grads, const int max_s, const int b, - const int h, const int d, const int d2, const int stride_t, - const int stride_h, const int stride_d, const int o_stride_t, - const int o_stride_h, const int o_stride_d, cudaStream_t stream) { + const Tensor &freqs, Tensor *input_grads, const int cp_size, + const int cp_rank, const int max_s, const int b, const int h, + const int d, const int d2, const int stride_t, const int stride_h, + const int stride_d, const int o_stride_t, const int o_stride_h, + const int o_stride_d, cudaStream_t stream) { TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( output_grads.data.dtype, scalar_t, fused_rope_thd_backward_launcher(reinterpret_cast(output_grads.data.dptr), reinterpret_cast(cu_seqlens.data.dptr), reinterpret_cast(freqs.data.dptr), - reinterpret_cast(input_grads->data.dptr), max_s, - b, h, d, d2, stride_t, stride_h, stride_d, o_stride_t, - o_stride_h, o_stride_d, stream);); + reinterpret_cast(input_grads->data.dptr), + cp_size, cp_rank, max_s, b, h, d, d2, stride_t, stride_h, + stride_d, o_stride_t, o_stride_h, o_stride_d, stream);); } } // end namespace transformer_engine @@ -302,30 +336,31 @@ void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor fr } void nvte_fused_rope_thd_forward(const NVTETensor input, const NVTETensor cu_seqlens, - const NVTETensor freqs, NVTETensor output, const int max_s, - const int b, const int h, const int d, const int d2, - const int stride_t, const int stride_h, const int stride_d, - const int o_stride_t, const int o_stride_h, const int o_stride_d, - cudaStream_t stream) { + const NVTETensor freqs, NVTETensor output, const int cp_size, + const int cp_rank, const int max_s, const int b, const int h, + const int d, const int d2, const int stride_t, const int stride_h, + const int stride_d, const int o_stride_t, const int o_stride_h, + const int o_stride_d, cudaStream_t stream) { NVTE_API_CALL(nvte_fused_rope_thd_forward); using namespace transformer_engine; - fused_rope_thd_forward( - *reinterpret_cast(input), *reinterpret_cast(cu_seqlens), - *reinterpret_cast(freqs), reinterpret_cast(output), max_s, b, h, d, - d2, stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, stream); + fused_rope_thd_forward(*reinterpret_cast(input), + *reinterpret_cast(cu_seqlens), + *reinterpret_cast(freqs), + reinterpret_cast(output), cp_size, cp_rank, max_s, b, h, d, d2, + stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, stream); } void nvte_fused_rope_thd_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens, - const NVTETensor freqs, NVTETensor input_grads, const int max_s, - const int b, const int h, const int d, const int d2, - const int stride_t, const int stride_h, const int stride_d, - const int o_stride_t, const int o_stride_h, const int o_stride_d, - cudaStream_t stream) { + const NVTETensor freqs, NVTETensor input_grads, const int cp_size, + const int cp_rank, const int max_s, const int b, const int h, + const int d, const int d2, const int stride_t, const int stride_h, + const int stride_d, const int o_stride_t, const int o_stride_h, + const int o_stride_d, cudaStream_t stream) { NVTE_API_CALL(nvte_fused_rope_thd_backward); using namespace transformer_engine; - fused_rope_thd_backward(*reinterpret_cast(output_grads), - *reinterpret_cast(cu_seqlens), - *reinterpret_cast(freqs), - reinterpret_cast(input_grads), max_s, b, h, d, d2, stride_t, - stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, stream); + fused_rope_thd_backward( + *reinterpret_cast(output_grads), + *reinterpret_cast(cu_seqlens), *reinterpret_cast(freqs), + reinterpret_cast(input_grads), cp_size, cp_rank, max_s, b, h, d, d2, stride_t, + stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, stream); } diff --git a/transformer_engine/common/include/transformer_engine/fused_rope.h b/transformer_engine/common/include/transformer_engine/fused_rope.h index b92de88eca..b7b9b93881 100644 --- a/transformer_engine/common/include/transformer_engine/fused_rope.h +++ b/transformer_engine/common/include/transformer_engine/fused_rope.h @@ -72,6 +72,8 @@ void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor fr * \param[in] cu_seqlens The cumulative sum of sequence lengths tensor. * \param[in] freqs The freqs tensor. * \param[out] output Output tensor. + * \param[in] cp_size Context parallel world size. + * \param[in] cp_rank Context parallel rank. * \param[in] max_s Max sequence length. * \param[in] b Batch size. * \param[in] h Length of the h dimension of input. @@ -86,11 +88,11 @@ void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor fr * \param[in] stream CUDA stream used for the operation. */ void nvte_fused_rope_thd_forward(const NVTETensor input, const NVTETensor cu_seqlens, - const NVTETensor freqs, NVTETensor output, const int max_s, - const int b, const int h, const int d, const int d2, - const int stride_t, const int stride_h, const int stride_d, - const int o_stride_t, const int o_stride_h, const int o_stride_d, - cudaStream_t stream); + const NVTETensor freqs, NVTETensor output, const int cp_size, + const int cp_rank, const int max_s, const int b, const int h, + const int d, const int d2, const int stride_t, const int stride_h, + const int stride_d, const int o_stride_t, const int o_stride_h, + const int o_stride_d, cudaStream_t stream); /*! \brief Compute the backward of the fused rope in thd format. * @@ -98,6 +100,8 @@ void nvte_fused_rope_thd_forward(const NVTETensor input, const NVTETensor cu_seq * \param[in] cu_seqlens The cumulative sum of sequence lengths tensor. * \param[in] freqs The freqs tensor. * \param[out] input_grads Input gradient to calculate. + * \param[in] cp_size Context parallel world size. + * \param[in] cp_rank Context parallel rank. * \param[in] max_s Max sequence length. * \param[in] b Batch size. * \param[in] h Length of the h dimension of output_grads. @@ -112,11 +116,11 @@ void nvte_fused_rope_thd_forward(const NVTETensor input, const NVTETensor cu_seq * \param[in] stream CUDA stream used for the operation. */ void nvte_fused_rope_thd_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens, - const NVTETensor freqs, NVTETensor input_grads, const int max_s, - const int b, const int h, const int d, const int d2, - const int stride_t, const int stride_h, const int stride_d, - const int o_stride_t, const int o_stride_h, const int o_stride_d, - cudaStream_t stream); + const NVTETensor freqs, NVTETensor input_grads, const int cp_size, + const int cp_rank, const int max_s, const int b, const int h, + const int d, const int d2, const int stride_t, const int stride_h, + const int stride_d, const int o_stride_t, const int o_stride_h, + const int o_stride_d, cudaStream_t stream); #ifdef __cplusplus } // extern "C" diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 9b8279be25..b3b11bb9dd 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -86,6 +86,66 @@ def get_qkv_format(qkv_layout): return QKVFormat(nvte_get_qkv_format(qkv_layout.value)) +def make_swa_mask( + max_seqlen_q: int, + max_seqlen_kv: int, + window_size: Optional[Tuple[int, int]] = None, + attn_mask_type: AttnMaskType = AttnMaskType.NO_MASK, + dtype: jax.typing.DTypeLike = jnp.float32, +): + """ + Generate sliding window mask. `True` or `1` means keep the element. + + For `CAUSAL_BOTTOM_RIGHT_MASK` and `PADDING_CAUSAL_BOTTOM_RIGHT_MASK` mask type, + the sliding window diagonal is aligned to the bottom right corner, and for other + mask types, the top left corner. + + Parameters + ---------- + max_seqlen_q: int + Maximum sequence length for queries. + max_seqlen_kv: int + Maximum sequence length for keys and values. + window_size: Optional[Tuple[int, int]] = None + Sliding window size for local attention, where query at position i attends to keys + in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + + window_size[1]] inclusive. Negative number in window size means infinity window. + `None` means no sliding window. + attn_mask_type: AttnMaskType, default = AttnMaskType.NO_MASK + dtype: jax.typing.DTypeLike, default=jnp.float32 + The mask data type. + Returns + ---------- + swa_mask: jax.numpy.tensor + Matrix with shape [max_seqlen_q, max_seqlen_kv]. Elements with value 1 are the positions + that will get attention, value 0 are the masked out positions. + """ + swa_mask = jnp.ones((max_seqlen_q, max_seqlen_kv), dtype=dtype) + if window_size is None: + return swa_mask + bottom_right_masks = [ + AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK, + AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, + ] + left_window, right_window = window_size + if attn_mask_type in bottom_right_masks: + if left_window < 0: + left_window = max_seqlen_kv + if right_window < 0: + right_window = max_seqlen_kv + bottom_right_shift = max_seqlen_kv - max_seqlen_q + swa_mask = jnp.triu(swa_mask, k=-left_window + bottom_right_shift) + swa_mask = jnp.tril(swa_mask, k=right_window + bottom_right_shift) + else: + if left_window < 0: + left_window = max_seqlen_q + if right_window < 0: + right_window = max_seqlen_q + swa_mask = jnp.triu(swa_mask, k=-left_window) + swa_mask = jnp.tril(swa_mask, k=right_window) + return swa_mask + + def canonicalize_attn_mask_type(attn_mask_type: str): """Convert string attn_mask_type to AttnMaskType TE-JAX currently fall back to the padding version kernels for the libraries integration. @@ -129,23 +189,38 @@ def is_fused_attn_kernel_available( q_max_seqlen, kv_max_seqlen, head_dim, + window_size: Optional[Tuple[int, int]] = None, + is_context_parallel: bool = False, ): """ To check whether the fused attention kernel is supported """ - return tex.FusedAttnHelper( - q_dtype, - kv_dtype, - qkv_layout.value, - attn_bias_type.value, - attn_mask_type.value, - dropout_probability, - q_num_heads, - kv_num_heads, - q_max_seqlen, - kv_max_seqlen, - head_dim, - ).is_fused_attn_kernel_available() + + def make_helper(attn_mask_type): + return tex.FusedAttnHelper( + q_dtype, + kv_dtype, + qkv_layout.value, + attn_bias_type.value, + attn_mask_type.value, + dropout_probability, + q_num_heads, + kv_num_heads, + q_max_seqlen, + kv_max_seqlen, + head_dim, + (-1, -1) if window_size is None else window_size, + ) + + if not make_helper(attn_mask_type).is_fused_attn_kernel_available(): + return False + + # For context parallel need to check additional masking types + if is_context_parallel and attn_mask_type == AttnMaskType.CAUSAL_MASK: + if not make_helper(AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK).is_fused_attn_kernel_available(): + return False + + return True def _obtain_batch_and_max_seqlen(qkv, qkv_layout): @@ -167,73 +242,16 @@ def _obtain_batch_and_max_seqlen(qkv, qkv_layout): return batch, q_max_seqlen, kv_max_seqlen -def _reorder_causal_load_balancing(tensor, cp_size: int, tensor_format: QKVFormat, inverse: bool): - match tensor_format: - case QKVFormat.SBHD: - seq_dim = 0 - case QKVFormat.BSHD: - seq_dim = 1 - case _: - raise ValueError(f"{tensor_format=} is not supported for causal load balancing.") - - if cp_size == 1: - return tensor - - if cp_size % 2 != 0: - raise ValueError(f"{cp_size=} must be a multiple of 2.") - - # Need to ensure we have 2 pairs to swap for balancing between cp ranks - if tensor.shape[seq_dim] % (cp_size * 2) != 0: - raise ValueError(f"{tensor.shape=} is not a multiple of {cp_size*2=}") - - # [B, S, H, D] -> [B, 2*cp_size, S/2*cp_size, D] - # [S, B, H, D] -> [2*cp_size, S/2*cp_size, B, H, D] - ori_tensor_shape = tensor.shape - tensor = tensor.reshape( - ( - *ori_tensor_shape[:seq_dim], - 2 * cp_size, - ori_tensor_shape[seq_dim] // (2 * cp_size), - *ori_tensor_shape[seq_dim + 1 :], - ) - ) - - parts = [] - if not inverse: - for cp_rank in range(cp_size): - # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D] - # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D] - index = jnp.array([cp_rank, (2 * cp_size - cp_rank - 1)]) - parts.append(jnp.take(tensor, index, axis=seq_dim)) - else: - for cp_rank in range(cp_size // 2): - # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D] - # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D] - base = 4 * cp_rank - index = jnp.array([base, base + 2]) - parts.append(jnp.take(tensor, index, axis=seq_dim)) - for cp_rank in range(cp_size // 2): - # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D] - # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D] - base = 2 * cp_size - 1 - 4 * cp_rank - index = jnp.array([base, base - 2]) - parts.append(jnp.take(tensor, index, axis=seq_dim)) - - # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] - # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] - combined = jnp.stack(parts, axis=seq_dim) - - return combined.reshape(ori_tensor_shape) - - def reorder_causal_load_balancing(tensor, cp_size: int, tensor_format: QKVFormat): """Reorders a tensor for load balancing the compute of causal attention.""" - return _reorder_causal_load_balancing(tensor, cp_size, tensor_format, False) + seq_dim = 1 if tensor_format == QKVFormat.BSHD else 0 + return tex.attention.reorder_causal_load_balancing(tensor, cp_size, seq_dim, False) def inverse_reorder_causal_load_balancing(tensor, cp_size: int, tensor_format: QKVFormat): """Inverse operation of `reorder_causal_load_balancing`.""" - return _reorder_causal_load_balancing(tensor, cp_size, tensor_format, True) + seq_dim = 1 if tensor_format == QKVFormat.BSHD else 0 + return tex.attention.reorder_causal_load_balancing(tensor, cp_size, seq_dim, True) def fused_attn( @@ -247,6 +265,7 @@ def fused_attn( scaling_factor: float, dropout_probability: float, is_training: bool, + window_size: Optional[Tuple[int, int]] = None, context_parallel_causal_load_balanced: bool = False, context_parallel_axis: str = "", ): @@ -275,6 +294,7 @@ def fused_attn( scaling_factor (float): Scaling factor for the attention scores. dropout_probability (float): Dropout probability to apply during attention. is_training (bool): Flag indicating whether the model is in training mode. + window_size (Optional[Tuple[int, int]]): Sliding window size. context_parallel_causal_load_balanced (bool): Indicates the sequences are ordered for causal mask load balancing when running context parallelism. context_parallel_axis (str): The name of the context parallel axis. @@ -332,6 +352,7 @@ def fused_attn( dropout_probability=dropout_probability, is_training=is_training, max_segments_per_seq=1, + window_size=window_size, context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, context_parallel_axis=context_parallel_axis, ) @@ -354,6 +375,7 @@ def fused_attn_thd( dropout_probability: float, is_training: bool, max_segments_per_seq: int = 1, + window_size: Optional[Tuple[int, int]] = None, context_parallel_causal_load_balanced: bool = False, context_parallel_axis: str = "", ): @@ -394,6 +416,8 @@ def fused_attn_thd( Indicating the maximum number of segments inside a sequence. This parameter is to constrain the limit usage and need to be static during the e2e training. The XLA compile time and memory consumption is proportional to `max_segments_per_seq`. + window_size (Optional[Tuple[int, int]]): + Sliding window size. context_parallel_causal_load_balanced (bool): Indicates the sequences are ordered for causal mask load balancing when running context parallelism. context_parallel_axis (str): The name of the context parallel axis. @@ -451,6 +475,7 @@ def fused_attn_thd( dropout_probability=dropout_probability, is_training=is_training, max_segments_per_seq=max_segments_per_seq, + window_size=window_size, context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, context_parallel_axis=context_parallel_axis, ) @@ -458,7 +483,7 @@ def fused_attn_thd( return output -@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15)) +@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16)) def _fused_attn( qkv: Tuple[jnp.ndarray, ...], bias: Optional[jnp.ndarray], @@ -474,6 +499,7 @@ def _fused_attn( dropout_probability: float, is_training: bool, max_segments_per_seq: int, + window_size: Optional[Tuple[int, int]], context_parallel_causal_load_balanced: bool, context_parallel_axis: str, ): @@ -492,6 +518,7 @@ def _fused_attn( dropout_probability, is_training, max_segments_per_seq, + window_size, context_parallel_causal_load_balanced, context_parallel_axis, ) @@ -513,6 +540,7 @@ def _fused_attn_fwd_rule( dropout_probability, is_training, max_segments_per_seq, + window_size, context_parallel_causal_load_balanced, context_parallel_axis, ): @@ -531,6 +559,7 @@ def _fused_attn_fwd_rule( dropout_probability=dropout_probability, is_training=is_training, max_segments_per_seq=max_segments_per_seq, + window_size=window_size, context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, context_parallel_axis=context_parallel_axis, ) @@ -558,6 +587,7 @@ def _fused_attn_bwd_rule( dropout_probability, is_training, max_segments_per_seq, + window_size, context_parallel_causal_load_balanced, context_parallel_axis, ctx, @@ -592,6 +622,7 @@ def _fused_attn_bwd_rule( dropout_probability=dropout_probability, is_training=is_training, max_segments_per_seq=max_segments_per_seq, + window_size=window_size, context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, context_parallel_axis=context_parallel_axis, ) diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index d5b901c107..7246e961bd 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -63,6 +63,7 @@ "dropout_probability", "is_training", "max_segments_per_seq", + "window_size", "context_parallel_load_balanced", "cp_axis", ], @@ -80,6 +81,7 @@ class _FusedAttnConfig: dropout_probability: float is_training: bool max_segments_per_seq: int + window_size: Tuple[int, int] context_parallel_load_balanced: bool cp_axis: str @@ -101,6 +103,7 @@ class FusedAttnHelper: q_max_seqlen: int kv_max_seqlen: int head_dim: int + window_size: Tuple[int, int] def is_fused_attn_kernel_available(self): """Check if there is available fused attention kernel""" @@ -120,6 +123,8 @@ def get_fused_attn_backend(self): self.q_max_seqlen, self.kv_max_seqlen, self.head_dim, + self.window_size[0], + self.window_size[1], ) @staticmethod @@ -263,6 +268,7 @@ def abstract( q_max_seqlen, kv_max_seqlen, head_dim, + config.window_size, ).get_fused_attn_backend() if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen: @@ -309,6 +315,8 @@ def abstract( jax_dtype_to_te_dtype(q_aval.dtype), config.is_training, config.max_segments_per_seq, + config.window_size[0], + config.window_size[1], ) wkspace_aval = q_aval.update( shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) @@ -388,6 +396,8 @@ def lowering( jax_dtype_to_te_dtype(wkspace_aval.dtype), config.is_training, not FusedAttnHelper.is_non_deterministic_allowed(), + config.window_size[0], + config.window_size[1], ) out = custom_caller(FusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False) @@ -615,6 +625,8 @@ def abstract( config.is_training, deterministic, config.max_segments_per_seq, + config.window_size[0], + config.window_size[1], ) dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype) @@ -714,6 +726,8 @@ def lowering( jax_dtype_to_te_dtype(wkspace_aval.dtype), config.is_training, not FusedAttnHelper.is_non_deterministic_allowed(), + config.window_size[0], + config.window_size[1], ) out = custom_caller(FusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False) @@ -897,6 +911,58 @@ def sharded_impl( register_primitive(FusedAttnBwdPrimitive) +def reorder_causal_load_balancing(tensor, cp_size: int, seq_dim: int, to_contiguous: bool): + """Reorders a tensor for load balancing the compute of causal attention.""" + if cp_size == 1: + return tensor + + if cp_size % 2 != 0: + raise ValueError(f"{cp_size=} must be a multiple of 2.") + + # Need to ensure we have 2 pairs to swap for balancing between cp ranks + if tensor.shape[seq_dim] % (cp_size * 2) != 0: + raise ValueError(f"{tensor.shape=} is not a multiple of {cp_size*2=}") + + # [B, S, H, D] -> [B, 2*cp_size, S/2*cp_size, D] + # [S, B, H, D] -> [2*cp_size, S/2*cp_size, B, H, D] + ori_tensor_shape = tensor.shape + tensor = tensor.reshape( + ( + *ori_tensor_shape[:seq_dim], + 2 * cp_size, + ori_tensor_shape[seq_dim] // (2 * cp_size), + *ori_tensor_shape[seq_dim + 1 :], + ) + ) + + parts = [] + if not to_contiguous: + for cp_rank in range(cp_size): + # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D] + # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D] + index = jnp.array([cp_rank, (2 * cp_size - cp_rank - 1)]) + parts.append(jnp.take(tensor, index, axis=seq_dim)) + else: + for cp_rank in range(cp_size // 2): + # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D] + # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D] + base = 4 * cp_rank + index = jnp.array([base, base + 2]) + parts.append(jnp.take(tensor, index, axis=seq_dim)) + for cp_rank in range(cp_size // 2): + # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D] + # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D] + base = 2 * cp_size - 1 - 4 * cp_rank + index = jnp.array([base, base - 2]) + parts.append(jnp.take(tensor, index, axis=seq_dim)) + + # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] + # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] + combined = jnp.stack(parts, axis=seq_dim) + + return combined.reshape(ori_tensor_shape) + + @dataclass(frozen=True) class _FusedAttnCPWithAllGatherHelper: """Helper class to assist with running the all-gather strategy for CP attention.""" @@ -909,26 +975,30 @@ def check_supported(self): header = "Context parallel fused attention" allowed_layouts = [NVTE_QKV_Layout.NVTE_BSHD_BS2HD, NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD] - assert self.config.qkv_layout in allowed_layouts, ( - f"{header} only supports layouts: {','.join([str(x) for x in allowed_layouts])} got:" - f" {self.config.qkv_layout}" - ) + if self.config.qkv_layout not in allowed_layouts: + raise ValueError( + f"{header} only supports layouts:" + f" {','.join([str(x) for x in allowed_layouts])} got: {self.config.qkv_layout}" + ) - assert ( - self.config.attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS - ), f"{header} does not support bias got: {self.config.attn_bias_type}" + if self.config.attn_bias_type != NVTE_Bias_Type.NVTE_NO_BIAS: + raise ValueError(f"{header} does not support bias got: {self.config.attn_bias_type}") allowed_masks = [NVTE_Mask_Type.NVTE_NO_MASK, NVTE_Mask_Type.NVTE_CAUSAL_MASK] - assert self.config.attn_mask_type in allowed_masks, ( - f"{header} only supports masking types: " - f" {','.join([str(x) for x in allowed_masks])} got: {self.config.attn_mask_type}" - ) + if self.config.attn_mask_type not in allowed_masks: + raise ValueError( + f"{header} only supports masking types: " + f" {','.join([str(x) for x in allowed_masks])} got: {self.config.attn_mask_type}" + ) - assert self.config.max_segments_per_seq == 1, ( - f"{header} only supports max_segments_per_seq == 1 got:" - f" {self.config.max_segments_per_seq}" - ) - assert self.config.dropout_probability == 0.0, f"{header} does not support dropout" + if self.config.max_segments_per_seq != 1: + raise ValueError( + f"{header} only supports max_segments_per_seq == 1 got:" + f" {self.config.max_segments_per_seq}" + ) + + if self.config.dropout_probability != 0.0: + raise ValueError(f"{header} does not support dropout") def get_adjusted_mask(self): """Converts the mask for context parallelism.""" @@ -936,13 +1006,32 @@ def get_adjusted_mask(self): return NVTE_Mask_Type.NVTE_CAUSAL_BOTTOM_RIGHT_MASK return self.config.attn_mask_type + def get_step_config(self) -> _FusedAttnConfig: + """Returns a _FusedAttnConfig for single CP step call to fused attention.""" + return _FusedAttnConfig( + attn_bias_type=self.config.attn_bias_type, + attn_mask_type=self.get_adjusted_mask(), + qkv_layout=self.config.qkv_layout, + scaling_factor=self.config.scaling_factor, + dropout_probability=self.config.dropout_probability, + is_training=self.config.is_training, + max_segments_per_seq=self.config.max_segments_per_seq, + window_size=self.config.window_size, + context_parallel_load_balanced=self.config.context_parallel_load_balanced, + cp_axis=self.config.cp_axis, + ) + def all_gather_kv(self, k, v): """Performs a all-gather of k and v over context parallel ranks.""" def ag(x): - return lax_paral_op( + x = lax_paral_op( x, lax.all_gather, self.config.cp_axis, mesh=self.mesh, axis=1, tiled=True ) + if self.config.context_parallel_load_balanced: + cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh) + x = reorder_causal_load_balancing(x, cp_size, 1, to_contiguous=True) + return x match self.config.qkv_layout: case NVTE_QKV_Layout.NVTE_BSHD_BS2HD: @@ -956,6 +1045,10 @@ def reduce_scatter_dkv(self, dk, dv): """Performs a reduce-scatter of dk and dv over context parallel ranks.""" def rs(x): + if self.config.context_parallel_load_balanced: + cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh) + x = reorder_causal_load_balancing(x, cp_size, 1, to_contiguous=False) + return lax_paral_op( x, lax.psum_scatter, @@ -1042,6 +1135,9 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive): def partition(config, mesh, arg_infos, result_infos): # Call base implementation for non-context parallel mesh to avoid unecessary work. is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1 + assert ( + not is_context_parallel or config.window_size[0] == -1 + ), "Sliding window attention is not supported when context parallelism is enabled" if not is_context_parallel: return FusedAttnFwdPrimitive.partition(config, mesh, arg_infos, result_infos) @@ -1057,7 +1153,6 @@ def partition(config, mesh, arg_infos, result_infos): out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) def impl(q, k, v, bias, q_seqlen, kv_seqlen, q_seq_offsets, k_seq_offsets, seed): - cp_size = get_mesh_axis_size(config.cp_axis, mesh) cp_rank = get_mesh_axis_rank(config.cp_axis, mesh) @@ -1099,7 +1194,7 @@ def _cross_attn(idx, q, k, v, bias, q_seqlen, kv_seqlen, seed): q_seq_offsets, k_seq_offsets, seed, - config=config, + config=helper.get_step_config(), ) results.append((output, softmax_aux, rng_state)) @@ -1136,6 +1231,9 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive): def partition(config, mesh, arg_infos, result_infos): # Call base implementation for non-context parallel mesh to avoid unecessary work. is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1 + assert ( + not is_context_parallel or config.window_size[0] == -1 + ), "Sliding window attention is not supported when context parallelism is enabled" if not is_context_parallel: return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos) @@ -1213,7 +1311,7 @@ def _cross_attn_bwd( kv_seqlen_for_step, q_seq_offsets, k_seq_offsets, - config=config, + config=helper.get_step_config(), ) # pad dk/dv to be unsliced shape so we can reduce scatter over all ranks. @@ -1284,6 +1382,7 @@ def fused_attn_fwd( dropout_probability: float, is_training: bool, max_segments_per_seq: int, + window_size: Optional[Tuple[int, int]] = None, context_parallel_causal_load_balanced: bool = False, context_parallel_axis: str = "", ) -> jnp.ndarray: @@ -1314,6 +1413,11 @@ def fused_attn_fwd( scaling_factor (float): Scaling factor for the attention scores. dropout_probability (float): Dropout probability to apply during attention. is_training (bool): Flag indicating whether the model is in training mode. + max_segments_per_seq (int): + Indicating the maximum number of segments inside a sequence. This parameter is to + constrain the limit usage and need to be static during the e2e training. The XLA compile + time and memory consumption is proportional to `max_segments_per_seq`. + window_size (Optional[Tuple[int, int]]): Sliding window size. context_parallel_causal_load_balanced (bool): Indicates the sequences are ordered for causal mask load balancing when running context parallelism. context_parallel_axis (str): The name of the context parallel axis. @@ -1356,6 +1460,7 @@ def fused_attn_fwd( dropout_probability=dropout_probability, is_training=is_training, max_segments_per_seq=max_segments_per_seq, + window_size=(-1, -1) if window_size is None else window_size, context_parallel_load_balanced=context_parallel_causal_load_balanced, cp_axis=_maybe_context_parallel_axis(context_parallel_axis), ) @@ -1390,6 +1495,7 @@ def fused_attn_bwd( dropout_probability: float, is_training: bool, max_segments_per_seq: int, + window_size: Optional[Tuple[int, int]] = None, context_parallel_causal_load_balanced: bool = False, context_parallel_axis: str = "", ): @@ -1421,6 +1527,11 @@ def fused_attn_bwd( scaling_factor (float): Scaling factor for the attention scores. dropout_probability (float): Dropout probability to apply during attention. is_training (bool): Flag indicating whether the model is in training mode. + max_segments_per_seq (int): + Indicating the maximum number of segments inside a sequence. This parameter is to + constrain the limit usage and need to be static during the e2e training. The XLA compile + time and memory consumption is proportional to `max_segments_per_seq`. + window_size (Optional[Tuple[int, int]]): Sliding window size . context_parallel_causal_load_balanced (bool): Indicates the sequences are ordered for causal mask load balancing when running context parallelism. context_parallel_axis (str): The name of the context parallel axis. @@ -1466,6 +1577,7 @@ def fused_attn_bwd( dropout_probability=dropout_probability, is_training=is_training, max_segments_per_seq=max_segments_per_seq, + window_size=(-1, -1) if window_size is None else window_size, context_parallel_load_balanced=context_parallel_causal_load_balanced, cp_axis=_maybe_context_parallel_axis(context_parallel_axis), ) diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index b872370715..c233177e28 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -135,6 +135,8 @@ struct CustomCallFusedAttnDescriptor { DType wkspace_dtype; bool is_training; bool deterministic; + int64_t window_size_left; + int64_t window_size_right; }; pybind11::bytes PackCustomCallFusedAttnDescriptor( @@ -143,7 +145,7 @@ pybind11::bytes PackCustomCallFusedAttnDescriptor( size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training, - bool deterministic); + bool deterministic, int64_t window_size_left, int64_t window_size_right); // Transpose @@ -239,14 +241,15 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, NVTE_Mask_Type mask_type, float dropout_probability, size_t q_num_heads, size_t kv_num_heads, size_t q_max_seqlen, size_t kv_max_seqlen, - size_t head_dim); + size_t head_dim, int64_t window_size_left, + int64_t window_size_right); pybind11::tuple GetFusedAttnForwardWorkspaceSizes( size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, - size_t max_segments_per_seq); + size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right); void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); @@ -255,7 +258,8 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, - bool deterministic, size_t max_segments_per_seq); + bool deterministic, size_t max_segments_per_seq, int64_t window_size_left, + int64_t window_size_right); void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 1d367f5cc1..90aa3f6e2b 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -15,11 +15,12 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, NVTE_Mask_Type mask_type, float dropout_probability, size_t q_attn_heads, size_t kv_attn_heads, size_t q_max_seqlen, size_t kv_max_seqlen, - size_t head_dim) { + size_t head_dim, int64_t window_size_left, + int64_t window_size_right) { auto backend = nvte_get_fused_attn_backend( static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, bias_type, mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, kv_max_seqlen, - head_dim, head_dim, -1, -1); + head_dim, head_dim, window_size_left, window_size_right); return backend; } @@ -105,7 +106,7 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, - size_t max_segments_per_seq) { + size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right) { // For qkv_packed auto qkv_shape = std::vector{input_batch * q_max_seqlen, 3, attn_heads, head_dim}; auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); @@ -155,27 +156,28 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( TensorWrapper(nullptr, std::vector{num_segments + 1}, DType::kInt32); if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { NVTE_CHECK(q_max_seqlen == kv_max_seqlen, "q_max_seqlen must equal to kv_max_seqlen"); - nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), - o_tensor.data(), &aux_output_tensors, - q_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), - dummy_rng_state_tensor.data(), q_max_seqlen, is_training, - scaling_factor, dropout_probability, qkv_layout, bias_type, - mask_type, -1, -1, query_workspace_tensor.data(), nullptr); + nvte_fused_attn_fwd_qkvpacked( + qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), + &aux_output_tensors, q_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), + dummy_rng_state_tensor.data(), q_max_seqlen, is_training, scaling_factor, + dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, + window_size_right, query_workspace_tensor.data(), nullptr); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { nvte_fused_attn_fwd_kvpacked( q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, - bias_type, mask_type, -1, -1, query_workspace_tensor.data(), nullptr); + bias_type, mask_type, window_size_left, window_size_right, query_workspace_tensor.data(), + nullptr); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { - nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), - s_tensor.data(), o_tensor.data(), &aux_output_tensors, - q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - ragged_offset_tensor.data(), ragged_offset_tensor.data(), - dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, - scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, -1, - -1, query_workspace_tensor.data(), nullptr); + nvte_fused_attn_fwd( + q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), s_tensor.data(), + o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), + kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(), + dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, + dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, + window_size_right, query_workspace_tensor.data(), nullptr); } else { NVTE_ERROR("Unsupported QKVLayout."); } @@ -223,6 +225,8 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s auto dtype = descriptor.dtype; auto is_training = descriptor.is_training; auto max_segments_per_seq = descriptor.max_segments_per_seq; + auto window_size_left = descriptor.window_size_left; + auto window_size_right = descriptor.window_size_right; /* Input tensors */ auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; @@ -269,7 +273,7 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s auto backend = nvte_get_fused_attn_backend( static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, - head_dim, head_dim, -1, -1); + head_dim, head_dim, window_size_left, window_size_right); PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); /* Auxiliary tensors (to be propagated to the backward pass later) */ @@ -288,12 +292,12 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s auto qkv = buffers[0]; auto qkv_shape = std::vector{input_batch * q_max_seqlen, 3, attn_heads, head_dim}; auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype); - nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), - o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), - q_seq_offsets_tensor.data(), rng_state_tensor.data(), - q_max_seqlen, is_training, descriptor.scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, -1, -1, - workspace_tensor.data(), stream); + nvte_fused_attn_fwd_qkvpacked( + qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), + &aux_output_tensors, q_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), + rng_state_tensor.data(), q_max_seqlen, is_training, descriptor.scaling_factor, + dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, + workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { auto q = buffers[0]; auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; @@ -306,7 +310,7 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, - bias_type, mask_type, -1, -1, workspace_tensor.data(), stream); + bias_type, mask_type, window_size_left, window_size_right, workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { auto q = buffers[0]; auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; @@ -322,8 +326,8 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, - scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, -1, - -1, workspace_tensor.data(), stream); + scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, + window_size_left, window_size_right, workspace_tensor.data(), stream); } else { NVTE_ERROR("Unsupported qkv_layout."); } @@ -336,7 +340,8 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, - bool deterministic, size_t max_segments_per_seq) { + bool deterministic, size_t max_segments_per_seq, int64_t window_size_left, + int64_t window_size_right) { // For qkv_packed auto qkv_shape = std::vector{input_batch * q_max_seqlen, 3, attn_heads, head_dim}; auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); @@ -398,8 +403,8 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen, scaling_factor, dropout_probability, qkv_layout, - bias_type, mask_type, -1, -1, deterministic, - query_workspace_tensor.data(), nullptr); + bias_type, mask_type, window_size_left, window_size_right, + deterministic, query_workspace_tensor.data(), nullptr); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { nvte_fused_attn_bwd_kvpacked( q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(), @@ -408,8 +413,9 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen, - kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, -1, - -1, deterministic, query_workspace_tensor.data(), nullptr); + kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, + window_size_left, window_size_right, deterministic, query_workspace_tensor.data(), + nullptr); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), doutput_tensor.data(), @@ -419,8 +425,9 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( dbias_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, - scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, -1, - -1, deterministic, query_workspace_tensor.data(), nullptr); + scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, + window_size_left, window_size_right, deterministic, + query_workspace_tensor.data(), nullptr); } else { NVTE_ERROR("Unsupported qkv_layout."); } @@ -470,6 +477,8 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, auto dtype = descriptor.dtype; auto deterministic = descriptor.deterministic; auto max_segments_per_seq = descriptor.max_segments_per_seq; + auto window_size_left = descriptor.window_size_left; + auto window_size_right = descriptor.window_size_right; /* Input tensors */ auto output_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; @@ -513,7 +522,7 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, auto backend = nvte_get_fused_attn_backend( static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, - head_dim, head_dim, -1, -1); + head_dim, head_dim, window_size_left, window_size_right); PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend, softmax_aux, rng_state, bias); @@ -535,13 +544,14 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, std::accumulate(qkv_shape.cbegin(), qkv_shape.cend(), 1, std::multiplies()); cudaMemsetAsync(dqkv, 0, dqkv_size * typeToSize(dtype), stream); } - nvte_fused_attn_bwd_qkvpacked( - qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(), - q_seq_offsets_tensor.data(), q_max_seqlen, scaling_factor, dropout_probability, qkv_layout, - bias_type, mask_type, -1, -1, deterministic, workspace_tensor.data(), stream); + nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), + s_tensor.data(), // not used for F16 + s_tensor.data(), // not used for F16 + &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), + q_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), + q_max_seqlen, scaling_factor, dropout_probability, qkv_layout, + bias_type, mask_type, window_size_left, window_size_right, + deterministic, workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { auto q = buffers[0]; auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; @@ -568,8 +578,8 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, -1, -1, deterministic, - workspace_tensor.data(), stream); + dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, + deterministic, workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { auto q = buffers[0]; auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; @@ -604,8 +614,8 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, dbias_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, -1, -1, - deterministic, workspace_tensor.data(), stream); + dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, + window_size_right, deterministic, workspace_tensor.data(), stream); } else { NVTE_ERROR("Unsupported qkv_layout."); } diff --git a/transformer_engine/jax/csrc/extensions/packing.cpp b/transformer_engine/jax/csrc/extensions/packing.cpp index 128564db64..298478603b 100644 --- a/transformer_engine/jax/csrc/extensions/packing.cpp +++ b/transformer_engine/jax/csrc/extensions/packing.cpp @@ -69,11 +69,15 @@ pybind11::bytes PackCustomCallFusedAttnDescriptor( size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training, - bool deterministic) { - return PackOpaque(CustomCallFusedAttnDescriptor{ - input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, bias_heads, - head_dim, max_segments_per_seq, wkspace_size, scaling_factor, dropout_probability, bias_type, - mask_type, qkv_layout, dtype, wkspace_dtype, is_training, deterministic}); + bool deterministic, int64_t window_size_left, int64_t window_size_right) { + return PackOpaque( + CustomCallFusedAttnDescriptor{input_batch, bias_batch, q_max_seqlen, + kv_max_seqlen, attn_heads, num_gqa_groups, + bias_heads, head_dim, max_segments_per_seq, + wkspace_size, scaling_factor, dropout_probability, + bias_type, mask_type, qkv_layout, + dtype, wkspace_dtype, is_training, + deterministic, window_size_left, window_size_right}); } } // namespace jax diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index c62c2bb77d..b91584219f 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -25,7 +25,7 @@ from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP from .module import LayerNorm, Softmax from ..attention import AttnBiasType, AttnMaskType, QKVLayout -from ..attention import is_fused_attn_kernel_available, canonicalize_attn_mask_type +from ..attention import is_fused_attn_kernel_available, make_swa_mask, canonicalize_attn_mask_type from ..attention import fused_attn from ..softmax import SoftmaxType from ..sharding import num_of_devices @@ -118,6 +118,7 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public- float32_logits: bool = False scale_factor: Optional[float] = None transpose_batch_sequence: bool = True + window_size: Optional[Tuple[int, int]] = None @nn.compact def __call__( @@ -193,11 +194,27 @@ def __call__( if self.attn_bias_type == AttnBiasType.PRE_SCALE_BIAS: attn_weights += bias + def apply_swa_mask(attn_mask_type: AttnMaskType, original_mask: Array) -> Array: + """Apply the sliding window mask to a given mask""" + max_seqlen_q = original_mask.shape[-2] + max_seqlen_kv = original_mask.shape[-1] + swa_mask = make_swa_mask(max_seqlen_q, max_seqlen_kv, self.window_size, attn_mask_type) + # In swa_mask 0 is masked out, in original_mask 1 is masked out + swa_mask = 1 - swa_mask.astype(original_mask.dtype) + swa_mask_bcast = jnp.broadcast_to(swa_mask, original_mask.shape) + new_mask = jnp.where(original_mask == 0, swa_mask_bcast, original_mask) + return new_mask + def convert_to_softmax_type(attn_mask_type, mask): """Convert the attn_mask_type to SoftmaxType""" - # mask is ignored for no_mask and causal_mask - if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]: + # mask is ignored for no_mask and causal_mask without sliding window + if attn_mask_type == AttnMaskType.NO_MASK: + mask = None + if attn_mask_type == AttnMaskType.CAUSAL_MASK and self.window_size is None: mask = None + if mask is not None: + mask = apply_swa_mask(attn_mask_type, mask) + # Currently cuDNN backend only supports SWA for causal/padding_causal, follow this if attn_mask_type in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]: return SoftmaxType.SCALED_UPPER_TRIANG_MASKED, mask if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.PADDING_MASK]: @@ -244,6 +261,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me qkv_layout: QKVLayout = QKVLayout.BSHD_BSHD_BSHD scale_factor: Optional[float] = None transpose_batch_sequence: bool = False + window_size: Optional[Tuple[int, int]] = None @nn.compact def __call__( @@ -289,6 +307,7 @@ def __call__( scaling_factor=scale_factor, dropout_probability=self.attention_dropout, is_training=not deterministic, + window_size=self.window_size, ) elif self.qkv_layout == QKVLayout.BSHD_BS2HD: """kvpacked format, treat @@ -311,6 +330,7 @@ def __call__( scaling_factor=scale_factor, dropout_probability=self.attention_dropout, is_training=not deterministic, + window_size=self.window_size, ) elif self.qkv_layout == QKVLayout.BSHD_BSHD_BSHD: if self.transpose_batch_sequence: @@ -328,6 +348,7 @@ def __call__( scaling_factor=scale_factor, dropout_probability=self.attention_dropout, is_training=not deterministic, + window_size=self.window_size, ) else: raise ValueError(f"Unsupported {self.qkv_layout=}.") @@ -440,6 +461,8 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods Indicate whether the input tensors were switched axis of batch and sequence length dimension. if set to True, the input tensors should be in (seqlen, batch, ...), otherwise (batch, seqlen, ...). + window_size: Optional[Tuple[int, int]], default = None + Sliding window size. The default value is no sliding window. Optimization parameters ----------------------- @@ -459,6 +482,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods qkv_layout: str = "bshd_bshd_bshd" scale_factor: Optional[float] = None transpose_batch_sequence: bool = True + window_size: Optional[Tuple[int, int]] = None @nn.compact def __call__( @@ -532,6 +556,7 @@ def __call__( seqlen_q, seqlen_kv, self.head_dim, + self.window_size, ) use_fused_attn = enable_fused_attn and has_fused_attn_kernel @@ -577,6 +602,7 @@ def __call__( float32_logits=self.float32_logits, scale_factor=scale_factor, transpose_batch_sequence=self.transpose_batch_sequence, + window_size=self.window_size, )(query, key, value, mask, bias, dropout_rng=dropout_rng, deterministic=deterministic) else: x = _FusedDotProductAttention( @@ -587,6 +613,7 @@ def __call__( scale_factor=scale_factor, transpose_batch_sequence=self.transpose_batch_sequence, qkv_layout=qkv_layout, + window_size=self.window_size, )(query, key, value, mask, bias, dropout_rng=dropout_rng, deterministic=deterministic) return x @@ -856,6 +883,8 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods For fused attention backend, the accumulation is always float32 without the perf overhead. fuse_qkv: bool, default = None Deprecated. Please refer `fuse_qkv_params` + window_size: Optional[Tuple[int, int]], default = None + Sliding window size. Default value is no sliding window. """ head_dim: int @@ -886,6 +915,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods scale_attn_logits: bool = False scaled_query_init: bool = True float32_logits: bool = False + window_size: Optional[Tuple[int, int]] = None # Deprecated parameters num_heads: Optional[int] = None @@ -1280,6 +1310,7 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): qkv_layout=qkv_layout.name, scale_factor=scale_factor, transpose_batch_sequence=self.transpose_batch_sequence, + window_size=self.window_size, )(*dpa_args, mask, bias, deterministic=deterministic) x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3])) @@ -1555,6 +1586,8 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods :math:`\frac{alpha}{rank} * lora\_output`. None means no scaling. enable_sequence_parallel: bool, default = False Whether to enable sequence parallelism to operations except dot. + window_size: Optional[Tuple[int, int]], default = None + Sliding window size. Default value is no sliding window. Optimization parameters ----------------------- @@ -1618,6 +1651,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods enable_sequence_parallel: bool = False scale_attn_logits: bool = False scaled_query_init: bool = True + window_size: Optional[Tuple[int, int]] = None def __post_init__(self): if self.mha_kernel_init is None: @@ -1771,6 +1805,7 @@ def generate_batch_seqlen_logical_axes(is_shared_seq=None): use_bias=self.use_bias, bias_init=self.bias_init, name=mha_name, + window_size=self.window_size, )(inputs, inputs, attention_mask, attn_bias, deterministic=deterministic, decode=decode) def hidden_dropout(x, deterministic): @@ -1848,6 +1883,7 @@ def hidden_dropout(x, deterministic): use_bias=self.use_bias, bias_init=self.bias_init, name="encoder_decoder_attention", + window_size=self.window_size, )(x, encoded, encoder_decoder_mask, deterministic=deterministic) y = with_sharding_constraint_by_logical_axes( diff --git a/transformer_engine/jax/praxis/transformer.py b/transformer_engine/jax/praxis/transformer.py index 2651144eee..f2ac802f10 100644 --- a/transformer_engine/jax/praxis/transformer.py +++ b/transformer_engine/jax/praxis/transformer.py @@ -80,6 +80,7 @@ class DotProductAttention(TransformerEngineBaseLayer): qkv_layout: str = "bshd_bshd_bshd" scale_factor: Optional[float] = None transpose_batch_sequence: bool = True + window_size: Optional[Tuple[int, int]] = None def setup(self) -> None: """setup""" @@ -102,6 +103,7 @@ def setup(self) -> None: qkv_layout=self.qkv_layout, scale_factor=self.scale_factor, transpose_batch_sequence=self.transpose_batch_sequence, + window_size=self.window_size, ) self.create_layer("dot_product_attention", dpa_cls) @@ -151,6 +153,7 @@ class MultiHeadAttention(TransformerEngineBaseLayer): scale_attn_logits: bool = False scaled_query_init: bool = True float32_logits: bool = False + window_size: Optional[Tuple[int, int]] = None # Deprecated parameters num_heads: Optional[int] = None @@ -233,6 +236,7 @@ def setup(self) -> None: scale_attn_logits=self.scale_attn_logits, scaled_query_init=self.scaled_query_init, float32_logits=self.float32_logits, + window_size=self.window_size, ) self.create_layer("multi_head_attn", mha_cls) @@ -292,6 +296,7 @@ class TransformerLayer(TransformerEngineBaseLayer): enable_sequence_parallel: bool = False scale_attn_logits: bool = False scaled_query_init: bool = True + window_size: Optional[Tuple[int, int]] = None def __post_init__(self): if self.num_gqa_groups is None: @@ -371,6 +376,7 @@ def setup(self) -> None: enable_sequence_parallel=self.enable_sequence_parallel, scale_attn_logits=self.scale_attn_logits, scaled_query_init=self.scaled_query_init, + window_size=self.window_size, ) self.create_layer("transformerlayer", transformerlayer_cls) diff --git a/transformer_engine/paddle/cpp_extensions.py b/transformer_engine/paddle/cpp_extensions.py index 7860da2496..281be66a8c 100644 --- a/transformer_engine/paddle/cpp_extensions.py +++ b/transformer_engine/paddle/cpp_extensions.py @@ -583,6 +583,7 @@ def fused_attn_fwd_qkvpacked( fused_attention_backend != FusedAttnBackend["No_Backend"] ), "Fused attention does not support this input combination." + rng_elts_per_thread = None # BF16/FP16 fused attention API from fmha_v1 apex if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: rng_elts_per_thread = ( @@ -773,6 +774,7 @@ def fused_attn_fwd_kvpacked( fused_attention_backend != FusedAttnBackend["No_Backend"] ), "Fused attention does not support this input combination." + rng_elts_per_thread = None # BF16/FP16 fused attention API from fmha_v1 apex if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: rng_elts_per_thread = ( @@ -982,6 +984,7 @@ def fused_attn_fwd( fused_attention_backend != FusedAttnBackend["No_Backend"] ), "Fused attention does not support this input combination." + rng_elts_per_thread = None # BF16/FP16 fused attention API from fmha_v1 apex if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: rng_elts_per_thread = ( diff --git a/transformer_engine/paddle/fp8_buffer.py b/transformer_engine/paddle/fp8_buffer.py index 8f14f3966e..a880ca8107 100644 --- a/transformer_engine/paddle/fp8_buffer.py +++ b/transformer_engine/paddle/fp8_buffer.py @@ -100,6 +100,7 @@ def _reduce_tensor_across_group_op_max(tensor, group, sync_op): self._dp_amax_reduce_interval = int(os.getenv("NVTE_DP_AMAX_REDUCE_INTERVAL", "1")) tp_amax_reduce = False + reduce_group = -1 # Set value that will raise error if not set. `None` is a valid group. if self._dp_amax_reduce_idx == 0: reduce_group = fp8_meta["fp8_group"] else: diff --git a/transformer_engine/paddle/layer/attention.py b/transformer_engine/paddle/layer/attention.py index 75a3513d14..3ff5a42ff5 100644 --- a/transformer_engine/paddle/layer/attention.py +++ b/transformer_engine/paddle/layer/attention.py @@ -1008,6 +1008,7 @@ def forward( else: raise ValueError(f"hidden_states should have 2 or 3 dimensions, got {input_dim}.") + layernorm_output = None if self.attention_type == "self": if self.input_layernorm: layernorm_qkv_outputs = self.layernorm_qkv( diff --git a/transformer_engine/paddle/layer/layernorm_mlp.py b/transformer_engine/paddle/layer/layernorm_mlp.py index 3958897be9..32f837183c 100644 --- a/transformer_engine/paddle/layer/layernorm_mlp.py +++ b/transformer_engine/paddle/layer/layernorm_mlp.py @@ -266,6 +266,8 @@ def _mlp_backward( accumulate_wgrad_into_param_main_grad, ) + dgelu_t = None + fc1_bgrad_ = None if activation == "gelu": # GELU Bwd dgelu, dgelu_t, fc1_bgrad_ = dgelu_cast_transpose_bgrad_fp8( diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index bacadf2cd5..be36b0375a 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -12,6 +12,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import warnings import logging +import functools from dataclasses import dataclass, fields import numpy as np @@ -86,64 +87,125 @@ from transformer_engine.pytorch.graph import is_graph_capturing +# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0 +_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) +# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0 +_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0")) +_log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL +_log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG} +_log_level = _log_levels[_log_level if _log_level in [0, 1, 2] else 2] +_formatter = logging.Formatter("[%(levelname)-8s | %(name)-19s]: %(message)s") +_stream_handler = logging.StreamHandler() +_stream_handler.setFormatter(_formatter) +fa_logger = logging.getLogger() +fa_logger.setLevel(_log_level) +if not fa_logger.hasHandlers(): + fa_logger.addHandler(_stream_handler) + + +@functools.lru_cache(maxsize=None) +def _get_supported_versions(version_min, version_max): + return ">= " + str(version_min) + ", " + "<= " + str(version_max) + + _NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1")) _NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1")) _NVTE_UNFUSED_ATTN = int(os.getenv("NVTE_UNFUSED_ATTN", "1")) -_flash_attn_version = PkgVersion(get_pkg_version("flash-attn")) -_flash_attn_version_required = PkgVersion("2.0.6") + +# Detect flash-attn v2 in the environment +_flash_attn_is_installed = False +_flash_attn_version = PkgVersion("0") +_flash_attn_version_required = PkgVersion("2.1.1") _flash_attn_max_version = PkgVersion("2.6.3") -_flash_attn_2_plus = _flash_attn_version >= PkgVersion("2") -_flash_attn_2_1_plus = _flash_attn_version >= PkgVersion("2.1") -_flash_attn_2_3_plus = _flash_attn_version >= PkgVersion("2.3") -_flash_attn_2_4_plus = _flash_attn_version >= PkgVersion("2.4") -_flash_attn_2_4_1_plus = _flash_attn_version >= PkgVersion("2.4.1") -_flash_attn_2_5_7_plus = _flash_attn_version >= PkgVersion("2.5.7") -_flash_attn_3_plus = False +_flash_attn_2_plus = False +_flash_attn_2_1_plus = False +_flash_attn_2_3_plus = False +_flash_attn_2_4_plus = False +_flash_attn_2_4_1_plus = False +_flash_attn_2_5_7_plus = False +_flash_attn_2_6_0_plus = False + +flash_attn_func = None +flash_attn_varlen_func = None +flash_attn_varlen_fwd = None +flash_attn_varlen_bwd = None +flash_attn_cuda_bwd = None + +try: + _flash_attn_version = PkgVersion(get_pkg_version("flash-attn")) +except PackageNotFoundError: + if torch.cuda.is_available() and get_device_compute_capability() >= (8, 0) and _NVTE_FLASH_ATTN: + fa_logger.debug( + "flash-attn v2 is not installed. To use, please install it by" + """ "pip install flash-attn".""", + ) +else: + if _flash_attn_version_required <= _flash_attn_version <= _flash_attn_max_version: + from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func + from flash_attn.flash_attn_interface import ( + _flash_attn_varlen_forward as flash_attn_varlen_fwd, + ) + from flash_attn.flash_attn_interface import ( + _flash_attn_varlen_backward as flash_attn_varlen_bwd, + ) + from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd + + _flash_attn_is_installed = True + _flash_attn_2_plus = _flash_attn_version >= PkgVersion("2") + _flash_attn_2_1_plus = _flash_attn_version >= PkgVersion("2.1") + _flash_attn_2_3_plus = _flash_attn_version >= PkgVersion("2.3") + _flash_attn_2_4_plus = _flash_attn_version >= PkgVersion("2.4") + _flash_attn_2_4_1_plus = _flash_attn_version >= PkgVersion("2.4.1") + _flash_attn_2_5_7_plus = _flash_attn_version >= PkgVersion("2.5.7") + _flash_attn_2_6_0_plus = _flash_attn_version >= PkgVersion("2.6.0") + elif ( + torch.cuda.is_available() and get_device_compute_capability() >= (8, 0) and _NVTE_FLASH_ATTN + ): + fa_logger.warning( + "Supported flash-attn versions are %s. Found flash-attn %s.", + _get_supported_versions( + _flash_attn_version_required, + _flash_attn_max_version, + ), + _flash_attn_version, + ) + +# Detect flash-attn v3 in the environment +# This section will be removed when FA3 is released as a regular FA package, +# i.e. flashattn-hopper 3.0.0 as flash-attn 3.0.0 +_flash_attn_3_is_installed = False +_flash_attn_3_version = PkgVersion("0") +_flash_attn_3_0_0_beta = False _use_flash_attn_3 = False +_flash_attn_3_installation_steps = """\ +(1) pip install "git+https://github.com/Dao-AILab/flash-attention.git#egg=flashattn-hopper&subdirectory=hopper" +(2) python_path=`python -c "import site; print(site.getsitepackages()[0])"` +(3) mkdir -p $python_path/flashattn_hopper +(4) wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/main/hopper/flash_attn_interface.py""" try: - _flash_attn_v3_version = PkgVersion(get_pkg_version("flashattn-hopper")) - _flash_attn_3_plus = _flash_attn_v3_version >= PkgVersion("2.6.1") + _flash_attn_3_version = PkgVersion(get_pkg_version("flashattn-hopper")) except PackageNotFoundError: - if get_device_compute_capability() == (9, 0) and _NVTE_FLASH_ATTN: - warnings.warn( - "To use flash-attn v3, please use the following commands to install: \n" - """(1) pip install "git+https://github.com/Dao-AILab/flash-attention.git#egg=flashattn-hopper&subdirectory=hopper" \n""" - """(2) python_path=`python -c "import site; print(site.getsitepackages()[0])"` \n""" - """(3) mkdir -p $python_path/flashattn_hopper \n""" - """(4) wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/main/hopper/flash_attn_interface.py""" + if torch.cuda.is_available() and get_device_compute_capability() >= (9, 0) and _NVTE_FLASH_ATTN: + fa_logger.debug( + "flash-attn v3 is not installed. To use, please install it by \n%s", + _flash_attn_3_installation_steps, ) else: from flashattn_hopper.flash_attn_interface import flash_attn_func as flash_attn_func_v3 from flashattn_hopper.flash_attn_interface import ( flash_attn_varlen_func as flash_attn_varlen_func_v3, ) - from flashattn_hopper.flash_attn_interface import ( # pylint: disable=unused-import - _flash_attn_forward as _flash_attn_forward_v3, + from flashattn_hopper.flash_attn_interface import ( + _flash_attn_varlen_forward as flash_attn_varlen_fwd_v3, ) - from flashattn_hopper.flash_attn_interface import ( # pylint: disable=unused-import - _flash_attn_backward as _flash_attn_backward_v3, + from flashattn_hopper.flash_attn_interface import ( + _flash_attn_varlen_backward as flash_attn_varlen_bwd_v3, ) + _flash_attn_3_is_installed = True + _flash_attn_3_0_0_beta = PkgVersion("3.0.0b") < _flash_attn_3_version < PkgVersion("3.0.0") _use_flash_attn_3 = True -if _flash_attn_version >= _flash_attn_version_required: - from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func - from flash_attn.flash_attn_interface import _flash_attn_varlen_forward as _flash_attn_forward - from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as _flash_attn_backward - from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd - - -# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0 -_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) -# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0 -_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0")) -_log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL -_log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG} -_log_level = _log_levels[_log_level if _log_level in [0, 1, 2] else 2] -_formatter = logging.Formatter("[%(levelname)-8s | %(name)-19s]: %(message)s") -_stream_handler = logging.StreamHandler() -_stream_handler.setFormatter(_formatter) - _attention_backends = { "attention_params": None, "use_flash_attention": None, @@ -251,6 +313,11 @@ class AttentionParams: __all__ = ["DotProductAttention", "InferenceParams", "MultiheadAttention"] +def maybe_contiguous(tensor: torch.Tensor) -> torch.Tensor: + """Make tensor contiguous if final stride is not 1.""" + return tensor.contiguous() if tensor.stride(-1) != 1 else tensor + + def get_attention_backend( attention_params: AttentionParams = None, ): @@ -309,10 +376,13 @@ def get_attention_backend( run_config = { "transformer_engine_version": te.__version__, "compute_capability": "sm" - + str( - (lambda x, y: x * 10 + y)(device_compute_capability[0], device_compute_capability[1]) + + str(10 * device_compute_capability[0] + device_compute_capability[1]), + "flash_attn_version": ( + str(_flash_attn_version) if _flash_attn_is_installed else "not installed" + ), + "flash_attn_3_version": ( + str(_flash_attn_3_version) if _flash_attn_3_is_installed else "not installed" ), - "flash_attn_version": _flash_attn_version, "cudnn_version": ".".join([str(i) for i in cudnn_version]), } attention_params_dict = { @@ -323,15 +393,17 @@ def get_attention_backend( run_config["NVTE_FP8_DPA_BWD"] = int(os.getenv("NVTE_FP8_DPA_BWD", "1")) logger.debug("Running with config=%s", run_config) + # The following sections check if `FlashAttention` supports the provided attention params, + # regardless of whether FA2 or FA3 is installed. If FA2 or FA3 is not installed but is + # necessary for performance/functionality, a warning will be issued to prompt users to + # install an appropriate FA version. + global _flash_attn_version_required, _flash_attn_max_version, _use_flash_attn_3 + # Filter: Environment variables - global _NVTE_FLASH_ATTN, _NVTE_FUSED_ATTN, _NVTE_UNFUSED_ATTN - _NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1")) - _NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1")) - _NVTE_UNFUSED_ATTN = int(os.getenv("NVTE_UNFUSED_ATTN", "1")) - use_flash_attention = _NVTE_FLASH_ATTN - use_fused_attention = _NVTE_FUSED_ATTN - use_unfused_attention = _NVTE_UNFUSED_ATTN - if not use_flash_attention: + use_flash_attention = int(os.getenv("NVTE_FLASH_ATTN", "1")) + use_fused_attention = int(os.getenv("NVTE_FUSED_ATTN", "1")) + use_unfused_attention = int(os.getenv("NVTE_UNFUSED_ATTN", "1")) + if not use_flash_attention and _flash_attn_is_installed: logger.debug("Disabling FlashAttention due to NVTE_FLASH_ATTN=0") if not use_fused_attention: logger.debug("Disabling FusedAttention due to NVTE_FUSED_ATTN=0") @@ -340,7 +412,7 @@ def get_attention_backend( # Filter: ONNX mode if is_in_onnx_export_mode(): - if use_flash_attention: + if use_flash_attention and _flash_attn_is_installed: logger.debug("Disabling FlashAttention due to ONNX mode") use_flash_attention = False if use_fused_attention: @@ -348,32 +420,31 @@ def get_attention_backend( use_fused_attention = False # Filter: Compute capability - global _flash_attn_3_plus, _use_flash_attn_3 if device_compute_capability < (8, 0): - if use_flash_attention: + if use_flash_attention and _flash_attn_is_installed: logger.debug("Disabling FlashAttention as it requires compute capability sm80+") - use_flash_attention = False + use_flash_attention = False if use_fused_attention: logger.debug("Disabling FusedAttention as it requires compute capability sm80+") use_fused_attention = False if device_compute_capability < (9, 0): - if use_flash_attention and _flash_attn_3_plus: + if use_flash_attention and _flash_attn_3_is_installed: logger.debug("Disabling FlashAttention 3 as it requires compute capability sm90+") - _use_flash_attn_3 = False + _use_flash_attn_3 = False # Filter: Data type if qkv_dtype not in [torch.bfloat16, torch.float16] or qkv_type not in [ torch.Tensor, Float8Tensor, ]: - if use_flash_attention: + if use_flash_attention and _flash_attn_is_installed: logger.debug( "Disabling FlashAttention due to unsupported QKV data type. " "Supported: qkv_dtype = {torch.bfloat16, torch.float16}. " "Found: qkv_dtype = %s.", qkv_dtype, ) - use_flash_attention = False + use_flash_attention = False if use_fused_attention: logger.debug( "Disabling FusedAttention due to unsupported QKV data type. " @@ -386,7 +457,8 @@ def get_attention_backend( # Filter: Execution type if fp8 and fp8_meta["recipe"].fp8_dpa: if use_flash_attention and not _use_flash_attn_3: - logger.debug("Disabling FlashAttention as FlashAttention 2 does not support FP8") + if _flash_attn_is_installed: + logger.debug("Disabling FlashAttention as FlashAttention 2 does not support FP8") use_flash_attention = False if use_flash_attention and _use_flash_attn_3 and is_training: logger.debug( @@ -399,22 +471,24 @@ def get_attention_backend( # Filter: Head dimension if use_flash_attention and head_dim_qk != head_dim_v: - logger.debug("Disabling FlashAttention as it does not support MLA.") + if _flash_attn_is_installed: + logger.debug("Disabling FlashAttention as it does not support MLA.") use_flash_attention = False if use_flash_attention and ( head_dim_qk > 256 or head_dim_qk % 8 != 0 or (head_dim_qk > 192 and device_compute_capability not in ((8, 0), (9, 0))) ): - logger.debug( - "Disabling FlashAttention due to unsupported head_dim_qk and head_dim_v. " - "Supported: head_dim_qk = head_dim_v, head_dim_qk %%8 = 0, " - "head_dim_qk <= 256 (>192 requires sm80/90). " - "Found: head_dim_qk = %s, head_dim_v = %s, on sm%s.", - head_dim_qk, - head_dim_v, - ".".join([str(i) for i in device_compute_capability]), - ) + if _flash_attn_is_installed: + logger.debug( + "Disabling FlashAttention due to unsupported head_dim_qk and head_dim_v. " + "Supported: head_dim_qk = head_dim_v, head_dim_qk %%8 = 0, " + "head_dim_qk <= 256 (>192 requires sm80/90). " + "Found: head_dim_qk = %s, head_dim_v = %s, on sm%s.", + head_dim_qk, + head_dim_v, + ".".join([str(i) for i in device_compute_capability]), + ) use_flash_attention = False qkv_layout_group = qkv_layout.replace("b", "").replace("s", "").replace("t", "") if use_fused_attention and head_dim_qk != head_dim_v and qkv_layout_group != "hd_hd_hd": @@ -431,17 +505,17 @@ def get_attention_backend( logger.debug("Disabling UnfusedDotProductAttention for qkv_format = thd") use_unfused_attention = False if use_flash_attention and pad_between_seqs: - logger.debug( - "Disabling FlashAttention for qkv_format = thd when there is " - "padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]" - ) + if _flash_attn_is_installed: + logger.debug( + "Disabling FlashAttention for qkv_format = thd when there is " + "padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]" + ) use_flash_attention = False # Filter: Dropout - if attention_dropout != 0.0 and use_flash_attention: - if _flash_attn_3_plus and _use_flash_attn_3: - logger.debug("Disabling FlashAttention 3 for dropout") - _use_flash_attn_3 = False + if attention_dropout != 0.0 and use_flash_attention and _use_flash_attn_3: + logger.debug("Disabling FlashAttention 3 for dropout") + _use_flash_attn_3 = False # Filter: Context parallelism # qkv_format | attn_mask_type | attn_bias_type | supported backends @@ -461,38 +535,40 @@ def get_attention_backend( ) use_unfused_attention = False if context_parallel and use_flash_attention: - if _flash_attn_3_plus and _use_flash_attn_3: - logger.debug("Disabling FlashAttention 3 for context parallelism") - _use_flash_attn_3 = False if fp8 and fp8_meta["recipe"].fp8_dpa: - logger.debug( - "Disabling FlashAttention as it does not support context parallelism with FP8" - ) + if _flash_attn_is_installed: + logger.debug( + "Disabling FlashAttention as it does not support context parallelism with FP8" + ) use_flash_attention = False if "bottom_right" in attn_mask_type: - logger.debug( - "Disabling FlashAttention as it does not support context parallelism with" - " causal_bottom_right masking" - ) + if _flash_attn_is_installed: + logger.debug( + "Disabling FlashAttention as it does not support context parallelism with" + " causal_bottom_right masking" + ) use_flash_attention = False elif "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv: - logger.debug( - "Disabling FlashAttention as it does not support context parallelism with causal" - " masking for cross-attention" - ) + if _flash_attn_is_installed: + logger.debug( + "Disabling FlashAttention as it does not support context parallelism with" + " causal masking for cross-attention" + ) use_flash_attention = False elif core_attention_bias_type not in ["no_bias", "post_scale_bias"]: - logger.debug( - "Disabling FlashAttention as it does not support context parallelism with bias type" - " of %s", - core_attention_bias_type, - ) + if _flash_attn_is_installed: + logger.debug( + "Disabling FlashAttention as it does not support context parallelism with bias" + " type of %s", + core_attention_bias_type, + ) use_flash_attention = False elif qkv_format == "thd" and core_attention_bias_type != "no_bias": - logger.debug( - "Disabling FlashAttention as it does not support context parallelism with attention" - " bias for THD format" - ) + if _flash_attn_is_installed: + logger.debug( + "Disabling FlashAttention as it does not support context parallelism with" + " attention bias for THD format" + ) use_flash_attention = False if context_parallel and use_fused_attention: @@ -548,7 +624,7 @@ def get_attention_backend( # arbitrary | One tensor in shape broadcastable to | UnfusedDotProductAttention # | [b, h, sq, skv] | if attn_mask_type == "arbitrary": - if use_flash_attention: + if use_flash_attention and _flash_attn_is_installed: logger.debug("Disabling FlashAttention for arbitrary mask") use_flash_attention = False if use_fused_attention: @@ -556,7 +632,7 @@ def get_attention_backend( use_fused_attention = False if ( use_flash_attention - and _flash_attn_3_plus + and _use_flash_attn_3 and attn_mask_type in ["causal", "padding_causal"] and max_seqlen_q != max_seqlen_kv ): @@ -568,28 +644,41 @@ def get_attention_backend( _use_flash_attn_3 = False if ( use_flash_attention - and _flash_attn_2_1_plus and attn_mask_type in ["causal", "padding_causal"] and max_seqlen_q != max_seqlen_kv ): - logger.warning( - "Disabling FlashAttention as it only supports bottom-right-diagonal " - "causal mask since flash-attn 2.1. See " - "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag" - ) - use_flash_attention = False + if _flash_attn_2_1_plus: + logger.warning( + "Disabling FlashAttention as it only supports bottom-right-diagonal " + "causal mask since flash-attn 2.1. See " + "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag" + ) + use_flash_attention = False + if not _flash_attn_is_installed: + _flash_attn_max_version = PkgVersion("2.1") if ( use_flash_attention - and not _flash_attn_2_1_plus and attn_mask_type in ["causal_bottom_right", "padding_causal_bottom_right"] and max_seqlen_q != max_seqlen_kv ): - logger.warning( - "Disabling FlashAttention as it only supports top-left-diagonal " - "causal mask before flash-attn 2.1. See " - "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag" - ) - use_flash_attention = False + if not _flash_attn_is_installed: + _flash_attn_version_required = PkgVersion("2.1") + elif not _flash_attn_2_1_plus and not _use_flash_attn_3: + logger.warning( + "Disabling FlashAttention as it only supports top-left-diagonal " + "causal mask before flash-attn 2.1. See " + "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag" + ) + use_flash_attention = False + if ( + use_flash_attention + and _use_flash_attn_3 + and fp8 + and fp8_meta["recipe"].fp8_dpa + and "padding" in attn_mask_type + ): + logger.debug("Disabling FlashAttention 3 for FP8 and padding masks") + _use_flash_attn_3 = False # Filter: Sliding window attention # backend | window_size | diagonal alignment @@ -633,24 +722,19 @@ def get_attention_backend( attn_mask_type, ) use_fused_attention = False - if ( - use_flash_attention - and (window_size[0] != -1 or window_size[1] not in [-1, 0]) - and _flash_attn_3_plus - ): - logger.debug( - "Disabling FlashAttention 3 as it does not support sliding window attention" - ) - _use_flash_attn_3 = False - if ( - use_flash_attention - and (window_size[0] != -1 or window_size[1] not in [-1, 0]) - and not _flash_attn_2_3_plus - ): - logger.debug( - "Disabling FlashAttention as sliding window attention requires flash-attn 2.3+" - ) - use_flash_attention = False + if use_flash_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]): + if _use_flash_attn_3: + logger.debug( + "Disabling FlashAttention 3 as it does not support sliding window attention" + ) + _use_flash_attn_3 = False + if not _flash_attn_is_installed: + _flash_attn_version_required = PkgVersion("2.3") + elif not _flash_attn_2_3_plus: + logger.debug( + "Disabling FlashAttention as sliding window attention requires flash-attn 2.3+" + ) + use_flash_attention = False # Filter: Attention bias # backend | bias types | ALiBi diagonal alignment @@ -662,18 +746,21 @@ def get_attention_backend( # UnfusedDotProductAttention | no_bias, pre/post_scale_bias | # | alibi/alibi_slopes | both; converts to a 'post_scale_bias' bias if use_flash_attention and core_attention_bias_type == "alibi": - if _flash_attn_3_plus and _use_flash_attn_3: + if _use_flash_attn_3: logger.debug("Disabling FlashAttention 3 for ALiBi") _use_flash_attn_3 = False - if not _flash_attn_2_4_plus: - logger.debug("Disabling FlashAttention for ALiBi") - use_flash_attention = False + if not _flash_attn_is_installed: + _flash_attn_version_required = PkgVersion("2.4") + elif not _flash_attn_2_4_plus: + logger.debug("Disabling FlashAttention as ALiBi requires flash-attn 2.4+") + use_flash_attention = False if use_flash_attention and ( core_attention_bias_type not in ["no_bias", "alibi"] or core_attention_bias_shape is not None ): - logger.debug("Disabling FlashAttention for pre/post_scale_bias") + if _flash_attn_is_installed: + logger.debug("Disabling FlashAttention for pre/post_scale_bias") use_flash_attention = False fu_core_attention_bias_type = core_attention_bias_type @@ -777,13 +864,16 @@ def get_attention_backend( # | otherwise: no # sub-backend 2 | no # UnfusedDotProductAttention | yes - if use_flash_attention and deterministic and not _flash_attn_2_4_1_plus: - logger.warning( - "Disabling FlashAttention as version <2.4.1 does not support deterministic " - "execution. To use FlashAttention with deterministic behavior, " - "please install flash-attn >= 2.4.1." - ) - use_flash_attention = False + if use_flash_attention and deterministic: + if not _flash_attn_is_installed: + _flash_attn_version_required = PkgVersion("2.4.1") + elif not _flash_attn_2_4_1_plus and not _use_flash_attn_3: + logger.warning( + "Disabling FlashAttention as version <2.4.1 does not support deterministic " + "execution. To use FlashAttention with deterministic behavior, " + "please install flash-attn >= 2.4.1." + ) + use_flash_attention = False if use_fused_attention and deterministic: if fused_attention_backend == FusedAttnBackend["FP8"] and is_training: logger.debug("Disabling FusedAttention for determinism reasons") @@ -802,6 +892,23 @@ def get_attention_backend( # All available backends available_backends = [use_flash_attention, use_fused_attention, use_unfused_attention] + + # `FusedAttention` and `FlashAttention` are faster backends than `UnfusedDotProductAttention`. + # When `FusedAttention` does not support the provided attention params, and `FlashAttention` + # does, we recommend users to install flash-attn if not installed already. + if not use_fused_attention and use_flash_attention and not _flash_attn_is_installed: + logger.warning( + "flash-attn may provide important feature support or performance improvement." + " Please install flash-attn %s.", + _get_supported_versions( + _flash_attn_version_required, + _flash_attn_max_version, + ), + ) + if use_flash_attention and not _flash_attn_is_installed: + use_flash_attention = False + available_backends[0] = False + logger.debug( "Available backends = {FlashAttention=%s, FusedAttention=%s%s," " UnfusedDotProductAttention=%s}", @@ -827,10 +934,6 @@ def get_attention_backend( "for performance reasons" ) use_flash_attention = False - - # Select FusedAttention for FP8 - # FA3 uses default scaling factors (i.e. 1) in FP8 execution, while FusedAttention takes - # scaling factors from `fp8_meta` and offers more accurate quantization/de-quantization if ( use_flash_attention and use_fused_attention @@ -838,8 +941,8 @@ def get_attention_backend( and _use_flash_attn_3 ): logger.debug( - "Disabling FlashAttention 3 to give FusedAttention preference as FusedAttention " - "supports more accurate scaling factors in FP8 execution" + "Disabling FlashAttention 3 to give FusedAttention preference for performance reasons " + "in FP8 execution" ) use_flash_attention = False @@ -1044,8 +1147,11 @@ def get_alibi( assert _alibi_cache["_alibi_slopes"] is not None, "ALiBi slopes can not be None!" if _alibi_cache["_alibi_slopes"].dim() == 1: slopes_shape = torch.Size([1, _alibi_cache["_alibi_slopes"].shape[0], 1, 1]) - if _alibi_cache["_alibi_slopes"].dim() == 2: + elif _alibi_cache["_alibi_slopes"].dim() == 2: slopes_shape = torch.Size([*_alibi_cache["_alibi_slopes"].shape[:], 1, 1]) + else: + raise ValueError("ALiBi slopes cannot exceed 2 dimensions.") + bias = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view( 1, 1, max_seqlen_q, 1 ) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view( @@ -1281,6 +1387,7 @@ class PackTensors(torch.autograd.Function): def forward( ctx, indices: torch.Tensor, *tensors: Tuple[torch.Tensor, ...] ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: + # pylint: disable=missing-function-docstring assert 1 <= len(tensors) <= 3, f"Packing {len(tensors)} tensors not supported." ctx.save_for_backward(indices) ctx.dim0 = tensors[0].shape[0] @@ -1292,6 +1399,7 @@ def forward( @staticmethod def backward(ctx, *grad_outputs: Tuple[torch.Tensor, ...]): + # pylint: disable=missing-function-docstring (indices,) = ctx.saved_tensors if len(grad_outputs) == 1: return None, unpack_tensor(indices, ctx.dim0, *grad_outputs) @@ -1312,11 +1420,13 @@ def forward( dim0: int, tensor: torch.Tensor, ) -> torch.Tensor: + # pylint: disable=missing-function-docstring ctx.save_for_backward(indices) return unpack_tensor(indices, dim0, tensor) @staticmethod def backward(ctx, grad_output): + # pylint: disable=missing-function-docstring (indices,) = ctx.saved_tensors return None, None, pack_tensor(indices, grad_output) @@ -1364,16 +1474,28 @@ def flash_attn_p2p_communicate( @jit_fuser -def flash_attn_fwd_out_correction(out, out_per_step, seq_dim, softmax_lse, softmax_lse_per_step): +def flash_attn_fwd_out_correction( + out: torch.Tensor, + out_per_step: torch.Tensor, + softmax_lse: torch.Tensor, + softmax_lse_per_step: torch.Tensor, + movedim_src: int, + movedim_dst: int, +): """Merge partial outputs of each step in Attention with context parallelism""" - softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).movedim(2, seq_dim) + softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).movedim( + movedim_src, movedim_dst + ) softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1) out_corrected = out_per_step * softmax_lse_corrected_exp out.add_(out_corrected) @jit_fuser -def flash_attn_fwd_softmax_lse_correction(softmax_lse, softmax_lse_per_step): +def flash_attn_fwd_softmax_lse_correction( + softmax_lse: torch.Tensor, + softmax_lse_per_step: torch.Tensor, +): """Merge softmax stats of each step in Attention with context parallelism""" max_scale = torch.max(softmax_lse, softmax_lse_per_step) min_scale = torch.min(softmax_lse, softmax_lse_per_step) @@ -1383,7 +1505,12 @@ def flash_attn_fwd_softmax_lse_correction(softmax_lse, softmax_lse_per_step): @jit_fuser def get_cu_seqlens_on_cp_rank( - cu_seqlens, cu_seqlens_padded_on_cp_rank, cp_size, cp_rank, first_half, second_half + cu_seqlens: torch.Tensor, + cu_seqlens_padded_on_cp_rank: torch.Tensor, + cp_size: int, + cp_rank: int, + first_half: bool, + second_half: bool, ): """Compute cu_seqlens of a context parallelism rank""" seqlens = cu_seqlens[1:] - cu_seqlens[:-1] @@ -1402,11 +1529,128 @@ def get_cu_seqlens_on_cp_rank( return cu_seqlens_on_cp_rank +@torch.compile +def get_seq_chunk_ids_for_reordering(cp_size, device, to_contiguous): + """ + Context parallelism assigns two discontiguous sequence chunks to each GPU for load balancing. + To make sure tokens are ordered correctly for compute, we need to reorder sequence chunks + before or after CP communications (e.g., all-gather, all-to-all). This function is to compute + sequence chunk ids for reordering. + """ + chunk_ids = torch.empty(2 * cp_size, dtype=torch.int32, device=device) + if to_contiguous: + for rank in range(cp_size): + chunk_ids[rank] = 2 * rank + chunk_ids[rank + cp_size] = 2 * cp_size - 2 * rank - 1 + else: + for rank in range(cp_size): + chunk_ids[2 * rank] = rank + chunk_ids[2 * rank + 1] = 2 * cp_size - rank - 1 + return chunk_ids + + +@torch.compile +def reorder_seq_chunks_for_a2a(x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn): + """Reorder sequence chunk for A2A communication.""" + if before_attn: + # [cp, b, s, np//cp, hn] -> [b, cp, s, np//cp, hn] + # or [cp, s, b, np//cp, hn] -> [cp, s, b, np//cp, hn] + x = x.movedim(0, seq_dim).contiguous() + # [b, cp, s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn] + # or [cp, s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] + x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 2) :]) + # reorder the sequence chunks + x = torch.index_select(x, dim=seq_dim, index=chunk_ids_for_a2a) + else: + # [b, cp*2, s//2, np//cp, hn] -> [cp*2, b, s//2, np//cp, hn] + # or [cp*2, s//2, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] + x = x.movedim(seq_dim, 0).contiguous() + # reorder the sequence chunks + x = torch.index_select(x, dim=0, index=chunk_ids_for_a2a) + # [cp*2, b, s//2, np//cp, hn] -> [cp, 2, b, s//2, np//cp, hn] + # or [cp*2, s//2, b, np//cp, hn] -> [cp, 2, s//2, b, np//cp, hn] + x = x.view(cp_size, 2, *x.shape[1:]) + return x + + +def flash_attn_a2a_communicate( + a2a_inputs: Union[torch.Tensor, List[torch.Tensor]], + chunk_ids_for_a2a: torch.Tensor, + seq_dim: int, + cp_size: int, + cp_group: dist_group_type, + cp_stream: torch.cuda.Stream, + before_attn: bool, +) -> Union[torch.Tensor, List[torch.Tensor]]: + """A2A communication for context parallelism.""" + a2a_inputs = [a2a_inputs] if not isinstance(a2a_inputs, list) else a2a_inputs + a2a_outputs, a2a_reqs = [None] * len(a2a_inputs), [None] * len(a2a_inputs) + if before_attn: + for i in range(len(a2a_inputs) + 2): + if 0 < i < len(a2a_inputs) + 1: + a2a_outputs[i - 1] = torch.empty_like(a2a_inputs[i - 1]) + a2a_reqs[i - 1] = torch.distributed.all_to_all_single( + a2a_outputs[i - 1], a2a_inputs[i - 1], group=cp_group, async_op=True + ) + if i > 1: + with torch.cuda.stream(cp_stream): + a2a_reqs[i - 2].wait() + x = a2a_outputs[i - 2] + # reorder the sequence chunks + x = reorder_seq_chunks_for_a2a( + x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn + ) + # [b, cp*2, s//2, np//cp, hn] -> [b, cp*s, np//cp, hn] + # or [cp*2, s//2, b, np//cp, hn] -> [cp*s, b, np//cp, hn] + a2a_outputs[i - 2] = x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]) + if i < len(a2a_inputs): + x = a2a_inputs[i] + # [b, s, np, hn] -> [b, s, cp, np//cp, hn] + # or [s, b, np, hn] -> [s, b, cp, np//cp, hn] + x = x.view(*x.shape[:-2], cp_size, x.shape[-2] // cp_size, x.shape[-1]) + # [b, s, cp, np//cp, hn] -> [cp, b, s, np//cp, hn] + # or [s, b, cp, np//cp, hn] -> [cp, s, b, np//cp, hn] + a2a_inputs[i] = x.movedim(-3, 0).contiguous() + else: + for i in range(len(a2a_inputs) + 2): + if 0 < i < len(a2a_inputs) + 1: + a2a_outputs[i - 1] = torch.empty_like(a2a_inputs[i - 1]) + a2a_reqs[i - 1] = torch.distributed.all_to_all_single( + a2a_outputs[i - 1], a2a_inputs[i - 1], group=cp_group, async_op=True + ) + if i < len(a2a_inputs): + x = a2a_inputs[i] + # [b, cp*s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn] + # or [cp*s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] + x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 1) :]) + # reorder the sequence chunks + a2a_inputs[i] = reorder_seq_chunks_for_a2a( + x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn + ) + if i > 1: + with torch.cuda.stream(cp_stream): + a2a_reqs[i - 2].wait() + x = a2a_outputs[i - 2] + # [cp, 2, b, s//2, np//cp, hn] -> [b, 2, s//2, cp, np//cp, hn] + # or [cp, 2, s//2, b, np//cp, hn] -> [2, s//2, b, cp, np//cp, hn] + x = x.movedim(0, -3).movedim(0, seq_dim).contiguous() + # [b, 2, s//2, cp, np//cp, hn] -> [b*s, np, hn] + # or [2, s//2, b, cp, np//cp, hn] -> [s*b, np, hn] + a2a_outputs[i - 2] = x.view(-1, x.shape[-3] * x.shape[-2], x.shape[-1]) + torch.cuda.current_stream().wait_stream(cp_stream) + return a2a_outputs[0] if len(a2a_inputs) == 1 else a2a_outputs + + class AttnFuncWithCPAndKVP2P(torch.autograd.Function): """ Attention implementation with context parallelism. Exchange KV between CP ranks with P2P in ring topology. Split attention compute into multiple steps, and overlap current-step compute with next-step communication. + + This implementation also supports hierarchical CP, which parallelizes attention + heads in low-level CP groups and parallelizes sequence dimension in high-level CP + groups. For more details, please refer to `LongVILA `_ + and `USP `_. """ @staticmethod @@ -1436,18 +1680,37 @@ def forward( cp_global_ranks, cp_stream, ): + # pylint: disable=missing-function-docstring if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) + if isinstance(cp_group, list): + assert ( + qkv_format != "thd" + ), f"{qkv_format} format is not supported with hierarchical CP implementation yet!" + assert attn_bias_type == "no_bias", ( + f"{attn_bias_type} bias type is not supported with hierarchical CP implementation" + " yet!" + ) + cp_group_a2a = cp_group[0] + cp_size_a2a = get_distributed_world_size(cp_group_a2a) + rank_a2a = get_distributed_rank(cp_group_a2a) + cp_group = cp_group[1] + else: + cp_group_a2a = None + cp_size_a2a = 1 + rank_a2a = 0 + cp_size = get_distributed_world_size(cp_group) rank = get_distributed_rank(cp_group) - send_dst = cp_global_ranks[(rank + 1) % cp_size] - recv_src = cp_global_ranks[(rank - 1) % cp_size] + send_dst = cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a] + recv_src = cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a] batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2) causal = "causal" in attn_mask_type padding = "padding" in attn_mask_type + seq_dim = None if qkv_format in ["bshd", "sbhd"]: seq_dim = qkv_format.index("s") qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:] @@ -1463,6 +1726,62 @@ def forward( cu_seqlens_q_per_step = [None for _ in range(cp_size)] cu_seqlens_kv_per_step = [None for _ in range(cp_size)] + fused_attn_qkv_dtype = None + fused_attn_backend = None + amax_per_step = None + if fp8: + if use_fused_attention: + fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + fused_attn_qkv_dtype = fp8_dtype_forward + fused_attn_backend = FusedAttnBackend["FP8"] + if fp8_meta["recipe"].fp8_mha: + assert ( + isinstance(q, Float8Tensor) + and isinstance(k, Float8Tensor) + and isinstance(v, Float8Tensor) + ), "q/k/v must be Float8Tensors for FP8 MHA!" + fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv + q_fp8, k_fp8, v_fp8 = q, k, v + q, k, v = q_fp8._data, k_fp8._data, v_fp8._data + else: + q_f16, k_f16, v_f16 = q, k, v + if cp_size_a2a == 1 or int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + q = cast_to_fp8(q_f16, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward) + if int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + k, v = [ + cast_to_fp8(x, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward) + for x in [k_f16, v_f16] + ] + fp8_meta_kwargs = {} + fp8_meta_kwargs["d_scale_qkv"] = fp8_meta["scaling_fwd"].scale_inv + fp8_meta_kwargs["d_scale_qkv_offset"] = META_QKV + fp8_meta_kwargs["d_scale_s"] = fp8_meta["scaling_fwd"].scale_inv + fp8_meta_kwargs["d_scale_s_offset"] = META_S + fp8_meta_kwargs["q_scale_s"] = fp8_meta["scaling_fwd"].scale + fp8_meta_kwargs["q_scale_s_offset"] = META_S + fp8_meta_kwargs["q_scale_o"] = fp8_meta["scaling_fwd"].scale + fp8_meta_kwargs["q_scale_o_offset"] = META_O_CP + amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device) + else: + assert False, "FP8 is only supported with Fused Attention!" + else: + q_f16 = q + if use_fused_attention: + fp8_meta_kwargs = {} + fused_attn_qkv_dtype = TE_DType[q.dtype] + fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] + + if cp_size_a2a > 1: + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, q.device, True) + q, k, v = flash_attn_a2a_communicate( + [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, True + ) + if not fp8: + q_f16 = q + elif not fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + q_f16 = q + q = cast_to_fp8(q_f16, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward) + assert qkv_format == "thd" or ( q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 ), "Sequence length per GPU needs to be divisible by 2!" @@ -1497,69 +1816,41 @@ def forward( *attn_bias.shape[:-1], 2 * cp_size, attn_bias.shape[-1] // (2 * cp_size) ) assert q.shape[-1] % 8 == 0, "hidden size per attention head should be multiple of 8" - fa_optional_forward_kwargs = {} - if _flash_attn_2_3_plus: - fa_optional_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1) - if _flash_attn_2_4_plus: - fa_optional_forward_kwargs["alibi_slopes"] = None - if _flash_attn_2_5_7_plus: - fa_optional_forward_kwargs["block_table"] = None + + softmax_lse_in_packed_format = not use_fused_attention and ( + _flash_attn_2_6_0_plus or _use_flash_attn_3 + ) + flash_attn_fwd = None + if not use_fused_attention: + fa_forward_kwargs = {"softmax_scale": softmax_scale} + if _use_flash_attn_3: + flash_attn_fwd = flash_attn_varlen_fwd_v3 + fa_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1) + else: + flash_attn_fwd = flash_attn_varlen_fwd + fa_forward_kwargs["dropout_p"] = dropout_p + fa_forward_kwargs["return_softmax"] = False + if _flash_attn_2_3_plus: + fa_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1) + if _flash_attn_2_4_plus: + fa_forward_kwargs["alibi_slopes"] = None + if _flash_attn_2_5_7_plus: + fa_forward_kwargs["block_table"] = None # Flash Attn inputs q_inputs = [None, None] kv_inputs = [None, None] attn_bias_inputs = [None, None] # Flash Attn outputs - out_per_step = [None for _ in range(cp_size)] - softmax_lse_per_step = [None for _ in range(cp_size)] - rng_states = [None for _ in range(cp_size)] - attn_biases = [None for _ in range(cp_size)] - - # create two streams to resolve wave quantization issue of Flash Attn in each step - flash_attn_streams = [torch.cuda.current_stream(), cp_stream] - # synchronize fwd results correction across steps - fwd_results_correction_done = torch.cuda.Event() - - if fp8: - if use_fused_attention: - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - fused_attn_qkv_dtype = fp8_dtype_forward - fused_attn_backend = FusedAttnBackend["FP8"] - if fp8_meta["recipe"].fp8_mha: - assert ( - isinstance(q, Float8Tensor) - and isinstance(k, Float8Tensor) - and isinstance(v, Float8Tensor) - ), "q/k/v must be Float8Tensors for FP8 MHA!" - fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv - q_fp8, k_fp8, v_fp8 = q, k, v - q, k, v = q_fp8._data, k_fp8._data, v_fp8._data - else: - q_f16, k_f16, v_f16 = q, k, v - q = cast_to_fp8(q_f16, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward) - if int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - k, v = [ - cast_to_fp8(x, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward) - for x in [k_f16, v_f16] - ] - fp8_meta_kwargs = {} - fp8_meta_kwargs["d_scale_qkv"] = fp8_meta["scaling_fwd"].scale_inv - fp8_meta_kwargs["d_scale_qkv_offset"] = META_QKV - fp8_meta_kwargs["d_scale_s"] = fp8_meta["scaling_fwd"].scale_inv - fp8_meta_kwargs["d_scale_s_offset"] = META_S - fp8_meta_kwargs["q_scale_s"] = fp8_meta["scaling_fwd"].scale - fp8_meta_kwargs["q_scale_s_offset"] = META_S - fp8_meta_kwargs["q_scale_o"] = fp8_meta["scaling_fwd"].scale - fp8_meta_kwargs["q_scale_o_offset"] = META_O_CP - amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device) - else: - assert False, "FP8 is only supported with Fused Attention!" - else: - q_f16 = q - if use_fused_attention: - fp8_meta_kwargs = {} - fused_attn_qkv_dtype = TE_DType[q.dtype] - fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] + out_per_step = [None for _ in range(cp_size)] + softmax_lse_per_step = [None for _ in range(cp_size)] + rng_states = [None for _ in range(cp_size)] + attn_biases = [None for _ in range(cp_size)] + + # create two streams to resolve wave quantization issue of Flash Attn in each step + flash_attn_streams = [torch.cuda.current_stream(), cp_stream] + # synchronize fwd results correction across steps + fwd_results_correction_done = torch.cuda.Event() p2p_comm_buffers = [None for _ in range(cp_size)] if use_fused_attention and qkv_format in ["bshd", "sbhd"]: @@ -1568,6 +1859,8 @@ def forward( p2p_comm_buffers[0] = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0) send_recv_reqs = [[], []] + softmax_lse_ = None + out = None for i in range(cp_size + 1): if i < cp_size: with torch.cuda.stream(flash_attn_streams[i % 2]): @@ -1685,16 +1978,7 @@ def forward( q_inputs[i % 2] = q.view(-1, *q.shape[-2:]) # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn] kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:]) - ( - _, - _, - _, - _, - out_per_step[i], - softmax_lse_per_step[i], - _, - rng_states[i], - ) = _flash_attn_forward( + fa_outputs = flash_attn_fwd( q_inputs[i % 2], kv_inputs[i % 2][0], kv_inputs[i % 2][1], @@ -1702,12 +1986,13 @@ def forward( cu_seqlens_kv_per_step[i], max_seqlen_q, max_seqlen_kv, - dropout_p, - softmax_scale, causal=True, - return_softmax=False, - **fa_optional_forward_kwargs, + **fa_forward_kwargs, ) + out_per_step[i] = fa_outputs[4] + softmax_lse_per_step[i] = fa_outputs[5] + if not _use_flash_attn_3: + rng_states[i] = fa_outputs[7] elif i <= rank: if pad_between_seqs_q: cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( @@ -1797,18 +2082,9 @@ def forward( kv_inputs[i % 2] = kv_inputs[i % 2][:, :, 0, ...].contiguous() # [2, b, sk//2, np, hn] -> [2, b*sk//2, np, hn] kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:]) - if _flash_attn_2_3_plus: - fa_optional_forward_kwargs["window_size"] = (-1, -1) - ( - _, - _, - _, - _, - out_per_step[i], - softmax_lse_per_step[i], - _, - rng_states[i], - ) = _flash_attn_forward( + if _use_flash_attn_3 or _flash_attn_2_3_plus: + fa_forward_kwargs["window_size"] = (-1, -1) + fa_outputs = flash_attn_fwd( q_inputs[i % 2], kv_inputs[i % 2][0], kv_inputs[i % 2][1], @@ -1816,12 +2092,13 @@ def forward( cu_seqlens_kv_per_step[i], max_seqlen_q, max_seqlen_kv // 2, - dropout_p, - softmax_scale, causal=False, - return_softmax=False, - **fa_optional_forward_kwargs, + **fa_forward_kwargs, ) + out_per_step[i] = fa_outputs[4] + softmax_lse_per_step[i] = fa_outputs[5] + if not _use_flash_attn_3: + rng_states[i] = fa_outputs[7] else: if pad_between_seqs_q: cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( @@ -1920,18 +2197,9 @@ def forward( ) # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn] kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:]) - if _flash_attn_2_3_plus: - fa_optional_forward_kwargs["window_size"] = (-1, -1) - ( - _, - _, - _, - _, - out_per_step[i], - softmax_lse_per_step[i], - _, - rng_states[i], - ) = _flash_attn_forward( + if _use_flash_attn_3 or _flash_attn_2_3_plus: + fa_forward_kwargs["window_size"] = (-1, -1) + fa_outputs = flash_attn_fwd( q_inputs[i % 2], kv_inputs[i % 2][0], kv_inputs[i % 2][1], @@ -1939,12 +2207,13 @@ def forward( cu_seqlens_kv_per_step[i], max_seqlen_q // 2, max_seqlen_kv, - dropout_p, - softmax_scale, causal=False, - return_softmax=False, - **fa_optional_forward_kwargs, + **fa_forward_kwargs, ) + out_per_step[i] = fa_outputs[4] + softmax_lse_per_step[i] = fa_outputs[5] + if not _use_flash_attn_3: + rng_states[i] = fa_outputs[7] else: if pad_between_seqs_q: cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( @@ -2012,16 +2281,7 @@ def forward( q_inputs[i % 2] = q.view(-1, *q.shape[-2:]) # [2, b, sk, np, hn] -> [2, b*sk, np, hn] kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:]) - ( - _, - _, - _, - _, - out_per_step[i], - softmax_lse_per_step[i], - _, - rng_states[i], - ) = _flash_attn_forward( + fa_outputs = flash_attn_fwd( q_inputs[i % 2], kv_inputs[i % 2][0], kv_inputs[i % 2][1], @@ -2029,12 +2289,13 @@ def forward( cu_seqlens_kv_per_step[i], max_seqlen_q, max_seqlen_kv, - dropout_p, - softmax_scale, causal=False, - return_softmax=False, - **fa_optional_forward_kwargs, + **fa_forward_kwargs, ) + out_per_step[i] = fa_outputs[4] + softmax_lse_per_step[i] = fa_outputs[5] + if not _use_flash_attn_3: + rng_states[i] = fa_outputs[7] if i > 0: # wait until fwd restuls correction of last step is done @@ -2044,6 +2305,11 @@ def forward( if use_fused_attention: # [b, np, sq, 1] -> [b, np, sq] softmax_lse_per_step[i - 1].squeeze_(-1) + if qkv_format != "thd" and softmax_lse_in_packed_format: + # [np, t] -> [np, b, sq] + softmax_lse_per_step[i - 1] = softmax_lse_per_step[i - 1].view( + q.shape[-2], q.shape[0], -1 + ) with torch.cuda.stream(flash_attn_streams[(i - 1) % 2]): if fp8: @@ -2058,7 +2324,8 @@ def forward( out = torch.zeros_like(q if not fp8 else out_per_step[0]).view(q.shape) softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double) if causal and qkv_format != "thd": - # [b, np, sq] -> [b, np, 2, sq//2] + # [b, np, sq] -> [b, np, 2, sq//2] lse not in packed format + # [np, b, sq] -> [np, b, 2, sq//2] lse in packed format softmax_lse_ = softmax_lse.view( *softmax_lse.shape[:-1], 2, softmax_lse.shape[-1] // 2 ) @@ -2072,7 +2339,7 @@ def forward( softmax_lse, softmax_lse_per_step[i - 1], cu_seqlens_q_padded, - max_seqlen_q, + softmax_lse_in_packed_format, ) else: flash_attn_fwd_softmax_lse_correction( @@ -2086,8 +2353,11 @@ def forward( softmax_lse = softmax_lse.to(torch.float) for i in range(cp_size): + out_ = None if qkv_format == "bshd": - out_per_step[i] = out_per_step[i].view(out.shape[0], -1, *out.shape[-2:]) + out_per_step[i] = out_per_step[i].view( + out.shape[0], -1, *out.shape[-2:] + ) # pylint: disable=used-before-assignment out_ = out[:, 1, ...] elif qkv_format == "sbhd": out_per_step[i] = out_per_step[i].view(-1, *out.shape[-3:]) @@ -2098,9 +2368,10 @@ def forward( flash_attn_fwd_out_correction( out.view(*out_per_step[i].shape), out_per_step[i], - seq_dim, softmax_lse, softmax_lse_per_step[i], + 0 if softmax_lse_in_packed_format else 2, + 2 if softmax_lse_in_packed_format else seq_dim, ) elif qkv_format == "thd": tex.thd_out_correction( @@ -2110,15 +2381,17 @@ def forward( softmax_lse_per_step[i], cu_seqlens_q_padded, False, + softmax_lse_in_packed_format, ) else: if qkv_format in ["bshd", "sbhd"]: flash_attn_fwd_out_correction( out_, out_per_step[i], - seq_dim, softmax_lse_[..., 1, :], softmax_lse_per_step[i], + 0 if softmax_lse_in_packed_format else 2, + 2 if softmax_lse_in_packed_format else seq_dim, ) elif qkv_format == "thd": tex.thd_out_correction( @@ -2128,15 +2401,33 @@ def forward( softmax_lse_per_step[i], cu_seqlens_q_padded, True, + softmax_lse_in_packed_format, ) + if qkv_format != "thd" and softmax_lse_in_packed_format: + # [np, b, sq] -> [np, t] + softmax_lse = softmax_lse.view(softmax_lse.shape[0], -1) kv = p2p_comm_buffers[-1] - if use_fused_attention: - if qkv_format == "bshd": - out = out.view(out.shape[0], -1, *out.shape[-2:]) - elif qkv_format == "sbhd": - out = out.view(-1, *out.shape[-3:]) - else: + if qkv_format == "bshd": + out = out.view(out.shape[0], -1, *out.shape[-2:]) + ctx.batch_size = out.shape[0] + elif qkv_format == "sbhd": + out = out.view(-1, *out.shape[-3:]) + ctx.batch_size = out.shape[1] + + if cp_size_a2a > 1: + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, out.device, False) + out = flash_attn_a2a_communicate( + out, chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, False + ) + if use_fused_attention: + if qkv_format == "bshd": + # [b*s, np, hn] -> [b, s, np, hn] + out = out.view(ctx.batch_size, -1, *out.shape[-2:]) + elif qkv_format == "sbhd": + # [s*b, np, hn] -> [s, b, np, hn] + out = out.view(-1, ctx.batch_size, *out.shape[-2:]) + elif not use_fused_attention: out = out.view(-1, *out.shape[-2:]) if fp8 and use_fused_attention: @@ -2144,6 +2435,7 @@ def forward( fp8_meta["scaling_fwd"].amax_history[0][META_S] = amax_cp_fwd[0] fp8_meta["scaling_fwd"].amax_history[0][META_O_CP] = amax_cp_fwd[1] + out_fp8 = None out_f16 = out.to(q_fp8.dtype if fp8 and fp8_meta["recipe"].fp8_mha else q_f16.dtype) if fp8 and (fp8_meta["recipe"].fp8_mha or int(os.getenv("NVTE_FP8_DPA_BWD", "1"))): out_fp8 = cast_to_fp8(out_f16, fp8_meta["scaling_fwd"], META_O, fp8_dtype_forward) @@ -2165,6 +2457,14 @@ def forward( fp8_fwd_scales = fp8_meta["scaling_fwd"].scale.clone() fp8_fwd_scale_invs = fp8_meta["scaling_fwd"].scale_inv.clone() elif fp8 and fp8_meta["recipe"].fp8_mha: + q_fp8 = Float8Tensor( + data=q, + fp8_meta=fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=META_QKV, + fp8_dtype=fp8_dtype_forward, + dtype=q_fp8.dtype, + ) kv_fp8 = Float8Tensor( data=kv, fp8_meta=fp8_meta, @@ -2176,6 +2476,7 @@ def forward( q_save, kv_save, out_save = q_fp8, kv_fp8, out_f16 fp8_fwd_scales, fp8_fwd_scale_invs = None, None else: + q_f16 = q_f16.view(q.shape) q_save, kv_save, out_save = q_f16, kv, out_f16 fp8_fwd_scales, fp8_fwd_scale_invs = None, None @@ -2193,8 +2494,12 @@ def forward( *rng_states, *attn_biases, ) + ctx.cp_group_a2a = cp_group_a2a + ctx.cp_size_a2a = cp_size_a2a + ctx.rank_a2a = rank_a2a ctx.cp_group = cp_group ctx.cp_global_ranks = cp_global_ranks + ctx.cp_stream = cp_stream ctx.dropout_p = dropout_p ctx.total_tokens_kv = total_tokens_kv ctx.max_seqlen_q = max_seqlen_q @@ -2212,10 +2517,14 @@ def forward( @staticmethod def backward(ctx, dout): + # pylint: disable=missing-function-docstring + cp_size_a2a = ctx.cp_size_a2a + rank_a2a = ctx.rank_a2a + cp_size = get_distributed_world_size(ctx.cp_group) rank = get_distributed_rank(ctx.cp_group) - send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size] - recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size] + send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a] + recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a] batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2) (q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded) = ctx.saved_tensors[:6] @@ -2227,7 +2536,10 @@ def backward(ctx, dout): causal = "causal" in ctx.attn_mask_type padding = "padding" in ctx.attn_mask_type + + seq_dim = None if ctx.qkv_format in ["bshd", "sbhd"]: + seq_dim = ctx.qkv_format.index("s") qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format[:-2] + "2" + ctx.qkv_format[-2:] else: qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format @@ -2243,11 +2555,16 @@ def backward(ctx, dout): ) else: attn_dbias = None + attn_dbias_ = None + + softmax_lse_in_packed_format = not ctx.use_fused_attention and ( + _flash_attn_2_6_0_plus or _use_flash_attn_3 + ) if causal: - if ctx.qkv_format == "thd": + if ctx.qkv_format == "thd" or softmax_lse_in_packed_format: softmax_lse_ = tex.thd_read_second_half_lse( - softmax_lse, cu_seqlens_q_padded, ctx.max_seqlen_q + softmax_lse, cu_seqlens_q_padded, softmax_lse_in_packed_format ) else: # [b, np, sq] -> [b, np, 2, sq//2] @@ -2262,6 +2579,12 @@ def backward(ctx, dout): # [b, np, sq] -> [b, np, sq, 1] softmax_lse.unsqueeze_(-1) + dout_dtype = dout.dtype + fused_attn_backend = None + fused_attn_qkv_dtype = None + fused_attn_dqkv_dtype = None + amax_per_step = None + dout_fp8_dtype = None if ctx.fp8: if ctx.use_fused_attention: fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) @@ -2272,7 +2595,6 @@ def backward(ctx, dout): dq_fp8 = torch.empty((cp_size, *q.shape), dtype=q.dtype, device=q.device) dkv_fp8 = torch.empty((cp_size, *kv.shape), dtype=kv.dtype, device=kv.device) dkv_fp8_ = torch.empty_like(dkv_fp8) - dout_dtype = dout.dtype if ctx.fp8_meta["recipe"].fp8_mha: assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = dout._scale_inv @@ -2296,7 +2618,13 @@ def backward(ctx, dout): assert False, "FP8 is only supported with Fused Attention!" else: if ctx.fp8_meta is not None and ctx.fp8_meta["recipe"].fp8_mha: - q, kv, dout = [x.from_float8(x.dtype) for x in [q, kv, dout]] + q, kv = [x.from_float8(x.dtype) for x in [q, kv]] + if cp_size_a2a == 1: + dout = dout.from_float8(dout_dtype) + else: + dout_fp8_dtype = dout._fp8_dtype + dout_scale_inv = dout._scale_inv + dout = dout._data dq = torch.empty_like(q) if ctx.qkv_format == "thd" and causal: dq[cu_seqlens_q_padded[-1] :].fill_(0) @@ -2308,18 +2636,50 @@ def backward(ctx, dout): if ctx.use_fused_attention: fp8_meta_kwargs = {} fused_attn_qkv_dtype = TE_DType[q.dtype] - fused_attn_dqkv_dtype = TE_DType[dout.dtype] + fused_attn_dqkv_dtype = TE_DType[dout_dtype] fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] + if cp_size_a2a > 1: + if not ctx.use_fused_attention: + out = out.view(ctx.batch_size, -1, *out.shape[-2:]) + dout = dout.view(*out.shape) + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, out.device, True) + out, dout = flash_attn_a2a_communicate( + [out, dout], + chunk_ids_for_a2a, + seq_dim, + cp_size_a2a, + ctx.cp_group_a2a, + ctx.cp_stream, + True, + ) + if not ctx.fp8 and ctx.fp8_meta is not None and ctx.fp8_meta["recipe"].fp8_mha: + dout = cast_from_fp8( + dout, + None, + None, + dout_fp8_dtype, + TE_DType[dout_dtype], + scale_inv=dout_scale_inv, # pylint: disable=used-before-assignment + ) + out = out.view(*q.shape) dout = dout.view(*q.shape) send_recv_reqs = [] - fa_optional_backward_kwargs = {} - if _flash_attn_2_4_plus: - fa_optional_backward_kwargs["alibi_slopes"] = None - if _flash_attn_2_4_1_plus: - fa_optional_backward_kwargs["deterministic"] = ctx.deterministic + flash_attn_bwd = None + if not ctx.use_fused_attention: + fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale} + if _use_flash_attn_3: + flash_attn_bwd = flash_attn_varlen_bwd_v3 + fa_backward_kwargs["deterministic"] = ctx.deterministic + else: + flash_attn_bwd = flash_attn_varlen_bwd + fa_backward_kwargs["dropout_p"] = ctx.dropout_p + if _flash_attn_2_4_plus: + fa_backward_kwargs["alibi_slopes"] = None + if _flash_attn_2_4_1_plus: + fa_backward_kwargs["deterministic"] = ctx.deterministic for i in range(cp_size): # wait until KV is received @@ -2359,6 +2719,7 @@ def backward(ctx, dout): ) kv = p2p_comm_buffers[i % 2][0] + dk_, dv_ = None, None if ctx.fp8 and ctx.use_fused_attention: fp8_meta_kwargs["amax_dp"] = amax_per_step[0][i] fp8_meta_kwargs["amax_dqkv"] = amax_per_step[0][i] @@ -2428,9 +2789,11 @@ def backward(ctx, dout): # [b, 2, sq//2, np, hn] -> [b*sq, np, hn] out_ = out.view(-1, *out.shape[-2:]) dout_ = dout.view(-1, *dout.shape[-2:]) - if _flash_attn_2_3_plus: - fa_optional_backward_kwargs["window_size"] = (-1, 0) - _flash_attn_backward( + if _use_flash_attn_3 or _flash_attn_2_3_plus: + fa_backward_kwargs["window_size"] = (-1, 0) + if not _use_flash_attn_3: + fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] + flash_attn_bwd( dout_, q_, kv_[0], @@ -2444,11 +2807,8 @@ def backward(ctx, dout): cu_seqlens_kv_per_step[cp_size - i - 1], ctx.max_seqlen_q, ctx.max_seqlen_kv, - ctx.dropout_p, - ctx.softmax_scale, - True, - rng_state=rng_states[cp_size - i - 1], - **fa_optional_backward_kwargs, + causal=True, + **fa_backward_kwargs, ) elif i >= (cp_size - rank - 1): if ctx.use_fused_attention: @@ -2522,9 +2882,11 @@ def backward(ctx, dout): # [b, 2, sq//2, np, hn] -> [b*sq, np, hn] out_ = out.view(-1, *out.shape[-2:]) dout_ = dout.view(-1, *dout.shape[-2:]) - if _flash_attn_2_3_plus: - fa_optional_backward_kwargs["window_size"] = (-1, -1) - _flash_attn_backward( + if _use_flash_attn_3 or _flash_attn_2_3_plus: + fa_backward_kwargs["window_size"] = (-1, -1) + if not _use_flash_attn_3: + fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] + flash_attn_bwd( dout_, q_, kv_[0], @@ -2538,11 +2900,8 @@ def backward(ctx, dout): cu_seqlens_kv_per_step[cp_size - i - 1], ctx.max_seqlen_q, ctx.max_seqlen_kv // 2, - ctx.dropout_p, - ctx.softmax_scale, - False, - rng_state=rng_states[cp_size - i - 1], - **fa_optional_backward_kwargs, + causal=False, + **fa_backward_kwargs, ) else: if ctx.use_fused_attention: @@ -2622,9 +2981,11 @@ def backward(ctx, dout): # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn] out_ = out[:, 1, ...].contiguous().view(-1, *out.shape[-2:]) dout_ = dout[:, 1, ...].contiguous().view(-1, *dout.shape[-2:]) - if _flash_attn_2_3_plus: - fa_optional_backward_kwargs["window_size"] = (-1, -1) - _flash_attn_backward( + if _use_flash_attn_3 or _flash_attn_2_3_plus: + fa_backward_kwargs["window_size"] = (-1, -1) + if not _use_flash_attn_3: + fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] + flash_attn_bwd( dout_, q_, kv_[0], @@ -2638,11 +2999,8 @@ def backward(ctx, dout): cu_seqlens_kv_per_step[cp_size - i - 1], ctx.max_seqlen_q // 2, ctx.max_seqlen_kv, - ctx.dropout_p, - ctx.softmax_scale, - False, - rng_state=rng_states[cp_size - i - 1], - **fa_optional_backward_kwargs, + causal=False, + **fa_backward_kwargs, ) else: if ctx.use_fused_attention: @@ -2686,9 +3044,11 @@ def backward(ctx, dout): # [b, sq, np, hn] -> [b*sq, np, hn] out_ = out.view(-1, *out.shape[-2:]) dout_ = dout.view(-1, *dout.shape[-2:]) - if _flash_attn_2_3_plus: - fa_optional_backward_kwargs["window_size"] = (-1, -1) - _flash_attn_backward( + if _use_flash_attn_3 or _flash_attn_2_3_plus: + fa_backward_kwargs["window_size"] = (-1, -1) + if not _use_flash_attn_3: + fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] + flash_attn_bwd( dout_, q_, kv_[0], @@ -2702,11 +3062,8 @@ def backward(ctx, dout): cu_seqlens_kv_per_step[cp_size - i - 1], ctx.max_seqlen_q, ctx.max_seqlen_kv, - ctx.dropout_p, - ctx.softmax_scale, - False, - rng_state=rng_states[cp_size - i - 1], - **fa_optional_backward_kwargs, + causal=False, + **fa_backward_kwargs, ) if ctx.fp8: @@ -2796,7 +3153,9 @@ def backward(ctx, dout): else: dkv = p2p_comm_buffers[(i + 1) % 2][1] if ctx.use_fused_attention: - dkv_ = torch.cat((dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0) + dkv_ = torch.cat( + (dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0 + ) # pylint: disable=used-before-assignment if ctx.qkv_format in ["bshd", "sbhd"]: # [b, 2, sk//2, 2, np, hn] -> [2, b, 2, sk//2, np, hn] or # [2, sk//2, b, 2, np, hn] -> [2, 2, sk//2, b, np, hn] @@ -2906,6 +3265,25 @@ def backward(ctx, dout): cast_to_fp8(x, ctx.fp8_meta["scaling_bwd"], META_DQKV, fp8_dtype_backward) for x in [dq, dkv] ] + dk, dv = dkv[0], dkv[1] + + if cp_size_a2a > 1: + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, q.device, False) + dq, dk, dv = flash_attn_a2a_communicate( + [dq, dk, dv], + chunk_ids_for_a2a, + seq_dim, + cp_size_a2a, + ctx.cp_group_a2a, + ctx.cp_stream, + False, + ) + if ctx.qkv_format == "bshd": + dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]] + elif ctx.qkv_format == "sbhd": + dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]] + + if ctx.fp8 and ctx.fp8_meta["recipe"].fp8_mha: dq, dk, dv = [ Float8Tensor( data=x, @@ -2915,10 +3293,8 @@ def backward(ctx, dout): fp8_dtype=fp8_dtype_backward, dtype=dout_dtype, ) - for x in [dq, dkv[0], dkv[1]] + for x in [dq, dk, dv] ] - else: - dk, dv = dkv[0], dkv[1] if attn_dbias is not None: # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, sq, sk] @@ -2951,26 +3327,6 @@ def backward(ctx, dout): ) -@torch.compile -def get_seq_chunk_ids_for_reordering(cp_size, device, to_contiguous): - """ - Context parallelism assigns two discontiguous sequence chunks to each GPU for load balancing. - To make sure tokens are ordered correctly for compute, we need to reorder sequence chunks - before or after CP communications (e.g., all-gather, all-to-all). This function is to compute - sequence chunk ids for reordering. - """ - chunk_ids = torch.empty(2 * cp_size, dtype=torch.int32, device=device) - if to_contiguous: - for rank in range(cp_size): - chunk_ids[rank] = 2 * rank - chunk_ids[rank + cp_size] = 2 * cp_size - 2 * rank - 1 - else: - for rank in range(cp_size): - chunk_ids[2 * rank] = rank - chunk_ids[2 * rank + 1] = 2 * cp_size - rank - 1 - return chunk_ids - - def get_kv_seq_info_after_all_gather( local_chunk_id, cp_size, max_seqlen_q, max_seqlen_kv, window_size, causal ): @@ -3027,6 +3383,7 @@ def forward( cp_group, cp_stream, ): + # pylint: disable=missing-function-docstring if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) @@ -3043,11 +3400,20 @@ def forward( assert ( use_fused_attention or _flash_attn_2_3_plus ), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!" - fa_optional_forward_kwargs = {} - if _flash_attn_2_4_plus: - fa_optional_forward_kwargs["alibi_slopes"] = None - if _flash_attn_2_5_7_plus: - fa_optional_forward_kwargs["block_table"] = None + + flash_attn_fwd = None + if not use_fused_attention: + fa_forward_kwargs = {"softmax_scale": softmax_scale} + if _use_flash_attn_3: + flash_attn_fwd = flash_attn_varlen_fwd_v3 + else: + flash_attn_fwd = flash_attn_varlen_fwd + fa_forward_kwargs["dropout_p"] = dropout_p + fa_forward_kwargs["return_softmax"] = False + if _flash_attn_2_4_plus: + fa_forward_kwargs["alibi_slopes"] = None + if _flash_attn_2_5_7_plus: + fa_forward_kwargs["block_table"] = None assert qkv_format != "thd", f"{qkv_format} format is not supported!" qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format @@ -3097,7 +3463,8 @@ def forward( for i in range(len(local_seq_chunk_ids) + 1): if i < len(local_seq_chunk_ids): with torch.cuda.stream(flash_attn_streams[i]): - # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] or [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] + # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] + # or [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] q_ = q.select(seq_dim, i).contiguous() kv_seq_range_per_step[i], window_size_per_step[i] = ( get_kv_seq_info_after_all_gather( @@ -3144,23 +3511,22 @@ def forward( ) else: q_, k_, v_ = [x.view(-1, *x.shape[-2:]) for x in [q_, k_, v_]] - _, _, _, _, out_per_step[i], softmax_lse_per_step[i], _, rng_states[i] = ( - _flash_attn_forward( - q_, - k_, - v_, - cu_seqlens_q, - cu_seqlens_kv_per_step[i], - max_seqlen_q, - max_seqlen_kv_, - dropout_p, - softmax_scale, - causal=causal, - return_softmax=False, - window_size=window_size_per_step[i], - **fa_optional_forward_kwargs, - ) + fa_outputs = flash_attn_fwd( + q_, + k_, + v_, + cu_seqlens_q, + cu_seqlens_kv_per_step[i], + max_seqlen_q, + max_seqlen_kv_, + causal=causal, + window_size=window_size_per_step[i], + **fa_forward_kwargs, ) + out_per_step[i] = fa_outputs[4] + softmax_lse_per_step[i] = fa_outputs[5] + if not _use_flash_attn_3: + rng_states[i] = fa_outputs[7] if i > 0: with torch.cuda.stream(flash_attn_streams[i - 1]): @@ -3206,6 +3572,7 @@ def forward( @staticmethod def backward(ctx, dout): + # pylint: disable=missing-function-docstring cp_size = get_distributed_world_size(ctx.cp_group) rank = get_distributed_rank(ctx.cp_group) @@ -3250,16 +3617,25 @@ def backward(ctx, dout): local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1] - fa_optional_backward_kwargs = {} - if _flash_attn_2_4_plus: - fa_optional_backward_kwargs["alibi_slopes"] = None - if _flash_attn_2_4_1_plus: - fa_optional_backward_kwargs["deterministic"] = ctx.deterministic + flash_attn_bwd = None + if not ctx.use_fused_attention: + fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale} + if _use_flash_attn_3: + flash_attn_bwd = flash_attn_varlen_bwd_v3 + fa_backward_kwargs["deterministic"] = ctx.deterministic + else: + flash_attn_bwd = flash_attn_varlen_bwd + fa_backward_kwargs["dropout_p"] = ctx.dropout_p + if _flash_attn_2_4_plus: + fa_backward_kwargs["alibi_slopes"] = None + if _flash_attn_2_4_1_plus: + fa_backward_kwargs["deterministic"] = ctx.deterministic for i in range(len(local_seq_chunk_ids) + 1): if i < len(local_seq_chunk_ids): with torch.cuda.stream(flash_attn_streams[i]): - # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] or [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] + # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] + # or [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] q_ = q.select(seq_dim, i).contiguous() seq_start_idx, seq_end_idx = ( kv_seq_range_per_step[i][0], @@ -3303,7 +3679,9 @@ def backward(ctx, dout): dq_per_step[i], dk_per_step[i], dv_per_step[i] = [ torch.empty_like(x) for x in [q_, k_, v_] ] - _flash_attn_backward( + if not _use_flash_attn_3: + fa_backward_kwargs["rng_state"] = rng_states[i] + flash_attn_bwd( dout_, q_, k_, @@ -3317,12 +3695,9 @@ def backward(ctx, dout): cu_seqlens_kv_per_step[i], ctx.max_seqlen_q, max_seqlen_kv, - ctx.dropout_p, - ctx.softmax_scale, - "causal" in ctx.attn_mask_type, + causal="causal" in ctx.attn_mask_type, window_size=window_size_per_step[i], - rng_state=rng_states[i], - **fa_optional_backward_kwargs, + **fa_backward_kwargs, ) # [b*sq//2, np, hn] -> [b, sq//2, np, hn] dq_per_step[i] = dq_per_step[i].view(dq[:, i].shape) @@ -3396,88 +3771,6 @@ def backward(ctx, dout): ) -@torch.compile -def reorder_seq_chunks_for_a2a(x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn): - """Reorder sequence chunk for A2A communication.""" - if before_attn: - # [cp, b, s, np//cp, hn] -> [b, cp, s, np//cp, hn] or [cp, s, b, np//cp, hn] -> [cp, s, b, np//cp, hn] - x = x.movedim(0, seq_dim).contiguous() - # [b, cp, s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn] or [cp, s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] - x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 2) :]) - # reorder the sequence chunks - x = torch.index_select(x, dim=seq_dim, index=chunk_ids_for_a2a) - else: - # [b, cp*2, s//2, np//cp, hn] -> [cp*2, b, s//2, np//cp, hn] or [cp*2, s//2, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] - x = x.movedim(seq_dim, 0).contiguous() - # reorder the sequence chunks - x = torch.index_select(x, dim=0, index=chunk_ids_for_a2a) - # [cp*2, b, s//2, np//cp, hn] -> [cp, 2, b, s//2, np//cp, hn] or [cp*2, s//2, b, np//cp, hn] -> [cp, 2, s//2, b, np//cp, hn] - x = x.view(cp_size, 2, *x.shape[1:]) - return x - - -def flash_attn_a2a_communicate( - a2a_inputs: Union[torch.Tensor, List[torch.Tensor]], - chunk_ids_for_a2a: torch.Tensor, - seq_dim: int, - cp_size: int, - cp_group: dist_group_type, - cp_stream: torch.cuda.Stream, - before_attn: bool, -) -> Union[torch.Tensor, List[torch.Tensor]]: - """A2A communication for context parallelism.""" - a2a_inputs = [a2a_inputs] if not isinstance(a2a_inputs, list) else a2a_inputs - a2a_outputs, a2a_reqs = [None] * len(a2a_inputs), [None] * len(a2a_inputs) - if before_attn: - for i in range(len(a2a_inputs) + 2): - if 0 < i < len(a2a_inputs) + 1: - a2a_outputs[i - 1] = torch.empty_like(a2a_inputs[i - 1]) - a2a_reqs[i - 1] = torch.distributed.all_to_all_single( - a2a_outputs[i - 1], a2a_inputs[i - 1], group=cp_group, async_op=True - ) - if i > 1: - with torch.cuda.stream(cp_stream): - a2a_reqs[i - 2].wait() - x = a2a_outputs[i - 2] - # reorder the sequence chunks - x = reorder_seq_chunks_for_a2a( - x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn - ) - # [b, cp*2, s//2, np//cp, hn] -> [b, cp*s, np//cp, hn] or [cp*2, s//2, b, np//cp, hn] -> [cp*s, b, np//cp, hn] - a2a_outputs[i - 2] = x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]) - if i < len(a2a_inputs): - x = a2a_inputs[i] - # [b, s, np, hn] -> [b, s, cp, np//cp, hn] or [s, b, np, hn] -> [s, b, cp, np//cp, hn] - x = x.view(*x.shape[:-2], cp_size, x.shape[-2] // cp_size, x.shape[-1]) - # [b, s, cp, np//cp, hn] -> [cp, b, s, np//cp, hn] or [s, b, cp, np//cp, hn] -> [cp, s, b, np//cp, hn] - a2a_inputs[i] = x.movedim(-3, 0).contiguous() - else: - for i in range(len(a2a_inputs) + 2): - if 0 < i < len(a2a_inputs) + 1: - a2a_outputs[i - 1] = torch.empty_like(a2a_inputs[i - 1]) - a2a_reqs[i - 1] = torch.distributed.all_to_all_single( - a2a_outputs[i - 1], a2a_inputs[i - 1], group=cp_group, async_op=True - ) - if i < len(a2a_inputs): - x = a2a_inputs[i] - # [b, cp*s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn] or [cp*s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] - x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 1) :]) - # reorder the sequence chunks - a2a_inputs[i] = reorder_seq_chunks_for_a2a( - x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn - ) - if i > 1: - with torch.cuda.stream(cp_stream): - a2a_reqs[i - 2].wait() - x = a2a_outputs[i - 2] - # [cp, 2, b, s//2, np//cp, hn] -> [b, 2, s//2, cp, np//cp, hn] or [cp, 2, s//2, b, np//cp, hn] -> [2, s//2, b, cp, np//cp, hn] - x = x.movedim(0, -3).movedim(0, seq_dim).contiguous() - # [b, 2, s//2, cp, np//cp, hn] -> [b*s, np, hn] or [2, s//2, b, cp, np//cp, hn] -> [s*b, np, hn] - a2a_outputs[i - 2] = x.view(-1, x.shape[-3] * x.shape[-2], x.shape[-1]) - torch.cuda.current_stream().wait_stream(cp_stream) - return a2a_outputs[0] if len(a2a_inputs) == 1 else a2a_outputs - - class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): """ Attention implementation with context parallelism. Like Ulysses, applying A2A to QKVO. @@ -3511,6 +3804,7 @@ def forward( cp_group, cp_stream, ): + # pylint: disable=missing-function-docstring if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) @@ -3527,13 +3821,23 @@ def forward( or use_fused_attention or _flash_attn_2_3_plus ), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!" - fa_optional_forward_kwargs = {} - if _flash_attn_2_3_plus: - fa_optional_forward_kwargs["window_size"] = window_size - if _flash_attn_2_4_plus: - fa_optional_forward_kwargs["alibi_slopes"] = None - if _flash_attn_2_5_7_plus: - fa_optional_forward_kwargs["block_table"] = None + + flash_attn_fwd = None + if not use_fused_attention: + fa_forward_kwargs = {"softmax_scale": softmax_scale} + if _use_flash_attn_3: + flash_attn_fwd = flash_attn_varlen_fwd_v3 + fa_forward_kwargs["window_size"] = window_size + else: + flash_attn_fwd = flash_attn_varlen_fwd + fa_forward_kwargs["dropout_p"] = dropout_p + fa_forward_kwargs["return_softmax"] = False + if _flash_attn_2_3_plus: + fa_forward_kwargs["window_size"] = window_size + if _flash_attn_2_4_plus: + fa_forward_kwargs["alibi_slopes"] = None + if _flash_attn_2_5_7_plus: + fa_forward_kwargs["block_table"] = None assert ( q.shape[-2] % cp_size == 0 and k.shape[-2] % cp_size == 0 @@ -3548,6 +3852,8 @@ def forward( q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 ), "Sequence length per GPU needs to be divisible by 2!" + fused_attn_backend = None + fused_attn_qkv_dtype = None if fp8: if use_fused_attention: fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) @@ -3628,16 +3934,7 @@ def forward( else: # [b, cp*s, np//cp, hn] -> [b*cp*s, np//cp, hn] q, k, v = [x.view(-1, *x.shape[-2:]) for x in [q, k, v]] - ( - _, - _, - _, - _, - out, - softmax_lse, - _, - rng_state, - ) = _flash_attn_forward( + fa_outputs = flash_attn_fwd( q, k, v, @@ -3645,12 +3942,11 @@ def forward( cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, - dropout_p, - softmax_scale, causal=causal, - return_softmax=False, - **fa_optional_forward_kwargs, + **fa_forward_kwargs, ) + out, softmax_lse = fa_outputs[4], fa_outputs[5] + rng_state = fa_outputs[7] if not _use_flash_attn_3 else None aux_ctx_tensors = [softmax_lse, rng_state] # [b*cp*s, np//cp, hn] -> [b, cp*s, np//cp, hn] out = out.view(batch_size, -1, *out.shape[-2:]) @@ -3751,6 +4047,7 @@ def forward( @staticmethod def backward(ctx, dout): + # pylint: disable=missing-function-docstring cp_size = get_distributed_world_size(ctx.cp_group) q, k, v, out = ctx.saved_tensors[:4] @@ -3764,6 +4061,9 @@ def backward(ctx, dout): causal = "causal" in ctx.attn_mask_type seq_dim = ctx.qkv_format.index("s") + fused_attn_backend = None + fused_attn_dqkv_dtype = None + fused_attn_qkv_dtype = None if ctx.fp8: if ctx.use_fused_attention: fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) @@ -3815,13 +4115,22 @@ def backward(ctx, dout): [out, dout], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, True ) - fa_optional_backward_kwargs = {} - if _flash_attn_2_3_plus: - fa_optional_backward_kwargs["window_size"] = ctx.window_size - if _flash_attn_2_4_plus: - fa_optional_backward_kwargs["alibi_slopes"] = None - if _flash_attn_2_4_1_plus: - fa_optional_backward_kwargs["deterministic"] = ctx.deterministic + flash_attn_bwd = None + if not ctx.use_fused_attention: + fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale} + if _use_flash_attn_3: + flash_attn_bwd = flash_attn_varlen_bwd_v3 + fa_backward_kwargs["window_size"] = ctx.window_size + fa_backward_kwargs["deterministic"] = ctx.deterministic + else: + flash_attn_bwd = flash_attn_varlen_bwd + fa_backward_kwargs["dropout_p"] = ctx.dropout_p + if _flash_attn_2_3_plus: + fa_backward_kwargs["window_size"] = ctx.window_size + if _flash_attn_2_4_plus: + fa_backward_kwargs["alibi_slopes"] = None + if _flash_attn_2_4_1_plus: + fa_backward_kwargs["deterministic"] = ctx.deterministic if ctx.use_fused_attention: dq, dk, dv, _ = fused_attn_bwd( @@ -3853,7 +4162,9 @@ def backward(ctx, dout): softmax_lse, rng_state = aux_ctx_tensors out, dout = [x.view(-1, *x.shape[-2:]) for x in [out, dout]] dq, dk, dv = [torch.empty_like(x) for x in [q, k, v]] - _flash_attn_backward( + if not _use_flash_attn_3: + fa_backward_kwargs["rng_state"] = rng_state + flash_attn_bwd( dout, q, k, @@ -3867,11 +4178,8 @@ def backward(ctx, dout): cu_seqlens_kv, ctx.max_seqlen_q, ctx.max_seqlen_kv, - ctx.dropout_p, - ctx.softmax_scale, - causal, - rng_state=rng_state, - **fa_optional_backward_kwargs, + causal=causal, + **fa_backward_kwargs, ) dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]] @@ -3969,6 +4277,22 @@ def attn_forward_func_with_cp( Attention implementation with context parallelism. """ + if cp_comm_type == "a2a+p2p": + assert isinstance( + cp_group, list + ), "Hierarchical CP implementation needs multi-level CP groups!" + assert len(cp_group) == 2, "Current implementation only supports two-level CP groups!" + if get_distributed_world_size(cp_group[0]) == 1: + cp_group = cp_group[1] + cp_comm_type = "p2p" + elif get_distributed_world_size(cp_group[1]) == 1: + cp_group = cp_group[0] + cp_comm_type = "a2a" + else: + assert isinstance( + cp_group, dist_group_type + ), f"Unsupported process group for CP communication type {cp_comm_type}!" + assert qkv_format in [ "bshd", "sbhd", @@ -4023,7 +4347,7 @@ def attn_forward_func_with_cp( use_fused_attention, ] - if cp_comm_type == "p2p": + if cp_comm_type in ["p2p", "a2a+p2p"]: args += [fp8, fp8_meta, cp_group, cp_global_ranks, cp_stream] out = AttnFuncWithCPAndKVP2P.apply(*args) elif cp_comm_type == "all_gather": @@ -4051,6 +4375,7 @@ def __init__( rotary_percent: float = 1.0, seq_len_interpolation_factor: Optional[int] = None, pretrained_max_position_embeddings: Optional[int] = None, + rotary_base: float = 10000.0, ): """ Parameters @@ -4069,8 +4394,9 @@ def __init__( if rotary_percent < 1.0: dim = int(dim * rotary_percent) self.seq_len_interpolation_factor = seq_len_interpolation_factor + self.rotary_base = rotary_base inv_freq = 1.0 / ( - 10000 + self.rotary_base ** ( torch.arange(0, dim, 2, dtype=torch.float32, device=torch.cuda.current_device()) / dim @@ -4133,7 +4459,10 @@ def forward( freqs: torch.Tensor, tensor_format: str = "sbhd", cu_seqlens: Union[torch.Tensor, None] = None, + cp_size: int = 1, + cp_rank: int = 0, ) -> torch.Tensor: + # pylint: disable=missing-function-docstring if freqs.dtype != torch.float32: freqs = freqs.float() if tensor_format == "sbhd": @@ -4141,16 +4470,19 @@ def forward( elif tensor_format == "bshd": output = tex.fused_rope_forward(t.transpose(0, 1), freqs, True).transpose(0, 1) elif tensor_format == "thd": - output = tex.fused_rope_thd_forward(t, cu_seqlens, freqs) + output = tex.fused_rope_thd_forward(t, cu_seqlens, freqs, cp_size, cp_rank) else: raise ValueError(f"Unsupported tensor_format: {tensor_format}.") ctx.save_for_backward(freqs, cu_seqlens) ctx.tensor_format = tensor_format + ctx.cp_size = cp_size + ctx.cp_rank = cp_rank return output @staticmethod def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: + # pylint: disable=missing-function-docstring freqs, cu_seqlens = ctx.saved_tensors if ctx.tensor_format == "sbhd": grad_input = tex.fused_rope_backward(grad_output, freqs, False) @@ -4159,11 +4491,13 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], grad_output.transpose(0, 1), freqs, True ).transpose(0, 1) elif ctx.tensor_format == "thd": - grad_input = tex.fused_rope_thd_backward(grad_output, cu_seqlens, freqs) + grad_input = tex.fused_rope_thd_backward( + grad_output, cu_seqlens, freqs, ctx.cp_size, ctx.cp_rank + ) else: raise ValueError(f"Unsupported tensor_format: {ctx.tensor_format}.") - return grad_input, None, None, None, None + return grad_input, None, None, None, None, None def _rotate_half(x: torch.Tensor) -> torch.Tensor: @@ -4181,6 +4515,8 @@ def apply_rotary_pos_emb( tensor_format: str = "sbhd", fused: bool = False, cu_seqlens: Union[torch.Tensor, None] = None, + cp_size: int = 1, + cp_rank: int = 0, ) -> torch.Tensor: """ Apply rotary positional embedding tensor to the input tensor. @@ -4201,12 +4537,17 @@ def apply_rotary_pos_emb( cu_seqlens: torch.Tensor, default = None. Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and dtype torch.int32. Only valid when `tensor_format` is 'thd'. + Should be `cu_seqlens_padded` when cp_size > 1. + cp_size: int, default = 1. + Context parallel world size. Only valid when `tensor_format` is 'thd' and `fused` is True. + cp_rank: int, default = 0. + Context parallel rank. Only valid when `tensor_format` is 'thd' and `fused` is True. """ if fused: assert ( tensor_format != "thd" or cu_seqlens is not None ), "cu_seqlens must not be None when tensor_format is 'thd'." - return FusedRoPEFunc.apply(t, freqs, tensor_format, cu_seqlens) + return FusedRoPEFunc.apply(t, freqs, tensor_format, cu_seqlens, cp_size, cp_rank) assert tensor_format in ("sbhd", "bshd"), ( "Only formats `sbhd` or `bshd` are supported for input tensor `t` " @@ -4248,6 +4589,7 @@ def forward( split_dim: int, split_size_or_sections: Union[int, List[int], Tuple[int]], ) -> Tuple[torch.Tensor, ...]: + # pylint: disable=missing-function-docstring ctx.split_dim = split_dim ctx.split_size_or_sections = split_size_or_sections if isinstance(mixed_x_layer, Float8Tensor): @@ -4266,6 +4608,7 @@ def forward( @staticmethod def backward(ctx, *grad_outputs): + # pylint: disable=missing-function-docstring assert len(grad_outputs) > 0, "No gradients received for backprop!" if isinstance(ctx.split_size_or_sections, (list, tuple)): @@ -4610,6 +4953,7 @@ def forward( key_layer: torch.Tensor, value_layer: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # pylint: disable=missing-function-docstring # All inputs received are non-contiguous tensors. # The `query_layer` tensor is used to access the # full memory region of the QKV tensor. @@ -4627,6 +4971,7 @@ def backward( dk: torch.Tensor, dv: torch.Tensor, ) -> Tuple[Union[torch.Tensor, None], ...]: + # pylint: disable=missing-function-docstring dqkv = tex.fa_prepare_bwd(dq, dk, dv) dq, dk, dv = split_tensor_along_dim(dqkv, -1, 3) return dq, dk, dv @@ -4667,74 +5012,105 @@ def get_qkv_layout( `sbhd`: {`sb3hd`, `sbh3d`, `sbhd_sb2hd`, `sbhd_sbh2d`, `sbhd_sbhd_sbhd`} `bshd`: {`bs3hd`, `bsh3d`, `bshd_bs2hd`, `bshd_bsh2d`, `bshd_bshd_bshd`} `thd` : {`t3hd`, `th3d`, `thd_t2hd`, `thd_th2d`, `thd_thd_thd`} + q: torch.Tensor + Query tensor. It may be different from input `q` as we try to fit tensors to + a supported layout. + k: torch.Tensor + Key tensor. It may be different from input `k` as we try to fit tensors to + a supported layout. + v: torch.Tensor + Value tensor. It may be different from input `v` as we try to fit tensors to + a supported layout. """ check_last_dim_contiguous = all(x.stride(-1) == 1 for x in [q, k, v]) assert check_last_dim_contiguous, "q, k and v must have stride 1 in their last dimension!" def run_iteratively(q, k, v): + # check data pointers data_ptr = q.untyped_storage().data_ptr() check_ptrs_qkv = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k, v]) + check_ptrs_qk = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k]) data_ptr = k.untyped_storage().data_ptr() check_ptrs_kv = all(x.untyped_storage().data_ptr() == data_ptr for x in [k, v]) + # check tensor shapes + shape = q.shape + check_shapes_qkv = all(shape == x.shape for x in [q, k, v]) + shape = k.shape + check_shapes_kv = shape[:-1] == v.shape[:-1] + + # check tensor strides stride = q.stride() check_strides_qkv = all(stride == x.stride() for x in [q, k, v]) check_strides_kv = tuple(sk / k.shape[-1] for sk in k.stride()[:-1]) == tuple( sv / v.shape[-1] for sv in v.stride()[:-1] ) - shape = q.shape - check_shapes_qkv = all(shape == x.shape for x in [q, k, v]) - shape = k.shape - check_shapes_kv = shape[:-1] == v.shape[:-1] + # check tensor offsets for h3d and 3hd layouts + prod_h_d = q.shape[-1] * q.shape[-2] + check_3hd_offsets = all(x.storage_offset() == i * prod_h_d for i, x in enumerate([q, k, v])) + check_h3d_offsets = all( + x.storage_offset() == i * q.shape[-1] for i, x in enumerate([q, k, v]) + ) - last_dim_size = q.shape[-1] - check_last_dim_offsets_qkv = all( - i * last_dim_size == x.storage_offset() for i, x in enumerate([q, k, v]) + # check tensor offsets for hd_h2d and hd_2hd layouts + prod_all_dims = [np.prod(x.shape) for x in [q, k]] + offset = prod_all_dims[0] if check_ptrs_qkv else 0 + prod_h_d = k.shape[-1] * k.shape[-2] + check_2hd_offsets = all( + x.storage_offset() == (offset + i * prod_h_d) for i, x in enumerate([k, v]) ) - last_dim_size = k.shape[-1] - check_last_dim_offsets_kv = all( - i * last_dim_size == x.storage_offset() for i, x in enumerate([k, v]) + check_h2d_offsets = all( + x.storage_offset() == (offset + i * k.shape[-1]) for i, x in enumerate([k, v]) ) - last_two_dims_size = q.shape[-1] * q.shape[-2] - check_last_two_dims_offsets_qkv = all( - i * last_two_dims_size == x.storage_offset() for i, x in enumerate([q, k, v]) + # check tensor offsets for hd_hd_hd layouts + check_hd_offsets_qkv = ( + all(x.storage_offset() == sum(prod_all_dims[:i]) for i, x in enumerate([q, k, v])) + if check_ptrs_qkv + else all(x.storage_offset() == 0 for i, x in enumerate([q, k, v])) + ) + check_hd_offsets_qk = ( + all(x.storage_offset() == sum(prod_all_dims[:i]) for i, x in enumerate([q, k])) + if not check_ptrs_qkv and check_ptrs_qk + else all(x.storage_offset() == 0 for i, x in enumerate([q, k])) ) - last_two_dims_size = k.shape[-1] * k.shape[-2] - check_last_two_dims_offsets_kv = all( - i * last_two_dims_size == x.storage_offset() for i, x in enumerate([k, v]) + check_hd_offsets_kv = ( + all(x.storage_offset() == sum(prod_all_dims[1 : i + 1]) for i, x in enumerate([k, v])) + if not check_ptrs_qkv and check_ptrs_kv + else all(x.storage_offset() == 0 for i, x in enumerate([k, v])) ) - if ( - check_ptrs_qkv - and check_strides_qkv - and check_shapes_qkv - and check_last_two_dims_offsets_qkv - and not check_last_dim_offsets_qkv - ): + if check_ptrs_qkv and check_strides_qkv and check_shapes_qkv and check_3hd_offsets: # sb3hd, bs3hd, t3hd + # one chunk of memory, qkv, with q, k, v interleaved at dim=-3 in qkv qkv_layout = qkv_format[:-2] + "3" + qkv_format[-2:] - elif ( - check_ptrs_qkv and check_strides_qkv and check_shapes_qkv and check_last_dim_offsets_qkv - ): + elif check_ptrs_qkv and check_strides_qkv and check_shapes_qkv and check_h3d_offsets: # sbh3d, bsh3d, th3d + # one chunk of memory, qkv, with q, k, v interleaved at dim=-2 in qkv qkv_layout = qkv_format[:-1] + "3" + qkv_format[-1:] - elif ( - check_ptrs_kv - and check_strides_kv - and check_shapes_kv - and check_last_two_dims_offsets_kv - and not check_last_dim_offsets_kv - ): + elif check_ptrs_kv and check_strides_kv and check_shapes_kv and check_2hd_offsets: # sbhd_sb2hd, bshd_bs2hd, thd_t2hd + # two chunks of memory, q and kv, with k, v interleaved at dim=-3 in kv + # q and kv may be disjoint or consecutive in memory, and when consecutive, they may + # have the same data pointer, i.e. check_ptrs_qkv=True qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:] - elif check_ptrs_kv and check_strides_kv and check_shapes_kv and check_last_dim_offsets_kv: + elif check_ptrs_kv and check_strides_kv and check_shapes_kv and check_h2d_offsets: # sbhd_sbh2d, bshd_bsh2d, thd_th2d + # two chunks of memory, q and kv, with k, v interleaved at dim=-2 in kv + # q and kv may be disjoint or consecutive in memory, and when consecutive, they may + # have the same data pointer, i.e. check_ptrs_qkv=True qkv_layout = qkv_format + "_" + qkv_format[:-1] + "2" + qkv_format[-1:] - elif check_strides_kv and check_shapes_kv: + elif ( + check_strides_kv + and check_shapes_kv + and (check_hd_offsets_qkv or check_hd_offsets_kv or check_hd_offsets_qk) + ): # sbhd_sbhd_sbhd, bshd_bshd_bshd, thd_thd_thd + # three chunks of memory, q, k and v, which may be disjoint or consecutive, and + # when consecutive, they may have the same data pointer, i.e. check_ptrs_qkv=True or + # check_ptrs_qk=True or check_ptrs_kv=True qkv_layout = "_".join(list([qkv_format]) * 3) else: qkv_layout = "not_supported" @@ -4747,7 +5123,7 @@ def run_iteratively(q, k, v): q, k, v = [x.contiguous() for x in [q, k, v]] qkv_layout = run_iteratively(q, k, v) if qkv_layout == "not_supported": - raise Exception("The provided qkv memory layout is not supported!") + raise RuntimeError("The provided qkv memory layout is not supported!") return qkv_layout, q, k, v @@ -4813,12 +5189,13 @@ def __init__( ) -> None: super().__init__() - assert ( - _flash_attn_version >= _flash_attn_version_required - ), f"FlashAttention minimum version {_flash_attn_version_required} is required." - assert ( - _flash_attn_version <= _flash_attn_max_version - ), f"FlashAttention maximum version {_flash_attn_max_version} is supported." + if _flash_attn_is_installed: + assert ( + _flash_attn_version >= _flash_attn_version_required + ), f"FlashAttention minimum version {_flash_attn_version_required} is required." + assert ( + _flash_attn_version <= _flash_attn_max_version + ), f"FlashAttention maximum version {_flash_attn_max_version} is supported." self.softmax_scale = softmax_scale self.attention_dropout_ctx = attention_dropout_ctx @@ -4826,6 +5203,10 @@ def __init__( self.attention_type = attention_type self.layer_number = 1 if layer_number is None else layer_number self.deterministic = deterministic + self.logger = logging.getLogger("FlashAttention") + self.logger.setLevel(_log_level) + if not self.logger.hasHandlers(): + self.logger.addHandler(_stream_handler) def forward( self, @@ -4841,7 +5222,7 @@ def forward( attn_mask_type: str = "causal", window_size: Optional[Tuple[int, int]] = None, alibi_slopes: Optional[torch.Tensor] = None, - cp_group: Optional[dist_group_type] = None, + cp_group: Optional[Union[dist_group_type, List[dist_group_type]]] = None, cp_global_ranks: List[int] = None, cp_stream: torch.cuda.Stream = None, cp_comm_type: str = "p2p", @@ -4861,7 +5242,12 @@ def forward( qkv_layout in QKVLayouts ), f"FlashAttention does not support qkv_layout = {qkv_layout}!" - cp_size = 1 if cp_group is None else get_distributed_world_size(cp_group) + cp_size = 1 + if isinstance(cp_group, dist_group_type): + cp_size = get_distributed_world_size(cp_group) + elif isinstance(cp_group, list): + for group in cp_group: + cp_size *= get_distributed_world_size(group) context_parallel = cp_size > 1 qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) @@ -4879,20 +5265,23 @@ def forward( ) else: query_layer, key_layer, value_layer = [ - x.transpose(0, 1).contiguous() - for x in (query_layer, key_layer, value_layer) + x.transpose(0, 1) for x in (query_layer, key_layer, value_layer) ] - elif qkv_format in ["bshd", "thd"]: + if context_parallel: query_layer, key_layer, value_layer = [ x.contiguous() for x in (query_layer, key_layer, value_layer) ] else: if qkv_format == "sbhd": query_layer._data, key_layer._data, value_layer._data = [ - x.transpose(0, 1).contiguous() + x.transpose(0, 1) for x in (query_layer._data, key_layer._data, value_layer._data) ] - elif qkv_format in ["bshd", "thd"]: + query_layer, key_layer, value_layer = [ + Float8Tensor.make_like(x, data=x._data) + for x in (query_layer, key_layer, value_layer) + ] + if context_parallel: query_layer._data, key_layer._data, value_layer._data = [ x.contiguous() for x in (query_layer._data, key_layer._data, value_layer._data) ] @@ -5011,12 +5400,12 @@ def forward( fa_optional_forward_kwargs["alibi_slopes"] = alibi_slopes if _flash_attn_2_4_1_plus: fa_optional_forward_kwargs["deterministic"] = self.deterministic - if _flash_attn_2_5_7_plus: - fa_optional_forward_kwargs["block_table"] = None fa_optional_forward_args_thd = [] if qkv_format in ["bshd", "sbhd"] and "padding" not in attn_mask_type: func = flash_attn_func if not _use_flash_attn_3 else flash_attn_func_v3 else: + if _flash_attn_2_5_7_plus: + fa_optional_forward_kwargs["block_table"] = None func = ( flash_attn_varlen_func if not _use_flash_attn_3 @@ -5027,33 +5416,62 @@ def forward( fa_optional_forward_args_thd.append(max_seqlen_q) fa_optional_forward_args_thd.append(max_seqlen_kv) if _use_flash_attn_3: + fa_3_optional_forward_kwargs = {} + fa_3_optional_forward_kwargs["window_size"] = window_size + fa_3_optional_forward_kwargs["deterministic"] = self.deterministic + activation_dtype = query_layer.dtype if fp8: fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - activation_dtype = query_layer.dtype torch_dtype = get_fp8_torch_dtype(fp8_meta["recipe"], fprop_tensor=True) + + def convert_to_torch_float8(tensor, dtype): + out = torch.Tensor().to(device=tensor.device, dtype=dtype) + out.set_( + tensor._data.untyped_storage(), + tensor._data.storage_offset(), + tensor._data.shape, + tensor._data.stride(), + ) + return out + if fp8_meta["recipe"].fp8_mha: assert all( isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer] ), "q/k/v must be Float8Tensors for FP8 MHA." fp8_meta["scaling_fwd"].scale_inv[META_QKV] = query_layer._scale_inv - query_layer, key_layer, value_layer = ( - x.to(activation_dtype).to(torch_dtype) - for x in [query_layer, key_layer, value_layer] - ) else: query_layer, key_layer, value_layer = ( - x.to(torch_dtype) for x in [query_layer, key_layer, value_layer] + Float8Tensor.to_float8(x, fp8_dtype=fp8_dtype_forward) + for x in [query_layer, key_layer, value_layer] ) - output, _ = func( - query_layer, - key_layer, - value_layer, - *fa_optional_forward_args_thd, - softmax_scale=self.softmax_scale, - causal="causal" in attn_mask_type, - deterministic=self.deterministic, - ) + fa_3_optional_forward_kwargs["descale_q"] = query_layer._scale_inv + fa_3_optional_forward_kwargs["descale_k"] = key_layer._scale_inv + fa_3_optional_forward_kwargs["descale_v"] = value_layer._scale_inv + query_layer, key_layer, value_layer = ( + convert_to_torch_float8(x, torch_dtype) + for x in [query_layer, key_layer, value_layer] + ) + try: + output, _ = func( + query_layer, + key_layer, + value_layer, + *fa_optional_forward_args_thd, + softmax_scale=self.softmax_scale, + causal="causal" in attn_mask_type, + **fa_3_optional_forward_kwargs, + ) + except TypeError as e: + if _flash_attn_3_0_0_beta: + e.args = ( + e.args[0] + + ". Please update your flash-attn v3 (beta) installation as it " + + "may have added more supported arguments to its API. \n" + + _flash_attn_3_installation_steps, + ) + e.args[1:] + raise + if fp8 and fp8_meta["recipe"].fp8_mha: output = cast_to_fp8( output, @@ -5087,14 +5505,14 @@ def forward( if qkv_format == "sbhd": # (bs)hd -> bs(hd) -> sb(hd) if fp8 and fp8_meta["recipe"].fp8_mha: - output.reshape(batch_size * max_seqlen_q // cp_size, -1).transpose_2d() - output = output.reshape(batch_size, max_seqlen_q // cp_size, -1) - else: - output = ( - output.view(batch_size, max_seqlen_q // cp_size, -1) + output = Float8Tensor.make_like( + output, + data=output._data.reshape(batch_size, max_seqlen_q // cp_size, -1) .transpose(0, 1) - .contiguous() + .contiguous(), ) + else: + output = output.view(batch_size, max_seqlen_q // cp_size, -1).transpose(0, 1) elif qkv_format == "bshd": # (bs)hd -> bs(hd) output = output.reshape(batch_size, max_seqlen_q // cp_size, -1) @@ -5102,7 +5520,7 @@ def forward( # thd -> t(hd) output = output.reshape(output.shape[0], -1) - return output + return output.contiguous() def _combine_tensors( @@ -5161,6 +5579,7 @@ def forward( fp8_meta, deterministic, ): + # pylint: disable=missing-function-docstring is_input_fp8 = False is_output_fp8 = fp8_meta["recipe"].fp8_mha if fp8: @@ -5314,6 +5733,7 @@ def forward( @staticmethod def backward(ctx, d_out): + # pylint: disable=missing-function-docstring if ctx.is_output_fp8: assert isinstance( d_out, Float8Tensor @@ -5333,12 +5753,12 @@ def backward(ctx, d_out): fwd_scale_invs, *aux_ctx_tensors, ) = ctx.saved_tensors + rest = [None] if not aux_ctx_tensors[0].is_contiguous(): aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous() if ctx.use_FAv2_bwd: softmax_lse, rng_state = aux_ctx_tensors dqkv = torch.empty_like(qkv) - maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x d_out, q, k, v, out = [ maybe_contiguous(x) for x in (d_out, qkv[:, 0], qkv[:, 1], qkv[:, 2], out) ] @@ -5549,6 +5969,7 @@ def forward( fp8_meta, deterministic, ): + # pylint: disable=missing-function-docstring is_input_fp8 = False is_output_fp8 = fp8_meta["recipe"].fp8_mha if fp8: @@ -5730,6 +6151,7 @@ def forward( @staticmethod def backward(ctx, d_out): + # pylint: disable=missing-function-docstring if ctx.is_output_fp8: assert isinstance( d_out, Float8Tensor @@ -5753,13 +6175,13 @@ def backward(ctx, d_out): fwd_scale_invs, *aux_ctx_tensors, ) = ctx.saved_tensors + rest = [None] if not aux_ctx_tensors[0].is_contiguous(): aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous() if ctx.use_FAv2_bwd: softmax_lse, rng_state = aux_ctx_tensors dq = torch.empty_like(q) dkv = torch.empty_like(kv) - maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x d_out, q, k, v, out = [maybe_contiguous(x) for x in (d_out, q, kv[:, 0], kv[:, 1], out)] flash_attn_cuda_bwd( d_out, @@ -6001,6 +6423,7 @@ def forward( fp8_meta, deterministic, ): + # pylint: disable=missing-function-docstring is_input_fp8 = False is_output_fp8 = fp8_meta["recipe"].fp8_mha if fp8: @@ -6266,6 +6689,7 @@ def forward( @staticmethod def backward(ctx, d_out): + # pylint: disable=missing-function-docstring if ctx.is_output_fp8: assert isinstance( d_out, Float8Tensor @@ -6293,12 +6717,12 @@ def backward(ctx, d_out): ) = ctx.saved_tensors if not aux_ctx_tensors[0].is_contiguous(): aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous() + rest = [None] if ctx.use_FAv2_bwd: softmax_lse, rng_state = aux_ctx_tensors dq = torch.empty_like(q) dk = torch.empty_like(k) dv = torch.empty_like(v) - maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x d_out, q, k, v, out = [maybe_contiguous(x) for x in (d_out, q, k, v, out)] flash_attn_cuda_bwd( d_out, @@ -6617,10 +7041,10 @@ def __init__( def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument """ Temporarily remove fused_attention._extra_state as a missing key - or an unexpected key when loading TransformerEngine checkpoints. + or an unexpected key when loading Transformer Engine checkpoints. Please store FP8 metadata as DotProductAttention's _extra_state, rather than FusedAttention's _extra_state. This hook will be - phased out in TransformerEngine 2.0. + phased out in Transformer Engine 2.0. """ for key in incompatible_keys.missing_keys: if "fused_attention._extra_state" in key: @@ -6655,7 +7079,7 @@ def forward( core_attention_bias_type: str = "no_bias", core_attention_bias: Optional[torch.Tensor] = None, fast_zero_fill: bool = True, - cp_group: Optional[dist_group_type] = None, + cp_group: Optional[Union[dist_group_type, List[dist_group_type]]] = None, cp_global_ranks: List[int] = None, cp_stream: torch.cuda.Stream = None, cp_comm_type: str = "p2p", @@ -6677,7 +7101,12 @@ def forward( qkv_layout in QKVLayouts ), f"FusedAttention does not support qkv_layout = {qkv_layout}!" - cp_size = 1 if cp_group is None else get_distributed_world_size(cp_group) + cp_size = 1 + if isinstance(cp_group, dist_group_type): + cp_size = get_distributed_world_size(cp_group) + elif isinstance(cp_group, list): + for group in cp_group: + cp_size *= get_distributed_world_size(group) context_parallel = cp_size > 1 qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) @@ -6845,6 +7274,13 @@ class DotProductAttention(TransformerEngineBaseModule): and set the environment variable :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0`. In order to disable`flash-attn` entirely, set :attr:`NVTE_FLASH_ATTN=0`. + .. note:: + + Transformer Engine stores the FP8 metadata under a `._extra_state` key when checkpointing. + As the FP8 attention support expands from one backend to multiple backends, the location + of that key has also shifted (see `FP8 checkpoint compatibility `_). + + Parameters ---------- num_attention_heads : int @@ -6873,7 +7309,7 @@ class DotProductAttention(TransformerEngineBaseModule): e.g. a different mask for training and inference. 1. For "`no_mask`", no attention mask is applied. 2. For "`causal`", "`causal_bottom_right`", or the causal mask in - "`padding_causal`" and "`padding_causal_bottom_right`", TransformerEngine + "`padding_causal`" and "`padding_causal_bottom_right`", Transformer Engine calculates and applies an upper triangular mask to the softmax input. No user input is needed. Causal masks without the "`bottom_right`" appendix align the diagonal line to the top left corner of the softmax matrix. With @@ -6923,8 +7359,11 @@ class DotProductAttention(TransformerEngineBaseModule): tensor parallel world size. tp_group : ProcessGroup, default = `None` tensor parallel process group. - cp_group : ProcessGroup, default = `None` + cp_group : Union[ProcessGroup, List[ProcessGroup]], default = `None` context parallel process group. + ProcessGroup is for cp_comm_type of "p2p", "all_gather", and "a2a". + List[ProcessGroup] is for cp_comm_type of "a2a+p2p", where cp_group[0] + and cp_group[1] are for a2a and p2p communications respectively. cp_global_ranks : list of global rank IDs, default = `None` global rank IDs of GPUs that are in cp_group. cp_stream : CUDA stream, default = `None` @@ -6932,15 +7371,18 @@ class DotProductAttention(TransformerEngineBaseModule): compute and communication overlapping. To address the wave quantization issue of each split step, we add an additional CUDA stream so that we can overlap two flash attention kernels. - cp_comm_type : str + cp_comm_type : str, default = `p2p` inter-gpu communication type for context parallelism. - Can be "p2p" or "all_gather" or "a2a". + Can be "p2p" or "all_gather" or "a2a" or "a2a+p2p". "p2p": Exchange KV chunks with P2P communications in ring topology. P2P is async and can be overlapped with attention compute. "all_gather": All-gather to get full sequence of KV before attention. The all-gather is not async, and cannot be overlapped. "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP group, and gather to get full sequence of QKV. + "a2a+p2p": hierarchical CP implementation. First applying a2a to QKV + across each CP sub-group (e.g., via NVLink), then exchanging KV with + p2p between sub-groups (e.g., via IBLink). """ def __init__( @@ -6958,7 +7400,7 @@ def __init__( tp_group: Optional[dist_group_type] = None, layer_number: Optional[int] = None, attention_type: str = "self", - cp_group: Optional[dist_group_type] = None, + cp_group: Optional[Union[dist_group_type, List[dist_group_type]]] = None, cp_global_ranks: List[int] = None, cp_stream: torch.cuda.Stream = None, cp_comm_type: str = "p2p", @@ -7080,8 +7522,8 @@ def __init__( def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument """ Temporarily remove core_attention._extra_state as a missing key - when loading older TransformerEngine checkpoints. Will phase out - this hook in TransformerEngine 2.0. + when loading older Transformer Engine checkpoints. Will phase out + this hook in Transformer Engine 2.0. """ for key in incompatible_keys.missing_keys: if "core_attention._extra_state" in key: @@ -7089,6 +7531,28 @@ def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unuse self.register_load_state_dict_post_hook(remove_extra_states_check) + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + """ + This function helps to load Transformer Engine 1.6 and 1.7 checkpoints, where FP8 attention + metadata is stored under the `core_attention.fused_attention._extra_state` key and not the + `core_attention._extra_state` key. Please see `FP8 checkpoint compatibility + `_ for more details. + """ + fused_attn_key = False + dot_product_attn_key = False + for k in state_dict.keys(): + if "core_attention.fused_attention._extra_state" in k: + fused_attn_key = True + if "core_attention._extra_state" in k: + dot_product_attn_key = True + if fused_attn_key and not dot_product_attn_key: + prefix = prefix + "fused_attention." + super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + def _checkpointed_attention_forward( self, attention_func: Callable, @@ -7113,7 +7577,7 @@ def custom_forward(*input_args, **input_kwargs): def set_context_parallel_group( self, - cp_group: Union[dist_group_type, None], + cp_group: Union[dist_group_type, List[dist_group_type], None], cp_global_ranks: List[int], cp_stream: torch.cuda.Stream, cp_comm_type: str = "p2p", @@ -7124,21 +7588,27 @@ def set_context_parallel_group( Parameters ---------- - cp_group : ProcessGroup + cp_group : Union[ProcessGroup, List[ProcessGroup]] context parallel process group. + ProcessGroup is for cp_comm_type of "p2p", "all_gather", and "a2a". + List[ProcessGroup] is for cp_comm_type of "a2a+p2p", where cp_group[0] + and cp_group[1] are for a2a and p2p communications respectively. cp_global_ranks : List[int] list of global ranks in the context group. cp_stream : torch.cuda.Stream cuda stream for context parallel execution. - cp_comm_type : str + cp_comm_type : str, default = `p2p` inter-gpu communication type for context parallelism. - Can be "p2p" or "all_gather" or "a2a". + Can be "p2p" or "all_gather" or "a2a" or "a2a+p2p". "p2p": Exchange KV chunks with P2P communications in ring topology. P2P is async and can be overlapped with attention compute. "all_gather": All-gather to get full sequence of KV before attention. The all-gather is not async, and cannot be overlapped. "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP group, and gather to get full sequence of QKV. + "a2a+p2p": hierarchical CP implementation. First applying a2a to QKV + across each CP sub-group (e.g., via NVLink), then exchanging KV with + p2p between sub-groups (e.g., via IBLink). """ self.cp_group = cp_group self.cp_global_ranks = cp_global_ranks @@ -7192,14 +7662,14 @@ def forward( Users can use environment variables :attr:`NVTE_FLASH_ATTN`, :attr:`NVTE_FUSED_ATTN`, and :attr:`NVTE_FUSED_ATTN_BACKEND` to control which DotProductAttention backend, - and FusedAttention backend if applicable, to use. TransformerEngine prioritizes + and FusedAttention backend if applicable, to use. Transformer Engine prioritizes FlashAttention over FusedAttention and over UnfusedDotProductAttention. If FusedAttention is being used, users can also choose to switch to flash-attn's implementation for backward by setting :attr:`NVTE_FUSED_ATTN_USE_FAv2_BWD=1` (default: 0), because of the performance differences between various versions of flash-attn and FusedAttention. Further, :attr:`NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT` can be used to enable (:attr:`1`) or disable (:attr:`0`) the workspace related - optimizations in FusedAttention. When unset, TransformerEngine determines the code path + optimizations in FusedAttention. When unset, Transformer Engine determines the code path based on its internal logic. These optimizations trade memory for performance and should be used with care. @@ -7448,7 +7918,12 @@ def forward( max_seqlen_kv = int((seqlens_kv.max().item() + 63) // 64 * 64) batch_size = len(cu_seqlens_q) - 1 - cp_size = 1 if self.cp_group is None else get_distributed_world_size(self.cp_group) + cp_size = 1 + if isinstance(self.cp_group, dist_group_type): + cp_size = get_distributed_world_size(self.cp_group) + elif isinstance(self.cp_group, list): + for group in self.cp_group: + cp_size *= get_distributed_world_size(group) context_parallel = cp_size > 1 if qkv_format in ["sbhd", "bshd"]: @@ -7458,7 +7933,7 @@ def forward( if qkv_format == "sbhd": max_seqlen_q, max_seqlen_kv = (query_layer.shape[0], key_layer.shape[0]) batch_size = query_layer.shape[1] - if qkv_format == "bshd": + else: max_seqlen_q, max_seqlen_kv = (query_layer.shape[1], key_layer.shape[1]) batch_size = query_layer.shape[0] max_seqlen_q *= cp_size @@ -7592,7 +8067,7 @@ def forward( fp8=self.fp8, fp8_meta=self.fp8_meta, ) - global _attention_backends, _flash_attn_3_plus, _use_flash_attn_3 + global _attention_backends, _use_flash_attn_3 if ( _attention_backends["attention_params"] is None or attention_params != _attention_backends["attention_params"] @@ -7600,7 +8075,7 @@ def forward( _attention_backends["attention_params"] = attention_params _attention_backends["backend_selection_requires_update"] = True if _attention_backends["backend_selection_requires_update"]: - _use_flash_attn_3 = _flash_attn_3_plus + _use_flash_attn_3 = _flash_attn_3_is_installed ( use_flash_attention, use_fused_attention, @@ -7611,7 +8086,7 @@ def forward( if use_flash_attention: self.logger.info( "Running with FlashAttention backend (version %s)", - _flash_attn_version if not _use_flash_attn_3 else _flash_attn_v3_version, + _flash_attn_version if not _use_flash_attn_3 else _flash_attn_3_version, ) elif use_fused_attention: self.logger.info( @@ -7767,7 +8242,7 @@ def forward( alibi_slopes=alibi_slopes, ) - raise Exception("No dot product attention support for the provided inputs!") + raise ValueError("No dot product attention support for the provided inputs!") class MultiheadAttention(torch.nn.Module): @@ -7960,6 +8435,8 @@ def __init__( self.params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype self.num_attention_heads = num_attention_heads self.return_bias = return_bias + self.cp_size = 1 + self.cp_rank = 0 kv_channels = kv_channels if kv_channels else (hidden_size // num_attention_heads) @@ -8121,6 +8598,7 @@ def __init__( def _allocate_memory( self, inference_max_sequence_len: int, batch_size: int, dtype: torch.dtype ) -> torch.Tensor: + """Allocates memory for KV cache.""" return torch.empty( inference_max_sequence_len, batch_size, @@ -8144,7 +8622,7 @@ def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> N def set_context_parallel_group( self, - cp_group: Union[dist_group_type, None], + cp_group: Union[dist_group_type, List[dist_group_type], None], cp_global_ranks: List[int], cp_stream: torch.cuda.Stream, cp_comm_type: str = "p2p", @@ -8155,22 +8633,43 @@ def set_context_parallel_group( Parameters ---------- - cp_group : ProcessGroup + cp_group : Union[ProcessGroup, List[ProcessGroup]] context parallel process group. + ProcessGroup is for cp_comm_type of "p2p", "all_gather", and "a2a". + List[ProcessGroup] is for cp_comm_type of "a2a+p2p", where cp_group[0] + and cp_group[1] are for a2a and p2p communications respectively. cp_global_ranks : List[int] list of global ranks in the context group. cp_stream : torch.cuda.Stream cuda stream for context parallel execution. - cp_comm_type : str + cp_comm_type : str, default = `p2p` inter-gpu communication type for context parallelism. - Can be "p2p" or "all_gather" or "a2a". + Can be "p2p" or "all_gather" or "a2a", "a2a+p2p". "p2p": Exchange KV chunks with P2P communications in ring topology. P2P is async and can be overlapped with attention compute. "all_gather": All-gather to get full sequence of KV before attention. The all-gather is not async, and cannot be overlapped. "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP group, and gather to get full sequence of QKV. + "a2a+p2p": hierarchical CP implementation. First applying a2a to QKV + across each CP sub-group (e.g., via NVLink), then exchanging KV with + p2p between sub-groups (e.g., via IBLink). """ + if isinstance(cp_group, dist_group_type): + self.cp_size = get_distributed_world_size(cp_group) + self.cp_rank = get_distributed_rank(cp_group) + elif isinstance(cp_group, list): + assert len(cp_group) == 2, "Current implementation only supports two-level CP groups!" + assert ( + cp_comm_type == "a2a+p2p" + ), "Only cp_comm_type of a2a+p2p requires hierarchical CP groups!" + cp_size_a2a = get_distributed_world_size(cp_group[0]) + cp_rank_a2a = get_distributed_rank(cp_group[0]) + cp_size_p2p = get_distributed_world_size(cp_group[1]) + cp_rank_p2p = get_distributed_rank(cp_group[1]) + self.cp_size = cp_size_a2a * cp_size_p2p + self.cp_rank = cp_size_a2a * cp_rank_p2p + cp_rank_a2a + # Deep iterate but skip self to avoid infinite recursion. for index, child in enumerate(self.modules()): if index == 0: @@ -8285,10 +8784,8 @@ def forward( window_size = check_set_window_size(attn_mask_type, window_size) if "padding" in attn_mask_type and attention_mask is not None: - for i, _ in enumerate(attention_mask): - assert ( - attention_mask[i].dtype == torch.bool - ), "Attention mask must be in boolean type!" + for mask in attention_mask: + assert mask.dtype == torch.bool, "Attention mask must be in boolean type!" assert ( core_attention_bias_type in AttnBiasTypes @@ -8330,6 +8827,7 @@ def forward( and FP8GlobalStateManager.get_fp8_recipe().fp8_mha ) + layernorm_output = None if self.attention_type == "self": # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn] if self.input_layernorm: @@ -8497,6 +8995,8 @@ def forward( sequence_length = key_layer.size(0) elif self.qkv_format == "bshd": sequence_length = key_layer.size(1) + else: + raise ValueError(f"QKV format {self.qkv_format} not supported for KV caching.") sequence_start = inference_params.sequence_len_offset sequence_end = sequence_start + sequence_length @@ -8504,8 +9004,24 @@ def forward( q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...] k_pos_emb = k_pos_emb[sequence_start:sequence_end, ...] - query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb, self.qkv_format, fused=True) - key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb, self.qkv_format, fused=True) + query_layer = apply_rotary_pos_emb( + query_layer, + q_pos_emb, + self.qkv_format, + fused=True, + cu_seqlens=cu_seqlens_q, + cp_size=self.cp_size, + cp_rank=self.cp_rank, + ) + key_layer = apply_rotary_pos_emb( + key_layer, + k_pos_emb, + self.qkv_format, + fused=True, + cu_seqlens=cu_seqlens_kv, + cp_size=self.cp_size, + cp_rank=self.cp_rank, + ) # =========================== # Core attention computation diff --git a/transformer_engine/pytorch/cpp_extensions/_common.py b/transformer_engine/pytorch/cpp_extensions/_common.py index b9d7288dfa..8f7e72e268 100644 --- a/transformer_engine/pytorch/cpp_extensions/_common.py +++ b/transformer_engine/pytorch/cpp_extensions/_common.py @@ -78,10 +78,10 @@ def canonicalize_fp8_scales( scale_inv_offset = 0 # Pack tensors and offsets into dicts - tensors = dict(scale=scale, amax=amax, scale_inv=scale_inv) - offsets = dict( - scale_offset=scale_offset, - amax_offset=amax_offset, - scale_inv_offset=scale_inv_offset, - ) + tensors = {"scale": scale, "amax": amax, "scale_inv": scale_inv} + offsets = { + "scale_offset": scale_offset, + "amax_offset": amax_offset, + "scale_inv_offset": scale_inv_offset, + } return tensors, offsets diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index cd0ecbaa6c..1932e9feb2 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -240,13 +240,11 @@ def fused_attn_fwd_qkvpacked( rng_elts_per_thread = ( max_seqlen * max_seqlen + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 ) // BACKEND_F16m512_FP8_THREADS_PER_CTA - # BF16/FP16 fused attention API from fmha_v2 - if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: + elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS - # FP8 fused attention API from fmha_v2 - if fused_attention_backend == FusedAttnBackend["FP8"]: + elif fused_attention_backend == FusedAttnBackend["FP8"]: rng_elts_per_thread = ( max_seqlen * max_seqlen + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 ) // BACKEND_F16m512_FP8_THREADS_PER_CTA @@ -259,6 +257,8 @@ def fused_attn_fwd_qkvpacked( assert q_scale_o is not None, "q_scale_o is required as an input for FP8 fused attention." assert amax_s is not None, "amax_s is required as an input for FP8 fused attention." assert amax_o is not None, "amax_o is required as an input for FP8 fused attention." + else: + raise ValueError(f"Unsupported backend {fused_attention_backend}") # execute kernel output_tensors = tex.fused_attn_fwd_qkvpacked( @@ -633,13 +633,11 @@ def fused_attn_fwd_kvpacked( rng_elts_per_thread = ( max_seqlen_q * max_seqlen_kv + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 ) // BACKEND_F16m512_FP8_THREADS_PER_CTA - # BF16/FP16 fused attention API from fmha_v2 - if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: + elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS - # FP8 fused attention API from fmha_v2 - if fused_attention_backend == FusedAttnBackend["FP8"]: + elif fused_attention_backend == FusedAttnBackend["FP8"]: rng_elts_per_thread = ( max_seqlen_q * max_seqlen_q + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 ) // BACKEND_F16m512_FP8_THREADS_PER_CTA @@ -652,6 +650,8 @@ def fused_attn_fwd_kvpacked( assert q_scale_o is not None, "q_scale_o is required as an input for FP8 fused attention." assert amax_s is not None, "amax_s is required as an input for FP8 fused attention." assert amax_o is not None, "amax_o is required as an input for FP8 fused attention." + else: + raise ValueError(f"Unsupported backend {fused_attention_backend}") # execute kernel output_tensors = tex.fused_attn_fwd_kvpacked( @@ -1058,13 +1058,11 @@ def fused_attn_fwd( rng_elts_per_thread = ( max_seqlen_q * max_seqlen_kv + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 ) // BACKEND_F16m512_FP8_THREADS_PER_CTA - # BF16/FP16 fused attention API from fmha_v2 - if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: + elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS - # FP8 fused attention API from fmha_v2 - if fused_attention_backend == FusedAttnBackend["FP8"]: + elif fused_attention_backend == FusedAttnBackend["FP8"]: rng_elts_per_thread = ( max_seqlen_q * max_seqlen_q + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 ) // BACKEND_F16m512_FP8_THREADS_PER_CTA @@ -1077,6 +1075,8 @@ def fused_attn_fwd( assert q_scale_o is not None, "q_scale_o is required as an input for FP8 fused attention." assert amax_s is not None, "amax_s is required as an input for FP8 fused attention." assert amax_o is not None, "amax_o is required as an input for FP8 fused attention." + else: + raise ValueError(f"Unsupported backend {fused_attention_backend}") # execute kernel output_tensors = tex.fused_attn_fwd( diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index 4e9c74d396..123758b0da 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -16,6 +16,11 @@ CPUOffloadEnabled = False +def is_cpu_offload_enabled() -> bool: + """Check if CPU offloading is currently enabled.""" + return CPUOffloadEnabled + + class CpuOffloadSavedTensorHook: """Contex-manager that executes a pair of pack/unpack hooks for saved tensors. @@ -156,6 +161,7 @@ class GroupCommitFunction(torch.autograd.Function): @staticmethod def forward(ctx, tensor, cpu_offload_handler): + # pylint: disable=missing-function-docstring cpu_offload_handler.on_group_commit_forward() ctx.cpu_offload_handler = cpu_offload_handler # return the identical tensor @@ -163,6 +169,7 @@ def forward(ctx, tensor, cpu_offload_handler): @staticmethod def backward(ctx, grad_output): + # pylint: disable=missing-function-docstring cpu_offload_handler = ctx.cpu_offload_handler cpu_offload_handler.on_group_commit_backward() return grad_output, None diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index c797208e06..c30e583178 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -412,10 +412,10 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor const bool transpose_output_memory); at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_seqlens, - const at::Tensor &freqs); + const at::Tensor &freqs, const int cp_size, const int cp_rank); at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Tensor &cu_seqlens, - const at::Tensor &freqs); + const at::Tensor &freqs, const int cp_size, const int cp_rank); /*************************************************************************************************** * Miscellaneous @@ -433,14 +433,14 @@ at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_s int half_idx); void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_step, - const at::Tensor &cu_seqlens, int total_tokens); + const at::Tensor &cu_seqlens, bool lse_packed); at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_seqlens, - int total_tokens); + bool lse_packed); void thd_out_correction(at::Tensor out, const at::Tensor &out_per_step, const at::Tensor &lse, const at::Tensor &lse_per_step, const at::Tensor &cu_seqlens, - bool only_second_half); + bool only_second_half, bool lse_packed); void thd_grad_correction(at::Tensor grad, const at::Tensor &grad_per_step, const at::Tensor &cu_seqlens, const std::string &first_half, diff --git a/transformer_engine/pytorch/csrc/extensions/apply_rope.cu b/transformer_engine/pytorch/csrc/extensions/apply_rope.cu index c58ba91d5e..c0cd2e9920 100644 --- a/transformer_engine/pytorch/csrc/extensions/apply_rope.cu +++ b/transformer_engine/pytorch/csrc/extensions/apply_rope.cu @@ -121,7 +121,7 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor } at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_seqlens, - const at::Tensor &freqs) { + const at::Tensor &freqs, const int cp_size, const int cp_rank) { using namespace transformer_engine; TORCH_CHECK(input.dim() == 3, "expected 3D tensor"); TORCH_CHECK(cu_seqlens.dim() == 1, "expected 1D tensor"); @@ -165,14 +165,15 @@ at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_ auto output_cu = makeTransformerEngineTensor(output); nvte_fused_rope_thd_forward(input_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), - output_cu.data(), max_s, b, h, d, d2, stride_t, stride_h, stride_d, - o_stride_t, o_stride_h, o_stride_d, at::cuda::getCurrentCUDAStream()); + output_cu.data(), cp_size, cp_rank, max_s, b, h, d, d2, stride_t, + stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, + at::cuda::getCurrentCUDAStream()); return output; } at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Tensor &cu_seqlens, - const at::Tensor &freqs) { + const at::Tensor &freqs, const int cp_size, const int cp_rank) { using namespace transformer_engine; TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(cu_seqlens.dim() == 1, "expected 1D tensor"); @@ -214,8 +215,8 @@ at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Ten auto input_grads_cu = makeTransformerEngineTensor(input_grads); nvte_fused_rope_thd_backward(output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), - input_grads_cu.data(), max_s, b, h, d, d2, stride_t, stride_h, - stride_d, o_stride_t, o_stride_h, o_stride_d, + input_grads_cu.data(), cp_size, cp_rank, max_s, b, h, d, d2, + stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, at::cuda::getCurrentCUDAStream()); return input_grads; diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index b2968a688d..8088a2b8f1 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -1464,9 +1464,9 @@ at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_s * Support THD format for Context Parallel: softmax_lse related operations **************************************************************************************************/ -template +template __global__ void thd_lse_kernel(lse_dtype *lse, float *half_lse, int *cu_seqlens, int batch, - int num_heads, int max_seqlen) { + int num_heads, int total_tokens) { extern __shared__ int cu_seqlens_s[]; for (int i = threadIdx.x; i <= batch; i += blockDim.x) { cu_seqlens_s[i] = cu_seqlens[i] / 2; @@ -1480,12 +1480,18 @@ __global__ void thd_lse_kernel(lse_dtype *lse, float *half_lse, int *cu_seqlens, for (int token_id = tid; token_id < num_total_tokens; token_id += num_threads) { int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); for (int head_id = blockIdx.y; head_id < num_heads; head_id += gridDim.y) { - size_t row = static_cast(seq_id) * num_heads + head_id; - int col = token_id - cu_seqlens_s[seq_id]; - int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; + size_t idx, half_idx; + if constexpr (lse_packed) { + idx = head_id * total_tokens + token_id + cu_seqlens_s[seq_id + 1]; + half_idx = head_id * total_tokens / 2 + token_id; + } else { + size_t row = static_cast(seq_id) * num_heads + head_id; + int col = token_id - cu_seqlens_s[seq_id]; + int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; - size_t idx = row * max_seqlen + col + seq_len; - size_t half_idx = row * max_seqlen / 2 + col; + idx = row * total_tokens + col + seq_len; + half_idx = row * total_tokens / 2 + col; + } Functor::run(lse, half_lse, idx, half_idx); } @@ -1504,32 +1510,53 @@ struct LseCorrectionFunctor { }; void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_step, - const at::Tensor &cu_seqlens, int total_tokens) { + const at::Tensor &cu_seqlens, bool lse_packed) { NVTE_CHECK(lse.scalar_type() == at::ScalarType::Double); NVTE_CHECK(lse_per_step.scalar_type() == at::ScalarType::Float); NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); - - NVTE_CHECK(lse.dim() == 3); - NVTE_CHECK(lse_per_step.dim() == 3); NVTE_CHECK(cu_seqlens.dim() == 1); - int batch = lse.size(0); - int num_heads = lse.size(1); - int max_seqlen = lse.size(2); + int batch, num_heads, total_tokens; + + if (lse_packed) { + NVTE_CHECK(lse.dim() == 2); + NVTE_CHECK(lse_per_step.dim() == 2); + + batch = cu_seqlens.size(0) - 1; + num_heads = lse.size(0); + total_tokens = lse.size(1); + + NVTE_CHECK(lse_per_step.size(0) == num_heads); + NVTE_CHECK(lse_per_step.size(1) == total_tokens / 2); + } else { + NVTE_CHECK(lse.dim() == 3); + NVTE_CHECK(lse_per_step.dim() == 3); + + batch = lse.size(0); + num_heads = lse.size(1); + total_tokens = lse.size(2); - NVTE_CHECK(lse_per_step.size(0) == batch); - NVTE_CHECK(lse_per_step.size(1) == num_heads); - NVTE_CHECK(lse_per_step.size(2) == max_seqlen / 2); - NVTE_CHECK(cu_seqlens.size(0) == batch + 1); + NVTE_CHECK(lse_per_step.size(0) == batch); + NVTE_CHECK(lse_per_step.size(1) == num_heads); + NVTE_CHECK(lse_per_step.size(2) == total_tokens / 2); + NVTE_CHECK(cu_seqlens.size(0) == batch + 1); + } constexpr unsigned int block = 256; unsigned int grid_x = (total_tokens / 2 + block - 1) / block; unsigned int grid_y = num_heads; dim3 grid = {grid_x, grid_y}; - thd_lse_kernel - <<>>( - lse.data_ptr(), lse_per_step.data_ptr(), cu_seqlens.data_ptr(), batch, - num_heads, max_seqlen); + if (lse_packed) { + thd_lse_kernel + <<>>( + lse.data_ptr(), lse_per_step.data_ptr(), cu_seqlens.data_ptr(), + batch, num_heads, total_tokens); + } else { + thd_lse_kernel + <<>>( + lse.data_ptr(), lse_per_step.data_ptr(), cu_seqlens.data_ptr(), + batch, num_heads, total_tokens); + } } struct ReadLseFunctor { @@ -1540,29 +1567,51 @@ struct ReadLseFunctor { }; at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_seqlens, - int total_tokens) { + bool lse_packed) { NVTE_CHECK(lse.scalar_type() == at::ScalarType::Float); - NVTE_CHECK(lse.dim() == 3); NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); NVTE_CHECK(cu_seqlens.dim() == 1); - int batch = lse.size(0); - int num_heads = lse.size(1); - int max_seqlen = lse.size(2); + int batch, num_heads, total_tokens; + std::vector shape; + + if (lse_packed) { + NVTE_CHECK(lse.dim() == 2); + + batch = cu_seqlens.size(0) - 1; + num_heads = lse.size(0); + total_tokens = lse.size(1); + + shape = {num_heads, total_tokens / 2}; + } else { + NVTE_CHECK(lse.dim() == 3); - NVTE_CHECK(cu_seqlens.size(0) == batch + 1); + batch = lse.size(0); + num_heads = lse.size(1); + total_tokens = lse.size(2); + + NVTE_CHECK(cu_seqlens.size(0) == batch + 1); + + shape = {batch, num_heads, total_tokens / 2}; + } - std::vector shape = {batch, num_heads, max_seqlen / 2}; at::Tensor half_lse = at::zeros(shape, at::CUDA(lse.scalar_type())); constexpr unsigned int block = 256; unsigned int grid_x = (total_tokens / 2 + block - 1) / block; unsigned int grid_y = num_heads; dim3 grid = {grid_x, grid_y}; - thd_lse_kernel - <<>>( - lse.data_ptr(), half_lse.data_ptr(), cu_seqlens.data_ptr(), batch, - num_heads, max_seqlen); + if (lse_packed) { + thd_lse_kernel + <<>>( + lse.data_ptr(), half_lse.data_ptr(), cu_seqlens.data_ptr(), batch, + num_heads, total_tokens); + } else { + thd_lse_kernel + <<>>( + lse.data_ptr(), half_lse.data_ptr(), cu_seqlens.data_ptr(), batch, + num_heads, total_tokens); + } return half_lse; } @@ -1571,10 +1620,10 @@ at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_ * Support THD format for Context Parallel: Out correction in forward **************************************************************************************************/ -template +template __global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float *lse, float *lse_per_step, int *cu_seqlens, int batch, - int num_heads, int dim_per_head, int max_seqlen) { + int num_heads, int dim_per_head, int lse_seqlen) { extern __shared__ int cu_seqlens_s[]; for (int i = threadIdx.x; i <= batch; i += blockDim.x) { cu_seqlens_s[i] = cu_seqlens[i] / (only_second_half + 1); @@ -1592,11 +1641,16 @@ __global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float for (int head_id = blockIdx.y; head_id < num_heads; head_id += gridDim.y) { size_t idx, idx_per_step; - size_t row = static_cast(seq_id) * num_heads + head_id; - int col = token_id - cu_seqlens_s[seq_id]; - int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; - idx = row * max_seqlen + col + seq_len * only_second_half; - idx_per_step = row * max_seqlen / (only_second_half + 1) + col; + if constexpr (lse_packed) { + idx = head_id * lse_seqlen + token_id + cu_seqlens_s[seq_id + 1] * only_second_half; + idx_per_step = head_id * lse_seqlen / (only_second_half + 1) + token_id; + } else { + size_t row = static_cast(seq_id) * num_heads + head_id; + int col = token_id - cu_seqlens_s[seq_id]; + int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; + idx = row * lse_seqlen + col + seq_len * only_second_half; + idx_per_step = row * lse_seqlen / (only_second_half + 1) + col; + } float lse_corrected_exp = exp(lse_per_step[idx_per_step] - lse[idx]); idx = token_id + cu_seqlens_s[seq_id + 1] * only_second_half; @@ -1622,7 +1676,7 @@ __global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float template static void thd_out_correction_helper(at::Tensor out, const at::Tensor &out_per_step, const at::Tensor &lse, const at::Tensor &lse_per_step, - const at::Tensor &cu_seqlens) { + const at::Tensor &cu_seqlens, bool lse_packed) { NVTE_CHECK(out.scalar_type() == out_per_step.scalar_type()); NVTE_CHECK(lse.scalar_type() == at::ScalarType::Float); NVTE_CHECK(lse_per_step.scalar_type() == at::ScalarType::Float); @@ -1631,17 +1685,30 @@ static void thd_out_correction_helper(at::Tensor out, const at::Tensor &out_per_ int total_tokens = out.size(0); int num_heads = out.size(1); int dim_per_head = out.size(2); - int batch = lse.size(0); - int max_seqlen = lse.size(2); NVTE_CHECK(out_per_step.size(0) == total_tokens / (only_second_half + 1)); NVTE_CHECK(out_per_step.size(1) == num_heads); NVTE_CHECK(out_per_step.size(2) == dim_per_head); - NVTE_CHECK(lse.size(1) == num_heads); - NVTE_CHECK(lse_per_step.size(0) == batch); - NVTE_CHECK(lse_per_step.size(1) == num_heads); - NVTE_CHECK(lse_per_step.size(2) == max_seqlen / (only_second_half + 1)); - NVTE_CHECK(cu_seqlens.size(0) == batch + 1); + + int batch, lse_seqlen; + if (lse_packed) { + batch = cu_seqlens.size(0) - 1; + lse_seqlen = total_tokens; + + NVTE_CHECK(lse.size(0) == num_heads); + NVTE_CHECK(lse.size(1) == lse_seqlen); + NVTE_CHECK(lse_per_step.size(0) == num_heads); + NVTE_CHECK(lse_per_step.size(1) == lse_seqlen / (only_second_half + 1)); + } else { + batch = lse.size(0); + lse_seqlen = lse.size(2); + + NVTE_CHECK(lse.size(1) == num_heads); + NVTE_CHECK(lse_per_step.size(0) == batch); + NVTE_CHECK(lse_per_step.size(1) == num_heads); + NVTE_CHECK(lse_per_step.size(2) == lse_seqlen / (only_second_half + 1)); + NVTE_CHECK(cu_seqlens.size(0) == batch + 1); + } constexpr int tile = 16; constexpr int block = 512; @@ -1649,39 +1716,53 @@ static void thd_out_correction_helper(at::Tensor out, const at::Tensor &out_per_ (static_cast(total_tokens) / (only_second_half + 1) * tile + block - 1) / block; dim3 grid = {grid_x, (unsigned int)num_heads}; - thd_out_correction_kernel - <<>>( - out.data_ptr(), out_per_step.data_ptr(), lse.data_ptr(), - lse_per_step.data_ptr(), cu_seqlens.data_ptr(), batch, num_heads, - dim_per_head, max_seqlen); + if (lse_packed) { + thd_out_correction_kernel + <<>>( + out.data_ptr(), out_per_step.data_ptr(), lse.data_ptr(), + lse_per_step.data_ptr(), cu_seqlens.data_ptr(), batch, num_heads, + dim_per_head, lse_seqlen); + } else { + thd_out_correction_kernel + <<>>( + out.data_ptr(), out_per_step.data_ptr(), lse.data_ptr(), + lse_per_step.data_ptr(), cu_seqlens.data_ptr(), batch, num_heads, + dim_per_head, lse_seqlen); + } } void thd_out_correction(at::Tensor out, const at::Tensor &out_per_step, const at::Tensor &lse, const at::Tensor &lse_per_step, const at::Tensor &cu_seqlens, - bool only_second_half) { + bool only_second_half, bool lse_packed) { if (only_second_half) { if (out.scalar_type() == at::ScalarType::Half) { using dtype = at::Half; - thd_out_correction_helper(out, out_per_step, lse, lse_per_step, cu_seqlens); + thd_out_correction_helper(out, out_per_step, lse, lse_per_step, cu_seqlens, + lse_packed); } else if (out.scalar_type() == at::ScalarType::BFloat16) { using dtype = at::BFloat16; - thd_out_correction_helper(out, out_per_step, lse, lse_per_step, cu_seqlens); + thd_out_correction_helper(out, out_per_step, lse, lse_per_step, cu_seqlens, + lse_packed); } else if (out.scalar_type() == at::ScalarType::Float) { using dtype = float; - thd_out_correction_helper(out, out_per_step, lse, lse_per_step, cu_seqlens); + thd_out_correction_helper(out, out_per_step, lse, lse_per_step, cu_seqlens, + lse_packed); } else { NVTE_ERROR("Unsupported dtype of out\n"); } } else { if (out.scalar_type() == at::ScalarType::Half) { using dtype = at::Half; - thd_out_correction_helper(out, out_per_step, lse, lse_per_step, cu_seqlens); + thd_out_correction_helper(out, out_per_step, lse, lse_per_step, cu_seqlens, + lse_packed); } else if (out.scalar_type() == at::ScalarType::BFloat16) { using dtype = at::BFloat16; - thd_out_correction_helper(out, out_per_step, lse, lse_per_step, cu_seqlens); + thd_out_correction_helper(out, out_per_step, lse, lse_per_step, cu_seqlens, + lse_packed); } else if (out.scalar_type() == at::ScalarType::Float) { using dtype = float; - thd_out_correction_helper(out, out_per_step, lse, lse_per_step, cu_seqlens); + thd_out_correction_helper(out, out_per_step, lse, lse_per_step, cu_seqlens, + lse_packed); } else { NVTE_ERROR("Unsupported dtype of out\n"); } diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cu b/transformer_engine/pytorch/csrc/extensions/gemm.cu index ba9851e7e8..40b96a057f 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cu +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cu @@ -15,10 +15,16 @@ void te_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType at::Tensor workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count) { using namespace transformer_engine; - if (A.data_ptr() == nullptr || B.data_ptr() == nullptr) { - if (D.data_ptr() != nullptr && !accumulate) D.zero_(); - if (bias.data_ptr() != nullptr) bias.zero_(); - if (pre_gelu_out.data_ptr() != nullptr) pre_gelu_out.zero_(); + if (A.numel() == 0 || B.numel() == 0) { + if (D.numel() != 0 && !accumulate) D.zero_(); + if (bias.numel() != 0 && grad) { + if (B.numel() == 0) { + bias.zero_(); + } else { + bias.copy_(B.sum(0)); + } + } + if (pre_gelu_out.numel() != 0) pre_gelu_out.zero_(); return; } @@ -109,10 +115,16 @@ void te_grouped_gemm(std::vector A, at::Tensor A_scale_inverse, int return tensor_wrappers.back().data(); }; for (size_t i = 0; i < A.size(); i++) { - if (A[i].data_ptr() == nullptr || B[i].data_ptr() == nullptr) { - if (D[i].data_ptr() != nullptr && !accumulate) D[i].zero_(); - if (bias[i].data_ptr() != nullptr) bias[i].zero_(); - if (pre_gelu_out[i].data_ptr() != nullptr) pre_gelu_out[i].zero_(); + if (A[i].numel() == 0 || B[i].numel() == 0) { + if (D[i].numel() != 0 && !accumulate) D[i].zero_(); + if (bias[i].numel() != 0 && grad) { + if (B[i].numel() == 0) { + bias[i].zero_(); + } else { + bias[i].copy_(B[i].sum(0)); + } + } + if (pre_gelu_out[i].numel() != 0) pre_gelu_out[i].zero_(); continue; } @@ -175,6 +187,8 @@ void te_grouped_gemm_single_output( void* d_i_ptr = reinterpret_cast(D.data_ptr()); for (size_t i = 0; i < A.size(); i++) { if (m_splits[i] == 0) continue; + NVTE_CHECK(A[i].data_ptr() != nullptr, "A[", i, "] must not be nullptr."); + NVTE_CHECK(B[i].data_ptr() != nullptr, "B[", i, "] must not be nullptr."); NVTE_CHECK(A[i].is_contiguous(), "A[", i, "] must be contiguous."); NVTE_CHECK(B[i].is_contiguous(), "B[", i, "] must be contiguous."); te_A.emplace_back(make_tensor( diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index e9fb11e3b9..490ac3b160 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -6,6 +6,7 @@ from __future__ import annotations from contextlib import contextmanager, AbstractContextManager, ContextDecorator +from functools import lru_cache from typing import Any, Callable, Dict, List, Optional, Tuple, Union import warnings @@ -125,6 +126,7 @@ def set_tensor_model_parallel_attributes( setattr(tensor, "partition_stride", stride) +@lru_cache def get_distributed_world_size(group: Optional[dist_group_type] = None) -> int: """Return world size for the distributed group.""" if not torch.distributed.is_initialized(): @@ -132,6 +134,7 @@ def get_distributed_world_size(group: Optional[dist_group_type] = None) -> int: return torch.distributed.get_world_size(group=group) +@lru_cache def get_distributed_rank(group: Optional[dist_group_type] = None) -> int: """Return my rank for the distributed group.""" assert torch.distributed.is_initialized(), "torch.distributed is not initialized." @@ -203,6 +206,8 @@ class activation_recompute_forward(AbstractContextManager, ContextDecorator): activations, followed by calculation of gradients using these values. """ + _is_first_fp8_module: List = [] + def __init__(self, activation_recompute: bool = False, recompute_phase: bool = False): super().__init__() self.activation_recompute = activation_recompute @@ -215,6 +220,15 @@ def __enter__(self): ) _FP8_ACTIVATION_RECOMPUTE_PHASE = self.recompute_phase + if self.activation_recompute and not self.recompute_phase: + activation_recompute_forward._is_first_fp8_module.append( + FP8GlobalStateManager.IS_FIRST_FP8_MODULE + ) + if self.activation_recompute and self.recompute_phase: + FP8GlobalStateManager.IS_FIRST_FP8_MODULE = ( + activation_recompute_forward._is_first_fp8_module.pop(0) + ) + def __exit__(self, *exc_details): global _FP8_ACTIVATION_RECOMPUTE_ENABLED, _FP8_ACTIVATION_RECOMPUTE_PHASE _FP8_ACTIVATION_RECOMPUTE_ENABLED = False @@ -749,11 +763,11 @@ def add(self, name: str, seed: int) -> None: """ # Check seed is not already used. if seed in self.seeds_: - raise Exception(f"seed {seed} already exists") + raise RuntimeError(f"seed {seed} already exists") self.seeds_.add(seed) # Check that state is not already defined. if name in self.states_: - raise Exception(f"cuda rng state {name} already exists") + raise RuntimeError(f"cuda rng state {name} already exists") if graph_safe_rng_available(): new_state = _get_cuda_rng_state(clone=True) @@ -783,7 +797,7 @@ def fork(self, name: str = "model-parallel-rng"): """ # Check if we have added the state if name not in self.states_: - raise Exception(f"cuda rng state {name} is not added") + raise KeyError(f"cuda rng state {name} is not added") # Get the reference to current rng state. orig_cuda_rng_state = _get_cuda_rng_state() # Set rng state to the desired one diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index e2642bc360..cba71e1326 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -61,6 +61,7 @@ def _make_graphed_callables( fp8_weight_caching: bool = False, sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None, _order: Optional[List[int]] = None, + pool: Optional[Tuple[int, ...]] = None, ) -> SingleOrTuple[Callable]: """ Helper method for `make_graphed_callables` @@ -193,7 +194,7 @@ def _make_graphed_callables( fwd_graph.register_generator_state(state) bwd_graph.register_generator_state(state) - mempool = graph_pool_handle() + mempool = graph_pool_handle() if pool is None else pool # Warmup # Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work @@ -357,6 +358,7 @@ class Graphed(torch.autograd.Function): @staticmethod def forward(ctx, skip_fp8_weight_update, *inputs): + # pylint: disable=missing-function-docstring # Set flag for whether to update FP8 weight updates ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() @@ -376,6 +378,7 @@ def forward(ctx, skip_fp8_weight_update, *inputs): @staticmethod @torch.autograd.function.once_differentiable def backward(ctx, *grads): + # pylint: disable=missing-function-docstring # Replay backward graph assert len(grads) == len(static_grad_outputs) @@ -518,6 +521,7 @@ def make_graphed_callables( fp8_recipe: Optional[DelayedScaling] = None, fp8_weight_caching: bool = False, _order: Optional[List[int]] = None, + pool: Optional[Tuple[int, ...]] = None, ) -> Union[Callable, Tuple[Callable, ...]]: """ Make CUDA graph version of Transformer Engine modules @@ -541,6 +545,9 @@ def make_graphed_callables( and outputs are disconnected in compute graph. sample_kwargs: (tuple of) dict, optional Keyword arguments to callable(s) + pool: (tuple of) int, default = `None`, optional + An instance returned from function `torch.cuda.graph_pool_handle` that hints + this graph may share memory with the indicated pool. FP8-related parameters ---------------------- @@ -617,6 +624,7 @@ def forward_func(*args, **kwargs): fp8_weight_caching=fp8_weight_caching, sample_kwargs=sample_kwargs, _order=_order, + pool=pool, ) # Ensures warmup does not affect numerics for ops such as dropout. diff --git a/transformer_engine/pytorch/jit.py b/transformer_engine/pytorch/jit.py index 1646847162..ed08627e95 100644 --- a/transformer_engine/pytorch/jit.py +++ b/transformer_engine/pytorch/jit.py @@ -8,6 +8,8 @@ import torch +# pylint: disable=unnecessary-lambda-assignment + jit_fuser = torch.jit.script if torch.__version__ >= "2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))): jit_fuser = torch.compile @@ -109,7 +111,7 @@ def dgelu_fused_(grad_output: torch.Tensor, inp: torch.Tensor) -> torch.Tensor: def bias_gelu_fused(inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor: """Disable native AMP for bias_gelu_fused_""" with torch.cuda.amp.autocast(enabled=False): - if bias.numel() != 0: + if bias is not None and bias.numel() != 0: return bias_gelu_fused_(inp, bias) return gelu_fused_(inp) @@ -119,7 +121,7 @@ def bgrad_dgelu_fused( ) -> Tuple[Optional[torch.Tensor], torch.Tensor]: """Disable native AMP for `bgrad_dgelu_fused_`""" with torch.cuda.amp.autocast(enabled=False): - if bias.numel() != 0: + if bias is not None and bias.numel() != 0: return bgrad_dgelu_fused_(grad_output, inp, bias) return None, dgelu_fused_(grad_output, inp) diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index 23a06e318f..21365398f3 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -124,6 +124,7 @@ def forward( dim: int, *tensors: Tuple[torch.Tensor, ...], ) -> torch.Tensor: + # pylint: disable=missing-function-docstring # Check first tensor if not tensors: @@ -192,6 +193,7 @@ def backward( ctx, grad_output: torch.Tensor, ) -> Tuple[Optional[torch.Tensor], ...]: + # pylint: disable=missing-function-docstring grad_inputs = [] for split_start, split_end in ctx.split_ranges: slices = [slice(None)] * grad_output.dim() diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 644af2c22c..bc4a06b4cb 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -11,7 +11,7 @@ import fcntl import struct from abc import ABC, abstractmethod -from typing import Dict, Generator, List, Optional, Tuple, Union +from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union from contextlib import contextmanager import torch @@ -406,6 +406,27 @@ def __init__(self) -> None: self.fsdp_wrapped = False self.fsdp_group = None self._fp8_workspaces: Dict[str, Float8Tensor] = {} + self.activation_dtype: Optional[torch.dtype] = None + + # Names of attributes that can be set quickly (see __setattr__ + # method) + _fast_setattr_names: Set[str] = { + "activation_dtype", + "fp8", + "fp8_initialized", + "fp8_calibration", + "fp8_parameters", + } + + def __setattr__(self, name: str, value: Any) -> None: + if name in TransformerEngineBaseModule._fast_setattr_names: + # torch.nn.Module has a custom __setattr__ that handles + # modules, parameters, and buffers. This is unnecessary + # overhead when setting plain attrs. + self.__dict__[name] = value + else: + # Default case + super().__setattr__(name, value) def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None: """Increase or decrease size of amax history based on given `length`. @@ -593,7 +614,7 @@ def set_activation_dtype(self, inp: torch.Tensor) -> None: return # All checks after this have already been performed once, thus skip - if hasattr(self, "activation_dtype") and self.activation_dtype == inp.dtype: + if self.activation_dtype == inp.dtype: return dtype = inp.dtype @@ -664,7 +685,6 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: else: # If fp8 isn't enabled, turn off and return. self.fp8_initialized = False - return @contextmanager def prepare_forward( @@ -708,14 +728,12 @@ def prepare_forward( FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta) with torch.cuda.nvtx.range(self.__class__.__name__ + " forward"): - if not allow_non_contiguous: - yield inp.contiguous() - else: - yield inp + if not allow_non_contiguous and not inp.is_contiguous(): + inp = inp.contiguous() + yield inp if self.fp8 and in_fp8_activation_recompute_phase(): FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta) - return def set_nccl_overlap_warning_if_tp(self) -> None: """When using TP, the NCCL communication needs to be scheduled diff --git a/transformer_engine/pytorch/module/fp8_padding.py b/transformer_engine/pytorch/module/fp8_padding.py index 60bac91353..16d40cf401 100644 --- a/transformer_engine/pytorch/module/fp8_padding.py +++ b/transformer_engine/pytorch/module/fp8_padding.py @@ -28,6 +28,7 @@ def forward( padded_m_splits: List[int], is_grad_enabled: bool, ) -> torch.Tensor: + # pylint: disable=missing-function-docstring # Make sure input dimensions are compatible in_features = inp.shape[-1] @@ -46,6 +47,7 @@ def forward( @staticmethod def backward(ctx, grad_output: torch.Tensor): + # pylint: disable=missing-function-docstring grad_input = None if ctx.requires_dgrad: diff --git a/transformer_engine/pytorch/module/fp8_unpadding.py b/transformer_engine/pytorch/module/fp8_unpadding.py index 6e08f849ef..d45abe0668 100644 --- a/transformer_engine/pytorch/module/fp8_unpadding.py +++ b/transformer_engine/pytorch/module/fp8_unpadding.py @@ -28,6 +28,7 @@ def forward( padded_m_splits: List[int], is_grad_enabled: bool, ) -> torch.Tensor: + # pylint: disable=missing-function-docstring inputmats = torch.split(inp.view(-1, inp.shape[-1]), padded_m_splits) out_ret = torch.cat( [grad_output_mat[: m_splits[i]] for i, grad_output_mat in enumerate(inputmats)], dim=0 @@ -42,6 +43,7 @@ def forward( @staticmethod def backward(ctx, grad_output: torch.Tensor): + # pylint: disable=missing-function-docstring grad_input = None if ctx.requires_dgrad: grad_output = grad_output.contiguous() diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 14edd64249..08c5addcfc 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -39,8 +39,9 @@ from ..constants import GemmParallelModes, dist_group_type from ..jit import no_torch_dynamo from ..graph import is_graph_capturing -from ..float8_tensor import Float8Tensor +from ..tensor import Float8Tensor, QuantizedTensor from ..export import is_in_onnx_export_mode +from ..cpu_offload import is_cpu_offload_enabled __all__ = ["GroupedLinear"] @@ -69,6 +70,7 @@ def forward( weights_fp8: List[Union[Float8Tensor, None]], *weights_and_biases: Union[Float8Tensor, torch.Tensor, None], ) -> torch.Tensor: + # pylint: disable=missing-function-docstring num_gemms = len(m_splits) weights = weights_and_biases[:num_gemms] biases = weights_and_biases[num_gemms:] @@ -267,6 +269,7 @@ def forward( @staticmethod def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: + # pylint: disable=missing-function-docstring with torch.cuda.nvtx.range("_GroupedLinear_backward"): ( inputmat_scale_inv, @@ -440,36 +443,38 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], clear_tensor_data(*inputmats) clear_tensor_data(*inputmats_t) - if not ctx.use_bias: - grad_biases = [None] * ctx.num_gemms - - def handle_custom_ddp_from_mcore(w, wgrad): - if w.requires_grad: - if ctx.fuse_wgrad_accumulation and hasattr(w, "grad_added_to_main_grad"): - w.grad_added_to_main_grad = True - if getattr(w, "zero_out_wgrad", False): - wgrad = torch.zeros( - w.main_grad.shape, - dtype=w.dtype, - device=torch.cuda.current_device(), - requires_grad=False, - ) + def handle_custom_ddp_from_mcore(w, wgrad): + if w.requires_grad: + if ctx.fuse_wgrad_accumulation and hasattr(w, "grad_added_to_main_grad"): + w.grad_added_to_main_grad = True + if getattr(w, "zero_out_wgrad", False): + wgrad = torch.zeros( + w.main_grad.shape, + dtype=w.dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + else: + wgrad = torch.empty( + w.main_grad.shape, + dtype=w.dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + elif ctx.fuse_wgrad_accumulation: + wgrad = None else: - wgrad = torch.empty( - w.main_grad.shape, - dtype=w.dtype, - device=torch.cuda.current_device(), - requires_grad=False, - ) - elif ctx.fuse_wgrad_accumulation: - wgrad = None + wgrad = None + return wgrad + + wgrad_list = [ + handle_custom_ddp_from_mcore(w, wgrad) for w, wgrad in zip(weights, wgrad_list) + ] else: - wgrad = None - return wgrad + wgrad_list = [None] * ctx.num_gemms - wgrad_list = [ - handle_custom_ddp_from_mcore(w, wgrad) for w, wgrad in zip(weights, wgrad_list) - ] + if not ctx.use_bias: + grad_biases = [None] * ctx.num_gemms if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing(): FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) @@ -640,7 +645,7 @@ def __init__( if self.primary_weights_in_fp8: self.init_fp8_metadata(num_gemms=self.num_gemms) - self.reset_parameters(defer_init=(device == "meta")) + self.reset_parameters(defer_init=device == "meta") # For RPL, bias has to be added after TP collectives # So it cannot be fused with the GEMM @@ -719,7 +724,7 @@ def forward( bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] if not self.fp8: weight_tensors = [ - w.from_float8() if isinstance(w, Float8Tensor) else w for w in weight_tensors + w.dequantize() if isinstance(w, QuantizedTensor) else w for w in weight_tensors ] weight_tensors_fp8 = [None] * self.num_gemms @@ -746,8 +751,6 @@ def forward( skip_update_flag=skip_fp8_weight_update, ) - from ..cpu_offload import CPUOffloadEnabled - if torch.is_grad_enabled(): linear_fn = _GroupedLinear.apply args = [] @@ -763,7 +766,7 @@ def forward( self.fp8_calibration, self.fp8_meta, self.fuse_wgrad_accumulation, - CPUOffloadEnabled, + is_cpu_offload_enabled(), self.sequence_parallel, self.activation_dtype, self._offsets, diff --git a/transformer_engine/pytorch/module/layernorm.py b/transformer_engine/pytorch/module/layernorm.py index 292fcd06de..0c439ac417 100644 --- a/transformer_engine/pytorch/module/layernorm.py +++ b/transformer_engine/pytorch/module/layernorm.py @@ -12,7 +12,6 @@ from torch.nn import init import transformer_engine_torch as tex -from .base import TransformerEngineBaseModule from ..cpp_extensions import ( layernorm_fwd_inf, ) @@ -39,6 +38,7 @@ def forward( is_grad_enabled: bool, activation_dtype: torch.dtype, ) -> torch.Tensor: + # pylint: disable=missing-function-docstring # Make sure input dimensions are compatible in_features = ln_weight.numel() assert inp.is_cuda, "TransformerEngine needs CUDA." @@ -70,6 +70,7 @@ def forward( @staticmethod def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: + # pylint: disable=missing-function-docstring inputmat, ln_weight, mu, rsigma = ctx.saved_tensors grad_output = grad_output.contiguous() d_ln_out = grad_output.view(inputmat.shape) @@ -143,8 +144,9 @@ def __init__( ) ) self.sequence_parallel = sequence_parallel + self.activation_dtype: Optional[torch.dtype] = None - self.reset_parameters(defer_init=(device == "meta")) + self.reset_parameters(defer_init=device == "meta") # These many SMs are subtracted from the total SM count when calling forward # and backward LayerNorm C APIs. These envvars can be used to prevent the LN @@ -185,9 +187,22 @@ def reset_parameters(self, defer_init=False) -> None: @no_torch_dynamo() def forward(self, inp: torch.Tensor) -> torch.Tensor: - """LayerNorm FWD""" + # pylint: disable=missing-function-docstring + # Set the activation type for AMP. - TransformerEngineBaseModule.set_activation_dtype(self, inp) + # Note: This will soon be deprecated with + # https://github.com/NVIDIA/TransformerEngine/pull/1033 + if torch.is_autocast_enabled(): + self.activation_dtype = torch.get_autocast_gpu_dtype() + elif self.activation_dtype != inp.dtype: + dtype = inp.dtype + for name, param in self.named_parameters(): + if param is not None: + assert dtype == param.dtype, ( + "Data types for parameters must match when outside of autocasted region. " + f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}" + ) + self.activation_dtype = dtype if torch.is_grad_enabled(): fwd_fn = _LayerNorm.apply diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 92030a7f7a..97006a0671 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -36,6 +36,7 @@ allreduce, reduce_scatter_along_first_dim, gather_along_first_dim, + in_fp8_activation_recompute_phase, _fsdp_scatter_tensors, _fsdp_gather_tensors, ) @@ -46,6 +47,7 @@ from ..float8_tensor import Float8Tensor from ..export import is_in_onnx_export_mode from ..tensor import QuantizedTensor +from ..cpu_offload import is_cpu_offload_enabled __all__ = ["LayerNormLinear"] @@ -93,9 +95,11 @@ def forward( fp8_output: bool, fsdp_group: Union[dist_group_type, None], ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: + # pylint: disable=missing-function-docstring # Make sure input dimensions are compatible - in_features = ln_weight.numel() - assert inp.shape[-1] == in_features, "GEMM not possible" + out_features, in_features = weight.shape + inp_shape = inp.shape + assert inp_shape[-1] == in_features, "GEMM not possible" inputmat = inp.view((-1, in_features)) if fp8: assert_dim_for_fp8_exec(inputmat) @@ -151,6 +155,7 @@ def forward( # Column Parallel Linear ln_out_gathered = False + ub_algo = None if ub_overlap_ag: ln_out_total = ub_obj_lnout.get_ubuf_output(1) if not return_layernorm_output: @@ -339,7 +344,7 @@ def forward( ctx.use_bias = use_bias ctx.sequence_parallel = sequence_parallel ctx.tensor_parallel = tensor_parallel - ctx.inp_shape = inp.shape + ctx.inp_shape = inp_shape ctx.parallel_mode = parallel_mode ctx.tp_group = tp_group ctx.tp_size = tp_size @@ -357,10 +362,10 @@ def forward( ctx.normalization = normalization ctx.reduce_and_update_bwd_fp8_tensors = False if ctx.fp8 and requires_grad(inp, ln_weight, ln_bias, weight, bias): - ctx.reduce_and_update_bwd_fp8_tensors = ( - ctx.reduce_and_update_bwd_fp8_tensors - or FP8GlobalStateManager.is_first_fp8_module() - ) + _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE + ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() + if in_fp8_activation_recompute_phase(): + FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module # Row Parallel Linear if parallel_mode == "row" and sequence_parallel: @@ -369,7 +374,7 @@ def forward( out, _ = allreduce(out, tp_group) # [*, in_features] -> [*, out_features] except first dimension changes for SP - out = out.view(-1, *inp.shape[1:-1], out.shape[-1]) + out = out.view(-1, *inp_shape[1:-1], out_features) if return_layernorm_output: if return_layernorm_output_gathered: @@ -383,6 +388,7 @@ def forward( def backward( ctx, *grad_outputs: Tuple[torch.Tensor, ...] ) -> Tuple[Union[torch.Tensor, None], ...]: + # pylint: disable=missing-function-docstring if isinstance(grad_outputs[0], Float8Tensor): ctx.fp8_meta["scaling_bwd"].scale_inv[tex.FP8BwdTensors.GRAD_OUTPUT1] = grad_outputs[ 0 @@ -477,6 +483,7 @@ def backward( else: dgrad = torch.empty(dgrad_size, dtype=ctx.activation_dtype, device=weight.device) + rs_out = None if ctx.ub_bulk_dgrad: ub_algo = tex.UbufOverlapAlgo.BULK_OVERLAP_AG ub_obj = ub_obj_lnout @@ -574,6 +581,7 @@ def backward( elif ctx.parallel_mode == "column" and ctx.tensor_parallel: dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True) + wgrad = None if weight.requires_grad: if ctx.fp8: # WGRAD @@ -676,6 +684,8 @@ def backward( if ctx.return_layernorm_output and not ctx.return_layernorm_output_gathered: dgrad = dgrad + grad_outputs[1].view_as(dgrad) + dgamma = None + dbeta = None if ctx.normalization == "LayerNorm": dgrad, dgamma, dbeta = tex.layernorm_bwd( dgrad, @@ -1055,7 +1065,7 @@ def __init__( if with_fp8_params: self.init_fp8_metadata() - self.reset_parameters(defer_init=(device == "meta")) + self.reset_parameters(defer_init=device == "meta") # For RPL, bias has to be added after TP collectives # So it cannot be fused with the GEMM @@ -1160,9 +1170,7 @@ def forward( unfused_weights = [w.dequantize() for w in unfused_weights] weight_tensor = _noop_cat(unfused_weights) if self.use_bias: - bias_tensor = _noop_cat( - [getattr(self, name) for name in self.bias_names], - ) + bias_tensor = _noop_cat([getattr(self, name) for name in self.bias_names]) else: bias_tensor = getattr(self, self.bias_names[0]) # Unused @@ -1190,8 +1198,6 @@ def forward( skip_update_flag=skip_fp8_weight_update, ) - from ..cpu_offload import CPUOffloadEnabled - if torch.is_grad_enabled(): fwd_fn = _LayerNormLinear.apply args = [] @@ -1212,7 +1218,7 @@ def forward( self.fp8_calibration, self.fp8_meta, self.fuse_wgrad_accumulation, - CPUOffloadEnabled, + is_cpu_offload_enabled(), self.tp_group, self.tp_size, self.sequence_parallel, diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 6d5609ccd2..966924a85c 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -43,6 +43,7 @@ reduce_scatter_along_first_dim, gather_along_first_dim, use_reentrant_activation_recompute, + in_fp8_activation_recompute_phase, _fsdp_scatter_tensors, _fsdp_gather_tensors, ) @@ -54,6 +55,7 @@ from ..graph import is_graph_capturing from ..float8_tensor import Float8Tensor from ._common import _apply_normalization +from ..cpu_offload import is_cpu_offload_enabled __all__ = ["LayerNormMLP"] @@ -122,9 +124,11 @@ def forward( gemm_gelu_fusion: bool, fsdp_group: Union[dist_group_type, None], ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: + # pylint: disable=missing-function-docstring # Make sure input dimensions are compatible in_features = ln_weight.numel() - assert inp.shape[-1] == in_features, "GEMM not possible" + inp_shape = inp.shape + assert inp_shape[-1] == in_features, "GEMM not possible" inputmat = inp.view((-1, in_features)) if fp8: assert_dim_for_fp8_exec(inputmat) @@ -171,6 +175,7 @@ def forward( # Column Parallel Linear ln_out_gathered = False + ub_algo_ag = None if ub_overlap_ag: ln_out_total = ub_obj_lnout.get_ubuf_output(1) ln_out = torch.empty_like(ln_out) @@ -239,23 +244,23 @@ def forward( activation_dtype, get_workspace(), ] - fp8_gemm_kwargs = dict( - bias=fc1_bias, - use_bias=use_fc1_bias, - use_split_accumulator=_2X_ACC_FPROP, - ub_algo=ub_algo_ag if ub_overlap_ag else None, - ub=ub_obj_lnout if ub_overlap_ag else None, - extra_output_tensor=ln_out if ub_overlap_ag else None, - ) + fp8_gemm_kwargs = { + "bias": fc1_bias, + "use_bias": use_fc1_bias, + "use_split_accumulator": _2X_ACC_FPROP, + "ub_algo": ub_algo_ag if ub_overlap_ag else None, + "ub": ub_obj_lnout if ub_overlap_ag else None, + "extra_output_tensor": ln_out if ub_overlap_ag else None, + } if gemm_gelu_fusion: fp8_gemm_args[8] = torch.uint8 # out_dtype fp8_gemm_kwargs.update( - dict( - gelu=True, - out_index=tex.FP8FwdTensors.GEMM2_INPUT, - fp8_meta_tensor=fp8_meta["scaling_fwd"], - D_dtype=fp8_dtype_forward, - ) + { + "gelu": True, + "out_index": tex.FP8FwdTensors.GEMM2_INPUT, + "fp8_meta_tensor": fp8_meta["scaling_fwd"], + "D_dtype": fp8_dtype_forward, + } ) fp8_gemm_out = tex.fp8_gemm(*fp8_gemm_args, **fp8_gemm_kwargs) if not is_grad_enabled: @@ -281,6 +286,9 @@ def forward( None, activation_dtype, ) + + rs_out = None + ub_algo_rs = None if ub_overlap_rs: ub_obj_fc2out = get_ub("fc2_fprop") fc2_out = ub_obj_fc2out.get_ubuf_output(1) @@ -433,7 +441,8 @@ def forward( ln_weight.weight_offloading = True fc1_weight.weight_offloading = True fc2_weight.weight_offloading = True - fc1_bias.weight_offloading = True + if fc1_bias is not None: + fc1_bias.weight_offloading = True inputmat.activation_offloading = True if normalization == "LayerNorm": @@ -487,7 +496,7 @@ def forward( ctx.use_fc2_bias = use_fc2_bias ctx.sequence_parallel = sequence_parallel ctx.tensor_parallel = tensor_parallel - ctx.inp_shape = inp.shape + ctx.inp_shape = inp_shape ctx.tp_group = tp_group ctx.tp_size = tp_size ctx.bias_gelu_nvfusion = bias_gelu_nvfusion @@ -508,7 +517,10 @@ def forward( if ctx.fp8 and requires_grad( inp, ln_weight, ln_bias, fc1_weight, fc2_weight, fc1_bias, fc2_bias ): + _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() + if in_fp8_activation_recompute_phase(): + FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module # Row Parallel Linear if ub_overlap_rs: @@ -519,11 +531,11 @@ def forward( fc2_out, _ = allreduce(fc2_out, tp_group) # [*, in_features] -> [*, out_features] except first dimension changes for SP - fc2_out = fc2_out.view(-1, *inp.shape[1:-1], fc2_out.shape[-1]) + fc2_out = fc2_out.view(-1, *inp_shape[1:-1], fc2_out.shape[-1]) if return_layernorm_output: if return_layernorm_output_gathered: - shape = list(inp.shape) + shape = list(inp_shape) shape[0] *= tp_size return fc2_out, ln_out_return.view(shape) return fc2_out, ln_out_return.view_as(inp) @@ -533,6 +545,7 @@ def forward( def backward( ctx, *grad_outputs: Tuple[torch.Tensor, ...] ) -> Tuple[Union[torch.Tensor, None], ...]: + # pylint: disable=missing-function-docstring with torch.cuda.nvtx.range("_LayerNormMLP_backward"): ( inputmat, @@ -596,6 +609,7 @@ def backward( if tp_world_size == 1: ctx.ub_overlap_ag = False + ub_algo = None if ctx.ub_overlap_ag: dim_size = list(grad_outputs[0].size()) dim_size[0] = dim_size[0] * tp_world_size @@ -637,6 +651,7 @@ def backward( else: accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation + fc2_wgrad = None if ctx.fp8: fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) @@ -771,6 +786,7 @@ def backward( ub_obj_dgrad.set_ubuf_scale_inv(meta_tensor.scale_inv[out_index]) # Set UB algo and UB obj for fc1_dgrad bulk/pipelined overlap + rs_out = None if ctx.ub_bulk_dgrad: ub_algo = tex.UbufOverlapAlgo.BULK_OVERLAP_AG ub_obj = ub_obj_lnout @@ -920,6 +936,7 @@ def backward( elif ctx.set_parallel_mode and ctx.tensor_parallel: fc1_dgrad, handle = allreduce(fc1_dgrad, ctx.tp_group, async_op=True) + fc1_wgrad = None if fc1_weight.requires_grad: if ctx.fp8: # FC1 WGRAD @@ -1023,6 +1040,8 @@ def backward( if ctx.return_layernorm_output and not ctx.return_layernorm_output_gathered: dgrad = dgrad + grad_outputs[1].view_as(dgrad) + dgamma = None + dbeta = None if ctx.normalization == "LayerNorm": dgrad, dgamma, dbeta = tex.layernorm_bwd( dgrad, @@ -1109,7 +1128,9 @@ def backward( dbeta, fc1_wgrad, None, # fc1_weight_fp8 - fc1_bias_grad if ctx.use_fc1_bias else None, + # Due to bias gelu nvfusion available in the bf16 case, fc1_bias_grad is calculated at + # different paths and this confused the linter. + fc1_bias_grad if ctx.use_fc1_bias else None, # pylint: disable=used-before-assignment None, # use_fc1_bias fc2_wgrad, None, # fc2_weight_fp8 @@ -1381,7 +1402,7 @@ def __init__( if with_fp8_params: self.init_fp8_metadata(num_gemms=2) - self.reset_parameters(defer_init=(device == "meta")) + self.reset_parameters(defer_init=device == "meta") # For RPL, bias has to be added after TP collectives # So it cannot be fused with the GEMM @@ -1471,7 +1492,9 @@ def forward( # Get weight tensors fc1_weight = self.fc1_weight + fc1_bias = self.fc1_bias fc2_weight = self.fc2_weight + fc2_bias = self.fc2_bias if not self.fp8: if isinstance(fc1_weight, Float8Tensor): fc1_weight = fc1_weight.from_float8() @@ -1524,8 +1547,6 @@ def forward( if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute(): self.bias_gelu_nvfusion = False - from ..cpu_offload import CPUOffloadEnabled - if torch.is_grad_enabled(): fwd_fn = _LayerNormMLP.apply args = [] @@ -1538,11 +1559,11 @@ def forward( self.layer_norm_bias, fc1_weight, fc1_weight_fp8, - self.fc1_bias, + fc1_bias, self.use_bias, fc2_weight, fc2_weight_fp8, - self.fc2_bias, + fc2_bias, self.apply_bias and not self.gemm_bias_unfused_add, self.eps, is_first_microbatch, @@ -1550,7 +1571,7 @@ def forward( self.fp8_calibration, self.fp8_meta, self.fuse_wgrad_accumulation, - CPUOffloadEnabled, + is_cpu_offload_enabled(), self.tp_group, self.tp_size, self.sequence_parallel, @@ -1580,12 +1601,12 @@ def forward( out, ln_out = out if self.gemm_bias_unfused_add: - out = out + cast_if_needed(self.fc2_bias, self.activation_dtype) + out = out + cast_if_needed(fc2_bias, self.activation_dtype) if self.return_bias: if self.return_layernorm_output: - return out, cast_if_needed(self.fc2_bias, self.activation_dtype), ln_out - return out, cast_if_needed(self.fc2_bias, self.activation_dtype) + return out, cast_if_needed(fc2_bias, self.activation_dtype), ln_out + return out, cast_if_needed(fc2_bias, self.activation_dtype) if self.return_layernorm_output: return out, ln_out return out diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 8e19a65a28..403eef091f 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -33,6 +33,7 @@ allreduce, reduce_scatter_along_first_dim, gather_along_first_dim, + in_fp8_activation_recompute_phase, _fsdp_scatter_tensors, _fsdp_gather_tensors, ) @@ -48,6 +49,7 @@ from ..float8_tensor import Float8Tensor from ..export import is_in_onnx_export_mode from ..tensor import QuantizedTensor +from ..cpu_offload import is_cpu_offload_enabled __all__ = ["Linear"] @@ -84,11 +86,13 @@ def forward( fp8_output: bool, fsdp_group: Union[dist_group_type, None], ) -> torch.Tensor: + # pylint: disable=missing-function-docstring is_input_fp8 = isinstance(inp, Float8Tensor) # Make sure input dimensions are compatible - in_features = weight.shape[-1] - assert inp.shape[-1] == in_features, "GEMM not possible" + out_features, in_features = weight.shape + inp_shape = inp.shape + assert inp_shape[-1] == in_features, "GEMM not possible" inputmat = inp.view(-1, in_features) if fp8: assert_dim_for_fp8_exec(inputmat) @@ -175,12 +179,14 @@ def forward( activation_dtype, ) + ub_algo = None + rs_out = None if ub_overlap_rs: ub_obj_projout = get_ub(ub_name + "_fprop") out = ub_obj_projout.get_ubuf_output(1) dim_size = list(inputmat_total.size()) dim_size[0] = dim_size[0] // tp_world_size - dim_size[1] = weight_fp8.size(0) + dim_size[1] = out_features rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) if ub_obj_projout.is_p2p_overlap(): if ub_obj_projout.is_atomic_gemm(): @@ -200,7 +206,7 @@ def forward( ub_obj_projout.set_ubuf_scale_inv(meta_tensor.scale_inv[proj_out_index]) else: dim_size = list(inputmat_total.size()) - dim_size[1] = weight_fp8.size(0) + dim_size[1] = out_features out = torch.empty(dim_size, dtype=proj_out_pttype, device=inputmat_total.device) _ = fp8_gemm( @@ -260,7 +266,7 @@ def forward( out = ub_obj_projout.get_ubuf_output(1) dim_size = list(inputmat_total.size()) dim_size[0] = dim_size[0] // get_distributed_world_size(tp_group) - dim_size[1] = weight.size(0) + dim_size[1] = out_features rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) if ub_obj_projout.is_p2p_overlap(): ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P @@ -268,7 +274,7 @@ def forward( ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS else: dim_size = list(inputmat_total.size()) - dim_size[1] = weight.size(0) + dim_size[1] = out_features out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) _ = gemm( @@ -334,7 +340,7 @@ def forward( ctx.use_bias = use_bias ctx.sequence_parallel = sequence_parallel ctx.tensor_parallel = tensor_parallel - ctx.inp_shape = inp.shape + ctx.inp_shape = inp_shape ctx.parallel_mode = parallel_mode ctx.tp_group = tp_group ctx.ub_overlap_ag = ub_overlap_ag @@ -344,10 +350,10 @@ def forward( ctx.is_input_fp8 = is_input_fp8 ctx.reduce_and_update_bwd_fp8_tensors = False if ctx.fp8 and requires_grad(inp, weight, bias): - ctx.reduce_and_update_bwd_fp8_tensors = ( - ctx.reduce_and_update_bwd_fp8_tensors - or FP8GlobalStateManager.is_first_fp8_module() - ) + _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE + ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() + if in_fp8_activation_recompute_phase(): + FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module # Row Parallel Linear if ub_overlap_rs: @@ -358,10 +364,11 @@ def forward( out, _ = allreduce(out, tp_group) # [*, in_features] -> [*, out_features] except first dimension changes for SP - return out.view(-1, *inp.shape[1:-1], out.shape[-1]) + return out.view(-1, *inp_shape[1:-1], out_features) @staticmethod def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: + # pylint: disable=missing-function-docstring if isinstance(grad_output, Float8Tensor): ctx.fp8_meta["scaling_bwd"].scale_inv[ tex.FP8BwdTensors.GRAD_OUTPUT1 @@ -394,6 +401,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], tp_world_size = get_distributed_world_size(ctx.tp_group) ctx.ub_overlap_ag = False if tp_world_size == 1 else ctx.ub_overlap_ag + ub_algo = None if ctx.ub_overlap_ag: dim_size = list(grad_output.size()) dim_size[0] = dim_size[0] * tp_world_size @@ -505,6 +513,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], elif ctx.parallel_mode == "column" and ctx.tensor_parallel: dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True) + wgrad = None if weight.requires_grad: if ctx.fp8: # WGRAD @@ -871,7 +880,7 @@ def __init__( if with_fp8_params: self.init_fp8_metadata() - self.reset_parameters(defer_init=(device == "meta")) + self.reset_parameters(defer_init=device == "meta") # For RPL, bias has to be added after TP collectives # So it cannot be fused with the GEMM @@ -952,9 +961,7 @@ def forward( unfused_weights = [w.dequantize() for w in unfused_weights] weight_tensor = _noop_cat(unfused_weights) if self.use_bias: - bias_tensor = _noop_cat( - [getattr(self, name) for name in self.bias_names], - ) + bias_tensor = _noop_cat([getattr(self, name) for name in self.bias_names]) else: bias_tensor = getattr(self, self.bias_names[0]) # Unused @@ -983,8 +990,6 @@ def forward( fsdp_group=self.fsdp_group, ) - from ..cpu_offload import CPUOffloadEnabled - if torch.is_grad_enabled(): linear_fn = _Linear.apply args = [] @@ -1002,7 +1007,7 @@ def forward( self.fp8_calibration, self.fp8_meta, self.fuse_wgrad_accumulation, - CPUOffloadEnabled, + is_cpu_offload_enabled(), self.tp_group, self.tp_size, self.sequence_parallel, diff --git a/transformer_engine/pytorch/module/rmsnorm.py b/transformer_engine/pytorch/module/rmsnorm.py index d5dc400206..fc6ec5746f 100644 --- a/transformer_engine/pytorch/module/rmsnorm.py +++ b/transformer_engine/pytorch/module/rmsnorm.py @@ -11,7 +11,6 @@ from torch.nn.parameter import Parameter from torch.nn import init -from .base import TransformerEngineBaseModule from .. import cpp_extensions as tex from ..jit import no_torch_dynamo from ..utils import cast_if_needed @@ -36,6 +35,7 @@ def forward( is_grad_enabled: bool, activation_dtype: torch.dtype, ) -> torch.Tensor: + # pylint: disable=missing-function-docstring # Make sure input dimensions are compatible in_features = rmsnorm_weight.numel() assert inp.is_cuda, "TransformerEngine needs CUDA." @@ -62,6 +62,7 @@ def forward( @staticmethod def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: + # pylint: disable=missing-function-docstring inputmat, rmsnorm_weight, rsigma = ctx.saved_tensors grad_output = grad_output.contiguous() d_rmsnorm_out = grad_output.view(inputmat.shape) @@ -146,8 +147,9 @@ def __init__( ) ) self.sequence_parallel = sequence_parallel + self.activation_dtype: Optional[torch.dtype] = None - self.reset_parameters(defer_init=(device == "meta")) + self.reset_parameters(defer_init=device == "meta") # These many SMs are subtracted from the total SM count when calling forward # and backward RMSNorm C APIs. These envvars can be used to prevent the LN @@ -182,10 +184,22 @@ def reset_parameters(self, defer_init=False) -> None: @no_torch_dynamo() def forward(self, inp: torch.Tensor) -> torch.Tensor: - """RMSNorm FWD""" + # pylint: disable=missing-function-docstring # Set the activation type for AMP. - TransformerEngineBaseModule.set_activation_dtype(self, inp) + # Note: This will soon be deprecated with + # https://github.com/NVIDIA/TransformerEngine/pull/1033 + if torch.is_autocast_enabled(): + self.activation_dtype = torch.get_autocast_gpu_dtype() + elif self.activation_dtype != inp.dtype: + dtype = inp.dtype + for name, param in self.named_parameters(): + if param is not None: + assert dtype == param.dtype, ( + "Data types for parameters must match when outside of autocasted region. " + f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}" + ) + self.activation_dtype = dtype if torch.is_grad_enabled(): fwd_fn = _RMSNorm.apply diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index ce72dd8a55..859b1ba1d7 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -119,7 +119,6 @@ def __init__( dtype = canonicalize_dtype(dtype) if dtype not in (torch.float32, torch.float16, torch.bfloat16): raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})") - self.dtype: torch.dtype = canonicalize_dtype(dtype) # Tensor parallel configuration self.tensor_parallel_mode: Optional[str] @@ -278,7 +277,8 @@ def reset_parameters(self) -> None: weight = self.weight if weight.device.type != "cuda" or is_float8_tensor(weight): weight = torch.empty_like(weight, device=self.device) - weight = weight.to(device=self.device, dtype=self.dtype) + else: + weight = weight.to(device=self.device) # Initialize values init_context = contextlib.nullcontext @@ -562,12 +562,12 @@ def _functional_forward( _wait_async(x_async) x_async = None if with_fp8_compute: - kwargs = dict( - accumulate=accumulate_into_out, - out=y, - bias=b, - use_bias=(b is not None), - ) + kwargs = { + "accumulate": accumulate_into_out, + "out": y, + "bias": b, + "use_bias": (b is not None), + } if with_fp8_output: if y._fp8_meta is None: # Hackily create FP8TensorMeta if needed @@ -584,12 +584,12 @@ def _functional_forward( fp8_meta = y._fp8_meta[fp8_meta_key] fp8_meta_index = y._fp8_meta_index kwargs.update( - dict( - out=y._data, - out_index=fp8_meta_index, - fp8_meta_tensor=fp8_meta, - D_dtype=y._fp8_dtype, - ) + { + "out": y._data, + "out_index": fp8_meta_index, + "fp8_meta_tensor": fp8_meta, + "D_dtype": y._fp8_dtype, + } ) fp8_gemm( w._data, @@ -936,10 +936,7 @@ def _functional_backward( _wait_async(dy_async) dy_async = None if with_fp8_compute: - kwargs = dict( - accumulate=accumulate_into_grad_input, - out=dx, - ) + kwargs = {"accumulate": accumulate_into_grad_input, "out": dx} if with_fp8_grad_input: if dx._fp8_meta is None: # Hackily create FP8TensorMeta if needed @@ -958,12 +955,12 @@ def _functional_backward( fp8_meta = dx._fp8_meta[fp8_meta_key] fp8_meta_index = dx._fp8_meta_index kwargs.update( - dict( - out=dx._data, - out_index=fp8_meta_index, - fp8_meta_tensor=fp8_meta, - D_dtype=dx._fp8_dtype, - ) + { + "out": dx._data, + "out_index": fp8_meta_index, + "fp8_meta_tensor": fp8_meta, + "D_dtype": dx._fp8_dtype, + } ) fp8_gemm( w.transpose_2d(), @@ -1082,12 +1079,17 @@ def op_forward( if prev_op is not None and prev_op.num_fp8_scales("grad_output") > 0: grad_input_fp8_meta = prev_op.get_fp8_meta("grad_output") + # Get autocast dtype if needed + dtype = None + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_dtype("cuda") + # Linear forward output, x_local, _ = BasicLinear._functional_forward( input=input_, weight=self.weight, device=self.device, - dtype=self.dtype, + dtype=dtype, tensor_parallel_mode=self.tensor_parallel_mode, tensor_parallel_group=self.tensor_parallel_group, sequence_parallel=self.sequence_parallel, @@ -1103,6 +1105,7 @@ def op_forward( ctx.weight_fp8_meta = weight_fp8_meta ctx.grad_output_fp8_meta = grad_output_fp8_meta ctx.grad_input_fp8_meta = grad_input_fp8_meta + ctx.dtype = dtype ctx.input_dims = input_.size() ctx.input_requires_grad = input_.requires_grad ctx.weight_requires_grad = self.weight.requires_grad @@ -1143,7 +1146,7 @@ def op_backward( input_requires_grad=ctx.input_requires_grad, weight_requires_grad=ctx.weight_requires_grad, device=self.device, - dtype=self.dtype, + dtype=ctx.dtype, grad_weight=grad_weight, accumulate_into_grad_weight=accumulate_into_main_grad, tensor_parallel_mode=self.tensor_parallel_mode, diff --git a/transformer_engine/pytorch/ops/basic/bias.py b/transformer_engine/pytorch/ops/basic/bias.py index b8e8cc5e56..44a97b3b2d 100644 --- a/transformer_engine/pytorch/ops/basic/bias.py +++ b/transformer_engine/pytorch/ops/basic/bias.py @@ -62,9 +62,6 @@ def __init__( device = canonicalize_device(None) self.device: torch.device = device - # Bias tensor datatype - self.dtype: torch.dtype = canonicalize_dtype(dtype) - # Tensor parallel configuration tensor_parallel_size = 1 local_size = size @@ -88,7 +85,7 @@ def __init__( bias = torch.empty( local_size, device="meta", - dtype=dtype, + dtype=canonicalize_dtype(dtype), ) bias = torch.nn.Parameter(bias) self.bias: torch.nn.Parameter @@ -103,7 +100,8 @@ def reset_parameters(self) -> None: bias = self.bias if bias.device.type != "cuda": bias = torch.empty_like(bias, device=self.device) - bias = bias.to(device=self.device, dtype=self.dtype) + else: + bias = bias.to(device=self.device) # Initialize values bias.zero_() diff --git a/transformer_engine/pytorch/ops/fused/backward_linear_add.py b/transformer_engine/pytorch/ops/fused/backward_linear_add.py index 138eca3d96..123c560066 100644 --- a/transformer_engine/pytorch/ops/fused/backward_linear_add.py +++ b/transformer_engine/pytorch/ops/fused/backward_linear_add.py @@ -78,7 +78,7 @@ def fuser_backward( input_requires_grad=linear_op_ctx.input_requires_grad, weight_requires_grad=linear_op_ctx.weight_requires_grad, device=linear_op.device, - dtype=linear_op.dtype, + dtype=grad_input.dtype, grad_weight=grad_weight, accumulate_into_grad_weight=accumulate_into_main_grad, grad_input=grad_input, diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index 5fd52405e4..3afdc3a0c3 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -38,11 +38,7 @@ def __init__( ) -> None: # Basic operations that comprise this fused operation - op_idxs = dict( - linear=0, - bias=None, - activation=None, - ) + op_idxs = {"linear": 0, "bias": None, "activation": None} ops = [linear] if bias is not None: op_idxs["bias"] = len(ops) @@ -104,13 +100,18 @@ def fuser_forward( if prev_op is not None and prev_op.num_fp8_scales("grad_output") > 0: grad_input_fp8_meta = prev_op.get_fp8_meta("grad_output") + # Get autocast dtype if needed + dtype = None + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_dtype("cuda") + # Linear forward output, x_local, _ = BasicLinear._functional_forward( input=input_, weight=linear_op.weight, bias=bias, device=linear_op.device, - dtype=linear_op.dtype, + dtype=dtype, tensor_parallel_mode=linear_op.tensor_parallel_mode, tensor_parallel_group=linear_op.tensor_parallel_group, sequence_parallel=linear_op.sequence_parallel, @@ -126,6 +127,7 @@ def fuser_forward( linear_op_ctx.weight_fp8_meta = weight_fp8_meta linear_op_ctx.grad_output_fp8_meta = grad_output_fp8_meta linear_op_ctx.grad_input_fp8_meta = grad_input_fp8_meta + linear_op_ctx.dtype = dtype linear_op_ctx.input_dims = input_.size() linear_op_ctx.input_requires_grad = input_.requires_grad linear_op_ctx.weight_requires_grad = linear_op.weight.requires_grad @@ -167,7 +169,7 @@ def fuse_forward_linear_bias_activation( # Row tensor-parallelism requires communication after the # GEMM continue - if op1.dtype not in (torch.float16, torch.bfloat16): + if op1.weight.dtype not in (torch.float16, torch.bfloat16): # cuBLAS only supports fused GEMM+bias+activation with # FP16 and BF16 output continue diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index 6ddee2849a..3d994d80f0 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -37,11 +37,7 @@ def __init__( ) -> None: # Basic operations that comprise this fused operation - op_idxs = dict( - linear=0, - bias=None, - add=None, - ) + op_idxs = {"linear": 0, "bias": None, "add": None} ops = [linear] if bias is not None: op_idxs["bias"] = len(ops) @@ -95,6 +91,11 @@ def fuser_forward( if prev_op is not None and prev_op.num_fp8_scales("grad_output") > 0: grad_input_fp8_meta = prev_op.get_fp8_meta("grad_output") + # Get autocast dtype if needed + dtype = None + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_dtype("cuda") + # Linear forward output = basic_op_extra_inputs[self._op_idxs["add"]][0] output, x_local, _ = BasicLinear._functional_forward( @@ -102,7 +103,6 @@ def fuser_forward( weight=linear_op.weight, bias=bias, device=linear_op.device, - dtype=linear_op.dtype, out=output, accumulate_into_out=True, tensor_parallel_mode=linear_op.tensor_parallel_mode, @@ -120,6 +120,7 @@ def fuser_forward( linear_op_ctx.weight_fp8_meta = weight_fp8_meta linear_op_ctx.grad_output_fp8_meta = grad_output_fp8_meta linear_op_ctx.grad_input_fp8_meta = grad_input_fp8_meta + linear_op_ctx.dtype = dtype linear_op_ctx.input_dims = input_.size() linear_op_ctx.input_requires_grad = input_.requires_grad linear_op_ctx.weight_requires_grad = linear_op.weight.requires_grad diff --git a/transformer_engine/pytorch/ops/linear.py b/transformer_engine/pytorch/ops/linear.py index 13cec30fa2..daa5a6952e 100644 --- a/transformer_engine/pytorch/ops/linear.py +++ b/transformer_engine/pytorch/ops/linear.py @@ -91,24 +91,24 @@ def __init__( # Construct basic ops ops = [] - linear_kwargs = dict( - in_features=in_features, - out_features=out_features, - device=device, - dtype=dtype, - tensor_parallel_mode=tensor_parallel_mode, - tensor_parallel_group=tensor_parallel_group, - sequence_parallel=sequence_parallel, - rng_state_tracker_function=rng_state_tracker_function, - accumulate_into_main_grad=accumulate_into_main_grad, - ) - bias_kwargs = dict( - size=out_features, - device=device, - dtype=dtype, - tensor_parallel=(tensor_parallel_mode is not None), - tensor_parallel_group=tensor_parallel_group, - ) + linear_kwargs = { + "in_features": in_features, + "out_features": out_features, + "device": device, + "dtype": dtype, + "tensor_parallel_mode": tensor_parallel_mode, + "tensor_parallel_group": tensor_parallel_group, + "sequence_parallel": sequence_parallel, + "rng_state_tracker_function": rng_state_tracker_function, + "accumulate_into_main_grad": accumulate_into_main_grad, + } + bias_kwargs = { + "size": out_features, + "device": device, + "dtype": dtype, + "tensor_parallel": (tensor_parallel_mode is not None), + "tensor_parallel_group": tensor_parallel_group, + } if tensor_parallel_mode == "row": # Row TP: GEMM + bias + reduction linear_kwargs["in_features"] = local_in_features diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 47c6567056..75905ad854 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -179,7 +179,6 @@ def __init__(self) -> None: def is_fused_op(self) -> bool: return False - # pylint: disable=no-self-use def num_fp8_scales( self, mode: str, # pylint: disable=unused-argument @@ -225,11 +224,11 @@ def _make_meta( } # Construct FP8 metadata for all tensor types - return dict( - input=_make_meta(self.num_fp8_scales("input"), True), - param=_make_meta(self.num_fp8_scales("param"), True), - grad_output=_make_meta(self.num_fp8_scales("grad_output"), False), - ) + return { + "input": _make_meta(self.num_fp8_scales("input"), True), + "param": _make_meta(self.num_fp8_scales("param"), True), + "grad_output": _make_meta(self.num_fp8_scales("grad_output"), False), + } @classmethod def _maybe_update_fp8_meta(cls, fp8_meta: Optional[dict[str, Any]]) -> None: diff --git a/transformer_engine/pytorch/ops/sequential.py b/transformer_engine/pytorch/ops/sequential.py index c5e25fe1f2..8d4fefb4c5 100644 --- a/transformer_engine/pytorch/ops/sequential.py +++ b/transformer_engine/pytorch/ops/sequential.py @@ -46,6 +46,7 @@ def __init__( self.append(module) def add_module(self, name: str, module: Optional[torch.nn.Module]) -> None: + # pylint: disable=missing-function-docstring self._module_groups = None super().add_module(name, module) diff --git a/transformer_engine/pytorch/optimizers/fused_adam.py b/transformer_engine/pytorch/optimizers/fused_adam.py index da0ba3328a..191c98745d 100644 --- a/transformer_engine/pytorch/optimizers/fused_adam.py +++ b/transformer_engine/pytorch/optimizers/fused_adam.py @@ -100,13 +100,13 @@ def __init__( # If the optimizer is capturable then LR should be a tensor (on GPU) lr = torch.tensor(lr, dtype=torch.float32) if capturable else lr - defaults = dict( - lr=lr, - bias_correction=bias_correction, - betas=betas, - eps=eps, - weight_decay=weight_decay, - ) + defaults = { + "lr": lr, + "bias_correction": bias_correction, + "betas": betas, + "eps": eps, + "weight_decay": weight_decay, + } super().__init__(params, defaults) self.adam_w_mode = 1 if adam_w_mode else 0 self.set_grad_none = set_grad_none @@ -135,6 +135,7 @@ def __init__( self.multi_tensor_adam_capturable_master = tex.multi_tensor_adam_capturable_master def zero_grad(self): + # pylint: disable=missing-function-docstring if self.set_grad_none: for group in self.param_groups: for p in group["params"]: diff --git a/transformer_engine/pytorch/optimizers/fused_sgd.py b/transformer_engine/pytorch/optimizers/fused_sgd.py index 6186f3f3ea..ee428d2417 100644 --- a/transformer_engine/pytorch/optimizers/fused_sgd.py +++ b/transformer_engine/pytorch/optimizers/fused_sgd.py @@ -91,13 +91,13 @@ def __init__( if weight_decay < 0.0: raise ValueError(f"Invalid weight_decay value: {weight_decay}") - defaults = dict( - lr=lr, - momentum=momentum, - dampening=dampening, - weight_decay=weight_decay, - nesterov=nesterov, - ) + defaults = { + "lr": lr, + "momentum": momentum, + "dampening": dampening, + "weight_decay": weight_decay, + "nesterov": nesterov, + } if nesterov and (momentum <= 0 or dampening != 0): raise ValueError("Nesterov momentum requires a momentum and zero dampening") super().__init__(params, defaults) @@ -120,6 +120,7 @@ def __setstate__(self, state): group.setdefault("nesterov", False) def zero_grad(self): + # pylint: disable=missing-function-docstring if self.set_grad_none: for group in self.param_groups: for p in group["params"]: diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py index 9987db58e0..540bacbf84 100644 --- a/transformer_engine/pytorch/permutation.py +++ b/transformer_engine/pytorch/permutation.py @@ -32,6 +32,7 @@ def forward( num_out_tokens: int, max_token_num: int, ) -> Tuple[torch.Tensor, torch.Tensor]: + # pylint: disable=missing-function-docstring # Empty input check if not inp.numel(): return inp, torch.tensor([], device=inp.device) @@ -90,6 +91,7 @@ def backward( permuted_act_grad: torch.Tensor, _, ) -> Tuple[torch.Tensor, ...]: + # pylint: disable=missing-function-docstring # Empty input check if not permuted_act_grad.numel(): return permuted_act_grad, None, None, None @@ -130,6 +132,7 @@ def forward( row_id_map: torch.Tensor, probs: torch.Tensor, ) -> torch.Tensor: + # pylint: disable=missing-function-docstring # Empty input check if not inp.numel(): ctx.probs = probs @@ -188,6 +191,7 @@ def backward( ctx, unpermuted_act_grad: torch.Tensor, ) -> Tuple[torch.Tensor, None, torch.Tensor]: + # pylint: disable=missing-function-docstring # Empty input check if not unpermuted_act_grad.numel(): return unpermuted_act_grad, None, ctx.probs @@ -208,6 +212,7 @@ def backward( inp, row_id_map, probs = ctx.saved_tensors act_grad = None + prob_grad = None if ctx.needs_input_grad[0]: act_grad, prob_grad = tex.moe_unpermute_bwd( unpermuted_act_grad, inp, dtype, row_id_map, probs diff --git a/transformer_engine/pytorch/setup.py b/transformer_engine/pytorch/setup.py index 034e671150..c527ca83ef 100644 --- a/transformer_engine/pytorch/setup.py +++ b/transformer_engine/pytorch/setup.py @@ -56,7 +56,7 @@ description="Transformer acceleration library - Torch Lib", ext_modules=ext_modules, cmdclass={"build_ext": CMakeBuildExtension}, - install_requires=["torch", "flash-attn>=2.0.6,<=2.6.3,!=2.0.9,!=2.1.0"], + install_requires=["torch"], tests_require=["numpy", "onnxruntime", "torchvision"], ) if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")): diff --git a/transformer_engine/pytorch/te_onnx_extensions.py b/transformer_engine/pytorch/te_onnx_extensions.py index 0fa9401163..9b4b2df145 100755 --- a/transformer_engine/pytorch/te_onnx_extensions.py +++ b/transformer_engine/pytorch/te_onnx_extensions.py @@ -146,89 +146,136 @@ def onnx_cast_from_fp8(g, inputs, scale_inv, fp8_tensor, itype, otype): @symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i") -def onnx_fp8_gelu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): +def onnx_fp8_gelu(g, inp, scale, amax, scale_inv, fp8_tensor, otype): """ONNX graph for fp8_gelu""" # pylint: disable=unused-argument # TE computes GELU using float32 precision so wrap the GELU subgraph with # conversion to/from float32. - gelu = compute_in_fp32(g, inputs, torch.onnx.symbolic_opset9.gelu, "tanh") + dtype = get_TensorProtoDataType(inp) + if dtype != _type_utils.JitScalarType.FLOAT: + inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT) + out = torch.onnx.symbolic_opset9.gelu(g, inp, "tanh") if scale: - gelu = quantize(g, gelu, scale, fp8_tensor) - return gelu + out = quantize(g, out, scale, fp8_tensor) + elif dtype != _type_utils.JitScalarType.FLOAT: + out = g.op("Cast", out, to_i=dtype) + return out @symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i") -def onnx_fp8_relu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): +def onnx_fp8_relu(g, inp, scale, amax, scale_inv, fp8_tensor, otype): """ONNX graph for fp8_relu""" # pylint: disable=unused-argument - relu = compute_in_fp32(g, inputs, torch.onnx.symbolic_opset9.relu) + out = torch.onnx.symbolic_opset9.relu(g, inp) if scale: - relu = quantize(g, relu, scale, fp8_tensor) - return relu + out = quantize(g, out, scale, fp8_tensor) + return out @symbolic_helper.parse_args("v", "i") def onnx_swiglu(g: jit_utils.GraphContext, inp, dim): """ONNX graph for swiglu""" + + # Check dimensions dim_size = symbolic_helper._get_tensor_dim_size(inp, dim) if dim_size is not None: assert dim_size % 2 == 0 + # Perform compute in FP32 + dtype = get_TensorProtoDataType(inp) + if dtype != _type_utils.JitScalarType.FLOAT: + inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT) first, second = g.op("Split", inp, axis_i=dim, outputs=2) - return g.op("Mul", g.op("Sigmoid", first), second) + out = g.op("Mul", g.op("Sigmoid", first), second) + if dtype != _type_utils.JitScalarType.FLOAT: + out = g.op("Cast", out, to_i=dtype) + return out @symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i") -def onnx_fp8_swiglu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): +def onnx_fp8_swiglu(g, inp, scale, amax, scale_inv, fp8_tensor, otype): """ONNX graph for fp8_swiglu""" # pylint: disable=unused-argument - swiglu = compute_in_fp32(g, inputs, onnx_swiglu, 1) + dtype = get_TensorProtoDataType(inp) + if dtype != _type_utils.JitScalarType.FLOAT: + inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT) + out = onnx_swiglu(g, inp, 1) if scale: - swiglu = quantize(g, swiglu, scale, fp8_tensor) - return swiglu + out = quantize(g, out, scale, fp8_tensor) + elif dtype != _type_utils.JitScalarType.FLOAT: + out = g.op("Cast", out, to_i=dtype) + return out @symbolic_helper.parse_args("v", "i") def onnx_reglu(g: jit_utils.GraphContext, inp, dim): """ONNX graph for reglu""" + + # Check dimensions dim_size = symbolic_helper._get_tensor_dim_size(inp, dim) if dim_size is not None: assert dim_size % 2 == 0 + # Perform compute in FP32 + dtype = get_TensorProtoDataType(inp) + if dtype != _type_utils.JitScalarType.FLOAT: + inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT) first, second = g.op("Split", inp, axis_i=dim, outputs=2) - return g.op("Mul", g.op("Relu", first), second) + out = g.op("Mul", g.op("Relu", first), second) + if dtype != _type_utils.JitScalarType.FLOAT: + out = g.op("Cast", out, to_i=dtype) + return out @symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i") -def onnx_fp8_reglu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): +def onnx_fp8_reglu(g, inp, scale, amax, scale_inv, fp8_tensor, otype): """ONNX graph for fp8_reglu""" # pylint: disable=unused-argument - reglu = compute_in_fp32(g, inputs, onnx_reglu, 1) + dtype = get_TensorProtoDataType(inp) + if dtype != _type_utils.JitScalarType.FLOAT: + inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT) + out = onnx_reglu(g, inp, 1) if scale: - reglu = quantize(g, reglu, scale, fp8_tensor) - return reglu + out = quantize(g, out, scale, fp8_tensor) + elif dtype != _type_utils.JitScalarType.FLOAT: + out = g.op("Cast", out, to_i=dtype) + return out @symbolic_helper.parse_args("v", "i") def onnx_geglu(g: jit_utils.GraphContext, inp, dim): """ONNX graph for geglu""" + + # Check dimensions dim_size = symbolic_helper._get_tensor_dim_size(inp, dim) if dim_size is not None: assert dim_size % 2 == 0 + # Perform compute in FP32 + dtype = get_TensorProtoDataType(inp) + if dtype != _type_utils.JitScalarType.FLOAT: + inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT) first, second = g.op("Split", inp, axis_i=dim, outputs=2) - first_gelu = torch.onnx.symbolic_opset9.gelu(g, first, "tanh") - return g.op("Mul", first_gelu, second) + first = torch.onnx.symbolic_opset9.gelu(g, first, "tanh") + out = g.op("Mul", first, second) + if dtype != _type_utils.JitScalarType.FLOAT: + out = g.op("Cast", out, to_i=dtype) + return out @symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i") -def onnx_fp8_geglu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): +def onnx_fp8_geglu(g, inp, scale, amax, scale_inv, fp8_tensor, otype): """ONNX graph for fp8_geglu""" # pylint: disable=unused-argument - geglu = compute_in_fp32(g, inputs, onnx_geglu, 1) + dtype = get_TensorProtoDataType(inp) + if dtype != _type_utils.JitScalarType.FLOAT: + inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT) + out = onnx_geglu(g, inp, 1) if scale: - geglu = quantize(g, geglu, scale, fp8_tensor) - return geglu + out = quantize(g, out, scale, fp8_tensor) + elif dtype != _type_utils.JitScalarType.FLOAT: + out = g.op("Cast", out, to_i=dtype) + return out @symbolic_helper.parse_args( @@ -394,7 +441,7 @@ def onnx_layernorm_fwd(g, inputs, weight, bias, eps, sm_margin, zero_centered_ga @symbolic_helper.parse_args("v", "v", "f", "fs", "v", "v", "i", "i", "i", "b") def onnx_rmsnorm_fwd_fp8( g, - inputs, + inp, weight, eps, scale, @@ -407,50 +454,54 @@ def onnx_rmsnorm_fwd_fp8( ): """ONNX graph for rmsnorm_fwd_fp8""" # pylint: disable=unused-argument - inp_dtype = get_TensorProtoDataType(inputs) - - if inp_dtype != get_TensorProtoDataType(weight): - weight = g.op("Cast", weight, to_i=inp_dtype) - - ln = onnx_rmsnorm_fwd(g, inputs, weight, eps, sm_margin, zero_centered_gamma) - fp8_ln = quantize(g, ln, scale, fp8_tensor) - return fp8_ln + dtype = get_TensorProtoDataType(inp) + if dtype != _type_utils.JitScalarType.FLOAT: + inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT) + out = onnx_rmsnorm_fwd(g, inp, weight, eps, sm_margin, zero_centered_gamma) + out = quantize(g, out, scale, fp8_tensor) + return out @symbolic_helper.parse_args("v", "v", "f", "i", "b") -def onnx_rmsnorm_fwd(g, inputs, weight, eps, sm_margin, zero_centered_gamma): +def onnx_rmsnorm_fwd(g, inp, weight, eps, sm_margin, zero_centered_gamma): """ONNX graph for rmsnorm_fwd""" # pylint: disable=unused-argument - normalized_shape = torch.onnx.symbolic_helper._get_tensor_sizes(inputs) + # Check dimensions + normalized_shape = torch.onnx.symbolic_helper._get_tensor_sizes(inp) if normalized_shape is None: - ndim = torch.onnx.symbolic_helper._get_tensor_rank(inputs) + ndim = torch.onnx.symbolic_helper._get_tensor_rank(inp) assert ndim is not None normalized_shape = list(range(0, ndim)) # Normalization axis = 0, so normalized_shape uses all dims except dim = 0 normalized_shape = normalized_shape[1:] + axis = -len(normalized_shape) + + # Cast input tensors to FP32 if needed + dtype = get_TensorProtoDataType(inp) + if dtype != _type_utils.JitScalarType.FLOAT: + inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT) + if get_TensorProtoDataType(weight) != _type_utils.JitScalarType.FLOAT: + weight = g.op("Cast", weight, to_i=_C_onnx.TensorProtoDataType.FLOAT) + # Adjust zero-centered weights if zero_centered_gamma: - inputs_dtype = inputs.type().dtype() - one = _ones_like(g, weight, inputs_dtype) + one = _ones_like(g, weight, torch.float32) weight = g.op("Add", weight, one) - axis = -len(normalized_shape) - - inputs_float = g.op("Cast", inputs, to_i=_C_onnx.TensorProtoDataType.FLOAT) - - sum_square = g.op("ReduceSumSquare", inputs_float, axes_i=[axis]) - shape = g.op("Shape", inputs_float, start_i=-1) + # Perform compute in FP32 + sum_square = g.op("ReduceSumSquare", inp, axes_i=[axis]) + shape = g.op("Shape", inp, start_i=-1) shape_f = g.op("Cast", shape, to_i=_C_onnx.TensorProtoDataType.FLOAT) mean_squared = g.op("Div", sum_square, shape_f) eps_tensor = g.op("ConstantOfShape", shape, value_t=torch.tensor([eps], dtype=torch.float32)) rms_squared = g.op("Add", mean_squared, eps_tensor) rms_eps = g.op("Sqrt", rms_squared) - normalized_input = g.op("Div", inputs_float, rms_eps) - result = g.op("Mul", weight, normalized_input) - result = g.op("Cast", result, to_i=get_TensorProtoDataType(inputs)) - - return result + normalized_input = g.op("Div", inp, rms_eps) + out = g.op("Mul", weight, normalized_input) + if dtype != _type_utils.JitScalarType.FLOAT: + out = g.op("Cast", out, to_i=dtype) + return out register_custom_op_symbolic("tex_ts::cast_to_fp8_ts", onnx_cast_to_fp8, VER) diff --git a/transformer_engine/pytorch/tensor/__init__.py b/transformer_engine/pytorch/tensor/__init__.py index 2bad862768..16b7f8b623 100644 --- a/transformer_engine/pytorch/tensor/__init__.py +++ b/transformer_engine/pytorch/tensor/__init__.py @@ -4,5 +4,44 @@ """Custom tensor classes""" +import torch + from .float8_tensor import Float8Tensor from .quantized_tensor import QuantizedTensor + +__all__ = ["Float8Tensor", "QuantizedTensor"] + + +def _make_module_cast_func(dtype): + """Make module cast function that can handle QuantizedTensor""" + cast_func_name = { + torch.float32: "float", + torch.float16: "half", + torch.bfloat16: "bfloat16", + }[dtype] + + def tensor_cast_func(tensor: torch.Tensor) -> torch.Tensor: + """Cast tensor dtype""" + if isinstance(tensor, Float8Tensor): + return Float8Tensor.make_like( + tensor, + data=tensor._data, + fp8_attrs=tensor._fp8_attrs, + dtype=dtype, + requires_grad=tensor.requires_grad, + ) + if tensor.is_floating_point(): + return getattr(tensor, cast_func_name)() + return tensor + + def module_cast_func(self: torch.nn.Module) -> torch.nn.Module: + """Cast module dtype""" + return self._apply(tensor_cast_func) + + return module_cast_func + + +# Monkey-patch module cast functions to handle QuantizedTensor +torch.nn.Module.float = _make_module_cast_func(torch.float32) +torch.nn.Module.half = _make_module_cast_func(torch.float16) +torch.nn.Module.bfloat16 = _make_module_cast_func(torch.bfloat16) diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 610523a10d..110059d745 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -49,7 +49,7 @@ def set_func(self, value: Any) -> None: def del_func(self) -> None: del self._fp8_attrs[name] - return dict(fget=get_func, fset=set_func, fdel=del_func) + return {"fget": get_func, "fset": set_func, "fdel": del_func} class _FromFloat8Func(torch.autograd.Function): @@ -61,6 +61,7 @@ def forward( tensor: Float8Tensor, dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: + # pylint: disable=missing-function-docstring return tensor.dequantize(dtype=dtype) @staticmethod @@ -68,6 +69,7 @@ def backward( _ctx: torch.autograd.function.FunctionCtx, # unused grad: torch.Tensor, ) -> Tuple[Optional[torch.Tensor], ...]: + # pylint: disable=missing-function-docstring # Assume that we want gradients in full precision return grad, None @@ -112,6 +114,7 @@ def forward( scale_inv: Optional[torch.Tensor] = None, with_transpose_cache: bool = False, ) -> Float8Tensor: + # pylint: disable=missing-function-docstring # Tensor attributes dtype = tensor.dtype @@ -126,12 +129,9 @@ def forward( # Check scale if scale is None and fp8_meta is None: - scale = 1 + scale = torch.full([1], 1, dtype=torch.float32, device=device) if scale is not None: - if isinstance(scale, torch.Tensor): - scale = scale.to(device=device, dtype=torch.float32) - else: - scale = torch.full([1], scale, dtype=torch.float32, device=device) + scale = scale.to(device=device, dtype=torch.float32) # Check scale-inverse if scale_inv is None: @@ -170,6 +170,7 @@ def backward( _ctx: torch.autograd.function.FunctionCtx, # unused grad: torch.Tensor, ) -> Tuple[Optional[torch.Tensor], ...]: + # pylint: disable=missing-function-docstring # Assume that we want gradients in full precision return grad, None, None, None, None, None, None, None @@ -188,6 +189,7 @@ def forward( tensor: Float8Tensor, init_kwargs: Optional[Dict[str, Any]] = None, ) -> torch.Tensor: + # pylint: disable=missing-function-docstring # Return input tensor if constructor kwargs are not provided ctx.input_dtype = tensor.dtype @@ -195,15 +197,15 @@ def forward( return tensor # Construct new tensor if constructor kwargs are provided - default_kwargs = dict( - data=tensor._data, - fp8_meta=tensor._fp8_meta, - fp8_meta_forward=tensor._fp8_meta_forward, - fp8_meta_index=tensor._fp8_meta_index, - fp8_dtype=tensor._fp8_dtype, - fp8_scale_inv=tensor._scale_inv, - dtype=tensor.dtype, - ) + default_kwargs = { + "data": tensor._data, + "fp8_meta": tensor._fp8_meta, + "fp8_meta_forward": tensor._fp8_meta_forward, + "fp8_meta_index": tensor._fp8_meta_index, + "fp8_dtype": tensor._fp8_dtype, + "fp8_scale_inv": tensor._scale_inv, + "dtype": tensor.dtype, + } for key, val in default_kwargs.items(): if key not in init_kwargs: init_kwargs[key] = val @@ -211,6 +213,7 @@ def forward( @staticmethod def backward(ctx, grad): + # pylint: disable=missing-function-docstring return grad.to(ctx.input_dtype), None @@ -227,6 +230,7 @@ def forward( tensor: torch.Tensor, shape: Tuple[int] = None, ) -> torch.Tensor: + # pylint: disable=missing-function-docstring # Return input tensor if shape is not provided ctx.shape = tensor.shape @@ -246,6 +250,7 @@ def backward( ctx, grad: torch.Tensor, ) -> Tuple[Optional[torch.Tensor], ...]: + # pylint: disable=missing-function-docstring if isinstance(grad, Float8Tensor): dgrad = Float8Tensor.make_like( @@ -269,6 +274,7 @@ def forward( tensor: torch.Tensor, shape: Tuple[int] = None, ) -> torch.Tensor: + # pylint: disable=missing-function-docstring # Return input tensor if shape is not provided ctx.shape = tensor.shape @@ -288,6 +294,7 @@ def backward( ctx, grad: torch.Tensor, ) -> Tuple[Optional[torch.Tensor], ...]: + # pylint: disable=missing-function-docstring if isinstance(grad, Float8Tensor): dgrad = Float8Tensor.make_like( @@ -335,6 +342,18 @@ class Float8Tensor(QuantizedTensor): """ + _data: torch.Tensor + _fp8_attrs: Dict[str, Any] + _fp8_meta: Optional[Dict[str, Any]] + _fp8_meta_forward: bool + _fp8_meta_index: Optional[int] + _fp8_dtype: TE_DType + _scale_inv: torch.Tensor + + # FP8 transpose cache + _transpose: Optional[torch.Tensor] + _transpose_invalid: bool + def __new__( cls, *, @@ -346,6 +365,7 @@ def __new__( fp8_dtype: TE_DType = TE_DType.kFloat8E4M3, fp8_scale_inv: Optional[torch.Tensor] = None, dtype: torch.dtype = torch.float32, + requires_grad: bool = False, data_transpose: Optional[torch.Tensor] = None, ): @@ -367,16 +387,15 @@ def __new__( storage_offset=data.storage_offset(), dtype=dtype, layout=data.layout, - requires_grad=data.requires_grad, + requires_grad=requires_grad, device=data.device, ) - self._data: torch.Tensor = data + self._data = data # Initialize dict of class attributes # Note: We store FP8 attributes in a dictionary so we can # share them between tensors with the same data, e.g. detached # tensors. - self._fp8_attrs: dict if fp8_attrs is None: self._fp8_attrs = {} else: @@ -389,16 +408,16 @@ def __new__( "To initialize Float8Tensor with FP8 meta tensors, " "the FP8 meta tensor index must also be provided" ) - self._fp8_meta: Optional[Dict[str, Any]] = fp8_meta - self._fp8_meta_forward: bool = fp8_meta_forward - self._fp8_meta_index: Optional[int] = fp8_meta_index + self._fp8_meta = fp8_meta + self._fp8_meta_forward = fp8_meta_forward + self._fp8_meta_index = fp8_meta_index # FP8 dtype assert fp8_dtype in ( TE_DType.kFloat8E4M3, TE_DType.kFloat8E5M2, ), f"Unsupported fp8_dtype {fp8_dtype}." - self._fp8_dtype: TE_DType = fp8_dtype + self._fp8_dtype = fp8_dtype # FP8 scale-inverse if fp8_scale_inv is None and self._fp8_meta is not None: @@ -411,13 +430,6 @@ def __new__( raise ValueError( "Attempted to initialize Float8Tensor without specifying scale-inverse" ) - if not isinstance(fp8_scale_inv, torch.Tensor): - fp8_scale_inv = torch.full( - [1], - fp8_scale_inv, - dtype=torch.float32, - device=self._data.device, - ) if fp8_scale_inv.numel() != 1: raise ValueError( "Attempted to initialize Float8Tensor with invalid scale-inverse tensor" @@ -432,11 +444,11 @@ def __new__( device=self._data.device, dtype=torch.float32, ) - self._scale_inv: Optional[torch.Tensor] = fp8_scale_inv + self._scale_inv = fp8_scale_inv # FP8 transpose cache - self._transpose: Optional[Float8Tensor] = data_transpose - self._transpose_invalid: bool = self._transpose is None + self._transpose = data_transpose + self._transpose_invalid = self._transpose is None return self @@ -454,14 +466,14 @@ def make_like( See constructor for list of keyword arguments. """ - default_kwargs = dict( - fp8_meta=tensor._fp8_meta, - fp8_meta_forward=tensor._fp8_meta_forward, - fp8_meta_index=tensor._fp8_meta_index, - fp8_dtype=tensor._fp8_dtype, - fp8_scale_inv=tensor._scale_inv, - dtype=tensor.dtype, - ) + default_kwargs = { + "fp8_meta": tensor._fp8_meta, + "fp8_meta_forward": tensor._fp8_meta_forward, + "fp8_meta_index": tensor._fp8_meta_index, + "fp8_dtype": tensor._fp8_dtype, + "fp8_scale_inv": tensor._scale_inv, + "dtype": tensor.dtype, + } for key, val in default_kwargs.items(): if key not in kwargs: kwargs[key] = val @@ -476,7 +488,7 @@ def __repr__(self): ")" ) - def dequantize(self, dtype: Optional[torch.dtype] = None) -> torch.Tensor: + def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: # Convert PyTorch dtype to TE dtype if dtype is None: @@ -602,11 +614,8 @@ def quantize_( # Make sure FP8 scaling factors are in expected format if scale is not None: - if isinstance(scale, torch.Tensor): - if not devices_match(scale.device, dst.device) or scale.dtype != torch.float32: - scale = scale.to(device=dst.device, dtype=torch.float32) - else: - scale = torch.full([1], scale, dtype=torch.float32, device=dst.device) + if not devices_match(scale.device, dst.device) or scale.dtype != torch.float32: + scale = scale.to(device=dst.device, dtype=torch.float32) if amax is not None: while amax.dim() < 2: amax = amax.unsqueeze(0) @@ -698,6 +707,7 @@ def to_float8( ) def detach(self) -> Float8Tensor: + # pylint: disable=missing-function-docstring return Float8Tensor.make_like( self, data=self._data, @@ -705,22 +715,25 @@ def detach(self) -> Float8Tensor: ) def clone(self) -> Float8Tensor: + # pylint: disable=missing-function-docstring data = self._data.detach().clone() data_transpose = None if self._transpose is not None: data_transpose = self._transpose.detach().clone() return _IdentityFunc.apply( self, - dict( - data=data, - data_transpose=data_transpose, - ), + { + "data": data, + "data_transpose": data_transpose, + }, ) def view(self, *shape: Tuple[int]) -> Float8Tensor: + # pylint: disable=missing-function-docstring return _ViewFunc.apply(self, shape) def reshape(self, *shape: Tuple[int]) -> Float8Tensor: + # pylint: disable=missing-function-docstring return _ReshapeFunc.apply(self, shape) def contiguous( @@ -780,23 +793,21 @@ def transpose_2d( fill_cache = False # Need to compute transpose if cache is invalid - need_compute = force_compute - if self._transpose is None: - need_compute = True - elif self._transpose_invalid: - need_compute = True - - # Need to apply transpose kernel if noop flag is applied - if noop_flag is not None: - need_compute = True + need_compute = ( + force_compute + or (self._transpose is None) + or self._transpose_invalid + or (noop_flag is not None) + ) # Return cached transpose if possible if not need_compute: + assert self._transpose is not None return self._transpose # Allocate output if needed data = self._data.contiguous().reshape(-1, self.size(-1)) - out = self._transpose + out: Optional[torch.Tensor] = self._transpose if out is None: out = torch.empty( (data.size(1), data.size(0)), @@ -947,14 +958,83 @@ def _get_data(self) -> Float8Tensor: """Get tensor data property""" return super().data + @torch.no_grad() def _set_data(self, tensor: torch.Tensor) -> None: """Set tensor data property - Cast tensor to FP8 and store in FP8 buffer. + Just takes FP8 data if setting from a Float8Tensor. Otherwise + casts to FP8. """ - with torch.no_grad(): - self.copy_(tensor) + + # Tensor device + new_device = tensor.device if tensor.is_cuda else self.device + + # Check whether grad is required + if self.requires_grad != tensor.requires_grad: + self.requires_grad_(requires_grad=tensor.requires_grad) + + # Just copy FP8 data if other tensor is Float8Tensor + if isinstance(tensor, Float8Tensor): + if ( # pylint: disable=too-many-boolean-expressions + self.size() != tensor.size() + or self.stride() != tensor.stride() + or self.storage_offset() != tensor.storage_offset() + or self.dtype != tensor.dtype + or self.layout != tensor.layout + or not devices_match(self.device, new_device) + ): + dummy_tensor = torch.Tensor._make_wrapper_subclass( + Float8Tensor, + tensor.size(), + strides=tensor.stride(), + storage_offset=tensor.storage_offset(), + dtype=tensor.dtype, + layout=tensor.layout, + requires_grad=tensor.requires_grad, + device=new_device, + ) + # pylint: disable=unnecessary-dunder-call + super(Float8Tensor, type(self)).data.__set__(self, dummy_tensor) + self._data = tensor._data + self._fp8_attrs = tensor._fp8_attrs + return + + # Reallocate FP8 data if needed + if ( + self.size() != tensor.size() + or self.stride() != tensor.stride() + or self.dtype != tensor.dtype + or self.layout != tensor.layout + or not devices_match(self.device, new_device) + ): + self._data = torch.empty_like( + tensor, + dtype=torch.uint8, + device=new_device, + ) + dummy_tensor = torch.Tensor._make_wrapper_subclass( + Float8Tensor, + self._data.size(), + strides=self._data.stride(), + storage_offset=self._data.storage_offset(), + dtype=tensor.dtype, + layout=self._data.layout, + requires_grad=tensor.requires_grad, + device=self._data.device, + ) + # pylint: disable=unnecessary-dunder-call + super(Float8Tensor, type(self)).data.__set__(self, dummy_tensor) + if self._transpose is not None: + self._transpose = torch.empty( + (self._data.size(-1), self._data.numel() // self._data.size(-1)), + dtype=torch.uint8, + device=self.device, + ) + self._transpose_invalid = True + + # Copy values from other tensor + self.quantize_(tensor) # Cast to FP8 when setting Float8Tensor.data data = property(_get_data, _set_data) diff --git a/transformer_engine/pytorch/tensor/quantized_tensor.py b/transformer_engine/pytorch/tensor/quantized_tensor.py index f890b0878a..92c95b56ca 100644 --- a/transformer_engine/pytorch/tensor/quantized_tensor.py +++ b/transformer_engine/pytorch/tensor/quantized_tensor.py @@ -20,6 +20,7 @@ def forward( tensor: QuantizedTensor, dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: + # pylint: disable=missing-function-docstring return tensor.dequantize(dtype=dtype) @staticmethod @@ -27,6 +28,7 @@ def backward( _ctx: torch.autograd.function.FunctionCtx, # unused grad: torch.Tensor, ) -> Tuple[Optional[torch.Tensor], ...]: + # pylint: disable=missing-function-docstring return grad, None @@ -38,6 +40,7 @@ def forward( _ctx: torch.autograd.function.FunctionCtx, # unused tensor: QuantizedTensor, ) -> QuantizedTensor: + # pylint: disable=missing-function-docstring return tensor.detach() @staticmethod @@ -45,6 +48,7 @@ def backward( _ctx: torch.autograd.function.FunctionCtx, # unused grad: torch.Tensor, ) -> torch.Tensor: + # pylint: disable=missing-function-docstring return grad @@ -85,18 +89,23 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}(data={self.dequantize(dtype=self.dtype)})" def float(self) -> torch.Tensor: + # pylint: disable=missing-function-docstring return _DequantizeFunc.apply(self, torch.float32) def bfloat16(self) -> torch.Tensor: + # pylint: disable=missing-function-docstring return _DequantizeFunc.apply(self, torch.bfloat16) def half(self) -> torch.Tensor: + # pylint: disable=missing-function-docstring return _DequantizeFunc.apply(self, torch.float16) def cpu(self) -> torch.Tensor: + # pylint: disable=missing-function-docstring return _DequantizeFunc.apply(self).cpu() def expand_as(self, other: torch.Tensor) -> torch.Tensor: + # pylint: disable=missing-function-docstring if other is self: # Note: expand_as is hackily used to create dummy autograd nodes # and access the backward graph (see diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 020d262be2..ad5476450b 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -484,7 +484,7 @@ def reset_fp8_meta_tensors(self) -> None: def set_context_parallel_group( self, - cp_group: Union[dist_group_type, None], + cp_group: Union[dist_group_type, List[dist_group_type], None], cp_global_ranks: List[int], cp_stream: torch.cuda.Stream, cp_comm_type: str = "p2p", @@ -495,21 +495,27 @@ def set_context_parallel_group( Parameters ---------- - cp_group : ProcessGroup + cp_group : Union[ProcessGroup, List[ProcessGroup]] context parallel process group. + ProcessGroup is for cp_comm_type of "p2p", "all_gather", and "a2a". + List[ProcessGroup] is for cp_comm_type of "a2a+p2p", where cp_group[0] + and cp_group[1] are for a2a and p2p communications respectively. cp_global_ranks : List[int] list of global ranks in the context group. cp_stream : torch.cuda.Stream cuda stream for context parallel execution. - cp_comm_type : str + cp_comm_type : str, default = `p2p` inter-gpu communication type for context parallelism. - Can be "p2p" or "all_gather" or "a2a". + Can be "p2p" or "all_gather" or "a2a", or "a2a+p2p". "p2p": Exchange KV chunks with P2P communications in ring topology. P2P is async and can be overlapped with attention compute. "all_gather": All-gather to get full sequence of KV before attention. The all-gather is not async, and cannot be overlapped. "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP group, and gather to get full sequence of QKV. + "a2a+p2p": hierarchical CP implementation. First applying a2a to QKV + across each CP sub-group (e.g., via NVLink), then exchanging KV with + p2p between sub-groups (e.g., via IBLink). """ # Deep iterate but skip self to avoid infinite recursion. for index, child in enumerate(self.modules()): @@ -751,7 +757,7 @@ def forward( return output def _bias_dropout_add(self, hidden_state, bias, residual, drop_path=None): - if drop_path is None and bias.numel() != 0: + if drop_path is None and bias is not None and bias.numel() != 0: if self.bias_dropout_fusion: if self.training: bias_dropout_add_func = bias_dropout_add_fused_train @@ -763,7 +769,7 @@ def _bias_dropout_add(self, hidden_state, bias, residual, drop_path=None): with self.bias_dropout_add_exec_handler(): output = bias_dropout_add_func(hidden_state, bias, residual, self.hidden_dropout) else: - if bias.numel() != 0: + if bias is not None and bias.numel() != 0: hidden_state = hidden_state + bias out = torch.nn.functional.dropout( hidden_state, p=self.hidden_dropout, training=self.training diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index d5145455b8..947c642c2c 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -218,8 +218,12 @@ def safely_set_viewless_tensor_data(tensor: torch.Tensor, new_data_tensor: torch def cast_if_needed(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: """Cast tensor to dtype""" + if tensor is None: + return None + if tensor.dtype == dtype: + return tensor with torch.enable_grad(): - return tensor if tensor is None or tensor.dtype == dtype else tensor.to(dtype) + return tensor.to(dtype=dtype) def check_dim_for_fp8_exec(tensor: torch.Tensor) -> bool: