From 29e9e278a331eeba8b44d84a718a4bc49a50dd5b Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Thu, 20 Oct 2022 11:23:53 -0700 Subject: [PATCH 01/68] Change version to 0.2.0 Signed-off-by: Przemek Tredak --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index 3f69f33a74..0ea3a944b3 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.2.0dev +0.2.0 From 73166c4e3f6cf0e754045ba22ff461ef96453aeb Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Wed, 23 Nov 2022 09:45:36 -0800 Subject: [PATCH 02/68] Full activation recompute checkpointing bug fix (#31) fix checkpoint loading bug for FAR Signed-off-by: Kirthi Shankar Sivamani Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/fp8.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index e25a413d4f..8fafdafa3e 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -69,13 +69,13 @@ def get_global_fp8_recompute_buffer() -> Dict[str, List[torch.Tensor]]: return _fp8_tensors_recompute_buffer -def set_global_fp8_recompute_buffer(buffer: List[Deque[torch.Tensor]]) -> None: +def set_global_fp8_recompute_buffer(buffer: List[Deque[List[torch.Tensor]]]) -> None: """Sets global fp8 recompute buffer.""" global _fp8_tensors_recompute_buffer # Map all tensors back to GPU. for index, deck in enumerate(buffer): - buffer[index] = deque([tensor.cuda() for tensor in deck]) + buffer[index] = deque([[t.cuda() for t in tensors] for tensors in deck]) _fp8_tensors_recompute_buffer = buffer @@ -118,11 +118,11 @@ def copy_forward_fp8_meta_tensors_for_recompute(fp8_meta: Dict[str, Any]) -> Non global _fp8_tensors_recompute_buffer buffer_position_key = "global_fp8_buffer_pos_fwd_recompute" - to_copy = ( + to_copy = [ fp8_meta["scaling_fwd"].amax_history.clone(), fp8_meta["scaling_fwd"].scale.clone(), fp8_meta["scaling_fwd"].scale_inv.clone(), - ) + ] if buffer_position_key in fp8_meta: _fp8_tensors_recompute_buffer[fp8_meta[buffer_position_key]].append(to_copy) From 126232df4e87cea7a46278ebb23f47397315d0c0 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 31 Jan 2023 10:09:48 -0800 Subject: [PATCH 03/68] Address steady memory increase and bloated checkpoints (#63) Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/fp8.py | 18 +----------------- transformer_engine/pytorch/module.py | 8 +++----- 2 files changed, 4 insertions(+), 22 deletions(-) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index fd05358a93..e4cce98931 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -5,7 +5,7 @@ """FP8 utilies for TransformerEngine""" from contextlib import contextmanager from collections import deque -from typing import Callable, List, Optional, Dict, Any, Tuple, Union, Deque +from typing import Callable, List, Optional, Dict, Any, Tuple, Union import torch import transformer_engine_extensions as tex @@ -64,22 +64,6 @@ def set_global_fp8_buffer(buffer: Dict[str, List[torch.Tensor]]) -> None: _global_fp8_buffer = buffer -def get_global_fp8_recompute_buffer() -> Dict[str, List[torch.Tensor]]: - """Returns global fp8 recompute buffer.""" - return _fp8_tensors_recompute_buffer - - -def set_global_fp8_recompute_buffer(buffer: List[Deque[List[torch.Tensor]]]) -> None: - """Sets global fp8 recompute buffer.""" - global _fp8_tensors_recompute_buffer - - # Map all tensors back to GPU. - for index, deck in enumerate(buffer): - buffer[index] = deque([[t.cuda() for t in tensors] for tensors in deck]) - - _fp8_tensors_recompute_buffer = buffer - - def setup_amax_forward_global_reduce_func(f: Callable) -> None: """Sets up the function to call during autocast exit.""" global _amax_forward_global_reduce_func diff --git a/transformer_engine/pytorch/module.py b/transformer_engine/pytorch/module.py index 0a6cae3b4a..ada798c374 100644 --- a/transformer_engine/pytorch/module.py +++ b/transformer_engine/pytorch/module.py @@ -32,8 +32,6 @@ amax_and_scale_update, get_global_fp8_buffer, set_global_fp8_buffer, - get_global_fp8_recompute_buffer, - set_global_fp8_recompute_buffer, set_amax_buffer_key_deletion, delete_key_from_amax_buffer, copy_forward_fp8_meta_tensors_for_recompute, @@ -201,7 +199,6 @@ def get_extra_state(self) -> Union[List[Any], None]: state["scale_bwd"] = self.fp8_meta["scaling_bwd"].scale state["amax_history_bwd"] = self.fp8_meta["scaling_bwd"].amax_history state["global_fp8_buffer"] = get_global_fp8_buffer() - state["global_fp8_recompute_buffer"] = get_global_fp8_recompute_buffer() # Store other pickelable values. extra = {} @@ -254,11 +251,11 @@ def set_extra_state(self, state: Union[List[Any], None]) -> None: # Restore global FP8 buffer states. set_global_fp8_buffer(state["global_fp8_buffer"]) - set_global_fp8_recompute_buffer(state["global_fp8_recompute_buffer"]) - # Load extra items. self.fp8_meta.update(state["extra_fp8_variables"]) self.fp8_meta["recipe"].amax_history_len = state["amax_history_fwd"].shape[0] + if "global_fp8_buffer_pos_fwd_recompute" in self.fp8_meta: + del self.fp8_meta["global_fp8_buffer_pos_fwd_recompute"] # Initialize before loading. self.init_fp8_meta_tensors() @@ -433,6 +430,7 @@ def prepare_forward( # Activation recomputation is used and this is the first forward phase. if ( self.fp8 + and self.training and is_fp8_activation_recompute_enabled() and not in_fp8_activation_recompute_phase() ): From ce58fc2fe786776fef43fcf1a3bb1baaf09ee03a Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 28 Feb 2023 23:05:43 -0800 Subject: [PATCH 04/68] 3rd party acknowledgements (#82) add 3rd party acknowledgements Signed-off-by: Kirthi Shankar Sivamani --- Acknowledgements.txt | 140 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 140 insertions(+) create mode 100644 Acknowledgements.txt diff --git a/Acknowledgements.txt b/Acknowledgements.txt new file mode 100644 index 0000000000..7eec81a9ce --- /dev/null +++ b/Acknowledgements.txt @@ -0,0 +1,140 @@ +This software includes third-party components under the following licenses: + +======================== +GoogleTest + +Copyright 2008, Google Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +======================== +pybind11 + +Copyright (c) 2016 Wenzel Jakob , All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors + may be used to endorse or promote products derived from this software + without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +Please also refer to the file CONTRIBUTING.md, which clarifies licensing of +external contributions to this project including patches, pull requests, etc. + +======================== +PyTorch + +Copyright (c) 2016- Facebook, Inc (Adam Paszke) +Copyright (c) 2014- Facebook, Inc (Soumith Chintala) +Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) +Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) +Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) +Copyright (c) 2011-2013 NYU (Clement Farabet) +Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) +Copyright (c) 2006 Idiap Research Institute (Samy Bengio) +Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America + and IDIAP Research Institute nor the names of its contributors may be + used to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. + +======================== +FlashAttn + +Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file. +All rights reserved. + +All contributions by Nvidia: +Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. From 4c358916450c74d03a882e1eda572dd380cfd527 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 2 Mar 2023 10:56:33 -0800 Subject: [PATCH 05/68] Fix unfused QKV params case; stack vs interleave option (#83) * fix qkv weight unfused path Signed-off-by: Kirthi Shankar Sivamani * fix non FA non interleaved case Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/transformer.py | 79 +++++++++++++++++------ transformer_engine/pytorch/utils.py | 9 ++- 2 files changed, 63 insertions(+), 25 deletions(-) diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index c0989f9c93..046dda20b2 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -24,7 +24,7 @@ from transformer_engine.pytorch.utils import ( divide, attention_mask_func, - split_tensor_along_last_dim, + split_tensor_along_dim, cast_if_needed, get_default_init_method, ) @@ -126,11 +126,11 @@ def forward( ) # [sq, b, np, hn] -> [sq, b * np, hn] - query_layer = query_layer.view( + query_layer = query_layer.reshape( output_size[2], output_size[0] * output_size[1], -1 ) # [sk, b, np, hn] -> [sk, b * np, hn] - key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) + key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1) # preallocting result tensor: [b * np, sq, sk] matmul_result = torch.empty( @@ -171,7 +171,7 @@ def forward( ) # change view [sk, b * np, hn] - value_layer = value_layer.view( + value_layer = value_layer.reshape( value_layer.size(0), output_size[0] * output_size[1], -1 ) @@ -504,6 +504,7 @@ def __init__( set_parallel_mode: bool = False, fuse_qkv_params: bool = False, zero_centered_gamma: bool = False, + qkv_weight_interleaved: bool = True, ) -> None: super().__init__() self.layer_number = (layer_number,) @@ -515,6 +516,10 @@ def __init__( self.params_dtype = params_dtype self.init_method = init_method + if not fuse_qkv_params: + qkv_weight_interleaved = False + self.qkv_weight_interleaved = qkv_weight_interleaved + assert ( attention_type in AttnTypes ), f"attention_type {attention_type} not supported" @@ -703,16 +708,28 @@ def forward( is_first_microbatch=is_first_microbatch, ) - # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] - new_tensor_shape = mixed_x_layer.size()[:-1] + ( - self.num_attention_heads_per_partition, - 3 * self.hidden_size_per_attention_head, - ) + if self.qkv_weight_interleaved: + # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] + new_tensor_shape = mixed_x_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head, + ) + # split along last dimension + split_dim = -1 + else: + # [sq, b, (np * 3 * hn)] --> [sq, b, 3 * np, hn] + new_tensor_shape = mixed_x_layer.size()[:-1] + ( + 3 * self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) + # split along second last dimension + split_dim = -2 + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) - # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] - query_layer, key_layer, value_layer = split_tensor_along_last_dim( - mixed_x_layer, 3 + # mixed_x_layer --> 3 [sq, b, np, hn] + query_layer, key_layer, value_layer = split_tensor_along_dim( + mixed_x_layer, split_dim, 3 ) else: # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] @@ -721,15 +738,27 @@ def forward( is_first_microbatch=is_first_microbatch, ) - # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn] - new_tensor_shape = mixed_kv_layer.size()[:-1] + ( - self.num_attention_heads_per_partition, - 2 * self.hidden_size_per_attention_head, - ) + if self.qkv_weight_interleaved: + # [sq, b, (np * 2 * hn)] --> [sq, b, np, 2 * hn] + new_tensor_shape = mixed_kv_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + 2 * self.hidden_size_per_attention_head, + ) + # split along last dimension + split_dim = -1 + else: + # [sq, b, (np * 2 * hn)] --> [sq, b, 2 * np, hn] + new_tensor_shape = mixed_kv_layer.size()[:-1] + ( + 2 * self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) + # split along second last dimension + split_dim = -2 + mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape) - # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn] - (key_layer, value_layer) = split_tensor_along_last_dim(mixed_kv_layer, 2) + # mixed_kv_layer --> 2 [sk, b, np, hn] + key_layer, value_layer = split_tensor_along_dim(mixed_kv_layer, split_dim, 2) # Attention head [sq, b, h] --> [sq, b, hp] if self.input_layernorm: @@ -863,7 +892,12 @@ class TransformerLayer(torch.nn.Module): .. math:: y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta - + qkv_weight_interleaved : bool, default = `True` + if set to `False`, the QKV weight is interpreted as a concatenation of + query, key, and value weights along the `0th` dimension. The default + interpretation is that the individual `q`, `k`, and `v` weights for each + attention head are interleaved. This parameter is set to `False` when + using :attr:`fuse_qkv_params=False`. Parallelism parameters ---------------------- set_parallel_mode : bool, default = `False` @@ -938,6 +972,7 @@ def __init__( set_parallel_mode: bool = False, fuse_qkv_params: bool = False, zero_centered_gamma: bool = False, + qkv_weight_interleaved: bool = True, ) -> None: super().__init__() @@ -958,6 +993,9 @@ def __init__( not fuse_wgrad_accumulation ), "Gradient accumulation fusion requires single QKV parameter." + if not fuse_qkv_params: + qkv_weight_interleaved = False + self.kv_channels = ( kv_channels if kv_channels else (hidden_size // num_attention_heads) ) @@ -995,6 +1033,7 @@ def __init__( "set_parallel_mode": set_parallel_mode, "fuse_qkv_params": fuse_qkv_params, "zero_centered_gamma": zero_centered_gamma, + "qkv_weight_interleaved" : qkv_weight_interleaved, } self.self_attention = MultiHeadAttention( diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index a71891b8e9..9f1ddaa2b2 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -78,8 +78,8 @@ def divide(numerator: int, denominator: int) -> int: return numerator // denominator -def split_tensor_along_last_dim( - tensor: torch.Tensor, num_partitions: int, contiguous_split_chunks: bool = False +def split_tensor_along_dim( + tensor: torch.Tensor, dim: int, num_partitions: int, contiguous_split_chunks: bool = False ) -> Tuple[torch.Tensor, ...]: """Split a tensor along its last dimension. Arguments: @@ -89,10 +89,9 @@ def split_tensor_along_last_dim( in memory. """ # Get the size and dimension. - last_dim = tensor.dim() - 1 - last_dim_size = divide(tensor.size()[last_dim], num_partitions) + split_size = divide(tensor.size()[dim], num_partitions) # Split. - tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + tensor_list = torch.split(tensor, split_size, dim=dim) # Note: torch.split does not create contiguous tensors by default. if contiguous_split_chunks: return tuple(chunk.contiguous() for chunk in tensor_list) From bb1203894d4cf5007e00a8004bb1b10740cfbee5 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 7 Mar 2023 09:26:18 -0800 Subject: [PATCH 06/68] Fix flash attention (#84) * ignore self attention mask for causal type Signed-off-by: Kirthi Shankar Sivamani * further relax checks to run FA, update docs Signed-off-by: Kirthi Shankar Sivamani * fix pytorch softmax path Signed-off-by: Kirthi Shankar Sivamani * fixes Signed-off-by: Kirthi Shankar Sivamani * minimum ampere requirement for fa Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani --- tests/test_onnx_export.py | 1 + transformer_engine/pytorch/softmax.py | 13 +++++++ transformer_engine/pytorch/transformer.py | 46 ++++++++++++++--------- transformer_engine/pytorch/utils.py | 7 ++++ 4 files changed, 49 insertions(+), 18 deletions(-) diff --git a/tests/test_onnx_export.py b/tests/test_onnx_export.py index f43899c33f..7d905612b4 100644 --- a/tests/test_onnx_export.py +++ b/tests/test_onnx_export.py @@ -793,6 +793,7 @@ def test_export_core_attention( if attn_mask_type is None: attn_mask_type = 'causal' + inp = (query_layer, key_layer, value_layer) model = te.transformer.DotProductAttention( num_attention_heads=num_attention_heads, kv_channels=kv_channels, diff --git a/transformer_engine/pytorch/softmax.py b/transformer_engine/pytorch/softmax.py index 8bdb3e1c82..775f3fedd9 100644 --- a/transformer_engine/pytorch/softmax.py +++ b/transformer_engine/pytorch/softmax.py @@ -16,6 +16,15 @@ THREADS_PER_BLOCK = 128 +_default_causal_mask = {} + +def _get_default_causal_mask(sq: int) -> torch.Tensor: + """Return the causal upper triangular mask for softmax input""" + if sq not in _default_causal_mask: + _default_causal_mask[sq] = torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool() + return _default_causal_mask[sq] + + class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): """ Fused operation which performs following three operations in sequence @@ -274,6 +283,10 @@ def forward_torch_softmax( if self.scale is not None: inp = inp * self.scale + + if self.attn_mask_type == "causal": + mask = _get_default_causal_mask(inp.size()[2]) + mask_output = self.mask_func(inp, mask) if mask is not None else inp probs = torch.nn.Softmax(dim=-1)(mask_output) diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 046dda20b2..a9a3b84aa0 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -27,6 +27,7 @@ split_tensor_along_dim, cast_if_needed, get_default_init_method, + get_device_compute_capability, ) from transformer_engine.pytorch.constants import ( AttnMaskTypes, @@ -220,9 +221,6 @@ def __init__( assert ( attn_mask_type == "causal" ), 'FlashAttention currently only supports causal attention mask.' - assert ( - attention_softmax_in_fp32 - ), 'FlashAttention currently only supports softmax compute in fp32.' self.attn_causal_mask = attn_mask_type == "causal" self.norm_factor = norm_factor @@ -230,6 +228,7 @@ def __init__( self.attention_dropout = attention_dropout self.layer_number = layer_number self.apply_query_key_layer_scaling = apply_query_key_layer_scaling + self.attention_softmax_in_fp32 = attention_softmax_in_fp32 def forward( self, @@ -287,6 +286,11 @@ class DotProductAttention(torch.nn.Module): representation subspaces as described in the paper: `Attention Is All You Need `_. + .. note:: + + Argument :attr:`attention_mask` will be ignored in the `forward` call when + :attr:`attn_mask_type` is set to `"causal"`. + .. warning:: For the default attention mechanism, this module executes a non-deterministic version of @@ -303,15 +307,6 @@ class DotProductAttention(torch.nn.Module): number of key-value channels. attention_dropout: float, default = 0.0 dropout probability for the dropout op during multi-head attention. - layer_number: int, default = `None` - layer number of the current `DotProductAttention` when multiple such modules - are concatenated, for instance in consecutive transformer blocks. - apply_query_key_layer_scaling: bool, default = `False` - apply query-key layer scaling during BMM1 - by a factor of `layer_number` - attention_softmax_in_fp32: bool, default = `True` - if set to `False`, softmax is executed in - the dtype of activation tensors. attn_mask_type: {'causal', 'padding'}, default = `causal` type of attention mask passed into softmax operation. @@ -371,9 +366,8 @@ def __init__( self.use_flash_attention = ( int(os.getenv("NVTE_FLASH_ATTN", "1")) - and attention_softmax_in_fp32 and attn_mask_type == "causal" - and not apply_query_key_layer_scaling + and get_device_compute_capability() >= 8.0 ) attn_kwargs = { @@ -422,6 +416,11 @@ def forward( """ Dot Product Attention Layer. + .. note:: + + Argument :attr:`attention_mask` will be ignored when :attr:`attn_mask_type` + is set to `"causal"`. + .. note:: Input tensors :attr:`query_layer`, :attr:`key_layer`, and :attr:`value_layer` @@ -448,8 +447,7 @@ def forward( """ use_flash_attention = self.use_flash_attention - if (attention_mask is not None - or query_layer.dtype not in [torch.bfloat16, torch.float16] + if (query_layer.dtype not in [torch.bfloat16, torch.float16] or key_layer.dtype not in [torch.bfloat16, torch.float16] or value_layer.dtype not in [torch.bfloat16, torch.float16] ): @@ -515,6 +513,7 @@ def __init__( self.return_layernorm_output = return_layernorm_output self.params_dtype = params_dtype self.init_method = init_method + self.attn_mask_type = attn_mask_type if not fuse_qkv_params: qkv_weight_interleaved = False @@ -658,7 +657,7 @@ def forward( """MultiHeadAttention FWD""" # hidden_states: [sq, b, h] - if attention_mask is not None: + if self.attn_mask_type != "causal" and attention_mask is not None: assert ( attention_mask.dtype == torch.bool ), "Attention mask must be a boolean tensor" @@ -836,6 +835,11 @@ class TransformerLayer(torch.nn.Module): TransformerLayer is made up of an attention block and a feedforward network (MLP). This standard layer is based on the paper "Attention Is All You Need". + .. note:: + + Argument :attr:`attention_mask` will be ignored in the `forward` call when + :attr:`self_attn_mask_type` is set to `"causal"`. + Parameters ---------- hidden_size : int @@ -983,6 +987,7 @@ def __init__( self.apply_residual_connection_post_layernorm = ( apply_residual_connection_post_layernorm ) + self.self_attn_mask_type = self_attn_mask_type assert ( self_attn_mask_type in AttnMaskTypes ), f"self_attn_mask_type {self_attn_mask_type} not supported" @@ -1129,6 +1134,11 @@ def forward( """ Transformer Layer: attention block and a feedforward network (MLP) + .. note:: + + Argument :attr:`attention_mask` will be ignored when :attr:`self_attn_mask_type` + is set to `"causal"`. + Parameters ---------- hidden_states : torch.Tensor @@ -1163,7 +1173,7 @@ def forward( hidden_states = hidden_states.contiguous() - if attention_mask is not None: + if self.self_attn_mask_type != "causal" and attention_mask is not None: assert ( attention_mask.dtype == torch.bool ), "Attention mask must be a boolean tensor" diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 9f1ddaa2b2..798bcfb332 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -8,6 +8,13 @@ import torch +def get_device_compute_capability() -> float: + """Returns the cuda compute capability of current GPU""" + major = torch.cuda.get_device_properties(torch.cuda.current_device()).major + minor = torch.cuda.get_device_properties(torch.cuda.current_device()).minor + return major + minor / 10 + + def attention_mask_func( attention_scores: torch.Tensor, attention_mask: torch.Tensor ) -> torch.Tensor: From f18e6773d9ed1aca1f497f6a2d3a927a21a372ea Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Fri, 24 Feb 2023 17:54:09 -0800 Subject: [PATCH 07/68] fix bug in non-FP8 nvfuser path (#81) Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/module.py b/transformer_engine/pytorch/module.py index a5c247926a..22434ab887 100644 --- a/transformer_engine/pytorch/module.py +++ b/transformer_engine/pytorch/module.py @@ -2204,7 +2204,7 @@ def forward( gelu=not bias_gelu_nvfusion, ) - if bias_gelu_nvfusion and is_grad_enabled: + if bias_gelu_nvfusion: fc1_out, _, _ = fc1_outputs gelu_out = bias_gelu_fused(fc1_out, fc1_bias) else: From f4955d3a510cab9e40ac63ffa180d9e6702ad603 Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Mon, 20 Mar 2023 17:17:38 -0700 Subject: [PATCH 08/68] Add SECURITY.md (#110) Signed-off-by: Przemek Tredak --- SECURITY.md | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 SECURITY.md diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 0000000000..35edb61b01 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,24 @@ +## Security + +NVIDIA is dedicated to the security and trust of our software products and services, including all source code repositories managed through our organization. + +If you need to report a security issue, please use the appropriate contact points outlined below. **Please do not report security vulnerabilities through GitHub/GitLab.** + +## Reporting Potential Security Vulnerability in an NVIDIA Product + +To report a potential security vulnerability in any NVIDIA product: +- Web: [Security Vulnerability Submission Form](https://www.nvidia.com/object/submit-security-vulnerability.html) +- E-Mail: psirt@nvidia.com + - We encourage you to use the following PGP key for secure email communication: [NVIDIA public PGP Key for communication](https://www.nvidia.com/en-us/security/pgp-key) + - Please include the following information: + - Product/Driver name and version/branch that contains the vulnerability + - Type of vulnerability (code execution, denial of service, buffer overflow, etc.) + - Instructions to reproduce the vulnerability + - Proof-of-concept or exploit code + - Potential impact of the vulnerability, including how an attacker could exploit the vulnerability + +While NVIDIA currently does not have a bug bounty program, we do offer acknowledgement when an externally reported security issue is addressed under our coordinated vulnerability disclosure policy. Please visit our [Product Security Incident Response Team (PSIRT)](https://www.nvidia.com/en-us/security/psirt-policies/) policies page for more information. + +## NVIDIA Product Security + +For all security-related concerns, please visit NVIDIA's Product Security portal at https://www.nvidia.com/en-us/security From e5ab21131c3d185823229b4f86cc3d54a3b39edf Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Wed, 22 Mar 2023 00:41:49 -0700 Subject: [PATCH 09/68] Catch FA internal error with compute capability 8.6 (#113) FA doesn't support compute 8.6 with head_dim>64 Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/transformer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 1869228c2e..cbd0622947 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -353,10 +353,11 @@ def __init__( norm_factor = math.sqrt(self.hidden_size_per_attention_head) + self.device_compute_capability = get_device_compute_capability() self.use_flash_attention = ( int(os.getenv("NVTE_FLASH_ATTN", "1")) and attn_mask_type == "causal" - and get_device_compute_capability() >= 8.0 + and self.device_compute_capability >= 8.0 ) attn_kwargs = { @@ -437,6 +438,7 @@ def forward( if (query_layer.dtype not in [torch.bfloat16, torch.float16] or key_layer.dtype not in [torch.bfloat16, torch.float16] or value_layer.dtype not in [torch.bfloat16, torch.float16] + or (self.device_compute_capability == 8.6 and key_layer.shape[-1] > 64) ): use_flash_attention = False From 7e8c3e69da100e485895e44ec9c1699cb1add629 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 28 Mar 2023 09:42:26 -0700 Subject: [PATCH 10/68] Fix usage of return_bias argument (#114) * fix usage of return_bias argument Signed-off-by: Kirthi Shankar Sivamani * review comments Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/module.py | 28 +++++++++++------------ transformer_engine/pytorch/transformer.py | 4 ++-- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/transformer_engine/pytorch/module.py b/transformer_engine/pytorch/module.py index 4b67f1b91a..4e012be58c 100644 --- a/transformer_engine/pytorch/module.py +++ b/transformer_engine/pytorch/module.py @@ -1123,6 +1123,7 @@ def __init__( self.fuse_wgrad_accumulation = fuse_wgrad_accumulation self.use_bias = bias self.return_bias = return_bias + self.apply_bias = bias and not return_bias self.return_layernorm_output = return_layernorm_output self.parameters_split = parameters_split self.zero_centered_gamma = zero_centered_gamma @@ -1187,7 +1188,7 @@ def __init__( stride=1, ) - if self.use_bias or self.return_bias: + if self.use_bias: self.register_buffer("bias_tensor", torch.empty( self.out_features, @@ -1229,7 +1230,7 @@ def __init__( stride=1, ) - if self.use_bias or self.return_bias: + if self.use_bias: self.register_parameter( bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size]) ) @@ -1246,9 +1247,8 @@ def __init__( # For RPL, bias has to be added after TP collectives # So it cannot be fused with the GEMM - if self.parallel_mode == "row" and self.use_bias: + if self.parallel_mode == "row" and self.apply_bias: self.gemm_bias_unfused_add = True - self.use_bias = False else: self.gemm_bias_unfused_add = False @@ -1331,7 +1331,7 @@ def forward( self.weight1_fp8 if self.fp8 else None, self.weight1_t_fp8 if self.fp8 else None, bias_tensor, - self.use_bias, + self.apply_bias and not self.gemm_bias_unfused_add, self.eps, is_first_microbatch, self.fp8, @@ -1776,6 +1776,7 @@ def __init__( self.fuse_wgrad_accumulation = fuse_wgrad_accumulation self.use_bias = bias self.return_bias = return_bias + self.apply_bias = bias and not return_bias self.parameters_split = parameters_split if tp_group is None: @@ -1819,7 +1820,7 @@ def __init__( stride=1, ) - if self.use_bias or self.return_bias: + if self.use_bias: self.register_buffer("bias_tensor", torch.empty( self.out_features, @@ -1861,7 +1862,7 @@ def __init__( stride=1, ) - if self.use_bias or self.return_bias: + if self.use_bias: self.register_parameter( bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size]) ) @@ -1878,9 +1879,8 @@ def __init__( # For RPL, bias has to be added after TP collectives # So it cannot be fused with the GEMM - if self.parallel_mode == "row" and self.use_bias: + if self.parallel_mode == "row" and self.apply_bias: self.gemm_bias_unfused_add = True - self.use_bias = False else: self.gemm_bias_unfused_add = False @@ -1946,7 +1946,7 @@ def forward( self.weight1_t_fp8 if self.fp8 else None, inp, bias_tensor, - self.use_bias, + self.apply_bias and not self.gemm_bias_unfused_add, is_first_microbatch, self.fp8, self.fp8_calibration, @@ -2667,6 +2667,7 @@ def __init__( self.fuse_wgrad_accumulation = fuse_wgrad_accumulation self.use_bias = bias self.return_bias = return_bias + self.apply_bias = bias and not return_bias self.return_layernorm_output = return_layernorm_output self.bias_gelu_nvfusion = bool(int(os.getenv("NVTE_BIAS_GELU_NVFUSION", "1"))) self.set_parallel_mode = set_parallel_mode @@ -2759,7 +2760,7 @@ def __init__( stride=1, ) - if self.use_bias or self.return_bias: + if self.use_bias: self.fc2_bias = Parameter( torch.empty( hidden_size, device=torch.cuda.current_device(), dtype=params_dtype @@ -2770,9 +2771,8 @@ def __init__( # For RPL, bias has to be added after TP collectives # So it cannot be fused with the GEMM - if self.set_parallel_mode and self.use_bias: + if self.set_parallel_mode and self.apply_bias: self.gemm_bias_unfused_add = True - self.use_bias = False else: self.gemm_bias_unfused_add = False @@ -2845,7 +2845,7 @@ def forward( self.weight2_fp8 if self.fp8 else None, self.weight2_t_fp8 if self.fp8 else None, self.fc2_bias, - self.use_bias, + self.apply_bias and not self.gemm_bias_unfused_add, self.eps, is_first_microbatch, self.fp8, diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index cbd0622947..774c9fd11e 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -607,7 +607,7 @@ def __init__( hidden_size, hidden_size, init_method=output_layer_init_method, - bias=False, + bias=True, return_bias=True, parallel_mode="row" if set_parallel_mode else None, **common_gemm_kwargs, @@ -1059,7 +1059,7 @@ def __init__( get_rng_state_tracker=get_rng_state_tracker, init_method=init_method, output_layer_init_method=output_layer_init_method, - bias=False, + bias=True, return_bias=True, sequence_parallel=self.sequence_parallel, params_dtype=params_dtype, From 626da0deca4b77cfe1e0ad2de970d39938f43210 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Tue, 28 Mar 2023 09:44:15 -0700 Subject: [PATCH 11/68] Fix zombie process when querying TE install path (#121) * Remove zombie process from querying TE install path Co-authored-by: Naman Goyal Signed-off-by: Tim Moon * Fix FA version checking Signed-off-by: Kirthi Shankar Sivamani * fix unused import error Signed-off-by: Kirthi Shankar Sivamani * Fix lint warning Signed-off-by: Tim Moon --------- Signed-off-by: Tim Moon Signed-off-by: Kirthi Shankar Sivamani Co-authored-by: Naman Goyal Co-authored-by: Kirthi Shankar Sivamani --- transformer_engine/common/__init__.py | 24 +++++++++++------------ transformer_engine/pytorch/transformer.py | 4 ++-- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index 791ba793a8..7dfcdc96bb 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -3,25 +3,23 @@ # See LICENSE for license information. """FW agnostic user-end APIs""" +import ctypes +import os +import platform +import subprocess def get_te_path(): - """Find TE path using pip""" + """Find Transformer Engine install path using pip""" - import os - - te_info = ( - os.popen("pip show transformer_engine").read().replace("\n", ":").split(":") - ) - return te_info[te_info.index("Location") + 1].strip() + command = ["pip", "show", "transformer_engine"] + result = subprocess.run(command, capture_output=True, check=True, text=True) + result = result.stdout.replace("\n", ":").split(":") + return result[result.index("Location")+1].strip() def _load_library(): - """Load TE .so""" - - import os - import ctypes - import platform + """Load shared library with Transformer Engine C extensions""" system = platform.system() if system == "Linux": @@ -31,7 +29,7 @@ def _load_library(): elif system == "Windows": extension = "dll" else: - raise "Unsupported operating system " + system + "." + raise RuntimeError(f"Unsupported operating system ({system})") lib_name = "libtransformer_engine." + extension dll_path = get_te_path() dll_path = os.path.join(dll_path, lib_name) diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 774c9fd11e..fa00fb86fc 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -4,9 +4,9 @@ """Transformer.""" import os -import re import math import warnings +from importlib.metadata import version from contextlib import nullcontext from typing import Any, Callable, Optional, Tuple, Union @@ -42,7 +42,7 @@ checkpoint, ) -_flash_attn_version = re.search("Version: (.*)", os.popen("pip show flash_attn").read()).group(1) +_flash_attn_version = version("flash-attn") warnings.filterwarnings("module", category=DeprecationWarning, module="transformer") From 084b1e54a5d5bc84e380cbebde18d53d0243fc5a Mon Sep 17 00:00:00 2001 From: Jeng Bai-Cheng Date: Wed, 29 Mar 2023 01:39:20 +0800 Subject: [PATCH 12/68] [JAX] Add TE examples (#108) * refactor JAX examples Signed-off-by: Ryan Jeng * fix doc-string Signed-off-by: Ryan Jeng * add dp example Signed-off-by: Ryan Jeng * refactor Signed-off-by: Ryan Jeng * fix params_axes_pspec Signed-off-by: Ryan Jeng * Add model parallel example and refactor Update readme Signed-off-by: Ryan Jeng * align code and readme Signed-off-by: Ryan Jeng * update verification Signed-off-by: Ryan Jeng * add mask Signed-off-by: Ryan Jeng * num_gpu is configurable Signed-off-by: Ryan Jeng * update readme Signed-off-by: Ryan Jeng * update readme Signed-off-by: Ryan Jeng * solvepylint issue Signed-off-by: Ryan Jeng * ignore markdown and txt file from license check Signed-off-by: Ryan Jeng * Update README.md Signed-off-by: Ryan Jeng * add flax into requirements.txt Signed-off-by: Ryan Jeng --------- Signed-off-by: Ryan Jeng --- examples/jax/README.md | 7 + examples/jax/encoder/README.md | 69 +++ examples/jax/encoder/requirements.txt | 4 + .../encoder/test_model_parallel_encoder.py | 441 ++++++++++++++++++ examples/jax/encoder/test_multigpu_encoder.py | 420 +++++++++++++++++ .../encoder/test_single_gpu_bf16_training.py | 75 --- .../jax/encoder/test_single_gpu_encoder.py | 344 ++++++++++++++ .../encoder/test_single_gpu_fp8_training.py | 99 ---- examples/jax/mnist/README.md | 34 ++ examples/jax/mnist/requirements.txt | 3 + examples/jax/mnist/test_single_gpu_mnist.py | 311 ++++++++++++ qa/L0_jax_unittest/test.sh | 3 + qa/L0_license/config.json | 4 +- qa/L0_license/copyright_checker.py | 1 + tests/jax/test_mnist.py | 227 --------- transformer_engine/jax/module.py | 8 +- transformer_engine/jax/transformer.py | 4 +- 17 files changed, 1646 insertions(+), 408 deletions(-) create mode 100644 examples/jax/README.md create mode 100644 examples/jax/encoder/README.md create mode 100644 examples/jax/encoder/requirements.txt create mode 100644 examples/jax/encoder/test_model_parallel_encoder.py create mode 100644 examples/jax/encoder/test_multigpu_encoder.py delete mode 100644 examples/jax/encoder/test_single_gpu_bf16_training.py create mode 100644 examples/jax/encoder/test_single_gpu_encoder.py delete mode 100644 examples/jax/encoder/test_single_gpu_fp8_training.py create mode 100644 examples/jax/mnist/README.md create mode 100644 examples/jax/mnist/requirements.txt create mode 100644 examples/jax/mnist/test_single_gpu_mnist.py delete mode 100644 tests/jax/test_mnist.py diff --git a/examples/jax/README.md b/examples/jax/README.md new file mode 100644 index 0000000000..d2c98f15c2 --- /dev/null +++ b/examples/jax/README.md @@ -0,0 +1,7 @@ +# Transformer Engine Examples # + +This folder contains simple examples introducing Transformer Engine and FP8 training usage. + +**Examples Outline** +* MNIST training: Training MNIST dataset is a good start point to learn how use Transformer Engine and enable FP8 training +* Encoder training: The encoder examples introduce more about how to scale up training on multiple GPUs with Transformer Engine \ No newline at end of file diff --git a/examples/jax/encoder/README.md b/examples/jax/encoder/README.md new file mode 100644 index 0000000000..388f2f40c6 --- /dev/null +++ b/examples/jax/encoder/README.md @@ -0,0 +1,69 @@ +# Basic Transformer Encoder Example with Optional FP8 # + +This example uses Transformer Encoder to demonstrate the Transformer Engine usage. And more focus on scaling up training on multiple GPUs. Highly recommend studying the [MNIST example of the Transformer Engine](/examples/jax/mnist) before reading this example. The Transformer Engine is built on top of [Flax](https://github.com/google/flax). Thus, examples use `pjit` to set up multiple GPU training. The basic pjit usage can be referred to [Scale up Flax Modules on multiple devices with pjit](https://flax.readthedocs.io/en/latest/guides/flax_on_pjit.html). + +## Single GPU ## + +1. Setup dataset: This is done by using the `tfds` library to download the GLUE/CoLA dataset and using `nltk` to tokenize the sentences. This example focuses on Transformer Engine usage. Thus, a simple algorithm is used to convert tokens to INT32 tensors as input to the embedding layer. The `get_datasets` and `data_preprocess` routines are used for this purpose. + +2. Define model: The `Net` class is a small Transformer Encoder model for sentence classification. The Transformer Engine provides `te.TransformerLayer` as encoder block and `te.DenseGeneral`. The structure of encoder block can be referred to [Scaling Up Models and Data with t5x and seqio](https://arxiv.org/abs/2203.17189) + +3. Build training loop: The `train_and_evaluate` is the main routine to initialize the model and start training and evaluating. Use `fp8_autocast` context manager to enable FP8 training and check `var_collect` if the variable collection contains `Float8`. + +4. Training process: In `train_step`, combine the FP8 metadata and latest model parameters into var_collect as a frozen dictionary and fill it to the gradient function. And then, call `te.update_fp8_metas` to update FP8 metadata. The number of training steps to update FP8 metadata can be customized. In this example, it is updated every step. + +5. Evaluating process: Same as the training process, the FP8 metadata needs to be in var_collect and fill it into a loss function, if enabling FP8 computing. + +### Run ### + +```bash +python test_single_gpu_encoder.py +python test_single_gpu_encoder.py --use-fp8 +``` + +## Multiple GPU with Data Parallelism ## + +1. The data parallelism (DP) divides a mini-batch for multiple devices, and each device has complete model parameters. In this example, the first dimension of input tensor is `batch_size` which is 64 by default, and uses 8 GPUs to train the model, so each device takes 8 sentences at once. The "dividing" is called "sharding" in the JAX documents. + +2. In order to let JAX know how to do sharding, the `device_mesh` needs to be defined and each axis need to be named. A common way to annotate axis names is `data` which means the mesh dimension used for data-parallel sharding of the batch dimension of inputs and activations. And the first argument of `te.ShardingResource` is the name of the device axis which is used for data parallelism. + +3. On the model side, the logical axis of each weight tensor of the model can be named. The `te.TransformerLayer` has the default names, which are stored in `abs_var_collect`, a collection of variables returned by `jax.eval_shape(encoder.init, ...)`. The key index is `params_axes`. The `te.DenseGeneral` doesn't have the default named axis because it is generic. Also, data-parallel sharding doesn't need to divide weight tensor, so named axis is not required for this case. But te.DenseGeneral is based on [XLA custom-call](https://www.tensorflow.org/xla/custom_call) and [xmap](https://jax.readthedocs.io/en/latest/notebooks/xmap_tutorial.html), the `sharding_type` must be set to map weights and xmap correctly. + +4. The next is to create sharding rules, mapping the device axis to the logical axis. The `te.extend_logical_axis_rules` under fp8_autocast will return a list of pairs of the mapping, such as `(('batch', 'data'), ...)`. The first is the logical axis and second is the device axis. + +5. Refer structure of `abs_var_collect['params']` and `abs_var_collect['params_axes']` to set up `PartitionSpec` for pjit. All logical axes should be replaced by device axes. If the value of PartitionSpec is None, that means no sharding, broadcasting the data to every device. Note that the `params_axes` attribute is provided by Transformer Engine. The Flax's module doesn't have it, such as `nn.Embed`. For nn.Embed, assigning an empty PartitionSpec is fine because each device has its own embedding layer in DP mode. The `get_params_pspec` routine is used for this purpose. Because each device has a complete model in DP mode, all values of PartitionSpec in params_pspec should be None. This will be different in the model parallelism example. + +6. Fill in `params_pspec` and `encoder.init` to pjit to get a compiled function, `pjit_encoder_init `, and use it to initialize the model, so JAX now can know how to do the sharding. + +7. The `train_step` and `eval_step` also needs to be compiled by pjit. Thus, every input and output argument has to be set up `PartitionSpec` if the argument contains a tensor. For instance, the `input_pspec` is `PartitionSpec('data', None)` because the input shape is (batch size, sequence length). Then, the rest of the workflow is similar to the previous example. + +### Run ### + +```bash +python test_multigpu_encoder.py +python test_multigpu_encoder.py --use-fp8 +``` + +## Multiple GPU with Model Parallelism ## + +1. The model parallelism as known as tensor parallelism (TP) divides a model for multiple devices, and each device has part of model parameters. This example inherits previous DP example, but divides a model to two devices. + +2. To set up device mesh for TP, adding a new named axis called `model`, which is used for sharding parameters of the model across devices. This example divides the model to two parts (`num_gpu_tp = 2`). One device only has half of the model. + +3. On the model side, The `te.TransformerLayer` doesn't need additional settings because it has the default axis name already. It will be divided by `DEVICE_TP_AXIS` when model initialization. The first `te.DenseGeneral` is divided by columns and second one is divided by rows for TP. Because `te.DenseGeneral` doesn't have the default named axis, the names must be set manually by passing `kernel_axes` and `bias_axes` arguments. Then, the rest of the workflow is similar to the previous example. + +4. The tips for debugging TP: + * Use [inspect_array_sharding](https://jax.readthedocs.io/en/latest/_autosummary/jax.debug.inspect_array_sharding.html) or [visualize_array_sharding](https://jax.readthedocs.io/en/latest/_autosummary/jax.debug.visualize_array_sharding.html) to check the shape of activations and weights. + * Check the shape of device buffer of weight tensor. For instance, `var_collect['params']['DenseGeneral_0']['kernel'].device_buffers[device_id].shape`. The `device_id` is an integer. If a weight tensor's shape is (256, 256) and you intend to divide it for two devices by second dimension, then the shape returned by device_buffers should be (256, 128). + * Dump XLA HLO by setting `XLA_FLAGS` and see whether it contains unexpected `all-gather` operations or not. + ```python + import os + os.environ['XLA_FLAGS'] = "--xla_dump_hlo_as_proto --xla_dump_hlo_as_text --xla_dump_hlo_as_html --xla_dump_to=" + ``` + +### Run ### + +```bash +python test_model_parallel_encoder.py +python test_model_parallel_encoder.py --use-fp8 +``` diff --git a/examples/jax/encoder/requirements.txt b/examples/jax/encoder/requirements.txt new file mode 100644 index 0000000000..bc1b755cb9 --- /dev/null +++ b/examples/jax/encoder/requirements.txt @@ -0,0 +1,4 @@ +flax +nltk +optax +tensorflow-datasets diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py new file mode 100644 index 0000000000..10c880710e --- /dev/null +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -0,0 +1,441 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +""" Encoder training on multi-GPU with tesnor parallelism""" +import argparse +import unittest +from functools import partial + +import jax +import jax.numpy as jnp +import nltk +import numpy as np +import optax +import tensorflow_datasets as tfds +from cuda import cudart +from flax import linen as nn +from flax.core.frozen_dict import FrozenDict +from flax.training import train_state +from jax.experimental import mesh_utils +from jax.experimental.pjit import pjit + +import transformer_engine.jax as te + +DEVICE_DP_AXIS = 'data' +DEVICE_TP_AXIS = 'model' +NAMED_BROADCAST_AXIS = 'my_broadcast_axis' +NAMED_TP_AXIS = 'my_tp_axis' +PARAMS_KEY = 'params' +PARAMS_AXES_KEY = PARAMS_KEY + '_axes' +DROPOUT_KEY = 'dropout' +INPUT_KEY = 'input_rng' + + +def check_num_gpu(desired_num_gpu): + """Check if the number of GPUs are correct.""" + actual_num_gpu = len(jax.local_devices()) + assert actual_num_gpu == desired_num_gpu, f"Number of GPUs is mismatch. " \ + f"{desired_num_gpu} GPUs are assigned, but the actual number of GPUs is {actual_num_gpu}" + + +def gpu_has_fp8(): + """Check if the GPU has FP8.""" + cudaSuccess = cudart.cudaError_t.cudaSuccess + ret, gpu_id = cudart.cudaGetDevice() + assert ret == cudaSuccess + flag = cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMajor + _, major = cudart.cudaDeviceGetAttribute(flag, gpu_id) + flag = cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMinor + _, minor = cudart.cudaDeviceGetAttribute(flag, gpu_id) + sm_arch = major * 10 + minor + return sm_arch >= 89 + + +class Net(nn.Module): + """NLP Encoder""" + num_embed: int + + @nn.compact + def __call__(self, x, mask, disable_dropout=False): + x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x) + + te_Encoder = partial(te.TransformerLayer, + hidden_size=256, + mlp_hidden_size=1024, + num_attention_heads=8, + hidden_dropout=0.1, + attention_dropout=0.1, + dropout_rng_name=DROPOUT_KEY, + layer_type=te.TransformerLayerType.ENCODER, + enable_relative_embedding=False, + dtype=jnp.bfloat16) + x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) + + x = x.reshape(x.shape[0], -1) + + x = te.DenseGeneral(features=256, + kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS), + bias_axes=(NAMED_TP_AXIS,), + sharding_type=te.ShardingType.DP_TP_COL, + dtype=jnp.bfloat16)(x) + + x = te.DenseGeneral(features=256, + kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS), + bias_axes=(NAMED_BROADCAST_AXIS,), + sharding_type=te.ShardingType.DP_TP_ROW, + dtype=jnp.bfloat16)(x) + + x = nn.Dense(features=2, dtype=jnp.bfloat16)(x) + return x + + +def train_step(state, inputs, masks, labels, var_collect, rngs, use_fp8): + """Computes gradients, loss and accuracy for a single batch.""" + + def loss_fn(var_collect, disable_dropout=False): + logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs) + one_hot = jax.nn.one_hot(labels, 2) + loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) + return loss, logits + + var_collect = FrozenDict({**var_collect, PARAMS_KEY: state.params}) + grad_fn = jax.value_and_grad(loss_fn, has_aux=True) + (loss, logits), grads = grad_fn(var_collect) + accuracy = jnp.mean(jnp.argmax(logits, -1) == labels) + + var_collect, grads = grads.pop(PARAMS_KEY) + state = state.apply_gradients(grads=grads) + if use_fp8: + var_collect = te.update_fp8_metas(var_collect) + + return state, loss, accuracy, var_collect + + +def train_epoch(state, train_ds, batch_size, rngs, var_collect, use_fp8, train_fn): + """Train for a single epoch.""" + train_ds_size = len(train_ds['sentence']) + steps_per_epoch = train_ds_size // batch_size + perms = jax.random.permutation(rngs[INPUT_KEY], train_ds_size) + perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch + perms = perms.reshape((steps_per_epoch, batch_size)) + epoch_loss = [] + epoch_accuracy = [] + + for perm in perms: + batch_inputs = train_ds['sentence'][perm, ...] + batch_masks = train_ds['mask'][perm, ...] + batch_labels = train_ds['label'][perm, ...] + state, loss, accuracy, var_collect = train_fn(state, batch_inputs, batch_masks, + batch_labels, var_collect, rngs, use_fp8) + epoch_loss.append(loss) + epoch_accuracy.append(accuracy) + + avg_loss = np.mean(epoch_loss) + avg_accuracy = np.mean(epoch_accuracy) + return state, avg_loss, avg_accuracy, var_collect + + +def eval_step(state, inputs, masks, labels, var_collect): + """Computes loss and accuracy for a single batch.""" + + def loss_fn(var_collect, disable_dropout=False): + logits = state.apply_fn(var_collect, inputs, masks, disable_dropout) + one_hot = jax.nn.one_hot(labels, 2) + loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) + return loss, logits + + var_collect = FrozenDict({**var_collect, PARAMS_KEY: state.params}) + loss, logits = loss_fn(var_collect, disable_dropout=True) + accuracy = jnp.mean(jnp.argmax(logits, -1) == labels) + return loss, accuracy + + +def eval_model(state, test_ds, batch_size, var_collect, eval_fn): + """Evaluation loop.""" + test_ds_size = len(test_ds['sentence']) + num_steps = test_ds_size // batch_size + valid_size = num_steps * batch_size + all_loss = [] + all_accuracy = [] + + for batch_start in range(0, valid_size, batch_size): + batch_end = batch_start + batch_size + batch_inputs = test_ds['sentence'][batch_start:batch_end] + batch_masks = test_ds['mask'][batch_start:batch_end] + batch_labels = test_ds['label'][batch_start:batch_end] + loss, accuracy = eval_fn(state, batch_inputs, batch_masks, batch_labels, var_collect) + all_loss.append(loss) + all_accuracy.append(accuracy) + + avg_loss = np.mean(all_loss) + avg_accuracy = np.mean(all_accuracy) + return avg_loss, avg_accuracy + + +def data_preprocess(dataset, vocab, word_id, max_seq_len): + """Convert tokens to numbers.""" + nltk.download('punkt') + dataset_size = len(dataset['sentence']) + output = np.zeros((dataset_size, max_seq_len), dtype=np.int32) + mask_3d = np.empty((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8) + + for j, sentence in enumerate(dataset['sentence']): + tokens = nltk.word_tokenize(sentence.decode("utf-8")) + tensor = output[j] + mask_1d = np.zeros((1, max_seq_len), dtype=np.uint8) + + for i, word in enumerate(tokens): + if i >= max_seq_len: + break + + if word not in vocab: + vocab[word] = word_id + tensor[i] = word_id + word_id = word_id + 1 + else: + tensor[i] = vocab[word] + + mask_1d[0, i] = 1 + + mask_2d = mask_3d[j] + np.dot(mask_1d.T, mask_1d, out=mask_2d) + np.subtract(1, mask_2d, out=mask_2d) + + dataset['sentence'] = output + dataset['label'] = dataset['label'].astype(np.float32) + dataset['mask'] = mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len)) + return dataset, vocab, word_id + + +def get_datasets(max_seq_len): + """Load GLUE train and test datasets into memory.""" + vocab = {} + word_id = 0 + dataset = 'glue/cola' + train_ds = tfds.as_numpy(tfds.load(dataset, split='train', batch_size=-1)) + train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len) + test_ds = tfds.as_numpy(tfds.load(dataset, split='validation', batch_size=-1)) + test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len) + return train_ds, test_ds, word_id + + +def check_fp8(state, var_collect, inputs, masks, labels): + "Check if model includes FP8." + rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)} + assert "Float8" in str( + jax.make_jaxpr(train_step, static_argnums=6)(state, inputs, masks, labels, var_collect, + rngs, True)) + + +def get_params_pspec(sharding_rules, abs_var_collect): + """Refer params to create params partition spec""" + rules_dict = {} + for key, value in sharding_rules: + rules_dict[key] = value + + def to_device_axis(logical_axis): + partitions = [rules_dict[key] for key in logical_axis] + return jax.sharding.PartitionSpec(*partitions) + + params_axes = abs_var_collect.get(PARAMS_AXES_KEY, {}) + params_axes_pspec = jax.tree_map(to_device_axis, nn.partitioning.get_axis_names(params_axes)) + params_pspec = jax.tree_map(lambda x: jax.sharding.PartitionSpec(), abs_var_collect[PARAMS_KEY]) + params_pspec = FrozenDict({**params_pspec, **params_axes_pspec}) + return params_pspec + + +def get_state_pspec(state, params_pspec): + """Refer params_pspec to create state partition spec""" + + def replace_params(x): + return params_pspec if isinstance(x, FrozenDict) else None + + state_pspec = jax.tree_map(replace_params, state, is_leaf=lambda x: isinstance(x, FrozenDict)) + return state_pspec + + +def train_and_evaluate(args): + """Execute model training and evaluation loop.""" + print(args) + check_num_gpu(args.num_gpu) + + if args.use_fp8: + assert gpu_has_fp8(), "GPU needs to support FP8." + + num_gpu_tp = 2 + if args.num_gpu % num_gpu_tp == 0: + num_gpu_dp = args.num_gpu // num_gpu_tp + else: + num_gpu_dp = 1 + num_gpu_tp = 1 + + assert args.batch_size % num_gpu_dp == 0, f"Batch size needs to be multiple of {num_gpu_dp}" + assert args.test_batch_size % num_gpu_dp == 0, \ + f"Test batch size needs to be multiple of {num_gpu_dp}" + + device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp)) + with jax.sharding.Mesh(devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS)): + + rng = jax.random.PRNGKey(args.seed) + rng, params_rng = jax.random.split(rng) + rng, dropout_rng = jax.random.split(rng) + init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng} + + input_shape = [args.batch_size, args.max_seq_len] + mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len] + label_shape = [args.batch_size] + + with te.fp8_autocast(args.use_fp8, + sharding_resource=te.ShardingResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS)): + train_ds, test_ds, num_embed = get_datasets(args.max_seq_len) + encoder = Net(num_embed) + inputs = jnp.zeros(input_shape, dtype=jnp.int32) + masks = jnp.zeros(mask_shape, dtype=jnp.uint8) + abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks) + + customized_rules = ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS)) + sharding_rules = te.extend_logical_axis_rules(tuple()) + customized_rules + params_pspec = get_params_pspec(sharding_rules, abs_var_collect) + inputs_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None) + masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None) + + in_shardings = (None, inputs_pspec, masks_pspec) + out_shardings = FrozenDict({key: params_pspec if key is PARAMS_KEY else None \ + for key in abs_var_collect}) + pjit_encoder_init = pjit(encoder.init, in_shardings, out_shardings) + var_collect = pjit_encoder_init(init_rngs, inputs, masks) + + optimizer = optax.adamw(args.lr) + var_collect, params = var_collect.pop(PARAMS_KEY) + state = train_state.TrainState.create(apply_fn=encoder.apply, + params=params, + tx=optimizer) + state_pspec = get_state_pspec(state, params_pspec) + labels_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS,) + + in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None, None) + out_shardings = (state_pspec, None, None, None) + pjit_train_step = pjit(train_step, in_shardings, out_shardings, static_argnums=(6,)) + + in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None) + out_shardings = (None, None) + pjit_eval_step = pjit(eval_step, in_shardings, out_shardings) + + if args.use_fp8: + labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) + check_fp8(state, var_collect, inputs, masks, labels) + + if args.dry_run: + labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) + rngs = {DROPOUT_KEY: dropout_rng} + pjit_train_step(state, inputs, masks, labels, var_collect, rngs, args.use_fp8) + print("PASSED") + return None + + for epoch in range(1, args.epochs + 1): + rng, input_rng = jax.random.split(rng) + rng, dropout_rng = jax.random.split(rng) + rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng} + + state, train_loss, train_accuracy, var_collect = train_epoch( + state, train_ds, args.batch_size, rngs, var_collect, args.use_fp8, + pjit_train_step) + + test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size, + var_collect, pjit_eval_step) + + print(f"Epoch: {epoch:>2} " + f"Train Loss: {train_loss:.6f} " + f"Train Accuracy: {train_accuracy:.6f} " + f"Test Loss: {test_loss:.6f} " + f"Test Accuracy: {test_accuracy:.6f} ") + + return [train_loss, train_accuracy, test_loss, test_accuracy] + + +def encoder_parser(args): + """Training settings.""" + parser = argparse.ArgumentParser(description="JAX Encoder Example") + parser.add_argument( + "--num-gpu", + type=int, + default=8, + metavar="N", + help="number of GPUs (default: 8)", + ) + parser.add_argument( + "--batch-size", + type=int, + default=64, + metavar="N", + help="input batch size for training (default: 64)", + ) + parser.add_argument( + "--test-batch-size", + type=int, + default=64, + metavar="N", + help="input batch size for testing (default: 64)", + ) + parser.add_argument( + "--max-seq-len", + type=int, + default=32, + metavar="N", + help="maximum sequence length (default: 32)", + ) + parser.add_argument( + "--epochs", + type=int, + default=3, + metavar="N", + help="number of epochs to train (default: 3)", + ) + parser.add_argument( + "--lr", + type=float, + default=0.0001, + metavar="LR", + help="learning rate (default: 0.0001)", + ) + parser.add_argument( + "--dry-run", + action="store_true", + default=False, + help="quickly check a single pass", + ) + parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") + parser.add_argument("--use-fp8", + action="store_true", + default=False, + help="Use FP8 for inference and training without recalibration") + + return parser.parse_args(args) + + +class TestEncoder(unittest.TestCase): + """Encoder unittests""" + + @classmethod + def setUpClass(cls): + """Run 3 epochs for testing""" + num_gpu = len(jax.local_devices()) + if num_gpu % 2 != 0: + num_gpu = 1 + cls.args = encoder_parser(["--epochs", "3", "--num-gpu", str(num_gpu)]) + + def test_te_bf16(self): + """Test Transformer Engine with BF16""" + actual = train_and_evaluate(self.args) + assert actual[0] < 0.45 and actual[1] > 0.79 + + @unittest.skipIf(not gpu_has_fp8(), reason='GPU capability is not enough to run FP8') + def test_te_fp8(self): + """Test Transformer Engine with FP8""" + self.args.use_fp8 = True + actual = train_and_evaluate(self.args) + assert actual[0] < 0.45 and actual[1] > 0.79 + + +if __name__ == "__main__": + train_and_evaluate(encoder_parser(None)) diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py new file mode 100644 index 0000000000..9cb420b0c8 --- /dev/null +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -0,0 +1,420 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +""" Encoder training on multi-GPU with data parallelism""" +import argparse +import unittest +from functools import partial + +import jax +import jax.numpy as jnp +import nltk +import numpy as np +import optax +import tensorflow_datasets as tfds +from cuda import cudart +from flax import linen as nn +from flax.core.frozen_dict import FrozenDict +from flax.training import train_state +from jax.experimental import mesh_utils +from jax.experimental.pjit import pjit + +import transformer_engine.jax as te + +DEVICE_DP_AXIS = 'data' +PARAMS_KEY = 'params' +PARAMS_AXES_KEY = PARAMS_KEY + '_axes' +DROPOUT_KEY = 'dropout' +INPUT_KEY = 'input_rng' + + +def check_num_gpu(desired_num_gpu): + """Check if the number of GPUs are correct.""" + actual_num_gpu = len(jax.local_devices()) + assert actual_num_gpu == desired_num_gpu, f"Number of GPUs is mismatch. " \ + f"{desired_num_gpu} GPUs are assigned, but the actual number of GPUs is {actual_num_gpu}" + + +def gpu_has_fp8(): + """Check if the GPU has FP8.""" + cudaSuccess = cudart.cudaError_t.cudaSuccess + ret, gpu_id = cudart.cudaGetDevice() + assert ret == cudaSuccess + flag = cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMajor + _, major = cudart.cudaDeviceGetAttribute(flag, gpu_id) + flag = cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMinor + _, minor = cudart.cudaDeviceGetAttribute(flag, gpu_id) + sm_arch = major * 10 + minor + return sm_arch >= 89 + + +class Net(nn.Module): + """NLP Encoder""" + num_embed: int + + @nn.compact + def __call__(self, x, mask, disable_dropout=False): + x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x) + + te_Encoder = partial(te.TransformerLayer, + hidden_size=256, + mlp_hidden_size=1024, + num_attention_heads=8, + hidden_dropout=0.1, + attention_dropout=0.1, + dropout_rng_name=DROPOUT_KEY, + layer_type=te.TransformerLayerType.ENCODER, + enable_relative_embedding=False, + dtype=jnp.bfloat16) + x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) + + x = x.reshape(x.shape[0], -1) + + x = te.DenseGeneral(features=256, sharding_type=te.ShardingType.DP, dtype=jnp.bfloat16)(x) + + x = te.DenseGeneral(features=256, sharding_type=te.ShardingType.DP, dtype=jnp.bfloat16)(x) + + x = nn.Dense(features=2, dtype=jnp.bfloat16)(x) + return x + + +def train_step(state, inputs, masks, labels, var_collect, rngs, use_fp8): + """Computes gradients, loss and accuracy for a single batch.""" + + def loss_fn(var_collect, disable_dropout=False): + logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs) + one_hot = jax.nn.one_hot(labels, 2) + loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) + return loss, logits + + var_collect = FrozenDict({**var_collect, PARAMS_KEY: state.params}) + grad_fn = jax.value_and_grad(loss_fn, has_aux=True) + (loss, logits), grads = grad_fn(var_collect) + accuracy = jnp.mean(jnp.argmax(logits, -1) == labels) + + var_collect, grads = grads.pop(PARAMS_KEY) + state = state.apply_gradients(grads=grads) + if use_fp8: + var_collect = te.update_fp8_metas(var_collect) + + return state, loss, accuracy, var_collect + + +def train_epoch(state, train_ds, batch_size, rngs, var_collect, use_fp8, train_fn): + """Train for a single epoch.""" + train_ds_size = len(train_ds['sentence']) + steps_per_epoch = train_ds_size // batch_size + perms = jax.random.permutation(rngs[INPUT_KEY], train_ds_size) + perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch + perms = perms.reshape((steps_per_epoch, batch_size)) + epoch_loss = [] + epoch_accuracy = [] + + for perm in perms: + batch_inputs = train_ds['sentence'][perm, ...] + batch_masks = train_ds['mask'][perm, ...] + batch_labels = train_ds['label'][perm, ...] + state, loss, accuracy, var_collect = train_fn(state, batch_inputs, batch_masks, + batch_labels, var_collect, rngs, use_fp8) + epoch_loss.append(loss) + epoch_accuracy.append(accuracy) + + avg_loss = np.mean(epoch_loss) + avg_accuracy = np.mean(epoch_accuracy) + return state, avg_loss, avg_accuracy, var_collect + + +def eval_step(state, inputs, masks, labels, var_collect): + """Computes loss and accuracy for a single batch.""" + + def loss_fn(var_collect, disable_dropout=False): + logits = state.apply_fn(var_collect, inputs, masks, disable_dropout) + one_hot = jax.nn.one_hot(labels, 2) + loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) + return loss, logits + + var_collect = FrozenDict({**var_collect, PARAMS_KEY: state.params}) + loss, logits = loss_fn(var_collect, disable_dropout=True) + accuracy = jnp.mean(jnp.argmax(logits, -1) == labels) + return loss, accuracy + + +def eval_model(state, test_ds, batch_size, var_collect, eval_fn): + """Evaluation loop.""" + test_ds_size = len(test_ds['sentence']) + num_steps = test_ds_size // batch_size + valid_size = num_steps * batch_size + all_loss = [] + all_accuracy = [] + + for batch_start in range(0, valid_size, batch_size): + batch_end = batch_start + batch_size + batch_inputs = test_ds['sentence'][batch_start:batch_end] + batch_masks = test_ds['mask'][batch_start:batch_end] + batch_labels = test_ds['label'][batch_start:batch_end] + loss, accuracy = eval_fn(state, batch_inputs, batch_masks, batch_labels, var_collect) + all_loss.append(loss) + all_accuracy.append(accuracy) + + avg_loss = np.mean(all_loss) + avg_accuracy = np.mean(all_accuracy) + return avg_loss, avg_accuracy + + +def data_preprocess(dataset, vocab, word_id, max_seq_len): + """Convert tokens to numbers.""" + nltk.download('punkt') + dataset_size = len(dataset['sentence']) + output = np.zeros((dataset_size, max_seq_len), dtype=np.int32) + mask_3d = np.empty((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8) + + for j, sentence in enumerate(dataset['sentence']): + tokens = nltk.word_tokenize(sentence.decode("utf-8")) + tensor = output[j] + mask_1d = np.zeros((1, max_seq_len), dtype=np.uint8) + + for i, word in enumerate(tokens): + if i >= max_seq_len: + break + + if word not in vocab: + vocab[word] = word_id + tensor[i] = word_id + word_id = word_id + 1 + else: + tensor[i] = vocab[word] + + mask_1d[0, i] = 1 + + mask_2d = mask_3d[j] + np.dot(mask_1d.T, mask_1d, out=mask_2d) + np.subtract(1, mask_2d, out=mask_2d) + + dataset['sentence'] = output + dataset['label'] = dataset['label'].astype(np.float32) + dataset['mask'] = mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len)) + return dataset, vocab, word_id + + +def get_datasets(max_seq_len): + """Load GLUE train and test datasets into memory.""" + vocab = {} + word_id = 0 + dataset = 'glue/cola' + train_ds = tfds.as_numpy(tfds.load(dataset, split='train', batch_size=-1)) + train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len) + test_ds = tfds.as_numpy(tfds.load(dataset, split='validation', batch_size=-1)) + test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len) + return train_ds, test_ds, word_id + + +def check_fp8(state, var_collect, inputs, masks, labels): + "Check if model includes FP8." + rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)} + assert "Float8" in str( + jax.make_jaxpr(train_step, static_argnums=6)(state, inputs, masks, labels, var_collect, + rngs, True)) + + +def get_params_pspec(sharding_rules, abs_var_collect): + """Refer params to create params partition spec""" + rules_dict = {} + for key, value in sharding_rules: + rules_dict[key] = value + + def to_device_axis(logical_axis): + partitions = [rules_dict[key] for key in logical_axis] + return jax.sharding.PartitionSpec(*partitions) + + params_axes = abs_var_collect.get(PARAMS_AXES_KEY, {}) + params_axes_pspec = jax.tree_map(to_device_axis, nn.partitioning.get_axis_names(params_axes)) + params_pspec = jax.tree_map(lambda x: jax.sharding.PartitionSpec(), abs_var_collect[PARAMS_KEY]) + params_pspec = FrozenDict({**params_pspec, **params_axes_pspec}) + return params_pspec + + +def get_state_pspec(state, params_pspec): + """Refer params_pspec to create state partition spec""" + + def replace_params(x): + return params_pspec if isinstance(x, FrozenDict) else None + + state_pspec = jax.tree_map(replace_params, state, is_leaf=lambda x: isinstance(x, FrozenDict)) + return state_pspec + + +def train_and_evaluate(args): + """Execute model training and evaluation loop.""" + print(args) + check_num_gpu(args.num_gpu) + assert args.batch_size % args.num_gpu == 0, f"Batch size needs to be multiple of {args.num_gpu}" + assert args.test_batch_size % args.num_gpu == 0, \ + f"Test batch size needs to be multiple of {args.num_gpu}" + + if args.use_fp8: + assert gpu_has_fp8(), "GPU needs to support FP8." + + device_mesh = mesh_utils.create_device_mesh((args.num_gpu,)) + with jax.sharding.Mesh(devices=device_mesh, axis_names=(DEVICE_DP_AXIS,)): + + rng = jax.random.PRNGKey(args.seed) + rng, params_rng = jax.random.split(rng) + rng, dropout_rng = jax.random.split(rng) + init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng} + + input_shape = [args.batch_size, args.max_seq_len] + mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len] + label_shape = [args.batch_size] + + with te.fp8_autocast(args.use_fp8, sharding_resource=te.ShardingResource(DEVICE_DP_AXIS)): + train_ds, test_ds, num_embed = get_datasets(args.max_seq_len) + encoder = Net(num_embed) + inputs = jnp.zeros(input_shape, dtype=jnp.int32) + masks = jnp.zeros(mask_shape, dtype=jnp.uint8) + abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks) + + sharding_rules = te.extend_logical_axis_rules(tuple()) + params_pspec = get_params_pspec(sharding_rules, abs_var_collect) + inputs_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None) + masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None) + + in_shardings = (None, inputs_pspec, masks_pspec) + out_shardings = FrozenDict({key: params_pspec if key is PARAMS_KEY else None \ + for key in abs_var_collect}) + pjit_encoder_init = pjit(encoder.init, in_shardings, out_shardings) + var_collect = pjit_encoder_init(init_rngs, inputs, masks) + + optimizer = optax.adamw(args.lr) + var_collect, params = var_collect.pop(PARAMS_KEY) + state = train_state.TrainState.create(apply_fn=encoder.apply, + params=params, + tx=optimizer) + state_pspec = get_state_pspec(state, params_pspec) + labels_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS,) + + in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None, None) + out_shardings = (state_pspec, None, None, None) + pjit_train_step = pjit(train_step, in_shardings, out_shardings, static_argnums=(6,)) + + in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None) + out_shardings = (None, None) + pjit_eval_step = pjit(eval_step, in_shardings, out_shardings) + + if args.use_fp8: + labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) + check_fp8(state, var_collect, inputs, masks, labels) + + if args.dry_run: + labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) + rngs = {DROPOUT_KEY: dropout_rng} + pjit_train_step(state, inputs, masks, labels, var_collect, rngs, args.use_fp8) + print("PASSED") + return None + + for epoch in range(1, args.epochs + 1): + rng, input_rng = jax.random.split(rng) + rng, dropout_rng = jax.random.split(rng) + rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng} + + state, train_loss, train_accuracy, var_collect = train_epoch( + state, train_ds, args.batch_size, rngs, var_collect, args.use_fp8, + pjit_train_step) + + test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size, + var_collect, pjit_eval_step) + + print(f"Epoch: {epoch:>2} " + f"Train Loss: {train_loss:.6f} " + f"Train Accuracy: {train_accuracy:.6f} " + f"Test Loss: {test_loss:.6f} " + f"Test Accuracy: {test_accuracy:.6f} ") + + return [train_loss, train_accuracy, test_loss, test_accuracy] + + +def encoder_parser(args): + """Training settings.""" + parser = argparse.ArgumentParser(description="JAX Encoder Example") + parser.add_argument( + "--num-gpu", + type=int, + default=8, + metavar="N", + help="number of GPUs (default: 8)", + ) + parser.add_argument( + "--batch-size", + type=int, + default=64, + metavar="N", + help="input batch size for training (default: 64)", + ) + parser.add_argument( + "--test-batch-size", + type=int, + default=64, + metavar="N", + help="input batch size for testing (default: 64)", + ) + parser.add_argument( + "--max-seq-len", + type=int, + default=32, + metavar="N", + help="maximum sequence length (default: 32)", + ) + parser.add_argument( + "--epochs", + type=int, + default=3, + metavar="N", + help="number of epochs to train (default: 3)", + ) + parser.add_argument( + "--lr", + type=float, + default=0.0001, + metavar="LR", + help="learning rate (default: 0.0001)", + ) + parser.add_argument( + "--dry-run", + action="store_true", + default=False, + help="quickly check a single pass", + ) + parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") + parser.add_argument("--use-fp8", + action="store_true", + default=False, + help="Use FP8 for inference and training without recalibration") + + return parser.parse_args(args) + + +class TestEncoder(unittest.TestCase): + """Encoder unittests""" + + @classmethod + def setUpClass(cls): + """Run 3 epochs for testing""" + num_gpu = len(jax.local_devices()) + if num_gpu % 2 != 0: + num_gpu = 1 + cls.args = encoder_parser(["--epochs", "3", "--num-gpu", str(num_gpu)]) + + def test_te_bf16(self): + """Test Transformer Engine with BF16""" + actual = train_and_evaluate(self.args) + assert actual[0] < 0.45 and actual[1] > 0.79 + + @unittest.skipIf(not gpu_has_fp8(), reason='GPU capability is not enough to run FP8') + def test_te_fp8(self): + """Test Transformer Engine with FP8""" + self.args.use_fp8 = True + actual = train_and_evaluate(self.args) + assert actual[0] < 0.45 and actual[1] > 0.79 + + +if __name__ == "__main__": + train_and_evaluate(encoder_parser(None)) diff --git a/examples/jax/encoder/test_single_gpu_bf16_training.py b/examples/jax/encoder/test_single_gpu_bf16_training.py deleted file mode 100644 index 122f2aa599..0000000000 --- a/examples/jax/encoder/test_single_gpu_bf16_training.py +++ /dev/null @@ -1,75 +0,0 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -""" Encoder with BF16 Training on single GPU""" -import jax -import jax.numpy as jnp -import optax -from flax.core.frozen_dict import FrozenDict -from flax.training import train_state - -import transformer_engine.jax as te - -PARAMS_KEY = 'params' - -BATCH = 32 -SEQLEN = 512 -HIDDEN = 1024 - - -def network(): - """NLP Encoder""" - encoder = te.TransformerLayer(hidden_size=HIDDEN, - mlp_hidden_size=4 * HIDDEN, - hidden_dropout=0.0, - attention_dropout=0.0, - layernorm_type='rmsnorm', - mlp_activations=('gelu', 'linear'), - layer_type=te.TransformerLayerType.ENCODER, - transpose_batch_sequence=True, - dtype=jnp.bfloat16) - return encoder - - -def synthesis_data(data_rng): - """Dataset generator""" - return jax.random.normal(data_rng, [SEQLEN, BATCH, HIDDEN], jnp.bfloat16) - - -def train_step(batch, state, others): - """Training function.""" - - def loss_fn(collections): - logits = state.apply_fn(collections, batch) - loss = jnp.mean(logits) - return loss - - grad_fn = jax.value_and_grad(loss_fn) - loss, grads = grad_fn(FrozenDict({PARAMS_KEY: state.params, **others})) - grads, params_grads = grads.pop(PARAMS_KEY) - state = state.apply_gradients(grads=params_grads) - return loss, state, others - - -def test_encoder(): - """Encoder example""" - rng = jax.random.PRNGKey(0) - rng, init_rng, data_rng = jax.random.split(rng, 3) - inputs = synthesis_data(data_rng) - - encoder = network() - variables = jax.jit(encoder.init)(init_rng, inputs) - variables, params = variables.pop(PARAMS_KEY) - optimizer = optax.sgd(0.001, 0.9) - state = train_state.TrainState.create(apply_fn=encoder.apply, params=params, tx=optimizer) - jitted_train_step = jax.jit(train_step) - - for i in range(5): - rng, data_rng = jax.random.split(rng) - inputs = synthesis_data(data_rng) - loss, state, variables = jitted_train_step(inputs, state, variables) - print(f"Step {i} - Loss: {loss}") - - -if __name__ == "__main__": - test_encoder() diff --git a/examples/jax/encoder/test_single_gpu_encoder.py b/examples/jax/encoder/test_single_gpu_encoder.py new file mode 100644 index 0000000000..bac1469b5b --- /dev/null +++ b/examples/jax/encoder/test_single_gpu_encoder.py @@ -0,0 +1,344 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +""" Encoder training on single GPU""" +import argparse +import os +import unittest +from functools import partial + +import jax +import jax.numpy as jnp +import nltk +import numpy as np +import optax +import tensorflow_datasets as tfds +from cuda import cudart +from flax import linen as nn +from flax.core.frozen_dict import FrozenDict +from flax.training import train_state + +import transformer_engine.jax as te + +PARAMS_KEY = 'params' +DROPOUT_KEY = 'dropout' +INPUT_KEY = 'input_rng' + + +def gpu_has_fp8(): + """Check if the GPU has FP8.""" + cudaSuccess = cudart.cudaError_t.cudaSuccess + ret, gpu_id = cudart.cudaGetDevice() + assert ret == cudaSuccess + flag = cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMajor + _, major = cudart.cudaDeviceGetAttribute(flag, gpu_id) + flag = cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMinor + _, minor = cudart.cudaDeviceGetAttribute(flag, gpu_id) + sm_arch = major * 10 + minor + return sm_arch >= 89 + + +class Net(nn.Module): + """NLP Encoder""" + num_embed: int + + @nn.compact + def __call__(self, x, mask, disable_dropout=False): + x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x) + + te_Encoder = partial(te.TransformerLayer, + hidden_size=256, + mlp_hidden_size=1024, + num_attention_heads=8, + hidden_dropout=0.1, + attention_dropout=0.1, + dropout_rng_name=DROPOUT_KEY, + layer_type=te.TransformerLayerType.ENCODER, + enable_relative_embedding=False, + dtype=jnp.bfloat16) + x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) + + x = x.reshape(x.shape[0], -1) + + x = te.DenseGeneral(features=256, dtype=jnp.bfloat16)(x) + + x = te.DenseGeneral(features=256, dtype=jnp.bfloat16)(x) + + x = nn.Dense(features=2, dtype=jnp.bfloat16)(x) + return x + + +@partial(jax.jit, static_argnums=6) +def train_step(state, inputs, masks, labels, var_collect, rngs, use_fp8): + """Computes gradients, loss and accuracy for a single batch.""" + + def loss_fn(var_collect, disable_dropout=False): + logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs) + one_hot = jax.nn.one_hot(labels, 2) + loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) + return loss, logits + + var_collect = FrozenDict({**var_collect, PARAMS_KEY: state.params}) + grad_fn = jax.value_and_grad(loss_fn, has_aux=True) + (loss, logits), grads = grad_fn(var_collect) + accuracy = jnp.mean(jnp.argmax(logits, -1) == labels) + + var_collect, grads = grads.pop(PARAMS_KEY) + state = state.apply_gradients(grads=grads) + if use_fp8: + var_collect = te.update_fp8_metas(var_collect) + + return state, loss, accuracy, var_collect + + +def train_epoch(state, train_ds, batch_size, rngs, var_collect, use_fp8): + """Train for a single epoch.""" + train_ds_size = len(train_ds['sentence']) + steps_per_epoch = train_ds_size // batch_size + perms = jax.random.permutation(rngs[INPUT_KEY], train_ds_size) + perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch + perms = perms.reshape((steps_per_epoch, batch_size)) + epoch_loss = [] + epoch_accuracy = [] + + for perm in perms: + batch_inputs = train_ds['sentence'][perm, ...] + batch_masks = train_ds['mask'][perm, ...] + batch_labels = train_ds['label'][perm, ...] + state, loss, accuracy, var_collect = train_step(state, batch_inputs, batch_masks, + batch_labels, var_collect, rngs, use_fp8) + epoch_loss.append(loss) + epoch_accuracy.append(accuracy) + + avg_loss = np.mean(epoch_loss) + avg_accuracy = np.mean(epoch_accuracy) + return state, avg_loss, avg_accuracy, var_collect + + +@jax.jit +def eval_step(state, inputs, masks, labels, var_collect): + """Computes loss and accuracy for a single batch.""" + + def loss_fn(var_collect, disable_dropout=False): + logits = state.apply_fn(var_collect, inputs, masks, disable_dropout) + one_hot = jax.nn.one_hot(labels, 2) + loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) + return loss, logits + + var_collect = FrozenDict({**var_collect, PARAMS_KEY: state.params}) + loss, logits = loss_fn(var_collect, disable_dropout=True) + accuracy = jnp.mean(jnp.argmax(logits, -1) == labels) + return loss, accuracy + + +def eval_model(state, test_ds, batch_size, var_collect): + """Evaluation loop.""" + test_ds_size = len(test_ds['sentence']) + num_steps = test_ds_size // batch_size + valid_size = num_steps * batch_size + all_loss = [] + all_accuracy = [] + + for batch_start in range(0, valid_size, batch_size): + batch_end = batch_start + batch_size + batch_inputs = test_ds['sentence'][batch_start:batch_end] + batch_masks = test_ds['mask'][batch_start:batch_end] + batch_labels = test_ds['label'][batch_start:batch_end] + loss, accuracy = eval_step(state, batch_inputs, batch_masks, batch_labels, var_collect) + all_loss.append(loss) + all_accuracy.append(accuracy) + + avg_loss = np.mean(all_loss) + avg_accuracy = np.mean(all_accuracy) + return avg_loss, avg_accuracy + + +def data_preprocess(dataset, vocab, word_id, max_seq_len): + """Convert tokens to numbers.""" + nltk.download('punkt') + dataset_size = len(dataset['sentence']) + output = np.zeros((dataset_size, max_seq_len), dtype=np.int32) + mask_3d = np.empty((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8) + + for j, sentence in enumerate(dataset['sentence']): + tokens = nltk.word_tokenize(sentence.decode("utf-8")) + tensor = output[j] + mask_1d = np.zeros((1, max_seq_len), dtype=np.uint8) + + for i, word in enumerate(tokens): + if i >= max_seq_len: + break + + if word not in vocab: + vocab[word] = word_id + tensor[i] = word_id + word_id = word_id + 1 + else: + tensor[i] = vocab[word] + + mask_1d[0, i] = 1 + + mask_2d = mask_3d[j] + np.dot(mask_1d.T, mask_1d, out=mask_2d) + np.subtract(1, mask_2d, out=mask_2d) + + dataset['sentence'] = output + dataset['label'] = dataset['label'].astype(np.float32) + dataset['mask'] = mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len)) + return dataset, vocab, word_id + + +def get_datasets(max_seq_len): + """Load GLUE train and test datasets into memory.""" + vocab = {} + word_id = 0 + dataset = 'glue/cola' + train_ds = tfds.as_numpy(tfds.load(dataset, split='train', batch_size=-1)) + train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len) + test_ds = tfds.as_numpy(tfds.load(dataset, split='validation', batch_size=-1)) + test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len) + return train_ds, test_ds, word_id + + +def check_fp8(state, var_collect, inputs, masks, labels): + "Check if model includes FP8." + rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)} + assert "Float8" in str( + jax.make_jaxpr(train_step, static_argnums=6)(state, inputs, masks, labels, var_collect, + rngs, True)) + + +def train_and_evaluate(args): + """Execute model training and evaluation loop.""" + os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" + print(args) + + if args.use_fp8: + assert gpu_has_fp8(), "GPU needs to support FP8." + + rng = jax.random.PRNGKey(args.seed) + rng, params_rng = jax.random.split(rng) + rng, dropout_rng = jax.random.split(rng) + init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng} + + input_shape = [args.batch_size, args.max_seq_len] + mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len] + label_shape = [args.batch_size] + + with te.fp8_autocast(enabled=args.use_fp8): + train_ds, test_ds, num_embed = get_datasets(args.max_seq_len) + encoder = Net(num_embed) + inputs = jnp.zeros(input_shape, dtype=jnp.int32) + masks = jnp.zeros(mask_shape, dtype=jnp.uint8) + var_collect = encoder.init(init_rngs, inputs, masks) + tx = optax.adamw(args.lr) + state = train_state.TrainState.create(apply_fn=encoder.apply, + params=var_collect[PARAMS_KEY], + tx=tx) + + if args.use_fp8: + labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) + check_fp8(state, var_collect, inputs, masks, labels) + + if args.dry_run: + labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) + rngs = {DROPOUT_KEY: dropout_rng} + train_step(state, inputs, masks, labels, var_collect, rngs, args.use_fp8) + print("PASSED") + return None + + for epoch in range(1, args.epochs + 1): + rng, input_rng = jax.random.split(rng) + rng, dropout_rng = jax.random.split(rng) + rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng} + + state, train_loss, train_accuracy, var_collect = train_epoch( + state, train_ds, args.batch_size, rngs, var_collect, args.use_fp8) + + test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size, var_collect) + + print(f"Epoch: {epoch:>2} " + f"Train Loss: {train_loss:.6f} " + f"Train Accuracy: {train_accuracy:.6f} " + f"Test Loss: {test_loss:.6f} " + f"Test Accuracy: {test_accuracy:.6f} ") + + return [train_loss, train_accuracy, test_loss, test_accuracy] + + +def encoder_parser(args): + """Training settings.""" + parser = argparse.ArgumentParser(description="JAX Encoder Example") + parser.add_argument( + "--batch-size", + type=int, + default=64, + metavar="N", + help="input batch size for training (default: 64)", + ) + parser.add_argument( + "--test-batch-size", + type=int, + default=64, + metavar="N", + help="input batch size for testing (default: 64)", + ) + parser.add_argument( + "--max-seq-len", + type=int, + default=32, + metavar="N", + help="maximum sequence length (default: 32)", + ) + parser.add_argument( + "--epochs", + type=int, + default=3, + metavar="N", + help="number of epochs to train (default: 3)", + ) + parser.add_argument( + "--lr", + type=float, + default=0.0001, + metavar="LR", + help="learning rate (default: 0.0001)", + ) + parser.add_argument( + "--dry-run", + action="store_true", + default=False, + help="quickly check a single pass", + ) + parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") + parser.add_argument("--use-fp8", + action="store_true", + default=False, + help="Use FP8 for inference and training without recalibration") + + return parser.parse_args(args) + + +class TestEncoder(unittest.TestCase): + """Encoder unittests""" + + @classmethod + def setUpClass(cls): + """Run 4 epochs for testing""" + cls.args = encoder_parser(["--epochs", "3"]) + + def test_te_bf16(self): + """Test Transformer Engine with BF16""" + actual = train_and_evaluate(self.args) + assert actual[0] < 0.45 and actual[1] > 0.79 + + @unittest.skipIf(not gpu_has_fp8(), reason='GPU capability is not enough to run FP8') + def test_te_fp8(self): + """Test Transformer Engine with FP8""" + self.args.use_fp8 = True + actual = train_and_evaluate(self.args) + assert actual[0] < 0.45 and actual[1] > 0.79 + + +if __name__ == "__main__": + train_and_evaluate(encoder_parser(None)) diff --git a/examples/jax/encoder/test_single_gpu_fp8_training.py b/examples/jax/encoder/test_single_gpu_fp8_training.py deleted file mode 100644 index f03b43250a..0000000000 --- a/examples/jax/encoder/test_single_gpu_fp8_training.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -""" Encoder with FP8 Training on single GPU""" -import jax -import jax.numpy as jnp -import optax -from cuda import cudart -from flax.core.frozen_dict import FrozenDict -from flax.training import train_state - -import transformer_engine.jax as te -from transformer_engine.jax.fp8 import FP8Helper -from transformer_engine.common.recipe import Format as FP8Format -from transformer_engine.common.recipe import DelayedScaling - -PARAMS_KEY = 'params' - -BATCH = 32 -SEQLEN = 512 -HIDDEN = 1024 - - -def gpu_has_fp8(): - """GPU arch has to support FP8""" - cudaSuccess = cudart.cudaError_t.cudaSuccess - ret, gpu_id = cudart.cudaGetDevice() - assert ret == cudaSuccess - flag = cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMajor - _, major = cudart.cudaDeviceGetAttribute(flag, gpu_id) - flag = cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMinor - _, minor = cudart.cudaDeviceGetAttribute(flag, gpu_id) - sm_arch = major * 10 + minor - return sm_arch >= 89 - - -def network(): - """NLP Encoder""" - encoder = te.TransformerLayer(hidden_size=HIDDEN, - mlp_hidden_size=4 * HIDDEN, - hidden_dropout=0.0, - attention_dropout=0.0, - layernorm_type='rmsnorm', - mlp_activations=('gelu', 'linear'), - layer_type=te.TransformerLayerType.ENCODER, - transpose_batch_sequence=True, - dtype=jnp.bfloat16) - return encoder - - -def synthesis_data(data_rng): - """Dataset generator""" - return jax.random.normal(data_rng, [SEQLEN, BATCH, HIDDEN], jnp.bfloat16) - - -def train_step(batch, state, others): - """Training function.""" - - def loss_fn(collections): - logits = state.apply_fn(collections, batch) - loss = jnp.mean(logits) - return loss - - grad_fn = jax.value_and_grad(loss_fn) - loss, grads = grad_fn(FrozenDict({PARAMS_KEY: state.params, **others})) - grads, params_grads = grads.pop(PARAMS_KEY) - state = state.apply_gradients(grads=params_grads) - others = FP8Helper.update_fp8_metas(grads) - return loss, state, others - - -def test_encoder(): - """Encoder example""" - if gpu_has_fp8() is False: - print("GPU doesn't support FP8") - return - - rng = jax.random.PRNGKey(0) - rng, init_rng, data_rng = jax.random.split(rng, 3) - inputs = synthesis_data(data_rng) - optimizer = optax.sgd(0.001, 0.9) - - with te.fp8_autocast(enabled=True, fp8_recipe=DelayedScaling(fp8_format=FP8Format.HYBRID)): - encoder = network() - variables = jax.jit(encoder.init)(init_rng, inputs) - variables, params = variables.pop(PARAMS_KEY) - state = train_state.TrainState.create(apply_fn=encoder.apply, params=params, tx=optimizer) - jitted_train_step = jax.jit(train_step) - assert "fp8" in str(jax.make_jaxpr(jitted_train_step)(inputs, state, variables)) - - for i in range(5): - rng, data_rng = jax.random.split(rng) - inputs = synthesis_data(data_rng) - loss, state, variables = jitted_train_step(inputs, state, variables) - print(f"Step {i} - Loss: {loss}") - - -if __name__ == "__main__": - test_encoder() diff --git a/examples/jax/mnist/README.md b/examples/jax/mnist/README.md new file mode 100644 index 0000000000..51e4f45f5f --- /dev/null +++ b/examples/jax/mnist/README.md @@ -0,0 +1,34 @@ +# Basic MNIST Example with Optional FP8 # + +This example uses MNIST training to demonstrate the Transformer Engine usage. The Transformer Engine is built on top of [Flax](https://github.com/google/flax), a neural network library and ecosystem for JAX. Thus, the Transformer Engine is free to interoperate with other Flax modules. The basic Flax usage can be referred to [Flax Basics](https://flax.readthedocs.io/en/latest/guides/flax_basics.html). + +1. Setup dataset: The first step is to prepare the dataset. This is done by using the `tfds` library to download the MNIST dataset and perform image preprocessing. The `get_datasets` routine is used for this purpose. + +2. Define model: The `Net` class is a small CNN model for image classification. It has an option to switch between using `nn.Dense` provided by Flax and `te.DenseGeneral` provided by the Transformer Engine. This allows for easy comparison between the two libraries. + +3. Build training loop: The `train_and_evaluate` is the main routine to initialize the model and start training and evaluating. For FP8 training, the key is `te.fp8_autocast` context manager. If fp8_autocast is enabled, it will cast all `te.DenseGeneral` to FP8 precision. The `var_collect` is a collection including needed information for model training, such as parameters and FP8 metadata, which is necessary for correct casting of BF16 tensors into FP8 tensors at runtime. If fp8_autocast is turned on and print var_collect, you will see FP8 metadata inside, such as `fp8_meta_collection` section. The training and evaluating with FP8 have to be done under fp8_autocast. If not, then fp8_autocast will deconstruct the FP8 metadata, and the model will fall back to higher floating point precision, such as BF16 in this example. To check if FP8 is enabled, use the `check_fp8` routine. If model initialization with FP8 works fine, the string returned by jax.make_jaxpr should include the `Float8` keyword. + +4. Training process: In `apply_model`, the main difference between normal Flax usage and this example is, with FP8 training, the FP8 metadata has to be filled into the gradient function `grad_fn`. Otherwise, the Transformer Engine doesn't know how to cast the BF16 tensor into FP8 tensor at runtime correctly. The FP8 metadata doesn't belong in model parameters (`state.params`), so we need to manually combine the metadata and latest model parameters into var_collect as a frozen dictionary and fill it to the gradient function. After getting loss and gradient, we also need to call `te.update_fp8_metas` to update FP8 metadata in the `update_model` routine. The number of training steps to update FP8 metadata can be customized. In this example, it is updated every step. + +5. Evaluating process: The evaluating process is the same as the training process. Need to ensure FP8 metadata is inside var_collect and fill it into loss function. + +6. Additional options: The `te.fp8_autocast` context manager has additional options + * FP8 Recipe: control FP8 training behavior. See the [FP8 tutorial](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html) for a detailed explanation of FP8 recipes and the supported options. **Noted** that FP8 metadata is now the responsibility of the user to update (i.e., manually calling `te.update_fp8_metas`). The JAX version of Transformer Engine cannot update FP8 metadata on its own. + * Sharding Resource: tell Transformer Engine how to make data parallelism and tensor parallelism. We will introduce it more in Encoder examples. + +## Run ## + +1. Use Flax to train MNIST with BF16 as usual +```bash +python test_single_gpu_mnist.py +``` + +2. Use `te.DenseGeneral` provided by Transformer Engine to train MNIST with BF16 +```bash +python test_single_gpu_mnist.py --use-te +``` + +3. Use `te.DenseGeneral` provided by Transformer Engine to train MNIST and enable FP8 training and evaluation. +```bash +python test_single_gpu_mnist.py --use-fp8 +``` diff --git a/examples/jax/mnist/requirements.txt b/examples/jax/mnist/requirements.txt new file mode 100644 index 0000000000..b5b1aca343 --- /dev/null +++ b/examples/jax/mnist/requirements.txt @@ -0,0 +1,3 @@ +flax +optax +tensorflow-datasets diff --git a/examples/jax/mnist/test_single_gpu_mnist.py b/examples/jax/mnist/test_single_gpu_mnist.py new file mode 100644 index 0000000000..0b16dd8b98 --- /dev/null +++ b/examples/jax/mnist/test_single_gpu_mnist.py @@ -0,0 +1,311 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +""" MNIST training on single GPU""" +import argparse +import os +import unittest +from functools import partial + +import jax +import jax.numpy as jnp +import numpy as np +import optax +import tensorflow_datasets as tfds +from cuda import cudart +from flax import linen as nn +from flax.core.frozen_dict import FrozenDict +from flax.training import train_state + +import transformer_engine.jax as te + +IMAGE_H = 28 +IMAGE_W = 28 +IMAGE_C = 1 +PARAMS_KEY = 'params' +DROPOUT_KEY = 'dropout' +INPUT_KEY = 'input_rng' + + +def gpu_has_fp8(): + """Check if the GPU has FP8.""" + cudaSuccess = cudart.cudaError_t.cudaSuccess + ret, gpu_id = cudart.cudaGetDevice() + assert ret == cudaSuccess + flag = cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMajor + _, major = cudart.cudaDeviceGetAttribute(flag, gpu_id) + flag = cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMinor + _, minor = cudart.cudaDeviceGetAttribute(flag, gpu_id) + sm_arch = major * 10 + minor + return sm_arch >= 89 + + +class Net(nn.Module): + """CNN model for MNIST.""" + use_te: bool = False + + @nn.compact + def __call__(self, x, disable_dropout=False): + if self.use_te: + nn_Dense = te.DenseGeneral + else: + nn_Dense = nn.Dense + + x = nn.Conv(features=32, kernel_size=(3, 3), strides=1, dtype=jnp.bfloat16)(x) + x = nn.relu(x) + x = nn.Conv(features=64, kernel_size=(3, 3), strides=1, dtype=jnp.bfloat16)(x) + x = nn.relu(x) + x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2)) + x = nn.Dropout(rate=0.25)(x, deterministic=disable_dropout) + x = x.reshape(x.shape[0], -1) + x = nn_Dense(features=128, dtype=jnp.bfloat16)(x) + x = nn.relu(x) + x = nn.Dropout(rate=0.5)(x, deterministic=disable_dropout) + x = nn_Dense(features=16, dtype=jnp.bfloat16)(x) + x = nn.Dense(features=10, dtype=jnp.bfloat16)(x) + return x + + +@jax.jit +def apply_model(state, images, labels, var_collect, rngs=None): + """Computes gradients, loss and accuracy for a single batch.""" + + def loss_fn(var_collect, disable_dropout=False): + logits = state.apply_fn(var_collect, images, disable_dropout, rngs=rngs) + one_hot = jax.nn.one_hot(labels, 10) + loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) + return loss, logits + + var_collect = FrozenDict({**var_collect, PARAMS_KEY: state.params}) + + if rngs is not None: + grad_fn = jax.value_and_grad(loss_fn, has_aux=True) + (loss, logits), grads = grad_fn(var_collect) + else: + loss, logits = loss_fn(var_collect, disable_dropout=True) + grads = None + + accuracy = jnp.mean(jnp.argmax(logits, -1) == labels) + return grads, loss, accuracy + + +@partial(jax.jit, static_argnums=2) +def update_model(state, grads, use_fp8): + """Update model params and FP8 meta.""" + state = state.apply_gradients(grads=grads[PARAMS_KEY]) + if use_fp8: + grads = te.update_fp8_metas(grads) + return state, grads + + +def train_epoch(state, train_ds, batch_size, rngs, var_collect, use_fp8): + """Train for a single epoch.""" + train_ds_size = len(train_ds['image']) + steps_per_epoch = train_ds_size // batch_size + perms = jax.random.permutation(rngs[INPUT_KEY], train_ds_size) + perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch + perms = perms.reshape((steps_per_epoch, batch_size)) + epoch_loss = [] + epoch_accuracy = [] + + for perm in perms: + batch_images = train_ds['image'][perm, ...] + batch_labels = train_ds['label'][perm, ...] + grads, loss, accuracy = apply_model(state, batch_images, batch_labels, var_collect, rngs) + state, var_collect = update_model(state, grads, use_fp8) + epoch_loss.append(loss) + epoch_accuracy.append(accuracy) + + avg_loss = np.mean(epoch_loss) + avg_accuracy = np.mean(epoch_accuracy) + return state, avg_loss, avg_accuracy, var_collect + + +def eval_model(state, test_ds, batch_size, var_collect): + """Evaluation loop.""" + test_ds_size = len(test_ds['image']) + num_steps = test_ds_size // batch_size + valid_size = num_steps * batch_size + all_loss = [] + all_accuracy = [] + + for batch_start in range(0, valid_size, batch_size): + batch_end = batch_start + batch_size + batch_images = test_ds['image'][batch_start:batch_end] + batch_labels = test_ds['label'][batch_start:batch_end] + _, loss, accuracy = apply_model(state, batch_images, batch_labels, var_collect) + all_loss.append(loss) + all_accuracy.append(accuracy) + + avg_loss = np.mean(all_loss) + avg_accuracy = np.mean(all_accuracy) + return avg_loss, avg_accuracy + + +def get_datasets(): + """Load MNIST train and test datasets into memory.""" + ds_builder = tfds.builder('mnist') + ds_builder.download_and_prepare() + train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1)) + test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1)) + train_ds['image'] = jnp.float32(train_ds['image']) / 255. + test_ds['image'] = jnp.float32(test_ds['image']) / 255. + return train_ds, test_ds + + +def check_fp8(state, var_collect, input_shape, label_shape): + "Check if model includes FP8." + assert "Float8" in str( + jax.make_jaxpr(apply_model)(state, jnp.empty(input_shape, dtype=jnp.bfloat16), + jnp.empty(label_shape, dtype=jnp.bfloat16), var_collect)) + + +def train_and_evaluate(args): + """Execute model training and evaluation loop.""" + os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" + print(args) + + if args.use_fp8: + assert gpu_has_fp8(), "GPU needs to support FP8." + args.use_te = True + + train_ds, test_ds = get_datasets() + rng = jax.random.PRNGKey(args.seed) + rng, params_rng = jax.random.split(rng) + rng, dropout_rng = jax.random.split(rng) + init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng} + + input_shape = [args.batch_size, IMAGE_H, IMAGE_W, IMAGE_C] + label_shape = [args.batch_size] + + with te.fp8_autocast(enabled=args.use_fp8): + cnn = Net(args.use_te) + var_collect = cnn.init(init_rngs, jnp.empty(input_shape, dtype=jnp.bfloat16)) + tx = optax.sgd(args.lr, args.momentum) + state = train_state.TrainState.create(apply_fn=cnn.apply, + params=var_collect[PARAMS_KEY], + tx=tx) + + if args.use_fp8: + check_fp8(state, var_collect, input_shape, label_shape) + + if args.dry_run: + apply_model(state, jnp.empty(input_shape, dtype=jnp.bfloat16), + jnp.empty(label_shape, dtype=jnp.bfloat16), var_collect, + {DROPOUT_KEY: dropout_rng}) + print("PASSED") + return None + + for epoch in range(1, args.epochs + 1): + rng, input_rng = jax.random.split(rng) + rng, dropout_rng = jax.random.split(rng) + rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng} + + state, train_loss, train_accuracy, var_collect = train_epoch( + state, train_ds, args.batch_size, rngs, var_collect, args.use_fp8) + test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size, var_collect) + + print(f"Epoch: {epoch:>2} " + f"Train Loss: {train_loss:.6f} " + f"Train Accuracy: {train_accuracy:.6f} " + f"Test Loss: {test_loss:.6f} " + f"Test Accuracy: {test_accuracy:.6f} ") + + return [train_loss, train_accuracy, test_loss, test_accuracy] + + +def mnist_parser(args): + """Training settings.""" + parser = argparse.ArgumentParser(description="JAX MNIST Example") + parser.add_argument( + "--batch-size", + type=int, + default=64, + metavar="N", + help="input batch size for training (default: 64)", + ) + parser.add_argument( + "--test-batch-size", + type=int, + default=800, + metavar="N", + help="input batch size for testing (default: 800)", + ) + parser.add_argument( + "--epochs", + type=int, + default=10, + metavar="N", + help="number of epochs to train (default: 10)", + ) + parser.add_argument( + "--lr", + type=float, + default=0.01, + metavar="LR", + help="learning rate (default: 0.01)", + ) + parser.add_argument( + "--momentum", + type=float, + default=0.9, + metavar="M", + help="Momentum (default: 0.9)", + ) + parser.add_argument( + "--dry-run", + action="store_true", + default=False, + help="quickly check a single pass", + ) + parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") + parser.add_argument("--use-fp8", + action="store_true", + default=False, + help="Use FP8 for inference and training without recalibration. " \ + "It also enables Transformer Engine implicitly.") + parser.add_argument("--use-te", + action="store_true", + default=False, + help="Use Transformer Engine") + + return parser.parse_args(args) + + +class TestMNIST(unittest.TestCase): + """MNIST unittests""" + + @classmethod + def setUpClass(cls): + """Run MNIST without Transformer Engine""" + cls.args = mnist_parser(["--epochs", "5"]) + + @staticmethod + def verify(actual): + """Check If loss and accuracy match target""" + desired_traing_loss = 0.055 + desired_traing_accuracy = 0.98 + desired_test_loss = 0.035 + desired_test_accuracy = 0.098 + assert actual[0] < desired_traing_loss + assert actual[1] > desired_traing_accuracy + assert actual[2] < desired_test_loss + assert actual[3] > desired_test_accuracy + + def test_te_bf16(self): + """Test Transformer Engine with BF16""" + self.args.use_te = True + self.args.use_fp8 = False + actual = train_and_evaluate(self.args) + self.verify(actual) + + @unittest.skipIf(not gpu_has_fp8(), reason='GPU capability is not enough to run FP8') + def test_te_fp8(self): + """Test Transformer Engine with FP8""" + self.args.use_fp8 = True + actual = train_and_evaluate(self.args) + self.verify(actual) + + +if __name__ == "__main__": + train_and_evaluate(mnist_parser(None)) diff --git a/qa/L0_jax_unittest/test.sh b/qa/L0_jax_unittest/test.sh index c040e973bf..247a388edb 100644 --- a/qa/L0_jax_unittest/test.sh +++ b/qa/L0_jax_unittest/test.sh @@ -6,4 +6,7 @@ set -xe : ${TE_PATH:=/opt/transformerengine} pytest -Wignore -v $TE_PATH/tests/jax + +pip install -r $TE_PATH/examples/jax/mnist/requirements.txt +pip install -r $TE_PATH/examples/jax/encoder/requirements.txt pytest -Wignore -v $TE_PATH/examples/jax diff --git a/qa/L0_license/config.json b/qa/L0_license/config.json index f9a93a70f5..ad47393434 100644 --- a/qa/L0_license/config.json +++ b/qa/L0_license/config.json @@ -17,7 +17,9 @@ "VERSION", "Doxyfile", "pylintrc", - ".json" + ".json", + ".md", + ".txt" ], "exclude_copyright": [], "copyright_only": false diff --git a/qa/L0_license/copyright_checker.py b/qa/L0_license/copyright_checker.py index c2f462e690..cd80b957da 100644 --- a/qa/L0_license/copyright_checker.py +++ b/qa/L0_license/copyright_checker.py @@ -69,6 +69,7 @@ def get_file_type(path): "txt": ["txt"], "cfg": ["cfg"], "sh": ["sh"], + "md": ["md"], } tmp = path.split(".") for filetype, ext_list in ext.items(): diff --git a/tests/jax/test_mnist.py b/tests/jax/test_mnist.py deleted file mode 100644 index ce5d9e4d8c..0000000000 --- a/tests/jax/test_mnist.py +++ /dev/null @@ -1,227 +0,0 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -import os -import tempfile -import unittest -from functools import partial - -import jax -import jax.numpy as jnp -import numpy as np -import optax -import tensorflow_datasets as tfds -from flax import linen as nn -from flax.training import train_state - -from transformer_engine.common.recipe import Format as FP8Format -from transformer_engine.jax import DenseGeneral -from transformer_engine.jax.fp8 import FP8Helper -from utils import is_fp8_supported - - -class MLPNN(nn.Module): - - use_fp8_dense: bool = True - - @nn.compact - def __call__(self, x): - x = x.reshape((x.shape[0], -1)) # flatten - x = nn.Dense(features=512)(x) - x = nn.relu(x) - - features = [256, 256, 128] - for feature in features: - x = DenseGeneral(features=feature, transpose_batch_sequence=False, - dtype=jnp.bfloat16, use_bias=True)(x) \ - if self.use_fp8_dense else nn.Dense(features=feature)(x) - x = nn.relu(x) - - x = nn.Dense(features=10, use_bias=True)(x) - return x - - -def cross_entropy_loss(*, logits, labels): - labels_onehot = jax.nn.one_hot(labels, num_classes=10) - return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean() - - -def compute_metrics(*, logits, labels): - loss = cross_entropy_loss(logits=logits, labels=labels) - accuracy = jnp.mean(jnp.argmax(logits, -1) == labels) - metrics = { - 'loss': loss, - 'accuracy': accuracy, - } - return metrics - - -def get_datasets(): - """Load MNIST train and test datasets into memory.""" - ds_builder = tfds.builder('mnist', data_dir="/tmp/tensorflow-datasets/downloads") - ds_builder.download_and_prepare() - train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1)) - test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1)) - train_ds['image'] = jnp.float32(train_ds['image']) / 255. - test_ds['image'] = jnp.float32(test_ds['image']) / 255. - return train_ds, test_ds - - -def create_train_state(rng, learning_rate, momentum, use_fp8_dense): - """Creates initial `TrainState`.""" - cnn = MLPNN(use_fp8_dense=use_fp8_dense) - variables = cnn.init(rng, jnp.ones([32, 28, 28, 1])) - tx = optax.sgd(learning_rate, momentum) - return train_state.TrainState.create(apply_fn=cnn.apply, params=variables['params'], - tx=tx), variables - - -@partial(jax.jit, static_argnums=(3,)) -def train_step(state, others, batch, use_fp8_dense): - """Train for a single step.""" - - def loss_fn(collections): - logits = MLPNN(use_fp8_dense=use_fp8_dense).apply(collections, batch['image']) - loss = cross_entropy_loss(logits=logits, labels=batch['label']) - return loss, logits - - grad_fn = jax.value_and_grad(loss_fn, has_aux=True) - (_, logits), grads = grad_fn(others) - state = state.apply_gradients(grads=grads['params']) - metrics = compute_metrics(logits=logits, labels=batch['label']) - return state, metrics, grads - - -def train_epoch(state, variables, train_ds, batch_size, rng, use_fp8_dense): - """Train for a single epoch.""" - train_ds_size = len(train_ds['image']) - steps_per_epoch = train_ds_size // batch_size - perms = jax.random.permutation(rng, train_ds_size) - perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch - perms = perms.reshape((steps_per_epoch, batch_size)) - batch_metrics = [] - for idx, perm in enumerate(perms): - idx = idx + 1 - batch = {k: v[perm, ...] for k, v in train_ds.items()} - state, metrics, grads = train_step(state, variables, batch, use_fp8_dense) - - updated_coll = {'params': state.params} - if use_fp8_dense: - updated_coll[FP8Helper.FP8_COLLECTION_NAME] \ - = grads[FP8Helper.FP8_COLLECTION_NAME] - variables = FP8Helper.update_collections(updated_coll, variables) - batch_metrics.append(metrics) - - if use_fp8_dense: - variables = FP8Helper.update_fp8_metas(variables) - - return state, variables - - -@partial(jax.jit, static_argnums=(2,)) -def eval_step(variables, batch, use_fp8_dense): - logits = MLPNN(use_fp8_dense=use_fp8_dense).apply(variables, batch['image']) - return compute_metrics(logits=logits, labels=batch['label']) - - -def eval_model(variables, test_ds, batch_size, use_fp8_dense): - test_ds_size = len(test_ds['image']) - steps_per_epoch = test_ds_size // batch_size - perms = np.arange(0, test_ds_size) - perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch - perms = perms.reshape((steps_per_epoch, batch_size)) - total_summary = {'correct': 0, 'loss': 0, 'total': 0} - for _, perm in enumerate(perms): - batch = {k: v[perm, ...] for k, v in test_ds.items()} - metrics = eval_step(variables, batch, use_fp8_dense) - metrics = jax.device_get(metrics) - summary = jax.tree_map(lambda x: x.item(), metrics) - total_summary['correct'] += summary['accuracy'] * batch_size - total_summary['loss'] += summary['loss'] * batch_size - total_summary['total'] += batch_size - return total_summary['loss']/total_summary['total'], \ - total_summary['correct']/total_summary['total'] - - -class TestMnist(unittest.TestCase): - - def setUp(self) -> None: - os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" - self.learning_rate = 0.1 - self.momentum = 0.9 - - self.num_epochs = 5 - self.batch_size = 512 - self.train_ds, self.test_ds = get_datasets() - - self.margin = 0.0 - self.num_fp8_layers = 3 - self.fp8_meta_update_interval = 1 - self.temp_file = tempfile.NamedTemporaryFile() # pylint: disable=consider-using-with - self.fp8_ckpt_path = self.temp_file.name - - self.seed = 0 - - acc_bfp16_ = self._mnist_baseline_runner() - acc_rtol = 0.005 - self.target_accuracy = acc_bfp16_ * (1. - acc_rtol) - - def tearDown(self): - self.temp_file.close() - - @unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8') - def test_mnist_e4m3(self): - self._mnist_test_runner(FP8Format.E4M3) - - @unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8') - def test_mnist_hybrid(self): - self._mnist_test_runner(FP8Format.HYBRID) - - # Skip for now due to lack bf16 in TE.Format - # def test_mnist_bfloa16(self): - # self._mnist_test_runner(FP8Format.BFLOAT16) - - def _mnist_baseline_runner(self): - rng = jax.random.PRNGKey(self.seed) - rng, init_rng = jax.random.split(rng) - - state, variables = create_train_state(init_rng, self.learning_rate, self.momentum, False) - del init_rng - - _, accuracy = self._train_model(state, variables, self.num_epochs, rng, False) - return accuracy - - def _mnist_test_runner(self, fp8_format): - FP8Helper.initialize(margin=self.margin, fp8_format=fp8_format) - - rng = jax.random.PRNGKey(self.seed) - rng, init_rng = jax.random.split(rng) - - state, variables = create_train_state(init_rng, self.learning_rate, self.momentum, True) - del init_rng - - _, test_accuracy = self._train_model(state, variables, self.num_epochs, rng, True) - - self.assertGreater( - test_accuracy, self.target_accuracy, - f"Convergence test failed on MNIST with FP8Fomat.{fp8_format.name}. " - f"Test Accuracy {test_accuracy:.4f} is lower than target {self.target_accuracy:.4f}") - - FP8Helper.finalize() - - def _train_model(self, state, variables, epochs, rng, use_fp8_dense): - max_test_acc = 0.0 - for _ in range(0, epochs): - rng, input_rng = jax.random.split(rng) - - state, variables = train_epoch(state, variables, self.train_ds, self.batch_size, - input_rng, use_fp8_dense) - - _, test_accuracy = eval_model(variables, self.test_ds, self.batch_size, use_fp8_dense) - max_test_acc = test_accuracy if test_accuracy > max_test_acc else max_test_acc - return state, max_test_acc - - -if __name__ == '__main__': - unittest.main() diff --git a/transformer_engine/jax/module.py b/transformer_engine/jax/module.py index 33029b049d..61dee42475 100644 --- a/transformer_engine/jax/module.py +++ b/transformer_engine/jax/module.py @@ -219,7 +219,7 @@ class LayerNorm(nn.Module): ----------------------- dtype : jax.numpy.dtype, default = jax.numpy.float32 the data type used to allocate the initial parameters. - transpose_batch_sequence : bool, default = True + transpose_batch_sequence : bool, default = False Indicate whether the input tensors were switched axis of batch and sequence length dimension. If set to True, the input tensors should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden). @@ -233,7 +233,7 @@ class LayerNorm(nn.Module): bias_init: Initializer = nn.initializers.zeros bias_axes: Tuple[str, ...] = ('embed',) dtype: DType = jnp.float32 - transpose_batch_sequence: bool = True + transpose_batch_sequence: bool = False sharding_type: ShardingType = ShardingType.SINGLE @nn.compact @@ -358,12 +358,12 @@ class DenseGeneral(TransformerEngineBase): features: Union[Iterable[int], int] kernel_init: Initializer = None kernel_axes: Tuple[str, ...] = () - use_bias: bool = False + use_bias: bool = True bias_init: Initializer = nn.initializers.zeros bias_axes: Tuple[str, ...] = () axis: Union[Iterable[int], int] = -1 dtype: DType = jnp.float32 - transpose_batch_sequence: bool = True + transpose_batch_sequence: bool = False sharding_type: ShardingType = ShardingType.SINGLE def __post_init__(self): diff --git a/transformer_engine/jax/transformer.py b/transformer_engine/jax/transformer.py index 0a5dfce147..69b1325df0 100644 --- a/transformer_engine/jax/transformer.py +++ b/transformer_engine/jax/transformer.py @@ -720,7 +720,7 @@ class TransformerLayer(nn.Module): If set to True, `TransformerLayer` module exposes a single fused parameter for query-key-value for self-attention and key-value for cross-attention. - transpose_batch_sequence : bool, default = True + transpose_batch_sequence : bool, default = False Indicate whether the input tensors were switched axis of batch and sequence length dimension. if set to True, the input tensors should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden). @@ -755,7 +755,7 @@ class TransformerLayer(nn.Module): dtype: DType = jnp.float32 drop_path: float = 0.0 fuse_qkv_params: bool = True - transpose_batch_sequence: bool = True + transpose_batch_sequence: bool = False scale_attn_logits: bool = False scaled_query_init: bool = True From bc0e44848fc83aa422f4777e377abc5cf8bc2474 Mon Sep 17 00:00:00 2001 From: Ming-Xu Huang Date: Wed, 29 Mar 2023 08:43:32 +0800 Subject: [PATCH 13/68] Fix Bugs of TE/JAX (#119) * Support transpose_bs when decoded=True Signed-off-by: Ming-Xu Huang Signed-off-by: Ming Huang * Fix Bugs, 1. Fix missing dropout_dims in LayerNormMLP. 2. Fix broadcast issues in decoded. Signed-off-by: Ming-Xu Huang Signed-off-by: Ming Huang * Fix wrong masks in decoded. Signed-off-by: Ming-Xu Huang Signed-off-by: Ming Huang * Fixed wrong assert condition in TransformerLayer Signed-off-by: Ming Huang * Fix amax is not set as 0 in each step. Signed-off-by: Ming Huang * Enhance rules conflict checking and docs. Signed-off-by: Ming Huang * fix code formatting. Signed-off-by: Ming Huang --------- Signed-off-by: Ming-Xu Huang Signed-off-by: Ming Huang --- tests/jax/test_sharding.py | 10 ++++-- tests/jax/utils.py | 3 +- transformer_engine/jax/fp8.py | 2 +- transformer_engine/jax/module.py | 8 +++-- transformer_engine/jax/transformer.py | 49 +++++++++++++++------------ 5 files changed, 43 insertions(+), 29 deletions(-) diff --git a/tests/jax/test_sharding.py b/tests/jax/test_sharding.py index e572d2162a..458e10ffac 100644 --- a/tests/jax/test_sharding.py +++ b/tests/jax/test_sharding.py @@ -38,9 +38,13 @@ def _get_sharding_resource(mesh_names, sharding_type): ((4,), ("tp",), ShardingType.TP_ROW), ((2, 2), ("dp", "tp"), ShardingType.DP_TP_COL), ((2, 2), ("dp", "tp"), ShardingType.DP_TP_ROW)] -LOGICAL_RULES = [[(('a1', None), ('a2', 'ma2')), False], - [(('a1', None), ('a2', 'ma2'), ('a3', ('ma31', 'ma32'))), True], - [(('a1', None), ('a2', 'ma2'), ('batch', 'batch_1200234')), True]] +LOGICAL_RULES = [ + [(('a1', None), ('a2', 'ma2')), False], + [(('a1', None), ('a2', 'ma2'), ('a3', ('ma31', 'ma32'))), True], + [(('a1', None), ('a2', 'ma2'), ('a3', 'ma31'), ('a3', 'ma32')), False], + [(('a1', None), ('a2', 'ma2'), ('batch', 'batch_1200234')), True], + [(('a1', None), ('a2', 'ma2'), ('a2', 'ma1'), ('batch', 'model'), ('batch', 'data')), True], +] SRS = [ ShardingResource(), ShardingResource('data', None), diff --git a/tests/jax/utils.py b/tests/jax/utils.py index c8a1b0e402..bbd0b1392f 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -321,8 +321,9 @@ def __call__(self, inputs, deterministic: bool = False): # Take elementwise product of above intermediate activations. x = functools.reduce(operator.mul, activations) + dropout_broadcast_dims = (0,) if self.transpose_batch_sequence else (1,) # Apply dropout and final dense output projection. - x = nn.Dropout(rate=self.intermediate_dropout_rate, broadcast_dims=(-2,))( + x = nn.Dropout(rate=self.intermediate_dropout_rate, broadcast_dims=dropout_broadcast_dims)( x, deterministic=deterministic) # Broadcast along length. if self.transpose_batch_sequence: x = nn_partitioning.with_sharding_constraint(x, ('length', 'batch', 'mlp')) diff --git a/transformer_engine/jax/fp8.py b/transformer_engine/jax/fp8.py index 106f8e310f..906f7d273b 100644 --- a/transformer_engine/jax/fp8.py +++ b/transformer_engine/jax/fp8.py @@ -190,7 +190,7 @@ def update_amax_history(amax_buffers: jnp.ndarray) -> jnp.ndarray: Update the amax history """ updated_amax_buffers = jnp.roll(amax_buffers, -1, 1) - updated_amax_buffers.at[:, 0].set(0) + updated_amax_buffers = updated_amax_buffers.at[:, 0].set(0) return updated_amax_buffers @staticmethod diff --git a/transformer_engine/jax/module.py b/transformer_engine/jax/module.py index 61dee42475..2cb0bfea0a 100644 --- a/transformer_engine/jax/module.py +++ b/transformer_engine/jax/module.py @@ -683,6 +683,8 @@ class LayerNormMLP(TransformerEngineBase): Each activation has its own transformation layer. intermediate_dropout_rate: float, default = 0.1 Dropout probability for the dropout op after the :attr:`activations`. + intermediate_hidden_dropout_dims: Sequence[int], default = () + Dimensions that will share the same dropout mask for hidden axis: Union[Iterable[int], int], default = -1 An integer tuple with axes to apply the transformation on. @@ -716,6 +718,7 @@ class LayerNormMLP(TransformerEngineBase): return_layernorm_output: bool = True activations: Sequence[Union[str, Callable]] = ('relu',) intermediate_dropout_rate: float = 0.1 + intermediate_hidden_dropout_dims: Sequence[int] = () axis: Union[Iterable[int], int] = -1 dtype: DType = jnp.float32 transpose_batch_sequence: bool = True @@ -912,8 +915,9 @@ def fp8_meta_generator(): z = functools.reduce(operator.mul, activations) z = jnp.reshape(z, (*z.shape[:-2], -1)) - z = nn.Dropout(rate=self.intermediate_dropout_rate, broadcast_dims=(-2,))( - z, deterministic=deterministic) # Broadcast along length. + z = nn.Dropout(rate=self.intermediate_dropout_rate, + broadcast_dims=self.intermediate_hidden_dropout_dims)( + z, deterministic=deterministic) # DenseGeneral 2 hidden_size = inputs.shape[-1] diff --git a/transformer_engine/jax/transformer.py b/transformer_engine/jax/transformer.py index 69b1325df0..51ead9ceba 100644 --- a/transformer_engine/jax/transformer.py +++ b/transformer_engine/jax/transformer.py @@ -53,6 +53,10 @@ def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules: .. warning:: Please make sure ShardingResource is set via fp8_autocast before calling this function. + .. note:: + This function is only needed when using TransformerLayer. For other modules, such as + DenseGeneral, please properly set axes of kernels and bias. + Parameters ---------- rules : Sequence[Tuple[str, Union[str, None]]] @@ -73,10 +77,12 @@ def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules: f"Thie axis_name should be str, but got {type(key)}." assert isinstance(val, str) or (val is None), \ f"Thie mesh_axis_name should be str or None, but got {type(val)}." - rules_map[key] = val + if key in rules_map: + rules_map[key].append(val) + else: + rules_map[key] = [val] gsr = global_shard_resource() - te_logical_axis_rules = (('batch', gsr.dp_resource), ('embed', None), ('mlp', gsr.tp_resource), ('heads', gsr.tp_resource), ('kv', None), ('qkv_dim', None), ('kv_dim', None), ('joined_kv', gsr.tp_resource), ('act', None), @@ -87,7 +93,7 @@ def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules: key = item[0] val = item[1] if key in rules_map: - assert rules_map[key] == val, \ + assert len(rules_map[key]) == 1 and rules_map[key][0] == val, \ f"The rule diverged between TE and given rule." \ f"Axis:{key} map to {rules_map[key]} in the given" \ f" rules, but {val} in TE's rules." @@ -447,21 +453,22 @@ def kv_init(key, shape, dtype): if decode: is_initialized = self.has_variable('cache', 'cached_key') - # TODO (Ming Huang): Check performance on GPU withou swap dimensions # pylint: disable=fixme - def swap_dims(x): - return x[:-3] + tuple(x[i] for i in [-2, -1, -3]) - - cached_key = self.variable('cache', 'cached_key', jnp.zeros, swap_dims(key.shape), - key.dtype) - cached_value = self.variable('cache', 'cached_value', jnp.zeros, swap_dims(value.shape), + cached_key = self.variable('cache', 'cached_key', jnp.zeros, key.shape, key.dtype) + cached_value = self.variable('cache', 'cached_value', jnp.zeros, value.shape, value.dtype) cache_index = self.variable('cache', 'cache_index', lambda: jnp.array(0, dtype=jnp.int32)) if is_initialized: - batch, num_heads, head_dim, length = cached_key.value.shape + if self.transpose_batch_sequence: + length, batch, num_heads, head_dim = cached_key.value.shape + expected_shape = (1, batch, num_heads, head_dim) + one_hot_indices_shape = (length, 1, 1, 1) + else: + batch, length, num_heads, head_dim = cached_key.value.shape + expected_shape = (batch, 1, num_heads, head_dim) + one_hot_indices_shape = (1, length, 1, 1) # Sanity shape check of cached key against input query. - expected_shape = (batch, 1, num_heads, head_dim) if expected_shape != query.shape: raise ValueError( 'Autoregressive cache shape error, ' @@ -469,19 +476,15 @@ def swap_dims(x): cur_index = cache_index.value one_hot_indices = jax_nn.one_hot(cur_index, length, dtype=key.dtype) - one_token_key = jnp.moveaxis(key, -3, -1) - one_token_value = jnp.moveaxis(value, -3, -1) - key = cached_key.value + one_token_key * one_hot_indices - value = cached_value.value + one_token_value * one_hot_indices + one_hot_indices = jnp.reshape(one_hot_indices, one_hot_indices_shape) + key = cached_key.value + key * one_hot_indices + value = cached_value.value + value * one_hot_indices cached_key.value = key cached_value.value = value cache_index.value = cache_index.value + 1 - key = jnp.moveaxis(key, -1, -3) - value = jnp.moveaxis(value, -1, -3) - mask = combine_masks( - mask, jnp.broadcast_to(jnp.arange(length) <= cur_index, (batch, 1, 1, length))) + mask, jnp.broadcast_to(jnp.arange(length) > cur_index, (batch, 1, 1, length))) if bias is not None: bias = dynamic_vector_slice_in_dim(jnp.squeeze(bias, axis=0), @@ -889,10 +892,11 @@ def hidden_dropout(x, deterministic): assert isinstance(self.hidden_dropout_dims, Sequence) x_shape_len = len(x.shape) for dims in self.hidden_dropout_dims: - assert -x_shape_len < dims < x_shape_len + assert -x_shape_len <= dims < x_shape_len return nn.Dropout(rate=self.hidden_dropout, - broadcast_dims=self.hidden_dropout_dims)(x, deterministic) + broadcast_dims=self.hidden_dropout_dims)(x, + deterministic=deterministic) x = hidden_dropout(x, deterministic) if self.drop_path > 0.0: @@ -944,6 +948,7 @@ def hidden_dropout(x, deterministic): intermediate_dim=self.mlp_hidden_size, activations=self.mlp_activations, intermediate_dropout_rate=self.hidden_dropout, + intermediate_hidden_dropout_dims=self.hidden_dropout_dims, dtype=self.dtype, scale_axes=('embed',), kernel_init=self.mlp_kernel_init, From 78c375d297970ab561f351faac2423fe9bdcb00a Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Wed, 29 Mar 2023 17:38:25 -0700 Subject: [PATCH 14/68] Change FP8 recipe defaults (#112) * Change FP8 recipe defaults Signed-off-by: Kirthi Shankar Sivamani * Increase default amax history length Signed-off-by: Kirthi Shankar Sivamani * Always check history size Signed-off-by: Kirthi Shankar Sivamani * no amax history for onnx export Signed-off-by: Kirthi Shankar Sivamani * revert onnx export test changes Signed-off-by: Kirthi Shankar Sivamani * Fix indices in onnx test Co-authored-by: Neta Zmora Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani Co-authored-by: Neta Zmora --- tests/pytorch/test_onnx_export.py | 17 ++++++++++------- transformer_engine/common/recipe.py | 8 ++++---- transformer_engine/pytorch/module.py | 25 +++++++++++++++++++------ 3 files changed, 33 insertions(+), 17 deletions(-) diff --git a/tests/pytorch/test_onnx_export.py b/tests/pytorch/test_onnx_export.py index e72d1cae59..40486057f4 100644 --- a/tests/pytorch/test_onnx_export.py +++ b/tests/pytorch/test_onnx_export.py @@ -92,12 +92,15 @@ def to_numpy(tensor): return tensor.cpu().numpy() -def set_layer_scale(module: torch.nn.Module, scale: float): - module.fp8_init() +def set_layer_scale(module: torch.nn.Module, scale: float, num_gemms: int): + """Initialize the FP8 quantization scales in module""" + NB_SCALES_PER_GEMM = 3 # One scale per: input, weights, and output GEMM tensors. + nb_total_scales = num_gemms * NB_SCALES_PER_GEMM + module.fp8_init(num_gemms) module.fp8_meta["scaling_fwd"].scale = torch.ones( - 2, dtype=torch.float32, device="cuda") / scale + nb_total_scales, dtype=torch.float32, device="cuda") / scale module.fp8_meta["scaling_fwd"].scale_inv = torch.ones( - 2, dtype=torch.float32, device="cuda") * scale + nb_total_scales, dtype=torch.float32, device="cuda") * scale def te_infer(model: torch.nn.Module, inps: Union[Tuple[torch.tensor], torch.tensor], is_fp8: bool): @@ -649,7 +652,7 @@ def forward(self, inp): precision ).to(device='cuda') if use_fp8: - set_layer_scale(model.linear, scale_factor) + set_layer_scale(model.linear, scale_factor, num_gemms=1) do_export(model, inp, fname, use_fp8) if precision in (torch.bfloat16, ): @@ -707,7 +710,7 @@ def test_export_layernorm_linear( zero_centered_gamma=zero_centered_gamma, ).to(device='cuda') if use_fp8: - set_layer_scale(model, scale_factor) + set_layer_scale(model, scale_factor, num_gemms=1) do_export(model, inp, fname, use_fp8) if not use_fp8: validate_result(fname, inp, model, atol=1e-3) @@ -763,7 +766,7 @@ def test_export_layernorm_mlp( zero_centered_gamma=zero_centered_gamma, ).to(device='cuda') if use_fp8: - set_layer_scale(model, scale_factor) + set_layer_scale(model, scale_factor, num_gemms=2) do_export(model, inp, fname, use_fp8) if not use_fp8: validate_result(fname, inp, model, atol=1e-3) diff --git a/transformer_engine/common/recipe.py b/transformer_engine/common/recipe.py index 583b47d80c..3bb5320475 100644 --- a/transformer_engine/common/recipe.py +++ b/transformer_engine/common/recipe.py @@ -66,10 +66,10 @@ class DelayedScaling: fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.HYBRID Controls the FP8 data format used during forward and backward pass. - amax_history_len : int, default = 1 + amax_history_len : int, default = 1024 The length of the amax history window used for scaling factor computation. - amax_compute_algo : {'max', 'most_recent', Callable}, default = 'most_recent' + amax_compute_algo : {'max', 'most_recent', Callable}, default = 'max' Algorithm used for choosing the `amax` value for the scaling factor computation. There are 2 predefined choices: `max` chooses the largest `amax` in the history @@ -125,8 +125,8 @@ def scaling_factor_compute(amax: Tensor, margin: int = 0 interval: int = 1 fp8_format: Format = Format.HYBRID - amax_history_len: int = 1 - amax_compute_algo: Union[Literal["max", "most_recent"], Callable] = "most_recent" + amax_history_len: int = 1024 + amax_compute_algo: Union[Literal["max", "most_recent"], Callable] = "max" override_linear_precision: _OverrideLinearPrecision = _OverrideLinearPrecision() scaling_factor_compute_algo: Optional[Callable] = None reduce_amax: bool = True diff --git a/transformer_engine/pytorch/module.py b/transformer_engine/pytorch/module.py index 4e012be58c..516081c7b2 100644 --- a/transformer_engine/pytorch/module.py +++ b/transformer_engine/pytorch/module.py @@ -13,6 +13,7 @@ import numpy as np import torch +import torch.nn.functional as F from torch.nn.parameter import Parameter from torch.nn import init @@ -187,6 +188,23 @@ def __init__(self) -> None: def set_meta_tensor(self, fwd: bool) -> None: """Init scales and amaxes for fwd | bwd.""" fp8_meta_tensor_key = "scaling_fwd" if fwd else "scaling_bwd" + + if self.fp8_meta_tensors_initialized: + # Handle changed amax history size. + curr_len = self.fp8_meta[fp8_meta_tensor_key].amax_history.shape[0] + need_len = self.fp8_meta["recipe"].amax_history_len + if need_len < curr_len: + self.fp8_meta[fp8_meta_tensor_key].amax_history = ( + self.fp8_meta[fp8_meta_tensor_key] + .amax_history[: self.fp8_meta["recipe"].amax_history_len].clone() + ) + elif need_len > curr_len: + extra_rows = need_len - curr_len + self.fp8_meta[fp8_meta_tensor_key].amax_history = F.pad( + self.fp8_meta[fp8_meta_tensor_key].amax_history, pad=(0, 0, 0, extra_rows) + ) + return + # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and # 2 (grad_output and grad_input) for bwd num_fp8_tensors = ( @@ -222,12 +240,9 @@ def set_meta_tensor(self, fwd: bool) -> None: def init_fp8_meta_tensors(self) -> None: """Init scales and amaxes.""" - # Checkpoint loaded - if self.fp8_meta_tensors_initialized: - return - self.set_meta_tensor(True) self.set_meta_tensor(False) + self.fp8_meta_tensors_initialized = True def get_extra_state(self) -> torch.Tensor: """Save before checkpointing.""" @@ -280,7 +295,6 @@ def set_extra_state(self, state: torch.Tensor) -> None: self.fp8_meta["scaling_fwd"].amax_history.copy_(amax_history_fwd) self.fp8_meta["scaling_bwd"].scale.copy_(scale_bwd) self.fp8_meta["scaling_bwd"].amax_history.copy_(amax_history_bwd) - self.fp8_meta_tensors_initialized = True # Restore global FP8 buffer state. set_global_fp8_buffer(state[4]) @@ -310,7 +324,6 @@ def set_extra_state(self, state: torch.Tensor) -> None: self.fp8_meta["scaling_fwd"].amax_history.copy_(state["amax_history_fwd"]) self.fp8_meta["scaling_bwd"].scale.copy_(state["scale_bwd"]) self.fp8_meta["scaling_bwd"].amax_history.copy_(state["amax_history_bwd"]) - self.fp8_meta_tensors_initialized = True def set_activation_dtype(self, inp: torch.Tensor) -> None: """Get activation data type for AMP.""" From a7537155847907d2a27b330d94d21e526ddbf20e Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 4 Apr 2023 09:56:44 -0700 Subject: [PATCH 15/68] Add FP8 support for Ada (#129) * Add FP8 support for Ada Signed-off-by: Kirthi Shankar Sivamani * Fixes Signed-off-by: Kirthi Shankar Sivamani * better message Signed-off-by: Kirthi Shankar Sivamani * lint fixes Signed-off-by: Kirthi Shankar Sivamani * Address review comments Signed-off-by: Kirthi Shankar Sivamani * better message for no fp8 Signed-off-by: Kirthi Shankar Sivamani * same thing for onnx test Signed-off-by: Kirthi Shankar Sivamani * fix Signed-off-by: Kirthi Shankar Sivamani * Fix CI and review Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani --- tests/pytorch/test_onnx_export.py | 35 +++++++++---------- tests/pytorch/test_sanity.py | 24 ++++++------- transformer_engine/CMakeLists.txt | 2 +- transformer_engine/pytorch/csrc/common.h | 1 + transformer_engine/pytorch/csrc/extensions.cu | 8 +++++ transformer_engine/pytorch/fp8.py | 31 ++++++++++++++-- 6 files changed, 67 insertions(+), 34 deletions(-) diff --git a/tests/pytorch/test_onnx_export.py b/tests/pytorch/test_onnx_export.py index 40486057f4..9f2308f5e4 100644 --- a/tests/pytorch/test_onnx_export.py +++ b/tests/pytorch/test_onnx_export.py @@ -31,6 +31,7 @@ import transformer_engine.pytorch.cpp_extensions as texcpp import transformer_engine.pytorch.softmax as softmax_defs from transformer_engine.pytorch.utils import get_default_init_method +from transformer_engine.pytorch.fp8 import is_fp8_available # Directory where generated ONNX test models are stored. @@ -46,10 +47,8 @@ OPSET = 15 assert OPSET >= TRILU_OPSET -skip_FP8 = pytest.mark.skipif( - torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9, - reason="Device compute capability 9.x required for FP8 execution.", -) +fp8_available, reason_for_no_fp8 = is_fp8_available() +skip_FP8 = pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) def create_fp8_recipe(): return recipe.DelayedScaling(margin=0, interval=1, fp8_format=recipe.Format.E4M3) @@ -346,8 +345,8 @@ def test_export_gemm( scale_factors ): # Skip FP8 tests on non-hopper devices - if use_fp8 and torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9: - pytest.skip("Device compute capability 9.x required for FP8 execution.") + if use_fp8 and not fp8_available: + pytest.skip(reason_for_no_fp8) class TestFP8_GEMM(nn.Module): def __init__(self, precision, use_bias, gelu, scale_factors): @@ -467,8 +466,8 @@ def test_export_layernorm( zero_centered_gamma: bool ): # Skip FP8 tests on non-hopper devices - if use_fp8 and torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9: - pytest.skip("Device compute capability 9.x required for FP8 execution.") + if use_fp8 and not fp8_available: + pytest.skip(reason_for_no_fp8) # Set dimensions (these are arbitrary). inp_shape = [64, 32] @@ -608,8 +607,8 @@ def test_export_linear( precision: torch.dtype ): # Skip FP8 tests on non-hopper devices - if use_fp8 and torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9: - pytest.skip("Device compute capability 9.x required for FP8 execution.") + if use_fp8 and not fp8_available: + pytest.skip(reason_for_no_fp8) # Set dimensions (these are arbitrary). in_features = 64 @@ -686,8 +685,8 @@ def test_export_layernorm_linear( zero_centered_gamma: bool ): # Skip FP8 tests on non-hopper devices - if use_fp8 and torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9: - pytest.skip("Device compute capability 9.x required for FP8 execution.") + if use_fp8 and not fp8_available: + pytest.skip(reason_for_no_fp8) # Set dimensions (these are arbitrary). in_features = 64 @@ -741,8 +740,8 @@ def test_export_layernorm_mlp( zero_centered_gamma: bool ): # Skip FP8 tests on non-hopper devices - if use_fp8 and torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9: - pytest.skip("Device compute capability 9.x required for FP8 execution.") + if use_fp8 and not fp8_available: + pytest.skip(reason_for_no_fp8) # Set dimensions (these are arbitrary). in_features = 64 @@ -861,8 +860,8 @@ def test_export_multihead_attention( fuse_qkv_params: bool ): # Skip FP8 tests on non-hopper devices - if use_fp8 and torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9: - pytest.skip("Device compute capability 9.x required for FP8 execution.") + if use_fp8 and not fp8_available: + pytest.skip(reason_for_no_fp8) hidden_size = 256 sequence_length = 128 @@ -938,8 +937,8 @@ def test_export_transformer_layer( zero_centered_gamma: bool ): # Skip FP8 tests on non-hopper devices - if use_fp8 and torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9: - pytest.skip("Device compute capability 9.x required for FP8 execution.") + if use_fp8 and not fp8_available: + pytest.skip(reason_for_no_fp8) # Layer configuration hidden_size = 64 diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 3ff0f66bc9..3af50f59c3 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -5,7 +5,7 @@ import torch import pytest -from transformer_engine.pytorch.fp8 import fp8_autocast +from transformer_engine.pytorch.fp8 import fp8_autocast, is_fp8_available from transformer_engine.pytorch.utils import ( init_method_normal, scaled_init_method_normal, @@ -19,7 +19,7 @@ from transformer_engine.common import recipe # Only run FP8 tests on H100. -fp8_available = torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9 +fp8_available, reason_for_no_fp8 = is_fp8_available() def custom_amax_to_scale( @@ -263,7 +263,7 @@ def _test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad): @pytest.mark.parametrize("zero_centered_gamma", all_boolean) def test_sanity_layernorm_linear(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma): if fp8_recipe is not None and not fp8_available: - pytest.skip("FP8 device not available.") + pytest.skip(reason_for_no_fp8) config = model_configs[model] @@ -291,7 +291,7 @@ def test_sanity_layernorm_linear(dtype, bs, fp8_recipe, model, skip_wgrad, zero_ @pytest.mark.parametrize("skip_wgrad", all_boolean) def test_sanity_linear(dtype, bs, fp8_recipe, model, skip_wgrad): if fp8_recipe is not None and not fp8_available: - pytest.skip("FP8 device not available.") + pytest.skip(reason_for_no_fp8) config = model_configs[model] @@ -316,7 +316,7 @@ def test_sanity_linear(dtype, bs, fp8_recipe, model, skip_wgrad): @pytest.mark.parametrize("zero_centered_gamma", all_boolean) def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma): if fp8_recipe is not None and not fp8_available: - pytest.skip("FP8 device not available.") + pytest.skip(reason_for_no_fp8) config = model_configs[model] @@ -347,7 +347,7 @@ def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad, zero_cen @pytest.mark.parametrize("zero_centered_gamma", all_boolean) def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma): if fp8_recipe is not None and not fp8_available: - pytest.skip("FP8 device not available.") + pytest.skip(reason_for_no_fp8) config = model_configs[model] @@ -385,7 +385,7 @@ def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamm @pytest.mark.parametrize("zero_centered_gamma", all_boolean) def test_sanity_bert(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma): if fp8_recipe is not None and not fp8_available: - pytest.skip("FP8 device not available.") + pytest.skip(reason_for_no_fp8) config = model_configs[model] @@ -423,7 +423,7 @@ def test_sanity_bert(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gam @pytest.mark.parametrize("zero_centered_gamma", all_boolean) def test_sanity_T5(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma): if fp8_recipe is not None and not fp8_available: - pytest.skip("FP8 device not available.") + pytest.skip(reason_for_no_fp8) config = model_configs[model] @@ -461,7 +461,7 @@ def test_sanity_T5(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma @pytest.mark.parametrize("skip_wgrad", all_boolean) def test_sanity_amp_and_nvfuser(dtype, bs, fp8_recipe, model, skip_wgrad): if fp8_recipe is not None and not fp8_available: - pytest.skip("FP8 device not available.") + pytest.skip(reason_for_no_fp8) config = model_configs[model] @@ -495,7 +495,7 @@ def test_sanity_amp_and_nvfuser(dtype, bs, fp8_recipe, model, skip_wgrad): @pytest.mark.parametrize("skip_wgrad", all_boolean) def test_sanity_drop_path(dtype, bs, fp8_recipe, model, skip_wgrad): if fp8_recipe is not None and not fp8_available: - pytest.skip("FP8 device not available.") + pytest.skip(reason_for_no_fp8) config = model_configs[model] @@ -532,7 +532,7 @@ def test_sanity_drop_path(dtype, bs, fp8_recipe, model, skip_wgrad): @pytest.mark.parametrize("skip_wgrad", all_boolean) def test_sanity_fused_qkv_params(dtype, bs, fp8_recipe, model, skip_wgrad): if fp8_recipe is not None and not fp8_available: - pytest.skip("FP8 device not available.") + pytest.skip(reason_for_no_fp8) config = model_configs[model] @@ -570,7 +570,7 @@ def test_sanity_fused_qkv_params(dtype, bs, fp8_recipe, model, skip_wgrad): @pytest.mark.parametrize("zero_centered_gamma", all_boolean) def test_sanity_gradient_accumulation_fusion(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma): if fp8_recipe is not None and not fp8_available: - pytest.skip("FP8 device not available.") + pytest.skip(reason_for_no_fp8) config = model_configs[model] diff --git a/transformer_engine/CMakeLists.txt b/transformer_engine/CMakeLists.txt index d3ee61ac66..c6977e5ece 100644 --- a/transformer_engine/CMakeLists.txt +++ b/transformer_engine/CMakeLists.txt @@ -5,7 +5,7 @@ cmake_minimum_required(VERSION 3.18) if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) - set(CMAKE_CUDA_ARCHITECTURES 70 80 90) + set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90) endif() diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 67b47dcdcc..f6c9898601 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include diff --git a/transformer_engine/pytorch/csrc/extensions.cu b/transformer_engine/pytorch/csrc/extensions.cu index 47f1eb465e..ec99ad403f 100644 --- a/transformer_engine/pytorch/csrc/extensions.cu +++ b/transformer_engine/pytorch/csrc/extensions.cu @@ -830,6 +830,11 @@ at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_, } +size_t get_cublasLt_version() { + return cublasLtGetVersion(); +} + + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // Softmax functions m.def("scaled_softmax_forward", &scaled_softmax_forward, "Scaled Softmax FWD"); @@ -862,6 +867,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O"); m.def("fp8_gelu", &fp8_gelu, "GeLU with FP8 output"); + // Misc + m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version"); + // Data structures py::class_(m, "FP8TensorMeta") .def(py::init<>()) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 98d35df363..ed9e10ae0d 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -12,6 +12,7 @@ from transformer_engine.common.recipe import DelayedScaling, Format from .constants import dist_group_type +from .utils import get_device_compute_capability _FP8_ENABLED = False _FP8_CALIBRATION = False @@ -26,6 +27,29 @@ _amax_forward_global_reduce_func = None _buffer_delete_key_fwd = None _buffer_delete_key_bwd = None +_is_fp8_available = None +_reason_for_no_fp8 = "" + + +def _check_fp8_support() -> Tuple[bool, str]: + """Return if fp8 support is available""" + if get_device_compute_capability() >= 9.0: # hopper and above + return True, "" + if get_device_compute_capability() < 8.9: # pre-ada + return False, "Device compute capability 8.9 or higher required for FP8 execution." + if tex.get_cublasLt_version() < 120103: + return False, "CublasLt version 12.1.3.x or higher required for FP8 execution on Ada." + if float(torch.version.cuda) < 12.1: + return False, "Cuda version 12.1 or higher required for FP8 execution on Ada." + return True, "" + + +def is_fp8_available() -> Tuple[bool, str]: + """Return if fp8 support is available""" + global _is_fp8_available, _reason_for_no_fp8 + if _is_fp8_available is None: + _is_fp8_available, _reason_for_no_fp8 = _check_fp8_support() + return _is_fp8_available, _reason_for_no_fp8 def get_meta_tensor_key(forward: bool = True) -> str: @@ -253,9 +277,8 @@ def fp8_autocast( _FP8_AUTOCAST_DEPTH += 1 if enabled: - assert ( - torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9 - ), "Device compute capability 9.x required for FP8 execution." + fp8_available, reason_for_no_fp8 = is_fp8_available() + assert fp8_available, reason_for_no_fp8 yield finally: _FP8_ENABLED,_FP8_CALIBRATION, _FP8_RECIPE, _FP8_DISTRIBUTED_GROUP = fp8_state @@ -290,10 +313,12 @@ def is_fp8_enabled() -> bool: """Is FP8 enabled""" return _FP8_ENABLED + def is_fp8_calibration() -> bool: """Is FP8 calibration""" return _FP8_CALIBRATION + def is_first_fp8_module(): """Returns `True` only the first time when called multiple times from within the same `fp8_autocast` context. From 770e968b073c4712f03bcc1a84eb564bf7067997 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20Bastien?= Date: Wed, 5 Apr 2023 11:30:14 -0400 Subject: [PATCH 16/68] Update installation instruction for JAX and add some dependencies. (#117) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Update installation instructio for JAX and add some depenencies. Signed-off-by: Frederic Bastien * Bring back support for none pip installed pybind11. Signed-off-by: Frederic Bastien * Apply suggestions from code review Co-authored-by: Kirthi Shankar Sivamani Signed-off-by: Frédéric Bastien * Changes following review. Signed-off-by: Frederic Bastien * Change order to make it more clear. Signed-off-by: Frederic Bastien * Add other reviers suggestion. Signed-off-by: Frederic Bastien * pybind11 is needed for all FW. Signed-off-by: Frederic Bastien * Add flax as a dep Signed-off-by: Frederic Bastien * Update README.rst Co-authored-by: Kirthi Shankar Sivamani Signed-off-by: Frédéric Bastien --------- Signed-off-by: Frederic Bastien Signed-off-by: Frédéric Bastien Co-authored-by: Kirthi Shankar Sivamani --- README.rst | 26 ++++++++++++++++++++++---- docs/installation.rst | 8 +++++--- setup.py | 27 ++++++++++++++++++++++++--- 3 files changed, 51 insertions(+), 10 deletions(-) diff --git a/README.rst b/README.rst index 8a042194a0..8bccb56912 100644 --- a/README.rst +++ b/README.rst @@ -131,13 +131,31 @@ Transformer Engine comes preinstalled in the pyTorch container on From source ^^^^^^^^^^^ -Clone the repository and inside it type: +For JAX, pybind11 must be installed: .. code-block:: bash - NVTE_FRAMEWORK=all pip install . # Building with all frameworks. - NVTE_FRAMEWORK=pytorch pip install . # Building with pyTorch only. - NVTE_FRAMEWORK=jax pip install . # Building with JAX only. + pip install pybind11 + +Then, you can install this optional dependency: + +.. code-block:: bash + + pip install ninja + +Install TE (optionally specifying the framework): + +.. code-block:: bash + + git clone https://github.com/NVIDIA/TransformerEngine.git + cd TransformerEngine + + # Execute one of the following command + NVTE_FRAMEWORK=all pip install . # Build TE for all supported frameworks. + NVTE_FRAMEWORK=pytorch pip install . # Build TE for PyTorch only. + NVTE_FRAMEWORK=jax pip install . # Build TE for JAX only. + +If the framework is not explicitly specified, TE will be built for PyTorch only. User Guide ---------- diff --git a/docs/installation.rst b/docs/installation.rst index 263d3ed760..0c12b6b79e 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -29,9 +29,11 @@ pip - from GitHub Additional Prerequisites ^^^^^^^^^^^^^^^^^^^^^^^^ -1. `CMake `__ version 3.18 or later -2. `pyTorch `__ with GPU support -3. `Ninja `__ +1. `CMake `__ version 3.18 or later. +2. [For pyTorch support] `pyTorch `__ with GPU support. +3. [For JAX support] `JAX `__ with GPU support, version >= 0.4.7. +4. `pybind11`: `pip install pybind11`. +5. [Optional] `Ninja `__: `pip install ninja`. Installation (stable release) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/setup.py b/setup.py index 4b45cfd7de..46c5106794 100644 --- a/setup.py +++ b/setup.py @@ -161,11 +161,16 @@ def install_requires(): class JaxBuilder(FrameworkBuilderBase): def cmake_flags(self): - return ["-DENABLE_JAX=ON"] + p = [d for d in sys.path if 'dist-packages' in d][0] + return ["-DENABLE_JAX=ON", "-DCMAKE_PREFIX_PATH="+p] def run(self, extensions): print("Building jax extensions!") + def install_requires(): + # TODO: find a way to install pybind11 and ninja directly. + return ['cmake', 'flax'] + ext_modules = [] dlfw_builder_funcs = [] @@ -195,8 +200,13 @@ def run(self, extensions): if framework in ("all", "jax"): dlfw_builder_funcs.append(JaxBuilder) + # Trigger a better error when pybind11 isn't present. + # Sadly, if pybind11 was installed with `apt -y install pybind11-dev` + # This doesn't install a python packages. So the line bellow is too strict. + # When it fail, we need to detect if cmake will find pybind11. + # import pybind11 -dlfw_install_requires = [] +dlfw_install_requires = ['pydantic'] for builder in dlfw_builder_funcs: dlfw_install_requires = dlfw_install_requires + builder.install_requires() @@ -257,10 +267,16 @@ def build_extensions(self) -> None: build_dir = os.path.abspath(build_dir) cmake_args = [ - "-GNinja", "-DCMAKE_BUILD_TYPE=" + config, "-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{}={}".format(config.upper(), build_dir), ] + try: + import ninja + except ImportError: + pass + else: + cmake_args.append("-GNinja") + cmake_args = cmake_args + self.dlfw_flags cmake_build_args = ["--config", config] @@ -384,5 +400,10 @@ def get_outputs(self): ext_modules=ext_modules, cmdclass={"build_ext": TEBuildExtension}, install_requires=dlfw_install_requires, + extras_require={ + 'test': ['pytest', + 'tensorflow_datasets'], + 'test_pytest': ['onnxruntime',], + }, license_files=("LICENSE",), ) From ee87982096355b860beacd1eae7057715b51e989 Mon Sep 17 00:00:00 2001 From: Sangkug Lym Date: Tue, 18 Apr 2023 09:07:32 -0700 Subject: [PATCH 17/68] Amax reduction interval (#154) * amax reduction internval Signed-off-by: Sangkug Lym Skip TP-domain only AMAX reduction when TP-group is not initialized Signed-off-by: Sangkug Lym * Update transformer_engine/pytorch/fp8.py Co-authored-by: Kirthi Shankar Sivamani Signed-off-by: Sangkug Lym * check TP group initialized Signed-off-by: Sangkug Lym fix Signed-off-by: Sangkug Lym --------- Signed-off-by: Sangkug Lym Co-authored-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/fp8.py | 33 ++++++++++++++++++- transformer_engine/pytorch/module.py | 48 ++++++++++++++++++++++++---- 2 files changed, 73 insertions(+), 8 deletions(-) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 4304c8cd8f..07cad012ec 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """FP8 utilities for TransformerEngine""" +import os from contextlib import contextmanager from collections import deque from typing import Callable, List, Optional, Dict, Any, Tuple, Union @@ -30,6 +31,9 @@ _amax_reduce_handle_fwd = None _is_fp8_available = None _reason_for_no_fp8 = "" +_dp_amax_reduce_interval = None +_dp_amax_reduce_forward_idx = 0 +_dp_amax_reduce_backward_idx = 0 def _check_fp8_support() -> Tuple[bool, str]: @@ -545,6 +549,8 @@ def reduce_tensor_across_group_op_max( def global_amax_reduction( fp8_meta: Dict[str, Any], + tp_group: dist_group_type, + tp_size: int, forward: bool = True, ) -> None: """Concatenate, reduce, and split amaxes in the global buffer.""" @@ -555,12 +561,37 @@ def global_amax_reduction( if amax_buffer_key not in _global_fp8_buffer: return None + # Reduce AMAX in DP-domain at an interval. + global _dp_amax_reduce_interval, _dp_amax_reduce_forward_idx, _dp_amax_reduce_backward_idx + if _dp_amax_reduce_interval is None: + _dp_amax_reduce_interval = int(os.getenv("NVTE_DP_AMAX_REDUCE_INTERVAL", "1")) + + tp_amax_reduce = False + if forward: + if _dp_amax_reduce_forward_idx == 0: + reduce_group = fp8_meta["fp8_group"] + else: + tp_amax_reduce = True + _dp_amax_reduce_forward_idx = (_dp_amax_reduce_forward_idx + 1) % _dp_amax_reduce_interval + else: + if _dp_amax_reduce_backward_idx == 0: + reduce_group = fp8_meta["fp8_group"] + else: + tp_amax_reduce = True + _dp_amax_reduce_backward_idx = (_dp_amax_reduce_backward_idx + 1) % _dp_amax_reduce_interval + + if tp_amax_reduce: + if tp_size > 1: + reduce_group = tp_group + else: + return None + chunk_sizes = [x.numel() for x in _global_fp8_buffer[amax_buffer_key]] contiguous_amax = torch.cat(_global_fp8_buffer[amax_buffer_key]) wait_handle = reduce_tensor_across_group_op_max( contiguous_amax, - fp8_meta["fp8_group"], + reduce_group, fp8_meta["async_amax_reduction"], ) diff --git a/transformer_engine/pytorch/module.py b/transformer_engine/pytorch/module.py index 7c25619485..dff37497d6 100644 --- a/transformer_engine/pytorch/module.py +++ b/transformer_engine/pytorch/module.py @@ -105,7 +105,13 @@ def get_workspace() -> torch.Tensor: return _cublas_workspace @contextmanager -def _prepare_backward(fp8: bool, fp8_meta: Dict[str, Any], name: str = "") -> None: +def _prepare_backward( + fp8: bool, + fp8_meta: Dict[str, Any], + tp_group: dist_group_type, + tp_size: int, + name: str = "" +) -> None: """Checks and prep for BWD.""" if fp8: global _amax_reduce_handle_bwd @@ -132,7 +138,12 @@ def _prepare_backward(fp8: bool, fp8_meta: Dict[str, Any], name: str = "") -> N if fp8 and fp8_meta["recipe"].reduce_amax: if fp8_meta["first_module"]: - _amax_reduce_handle_bwd = global_amax_reduction(fp8_meta, forward=False) + _amax_reduce_handle_bwd = global_amax_reduction( + fp8_meta, + tp_group, + tp_size, + forward=False + ) delete_key_from_amax_buffer(forward=False) @@ -186,7 +197,6 @@ def __init__(self) -> None: self.fp8_meta["recipe"] = get_default_fp8_recipe() self.fp8_meta_tensors_initialized = False self.tp_group = None - self.tp_group_initialized = False self.tp_size = 1 self.sequence_parallel = False self.fp8_weight_shapes = [] @@ -541,7 +551,13 @@ def prepare_forward( if self.fp8 and self.training and self.fp8_meta["recipe"].reduce_amax: set_fp8_context_id(self.fp8_meta["autocast_id_fwd"]) - reduce_func = partial(global_amax_reduction, self.fp8_meta, forward=True) + reduce_func = partial( + global_amax_reduction, + self.fp8_meta, + self.tp_group, + self.tp_size, + forward=True + ) setup_amax_forward_global_reduce_func(reduce_func) def set_nccl_overlap_warning_if_tp(self) -> None: @@ -692,6 +708,7 @@ def forward( fp8_meta: Dict[str, Any], fuse_wgrad_accumulation: bool, tp_group: Union[dist_group_type, None], + tp_size: int, sequence_parallel: bool, tensor_parallel: bool, activation_dtype: torch.dtype, @@ -867,6 +884,7 @@ def forward( ctx.inp_shape = inp.shape ctx.parallel_mode = parallel_mode ctx.tp_group = tp_group + ctx.tp_size = tp_size ctx.return_layernorm_output = return_layernorm_output ctx.bwd_ln_sm_margin = bwd_ln_sm_margin ctx.zero_centered_gamma = zero_centered_gamma @@ -890,7 +908,9 @@ def forward( def backward( ctx, *grad_outputs: Tuple[torch.Tensor, ...] ) -> Tuple[Union[torch.Tensor, None], ...]: - with _prepare_backward(ctx.fp8, ctx.fp8_meta, name="_LayerNormLinear"): + with _prepare_backward( + ctx.fp8, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_LayerNormLinear" + ): ( inputmat, ln_weight, @@ -1065,6 +1085,7 @@ def backward( None, None, None, + None, ) @@ -1381,6 +1402,7 @@ def forward( self.fp8_meta, self.fuse_wgrad_accumulation, self.tp_group, + self.tp_size, self.sequence_parallel, self.tp_size > 1, self.activation_dtype, @@ -1427,6 +1449,7 @@ def forward( fp8_meta: Dict[str, Any], fuse_wgrad_accumulation: bool, tp_group: Union[dist_group_type, None], + tp_size: int, sequence_parallel: bool, tensor_parallel: bool, activation_dtype: torch.dtype, @@ -1563,6 +1586,7 @@ def forward( ctx.inp_shape = inp.shape ctx.parallel_mode = parallel_mode ctx.tp_group = tp_group + ctx.tp_size = tp_size ctx.requires_dgrad = inp.requires_grad # Row Parallel Linear @@ -1579,7 +1603,9 @@ def forward( def backward( ctx, grad_output: torch.Tensor ) -> Tuple[Union[torch.Tensor, None], ...]: - with _prepare_backward(ctx.fp8, ctx.fp8_meta, name="_Linear"): + with _prepare_backward( + ctx.fp8, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_Linear" + ): ( inputmat, inputmat_t, @@ -1730,6 +1756,7 @@ def backward( None, None, None, + None, ) @@ -1995,6 +2022,7 @@ def forward( self.fp8_meta, self.fuse_wgrad_accumulation, self.tp_group, + self.tp_size, self.sequence_parallel, self.tp_size > 1, self.activation_dtype, @@ -2039,6 +2067,7 @@ def forward( fp8_meta: Dict[str, Any], fuse_wgrad_accumulation: bool, tp_group: Union[dist_group_type, None], + tp_size: int, sequence_parallel: bool, tensor_parallel: bool, activation_dtype: torch.dtype, @@ -2282,6 +2311,7 @@ def forward( ctx.tensor_parallel = tensor_parallel ctx.inp_shape = inp.shape ctx.tp_group = tp_group + ctx.tp_size = tp_size ctx.bias_gelu_nvfusion = bias_gelu_nvfusion ctx.return_layernorm_output = return_layernorm_output ctx.set_parallel_mode = set_parallel_mode @@ -2307,7 +2337,9 @@ def forward( def backward( ctx, *grad_outputs: Tuple[torch.Tensor, ...] ) -> Tuple[Union[torch.Tensor, None], ...]: - with _prepare_backward(ctx.fp8, ctx.fp8_meta, name="_LayerNormMLP"): + with _prepare_backward( + ctx.fp8, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_LayerNormMLP" + ): ( inputmat, ln_weight, @@ -2610,6 +2642,7 @@ def backward( None, None, None, + None, ) @@ -2904,6 +2937,7 @@ def forward( self.fp8_meta, self.fuse_wgrad_accumulation, self.tp_group, + self.tp_size, self.sequence_parallel, self.tp_size > 1, self.activation_dtype, From e64fc3be6a7dacf21e992ec4f1ddd5ea6fb6ce21 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Wed, 19 Apr 2023 10:52:31 -0700 Subject: [PATCH 18/68] TP communication overlap with userbuffers (#147) * Port initial changes Co-authored-by: Sangkug Lym Co-authored-by: Vasudevan Rengasamy Signed-off-by: Kirthi Shankar Sivamani * readd FA include for PyTorch Signed-off-by: Kirthi Shankar Sivamani * Re-enable sm_70 + cleanup Signed-off-by: Kirthi Shankar Sivamani * LICENSE, cleanup header Signed-off-by: Kirthi Shankar Sivamani * 5k -> 173 errors Signed-off-by: Kirthi Shankar Sivamani * license and fixes in userbuffers-host Signed-off-by: Kirthi Shankar Sivamani * next round fixes Signed-off-by: Kirthi Shankar Sivamani * final cpp cleanup Signed-off-by: Kirthi Shankar Sivamani * pylinting Signed-off-by: Kirthi Shankar Sivamani * fix from linting Signed-off-by: Kirthi Shankar Sivamani * Turn off default async amax reduction (#148) Signed-off-by: Kirthi Shankar Sivamani * remove unused code path Signed-off-by: Sangkug Lym * cleanup Macros Signed-off-by: Sangkug Lym * fix conflict resolution bug Signed-off-by: Sangkug Lym * Fix gencode flags in setup (#145) * Fix gencode flags based on cuda version Signed-off-by: Kirthi Shankar Sivamani * review suggestions Signed-off-by: Kirthi Shankar Sivamani * revert append_nvcc_threads change Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani * Change overlap config dict error message Signed-off-by: Sangkug Lym * simplify ub initialization Signed-off-by: Sangkug Lym * lint Signed-off-by: Kirthi Shankar Sivamani * fix sanity imports Signed-off-by: Kirthi Shankar Sivamani * cpplint Signed-off-by: Kirthi Shankar Sivamani * fix TensorFlow build Signed-off-by: Kirthi Shankar Sivamani * fix TE macros in public header Signed-off-by: Kirthi Shankar Sivamani * fix lint Signed-off-by: Kirthi Shankar Sivamani * More fixes Signed-off-by: Kirthi Shankar Sivamani * compiles with and w/o MPI Signed-off-by: Kirthi Shankar Sivamani * fixes for python side annotations for conditional compile Signed-off-by: Kirthi Shankar Sivamani * link gdrAPI only when MPI found Signed-off-by: Kirthi Shankar Sivamani * fix comments for dummy var Signed-off-by: Kirthi Shankar Sivamani * Fix linking Signed-off-by: Kirthi Shankar Sivamani * Review comments Signed-off-by: Kirthi Shankar Sivamani * load MPI before TE Signed-off-by: Kirthi Shankar Sivamani * Add Py side argument checks Signed-off-by: Kirthi Shankar Sivamani * remove unused code and catch silent failures Signed-off-by: Kirthi Shankar Sivamani * Fix cpp tests Signed-off-by: Kirthi Shankar Sivamani * fix find_lib path for tests Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani Signed-off-by: Sangkug Lym Co-authored-by: Sangkug Lym Co-authored-by: Vasudevan Rengasamy --- qa/L0_cppunittest/test.sh | 7 +- qa/L0_jax_lint/CPPLINT.cfg | 1 + qa/L0_lint/CPPLINT.cfg | 1 + qa/L0_tensorflow_lint/CPPLINT.cfg | 1 + setup.py | 30 +- tests/cpp/CMakeLists.txt | 6 + tests/cpp/operator/CMakeLists.txt | 8 +- transformer_engine/__init__.py | 1 - transformer_engine/common/CMakeLists.txt | 58 +- transformer_engine/common/__init__.py | 23 + .../comm_gemm_overlap/userbuffers-host.cpp | 464 +++++ .../common/comm_gemm_overlap/userbuffers.cu | 1734 +++++++++++++++++ .../common/gemm/cublaslt_gemm.cu | 11 + .../common/include/transformer_engine/gemm.h | 2 + .../include/transformer_engine/userbuffers.h | 227 +++ transformer_engine/jax/csrc/modules.cpp | 2 +- transformer_engine/pytorch/cpp_extensions.py | 89 +- .../pytorch/csrc/comm_gemm_overlap.h | 579 ++++++ transformer_engine/pytorch/csrc/extensions.cu | 187 +- transformer_engine/pytorch/csrc/extensions.h | 34 +- transformer_engine/pytorch/csrc/ts_fp8_op.cpp | 3 +- transformer_engine/pytorch/module.py | 517 ++++- transformer_engine/pytorch/transformer.py | 38 + .../tensorflow/csrc/extensions.cu | 2 +- 24 files changed, 3942 insertions(+), 83 deletions(-) create mode 100644 transformer_engine/common/comm_gemm_overlap/userbuffers-host.cpp create mode 100644 transformer_engine/common/comm_gemm_overlap/userbuffers.cu create mode 100644 transformer_engine/common/include/transformer_engine/userbuffers.h create mode 100644 transformer_engine/pytorch/csrc/comm_gemm_overlap.h diff --git a/qa/L0_cppunittest/test.sh b/qa/L0_cppunittest/test.sh index 55406c2089..73a27a1fcd 100644 --- a/qa/L0_cppunittest/test.sh +++ b/qa/L0_cppunittest/test.sh @@ -4,11 +4,16 @@ set -e +# Find TE : ${TE_PATH:=/opt/transformerengine} TE_LIB_PATH=`pip show transformer-engine | grep Location | cut -d ' ' -f 2` export LD_LIBRARY_PATH=$TE_LIB_PATH:$LD_LIBRARY_PATH +# Find MPI +MPI_HOME=${MPI_HOME:-/usr/local/mpi} +NVTE_MPI_INCLUDE="$MPI_HOME/lib" + cd $TE_PATH/tests/cpp -cmake -GNinja -Bbuild . +cmake -GNinja -Bbuild -DNVTE_MPI_INCLUDE=$NVTE_MPI_INCLUDE . cmake --build build ctest --test-dir build -j4 diff --git a/qa/L0_jax_lint/CPPLINT.cfg b/qa/L0_jax_lint/CPPLINT.cfg index 9eb7b734bb..a2a06602c1 100644 --- a/qa/L0_jax_lint/CPPLINT.cfg +++ b/qa/L0_jax_lint/CPPLINT.cfg @@ -14,3 +14,4 @@ filter=-build/namespaces filter=-readability/todo filter=-build/header_guard filter=-build/include +filter=-build/c++11 diff --git a/qa/L0_lint/CPPLINT.cfg b/qa/L0_lint/CPPLINT.cfg index 9eb7b734bb..a2a06602c1 100644 --- a/qa/L0_lint/CPPLINT.cfg +++ b/qa/L0_lint/CPPLINT.cfg @@ -14,3 +14,4 @@ filter=-build/namespaces filter=-readability/todo filter=-build/header_guard filter=-build/include +filter=-build/c++11 diff --git a/qa/L0_tensorflow_lint/CPPLINT.cfg b/qa/L0_tensorflow_lint/CPPLINT.cfg index 9eb7b734bb..a2a06602c1 100644 --- a/qa/L0_tensorflow_lint/CPPLINT.cfg +++ b/qa/L0_tensorflow_lint/CPPLINT.cfg @@ -14,3 +14,4 @@ filter=-build/namespaces filter=-readability/todo filter=-build/header_guard filter=-build/include +filter=-build/c++11 diff --git a/setup.py b/setup.py index 55552294e4..decdce51a4 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,11 @@ path = os.path.dirname(os.path.realpath(__file__)) with open(path + "/VERSION", "r") as f: te_version = f.readline() + CUDA_HOME = os.environ.get("CUDA_HOME", "/usr/local/cuda") +MPI_HOME = os.environ.get("MPI_HOME", "/usr/local/mpi") +NVTE_MPI_FOUND = os.path.exists(MPI_HOME) +NVTE_MPI_INCLUDE = os.path.join(MPI_HOME, "include") def get_cuda_bare_metal_version(cuda_dir): raw_output = subprocess.check_output( @@ -51,7 +55,7 @@ def extra_gencodes(cc_flag): def extra_compiler_flags(): - return [ + extra_flags = [ "-O3", "-gencode", "arch=compute_70,code=sm_70", @@ -66,6 +70,9 @@ def extra_compiler_flags(): "--expt-extended-lambda", "--use_fast_math", ] + if NVTE_MPI_FOUND: + extra_flags.append("-DNVTE_MPI_FOUND") + return extra_flags cc_flag = [] @@ -76,12 +83,6 @@ def make_abs_path(l): return [os.path.join(path, p) for p in l] -include_dirs = [ - "transformer_engine/common/include", - "transformer_engine/pytorch/csrc", -] -include_dirs = make_abs_path(include_dirs) - pytorch_sources = [ "transformer_engine/pytorch/csrc/extensions.cu", "transformer_engine/pytorch/csrc/common.cu", @@ -100,6 +101,14 @@ def make_abs_path(l): framework = os.environ.get("NVTE_FRAMEWORK", "pytorch") +include_dirs = [ + "transformer_engine/common/include", + "transformer_engine/pytorch/csrc", +] +if (framework in ("all", "pytorch")) and NVTE_MPI_FOUND: + include_dirs.append(NVTE_MPI_INCLUDE) +include_dirs = make_abs_path(include_dirs) + args = sys.argv.copy() for s in args: if s.startswith("--framework="): @@ -155,10 +164,16 @@ def run(self, extensions): print("Building pyTorch extensions!") self.pytorch_build_extensions.run() + def cmake_flags(self): + if not NVTE_MPI_FOUND: + return [] + return ["-DNVTE_MPI_FOUND=1", f"-DNVTE_MPI_INCLUDE={NVTE_MPI_INCLUDE}"] + @staticmethod def install_requires(): return ["flash-attn>=1.0.2",] + class TensorFlowBuilder(FrameworkBuilderBase): def cmake_flags(self): p = [d for d in sys.path if 'dist-packages' in d][0] @@ -167,6 +182,7 @@ def cmake_flags(self): def run(self, extensions): print("Building TensorFlow extensions!") + class JaxBuilder(FrameworkBuilderBase): def cmake_flags(self): p = [d for d in sys.path if 'dist-packages' in d][0] diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index 75a9d13a20..631b356fec 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -27,6 +27,12 @@ if(NOT DEFINED TE_LIB_PATH) endif() find_library(TE_LIB NAMES transformer_engine PATHS ${TE_LIB_PATH} ENV TE_LIB_PATH REQUIRED) + +if(EXISTS ${NVTE_MPI_INCLUDE}) + find_library(MPI_LIB NAMES mpi PATHS ${NVTE_MPI_INCLUDE} REQUIRED) + message(STATUS "Found MPI library: ${MPI_LIB}") +endif() + message(STATUS "Found transformer_engine library: ${TE_LIB}") include_directories(../../transformer_engine/common/include) include_directories(${CMAKE_SOURCE_DIR}) diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index d720798db5..a77cf98a73 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -17,7 +17,13 @@ add_executable(test_operator test_multi_cast_transpose.cu ../test_common.cu) -target_link_libraries(test_operator PUBLIC CUDA::cudart GTest::gtest_main ${TE_LIB}) +list(APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main ${TE_LIB}) + +if(EXISTS ${NVTE_MPI_INCLUDE}) + list(APPEND test_operator_LINKER_LIBS ${MPI_LIB}) +endif() + +target_link_libraries(test_operator PUBLIC ${test_operator_LINKER_LIBS}) target_compile_options(test_operator PRIVATE -O2) include(GoogleTest) diff --git a/transformer_engine/__init__.py b/transformer_engine/__init__.py index bbe18df6db..6d89b9aad5 100644 --- a/transformer_engine/__init__.py +++ b/transformer_engine/__init__.py @@ -5,7 +5,6 @@ """Top level package""" from . import common - try: from . import pytorch except ImportError as e: diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index cee3cad71d..7459f77e4f 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -1,35 +1,55 @@ # Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. -add_library(transformer_engine SHARED - transformer_engine.cpp - transpose/cast_transpose.cu - transpose/transpose.cu - transpose/cast_transpose_fusion.cu - transpose/transpose_fusion.cu - transpose/multi_cast_transpose.cu - activation/gelu.cu - gemm/cublaslt_gemm.cu - layer_norm/ln_api.cpp - layer_norm/ln_bwd_semi_cuda_kernel.cu - layer_norm/ln_fwd_cuda_kernel.cu - rmsnorm/rmsnorm_api.cpp - rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu - rmsnorm/rmsnorm_fwd_cuda_kernel.cu - util/cast.cu - fused_softmax/scaled_masked_softmax.cu - fused_softmax/scaled_upper_triang_masked_softmax.cu) + +set(transformer_engine_SOURCES) +list(APPEND transformer_engine_SOURCES transformer_engine.cpp + transpose/cast_transpose.cu + transpose/transpose.cu + transpose/cast_transpose_fusion.cu + transpose/transpose_fusion.cu + transpose/multi_cast_transpose.cu + activation/gelu.cu + gemm/cublaslt_gemm.cu + layer_norm/ln_api.cpp + layer_norm/ln_bwd_semi_cuda_kernel.cu + layer_norm/ln_fwd_cuda_kernel.cu + rmsnorm/rmsnorm_api.cpp + rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu + rmsnorm/rmsnorm_fwd_cuda_kernel.cu + util/cast.cu + fused_softmax/scaled_masked_softmax.cu + fused_softmax/scaled_upper_triang_masked_softmax.cu) + +if(NVTE_MPI_FOUND) + list(APPEND transformer_engine_SOURCES comm_gemm_overlap/userbuffers.cu + comm_gemm_overlap/userbuffers-host.cpp) +endif() + +add_library(transformer_engine SHARED ${transformer_engine_SOURCES}) target_include_directories(transformer_engine PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/include") list(APPEND transformer_engine_LINKER_LIBS CUDA::cublas CUDA::cudart CUDA::nvToolsExt) -target_link_libraries(transformer_engine PUBLIC ${transformer_engine_LINKER_LIBS}) +if(NVTE_MPI_FOUND) + list(APPEND transformer_engine_LINKER_LIBS gdrapi) +endif() +target_link_libraries(transformer_engine PUBLIC ${transformer_engine_LINKER_LIBS}) target_include_directories(transformer_engine PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) set_source_files_properties(fused_softmax/scaled_masked_softmax.cu fused_softmax/scaled_upper_triang_masked_softmax.cu PROPERTIES COMPILE_OPTIONS "--use_fast_math") + +if(NVTE_MPI_FOUND) + set_source_files_properties(comm_gemm_overlap/userbuffers.cu + comm_gemm_overlap/userbuffers-host.cpp + PROPERTIES + INCLUDE_DIRECTORIES ${NVTE_MPI_INCLUDE} + COMPILE_OPTIONS "$<$:-maxrregcount=64>") +endif() + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3") diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index 7dfcdc96bb..0a8924f8ed 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -37,4 +37,27 @@ def _load_library(): return ctypes.CDLL(dll_path, mode=ctypes.RTLD_GLOBAL) +def _load_mpi(): + """Load MPI shared library""" + + system = platform.system() + if system == "Linux": + extension = "so" + elif system == "Darwin": + extension = "dylib" + elif system == "Windows": + extension = "dll" + else: + raise RuntimeError(f"Unsupported operating system ({system})") + lib_name = "libmpi." + extension + MPI_HOME = os.environ.get("MPI_HOME", "/usr/local/mpi") + NVTE_MPI_FOUND = os.path.exists(MPI_HOME) + dll_path = os.path.join(MPI_HOME, "lib", lib_name) + + if NVTE_MPI_FOUND: + return ctypes.CDLL(dll_path, mode=ctypes.RTLD_GLOBAL) + return None + + +_TE_LIB_CTYPES = _load_mpi() _TE_LIB_CTYPES = _load_library() diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers-host.cpp b/transformer_engine/common/comm_gemm_overlap/userbuffers-host.cpp new file mode 100644 index 0000000000..14928ed5a1 --- /dev/null +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers-host.cpp @@ -0,0 +1,464 @@ +/************************************************************************* + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static int oob_bcast(void *comm_context, void *buf, int size, int root) { + MPI_Bcast(buf, size, MPI_BYTE, root, + (reinterpret_cast(comm_context))->comm_inter); + return 0; +} + +static int oob_barrier(void *comm_context) { + MPI_Barrier((reinterpret_cast(comm_context))->comm_inter); + return 0; +} + +static int oob_gather(void *comm_context, int root, void *sbuf, void *rbuf, int len) { + MPI_Gather(sbuf, len, MPI_BYTE, rbuf, len, MPI_BYTE, root, + (reinterpret_cast(comm_context))->comm_inter); + return 0; +} + +int stringCmp(const void *a, const void *b) { return strcmp((const char *)a, (const char *)b); } + +#define CUDACHECK(cmd) \ + do { \ + cudaError_t e = cmd; \ + if (e != cudaSuccess) { \ + printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +int pipe_rank(communicator *comm, int step) { + int mynode = comm->myrank / comm->nvsize; + int mylocal = comm->nvrank; + int numlocal = comm->nvsize; + + int newlocal1 = mylocal + step * comm->ar_nvsize * comm->ar2_nvsize; + int newlocal = (numlocal + (newlocal1 % numlocal)) % numlocal; + int newnode = mynode; + newnode += (newlocal1 - newlocal) / numlocal * comm->num_nodes * comm->num2_nodes; + int allnodes = comm->nranks / comm->nvsize; + newnode = (allnodes + (newnode % allnodes)) % allnodes; + return newnode * numlocal + newlocal; +} + +int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenodes, int tensorgpus, + int tensornodes) { + *comm = reinterpret_cast(malloc(sizeof(communicator))); + + int myrank, nranks, cur_dev, ndev; + MPI_Comm_rank(MPI_COMM_WORLD, &myrank); + MPI_Comm_size(MPI_COMM_WORLD, &nranks); + (*comm)->nranks = nranks; + (*comm)->myrank = myrank; + (*comm)->free_region = 0; + (*comm)->launch_mode = NVTE_LAUNCH_GPU | NVTE_LAUNCH_CPU; + + cudaDeviceProp device_prop; + CUDACHECK(cudaGetDevice(&cur_dev)); + CUDACHECK(cudaGetDeviceCount(&ndev)); + CUDACHECK(cudaGetDeviceProperties(&device_prop, cur_dev)); + (*comm)->sm_arch = device_prop.major; + // (*comm)->use_rr_kernel = device_prop.major == 8; + (*comm)->use_rr_kernel = 0; + (*comm)->push = 1; + (*comm)->use_ce = 0; + (*comm)->cga_size = 2; + for (int i = 0; i < userbuffers_op_types; i++) (*comm)->basecounter[i] = 0; + (*comm)->head = 0; + (*comm)->tail = 0; + (*comm)->activeproxy = 1; + (*comm)->active_nreqs = 0; + for (int i = 0; i < userbuffers_op_types; i++) (*comm)->active_req[i].active = -1; + + int ret = 0; + // split communicator + char host_name[MPI_MAX_PROCESSOR_NAME]; + char(*host_names)[MPI_MAX_PROCESSOR_NAME]; + int namelen, bytes, color, my_node, mylocal, numlocal, num_nodes; + int rank = (*comm)->myrank, size = (*comm)->nranks; + MPI_Get_processor_name(host_name, &namelen); + bytes = size * sizeof(char[MPI_MAX_PROCESSOR_NAME]); + host_names = (char(*)[MPI_MAX_PROCESSOR_NAME])malloc(bytes); + strcpy(host_names[rank], host_name); // NOLINT(*) + for (int n = 0; n < size; n++) + MPI_Bcast(&(host_names[n]), MPI_MAX_PROCESSOR_NAME, MPI_CHAR, n, MPI_COMM_WORLD); + qsort(host_names, size, sizeof(char[MPI_MAX_PROCESSOR_NAME]), stringCmp); + + color = 0; + for (int n = 0; n < size; n++) { + if (n > 0 && strcmp(host_names[n - 1], host_names[n])) color++; + if (strcmp(host_name, host_names[n]) == 0) break; + } + free(host_names); + + MPI_Comm_split(MPI_COMM_WORLD, color, rank, &(*comm)->comm_intra); + // find intranode numbers and make internode communicator + // figure out mylocal + MPI_Comm_rank((*comm)->comm_intra, &mylocal); + MPI_Comm_size((*comm)->comm_intra, &numlocal); + (*comm)->nvrank = mylocal; + (*comm)->nvsize = numlocal; + + cpu_set_t cpuset; + CPU_ZERO(&cpuset); + int core; + if (mylocal == 0) core = 50; + if (mylocal == 1) core = 58; + if (mylocal == 2) core = 18; + if (mylocal == 3) core = 26; + if (mylocal == 4) core = 114; + if (mylocal == 5) core = 122; + if (mylocal == 6) core = 82; + if (mylocal == 7) core = 90; + + CPU_SET(core, &cpuset); + if (!getenv("NVTE_NODOUBLE")) { + if (core > 128) + CPU_SET(core - 128, &cpuset); + else + CPU_SET(core + 128, &cpuset); + } + if (getenv("NVTE_DOPIN")) pthread_setaffinity_np(pthread_self(), sizeof(cpu_set_t), &cpuset); + + if (ndev == numlocal) { // all visible devices + if (cur_dev != mylocal) + printf("%d: device used %d[%d] ,resetting device to %d\n", rank, cur_dev, ndev, mylocal); + CUDACHECK(cudaSetDevice(mylocal)); + } + (*comm)->mydev = cur_dev; + // FIXME need to check that numlocal is multiple of pipegpus x tensorgpus + // ar1 is data + int divgpus = pipegpus * tensorgpus; + int datagpus = numlocal / divgpus; + (*comm)->ar_nvsize = datagpus; + (*comm)->ar_firstgpu = mylocal - ((mylocal / tensorgpus) % datagpus) * tensorgpus; + (*comm)->ar_nvrank = (mylocal - (*comm)->ar_firstgpu) / tensorgpus; + // ar2 is tensor + (*comm)->ar2_nvsize = tensorgpus; + (*comm)->ar2_firstgpu = mylocal - mylocal % tensorgpus; + (*comm)->ar2_nvrank = mylocal - (*comm)->ar2_firstgpu; + // ar2 has step equal to ar_nvsize + int allnodes = nranks / numlocal; + int mynode = myrank / numlocal; + int datanodes = allnodes / pipenodes / tensornodes; + int pipenodegroup_id = myrank / numlocal / (datanodes * tensornodes); + + (*comm)->pipe_id = pipegpus * pipenodegroup_id + mylocal / (datagpus * tensorgpus); + + CUDACHECK(cudaFree(0)); + int datanodegroup_id = + myrank / numlocal / datanodes; // data reduction group node belongs, equals 0 for all if both + // pipenodes=1 and tensornodes=1 + // mpi communicator only needed for SHARP which is always allreduce1/data-parallel + MPI_Comm_split(MPI_COMM_WORLD, mylocal + numlocal * datanodegroup_id, rank, &(*comm)->comm_inter); + // different rails from same group are in different subcommunicators + + MPI_Comm_size((*comm)->comm_inter, &num_nodes); + MPI_Comm_rank((*comm)->comm_inter, &my_node); + (*comm)->first_node = mynode - my_node; + (*comm)->num_nodes = num_nodes; + (*comm)->my_node = my_node; + + (*comm)->num2_nodes = tensornodes; + (*comm)->my2_node = (mynode / datanodes) % tensornodes; + (*comm)->first2_node = mynode - (*comm)->my2_node * datanodes; + + char *ib_dev_list; + int ZIONROCE = getenv("NVTE_ZIONROCE") ? atoi(getenv("NVTE_ZIONROCE")) : 0; + int ROCE = getenv("NVTE_ROCE") ? atoi(getenv("NVTE_ROCE")) : 0; + if (ZIONROCE) ROCE = 1; + int DGX_H100 = device_prop.major == 9; + + switch (mylocal) { + case 0:ib_dev_list = "mlx5_0:1"; break; // NOLINT(*) + case 1:ib_dev_list = (char*)(DGX_H100?"mlx5_3:1":"mlx5_1:1"); break; // NOLINT(*) + case 2:ib_dev_list = (char*)(ZIONROCE?"mlx5_4:1":DGX_H100?"mlx5_4:1":"mlx5_2:1"); break; // NOLINT(*) + case 3:ib_dev_list = (char*)(DGX_H100?"mlx5_5:1":"mlx5_3:1"); break; // NOLINT(*) + case 4:ib_dev_list = (char*)(DGX_H100?"mlx5_6:1":"mlx5_6:1"); break; // NOLINT(*) + case 5:ib_dev_list = (char*)(DGX_H100?"mlx5_9:1":"mlx5_7:1"); break; // NOLINT(*) + case 6:ib_dev_list = (char*)(ZIONROCE?"mlx5_10:1":DGX_H100?"mlx5_10:1":"mlx5_8:1"); break; // NOLINT(*) + case 7:ib_dev_list = (char*)(DGX_H100?"mlx5_11:1":"mlx5_9:1"); break; // NOLINT(*) + default: break; + } + + (*comm)->fifo = reinterpret_cast(malloc(sizeof(ub_request) * NVTE_MAX_REQUESTS)); + (*comm)->nblocks = 8; + (*comm)->alignblock = 1024 * 512; + (*comm)->minblock = 1024 * 2 * 1024; + (*comm)->asyncblocks = 16; + + CUDACHECK(cudaMallocHost((void **)&(*comm)->hostflags, // NOLINT(*) + (NVTE_MAX_SMS + 100) * sizeof(int))); + for (int i = 0; i < 100 + NVTE_MAX_SMS; i++) (*comm)->hostflags[i] = 0; + _mm_mfence(); + sleep(1); + + // init_p2p_transport(); + (*comm)->ibnvsize = (*comm)->nvsize; + +#define NBUF 2 +#define LOCALSIZE 4 * (NVTE_REG0_OFFSET(*comm) + NVTE_REG0_FLAGS + NVTE_REG0_COMMBUFFER * NBUF) + // peer pointers + op flags + comm buffer + + CUDACHECK(cudaMalloc(&(*comm)->gpu_ptrs, LOCALSIZE)); // flags and pointers, no block data yet + CUDACHECK(cudaMemset((*comm)->gpu_ptrs, 0, LOCALSIZE)); + CUDACHECK(cudaDeviceSynchronize()); + register_user_buffer_collective(&((*comm)->gpu_ptrs), LOCALSIZE, *comm); // will use handler 0 + CUDACHECK(cudaMalloc(&(*comm)->send_id, (*comm)->nranks * sizeof(int))); + CUDACHECK(cudaMalloc(&(*comm)->recv_id, NVTE_MAX_REGIONS * (*comm)->nranks * sizeof(int))); + CUDACHECK(cudaMemset((*comm)->send_id, 0, (*comm)->nranks * sizeof(int))); + CUDACHECK(cudaMemset((*comm)->recv_id, 0, NVTE_MAX_REGIONS * (*comm)->nranks * sizeof(int))); + (*comm)->sms = 16; + (*comm)->threads = 1024; + +#define GPU_PAGE_SHIFT 16 +#define GPU_PAGE_SIZE (1UL << GPU_PAGE_SHIFT) +#define GPU_PAGE_OFFSET (GPU_PAGE_SIZE - 1) +#define GPU_PAGE_MASK (~GPU_PAGE_OFFSET) + CUDACHECK(cudaMalloc(&(*comm)->flags, 2 * GPU_PAGE_SIZE)); + unsigned int flag = 1; + // cuPointerSetAttribute(&flag, CU_POINTER_ATTRIBUTE_SYNC_MEMOPS, (CUdeviceptr)(*comm)->flags); + CUDACHECK(cudaMemset((*comm)->flags, 0, 2 * GPU_PAGE_SIZE)); + (*comm)->flags = + reinterpret_cast(((CUdeviceptr)(*comm)->flags + GPU_PAGE_SIZE - 1) & GPU_PAGE_MASK); + + using namespace std; + (*comm)->g = gdr_open(); + if ((*comm)->g == NULL) { + fprintf(stderr, "gdrcopy open failed\n"); + return -1; + } + gdr_mh_t mh; + ret = gdr_pin_buffer((*comm)->g, (CUdeviceptr)(*comm)->flags, GPU_PAGE_SIZE, 0, 0, &mh); + if (ret) { + fprintf(stderr, "gdr_pin_buffer failed\n"); + return -1; + } + ret = gdr_map((*comm)->g, mh, (void **)&((*comm)->map_flags), GPU_PAGE_SIZE); // NOLINT(*) + + if (ret) { + fprintf(stderr, "gdr_map failed\n"); + return -1; + } + sched_param param; + pthread_attr_t attr; + pthread_attr_init(&attr); + pthread_attr_getschedparam(&attr, ¶m); + param.sched_priority = sched_get_priority_max(SCHED_FIFO); + + pthread_attr_setschedparam(&attr, ¶m); + + if (getenv("NVTE_UBDEBUG")) + printf("%d/%d:(%d x %d): DP %d x %d TP %d x %d, DPGROUP %dx%d TPGROUP %dx%d PIPE_ID %d/%d\n", + myrank, nranks, myrank / numlocal, myrank % numlocal, (*comm)->my_node, + (*comm)->ar_nvrank, (*comm)->my2_node, (*comm)->ar2_nvrank, (*comm)->num_nodes, + (*comm)->ar_nvsize, (*comm)->num2_nodes, (*comm)->ar2_nvsize, (*comm)->pipe_id, + pipegpus * pipenodes); + fflush(NULL); + + return 0; +} +int create_communicator_grouped(communicator **comm, int pipegpus, int pipenodes) { + return create_communicator_grouped2(comm, pipegpus, pipenodes, 1, 1); +} + +int create_communicator(communicator **comm) { + return create_communicator_grouped2(comm, 1, 1, 1, 1); +} + +void destroy_communicator(communicator *comm) { + comm->activeproxy = 0; + if (!comm->myrank && getenv("NVTE_UBDEBUG")) + printf("waiting for userbuffers proxy thread to exit()\n"); + gdr_close(comm->g); +} + +int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *comm, bool alloc) { + if (comm->free_region > NVTE_MAX_REGIONS) return -1; + int hndl = comm->free_region; + // printf("%d register %d size %lld\n",comm->myrank,hndl,bytes);fflush(NULL); + comm->peer_ptr[hndl] = reinterpret_cast(malloc(sizeof(void *) * (comm->nvsize))); + + if (alloc) { + CUDACHECK(cudaMalloc(gpubuff, bytes)); + } + assert(comm->nvsize <= 8); + cudaIpcMemHandle_t *memhndl = + reinterpret_cast(malloc(sizeof(cudaIpcMemHandle_t) * (comm->nvsize))); + + CUDACHECK(cudaIpcGetMemHandle(&memhndl[comm->nvrank], *gpubuff)); + + MPI_Allgather(&memhndl[comm->nvrank], sizeof(cudaIpcMemHandle_t), MPI_BYTE, memhndl, + sizeof(cudaIpcMemHandle_t), MPI_BYTE, comm->comm_intra); + + for (int i = 0; i < comm->nvsize; i++) + if (i != comm->nvrank) + CUDACHECK(cudaIpcOpenMemHandle((void **)&(comm->peer_ptr[hndl][i]), // NOLINT(*) + memhndl[i], cudaIpcMemLazyEnablePeerAccess)); + comm->peer_ptr[hndl][comm->nvrank] = *gpubuff; + CUDACHECK(cudaDeviceSynchronize()); + + CUDACHECK( + cudaMemcpy(reinterpret_cast(comm->gpu_ptrs) + (hndl * comm->nvsize * sizeof(void *)), + comm->peer_ptr[hndl], comm->nvsize * sizeof(void *), cudaMemcpyHostToDevice)); + + CUDACHECK(cudaDeviceSynchronize()); + free(memhndl); + + comm->mem_ptr[hndl] = *gpubuff; + return comm->free_region++; +} + +int allreduce_userbuff_inplace_gpu(const int handler, const int offset, const int elements, + const int blocksize, communicator *comm, cudaStream_t stream); + +int allreduce2_userbuff_inplace_gpu(const int maxcredit, const int handler, const int offset, + const int elements, const int blocksize, communicator *comm, + cudaStream_t stream, int op); + +int reducescatter2_userbuff_inplace_gpu(const int maxcredit, const int handler, const int offset, + const int elements, const int blocksize, communicator *comm, + cudaStream_t stream, int op); + +int allgather2_userbuff_inplace_gpu(const int maxcredit, const int handler, const int offset, + const int elements, const int blocksize, communicator *comm, + cudaStream_t stream, int op); + +void allreduce_nonsharp_inplace(const int handler, const int offset, const int elements, + communicator *comm, cudaStream_t stream, int op) { + if (elements < 64) NVTE_ERROR("Userbuffer comm for given config not implemented."); + // if(comm->myrank==0) fprintf(stderr,"AR2(%d) user call launch_mode=%d\n",op,comm->launch_mode); + const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; + int blocksize = elements * 2; + int maxcredit = 0; + const int num_nodes = op == userbuffers_allreduceop_nonsharp ? comm->num_nodes : comm->num2_nodes; + blocksize = (comm->nblocks - 1 + (comm->alignblock - 1 + elements * 2) / comm->alignblock) / + comm->nblocks; // FIXME TUNING + blocksize *= comm->alignblock; + if (blocksize < comm->minblock) blocksize = comm->minblock; + + maxcredit = (elements * 2 + blocksize - 1) / blocksize; + // if(maxcredit>4) maxcredit=4; + // if(maxcredit>4 && ar_nvsize==1) maxcredit=4; + size_t peerblock = sizeof(int) * NVTE_REG0_COMMBUFFER / maxcredit; // max size we can fit + if (blocksize > peerblock * ar_nvsize) blocksize = peerblock * ar_nvsize; + // blocksize=elements*2; + int sms = allreduce2_userbuff_inplace_gpu(maxcredit, handler, offset, elements, blocksize, comm, + stream, op); + + if (num_nodes > 1 && comm->launch_mode & NVTE_LAUNCH_CPU) { + if (!sms) return; + comm->fifo[comm->head].optype = op; + comm->fifo[comm->head].basecounter = comm->basecounter[op]; + comm->fifo[comm->head].blocksize = blocksize; + comm->fifo[comm->head].maxcredit = maxcredit; + comm->fifo[comm->head].handler = handler; + comm->fifo[comm->head].offset = offset; + comm->fifo[comm->head].elements = elements; + + int newhead = (comm->head + 1) & (NVTE_MAX_REQUESTS - 1); + while (newhead == comm->tail) { + } + comm->head = newhead; + + comm->basecounter[op] += (elements * 2 + blocksize - 1) / blocksize; + } +} + +void allreduce2_userbuff_inplace(const int handler, const int offset, const int elements, + communicator *comm, cudaStream_t stream) { + allreduce_nonsharp_inplace(handler, offset, elements, comm, stream, + userbuffers_allreduceop_nonsharp2); +} + +void allreduce_userbuff_inplace(const int handler, const int offset, const int elements, + communicator *comm, cudaStream_t stream) { + if (elements < 64) NVTE_ERROR("Userbuffer comm for given config not implemented."); + allreduce_nonsharp_inplace(handler, offset, elements, comm, stream, + userbuffers_allreduceop_nonsharp); + return; +} + +void reducescatter_userbuff_inplace(const int handler, const int offset, const int elements, + communicator *comm, cudaStream_t stream) { + if (elements < 64) NVTE_ERROR("Userbuffer comm for given config not implemented."); + + int op = userbuffers_allreduceop_nonsharp; + const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; + int blocksize = elements * 2; + int maxcredit = 0; + + const int num_nodes = op == userbuffers_allreduceop_nonsharp ? comm->num_nodes : comm->num2_nodes; + blocksize = (comm->nblocks - 1 + (comm->alignblock - 1 + elements * 2) / comm->alignblock) / + comm->nblocks; // FIXME TUNING + blocksize *= comm->alignblock; + if (blocksize < comm->minblock) blocksize = comm->minblock; + + maxcredit = (elements * 2 + blocksize - 1) / blocksize; + size_t peerblock = sizeof(int) * NVTE_REG0_COMMBUFFER / maxcredit; // max size we can fit + if (blocksize > peerblock * ar_nvsize) blocksize = peerblock * ar_nvsize; + + int sms = reducescatter2_userbuff_inplace_gpu(maxcredit, handler, offset, elements, blocksize, + comm, stream, op); + + if (num_nodes > 1 && comm->launch_mode & NVTE_LAUNCH_CPU) { + if (!sms) return; + comm->fifo[comm->head].optype = op; + comm->fifo[comm->head].basecounter = comm->basecounter[op]; + comm->fifo[comm->head].blocksize = blocksize; + comm->fifo[comm->head].maxcredit = maxcredit; + comm->fifo[comm->head].handler = handler; + comm->fifo[comm->head].offset = offset; + comm->fifo[comm->head].elements = elements; + + int newhead = (comm->head + 1) & (NVTE_MAX_REQUESTS - 1); + while (newhead == comm->tail) { + } + comm->head = newhead; + + comm->basecounter[op] += (elements * 2 + blocksize - 1) / blocksize; + } +} + +void allgather_userbuff_inplace(const int handler, const int offset, const int elements, + communicator *comm, cudaStream_t stream) { + if (elements < 64) NVTE_ERROR("Userbuffer comm for given config not implemented."); + int op = userbuffers_allreduceop_nonsharp; + const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; + int blocksize = elements * 2; + int maxcredit = 0; + + const int num_nodes = op == userbuffers_allreduceop_nonsharp ? comm->num_nodes : comm->num2_nodes; + blocksize = (comm->nblocks - 1 + (comm->alignblock - 1 + elements * 2) / comm->alignblock) / + comm->nblocks; // FIXME TUNING + blocksize *= comm->alignblock; + if (blocksize < comm->minblock) blocksize = comm->minblock; + + maxcredit = (elements * 2 + blocksize - 1) / blocksize; + size_t peerblock = sizeof(int) * NVTE_REG0_COMMBUFFER / maxcredit; // max size we can fit + if (blocksize > peerblock * ar_nvsize) blocksize = peerblock * ar_nvsize; + + int sms = allgather2_userbuff_inplace_gpu(maxcredit, handler, offset, elements, blocksize, comm, + stream, op); +} diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers.cu b/transformer_engine/common/comm_gemm_overlap/userbuffers.cu new file mode 100644 index 0000000000..684771801b --- /dev/null +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers.cu @@ -0,0 +1,1734 @@ +/************************************************************************* + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#if __CUDA_ARCH__ >= 800 +#include +#define half nv_bfloat16 +#else +#include +#endif +#include +#include +#include + +#define MAX_THREADS 1024 +#define TIMEOUT 200000000000ull + +#define CUDACHECK(cmd) \ + do { \ + cudaError_t e = cmd; \ + if (e != cudaSuccess) { \ + printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +template +__global__ void __launch_bounds__(MAX_THREADS) + userbuffers_fp16_sum_inplace_gpu_rw(const int op, const int flagoffset, const int firstrank, + const int myrank, const int gpustep, const int lineoffset, + const int numlines, void **commbuff, const int handleridx) { + __shared__ int4 *userptr[RANKS]; + int *flagptr, physgpu, targetgpu, *myptr; + int *reduceidptr, reduce_id; + // if(blockIdx.x==0 && threadIdx.x==0) printf("%d/%d(phys %d gpustep %d firstrank %d):RRkernel(d) + // start, size %lld\n",myrank,RANKS,gpustep*myrank+firstrank,gpustep,firstrank,numlines*16ull); + if (threadIdx.x < RANKS) { + physgpu = myrank * gpustep + firstrank; + targetgpu = threadIdx.x * gpustep + firstrank; + const int blockflagoffset = NVTE_MAX_NVLINK * 2 * blockIdx.x; + myptr = (reinterpret_cast(commbuff[physgpu])) + flagoffset; + reduceidptr = myptr - NVTE_MAX_OPS; // +op; + reduce_id = (*reduceidptr) + 1; + flagptr = (reinterpret_cast(commbuff[targetgpu])) + flagoffset + blockflagoffset; + myptr += blockflagoffset; + + flagptr[physgpu] = reduce_id; + volatile int *flag = (volatile int *)&(myptr[targetgpu]); + userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu + handleridx]); + clock_t s = clock64(); + while (*flag < reduce_id) { + if (clock64() - s > TIMEOUT) { + printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, + *flag); + break; + } + } + reduce_id++; + } + __syncthreads(); + + int warp = blockIdx.x + (threadIdx.x >> 5); + int dest[RANKS]; +#pragma unroll + for (int i = 0; i < RANKS; i++) dest[i] = (i + myrank + warp) & (RANKS - 1); + + __syncthreads(); + for (int line = threadIdx.x + blockDim.x * (myrank + RANKS * blockIdx.x); line < numlines; + line += blockDim.x * gridDim.x * RANKS) { + int4 val[RANKS]; + +#pragma unroll + for (int i = 0; i < RANKS; i++) { + // int dest = (i+myrank+warp)&(RANKS-1); + val[i] = userptr[dest[i]][lineoffset + line]; + } + + int4 sum = val[0]; + half *s = reinterpret_cast(&sum); + +#pragma unroll + for (int i = 1; i < RANKS; i++) { + half *x = reinterpret_cast(&val[i]); +#pragma unroll + for (int j = 0; j < 8; j++) s[j] += x[j]; + } +#pragma unroll + for (int i = 0; i < RANKS; i++) { + // int dest = (i+myrank+warp)&(RANKS-1); + userptr[dest[i]][lineoffset + line] = sum; + } + } + + __syncthreads(); + if (threadIdx.x == 0) __threadfence_system(); + __syncthreads(); + + if (threadIdx.x < RANKS) { + flagptr[physgpu] = reduce_id; + volatile int *flag = (volatile int *)&myptr[targetgpu]; + clock_t s = clock64(); + while (*flag < reduce_id) { + if (clock64() - s > 2ull * TIMEOUT) { + printf("NVONLY AGBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, + *flag); + break; + } + } + } + if (threadIdx.x == 0 && blockIdx.x == 0) *reduceidptr = reduce_id; +} // fp16 inplace reduce kernel (Volta,Hopper) + +template +__global__ void __launch_bounds__(MAX_THREADS) + userbuffers_fp16_sum_inplace_gpu_rr(const int op, const int flagoffset, const int firstrank, + const int myrank, const int gpustep, const int lineoffset, + const int numlines, void **commbuff, const int handleridx) { + __shared__ int4 *userptr[RANKS]; + int *flagptr, physgpu, targetgpu, *myptr; + int *reduceidptr, reduce_id; + if (threadIdx.x < RANKS) { + physgpu = myrank * gpustep + firstrank; + targetgpu = threadIdx.x * gpustep + firstrank; + const int blockflagoffset = NVTE_MAX_NVLINK * 2 * blockIdx.x; + myptr = (reinterpret_cast(commbuff[physgpu])) + flagoffset; + reduceidptr = myptr - NVTE_MAX_OPS; // +op; + reduce_id = (*reduceidptr) + 1; + flagptr = (reinterpret_cast(commbuff[targetgpu])) + flagoffset + blockflagoffset; + myptr += blockflagoffset; + + flagptr[physgpu] = reduce_id; + volatile int *flag = (volatile int *)&(myptr[targetgpu]); + userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu + handleridx]); + clock_t s = clock64(); + while (*flag < reduce_id) { + if (clock64() - s > TIMEOUT) { + printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, + *flag); + break; + } + } + reduce_id++; + } + __syncthreads(); + + int warp = blockIdx.x + (threadIdx.x >> 5); + int dest[RANKS]; +#pragma unroll + for (int i = 0; i < RANKS; i++) dest[i] = (i + myrank + warp) & (RANKS - 1); + + __syncthreads(); + for (int line = threadIdx.x + blockDim.x * (myrank + RANKS * blockIdx.x); line < numlines; + line += blockDim.x * gridDim.x * RANKS) { + int4 val[RANKS]; + +#pragma unroll + for (int i = 0; i < RANKS; i++) { + val[i] = userptr[dest[i]][lineoffset + line]; + } + + int4 sum = val[0]; + half *s = reinterpret_cast(&sum); + +#pragma unroll + for (int i = 1; i < RANKS; i++) { + half *x = reinterpret_cast(&val[i]); +#pragma unroll + for (int j = 0; j < 8; j++) s[j] += x[j]; + } + + userptr[myrank][lineoffset + line] = sum; + } + __syncthreads(); + if (threadIdx.x == 0) __threadfence(); + __syncthreads(); + + if (threadIdx.x < RANKS) { + flagptr[physgpu] = reduce_id; + volatile int *flag = (volatile int *)&myptr[targetgpu]; + clock_t s = clock64(); + while (*flag < reduce_id) { + if (clock64() - s > 2ull * TIMEOUT) { + printf("NVONLY AGBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, + *flag); + break; + } + } + } + + int skipmy = 0; +#pragma unroll + for (int i = 0; i < RANKS; i++) { + int dst = (i + warp + myrank) & (RANKS - 1); + if (dst == myrank) { + skipmy++; + continue; + } + dest[i - skipmy] = dst; + } + __syncthreads(); + + for (int line = threadIdx.x + blockDim.x * RANKS * blockIdx.x; line < numlines; + line += blockDim.x * gridDim.x * RANKS) { + int4 val[RANKS - 1]; + +#pragma unroll + for (int i = 0; i < RANKS - 1; i++) { + val[i] = userptr[dest[i]][lineoffset + line + blockDim.x * dest[i]]; + } + +#pragma unroll + for (int i = 0; i < RANKS - 1; i++) { + userptr[myrank][lineoffset + line + blockDim.x * dest[i]] = val[i]; + } + } + if (threadIdx.x == 0 && blockIdx.x == 0) *reduceidptr = reduce_id; +} // fp16 inplace reduce kernel (Ampere) + +template +__global__ void __launch_bounds__(MAX_THREADS) + userbuffers_fp16_sum_inplace_gpu_rr_rs(const int op, const int flagoffset, const int firstrank, + const int myrank, const int gpustep, + const int mylineoffset, const int totallines, + void **commbuff, const int handleridx) { + __shared__ int4 *userptr[RANKS]; + int *flagptr, physgpu, targetgpu, *myptr; + int *reduceidptr, reduce_id; + if (threadIdx.x < RANKS) { + physgpu = myrank * gpustep + firstrank; + targetgpu = threadIdx.x * gpustep + firstrank; + const int blockflagoffset = NVTE_MAX_NVLINK * 2 * blockIdx.x; + myptr = (reinterpret_cast(commbuff[physgpu])) + flagoffset; + reduceidptr = myptr - NVTE_MAX_OPS; // +op; + reduce_id = (*reduceidptr) + 1; + flagptr = (reinterpret_cast(commbuff[targetgpu])) + flagoffset + blockflagoffset; + myptr += blockflagoffset; + + flagptr[physgpu] = reduce_id; + volatile int *flag = (volatile int *)&(myptr[targetgpu]); + userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu + handleridx]); + clock_t s = clock64(); + while (*flag < reduce_id) { + if (clock64() - s > TIMEOUT) { + printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, + *flag); + break; + } + } + } + __syncthreads(); + + int warp = blockIdx.x + (threadIdx.x >> 5); + int dest[RANKS]; +#pragma unroll + for (int i = 0; i < RANKS; i++) dest[i] = (i + myrank + warp) & (RANKS - 1); + + __syncthreads(); + for (int line = threadIdx.x + blockDim.x * blockIdx.x; line < totallines; + line += blockDim.x * gridDim.x) { + int4 val[RANKS]; + +#pragma unroll + for (int i = 0; i < RANKS; i++) { + val[i] = userptr[dest[i]][mylineoffset + line]; + } + + int4 sum = val[0]; + half *s = reinterpret_cast(&sum); + +#pragma unroll + for (int i = 1; i < RANKS; i++) { + half *x = reinterpret_cast(&val[i]); +#pragma unroll + for (int j = 0; j < 8; j++) s[j] += x[j]; + } + + userptr[myrank][mylineoffset + line] = sum; + } + + if (threadIdx.x == 0 && blockIdx.x == 0) *reduceidptr = reduce_id; +} // fp16 inplace reduce-scatter kernel + +template +__global__ void __launch_bounds__(MAX_THREADS) + userbuffers_fp16_sum_inplace_gpu_rr_rs_oop(const int op, const int flagoffset, + const int firstrank, const int myrank, + const int gpustep, const int mylineoffset, + const int totallines, const int rowlines, + const int skiplines, void **commbuff, + const int handleridx, void *outbuf) { + __shared__ int4 *userptr[RANKS]; + int *flagptr, physgpu, targetgpu, *myptr; + int *reduceidptr, reduce_id; + if (threadIdx.x < RANKS) { + physgpu = myrank * gpustep + firstrank; + targetgpu = threadIdx.x * gpustep + firstrank; + const int blockflagoffset = NVTE_MAX_NVLINK * 2 * blockIdx.x; + myptr = (reinterpret_cast(commbuff[physgpu])) + flagoffset; + reduceidptr = myptr - NVTE_MAX_OPS; // +op; + reduce_id = (*reduceidptr) + 1; + flagptr = (reinterpret_cast(commbuff[targetgpu])) + flagoffset + blockflagoffset; + myptr += blockflagoffset; + + flagptr[physgpu] = reduce_id; + volatile int *flag = (volatile int *)&(myptr[targetgpu]); + userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu + handleridx]); + clock_t s = clock64(); + while (*flag < reduce_id) { + if (clock64() - s > TIMEOUT) { + printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, + *flag); + break; + } + } + } + __syncthreads(); + + int warp = blockIdx.x + (threadIdx.x >> 5); + int dest[RANKS]; +#pragma unroll + for (int i = 0; i < RANKS; i++) dest[i] = (i + myrank + warp) & (RANKS - 1); + + __syncthreads(); + for (int line = threadIdx.x + blockDim.x * blockIdx.x; line < totallines; + line += blockDim.x * gridDim.x) { + int4 val[RANKS]; + +#pragma unroll + for (int i = 0; i < RANKS; i++) { + val[i] = userptr[dest[i]][mylineoffset + line]; + } + + int4 sum = val[0]; + half *s = reinterpret_cast(&sum); + +#pragma unroll + for (int i = 1; i < RANKS; i++) { + half *x = reinterpret_cast(&val[i]); +#pragma unroll + for (int j = 0; j < 8; j++) s[j] += x[j]; + } + + (reinterpret_cast(outbuf))[(line / rowlines) * skiplines + (line % rowlines)] = sum; + } + + if (threadIdx.x == 0 && blockIdx.x == 0) *reduceidptr = reduce_id; +} // fp16 reduce-scatter kernel (out of place) + +template +__global__ void __launch_bounds__(MAX_THREADS) + userbuffers_fp16_sum_inplace_gpu_rr_ag(const int op, const int flagoffset, const int firstrank, + const int myrank, const int gpustep, + const int mylineoffset, const int totallines, + void **commbuff, const int handleridx) { + __shared__ int4 *userptr[RANKS]; + int *flagptr, physgpu, targetgpu, *myptr; + int *reduceidptr, reduce_id; + if (threadIdx.x < RANKS) { + physgpu = myrank * gpustep + firstrank; + targetgpu = threadIdx.x * gpustep + firstrank; + const int blockflagoffset = NVTE_MAX_NVLINK * 2 * blockIdx.x; + myptr = (reinterpret_cast(commbuff[physgpu])) + flagoffset; + reduceidptr = myptr - NVTE_MAX_OPS; // +op; + reduce_id = (*reduceidptr) + 1; + flagptr = (reinterpret_cast(commbuff[targetgpu])) + flagoffset + blockflagoffset; + myptr += blockflagoffset; + + flagptr[physgpu] = reduce_id; + volatile int *flag = (volatile int *)&(myptr[targetgpu]); + userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu + handleridx]); + clock_t s = clock64(); + } + + int warp = blockIdx.x + (threadIdx.x >> 5); + int dest[RANKS]; + + int skipmy = 0; +#pragma unroll + for (int i = 0; i < RANKS; i++) { + int dst = (i + warp + myrank) & (RANKS - 1); + if (dst == myrank) { + skipmy++; + continue; + } + dest[i - skipmy] = dst; + } + __syncthreads(); + + for (int line = threadIdx.x + blockDim.x * blockIdx.x; line < totallines; + line += blockDim.x * gridDim.x) { + int4 val[RANKS - 1]; + +#pragma unroll + for (int i = 0; i < RANKS - 1; i++) { + val[i] = userptr[dest[i]][mylineoffset + line + totallines * dest[i]]; + } + +#pragma unroll + for (int i = 0; i < RANKS - 1; i++) { + userptr[myrank][mylineoffset + line + totallines * dest[i]] = val[i]; + } + } + if (threadIdx.x == 0 && blockIdx.x == 0) *reduceidptr = reduce_id; +} // fp16 inplace reduce kernel (Ampere) + +template +__global__ void __launch_bounds__(MAX_THREADS) + userbuffers_fp16_sum_inplace_gpu_rw_ag(const int op, const int flagoffset, const int firstrank, + const int myrank, const int gpustep, + const int mylineoffset, const int totallines, + void **commbuff, const int handleridx) { + __shared__ int4 *userptr[RANKS]; + int *flagptr, physgpu, targetgpu, *myptr; + int *reduceidptr, reduce_id; + int4 *localptr; + if (threadIdx.x < RANKS) { + physgpu = myrank * gpustep + firstrank; + targetgpu = threadIdx.x * gpustep + firstrank; + const int blockflagoffset = NVTE_MAX_NVLINK * 2 * blockIdx.x; + myptr = (reinterpret_cast(commbuff[physgpu])) + flagoffset; + reduceidptr = myptr - NVTE_MAX_OPS; // +op; + reduce_id = (*reduceidptr) + 1; + flagptr = (reinterpret_cast(commbuff[targetgpu])) + flagoffset + blockflagoffset; + myptr += blockflagoffset; + userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu + handleridx]); + reduce_id++; + } + __syncthreads(); + localptr = userptr[myrank]; + + int warp = blockIdx.x + (threadIdx.x >> 5); + int dest[RANKS - 1]; + int skipmy = 0; +#pragma unroll + for (int i = 0; i < RANKS; i++) { + int dst = (i + warp + myrank) & (RANKS - 1); + if (dst == myrank) { + skipmy++; + continue; + } + dest[i - skipmy] = dst; + } +#define UNROLLAG 4 + __syncthreads(); + const int loop_step0 = blockDim.x * gridDim.x; + const int loop_step = loop_step0 * UNROLLAG; + const int start_elem = threadIdx.x + blockDim.x * blockIdx.x; + const int end_elem = max(start_elem, totallines); + const int aligned_elem = ((end_elem - start_elem) / loop_step) * loop_step; + const int end_aligned = start_elem + aligned_elem; + + for (int line = start_elem; line < end_aligned; line += loop_step) { + int4 val[UNROLLAG]; +#pragma unroll + for (int j = 0; j < UNROLLAG; j++) val[j] = localptr[mylineoffset + line + loop_step0 * j]; + +#pragma unroll + for (int j = 0; j < UNROLLAG; j++) +#pragma unroll + for (int i = 0; i < RANKS - 1; i++) { + userptr[dest[i]][mylineoffset + line + j * loop_step0] = val[j]; + } + } + + for (int line = end_aligned; line < end_elem; line += loop_step0) { + int4 sum = localptr[mylineoffset + line]; +#pragma unroll + for (int i = 0; i < RANKS - 1; i++) { + userptr[dest[i]][mylineoffset + line] = sum; + } + } + + __syncthreads(); + if (threadIdx.x == 0) __threadfence_system(); + __syncthreads(); + + if (threadIdx.x < RANKS) { + flagptr[physgpu] = reduce_id; + volatile int *flag = (volatile int *)&myptr[targetgpu]; + clock_t s = clock64(); + while (*flag < reduce_id) { + if (clock64() - s > 2ull * TIMEOUT) { + printf("NVONLY AGBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, + *flag); + break; + } + } + } + if (threadIdx.x == 0 && blockIdx.x == 0) *reduceidptr = reduce_id; +} // fp16 inplace allgather kernel (Volta,Hopper) + +template +__global__ void __launch_bounds__(MAX_THREADS) + userbuffers_fp16_sum_inplace_gpu_rr_blocked(const int op, const int flagoffset, + const int firstrank, const int myrank, + const int lineoffset, const int numlines, + void **commbuff, const int handleridx, + const int peerblocklines, int *hostflags, + int *gpuflag, const int numblocks) { + const int basecounter = gpuflag[NVTE_GF_STATE + op]; + +#define REDUCETHREADS (blockDim.x - 32) + + if (threadIdx.x < 32) { + int *flagptr; + if (threadIdx.x < RANKS) { + if (!blockIdx.x) { + flagptr = reinterpret_cast(commbuff[threadIdx.x + firstrank]); + flagptr[flagoffset + myrank + firstrank] = basecounter; + } + volatile int *flag = (volatile int *)&((reinterpret_cast( + commbuff[myrank + firstrank]))[flagoffset + threadIdx.x + firstrank]); + while (*flag < basecounter) { + } + } + __syncthreads(); + + int startblock = 0, endblock = numblocks; + + for (int nblock = 0; nblock < endblock; nblock++) { + asm volatile("bar.sync 13, %0;" ::"r"(REDUCETHREADS + 32)); + + if (threadIdx.x == 0) { + __threadfence(); + if (blockIdx.x) gpuflag[op * NVTE_MAX_SMS * 2 + blockIdx.x] = nblock + basecounter + 1; + } else if (blockIdx.x == 0) { + int expecting = (basecounter + nblock + 1); + if (threadIdx.x < gridDim.x) + while (((volatile int *)gpuflag)[op * NVTE_MAX_SMS * 2 + threadIdx.x] < expecting) { + } + } + if (!blockIdx.x) { + asm volatile("bar.sync 15, %0;" ::"r"(32)); + if (!threadIdx.x) hostflags[0] = nblock + basecounter + 1; + } + } + + int cachedflag = basecounter; + +#define ALLGATHERFLAG NVTE_GF_IBSHARPDONE + + if (blockIdx.x == 0 && threadIdx.x < RANKS) { + while (cachedflag < basecounter + numblocks) { + int newflag = ((volatile int *)gpuflag)[ALLGATHERFLAG]; + if (newflag == cachedflag) continue; + cachedflag = newflag; + flagptr[flagoffset + myrank + 32 + firstrank] = cachedflag; + } + } + + if (blockIdx.x == 0 && threadIdx.x == 0) gpuflag[NVTE_GF_STATE + op] = basecounter + numblocks; + } else { + const int warp = blockIdx.x + (threadIdx.x >> 5); + int4 *userptr[RANKS]; + int4 *userptrmyrank; +#pragma unroll + for (int i = 0; i < RANKS; i++) + userptr[i] = reinterpret_cast( + commbuff[((i + myrank + warp) & (RANKS - 1)) + handleridx + firstrank]); + userptrmyrank = reinterpret_cast(commbuff[myrank + handleridx + firstrank]); + __syncthreads(); + + int blocklineoffset = 0; + + while (blocklineoffset < numlines) { + const int remainder = min(numlines - blocklineoffset, peerblocklines * RANKS); + const int blocklines = remainder / RANKS; + const int blockstart = lineoffset + blocklineoffset + blocklines * myrank; + + for (int line = threadIdx.x - 32 + REDUCETHREADS * blockIdx.x; line < blocklines; + line += REDUCETHREADS * gridDim.x) { + int4 val[RANKS]; + +#pragma unroll + for (int i = 0; i < RANKS; i++) { + val[i] = userptr[i][blockstart + line]; + } + + int4 sum = val[0]; + half *s = reinterpret_cast(&sum); + +#pragma unroll + for (int i = 1; i < RANKS; i++) { + half *x = reinterpret_cast(&val[i]); +#pragma unroll + for (int j = 0; j < sizeof(int4) / sizeof(half); j++) s[j] += x[j]; + } + + userptrmyrank[blockstart + line] = sum; + } // single block loop + + asm volatile("bar.sync 13, %0;" ::"r"(REDUCETHREADS + 32)); + + blocklineoffset += peerblocklines * RANKS; + } // block loop NVLINK-REDUCESCATTER + const int nwarps = (REDUCETHREADS >> 5) / (RANKS - 1); + const int myblockDim = nwarps << 5; + const int mywarp = ((threadIdx.x - 32) >> 5) / (RANKS - 1); + const int maxthreadIdx = myblockDim * (RANKS - 1) + 32; + const int mydest = (myrank + 1 + ((threadIdx.x - 32) >> 5) % (RANKS - 1)) & (RANKS - 1); + const int mythreadIdx = (mywarp << 5) + (threadIdx.x & 31); + volatile int *flag = (volatile int *)&((reinterpret_cast( + commbuff[myrank + firstrank]))[flagoffset + mydest + 32 + firstrank]); + + int4 *userptrmydest = userptr[((RANKS << 10) + mydest - myrank - warp) & (RANKS - 1)]; + + blocklineoffset = 0; + int gathercounter = basecounter + 1; + while (blocklineoffset < numlines) { + const int remainder = min(numlines - blocklineoffset, peerblocklines * RANKS); + const int blocklines = remainder / RANKS; + const int blockstart = lineoffset + blocklineoffset; + +#define UNROLL 6 + int4 *myptr = &userptrmyrank[blockstart + blocklines * mydest]; + int4 *peerptr = &userptrmydest[blockstart + blocklines * mydest]; + + if (threadIdx.x < maxthreadIdx) { + const int start_elem = mythreadIdx + myblockDim * blockIdx.x; + const int end_elem = max(start_elem, blocklines); + const int aligned_elem = ((end_elem - start_elem) / (myblockDim * gridDim.x * UNROLL)) * + (myblockDim * gridDim.x * UNROLL); + const int end_aligned = start_elem + aligned_elem; + + if (mythreadIdx == 0) { + while (*flag < gathercounter) { + } + gathercounter++; + } + + asm volatile("bar.sync %0, %1;" ::"r"(1 + mydest), "r"(myblockDim)); + + for (int line = start_elem; line < end_aligned; line += myblockDim * gridDim.x * UNROLL) { + int4 val[UNROLL]; +#pragma unroll + for (int i = 0; i < UNROLL; i++) val[i] = peerptr[line + i * myblockDim * gridDim.x]; +#pragma unroll + for (int i = 0; i < UNROLL; i++) myptr[line + i * myblockDim * gridDim.x] = val[i]; + } + for (int line = end_aligned; line < end_elem; line += myblockDim * gridDim.x) + myptr[line] = peerptr[line]; + } + blocklineoffset += peerblocklines * RANKS; + } // block loop for NVLINK-ALLGATHER + } // worker warps else block +} // fp16 inplace reduce kernel with SHARP / in blocks + +// threadfence and SMs sync to SM0 +#define SMBAR(offset, block) \ + asm volatile("bar.sync 13, %0;" ::"r"(blockDim.x)); \ + if (threadIdx.x == 0) { \ + __threadfence_system(); \ + if (blockIdx.x) gpuflag[offset + blockIdx.x] = block + basecounter + 1; \ + } else if (blockIdx.x == 0) { \ + int expecting = (basecounter + block + 1); \ + if (threadIdx.x < gridDim.x) \ + while (((volatile int *)gpuflag)[offset + threadIdx.x] < expecting) { \ + } \ + } \ + if (blockIdx.x == 0) asm volatile("bar.sync 15, %0;" ::"r"(32)); + +template +__global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_rr_blocked2( + const int op, const int maxcredit, const int headstart, const int myibrank, const int ibranks, + const int commbufoffset, const int flagoffset, const int firstrank, const int myrank, + const int gpustep, const int lineoffset, const int numlines, void **commbuff, + const int handleridx, const int peerblocklines, int *hostflags, int *gpuflag, + const int numblocks) { + const int basecounter = gpuflag[NVTE_GF_STATE + op]; + if (threadIdx.x < 32) { + int *flagptr; + volatile int *localflag = (volatile int *)&( + ((int *)commbuff[gpustep * myrank + firstrank])[flagoffset]); // NOLINT(*) + // initial intranode barrier - once + if (threadIdx.x < RANKS) { + if (!blockIdx.x) { + flagptr = reinterpret_cast(commbuff[gpustep * threadIdx.x + firstrank]); + flagptr[flagoffset + gpustep * myrank + firstrank] = basecounter; + } + volatile int *flag = &localflag[gpustep * threadIdx.x + firstrank]; + while (*flag < basecounter) { + } + } + __syncthreads(); + + for (int nblock = 0; nblock < numblocks + headstart; nblock++) { + if (nblock < numblocks) { + // RS happens here + SMBAR(op * 2 * NVTE_MAX_SMS, nblock); + if (!blockIdx.x && !threadIdx.x) + hostflags[NVTE_HF_NVRSDONE + (op & 1)] = nblock + basecounter + 1; + } + + if (nblock >= headstart) { + for (int ibflag = threadIdx.x; ibflag < ibranks; ibflag += 32) + if (ibflag != myibrank) + while (localflag[NVTE_REG0_IBRS + ibflag] < basecounter + nblock - headstart + 1) { + } + asm volatile("bar.sync 13, %0;" ::"r"(blockDim.x)); + // REDUCE happens here + SMBAR(op * 2 * NVTE_MAX_SMS + NVTE_MAX_SMS, nblock - headstart); + if (!blockIdx.x && !threadIdx.x) + hostflags[NVTE_HF_NVREDUCEDONE + (op & 1)] = nblock + basecounter + 1 - headstart; + } + } + // final part doing NVAG based on responses from NIC-RMW:IBAG + + if (blockIdx.x == 0) { + for (int nblock = 0; nblock < numblocks; nblock++) { + const int expected = basecounter + nblock + 1; + for (int ibflag = threadIdx.x; ibflag < ibranks; ibflag += 32) + if (ibflag != myibrank) + while (localflag[NVTE_REG0_IBAG + ibflag] < expected) { + } + asm volatile("bar.sync 15, %0;" ::"r"(32)); + if (threadIdx.x < RANKS) + flagptr[flagoffset + gpustep * myrank + NVTE_MAX_NVLINK + firstrank] = expected; + } + } + + if (blockIdx.x == 0 && threadIdx.x == 0) gpuflag[NVTE_GF_STATE + op] = basecounter + numblocks; + } else { // sync warp + // reducethreads + const int warp = blockIdx.x + (threadIdx.x >> 5); + int4 *userptr[RANKS]; + int4 *userptrmyrank; +#pragma unroll + for (int i = 0; i < RANKS; i++) + userptr[i] = reinterpret_cast( + commbuff[((i + myrank + warp) & (RANKS - 1)) * gpustep + handleridx + firstrank]); + userptrmyrank = reinterpret_cast(commbuff[gpustep * myrank + handleridx + firstrank]); + int4 *internalbuf = reinterpret_cast(commbuff[myrank * gpustep + firstrank] + + commbufoffset * sizeof(int)); + __syncthreads(); + + int blocklineoffset = 0, rblocklineoffset = 0; + + for (int nblock = 0; nblock < numblocks + headstart; nblock++) { + // NVRS part(only first numblocks steps) + if (blocklineoffset < numlines) { + const int remainder = min(numlines - blocklineoffset, peerblocklines * RANKS); + const int blocklines = remainder / RANKS; + const int blockstart = lineoffset + blocklineoffset + blocklines * myrank; + if (RANKS > 1) { + for (int line = threadIdx.x - 32 + REDUCETHREADS * blockIdx.x; line < blocklines; + line += REDUCETHREADS * gridDim.x) { + int4 val[RANKS]; + +#pragma unroll + for (int i = 0; i < RANKS; i++) { + val[i] = userptr[i][blockstart + line]; + } + + int4 sum = val[0]; + half *s = reinterpret_cast(&sum); + +#pragma unroll + for (int i = 1; i < RANKS; i++) { + half *x = reinterpret_cast(&val[i]); +#pragma unroll + for (int j = 0; j < sizeof(int4) / sizeof(half); j++) s[j] += x[j]; + } + + userptrmyrank[blockstart + line] = sum; + } // single block loop + } + + asm volatile("bar.sync 13, %0;" ::"r"(REDUCETHREADS + 32)); + blocklineoffset += peerblocklines * RANKS; + } + if (nblock >= headstart) { +#define UNROLLRS 2 + const int remainder = min(numlines - rblocklineoffset, peerblocklines * RANKS); + const int blocklines = remainder / RANKS; + rblocklineoffset += peerblocklines * RANKS; + const int ibblocklines = blocklines / ibranks; + int4 *tempbufptr = &internalbuf[((nblock - headstart) % maxcredit) * peerblocklines]; + const int tempstart = lineoffset + (nblock - headstart) * peerblocklines * RANKS + + myrank * blocklines + ibblocklines * myibrank; + + asm volatile("bar.sync 13, %0;" ::"r"(REDUCETHREADS + 32)); + + for (int line = threadIdx.x - 32 + REDUCETHREADS * blockIdx.x; line < ibblocklines; + line += REDUCETHREADS * gridDim.x) { + int4 val[UNROLLRS]; + +#pragma unroll + for (int i = 0; i < UNROLLRS; i++) + val[i] = i == myibrank ? userptrmyrank[tempstart + line] + : tempbufptr[i * ibblocklines + line]; + + int4 sum = val[0]; + half *s = reinterpret_cast(&sum); + + for (int i = 0; i < ibranks - UNROLLRS; i++) { + val[i % UNROLLRS] = i == myibrank ? userptrmyrank[tempstart + line] + : tempbufptr[i * ibblocklines + line]; + half *x = reinterpret_cast(&val[(i + 1) % UNROLLRS]); +#pragma unroll + for (int j = 0; j < 16; j++) s[j] += x[j]; + } +#pragma unroll + for (int i = 1; i < UNROLLRS; i++) { + half *x = reinterpret_cast(&val[i]); +#pragma unroll + for (int j = 0; j < 16; j++) s[j] += x[j]; + } + userptrmyrank[tempstart + line] = sum; + } + + asm volatile("bar.sync 13, %0;" ::"r"(REDUCETHREADS + 32)); + } + } // nblock loop NVLINK-REDUCESCATTER + IBREDUCE LOCAL COMPUTE + + if (RANKS != 1) { + const int nwarps = (REDUCETHREADS >> 5) / (RANKS - 1); + const int myblockDim = nwarps << 5; + const int mywarp = ((threadIdx.x - 32) >> 5) / (RANKS - 1); + const int maxthreadIdx = myblockDim * (RANKS - 1) + 32; + const int mydest = (myrank + 1 + ((threadIdx.x - 32) >> 5) % (RANKS - 1)) & (RANKS - 1); + const int mythreadIdx = (mywarp << 5) + (threadIdx.x & 31); + volatile int *flag = (volatile int *)&((reinterpret_cast( + commbuff[gpustep * myrank + firstrank]))[flagoffset + gpustep * mydest + NVTE_MAX_NVLINK + + firstrank]); + + int4 *userptrmydest = userptr[((RANKS << 10) + mydest - myrank - warp) & (RANKS - 1)]; + + blocklineoffset = 0; + int gathercounter = basecounter + 1; + while (blocklineoffset < numlines) { + const int remainder = min(numlines - blocklineoffset, peerblocklines * RANKS); + const int blocklines = remainder / RANKS; + const int blockstart = lineoffset + blocklineoffset; + +#define UNROLL 6 + int4 *myptr = &userptrmyrank[blockstart + blocklines * mydest]; + int4 *peerptr = &userptrmydest[blockstart + blocklines * mydest]; + + if (threadIdx.x < maxthreadIdx) { + const int start_elem = mythreadIdx + myblockDim * blockIdx.x; + const int end_elem = max(start_elem, blocklines); + const int aligned_elem = ((end_elem - start_elem) / (myblockDim * gridDim.x * UNROLL)) * + (myblockDim * gridDim.x * UNROLL); + const int end_aligned = start_elem + aligned_elem; + + if (mythreadIdx == 0) { + while (*flag < gathercounter) { + } + gathercounter++; + } + + asm volatile("bar.sync %0, %1;" ::"r"(1 + mydest), "r"(myblockDim)); + + for (int line = start_elem; line < end_aligned; line += myblockDim * gridDim.x * UNROLL) { + int4 val[UNROLL]; +#pragma unroll + for (int i = 0; i < UNROLL; i++) val[i] = peerptr[line + i * myblockDim * gridDim.x]; +#pragma unroll + for (int i = 0; i < UNROLL; i++) myptr[line + i * myblockDim * gridDim.x] = val[i]; + } + for (int line = end_aligned; line < end_elem; line += myblockDim * gridDim.x) + myptr[line] = peerptr[line]; + } + blocklineoffset += peerblocklines * RANKS; + } // block loop for NVLINK-ALLGATHER + } // RANKS!=1 + } // worker warps else block +} // fp16 inplace reduce kernel with SHARP / in blocks + +template +__global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_rr_blocked2_rs( + const int op, const int maxcredit, const int headstart, const int myibrank, const int ibranks, + const int commbufoffset, const int flagoffset, const int firstrank, const int myrank, + const int gpustep, const int lineoffset, const int numlines, void **commbuff, + const int handleridx, const int peerblocklines, int *hostflags, int *gpuflag, + const int numblocks) { + const int basecounter = gpuflag[NVTE_GF_STATE + op]; + if (threadIdx.x < 32) { + int *flagptr; + volatile int *localflag = (volatile int *)&( + ((int *)commbuff[gpustep * myrank + firstrank])[flagoffset]); // NOLINT(*) + // initial intranode barrier - once + if (threadIdx.x < RANKS) { + if (!blockIdx.x) { + flagptr = reinterpret_cast(commbuff[gpustep * threadIdx.x + firstrank]); + flagptr[flagoffset + gpustep * myrank + firstrank] = basecounter; + } + volatile int *flag = &localflag[gpustep * threadIdx.x + firstrank]; + while (*flag < basecounter) { + } + } + __syncthreads(); + + for (int nblock = 0; nblock < numblocks + headstart; nblock++) { + if (nblock < numblocks) { + // RS happens here + SMBAR(op * 2 * NVTE_MAX_SMS, nblock); + if (!blockIdx.x && !threadIdx.x) + hostflags[NVTE_HF_NVRSDONE + (op & 1)] = nblock + basecounter + 1; + } + + if (nblock >= headstart) { + for (int ibflag = threadIdx.x; ibflag < ibranks; ibflag += 32) + if (ibflag != myibrank) + while (localflag[NVTE_REG0_IBRS + ibflag] < basecounter + nblock - headstart + 1) { + } + asm volatile("bar.sync 13, %0;" ::"r"(blockDim.x)); + // REDUCE happens here + SMBAR(op * 2 * NVTE_MAX_SMS + NVTE_MAX_SMS, nblock - headstart); + } + } + } else { // sync warp + // reducethreads + const int warp = blockIdx.x + (threadIdx.x >> 5); + int4 *userptr[RANKS]; + int4 *userptrmyrank; +#pragma unroll + for (int i = 0; i < RANKS; i++) + userptr[i] = reinterpret_cast( + commbuff[((i + myrank + warp) & (RANKS - 1)) * gpustep + handleridx + firstrank]); + userptrmyrank = reinterpret_cast(commbuff[gpustep * myrank + handleridx + firstrank]); + int4 *internalbuf = reinterpret_cast(commbuff[myrank * gpustep + firstrank] + + commbufoffset * sizeof(int)); + __syncthreads(); + + int blocklineoffset = 0, rblocklineoffset = 0; + + for (int nblock = 0; nblock < numblocks + headstart; nblock++) { + // NVRS part(only first numblocks steps) + if (blocklineoffset < numlines) { + const int remainder = min(numlines - blocklineoffset, peerblocklines * RANKS); + const int blocklines = remainder / RANKS; + const int blockstart = lineoffset + blocklineoffset + blocklines * myrank; + if (RANKS > 1) { + for (int line = threadIdx.x - 32 + REDUCETHREADS * blockIdx.x; line < blocklines; + line += REDUCETHREADS * gridDim.x) { + int4 val[RANKS]; + +#pragma unroll + for (int i = 0; i < RANKS; i++) { + val[i] = userptr[i][blockstart + line]; + } + + int4 sum = val[0]; + half *s = reinterpret_cast(&sum); + +#pragma unroll + for (int i = 1; i < RANKS; i++) { + half *x = reinterpret_cast(&val[i]); +#pragma unroll + for (int j = 0; j < sizeof(int4) / sizeof(half); j++) s[j] += x[j]; + } + + userptrmyrank[blockstart + line] = sum; + } // single block loop + } + + asm volatile("bar.sync 13, %0;" ::"r"(REDUCETHREADS + 32)); + blocklineoffset += peerblocklines * RANKS; + } + if (nblock >= headstart) { +#define UNROLLRS 2 + const int remainder = min(numlines - rblocklineoffset, peerblocklines * RANKS); + const int blocklines = remainder / RANKS; + rblocklineoffset += peerblocklines * RANKS; + const int ibblocklines = blocklines / ibranks; + int4 *tempbufptr = &internalbuf[((nblock - headstart) % maxcredit) * peerblocklines]; + const int tempstart = lineoffset + (nblock - headstart) * peerblocklines * RANKS + + myrank * blocklines + ibblocklines * myibrank; + // if(threadIdx.x==32) printf("[%d] block%d thread %d offset %d line %d ibblocklines %d ptr + // %lx commbufoffset + // %d\n",myrank,blockIdx.x,threadIdx.x,tempstart,0,ibblocklines,(void*)&tempbufptr[(1-myibrank)*ibblocklines],(1-myibrank)*ibblocklines*16); + + asm volatile("bar.sync 13, %0;" ::"r"(REDUCETHREADS + 32)); + + for (int line = threadIdx.x - 32 + REDUCETHREADS * blockIdx.x; line < ibblocklines; + line += REDUCETHREADS * gridDim.x) { + int4 val[UNROLLRS]; + +#pragma unroll + for (int i = 0; i < UNROLLRS; i++) + val[i] = i == myibrank ? userptrmyrank[tempstart + line] + : tempbufptr[i * ibblocklines + line]; + + int4 sum = val[0]; + half *s = reinterpret_cast(&sum); + + for (int i = 0; i < ibranks - UNROLLRS; i++) { + val[i % UNROLLRS] = i == myibrank ? userptrmyrank[tempstart + line] + : tempbufptr[i * ibblocklines + line]; + half *x = reinterpret_cast(&val[(i + 1) % UNROLLRS]); +#pragma unroll + for (int j = 0; j < 16; j++) s[j] += x[j]; + } +#pragma unroll + for (int i = 1; i < UNROLLRS; i++) { + half *x = reinterpret_cast(&val[i]); +#pragma unroll + for (int j = 0; j < 16; j++) s[j] += x[j]; + } + userptrmyrank[tempstart + line] = sum; + } + + asm volatile("bar.sync 13, %0;" ::"r"(REDUCETHREADS + 32)); + } + } // nblock loop NVLINK-REDUCESCATTER + IBREDUCE LOCAL COMPUTE + } // worker warps else block +} // fp16 inplace reduce kernel with SHARP / in blocks + +template +__global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_rr_blocked2_ag( + const int op, const int maxcredit, const int headstart, const int myibrank, const int ibranks, + const int commbufoffset, const int flagoffset, const int firstrank, const int myrank, + const int gpustep, const int lineoffset, const int numlines, void **commbuff, + const int handleridx, const int peerblocklines, int *hostflags, int *gpuflag, + const int numblocks) { + const int basecounter = gpuflag[NVTE_GF_STATE + op]; + if (threadIdx.x < 32) { + int *flagptr; + volatile int *localflag = (volatile int *)&( + ((int *)commbuff[gpustep * myrank + firstrank])[flagoffset]); // NOLINT(*) + if (threadIdx.x < RANKS) { + if (!blockIdx.x) { + flagptr = reinterpret_cast(commbuff[gpustep * threadIdx.x + firstrank]); + } + } + __syncthreads(); + if (!blockIdx.x && !threadIdx.x) + hostflags[NVTE_HF_NVREDUCEDONE + (op & 1)] = numblocks + basecounter; + // tell CPU proxy all blocks are done and ready for NVAG + + // final part doing NVAG based on responses from NIC-RMW:IBAG + + if (blockIdx.x == 0) { + for (int nblock = 0; nblock < numblocks; nblock++) { + const int expected = basecounter + nblock + 1; + for (int ibflag = threadIdx.x; ibflag < ibranks; ibflag += 32) + if (ibflag != myibrank) + while (localflag[NVTE_REG0_IBAG + ibflag] < expected) { + } + asm volatile("bar.sync 15, %0;" ::"r"(32)); + if (threadIdx.x < RANKS) + flagptr[flagoffset + gpustep * myrank + NVTE_MAX_NVLINK + firstrank] = expected; + } + } + + if (blockIdx.x == 0 && threadIdx.x == 0) gpuflag[NVTE_GF_STATE + op] = basecounter + numblocks; + } else { // sync warp + // reducethreads + const int warp = blockIdx.x + (threadIdx.x >> 5); + int4 *userptr[RANKS]; + int4 *userptrmyrank; +#pragma unroll + for (int i = 0; i < RANKS; i++) + userptr[i] = reinterpret_cast( + commbuff[((i + myrank + warp) & (RANKS - 1)) * gpustep + handleridx + firstrank]); + userptrmyrank = reinterpret_cast(commbuff[gpustep * myrank + handleridx + firstrank]); + __syncthreads(); + + int blocklineoffset = 0, rblocklineoffset = 0; + + if (RANKS != 1) { + const int nwarps = (REDUCETHREADS >> 5) / (RANKS - 1); + const int myblockDim = nwarps << 5; + const int mywarp = ((threadIdx.x - 32) >> 5) / (RANKS - 1); + const int maxthreadIdx = myblockDim * (RANKS - 1) + 32; + const int mydest = (myrank + 1 + ((threadIdx.x - 32) >> 5) % (RANKS - 1)) & (RANKS - 1); + const int mythreadIdx = (mywarp << 5) + (threadIdx.x & 31); + volatile int *flag = (volatile int *)&((reinterpret_cast( + commbuff[gpustep * myrank + firstrank]))[flagoffset + gpustep * mydest + NVTE_MAX_NVLINK + + firstrank]); + + int4 *userptrmydest = userptr[((RANKS << 10) + mydest - myrank - warp) & (RANKS - 1)]; + + blocklineoffset = 0; + int gathercounter = basecounter + 1; + while (blocklineoffset < numlines) { + const int remainder = min(numlines - blocklineoffset, peerblocklines * RANKS); + const int blocklines = remainder / RANKS; + const int blockstart = lineoffset + blocklineoffset; + +#define UNROLL 6 + int4 *myptr = &userptrmyrank[blockstart + blocklines * mydest]; + int4 *peerptr = &userptrmydest[blockstart + blocklines * mydest]; + + if (threadIdx.x < maxthreadIdx) { + const int start_elem = mythreadIdx + myblockDim * blockIdx.x; + const int end_elem = max(start_elem, blocklines); + const int aligned_elem = ((end_elem - start_elem) / (myblockDim * gridDim.x * UNROLL)) * + (myblockDim * gridDim.x * UNROLL); + const int end_aligned = start_elem + aligned_elem; + + if (mythreadIdx == 0) { + while (*flag < gathercounter) { + } + gathercounter++; + } + + asm volatile("bar.sync %0, %1;" ::"r"(1 + mydest), "r"(myblockDim)); + + for (int line = start_elem; line < end_aligned; line += myblockDim * gridDim.x * UNROLL) { + int4 val[UNROLL]; +#pragma unroll + for (int i = 0; i < UNROLL; i++) val[i] = peerptr[line + i * myblockDim * gridDim.x]; +#pragma unroll + for (int i = 0; i < UNROLL; i++) myptr[line + i * myblockDim * gridDim.x] = val[i]; + } + for (int line = end_aligned; line < end_elem; line += myblockDim * gridDim.x) + myptr[line] = peerptr[line]; + } + blocklineoffset += peerblocklines * RANKS; + } // block loop for NVLINK-ALLGATHER + } // RANKS!=1 + } // worker warps else block +} // fp16 inplace reduce kernel with SHARP / in blocks + +__global__ void userbuffers_fp16_sum_inplace_gpu_null(const int op, int *hostflags, int *gpuflag, + int numblocks) { + const int basecounter = gpuflag[NVTE_GF_STATE + op] + numblocks; + hostflags[0] = basecounter; + gpuflag[NVTE_GF_STATE + op] = basecounter; + while (((volatile int *)gpuflag)[NVTE_GF_IBSHARPDONE] < basecounter) { + } +} + +#define callranks_block(x) \ + if (comm->ar_nvsize == x) \ + userbuffers_fp16_sum_inplace_gpu_rr_blocked<<>>( \ + userbuffers_allreduceop_sharp, NVTE_REG0_OFFSET(comm), comm->ar_firstgpu, comm->ar_nvrank, \ + offset / 8, elements / 8, reinterpret_cast(comm->gpu_ptrs), \ + handler * comm->nvsize, blocksize / sizeof(int4) / comm->ar_nvsize, \ + reinterpret_cast(comm->hostflags), comm->flags, \ + (elements * 2 + blocksize - 1) / blocksize); + +#define callranks2_block(x) \ + if (ar_nvsize == x) { \ + int numblocks = (elements * 2 + blocksize - 1) / blocksize; \ + int headstart = numblocks - 1; /*<3?numblocks-1:3;*/ \ + if (headstart > maxcredit) headstart = maxcredit; \ + if (x == 1) headstart = maxcredit; \ + if (headstart > numblocks) headstart = numblocks; \ + if (headstart == 0) headstart = 1; \ + userbuffers_fp16_sum_inplace_gpu_rr_blocked2<<>>( \ + op, maxcredit, headstart, my_node, num_nodes, \ + NVTE_REG0_OFFSET(comm) + NVTE_REG0_FLAGS + \ + (op == userbuffers_allreduceop_nonsharp ? NVTE_REG0_COMMBUFFER : 0), \ + NVTE_REG0_OFFSET(comm) + NVTE_REG0_OPFLAGS * op, ar_firstgpu, ar_nvrank, ar_step, \ + offset / 8, elements / 8, reinterpret_cast(comm->gpu_ptrs), \ + handler * comm->nvsize, blocksize / sizeof(int4) / ar_nvsize, \ + reinterpret_cast(comm->hostflags), comm->flags, numblocks); \ + } + +#define callranks2_block_rs(x) \ + if (ar_nvsize == x) { \ + int numblocks = (elements * 2 + blocksize - 1) / blocksize; \ + int headstart = numblocks - 1; /*<3?numblocks-1:3;*/ \ + if (headstart > maxcredit) headstart = maxcredit; \ + if (x == 1) headstart = maxcredit; \ + if (headstart > numblocks) headstart = numblocks; \ + if (headstart == 0) headstart = 1; \ + userbuffers_fp16_sum_inplace_gpu_rr_blocked2_rs<<>>( \ + op, maxcredit, headstart, my_node, num_nodes, \ + NVTE_REG0_OFFSET(comm) + NVTE_REG0_FLAGS + \ + (op == userbuffers_allreduceop_nonsharp ? NVTE_REG0_COMMBUFFER : 0), \ + NVTE_REG0_OFFSET(comm) + NVTE_REG0_OPFLAGS * op, ar_firstgpu, ar_nvrank, ar_step, \ + offset / 8, elements / 8, reinterpret_cast(comm->gpu_ptrs), \ + handler * comm->nvsize, blocksize / sizeof(int4) / ar_nvsize, \ + reinterpret_cast(comm->hostflags), comm->flags, numblocks); \ + } + +#define callranks2_block_ag(x) \ + if (ar_nvsize == x) { \ + int numblocks = (elements * 2 + blocksize - 1) / blocksize; \ + int headstart = numblocks - 1; /*<3?numblocks-1:3;*/ \ + if (headstart > maxcredit) headstart = maxcredit; \ + if (x == 1) headstart = maxcredit; \ + if (headstart > numblocks) headstart = numblocks; \ + if (headstart == 0) headstart = 1; \ + userbuffers_fp16_sum_inplace_gpu_rr_blocked2_ag<<>>( \ + op, maxcredit, headstart, my_node, num_nodes, \ + NVTE_REG0_OFFSET(comm) + NVTE_REG0_FLAGS + \ + (op == userbuffers_allreduceop_nonsharp ? NVTE_REG0_COMMBUFFER : 0), \ + NVTE_REG0_OFFSET(comm) + NVTE_REG0_OPFLAGS * op, ar_firstgpu, ar_nvrank, ar_step, \ + offset / 8, elements / 8, reinterpret_cast(comm->gpu_ptrs), \ + handler * comm->nvsize, blocksize / sizeof(int4) / ar_nvsize, \ + reinterpret_cast(comm->hostflags), comm->flags, numblocks); \ + } + +#define callranks(x) \ + if (ar_nvsize == x) { \ + int arg1 = op - NVTE_MAX_OPS, \ + arg2 = NVTE_REG0_OFFSET(comm) - \ + (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ + NVTE_MAX_OPS, \ + arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg6 = offset / 8, \ + arg7 = elements / 8; \ + void **arg8 = reinterpret_cast(comm->gpu_ptrs); \ + int arg9 = handler * comm->nvsize; \ + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ + reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ + reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ + reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ + reinterpret_cast(&arg9)}; \ + CUDACHECK(cudaLaunchKernelExC( \ + &cfg, \ + reinterpret_cast(comm->use_rr_kernel ? userbuffers_fp16_sum_inplace_gpu_rr \ + : userbuffers_fp16_sum_inplace_gpu_rw), \ + kernelArgs)); \ + } + +#define SETUP_LAUNCH_CONFIG(sms, threads, stream) \ + cudaLaunchConfig_t cfg = {sms, threads, 0, stream, NULL, 0}; \ + cudaLaunchAttribute attribute_ub[2]; \ + attribute_ub[1].id = cudaLaunchAttributeClusterDimension; \ + attribute_ub[1].val.clusterDim.x = sms % comm->cga_size == 0 ? comm->cga_size : 1; \ + attribute_ub[1].val.clusterDim.y = 1; \ + attribute_ub[1].val.clusterDim.z = 1; \ + attribute_ub[0].id = cudaLaunchAttributeCooperative; \ + cfg.attrs = attribute_ub; \ + cfg.numAttrs = comm->sm_arch >= 9 ? 2 : 1; + +int allreduce_userbuff_inplace_gpu(const int handler, const int offset, const int elements, + const int blocksize, communicator *comm, cudaStream_t stream) { + // schedule GPU kernel only + // CPU/SHARP part is responsibility of caller + const int ar_step = comm->ar2_nvsize; + const int op = userbuffers_allreduceop_nonsharp; + const int ar_nvsize = comm->nvsize; + const int ar_firstgpu = comm->ar_firstgpu; + const int ar_nvrank = comm->ar_nvrank; + if (elements < 8) return 0; + int sms = sms = comm->sms; + int warps = comm->threads / 32; + if (warps < comm->ar_nvsize) warps = comm->ar_nvsize; + + if (comm->launch_mode & NVTE_LAUNCH_GPU) { + if (comm->ar_nvsize == 1) + userbuffers_fp16_sum_inplace_gpu_null<<<1, 1, 0, stream>>>( + userbuffers_allreduceop_sharp, reinterpret_cast(comm->hostflags), comm->flags, + (elements * 2 + blocksize - 1) / blocksize); + callranks_block(2) callranks_block(4) callranks_block(8) + } + return sms; +} + +int allreduce2_userbuff_inplace_gpu(const int maxcredit, const int handler, const int offset, + const int elements, const int blocksize, communicator *comm, + cudaStream_t stream, int op) { + // schedule GPU kernel only + // CPU/SHARP part is responsibility of caller + const int num_nodes = op == userbuffers_allreduceop_nonsharp ? comm->num_nodes : comm->num2_nodes; + const int my_node = op == userbuffers_allreduceop_nonsharp ? comm->my_node : comm->my2_node; + const int ar_firstgpu = + op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu; + const int ar_step = op == userbuffers_allreduceop_nonsharp2 ? 1 : comm->ar2_nvsize; + const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; + const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank; + + if (elements < 8) return 0; + int sms = ar_nvsize == 1 ? 2 : comm->sms; + int warps = comm->threads / 32; + if (warps < ar_nvsize) warps = ar_nvsize; + if (num_nodes > 1) { + callranks2_block(1) callranks2_block(2) callranks2_block(4) callranks2_block(8) + } else { + SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); + callranks(2) callranks(4) callranks(8) + } + return sms; +} + +#define callranks_ag(x) \ + if (ar_nvsize == x) { \ + int arg1 = op - NVTE_MAX_OPS, \ + arg2 = NVTE_REG0_OFFSET(comm) - \ + (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ + NVTE_MAX_OPS, \ + arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 8 / x, \ + arg6 = offset / 8 + (comm->use_rr_kernel ? 0 : arg4 * arg7); \ + void **arg8 = reinterpret_cast(comm->gpu_ptrs); \ + int arg9 = handler * comm->nvsize; \ + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ + reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ + reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ + reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ + reinterpret_cast(&arg9)}; \ + CUDACHECK(cudaLaunchKernelExC( \ + &cfg, \ + reinterpret_cast(comm->use_rr_kernel ? userbuffers_fp16_sum_inplace_gpu_rr_ag \ + : userbuffers_fp16_sum_inplace_gpu_rw_ag), \ + kernelArgs)); \ + } + +#define callranks_rs(x) \ + if (ar_nvsize == x) { \ + int arg1 = op - NVTE_MAX_OPS, \ + arg2 = NVTE_REG0_OFFSET(comm) - \ + (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ + NVTE_MAX_OPS, \ + arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 8 / x, \ + arg6 = offset / 8 + arg4 * arg7; \ + void **arg8 = reinterpret_cast(comm->gpu_ptrs); \ + int arg9 = handler * comm->nvsize; \ + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ + reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ + reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ + reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ + reinterpret_cast(&arg9)}; \ + CUDACHECK(cudaLaunchKernelExC( \ + &cfg, reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_rr_rs), kernelArgs)); \ + } + +#define callranks_rs_oop(x) \ + if (ar_nvsize == x) { \ + int arg1 = op - NVTE_MAX_OPS, \ + arg2 = NVTE_REG0_OFFSET(comm) - \ + (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ + NVTE_MAX_OPS, \ + arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 8 / x, \ + arg6 = offset / 8 + arg4 * arg7, arg8 = rowelements / 8, arg9 = strideelements / 8; \ + void **arg10 = reinterpret_cast(comm->gpu_ptrs); \ + int arg11 = handler * comm->nvsize; \ + void *arg12 = output; \ + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ + reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ + reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ + reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ + reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ + reinterpret_cast(&arg11), reinterpret_cast(&arg12)}; \ + CUDACHECK(cudaLaunchKernelExC( \ + &cfg, reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_rr_rs_oop), \ + kernelArgs)); \ + } + +int reducescatter2_userbuff_inplace_gpu(const int maxcredit, const int handler, const int offset, + const int elements, const int blocksize, communicator *comm, + cudaStream_t stream, int op) { + // schedule GPU kernel only + // CPU/SHARP part is responsibility of caller + + const int num_nodes = op == userbuffers_allreduceop_nonsharp ? comm->num_nodes : comm->num2_nodes; + const int my_node = op == userbuffers_allreduceop_nonsharp ? comm->my_node : comm->my2_node; + const int ar_firstgpu = + op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu; + const int ar_step = op == userbuffers_allreduceop_nonsharp2 ? 1 : comm->ar2_nvsize; + const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; + const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank; + + if (elements < 8) return 0; + int sms = ar_nvsize == 1 ? 2 : comm->sms; + int warps = comm->threads / 32; + if (warps < ar_nvsize) warps = ar_nvsize; + + if (num_nodes > 1) { + callranks2_block_rs(1) callranks2_block_rs(2) callranks2_block_rs(4) callranks2_block_rs(8) + } else { + SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); + callranks_rs(2) callranks_rs(4) callranks_rs(8) + } + return sms; +} + +int allgather2_userbuff_inplace_gpu(const int maxcredit, const int handler, const int offset, + const int elements, const int blocksize, communicator *comm, + cudaStream_t stream, int op) { + // schedule GPU kernel only + // CPU/SHARP part is responsibility of caller + + const int num_nodes = op == userbuffers_allreduceop_nonsharp ? comm->num_nodes : comm->num2_nodes; + const int my_node = op == userbuffers_allreduceop_nonsharp ? comm->my_node : comm->my2_node; + const int ar_firstgpu = + op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu; + const int ar_step = op == userbuffers_allreduceop_nonsharp2 ? 1 : comm->ar2_nvsize; + const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; + const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank; + + if (elements < 8) return 0; + int sms = ar_nvsize == 1 ? 2 : comm->sms; + int warps = comm->threads / 32; + if (warps < ar_nvsize) warps = ar_nvsize; + + if (num_nodes > 1) { + callranks2_block_ag(1) callranks2_block_ag(2) callranks2_block_ag(4) callranks2_block_ag(8) + } else { + SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); + callranks_ag(2) callranks_ag(4) callranks_ag(8) + } + return sms; +} + +void allgather2_userbuff_inplace(const int handler, const int offset, const int elements, + communicator *comm, cudaStream_t stream) { + const int op = userbuffers_allreduceop_nonsharp2; + const int blocksize = elements * 2; + const int ar_firstgpu = + op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu; + const int ar_step = op == userbuffers_allreduceop_nonsharp2 ? 1 : comm->ar2_nvsize; + const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; + const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank; + + if (elements < 64) return; + int sms = ar_nvsize == 1 ? 2 : comm->sms; + int warps = comm->threads / 32; + if (warps < ar_nvsize) warps = ar_nvsize; + + SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); + callranks_ag(2) callranks_ag(4) callranks_ag(8) +} + +void allgather2_userbuff_inplace_sliced(const int handler, const int offset, const int elements, + communicator *comm, const int slice_id, const int nslices, + cudaStream_t stream) { + const int op = userbuffers_allreduceop_nonsharp2; + const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank; + const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; + int peerelements = elements / ar_nvsize; + int saverrkernel = comm->use_rr_kernel; + comm->use_rr_kernel = 0; + allgather2_userbuff_inplace( + handler, offset + ar_nvrank * peerelements * (nslices - 1) + slice_id * peerelements, + elements, comm, stream); + comm->use_rr_kernel = saverrkernel; +} + +void reducescatter2_userbuff_inplace(const int handler, const int offset, const int elements, + communicator *comm, cudaStream_t stream) { + const int op = userbuffers_allreduceop_nonsharp2; + const int blocksize = elements * 2; + const int ar_firstgpu = + op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu; + const int ar_step = op == userbuffers_allreduceop_nonsharp2 ? 1 : comm->ar2_nvsize; + const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; + const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank; + + if (elements < 64) return; + int sms = ar_nvsize == 1 ? 2 : comm->sms; + int warps = comm->threads / 32; + if (warps < ar_nvsize) warps = ar_nvsize; + + SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); + callranks_rs(2) callranks_rs(4) callranks_rs(8) +} +void reducescatter2_userbuff_stridedoutput(void *output, const int handler, const int offset, + const int rowelements, const int colelements, + const int strideelements, communicator *comm, + cudaStream_t stream) { + const int elements = rowelements * colelements; + const int op = userbuffers_allreduceop_nonsharp2; + const int blocksize = elements * 2; + const int ar_firstgpu = + op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu; + const int ar_step = op == userbuffers_allreduceop_nonsharp2 ? 1 : comm->ar2_nvsize; + const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; + const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank; + + if (elements < 64) return; + int sms = ar_nvsize == 1 ? 2 : comm->sms; + int warps = comm->threads / 32; + if (warps < ar_nvsize) warps = ar_nvsize; + + SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); + callranks_rs_oop(2) callranks_rs_oop(4) callranks_rs_oop(8) +} +void reducescatter2_userbuff(void *output, const int handler, const int offset, const int elements, + communicator *comm, cudaStream_t stream) { + reducescatter2_userbuff_stridedoutput(output, handler, offset, elements, 1, 0, comm, stream); +} + +__global__ void kuserbuffers_pullsend(int myrank, int peer, int *send_id, int *flagptr) { + atomicAdd(flagptr, 1); +} + +__global__ void kuserbuffers_inc(int *id) { + const int signal_id = (*id) + 1; + *id = signal_id; +} + +__global__ void kuserbuffers_proxysend(int *id, int *hostflag) { + const int signal_id = (*id) + 1; + *hostflag = signal_id; + *id = signal_id; +} + +__global__ void kuserbuffers_dummy(void) {} + +__global__ void __launch_bounds__(MAX_THREADS) + kuserbuffers_pullrecv(int myrank, int peer, int *recv_id, int *flagptr, int4 *srcptr, + int4 *dstptr, const int lines) { +#define UNROLLCOPY 8 + const int start_elem = threadIdx.x + blockDim.x * blockIdx.x; + const int end_elem = lines; + const int aligned_elem = (end_elem - start_elem) & (~(blockDim.x * gridDim.x * UNROLLCOPY - 1)); + const int end_aligned = start_elem + aligned_elem; + + if (threadIdx.x == 0) { + const int signal_id = (*recv_id) + 1; + volatile int *flag = (volatile int *)flagptr; + clock_t s = clock64(); + while (*flag < signal_id) { + if (clock64() - s > TIMEOUT) { + printf("[%d from %d] pullrecv: expected %d, stuck with %d\n", myrank, peer, signal_id, + *flag); + break; + } + } + if (lines == 0) { + *recv_id = signal_id; + return; + } // otherwise need an extra kernel + } + __syncthreads(); + + if (end_elem <= start_elem) return; + + for (int line = start_elem; line < end_aligned; line += blockDim.x * gridDim.x * UNROLLCOPY) { + int4 val[UNROLLCOPY]; +#pragma unroll + for (int i = 0; i < UNROLLCOPY; i++) val[i] = srcptr[line + i * blockDim.x * gridDim.x]; +#pragma unroll + for (int i = 0; i < UNROLLCOPY; i++) dstptr[line + i * blockDim.x * gridDim.x] = val[i]; + } + for (int line = end_aligned; line < end_elem; line += blockDim.x * gridDim.x) + dstptr[line] = srcptr[line]; +} + +__global__ void __launch_bounds__(MAX_THREADS) + kuserbuffers_pushsend(int *send_id, int *flagptr, int4 *srcptr, int4 *dstptr, const int lines) { + if (lines) { + const int start_elem = threadIdx.x + blockDim.x * blockIdx.x; + const int end_elem = lines; + const int aligned_elem = + ((end_elem - start_elem) & (~(blockDim.x * gridDim.x * UNROLLCOPY - 1))); + const int end_aligned = start_elem + aligned_elem; + if (end_elem > start_elem) { + for (int line = start_elem; line < end_aligned; line += blockDim.x * gridDim.x * UNROLLCOPY) { + int4 val[UNROLLCOPY]; +#pragma unroll + for (int i = 0; i < UNROLLCOPY; i++) val[i] = srcptr[line + i * blockDim.x * gridDim.x]; +#pragma unroll + for (int i = 0; i < UNROLLCOPY; i++) dstptr[line + i * blockDim.x * gridDim.x] = val[i]; + } + for (int line = end_aligned; line < end_elem; line += blockDim.x * gridDim.x) + dstptr[line] = srcptr[line]; + } + __syncthreads(); + if (threadIdx.x) return; + __threadfence_system(); + atomicAdd(flagptr, 1); // otherwise need local SM sync before sending flag + } else { // 0 bytes and 1 SM only + atomicAdd(flagptr, 1); + } +} + +__global__ void kuserbuffers_pushrecv(int myrank, int peer, int *recv_id, int *flagptr, int adder) { + const int signal_id = (*recv_id) + adder; + *recv_id = signal_id; + volatile int *flag = (volatile int *)flagptr; + if (*flag >= signal_id) return; + clock_t s = clock64(); + while (*flag < signal_id) { + if (clock64() - s > TIMEOUT) { + printf("%d from %d] pushrecv: expected %d, stuck with %d\n", myrank, peer, signal_id, *flag); + return; + } + } +} + +#define CUDACHECK(cmd) \ + do { \ + cudaError_t e = cmd; \ + if (e != cudaSuccess) { \ + printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +#define INTRANODE(peer) ((peer / comm->nvsize) == (comm->myrank / comm->nvsize)) + +void userbuffers_send(const int srchandler, const size_t srcoffset, const int dsthandler, + const size_t dstoffset, const size_t bytes, communicator *comm, + const int peer, cudaStream_t stream) { + int peerlocal = peer % comm->nvsize; + void *flagptr = + (comm->peer_ptr[0][peerlocal]) + + ((NVTE_REG0_OFFSET(comm) + NVTE_REG0_RECV + comm->myrank * NVTE_MAX_REGIONS + dsthandler) * + sizeof(int)); + bool signalonly = (bytes / 16 == 0) || (comm->use_ce != 0); + bool intranode = INTRANODE(peer); + if (!intranode && (comm->launch_mode & NVTE_LAUNCH_CPU)) { + comm->fifo[comm->head].optype = userbuffers_sendop; + comm->fifo[comm->head].basecounter = comm->basecounter[userbuffers_sendop]; + comm->fifo[comm->head].handler = srchandler; + comm->fifo[comm->head].offset = srcoffset; + comm->fifo[comm->head].handler2 = dsthandler; + comm->fifo[comm->head].offset2 = dstoffset; + comm->fifo[comm->head].elements = bytes; + comm->fifo[comm->head].peer = peer; + + int newhead = (comm->head + 1) & (NVTE_MAX_REQUESTS - 1); + while (newhead == comm->tail) { + } + comm->head = newhead; + comm->basecounter[userbuffers_sendop] += 1; + } + if (!intranode && (comm->launch_mode & NVTE_LAUNCH_GPU)) { + kuserbuffers_proxysend<<<1, 1, 0, stream>>>(&(comm->flags[NVTE_GF_STATE + userbuffers_sendop]), + comm->hostflags + userbuffers_sendop); + return; + } + if (!(comm->launch_mode & NVTE_LAUNCH_GPU)) return; + if (comm->push == 0) { + kuserbuffers_pullsend<<<1, 1, 0, stream>>>(comm->myrank, peer, &(comm->send_id[peer]), + reinterpret_cast(flagptr)); + } else { + void *srcptr = (comm->mem_ptr[srchandler]) + srcoffset; + void *dstptr = (comm->peer_ptr[dsthandler][peerlocal]) + dstoffset; + + if (comm->use_ce) + CUDACHECK(cudaMemcpyAsync(dstptr, srcptr, bytes, cudaMemcpyDeviceToDevice, stream)); + SETUP_LAUNCH_CONFIG(signalonly ? 1 : comm->sms, signalonly ? 1 : 1024, stream); + int *arg1 = &comm->send_id[peer], *arg2 = reinterpret_cast(flagptr); + int4 *arg3 = reinterpret_cast(srcptr), *arg4 = reinterpret_cast(dstptr); + int arg5 = signalonly ? 0 : bytes / 16; + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), + reinterpret_cast(&arg3), reinterpret_cast(&arg4), + reinterpret_cast(&arg5)}; + CUDACHECK( + cudaLaunchKernelExC(&cfg, reinterpret_cast(kuserbuffers_pushsend), kernelArgs)); + } +} + +__global__ void __launch_bounds__(MAX_THREADS) + kuserbuffers_alltoall(void **baseflagptrs, int flagoffset, int4 *basesrcptr, void **dstptrs, + size_t dstoffset, const int lines, const int myrank) { + if (blockIdx.x == myrank) return; + int4 *dstptr = reinterpret_cast(dstptrs[blockIdx.x] + dstoffset); + int *flagptr = reinterpret_cast(baseflagptrs[blockIdx.x] + flagoffset); + const size_t myblockoffset = blockIdx.x * lines; + int4 *srcptr = basesrcptr + myblockoffset; + dstptr += myblockoffset; + + if (lines) { + const int start_elem = threadIdx.x; + const int end_elem = lines; + const int aligned_elem = ((end_elem - start_elem) & (~(blockDim.x * UNROLLCOPY - 1))); + const int end_aligned = start_elem + aligned_elem; + if (end_elem > start_elem) { + for (int line = start_elem; line < end_aligned; line += blockDim.x * UNROLLCOPY) { + int4 val[UNROLLCOPY]; +#pragma unroll + for (int i = 0; i < UNROLLCOPY; i++) val[i] = srcptr[line + i * blockDim.x]; +#pragma unroll + for (int i = 0; i < UNROLLCOPY; i++) dstptr[line + i * blockDim.x] = val[i]; + } + for (int line = end_aligned; line < end_elem; line += blockDim.x) dstptr[line] = srcptr[line]; + } + __syncthreads(); + if (threadIdx.x) return; + __threadfence_system(); + atomicAdd(flagptr, 1); + + } else { + atomicAdd(flagptr, 1); + } +} + +void userbuffers_alltoall_send(const int srchandler, const size_t srcoffset, const int dsthandler, + const size_t dstoffset, const size_t bytes, communicator *comm, + cudaStream_t stream) { + if (comm->launch_mode & NVTE_LAUNCH_CPU) { + comm->fifo[comm->head].optype = userbuffers_alltoall; + comm->fifo[comm->head].basecounter = comm->basecounter[userbuffers_alltoall]; + comm->fifo[comm->head].handler = srchandler; + comm->fifo[comm->head].offset = srcoffset; + comm->fifo[comm->head].handler2 = dsthandler; + comm->fifo[comm->head].offset2 = dstoffset; + comm->fifo[comm->head].elements = bytes; + + int newhead = (comm->head + 1) & (NVTE_MAX_REQUESTS - 1); + while (newhead == comm->tail) { + } + comm->head = newhead; + comm->basecounter[userbuffers_alltoall] += 1; + } + if (comm->launch_mode & NVTE_LAUNCH_GPU) + kuserbuffers_proxysend<<<1, 1, 0, stream>>>( + &(comm->flags[NVTE_GF_STATE + userbuffers_alltoall]), + comm->hostflags + userbuffers_alltoall); +} + +void userbuffers_recv(const int srchandler, const size_t srcoffset, const int dsthandler, + const size_t dstoffset, const size_t bytes, communicator *comm, + const int peer, cudaStream_t stream) { + int peerlocal = peer % comm->nvsize; + void *flagptr = + (comm->mem_ptr[0]) + + ((NVTE_REG0_OFFSET(comm) + NVTE_REG0_RECV + peer * NVTE_MAX_REGIONS + dsthandler) * + sizeof(int)); + bool signalonly = (bytes / 16 == 0) || (comm->use_ce != 0); + bool intranode = INTRANODE(peer); + if (!(comm->launch_mode & NVTE_LAUNCH_GPU)) return; + if (comm->push == 0 && intranode) { + void *dstptr = (comm->mem_ptr[dsthandler]) + dstoffset; + void *srcptr = (comm->peer_ptr[srchandler][peerlocal]) + srcoffset; + + kuserbuffers_pullrecv<<sms, signalonly ? 1 : 1024, 0, stream>>>( + comm->myrank, peer, &(comm->recv_id[peer * NVTE_MAX_REGIONS + dsthandler]), + reinterpret_cast(flagptr), reinterpret_cast(srcptr), + reinterpret_cast(dstptr), signalonly ? 0 : bytes / 16); + if (!signalonly) + kuserbuffers_inc<<<1, 1, 0, stream>>>(&(comm->recv_id[peer * NVTE_MAX_REGIONS + dsthandler])); + if (comm->use_ce) { + CUDACHECK(cudaMemcpyAsync(dstptr, srcptr, bytes, cudaMemcpyDeviceToDevice, stream)); + } + } else { + kuserbuffers_pushrecv<<<1, 1, 0, stream>>>( + comm->myrank, peer, &comm->recv_id[peer * NVTE_MAX_REGIONS + dsthandler], + reinterpret_cast(flagptr), signalonly || !intranode ? 1 : comm->sms); + } +} + +void userbuffers_alltoall_recv(communicator *comm, cudaStream_t stream) { + void *flagptr = + (comm->mem_ptr[0]) + + ((NVTE_REG0_OFFSET(comm) + NVTE_REG0_OPFLAGS * userbuffers_alltoall) * sizeof(int)); + + if (!(comm->launch_mode & NVTE_LAUNCH_GPU)) return; + kuserbuffers_pushrecv<<<1, 1, 0, stream>>>(comm->myrank, -1, reinterpret_cast(flagptr + 4), + reinterpret_cast(flagptr), comm->nranks - 1); +} diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 1e28cec70e..a216799a5c 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -49,6 +49,7 @@ void cublas_gemm(const Tensor *inputA, size_t workspaceSize, bool accumulate, bool use_split_accumulator, + int math_sm_count, cudaStream_t stream ) { void *A = inputA->data.dptr; @@ -124,6 +125,13 @@ void cublas_gemm(const Tensor *inputA, &transa, sizeof(transa))); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transb))); + // Set math SM count + if (math_sm_count != 0) { + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET, + &math_sm_count, sizeof(math_sm_count))); + } + // set fp8 attributes -- input and output types should already be set to fp8 as appropriate // Note: gelu fusion isn't available right now, and we don't need @@ -227,6 +235,7 @@ void cublas_gemm(const Tensor *inputA, if (returnedResults == 0) throw std::runtime_error("Unable to find any suitable algorithms"); // D = alpha * (A * B) + beta * C + NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, operationDesc, static_cast(&one), /* alpha */ @@ -266,6 +275,7 @@ void nvte_cublas_gemm(const NVTETensor A, NVTETensor workspace, bool accumulate, bool use_split_accumulator, + int math_sm_count, cudaStream_t stream) { NVTE_API_CALL(nvte_cublas_gemm); using namespace transformer_engine; @@ -308,5 +318,6 @@ void nvte_cublas_gemm(const NVTETensor A, grad, wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator, + math_sm_count, stream); } diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index 035f467adb..8cd549b658 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -36,6 +36,7 @@ extern "C" { * \param[out] workspace Workspace tensor. * \param[in] accumulate Whether to accumulate the result into the D matrix. * \param[in] use_split_accumulator Whether to use split accumulator in the FP8 GEMM. + * \param[in] math_sm_count Number of GPU SMs to use (default=0: use cuBLAS heuristics) * \param[in] stream CUDA stream used for the operation. */ void nvte_cublas_gemm(const NVTETensor A, @@ -49,6 +50,7 @@ void nvte_cublas_gemm(const NVTETensor A, NVTETensor workspace, bool accumulate, bool use_split_accumulator, + int math_sm_count, cudaStream_t stream ); diff --git a/transformer_engine/common/include/transformer_engine/userbuffers.h b/transformer_engine/common/include/transformer_engine/userbuffers.h new file mode 100644 index 0000000000..cd5b1ec382 --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/userbuffers.h @@ -0,0 +1,227 @@ +/************************************************************************* + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_USERBUFFERS_H_ +#define TRANSFORMER_ENGINE_USERBUFFERS_H_ + +#include +#include +#include "cuda_runtime.h" +#include +#include +#include "gdrapi.h" +#include + +#define NVTE_MAX_REGIONS 16 +#define NVTE_MAX_SMS 32 +#define NVTE_MAX_OPS 32 +#define NVTE_MAX_PEERS 8192 +#define NVTE_MAX_REQUESTS 1024 +#define NVTE_LAUNCH_GPU 1 +#define NVTE_LAUNCH_CPU 2 +#define NVTE_MAX_NVLINK 8 + +// region 0 flag offsets +#define NVTE_REG0_OPFLAGS 1024 +#define NVTE_REG0_RECV (NVTE_REG0_OPFLAGS * userbuffers_op_types) +#define NVTE_REG0_SINGLENODE (2 * NVTE_MAX_NVLINK * NVTE_MAX_SMS + NVTE_MAX_OPS) +#define NVTE_REG0_OFFSET(comm) ((2 * NVTE_MAX_REGIONS) * NVTE_MAX_NVLINK \ + + NVTE_REG0_SINGLENODE * 2 + NVTE_MAX_PEERS) +#define NVTE_REG0_COMMBUFFER 0 +#define NVTE_REG0_FLAGS (NVTE_REG0_RECV + NVTE_MAX_PEERS * NVTE_MAX_REGIONS) +#define NVTE_REG0_IBRS 32 +#define NVTE_REG0_IBAG 512 +#undef NVTE_REG0_COMMBUFFER +#define NVTE_REG0_COMMBUFFER (1024 * 1024 * 16) + +// gpuflags map offsets +#define NVTE_GF_STATE 16000 +#define NVTE_GF_IBSHARPDONE 0 +#define NVTE_HF_NVRSDONE (userbuffers_op_types + 1) +#define NVTE_HF_NVREDUCEDONE (userbuffers_op_types + 3) +#define NVTE_MAX_SHARP 16 + +typedef struct ub_request { + int optype; + int blocksize; + int basecounter; + int elements; + int handler; + int handler2; + size_t offset; + size_t offset2; + int peer; + // ----execution states + int active, maxcredit; + int nblock, numblocks, unconfirmed_ib_in_flight; +} ub_request; + +enum req_type { + userbuffers_allreduceop_sharp, + userbuffers_sendop, + userbuffers_allreduceop_nonsharp, + userbuffers_allreduceop_nonsharp2, + userbuffers_alltoall, + userbuffers_op_types +}; + +struct communicator { + int myrank, nranks; // global job communicator + int nvrank, nvsize; // single node comm_intra + int free_region; + + int launch_mode; + + void *gpu_ptrs; + int sms, threads; + int use_rr_kernel; // Whether to use RR (or RW) for NVLink-only kernel + int cga_size; + int push, use_ce; + + void *mem_ptr[NVTE_MAX_REGIONS]; + void **peer_ptr[NVTE_MAX_REGIONS]; + int ar_nvsize, ar_firstgpu, + ar_nvrank; // number of gpus(and first gpu in a group) of gpus per node in reduction subgroup + // (_splitar init used) would be equal to (nvsize,0) for regular comm_create + int ar2_nvsize, ar2_firstgpu, ar2_nvrank; // with ar_nvsize as a step + int pipe_id; // which allreduce set of groups (pipeline rank in range of 0..pipeline_size) + int sm_arch; + int num_nodes, my_node, + first_node; // comm_inter communicator, per-rail allreduce (might have subset of nodes) + int num2_nodes, my2_node, first2_node; // with num_nodes as a stride + // max value for running block counters in hostflags + int basecounter[userbuffers_op_types]; // NOLINT(*) + + int *hostflags; + int *flags, *map_flags; + gdr_t g; + + struct sharp_coll_context *sharp_coll_context; + struct sharp_coll_comm *sharp_coll_comm; + void *mem_mr[NVTE_MAX_REGIONS]; + + ub_request *fifo; + volatile int activeproxy; + int nblocks, alignblock, minblock, asyncblocks, active_nreqs; + ub_request active_req[userbuffers_op_types]; // NOLINT(*) + int padding[7]; + volatile int head; + int padding2[15]; + volatile int tail; + + MPI_Request mpihndl[NVTE_MAX_SHARP]; + MPI_Comm comm_inter, // reduction group communicator (subset of the nodes) along GPU rail + comm_intra; // full intranode (all ndev GPUS) + int ibnvsize; // can be used to fake smaller or larger nvlink domain to use ib instead of nvlink + // or force MNNVL + int *send_id, *recv_id; + int mydev; +}; +typedef struct communicator communicator; + +int create_communicator(communicator **comm); +/* creates communicator, allocates all internal buffers if necessary */ + +int create_communicator_grouped(communicator **comm, int pipegpus, int pipenodes); +int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenodes, int tensorgpus, + int tensornodes); +/* creates communicator with + allreduce1 to happen in datagpus x datanodes groups, + allreduce2 to happen in tensorgpus x tensor nodes, + where num_nodes = pipenodes x tensornodes x datanodes + nvlink_size = pipegpus x tensorgpus x datagpus + */ + +// int check_user_buffer_registration(void* gpubuff, int bytes, communicator* comm, size_t* offset); +/* + local calls, doesnt communicate between peers + returns handler if buffer is registered already, or -1 if not. + returned offset is offset of gpubuff relative to buffer registered +*/ + +int pipe_rank(communicator *comm, + int step); // helper function to help walk across allreduce1 x allreduce2 groups + // data-parallel and tensor-parallel position within data and tensor + // groups would be preserved + +int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *comm, + bool alloc = false); +/* returns handler and registers buffers. assumed to be collective i.e. you use same groups and + dont mix buffers for different operations returns -1 if cant register (too many preregistered + regions already) if alloc==true will allocate memory and fill the pointers (required for NVL + SHARP and NSO/MNNVL) +*/ + +void allreduce_userbuff_inplace(const int handler, const int offset, const int elements, + communicator *comm, cudaStream_t stream = 0); +// for DP distributed optimizer, only nonSHARP multinode is implemented & calls must come in pairs +// ordered +void allgather_userbuff_inplace(const int handler, const int offset, const int elements, + communicator *comm, cudaStream_t stream = 0); +void reducescatter_userbuff_inplace(const int handler, const int offset, const int elements, + communicator *comm, cudaStream_t stream = 0); + +void allreduce2_userbuff_inplace(const int handler, const int offset, const int elements, + communicator *comm, cudaStream_t stream = 0); +// for TP-parallelism, only single node is implemented +void allgather2_userbuff_inplace(const int handler, const int offset, const int elements, + communicator *comm, cudaStream_t stream = 0); +void allgather2_userbuff_inplace_sliced(const int handler, const int offset, const int elements, + communicator *comm, const int slice_id, const int nslices, + cudaStream_t stream = 0); +/* +each Rank input is +allgather2_userbuff_inplace: offset+myrank*elements +allgather2_userbuff_inplace_sliced: offset+myrank*elements*nslices+slice_id*elements + +equivalent codes would be: +for(int slice=0;slice torch.Tensor: """TN layout GEMM with fp8 inputs.""" @@ -55,7 +58,7 @@ def fp8_gemm( out_dtype = TE_DType[out.dtype] if D_dtype is None else D_dtype - _ = torch.ops.tex_ts.te_gemm_ts( + args = ( A, A_scale_inv, A_fp8_tensor, @@ -77,8 +80,29 @@ def fp8_gemm( workspace, workspace.shape[0], accumulate, - use_split_accumulator, - ) + use_split_accumulator) + fn = torch.ops.tex_ts.te_gemm_ts + if ub_algo is not None: + assert ub is not None, 'ub object is None!' + if ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_AG: + fn = ub.bulk_overlap + args = tuple(args + (1,)) + elif ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_RS: + fn = ub.bulk_overlap + args = tuple(args + (0,)) + elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG: + fn = ub.split_overlap_ag + extra_output_tensor = ( + empty_tensor if extra_output_tensor is None else extra_output_tensor + ) + args = tuple(args + (extra_output_tensor,)) + elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS: + fn = ub.split_overlap_rs + assert ( + extra_output_tensor is not None + ), 'SPLIT_PIPELINED_RS requires extra output tensor' + args = tuple(args + (True, extra_output_tensor,)) + _ = fn(*args) if return_output: if gelu: @@ -102,6 +126,9 @@ def gemm( out: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, use_bias: bool = False, + ub_algo: tex.UbufOverlapAlgo = None, + ub: tex.UbufCommOverlap = None, + extra_output_tensor: torch.Tensor = None, ) -> Tuple[Union[torch.Tensor, None], ...]: """Non FP8 GEMM.""" @@ -142,7 +169,7 @@ def gemm( else: bias_dtype = output_dtype - _ = torch.ops.tex_ts.te_gemm_ts( + args = ( A, empty_tensor, fp8_index, @@ -166,6 +193,28 @@ def gemm( accumulate, False, # use_split_accumulator ) + fn = torch.ops.tex_ts.te_gemm_ts + if ub_algo is not None: + assert ub is not None, 'ub object is None!' + if ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_AG: + fn = ub.bulk_overlap + args = tuple(args + (1,)) + elif ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_RS: + fn = ub.bulk_overlap + args = tuple(args + (0,)) + elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG: + fn = ub.split_overlap_ag + extra_output_tensor = ( + empty_tensor if extra_output_tensor is None else extra_output_tensor + ) + args = tuple(args + (extra_output_tensor,)) + elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS: + fn = ub.split_overlap_rs + assert ( + extra_output_tensor is not None + ), 'SPLIT_PIPELINED_RS requires extra output tensor' + args = tuple(args + (False, extra_output_tensor,)) + _ = fn(*args) if return_output: return out, grad_bias, gelu_input @@ -283,9 +332,25 @@ def layernorm_fwd_fp8( fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], otype: tex.DType, sm_margin: int, - zero_centered_gamma: bool + zero_centered_gamma: bool, + ln_out: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """LayerNorm with FP8 output""" + if ln_out is not None: + return tex.layernorm_fwd_fp8_noalloc( + inp, + weight, + bias, + eps, + fp8_meta_tensor.scale[fp8_tensor], + ln_out, + fp8_meta_tensor.amax_history[0][fp8_tensor], + fp8_meta_tensor.scale_inv[fp8_tensor], + otype, + sm_margin, + zero_centered_gamma + ) + return tex.layernorm_fwd_fp8( inp, weight, @@ -351,8 +416,20 @@ def cast_to_fp8( fp8_meta_tensor: tex.FP8TensorMeta, fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], otype: tex.DType, -) -> torch.Tensor: + out: Optional[torch.Tensor] = None, +) -> Optional[torch.Tensor]: """Cast input to FP8""" + + if out is not None: + tex.cast_to_fp8_noalloc( + inp, + fp8_meta_tensor.scale[fp8_tensor], + out, + fp8_meta_tensor.amax_history[0][fp8_tensor], + fp8_meta_tensor.scale_inv[fp8_tensor], + otype + ) + return None return torch.ops.tex_ts.cast_to_fp8_ts( inp, fp8_meta_tensor.scale, diff --git a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h new file mode 100644 index 0000000000..18863a7858 --- /dev/null +++ b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h @@ -0,0 +1,579 @@ +/************************************************************************* + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define HALF_BYTES 2 + +#define CHECK_CUDA(call) \ + do { \ + cudaError_t status_ = call; \ + if (status_ != cudaSuccess) { \ + fprintf(stderr, "CUDA Error at line %d: %s\n", __LINE__, cudaGetErrorString(status_)); \ + exit(1); \ + } \ + } while (0) + +namespace ubuf { + +enum class COMM_TYPE { RS = 0, AG = 1 }; + +enum class UBOverlapAlgo { + BULK_OVERLAP_AG = 0, + BULK_OVERLAP_RS = 1, + SPLIT_PIPELINED_AG = 2, + SPLIT_PIPELINED_RS = 3 +}; + +struct UbufCommOverlap : torch::CustomClassHolder { + communicator *_ub_comm; + int _tp_id; + int _tp_size; + int _num_splits; + int _math_sms; + int _ub_reg; + void *_ubuf_ptr; + torch::Tensor _ubuf; + torch::Tensor output_tensor; + at::cuda::CUDAStream _stream_comm = at::cuda::getStreamFromPool(true); + std::vector _stream_compute; + cudaEvent_t _start_compute, _stop_compute, _start_d2dcopy, _start_comm, _stop_comm; + + UbufCommOverlap(torch::Tensor sample, int rank, int tp_size, int num_comm_sm, int comm_cga_size, + int num_splits, bool set_sm_margin, int num_max_streams) { + // Initialize userbuf communicator + create_communicator_grouped2(&_ub_comm, 1, 1, tp_size, 1); + _ub_comm->use_ce = 0; + _ub_comm->sms = num_comm_sm; + _ub_comm->cga_size = comm_cga_size; + + // Allocate and register extra userbuffers + int ubuf_bytes = sample.numel() * sample.element_size(); + _ub_reg = register_user_buffer_collective(reinterpret_cast(&_ubuf_ptr), ubuf_bytes, + _ub_comm, true); + _ubuf = torch::from_blob(_ubuf_ptr, {sample.size(0), sample.size(1)}, sample.options()); + + at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); + for (int i = 0; i < std::min(num_max_streams, num_splits); i++) { + cudaStream_t stream; + cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, -1); + _stream_compute.push_back( + at::cuda::getStreamFromExternal(stream, stream_main.device_index())); + } + + _num_splits = num_splits; + _tp_size = tp_size; + _tp_id = (rank % tp_size); + + // Set the number of SMs for GEMM with margin + cudaDeviceProp prop; + cudaGetDeviceProperties(&prop, 0); + _math_sms = (set_sm_margin) ? prop.multiProcessorCount - num_comm_sm : prop.multiProcessorCount; + + output_tensor = torch::Tensor(); + // CUDA event creation + cudaEventCreateWithFlags(&_start_compute, 0); + cudaEventCreateWithFlags(&_stop_compute, 0); + cudaEventCreateWithFlags(&_start_d2dcopy, 0); + cudaEventCreateWithFlags(&_start_comm, 0); + cudaEventCreateWithFlags(&_stop_comm, 0); + } + + /* + ** Bulk GEMM + COMM + ** This function assumes the communication input is pre-copied to _ubuf + */ + std::vector bulk_overlap( + at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, + transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse, + int64_t B_fp8_tensor, transformer_engine::DType B_type, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, + transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + size_t workspaceSize, bool accumulate, bool use_split_accumulator, int comm_type) { + // Get the current userbuf offset + char *ubuf_wt_ptr = reinterpret_cast(_ubuf.data_ptr()); + int comm_elements = (_ubuf.numel() / 2) * _ubuf.element_size(); // UBUF uses 2Byte element size + COMM_TYPE _comm_type = static_cast(comm_type); + if (_comm_type == COMM_TYPE::RS) { + ubuf_wt_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size(); + } + + // Catch up the default torch stream + at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); + CHECK_CUDA(cudaEventRecord(_start_comm, (cudaStream_t)stream_main)); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0)); + + // Communication: AG and RS + if (_comm_type == COMM_TYPE::AG) { + allgather2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, (cudaStream_t)_stream_comm); + } else if (_comm_type == COMM_TYPE::RS) { + reducescatter2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, + (cudaStream_t)_stream_comm); + } else { + NVTE_ERROR("Not supported communication type."); + } + + if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; + + if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor]; + + assert(pre_gelu_out.numel() == 0); + te_gemm(A, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb, D, D_scale, + D_type, D_amax, bias, bias_type, pre_gelu_out, grad, workspace, workspaceSize, + accumulate, use_split_accumulator, _math_sms); + + CHECK_CUDA(cudaEventRecord(_stop_comm, (cudaStream_t)_stream_comm)); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_comm, 0)); + + // Generate output tensor from userbuf data pointer + int output_c_dim0 = (_comm_type == COMM_TYPE::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size; + int output_c_dim1 = _ubuf.size(1); + output_tensor = torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf.options()); + + return {D, output_tensor}; + } // bulk_overlap + + /* + ** Split FPROP GEMM + ReduceScatter + */ + void split_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, + transformer_engine::DType A_type, bool transa, at::Tensor B, + at::Tensor B_scale_inverse, int64_t B_fp8_tensor, + transformer_engine::DType B_type, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, + at::Tensor bias, transformer_engine::DType bias_type, + at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + size_t workspaceSize, bool accumulate, bool use_split_accumulator, + bool gemm_overlap, at::Tensor rs_output) { + // Get GEMM dimensions + int m = A.size(0); + int k = A.size(1); + int n = B.size(0); + int m_chunk = m / _num_splits; + int input_a_chunk_size = m_chunk * k; + int output_chunk_size = n * m_chunk; + int workspace_size_chunk = workspaceSize / _stream_compute.size(); + + // Get input, output, and workspace data pointers + char *input_a_chunk_ptr = reinterpret_cast(A.data_ptr()); + char *output_buf_chunk_ptr = reinterpret_cast(_ubuf.data_ptr()); + char *workspace_ptr = reinterpret_cast(workspace.data_ptr()); + + char *rs_output_ptr = reinterpret_cast(rs_output.data_ptr()); + int ubuf_offset = 0; + + // Catch up the default torch stream + at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); + CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); + for (int i = 0; i < _stream_compute.size(); i++) { + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0)); + } + + if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; + + if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor]; + + assert(pre_gelu_out.numel() == 0); + + if (gemm_overlap) { + torch::Tensor input_a_chunk = torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options()); + torch::Tensor output_chunk = + torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options()); + torch::Tensor workspace_chunk = + torch::from_blob(workspace_ptr, {workspace_size_chunk}, workspace.options()); + at::cuda::setCurrentCUDAStream(_stream_compute[0]); + te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb, + output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, + workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, _math_sms); + + for (int i = 1; i < _num_splits; i++) { + input_a_chunk_ptr += input_a_chunk_size * B.element_size(); + output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size(); + + torch::Tensor input_a_chunk = + torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options()); + torch::Tensor output_chunk = + torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options()); + torch::Tensor workspace_chunk = + torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk, + {workspace_size_chunk}, workspace.options()); + at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]); + te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb, + output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, + workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, + _math_sms); + + CHECK_CUDA(cudaEventRecord( + _start_comm, (cudaStream_t)_stream_compute[(i - 1) % _stream_compute.size()])); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0)); + + // Communication chunk + reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, (i - 1) * output_chunk_size, + m_chunk, n, m, _ub_comm, (cudaStream_t)_stream_comm); + + rs_output_ptr += m_chunk * _ubuf.element_size(); + } + int last_compute_stream_id = + (_num_splits + _stream_compute.size() - 1) % _stream_compute.size(); + CHECK_CUDA( + cudaEventRecord(_start_comm, (cudaStream_t)_stream_compute[last_compute_stream_id])); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0)); + + // Communication chunk + reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, + (_num_splits - 1) * output_chunk_size, m_chunk, n, m, + _ub_comm, (cudaStream_t)_stream_comm); + } else { + for (int i = 0; i < _num_splits; i++) { + torch::Tensor input_a_chunk = + torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options()); + torch::Tensor output_chunk = + torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options()); + torch::Tensor workspace_chunk = + torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk, + {workspace_size_chunk}, workspace.options()); + at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]); + te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb, + output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, + workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, + _math_sms); + + CHECK_CUDA(cudaEventRecord(_start_comm, + (cudaStream_t)_stream_compute[i % _stream_compute.size()])); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0)); + + // Communication chunk + reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, i * output_chunk_size, + m_chunk, n, m, _ub_comm, (cudaStream_t)_stream_comm); + + rs_output_ptr += m_chunk * _ubuf.element_size(); + input_a_chunk_ptr += input_a_chunk_size * B.element_size(); + output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size(); + } + } + int last_compute_stream_id = + (_num_splits + _stream_compute.size() - 1) % _stream_compute.size(); + CHECK_CUDA( + cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[last_compute_stream_id])); + CHECK_CUDA(cudaEventRecord(_stop_comm, (cudaStream_t)_stream_comm)); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_comm, 0)); + at::cuda::setCurrentCUDAStream(stream_main); + + return; + } // split_overlap_rs + + /* + ** Helper function to copy input to _ubuf + */ + void copy_input_to_ubuf(torch::Tensor input, int comm_type) { + char *ubuf_ptr = reinterpret_cast(_ubuf.data_ptr()); + COMM_TYPE _comm_type = static_cast(comm_type); + if (_comm_type == COMM_TYPE::AG) { + if ((input.numel() * _tp_size) != _ubuf.numel() || + input.element_size() != _ubuf.element_size()) { + NVTE_ERROR("input and ubuf size do not match!"); + } + ubuf_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size(); + } else { + if (input.numel() != _ubuf.numel() || input.element_size() != _ubuf.element_size()) { + NVTE_ERROR("input and ubuf size do not match!"); + } + } + + at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); + CHECK_CUDA(cudaEventRecord(_start_d2dcopy, (cudaStream_t)stream_main)); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_d2dcopy, 0)); + CHECK_CUDA(cudaMemcpyAsync(ubuf_ptr, input.data_ptr(), input.numel() * input.element_size(), + cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_comm)); + } + + torch::Tensor &get_ubuf_output(int comm_type) { + char *ubuf_wt_ptr = reinterpret_cast(_ubuf.data_ptr()); + COMM_TYPE _comm_type = static_cast(comm_type); + if (_comm_type != COMM_TYPE::AG && _comm_type != COMM_TYPE::RS) NVTE_ERROR("Invalid comm_type"); + if (_comm_type == COMM_TYPE::RS) + ubuf_wt_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size(); + int output_c_dim0 = (_comm_type == COMM_TYPE::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size; + int output_c_dim1 = _ubuf.size(1); + output_tensor = torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf.options()); + return output_tensor; + } +}; // UbufCommOverlap + +struct UbufP2PCommOverlap : torch::CustomClassHolder { + communicator *_ub_comm; + int _tp_id; + int _tp_size; + int _ub_reg; + int _next_rank, _prev_rank, _rank, _rank_round_tp; + int _aggregate2; + int _math_sms; + void *_ubuf_ptr; + torch::Tensor _ubuf; + std::vector _ubufs; + at::cuda::CUDAStream _stream_comm = at::cuda::getStreamFromPool(true); + std::vector _stream_compute; + cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_comm, _start_accum, _stop_accum; + + UbufP2PCommOverlap(torch::Tensor sample, int rank, int tp_size, bool aggregate2, + int num_max_streams) { + // Initialize userbuf communicator + create_communicator_grouped2(&_ub_comm, 1, 1, tp_size, 1); + _ub_comm->use_ce = 1; + _ub_comm->sms = 1; + _ub_comm->cga_size = 1; + + // Create workspace tensor with userbuffer + int ubuf_bytes = sample.numel() * sample.element_size(); + int ubuf_chunk_bytes = ubuf_bytes / tp_size; + _ub_reg = register_user_buffer_collective(reinterpret_cast(&_ubuf_ptr), ubuf_bytes, + _ub_comm, true); + _ubuf = torch::from_blob(_ubuf_ptr, {sample.size(0), sample.size(1)}, sample.options()); + + // Create tensor chunks for easy management + char *ubuf_byte_ptr = reinterpret_cast(_ubuf.data_ptr()); + for (int i = 0; i < tp_size; i++) { + torch::Tensor ubuf_chunk = torch::from_blob( + ubuf_byte_ptr, {sample.size(0) / tp_size, sample.size(1)}, sample.options()); + _ubufs.push_back(ubuf_chunk); + ubuf_byte_ptr += ubuf_chunk_bytes; + } + + at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); + for (int i = 0; i < std::min(num_max_streams, tp_size); i++) { + cudaStream_t stream; + cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, -1); + _stream_compute.push_back( + at::cuda::getStreamFromExternal(stream, stream_main.device_index())); + } + + // Set the number of SMs for GEMM with margin + cudaDeviceProp prop; + cudaGetDeviceProperties(&prop, 0); + _math_sms = prop.multiProcessorCount; + + _tp_size = tp_size; + _aggregate2 = aggregate2; + + _rank = rank; + _tp_id = (rank % tp_size); + _rank_round_tp = (rank / tp_size) * tp_size; + _next_rank = (tp_size + rank + 1) % tp_size + _rank_round_tp; + _prev_rank = (tp_size + rank + -1) % tp_size + _rank_round_tp; + + // CUDA event creation + cudaEventCreateWithFlags(&_start_compute, 0); + cudaEventCreateWithFlags(&_stop_compute, 0); + cudaEventCreateWithFlags(&_start_comm, 0); + cudaEventCreateWithFlags(&_stop_comm, 0); + cudaEventCreateWithFlags(&_start_accum, 0); + cudaEventCreateWithFlags(&_stop_accum, 0); + } + + /* + ** Split AllGather + GEMM using P2P communication + ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG + *outputs + ** in each rank to be in the contiguous memory space after all ring exchange phases. + */ + torch::Tensor split_overlap_ag(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, + transformer_engine::DType A_type, bool transa, at::Tensor B, + at::Tensor B_scale_inverse, int64_t B_fp8_tensor, + transformer_engine::DType B_type, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, + at::Tensor D_amax, at::Tensor bias, + transformer_engine::DType bias_type, at::Tensor pre_gelu_out, + bool grad, at::Tensor workspace, size_t workspaceSize, + bool accumulate, bool use_split_accumulator, at::Tensor B_copy) { + // Get GEMM dimensions between TN and NN input layouts + const int m = (transa) ? A.size(0) : A.size(1); + const int k = (transa) ? A.size(1) : A.size(0); + const int n_chunk = _ubufs[0].size(0); + + // Get communication and GEMM output chunk sizes + const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); + const int output_chunk_bytes = (n_chunk * m) * HALF_BYTES; + + // Get output and workspace data pointers + char *output_ptr = reinterpret_cast(D.data_ptr()); + char *workspace_ptr = reinterpret_cast(workspace.data_ptr()); + int workspace_size_chunk = workspaceSize / _stream_compute.size(); + + if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; + + if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor]; + + at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); + CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main)); + + assert(pre_gelu_out.numel() == 0); + if (_aggregate2) { + // Catch up the default torch stream + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_compute, 0)); + + const int num_steps = _tp_size / 2; + char *input_b_ptr = reinterpret_cast(_ubuf.data_ptr()); + + // Initial 1X input chunk exchange between neighboring peers + int send_chunk_id = _tp_id; + int recv_chunk_id = (_tp_id % 2 == 0) ? _tp_id + 1 : _tp_id - 1; + int send_offset = comm_bytes * send_chunk_id; + int recv_offset = comm_bytes * recv_chunk_id; + int peer_rank = (_tp_id % 2 == 0) ? _next_rank : _prev_rank; + userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm, peer_rank, + (cudaStream_t)_stream_comm); + userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, peer_rank, + (cudaStream_t)_stream_comm); + CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)_stream_comm)); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _start_compute, 0)); + + int local_rank_round2 = (_tp_id % 2 == 0) ? _tp_id : _tp_id - 1; + const int next_rank = (_tp_size + _tp_id + 2) % _tp_size + _rank_round_tp; + const int prev_rank = (_tp_size + _tp_id - 2) % _tp_size + _rank_round_tp; + + // Ring exchange of 2X inputs chunks + for (int i = 0; i < num_steps; i++) { + send_chunk_id = (_tp_size + local_rank_round2 - i * 2) % _tp_size; + recv_chunk_id = (_tp_size + local_rank_round2 - i * 2 - 2) % _tp_size; + send_offset = comm_bytes * send_chunk_id; + recv_offset = comm_bytes * recv_chunk_id; + + // GEMM + torch::Tensor input_b_chunk = + torch::from_blob(input_b_ptr + send_offset, {n_chunk * 2, k}, _ubuf.options()); + torch::Tensor output_chunk = torch::from_blob( + output_ptr + (send_chunk_id * output_chunk_bytes), {n_chunk * 2, m}, D.options()); + torch::Tensor workspace_chunk = + torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk, + {workspace_size_chunk}, workspace.options()); + at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]); + te_gemm(A, A_scale_inverse, A_type, transa, input_b_chunk, B_scale_inverse, B_type, transb, + output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, + workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, + _math_sms); + + if (i < num_steps - 1) { + // P2P communication + userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes * 2, _ub_comm, + next_rank, (cudaStream_t)_stream_comm); + userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes * 2, _ub_comm, + prev_rank, (cudaStream_t)_stream_comm); + CHECK_CUDA(cudaEventRecord(_stop_comm, (cudaStream_t)_stream_comm)); + CHECK_CUDA(cudaStreamWaitEvent( + (cudaStream_t)_stream_compute[(i + 1) % _stream_compute.size()], _stop_comm, 0)); + } else if (B_copy.numel() > 0) { + assert(B_copy.numel() == _ubufs[_tp_id].numel()); + assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); + CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_tp_id].data_ptr(), + _ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(), + cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_comm)); + } + } + at::cuda::setCurrentCUDAStream(stream_main); + int last_compute_stream_id = + (num_steps + _stream_compute.size() - 1) % _stream_compute.size(); + CHECK_CUDA( + cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[last_compute_stream_id])); + } else { + // Catch up the default torch stream + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_compute, 0)); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _start_compute, 0)); + + for (int i = 0; i < _tp_size; i++) { + // Set the userbuffer id. Buffer under send is the input for the current GEMM chunk + // The initial input chunk is stored _ubuf[rank]. This is to have the AG output in all ranks + // to be contiguous after the ring exchanges + int send_chunk_id = (_tp_size + _tp_id - i) % _tp_size; + int recv_chunk_id = (_tp_size + _tp_id - i - 1) % _tp_size; + int send_offset = comm_bytes * send_chunk_id; + int recv_offset = comm_bytes * recv_chunk_id; + + // GEMM + torch::Tensor output_chunk = torch::from_blob( + output_ptr + (send_chunk_id * output_chunk_bytes), {n_chunk, m}, D.options()); + torch::Tensor workspace_chunk = + torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk, + {workspace_size_chunk}, workspace.options()); + at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]); + te_gemm(A, A_scale_inverse, A_type, transa, _ubufs[send_chunk_id], B_scale_inverse, B_type, + transb, output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, + workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, + _math_sms); + + if (i < _tp_size - 1) { + // P2P communication + userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm, + _next_rank, (cudaStream_t)_stream_comm); + userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, + _prev_rank, (cudaStream_t)_stream_comm); + CHECK_CUDA(cudaEventRecord(_stop_comm, (cudaStream_t)_stream_comm)); + CHECK_CUDA(cudaStreamWaitEvent( + (cudaStream_t)_stream_compute[(i + 1) % _stream_compute.size()], _stop_comm, 0)); + } else if (B_copy.numel() > 0) { + assert(B_copy.numel() == _ubufs[_tp_id].numel()); + assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); + CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_tp_id].data_ptr(), + _ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(), + cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_comm)); + } + } + at::cuda::setCurrentCUDAStream(stream_main); + int last_compute_stream_id = (_tp_size + _stream_compute.size() - 1) % _stream_compute.size(); + CHECK_CUDA( + cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[last_compute_stream_id])); + } + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _stop_compute, 0)); + + return D; + } // split_overlap_ag + + /* + ** Copy input to _ubufs[0] + */ + void copy_input_to_ubuf(torch::Tensor input, bool chunk) { + at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); + if (chunk) { + // Copy input to the target ubuf chunk by rank offset + if (input.numel() != _ubufs[0].numel() || input.element_size() != _ubufs[0].element_size()) { + NVTE_ERROR("input and ubuf size do not match!"); + } + CHECK_CUDA(cudaMemcpyAsync(_ubufs[_tp_id].data_ptr(), input.data_ptr(), + input.numel() * input.element_size(), cudaMemcpyDeviceToDevice, + (cudaStream_t)stream_main)); + } else { + if (input.numel() != _ubuf.numel() || input.element_size() != _ubuf.element_size()) { + NVTE_ERROR("input and ubuf size do not match!"); + } + CHECK_CUDA(cudaMemcpyAsync(_ubuf.data_ptr(), input.data_ptr(), + input.numel() * input.element_size(), cudaMemcpyDeviceToDevice, + (cudaStream_t)stream_main)); + } + } + torch::Tensor get_ubuf_output(int comm_type) { + char *ubuf_wt_ptr = reinterpret_cast(_ubuf.data_ptr()); + COMM_TYPE _comm_type = static_cast(comm_type); + if (_comm_type != COMM_TYPE::AG && _comm_type != COMM_TYPE::RS) NVTE_ERROR("Invalid comm_type"); + if (_comm_type == COMM_TYPE::RS) + ubuf_wt_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size(); + int output_c_dim0 = (_comm_type == COMM_TYPE::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size; + int output_c_dim1 = _ubuf.size(1); + return torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf.options()); + } +}; // UbufP2PCommOverlap + +} // namespace ubuf diff --git a/transformer_engine/pytorch/csrc/extensions.cu b/transformer_engine/pytorch/csrc/extensions.cu index ede0a5ef6c..e34c79d980 100644 --- a/transformer_engine/pytorch/csrc/extensions.cu +++ b/transformer_engine/pytorch/csrc/extensions.cu @@ -5,7 +5,9 @@ ************************************************************************/ #include "extensions.h" - +#ifdef NVTE_MPI_FOUND +#include "comm_gemm_overlap.h" +#endif // NVTE_MPI_FOUND void te_gemm(at::Tensor A, at::Tensor A_scale_inverse, @@ -26,7 +28,8 @@ void te_gemm(at::Tensor A, at::Tensor workspace, size_t workspaceSize, bool accumulate, - bool use_split_accumulator + bool use_split_accumulator, + int math_sm_count ) { using namespace transformer_engine; auto te_A = makeTransformerEngineTensor(A.data_ptr(), @@ -70,6 +73,7 @@ void te_gemm(at::Tensor A, te_workspace.data(), accumulate, use_split_accumulator, + math_sm_count, at::cuda::getCurrentCUDAStream()); } @@ -536,6 +540,67 @@ std::vector layernorm_fwd_fp8(const at::Tensor &input, } +std::vector layernorm_fwd_fp8_noalloc(const at::Tensor &input, + const at::Tensor &weight, + const at::Tensor &bias, + float eps, + at::Tensor scale, + at::Tensor ln_out, + at::Tensor amax, + at::Tensor scale_inv, + transformer_engine::DType otype, + const int sm_margin, + const bool zero_centered_gamma +) { + using namespace transformer_engine; + + size_t N = static_cast(input.size(0)); + size_t H = static_cast(input.size(1)); + + DType itype = GetTransformerEngineDType(input.scalar_type()); + + auto mu = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); + auto rsigma = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); + auto input_cu = makeTransformerEngineTensor(input); + auto gamma_cu = makeTransformerEngineTensor(weight); + auto beta_cu = makeTransformerEngineTensor(bias); + auto z_cu = makeTransformerEngineTensor(ln_out.data_ptr(), {N, H}, otype, + amax.data_ptr(), scale.data_ptr(), + scale_inv.data_ptr()); + auto mu_cu = makeTransformerEngineTensor(mu); + auto rsigma_cu = makeTransformerEngineTensor(rsigma); + transformer_engine::TensorWrapper workspace, barrier; + + // This call populates workspace and barrier tensors with the required config + const auto func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd; + func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), + mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(), + at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, + workspace.data(), barrier.data()); + + // Fill workspace and barrier + auto workspace_data = allocateSpace(workspace.shape(), + workspace.dtype()); + auto barrier_data = allocateSpace(barrier.shape(), + barrier.dtype(), + true); + workspace = makeTransformerEngineTensor(workspace_data.data_ptr(), + workspace.shape(), + workspace.dtype()); + barrier = makeTransformerEngineTensor(barrier_data.data_ptr(), + barrier.shape(), + barrier.dtype()); + + // Actual call to fwd kernel + func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), + mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(), + at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, + workspace.data(), barrier.data()); + + return {ln_out, mu, rsigma}; +} + + at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, @@ -609,6 +674,61 @@ std::vector layernorm_fwd(const at::Tensor &input, } +std::vector layernorm_fwd_noalloc(const at::Tensor &input, + const at::Tensor &weight, + const at::Tensor &bias, + at::Tensor ln_out, + float eps, + const int sm_margin, + const bool zero_centered_gamma +) { + using namespace transformer_engine; + + size_t N = static_cast(input.size(0)); + size_t H = static_cast(input.size(1)); + + DType itype = GetTransformerEngineDType(input.scalar_type()); + + auto mu = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); + auto rsigma = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); + auto input_cu = makeTransformerEngineTensor(input); + auto gamma_cu = makeTransformerEngineTensor(weight); + auto beta_cu = makeTransformerEngineTensor(bias); + auto z_cu = makeTransformerEngineTensor(ln_out); + auto mu_cu = makeTransformerEngineTensor(mu); + auto rsigma_cu = makeTransformerEngineTensor(rsigma); + transformer_engine::TensorWrapper workspace, barrier; + + // This call populates workspace and barrier tensors with the required config + const auto func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd; + func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), + mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(), + at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, + workspace.data(), barrier.data()); + + // Fill workspace and barrier + auto workspace_data = allocateSpace(workspace.shape(), + workspace.dtype()); + auto barrier_data = allocateSpace(barrier.shape(), + barrier.dtype(), + true); + workspace = makeTransformerEngineTensor(workspace_data.data_ptr(), + workspace.shape(), + workspace.dtype()); + barrier = makeTransformerEngineTensor(barrier_data.data_ptr(), + barrier.shape(), + barrier.dtype()); + + // Actual call to fwd kernel + func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), + mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(), + at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, + workspace.data(), barrier.data()); + + return {ln_out, mu, rsigma}; +} + + at::Tensor layernorm_fwd_inf(const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, @@ -646,6 +766,29 @@ at::Tensor cast_to_fp8(const at::Tensor &input, } +void cast_to_fp8_noalloc(const at::Tensor &input, + const at::Tensor &scale, + at::Tensor output, + at::Tensor amax, + at::Tensor scale_inv, + transformer_engine::DType otype +) { + using namespace transformer_engine; + size_t N = static_cast(input.size(0)); + size_t H = static_cast(input.size(1)); + + auto input_cu = makeTransformerEngineTensor(input); + auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, H}, otype, + amax.data_ptr(), scale.data_ptr(), + scale_inv.data_ptr()); + + nvte_fp8_quantize(input_cu.data(), output_cu.data(), + at::cuda::getCurrentCUDAStream()); + + return; +} + + at::Tensor cast_from_fp8(const at::Tensor &input, const at::Tensor &scale_inv, transformer_engine::DType itype, @@ -878,6 +1021,17 @@ size_t get_cublasLt_version() { } +bool userbuf_comm_available() { // TODO(ksivamani) check on python side +#ifdef NVTE_MPI_FOUND + return true; +#else + return false; +#endif +} + +void placeholder() {} // TODO(ksivamani) clean this up + + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // Softmax functions m.def("scaled_softmax_forward", &scaled_softmax_forward, "Scaled Softmax FWD"); @@ -895,8 +1049,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // Other granular functions m.def("layernorm_fwd_fp8", &layernorm_fwd_fp8, "LN FWD FP8"); + m.def("layernorm_fwd_fp8_noalloc", &layernorm_fwd_fp8_noalloc, "LN FWD FP8"); m.def("layernorm_bwd", &layernorm_bwd, "LN BWD"); m.def("layernorm_fwd", &layernorm_fwd, "LN FWD"); + m.def("layernorm_fwd_noalloc", &layernorm_fwd_noalloc, "LN FWD"); m.def("fused_cast_transpose", &fused_cast_transpose, "Fused Cast + Transpose"); m.def("fused_cast_transpose_bgrad", &fused_cast_transpose_bgrad, "Fused Cast + Transpose + BGRAD"); @@ -907,6 +1063,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("fused_multi_cast_transpose", &fused_multi_cast_transpose, "Fused Multi-tensor Cast + Transpose"); m.def("cast_to_fp8", &cast_to_fp8, "Cast to FP8"); + m.def("cast_to_fp8_noalloc", &cast_to_fp8_noalloc, "Cast to FP8"); m.def("cast_from_fp8", &cast_from_fp8, "Cast from FP8"); m.def("te_gemm", &te_gemm, "CublasLt GEMM"); m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O"); @@ -914,6 +1071,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // Misc m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version"); + m.def("userbuf_comm_available", &userbuf_comm_available, "If userbuf backend is available"); // Data structures py::class_(m, "FP8TensorMeta") @@ -922,6 +1080,31 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .def_readwrite("scale_inv", &transformer_engine::FP8TensorMeta::scale_inv) .def_readwrite("amax_history", &transformer_engine::FP8TensorMeta::amax_history); +#ifdef NVTE_MPI_FOUND + py::enum_(m, "UbufOverlapAlgo") + .value("BULK_OVERLAP_AG", ubuf::UBOverlapAlgo::BULK_OVERLAP_AG) + .value("BULK_OVERLAP_RS", ubuf::UBOverlapAlgo::BULK_OVERLAP_RS) + .value("SPLIT_PIPELINED_RS", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_RS) + .value("SPLIT_PIPELINED_AG", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_AG); + + py::class_(m, "UbufCommOverlap") + .def(py::init()) + .def("bulk_overlap", &ubuf::UbufCommOverlap::bulk_overlap) + .def("split_overlap_rs", &ubuf::UbufCommOverlap::split_overlap_rs) + .def("copy_input_to_ubuf", &ubuf::UbufCommOverlap::copy_input_to_ubuf) + .def("get_ubuf_output", &ubuf::UbufCommOverlap::get_ubuf_output); + + py::class_(m, "UbufP2PCommOverlap") + .def(py::init()) + .def("split_overlap_ag", &ubuf::UbufP2PCommOverlap::split_overlap_ag) + .def("copy_input_to_ubuf", &ubuf::UbufP2PCommOverlap::copy_input_to_ubuf) + .def("get_ubuf_output", &ubuf::UbufP2PCommOverlap::get_ubuf_output); +#else // NVTE_MPI_FOUND + m.def("UbufOverlapAlgo", &placeholder, "Dummy function for python side annotations"); + m.def("UbufCommOverlap", &placeholder, "Dummy function for python side annotations"); + m.def("UbufP2PCommOverlap", &placeholder, "Dummy function for python side annotations"); +#endif // NVTE_MPI_FOUND + py::enum_(m, "DType", py::module_local()) .value("kByte", transformer_engine::DType::kByte) .value("kInt32", transformer_engine::DType::kInt32) diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 99849c15fe..6be404226e 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -26,7 +26,8 @@ void te_gemm(at::Tensor A, at::Tensor workspace, size_t workspaceSize, bool accumulate, - bool use_split_accumulator + bool use_split_accumulator, + int math_sm_count ); @@ -111,6 +112,19 @@ std::vector layernorm_fwd_fp8(const at::Tensor &input, const bool zero_centered_gamma ); +std::vector layernorm_fwd_fp8_noalloc(const at::Tensor &input, + const at::Tensor &weight, + const at::Tensor &bias, + float eps, + at::Tensor scale, + at::Tensor ln_out, + at::Tensor amax, + at::Tensor scale_inv, + transformer_engine::DType otype, + const int sm_margin, + const bool zero_centered_gamma +); + at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, @@ -130,6 +144,15 @@ std::vector layernorm_fwd(const at::Tensor &input, const bool zero_centered_gamma ); +std::vector layernorm_fwd_noalloc(const at::Tensor &input, + const at::Tensor &weight, + const at::Tensor &bias, + at::Tensor ln_out, + float eps, + const int sm_margin, + const bool zero_centered_gamma +); + at::Tensor layernorm_fwd_inf(const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, @@ -145,6 +168,15 @@ at::Tensor cast_to_fp8(const at::Tensor &input, ); +void cast_to_fp8_noalloc(const at::Tensor &input, + const at::Tensor &scale, + at::Tensor output, + at::Tensor amax, + at::Tensor scale_inv, + transformer_engine::DType otype +); + + at::Tensor cast_from_fp8(const at::Tensor &input, const at::Tensor &scale_inv, transformer_engine::DType itype, diff --git a/transformer_engine/pytorch/csrc/ts_fp8_op.cpp b/transformer_engine/pytorch/csrc/ts_fp8_op.cpp index b0085de04e..e3d1ef4d7b 100755 --- a/transformer_engine/pytorch/csrc/ts_fp8_op.cpp +++ b/transformer_engine/pytorch/csrc/ts_fp8_op.cpp @@ -121,7 +121,8 @@ at::Tensor te_gemm_ts(at::Tensor A, workspace, workspaceSize_arg, accumulate_arg, - use_split_accumulator_arg); + use_split_accumulator_arg, + 0); return D; } diff --git a/transformer_engine/pytorch/module.py b/transformer_engine/pytorch/module.py index dff37497d6..3e0a868047 100644 --- a/transformer_engine/pytorch/module.py +++ b/transformer_engine/pytorch/module.py @@ -85,6 +85,8 @@ _2X_ACC_DGRAD = True _2X_ACC_WGRAD = True _cublas_workspace = None +_ub_communicators = None +_NUM_MAX_UB_STREAMS = 3 _amax_reduce_handle_bwd = None @@ -147,6 +149,105 @@ def _prepare_backward( delete_key_from_amax_buffer(forward=False) +def initialize_ub( + shape: list, + tp_size: int, + use_fp8: bool = False, + ub_cfgs: Optional[dict] = None +) -> None: + """Initialize communicators for TP comm overlap using userbuffers.""" + global _ub_communicators + assert _ub_communicators is None, "UB communicators are already initialized." + _ub_communicators = {} + rank_id = torch.distributed.get_rank() + + # Increase the workspace by the number of maximum concurrent streams + global _cublas_workspace + _cublas_workspace = get_workspace().repeat(_NUM_MAX_UB_STREAMS) + + # Default buffer precision: AllGather buffers use fp8 when using fp8 recipe + fp8_buf = [ + "qkv_fprop", "qkv_dgrad", "proj_dgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad" + ] + # Default overlap methods for layers + methods = { + "ring_exchange":["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad"], + "pipeline":["proj_fprop", "fc2_fprop"], + "bulk":["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"], + } + + def get_method(name): + for method, names in methods.items(): + if name in names: + return method + raise KeyError(f"Given layer name {name} does not exist.") + + def add_ub( + name: str, + method: str, + num_sm: int = 16, + cga_size: int = 2, + set_sm_margin: int = 0, + num_splits: int = 4, + aggregate: int = 0, + ) -> None: + dtype = torch.uint8 if (use_fp8 and name in fp8_buf) else torch.bfloat16 + sample_buffer = torch.empty(shape, dtype=dtype, device='cuda') + if method == 'ring_exchange': + ub_obj = tex.UbufP2PCommOverlap( + sample_buffer, # Sample userbuffer + rank_id, # Rank id + tp_size, # TP size + aggregate, # Aggregate 2X GEMM chunks + _NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams + ) + else: + ub_obj = tex.UbufCommOverlap( + sample_buffer, # Sample userbuffer + rank_id, # Rank id + tp_size, # TP size + num_sm, # Number of communication SMs + cga_size, # CGA cluster size + num_splits, # Number of communication splits + set_sm_margin, # Set SM margin + _NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams + ) + _ub_communicators[name] = ub_obj + + for name in (methods["ring_exchange"]+methods["pipeline"]+methods["bulk"]): + if ub_cfgs is not None and name in ub_cfgs: + ub_cfg = ub_cfgs[name] + method = ub_cfg["method"] if "method" in ub_cfg else get_method(name) + num_sm = ub_cfg["num_sm"] if "num_sm" in ub_cfg else 16 + cga_size = ub_cfg["cga_size"] if "cga_size" in ub_cfg else 2 + num_splits = ub_cfg["num_splits"] if "num_splits" in ub_cfg else 0 + set_sm_margin = ub_cfg["set_sm_margin"] if "set_sm_margin" in ub_cfg else 0 + aggregate = ub_cfg["aggregate"] if "aggregate" in ub_cfg else 0 + add_ub( + name, + method, + num_sm, + cga_size, + set_sm_margin, + num_splits, + aggregate + ) + else: + method = get_method(name) + if method == "pipeline": + add_ub(name, method) + else: + add_ub(name, method, num_splits=0) + + +def get_ub(name: str): + """Get userbuffer communicator corresponding to give key.""" + global _ub_communicators + assert _ub_communicators is not None, "UB manager is not initialized." + assert name in _ub_communicators, f"UB for {name} is not registered." + return _ub_communicators[name] + + class _NoopCat(torch.autograd.Function): """This class is a no-op replacement for `torch.cat`.""" @@ -596,9 +697,13 @@ def grad_output_preprocess( # No-FP8 case: bgrad is fused with wgrad for this case. if not ctx.fp8: if gather_grad_output: - grad_output_mat, _ = gather_along_first_dim( - grad_output_mat, ctx.tp_group - ) + if not ctx.ub_split_ag: + grad_output_mat, _ = gather_along_first_dim( + grad_output_mat, ctx.tp_group + ) + else: + ctx.ub_obj_gradout.copy_input_to_ubuf(grad_output, True) + grad_output_mat = ctx.ub_obj_gradout.get_ubuf_output(1) return grad_output_mat, None, None, None fp8_dtype_backward = get_fp8_te_dtype( @@ -610,6 +715,9 @@ def grad_output_preprocess( gather_grad_output and ctx.fp8_meta["recipe"].override_linear_precision.wgrad ): + assert ( + not ctx.ub_split_ag + ), "override_linear_precision.wgrad not supported with ub_split_ag" grad_output_mat, _ = gather_along_first_dim(grad_output_mat, ctx.tp_group) # FP8 case with gather: unfused bgrad, cast, transpose for efficient gather elif gather_grad_output: @@ -617,14 +725,23 @@ def grad_output_preprocess( grad_bias = grad_output_mat.sum(dim=0) else: grad_bias = None - grad_output_c = cast_to_fp8( + if ctx.ub_split_ag: + grad_output_c = ctx.ub_obj_gradout.get_ubuf_output(0) + else: + grad_output_c = torch.empty_like(grad_output_mat, dtype=torch.uint8) + cast_to_fp8( grad_output_mat, ctx.fp8_meta["scaling_bwd"], tex.FP8BwdTensors.GRAD_OUTPUT1, fp8_dtype_backward, + out=grad_output_c, ) - grad_output_c, _ = gather_along_first_dim(grad_output_c, ctx.tp_group) - grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward) + if not ctx.ub_split_ag: + grad_output_c, _ = gather_along_first_dim(grad_output_c, ctx.tp_group) + grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward) + else: + grad_output_c = ctx.ub_obj_gradout.get_ubuf_output(1) + grad_output_t = None return grad_output_mat, grad_output_c, grad_output_t, grad_bias @@ -718,6 +835,9 @@ def forward( fwd_ln_sm_margin: int, bwd_ln_sm_margin: int, zero_centered_gamma: bool, + ub_bulk_wgrad: bool, + ub_bulk_dgrad: bool, + ub_split_ag: bool, ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # Make sure input dimensions are compatible in_features = ln_weight.numel() @@ -733,16 +853,26 @@ def forward( inputmat = cast_if_needed(inputmat, activation_dtype) ln_weight = cast_if_needed(ln_weight, activation_dtype) ln_bias = cast_if_needed(ln_bias, activation_dtype) - # If residual connection is after LN, we need `ln_out` # tensor in higher precision, this comes at the cost # of an extra fp8 cast. + if ub_split_ag: + tp_world_size = get_distributed_world_size(tp_group) + if tp_world_size == 1 or (not is_grad_enabled) or return_layernorm_output: + ub_split_ag = False + if ub_split_ag: + dim_size = list(inputmat.size()) + dim_size[0] = dim_size[0] * tp_world_size + ub_obj_lnout = get_ub("qkv_fprop") + ln_out = ub_obj_lnout.get_ubuf_output(0) if fp8: fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) if not return_layernorm_output: if is_grad_enabled: - ln_out, mu, rsigma = layernorm_fwd_fp8( + if not ub_split_ag: + ln_out = torch.empty_like(inputmat, dtype=torch.uint8) + _, mu, rsigma = layernorm_fwd_fp8( inputmat, ln_weight, ln_bias, @@ -752,6 +882,7 @@ def forward( fp8_dtype_forward, fwd_ln_sm_margin, zero_centered_gamma, + ln_out = ln_out ) else: mu = rsigma = None @@ -783,17 +914,25 @@ def forward( ) else: if is_grad_enabled: - ln_out, mu, rsigma = tex.layernorm_fwd( - inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma - ) + if ub_split_ag: + _, mu, rsigma = tex.layernorm_fwd_noalloc( + inputmat, ln_weight, ln_bias, ln_out, eps, + fwd_ln_sm_margin, zero_centered_gamma + ) + else: + ln_out, mu, rsigma = tex.layernorm_fwd( + inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma + ) else: ln_out, mu, rsigma = layernorm_fwd_inf( inputmat, ln_weight, ln_bias, eps, zero_centered_gamma ), None, None ln_out_return = ln_out - # Column Parallel Linear - if parallel_mode == "column" and sequence_parallel: + if ub_split_ag: + ln_out_total = ub_obj_lnout.get_ubuf_output(1) + ln_out = torch.empty_like(ln_out) + elif parallel_mode == "column" and sequence_parallel: ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) else: ln_out_total = ln_out @@ -838,6 +977,9 @@ def forward( bias=bias, use_bias=use_bias, use_split_accumulator=_2X_ACC_FPROP, + ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None, + ub=ub_obj_lnout if ub_split_ag else None, + extra_output_tensor=ln_out if ub_split_ag else None, ) else: # Cast for native AMP @@ -859,6 +1001,9 @@ def forward( get_workspace(), bias=bias, use_bias=use_bias, + ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None, + ub=ub_obj_lnout if ub_split_ag else None, + extra_output_tensor=ln_out if ub_split_ag else None, ) if is_grad_enabled: @@ -888,6 +1033,8 @@ def forward( ctx.return_layernorm_output = return_layernorm_output ctx.bwd_ln_sm_margin = bwd_ln_sm_margin ctx.zero_centered_gamma = zero_centered_gamma + ctx.ub_bulk_wgrad = ub_bulk_wgrad + ctx.ub_bulk_dgrad = ub_bulk_dgrad ctx.requires_dgrad = inp.requires_grad # Row Parallel Linear @@ -922,6 +1069,15 @@ def backward( fwd_scale_inverses, ) = ctx.saved_tensors + if ctx.ub_bulk_dgrad: + tp_world_size = get_distributed_world_size(ctx.tp_group) + if tp_world_size == 1: + ctx.ub_bulk_dgrad = False + if ctx.ub_bulk_dgrad: + dim_size = list(ln_out.size()) + dim_size[0] = dim_size[0] * tp_world_size + ub_obj_lnout = get_ub("qkv_dgrad") + ub_obj_lnout.copy_input_to_ubuf(ln_out, 1) ( grad_output, grad_output_c, @@ -931,9 +1087,14 @@ def backward( ctx, grad_outputs[0], ctx.parallel_mode == "row" ) + if ctx.ub_bulk_wgrad: + tp_world_size = get_distributed_world_size(ctx.tp_group) + if tp_world_size == 1: + ctx.ub_bulk_wgrad = False + # Column Parallel Linear # Overlap input AG with dgrad - if ctx.parallel_mode == "column" and ctx.sequence_parallel: + if (not ctx.ub_bulk_dgrad) and ctx.parallel_mode == "column" and ctx.sequence_parallel: ln_out_total, handle = gather_along_first_dim( ln_out, ctx.tp_group, async_op=True ) @@ -947,6 +1108,15 @@ def backward( else: accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation + + dgrad_size = list(grad_output.size()) + dgrad_size[1] = weight.size(1) + if ctx.ub_bulk_wgrad: # allocate dgrad output + ub_obj_dgrad = get_ub("qkv_wgrad") + dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output + else: + dgrad = torch.empty (dgrad_size, dtype=ctx.activation_dtype, device=weight.device) + if ctx.fp8: fp8_dtype_forward = get_fp8_te_dtype( ctx.fp8_meta["recipe"], fprop_tensor=True @@ -956,7 +1126,7 @@ def backward( ) # DGRAD: Evaluated unconditionally to feed into Linear backward - dgrad = fp8_gemm( + _ = fp8_gemm( weight_t_fp8, fwd_scale_inverses, tex.FP8FwdTensors.GEMM1_WEIGHT, @@ -967,25 +1137,35 @@ def backward( fp8_dtype_backward, ctx.activation_dtype, get_workspace(), + out=dgrad, use_split_accumulator=_2X_ACC_DGRAD, + ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_AG if ctx.ub_bulk_dgrad else None, + ub=ub_obj_lnout if ctx.ub_bulk_dgrad else None ) else: # DGRAD: Evaluated unconditionally to feed into Linear backward - dgrad, _, _ = gemm( + _, _, _ = gemm( weight, grad_output, ctx.activation_dtype, get_workspace(), + out=dgrad, layout="NN", grad=True, + ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_AG if ctx.ub_bulk_dgrad else None, + ub=ub_obj_lnout if ctx.ub_bulk_dgrad else None ) + if ctx.ub_bulk_dgrad: + ln_out_total = ub_obj_lnout.get_ubuf_output(1) # Overlap dgrad-RS/AR with wgrad if ctx.parallel_mode == "column" and ctx.sequence_parallel: - handle.wait() - dgrad, handle = reduce_scatter_along_first_dim( - dgrad, ctx.tp_group, async_op=True - ) + if not ctx.ub_bulk_dgrad: + handle.wait() + if not ctx.ub_bulk_wgrad: + dgrad, handle = reduce_scatter_along_first_dim( + dgrad, ctx.tp_group, async_op=True + ) elif ctx.parallel_mode == "column" and ctx.tensor_parallel: dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True) @@ -1008,6 +1188,9 @@ def backward( accumulate=accumulate_wgrad_into_param_main_grad, out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, use_split_accumulator=_2X_ACC_WGRAD, + ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS + if ctx.ub_bulk_wgrad else None, + ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None ) else: ln_out_total_c = cast_from_fp8( @@ -1026,6 +1209,9 @@ def backward( grad=True, accumulate=accumulate_wgrad_into_param_main_grad, out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, + ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS + if ctx.ub_bulk_wgrad else None, + ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None ) else: # WGRAD @@ -1039,10 +1225,15 @@ def backward( use_bias=ctx.use_bias, accumulate=accumulate_wgrad_into_param_main_grad, out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, + ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None, + ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None ) + + if ctx.ub_bulk_wgrad: + dgrad = ub_obj_dgrad.get_ubuf_output(0) # Reduce-scatter output # Column Parallel Linear - if ctx.parallel_mode == "column" and ctx.tensor_parallel and handle is not None: + elif ctx.parallel_mode == "column" and ctx.tensor_parallel and handle is not None: handle.wait() # LayerNorm gradient @@ -1086,6 +1277,9 @@ def backward( None, None, None, + None, + None, + None, ) @@ -1179,6 +1373,9 @@ def __init__( skip_weight_param_allocation: bool = False, parameters_split: Optional[Tuple[str, ...]] = None, zero_centered_gamma: bool = False, + ub_bulk_wgrad: bool = False, + ub_bulk_dgrad: bool = False, + ub_split_ag: bool = False, ) -> None: super().__init__() self.in_features = in_features @@ -1190,6 +1387,14 @@ def __init__( self.return_layernorm_output = return_layernorm_output self.parameters_split = parameters_split self.zero_centered_gamma = zero_centered_gamma + self.ub_bulk_wgrad = ub_bulk_wgrad + self.ub_bulk_dgrad = ub_bulk_dgrad + self.ub_split_ag = ub_split_ag + + if ub_bulk_wgrad or ub_bulk_dgrad or ub_split_ag: + assert ( + tex.userbuf_comm_available() + ), "Userbuffer communication backend not available." if tp_group is None: self.tp_size = tp_size @@ -1308,6 +1513,7 @@ def __init__( self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features))) + # For RPL, bias has to be added after TP collectives # So it cannot be fused with the GEMM if self.parallel_mode == "row" and self.apply_bias: @@ -1412,6 +1618,9 @@ def forward( self.fwd_ln_sm_margin, self.bwd_ln_sm_margin, self.zero_centered_gamma, + self.ub_bulk_wgrad, + self.ub_bulk_dgrad, + self.ub_split_ag, ) out = fwd_fn(*args) @@ -1455,6 +1664,8 @@ def forward( activation_dtype: torch.dtype, parallel_mode: Union[str, None], is_grad_enabled: bool, + ub_split_rs: bool, + ub_split_ag: bool, ) -> torch.Tensor: # Make sure input dimensions are compatible in_features = weight.shape[-1] @@ -1466,6 +1677,10 @@ def forward( update_fp8_weights = is_first_microbatch is None or is_first_microbatch + if ub_split_rs: + tp_world_size = get_distributed_world_size(tp_group) + if tp_world_size == 1: + ub_split_rs = False # Cast for native AMP inputmat = cast_if_needed(inputmat, activation_dtype) inputmat_no_fp8 = inputmat @@ -1529,7 +1744,19 @@ def forward( fp8_dtype_forward, ) - out = fp8_gemm( + if ub_split_rs: + ub_obj_projout = get_ub("proj_fprop") + out = ub_obj_projout.get_ubuf_output(1) + dim_size = list(inputmat_total.size()) + dim_size[0] = dim_size[0] // tp_world_size + dim_size[1] = weight.size(0) + rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) + else: + dim_size = list(inputmat_total.size()) + dim_size[1] = weight.size(0) + out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) + + _ = fp8_gemm( weight_fp8, fp8_meta["scaling_fwd"].scale_inv, tex.FP8FwdTensors.GEMM1_WEIGHT, @@ -1543,6 +1770,10 @@ def forward( bias=bias, use_bias=use_bias, use_split_accumulator=_2X_ACC_FPROP, + out=out, + ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else None, + ub=ub_obj_projout if ub_split_rs else None, + extra_output_tensor=rs_out if ub_split_rs else None, ) else: # Cast for native AMP @@ -1557,13 +1788,29 @@ def forward( fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = \ torch.amax(weight).float() - out, _, _ = gemm( + if ub_split_rs: + ub_obj_projout = get_ub("proj_fprop") + out = ub_obj_projout.get_ubuf_output(1) + dim_size = list(inputmat_total.size()) + dim_size[0] = dim_size[0] // tp_world_size + dim_size[1] = weight.size(0) + rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) + else: + dim_size = list(inputmat_total.size()) + dim_size[1] = weight.size(0) + out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) + + _, _, _ = gemm( weight, inputmat_total, activation_dtype, get_workspace(), bias=bias, use_bias=use_bias, + out=out, + ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else None, + ub=ub_obj_projout if ub_split_rs else None, + extra_output_tensor=rs_out if ub_split_rs else None, ) if is_grad_enabled: @@ -1586,11 +1833,14 @@ def forward( ctx.inp_shape = inp.shape ctx.parallel_mode = parallel_mode ctx.tp_group = tp_group + ctx.ub_split_ag = ub_split_ag ctx.tp_size = tp_size ctx.requires_dgrad = inp.requires_grad # Row Parallel Linear - if parallel_mode == "row" and sequence_parallel: + if ub_split_rs: + out = rs_out + elif parallel_mode == "row" and sequence_parallel: out, _ = reduce_scatter_along_first_dim(out, tp_group) elif parallel_mode == "row" and tensor_parallel: out, _ = allreduce(out, tp_group) @@ -1614,6 +1864,14 @@ def backward( fwd_scale_inverses, ) = ctx.saved_tensors + if ctx.ub_split_ag: + tp_world_size = get_distributed_world_size(ctx.tp_group) + if tp_world_size == 1: + ctx.ub_split_ag = False + if ctx.ub_split_ag: + dim_size = list(grad_output.size()) + dim_size[0] = dim_size[0] * tp_world_size + ctx.ub_obj_gradout = get_ub("proj_dgrad") ( grad_output, grad_output_c, @@ -1667,6 +1925,8 @@ def backward( ctx.activation_dtype, get_workspace(), use_split_accumulator=_2X_ACC_DGRAD, + ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None, + ub=ctx.ub_obj_gradout if ctx.ub_split_ag else None, ) else: dgrad, _, _ = gemm( @@ -1676,6 +1936,8 @@ def backward( get_workspace(), layout="NN", grad=True, + ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None, + ub=ctx.ub_obj_gradout if ctx.ub_split_ag else None, ) # Overlap dgrad-RS/AR with wgrad @@ -1691,6 +1953,8 @@ def backward( if ctx.fp8: # WGRAD if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: + if ctx.ub_split_ag: + grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward) wgrad = fp8_gemm( inputmat_t_total, fwd_scale_inverses, @@ -1757,6 +2021,8 @@ def backward( None, None, None, + None, + None, ) @@ -1838,6 +2104,8 @@ def __init__( parallel_mode: Optional[str] = None, skip_weight_param_allocation: bool = False, parameters_split: Optional[Tuple[str, ...]] = None, + ub_split_rs: bool = False, + ub_split_ag: bool = False, ) -> None: super().__init__() self.in_features = in_features @@ -1847,6 +2115,13 @@ def __init__( self.return_bias = return_bias self.apply_bias = bias and not return_bias self.parameters_split = parameters_split + self.ub_split_rs = ub_split_rs + self.ub_split_ag = ub_split_ag + + if ub_split_rs or ub_split_ag: + assert ( + tex.userbuf_comm_available() + ), "Userbuffer communication backend not available." if tp_group is None: self.tp_size = tp_size @@ -2028,6 +2303,8 @@ def forward( self.activation_dtype, self.parallel_mode, torch.is_grad_enabled(), + self.ub_split_rs, + self.ub_split_ag, ) out = linear_fn(*args) @@ -2078,6 +2355,10 @@ def forward( fwd_ln_sm_margin: int, bwd_ln_sm_margin: int, zero_centered_gamma: bool, + ub_bulk_wgrad: bool, + ub_bulk_dgrad: bool, + ub_split_rs: bool, + ub_split_ag: bool, ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # Make sure input dimensions are compatible in_features = ln_weight.numel() @@ -2094,6 +2375,18 @@ def forward( ln_weight = cast_if_needed(ln_weight, activation_dtype) ln_bias = cast_if_needed(ln_bias, activation_dtype) + if ub_split_ag: + tp_world_size = get_distributed_world_size(tp_group) + if tp_world_size == 1 or (not is_grad_enabled) or return_layernorm_output: + ub_split_ag = False + if ub_split_ag: + ub_obj_lnout = get_ub("fc1_fprop") + ln_out = ub_obj_lnout.get_ubuf_output(0) + if ub_split_rs: + tp_world_size = get_distributed_world_size(tp_group) + if tp_world_size == 1: + ub_split_rs = False + # If residual connection is after LN, we need `ln_out` # tensor in higher precision, this comes at the cost # of an extra fp8 cast. @@ -2101,7 +2394,9 @@ def forward( fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) if not return_layernorm_output: if is_grad_enabled: - ln_out, mu, rsigma = layernorm_fwd_fp8( + if not ub_split_ag: + ln_out = torch.empty_like(inputmat, dtype=torch.uint8) + _, mu, rsigma = layernorm_fwd_fp8( inputmat, ln_weight, ln_bias, @@ -2111,6 +2406,7 @@ def forward( fp8_dtype_forward, fwd_ln_sm_margin, zero_centered_gamma, + ln_out = ln_out, ) else: ln_out = layernorm_fwd_fp8_inf( @@ -2135,9 +2431,15 @@ def forward( ) else: if is_grad_enabled: - ln_out, mu, rsigma = tex.layernorm_fwd( - inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma - ) + if ub_split_ag: + _, mu, rsigma = tex.layernorm_fwd_noalloc( + inputmat, ln_weight, ln_bias, ln_out, eps, + fwd_ln_sm_margin, zero_centered_gamma + ) + else: + ln_out, mu, rsigma = tex.layernorm_fwd( + inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma + ) else: ln_out, mu, rsigma = layernorm_fwd_inf( inputmat, ln_weight, ln_bias, eps, zero_centered_gamma @@ -2145,7 +2447,10 @@ def forward( ln_out_return = ln_out # Column Parallel Linear - if set_parallel_mode and sequence_parallel: + if ub_split_ag: + ln_out_total = ub_obj_lnout.get_ubuf_output(1) + ln_out = torch.empty_like(ln_out) + elif set_parallel_mode and sequence_parallel: ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) else: ln_out_total = ln_out @@ -2208,6 +2513,9 @@ def forward( bias=fc1_bias, use_bias=use_fc1_bias, use_split_accumulator=_2X_ACC_FPROP, + ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None, + ub=ub_obj_lnout if ub_split_ag else None, + extra_output_tensor=ln_out if ub_split_ag else None, ) gelu_out = fp8_gelu( @@ -2217,7 +2525,19 @@ def forward( fp8_dtype_forward, ) - fc2_out = fp8_gemm( + if ub_split_rs: + ub_obj_fc2out = get_ub("fc2_fprop") + fc2_out = ub_obj_fc2out.get_ubuf_output(1) + dim_size = list(gelu_out.size()) + dim_size[0] = dim_size[0] // tp_world_size + dim_size[1] = fc2_weight.size(0) + rs_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device) + else: + dim_size = list(gelu_out.size()) + dim_size[1] = fc2_weight.size(0) + fc2_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device) + + _ = fp8_gemm( fc2_weight_fp8, fp8_meta["scaling_fwd"].scale_inv, tex.FP8FwdTensors.GEMM2_WEIGHT, @@ -2231,6 +2551,10 @@ def forward( bias=fc2_bias, use_bias=use_fc2_bias, use_split_accumulator=_2X_ACC_FPROP, + out=fc2_out, + ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else None, + ub=ub_obj_fc2out if ub_split_rs else None, + extra_output_tensor=rs_out if ub_split_rs else None, ) else: # Cast for native AMP @@ -2259,6 +2583,9 @@ def forward( bias=fc1_bias, use_bias=(not bias_gelu_nvfusion) and use_fc1_bias, gelu=not bias_gelu_nvfusion, + ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None, + ub=ub_obj_lnout if ub_split_ag else None, + extra_output_tensor=ln_out if ub_split_ag else None, ) if bias_gelu_nvfusion: @@ -2276,14 +2603,30 @@ def forward( fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM2_WEIGHT] = \ torch.amax(fc2_weight).float() - fc2_out, _, _ = gemm( + if ub_split_rs: + ub_obj_fc2out = get_ub("fc2_fprop") + fc2_out = ub_obj_fc2out.get_ubuf_output(1) + dim_size = list(gelu_out.size()) + dim_size[0] = dim_size[0] // tp_world_size + dim_size[1] = fc2_weight.size(0) + rs_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device) + else: + dim_size = list(gelu_out.size()) + dim_size[1] = fc2_weight.size(0) + fc2_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device) + _, _, _ = gemm( fc2_weight, gelu_out, activation_dtype, get_workspace(), bias=fc2_bias, use_bias=use_fc2_bias, + out=fc2_out, + ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else None, + ub=ub_obj_fc2out if ub_split_rs else None, + extra_output_tensor=rs_out if ub_split_rs else None, ) + if is_grad_enabled: ctx.save_for_backward( inputmat, @@ -2317,10 +2660,15 @@ def forward( ctx.set_parallel_mode = set_parallel_mode ctx.bwd_ln_sm_margin = bwd_ln_sm_margin ctx.zero_centered_gamma = zero_centered_gamma + ctx.ub_bulk_wgrad = ub_bulk_wgrad + ctx.ub_bulk_dgrad = ub_bulk_dgrad + ctx.ub_split_ag = ub_split_ag ctx.requires_dgrad = inp.requires_grad # Row Parallel Linear - if set_parallel_mode and sequence_parallel: + if ub_split_rs: + fc2_out = rs_out + elif set_parallel_mode and sequence_parallel: fc2_out, _ = reduce_scatter_along_first_dim(fc2_out, tp_group) elif set_parallel_mode and tensor_parallel: fc2_out, _ = allreduce(fc2_out, tp_group) @@ -2356,6 +2704,24 @@ def backward( fwd_scale_inverses, ) = ctx.saved_tensors + if ctx.ub_bulk_dgrad: + tp_world_size = get_distributed_world_size(ctx.tp_group) + if tp_world_size == 1: + ctx.ub_bulk_dgrad = False + if ctx.ub_bulk_dgrad: + dim_size = list(ln_out.size()) + dim_size[0] = dim_size[0] * tp_world_size + ub_obj_lnout = get_ub("fc1_dgrad") + ub_obj_lnout.copy_input_to_ubuf(ln_out, 1) + if ctx.ub_split_ag: + tp_world_size = get_distributed_world_size(ctx.tp_group) + if tp_world_size == 1: + ctx.ub_split_ag = False + if ctx.ub_split_ag: + dim_size = list(grad_outputs[0].size()) + dim_size[0] = dim_size[0] * tp_world_size + ctx.ub_obj_gradout = get_ub("fc2_dgrad") + ctx.use_bias = ctx.use_fc2_bias # For grad_output_preprocess ( grad_output, @@ -2365,10 +2731,13 @@ def backward( ) = TransformerEngineBaseModule.grad_output_preprocess( ctx, grad_outputs[0], True ) - + if ctx.ub_bulk_wgrad: + tp_world_size = get_distributed_world_size(ctx.tp_group) + if tp_world_size == 1: + ctx.ub_bulk_wgrad = False # Column Parallel Linear # Overlap input AG with dgrad - if ctx.set_parallel_mode and ctx.sequence_parallel: + if (not ctx.ub_bulk_dgrad) and ctx.set_parallel_mode and ctx.sequence_parallel: ln_out_total, handle = gather_along_first_dim( ln_out, ctx.tp_group, async_op=True ) @@ -2403,8 +2772,11 @@ def backward( ctx.activation_dtype, get_workspace(), use_split_accumulator=_2X_ACC_DGRAD, + ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None, + ub=ctx.ub_obj_gradout if ctx.ub_split_ag else None, ) - + if ctx.ub_split_ag: + grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward) # FC2 WGRAD if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: if fc2_weight.requires_grad: @@ -2469,8 +2841,17 @@ def backward( ) dgelu_t = None + fc1_dgrad_size = list(dgelu.size()) + fc1_dgrad_size[1] = fc1_weight.size(1) + if ctx.ub_bulk_wgrad: # allocate dgrad output + ub_obj_dgrad = get_ub("fc1_wgrad") + fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output + else: + fc1_dgrad = torch.empty( + fc1_dgrad_size, dtype=ctx.activation_dtype, device=fc1_weight.device + ) # FC1 DGRAD: Unconditional - fc1_dgrad = fp8_gemm( + _ = fp8_gemm( fc1_weight_t_fp8, fwd_scale_inverses, tex.FP8FwdTensors.GEMM1_WEIGHT, @@ -2481,7 +2862,10 @@ def backward( fp8_dtype_backward, ctx.activation_dtype, get_workspace(), + out=fc1_dgrad, use_split_accumulator=_2X_ACC_DGRAD, + ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_AG if ctx.ub_bulk_dgrad else None, + ub=ub_obj_lnout if ctx.ub_bulk_dgrad else None ) else: # FC2 DGRAD; Unconditional @@ -2494,6 +2878,8 @@ def backward( gelu=not ctx.bias_gelu_nvfusion, grad=True, gelu_input=fc1_out, + ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None, + ub=ctx.ub_obj_gradout if ctx.ub_split_ag else None, ) # FC2 WGRAD @@ -2515,22 +2901,38 @@ def backward( else: dgelu = fc2_dgrad + fc1_dgrad_size = list(dgelu.size()) + fc1_dgrad_size[1] = fc1_weight.size(1) + if ctx.ub_bulk_wgrad: # allocate dgrad output + ub_obj_dgrad = get_ub("fc1_wgrad") + fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output + else: + fc1_dgrad = torch.empty( + fc1_dgrad_size, dtype=ctx.activation_dtype, device=fc1_weight.device + ) # FC1 DGRAD: Unconditional - fc1_dgrad, _, _ = gemm( + _, _, _ = gemm( fc1_weight, dgelu, ctx.activation_dtype, get_workspace(), + out=fc1_dgrad, layout="NN", grad=True, + ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_AG if ctx.ub_bulk_dgrad else None, + ub=ub_obj_lnout if ctx.ub_bulk_dgrad else None ) + if ctx.ub_bulk_dgrad: + ln_out_total = ub_obj_lnout.get_ubuf_output(1) # Overlap dgrad-RS/AR with wgrad if ctx.set_parallel_mode and ctx.sequence_parallel: - handle.wait() - fc1_dgrad, handle = reduce_scatter_along_first_dim( - fc1_dgrad, ctx.tp_group, async_op=True - ) + if not ctx.ub_bulk_dgrad: + handle.wait() + if not ctx.ub_bulk_wgrad: + fc1_dgrad, handle = reduce_scatter_along_first_dim( + fc1_dgrad, ctx.tp_group, async_op=True + ) elif ctx.set_parallel_mode and ctx.tensor_parallel: fc1_dgrad, handle = allreduce(fc1_dgrad, ctx.tp_group, async_op=True) @@ -2555,6 +2957,9 @@ def backward( if ctx.fuse_wgrad_accumulation else None, use_split_accumulator=_2X_ACC_WGRAD, + ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS + if ctx.ub_bulk_wgrad else None, + ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, ) else: ln_out_total_c = cast_from_fp8( @@ -2575,6 +2980,9 @@ def backward( out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, + ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS + if ctx.ub_bulk_wgrad else None, + ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, ) else: # FC1 WGRAD @@ -2588,6 +2996,8 @@ def backward( use_bias=not ctx.bias_gelu_nvfusion, accumulate=accumulate_wgrad_into_param_main_grad, out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, + ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None, + ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None ) if ctx.bias_gelu_nvfusion: @@ -2596,7 +3006,9 @@ def backward( fc1_wgrad, fc1_bias_grad, _ = fc1_wgrad_outputs # Column Parallel Linear - if ctx.set_parallel_mode and ctx.tensor_parallel and handle is not None: + if ctx.ub_bulk_wgrad: + fc1_dgrad = ub_obj_dgrad.get_ubuf_output(0) # Reduce-scatter output + elif ctx.set_parallel_mode and ctx.tensor_parallel and handle is not None: handle.wait() # LayerNorm gradient @@ -2643,6 +3055,10 @@ def backward( None, None, None, + None, + None, + None, + None, ) @@ -2741,6 +3157,10 @@ def __init__( micro_batch_size: Optional[int] = None, set_parallel_mode: bool = False, zero_centered_gamma: bool = False, + ub_bulk_wgrad: bool = False, + ub_bulk_dgrad: bool = False, + ub_split_rs: bool = False, + ub_split_ag: bool = False, ) -> None: super().__init__() @@ -2752,6 +3172,15 @@ def __init__( self.bias_gelu_nvfusion = bool(int(os.getenv("NVTE_BIAS_GELU_NVFUSION", "1"))) self.set_parallel_mode = set_parallel_mode self.zero_centered_gamma = zero_centered_gamma + self.ub_bulk_wgrad = ub_bulk_wgrad + self.ub_bulk_dgrad = ub_bulk_dgrad + self.ub_split_rs = ub_split_rs + self.ub_split_ag = ub_split_ag + + if ub_bulk_wgrad or ub_bulk_dgrad or ub_split_rs or ub_split_ag: + assert ( + tex.userbuf_comm_available() + ), "Userbuffer communication backend not available." if tp_group is None: self.tp_size = tp_size @@ -2948,6 +3377,10 @@ def forward( self.fwd_ln_sm_margin, self.bwd_ln_sm_margin, self.zero_centered_gamma, + self.ub_bulk_wgrad, + self.ub_bulk_dgrad, + self.ub_split_rs, + self.ub_split_ag, ) out = fwd_fn(*args) diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 83582e2aae..52d303e8f4 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -15,6 +15,7 @@ from flash_attn.flash_attn_interface import flash_attn_unpadded_func +import transformer_engine_extensions as tex from transformer_engine.pytorch.module import LayerNormLinear, Linear, LayerNormMLP, LayerNorm from transformer_engine.pytorch.jit import ( set_jit_fusion_options, @@ -495,6 +496,10 @@ def __init__( fuse_qkv_params: bool = False, zero_centered_gamma: bool = False, qkv_weight_interleaved: bool = True, + ub_bulk_wgrad: bool = False, + ub_bulk_dgrad: bool = False, + ub_split_rs: bool = False, + ub_split_ag: bool = False, bias: bool = True, ) -> None: super().__init__() @@ -547,6 +552,9 @@ def __init__( return_layernorm_output=return_layernorm_output, parameters_split=("query_", "key_", "value_") if not fuse_qkv_params else None, zero_centered_gamma=zero_centered_gamma, + ub_bulk_wgrad=ub_bulk_wgrad, + ub_bulk_dgrad=ub_bulk_dgrad, + ub_split_ag=ub_split_ag, **common_gemm_kwargs, ) else: @@ -572,6 +580,9 @@ def __init__( parallel_mode=qkv_parallel_mode, return_layernorm_output=return_layernorm_output, zero_centered_gamma=zero_centered_gamma, + ub_bulk_wgrad=ub_bulk_wgrad, + ub_bulk_dgrad=ub_bulk_dgrad, + ub_split_ag=ub_split_ag, **common_gemm_kwargs, ) else: @@ -616,6 +627,8 @@ def __init__( bias=bias, return_bias=True, parallel_mode="row" if set_parallel_mode else None, + ub_split_rs=ub_split_rs, + ub_split_ag=ub_split_ag, **common_gemm_kwargs, ) @@ -911,6 +924,12 @@ class TransformerLayer(torch.nn.Module): `set_tensor_parallel_group(tp_group)` method on the initialized module before the forward pass to supply the tensor parallel group needed for tensor and sequence parallel collectives. + ub_bulk_wgrad: bool, default = False + Bulk overlap UserBuffer ReduceScatter | WGRAD GEMM + ub_bulk_dgrad: bool, default = False + Bulk overlap UserBuffer AllGather | DGRAD GEMM + ub_split_ag: bool, default = False + Split pipelined overlap UserBuffer AllGather -> GEMM Optimization parameters ----------------------- @@ -970,6 +989,7 @@ def __init__( fuse_qkv_params: bool = False, zero_centered_gamma: bool = False, qkv_weight_interleaved: bool = True, + ub_tp_comm_overlap: bool = False, bias: bool = True, ) -> None: super().__init__() @@ -980,6 +1000,16 @@ def __init__( category=DeprecationWarning, ) + if ub_tp_comm_overlap: + assert ( + tex.userbuf_comm_available() + ), "Userbuffer communication backend not available." + + ub_tp_comm_overlap = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_OVERLAP", "1"))) + ub_bulk_wgrad = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_BULK_WGRAD", "1"))) + ub_bulk_dgrad = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_BULK_DGRAD", "1"))) + ub_split_ag = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_SPLIT_AG", "1"))) + ub_split_rs = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_SPLIT_RS", "1"))) bias_dropout_fusion = bool(int(os.getenv("NVTE_BIAS_DROPOUT_FUSION", "1"))) self.layer_number = layer_number self.output_layernorm = output_layernorm @@ -1037,6 +1067,10 @@ def __init__( "fuse_qkv_params": fuse_qkv_params, "zero_centered_gamma": zero_centered_gamma, "qkv_weight_interleaved" : qkv_weight_interleaved, + "ub_bulk_wgrad" : ub_bulk_wgrad, + "ub_bulk_dgrad" : ub_bulk_dgrad, + "ub_split_ag" : ub_split_ag, + "ub_split_rs" : ub_split_rs, } self.self_attention = MultiHeadAttention( @@ -1080,6 +1114,10 @@ def __init__( micro_batch_size=micro_batch_size, set_parallel_mode=set_parallel_mode, zero_centered_gamma=zero_centered_gamma, + ub_bulk_wgrad=ub_bulk_wgrad, + ub_bulk_dgrad=ub_bulk_dgrad, + ub_split_rs=ub_split_rs, + ub_split_ag=ub_split_ag, ) self.hidden_dropout = hidden_dropout diff --git a/transformer_engine/tensorflow/csrc/extensions.cu b/transformer_engine/tensorflow/csrc/extensions.cu index aa2ad0b3ba..8cda79a7ed 100644 --- a/transformer_engine/tensorflow/csrc/extensions.cu +++ b/transformer_engine/tensorflow/csrc/extensions.cu @@ -568,7 +568,7 @@ py::object TFE_Py_TeGemm_wrapper( nvte_cublas_gemm(a_tensor.data(), b_tensor.data(), d_tensor.data(), bias_tensor.data(), gelu_input_tensor.data(), transa, transb, grad, workspace_tensor.data(), accumulate, - use_split_accumulate, stream); + use_split_accumulate, 0, stream); auto d_eager = CreateTensor(d_ptr, d_shape, otype); if (use_gelu && !grad) { From 7bf886d1e9cfc23146f0d6da4db7edfcabad3338 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Fri, 21 Apr 2023 00:25:12 -0700 Subject: [PATCH 19/68] Move userbuffer to PyTorch (#162) * Initial refactor; linker error Signed-off-by: Kirthi Shankar Sivamani * Fix linking issue and make mpi conditional Signed-off-by: Kirthi Shankar Sivamani * Fix TF/JAX build Signed-off-by: Kirthi Shankar Sivamani * Use max SMs at the last RS chunk in pipelined overlap Co-authored-by: Sangkug Lym Signed-off-by: Kirthi Shankar Sivamani * lint Signed-off-by: Kirthi Shankar Sivamani * Make userbuffers support opt-in Decouple userbuffers from MPI. Refactor MPI handling in build system. Standardize names to "userbuffers". Signed-off-by: Tim Moon * Lint Signed-off-by: Tim Moon --------- Signed-off-by: Kirthi Shankar Sivamani Signed-off-by: Tim Moon Co-authored-by: Sangkug Lym Co-authored-by: Tim Moon --- qa/L0_cppunittest/test.sh | 6 +- setup.py | 22 +++--- tests/cpp/CMakeLists.txt | 7 +- tests/cpp/operator/CMakeLists.txt | 4 - transformer_engine/CMakeLists.txt | 4 +- transformer_engine/common/CMakeLists.txt | 74 ++++++++----------- transformer_engine/common/__init__.py | 15 ++-- .../pytorch/csrc/comm_gemm_overlap.h | 13 +++- transformer_engine/pytorch/csrc/extensions.cu | 12 +-- .../pytorch/csrc/userbuffers/CMakeLists.txt | 33 +++++++++ .../csrc/userbuffers}/userbuffers-host.cpp | 17 +++-- .../csrc/userbuffers}/userbuffers.cu | 2 +- .../csrc/userbuffers}/userbuffers.h | 2 +- 13 files changed, 117 insertions(+), 94 deletions(-) create mode 100644 transformer_engine/pytorch/csrc/userbuffers/CMakeLists.txt rename transformer_engine/{common/comm_gemm_overlap => pytorch/csrc/userbuffers}/userbuffers-host.cpp (96%) rename transformer_engine/{common/comm_gemm_overlap => pytorch/csrc/userbuffers}/userbuffers.cu (99%) rename transformer_engine/{common/include/transformer_engine => pytorch/csrc/userbuffers}/userbuffers.h (99%) diff --git a/qa/L0_cppunittest/test.sh b/qa/L0_cppunittest/test.sh index 73a27a1fcd..6333f33fb1 100644 --- a/qa/L0_cppunittest/test.sh +++ b/qa/L0_cppunittest/test.sh @@ -9,11 +9,7 @@ set -e TE_LIB_PATH=`pip show transformer-engine | grep Location | cut -d ' ' -f 2` export LD_LIBRARY_PATH=$TE_LIB_PATH:$LD_LIBRARY_PATH -# Find MPI -MPI_HOME=${MPI_HOME:-/usr/local/mpi} -NVTE_MPI_INCLUDE="$MPI_HOME/lib" - cd $TE_PATH/tests/cpp -cmake -GNinja -Bbuild -DNVTE_MPI_INCLUDE=$NVTE_MPI_INCLUDE . +cmake -GNinja -Bbuild . cmake --build build ctest --test-dir build -j4 diff --git a/setup.py b/setup.py index decdce51a4..cb0c37fe3a 100644 --- a/setup.py +++ b/setup.py @@ -21,9 +21,10 @@ te_version = f.readline() CUDA_HOME = os.environ.get("CUDA_HOME", "/usr/local/cuda") -MPI_HOME = os.environ.get("MPI_HOME", "/usr/local/mpi") -NVTE_MPI_FOUND = os.path.exists(MPI_HOME) -NVTE_MPI_INCLUDE = os.path.join(MPI_HOME, "include") +NVTE_WITH_USERBUFFERS = int(os.environ.get("NVTE_WITH_USERBUFFERS", "0")) +if NVTE_WITH_USERBUFFERS: + MPI_HOME = os.environ.get("MPI_HOME", "") + assert MPI_HOME, "MPI_HOME must be set if NVTE_WITH_USERBUFFERS=1" def get_cuda_bare_metal_version(cuda_dir): raw_output = subprocess.check_output( @@ -70,8 +71,8 @@ def extra_compiler_flags(): "--expt-extended-lambda", "--use_fast_math", ] - if NVTE_MPI_FOUND: - extra_flags.append("-DNVTE_MPI_FOUND") + if NVTE_WITH_USERBUFFERS: + extra_flags.append("-DNVTE_WITH_USERBUFFERS") return extra_flags @@ -105,8 +106,9 @@ def make_abs_path(l): "transformer_engine/common/include", "transformer_engine/pytorch/csrc", ] -if (framework in ("all", "pytorch")) and NVTE_MPI_FOUND: - include_dirs.append(NVTE_MPI_INCLUDE) +if NVTE_WITH_USERBUFFERS: + if MPI_HOME: + include_dirs.append(os.path.join(MPI_HOME, "include")) include_dirs = make_abs_path(include_dirs) args = sys.argv.copy() @@ -165,9 +167,7 @@ def run(self, extensions): self.pytorch_build_extensions.run() def cmake_flags(self): - if not NVTE_MPI_FOUND: - return [] - return ["-DNVTE_MPI_FOUND=1", f"-DNVTE_MPI_INCLUDE={NVTE_MPI_INCLUDE}"] + return [] @staticmethod def install_requires(): @@ -338,6 +338,8 @@ def __init__(self, *args, **kwargs) -> None: self.dlfw_builder.append(functor(*args, **kwargs)) flags = [] + if NVTE_WITH_USERBUFFERS: + flags.append('-DNVTE_WITH_USERBUFFERS=ON') for builder in self.dlfw_builder: flags = flags + builder.cmake_flags() diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index 631b356fec..8bdfb89df2 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -19,7 +19,7 @@ add_subdirectory(../../3rdparty/googletest ${PROJECT_BINARY_DIR}/googletest) enable_testing() -include_directories(${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR}) +include_directories(${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR}) if(NOT DEFINED TE_LIB_PATH) execute_process(COMMAND bash -c "pip show transformer-engine | grep Location | cut -d ' ' -f 2 | tr -d '\n'" @@ -28,11 +28,6 @@ endif() find_library(TE_LIB NAMES transformer_engine PATHS ${TE_LIB_PATH} ENV TE_LIB_PATH REQUIRED) -if(EXISTS ${NVTE_MPI_INCLUDE}) - find_library(MPI_LIB NAMES mpi PATHS ${NVTE_MPI_INCLUDE} REQUIRED) - message(STATUS "Found MPI library: ${MPI_LIB}") -endif() - message(STATUS "Found transformer_engine library: ${TE_LIB}") include_directories(../../transformer_engine/common/include) include_directories(${CMAKE_SOURCE_DIR}) diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index a77cf98a73..65a7ccaebd 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -19,10 +19,6 @@ add_executable(test_operator list(APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main ${TE_LIB}) -if(EXISTS ${NVTE_MPI_INCLUDE}) - list(APPEND test_operator_LINKER_LIBS ${MPI_LIB}) -endif() - target_link_libraries(test_operator PUBLIC ${test_operator_LINKER_LIBS}) target_compile_options(test_operator PRIVATE -O2) diff --git a/transformer_engine/CMakeLists.txt b/transformer_engine/CMakeLists.txt index a03cd42806..336f41be70 100644 --- a/transformer_engine/CMakeLists.txt +++ b/transformer_engine/CMakeLists.txt @@ -8,7 +8,6 @@ if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90) endif() - set(CMAKE_CXX_STANDARD 17) set(CMAKE_CUDA_STANDARD 17) set(CMAKE_CUDA_STANDARD_REQUIRED ON) @@ -26,6 +25,9 @@ find_package(Python COMPONENTS Interpreter Development REQUIRED) include_directories(${PROJECT_SOURCE_DIR}) add_subdirectory(common) +if(NVTE_WITH_USERBUFFERS) + add_subdirectory(pytorch/csrc/userbuffers) +endif() option(ENABLE_JAX "Enable JAX in the building workflow." OFF) if(ENABLE_JAX) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 7459f77e4f..c5bc6bb0f1 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -2,54 +2,42 @@ # # See LICENSE for license information. +# Configure Transformer Engine library set(transformer_engine_SOURCES) -list(APPEND transformer_engine_SOURCES transformer_engine.cpp - transpose/cast_transpose.cu - transpose/transpose.cu - transpose/cast_transpose_fusion.cu - transpose/transpose_fusion.cu - transpose/multi_cast_transpose.cu - activation/gelu.cu - gemm/cublaslt_gemm.cu - layer_norm/ln_api.cpp - layer_norm/ln_bwd_semi_cuda_kernel.cu - layer_norm/ln_fwd_cuda_kernel.cu - rmsnorm/rmsnorm_api.cpp - rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu - rmsnorm/rmsnorm_fwd_cuda_kernel.cu - util/cast.cu - fused_softmax/scaled_masked_softmax.cu - fused_softmax/scaled_upper_triang_masked_softmax.cu) - -if(NVTE_MPI_FOUND) - list(APPEND transformer_engine_SOURCES comm_gemm_overlap/userbuffers.cu - comm_gemm_overlap/userbuffers-host.cpp) -endif() - +list(APPEND transformer_engine_SOURCES + transformer_engine.cpp + transpose/cast_transpose.cu + transpose/transpose.cu + transpose/cast_transpose_fusion.cu + transpose/transpose_fusion.cu + transpose/multi_cast_transpose.cu + activation/gelu.cu + gemm/cublaslt_gemm.cu + layer_norm/ln_api.cpp + layer_norm/ln_bwd_semi_cuda_kernel.cu + layer_norm/ln_fwd_cuda_kernel.cu + rmsnorm/rmsnorm_api.cpp + rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu + rmsnorm/rmsnorm_fwd_cuda_kernel.cu + util/cast.cu + fused_softmax/scaled_masked_softmax.cu + fused_softmax/scaled_upper_triang_masked_softmax.cu) add_library(transformer_engine SHARED ${transformer_engine_SOURCES}) - -target_include_directories(transformer_engine PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/include") - -list(APPEND transformer_engine_LINKER_LIBS CUDA::cublas CUDA::cudart CUDA::nvToolsExt) -if(NVTE_MPI_FOUND) - list(APPEND transformer_engine_LINKER_LIBS gdrapi) -endif() - -target_link_libraries(transformer_engine PUBLIC ${transformer_engine_LINKER_LIBS}) -target_include_directories(transformer_engine PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) - +target_include_directories(transformer_engine PUBLIC + "${CMAKE_CURRENT_SOURCE_DIR}/include") + +# Configure dependencies +target_link_libraries(transformer_engine PUBLIC + CUDA::cublas + CUDA::cudart + CUDA::nvToolsExt) +target_include_directories(transformer_engine PRIVATE + ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) + +# Compiler options set_source_files_properties(fused_softmax/scaled_masked_softmax.cu fused_softmax/scaled_upper_triang_masked_softmax.cu PROPERTIES COMPILE_OPTIONS "--use_fast_math") - -if(NVTE_MPI_FOUND) - set_source_files_properties(comm_gemm_overlap/userbuffers.cu - comm_gemm_overlap/userbuffers-host.cpp - PROPERTIES - INCLUDE_DIRECTORIES ${NVTE_MPI_INCLUDE} - COMPILE_OPTIONS "$<$:-maxrregcount=64>") -endif() - set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3") diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index 0a8924f8ed..220bec7003 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -37,8 +37,8 @@ def _load_library(): return ctypes.CDLL(dll_path, mode=ctypes.RTLD_GLOBAL) -def _load_mpi(): - """Load MPI shared library""" +def _load_userbuffers(): + """Load shared library with userbuffers""" system = platform.system() if system == "Linux": @@ -49,15 +49,14 @@ def _load_mpi(): extension = "dll" else: raise RuntimeError(f"Unsupported operating system ({system})") - lib_name = "libmpi." + extension - MPI_HOME = os.environ.get("MPI_HOME", "/usr/local/mpi") - NVTE_MPI_FOUND = os.path.exists(MPI_HOME) - dll_path = os.path.join(MPI_HOME, "lib", lib_name) + lib_name = "libtransformer_engine_userbuffers." + extension + dll_path = get_te_path() + dll_path = os.path.join(dll_path, lib_name) - if NVTE_MPI_FOUND: + if os.path.exists(dll_path): return ctypes.CDLL(dll_path, mode=ctypes.RTLD_GLOBAL) return None -_TE_LIB_CTYPES = _load_mpi() _TE_LIB_CTYPES = _load_library() +_UB_LIB_CTYPES = _load_userbuffers() diff --git a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h index 18863a7858..1e8b96f46b 100644 --- a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h +++ b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h @@ -14,9 +14,10 @@ #include #include #include -#include +#include "userbuffers/userbuffers.h" #define HALF_BYTES 2 +#define UB_MAX_SM 32 #define CHECK_CUDA(call) \ do { \ @@ -174,6 +175,7 @@ struct UbufCommOverlap : torch::CustomClassHolder { char *rs_output_ptr = reinterpret_cast(rs_output.data_ptr()); int ubuf_offset = 0; + int ori_sms = _ub_comm->sms; // Catch up the default torch stream at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); @@ -232,7 +234,8 @@ struct UbufCommOverlap : torch::CustomClassHolder { cudaEventRecord(_start_comm, (cudaStream_t)_stream_compute[last_compute_stream_id])); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0)); - // Communication chunk + // Last communication chunk with max SM + _ub_comm->sms = UB_MAX_SM; reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, (_num_splits - 1) * output_chunk_size, m_chunk, n, m, _ub_comm, (cudaStream_t)_stream_comm); @@ -255,7 +258,10 @@ struct UbufCommOverlap : torch::CustomClassHolder { (cudaStream_t)_stream_compute[i % _stream_compute.size()])); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0)); - // Communication chunk + // Communication chunk. Uses MAX_SM at the last chunk + if (i == _num_splits-1) { + _ub_comm->sms = UB_MAX_SM; + } reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, i * output_chunk_size, m_chunk, n, m, _ub_comm, (cudaStream_t)_stream_comm); @@ -264,6 +270,7 @@ struct UbufCommOverlap : torch::CustomClassHolder { output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size(); } } + _ub_comm->sms = ori_sms; int last_compute_stream_id = (_num_splits + _stream_compute.size() - 1) % _stream_compute.size(); CHECK_CUDA( diff --git a/transformer_engine/pytorch/csrc/extensions.cu b/transformer_engine/pytorch/csrc/extensions.cu index e34c79d980..23330efbf0 100644 --- a/transformer_engine/pytorch/csrc/extensions.cu +++ b/transformer_engine/pytorch/csrc/extensions.cu @@ -5,9 +5,9 @@ ************************************************************************/ #include "extensions.h" -#ifdef NVTE_MPI_FOUND +#ifdef NVTE_WITH_USERBUFFERS #include "comm_gemm_overlap.h" -#endif // NVTE_MPI_FOUND +#endif // NVTE_WITH_USERBUFFERS void te_gemm(at::Tensor A, at::Tensor A_scale_inverse, @@ -1022,7 +1022,7 @@ size_t get_cublasLt_version() { bool userbuf_comm_available() { // TODO(ksivamani) check on python side -#ifdef NVTE_MPI_FOUND +#ifdef NVTE_WITH_USERBUFFERS return true; #else return false; @@ -1080,7 +1080,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .def_readwrite("scale_inv", &transformer_engine::FP8TensorMeta::scale_inv) .def_readwrite("amax_history", &transformer_engine::FP8TensorMeta::amax_history); -#ifdef NVTE_MPI_FOUND +#ifdef NVTE_WITH_USERBUFFERS py::enum_(m, "UbufOverlapAlgo") .value("BULK_OVERLAP_AG", ubuf::UBOverlapAlgo::BULK_OVERLAP_AG) .value("BULK_OVERLAP_RS", ubuf::UBOverlapAlgo::BULK_OVERLAP_RS) @@ -1099,11 +1099,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .def("split_overlap_ag", &ubuf::UbufP2PCommOverlap::split_overlap_ag) .def("copy_input_to_ubuf", &ubuf::UbufP2PCommOverlap::copy_input_to_ubuf) .def("get_ubuf_output", &ubuf::UbufP2PCommOverlap::get_ubuf_output); -#else // NVTE_MPI_FOUND +#else // NVTE_WITH_USERBUFFERS m.def("UbufOverlapAlgo", &placeholder, "Dummy function for python side annotations"); m.def("UbufCommOverlap", &placeholder, "Dummy function for python side annotations"); m.def("UbufP2PCommOverlap", &placeholder, "Dummy function for python side annotations"); -#endif // NVTE_MPI_FOUND +#endif // NVTE_WITH_USERBUFFERS py::enum_(m, "DType", py::module_local()) .value("kByte", transformer_engine::DType::kByte) diff --git a/transformer_engine/pytorch/csrc/userbuffers/CMakeLists.txt b/transformer_engine/pytorch/csrc/userbuffers/CMakeLists.txt new file mode 100644 index 0000000000..fde8632ec6 --- /dev/null +++ b/transformer_engine/pytorch/csrc/userbuffers/CMakeLists.txt @@ -0,0 +1,33 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +# Configure userbuffers library +add_library(transformer_engine_userbuffers SHARED + userbuffers.cu + userbuffers-host.cpp) +target_include_directories(transformer_engine_userbuffers PUBLIC + "${CMAKE_CURRENT_SOURCE_DIR}") + +# Configure dependencies +find_package(MPI REQUIRED) +find_library(GDRCOPY_LIBRARY gdrapi + HINTS "${GDRCOPY_LIBRARY_DIR}" "$ENV{GDRCOPY_LIBRARY_DIR}") +if(NOT GDRCOPY_LIBRARY) + message(FATAL_ERROR "Could not find GDRCopy, please set GDRCOPY_LIBRARY_DIR") +endif() +message(STATUS "Found GDRCopy: ${GDRCOPY_LIBRARY}") +target_link_libraries(transformer_engine_userbuffers PUBLIC + CUDA::cudart + MPI::MPI_CXX + ${GDRCOPY_LIBRARY}) +target_include_directories(transformer_engine_userbuffers PRIVATE + ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) + +# Compiler options +set_source_files_properties(userbuffers.cu + userbuffers-host.cpp + PROPERTIES + COMPILE_OPTIONS "$<$:-maxrregcount=64>") +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3") diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers-host.cpp b/transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp similarity index 96% rename from transformer_engine/common/comm_gemm_overlap/userbuffers-host.cpp rename to transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp index 14928ed5a1..59afc4b452 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers-host.cpp +++ b/transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp @@ -13,12 +13,11 @@ #include #include #include -#include -#include #include #include #include #include +#include "userbuffers.h" static int oob_bcast(void *comm_context, void *buf, int size, int root) { MPI_Bcast(buf, size, MPI_BYTE, root, @@ -48,6 +47,12 @@ int stringCmp(const void *a, const void *b) { return strcmp((const char *)a, (co } \ } while (0) +#define NVTE_UB_ERROR(x) \ + do { \ + throw std::runtime_error(std::string(__FILE__ ":") + std::to_string(__LINE__) + \ + " in function " + __func__ + ": " + x); \ + } while (false) + int pipe_rank(communicator *comm, int step) { int mynode = comm->myrank / comm->nvsize; int mylocal = comm->nvrank; @@ -347,7 +352,7 @@ int allgather2_userbuff_inplace_gpu(const int maxcredit, const int handler, cons void allreduce_nonsharp_inplace(const int handler, const int offset, const int elements, communicator *comm, cudaStream_t stream, int op) { - if (elements < 64) NVTE_ERROR("Userbuffer comm for given config not implemented."); + if (elements < 64) NVTE_UB_ERROR("Userbuffer comm for given config not implemented."); // if(comm->myrank==0) fprintf(stderr,"AR2(%d) user call launch_mode=%d\n",op,comm->launch_mode); const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; int blocksize = elements * 2; @@ -394,7 +399,7 @@ void allreduce2_userbuff_inplace(const int handler, const int offset, const int void allreduce_userbuff_inplace(const int handler, const int offset, const int elements, communicator *comm, cudaStream_t stream) { - if (elements < 64) NVTE_ERROR("Userbuffer comm for given config not implemented."); + if (elements < 64) NVTE_UB_ERROR("Userbuffer comm for given config not implemented."); allreduce_nonsharp_inplace(handler, offset, elements, comm, stream, userbuffers_allreduceop_nonsharp); return; @@ -402,7 +407,7 @@ void allreduce_userbuff_inplace(const int handler, const int offset, const int e void reducescatter_userbuff_inplace(const int handler, const int offset, const int elements, communicator *comm, cudaStream_t stream) { - if (elements < 64) NVTE_ERROR("Userbuffer comm for given config not implemented."); + if (elements < 64) NVTE_UB_ERROR("Userbuffer comm for given config not implemented."); int op = userbuffers_allreduceop_nonsharp; const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; @@ -443,7 +448,7 @@ void reducescatter_userbuff_inplace(const int handler, const int offset, const i void allgather_userbuff_inplace(const int handler, const int offset, const int elements, communicator *comm, cudaStream_t stream) { - if (elements < 64) NVTE_ERROR("Userbuffer comm for given config not implemented."); + if (elements < 64) NVTE_UB_ERROR("Userbuffer comm for given config not implemented."); int op = userbuffers_allreduceop_nonsharp; const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; int blocksize = elements * 2; diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers.cu b/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu similarity index 99% rename from transformer_engine/common/comm_gemm_overlap/userbuffers.cu rename to transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu index 684771801b..9144e9e739 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers.cu +++ b/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu @@ -14,7 +14,7 @@ #endif #include #include -#include +#include "userbuffers.h" #define MAX_THREADS 1024 #define TIMEOUT 200000000000ull diff --git a/transformer_engine/common/include/transformer_engine/userbuffers.h b/transformer_engine/pytorch/csrc/userbuffers/userbuffers.h similarity index 99% rename from transformer_engine/common/include/transformer_engine/userbuffers.h rename to transformer_engine/pytorch/csrc/userbuffers/userbuffers.h index cd5b1ec382..1d4c1d4024 100644 --- a/transformer_engine/common/include/transformer_engine/userbuffers.h +++ b/transformer_engine/pytorch/csrc/userbuffers/userbuffers.h @@ -8,7 +8,7 @@ #define TRANSFORMER_ENGINE_USERBUFFERS_H_ #include -#include +#include // TODO (tym): Removing will remove PyT extension dependence on MPI #include "cuda_runtime.h" #include #include From ac5d44ecf7cdcf9896f04f7326ce9514b4f39aeb Mon Sep 17 00:00:00 2001 From: cyanguwa Date: Fri, 21 Apr 2023 16:22:39 -0700 Subject: [PATCH 20/68] Add FP8 fused attention (#155) * Add FP8 fused attention to TE for PyTorch Signed-off-by: Charlene Yang * add license for cudnn-frontend, modify installation requirements, and refactor some headers for aesthetics Signed-off-by: Charlene Yang * add c api docs for fused attention Signed-off-by: Charlene Yang * add exception for unsupported precision/sequence length combinations Signed-off-by: Charlene Yang * fix installation requirement for non fused attn use cases Signed-off-by: Charlene Yang * fix docs for fused-attn Signed-off-by: Kirthi Shankar Sivamani * prefix enums with NVTE_ and replace old MHA_Matrix with NVTE_QKV_Matrix Signed-off-by: Charlene Yang * minor fixes based on PR comments Signed-off-by: Charlene Yang * fix description for kvpacked fwd Signed-off-by: Charlene Yang * fix description of Bias in C api Signed-off-by: Charlene Yang * minor fixes for cudnn requirement and description for QKV tensors Signed-off-by: Charlene Yang * fix QKV layout description and support matrix for C api Signed-off-by: Charlene Yang * add asserts to cpp_extensions for qkv layout/bias type/attn mask type Signed-off-by: Charlene Yang * fix typo precision Signed-off-by: Charlene Yang --------- Signed-off-by: Charlene Yang Signed-off-by: Kirthi Shankar Sivamani Co-authored-by: Charlene Yang Co-authored-by: Kirthi Shankar Sivamani --- .github/workflows/build.yml | 6 + .gitmodules | 3 + 3rdparty/cudnn-frontend | 1 + Acknowledgements.txt | 22 + docs/api/c/fused_attn.rst | 9 + docs/api/c/index.rst | 1 + docs/installation.rst | 2 + setup.py | 1 + tests/cpp/test_common.cu | 1 + tests/cpp/test_common.h | 8 + transformer_engine/CMakeLists.txt | 2 + transformer_engine/cmake/FindCUDNN.cmake | 78 + transformer_engine/common/CMakeLists.txt | 7 +- .../common/fused_attn/fused_attn.cpp | 232 ++ .../common/fused_attn/fused_attn_fp8.cu | 2138 +++++++++++++++++ .../common/fused_attn/fused_attn_fp8.h | 46 + transformer_engine/common/fused_attn/utils.cu | 167 ++ transformer_engine/common/fused_attn/utils.h | 90 + .../include/transformer_engine/fused_attn.h | 262 ++ .../include/transformer_engine/logging.h | 9 + .../transformer_engine/transformer_engine.h | 35 +- .../common/transformer_engine.cpp | 13 + transformer_engine/pytorch/constants.py | 2 +- transformer_engine/pytorch/cpp_extensions.py | 730 +++++- transformer_engine/pytorch/csrc/common.cu | 13 + transformer_engine/pytorch/csrc/common.h | 15 + transformer_engine/pytorch/csrc/extensions.cu | 756 +++++- transformer_engine/pytorch/csrc/extensions.h | 90 +- transformer_engine/pytorch/module.py | 6 +- 29 files changed, 4720 insertions(+), 25 deletions(-) create mode 160000 3rdparty/cudnn-frontend create mode 100644 docs/api/c/fused_attn.rst create mode 100644 transformer_engine/cmake/FindCUDNN.cmake create mode 100644 transformer_engine/common/fused_attn/fused_attn.cpp create mode 100644 transformer_engine/common/fused_attn/fused_attn_fp8.cu create mode 100644 transformer_engine/common/fused_attn/fused_attn_fp8.h create mode 100644 transformer_engine/common/fused_attn/utils.cu create mode 100644 transformer_engine/common/fused_attn/utils.h create mode 100644 transformer_engine/common/include/transformer_engine/fused_attn.h diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index ff64f1de72..24d87c0416 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -17,6 +17,8 @@ jobs: steps: - name: 'Checkout' uses: actions/checkout@v3 + with: + submodules: recursive - name: 'Build' run: | mkdir -p wheelhouse && \ @@ -41,6 +43,8 @@ jobs: steps: - name: 'Checkout' uses: actions/checkout@v3 + with: + submodules: recursive - name: 'Build' run: | pip install ninja pybind11 && \ @@ -66,6 +70,8 @@ jobs: steps: - name: 'Checkout' uses: actions/checkout@v3 + with: + submodules: recursive - name: 'Build' run: | pip install ninja pybind11 && \ diff --git a/.gitmodules b/.gitmodules index 85675ac0bc..21492db5ef 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,6 @@ [submodule "3rdparty/googletest"] path = 3rdparty/googletest url = https://github.com/google/googletest.git +[submodule "3rdparty/cudnn-frontend"] + path = 3rdparty/cudnn-frontend + url = https://github.com/NVIDIA/cudnn-frontend.git diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend new file mode 160000 index 0000000000..e7f64390e9 --- /dev/null +++ b/3rdparty/cudnn-frontend @@ -0,0 +1 @@ +Subproject commit e7f64390e9bb4a3db622ffe11c973834f572b609 diff --git a/Acknowledgements.txt b/Acknowledgements.txt index 7eec81a9ce..ad11acc047 100644 --- a/Acknowledgements.txt +++ b/Acknowledgements.txt @@ -138,3 +138,25 @@ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +======================== +cudnn-frontend + +Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. diff --git a/docs/api/c/fused_attn.rst b/docs/api/c/fused_attn.rst new file mode 100644 index 0000000000..c2384b7e12 --- /dev/null +++ b/docs/api/c/fused_attn.rst @@ -0,0 +1,9 @@ +.. + Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +fused_attn.h +============ + +.. doxygenfile:: fused_attn.h diff --git a/docs/api/c/index.rst b/docs/api/c/index.rst index 0f83b8dc02..f98a419088 100644 --- a/docs/api/c/index.rst +++ b/docs/api/c/index.rst @@ -17,6 +17,7 @@ directly from C/C++, without Python. activation.h cast.h gemm.h + fused_attn.h layer_norm.h softmax.h transformer_engine.h diff --git a/docs/installation.rst b/docs/installation.rst index 088d65f9ca..9aded82d0f 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -14,6 +14,8 @@ Prerequisites 1. Linux x86_64 2. `CUDA 11.8 `__ 3. |driver link|_ supporting CUDA 11.8 or later. +4. `cuDNN 8 `__ or later. +5. For FP8 fused attention, `CUDA 12.1 `__ or later, |driver link|_ supporting CUDA 12.1 or later, and `cuDNN 8.9 `__ or later. Transformer Engine in NGC Containers diff --git a/setup.py b/setup.py index cb0c37fe3a..b88e4fbcc4 100644 --- a/setup.py +++ b/setup.py @@ -105,6 +105,7 @@ def make_abs_path(l): include_dirs = [ "transformer_engine/common/include", "transformer_engine/pytorch/csrc", + "3rdparty/cudnn-frontend/include", ] if NVTE_WITH_USERBUFFERS: if MPI_HOME: diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 151eddb9f9..bbb25bb2fc 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -42,6 +42,7 @@ const std::string &typeName(DType type) { static const std::unordered_map name_map = { {DType::kByte, "byte"}, {DType::kInt32, "int32"}, + {DType::kInt64, "int64"}, {DType::kFloat32, "float32"}, {DType::kFloat16, "float16"}, {DType::kBFloat16, "bfloat16"}, diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index f35d494c8d..7278f1827b 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -44,6 +44,7 @@ struct BytesToType<8> { using byte = uint8_t; using int32 = int32_t; +using int64 = int64_t; using fp32 = float; using fp16 = half; using bf16 = nv_bfloat16; @@ -54,6 +55,7 @@ template struct TypeInfo{ using types = std::tuple + $ +) + +target_link_libraries( + CUDNN::cudnn_all + INTERFACE + CUDNN::cudnn_adv_train + CUDNN::cudnn_ops_train + CUDNN::cudnn_cnn_train + CUDNN::cudnn_adv_infer + CUDNN::cudnn_cnn_infer + CUDNN::cudnn_ops_infer + CUDNN::cudnn +) + diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index c5bc6bb0f1..7b844540ae 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -12,6 +12,9 @@ list(APPEND transformer_engine_SOURCES transpose/transpose_fusion.cu transpose/multi_cast_transpose.cu activation/gelu.cu + fused_attn/fused_attn_fp8.cu + fused_attn/fused_attn.cpp + fused_attn/utils.cu gemm/cublaslt_gemm.cu layer_norm/ln_api.cpp layer_norm/ln_bwd_semi_cuda_kernel.cu @@ -30,9 +33,11 @@ target_include_directories(transformer_engine PUBLIC target_link_libraries(transformer_engine PUBLIC CUDA::cublas CUDA::cudart - CUDA::nvToolsExt) + CUDA::nvToolsExt + CUDNN::cudnn) target_include_directories(transformer_engine PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) +target_include_directories(transformer_engine PRIVATE "${CMAKE_SOURCE_DIR}/../3rdparty/cudnn-frontend/include") # Compiler options set_source_files_properties(fused_softmax/scaled_masked_softmax.cu diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp new file mode 100644 index 0000000000..17b6505038 --- /dev/null +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -0,0 +1,232 @@ +/************************************************************************* + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "transformer_engine/fused_attn.h" +#include "../common.h" +#include "utils.h" +#include "fused_attn_fp8.h" + +// NVTE fused attention FWD FP8 with packed QKV +void nvte_fused_attn_fwd_qkvpacked( + const NVTETensor QKV, + const NVTETensor Bias, + NVTETensor S, + NVTETensor O, + NVTETensorPack* Aux_Output_Tensors, + const NVTETensor cu_seqlens, + const NVTETensor rng_state, + size_t max_seqlen, + bool is_training, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, + NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked); + using namespace transformer_engine; + const Tensor *input_cu_seqlens = reinterpret_cast(cu_seqlens); + const Tensor *input_rng_state = reinterpret_cast(rng_state); + const Tensor *input_QKV = reinterpret_cast(QKV); + const Tensor *input_Bias = reinterpret_cast(Bias); + Tensor *input_output_S = reinterpret_cast(S); + Tensor *output_O = reinterpret_cast(O); + Tensor *wkspace = reinterpret_cast(workspace); + + // QKV shape is [total_seqs, 3, h, d] + size_t b = input_cu_seqlens->data.shape[0] - 1; + size_t h = input_QKV->data.shape[2]; + size_t d = input_QKV->data.shape[3]; + const DType QKV_type = input_QKV->data.dtype; + + if (((QKV_type == DType::kFloat8E4M3) || (QKV_type == DType::kFloat8E5M2)) + && (max_seqlen <= 512)) { +#if (CUDNN_VERSION >= 8900) + auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); + // FP8 API doesn't use input_Bias, bias_type or attn_mask_type + fused_attn_fwd_fp8_qkvpacked( + b, max_seqlen, h, d, + is_training, attn_scale, dropout, qkv_layout, + input_QKV, input_output_S, output_O, + Aux_Output_Tensors, + input_cu_seqlens, + input_rng_state, + wkspace, stream, handle); +#else + NVTE_ERROR("cuDNN 8.9 is required to run FP8 fused attention. \n"); +#endif + } else if (((QKV_type == DType::kFloat16) || (QKV_type == DType::kBFloat16)) + && (max_seqlen <= 512)) { + NVTE_ERROR("TBD: No support for BF16/FP16 fused attention currently. \n"); + } else if (max_seqlen > 512) { + NVTE_ERROR("TBD: No support for fused attention with >512 seqlence length currently. \n"); + } else { + NVTE_ERROR("Invalid combination of data type and sequence length! \n"); + } +} +// NVTE fused attention BWD FP8 with packed QKV +void nvte_fused_attn_bwd_qkvpacked( + const NVTETensor QKV, + const NVTETensor dBias, + const NVTETensor O, + const NVTETensor dO, + const NVTETensor S, + NVTETensor dP, + const NVTETensorPack* Aux_CTX_Tensors, + NVTETensor dQKV, + const NVTETensor cu_seqlens, + size_t max_seqlen, + float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, + NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked); + using namespace transformer_engine; + const Tensor *input_cu_seqlens = reinterpret_cast(cu_seqlens); + const Tensor *input_QKV = reinterpret_cast(QKV); + const Tensor *input_dBias = reinterpret_cast(dBias); + const Tensor *input_O = reinterpret_cast(O); + const Tensor *input_dO = reinterpret_cast(dO); + const Tensor *input_S = reinterpret_cast(S); + Tensor *input_output_dP = reinterpret_cast(dP); + Tensor *output_dQKV = reinterpret_cast(dQKV); + Tensor *wkspace = reinterpret_cast(workspace); + + // QKV shape is [total_seqs, 3, h, d] + size_t b = input_cu_seqlens->data.shape[0] - 1; + size_t h = input_QKV->data.shape[2]; + size_t d = input_QKV->data.shape[3]; + const DType QKV_type = input_QKV->data.dtype; + + if (((QKV_type == DType::kFloat8E4M3) || (QKV_type == DType::kFloat8E5M2)) + && (max_seqlen <= 512)) { +#if (CUDNN_VERSION >= 8900) + // Aux_CTX_Tensors contain [M, ZInv, rng_state] generated by the forward pass + const Tensor *input_M = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + const Tensor *input_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + const Tensor *input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); + // FP8 API doesn't use input_dBias, bias_type or attn_mask_type + fused_attn_bwd_fp8_qkvpacked( + b, max_seqlen, h, d, + attn_scale, dropout, qkv_layout, + input_QKV, input_O, input_dO, + input_M, input_ZInv, + input_S, input_output_dP, + output_dQKV, + input_cu_seqlens, + input_rng_state, + wkspace, stream, handle); +#else + NVTE_ERROR("cuDNN 8.9 is required to run FP8 fused attention. \n"); +#endif + } else if (((QKV_type == DType::kFloat16) || (QKV_type == DType::kBFloat16)) + && (max_seqlen <= 512)) { + NVTE_ERROR("TBD: No support for BF16/FP16 fused attention currently. \n"); + } else if (max_seqlen > 512) { + NVTE_ERROR("TBD: No support for fused attention with >512 seqlence length currently. \n"); + } else { + NVTE_ERROR("Invalid combination of data type and sequence length! \n"); + } +} +// NVTE fused attention FWD FP8 with packed KV +void nvte_fused_attn_fwd_kvpacked( + const NVTETensor Q, + const NVTETensor KV, + const NVTETensor Bias, + NVTETensor S, + NVTETensor O, + NVTETensorPack* Aux_Output_Tensors, + const NVTETensor cu_seqlens_q, + const NVTETensor cu_seqlens_kv, + const NVTETensor rng_state, + size_t max_seqlen_q, size_t max_seqlen_kv, + bool is_training, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, + NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked); + using namespace transformer_engine; + const Tensor *input_cu_seqlens_q = reinterpret_cast(cu_seqlens_q); + const Tensor *input_cu_seqlens_kv = reinterpret_cast(cu_seqlens_kv); + const Tensor *input_rng_state = reinterpret_cast(rng_state); + const Tensor *input_Q = reinterpret_cast(Q); + const Tensor *input_KV = reinterpret_cast(KV); + const Tensor *input_Bias = reinterpret_cast(Bias); + Tensor *input_output_S = reinterpret_cast(S); + Tensor *output_O = reinterpret_cast(O); + Tensor *wkspace = reinterpret_cast(workspace); + + // Q shape is [total_seqs, h, d] + size_t b = input_cu_seqlens_q->data.shape[0] - 1; + size_t h = input_Q->data.shape[1]; + size_t d = input_Q->data.shape[2]; + const DType QKV_type = input_Q->data.dtype; + + if (((QKV_type == DType::kFloat8E4M3) || (QKV_type == DType::kFloat8E5M2)) + && (max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) { + NVTE_ERROR("The FP8 fused attention API only supports packed QKV input. \n"); + } else if (((QKV_type == DType::kFloat16) || (QKV_type == DType::kBFloat16)) + && (max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) { + NVTE_ERROR("TBD: No support for BF16/FP16 fused attention currently. \n"); + } else if ((max_seqlen_q > 512) || (max_seqlen_kv > 512)) { + NVTE_ERROR("TBD: No support for fused attention with >512 seqlence length currently. \n"); + } else { + NVTE_ERROR("Invalid combination of data type and sequence length! \n"); + } +} +// NVTE fused attention BWD FP8 with packed KV +void nvte_fused_attn_bwd_kvpacked( + const NVTETensor Q, + const NVTETensor KV, + const NVTETensor dBias, + const NVTETensor O, + const NVTETensor dO, + const NVTETensor S, + NVTETensor dP, + const NVTETensorPack* Aux_CTX_Tensors, + NVTETensor dQ, + NVTETensor dKV, + const NVTETensor cu_seqlens_q, + const NVTETensor cu_seqlens_kv, + size_t max_seqlen_q, size_t max_seqlen_kv, + float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, + NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked); + using namespace transformer_engine; + const Tensor *input_cu_seqlens_q = reinterpret_cast(cu_seqlens_q); + const Tensor *input_cu_seqlens_kv = reinterpret_cast(cu_seqlens_kv); + const Tensor *input_Q = reinterpret_cast(Q); + const Tensor *input_KV = reinterpret_cast(KV); + const Tensor *input_dBias = reinterpret_cast(dBias); + const Tensor *input_O = reinterpret_cast(O); + const Tensor *input_dO = reinterpret_cast(dO); + const Tensor *input_S = reinterpret_cast(S); + Tensor *input_output_dP = reinterpret_cast(dP); + Tensor *output_dQ = reinterpret_cast(dQ); + Tensor *output_dKV = reinterpret_cast(dKV); + Tensor *wkspace = reinterpret_cast(workspace); + + // Q shape is [total_seqs, h, d] + size_t b = input_cu_seqlens_q->data.shape[0] - 1; + size_t h = input_Q->data.shape[1]; + size_t d = input_Q->data.shape[2]; + const DType QKV_type = input_Q->data.dtype; + if (((QKV_type == DType::kFloat8E4M3) || (QKV_type == DType::kFloat8E5M2)) + && (max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) { + NVTE_ERROR("The FP8 fused attention API only supports packed QKV input. \n"); + } else if (((QKV_type == DType::kFloat16) || (QKV_type == DType::kBFloat16)) + && (max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) { + NVTE_ERROR("TBD: No support for BF16/FP16 fused attention currently. \n"); + } else if ((max_seqlen_q > 512) || (max_seqlen_kv > 512)) { + NVTE_ERROR("TBD: No support for fused attention with >512 seqlence length currently. \n"); + } else { + NVTE_ERROR("Invalid combination of data type and sequence length! \n"); + } +} diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu new file mode 100644 index 0000000000..633f46c51f --- /dev/null +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -0,0 +1,2138 @@ +/************************************************************************* + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "transformer_engine/fused_attn.h" +#include "../common.h" +#include "utils.h" +#include "fused_attn_fp8.h" + +namespace transformer_engine { +namespace fused_attn { + +using namespace transformer_engine; + +#if (CUDNN_VERSION >= 8900) +std::unordered_map tensor_name_to_uid = { + {"Q", 1}, + {"K", 2}, + {"V", 3}, + {"O", 4}, + {"S", 5}, + {"B", 6}, + {"DROPOUT_SCALE", 7}, + {"S_CONST", 8}, + {"MNK_OVERRIDE", 9}, + {"dQ", 11}, + {"dK", 12}, + {"dV", 13}, + {"dO", 14}, + {"MASK_VAL", 15}, + {"dS", 16}, + {"O_SEQLEN", 17}, + {"M", 18}, + {"Z", 19}, + {"descaleQ", 20}, + {"descaleK", 21}, + {"descaleV", 22}, + {"descaleS", 23}, + {"scaleS", 24}, + {"amaxS", 25}, + {"amaxO", 26}, + {"QKV_RAGGED", 27}, + {"O_RAGGED", 28}, + {"K_TRANSPOSE", 29}, + {"AttnScale", 30}, + {"scaleO", 31}, + {"Z_INV", 32}, + {"descaleO", 33}, + {"descaledO", 34}, + {"descaledS", 35}, + {"descaledQ", 36}, + {"descaledK", 37}, + {"descaledV", 38}, + {"scaledS", 39}, + {"scaledQ", 40}, + {"scaledK", 41}, + {"scaledV", 42}, + {"amaxdS", 43}, + {"amaxdQ", 44}, + {"amaxdK", 45}, + {"amaxdV", 46}, + {"V_TRANSPOSE", 47}, + {"AttnScale_dS_K", 48}, + {"AttnScale_dSTranspose_Q", 49}, + {"DROPOUT_SCALE_dOVt_OdO", 50}, + {"DROPOUT_OFFSET", 51}, + {"DROPOUT_SEED", 52}, + {"VIRTUAL", 80} +}; + +bool allowAllConfig(cudnnBackendDescriptor_t engine_config) { + (void)engine_config; + return false; +} + +static cudnn_frontend::Tensor tensor_create( + cudnnDataType_t type, int64_t id, + int64_t const * dim, int64_t const * stride, + bool is_virtual, bool is_value) { + int nbDims = 4; + auto tensor_created = cudnn_frontend::TensorBuilder() + .setDim(nbDims, dim) + .setStride(nbDims, stride) + .setId(id) + .setAlignment(16) // 16B alignment is needed to run a tensor core engine + .setDataType(type) + .setVirtual(is_virtual) + .setByValue(is_value) + .build(); + return tensor_created; +} + +static cudnn_frontend::Tensor tensor_create_with_offset( + cudnnDataType_t type, int64_t id, + int64_t const * dim, int64_t const * stride, + bool is_virtual, bool is_value, + std::shared_ptr raggedOffset) { + int nbDims = 4; + auto tensor_created = cudnn_frontend::TensorBuilder() + .setDim(nbDims, dim) + .setStride(nbDims, stride) + .setId(id) + .setAlignment(16) // 16B alignment is needed to run a tensor core engine + .setDataType(type) + .setVirtual(is_virtual) + .setByValue(is_value) + .setRaggedOffset(raggedOffset) + .build(); + return tensor_created; +} + +static cudnn_frontend::PointWiseDesc pw_desc_create( + cudnnDataType_t type, cudnnPointwiseMode_t mode) { + auto pw_desc_created = cudnn_frontend::PointWiseDescBuilder() + .setMode(mode) + .setComputeType(type) + .build(); + return pw_desc_created; +} + +static cudnn_frontend::Operation unary_pw_op_create( + cudnn_frontend::Tensor const &xDesc, + cudnn_frontend::Tensor const &yDesc, + cudnn_frontend::PointWiseDesc const &pwDesc) { + auto pw_op_created = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(xDesc) + .setyDesc(yDesc) + .setpwDesc(pwDesc) + .build(); + return pw_op_created; +} + +static cudnn_frontend::Operation binary_pw_op_create( + cudnn_frontend::Tensor const &xDesc, + cudnn_frontend::Tensor const &bDesc, + cudnn_frontend::Tensor const &yDesc, + cudnn_frontend::PointWiseDesc const &pwDesc) { + auto pw_op_created = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(xDesc) + .setbDesc(bDesc) + .setyDesc(yDesc) + .setpwDesc(pwDesc) + .build(); + return pw_op_created; +} + +static cudnn_frontend::Operation ternary_pw_op_create( + cudnn_frontend::Tensor const &xDesc, + cudnn_frontend::Tensor const &bDesc, + cudnn_frontend::Tensor const &tDesc, + cudnn_frontend::Tensor const &yDesc, + cudnn_frontend::PointWiseDesc const &pwDesc) { + auto pw_op_created = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(xDesc) + .setbDesc(bDesc) + .settDesc(tDesc) + .setyDesc(yDesc) + .setpwDesc(pwDesc) + .build(); + return pw_op_created; +} + +static cudnn_frontend::Tensor createAmax( + const std::string& amax_tensor_name, + const cudnn_frontend::Tensor& prevBlockOutputTensor, + std::vector* ops) { + int64_t amax_dim[4] = {1, 1, 1, 1}; + int64_t amax_stride[4] = {1, 1, 1, 1}; + auto amaxTensor = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid[amax_tensor_name], + amax_dim, amax_stride, false, false); + + // Define the amax descriptor + auto reductionDesc = cudnn_frontend::ReductionDescBuilder() + .setMathPrecision(CUDNN_DATA_FLOAT) + .setReductionOp(CUDNN_REDUCE_TENSOR_AMAX) + .build(); + + // Create a reduction amax Node + auto reduction_op = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) + .setxDesc(prevBlockOutputTensor) + .setyDesc(amaxTensor) + .setreductionDesc(reductionDesc) + .build(); + ops->push_back(std::move(reduction_op)); + return amaxTensor; +} + +static cudnn_frontend::Tensor createScale( + const cudnn_frontend::Tensor& prevBlockOutputTensor, + const std::string& scale_tensor_name, + cudnnDataType_t tensorType, + bool isOutputVirtual, bool isScaleByValue, + std::vector* ops, + const std::string& output_tensor_name ="") { + int64_t scale_dim[4] = {1, 1, 1, 1}; + int64_t scale_stride[4] = {1, 1, 1, 1}; + + int64_t output_dim[4]; + int64_t output_stride[4]; + + for (int i = 0; i < 4; i++) { + output_dim[i] = prevBlockOutputTensor.getDim()[i]; + output_stride[i] = prevBlockOutputTensor.getStride()[i]; + } + + auto scaleTensor = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid[scale_tensor_name], + scale_dim, scale_stride, false, isScaleByValue); // is by value + + int64_t outputUID = isOutputVirtual ? tensor_name_to_uid["VIRTUAL"] + + tensor_name_to_uid[scale_tensor_name] + 5000 : + tensor_name_to_uid[output_tensor_name]; + auto afterScaleKTensor = tensor_create( + tensorType, outputUID, output_dim, + output_stride, isOutputVirtual, false); // is virtual + + // Define the scale descriptor + auto scaleDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); + + // Create a Scale Node + auto scale_op = binary_pw_op_create( + prevBlockOutputTensor, scaleTensor, afterScaleKTensor, scaleDesc); + + ops->push_back(std::move(scale_op)); + return afterScaleKTensor; +} + +static cudnn_frontend::Tensor createScale( + const cudnn_frontend::Tensor& prevBlockOutputTensor, + const cudnn_frontend::Tensor& scaleTensor, + cudnnDataType_t tensorType, + bool isOutputVirtual, bool isScaleByValue, + std::vector* ops, + int UID_offset, const std::string& output_tensor_name ="") { + int64_t output_dim[4]; + int64_t output_stride[4]; + for (int i = 0; i < 4; i++) { + output_dim[i] = prevBlockOutputTensor.getDim()[i]; + output_stride[i] = prevBlockOutputTensor.getStride()[i]; + } + + int64_t outputUID = isOutputVirtual ? + tensor_name_to_uid["VIRTUAL"] + UID_offset : + tensor_name_to_uid[output_tensor_name]; + auto afterScaleTensor = tensor_create( + tensorType, outputUID, output_dim, + output_stride, isOutputVirtual, false); // is virtual + + // Define the scale descriptor + auto scaleDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); + + // Create a Scale Node + auto scale_op = binary_pw_op_create( + prevBlockOutputTensor, scaleTensor, afterScaleTensor, scaleDesc); + + ops->push_back(std::move(scale_op)); + return afterScaleTensor; +} + +static cudnn_frontend::Tensor createScaleWithOffset( + const cudnn_frontend::Tensor& prevBlockOutputTensor, + const std::string& scale_tensor_name, + cudnnDataType_t tensorType, + bool isOutputVirtual, + bool isScaleByValue, + std::vector* ops, + std::shared_ptr offsetTensor, + const std::string& output_tensor_name ="") { + int64_t scale_dim[4] = {1, 1, 1, 1}; + int64_t scale_stride[4] = {1, 1, 1, 1}; + + int64_t output_dim[4]; + int64_t output_stride[4]; + // If output tensor is dQ, dK, or dV, we need to generate QKV interleaved strides + if (output_tensor_name == "dQ" || output_tensor_name == "dK" || output_tensor_name == "dV") { + for (int i = 0; i < 4; i++) { + output_dim[i] = prevBlockOutputTensor.getDim()[i]; + } + generateMatrixStrides(output_dim[0], output_dim[1], output_dim[2], + 0 /*s_kv = 0 for placeholder*/, + output_dim[3], output_stride, + NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED, NVTE_QKV_Matrix::NVTE_Q_Matrix); + } else { + // Otherwise output dim and stride should be the same as prev block dim and stride + for (int i = 0; i < 4; i++) { + output_dim[i] = prevBlockOutputTensor.getDim()[i]; + output_stride[i] = prevBlockOutputTensor.getStride()[i]; + } + } + + auto scaleTensor = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid[scale_tensor_name], + scale_dim, scale_stride, false, isScaleByValue); // is by value + + cudnnDataType_t outputDataType = isOutputVirtual ? CUDNN_DATA_FLOAT : tensorType; + int64_t outputUID = isOutputVirtual ? + tensor_name_to_uid["VIRTUAL"] + tensor_name_to_uid[scale_tensor_name] + 7000 : + tensor_name_to_uid[output_tensor_name]; + auto afterScaleTensor = tensor_create_with_offset( + outputDataType, outputUID, output_dim, + output_stride, isOutputVirtual, false, offsetTensor); // is virtual + + // Define the scale descriptor + auto scaleDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); + + // Create a Scale Node + auto scale_op = binary_pw_op_create( + prevBlockOutputTensor, scaleTensor, afterScaleTensor, scaleDesc); + + ops->push_back(std::move(scale_op)); + return afterScaleTensor; +} + +static cudnn_frontend::Tensor createSoftmaxForward( + int64_t b, int64_t h, int64_t s_q, int64_t s_kv, + std::vector* ops, + const cudnn_frontend::Tensor& prevBlockOutputTensor, + bool isTraining) { + int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv}; + int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1}; + + int64_t afterReduction_dim[4] = {b, h, s_q, 1}; + int64_t afterReduction_stride[4] = {h * s_q, s_q, 1, 1}; + + // max (x) (M tensor) + auto afterMaxReductionTensor = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["M"], + afterReduction_dim, afterReduction_stride, + !isTraining, false); // not virtual if training is true, + // virtual if training is false + // x - max(x) + auto afterSubtractionTensor = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 151, + afterBMM1_dim, afterBMM1_stride, true, false); // is virtual + // e^(x - max(x)) + auto afterExponentTensor = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 152, + afterBMM1_dim, afterBMM1_stride, true, false); // is virtual; + // sum (e^(x - max(x))) (Z tensor) + auto zTensor = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["Z"], + afterReduction_dim, afterReduction_stride, true, false); // is virtual + // 1 / sum (e^(x - max(x))) (Z_INV tensor) + auto zInvTensor = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["Z_INV"], + afterReduction_dim, afterReduction_stride, + !isTraining, false); // not virtual if training is true, + // virtual if training is false + // Final softmax output (After exponent * Z_INV) + auto beforeDropoutTensor = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 153, + afterBMM1_dim, afterBMM1_stride, true, false); // is virtual + + // Define the reduction descriptor + auto reductionMaxDesc = cudnn_frontend::ReductionDescBuilder() + .setComputeType(CUDNN_DATA_FLOAT) + .setReductionOp(CUDNN_REDUCE_TENSOR_MAX) + .build(); + + // Create a reduction max Node + auto reductionMax_op = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) + .setxDesc(prevBlockOutputTensor) + .setyDesc(afterMaxReductionTensor) + .setreductionDesc(reductionMaxDesc) + .build(); + + // Define the subtract descriptor + auto subtractDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_SUB); + + // Create a subtract Node + auto subtract_op = binary_pw_op_create( + prevBlockOutputTensor, afterMaxReductionTensor, + afterSubtractionTensor, subtractDesc); + + // Define the exponent descriptor + auto exponentDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_EXP); + + // Create a exponent Node + auto exponent_op = unary_pw_op_create( + afterSubtractionTensor, afterExponentTensor, exponentDesc); + + // Define the reduction descriptor + auto reductionAddDesc = cudnn_frontend::ReductionDescBuilder() + .setComputeType(CUDNN_DATA_FLOAT) + .setReductionOp(CUDNN_REDUCE_TENSOR_ADD) + .build(); + + // Create a reduction add Node + auto reductionAdd_op = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) + .setxDesc(afterExponentTensor) + .setyDesc(zTensor) + .setreductionDesc(reductionAddDesc) + .build(); + + // Define the reciprocal descriptor + auto reciprocalDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_RECIPROCAL); + + // Create a reciprocal Node + auto reciprocal_op = unary_pw_op_create(zTensor, zInvTensor, reciprocalDesc); + + // Define the pw multiply descriptor + auto multiplyDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); + + // Create a multiply Node + auto mutliply_op = binary_pw_op_create( + afterExponentTensor, zInvTensor, beforeDropoutTensor, multiplyDesc); + + ops->push_back(std::move(reductionMax_op)); + ops->push_back(std::move(subtract_op)); + ops->push_back(std::move(exponent_op)); + ops->push_back(std::move(reductionAdd_op)); + ops->push_back(std::move(reciprocal_op)); + ops->push_back(std::move(mutliply_op)); + + return beforeDropoutTensor; +} + +static cudnn_frontend::Tensor createDropoutForward( + int64_t b, int64_t h, int64_t s_q, int64_t s_kv, + double probability, + std::vector* ops, + const cudnn_frontend::Tensor& beforeDropoutTensor) { + cudnn_frontend::throw_if(ops->size() == 0, + "Dropout DAG constructed incorrectly as the first one", + CUDNN_STATUS_BAD_PARAM); + + int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv}; + int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1}; + + int64_t scale_dim[4] = {1, 1, 1, 1}; + int64_t scale_stride[4] = {1, 1, 1, 1}; + + // Mask for the dropout + auto dropoutMaskTensor = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 250, + afterBMM1_dim, afterBMM1_stride, true, false); // is virtual + auto dropoutSeedTensor = tensor_create( + CUDNN_DATA_INT64, tensor_name_to_uid["DROPOUT_SEED"], + scale_dim, scale_stride, false, false); // is by value + auto dropoutOffsetTensor = tensor_create( + CUDNN_DATA_INT64, tensor_name_to_uid["DROPOUT_OFFSET"], + scale_dim, scale_stride, false, false); // is by value + + // After dropout tensor befor scale + auto beforeDropoutScaleTensor = cudnn_frontend::TensorBuilder() + .setDim(4, afterBMM1_dim) + .setStride(4, afterBMM1_stride) + .setId(tensor_name_to_uid["VIRTUAL"] + 201) + .setAlignment(16) // 16B alignment is needed to run a tensor core engine + .setDataType(CUDNN_DATA_FLOAT) + .setVirtual(true) + .setByValue(false) + .setReorderType(cudnn_frontend::cudnnBackendTensorReordering_t:: + CUDNN_TENSOR_REORDERING_F16x16) + .build(); + // Scale after dropout + auto scaleDropoutTensor = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["DROPOUT_SCALE"], + scale_dim, scale_stride, false, true); // is by value + // After Scale + auto afterDropout_before_quan_S = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 202, + afterBMM1_dim, afterBMM1_stride, true, false); // is virtual + + // Define the reduction descriptor + auto rngDesc = cudnn_frontend::RngDescBuilder() + .setRngDistribution(CUDNN_RNG_DISTRIBUTION_BERNOULLI) + .setBernoulliDistProbability(1.0 - probability) + .build(); + + // Create a rng Node + auto rng_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RNG_DESCRIPTOR) + .setyDesc(dropoutMaskTensor) + .setSeedDesc(dropoutSeedTensor) + .setOffsetDesc(dropoutOffsetTensor) + .setRngDesc(rngDesc) + .build(); + + + // Define the multiply mask descriptor + auto maskMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); + + // Create a multiply mask Node + auto maskMul_op = binary_pw_op_create( + beforeDropoutTensor, dropoutMaskTensor, + beforeDropoutScaleTensor, maskMulDesc); + + // Define the multiply scale descriptor + auto scaleMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); + + // Create a multiply mask Node + auto scaleMul_op = binary_pw_op_create( + beforeDropoutScaleTensor, scaleDropoutTensor, + afterDropout_before_quan_S, scaleMulDesc); + + ops->push_back(std::move(rng_op)); + ops->push_back(std::move(maskMul_op)); + ops->push_back(std::move(scaleMul_op)); + + return afterDropout_before_quan_S; +} + +static cudnn_frontend::Tensor createDropoutBackward( + int64_t b, int64_t h, int64_t s_q, int64_t s_kv, + double probability, + std::vector* ops, + const cudnn_frontend::Tensor& beforeDropoutTensor, + const cudnn_frontend::Tensor& dropoutMaskTensor) { + cudnn_frontend::throw_if(ops->size() == 0, + "Dropout DAG constructed incorrectly as the first one", + CUDNN_STATUS_BAD_PARAM); + + int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv}; + int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1}; + + int64_t scale_dim[4] = {1, 1, 1, 1}; + int64_t scale_stride[4] = {1, 1, 1, 1}; + + auto dropoutSeedTensor = tensor_create( + CUDNN_DATA_INT64, tensor_name_to_uid["DROPOUT_SEED"], + scale_dim, scale_stride, false, false); // is by value + auto dropoutOffsetTensor = tensor_create( + CUDNN_DATA_INT64, tensor_name_to_uid["DROPOUT_OFFSET"], + scale_dim, scale_stride, false, false); // is by value + + // After dropout tensor befor scale + auto beforeDropoutScaleTensor = cudnn_frontend::TensorBuilder() + .setDim(4, afterBMM1_dim) + .setStride(4, afterBMM1_stride) + .setId(tensor_name_to_uid["VIRTUAL"] + 201) + .setAlignment(16) // 16B alignment is needed to run a tensor core engine + .setDataType(CUDNN_DATA_FLOAT) + .setVirtual(true) + .setByValue(false) + .setReorderType(cudnn_frontend::cudnnBackendTensorReordering_t:: + CUDNN_TENSOR_REORDERING_F16x16) + .build(); + // Scale after dropout (1 / (1 - p)) + auto scaleDropoutTensor = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["DROPOUT_SCALE"], + scale_dim, scale_stride, false, true); // is by value + // After Scale + auto afterDropout_before_quan_S = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 202, + afterBMM1_dim, afterBMM1_stride, true, false); // is virtual + + // Define the reduction descriptor + auto rngDesc = cudnn_frontend::RngDescBuilder() + .setRngDistribution(CUDNN_RNG_DISTRIBUTION_BERNOULLI) + .setBernoulliDistProbability(1.0 - probability) + .build(); + + // Create a rng Node + auto rng_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RNG_DESCRIPTOR) + .setyDesc(dropoutMaskTensor) + .setSeedDesc(dropoutSeedTensor) + .setOffsetDesc(dropoutOffsetTensor) + .setRngDesc(rngDesc) + .build(); + + // Define the multiply mask descriptor + auto maskMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); + + // Create a multiply mask Node + auto maskMul_op = binary_pw_op_create( + beforeDropoutTensor, dropoutMaskTensor, + beforeDropoutScaleTensor, maskMulDesc); + + // Define the multiply scale descriptor + auto scaleMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); + + // Create a multiply mask Node + auto scaleMul_op = binary_pw_op_create( + beforeDropoutScaleTensor, scaleDropoutTensor, + afterDropout_before_quan_S, scaleMulDesc); + + ops->push_back(std::move(rng_op)); + ops->push_back(std::move(maskMul_op)); + ops->push_back(std::move(scaleMul_op)); + + return afterDropout_before_quan_S; +} + +static cudnn_frontend::Tensor createSoftmaxBackward( + int64_t b, int64_t h, int64_t s_q, int64_t s_kv, + std::vector* ops, + const cudnn_frontend::Tensor& dyTensor) { + cudnn_frontend::throw_if(ops->size() == 0, + "Softmax backward constructed incorrectly as the first one", + CUDNN_STATUS_BAD_PARAM); + + int64_t dx_dim[4] = {b, h, s_q, s_kv}; + int64_t dx_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1}; + + int64_t M_Z_dim[4] = {b, h, s_q, 1}; + int64_t M_Z_stride[4] = {h * s_q, s_q, 1, 1}; + + // Creating all tensors + auto MTensor = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["M"], + M_Z_dim, M_Z_stride, false, false); // not virtual + auto ZInvTensor = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["Z_INV"], + M_Z_dim, M_Z_stride, false, false); // not virtual + auto dxAfterSubtractionTensor = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 252, + dx_dim, dx_stride, true, false); // is virtual + auto dxAfterExponentiation = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 253, + dx_dim, dx_stride, true, false); // is virtual + auto dxBeforeDropout_QKt_Tensor = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 254, + dx_dim, dx_stride, true, false); // is virtual + + // Creating all ops + // sub (dy - M) + auto subtractionDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_SUB); + auto subtractionOp = binary_pw_op_create( + dyTensor, MTensor, dxAfterSubtractionTensor, subtractionDesc); + + // Define the exponent descriptor + auto exponentDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_EXP); + + // Create a exponent Node. (exp(dy - M)) + auto exponentOp = unary_pw_op_create( + dxAfterSubtractionTensor, dxAfterExponentiation, exponentDesc); + + // Define the pw multiply descriptor + auto multiplyDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); + + // Create a multiply Node + auto mutliplyOp = binary_pw_op_create( + dxAfterExponentiation, ZInvTensor, dxBeforeDropout_QKt_Tensor, multiplyDesc); + + ops->push_back(std::move(subtractionOp)); + ops->push_back(std::move(exponentOp)); + ops->push_back(std::move(mutliplyOp)); + + return dxBeforeDropout_QKt_Tensor; +} + +static cudnn_frontend::Tensor createQKBMM( + int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, + NVTE_QKV_Layout layout, + cudnnDataType_t tensorType, + std::vector* ops, + const cudnn_frontend::Tensor &qTensor, + const cudnn_frontend::Tensor &kTensor, + const cudnn_frontend::Tensor &mnkOverride, + std::shared_ptr QKVRaggedOffsetTensor) { + // Creates the necessary tensor descriptors + int64_t k_transpose_dim[4] = {b, h, d, s_kv}; + int64_t k_transpose_stride[4]; + generateMatrixStrides( + b, h, s_q, s_kv, d, + k_transpose_stride, layout, NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose); + + int64_t s_dim[4] = {b, h, s_q, s_kv}; + int64_t s_stride[4]; + generateMatrixStrides(b, h, s_q, s_kv, d, s_stride, layout, NVTE_QKV_Matrix::NVTE_S_Matrix); + + auto kTransposeTensor = tensor_create_with_offset( + tensorType, tensor_name_to_uid["K_TRANSPOSE"], + k_transpose_dim, k_transpose_stride, + false, false, QKVRaggedOffsetTensor); // is virtual + + // First GEMM output + auto afterQKTensor = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 1, + s_dim, s_stride, true, false); // is virtual + + // Define the matmul desc + auto matmulDesc = cudnn_frontend::MatMulDescBuilder() + .setComputeType(CUDNN_DATA_FLOAT) + .setPaddingValue(-2000000) + .build(); + + // Create reshape node for K -> K.T + auto reshape_op = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR) + .setxDesc(kTensor) + .setyDesc(kTransposeTensor) + .build(); + + // Create a matmul Node + auto matmulOp = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) + .setaMatDesc(qTensor) + .setbMatDesc(kTransposeTensor) + .setcMatDesc(afterQKTensor) + .setmOverrideDesc(mnkOverride) + .setnOverrideDesc(mnkOverride) + .setmatmulDesc(matmulDesc) + .build(); + + ops->push_back(std::move(reshape_op)); + ops->push_back(std::move(matmulOp)); + + return afterQKTensor; +} + +static cudnn_frontend::Tensor createSVBMM( + int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, + NVTE_QKV_Layout layout, + cudnnDataType_t tensorType, + std::vector* ops, + const cudnn_frontend::Tensor &softmaxTensor, + const cudnn_frontend::Tensor &mnkOverride, + std::shared_ptr QKVRaggedOffsetTensor) { + cudnn_frontend::throw_if(ops->size() == 0, + "BMM2 op constructed incorrectly as the first one", + CUDNN_STATUS_BAD_PARAM); + + int64_t v_dim[4] = {b, h, s_kv, d}; + int64_t v_stride[4]; + generateMatrixStrides(b, h, s_q, s_kv, d, v_stride, layout, NVTE_QKV_Matrix::NVTE_V_Matrix); + + int64_t o_dim[4] = {b, h, s_q, d}; + int64_t o_stride[4]; + generateMatrixStrides(b, h, s_q, s_kv, d, o_stride, layout, NVTE_QKV_Matrix::NVTE_O_Matrix); + + auto vTensor = tensor_create_with_offset( + tensorType, tensor_name_to_uid["V"], + v_dim, v_stride, false, false, QKVRaggedOffsetTensor); + // Second fprop GEMM output + auto oTensor = tensor_create( + tensorType, tensor_name_to_uid["VIRTUAL"] + 300, + o_dim, o_stride, true, false); // is virtual + + // Define the matmul desc + auto matmulDesc = cudnn_frontend::MatMulDescBuilder() + .setComputeType(CUDNN_DATA_FLOAT) + .build(); + + // Create a matmul Node + auto matmulOp = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) + .setaMatDesc(softmaxTensor) + .setbMatDesc(vTensor) + .setcMatDesc(oTensor) + .setmOverrideDesc(mnkOverride) + .setkOverrideDesc(mnkOverride) + .setmatmulDesc(matmulDesc) + .build(); + + ops->push_back(std::move(matmulOp)); + + return oTensor; +} + +static cudnn_frontend::Tensor createSdOBMM( + int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, + cudnnDataType_t tensorType, + std::vector* ops, + const cudnn_frontend::Tensor &softmaxTensor, + const cudnn_frontend::Tensor &dOTensor, + const cudnn_frontend::Tensor &mnkOverride) { + cudnn_frontend::throw_if(ops->size() == 0, + "BMM2 op constructed incorrectly as the first one", + CUDNN_STATUS_BAD_PARAM); + + int64_t s_dim_transpose[4] = {b, h, s_kv, s_q}; + int64_t s_stride_transpose[4] = {h * s_kv * s_q, s_kv * s_q, 1, s_kv}; + + int64_t v_dim[4] = {b, h, s_kv, d}; + int64_t v_stride[4] = {h * s_kv * d, d, h * d, 1}; + + auto sTransposeTensor = tensor_create( + tensorType, tensor_name_to_uid["VIRTUAL"] + 499, + s_dim_transpose, s_stride_transpose, + true, false); // is virtual + // S.T * dO + auto dVTensor_before_dequan_S = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 500, + v_dim, v_stride, + true, false); // is virtual + + // Create reshape node for softmax -> softmax.T + auto reshape_op = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR) + .setxDesc(softmaxTensor) + .setyDesc(sTransposeTensor) + .build(); + + // Define the matmul desc + auto matmulDesc = cudnn_frontend::MatMulDescBuilder() + .setComputeType(CUDNN_DATA_FLOAT) + .setPaddingValue(0) + .build(); + + // Create a matmul Node + auto matmulOp = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) + .setaMatDesc(sTransposeTensor) + .setbMatDesc(dOTensor) + .setcMatDesc(dVTensor_before_dequan_S) + .setmOverrideDesc(mnkOverride) + .setkOverrideDesc(mnkOverride) + .setmatmulDesc(matmulDesc) + .build(); + + ops->push_back(std::move(reshape_op)); + ops->push_back(std::move(matmulOp)); + + return dVTensor_before_dequan_S; +} + +static cudnn_frontend::Tensor createdOVBMM( + int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, + NVTE_QKV_Layout layout, + cudnnDataType_t tensorType, + std::vector* ops, + const cudnn_frontend::Tensor &dOTensor, + const cudnn_frontend::Tensor &mnkOverride, + std::shared_ptr QKVRaggedOffsetTensor) { + // Creates the necessary tensor descriptors + int64_t v_dim[4] = {b, h, s_kv, d}; + int64_t v_stride[4]; + generateMatrixStrides(b, h, s_q, s_kv, d, v_stride, layout, NVTE_QKV_Matrix::NVTE_V_Matrix); + + int64_t v_transpose_dim[4] = {b, h, d, s_kv}; + int64_t v_transpose_stride[4]; + v_transpose_stride[0] = v_stride[0]; + v_transpose_stride[1] = v_stride[1]; + v_transpose_stride[2] = v_stride[3]; + v_transpose_stride[3] = v_stride[2]; + + int64_t s_dim[4] = {b, h, s_q, s_kv}; + int64_t s_stride[4]; + generateMatrixStrides(b, h, s_q, s_kv, d, s_stride, layout, NVTE_QKV_Matrix::NVTE_S_Matrix); + + auto vTensor = tensor_create_with_offset( + tensorType, tensor_name_to_uid["V"], + v_dim, v_stride, + false, false, QKVRaggedOffsetTensor); + auto vTransposeTensor = tensor_create_with_offset( + tensorType, tensor_name_to_uid["V_TRANSPOSE"], + v_transpose_dim, v_transpose_stride, + false, false, QKVRaggedOffsetTensor); // is virtual + + // dO * V.T + auto afterdOVTensor = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 600, + s_dim, s_stride, true, false); // is virtual + + // Define the matmul desc + auto matmulDesc = cudnn_frontend::MatMulDescBuilder() + .setComputeType(CUDNN_DATA_FLOAT) + .setPaddingValue(0) + .build(); + + // Create reshape node for V -> V.T + auto reshape_op = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR) + .setxDesc(vTensor) + .setyDesc(vTransposeTensor) + .build(); + + // Create a matmul Node + auto matmulOp = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) + .setaMatDesc(dOTensor) + .setbMatDesc(vTransposeTensor) + .setcMatDesc(afterdOVTensor) + .setmOverrideDesc(mnkOverride) + .setnOverrideDesc(mnkOverride) + .setmatmulDesc(matmulDesc) + .build(); + + ops->push_back(std::move(reshape_op)); + ops->push_back(std::move(matmulOp)); + + return afterdOVTensor; +} + +static cudnn_frontend::Tensor createdOAndORowReductionChain( + int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, + NVTE_QKV_Layout layout, + std::vector* ops, + const cudnn_frontend::Tensor &O_after_dequan, + const cudnn_frontend::Tensor &dO_after_dequan, + const cudnn_frontend::Tensor &dropoutScale_dOVt_OdO_Tensor) { + int64_t o_dim[4] = {b, h, s_q, d}; + int64_t o_stride[4]; + generateMatrixStrides(b, h, s_q, s_kv, d, o_stride, layout, NVTE_QKV_Matrix::NVTE_O_Matrix); + int64_t o_dim_row_sum[4] = {b, h, s_q, 1}; + int64_t o_dim_row_sum_stride[4] = {s_q * h, s_q, 1, 1}; + + auto O_dO_after_pointwise_multiply = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 700, + o_dim, o_stride, true, false); // is virtual + auto O_dO_after_dropout_scale = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 701, + o_dim, o_stride, true, false); // is virtual + auto O_dO_after_rowsum = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 702, + o_dim_row_sum, o_dim_row_sum_stride, true, false); // is virtual + + // Define the pw multiply descriptor + auto multiplyDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); + + // Create a multiply Node + auto mutliply_op = binary_pw_op_create( + O_after_dequan, dO_after_dequan, + O_dO_after_pointwise_multiply, multiplyDesc); + + // Create multiply node with dropout scale + auto dropout_scale_multiply_op = binary_pw_op_create( + O_dO_after_pointwise_multiply, dropoutScale_dOVt_OdO_Tensor, + O_dO_after_dropout_scale, multiplyDesc); + + // Define the reduction descriptor + auto reductionAddDesc = cudnn_frontend::ReductionDescBuilder() + .setComputeType(CUDNN_DATA_FLOAT) + .setReductionOp(CUDNN_REDUCE_TENSOR_ADD) + .build(); + + // Create a reduction add Node + auto reductionAdd_op = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) + .setxDesc(O_dO_after_dropout_scale) + .setyDesc(O_dO_after_rowsum) + .setreductionDesc(reductionAddDesc) + .build(); + + ops->push_back(std::move(mutliply_op)); + ops->push_back(std::move(dropout_scale_multiply_op)); + ops->push_back(std::move(reductionAdd_op)); + + return O_dO_after_rowsum; +} + +static cudnn_frontend::Tensor createBiasSubtractionSoftmaxMulChain( + int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, + NVTE_QKV_Layout layout, + std::vector* ops, + const cudnn_frontend::Tensor &dS_after_dropout, + const cudnn_frontend::Tensor &AfterDropout_before_quan_S, + const cudnn_frontend::Tensor &O_dO_after_rowsum, + const cudnn_frontend::Tensor &attnScale) { + int64_t o_dim[4] = {b, h, s_q, s_kv}; + int64_t o_stride[4]; + generateMatrixStrides(b, h, s_q, s_kv, d, o_stride, layout, NVTE_QKV_Matrix::NVTE_S_Matrix); + auto dS_minus_O_dO = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 800, + o_dim, o_stride, true, false); // is virtual + auto AfterAttnScale_before_dS = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 801, + o_dim, o_stride, true, false); // is virtual + auto S_mul_dS_minus_O_dO = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 802, + o_dim, o_stride, true, false); // is virtual + + // Define the pw subtraction descriptor + auto subDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_SUB); + + // Create a subtraction Node + auto sub_op = binary_pw_op_create( + dS_after_dropout, O_dO_after_rowsum, dS_minus_O_dO, subDesc); + + // Define the pw multiplication descriptor + auto multiplyDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); + + // dS_minus_O_dO * attnScale + auto mutliply_attn_scale_op = binary_pw_op_create( + dS_minus_O_dO, attnScale, + AfterAttnScale_before_dS, multiplyDesc); + + // AfterDropout_before_quan_S * AfterAttnScale_before_dS + auto mutliply_op = binary_pw_op_create( + AfterDropout_before_quan_S, AfterAttnScale_before_dS, + S_mul_dS_minus_O_dO, multiplyDesc); + + ops->push_back(std::move(sub_op)); + ops->push_back(std::move(mutliply_attn_scale_op)); + ops->push_back(std::move(mutliply_op)); + + return S_mul_dS_minus_O_dO; +} + +static cudnn_frontend::Tensor createdSKBMM( + int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, + std::vector* ops, + const cudnn_frontend::Tensor &dSTensor, + const cudnn_frontend::Tensor &kTensor, + const cudnn_frontend::Tensor &mnkOverride) { + // Creates the necessary tensor descriptors + int64_t after_dSK_dim[4] = {b, h, s_kv, d}; + int64_t after_dSK_stride[4] = {h * s_kv * d, d, h * d, 1}; + // dS * K + auto After_dS_K = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 875, + after_dSK_dim, after_dSK_stride, true, false); // is virtual + + // Define the matmul desc + auto matmulDesc = cudnn_frontend::MatMulDescBuilder() + .setComputeType(CUDNN_DATA_FLOAT) + .setPaddingValue(0) + .build(); + + // Create a matmul Node + auto matmulOp = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) + .setaMatDesc(dSTensor) + .setbMatDesc(kTensor) + .setcMatDesc(After_dS_K) + .setmOverrideDesc(mnkOverride) + .setkOverrideDesc(mnkOverride) + .setmatmulDesc(matmulDesc) + .build(); + + ops->push_back(std::move(matmulOp)); + + return After_dS_K; +} + +static cudnn_frontend::Tensor createdSQBMM( + int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, + NVTE_QKV_Layout layout, + std::vector* ops, + const cudnn_frontend::Tensor &dSTensor, + const cudnn_frontend::Tensor &qTensor, + const cudnn_frontend::Tensor &mnkOverride) { + // Creates the necessary tensor descriptors + int64_t dS_stride[4]; + generateMatrixStrides(b, h, s_q, s_kv, d, dS_stride, layout, NVTE_QKV_Matrix::NVTE_S_Matrix); + + int64_t dS_transpose_dim[4] = {b, h, s_kv, s_q}; + int64_t dS_transpose_stride[4]; + dS_transpose_stride[0] = dS_stride[0]; + dS_transpose_stride[1] = dS_stride[1]; + dS_transpose_stride[2] = dS_stride[3]; + dS_transpose_stride[3] = dS_stride[2]; + + int64_t after_dSTranspose_Q_dim[4] = {b, h, s_kv, d}; + int64_t after_dSTranspose_Q_stride[4] = {h * s_kv * d, d, h * d, 1}; + + auto dSTransposeTensor = tensor_create( + CUDNN_DATA_FP8_E5M2, tensor_name_to_uid["VIRTUAL"] + 650, + dS_transpose_dim, dS_transpose_stride, true, false); // is virtual + + // dS.T * Q + auto After_dSTranspose_Q = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 651, + after_dSTranspose_Q_dim, after_dSTranspose_Q_stride, + true, false); // is virtual + + // Create reshape node for V -> V.T + auto reshape_op = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR) + .setxDesc(dSTensor) + .setyDesc(dSTransposeTensor) + .build(); + + // Define the matmul desc + auto matmulDesc = cudnn_frontend::MatMulDescBuilder() + .setComputeType(CUDNN_DATA_FLOAT) + .setPaddingValue(0) + .build(); + + // Create a matmul Node + auto matmulOp = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) + .setaMatDesc(dSTransposeTensor) + .setbMatDesc(qTensor) + .setcMatDesc(After_dSTranspose_Q) + .setmOverrideDesc(mnkOverride) + .setkOverrideDesc(mnkOverride) + .setmatmulDesc(matmulDesc) + .build(); + + ops->push_back(std::move(reshape_op)); + ops->push_back(std::move(matmulOp)); + + return After_dSTranspose_Q; +} + +// fused attention FWD FP8 +void fa_fwd_fp8(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d, + bool isTraining, float attnScale, + float dropoutProbability, NVTE_QKV_Layout layout, + void* devPtrQ, void* devPtrK, void* devPtrV, + void* devPtrM, void* devPtrZInv, + void* devPtrO, + void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, + void* devPtrDescaleS, void* devPtrScaleS, void* devPtrScaleO, + void* devPtrAmaxO, void* devPtrAmaxS, + void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, + void* devPtrDropoutSeed, void* devPtrDropoutOffset, + cudnnDataType_t tensorType, + void* workspace_ptr, + size_t* workspace_size, + cudaStream_t stream, + cudnnHandle_t handle_) { + try { + NVTE_CHECK_CUDNN(cudnnSetStream(handle_, stream)); + + FADescriptor descriptor{ + b, h, s_q, s_kv, d, + attnScale, isTraining, dropoutProbability, layout, tensorType}; + + using CacheType = std::map; + static CacheType fa_fprop_cache; + + // Get plan from cache if cache is available, otherwise create one + auto get_plan = [&](CacheType &cache, const FADescriptor &descriptor) { + // If hit, return + auto it = cache.find(descriptor); + if (it != cache.end()) { + auto plan = it->second; + return plan; + } + + // Otherwise, build the op_graph and the plan. Then update cache + std::vector all_ops; + std::vector ops; + + cudnn_frontend::throw_if(dropoutProbability != 0.0f && !isTraining, + "Dropout probability should be 0.0f for inference mode", + CUDNN_STATUS_BAD_PARAM); + cudnn_frontend::throw_if(dropoutProbability == 1.0f, + "Dropout probability cannot be 1.0", + CUDNN_STATUS_BAD_PARAM); + + int64_t raggedDim[4] = {b + 1, 1, 1, 1}; + int64_t raggedStride[4] = {1, 1, 1, 1}; + // Create offset tensors + auto QKVOffsetTensor = tensor_create( + CUDNN_DATA_INT32, tensor_name_to_uid["QKV_RAGGED"], + raggedDim, raggedStride, false, false); + auto ORaggedOffsetTensor = tensor_create( + CUDNN_DATA_INT32, tensor_name_to_uid["O_RAGGED"], + raggedDim, raggedStride, false, false); + + int64_t seqlen_dim[4] = {b, 1, 1, 1}; + int64_t seqlen_stride[4] = {1, 1, 1, 1}; + // Create override tensors + auto seqlenMNKTensor = tensor_create( + CUDNN_DATA_INT32, tensor_name_to_uid["MNK_OVERRIDE"], + seqlen_dim, seqlen_stride, false, false); + + // Create shared ptrs to ragged offset tensors + // for multiple tensors to use ragged offset + std::shared_ptr QKVRaggedOffsetTensorPtr = + std::make_shared(std::move(QKVOffsetTensor)); + std::shared_ptr ORaggedOffsetTensorPtr = + std::make_shared(std::move(ORaggedOffsetTensor)); + + // Create Q and K tensors that are used in different places + int64_t q_dim[4] = {b, h, s_q, d}; + int64_t q_stride[4]; + generateMatrixStrides(b, h, s_q, s_kv, d, q_stride, layout, + NVTE_QKV_Matrix::NVTE_Q_Matrix); + + int64_t k_dim[4] = {b, h, s_kv, d}; + int64_t k_stride[4]; + generateMatrixStrides(b, h, s_q, s_kv, d, k_stride, layout, + NVTE_QKV_Matrix::NVTE_K_Matrix); + + auto qTensor = tensor_create_with_offset( + tensorType, tensor_name_to_uid["Q"], + q_dim, q_stride, false, false, + QKVRaggedOffsetTensorPtr); + auto kTensor = tensor_create_with_offset( + tensorType, tensor_name_to_uid["K"], + k_dim, k_stride, false, false, + QKVRaggedOffsetTensorPtr); + + // Q * K.T + auto afterQKTensor = createQKBMM( + b, h, s_q, s_kv, d, layout, tensorType, + &ops, qTensor, kTensor, + seqlenMNKTensor, QKVRaggedOffsetTensorPtr); + + // QK.T * attn scale + auto AfterAttnScale_before_dequan_Q_tensor = createScale( + afterQKTensor, // input tensor + "AttnScale", // scale tensor + CUDNN_DATA_FLOAT, // output tensor type + true, // output is virtual + true, // scale is by value + &ops); + + // QK.T * attn scale * dequant_Q + auto AfterAttnScale_before_dequan_K_tensor = createScale( + AfterAttnScale_before_dequan_Q_tensor, // input tensor + "descaleQ", // scale tensor + CUDNN_DATA_FLOAT, // output tensor type + true, // output is virtual + false, // scale is by value + &ops); + + // QK.T * attn scale * dequant_Q * dequant_K + auto AfterAttnScale_tensor = createScale( + AfterAttnScale_before_dequan_K_tensor, // input tensor + "descaleK", // scale tensor + CUDNN_DATA_FLOAT, // output tensor type + true, // output is virtual + false, // scale is by value + &ops); + + auto BeforeDropoutTensor = createSoftmaxForward( + b, h, s_q, s_kv, &ops, + AfterAttnScale_tensor, isTraining); + + auto AfterDropout_before_quan_S = createDropoutForward( + b, h, s_q, s_kv, dropoutProbability, + &ops, BeforeDropoutTensor); + + // Amax for S + createAmax("amaxS", BeforeDropoutTensor, &ops); + + // After softmax * dropout * scale S -> fp8 input to next bmm with V + auto AfterMultiplyDropout = createScale( + AfterDropout_before_quan_S, // input tensor + "scaleS", // scale tensor + tensorType, // output tensor type + true, // output is virtual + false, // scale is by value + &ops); + + // After softmax * Dropout * V + auto OTensor_before_dequan_S_tensor = createSVBMM( + b, h, s_q, s_kv, d, layout, tensorType, + &ops, AfterMultiplyDropout, + seqlenMNKTensor, QKVRaggedOffsetTensorPtr); + + // O * dequant_S + auto OTensor_before_dequan_V_tensor = createScale( + OTensor_before_dequan_S_tensor, // input tensor + "descaleS", // scale tensor + CUDNN_DATA_FLOAT, // output tensor type + true, // output is virtual + false, // scale is by value + &ops); + + // O * dequant_S * dequant_V + auto OTensor_before_quan_O_tensor = createScale( + OTensor_before_dequan_V_tensor, // input tensor + "descaleV", // scale tensor + CUDNN_DATA_FLOAT, // output tensor type + true, // output is virtual + false, // scale is by value + &ops); + + // O * dequant_S * dequant_V * scale O + auto OTensor = createScaleWithOffset( + OTensor_before_quan_O_tensor, // input tensor + "scaleO", // scale tensor + tensorType, // output tensor type + false, // output not virtual + false, // scale is by value + &ops, + ORaggedOffsetTensorPtr, // ragged offset + "O"); + + // Amax for O + createAmax("amaxO", OTensor_before_quan_O_tensor, &ops); + + for (unsigned int i = 0; i < ops.size(); i++) { + all_ops.push_back(&ops[i]); + } + + // Create an Operation Graph + auto opGraph = cudnn_frontend::OperationGraphBuilder() + .setHandle(handle_) + .setOperationGraph(all_ops.size(), all_ops.data()) + .build(); + + cudnn_frontend::EngineConfigList filtered_configs; + auto statuses = cudnn_frontend::get_heuristics_list<1>( + {"heuristics_instant"}, opGraph, + allowAllConfig, filtered_configs, true); + + if (filtered_configs.size() == 0) { + cudnn_frontend::set_error_and_throw_exception( + nullptr, + CUDNN_STATUS_NOT_SUPPORTED, + "run_mha_fprop: No config returned by the heuristics"); + } + + auto plan = cudnn_frontend::ExecutionPlanBuilder() + .setHandle(handle_) + .setEngineConfig(filtered_configs[0], opGraph.getTag()) + .build(); + cache.insert({descriptor, plan}); + return plan; + }; // end of get_plan + + auto plan = get_plan(fa_fprop_cache, descriptor); + size_t wkspace_size = static_cast(plan.getWorkspaceSize()); + + // Exit to request upper level API to allocate memory if needed + if (workspace_ptr == nullptr) { + *workspace_size = wkspace_size + ((b + 1) * 2 + b) * sizeof(int32_t); + return; + } + + int32_t* qkv_ragged_offset = reinterpret_cast( + reinterpret_cast(workspace_ptr) + wkspace_size); + int32_t* o_ragged_offset = reinterpret_cast( + reinterpret_cast(workspace_ptr) + + wkspace_size + (b + 1) * sizeof(int32_t)); + int32_t* actual_seqlens_q = reinterpret_cast( + reinterpret_cast(workspace_ptr) + + wkspace_size + (b + 1) * 2 * sizeof(int32_t)); + // FP8 currently only supports self-attention, so doesn't use devPtrcuSeqlensKV + dim3 blockDims(128); + dim3 gridDims((b + blockDims.x)/blockDims.x); + cu_seqlens_to_offsets<<>>( + b, h, d, reinterpret_cast(devPtrcuSeqlensQ), + actual_seqlens_q, qkv_ragged_offset, o_ragged_offset); + void* devPtrQKVRaggedOffset = reinterpret_cast(qkv_ragged_offset); + void* devPtrORaggedOffset = reinterpret_cast(o_ragged_offset); + void* devPtrMNKOverride = reinterpret_cast(actual_seqlens_q); + + float dropoutScale = 1.0f/(1.0f - dropoutProbability); + + std::set> data_ptrs; + data_ptrs.emplace(std::pair(tensor_name_to_uid["Q"], devPtrQ)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["K"], devPtrK)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["K_TRANSPOSE"], devPtrK)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["V"], devPtrV)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["AttnScale"], &attnScale)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["DROPOUT_SCALE"], &dropoutScale)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["DROPOUT_SEED"], devPtrDropoutSeed)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["DROPOUT_OFFSET"], devPtrDropoutOffset)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["O"], devPtrO)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["descaleQ"], devPtrDescaleQ)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["descaleK"], devPtrDescaleK)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["descaleV"], devPtrDescaleV)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["descaleS"], devPtrDescaleS)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["scaleS"], devPtrScaleS)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["scaleO"], devPtrScaleO)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["amaxO"], devPtrAmaxO)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["amaxS"], devPtrAmaxS)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["QKV_RAGGED"], devPtrQKVRaggedOffset)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["O_RAGGED"], devPtrORaggedOffset)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["MNK_OVERRIDE"], devPtrMNKOverride)); + + // If training, then we need to write out M and Z_INV + if (isTraining) { + data_ptrs.emplace(std::pair( + tensor_name_to_uid["M"], devPtrM)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["Z_INV"], devPtrZInv)); + } + + auto variantPack = cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(workspace_ptr) + .setDataPointers(data_ptrs) + .build(); + cudnnStatus_t status = cudnnBackendExecute( + handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); + + cudnn_frontend::throw_if( + [status]() { return (status != CUDNN_STATUS_SUCCESS); }, + "Plan execute error", status); + } catch (cudnn_frontend::cudnnException& e) { + struct cudaDeviceProp prop; + NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, 0)); + + // This example is only for GH100 cards (cudnn Version >= 8900) + if (!((prop.major == 9 && prop.minor == 0 && CUDNN_VERSION >= 8900)) + && (e.getCudnnStatus() == CUDNN_STATUS_ARCH_MISMATCH + || e.getCudnnStatus() == CUDNN_STATUS_NOT_SUPPORTED)) { + std::cout << "Example is only supported for GH100 (cuDNN >= 8900) GPUs" << std::endl; + } else { + std::cout << "[ERROR] Exception " << e.what() << std::endl; + } + } +} + +// fused attention BWD FP8 +void fa_bwd_fp8(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d, + float attnScale, float dropoutProbability, NVTE_QKV_Layout layout, + void* devPtrQ, void* devPtrK, void* devPtrV, + void* devPtrM, void* devPtrZInv, + void* devPtrO, void* devPtrdO, + void* devPtrdQ, void* devPtrdK, void* devPtrdV, + void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, + void* devPtrDescaleO, void* devPtrDescaledO, + void* devPtrDescaleS, void* devPtrDescaledS, + void* devPtrScaleS, void* devPtrScaledS, + void* devPtrScaledQ, void* devPtrScaledK, void* devPtrScaledV, + void* devPtrAmaxdS, + void* devPtrAmaxdQ, void* devPtrAmaxdK, void* devPtrAmaxdV, + void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, + void* devPtrDropoutSeed, void* devPtrDropoutOffset, + cudnnDataType_t tensorType, + void* workspace_ptr, + size_t* workspace_size, + cudaStream_t stream, + cudnnHandle_t handle_) { + try { + NVTE_CHECK_CUDNN(cudnnSetStream(handle_, stream)); + + FADescriptor descriptor{ + b, h, s_q, s_kv, d, + attnScale, false, dropoutProbability, layout, tensorType}; + + using CacheType = std::map; + static CacheType fa_bprop_cache; + + // Get plan from cache if cache is available, otherwise create one + auto get_plan = [&](CacheType &cache, const FADescriptor &descriptor) { + // If hit, return + auto it = cache.find(descriptor); + if (it != cache.end()) { + auto plan = it->second; + return plan; + } + + // Otherwise, build the op_graph and the plan. Then update cache + std::vector all_ops; + std::vector ops; + + cudnn_frontend::throw_if(dropoutProbability == 1.0f, + "Dropout probability cannot be 1.0", + CUDNN_STATUS_BAD_PARAM); + + int64_t raggedDim[4] = {b + 1, 1, 1, 1}; + int64_t raggedStride[4] = {1, 1, 1, 1}; + // Create offset tensors + auto QKVOffsetTensor = tensor_create( + CUDNN_DATA_INT32, tensor_name_to_uid["QKV_RAGGED"], + raggedDim, raggedStride, false, false); + auto ORaggedOffsetTensor = tensor_create( + CUDNN_DATA_INT32, tensor_name_to_uid["O_RAGGED"], + raggedDim, raggedStride, false, false); + + // Create shared ptrs to ragged offset tensors for multiple tensors + std::shared_ptr QKVRaggedOffsetTensorPtr = + std::make_shared(std::move(QKVOffsetTensor)); + std::shared_ptr ORaggedOffsetTensorPtr = + std::make_shared(std::move(ORaggedOffsetTensor)); + + // Create Q and K tensors that are used in different places + int64_t q_dim[4] = {b, h, s_q, d}; + int64_t q_stride[4]; + generateMatrixStrides(b, h, s_q, s_kv, d, q_stride, layout, + NVTE_QKV_Matrix::NVTE_Q_Matrix); + + int64_t k_dim[4] = {b, h, s_kv, d}; + int64_t k_stride[4]; + generateMatrixStrides(b, h, s_q, s_kv, d, k_stride, layout, + NVTE_QKV_Matrix::NVTE_K_Matrix); + + auto qTensor = tensor_create_with_offset( + tensorType, tensor_name_to_uid["Q"], + q_dim, q_stride, false, false, QKVRaggedOffsetTensorPtr); + auto kTensor = tensor_create_with_offset( + tensorType, tensor_name_to_uid["K"], + k_dim, k_stride, false, false, QKVRaggedOffsetTensorPtr); + + int64_t scale_dim[4] = {1, 1, 1, 1}; + int64_t scale_stride[4] = {1, 1, 1, 1}; + + // Create attnScale tensor for multiple ops to use + auto attnScaleTensor = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["AttnScale"], + scale_dim, scale_stride, false, true); // is by value + + // Create descale Q K dO dS global tensors since they are used in multiple places + auto descaleQTensor = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["descaleQ"], + scale_dim, scale_stride, false, false); + auto descaleKTensor = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["descaleK"], + scale_dim, scale_stride, false, false); + auto descaledOTensor = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["descaledO"], + scale_dim, scale_stride, false, false); + auto descaledSTensor = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["descaledS"], + scale_dim, scale_stride, false, false); + + int64_t seqlen_dim[4] = {b, 1, 1, 1}; + int64_t seqlen_stride[4] = {1, 1, 1, 1}; + // Create MNK override tensor + auto seqlenMNKTensor = tensor_create( + CUDNN_DATA_INT32, tensor_name_to_uid["MNK_OVERRIDE"], + seqlen_dim, seqlen_stride, false, false); + + int64_t O_dim[4] = {b, h, s_q, d}; + int64_t O_stride[4]; + generateMatrixStrides(b, h, s_q, s_kv, d, O_stride, layout, + NVTE_QKV_Matrix::NVTE_O_Matrix); + // Create O and loss tensor + auto OTensor = tensor_create_with_offset( + tensorType, tensor_name_to_uid["O"], + O_dim, O_stride, false, false, ORaggedOffsetTensorPtr); + // dO is used in multiple places and E5M2 + auto dOTensor = tensor_create_with_offset( + CUDNN_DATA_FP8_E5M2, tensor_name_to_uid["dO"], + O_dim, O_stride, false, false, ORaggedOffsetTensorPtr); + + // Q * K.T + auto afterQKTensor = createQKBMM( + b, h, s_q, s_kv, d, layout, tensorType, + &ops, qTensor, kTensor, + seqlenMNKTensor, QKVRaggedOffsetTensorPtr); + + // QK.T * attn scale + auto AfterAttnScale_before_dequan_Q_tensor = createScale( + afterQKTensor, // input tensor + attnScaleTensor, // scale tensor + CUDNN_DATA_FLOAT, // output tensor type + true, // output is virtual + true, // scale is by value + &ops, + 1999 /*UID offset*/); + + // QK.T * attn scale * dequant_Q + auto AfterAttnScale_before_dequan_K_tensor = createScale( + AfterAttnScale_before_dequan_Q_tensor, // input tensor + descaleQTensor, // scale tensor + CUDNN_DATA_FLOAT, // output tensor type + true, // output is virtual + false, // scale is by value + &ops, + 2000 /*UID offset*/); + + // QK.T * attn scale * dequant_Q * dequant_K + auto AfterAttnScale_tensor = createScale( + AfterAttnScale_before_dequan_K_tensor, // input tensor + descaleKTensor, // scale tensor + CUDNN_DATA_FLOAT, // output tensor type + true, // output is virtual + false, // scale is by value + &ops, + 2001 /*UID offset*/); + + auto beforeDropout_QKt_Tensor = createSoftmaxBackward( + b, h, s_q, s_kv, &ops, AfterAttnScale_tensor); + + int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv}; + int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1}; + + // mask for the dropout. Used in different places + auto dropoutMaskTensor = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 200, + afterBMM1_dim, afterBMM1_stride, true, false); // is virtual + + auto AfterDropout_before_quan_S = createDropoutBackward( + b, h, s_q, s_kv, dropoutProbability, + &ops, beforeDropout_QKt_Tensor, dropoutMaskTensor); + + // After softmax * scale S -> fp8 input to next bmm with V + auto AfterMultiply = createScale( + AfterDropout_before_quan_S, // input tensor + "scaleS", // scale tensor + tensorType, // output tensor type + true, // output is virtual + false, // scale is by value + &ops); + + // After softmax * dO + auto dVTensor_before_dequan_S = createSdOBMM( + b, h, s_q, s_kv, d, tensorType, + &ops, AfterMultiply, dOTensor, seqlenMNKTensor); + + // O * dequant_S + auto dVTensor_before_dequan_dO = createScale( + dVTensor_before_dequan_S, // input tensor + "descaleS", // scale tensor + CUDNN_DATA_FLOAT, // output tensor type + true, // output is virtual + false, // scale is by value + &ops); + + // O * dequant_S * dequant_dO + auto dVTensor_before_quan_dV = createScale( + dVTensor_before_dequan_dO, // input tensor + descaledOTensor, // scale tensor + CUDNN_DATA_FLOAT, // output tensor type + true, // output is virtual + false, // scale is by value + &ops, + 2002 /*UID offset*/); + + // O * dequant_S * dequant_dO * scale dV + auto dVTensor = createScaleWithOffset( + dVTensor_before_quan_dV, // input tensor + "scaledV", // scale tensor + CUDNN_DATA_FP8_E5M2, // output tensor type + false, // output not virtual + false, // scale is by value + &ops, + QKVRaggedOffsetTensorPtr, // ragged offset + "dV" /*Output tensor name*/); + + // Amax for dV + createAmax("amaxdV", dVTensor_before_quan_dV, &ops); + + auto dS_before_dequan_dO_Tensor = createdOVBMM( + b, h, s_q, s_kv, d, layout, tensorType, + &ops, dOTensor, seqlenMNKTensor, QKVRaggedOffsetTensorPtr); + + // dS * dequant_dO + auto dS_before_dequan_V = createScale( + dS_before_dequan_dO_Tensor, // input tensor + descaledOTensor, // scale tensor + CUDNN_DATA_FLOAT, // output tensor type + true, // output is virtual + false, // scale is by value + &ops, + 2003 /*UID offset*/); + + // O * dequant_S * dequant_dV + auto dS_after_dequan = createScale( + dS_before_dequan_V, // input tensor + "descaleV", // scale tensor + CUDNN_DATA_FLOAT, // output tensor type + true, // output is virtual + false, // scale is by value + &ops); + + // RNG Multiply + auto beforeDropoutScale_dOVt_Tensor = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 350, + afterBMM1_dim, afterBMM1_stride, true, false); // is virtual + // After dropout mask and scale + auto dS_after_dropout = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 351, + afterBMM1_dim, afterBMM1_stride, true, false); // is virtual + + // Define the multiply mask descriptor + auto mulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); + + // Create a multiply mask Node + auto maskMul_op = binary_pw_op_create( + dS_after_dequan, dropoutMaskTensor, + beforeDropoutScale_dOVt_Tensor, mulDesc); + + ops.push_back(std::move(maskMul_op)); + + // scale after dropout for dO and O chain + auto dropoutScale_dOVt_OdO_Tensor = tensor_create( + tensorType, tensor_name_to_uid["DROPOUT_SCALE_dOVt_OdO"], + scale_dim, scale_stride, false, true); // is by value + + // Create a multiply dropout scale Node + auto mul_dropout_scale_op = binary_pw_op_create( + beforeDropoutScale_dOVt_Tensor, + dropoutScale_dOVt_OdO_Tensor, + dS_after_dropout, mulDesc); + + ops.push_back(std::move(mul_dropout_scale_op)); + + // O * dequant_O + auto O_after_dequan_Tensor = createScale(OTensor, // input tensor + "descaleO", // scale tensor + CUDNN_DATA_FLOAT, // output tensor type + true, // output is virtual + false, // scale is by value + &ops); + + // dO * dequant_dO + auto dO_after_dequan_Tensor = createScale(dOTensor, // input tensor + descaledOTensor, // scale tensor + CUDNN_DATA_FLOAT, // output tensor type + true, // output is virtual + false, // scale is by value + &ops, + 2004 /*UID offset*/); + + // row reduction sum[(dO * dequant_dO) * (O * dequant_O) * (1 - p)] + auto O_dO_after_rowsum = createdOAndORowReductionChain( + b, h, s_q, s_kv, d, layout, + &ops, O_after_dequan_Tensor, + dO_after_dequan_Tensor, dropoutScale_dOVt_OdO_Tensor); + + // (dS_after_dropout - O_dO_after_rowsum) * AfterDropout_before_quan_S * attnScale + auto S_mul_dS_minus_O_dO = createBiasSubtractionSoftmaxMulChain( + b, h, s_q, s_kv, d, layout, + &ops, dS_after_dropout, + AfterDropout_before_quan_S, O_dO_after_rowsum, + attnScaleTensor); + + + // S_mul_dS_minus_O_dO * scaledS + auto S_mul_dS_minus_O_dO_after_quan_dS = createScale( + S_mul_dS_minus_O_dO, // input tensor + "scaledS", // scale tensor + CUDNN_DATA_FP8_E5M2, // output tensor type + true, // output is virtual + false, // scale is by value + &ops); + + // Amax for dS + createAmax("amaxdS", S_mul_dS_minus_O_dO, &ops); + + // dS @ K + auto After_dS_K = createdSKBMM( + b, h, s_q, s_kv, d, &ops, + S_mul_dS_minus_O_dO_after_quan_dS, + kTensor, seqlenMNKTensor); + + // (dS * K) * descale dS + auto After_dS_K_before_dequan_K = createScale( + After_dS_K, // input tensor + descaledSTensor, // scale tensor + CUDNN_DATA_FLOAT, // output tensor type + true, // output is virtual + false, // scale is by value + &ops, + 2006 /*UID offset*/); + + // (dS * K) * descale dS * descale K + auto After_dS_K_before_quan_dQ = createScale( + After_dS_K_before_dequan_K, // input tensor + descaleKTensor, // scale tensor + CUDNN_DATA_FLOAT, // output tensor type + true, // output is virtual + false, // scale is by value + &ops, + 2007 /*UID offset*/); + + // (dS * K) * descale dS * descale K * scale dQ + auto dQ = createScaleWithOffset( + After_dS_K_before_quan_dQ, // input tensor + "scaledQ", // scale tensor + CUDNN_DATA_FP8_E5M2, // output tensor type + false, // output not virtual + false, // scale is by value + &ops, + QKVRaggedOffsetTensorPtr, // ragged offset + "dQ"); + + // Amax for dQ + createAmax("amaxdQ", After_dS_K_before_quan_dQ, &ops); + + // dS.T @ Q + auto After_dSTranspose_Q = createdSQBMM( + b, h, s_q, s_kv, d, layout, &ops, + S_mul_dS_minus_O_dO_after_quan_dS, + qTensor, seqlenMNKTensor); + + // (dS.T * Q) * descale dS + auto After_dSTranspose_Q_before_dequan_Q = createScale( + After_dSTranspose_Q, // input tensor + descaledSTensor, // scale tensor + CUDNN_DATA_FLOAT, // output tensor type + true, // output is virtual + false, // scale is by value + &ops, + 2009 /*UID offset*/); + + // (dS.T * Q) * descale dS * descale Q + auto After_dSTranspose_Q_before_quan_dK = createScale( + After_dSTranspose_Q_before_dequan_Q, // input tensor + descaleQTensor, // scale tensor + CUDNN_DATA_FLOAT, // output tensor type + true, // output is virtual + false, // scale is by value + &ops, + 2010 /*UID offset*/); + + // (dS.T * Q) * descale dS * descale Q * scale dK + auto dK = createScaleWithOffset( + After_dSTranspose_Q_before_quan_dK, // input tensor + "scaledK", // scale tensor + CUDNN_DATA_FP8_E5M2, // output tensor type + false, // output not virtual + false, // scale is by value + &ops, + QKVRaggedOffsetTensorPtr, // ragged offset + "dK"); + + // Amax for dK + createAmax("amaxdK", After_dSTranspose_Q_before_quan_dK, &ops); + + for (unsigned int i = 0; i < ops.size(); i++) { + all_ops.push_back(&ops[i]); + } + + // Create an Operation Graph + auto opGraph = cudnn_frontend::OperationGraphBuilder() + .setHandle(handle_) + .setOperationGraph(all_ops.size(), all_ops.data()) + .build(); + + cudnn_frontend::EngineConfigList filtered_configs; + auto statuses = cudnn_frontend::get_heuristics_list<1>( + {"heuristics_instant"}, opGraph, + allowAllConfig, filtered_configs, true); + + if (filtered_configs.size() == 0) { + cudnn_frontend::set_error_and_throw_exception( + nullptr, + CUDNN_STATUS_NOT_SUPPORTED, + "run_mha_bprop: No config returned by the heuristics"); + } + + auto plan = cudnn_frontend::ExecutionPlanBuilder() + .setHandle(handle_) + .setEngineConfig(filtered_configs[0], opGraph.getTag()) + .build(); + cache.insert({descriptor, plan}); + return plan; + }; + + auto plan = get_plan(fa_bprop_cache, descriptor); + size_t wkspace_size = static_cast(plan.getWorkspaceSize()); + + // Exit to request upper level API to allocate memory if needed + if (workspace_ptr == nullptr) { + *workspace_size = wkspace_size + ((b + 1) * 2 + b) * sizeof(int32_t); + return; + } + + int32_t* qkv_ragged_offset = reinterpret_cast( + reinterpret_cast(workspace_ptr) + wkspace_size); + int32_t* o_ragged_offset = reinterpret_cast( + reinterpret_cast(workspace_ptr) + + wkspace_size + (b + 1) * sizeof(int32_t)); + int32_t* actual_seqlens_q = reinterpret_cast( + reinterpret_cast(workspace_ptr) + + wkspace_size + (b + 1) * 2 * sizeof(int32_t)); + // FP8 currently only supports self-attention, so doesn't use devPtrcuSeqlensKV + dim3 blockDims(128); + dim3 gridDims((b + blockDims.x)/blockDims.x); + cu_seqlens_to_offsets<<>>( + b, h, d, reinterpret_cast(devPtrcuSeqlensQ), + actual_seqlens_q, qkv_ragged_offset, o_ragged_offset); + void* devPtrQKVRaggedOffset = reinterpret_cast(qkv_ragged_offset); + void* devPtrORaggedOffset = reinterpret_cast(o_ragged_offset); + void* devPtrMNKOverride = reinterpret_cast(actual_seqlens_q); + + std::set> data_ptrs; + float dropoutScale = 1.0f/(1.0f - dropoutProbability); + float dropoutScale_dOVt_OdO = 1.0f - dropoutProbability; + data_ptrs.emplace(std::pair(tensor_name_to_uid["Q"], devPtrQ)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["K"], devPtrK)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["K_TRANSPOSE"], devPtrK)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["V"], devPtrV)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["V_TRANSPOSE"], devPtrV)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["dQ"], devPtrdQ)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["dK"], devPtrdK)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["dV"], devPtrdV)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["dO"], devPtrdO)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["AttnScale"], &attnScale)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["DROPOUT_SCALE"], &dropoutScale)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["DROPOUT_SCALE_dOVt_OdO"], + &dropoutScale_dOVt_OdO)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["DROPOUT_SEED"], devPtrDropoutSeed)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["DROPOUT_OFFSET"], devPtrDropoutOffset)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["M"], devPtrM)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["Z_INV"], devPtrZInv)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["O"], devPtrO)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["descaleQ"], devPtrDescaleQ)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["descaleK"], devPtrDescaleK)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["descaleV"], devPtrDescaleV)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["descaleS"], devPtrDescaleS)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["descaledS"], devPtrDescaledS)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["descaleO"], devPtrDescaleO)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["descaledO"], devPtrDescaledO)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["scaleS"], devPtrScaleS)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["scaledS"], devPtrScaledS)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["scaledQ"], devPtrScaledQ)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["scaledK"], devPtrScaledK)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["scaledV"], devPtrScaledV)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["amaxdS"], devPtrAmaxdS)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["amaxdQ"], devPtrAmaxdQ)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["amaxdK"], devPtrAmaxdK)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["amaxdV"], devPtrAmaxdV)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["QKV_RAGGED"], devPtrQKVRaggedOffset)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["O_RAGGED"], devPtrORaggedOffset)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["MNK_OVERRIDE"], devPtrMNKOverride)); + + auto variantPack = cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(workspace_ptr) + .setDataPointers(data_ptrs) + .build(); + cudnnStatus_t status = cudnnBackendExecute( + handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); + + cudnn_frontend::throw_if( + [status]() { return (status != CUDNN_STATUS_SUCCESS); }, + "Plan execute error", status); + } catch (cudnn_frontend::cudnnException& e) { + struct cudaDeviceProp prop; + NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, 0)); + + // This example is only for GH100 cards (cudnn Version >= 8900) + if (!((prop.major == 9 && prop.minor == 0 && CUDNN_VERSION >= 8900)) + && (e.getCudnnStatus() == CUDNN_STATUS_ARCH_MISMATCH + || e.getCudnnStatus() == CUDNN_STATUS_NOT_SUPPORTED)) { + std::cout << "Example is only supported for GH100 (cuDNN >= 8900) GPUs" << std::endl; + } else { + std::cout << "[ERROR] Exception " << e.what() << std::endl; + } + } +} + +#endif + +} // namespace fused_attn + +#if (CUDNN_VERSION >= 8900) +// fused attention FWD FP8 with packed QKV +void fused_attn_fwd_fp8_qkvpacked( + size_t b, size_t max_seqlen, + size_t h, size_t d, + bool is_training, float attn_scale, + float p_dropout, NVTE_QKV_Layout qkv_layout, + const Tensor *input_QKV, + Tensor *input_output_S, + Tensor *output_O, + NVTETensorPack* Aux_Output_Tensors, + const Tensor *cu_seqlens, + const Tensor *rng_state, + Tensor *workspace, + cudaStream_t stream, + cudnnHandle_t handle) { + using namespace transformer_engine; + // QKV shape is [total_seqs, 3, h, d] + void* devPtrQKV = input_QKV->data.dptr; + void* devPtrQ = reinterpret_cast(devPtrQKV); + void* devPtrK = reinterpret_cast(reinterpret_cast(devPtrQKV) + h * d); + void* devPtrV = reinterpret_cast(reinterpret_cast(devPtrQKV) + 2 * h * d); + void* devPtrDescaleQ = input_QKV->scale_inv.dptr; + void* devPtrDescaleK = input_QKV->scale_inv.dptr; + void* devPtrDescaleV = input_QKV->scale_inv.dptr; + + void* devPtrO = output_O->data.dptr; + void* devPtrAmaxO = output_O->amax.dptr; + void* devPtrScaleO = output_O->scale.dptr; + + void* devPtrM = nullptr; + void* devPtrZInv = nullptr; + if (Aux_Output_Tensors->size == 0) { + if (is_training) { + Aux_Output_Tensors->size = 2; + Tensor *output_M = reinterpret_cast(Aux_Output_Tensors->tensors[0]); + Tensor *output_ZInv = reinterpret_cast(Aux_Output_Tensors->tensors[1]); + output_M->data.dptr = nullptr; + output_M->data.shape = {b, h, max_seqlen, 1}; + output_M->data.dtype = DType::kFloat32; + output_ZInv->data.dptr = nullptr; + output_ZInv->data.shape = {b, h, max_seqlen, 1}; + output_ZInv->data.dtype = DType::kFloat32; + } + } else if (Aux_Output_Tensors->size == 2) { + Tensor *output_M = reinterpret_cast(Aux_Output_Tensors->tensors[0]); + Tensor *output_ZInv = reinterpret_cast(Aux_Output_Tensors->tensors[1]); + devPtrM = output_M->data.dptr; + devPtrZInv = output_ZInv->data.dptr; + } + + void* devPtrAmaxS = input_output_S->amax.dptr; + void* devPtrScaleS = input_output_S->scale.dptr; + void* devPtrDescaleS = input_output_S->scale_inv.dptr; + + void* devPtrcuSeqlens = reinterpret_cast( + reinterpret_cast(cu_seqlens->data.dptr)); + void* devPtrDropoutSeed = reinterpret_cast( + reinterpret_cast(rng_state->data.dptr)); + void* devPtrDropoutOffset = reinterpret_cast( + reinterpret_cast(rng_state->data.dptr) + 1); + + const DType QKV_type = input_QKV->data.dtype; + size_t workspace_size = 0; + + fused_attn::fa_fwd_fp8( + b, max_seqlen, max_seqlen, h, d, + is_training, attn_scale, p_dropout, qkv_layout, + devPtrQ, devPtrK, devPtrV, + devPtrM, devPtrZInv, + devPtrO, + devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, + devPtrDescaleS, devPtrScaleS, devPtrScaleO, + devPtrAmaxO, devPtrAmaxS, + devPtrcuSeqlens, devPtrcuSeqlens, + devPtrDropoutSeed, devPtrDropoutOffset, + get_cudnn_dtype(QKV_type), + workspace->data.dptr, &workspace_size, stream, handle); + + if (workspace_size > 0) { + if (workspace->data.dptr == nullptr) { + workspace->data.shape = { workspace_size }; + workspace->data.dtype = DType::kByte; + return; + } + } else if (workspace_size == 0) { + workspace->data.shape = { 1 }; + workspace->data.dtype = DType::kByte; + return; + } +} +// fused attention BWD FP8 with packed QKV +void fused_attn_bwd_fp8_qkvpacked( + size_t b, size_t max_seqlen, + size_t h, size_t d, + float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + const Tensor *input_QKV, + const Tensor *input_O, + const Tensor *input_dO, + const Tensor *input_M, + const Tensor *input_ZInv, + const Tensor *input_S, + Tensor *input_output_dP, + const Tensor *output_dQKV, + const Tensor *cu_seqlens, + const Tensor *rng_state, + Tensor *workspace, + cudaStream_t stream, + cudnnHandle_t handle) { + using namespace transformer_engine; + // QKV shape is [total_seqs, 3, h, d] + void* devPtrQKV = input_QKV->data.dptr; + void* devPtrQ = reinterpret_cast(devPtrQKV); + void* devPtrK = reinterpret_cast(reinterpret_cast(devPtrQKV) + h * d); + void* devPtrV = reinterpret_cast(reinterpret_cast(devPtrQKV) + 2 * h * d); + void* devPtrDescaleQ = input_QKV->scale_inv.dptr; + void* devPtrDescaleK = input_QKV->scale_inv.dptr; + void* devPtrDescaleV = input_QKV->scale_inv.dptr; + + void* devPtrO = input_O->data.dptr; + void* devPtrDescaleO = input_O->scale_inv.dptr; + void* devPtrdO = input_dO->data.dptr; + void* devPtrDescaledO = input_dO->scale_inv.dptr; + + void* devPtrM = input_M->data.dptr; + void* devPtrZInv = input_ZInv->data.dptr; + + void* devPtrScaleS = input_S->scale.dptr; + void* devPtrDescaleS = input_S->scale_inv.dptr; + void* devPtrAmaxdS = input_output_dP->amax.dptr; + void* devPtrScaledS = input_output_dP->scale.dptr; + void* devPtrDescaledS = input_output_dP->scale_inv.dptr; + + // dQKV shape is [total_seqs, 3, h, d] + void* devPtrdQKV = output_dQKV->data.dptr; + void* devPtrdQ = reinterpret_cast(devPtrdQKV); + void* devPtrdK = reinterpret_cast(reinterpret_cast(devPtrdQKV) + h * d); + void* devPtrdV = reinterpret_cast(reinterpret_cast(devPtrdQKV) + 2 * h * d); + void* devPtrAmaxdQ = output_dQKV->amax.dptr; + void* devPtrAmaxdK = output_dQKV->amax.dptr; + void* devPtrAmaxdV = output_dQKV->amax.dptr; + void* devPtrScaledQ = output_dQKV->scale.dptr; + void* devPtrScaledK = output_dQKV->scale.dptr; + void* devPtrScaledV = output_dQKV->scale.dptr; + + void* devPtrcuSeqlens = reinterpret_cast( + reinterpret_cast(cu_seqlens->data.dptr)); + void* devPtrDropoutSeed = reinterpret_cast( + reinterpret_cast(rng_state->data.dptr)); + void* devPtrDropoutOffset = reinterpret_cast( + reinterpret_cast(rng_state->data.dptr) + 1); + + const DType QKV_type = input_QKV->data.dtype; + size_t workspace_size = 0; + + fused_attn::fa_bwd_fp8( + b, max_seqlen, max_seqlen, h, d, + attn_scale, p_dropout, qkv_layout, + devPtrQ, devPtrK, devPtrV, + devPtrM, devPtrZInv, + devPtrO, devPtrdO, + devPtrdQ, devPtrdK, devPtrdV, + devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, + devPtrDescaleO, devPtrDescaledO, + devPtrDescaleS, devPtrDescaledS, + devPtrScaleS, devPtrScaledS, + devPtrScaledQ, devPtrScaledK, devPtrScaledV, + devPtrAmaxdS, + devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, + devPtrcuSeqlens, devPtrcuSeqlens, + devPtrDropoutSeed, devPtrDropoutOffset, + get_cudnn_dtype(QKV_type), + workspace->data.dptr, &workspace_size, stream, handle); + + if (workspace_size > 0) { + if (workspace->data.dptr == nullptr) { + workspace->data.shape = { workspace_size }; + workspace->data.dtype = DType::kByte; + return; + } + } else if (workspace_size == 0) { + workspace->data.shape = { 1 }; + workspace->data.dtype = DType::kByte; + return; + } +} +#endif // end of CUDNN>=8900 +} // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.h b/transformer_engine/common/fused_attn/fused_attn_fp8.h new file mode 100644 index 0000000000..928e128737 --- /dev/null +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.h @@ -0,0 +1,46 @@ +/************************************************************************* + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "transformer_engine/transformer_engine.h" + +namespace transformer_engine { +#if (CUDNN_VERSION >= 8900) +// fused attention FWD FP8 with packed QKV +void fused_attn_fwd_fp8_qkvpacked( + size_t b, size_t max_seqlen, + size_t h, size_t d, + bool is_training, float attn_scale, + float p_dropout, NVTE_QKV_Layout qkv_layout, + const Tensor *input_QKV, + Tensor *input_output_S, + Tensor *output_O, + NVTETensorPack* Aux_Output_Tensors, + const Tensor *cu_seqlens, + const Tensor *rng_state, + Tensor *workspace, + cudaStream_t stream, + cudnnHandle_t handle); + +// fused attention BWD FP8 with packed QKV +void fused_attn_bwd_fp8_qkvpacked( + size_t b, size_t max_seqlen, + size_t h, size_t d, + float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + const Tensor *input_QKV, + const Tensor *input_O, + const Tensor *input_dO, + const Tensor *input_M, + const Tensor *input_ZInv, + const Tensor *input_S, + Tensor *input_output_dP, + const Tensor *output_dQKV, + const Tensor *cu_seqlens, + const Tensor *rng_state, + Tensor *workspace, + cudaStream_t stream, + cudnnHandle_t handle); +#endif // end of CUDNN>=8900 +} // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn/utils.cu b/transformer_engine/common/fused_attn/utils.cu new file mode 100644 index 0000000000..5b0b03cb3e --- /dev/null +++ b/transformer_engine/common/fused_attn/utils.cu @@ -0,0 +1,167 @@ +/************************************************************************* + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "transformer_engine/fused_attn.h" +#include "../common.h" +#include "utils.h" + +namespace transformer_engine { +namespace fused_attn { + +using namespace transformer_engine; + +// get matrix strides based on matrix type +void generateMatrixStrides( + int64_t b, int64_t h, + int64_t s_q, int64_t s_kv, + int64_t d, int64_t* strideA, + NVTE_QKV_Layout layout, NVTE_QKV_Matrix matrix) { + constexpr int batch_dim_idx = 0; + constexpr int head_dim_idx = 1; + constexpr int seqlen_dim_idx = 2; + constexpr int hidden_dim_idx = 3; + + constexpr int seqlen_transpose_dim_idx = 3; + constexpr int hidden_transpose_dim_idx = 2; + + constexpr int seqlen_q_dim_idx = 2; + constexpr int seqlen_kv_dim_idx = 3; + + switch (matrix) { + case NVTE_QKV_Matrix::NVTE_Q_Matrix: + if (layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) { + strideA[hidden_dim_idx] = 1; + strideA[seqlen_dim_idx] = 3 * h * d; + strideA[head_dim_idx] = d; + strideA[batch_dim_idx] = s_q * 3 * h * d; + } else { + strideA[hidden_dim_idx] = 1; + strideA[seqlen_dim_idx] = h * d; + strideA[head_dim_idx] = d; + strideA[batch_dim_idx] = s_q * h * d; + } + break; + case NVTE_QKV_Matrix::NVTE_K_Matrix: + if (layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) { + strideA[seqlen_dim_idx] = 3 * h * d; + strideA[hidden_dim_idx] = 1; + strideA[head_dim_idx] = d; + strideA[batch_dim_idx] = s_kv * 3 * h * d; + } else if (layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED) { + strideA[seqlen_transpose_dim_idx] = 2 * h * d; + strideA[hidden_transpose_dim_idx] = 1; + strideA[head_dim_idx] = d; + strideA[batch_dim_idx] = s_kv * 2 * h * d; + } else { + strideA[seqlen_transpose_dim_idx] = h * d; + strideA[hidden_transpose_dim_idx] = 1; + strideA[head_dim_idx] = d; + strideA[batch_dim_idx] = s_kv * h * d; + } + break; + case NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose: + if (layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) { + strideA[seqlen_transpose_dim_idx] = 3 * h * d; + strideA[hidden_transpose_dim_idx] = 1; + strideA[head_dim_idx] = d; + strideA[batch_dim_idx] = s_kv * 3 * h * d; + } else if (layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED) { + strideA[seqlen_transpose_dim_idx] = 2 * h * d; + strideA[hidden_transpose_dim_idx] = 1; + strideA[head_dim_idx] = d; + strideA[batch_dim_idx] = s_kv * 2 * h * d; + } else { + strideA[seqlen_transpose_dim_idx] = h * d; + strideA[hidden_transpose_dim_idx] = 1; + strideA[head_dim_idx] = d; + strideA[batch_dim_idx] = s_kv * h * d; + } + break; + case NVTE_QKV_Matrix::NVTE_V_Matrix: + if (layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) { + strideA[hidden_dim_idx] = 1; + strideA[seqlen_dim_idx] = 3 * h * d; + strideA[head_dim_idx] = d; + strideA[batch_dim_idx] = s_kv * 3 * h * d; + } else if (layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED) { + strideA[hidden_dim_idx] = 1; + strideA[seqlen_dim_idx] = 2* h * d; + strideA[head_dim_idx] = d; + strideA[batch_dim_idx] = s_kv * 2 * h * d; + } else { + strideA[hidden_dim_idx] = 1; + strideA[seqlen_dim_idx] = h * d; + strideA[head_dim_idx] = d; + strideA[batch_dim_idx] = s_kv * h * d; + } + break; + case NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose: + if (layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) { + strideA[hidden_transpose_dim_idx] = 1; + strideA[seqlen_transpose_dim_idx] = 3 * h * d; + strideA[head_dim_idx] = d; + strideA[batch_dim_idx] = s_kv * 3 * h * d; + } else if (layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED) { + strideA[hidden_transpose_dim_idx] = 1; + strideA[seqlen_transpose_dim_idx] = 2* h * d; + strideA[head_dim_idx] = d; + strideA[batch_dim_idx] = s_kv * 2 * h * d; + } else { + strideA[hidden_transpose_dim_idx] = 1; + strideA[seqlen_transpose_dim_idx] = h * d; + strideA[head_dim_idx] = d; + strideA[batch_dim_idx] = s_kv * h * d; + } + break; + case NVTE_QKV_Matrix::NVTE_S_Matrix: + strideA[seqlen_kv_dim_idx] = 1; + strideA[seqlen_q_dim_idx] = s_kv; + strideA[head_dim_idx] = s_q * s_kv; + strideA[batch_dim_idx] = h * s_q * s_kv; + break; + case NVTE_QKV_Matrix::NVTE_O_Matrix: + strideA[seqlen_kv_dim_idx] = 1; + strideA[seqlen_q_dim_idx] = h * d; + strideA[head_dim_idx] = d; + strideA[batch_dim_idx] = s_q * h * d; + break; + } +} + +// convert cu_seqlens_q to qkv/o_ragged_offset and actual_seqlens_q +__global__ void cu_seqlens_to_offsets(size_t b, size_t h, size_t d, + int32_t *cu_seqlens_q, int32_t *actual_seqlens_q, + int32_t *qkv_ragged_offset, int32_t *o_ragged_offset) { + size_t tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < b) { + actual_seqlens_q[tid] = cu_seqlens_q[tid + 1] - cu_seqlens_q[tid]; + } + if (tid < b + 1) { + qkv_ragged_offset[tid] = cu_seqlens_q[tid] * 3 * h * d; + o_ragged_offset[tid] = cu_seqlens_q[tid] * h * d; + } +} +} // namespace fused_attn + +// get cuDNN data type +cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t) { + using namespace transformer_engine; + switch (t) { + case DType::kFloat16: + return CUDNN_DATA_HALF; + case DType::kFloat32: + return CUDNN_DATA_FLOAT; + case DType::kBFloat16: + return CUDNN_DATA_BFLOAT16; + case DType::kFloat8E4M3: + return CUDNN_DATA_FP8_E4M3; + case DType::kFloat8E5M2: + return CUDNN_DATA_FP8_E5M2; + default: + NVTE_ERROR("Invalid cuDNN data type. \n"); + } +} +} // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h new file mode 100644 index 0000000000..371a19990e --- /dev/null +++ b/transformer_engine/common/fused_attn/utils.h @@ -0,0 +1,90 @@ +/************************************************************************* + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_UTILS_H_ +#define TRANSFORMER_ENGINE_FUSED_ATTN_UTILS_H_ + +#include "transformer_engine/transformer_engine.h" +#include + +namespace transformer_engine { +namespace fused_attn { + +using namespace transformer_engine; + +enum NVTE_QKV_Matrix { + NVTE_Q_Matrix = 0, // queries + NVTE_K_Matrix = 1, // keys + NVTE_K_Matrix_Transpose = 2, // keys transposed + NVTE_V_Matrix = 3, // values + NVTE_V_Matrix_Transpose = 4, // value matrix transposed + NVTE_S_Matrix = 5, // output of GEMM1 + NVTE_O_Matrix = 6, // final output +}; + +void generateMatrixStrides( + int64_t b, int64_t h, + int64_t s_q, int64_t s_kv, + int64_t d, int64_t* strideA, + NVTE_QKV_Layout layout, NVTE_QKV_Matrix matrix); + +struct FADescriptor { + std::int64_t b; + std::int64_t h; + std::int64_t s_q; + std::int64_t s_kv; + std::int64_t d; + float attnScale; + bool isTraining; + float dropoutProbability; + NVTE_QKV_Layout layout; + cudnnDataType_t tensor_type; + + bool operator<(const FADescriptor &rhs) const { + return std::tie(b, h, s_q, s_kv, d, + attnScale, isTraining, dropoutProbability, + layout, tensor_type) < std::tie( + rhs.b, rhs.h, rhs.s_q, rhs.s_kv, rhs.d, + rhs.attnScale, rhs.isTraining, + rhs.dropoutProbability, rhs.layout, rhs.tensor_type); + } +}; + +__global__ void cu_seqlens_to_offsets(size_t b, size_t h, size_t d, + int32_t *cu_seqlens_q, int32_t *actual_seqlens_q, + int32_t *qkv_ragged_offset, int32_t *o_ragged_offset); + +} // namespace fused_attn + +cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t); + +class cudnnExecutionPlanManager { + public: + static cudnnExecutionPlanManager &Instance() { + static thread_local cudnnExecutionPlanManager instance; + return instance; + } + + cudnnHandle_t GetCudnnHandle() { + static thread_local std::once_flag flag; + std::call_once(flag, [&] { cudnnCreate(&handle_); }); + return handle_; + } + + ~cudnnExecutionPlanManager() { + static thread_local std::once_flag flag; + std::call_once(flag, [&] { + if (handle_ != nullptr) { + cudnnDestroy(handle_); + }}); + } + + private: + cudnnHandle_t handle_ = nullptr; +}; +} // namespace transformer_engine + +#endif diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h new file mode 100644 index 0000000000..bb9262de18 --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -0,0 +1,262 @@ +/************************************************************************* + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_FP8_H_ +#define TRANSFORMER_ENGINE_FUSED_ATTN_FP8_H_ + +#include "transformer_engine.h" + +#ifdef __cplusplus +extern "C" { +#endif + +enum NVTE_QKV_Layout { +/*!< separate Q, K, V tensors: + Q: [total_seqs_q, num_heads, head_dim] + | Q Q Q ... Q + | \___________ _____________/ + total_seqs_q <| \/ + | num_heads * head_dim + K: [total_seqs_kv, num_heads, head_dim] + | K K K ... K + | \___________ _____________/ + total_seqs_kv <| \/ + | num_heads * head_dim + V: [total_seqs_kv, num_heads, head_dim] + | V V V ... V + | \___________ _____________/ + total_seqs_kv <| \/ + | num_heads * head_dim + */ + NVTE_NOT_INTERLEAVED = 0, + +/*!< packed QKV tensor: + QKV: [total_seqs, 3, num_heads, head_dim] + | Q Q Q ... Q K K K ... K V V V ... V + | \___________ _____________/ + total_seqs <| \/ + | num_heads * head_dim + */ + NVTE_QKV_INTERLEAVED = 1, + +/*!< Q and packed KV tensor: + Q: [total_seqs_q, num_heads, head_dim] + | Q Q Q ... Q + | \___________ _____________/ + total_seqs_q <| \/ + | num_heads * head_dim + KV: [total_seqs_kv, 2, num_heads, head_dim] + | K K K ... K V V V ... V + | \___________ _____________/ + total_seqs_kv <| \/ + | num_heads * head_dim + */ + NVTE_KV_INTERLEAVED = 2 +}; + +enum NVTE_Bias_Type { + NVTE_NO_BIAS = 0, /*!< no bias */ + NVTE_PRE_SCALE_BIAS = 1, /*!< bias before scale */ + NVTE_POST_SCALE_BIAS = 2 /*!< bias after scale */ +}; + +enum NVTE_Mask_Type { + NVTE_PADDING_MASK = 0, /*!< padding attention mask */ + NVTE_CAUSAL_MASK = 1, /*!< causal attention mask */ + NVTE_NO_MASK = 2 /*!< no masking */ +}; + +/*! \brief Compute dot product attention with packed QKV input. + * + * Computes: + * - P = Q * K.T + Bias + * - S = ScaleMaskSoftmax(P) + * - D = Dropout(S) + * - O = D * V.T + * + * Support Matrix: + * | precision | qkv layout | bias | mask | sequence length | head_dim | + * | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | <= 512 | 64 | + * + * + * \param[in] QKV The QKV tensor in packed format, + * [total_seqs, 3, num_heads, head_dim]. + * \param[in] Bias The Bias tensor. + * \param[in,out] S The S tensor. + * \param[out] O The output O tensor. + * \param[out] Aux_Output_Tensors Auxiliary output tensors when training, e.g. M, ZInv. + * \param[in] cu_seqlens Accumulative sequence lengths, [batch_size + 1]. + * \param[in] rng_state Seed and offset of CUDA random number generator. + * \param[in] max_seqlen Max sequence length used for computing, + * it may be >= max(cu_seqlens). + * \param[in] is_training Whether this is in training mode or inference. + * \param[in] attn_scale Scaling factor for Q * K.T. + * \param[in] dropout Dropout probability. + * \param[in] qkv_layout QKV tensor's layout. + * \param[in] bias_type Bias type. + * \param[in] attn_mask_type Attention mask type. + * \param[in] workspace Workspace tensor. + * \param[in] stream CUDA stream used for this operation. + */ +void nvte_fused_attn_fwd_qkvpacked( + const NVTETensor QKV, + const NVTETensor Bias, + NVTETensor S, + NVTETensor O, + NVTETensorPack* Aux_Output_Tensors, + const NVTETensor cu_seqlens, + const NVTETensor rng_state, + size_t max_seqlen, + bool is_training, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, + NVTETensor workspace, + cudaStream_t stream); + +/*! \brief Compute the backward of the dot product attention with packed QKV input. + * + * Support Matrix: + * | precision | qkv layout | bias | mask | sequence length | head_dim | + * | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | <= 512 | 64 | + * + * + * \param[in] QKV The QKV tensor in packed format, + * [total_seqs, 3, num_heads, head_dim]. + * \param[in] dBias The gradient of the Bias tensor. + * \param[in] O The O tensor from forward. + * \param[in] dO The gradient of the O tensor. + * \param[in] S The S tensor. + * \param[in,out] dP The gradient of the P tensor. + * \param[in] Aux_CTX_Tensors Auxiliary tensors from forward when in training mode. + * \param[out] dQKV The gradient of the QKV tensor. + * \param[in] cu_seqlens Accumulative sequence lengths, [batch_size + 1]. + * \param[in] rng_state Seed and offset of CUDA random number generator. + * \param[in] max_seqlen Max sequence length used for computing, + * it may be >= max(cu_seqlens). + * \param[in] attn_scale Scaling factor for Q * K.T. + * \param[in] dropout Dropout probability. + * \param[in] qkv_layout QKV tensor's layout. + * \param[in] bias_type Bias type. + * \param[in] attn_mask_type Attention mask type. + * \param[in] workspace Workspace tensor. + * \param[in] stream CUDA stream used for this operation. + */ +void nvte_fused_attn_bwd_qkvpacked( + const NVTETensor QKV, + const NVTETensor dBias, + const NVTETensor O, + const NVTETensor dO, + const NVTETensor S, + NVTETensor dP, + const NVTETensorPack* Aux_CTX_Tensors, + NVTETensor dQKV, + const NVTETensor cu_seqlens, + size_t max_seqlen, + float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, + NVTETensor workspace, + cudaStream_t stream); + +/*! \brief Compute dot product attention with packed KV input. + * + * Computes: + * - P = Q * K.T + Bias + * - S = ScaleMaskSoftmax(P) + * - D = Dropout(S) + * - O = D * V.T + * + * \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim]. + * \param[in] KV The KV tensor, [total_seqs_kv, 2, num_heads, head_dim]. + * \param[in] Bias The Bias tensor. + * \param[in,out] S The S tensor. + * \param[out] O The output O tensor. + * \param[out] Aux_Output_Tensors Auxiliary output tensors when training, e.g. M, ZInv. + * \param[in] cu_seqlens_q Accumulative sequence lengths for Q, [batch_size + 1]. + * \param[in] cu_seqlens_kv Accumulative sequence lengths for KV, [batch_size + 1]. + * \param[in] rng_state Seed and offset of CUDA random number generator. + * \param[in] max_seqlen_q Max sequence length used for computing for Q. + * it may be >= max(cu_seqlens_q). + * \param[in] max_seqlen_kv Max sequence length used for computing for KV. + * it may be >= max(cu_seqlens_kv). + * \param[in] is_training Whether this is in training mode or inference. + * \param[in] attn_scale Scaling factor for Q * K.T. + * \param[in] dropout Dropout probability. + * \param[in] qkv_layout QKV tensor's layout. + * \param[in] bias_type Bias type. + * \param[in] attn_mask_type Attention mask type. + * \param[in] workspace Workspace tensor. + * \param[in] stream CUDA stream used for this operation. + */ +void nvte_fused_attn_fwd_kvpacked( + const NVTETensor Q, + const NVTETensor KV, + const NVTETensor Bias, + NVTETensor S, + NVTETensor O, + NVTETensorPack* Aux_Output_Tensors, + const NVTETensor cu_seqlens_q, + const NVTETensor cu_seqlens_kv, + const NVTETensor rng_state, + size_t max_seqlen_q, size_t max_seqlen_kv, + bool is_training, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, + NVTETensor workspace, + cudaStream_t stream); + +/*! \brief Compute the backward of the dot product attention with packed KV input. + * + * \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim]. + * \param[in] KV The KV tensor, [total_seqs_kv, 2, num_heads, head_dim]. + * \param[in] dBias The gradient of the Bias tensor. + * \param[in] O The O tensor from forward. + * \param[in] dO The gradient of the O tensor. + * \param[in] S The S tensor. + * \param[in,out] dP The gradient of the P tensor. + * \param[in] Aux_CTX_Tensors Auxiliary tensors from forward when in training mode. + * \param[out] dQ The gradient of the Q tensor. + * \param[out] dKV The gradient of the KV tensor. + * \param[in] cu_seqlens_q Accumulative sequence lengths for Q, [batch_size + 1]. + * \param[in] cu_seqlens_kv Accumulative sequence lengths for KV, [batch_size + 1]. + * \param[in] rng_state Seed and offset of CUDA random number generator. + * \param[in] max_seqlen_q Max sequence length used for computing for Q. + * it may be >= max(cu_seqlens_q). + * \param[in] max_seqlen_kv Max sequence length used for computing for KV. + * it may be >= max(cu_seqlens_kv). + * \param[in] attn_scale Scaling factor for Q * K.T. + * \param[in] dropout Dropout probability. + * \param[in] qkv_layout QKV tensor's layout. + * \param[in] bias_type Bias type. + * \param[in] attn_mask_type Attention mask type. + * \param[in] workspace Workspace tensor. + * \param[in] stream CUDA stream used for this operation. + */ +void nvte_fused_attn_bwd_kvpacked( + const NVTETensor Q, + const NVTETensor KV, + const NVTETensor dBias, + const NVTETensor O, + const NVTETensor dO, + const NVTETensor S, + NVTETensor dP, + const NVTETensorPack* Aux_CTX_Tensors, + NVTETensor dQ, + NVTETensor dKV, + const NVTETensor cu_seqlens_q, + const NVTETensor cu_seqlens_kv, + size_t max_seqlen_q, size_t max_seqlen_kv, + float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, + NVTETensor workspace, + cudaStream_t stream); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif diff --git a/transformer_engine/common/include/transformer_engine/logging.h b/transformer_engine/common/include/transformer_engine/logging.h index 36fd614f59..d488274579 100644 --- a/transformer_engine/common/include/transformer_engine/logging.h +++ b/transformer_engine/common/include/transformer_engine/logging.h @@ -9,6 +9,7 @@ #include #include +#include #include #include @@ -39,10 +40,18 @@ inline void check_cublas_(cublasStatus_t status) { } } +inline void check_cudnn_(cudnnStatus_t status) { + if ( status != CUDNN_STATUS_SUCCESS ) { + NVTE_ERROR("CUDNN Error: " + std::string(cudnnGetErrorString(status))); + } +} + } // namespace #define NVTE_CHECK_CUDA(ans) { check_cuda_(ans); } #define NVTE_CHECK_CUBLAS(ans) { check_cublas_(ans); } +#define NVTE_CHECK_CUDNN(ans) { check_cudnn_(ans); } + #endif // TRANSFORMER_ENGINE_LOGGING_H_ diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 0f17a4926a..72383c36bc 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -24,11 +24,12 @@ extern "C" { enum NVTEDType { kNVTEByte = 0, /*!< Byte */ kNVTEInt32 = 1, /*!< 32-bit integer */ - kNVTEFloat32 = 2, /*!< 32-bit float */ - kNVTEFloat16 = 3, /*!< 16-bit float (E5M10) */ - kNVTEBFloat16 = 4, /*!< 16-bit bfloat (E8M7) */ - kNVTEFloat8E4M3 = 5, /*!< 8-bit float (E4M3) */ - kNVTEFloat8E5M2 = 6, /*!< 8-bit float (E5M2) */ + kNVTEInt64 = 2, /*!< 32-bit integer */ + kNVTEFloat32 = 3, /*!< 32-bit float */ + kNVTEFloat16 = 4, /*!< 16-bit float (E5M10) */ + kNVTEBFloat16 = 5, /*!< 16-bit bfloat (E8M7) */ + kNVTEFloat8E4M3 = 6, /*!< 8-bit float (E4M3) */ + kNVTEFloat8E5M2 = 7, /*!< 8-bit float (E5M2) */ kNVTENumTypes /*!< Number of supported types */ }; @@ -129,6 +130,19 @@ float *nvte_tensor_scale(const NVTETensor tensor); */ float *nvte_tensor_scale_inv(const NVTETensor tensor); +struct NVTETensorPack { + static const int MAX_SIZE = 10; /*!< we expect <10 matrices in auxiliary outputs */ + NVTETensor tensors[MAX_SIZE]; /*!< wrappers to tensors, do not hold memory */ + size_t size = 0; /*!< actual size of the tensor pack, 0 <= size <= MAX_SIZE */ +}; + +/*! \brief Create NVTETensors in NVTETensorPack. + */ +void nvte_tensor_pack_create(NVTETensorPack* pack); + +/*! \brief Destroy NVTETensors in NVTETensorPack. + */ +void nvte_tensor_pack_destroy(NVTETensorPack* pack); #ifdef __cplusplus } // extern "C" @@ -146,11 +160,12 @@ namespace transformer_engine { enum class DType { kByte = 0, kInt32 = 1, - kFloat32 = 2, - kFloat16 = 3, - kBFloat16 = 4, - kFloat8E4M3 = 5, - kFloat8E5M2 = 6, + kInt64 = 2, + kFloat32 = 3, + kFloat16 = 4, + kBFloat16 = 5, + kFloat8E4M3 = 6, + kFloat8E5M2 = 7, kNumTypes }; diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 679d1e93c4..708712ff9a 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -133,3 +133,16 @@ float *nvte_tensor_scale_inv(const NVTETensor tensor) { "Tensor's inverse of scale must have Float32 type!"); return reinterpret_cast(t.scale_inv.dptr); } + +void nvte_tensor_pack_create(NVTETensorPack* pack) { + for (int i = 0; i < pack->MAX_SIZE; i++) { + pack->tensors[i] = reinterpret_cast(new transformer_engine::Tensor); + } +} + +void nvte_tensor_pack_destroy(NVTETensorPack* pack) { + for (int i = 0; i < pack->MAX_SIZE; i++) { + auto *t = reinterpret_cast(pack->tensors[i]); + delete t; + } +} diff --git a/transformer_engine/pytorch/constants.py b/transformer_engine/pytorch/constants.py index 271c70fcab..cc8b063245 100644 --- a/transformer_engine/pytorch/constants.py +++ b/transformer_engine/pytorch/constants.py @@ -14,7 +14,7 @@ with enum in transformer_engine.h """ TE_DType = { - torch.int8: tex.DType.kByte, + torch.uint8: tex.DType.kByte, torch.int32: tex.DType.kInt32, torch.float32: tex.DType.kFloat32, torch.half: tex.DType.kFloat16, diff --git a/transformer_engine/pytorch/cpp_extensions.py b/transformer_engine/pytorch/cpp_extensions.py index fae64445f0..1353f1513e 100644 --- a/transformer_engine/pytorch/cpp_extensions.py +++ b/transformer_engine/pytorch/cpp_extensions.py @@ -3,11 +3,735 @@ # See LICENSE for license information. """TE FP8 extensions and GEMMs""" -from typing import Optional, Tuple, Union +import math +from typing import Optional, Tuple, List, Union import torch import transformer_engine_extensions as tex from .constants import TE_DType +TORCH_DType = { + tex.DType.kFloat8E4M3: torch.uint8, + tex.DType.kFloat8E5M2: torch.uint8, + tex.DType.kFloat16: torch.half, + tex.DType.kBFloat16: torch.bfloat16, + tex.DType.kFloat32: torch.float32, + tex.DType.kInt32: torch.int32, +} + +def check_tensor(x: torch.Tensor): + """Check tensor properties.""" + assert (x.is_cuda and x.is_contiguous() + ), "Tensor should be a GPU tensor and contiguous." + +def check_qkv(qkv: torch.Tensor, dtype: torch.dtype): + """Check tensor properties.""" + check_tensor(qkv) + assert (qkv.dtype is dtype + and qkv.dim() == 4 + and qkv.shape[1] == 3 + ), """QKV should be in [total_seqs, 3, num_heads, head_dim] shape + and {dtype} dtype.""" + +def check_q(q: torch.Tensor, dtype: torch.dtype): + """Check tensor properties.""" + check_tensor(q) + assert (q.dtype is dtype + and q.dim() == 3 + ), """Q should be in [total_seqs, num_heads, head_dim] shape + and {dtype} dtype.""" + +def check_kv(kv: torch.Tensor, dtype: torch.dtype): + """Check tensor properties.""" + check_tensor(kv) + assert (kv.dtype is dtype + and kv.dim() == 4 + and kv.shape[1] == 2 + ), """KV should be in [total_seqs, 2, num_heads, head_dim] shape + and {dtype} dtype.""" + +def check_o(o: torch.Tensor, dtype: torch.dtype): + """Check tensor properties.""" + check_tensor(o) + assert (o.dtype is dtype + and o.dim() == 3 + ), """O and dO should be in [total_seqs, num_heads, head_dim] shape + and {dtype} dtype.""" + +def check_stats(stats: torch.Tensor, b: int, h: int, s: int): + """Check tensor properties.""" + check_tensor(stats) + assert (stats.dtype is torch.float32 + and stats.dim() == 4 + and stats.shape == torch.Size([b, h, s, 1]) + ), """M and ZInv should be in [batch_size, num_heads, max_seqlen_q, 1] + shape and float32 dtype.""" + +def check_cu_seqlens(cu_seqlens: torch.Tensor): + """Check tensor properties.""" + check_tensor(cu_seqlens) + assert (cu_seqlens.dtype is torch.int32 + and cu_seqlens.dim() == 1 + ), """cu_seqlens should be in [batch_size +1] shape and int32 dtype.""" + +def check_scalar(scalar: torch.Tensor): + """Check tensor properties.""" + check_tensor(scalar) + assert (scalar.dtype is torch.float32 + and scalar.dim() <= 1 + and scalar.numel() == 1 + ), "amax/scale/descale tensors should be scalars in float32 dtype." + +def check_rng_state(rng_state: torch.Tensor): + """Check tensor properties.""" + check_tensor(rng_state) + assert (rng_state.dtype is torch.int64 + and rng_state.numel() == 2 + ), "rng_state should be [seed, offset] and in int64 dtype." + +def fused_attn_fwd_qkvpacked( + is_training: bool, + max_seqlen: int, + cu_seqlens: torch.Tensor, + qkv: torch.Tensor, + qkv_dtype: tex.DType, + bias: torch.Tensor = None, + d_scale_qkv: torch.Tensor = None, + q_scale_s: torch.Tensor = None, + q_scale_o: torch.Tensor = None, + amax_s: torch.Tensor = None, + amax_o: torch.Tensor = None, + attn_scale: float = None, + dropout: float = 0.0, + set_zero: bool = True, + qkv_layout: str = "qkv_interleaved", + bias_type: str = "no_bias", + attn_mask_type: str = "padding", + rng_gen: torch.Generator = None, +) -> Tuple[Union[torch.Tensor, None], ...]: + """Fused Attention FWD for packed QKV input. + + Parameters + ---------- + is_training: bool + if True, runs training and produces auxiliary tensors aux_ctx_tensors + for the backward; if False, runs inference and doesn't produce aux_ctx_tensors + max_seqlen: int + max sequence length for QKV, used for padding; may be larger than max(cu_seqlens) + cu_seqlens: torch.Tensor + accumulative sequence lengths for QKV; shape [batch_size + 1] + qkv: torch.Tensor + input tensor QKV; + shape [total_seqs, 3, num_heads, head_dim], where total_seqs = cu_seqlens[-1] + qkv_dtype: tex.DType + data type of QKV; in tex.DType, not torch.dtype + bias: torch.Tensor, default = None + input tensor Bias; + shape [total_seqs, num_heads, head_dim], where total_seqs = cu_seqlens[-1] + d_scale_qkv: torch.Tensor, default = None + input tensor for the dequantization of QKV in FP8 computations + q_scale_s: torch.Tensor, default = None + input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T) + q_scale_o: torch.Tensor, default = None + input tensor for the quantization of O in FP8 computations + amax_s: torch.Tensor, default = None + output tensor, amax of S, used by the next iteration in FP8 computations + amax_o: torch.Tensor, default = None + output tensor, amax of O, used by the next iteration in FP8 computations + attn_scale: float, default = None + if not None, use attn_scale as the attention scale for Q*K.T BMM; + if None, use 1.0/sqrt(head_dim) as the default + dropout: float, default = 0.0 + dropout probability, 0.0 means no dropout, 1.0 means no output; + dropout must be 0.0 if is_training is False + set_zero: bool, default = True + if True, initializes the output tensor O to zero using the mha_fill method; + if False, doesn't initialize O after its allocation + qkv_layout: str, default = "qkv_interleaved" + layout of QKV; {"qkv_interleaved", "kv_interleaved", "not_interleaved"} + bias_type: str, default = "no_bias" + type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"} + attn_mask_type: str, default = "padding" + type of the attention mask; {"padding", "causal", "no_mask"} + rng_gen: torch.Generator, default = None + random number generator; + if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen + + Returns + ---------- + o: torch.Tensor + output tensor O, of the attention calculation; same data type as QKV; + shape [total_seqs, num_heads, head_dim], where total_seqs = cu_seqlens[-1] + aux_ctx_tensors: List[torch.Tensor] + auxiliary output tensors used for the backward; + if is_training is True, aux_ctx_tensors = [M, ZInv, rng_state] + if is_training is False, aux_ctx_tensors = [rng_state] + M: torch.Tensor + max(Q*K.T) + shape [batch_size, num_heads, max_seqlen, 1], dtype float32 + ZInv: torch.Tensor + 1/sum(e^(x - max(x))), where x=Q*K.T + shape [batch_size, num_heads, max_seqlen, 1], dtype float32 + rng_state: torch.Tensor + state of the random number generator; + [seed, offset], dtype uint64 + """ + + check_cu_seqlens(cu_seqlens) + b = cu_seqlens.numel() - 1 + qkv_type = TORCH_DType[qkv_dtype] + check_qkv(qkv, qkv_type) + + total_seqs = qkv.size(0) + h = qkv.size(2) + d = qkv.size(3) + + if attn_scale is None: + attn_scale = 1.0 / math.sqrt(d) + + # FP8 fused attention API + if (qkv_type is torch.uint8) and (max_seqlen <= 512) and (d == 64): + assert (qkv_layout == "qkv_interleaved" + and bias_type == "no_bias" + and attn_mask_type == "padding" + ), """The FP8 fused attention API currently only supports qkv_interleaved layout, + no_bias type, and padding attention mask type.""" + assert (d_scale_qkv is not None), "d_scale_qkv is required for the FP8 API." + assert (q_scale_s is not None), "q_scale_s is required for the FP8 API." + assert (q_scale_o is not None), "q_scale_o is required for the FP8 API." + assert (amax_s is not None), "amax_s is required for the FP8 API." + assert (amax_o is not None), "amax_o is required for the FP8 API." + check_scalar(d_scale_qkv) + check_scalar(q_scale_s) + check_scalar(q_scale_o) + check_scalar(amax_s) + check_scalar(amax_o) + + # BF16/FP16 fused attention API from fmha_v2 + elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) and (max_seqlen > 512): + # add BF/FP16 support for >512 sequence length + assert False, "The BF16/FP16 support for >512 sequence length is coming!" + + # BF16/FP16 fused attention API from fmha_v1 apex + elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) and (max_seqlen <= 512): + # add BF/FP16 support for <=512 sequence length + assert False, "The BF16/FP16 support for <=512 sequence length is coming!" + + else: + assert False, "No support for this dtype and max_seqlen combination." + + # execute kernel + output_tensors = tex.fused_attn_fwd_qkvpacked( + b, max_seqlen, total_seqs, h, d, + is_training, attn_scale, dropout, set_zero, qkv_layout, bias_type, attn_mask_type, + cu_seqlens, + qkv, + qkv_dtype, + d_scale_qkv, + q_scale_s, + q_scale_o, + amax_s, + amax_o, + bias, + rng_gen, + ) + + return output_tensors[0], output_tensors[1:] + + +def fused_attn_bwd_qkvpacked( + max_seqlen: int, + cu_seqlens: torch.Tensor, + qkv: torch.Tensor, + o: torch.Tensor, + d_o: torch.Tensor, + qkv_dtype: tex.DType, + aux_ctx_tensors: List[torch.Tensor] = None, + d_bias: torch.Tensor = None, + d_scale_qkv: torch.Tensor = None, + d_scale_s: torch.Tensor = None, + d_scale_o: torch.Tensor = None, + d_scale_do: torch.Tensor = None, + q_scale_s: torch.Tensor = None, + q_scale_dp: torch.Tensor = None, + q_scale_dqkv: torch.Tensor = None, + amax_dp: torch.Tensor = None, + amax_dqkv: torch.Tensor = None, + attn_scale: float = None, + dropout: float = 0.0, + set_zero: bool = True, + qkv_layout: str = "qkv_interleaved", + bias_type: str = "no_bias", + attn_mask_type: str = "padding", +) -> Tuple[Union[torch.Tensor, None], ...]: + """Fused Attention BWD for packed QKV input. + + Parameters + ---------- + max_seqlen: int + max sequence length for QKV, used for padding; may be larger than max(cu_seqlens_q) + cu_seqlens: torch.Tensor + accumulative sequence lengths for QKV; shape [batch_size + 1] + qkv: torch.Tensor + input tensor QKV; + shape [total_seqs, 3, num_heads, head_dim], where total_seqs = cu_seqlens[-1] + o: torch.Tensor + input tensor O (output of forward); + shape [total_seqs, num_heads, head_dim], where total_seqs = cu_seqlens[-1] + d_o: torch.Tensor + input tensor dO (gradient of O); + shape [total_seqs, num_heads, head_dim], where total_seqs = cu_seqlens[-1] + qkv_dtype: tex.DType + data type of QKV; in tex.DType, not torch.dtype + aux_ctx_tensors: List[torch.Tensor] + auxiliary output tensors of the forward pass when its is_training is True, + e.g. aux_ctx_tensors = [M, ZInv, rng_state] + d_bias: torch.Tensor, default = None + input tensor Bias; + shape [total_seqs, num_heads, head_dim], where total_seqs = cu_seqlens[-1] + d_scale_qkv: torch.Tensor, default = None + input tensor for the dequantization of QKV in FP8 computations + d_scale_s: torch.Tensor, default = None + input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T) + d_scale_o: torch.Tensor, default = None + input tensor for the dequantization of O in FP8 computations + d_scale_do: torch.Tensor, default = None + input tensor for the dequantization of dO in FP8 computations + q_scale_s: torch.Tensor, default = None + input tensor for the quantization of S in FP8 computations + q_scale_dp: torch.Tensor, default = None + input tensor for the quantization of dP in FP8 computations, P = Q * K.T + q_scale_dqkv: torch.Tensor, default = None + input tensor for the quantization of dQKV in FP8 computations + amax_dp: torch.Tensor, default = None + output tensor, amax of dP, used by the next iteration in FP8 computations + amax_dqkv: torch.Tensor, default = None + output tensor, amax of dQKV, used by the next iteration in FP8 computations + attn_scale: float, default = None + if not None, use attn_scale as the attention scale for Q*K.T BMM; + if None, use 1.0/sqrt(head_dim) as the default + dropout: float, default = 0.0 + dropout probability, 0.0 means no dropout, 1.0 means no output; + dropout must be 0.0 if is_training is False + set_zero: bool, default = True + if True, initializes the output tensor O to zero using the mha_fill method; + if False, doesn't initialize O after its allocation + qkv_layout: str, default = "qkv_interleaved" + layout of QKV; {"qkv_interleaved", "kv_interleaved", "not_interleaved"} + bias_type: str, default = "no_bias" + type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"} + attn_mask_type: str, default = "padding" + type of the attention mask; {"padding", "causal", "no_mask"} + + Returns + ---------- + d_qkv: torch.Tensor + gradient tensor of QKV; same data type and shape as QKV + """ + + check_cu_seqlens(cu_seqlens) + b = cu_seqlens.numel() - 1 + qkv_type = TORCH_DType[qkv_dtype] + check_qkv(qkv, qkv_type) + check_o(o, qkv_type) + check_o(d_o, qkv_type) + + total_seqs = qkv.size(0) + h = qkv.size(2) + d = qkv.size(3) + + if attn_scale is None: + attn_scale = 1.0 / math.sqrt(d) + + assert (len(aux_ctx_tensors) >= 1 + ), "aux_ctx_tensors must contain rng_state as its last element." + rng_state = aux_ctx_tensors[-1] + check_rng_state(rng_state) + + # FP8 fused attention API + if (qkv_type is torch.uint8) and (max_seqlen <= 512) and d == 64: + assert (qkv_layout == "qkv_interleaved" + and bias_type == "no_bias" + and attn_mask_type == "padding" + ), """The FP8 fused attention API currently only supports qkv_interleaved layout, + no_bias type, and padding attention mask type.""" + assert (d_scale_qkv is not None), "d_scale_qkv is required for the FP8 API." + assert (d_scale_s is not None), "d_scale_s is required for the FP8 API." + assert (d_scale_o is not None), "d_scale_o is required for the FP8 API." + assert (d_scale_do is not None), "d_scale_do is required for the FP8 API." + assert (q_scale_s is not None), "q_scale_s is required for the FP8 API." + assert (q_scale_dp is not None), "q_scale_dp is required for the FP8 API." + assert (q_scale_dqkv is not None), "q_scale_dqkv is required for the FP8 API." + assert (amax_dp is not None), "amax_dp is required for the FP8 API." + assert (amax_dqkv is not None), "amax_dqkv is required for the FP8 API." + assert (len(aux_ctx_tensors) == 3 + ), "aux_ctx_tensors is required to be [M, ZInv, rng_state] for the FP8 API." + check_scalar(d_scale_qkv) + check_scalar(d_scale_s) + check_scalar(d_scale_o) + check_scalar(d_scale_do) + check_scalar(q_scale_s) + check_scalar(q_scale_dp) + check_scalar(q_scale_dqkv) + check_scalar(amax_dp) + check_scalar(amax_dqkv) + m, z_inv = aux_ctx_tensors[:2] + check_stats(m, b, h, max_seqlen) + check_stats(z_inv, b, h, max_seqlen) + + # BF16/FP16 fused attention API from fmha_v2 + elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) and (max_seqlen > 512): + # add BF/FP16 support for >512 sequence length + assert False, "The BF16/FP16 support for >512 sequence length is coming!" + + # BF16/FP16 fused attention API from fmha_v1 apex + elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) and (max_seqlen <= 512): + # add BF/FP16 support for <=512 sequence length + assert False, "The BF16/FP16 support for <=512 sequence length is coming!" + + else: + assert False, "No support for this dtype and max_seqlen combination." + + # execute kernel + output_tensors = tex.fused_attn_bwd_qkvpacked( + b, max_seqlen, total_seqs, h, d, + attn_scale, dropout, set_zero, qkv_layout, bias_type, attn_mask_type, + cu_seqlens, + qkv, o, d_o, + qkv_dtype, + aux_ctx_tensors, + d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, + q_scale_s, q_scale_dp, q_scale_dqkv, + amax_dp, amax_dqkv, + d_bias, + ) + + return output_tensors[0] + + +def fused_attn_fwd_kvpacked( + is_training: bool, + max_seqlen_q: int, + max_seqlen_kv: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + q: torch.Tensor, + kv: torch.Tensor, + qkv_dtype: tex.DType, + bias: torch.Tensor = None, + d_scale_qkv: torch.Tensor = None, + q_scale_s: torch.Tensor = None, + q_scale_o: torch.Tensor = None, + amax_s: torch.Tensor = None, + amax_o: torch.Tensor = None, + attn_scale: float = None, + dropout: float = 0.0, + set_zero: bool = True, + qkv_layout: str = "qkv_interleaved", + bias_type: str = "no_bias", + attn_mask_type: str = "padding", + rng_gen: torch.Generator = None, +) -> Tuple[Union[torch.Tensor, None], ...]: + """Fused Attention FWD for packed KV input. + + Parameters + ---------- + is_training: bool + if True, runs training and produces auxiliary tensors aux_ctx_tensors + for the backward; if False, runs inference and doesn't produce aux_ctx_tensors + max_seqlen_q: int + max sequence length for Q, used for padding; may be larger than max(cu_seqlens_q) + max_seqlen_kv: int + max sequence length for KV, used for padding; may be larger than max(cu_seqlens_kv) + cu_seqlens_q: torch.Tensor + accumulative sequence lengths for Q; shape [batch_size + 1] + cu_seqlens_kv: torch.Tensor + accumulative sequence lengths for KV; shape [batch_size + 1] + q: torch.Tensor + input tensor Q; + shape [total_seqs_q, num_heads, head_dim], where total_seqs_q = cu_seqlens_q[-1] + kv: torch.Tensor + packed input tensor KV; + shape [total_seqs_kv, 2, num_heads, head_dim], + where total_seqs_kv = cu_seqlens_kv[-1] + qkv_dtype: tex.DType + data type of QKV; in tex.DType, not torch.dtype + bias: torch.Tensor, default = None + input tensor Bias; + shape [total_seqs_q, num_heads, head_dim], where total_seqs_q = cu_seqlens_q[-1] + d_scale_qkv: torch.Tensor, default = None + input tensor for the dequantization of QKV in FP8 computations + q_scale_s: torch.Tensor, default = None + input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T) + q_scale_o: torch.Tensor, default = None + input tensor for the quantization of O in FP8 computations + amax_s: torch.Tensor, default = None + output tensor, amax of S, used by the next iteration in FP8 computations + amax_o: torch.Tensor, default = None + output tensor, amax of O, used by the next iteration in FP8 computations + attn_scale: float, default = None + if not None, use attn_scale as the attention scale for Q*K.T BMM; + if None, use 1.0/sqrt(head_dim) as the default + dropout: float, default = 0.0 + dropout probability, 0.0 means no dropout, 1.0 means no output; + dropout must be 0.0 if is_training is False + set_zero: bool, default = True + if True, initializes the output tensor O to zero using the mha_fill method; + if False, doesn't initialize O after its allocation + qkv_layout: str, default = "qkv_interleaved" + layout of QKV; {"qkv_interleaved", "kv_interleaved", "not_interleaved"} + bias_type: str, default = "no_bias" + type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"} + attn_mask_type: str, default = "padding" + type of the attention mask; {"padding", "causal", "no_mask"} + rng_gen: torch.Generator, default = None + random number generator; + if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen + + Returns + ---------- + o: torch.Tensor + output tensor O, of the attention calculation; same data type as QKV; + shape [total_seqs, num_heads, head_dim], where total_seqs = cu_seqlens[-1] + aux_ctx_tensors: List[torch.Tensor] + auxiliary output tensors used for the backward; + if is_training is True, aux_ctx_tensors = [M, ZInv, rng_state] + if is_training is False, aux_ctx_tensors = [rng_state] + M: torch.Tensor + max(Q*K.T) + shape [batch_size, num_heads, max_seqlen, 1], dtype float32 + ZInv: torch.Tensor + 1/sum(e^(x - max(x))), where x=Q*K.T + shape [batch_size, num_heads, max_seqlen, 1], dtype float32 + rng_state: torch.Tensor + state of the random number generator; + [seed, offset], dtype uint64 + """ + + check_cu_seqlens(cu_seqlens_q) + check_cu_seqlens(cu_seqlens_kv) + assert (cu_seqlens_q.numel() == cu_seqlens_kv.numel() + ), "cu_seqlens_q and cu_seqlens_kv must have the same length." + b = cu_seqlens_q.numel() - 1 + qkv_type = TORCH_DType[qkv_dtype] + check_q(q, qkv_type) + check_kv(kv, qkv_type) + + assert (q.size(1) == kv.size(2) + and q.size(2) == kv.size(3) + ), "Q and KV must have the same num_heads and head_dim." + total_seqs_q = q.size(0) + total_seqs_kv = kv.size(0) + h = q.size(1) + d = q.size(2) + + if attn_scale is None: + attn_scale = 1.0 / math.sqrt(d) + + # FP8 fused attention API + if (qkv_type is torch.uint8) and (max_seqlen_q <= 512) and (max_seqlen_kv <= 512) \ + and (d == 64): + assert False, "The FP8 fused attention API currently only supports packed QKV input." + + # BF16/FP16 fused attention API from fmha_v2 + elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) \ + and (max_seqlen_q > 512) and (max_seqlen_kv > 512): + # add BF/FP16 support for >512 sequence length + assert False, "The BF16/FP16 support for >512 sequence length is coming!" + + # BF16/FP16 fused attention API from fmha_v1 apex + elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) \ + and (max_seqlen_q <= 512) and (max_seqlen_kv <= 512): + # add BF/FP16 support for <=512 sequence length + assert False, "The BF16/FP16 support for <=512 sequence length is coming!" + + else: + assert False, "No support for this dtype and max_seqlen combination." + + # execute kernel + output_tensors = tex.fused_attn_fwd_kvpacked( + b, max_seqlen_q, max_seqlen_kv, total_seqs_q, total_seqs_kv, h, d, + is_training, attn_scale, dropout, set_zero, qkv_layout, bias_type, attn_mask_type, + cu_seqlens_q, cu_seqlens_kv, + q, kv, + qkv_dtype, + d_scale_qkv, + q_scale_s, + q_scale_o, + amax_s, + amax_o, + bias, + rng_gen, + ) + + return output_tensors[0], output_tensors[1:] + + +def fused_attn_bwd_kvpacked( + max_seqlen_q: int, + max_seqlen_kv: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + q: torch.Tensor, + kv: torch.Tensor, + o: torch.Tensor, + d_o: torch.Tensor, + qkv_dtype: tex.DType, + aux_ctx_tensors: List[torch.Tensor] = None, + d_bias: torch.Tensor = None, + d_scale_qkv: torch.Tensor = None, + d_scale_s: torch.Tensor = None, + d_scale_o: torch.Tensor = None, + d_scale_do: torch.Tensor = None, + q_scale_s: torch.Tensor = None, + q_scale_dp: torch.Tensor = None, + q_scale_dqkv: torch.Tensor = None, + amax_dp: torch.Tensor = None, + amax_dqkv: torch.Tensor = None, + attn_scale: float = None, + dropout: float = 0.0, + set_zero: bool = True, + qkv_layout: str = "qkv_interleaved", + bias_type: str = "no_bias", + attn_mask_type: str = "padding", +) -> Tuple[Union[torch.Tensor, None], ...]: + """Fused Attention BWD for packed KV input. + + Parameters + ---------- + max_seqlen_q: int + max sequence length for Q, used for padding; may be larger than max(cu_seqlens_q) + max_seqlen_kv: int + max sequence length for KV, used for padding; may be larger than max(cu_seqlens_kv) + cu_seqlens_q: torch.Tensor + accumulative sequence lengths for Q; shape [batch_size + 1] + cu_seqlens_kv: torch.Tensor + accumulative sequence lengths for KV; shape [batch_size + 1] + q: torch.Tensor + input tensor Q; + shape [total_seqs_q, num_heads, head_dim], where total_seqs_q = cu_seqlens_q[-1] + kv: torch.Tensor + packed input tensor KV; + shape [total_seqs_kv, 2, num_heads, head_dim], + where total_seqs_kv = cu_seqlens_kv[-1] + o: torch.Tensor + input tensor O (output of forward); + shape [total_seqs_q, num_heads, head_dim], where total_seqs_q = cu_seqlens_q[-1] + d_o: torch.Tensor + input tensor dO (gradient of O); + shape [total_seqs_q, num_heads, head_dim], where total_seqs_q = cu_seqlens_q[-1] + qkv_dtype: tex.DType + data type of QKV; in tex.DType, not torch.dtype + aux_ctx_tensors: List[torch.Tensor] + auxiliary output tensors of the forward pass when its is_training is True, + e.g. aux_ctx_tensors = [M, ZInv, rng_state] + bias: torch.Tensor, default = None + input tensor Bias; + shape [total_seqs_q, num_heads, head_dim], where total_seqs_q = cu_seqlens_q[-1] + d_scale_qkv: torch.Tensor, default = None + input tensor for the dequantization of QKV in FP8 computations + d_scale_s: torch.Tensor, default = None + input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T) + d_scale_o: torch.Tensor, default = None + input tensor for the dequantization of O in FP8 computations + d_scale_do: torch.Tensor, default = None + input tensor for the dequantization of dO in FP8 computations + q_scale_s: torch.Tensor, default = None + input tensor for the quantization of S in FP8 computations + q_scale_dp: torch.Tensor, default = None + input tensor for the quantization of dP in FP8 computations, P = Q * K.T + q_scale_dqkv: torch.Tensor, default = None + input tensor for the quantization of dQKV in FP8 computations + amax_dp: torch.Tensor, default = None + output tensor, amax of dP, used by the next iteration in FP8 computations, + P = Q * K.T + amax_dqkv: torch.Tensor, default = None + output tensor, amax of dQKV, used by the next iteration in FP8 computations + attn_scale: float, default = None + if not None, use attn_scale as the attention scale for Q*K.T BMM; + if None, use 1.0/sqrt(head_dim) as the default + dropout: float, default = 0.0 + dropout probability, 0.0 means no dropout, 1.0 means no output; + dropout must be 0.0 if is_training is False + set_zero: bool, default = True + if True, initializes the output tensor O to zero using the mha_fill method; + if False, doesn't initialize O after its allocation + qkv_layout: str, default = "qkv_interleaved" + layout of QKV; {"qkv_interleaved", "kv_interleaved", "not_interleaved"} + bias_type: str, default = "no_bias" + type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"} + attn_mask_type: str, default = "padding" + type of the attention mask; {"padding", "causal", "no_mask"} + + Returns + ---------- + d_q: torch.Tensor + gradient tensor of Q; same data type and shape as Q + d_kv: torch.Tensor + gradient tensor of KV; same data type and shape as KV + """ + + check_cu_seqlens(cu_seqlens_q) + check_cu_seqlens(cu_seqlens_kv) + assert (cu_seqlens_q.numel() == cu_seqlens_kv.numel() + ), "cu_seqlens_q and cu_seqlens_kv must have the same length." + b = cu_seqlens_q.numel() - 1 + qkv_type = TORCH_DType[qkv_dtype] + check_q(q, qkv_type) + check_kv(kv, qkv_type) + check_o(o, qkv_type) + check_o(d_o, qkv_type) + + assert (q.size(1) == kv.size(2) + and q.size(2) == kv.size(3) + ), "Q and KV must have the same num_heads and head_dim." + total_seqs_q = q.size(0) + total_seqs_kv = q.size(0) + h = q.size(1) + d = q.size(2) + + if attn_scale is None: + attn_scale = 1.0 / math.sqrt(d) + + assert (len(aux_ctx_tensors) >= 1 + ), "aux_ctx_tensors must contain rng_state as its last element." + rng_state = aux_ctx_tensors[-1] + check_rng_state(rng_state) + + # FP8 fused attention API + if (qkv_type is torch.uint8) and (max_seqlen_q <= 512) and (max_seqlen_kv <= 512) \ + and d == 64: + assert False, "The FP8 fused attention API currently only supports packed QKV input." + + ############### BF16/FP16 fused attention API from fmha_v2 ################ + elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) \ + and (max_seqlen_q > 512) and (max_seqlen_kv > 512): + # add BF/FP16 support for >512 sequence length + assert False, "The BF16/FP16 support for >512 sequence length is coming!" + + ############### BF16/FP16 fused attention API from fmha_v1 apex ################ + elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) \ + and (max_seqlen_q <= 512) and (max_seqlen_kv <= 512): + # add BF/FP16 support for <=512 sequence length + assert False, "The BF16/FP16 support for <=512 sequence length is coming!" + + else: + assert False, "No support for this dtype and max_seqlen combination." + + # execute kernel + output_tensors = tex.fused_attn_bwd_kvpacked( + b, max_seqlen_q, max_seqlen_kv, total_seqs_q, total_seqs_kv, h, d, + attn_scale, dropout, set_zero, qkv_layout, bias_type, attn_mask_type, + cu_seqlens_q, cu_seqlens_kv, + q, kv, o, d_o, + qkv_dtype, + aux_ctx_tensors, + d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, + q_scale_s, q_scale_dp, q_scale_dqkv, + amax_dp, amax_dqkv, + d_bias, + ) + + return output_tensors def fp8_gemm( A: torch.Tensor, @@ -233,9 +957,9 @@ def fp8_cast_transpose_fused( return_outputs = False if cast_out is None or transpose_out is None: - cast_out = torch.empty_like(inp, dtype=torch.int8) + cast_out = torch.empty_like(inp, dtype=torch.uint8) transpose_out = torch.empty( - inp.shape[1], inp.shape[0], device="cuda", dtype=torch.int8 + inp.shape[1], inp.shape[0], device="cuda", dtype=torch.uint8 ) return_outputs = True diff --git a/transformer_engine/pytorch/csrc/common.cu b/transformer_engine/pytorch/csrc/common.cu index 2146118382..1d20607940 100644 --- a/transformer_engine/pytorch/csrc/common.cu +++ b/transformer_engine/pytorch/csrc/common.cu @@ -88,6 +88,19 @@ size_t product(const std::vector &shape) { } +at::Tensor allocateSpace(const std::vector& shape, + const transformer_engine::DType type, + bool init_to_zeros) { + std::vector shape_int64(shape.begin(), shape.end()); + c10::IntArrayRef ar_shape(shape_int64); + if (init_to_zeros) { + return at::zeros(ar_shape, at::CUDA(GetATenDType(type))); + } else { + return at::empty(ar_shape, at::CUDA(GetATenDType(type))); + } +} + + at::Tensor allocateSpace(const NVTEShape &shape, const transformer_engine::DType type, bool init_to_zeros) { diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index f6c9898601..1d59fc7c43 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -15,9 +15,15 @@ #include #include #include +#include #include #include #include +#include +#include +#include +#include +#include #include #include #include @@ -101,6 +107,12 @@ inline transformer_engine::DType GetTransformerEngineDType(at::ScalarType t) { return transformer_engine::DType::kBFloat16; case at::kBool: return transformer_engine::DType::kByte; + case torch::kByte: + return transformer_engine::DType::kByte; + case torch::kInt32: + return transformer_engine::DType::kInt32; + case torch::kInt64: + return transformer_engine::DType::kInt64; default: NVTE_ERROR("Invalid type"); } @@ -141,6 +153,9 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor, size_t product(const std::vector &shape); +at::Tensor allocateSpace(const std::vector& shape, + const transformer_engine::DType type, + bool init_to_zeros); at::Tensor allocateSpace(const NVTEShape &shape, const transformer_engine::DType type, diff --git a/transformer_engine/pytorch/csrc/extensions.cu b/transformer_engine/pytorch/csrc/extensions.cu index 23330efbf0..75d4abd031 100644 --- a/transformer_engine/pytorch/csrc/extensions.cu +++ b/transformer_engine/pytorch/csrc/extensions.cu @@ -9,6 +9,742 @@ #include "comm_gemm_overlap.h" #endif // NVTE_WITH_USERBUFFERS +constexpr int block_size = 512; +constexpr int ctas_per_sm = 4; + +// convert QKV layout to enum +NVTE_QKV_Layout get_nvte_qkv_layout(const std::string qkv_layout) { + if (qkv_layout == "not_interleaved") { + return NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED; + } else if (qkv_layout == "qkv_interleaved") { + return NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED; + } else if (qkv_layout == "kv_interleaved") { + return NVTE_QKV_Layout::NVTE_KV_INTERLEAVED; + } else { + NVTE_ERROR("Invalid QKV layout. \n"); + } +} + +// convert bias type to enum +NVTE_Bias_Type get_nvte_bias_type(const std::string bias_type) { + if (bias_type == "no_bias") { + return NVTE_Bias_Type::NVTE_NO_BIAS; + } else if (bias_type == "pre_scale_bias") { + return NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS; + } else if (bias_type == "post_scale_bias") { + return NVTE_Bias_Type::NVTE_POST_SCALE_BIAS; + } else { + NVTE_ERROR("Invalid bias type. \n"); + } +} + +// convert attn mask type to enum +NVTE_Mask_Type get_nvte_mask_type(const std::string mask_type) { + if (mask_type == "padding") { + return NVTE_Mask_Type::NVTE_PADDING_MASK; + } else if (mask_type == "causal") { + return NVTE_Mask_Type::NVTE_CAUSAL_MASK; + } else if (mask_type == "no_mask") { + return NVTE_Mask_Type::NVTE_NO_MASK; + } else { + NVTE_ERROR("Invalid attention mask type. \n"); + } +} + +// fast zero-fills of tensors +template +__global__ void __launch_bounds__(block_size) mha_fill_kernel(scalar_t* out_tensor, + const int32_t* const start_row, + const size_t num_rows) { + size_t row_stride = gridDim.y * blockDim.x; + size_t row_index = blockIdx.x + static_cast(start_row[0]); + size_t col_index = blockIdx.y * blockDim.x + threadIdx.x; + while (row_index < num_rows) { + out_tensor[row_index*row_stride + col_index] = 0; + row_index += gridDim.x; + } +} + +// fast zero-fills of tensors +void mha_fill(const at::Tensor &self, const at::Tensor &start_index) { + auto max_tokens = self.size(0); + auto self_2d = self.view({max_tokens, -1}); + auto fcd_size = self_2d.size(1); + TORCH_CHECK(self.is_contiguous(), "input not contiguous"); + TORCH_CHECK(fcd_size % block_size == 0, "input size not aligned to block size"); + const int num_mp = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; + uint64_t num_blk_y = (uint64_t)(fcd_size / block_size); + uint64_t num_blk_x = (uint64_t)((num_mp * ctas_per_sm + num_blk_y - 1) / num_blk_y); + dim3 dim_grid(num_blk_x, num_blk_y); + dim3 dim_block(block_size); + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( + at::ScalarType::Half, at::ScalarType::BFloat16, + self_2d.scalar_type(), "mha_fill", [&]() { + mha_fill_kernel<<>>( + self_2d.data_ptr(), + static_cast(start_index.data_ptr()), + max_tokens); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); +} + +// extract seed and offset from PhiloxCudaState +__global__ void unpack(at::PhiloxCudaState arg, int64_t* rng_state_ptr) { + if (arg.captured_) { + rng_state_ptr[0] = static_cast(*arg.seed_.ptr); + rng_state_ptr[1] = static_cast( + *(arg.offset_.ptr) + static_cast(arg.offset_intragraph_)); + } else { + rng_state_ptr[0] = static_cast(arg.seed_.val); + rng_state_ptr[1] = static_cast(arg.offset_.val); + } +} + +// extract PhiloxCudaState from CUDA random number generator +at::PhiloxCudaState init_philox_state( + at::CUDAGeneratorImpl* gen, + size_t max_seq_len, + size_t threads_per_cta) { + at::PhiloxCudaState philox_args; + size_t elts_per_thread = (max_seq_len * max_seq_len + threads_per_cta - 1)/threads_per_cta; + std::lock_guard lock(gen->mutex_); + philox_args = gen->philox_cuda_state(elts_per_thread); + return philox_args; +} + +// fused attention FWD with packed QKV +std::vector fused_attn_fwd_qkvpacked( + size_t b, size_t max_seqlen, size_t total_seqs, + size_t h, size_t d, + bool is_training, float attn_scale, float p_dropout, bool set_zero, + std::string qkv_layout, std::string bias_type, std::string attn_mask_type, + const at::Tensor cu_seqlens, + const at::Tensor QKV, + const transformer_engine::DType qkv_type, + const c10::optional descale_QKV, + const c10::optional scale_S, + const c10::optional scale_O, + c10::optional amax_S, + c10::optional amax_O, + const c10::optional Bias, + const c10::optional rng_gen) { + using namespace transformer_engine; + + // create output tensor O + auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA); + auto O = torch::empty({static_cast(total_seqs), + static_cast(h), static_cast(d)}, options); + if (set_zero) { + mha_fill(O, cu_seqlens.index({torch::indexing::Slice(-1, torch::indexing::None)})); + } + + // construct NVTE tensors + TensorWrapper te_QKV, te_S, te_O, te_Bias, te_cu_seqlens; + if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { + // FP8 + if ((!descale_QKV.has_value()) || (!scale_S.has_value()) || (!scale_O.has_value()) + || (!amax_S.has_value()) || (!amax_O.has_value())) { + std::string err_tensors = "descale_QKV, scale_S, scale_O, amax_S and amax_O"; + NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); + } + te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), {total_seqs, 3, h, d}, + qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); + at::Tensor descale_S = torch::empty_like(scale_S.value()); + te_S = makeTransformerEngineTensor(nullptr, {0}, + DType::kFloat32, amax_S.value().data_ptr(), + scale_S.value().data_ptr(), descale_S.data_ptr()); + te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs, h, d}, + qkv_type, amax_O.value().data_ptr(), scale_O.value().data_ptr(), nullptr); + } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { + // BF16 or FP16 + te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), {total_seqs, 3, h, d}, + qkv_type, nullptr, nullptr, nullptr); + te_S = makeTransformerEngineTensor(nullptr, {0}, + DType::kFloat32, nullptr, nullptr, nullptr); + te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs, h, d}, + qkv_type, nullptr, nullptr, nullptr); + } else { + NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); + } + if (Bias.has_value()) { + auto bias_shape = Bias.value().sizes().vec(); + std::vector shape{bias_shape.begin(), bias_shape.end()}; + te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), shape, + DType::kFloat32, nullptr, nullptr, nullptr); + } + te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens.data_ptr(), {b+1}, + DType::kInt32, nullptr, nullptr, nullptr); + + // convert strings to enums + NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout); + NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type); + NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type); + + // extract random number generator seed and offset + auto gen = at::get_generator_or_default( + rng_gen, at::cuda::detail::getDefaultCUDAGenerator()); + size_t threads_per_cta = 128; + at::PhiloxCudaState philox_args = init_philox_state(gen, max_seqlen, threads_per_cta); + auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); + unpack<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>( + philox_args, static_cast(rng_state.data_ptr())); + auto te_rng_state = makeTransformerEngineTensor(rng_state); + + // create auxiliary output tensors + // if training, tensors are [M, ZInv] + NVTETensorPack nvte_aux_tensor_pack; + nvte_tensor_pack_create(&nvte_aux_tensor_pack); + + // create workspace + TensorWrapper workspace; + + // populate tensors with appropriate shapes and dtypes + nvte_fused_attn_fwd_qkvpacked( + te_QKV.data(), + te_Bias.data(), + te_S.data(), + te_O.data(), + &nvte_aux_tensor_pack, + te_cu_seqlens.data(), + te_rng_state.data(), + max_seqlen, + is_training, attn_scale, p_dropout, + qkv_layout_enum, bias_type_enum, attn_mask_type_enum, + workspace.data(), + at::cuda::getCurrentCUDAStream()); + + // allocate memory for workspace and auxiliary output tensors + auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); + workspace = makeTransformerEngineTensor( + workspace_data.data_ptr(), + workspace.shape(), workspace.dtype()); + + // output_tensors = [O, nvte_aux_tensor_pack.tensors, rng_state] + std::vector output_tensors; + output_tensors.push_back(O); + // nvte_aux_tensor_pack.size is 0 if inference + for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { + auto tensor = reinterpret_cast(nvte_aux_tensor_pack.tensors[i]); + // allocate memory for nvte_aux_tensor_pack.tensors + auto output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false); + output_tensors.push_back(output_tensor); + tensor->data.dptr = output_tensor.data_ptr(); + } + if (is_training) { + output_tensors.push_back(rng_state); + } + + // execute the kernel + nvte_fused_attn_fwd_qkvpacked( + te_QKV.data(), + te_Bias.data(), + te_S.data(), + te_O.data(), + &nvte_aux_tensor_pack, + te_cu_seqlens.data(), + te_rng_state.data(), + max_seqlen, + is_training, attn_scale, p_dropout, + qkv_layout_enum, bias_type_enum, attn_mask_type_enum, + workspace.data(), + at::cuda::getCurrentCUDAStream()); + + // destroy tensor wrappers, but not allocated memory + nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); + + // if training, [O, M, ZInv, rng_state]; if inference, [O] + return output_tensors; +} + +// fused attention BWD with packed QKV +std::vector fused_attn_bwd_qkvpacked( + size_t b, size_t max_seqlen, size_t total_seqs, + size_t h, size_t d, + float attn_scale, float p_dropout, bool set_zero, + std::string qkv_layout, std::string bias_type, std::string attn_mask_type, + const at::Tensor cu_seqlens, + const at::Tensor QKV, + const at::Tensor O, + const at::Tensor dO, + const transformer_engine::DType qkv_type, + const std::vector Aux_CTX_Tensors, + const c10::optional descale_QKV, + const c10::optional descale_S, + const c10::optional descale_O, + const c10::optional descale_dO, + const c10::optional scale_S, + const c10::optional scale_dP, + const c10::optional scale_dQKV, + c10::optional amax_dP, + c10::optional amax_dQKV, + const c10::optional dBias) { + using namespace transformer_engine; + + // create output tensor dQKV + at::Tensor dQKV = torch::empty_like(QKV); + if (set_zero) { + mha_fill(dQKV, cu_seqlens.index({torch::indexing::Slice(-1, torch::indexing::None)})); + } + + // construct NVTE tensors + TensorWrapper te_QKV, te_O, te_dO, te_S, te_dP, te_dQKV, te_dBias; + if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { + // FP8 + if ((!descale_QKV.has_value()) || (!descale_S.has_value()) + || (!descale_O.has_value()) || (!descale_dO.has_value()) + || (!scale_S.has_value()) || (!scale_dP.has_value()) + || (!scale_dQKV.has_value()) + || (!amax_dP.has_value()) || (!amax_dQKV.has_value())) { + std::string err_tensors = "descale_QKV, descale_S, descale_O, scale_S, scale_dP, "; + err_tensors = err_tensors + std::string("scale_dQKV, amax_dP and amax_dQKV"); + NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); + } + te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), {total_seqs, 3, h, d}, + qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); + te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs, h, d}, + qkv_type, nullptr, nullptr, descale_O.value().data_ptr()); + te_dO = makeTransformerEngineTensor(dO.data_ptr(), {total_seqs, h, d}, + qkv_type, nullptr, nullptr, descale_dO.value().data_ptr()); + te_S = makeTransformerEngineTensor(nullptr, {0}, + DType::kFloat32, + nullptr, scale_S.value().data_ptr(), descale_S.value().data_ptr()); + at::Tensor descale_dP = torch::empty_like(scale_dP.value()); + te_dP = makeTransformerEngineTensor(nullptr, {0}, + DType::kFloat32, amax_dP.value().data_ptr(), scale_dP.value().data_ptr(), + descale_dP.data_ptr()); + te_dQKV = makeTransformerEngineTensor(dQKV.data_ptr(), {total_seqs, 3, h, d}, + qkv_type, + amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr); + } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { + // BF16 or FP16 + te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), {total_seqs, 3, h, d}, + qkv_type, nullptr, nullptr, nullptr); + te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs, h, d}, + qkv_type, nullptr, nullptr, nullptr); + te_dO = makeTransformerEngineTensor(dO.data_ptr(), {total_seqs, h, d}, + qkv_type, nullptr, nullptr, nullptr); + te_S = makeTransformerEngineTensor(nullptr, {0}, + DType::kFloat32, nullptr, nullptr, nullptr); + te_dP = makeTransformerEngineTensor(nullptr, {0}, + DType::kFloat32, nullptr, nullptr, nullptr); + te_dQKV = makeTransformerEngineTensor(dQKV.data_ptr(), {total_seqs, 3, h, d}, + qkv_type, nullptr, nullptr, nullptr); + } else { + NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); + } + if (dBias.has_value()) { + auto bias_shape = dBias.value().sizes().vec(); + std::vector shape{bias_shape.begin(), bias_shape.end()}; + te_dBias = makeTransformerEngineTensor( + dBias.value().data_ptr(), shape, DType::kFloat32, + nullptr, nullptr, nullptr); + } + + // convert strings to enums + NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout); + NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type); + NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type); + + // convert auxiliary tensors from forward into NVTETensors + // aux_ctx_tensors are [M, ZInv, rng_state] + NVTETensorPack nvte_aux_tensor_pack; + nvte_tensor_pack_create(&nvte_aux_tensor_pack); + nvte_aux_tensor_pack.size = Aux_CTX_Tensors.size(); + for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { + auto tensor = reinterpret_cast(nvte_aux_tensor_pack.tensors[i]); + tensor->data.dptr = Aux_CTX_Tensors[i].data_ptr(); + std::vector tmp(Aux_CTX_Tensors[i].sizes().vec()); + tensor->data.shape = std::vector(tmp.begin(), tmp.end()); + tensor->data.dtype = GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type()); + } + + // create cu_seqlens tensorwrappers + TensorWrapper te_cu_seqlens; + te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens.data_ptr(), {b+1}, + DType::kInt32, nullptr, nullptr, nullptr); + + // create workspace + TensorWrapper workspace; + + // populate tensors with appropriate shapes and dtypes + nvte_fused_attn_bwd_qkvpacked( + te_QKV.data(), + te_dBias.data(), + te_O.data(), + te_dO.data(), + te_S.data(), + te_dP.data(), + &nvte_aux_tensor_pack, + te_dQKV.data(), + te_cu_seqlens.data(), + max_seqlen, + attn_scale, p_dropout, + qkv_layout_enum, bias_type_enum, attn_mask_type_enum, + workspace.data(), + at::cuda::getCurrentCUDAStream()); + + // allocate memory for workspace + auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); + workspace = makeTransformerEngineTensor( + workspace_data.data_ptr(), + workspace.shape(), workspace.dtype()); + + // execute kernel + nvte_fused_attn_bwd_qkvpacked( + te_QKV.data(), + te_dBias.data(), + te_O.data(), + te_dO.data(), + te_S.data(), + te_dP.data(), + &nvte_aux_tensor_pack, + te_dQKV.data(), + te_cu_seqlens.data(), + max_seqlen, + attn_scale, p_dropout, + qkv_layout_enum, bias_type_enum, attn_mask_type_enum, + workspace.data(), + at::cuda::getCurrentCUDAStream()); + + // destroy tensor wrappers + nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); + + return {dQKV}; +} + +// fused attention FWD with packed KV +std::vector fused_attn_fwd_kvpacked( + size_t b, size_t max_seqlen_q, size_t max_seqlen_kv, + size_t total_seqs_q, size_t total_seqs_kv, + size_t h, size_t d, + bool is_training, float attn_scale, float p_dropout, bool set_zero, + std::string qkv_layout, std::string bias_type, std::string attn_mask_type, + const at::Tensor cu_seqlens_q, + const at::Tensor cu_seqlens_kv, + const at::Tensor Q, + const at::Tensor KV, + const transformer_engine::DType qkv_type, + const c10::optional descale_QKV, + const c10::optional scale_S, + const c10::optional scale_O, + c10::optional amax_S, + c10::optional amax_O, + const c10::optional Bias, + const c10::optional rng_gen) { + using namespace transformer_engine; + + // create output tensor O + auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA); + auto O = torch::empty({static_cast(total_seqs_q), + static_cast(h), static_cast(d)}, options); + if (set_zero) { + mha_fill(O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); + } + + // construct NVTE tensors + TensorWrapper te_Q, te_KV, te_S, te_O, te_Bias, te_cu_seqlens_q, te_cu_seqlens_kv; + if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { + // FP8 + if ((!descale_QKV.has_value()) || (!scale_S.has_value()) || (!scale_O.has_value()) + || (!amax_S.has_value()) || (!amax_O.has_value())) { + std::string err_tensors = "descale_QKV, scale_S, scale_O, amax_S and amax_O"; + NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); + } + te_Q = makeTransformerEngineTensor(Q.data_ptr(), {total_seqs_q, h, d}, + qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); + te_KV = makeTransformerEngineTensor(KV.data_ptr(), {total_seqs_kv, 2, h, d}, + qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); + at::Tensor descale_S = torch::empty_like(scale_S.value()); + te_S = makeTransformerEngineTensor(nullptr, {0}, + DType::kFloat32, amax_S.value().data_ptr(), + scale_S.value().data_ptr(), descale_S.data_ptr()); + te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs_q, h, d}, + qkv_type, amax_O.value().data_ptr(), scale_O.value().data_ptr(), nullptr); + } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { + // BF16 or FP16 + te_Q = makeTransformerEngineTensor(Q.data_ptr(), {total_seqs_q, h, d}, + qkv_type, nullptr, nullptr, nullptr); + te_KV = makeTransformerEngineTensor(KV.data_ptr(), {total_seqs_kv, 2, h, d}, + qkv_type, nullptr, nullptr, nullptr); + te_S = makeTransformerEngineTensor(nullptr, {0}, + DType::kFloat32, nullptr, nullptr, nullptr); + te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs_q, h, d}, + qkv_type, nullptr, nullptr, nullptr); + } else { + NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); + } + if (Bias.has_value()) { + auto bias_shape = Bias.value().sizes().vec(); + std::vector shape{bias_shape.begin(), bias_shape.end()}; + te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), shape, + DType::kFloat32, nullptr, nullptr, nullptr); + } + te_cu_seqlens_q = makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), {b+1}, + DType::kInt32, nullptr, nullptr, nullptr); + te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), {b+1}, + DType::kInt32, nullptr, nullptr, nullptr); + + // convert strings to enums + NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout); + NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type); + NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type); + + // extract rng seed and offset + auto gen = at::get_generator_or_default( + rng_gen, at::cuda::detail::getDefaultCUDAGenerator()); + size_t threads_per_cta = 128; + at::PhiloxCudaState philox_args = init_philox_state( + gen, max(max_seqlen_q, max_seqlen_kv), threads_per_cta); + auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); + unpack<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>( + philox_args, static_cast(rng_state.data_ptr())); + auto te_rng_state = makeTransformerEngineTensor(rng_state); + + // create auxiliary output tensors + // if training, tensors are [M, ZInv] + NVTETensorPack nvte_aux_tensor_pack; + nvte_tensor_pack_create(&nvte_aux_tensor_pack); + + // create workspace + TensorWrapper workspace; + + // populate tensors with appropriate shapes and dtypes + nvte_fused_attn_fwd_kvpacked( + te_Q.data(), + te_KV.data(), + te_Bias.data(), + te_S.data(), + te_O.data(), + &nvte_aux_tensor_pack, + te_cu_seqlens_q.data(), + te_cu_seqlens_kv.data(), + te_rng_state.data(), + max_seqlen_q, max_seqlen_kv, + is_training, attn_scale, p_dropout, + qkv_layout_enum, bias_type_enum, attn_mask_type_enum, + workspace.data(), + at::cuda::getCurrentCUDAStream()); + + // allocate memory for workspace and auxiliary output tensors + auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); + workspace = makeTransformerEngineTensor( + workspace_data.data_ptr(), + workspace.shape(), workspace.dtype()); + + // output_tensors = [O, nvte_aux_tensor_pack.tensors, rng_state] + std::vector output_tensors; + output_tensors.push_back(O); + // nvte_aux_tensor_pack.size is 0 if inference + for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { + auto tensor = reinterpret_cast(nvte_aux_tensor_pack.tensors[i]); + // allocate memory for nvte_aux_tensor_pack.tensors + auto output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false); + output_tensors.push_back(output_tensor); + tensor->data.dptr = output_tensor.data_ptr(); + } + if (is_training) { + output_tensors.push_back(rng_state); + } + + // execute the kernel + nvte_fused_attn_fwd_kvpacked( + te_Q.data(), + te_KV.data(), + te_Bias.data(), + te_S.data(), + te_O.data(), + &nvte_aux_tensor_pack, + te_cu_seqlens_q.data(), + te_cu_seqlens_kv.data(), + te_rng_state.data(), + max_seqlen_q, max_seqlen_kv, + is_training, attn_scale, p_dropout, + qkv_layout_enum, bias_type_enum, attn_mask_type_enum, + workspace.data(), + at::cuda::getCurrentCUDAStream()); + + // destroy tensor wrappers, but not allocated memory + nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); + + // if training, [O, M, ZInv, rng_state]; if inference, [O] + return output_tensors; +} + +// fused attention BWD with packed KV +std::vector fused_attn_bwd_kvpacked( + size_t b, size_t max_seqlen_q, size_t max_seqlen_kv, + size_t total_seqs_q, size_t total_seqs_kv, + size_t h, size_t d, + float attn_scale, float p_dropout, bool set_zero, + std::string qkv_layout, std::string bias_type, std::string attn_mask_type, + const at::Tensor cu_seqlens_q, + const at::Tensor cu_seqlens_kv, + const at::Tensor Q, + const at::Tensor KV, + const at::Tensor O, + const at::Tensor dO, + const transformer_engine::DType qkv_type, + const std::vector Aux_CTX_Tensors, + const c10::optional descale_QKV, + const c10::optional descale_S, + const c10::optional descale_O, + const c10::optional descale_dO, + const c10::optional scale_S, + const c10::optional scale_dP, + const c10::optional scale_dQKV, + c10::optional amax_dP, + c10::optional amax_dQKV, + const c10::optional dBias) { + using namespace transformer_engine; + + // create output tensors dQ and dKV + at::Tensor dQ = torch::empty_like(Q); + at::Tensor dKV = torch::empty_like(KV); + if (set_zero) { + mha_fill(dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); + mha_fill(dKV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); + } + + // construct NVTE tensors + TensorWrapper te_Q, te_KV, te_O, te_dO, te_S, te_dP, te_dQ, te_dKV, te_dBias; + if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { + // FP8 + if ((!descale_QKV.has_value()) || (!descale_S.has_value()) + || (!descale_O.has_value()) || (!descale_dO.has_value()) + || (!scale_S.has_value()) || (!scale_dP.has_value()) + || (!scale_dQKV.has_value()) + || (!amax_dP.has_value()) || (!amax_dQKV.has_value())) { + std::string err_tensors = "descale_QKV, descale_S, descale_O, scale_S, scale_dP, "; + err_tensors = err_tensors + std::string("scale_dQKV, amax_dP and amax_dQKV"); + NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); + } + te_Q = makeTransformerEngineTensor(Q.data_ptr(), {total_seqs_q, h, d}, + qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); + te_KV = makeTransformerEngineTensor(KV.data_ptr(), {total_seqs_kv, 2, h, d}, + qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); + te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs_q, h, d}, + qkv_type, nullptr, nullptr, descale_O.value().data_ptr()); + te_dO = makeTransformerEngineTensor(dO.data_ptr(), {total_seqs_q, h, d}, + qkv_type, nullptr, nullptr, descale_dO.value().data_ptr()); + te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, + scale_S.value().data_ptr(), descale_S.value().data_ptr()); + at::Tensor descale_dP = torch::empty_like(scale_dP.value()); + te_dP = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, + amax_dP.value().data_ptr(), scale_dP.value().data_ptr(), + descale_dP.data_ptr()); + te_dQ = makeTransformerEngineTensor(dQ.data_ptr(), {total_seqs_q, h, d}, qkv_type, + amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr); + te_dKV = makeTransformerEngineTensor(dKV.data_ptr(), {total_seqs_kv, 2, h, d}, qkv_type, + amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr); + } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { + // BF16 or FP16 + te_Q = makeTransformerEngineTensor(Q.data_ptr(), {total_seqs_q, h, d}, + qkv_type, nullptr, nullptr, nullptr); + te_KV = makeTransformerEngineTensor(KV.data_ptr(), {total_seqs_kv, 2, h, d}, + qkv_type, nullptr, nullptr, nullptr); + te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs_q, h, d}, + qkv_type, nullptr, nullptr, nullptr); + te_dO = makeTransformerEngineTensor(dO.data_ptr(), {total_seqs_q, h, d}, + qkv_type, nullptr, nullptr, nullptr); + te_S = makeTransformerEngineTensor(nullptr, {0}, + DType::kFloat32, nullptr, nullptr, nullptr); + te_dP = makeTransformerEngineTensor(nullptr, {0}, + DType::kFloat32, nullptr, nullptr, nullptr); + te_dQ = makeTransformerEngineTensor(dQ.data_ptr(), {total_seqs_q, h, d}, + qkv_type, nullptr, nullptr, nullptr); + te_dKV = makeTransformerEngineTensor(dKV.data_ptr(), {total_seqs_kv, 2, h, d}, + qkv_type, nullptr, nullptr, nullptr); + } else { + NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); + } + if (dBias.has_value()) { + auto bias_shape = dBias.value().sizes().vec(); + std::vector shape{bias_shape.begin(), bias_shape.end()}; + te_dBias = makeTransformerEngineTensor( + dBias.value().data_ptr(), shape, DType::kFloat32, + nullptr, nullptr, nullptr); + } + + // create cu_seqlens tensorwrappers + TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv; + te_cu_seqlens_q = makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), {b+1}, + DType::kInt32, nullptr, nullptr, nullptr); + te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), {b+1}, + DType::kInt32, nullptr, nullptr, nullptr); + + // convert strings to enums + NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout); + NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type); + NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type); + + // convert auxiliary tensors from forward to NVTETensors + // aux_ctx_tensors are [M, ZInv, rng_state] + NVTETensorPack nvte_aux_tensor_pack; + nvte_tensor_pack_create(&nvte_aux_tensor_pack); + nvte_aux_tensor_pack.size = Aux_CTX_Tensors.size(); + for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { + auto tensor = reinterpret_cast(nvte_aux_tensor_pack.tensors[i]); + tensor->data.dptr = Aux_CTX_Tensors[i].data_ptr(); + std::vector tmp(Aux_CTX_Tensors[i].sizes().vec()); + tensor->data.shape = std::vector(tmp.begin(), tmp.end()); + tensor->data.dtype = GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type()); + } + + // create workspace + TensorWrapper workspace; + + // populate tensors with appropriate shapes and dtypes + nvte_fused_attn_bwd_kvpacked( + te_Q.data(), + te_KV.data(), + te_dBias.data(), + te_O.data(), + te_dO.data(), + te_S.data(), + te_dP.data(), + &nvte_aux_tensor_pack, + te_dQ.data(), + te_dKV.data(), + te_cu_seqlens_q.data(), + te_cu_seqlens_kv.data(), + max_seqlen_q, max_seqlen_kv, + attn_scale, p_dropout, + qkv_layout_enum, bias_type_enum, attn_mask_type_enum, + workspace.data(), + at::cuda::getCurrentCUDAStream()); + + // allocate memory for workspace + auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); + workspace = makeTransformerEngineTensor( + workspace_data.data_ptr(), + workspace.shape(), workspace.dtype()); + + // execute kernel + nvte_fused_attn_bwd_kvpacked( + te_Q.data(), + te_KV.data(), + te_dBias.data(), + te_O.data(), + te_dO.data(), + te_S.data(), + te_dP.data(), + &nvte_aux_tensor_pack, + te_dQ.data(), + te_dKV.data(), + te_cu_seqlens_q.data(), + te_cu_seqlens_kv.data(), + max_seqlen_q, max_seqlen_kv, + attn_scale, p_dropout, + qkv_layout_enum, bias_type_enum, attn_mask_type_enum, + workspace.data(), + at::cuda::getCurrentCUDAStream()); + + // destroy tensor wrappers + nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); + + return {dQ, dKV}; +} + void te_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType A_type, @@ -749,13 +1485,13 @@ at::Tensor cast_to_fp8(const at::Tensor &input, transformer_engine::DType otype ) { using namespace transformer_engine; - size_t N = static_cast(input.size(0)); - size_t H = static_cast(input.size(1)); + auto input_shape = input.sizes().vec(); + std::vector shape{input_shape.begin(), input_shape.end()}; auto output = at::empty_like(input, at::CUDA(GetATenDType(otype))); auto input_cu = makeTransformerEngineTensor(input); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, H}, otype, + auto output_cu = makeTransformerEngineTensor(output.data_ptr(), shape, otype, amax.data_ptr(), scale.data_ptr(), scale_inv.data_ptr()); @@ -795,12 +1531,12 @@ at::Tensor cast_from_fp8(const at::Tensor &input, transformer_engine::DType otype ) { using namespace transformer_engine; - size_t N = static_cast(input.size(0)); - size_t H = static_cast(input.size(1)); + auto input_shape = input.sizes().vec(); + std::vector shape{input_shape.begin(), input_shape.end()}; auto output = at::empty_like(input, at::CUDA(GetATenDType(otype))); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {N, H}, itype, + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), shape, itype, nullptr, nullptr, scale_inv.data_ptr()); auto output_cu = makeTransformerEngineTensor(output); @@ -1066,6 +1802,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("cast_to_fp8_noalloc", &cast_to_fp8_noalloc, "Cast to FP8"); m.def("cast_from_fp8", &cast_from_fp8, "Cast from FP8"); m.def("te_gemm", &te_gemm, "CublasLt GEMM"); + m.def("fused_attn_fwd_qkvpacked", &fused_attn_fwd_qkvpacked, + "Fused Attention FP8/BF16/FP16 FWD with packed QKV"); + m.def("fused_attn_bwd_qkvpacked", &fused_attn_bwd_qkvpacked, + "Fused Attention FP8/BF16/FP16 BWD with packed QKV"); + m.def("fused_attn_fwd_kvpacked", &fused_attn_fwd_kvpacked, + "Fused Attention FP8/BF16/FP16 FWD with packed KV"); + m.def("fused_attn_bwd_kvpacked", &fused_attn_bwd_kvpacked, + "Fused Attention FP8/BF16/FP16 BWD with packed KV"); m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O"); m.def("fp8_gelu", &fp8_gelu, "GeLU with FP8 output"); diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 6be404226e..561ba417e6 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -5,7 +5,95 @@ ************************************************************************/ #include "common.h" - +#include "../common.h" + +NVTE_QKV_Layout get_nvte_qkv_layout(const std::string qkv_layout); + +NVTE_Bias_Type get_nvte_bias_type(const std::string bias_type); + +NVTE_Mask_Type get_nvte_mask_type(const std::string mask_type); + +std::vector fused_attn_fwd_qkvpacked( + size_t b, size_t max_seqlen, size_t total_seqs, + size_t h, size_t d, + bool is_training, float attn_scale, float p_dropout, bool set_zero, + std::string qkv_layout, std::string bias_type, std::string attn_mask_type, + const at::Tensor cu_seqlens, + const at::Tensor QKV, + const transformer_engine::DType qkv_type, + const c10::optional descale_QKV, + const c10::optional scale_S, + const c10::optional scale_O, + c10::optional amax_S, + c10::optional amax_O, + const c10::optional Bias, + const c10::optional rng_gen); + +std::vector fused_attn_bwd_qkvpacked( + size_t b, size_t max_seqlen, size_t total_seqs, + size_t h, size_t d, + float attn_scale, float p_dropout, bool set_zero, + std::string qkv_layout, std::string bias_type, std::string attn_mask_type, + const at::Tensor cu_seqlens, + const at::Tensor QKV, + const at::Tensor O, + const at::Tensor dO, + const transformer_engine::DType qkv_type, + const std::vector Aux_CTX_Tensors, + const c10::optional descale_QKV, + const c10::optional descale_S, + const c10::optional descale_O, + const c10::optional descale_dO, + const c10::optional scale_S, + const c10::optional scale_dP, + const c10::optional scale_dQKV, + c10::optional amax_dP, + c10::optional amax_dQKV, + const c10::optional dBias); + +std::vector fused_attn_fwd_kvpacked( + size_t b, size_t max_seqlen_q, size_t max_seqlen_kv, + size_t total_seqs_q, size_t total_seqs_kv, + size_t h, size_t d, + bool is_training, float attn_scale, float p_dropout, bool set_zero, + std::string qkv_layout, std::string bias_type, std::string attn_mask_type, + const at::Tensor cu_seqlens_q, + const at::Tensor cu_seqlens_kv, + const at::Tensor Q, + const at::Tensor KV, + const transformer_engine::DType qkv_type, + const c10::optional descale_QKV, + const c10::optional scale_S, + const c10::optional scale_O, + c10::optional amax_S, + c10::optional amax_O, + const c10::optional Bias, + const c10::optional rng_gen); + +std::vector fused_attn_bwd_kvpacked( + size_t b, size_t max_seqlen_q, size_t max_seqlen_kv, + size_t total_seqs_q, size_t total_seqs_kv, + size_t h, size_t d, + float attn_scale, float p_dropout, bool set_zero, + std::string qkv_layout, std::string bias_type, std::string attn_mask_type, + const at::Tensor cu_seqlens_q, + const at::Tensor cu_seqlens_kv, + const at::Tensor Q, + const at::Tensor KV, + const at::Tensor O, + const at::Tensor dO, + const transformer_engine::DType qkv_type, + const std::vector Aux_CTX_Tensors, + const c10::optional descale_QKV, + const c10::optional descale_S, + const c10::optional descale_O, + const c10::optional descale_dO, + const c10::optional scale_S, + const c10::optional scale_dP, + const c10::optional scale_dQKV, + c10::optional amax_dP, + c10::optional amax_dQKV, + const c10::optional dBias); void te_gemm(at::Tensor A, at::Tensor A_scale_inverse, diff --git a/transformer_engine/pytorch/module.py b/transformer_engine/pytorch/module.py index 3e0a868047..07805088b2 100644 --- a/transformer_engine/pytorch/module.py +++ b/transformer_engine/pytorch/module.py @@ -102,7 +102,7 @@ def get_workspace() -> torch.Tensor: global _cublas_workspace if _cublas_workspace is None: _cublas_workspace = torch.empty( - get_cublas_workspace_size_bytes(), dtype=torch.int8, device="cuda" + get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda" ) return _cublas_workspace @@ -520,7 +520,7 @@ def set_fp8_weights(self) -> None: torch.empty( shape, device=torch.cuda.current_device(), - dtype=torch.int8, + dtype=torch.uint8, ), ) setattr( @@ -530,7 +530,7 @@ def set_fp8_weights(self) -> None: shape[1], shape[0], device=torch.cuda.current_device(), - dtype=torch.int8, + dtype=torch.uint8, ), ) From e1ef756590fb3e73043c1f17b1f6783d9b40b016 Mon Sep 17 00:00:00 2001 From: Sangkug Lym Date: Fri, 21 Apr 2023 16:22:56 -0700 Subject: [PATCH 21/68] zero inter-node communication buffer (#163) Signed-off-by: Sangkug Lym Co-authored-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/csrc/userbuffers/userbuffers.h | 2 -- 1 file changed, 2 deletions(-) diff --git a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.h b/transformer_engine/pytorch/csrc/userbuffers/userbuffers.h index 1d4c1d4024..d6ec23c40d 100644 --- a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.h +++ b/transformer_engine/pytorch/csrc/userbuffers/userbuffers.h @@ -34,8 +34,6 @@ #define NVTE_REG0_FLAGS (NVTE_REG0_RECV + NVTE_MAX_PEERS * NVTE_MAX_REGIONS) #define NVTE_REG0_IBRS 32 #define NVTE_REG0_IBAG 512 -#undef NVTE_REG0_COMMBUFFER -#define NVTE_REG0_COMMBUFFER (1024 * 1024 * 16) // gpuflags map offsets #define NVTE_GF_STATE 16000 From 9d90eb477974182c196b556a1fea79b81c368603 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Fri, 21 Apr 2023 16:40:15 -0700 Subject: [PATCH 22/68] Remove userbuf docs (#164) Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/transformer.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 52d303e8f4..dfa28846af 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -924,12 +924,6 @@ class TransformerLayer(torch.nn.Module): `set_tensor_parallel_group(tp_group)` method on the initialized module before the forward pass to supply the tensor parallel group needed for tensor and sequence parallel collectives. - ub_bulk_wgrad: bool, default = False - Bulk overlap UserBuffer ReduceScatter | WGRAD GEMM - ub_bulk_dgrad: bool, default = False - Bulk overlap UserBuffer AllGather | DGRAD GEMM - ub_split_ag: bool, default = False - Split pipelined overlap UserBuffer AllGather -> GEMM Optimization parameters ----------------------- From 71488dbec80899d2ce5e1730b08a6feb9451f0ec Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 27 Apr 2023 17:09:24 -0700 Subject: [PATCH 23/68] Faster split of QKV for FlashAttention (#166) * Faster split of QKV for FlashAttention Signed-off-by: Przemek Tredak * CI fixes Signed-off-by: Przemek Tredak * Fix Signed-off-by: Przemek Tredak * review comments Signed-off-by: Kirthi Shankar Sivamani * Message with assert Co-authored-by: Przemyslaw Tredak Signed-off-by: Kirthi Shankar Sivamani * Review comments Signed-off-by: Kirthi Shankar Sivamani * review Signed-off-by: Kirthi Shankar Sivamani * fix misalignment error Signed-off-by: Kirthi Shankar Sivamani * make clarifying comment and check strides Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Przemek Tredak Signed-off-by: Kirthi Shankar Sivamani Co-authored-by: Kirthi Shankar Sivamani Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/csrc/extensions.cu | 171 ++++++++++++++++++ transformer_engine/pytorch/transformer.py | 119 +++++++++++- 2 files changed, 284 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions.cu b/transformer_engine/pytorch/csrc/extensions.cu index 75d4abd031..4cb6c50c34 100644 --- a/transformer_engine/pytorch/csrc/extensions.cu +++ b/transformer_engine/pytorch/csrc/extensions.cu @@ -1767,6 +1767,175 @@ bool userbuf_comm_available() { // TODO(ksivamani) check on python side void placeholder() {} // TODO(ksivamani) clean this up +namespace flash_attention { + +constexpr int warp_size = 32; +constexpr int type_size = 2; // FP16 or BF16 +constexpr int nvec = sizeof(uint64_t) / type_size; +constexpr int load_size = warp_size * nvec; +constexpr int block_size = 512; + +template +__launch_bounds__(block_size) +__global__ void prepare_kernel_fwd(const T *qkvi, + T *qkv, + const size_t B, + const size_t S, + const size_t Z, + const size_t W) { + const int warpid = (blockDim.x * blockIdx.x + threadIdx.x) / warp_size; + const int id_in_warp = threadIdx.x % warp_size; + const size_t offset_input = blockIdx.y * W + warpid * 3 * W * Z + id_in_warp * nvec; + const T *my_input = qkvi + offset_input; + + const size_t s = warpid / B; + if (s >= S) return; + + const size_t b = warpid % B; + + const size_t offset_output = blockIdx.y * B * S * Z * W + + (s + b * S) * W * Z + + id_in_warp * nvec; + + T *my_output = qkv + offset_output; + + for (int i = 0; i < Z; ++i) { + uint64_t *out = reinterpret_cast(my_output + i * load_size); + *out = *reinterpret_cast(my_input + i * load_size * 3); + } +} + +template +__launch_bounds__(block_size) +__global__ void prepare_kernel_bwd(const T *q, const T *k, const T *v, + T *qkv, const size_t B, const size_t S, + const size_t Z, const size_t W) { + const T *input = blockIdx.y == 0 ? q : (blockIdx.y == 1 ? k : v); + + const int warpid = (blockDim.x * blockIdx.x + threadIdx.x) / warp_size; + const int id_in_warp = threadIdx.x % warp_size; + const size_t offset_input = warpid * W * Z + id_in_warp * nvec; + const T *my_input = input + offset_input; + + const size_t b = warpid / S; + if (b >= B) return; + + const size_t s = warpid % S; + + const size_t offset_output = (b + s * B) * 3 * W * Z + + id_in_warp * nvec + blockIdx.y * W; + + T *my_output = qkv + offset_output; + + for (int i = 0; i < Z; ++i) { + uint64_t *out = reinterpret_cast(my_output + i * load_size * 3); + *out = *reinterpret_cast(my_input + i * load_size); + } +} + +} // namespace flash_attention + +at::Tensor fa_prepare_fwd(at::Tensor qkvi) { + NVTE_CHECK(qkvi.dim() == 4, "Expected 4-dim tensor."); + NVTE_CHECK(qkvi.scalar_type() == at::ScalarType::Half || + qkvi.scalar_type() == at::ScalarType::BFloat16); + NVTE_CHECK(qkvi.size(3) % flash_attention::load_size == 0); + NVTE_CHECK(qkvi.size(3) == flash_attention::load_size); + NVTE_CHECK(qkvi.stride(3) == 1, "Wrong stride."); + NVTE_CHECK(qkvi.stride(2) == 3 * qkvi.size(3), "Wrong stride."); + NVTE_CHECK(qkvi.stride(1) == 3 * qkvi.size(3) * qkvi.size(2), "Wrong stride."); + NVTE_CHECK(qkvi.stride(0) == 3 * qkvi.size(3) * qkvi.size(2) * qkvi.size(1), "Wrong stride."); + + // [s, b, n, h * 3] -> [3, b, s, n, h] + std::vector shape = {3, qkvi.size(1), qkvi.size(0), qkvi.size(2), qkvi.size(3)}; + at::Tensor qkv = at::empty(shape, at::CUDA(qkvi.scalar_type())); + + size_t warps = qkvi.size(0) * qkvi.size(1); + size_t warps_per_block = flash_attention::block_size / flash_attention::warp_size; + size_t blocks = (warps + warps_per_block - 1) / warps_per_block; + dim3 grid(blocks, 3); + int threads = flash_attention::block_size; + if (qkvi.scalar_type() == at::ScalarType::Half) { + using dtype = at::Half; + flash_attention::prepare_kernel_fwd<<>>( + qkvi.data_ptr(), + qkv.data_ptr(), + shape[1], + shape[2], + shape[3], + shape[4]); + } else { + using dtype = at::BFloat16; + flash_attention::prepare_kernel_fwd<<>>( + qkvi.data_ptr(), + qkv.data_ptr(), + shape[1], + shape[2], + shape[3], + shape[4]); + } + + return qkv; +} + +at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v) { + NVTE_CHECK(q.is_contiguous()); + NVTE_CHECK(k.is_contiguous()); + NVTE_CHECK(v.is_contiguous()); + NVTE_CHECK(q.dim() == 4, "Expected 4-dim tensor."); + NVTE_CHECK(k.dim() == 4, "Expected 4-dim tensor."); + NVTE_CHECK(v.dim() == 4, "Expected 4-dim tensor."); + NVTE_CHECK(q.scalar_type() == at::ScalarType::Half || + q.scalar_type() == at::ScalarType::BFloat16); + NVTE_CHECK(k.scalar_type() == q.scalar_type()); + NVTE_CHECK(v.scalar_type() == q.scalar_type()); + NVTE_CHECK(q.size(3) % flash_attention::load_size == 0); + NVTE_CHECK(q.size(3) == flash_attention::load_size); + NVTE_CHECK(k.size(3) % flash_attention::load_size == 0); + NVTE_CHECK(k.size(3) == flash_attention::load_size); + NVTE_CHECK(v.size(3) % flash_attention::load_size == 0); + NVTE_CHECK(v.size(3) == flash_attention::load_size); + + // 3 x [s, b, n, h] -> [b, s, n, 3 * h] + + std::vector shape = {q.size(1), q.size(0), q.size(2), 3 * q.size(3)}; + at::Tensor qkv = at::empty(shape, at::CUDA(q.scalar_type())); + + size_t warps = q.size(0) * q.size(1); + size_t warps_per_block = flash_attention::block_size / flash_attention::warp_size; + size_t blocks = (warps + warps_per_block - 1) / warps_per_block; + dim3 grid(blocks, 3); + int threads = flash_attention::block_size; + if (q.scalar_type() == at::ScalarType::Half) { + using dtype = at::Half; + flash_attention::prepare_kernel_bwd<<>>( + q.data_ptr(), + k.data_ptr(), + v.data_ptr(), + qkv.data_ptr(), + q.size(0), + q.size(1), + q.size(2), + q.size(3)); + } else { + using dtype = at::BFloat16; + flash_attention::prepare_kernel_bwd<<>>( + q.data_ptr(), + k.data_ptr(), + v.data_ptr(), + qkv.data_ptr(), + q.size(0), + q.size(1), + q.size(2), + q.size(3)); + } + + return qkv; +} PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // Softmax functions @@ -1812,6 +1981,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Fused Attention FP8/BF16/FP16 BWD with packed KV"); m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O"); m.def("fp8_gelu", &fp8_gelu, "GeLU with FP8 output"); + m.def("fa_prepare_fwd", &fa_prepare_fwd, "Prepare QKV for Flash Attention"); + m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention"); // Misc m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version"); diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index dfa28846af..7071378b61 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -77,6 +77,48 @@ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: output = hidden_state.div(keep_prob) * random_tensor return output +class _SplitLastDim(torch.autograd.Function): + """""" + + @staticmethod + def forward(ctx, + mixed_x_layer: torch.Tensor, + num_parts: int + ) -> Tuple[torch.Tensor, ...]: + return split_tensor_along_dim(mixed_x_layer, -1, num_parts) + + @staticmethod + def backward(ctx, + *grad_outputs): + assert len(grad_outputs) > 0, "No gradients received for backprop!" + + noop_ok = True + strides = grad_outputs[0].stride() + data_ptr = grad_outputs[0].untyped_storage().data_ptr() + shape = grad_outputs[0].shape + last_dim_size = grad_outputs[0].shape[-1] + for i, tensor in enumerate(grad_outputs): + if (tensor.stride() != strides or + tensor.shape != shape or + tensor.untyped_storage().data_ptr() != data_ptr or + tensor.storage_offset() != i * last_dim_size): + noop_ok = False + break + + if noop_ok: + ret = torch.Tensor().to(grad_outputs[0].dtype) + ret = torch.Tensor().to(device=grad_outputs[0].device, + dtype=grad_outputs[0].dtype) + new_shape = list(shape) + new_shape[-1] = new_shape[-1] * len(grad_outputs) + ret.set_(grad_outputs[0].untyped_storage(), + grad_outputs[0].storage_offset(), + new_shape, + grad_outputs[0].stride() + ) + return ret, None + + return torch.cat(grad_outputs, dim = -1), None class UnfusedDotProductAttention(torch.nn.Module): """Parallel attention w/o QKV and Proj Gemms @@ -204,6 +246,56 @@ def forward( return context_layer +class _PrepareQKVForFA(torch.autograd.Function): + """This class converts QKV from interleaved (s, b, ...) layout + to separate contiguous q, k, v tensors in (b, s, ...) layout.""" + + @staticmethod + def forward(ctx, + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor + ) -> torch.Tensor: + # All inputs received are non-contiguous tensors. + # The `query_layer` tensor is used to access the + # full memory region of the QKV tensor. + qkv = tex.fa_prepare_fwd(query_layer) + q, k, v = split_tensor_along_dim(qkv, 0, 3) + query_layer = torch.squeeze(q, 0) + key_layer = torch.squeeze(k, 0) + value_layer = torch.squeeze(v, 0) + return query_layer, key_layer, value_layer + + @staticmethod + def backward(ctx, + dq: torch.Tensor, + dk: torch.Tensor, + dv: torch.Tensor + ) -> Tuple[Union[torch.Tensor, None], ...]: + dqkv = tex.fa_prepare_bwd(dq, dk, dv) + dq, dk, dv = split_tensor_along_dim(dqkv, -1, 3) + return dq, dk, dv + +def _check_if_interleaved(q, k, v): + data_ptr = q.untyped_storage().data_ptr() + check_ptrs = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k, v]) + if not check_ptrs: + return False + + stride = q.stride() + check_strides = all(stride == x.stride() for x in [q, k, v]) + if not check_strides: + return False + + shape = q.shape + check_shapes = all(shape == x.shape for x in [q, k, v]) + if not check_shapes: + return False + + last_dim_size = shape[-1] + check_offsets = all(i * last_dim_size == x.storage_offset() + for i, x in enumerate([q, k, v])) + return check_offsets class FlashAttention(torch.nn.Module): """Dot product attention implementation by using the flash-attn package. @@ -252,8 +344,17 @@ def forward( attention_mask is None ), 'FlashAttention currently does not support external attention mask.' - query_layer, key_layer, value_layer = [x.transpose(0,1).contiguous() - for x in (query_layer, key_layer, value_layer)] + # For now just 128, will make it more general in the future + + if (query_layer.shape[-1] == 128 and + query_layer.shape[0] * query_layer.shape[1] >= 512 and + _check_if_interleaved(query_layer, key_layer, value_layer)): + query_layer, key_layer, value_layer = _PrepareQKVForFA.apply(query_layer, + key_layer, + value_layer) + else: + query_layer, key_layer, value_layer = [x.transpose(0,1).contiguous() + for x in (query_layer, key_layer, value_layer)] batch_size, seqlen = query_layer.shape[0], query_layer.shape[1] @@ -731,9 +832,12 @@ def forward( mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) # mixed_x_layer --> 3 [sq, b, np, hn] - query_layer, key_layer, value_layer = split_tensor_along_dim( - mixed_x_layer, split_dim, 3 - ) + if split_dim == -1 and not is_in_onnx_export_mode(): + query_layer, key_layer, value_layer = _SplitLastDim.apply(mixed_x_layer, 3) + else: + query_layer, key_layer, value_layer = split_tensor_along_dim( + mixed_x_layer, split_dim, 3 + ) else: # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] mixed_kv_layer = self.key_value( @@ -761,7 +865,10 @@ def forward( mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape) # mixed_kv_layer --> 2 [sk, b, np, hn] - key_layer, value_layer = split_tensor_along_dim(mixed_kv_layer, split_dim, 2) + if split_dim == -1 and not is_in_onnx_export_mode(): + key_layer, value_layer = _SplitLastDim.apply(mixed_kv_layer, 2) + else: + key_layer, value_layer = split_tensor_along_dim(mixed_kv_layer, split_dim, 2) # Attention head [sq, b, h] --> [sq, b, hp] if self.input_layernorm: From 87706dc6a65e7d5e44acf801527ceb898e990ecd Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 27 Apr 2023 17:09:59 -0700 Subject: [PATCH 24/68] Remove the nonexistent parameter from fused attention documentation (#181) * Remove the nonexistent parameter from fused attention documentation Signed-off-by: Przemek Tredak * Remove the second instance Signed-off-by: Przemek Tredak --------- Signed-off-by: Przemek Tredak Signed-off-by: Kirthi Shankar Sivamani --- .../common/include/transformer_engine/fused_attn.h | 2 -- 1 file changed, 2 deletions(-) diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index bb9262de18..967fc62724 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -133,7 +133,6 @@ void nvte_fused_attn_fwd_qkvpacked( * \param[in] Aux_CTX_Tensors Auxiliary tensors from forward when in training mode. * \param[out] dQKV The gradient of the QKV tensor. * \param[in] cu_seqlens Accumulative sequence lengths, [batch_size + 1]. - * \param[in] rng_state Seed and offset of CUDA random number generator. * \param[in] max_seqlen Max sequence length used for computing, * it may be >= max(cu_seqlens). * \param[in] attn_scale Scaling factor for Q * K.T. @@ -222,7 +221,6 @@ void nvte_fused_attn_fwd_kvpacked( * \param[out] dKV The gradient of the KV tensor. * \param[in] cu_seqlens_q Accumulative sequence lengths for Q, [batch_size + 1]. * \param[in] cu_seqlens_kv Accumulative sequence lengths for KV, [batch_size + 1]. - * \param[in] rng_state Seed and offset of CUDA random number generator. * \param[in] max_seqlen_q Max sequence length used for computing for Q. * it may be >= max(cu_seqlens_q). * \param[in] max_seqlen_kv Max sequence length used for computing for KV. From 2ce7f0c8b06498a41eb90192bef28021b46ffb26 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 27 Apr 2023 17:12:07 -0700 Subject: [PATCH 25/68] Re-add support for PyTorch version 1.x (#180) Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/transformer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 7071378b61..fae4ff595d 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -94,13 +94,13 @@ def backward(ctx, noop_ok = True strides = grad_outputs[0].stride() - data_ptr = grad_outputs[0].untyped_storage().data_ptr() + data_ptr = grad_outputs[0].storage().data_ptr() shape = grad_outputs[0].shape last_dim_size = grad_outputs[0].shape[-1] for i, tensor in enumerate(grad_outputs): if (tensor.stride() != strides or tensor.shape != shape or - tensor.untyped_storage().data_ptr() != data_ptr or + tensor.storage().data_ptr() != data_ptr or tensor.storage_offset() != i * last_dim_size): noop_ok = False break @@ -111,7 +111,7 @@ def backward(ctx, dtype=grad_outputs[0].dtype) new_shape = list(shape) new_shape[-1] = new_shape[-1] * len(grad_outputs) - ret.set_(grad_outputs[0].untyped_storage(), + ret.set_(grad_outputs[0].storage(), grad_outputs[0].storage_offset(), new_shape, grad_outputs[0].stride() @@ -277,8 +277,8 @@ def backward(ctx, return dq, dk, dv def _check_if_interleaved(q, k, v): - data_ptr = q.untyped_storage().data_ptr() - check_ptrs = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k, v]) + data_ptr = q.storage().data_ptr() + check_ptrs = all(x.storage().data_ptr() == data_ptr for x in [q, k, v]) if not check_ptrs: return False From 00707bbd13429d40ee1eec0f11b09c9cff743b83 Mon Sep 17 00:00:00 2001 From: Ming-Xu Huang Date: Fri, 28 Apr 2023 10:04:35 +0800 Subject: [PATCH 26/68] [JAX] Adjust Module Structure. (#169) * Adjust Module Structure. 1. Collect Flax related modules to a sub-folder, flax. 2. Add a function to unify scale_init for zero-centered-gamma LN. Signed-off-by: Ming Huang * Make changes be compatible to previous versions. Signed-off-by: Ming Huang * Adapt jax/examples to the new module structure. Signed-off-by: Ming Huang * Update jax/docs and Add deprecated warning. Signed-off-by: Ming Huang * Update README Signed-off-by: Ming Huang * Adding deprecated_wrapper Signed-off-by: Ming Huang * Adding deprecated warning to flax modules which imported via transformer_engine.jax Signed-off-by: Ming Huang * Fix CI errors and update docs. Signed-off-by: Ming Huang * Removing unnecessary deprecated warning in docs. Signed-off-by: Ming Huang * Implementing __iter__ to DeprecatedEnum. Signed-off-by: Ming Huang --------- Signed-off-by: Ming Huang Co-authored-by: Kirthi Shankar Sivamani --- README.rst | 5 +- docs/api/jax.rst | 25 +++++---- .../encoder/test_model_parallel_encoder.py | 24 ++++----- examples/jax/encoder/test_multigpu_encoder.py | 8 +-- .../jax/encoder/test_single_gpu_encoder.py | 6 +-- examples/jax/mnist/test_single_gpu_mnist.py | 2 +- tests/jax/test_layer.py | 2 +- tests/jax/test_sharding.py | 2 +- transformer_engine/common/utils.py | 53 +++++++++++++++++++ transformer_engine/jax/__init__.py | 41 ++++++++++++-- transformer_engine/jax/flax/__init__.py | 9 ++++ transformer_engine/jax/{ => flax}/module.py | 46 ++++++++-------- .../jax/{ => flax}/transformer.py | 6 +-- 13 files changed, 162 insertions(+), 67 deletions(-) create mode 100644 transformer_engine/common/utils.py create mode 100644 transformer_engine/jax/flax/__init__.py rename transformer_engine/jax/{ => flax}/module.py (97%) rename transformer_engine/jax/{ => flax}/transformer.py (99%) diff --git a/README.rst b/README.rst index fe576f3498..6964f219d0 100644 --- a/README.rst +++ b/README.rst @@ -69,6 +69,9 @@ pyTorch JAX ^^^ +Flax +~~~~ + .. code-block:: python import jax @@ -90,7 +93,7 @@ JAX # Enable autocasting for the forward pass with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): - model = te.DenseGeneral(features=HIDDEN) + model = te.flax.DenseGeneral(features=HIDDEN) def loss_fn(params, other_vars, inp): out = model.apply({'params':params, **other_vars}, inp) diff --git a/docs/api/jax.rst b/docs/api/jax.rst index e049c70e50..13b276c3a1 100644 --- a/docs/api/jax.rst +++ b/docs/api/jax.rst @@ -9,34 +9,33 @@ Jax .. autoapiclass:: transformer_engine.jax.MajorShardingType .. autoapiclass:: transformer_engine.jax.ShardingType .. autoapiclass:: transformer_engine.jax.TransformerLayerType +.. autoapiclass:: transformer_engine.jax.ShardingResource(dp_resource=None, tp_resource=None) -.. autoapiclass:: transformer_engine.jax.ShardingResource(dp_resource=None, tp_resource=None) +.. autoapifunction:: transformer_engine.jax.fp8_autocast +.. autoapifunction:: transformer_engine.jax.update_collections +.. autoapifunction:: transformer_engine.jax.update_fp8_metas -.. autoapiclass:: transformer_engine.jax.LayerNorm(epsilon=1e-6, layernorm_type='layernorm', **kwargs) +.. autoapiclass:: transformer_engine.jax.flax.LayerNorm(epsilon=1e-6, layernorm_type='layernorm', **kwargs) :members: __call__ -.. autoapiclass:: transformer_engine.jax.DenseGeneral(features, layernorm_type='layernorm', use_bias=False, **kwargs) +.. autoapiclass:: transformer_engine.jax.flax.DenseGeneral(features, layernorm_type='layernorm', use_bias=False, **kwargs) :members: __call__ -.. autoapiclass:: transformer_engine.jax.LayerNormDenseGeneral(features, layernorm_type='layernorm', epsilon=1e-6, use_bias=False, **kwargs) +.. autoapiclass:: transformer_engine.jax.flax.LayerNormDenseGeneral(features, layernorm_type='layernorm', epsilon=1e-6, use_bias=False, **kwargs) :members: __call__ -.. autoapiclass:: transformer_engine.jax.LayerNormMLP(intermediate_dim=2048, layernorm_type='layernorm', epsilon=1e-6, use_bias=False, **kwargs) +.. autoapiclass:: transformer_engine.jax.flax.LayerNormMLP(intermediate_dim=2048, layernorm_type='layernorm', epsilon=1e-6, use_bias=False, **kwargs) :members: __call__ -.. autoapiclass:: transformer_engine.jax.RelativePositionBiases(num_buckets, max_distance, num_heads, **kwargs) +.. autoapiclass:: transformer_engine.jax.flax.RelativePositionBiases(num_buckets, max_distance, num_heads, **kwargs) :members: __call__ -.. autoapiclass:: transformer_engine.jax.MultiHeadAttention(head_dim, num_heads, **kwargs) +.. autoapiclass:: transformer_engine.jax.flax.MultiHeadAttention(head_dim, num_heads, **kwargs) :members: __call__ -.. autoapiclass:: transformer_engine.jax.TransformerLayer(hidden_size=512, mlp_hidden_size=2048, num_attention_heads=8, **kwargs) +.. autoapiclass:: transformer_engine.jax.flax.TransformerLayer(hidden_size=512, mlp_hidden_size=2048, num_attention_heads=8, **kwargs) :members: __call__ - -.. autoapifunction:: transformer_engine.jax.extend_logical_axis_rules -.. autoapifunction:: transformer_engine.jax.fp8_autocast -.. autoapifunction:: transformer_engine.jax.update_collections -.. autoapifunction:: transformer_engine.jax.update_fp8_metas \ No newline at end of file +.. autoapifunction:: transformer_engine.jax.flax.extend_logical_axis_rules diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index 10c880710e..ff09f1b84e 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -59,7 +59,7 @@ class Net(nn.Module): def __call__(self, x, mask, disable_dropout=False): x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x) - te_Encoder = partial(te.TransformerLayer, + te_Encoder = partial(te.flax.TransformerLayer, hidden_size=256, mlp_hidden_size=1024, num_attention_heads=8, @@ -73,17 +73,17 @@ def __call__(self, x, mask, disable_dropout=False): x = x.reshape(x.shape[0], -1) - x = te.DenseGeneral(features=256, - kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS), - bias_axes=(NAMED_TP_AXIS,), - sharding_type=te.ShardingType.DP_TP_COL, - dtype=jnp.bfloat16)(x) - - x = te.DenseGeneral(features=256, - kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS), - bias_axes=(NAMED_BROADCAST_AXIS,), - sharding_type=te.ShardingType.DP_TP_ROW, - dtype=jnp.bfloat16)(x) + x = te.flax.DenseGeneral(features=256, + kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS), + bias_axes=(NAMED_TP_AXIS,), + sharding_type=te.ShardingType.DP_TP_COL, + dtype=jnp.bfloat16)(x) + + x = te.flax.DenseGeneral(features=256, + kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS), + bias_axes=(NAMED_BROADCAST_AXIS,), + sharding_type=te.ShardingType.DP_TP_ROW, + dtype=jnp.bfloat16)(x) x = nn.Dense(features=2, dtype=jnp.bfloat16)(x) return x diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index 9cb420b0c8..5f06ddf879 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -56,7 +56,7 @@ class Net(nn.Module): def __call__(self, x, mask, disable_dropout=False): x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x) - te_Encoder = partial(te.TransformerLayer, + te_Encoder = partial(te.flax.TransformerLayer, hidden_size=256, mlp_hidden_size=1024, num_attention_heads=8, @@ -70,9 +70,11 @@ def __call__(self, x, mask, disable_dropout=False): x = x.reshape(x.shape[0], -1) - x = te.DenseGeneral(features=256, sharding_type=te.ShardingType.DP, dtype=jnp.bfloat16)(x) + x = te.flax.DenseGeneral(features=256, sharding_type=te.ShardingType.DP, + dtype=jnp.bfloat16)(x) - x = te.DenseGeneral(features=256, sharding_type=te.ShardingType.DP, dtype=jnp.bfloat16)(x) + x = te.flax.DenseGeneral(features=256, sharding_type=te.ShardingType.DP, + dtype=jnp.bfloat16)(x) x = nn.Dense(features=2, dtype=jnp.bfloat16)(x) return x diff --git a/examples/jax/encoder/test_single_gpu_encoder.py b/examples/jax/encoder/test_single_gpu_encoder.py index bac1469b5b..ea6c0abd51 100644 --- a/examples/jax/encoder/test_single_gpu_encoder.py +++ b/examples/jax/encoder/test_single_gpu_encoder.py @@ -46,7 +46,7 @@ class Net(nn.Module): def __call__(self, x, mask, disable_dropout=False): x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x) - te_Encoder = partial(te.TransformerLayer, + te_Encoder = partial(te.flax.TransformerLayer, hidden_size=256, mlp_hidden_size=1024, num_attention_heads=8, @@ -60,9 +60,9 @@ def __call__(self, x, mask, disable_dropout=False): x = x.reshape(x.shape[0], -1) - x = te.DenseGeneral(features=256, dtype=jnp.bfloat16)(x) + x = te.flax.DenseGeneral(features=256, dtype=jnp.bfloat16)(x) - x = te.DenseGeneral(features=256, dtype=jnp.bfloat16)(x) + x = te.flax.DenseGeneral(features=256, dtype=jnp.bfloat16)(x) x = nn.Dense(features=2, dtype=jnp.bfloat16)(x) return x diff --git a/examples/jax/mnist/test_single_gpu_mnist.py b/examples/jax/mnist/test_single_gpu_mnist.py index 0b16dd8b98..3b8e2d0bd9 100644 --- a/examples/jax/mnist/test_single_gpu_mnist.py +++ b/examples/jax/mnist/test_single_gpu_mnist.py @@ -47,7 +47,7 @@ class Net(nn.Module): @nn.compact def __call__(self, x, disable_dropout=False): if self.use_te: - nn_Dense = te.DenseGeneral + nn_Dense = te.flax.DenseGeneral else: nn_Dense = nn.Dense diff --git a/tests/jax/test_layer.py b/tests/jax/test_layer.py index 1522fa198b..c959f7abcf 100644 --- a/tests/jax/test_layer.py +++ b/tests/jax/test_layer.py @@ -10,7 +10,7 @@ import pytest from transformer_engine.common.recipe import Format -from transformer_engine.jax import TransformerLayer, TransformerLayerType +from transformer_engine.jax.flax import TransformerLayer, TransformerLayerType from transformer_engine.jax.fp8 import FP8Helper from utils import assert_allclose, is_fp8_supported from utils import DecoderLayer as RefDecoderLayer diff --git a/tests/jax/test_sharding.py b/tests/jax/test_sharding.py index 458e10ffac..cd135752c0 100644 --- a/tests/jax/test_sharding.py +++ b/tests/jax/test_sharding.py @@ -7,7 +7,7 @@ import pytest from jax.experimental import maps -from transformer_engine.jax import extend_logical_axis_rules +from transformer_engine.jax.flax import extend_logical_axis_rules from transformer_engine.jax.sharding import get_dot_sharding_meta from transformer_engine.jax.sharding import get_elementwise_sharding_meta from transformer_engine.jax.sharding import get_fp8_meta_sharding_meta diff --git a/transformer_engine/common/utils.py b/transformer_engine/common/utils.py new file mode 100644 index 0000000000..cf35108673 --- /dev/null +++ b/transformer_engine/common/utils.py @@ -0,0 +1,53 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""The utilities for Transformer Engine""" +import inspect +import warnings +from enum import Enum + +warnings.simplefilter('default') + + +class DeprecatedEnum: # pylint: disable=too-few-public-methods + """DeprecatedEnum""" + + def __init__(self, enum_cls, msg): + self.enum_cls = enum_cls + self.msg = msg + + def __iter__(self): + return iter(list(self.enum_cls.__members__.values())) + + def __getattr__(self, name): + if name in self.enum_cls.__members__: + warnings.warn(self.msg, DeprecationWarning) + return self.enum_cls.__members__[name] + raise AttributeError(f"{self.enum_cls} does not contain {name}") + + +def deprecate_wrapper(obj, msg): + """Deprecate wrapper""" + if inspect.isclass(obj): + if issubclass(obj, Enum): + return DeprecatedEnum(obj, msg) + + class DeprecatedCls(obj): # pylint: disable=too-few-public-methods + """DeprecatedCls""" + + def __init__(self, *args, **kwargs): + warnings.warn(msg, DeprecationWarning) + super().__init__(*args, **kwargs) + + return DeprecatedCls + + if inspect.isfunction(obj): + + def deprecated(*args, **kwargs): + warnings.warn(msg, DeprecationWarning) + return obj(*args, **kwargs) + + return deprecated + + raise NotImplementedError( + f"deprecate_cls_wrapper only support Class and Function, but got {type(obj)}.") diff --git a/transformer_engine/jax/__init__.py b/transformer_engine/jax/__init__.py index 750a34fb5b..9b7c2f224f 100644 --- a/transformer_engine/jax/__init__.py +++ b/transformer_engine/jax/__init__.py @@ -2,10 +2,41 @@ # # See LICENSE for license information. """Transformer Engine bindings for JAX""" + +from . import flax from .fp8 import fp8_autocast, update_collections, update_fp8_metas, get_delayed_scaling -from .module import DenseGeneral, LayerNorm -from .module import LayerNormDenseGeneral, LayerNormMLP, TransformerEngineBase -from .transformer import extend_logical_axis_rules -from .transformer import MultiHeadAttention, RelativePositionBiases -from .transformer import TransformerLayer, TransformerLayerType from .sharding import MajorShardingType, ShardingResource, ShardingType +from ..common.utils import deprecate_wrapper + +extend_logical_axis_rules = deprecate_wrapper( + flax.extend_logical_axis_rules, + "extend_logical_axis_rules is moving to transformer_engine.jax.flax module") +DenseGeneral = deprecate_wrapper(flax.DenseGeneral, + "DenseGeneral is moving to transformer_engine.jax.flax module") +LayerNorm = deprecate_wrapper(flax.LayerNorm, + "LayerNorm is moving to transformer_engine.jax.flax module") +LayerNormDenseGeneral = deprecate_wrapper( + flax.LayerNormDenseGeneral, + "LayerNormDenseGeneral is moving to transformer_engine.jax.flax module") +LayerNormMLP = deprecate_wrapper(flax.LayerNormMLP, + "LayerNormMLP is moving to transformer_engine.jax.flax module") +TransformerEngineBase = deprecate_wrapper( + flax.TransformerEngineBase, + "TransformerEngineBase is moving to transformer_engine.jax.flax module") +MultiHeadAttention = deprecate_wrapper( + flax.MultiHeadAttention, "MultiHeadAttention is moving to transformer_engine.jax.flax module") +RelativePositionBiases = deprecate_wrapper( + flax.RelativePositionBiases, + "RelativePositionBiases is moving to transformer_engine.jax.flax module") +TransformerLayer = deprecate_wrapper( + flax.TransformerLayer, "TransformerLayer is moving to transformer_engine.jax.flax module") +TransformerLayerType = deprecate_wrapper( + flax.TransformerLayerType, + "TransformerLayerType is moving to transformer_engine.jax.flax module") + +__all__ = [ + 'fp8_autocast', 'update_collections', 'update_fp8_metas', 'get_delayed_scaling', + 'MajorShardingType', 'ShardingResource', 'ShardingType', 'flax', 'DenseGeneral', 'LayerNorm', + 'LayerNormDenseGeneral', 'LayerNormMLP', 'TransformerEngineBase', 'MultiHeadAttention', + 'RelativePositionBiases', 'TransformerLayer', 'TransformerLayerType' +] diff --git a/transformer_engine/jax/flax/__init__.py b/transformer_engine/jax/flax/__init__.py new file mode 100644 index 0000000000..5dd8f9bdf1 --- /dev/null +++ b/transformer_engine/jax/flax/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Transformer Engine bindings for JAX""" +from .module import DenseGeneral, LayerNorm +from .module import LayerNormDenseGeneral, LayerNormMLP, TransformerEngineBase +from .transformer import extend_logical_axis_rules +from .transformer import MultiHeadAttention, RelativePositionBiases +from .transformer import TransformerLayer, TransformerLayerType diff --git a/transformer_engine/jax/module.py b/transformer_engine/jax/flax/module.py similarity index 97% rename from transformer_engine/jax/module.py rename to transformer_engine/jax/flax/module.py index af96b95ada..f9924c600f 100644 --- a/transformer_engine/jax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -16,15 +16,15 @@ from jax import nn as jax_nn from jax import random as jax_random -from .dot import fp8_dot -from .fp8 import FP8GemmPackage, FP8Helper -from .layernorm import canonicalize_layernorm_type -from .layernorm import layernorm, layernorm_fp8_dot -from .mlp import fp8_ln_mlp, geglu -from .sharding import infer_sharding_type -from .softmax import is_softmax_kernel_available -from .sharding import MajorShardingType, ShardingType -from .softmax import softmax, SoftmaxType +from ..dot import fp8_dot +from ..fp8 import FP8GemmPackage, FP8Helper +from ..layernorm import canonicalize_layernorm_type +from ..layernorm import layernorm, layernorm_fp8_dot +from ..mlp import fp8_ln_mlp, geglu +from ..sharding import infer_sharding_type +from ..softmax import is_softmax_kernel_available +from ..sharding import MajorShardingType, ShardingType +from ..softmax import softmax, SoftmaxType PRNGKey = Any Shape = Tuple[int, ...] @@ -46,6 +46,13 @@ def _canonicalize_tuple(x): return (x,) +def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_gamma): + if original_init is None: + if not zero_centered_gamma: + return nn.initializers.ones + return nn.initializers.zeros + + def _create_layernorm_parameters(layernorm_type, shape, scale_init, scale_axes, bias_init, bias_axes, dtype): scale = nn_partitioning.param_with_axes('scale', @@ -250,11 +257,8 @@ class LayerNorm(nn.Module): sharding_type: ShardingType = ShardingType.SINGLE def __post_init__(self): - if self.scale_init is None: - if not self.zero_centered_gamma: - self.scale_init = nn.initializers.ones - else: - self.scale_init = nn.initializers.zeros + self.scale_init = _obtain_default_layernorm_scale_init_if_need( + self.scale_init, self.zero_centered_gamma) super().__post_init__() @nn.compact @@ -549,11 +553,8 @@ class LayerNormDenseGeneral(TransformerEngineBase): def __post_init__(self): if self.kernel_init is None: self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal') - if self.scale_init is None: - if not self.zero_centered_gamma: - self.scale_init = nn.initializers.ones - else: - self.scale_init = nn.initializers.zeros + self.scale_init = _obtain_default_layernorm_scale_init_if_need( + self.scale_init, self.zero_centered_gamma) super().__post_init__() @nn.compact @@ -781,11 +782,8 @@ class LayerNormMLP(TransformerEngineBase): def __post_init__(self): if self.kernel_init is None: self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal') - if self.scale_init is None: - if not self.zero_centered_gamma: - self.scale_init = nn.initializers.ones - else: - self.scale_init = nn.initializers.zeros + self.scale_init = _obtain_default_layernorm_scale_init_if_need( + self.scale_init, self.zero_centered_gamma) super().__post_init__() @nn.compact diff --git a/transformer_engine/jax/transformer.py b/transformer_engine/jax/flax/transformer.py similarity index 99% rename from transformer_engine/jax/transformer.py rename to transformer_engine/jax/flax/transformer.py index 2ec33cf5b6..aaecab7b51 100644 --- a/transformer_engine/jax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -18,9 +18,9 @@ from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP from .module import LayerNorm, Softmax -from .softmax import SoftmaxType -from .sharding import infer_major_sharding_type, infer_sharding_type -from .sharding import global_shard_resource, ShardingType +from ..softmax import SoftmaxType +from ..sharding import infer_major_sharding_type, infer_sharding_type +from ..sharding import global_shard_resource, ShardingType PRNGKey = Any Shape = Tuple[int, ...] From 550da28957304219a25deffcef007e88ec86ba10 Mon Sep 17 00:00:00 2001 From: cyanguwa <8636796+cyanguwa@users.noreply.github.com> Date: Sat, 29 Apr 2023 10:41:14 -0700 Subject: [PATCH 27/68] Correct cuDNN version requirement (#184) correct cuDNN version requirement Signed-off-by: Charlene Yang --- docs/installation.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/installation.rst b/docs/installation.rst index 9aded82d0f..89f9fd549d 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -14,7 +14,7 @@ Prerequisites 1. Linux x86_64 2. `CUDA 11.8 `__ 3. |driver link|_ supporting CUDA 11.8 or later. -4. `cuDNN 8 `__ or later. +4. `cuDNN 8.1 `__ or later. 5. For FP8 fused attention, `CUDA 12.1 `__ or later, |driver link|_ supporting CUDA 12.1 or later, and `cuDNN 8.9 `__ or later. From d3d419117f28af637968c7c1f175656eb72ec94d Mon Sep 17 00:00:00 2001 From: Sangkug Lym Date: Tue, 2 May 2023 07:20:52 -0700 Subject: [PATCH 28/68] Use separate streams for pushsend/recv kernels in UB p2p exchanges (#188) * using different strems for pushsend and pushrecv Signed-off-by: Sangkug Lym * fix stream dependency Signed-off-by: Sangkug Lym * add wait from main_stream to memcpy stream Signed-off-by: Sangkug Lym --------- Signed-off-by: Sangkug Lym Co-authored-by: Kirthi Shankar Sivamani --- .../pytorch/csrc/comm_gemm_overlap.h | 53 +++++++++++-------- 1 file changed, 30 insertions(+), 23 deletions(-) diff --git a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h index 1e8b96f46b..5dd71e4758 100644 --- a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h +++ b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h @@ -332,9 +332,10 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder { void *_ubuf_ptr; torch::Tensor _ubuf; std::vector _ubufs; - at::cuda::CUDAStream _stream_comm = at::cuda::getStreamFromPool(true); + at::cuda::CUDAStream _stream_send = at::cuda::getStreamFromPool(true); + at::cuda::CUDAStream _stream_recv = at::cuda::getStreamFromPool(true); std::vector _stream_compute; - cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_comm, _start_accum, _stop_accum; + cudaEvent_t _start_compute, _stop_compute, _stop_send, _stop_recv; UbufP2PCommOverlap(torch::Tensor sample, int rank, int tp_size, bool aggregate2, int num_max_streams) { @@ -385,10 +386,8 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder { // CUDA event creation cudaEventCreateWithFlags(&_start_compute, 0); cudaEventCreateWithFlags(&_stop_compute, 0); - cudaEventCreateWithFlags(&_start_comm, 0); - cudaEventCreateWithFlags(&_stop_comm, 0); - cudaEventCreateWithFlags(&_start_accum, 0); - cudaEventCreateWithFlags(&_stop_accum, 0); + cudaEventCreateWithFlags(&_stop_send, 0); + cudaEventCreateWithFlags(&_stop_recv, 0); } /* @@ -430,7 +429,8 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder { assert(pre_gelu_out.numel() == 0); if (_aggregate2) { // Catch up the default torch stream - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_compute, 0)); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0)); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0)); const int num_steps = _tp_size / 2; char *input_b_ptr = reinterpret_cast(_ubuf.data_ptr()); @@ -442,11 +442,12 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder { int recv_offset = comm_bytes * recv_chunk_id; int peer_rank = (_tp_id % 2 == 0) ? _next_rank : _prev_rank; userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm, peer_rank, - (cudaStream_t)_stream_comm); + (cudaStream_t)_stream_send); userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, peer_rank, - (cudaStream_t)_stream_comm); - CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)_stream_comm)); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _start_compute, 0)); + (cudaStream_t)_stream_recv); + CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _stop_recv, 0)); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _stop_recv, 0)); int local_rank_round2 = (_tp_id % 2 == 0) ? _tp_id : _tp_id - 1; const int next_rank = (_tp_size + _tp_id + 2) % _tp_size + _rank_round_tp; @@ -476,18 +477,21 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder { if (i < num_steps - 1) { // P2P communication userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes * 2, _ub_comm, - next_rank, (cudaStream_t)_stream_comm); + next_rank, (cudaStream_t)_stream_send); userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes * 2, _ub_comm, - prev_rank, (cudaStream_t)_stream_comm); - CHECK_CUDA(cudaEventRecord(_stop_comm, (cudaStream_t)_stream_comm)); + prev_rank, (cudaStream_t)_stream_recv); + CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _stop_recv, 0)); CHECK_CUDA(cudaStreamWaitEvent( - (cudaStream_t)_stream_compute[(i + 1) % _stream_compute.size()], _stop_comm, 0)); + (cudaStream_t)_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); } else if (B_copy.numel() > 0) { assert(B_copy.numel() == _ubufs[_tp_id].numel()); assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_tp_id].data_ptr(), _ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(), - cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_comm)); + cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send)); + CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send)); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0)); } } at::cuda::setCurrentCUDAStream(stream_main); @@ -497,7 +501,8 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder { cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[last_compute_stream_id])); } else { // Catch up the default torch stream - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_compute, 0)); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0)); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0)); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _start_compute, 0)); for (int i = 0; i < _tp_size; i++) { @@ -524,18 +529,21 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder { if (i < _tp_size - 1) { // P2P communication userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm, - _next_rank, (cudaStream_t)_stream_comm); + _next_rank, (cudaStream_t)_stream_send); userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, - _prev_rank, (cudaStream_t)_stream_comm); - CHECK_CUDA(cudaEventRecord(_stop_comm, (cudaStream_t)_stream_comm)); + _prev_rank, (cudaStream_t)_stream_recv); + CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _stop_recv, 0)); CHECK_CUDA(cudaStreamWaitEvent( - (cudaStream_t)_stream_compute[(i + 1) % _stream_compute.size()], _stop_comm, 0)); + (cudaStream_t)_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); } else if (B_copy.numel() > 0) { assert(B_copy.numel() == _ubufs[_tp_id].numel()); assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_tp_id].data_ptr(), _ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(), - cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_comm)); + cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send)); + CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send)); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0)); } } at::cuda::setCurrentCUDAStream(stream_main); @@ -544,7 +552,6 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder { cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[last_compute_stream_id])); } CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _stop_compute, 0)); return D; } // split_overlap_ag From 8e5f00f203ee518961bfb8febb017a2ffcc1d6b3 Mon Sep 17 00:00:00 2001 From: Shriya Palsamudram <69161273+ShriyaPalsamudram@users.noreply.github.com> Date: Wed, 10 May 2023 13:22:10 -0400 Subject: [PATCH 29/68] Shriya/tp overlap patch (#205) userbuffer pushsend/recv fix with atomicAdd_system Signed-off-by: Sangkug Lym Co-authored-by: Sangkug Lym --- transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu b/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu index 9144e9e739..2c8e9dc61d 100644 --- a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu +++ b/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu @@ -1551,7 +1551,7 @@ __global__ void __launch_bounds__(MAX_THREADS) __threadfence_system(); atomicAdd(flagptr, 1); // otherwise need local SM sync before sending flag } else { // 0 bytes and 1 SM only - atomicAdd(flagptr, 1); + atomicAdd_system(flagptr, 1); } } @@ -1561,7 +1561,7 @@ __global__ void kuserbuffers_pushrecv(int myrank, int peer, int *recv_id, int *f volatile int *flag = (volatile int *)flagptr; if (*flag >= signal_id) return; clock_t s = clock64(); - while (*flag < signal_id) { + while (atomicAdd_system(flagptr, 0) < signal_id) { if (clock64() - s > TIMEOUT) { printf("%d from %d] pushrecv: expected %d, stuck with %d\n", myrank, peer, signal_id, *flag); return; From f92c430e56c7f74de389a2a55f79d186b06ceeb5 Mon Sep 17 00:00:00 2001 From: cyanguwa <8636796+cyanguwa@users.noreply.github.com> Date: Mon, 22 May 2023 13:55:33 -0700 Subject: [PATCH 30/68] Relax checks for attn_mask_type in FlashAttention (#226) * relax attn mask type checks for FlashAttention Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * disable flash attn if mask tensor is not None Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix the logic for flash attn Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor fix for lint Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 41b4d5fcd4..29e6412b02 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -281,9 +281,6 @@ def __init__( assert ( _flash_attn_version >= _flash_attn_version_required ), f"FlashAttention minimum version {_flash_attn_version_required} is required." - assert ( - attn_mask_type == "causal" - ), 'FlashAttention currently only supports causal attention mask.' self.attn_causal_mask = attn_mask_type == "causal" self.norm_factor = norm_factor @@ -296,7 +293,6 @@ def forward( query_layer: torch.Tensor, key_layer: torch.Tensor, value_layer: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """flash-attn fprop""" @@ -308,9 +304,6 @@ def forward( assert ( query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda ), 'FlashAttention currently only supports CUDA tensors.' - assert ( - attention_mask is None - ), 'FlashAttention currently does not support external attention mask.' # For now just 128, will make it more general in the future @@ -428,7 +421,6 @@ def __init__( self.device_compute_capability = get_device_compute_capability() self.use_flash_attention = ( int(os.getenv("NVTE_FLASH_ATTN", "1")) - and attn_mask_type == "causal" and self.device_compute_capability >= 8.0 ) @@ -437,6 +429,7 @@ def __init__( "attention_dropout_ctx": attention_dropout_ctx, "attn_mask_type": attn_mask_type, } + self.attn_mask_type = attn_mask_type if self.use_flash_attention: self.flash_attention = FlashAttention(norm_factor, **attn_kwargs) @@ -514,6 +507,9 @@ def forward( ): use_flash_attention = False + if self.attn_mask_type == "padding" and attention_mask is not None: + use_flash_attention = False + if is_in_onnx_export_mode(): use_flash_attention = False From 06cacd205e317d9ce804a87b686ada89e967912d Mon Sep 17 00:00:00 2001 From: zlsh80826 Date: Tue, 23 May 2023 13:14:32 +0800 Subject: [PATCH 31/68] Jax bug fixes for the dot product attention (#236) * Unfused scale+softmax if bias is present Signed-off-by: Reese Wang * WAR a causal masking + no_bias bug and add the unittests Signed-off-by: Reese Wang * Fix the optional args (bias) sharding Signed-off-by: Reese Wang * Disable fused attn in JAX by default, enable it with NVTE_USE_FUSED_ATTN Signed-off-by: Reese Wang * Add thread local for the plan cache Signed-off-by: Reese Wang * Rename dbeta to dbias for the readability Signed-off-by: Reese Wang * Add scaled softmax with dropout test cases Signed-off-by: Reese Wang * Updated NVTE_FUSED_ATTN variable name Signed-off-by: Reese Wang --------- Signed-off-by: Reese Wang --- tests/jax/test_fused_attn.py | 67 ++++++++++++------- tests/jax/test_layer.py | 6 ++ .../fused_attn_fp16_bf16_max_seqlen_512.cu | 8 +-- .../common/fused_attn/fused_attn_fp8.cu | 4 +- transformer_engine/jax/flax/transformer.py | 17 ++++- transformer_engine/jax/sharding.py | 4 +- 6 files changed, 71 insertions(+), 35 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index fb333275bb..2504960705 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -113,7 +113,7 @@ def customcall_cross_fused_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs) reason="Fused attention kernel is not supported.") class TestSelfFusedAttnMax512(): - def set_input(self, b, s, h, d, dtype, attn_mask_type, pad_ratio): + def set_input(self, b, s, h, d, dtype, attn_mask_type, pad_ratio, with_bias): key = jax.random.PRNGKey(0) subkeys = jax.random.split(key, 2) @@ -125,7 +125,8 @@ def set_input(self, b, s, h, d, dtype, attn_mask_type, pad_ratio): min_val, max_val = -1, 1 self.qkv = jax.random.uniform(subkeys[0], qkv_shape, dtype, min_val, max_val) - self.bias = jax.random.uniform(subkeys[1], bias_shape, dtype, min_val, max_val) + self.bias = jax.random.uniform(subkeys[1], bias_shape, dtype, min_val, + max_val) if with_bias else None self.q_token = jnp.concatenate((jnp.ones((b, self.valid_len)), jnp.zeros((b, pad_len))), axis=-1) @@ -133,8 +134,8 @@ def set_input(self, b, s, h, d, dtype, attn_mask_type, pad_ratio): self.scaling_factor = 1. / math.sqrt(d) self.dropout_probability = 0. - self.dropout_rng = jax.random.PRNGKey(0) - self.attn_bias_type = AttnBiasType.POST_SCALE_BIAS + self.dropout_rng = jax.random.PRNGKey(0) if self.dropout_probability > 0 else None + self.attn_bias_type = AttnBiasType.NO_BIAS if self.bias is None else AttnBiasType.POST_SCALE_BIAS # deterministic = not is_training self.deterministic = False @@ -143,9 +144,17 @@ def set_input(self, b, s, h, d, dtype, attn_mask_type, pad_ratio): @pytest.mark.parametrize('attn_mask_type', [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]) @pytest.mark.parametrize('pad_ratio', PAD_RATIO) - def test_forward(self, b, s, h, d, dtype, attn_mask_type, pad_ratio): + @pytest.mark.parametrize('with_bias', [True, False]) + def test_forward(self, b, s, h, d, dtype, attn_mask_type, pad_ratio, with_bias): - self.set_input(b, s, h, d, dtype=dtype, attn_mask_type=attn_mask_type, pad_ratio=pad_ratio) + self.set_input(b, + s, + h, + d, + dtype=dtype, + attn_mask_type=attn_mask_type, + pad_ratio=pad_ratio, + with_bias=with_bias) primitive_out = customcall_self_fused_attn(self.qkv, self.bias, @@ -183,8 +192,16 @@ def test_forward(self, b, s, h, d, dtype, attn_mask_type, pad_ratio): [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]) @pytest.mark.parametrize('dtype', DTYPES) @pytest.mark.parametrize('pad_ratio', PAD_RATIO) - def test_forward_backward(self, b, s, h, d, dtype, attn_mask_type, pad_ratio): - self.set_input(b, s, h, d, dtype=dtype, attn_mask_type=attn_mask_type, pad_ratio=pad_ratio) + @pytest.mark.parametrize('with_bias', [True, False]) + def test_forward_backward(self, b, s, h, d, dtype, attn_mask_type, pad_ratio, with_bias): + self.set_input(b, + s, + h, + d, + dtype=dtype, + attn_mask_type=attn_mask_type, + pad_ratio=pad_ratio, + with_bias=with_bias) def grad_func(fused_attn_max_512_func, *args, **kwargs): # Gradient is small, use a gradient multiplier to amplify the graident @@ -221,11 +238,11 @@ def grad_func(fused_attn_max_512_func, *args, **kwargs): (0, 1))) primitive_out, (primitive_dqkv, - primitive_dbeta) = jitted_primitive(self.qkv, self.bias, self.q_token, + primitive_dbias) = jitted_primitive(self.qkv, self.bias, self.q_token, self.kv_token, self.dropout_rng) reference_out, (reference_dqkv, - reference_dbeta) = jitted_reference(self.qkv, self.bias, self.q_token, + reference_dbias) = jitted_reference(self.qkv, self.bias, self.q_token, self.kv_token, self.dropout_rng) np.testing.assert_allclose(jnp.asarray(primitive_out, np.float32), @@ -261,20 +278,22 @@ def grad_func(fused_attn_max_512_func, *args, **kwargs): # Padded part should be 0s assert jnp.allclose(invalid_primitive_dqkv, jnp.zeros_like(invalid_primitive_dqkv)) - # dbeta valid part - np.testing.assert_allclose( - jnp.asarray(primitive_dbeta[:, :, :self.valid_len, :self.valid_len], np.float32), - jnp.asarray(reference_dbeta[:, :, :self.valid_len, :self.valid_len], np.float32), - rtol=1e-4, - atol=3e-5) - - # dbeta padded part - np.testing.assert_allclose( - jnp.asarray(primitive_dbeta[:, :, self.valid_len:, self.valid_len:], np.float32), - jnp.asarray(reference_dbeta[:, :, self.valid_len:, self.valid_len:], np.float32)) - - assert jnp.allclose(primitive_dbeta[:, :, self.valid_len:, self.valid_len:], - jnp.zeros_like(primitive_dbeta[:, :, self.valid_len:, self.valid_len:])) + if self.attn_bias_type != AttnBiasType.NO_BIAS: + # dbias valid part + np.testing.assert_allclose( + jnp.asarray(primitive_dbias[:, :, :self.valid_len, :self.valid_len], np.float32), + jnp.asarray(reference_dbias[:, :, :self.valid_len, :self.valid_len], np.float32), + rtol=1e-4, + atol=3e-5) + + # dbias padded part + np.testing.assert_allclose( + jnp.asarray(primitive_dbias[:, :, self.valid_len:, self.valid_len:], np.float32), + jnp.asarray(reference_dbias[:, :, self.valid_len:, self.valid_len:], np.float32)) + + assert jnp.allclose( + primitive_dbias[:, :, self.valid_len:, self.valid_len:], + jnp.zeros_like(primitive_dbias[:, :, self.valid_len:, self.valid_len:])) @pytest.mark.skipif(not is_fused_attn_kernel_available(), diff --git a/tests/jax/test_layer.py b/tests/jax/test_layer.py index 9cce15aa70..30143e5f75 100644 --- a/tests/jax/test_layer.py +++ b/tests/jax/test_layer.py @@ -102,6 +102,12 @@ def compare_frozen_dict(ref_fd, test_fd, rtol=1e-05, atol=1e-08): _KEY_OF_DROPOUT_RATE: 0.0, _KEY_OF_MLP_ACTIVATIONS: (('gelu', 'linear')), _KEY_OF_FUSE_MLP_WI: True +}, { + _KEY_OF_SCALE_ATTN_LOGITS: True, + _KEY_OF_LAYERNORM_TYPE: 'rmsnorm', + _KEY_OF_DROPOUT_RATE: 0.8, + _KEY_OF_MLP_ACTIVATIONS: (('gelu', 'linear')), + _KEY_OF_FUSE_MLP_WI: True }, { _KEY_OF_TRANSPOSE_BS: False, _KEY_OF_SCALE_ATTN_LOGITS: True, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp16_bf16_max_seqlen_512.cu b/transformer_engine/common/fused_attn/fused_attn_fp16_bf16_max_seqlen_512.cu index c01018137b..53f4f72636 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp16_bf16_max_seqlen_512.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp16_bf16_max_seqlen_512.cu @@ -327,7 +327,6 @@ static cudnn_frontend::Tensor createSoftmaxForward( // NOLINTNEXTLINE(runtime/references) std::vector &ops, cudnn_frontend::Tensor const &prevBlockOutputTensor) { - int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv}; int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1}; @@ -645,7 +644,7 @@ void fused_attn_max_512_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv mask_type, tensorType}; using CacheType = std::map; - static CacheType fmha_fprop_cache; + static thread_local CacheType fmha_fprop_cache; bool enable_dropout = (dropout_probability != 0.0f); @@ -668,7 +667,8 @@ void fused_attn_max_512_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv createScale(b, h, s_q, s_kv, d, layout, tensorType, ops); // if bias, we need to memset the S buffer to correctly computate dbias - auto zero_s = (bias_type != NVTE_Bias_Type::NVTE_NO_BIAS); + auto zero_s = (bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) || + (mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK); auto bmm1_output = createBMM1(b, h, s_q, s_kv, d, layout, tensorType, zero_s, ops); NVTE_CHECK(bias_type != NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS, @@ -814,7 +814,7 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv layout, bias_type, mask_type, tensorType}; using CacheType = std::map; - static CacheType fmha_bprop_cache; + static thread_local CacheType fmha_bprop_cache; auto get_plan = [&](CacheType &cache, const FADescriptor &descriptor) { auto it = cache.find(descriptor); diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index be483b8af5..768ac8eb20 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1016,7 +1016,7 @@ void fa_fwd_fp8(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d, NVTE_Bias_Type::NVTE_NO_BIAS, NVTE_Mask_Type::NVTE_PADDING_MASK, tensorType}; using CacheType = std::map; - static CacheType fa_fprop_cache; + static thread_local CacheType fa_fprop_cache; // Get plan from cache if cache is available, otherwise create one auto get_plan = [&](CacheType &cache, const FADescriptor &descriptor) { @@ -1332,7 +1332,7 @@ void fa_bwd_fp8(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d, NVTE_Bias_Type::NVTE_NO_BIAS, NVTE_Mask_Type::NVTE_PADDING_MASK, tensorType}; using CacheType = std::map; - static CacheType fa_bprop_cache; + static thread_local CacheType fa_bprop_cache; // Get plan from cache if cache is available, otherwise create one auto get_plan = [&](CacheType &cache, const FADescriptor &descriptor) { diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index a6b9f92b6f..3b4a61f3aa 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -7,6 +7,7 @@ import functools from enum import Enum from math import sqrt +import os from typing import Any, Callable, Optional, Sequence, Tuple, Union import warnings @@ -165,8 +166,17 @@ def core_attention(query: Array, else: attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key) + # When a bias is present, the computation is performed as Softmax(attn_weights * scale + bias). + # In this case, the scale can not fused into the Softmax module. + if bias is not None: + attn_weights = attn_weights * scale_factor + fused_scale_factor = 1. + else: + # If no bias, the scale can be fused into Softmax module + fused_scale_factor = scale_factor + attn_weights = Softmax(softmax_type=softmax_type, - scale_factor=scale_factor, + scale_factor=fused_scale_factor, sharding_type=softmax_sharding_type)(attn_weights, mask, bias) if not deterministic and dropout_rate > 0.: @@ -360,12 +370,13 @@ def kv_init(key, shape, dtype): q_seqlen = inputs_q.shape[0] if self.transpose_batch_sequence else inputs_q.shape[1] kv_seqlen = inputs_kv.shape[0] if self.transpose_batch_sequence else inputs_kv.shape[1] fused_attn_supported_seqlen = [128, 256, 384, 512] + enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "0")) use_fused_attn = not decode and not self.transpose_batch_sequence and self.fuse_qkv and \ self.dropout_rate == 0 and canonicalize_dtype in [jnp.bfloat16, jnp.float16] and \ q_seqlen in fused_attn_supported_seqlen and kv_seqlen in fused_attn_supported_seqlen \ - and is_fused_attn_kernel_available() + and is_fused_attn_kernel_available() and enable_fused_attn - if not use_fused_attn: + if enable_fused_attn and not use_fused_attn: reason = "" if decode: reason += f"decode=False is required but got {decode}, " diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index 939072cfd4..f93a3c0983 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -386,7 +386,7 @@ def _get_dptp_sharding_meta(input_shapes: Tuple[Tuple[int, ...]], for input_shape, dp_dim, tp_dim in zip(input_shapes, input_dp_dims, input_tp_dims): in_axis = {} - if dp_dim is not None: + if dp_dim is not None and input_shape is not None: in_axis[dp_dim] = dp_axis_name assert input_shape[dp_dim] % dp_size == 0, \ f"The dimension of batch in input_shape should be a multiple of " \ @@ -398,7 +398,7 @@ def _get_dptp_sharding_meta(input_shapes: Tuple[Tuple[int, ...]], if tp_dim is not None and tp_dim >= dp_dim: tp_dim = tp_dim + 1 - if tp_dim is not None: + if tp_dim is not None and input_shape is not None: in_axis[tp_dim] = tp_axis_name assert input_shape[tp_dim] % tp_size == 0, \ f"The dimension of tensor parallel in input_shape should be a multiple of " \ From 84a4a7504221e671efdf9d582d994250c3cdf465 Mon Sep 17 00:00:00 2001 From: zlsh80826 Date: Tue, 20 Jun 2023 23:44:44 +0800 Subject: [PATCH 32/68] [JAX] Add self_attn_mask_type and replace attn_type (#273) * Add self_attn_mask_type and replace attn_type Signed-off-by: Reese Wang * Refine the keyword style for the better readability Signed-off-by: Reese Wang * Replace attn_type with attn_mask_type in praxis transformer Signed-off-by: Reese Wang * Fix typos Signed-off-by: Reese Wang --------- Signed-off-by: Reese Wang Co-authored-by: Kirthi Shankar Sivamani --- tests/jax/test_praxis_layers.py | 13 ++- transformer_engine/jax/flax/transformer.py | 105 +++++++++++++++---- transformer_engine/jax/praxis/transformer.py | 12 ++- 3 files changed, 96 insertions(+), 34 deletions(-) diff --git a/tests/jax/test_praxis_layers.py b/tests/jax/test_praxis_layers.py index 3adec948bd..de44b3a163 100644 --- a/tests/jax/test_praxis_layers.py +++ b/tests/jax/test_praxis_layers.py @@ -20,7 +20,6 @@ from transformer_engine.jax.flax import RelativePositionBiases as flax_RelativePositionBiases from transformer_engine.jax.flax import TransformerLayer as flax_TransformerLayer from transformer_engine.jax.flax.module import Softmax -from transformer_engine.jax.flax.transformer import AttentionType from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available from transformer_engine.jax.praxis import LayerNorm from transformer_engine.jax.praxis import FusedSoftmax, LayerNorm @@ -666,32 +665,32 @@ class MultiHeadAttnAttr: USE_BIAS: True, LN_TYPE: 'layernorm', ZERO_CEN: False, - ATTN_TYPE: AttentionType.PADDING + ATTN_TYPE: 'padding' }, { USE_BIAS: True, LN_TYPE: 'layernorm', ZERO_CEN: True, - ATTN_TYPE: AttentionType.PADDING + ATTN_TYPE: 'padding' }, { USE_BIAS: True, LN_TYPE: 'rmsnorm', ZERO_CEN: False, - ATTN_TYPE: AttentionType.PADDING + ATTN_TYPE: 'padding' }, { USE_BIAS: True, LN_TYPE: 'layernorm', ZERO_CEN: False, - ATTN_TYPE: AttentionType.CAUSAL + ATTN_TYPE: 'causal' }, { USE_BIAS: True, LN_TYPE: 'layernorm', ZERO_CEN: True, - ATTN_TYPE: AttentionType.CAUSAL + ATTN_TYPE: 'causal' }, { USE_BIAS: True, LN_TYPE: 'rmsnorm', ZERO_CEN: False, - ATTN_TYPE: AttentionType.CAUSAL + ATTN_TYPE: 'causal' }] diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index c8a949c90e..563b15d526 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -197,17 +197,16 @@ def core_attention(query: Array, dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None)) -class AttentionType(Enum): - """TransformerLayerType.""" - PADDING = AttnMaskType.PADDING_MASK - CAUSAL = AttnMaskType.CAUSAL_MASK - - class MultiHeadAttention(nn.Module): r""" Multi-head Attention (MHA), including Query, Key, Value and Output projection. + .. warning:: + + Argument :attr:`attn_type` is deprecated and superseded by :attr:`attn_mask_type`. + :attr:`attn_type` is ignored in version 0.10 and will be fully removed in version 0.11. + Parameters ---------- head_dim : int @@ -245,8 +244,11 @@ class MultiHeadAttention(nn.Module): Indicate if apply residual connection with the output of layer normalization. output_layernorm : bool, default = False Indicate if apply a layer normalization at the end of MHA. - attn_type: AttentionType, defult = AttentionType.PADDING - Indicate the format of the attention mask in the core attention. + attn_type: Any, defult = None + *Deprecated*, will be ignored in v0.10 and be fully removed in v0.11. + Please use `attn_mask_type` to config the attention mask. + attn_mask_type: {'causal', 'padding'}, default = 'causal' + Type of attention mask passed into softmax operation. Optimization parameters ----------------------- @@ -282,7 +284,9 @@ class MultiHeadAttention(nn.Module): bias_init: Initializer = nn.initializers.zeros apply_residual_connection_post_layernorm: bool = False output_layernorm: bool = False - attn_type: AttentionType = AttentionType.PADDING + # TODO(rewang): remove attn_type and the related doc after v0.11 + attn_type: Any = None + attn_mask_type: str = 'causal' dtype: DType = jnp.float32 fuse_qkv: bool = True transpose_batch_sequence: bool = True @@ -293,6 +297,14 @@ class MultiHeadAttention(nn.Module): def __post_init__(self): if self.kernel_init is None: self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'normal') + # TODO(rewang): remove attn_type after v0.11 + if self.attn_type is not None: + warnings.warn( + "The 'attn_type' argument in the 'MultiHeadAttention' is" + " deprecated in version 0.10 and will be removed in version 0.11." + " Passing value in attn_type will be ignored, please use `attn_mask_type`" + " to config the attention mask type.", + category=DeprecationWarning) super().__post_init__() @nn.compact @@ -570,9 +582,23 @@ def kv_init(key, shape, dtype): if use_fused_attn: assert mask is not None and mask.ndim == 4 # (b, 1, s_q, s_kv) assert not self.transpose_batch_sequence + # TODO(rewang): make it configurable for pre_scale_bias attn_bias_type = AttnBiasType.NO_BIAS if bias is None else AttnBiasType.POST_SCALE_BIAS + def canonicalize_attn_mask_type(attn_mask_type): + """ + Convert the string to AttnMaskType + """ + if attn_mask_type == 'causal': + return AttnMaskType.CAUSAL_MASK + if attn_mask_type == 'padding': + return AttnMaskType.PADDING_MASK + raise ValueError(f"Unsupported {attn_mask_type=}, " + "supported attn_mask_type = {'causal', 'padding'}") + + attn_mask_type = canonicalize_attn_mask_type(self.attn_mask_type) + if inputs_q is inputs_kv: qkv_proj = qkv_proj.reshape((*qkv_proj.shape[:-1], self.num_heads, self.head_dim)) qkv_sharding_constraint = ('batch', 'length', 'qkv_dim', 'heads', 'kv') @@ -583,7 +609,7 @@ def kv_init(key, shape, dtype): mask, dropout_rng, attn_bias_type=attn_bias_type, - attn_mask_type=self.attn_type.value, + attn_mask_type=attn_mask_type, scaling_factor=scale_factor, dropout_probability=self.dropout_rate, is_training=not deterministic, @@ -602,18 +628,27 @@ def kv_init(key, shape, dtype): mask, dropout_rng, attn_bias_type=attn_bias_type, - attn_mask_type=self.attn_type.value, + attn_mask_type=attn_mask_type, scaling_factor=scale_factor, dropout_probability=self.dropout_rate, is_training=not deterministic, sharding_type=first_sharding_type) else: - softmax_type = SoftmaxType.SCALED - if self.attn_type is AttentionType.PADDING: - if mask is not None: - softmax_type = SoftmaxType.SCALED_MASKED - else: - softmax_type = SoftmaxType.SCALED_UPPER_TRIANG_MASKED + + def convert_to_softmax_type(attn_mask_type, mask): + """ + Convert the string to SoftmaxType + """ + if attn_mask_type == 'causal': + return SoftmaxType.SCALED_UPPER_TRIANG_MASKED + if attn_mask_type == 'padding': + if mask is not None: + return SoftmaxType.SCALED_MASKED + return SoftmaxType.SCALED + raise ValueError(f"Unsupported {attn_mask_type=}, " + "supported attn_mask_type = {'causal', 'padding'}") + + softmax_type = convert_to_softmax_type(self.attn_mask_type, mask) x = core_attention(query, key, @@ -765,6 +800,18 @@ class TransformerLayer(nn.Module): an attention block and a feedforward network (MLP). This standard layer is based on the paper “Attention Is All You Need”. + .. warning:: + + Argument :attr:`self_attn_mask_type` is introduced in version 0.10. + Starting from version 0.11, the default value will be `"causal"`. + However, to ensure compatibility with earlier versions, before 0.11, + the default value will be `"padding"` for the encoder and `"causal"` for the decoder. + + .. note:: + + Argument :attr:`attention_mask` will be ignored when + :attr:`self_attn_mask_type` is set to `"causal"`. + Parameters ---------- hidden_size: int, default = 512 @@ -825,6 +872,8 @@ class TransformerLayer(nn.Module): If set to TransformerLayerType.DECODER, an additional cross-attention block is added after self-attention.this can be used for structures like `T5` Transformer in conjunction with the TransformerLayerType.ENCODER option. + self_attn_mask_type: {'causal', 'padding'}, default = 'causal' + Type of attention mask passed into softmax operation. enable_relative_embedding: bool, default = True Whether to enable relative embedding as shifting of attention logits. relative_embedding: flax.linen.Module, default = None @@ -878,6 +927,7 @@ class TransformerLayer(nn.Module): output_layernorm: bool = False float32_attention_logits: bool = False layer_type: TransformerLayerType = TransformerLayerType.ENCODER + self_attn_mask_type: str = None # TODO(rewang): default to 'causal' after 0.11 enable_relative_embedding: bool = True relative_embedding: nn.Module = None dtype: DType = jnp.float32 @@ -893,6 +943,19 @@ def __post_init__(self): if self.mlp_kernel_init is None: self.mlp_kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal') + # TODO(rewang): default to 'causal' in 0.11 (also updated the doc after 0.11) + if self.self_attn_mask_type is None: + warnings.warn( + "The 'self_attn_mask_type' argument in the 'TransformerLayer' is" + " introduced in version 0.10. Starting from version 0.11, the default" + " value will be 'causal'. However, to ensure compatibility with earlier" + " versions, before 0.11, the default value will be 'padding' for the" + " encoder and 'causal' for the decoder.", + category=FutureWarning) + if self.layer_type == TransformerLayerType.ENCODER: + self.self_attn_mask_type = 'padding' + else: + self.self_attn_mask_type = 'causal' super().__post_init__() @nn.compact @@ -975,16 +1038,12 @@ def __call__(self, assert inputs.ndim == 3 - self_attn_type = None # Make name be the exactly same as T5X, since names would affect # RNGKey during init and apply. Myabe no need in the feature. if self.layer_type == TransformerLayerType.ENCODER: mha_name = 'attention' - self_attn_type = AttentionType.PADDING else: mha_name = 'self_attention' - self_attn_type = AttentionType.CAUSAL - assert self_attn_type is not None # [batch, length, emb_dim] -> [batch, length, emb_dim] x, residual = MultiHeadAttention( @@ -1002,7 +1061,7 @@ def __call__(self, zero_centered_gamma=self.zero_centered_gamma, apply_residual_connection_post_layernorm=self.apply_residual_connection_post_layernorm, output_layernorm=self.output_layernorm, - attn_type=self_attn_type, + attn_mask_type=self.self_attn_mask_type, fuse_qkv=self.fuse_qkv_params, kernel_init=self.mha_kernel_init, use_bias=self.use_bias, @@ -1049,7 +1108,7 @@ def hidden_dropout(x, deterministic): apply_residual_connection_post_layernorm=self. apply_residual_connection_post_layernorm, output_layernorm=False, # Must do LayerNorm before MHA. - attn_type=AttentionType.PADDING, + attn_mask_type='padding', float32_logits=self.float32_attention_logits, scale_attn_logits=self.scale_attn_logits, scaled_query_init=self.scaled_query_init, diff --git a/transformer_engine/jax/praxis/transformer.py b/transformer_engine/jax/praxis/transformer.py index 32facd04aa..1260c266b5 100644 --- a/transformer_engine/jax/praxis/transformer.py +++ b/transformer_engine/jax/praxis/transformer.py @@ -5,14 +5,14 @@ Praxis Modules related Transformer """ from functools import partial -from typing import Optional, Sequence, Tuple +from typing import Any, Optional, Sequence, Tuple from praxis import pax_fiddle from praxis.base_layer import WeightInit from praxis.pytypes import JTensor from .module import TransformerEngineBaseLayer -from ..flax.transformer import AttentionType, TransformerLayerType +from ..flax.transformer import TransformerLayerType from ..flax.transformer import MultiHeadAttention as flax_MultiHeadAttention from ..flax.transformer import RelativePositionBiases as flax_RelativePositionBiases from ..flax.transformer import TransformerLayer as flax_TransformerLayer @@ -73,7 +73,9 @@ class MultiHeadAttention(TransformerEngineBaseLayer): bias_init: WeightInit = WeightInit.Constant(0.0) apply_residual_connection_post_layernorm: bool = False output_layernorm: bool = False - attn_type: AttentionType = AttentionType.PADDING + # TODO(rewang): remove attn_type and the related doc after v0.11 + attn_type: Any = None + attn_mask_type: str = 'causal' fuse_qkv: bool = True transpose_batch_sequence: bool = True scale_attn_logits: bool = False @@ -99,7 +101,7 @@ def setup(self) -> None: bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init), apply_residual_connection_post_layernorm=self.apply_residual_connection_post_layernorm, output_layernorm=self.output_layernorm, - attn_type=self.attn_type, + attn_mask_type=self.attn_mask_type, fuse_qkv=self.fuse_qkv, transpose_batch_sequence=self.transpose_batch_sequence, scale_attn_logits=self.scale_attn_logits, @@ -145,6 +147,7 @@ class TransformerLayer(TransformerEngineBaseLayer): output_layernorm: bool = False float32_attention_logits: bool = False layer_type: TransformerLayerType = TransformerLayerType.ENCODER + self_attn_mask_type: str = None # TODO(rewang): default to 'causal' after 0.11 enable_relative_embedding: bool = True relative_embedding: pax_fiddle.Config[RelativePositionBiases] = pax_fiddle.template_field(None) drop_path: float = 0.0 @@ -201,6 +204,7 @@ def setup(self) -> None: output_layernorm=self.output_layernorm, float32_attention_logits=self.float32_attention_logits, layer_type=self.layer_type, + self_attn_mask_type=self.self_attn_mask_type, enable_relative_embedding=self.enable_relative_embedding, relative_embedding=relative_embedding_flax_module, drop_path=self.drop_path, From 4244ba91390a41a849a4188cc2c9a434609045dc Mon Sep 17 00:00:00 2001 From: zlsh80826 Date: Wed, 21 Jun 2023 02:45:27 +0800 Subject: [PATCH 33/68] Support dropout for the fused attention when max seqlen <= 512 (#227) * Enable fused attention dropout Signed-off-by: Reese Wang * Cast the uint32 key/counter to int64 Signed-off-by: Reese Wang * Update dropout support in fused attention docs Signed-off-by: Reese Wang * Revise devPtrCuSeqlen* to align the naming Signed-off-by: Reese Wang * Support different Jax PRNG impls Signed-off-by: Reese Wang * Revert CastAsync since it is not used Signed-off-by: Reese Wang * Implement is_training for 16-bit fused attn Signed-off-by: Reese Wang * Add fused attn with dropout sanity unit tests Signed-off-by: Reese Wang * Enhance the comments readability and rng_state checker Signed-off-by: Reese Wang * Change the attention dropout shape to align other frameworks Signed-off-by: Reese Wang * Make encoder tests deterministic Signed-off-by: Reese Wang * Change the default seed for the jax encoder tests Signed-off-by: Reese Wang * Maintain offset in TE Signed-off-by: Reese Wang * Enhance the resource safety Signed-off-by: Reese Wang * Revert rng_state type to allow only i64 Signed-off-by: Reese Wang * Handle the corner case for elts_per_threads calculation Signed-off-by: Reese Wang * Populate rng state by kernels Signed-off-by: Reese Wang * Rename rng_state as seed in cpp_extensions Signed-off-by: Reese Wang * Update the attention dropout comment Signed-off-by: Reese Wang --------- Signed-off-by: Reese Wang Co-authored-by: Kirthi Shankar Sivamani --- .../encoder/test_model_parallel_encoder.py | 2 +- examples/jax/encoder/test_multigpu_encoder.py | 2 +- .../encoder/test_multiprocessing_encoder.py | 2 +- .../jax/encoder/test_single_gpu_encoder.py | 2 +- qa/L0_jax_unittest/test.sh | 7 +- tests/jax/test_fused_attn.py | 138 +++++++++++---- tests/jax/utils.py | 2 - .../fused_attn_fp16_bf16_max_seqlen_512.cu | 161 ++++++++++-------- transformer_engine/common/fused_attn/utils.cu | 4 + .../include/transformer_engine/fused_attn.h | 8 +- transformer_engine/jax/CMakeLists.txt | 2 +- transformer_engine/jax/cpp_extensions.py | 81 ++++++--- transformer_engine/jax/csrc/modules.cpp | 48 ++++-- .../jax/csrc/{utils.cpp => utils.cu} | 18 ++ transformer_engine/jax/csrc/utils.h | 24 +++ transformer_engine/jax/flax/transformer.py | 19 ++- transformer_engine/jax/fused_attn.py | 39 +++-- 17 files changed, 375 insertions(+), 184 deletions(-) rename transformer_engine/jax/csrc/{utils.cpp => utils.cu} (52%) diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index 0a2af0623e..4a26244fff 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -377,7 +377,7 @@ def encoder_parser(args): default=False, help="quickly check a single pass", ) - parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") + parser.add_argument("--seed", type=int, default=0, metavar="S", help="random seed (default: 0)") parser.add_argument("--use-fp8", action="store_true", default=False, diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index 48f858af58..ef3837c8d4 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -359,7 +359,7 @@ def encoder_parser(args): default=False, help="quickly check a single pass", ) - parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") + parser.add_argument("--seed", type=int, default=0, metavar="S", help="random seed (default: 0)") parser.add_argument("--use-fp8", action="store_true", default=False, diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index 61e5bda9df..a21346458c 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -459,7 +459,7 @@ def encoder_parser(args): default=False, help="quickly check a single pass", ) - parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") + parser.add_argument("--seed", type=int, default=0, metavar="S", help="random seed (default: 0)") parser.add_argument("--use-fp8", action="store_true", default=False, diff --git a/examples/jax/encoder/test_single_gpu_encoder.py b/examples/jax/encoder/test_single_gpu_encoder.py index 3db264daf7..62798eed82 100644 --- a/examples/jax/encoder/test_single_gpu_encoder.py +++ b/examples/jax/encoder/test_single_gpu_encoder.py @@ -294,7 +294,7 @@ def encoder_parser(args): default=False, help="quickly check a single pass", ) - parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") + parser.add_argument("--seed", type=int, default=0, metavar="S", help="random seed (default: 0)") parser.add_argument("--use-fp8", action="store_true", default=False, diff --git a/qa/L0_jax_unittest/test.sh b/qa/L0_jax_unittest/test.sh index 62242ba075..72d2817456 100644 --- a/qa/L0_jax_unittest/test.sh +++ b/qa/L0_jax_unittest/test.sh @@ -9,5 +9,10 @@ pytest -Wignore -v $TE_PATH/tests/jax pip install -r $TE_PATH/examples/jax/mnist/requirements.txt pip install -r $TE_PATH/examples/jax/encoder/requirements.txt -pytest -Wignore -v $TE_PATH/examples/jax --ignore=$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py + +pytest -Wignore -v $TE_PATH/examples/jax/mnist + +# Make encoder tests to have run-to-run deterministic to have the stable CI results +export XLA_FLAGS="--xla_gpu_deterministic_ops" +pytest -Wignore -v $TE_PATH/examples/jax/encoder --ignore=$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py pytest -Wignore -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 2504960705..8e4d59a9e2 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -54,6 +54,7 @@ def jax_self_fused_attn(qkv, bias, q_token, kv_token, dropout_rng, **kwargs): value, bias=bias, mask=mask, + deterministic=not kwargs['is_training'], dropout_rate=kwargs['dropout_probability'], dropout_rng=dropout_rng, dtype=qkv.dtype) @@ -78,6 +79,7 @@ def jax_cross_fused_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs): value, bias=None, mask=mask, + deterministic=not kwargs['is_training'], dropout_rate=kwargs['dropout_probability'], dropout_rng=dropout_rng, dtype=q.dtype) @@ -113,7 +115,8 @@ def customcall_cross_fused_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs) reason="Fused attention kernel is not supported.") class TestSelfFusedAttnMax512(): - def set_input(self, b, s, h, d, dtype, attn_mask_type, pad_ratio, with_bias): + def set_input(self, b, s, h, d, *, attn_bias_type, attn_mask_type, dropout_probability, dtype, + is_training, pad_ratio): key = jax.random.PRNGKey(0) subkeys = jax.random.split(key, 2) @@ -125,6 +128,8 @@ def set_input(self, b, s, h, d, dtype, attn_mask_type, pad_ratio, with_bias): min_val, max_val = -1, 1 self.qkv = jax.random.uniform(subkeys[0], qkv_shape, dtype, min_val, max_val) + + with_bias = attn_bias_type != AttnBiasType.NO_BIAS self.bias = jax.random.uniform(subkeys[1], bias_shape, dtype, min_val, max_val) if with_bias else None @@ -133,28 +138,81 @@ def set_input(self, b, s, h, d, dtype, attn_mask_type, pad_ratio, with_bias): self.kv_token = self.q_token self.scaling_factor = 1. / math.sqrt(d) - self.dropout_probability = 0. + self.dropout_probability = dropout_probability self.dropout_rng = jax.random.PRNGKey(0) if self.dropout_probability > 0 else None - self.attn_bias_type = AttnBiasType.NO_BIAS if self.bias is None else AttnBiasType.POST_SCALE_BIAS - # deterministic = not is_training - self.deterministic = False + self.attn_bias_type = attn_bias_type + self.is_training = is_training @pytest.mark.parametrize('b, s, h, d', SELF_CASES) - @pytest.mark.parametrize('dtype', DTYPES) + @pytest.mark.parametrize('attn_bias_type', [AttnBiasType.NO_BIAS, AttnBiasType.POST_SCALE_BIAS]) @pytest.mark.parametrize('attn_mask_type', [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]) + @pytest.mark.parametrize('dropout_probability', [0., 0.1]) + @pytest.mark.parametrize('dtype', DTYPES) + @pytest.mark.parametrize('is_training', [True, False]) @pytest.mark.parametrize('pad_ratio', PAD_RATIO) - @pytest.mark.parametrize('with_bias', [True, False]) - def test_forward(self, b, s, h, d, dtype, attn_mask_type, pad_ratio, with_bias): + def test_sanity(self, b, s, h, d, attn_bias_type, attn_mask_type, dropout_probability, dtype, + is_training, pad_ratio): + + def grad_func(func, *args, **kwargs): + # Keep only valid result for the gradient + # fused_attn_max_512 output has shape (b, s, h, d) + valid_ret, _ = jnp.split(func(*args, **kwargs), (self.valid_len,), axis=1) + return jnp.mean(valid_ret, dtype=jnp.float32).astype(dtype) self.set_input(b, s, h, d, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + dropout_probability=dropout_probability, dtype=dtype, + is_training=is_training, + pad_ratio=pad_ratio) + + kwargs = { + 'attn_bias_type': self.attn_bias_type, + 'attn_mask_type': attn_mask_type, + 'scaling_factor': self.scaling_factor, + 'dropout_probability': self.dropout_probability, + 'is_training': self.is_training + } + + jitted_primitive = jit( + value_and_grad( + lambda qkv, bias, q_token, kv_token, dropout_rng: grad_func( + customcall_self_fused_attn, qkv, bias, q_token, kv_token, dropout_rng, **kwargs + ), (0, 1))) + + primitive_out, (primitive_dqkv, + primitive_dbias) = jitted_primitive(self.qkv, self.bias, self.q_token, + self.kv_token, self.dropout_rng) + + @pytest.mark.parametrize('b, s, h, d', SELF_CASES) + @pytest.mark.parametrize('attn_bias_type', [AttnBiasType.NO_BIAS, AttnBiasType.POST_SCALE_BIAS]) + @pytest.mark.parametrize('attn_mask_type', + [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]) + @pytest.mark.parametrize('dropout_probability', [0., 0.1]) + @pytest.mark.parametrize('dtype', DTYPES) + @pytest.mark.parametrize('is_training', [True, False]) + @pytest.mark.parametrize('pad_ratio', PAD_RATIO) + def test_forward(self, b, s, h, d, attn_bias_type, attn_mask_type, dropout_probability, dtype, + is_training, pad_ratio): + # dropout can't get the bitmatch result + if is_training and dropout_probability > 0.: + return + + self.set_input(b, + s, + h, + d, + attn_bias_type=attn_bias_type, attn_mask_type=attn_mask_type, - pad_ratio=pad_ratio, - with_bias=with_bias) + dropout_probability=dropout_probability, + dtype=dtype, + is_training=is_training, + pad_ratio=pad_ratio) primitive_out = customcall_self_fused_attn(self.qkv, self.bias, @@ -165,7 +223,7 @@ def test_forward(self, b, s, h, d, dtype, attn_mask_type, pad_ratio, with_bias): attn_mask_type=attn_mask_type, scaling_factor=self.scaling_factor, dropout_probability=self.dropout_probability, - is_training=not self.deterministic) + is_training=self.is_training) reference_out = jax_self_fused_attn(self.qkv, self.bias, @@ -174,7 +232,8 @@ def test_forward(self, b, s, h, d, dtype, attn_mask_type, pad_ratio, with_bias): self.dropout_rng, attn_mask_type=attn_mask_type, scaling_factor=self.scaling_factor, - dropout_probability=self.dropout_probability) + dropout_probability=self.dropout_probability, + is_training=self.is_training) ref_valid, _ = jnp.split(reference_out, (self.valid_len,), axis=1) pri_valid, pri_invalid = jnp.split(primitive_out, (self.valid_len,), axis=1) @@ -188,20 +247,25 @@ def test_forward(self, b, s, h, d, dtype, attn_mask_type, pad_ratio, with_bias): jnp.zeros_like(pri_invalid, jnp.float32)) @pytest.mark.parametrize('b, s, h, d', SELF_CASES) + @pytest.mark.parametrize('attn_bias_type', [AttnBiasType.NO_BIAS, AttnBiasType.POST_SCALE_BIAS]) @pytest.mark.parametrize('attn_mask_type', [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]) + @pytest.mark.parametrize('dropout_probability', [0.]) # dropout can't get the bitmatch result @pytest.mark.parametrize('dtype', DTYPES) + @pytest.mark.parametrize('is_training', [True]) # backward is only used when is_training @pytest.mark.parametrize('pad_ratio', PAD_RATIO) - @pytest.mark.parametrize('with_bias', [True, False]) - def test_forward_backward(self, b, s, h, d, dtype, attn_mask_type, pad_ratio, with_bias): + def test_forward_backward(self, b, s, h, d, attn_bias_type, attn_mask_type, dropout_probability, + dtype, is_training, pad_ratio): self.set_input(b, s, h, d, - dtype=dtype, + attn_bias_type=attn_bias_type, attn_mask_type=attn_mask_type, - pad_ratio=pad_ratio, - with_bias=with_bias) + dropout_probability=dropout_probability, + dtype=dtype, + is_training=is_training, + pad_ratio=pad_ratio) def grad_func(fused_attn_max_512_func, *args, **kwargs): # Gradient is small, use a gradient multiplier to amplify the graident @@ -221,7 +285,7 @@ def grad_func(fused_attn_max_512_func, *args, **kwargs): 'attn_mask_type': attn_mask_type, 'scaling_factor': self.scaling_factor, 'dropout_probability': self.dropout_probability, - 'is_training': not self.deterministic + 'is_training': self.is_training } # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation @@ -300,7 +364,8 @@ def grad_func(fused_attn_max_512_func, *args, **kwargs): reason="Fused attention kernel is not supported.") class TestCrossFusedAttnMax512(): - def set_input(self, b, s_q, s_kv, h, d, dtype, attn_mask_type, pad_ratio): + def set_input(self, b, s_q, s_kv, h, d, *, attn_mask_type, dropout_probability, dtype, + is_training, pad_ratio): key = jax.random.PRNGKey(0) subkeys = jax.random.split(key, 2) @@ -321,25 +386,32 @@ def set_input(self, b, s_q, s_kv, h, d, dtype, attn_mask_type, pad_ratio): (b, kv_pad_len))), axis=-1) self.scaling_factor = 1. / math.sqrt(d) - self.dropout_probability = 0. - self.dropout_rng = jax.random.PRNGKey(0) + self.dropout_probability = dropout_probability + self.dropout_rng = jax.random.PRNGKey(0) if self.dropout_probability > 0 else None self.attn_bias_type = AttnBiasType.NO_BIAS - # deterministic = not is_training - self.deterministic = False + self.is_training = is_training @pytest.mark.parametrize('b, s_q, s_kv, h, d', CROSS_CASES) @pytest.mark.parametrize('attn_mask_type', [AttnMaskType.PADDING_MASK]) + @pytest.mark.parametrize('dropout_probability', [0., 0.1]) @pytest.mark.parametrize('dtype', DTYPES) + @pytest.mark.parametrize('is_training', [True, False]) @pytest.mark.parametrize('pad_ratio', PAD_RATIO) - def test_forward(self, b, s_q, s_kv, h, d, dtype, attn_mask_type, pad_ratio): + def test_forward(self, b, s_q, s_kv, h, d, attn_mask_type, dropout_probability, dtype, + is_training, pad_ratio): + # dropout can't get the bitmatch result + if is_training and dropout_probability > 0.: + return self.set_input(b, s_q, s_kv, h, d, - dtype=dtype, attn_mask_type=attn_mask_type, + dropout_probability=dropout_probability, + dtype=dtype, + is_training=is_training, pad_ratio=pad_ratio) primitive_out = customcall_cross_fused_attn(self.q, @@ -351,7 +423,7 @@ def test_forward(self, b, s_q, s_kv, h, d, dtype, attn_mask_type, pad_ratio): attn_mask_type=attn_mask_type, scaling_factor=self.scaling_factor, dropout_probability=self.dropout_probability, - is_training=not self.deterministic) + is_training=self.is_training) reference_out = jax_cross_fused_attn(self.q, self.kv, @@ -360,7 +432,8 @@ def test_forward(self, b, s_q, s_kv, h, d, dtype, attn_mask_type, pad_ratio): self.dropout_rng, attn_mask_type=attn_mask_type, scaling_factor=self.scaling_factor, - dropout_probability=self.dropout_probability) + dropout_probability=self.dropout_probability, + is_training=self.is_training) ref_valid, _ = jnp.split(reference_out, (self.q_valid_len,), axis=1) pri_valid, pri_invalid = jnp.split(primitive_out, (self.q_valid_len,), axis=1) @@ -375,16 +448,21 @@ def test_forward(self, b, s_q, s_kv, h, d, dtype, attn_mask_type, pad_ratio): @pytest.mark.parametrize('b, s_q, s_kv, h, d', CROSS_CASES) @pytest.mark.parametrize('attn_mask_type', [AttnMaskType.PADDING_MASK]) + @pytest.mark.parametrize('dropout_probability', [0.]) # dropout can't get the bitmatch result @pytest.mark.parametrize('dtype', DTYPES) + @pytest.mark.parametrize('is_training', [True]) # backward is only used when is_training @pytest.mark.parametrize('pad_ratio', PAD_RATIO) - def test_forward_backward(self, b, s_q, s_kv, h, d, dtype, attn_mask_type, pad_ratio): + def test_forward_backward(self, b, s_q, s_kv, h, d, attn_mask_type, dropout_probability, dtype, + is_training, pad_ratio): self.set_input(b, s_q, s_kv, h, d, - dtype=dtype, attn_mask_type=attn_mask_type, + dropout_probability=dropout_probability, + dtype=dtype, + is_training=is_training, pad_ratio=pad_ratio) def grad_func(fused_attn_max_512_func, *args, **kwargs): @@ -405,7 +483,7 @@ def grad_func(fused_attn_max_512_func, *args, **kwargs): 'attn_mask_type': attn_mask_type, 'scaling_factor': self.scaling_factor, 'dropout_probability': self.dropout_probability, - 'is_training': not self.deterministic + 'is_training': self.is_training } # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation diff --git a/tests/jax/utils.py b/tests/jax/utils.py index dc5ef2bb13..893a5afcbe 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -167,9 +167,7 @@ def dot_product_attention(query: Array, # T5 broadcasts along the "length" dim, but unclear which one that # corresponds to in positional dimensions here, assuming query dim. dropout_shape = list(attn_weights.shape) - dropout_shape[-2] = 1 keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape) - keep = jnp.broadcast_to(keep, attn_weights.shape) multiplier = (keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=dtype)) attn_weights = attn_weights * multiplier diff --git a/transformer_engine/common/fused_attn/fused_attn_fp16_bf16_max_seqlen_512.cu b/transformer_engine/common/fused_attn/fused_attn_fp16_bf16_max_seqlen_512.cu index 53f4f72636..e8906b31c4 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp16_bf16_max_seqlen_512.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp16_bf16_max_seqlen_512.cu @@ -22,7 +22,7 @@ #define O_ID 4 #define S_ID 5 #define B_ID 6 -#define D_CONST_ID 7 +#define DROPOUT_CONST_ID 7 #define S_CONST_ID 8 #define Q_SEQLEN_ID 9 #define K_SEQLEN_ID 10 @@ -33,6 +33,8 @@ #define MASK_VAL_ID 15 #define dS_ID 16 #define dBias_ID 17 +#define DROPOUT_SEED_ID 18 +#define DROPOUT_OFFSET_ID 19 #define VIRTUAL_ID 20 @@ -333,8 +335,7 @@ static cudnn_frontend::Tensor createSoftmaxForward( int64_t afterReduction_dim[4] = {b, h, s_q, 1}; int64_t afterReduction_stride[4] = {h * s_q, s_q, 1, 1}; - cudnnDataType_t softmaxOutputType = - (enable_dropout || softmax_output_virtual) ? CUDNN_DATA_FLOAT : tensorType; + cudnnDataType_t softmaxOutputType = enable_dropout ? CUDNN_DATA_FLOAT : tensorType; uint64_t softmaxOutputName = softmax_output_virtual ? VIRTUAL_ID + 154 : S_ID; // max (x) @@ -427,7 +428,7 @@ static cudnn_frontend::Tensor createSoftmaxForward( } static cudnn_frontend::Tensor createDropout(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, - int64_t d, int64_t seed, double probability, + int64_t d, double probability, cudnnDataType_t tensorType, // NOLINTNEXTLINE(runtime/references) std::vector &ops, @@ -460,8 +461,9 @@ static cudnn_frontend::Tensor createDropout(int64_t b, int64_t h, int64_t s_q, i .setReorderType(reorder_type) .build(); // scale after dropout - auto scaleDropoutTensor = tensor_create(tensorType, D_CONST_ID, scale_dim, scale_stride, false, - true); // is by value + auto scaleDropoutTensor = + tensor_create(tensorType, DROPOUT_CONST_ID, scale_dim, scale_stride, false, + true); // is by value // after Scale auto afterScaleTensor = tensor_create(tensorType, VIRTUAL_ID + 201, afterBMM1_dim, afterBMM1_stride, true, false); // is virtual @@ -472,10 +474,16 @@ static cudnn_frontend::Tensor createDropout(int64_t b, int64_t h, int64_t s_q, i .setBernoulliDistProbability(1.0 - probability) .build(); + auto dropoutSeed = + tensor_create(CUDNN_DATA_INT64, DROPOUT_SEED_ID, scale_dim, scale_stride, false, false); + auto dropoutOffset = + tensor_create(CUDNN_DATA_INT64, DROPOUT_OFFSET_ID, scale_dim, scale_stride, false, false); + // Create a rng Node. auto rng_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RNG_DESCRIPTOR) .setyDesc(dropoutMaskTensor) - .setSeed(seed) + .setSeedDesc(dropoutSeed) + .setOffsetDesc(dropoutOffset) .setRngDesc(rngDesc) .build(); @@ -624,16 +632,14 @@ static cudnn_frontend::Tensor createSoftmaxBackward(int64_t b, int64_t h, int64_ return dxTensor; } -void fused_attn_max_512_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, - bool is_training, float scaling_factor, float dropout_probability, - NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, void *devPtrQ, void *devPtrK, - void *devPtrV, void *devPtrS, void *devPtrO, void *devPtrBias, - void *devCuSeqlenQ, void *devCuSeqlenK, void *workspace, - size_t *workspace_size, cudnnDataType_t tensorType, - cudaStream_t stream, cudnnHandle_t handle) { +void fused_attn_max_512_fwd_impl( + int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, bool is_training, + float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, void *devPtrQ, void *devPtrK, void *devPtrV, + void *devPtrS, void *devPtrO, void *devPtrBias, void *devPtrCuSeqlenQ, void *devPtrCuSeqlenKV, + void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *workspace, size_t *workspace_size, + cudnnDataType_t tensorType, cudaStream_t stream, cudnnHandle_t handle) { try { - constexpr int64_t seed = 0; // TODO(rewang): replace this with device seed/offset NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream)); FADescriptor descriptor{b, h, @@ -646,10 +652,13 @@ void fused_attn_max_512_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv using CacheType = std::map; static thread_local CacheType fmha_fprop_cache; - bool enable_dropout = (dropout_probability != 0.0f); + // softmax auxiliary is only used in the training mode + bool enable_dropout = is_training && (dropout_probability != 0.0f); - NVTE_CHECK(!enable_dropout, - "dropout probability > 0 in fused_attn_max_512 has not been implemented."); + // two conditions that make softmax auxiliary in virtual + // 1. inference mode (not is_training) + // 2. dropout enabled: the auxiliary becomes the dropout output + bool softmax_output_virtual = !is_training || enable_dropout; // Get plan from cache if cache is available, otherwise create one auto get_plan = [&](CacheType &cache, const FADescriptor &descriptor) { @@ -667,8 +676,10 @@ void fused_attn_max_512_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv createScale(b, h, s_q, s_kv, d, layout, tensorType, ops); // if bias, we need to memset the S buffer to correctly computate dbias + // WAR: causal_mask without bias needs memset the S buffer + // inference mode doesn't need the S auxiliary auto zero_s = (bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) || - (mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK); + (mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) && is_training; auto bmm1_output = createBMM1(b, h, s_q, s_kv, d, layout, tensorType, zero_s, ops); NVTE_CHECK(bias_type != NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS, @@ -683,14 +694,12 @@ void fused_attn_max_512_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv NVTE_CHECK(dropout_probability != 1.0f, "Dropout probability cannot be 1.0."); - // TODO(rewang): check whether devPtrS can be removed - bool softmax_output_virtual = enable_dropout; // || devPtrS == nullptr; auto softmax_output = createSoftmaxForward(b, h, s_q, s_kv, d, layout, enable_dropout, softmax_output_virtual, tensorType, ops, mask_output); - if (dropout_probability != 0.0f) { - auto dropout_output = createDropout(b, h, s_q, s_kv, d, seed, dropout_probability, + if (enable_dropout) { + auto dropout_output = createDropout(b, h, s_q, s_kv, d, dropout_probability, tensorType, ops, softmax_output); createBMM2(b, h, s_q, s_kv, d, layout, tensorType, ops, dropout_output); } else { @@ -741,9 +750,10 @@ void fused_attn_max_512_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv void *devActualSeqlenQ = static_cast(workspace) + plan_workspace_size; void *devActualSeqlenK = static_cast(devActualSeqlenQ) + b * sizeof(int32_t); cu_seqlens_to_actual_seqlens<<>>( - b, static_cast(devCuSeqlenQ), - static_cast(devCuSeqlenK), static_cast(devActualSeqlenQ), - static_cast(devActualSeqlenK)); + b, static_cast(devPtrCuSeqlenQ), + static_cast(devPtrCuSeqlenKV), + static_cast(devActualSeqlenQ), static_cast(devActualSeqlenK)); + NVTE_CHECK_CUDA(cudaGetLastError()); // change this if you have access to float_min float negInfinity = -1.0E+10; @@ -758,16 +768,17 @@ void fused_attn_max_512_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv data_ptrs.insert(std::pair(K_SEQLEN_ID, devActualSeqlenK)); data_ptrs.insert(std::pair(MASK_VAL_ID, &negInfinity)); + __half half_cast_scaling_factor{scaling_factor}; + __nv_bfloat16 bfloat_cast_scaling_factor{scaling_factor}; + if (tensorType == CUDNN_DATA_FLOAT) { data_ptrs.insert(std::pair(S_CONST_ID, &scaling_factor)); } else if (tensorType == CUDNN_DATA_HALF) { - __half cast_scaling_factor{scaling_factor}; - data_ptrs.insert(std::pair(S_CONST_ID, &cast_scaling_factor)); + data_ptrs.insert(std::pair(S_CONST_ID, &half_cast_scaling_factor)); } else if (tensorType == CUDNN_DATA_BFLOAT16) { - __nv_bfloat16 cast_scaling_factor{scaling_factor}; - data_ptrs.insert(std::pair(S_CONST_ID, &cast_scaling_factor)); + data_ptrs.insert(std::pair(S_CONST_ID, &bfloat_cast_scaling_factor)); } else { - std::cerr << "Not supported tensorType." << std::endl; + NVTE_ERROR("Unsupported tensor type."); } data_ptrs.insert(std::pair(O_ID, devPtrO)); @@ -776,12 +787,30 @@ void fused_attn_max_512_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv data_ptrs.insert(std::pair(B_ID, devPtrBias)); } - if (devPtrS != nullptr) { + // if enable_dropout, S is the result after dropout + // if not enable dropout, S is the result after softmax + if (enable_dropout || !softmax_output_virtual) { data_ptrs.insert(std::pair(S_ID, devPtrS)); } + __half half_cast_scale_dropout{scale_dropout}; + __nv_bfloat16 bfloat16_cast_scale_dropout{scale_dropout}; + if (enable_dropout) { - data_ptrs.insert(std::pair(D_CONST_ID, &scale_dropout)); + // TODO(rewang): make a util func + if (tensorType == CUDNN_DATA_FLOAT) { + data_ptrs.insert(std::pair(DROPOUT_CONST_ID, &scale_dropout)); + } else if (tensorType == CUDNN_DATA_HALF) { + data_ptrs.insert( + std::pair(DROPOUT_CONST_ID, &half_cast_scale_dropout)); + } else if (tensorType == CUDNN_DATA_BFLOAT16) { + data_ptrs.insert( + std::pair(DROPOUT_CONST_ID, &bfloat16_cast_scale_dropout)); + } else { + NVTE_ERROR("Unsupported tensor type."); + } + data_ptrs.insert(std::pair(DROPOUT_SEED_ID, devPtrDropoutSeed)); + data_ptrs.insert(std::pair(DROPOUT_OFFSET_ID, devPtrDropoutOffset)); } auto variantPack = cudnn_frontend::VariantPackBuilder() @@ -802,7 +831,7 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv NVTE_Bias_Type bias_type, void *devPtrQ, void *devPtrK, void *devPtrV, void *devPtrS, void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, void *devPtrdS, void *devPtrdBias, - void *devCuSeqlenQ, void *devCuSeqlenK, void *workspace, + void *devPtrCuSeqlenQ, void *devPtrCuSeqlenKV, void *workspace, size_t *workspace_size, cudnnDataType_t tensorType, cudaStream_t stream, cudnnHandle_t handle) { try { @@ -915,7 +944,7 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv ops.push_back(std::move(reshape_op)); // scale dropout - auto dropoutScaleTensor = tensor_create(CUDNN_DATA_FLOAT, D_CONST_ID, scale_dim, + auto dropoutScaleTensor = tensor_create(CUDNN_DATA_FLOAT, DROPOUT_CONST_ID, scale_dim, scale_stride, false, true); // is by value auto pAfterScaleTensor = tensor_create(tensorType, VIRTUAL_ID + 301, p_transpose_dim, p_transpose_stride, true, false); @@ -1160,9 +1189,10 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv void *devActualSeqlenQ = static_cast(workspace) + plan_workspace_size; void *devActualSeqlenK = static_cast(devActualSeqlenQ) + b * sizeof(int32_t); cu_seqlens_to_actual_seqlens<<>>( - b, static_cast(devCuSeqlenQ), - static_cast(devCuSeqlenK), static_cast(devActualSeqlenQ), - static_cast(devActualSeqlenK)); + b, static_cast(devPtrCuSeqlenQ), + static_cast(devPtrCuSeqlenKV), + static_cast(devActualSeqlenQ), static_cast(devActualSeqlenK)); + NVTE_CHECK_CUDA(cudaGetLastError()); std::set> data_ptrs; // add all the data pointers to be used in the variant pack @@ -1183,13 +1213,10 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv data_ptrs.insert(std::pair(dBias_ID, devPtrdBias)); } - NVTE_CHECK(dropout_probability == 0.f, - "dropout probability > 0 in fused_attn_max_512 has not been implemented."); - float zeroVal = 0.0f; float dropoutScale = 1.0f / (1.0f - dropout_probability); - data_ptrs.insert(std::pair(D_CONST_ID, &dropoutScale)); + data_ptrs.insert(std::pair(DROPOUT_CONST_ID, &dropoutScale)); data_ptrs.insert(std::pair(S_CONST_ID, &scaling_factor)); data_ptrs.insert(std::pair(MASK_VAL_ID, &zeroVal)); @@ -1216,8 +1243,6 @@ void fused_attn_max_512_fwd_qkvpacked( Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; - // Only is_training is verified - NVTE_CHECK(is_training, "is_training=False is not implemented in fused_attn_max_512."); NVTE_CHECK(qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED, "qkv_layout must be NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED."); @@ -1246,23 +1271,22 @@ void fused_attn_max_512_fwd_qkvpacked( devPtrS = output_S->data.dptr; } - void *devCuSeqlen = cu_seqlens->data.dptr; + void *devPtrCuSeqlen = cu_seqlens->data.dptr; - // TODO(rewang): dropout seed - // void* devPtrDropoutSeed = reinterpret_cast( - // reinterpret_cast(rng_state->data.dptr)); - // void* devPtrDropoutOffset = reinterpret_cast( - // reinterpret_cast(rng_state->data.dptr) + 1); + const DType rng_state_type = rng_state->data.dtype; + NVTE_CHECK(rng_state_type == DType::kInt64); + void *devPtrDropoutSeed = rng_state->data.dptr; + void *devPtrDropoutOffset = + static_cast(static_cast(rng_state->data.dptr) + 1); const DType QKV_type = input_QKV->data.dtype; size_t workspace_size = 0; - // TODO(rewang): replace CPU seed - fused_attn_max_512_fwd_impl(batch, num_head, max_seqlen, max_seqlen, head_dim, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, - devPtrK, devPtrV, devPtrS, devPtrO, devPtrBias, devCuSeqlen, - devCuSeqlen, workspace->data.dptr, &workspace_size, - get_cudnn_dtype(QKV_type), stream, handle); + fused_attn_max_512_fwd_impl( + batch, num_head, max_seqlen, max_seqlen, head_dim, is_training, attn_scale, p_dropout, + qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO, devPtrBias, + devPtrCuSeqlen, devPtrCuSeqlen, devPtrDropoutSeed, devPtrDropoutOffset, + workspace->data.dptr, &workspace_size, get_cudnn_dtype(QKV_type), stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1288,8 +1312,6 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; - // Only is_training is verified - NVTE_CHECK(is_training, "is_training=False is not implemented in fused_attn_max_512."); NVTE_CHECK(qkv_layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED, "qkv_layout must be NVTE_QKV_Layout::NVTE_KV_INTERLEAVED."); NVTE_CHECK(bias_type == NVTE_Bias_Type::NVTE_NO_BIAS || @@ -1328,20 +1350,19 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k void *devQCuSeqlen = q_cu_seqlens->data.dptr; void *devKVCuSeqlen = kv_cu_seqlens->data.dptr; - // TODO(rewang): dropout seed - // void* devPtrDropoutSeed = reinterpret_cast( - // reinterpret_cast(rng_state->data.dptr)); - // void* devPtrDropoutOffset = reinterpret_cast( - // reinterpret_cast(rng_state->data.dptr) + 1); + const DType rng_state_type = rng_state->data.dtype; + NVTE_CHECK(rng_state_type == DType::kInt64); + void *devPtrDropoutSeed = rng_state->data.dptr; + void *devPtrDropoutOffset = + static_cast(static_cast(rng_state->data.dptr) + 1); size_t workspace_size = 0; - // TODO(rewang): replace CPU seed - fused_attn_max_512_fwd_impl(batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, - devPtrK, devPtrV, devPtrS, devPtrO, devPtrBias, devQCuSeqlen, - devKVCuSeqlen, workspace->data.dptr, &workspace_size, - get_cudnn_dtype(q_type), stream, handle); + fused_attn_max_512_fwd_impl( + batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, is_training, attn_scale, p_dropout, + qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO, devPtrBias, + devQCuSeqlen, devKVCuSeqlen, devPtrDropoutSeed, devPtrDropoutOffset, workspace->data.dptr, + &workspace_size, get_cudnn_dtype(q_type), stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { diff --git a/transformer_engine/common/fused_attn/utils.cu b/transformer_engine/common/fused_attn/utils.cu index 5ae4b42c16..cae42bafa0 100644 --- a/transformer_engine/common/fused_attn/utils.cu +++ b/transformer_engine/common/fused_attn/utils.cu @@ -256,6 +256,10 @@ __global__ void cu_seqlens_to_actual_seqlens(size_t b, cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t) { using namespace transformer_engine; switch (t) { + case DType::kInt32: + return CUDNN_DATA_INT32; + case DType::kInt64: + return CUDNN_DATA_INT64; case DType::kFloat16: return CUDNN_DATA_HALF; case DType::kFloat32: diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 6311da2465..ed6dd4c041 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -106,7 +106,7 @@ enum NVTE_Mask_Type { \verbatim | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | Yes | <= 512 | 64 | - | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | No | <= 512 | 64 | + | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 | \endverbatim * * \param[in] QKV The QKV tensor in packed format, @@ -149,7 +149,7 @@ void nvte_fused_attn_fwd_qkvpacked( \verbatim | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | Yes | <= 512 | 64 | - | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | No | <= 512 | 64 | + | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 | \endverbatim * * \param[in] QKV The QKV tensor in packed format, @@ -200,7 +200,7 @@ void nvte_fused_attn_bwd_qkvpacked( * Support Matrix: \verbatim | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | - | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | No | <= 512 | 64 | + | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 | \endverbatim * * \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim]. @@ -247,7 +247,7 @@ void nvte_fused_attn_fwd_kvpacked( * Support Matrix: \verbatim | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | - | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | No | <= 512 | 64 | + | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 | \endverbatim * * \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim]. diff --git a/transformer_engine/jax/CMakeLists.txt b/transformer_engine/jax/CMakeLists.txt index 9e8efa2c60..cf9a48244d 100644 --- a/transformer_engine/jax/CMakeLists.txt +++ b/transformer_engine/jax/CMakeLists.txt @@ -6,7 +6,7 @@ pybind11_add_module( transformer_engine_jax ${CMAKE_CURRENT_SOURCE_DIR}/csrc/extensions.cpp ${CMAKE_CURRENT_SOURCE_DIR}/csrc/modules.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/csrc/utils.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/csrc/utils.cu ) target_link_libraries(transformer_engine_jax PRIVATE CUDA::cudart CUDA::cublas CUDA::cublasLt transformer_engine) diff --git a/transformer_engine/jax/cpp_extensions.py b/transformer_engine/jax/cpp_extensions.py index 566b95ff63..b8dc0447c7 100644 --- a/transformer_engine/jax/cpp_extensions.py +++ b/transformer_engine/jax/cpp_extensions.py @@ -8,6 +8,8 @@ from typing import Tuple from functools import partial, reduce import operator +import warnings + import numpy as np from jaxlib.hlo_helpers import custom_call import jax.numpy as jnp @@ -1679,7 +1681,7 @@ def lowering(ctx, grad_outputs, softmax_outputs, *, scale_factor): grad_outputs, softmax_outputs, scale_factor) - return out # out is iterable already + return out # out is iterable already _scaled_softmax_bwd_p = register_primitive(ScaledSoftmaxBwdPrimitive) @@ -1828,7 +1830,7 @@ def lowering(ctx, grad_outputs, softmax_outputs, *, scale_factor): grad_outputs, softmax_outputs, scale_factor) - return out # out is iterable already + return out # out is iterable already _scaled_masked_softmax_bwd_p = register_primitive(ScaledMaskedSoftmaxBwdPrimitive) @@ -1962,7 +1964,7 @@ def lowering(ctx, grad_outputs, softmax_outputs, *, scale_factor): ScaledUpperTriangMaskedSoftmaxBwdPrimitive.name, ctx, grad_outputs, softmax_outputs, scale_factor) - return out # out is iterable already + return out # out is iterable already _scaled_upper_triang_masked_softmax_bwd_p = \ register_primitive(ScaledUpperTriangMaskedSoftmaxBwdPrimitive) @@ -1979,6 +1981,27 @@ def scaled_upper_triang_masked_softmax_bwd(grad_outputs: jnp.ndarray, softmax_ou scale_factor=scale_factor) +def _check_seed(seed, dropout_probability, is_training): + # Jax can't bind None, create a dummy tensor for None + if seed is None: + dropout_enabled = dropout_probability > 0 and is_training + assert not dropout_enabled, "seed is not allowed to be None when dropout is enabled." + seed = jnp.zeros(2, dtype=jnp.uint32) + + if seed.dtype != jnp.uint32: + warnings.warn( + f"Requested {seed.dtype=} is not available, and will be " + f"casted to dtype uint32. " + f"Please use threefry/rbg/unsafe_rbg PRNG implementations to remove this warning.") + seed = seed.astype(jnp.uint32) + + assert seed.dtype == jnp.uint32 + # Only the first 2 u32 elements are taken + assert seed.size >= 2 + + return seed + + class SelfFusedAttnMax512FwdPrimitive(BasePrimitive): """ Self Fused Attention Max Seqlen 512 Forward Primitive @@ -1991,7 +2014,7 @@ def abstract( qkv, bias, cu_seqlen, # pylint: disable=unused-argument - rng_state, # pylint: disable=unused-argument + seed, # pylint: disable=unused-argument *, attn_bias_type, # pylint: disable=unused-argument attn_mask_type, # pylint: disable=unused-argument @@ -2020,8 +2043,8 @@ def abstract( ) @staticmethod - def lowering(ctx, qkv, bias, cu_seqlen, rng_state, *, attn_bias_type, attn_mask_type, - scaling_factor, dropout_probability, is_training): + def lowering(ctx, qkv, bias, cu_seqlen, seed, *, attn_bias_type, attn_mask_type, scaling_factor, + dropout_probability, is_training): """ Self fused attention max seqlen 512 fwd lowering rules """ @@ -2036,8 +2059,8 @@ def lowering(ctx, qkv, bias, cu_seqlen, rng_state, *, attn_bias_type, attn_mask_ ir_cu_seqlen_type = ir.RankedTensorType(cu_seqlen.type) ir_cu_seqlen_shape = ir_cu_seqlen_type.shape - ir_rng_state_type = ir.RankedTensorType(rng_state.type) - ir_rng_state_shape = ir_rng_state_type.shape + ir_seed_type = ir.RankedTensorType(seed.type) + ir_seed_shape = ir_seed_type.shape batch, max_seqlen, nqkv, num_head, head_dim = ir_qkv_shape assert nqkv == 3 @@ -2049,8 +2072,8 @@ def lowering(ctx, qkv, bias, cu_seqlen, rng_state, *, attn_bias_type, attn_mask_ ir.RankedTensorType.get(output_shape, ir_qkv_type.element_type), ir.RankedTensorType.get(softmax_aux_shape, ir_qkv_type.element_type) ] - operands = [qkv, bias, cu_seqlen, rng_state] - operand_shapes = [ir_qkv_shape, ir_bias_shape, ir_cu_seqlen_shape, ir_rng_state_shape] + operands = [qkv, bias, cu_seqlen, seed] + operand_shapes = [ir_qkv_shape, ir_bias_shape, ir_cu_seqlen_shape, ir_seed_shape] args = CustomCallArgsWrapper(out_types, operands, operand_shapes) opaque = transformer_engine_jax.pack_fused_attn_descriptor( @@ -2069,23 +2092,22 @@ def lowering(ctx, qkv, bias, cu_seqlen, rng_state, *, attn_bias_type, attn_mask_ def self_fused_attn_max_512_fwd(qkv: jnp.ndarray, bias: jnp.ndarray, cu_seqlen: jnp.ndarray, - rng_state: jnp.ndarray, attn_bias_type: NVTE_Bias_Type, + seed: jnp.ndarray, attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type, scaling_factor: float, dropout_probability: float, is_training: bool): """ Wrapper for TE self fused attention max seqlen 512 fwd Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2 """ - # Jax can't bind None, create a dummy tensor for None - if rng_state is None: - rng_state = jnp.zeros(2, dtype=jnp.int32) + seed = _check_seed(seed, dropout_probability, is_training) + if bias is None: assert attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS bias = jnp.zeros(0, dtype=qkv.dtype) return _self_fused_attn_max_512_fwd_p.bind(qkv, bias, cu_seqlen, - rng_state, + seed, attn_bias_type=attn_bias_type, attn_mask_type=attn_mask_type, scaling_factor=scaling_factor, @@ -2161,6 +2183,9 @@ def lowering(ctx, qkv, softmax_aux, doutput, cu_seqlen, *, attn_bias_type, attn_ operand_shapes = [ir_qkv_shape, ir_softmax_aux_shape, ir_doutput_shape, ir_cu_seqlen_shape] args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + + # the dropout elements are encoded in the forward auxiliary tensor + # so seed is not needed in backward opaque = transformer_engine_jax.pack_fused_attn_descriptor( batch, num_head, max_seqlen, max_seqlen, head_dim, scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, jax_dtype_to_te_dtype(qkv_aval.dtype), is_training) @@ -2208,7 +2233,7 @@ def abstract( kv, q_cu_seqlen, kv_cu_seqlen, - rng_state, # pylint: disable=unused-argument + seed, # pylint: disable=unused-argument *, attn_bias_type, # pylint: disable=unused-argument attn_mask_type, # pylint: disable=unused-argument @@ -2243,8 +2268,8 @@ def abstract( ) @staticmethod - def lowering(ctx, q, kv, q_cu_seqlen, kv_cu_seqlen, rng_state, *, attn_bias_type, - attn_mask_type, scaling_factor, dropout_probability, is_training): + def lowering(ctx, q, kv, q_cu_seqlen, kv_cu_seqlen, seed, *, attn_bias_type, attn_mask_type, + scaling_factor, dropout_probability, is_training): """ Cross fused attention max seqlen 512 fwd lowering rules """ @@ -2260,8 +2285,8 @@ def lowering(ctx, q, kv, q_cu_seqlen, kv_cu_seqlen, rng_state, *, attn_bias_type ir_q_cu_seqlen_shape = ir.RankedTensorType(q_cu_seqlen.type).shape ir_kv_cu_seqlen_shape = ir.RankedTensorType(kv_cu_seqlen.type).shape - ir_rng_state_type = ir.RankedTensorType(rng_state.type) - ir_rng_state_shape = ir_rng_state_type.shape + ir_seed_type = ir.RankedTensorType(seed.type) + ir_seed_shape = ir_seed_type.shape batch, q_max_seqlen, num_head, head_dim = ir_q_shape kv_max_seqlen = ir_kv_shape[1] @@ -2273,9 +2298,9 @@ def lowering(ctx, q, kv, q_cu_seqlen, kv_cu_seqlen, rng_state, *, attn_bias_type ir.RankedTensorType.get(output_shape, ir_q_type.element_type), ir.RankedTensorType.get(softmax_aux_shape, ir_q_type.element_type) ] - operands = [q, kv, q_cu_seqlen, kv_cu_seqlen, rng_state] + operands = [q, kv, q_cu_seqlen, kv_cu_seqlen, seed] operand_shapes = [ - ir_q_shape, ir_kv_shape, ir_q_cu_seqlen_shape, ir_kv_cu_seqlen_shape, ir_rng_state_shape + ir_q_shape, ir_kv_shape, ir_q_cu_seqlen_shape, ir_kv_cu_seqlen_shape, ir_seed_shape ] args = CustomCallArgsWrapper(out_types, operands, operand_shapes) @@ -2296,7 +2321,7 @@ def lowering(ctx, q, kv, q_cu_seqlen, kv_cu_seqlen, rng_state, *, attn_bias_type def cross_fused_attn_max_512_fwd(q: jnp.ndarray, kv: jnp.ndarray, q_cu_seqlen: jnp.ndarray, - kv_cu_seqlen: jnp.ndarray, rng_state: jnp.ndarray, + kv_cu_seqlen: jnp.ndarray, seed: jnp.ndarray, attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type, scaling_factor: float, dropout_probability: float, is_training: bool): @@ -2304,14 +2329,13 @@ def cross_fused_attn_max_512_fwd(q: jnp.ndarray, kv: jnp.ndarray, q_cu_seqlen: j Wrapper for TE cross fused attention max seqlen 512 fwd Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2 """ - # Jax can't bind None, create a dummy tensor for None - if rng_state is None: - rng_state = jnp.zeros(2, dtype=jnp.int32) + seed = _check_seed(seed, dropout_probability, is_training) + return _cross_fused_attn_max_512_fwd_p.bind(q, kv, q_cu_seqlen, kv_cu_seqlen, - rng_state, + seed, attn_bias_type=attn_bias_type, attn_mask_type=attn_mask_type, scaling_factor=scaling_factor, @@ -2391,6 +2415,9 @@ def lowering(ctx, q, kv, softmax_aux, doutput, q_cu_seqlen, kv_cu_seqlen, *, att ] args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + + # the dropout elements are encoded in the forward auxiliary tensor + # so seed is not needed in backward opaque = transformer_engine_jax.pack_fused_attn_descriptor( batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, diff --git a/transformer_engine/jax/csrc/modules.cpp b/transformer_engine/jax/csrc/modules.cpp index b1c9d5d21a..d6d3caf4ba 100644 --- a/transformer_engine/jax/csrc/modules.cpp +++ b/transformer_engine/jax/csrc/modules.cpp @@ -749,7 +749,7 @@ void SelfFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char void *qkv = buffers[0]; void *bias = buffers[1]; void *cu_seqlens = buffers[2]; - void *rng_state = buffers[3]; + void *seed = buffers[3]; // output void *output = buffers[4]; @@ -778,30 +778,37 @@ void SelfFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char auto cu_seqlens_tensor = TensorWrapper(cu_seqlens, std::vector{batch + 1}, DType::kInt32); - auto rng_state_tensor = TensorWrapper(rng_state, std::vector{1}, DType::kInt64); + + auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector{2}, DType::kInt64); NVTETensorPack aux_output_tensors; nvte_tensor_pack_create(&aux_output_tensors); TensorWrapper query_workspace_tensor; - nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), - o_tensor.data(), &aux_output_tensors, cu_seqlens_tensor.data(), - rng_state_tensor.data(), q_max_seqlen, descriptor.is_training, - descriptor.scaling_factor, descriptor.dropout_probability, - NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED, descriptor.bias_type, - descriptor.mask_type, query_workspace_tensor.data(), stream); + nvte_fused_attn_fwd_qkvpacked( + qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), + &aux_output_tensors, cu_seqlens_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, + descriptor.is_training, descriptor.scaling_factor, descriptor.dropout_probability, + NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED, descriptor.bias_type, descriptor.mask_type, + query_workspace_tensor.data(), stream); auto *output_s = reinterpret_cast(aux_output_tensors.tensors[0]); output_s->data.dptr = softmax_aux; - size_t workspace_size = + // fused attn workspace + workspace for rng_state + auto plan_workspace_size = query_workspace_tensor.shape().data[0] * typeToSize(query_workspace_tensor.dtype()); - auto *workspace = cublasLtMetaManager::Instance().GetWorkspace(workspace_size); - + auto rng_workspace_size = 2 * sizeof(int64_t); + auto total_workspace_size = plan_workspace_size + rng_workspace_size; + auto *workspace = cublasLtMetaManager::Instance().GetWorkspace(total_workspace_size); auto workspace_tensor = TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype()); + auto rng_state = static_cast(workspace) + plan_workspace_size; + auto rng_state_tensor = TensorWrapper(rng_state, std::vector{2}, DType::kInt64); + PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, stream); + nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, cu_seqlens_tensor.data(), rng_state_tensor.data(), q_max_seqlen, descriptor.is_training, @@ -907,7 +914,7 @@ void CrossFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char void *kv = buffers[1]; void *q_cu_seqlens = buffers[2]; void *kv_cu_seqlens = buffers[3]; - void *rng_state = buffers[4]; + void *seed = buffers[4]; // output void *output = buffers[5]; @@ -939,7 +946,8 @@ void CrossFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char TensorWrapper(q_cu_seqlens, std::vector{batch + 1}, DType::kInt32); auto kv_cu_seqlens_tensor = TensorWrapper(kv_cu_seqlens, std::vector{batch + 1}, DType::kInt32); - auto rng_state_tensor = TensorWrapper(rng_state, std::vector{1}, DType::kInt64); + + auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector{2}, DType::kInt64); NVTETensorPack aux_output_tensors; nvte_tensor_pack_create(&aux_output_tensors); @@ -949,7 +957,7 @@ void CrossFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char nvte_fused_attn_fwd_kvpacked( q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, descriptor.is_training, + dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, descriptor.is_training, descriptor.scaling_factor, descriptor.dropout_probability, NVTE_QKV_Layout::NVTE_KV_INTERLEAVED, descriptor.bias_type, descriptor.mask_type, query_workspace_tensor.data(), stream); @@ -957,13 +965,19 @@ void CrossFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char auto *output_s = reinterpret_cast(aux_output_tensors.tensors[0]); output_s->data.dptr = softmax_aux; - size_t workspace_size = + // fused attn workspace + workspace for rng_state + auto plan_workspace_size = query_workspace_tensor.shape().data[0] * typeToSize(query_workspace_tensor.dtype()); - auto *workspace = cublasLtMetaManager::Instance().GetWorkspace(workspace_size); - + auto rng_workspace_size = 2 * sizeof(int64_t); + auto total_workspace_size = plan_workspace_size + rng_workspace_size; + auto *workspace = cublasLtMetaManager::Instance().GetWorkspace(total_workspace_size); auto workspace_tensor = TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype()); + auto rng_state = static_cast(workspace) + plan_workspace_size; + auto rng_state_tensor = TensorWrapper(rng_state, std::vector{2}, DType::kInt64); + PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, stream); + nvte_fused_attn_fwd_kvpacked( q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), diff --git a/transformer_engine/jax/csrc/utils.cpp b/transformer_engine/jax/csrc/utils.cu similarity index 52% rename from transformer_engine/jax/csrc/utils.cpp rename to transformer_engine/jax/csrc/utils.cu index f8440e2625..0970076838 100644 --- a/transformer_engine/jax/csrc/utils.cpp +++ b/transformer_engine/jax/csrc/utils.cu @@ -32,5 +32,23 @@ int GetDeviceComputeCapability(int gpu_id) { return gpu_arch; } +__global__ void populate_rng_state_kernel(int64_t *rng_state_dst, const int64_t *const seed, + int64_t offset) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid > 0) return; + rng_state_dst[0] = seed[0]; + rng_state_dst[1] = offset; +} + +void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q_max_seqlen, + size_t kv_max_seqlen, cudaStream_t stream) { + constexpr int threads_per_cta = 128; + const size_t increment = (q_max_seqlen * kv_max_seqlen + threads_per_cta - 1) / threads_per_cta; + auto offset = FusedAttnOffsetManager::Instance().GetAndUpdateOffset(increment); + populate_rng_state_kernel<<<1, 1, 0, stream>>>(reinterpret_cast(rng_state_dst), + reinterpret_cast(seed), offset); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/utils.h b/transformer_engine/jax/csrc/utils.h index 448c6706c7..baa014d6cb 100644 --- a/transformer_engine/jax/csrc/utils.h +++ b/transformer_engine/jax/csrc/utils.h @@ -21,6 +21,9 @@ namespace jax { int GetCudaRuntimeVersion(); int GetDeviceComputeCapability(int gpu_id); +void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q_max_seqlen, + size_t kv_max_seqlen, cudaStream_t stream); + class cublasLtMetaManager { public: static cublasLtMetaManager &Instance() { @@ -93,6 +96,27 @@ class cudaDevicePropertiesManager { cudaDeviceProp prop_; }; +class FusedAttnOffsetManager { + public: + static FusedAttnOffsetManager &Instance() { + static thread_local FusedAttnOffsetManager instance; + return instance; + } + + size_t GetAndUpdateOffset(size_t increment) { + size_t ret = offset_; + offset_ += increment; + return ret; + } + + FusedAttnOffsetManager(FusedAttnOffsetManager const &) = delete; + void operator=(FusedAttnOffsetManager const &) = delete; + + private: + FusedAttnOffsetManager() {} + size_t offset_ = 0; +}; + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index 563b15d526..14ad7f02e8 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -11,6 +11,7 @@ from typing import Any, Callable, Optional, Sequence, Tuple, Union import warnings +import jax import jax.numpy as jnp import numpy as np from flax import linen as nn @@ -182,9 +183,8 @@ def core_attention(query: Array, if not deterministic and dropout_rate > 0.: keep_prob = 1.0 - dropout_rate dropout_shape = list(attn_weights.shape) - dropout_shape[-2] = 1 + # TODO(rewang): add attention dropout broadcast dimension arguments for users keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape) - keep = jnp.broadcast_to(keep, attn_weights.shape) multiplier = (keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=dtype)) attn_weights = attn_weights * multiplier @@ -384,7 +384,7 @@ def kv_init(key, shape, dtype): fused_attn_supported_seqlen = [128, 256, 384, 512] enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "0")) use_fused_attn = not decode and not self.transpose_batch_sequence and self.fuse_qkv and \ - self.dropout_rate == 0 and canonicalize_dtype in [jnp.bfloat16, jnp.float16] and \ + canonicalize_dtype in [jnp.bfloat16, jnp.float16] and \ q_seqlen in fused_attn_supported_seqlen and kv_seqlen in fused_attn_supported_seqlen \ and is_fused_attn_kernel_available() and (self.head_dim == 64) and enable_fused_attn @@ -397,9 +397,6 @@ def kv_init(key, shape, dtype): f"but got {self.transpose_batch_sequence}, " if not self.fuse_qkv: reason += f"fuse_qkv=True is required but got {self.fuse_qkv}, " - if self.dropout_rate != 0: - # TODO(rewang): add dropout support - reason += f"no dropout is required but got dropout_rate={self.dropout_rate}, " if canonicalize_dtype not in [jnp.bfloat16, jnp.float16]: reason += f"dtype in [BF16, FP16] is required " \ f"but got dtype={canonicalize_dtype}, " @@ -583,6 +580,12 @@ def kv_init(key, shape, dtype): assert mask is not None and mask.ndim == 4 # (b, 1, s_q, s_kv) assert not self.transpose_batch_sequence + seed = None + if dropout_rng is not None: + seed = jax.random.split(dropout_rng, len(jax.devices())) + # ensure the old key never used + del dropout_rng + # TODO(rewang): make it configurable for pre_scale_bias attn_bias_type = AttnBiasType.NO_BIAS if bias is None else AttnBiasType.POST_SCALE_BIAS @@ -607,7 +610,7 @@ def canonicalize_attn_mask_type(attn_mask_type): x = self_fused_attn(qkv_proj, bias, mask, - dropout_rng, + seed, attn_bias_type=attn_bias_type, attn_mask_type=attn_mask_type, scaling_factor=scale_factor, @@ -626,7 +629,7 @@ def canonicalize_attn_mask_type(attn_mask_type): x = cross_fused_attn(query, kv_proj, mask, - dropout_rng, + seed, attn_bias_type=attn_bias_type, attn_mask_type=attn_mask_type, scaling_factor=scale_factor, diff --git a/transformer_engine/jax/fused_attn.py b/transformer_engine/jax/fused_attn.py index 3eb516e3bb..ce34ca2670 100644 --- a/transformer_engine/jax/fused_attn.py +++ b/transformer_engine/jax/fused_attn.py @@ -46,7 +46,7 @@ class AttnMaskType(Enum): def self_fused_attn(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, - rng_state: jnp.ndarray, + seed: jnp.ndarray, attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType, scaling_factor: float, @@ -63,7 +63,7 @@ def self_fused_attn(qkv: jnp.ndarray, output = _self_fused_attn_max_512(qkv, bias, mask, - rng_state, + seed, attn_bias_type=attn_bias_type, attn_mask_type=attn_mask_type, scaling_factor=scaling_factor, @@ -73,13 +73,13 @@ def self_fused_attn(qkv: jnp.ndarray, dp_axis_name = "batch" tp_axis_name = "model" - inputs = [qkv, bias, mask, rng_state] + inputs = [qkv, bias, mask, seed] batch, seqlen, _, num_head, head_dim = qkv.shape output_shape = [batch, seqlen, num_head, head_dim] sharding_meta = get_fused_attn_sharding_meta( sharding_type, [x.shape if x is not None else None for x in inputs], [output_shape], - dp_dims=([0, None, 0, None], [0]), - tp_dims=([3, 1, None, None], [2]), + dp_dims=([0, None, 0, 0], [0]), + tp_dims=([3, 1, None, 0], [2]), dp_axis_name=dp_axis_name, tp_axis_name=tp_axis_name) @@ -104,13 +104,13 @@ def self_fused_attn(qkv: jnp.ndarray, @partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8)) def _self_fused_attn_max_512(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, - rng_state: jnp.ndarray, attn_bias_type: AttnBiasType, + seed: jnp.ndarray, attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType, scaling_factor: float, dropout_probability: float, is_training: bool): output, _ = _self_fused_attn_max_512_fwd(qkv, bias, mask, - rng_state, + seed, attn_bias_type=attn_bias_type, attn_mask_type=attn_mask_type, scaling_factor=scaling_factor, @@ -119,7 +119,7 @@ def _self_fused_attn_max_512(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndar return output -def _self_fused_attn_max_512_fwd(qkv, bias, mask, rng_state, attn_bias_type, attn_mask_type, +def _self_fused_attn_max_512_fwd(qkv, bias, mask, seed, attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training): seqlen = jnp.sum(mask[:, :, :, 0] == 0, axis=(-1, -2), dtype=jnp.int32) @@ -129,7 +129,7 @@ def _self_fused_attn_max_512_fwd(qkv, bias, mask, rng_state, attn_bias_type, att output, softmax_aux = self_fused_attn_max_512_fwd(qkv, bias, cu_seqlen, - rng_state, + seed, attn_bias_type=attn_bias_type.value, attn_mask_type=attn_mask_type.value, scaling_factor=scaling_factor, @@ -163,7 +163,7 @@ def _self_fused_attn_max_512_bwd(attn_bias_type, attn_mask_type, scaling_factor, def cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, mask: jnp.ndarray, - rng_state: jnp.ndarray, + seed: jnp.ndarray, attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType, scaling_factor: float, @@ -180,7 +180,7 @@ def cross_fused_attn(q: jnp.ndarray, output = _cross_fused_attn_max_512(q, kv, mask, - rng_state, + seed, attn_bias_type=attn_bias_type, attn_mask_type=attn_mask_type, scaling_factor=scaling_factor, @@ -190,7 +190,7 @@ def cross_fused_attn(q: jnp.ndarray, dp_axis_name = "batch" tp_axis_name = "model" - inputs = [q, kv, mask, rng_state] + inputs = [q, kv, mask, seed] output_shape = q.shape sharding_meta = get_fused_attn_sharding_meta( sharding_type, [x.shape if x is not None else None for x in inputs], [output_shape], @@ -219,15 +219,14 @@ def cross_fused_attn(q: jnp.ndarray, @partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8)) -def _cross_fused_attn_max_512(q: jnp.ndarray, kv: jnp.ndarray, mask: jnp.ndarray, - rng_state: jnp.ndarray, attn_bias_type: AttnBiasType, - attn_mask_type: AttnMaskType, scaling_factor: float, - dropout_probability: float, is_training: bool): +def _cross_fused_attn_max_512(q: jnp.ndarray, kv: jnp.ndarray, mask: jnp.ndarray, seed: jnp.ndarray, + attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType, + scaling_factor: float, dropout_probability: float, is_training: bool): output, _ = _cross_fused_attn_max_512_fwd(q, kv, mask, - rng_state, + seed, attn_bias_type=attn_bias_type, attn_mask_type=attn_mask_type, scaling_factor=scaling_factor, @@ -236,8 +235,8 @@ def _cross_fused_attn_max_512(q: jnp.ndarray, kv: jnp.ndarray, mask: jnp.ndarray return output -def _cross_fused_attn_max_512_fwd(q, kv, mask, rng_state, attn_bias_type, attn_mask_type, - scaling_factor, dropout_probability, is_training): +def _cross_fused_attn_max_512_fwd(q, kv, mask, seed, attn_bias_type, attn_mask_type, scaling_factor, + dropout_probability, is_training): q_seqlen = jnp.sum(mask[:, :, :, 0] == 0, axis=(-1, -2), dtype=jnp.int32) q_cu_seqlen = jnp.cumsum(q_seqlen) @@ -251,7 +250,7 @@ def _cross_fused_attn_max_512_fwd(q, kv, mask, rng_state, attn_bias_type, attn_m kv, q_cu_seqlen, kv_cu_seqlen, - rng_state, + seed, attn_bias_type=attn_bias_type.value, attn_mask_type=attn_mask_type.value, scaling_factor=scaling_factor, From 92eabc339e159c50cda00fdd2de356ed43aba115 Mon Sep 17 00:00:00 2001 From: cyanguwa <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 22 Jun 2023 11:41:36 -0700 Subject: [PATCH 34/68] Add long sequence support for fused attention (#237) * add long sequence support and unify three backends for fused attention Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update cudnn-frontend to v0.9.1 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * replace cpu_float2half_rn with __float2half_rn Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix backend selection and NVTEDType Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor fixes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix ci Signed-off-by: Kirthi Shankar Sivamani * make cudnn plan caches thread_local Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix CI Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * replace cuDNN throw with NVTE_CHECK Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix replacement of cuDNN throw with NVTE_CHECK Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * force dropout probablity to 0 in inference mode Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * change negInfinity to be consistent with m512 fused attn Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove float2half conversion for scale_dropout Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add back runtime api for sm detection Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add gemm3 to enums FP8Fwd/BwdTensors Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * change dropout from no to yes for fmha_v1 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove output_rng_state in m512 kernels Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix elts_per_thread calculation in kvpacked fwd Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove dropout=0.0 restriction for m512 fused attn Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove output_rng_state completely from m512 kernels Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Signed-off-by: Kirthi Shankar Sivamani Co-authored-by: Kirthi Shankar Sivamani --- 3rdparty/cudnn-frontend | 2 +- tests/pytorch/test_fused_attn.py | 626 ++++++++ transformer_engine/common/CMakeLists.txt | 3 +- .../common/fused_attn/fused_attn.cpp | 355 +++-- .../fused_attn_f16_arbitrary_seqlen.cu | 1304 +++++++++++++++++ .../fused_attn_f16_arbitrary_seqlen.h | 44 + ...512.cu => fused_attn_f16_max512_seqlen.cu} | 46 +- ...n_512.h => fused_attn_f16_max512_seqlen.h} | 8 +- .../common/fused_attn/fused_attn_fp8.cu | 34 +- .../common/fused_attn/fused_attn_fp8.h | 6 +- transformer_engine/common/fused_attn/utils.cu | 1 - .../include/transformer_engine/fused_attn.h | 224 +-- transformer_engine/pytorch/attention.py | 410 +++++- transformer_engine/pytorch/constants.py | 2 +- .../pytorch/cpp_extensions/fused_attn.py | 470 +++--- transformer_engine/pytorch/csrc/common.h | 12 +- transformer_engine/pytorch/csrc/extensions.cu | 198 ++- transformer_engine/pytorch/csrc/extensions.h | 51 +- transformer_engine/pytorch/transformer.py | 15 + 19 files changed, 3172 insertions(+), 639 deletions(-) create mode 100644 tests/pytorch/test_fused_attn.py create mode 100644 transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu create mode 100644 transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h rename transformer_engine/common/fused_attn/{fused_attn_fp16_bf16_max_seqlen_512.cu => fused_attn_f16_max512_seqlen.cu} (98%) rename transformer_engine/common/fused_attn/{fused_attn_fp16_bf16_max_seqlen_512.h => fused_attn_f16_max512_seqlen.h} (91%) diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index e7f64390e9..a4f05c1edc 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit e7f64390e9bb4a3db622ffe11c973834f572b609 +Subproject commit a4f05c1edcef453f5fd52f96218c29c7d420e511 diff --git a/tests/pytorch/test_fused_attn.py b/tests/pytorch/test_fused_attn.py new file mode 100644 index 0000000000..831c2d7c79 --- /dev/null +++ b/tests/pytorch/test_fused_attn.py @@ -0,0 +1,626 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import torch +import pytest + +from transformer_engine.pytorch.utils import ( + init_method_normal, + scaled_init_method_normal, +) +from transformer_engine.pytorch import TransformerLayer +from transformer_engine.pytorch.attention import DotProductAttention +import os + +class ModelConfig: + def __init__( + self, num_layers, hidden_size, num_attention_heads, head_dim, seq_len, + dropout_p, attn_mask_type, + ): + self.num_layers = num_layers + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.head_dim = head_dim + assert (hidden_size == num_attention_heads * head_dim + ), """hidden_size must be = num_heads x head_dim.""" + self.seq_len = seq_len + self.dropout_p = dropout_p + self.attn_mask_type = attn_mask_type + +model_configs = { + "test1": ModelConfig(1, 1024, 16, 64, 128, 0.0, "causal"), + "test2": ModelConfig(1, 1024, 16, 64, 2048, 0.0, "causal"), + "test3": ModelConfig(1, 2048, 16, 128, 128, 0.0, "causal"), + "test4": ModelConfig(1, 2048, 16, 128, 2048, 0.0, "causal"), + "test5": ModelConfig(1, 1024, 16, 64, 128, 0.0, "no_mask"), +} + +param_types = [torch.float16] +if torch.cuda.is_bf16_supported(): + param_types.append(torch.bfloat16) + +batch_sizes = [1, 2] + +@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("bs", batch_sizes) +@pytest.mark.parametrize("model", model_configs.keys()) +def test_dot_product_attention(dtype, bs, model): + """Test DotProductAttention module with three backends, + FlashAttention, FusedAttention and UnfusedDotProductAttention""" + + config = model_configs[model] + + flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention( + dtype, bs, config, "FlashAttention") + fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention( + dtype, bs, config, "FusedAttention") + unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention( + dtype, bs, config, "UnfusedDotProductAttention") + + atol, rtol = (2.5e-2, 2.5e-2) if dtype == torch.bfloat16 else (2.5e-3, 2.5e-3) + assert torch.allclose(fused_attn_fwd, flash_attn_fwd, atol = atol, rtol = rtol) + assert torch.allclose(fused_attn_bwd, flash_attn_bwd, atol = atol, rtol = rtol) + assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol) + assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol = atol, rtol = rtol) + +def _run_dot_product_attention(dtype, bs, config, backend): + + torch.manual_seed(1234) + torch.cuda.manual_seed(1234) + os.environ["NVTE_FLASH_ATTN"] = "0" + os.environ["NVTE_FUSED_ATTN"] = "0" + if backend == "FlashAttention": + os.environ["NVTE_FLASH_ATTN"] = "1" + if backend == "FusedAttention": + os.environ["NVTE_FUSED_ATTN"] = "1" + + inp = 0.1 * torch.randn( + config.seq_len, bs, 3, config.num_attention_heads, config.head_dim, + dtype = dtype).cuda() + inp.requires_grad=True + seqlens = torch.empty(bs, dtype = torch.int32).cuda() + seqlens.fill_(config.seq_len) + cu_seqlens = torch.zeros(bs + 1, device = inp.device, dtype = torch.int32) + cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0) + op_grad = 0.001 * torch.randint(0, 200, ( + config.seq_len, bs, config.num_attention_heads * config.head_dim + ), dtype = dtype).cuda() + + block = ( + DotProductAttention( + config.num_attention_heads, + config.head_dim, + attention_dropout = config.dropout_p, + attn_mask_type = config.attn_mask_type, + sequence_parallel = False, + tp_size = 1, + get_rng_state_tracker = None, + tp_group = None, + layer_number = 1, + attention_type = "self" + ).to(dtype = dtype).cuda() + ) + + q = inp[:, :,0,:,:] + k = inp[:, :,1,:,:] + v = inp[:, :,2,:,:] + op = block(q, k, v) + op.backward(op_grad) + + return op, inp.grad + +@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("bs", batch_sizes) +@pytest.mark.parametrize("model", model_configs.keys()) +def test_transformer_layer(dtype, bs, model): + """Test TransformerLayer module when its DotProductAttention is enabled with + FlashAttention, FusedAttention, or UnfusedDotProductAttention backend""" + + config = model_configs[model] + + flash_attn_fwd, flash_attn_bwd = _run_transformer_layer( + dtype, bs, config, "FlashAttention") + fused_attn_fwd, fused_attn_bwd = _run_transformer_layer( + dtype, bs, config, "FusedAttention") + unfused_attn_fwd, unfused_attn_bwd = _run_transformer_layer( + dtype, bs, config, "UnfusedDotProductAttention") + + atol, rtol = (5e-1, 5e-1) if dtype == torch.bfloat16 else (5e-1, 5e-1) + assert torch.allclose(fused_attn_fwd, flash_attn_fwd, atol = atol, rtol = rtol) + assert torch.allclose(fused_attn_bwd, flash_attn_bwd, atol = atol, rtol = rtol) + assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol) + assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol = atol, rtol = rtol) + +def _run_transformer_layer(dtype, bs, config, backend): + + torch.manual_seed(1234) + torch.cuda.manual_seed(1234) + os.environ["NVTE_FLASH_ATTN"] = "0" + os.environ["NVTE_FUSED_ATTN"] = "0" + if backend == "FlashAttention": + os.environ["NVTE_FLASH_ATTN"] = "1" + if backend == "FusedAttention": + os.environ["NVTE_FUSED_ATTN"] = "1" + + inp = 0.1 * torch.randn( + config.seq_len, bs, config.num_attention_heads * config.head_dim, + dtype = dtype).cuda() + inp.requires_grad=True + seqlens = torch.empty(bs, dtype = torch.int32).cuda() + seqlens.fill_(config.seq_len) + cu_seqlens = torch.zeros(bs + 1, device = inp.device, dtype = torch.int32) + cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0) + op_grad = 0.001 * torch.randint(0, 200, ( + config.seq_len, bs, config.num_attention_heads * config.head_dim + ), dtype = dtype).cuda() + + sigma = 0.02 + init_method = init_method_normal(sigma) + output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) + + layer_number = 1 + drop_path_rate = 0.0 + drop_path_rates = [ + rate.item() for rate in torch.linspace(0, drop_path_rate, config.num_layers)] + + block = ( + TransformerLayer( + config.hidden_size, + 4 * config.hidden_size, + config.num_attention_heads, + layernorm_epsilon = 1e-5, + hidden_dropout = 0.0, + attention_dropout = config.dropout_p, + init_method = init_method, + output_layer_init_method = output_layer_init_method, + layer_number = layer_number, + kv_channels = config.head_dim, + self_attn_mask_type = config.attn_mask_type, + tp_group = None, + tp_size = 1, + params_dtype = dtype, + get_rng_state_tracker = None, + fuse_wgrad_accumulation = False, + seq_length = config.seq_len, + micro_batch_size = bs, + sequence_parallel = False, + apply_residual_connection_post_layernorm = False, + output_layernorm = False, + layer_type = "encoder", + drop_path_rate = drop_path_rates[layer_number - 1], + set_parallel_mode = True, + fuse_qkv_params = True, + zero_centered_gamma = False, + qkv_weight_interleaved = False, + ub_tp_comm_overlap = False, + bias = True, + ) + .to(dtype = dtype) + .cuda() + ) + + op = block(inp) + op.backward(op_grad) + + return op, inp.grad + +model_configs_fp8 = { + "test1": ModelConfig(1, 1024, 16, 64, 512, 0.0, "no_mask"), +} +batch_sizes_fp8 = [1, 4] +param_types_fp8 = [torch.float16] + +@pytest.mark.parametrize("dtype", param_types_fp8) +@pytest.mark.parametrize("bs", batch_sizes_fp8) +@pytest.mark.parametrize("model", model_configs_fp8.keys()) +def test_dpa_fp8(dtype, bs, model): + """Test DotProductAttention module with FP8, + using cpp_extensions import fused_attn_fwd/bwd_qkvpacked and UnfusedDotProductAttention""" + + config = model_configs_fp8[model] + + fused_attn_fwd, fused_attn_bwd = _run_dpa_fp8( + dtype, bs, config, "FusedAttention") + unfused_attn_fwd, unfused_attn_bwd = _run_dpa_fp8_ref( + dtype, bs, config, "UnfusedDotProductAttention") + + atol, rtol = (5e-2, 1e-1) + assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol) + assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol = atol, rtol = rtol) + +def _run_dpa_fp8(dtype, bs, config, backend): + + torch.manual_seed(1234) + torch.cuda.manual_seed(1234) + os.environ["NVTE_FLASH_ATTN"] = "0" + os.environ["NVTE_FUSED_ATTN"] = "0" + + inp = 0.01 * torch.randn( + bs * config.seq_len, config.num_attention_heads * config.head_dim, + dtype = dtype).cuda() + inp.requires_grad=True + seqlens = torch.empty(bs, dtype = torch.int32).cuda() + seqlens.fill_(config.seq_len) + cu_seqlens = torch.zeros(bs + 1, device = inp.device, dtype = torch.int32) + cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0) + op_grad = 0.001 * torch.randint(0, 200, ( + bs * config.seq_len, config.num_attention_heads * config.head_dim + ), dtype = dtype).cuda() + torch.save(op_grad, 'op_grad.pt') + + fp8_recipe = recipe.DelayedScaling( + margin=0, + interval=1, + fp8_format=recipe.Format.HYBRID, + amax_history_len=1, + amax_compute_algo="most_recent", + ) + + dpa = DPA_FP8(config).to(dtype = torch.float16).cuda() + with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): + op = dpa(inp, cu_seqlens, config.seq_len) + op.backward(op_grad) + + context = torch.load("ctx.pt") + dqkv = torch.load('dqkv.pt') + return (context.view(bs, config.seq_len, -1).transpose(0,1), + dqkv.view(bs, config.seq_len, 3, config.num_attention_heads, config.head_dim).transpose(0,1).contiguous()) + +def _run_dpa_fp8_ref(dtype, bs, config, backend): + + os.environ["NVTE_FLASH_ATTN"] = "0" + os.environ["NVTE_FUSED_ATTN"] = "0" + if backend == "FlashAttention": + os.environ["NVTE_FLASH_ATTN"] = "1" + if backend == "FusedAttention": + os.environ["NVTE_FUSED_ATTN"] = "1" + + inp = torch.load('qkv.pt').cuda() + inp.requires_grad=True + seqlens = torch.empty(bs, dtype = torch.int32).cuda() + seqlens.fill_(config.seq_len) + cu_seqlens = torch.zeros(bs + 1, device = inp.device, dtype = torch.int32) + cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0) + op_grad = torch.load('op_grad.pt').cuda().view(bs, config.seq_len, -1).transpose(0,1) + + block = ( + DotProductAttention( + config.num_attention_heads, + config.head_dim, + attention_dropout = config.dropout_p, + attn_mask_type = config.attn_mask_type, + sequence_parallel = False, + tp_size = 1, + get_rng_state_tracker = None, + tp_group = None, + layer_number = 1, + attention_type = "self" + ).to(dtype = dtype).cuda() + ) + + q = inp[:, :,0,:,:] + k = inp[:, :,1,:,:] + v = inp[:, :,2,:,:] + op = block(q, k, v) + op.backward(op_grad) + torch.save(op,'ctx_ref.pt') + torch.save(inp.grad,'dqkv_ref.pt') + + return op, inp.grad + +from torch.nn.parameter import Parameter +import transformer_engine.pytorch.cpp_extensions as ext +import transformer_engine_extensions as tex +import transformer_engine.pytorch.fp8 as fp8 +from transformer_engine.pytorch import fp8_autocast +from transformer_engine.pytorch.module.base import TransformerEngineBaseModule, _prepare_backward +from transformer_engine.common import recipe +from typing import Union, Dict, Any, Tuple, List +from transformer_engine.pytorch.cpp_extensions.fused_attn import ( + fused_attn_fwd_qkvpacked, + fused_attn_bwd_qkvpacked, + FusedAttnBackend) + +_CUBLASLT_WORKSPACE_SIZE_BYTES = 33_554_432 # 32MiB +_2X_ACC_FPROP = False +_2X_ACC_DGRAD = False +_2X_ACC_WGRAD = False + +META_QKV = tex.FP8FwdTensors.GEMM1_OUTPUT +META_O = tex.FP8FwdTensors.GEMM2_INPUT +META_DO = tex.FP8BwdTensors.GRAD_INPUT2 +META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1 + +META_S = tex.FP8FwdTensors.GEMM3_WEIGHT +META_DS = tex.FP8BwdTensors.GRAD_INPUT3 + +class _dpa_fp8(torch.autograd.Function): + @staticmethod + def forward( + ctx, + inp: torch.Tensor, + qkv_weight: torch.Tensor, + qkv_bias: torch.Tensor, + cu_seqlens: torch.Tensor, + num_attention_heads: int, + p_dropout: float, + max_s: int, + fast_zero_fill: bool, + fp8_meta: Dict[str, Any], + workspace: torch.Tensor, + is_training: bool, + ) -> torch.Tensor: + + assert inp.dim() == 2 + in_features = qkv_weight.shape[-1] + h = num_attention_heads + d = in_features // h + b = cu_seqlens.numel() - 1 + is_nl = False + if b < 4 and b > 1: + max_s = 512 + is_nl = True + + fp8_dtype_forward = fp8.get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + + inputmat, inputmat_t = ext.fp8_cast_transpose_fused( + inp, + fp8_meta["scaling_fwd"], + tex.FP8FwdTensors.GEMM1_INPUT, + fp8_dtype_forward, + ) + + qkv_weight_fp8, qkv_weight_t_fp8 = ext.fp8_cast_transpose_fused( + qkv_weight, + fp8_meta["scaling_fwd"], + tex.FP8FwdTensors.GEMM1_WEIGHT, + fp8_dtype_forward, + ) + + M = None + ZInv = None + philox_unpacked = None + + qkv_out = ext.fp8_gemm( + qkv_weight_fp8, + fp8_meta["scaling_fwd"].scale_inv, + tex.FP8FwdTensors.GEMM1_WEIGHT, + fp8_dtype_forward, + inputmat, + fp8_meta["scaling_fwd"].scale_inv, + tex.FP8FwdTensors.GEMM1_INPUT, + fp8_dtype_forward, + torch.uint8, + workspace, + bias=qkv_bias, + use_bias=True, + out_index = META_QKV, + fp8_meta_tensor = fp8_meta["scaling_fwd"], + use_split_accumulator=_2X_ACC_FPROP, + D_dtype=fp8_dtype_forward, + ) + qkv_out = qkv_out.view(-1, 3, h, d) + qkv_out_fp16 = ext.cast_from_fp8(qkv_out, fp8_meta["scaling_fwd"], + META_QKV, fp8_dtype_forward, + tex.DType.kFloat16).view(b, max_s, 3, h, d).transpose(0,1).contiguous() + torch.save(qkv_out_fp16, 'qkv.pt') + + # FMHA + context_, aux_ctx_tensors, *rest = fused_attn_fwd_qkvpacked( + is_training, + max_s, + cu_seqlens, + qkv_out, + fp8_dtype_forward, + FusedAttnBackend["FP8"], + None, + fp8_meta["scaling_fwd"].scale_inv[META_QKV], + fp8_meta["scaling_fwd"].scale[META_S], + fp8_meta["scaling_fwd"].scale[META_O], + fp8_meta["scaling_fwd"].amax_history[0][META_S], + fp8_meta["scaling_fwd"].amax_history[0][META_O], + attn_scale = None, + dropout = p_dropout, + fast_zero_fill = fast_zero_fill, + qkv_layout = "qkv_interleaved", + attn_bias_type = "no_bias", + attn_mask_type = "padding", + rng_gen = None, + ) + M, ZInv, philox_unpacked = aux_ctx_tensors + + context = context_.view(-1, in_features) + context_t = tex.fp8_transpose(context, fp8_dtype_forward) + + ctx.save_for_backward( + inputmat_t, qkv_weight_t_fp8, workspace, + qkv_out, + context_, context_t, + fp8_meta["scaling_fwd"].scale, + fp8_meta["scaling_fwd"].scale_inv, + ) + ctx.aux_ctx_tensors = aux_ctx_tensors + ctx.fp8_meta = fp8_meta + ctx.cu_seqlens = cu_seqlens + ctx.p_dropout = p_dropout + ctx.max_s = max_s + ctx.fast_zero_fill = fast_zero_fill + ctx.is_nl = is_nl + ctx.hidden_size = in_features + ctx.num_attention_heads = num_attention_heads + + context_fp16 = ext.cast_from_fp8(context, fp8_meta["scaling_fwd"], + META_O, fp8_dtype_forward, tex.DType.kFloat16) + torch.save(context_fp16, 'ctx.pt') + return context_fp16 + + + @staticmethod + def backward( + ctx, grad_output: torch.Tensor + ) -> Tuple[Union[torch.Tensor, None], ...]: + + with _prepare_backward(True, ctx.fp8_meta, None, 1, name="_DPA"): + ( + inputmat_t, + qkv_weight_t_fp8, + workspace, + qkv_out, + context, context_t, + fwd_scales, + fwd_scale_inverses, + ) = ctx.saved_tensors + fp8_dtype_forward = fp8.get_fp8_te_dtype( + ctx.fp8_meta["recipe"], fprop_tensor=True + ) + fp8_dtype_backward = fp8.get_fp8_te_dtype( + ctx.fp8_meta["recipe"], fprop_tensor=False + ) + + proj_dgrad = ext.cast_to_fp8( + grad_output, ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward + ) + + dqkv, *rest = fused_attn_bwd_qkvpacked( + ctx.max_s, + ctx.cu_seqlens, + qkv_out, + context, + proj_dgrad.view_as(context), + fp8_dtype_forward, + ctx.aux_ctx_tensors, + FusedAttnBackend["FP8"], + fwd_scale_inverses[META_QKV], # d_scale_qkv, + fwd_scale_inverses[META_S], # d_scale_s, + fwd_scale_inverses[META_O], # d_scale_o, + ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO], # d_scale_do + fwd_scales[META_S], # q_scale_s + ctx.fp8_meta['scaling_bwd'].scale[META_DS], # q_scale_ds + ctx.fp8_meta['scaling_bwd'].scale[META_DQKV], # q_scale_dqkv + ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DS], # amax_ds + ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DQKV], # amax_dqkv + None, + ctx.p_dropout, + ctx.fast_zero_fill, + "qkv_interleaved", + "no_bias", + "padding", + ) + + dqkv_grad_output_c = dqkv.view(-1, 3*ctx.hidden_size) + dqkv_grad_output_c_fp16 = ext.cast_from_fp8(dqkv_grad_output_c, + ctx.fp8_meta["scaling_bwd"], META_DQKV, + fp8_dtype_backward, tex.DType.kFloat16) + torch.save(dqkv_grad_output_c_fp16, 'dqkv.pt') + + qkv_bgrad, dqkv_grad_output_t = ext.fp8_transpose_bgrad_fused( + dqkv_grad_output_c, + ctx.fp8_meta["scaling_bwd"], + META_DQKV, + fp8_dtype_backward, + torch.float16, + ) + + # QKV DGRAD + qkv_dgrad = ext.fp8_gemm( + qkv_weight_t_fp8, + fwd_scale_inverses, + tex.FP8FwdTensors.GEMM1_WEIGHT, + fp8_dtype_forward, + dqkv_grad_output_c, + ctx.fp8_meta["scaling_bwd"].scale_inv, + META_DQKV, + fp8_dtype_backward, + torch.float16, + workspace, + use_split_accumulator=_2X_ACC_DGRAD, + ) + # QKV WGRAD + qkv_wgrad = ext.fp8_gemm( + inputmat_t, + fwd_scale_inverses, + tex.FP8FwdTensors.GEMM1_INPUT, + fp8_dtype_forward, + dqkv_grad_output_t, + ctx.fp8_meta["scaling_bwd"].scale_inv, + META_DQKV, + fp8_dtype_backward, + torch.float16, + workspace, + use_split_accumulator=_2X_ACC_WGRAD, + ) + + return (qkv_dgrad, + qkv_wgrad, + qkv_bgrad, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None) + +class DPA_FP8(TransformerEngineBaseModule): + def __init__( + self, + config, + params_dtype: torch.dtype = torch.float32): + super().__init__() + self.p_dropout = config.dropout_p + self.h = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_dim = config.head_dim + self.fast_zero_fill = True + + self.qkv_weight = Parameter( + torch.empty( + self.hidden_size * 3, + self.hidden_size, + device=torch.cuda.current_device(), + dtype=params_dtype, + ) + ) + self.fp8_weight_shapes.append(self.qkv_weight.shape) + self.qkv_bias = Parameter( + torch.empty( + self.hidden_size * 3, + device=torch.cuda.current_device(), + dtype=params_dtype, + ) + ) + with torch.no_grad(): + self.qkv_bias.zero_() + self.qkv_weight.fill_(1.0) + self.workspace = torch.empty( + _CUBLASLT_WORKSPACE_SIZE_BYTES, dtype=torch.int8, device="cuda" + ) + + def forward( + self, inp: torch.Tensor, + cu_seqlens, max_s, + ) -> torch.Tensor: + with self.prepare_forward(inp, None, num_gemms=3) as inp: + out = _dpa_fp8.apply( + inp, + self.qkv_weight, + self.qkv_bias, + cu_seqlens, + self.h, + self.p_dropout, + max_s, + self.fast_zero_fill, + self.fp8_meta, + self.workspace, + self.training) + return out + + def get_fp8_weights_scratchpad( + self, + is_first_microbatch: Union[bool, None], + ) -> List[torch.Tensor]: + """Needs override.""" diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index a7653355db..481e1677ee 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -12,9 +12,10 @@ list(APPEND transformer_engine_SOURCES transpose/transpose_fusion.cu transpose/multi_cast_transpose.cu activation/gelu.cu + fused_attn/fused_attn_f16_max512_seqlen.cu + fused_attn/fused_attn_f16_arbitrary_seqlen.cu activation/relu.cu activation/swiglu.cu - fused_attn/fused_attn_fp16_bf16_max_seqlen_512.cu fused_attn/fused_attn_fp8.cu fused_attn/fused_attn.cpp fused_attn/utils.cu diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index f1846c49d5..25f62cad09 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -7,8 +7,80 @@ #include "transformer_engine/fused_attn.h" #include "../common.h" #include "utils.h" -#include "fused_attn_fp16_bf16_max_seqlen_512.h" +#include "fused_attn_f16_max512_seqlen.h" +#include "fused_attn_f16_arbitrary_seqlen.h" #include "fused_attn_fp8.h" +#include "../util/cuda_runtime.h" + +// select a backend for fused attention +NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( + NVTEDType q_dtype, + NVTEDType kv_dtype, + NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, + float dropout, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim) { + using namespace transformer_engine; + NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; + const int device_id = cuda::current_device(); + const int sm_arch_ = cuda::sm_arch(device_id); + NVTE_CHECK(q_dtype == kv_dtype, "Q and KV must have the same data type."); + if ((q_dtype == NVTEDType::kNVTEFloat8E4M3) || (q_dtype == NVTEDType::kNVTEFloat8E5M2) + && (sm_arch_ >= 90) + && (max_seqlen_q == max_seqlen_kv) + && (max_seqlen_q <= 512) + && (head_dim == 64) + && (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) + && (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK) + && (qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED)) { + backend = NVTE_Fused_Attn_Backend::NVTE_FP8; + } else if ((q_dtype == NVTEDType::kNVTEFloat16) || (q_dtype == NVTEDType::kNVTEBFloat16)) { + bool flag_m512 = false; + bool flag_arb = false; + if ((sm_arch_ >= 80) + && (head_dim == 64) + && ((bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) + || (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) + && ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) + || (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) + || (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) + && ((qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) + || (qkv_layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED))) { + flag_m512 = true; + } + if ((sm_arch_ >= 80) + && (max_seqlen_q == max_seqlen_kv) + && ((head_dim == 64) || (head_dim == 128)) + && (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) + && (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) + && (qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED)) { + flag_arb = true; + } + if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) + && (flag_arb == true)) { + backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen; + } + if ((max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) { + if (flag_m512 == true) { + backend = NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen; + } else if ((flag_m512 == false) && (flag_arb == true)) { + backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen; + } + } + const char* env_backend = std::getenv("NVTE_FUSED_ATTN_BACKEND"); + if ((max_seqlen_q <= 512) && (max_seqlen_kv <= 512) + && (flag_arb == true) + && (env_backend != nullptr) + && (std::string(env_backend) == std::to_string( + NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen))) { + backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen; + } + } else { + backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; + } + return backend; +} // NVTE fused attention FWD FP8 with packed QKV void nvte_fused_attn_fwd_qkvpacked( @@ -16,7 +88,7 @@ void nvte_fused_attn_fwd_qkvpacked( const NVTETensor Bias, NVTETensor S, NVTETensor O, - NVTETensorPack* Aux_Output_Tensors, + NVTETensorPack* Aux_CTX_Tensors, const NVTETensor cu_seqlens, const NVTETensor rng_state, size_t max_seqlen, @@ -43,54 +115,56 @@ void nvte_fused_attn_fwd_qkvpacked( size_t d = input_QKV->data.shape[ndim - 1]; auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); - const DType QKV_type = input_QKV->data.dtype; + const NVTEDType QKV_type = static_cast(input_QKV->data.dtype); - if (((QKV_type == DType::kFloat8E4M3) || (QKV_type == DType::kFloat8E5M2)) - && (max_seqlen <= 512)) { + NVTE_Fused_Attn_Backend fused_attention_backend = + nvte_get_fused_attn_backend( + QKV_type, QKV_type, + qkv_layout, bias_type, attn_mask_type, + dropout, max_seqlen, max_seqlen, d); + + if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { +#if (CUDNN_VERSION >= 8901) + fused_attn_max_512_fwd_qkvpacked( + b, max_seqlen, h, d, + is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, + input_QKV, input_Bias, output_O, + Aux_CTX_Tensors, + input_cu_seqlens, + input_rng_state, + wkspace, stream, handle); +#else + NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); +#endif + } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8900) - // FP8 API doesn't use input_Bias, bias_type or attn_mask_type - fused_attn_fwd_fp8_qkvpacked( + fused_attn_arbitrary_seqlen_fwd_qkvpacked( + b, max_seqlen, h, d, + is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, + input_QKV, input_Bias, output_O, + Aux_CTX_Tensors, + input_cu_seqlens, + input_rng_state, + wkspace, stream, handle); +#else + NVTE_ERROR( + "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); +#endif + } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { +#if (CUDNN_VERSION >= 8900) + fused_attn_fp8_fwd_qkvpacked( b, max_seqlen, h, d, is_training, attn_scale, dropout, qkv_layout, input_QKV, input_output_S, output_O, - Aux_Output_Tensors, + Aux_CTX_Tensors, input_cu_seqlens, input_rng_state, wkspace, stream, handle); #else - NVTE_ERROR("cuDNN 8.9 is required to run FP8 fused attention. \n"); + NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); #endif - } else if (((QKV_type == DType::kFloat16) || (QKV_type == DType::kBFloat16)) - && (max_seqlen <= 512)) { -#if (CUDNN_VERSION >= 8901) - fused_attn_max_512_fwd_qkvpacked( - b, - max_seqlen, - h, - d, - is_training, - attn_scale, - dropout, - qkv_layout, - bias_type, - attn_mask_type, - input_QKV, - input_Bias, - output_O, - Aux_Output_Tensors, - input_cu_seqlens, - input_rng_state, - wkspace, - stream, - handle); -#else - NVTE_ERROR( - "cuDNN 8.9.1 is required to run BF16/FP16 fused attention with max_seqlen<=512. \n"); -#endif - } else if (max_seqlen > 512) { - NVTE_ERROR("TBD: No support for fused attention with >512 seqlence length currently. \n"); } else { - NVTE_ERROR("Invalid combination of data type and sequence length! \n"); + NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); } } // NVTE fused attention BWD FP8 with packed QKV @@ -130,18 +204,52 @@ void nvte_fused_attn_bwd_qkvpacked( size_t d = input_QKV->data.shape[ndim - 1]; auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); - const DType QKV_type = input_QKV->data.dtype; + const NVTEDType QKV_type = static_cast(input_QKV->data.dtype); - if (((QKV_type == DType::kFloat8E4M3) || (QKV_type == DType::kFloat8E5M2)) - && (max_seqlen <= 512)) { + NVTE_Fused_Attn_Backend fused_attention_backend = + nvte_get_fused_attn_backend( + QKV_type, QKV_type, + qkv_layout, bias_type, attn_mask_type, + dropout, max_seqlen, max_seqlen, d); + + if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { +#if (CUDNN_VERSION >= 8901) + Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + fused_attn_max_512_bwd_qkvpacked( + b, max_seqlen, h, d, + attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, + input_QKV, input_dO, + output_S, + output_dQKV, output_dBias, + input_cu_seqlens, + wkspace, stream, handle); +#else + NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); +#endif + } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { +#if (CUDNN_VERSION >= 8900) + Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + const Tensor *input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + fused_attn_arbitrary_seqlen_bwd_qkvpacked( + b, max_seqlen, h, d, + attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, + input_QKV, input_O, input_dO, + output_S, + output_dQKV, output_dBias, + input_cu_seqlens, input_rng_state, + wkspace, stream, handle); +#else + const char *err_msg = + "cuDNN 8.9.0 is required for BF16/FP16 fused attention " + "with arbitrary sequence length. \n"; + NVTE_ERROR(err_msg); +#endif + } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) - // Aux_CTX_Tensors contain [M, ZInv, rng_state] generated by the forward pass const Tensor *input_M = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); const Tensor *input_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); const Tensor *input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); - - // FP8 API doesn't use input_dBias, bias_type or attn_mask_type - fused_attn_bwd_fp8_qkvpacked( + fused_attn_fp8_bwd_qkvpacked( b, max_seqlen, h, d, attn_scale, dropout, qkv_layout, input_QKV, input_O, input_dO, @@ -152,38 +260,10 @@ void nvte_fused_attn_bwd_qkvpacked( input_rng_state, wkspace, stream, handle); #else - NVTE_ERROR("cuDNN 8.9 is required to run FP8 fused attention. \n"); -#endif - } else if (((QKV_type == DType::kFloat16) || (QKV_type == DType::kBFloat16)) - && (max_seqlen <= 512)) { -#if (CUDNN_VERSION >= 8901) - fused_attn_max_512_bwd_qkvpacked( - b, - max_seqlen, - h, - d, - attn_scale, - dropout, - qkv_layout, - bias_type, - attn_mask_type, - input_QKV, - input_dO, - Aux_CTX_Tensors, - output_dQKV, - output_dBias, - input_cu_seqlens, - wkspace, - stream, - handle); -#else - NVTE_ERROR( - "cuDNN 8.9.1 is required to run BF16/FP16 fused attention with max_seqlen<=512. \n"); + NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); #endif - } else if (max_seqlen > 512) { - NVTE_ERROR("TBD: No support for fused attention with >512 seqlence length currently. \n"); } else { - NVTE_ERROR("Invalid combination of data type and sequence length! \n"); + NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); } } // NVTE fused attention FWD FP8 with packed KV @@ -193,7 +273,7 @@ void nvte_fused_attn_fwd_kvpacked( const NVTETensor Bias, NVTETensor S, NVTETensor O, - NVTETensorPack* Aux_Output_Tensors, + NVTETensorPack* Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, const NVTETensor rng_state, @@ -223,45 +303,37 @@ void nvte_fused_attn_fwd_kvpacked( size_t d = input_Q->data.shape[ndim - 1]; auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); - const DType QKV_type = input_Q->data.dtype; + const NVTEDType Q_type = static_cast(input_Q->data.dtype); + const NVTEDType KV_type = static_cast(input_KV->data.dtype); - if (((QKV_type == DType::kFloat8E4M3) || (QKV_type == DType::kFloat8E5M2)) - && (max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) { - NVTE_ERROR("The FP8 fused attention API only supports packed QKV input. \n"); - } else if (((QKV_type == DType::kFloat16) || (QKV_type == DType::kBFloat16)) - && (max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) { + NVTE_Fused_Attn_Backend fused_attention_backend = + nvte_get_fused_attn_backend( + Q_type, KV_type, + qkv_layout, bias_type, attn_mask_type, + dropout, max_seqlen_q, max_seqlen_kv, d); + + if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) - fused_attn_max_512_fwd_kvpacked( - b, - max_seqlen_q, - max_seqlen_kv, - h, - d, - is_training, - attn_scale, - dropout, - qkv_layout, - bias_type, - attn_mask_type, - input_Q, - input_KV, - input_Bias, - output_O, - Aux_Output_Tensors, - input_cu_seqlens_q, - input_cu_seqlens_kv, - input_rng_state, - wkspace, - stream, - handle); + fused_attn_max_512_fwd_kvpacked( + b, max_seqlen_q, max_seqlen_kv, h, d, + is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, + input_Q, input_KV, input_Bias, output_O, + Aux_CTX_Tensors, + input_cu_seqlens_q, input_cu_seqlens_kv, + input_rng_state, + wkspace, stream, handle); #else - NVTE_ERROR( - "cuDNN 8.9.1 is required to run BF16/FP16 fused attention with max_seqlen<=512. \n"); + NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); #endif - } else if ((max_seqlen_q > 512) || (max_seqlen_kv > 512)) { - NVTE_ERROR("TBD: No support for fused attention with >512 seqlence length currently. \n"); + } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { + const char* err_msg = + "The FP16/BF16 fused attention (arbitrary seqlen) currently " + "only supports packed QKV input.\n"; + NVTE_ERROR(err_msg); + } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { + NVTE_ERROR("The FP8 fused attention API only supports packed QKV input. \n"); } else { - NVTE_ERROR("Invalid combination of data type and sequence length! \n"); + NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); } } // NVTE fused attention BWD FP8 with packed KV @@ -307,44 +379,37 @@ void nvte_fused_attn_bwd_kvpacked( size_t d = input_Q->data.shape[ndim - 1]; auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); - const DType QKV_type = input_Q->data.dtype; + const NVTEDType Q_type = static_cast(input_Q->data.dtype); + const NVTEDType KV_type = static_cast(input_KV->data.dtype); - if (((QKV_type == DType::kFloat8E4M3) || (QKV_type == DType::kFloat8E5M2)) - && (max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) { - NVTE_ERROR("The FP8 fused attention API only supports packed QKV input. \n"); - } else if (((QKV_type == DType::kFloat16) || (QKV_type == DType::kBFloat16)) - && (max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) { + NVTE_Fused_Attn_Backend fused_attention_backend = + nvte_get_fused_attn_backend( + Q_type, KV_type, + qkv_layout, bias_type, attn_mask_type, + dropout, max_seqlen_q, max_seqlen_kv, d); + + if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) - fused_attn_max_512_bwd_kvpacked( - b, - max_seqlen_q, - max_seqlen_kv, - h, - d, - attn_scale, - dropout, - qkv_layout, - bias_type, - attn_mask_type, - input_Q, - input_KV, - input_dO, - Aux_CTX_Tensors, - output_dQ, - output_dKV, - output_dBias, - input_cu_seqlens_q, - input_cu_seqlens_kv, - wkspace, - stream, - handle); + Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + fused_attn_max_512_bwd_kvpacked( + b, max_seqlen_q, max_seqlen_kv, h, d, + attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, + input_Q, input_KV, input_dO, + output_S, + output_dQ, output_dKV, output_dBias, + input_cu_seqlens_q, input_cu_seqlens_kv, + wkspace, stream, handle); #else - NVTE_ERROR( - "cuDNN 8.9.1 is required to run BF16/FP16 fused attention with max_seqlen<=512. \n"); + NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); #endif - } else if ((max_seqlen_q > 512) || (max_seqlen_kv > 512)) { - NVTE_ERROR("TBD: No support for fused attention with >512 seqlence length currently. \n"); + } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { + const char* err_msg = + "The FP16/BF16 fused attention (arbitrary seqlen) currently " + "only supports packed QKV input.\n"; + NVTE_ERROR(err_msg); + } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { + NVTE_ERROR("The FP8 fused attention API only supports packed QKV input. \n"); } else { - NVTE_ERROR("Invalid combination of data type and sequence length! \n"); + NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); } } diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu new file mode 100644 index 0000000000..88e006fb4e --- /dev/null +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -0,0 +1,1304 @@ +/************************************************************************* + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "fused_attn_f16_arbitrary_seqlen.h" + +#include +#include +#include +#include +#include + +#include "../common.h" +#include "utils.h" + +#if (CUDNN_VERSION >= 8900) +#define Q_ID 1 +#define K_ID 2 +#define V_ID 3 +#define O_ID 4 +#define S_ID 5 +#define B_ID 6 +#define D_CONST_ID 7 +#define S_CONST_ID 8 +#define Q_SEQLEN_ID 9 +#define K_SEQLEN_ID 10 +#define dQ_ID 11 +#define dK_ID 12 +#define dV_ID 13 +#define dO_ID 14 +#define MASK_VAL_ID 15 +#define dS_ID 16 +#define D_SEED_ID 17 +#define D_OFFSET_ID 18 +#define S_STATS_ID 19 +#define S_SUM_ID 20 +#define SCALE_PROB 21 +#define K_TRANSPOSE_ID 22 +#define dQ_ACCUM_ID 23 + +#define VIRTUAL_ID 30 + +namespace transformer_engine { +namespace fused_attn { + +static cudnn_frontend::Tensor +createScale(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, + NVTE_QKV_Layout layout, cudnnDataType_t tensorType, + const cudnn_frontend::Tensor& sTensor, + std::vector* ops) { + // scale + int64_t scale_dim[4] = {1, 1, 1, 1}; + int64_t scale_stride[4] = {1, 1, 1, 1}; + + int64_t s_dim[4] = {b, h, s_q, s_kv}; + int64_t s_stride[4]; + generateMatrixStrides(b, h, s_q, s_kv, d, s_stride, layout, NVTE_QKV_Matrix::NVTE_S_Matrix); + + auto scaleTensor = tensor_create( + tensorType, S_CONST_ID, scale_dim, + scale_stride, false, true); // is by value + auto sScaleTensor = tensor_create( + tensorType, VIRTUAL_ID + 2000, s_dim, + s_stride, true, false); // is virtual + + // Define the scale descriptor + auto scaleDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); + + // Create a scale node + auto scale_op = binary_pw_op_create(sTensor, scaleTensor, sScaleTensor, scaleDesc); + + ops->push_back(std::move(scale_op)); + return sScaleTensor; +} + +static cudnn_frontend::Tensor +createQKBMM(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, + NVTE_QKV_Layout layout, cudnnDataType_t tensorType, + std::vector* ops) { + // Creates the necessary tensor descriptors + int64_t q_dim[4] = {b, h, s_q, d}; + int64_t q_stride[4]; + generateMatrixStrides(b, h, s_q, s_kv, d, q_stride, layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); + + int64_t k_dim[4] = {b, h, d, s_kv}; + int64_t k_stride[4]; + generateMatrixStrides( + b, h, s_q, s_kv, d, k_stride, layout, NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose); + + int64_t s_dim[4] = {b, h, s_q, s_kv}; + int64_t s_stride[4]; + generateMatrixStrides(b, h, s_q, s_kv, d, s_stride, layout, NVTE_QKV_Matrix::NVTE_S_Matrix); + + auto qTensor = tensor_create(tensorType, Q_ID, q_dim, q_stride, false, false); + auto kTransposeTensor = tensor_create( + tensorType, K_ID, k_dim, k_stride, false, false); // is virtual + // first GEMM output + auto sTensor = tensor_create( + CUDNN_DATA_FLOAT, VIRTUAL_ID + 1, s_dim, s_stride, true, false); // is virtual + + // Define the matmul 1 desc + auto matmul_1_Desc = cudnn_frontend::MatMulDescBuilder() + .setComputeType(CUDNN_DATA_FLOAT) + .build(); + + // Create a matmul 1 node + auto matmul_op1 = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) + .setaMatDesc(qTensor) + .setbMatDesc(kTransposeTensor) + .setcMatDesc(sTensor) + .setmatmulDesc(matmul_1_Desc) + .build(); + + ops->push_back(std::move(matmul_op1)); + + return sTensor; +} + +static cudnn_frontend::Tensor +createCausalMask(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, + NVTE_QKV_Layout layout, cudnnDataType_t tensorType, + std::vector* ops, + const cudnn_frontend::Tensor& prevBlockOutputTensor) { + CUDNN_FRONTEND_UNUSED(d); + CUDNN_FRONTEND_UNUSED(layout); + CUDNN_FRONTEND_UNUSED(tensorType); + + NVTE_CHECK(ops->size() != 0, "Padding Mask constructed incorrectly as the first one"); + + // subtraction output + int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv}; + int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1}; + + int64_t maskVal_dim[4] = {1, 1, 1, 1}; + int64_t maskVal_stride[4] = {1, 1, 1, 1}; + + // mask value to put in the masked pixels + auto maskValTensor = tensor_create( + CUDNN_DATA_FLOAT, MASK_VAL_ID, maskVal_dim, + maskVal_stride, false, true); // is by value + // gen index row output + auto rowIndexTensor = tensor_create( + CUDNN_DATA_FLOAT, VIRTUAL_ID + 100, afterBMM1_dim, + afterBMM1_stride, true, false); // is virtual + // gen index column output + auto columnIndexTensor = tensor_create( + CUDNN_DATA_FLOAT, VIRTUAL_ID + 101, afterBMM1_dim, + afterBMM1_stride, true, false); // is virtual + // create causal mask (row >= col) + auto causalMaskTensor = tensor_create( + CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 106, afterBMM1_dim, + afterBMM1_stride, true, false); // is virtual + + // output after masking + auto maskOutputTensor = tensor_create( + CUDNN_DATA_FLOAT, VIRTUAL_ID + 107, afterBMM1_dim, + afterBMM1_stride, true, false); // is virtual + + // Define the gen index for row descriptor + auto genIndexRowDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_GEN_INDEX) + .setAxis(2) + .setComputeType(CUDNN_DATA_FLOAT) + .build(); + + // Create a gen index node + auto genIndexRow_op = unary_pw_op_create( + prevBlockOutputTensor, rowIndexTensor, genIndexRowDesc); + + // Define the gen index for row descriptor + auto genIndexColumnDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_GEN_INDEX) + .setAxis(3) + .setComputeType(CUDNN_DATA_FLOAT) + .build(); + + // Create a gen index node + auto genIndexColumn_op = unary_pw_op_create( + prevBlockOutputTensor, columnIndexTensor, genIndexColumnDesc); + + // Define the greater than equal to comparison descriptor + auto rowGreaterColDesc = pw_desc_create(CUDNN_DATA_BOOLEAN, CUDNN_POINTWISE_CMP_GE); + + // Create a greater than equal to node + auto rowGreaterCol_op = binary_pw_op_create( + rowIndexTensor, columnIndexTensor, causalMaskTensor, rowGreaterColDesc); + + // Define the binary select to perform masking descriptor + auto maskDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_BINARY_SELECT); + + // Create a binary select node + auto mask_op = ternary_pw_op_create( + prevBlockOutputTensor, maskValTensor, + causalMaskTensor, maskOutputTensor, maskDesc); + + ops->push_back(std::move(genIndexRow_op)); + ops->push_back(std::move(genIndexColumn_op)); + ops->push_back(std::move(rowGreaterCol_op)); + ops->push_back(std::move(mask_op)); + + return maskOutputTensor; +} + +static cudnn_frontend::Tensor +createSoftmaxForward(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, bool isTraining, + std::vector* ops, + const cudnn_frontend::Tensor& sAfterMaskTensor) { + int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv}; + int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1}; + + int64_t afterReduction_dim[4] = {b, h, s_q, 1}; + int64_t afterReduction_stride[4] = {h * s_q, s_q, 1, 1}; + + // max (x) + auto afterMaxReductionTensor = tensor_create( + CUDNN_DATA_FLOAT, VIRTUAL_ID + 150, afterReduction_dim, + afterReduction_stride, true, false); // is virtual + + // x - max(x) + auto afterSubtractionTensor = tensor_create( + CUDNN_DATA_FLOAT, VIRTUAL_ID + 151, afterBMM1_dim, + afterBMM1_stride, true, false); // is virtual + + // e^(x - max(x)) + auto afterExponentTensor = tensor_create( + CUDNN_DATA_FLOAT, VIRTUAL_ID + 152, afterBMM1_dim, + afterBMM1_stride, true, false); // is virtual; + + // sum (e^(x - max(x))) + auto afterAddReductionTensor = tensor_create( + CUDNN_DATA_FLOAT, VIRTUAL_ID + 153, afterReduction_dim, + afterReduction_stride, true, false); // is virtual + + // log (sum (e^(x - max(x)))) + auto afterLogLTensor = tensor_create( + CUDNN_DATA_FLOAT, VIRTUAL_ID + 154, afterReduction_dim, + afterReduction_stride, true, false); + + // M + log (sum (e^(x - max(x)))) + auto softmaxStatsTensor = tensor_create( + CUDNN_DATA_FLOAT, S_STATS_ID, afterReduction_dim, + afterReduction_stride, !isTraining, false); + // not virtual if training is true, virtual if training is false + + // divide (e/ sum(e)) + auto afterSoftmaxTensor = cudnn_frontend::TensorBuilder() + .setDim(4, afterBMM1_dim) + .setStride(4, afterBMM1_stride) + .setId(VIRTUAL_ID + 156) + .setAlignment(16) // 16B alignment is needed to run a tensor core engine + .setDataType(CUDNN_DATA_FLOAT) + .setVirtual(true) + .setByValue(false) + .setReorderType( + cudnn_frontend::cudnnBackendTensorReordering_t::CUDNN_TENSOR_REORDERING_F16x16) + .build(); + + // Define the reduction descriptor + auto reductionMaxDesc = cudnn_frontend::ReductionDescBuilder() + .setComputeType(CUDNN_DATA_FLOAT) + .setReductionOp(CUDNN_REDUCE_TENSOR_MAX) + .build(); + + // Create a reduction max node + auto reductionMax_op = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) + .setxDesc(sAfterMaskTensor) + .setyDesc(afterMaxReductionTensor) + .setreductionDesc(reductionMaxDesc) + .build(); + + // Define the subtract descriptor + auto subtractDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_SUB); + + // Create a subtract node + auto subtract_op = binary_pw_op_create( + sAfterMaskTensor, afterMaxReductionTensor, + afterSubtractionTensor, subtractDesc); + + // Define the exponent descriptor + auto exponentDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_EXP); + + // Create a exponent node + auto exponent_op = unary_pw_op_create( + afterSubtractionTensor, afterExponentTensor, exponentDesc); + + // Define the reduction descriptor + auto reductionAddDesc = cudnn_frontend::ReductionDescBuilder() + .setComputeType(CUDNN_DATA_FLOAT) + .setReductionOp(CUDNN_REDUCE_TENSOR_ADD) + .build(); + + // Create a reduction add node + auto reductionAdd_op = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) + .setxDesc(afterExponentTensor) + .setyDesc(afterAddReductionTensor) + .setreductionDesc(reductionAddDesc) + .build(); + + // Create log descriptor + auto logDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_LOG); + + // Create log node + auto log_op = unary_pw_op_create(afterAddReductionTensor, afterLogLTensor, logDesc); + + // Create add descriptor + auto addDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_ADD); + + // Create add node + auto add_op = binary_pw_op_create( + afterMaxReductionTensor, afterLogLTensor, + softmaxStatsTensor, addDesc); + + // Define the division descriptor + auto divisionDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_DIV); + + // Create a subtract node + auto division_op = binary_pw_op_create( + afterExponentTensor, afterAddReductionTensor, + afterSoftmaxTensor, divisionDesc); + + ops->push_back(std::move(reductionMax_op)); + ops->push_back(std::move(subtract_op)); + ops->push_back(std::move(exponent_op)); + ops->push_back(std::move(reductionAdd_op)); + ops->push_back(std::move(log_op)); + ops->push_back(std::move(add_op)); + ops->push_back(std::move(division_op)); + + return afterSoftmaxTensor; +} + +static cudnn_frontend::Tensor +createDropoutForward(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, + double probability, cudnnDataType_t tensorType, + std::vector* ops, + const cudnn_frontend::Tensor& afterSoftmaxTensor) { + CUDNN_FRONTEND_UNUSED(d); + + NVTE_CHECK(ops->size() != 0, "Dropout DAG constructed incorrectly as the first one"); + + int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv}; + int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1}; + + int64_t scale_dim[4] = {1, 1, 1, 1}; + int64_t scale_stride[4] = {1, 1, 1, 1}; + + auto dropoutSeed = tensor_create( + CUDNN_DATA_INT64, D_SEED_ID, scale_dim, + scale_stride, false, false); // not virtual + auto dropoutOffset = tensor_create( + CUDNN_DATA_INT64, D_OFFSET_ID, scale_dim, + scale_stride, false, false); // not virtual + + // mask for the dropout + auto dropoutMaskTensor = tensor_create( + CUDNN_DATA_FLOAT, VIRTUAL_ID + 200, afterBMM1_dim, + afterBMM1_stride, true, false); // is virtual + // after dropout tensor + auto afterDropoutTensor = cudnn_frontend::TensorBuilder() + .setDim(4, afterBMM1_dim) + .setStride(4, afterBMM1_stride) + .setId(VIRTUAL_ID + 201) + .setAlignment(16) // 16B alignment is needed to run a tensor core engine + .setDataType(tensorType) + .setVirtual(true) + .setByValue(false) + .setReorderType( + cudnn_frontend::cudnnBackendTensorReordering_t::CUDNN_TENSOR_REORDERING_F16x16) + .build(); + // scale after dropout + auto scaleDropoutTensor = tensor_create( + tensorType, D_CONST_ID, scale_dim, + scale_stride, false, true); // is by value + // after Scale + auto afterScaleTensor = tensor_create( + tensorType, VIRTUAL_ID + 202, afterBMM1_dim, + afterBMM1_stride, true, false); // is virtual + + // Define the reduction descriptor + auto rngDesc = cudnn_frontend::RngDescBuilder() + .setRngDistribution(CUDNN_RNG_DISTRIBUTION_BERNOULLI) + .setBernoulliDistProbability(1.0 - probability) + .build(); + + // Create a rng node + auto rng_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RNG_DESCRIPTOR) + .setyDesc(dropoutMaskTensor) + .setSeedDesc(dropoutSeed) + .setOffsetDesc(dropoutOffset) + .setRngDesc(rngDesc) + .build(); + + // Define the multiply mask descriptor + auto maskMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); + + // Create a multiply mask node + auto maskMul_op = binary_pw_op_create( + afterSoftmaxTensor, dropoutMaskTensor, + afterDropoutTensor, maskMulDesc); + + // Define the multiply scale descriptor + auto scaleMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); + + // Create a multiply scale node + auto scaleMul_op = binary_pw_op_create( + afterDropoutTensor, scaleDropoutTensor, + afterScaleTensor, scaleMulDesc); + + ops->push_back(std::move(rng_op)); + ops->push_back(std::move(maskMul_op)); + ops->push_back(std::move(scaleMul_op)); + + return afterScaleTensor; +} + +static cudnn_frontend::Tensor +createDropoutBackward(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, + double probability, cudnnDataType_t tensorType, + std::vector* ops, + const cudnn_frontend::Tensor& afterSoftmaxTensor, + const cudnn_frontend::Tensor& dropoutMaskTensor) { + CUDNN_FRONTEND_UNUSED(d); + + NVTE_CHECK(ops->size() != 0, "Dropout DAG constructed incorrectly as the first one"); + + int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv}; + int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1}; + + int64_t scale_dim[4] = {1, 1, 1, 1}; + int64_t scale_stride[4] = {1, 1, 1, 1}; + + auto dropoutSeed = tensor_create( + CUDNN_DATA_INT64, D_SEED_ID, scale_dim, + scale_stride, false, false); // not virtual + auto dropoutOffset = tensor_create( + CUDNN_DATA_INT64, D_OFFSET_ID, scale_dim, + scale_stride, false, false); // not virtual + + // after dropout tensor + auto afterDropoutTensor = cudnn_frontend::TensorBuilder() + .setDim(4, afterBMM1_dim) + .setStride(4, afterBMM1_stride) + .setId(VIRTUAL_ID + 201) + .setAlignment(16) // 16B alignment is needed to run a tensor core engine + .setDataType(tensorType) + .setVirtual(true) + .setByValue(false) + .setReorderType( + cudnn_frontend::cudnnBackendTensorReordering_t::CUDNN_TENSOR_REORDERING_F16x16) + .build(); + // scale after dropout + auto scaleDropoutTensor = tensor_create( + tensorType, D_CONST_ID, scale_dim, + scale_stride, false, true); // is by value + // after Scale + auto afterScaleTensor = tensor_create( + tensorType, VIRTUAL_ID + 202, afterBMM1_dim, + afterBMM1_stride, true, false); // is virtual + + // Define the reduction descriptor + auto rngDesc = cudnn_frontend::RngDescBuilder() + .setRngDistribution(CUDNN_RNG_DISTRIBUTION_BERNOULLI) + .setBernoulliDistProbability(1.0 - probability) + .build(); + + // Create a rng node + auto rng_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RNG_DESCRIPTOR) + .setyDesc(dropoutMaskTensor) + .setSeedDesc(dropoutSeed) + .setOffsetDesc(dropoutOffset) + .setRngDesc(rngDesc) + .build(); + + // Define the multiply mask descriptor + auto maskMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); + + // Create a multiply mask node + auto maskMul_op = binary_pw_op_create( + afterSoftmaxTensor, dropoutMaskTensor, + afterDropoutTensor, maskMulDesc); + + // Define the multiply scale descriptor + auto scaleMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); + + // Create a multiply scale node + auto scaleMul_op = binary_pw_op_create( + afterDropoutTensor, scaleDropoutTensor, + afterScaleTensor, scaleMulDesc); + + ops->push_back(std::move(rng_op)); + ops->push_back(std::move(maskMul_op)); + ops->push_back(std::move(scaleMul_op)); + + return afterScaleTensor; +} + +static void +createSVBMM(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, + NVTE_QKV_Layout layout, cudnnDataType_t tensorType, + std::vector* ops, + cudnn_frontend::Tensor const &afterScaleDropoutTensor) { + NVTE_CHECK(ops->size() != 0, "BMM2 op constructed incorrectly as the first one"); + + int64_t v_dim[4] = {b, h, s_kv, d}; + int64_t v_stride[4]; + generateMatrixStrides(b, h, s_q, s_kv, d, v_stride, layout, NVTE_QKV_Matrix::NVTE_V_Matrix); + + int64_t o_dim[4] = {b, h, s_q, d}; + int64_t o_stride[4]; + generateMatrixStrides(b, h, s_q, s_kv, d, o_stride, layout, NVTE_QKV_Matrix::NVTE_O_Matrix); + + auto vTensor = tensor_create(tensorType, V_ID, v_dim, v_stride, false, false); + // second GEMM output + auto oTensor = tensor_create(tensorType, O_ID, o_dim, o_stride, false, false); + + // Define the matmul 2 desc + auto matmul_2_Desc = cudnn_frontend::MatMulDescBuilder() + .setComputeType(CUDNN_DATA_FLOAT) + .build(); + + // Create a matmul 2 node + auto matmul_op2 = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) + .setaMatDesc(afterScaleDropoutTensor) + .setbMatDesc(vTensor) + .setcMatDesc(oTensor) + .setmatmulDesc(matmul_2_Desc) + .build(); + + ops->push_back(std::move(matmul_op2)); +} + +void fused_attn_arbitrary_seqlen_fwd_impl( + int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, + bool is_training, float scaling_factor, float dropout_probability, + NVTE_QKV_Layout layout, + void *devPtrQ, void *devPtrK, void *devPtrV, + void *devPtrSoftmaxStats, void *devPtrO, + void* devPtrDropoutSeed, void* devPtrDropoutOffset, + cudnnDataType_t tensorType, + void *workspace, size_t *workspace_size, + cudaStream_t stream, cudnnHandle_t handle) { + try { + NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream)); + + if (!is_training) { + dropout_probability == 0.0f; + } + + FADescriptor descriptor{b, h, + s_q, s_kv, + d, scaling_factor, + is_training, dropout_probability, + layout, NVTE_Bias_Type::NVTE_NO_BIAS, + NVTE_Mask_Type::NVTE_CAUSAL_MASK, tensorType}; + + using CacheType = std::map; + static thread_local CacheType fmha_fprop_cache; + + // Get plan from cache if cache is available, otherwise create one + auto get_plan = [&](CacheType &cache, const FADescriptor &descriptor) { + // if hit, return + auto it = cache.find(descriptor); + if (it != cache.end()) { + auto plan = it->second; + return plan; + } + + // otherwise, build the op_graph and the plan. Then update cache + std::vector all_ops; + std::vector ops; + + // Q * K^T + auto sTensor = createQKBMM(b, h, s_q, s_kv, d, layout, tensorType, &ops); + + // Q * K^T * bmmScale + auto sScaleTensor = createScale( + b, h, s_q, s_kv, d, layout, CUDNN_DATA_FLOAT, sTensor, &ops); + + // Causual mask + auto sAfterMaskTensor = createCausalMask( + b, h, s_q, s_kv, d, layout, tensorType, &ops, sScaleTensor); + + NVTE_CHECK(dropout_probability != 1.0f, + "Dropout probability cannot be 1.0"); + + auto softmax_output = createSoftmaxForward( + b, h, s_q, s_kv, is_training, &ops, sAfterMaskTensor); + + // Dropout(softmax) + auto dropout_output = createDropoutForward( + b, h, s_q, s_kv, d, + dropout_probability, tensorType, &ops, softmax_output); + createSVBMM(b, h, s_q, s_kv, d, layout, tensorType, &ops, dropout_output); + + for (unsigned int i = 0; i < ops.size(); i++) { + all_ops.push_back(&ops[i]); + } + + // Create an Operation Graph + auto opGraph = cudnn_frontend::OperationGraphBuilder() + .setHandle(handle) + .setOperationGraph(all_ops.size(), all_ops.data()) + .build(); + + cudnn_frontend::EngineConfigList filtered_configs; + auto statuses = cudnn_frontend::get_heuristics_list<1>( + {"heuristics_instant"}, opGraph, allowAllConfig, + filtered_configs, true); + + if (filtered_configs.size() == 0) { + cudnn_frontend::set_error_and_throw_exception( + nullptr, + CUDNN_STATUS_NOT_SUPPORTED, + "run_mha_fprop: No config returned by the heuristics"); + } + + auto plan = cudnn_frontend::ExecutionPlanBuilder() + .setHandle(handle) + .setEngineConfig(filtered_configs[0], opGraph.getTag()) + .build(); + + cache.insert({descriptor, plan}); + return plan; + }; + + auto plan = get_plan(fmha_fprop_cache, descriptor); + + auto plan_workspace_size = plan.getWorkspaceSize(); + + // Exit to request upper level API to allocate memory if needed + if (workspace == nullptr) { + *workspace_size = plan_workspace_size; + return; + } + + std::set> data_ptrs; + // Add all the data pointers to be used in the variant pack + float negInfinity = -1.0E+10f; + float scale_dropout = 1.0f/(1.0f - dropout_probability); + + data_ptrs.insert(std::pair(Q_ID, devPtrQ)); + data_ptrs.insert(std::pair(K_ID, devPtrK)); + data_ptrs.insert(std::pair(V_ID, devPtrV)); + data_ptrs.insert(std::pair(MASK_VAL_ID, &negInfinity)); + data_ptrs.insert(std::pair(S_CONST_ID, &scaling_factor)); + data_ptrs.insert(std::pair(O_ID, devPtrO)); + data_ptrs.insert(std::pair(D_SEED_ID, devPtrDropoutSeed)); + data_ptrs.insert(std::pair(D_OFFSET_ID, devPtrDropoutOffset)); + data_ptrs.insert(std::pair(D_CONST_ID, &scale_dropout)); + + // If training mode, we write out softmax stats + if (is_training) { + data_ptrs.insert(std::pair(S_STATS_ID, devPtrSoftmaxStats)); + } + + auto variantPack = cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(workspace) + .setDataPointers(data_ptrs) + .build(); + + NVTE_CHECK_CUDNN( + cudnnBackendExecute(handle, plan.get_raw_desc(), variantPack.get_raw_desc())); + } catch (cudnn_frontend::cudnnException &e) { + NVTE_ERROR(e.what()); + } +} + +void fused_attn_arbitrary_seqlen_bwd_impl( + int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, + float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, + void* devPtrQ, void* devPtrKTranspose, void* devPtrVTranspose, + void* devPtrO, void* devPtrSoftmaxStats, + void* devPtrdQ, void* devPtrdK, void* devPtrdV, void* devPtrdO, + void* devPtrDropoutSeed, void* devPtrDropoutOffset, + cudnnDataType_t tensorType, void *workspace, size_t *workspace_size, + cudaStream_t stream, cudnnHandle_t handle) { + try { + NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream)); + + FADescriptor descriptor{b, h, + s_q, s_kv, + d, scaling_factor, + true, dropout_probability, + layout, NVTE_Bias_Type::NVTE_NO_BIAS, + NVTE_Mask_Type::NVTE_CAUSAL_MASK, tensorType}; + + using CacheType = std::map; + static thread_local CacheType fmha_bprop_cache; + + auto get_plan = [&](CacheType &cache, const FADescriptor &descriptor) { + auto it = cache.find(descriptor); + if (it != cache.end()) { + return it->second; + } + + std::vector all_ops; + std::vector ops; + + // Creates the necessary tensor descriptors + int64_t q_dim[4] = {b, h, s_q, d}; + int64_t q_stride[4]; + generateMatrixStrides( + b, h, s_q, s_kv, d, q_stride, + layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); + + int64_t k_transpose_dim[4] = {b, h, d, s_kv}; + int64_t k_transpose_stride[4]; + generateMatrixStrides( + b, h, s_q, s_kv, d, k_transpose_stride, + layout, NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose); + + int64_t v_transpose_dim[4] = {b, h, d, s_kv}; + int64_t v_transpose_stride[4]; + generateMatrixStrides( + b, h, s_q, s_kv, d, v_transpose_stride, + layout, NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose); + + int64_t p_dim[4] = {b, h, s_q, s_kv}; + int64_t p_stride[4]; + generateMatrixStrides( + b, h, s_q, s_kv, d, p_stride, + layout, NVTE_QKV_Matrix::NVTE_S_Matrix); + + int64_t p_transpose_dim[4] = {b, h, s_kv, s_q}; + int64_t p_transpose_stride[4]; + p_transpose_stride[0] = p_stride[0]; + p_transpose_stride[1] = p_stride[1]; + p_transpose_stride[2] = p_stride[3]; + p_transpose_stride[3] = p_stride[2]; + + int64_t o_dim[4] = {b, h, s_q, d}; + int64_t o_stride[4]; + generateMatrixStrides( + b, h, s_q, s_kv, d, o_stride, + layout, NVTE_QKV_Matrix::NVTE_O_Matrix); + + int64_t scale_dim[4] = {1, 1, 1, 1}; + int64_t scale_stride[4] = {1, 1, 1, 1}; + + /******************************************************************************* + * Dot product dO * O */ + + // output and gradient of the output + auto oTensor = tensor_create(tensorType, O_ID, o_dim, o_stride, false, false); + auto dOTensor = tensor_create(tensorType, dO_ID, o_dim, o_stride, false, false); + + auto dotProductTensor = tensor_create( + CUDNN_DATA_FLOAT, VIRTUAL_ID, o_dim, + o_stride, true, false); // is virtual + + // Create pointwise mul + auto multiplyDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); + + // do * O + auto dotProductOp = binary_pw_op_create( + dOTensor, oTensor, dotProductTensor, multiplyDesc); + ops.push_back(std::move(dotProductOp)); + + /******************************************************************************* + * Reduction(dO * O) */ + + int64_t reduction_dim[4] = {b, h, s_q, 1}; + int64_t reduction_stride[4] = {h * s_q, s_q, 1, 1}; + + // reduction(dO * O) + auto afterReductionTensor = tensor_create( + CUDNN_DATA_FLOAT, VIRTUAL_ID + 1, reduction_dim, + reduction_stride, true, false); // is virtual + auto reductionMaxDesc = cudnn_frontend::ReductionDescBuilder() + .setComputeType(CUDNN_DATA_FLOAT) + .setReductionOp(CUDNN_REDUCE_TENSOR_MAX) + .build(); + + // Create a reduction max node + auto reductionMax_op = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) + .setxDesc(dotProductTensor) + .setyDesc(afterReductionTensor) + .setreductionDesc(reductionMaxDesc) + .build(); + ops.push_back(std::move(reductionMax_op)); + + + /******************************************************************************* + * reduction(dO * O) * scale prob -> softmaxSum */ + + auto softmaxSumTensor = tensor_create( + CUDNN_DATA_FLOAT, S_SUM_ID, reduction_dim, + reduction_stride, false, false); // not virtual + auto scaleProbTensor = tensor_create( + CUDNN_DATA_FLOAT, SCALE_PROB, scale_dim, + scale_stride, false, true); // not virtual + auto softmaxSumOp = binary_pw_op_create( + afterReductionTensor, scaleProbTensor, + softmaxSumTensor, multiplyDesc); + ops.push_back(std::move(softmaxSumOp)); + + /******************************************************************************* + * Q @ K.T -> P */ + + // Inputs from fprop + auto qTensor = tensor_create(tensorType, Q_ID, q_dim, q_stride, false, false); + auto kTransposeTensor = tensor_create( + tensorType, K_ID, k_transpose_dim, + k_transpose_stride, false, false); + auto pTensor = tensor_create( + CUDNN_DATA_FLOAT, VIRTUAL_ID + 2, p_dim, + p_stride, true, false); // is virtual + + // matmul to calculate dvTensor + auto matmul_0_Desc = cudnn_frontend::MatMulDescBuilder() + .setComputeType(CUDNN_DATA_FLOAT) + .build(); + + auto matmul_op0 = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) + .setaMatDesc(qTensor) + .setbMatDesc(kTransposeTensor) + .setcMatDesc(pTensor) + .setmatmulDesc(matmul_0_Desc) + .build(); + + ops.push_back(std::move(matmul_op0)); + + /******************************************************************************* + * P * bmmScale -> pAfterScale */ + + auto bmmScaleTensor = tensor_create( + CUDNN_DATA_FLOAT, S_CONST_ID, scale_dim, + scale_stride, false, true); // not virtual and by value + auto pAfterScaleTensor = tensor_create( + CUDNN_DATA_FLOAT, VIRTUAL_ID + 2000, p_dim, + p_stride, true, false); // virtual + auto scaleOp = binary_pw_op_create( + pTensor, bmmScaleTensor, pAfterScaleTensor, multiplyDesc); + ops.push_back(std::move(scaleOp)); + + /******************************************************************************* + * Causal masking -> pAfterMaskTensor */ + + auto pAfterMaskTensor = createCausalMask( + b, h, s_q, s_kv, d, layout, tensorType, &ops, pAfterScaleTensor); + + /******************************************************************************* + * pAfterMaskTensor - softmaxStats -> pAfterSubtract */ + + auto pAfterSubtractTensor = tensor_create( + CUDNN_DATA_FLOAT, VIRTUAL_ID + 3, p_dim, + p_stride, true, false); // is virtual + auto softmaxStatsTensor = tensor_create( + CUDNN_DATA_FLOAT, S_STATS_ID, reduction_dim, + reduction_stride, false, false); // not virtual + auto subtractDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_SUB); + auto subtract_op = binary_pw_op_create( + pAfterMaskTensor, softmaxStatsTensor, + pAfterSubtractTensor, subtractDesc); + ops.push_back(std::move(subtract_op)); + + /******************************************************************************* + * e^(pAfterSubtract) -> pAfterSoftmax */ + + auto pAfterSoftmaxTensor = tensor_create( + CUDNN_DATA_FLOAT, VIRTUAL_ID + 4, p_dim, + p_stride, true, false); // is virtual + auto expDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_EXP); + auto exp_op = unary_pw_op_create( + pAfterSubtractTensor, pAfterSoftmaxTensor, expDesc); + ops.push_back(std::move(exp_op)); + + /******************************************************************************* + * Dropout -> afterScaleDropout */ + + auto dropoutMaskTensor = tensor_create( + CUDNN_DATA_FLOAT, VIRTUAL_ID + 5, p_dim, + p_stride, true, false); // is virtual + auto afterScaleDropoutTensor = createDropoutBackward( + b, h, s_q, s_kv, d, dropout_probability, tensorType, + &ops, pAfterSoftmaxTensor, dropoutMaskTensor); + + /******************************************************************************* + * afterScaleDropout -> sTransposeTensor */ + + auto sTransposeTensor = tensor_create( + tensorType, VIRTUAL_ID + 6, p_transpose_dim, + p_transpose_stride, true, false); // is virtual + auto reshape_op = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR) + .setxDesc(afterScaleDropoutTensor) + .setyDesc(sTransposeTensor) + .build(); + ops.push_back(std::move(reshape_op)); + + // Outputs of bprop + int64_t dqkv_dim[4] = {b, h, s_kv, d}; + int64_t dqkv_stride[4]; + generateMatrixStrides( + b, h, s_q, s_kv, d, dqkv_stride, + layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); + + // Outputs of backprop + auto dQTensor = tensor_create(tensorType, dQ_ID, dqkv_dim, dqkv_stride, false, false); + auto dKTensor = tensor_create(tensorType, dK_ID, dqkv_dim, dqkv_stride, false, false); + auto dVTensor = tensor_create(tensorType, dV_ID, dqkv_dim, dqkv_stride, false, false); + // not virtual + + /******************************************************************************* + * sTransposeTensor @ dO -> dV */ + + auto matmul_1_Desc = cudnn_frontend::MatMulDescBuilder() + .setComputeType(CUDNN_DATA_FLOAT) + .build(); + + auto matmul_op1 = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) + .setaMatDesc(sTransposeTensor) + .setbMatDesc(dOTensor) + .setcMatDesc(dVTensor) + .setmatmulDesc(matmul_1_Desc) + .build(); + + ops.push_back(std::move(matmul_op1)); + + /******************************************************************************* + * dO @ V.T -> dS */ + + auto vTransposeTensor = tensor_create( + tensorType, V_ID, v_transpose_dim, + v_transpose_stride, false, false); + auto dSTensor = tensor_create( + CUDNN_DATA_FLOAT, VIRTUAL_ID + 7, p_dim, + p_stride, true, false); // is virtual + + auto matmul_2_Desc = cudnn_frontend::MatMulDescBuilder() + .setComputeType(CUDNN_DATA_FLOAT) + .build(); + + auto matmul_op2 = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) + .setaMatDesc(dOTensor) + .setbMatDesc(vTransposeTensor) + .setcMatDesc(dSTensor) + .setmatmulDesc(matmul_2_Desc) + .build(); + + ops.push_back(std::move(matmul_op2)); + + /******************************************************************************* + * dS * dropoutMask -> dSAfterDropout */ + + auto dSAfterDropoutTensor = tensor_create( + CUDNN_DATA_FLOAT, VIRTUAL_ID + 8, p_dim, + p_stride, true, false); // is virtual + auto multiply_op = binary_pw_op_create( + dSTensor, dropoutMaskTensor, + dSAfterDropoutTensor, multiplyDesc); + ops.push_back(std::move(multiply_op)); + + /******************************************************************************* + * dSAfterDropout - softmaxSum -> dsAfterSubtract */ + + auto dsAfterSubtractTensor = tensor_create( + CUDNN_DATA_FLOAT, VIRTUAL_ID + 9, p_dim, + p_stride, true, false); // is virtual + auto subtract_op2 = binary_pw_op_create( + dSAfterDropoutTensor, softmaxSumTensor, + dsAfterSubtractTensor, subtractDesc); + ops.push_back(std::move(subtract_op2)); + + /******************************************************************************* + * dsAfterSubtract * afterSoftmax -> dP */ + + auto dPTensor = tensor_create( + CUDNN_DATA_FLOAT, VIRTUAL_ID + 10, p_dim, + p_stride, true, false); // is virtual + auto multiply_op2 = binary_pw_op_create( + dsAfterSubtractTensor, pAfterSoftmaxTensor, + dPTensor, multiplyDesc); + ops.push_back(std::move(multiply_op2)); + + /******************************************************************************* + * dP * scaleDropout -> dPAfterDropoutScale */ + + auto dPAfterDropoutScaleTensor = tensor_create( + CUDNN_DATA_FLOAT, VIRTUAL_ID + 11, p_dim, + p_stride, true, false); // is virtual + auto scaleDropoutTensor = tensor_create( + CUDNN_DATA_FLOAT, D_CONST_ID, scale_dim, + scale_stride, false, true); // is by value + auto multiply_op3 = binary_pw_op_create( + dPTensor, scaleDropoutTensor, + dPAfterDropoutScaleTensor, multiplyDesc); + ops.push_back(std::move(multiply_op3)); + + /******************************************************************************* + * dPAfterDropoutScale * bmmScale -> dPScaledTensor */ + + auto dPScaledTensor = tensor_create( + CUDNN_DATA_FLOAT, VIRTUAL_ID + 12, p_dim, + p_stride, true, false); // is virtual + auto multiply_op4 = binary_pw_op_create( + dPAfterDropoutScaleTensor, bmmScaleTensor, + dPScaledTensor, multiplyDesc); + ops.push_back(std::move(multiply_op4)); + + /******************************************************************************* + * K.T -> K */ + + int64_t kDim[4] = {b, h, s_kv, d}; + int64_t kStride[4]; + generateMatrixStrides( + b, h, s_q, s_kv, d, kStride, + layout, NVTE_QKV_Matrix::NVTE_K_Matrix); + auto kTensor = tensor_create( + tensorType, VIRTUAL_ID + 13, kDim, + kStride, true, false); // is virtual + auto reshape_op2 = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR) + .setxDesc(kTransposeTensor) + .setyDesc(kTensor) + .build(); + ops.push_back(std::move(reshape_op2)); + + /******************************************************************************* + * dP @ K -> dqAccumTensor */ + + auto dqAccumTensor = cudnn_frontend::TensorBuilder() + .setDim(4, dqkv_dim) + .setStride(4, dqkv_stride) + .setId(dQ_ACCUM_ID) + .setAlignment(16) // 16B alignment is needed to run a tensor core engine + .setDataType(CUDNN_DATA_FLOAT) + .setVirtual(false) + .setByValue(false) + .setReorderType( + cudnn_frontend::cudnnBackendTensorReordering_t::CUDNN_TENSOR_REORDERING_F16x16) + .build(); + + auto matmul_3_Desc = cudnn_frontend::MatMulDescBuilder() + .setComputeType(CUDNN_DATA_FLOAT) + .build(); + auto matmul_op3 = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) + .setaMatDesc(dPTensor) + .setbMatDesc(kTensor) + .setcMatDesc(dqAccumTensor) + .setmatmulDesc(matmul_3_Desc) + .build(); + + ops.push_back(std::move(matmul_op3)); + + /******************************************************************************* + * dP.T @ Q -> dK */ + + auto dPTransposeTensor = tensor_create( + CUDNN_DATA_FLOAT, VIRTUAL_ID + 14, p_transpose_dim, + p_transpose_stride, true, false); // is virtual + auto reshape_op3 = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR) + .setxDesc(dPTensor) + .setyDesc(dPTransposeTensor) + .build(); + ops.push_back(std::move(reshape_op3)); + + auto matmul_4_Desc = cudnn_frontend::MatMulDescBuilder() + .setComputeType(CUDNN_DATA_FLOAT) + .build(); + auto matmul_op4 = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) + .setaMatDesc(dPTransposeTensor) + .setbMatDesc(qTensor) + .setcMatDesc(dKTensor) + .setmatmulDesc(matmul_4_Desc) + .build(); + + ops.push_back(std::move(matmul_op4)); + + /******************************************************************************* + * dqAccumTensor @ identity -> dqTensor */ + + auto identityDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_IDENTITY); + auto identity_op = unary_pw_op_create(dqAccumTensor, dQTensor, identityDesc); + ops.push_back(std::move(identity_op)); + + for (unsigned int i = 0; i < ops.size(); i++) { + all_ops.push_back(&ops[i]); + } + + // Create an Operation Graph + auto opGraph = cudnn_frontend::OperationGraphBuilder() + .setHandle(handle) + .setOperationGraph(all_ops.size(), all_ops.data()) + .build(); + + cudnn_frontend::EngineConfigList filtered_configs; + auto statuses = cudnn_frontend::get_heuristics_list<1>( + {"heuristics_instant"}, opGraph, allowAllConfig, filtered_configs, true); + + if (filtered_configs.size() == 0) { + cudnn_frontend::set_error_and_throw_exception( + nullptr, CUDNN_STATUS_NOT_SUPPORTED, + "run_mha_bprop: No config returned by the heuristics"); + } + + auto plan = cudnn_frontend::ExecutionPlanBuilder() + .setHandle(handle) + .setEngineConfig(filtered_configs[0], opGraph.getTag()) + .build(); + + cache.insert({descriptor, plan}); + return plan; + }; + + auto plan = get_plan(fmha_bprop_cache, descriptor); + + auto plan_workspace_size = plan.getWorkspaceSize(); + + // Exit to request upper level API to allocate memory if needed + size_t softmaxSum_workspace_size = b * h * s_q * sizeof(float); + size_t dqAccum_workspace_size = b * s_q * h * d * sizeof(float); + if (workspace == nullptr) { + *workspace_size = plan_workspace_size + softmaxSum_workspace_size + + dqAccum_workspace_size; + return; + } + + void *devPtrSoftmaxSum = static_cast(workspace) + plan_workspace_size; + void *devPtrdQAccumulator = static_cast(devPtrSoftmaxSum) + + softmaxSum_workspace_size; + NVTE_CHECK_CUDA(cudaMemset(devPtrdQAccumulator, 0, dqAccum_workspace_size)); + + std::set> data_ptrs; + // add all the data pointers to be used in the variant pack + float negInfinity = -1.0E+10f; + float scale_dropout = 1.0f/(1.0f - dropout_probability); + data_ptrs.insert(std::pair(dQ_ID, devPtrdQ)); + data_ptrs.insert(std::pair(dQ_ACCUM_ID, devPtrdQAccumulator)); + data_ptrs.insert(std::pair(dK_ID, devPtrdK)); + data_ptrs.insert(std::pair(dV_ID, devPtrdV)); + + data_ptrs.insert(std::pair(Q_ID, devPtrQ)); + data_ptrs.insert(std::pair(K_ID, devPtrKTranspose)); + data_ptrs.insert(std::pair(V_ID, devPtrVTranspose)); + data_ptrs.insert(std::pair(O_ID, devPtrO)); + data_ptrs.insert(std::pair(dO_ID, devPtrdO)); + data_ptrs.insert(std::pair(S_STATS_ID, devPtrSoftmaxStats)); + data_ptrs.insert(std::pair(S_SUM_ID, devPtrSoftmaxSum)); + data_ptrs.insert(std::pair(D_SEED_ID, devPtrDropoutSeed)); + data_ptrs.insert(std::pair(D_OFFSET_ID, devPtrDropoutOffset)); + data_ptrs.insert(std::pair(MASK_VAL_ID, &negInfinity)); + + float scaleProb = 1.0f - dropout_probability; + data_ptrs.insert(std::pair(D_CONST_ID, &scale_dropout)); + data_ptrs.insert(std::pair(S_CONST_ID, &scaling_factor)); + data_ptrs.insert(std::pair(SCALE_PROB, &scaleProb)); + + auto variantPack = cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(workspace) + .setDataPointers(data_ptrs) + .build(); + + NVTE_CHECK_CUDNN( + cudnnBackendExecute(handle, plan.get_raw_desc(), variantPack.get_raw_desc())); + } catch (cudnn_frontend::cudnnException &e) { + NVTE_ERROR(e.what()); + } +} + +} // namespace fused_attn + +using namespace transformer_engine::fused_attn; +void fused_attn_arbitrary_seqlen_fwd_qkvpacked( + size_t batch, size_t max_seqlen, size_t num_head, size_t head_dim, bool is_training, + float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O, + NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { + using namespace transformer_engine; + + NVTE_CHECK(qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED, + "qkv_layout must be NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED."); + + // QKV shape is [b, s, 3, h, d] + void *devPtrQKV = input_QKV->data.dptr; + const auto stride = num_head * head_dim; + + void *devPtrQ = static_cast(devPtrQKV); + void *devPtrK = static_cast(static_cast(devPtrQKV) + stride); + void *devPtrV = static_cast(static_cast(devPtrQKV) + 2 * stride); + + void *devPtrO = output_O->data.dptr; + + void *devPtrS = nullptr; + + if (Aux_CTX_Tensors->size == 0) { + Aux_CTX_Tensors->size = 2; + Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + output_S->data.dptr = nullptr; + output_S->data.shape = {batch, num_head, max_seqlen, 1}; + output_S->data.dtype = DType::kFloat32; + Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + output_rng_state->data.dptr = nullptr; + output_rng_state->data.shape = {2}; + output_rng_state->data.dtype = DType::kInt64; + } else if (Aux_CTX_Tensors->size == 2) { + Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + devPtrS = output_S->data.dptr; + Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + output_rng_state->data.dptr = rng_state->data.dptr; + } + + void* devPtrDropoutSeed = rng_state->data.dptr; + void* devPtrDropoutOffset = reinterpret_cast( + reinterpret_cast(rng_state->data.dptr) + 1); + + const DType QKV_type = input_QKV->data.dtype; + size_t workspace_size = 0; + + fused_attn_arbitrary_seqlen_fwd_impl(batch, num_head, max_seqlen, max_seqlen, head_dim, + is_training, attn_scale, p_dropout, qkv_layout, + devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO, + devPtrDropoutSeed, devPtrDropoutOffset, + get_cudnn_dtype(QKV_type), + workspace->data.dptr, &workspace_size, stream, handle); + + if (workspace_size > 0) { + if (workspace->data.dptr == nullptr) { + workspace->data.shape = {workspace_size}; + workspace->data.dtype = DType::kByte; + return; + } + } else if (workspace_size == 0) { + workspace->data.shape = {1}; + workspace->data.dtype = DType::kByte; + return; + } +} + +void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t num_head, + size_t head_dim, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, + const Tensor *input_QKV, const Tensor *input_O, + const Tensor *input_dO, Tensor *output_S, + Tensor *output_dQKV, Tensor *output_dBias, + const Tensor *cu_seqlens, const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { + using namespace transformer_engine; + + NVTE_CHECK(qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED, + "qkv_layout must be NVTE_QKV_INTERLEAVED."); + + // QKV shape is [b, s, 3, h, d] + void *devPtrQKV = input_QKV->data.dptr; + + auto stride = num_head * head_dim; + void *devPtrQ = devPtrQKV; + void *devPtrK = static_cast(static_cast(devPtrQKV) + stride); + void *devPtrV = static_cast(static_cast(devPtrQKV) + 2 * stride); + + void* devPtrO = input_O->data.dptr; + void *devPtrdO = input_dO->data.dptr; + + // dQKV shape is [b, s, 3, h, d] + void *devPtrdQKV = output_dQKV->data.dptr; + void *devPtrdQ = devPtrdQKV; + void *devPtrdK = static_cast(static_cast(devPtrdQKV) + stride); + void *devPtrdV = static_cast(static_cast(devPtrdQKV) + 2 * stride); + + void *devPtrSoftmaxStats = nullptr; + devPtrSoftmaxStats = output_S->data.dptr; + + void* devPtrDropoutSeed = rng_state->data.dptr; + void* devPtrDropoutOffset = reinterpret_cast( + reinterpret_cast(rng_state->data.dptr) + 1); + + const auto qkv_type = input_QKV->data.dtype; + size_t workspace_size = 0; + + fused_attn_arbitrary_seqlen_bwd_impl(batch, num_head, max_seqlen, max_seqlen, head_dim, + attn_scale, p_dropout, qkv_layout, + devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, + devPtrdQ, devPtrdK, devPtrdV, devPtrdO, + devPtrDropoutSeed, devPtrDropoutOffset, + get_cudnn_dtype(qkv_type), + workspace->data.dptr, &workspace_size, stream, handle); + + if (workspace_size > 0) { + if (workspace->data.dptr == nullptr) { + workspace->data.shape = {workspace_size}; + workspace->data.dtype = DType::kByte; + return; + } + } else if (workspace_size == 0) { + workspace->data.shape = {1}; + workspace->data.dtype = DType::kByte; + return; + } +} +} // namespace transformer_engine +#endif // CUDNN_VERSION >= 8900 diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h new file mode 100644 index 0000000000..68ebe0c7c0 --- /dev/null +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -0,0 +1,44 @@ +/************************************************************************* + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file fused_attn_arbitrary_seqlen.h + * \brief Functions for fused attention with seqlen > 512 + */ + +#ifndef TRANSFORMER_ENGINE_COMMON_FUSED_ATTN_FUSED_ATTN_ARBITRARY_SEQLEN_H_ +#define TRANSFORMER_ENGINE_COMMON_FUSED_ATTN_FUSED_ATTN_ARBITRARY_SEQLEN_H_ + +#include "transformer_engine/fused_attn.h" + +#include + +#include "common/common.h" + +namespace transformer_engine { +#if (CUDNN_VERSION >= 8900) +void fused_attn_arbitrary_seqlen_fwd_qkvpacked(size_t batch, size_t max_seqlen, size_t num_head, + size_t head_size, bool is_training, float attn_scale, + float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + const Tensor *input_QKV, const Tensor *input_Bias, + Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, + const Tensor *cu_seqlens, const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + +void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t num_head, + size_t head_dim, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, const Tensor *input_QKV, + const Tensor *input_O, + const Tensor *input_dO, Tensor *output_S, + Tensor *output_dQKV, Tensor *output_dBias, + const Tensor *cu_seqlens, const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + +#endif // CUDNN_VERSION >= 8900 +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_COMMON_FUSED_ATTN_FUSED_ATTN_ARBITRARY_SEQLEN_H_ diff --git a/transformer_engine/common/fused_attn/fused_attn_fp16_bf16_max_seqlen_512.cu b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu similarity index 98% rename from transformer_engine/common/fused_attn/fused_attn_fp16_bf16_max_seqlen_512.cu rename to transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu index e8906b31c4..932414ffc0 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp16_bf16_max_seqlen_512.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu @@ -4,7 +4,7 @@ * See LICENSE for license information. ************************************************************************/ -#include "fused_attn_fp16_bf16_max_seqlen_512.h" +#include "fused_attn_f16_max512_seqlen.h" #include #include @@ -1239,7 +1239,7 @@ void fused_attn_max_512_fwd_qkvpacked( size_t batch, size_t max_seqlen, size_t num_head, size_t head_dim, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O, - NVTETensorPack *Aux_Output_Tensors, const Tensor *cu_seqlens, const Tensor *rng_state, + NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; @@ -1260,14 +1260,14 @@ void fused_attn_max_512_fwd_qkvpacked( void *devPtrS = nullptr; - if (Aux_Output_Tensors->size == 0) { - Aux_Output_Tensors->size = 1; - Tensor *output_S = reinterpret_cast(Aux_Output_Tensors->tensors[0]); + if (Aux_CTX_Tensors->size == 0) { + Aux_CTX_Tensors->size = 1; + Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); output_S->data.dptr = nullptr; output_S->data.shape = {batch, num_head, max_seqlen, max_seqlen}; output_S->data.dtype = input_QKV->data.dtype; - } else if (Aux_Output_Tensors->size == 1) { - Tensor *output_S = reinterpret_cast(Aux_Output_Tensors->tensors[0]); + } else if (Aux_CTX_Tensors->size == 1) { + Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); devPtrS = output_S->data.dptr; } @@ -1307,7 +1307,7 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O, - NVTETensorPack *Aux_Output_Tensors, const Tensor *q_cu_seqlens, + NVTETensorPack *Aux_CTX_Tensors, const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; @@ -1336,14 +1336,14 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k const DType kv_type = input_KV->data.dtype; NVTE_CHECK(q_type == kv_type, "data type of Q must be equal to data type of KV."); - if (Aux_Output_Tensors->size == 0) { - Aux_Output_Tensors->size = 1; - Tensor *output_S = reinterpret_cast(Aux_Output_Tensors->tensors[0]); + if (Aux_CTX_Tensors->size == 0) { + Aux_CTX_Tensors->size = 1; + Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); output_S->data.dptr = nullptr; output_S->data.shape = {batch, num_head, q_max_seqlen, kv_max_seqlen}; output_S->data.dtype = q_type; - } else if (Aux_Output_Tensors->size == 1) { - Tensor *output_S = reinterpret_cast(Aux_Output_Tensors->tensors[0]); + } else if (Aux_CTX_Tensors->size == 1) { + Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); devPtrS = output_S->data.dptr; } @@ -1381,7 +1381,7 @@ void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t nu size_t head_dim, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, const Tensor *input_QKV, - const Tensor *input_dO, const NVTETensorPack *Aux_CTX_Tensors, + const Tensor *input_dO, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, const Tensor *cu_seqlens, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { @@ -1408,12 +1408,8 @@ void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t nu void *devPtrdBias = output_dBias->data.dptr; - NVTE_CHECK(Aux_CTX_Tensors->size == 1); - void *devPtrS = nullptr; - if (Aux_CTX_Tensors->size == 1) { - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); - devPtrS = output_S->data.dptr; - } + void *devPtrS = output_S->data.dptr; + // devPtrdS reuses the memory of devPtrS void *devPtrdS = devPtrS; @@ -1446,7 +1442,7 @@ void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, const Tensor *input_Q, const Tensor *input_KV, - const Tensor *input_dO, const NVTETensorPack *Aux_CTX_Tensors, + const Tensor *input_dO, Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dBias, const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { @@ -1472,12 +1468,8 @@ void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k void *devPtrdBias = output_dBias->data.dptr; - NVTE_CHECK(Aux_CTX_Tensors->size == 1); - void *devPtrS = nullptr; - if (Aux_CTX_Tensors->size == 1) { - Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); - devPtrS = output_S->data.dptr; - } + void *devPtrS = output_S->data.dptr; + // devPtrdS reuses the memory of devPtrS void *devPtrdS = devPtrS; diff --git a/transformer_engine/common/fused_attn/fused_attn_fp16_bf16_max_seqlen_512.h b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.h similarity index 91% rename from transformer_engine/common/fused_attn/fused_attn_fp16_bf16_max_seqlen_512.h rename to transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.h index 3e11a1f02a..75545d0b40 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp16_bf16_max_seqlen_512.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.h @@ -24,7 +24,7 @@ void fused_attn_max_512_fwd_qkvpacked(size_t batch, size_t max_seqlen, size_t nu float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_Bias, - Tensor *output_O, NVTETensorPack *Aux_Output_Tensors, + Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); @@ -34,7 +34,7 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O, - NVTETensorPack *Aux_Output_Tensors, const Tensor *q_cu_seqlens, + NVTETensorPack *Aux_CTX_Tensors, const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); @@ -42,7 +42,7 @@ void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t nu size_t head_dim, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, const Tensor *input_QKV, - const Tensor *input_dO, const NVTETensorPack *Aux_CTX_Tensors, + const Tensor *input_dO, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, const Tensor *cu_seqlens, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); @@ -52,7 +52,7 @@ void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, const Tensor *input_Q, const Tensor *input_KV, - const Tensor *input_dO, const NVTETensorPack *Aux_CTX_Tensors, + const Tensor *input_dO, Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dBias, const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 768ac8eb20..8fc208bfcd 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -991,7 +991,7 @@ static cudnn_frontend::Tensor createdSQBMM( } // fused attention FWD FP8 -void fa_fwd_fp8(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d, +void fused_attn_fp8_fwd_impl(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d, bool isTraining, float attnScale, float dropoutProbability, NVTE_QKV_Layout layout, void* devPtrQ, void* devPtrK, void* devPtrV, @@ -1303,7 +1303,7 @@ void fa_fwd_fp8(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d, } // fused attention BWD FP8 -void fa_bwd_fp8(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d, +void fused_attn_fp8_bwd_impl(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d, float attnScale, float dropoutProbability, NVTE_QKV_Layout layout, void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrM, void* devPtrZInv, @@ -1858,7 +1858,7 @@ void fa_bwd_fp8(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d, #if (CUDNN_VERSION >= 8900) // fused attention FWD FP8 with packed QKV -void fused_attn_fwd_fp8_qkvpacked( +void fused_attn_fp8_fwd_qkvpacked( size_t b, size_t max_seqlen, size_t h, size_t d, bool is_training, float attn_scale, @@ -1866,7 +1866,7 @@ void fused_attn_fwd_fp8_qkvpacked( const Tensor *input_QKV, Tensor *input_output_S, Tensor *output_O, - NVTETensorPack* Aux_Output_Tensors, + NVTETensorPack* Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *rng_state, Tensor *workspace, @@ -1888,23 +1888,29 @@ void fused_attn_fwd_fp8_qkvpacked( void* devPtrM = nullptr; void* devPtrZInv = nullptr; - if (Aux_Output_Tensors->size == 0) { + if (Aux_CTX_Tensors->size == 0) { if (is_training) { - Aux_Output_Tensors->size = 2; - Tensor *output_M = reinterpret_cast(Aux_Output_Tensors->tensors[0]); - Tensor *output_ZInv = reinterpret_cast(Aux_Output_Tensors->tensors[1]); + Aux_CTX_Tensors->size = 3; + Tensor *output_M = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *output_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); output_M->data.dptr = nullptr; output_M->data.shape = {b, h, max_seqlen, 1}; output_M->data.dtype = DType::kFloat32; output_ZInv->data.dptr = nullptr; output_ZInv->data.shape = {b, h, max_seqlen, 1}; output_ZInv->data.dtype = DType::kFloat32; + output_rng_state->data.dptr = nullptr; + output_rng_state->data.shape = {2}; + output_rng_state->data.dtype = DType::kInt64; } - } else if (Aux_Output_Tensors->size == 2) { - Tensor *output_M = reinterpret_cast(Aux_Output_Tensors->tensors[0]); - Tensor *output_ZInv = reinterpret_cast(Aux_Output_Tensors->tensors[1]); + } else if (Aux_CTX_Tensors->size == 3) { + Tensor *output_M = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *output_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); devPtrM = output_M->data.dptr; devPtrZInv = output_ZInv->data.dptr; + output_rng_state->data.dptr = rng_state->data.dptr; } void* devPtrAmaxS = input_output_S->amax.dptr; @@ -1921,7 +1927,7 @@ void fused_attn_fwd_fp8_qkvpacked( const DType QKV_type = input_QKV->data.dtype; size_t workspace_size = 0; - fused_attn::fa_fwd_fp8( + fused_attn::fused_attn_fp8_fwd_impl( b, max_seqlen, max_seqlen, h, d, is_training, attn_scale, p_dropout, qkv_layout, devPtrQ, devPtrK, devPtrV, @@ -1948,7 +1954,7 @@ void fused_attn_fwd_fp8_qkvpacked( } } // fused attention BWD FP8 with packed QKV -void fused_attn_bwd_fp8_qkvpacked( +void fused_attn_fp8_bwd_qkvpacked( size_t b, size_t max_seqlen, size_t h, size_t d, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, @@ -2011,7 +2017,7 @@ void fused_attn_bwd_fp8_qkvpacked( const DType QKV_type = input_QKV->data.dtype; size_t workspace_size = 0; - fused_attn::fa_bwd_fp8( + fused_attn::fused_attn_fp8_bwd_impl( b, max_seqlen, max_seqlen, h, d, attn_scale, p_dropout, qkv_layout, devPtrQ, devPtrK, devPtrV, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.h b/transformer_engine/common/fused_attn/fused_attn_fp8.h index e43683d338..111dfddd10 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.h +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.h @@ -13,7 +13,7 @@ namespace transformer_engine { #if (CUDNN_VERSION >= 8900) // fused attention FWD FP8 with packed QKV -void fused_attn_fwd_fp8_qkvpacked( +void fused_attn_fp8_fwd_qkvpacked( size_t b, size_t max_seqlen, size_t h, size_t d, bool is_training, float attn_scale, @@ -21,7 +21,7 @@ void fused_attn_fwd_fp8_qkvpacked( const Tensor *input_QKV, Tensor *input_output_S, Tensor *output_O, - NVTETensorPack* Aux_Output_Tensors, + NVTETensorPack* Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *rng_state, Tensor *workspace, @@ -29,7 +29,7 @@ void fused_attn_fwd_fp8_qkvpacked( cudnnHandle_t handle); // fused attention BWD FP8 with packed QKV -void fused_attn_bwd_fp8_qkvpacked( +void fused_attn_fp8_bwd_qkvpacked( size_t b, size_t max_seqlen, size_t h, size_t d, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, diff --git a/transformer_engine/common/fused_attn/utils.cu b/transformer_engine/common/fused_attn/utils.cu index cae42bafa0..ebba6efa21 100644 --- a/transformer_engine/common/fused_attn/utils.cu +++ b/transformer_engine/common/fused_attn/utils.cu @@ -249,7 +249,6 @@ __global__ void cu_seqlens_to_actual_seqlens(size_t b, kv_seqlens[tid] = kv_cu_seqlens[tid + 1] - kv_cu_seqlens[tid]; } } - } // namespace fused_attn // get cuDNN data type diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index ed6dd4c041..447b1f9d6a 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -94,6 +94,38 @@ enum NVTE_Mask_Type { NVTE_CAUSAL_MASK = 2, }; +enum NVTE_Fused_Attn_Backend { + /*! No supported backend */ + NVTE_No_Backend = -1, + /*! cuDNN-based FP16/BF16 fused attention for <= 512 sequence length */ + NVTE_F16_max512_seqlen = 0, + /*! cuDNN-based FP16/BF16 fused attention for any sequence length */ + NVTE_F16_arbitrary_seqlen = 1, + /*! cuDNN-based FP8 fused attention for <= 512 sequence length */ + NVTE_FP8 = 2, +}; + +/*! \brief Get fused attention backend based on input parameters. + * + * \param[in] q_dtype The data type of Tensor Q. + * \param[in] kv_dtype The data type of Tensors K, V. + * \param[in] qkv_layout The layout of Tensors Q, K, V. + * \param[in] bias_type The attention bias type. + * \param[in] attn_mask_type The attention mask type. + * \param[in] dropout The dropout probability. + * \param[in] max_seqlen_q The sequence length of Q. + * \param[in] max_seqlen_kv The sequence length of K, V. + * \param[in] head_dim The head dimension of Q, K, V. + */ +NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( + NVTEDType q_dtype, + NVTEDType kv_dtype, + NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, + float dropout, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim); + /*! \brief Compute dot product attention with packed QKV input. * * Computes: @@ -104,36 +136,38 @@ enum NVTE_Mask_Type { * * Support Matrix: \verbatim - | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | - | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | Yes | <= 512 | 64 | - | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 | + | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | + | 0 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 | + | 1 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS | CAUSAL | Yes | > 512 | 64, 128 | + | 2 | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | Yes | <= 512 | 64 | \endverbatim * - * \param[in] QKV The QKV tensor in packed format, - * [total_seqs, 3, num_heads, head_dim]. - * \param[in] Bias The Bias tensor. - * \param[in,out] S The S tensor. - * \param[out] O The output O tensor. - * \param[out] Aux_Output_Tensors Auxiliary output tensors when training, e.g. M, ZInv. - * \param[in] cu_seqlens Accumulative sequence lengths, [batch_size + 1]. - * \param[in] rng_state Seed and offset of CUDA random number generator. - * \param[in] max_seqlen Max sequence length used for computing. - * It may be >= max(cu_seqlens). - * \param[in] is_training Whether this is in training mode or inference. - * \param[in] attn_scale Scaling factor for Q * K.T. - * \param[in] dropout Dropout probability. - * \param[in] qkv_layout QKV tensor's layout. - * \param[in] bias_type Bias type. - * \param[in] attn_mask_type Attention mask type. - * \param[in] workspace Workspace tensor. - * \param[in] stream CUDA stream used for this operation. + * \param[in] QKV The QKV tensor in packed format, + * [total_seqs, 3, num_heads, head_dim]. + * \param[in] Bias The Bias tensor. + * \param[in,out] S The S tensor. + * \param[out] O The output O tensor. + * \param[out] Aux_CTX_Tensors Auxiliary output tensors when training, + * e.g. M, ZInv, rng_state. + * \param[in] cu_seqlens Accumulative sequence lengths, [batch_size + 1]. + * \param[in] rng_state Seed and offset of CUDA random number generator. + * \param[in] max_seqlen Max sequence length used for computing, + * it may be >= max(cu_seqlens). + * \param[in] is_training Whether this is in training mode or inference. + * \param[in] attn_scale Scaling factor for Q * K.T. + * \param[in] dropout Dropout probability. + * \param[in] qkv_layout QKV tensor's layout. + * \param[in] bias_type Bias type. + * \param[in] attn_mask_type Attention mask type. + * \param[in] workspace Workspace tensor. + * \param[in] stream CUDA stream used for this operation. */ void nvte_fused_attn_fwd_qkvpacked( const NVTETensor QKV, const NVTETensor Bias, NVTETensor S, NVTETensor O, - NVTETensorPack* Aux_Output_Tensors, + NVTETensorPack* Aux_CTX_Tensors, const NVTETensor cu_seqlens, const NVTETensor rng_state, size_t max_seqlen, @@ -147,30 +181,32 @@ void nvte_fused_attn_fwd_qkvpacked( * * Support Matrix: \verbatim - | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | - | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | Yes | <= 512 | 64 | - | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 | + | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | + | 0 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 | + | 1 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS | CAUSAL | Yes | > 512 | 64, 128 | + | 2 | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | Yes | <= 512 | 64 | \endverbatim * - * \param[in] QKV The QKV tensor in packed format, - * [total_seqs, 3, num_heads, head_dim]. - * \param[in] O The O tensor from forward. - * \param[in] dO The gradient of the O tensor. - * \param[in] S The S tensor. - * \param[in,out] dP The gradient of the P tensor. - * \param[in] Aux_CTX_Tensors Auxiliary tensors from forward when in training mode. - * \param[out] dQKV The gradient of the QKV tensor. - * \param[out] dBias The gradient of the Bias tensor. - * \param[in] cu_seqlens Accumulative sequence lengths, [batch_size + 1]. - * \param[in] max_seqlen Max sequence length used for computing. - * It may be >= max(cu_seqlens). - * \param[in] attn_scale Scaling factor for Q * K.T. - * \param[in] dropout Dropout probability. - * \param[in] qkv_layout QKV tensor's layout. - * \param[in] bias_type Bias type. - * \param[in] attn_mask_type Attention mask type. - * \param[in] workspace Workspace tensor. - * \param[in] stream CUDA stream used for this operation. + * \param[in] QKV The QKV tensor in packed format, + * [total_seqs, 3, num_heads, head_dim]. + * \param[in] O The O tensor from forward. + * \param[in] dO The gradient of the O tensor. + * \param[in] S The S tensor. + * \param[in,out] dP The gradient of the P tensor. + * \param[in] Aux_CTX_Tensors Auxiliary tensors from context when in training mode, + * e.g. M, ZInv, rng_state. + * \param[out] dQKV The gradient of the QKV tensor. + * \param[out] dBias The gradient of the Bias tensor. + * \param[in] cu_seqlens Accumulative sequence lengths, [batch_size + 1]. + * \param[in] max_seqlen Max sequence length used for computing, + * it may be >= max(cu_seqlens). + * \param[in] attn_scale Scaling factor for Q * K.T. + * \param[in] dropout Dropout probability. + * \param[in] qkv_layout QKV tensor's layout. + * \param[in] bias_type Bias type. + * \param[in] attn_mask_type Attention mask type. + * \param[in] workspace Workspace tensor. + * \param[in] stream CUDA stream used for this operation. */ void nvte_fused_attn_bwd_qkvpacked( const NVTETensor QKV, @@ -199,31 +235,32 @@ void nvte_fused_attn_bwd_qkvpacked( * * Support Matrix: \verbatim - | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | - | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 | + | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | + | 0 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 | \endverbatim * - * \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim]. - * \param[in] KV The KV tensor, [total_seqs_kv, 2, num_heads, head_dim]. - * \param[in] Bias The Bias tensor. - * \param[in,out] S The S tensor. - * \param[out] O The output O tensor. - * \param[out] Aux_Output_Tensors Auxiliary output tensors when training, e.g. M, ZInv. - * \param[in] cu_seqlens_q Accumulative sequence lengths for Q, [batch_size + 1]. - * \param[in] cu_seqlens_kv Accumulative sequence lengths for KV, [batch_size + 1]. - * \param[in] rng_state Seed and offset of CUDA random number generator. - * \param[in] max_seqlen_q Max sequence length used for computing - * for Q. It may be >= max(cu_seqlens_q). - * \param[in] max_seqlen_kv Max sequence length used for computing - * for KV. It may be >= max(cu_seqlens_kv). - * \param[in] is_training Whether this is in training mode or inference. - * \param[in] attn_scale Scaling factor for Q * K.T. - * \param[in] dropout Dropout probability. - * \param[in] qkv_layout QKV tensor's layout. - * \param[in] bias_type Bias type. - * \param[in] attn_mask_type Attention mask type. - * \param[in] workspace Workspace tensor. - * \param[in] stream CUDA stream used for this operation. + * \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim]. + * \param[in] KV The KV tensor, [total_seqs_kv, 2, num_heads, head_dim]. + * \param[in] Bias The Bias tensor. + * \param[in,out] S The S tensor. + * \param[out] O The output O tensor. + * \param[out] Aux_CTX_Tensors Auxiliary output tensors when training, + * e.g. M, ZInv, rng_state. + * \param[in] cu_seqlens_q Accumulative sequence lengths for Q, [batch_size + 1]. + * \param[in] cu_seqlens_kv Accumulative sequence lengths for KV, [batch_size + 1]. + * \param[in] rng_state Seed and offset of CUDA random number generator. + * \param[in] max_seqlen_q Max sequence length used for computing for Q. + * it may be >= max(cu_seqlens_q). + * \param[in] max_seqlen_kv Max sequence length used for computing for KV. + * it may be >= max(cu_seqlens_kv). + * \param[in] is_training Whether this is in training mode or inference. + * \param[in] attn_scale Scaling factor for Q * K.T. + * \param[in] dropout Dropout probability. + * \param[in] qkv_layout QKV tensor's layout. + * \param[in] bias_type Bias type. + * \param[in] attn_mask_type Attention mask type. + * \param[in] workspace Workspace tensor. + * \param[in] stream CUDA stream used for this operation. */ void nvte_fused_attn_fwd_kvpacked( const NVTETensor Q, @@ -231,7 +268,7 @@ void nvte_fused_attn_fwd_kvpacked( const NVTETensor Bias, NVTETensor S, NVTETensor O, - NVTETensorPack* Aux_Output_Tensors, + NVTETensorPack* Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, const NVTETensor rng_state, @@ -246,33 +283,34 @@ void nvte_fused_attn_fwd_kvpacked( * * Support Matrix: \verbatim - | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | - | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 | + | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | + | 0 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 | \endverbatim * - * \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim]. - * \param[in] KV The KV tensor, [total_seqs_kv, 2, num_heads, head_dim]. - * \param[in] O The O tensor from forward. - * \param[in] dO The gradient of the O tensor. - * \param[in] S The S tensor. - * \param[in,out] dP The gradient of the P tensor. - * \param[in] Aux_CTX_Tensors Auxiliary tensors from forward when in training mode. - * \param[out] dQ The gradient of the Q tensor. - * \param[out] dKV The gradient of the KV tensor. - * \param[out] dBias The gradient of the Bias tensor. - * \param[in] cu_seqlens_q Accumulative sequence lengths for Q, [batch_size + 1]. - * \param[in] cu_seqlens_kv Accumulative sequence lengths for KV, [batch_size + 1]. - * \param[in] max_seqlen_q Max sequence length used for computing - * for Q. It may be >= max(cu_seqlens_q). - * \param[in] max_seqlen_kv Max sequence length used for computing - * for KV. It may be >= max(cu_seqlens_kv). - * \param[in] attn_scale Scaling factor for Q * K.T. - * \param[in] dropout Dropout probability. - * \param[in] qkv_layout QKV tensor's layout. - * \param[in] bias_type Bias type. - * \param[in] attn_mask_type Attention mask type. - * \param[in] workspace Workspace tensor. - * \param[in] stream CUDA stream used for this operation. + * \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim]. + * \param[in] KV The KV tensor, [total_seqs_kv, 2, num_heads, head_dim]. + * \param[in] O The O tensor from forward. + * \param[in] dO The gradient of the O tensor. + * \param[in] S The S tensor. + * \param[in,out] dP The gradient of the P tensor. + * \param[in] Aux_CTX_Tensors Auxiliary tensors from context when in training mode, + * e.g. M, ZInv, rng_state. + * \param[out] dQ The gradient of the Q tensor. + * \param[out] dKV The gradient of the KV tensor. + * \param[out] dBias The gradient of the Bias tensor. + * \param[in] cu_seqlens_q Accumulative sequence lengths for Q, [batch_size + 1]. + * \param[in] cu_seqlens_kv Accumulative sequence lengths for KV, [batch_size + 1]. + * \param[in] max_seqlen_q Max sequence length used for computing for Q. + * it may be >= max(cu_seqlens_q). + * \param[in] max_seqlen_kv Max sequence length used for computing for KV. + * it may be >= max(cu_seqlens_kv). + * \param[in] attn_scale Scaling factor for Q * K.T. + * \param[in] dropout Dropout probability. + * \param[in] qkv_layout QKV tensor's layout. + * \param[in] bias_type Bias type. + * \param[in] attn_mask_type Attention mask type. + * \param[in] workspace Workspace tensor. + * \param[in] stream CUDA stream used for this operation. */ void nvte_fused_attn_bwd_kvpacked( const NVTETensor Q, diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index f81b37cbc7..492ebe5cb6 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -15,6 +15,16 @@ from flash_attn.flash_attn_interface import flash_attn_unpadded_func import transformer_engine_extensions as tex +from transformer_engine.pytorch.cpp_extensions.fused_attn import ( + fused_attn_fwd_qkvpacked, + fused_attn_bwd_qkvpacked, + fused_attn_fwd_kvpacked, + fused_attn_bwd_kvpacked, + QKVLayout, + AttnBiasType, + AttnMaskType, + FusedAttnBackend, +) from transformer_engine.pytorch.module import LayerNormLinear, Linear from transformer_engine.pytorch.utils import ( divide, @@ -26,6 +36,7 @@ AttnMaskTypes, AttnTypes, dist_group_type, + TE_DType, ) from transformer_engine.pytorch.softmax import FusedScaleMaskSoftmax from transformer_engine.pytorch.distributed import ( @@ -267,9 +278,9 @@ def backward(ctx, return dq, dk, dv -def _check_if_interleaved(q, k, v): - data_ptr = q.storage().data_ptr() - check_ptrs = all(x.storage().data_ptr() == data_ptr for x in [q, k, v]) +def _check_if_interleaved_qkv(q, k, v): + data_ptr = q.untyped_storage().data_ptr() + check_ptrs = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k, v]) if not check_ptrs: return False @@ -288,9 +299,32 @@ def _check_if_interleaved(q, k, v): for i, x in enumerate([q, k, v])) return check_offsets +def _check_if_interleaved_kv(k, v): + data_ptr = k.untyped_storage().data_ptr() + check_ptrs = all(x.untyped_storage().data_ptr() == data_ptr for x in [k, v]) + if not check_ptrs: + return False + + stride = k.stride() + check_strides = all(stride == x.stride() for x in [k, v]) + if not check_strides: + return False + + shape = k.shape + check_shapes = all(shape == x.shape for x in [k, v]) + if not check_shapes: + return False + + last_dim_size = shape[-1] + check_offsets = all(i * last_dim_size == x.storage_offset() + for i, x in enumerate([k, v])) + return check_offsets + + class FlashAttention(torch.nn.Module): - """Dot product attention implementation by using the flash-attn package. + """Dot product attention, using HazyResearch flash-attn package: + https://github.com/HazyResearch/flash-attention """ def __init__( @@ -321,9 +355,9 @@ def forward( """flash-attn fprop""" assert ( - (query_layer.dtype in [torch.float16, torch.bfloat16]) - and (key_layer.dtype in [torch.float16, torch.bfloat16]) - and (value_layer.dtype in [torch.float16, torch.bfloat16]) + query_layer.dtype in [torch.float16, torch.bfloat16] + and key_layer.dtype in [torch.float16, torch.bfloat16] + and value_layer.dtype in [torch.float16, torch.bfloat16] ), 'FlashAttention currently only supports FP16 and BF16.' assert ( query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda @@ -333,7 +367,7 @@ def forward( if (query_layer.shape[-1] == 128 and query_layer.shape[0] * query_layer.shape[1] >= 512 and - _check_if_interleaved(query_layer, key_layer, value_layer)): + _check_if_interleaved_qkv(query_layer, key_layer, value_layer)): query_layer, key_layer, value_layer = _PrepareQKVForFA.apply(query_layer, key_layer, value_layer) @@ -369,6 +403,286 @@ def forward( return output.view(batch_size, seqlen, -1).transpose(0, 1).contiguous() +class FusedAttnFunc_qkvpacked(torch.autograd.Function): + """Function for FusedAttention with packed QKV input""" + + @staticmethod + def forward(ctx, is_training, max_seqlen, cu_seqlens, qkv, qkv_dtype, attn_bias, attn_scale, + dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, + rng_gen, fused_attention_backend): + out, aux_ctx_tensors = fused_attn_fwd_qkvpacked( + is_training, max_seqlen, cu_seqlens, qkv, qkv_dtype, + fused_attention_backend, attn_bias, + None, None, None, None, None, + attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, + rng_gen) + + ctx.save_for_backward(qkv, out, cu_seqlens) + ctx.aux_ctx_tensors = aux_ctx_tensors + ctx.max_seqlen = max_seqlen + ctx.qkv_dtype = qkv_dtype + ctx.attn_scale = attn_scale + ctx.dropout_p = dropout_p + ctx.fast_zero_fill = fast_zero_fill + ctx.qkv_layout = qkv_layout + ctx.attn_bias_type = attn_bias_type + ctx.attn_mask_type = attn_mask_type + ctx.fused_attention_backend = fused_attention_backend + + return out + + @staticmethod + def backward(ctx, d_out): + qkv, out, cu_seqlens = ctx.saved_tensors + dqkv, *rest = fused_attn_bwd_qkvpacked( + ctx.max_seqlen, cu_seqlens, qkv, out, d_out, + ctx.qkv_dtype, ctx.aux_ctx_tensors, + ctx.fused_attention_backend, + None, None, None, None, None, None, None, None, None, + ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, + ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type) + + # if no_bias, return dqkv + if ctx.attn_bias_type == "no_bias": + return (None, None, None, dqkv, None, None, None, + None, None, None, None, None, None, + None, None, None, None, None, None) + # else, return (dqkv, dbias) + return (None, None, None, dqkv, None, rest[0], None, + None, None, None, None, None, None, + None, None, None, None, None, None) + +class FusedAttnFunc_kvpacked(torch.autograd.Function): + """Function for FusedAttention with packed KV input""" + + @staticmethod + def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, + q, kv, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill, + qkv_layout, attn_bias_type, attn_mask_type, + rng_gen, fused_attention_backend): + out, aux_ctx_tensors = fused_attn_fwd_kvpacked( + is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, + q, kv, qkv_dtype, fused_attention_backend, attn_bias, + None, None, None, None, None, + attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, + rng_gen) + + ctx.save_for_backward(q, kv, out, cu_seqlens_q, cu_seqlens_kv) + ctx.aux_ctx_tensors = aux_ctx_tensors + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_kv = max_seqlen_kv + ctx.qkv_dtype = qkv_dtype + ctx.attn_scale = attn_scale + ctx.dropout_p = dropout_p + ctx.fast_zero_fill = fast_zero_fill + ctx.qkv_layout = qkv_layout + ctx.attn_bias_type = attn_bias_type + ctx.attn_mask_type = attn_mask_type + ctx.fused_attention_backend = fused_attention_backend + + return out + + @staticmethod + def backward(ctx, d_out): + q, kv, out, cu_seqlens_q, cu_seqlens_kv = ctx.saved_tensors + dq, dkv, *rest = fused_attn_bwd_kvpacked( + ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, + q, kv, out, d_out, + ctx.qkv_dtype, ctx.aux_ctx_tensors, + ctx.fused_attention_backend, + None, None, None, None, None, None, None, None, None, + ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, + ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type) + + # if no_bias, return dqkv + if ctx.attn_bias_type == "no_bias": + return (None, None, None, None, None, dq, dkv, None, None, None, + None, None, None, None, None, None, + None, None, None, None, None, None) + # else, return (dqkv, dbias) + return (None, None, None, None, None, dq, dkv, None, rest[0], None, + None, None, None, None, None, None, + None, None, None, None, None, None) + +class FusedAttention(torch.nn.Module): + """Dot product attention, with multiple backends: + + 1. FusedAttnBackend["F16_max512_seqlen"] + cuDNN based fused attention for FP16/BF16 and <=512 sequence length. + 2. FusedAttnBackend["F16_arbitrary_seqlen"] + cuDNN based fused attention for FP16/BF16 and any sequence length. + + Support matrix: + + | backend | 1 | 2 | + | flash based | no | yes | + | cuDNN based | yes | yes | + | qkv dtype | fp16/bf16 | fp16/bf16 | + | attn_type | self/cross | self | + | qkv_layout | | | + | - qkv | qkv_interleaved | qkv_interleaved | + | - (q,kv) | kv_interleaved | | + | mask_type | causal/no_mask | causal | + | bias_type | no_bias/post_scale_bias | no_bias | + | dropout | yes | yes | + | max_seqlen | <=512 | any | + | head_dim | 64 | 64,128 | + | output dtype | fp16/bf16 | fp16/bf16 | + """ + + def __init__( + self, + norm_factor: float, + attention_dropout: float = 0.0, + attention_dropout_ctx: Optional[Callable] = nullcontext, + attn_mask_type: str = "causal", + attention_type: str = "self", + ) -> None: + super().__init__() + + self.norm_factor = norm_factor + self.attention_dropout = attention_dropout + self.attention_dropout_ctx = attention_dropout_ctx + self.attn_mask_type = attn_mask_type + self.attention_type = attention_type + + def forward( + self, + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + fused_attention_backend: tex.NVTE_Fused_Attn_Backend, + core_attention_bias_type: str = "no_bias", + core_attention_bias: Optional[torch.Tensor] = None, + fast_zero_fill: bool = True, + ) -> torch.Tensor: + """fused attention fprop""" + + assert ( + (query_layer.dtype in [torch.float16, torch.bfloat16]) + and (key_layer.dtype in [torch.float16, torch.bfloat16]) + and (value_layer.dtype in [torch.float16, torch.bfloat16]) + ), 'FusedAttention only supports FP16 and BF16 data types.' + assert ( + query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda + ), 'FusedAttention only supports CUDA tensors.' + + qkv_dtype = TE_DType[query_layer.dtype] + seqlen_q, batch_size = query_layer.shape[0], query_layer.shape[1] + seqlen_kv = key_layer.shape[0] + max_seqlen_q = seqlen_q + max_seqlen_kv = seqlen_kv + + if self.attention_type == "self": + if _check_if_interleaved_qkv(query_layer, key_layer, value_layer): + query_layer = query_layer.unsqueeze(3) + key_layer = key_layer.unsqueeze(3) + value_layer = value_layer.unsqueeze(3) + # [s, b, h, 3, d] + mixed_layer = torch.cat([query_layer, key_layer, value_layer], dim = 3) + # [b, s, 3, h, d] + mixed_layer = mixed_layer.transpose(2, 3).transpose(0, 1).contiguous() + else: + query_layer = query_layer.unsqueeze(2) + key_layer = key_layer.unsqueeze(2) + value_layer = value_layer.unsqueeze(2) + # [s, b, 3, h, d] + mixed_layer = torch.cat([query_layer, key_layer, value_layer], dim = 2) + # [b, s, 3, h, d] + mixed_layer = mixed_layer.transpose(0, 1).contiguous() + + # [total_seqs, 3, h, d] + mixed_layer = mixed_layer.view( + mixed_layer.shape[0] * mixed_layer.shape[1], *mixed_layer.shape[2:]).contiguous() + + qkv_layout = "qkv_interleaved" + max_seqlen = seqlen_q + cu_seqlens = torch.arange( + 0, + (batch_size + 1) * seqlen_q, + step=seqlen_q, + dtype=torch.int32, + device=query_layer.device) + + with self.attention_dropout_ctx(): + output = FusedAttnFunc_qkvpacked.apply( + self.training, + max_seqlen, + cu_seqlens, + mixed_layer, + qkv_dtype, + core_attention_bias, + 1.0/self.norm_factor, + self.attention_dropout if self.training else 0.0, + fast_zero_fill, + qkv_layout, + core_attention_bias_type, + self.attn_mask_type, + None, # rng_gen + fused_attention_backend, + ) + output = output.view(batch_size, seqlen_q, -1).transpose(0, 1).contiguous() + + if self.attention_type == "cross": + if _check_if_interleaved_kv(key_layer, value_layer): + # [s, b, h, 2, d] + key_layer = key_layer.unsqueeze(3) + value_layer = value_layer.unsqueeze(3) + key_value = torch.cat([key_layer, value_layer], dim = 3) + # [b, s, 2, h, d] + key_value = key_value.transpose(2, 3).transpose(0, 1).contiguous() + else: + # [s, b, 2, h, d] + key_layer = key_layer.unsqueeze(2) + value_layer = value_layer.unsqueeze(2) + key_value = torch.cat([key_layer, value_layer], dim = 2) + # [b, s, 2, h, d] + key_value = key_value.transpose(0, 1).contiguous() + + # [total_seqs, 2, h, d] + query_layer = query_layer.transpose(0, 1).contiguous() + query_layer = query_layer.view( + query_layer.shape[0] * query_layer.shape[1], *query_layer.shape[2:]) + key_value = key_value.view([key_value.shape[0] * key_value.shape[1]] + + key_value.shape[2:]).contiguous() + + qkv_layout = "kv_interleaved" + cu_seqlens_q = torch.arange( + 0, + (batch_size + 1) * seqlen_q, + step=seqlen_q, + dtype=torch.int32, + device=query_layer.device) + cu_seqlens_kv = torch.arange( + 0, + (batch_size + 1) * seqlen_kv, + step=seqlen_kv, + dtype=torch.int32, + device=key_layer.device) + + with self.attention_dropout_ctx(): + outputs = FusedAttnFunc_kvpacked.apply( + self.training, + max_seqlen_q, max_seqlen_kv, + cu_seqlens_q, cu_seqlens_kv, + query_layer, key_value, + qkv_dtype, + core_attention_bias, + 1.0/self.norm_factor, + self.attention_dropout if self.training else 0.0, + fast_zero_fill, + qkv_layout, + core_attention_bias_type, + self.attn_mask_type, + None, # rng_gen + fused_attention_backend, + ) + + output = (outputs[0].view(batch_size, seqlen_q, -1).transpose(0, 1).contiguous(), + outputs[1].view(batch_size, seqlen_q, -1).transpose(0, 1).contiguous()) + return output + + class DotProductAttention(torch.nn.Module): """Allows the model to jointly attend to information from different representation subspaces as described in the paper: @@ -422,15 +736,16 @@ def __init__( get_rng_state_tracker: Optional[Callable] = None, tp_group: Optional[dist_group_type] = None, layer_number: Optional[int] = None, + attention_type: str = "self", ) -> None: super().__init__() - tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group) + self.tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group) self.tp_group = tp_group self.get_rng_state_tracker = get_rng_state_tracker projection_size = kv_channels * num_attention_heads - self.hidden_size_per_partition = divide(projection_size, tp_size) + self.hidden_size_per_partition = divide(projection_size, self.tp_size) self.hidden_size_per_attention_head = divide( projection_size, num_attention_heads ) @@ -447,18 +762,28 @@ def __init__( int(os.getenv("NVTE_FLASH_ATTN", "1")) and self.device_compute_capability >= 8.0 ) + self.use_fused_attention = ( + int(os.getenv("NVTE_FUSED_ATTN", "1")) + and self.device_compute_capability >= 8.0 + ) attn_kwargs = { "attention_dropout": attention_dropout, "attention_dropout_ctx": attention_dropout_ctx, "attn_mask_type": attn_mask_type, } + self.attention_type = attention_type self.attn_mask_type = attn_mask_type + self.attention_dropout = attention_dropout if self.use_flash_attention: self.flash_attention = FlashAttention(norm_factor, **attn_kwargs) - # Instantiating both types since use of flash-attn + # Instantiating three types since use of flash-attn and FusedAttention # might be ruled out due to forward inputs. + if self.use_fused_attention: + self.fused_attention = FusedAttention( + norm_factor, **attn_kwargs, + attention_type = attention_type) self.unfused_attention = UnfusedDotProductAttention( norm_factor, **attn_kwargs, layer_number=layer_number) @@ -489,6 +814,9 @@ def forward( value_layer: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, checkpoint_core_attention: bool = False, + core_attention_bias_type: str = "no_bias", + core_attention_bias: Optional[torch.Tensor] = None, + fast_zero_fill: bool = True, ) -> torch.Tensor: """ Dot Product Attention Layer. @@ -506,6 +834,17 @@ def forward( (:attr:`sequence_length`, :attr:`batch_size`, :attr:`num_attention_heads` * :attr:`kv_channels`) is returned. + .. note:: + + `DotProductAttention` supports three backends: 1) `FlashAttention` which calls + HazyResearch's FlashAttention PyTorch API, 2) `FusedAttention` which has multiple + fused attention implementations as its backends (see `FusedAttention` for + more details), and 3) `UnfusedDotProductAttention` which is the native PyTorch + implementation with fused scaled masked softmax. Users can use environment variables + `NVTE_FLASH_ATTN`, `NVTE_FUSED_ATTN`, and `NVTE_FUSED_ATTN_BACKEND` to control + which DotProductAttention backend, and FusedAttention backend if applicable, to use. + The default DotProductAttention backend is 1. + Parameters ---------- query_layer : torch.Tensor @@ -521,9 +860,17 @@ def forward( during the backward pass in order to save memory that would otherwise be occupied to store the forward activations until backprop. + core_attention_bias_type: str, default = `no_bias` + Bias type, {`no_bias`, `pre_scale_bias`, 'post_scale_bias`} + core_attention_bias: Optional[torch.Tensor], default = `None` + Bias tensor for Q * K.T + fast_zero_fill: bool, defautl = `True` + Whether to use the fast path to set output tensors to 0 or not. """ use_flash_attention = self.use_flash_attention + use_fused_attention = self.use_fused_attention + if (query_layer.dtype not in [torch.bfloat16, torch.float16] or key_layer.dtype not in [torch.bfloat16, torch.float16] or value_layer.dtype not in [torch.bfloat16, torch.float16] @@ -533,9 +880,26 @@ def forward( if self.attn_mask_type == "padding" and attention_mask is not None: use_flash_attention = False + use_fused_attention = False if is_in_onnx_export_mode(): use_flash_attention = False + use_fused_attention = False + + qkv_layout = "qkv_interleaved" if self.attention_type == "self" else "kv_interleaved" + fused_attention_backend = tex.get_fused_attn_backend( + TE_DType[query_layer.dtype], + TE_DType[key_layer.dtype], + QKVLayout[qkv_layout], + AttnBiasType[core_attention_bias_type], + AttnMaskType[self.attn_mask_type], + self.attention_dropout, + query_layer.shape[0], key_layer.shape[0], + query_layer.shape[-1]) + # DPA does not support FP8; for FP8, use cpp_extensions modules directly + is_backend_avail = (fused_attention_backend in + [FusedAttnBackend["F16_max512_seqlen"], FusedAttnBackend["F16_arbitrary_seqlen"]]) + use_fused_attention = use_fused_attention and is_backend_avail if use_flash_attention: if checkpoint_core_attention: @@ -545,6 +909,22 @@ def forward( value_layer) return self.flash_attention(query_layer, key_layer, value_layer) + if use_fused_attention: + if checkpoint_core_attention: + return self._checkpointed_attention_forward(self.fused_attention, + query_layer, + key_layer, + value_layer, + fused_attention_backend, + core_attention_bias_type, + core_attention_bias, + fast_zero_fill) + return self.fused_attention(query_layer, key_layer, value_layer, + fused_attention_backend, + core_attention_bias_type, + core_attention_bias, + fast_zero_fill) + if checkpoint_core_attention: return self._checkpointed_attention_forward( self.unfused_attention, @@ -747,6 +1127,9 @@ def forward( checkpoint_core_attention: bool = False, inference_params: Optional[Any] = None, rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, + core_attention_bias_type: str = "no_bias", + core_attention_bias: Optional[torch.Tensor] = None, + fast_zero_fill: bool = True, ) -> Tuple[Union[torch.Tensor, None], ...]: """MultiHeadAttention FWD""" # hidden_states: [sq, b, h] @@ -947,7 +1330,10 @@ def forward( key_layer, value_layer, attention_mask, - checkpoint_core_attention=checkpoint_core_attention, + checkpoint_core_attention = checkpoint_core_attention, + core_attention_bias_type = core_attention_bias_type, + core_attention_bias = core_attention_bias, + fast_zero_fill = fast_zero_fill, ) # ================= diff --git a/transformer_engine/pytorch/constants.py b/transformer_engine/pytorch/constants.py index b8495b58f3..8d109026fb 100644 --- a/transformer_engine/pytorch/constants.py +++ b/transformer_engine/pytorch/constants.py @@ -22,7 +22,7 @@ torch.bfloat16: tex.DType.kBFloat16, } -AttnMaskTypes = ("causal", "padding") +AttnMaskTypes = ("causal", "padding", "no_mask") AttnTypes = ("self", "cross") diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 51eb6b6774..35a1fa72f3 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -7,6 +7,12 @@ from typing import Tuple, List, Union import torch import transformer_engine_extensions as tex +from transformer_engine_extensions import ( + NVTE_QKV_Layout, + NVTE_Bias_Type, + NVTE_Mask_Type, + NVTE_Fused_Attn_Backend +) __all__ = ['fused_attn_fwd_qkvpacked', @@ -24,6 +30,34 @@ tex.DType.kInt32: torch.int32, } +QKVLayout = { + "not_interleaved": NVTE_QKV_Layout.NVTE_NOT_INTERLEAVED, + "qkv_interleaved": NVTE_QKV_Layout.NVTE_QKV_INTERLEAVED, + "kv_interleaved": NVTE_QKV_Layout.NVTE_KV_INTERLEAVED, + } + +AttnBiasType = { + "no_bias": NVTE_Bias_Type.NVTE_NO_BIAS, + "pre_scale_bias": NVTE_Bias_Type.NVTE_PRE_SCALE_BIAS, + "post_scale_bias": NVTE_Bias_Type.NVTE_POST_SCALE_BIAS, + } + +AttnMaskType = { + "no_mask": NVTE_Mask_Type.NVTE_NO_MASK, + "padding": NVTE_Mask_Type.NVTE_PADDING_MASK, + "causal": NVTE_Mask_Type.NVTE_CAUSAL_MASK, + } + +FusedAttnBackend = { + "F16_max512_seqlen": NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen, + "F16_arbitrary_seqlen": NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, + "FP8": NVTE_Fused_Attn_Backend.NVTE_FP8, + "No_Backend": NVTE_Fused_Attn_Backend.NVTE_No_Backend, + } + +BACKEND_F16m512_FP8_THREADS_PER_CTA = 128 +BACKEND_F16arb_ELTS_PER_THREADS = 16 + def check_tensor(x: torch.Tensor): """Check tensor properties.""" @@ -109,7 +143,8 @@ def fused_attn_fwd_qkvpacked( cu_seqlens: torch.Tensor, qkv: torch.Tensor, qkv_dtype: tex.DType, - bias: torch.Tensor = None, + fused_attention_backend: tex.NVTE_Fused_Attn_Backend, + attn_bias: torch.Tensor = None, d_scale_qkv: torch.Tensor = None, q_scale_s: torch.Tensor = None, q_scale_o: torch.Tensor = None, @@ -117,9 +152,9 @@ def fused_attn_fwd_qkvpacked( amax_o: torch.Tensor = None, attn_scale: float = None, dropout: float = 0.0, - set_zero: bool = True, + fast_zero_fill: bool = True, qkv_layout: str = "qkv_interleaved", - bias_type: str = "no_bias", + attn_bias_type: str = "no_bias", attn_mask_type: str = "padding", rng_gen: torch.Generator = None, ) -> Tuple[Union[torch.Tensor, None], ...]: @@ -139,8 +174,10 @@ def fused_attn_fwd_qkvpacked( shape [total_seqs, 3, num_heads, head_dim], where total_seqs = cu_seqlens[-1] qkv_dtype: tex.DType data type of QKV; in tex.DType, not torch.dtype - bias: torch.Tensor, default = None - input tensor Bias when bias_type is "pre_scale_bias" or "post_scale_bias"; + fused_attention_backend: tex.NVTE_Fused_Attn_Backend + please see FusedAttention module for details on supported backends. + attn_bias: torch.Tensor, default = None + input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias"; shape [1, num_heads, max_seqlen, max_seqlen], same data type as qkv d_scale_qkv: torch.Tensor, default = None input tensor for the dequantization of QKV in FP8 computations @@ -158,12 +195,12 @@ def fused_attn_fwd_qkvpacked( dropout: float, default = 0.0 dropout probability, 0.0 means no dropout, 1.0 means no output; dropout must be 0.0 if is_training is False - set_zero: bool, default = True - if True, initializes the output tensor O to zero using the mha_fill method; - if False, doesn't initialize O after its allocation + fast_zero_fill: bool, default = True + if True, initializes the output tensor O to zero using the fast filling method; + if False, uses PyTorch's .fill_() method qkv_layout: str, default = "qkv_interleaved" layout of QKV; {"qkv_interleaved", "kv_interleaved", "not_interleaved"} - bias_type: str, default = "no_bias" + attn_bias_type: str, default = "no_bias" type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"} attn_mask_type: str, default = "padding" type of the attention mask; {"padding", "causal", "no_mask"} @@ -178,15 +215,26 @@ def fused_attn_fwd_qkvpacked( shape [total_seqs, num_heads, head_dim], where total_seqs = cu_seqlens[-1] aux_ctx_tensors: List[torch.Tensor] auxiliary output tensors used for the backward; - if is_training is True, aux_ctx_tensors = [M, ZInv, rng_state] - if is_training is False, aux_ctx_tensors = [rng_state] - M: torch.Tensor - max(Q*K.T) - shape [batch_size, num_heads, max_seqlen, 1], dtype float32 - ZInv: torch.Tensor - 1/sum(e^(x - max(x))), where x=Q*K.T - shape [batch_size, num_heads, max_seqlen, 1], dtype float32 - rng_state: torch.Tensor + if is_training is True, aux_ctx_tensors = [softmax-related tensors, rng_state] + if is_training is False, aux_ctx_tensors = None + + softmax-related tensors: + 1. if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"] + softmax: torch.Tensor + Softmax(Q*K.T) + shape [batch_size, num_heads, max_seqlen, max_seqlen], dtype float32 + 2. if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"] + softmaxStats: torch.Tensor + log(sum(e^(x - max(x)))), where x=Q*K.T + shape [batch_size, num_heads, max_seqlen, 1], dtype float32 + 3. if fused_attention_backend == FusedAttnBackend["FP8"] + M: torch.Tensor + max(Q*K.T) + shape [batch_size, num_heads, max_seqlen, 1], dtype float32 + ZInv: torch.Tensor + 1/sum(e^(x - max(x))), where x=Q*K.T + shape [batch_size, num_heads, max_seqlen, 1], dtype float32 + rng_state: torch.Tensor, optional, if backend is not F16_max512_seqlen state of the random number generator; [seed, offset], dtype uint64 """ @@ -203,60 +251,58 @@ def fused_attn_fwd_qkvpacked( if attn_scale is None: attn_scale = 1.0 / math.sqrt(d) - if bias_type != "no_bias": - assert bias is not None, "bias tensor cannot be None when bias_type is not no_bias." - assert (bias.shape == [1, h, max_seqlen, max_seqlen] - ), "bias tensor must be in [1, h, max_seqlen, max_seqlen] shape." - assert (bias.dtype == qkv.dtype - ), "bias tensor must be in the same dtype as qkv." - - # FP8 fused attention API - if (qkv_type is torch.uint8) and (max_seqlen <= 512) and (d == 64): - assert (qkv_layout == "qkv_interleaved" - and bias_type == "no_bias" - and attn_mask_type == "padding" - ), """The FP8 fused attention API currently only supports qkv_interleaved layout, - no_bias type, and padding attention mask type.""" - assert (d_scale_qkv is not None), "d_scale_qkv is required for the FP8 API." - assert (q_scale_s is not None), "q_scale_s is required for the FP8 API." - assert (q_scale_o is not None), "q_scale_o is required for the FP8 API." - assert (amax_s is not None), "amax_s is required for the FP8 API." - assert (amax_o is not None), "amax_o is required for the FP8 API." + if attn_bias_type != "no_bias": + assert (attn_bias is not None + ), "attn_bias tensor cannot be None when attn_bias_type is not no_bias." + assert (attn_bias.shape == [1, h, max_seqlen, max_seqlen] + ), "attn_bias tensor must be in [1, h, max_seqlen, max_seqlen] shape." + assert (attn_bias.dtype == qkv.dtype + ), "attn_bias tensor must be in the same dtype as qkv." + + assert (fused_attention_backend != FusedAttnBackend["No_Backend"] + ), "Fused attention does not support this input combination." + + # BF16/FP16 fused attention API from fmha_v1 apex + if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: + rng_elts_per_thread = (max_seqlen * max_seqlen + + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1)//BACKEND_F16m512_FP8_THREADS_PER_CTA + + # BF16/FP16 fused attention API from fmha_v2 + if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: + rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS + + # FP8 fused attention API from fmha_v2 + if fused_attention_backend == FusedAttnBackend["FP8"]: + rng_elts_per_thread = (max_seqlen * max_seqlen + + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1)//BACKEND_F16m512_FP8_THREADS_PER_CTA + + assert (d_scale_qkv is not None + ), "d_scale_qkv is required as an input for FP8 fused attention." + assert (q_scale_s is not None + ), "q_scale_s is required as an input for FP8 fused attention." + assert (q_scale_o is not None + ), "q_scale_o is required as an input for FP8 fused attention." + assert (amax_s is not None + ), "amax_s is required as an input for FP8 fused attention." + assert (amax_o is not None + ), "amax_o is required as an input for FP8 fused attention." check_scalar(d_scale_qkv) check_scalar(q_scale_s) check_scalar(q_scale_o) check_scalar(amax_s) check_scalar(amax_o) - # BF16/FP16 fused attention API from fmha_v2 - elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) and (max_seqlen > 512): - # add BF/FP16 support for >512 sequence length - assert False, "The BF16/FP16 support for >512 sequence length is coming!" - - # BF16/FP16 fused attention API from fmha_v1 apex - elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) and (max_seqlen <= 512): - # add BF/FP16 support for <=512 sequence length - assert False, "The BF16/FP16 support for <=512 sequence length is coming!" - - else: - assert False, "No support for this dtype and max_seqlen combination." - # execute kernel output_tensors = tex.fused_attn_fwd_qkvpacked( b, max_seqlen, total_seqs, h, d, - is_training, attn_scale, dropout, set_zero, qkv_layout, bias_type, attn_mask_type, - cu_seqlens, - qkv, - qkv_dtype, - d_scale_qkv, - q_scale_s, - q_scale_o, - amax_s, - amax_o, - bias, - rng_gen, + is_training, attn_scale, dropout, fast_zero_fill, + QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], + cu_seqlens, qkv, qkv_dtype, + d_scale_qkv, q_scale_s, q_scale_o, amax_s, amax_o, attn_bias, + rng_gen, rng_elts_per_thread, ) + # out, aux_ctx_tensors return output_tensors[0], output_tensors[1:] @@ -267,7 +313,8 @@ def fused_attn_bwd_qkvpacked( o: torch.Tensor, d_o: torch.Tensor, qkv_dtype: tex.DType, - aux_ctx_tensors: List[torch.Tensor] = None, + aux_ctx_tensors: List[torch.Tensor], + fused_attention_backend: tex.NVTE_Fused_Attn_Backend, d_scale_qkv: torch.Tensor = None, d_scale_s: torch.Tensor = None, d_scale_o: torch.Tensor = None, @@ -279,9 +326,9 @@ def fused_attn_bwd_qkvpacked( amax_dqkv: torch.Tensor = None, attn_scale: float = None, dropout: float = 0.0, - set_zero: bool = True, + fast_zero_fill: bool = True, qkv_layout: str = "qkv_interleaved", - bias_type: str = "no_bias", + attn_bias_type: str = "no_bias", attn_mask_type: str = "padding", ) -> Tuple[Union[torch.Tensor, None], ...]: """Fused Attention BWD for packed QKV input. @@ -306,6 +353,8 @@ def fused_attn_bwd_qkvpacked( aux_ctx_tensors: List[torch.Tensor] auxiliary output tensors of the forward pass when its is_training is True, e.g. aux_ctx_tensors = [M, ZInv, rng_state] + fused_attention_backend: tex.NVTE_Fused_Attn_Backend + please see FusedAttention module for details on supported backends. d_scale_qkv: torch.Tensor, default = None input tensor for the dequantization of QKV in FP8 computations d_scale_s: torch.Tensor, default = None @@ -330,12 +379,12 @@ def fused_attn_bwd_qkvpacked( dropout: float, default = 0.0 dropout probability, 0.0 means no dropout, 1.0 means no output; dropout must be 0.0 if is_training is False - set_zero: bool, default = True - if True, initializes the output tensor O to zero using the mha_fill method; - if False, doesn't initialize O after its allocation + fast_zero_fill: bool, default = True + if True, initializes the output tensor O to zero using the fast filling method; + if False, uses PyTorch's .fill_() method qkv_layout: str, default = "qkv_interleaved" layout of QKV; {"qkv_interleaved", "kv_interleaved", "not_interleaved"} - bias_type: str, default = "no_bias" + attn_bias_type: str, default = "no_bias" type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"} attn_mask_type: str, default = "padding" type of the attention mask; {"padding", "causal", "no_mask"} @@ -345,8 +394,8 @@ def fused_attn_bwd_qkvpacked( d_qkv: torch.Tensor gradient tensor of QKV; same data type and shape as QKV d_bias: torch.Tensor, optional - gradient tensor of Bias when bias_type is "pre_scale_bias" or "post_scale_bias"; - same data type and shape as Bias + gradient tensor of Bias when attn_bias_type is "pre_scale_bias" + or "post_scale_bias"; same data type and shape as Bias """ check_cu_seqlens(cu_seqlens) @@ -363,29 +412,27 @@ def fused_attn_bwd_qkvpacked( if attn_scale is None: attn_scale = 1.0 / math.sqrt(d) - assert (len(aux_ctx_tensors) >= 1 - ), "aux_ctx_tensors must contain rng_state as its last element." - rng_state = aux_ctx_tensors[-1] - check_rng_state(rng_state) - - # FP8 fused attention API - if (qkv_type is torch.uint8) and (max_seqlen <= 512) and d == 64: - assert (qkv_layout == "qkv_interleaved" - and bias_type == "no_bias" - and attn_mask_type == "padding" - ), """The FP8 fused attention API currently only supports qkv_interleaved layout, - no_bias type, and padding attention mask type.""" - assert (d_scale_qkv is not None), "d_scale_qkv is required for the FP8 API." - assert (d_scale_s is not None), "d_scale_s is required for the FP8 API." - assert (d_scale_o is not None), "d_scale_o is required for the FP8 API." - assert (d_scale_do is not None), "d_scale_do is required for the FP8 API." - assert (q_scale_s is not None), "q_scale_s is required for the FP8 API." - assert (q_scale_dp is not None), "q_scale_dp is required for the FP8 API." - assert (q_scale_dqkv is not None), "q_scale_dqkv is required for the FP8 API." - assert (amax_dp is not None), "amax_dp is required for the FP8 API." - assert (amax_dqkv is not None), "amax_dqkv is required for the FP8 API." + assert (fused_attention_backend != FusedAttnBackend["No_Backend"] + ), "Fused attention does not support this input combination." + + if fused_attention_backend != FusedAttnBackend["F16_max512_seqlen"]: + assert (len(aux_ctx_tensors) >= 1 + ), "aux_ctx_tensors must contain rng_state as its last element." + rng_state = aux_ctx_tensors[-1] + check_rng_state(rng_state) + + if fused_attention_backend == FusedAttnBackend["FP8"]: + assert (d_scale_qkv is not None), "d_scale_qkv is required for FP8 fused attention." + assert (d_scale_s is not None), "d_scale_s is required for FP8 fused attention." + assert (d_scale_o is not None), "d_scale_o is required for FP8 fused attention." + assert (d_scale_do is not None), "d_scale_do is required for FP8 fused attention." + assert (q_scale_s is not None), "q_scale_s is required for FP8 fused attention." + assert (q_scale_dp is not None), "q_scale_dp is required for FP8 fused attention." + assert (q_scale_dqkv is not None), "q_scale_dqkv is required for FP8 fused attention." + assert (amax_dp is not None), "amax_dp is required for FP8 fused attention." + assert (amax_dqkv is not None), "amax_dqkv is required for FP8 fused attention." assert (len(aux_ctx_tensors) == 3 - ), "aux_ctx_tensors is required to be [M, ZInv, rng_state] for the FP8 API." + ), "aux_ctx_tensors is required to be [M, ZInv, rng_state] for FP8 fused attention." check_scalar(d_scale_qkv) check_scalar(d_scale_s) check_scalar(d_scale_o) @@ -399,37 +446,21 @@ def fused_attn_bwd_qkvpacked( check_stats(m, b, h, max_seqlen) check_stats(z_inv, b, h, max_seqlen) - # BF16/FP16 fused attention API from fmha_v2 - elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) and (max_seqlen > 512): - # add BF/FP16 support for >512 sequence length - assert False, "The BF16/FP16 support for >512 sequence length is coming!" - - # BF16/FP16 fused attention API from fmha_v1 apex - elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) and (max_seqlen <= 512): - # add BF/FP16 support for <=512 sequence length - assert False, "The BF16/FP16 support for <=512 sequence length is coming!" - - else: - assert False, "No support for this dtype and max_seqlen combination." - # execute kernel output_tensors = tex.fused_attn_bwd_qkvpacked( b, max_seqlen, total_seqs, h, d, - attn_scale, dropout, set_zero, qkv_layout, bias_type, attn_mask_type, - cu_seqlens, - qkv, o, d_o, - qkv_dtype, - aux_ctx_tensors, + attn_scale, dropout, fast_zero_fill, + QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], + cu_seqlens, qkv, o, d_o, qkv_dtype, aux_ctx_tensors, d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, - q_scale_s, q_scale_dp, q_scale_dqkv, - amax_dp, amax_dqkv, + q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv, ) - if bias_type == "no_bias": - # return d_qkv when bias_type is no_bias - return output_tensors[0] + if attn_bias_type == "no_bias": + # return d_qkv when attn_bias_type is no_bias + return output_tensors # otherwise return (d_qkv, d_bias) - return output_tensors + return output_tensors[0], output_tensors[1] def fused_attn_fwd_kvpacked( @@ -441,7 +472,8 @@ def fused_attn_fwd_kvpacked( q: torch.Tensor, kv: torch.Tensor, qkv_dtype: tex.DType, - bias: torch.Tensor = None, + fused_attention_backend: tex.NVTE_Fused_Attn_Backend, + attn_bias: torch.Tensor = None, d_scale_qkv: torch.Tensor = None, q_scale_s: torch.Tensor = None, q_scale_o: torch.Tensor = None, @@ -449,9 +481,9 @@ def fused_attn_fwd_kvpacked( amax_o: torch.Tensor = None, attn_scale: float = None, dropout: float = 0.0, - set_zero: bool = True, + fast_zero_fill: bool = True, qkv_layout: str = "qkv_interleaved", - bias_type: str = "no_bias", + attn_bias_type: str = "no_bias", attn_mask_type: str = "padding", rng_gen: torch.Generator = None, ) -> Tuple[Union[torch.Tensor, None], ...]: @@ -479,8 +511,10 @@ def fused_attn_fwd_kvpacked( where total_seqs_kv = cu_seqlens_kv[-1] qkv_dtype: tex.DType data type of Q and KV; in tex.DType, not torch.dtype - bias: torch.Tensor, default = None - input tensor Bias when bias_type is "pre_scale_bias" or "post_scale_bias"; + fused_attention_backend: tex.NVTE_Fused_Attn_Backend + please see FusedAttention module for details on supported backends. + attn_bias: torch.Tensor, default = None + input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias"; shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q and kv d_scale_qkv: torch.Tensor, default = None input tensor for the dequantization of QKV in FP8 computations @@ -498,12 +532,12 @@ def fused_attn_fwd_kvpacked( dropout: float, default = 0.0 dropout probability, 0.0 means no dropout, 1.0 means no output; dropout must be 0.0 if is_training is False - set_zero: bool, default = True - if True, initializes the output tensor O to zero using the mha_fill method; - if False, doesn't initialize O after its allocation + fast_zero_fill: bool, default = True + if True, initializes the output tensor O to zero using the fast filling method; + if False, uses PyTorch's .fill_() method qkv_layout: str, default = "qkv_interleaved" layout of QKV; {"qkv_interleaved", "kv_interleaved", "not_interleaved"} - bias_type: str, default = "no_bias" + attn_bias_type: str, default = "no_bias" type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"} attn_mask_type: str, default = "padding" type of the attention mask; {"padding", "causal", "no_mask"} @@ -518,15 +552,26 @@ def fused_attn_fwd_kvpacked( shape [total_seqs, num_heads, head_dim], where total_seqs = cu_seqlens[-1] aux_ctx_tensors: List[torch.Tensor] auxiliary output tensors used for the backward; - if is_training is True, aux_ctx_tensors = [M, ZInv, rng_state] - if is_training is False, aux_ctx_tensors = [rng_state] - M: torch.Tensor - max(Q*K.T) - shape [batch_size, num_heads, max_seqlen, 1], dtype float32 - ZInv: torch.Tensor - 1/sum(e^(x - max(x))), where x=Q*K.T - shape [batch_size, num_heads, max_seqlen, 1], dtype float32 - rng_state: torch.Tensor + if is_training is True, aux_ctx_tensors = [softmax-related tensors, rng_state] + if is_training is False, aux_ctx_tensors = None + + softmax-related tensors: + 1. if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"] + softmax: torch.Tensor + Softmax(Q*K.T) + shape [batch_size, num_heads, max_seqlen_q, max_seqlen_kv], dtype float32 + 2. if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"] + softmaxStats: torch.Tensor + log(sum(e^(x - max(x)))), where x=Q*K.T + shape [batch_size, num_heads, max_seqlen_q, 1], dtype float32 + 3. if fused_attention_backend == FusedAttnBackend["FP8"] + M: torch.Tensor + max(Q*K.T) + shape [batch_size, num_heads, max_seqlen_q, 1], dtype float32 + ZInv: torch.Tensor + 1/sum(e^(x - max(x))), where x=Q*K.T + shape [batch_size, num_heads, max_seqlen_q, 1], dtype float32 + rng_state: torch.Tensor, optional, if backend is not F16_max512_seqlen state of the random number generator; [seed, offset], dtype uint64 """ @@ -551,49 +596,42 @@ def fused_attn_fwd_kvpacked( if attn_scale is None: attn_scale = 1.0 / math.sqrt(d) - if bias_type != "no_bias": - assert bias is not None, "bias tensor cannot be None when bias_type is not no_bias." - assert (bias.shape == [1, h, max_seqlen_q, max_seqlen_kv] - ), "bias tensor must be in [1, h, max_seqlen_q, max_seqlen_kv] shape." - assert (bias.dtype == q.dtype - ), "bias tensor must be in the same dtype as q and kv." + if attn_bias_type != "no_bias": + assert (attn_bias is not None + ), "attn_bias tensor cannot be None when attn_bias_type is not no_bias." + assert (attn_bias.shape == [1, h, max_seqlen_q, max_seqlen_kv] + ), "attn_bias tensor must be in [1, h, max_seqlen_q, max_seqlen_kv] shape." + assert (attn_bias.dtype == q.dtype + ), "attn_bias tensor must be in the same dtype as q and kv." - # FP8 fused attention API - if (qkv_type is torch.uint8) and (max_seqlen_q <= 512) and (max_seqlen_kv <= 512) \ - and (d == 64): - assert False, "The FP8 fused attention API currently only supports packed QKV input." - - # BF16/FP16 fused attention API from fmha_v2 - elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) \ - and (max_seqlen_q > 512) and (max_seqlen_kv > 512): - # add BF/FP16 support for >512 sequence length - assert False, "The BF16/FP16 support for >512 sequence length is coming!" + assert (fused_attention_backend != FusedAttnBackend["No_Backend"] + ), "Fused attention does not support this input combination." # BF16/FP16 fused attention API from fmha_v1 apex - elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) \ - and (max_seqlen_q <= 512) and (max_seqlen_kv <= 512): - # add BF/FP16 support for <=512 sequence length - assert False, "The BF16/FP16 support for <=512 sequence length is coming!" + if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: + rng_elts_per_thread = (max_seqlen_q * max_seqlen_kv + + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1)//BACKEND_F16m512_FP8_THREADS_PER_CTA - else: - assert False, "No support for this dtype and max_seqlen combination." + # BF16/FP16 fused attention API from fmha_v2 + if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: + rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS + + # FP8 fused attention API from fmha_v2 + if fused_attention_backend == FusedAttnBackend["FP8"]: + rng_elts_per_thread = (max_seqlen_q * max_seqlen_q + + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1)//BACKEND_F16m512_FP8_THREADS_PER_CTA # execute kernel output_tensors = tex.fused_attn_fwd_kvpacked( b, max_seqlen_q, max_seqlen_kv, total_seqs_q, total_seqs_kv, h, d, - is_training, attn_scale, dropout, set_zero, qkv_layout, bias_type, attn_mask_type, - cu_seqlens_q, cu_seqlens_kv, - q, kv, - qkv_dtype, - d_scale_qkv, - q_scale_s, - q_scale_o, - amax_s, - amax_o, - bias, - rng_gen, + is_training, attn_scale, dropout, fast_zero_fill, + QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], + cu_seqlens_q, cu_seqlens_kv, q, kv, qkv_dtype, + d_scale_qkv, q_scale_s, q_scale_o, amax_s, amax_o, + attn_bias, rng_gen, rng_elts_per_thread, ) + # out, aux_ctx_tensors return output_tensors[0], output_tensors[1:] @@ -607,7 +645,8 @@ def fused_attn_bwd_kvpacked( o: torch.Tensor, d_o: torch.Tensor, qkv_dtype: tex.DType, - aux_ctx_tensors: List[torch.Tensor] = None, + aux_ctx_tensors: List[torch.Tensor], + fused_attention_backend: tex.NVTE_Fused_Attn_Backend, d_scale_qkv: torch.Tensor = None, d_scale_s: torch.Tensor = None, d_scale_o: torch.Tensor = None, @@ -619,9 +658,9 @@ def fused_attn_bwd_kvpacked( amax_dqkv: torch.Tensor = None, attn_scale: float = None, dropout: float = 0.0, - set_zero: bool = True, + fast_zero_fill: bool = True, qkv_layout: str = "qkv_interleaved", - bias_type: str = "no_bias", + attn_bias_type: str = "no_bias", attn_mask_type: str = "padding", ) -> Tuple[Union[torch.Tensor, None], ...]: """Fused Attention BWD for packed KV input. @@ -654,6 +693,8 @@ def fused_attn_bwd_kvpacked( aux_ctx_tensors: List[torch.Tensor] auxiliary output tensors of the forward pass when its is_training is True, e.g. aux_ctx_tensors = [M, ZInv, rng_state] + fused_attention_backend: tex.NVTE_Fused_Attn_Backend + please see FusedAttention module for details on supported backends. d_scale_qkv: torch.Tensor, default = None input tensor for the dequantization of QKV in FP8 computations d_scale_s: torch.Tensor, default = None @@ -679,12 +720,12 @@ def fused_attn_bwd_kvpacked( dropout: float, default = 0.0 dropout probability, 0.0 means no dropout, 1.0 means no output; dropout must be 0.0 if is_training is False - set_zero: bool, default = True - if True, initializes the output tensor O to zero using the mha_fill method; - if False, doesn't initialize O after its allocation + fast_zero_fill: bool, default = True + if True, initializes the output tensor O to zero using the fast filling method; + if False, uses PyTorch's .fill_() method qkv_layout: str, default = "qkv_interleaved" layout of QKV; {"qkv_interleaved", "kv_interleaved", "not_interleaved"} - bias_type: str, default = "no_bias" + attn_bias_type: str, default = "no_bias" type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"} attn_mask_type: str, default = "padding" type of the attention mask; {"padding", "causal", "no_mask"} @@ -696,8 +737,8 @@ def fused_attn_bwd_kvpacked( d_kv: torch.Tensor gradient tensor of KV; same data type and shape as KV d_bias: torch.Tensor, optional - gradient tensor of Bias when bias_type is "pre_scale_bias" or "post_scale_bias"; - same data type and shape as Bias + gradient tensor of Bias when attn_bias_type is "pre_scale_bias" + or "post_scale_bias"; same data type and shape as Bias """ check_cu_seqlens(cu_seqlens_q) @@ -722,45 +763,52 @@ def fused_attn_bwd_kvpacked( if attn_scale is None: attn_scale = 1.0 / math.sqrt(d) - assert (len(aux_ctx_tensors) >= 1 - ), "aux_ctx_tensors must contain rng_state as its last element." - rng_state = aux_ctx_tensors[-1] - check_rng_state(rng_state) - - # FP8 fused attention API - if (qkv_type is torch.uint8) and (max_seqlen_q <= 512) and (max_seqlen_kv <= 512) \ - and d == 64: - assert False, "The FP8 fused attention API currently only supports packed QKV input." - - ############### BF16/FP16 fused attention API from fmha_v2 ################ - elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) \ - and (max_seqlen_q > 512) and (max_seqlen_kv > 512): - # add BF/FP16 support for >512 sequence length - assert False, "The BF16/FP16 support for >512 sequence length is coming!" - - ############### BF16/FP16 fused attention API from fmha_v1 apex ################ - elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) \ - and (max_seqlen_q <= 512) and (max_seqlen_kv <= 512): - # add BF/FP16 support for <=512 sequence length - assert False, "The BF16/FP16 support for <=512 sequence length is coming!" - - else: - assert False, "No support for this dtype and max_seqlen combination." + assert (fused_attention_backend != FusedAttnBackend["No_Backend"] + ), "Fused attention does not support this input combination." + + if fused_attention_backend != FusedAttnBackend["F16_max512_seqlen"]: + assert (len(aux_ctx_tensors) >= 1 + ), "aux_ctx_tensors must contain rng_state as its last element." + rng_state = aux_ctx_tensors[-1] + check_rng_state(rng_state) + + if fused_attention_backend == FusedAttnBackend["FP8"]: + assert (d_scale_qkv is not None), "d_scale_qkv is required for FP8 fused attention." + assert (d_scale_s is not None), "d_scale_s is required for FP8 fused attention." + assert (d_scale_o is not None), "d_scale_o is required for FP8 fused attention." + assert (d_scale_do is not None), "d_scale_do is required for FP8 fused attention." + assert (q_scale_s is not None), "q_scale_s is required for FP8 fused attention." + assert (q_scale_dp is not None), "q_scale_dp is required for FP8 fused attention." + assert (q_scale_dqkv is not None), "q_scale_dqkv is required for FP8 fused attention." + assert (amax_dp is not None), "amax_dp is required for FP8 fused attention." + assert (amax_dqkv is not None), "amax_dqkv is required for FP8 fused attention." + assert (len(aux_ctx_tensors) == 3 + ), "aux_ctx_tensors is required to be [M, ZInv, rng_state] for FP8 fused attention." + check_scalar(d_scale_qkv) + check_scalar(d_scale_s) + check_scalar(d_scale_o) + check_scalar(d_scale_do) + check_scalar(q_scale_s) + check_scalar(q_scale_dp) + check_scalar(q_scale_dqkv) + check_scalar(amax_dp) + check_scalar(amax_dqkv) + m, z_inv = aux_ctx_tensors[:2] + check_stats(m, b, h, max_seqlen_q) + check_stats(z_inv, b, h, max_seqlen_q) # execute kernel output_tensors = tex.fused_attn_bwd_kvpacked( b, max_seqlen_q, max_seqlen_kv, total_seqs_q, total_seqs_kv, h, d, - attn_scale, dropout, set_zero, qkv_layout, bias_type, attn_mask_type, - cu_seqlens_q, cu_seqlens_kv, - q, kv, o, d_o, - qkv_dtype, - aux_ctx_tensors, + attn_scale, dropout, fast_zero_fill, + QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], + cu_seqlens_q, cu_seqlens_kv, q, kv, o, d_o, qkv_dtype, aux_ctx_tensors, d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, - q_scale_s, q_scale_dp, q_scale_dqkv, - amax_dp, amax_dqkv, + q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv, ) - # returns (d_q, d_kv) when bias_type is no_bias; otherwise returns (d_q, d_kv, d_bias) - if bias_type == "no_bias": - return output_tensors[:2] - return output_tensors + if attn_bias_type == "no_bias": + # return (d_q, d_kv) when attn_bias_type is no_bias + return output_tensors + # otherwise return (d_q, d_kv), d_bias + return output_tensors[:2], output_tensors[2] diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 4904e1ebad..17d36b9911 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -58,7 +58,10 @@ enum FP8FwdTensors { GEMM1_OUTPUT = 2, GEMM2_INPUT = 3, GEMM2_WEIGHT = 4, - GEMM2_OUTPUT = 5 + GEMM2_OUTPUT = 5, + GEMM3_INPUT = 6, + GEMM3_WEIGHT = 7, + GEMM3_OUTPUT = 8 }; // Used as named indices on the `scale`, `scale_inv`, @@ -67,7 +70,9 @@ enum FP8BwdTensors { GRAD_OUTPUT1 = 0, GRAD_INPUT1 = 1, GRAD_OUTPUT2 = 2, - GRAD_INPUT2 = 3 + GRAD_INPUT2 = 3, + GRAD_OUTPUT3 = 4, + GRAD_INPUT3 = 5 }; @@ -81,6 +86,9 @@ transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid, inline at::ScalarType GetATenDType(transformer_engine::DType t) { switch (t) { case transformer_engine::DType::kInt32: + return torch::kInt32; + case transformer_engine::DType::kInt64: + return torch::kInt64; case transformer_engine::DType::kFloat32: return at::kFloat; case transformer_engine::DType::kFloat16: diff --git a/transformer_engine/pytorch/csrc/extensions.cu b/transformer_engine/pytorch/csrc/extensions.cu index 6d8ec6f2bb..69248d4aa9 100644 --- a/transformer_engine/pytorch/csrc/extensions.cu +++ b/transformer_engine/pytorch/csrc/extensions.cu @@ -12,43 +12,21 @@ constexpr int block_size = 512; constexpr int ctas_per_sm = 4; -// convert QKV layout to enum -NVTE_QKV_Layout get_nvte_qkv_layout(const std::string qkv_layout) { - if (qkv_layout == "not_interleaved") { - return NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED; - } else if (qkv_layout == "qkv_interleaved") { - return NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED; - } else if (qkv_layout == "kv_interleaved") { - return NVTE_QKV_Layout::NVTE_KV_INTERLEAVED; - } else { - NVTE_ERROR("Invalid QKV layout. \n"); - } -} - -// convert bias type to enum -NVTE_Bias_Type get_nvte_bias_type(const std::string bias_type) { - if (bias_type == "no_bias") { - return NVTE_Bias_Type::NVTE_NO_BIAS; - } else if (bias_type == "pre_scale_bias") { - return NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS; - } else if (bias_type == "post_scale_bias") { - return NVTE_Bias_Type::NVTE_POST_SCALE_BIAS; - } else { - NVTE_ERROR("Invalid bias type. \n"); - } -} - -// convert attn mask type to enum -NVTE_Mask_Type get_nvte_mask_type(const std::string mask_type) { - if (mask_type == "padding") { - return NVTE_Mask_Type::NVTE_PADDING_MASK; - } else if (mask_type == "causal") { - return NVTE_Mask_Type::NVTE_CAUSAL_MASK; - } else if (mask_type == "no_mask") { - return NVTE_Mask_Type::NVTE_NO_MASK; - } else { - NVTE_ERROR("Invalid attention mask type. \n"); - } +// get the fused attention backend +NVTE_Fused_Attn_Backend get_fused_attn_backend( + const transformer_engine::DType q_dtype, + const transformer_engine::DType kv_dtype, + NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, + float p_dropout, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim) { + NVTE_Fused_Attn_Backend fused_attention_backend = + nvte_get_fused_attn_backend( + static_cast(q_dtype), static_cast(kv_dtype), + qkv_layout, bias_type, attn_mask_type, + p_dropout, max_seqlen_q, max_seqlen_kv, head_dim); + return fused_attention_backend; } // fast zero-fills of tensors @@ -103,10 +81,8 @@ __global__ void unpack(at::PhiloxCudaState arg, int64_t* rng_state_ptr) { // extract PhiloxCudaState from CUDA random number generator at::PhiloxCudaState init_philox_state( at::CUDAGeneratorImpl* gen, - size_t max_seq_len, - size_t threads_per_cta) { + size_t elts_per_thread) { at::PhiloxCudaState philox_args; - size_t elts_per_thread = (max_seq_len * max_seq_len + threads_per_cta - 1)/threads_per_cta; std::lock_guard lock(gen->mutex_); philox_args = gen->philox_cuda_state(elts_per_thread); return philox_args; @@ -117,7 +93,7 @@ std::vector fused_attn_fwd_qkvpacked( size_t b, size_t max_seqlen, size_t total_seqs, size_t h, size_t d, bool is_training, float attn_scale, float p_dropout, bool set_zero, - std::string qkv_layout, std::string bias_type, std::string attn_mask_type, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens, const at::Tensor QKV, const transformer_engine::DType qkv_type, @@ -127,15 +103,18 @@ std::vector fused_attn_fwd_qkvpacked( c10::optional amax_S, c10::optional amax_O, const c10::optional Bias, - const c10::optional rng_gen) { + const c10::optional rng_gen, + size_t rng_elts_per_thread) { using namespace transformer_engine; // create output tensor O auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA); auto O = torch::empty({static_cast(total_seqs), static_cast(h), static_cast(d)}, options); - if (set_zero) { + if (set_zero && (h * d % block_size == 0)) { mha_fill(O, cu_seqlens.index({torch::indexing::Slice(-1, torch::indexing::None)})); + } else { + O.fill_(0); } // construct NVTE tensors @@ -166,7 +145,7 @@ std::vector fused_attn_fwd_qkvpacked( } else { NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); } - if ((bias_type != "no_bias") && (Bias.has_value())) { + if ((bias_type != NVTE_NO_BIAS) && (Bias.has_value())) { auto bias_shape = Bias.value().sizes().vec(); std::vector shape{bias_shape.begin(), bias_shape.end()}; te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), shape, @@ -175,23 +154,16 @@ std::vector fused_attn_fwd_qkvpacked( te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens.data_ptr(), {b+1}, DType::kInt32, nullptr, nullptr, nullptr); - // convert strings to enums - NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout); - NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type); - NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type); - // extract random number generator seed and offset auto gen = at::get_generator_or_default( rng_gen, at::cuda::detail::getDefaultCUDAGenerator()); - size_t threads_per_cta = 128; - at::PhiloxCudaState philox_args = init_philox_state(gen, max_seqlen, threads_per_cta); + at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread); auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); unpack<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>( philox_args, static_cast(rng_state.data_ptr())); auto te_rng_state = makeTransformerEngineTensor(rng_state); // create auxiliary output tensors - // if training, tensors are [M, ZInv] NVTETensorPack nvte_aux_tensor_pack; nvte_tensor_pack_create(&nvte_aux_tensor_pack); @@ -209,7 +181,7 @@ std::vector fused_attn_fwd_qkvpacked( te_rng_state.data(), max_seqlen, is_training, attn_scale, p_dropout, - qkv_layout_enum, bias_type_enum, attn_mask_type_enum, + qkv_layout, bias_type, attn_mask_type, workspace.data(), at::cuda::getCurrentCUDAStream()); @@ -219,10 +191,9 @@ std::vector fused_attn_fwd_qkvpacked( workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); - // output_tensors = [O, nvte_aux_tensor_pack.tensors, rng_state] + // output_tensors = [O, nvte_aux_tensor_pack.tensors] std::vector output_tensors; output_tensors.push_back(O); - // nvte_aux_tensor_pack.size is 0 if inference for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { auto tensor = reinterpret_cast(nvte_aux_tensor_pack.tensors[i]); // allocate memory for nvte_aux_tensor_pack.tensors @@ -230,9 +201,6 @@ std::vector fused_attn_fwd_qkvpacked( output_tensors.push_back(output_tensor); tensor->data.dptr = output_tensor.data_ptr(); } - if (is_training) { - output_tensors.push_back(rng_state); - } // execute the kernel nvte_fused_attn_fwd_qkvpacked( @@ -245,14 +213,14 @@ std::vector fused_attn_fwd_qkvpacked( te_rng_state.data(), max_seqlen, is_training, attn_scale, p_dropout, - qkv_layout_enum, bias_type_enum, attn_mask_type_enum, + qkv_layout, bias_type, attn_mask_type, workspace.data(), at::cuda::getCurrentCUDAStream()); // destroy tensor wrappers, but not allocated memory nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); - // if training, [O, M, ZInv, rng_state]; if inference, [O] + // if training, [O, softmax-related tensors, rng_state]; if inference, [O] return output_tensors; } @@ -261,7 +229,7 @@ std::vector fused_attn_bwd_qkvpacked( size_t b, size_t max_seqlen, size_t total_seqs, size_t h, size_t d, float attn_scale, float p_dropout, bool set_zero, - std::string qkv_layout, std::string bias_type, std::string attn_mask_type, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens, const at::Tensor QKV, const at::Tensor O, @@ -281,13 +249,18 @@ std::vector fused_attn_bwd_qkvpacked( // create output tensor dQKV at::Tensor dQKV = torch::empty_like(QKV); - if (set_zero) { + auto max_tokens = dQKV.size(0); + auto self_2d = dQKV.view({max_tokens, -1}); + auto fcd_size = self_2d.size(1); + if (set_zero && (fcd_size % block_size == 0)) { mha_fill(dQKV, cu_seqlens.index({torch::indexing::Slice(-1, torch::indexing::None)})); + } else { + dQKV.fill_(0); } auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA); at::Tensor dBias; TensorWrapper te_dBias; - if (bias_type != "no_bias") { + if (bias_type != NVTE_NO_BIAS) { dBias = torch::zeros({1, static_cast(h), static_cast(max_seqlen), static_cast(max_seqlen)}, options); @@ -341,13 +314,7 @@ std::vector fused_attn_bwd_qkvpacked( NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); } - // convert strings to enums - NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout); - NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type); - NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type); - // convert auxiliary tensors from forward into NVTETensors - // aux_ctx_tensors are [M, ZInv, rng_state] NVTETensorPack nvte_aux_tensor_pack; nvte_tensor_pack_create(&nvte_aux_tensor_pack); nvte_aux_tensor_pack.size = Aux_CTX_Tensors.size(); @@ -380,7 +347,7 @@ std::vector fused_attn_bwd_qkvpacked( te_cu_seqlens.data(), max_seqlen, attn_scale, p_dropout, - qkv_layout_enum, bias_type_enum, attn_mask_type_enum, + qkv_layout, bias_type, attn_mask_type, workspace.data(), at::cuda::getCurrentCUDAStream()); @@ -403,7 +370,7 @@ std::vector fused_attn_bwd_qkvpacked( te_cu_seqlens.data(), max_seqlen, attn_scale, p_dropout, - qkv_layout_enum, bias_type_enum, attn_mask_type_enum, + qkv_layout, bias_type, attn_mask_type, workspace.data(), at::cuda::getCurrentCUDAStream()); @@ -419,7 +386,7 @@ std::vector fused_attn_fwd_kvpacked( size_t total_seqs_q, size_t total_seqs_kv, size_t h, size_t d, bool is_training, float attn_scale, float p_dropout, bool set_zero, - std::string qkv_layout, std::string bias_type, std::string attn_mask_type, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q, @@ -431,15 +398,18 @@ std::vector fused_attn_fwd_kvpacked( c10::optional amax_S, c10::optional amax_O, const c10::optional Bias, - const c10::optional rng_gen) { + const c10::optional rng_gen, + size_t rng_elts_per_thread) { using namespace transformer_engine; // create output tensor O auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA); auto O = torch::empty({static_cast(total_seqs_q), static_cast(h), static_cast(d)}, options); - if (set_zero) { + if (set_zero && (h * d % block_size == 0)) { mha_fill(O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); + } else { + O.fill_(0); } // construct NVTE tensors @@ -474,7 +444,7 @@ std::vector fused_attn_fwd_kvpacked( } else { NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); } - if ((bias_type != "no_bias") && (Bias.has_value())) { + if ((bias_type != NVTE_NO_BIAS) && (Bias.has_value())) { auto bias_shape = Bias.value().sizes().vec(); std::vector shape{bias_shape.begin(), bias_shape.end()}; te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), shape, @@ -485,24 +455,16 @@ std::vector fused_attn_fwd_kvpacked( te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), {b+1}, DType::kInt32, nullptr, nullptr, nullptr); - // convert strings to enums - NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout); - NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type); - NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type); - // extract rng seed and offset auto gen = at::get_generator_or_default( rng_gen, at::cuda::detail::getDefaultCUDAGenerator()); - size_t threads_per_cta = 128; - at::PhiloxCudaState philox_args = init_philox_state( - gen, max(max_seqlen_q, max_seqlen_kv), threads_per_cta); + at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread); auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); unpack<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>( philox_args, static_cast(rng_state.data_ptr())); auto te_rng_state = makeTransformerEngineTensor(rng_state); // create auxiliary output tensors - // if training, tensors are [M, ZInv] NVTETensorPack nvte_aux_tensor_pack; nvte_tensor_pack_create(&nvte_aux_tensor_pack); @@ -522,7 +484,7 @@ std::vector fused_attn_fwd_kvpacked( te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout, - qkv_layout_enum, bias_type_enum, attn_mask_type_enum, + qkv_layout, bias_type, attn_mask_type, workspace.data(), at::cuda::getCurrentCUDAStream()); @@ -532,10 +494,9 @@ std::vector fused_attn_fwd_kvpacked( workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); - // output_tensors = [O, nvte_aux_tensor_pack.tensors, rng_state] + // output_tensors = [O, nvte_aux_tensor_pack.tensors] std::vector output_tensors; output_tensors.push_back(O); - // nvte_aux_tensor_pack.size is 0 if inference for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { auto tensor = reinterpret_cast(nvte_aux_tensor_pack.tensors[i]); // allocate memory for nvte_aux_tensor_pack.tensors @@ -543,9 +504,6 @@ std::vector fused_attn_fwd_kvpacked( output_tensors.push_back(output_tensor); tensor->data.dptr = output_tensor.data_ptr(); } - if (is_training) { - output_tensors.push_back(rng_state); - } // execute the kernel nvte_fused_attn_fwd_kvpacked( @@ -560,14 +518,14 @@ std::vector fused_attn_fwd_kvpacked( te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout, - qkv_layout_enum, bias_type_enum, attn_mask_type_enum, + qkv_layout, bias_type, attn_mask_type, workspace.data(), at::cuda::getCurrentCUDAStream()); // destroy tensor wrappers, but not allocated memory nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); - // if training, [O, M, ZInv, rng_state]; if inference, [O] + // if training, [O, softmax-related tensors, rng_state]; if inference, [O] return output_tensors; } @@ -577,7 +535,7 @@ std::vector fused_attn_bwd_kvpacked( size_t total_seqs_q, size_t total_seqs_kv, size_t h, size_t d, float attn_scale, float p_dropout, bool set_zero, - std::string qkv_layout, std::string bias_type, std::string attn_mask_type, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q, @@ -600,14 +558,23 @@ std::vector fused_attn_bwd_kvpacked( // create output tensors dQ and dKV at::Tensor dQ = torch::empty_like(Q); at::Tensor dKV = torch::empty_like(KV); - if (set_zero) { + auto max_tokens_q = dQ.size(0); + auto self_2d_q = dQ.view({max_tokens_q, -1}); + auto fcd_size_q = self_2d_q.size(1); + auto max_tokens_kv = dQ.size(0); + auto self_2d_kv = dQ.view({max_tokens_kv, -1}); + auto fcd_size_kv = self_2d_kv.size(1); + if (set_zero && (fcd_size_q % block_size == 0) && (fcd_size_kv % block_size == 0)) { mha_fill(dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); mha_fill(dKV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); + } else { + dQ.fill_(0); + dKV.fill_(0); } auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA); at::Tensor dBias; TensorWrapper te_dBias; - if (bias_type != "no_bias") { + if (bias_type != NVTE_NO_BIAS) { dBias = torch::zeros({1, static_cast(h), static_cast(max_seqlen_q), static_cast(max_seqlen_kv)}, options); @@ -674,13 +641,7 @@ std::vector fused_attn_bwd_kvpacked( te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), {b+1}, DType::kInt32, nullptr, nullptr, nullptr); - // convert strings to enums - NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout); - NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type); - NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type); - // convert auxiliary tensors from forward to NVTETensors - // aux_ctx_tensors are [M, ZInv, rng_state] NVTETensorPack nvte_aux_tensor_pack; nvte_tensor_pack_create(&nvte_aux_tensor_pack); nvte_aux_tensor_pack.size = Aux_CTX_Tensors.size(); @@ -711,7 +672,7 @@ std::vector fused_attn_bwd_kvpacked( te_cu_seqlens_kv.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, - qkv_layout_enum, bias_type_enum, attn_mask_type_enum, + qkv_layout, bias_type, attn_mask_type, workspace.data(), at::cuda::getCurrentCUDAStream()); @@ -737,7 +698,7 @@ std::vector fused_attn_bwd_kvpacked( te_cu_seqlens_kv.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, - qkv_layout_enum, bias_type_enum, attn_mask_type_enum, + qkv_layout, bias_type, attn_mask_type, workspace.data(), at::cuda::getCurrentCUDAStream()); @@ -2227,6 +2188,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("dswiglu", &dswiglu, "Backward of SwiGLU"); m.def("fa_prepare_fwd", &fa_prepare_fwd, "Prepare QKV for Flash Attention"); m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention"); + m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend"); // Misc m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version"); @@ -2279,11 +2241,37 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .value("GEMM1_OUTPUT", transformer_engine::FP8FwdTensors::GEMM1_OUTPUT) .value("GEMM2_INPUT", transformer_engine::FP8FwdTensors::GEMM2_INPUT) .value("GEMM2_WEIGHT", transformer_engine::FP8FwdTensors::GEMM2_WEIGHT) - .value("GEMM2_OUTPUT", transformer_engine::FP8FwdTensors::GEMM2_OUTPUT); + .value("GEMM2_OUTPUT", transformer_engine::FP8FwdTensors::GEMM2_OUTPUT) + .value("GEMM3_INPUT", transformer_engine::FP8FwdTensors::GEMM3_INPUT) + .value("GEMM3_WEIGHT", transformer_engine::FP8FwdTensors::GEMM3_WEIGHT) + .value("GEMM3_OUTPUT", transformer_engine::FP8FwdTensors::GEMM3_OUTPUT); py::enum_(m, "FP8BwdTensors") .value("GRAD_OUTPUT1", transformer_engine::FP8BwdTensors::GRAD_OUTPUT1) .value("GRAD_INPUT1", transformer_engine::FP8BwdTensors::GRAD_INPUT1) .value("GRAD_OUTPUT2", transformer_engine::FP8BwdTensors::GRAD_OUTPUT2) - .value("GRAD_INPUT2", transformer_engine::FP8BwdTensors::GRAD_INPUT2); + .value("GRAD_INPUT2", transformer_engine::FP8BwdTensors::GRAD_INPUT2) + .value("GRAD_OUTPUT3", transformer_engine::FP8BwdTensors::GRAD_OUTPUT3) + .value("GRAD_INPUT3", transformer_engine::FP8BwdTensors::GRAD_INPUT3); + + py::enum_(m, "NVTE_Bias_Type") + .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) + .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) + .value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); + + py::enum_(m, "NVTE_Mask_Type") + .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) + .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) + .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK); + + py::enum_(m, "NVTE_QKV_Layout") + .value("NVTE_NOT_INTERLEAVED", NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED) + .value("NVTE_QKV_INTERLEAVED", NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) + .value("NVTE_KV_INTERLEAVED", NVTE_QKV_Layout::NVTE_KV_INTERLEAVED); + + py::enum_(m, "NVTE_Fused_Attn_Backend") + .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) + .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) + .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) + .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); } diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index a2083e5492..1467397c63 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -7,17 +7,22 @@ #include "common.h" #include "../common.h" -NVTE_QKV_Layout get_nvte_qkv_layout(const std::string qkv_layout); - -NVTE_Bias_Type get_nvte_bias_type(const std::string bias_type); - -NVTE_Mask_Type get_nvte_mask_type(const std::string mask_type); +NVTE_Fused_Attn_Backend get_fused_attn_backend( + const transformer_engine::DType q_dtype, + const transformer_engine::DType kv_dtype, + NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, + float p_dropout, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim); std::vector fused_attn_fwd_qkvpacked( size_t b, size_t max_seqlen, size_t total_seqs, - size_t h, size_t d, - bool is_training, float attn_scale, float p_dropout, bool set_zero, - std::string qkv_layout, std::string bias_type, std::string attn_mask_type, + size_t h, size_t d, bool is_training, + float attn_scale, float p_dropout, bool set_zero, + NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens, const at::Tensor QKV, const transformer_engine::DType qkv_type, @@ -27,13 +32,16 @@ std::vector fused_attn_fwd_qkvpacked( c10::optional amax_S, c10::optional amax_O, const c10::optional Bias, - const c10::optional rng_gen); + const c10::optional rng_gen, + size_t rng_elts_per_thread); std::vector fused_attn_bwd_qkvpacked( size_t b, size_t max_seqlen, size_t total_seqs, - size_t h, size_t d, - float attn_scale, float p_dropout, bool set_zero, - std::string qkv_layout, std::string bias_type, std::string attn_mask_type, + size_t h, size_t d, float attn_scale, + float p_dropout, bool set_zero, + NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens, const at::Tensor QKV, const at::Tensor O, @@ -53,9 +61,11 @@ std::vector fused_attn_bwd_qkvpacked( std::vector fused_attn_fwd_kvpacked( size_t b, size_t max_seqlen_q, size_t max_seqlen_kv, size_t total_seqs_q, size_t total_seqs_kv, - size_t h, size_t d, - bool is_training, float attn_scale, float p_dropout, bool set_zero, - std::string qkv_layout, std::string bias_type, std::string attn_mask_type, + size_t h, size_t d, bool is_training, + float attn_scale, float p_dropout, bool set_zero, + NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q, @@ -67,14 +77,17 @@ std::vector fused_attn_fwd_kvpacked( c10::optional amax_S, c10::optional amax_O, const c10::optional Bias, - const c10::optional rng_gen); + const c10::optional rng_gen, + size_t rng_elts_per_thread); std::vector fused_attn_bwd_kvpacked( size_t b, size_t max_seqlen_q, size_t max_seqlen_kv, size_t total_seqs_q, size_t total_seqs_kv, - size_t h, size_t d, - float attn_scale, float p_dropout, bool set_zero, - std::string qkv_layout, std::string bias_type, std::string attn_mask_type, + size_t h, size_t d, float attn_scale, + float p_dropout, bool set_zero, + NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q, diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index b30236acad..6a39c2cab1 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -400,6 +400,9 @@ def forward( checkpoint_core_attention: bool = False, inference_params: Optional[Any] = None, rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, + core_attention_bias_type: str = "no_bias", + core_attention_bias: Optional[torch.Tensor] = None, + fast_zero_fill: bool = True, ) -> torch.Tensor: """ Transformer Layer: attention block and a feedforward network (MLP) @@ -442,6 +445,12 @@ def forward( rotary_pos_emb: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None` Embeddings for query and key tensors for applying rotary position embedding. By default no input embedding is applied. + core_attention_bias_type: str, default = `no_bias` + Bias type, {`no_bias`, `pre_scale_bias`, 'post_scale_bias`} + core_attention_bias: Optional[torch.Tensor], default = `None` + Bias tensor for Q * K.T + fast_zero_fill: bool, default = `True` + Whether to set output tensors to 0 or not before use. """ hidden_states = hidden_states.contiguous() @@ -470,6 +479,9 @@ def forward( is_first_microbatch=is_first_microbatch, checkpoint_core_attention=checkpoint_core_attention, rotary_pos_emb=rotary_pos_emb, + core_attention_bias_type=core_attention_bias_type, + core_attention_bias=core_attention_bias, + fast_zero_fill=fast_zero_fill, ) if self.apply_residual_connection_post_layernorm and not self.output_layernorm: @@ -513,6 +525,9 @@ def forward( encoder_output=encoder_output, is_first_microbatch=is_first_microbatch, checkpoint_core_attention=checkpoint_core_attention, + core_attention_bias_type=core_attention_bias_type, + core_attention_bias=core_attention_bias, + fast_zero_fill=fast_zero_fill, ) if self.apply_residual_connection_post_layernorm: attention_output, attention_bias, residual = inter_attention_outputs From ac919e4559f1d04e782da31268894272c8eb79d4 Mon Sep 17 00:00:00 2001 From: asfiyab-nvidia <117682710+asfiyab-nvidia@users.noreply.github.com> Date: Tue, 20 Jun 2023 17:59:31 -0700 Subject: [PATCH 35/68] Fix BF16 ONNX export for successful ONNX Runtime Verification (#290) Signed-off-by: Asfiya Baig --- transformer_engine/pytorch/attention.py | 7 ++++++- transformer_engine/pytorch/te_onnx_extensions.py | 3 +++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 492ebe5cb6..ab164cff79 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -180,14 +180,19 @@ def forward( key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1) # preallocting result tensor: [b * np, sq, sk] + # WAR to set dtype to FP32 as ONNX lacks BF16 support for ConstantOfShape operator + is_bf16 = query_layer.dtype == torch.bfloat16 matmul_result = torch.empty( output_size[0] * output_size[1], output_size[2], output_size[3], - dtype=query_layer.dtype, + dtype=torch.float32 if is_in_onnx_export_mode() and is_bf16 else query_layer.dtype, device=torch.cuda.current_device(), ) + if is_in_onnx_export_mode() and is_bf16: + matmul_result = matmul_result.bfloat16() + scale = self.norm_factor if apply_qk_layer_scaling: scale *= self.layer_number diff --git a/transformer_engine/pytorch/te_onnx_extensions.py b/transformer_engine/pytorch/te_onnx_extensions.py index f641926cc2..3f3e97f198 100755 --- a/transformer_engine/pytorch/te_onnx_extensions.py +++ b/transformer_engine/pytorch/te_onnx_extensions.py @@ -254,6 +254,7 @@ def onnx_te_gemm( """ONNX graph for te_gemm""" # pylint: disable=unused-argument is_fp16 = is_dtype_fp16(inputs) + is_bf16 = is_dtype_bf16(inputs) if input_type == int(tex.DType.kFloat8E4M3): inputs = dequantize(g, inputs, input_scale_inverse, input_fp8_tensor, out_type) @@ -277,6 +278,8 @@ def onnx_te_gemm( else: if is_fp16: output = g.op("Cast", output, to_i=_C_onnx.TensorProtoDataType.FLOAT16) + elif is_bf16: + output = g.op("Cast", output, to_i=_C_onnx.TensorProtoDataType.BFLOAT16) return output From 96ed6fc69d99a9cff49637dbc58c837c8d921ad7 Mon Sep 17 00:00:00 2001 From: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com> Date: Fri, 23 Jun 2023 05:04:33 +0300 Subject: [PATCH 36/68] Fix layer_norm ONNX export (#293) * Fix ONNX export of layer_norm ONNX has a spec bug: ConstantOfShape supports all dtypes except for BF16. To WAR we use dtype FP32 and then cast to BF16. Will also issue a PR to the ONNX sig committee to change the spec in opset 20. Signed-off-by: Neta Zmora * fix lint Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Neta Zmora Signed-off-by: Kirthi Shankar Sivamani Co-authored-by: Kirthi Shankar Sivamani --- .../pytorch/te_onnx_extensions.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/te_onnx_extensions.py b/transformer_engine/pytorch/te_onnx_extensions.py index 3f3e97f198..5990160294 100755 --- a/transformer_engine/pytorch/te_onnx_extensions.py +++ b/transformer_engine/pytorch/te_onnx_extensions.py @@ -304,6 +304,20 @@ def onnx_layernorm_fwd_fp8(g, inputs, weight, bias, eps, scale, amax, def onnx_layernorm_fwd(g, inputs, weight, bias, eps, zero_centered_gamma): """ONNX graph for layernorm_fwd""" # pylint: disable=unused-argument + + def ones_like(inp, dtype): + """Returns a tensor filled with the scalar value 1, with the same size as input and + with dtype data-type""" + shape = g.op("Shape", inp) + # WAR ONNX spec: ConstantOfShape accepts all data types except for BF16. To WAR + # create a ConstantOfShape with type FP32 and then add a Cast to BF16. + is_bf16 = dtype == torch.bfloat16 + one = g.op("ConstantOfShape", shape, value_t=torch.tensor([1], + dtype=torch.float32 if is_bf16 else dtype)) + if is_bf16: + one = g.op("Cast", one, to_i=_C_onnx.TensorProtoDataType.BFLOAT16) + return one + normalized_shape = torch.onnx.symbolic_helper._get_tensor_sizes(inputs) if normalized_shape is None: ndim = torch.onnx.symbolic_helper._get_tensor_rank(inputs) @@ -314,8 +328,7 @@ def onnx_layernorm_fwd(g, inputs, weight, bias, eps, zero_centered_gamma): if zero_centered_gamma: inputs_dtype = inputs.type().dtype() - shape = g.op("Shape", weight) - one = g.op("ConstantOfShape", shape, value_t=torch.tensor([1], dtype=inputs_dtype)) + one = ones_like(weight, inputs_dtype) weight = g.op("Add", weight, one) axis = -len(normalized_shape) From 94beb13062f98e03ca71197aeab6821545c4e679 Mon Sep 17 00:00:00 2001 From: zlsh80826 Date: Tue, 18 Jul 2023 22:47:31 +0800 Subject: [PATCH 37/68] [JAX] Fully remove attn_type and set self_attn_mask_type default to 'causal' (#324) * Fully remove attn_type and set self_attn_mask_type default to 'causal' Signed-off-by: Reese Wang * Fix tests with new arguments Signed-off-by: Reese Wang * Explicit self_attn_mask_type for examples Signed-off-by: Reese Wang * Update transformer_engine/jax/flax/transformer.py Co-authored-by: Kirthi Shankar Sivamani Signed-off-by: zlsh80826 * Update transformer_engine/jax/flax/transformer.py Co-authored-by: Kirthi Shankar Sivamani Signed-off-by: zlsh80826 --------- Signed-off-by: Reese Wang Signed-off-by: zlsh80826 Co-authored-by: Kirthi Shankar Sivamani --- .../encoder/test_model_parallel_encoder.py | 1 + examples/jax/encoder/test_multigpu_encoder.py | 1 + .../encoder/test_multiprocessing_encoder.py | 1 + .../jax/encoder/test_single_gpu_encoder.py | 1 + tests/jax/test_layer.py | 2 + tests/jax/test_praxis_layers.py | 20 ++++----- transformer_engine/jax/flax/transformer.py | 43 +++---------------- transformer_engine/jax/praxis/transformer.py | 6 +-- 8 files changed, 24 insertions(+), 51 deletions(-) diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index 4a26244fff..75c41964c9 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -48,6 +48,7 @@ def __call__(self, x, mask, disable_dropout=False): attention_dropout=0.1, dropout_rng_name=DROPOUT_KEY, layer_type=te_flax.TransformerLayerType.ENCODER, + self_attn_mask_type='padding', enable_relative_embedding=False, dtype=jnp.bfloat16) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index ef3837c8d4..53be4b7134 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -45,6 +45,7 @@ def __call__(self, x, mask, disable_dropout=False): attention_dropout=0.1, dropout_rng_name=DROPOUT_KEY, layer_type=te_flax.TransformerLayerType.ENCODER, + self_attn_mask_type='padding', enable_relative_embedding=False, dtype=jnp.bfloat16) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index a21346458c..c1cf94332f 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -51,6 +51,7 @@ def __call__(self, x, mask, disable_dropout=False): attention_dropout=0.1, dropout_rng_name=DROPOUT_KEY, layer_type=te_flax.TransformerLayerType.ENCODER, + self_attn_mask_type='padding', enable_relative_embedding=False, dtype=jnp.bfloat16) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) diff --git a/examples/jax/encoder/test_single_gpu_encoder.py b/examples/jax/encoder/test_single_gpu_encoder.py index 62798eed82..6e519d87cc 100644 --- a/examples/jax/encoder/test_single_gpu_encoder.py +++ b/examples/jax/encoder/test_single_gpu_encoder.py @@ -40,6 +40,7 @@ def __call__(self, x, mask, disable_dropout=False): attention_dropout=0.1, dropout_rng_name=DROPOUT_KEY, layer_type=te_flax.TransformerLayerType.ENCODER, + self_attn_mask_type='padding', enable_relative_embedding=False, dtype=jnp.bfloat16) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) diff --git a/tests/jax/test_layer.py b/tests/jax/test_layer.py index 30143e5f75..ef1faebaf0 100644 --- a/tests/jax/test_layer.py +++ b/tests/jax/test_layer.py @@ -171,6 +171,7 @@ def forward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): layer_cls = partial(TransformerLayer, hidden_dropout_dims=(sequence_dim,), layer_type=TransformerLayerType.ENCODER, + self_attn_mask_type='padding', dtype=dtype, **te_layer_attrs) @@ -215,6 +216,7 @@ def forward_backward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e- layer_cls = partial(TransformerLayer, hidden_dropout_dims=(sequence_dim,), layer_type=TransformerLayerType.ENCODER, + self_attn_mask_type='padding', dtype=dtype, **te_layer_attrs) ref_layer, ref_params, ref_others = generate_layer(ref_layer_cls, init_rng, inputs, diff --git a/tests/jax/test_praxis_layers.py b/tests/jax/test_praxis_layers.py index de44b3a163..7a329d39ac 100644 --- a/tests/jax/test_praxis_layers.py +++ b/tests/jax/test_praxis_layers.py @@ -659,38 +659,38 @@ def test_forward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): class MultiHeadAttnAttr: USE_BIAS = 'use_bias' LN_TYPE = 'layernorm_type' - ATTN_TYPE = 'attn_type' + ATTN_MASK_TYPE = 'attn_mask_type' ZERO_CEN = 'zero_centered_gamma' ATTRS = [{ USE_BIAS: True, LN_TYPE: 'layernorm', ZERO_CEN: False, - ATTN_TYPE: 'padding' + ATTN_MASK_TYPE: 'padding' }, { USE_BIAS: True, LN_TYPE: 'layernorm', ZERO_CEN: True, - ATTN_TYPE: 'padding' + ATTN_MASK_TYPE: 'padding' }, { USE_BIAS: True, LN_TYPE: 'rmsnorm', ZERO_CEN: False, - ATTN_TYPE: 'padding' + ATTN_MASK_TYPE: 'padding' }, { USE_BIAS: True, LN_TYPE: 'layernorm', ZERO_CEN: False, - ATTN_TYPE: 'causal' + ATTN_MASK_TYPE: 'causal' }, { USE_BIAS: True, LN_TYPE: 'layernorm', ZERO_CEN: True, - ATTN_TYPE: 'causal' + ATTN_MASK_TYPE: 'causal' }, { USE_BIAS: True, LN_TYPE: 'rmsnorm', ZERO_CEN: False, - ATTN_TYPE: 'causal' + ATTN_MASK_TYPE: 'causal' }] @@ -714,7 +714,7 @@ def generate_praxis_p_and_flax_cls(self, dtype, attrs): bias_init = WeightInit.Constant(0.0) apply_residual_connection_post_layernorm = False output_layernorm = False - attn_type = attrs[MultiHeadAttnAttr.ATTN_TYPE] + attn_mask_type = attrs[MultiHeadAttnAttr.ATTN_MASK_TYPE] fuse_qkv: bool = True transpose_batch_sequence = True scale_attn_logits = False @@ -734,7 +734,7 @@ def generate_praxis_p_and_flax_cls(self, dtype, attrs): bias_init=bias_init, apply_residual_connection_post_layernorm=apply_residual_connection_post_layernorm, output_layernorm=output_layernorm, - attn_type=attn_type, + attn_mask_type=attn_mask_type, fuse_qkv=fuse_qkv, transpose_batch_sequence=transpose_batch_sequence, scale_attn_logits=scale_attn_logits, @@ -752,7 +752,7 @@ def generate_praxis_p_and_flax_cls(self, dtype, attrs): bias_init=TransformerEngineBaseLayer.generate_params_init("bias", bias_init), apply_residual_connection_post_layernorm=apply_residual_connection_post_layernorm, output_layernorm=output_layernorm, - attn_type=attn_type, + attn_mask_type=attn_mask_type, fuse_qkv=fuse_qkv, transpose_batch_sequence=transpose_batch_sequence, scale_attn_logits=scale_attn_logits, diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index 14ad7f02e8..a5cf05bb5e 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -202,10 +202,10 @@ class MultiHeadAttention(nn.Module): Multi-head Attention (MHA), including Query, Key, Value and Output projection. - .. warning:: + .. note:: - Argument :attr:`attn_type` is deprecated and superseded by :attr:`attn_mask_type`. - :attr:`attn_type` is ignored in version 0.10 and will be fully removed in version 0.11. + Argument :attr:`mask` will be ignored when + :attr:`attn_mask_type` is set to `"causal"`. Parameters ---------- @@ -244,11 +244,9 @@ class MultiHeadAttention(nn.Module): Indicate if apply residual connection with the output of layer normalization. output_layernorm : bool, default = False Indicate if apply a layer normalization at the end of MHA. - attn_type: Any, defult = None - *Deprecated*, will be ignored in v0.10 and be fully removed in v0.11. - Please use `attn_mask_type` to config the attention mask. attn_mask_type: {'causal', 'padding'}, default = 'causal' Type of attention mask passed into softmax operation. + Introduced in v0.10.0. Optimization parameters ----------------------- @@ -284,8 +282,6 @@ class MultiHeadAttention(nn.Module): bias_init: Initializer = nn.initializers.zeros apply_residual_connection_post_layernorm: bool = False output_layernorm: bool = False - # TODO(rewang): remove attn_type and the related doc after v0.11 - attn_type: Any = None attn_mask_type: str = 'causal' dtype: DType = jnp.float32 fuse_qkv: bool = True @@ -297,14 +293,6 @@ class MultiHeadAttention(nn.Module): def __post_init__(self): if self.kernel_init is None: self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'normal') - # TODO(rewang): remove attn_type after v0.11 - if self.attn_type is not None: - warnings.warn( - "The 'attn_type' argument in the 'MultiHeadAttention' is" - " deprecated in version 0.10 and will be removed in version 0.11." - " Passing value in attn_type will be ignored, please use `attn_mask_type`" - " to config the attention mask type.", - category=DeprecationWarning) super().__post_init__() @nn.compact @@ -803,13 +791,6 @@ class TransformerLayer(nn.Module): an attention block and a feedforward network (MLP). This standard layer is based on the paper “Attention Is All You Need”. - .. warning:: - - Argument :attr:`self_attn_mask_type` is introduced in version 0.10. - Starting from version 0.11, the default value will be `"causal"`. - However, to ensure compatibility with earlier versions, before 0.11, - the default value will be `"padding"` for the encoder and `"causal"` for the decoder. - .. note:: Argument :attr:`attention_mask` will be ignored when @@ -877,6 +858,7 @@ class TransformerLayer(nn.Module): Transformer in conjunction with the TransformerLayerType.ENCODER option. self_attn_mask_type: {'causal', 'padding'}, default = 'causal' Type of attention mask passed into softmax operation. + Introduced in v0.10.0. enable_relative_embedding: bool, default = True Whether to enable relative embedding as shifting of attention logits. relative_embedding: flax.linen.Module, default = None @@ -930,7 +912,7 @@ class TransformerLayer(nn.Module): output_layernorm: bool = False float32_attention_logits: bool = False layer_type: TransformerLayerType = TransformerLayerType.ENCODER - self_attn_mask_type: str = None # TODO(rewang): default to 'causal' after 0.11 + self_attn_mask_type: str = 'causal' enable_relative_embedding: bool = True relative_embedding: nn.Module = None dtype: DType = jnp.float32 @@ -946,19 +928,6 @@ def __post_init__(self): if self.mlp_kernel_init is None: self.mlp_kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal') - # TODO(rewang): default to 'causal' in 0.11 (also updated the doc after 0.11) - if self.self_attn_mask_type is None: - warnings.warn( - "The 'self_attn_mask_type' argument in the 'TransformerLayer' is" - " introduced in version 0.10. Starting from version 0.11, the default" - " value will be 'causal'. However, to ensure compatibility with earlier" - " versions, before 0.11, the default value will be 'padding' for the" - " encoder and 'causal' for the decoder.", - category=FutureWarning) - if self.layer_type == TransformerLayerType.ENCODER: - self.self_attn_mask_type = 'padding' - else: - self.self_attn_mask_type = 'causal' super().__post_init__() @nn.compact diff --git a/transformer_engine/jax/praxis/transformer.py b/transformer_engine/jax/praxis/transformer.py index 1260c266b5..9bf9628490 100644 --- a/transformer_engine/jax/praxis/transformer.py +++ b/transformer_engine/jax/praxis/transformer.py @@ -5,7 +5,7 @@ Praxis Modules related Transformer """ from functools import partial -from typing import Any, Optional, Sequence, Tuple +from typing import Optional, Sequence, Tuple from praxis import pax_fiddle from praxis.base_layer import WeightInit @@ -73,8 +73,6 @@ class MultiHeadAttention(TransformerEngineBaseLayer): bias_init: WeightInit = WeightInit.Constant(0.0) apply_residual_connection_post_layernorm: bool = False output_layernorm: bool = False - # TODO(rewang): remove attn_type and the related doc after v0.11 - attn_type: Any = None attn_mask_type: str = 'causal' fuse_qkv: bool = True transpose_batch_sequence: bool = True @@ -147,7 +145,7 @@ class TransformerLayer(TransformerEngineBaseLayer): output_layernorm: bool = False float32_attention_logits: bool = False layer_type: TransformerLayerType = TransformerLayerType.ENCODER - self_attn_mask_type: str = None # TODO(rewang): default to 'causal' after 0.11 + self_attn_mask_type: str = 'causal' enable_relative_embedding: bool = True relative_embedding: pax_fiddle.Config[RelativePositionBiases] = pax_fiddle.template_field(None) drop_path: float = 0.0 From 32ad922b143c4c6da4f0e1aaf65b12e0fe0de035 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 18 Jul 2023 21:27:26 -0400 Subject: [PATCH 38/68] FA does not support head_dim > 64 on Ada (#328) * FA does not support head_dim > 64 on Ada Signed-off-by: Kirthi Shankar Sivamani * Add cc8.7 to no FA list Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 44baa5cda5..9cf59e5b01 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -879,7 +879,7 @@ def forward( if (query_layer.dtype not in [torch.bfloat16, torch.float16] or key_layer.dtype not in [torch.bfloat16, torch.float16] or value_layer.dtype not in [torch.bfloat16, torch.float16] - or (self.device_compute_capability == 8.6 and key_layer.shape[-1] > 64) + or (self.device_compute_capability in (8.6, 8.7, 8.9) and key_layer.shape[-1] > 64) ): use_flash_attention = False From 33576bec9ed8534d97920010097e8db7687525ab Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 18 Jul 2023 21:30:23 -0400 Subject: [PATCH 39/68] FlashAttention 2.0 support (#329) * FA v2.0 support Signed-off-by: Kirthi Shankar Sivamani * fix typo Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani --- setup.py | 2 +- transformer_engine/pytorch/attention.py | 24 +++++++++++++++--------- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/setup.py b/setup.py index 98edddfc3e..81ba934cbd 100644 --- a/setup.py +++ b/setup.py @@ -290,7 +290,7 @@ def add_unique(l: List[str], vals: Union[str, List[str]]) -> None: # Framework-specific requirements if "pytorch" in frameworks(): - add_unique(install_reqs, ["torch", "flash-attn>=1.0.6, <=1.0.7"]) + add_unique(install_reqs, ["torch", "flash-attn>=1.0.6, <=2.0.0.post1"]) add_unique(test_reqs, ["numpy", "onnxruntime", "torchvision"]) if "jax" in frameworks(): if not found_pybind11(): diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 9cf59e5b01..48600b17df 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -12,8 +12,6 @@ import torch -from flash_attn.flash_attn_interface import flash_attn_unpadded_func - import transformer_engine_extensions as tex from transformer_engine.pytorch.cpp_extensions.fused_attn import ( fused_attn_fwd_qkvpacked, @@ -47,6 +45,12 @@ _flash_attn_version = packaging.version.Version(version("flash-attn")) _flash_attn_version_required = packaging.version.Version("1.0.6") +_flash_attn_2_available = _flash_attn_version >= packaging.version.Version("2") + +if _flash_attn_2_available: + from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_forward_func # pylint: disable=no-name-in-module +else: + from flash_attn.flash_attn_interface import flash_attn_unpadded_func as flash_attn_forward_func # pylint: disable=no-name-in-module __all__ = ["DotProductAttention"] @@ -397,11 +401,14 @@ def forward( device=query_layer.device) with self.attention_dropout_ctx(): - output = flash_attn_unpadded_func( + fa_optional_forward_kwargs = {} + if not _flash_attn_2_available: + fa_optional_forward_kwargs["deterministic"] = self.deterministic + output = flash_attn_forward_func( query_layer, key_layer, value_layer, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, self.attention_dropout if self.training else 0.0, softmax_scale=1.0/self.norm_factor, causal=self.attn_causal_mask, - deterministic=self.deterministic, + **fa_optional_forward_kwargs ) # [(b sq), np, hn] -> [sq, b, (np hn)] @@ -700,11 +707,10 @@ class DotProductAttention(torch.nn.Module): .. warning:: - For the default attention mechanism, this module executes a non-deterministic version of - `flash-attn `_ whenever possible in order to - achieve optimal performance. To observe deterministic behavior, set the environment - variable :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0`. In order to disable - `flash-attn` entirely, set :attr:`NVTE_FLASH_ATTN=0`. + FlashAttention uses a non-deterministic algorithm for optimal performance. To observe + deterministic behavior at the cost of performance, use FlashAttention version < `2.0.0` + and set the environment variable :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0`. In order + to disable`flash-attn` entirely, set :attr:`NVTE_FLASH_ATTN=0`. Parameters ---------- From 07774089b079016ce79c16935f9e1c04fc3c62e2 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Wed, 19 Jul 2023 21:42:38 -0400 Subject: [PATCH 40/68] Relax FA 2.0 checks for Ada (#331) Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/attention.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 48600b17df..f1d86e224d 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -885,10 +885,15 @@ def forward( if (query_layer.dtype not in [torch.bfloat16, torch.float16] or key_layer.dtype not in [torch.bfloat16, torch.float16] or value_layer.dtype not in [torch.bfloat16, torch.float16] - or (self.device_compute_capability in (8.6, 8.7, 8.9) and key_layer.shape[-1] > 64) ): use_flash_attention = False + if key_layer.shape[-1] > 64: + if self.device_compute_capability in (8.6, 8.7): + use_flash_attention = False + elif not _flash_attn_2_available and self.device_compute_capability == 8.9: + use_flash_attention = False + if self.attn_mask_type == "padding" and attention_mask is not None: use_flash_attention = False use_fused_attention = False From 3f9db848564ec78d9c7b215a5bd81978b57b0ffe Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 25 Jul 2023 17:59:32 -0700 Subject: [PATCH 41/68] Make QK layer scaling opt-in (#339) Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/attention.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index f1d86e224d..e75b67784b 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -157,6 +157,10 @@ def __init__( # on average it should not be partition dependent. self.attention_dropout = torch.nn.Dropout(attention_dropout) + # An FP16 training trick required for certain GPT-like models. + self.apply_qk_layer_scaling = ( + bool(int(os.getenv("NVTE_APPLY_QK_LAYER_SCALING", "0"))) and layer_number is not None) + def forward( self, query_layer: torch.Tensor, @@ -166,7 +170,7 @@ def forward( ) -> torch.Tensor: """core attention fprop""" batch_size, seqlen = query_layer.shape[1], query_layer.shape[0] - apply_qk_layer_scaling = self.layer_number is not None and key_layer.dtype == torch.float16 + apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16 # [b, np, sq, sk] output_size = ( From 058f9126871477fe7fc5e950964a304f406dde16 Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Thu, 27 Jul 2023 23:48:22 +0200 Subject: [PATCH 42/68] Exposing RMSNorm in pyTorch (#306) * Exposing RMSNorm in pyTorch extensions Signed-off-by: Przemek Tredak * First pass at the Python API Signed-off-by: Przemek Tredak * Small fixes Signed-off-by: Przemek Tredak * Added numerics tests and fixed issues Signed-off-by: Przemek Tredak * Lint fixes Signed-off-by: Przemek Tredak * Added RMSNorm to LayerNormMLP Signed-off-by: Przemek Tredak * Added ONNX export and tests for RMSNorm Signed-off-by: Przemek Tredak * Fix python lint Signed-off-by: Przemek Tredak * Fix BERT case Signed-off-by: Przemek Tredak * Added normalization option to the TransformerLayer Added tests Fixed test failures Signed-off-by: Przemek Tredak * Fix documentation Co-authored-by: Przemyslaw Tredak Signed-off-by: Kirthi Shankar Sivamani * Fix kwarg bug Signed-off-by: Kirthi Shankar Sivamani * Fix IMA and invalid type error Signed-off-by: Kirthi Shankar Sivamani * Increase RMSNorm threshold for bf16 case Signed-off-by: Kirthi Shankar Sivamani * Fix ONNX tests Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Przemek Tredak Signed-off-by: Kirthi Shankar Sivamani Co-authored-by: Kirthi Shankar Sivamani --- docs/api/c/index.rst | 1 + docs/api/c/rmsnorm.rst | 9 + docs/api/pytorch.rst | 2 + setup.py | 8 +- tests/pytorch/test_numerics.py | 108 +- tests/pytorch/test_onnx_export.py | 100 +- tests/pytorch/test_sanity.py | 52 +- transformer_engine/pytorch/__init__.py | 3 + transformer_engine/pytorch/attention.py | 3 + .../pytorch/cpp_extensions/normalization.py | 85 +- transformer_engine/pytorch/csrc/common.cu | 8 + transformer_engine/pytorch/csrc/common.h | 3 + transformer_engine/pytorch/csrc/extensions.cu | 2277 ----------------- transformer_engine/pytorch/csrc/extensions.h | 81 + .../pytorch/csrc/extensions/activation.cu | 267 ++ .../pytorch/csrc/extensions/attention.cu | 876 +++++++ .../pytorch/csrc/extensions/cast.cu | 75 + .../pytorch/csrc/extensions/gemm.cu | 75 + .../pytorch/csrc/extensions/misc.cu | 25 + .../pytorch/csrc/extensions/normalization.cu | 404 +++ .../pytorch/csrc/extensions/pybind.cpp | 158 ++ .../pytorch/csrc/extensions/softmax.cu | 211 ++ .../pytorch/csrc/extensions/transpose.cu | 321 +++ transformer_engine/pytorch/csrc/ts_fp8_op.cpp | 40 + transformer_engine/pytorch/module/__init__.py | 1 + transformer_engine/pytorch/module/_common.py | 95 + .../pytorch/module/layernorm_linear.py | 177 +- .../pytorch/module/layernorm_mlp.py | 122 +- transformer_engine/pytorch/module/rmsnorm.py | 168 ++ .../pytorch/te_onnx_extensions.py | 82 +- transformer_engine/pytorch/transformer.py | 16 +- 31 files changed, 3374 insertions(+), 2479 deletions(-) create mode 100644 docs/api/c/rmsnorm.rst delete mode 100644 transformer_engine/pytorch/csrc/extensions.cu create mode 100644 transformer_engine/pytorch/csrc/extensions/activation.cu create mode 100644 transformer_engine/pytorch/csrc/extensions/attention.cu create mode 100644 transformer_engine/pytorch/csrc/extensions/cast.cu create mode 100644 transformer_engine/pytorch/csrc/extensions/gemm.cu create mode 100644 transformer_engine/pytorch/csrc/extensions/misc.cu create mode 100644 transformer_engine/pytorch/csrc/extensions/normalization.cu create mode 100644 transformer_engine/pytorch/csrc/extensions/pybind.cpp create mode 100644 transformer_engine/pytorch/csrc/extensions/softmax.cu create mode 100644 transformer_engine/pytorch/csrc/extensions/transpose.cu create mode 100644 transformer_engine/pytorch/module/_common.py create mode 100644 transformer_engine/pytorch/module/rmsnorm.py diff --git a/docs/api/c/index.rst b/docs/api/c/index.rst index f98a419088..faf6cd4575 100644 --- a/docs/api/c/index.rst +++ b/docs/api/c/index.rst @@ -19,6 +19,7 @@ directly from C/C++, without Python. gemm.h fused_attn.h layer_norm.h + rmsnorm.h softmax.h transformer_engine.h transpose.h diff --git a/docs/api/c/rmsnorm.rst b/docs/api/c/rmsnorm.rst new file mode 100644 index 0000000000..9b43f26e91 --- /dev/null +++ b/docs/api/c/rmsnorm.rst @@ -0,0 +1,9 @@ +.. + Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +rmsnorm.h +============ + +.. doxygenfile:: rmsnorm.h diff --git a/docs/api/pytorch.rst b/docs/api/pytorch.rst index e62984b3c8..22a571279b 100644 --- a/docs/api/pytorch.rst +++ b/docs/api/pytorch.rst @@ -11,6 +11,8 @@ pyTorch .. autoapiclass:: transformer_engine.pytorch.LayerNorm(hidden_size, eps=1e-5, **kwargs) +.. autoapiclass:: transformer_engine.pytorch.RMSNorm(hidden_size, eps=1e-5, **kwargs) + .. autoapiclass:: transformer_engine.pytorch.LayerNormLinear(in_features, out_features, eps=1e-5, bias=True, **kwargs) :members: forward diff --git a/setup.py b/setup.py index 81ba934cbd..ded19044fc 100644 --- a/setup.py +++ b/setup.py @@ -461,16 +461,20 @@ def setup_common_extension() -> CMakeExtension: cmake_flags=cmake_flags, ) +def _all_files_in_dir(path): + return list(path.iterdir()) + def setup_pytorch_extension() -> setuptools.Extension: """Setup CUDA extension for PyTorch support""" # Source files src_dir = root_path / "transformer_engine" / "pytorch" / "csrc" + extensions_dir = src_dir / "extensions" sources = [ - src_dir / "extensions.cu", src_dir / "common.cu", src_dir / "ts_fp8_op.cpp", - ] + ] + \ + _all_files_in_dir(extensions_dir) # Header files include_dirs = [ diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 15b820893a..2ed901cb20 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -21,7 +21,7 @@ attention_mask_func, ) from transformer_engine.pytorch import ( - DotProductAttention, Linear, LayerNormLinear, LayerNormMLP, TransformerLayer + DotProductAttention, Linear, LayerNormLinear, LayerNormMLP, TransformerLayer, RMSNorm ) from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint @@ -59,6 +59,8 @@ def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"] +all_normalizations = ["LayerNorm", "RMSNorm"] + def get_causal_attn_mask(sq: int) -> torch.Tensor: return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool() @@ -74,7 +76,16 @@ def assert_allclose(l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float) """Ensures two lists are equal.""" assert len(l1) == len(l2), "Unequal number of outputs." for t1, t2 in zip(l1, l2): - assert torch.allclose(t1, t2, atol=atol), "Outputs not close enough." + result = torch.allclose(t1, t2, atol=atol) + if not result: + diff = torch.abs(t1 - t2).flatten() + m = torch.argmax(diff) + msg = (f"Outputs not close enough." + f"Location of the maximum difference: {m.item()} " + f"with {t1.flatten()[m].item()} vs {t2.flatten()[m].item()} " + f"(diff {diff[m].item()})." + ) + raise AssertionError(msg) def _set_cuda_rng_state(new_state, device=-1): @@ -310,11 +321,38 @@ def forward( return context_layer +# Adapted from https://github.com/bzhangGo/rmsnorm/blob/c6691f20ec0af4128c8159c903071f7575404295/rmsnorm_torch.py +class TorchRMSNorm(nn.Module): + def __init__(self, in_features, eps=1e-5): + super().__init__() + + self.eps = eps + self.in_features = in_features + + self.weight = nn.Parameter(torch.ones(in_features)) + self.register_parameter("weight", self.weight) + + def forward(self, x): + norm_x = x.norm(2, dim=-1, keepdim=True) + d_x = self.in_features + + rms_x = norm_x * d_x ** (-1. / 2) + x_normed = x / (rms_x + self.eps) + + return self.weight * x_normed class TorchLayerNormLinear(nn.Module): - def __init__(self, in_features: int, out_features: int, eps: float, bias: bool = True): + def __init__(self, in_features: int, out_features: int, + eps: float, bias: bool = True, + normalization: str = "LayerNorm"): super().__init__() - self.layernorm = nn.LayerNorm(in_features, eps=eps) + if normalization == "LayerNorm": + self.layernorm = nn.LayerNorm(in_features, eps=eps) + elif normalization == "RMSNorm": + self.layernorm = TorchRMSNorm(in_features, eps=eps) + else: + raise RuntimeError("Unsupported normalization") + self.linear = nn.Linear(in_features, out_features) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -355,9 +393,15 @@ def forward(self, x): class TorchLayerNormMLP(nn.Module): def __init__(self, hidden_size: int, ffn_hidden_size: int, - eps: float = 1e-5, activation = 'gelu'): + eps: float = 1e-5, activation = 'gelu', + normalization: str = "LayerNorm"): super().__init__() - self.ln = nn.LayerNorm(hidden_size, eps=eps) + if normalization == "LayerNorm": + self.ln = nn.LayerNorm(hidden_size, eps=eps) + elif normalization == "RMSNorm": + self.ln = TorchRMSNorm(hidden_size, eps=eps) + else: + raise RuntimeError("Unsupported normalization") if 'glu' in activation: fc1_output_features = 2 * ffn_hidden_size self.gelu = TorchGLU(activation) @@ -830,11 +874,48 @@ def test_linear_accuracy(dtype, bs, model): else: assert_allclose(te_outputs[0], torch_outputs[0], 5e-2) +@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("bs", batch_sizes) +@pytest.mark.parametrize("model", model_configs.keys()) +def test_rmsnorm_accuracy(dtype, bs, model): + config = model_configs[model] + + te_rmsnorm = ( + RMSNorm( + config.hidden_size, + ) + .to(dtype=dtype) + .cuda() + .eval() + ) + + torch_rmsnorm = ( + TorchRMSNorm( + config.hidden_size, + ) + .to(dtype=dtype) + .cuda() + .eval() + ) + + # Share params + with torch.no_grad(): + torch_rmsnorm.weight = Parameter(te_rmsnorm.weight.clone()) + + te_outputs = _test_granular_accuracy(te_rmsnorm, bs, dtype, config) + torch_outputs = _test_granular_accuracy(torch_rmsnorm, bs, dtype, config) + + # Check output. + if dtype == torch.float32: + assert_allclose(te_outputs[0], torch_outputs[0], 1e-7) + else: + assert_allclose(te_outputs[0], torch_outputs[0], 2e-2) @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", model_configs.keys()) -def test_layernorm_linear_accuracy(dtype, bs, model): +@pytest.mark.parametrize("normalization", all_normalizations) +def test_layernorm_linear_accuracy(dtype, bs, model, normalization): config = model_configs[model] te_ln_linear = ( @@ -843,6 +924,7 @@ def test_layernorm_linear_accuracy(dtype, bs, model): 4 * config.hidden_size, config.eps, bias=True, + normalization=normalization, ) .to(dtype=dtype) .cuda() @@ -855,6 +937,7 @@ def test_layernorm_linear_accuracy(dtype, bs, model): 4 * config.hidden_size, config.eps, bias=True, + normalization=normalization, ) .to(dtype=dtype) .cuda() @@ -864,7 +947,8 @@ def test_layernorm_linear_accuracy(dtype, bs, model): # Share params with torch.no_grad(): torch_ln_linear.layernorm.weight = Parameter(te_ln_linear.layer_norm_weight.clone()) - torch_ln_linear.layernorm.bias = Parameter(te_ln_linear.layer_norm_bias.clone()) + if normalization != "RMSNorm": + torch_ln_linear.layernorm.bias = Parameter(te_ln_linear.layer_norm_bias.clone()) torch_ln_linear.linear.weight = Parameter(te_ln_linear.weight.clone()) torch_ln_linear.linear.bias = Parameter(te_ln_linear.bias.clone()) @@ -882,7 +966,8 @@ def test_layernorm_linear_accuracy(dtype, bs, model): @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("activation", all_activations) -def test_layernorm_mlp_accuracy(dtype, bs, model, activation): +@pytest.mark.parametrize("normalization", all_normalizations) +def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization): config = model_configs[model] te_ln_mlp = ( @@ -890,6 +975,7 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation): config.hidden_size, 4 * config.hidden_size, activation=activation, + normalization=normalization, ) .to(dtype=dtype) .cuda() @@ -901,6 +987,7 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation): config.hidden_size, 4 * config.hidden_size, activation=activation, + normalization=normalization, ) .to(dtype=dtype) .cuda() @@ -910,7 +997,8 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation): # Share params with torch.no_grad(): torch_ln_mlp.ln.weight = Parameter(te_ln_mlp.layer_norm_weight.clone()) - torch_ln_mlp.ln.bias = Parameter(te_ln_mlp.layer_norm_bias.clone()) + if normalization != "RMSNorm": + torch_ln_mlp.ln.bias = Parameter(te_ln_mlp.layer_norm_bias.clone()) torch_ln_mlp.fc1.weight = Parameter(te_ln_mlp.fc1_weight.clone()) torch_ln_mlp.fc1.bias = Parameter(te_ln_mlp.fc1_bias.clone()) torch_ln_mlp.fc2.weight = Parameter(te_ln_mlp.fc2_weight.clone()) diff --git a/tests/pytorch/test_onnx_export.py b/tests/pytorch/test_onnx_export.py index cf158e9082..d4e834bdf2 100644 --- a/tests/pytorch/test_onnx_export.py +++ b/tests/pytorch/test_onnx_export.py @@ -71,6 +71,8 @@ supported_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"] +all_normalizations = ["LayerNorm", "RMSNorm"] + @pytest.fixture() def seed_default_rng(): @@ -676,6 +678,90 @@ def forward(self, inp): validate_result( fname, inp, model, atol=atol, is_fp8=use_fp8, allow_cnt_errors=3, te_outputs=te_outputs) +@pytest.mark.parametrize("scale_factor", [448, 112]) +@pytest.mark.parametrize( + "use_fp8, precision, atol", [ + [False, torch.float32, 1e-7], + [False, torch.float16, 1e-7], + [False, torch.bfloat16, 1e-7], + [False, "fake-torch.bfloat16", 1e-7], + [True, torch.float32, 1e-7], + [True, torch.float16, 1e-7], + [True, torch.bfloat16, 1e-2], + [True, "fake-torch.bfloat16", 1e-2] +]) +def test_export_rmsnorm( + seed_default_rng, + use_fp8: bool, + scale_factor: float, + precision: torch.dtype, + atol: float +): + fake_bf16_io = precision == "fake-torch.bfloat16" + # reset precision to torch.bfloat16 after capturing fake BF16 mode + precision = torch.bfloat16 if precision == "fake-torch.bfloat16" else precision + + # Skip FP8 tests on non-hopper devices + if use_fp8 and not fp8_available: + pytest.skip(reason_for_no_fp8) + + # Set dimensions (these are arbitrary). + inp_shape = [64, 32] + + class Test_RMSnorm(nn.Module): + def __init__(self) -> None: + super().__init__() + eps = 1e-6 # An arbitrary small value + dtype = torch.float if fake_bf16_io else precision + self.ln = te.RMSNorm(inp_shape[1], eps, params_dtype=dtype).eval().cuda() + + def forward(self, inp): + ret = self.ln(inp) + return ret + + class TestFP8_RMSnorm(nn.Module): + def __init__(self) -> None: + super().__init__() + normalized_shape = torch.Size(inp.shape[1:]) + self.weight = torch.randn(*normalized_shape, device="cuda", + dtype=torch.float32 if fake_bf16_io else precision) + self.eps = 1e-6 # An arbitrary small value + + self.fp8_tensor = tex.FP8FwdTensors.GEMM1_INPUT + self.meta = create_meta(scale_factor) + self.fp8_type = tex.DType.kFloat8E4M3 + + def forward(self, inp): + ret = texcpp.rmsnorm_fwd_fp8_inf( + inp, + self.weight, + self.eps, + self.meta, + self.fp8_tensor, + self.fp8_type, + False) + + ret = cast_from_fp8( + ret, + self.meta, + self.fp8_tensor, + self.fp8_type, + as_te_type(precision)) + if fake_bf16_io: + ret = ret.type(torch.float32) + return ret + + inp = torch.randn(*inp_shape, device="cuda", dtype=torch.float32 if fake_bf16_io else precision) + model = TestFP8_RMSnorm() if use_fp8 else Test_RMSnorm() + high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io) + fp8_str = f"_fp8-{scale_factor}" if use_fp8 else "" + fname = f"te.layernorm{fp8_str}{high_prec_str}.onnx" + do_export(model, inp, fname, use_fp8=use_fp8) + te_outputs = te_infer(model, inp, is_fp8=use_fp8) + serialize_inputs_outputs(fname, inp, te_outputs) + if fake_bf16_io or precision != torch.bfloat16: + validate_result( + fname, inp, model, atol=atol, is_fp8=use_fp8, allow_cnt_errors=3, te_outputs=te_outputs) @skip_FP8 @pytest.mark.parametrize("softmax_fn", [ @@ -916,6 +1002,7 @@ def forward(self, inp): (torch.bfloat16, False), ]) @pytest.mark.parametrize("zero_centered_gamma", [False, True]) +@pytest.mark.parametrize("normalization", all_normalizations) def test_export_layernorm_linear( seed_default_rng, scale_factor: float, @@ -924,12 +1011,16 @@ def test_export_layernorm_linear( return_bias: bool, return_layernorm_output: bool, precision: torch.dtype, - zero_centered_gamma: bool + zero_centered_gamma: bool, + normalization: str, ): # Skip FP8 tests on non-hopper devices if use_fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) + if normalization == "RMSNorm" and zero_centered_gamma: + pytest.skip("RMSNorm does not support zero_centered_gamma yet!") + # Set dimensions (these are arbitrary). in_features = 64 out_features = 256 @@ -950,6 +1041,7 @@ def test_export_layernorm_linear( return_layernorm_output=return_layernorm_output, params_dtype=precision, zero_centered_gamma=zero_centered_gamma, + normalization=normalization, ).to(device='cuda') if use_fp8: set_layer_scale(model, scale_factor, num_gemms=1) @@ -980,6 +1072,7 @@ def test_export_layernorm_linear( ]) @pytest.mark.parametrize("zero_centered_gamma", [False, True]) @pytest.mark.parametrize("activation", supported_activations) +@pytest.mark.parametrize("normalization", all_normalizations) def test_export_layernorm_mlp( seed_default_rng, scale_factor: float, @@ -990,11 +1083,15 @@ def test_export_layernorm_mlp( precision: torch.dtype, zero_centered_gamma: bool, activation: str, + normalization: str, ): # Skip FP8 tests on non-hopper devices if use_fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) + if normalization == "RMSNorm" and zero_centered_gamma: + pytest.skip("RMSNorm does not support zero_centered_gamma yet!") + # Set dimensions (these are arbitrary). in_features = 64 out_features = 256 @@ -1016,6 +1113,7 @@ def test_export_layernorm_mlp( params_dtype=precision, zero_centered_gamma=zero_centered_gamma, activation=activation, + normalization=normalization, ).to(device='cuda') if use_fp8: set_layer_scale(model, scale_factor, num_gemms=2) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 101734b570..1643172c54 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -95,6 +95,7 @@ def __init__( all_boolean = [True, False] all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"] +all_normalizations = ["LayerNorm", "RMSNorm"] def _disable_wgrads(block): for p in block.parameters(): @@ -314,10 +315,16 @@ def _test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad, skip_d @pytest.mark.parametrize("skip_wgrad", all_boolean) @pytest.mark.parametrize("zero_centered_gamma", all_boolean) @pytest.mark.parametrize("skip_dgrad", all_boolean) -def test_sanity_layernorm_linear(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma, skip_dgrad): +@pytest.mark.parametrize("normalization", all_normalizations) +def test_sanity_layernorm_linear(dtype, bs, fp8_recipe, model, skip_wgrad, + zero_centered_gamma, skip_dgrad, + normalization): if fp8_recipe is not None and not fp8_available: pytest.skip(reason_for_no_fp8) + if normalization == "RMSNorm" and zero_centered_gamma: + pytest.skip("RMSNorm does not support zero_centered_gamma yet!") + config = model_configs[model] sigma = 0.023 @@ -330,6 +337,7 @@ def test_sanity_layernorm_linear(dtype, bs, fp8_recipe, model, skip_wgrad, zero_ eps=config.eps, init_method=init_method, zero_centered_gamma=zero_centered_gamma, + normalization=normalization, ) .to(dtype=dtype) .cuda() @@ -370,10 +378,16 @@ def test_sanity_linear(dtype, bs, fp8_recipe, model, skip_wgrad, skip_dgrad): @pytest.mark.parametrize("zero_centered_gamma", all_boolean) @pytest.mark.parametrize("skip_dgrad", all_boolean) @pytest.mark.parametrize("activation", all_activations) -def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma, skip_dgrad, activation): +@pytest.mark.parametrize("normalization", all_normalizations) +def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad, + zero_centered_gamma, skip_dgrad, activation, + normalization): if fp8_recipe is not None and not fp8_available: pytest.skip(reason_for_no_fp8) + if normalization == "RMSNorm" and zero_centered_gamma: + pytest.skip("RMSNorm does not support zero_centered_gamma yet!") + config = model_configs[model] sigma = 0.023 @@ -389,6 +403,7 @@ def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad, zero_cen output_layer_init_method=output_layer_init_method, zero_centered_gamma=zero_centered_gamma, activation=activation, + normalization=normalization, ) .to(dtype=dtype) .cuda() @@ -404,10 +419,16 @@ def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad, zero_cen @pytest.mark.parametrize("zero_centered_gamma", all_boolean) @pytest.mark.parametrize("bias", all_boolean) @pytest.mark.parametrize("activation", all_activations) -def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma, bias, activation): +@pytest.mark.parametrize("normalization", all_normalizations) +def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad, + zero_centered_gamma, bias, activation, + normalization): if fp8_recipe is not None and not fp8_available: pytest.skip(reason_for_no_fp8) + if normalization == "RMSNorm" and zero_centered_gamma: + pytest.skip("RMSNorm does not support zero_centered_gamma yet!") + config = model_configs[model] sigma = 0.023 @@ -430,6 +451,7 @@ def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamm zero_centered_gamma=zero_centered_gamma, bias=bias, activation=activation, + normalization=normalization, ) .to(dtype=dtype) .cuda() @@ -444,10 +466,15 @@ def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamm @pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("skip_wgrad", all_boolean) @pytest.mark.parametrize("zero_centered_gamma", all_boolean) -def test_sanity_bert(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma): +@pytest.mark.parametrize("normalization", all_normalizations) +def test_sanity_bert(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma, + normalization): if fp8_recipe is not None and not fp8_available: pytest.skip(reason_for_no_fp8) + if normalization == "RMSNorm" and zero_centered_gamma: + pytest.skip("RMSNorm does not support zero_centered_gamma yet!") + config = model_configs[model] sigma = 0.023 @@ -468,6 +495,7 @@ def test_sanity_bert(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gam apply_residual_connection_post_layernorm=True, output_layernorm=True, zero_centered_gamma=zero_centered_gamma, + normalization=normalization, ) .to(dtype=dtype) .cuda() @@ -482,10 +510,15 @@ def test_sanity_bert(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gam @pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("skip_wgrad", all_boolean) @pytest.mark.parametrize("zero_centered_gamma", all_boolean) -def test_sanity_T5(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma): +@pytest.mark.parametrize("normalization", all_normalizations) +def test_sanity_T5(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma, + normalization): if fp8_recipe is not None and not fp8_available: pytest.skip(reason_for_no_fp8) + if normalization == "RMSNorm" and zero_centered_gamma: + pytest.skip("RMSNorm does not support zero_centered_gamma yet!") + config = model_configs[model] sigma = 0.023 @@ -507,6 +540,7 @@ def test_sanity_T5(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma output_layernorm=False, layer_type="decoder", zero_centered_gamma=zero_centered_gamma, + normalization=normalization, ) .to(dtype=dtype) .cuda() @@ -669,10 +703,15 @@ def test_sanity_gradient_accumulation_fusion(dtype, bs, fp8_recipe, model, skip_ @pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("skip_wgrad", all_boolean) @pytest.mark.parametrize("zero_centered_gamma", all_boolean) -def test_gpt_cuda_graph(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma): +@pytest.mark.parametrize("normalization", all_normalizations) +def test_gpt_cuda_graph(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma, + normalization): if fp8_recipe is not None and not fp8_available: pytest.skip(reason_for_no_fp8) + if normalization == "RMSNorm" and zero_centered_gamma: + pytest.skip("RMSNorm does not support zero_centered_gamma yet!") + config = model_configs[model] sigma = 0.023 @@ -694,6 +733,7 @@ def test_gpt_cuda_graph(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_ output_layernorm=False, zero_centered_gamma=zero_centered_gamma, fuse_qkv_params=True, + normalization=normalization, ) .to(dtype=dtype) .cuda() diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index e7654b895f..b67ecd05b9 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -7,6 +7,7 @@ from .module import Linear from .module import LayerNormMLP from .module import LayerNorm +from .module import RMSNorm from .attention import DotProductAttention from .transformer import TransformerLayer from .fp8 import fp8_autocast @@ -21,4 +22,6 @@ onnx_te_gemm, onnx_layernorm_fwd_fp8, onnx_layernorm_fwd, + onnx_rmsnorm_fwd, + onnx_rmsnorm_fwd_fp8 ) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index e75b67784b..dd3f561c95 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -990,6 +990,7 @@ def __init__( ub_split_rs: bool = False, ub_split_ag: bool = False, bias: bool = True, + normalization: str = "LayerNorm", ) -> None: super().__init__() self.layer_number = (layer_number,) @@ -1044,6 +1045,7 @@ def __init__( ub_bulk_wgrad=ub_bulk_wgrad, ub_bulk_dgrad=ub_bulk_dgrad, ub_split_ag=ub_split_ag, + normalization=normalization, **common_gemm_kwargs, ) else: @@ -1072,6 +1074,7 @@ def __init__( ub_bulk_wgrad=ub_bulk_wgrad, ub_bulk_dgrad=ub_bulk_dgrad, ub_split_ag=ub_split_ag, + normalization=normalization, **common_gemm_kwargs, ) else: diff --git a/transformer_engine/pytorch/cpp_extensions/normalization.py b/transformer_engine/pytorch/cpp_extensions/normalization.py index ddee0152dc..54c7a0789f 100644 --- a/transformer_engine/pytorch/cpp_extensions/normalization.py +++ b/transformer_engine/pytorch/cpp_extensions/normalization.py @@ -10,7 +10,10 @@ __all__ = ['layernorm_fwd_fp8', 'layernorm_fwd_fp8_inf', - 'layernorm_fwd_inf'] + 'layernorm_fwd_inf', + 'rmsnorm_fwd_fp8', + 'rmsnorm_fwd_fp8_inf', + 'rmsnorm_fwd_inf'] def layernorm_fwd_fp8( @@ -99,3 +102,83 @@ def layernorm_fwd_inf( eps, zero_centered_gamma, ) + +def rmsnorm_fwd_fp8( + inp: torch.Tensor, + weight: torch.Tensor, + eps: float, + fp8_meta_tensor: tex.FP8TensorMeta, + fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], + otype: tex.DType, + sm_margin: int, + zero_centered_gamma: bool, + rmsnorm_out: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """RMSNorm with FP8 output""" + if rmsnorm_out is not None: + return tex.rmsnorm_fwd_fp8_noalloc( + inp, + weight, + eps, + fp8_meta_tensor.scale[fp8_tensor], + rmsnorm_out, + fp8_meta_tensor.amax_history[0][fp8_tensor], + fp8_meta_tensor.scale_inv[fp8_tensor], + otype, + sm_margin, + zero_centered_gamma + ) + + return tex.rmsnorm_fwd_fp8( + inp, + weight, + eps, + fp8_meta_tensor.scale[fp8_tensor], + fp8_meta_tensor.amax_history[0][fp8_tensor], + fp8_meta_tensor.scale_inv[fp8_tensor], + otype, + sm_margin, + zero_centered_gamma + ) + + +def rmsnorm_fwd_fp8_inf( + inp: torch.Tensor, + weight: torch.Tensor, + eps: float, + fp8_meta_tensor: tex.FP8TensorMeta, + fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], + otype: tex.DType, + zero_centered_gamma, +) -> torch.Tensor: + """RMSNorm with FP8 output. + + This version of rmsnorm_fwd_fp8 is specialized for inference, and returns + only the normalized output. + """ + ret = torch.ops.tex_ts.rmsnorm_fwd_fp8_inf_ts( + inp, + weight, + eps, + fp8_meta_tensor.scale, + fp8_meta_tensor.amax_history, + fp8_meta_tensor.scale_inv, + fp8_tensor, + otype, + zero_centered_gamma) + return ret + + +def rmsnorm_fwd_inf( + inp: torch.Tensor, + weight: torch.Tensor, + eps: float, + zero_centered_gamma: bool, +) -> torch.Tensor: + """RMSNorm with FP8 output""" + return torch.ops.tex_ts.rmsnorm_fwd_inf_ts( + inp, + weight, + eps, + zero_centered_gamma, + ) diff --git a/transformer_engine/pytorch/csrc/common.cu b/transformer_engine/pytorch/csrc/common.cu index 1d20607940..3209dda004 100644 --- a/transformer_engine/pytorch/csrc/common.cu +++ b/transformer_engine/pytorch/csrc/common.cu @@ -137,3 +137,11 @@ at::Tensor allocateTorchTensor(int M, return at::empty({static_cast(M)}, at::CUDA(GetATenDType(dtype))); } + +void *getDataPtr(at::Tensor t) { + if (t.numel() > 0) { + return t.data_ptr(); + } else { + return nullptr; + } +} diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 17d36b9911..7c17f1f34c 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -9,6 +9,7 @@ #include #include +#include #include #include #include @@ -180,4 +181,6 @@ at::Tensor allocateTorchTensor(int M, transformer_engine::DType dtype ); +void *getDataPtr(at::Tensor t); + #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_ diff --git a/transformer_engine/pytorch/csrc/extensions.cu b/transformer_engine/pytorch/csrc/extensions.cu deleted file mode 100644 index 69248d4aa9..0000000000 --- a/transformer_engine/pytorch/csrc/extensions.cu +++ /dev/null @@ -1,2277 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include "extensions.h" -#ifdef NVTE_WITH_USERBUFFERS -#include "comm_gemm_overlap.h" -#endif // NVTE_WITH_USERBUFFERS - -constexpr int block_size = 512; -constexpr int ctas_per_sm = 4; - -// get the fused attention backend -NVTE_Fused_Attn_Backend get_fused_attn_backend( - const transformer_engine::DType q_dtype, - const transformer_engine::DType kv_dtype, - NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, - float p_dropout, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim) { - NVTE_Fused_Attn_Backend fused_attention_backend = - nvte_get_fused_attn_backend( - static_cast(q_dtype), static_cast(kv_dtype), - qkv_layout, bias_type, attn_mask_type, - p_dropout, max_seqlen_q, max_seqlen_kv, head_dim); - return fused_attention_backend; -} - -// fast zero-fills of tensors -template -__global__ void __launch_bounds__(block_size) mha_fill_kernel(scalar_t* out_tensor, - const int32_t* const start_row, - const size_t num_rows) { - size_t row_stride = gridDim.y * blockDim.x; - size_t row_index = blockIdx.x + static_cast(start_row[0]); - size_t col_index = blockIdx.y * blockDim.x + threadIdx.x; - while (row_index < num_rows) { - out_tensor[row_index*row_stride + col_index] = 0; - row_index += gridDim.x; - } -} - -// fast zero-fills of tensors -void mha_fill(const at::Tensor &self, const at::Tensor &start_index) { - auto max_tokens = self.size(0); - auto self_2d = self.view({max_tokens, -1}); - auto fcd_size = self_2d.size(1); - TORCH_CHECK(self.is_contiguous(), "input not contiguous"); - TORCH_CHECK(fcd_size % block_size == 0, "input size not aligned to block size"); - const int num_mp = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; - uint64_t num_blk_y = (uint64_t)(fcd_size / block_size); - uint64_t num_blk_x = (uint64_t)((num_mp * ctas_per_sm + num_blk_y - 1) / num_blk_y); - dim3 dim_grid(num_blk_x, num_blk_y); - dim3 dim_block(block_size); - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( - at::ScalarType::Half, at::ScalarType::BFloat16, - self_2d.scalar_type(), "mha_fill", [&]() { - mha_fill_kernel<<>>( - self_2d.data_ptr(), - static_cast(start_index.data_ptr()), - max_tokens); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); -} - -// extract seed and offset from PhiloxCudaState -__global__ void unpack(at::PhiloxCudaState arg, int64_t* rng_state_ptr) { - if (arg.captured_) { - rng_state_ptr[0] = static_cast(*arg.seed_.ptr); - rng_state_ptr[1] = static_cast( - *(arg.offset_.ptr) + static_cast(arg.offset_intragraph_)); - } else { - rng_state_ptr[0] = static_cast(arg.seed_.val); - rng_state_ptr[1] = static_cast(arg.offset_.val); - } -} - -// extract PhiloxCudaState from CUDA random number generator -at::PhiloxCudaState init_philox_state( - at::CUDAGeneratorImpl* gen, - size_t elts_per_thread) { - at::PhiloxCudaState philox_args; - std::lock_guard lock(gen->mutex_); - philox_args = gen->philox_cuda_state(elts_per_thread); - return philox_args; -} - -// fused attention FWD with packed QKV -std::vector fused_attn_fwd_qkvpacked( - size_t b, size_t max_seqlen, size_t total_seqs, - size_t h, size_t d, - bool is_training, float attn_scale, float p_dropout, bool set_zero, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - const at::Tensor cu_seqlens, - const at::Tensor QKV, - const transformer_engine::DType qkv_type, - const c10::optional descale_QKV, - const c10::optional scale_S, - const c10::optional scale_O, - c10::optional amax_S, - c10::optional amax_O, - const c10::optional Bias, - const c10::optional rng_gen, - size_t rng_elts_per_thread) { - using namespace transformer_engine; - - // create output tensor O - auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA); - auto O = torch::empty({static_cast(total_seqs), - static_cast(h), static_cast(d)}, options); - if (set_zero && (h * d % block_size == 0)) { - mha_fill(O, cu_seqlens.index({torch::indexing::Slice(-1, torch::indexing::None)})); - } else { - O.fill_(0); - } - - // construct NVTE tensors - TensorWrapper te_QKV, te_S, te_O, te_Bias, te_cu_seqlens; - if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { - // FP8 - if ((!descale_QKV.has_value()) || (!scale_S.has_value()) || (!scale_O.has_value()) - || (!amax_S.has_value()) || (!amax_O.has_value())) { - std::string err_tensors = "descale_QKV, scale_S, scale_O, amax_S and amax_O"; - NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); - } - te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), {total_seqs, 3, h, d}, - qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); - at::Tensor descale_S = torch::empty_like(scale_S.value()); - te_S = makeTransformerEngineTensor(nullptr, {0}, - DType::kFloat32, amax_S.value().data_ptr(), - scale_S.value().data_ptr(), descale_S.data_ptr()); - te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs, h, d}, - qkv_type, amax_O.value().data_ptr(), scale_O.value().data_ptr(), nullptr); - } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { - // BF16 or FP16 - te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), {total_seqs, 3, h, d}, - qkv_type, nullptr, nullptr, nullptr); - te_S = makeTransformerEngineTensor(nullptr, {0}, - DType::kFloat32, nullptr, nullptr, nullptr); - te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs, h, d}, - qkv_type, nullptr, nullptr, nullptr); - } else { - NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); - } - if ((bias_type != NVTE_NO_BIAS) && (Bias.has_value())) { - auto bias_shape = Bias.value().sizes().vec(); - std::vector shape{bias_shape.begin(), bias_shape.end()}; - te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), shape, - DType::kFloat32, nullptr, nullptr, nullptr); - } - te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens.data_ptr(), {b+1}, - DType::kInt32, nullptr, nullptr, nullptr); - - // extract random number generator seed and offset - auto gen = at::get_generator_or_default( - rng_gen, at::cuda::detail::getDefaultCUDAGenerator()); - at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread); - auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); - unpack<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>( - philox_args, static_cast(rng_state.data_ptr())); - auto te_rng_state = makeTransformerEngineTensor(rng_state); - - // create auxiliary output tensors - NVTETensorPack nvte_aux_tensor_pack; - nvte_tensor_pack_create(&nvte_aux_tensor_pack); - - // create workspace - TensorWrapper workspace; - - // populate tensors with appropriate shapes and dtypes - nvte_fused_attn_fwd_qkvpacked( - te_QKV.data(), - te_Bias.data(), - te_S.data(), - te_O.data(), - &nvte_aux_tensor_pack, - te_cu_seqlens.data(), - te_rng_state.data(), - max_seqlen, - is_training, attn_scale, p_dropout, - qkv_layout, bias_type, attn_mask_type, - workspace.data(), - at::cuda::getCurrentCUDAStream()); - - // allocate memory for workspace and auxiliary output tensors - auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - workspace = makeTransformerEngineTensor( - workspace_data.data_ptr(), - workspace.shape(), workspace.dtype()); - - // output_tensors = [O, nvte_aux_tensor_pack.tensors] - std::vector output_tensors; - output_tensors.push_back(O); - for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { - auto tensor = reinterpret_cast(nvte_aux_tensor_pack.tensors[i]); - // allocate memory for nvte_aux_tensor_pack.tensors - auto output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false); - output_tensors.push_back(output_tensor); - tensor->data.dptr = output_tensor.data_ptr(); - } - - // execute the kernel - nvte_fused_attn_fwd_qkvpacked( - te_QKV.data(), - te_Bias.data(), - te_S.data(), - te_O.data(), - &nvte_aux_tensor_pack, - te_cu_seqlens.data(), - te_rng_state.data(), - max_seqlen, - is_training, attn_scale, p_dropout, - qkv_layout, bias_type, attn_mask_type, - workspace.data(), - at::cuda::getCurrentCUDAStream()); - - // destroy tensor wrappers, but not allocated memory - nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); - - // if training, [O, softmax-related tensors, rng_state]; if inference, [O] - return output_tensors; -} - -// fused attention BWD with packed QKV -std::vector fused_attn_bwd_qkvpacked( - size_t b, size_t max_seqlen, size_t total_seqs, - size_t h, size_t d, - float attn_scale, float p_dropout, bool set_zero, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - const at::Tensor cu_seqlens, - const at::Tensor QKV, - const at::Tensor O, - const at::Tensor dO, - const transformer_engine::DType qkv_type, - const std::vector Aux_CTX_Tensors, - const c10::optional descale_QKV, - const c10::optional descale_S, - const c10::optional descale_O, - const c10::optional descale_dO, - const c10::optional scale_S, - const c10::optional scale_dP, - const c10::optional scale_dQKV, - c10::optional amax_dP, - c10::optional amax_dQKV) { - using namespace transformer_engine; - - // create output tensor dQKV - at::Tensor dQKV = torch::empty_like(QKV); - auto max_tokens = dQKV.size(0); - auto self_2d = dQKV.view({max_tokens, -1}); - auto fcd_size = self_2d.size(1); - if (set_zero && (fcd_size % block_size == 0)) { - mha_fill(dQKV, cu_seqlens.index({torch::indexing::Slice(-1, torch::indexing::None)})); - } else { - dQKV.fill_(0); - } - auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA); - at::Tensor dBias; - TensorWrapper te_dBias; - if (bias_type != NVTE_NO_BIAS) { - dBias = torch::zeros({1, static_cast(h), - static_cast(max_seqlen), - static_cast(max_seqlen)}, options); - te_dBias = makeTransformerEngineTensor(dBias); - } - - // construct NVTE tensors - TensorWrapper te_QKV, te_O, te_dO, te_S, te_dP, te_dQKV; - if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { - // FP8 - if ((!descale_QKV.has_value()) || (!descale_S.has_value()) - || (!descale_O.has_value()) || (!descale_dO.has_value()) - || (!scale_S.has_value()) || (!scale_dP.has_value()) - || (!scale_dQKV.has_value()) - || (!amax_dP.has_value()) || (!amax_dQKV.has_value())) { - std::string err_tensors = "descale_QKV, descale_S, descale_O, scale_S, scale_dP, "; - err_tensors = err_tensors + std::string("scale_dQKV, amax_dP and amax_dQKV"); - NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); - } - te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), {total_seqs, 3, h, d}, - qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); - te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs, h, d}, - qkv_type, nullptr, nullptr, descale_O.value().data_ptr()); - te_dO = makeTransformerEngineTensor(dO.data_ptr(), {total_seqs, h, d}, - qkv_type, nullptr, nullptr, descale_dO.value().data_ptr()); - te_S = makeTransformerEngineTensor(nullptr, {0}, - DType::kFloat32, - nullptr, scale_S.value().data_ptr(), descale_S.value().data_ptr()); - at::Tensor descale_dP = torch::empty_like(scale_dP.value()); - te_dP = makeTransformerEngineTensor(nullptr, {0}, - DType::kFloat32, amax_dP.value().data_ptr(), scale_dP.value().data_ptr(), - descale_dP.data_ptr()); - te_dQKV = makeTransformerEngineTensor(dQKV.data_ptr(), {total_seqs, 3, h, d}, - qkv_type, - amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr); - } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { - // BF16 or FP16 - te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), {total_seqs, 3, h, d}, - qkv_type, nullptr, nullptr, nullptr); - te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs, h, d}, - qkv_type, nullptr, nullptr, nullptr); - te_dO = makeTransformerEngineTensor(dO.data_ptr(), {total_seqs, h, d}, - qkv_type, nullptr, nullptr, nullptr); - te_S = makeTransformerEngineTensor(nullptr, {0}, - DType::kFloat32, nullptr, nullptr, nullptr); - te_dP = makeTransformerEngineTensor(nullptr, {0}, - DType::kFloat32, nullptr, nullptr, nullptr); - te_dQKV = makeTransformerEngineTensor(dQKV.data_ptr(), {total_seqs, 3, h, d}, - qkv_type, nullptr, nullptr, nullptr); - } else { - NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); - } - - // convert auxiliary tensors from forward into NVTETensors - NVTETensorPack nvte_aux_tensor_pack; - nvte_tensor_pack_create(&nvte_aux_tensor_pack); - nvte_aux_tensor_pack.size = Aux_CTX_Tensors.size(); - for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { - auto tensor = reinterpret_cast(nvte_aux_tensor_pack.tensors[i]); - tensor->data.dptr = Aux_CTX_Tensors[i].data_ptr(); - std::vector tmp(Aux_CTX_Tensors[i].sizes().vec()); - tensor->data.shape = std::vector(tmp.begin(), tmp.end()); - tensor->data.dtype = GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type()); - } - - // create cu_seqlens tensorwrappers - TensorWrapper te_cu_seqlens; - te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens.data_ptr(), {b+1}, - DType::kInt32, nullptr, nullptr, nullptr); - - // create workspace - TensorWrapper workspace; - - // populate tensors with appropriate shapes and dtypes - nvte_fused_attn_bwd_qkvpacked( - te_QKV.data(), - te_O.data(), - te_dO.data(), - te_S.data(), - te_dP.data(), - &nvte_aux_tensor_pack, - te_dQKV.data(), - te_dBias.data(), - te_cu_seqlens.data(), - max_seqlen, - attn_scale, p_dropout, - qkv_layout, bias_type, attn_mask_type, - workspace.data(), - at::cuda::getCurrentCUDAStream()); - - // allocate memory for workspace - auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - workspace = makeTransformerEngineTensor( - workspace_data.data_ptr(), - workspace.shape(), workspace.dtype()); - - // execute kernel - nvte_fused_attn_bwd_qkvpacked( - te_QKV.data(), - te_O.data(), - te_dO.data(), - te_S.data(), - te_dP.data(), - &nvte_aux_tensor_pack, - te_dQKV.data(), - te_dBias.data(), - te_cu_seqlens.data(), - max_seqlen, - attn_scale, p_dropout, - qkv_layout, bias_type, attn_mask_type, - workspace.data(), - at::cuda::getCurrentCUDAStream()); - - // destroy tensor wrappers - nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); - - return {dQKV, dBias}; -} - -// fused attention FWD with packed KV -std::vector fused_attn_fwd_kvpacked( - size_t b, size_t max_seqlen_q, size_t max_seqlen_kv, - size_t total_seqs_q, size_t total_seqs_kv, - size_t h, size_t d, - bool is_training, float attn_scale, float p_dropout, bool set_zero, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - const at::Tensor cu_seqlens_q, - const at::Tensor cu_seqlens_kv, - const at::Tensor Q, - const at::Tensor KV, - const transformer_engine::DType qkv_type, - const c10::optional descale_QKV, - const c10::optional scale_S, - const c10::optional scale_O, - c10::optional amax_S, - c10::optional amax_O, - const c10::optional Bias, - const c10::optional rng_gen, - size_t rng_elts_per_thread) { - using namespace transformer_engine; - - // create output tensor O - auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA); - auto O = torch::empty({static_cast(total_seqs_q), - static_cast(h), static_cast(d)}, options); - if (set_zero && (h * d % block_size == 0)) { - mha_fill(O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); - } else { - O.fill_(0); - } - - // construct NVTE tensors - TensorWrapper te_Q, te_KV, te_S, te_O, te_Bias, te_cu_seqlens_q, te_cu_seqlens_kv; - if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { - // FP8 - if ((!descale_QKV.has_value()) || (!scale_S.has_value()) || (!scale_O.has_value()) - || (!amax_S.has_value()) || (!amax_O.has_value())) { - std::string err_tensors = "descale_QKV, scale_S, scale_O, amax_S and amax_O"; - NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); - } - te_Q = makeTransformerEngineTensor(Q.data_ptr(), {total_seqs_q, h, d}, - qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); - te_KV = makeTransformerEngineTensor(KV.data_ptr(), {total_seqs_kv, 2, h, d}, - qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); - at::Tensor descale_S = torch::empty_like(scale_S.value()); - te_S = makeTransformerEngineTensor(nullptr, {0}, - DType::kFloat32, amax_S.value().data_ptr(), - scale_S.value().data_ptr(), descale_S.data_ptr()); - te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs_q, h, d}, - qkv_type, amax_O.value().data_ptr(), scale_O.value().data_ptr(), nullptr); - } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { - // BF16 or FP16 - te_Q = makeTransformerEngineTensor(Q.data_ptr(), {total_seqs_q, h, d}, - qkv_type, nullptr, nullptr, nullptr); - te_KV = makeTransformerEngineTensor(KV.data_ptr(), {total_seqs_kv, 2, h, d}, - qkv_type, nullptr, nullptr, nullptr); - te_S = makeTransformerEngineTensor(nullptr, {0}, - DType::kFloat32, nullptr, nullptr, nullptr); - te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs_q, h, d}, - qkv_type, nullptr, nullptr, nullptr); - } else { - NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); - } - if ((bias_type != NVTE_NO_BIAS) && (Bias.has_value())) { - auto bias_shape = Bias.value().sizes().vec(); - std::vector shape{bias_shape.begin(), bias_shape.end()}; - te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), shape, - DType::kFloat32, nullptr, nullptr, nullptr); - } - te_cu_seqlens_q = makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), {b+1}, - DType::kInt32, nullptr, nullptr, nullptr); - te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), {b+1}, - DType::kInt32, nullptr, nullptr, nullptr); - - // extract rng seed and offset - auto gen = at::get_generator_or_default( - rng_gen, at::cuda::detail::getDefaultCUDAGenerator()); - at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread); - auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); - unpack<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>( - philox_args, static_cast(rng_state.data_ptr())); - auto te_rng_state = makeTransformerEngineTensor(rng_state); - - // create auxiliary output tensors - NVTETensorPack nvte_aux_tensor_pack; - nvte_tensor_pack_create(&nvte_aux_tensor_pack); - - // create workspace - TensorWrapper workspace; - - // populate tensors with appropriate shapes and dtypes - nvte_fused_attn_fwd_kvpacked( - te_Q.data(), - te_KV.data(), - te_Bias.data(), - te_S.data(), - te_O.data(), - &nvte_aux_tensor_pack, - te_cu_seqlens_q.data(), - te_cu_seqlens_kv.data(), - te_rng_state.data(), - max_seqlen_q, max_seqlen_kv, - is_training, attn_scale, p_dropout, - qkv_layout, bias_type, attn_mask_type, - workspace.data(), - at::cuda::getCurrentCUDAStream()); - - // allocate memory for workspace and auxiliary output tensors - auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - workspace = makeTransformerEngineTensor( - workspace_data.data_ptr(), - workspace.shape(), workspace.dtype()); - - // output_tensors = [O, nvte_aux_tensor_pack.tensors] - std::vector output_tensors; - output_tensors.push_back(O); - for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { - auto tensor = reinterpret_cast(nvte_aux_tensor_pack.tensors[i]); - // allocate memory for nvte_aux_tensor_pack.tensors - auto output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false); - output_tensors.push_back(output_tensor); - tensor->data.dptr = output_tensor.data_ptr(); - } - - // execute the kernel - nvte_fused_attn_fwd_kvpacked( - te_Q.data(), - te_KV.data(), - te_Bias.data(), - te_S.data(), - te_O.data(), - &nvte_aux_tensor_pack, - te_cu_seqlens_q.data(), - te_cu_seqlens_kv.data(), - te_rng_state.data(), - max_seqlen_q, max_seqlen_kv, - is_training, attn_scale, p_dropout, - qkv_layout, bias_type, attn_mask_type, - workspace.data(), - at::cuda::getCurrentCUDAStream()); - - // destroy tensor wrappers, but not allocated memory - nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); - - // if training, [O, softmax-related tensors, rng_state]; if inference, [O] - return output_tensors; -} - -// fused attention BWD with packed KV -std::vector fused_attn_bwd_kvpacked( - size_t b, size_t max_seqlen_q, size_t max_seqlen_kv, - size_t total_seqs_q, size_t total_seqs_kv, - size_t h, size_t d, - float attn_scale, float p_dropout, bool set_zero, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - const at::Tensor cu_seqlens_q, - const at::Tensor cu_seqlens_kv, - const at::Tensor Q, - const at::Tensor KV, - const at::Tensor O, - const at::Tensor dO, - const transformer_engine::DType qkv_type, - const std::vector Aux_CTX_Tensors, - const c10::optional descale_QKV, - const c10::optional descale_S, - const c10::optional descale_O, - const c10::optional descale_dO, - const c10::optional scale_S, - const c10::optional scale_dP, - const c10::optional scale_dQKV, - c10::optional amax_dP, - c10::optional amax_dQKV) { - using namespace transformer_engine; - - // create output tensors dQ and dKV - at::Tensor dQ = torch::empty_like(Q); - at::Tensor dKV = torch::empty_like(KV); - auto max_tokens_q = dQ.size(0); - auto self_2d_q = dQ.view({max_tokens_q, -1}); - auto fcd_size_q = self_2d_q.size(1); - auto max_tokens_kv = dQ.size(0); - auto self_2d_kv = dQ.view({max_tokens_kv, -1}); - auto fcd_size_kv = self_2d_kv.size(1); - if (set_zero && (fcd_size_q % block_size == 0) && (fcd_size_kv % block_size == 0)) { - mha_fill(dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); - mha_fill(dKV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); - } else { - dQ.fill_(0); - dKV.fill_(0); - } - auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA); - at::Tensor dBias; - TensorWrapper te_dBias; - if (bias_type != NVTE_NO_BIAS) { - dBias = torch::zeros({1, static_cast(h), - static_cast(max_seqlen_q), - static_cast(max_seqlen_kv)}, options); - te_dBias = makeTransformerEngineTensor(dBias); - } - - // construct NVTE tensors - TensorWrapper te_Q, te_KV, te_O, te_dO, te_S, te_dP, te_dQ, te_dKV; - if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { - // FP8 - if ((!descale_QKV.has_value()) || (!descale_S.has_value()) - || (!descale_O.has_value()) || (!descale_dO.has_value()) - || (!scale_S.has_value()) || (!scale_dP.has_value()) - || (!scale_dQKV.has_value()) - || (!amax_dP.has_value()) || (!amax_dQKV.has_value())) { - std::string err_tensors = "descale_QKV, descale_S, descale_O, scale_S, scale_dP, "; - err_tensors = err_tensors + std::string("scale_dQKV, amax_dP and amax_dQKV"); - NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); - } - te_Q = makeTransformerEngineTensor(Q.data_ptr(), {total_seqs_q, h, d}, - qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); - te_KV = makeTransformerEngineTensor(KV.data_ptr(), {total_seqs_kv, 2, h, d}, - qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); - te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs_q, h, d}, - qkv_type, nullptr, nullptr, descale_O.value().data_ptr()); - te_dO = makeTransformerEngineTensor(dO.data_ptr(), {total_seqs_q, h, d}, - qkv_type, nullptr, nullptr, descale_dO.value().data_ptr()); - te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, - scale_S.value().data_ptr(), descale_S.value().data_ptr()); - at::Tensor descale_dP = torch::empty_like(scale_dP.value()); - te_dP = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, - amax_dP.value().data_ptr(), scale_dP.value().data_ptr(), - descale_dP.data_ptr()); - te_dQ = makeTransformerEngineTensor(dQ.data_ptr(), {total_seqs_q, h, d}, qkv_type, - amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr); - te_dKV = makeTransformerEngineTensor(dKV.data_ptr(), {total_seqs_kv, 2, h, d}, qkv_type, - amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr); - } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { - // BF16 or FP16 - te_Q = makeTransformerEngineTensor(Q.data_ptr(), {total_seqs_q, h, d}, - qkv_type, nullptr, nullptr, nullptr); - te_KV = makeTransformerEngineTensor(KV.data_ptr(), {total_seqs_kv, 2, h, d}, - qkv_type, nullptr, nullptr, nullptr); - te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs_q, h, d}, - qkv_type, nullptr, nullptr, nullptr); - te_dO = makeTransformerEngineTensor(dO.data_ptr(), {total_seqs_q, h, d}, - qkv_type, nullptr, nullptr, nullptr); - te_S = makeTransformerEngineTensor(nullptr, {0}, - DType::kFloat32, nullptr, nullptr, nullptr); - te_dP = makeTransformerEngineTensor(nullptr, {0}, - DType::kFloat32, nullptr, nullptr, nullptr); - te_dQ = makeTransformerEngineTensor(dQ.data_ptr(), {total_seqs_q, h, d}, - qkv_type, nullptr, nullptr, nullptr); - te_dKV = makeTransformerEngineTensor(dKV.data_ptr(), {total_seqs_kv, 2, h, d}, - qkv_type, nullptr, nullptr, nullptr); - } else { - NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); - } - - // create cu_seqlens tensorwrappers - TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv; - te_cu_seqlens_q = makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), {b+1}, - DType::kInt32, nullptr, nullptr, nullptr); - te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), {b+1}, - DType::kInt32, nullptr, nullptr, nullptr); - - // convert auxiliary tensors from forward to NVTETensors - NVTETensorPack nvte_aux_tensor_pack; - nvte_tensor_pack_create(&nvte_aux_tensor_pack); - nvte_aux_tensor_pack.size = Aux_CTX_Tensors.size(); - for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { - auto tensor = reinterpret_cast(nvte_aux_tensor_pack.tensors[i]); - tensor->data.dptr = Aux_CTX_Tensors[i].data_ptr(); - std::vector tmp(Aux_CTX_Tensors[i].sizes().vec()); - tensor->data.shape = std::vector(tmp.begin(), tmp.end()); - tensor->data.dtype = GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type()); - } - - // create workspace - TensorWrapper workspace; - - // populate tensors with appropriate shapes and dtypes - nvte_fused_attn_bwd_kvpacked( - te_Q.data(), - te_KV.data(), - te_O.data(), - te_dO.data(), - te_S.data(), - te_dP.data(), - &nvte_aux_tensor_pack, - te_dQ.data(), - te_dKV.data(), - te_dBias.data(), - te_cu_seqlens_q.data(), - te_cu_seqlens_kv.data(), - max_seqlen_q, max_seqlen_kv, - attn_scale, p_dropout, - qkv_layout, bias_type, attn_mask_type, - workspace.data(), - at::cuda::getCurrentCUDAStream()); - - // allocate memory for workspace - auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - workspace = makeTransformerEngineTensor( - workspace_data.data_ptr(), - workspace.shape(), workspace.dtype()); - - // execute kernel - nvte_fused_attn_bwd_kvpacked( - te_Q.data(), - te_KV.data(), - te_O.data(), - te_dO.data(), - te_S.data(), - te_dP.data(), - &nvte_aux_tensor_pack, - te_dQ.data(), - te_dKV.data(), - te_dBias.data(), - te_cu_seqlens_q.data(), - te_cu_seqlens_kv.data(), - max_seqlen_q, max_seqlen_kv, - attn_scale, p_dropout, - qkv_layout, bias_type, attn_mask_type, - workspace.data(), - at::cuda::getCurrentCUDAStream()); - - // destroy tensor wrappers - nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); - - return {dQ, dKV, dBias}; -} - -void te_gemm(at::Tensor A, - at::Tensor A_scale_inverse, - transformer_engine::DType A_type, - bool transa, - at::Tensor B, - at::Tensor B_scale_inverse, - transformer_engine::DType B_type, - bool transb, - at::Tensor D, - at::Tensor D_scale, - transformer_engine::DType D_type, - at::Tensor D_amax, - at::Tensor bias, - transformer_engine::DType bias_type, - at::Tensor pre_gelu_out, - bool grad, - at::Tensor workspace, - size_t workspaceSize, - bool accumulate, - bool use_split_accumulator, - int math_sm_count -) { - using namespace transformer_engine; - auto te_A = makeTransformerEngineTensor(A.data_ptr(), - {static_cast(A.size(0)), - static_cast(A.size(1))}, - A_type, nullptr, nullptr, - A_scale_inverse.data_ptr()); - auto te_B = makeTransformerEngineTensor(B.data_ptr(), - {static_cast(B.size(0)), - static_cast(B.size(1))}, - B_type, nullptr, nullptr, - B_scale_inverse.data_ptr()); - auto te_D = makeTransformerEngineTensor(D.data_ptr(), - {static_cast(D.size(0)), - static_cast(D.size(1))}, - D_type, D_amax.data_ptr(), - D_scale.data_ptr(), nullptr); - auto te_bias = makeTransformerEngineTensor(bias.data_ptr(), {static_cast(bias.size(0))}, - bias_type); - - const auto gelu_shape = pre_gelu_out.data_ptr() == nullptr - ? std::vector{static_cast(pre_gelu_out.size(0))} - : std::vector{static_cast(pre_gelu_out.size(0)), - static_cast(pre_gelu_out.size(1))}; - auto te_pre_gelu_out = makeTransformerEngineTensor(pre_gelu_out.data_ptr(), - gelu_shape, - GetTransformerEngineDType( - pre_gelu_out.scalar_type())); - auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(), - {workspaceSize}, - DType::kByte); - - nvte_cublas_gemm(te_A.data(), - te_B.data(), - te_D.data(), - te_bias.data(), - te_pre_gelu_out.data(), - transa, - transb, - grad, - te_workspace.data(), - accumulate, - use_split_accumulator, - math_sm_count, - at::cuda::getCurrentCUDAStream()); -} - - -void fused_cast_transpose(at::Tensor input, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, - at::Tensor input_cast, - at::Tensor input_transpose, - transformer_engine::DType otype -) { - using namespace transformer_engine; - - size_t M = static_cast(input.size(0)); - size_t N = static_cast(input.size(1)); - - auto input_cu = makeTransformerEngineTensor(input); - auto output_cast_cu = makeTransformerEngineTensor(input_cast.data_ptr(), {M, N}, otype, - amax.data_ptr(), scale.data_ptr(), - scale_inv.data_ptr()); - auto output_transpose_cu = makeTransformerEngineTensor(input_transpose.data_ptr(), {N, M}, otype, - amax.data_ptr(), scale.data_ptr(), - scale_inv.data_ptr()); - - nvte_cast_transpose(input_cu.data(), output_cast_cu.data(), output_transpose_cu.data(), - at::cuda::getCurrentCUDAStream()); -} - - -std::vector fused_cast_transpose_bgrad(at::Tensor grad_output, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, - transformer_engine::DType otype -) { - using namespace transformer_engine; - - size_t M = static_cast(grad_output.size(0)); - size_t N = static_cast(grad_output.size(1)); - - DType grad_output_type = GetTransformerEngineDType(grad_output.scalar_type()); - auto grad_bias = allocateTorchTensor(grad_output.size(-1), grad_output_type); - auto grad_output_cast = - allocateTorchTensor(grad_output.size(0), - grad_output.size(1), - DType::kByte); - auto grad_output_transpose = - allocateTorchTensor(grad_output.size(1), - grad_output.size(0), - DType::kByte); - - auto input_cu = makeTransformerEngineTensor(grad_output); - auto cast_output_cu = makeTransformerEngineTensor(grad_output_cast.data_ptr(), {M, N}, - otype, amax.data_ptr(), scale.data_ptr(), - scale_inv.data_ptr()); - auto transposed_output_cu = makeTransformerEngineTensor(grad_output_transpose.data_ptr(), - {N, M}, otype, amax.data_ptr(), - scale.data_ptr(), scale_inv.data_ptr()); - auto dbias_cu = makeTransformerEngineTensor(grad_bias); - transformer_engine::TensorWrapper workspace; - - nvte_cast_transpose_dbias(input_cu.data(), cast_output_cu.data(), - transposed_output_cu.data(), dbias_cu.data(), - workspace.data(), at::cuda::getCurrentCUDAStream()); - - // Fill workspace - auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - workspace = makeTransformerEngineTensor(workspace_data.data_ptr(), - workspace.shape(), - workspace.dtype()); - - nvte_cast_transpose_dbias(input_cu.data(), cast_output_cu.data(), - transposed_output_cu.data(), dbias_cu.data(), - workspace.data(), at::cuda::getCurrentCUDAStream()); - - return {grad_bias, grad_output_cast, grad_output_transpose}; -} - - -std::vector fused_fp8_transpose_bgrad(at::Tensor grad_output, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, - transformer_engine::DType otype, - transformer_engine::DType grad_bias_type -) { - using namespace transformer_engine; - - size_t M = static_cast(grad_output.size(0)); - size_t N = static_cast(grad_output.size(1)); - - auto grad_bias = allocateTorchTensor(grad_output.size(-1), grad_bias_type); - auto grad_output_transpose = - allocateTorchTensor(grad_output.size(1), - grad_output.size(0), - DType::kByte); - auto input_cu = makeTransformerEngineTensor(grad_output.data_ptr(), {M, N}, - otype, amax.data_ptr(), scale.data_ptr(), - scale_inv.data_ptr()); - auto transposed_output_cu = makeTransformerEngineTensor(grad_output_transpose.data_ptr(), - {N, M}, otype, amax.data_ptr(), - scale.data_ptr(), scale_inv.data_ptr()); - auto dbias_cu = makeTransformerEngineTensor(grad_bias); - transformer_engine::TensorWrapper workspace; - - nvte_fp8_transpose_dbias(input_cu.data(), transposed_output_cu.data(), dbias_cu.data(), - workspace.data(), at::cuda::getCurrentCUDAStream()); - - // Fill workspace - auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - workspace = makeTransformerEngineTensor(workspace_data.data_ptr(), - workspace.shape(), - workspace.dtype()); - - nvte_fp8_transpose_dbias(input_cu.data(), transposed_output_cu.data(), dbias_cu.data(), - workspace.data(), at::cuda::getCurrentCUDAStream()); - - return {grad_bias, grad_output_transpose}; -} - - - -std::vector fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output, - at::Tensor gelu_input, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, - transformer_engine::DType otype -) { - using namespace transformer_engine; - - size_t M = static_cast(grad_output.size(0)); - size_t N = static_cast(grad_output.size(1)); - - DType grad_output_type = GetTransformerEngineDType(grad_output.scalar_type()); - auto grad_bias = allocateTorchTensor(grad_output.size(-1), grad_output_type); - auto dgelu = - allocateTorchTensor(grad_output.size(0), - grad_output.size(1), - DType::kByte); - auto dgelu_transpose = - allocateTorchTensor(grad_output.size(1), - grad_output.size(0), - DType::kByte); - - transformer_engine::TensorWrapper workspace; - auto gelu_input_cu = makeTransformerEngineTensor(gelu_input); - auto input_cu = makeTransformerEngineTensor(grad_output); - auto cast_output_cu = makeTransformerEngineTensor(dgelu.data_ptr(), {M, N}, - otype, amax.data_ptr(), scale.data_ptr(), - scale_inv.data_ptr()); - auto transposed_output_cu = makeTransformerEngineTensor(dgelu_transpose.data_ptr(), {N, M}, - otype, amax.data_ptr(), scale.data_ptr(), - scale_inv.data_ptr()); - auto dbias_cu = makeTransformerEngineTensor(grad_bias); - - nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(), - cast_output_cu.data(), transposed_output_cu.data(), - dbias_cu.data(), workspace.data(), - at::cuda::getCurrentCUDAStream()); - - // Fill workspace - auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - workspace = makeTransformerEngineTensor(workspace_data.data_ptr(), - workspace.shape(), - workspace.dtype()); - - nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(), - cast_output_cu.data(), transposed_output_cu.data(), - dbias_cu.data(), workspace.data(), - at::cuda::getCurrentCUDAStream()); - - return {grad_bias, dgelu, dgelu_transpose}; -} - - -void fused_multi_cast_transpose(std::vector input_list, - std::vector scale_list, - std::vector cast_output_list, - std::vector transposed_output_list, - std::vector amax_list, - std::vector scale_inv_list, - transformer_engine::DType otype -) { - using namespace transformer_engine; - - // Extract properties from PyTorch tensors - std::vector input_dptr_list, scale_dptr_list, - cast_output_dptr_list, transposed_output_dptr_list, - amax_dptr_list, scale_inv_dptr_list; - std::vector> input_shape_list, scale_shape_list, - cast_output_shape_list, transposed_output_shape_list, - amax_shape_list, scale_inv_shape_list; - std::vector input_type_list, scale_type_list, - cast_output_type_list, transposed_output_type_list, - amax_type_list, scale_inv_type_list; - auto extract_tensor_props_skip_dtype = [](at::Tensor& tensor, - std::vector& dptr_list, - std::vector>& shape_list) { - dptr_list.push_back(tensor.data_ptr()); - shape_list.push_back({}); - for (int d = 0; d < tensor.dim(); ++d) { - shape_list.back().push_back(tensor.size(d)); - } - }; - auto extract_tensor_props = [](at::Tensor& tensor, - std::vector& dptr_list, - std::vector>& shape_list, - std::vector& type_list) { - dptr_list.push_back(tensor.data_ptr()); - shape_list.push_back({}); - for (int d = 0; d < tensor.dim(); ++d) { - shape_list.back().push_back(tensor.size(d)); - } - type_list.push_back(GetTransformerEngineDType(tensor.scalar_type())); - }; - for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) { - extract_tensor_props(input_list[tensor_id], - input_dptr_list, - input_shape_list, - input_type_list); - extract_tensor_props(scale_list[tensor_id], - scale_dptr_list, - scale_shape_list, - scale_type_list); - extract_tensor_props_skip_dtype(cast_output_list[tensor_id], - cast_output_dptr_list, - cast_output_shape_list); - cast_output_type_list.push_back(otype); - extract_tensor_props_skip_dtype(transposed_output_list[tensor_id], - transposed_output_dptr_list, - transposed_output_shape_list); - transposed_output_type_list.push_back(otype); - extract_tensor_props(amax_list[tensor_id], - amax_dptr_list, - amax_shape_list, - amax_type_list); - extract_tensor_props(scale_inv_list[tensor_id], - scale_inv_dptr_list, - scale_inv_shape_list, - scale_inv_type_list); - } - - transformer_engine::TensorWrapper workspace; - - // Construct TE tensors - std::vector nvte_input_list, - nvte_cast_output_list, nvte_transposed_output_list; - std::vector tensor_wrappers; - auto make_tensor = [&tensor_wrappers](void* dptr, - const std::vector& shape, - transformer_engine::DType dtype, - void* amax_dptr, - void* scale_dptr, - void* scale_inv_dptr) - -> NVTETensor { - tensor_wrappers.emplace_back(makeTransformerEngineTensor(dptr, shape, dtype, amax_dptr, - scale_dptr, scale_inv_dptr)); - return tensor_wrappers.back().data(); - }; - for (size_t i = 0; i < input_dptr_list.size(); ++i) { - nvte_input_list.emplace_back(make_tensor(input_dptr_list[i], - input_shape_list[i], - input_type_list[i], - nullptr, - nullptr, - nullptr)); - nvte_cast_output_list.emplace_back(make_tensor(cast_output_dptr_list[i], - cast_output_shape_list[i], - cast_output_type_list[i], - amax_dptr_list[i], - scale_dptr_list[i], - scale_inv_dptr_list[i])); - nvte_transposed_output_list.emplace_back(make_tensor(transposed_output_dptr_list[i], - transposed_output_shape_list[i], - transposed_output_type_list[i], - amax_dptr_list[i], - scale_dptr_list[i], - scale_inv_dptr_list[i])); - } - - // Check tensor lists - NVTE_CHECK(nvte_cast_output_list.size() == nvte_input_list.size(), - "Number of input and C output tensors must match"); - NVTE_CHECK(nvte_transposed_output_list.size() == nvte_input_list.size(), - "Number of input and T output tensors must match"); - - // Launch TE kernel - nvte_multi_cast_transpose(nvte_input_list.size(), - nvte_input_list.data(), - nvte_cast_output_list.data(), - nvte_transposed_output_list.data(), - at::cuda::getCurrentCUDAStream()); -} - - -at::Tensor fp8_transpose(at::Tensor input, - transformer_engine::DType otype -) { - using namespace transformer_engine; - - size_t M = static_cast(input.size(0)); - size_t N = static_cast(input.size(1)); - - auto output = - allocateTorchTensor(input.size(1), - input.size(0), - DType::kByte); - - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, otype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, M}, otype); - - nvte_transpose(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); - - return output; -} - - -at::Tensor gelu(at::Tensor input, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, - transformer_engine::DType otype -) { - using namespace transformer_engine; - - size_t N = static_cast(input.size(-1)); - size_t M = input.numel() / N; - - auto output = - allocateTorchTensor(M, - N, - otype); - - auto itype = GetTransformerEngineDType(input.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype, - amax.data_ptr(), scale.data_ptr(), - scale_inv.data_ptr()); - - nvte_gelu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); - - return output; -} - -at::Tensor dgelu(at::Tensor grad, - at::Tensor input, - transformer_engine::DType otype -) { - using namespace transformer_engine; - - size_t N = static_cast(input.size(-1)); - size_t M = input.numel() / N; - - auto output = - allocateTorchTensor(M, - N, - otype); - - auto itype = GetTransformerEngineDType(input.scalar_type()); - auto gtype = GetTransformerEngineDType(grad.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N}, gtype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype); - - nvte_dgelu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); - - return output; -} - -at::Tensor relu(at::Tensor input, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, - transformer_engine::DType otype -) { - using namespace transformer_engine; - - size_t N = static_cast(input.size(-1)); - size_t M = static_cast(input.numel()) / N; - - auto output = - allocateTorchTensor(M, - N, - otype); - - auto itype = GetTransformerEngineDType(input.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype, - amax.data_ptr(), scale.data_ptr(), - scale_inv.data_ptr()); - - nvte_relu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); - - return output; -} - -at::Tensor drelu(at::Tensor grad, - at::Tensor input, - transformer_engine::DType otype -) { - using namespace transformer_engine; - - size_t N = static_cast(input.size(-1)); - size_t M = input.numel() / N; - - auto output = - allocateTorchTensor(M, - N, - otype); - - auto itype = GetTransformerEngineDType(input.scalar_type()); - auto gtype = GetTransformerEngineDType(grad.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N}, gtype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype); - - nvte_drelu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); - - return output; -} - -at::Tensor geglu(at::Tensor input, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, - transformer_engine::DType otype -) { - using namespace transformer_engine; - - size_t N = static_cast(input.size(-1)); - size_t M = input.numel() / N; - - auto output = - allocateTorchTensor(M, - N / 2, - otype); - - auto itype = GetTransformerEngineDType(input.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N / 2}, otype, - amax.data_ptr(), scale.data_ptr(), - scale_inv.data_ptr()); - - nvte_geglu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); - - return output; -} - -at::Tensor dgeglu(at::Tensor grad, - at::Tensor input, - transformer_engine::DType otype -) { - using namespace transformer_engine; - - size_t N = static_cast(input.size(-1)); - size_t M = input.numel() / N; - - auto output = - allocateTorchTensor(M, - N, - otype); - - auto itype = GetTransformerEngineDType(input.scalar_type()); - auto gtype = GetTransformerEngineDType(grad.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N / 2}, gtype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype); - - nvte_dgeglu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); - - return output; -} - -at::Tensor reglu(at::Tensor input, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, - transformer_engine::DType otype -) { - using namespace transformer_engine; - - size_t N = static_cast(input.size(-1)); - size_t M = input.numel() / N; - - auto output = - allocateTorchTensor(M, - N / 2, - otype); - - auto itype = GetTransformerEngineDType(input.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N / 2}, otype, - amax.data_ptr(), scale.data_ptr(), - scale_inv.data_ptr()); - - nvte_reglu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); - - return output; -} - -at::Tensor dreglu(at::Tensor grad, - at::Tensor input, - transformer_engine::DType otype -) { - using namespace transformer_engine; - - size_t N = static_cast(input.size(-1)); - size_t M = input.numel() / N; - - auto output = - allocateTorchTensor(M, - N, - otype); - - auto itype = GetTransformerEngineDType(input.scalar_type()); - auto gtype = GetTransformerEngineDType(grad.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N / 2}, gtype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype); - - nvte_dreglu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); - - return output; -} - -at::Tensor swiglu(at::Tensor input, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, - transformer_engine::DType otype -) { - using namespace transformer_engine; - - size_t N = static_cast(input.size(-1)); - size_t M = input.numel() / N; - - auto output = - allocateTorchTensor(M, - N / 2, - otype); - - auto itype = GetTransformerEngineDType(input.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N / 2}, otype, - amax.data_ptr(), scale.data_ptr(), - scale_inv.data_ptr()); - - nvte_swiglu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); - - return output; -} - -at::Tensor dswiglu(at::Tensor grad, - at::Tensor input, - transformer_engine::DType otype -) { - using namespace transformer_engine; - - size_t N = static_cast(input.size(-1)); - size_t M = input.numel() / N; - - auto output = - allocateTorchTensor(M, - N, - otype); - - auto itype = GetTransformerEngineDType(input.scalar_type()); - auto gtype = GetTransformerEngineDType(grad.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N / 2}, gtype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype); - - nvte_dswiglu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); - - return output; -} - -std::vector layernorm_bwd(const at::Tensor &dz, - const at::Tensor &x, - const at::Tensor &mu, - const at::Tensor &rsigma, - const at::Tensor &gamma, - const int sm_margin, - const bool zero_centered_gamma -) { - auto dx = at::empty_like(x); - auto dgamma = at::empty_like(gamma); - auto dbeta = at::empty_like(gamma); - transformer_engine::TensorWrapper workspace, barrier, dgamma_part, dbeta_part; - - auto dz_cu = makeTransformerEngineTensor(dz); - auto x_cu = makeTransformerEngineTensor(x); - auto mu_cu = makeTransformerEngineTensor(mu); - auto rsigma_cu = makeTransformerEngineTensor(rsigma); - auto gamma_cu = makeTransformerEngineTensor(gamma); - auto dx_cu = makeTransformerEngineTensor(dx); - auto dgamma_cu = makeTransformerEngineTensor(dgamma); - auto dbeta_cu = makeTransformerEngineTensor(dbeta); - - // This call populates tensors with the required config. - const auto bwd_fun = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd; - bwd_fun(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), - dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), dgamma_part.data(), - dbeta_part.data(), at::cuda::getCurrentCUDAStream(), - at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, - workspace.data(), barrier.data()); - - // Alloc space for Tensors. - auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - auto barrier_data = allocateSpace(barrier.shape(), barrier.dtype(), true); - auto dgamma_part_data = allocateSpace(dgamma_part.shape(), dgamma_part.dtype()); - auto dbeta_part_data = allocateSpace(dbeta_part.shape(), dbeta_part.dtype()); - workspace = makeTransformerEngineTensor(workspace_data.data_ptr(), - workspace.shape(), - workspace.dtype()); - barrier = makeTransformerEngineTensor(barrier_data.data_ptr(), - barrier.shape(), - barrier.dtype()); - dgamma_part = makeTransformerEngineTensor(dgamma_part_data.data_ptr(), - dgamma_part.shape(), - dgamma_part.dtype()); - dbeta_part = makeTransformerEngineTensor(dbeta_part_data.data_ptr(), - dbeta_part.shape(), - dbeta_part.dtype()); - - // Actual call to bwd kernel. - bwd_fun(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), - dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), dgamma_part.data(), - dbeta_part.data(), at::cuda::getCurrentCUDAStream(), - at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, - workspace.data(), barrier.data()); - - return { dx, dgamma, dbeta }; -} - - -std::vector layernorm_fwd_fp8(const at::Tensor &input, - const at::Tensor &weight, - const at::Tensor &bias, - float eps, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, - transformer_engine::DType otype, - const int sm_margin, - const bool zero_centered_gamma -) { - using namespace transformer_engine; - - size_t N = static_cast(input.size(0)); - size_t H = static_cast(input.size(1)); - - DType itype = GetTransformerEngineDType(input.scalar_type()); - - auto ln_out = at::empty_like(input, at::CUDA(GetATenDType(otype))); - auto mu = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); - auto rsigma = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); - auto input_cu = makeTransformerEngineTensor(input); - auto gamma_cu = makeTransformerEngineTensor(weight); - auto beta_cu = makeTransformerEngineTensor(bias); - auto z_cu = makeTransformerEngineTensor(ln_out.data_ptr(), {N, H}, otype, - amax.data_ptr(), scale.data_ptr(), - scale_inv.data_ptr()); - auto mu_cu = makeTransformerEngineTensor(mu); - auto rsigma_cu = makeTransformerEngineTensor(rsigma); - transformer_engine::TensorWrapper workspace, barrier; - - // This call populates workspace and barrier tensors with the required config - const auto func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd; - func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), - mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(), - at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, - workspace.data(), barrier.data()); - - // Fill workspace and barrier - auto workspace_data = allocateSpace(workspace.shape(), - workspace.dtype()); - auto barrier_data = allocateSpace(barrier.shape(), - barrier.dtype(), - true); - workspace = makeTransformerEngineTensor(workspace_data.data_ptr(), - workspace.shape(), - workspace.dtype()); - barrier = makeTransformerEngineTensor(barrier_data.data_ptr(), - barrier.shape(), - barrier.dtype()); - - // Actual call to fwd kernel - func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), - mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(), - at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, - workspace.data(), barrier.data()); - - return {ln_out, mu, rsigma}; -} - - -std::vector layernorm_fwd_fp8_noalloc(const at::Tensor &input, - const at::Tensor &weight, - const at::Tensor &bias, - float eps, - at::Tensor scale, - at::Tensor ln_out, - at::Tensor amax, - at::Tensor scale_inv, - transformer_engine::DType otype, - const int sm_margin, - const bool zero_centered_gamma -) { - using namespace transformer_engine; - - size_t N = static_cast(input.size(0)); - size_t H = static_cast(input.size(1)); - - DType itype = GetTransformerEngineDType(input.scalar_type()); - - auto mu = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); - auto rsigma = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); - auto input_cu = makeTransformerEngineTensor(input); - auto gamma_cu = makeTransformerEngineTensor(weight); - auto beta_cu = makeTransformerEngineTensor(bias); - auto z_cu = makeTransformerEngineTensor(ln_out.data_ptr(), {N, H}, otype, - amax.data_ptr(), scale.data_ptr(), - scale_inv.data_ptr()); - auto mu_cu = makeTransformerEngineTensor(mu); - auto rsigma_cu = makeTransformerEngineTensor(rsigma); - transformer_engine::TensorWrapper workspace, barrier; - - // This call populates workspace and barrier tensors with the required config - const auto func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd; - func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), - mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(), - at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, - workspace.data(), barrier.data()); - - // Fill workspace and barrier - auto workspace_data = allocateSpace(workspace.shape(), - workspace.dtype()); - auto barrier_data = allocateSpace(barrier.shape(), - barrier.dtype(), - true); - workspace = makeTransformerEngineTensor(workspace_data.data_ptr(), - workspace.shape(), - workspace.dtype()); - barrier = makeTransformerEngineTensor(barrier_data.data_ptr(), - barrier.shape(), - barrier.dtype()); - - // Actual call to fwd kernel - func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), - mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(), - at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, - workspace.data(), barrier.data()); - - return {ln_out, mu, rsigma}; -} - - -at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input, - const at::Tensor &weight, - const at::Tensor &bias, - float eps, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, - transformer_engine::DType otype, - const bool zero_centered_gamma -) { - // This is a specialized version of layernorm_fwd_fp8, optimized for inference, - // which only returns the normalized output. - std::vector out = layernorm_fwd_fp8( - input, weight, bias, eps, scale, amax, scale_inv, otype, 0, zero_centered_gamma); - return out[0]; -} - - -std::vector layernorm_fwd(const at::Tensor &input, - const at::Tensor &weight, - const at::Tensor &bias, - float eps, - const int sm_margin, - const bool zero_centered_gamma -) { - using namespace transformer_engine; - - size_t N = static_cast(input.size(0)); - size_t H = static_cast(input.size(1)); - - DType itype = GetTransformerEngineDType(input.scalar_type()); - - auto ln_out = at::empty_like(input, at::CUDA(GetATenDType(itype))); - auto mu = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); - auto rsigma = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); - auto input_cu = makeTransformerEngineTensor(input); - auto gamma_cu = makeTransformerEngineTensor(weight); - auto beta_cu = makeTransformerEngineTensor(bias); - auto z_cu = makeTransformerEngineTensor(ln_out); - auto mu_cu = makeTransformerEngineTensor(mu); - auto rsigma_cu = makeTransformerEngineTensor(rsigma); - transformer_engine::TensorWrapper workspace, barrier; - - // This call populates workspace and barrier tensors with the required config - const auto func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd; - func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), - mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(), - at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, - workspace.data(), barrier.data()); - - // Fill workspace and barrier - auto workspace_data = allocateSpace(workspace.shape(), - workspace.dtype()); - auto barrier_data = allocateSpace(barrier.shape(), - barrier.dtype(), - true); - workspace = makeTransformerEngineTensor(workspace_data.data_ptr(), - workspace.shape(), - workspace.dtype()); - barrier = makeTransformerEngineTensor(barrier_data.data_ptr(), - barrier.shape(), - barrier.dtype()); - - // Actual call to fwd kernel - func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), - mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(), - at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, - workspace.data(), barrier.data()); - - return {ln_out, mu, rsigma}; -} - - -std::vector layernorm_fwd_noalloc(const at::Tensor &input, - const at::Tensor &weight, - const at::Tensor &bias, - at::Tensor ln_out, - float eps, - const int sm_margin, - const bool zero_centered_gamma -) { - using namespace transformer_engine; - - size_t N = static_cast(input.size(0)); - size_t H = static_cast(input.size(1)); - - DType itype = GetTransformerEngineDType(input.scalar_type()); - - auto mu = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); - auto rsigma = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); - auto input_cu = makeTransformerEngineTensor(input); - auto gamma_cu = makeTransformerEngineTensor(weight); - auto beta_cu = makeTransformerEngineTensor(bias); - auto z_cu = makeTransformerEngineTensor(ln_out); - auto mu_cu = makeTransformerEngineTensor(mu); - auto rsigma_cu = makeTransformerEngineTensor(rsigma); - transformer_engine::TensorWrapper workspace, barrier; - - // This call populates workspace and barrier tensors with the required config - const auto func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd; - func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), - mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(), - at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, - workspace.data(), barrier.data()); - - // Fill workspace and barrier - auto workspace_data = allocateSpace(workspace.shape(), - workspace.dtype()); - auto barrier_data = allocateSpace(barrier.shape(), - barrier.dtype(), - true); - workspace = makeTransformerEngineTensor(workspace_data.data_ptr(), - workspace.shape(), - workspace.dtype()); - barrier = makeTransformerEngineTensor(barrier_data.data_ptr(), - barrier.shape(), - barrier.dtype()); - - // Actual call to fwd kernel - func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), - mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(), - at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, - workspace.data(), barrier.data()); - - return {ln_out, mu, rsigma}; -} - - -at::Tensor layernorm_fwd_inf(const at::Tensor &input, - const at::Tensor &weight, - const at::Tensor &bias, - float eps, - const bool zero_centered_gamma -) { - // This is a specialized version of layernorm_fwd, optimized for inference, - // which only returns the normalized output. - std::vector out = layernorm_fwd(input, weight, bias, eps, 0, zero_centered_gamma); - return out[0]; -} - - -at::Tensor cast_to_fp8(const at::Tensor &input, - const at::Tensor &scale, - at::Tensor amax, - at::Tensor scale_inv, - transformer_engine::DType otype -) { - using namespace transformer_engine; - auto input_shape = input.sizes().vec(); - std::vector shape{input_shape.begin(), input_shape.end()}; - - auto output = at::empty_like(input, at::CUDA(GetATenDType(otype))); - - auto input_cu = makeTransformerEngineTensor(input); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), shape, otype, - amax.data_ptr(), scale.data_ptr(), - scale_inv.data_ptr()); - - nvte_fp8_quantize(input_cu.data(), output_cu.data(), - at::cuda::getCurrentCUDAStream()); - - return output; -} - - -void cast_to_fp8_noalloc(const at::Tensor &input, - const at::Tensor &scale, - at::Tensor output, - at::Tensor amax, - at::Tensor scale_inv, - transformer_engine::DType otype -) { - using namespace transformer_engine; - size_t N = static_cast(input.size(0)); - size_t H = static_cast(input.size(1)); - - auto input_cu = makeTransformerEngineTensor(input); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, H}, otype, - amax.data_ptr(), scale.data_ptr(), - scale_inv.data_ptr()); - - nvte_fp8_quantize(input_cu.data(), output_cu.data(), - at::cuda::getCurrentCUDAStream()); - - return; -} - - -at::Tensor cast_from_fp8(const at::Tensor &input, - const at::Tensor &scale_inv, - transformer_engine::DType itype, - transformer_engine::DType otype -) { - using namespace transformer_engine; - auto input_shape = input.sizes().vec(); - std::vector shape{input_shape.begin(), input_shape.end()}; - - auto output = at::empty_like(input, at::CUDA(GetATenDType(otype))); - - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), shape, itype, - nullptr, nullptr, scale_inv.data_ptr()); - auto output_cu = makeTransformerEngineTensor(output); - - nvte_fp8_dequantize(input_cu.data(), output_cu.data(), - at::cuda::getCurrentCUDAStream()); - - return output; -} - - -at::Tensor scaled_softmax_forward(at::Tensor input, - float scale_factor -) { - using namespace transformer_engine; - AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); - AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || - (input.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - - const int batches = input.size(0); - const int attn_heads = input.size(1); - const int query_seq_len = input.size(2); - const int key_seq_len = input.size(3); - - TORCH_CHECK(key_seq_len <= 4096); - TORCH_CHECK(query_seq_len > 1); - - // Output - auto act_options = input.options().requires_grad(false); - auto softmax_results = - torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); - - auto input_cu = makeTransformerEngineTensor(input); - auto softmax_results_cu = makeTransformerEngineTensor(softmax_results); - - nvte_scaled_softmax_forward(input_cu.data(), softmax_results_cu.data(), scale_factor, - at::cuda::getCurrentCUDAStream()); - - return softmax_results; -} - - -at::Tensor scaled_softmax_backward(at::Tensor output_grad_, - at::Tensor softmax_results_, - float scale_factor -) { - using namespace transformer_engine; - - auto output_grads = output_grad_.contiguous(); - auto softmax_results = softmax_results_.contiguous(); - - AT_ASSERTM(output_grads.dim() == 4, "expected 4D tensor"); - AT_ASSERTM(softmax_results.dim() == 4, "expected 4D tensor"); - - AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || - (output_grads.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || - (softmax_results.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - - auto output_grads_cu = makeTransformerEngineTensor(output_grads); - auto softmax_results_cu = makeTransformerEngineTensor(softmax_results); - - // Produce gradients in place. - nvte_scaled_softmax_backward( - output_grads_cu.data(), softmax_results_cu.data(), output_grads_cu.data(), - scale_factor, at::cuda::getCurrentCUDAStream()); - - return output_grads; -} - - -at::Tensor scaled_masked_softmax_forward(at::Tensor input, - at::Tensor mask, - float scale_factor -) { - using namespace transformer_engine; - - AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); - AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || - (input.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - AT_ASSERTM(mask.dim() == 4, "expected 4D tensor"); - if (!input.is_contiguous()) - input = input.contiguous(); - if (!mask.is_contiguous()) - mask = mask.contiguous(); - - const int batches = input.size(0); - const int pad_batches = mask.size(0); - const int attn_heads = input.size(1); - const int query_seq_len = input.size(2); - const int key_seq_len = input.size(3); - TORCH_CHECK(key_seq_len <= 4096); - TORCH_CHECK(query_seq_len > 1); - TORCH_CHECK(pad_batches == 1 || pad_batches == batches); - TORCH_CHECK(mask.size(1) == 1); - TORCH_CHECK(mask.size(2) == query_seq_len); - TORCH_CHECK(mask.size(3) == key_seq_len); - - auto act_options = input.options().requires_grad(false); - auto softmax_results = - torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); - - - auto input_cu = makeTransformerEngineTensor(input); - auto mask_cu = makeTransformerEngineTensor(mask); - auto softmax_results_cu = makeTransformerEngineTensor(softmax_results); - - nvte_scaled_masked_softmax_forward( - input_cu.data(), mask_cu.data(), softmax_results_cu.data(), - scale_factor, at::cuda::getCurrentCUDAStream()); - - return softmax_results; -} - - -at::Tensor scaled_masked_softmax_backward(at::Tensor output_grad_, - at::Tensor softmax_results_, - float scale_factor -) { - using namespace transformer_engine; - - auto output_grads = output_grad_.contiguous(); - auto softmax_results = softmax_results_.contiguous(); - - AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor"); - AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor"); - - AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || - (output_grads.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || - (softmax_results.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - - auto output_grads_cu = makeTransformerEngineTensor(output_grads); - auto softmax_results_cu = makeTransformerEngineTensor(softmax_results); - - // Produce gradients in place. - nvte_scaled_softmax_backward( - output_grads_cu.data(), softmax_results_cu.data(), output_grads_cu.data(), - scale_factor, at::cuda::getCurrentCUDAStream()); - - return output_grads; -} - - -at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input, - float scale_factor -) { - using namespace transformer_engine; - - AT_ASSERTM(input.dim() == 3, "expected 3D tensor"); - AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || - (input.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - - const int attn_batches = input.size(0); - const int seq_len = input.size(1); - TORCH_CHECK(seq_len <= 2048); - - // Output - auto act_options = input.options().requires_grad(false); - auto softmax_results = - torch::empty({attn_batches, seq_len, seq_len}, act_options); - - auto input_cu = makeTransformerEngineTensor(input); - auto softmax_results_cu = makeTransformerEngineTensor(softmax_results); - - nvte_scaled_upper_triang_masked_softmax_forward(input_cu.data(), - softmax_results_cu.data(), - scale_factor, - at::cuda::getCurrentCUDAStream()); - - return softmax_results; -} - - -at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_, - at::Tensor softmax_results_, - float scale_factor -) { - using namespace transformer_engine; - - auto output_grads = output_grads_.contiguous(); - auto softmax_results = softmax_results_.contiguous(); - - AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor"); - - AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || - (output_grads.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || - (softmax_results.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - - TORCH_CHECK(output_grads.size(1) == output_grads.size(2)); - - auto output_grads_cu = makeTransformerEngineTensor(output_grads); - auto softmax_results_cu = makeTransformerEngineTensor(softmax_results); - - // Produce gradients in place. - nvte_scaled_upper_triang_masked_softmax_backward(output_grads_cu.data(), - softmax_results_cu.data(), - output_grads_cu.data(), - scale_factor, - at::cuda::getCurrentCUDAStream()); - - return output_grads; -} - - -size_t get_cublasLt_version() { - return cublasLtGetVersion(); -} - - -bool userbuf_comm_available() { // TODO(ksivamani) check on python side -#ifdef NVTE_WITH_USERBUFFERS - return true; -#else - return false; -#endif -} - -void placeholder() {} // TODO(ksivamani) clean this up - -namespace flash_attention { - -constexpr int warp_size = 32; -constexpr int type_size = 2; // FP16 or BF16 -constexpr int nvec = sizeof(uint64_t) / type_size; -constexpr int load_size = warp_size * nvec; -constexpr int block_size = 512; - -template -__launch_bounds__(block_size) -__global__ void prepare_kernel_fwd(const T *qkvi, - T *qkv, - const size_t B, - const size_t S, - const size_t Z, - const size_t W) { - const int warpid = (blockDim.x * blockIdx.x + threadIdx.x) / warp_size; - const int id_in_warp = threadIdx.x % warp_size; - const size_t offset_input = blockIdx.y * W + warpid * 3 * W * Z + id_in_warp * nvec; - const T *my_input = qkvi + offset_input; - - const size_t s = warpid / B; - if (s >= S) return; - - const size_t b = warpid % B; - - const size_t offset_output = blockIdx.y * B * S * Z * W + - (s + b * S) * W * Z + - id_in_warp * nvec; - - T *my_output = qkv + offset_output; - - for (int i = 0; i < Z; ++i) { - uint64_t *out = reinterpret_cast(my_output + i * load_size); - *out = *reinterpret_cast(my_input + i * load_size * 3); - } -} - -template -__launch_bounds__(block_size) -__global__ void prepare_kernel_bwd(const T *q, const T *k, const T *v, - T *qkv, const size_t B, const size_t S, - const size_t Z, const size_t W) { - const T *input = blockIdx.y == 0 ? q : (blockIdx.y == 1 ? k : v); - - const int warpid = (blockDim.x * blockIdx.x + threadIdx.x) / warp_size; - const int id_in_warp = threadIdx.x % warp_size; - const size_t offset_input = warpid * W * Z + id_in_warp * nvec; - const T *my_input = input + offset_input; - - const size_t b = warpid / S; - if (b >= B) return; - - const size_t s = warpid % S; - - const size_t offset_output = (b + s * B) * 3 * W * Z + - id_in_warp * nvec + blockIdx.y * W; - - T *my_output = qkv + offset_output; - - for (int i = 0; i < Z; ++i) { - uint64_t *out = reinterpret_cast(my_output + i * load_size * 3); - *out = *reinterpret_cast(my_input + i * load_size); - } -} - -} // namespace flash_attention - -at::Tensor fa_prepare_fwd(at::Tensor qkvi) { - NVTE_CHECK(qkvi.dim() == 4, "Expected 4-dim tensor."); - NVTE_CHECK(qkvi.scalar_type() == at::ScalarType::Half || - qkvi.scalar_type() == at::ScalarType::BFloat16); - NVTE_CHECK(qkvi.size(3) % flash_attention::load_size == 0); - NVTE_CHECK(qkvi.size(3) == flash_attention::load_size); - NVTE_CHECK(qkvi.stride(3) == 1, "Wrong stride."); - NVTE_CHECK(qkvi.stride(2) == 3 * qkvi.size(3), "Wrong stride."); - NVTE_CHECK(qkvi.stride(1) == 3 * qkvi.size(3) * qkvi.size(2), "Wrong stride."); - NVTE_CHECK(qkvi.stride(0) == 3 * qkvi.size(3) * qkvi.size(2) * qkvi.size(1), "Wrong stride."); - - // [s, b, n, h * 3] -> [3, b, s, n, h] - std::vector shape = {3, qkvi.size(1), qkvi.size(0), qkvi.size(2), qkvi.size(3)}; - at::Tensor qkv = at::empty(shape, at::CUDA(qkvi.scalar_type())); - - size_t warps = qkvi.size(0) * qkvi.size(1); - size_t warps_per_block = flash_attention::block_size / flash_attention::warp_size; - size_t blocks = (warps + warps_per_block - 1) / warps_per_block; - dim3 grid(blocks, 3); - int threads = flash_attention::block_size; - if (qkvi.scalar_type() == at::ScalarType::Half) { - using dtype = at::Half; - flash_attention::prepare_kernel_fwd<<>>( - qkvi.data_ptr(), - qkv.data_ptr(), - shape[1], - shape[2], - shape[3], - shape[4]); - } else { - using dtype = at::BFloat16; - flash_attention::prepare_kernel_fwd<<>>( - qkvi.data_ptr(), - qkv.data_ptr(), - shape[1], - shape[2], - shape[3], - shape[4]); - } - - return qkv; -} - -at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v) { - NVTE_CHECK(q.is_contiguous()); - NVTE_CHECK(k.is_contiguous()); - NVTE_CHECK(v.is_contiguous()); - NVTE_CHECK(q.dim() == 4, "Expected 4-dim tensor."); - NVTE_CHECK(k.dim() == 4, "Expected 4-dim tensor."); - NVTE_CHECK(v.dim() == 4, "Expected 4-dim tensor."); - NVTE_CHECK(q.scalar_type() == at::ScalarType::Half || - q.scalar_type() == at::ScalarType::BFloat16); - NVTE_CHECK(k.scalar_type() == q.scalar_type()); - NVTE_CHECK(v.scalar_type() == q.scalar_type()); - NVTE_CHECK(q.size(3) % flash_attention::load_size == 0); - NVTE_CHECK(q.size(3) == flash_attention::load_size); - NVTE_CHECK(k.size(3) % flash_attention::load_size == 0); - NVTE_CHECK(k.size(3) == flash_attention::load_size); - NVTE_CHECK(v.size(3) % flash_attention::load_size == 0); - NVTE_CHECK(v.size(3) == flash_attention::load_size); - - // 3 x [s, b, n, h] -> [b, s, n, 3 * h] - - std::vector shape = {q.size(1), q.size(0), q.size(2), 3 * q.size(3)}; - at::Tensor qkv = at::empty(shape, at::CUDA(q.scalar_type())); - - size_t warps = q.size(0) * q.size(1); - size_t warps_per_block = flash_attention::block_size / flash_attention::warp_size; - size_t blocks = (warps + warps_per_block - 1) / warps_per_block; - dim3 grid(blocks, 3); - int threads = flash_attention::block_size; - if (q.scalar_type() == at::ScalarType::Half) { - using dtype = at::Half; - flash_attention::prepare_kernel_bwd<<>>( - q.data_ptr(), - k.data_ptr(), - v.data_ptr(), - qkv.data_ptr(), - q.size(0), - q.size(1), - q.size(2), - q.size(3)); - } else { - using dtype = at::BFloat16; - flash_attention::prepare_kernel_bwd<<>>( - q.data_ptr(), - k.data_ptr(), - v.data_ptr(), - qkv.data_ptr(), - q.size(0), - q.size(1), - q.size(2), - q.size(3)); - } - - return qkv; -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - // Softmax functions - m.def("scaled_softmax_forward", &scaled_softmax_forward, "Scaled Softmax FWD"); - m.def("scaled_softmax_backward", &scaled_softmax_backward, "Scaled Softmax BWD"); - m.def("scaled_masked_softmax_forward", &scaled_masked_softmax_forward, - "Scaled Masked Softmax FWD"); - m.def("scaled_masked_softmax_backward", &scaled_masked_softmax_backward, - "Scaled Masked Softmax BWD"); - m.def("scaled_upper_triang_masked_softmax_forward", - &scaled_upper_triang_masked_softmax_forward, - "Scaled Upper-Triangular Masked Softmax FWD"); - m.def("scaled_upper_triang_masked_softmax_backward", - &scaled_upper_triang_masked_softmax_backward, - "Scaled Upper-Triangular Masked Softmax BWD"); - - // Other granular functions - m.def("layernorm_fwd_fp8", &layernorm_fwd_fp8, "LN FWD FP8"); - m.def("layernorm_fwd_fp8_noalloc", &layernorm_fwd_fp8_noalloc, "LN FWD FP8"); - m.def("layernorm_bwd", &layernorm_bwd, "LN BWD"); - m.def("layernorm_fwd", &layernorm_fwd, "LN FWD"); - m.def("layernorm_fwd_noalloc", &layernorm_fwd_noalloc, "LN FWD"); - m.def("fused_cast_transpose", &fused_cast_transpose, "Fused Cast + Transpose"); - m.def("fused_cast_transpose_bgrad", &fused_cast_transpose_bgrad, - "Fused Cast + Transpose + BGRAD"); - m.def("fused_fp8_transpose_bgrad", &fused_fp8_transpose_bgrad, - "Fused FP8 Transpose + BGRAD"); - m.def("fused_cast_transpose_bgrad_dgelu", &fused_cast_transpose_bgrad_dgelu, - "Fused Cast + Transpose + BGRAD + DGELU"); - m.def("fused_multi_cast_transpose", &fused_multi_cast_transpose, - "Fused Multi-tensor Cast + Transpose"); - m.def("cast_to_fp8", &cast_to_fp8, "Cast to FP8"); - m.def("cast_to_fp8_noalloc", &cast_to_fp8_noalloc, "Cast to FP8"); - m.def("cast_from_fp8", &cast_from_fp8, "Cast from FP8"); - m.def("te_gemm", &te_gemm, "CublasLt GEMM"); - m.def("fused_attn_fwd_qkvpacked", &fused_attn_fwd_qkvpacked, - "Fused Attention FP8/BF16/FP16 FWD with packed QKV"); - m.def("fused_attn_bwd_qkvpacked", &fused_attn_bwd_qkvpacked, - "Fused Attention FP8/BF16/FP16 BWD with packed QKV"); - m.def("fused_attn_fwd_kvpacked", &fused_attn_fwd_kvpacked, - "Fused Attention FP8/BF16/FP16 FWD with packed KV"); - m.def("fused_attn_bwd_kvpacked", &fused_attn_bwd_kvpacked, - "Fused Attention FP8/BF16/FP16 BWD with packed KV"); - m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O"); - m.def("gelu", &gelu, "GeLU with FP8 output"); - m.def("relu", &relu, "ReLU with FP8 output"); - m.def("geglu", &geglu, "GeGLU with FP8 output"); - m.def("reglu", ®lu, "ReGLU with FP8 output"); - m.def("swiglu", &swiglu, "SwiGLU with FP8 output"); - m.def("dgelu", &dgelu, "Backward of GeLU"); - m.def("drelu", &drelu, "Backward of ReLU"); - m.def("dgeglu", &dgeglu, "Backward of GeGLU"); - m.def("dreglu", &dreglu, "Backward of ReGLU"); - m.def("dswiglu", &dswiglu, "Backward of SwiGLU"); - m.def("fa_prepare_fwd", &fa_prepare_fwd, "Prepare QKV for Flash Attention"); - m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention"); - m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend"); - - // Misc - m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version"); - m.def("userbuf_comm_available", &userbuf_comm_available, "If userbuf backend is available"); - - // Data structures - py::class_(m, "FP8TensorMeta") - .def(py::init<>()) - .def_readwrite("scale", &transformer_engine::FP8TensorMeta::scale) - .def_readwrite("scale_inv", &transformer_engine::FP8TensorMeta::scale_inv) - .def_readwrite("amax_history", &transformer_engine::FP8TensorMeta::amax_history); - -#ifdef NVTE_WITH_USERBUFFERS - py::enum_(m, "UbufOverlapAlgo") - .value("BULK_OVERLAP_AG", ubuf::UBOverlapAlgo::BULK_OVERLAP_AG) - .value("BULK_OVERLAP_RS", ubuf::UBOverlapAlgo::BULK_OVERLAP_RS) - .value("SPLIT_PIPELINED_RS", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_RS) - .value("SPLIT_PIPELINED_AG", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_AG); - - py::class_(m, "UbufCommOverlap") - .def(py::init()) - .def("bulk_overlap", &ubuf::UbufCommOverlap::bulk_overlap) - .def("split_overlap_rs", &ubuf::UbufCommOverlap::split_overlap_rs) - .def("copy_input_to_ubuf", &ubuf::UbufCommOverlap::copy_input_to_ubuf) - .def("get_ubuf_output", &ubuf::UbufCommOverlap::get_ubuf_output); - - py::class_(m, "UbufP2PCommOverlap") - .def(py::init()) - .def("split_overlap_ag", &ubuf::UbufP2PCommOverlap::split_overlap_ag) - .def("copy_input_to_ubuf", &ubuf::UbufP2PCommOverlap::copy_input_to_ubuf) - .def("get_ubuf_output", &ubuf::UbufP2PCommOverlap::get_ubuf_output); -#else // NVTE_WITH_USERBUFFERS - m.def("UbufOverlapAlgo", &placeholder, "Dummy function for python side annotations"); - m.def("UbufCommOverlap", &placeholder, "Dummy function for python side annotations"); - m.def("UbufP2PCommOverlap", &placeholder, "Dummy function for python side annotations"); -#endif // NVTE_WITH_USERBUFFERS - - py::enum_(m, "DType", py::module_local()) - .value("kByte", transformer_engine::DType::kByte) - .value("kInt32", transformer_engine::DType::kInt32) - .value("kFloat32", transformer_engine::DType::kFloat32) - .value("kFloat16", transformer_engine::DType::kFloat16) - .value("kBFloat16", transformer_engine::DType::kBFloat16) - .value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) - .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2); - - py::enum_(m, "FP8FwdTensors") - .value("GEMM1_INPUT", transformer_engine::FP8FwdTensors::GEMM1_INPUT) - .value("GEMM1_WEIGHT", transformer_engine::FP8FwdTensors::GEMM1_WEIGHT) - .value("GEMM1_OUTPUT", transformer_engine::FP8FwdTensors::GEMM1_OUTPUT) - .value("GEMM2_INPUT", transformer_engine::FP8FwdTensors::GEMM2_INPUT) - .value("GEMM2_WEIGHT", transformer_engine::FP8FwdTensors::GEMM2_WEIGHT) - .value("GEMM2_OUTPUT", transformer_engine::FP8FwdTensors::GEMM2_OUTPUT) - .value("GEMM3_INPUT", transformer_engine::FP8FwdTensors::GEMM3_INPUT) - .value("GEMM3_WEIGHT", transformer_engine::FP8FwdTensors::GEMM3_WEIGHT) - .value("GEMM3_OUTPUT", transformer_engine::FP8FwdTensors::GEMM3_OUTPUT); - - py::enum_(m, "FP8BwdTensors") - .value("GRAD_OUTPUT1", transformer_engine::FP8BwdTensors::GRAD_OUTPUT1) - .value("GRAD_INPUT1", transformer_engine::FP8BwdTensors::GRAD_INPUT1) - .value("GRAD_OUTPUT2", transformer_engine::FP8BwdTensors::GRAD_OUTPUT2) - .value("GRAD_INPUT2", transformer_engine::FP8BwdTensors::GRAD_INPUT2) - .value("GRAD_OUTPUT3", transformer_engine::FP8BwdTensors::GRAD_OUTPUT3) - .value("GRAD_INPUT3", transformer_engine::FP8BwdTensors::GRAD_INPUT3); - - py::enum_(m, "NVTE_Bias_Type") - .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) - .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) - .value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); - - py::enum_(m, "NVTE_Mask_Type") - .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) - .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) - .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK); - - py::enum_(m, "NVTE_QKV_Layout") - .value("NVTE_NOT_INTERLEAVED", NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED) - .value("NVTE_QKV_INTERLEAVED", NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) - .value("NVTE_KV_INTERLEAVED", NVTE_QKV_Layout::NVTE_KV_INTERLEAVED); - - py::enum_(m, "NVTE_Fused_Attn_Backend") - .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) - .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) - .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) - .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); -} diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 1467397c63..d06906b5a2 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -106,6 +106,10 @@ std::vector fused_attn_bwd_kvpacked( c10::optional amax_dP, c10::optional amax_dQKV); +at::Tensor fa_prepare_fwd(at::Tensor qkvi); + +at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v); + void te_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType A_type, @@ -318,6 +322,77 @@ at::Tensor layernorm_fwd_inf(const at::Tensor &input, const bool zero_centered_gamma ); +/*************************************************************************************************** + * RMSNorm + **************************************************************************************************/ + +std::vector rmsnorm_bwd(const at::Tensor &dz, + const at::Tensor &x, + const at::Tensor &rsigma, + const at::Tensor &gamma, + const int sm_margin, + const bool zero_centered_gamma +); + + +std::vector rmsnorm_fwd_fp8(const at::Tensor &input, + const at::Tensor &weight, + float eps, + at::Tensor scale, + at::Tensor amax, + at::Tensor scale_inv, + transformer_engine::DType otype, + const int sm_margin, + const bool zero_centered_gamma +); + +std::vector rmsnorm_fwd_fp8_noalloc(const at::Tensor &input, + const at::Tensor &weight, + float eps, + at::Tensor scale, + at::Tensor ln_out, + at::Tensor amax, + at::Tensor scale_inv, + transformer_engine::DType otype, + const int sm_margin, + const bool zero_centered_gamma +); + +at::Tensor rmsnorm_fwd_fp8_inf(const at::Tensor &input, + const at::Tensor &weight, + float eps, + at::Tensor scale, + at::Tensor amax, + at::Tensor scale_inv, + transformer_engine::DType otype, + const bool zero_centered_gamma +); + +std::vector rmsnorm_fwd(const at::Tensor &input, + const at::Tensor &weight, + float eps, + const int sm_margin, + const bool zero_centered_gamma +); + +std::vector rmsnorm_fwd_noalloc(const at::Tensor &input, + const at::Tensor &weight, + at::Tensor ln_out, + float eps, + const int sm_margin, + const bool zero_centered_gamma +); + +at::Tensor rmsnorm_fwd_inf(const at::Tensor &input, + const at::Tensor &weight, + float eps, + const bool zero_centered_gamma +); + +/*************************************************************************************************** + * Cast + **************************************************************************************************/ + at::Tensor cast_to_fp8(const at::Tensor &input, const at::Tensor &scale, at::Tensor amax, @@ -374,3 +449,9 @@ at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_, at::Tensor softmax_results_, float scale_factor ); + +size_t get_cublasLt_version(); + +bool userbuf_comm_available(); + +void placeholder(); diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cu b/transformer_engine/pytorch/csrc/extensions/activation.cu new file mode 100644 index 0000000000..05c61acc59 --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/activation.cu @@ -0,0 +1,267 @@ +/************************************************************************* + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "extensions.h" + +at::Tensor gelu(at::Tensor input, + at::Tensor scale, + at::Tensor amax, + at::Tensor scale_inv, + transformer_engine::DType otype +) { + using namespace transformer_engine; + + size_t N = static_cast(input.size(-1)); + size_t M = input.numel() / N; + + auto output = + allocateTorchTensor(M, + N, + otype); + + auto itype = GetTransformerEngineDType(input.scalar_type()); + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); + auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype, + amax.data_ptr(), scale.data_ptr(), + scale_inv.data_ptr()); + + nvte_gelu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); + + return output; +} + +at::Tensor dgelu(at::Tensor grad, + at::Tensor input, + transformer_engine::DType otype +) { + using namespace transformer_engine; + + size_t N = static_cast(input.size(-1)); + size_t M = input.numel() / N; + + auto output = + allocateTorchTensor(M, + N, + otype); + + auto itype = GetTransformerEngineDType(input.scalar_type()); + auto gtype = GetTransformerEngineDType(grad.scalar_type()); + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); + auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N}, gtype); + auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype); + + nvte_dgelu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); + + return output; +} + +at::Tensor relu(at::Tensor input, + at::Tensor scale, + at::Tensor amax, + at::Tensor scale_inv, + transformer_engine::DType otype +) { + using namespace transformer_engine; + + size_t N = static_cast(input.size(-1)); + size_t M = static_cast(input.numel()) / N; + + auto output = + allocateTorchTensor(M, + N, + otype); + + auto itype = GetTransformerEngineDType(input.scalar_type()); + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); + auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype, + amax.data_ptr(), scale.data_ptr(), + scale_inv.data_ptr()); + + nvte_relu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); + + return output; +} + +at::Tensor drelu(at::Tensor grad, + at::Tensor input, + transformer_engine::DType otype +) { + using namespace transformer_engine; + + size_t N = static_cast(input.size(-1)); + size_t M = input.numel() / N; + + auto output = + allocateTorchTensor(M, + N, + otype); + + auto itype = GetTransformerEngineDType(input.scalar_type()); + auto gtype = GetTransformerEngineDType(grad.scalar_type()); + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); + auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N}, gtype); + auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype); + + nvte_drelu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); + + return output; +} + +at::Tensor geglu(at::Tensor input, + at::Tensor scale, + at::Tensor amax, + at::Tensor scale_inv, + transformer_engine::DType otype +) { + using namespace transformer_engine; + + size_t N = static_cast(input.size(-1)); + size_t M = input.numel() / N; + + auto output = + allocateTorchTensor(M, + N / 2, + otype); + + auto itype = GetTransformerEngineDType(input.scalar_type()); + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); + auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N / 2}, otype, + amax.data_ptr(), scale.data_ptr(), + scale_inv.data_ptr()); + + nvte_geglu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); + + return output; +} + +at::Tensor dgeglu(at::Tensor grad, + at::Tensor input, + transformer_engine::DType otype +) { + using namespace transformer_engine; + + size_t N = static_cast(input.size(-1)); + size_t M = input.numel() / N; + + auto output = + allocateTorchTensor(M, + N, + otype); + + auto itype = GetTransformerEngineDType(input.scalar_type()); + auto gtype = GetTransformerEngineDType(grad.scalar_type()); + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); + auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N / 2}, gtype); + auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype); + + nvte_dgeglu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); + + return output; +} + +at::Tensor reglu(at::Tensor input, + at::Tensor scale, + at::Tensor amax, + at::Tensor scale_inv, + transformer_engine::DType otype +) { + using namespace transformer_engine; + + size_t N = static_cast(input.size(-1)); + size_t M = input.numel() / N; + + auto output = + allocateTorchTensor(M, + N / 2, + otype); + + auto itype = GetTransformerEngineDType(input.scalar_type()); + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); + auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N / 2}, otype, + amax.data_ptr(), scale.data_ptr(), + scale_inv.data_ptr()); + + nvte_reglu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); + + return output; +} + +at::Tensor dreglu(at::Tensor grad, + at::Tensor input, + transformer_engine::DType otype +) { + using namespace transformer_engine; + + size_t N = static_cast(input.size(-1)); + size_t M = input.numel() / N; + + auto output = + allocateTorchTensor(M, + N, + otype); + + auto itype = GetTransformerEngineDType(input.scalar_type()); + auto gtype = GetTransformerEngineDType(grad.scalar_type()); + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); + auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N / 2}, gtype); + auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype); + + nvte_dreglu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); + + return output; +} + +at::Tensor swiglu(at::Tensor input, + at::Tensor scale, + at::Tensor amax, + at::Tensor scale_inv, + transformer_engine::DType otype +) { + using namespace transformer_engine; + + size_t N = static_cast(input.size(-1)); + size_t M = input.numel() / N; + + auto output = + allocateTorchTensor(M, + N / 2, + otype); + + auto itype = GetTransformerEngineDType(input.scalar_type()); + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); + auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N / 2}, otype, + amax.data_ptr(), scale.data_ptr(), + scale_inv.data_ptr()); + + nvte_swiglu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); + + return output; +} + +at::Tensor dswiglu(at::Tensor grad, + at::Tensor input, + transformer_engine::DType otype +) { + using namespace transformer_engine; + + size_t N = static_cast(input.size(-1)); + size_t M = input.numel() / N; + + auto output = + allocateTorchTensor(M, + N, + otype); + + auto itype = GetTransformerEngineDType(input.scalar_type()); + auto gtype = GetTransformerEngineDType(grad.scalar_type()); + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); + auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N / 2}, gtype); + auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype); + + nvte_dswiglu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); + + return output; +} diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu new file mode 100644 index 0000000000..4904fbade5 --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -0,0 +1,876 @@ +/************************************************************************* + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "extensions.h" + +constexpr int block_size = 512; +constexpr int ctas_per_sm = 4; + +// get the fused attention backend +NVTE_Fused_Attn_Backend get_fused_attn_backend( + const transformer_engine::DType q_dtype, + const transformer_engine::DType kv_dtype, + NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, + float p_dropout, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim) { + NVTE_Fused_Attn_Backend fused_attention_backend = + nvte_get_fused_attn_backend( + static_cast(q_dtype), static_cast(kv_dtype), + qkv_layout, bias_type, attn_mask_type, + p_dropout, max_seqlen_q, max_seqlen_kv, head_dim); + return fused_attention_backend; +} + +// fast zero-fills of tensors +template +__global__ void __launch_bounds__(block_size) mha_fill_kernel(scalar_t* out_tensor, + const int32_t* const start_row, + const size_t num_rows) { + size_t row_stride = gridDim.y * blockDim.x; + size_t row_index = blockIdx.x + static_cast(start_row[0]); + size_t col_index = blockIdx.y * blockDim.x + threadIdx.x; + while (row_index < num_rows) { + out_tensor[row_index*row_stride + col_index] = 0; + row_index += gridDim.x; + } +} + +// fast zero-fills of tensors +void mha_fill(const at::Tensor &self, const at::Tensor &start_index) { + auto max_tokens = self.size(0); + auto self_2d = self.view({max_tokens, -1}); + auto fcd_size = self_2d.size(1); + TORCH_CHECK(self.is_contiguous(), "input not contiguous"); + TORCH_CHECK(fcd_size % block_size == 0, "input size not aligned to block size"); + const int num_mp = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; + uint64_t num_blk_y = (uint64_t)(fcd_size / block_size); + uint64_t num_blk_x = (uint64_t)((num_mp * ctas_per_sm + num_blk_y - 1) / num_blk_y); + dim3 dim_grid(num_blk_x, num_blk_y); + dim3 dim_block(block_size); + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( + at::ScalarType::Half, at::ScalarType::BFloat16, + self_2d.scalar_type(), "mha_fill", [&]() { + mha_fill_kernel<<>>( + self_2d.data_ptr(), + static_cast(start_index.data_ptr()), + max_tokens); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); +} + +// extract seed and offset from PhiloxCudaState +__global__ void unpack(at::PhiloxCudaState arg, int64_t* rng_state_ptr) { + if (arg.captured_) { + rng_state_ptr[0] = static_cast(*arg.seed_.ptr); + rng_state_ptr[1] = static_cast( + *(arg.offset_.ptr) + static_cast(arg.offset_intragraph_)); + } else { + rng_state_ptr[0] = static_cast(arg.seed_.val); + rng_state_ptr[1] = static_cast(arg.offset_.val); + } +} + +// extract PhiloxCudaState from CUDA random number generator +at::PhiloxCudaState init_philox_state( + at::CUDAGeneratorImpl* gen, + size_t elts_per_thread) { + at::PhiloxCudaState philox_args; + std::lock_guard lock(gen->mutex_); + philox_args = gen->philox_cuda_state(elts_per_thread); + return philox_args; +} + +// fused attention FWD with packed QKV +std::vector fused_attn_fwd_qkvpacked( + size_t b, size_t max_seqlen, size_t total_seqs, + size_t h, size_t d, + bool is_training, float attn_scale, float p_dropout, bool set_zero, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + const at::Tensor cu_seqlens, + const at::Tensor QKV, + const transformer_engine::DType qkv_type, + const c10::optional descale_QKV, + const c10::optional scale_S, + const c10::optional scale_O, + c10::optional amax_S, + c10::optional amax_O, + const c10::optional Bias, + const c10::optional rng_gen, + size_t rng_elts_per_thread) { + using namespace transformer_engine; + + // create output tensor O + auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA); + auto O = torch::empty({static_cast(total_seqs), + static_cast(h), static_cast(d)}, options); + if (set_zero && (h * d % block_size == 0)) { + mha_fill(O, cu_seqlens.index({torch::indexing::Slice(-1, torch::indexing::None)})); + } else { + O.fill_(0); + } + + // construct NVTE tensors + TensorWrapper te_QKV, te_S, te_O, te_Bias, te_cu_seqlens; + if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { + // FP8 + if ((!descale_QKV.has_value()) || (!scale_S.has_value()) || (!scale_O.has_value()) + || (!amax_S.has_value()) || (!amax_O.has_value())) { + std::string err_tensors = "descale_QKV, scale_S, scale_O, amax_S and amax_O"; + NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); + } + te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), {total_seqs, 3, h, d}, + qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); + at::Tensor descale_S = torch::empty_like(scale_S.value()); + te_S = makeTransformerEngineTensor(nullptr, {0}, + DType::kFloat32, amax_S.value().data_ptr(), + scale_S.value().data_ptr(), descale_S.data_ptr()); + te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs, h, d}, + qkv_type, amax_O.value().data_ptr(), scale_O.value().data_ptr(), nullptr); + } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { + // BF16 or FP16 + te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), {total_seqs, 3, h, d}, + qkv_type, nullptr, nullptr, nullptr); + te_S = makeTransformerEngineTensor(nullptr, {0}, + DType::kFloat32, nullptr, nullptr, nullptr); + te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs, h, d}, + qkv_type, nullptr, nullptr, nullptr); + } else { + NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); + } + if ((bias_type != NVTE_NO_BIAS) && (Bias.has_value())) { + auto bias_shape = Bias.value().sizes().vec(); + std::vector shape{bias_shape.begin(), bias_shape.end()}; + te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), shape, + DType::kFloat32, nullptr, nullptr, nullptr); + } + te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens.data_ptr(), {b+1}, + DType::kInt32, nullptr, nullptr, nullptr); + + // extract random number generator seed and offset + auto gen = at::get_generator_or_default( + rng_gen, at::cuda::detail::getDefaultCUDAGenerator()); + at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread); + auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); + unpack<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>( + philox_args, static_cast(rng_state.data_ptr())); + auto te_rng_state = makeTransformerEngineTensor(rng_state); + + // create auxiliary output tensors + NVTETensorPack nvte_aux_tensor_pack; + nvte_tensor_pack_create(&nvte_aux_tensor_pack); + + // create workspace + TensorWrapper workspace; + + // populate tensors with appropriate shapes and dtypes + nvte_fused_attn_fwd_qkvpacked( + te_QKV.data(), + te_Bias.data(), + te_S.data(), + te_O.data(), + &nvte_aux_tensor_pack, + te_cu_seqlens.data(), + te_rng_state.data(), + max_seqlen, + is_training, attn_scale, p_dropout, + qkv_layout, bias_type, attn_mask_type, + workspace.data(), + at::cuda::getCurrentCUDAStream()); + + // allocate memory for workspace and auxiliary output tensors + auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); + workspace = makeTransformerEngineTensor( + workspace_data.data_ptr(), + workspace.shape(), workspace.dtype()); + + // output_tensors = [O, nvte_aux_tensor_pack.tensors] + std::vector output_tensors; + output_tensors.push_back(O); + for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { + auto tensor = reinterpret_cast(nvte_aux_tensor_pack.tensors[i]); + // allocate memory for nvte_aux_tensor_pack.tensors + auto output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false); + output_tensors.push_back(output_tensor); + tensor->data.dptr = output_tensor.data_ptr(); + } + + // execute the kernel + nvte_fused_attn_fwd_qkvpacked( + te_QKV.data(), + te_Bias.data(), + te_S.data(), + te_O.data(), + &nvte_aux_tensor_pack, + te_cu_seqlens.data(), + te_rng_state.data(), + max_seqlen, + is_training, attn_scale, p_dropout, + qkv_layout, bias_type, attn_mask_type, + workspace.data(), + at::cuda::getCurrentCUDAStream()); + + // destroy tensor wrappers, but not allocated memory + nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); + + // if training, [O, softmax-related tensors, rng_state]; if inference, [O] + return output_tensors; +} + +// fused attention BWD with packed QKV +std::vector fused_attn_bwd_qkvpacked( + size_t b, size_t max_seqlen, size_t total_seqs, + size_t h, size_t d, + float attn_scale, float p_dropout, bool set_zero, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + const at::Tensor cu_seqlens, + const at::Tensor QKV, + const at::Tensor O, + const at::Tensor dO, + const transformer_engine::DType qkv_type, + const std::vector Aux_CTX_Tensors, + const c10::optional descale_QKV, + const c10::optional descale_S, + const c10::optional descale_O, + const c10::optional descale_dO, + const c10::optional scale_S, + const c10::optional scale_dP, + const c10::optional scale_dQKV, + c10::optional amax_dP, + c10::optional amax_dQKV) { + using namespace transformer_engine; + + // create output tensor dQKV + at::Tensor dQKV = torch::empty_like(QKV); + auto max_tokens = dQKV.size(0); + auto self_2d = dQKV.view({max_tokens, -1}); + auto fcd_size = self_2d.size(1); + if (set_zero && (fcd_size % block_size == 0)) { + mha_fill(dQKV, cu_seqlens.index({torch::indexing::Slice(-1, torch::indexing::None)})); + } else { + dQKV.fill_(0); + } + auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA); + at::Tensor dBias; + TensorWrapper te_dBias; + if (bias_type != NVTE_NO_BIAS) { + dBias = torch::zeros({1, static_cast(h), + static_cast(max_seqlen), + static_cast(max_seqlen)}, options); + te_dBias = makeTransformerEngineTensor(dBias); + } + + // construct NVTE tensors + TensorWrapper te_QKV, te_O, te_dO, te_S, te_dP, te_dQKV; + if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { + // FP8 + if ((!descale_QKV.has_value()) || (!descale_S.has_value()) + || (!descale_O.has_value()) || (!descale_dO.has_value()) + || (!scale_S.has_value()) || (!scale_dP.has_value()) + || (!scale_dQKV.has_value()) + || (!amax_dP.has_value()) || (!amax_dQKV.has_value())) { + std::string err_tensors = "descale_QKV, descale_S, descale_O, scale_S, scale_dP, "; + err_tensors = err_tensors + std::string("scale_dQKV, amax_dP and amax_dQKV"); + NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); + } + te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), {total_seqs, 3, h, d}, + qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); + te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs, h, d}, + qkv_type, nullptr, nullptr, descale_O.value().data_ptr()); + te_dO = makeTransformerEngineTensor(dO.data_ptr(), {total_seqs, h, d}, + qkv_type, nullptr, nullptr, descale_dO.value().data_ptr()); + te_S = makeTransformerEngineTensor(nullptr, {0}, + DType::kFloat32, + nullptr, scale_S.value().data_ptr(), descale_S.value().data_ptr()); + at::Tensor descale_dP = torch::empty_like(scale_dP.value()); + te_dP = makeTransformerEngineTensor(nullptr, {0}, + DType::kFloat32, amax_dP.value().data_ptr(), scale_dP.value().data_ptr(), + descale_dP.data_ptr()); + te_dQKV = makeTransformerEngineTensor(dQKV.data_ptr(), {total_seqs, 3, h, d}, + qkv_type, + amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr); + } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { + // BF16 or FP16 + te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), {total_seqs, 3, h, d}, + qkv_type, nullptr, nullptr, nullptr); + te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs, h, d}, + qkv_type, nullptr, nullptr, nullptr); + te_dO = makeTransformerEngineTensor(dO.data_ptr(), {total_seqs, h, d}, + qkv_type, nullptr, nullptr, nullptr); + te_S = makeTransformerEngineTensor(nullptr, {0}, + DType::kFloat32, nullptr, nullptr, nullptr); + te_dP = makeTransformerEngineTensor(nullptr, {0}, + DType::kFloat32, nullptr, nullptr, nullptr); + te_dQKV = makeTransformerEngineTensor(dQKV.data_ptr(), {total_seqs, 3, h, d}, + qkv_type, nullptr, nullptr, nullptr); + } else { + NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); + } + + // convert auxiliary tensors from forward into NVTETensors + NVTETensorPack nvte_aux_tensor_pack; + nvte_tensor_pack_create(&nvte_aux_tensor_pack); + nvte_aux_tensor_pack.size = Aux_CTX_Tensors.size(); + for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { + auto tensor = reinterpret_cast(nvte_aux_tensor_pack.tensors[i]); + tensor->data.dptr = Aux_CTX_Tensors[i].data_ptr(); + std::vector tmp(Aux_CTX_Tensors[i].sizes().vec()); + tensor->data.shape = std::vector(tmp.begin(), tmp.end()); + tensor->data.dtype = GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type()); + } + + // create cu_seqlens tensorwrappers + TensorWrapper te_cu_seqlens; + te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens.data_ptr(), {b+1}, + DType::kInt32, nullptr, nullptr, nullptr); + + // create workspace + TensorWrapper workspace; + + // populate tensors with appropriate shapes and dtypes + nvte_fused_attn_bwd_qkvpacked( + te_QKV.data(), + te_O.data(), + te_dO.data(), + te_S.data(), + te_dP.data(), + &nvte_aux_tensor_pack, + te_dQKV.data(), + te_dBias.data(), + te_cu_seqlens.data(), + max_seqlen, + attn_scale, p_dropout, + qkv_layout, bias_type, attn_mask_type, + workspace.data(), + at::cuda::getCurrentCUDAStream()); + + // allocate memory for workspace + auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); + workspace = makeTransformerEngineTensor( + workspace_data.data_ptr(), + workspace.shape(), workspace.dtype()); + + // execute kernel + nvte_fused_attn_bwd_qkvpacked( + te_QKV.data(), + te_O.data(), + te_dO.data(), + te_S.data(), + te_dP.data(), + &nvte_aux_tensor_pack, + te_dQKV.data(), + te_dBias.data(), + te_cu_seqlens.data(), + max_seqlen, + attn_scale, p_dropout, + qkv_layout, bias_type, attn_mask_type, + workspace.data(), + at::cuda::getCurrentCUDAStream()); + + // destroy tensor wrappers + nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); + + return {dQKV, dBias}; +} + +// fused attention FWD with packed KV +std::vector fused_attn_fwd_kvpacked( + size_t b, size_t max_seqlen_q, size_t max_seqlen_kv, + size_t total_seqs_q, size_t total_seqs_kv, + size_t h, size_t d, + bool is_training, float attn_scale, float p_dropout, bool set_zero, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + const at::Tensor cu_seqlens_q, + const at::Tensor cu_seqlens_kv, + const at::Tensor Q, + const at::Tensor KV, + const transformer_engine::DType qkv_type, + const c10::optional descale_QKV, + const c10::optional scale_S, + const c10::optional scale_O, + c10::optional amax_S, + c10::optional amax_O, + const c10::optional Bias, + const c10::optional rng_gen, + size_t rng_elts_per_thread) { + using namespace transformer_engine; + + // create output tensor O + auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA); + auto O = torch::empty({static_cast(total_seqs_q), + static_cast(h), static_cast(d)}, options); + if (set_zero && (h * d % block_size == 0)) { + mha_fill(O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); + } else { + O.fill_(0); + } + + // construct NVTE tensors + TensorWrapper te_Q, te_KV, te_S, te_O, te_Bias, te_cu_seqlens_q, te_cu_seqlens_kv; + if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { + // FP8 + if ((!descale_QKV.has_value()) || (!scale_S.has_value()) || (!scale_O.has_value()) + || (!amax_S.has_value()) || (!amax_O.has_value())) { + std::string err_tensors = "descale_QKV, scale_S, scale_O, amax_S and amax_O"; + NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); + } + te_Q = makeTransformerEngineTensor(Q.data_ptr(), {total_seqs_q, h, d}, + qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); + te_KV = makeTransformerEngineTensor(KV.data_ptr(), {total_seqs_kv, 2, h, d}, + qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); + at::Tensor descale_S = torch::empty_like(scale_S.value()); + te_S = makeTransformerEngineTensor(nullptr, {0}, + DType::kFloat32, amax_S.value().data_ptr(), + scale_S.value().data_ptr(), descale_S.data_ptr()); + te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs_q, h, d}, + qkv_type, amax_O.value().data_ptr(), scale_O.value().data_ptr(), nullptr); + } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { + // BF16 or FP16 + te_Q = makeTransformerEngineTensor(Q.data_ptr(), {total_seqs_q, h, d}, + qkv_type, nullptr, nullptr, nullptr); + te_KV = makeTransformerEngineTensor(KV.data_ptr(), {total_seqs_kv, 2, h, d}, + qkv_type, nullptr, nullptr, nullptr); + te_S = makeTransformerEngineTensor(nullptr, {0}, + DType::kFloat32, nullptr, nullptr, nullptr); + te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs_q, h, d}, + qkv_type, nullptr, nullptr, nullptr); + } else { + NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); + } + if ((bias_type != NVTE_NO_BIAS) && (Bias.has_value())) { + auto bias_shape = Bias.value().sizes().vec(); + std::vector shape{bias_shape.begin(), bias_shape.end()}; + te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), shape, + DType::kFloat32, nullptr, nullptr, nullptr); + } + te_cu_seqlens_q = makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), {b+1}, + DType::kInt32, nullptr, nullptr, nullptr); + te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), {b+1}, + DType::kInt32, nullptr, nullptr, nullptr); + + // extract rng seed and offset + auto gen = at::get_generator_or_default( + rng_gen, at::cuda::detail::getDefaultCUDAGenerator()); + at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread); + auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); + unpack<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>( + philox_args, static_cast(rng_state.data_ptr())); + auto te_rng_state = makeTransformerEngineTensor(rng_state); + + // create auxiliary output tensors + NVTETensorPack nvte_aux_tensor_pack; + nvte_tensor_pack_create(&nvte_aux_tensor_pack); + + // create workspace + TensorWrapper workspace; + + // populate tensors with appropriate shapes and dtypes + nvte_fused_attn_fwd_kvpacked( + te_Q.data(), + te_KV.data(), + te_Bias.data(), + te_S.data(), + te_O.data(), + &nvte_aux_tensor_pack, + te_cu_seqlens_q.data(), + te_cu_seqlens_kv.data(), + te_rng_state.data(), + max_seqlen_q, max_seqlen_kv, + is_training, attn_scale, p_dropout, + qkv_layout, bias_type, attn_mask_type, + workspace.data(), + at::cuda::getCurrentCUDAStream()); + + // allocate memory for workspace and auxiliary output tensors + auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); + workspace = makeTransformerEngineTensor( + workspace_data.data_ptr(), + workspace.shape(), workspace.dtype()); + + // output_tensors = [O, nvte_aux_tensor_pack.tensors] + std::vector output_tensors; + output_tensors.push_back(O); + for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { + auto tensor = reinterpret_cast(nvte_aux_tensor_pack.tensors[i]); + // allocate memory for nvte_aux_tensor_pack.tensors + auto output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false); + output_tensors.push_back(output_tensor); + tensor->data.dptr = output_tensor.data_ptr(); + } + + // execute the kernel + nvte_fused_attn_fwd_kvpacked( + te_Q.data(), + te_KV.data(), + te_Bias.data(), + te_S.data(), + te_O.data(), + &nvte_aux_tensor_pack, + te_cu_seqlens_q.data(), + te_cu_seqlens_kv.data(), + te_rng_state.data(), + max_seqlen_q, max_seqlen_kv, + is_training, attn_scale, p_dropout, + qkv_layout, bias_type, attn_mask_type, + workspace.data(), + at::cuda::getCurrentCUDAStream()); + + // destroy tensor wrappers, but not allocated memory + nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); + + // if training, [O, softmax-related tensors, rng_state]; if inference, [O] + return output_tensors; +} + +// fused attention BWD with packed KV +std::vector fused_attn_bwd_kvpacked( + size_t b, size_t max_seqlen_q, size_t max_seqlen_kv, + size_t total_seqs_q, size_t total_seqs_kv, + size_t h, size_t d, + float attn_scale, float p_dropout, bool set_zero, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + const at::Tensor cu_seqlens_q, + const at::Tensor cu_seqlens_kv, + const at::Tensor Q, + const at::Tensor KV, + const at::Tensor O, + const at::Tensor dO, + const transformer_engine::DType qkv_type, + const std::vector Aux_CTX_Tensors, + const c10::optional descale_QKV, + const c10::optional descale_S, + const c10::optional descale_O, + const c10::optional descale_dO, + const c10::optional scale_S, + const c10::optional scale_dP, + const c10::optional scale_dQKV, + c10::optional amax_dP, + c10::optional amax_dQKV) { + using namespace transformer_engine; + + // create output tensors dQ and dKV + at::Tensor dQ = torch::empty_like(Q); + at::Tensor dKV = torch::empty_like(KV); + auto max_tokens_q = dQ.size(0); + auto self_2d_q = dQ.view({max_tokens_q, -1}); + auto fcd_size_q = self_2d_q.size(1); + auto max_tokens_kv = dQ.size(0); + auto self_2d_kv = dQ.view({max_tokens_kv, -1}); + auto fcd_size_kv = self_2d_kv.size(1); + if (set_zero && (fcd_size_q % block_size == 0) && (fcd_size_kv % block_size == 0)) { + mha_fill(dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); + mha_fill(dKV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); + } else { + dQ.fill_(0); + dKV.fill_(0); + } + auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA); + at::Tensor dBias; + TensorWrapper te_dBias; + if (bias_type != NVTE_NO_BIAS) { + dBias = torch::zeros({1, static_cast(h), + static_cast(max_seqlen_q), + static_cast(max_seqlen_kv)}, options); + te_dBias = makeTransformerEngineTensor(dBias); + } + + // construct NVTE tensors + TensorWrapper te_Q, te_KV, te_O, te_dO, te_S, te_dP, te_dQ, te_dKV; + if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { + // FP8 + if ((!descale_QKV.has_value()) || (!descale_S.has_value()) + || (!descale_O.has_value()) || (!descale_dO.has_value()) + || (!scale_S.has_value()) || (!scale_dP.has_value()) + || (!scale_dQKV.has_value()) + || (!amax_dP.has_value()) || (!amax_dQKV.has_value())) { + std::string err_tensors = "descale_QKV, descale_S, descale_O, scale_S, scale_dP, "; + err_tensors = err_tensors + std::string("scale_dQKV, amax_dP and amax_dQKV"); + NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); + } + te_Q = makeTransformerEngineTensor(Q.data_ptr(), {total_seqs_q, h, d}, + qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); + te_KV = makeTransformerEngineTensor(KV.data_ptr(), {total_seqs_kv, 2, h, d}, + qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); + te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs_q, h, d}, + qkv_type, nullptr, nullptr, descale_O.value().data_ptr()); + te_dO = makeTransformerEngineTensor(dO.data_ptr(), {total_seqs_q, h, d}, + qkv_type, nullptr, nullptr, descale_dO.value().data_ptr()); + te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, + scale_S.value().data_ptr(), descale_S.value().data_ptr()); + at::Tensor descale_dP = torch::empty_like(scale_dP.value()); + te_dP = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, + amax_dP.value().data_ptr(), scale_dP.value().data_ptr(), + descale_dP.data_ptr()); + te_dQ = makeTransformerEngineTensor(dQ.data_ptr(), {total_seqs_q, h, d}, qkv_type, + amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr); + te_dKV = makeTransformerEngineTensor(dKV.data_ptr(), {total_seqs_kv, 2, h, d}, qkv_type, + amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr); + } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { + // BF16 or FP16 + te_Q = makeTransformerEngineTensor(Q.data_ptr(), {total_seqs_q, h, d}, + qkv_type, nullptr, nullptr, nullptr); + te_KV = makeTransformerEngineTensor(KV.data_ptr(), {total_seqs_kv, 2, h, d}, + qkv_type, nullptr, nullptr, nullptr); + te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs_q, h, d}, + qkv_type, nullptr, nullptr, nullptr); + te_dO = makeTransformerEngineTensor(dO.data_ptr(), {total_seqs_q, h, d}, + qkv_type, nullptr, nullptr, nullptr); + te_S = makeTransformerEngineTensor(nullptr, {0}, + DType::kFloat32, nullptr, nullptr, nullptr); + te_dP = makeTransformerEngineTensor(nullptr, {0}, + DType::kFloat32, nullptr, nullptr, nullptr); + te_dQ = makeTransformerEngineTensor(dQ.data_ptr(), {total_seqs_q, h, d}, + qkv_type, nullptr, nullptr, nullptr); + te_dKV = makeTransformerEngineTensor(dKV.data_ptr(), {total_seqs_kv, 2, h, d}, + qkv_type, nullptr, nullptr, nullptr); + } else { + NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); + } + + // create cu_seqlens tensorwrappers + TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv; + te_cu_seqlens_q = makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), {b+1}, + DType::kInt32, nullptr, nullptr, nullptr); + te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), {b+1}, + DType::kInt32, nullptr, nullptr, nullptr); + + // convert auxiliary tensors from forward to NVTETensors + NVTETensorPack nvte_aux_tensor_pack; + nvte_tensor_pack_create(&nvte_aux_tensor_pack); + nvte_aux_tensor_pack.size = Aux_CTX_Tensors.size(); + for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { + auto tensor = reinterpret_cast(nvte_aux_tensor_pack.tensors[i]); + tensor->data.dptr = Aux_CTX_Tensors[i].data_ptr(); + std::vector tmp(Aux_CTX_Tensors[i].sizes().vec()); + tensor->data.shape = std::vector(tmp.begin(), tmp.end()); + tensor->data.dtype = GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type()); + } + + // create workspace + TensorWrapper workspace; + + // populate tensors with appropriate shapes and dtypes + nvte_fused_attn_bwd_kvpacked( + te_Q.data(), + te_KV.data(), + te_O.data(), + te_dO.data(), + te_S.data(), + te_dP.data(), + &nvte_aux_tensor_pack, + te_dQ.data(), + te_dKV.data(), + te_dBias.data(), + te_cu_seqlens_q.data(), + te_cu_seqlens_kv.data(), + max_seqlen_q, max_seqlen_kv, + attn_scale, p_dropout, + qkv_layout, bias_type, attn_mask_type, + workspace.data(), + at::cuda::getCurrentCUDAStream()); + + // allocate memory for workspace + auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); + workspace = makeTransformerEngineTensor( + workspace_data.data_ptr(), + workspace.shape(), workspace.dtype()); + + // execute kernel + nvte_fused_attn_bwd_kvpacked( + te_Q.data(), + te_KV.data(), + te_O.data(), + te_dO.data(), + te_S.data(), + te_dP.data(), + &nvte_aux_tensor_pack, + te_dQ.data(), + te_dKV.data(), + te_dBias.data(), + te_cu_seqlens_q.data(), + te_cu_seqlens_kv.data(), + max_seqlen_q, max_seqlen_kv, + attn_scale, p_dropout, + qkv_layout, bias_type, attn_mask_type, + workspace.data(), + at::cuda::getCurrentCUDAStream()); + + // destroy tensor wrappers + nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); + + return {dQ, dKV, dBias}; +} + +namespace flash_attention { + +constexpr int warp_size = 32; +constexpr int type_size = 2; // FP16 or BF16 +constexpr int nvec = sizeof(uint64_t) / type_size; +constexpr int load_size = warp_size * nvec; +constexpr int block_size = 512; + +template +__launch_bounds__(block_size) +__global__ void prepare_kernel_fwd(const T *qkvi, + T *qkv, + const size_t B, + const size_t S, + const size_t Z, + const size_t W) { + const int warpid = (blockDim.x * blockIdx.x + threadIdx.x) / warp_size; + const int id_in_warp = threadIdx.x % warp_size; + const size_t offset_input = blockIdx.y * W + warpid * 3 * W * Z + id_in_warp * nvec; + const T *my_input = qkvi + offset_input; + + const size_t s = warpid / B; + if (s >= S) return; + + const size_t b = warpid % B; + + const size_t offset_output = blockIdx.y * B * S * Z * W + + (s + b * S) * W * Z + + id_in_warp * nvec; + + T *my_output = qkv + offset_output; + + for (int i = 0; i < Z; ++i) { + uint64_t *out = reinterpret_cast(my_output + i * load_size); + *out = *reinterpret_cast(my_input + i * load_size * 3); + } +} + +template +__launch_bounds__(block_size) +__global__ void prepare_kernel_bwd(const T *q, const T *k, const T *v, + T *qkv, const size_t B, const size_t S, + const size_t Z, const size_t W) { + const T *input = blockIdx.y == 0 ? q : (blockIdx.y == 1 ? k : v); + + const int warpid = (blockDim.x * blockIdx.x + threadIdx.x) / warp_size; + const int id_in_warp = threadIdx.x % warp_size; + const size_t offset_input = warpid * W * Z + id_in_warp * nvec; + const T *my_input = input + offset_input; + + const size_t b = warpid / S; + if (b >= B) return; + + const size_t s = warpid % S; + + const size_t offset_output = (b + s * B) * 3 * W * Z + + id_in_warp * nvec + blockIdx.y * W; + + T *my_output = qkv + offset_output; + + for (int i = 0; i < Z; ++i) { + uint64_t *out = reinterpret_cast(my_output + i * load_size * 3); + *out = *reinterpret_cast(my_input + i * load_size); + } +} + +} // namespace flash_attention + +at::Tensor fa_prepare_fwd(at::Tensor qkvi) { + NVTE_CHECK(qkvi.dim() == 4, "Expected 4-dim tensor."); + NVTE_CHECK(qkvi.scalar_type() == at::ScalarType::Half || + qkvi.scalar_type() == at::ScalarType::BFloat16); + NVTE_CHECK(qkvi.size(3) % flash_attention::load_size == 0); + NVTE_CHECK(qkvi.size(3) == flash_attention::load_size); + NVTE_CHECK(qkvi.stride(3) == 1, "Wrong stride."); + NVTE_CHECK(qkvi.stride(2) == 3 * qkvi.size(3), "Wrong stride."); + NVTE_CHECK(qkvi.stride(1) == 3 * qkvi.size(3) * qkvi.size(2), "Wrong stride."); + NVTE_CHECK(qkvi.stride(0) == 3 * qkvi.size(3) * qkvi.size(2) * qkvi.size(1), "Wrong stride."); + + // [s, b, n, h * 3] -> [3, b, s, n, h] + std::vector shape = {3, qkvi.size(1), qkvi.size(0), qkvi.size(2), qkvi.size(3)}; + at::Tensor qkv = at::empty(shape, at::CUDA(qkvi.scalar_type())); + + size_t warps = qkvi.size(0) * qkvi.size(1); + size_t warps_per_block = flash_attention::block_size / flash_attention::warp_size; + size_t blocks = (warps + warps_per_block - 1) / warps_per_block; + dim3 grid(blocks, 3); + int threads = flash_attention::block_size; + if (qkvi.scalar_type() == at::ScalarType::Half) { + using dtype = at::Half; + flash_attention::prepare_kernel_fwd<<>>( + qkvi.data_ptr(), + qkv.data_ptr(), + shape[1], + shape[2], + shape[3], + shape[4]); + } else { + using dtype = at::BFloat16; + flash_attention::prepare_kernel_fwd<<>>( + qkvi.data_ptr(), + qkv.data_ptr(), + shape[1], + shape[2], + shape[3], + shape[4]); + } + + return qkv; +} + +at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v) { + NVTE_CHECK(q.is_contiguous()); + NVTE_CHECK(k.is_contiguous()); + NVTE_CHECK(v.is_contiguous()); + NVTE_CHECK(q.dim() == 4, "Expected 4-dim tensor."); + NVTE_CHECK(k.dim() == 4, "Expected 4-dim tensor."); + NVTE_CHECK(v.dim() == 4, "Expected 4-dim tensor."); + NVTE_CHECK(q.scalar_type() == at::ScalarType::Half || + q.scalar_type() == at::ScalarType::BFloat16); + NVTE_CHECK(k.scalar_type() == q.scalar_type()); + NVTE_CHECK(v.scalar_type() == q.scalar_type()); + NVTE_CHECK(q.size(3) % flash_attention::load_size == 0); + NVTE_CHECK(q.size(3) == flash_attention::load_size); + NVTE_CHECK(k.size(3) % flash_attention::load_size == 0); + NVTE_CHECK(k.size(3) == flash_attention::load_size); + NVTE_CHECK(v.size(3) % flash_attention::load_size == 0); + NVTE_CHECK(v.size(3) == flash_attention::load_size); + + // 3 x [s, b, n, h] -> [b, s, n, 3 * h] + + std::vector shape = {q.size(1), q.size(0), q.size(2), 3 * q.size(3)}; + at::Tensor qkv = at::empty(shape, at::CUDA(q.scalar_type())); + + size_t warps = q.size(0) * q.size(1); + size_t warps_per_block = flash_attention::block_size / flash_attention::warp_size; + size_t blocks = (warps + warps_per_block - 1) / warps_per_block; + dim3 grid(blocks, 3); + int threads = flash_attention::block_size; + if (q.scalar_type() == at::ScalarType::Half) { + using dtype = at::Half; + flash_attention::prepare_kernel_bwd<<>>( + q.data_ptr(), + k.data_ptr(), + v.data_ptr(), + qkv.data_ptr(), + q.size(0), + q.size(1), + q.size(2), + q.size(3)); + } else { + using dtype = at::BFloat16; + flash_attention::prepare_kernel_bwd<<>>( + q.data_ptr(), + k.data_ptr(), + v.data_ptr(), + qkv.data_ptr(), + q.size(0), + q.size(1), + q.size(2), + q.size(3)); + } + + return qkv; +} diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cu b/transformer_engine/pytorch/csrc/extensions/cast.cu new file mode 100644 index 0000000000..0e886e4107 --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/cast.cu @@ -0,0 +1,75 @@ +/************************************************************************* + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "extensions.h" + +at::Tensor cast_to_fp8(const at::Tensor &input, + const at::Tensor &scale, + at::Tensor amax, + at::Tensor scale_inv, + transformer_engine::DType otype +) { + using namespace transformer_engine; + auto input_shape = input.sizes().vec(); + std::vector shape{input_shape.begin(), input_shape.end()}; + + auto output = at::empty_like(input, at::CUDA(GetATenDType(otype))); + + auto input_cu = makeTransformerEngineTensor(input); + auto output_cu = makeTransformerEngineTensor(output.data_ptr(), shape, otype, + amax.data_ptr(), scale.data_ptr(), + scale_inv.data_ptr()); + + nvte_fp8_quantize(input_cu.data(), output_cu.data(), + at::cuda::getCurrentCUDAStream()); + + return output; +} + + +void cast_to_fp8_noalloc(const at::Tensor &input, + const at::Tensor &scale, + at::Tensor output, + at::Tensor amax, + at::Tensor scale_inv, + transformer_engine::DType otype +) { + using namespace transformer_engine; + size_t N = static_cast(input.size(0)); + size_t H = static_cast(input.size(1)); + + auto input_cu = makeTransformerEngineTensor(input); + auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, H}, otype, + amax.data_ptr(), scale.data_ptr(), + scale_inv.data_ptr()); + + nvte_fp8_quantize(input_cu.data(), output_cu.data(), + at::cuda::getCurrentCUDAStream()); + + return; +} + + +at::Tensor cast_from_fp8(const at::Tensor &input, + const at::Tensor &scale_inv, + transformer_engine::DType itype, + transformer_engine::DType otype +) { + using namespace transformer_engine; + auto input_shape = input.sizes().vec(); + std::vector shape{input_shape.begin(), input_shape.end()}; + + auto output = at::empty_like(input, at::CUDA(GetATenDType(otype))); + + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), shape, itype, + nullptr, nullptr, scale_inv.data_ptr()); + auto output_cu = makeTransformerEngineTensor(output); + + nvte_fp8_dequantize(input_cu.data(), output_cu.data(), + at::cuda::getCurrentCUDAStream()); + + return output; +} diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cu b/transformer_engine/pytorch/csrc/extensions/gemm.cu new file mode 100644 index 0000000000..1a7630edce --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cu @@ -0,0 +1,75 @@ +/************************************************************************* + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "extensions.h" + +void te_gemm(at::Tensor A, + at::Tensor A_scale_inverse, + transformer_engine::DType A_type, + bool transa, + at::Tensor B, + at::Tensor B_scale_inverse, + transformer_engine::DType B_type, + bool transb, + at::Tensor D, + at::Tensor D_scale, + transformer_engine::DType D_type, + at::Tensor D_amax, + at::Tensor bias, + transformer_engine::DType bias_type, + at::Tensor pre_gelu_out, + bool grad, + at::Tensor workspace, + size_t workspaceSize, + bool accumulate, + bool use_split_accumulator, + int math_sm_count +) { + using namespace transformer_engine; + auto te_A = makeTransformerEngineTensor(A.data_ptr(), + {static_cast(A.size(0)), + static_cast(A.size(1))}, + A_type, nullptr, nullptr, + A_scale_inverse.data_ptr()); + auto te_B = makeTransformerEngineTensor(B.data_ptr(), + {static_cast(B.size(0)), + static_cast(B.size(1))}, + B_type, nullptr, nullptr, + B_scale_inverse.data_ptr()); + auto te_D = makeTransformerEngineTensor(D.data_ptr(), + {static_cast(D.size(0)), + static_cast(D.size(1))}, + D_type, D_amax.data_ptr(), + D_scale.data_ptr(), nullptr); + auto te_bias = makeTransformerEngineTensor(bias.data_ptr(), {static_cast(bias.size(0))}, + bias_type); + + const auto gelu_shape = pre_gelu_out.data_ptr() == nullptr + ? std::vector{static_cast(pre_gelu_out.size(0))} + : std::vector{static_cast(pre_gelu_out.size(0)), + static_cast(pre_gelu_out.size(1))}; + auto te_pre_gelu_out = makeTransformerEngineTensor(pre_gelu_out.data_ptr(), + gelu_shape, + GetTransformerEngineDType( + pre_gelu_out.scalar_type())); + auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(), + {workspaceSize}, + DType::kByte); + + nvte_cublas_gemm(te_A.data(), + te_B.data(), + te_D.data(), + te_bias.data(), + te_pre_gelu_out.data(), + transa, + transb, + grad, + te_workspace.data(), + accumulate, + use_split_accumulator, + math_sm_count, + at::cuda::getCurrentCUDAStream()); +} diff --git a/transformer_engine/pytorch/csrc/extensions/misc.cu b/transformer_engine/pytorch/csrc/extensions/misc.cu new file mode 100644 index 0000000000..e6275d1159 --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/misc.cu @@ -0,0 +1,25 @@ +/************************************************************************* + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "extensions.h" +#ifdef NVTE_WITH_USERBUFFERS +#include "comm_gemm_overlap.h" +#endif // NVTE_WITH_USERBUFFERS + +size_t get_cublasLt_version() { + return cublasLtGetVersion(); +} + + +bool userbuf_comm_available() { // TODO(ksivamani) check on python side +#ifdef NVTE_WITH_USERBUFFERS + return true; +#else + return false; +#endif +} + +void placeholder() {} // TODO(ksivamani) clean this up diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cu b/transformer_engine/pytorch/csrc/extensions/normalization.cu new file mode 100644 index 0000000000..6c723cd37f --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cu @@ -0,0 +1,404 @@ +/************************************************************************* + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "extensions.h" + +std::vector layernorm_bwd(const at::Tensor &dz, + const at::Tensor &x, + const at::Tensor &mu, + const at::Tensor &rsigma, + const at::Tensor &gamma, + const int sm_margin, + const bool zero_centered_gamma +) { + auto dx = at::empty_like(x); + auto dgamma = at::empty_like(gamma); + auto dbeta = at::empty_like(gamma); + transformer_engine::TensorWrapper workspace, barrier, dgamma_part, dbeta_part; + + auto dz_cu = makeTransformerEngineTensor(dz); + auto x_cu = makeTransformerEngineTensor(x); + auto mu_cu = makeTransformerEngineTensor(mu); + auto rsigma_cu = makeTransformerEngineTensor(rsigma); + auto gamma_cu = makeTransformerEngineTensor(gamma); + auto dx_cu = makeTransformerEngineTensor(dx); + auto dgamma_cu = makeTransformerEngineTensor(dgamma); + auto dbeta_cu = makeTransformerEngineTensor(dbeta); + + // This call populates tensors with the required config. + const auto bwd_fun = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd; + bwd_fun(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), + dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), dgamma_part.data(), + dbeta_part.data(), at::cuda::getCurrentCUDAStream(), + at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, + workspace.data(), barrier.data()); + + // Alloc space for Tensors. + auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); + auto barrier_data = allocateSpace(barrier.shape(), barrier.dtype(), true); + auto dgamma_part_data = allocateSpace(dgamma_part.shape(), dgamma_part.dtype()); + auto dbeta_part_data = allocateSpace(dbeta_part.shape(), dbeta_part.dtype()); + workspace = makeTransformerEngineTensor(workspace_data.data_ptr(), + workspace.shape(), + workspace.dtype()); + barrier = makeTransformerEngineTensor(barrier_data.data_ptr(), + barrier.shape(), + barrier.dtype()); + dgamma_part = makeTransformerEngineTensor(dgamma_part_data.data_ptr(), + dgamma_part.shape(), + dgamma_part.dtype()); + dbeta_part = makeTransformerEngineTensor(dbeta_part_data.data_ptr(), + dbeta_part.shape(), + dbeta_part.dtype()); + + // Actual call to bwd kernel. + bwd_fun(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), + dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), dgamma_part.data(), + dbeta_part.data(), at::cuda::getCurrentCUDAStream(), + at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, + workspace.data(), barrier.data()); + + return { dx, dgamma, dbeta }; +} + + +std::vector layernorm_fwd_fp8(const at::Tensor &input, + const at::Tensor &weight, + const at::Tensor &bias, + float eps, + at::Tensor scale, + at::Tensor amax, + at::Tensor scale_inv, + transformer_engine::DType otype, + const int sm_margin, + const bool zero_centered_gamma +) { + using namespace transformer_engine; + + auto ln_out = at::empty_like(input, at::CUDA(GetATenDType(otype))); + return layernorm_fwd_fp8_noalloc(input, weight, bias, eps, + scale, ln_out, amax, scale_inv, + otype, sm_margin, zero_centered_gamma); +} + + +std::vector layernorm_fwd_fp8_noalloc(const at::Tensor &input, + const at::Tensor &weight, + const at::Tensor &bias, + float eps, + at::Tensor scale, + at::Tensor ln_out, + at::Tensor amax, + at::Tensor scale_inv, + transformer_engine::DType otype, + const int sm_margin, + const bool zero_centered_gamma +) { + using namespace transformer_engine; + + size_t N = static_cast(input.size(0)); + size_t H = static_cast(input.size(1)); + + DType itype = GetTransformerEngineDType(input.scalar_type()); + + auto mu = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); + auto rsigma = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); + auto input_cu = makeTransformerEngineTensor(input); + auto gamma_cu = makeTransformerEngineTensor(weight); + auto beta_cu = makeTransformerEngineTensor(bias); + auto z_cu = makeTransformerEngineTensor(ln_out.data_ptr(), {N, H}, otype, + getDataPtr(amax), getDataPtr(scale), + getDataPtr(scale_inv)); + auto mu_cu = makeTransformerEngineTensor(mu); + auto rsigma_cu = makeTransformerEngineTensor(rsigma); + transformer_engine::TensorWrapper workspace, barrier; + + // This call populates workspace and barrier tensors with the required config + const auto func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd; + func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), + mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(), + at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, + workspace.data(), barrier.data()); + + // Fill workspace and barrier + auto workspace_data = allocateSpace(workspace.shape(), + workspace.dtype()); + auto barrier_data = allocateSpace(barrier.shape(), + barrier.dtype(), + true); + workspace = makeTransformerEngineTensor(workspace_data.data_ptr(), + workspace.shape(), + workspace.dtype()); + barrier = makeTransformerEngineTensor(barrier_data.data_ptr(), + barrier.shape(), + barrier.dtype()); + + // Actual call to fwd kernel + func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), + mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(), + at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, + workspace.data(), barrier.data()); + + return {ln_out, mu, rsigma}; +} + + +at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input, + const at::Tensor &weight, + const at::Tensor &bias, + float eps, + at::Tensor scale, + at::Tensor amax, + at::Tensor scale_inv, + transformer_engine::DType otype, + const bool zero_centered_gamma +) { + // This is a specialized version of layernorm_fwd_fp8, optimized for inference, + // which only returns the normalized output. + std::vector out = layernorm_fwd_fp8( + input, weight, bias, eps, scale, amax, scale_inv, otype, 0, zero_centered_gamma); + return out[0]; +} + + +std::vector layernorm_fwd(const at::Tensor &input, + const at::Tensor &weight, + const at::Tensor &bias, + float eps, + const int sm_margin, + const bool zero_centered_gamma +) { + using namespace transformer_engine; + + DType itype = GetTransformerEngineDType(input.scalar_type()); + auto ln_out = at::empty_like(input, at::CUDA(GetATenDType(itype))); + + return layernorm_fwd_noalloc(input, weight, bias, ln_out, eps, + sm_margin, zero_centered_gamma); +} + + +std::vector layernorm_fwd_noalloc(const at::Tensor &input, + const at::Tensor &weight, + const at::Tensor &bias, + at::Tensor ln_out, + float eps, + const int sm_margin, + const bool zero_centered_gamma +) { + using namespace transformer_engine; + + DType itype = GetTransformerEngineDType(input.scalar_type()); + + return layernorm_fwd_fp8_noalloc(input, weight, bias, eps, at::Tensor(), + ln_out, at::Tensor(), at::Tensor(), + itype, sm_margin, zero_centered_gamma); +} + + +at::Tensor layernorm_fwd_inf(const at::Tensor &input, + const at::Tensor &weight, + const at::Tensor &bias, + float eps, + const bool zero_centered_gamma +) { + // This is a specialized version of layernorm_fwd, optimized for inference, + // which only returns the normalized output. + std::vector out = layernorm_fwd(input, weight, bias, eps, 0, zero_centered_gamma); + return out[0]; +} + +std::vector rmsnorm_bwd(const at::Tensor &dz, + const at::Tensor &x, + const at::Tensor &rsigma, + const at::Tensor &gamma, + const int sm_margin, + const bool zero_centered_gamma +) { + NVTE_CHECK(zero_centered_gamma == false, + "Zero-centered gamma is not supported yet for RMSNorm."); + auto dx = at::empty_like(x); + auto dgamma = at::empty_like(gamma); + transformer_engine::TensorWrapper workspace, barrier, dgamma_part; + + auto dz_cu = makeTransformerEngineTensor(dz); + auto x_cu = makeTransformerEngineTensor(x); + auto rsigma_cu = makeTransformerEngineTensor(rsigma); + auto gamma_cu = makeTransformerEngineTensor(gamma); + auto dx_cu = makeTransformerEngineTensor(dx); + auto dgamma_cu = makeTransformerEngineTensor(dgamma); + + // This call populates tensors with the required config. + const auto bwd_fun = nvte_rmsnorm_bwd; + bwd_fun(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), + dx_cu.data(), dgamma_cu.data(), dgamma_part.data(), + at::cuda::getCurrentCUDAStream(), + at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, + workspace.data(), barrier.data()); + + // Alloc space for Tensors. + auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); + auto barrier_data = allocateSpace(barrier.shape(), barrier.dtype(), true); + auto dgamma_part_data = allocateSpace(dgamma_part.shape(), dgamma_part.dtype()); + workspace = makeTransformerEngineTensor(workspace_data.data_ptr(), + workspace.shape(), + workspace.dtype()); + barrier = makeTransformerEngineTensor(barrier_data.data_ptr(), + barrier.shape(), + barrier.dtype()); + dgamma_part = makeTransformerEngineTensor(dgamma_part_data.data_ptr(), + dgamma_part.shape(), + dgamma_part.dtype()); + + // Actual call to bwd kernel. + bwd_fun(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), + dx_cu.data(), dgamma_cu.data(), dgamma_part.data(), + at::cuda::getCurrentCUDAStream(), + at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, + workspace.data(), barrier.data()); + + return { dx, dgamma }; +} + + +std::vector rmsnorm_fwd_fp8(const at::Tensor &input, + const at::Tensor &weight, + float eps, + at::Tensor scale, + at::Tensor amax, + at::Tensor scale_inv, + transformer_engine::DType otype, + const int sm_margin, + const bool zero_centered_gamma +) { + using namespace transformer_engine; + + auto ln_out = at::empty_like(input, at::CUDA(GetATenDType(otype))); + return rmsnorm_fwd_fp8_noalloc(input, weight, eps, + scale, ln_out, amax, scale_inv, + otype, sm_margin, zero_centered_gamma); +} + + +std::vector rmsnorm_fwd_fp8_noalloc(const at::Tensor &input, + const at::Tensor &weight, + float eps, + at::Tensor scale, + at::Tensor ln_out, + at::Tensor amax, + at::Tensor scale_inv, + transformer_engine::DType otype, + const int sm_margin, + const bool zero_centered_gamma +) { + using namespace transformer_engine; + NVTE_CHECK(zero_centered_gamma == false, + "Zero-centered gamma is not supported yet for RMSNorm."); + + size_t N = static_cast(input.size(0)); + size_t H = static_cast(input.size(1)); + + DType itype = GetTransformerEngineDType(input.scalar_type()); + + auto rsigma = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); + auto input_cu = makeTransformerEngineTensor(input); + auto gamma_cu = makeTransformerEngineTensor(weight); + auto z_cu = makeTransformerEngineTensor(ln_out.data_ptr(), {N, H}, otype, + getDataPtr(amax), getDataPtr(scale), + getDataPtr(scale_inv)); + auto rsigma_cu = makeTransformerEngineTensor(rsigma); + transformer_engine::TensorWrapper workspace, barrier; + + // This call populates workspace and barrier tensors with the required config + const auto func = nvte_rmsnorm_fwd; + func(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), + rsigma_cu.data(), at::cuda::getCurrentCUDAStream(), + at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, + workspace.data(), barrier.data()); + + // Fill workspace and barrier + auto workspace_data = allocateSpace(workspace.shape(), + workspace.dtype()); + auto barrier_data = allocateSpace(barrier.shape(), + barrier.dtype(), + true); + workspace = makeTransformerEngineTensor(workspace_data.data_ptr(), + workspace.shape(), + workspace.dtype()); + barrier = makeTransformerEngineTensor(barrier_data.data_ptr(), + barrier.shape(), + barrier.dtype()); + + // Actual call to fwd kernel + func(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), + rsigma_cu.data(), at::cuda::getCurrentCUDAStream(), + at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, + workspace.data(), barrier.data()); + + return {ln_out, rsigma}; +} + + +at::Tensor rmsnorm_fwd_fp8_inf(const at::Tensor &input, + const at::Tensor &weight, + float eps, + at::Tensor scale, + at::Tensor amax, + at::Tensor scale_inv, + transformer_engine::DType otype, + const bool zero_centered_gamma +) { + // This is a specialized version of rmsnorm_fwd_fp8, optimized for inference, + // which only returns the normalized output. + std::vector out = rmsnorm_fwd_fp8( + input, weight, eps, scale, amax, scale_inv, otype, 0, zero_centered_gamma); + return out[0]; +} + + +std::vector rmsnorm_fwd(const at::Tensor &input, + const at::Tensor &weight, + float eps, + const int sm_margin, + const bool zero_centered_gamma +) { + using namespace transformer_engine; + + DType itype = GetTransformerEngineDType(input.scalar_type()); + auto ln_out = at::empty_like(input, at::CUDA(GetATenDType(itype))); + + return rmsnorm_fwd_noalloc(input, weight, ln_out, eps, + sm_margin, zero_centered_gamma); +} + + +std::vector rmsnorm_fwd_noalloc(const at::Tensor &input, + const at::Tensor &weight, + at::Tensor ln_out, + float eps, + const int sm_margin, + const bool zero_centered_gamma +) { + using namespace transformer_engine; + + DType itype = GetTransformerEngineDType(input.scalar_type()); + + return rmsnorm_fwd_fp8_noalloc(input, weight, eps, at::Tensor(), + ln_out, at::Tensor(), at::Tensor(), + itype, sm_margin, zero_centered_gamma); +} + + +at::Tensor rmsnorm_fwd_inf(const at::Tensor &input, + const at::Tensor &weight, + float eps, + const bool zero_centered_gamma +) { + // This is a specialized version of rmsnorm_fwd, optimized for inference, + // which only returns the normalized output. + std::vector out = rmsnorm_fwd(input, weight, eps, 0, zero_centered_gamma); + return out[0]; +} diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp new file mode 100644 index 0000000000..6dc48a4b5c --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -0,0 +1,158 @@ +/************************************************************************* + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "../extensions.h" +#ifdef NVTE_WITH_USERBUFFERS +#include "comm_gemm_overlap.h" +#endif // NVTE_WITH_USERBUFFERS + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + // Softmax functions + m.def("scaled_softmax_forward", &scaled_softmax_forward, "Scaled Softmax FWD"); + m.def("scaled_softmax_backward", &scaled_softmax_backward, "Scaled Softmax BWD"); + m.def("scaled_masked_softmax_forward", &scaled_masked_softmax_forward, + "Scaled Masked Softmax FWD"); + m.def("scaled_masked_softmax_backward", &scaled_masked_softmax_backward, + "Scaled Masked Softmax BWD"); + m.def("scaled_upper_triang_masked_softmax_forward", + &scaled_upper_triang_masked_softmax_forward, + "Scaled Upper-Triangular Masked Softmax FWD"); + m.def("scaled_upper_triang_masked_softmax_backward", + &scaled_upper_triang_masked_softmax_backward, + "Scaled Upper-Triangular Masked Softmax BWD"); + + // Other granular functions + m.def("layernorm_fwd_fp8", &layernorm_fwd_fp8, "LN FWD FP8"); + m.def("layernorm_fwd_fp8_noalloc", &layernorm_fwd_fp8_noalloc, "LN FWD FP8"); + m.def("layernorm_bwd", &layernorm_bwd, "LN BWD"); + m.def("layernorm_fwd", &layernorm_fwd, "LN FWD"); + m.def("layernorm_fwd_noalloc", &layernorm_fwd_noalloc, "LN FWD"); + m.def("rmsnorm_fwd_fp8", &rmsnorm_fwd_fp8, "LN FWD FP8"); + m.def("rmsnorm_fwd_fp8_noalloc", &rmsnorm_fwd_fp8_noalloc, "LN FWD FP8"); + m.def("rmsnorm_bwd", &rmsnorm_bwd, "LN BWD"); + m.def("rmsnorm_fwd", &rmsnorm_fwd, "LN FWD"); + m.def("rmsnorm_fwd_noalloc", &rmsnorm_fwd_noalloc, "LN FWD"); + m.def("fused_cast_transpose", &fused_cast_transpose, "Fused Cast + Transpose"); + m.def("fused_cast_transpose_bgrad", &fused_cast_transpose_bgrad, + "Fused Cast + Transpose + BGRAD"); + m.def("fused_fp8_transpose_bgrad", &fused_fp8_transpose_bgrad, + "Fused FP8 Transpose + BGRAD"); + m.def("fused_cast_transpose_bgrad_dgelu", &fused_cast_transpose_bgrad_dgelu, + "Fused Cast + Transpose + BGRAD + DGELU"); + m.def("fused_multi_cast_transpose", &fused_multi_cast_transpose, + "Fused Multi-tensor Cast + Transpose"); + m.def("cast_to_fp8", &cast_to_fp8, "Cast to FP8"); + m.def("cast_to_fp8_noalloc", &cast_to_fp8_noalloc, "Cast to FP8"); + m.def("cast_from_fp8", &cast_from_fp8, "Cast from FP8"); + m.def("te_gemm", &te_gemm, "CublasLt GEMM"); + m.def("fused_attn_fwd_qkvpacked", &fused_attn_fwd_qkvpacked, + "Fused Attention FP8/BF16/FP16 FWD with packed QKV"); + m.def("fused_attn_bwd_qkvpacked", &fused_attn_bwd_qkvpacked, + "Fused Attention FP8/BF16/FP16 BWD with packed QKV"); + m.def("fused_attn_fwd_kvpacked", &fused_attn_fwd_kvpacked, + "Fused Attention FP8/BF16/FP16 FWD with packed KV"); + m.def("fused_attn_bwd_kvpacked", &fused_attn_bwd_kvpacked, + "Fused Attention FP8/BF16/FP16 BWD with packed KV"); + m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O"); + m.def("gelu", &gelu, "GeLU with FP8 output"); + m.def("relu", &relu, "ReLU with FP8 output"); + m.def("geglu", &geglu, "GeGLU with FP8 output"); + m.def("reglu", ®lu, "ReGLU with FP8 output"); + m.def("swiglu", &swiglu, "SwiGLU with FP8 output"); + m.def("dgelu", &dgelu, "Backward of GeLU"); + m.def("drelu", &drelu, "Backward of ReLU"); + m.def("dgeglu", &dgeglu, "Backward of GeGLU"); + m.def("dreglu", &dreglu, "Backward of ReGLU"); + m.def("dswiglu", &dswiglu, "Backward of SwiGLU"); + m.def("fa_prepare_fwd", &fa_prepare_fwd, "Prepare QKV for Flash Attention"); + m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention"); + m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend"); + + // Misc + m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version"); + m.def("userbuf_comm_available", &userbuf_comm_available, "If userbuf backend is available"); + + // Data structures + py::class_(m, "FP8TensorMeta") + .def(py::init<>()) + .def_readwrite("scale", &transformer_engine::FP8TensorMeta::scale) + .def_readwrite("scale_inv", &transformer_engine::FP8TensorMeta::scale_inv) + .def_readwrite("amax_history", &transformer_engine::FP8TensorMeta::amax_history); + +#ifdef NVTE_WITH_USERBUFFERS + py::enum_(m, "UbufOverlapAlgo") + .value("BULK_OVERLAP_AG", ubuf::UBOverlapAlgo::BULK_OVERLAP_AG) + .value("BULK_OVERLAP_RS", ubuf::UBOverlapAlgo::BULK_OVERLAP_RS) + .value("SPLIT_PIPELINED_RS", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_RS) + .value("SPLIT_PIPELINED_AG", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_AG); + + py::class_(m, "UbufCommOverlap") + .def(py::init()) + .def("bulk_overlap", &ubuf::UbufCommOverlap::bulk_overlap) + .def("split_overlap_rs", &ubuf::UbufCommOverlap::split_overlap_rs) + .def("copy_input_to_ubuf", &ubuf::UbufCommOverlap::copy_input_to_ubuf) + .def("get_ubuf_output", &ubuf::UbufCommOverlap::get_ubuf_output); + + py::class_(m, "UbufP2PCommOverlap") + .def(py::init()) + .def("split_overlap_ag", &ubuf::UbufP2PCommOverlap::split_overlap_ag) + .def("copy_input_to_ubuf", &ubuf::UbufP2PCommOverlap::copy_input_to_ubuf) + .def("get_ubuf_output", &ubuf::UbufP2PCommOverlap::get_ubuf_output); +#else // NVTE_WITH_USERBUFFERS + m.def("UbufOverlapAlgo", &placeholder, "Dummy function for python side annotations"); + m.def("UbufCommOverlap", &placeholder, "Dummy function for python side annotations"); + m.def("UbufP2PCommOverlap", &placeholder, "Dummy function for python side annotations"); +#endif // NVTE_WITH_USERBUFFERS + + py::enum_(m, "DType", py::module_local()) + .value("kByte", transformer_engine::DType::kByte) + .value("kInt32", transformer_engine::DType::kInt32) + .value("kFloat32", transformer_engine::DType::kFloat32) + .value("kFloat16", transformer_engine::DType::kFloat16) + .value("kBFloat16", transformer_engine::DType::kBFloat16) + .value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) + .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2); + + py::enum_(m, "FP8FwdTensors") + .value("GEMM1_INPUT", transformer_engine::FP8FwdTensors::GEMM1_INPUT) + .value("GEMM1_WEIGHT", transformer_engine::FP8FwdTensors::GEMM1_WEIGHT) + .value("GEMM1_OUTPUT", transformer_engine::FP8FwdTensors::GEMM1_OUTPUT) + .value("GEMM2_INPUT", transformer_engine::FP8FwdTensors::GEMM2_INPUT) + .value("GEMM2_WEIGHT", transformer_engine::FP8FwdTensors::GEMM2_WEIGHT) + .value("GEMM2_OUTPUT", transformer_engine::FP8FwdTensors::GEMM2_OUTPUT) + .value("GEMM3_INPUT", transformer_engine::FP8FwdTensors::GEMM3_INPUT) + .value("GEMM3_WEIGHT", transformer_engine::FP8FwdTensors::GEMM3_WEIGHT) + .value("GEMM3_OUTPUT", transformer_engine::FP8FwdTensors::GEMM3_OUTPUT); + + py::enum_(m, "FP8BwdTensors") + .value("GRAD_OUTPUT1", transformer_engine::FP8BwdTensors::GRAD_OUTPUT1) + .value("GRAD_INPUT1", transformer_engine::FP8BwdTensors::GRAD_INPUT1) + .value("GRAD_OUTPUT2", transformer_engine::FP8BwdTensors::GRAD_OUTPUT2) + .value("GRAD_INPUT2", transformer_engine::FP8BwdTensors::GRAD_INPUT2) + .value("GRAD_OUTPUT3", transformer_engine::FP8BwdTensors::GRAD_OUTPUT3) + .value("GRAD_INPUT3", transformer_engine::FP8BwdTensors::GRAD_INPUT3); + + py::enum_(m, "NVTE_Bias_Type") + .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) + .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) + .value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); + + py::enum_(m, "NVTE_Mask_Type") + .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) + .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) + .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK); + + py::enum_(m, "NVTE_QKV_Layout") + .value("NVTE_NOT_INTERLEAVED", NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED) + .value("NVTE_QKV_INTERLEAVED", NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) + .value("NVTE_KV_INTERLEAVED", NVTE_QKV_Layout::NVTE_KV_INTERLEAVED); + + py::enum_(m, "NVTE_Fused_Attn_Backend") + .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) + .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) + .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) + .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); +} diff --git a/transformer_engine/pytorch/csrc/extensions/softmax.cu b/transformer_engine/pytorch/csrc/extensions/softmax.cu new file mode 100644 index 0000000000..6bfbb7bb96 --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/softmax.cu @@ -0,0 +1,211 @@ +/************************************************************************* + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "extensions.h" + +at::Tensor scaled_softmax_forward(at::Tensor input, + float scale_factor +) { + using namespace transformer_engine; + AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); + AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || + (input.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + + const int batches = input.size(0); + const int attn_heads = input.size(1); + const int query_seq_len = input.size(2); + const int key_seq_len = input.size(3); + + TORCH_CHECK(key_seq_len <= 4096); + TORCH_CHECK(query_seq_len > 1); + + // Output + auto act_options = input.options().requires_grad(false); + auto softmax_results = + torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); + + auto input_cu = makeTransformerEngineTensor(input); + auto softmax_results_cu = makeTransformerEngineTensor(softmax_results); + + nvte_scaled_softmax_forward(input_cu.data(), softmax_results_cu.data(), scale_factor, + at::cuda::getCurrentCUDAStream()); + + return softmax_results; +} + + +at::Tensor scaled_softmax_backward(at::Tensor output_grad_, + at::Tensor softmax_results_, + float scale_factor +) { + using namespace transformer_engine; + + auto output_grads = output_grad_.contiguous(); + auto softmax_results = softmax_results_.contiguous(); + + AT_ASSERTM(output_grads.dim() == 4, "expected 4D tensor"); + AT_ASSERTM(softmax_results.dim() == 4, "expected 4D tensor"); + + AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || + (output_grads.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || + (softmax_results.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + + auto output_grads_cu = makeTransformerEngineTensor(output_grads); + auto softmax_results_cu = makeTransformerEngineTensor(softmax_results); + + // Produce gradients in place. + nvte_scaled_softmax_backward( + output_grads_cu.data(), softmax_results_cu.data(), output_grads_cu.data(), + scale_factor, at::cuda::getCurrentCUDAStream()); + + return output_grads; +} + + +at::Tensor scaled_masked_softmax_forward(at::Tensor input, + at::Tensor mask, + float scale_factor +) { + using namespace transformer_engine; + + AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); + AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || + (input.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + AT_ASSERTM(mask.dim() == 4, "expected 4D tensor"); + if (!input.is_contiguous()) + input = input.contiguous(); + if (!mask.is_contiguous()) + mask = mask.contiguous(); + + const int batches = input.size(0); + const int pad_batches = mask.size(0); + const int attn_heads = input.size(1); + const int query_seq_len = input.size(2); + const int key_seq_len = input.size(3); + TORCH_CHECK(key_seq_len <= 4096); + TORCH_CHECK(query_seq_len > 1); + TORCH_CHECK(pad_batches == 1 || pad_batches == batches); + TORCH_CHECK(mask.size(1) == 1); + TORCH_CHECK(mask.size(2) == query_seq_len); + TORCH_CHECK(mask.size(3) == key_seq_len); + + auto act_options = input.options().requires_grad(false); + auto softmax_results = + torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); + + + auto input_cu = makeTransformerEngineTensor(input); + auto mask_cu = makeTransformerEngineTensor(mask); + auto softmax_results_cu = makeTransformerEngineTensor(softmax_results); + + nvte_scaled_masked_softmax_forward( + input_cu.data(), mask_cu.data(), softmax_results_cu.data(), + scale_factor, at::cuda::getCurrentCUDAStream()); + + return softmax_results; +} + + +at::Tensor scaled_masked_softmax_backward(at::Tensor output_grad_, + at::Tensor softmax_results_, + float scale_factor +) { + using namespace transformer_engine; + + auto output_grads = output_grad_.contiguous(); + auto softmax_results = softmax_results_.contiguous(); + + AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor"); + AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor"); + + AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || + (output_grads.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || + (softmax_results.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + + auto output_grads_cu = makeTransformerEngineTensor(output_grads); + auto softmax_results_cu = makeTransformerEngineTensor(softmax_results); + + // Produce gradients in place. + nvte_scaled_softmax_backward( + output_grads_cu.data(), softmax_results_cu.data(), output_grads_cu.data(), + scale_factor, at::cuda::getCurrentCUDAStream()); + + return output_grads; +} + + +at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input, + float scale_factor +) { + using namespace transformer_engine; + + AT_ASSERTM(input.dim() == 3, "expected 3D tensor"); + AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || + (input.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + + const int attn_batches = input.size(0); + const int seq_len = input.size(1); + TORCH_CHECK(seq_len <= 2048); + + // Output + auto act_options = input.options().requires_grad(false); + auto softmax_results = + torch::empty({attn_batches, seq_len, seq_len}, act_options); + + auto input_cu = makeTransformerEngineTensor(input); + auto softmax_results_cu = makeTransformerEngineTensor(softmax_results); + + nvte_scaled_upper_triang_masked_softmax_forward(input_cu.data(), + softmax_results_cu.data(), + scale_factor, + at::cuda::getCurrentCUDAStream()); + + return softmax_results; +} + + +at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_, + at::Tensor softmax_results_, + float scale_factor +) { + using namespace transformer_engine; + + auto output_grads = output_grads_.contiguous(); + auto softmax_results = softmax_results_.contiguous(); + + AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); + AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor"); + + AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || + (output_grads.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || + (softmax_results.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + + TORCH_CHECK(output_grads.size(1) == output_grads.size(2)); + + auto output_grads_cu = makeTransformerEngineTensor(output_grads); + auto softmax_results_cu = makeTransformerEngineTensor(softmax_results); + + // Produce gradients in place. + nvte_scaled_upper_triang_masked_softmax_backward(output_grads_cu.data(), + softmax_results_cu.data(), + output_grads_cu.data(), + scale_factor, + at::cuda::getCurrentCUDAStream()); + + return output_grads; +} diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cu b/transformer_engine/pytorch/csrc/extensions/transpose.cu new file mode 100644 index 0000000000..c58d474fb2 --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cu @@ -0,0 +1,321 @@ +/************************************************************************* + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "extensions.h" + +void fused_cast_transpose(at::Tensor input, + at::Tensor scale, + at::Tensor amax, + at::Tensor scale_inv, + at::Tensor input_cast, + at::Tensor input_transpose, + transformer_engine::DType otype +) { + using namespace transformer_engine; + + size_t M = static_cast(input.size(0)); + size_t N = static_cast(input.size(1)); + + auto input_cu = makeTransformerEngineTensor(input); + auto output_cast_cu = makeTransformerEngineTensor(input_cast.data_ptr(), {M, N}, otype, + amax.data_ptr(), scale.data_ptr(), + scale_inv.data_ptr()); + auto output_transpose_cu = makeTransformerEngineTensor(input_transpose.data_ptr(), {N, M}, otype, + amax.data_ptr(), scale.data_ptr(), + scale_inv.data_ptr()); + + nvte_cast_transpose(input_cu.data(), output_cast_cu.data(), output_transpose_cu.data(), + at::cuda::getCurrentCUDAStream()); +} + + +std::vector fused_cast_transpose_bgrad(at::Tensor grad_output, + at::Tensor scale, + at::Tensor amax, + at::Tensor scale_inv, + transformer_engine::DType otype +) { + using namespace transformer_engine; + + size_t M = static_cast(grad_output.size(0)); + size_t N = static_cast(grad_output.size(1)); + + DType grad_output_type = GetTransformerEngineDType(grad_output.scalar_type()); + auto grad_bias = allocateTorchTensor(grad_output.size(-1), grad_output_type); + auto grad_output_cast = + allocateTorchTensor(grad_output.size(0), + grad_output.size(1), + DType::kByte); + auto grad_output_transpose = + allocateTorchTensor(grad_output.size(1), + grad_output.size(0), + DType::kByte); + + auto input_cu = makeTransformerEngineTensor(grad_output); + auto cast_output_cu = makeTransformerEngineTensor(grad_output_cast.data_ptr(), {M, N}, + otype, amax.data_ptr(), scale.data_ptr(), + scale_inv.data_ptr()); + auto transposed_output_cu = makeTransformerEngineTensor(grad_output_transpose.data_ptr(), + {N, M}, otype, amax.data_ptr(), + scale.data_ptr(), scale_inv.data_ptr()); + auto dbias_cu = makeTransformerEngineTensor(grad_bias); + transformer_engine::TensorWrapper workspace; + + nvte_cast_transpose_dbias(input_cu.data(), cast_output_cu.data(), + transposed_output_cu.data(), dbias_cu.data(), + workspace.data(), at::cuda::getCurrentCUDAStream()); + + // Fill workspace + auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); + workspace = makeTransformerEngineTensor(workspace_data.data_ptr(), + workspace.shape(), + workspace.dtype()); + + nvte_cast_transpose_dbias(input_cu.data(), cast_output_cu.data(), + transposed_output_cu.data(), dbias_cu.data(), + workspace.data(), at::cuda::getCurrentCUDAStream()); + + return {grad_bias, grad_output_cast, grad_output_transpose}; +} + + +std::vector fused_fp8_transpose_bgrad(at::Tensor grad_output, + at::Tensor scale, + at::Tensor amax, + at::Tensor scale_inv, + transformer_engine::DType otype, + transformer_engine::DType grad_bias_type +) { + using namespace transformer_engine; + + size_t M = static_cast(grad_output.size(0)); + size_t N = static_cast(grad_output.size(1)); + + auto grad_bias = allocateTorchTensor(grad_output.size(-1), grad_bias_type); + auto grad_output_transpose = + allocateTorchTensor(grad_output.size(1), + grad_output.size(0), + DType::kByte); + auto input_cu = makeTransformerEngineTensor(grad_output.data_ptr(), {M, N}, + otype, amax.data_ptr(), scale.data_ptr(), + scale_inv.data_ptr()); + auto transposed_output_cu = makeTransformerEngineTensor(grad_output_transpose.data_ptr(), + {N, M}, otype, amax.data_ptr(), + scale.data_ptr(), scale_inv.data_ptr()); + auto dbias_cu = makeTransformerEngineTensor(grad_bias); + transformer_engine::TensorWrapper workspace; + + nvte_fp8_transpose_dbias(input_cu.data(), transposed_output_cu.data(), dbias_cu.data(), + workspace.data(), at::cuda::getCurrentCUDAStream()); + + // Fill workspace + auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); + workspace = makeTransformerEngineTensor(workspace_data.data_ptr(), + workspace.shape(), + workspace.dtype()); + + nvte_fp8_transpose_dbias(input_cu.data(), transposed_output_cu.data(), dbias_cu.data(), + workspace.data(), at::cuda::getCurrentCUDAStream()); + + return {grad_bias, grad_output_transpose}; +} + + + +std::vector fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output, + at::Tensor gelu_input, + at::Tensor scale, + at::Tensor amax, + at::Tensor scale_inv, + transformer_engine::DType otype +) { + using namespace transformer_engine; + + size_t M = static_cast(grad_output.size(0)); + size_t N = static_cast(grad_output.size(1)); + + DType grad_output_type = GetTransformerEngineDType(grad_output.scalar_type()); + auto grad_bias = allocateTorchTensor(grad_output.size(-1), grad_output_type); + auto dgelu = + allocateTorchTensor(grad_output.size(0), + grad_output.size(1), + DType::kByte); + auto dgelu_transpose = + allocateTorchTensor(grad_output.size(1), + grad_output.size(0), + DType::kByte); + + transformer_engine::TensorWrapper workspace; + auto gelu_input_cu = makeTransformerEngineTensor(gelu_input); + auto input_cu = makeTransformerEngineTensor(grad_output); + auto cast_output_cu = makeTransformerEngineTensor(dgelu.data_ptr(), {M, N}, + otype, amax.data_ptr(), scale.data_ptr(), + scale_inv.data_ptr()); + auto transposed_output_cu = makeTransformerEngineTensor(dgelu_transpose.data_ptr(), {N, M}, + otype, amax.data_ptr(), scale.data_ptr(), + scale_inv.data_ptr()); + auto dbias_cu = makeTransformerEngineTensor(grad_bias); + + nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(), + cast_output_cu.data(), transposed_output_cu.data(), + dbias_cu.data(), workspace.data(), + at::cuda::getCurrentCUDAStream()); + + // Fill workspace + auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); + workspace = makeTransformerEngineTensor(workspace_data.data_ptr(), + workspace.shape(), + workspace.dtype()); + + nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(), + cast_output_cu.data(), transposed_output_cu.data(), + dbias_cu.data(), workspace.data(), + at::cuda::getCurrentCUDAStream()); + + return {grad_bias, dgelu, dgelu_transpose}; +} + + +void fused_multi_cast_transpose(std::vector input_list, + std::vector scale_list, + std::vector cast_output_list, + std::vector transposed_output_list, + std::vector amax_list, + std::vector scale_inv_list, + transformer_engine::DType otype +) { + using namespace transformer_engine; + + // Extract properties from PyTorch tensors + std::vector input_dptr_list, scale_dptr_list, + cast_output_dptr_list, transposed_output_dptr_list, + amax_dptr_list, scale_inv_dptr_list; + std::vector> input_shape_list, scale_shape_list, + cast_output_shape_list, transposed_output_shape_list, + amax_shape_list, scale_inv_shape_list; + std::vector input_type_list, scale_type_list, + cast_output_type_list, transposed_output_type_list, + amax_type_list, scale_inv_type_list; + auto extract_tensor_props_skip_dtype = [](at::Tensor& tensor, + std::vector& dptr_list, + std::vector>& shape_list) { + dptr_list.push_back(tensor.data_ptr()); + shape_list.push_back({}); + for (int d = 0; d < tensor.dim(); ++d) { + shape_list.back().push_back(tensor.size(d)); + } + }; + auto extract_tensor_props = [](at::Tensor& tensor, + std::vector& dptr_list, + std::vector>& shape_list, + std::vector& type_list) { + dptr_list.push_back(tensor.data_ptr()); + shape_list.push_back({}); + for (int d = 0; d < tensor.dim(); ++d) { + shape_list.back().push_back(tensor.size(d)); + } + type_list.push_back(GetTransformerEngineDType(tensor.scalar_type())); + }; + for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) { + extract_tensor_props(input_list[tensor_id], + input_dptr_list, + input_shape_list, + input_type_list); + extract_tensor_props(scale_list[tensor_id], + scale_dptr_list, + scale_shape_list, + scale_type_list); + extract_tensor_props_skip_dtype(cast_output_list[tensor_id], + cast_output_dptr_list, + cast_output_shape_list); + cast_output_type_list.push_back(otype); + extract_tensor_props_skip_dtype(transposed_output_list[tensor_id], + transposed_output_dptr_list, + transposed_output_shape_list); + transposed_output_type_list.push_back(otype); + extract_tensor_props(amax_list[tensor_id], + amax_dptr_list, + amax_shape_list, + amax_type_list); + extract_tensor_props(scale_inv_list[tensor_id], + scale_inv_dptr_list, + scale_inv_shape_list, + scale_inv_type_list); + } + + transformer_engine::TensorWrapper workspace; + + // Construct TE tensors + std::vector nvte_input_list, + nvte_cast_output_list, nvte_transposed_output_list; + std::vector tensor_wrappers; + auto make_tensor = [&tensor_wrappers](void* dptr, + const std::vector& shape, + transformer_engine::DType dtype, + void* amax_dptr, + void* scale_dptr, + void* scale_inv_dptr) + -> NVTETensor { + tensor_wrappers.emplace_back(makeTransformerEngineTensor(dptr, shape, dtype, amax_dptr, + scale_dptr, scale_inv_dptr)); + return tensor_wrappers.back().data(); + }; + for (size_t i = 0; i < input_dptr_list.size(); ++i) { + nvte_input_list.emplace_back(make_tensor(input_dptr_list[i], + input_shape_list[i], + input_type_list[i], + nullptr, + nullptr, + nullptr)); + nvte_cast_output_list.emplace_back(make_tensor(cast_output_dptr_list[i], + cast_output_shape_list[i], + cast_output_type_list[i], + amax_dptr_list[i], + scale_dptr_list[i], + scale_inv_dptr_list[i])); + nvte_transposed_output_list.emplace_back(make_tensor(transposed_output_dptr_list[i], + transposed_output_shape_list[i], + transposed_output_type_list[i], + amax_dptr_list[i], + scale_dptr_list[i], + scale_inv_dptr_list[i])); + } + + // Check tensor lists + NVTE_CHECK(nvte_cast_output_list.size() == nvte_input_list.size(), + "Number of input and C output tensors must match"); + NVTE_CHECK(nvte_transposed_output_list.size() == nvte_input_list.size(), + "Number of input and T output tensors must match"); + + // Launch TE kernel + nvte_multi_cast_transpose(nvte_input_list.size(), + nvte_input_list.data(), + nvte_cast_output_list.data(), + nvte_transposed_output_list.data(), + at::cuda::getCurrentCUDAStream()); +} + + +at::Tensor fp8_transpose(at::Tensor input, + transformer_engine::DType otype +) { + using namespace transformer_engine; + + size_t M = static_cast(input.size(0)); + size_t N = static_cast(input.size(1)); + + auto output = + allocateTorchTensor(input.size(1), + input.size(0), + DType::kByte); + + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, otype); + auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, M}, otype); + + nvte_transpose(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); + + return output; +} diff --git a/transformer_engine/pytorch/csrc/ts_fp8_op.cpp b/transformer_engine/pytorch/csrc/ts_fp8_op.cpp index b0424d6f4b..6f38253052 100755 --- a/transformer_engine/pytorch/csrc/ts_fp8_op.cpp +++ b/transformer_engine/pytorch/csrc/ts_fp8_op.cpp @@ -328,6 +328,44 @@ at::Tensor layernorm_fwd_inf_ts(const at::Tensor &input, return output; } +at::Tensor rmsnorm_fwd_fp8_inf_ts(const at::Tensor &input, + const at::Tensor &weight, + double eps, + at::Tensor scale, + at::Tensor amax, + at::Tensor scale_inv, + int64_t fp8_tensor, + int64_t otype, + const bool zero_centered_gamma) { + transformer_engine::DType otype_arg = reverse_map_dtype(otype); + float eps_float = static_cast(eps); + + at::Tensor output = rmsnorm_fwd_fp8_inf(input, + weight, + eps_float, + scale, + amax, + scale_inv, + otype_arg, + zero_centered_gamma); + + return output; +} + +at::Tensor rmsnorm_fwd_inf_ts(const at::Tensor &input, + const at::Tensor &weight, + double eps, + const bool zero_centered_gamma) { + float eps_float = static_cast(eps); + + at::Tensor output = rmsnorm_fwd_inf(input, + weight, + eps_float, + zero_centered_gamma); + + return output; +} + TORCH_LIBRARY(tex_ts, m) { m.def("cast_to_fp8_ts", &cast_to_fp8_ts); m.def("cast_from_fp8_ts", &cast_from_fp8_ts); @@ -339,4 +377,6 @@ TORCH_LIBRARY(tex_ts, m) { m.def("te_gemm_ts", &te_gemm_ts); m.def("layernorm_fwd_fp8_inf_ts", &layernorm_fwd_fp8_inf_ts); m.def("layernorm_fwd_inf_ts", &layernorm_fwd_inf_ts); + m.def("rmsnorm_fwd_fp8_inf_ts", &rmsnorm_fwd_fp8_inf_ts); + m.def("rmsnorm_fwd_inf_ts", &rmsnorm_fwd_inf_ts); } diff --git a/transformer_engine/pytorch/module/__init__.py b/transformer_engine/pytorch/module/__init__.py index fef96e7738..51463eb12d 100644 --- a/transformer_engine/pytorch/module/__init__.py +++ b/transformer_engine/pytorch/module/__init__.py @@ -7,3 +7,4 @@ from .linear import Linear from .layernorm_mlp import LayerNormMLP from .layernorm import LayerNorm +from .rmsnorm import RMSNorm diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py new file mode 100644 index 0000000000..4b8d4de643 --- /dev/null +++ b/transformer_engine/pytorch/module/_common.py @@ -0,0 +1,95 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Internal function used by multiple modules.""" + +from typing import Union, Dict, Any + +import torch + +from .. import cpp_extensions as tex +from ..fp8 import get_fp8_te_dtype + +def _get_normalization_func(normalization: str, + fp8_output: bool, + is_grad_enabled: bool, + forward: bool): + fwd_normalization_funcs = { + ('LayerNorm', True, True): tex.layernorm_fwd_fp8, + ('LayerNorm', True, False): tex.layernorm_fwd_fp8_inf, + ('LayerNorm', False, True): tex.layernorm_fwd_noalloc, + ('LayerNorm', False, False): tex.layernorm_fwd_inf, + ('RMSNorm', True, True): tex.rmsnorm_fwd_fp8, + ('RMSNorm', True, False): tex.rmsnorm_fwd_fp8_inf, + ('RMSNorm', False, True): tex.rmsnorm_fwd_noalloc, + ('RMSNorm', False, False): tex.rmsnorm_fwd_inf, + } + bwd_normalization_funcs = { + 'LayerNorm': tex.layernorm_bwd, + 'RMSNorm': tex.rmsnorm_bwd, + } + + if forward: + return fwd_normalization_funcs[(normalization, fp8_output, is_grad_enabled)] + assert not fp8_output, "FP8 output is not supported in backward normalization!" + assert is_grad_enabled, "Gradient has to be enabled to call backward normalization!" + return bwd_normalization_funcs[normalization] + +def _apply_normalization(inputmat:torch.Tensor, + ln_out: torch.Tensor, + ln_weight: torch.Tensor, + ln_bias: Union[torch.Tensor, None], + eps: float, + fp8_out: bool, + fp8_meta: Dict[str, Any], + normalization: str, + fwd_ln_sm_margin: int, + zero_centered_gamma: bool, + is_grad_enabled: bool): + normalization_func = _get_normalization_func(normalization, + fp8_out, + is_grad_enabled, + True) + + inputs = (inputmat, ln_weight) if ln_bias is None else (inputmat, ln_weight, ln_bias) + if fp8_out: + fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + + if is_grad_enabled: + output_key = "ln_out" if normalization == "LayerNorm" else "rmsnorm_out" + output_kwarg = {output_key: ln_out} + output = normalization_func( + *inputs, + eps, + fp8_meta["scaling_fwd"], + tex.FP8FwdTensors.GEMM1_INPUT, + fp8_dtype_forward, + fwd_ln_sm_margin, + zero_centered_gamma, + **output_kwarg, + ) + else: + return normalization_func( + *inputs, + eps, + fp8_meta["scaling_fwd"], + tex.FP8FwdTensors.GEMM1_INPUT, + fp8_dtype_forward, + zero_centered_gamma, + ), None, None + else: + if is_grad_enabled: + output = normalization_func( + *inputs, ln_out, eps, + fwd_ln_sm_margin, zero_centered_gamma + ) + else: + return normalization_func( + *inputs, eps, zero_centered_gamma + ), None, None + if normalization == "RMSNorm": + output = (ln_out, None, output[1]) + elif normalization == "LayerNorm": + output = (ln_out, output[1], output[2]) + return output diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index c18da5ed85..698d88a284 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -12,7 +12,7 @@ from torch.nn.parameter import Parameter from torch.nn import init -import transformer_engine_extensions as tex +from .. import cpp_extensions as tex from .base import ( get_workspace, @@ -38,22 +38,13 @@ reduce_scatter_along_first_dim, gather_along_first_dim, ) -from ..cpp_extensions import ( - fp8_gemm, - gemm, - fp8_cast_transpose_fused, - layernorm_fwd_fp8, - layernorm_fwd_fp8_inf, - layernorm_fwd_inf, - cast_to_fp8, - cast_from_fp8, -) from ..constants import GemmParallelModes, dist_group_type, TE_DType from ..jit import no_torch_dynamo +from ._common import _apply_normalization -__all__ = ["LayerNormLinear"] +__all__ = ["LayerNormLinear"] class _LayerNormLinear(torch.autograd.Function): """LayerNormLinear semi-top level module @@ -65,7 +56,7 @@ def forward( ctx, inp: torch.Tensor, ln_weight: torch.Tensor, - ln_bias: torch.Tensor, + ln_bias: Union[torch.Tensor, None], weight: torch.Tensor, weight_fp8: Union[torch.Tensor, None], weight_t_fp8: Union[torch.Tensor, None], @@ -91,6 +82,7 @@ def forward( ub_bulk_wgrad: bool, ub_bulk_dgrad: bool, ub_split_ag: bool, + normalization: str, ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # Make sure input dimensions are compatible in_features = ln_weight.numel() @@ -105,10 +97,9 @@ def forward( # Cast for native AMP inputmat = cast_if_needed(inputmat, activation_dtype) ln_weight = cast_if_needed(ln_weight, activation_dtype) - ln_bias = cast_if_needed(ln_bias, activation_dtype) - # If residual connection is after LN, we need `ln_out` - # tensor in higher precision, this comes at the cost - # of an extra fp8 cast. + if ln_bias is not None: + ln_bias = cast_if_needed(ln_bias, activation_dtype) + if ub_split_ag: tp_world_size = get_distributed_world_size(tp_group) if tp_world_size == 1 or (not is_grad_enabled) or return_layernorm_output: @@ -118,69 +109,35 @@ def forward( dim_size[0] = dim_size[0] * tp_world_size ub_obj_lnout = get_ub("qkv_fprop") ln_out = ub_obj_lnout.get_ubuf_output(0) - if fp8: - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - - if not return_layernorm_output: - if is_grad_enabled: - if not ub_split_ag: - ln_out = torch.empty_like(inputmat, dtype=torch.uint8) - _, mu, rsigma = layernorm_fwd_fp8( - inputmat, - ln_weight, - ln_bias, - eps, - fp8_meta["scaling_fwd"], - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - fwd_ln_sm_margin, - zero_centered_gamma, - ln_out = ln_out - ) - else: - mu = rsigma = None - ln_out = layernorm_fwd_fp8_inf( - inputmat, - ln_weight, - ln_bias, - eps, - fp8_meta["scaling_fwd"], - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - zero_centered_gamma, - ) - else: - if is_grad_enabled: - ln_out_return, mu, rsigma = tex.layernorm_fwd( - inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma - ) - else: - ln_out_return, mu, rsigma = layernorm_fwd_inf( - inputmat, ln_weight, ln_bias, eps, zero_centered_gamma - ), None, None - - ln_out = cast_to_fp8( - ln_out_return, + else: + ln_out_dtype = torch.uint8 if fp8 else inputmat.dtype + ln_out = torch.empty_like(inputmat, dtype=ln_out_dtype) + + fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + + ln_out, mu, rsigma = _apply_normalization(inputmat, + ln_out, + ln_weight, + ln_bias, + eps, + fp8 and not return_layernorm_output, + fp8_meta, + normalization, + fwd_ln_sm_margin, + zero_centered_gamma, + is_grad_enabled) + # If residual connection is after LN, we need `ln_out_return` + # tensor in higher precision, this comes at the cost + # of an extra fp8 cast. + if return_layernorm_output: + ln_out_return = ln_out + if fp8: + ln_out = tex.cast_to_fp8( + ln_out, fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype_forward, ) - else: - if is_grad_enabled: - if ub_split_ag: - _, mu, rsigma = tex.layernorm_fwd_noalloc( - inputmat, ln_weight, ln_bias, ln_out, eps, - fwd_ln_sm_margin, zero_centered_gamma - ) - else: - ln_out, mu, rsigma = tex.layernorm_fwd( - inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma - ) - else: - ln_out, mu, rsigma = layernorm_fwd_inf( - inputmat, ln_weight, ln_bias, eps, zero_centered_gamma - ), None, None - ln_out_return = ln_out # Column Parallel Linear if ub_split_ag: ln_out_total = ub_obj_lnout.get_ubuf_output(1) @@ -200,7 +157,7 @@ def forward( if update_fp8_weights: if is_grad_enabled: - fp8_cast_transpose_fused( + tex.fp8_cast_transpose_fused( weight, fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_WEIGHT, @@ -210,13 +167,13 @@ def forward( ) else: weight_t_fp8 = None - weight_fp8 = cast_to_fp8( + weight_fp8 = tex.cast_to_fp8( weight, fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward) - out = fp8_gemm( + out = tex.fp8_gemm( weight_fp8, fp8_meta["scaling_fwd"].scale_inv, tex.FP8FwdTensors.GEMM1_WEIGHT, @@ -247,7 +204,7 @@ def forward( fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = \ torch.amax(weight).float() - out, _, _ = gemm( + out, _, _ = tex.gemm( weight, ln_out_total, activation_dtype, @@ -289,6 +246,7 @@ def forward( ctx.ub_bulk_wgrad = ub_bulk_wgrad ctx.ub_bulk_dgrad = ub_bulk_dgrad ctx.requires_dgrad = inp.requires_grad + ctx.normalization = normalization # Row Parallel Linear if parallel_mode == "row" and sequence_parallel: @@ -379,7 +337,7 @@ def backward( ) # DGRAD: Evaluated unconditionally to feed into Linear backward - _ = fp8_gemm( + _ = tex.fp8_gemm( weight_t_fp8, fwd_scale_inverses, tex.FP8FwdTensors.GEMM1_WEIGHT, @@ -397,7 +355,7 @@ def backward( ) else: # DGRAD: Evaluated unconditionally to feed into Linear backward - _, _, _ = gemm( + _, _, _ = tex.gemm( weight, grad_output, ctx.activation_dtype, @@ -427,7 +385,7 @@ def backward( # WGRAD if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward) - wgrad = fp8_gemm( + wgrad = tex.fp8_gemm( ln_out_total_t, fwd_scale_inverses, tex.FP8FwdTensors.GEMM1_INPUT, @@ -446,14 +404,14 @@ def backward( ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None ) else: - ln_out_total_c = cast_from_fp8( + ln_out_total_c = tex.cast_from_fp8( ln_out_total, ctx.fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype_forward, TE_DType[ctx.activation_dtype], ) - wgrad, _, _ = gemm( + wgrad, _, _ = tex.gemm( ln_out_total_c, grad_output, ctx.activation_dtype, @@ -468,7 +426,7 @@ def backward( ) else: # WGRAD - wgrad, grad_bias, _ = gemm( + wgrad, grad_bias, _ = tex.gemm( ln_out_total, grad_output, ctx.activation_dtype, @@ -496,10 +454,18 @@ def backward( if ctx.return_layernorm_output: d_ln_out = d_ln_out + grad_outputs[1].view_as(d_ln_out) - dxmat, dgamma, dbeta = tex.layernorm_bwd( - d_ln_out, inputmat, mu, rsigma, ln_weight, - ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma - ) + if ctx.normalization == "LayerNorm": + dxmat, dgamma, dbeta = tex.layernorm_bwd( + d_ln_out, inputmat, mu, rsigma, ln_weight, + ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma + ) + elif ctx.normalization == "RMSNorm": + dxmat, dgamma = tex.rmsnorm_bwd( + d_ln_out, inputmat, rsigma, ln_weight, + ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma + ) + dbeta = None + if not ctx.use_bias: grad_bias = None @@ -533,6 +499,7 @@ def backward( None, None, None, + None, ) @@ -555,6 +522,8 @@ class LayerNormLinear(TransformerEngineBaseModule): a value added to the denominator of layer normalization for numerical stability. bias : bool, default = `True` if set to `False`, the layer will not learn an additive bias. + normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm' + type of normalization applied. init_method : Callable, default = `None` used for initializing weights in the following way: `init_method(weight)`. When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`. @@ -624,6 +593,7 @@ def __init__( get_rng_state_tracker: Optional[Callable] = None, init_method: Optional[Callable] = None, bias: bool = True, + normalization: str = 'LayerNorm', return_bias: bool = False, params_dtype: Optional[torch.dtype] = None, parallel_mode: Optional[str] = None, @@ -649,9 +619,11 @@ def __init__( self.in_features = in_features self.out_features = out_features self.fuse_wgrad_accumulation = fuse_wgrad_accumulation + self.normalization = normalization + assert normalization in ['LayerNorm', 'RMSNorm'], "Unsupported normalization type!" self.use_bias = bias self.return_bias = return_bias - self.apply_bias = bias and not return_bias + self.apply_bias = self.use_bias and not return_bias self.return_layernorm_output = return_layernorm_output self.parameters_split = parameters_split self.zero_centered_gamma = zero_centered_gamma @@ -696,15 +668,18 @@ def __init__( dtype=params_dtype, ) ) - self.layer_norm_bias = Parameter( - torch.empty( - in_features, - device=torch.cuda.current_device(), - dtype=params_dtype, - ) - ) setattr(self.layer_norm_weight, "sequence_parallel", self.sequence_parallel) - setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel) + if self.normalization != "RMSNorm": + self.layer_norm_bias = Parameter( + torch.empty( + in_features, + device=torch.cuda.current_device(), + dtype=params_dtype, + ) + ) + setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel) + else: + self.layer_norm_bias = None self.reset_layer_norm_parameters() self.weight_tensor = torch.empty( @@ -796,7 +771,8 @@ def reset_layer_norm_parameters(self) -> None: init.ones_(self.layer_norm_weight) else: init.zeros_(self.layer_norm_weight) - init.zeros_(self.layer_norm_bias) + if self.layer_norm_bias is not None: + init.zeros_(self.layer_norm_bias) def get_fp8_weights_scratchpad( self, @@ -915,6 +891,7 @@ def forward( self.ub_bulk_wgrad, self.ub_bulk_dgrad, self.ub_split_ag, + self.normalization, ) out = fwd_fn(*args) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index bce92cabd7..d2d866667b 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -46,6 +46,8 @@ from ..constants import dist_group_type, TE_DType from ..jit import no_torch_dynamo +from ._common import _apply_normalization + __all__ = ["LayerNormMLP"] @@ -107,6 +109,7 @@ def forward( ub_split_rs: bool, ub_split_ag: bool, activation: str, + normalization: str, ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # Make sure input dimensions are compatible in_features = ln_weight.numel() @@ -124,7 +127,8 @@ def forward( # Cast for native AMP inputmat = cast_if_needed(inputmat, activation_dtype) ln_weight = cast_if_needed(ln_weight, activation_dtype) - ln_bias = cast_if_needed(ln_bias, activation_dtype) + if ln_bias is not None: + ln_bias = cast_if_needed(ln_bias, activation_dtype) if ub_split_ag: tp_world_size = get_distributed_world_size(tp_group) @@ -133,70 +137,39 @@ def forward( if ub_split_ag: ub_obj_lnout = get_ub("fc1_fprop") ln_out = ub_obj_lnout.get_ubuf_output(0) + else: + ln_out_dtype = torch.uint8 if (fp8 and not return_layernorm_output) else inputmat.dtype + ln_out = torch.empty_like(inputmat, dtype=ln_out_dtype) if ub_split_rs: tp_world_size = get_distributed_world_size(tp_group) if tp_world_size == 1: ub_split_rs = False + fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + + ln_out, mu, rsigma = _apply_normalization(inputmat, + ln_out, + ln_weight, + ln_bias, + eps, + fp8 and not return_layernorm_output, + fp8_meta, + normalization, + fwd_ln_sm_margin, + zero_centered_gamma, + is_grad_enabled) # If residual connection is after LN, we need `ln_out` # tensor in higher precision, this comes at the cost # of an extra fp8 cast. - if fp8: - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - if not return_layernorm_output: - if is_grad_enabled: - if not ub_split_ag: - ln_out = torch.empty_like(inputmat, dtype=torch.uint8) - _, mu, rsigma = tex.layernorm_fwd_fp8( - inputmat, - ln_weight, - ln_bias, - eps, - fp8_meta["scaling_fwd"], - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - fwd_ln_sm_margin, - zero_centered_gamma, - ln_out = ln_out, - ) - else: - ln_out = tex.layernorm_fwd_fp8_inf( - inputmat, - ln_weight, - ln_bias, - eps, - fp8_meta["scaling_fwd"], - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - zero_centered_gamma, - ) - else: - ln_out_return, mu, rsigma = tex.layernorm_fwd( - inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma - ) + if return_layernorm_output: + ln_out_return = ln_out + if fp8: ln_out = tex.cast_to_fp8( - ln_out_return, + ln_out, fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype_forward, ) - else: - if is_grad_enabled: - if ub_split_ag: - _, mu, rsigma = tex.layernorm_fwd_noalloc( - inputmat, ln_weight, ln_bias, ln_out, eps, - fwd_ln_sm_margin, zero_centered_gamma - ) - else: - ln_out, mu, rsigma = tex.layernorm_fwd( - inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma - ) - else: - ln_out, mu, rsigma = tex.layernorm_fwd_inf( - inputmat, ln_weight, ln_bias, eps, zero_centered_gamma - ), None, None - - ln_out_return = ln_out # Column Parallel Linear if ub_split_ag: ln_out_total = ub_obj_lnout.get_ubuf_output(1) @@ -422,6 +395,7 @@ def forward( ctx.ub_bulk_dgrad = ub_bulk_dgrad ctx.ub_split_ag = ub_split_ag ctx.requires_dgrad = inp.requires_grad + ctx.normalization = normalization # Row Parallel Linear if ub_split_rs: @@ -804,10 +778,17 @@ def backward( if ctx.return_layernorm_output: d_ln_out = d_ln_out + grad_outputs[1].view_as(d_ln_out) - dxmat, dgamma, dbeta = tex.layernorm_bwd( - d_ln_out, inputmat, mu, rsigma, ln_weight, - ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma - ) + if ctx.normalization == "LayerNorm": + dxmat, dgamma, dbeta = tex.layernorm_bwd( + d_ln_out, inputmat, mu, rsigma, ln_weight, + ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma + ) + elif ctx.normalization == "RMSNorm": + dxmat, dgamma = tex.rmsnorm_bwd( + d_ln_out, inputmat, rsigma, ln_weight, + ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma + ) + dbeta = None return ( dxmat.view(ctx.inp_shape) if ctx.requires_dgrad else None, @@ -846,6 +827,7 @@ def backward( None, None, None, + None, ) @@ -864,6 +846,8 @@ class LayerNormMLP(TransformerEngineBaseModule): a value added to the denominator of layer normalization for numerical stability. bias : bool, default = `True` if set to `False`, the FC1 and FC2 layers will not learn an additive bias. + normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm' + type of normalization applied. activation : str, default = 'gelu' activation function used. Options: 'gelu', 'geglu', 'relu', 'reglu', 'squared_relu', 'swiglu'. @@ -942,6 +926,7 @@ def __init__( tp_size: int = 1, init_method: Optional[Callable] = None, bias: bool = True, + normalization: str = 'LayerNorm', activation : str = "gelu", output_layer_init_method: Optional[Callable] = None, fuse_wgrad_accumulation: bool = False, @@ -960,6 +945,8 @@ def __init__( params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype self.fuse_wgrad_accumulation = fuse_wgrad_accumulation + self.normalization = normalization + assert normalization in ['LayerNorm', 'RMSNorm'], "Unsupported normalization type!" self.use_bias = bias self.activation = activation self.return_bias = return_bias @@ -1005,15 +992,18 @@ def __init__( dtype=params_dtype, ) ) - self.layer_norm_bias = Parameter( - torch.empty( - hidden_size, - device=torch.cuda.current_device(), - dtype=params_dtype, - ) - ) setattr(self.layer_norm_weight, "sequence_parallel", self.sequence_parallel) - setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel) + if self.normalization != "RMSNorm": + self.layer_norm_bias = Parameter( + torch.empty( + hidden_size, + device=torch.cuda.current_device(), + dtype=params_dtype, + ) + ) + setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel) + else: + self.layer_norm_bias = None self.reset_layer_norm_parameters() if self.activation in ['reglu', 'geglu', 'swiglu']: @@ -1114,7 +1104,8 @@ def reset_layer_norm_parameters(self) -> None: init.ones_(self.layer_norm_weight) else: init.zeros_(self.layer_norm_weight) - init.zeros_(self.layer_norm_bias) + if self.layer_norm_bias is not None: + init.zeros_(self.layer_norm_bias) def get_fp8_weights_scratchpad( self, @@ -1217,6 +1208,7 @@ def forward( self.ub_split_rs, self.ub_split_ag, self.activation, + self.normalization, ) out = fwd_fn(*args) diff --git a/transformer_engine/pytorch/module/rmsnorm.py b/transformer_engine/pytorch/module/rmsnorm.py new file mode 100644 index 0000000000..dc7db1a221 --- /dev/null +++ b/transformer_engine/pytorch/module/rmsnorm.py @@ -0,0 +1,168 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""RMSNorm API""" +import os +from typing import Union, Tuple, Optional + +import torch +from torch.nn.parameter import Parameter +from torch.nn import init + +from .. import cpp_extensions as tex +from ..jit import no_torch_dynamo + + +__all__ = ["RMSNorm"] + + +class _RMSNorm(torch.autograd.Function): + """functional RMSNorm""" + + @staticmethod + def forward( + ctx, + inp: torch.Tensor, + rmsnorm_weight: torch.Tensor, + eps: float, + fwd_rmsnorm_sm_margin: int, + bwd_rmsnorm_sm_margin: int, + zero_centered_gamma: bool, + is_grad_enabled: bool, + ) -> torch.Tensor: + # Make sure input dimensions are compatible + in_features = rmsnorm_weight.numel() + assert inp.is_cuda, "TransformerEngine needs CUDA." + assert inp.shape[-1] == in_features, "RMSNorm not possible" + inputmat = inp.view((-1, in_features)) + + if is_grad_enabled: + rmsnorm_out, rsigma = tex.rmsnorm_fwd(inputmat, rmsnorm_weight, + eps, fwd_rmsnorm_sm_margin, + zero_centered_gamma) + ctx.save_for_backward(inputmat, rmsnorm_weight, rsigma) + ctx.inp_shape = inp.shape + ctx.bwd_rmsnorm_sm_margin = bwd_rmsnorm_sm_margin + ctx.zero_centered_gamma = zero_centered_gamma + else: + rmsnorm_out = tex.rmsnorm_fwd_inf(inputmat, rmsnorm_weight, + eps, + zero_centered_gamma) + return rmsnorm_out.view_as(inp) + + @staticmethod + def backward( + ctx, grad_output: torch.Tensor + ) -> Tuple[Union[torch.Tensor, None], ...]: + inputmat, rmsnorm_weight, rsigma = ctx.saved_tensors + grad_output = grad_output.contiguous() + d_rmsnorm_out = grad_output.view(inputmat.shape) + dxmat, dgamma = tex.rmsnorm_bwd( + d_rmsnorm_out, inputmat, rsigma, rmsnorm_weight, + ctx.bwd_rmsnorm_sm_margin, ctx.zero_centered_gamma + ) + return ( + dxmat.view(ctx.inp_shape), + dgamma, + None, + None, + None, + None, + None, + ) + + +class RMSNorm(torch.nn.Module): + r""" + Applies Root Mean Square Layer Normalization over a mini-batch of inputs as described in + the paper `Root Mean Square Layer Normalization `__ + + .. math:: + y = \frac{x}{RMS(x) + \varepsilon} * \gamma + + where + + .. math:: + RMS(x) = \sqrt{\frac{1}{n}\sum_{i=0}^nx_i^2} + + :math:`\gamma` is a learnable affine transform parameter of size :attr:`hidden_size` + + Parameters + ---------- + hidden_size : int + size of each input sample. + eps : float, default = 1e-5 + a value added to the denominator of layer normalization for numerical stability. + sequence_parallel : bool, default = `False` + if set to `True`, uses sequence parallelism. + params_dtype : torch.dtype, default = `torch.get_default_dtype()` + it controls the type used to allocate the initial parameters. Useful when + the model is trained with lower precision and the original FP32 parameters + would not fit in GPU memory. + zero_centered_gamma : bool, default = 'False' + if set to 'True', gamma parameter in RMSNorm is initialized to 0 and + the RMSNorm formula changes to + + .. math:: + y = \frac{x}{RMS(x) + \varepsilon} * (1 + \gamma) + """ + + def __init__( + self, + hidden_size: int, + eps: float = 1e-5, + sequence_parallel: bool = False, + params_dtype: Optional[torch.dtype] = None, + zero_centered_gamma: bool = False, + ) -> None: + super().__init__() + params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype + self.eps = eps + self.zero_centered_gamma = zero_centered_gamma + self.weight = Parameter( + torch.empty( + hidden_size, + device=torch.cuda.current_device(), + dtype=params_dtype, + ) + ) + setattr(self.weight, "sequence_parallel", sequence_parallel) + self.reset_rms_norm_parameters() + + # These many SMs are subtracted from the total SM count when calling forward + # and backward RMSNorm C APIs. These envvars can be used to prevent the LN + # kernels from using all SMs in the device. This is useful for cases such as + # communication overlap with RMSNorm. + self.fwd_rmsnorm_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) + self.bwd_rmsnorm_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) + + def reset_rms_norm_parameters(self) -> None: + """Init RMSNorm params""" + if not self.zero_centered_gamma: + init.ones_(self.weight) + else: + init.zeros_(self.weight) + + + @no_torch_dynamo + def forward(self, inp: torch.Tensor) -> torch.Tensor: + """RMSNorm FWD""" + if torch.is_grad_enabled(): + fwd_fn = _RMSNorm.apply + args = [] + else: + fwd_fn = _RMSNorm.forward + args = [None] + + args += ( + inp, + self.weight, + self.eps, + self.fwd_rmsnorm_sm_margin, + self.bwd_rmsnorm_sm_margin, + self.zero_centered_gamma, + torch.is_grad_enabled() + ) + + return fwd_fn(*args) diff --git a/transformer_engine/pytorch/te_onnx_extensions.py b/transformer_engine/pytorch/te_onnx_extensions.py index 5990160294..7227205099 100755 --- a/transformer_engine/pytorch/te_onnx_extensions.py +++ b/transformer_engine/pytorch/te_onnx_extensions.py @@ -283,6 +283,20 @@ def onnx_te_gemm( return output +def _ones_like(g, inp, dtype): + """Returns a tensor filled with the scalar value 1, with the same size as input and + with dtype data-type""" + shape = g.op("Shape", inp) + # WAR ONNX spec: ConstantOfShape accepts all data types except for BF16. To WAR + # create a ConstantOfShape with type FP32 and then add a Cast to BF16. + is_bf16 = dtype == torch.bfloat16 + one = g.op("ConstantOfShape", shape, value_t=torch.tensor([1], + dtype=torch.float32 if is_bf16 else dtype)) + if is_bf16: + one = g.op("Cast", one, to_i=_C_onnx.TensorProtoDataType.BFLOAT16) + return one + + @symbolic_helper.parse_args("v", "v", "v", "f", "v", "v", "fs", "i", "i", "b") def onnx_layernorm_fwd_fp8(g, inputs, weight, bias, eps, scale, amax, scale_inv, fp8_tensor, otype, zero_centered_gamma): @@ -305,19 +319,6 @@ def onnx_layernorm_fwd(g, inputs, weight, bias, eps, zero_centered_gamma): """ONNX graph for layernorm_fwd""" # pylint: disable=unused-argument - def ones_like(inp, dtype): - """Returns a tensor filled with the scalar value 1, with the same size as input and - with dtype data-type""" - shape = g.op("Shape", inp) - # WAR ONNX spec: ConstantOfShape accepts all data types except for BF16. To WAR - # create a ConstantOfShape with type FP32 and then add a Cast to BF16. - is_bf16 = dtype == torch.bfloat16 - one = g.op("ConstantOfShape", shape, value_t=torch.tensor([1], - dtype=torch.float32 if is_bf16 else dtype)) - if is_bf16: - one = g.op("Cast", one, to_i=_C_onnx.TensorProtoDataType.BFLOAT16) - return one - normalized_shape = torch.onnx.symbolic_helper._get_tensor_sizes(inputs) if normalized_shape is None: ndim = torch.onnx.symbolic_helper._get_tensor_rank(inputs) @@ -328,7 +329,7 @@ def ones_like(inp, dtype): if zero_centered_gamma: inputs_dtype = inputs.type().dtype() - one = ones_like(weight, inputs_dtype) + one = _ones_like(g, weight, inputs_dtype) weight = g.op("Add", weight, one) axis = -len(normalized_shape) @@ -344,6 +345,57 @@ def ones_like(inp, dtype): ) return ln +@symbolic_helper.parse_args("v", "v", "f", "v", "v", "fs", "i", "i", "b") +def onnx_rmsnorm_fwd_fp8(g, inputs, weight, eps, scale, amax, + scale_inv, fp8_tensor, otype, zero_centered_gamma): + """ONNX graph for rmsnorm_fwd_fp8""" + # pylint: disable=unused-argument + inp_dtype = get_TensorProtoDataType(inputs) + + if inp_dtype != get_TensorProtoDataType(weight): + weight = g.op("Cast", weight, to_i=inp_dtype) + + ln = onnx_rmsnorm_fwd(g, inputs, weight, eps, zero_centered_gamma) + fp8_ln = quantize(g, ln, scale_inv, fp8_tensor) + return fp8_ln + + +@symbolic_helper.parse_args("v", "v", "f", "b") +def onnx_rmsnorm_fwd(g, inputs, weight, eps, zero_centered_gamma): + """ONNX graph for rmsnorm_fwd""" + # pylint: disable=unused-argument + + normalized_shape = torch.onnx.symbolic_helper._get_tensor_sizes(inputs) + if normalized_shape is None: + ndim = torch.onnx.symbolic_helper._get_tensor_rank(inputs) + assert ndim is not None + normalized_shape = list(range(0, ndim)) + # Normalization axis = 0, so normalized_shape uses all dims except dim = 0 + normalized_shape = normalized_shape[1:] + + if zero_centered_gamma: + inputs_dtype = inputs.type().dtype() + one = _ones_like(g, weight, inputs_dtype) + weight = g.op("Add", weight, one) + + axis = -len(normalized_shape) + + inputs_float = g.op("Cast", inputs, to_i=_C_onnx.TensorProtoDataType.FLOAT) + + norm = g.op("ReduceL2", inputs_float, axes_i=[axis]) + shape = g.op("Shape", inputs_float, start_i=-1) + shape_f = g.op("Cast", shape, to_i=_C_onnx.TensorProtoDataType.FLOAT) + n_reciprocal = g.op("Reciprocal", shape_f) + sqrt_n_reciprocal = g.op("Sqrt", n_reciprocal) + rms = g.op("Mul", norm, sqrt_n_reciprocal) + eps_tensor = g.op("ConstantOfShape", shape, value_t=torch.tensor([eps], dtype=torch.float32)) + rms_eps = g.op("Add", rms, eps_tensor) + normalized_input = g.op("Div", inputs_float, rms_eps) + result = g.op("Mul", weight, normalized_input) + result = g.op("Cast", result, to_i=get_TensorProtoDataType(inputs)) + + + return result register_custom_op_symbolic('tex_ts::cast_to_fp8_ts', onnx_cast_to_fp8, VER) register_custom_op_symbolic('tex_ts::cast_from_fp8_ts', onnx_cast_from_fp8, VER) @@ -355,3 +407,5 @@ def ones_like(inp, dtype): register_custom_op_symbolic('tex_ts::te_gemm_ts', onnx_te_gemm, VER) register_custom_op_symbolic('tex_ts::layernorm_fwd_fp8_inf_ts', onnx_layernorm_fwd_fp8, VER) register_custom_op_symbolic('tex_ts::layernorm_fwd_inf_ts', onnx_layernorm_fwd, VER) +register_custom_op_symbolic('tex_ts::rmsnorm_fwd_fp8_inf_ts', onnx_rmsnorm_fwd_fp8, VER) +register_custom_op_symbolic('tex_ts::rmsnorm_fwd_inf_ts', onnx_rmsnorm_fwd, VER) diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 55c547b7ec..7f1b9a7246 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -11,7 +11,7 @@ import torch import transformer_engine_extensions as tex -from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm +from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm from transformer_engine.pytorch.attention import MultiHeadAttention from transformer_engine.pytorch.jit import ( set_jit_fusion_options, @@ -128,6 +128,8 @@ class TransformerLayer(torch.nn.Module): .. math:: y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta + normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm' + type of normalization applied. qkv_weight_interleaved : bool, default = `True` if set to `False`, the QKV weight is interpreted as a concatenation of query, key, and value weights along the `0th` dimension. The default @@ -220,7 +222,8 @@ def __init__( qkv_weight_interleaved: bool = True, ub_tp_comm_overlap: bool = False, bias: bool = True, - activation: str = 'gelu' + activation: str = 'gelu', + normalization: str = "LayerNorm", ) -> None: super().__init__() @@ -312,6 +315,7 @@ def __init__( input_layernorm=not output_layernorm, attention_type="self", bias=bias, + normalization=normalization, ) if layer_type == "decoder": @@ -322,6 +326,7 @@ def __init__( input_layernorm=True, attention_type="cross", bias=bias, + normalization=normalization, ) # LayerNorm -> activation(Linear + Bias) -> Linear @@ -353,6 +358,7 @@ def __init__( ub_split_rs=ub_split_rs, ub_split_ag=ub_split_ag, activation=activation, + normalization=normalization, ) self.hidden_dropout = hidden_dropout @@ -376,8 +382,12 @@ def __init__( hidden_size, seq_length, micro_batch_size ) + norm_module = { + "LayerNorm": LayerNorm, + "RMSNorm": RMSNorm, + } if self.output_layernorm: - self.layernorm = LayerNorm( + self.layernorm = norm_module[normalization]( hidden_size, eps=layernorm_epsilon, sequence_parallel=self.sequence_parallel, From 5ed7e82c55a5adb03388c0854a36a449a21cad3b Mon Sep 17 00:00:00 2001 From: cyanguwa <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 28 Jul 2023 17:50:11 -0700 Subject: [PATCH 43/68] Add support for multi-query and grouped-query attention (#338) * add support for multi-query/grouped-query attention Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix lint Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * revert to flash-attn 1.0.6 and build 2.0.0.post1 manually in CI Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add keyword name for DPA input Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix fused attn tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix skipif for pytest Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Update transformer_engine/pytorch/attention.py Signed-off-by: Kirthi Shankar Sivamani * Update tests/pytorch/test_fused_attn.py Signed-off-by: Kirthi Shankar Sivamani * Fix TP and SP case Signed-off-by: Kirthi Shankar Sivamani * add skipifs for pytest Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove higher limit for flash-attn version Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Signed-off-by: Kirthi Shankar Sivamani Co-authored-by: Kirthi Shankar Sivamani --- qa/L0_unittest/test.sh | 1 + setup.py | 2 +- tests/pytorch/test_fused_attn.py | 114 ++++++++++++++++++++++ tests/pytorch/test_numerics.py | 2 +- transformer_engine/pytorch/attention.py | 83 +++++++++++++--- transformer_engine/pytorch/transformer.py | 10 ++ 6 files changed, 195 insertions(+), 17 deletions(-) diff --git a/qa/L0_unittest/test.sh b/qa/L0_unittest/test.sh index d061b62453..f02ea1c6e8 100644 --- a/qa/L0_unittest/test.sh +++ b/qa/L0_unittest/test.sh @@ -11,3 +11,4 @@ pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py PYTORCH_JIT=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py pytest -v -s $TE_PATH/tests/pytorch/test_jit.py +pytest -v -s $TE_PATH/tests/pytorch/test_fused_attn.py diff --git a/setup.py b/setup.py index ded19044fc..e42b6e01d0 100644 --- a/setup.py +++ b/setup.py @@ -290,7 +290,7 @@ def add_unique(l: List[str], vals: Union[str, List[str]]) -> None: # Framework-specific requirements if "pytorch" in frameworks(): - add_unique(install_reqs, ["torch", "flash-attn>=1.0.6, <=2.0.0.post1"]) + add_unique(install_reqs, ["torch", "flash-attn>=1.0.6"]) add_unique(test_reqs, ["numpy", "onnxruntime", "torchvision"]) if "jax" in frameworks(): if not found_pybind11(): diff --git a/tests/pytorch/test_fused_attn.py b/tests/pytorch/test_fused_attn.py index 1aa100672c..99a82eb6e1 100644 --- a/tests/pytorch/test_fused_attn.py +++ b/tests/pytorch/test_fused_attn.py @@ -8,11 +8,19 @@ from transformer_engine.pytorch.utils import ( init_method_normal, scaled_init_method_normal, + get_device_compute_capability, ) +from transformer_engine.pytorch.fp8 import is_fp8_available from transformer_engine.pytorch import TransformerLayer from transformer_engine.pytorch.attention import DotProductAttention import os +from pkg_resources import packaging +from importlib.metadata import version +fp8_available, reason_for_no_fp8 = is_fp8_available() +_flash_attn_version = packaging.version.Version(version("flash-attn")) +_flash_attn_2_available = _flash_attn_version >= packaging.version.Version("2") + class ModelConfig: def __init__( self, num_layers, hidden_size, num_attention_heads, head_dim, seq_len, @@ -45,6 +53,8 @@ def __init__( batch_sizes = [1, 2, 32] +@pytest.mark.skipif( + get_device_compute_capability() < 8.0, reason="Compute capability 8.0+ is required.") @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", model_configs.keys()) @@ -113,6 +123,8 @@ def _run_dot_product_attention(dtype, bs, config, backend): return op, inp.grad +@pytest.mark.skipif( + get_device_compute_capability() < 8.0, reason="Compute capability 8.0+ is required.") @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", model_configs.keys()) @@ -208,12 +220,114 @@ def _run_transformer_layer(dtype, bs, config, backend): return op, inp.grad +@pytest.mark.skipif(not _flash_attn_2_available, reason="FA2.0 is not available") +@pytest.mark.skipif( + get_device_compute_capability() < 8.0, reason="Compute capability 8.0+ is required.") +@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("bs", batch_sizes) +@pytest.mark.parametrize("model", model_configs.keys()) +def test_transformer_layer_gqa(dtype, bs, model): + """Test TransformerLayer module when its DotProductAttention is enabled with + FlashAttention, FusedAttention, or UnfusedDotProductAttention backend""" + + config = model_configs[model] + def find_factors(x): + f = [] + for i in range(1, x + 1): + if x % i == 0: + f.append(i) + return f + + num_querys_per_gqa_group = find_factors(config.num_attention_heads) + + for num_q_per_gqa_group in num_querys_per_gqa_group: + flash_attn_fwd, flash_attn_bwd = _run_transformer_layer_gqa( + dtype, bs, config, "FlashAttention", num_q_per_gqa_group) + unfused_attn_fwd, unfused_attn_bwd = _run_transformer_layer_gqa( + dtype, bs, config, "UnfusedDotProductAttention", num_q_per_gqa_group) + + atol, rtol = 5e-1, 5e-1 + assert torch.allclose(flash_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol) + assert torch.allclose(flash_attn_bwd, unfused_attn_bwd, atol = atol, rtol = rtol) + +def _run_transformer_layer_gqa(dtype, bs, config, backend, num_querys_per_gqa_group): + + torch.manual_seed(1234) + torch.cuda.manual_seed(1234) + os.environ["NVTE_FLASH_ATTN"] = "0" + if backend == "FlashAttention": + os.environ["NVTE_FLASH_ATTN"] = "1" + + inp = 0.1 * torch.randn( + config.seq_len, bs, config.num_attention_heads * config.head_dim, + dtype = dtype).cuda() + inp.requires_grad=True + seqlens = torch.empty(bs, dtype = torch.int32).cuda() + seqlens.fill_(config.seq_len) + cu_seqlens = torch.zeros(bs + 1, device = inp.device, dtype = torch.int32) + cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0) + op_grad = 0.001 * torch.randint(0, 200, ( + config.seq_len, bs, config.num_attention_heads * config.head_dim + ), dtype = dtype).cuda() + + sigma = 0.02 + init_method = init_method_normal(sigma) + output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) + + layer_number = 1 + drop_path_rate = 0.0 + drop_path_rates = [ + rate.item() for rate in torch.linspace(0, drop_path_rate, config.num_layers)] + + block = ( + TransformerLayer( + config.hidden_size, + 4 * config.hidden_size, + config.num_attention_heads, + num_gqa_groups = config.num_attention_heads / num_querys_per_gqa_group, + layernorm_epsilon = 1e-5, + hidden_dropout = 0.0, + attention_dropout = config.dropout_p, + init_method = init_method, + output_layer_init_method = output_layer_init_method, + layer_number = layer_number, + kv_channels = config.head_dim, + self_attn_mask_type = config.attn_mask_type, + tp_group = None, + tp_size = 1, + params_dtype = dtype, + get_rng_state_tracker = None, + fuse_wgrad_accumulation = False, + seq_length = config.seq_len, + micro_batch_size = bs, + sequence_parallel = False, + apply_residual_connection_post_layernorm = False, + output_layernorm = False, + layer_type = "encoder", + drop_path_rate = drop_path_rates[layer_number - 1], + set_parallel_mode = True, + fuse_qkv_params = True, + zero_centered_gamma = False, + qkv_weight_interleaved = False, + ub_tp_comm_overlap = False, + bias = True, + ) + .to(dtype = dtype) + .cuda() + ) + + op = block(inp) + op.backward(op_grad) + + return op, inp.grad + model_configs_fp8 = { "test1": ModelConfig(1, 1024, 16, 64, 512, 0.0, "no_mask"), } batch_sizes_fp8 = [1, 4] param_types_fp8 = [torch.float16] +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.parametrize("dtype", param_types_fp8) @pytest.mark.parametrize("bs", batch_sizes_fp8) @pytest.mark.parametrize("model", model_configs_fp8.keys()) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 2ed901cb20..143fc9a74d 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -805,7 +805,7 @@ def test_dpa_accuracy(dtype, bs, model): DotProductAttention( config.num_attention_heads, config.embed, - 0.1, # dropout + attention_dropout=0.1, # dropout ) .to(dtype=dtype) .cuda() diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index dd3f561c95..8966f261ed 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -180,6 +180,15 @@ def forward( key_layer.size(0), ) + assert key_layer.shape == value_layer.shape, "Keys and values must have the same shape!" + if key_layer.shape[2] != query_layer.shape[2]: + assert (query_layer.shape[2]%key_layer.shape[2]==0 + ),"The number of attention heads must be divisible by the number of GQA groups!" + key_layer = key_layer.repeat_interleave( + int(query_layer.shape[2]/key_layer.shape[2]), dim = 2) + value_layer = value_layer.repeat_interleave( + int(query_layer.shape[2]/value_layer.shape[2]), dim = 2) + # [sq, b, np, hn] -> [sq, b * np, hn] query_layer = query_layer.reshape( output_size[2], output_size[0] * output_size[1], -1 @@ -722,6 +731,14 @@ class DotProductAttention(torch.nn.Module): number of attention heads in the transformer layer. kv_channels : int number of key-value channels. + num_gqa_groups : Optional[int] = None + number of GQA groups in the transformer layer. + Grouped Query Attention is described in + `this paper `_. + This only affects the keys and values, not the queries. + GQA-1 is equivalent to Multi-Query Attention + (`MQA `_), while GQA-H + is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`. attention_dropout: float, default = 0.0 dropout probability for the dropout op during multi-head attention. attn_mask_type: {'causal', 'padding'}, default = `causal` @@ -744,6 +761,7 @@ def __init__( self, num_attention_heads: int, kv_channels: int, + num_gqa_groups: Optional[int] = None, attention_dropout: float = 0.0, attn_mask_type: str = "causal", sequence_parallel: bool = False, @@ -758,12 +776,16 @@ def __init__( self.tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group) self.tp_group = tp_group self.get_rng_state_tracker = get_rng_state_tracker + self.num_attention_heads = num_attention_heads - projection_size = kv_channels * num_attention_heads - self.hidden_size_per_partition = divide(projection_size, self.tp_size) - self.hidden_size_per_attention_head = divide( - projection_size, num_attention_heads + self.hidden_size_per_attention_head = kv_channels + self.num_gqa_groups = ( + num_attention_heads if num_gqa_groups is None else num_gqa_groups ) + self.num_gqa_groups_per_partition = int(self.num_gqa_groups // tp_size) + + assert (num_attention_heads % self.num_gqa_groups == 0 + ), "The number of attention heads must be divisible by the number of GQA groups!" if sequence_parallel or get_rng_state_tracker is None: attention_dropout_ctx = nullcontext @@ -883,6 +905,10 @@ def forward( Whether to use the fast path to set output tensors to 0 or not. """ + assert (key_layer.shape[-2] == self.num_gqa_groups_per_partition + and value_layer.shape[-2] == self.num_gqa_groups_per_partition + ), f"Keys and values must have {self.num_gqa_groups} heads!" + use_flash_attention = self.use_flash_attention use_fused_attention = self.use_fused_attention @@ -898,6 +924,9 @@ def forward( elif not _flash_attn_2_available and self.device_compute_capability == 8.9: use_flash_attention = False + if not _flash_attn_2_available and self.num_gqa_groups != self.num_attention_heads: + use_flash_attention = False + if self.attn_mask_type == "padding" and attention_mask is not None: use_flash_attention = False use_fused_attention = False @@ -919,7 +948,9 @@ def forward( # DPA does not support FP8; for FP8, use cpp_extensions modules directly is_backend_avail = (fused_attention_backend in [FusedAttnBackend["F16_max512_seqlen"], FusedAttnBackend["F16_arbitrary_seqlen"]]) - use_fused_attention = use_fused_attention and is_backend_avail + use_fused_attention = (use_fused_attention + and is_backend_avail + and self.num_gqa_groups == self.num_attention_heads) if use_flash_attention: if checkpoint_core_attention: @@ -974,6 +1005,7 @@ def __init__( attn_mask_type: str = "causal", tp_group: Optional[dist_group_type] = None, tp_size: int = 1, + num_gqa_groups: Optional[int] = None, fuse_wgrad_accumulation: bool = False, get_rng_state_tracker: Optional[Callable] = None, sequence_parallel: bool = False, @@ -1002,6 +1034,7 @@ def __init__( self.params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype self.init_method = init_method self.attn_mask_type = attn_mask_type + self.num_attention_heads = num_attention_heads if not fuse_qkv_params: qkv_weight_interleaved = False @@ -1017,6 +1050,15 @@ def __init__( self.hidden_size_per_attention_head = kv_channels self.num_attention_heads_per_partition = divide(num_attention_heads, tp_size) + self.num_gqa_groups = ( + num_attention_heads if num_gqa_groups is None else num_gqa_groups + ) + assert (num_attention_heads % self.num_gqa_groups == 0 + ), "The number of GQA groups must be divisible by the number of attention heads!" + assert (num_attention_heads % tp_size == 0 + ), "The number of GQA groups must be divisible by tensor parallel size!" + self.num_gqa_groups_per_partition = int(self.num_gqa_groups // tp_size) + self.hidden_size_kv = int(hidden_size * self.num_gqa_groups // num_attention_heads) common_gemm_kwargs = { "fuse_wgrad_accumulation": fuse_wgrad_accumulation, @@ -1029,7 +1071,7 @@ def __init__( qkv_parallel_mode = "column" if set_parallel_mode else None - if self.attention_type == "self": + if self.attention_type == "self" and self.num_gqa_groups == self.num_attention_heads: if self.input_layernorm: self.layernorm_qkv = LayerNormLinear( hidden_size, @@ -1059,7 +1101,9 @@ def __init__( parameters_split=("query_", "key_", "value_") if not fuse_qkv_params else None, **common_gemm_kwargs, ) - else: + elif ((self.attention_type == "cross") + or (self.attention_type == "self" + and self.num_gqa_groups != self.num_attention_heads)): if self.input_layernorm: self.layernorm_query = LayerNormLinear( hidden_size, @@ -1089,7 +1133,7 @@ def __init__( ) self.key_value = Linear( hidden_size, - 2 * hidden_size, + 2 * self.hidden_size_kv, init_method=init_method, bias=bias, return_bias=False, @@ -1102,7 +1146,8 @@ def __init__( self.core_attention = DotProductAttention( num_attention_heads, kv_channels, - attention_dropout, + num_gqa_groups=self.num_gqa_groups, + attention_dropout=attention_dropout, tp_size=tp_size, get_rng_state_tracker=get_rng_state_tracker, attn_mask_type=attn_mask_type, @@ -1131,7 +1176,7 @@ def _allocate_memory( return torch.empty( inference_max_sequence_len, batch_size, - self.num_attention_heads_per_partition, + self.num_gqa_groups_per_partition, self.hidden_size_per_attention_head, dtype=dtype, device=torch.cuda.current_device(), @@ -1192,7 +1237,7 @@ def forward( # Query, Key, and Value # ===================== - if self.attention_type == "self": + if self.attention_type == "self" and self.num_gqa_groups == self.num_attention_heads: # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] if self.input_layernorm: layernorm_qkv_outputs = self.layernorm_qkv( @@ -1235,17 +1280,25 @@ def forward( query_layer, key_layer, value_layer = split_tensor_along_dim( mixed_x_layer, split_dim, 3 ) - else: + elif ((self.attention_type == "cross") + or (self.attention_type == "self" + and self.num_gqa_groups != self.num_attention_heads)): + + if self.attention_type == "cross": + input_tensor = encoder_output + else: + input_tensor = hidden_states + # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] mixed_kv_layer = self.key_value( - encoder_output, + input_tensor, is_first_microbatch=is_first_microbatch, ) if self.qkv_weight_interleaved: # [sq, b, (np * 2 * hn)] --> [sq, b, np, 2 * hn] new_tensor_shape = mixed_kv_layer.size()[:-1] + ( - self.num_attention_heads_per_partition, + self.num_gqa_groups_per_partition, 2 * self.hidden_size_per_attention_head, ) # split along last dimension @@ -1253,7 +1306,7 @@ def forward( else: # [sq, b, (np * 2 * hn)] --> [sq, b, 2 * np, hn] new_tensor_shape = mixed_kv_layer.size()[:-1] + ( - 2 * self.num_attention_heads_per_partition, + 2 * self.num_gqa_groups_per_partition, self.hidden_size_per_attention_head, ) # split along second last dimension diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 7f1b9a7246..572b905dd8 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -86,6 +86,14 @@ class TransformerLayer(torch.nn.Module): intermediate size to which input samples are projected. num_attention_heads : int number of attention heads in the transformer layer. + num_gqa_groups : int, default = `None` + number of GQA groups in the transformer layer. + Grouped Query Attention is described in + `this paper `_. + This only affects the keys and values, not the querys. + GQA-1 is equivalent to Multi-Query Attention + (`MQA `_), while GQA-H + is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`. layernorm_epsilon : float, default = 1e-5 a value added to the denominator of layer normalization for numerical stability. @@ -194,6 +202,7 @@ def __init__( hidden_size: int, ffn_hidden_size: int, num_attention_heads: int, + num_gqa_groups: Optional[int] = None, layernorm_epsilon: float = 1e-5, hidden_dropout: float = 0.1, attention_dropout: float = 0.1, @@ -293,6 +302,7 @@ def __init__( "layer_number": layer_number, "tp_group": tp_group, "tp_size": self.tp_size, + "num_gqa_groups": num_gqa_groups, "fuse_wgrad_accumulation": fuse_wgrad_accumulation, "get_rng_state_tracker": get_rng_state_tracker, "sequence_parallel": self.sequence_parallel, From 9347b10ad9bb1faa289d92920fc0d889efeec177 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Mon, 31 Jul 2023 15:12:41 -0700 Subject: [PATCH 44/68] Add compilation OOM note for FA 2.0 (#346) Add compilation warning for FA 2.0 Signed-off-by: Kirthi Shankar Sivamani --- README.rst | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/README.rst b/README.rst index d892eae244..5920e36e5c 100644 --- a/README.rst +++ b/README.rst @@ -191,6 +191,14 @@ From source `See the installation guide `_. +Compiling with Flash Attention 2 +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +TransformerEngine release v0.11.0 adds support for Flash Attention 2.0 for improved performance. It is a known issue that Flash Attention 2.0 compilation is +resource intensive and requires a large amount of RAM (see `bug `_), which may lead to out of memory +errors during the installation of TransformerEngine. To circumvent the issue, please try setting **MAX_JOBS=1** in the environment. If the errors persist, then +proceed to install a supported version of Flash Attention 1 (v1.0.6 to v1.0.9). + Model Support ---------- From 3f01b4f812e0e501257278ec269499ea02b2d4f3 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Wed, 19 Jul 2023 21:40:44 -0700 Subject: [PATCH 45/68] Replace deprecated sharding API in JAX test (#332) Replace deprecated sharding API Signed-off-by: Tim Moon Co-authored-by: Kirthi Shankar Sivamani --- tests/jax/test_sharding.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/jax/test_sharding.py b/tests/jax/test_sharding.py index 217af3f816..ea216ac514 100644 --- a/tests/jax/test_sharding.py +++ b/tests/jax/test_sharding.py @@ -5,7 +5,6 @@ import jax import numpy as np import pytest -from jax.experimental import maps from utils import is_devices_enough from transformer_engine.jax.flax import extend_logical_axis_rules @@ -79,7 +78,7 @@ def test_infer_major_sharding_type( sharding_type): devices = np.asarray(jax.devices()[:DEVICE_COUNT]).reshape(*mesh_shape) with global_shard_guard(_get_sharding_resource(mesh_names, sharding_type)): - with maps.Mesh(devices, mesh_names): + with jax.sharding.Mesh(devices, mesh_names): assert infer_major_sharding_type() is sharding_type.value[0] @pytest.mark.parametrize('mesh_shape,mesh_names,sharding_type', MESH_CONFIG) @@ -150,7 +149,7 @@ def get_ref_sm(): devices = np.asarray(jax.devices()[:DEVICE_COUNT]).reshape(*mesh_shape) with global_shard_guard(_get_sharding_resource(mesh_names, sharding_type)): - with maps.Mesh(devices, mesh_names): + with jax.sharding.Mesh(devices, mesh_names): test_sm = get_fp8_meta_sharding_meta( sharding_type, num_of_fp8_meta, @@ -240,7 +239,7 @@ def get_ref_sm(): devices = np.asarray(jax.devices()[:DEVICE_COUNT]).reshape(*mesh_shape) with global_shard_guard(_get_sharding_resource(mesh_names, sharding_type)): - with maps.Mesh(devices, mesh_names): + with jax.sharding.Mesh(devices, mesh_names): test_sm = get_dot_sharding_meta( sharding_type, a_shape, @@ -319,7 +318,7 @@ def get_ref_sm(): devices = np.asarray(jax.devices()[:DEVICE_COUNT]).reshape(*mesh_shape) with global_shard_guard(_get_sharding_resource(mesh_names, sharding_type)): - with maps.Mesh(devices, mesh_names): + with jax.sharding.Mesh(devices, mesh_names): ref_sm, need_assert = get_ref_sm() try: test_sm = get_elementwise_sharding_meta( From 9799608b50c30989cdc75468dd76b4bebed8738e Mon Sep 17 00:00:00 2001 From: Shijie Date: Fri, 18 Aug 2023 07:07:10 +0800 Subject: [PATCH 46/68] [Paddle] Add nn layer (#361) * Add nn.layer: softmax, attention, transformer Signed-off-by: Shijie Wang * code refactor Signed-off-by: Shijie Wang * code refactor Signed-off-by: Shijie Wang * update docs and set dropout=0.1 Signed-off-by: Shijie Wang * Update transformer_engine/paddle/layer/attention.py Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Shijie Wang Signed-off-by: Kirthi Shankar Sivamani Co-authored-by: Kirthi Shankar Sivamani --- tests/paddle/test_layers.py | 490 +++++++++++++++ tests/paddle/test_operators.py | 8 +- transformer_engine/paddle/__init__.py | 3 +- transformer_engine/paddle/constants.py | 6 + transformer_engine/paddle/cpp_extensions.py | 8 +- transformer_engine/paddle/layer/__init__.py | 3 + transformer_engine/paddle/layer/attention.py | 568 ++++++++++++++++++ transformer_engine/paddle/layer/layernorm.py | 2 +- .../paddle/layer/layernorm_linear.py | 3 +- .../paddle/layer/layernorm_mlp.py | 2 +- transformer_engine/paddle/layer/softmax.py | 237 ++++++++ .../paddle/layer/transformer.py | 260 ++++++++ transformer_engine/paddle/utils.py | 34 ++ 13 files changed, 1610 insertions(+), 14 deletions(-) create mode 100644 transformer_engine/paddle/layer/attention.py create mode 100644 transformer_engine/paddle/layer/softmax.py create mode 100644 transformer_engine/paddle/layer/transformer.py diff --git a/tests/paddle/test_layers.py b/tests/paddle/test_layers.py index 3bd3a562db..171b9233e7 100644 --- a/tests/paddle/test_layers.py +++ b/tests/paddle/test_layers.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """Test TE Paddle Layer-level APIs""" +import math import os import pytest from utils import assert_allclose @@ -605,3 +606,492 @@ def test_layernorm_mlp_fp8(bs, hidden_size, ffn_hidden_size, has_bias, no_dbias, if do_calibration: assert paddle.count_nonzero(layer_te.fp8_meta["scaling_fwd"].amax_history).item() > 0 + + +@pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0), + reason="cuDNN fMHA requires Ampere+ GPU") +@pytest.mark.parametrize('bs', [1, 2, 8]) +@pytest.mark.parametrize('hidden_size, num_heads', [[1024, 16], [768, 12]]) +@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[128, 128], [512, 512]]) +@pytest.mark.parametrize('attn_type', ['self', 'cross']) +@pytest.mark.parametrize('mask_type', ['causal', 'padding']) +@pytest.mark.parametrize('math_dtype', ['bfloat16', 'float16']) +def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen, attn_type, + mask_type, math_dtype): + """ + Test DotProductAttention Layer + """ + paddle.set_default_dtype(math_dtype) + rtol = 1e-4 + atol = 2e-2 + + head_size = hidden_size // num_heads + self_attn_qkv_input = paddle.normal(mean=0.0, + std=0.02, + shape=(bs, q_seqlen, 3, num_heads, + head_size)).astype(math_dtype) + cross_attn_q_input = paddle.normal(mean=0.0, + std=0.02, + shape=(bs, q_seqlen, num_heads, + head_size)).astype(math_dtype) + cross_attn_kv_input = paddle.normal(mean=0.0, + std=0.02, + shape=(bs, kv_seqlen, 2, num_heads, + head_size)).astype(math_dtype) + + q_actual_seqlen = paddle.randint(low=20, high=q_seqlen, shape=(bs,), dtype='int32') + kv_actual_seqlen = paddle.randint(low=20, high=kv_seqlen, shape=(bs,), + dtype='int32') if attn_type == 'cross' else q_actual_seqlen + attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype='bool') + + grad_out = paddle.normal(mean=0.0, std=0.02, + shape=(bs, q_seqlen, num_heads, head_size)).astype('float32') + for i in range(0, bs): + grad_out[i, q_actual_seqlen[i]:, :, :] = 0 + grad_out = grad_out.astype(math_dtype) + + for i in range(0, bs): + attn_mask[i, 0, 0:q_actual_seqlen[i], 0:kv_actual_seqlen[i]] = False + + norm_factor = math.sqrt(hidden_size // num_heads) + layer_te = te.DotProductAttention(norm_factor, + attention_dropout=0.0, + attn_mask_type=mask_type, + attention_type=attn_type, + backend='transformer_engine') + layer_pd = te.DotProductAttention(norm_factor, + attention_dropout=0.0, + attn_mask_type=mask_type, + attention_type=attn_type, + backend='paddle') + + def calc_attn_output_and_grad(layer, q, kv, mask, dout): + _q = paddle.to_tensor(q, stop_gradient=False) + _kv = paddle.to_tensor(kv, stop_gradient=False) if kv is not None else None + + out = layer(_q, _kv, mask) + out.backward(dout) + return out, _q.grad, _kv.grad if _kv is not None else None + + if attn_type == 'self': + out, qkv_grad, _ = calc_attn_output_and_grad(layer_te, self_attn_qkv_input, None, attn_mask, + grad_out) + out_ref, qkv_grad_ref, _ = calc_attn_output_and_grad(layer_pd, self_attn_qkv_input, None, + attn_mask, grad_out) + valid_out_ref = paddle.full_like(out_ref, 0) + for i in range(0, bs): + valid_out_ref[i, 0:q_actual_seqlen[i], :, :] = out_ref[i, 0:q_actual_seqlen[i], :, :] + + q_grad = qkv_grad[:, :, 0] + k_grad = qkv_grad[:, :, 1] + v_grad = qkv_grad[:, :, 2] + q_grad_ref = qkv_grad_ref[:, :, 0] + k_grad_ref = qkv_grad_ref[:, :, 1] + v_grad_ref = qkv_grad_ref[:, :, 2] + + else: + out, q_grad, kv_grad = calc_attn_output_and_grad(layer_te, cross_attn_q_input, + cross_attn_kv_input, attn_mask, grad_out) + out_ref, q_grad_ref, kv_grad_ref = calc_attn_output_and_grad(layer_pd, cross_attn_q_input, + cross_attn_kv_input, attn_mask, + grad_out) + + valid_out_ref = paddle.full_like(out_ref, 0) + for i in range(0, bs): + valid_out_ref[i, 0:q_actual_seqlen[i], :, :] = out_ref[i, 0:q_actual_seqlen[i], :, :] + + k_grad = kv_grad[:, :, 0] + v_grad = kv_grad[:, :, 1] + k_grad_ref = kv_grad_ref[:, :, 0] + v_grad_ref = kv_grad_ref[:, :, 1] + + valid_q_grad_ref = paddle.full_like(q_grad_ref, 0) + valid_k_grad_ref = paddle.full_like(k_grad_ref, 0) + valid_v_grad_ref = paddle.full_like(v_grad_ref, 0) + for i in range(0, bs): + valid_q_grad_ref[i, 0:q_actual_seqlen[i], :, :] = q_grad_ref[i, 0:q_actual_seqlen[i], :, :] + valid_k_grad_ref[i, 0:kv_actual_seqlen[i], :, :] = k_grad_ref[i, + 0:kv_actual_seqlen[i], :, :] + valid_v_grad_ref[i, 0:kv_actual_seqlen[i], :, :] = v_grad_ref[i, + 0:kv_actual_seqlen[i], :, :] + + assert_allclose(out, valid_out_ref, rtol=rtol, atol=atol) + assert_allclose(q_grad, valid_q_grad_ref, rtol=rtol, atol=atol) + assert_allclose(k_grad, valid_k_grad_ref, rtol=rtol, atol=atol) + assert_allclose(v_grad, valid_v_grad_ref, rtol=rtol, atol=atol) + + +@pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0), + reason="cuDNN fMHA requires Ampere+ GPU") +@pytest.mark.parametrize('bs', [1, 2, 8]) +@pytest.mark.parametrize('hidden_size, num_heads, ffn_hidden_size', [[1024, 16, 4096]]) +@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[128, 128], [512, 512]]) +@pytest.mark.parametrize('has_bias, no_dbias', [[False, True], [True, True], [True, False]]) +@pytest.mark.parametrize('no_wgrad', [True, False]) +@pytest.mark.parametrize('mask_type', ['causal', 'padding']) +@pytest.mark.parametrize('math_dtype', ['bfloat16', 'float16']) +@pytest.mark.parametrize('output_layernorm', [True, False]) +@pytest.mark.parametrize('return_layernorm_output', [True, False]) +def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size, has_bias, no_dbias, + no_wgrad, q_seqlen, kv_seqlen, mask_type, math_dtype, + output_layernorm, return_layernorm_output): + """ + Test Transformer Encoder Layer + """ + paddle.set_default_dtype(math_dtype) + rtol = 5e-2 + atol = 5e-2 + eps = 1e-3 + + encoder_input = paddle.uniform(shape=(bs, q_seqlen, hidden_size), dtype=math_dtype) + + q_actual_seqlen = paddle.ones(shape=(bs,), dtype='int32') * q_seqlen + kv_actual_seqlen = q_actual_seqlen + attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype='bool') + + grad_out = paddle.normal(mean=0.0, std=0.02, + shape=(bs, q_seqlen, hidden_size)).astype('float32') + for i in range(0, bs): + grad_out[i, q_actual_seqlen[i]:, :] = 0 + grad_out = grad_out.astype(math_dtype) + + for i in range(0, bs): + attn_mask[i, 0, 0:q_actual_seqlen[i], 0:kv_actual_seqlen[i]] = False + + layer_te = te.TransformerLayer(hidden_size, + ffn_hidden_size, + num_heads, + layernorm_epsilon=eps, + hidden_dropout=0.0, + attention_dropout=0.0, + weight_attr=None, + bias_attr=None if has_bias else False, + self_attn_mask_type=mask_type, + apply_residual_connection_post_layernorm=return_layernorm_output, + output_layernorm=output_layernorm, + layer_type='encoder', + backend='transformer_engine') + layer_pd = te.TransformerLayer(hidden_size, + ffn_hidden_size, + num_heads, + layernorm_epsilon=eps, + hidden_dropout=0.0, + attention_dropout=0.0, + weight_attr=None, + bias_attr=None if has_bias else False, + self_attn_mask_type=mask_type, + apply_residual_connection_post_layernorm=return_layernorm_output, + output_layernorm=output_layernorm, + layer_type='encoder', + backend='paddle') + + # MultiHeadAttention params + if output_layernorm: + layer_pd.self_attention.qkv.weight.copy_(layer_te.self_attention.qkv.weight.T, True) + layer_pd.self_attention.qkv.weight.stop_gradient = no_wgrad + layer_te.self_attention.qkv.weight.stop_gradient = no_wgrad + if has_bias: + layer_pd.self_attention.qkv.bias.copy_(layer_te.self_attention.qkv.bias, True) + layer_pd.self_attention.qkv.bias.stop_gradient = no_dbias + layer_te.self_attention.qkv.bias.stop_gradient = no_dbias + else: + layer_pd.self_attention.layernorm_qkv.ln_weight.copy_( + layer_te.self_attention.layernorm_qkv.ln_weight, True) + layer_pd.self_attention.layernorm_qkv.ln_bias.copy_( + layer_te.self_attention.layernorm_qkv.ln_bias, True) + layer_pd.self_attention.layernorm_qkv.weight.copy_( + layer_te.self_attention.layernorm_qkv.weight.T, True) + layer_pd.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad + layer_pd.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias + layer_pd.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad + layer_te.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad + layer_te.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias + layer_te.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad + if has_bias: + layer_pd.self_attention.layernorm_qkv.bias.copy_( + layer_te.self_attention.layernorm_qkv.bias, True) + layer_pd.self_attention.layernorm_qkv.bias.stop_gradient = no_dbias + layer_te.self_attention.layernorm_qkv.bias.stop_gradient = no_dbias + + layer_pd.self_attention.proj.weight.copy_(layer_te.self_attention.proj.weight.T, True) + layer_pd.self_attention.proj.weight.stop_gradient = no_wgrad + layer_te.self_attention.proj.weight.stop_gradient = no_wgrad + if has_bias: + layer_pd.self_attention.proj.bias.copy_(layer_te.self_attention.proj.bias, True) + layer_pd.self_attention.proj.bias.stop_gradient = no_dbias + layer_te.self_attention.proj.bias.stop_gradient = no_dbias + + # LayerNorm MLP params + layer_pd.layernorm_mlp.ln_weight.copy_(layer_te.layernorm_mlp.ln_weight, True) + layer_pd.layernorm_mlp.ln_bias.copy_(layer_te.layernorm_mlp.ln_bias, True) + layer_pd.layernorm_mlp.fc1_weight.copy_(layer_te.layernorm_mlp.fc1_weight.T, True) + layer_pd.layernorm_mlp.fc2_weight.copy_(layer_te.layernorm_mlp.fc2_weight.T, True) + layer_pd.layernorm_mlp.ln_weight.stop_gradient = no_wgrad + layer_pd.layernorm_mlp.ln_bias.stop_gradient = no_dbias + layer_pd.layernorm_mlp.fc1_weight.stop_gradient = no_wgrad + layer_pd.layernorm_mlp.fc2_weight.stop_gradient = no_wgrad + layer_te.layernorm_mlp.ln_weight.stop_gradient = no_wgrad + layer_te.layernorm_mlp.ln_bias.stop_gradient = no_dbias + layer_te.layernorm_mlp.fc1_weight.stop_gradient = no_wgrad + layer_te.layernorm_mlp.fc2_weight.stop_gradient = no_wgrad + if has_bias: + layer_pd.layernorm_mlp.fc1_bias.copy_(layer_te.layernorm_mlp.fc1_bias, True) + layer_pd.layernorm_mlp.fc2_bias.copy_(layer_te.layernorm_mlp.fc2_bias, True) + layer_pd.layernorm_mlp.fc1_bias.stop_gradient = no_dbias + layer_pd.layernorm_mlp.fc2_bias.stop_gradient = no_dbias + layer_te.layernorm_mlp.fc1_bias.stop_gradient = no_dbias + layer_te.layernorm_mlp.fc2_bias.stop_gradient = no_dbias + + if output_layernorm: + layer_pd.layernorm.weight.copy_(layer_te.layernorm.weight, True) + layer_pd.layernorm.bias.copy_(layer_te.layernorm.bias, True) + layer_pd.layernorm.weight.stop_gradient = no_wgrad + layer_pd.layernorm.bias.stop_gradient = no_dbias + layer_te.layernorm.weight.stop_gradient = no_wgrad + layer_te.layernorm.bias.stop_gradient = no_dbias + + def calc_transformer_output_and_grad(layer, encoder_input, mask, dout): + _encoder_input = paddle.to_tensor(encoder_input, stop_gradient=False) + out = layer(_encoder_input, mask) + out.backward(dout) + return out, _encoder_input.grad + + out_ref, grad_input_ref = calc_transformer_output_and_grad(layer_pd, encoder_input, attn_mask, + grad_out) + out, grad_input = calc_transformer_output_and_grad(layer_te, encoder_input, attn_mask, grad_out) + + assert_allclose(out, out_ref, rtol=rtol, atol=atol) + assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol) + if not no_wgrad: + if output_layernorm: + assert_allclose(layer_te.self_attention.qkv.weight.grad, + layer_pd.self_attention.qkv.weight.grad.T, + rtol=rtol, + atol=atol) + else: + assert_allclose(layer_te.self_attention.layernorm_qkv.weight.grad, + layer_pd.self_attention.layernorm_qkv.weight.grad.T, + rtol=rtol, + atol=atol) + if not no_dbias: + if output_layernorm: + assert_allclose(layer_te.self_attention.qkv.bias.grad, + layer_pd.self_attention.qkv.bias.grad, + rtol=0.01, + atol=0.5) + else: + assert_allclose(layer_te.self_attention.layernorm_qkv.bias.grad, + layer_pd.self_attention.layernorm_qkv.bias.grad, + rtol=0.01, + atol=0.5) + + +@pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0), + reason="cuDNN fMHA requires Ampere+ GPU") +@pytest.mark.parametrize('bs', [1, 2, 8]) +@pytest.mark.parametrize('hidden_size, num_heads, ffn_hidden_size', [[1024, 16, 4096]]) +@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[128, 128], [512, 512]]) +@pytest.mark.parametrize('has_bias, no_dbias', [[False, True], [True, True], [True, False]]) +@pytest.mark.parametrize('no_wgrad', [True, False]) +@pytest.mark.parametrize('mask_type', ['causal', 'padding']) +@pytest.mark.parametrize('math_dtype', ['bfloat16', 'float16']) +@pytest.mark.parametrize('output_layernorm', [True, False]) +@pytest.mark.parametrize('return_layernorm_output', [True, False]) +def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size, has_bias, no_dbias, + no_wgrad, q_seqlen, kv_seqlen, mask_type, math_dtype, + output_layernorm, return_layernorm_output): + """ + Test Transformer Decoder Layer + """ + paddle.set_default_dtype(math_dtype) + rtol = 5e-2 + atol = 5e-2 + eps = 1e-3 + + encoder_input = paddle.uniform(shape=(bs, q_seqlen, hidden_size), dtype=math_dtype) + encoder_output = paddle.uniform(shape=(bs, kv_seqlen, hidden_size), dtype=math_dtype) + + q_actual_seqlen = paddle.ones(shape=(bs,), dtype='int32') * q_seqlen + kv_actual_seqlen = q_actual_seqlen + attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype='bool') + + grad_out = paddle.normal(mean=0.0, std=0.2, shape=(bs, q_seqlen, hidden_size)).astype('float32') + for i in range(0, bs): + grad_out[i, q_actual_seqlen[i]:, :] = 0 + grad_out = grad_out.astype(math_dtype) + + for i in range(0, bs): + attn_mask[i, 0, 0:q_actual_seqlen[i], 0:kv_actual_seqlen[i]] = False + + layer_te = te.TransformerLayer(hidden_size, + ffn_hidden_size, + num_heads, + layernorm_epsilon=eps, + hidden_dropout=0.0, + attention_dropout=0.0, + weight_attr=None, + bias_attr=None if has_bias else False, + self_attn_mask_type=mask_type, + apply_residual_connection_post_layernorm=return_layernorm_output, + output_layernorm=output_layernorm, + layer_type='decoder', + backend='transformer_engine') + layer_pd = te.TransformerLayer(hidden_size, + ffn_hidden_size, + num_heads, + layernorm_epsilon=eps, + hidden_dropout=0.0, + attention_dropout=0.0, + weight_attr=None, + bias_attr=None if has_bias else False, + self_attn_mask_type=mask_type, + apply_residual_connection_post_layernorm=return_layernorm_output, + output_layernorm=output_layernorm, + layer_type='decoder', + backend='paddle') + + # MultiHeadAttention params - self attn + if output_layernorm: + layer_pd.self_attention.qkv.weight.copy_(layer_te.self_attention.qkv.weight.T, True) + layer_pd.self_attention.qkv.weight.stop_gradient = no_wgrad + layer_te.self_attention.qkv.weight.stop_gradient = no_wgrad + if has_bias: + layer_pd.self_attention.qkv.bias.copy_(layer_te.self_attention.qkv.bias, True) + layer_pd.self_attention.qkv.bias.stop_gradient = no_dbias + layer_te.self_attention.qkv.bias.stop_gradient = no_dbias + else: + layer_pd.self_attention.layernorm_qkv.ln_weight.copy_( + layer_te.self_attention.layernorm_qkv.ln_weight, True) + layer_pd.self_attention.layernorm_qkv.ln_bias.copy_( + layer_te.self_attention.layernorm_qkv.ln_bias, True) + layer_pd.self_attention.layernorm_qkv.weight.copy_( + layer_te.self_attention.layernorm_qkv.weight.T, True) + layer_pd.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad + layer_pd.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias + layer_pd.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad + layer_te.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad + layer_te.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias + layer_te.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad + if has_bias: + layer_pd.self_attention.layernorm_qkv.bias.copy_( + layer_te.self_attention.layernorm_qkv.bias, True) + layer_pd.self_attention.layernorm_qkv.bias.stop_gradient = no_dbias + layer_te.self_attention.layernorm_qkv.bias.stop_gradient = no_dbias + + layer_pd.self_attention.proj.weight.copy_(layer_te.self_attention.proj.weight.T, True) + layer_pd.self_attention.proj.weight.stop_gradient = no_wgrad + layer_te.self_attention.proj.weight.stop_gradient = no_wgrad + if has_bias: + layer_pd.self_attention.proj.bias.copy_(layer_te.self_attention.proj.bias, True) + layer_pd.self_attention.proj.bias.stop_gradient = no_dbias + layer_te.self_attention.proj.bias.stop_gradient = no_dbias + + # MultiHeadAttention params - cross attn + layer_pd.inter_attention.layernorm_query.ln_weight.copy_( + layer_te.inter_attention.layernorm_query.ln_weight, True) + layer_pd.inter_attention.layernorm_query.ln_bias.copy_( + layer_te.inter_attention.layernorm_query.ln_bias, True) + layer_pd.inter_attention.layernorm_query.weight.copy_( + layer_te.inter_attention.layernorm_query.weight.T, True) + layer_pd.inter_attention.layernorm_query.ln_weight.stop_gradient = no_wgrad + layer_pd.inter_attention.layernorm_query.ln_bias.stop_gradient = no_dbias + layer_pd.inter_attention.layernorm_query.weight.stop_gradient = no_wgrad + layer_te.inter_attention.layernorm_query.ln_weight.stop_gradient = no_wgrad + layer_te.inter_attention.layernorm_query.ln_bias.stop_gradient = no_dbias + layer_te.inter_attention.layernorm_query.weight.stop_gradient = no_wgrad + if has_bias: + layer_pd.inter_attention.layernorm_query.bias.copy_( + layer_te.inter_attention.layernorm_query.bias, True) + layer_pd.inter_attention.layernorm_query.bias.stop_gradient = no_dbias + layer_te.inter_attention.layernorm_query.bias.stop_gradient = no_dbias + + layer_pd.inter_attention.key_value.weight.copy_(layer_te.inter_attention.key_value.weight.T, + True) + layer_pd.inter_attention.key_value.weight.stop_gradient = no_wgrad + layer_te.inter_attention.key_value.weight.stop_gradient = no_wgrad + layer_pd.inter_attention.proj.weight.copy_(layer_te.inter_attention.proj.weight.T, True) + layer_pd.inter_attention.proj.weight.stop_gradient = no_wgrad + layer_te.inter_attention.proj.weight.stop_gradient = no_wgrad + if has_bias: + layer_pd.inter_attention.key_value.bias.copy_(layer_te.inter_attention.key_value.bias, True) + layer_pd.inter_attention.key_value.bias.stop_gradient = no_dbias + layer_te.inter_attention.key_value.bias.stop_gradient = no_dbias + layer_pd.inter_attention.proj.bias.copy_(layer_te.inter_attention.proj.bias, True) + layer_pd.inter_attention.proj.bias.stop_gradient = no_dbias + layer_te.inter_attention.proj.bias.stop_gradient = no_dbias + + # LayerNorm MLP params + layer_pd.layernorm_mlp.ln_weight.copy_(layer_te.layernorm_mlp.ln_weight, True) + layer_pd.layernorm_mlp.ln_bias.copy_(layer_te.layernorm_mlp.ln_bias, True) + layer_pd.layernorm_mlp.fc1_weight.copy_(layer_te.layernorm_mlp.fc1_weight.T, True) + layer_pd.layernorm_mlp.fc2_weight.copy_(layer_te.layernorm_mlp.fc2_weight.T, True) + layer_pd.layernorm_mlp.ln_weight.stop_gradient = no_wgrad + layer_pd.layernorm_mlp.ln_bias.stop_gradient = no_dbias + layer_pd.layernorm_mlp.fc1_weight.stop_gradient = no_wgrad + layer_pd.layernorm_mlp.fc2_weight.stop_gradient = no_wgrad + layer_te.layernorm_mlp.ln_weight.stop_gradient = no_wgrad + layer_te.layernorm_mlp.ln_bias.stop_gradient = no_dbias + layer_te.layernorm_mlp.fc1_weight.stop_gradient = no_wgrad + layer_te.layernorm_mlp.fc2_weight.stop_gradient = no_wgrad + if has_bias: + layer_pd.layernorm_mlp.fc1_bias.copy_(layer_te.layernorm_mlp.fc1_bias, True) + layer_pd.layernorm_mlp.fc2_bias.copy_(layer_te.layernorm_mlp.fc2_bias, True) + layer_pd.layernorm_mlp.fc1_bias.stop_gradient = no_dbias + layer_pd.layernorm_mlp.fc2_bias.stop_gradient = no_dbias + layer_te.layernorm_mlp.fc1_bias.stop_gradient = no_dbias + layer_te.layernorm_mlp.fc2_bias.stop_gradient = no_dbias + + if output_layernorm: + layer_pd.layernorm.weight.copy_(layer_te.layernorm.weight, True) + layer_pd.layernorm.bias.copy_(layer_te.layernorm.bias, True) + layer_pd.layernorm.weight.stop_gradient = no_wgrad + layer_pd.layernorm.bias.stop_gradient = no_dbias + layer_te.layernorm.weight.stop_gradient = no_wgrad + layer_te.layernorm.bias.stop_gradient = no_dbias + + def calc_transformer_output_and_grad(layer, encoder_input, mask, encoder_output, + enc_dec_attn_mask, dout): + _encoder_input = paddle.to_tensor(encoder_input, stop_gradient=False) + _encoder_output = paddle.to_tensor(encoder_output, stop_gradient=False) + out = layer(_encoder_input, mask, _encoder_output, enc_dec_attn_mask) + out.backward(dout) + return out, _encoder_input.grad, _encoder_output.grad + + out_ref, grad_encoder_input_ref, grad_encoder_output_ref = calc_transformer_output_and_grad( + layer_pd, encoder_input, attn_mask, encoder_output, attn_mask, grad_out) + out, grad_encoder_input, grad_encoder_output = calc_transformer_output_and_grad( + layer_te, encoder_input, attn_mask, encoder_output, attn_mask, grad_out) + + assert_allclose(out, out_ref, rtol=rtol, atol=atol) + assert_allclose(grad_encoder_input, grad_encoder_input_ref, rtol=rtol, atol=atol) + assert_allclose(grad_encoder_output, grad_encoder_output_ref, rtol=rtol, atol=atol) + if not no_wgrad: + if output_layernorm: + assert_allclose(layer_te.self_attention.qkv.weight.grad, + layer_pd.self_attention.qkv.weight.grad.T, + rtol=rtol, + atol=atol) + else: + assert_allclose(layer_te.self_attention.layernorm_qkv.weight.grad, + layer_pd.self_attention.layernorm_qkv.weight.grad.T, + rtol=rtol, + atol=0.1) + assert_allclose(layer_te.inter_attention.layernorm_query.weight.grad, + layer_pd.inter_attention.layernorm_query.weight.grad.T, + rtol=rtol, + atol=atol) + if not no_dbias: + if output_layernorm: + assert_allclose(layer_te.self_attention.qkv.bias.grad, + layer_pd.self_attention.qkv.bias.grad, + rtol=0.01, + atol=0.5) + else: + assert_allclose(layer_te.self_attention.layernorm_qkv.bias.grad, + layer_pd.self_attention.layernorm_qkv.bias.grad, + rtol=0.01, + atol=0.5) + assert_allclose(layer_te.inter_attention.layernorm_query.bias.grad, + layer_pd.inter_attention.layernorm_query.bias.grad, + rtol=rtol, + atol=atol) diff --git a/tests/paddle/test_operators.py b/tests/paddle/test_operators.py index c2769ee2bc..662978086a 100644 --- a/tests/paddle/test_operators.py +++ b/tests/paddle/test_operators.py @@ -46,7 +46,7 @@ from transformer_engine.common.recipe import DelayedScaling np.random.seed(10) -paddle.seed(10) +paddle.seed(11) GEMM_CASES = [(256, 256, 512), (32, 32, 32), (16384, 1024, 2816), (16384, 2816, 1024), (16384, 1024, 1024)] is_fp8_supported, reason = is_fp8_available() @@ -400,7 +400,7 @@ def test_layernorm_fwd(self): y_ref, mu_ref, rsigma_ref = self.calc_fwd_ref(x, eps, gamma, beta) - assert_allclose(y, y_ref, rtol=1e-5, atol=1e-5) + assert_allclose(y, y_ref, rtol=1e-4, atol=1e-4) assert_allclose(mu, mu_ref, rtol=1e-3, atol=1e-3) assert_allclose(rsigma, rsigma_ref, rtol=5e-2, atol=5e-2) @@ -725,10 +725,8 @@ def _get_fused_attention_out(self): q_grad = dq k_grad = dkv[:, :, 0, :, :] v_grad = dkv[:, :, 1, :, :] - fwd_out = paddle.reshape( - out, shape=[self.batch_size, self.q_seqlen, self.num_heads, self.head_size]) - return fwd_out, q_grad, k_grad, v_grad + return out, q_grad, k_grad, v_grad @pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0), reason="cuDNN fMHA requires Ampere+ GPU") diff --git a/transformer_engine/paddle/__init__.py b/transformer_engine/paddle/__init__.py index 798ebb0527..6184c566d1 100644 --- a/transformer_engine/paddle/__init__.py +++ b/transformer_engine/paddle/__init__.py @@ -3,5 +3,6 @@ # See LICENSE for license information. """Transformer Engine bindings for Paddle""" -from .layer import Linear, LayerNorm, LayerNormLinear, LayerNormMLP from .fp8 import fp8_autocast +from .layer import (Linear, LayerNorm, LayerNormLinear, LayerNormMLP, FusedScaleMaskSoftmax, + DotProductAttention, MultiHeadAttention, TransformerLayer) diff --git a/transformer_engine/paddle/constants.py b/transformer_engine/paddle/constants.py index 0ae9e28b43..eac161ec60 100644 --- a/transformer_engine/paddle/constants.py +++ b/transformer_engine/paddle/constants.py @@ -40,3 +40,9 @@ class FP8BwdTensors(Enum): paddle.float16: tex.DType.kFloat16, paddle.bfloat16: tex.DType.kBFloat16, } + +AttnMaskTypes = ("causal", "padding", "no_mask") + +AttnTypes = ("self", "cross") + +LayerTypes = ("encoder", "decoder") diff --git a/transformer_engine/paddle/cpp_extensions.py b/transformer_engine/paddle/cpp_extensions.py index b16c1c81e6..97a141973b 100644 --- a/transformer_engine/paddle/cpp_extensions.py +++ b/transformer_engine/paddle/cpp_extensions.py @@ -435,9 +435,9 @@ def fused_attn_fwd_qkvpacked( assert (Bias.dtype == qkv.dtype), "bias tensor must be in the same dtype as qkv." if set_zero: - out = paddle.full(shape=[total_seqs, h, d], fill_value=0, dtype=qkv.dtype) + out = paddle.full(shape=[b, max_seqlen, h, d], fill_value=0, dtype=qkv.dtype) else: - out = paddle.empty(shape=[total_seqs, h, d], dtype=qkv.dtype) + out = paddle.empty(shape=[b, max_seqlen, h, d], dtype=qkv.dtype) if is_training: softmax_aux = paddle.empty(shape=[b, h, max_seqlen, max_seqlen], dtype=qkv.dtype) @@ -574,9 +574,9 @@ def fused_attn_fwd_kvpacked( assert (Bias.dtype == q.dtype), "bias tensor must be in the same dtype as q and kv." if set_zero: - out = paddle.full(shape=[total_seqs_q, h, d], fill_value=0, dtype=q.dtype) + out = paddle.full(shape=[b, max_seqlen_q, h, d], fill_value=0, dtype=q.dtype) else: - out = paddle.empty(shape=[total_seqs_q, h, d], dtype=q.dtype) + out = paddle.empty(shape=[b, max_seqlen_q, h, d], dtype=q.dtype) if is_training: softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype) diff --git a/transformer_engine/paddle/layer/__init__.py b/transformer_engine/paddle/layer/__init__.py index bf5efd2753..b4d6ec9fef 100644 --- a/transformer_engine/paddle/layer/__init__.py +++ b/transformer_engine/paddle/layer/__init__.py @@ -3,7 +3,10 @@ # See LICENSE for license information. """Layer level Paddle APIs""" +from .attention import DotProductAttention, MultiHeadAttention from .layernorm import LayerNorm from .layernorm_linear import LayerNormLinear from .layernorm_mlp import LayerNormMLP from .linear import Linear +from .softmax import FusedScaleMaskSoftmax +from .transformer import TransformerLayer diff --git a/transformer_engine/paddle/layer/attention.py b/transformer_engine/paddle/layer/attention.py new file mode 100644 index 0000000000..a5aac3566f --- /dev/null +++ b/transformer_engine/paddle/layer/attention.py @@ -0,0 +1,568 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Attntion API""" + +import math +import warnings +from typing import Optional, Tuple, Union + +import paddle +import paddle.nn.functional as F + +from transformer_engine.paddle.constants import ( + AttnTypes, + TE_DType, +) +from transformer_engine.paddle.cpp_extensions import ( + fused_attn_fwd_qkvpacked, + fused_attn_bwd_qkvpacked, + fused_attn_fwd_kvpacked, + fused_attn_bwd_kvpacked, +) +from transformer_engine.paddle.utils import (attention_mask_func, mask_to_cu_seqlens) +from .base import TransformerEngineBaseLayer +from .layernorm_linear import LayerNormLinear +from .linear import Linear +from .softmax import FusedScaleMaskSoftmax + + +class FusedAttnFuncPackedQKV(paddle.autograd.PyLayer): + """Function for FusedAttention with packed QKV input""" + + @staticmethod + def forward(ctx, qkv, cu_seqlens, attn_bias, rng_state, max_seqlen, attn_scale, qkv_dtype, + dropout_p, set_zero, qkv_layout, attn_bias_type, attn_mask_type, is_training): + """Forward function for FusedAttention with packed QKV input""" + out, aux_ctx_tensors = fused_attn_fwd_qkvpacked( + qkv, + cu_seqlens, + rng_state, + is_training, + max_seqlen, + qkv_dtype, + attn_bias, + attn_scale, + dropout_p, + set_zero, + qkv_layout, + attn_bias_type, + attn_mask_type, + ) + + ctx.save_for_backward(qkv, out, cu_seqlens, rng_state, aux_ctx_tensors) + ctx.max_seqlen = max_seqlen + ctx.qkv_dtype = qkv_dtype + ctx.attn_scale = attn_scale + ctx.dropout_p = dropout_p + ctx.set_zero = set_zero + ctx.qkv_layout = qkv_layout + ctx.attn_bias_type = attn_bias_type + ctx.attn_mask_type = attn_mask_type + + return out + + @staticmethod + def backward(ctx, d_out): + """Backward function for FusedAttention with packed QKV input""" + qkv, out, cu_seqlens, rng_state, aux_ctx_tensors = ctx.saved_tensor() + dqkv, *rest = fused_attn_bwd_qkvpacked(qkv, cu_seqlens, rng_state, out, d_out, + aux_ctx_tensors, ctx.max_seqlen, ctx.qkv_dtype, + ctx.attn_scale, ctx.dropout_p, ctx.set_zero, + ctx.qkv_layout, ctx.attn_bias_type, + ctx.attn_mask_type) + + # if no_bias, return dqkv + if ctx.attn_bias_type == "no_bias": + return (dqkv, None, None) + # else, return (dqkv, dbias) + return (dqkv, None, rest[0], None) + + +class FusedAttnFuncPackedKV(paddle.autograd.PyLayer): + """Function for FusedAttention with packed KV input""" + + @staticmethod + def forward(ctx, q, kv, cu_seqlens_q, cu_seqlens_kv, attn_bias, rng_state, max_seqlen_q, + max_seqlen_kv, attn_scale, qkv_dtype, dropout_p, set_zero, qkv_layout, + attn_bias_type, attn_mask_type, is_training): + """Forward function for FusedAttention with packed KV input""" + out, aux_ctx_tensors = fused_attn_fwd_kvpacked(q, kv, cu_seqlens_q, cu_seqlens_kv, + rng_state, is_training, max_seqlen_q, + max_seqlen_kv, qkv_dtype, attn_bias, + attn_scale, dropout_p, set_zero, qkv_layout, + attn_bias_type, attn_mask_type) + + ctx.save_for_backward(q, kv, out, cu_seqlens_q, cu_seqlens_kv, rng_state, aux_ctx_tensors) + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_kv = max_seqlen_kv + ctx.qkv_dtype = qkv_dtype + ctx.attn_scale = attn_scale + ctx.dropout_p = dropout_p + ctx.set_zero = set_zero + ctx.qkv_layout = qkv_layout + ctx.attn_bias_type = attn_bias_type + ctx.attn_mask_type = attn_mask_type + + return out + + @staticmethod + def backward(ctx, d_out): + """Backward function for FusedAttention with packed KV input""" + q, kv, out, cu_seqlens_q, cu_seqlens_kv, rng_state, aux_ctx_tensors = ctx.saved_tensor() + dq, dkv, *rest = fused_attn_bwd_kvpacked(q, kv, cu_seqlens_q, cu_seqlens_kv, rng_state, out, + d_out, aux_ctx_tensors, ctx.max_seqlen_q, + ctx.max_seqlen_kv, ctx.qkv_dtype, ctx.attn_scale, + ctx.dropout_p, ctx.set_zero, ctx.qkv_layout, + ctx.attn_bias_type, ctx.attn_mask_type) + + # if no_bias, return dq, dkv + if ctx.attn_bias_type == "no_bias": + return (dq, dkv, None, None, None) + # else, return (dq, dkv, dbias) + return (dq, dkv, None, None, rest[0], None) + + +class DotProductAttention(paddle.nn.Layer): + """Dot Product Attention Layer + Allows the model to jointly attend to information from different + representation subspaces as described in the paper: + `Attention Is All You Need `_. + + .. note:: + + Argument :attr:`attention_mask` will be ignored in the `forward` call when + :attr:`attn_mask_type` is set to `"causal"`. + + Parameters + ---------- + norm_factor : float + normalization factor for the attention scores. + attention_dropout: float, default = 0.1 + dropout probability for the dropout op during multi-head attention. + attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal` + type of attention mask passed into softmax operation. + attention_type: {'self', 'cross'}, default = `self` + type of attention operation. + backend: {'transformer_engine', 'paddle'}, default = `transformer_engine` + backend to use for attention operation. + + """ + + def __init__(self, + norm_factor: float, + attention_dropout: float = 0.1, + attn_mask_type: str = "causal", + attention_type: str = "self", + backend: str = 'transformer_engine') -> None: + super().__init__() + + self.norm_factor = norm_factor + self.attn_mask_type = attn_mask_type + self.attention_dropout = attention_dropout + self.attention_type = attention_type + self.backend = backend + self.rng_state = paddle.zeros((2,), dtype='int64') + self.rng_state.persistable = True + if self.backend != 'transformer_engine': + self.scale_mask_softmax = FusedScaleMaskSoftmax(attn_mask_type, + attention_mask_func, + backend=self.backend) + + def forward( + self, + query_layer: paddle.Tensor, + key_value_layer: paddle.Tensor = None, + attention_mask: Optional[paddle.Tensor] = None, + core_attention_bias_type: str = "no_bias", + core_attention_bias: Optional[paddle.Tensor] = None, + set_zero: bool = True, + ) -> paddle.Tensor: + """ + Dot Product Attention Layer. + + .. note:: + + Argument :attr:`attention_mask` will be ignored when :attr:`attn_mask_type` + is set to `"causal"`. + + .. note:: + + For self attention, :attr:`query_layer` is the `[query, key, value]` tensor + stacked along the 2nd dimension, which must be of shape (:attr:`batch_size`, + :attr:`seq_length`, 3, :attr:`num_attention_heads`, :attr:`size_per_head`). + And :attr:`key_value_layer` is `None`. + For cross attention, :attr:`query_layer` is the `[query]` tensor, which must + be of shape (:attr:`batch_size`, :attr:`seq_length`, :attr:`num_attention_heads`, + :attr:`size_per_head`). And :attr:`key_value_layer` is the `[key, value]` tensor, + which must be of shape (:attr:`batch_size`, :attr:`seq_length`, 2, + :attr:`num_attention_heads`, :attr:`size_per_head`). + + + + Parameters + ---------- + query_layer : paddle.Tensor + Query tensor. + key_value_layer : paddle.Tensor + Key tensor. + attention_mask : Optional[paddle.Tensor], default = `None` + Boolean tensor used to mask out softmax input when not using attention. + core_attention_bias_type: str, default = `no_bias` + only support no_bias type currently, {`no_bias`} + core_attention_bias: Optional[paddle.Tensor], default = `None` + Bias tensor for Q * K.T + set_zero: bool, defautl = `True` + Whether to use the fast path to set output tensors to 0 or not. + """ + + if self.backend == 'transformer_engine': + return self._te_forward(query_layer, key_value_layer, attention_mask, + core_attention_bias_type, core_attention_bias, set_zero) + if self.backend == 'paddle': + if core_attention_bias_type != "no_bias": + warnings.warn("Paddle backend dot product attention does not support bias yet. " + "Bias will be ignored.") + return self._pd_forward(query_layer, key_value_layer, attention_mask) + raise AttributeError(f"Backend {self.backend} is not supported.") + + def _te_forward( + self, + query_layer: paddle.Tensor, + key_value_layer: paddle.Tensor = None, + attention_mask: Optional[paddle.Tensor] = None, + core_attention_bias_type: str = "no_bias", + core_attention_bias: Optional[paddle.Tensor] = None, + set_zero: bool = True, + ) -> paddle.Tensor: + + gen_state = paddle.get_rng_state()[0].__getstate__() + self.rng_state[0], self.rng_state[1] = gen_state[1], gen_state[2] # [seed, offset] + if self.attention_type == "self": + # self attention - q: [b, s, 3, h, d] kv: None + assert (len(query_layer.shape) == 5 and query_layer.shape[2] == 3 + and key_value_layer is None + ), "query shape must be [b, s, 3, h, d] for dot product self attention" + max_seqlen = query_layer.shape[1] + cu_seqlens, _ = mask_to_cu_seqlens(attention_mask) + qkv_dtype = TE_DType[query_layer.dtype] + qkv_layout = "qkv_interleaved" + + output = FusedAttnFuncPackedQKV.apply( + query_layer, + cu_seqlens, + core_attention_bias, + self.rng_state, + max_seqlen, + 1.0 / self.norm_factor, + qkv_dtype, + self.attention_dropout if self.training else 0.0, + set_zero, + qkv_layout, + core_attention_bias_type, + self.attn_mask_type, + self.training, + ) + elif self.attention_type == "cross": + # cross attention - q: [b, s_q, h, d] kv: [b, s_kv, 2, h, d] + assert ( + len(query_layer.shape) == 4 and len(key_value_layer.shape) == 5 + and key_value_layer.shape[2] == 2 + ), "query shape must be [b, s, h, d] and key shape must be [b, s, 2, h, d]" \ + "for dot product cross attention" + max_seqlen_q = query_layer.shape[1] + max_seqlen_kv = key_value_layer.shape[1] + cu_seqlens_q, cu_seqlens_kv = mask_to_cu_seqlens(attention_mask, need_kv=True) + qkv_dtype = TE_DType[query_layer.dtype] + qkv_layout = "kv_interleaved" + output = FusedAttnFuncPackedKV.apply( + query_layer, + key_value_layer, + cu_seqlens_q, + cu_seqlens_kv, + core_attention_bias, + self.rng_state, + max_seqlen_q, + max_seqlen_kv, + 1.0 / self.norm_factor, + qkv_dtype, + self.attention_dropout if self.training else 0.0, + set_zero, + qkv_layout, + core_attention_bias_type, + self.attn_mask_type, + self.training, + ) + else: + raise ValueError("attention_type must be one of ['self', 'cross']") + return output + + def _pd_forward( + self, + query_layer: paddle.Tensor, + key_value_layer: paddle.Tensor = None, + attention_mask: Optional[paddle.Tensor] = None, + ) -> paddle.Tensor: + if self.attention_type == "self": + # self attention - q: [b, s, 3, h, d] k: None + assert (len(query_layer.shape) == 5 and query_layer.shape[2] == 3 + and key_value_layer is None + ), "query shape must be [b, s, 3, h, d] for dot product self attention" + q = query_layer[:, :, 0] + k = query_layer[:, :, 1] + v = query_layer[:, :, 2] + elif self.attention_type == "cross": + # cross attention - q: [b, s, h, d] kv: [b, s, 2, h, d] + assert ( + len(query_layer.shape) == 4 and len(key_value_layer.shape) == 5 + and key_value_layer.shape[2] == 2 + ), f"query shape must be [b, s, h, d] and key_value shape must be [b, s, 2, h, d]" \ + f"for dot product cross attention. The actual shape is q: {query_layer.shape}" \ + f"kv: {key_value_layer.shape}" + q = query_layer + k = key_value_layer[:, :, 0] + v = key_value_layer[:, :, 1] + + q = paddle.transpose(x=q, perm=[0, 2, 1, 3]) + k = paddle.transpose(x=k, perm=[0, 2, 1, 3]) + v = paddle.transpose(x=v, perm=[0, 2, 1, 3]) + + product = paddle.matmul(x=q * (1.0 / self.norm_factor), y=k, transpose_y=True) + attention_probs = self.scale_mask_softmax(product, attention_mask, scale=None) + + if self.attention_dropout > 0: + attention_probs = F.dropout( + attention_probs, + self.attention_dropout, + training=self.training, + ) + + out = paddle.matmul(attention_probs, v) + out = paddle.transpose(out, perm=[0, 2, 1, 3]) # [b, s, h, d] + # out = paddle.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]]) + return out + + +class MultiHeadAttention(TransformerEngineBaseLayer): + """Attention w/ QKV and Proj Gemms + + Parameters + ---------- + hidden_size: int + hidden size of the model. + num_attention_heads: int + number of attention heads. + attention_dropout: float, default = 0.1 + dropout probability for the dropout op during multi-head attention. + layernorm_epsilon: float, default = 1e-5 + epsilon to use in the layer norm operations. + weight_attr: Union[paddle.ParamAttr, None], default = `None` + paddle.ParamAttr object for the weight parameter. + bias_attr: Union[paddle.ParamAttr, None, bool], default = `None` + paddle.ParamAttr object for the bias parameter. + attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal` + type of attention mask passed into softmax operation. + params_dtype: Optional[paddle.dtype], default = `None` + data type for the weights and biases. + return_layernorm_output: bool, default = `False` + whether to return the output of the layernorm operation. + input_layernorm: bool, default = `False` + whether to apply layernorm to the input. + attention_type: {'self', 'cross'}, default = `self` + type of attention operation. + zero_centered_gamma: bool, default = `False` + whether to zero initialize the gamma of the layernorm operation. + backend: {'transformer_engine', 'paddle'}, default = `transformer_engine` + backend to use for attention operation. + """ + + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + attention_dropout: float = 0.1, + layernorm_epsilon: float = 1e-5, + weight_attr: Union[paddle.ParamAttr, None] = None, + bias_attr: Union[paddle.ParamAttr, None, bool] = None, + attn_mask_type: str = "causal", + params_dtype: Optional[paddle.dtype] = None, + return_layernorm_output: bool = False, + input_layernorm: bool = False, + attention_type: str = "self", + zero_centered_gamma: bool = False, + backend: str = 'transformer_engine', + ) -> None: + super().__init__() + self.input_layernorm = input_layernorm + self.attention_type = attention_type + self.return_layernorm_output = return_layernorm_output + self.params_dtype = paddle.get_default_dtype() if params_dtype is None else params_dtype + self.weight_attr = weight_attr + self.bias_attr = bias_attr + self.attn_mask_type = attn_mask_type + + assert attention_type in AttnTypes, f"attention_type {attention_type} not supported" + + self.hidden_size_per_attention_head = hidden_size // num_attention_heads + self.num_attention_heads = num_attention_heads + norm_factor = math.sqrt(self.hidden_size_per_attention_head) + self.backend = backend + + if self.attention_type == "self": + if self.input_layernorm: + self.layernorm_qkv = LayerNormLinear( + hidden_size, + 3 * hidden_size, + eps=layernorm_epsilon, + weight_attr=self.weight_attr, + bias_attr=self.bias_attr, + return_layernorm_output=return_layernorm_output, + zero_centered_gamma=zero_centered_gamma, + backend=self.backend, + ) + else: + self.qkv = Linear( + hidden_size, + 3 * hidden_size, + self.weight_attr, + self.bias_attr, + backend=self.backend, + ) + + else: # cross attention + if self.input_layernorm: + self.layernorm_query = LayerNormLinear( + hidden_size, + hidden_size, + eps=layernorm_epsilon, + weight_attr=self.weight_attr, + bias_attr=self.bias_attr, + return_layernorm_output=return_layernorm_output, + zero_centered_gamma=zero_centered_gamma, + backend=self.backend, + ) + else: + self.query_layer = Linear( + hidden_size, + hidden_size, + self.weight_attr, + self.bias_attr, + backend=self.backend, + ) + self.key_value = Linear( + hidden_size, + 2 * hidden_size, + self.weight_attr, + self.bias_attr, + backend=self.backend, + ) + + # Attention. + self.core_attention = DotProductAttention( + norm_factor, + attention_dropout, + attn_mask_type=attn_mask_type, + attention_type=self.attention_type, + backend=self.backend, + ) + + # Linear + self.proj = Linear( + hidden_size, + hidden_size, + self.weight_attr, + self.bias_attr, + backend=self.backend, + ) + + def forward( + self, + hidden_states: paddle.Tensor, + attention_mask: Optional[paddle.Tensor] = None, + encoder_output: Optional[paddle.Tensor] = None, + core_attention_bias_type: str = "no_bias", + core_attention_bias: Optional[paddle.Tensor] = None, + set_zero: bool = True, + ) -> Tuple[Union[paddle.Tensor, None], ...]: + """ + MultiHeadAttention Layer. + + + Parameters + ---------- + hidden_states : paddle.Tensor + Input tensor. + attention_mask : Optional[paddle.Tensor], default = `None` + Boolean tensor used to mask out softmax input when not using attention. + encoder_output : Optional[paddle.Tensor], default = `None` + Output of the encoder layer. + core_attention_bias_type: str, default = `no_bias` + only support no_bias type currently, {`no_bias`} + core_attention_bias: Optional[paddle.Tensor], default = `None` + Bias tensor for Q * K.T + set_zero: bool, defautl = `True` + Whether to use the fast path to set output tensors to 0 or not. + + """ + + # hidden_states: [b, s_q, hidden_size] + if self.attn_mask_type != "causal" and attention_mask is not None: + assert (attention_mask.dtype == paddle.bool), "Attention mask must be a boolean tensor" + + if self.attention_type == "self": + if self.input_layernorm: + layernorm_qkv_outputs = self.layernorm_qkv(hidden_states) + if self.return_layernorm_output: + mixed_qkv_layer, layernorm_output = layernorm_qkv_outputs + else: + mixed_qkv_layer = layernorm_qkv_outputs + else: + mixed_qkv_layer = self.qkv(hidden_states) + + # [b, s_q, 3 * hidden_size] --> [b, s_q, 3, num_heads, head_size] + mixed_qkv_layer = mixed_qkv_layer.reshape( + shape=[0, 0, 3, self.num_attention_heads, self.hidden_size_per_attention_head]) + + context_layer = self.core_attention( + query_layer=mixed_qkv_layer, + key_value_layer=None, + attention_mask=attention_mask, + core_attention_bias_type=core_attention_bias_type, + core_attention_bias=core_attention_bias, + set_zero=set_zero, + ) + + else: # cross attention + mixed_kv_layer = self.key_value(encoder_output) + # [b, s_kv, 2 * hidden_size] --> [b, s_kv, 2, num_heads, head_size] + mixed_kv_layer = mixed_kv_layer.reshape( + shape=[0, 0, 2, self.num_attention_heads, self.hidden_size_per_attention_head]) + + if self.input_layernorm: + layernorm_query_outputs = self.layernorm_query(hidden_states) + if self.return_layernorm_output: + query_layer, layernorm_output = layernorm_query_outputs + else: + query_layer = layernorm_query_outputs + else: + query_layer = self.query_layer(hidden_states) + + query_layer = query_layer.reshape( + shape=[0, 0, self.num_attention_heads, self.hidden_size_per_attention_head]) + context_layer = self.core_attention( + query_layer=query_layer, + key_value_layer=mixed_kv_layer, + attention_mask=attention_mask, + core_attention_bias_type=core_attention_bias_type, + core_attention_bias=core_attention_bias, + set_zero=set_zero, + ) + + context_layer = paddle.reshape(context_layer, + [0, 0, context_layer.shape[2] * context_layer.shape[3]]) + # Output. [b, s, hidden] + attention_output = self.proj(context_layer) + + if self.input_layernorm and self.return_layernorm_output: + return attention_output, layernorm_output + return attention_output diff --git a/transformer_engine/paddle/layer/layernorm.py b/transformer_engine/paddle/layer/layernorm.py index a706c85c88..3f0b8c4a50 100644 --- a/transformer_engine/paddle/layer/layernorm.py +++ b/transformer_engine/paddle/layer/layernorm.py @@ -126,7 +126,7 @@ def _pd_forward( "Paddle backend does not support LayerNorm with zero-centered scale.") return F.layer_norm(x=inp, - normalized_shape=inp.shape[1:], + normalized_shape=inp.shape[-1], weight=self.weight, bias=self.bias, epsilon=self.eps) diff --git a/transformer_engine/paddle/layer/layernorm_linear.py b/transformer_engine/paddle/layer/layernorm_linear.py index 88736ba75f..608f02a6ff 100644 --- a/transformer_engine/paddle/layer/layernorm_linear.py +++ b/transformer_engine/paddle/layer/layernorm_linear.py @@ -402,7 +402,6 @@ def _te_forward( if self.return_layernorm_output: out, ln_out = out return out, ln_out - return out def _pd_forward( @@ -415,7 +414,7 @@ def _pd_forward( "Paddle backend does not support LayerNorm with zero-centered scale.") ln_out = F.layer_norm(x=inp, - normalized_shape=inp.shape[1:], + normalized_shape=inp.shape[-1], weight=self.ln_weight, bias=self.ln_bias, epsilon=self.eps) diff --git a/transformer_engine/paddle/layer/layernorm_mlp.py b/transformer_engine/paddle/layer/layernorm_mlp.py index 7bf3cc6fab..6d725114b0 100644 --- a/transformer_engine/paddle/layer/layernorm_mlp.py +++ b/transformer_engine/paddle/layer/layernorm_mlp.py @@ -624,7 +624,7 @@ def _pd_forward( "Paddle backend does not support LayerNorm with zero-centered scale.") ln_out = F.layer_norm(x=inp, - normalized_shape=inp.shape[1:], + normalized_shape=inp.shape[-1], weight=self.ln_weight, bias=self.ln_bias, epsilon=self.eps) diff --git a/transformer_engine/paddle/layer/softmax.py b/transformer_engine/paddle/layer/softmax.py new file mode 100644 index 0000000000..33b0293e0a --- /dev/null +++ b/transformer_engine/paddle/layer/softmax.py @@ -0,0 +1,237 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Fused scaled masked softmax functions""" + +import os +import warnings +from typing import Callable, Tuple, Union, Optional + +import paddle + +from transformer_engine.paddle.cpp_extensions import ( + scaled_upper_triang_masked_softmax_forward, + scaled_upper_triang_masked_softmax_backward, + scaled_masked_softmax_forward, + scaled_masked_softmax_backward, + scaled_softmax_forward, + scaled_softmax_backward, +) + +THREADS_PER_WARP = 32 +THREADS_PER_BLOCK = 128 + +_default_causal_mask = {} + + +def _get_default_causal_mask(seqlen: int) -> paddle.Tensor: + """Return the causal upper triangular mask for softmax input""" + if seqlen not in _default_causal_mask: + _default_causal_mask[seqlen] = paddle.triu(paddle.ones((seqlen, seqlen)), + diagonal=1).cast('bool') + return _default_causal_mask[seqlen] + + +class ScaledUpperTriangMaskedSoftmax(paddle.autograd.PyLayer): + """ + Fused operation which performs following three operations in sequence + 1. Scale the tensor. + 2. Apply upper triangular mask (typically used in gpt models). + 3. Perform softmax. + """ + + @staticmethod + def forward(ctx, inputs: paddle.Tensor, scale: float) -> paddle.Tensor: + """ScaledUpperTriangMaskedSoftmax fwd""" + scale_t = paddle.Tensor([scale]) + softmax_results = scaled_upper_triang_masked_softmax_forward(inputs, scale_t[0]) + + ctx.save_for_backward(softmax_results, scale_t) + return softmax_results + + @staticmethod + def backward(ctx, output_grads: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None], ...]: + """ScaledUpperTriangMaskedSoftmax bwd""" + softmax_results, scale_t = ctx.saved_tensor() + input_grads = scaled_upper_triang_masked_softmax_backward(output_grads, softmax_results, + scale_t[0]) + + return input_grads, None + + +class ScaledMaskedSoftmax(paddle.autograd.PyLayer): + """ + Fused operation which performs following three operations in sequence + 1. Scale the tensor. + 2. Apply the mask. + 3. Perform softmax. + """ + + @staticmethod + def forward(ctx, inputs: paddle.Tensor, mask: paddle.Tensor, scale: float) -> paddle.Tensor: + """ScaledMaskedSoftmax fwd""" + scale_t = paddle.Tensor([scale]) + + softmax_results = scaled_masked_softmax_forward(inputs, mask, scale_t[0]) + ctx.save_for_backward(softmax_results, scale_t) + return softmax_results + + @staticmethod + def backward(ctx, output_grads: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None], ...]: + """ScaledMaskedSoftmax bwd""" + softmax_results, scale_t = ctx.saved_tensor() + + input_grads = scaled_masked_softmax_backward(output_grads, softmax_results, scale_t[0]) + return input_grads, None, None + + +class ScaledSoftmax(paddle.autograd.PyLayer): + """ + Fused operation which performs following two operations in sequence + 1. Scale the tensor. + 2. Perform softmax. + """ + + @staticmethod + def forward(ctx, inputs: paddle.Tensor, scale: float) -> paddle.Tensor: + """ScaledSoftmax fwd""" + scale_t = paddle.Tensor([scale]) + + softmax_results = scaled_softmax_forward(inputs, scale_t[0]) + ctx.save_for_backward(softmax_results, scale_t) + return softmax_results + + @staticmethod + def backward(ctx, output_grads: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None], ...]: + """ScaledSoftmax bwd""" + softmax_results, scale_t = ctx.saved_tensor() + + input_grads = scaled_softmax_backward(output_grads, softmax_results, scale_t[0]) + return input_grads, None, None + + +class FusedScaleMaskSoftmax(paddle.nn.Layer): + """ + fused operation: scaling + mask + softmax + + Arguments: + attn_mask_type: attention mask type (pad or causal) + mask_func: mask function to be applied. + softmax_in_fp32: if true, softmax in performed at fp32 precision. + """ + + def __init__( + self, + attn_mask_type: str, + mask_func: Callable, + softmax_in_fp32: bool = True, + backend: str = 'transformer_engine', + ) -> None: + super().__init__() + self.attn_mask_type = attn_mask_type + self.scaled_masked_softmax_fusion = bool(int(os.getenv("NVTE_MASKED_SOFTMAX_FUSION", "1"))) + self.mask_func = mask_func + self.softmax_in_fp32 = softmax_in_fp32 + self.backend = backend + + def forward( + self, + inp: paddle.Tensor, + mask: paddle.Tensor, + scale: Optional[float] = None, + ) -> paddle.Tensor: + """FusedScaleMaskSoftmax fprop""" + # [batch_size, num_heads, s_q, s_kv] + assert inp.dim() == 4 + self.input_is_fp16 = inp.dtype == paddle.float16 + self.input_is_bf16 = inp.dtype == paddle.bfloat16 + self.input_in_16bit_float = self.input_is_fp16 or self.input_is_bf16 + + assert (scale is None or self.softmax_in_fp32), "softmax should be in fp32 when scaled" + + if self.backend == 'transformer_engine' and not self.is_kernel_available(*inp.shape): + warnings.warn( + "fused kernel is not available for this input shape, fall back to paddle backend") + self.backend = 'paddle' + + if self.backend == 'transformer_engine': + return self._te_forward(inp, mask, scale) + if self.backend == 'paddle': + return self._pd_forward(inp, mask, scale) + raise AttributeError(f"Backend {self.backend} is not supported.") + + def is_kernel_available(self, b: int, h: int, s_q: int, s_kv: int) -> bool: + """Check FusedScaleMaskSoftmax kernel availability based on size""" + attn_batches = b * h + + if (self.scaled_masked_softmax_fusion # user want to fuse + and self.input_in_16bit_float # input must be fp16 + and 16 < s_kv <= 4096 # s_kv must be 16 ~ 2048 + and s_q % 4 == 0 # s_q must be a multiple of 4 + and attn_batches % 4 == 0 # b * h must be a multiple of 4 + ): + if 0 <= s_kv <= 4096: + batch_per_block = self.get_batch_per_block(int(s_kv)) + + if self.attn_mask_type == "causal": + if attn_batches % batch_per_block == 0: + return True + else: + if s_q % batch_per_block == 0: + return True + return False + + def _te_forward(self, + inp: paddle.Tensor, + mask: paddle.Tensor, + scale: Optional[float] = None) -> paddle.Tensor: + """Fused masked softmax kernel""" + b, h, s_q, s_kv = inp.size() + scale = 1.0 if scale is None else scale + + if self.attn_mask_type == "causal": + assert s_q == s_kv, "causal mask is only for self attention" + + # input is 3D tensor (attn_batches, s_q, s_kv) + inp = inp.reshape((-1, s_q, s_kv)) + probs = ScaledUpperTriangMaskedSoftmax.apply(inp, scale) + return probs.reshape((b, h, s_q, s_kv)) + # input is 4D tensor (b, h, s_q, s_kv) + if mask is not None: + return ScaledMaskedSoftmax.apply(inp, mask, scale) + return ScaledSoftmax.apply(inp, scale) + + def _pd_forward(self, + inp: paddle.Tensor, + mask: paddle.Tensor, + scale: Optional[float] = None) -> paddle.Tensor: + """Call Paddle OP""" + if self.input_in_16bit_float and self.softmax_in_fp32: + inp = paddle.cast(inp, 'float32') + + if scale is not None: + inp = inp * scale + + if self.attn_mask_type == "causal": + mask = _get_default_causal_mask(inp.shape[2]) + + mask_output = self.mask_func(inp, mask) if mask is not None else inp + probs = paddle.nn.functional.softmax(mask_output, axis=-1) + + if self.input_in_16bit_float and self.softmax_in_fp32: + if self.input_is_fp16: + probs = paddle.cast(probs, 'float16') + else: + probs = paddle.cast(probs, 'bfloat16') + + return probs + + @staticmethod + def get_batch_per_block(key_seq_len: int) -> int: + """Softmax utility""" + pow2 = 1 << (key_seq_len - 1).bit_length() + warp_size = pow2 if pow2 < THREADS_PER_WARP else THREADS_PER_WARP + batches_per_warp = 2 if pow2 <= 128 else 1 + warps_per_block = THREADS_PER_BLOCK // warp_size + batches_per_block = warps_per_block * batches_per_warp + return batches_per_block diff --git a/transformer_engine/paddle/layer/transformer.py b/transformer_engine/paddle/layer/transformer.py new file mode 100644 index 0000000000..6e6afd4ca2 --- /dev/null +++ b/transformer_engine/paddle/layer/transformer.py @@ -0,0 +1,260 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Transformer""" + +from typing import Optional, Union + +import paddle + +from transformer_engine.paddle.constants import ( + AttnMaskTypes, + LayerTypes, +) +from transformer_engine.paddle.layer import (LayerNormMLP, LayerNorm, MultiHeadAttention) +from .base import TransformerEngineBaseLayer + + +class TransformerLayer(TransformerEngineBaseLayer): + r""" + TransformerLayer is made up of an attention block and a feedforward network (MLP). + This standard layer is based on the paper "Attention Is All You Need". + + Parameters + ---------- + hidden_size : int + size of each input sample. + ffn_hidden_size : int + intermediate size to which input samples are projected. + num_attention_heads : int + number of attention heads in the transformer layer. + layernorm_epsilon : float, default = 1e-5 + a value added to the denominator of layer normalization + for numerical stability. + hidden_dropout: float, default = 0.1 + dropout probability for the dropout op after FC2 layer. + attention_dropout: float, default = 0.1 + dropout probability for the dropout op during multi-head attention. + self_attn_mask_type: {'causal', 'padding'}, default = `causal` + type of attention mask passed into softmax operation. + apply_residual_connection_post_layernorm : bool, default = `False` + if set to `True`, residual connections are taken + from the output of layer norm (default is taken + from input of layer norm) + output_layernorm: bool, default = `False` + if set to `True`, layer normalization is applied on the output side, + after the final dropout-add. default behavior is to apply layer + normalization on the input side, before the QKV transformation. + layer_type: {'encoder', 'decoder'}, default = `encoder` + if set to `decoder`, an additional cross-attn block is added after self-attn. + This can be used for structures like `T5` Transformer in conjunction with the + `encoder` option. + zero_centered_gamma : bool, default = 'False' + if set to 'True', gamma parameter in LayerNorm is initialized to 0 and + the LayerNorm formula changes to + + .. math:: + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * + (1 + \gamma) + \beta + activation : str, default = 'gelu' + Type of activation used in MLP block. + Options are: 'gelu', 'relu', 'reglu', 'geglu' and 'swiglu'. + + params_dtype : paddle.dtype, default = `paddle.get_default_dtype()` + it controls the type used to allocate the initial parameters. Useful when + the model is trained with lower precision and the original FP32 parameters + would not fit in GPU memory. + """ + + def __init__(self, + hidden_size: int, + ffn_hidden_size: int, + num_attention_heads: int, + layernorm_epsilon: float = 1e-5, + hidden_dropout: float = 0.1, + attention_dropout: float = 0.1, + weight_attr: Union[paddle.ParamAttr, None] = None, + bias_attr: Union[paddle.ParamAttr, None, bool] = None, + self_attn_mask_type: str = "causal", + params_dtype: Optional[paddle.dtype] = None, + apply_residual_connection_post_layernorm: bool = False, + output_layernorm: bool = False, + layer_type: str = "encoder", + zero_centered_gamma: bool = False, + activation: str = 'gelu', + backend: str = 'transformer_engine') -> None: + super().__init__() + + params_dtype = paddle.get_default_dtype() if params_dtype is None else params_dtype + self.output_layernorm = output_layernorm + self.layer_type = layer_type + self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm + self.self_attn_mask_type = self_attn_mask_type + + assert (self_attn_mask_type + in AttnMaskTypes), f"self_attn_mask_type {self_attn_mask_type} not supported" + assert layer_type in LayerTypes, f"layer_type {layer_type} not supported" + + attention_args = ( + hidden_size, + num_attention_heads, + attention_dropout, + layernorm_epsilon, + weight_attr, + bias_attr, + ) + common_attention_kwargs = { + "params_dtype": params_dtype, + "return_layernorm_output": apply_residual_connection_post_layernorm, + "zero_centered_gamma": zero_centered_gamma, + "backend": backend, + } + + self.self_attention = MultiHeadAttention( + *attention_args, + **common_attention_kwargs, + attn_mask_type=self_attn_mask_type, + input_layernorm=not output_layernorm, + attention_type="self", + ) + + if layer_type == "decoder": + self.inter_attention = MultiHeadAttention( + *attention_args, + **common_attention_kwargs, + attn_mask_type="padding", + input_layernorm=True, + attention_type="cross", + ) + + self.layernorm_mlp = LayerNormMLP( + hidden_size, + ffn_hidden_size, + eps=layernorm_epsilon, + weight_attr=weight_attr, + bias_attr=bias_attr, + activation=activation, + return_layernorm_output=apply_residual_connection_post_layernorm, + zero_centered_gamma=zero_centered_gamma, + backend=backend, + ) + + self.hidden_dropout = hidden_dropout + + if self.output_layernorm: + self.layernorm = LayerNorm( + hidden_size, + layernorm_epsilon, + weight_attr, + bias_attr, + zero_centered_gamma=zero_centered_gamma, + backend=backend, + ) + + def forward( + self, + hidden_states: paddle.Tensor, + attention_mask: Optional[paddle.Tensor] = None, + encoder_output: Optional[paddle.Tensor] = None, + enc_dec_attn_mask: Optional[paddle.Tensor] = None, + core_attention_bias_type: str = "no_bias", + core_attention_bias: Optional[paddle.Tensor] = None, + set_zero: bool = True, + ) -> paddle.Tensor: + """ + Transformer Layer: attention block and a feedforward network (MLP) + + .. note:: + + Argument :attr:`attention_mask` will be ignored when :attr:`self_attn_mask_type` + is set to `"causal"`. + + Parameters + ---------- + hidden_states : paddle.Tensor + Input tensor. + attention_mask : Optional[paddle.Tensor], default = `None` + Boolean tensor used to mask out self-attention softmax input. + encoder_output : Optional[paddle.Tensor], default = `None` + Output of the encoder block to be fed into the decoder block if using + `layer_type="decoder"`. + enc_dec_attn_mask : Optional[paddle.Tensor], default = `None` + Boolean tensor used to mask out inter-attention softmax input if using + `layer_type="decoder"`. + core_attention_bias_type: str, default = `no_bias` + core_attention_bias: Optional[paddle.Tensor], default = `None` + Bias tensor for Q * K.T + set_zero: bool, default = `True` + Whether to set output tensors to 0 or not before use. + """ + + if self.self_attn_mask_type != "causal" and attention_mask is not None: + assert (attention_mask.dtype == paddle.bool), "Attention mask must be a boolean tensor" + + assert core_attention_bias_type in ['no_bias'], f"Only no_bias is supported currently, " \ + f"but receive core_attention_bias_type = {core_attention_bias_type}" + + # Self attention. + self_attention_outputs = self.self_attention( + hidden_states, + attention_mask, + core_attention_bias_type=core_attention_bias_type, + core_attention_bias=core_attention_bias, + set_zero=set_zero, + ) + + if self.apply_residual_connection_post_layernorm and not self.output_layernorm: + attention_output, residual = self_attention_outputs + else: + attention_output = self_attention_outputs + residual = hidden_states + + # dropoout add. + out = paddle.nn.functional.dropout( + attention_output, + p=self.hidden_dropout, + training=True, + ) + bda_output = residual + out + + # Cross attention. + if self.layer_type == "decoder": + inter_attention_outputs = self.inter_attention( + bda_output, + enc_dec_attn_mask, + encoder_output=encoder_output, + core_attention_bias_type=core_attention_bias_type, + core_attention_bias=core_attention_bias, + set_zero=set_zero, + ) + if self.apply_residual_connection_post_layernorm: + attention_output, residual = inter_attention_outputs + else: + attention_output = inter_attention_outputs + residual = bda_output + + out = paddle.nn.functional.dropout( + attention_output, + p=self.hidden_dropout, + training=True, + ) + bda_output = residual + out + + # MLP. + mlp_outputs = self.layernorm_mlp(bda_output) + if self.apply_residual_connection_post_layernorm: + mlp_output, residual = mlp_outputs + else: + mlp_output = mlp_outputs + residual = bda_output + + # dropoout add. + out = paddle.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=True) + output = residual + out + + # For BERT like architectures. + if self.output_layernorm: + output = self.layernorm(output) + + # output: [b, s, hidden] + return output diff --git a/transformer_engine/paddle/utils.py b/transformer_engine/paddle/utils.py index 8bc1152a6f..9ade785d6e 100644 --- a/transformer_engine/paddle/utils.py +++ b/transformer_engine/paddle/utils.py @@ -52,3 +52,37 @@ def get_paddle_act_func(activation): if activation not in funcs: raise "Activation type " + activation + " is not supported." return funcs[activation] + + +def attention_mask_func(attention_scores: paddle.Tensor, + attention_mask: paddle.Tensor) -> paddle.Tensor: + """Get attention mask""" + + def _masked_fill(x, mask, value): + y = paddle.full(x.shape, value, x.dtype) + return paddle.where(mask, y, x) + + attention_scores = _masked_fill(attention_scores, attention_mask, -10000.0) + return attention_scores + + +def mask_to_cu_seqlens(mask: paddle.Tensor, need_kv: bool = False) -> paddle.Tensor: + """Convert mask to cu_seqlens""" + assert 'bool' in str(mask.dtype), "mask must be bool dtype" + assert len(mask.shape) == 4 and mask.shape[1] == 1, "mask must be [b, 1, s_q, s_kv]" + q_actual_seqlens = paddle.sum(mask[:, :, :, 0] == False, axis=(-1, -2), dtype='int32') # pylint: disable=singleton-comparison + q_cu_seqlens = paddle.cumsum(q_actual_seqlens) + q_cu_seqlens = paddle.concat([paddle.zeros([1], dtype=paddle.int32), q_cu_seqlens], axis=0) + if not need_kv: + return q_cu_seqlens, None + kv_actual_seqlens = paddle.sum(mask[:, :, 0, :] == False, axis=(-1, -2), dtype='int32') # pylint: disable=singleton-comparison + kv_cu_seqlens = paddle.cumsum(kv_actual_seqlens) + kv_cu_seqlens = paddle.concat([paddle.zeros([1], dtype=paddle.int32), kv_cu_seqlens], axis=0) + return q_cu_seqlens, kv_cu_seqlens + + +def divide(numerator: int, denominator: int) -> int: + """Ensure that numerator is divisible by the denominator and return + the division value.""" + assert (numerator % denominator == 0), f"{numerator} is not divisible by {denominator}" + return numerator // denominator From d661d06c38ddaa6859b161fde5f00491e7184b04 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Fri, 18 Aug 2023 00:48:04 -0700 Subject: [PATCH 47/68] fix for amax_and_scale_update when reduce_amax=False (#386) Signed-off-by: Sudhakar Singh --- transformer_engine/pytorch/module/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 56ee70d8c9..0352a7ba2b 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -78,7 +78,7 @@ def _prepare_backward( # Update amax and scale; Skip all setup for global amax reduction if not fp8_meta["recipe"].reduce_amax: - FP8GlobalStateManager.amax_and_scale_update(fp8_meta, False) + amax_and_scale_update(fp8_meta, False) else: # From previous iteration FP8GlobalStateManager.copy_amax_from_global_buffer(fp8_meta, forward=False) From 8cdd80df74f7bcfff7db041b306f378205782845 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Sat, 19 Aug 2023 01:04:24 -0700 Subject: [PATCH 48/68] PyTorch MultiheadAttention API (#387) * PyTorch MultiheadAttention API Signed-off-by: Kirthi Shankar Sivamani * Fix ONNX export tests Signed-off-by: Kirthi Shankar Sivamani * Expose MultiheadAttention for import Signed-off-by: Kirthi Shankar Sivamani * Expand mask type and add no mask numerical test Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani --- docs/api/pytorch.rst | 3 + tests/pytorch/test_numerics.py | 87 ++++++++- tests/pytorch/test_onnx_export.py | 3 +- transformer_engine/pytorch/__init__.py | 1 + transformer_engine/pytorch/attention.py | 213 ++++++++++++++++++++-- transformer_engine/pytorch/transformer.py | 8 +- 6 files changed, 288 insertions(+), 27 deletions(-) diff --git a/docs/api/pytorch.rst b/docs/api/pytorch.rst index 22a571279b..af71e1a2a7 100644 --- a/docs/api/pytorch.rst +++ b/docs/api/pytorch.rst @@ -22,6 +22,9 @@ pyTorch .. autoapiclass:: transformer_engine.pytorch.DotProductAttention(num_attention_heads, kv_channels, **kwargs) :members: forward +.. autoapiclass:: transformer_engine.pytorch.MultiheadAttention(hidden_size, num_attention_heads, **kwargs) + :members: forward + .. autoapiclass:: transformer_engine.pytorch.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads, **kwargs) :members: forward diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 6260c291c4..f8eda48cc3 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -21,7 +21,8 @@ attention_mask_func, ) from transformer_engine.pytorch import ( - DotProductAttention, Linear, LayerNormLinear, LayerNormMLP, TransformerLayer, RMSNorm + DotProductAttention, LayerNormLinear, LayerNormMLP, Linear, + MultiheadAttention, RMSNorm, TransformerLayer ) from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint @@ -60,6 +61,9 @@ def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq all_normalizations = ["LayerNorm", "RMSNorm"] +mask_types = ["causal", "no_mask"] + + def get_causal_attn_mask(sq: int) -> torch.Tensor: return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool() @@ -320,6 +324,7 @@ def forward( return context_layer + # Adapted from https://github.com/bzhangGo/rmsnorm/blob/c6691f20ec0af4128c8159c903071f7575404295/rmsnorm_torch.py class TorchRMSNorm(nn.Module): def __init__(self, in_features, eps=1e-5): @@ -341,6 +346,7 @@ def forward(self, x): return (self.weight.float() * x_normed).to(x.dtype) + class TorchLayerNormLinear(nn.Module): def __init__(self, in_features: int, out_features: int, eps: float, bias: bool = True, @@ -371,7 +377,11 @@ def __init__(self, hidden_size: int, num_attention_heads: int): ) def forward(self, x, attn_mask=None): - return self.mhsa(x, x, x, attn_mask=attn_mask, need_weights=False) + output = self.mhsa(x, x, x, attn_mask=attn_mask, need_weights=False) + if isinstance(output, tuple): + output = output[0] + return output + _supported_act = {'geglu' : nn.GELU(approximate="tanh"), 'gelu' : nn.GELU(approximate="tanh"), @@ -379,6 +389,7 @@ def forward(self, x, attn_mask=None): 'relu' : nn.ReLU(), 'swiglu' : nn.SiLU()} + class TorchGLU(nn.Module): def __init__(self, activation: str): super().__init__() @@ -391,6 +402,7 @@ def forward(self, x): a = self.act(a) return a * b + class TorchLayerNormMLP(nn.Module): def __init__(self, hidden_size: int, ffn_hidden_size: int, eps: float = 1e-5, activation = 'gelu', @@ -431,7 +443,7 @@ def forward( attn_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: a = self.ln(x) - b, _ = self.causal_attn(a, attn_mask) + b = self.causal_attn(a, attn_mask) x = x + self.resid_attn_dropout(b) n = self.ln_mlp(x) x = x + self.resid_mlp_dropout(n) @@ -754,6 +766,75 @@ def test_gpt_accuracy(dtype, bs, model): assert_allclose(te_outputs[0], torch_outputs[0], 5e-2) +def _test_mha_accuracy(block, bs, dtype, config, mask_type): + reset_rng_states() + + inp_hidden_states = torch.randn( + config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True + ).cuda() + inp_hidden_states.retain_grad() + inp_attn_mask = get_causal_attn_mask(config.seq_len) if mask_type == "causal" else None + + out = block(inp_hidden_states, inp_attn_mask) + loss = out.sum() + loss.backward() + + torch.cuda.synchronize() + outputs = [out, inp_hidden_states.grad] + for p in block.parameters(): + if p.requires_grad: + outputs.append(p.grad) + return outputs + + +@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("bs", batch_sizes) +@pytest.mark.parametrize("model", model_configs.keys()) +@pytest.mark.parametrize("mask_type", mask_types) +def test_mha_accuracy(dtype, bs, model, mask_type): + config = model_configs[model] + + te_mha = ( + MultiheadAttention( + config.hidden_size, + config.num_attention_heads, + fuse_qkv_params=True, + qkv_weight_interleaved=False, + input_layernorm=False, + attn_mask_type=mask_type, + ) + .to(dtype=dtype) + .cuda() + .eval() + ) + + torch_mha = ( + TorchMHA( + config.hidden_size, + config.num_attention_heads, + ) + .to(dtype=dtype) + .cuda() + .eval() + ) + + # Share params + with torch.no_grad(): + torch_mha.mhsa.in_proj_weight = Parameter(te_mha.qkv.weight.clone()) + torch_mha.mhsa.in_proj_bias = Parameter(te_mha.qkv.bias.clone()) + torch_mha.mhsa.out_proj.weight = Parameter(te_mha.proj.weight.clone()) + torch_mha.mhsa.out_proj.bias = Parameter(te_mha.proj.bias.clone()) + + te_outputs = _test_mha_accuracy(te_mha, bs, dtype, config, mask_type) + torch_outputs = _test_mha_accuracy(torch_mha, bs, dtype, config, mask_type) + + # Check output. + if dtype == torch.float32: + assert_allclose(te_outputs[0], torch_outputs[0], 5e-3) + else: + assert_allclose(te_outputs[0], torch_outputs[0], 5e-2) + + def _test_granular_accuracy(block, bs, dtype, config): reset_rng_states() diff --git a/tests/pytorch/test_onnx_export.py b/tests/pytorch/test_onnx_export.py index 65b2f39684..1e1fafcac5 100644 --- a/tests/pytorch/test_onnx_export.py +++ b/tests/pytorch/test_onnx_export.py @@ -1267,7 +1267,7 @@ def test_export_multihead_attention( input_ln_str = "_input-ln" if input_layernorm else "" fname = f"te.multihead_attention{fp8_str}{attn_mask_str}{attn_type_str}{input_ln_str}{fuse_qkv_str}{dtype_str}.onnx" - model = te.attention.MultiHeadAttention( + model = te.MultiheadAttention( *attention_args, attn_mask_type=attn_mask_type, params_dtype=precision, @@ -1275,6 +1275,7 @@ def test_export_multihead_attention( input_layernorm=input_layernorm, attention_type=attention_type, fuse_qkv_params=fuse_qkv_params, + return_bias=True, ).to(device='cuda') inp_context = (hidden_states_context, attention_mask, encoder_output) diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index b67ecd05b9..92a07e1242 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -9,6 +9,7 @@ from .module import LayerNorm from .module import RMSNorm from .attention import DotProductAttention +from .attention import MultiheadAttention from .transformer import TransformerLayer from .fp8 import fp8_autocast from .export import onnx_export diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 79f4b71c4e..6842a9bc60 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -30,6 +30,7 @@ attention_mask_func, split_tensor_along_dim, get_device_compute_capability, + get_default_init_method, ) from transformer_engine.pytorch.constants import ( AttnMaskTypes, @@ -56,7 +57,7 @@ from flash_attn.flash_attn_interface import flash_attn_unpadded_func as flash_attn_forward_func # pylint: disable=no-name-in-module,ungrouped-imports -__all__ = ["DotProductAttention"] +__all__ = ["DotProductAttention", "MultiheadAttention"] def _rotate_half(x: torch.Tensor) -> torch.Tensor: @@ -1181,20 +1182,132 @@ def forward( ) -class MultiHeadAttention(torch.nn.Module): - """Parallel attention w/o QKV and Proj Gemms - BMM1 -> softmax + dropout -> BMM2 +class MultiheadAttention(torch.nn.Module): + r""" + Multi-head Attention (MHA), including Query, + Key, Value and Output projection. + + .. note:: + + Argument :attr:`attention_mask` will be ignored in the `forward` call when + :attr:`self_attn_mask_type` is set to `"causal"`. + + Parameters + ---------- + hidden_size : int + size of each input sample. + num_attention_heads : int + number of attention heads in the transformer layer. + kv_channels: int, default = `None` + number of key-value channels. defaults to + :attr:`hidden_size` / :attr:`num_attention_heads` if `None`. + attention_dropout: float, default = 0.1 + dropout probability for the dropout op during multi-head attention. + layernorm_epsilon : float, default = 1e-5 + a value added to the denominator of layer normalization + for numerical stability. + init_method : Callable, default = `None` + used for initializing weights of QKV and FC1 weights in the following way: + `init_method(weight)`. When set to `None`, defaults to + `torch.nn.init.normal_(mean=0.0, std=0.023)`. + output_layer_init_method : Callable, default = `None` + used for initializing weights of PROJ and FC2 in the following way: + `output_layer_init_method(weight)`. When set to `None`, defaults to + `torch.nn.init.normal_(mean=0.0, std=0.023)`. + layer_number: int, default = `None` + layer number of the current `TransformerLayer` when multiple such modules are + concatenated to form a transformer block. + attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal` + type of attention mask passed into softmax operation. + num_gqa_groups : int, default = `None` + number of GQA groups in the transformer layer. + Grouped Query Attention is described in + `this paper `_. + This only affects the keys and values, not the querys. + GQA-1 is equivalent to Multi-Query Attention + (`MQA `_), while GQA-H + is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`. + return_layernorm_output : bool, default = `False` + if set to `True`, output of layernorm is returned from the forward + together with the output of the linear transformation. + Example use case: residual connection for transformer module is + taken post layernorm. + input_layernorm: bool, default = `True` + if set to `False`, layer normalization to the input is not applied. + attention_type: { 'self', 'cross' }, default = 'self' + type of attention applied. + zero_centered_gamma : bool, default = 'False' + if set to 'True', gamma parameter in LayerNorm is initialized to 0 and + the LayerNorm formula changes to + + .. math:: + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * + (1 + \gamma) + \beta + normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm' + type of normalization applied. + qkv_weight_interleaved : bool, default = `True` + if set to `False`, the QKV weight is interpreted as a concatenation of + query, key, and value weights along the `0th` dimension. The default + interpretation is that the individual `q`, `k`, and `v` weights for each + attention head are interleaved. This parameter is set to `False` when + using :attr:`fuse_qkv_params=False`. + bias : bool, default = `True` + if set to `False`, the transformer layer will not learn any additive biases. + device : Union[torch.device, str], default = "cuda" + The device on which the parameters of the model will allocated. It is the user's + responsibility to ensure all parameters are moved to the GPU before running the + forward pass. + + Parallelism parameters + ---------------------- + set_parallel_mode : bool, default = `False` + if set to `True`, QKV and FC1 layers are used as Column Parallel + whereas PROJ and FC2 is used as Row Parallel as described + `here `_. + sequence_parallel : bool, default = `False` + if set to `True`, uses sequence parallelism. + tp_group : ProcessGroup, default = `None` + tensor parallel process group. + tp_size : int, default = 1 + used as TP (tensor parallel) world size when TP groups are not formed during + initialization. In this case, users must call the + `set_tensor_parallel_group(tp_group)` method on the initialized module before the + forward pass to supply the tensor parallel group needed for tensor and sequence + parallel collectives. + + Optimization parameters + ----------------------- + fuse_wgrad_accumulation : bool, default = 'False' + if set to `True`, enables fusing of creation and accumulation of + the weight gradient. When enabled, it is assumed that the weights + have an additional `main_grad` attribute (used instead of the + regular `grad`) which is a pre-allocated buffer of the correct + size to accumulate gradients in. + params_dtype : torch.dtype, default = `torch.get_default_dtype()` + it controls the type used to allocate the initial parameters. Useful when + the model is trained with lower precision and the original FP32 parameters + would not fit in GPU memory. + return_bias : bool, default = `False` + when set to `True`, this module will not apply the additive bias itself, but + instead return the bias value during the forward pass together with the + output of the linear transformation :math:`y = xA^T`. This is useful when + the bias addition can be fused to subsequent operations. + fuse_qkv_params: bool, default = 'False' + if set to `True`, `TransformerLayer` module exposes a single fused + parameter for query-key-value. This enables optimizations such as QKV + fusion without concatentations/splits and also enables the argument + `fuse_wgrad_accumulation`. """ def __init__( self, hidden_size: int, num_attention_heads: int, - kv_channels: int, - attention_dropout: float, - layernorm_epsilon: float, - init_method: Callable, - output_layer_init_method: Callable, + kv_channels: Optional[int] = None, + attention_dropout: float = 0.1, + layernorm_epsilon: float = 1e-5, + init_method: Optional[Callable] = None, + output_layer_init_method: Optional[Callable] = None, layer_number: Optional[int] = None, attn_mask_type: str = "causal", tp_group: Optional[dist_group_type] = None, @@ -1204,6 +1317,7 @@ def __init__( get_rng_state_tracker: Optional[Callable] = None, sequence_parallel: bool = False, params_dtype: Optional[torch.dtype] = None, + return_bias: bool = False, return_layernorm_output: bool = False, input_layernorm: bool = False, attention_type: str = "self", @@ -1227,9 +1341,16 @@ def __init__( self.tp_group = tp_group self.return_layernorm_output = return_layernorm_output self.params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype - self.init_method = init_method self.attn_mask_type = attn_mask_type self.num_attention_heads = num_attention_heads + self.return_bias = return_bias + + kv_channels = kv_channels if kv_channels else (hidden_size // num_attention_heads) + + if init_method is None: + init_method = get_default_init_method() + if output_layer_init_method is None: + output_layer_init_method = get_default_init_method() if not fuse_qkv_params: qkv_weight_interleaved = False @@ -1358,7 +1479,7 @@ def __init__( hidden_size, init_method=output_layer_init_method, bias=bias, - return_bias=True, + return_bias=return_bias, parallel_mode="row" if set_parallel_mode else None, ub_split_rs=ub_split_rs, ub_split_ag=ub_split_ag, @@ -1395,10 +1516,54 @@ def forward( core_attention_bias: Optional[torch.Tensor] = None, fast_zero_fill: bool = True, ) -> Tuple[Union[torch.Tensor, None], ...]: - """MultiHeadAttention FWD""" + """ + Forward propagation for MultiheadAttention layer. + + .. note:: + + Argument :attr:`attention_mask` will be ignored when :attr:`self_attn_mask_type` + is set to `"causal"`. + + Parameters + ---------- + hidden_states : torch.Tensor + Input tensor. + attention_mask : Optional[torch.Tensor], default = `None` + Boolean tensor used to mask out self-attention softmax input. + encoder_output : Optional[torch.Tensor], default = `None` + Output of the encoder block to be fed into the decoder block if using + `layer_type="decoder"`. + is_first_microbatch : {True, False, None}, default = None + During training using either gradient accumulation or + pipeline parallelism a minibatch of data is further split + into microbatches. Between the microbatches of the same minibatch + the model weights are not updated. Setting this parameter indicates + whether the current microbatch is the first in a minibatch or not. + When set, this parameter enables additional optimizations: + + * during FP8 training, it allows caching of the FP8 versions of + the weights + * it also allows skipping gradient accumulation during the + first microbatch (since it is the first gradient being + produced) + checkpoint_core_attention: bool, default = `False` + If true, forward activations for core attention are recomputed + during the backward pass in order to save memory that would + otherwise be occupied to store the forward activations until + backprop. + rotary_pos_emb: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None` + Embeddings for query and key tensors for applying rotary position + embedding. By default no input embedding is applied. + core_attention_bias_type: str, default = `no_bias` + Bias type, {`no_bias`, `pre_scale_bias`, 'post_scale_bias`} + core_attention_bias: Optional[torch.Tensor], default = `None` + Bias tensor for Q * K.T + fast_zero_fill: bool, default = `True` + Whether to set output tensors to 0 or not before use. + """ # hidden_states: [sq, b, h] - if self.attn_mask_type != "causal" and attention_mask is not None: + if self.attn_mask_type == "padding" and attention_mask is not None: assert ( attention_mask.dtype == torch.bool ), "Attention mask must be a boolean tensor" @@ -1604,20 +1769,28 @@ def forward( key_layer, value_layer, attention_mask, - checkpoint_core_attention = checkpoint_core_attention, - core_attention_bias_type = core_attention_bias_type, - core_attention_bias = core_attention_bias, - fast_zero_fill = fast_zero_fill, + checkpoint_core_attention=checkpoint_core_attention, + core_attention_bias_type=core_attention_bias_type, + core_attention_bias=core_attention_bias, + fast_zero_fill=fast_zero_fill, ) # ================= # Output. [sq, b, h] # ================= - attention_output, attention_bias = self.proj( + projection_output = self.proj( context_layer, is_first_microbatch=is_first_microbatch ) + if self.return_bias: + attention_output, attention_bias = projection_output + else: + attention_output, attention_bias = projection_output, None + + outputs = (attention_output,) + if self.return_bias: + outputs += (attention_bias,) if self.input_layernorm and self.return_layernorm_output: - return attention_output, attention_bias, layernorm_output - return attention_output, attention_bias + outputs += (layernorm_output,) + return outputs if len(outputs) > 1 else outputs[0] diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index f27784d135..de93cd652f 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -12,7 +12,7 @@ import transformer_engine_extensions as tex from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm -from transformer_engine.pytorch.attention import MultiHeadAttention +from transformer_engine.pytorch.attention import MultiheadAttention from transformer_engine.pytorch.jit import ( set_jit_fusion_options, warmup_jit_bias_dropout_add_all_dtypes, @@ -323,25 +323,27 @@ def __init__( "ub_split_rs" : ub_split_rs, } - self.self_attention = MultiHeadAttention( + self.self_attention = MultiheadAttention( *attention_args, **common_attention_kwargs, attn_mask_type=self_attn_mask_type, input_layernorm=not output_layernorm, attention_type="self", bias=bias, + return_bias=True, normalization=normalization, device=device, ) if layer_type == "decoder": - self.inter_attention = MultiHeadAttention( + self.inter_attention = MultiheadAttention( *attention_args, **common_attention_kwargs, attn_mask_type="padding", input_layernorm=True, attention_type="cross", bias=bias, + return_bias=True, normalization=normalization, device=device, ) From 5b16352a5eb6bcb6e506fef5c0d8319a1c73400a Mon Sep 17 00:00:00 2001 From: cyanguwa <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 25 Aug 2023 15:35:26 -0700 Subject: [PATCH 49/68] Fix rng_state issue and minor compiler warning (#395) fix rng_state issue and minor compiler warning Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/test_fused_attn.py | 6 ++---- .../common/transpose/transpose_fusion.cu | 2 -- .../pytorch/csrc/extensions/attention.cu | 16 ++++++++++++++-- 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/tests/pytorch/test_fused_attn.py b/tests/pytorch/test_fused_attn.py index f516b70b0e..3c8a10e9e9 100644 --- a/tests/pytorch/test_fused_attn.py +++ b/tests/pytorch/test_fused_attn.py @@ -181,9 +181,6 @@ def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type): seqlens.fill_(config.seq_len) cu_seqlens = torch.zeros(bs + 1, device = inp.device, dtype = torch.int32) cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0) - op_grad = torch.randn( - config.seq_len, bs, config.num_attention_heads * config.head_dim, - dtype = dtype).cuda() sigma = 0.02 init_method = init_method_normal(sigma) @@ -241,7 +238,8 @@ def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type): checkpoint_core_attention = ckpt_attn, core_attention_bias_type = bias_type, core_attention_bias = bias) - op.backward(op_grad) + loss = op.sum() + loss.backward() return op, inp.grad diff --git a/transformer_engine/common/transpose/transpose_fusion.cu b/transformer_engine/common/transpose/transpose_fusion.cu index ba89c4abd2..8561a6881b 100644 --- a/transformer_engine/common/transpose/transpose_fusion.cu +++ b/transformer_engine/common/transpose/transpose_fusion.cu @@ -293,8 +293,6 @@ transpose_dbias_kernel_notaligned(const Param param, } } OVec out_trans[nvec_in]; // NOLINT(*) - const bool valid_store = my_place < tile_length && - warp_id_in_tile * n_iterations + i < tile_height; transpose_regs_partial_dbias( in[current_in ^ 1], out_trans, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index 4904fbade5..423b16013f 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -194,7 +194,13 @@ std::vector fused_attn_fwd_qkvpacked( for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { auto tensor = reinterpret_cast(nvte_aux_tensor_pack.tensors[i]); // allocate memory for nvte_aux_tensor_pack.tensors - auto output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false); + at::Tensor output_tensor; + if (nvte_aux_tensor_pack.size >= 2) { + output_tensor = (i < nvte_aux_tensor_pack.size-1) + ? allocateSpace(tensor->data.shape, tensor->data.dtype, false) : rng_state; + } else { + output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false); + } output_tensors.push_back(output_tensor); tensor->data.dptr = output_tensor.data_ptr(); } @@ -497,7 +503,13 @@ std::vector fused_attn_fwd_kvpacked( for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { auto tensor = reinterpret_cast(nvte_aux_tensor_pack.tensors[i]); // allocate memory for nvte_aux_tensor_pack.tensors - auto output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false); + at::Tensor output_tensor; + if (nvte_aux_tensor_pack.size >= 2) { + output_tensor = (i < nvte_aux_tensor_pack.size-1) + ? allocateSpace(tensor->data.shape, tensor->data.dtype, false) : rng_state; + } else { + output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false); + } output_tensors.push_back(output_tensor); tensor->data.dptr = output_tensor.data_ptr(); } From e6db29d15bdfeaefea091372a1b43a8a59d0f51d Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Fri, 25 Aug 2023 19:21:17 -0700 Subject: [PATCH 50/68] [PyTorch] move mask types to fprop (#402) * API change and some test fixes Signed-off-by: Kirthi Shankar Sivamani * more test fixes Signed-off-by: Kirthi Shankar Sivamani * ONNX fixes Signed-off-by: Kirthi Shankar Sivamani * fix Signed-off-by: Kirthi Shankar Sivamani * Fixed fused attention tests Signed-off-by: Kirthi Shankar Sivamani * rm duplicate test Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani --- tests/pytorch/test_fused_attn.py | 252 +++++++++++----------- tests/pytorch/test_numerics.py | 24 ++- tests/pytorch/test_onnx_export.py | 29 +-- tests/pytorch/test_sanity.py | 10 +- transformer_engine/pytorch/attention.py | 145 ++++++++----- transformer_engine/pytorch/softmax.py | 5 +- transformer_engine/pytorch/transformer.py | 48 +++-- 7 files changed, 287 insertions(+), 226 deletions(-) diff --git a/tests/pytorch/test_fused_attn.py b/tests/pytorch/test_fused_attn.py index 3c8a10e9e9..32442e40fb 100644 --- a/tests/pytorch/test_fused_attn.py +++ b/tests/pytorch/test_fused_attn.py @@ -77,10 +77,10 @@ def test_dot_product_attention(dtype, bs, model, ckpt_attn, bias_type): atol, rtol = (2.5e-2, 2.5e-2) if dtype == torch.bfloat16 else (5e-3, 5e-3) if bias_type == "no_bias": - assert torch.allclose(fused_attn_fwd, flash_attn_fwd, atol = atol, rtol = rtol) - assert torch.allclose(fused_attn_bwd, flash_attn_bwd, atol = atol, rtol = rtol) - assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol) - assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol = atol, rtol = rtol) + assert torch.allclose(fused_attn_fwd, flash_attn_fwd, atol=atol, rtol=rtol) + assert torch.allclose(fused_attn_bwd, flash_attn_bwd, atol=atol, rtol=rtol) + assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol) + assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol) def _run_dot_product_attention(dtype, bs, config, backend, ckpt_attn, bias_type): @@ -94,18 +94,18 @@ def _run_dot_product_attention(dtype, bs, config, backend, ckpt_attn, bias_type) inp = torch.randn( config.seq_len, bs, 3, config.num_attention_heads, config.head_dim, - dtype = dtype).cuda() + dtype=dtype).cuda() inp.requires_grad=True - seqlens = torch.empty(bs, dtype = torch.int32).cuda() + seqlens = torch.empty(bs, dtype=torch.int32).cuda() seqlens.fill_(config.seq_len) - cu_seqlens = torch.zeros(bs + 1, device = inp.device, dtype = torch.int32) - cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0) + cu_seqlens = torch.zeros(bs + 1, device=inp.device, dtype=torch.int32) + cu_seqlens[1:] = torch.cumsum(seqlens, dim=0) op_grad = torch.randn( config.seq_len, bs, config.num_attention_heads * config.head_dim, dtype = dtype).cuda() if bias_type != "no_bias": bias = torch.randn(1, config.num_attention_heads, config.seq_len, config.seq_len, - dtype = dtype).cuda() + dtype=dtype).cuda() else: bias = None @@ -113,24 +113,23 @@ def _run_dot_product_attention(dtype, bs, config, backend, ckpt_attn, bias_type) DotProductAttention( config.num_attention_heads, config.head_dim, - attention_dropout = config.dropout_p, - attn_mask_type = config.attn_mask_type, - sequence_parallel = False, - tp_size = 1, - get_rng_state_tracker = get_dummy_cuda_rng_tracker, - tp_group = None, - layer_number = 1, - attention_type = "self" - ).to(dtype = dtype).cuda() + attention_dropout=config.dropout_p, + sequence_parallel=False, + tp_size=1, + get_rng_state_tracker=get_dummy_cuda_rng_tracker, + tp_group=None, + layer_number=1, + attention_type="self" + ).to(dtype=dtype).cuda() ) q = inp[:, :,0,:,:] k = inp[:, :,1,:,:] v = inp[:, :,2,:,:] - op = block(q, k, v, - checkpoint_core_attention = ckpt_attn, - core_attention_bias_type = bias_type, - core_attention_bias = bias) + op = block(q, k, v, attn_mask_type=config.attn_mask_type, + checkpoint_core_attention=ckpt_attn, + core_attention_bias_type=bias_type, + core_attention_bias=bias) op.backward(op_grad) return op, inp.grad @@ -158,10 +157,10 @@ def test_transformer_layer(dtype, bs, model, ckpt_attn, bias_type): atol, rtol = (5e-1, 5e-2) if bias_type == "no_bias": - assert torch.allclose(fused_attn_fwd, flash_attn_fwd, atol = atol, rtol = rtol) - assert torch.allclose(fused_attn_bwd, flash_attn_bwd, atol = atol, rtol = rtol) - assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol) - assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol = atol, rtol = rtol) + assert torch.allclose(fused_attn_fwd, flash_attn_fwd, atol=atol, rtol=rtol) + assert torch.allclose(fused_attn_bwd, flash_attn_bwd, atol=atol, rtol=rtol) + assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol) + assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol) def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type): @@ -175,12 +174,12 @@ def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type): inp = torch.randn( config.seq_len, bs, config.num_attention_heads * config.head_dim, - dtype = dtype).cuda() + dtype=dtype).cuda() inp.requires_grad=True - seqlens = torch.empty(bs, dtype = torch.int32).cuda() + seqlens = torch.empty(bs, dtype=torch.int32).cuda() seqlens.fill_(config.seq_len) - cu_seqlens = torch.zeros(bs + 1, device = inp.device, dtype = torch.int32) - cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0) + cu_seqlens = torch.zeros(bs + 1, device=inp.device, dtype=torch.int32) + cu_seqlens[1:] = torch.cumsum(seqlens, dim=0) sigma = 0.02 init_method = init_method_normal(sigma) @@ -192,7 +191,7 @@ def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type): rate.item() for rate in torch.linspace(0, drop_path_rate, config.num_layers)] if bias_type != "no_bias": bias = torch.randn(1, config.num_attention_heads, config.seq_len, config.seq_len, - dtype = dtype).cuda() + dtype=dtype).cuda() else: bias = None @@ -201,43 +200,42 @@ def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type): config.hidden_size, 4 * config.hidden_size, config.num_attention_heads, - layernorm_epsilon = 1e-5, - hidden_dropout = 0.0, - attention_dropout = config.dropout_p, - init_method = init_method, - output_layer_init_method = output_layer_init_method, - layer_number = layer_number, - kv_channels = config.head_dim, - self_attn_mask_type = config.attn_mask_type, - tp_group = None, - tp_size = 1, - params_dtype = dtype, - get_rng_state_tracker = None, - fuse_wgrad_accumulation = False, - seq_length = config.seq_len, - micro_batch_size = bs, - sequence_parallel = False, - apply_residual_connection_post_layernorm = False, - output_layernorm = False, - layer_type = "encoder", - drop_path_rate = drop_path_rates[layer_number - 1], - set_parallel_mode = True, - fuse_qkv_params = True, - zero_centered_gamma = False, - qkv_weight_interleaved = False, - ub_tp_comm_overlap = False, - bias = True, + layernorm_epsilon=1e-5, + hidden_dropout=0.0, + attention_dropout=config.dropout_p, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + layer_number=layer_number, + kv_channels=config.head_dim, + tp_group=None, + tp_size=1, + params_dtype=dtype, + get_rng_state_tracker=None, + fuse_wgrad_accumulation=False, + seq_length=config.seq_len, + micro_batch_size=bs, + sequence_parallel=False, + apply_residual_connection_post_layernorm=False, + output_layernorm=False, + layer_type="encoder", + drop_path_rate=drop_path_rates[layer_number - 1], + set_parallel_mode=True, + fuse_qkv_params=True, + zero_centered_gamma=False, + qkv_weight_interleaved=False, + ub_tp_comm_overlap=False, + bias=True, ) - .to(dtype = dtype) + .to(dtype=dtype) .cuda() ) num_iters = 10 for i in range(num_iters): - op = block(inp, - checkpoint_core_attention = ckpt_attn, - core_attention_bias_type = bias_type, - core_attention_bias = bias) + op = block(inp, self_attn_mask_type=config.attn_mask_type, + checkpoint_core_attention=ckpt_attn, + core_attention_bias_type=bias_type, + core_attention_bias=bias) loss = op.sum() loss.backward() @@ -270,8 +268,8 @@ def find_factors(x): dtype, bs, config, "UnfusedDotProductAttention", num_q_per_gqa_group) atol, rtol = 5e-1, 5e-2 - assert torch.allclose(flash_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol) - assert torch.allclose(flash_attn_bwd, unfused_attn_bwd, atol = atol, rtol = rtol) + assert torch.allclose(flash_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol) + assert torch.allclose(flash_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol) def _run_transformer_layer_gqa(dtype, bs, config, backend, num_querys_per_gqa_group): @@ -282,15 +280,15 @@ def _run_transformer_layer_gqa(dtype, bs, config, backend, num_querys_per_gqa_gr inp = torch.randn( config.seq_len, bs, config.num_attention_heads * config.head_dim, - dtype = dtype).cuda() + dtype=dtype).cuda() inp.requires_grad=True - seqlens = torch.empty(bs, dtype = torch.int32).cuda() + seqlens = torch.empty(bs, dtype=torch.int32).cuda() seqlens.fill_(config.seq_len) - cu_seqlens = torch.zeros(bs + 1, device = inp.device, dtype = torch.int32) - cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0) + cu_seqlens = torch.zeros(bs + 1, device=inp.device, dtype=torch.int32) + cu_seqlens[1:] = torch.cumsum(seqlens, dim=0) op_grad = torch.randn( config.seq_len, bs, config.num_attention_heads * config.head_dim, - dtype = dtype).cuda() + dtype=dtype).cuda() sigma = 0.02 init_method = init_method_normal(sigma) @@ -306,39 +304,38 @@ def _run_transformer_layer_gqa(dtype, bs, config, backend, num_querys_per_gqa_gr config.hidden_size, 4 * config.hidden_size, config.num_attention_heads, - num_gqa_groups = config.num_attention_heads / num_querys_per_gqa_group, - layernorm_epsilon = 1e-5, - hidden_dropout = 0.0, - attention_dropout = config.dropout_p, - init_method = init_method, - output_layer_init_method = output_layer_init_method, - layer_number = layer_number, - kv_channels = config.head_dim, - self_attn_mask_type = config.attn_mask_type, - tp_group = None, - tp_size = 1, - params_dtype = dtype, - get_rng_state_tracker = None, - fuse_wgrad_accumulation = False, - seq_length = config.seq_len, - micro_batch_size = bs, - sequence_parallel = False, - apply_residual_connection_post_layernorm = False, - output_layernorm = False, - layer_type = "encoder", - drop_path_rate = drop_path_rates[layer_number - 1], - set_parallel_mode = True, - fuse_qkv_params = True, - zero_centered_gamma = False, - qkv_weight_interleaved = False, - ub_tp_comm_overlap = False, - bias = True, + num_gqa_groups=config.num_attention_heads / num_querys_per_gqa_group, + layernorm_epsilon=1e-5, + hidden_dropout=0.0, + attention_dropout=config.dropout_p, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + layer_number=layer_number, + kv_channels=config.head_dim, + tp_group=None, + tp_size= 1, + params_dtype=dtype, + get_rng_state_tracker=None, + fuse_wgrad_accumulation=False, + seq_length=config.seq_len, + micro_batch_size=bs, + sequence_parallel=False, + apply_residual_connection_post_layernorm=False, + output_layernorm=False, + layer_type="encoder", + drop_path_rate=drop_path_rates[layer_number - 1], + set_parallel_mode=True, + fuse_qkv_params=True, + zero_centered_gamma=False, + qkv_weight_interleaved=False, + ub_tp_comm_overlap=False, + bias=True, ) - .to(dtype = dtype) + .to(dtype=dtype) .cuda() ) - op = block(inp) + op = block(inp, self_attn_mask_type=config.attn_mask_type) op.backward(op_grad) return op, inp.grad @@ -365,8 +362,8 @@ def test_dpa_fp8(dtype, bs, model): dtype, bs, config, "UnfusedDotProductAttention") atol, rtol = (2.5e-2, 2.5e-2) - assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol) - assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol = atol, rtol = rtol) + assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol) + assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol) def _run_dpa_fp8(dtype, bs, config, backend): @@ -376,15 +373,15 @@ def _run_dpa_fp8(dtype, bs, config, backend): inp = 0.01 * torch.randn( bs * config.seq_len, config.num_attention_heads * config.head_dim, - dtype = dtype).cuda() + dtype=dtype).cuda() inp.requires_grad=True - seqlens = torch.empty(bs, dtype = torch.int32).cuda() + seqlens = torch.empty(bs, dtype=torch.int32).cuda() seqlens.fill_(config.seq_len) - cu_seqlens = torch.zeros(bs + 1, device = inp.device, dtype = torch.int32) - cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0) + cu_seqlens = torch.zeros(bs + 1, device=inp.device, dtype=torch.int32) + cu_seqlens[1:] = torch.cumsum(seqlens, dim=0) op_grad = 0.01 * torch.randn( bs * config.seq_len, config.num_attention_heads * config.head_dim, - dtype = dtype).cuda() + dtype=dtype).cuda() torch.save(op_grad, 'op_grad.pt') fp8_recipe = recipe.DelayedScaling( @@ -395,7 +392,7 @@ def _run_dpa_fp8(dtype, bs, config, backend): amax_compute_algo="most_recent", ) - dpa = DPA_FP8(config).to(dtype = torch.float16).cuda() + dpa = DPA_FP8(config).to(dtype=torch.float16).cuda() with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): op = dpa(inp, cu_seqlens, config.seq_len) op.backward(op_grad) @@ -416,31 +413,30 @@ def _run_dpa_fp8_ref(dtype, bs, config, backend): inp = torch.load('qkv.pt').cuda() inp.requires_grad=True - seqlens = torch.empty(bs, dtype = torch.int32).cuda() + seqlens = torch.empty(bs, dtype=torch.int32).cuda() seqlens.fill_(config.seq_len) - cu_seqlens = torch.zeros(bs + 1, device = inp.device, dtype = torch.int32) - cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0) + cu_seqlens = torch.zeros(bs + 1, device=inp.device, dtype=torch.int32) + cu_seqlens[1:] = torch.cumsum(seqlens, dim=0) op_grad = torch.load('op_grad.pt').cuda().view(bs, config.seq_len, -1).transpose(0,1) block = ( DotProductAttention( config.num_attention_heads, config.head_dim, - attention_dropout = config.dropout_p, - attn_mask_type = config.attn_mask_type, - sequence_parallel = False, - tp_size = 1, - get_rng_state_tracker = None, - tp_group = None, - layer_number = 1, - attention_type = "self" - ).to(dtype = dtype).cuda() + attention_dropout=config.dropout_p, + sequence_parallel=False, + tp_size=1, + get_rng_state_tracker=None, + tp_group=None, + layer_number=1, + attention_type="self" + ).to(dtype=dtype).cuda() ) q = inp[:, :,0,:,:] k = inp[:, :,1,:,:] v = inp[:, :,2,:,:] - op = block(q, k, v) + op = block(q, k, v, attn_mask_type=config.attn_mask_type) op.backward(op_grad) torch.save(op,'ctx_ref.pt') torch.save(inp.grad,'dqkv_ref.pt') @@ -533,8 +529,8 @@ def forward( workspace, bias=qkv_bias, use_bias=True, - out_index = META_QKV, - fp8_meta_tensor = fp8_meta["scaling_fwd"], + out_index=META_QKV, + fp8_meta_tensor=fp8_meta["scaling_fwd"], use_split_accumulator=_2X_ACC_FPROP, D_dtype=fp8_dtype_forward, ) @@ -558,13 +554,13 @@ def forward( fp8_meta["scaling_fwd"].scale[META_O], fp8_meta["scaling_fwd"].amax_history[0][META_S], fp8_meta["scaling_fwd"].amax_history[0][META_O], - attn_scale = None, - dropout = p_dropout, - fast_zero_fill = fast_zero_fill, - qkv_layout = "qkv_interleaved", - attn_bias_type = "no_bias", - attn_mask_type = "padding", - rng_gen = None, + attn_scale=None, + dropout=p_dropout, + fast_zero_fill=fast_zero_fill, + qkv_layout="qkv_interleaved", + attn_bias_type="no_bias", + attn_mask_type="padding", + rng_gen=None, ) M, ZInv, philox_unpacked = aux_ctx_tensors diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index f8eda48cc3..bf9f7502fd 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -376,8 +376,8 @@ def __init__(self, hidden_size: int, num_attention_heads: int): batch_first=False, ) - def forward(self, x, attn_mask=None): - output = self.mhsa(x, x, x, attn_mask=attn_mask, need_weights=False) + def forward(self, x, attention_mask=None): + output = self.mhsa(x, x, x, attn_mask=attention_mask, need_weights=False) if isinstance(output, tuple): output = output[0] return output @@ -461,7 +461,7 @@ def _test_e2e_selective_recompute(block, bs, dtype, config, recompute=False): te_out = block( te_inp_hidden_states, - te_inp_attn_mask, + attention_mask=te_inp_attn_mask, checkpoint_core_attention=recompute, ) loss = te_out.sum() @@ -526,13 +526,13 @@ def _test_e2e_full_recompute(block, bs, dtype, config, recompute=False): get_dummy_cuda_rng_tracker, None, # tp_group te_inp_hidden_states, - te_inp_attn_mask, + attention_mask=te_inp_attn_mask, checkpoint_core_attention=False, ) else: te_out = block( te_inp_hidden_states, - te_inp_attn_mask, + attention_mask=te_inp_attn_mask, checkpoint_core_attention=False, ) loss = te_out.sum() @@ -766,7 +766,7 @@ def test_gpt_accuracy(dtype, bs, model): assert_allclose(te_outputs[0], torch_outputs[0], 5e-2) -def _test_mha_accuracy(block, bs, dtype, config, mask_type): +def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True): reset_rng_states() inp_hidden_states = torch.randn( @@ -775,7 +775,12 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type): inp_hidden_states.retain_grad() inp_attn_mask = get_causal_attn_mask(config.seq_len) if mask_type == "causal" else None - out = block(inp_hidden_states, inp_attn_mask) + forward_kwargs = {} + if te: + forward_kwargs["attn_mask_type"] = mask_type + forward_kwargs["attention_mask"] = inp_attn_mask + + out = block(inp_hidden_states, **forward_kwargs) loss = out.sum() loss.backward() @@ -801,7 +806,6 @@ def test_mha_accuracy(dtype, bs, model, mask_type): fuse_qkv_params=True, qkv_weight_interleaved=False, input_layernorm=False, - attn_mask_type=mask_type, ) .to(dtype=dtype) .cuda() @@ -825,8 +829,8 @@ def test_mha_accuracy(dtype, bs, model, mask_type): torch_mha.mhsa.out_proj.weight = Parameter(te_mha.proj.weight.clone()) torch_mha.mhsa.out_proj.bias = Parameter(te_mha.proj.bias.clone()) - te_outputs = _test_mha_accuracy(te_mha, bs, dtype, config, mask_type) - torch_outputs = _test_mha_accuracy(torch_mha, bs, dtype, config, mask_type) + te_outputs = _test_mha_accuracy(te_mha, bs, dtype, config, mask_type, te=True) + torch_outputs = _test_mha_accuracy(torch_mha, bs, dtype, config, mask_type, te=False) # Check output. if dtype == torch.float32: diff --git a/tests/pytorch/test_onnx_export.py b/tests/pytorch/test_onnx_export.py index 1e1fafcac5..14640febde 100644 --- a/tests/pytorch/test_onnx_export.py +++ b/tests/pytorch/test_onnx_export.py @@ -783,7 +783,6 @@ def __init__(self, softmax_fn, fake_bf16_io, mask_inp=False): self.fake_bf16_io = fake_bf16_io if self.softmax_fn == te.softmax.FusedScaleMaskSoftmax: self.fused_scaled_softmax = te.softmax.FusedScaleMaskSoftmax( - attn_mask_type="causal", mask_func=te.utils.attention_mask_func, softmax_in_fp32=True, ) @@ -793,7 +792,7 @@ def forward(self, inp, mask): inp = inp.type(torch.bfloat16) if self.fused_scaled_softmax: - ret = self.fused_scaled_softmax(inp, mask, self.scale) + ret = self.fused_scaled_softmax(inp, mask, "causal", self.scale) else: if self.mask_inp: ret = self.softmax_fn.apply(inp, mask, self.scale) @@ -867,7 +866,6 @@ def __init__(self, use_default_te_mask_fn: bool, fake_bf16_io: bool): # even when is_in_onnx_export_mode()==False. os.environ["NVTE_MASKED_SOFTMAX_FUSION"] = "0" self.fused_scaled_softmax = te.softmax.FusedScaleMaskSoftmax( - attn_mask_type="causal", mask_func=te.utils.attention_mask_func, softmax_in_fp32=True, ) @@ -875,7 +873,7 @@ def __init__(self, use_default_te_mask_fn: bool, fake_bf16_io: bool): def forward(self, inp, mask): if self.fake_bf16_io: inp = inp.type(torch.bfloat16) - ret = self.fused_scaled_softmax(inp, mask, self.scale) + ret = self.fused_scaled_softmax(inp, mask, "causal", scale=self.scale) if self.fake_bf16_io: ret = ret.type(torch.float) return ret @@ -1161,13 +1159,13 @@ def test_export_core_attention( query_layer = torch.randn(qkv_size, dtype=precision, device="cuda") key_layer = torch.randn(qkv_size, dtype=precision, device="cuda") value_layer = torch.randn(qkv_size, dtype=precision, device="cuda") - input_names = ["query", "key", "value", "attention_mask"] + input_names = ["query", "key", "value", "attention_mask", "attn_mask_type"] attention_mask = None if use_mask: # Generate a random mask with 50% probability for 0 or 1. probs = 0.5 * torch.ones(qkv_size[1], qkv_size[2], qkv_size[0], qkv_size[0], device="cuda", dtype=precision) attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool) - inp = (query_layer, key_layer, value_layer, attention_mask) + inp = (query_layer, key_layer, value_layer, attention_mask, attn_mask_type) mask_str = get_attn_mask_str(use_mask, attn_mask_type) high_prec_str = dtype2str(precision) @@ -1177,7 +1175,6 @@ def test_export_core_attention( num_attention_heads=num_attention_heads, kv_channels=kv_channels, attention_dropout=0.5, - attn_mask_type=attn_mask_type, ).to(device='cuda') do_export(model, inp, @@ -1193,9 +1190,8 @@ def test_export_core_attention( test_configs_multihead_attention = [ #"use_mask, attn_mask_type" - (False, "causal"), # calls ScaledUpperTriangMaskedSoftmax + (False, "no_mask"), # calls ScaledUpperTriangMaskedSoftmax (True, "padding"), # calls ScaledMaskedSoftmax - (False, "padding"), # calls ScaledSoftmax ] test_configs_attention_type = [ #"input_layernorm, attention_type, fuse_qkv_params" @@ -1269,7 +1265,6 @@ def test_export_multihead_attention( model = te.MultiheadAttention( *attention_args, - attn_mask_type=attn_mask_type, params_dtype=precision, return_layernorm_output=return_layernorm_output, input_layernorm=input_layernorm, @@ -1278,8 +1273,8 @@ def test_export_multihead_attention( return_bias=True, ).to(device='cuda') - inp_context = (hidden_states_context, attention_mask, encoder_output) - input_names = ["hidden_states", "attention_mask", "encoder_output"] + inp_context = (hidden_states_context, attention_mask, encoder_output, attn_mask_type) + input_names = ["hidden_states", "attention_mask", "encoder_output", "attn_mask_type"] output_names=["attention_output", "attention_bias"] do_export(model, inp_context, fname, use_fp8, input_names=input_names, output_names=output_names, dynamic_axes={"hidden_states": {0: "seq", 1:"bs"}, @@ -1347,13 +1342,13 @@ def test_export_transformer_layer( num_attention_heads = 4 input_tensor = torch.rand(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda") - input_names = ["input", "attention_mask"] + input_names = ["input", "attention_mask", "self_attn_mask_type"] attention_mask = None if use_mask and attn_mask_type != "causal": # Generate a random mask with 50% probability for 0 or 1. probs = 0.5 * torch.ones(batch_size, 1, sequence_length, sequence_length, device="cuda", dtype=precision) attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool) - inp = (input_tensor, attention_mask) + inp = (input_tensor, attention_mask, attn_mask_type) fp8_str = "_fp8" if use_fp8 else "" fuse_qkv_params_str = "_fused-qkv" if fuse_qkv_params else "" @@ -1365,7 +1360,6 @@ def test_export_transformer_layer( hidden_size, ffn_hidden_size, num_attention_heads, - self_attn_mask_type=attn_mask_type, output_layernorm=output_layernorm, params_dtype=precision, fuse_qkv_params=fuse_qkv_params, @@ -1547,17 +1541,16 @@ def test_export_gpt_generation( hidden_size, ffn_hidden_size, num_attention_heads, - self_attn_mask_type=attn_mask_type, output_layernorm=output_layernorm, params_dtype=precision, fuse_qkv_params=fuse_qkv_params, zero_centered_gamma=zero_centered_gamma).to(device='cuda') # "Context phase": use full input sequence length - input_names = ["input"] + input_names = ["input", "attention_mask", "self_attn_mask_type"] output_names = ["output"] input_tensor = torch.rand(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda") - inp = (input_tensor,) + inp = (input_tensor, None, attn_mask_type) do_export(model, inp, fname, use_fp8, input_names=input_names, output_names=output_names, dynamic_axes={"input": {0: "seq", 1:"bs"}, diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 2605c563d6..21497b417f 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -176,7 +176,7 @@ def _test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe, skip_wgrad): use_fp8 = fp8_recipe is not None with torch.autocast(device_type="cuda", enabled=True, dtype=dtype): with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): - te_out = block(te_inp_hidden_states, te_inp_attn_mask) + te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask) loss = te_out.sum() loss.backward() @@ -217,7 +217,7 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, bs, dtype, config, fp8_ use_fp8 = fp8_recipe is not None with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): - te_out = block(te_inp_hidden_states, te_inp_attn_mask) + te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask) loss = te_out.sum() loss.backward() torch.cuda.synchronize() @@ -253,7 +253,7 @@ def _test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad): use_fp8 = fp8_recipe is not None with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): - te_out = block(te_inp_hidden_states, te_inp_attn_mask) + te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask) loss = te_out.sum() loss.backward() torch.cuda.synchronize() @@ -282,7 +282,9 @@ def _test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe, skip_wgrad): use_fp8 = fp8_recipe is not None with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): te_out = block( - te_inp_hidden_states, te_inp_attn_mask, encoder_output=te_inp_hidden_states + te_inp_hidden_states, + attention_mask=te_inp_attn_mask, + encoder_output=te_inp_hidden_states ) loss = te_out.sum() loss.backward() diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 6842a9bc60..a30f20d3a8 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -196,23 +196,15 @@ def __init__( norm_factor: float, attention_dropout: float = 0.0, attention_dropout_ctx: Optional[Callable] = nullcontext, - attn_mask_type: str = "causal", layer_number: Optional[int] = None, ) -> None: super().__init__() - assert ( - attn_mask_type in AttnMaskTypes - ), f"attn_mask_type {attn_mask_type} not supported" - self.norm_factor = norm_factor self.attention_dropout_ctx = attention_dropout_ctx self.layer_number = layer_number - self.scale_mask_softmax = FusedScaleMaskSoftmax( - attn_mask_type, - attention_mask_func, - ) + self.scale_mask_softmax = FusedScaleMaskSoftmax(attention_mask_func) # Dropout. Note that for a single iteration, this layer will generate # different outputs on different number of parallel partitions but @@ -228,11 +220,17 @@ def forward( query_layer: torch.Tensor, key_layer: torch.Tensor, value_layer: torch.Tensor, + attn_mask_type: str = "causal", attention_mask: Optional[torch.Tensor] = None, core_attention_bias_type: str = "no_bias", core_attention_bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: """core attention fprop""" + + assert ( + attn_mask_type in AttnMaskTypes + ), f"attn_mask_type {attn_mask_type} not supported" + batch_size, seqlen = query_layer.shape[1], query_layer.shape[0] apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16 @@ -321,7 +319,8 @@ def forward( # attention scores and attention mask [b, np, sq, sk] softmax_scale = self.layer_number if apply_qk_layer_scaling else None - attention_probs = self.scale_mask_softmax(attention_scores, attention_mask, softmax_scale) + attention_probs = self.scale_mask_softmax( + attention_scores, attention_mask, attn_mask_type, softmax_scale) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. @@ -464,7 +463,6 @@ def __init__( norm_factor: float, attention_dropout: float = 0.0, attention_dropout_ctx: Optional[Callable] = nullcontext, - attn_mask_type: str = "causal", deterministic: bool = False, ) -> None: super().__init__() @@ -473,7 +471,6 @@ def __init__( _flash_attn_version >= _flash_attn_version_required ), f"FlashAttention minimum version {_flash_attn_version_required} is required." - self.attn_causal_mask = attn_mask_type == "causal" self.norm_factor = norm_factor self.attention_dropout_ctx = attention_dropout_ctx self.attention_dropout = attention_dropout @@ -484,6 +481,7 @@ def forward( query_layer: torch.Tensor, key_layer: torch.Tensor, value_layer: torch.Tensor, + attn_mask_type: str = "causal", ) -> torch.Tensor: """flash-attn fprop""" @@ -531,7 +529,7 @@ def forward( output = flash_attn_forward_func( query_layer, key_layer, value_layer, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, self.attention_dropout if self.training else 0.0, - softmax_scale=1.0/self.norm_factor, causal=self.attn_causal_mask, + softmax_scale=1.0/self.norm_factor, causal=attn_mask_type=="causal", **fa_optional_forward_kwargs ) @@ -703,7 +701,6 @@ def __init__( norm_factor: float, attention_dropout: float = 0.0, attention_dropout_ctx: Optional[Callable] = nullcontext, - attn_mask_type: str = "causal", attention_type: str = "self", ) -> None: super().__init__() @@ -711,7 +708,6 @@ def __init__( self.norm_factor = norm_factor self.attention_dropout = attention_dropout self.attention_dropout_ctx = attention_dropout_ctx - self.attn_mask_type = attn_mask_type self.attention_type = attention_type self.use_FAv2_bwd = (os.getenv("NVTE_FUSED_ATTN_USE_FAv2_BWD", "1") == "1" and _flash_attn_2_available @@ -722,6 +718,7 @@ def forward( query_layer: torch.Tensor, key_layer: torch.Tensor, value_layer: torch.Tensor, + attn_mask_type: str = "causal", fused_attention_backend: tex.NVTE_Fused_Attn_Backend = tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend, core_attention_bias_type: str = "no_bias", @@ -797,7 +794,7 @@ def forward( fast_zero_fill, qkv_layout, core_attention_bias_type, - self.attn_mask_type, + attn_mask_type, None, # rng_gen fused_attention_backend, use_FAv2_bwd @@ -858,7 +855,7 @@ def forward( fast_zero_fill, qkv_layout, core_attention_bias_type, - self.attn_mask_type, + attn_mask_type, None, # rng_gen fused_attention_backend, use_FAv2_bwd @@ -886,6 +883,11 @@ class DotProductAttention(torch.nn.Module): and set the environment variable :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0`. In order to disable`flash-attn` entirely, set :attr:`NVTE_FLASH_ATTN=0`. + .. warning:: + + Argument :attr:`attn_mask_type` has been moved to the `forward` method and + is deprecated. It will be fully removed in future releases. + Parameters ---------- num_attention_heads : int @@ -902,8 +904,6 @@ class DotProductAttention(torch.nn.Module): is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`. attention_dropout: float, default = 0.0 dropout probability for the dropout op during multi-head attention. - attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal` - type of attention mask passed into softmax operation. layer_number: int, default = `None` layer number of the current `DotProductAttention` when multiple such modules are concatenated, for instance in consecutive transformer blocks. @@ -924,7 +924,7 @@ def __init__( kv_channels: int, num_gqa_groups: Optional[int] = None, attention_dropout: float = 0.0, - attn_mask_type: str = "causal", + attn_mask_type: Optional[str] = None, sequence_parallel: bool = False, tp_size: int = 1, get_rng_state_tracker: Optional[Callable] = None, @@ -934,6 +934,14 @@ def __init__( ) -> None: super().__init__() + if attn_mask_type is not None: + warnings.warn( + "Argument :attr:`attn_mask_type` has been moved to the `forward` method and" + "is deprecated. It will be fully removed in future releases.", + category=DeprecationWarning, + ) + + self.attn_mask_type = attn_mask_type self.tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group) self.tp_group = tp_group self.get_rng_state_tracker = get_rng_state_tracker @@ -978,10 +986,8 @@ def __init__( attn_kwargs = { "attention_dropout": attention_dropout, "attention_dropout_ctx": attention_dropout_ctx, - "attn_mask_type": attn_mask_type, } self.attention_type = attention_type - self.attn_mask_type = attn_mask_type self.attention_dropout = attention_dropout if self.use_flash_attention: @@ -1025,6 +1031,7 @@ def forward( key_layer: torch.Tensor, value_layer: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, + attn_mask_type: str = "causal", checkpoint_core_attention: bool = False, core_attention_bias_type: str = "no_bias", core_attention_bias: Optional[torch.Tensor] = None, @@ -1067,6 +1074,8 @@ def forward( Value tensor. attention_mask : Optional[torch.Tensor], default = `None` Boolean tensor used to mask out softmax input when not using flash-attn. + attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal` + type of attention mask passed into softmax operation. checkpoint_core_attention : bool, default = `False` If true, forward activations for attention are recomputed during the backward pass in order to save memory that would @@ -1080,6 +1089,15 @@ def forward( Whether to use the fast path to set output tensors to 0 or not. """ + if self.attn_mask_type is not None: + warnings.warn( + "Argument :attr:`attn_mask_type` has been moved to the `forward` method and" + "is deprecated. It will be fully removed in future releases.", + category=DeprecationWarning, + ) + # Keep previous functionality for current users. + attn_mask_type = self.attn_mask_type + assert (key_layer.shape[-2] == self.num_gqa_groups_per_partition and value_layer.shape[-2] == self.num_gqa_groups_per_partition ), f"Keys and values must have {self.num_gqa_groups} heads!" @@ -1102,7 +1120,7 @@ def forward( if not _flash_attn_2_available and self.num_gqa_groups != self.num_attention_heads: use_flash_attention = False - if self.attn_mask_type == "padding" and attention_mask is not None: + if attn_mask_type == "padding" and attention_mask is not None: use_flash_attention = False use_fused_attention = False @@ -1121,7 +1139,7 @@ def forward( TE_DType[key_layer.dtype], QKVLayout[qkv_layout], AttnBiasType[core_attention_bias_type], - AttnMaskType[self.attn_mask_type], + AttnMaskType[attn_mask_type], self.attention_dropout, query_layer.shape[0], key_layer.shape[0], query_layer.shape[-1]) @@ -1144,8 +1162,10 @@ def forward( return self._checkpointed_attention_forward(self.flash_attention, query_layer, key_layer, - value_layer) - return self.flash_attention(query_layer, key_layer, value_layer) + value_layer, + attn_mask_type=attn_mask_type) + return self.flash_attention( + query_layer, key_layer, value_layer, attn_mask_type=attn_mask_type) if use_fused_attention: if checkpoint_core_attention: @@ -1153,15 +1173,17 @@ def forward( query_layer, key_layer, value_layer, - fused_attention_backend = fused_attention_backend, - core_attention_bias_type = core_attention_bias_type, - core_attention_bias = core_attention_bias, - fast_zero_fill = fast_zero_fill) + attn_mask_type=attn_mask_type, + fused_attention_backend=fused_attention_backend, + core_attention_bias_type=core_attention_bias_type, + core_attention_bias=core_attention_bias, + fast_zero_fill=fast_zero_fill) return self.fused_attention(query_layer, key_layer, value_layer, - fused_attention_backend = fused_attention_backend, - core_attention_bias_type = core_attention_bias_type, - core_attention_bias = core_attention_bias, - fast_zero_fill = fast_zero_fill) + attn_mask_type=attn_mask_type, + fused_attention_backend=fused_attention_backend, + core_attention_bias_type=core_attention_bias_type, + core_attention_bias=core_attention_bias, + fast_zero_fill=fast_zero_fill) if checkpoint_core_attention: return self._checkpointed_attention_forward( @@ -1169,16 +1191,18 @@ def forward( query_layer, key_layer, value_layer, - attention_mask = attention_mask, - core_attention_bias_type = core_attention_bias_type, - core_attention_bias = core_attention_bias, + attn_mask_type=attn_mask_type, + attention_mask=attention_mask, + core_attention_bias_type=core_attention_bias_type, + core_attention_bias=core_attention_bias, ) return self.unfused_attention(query_layer, key_layer, value_layer, - attention_mask = attention_mask, - core_attention_bias_type = core_attention_bias_type, - core_attention_bias = core_attention_bias, + attn_mask_type=attn_mask_type, + attention_mask=attention_mask, + core_attention_bias_type=core_attention_bias_type, + core_attention_bias=core_attention_bias, ) @@ -1190,7 +1214,12 @@ class MultiheadAttention(torch.nn.Module): .. note:: Argument :attr:`attention_mask` will be ignored in the `forward` call when - :attr:`self_attn_mask_type` is set to `"causal"`. + :attr:`attn_mask_type` is set to `"causal"`. + + .. warning:: + + Argument :attr:`attn_mask_type` has been moved to the `forward` method and + is deprecated. It will be fully removed in future releases. Parameters ---------- @@ -1217,8 +1246,6 @@ class MultiheadAttention(torch.nn.Module): layer_number: int, default = `None` layer number of the current `TransformerLayer` when multiple such modules are concatenated to form a transformer block. - attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal` - type of attention mask passed into softmax operation. num_gqa_groups : int, default = `None` number of GQA groups in the transformer layer. Grouped Query Attention is described in @@ -1309,7 +1336,7 @@ def __init__( init_method: Optional[Callable] = None, output_layer_init_method: Optional[Callable] = None, layer_number: Optional[int] = None, - attn_mask_type: str = "causal", + attn_mask_type: Optional[str] = None, tp_group: Optional[dist_group_type] = None, tp_size: int = 1, num_gqa_groups: Optional[int] = None, @@ -1334,6 +1361,15 @@ def __init__( device: Union[torch.device, str] = "cuda", ) -> None: super().__init__() + + if attn_mask_type is not None: + warnings.warn( + "Argument :attr:`attn_mask_type` has been moved to the `forward` method and" + "is deprecated. It will be fully removed in future releases.", + category=DeprecationWarning, + ) + + self.attn_mask_type = attn_mask_type self.layer_number = layer_number self.input_layernorm = input_layernorm self.attention_type = attention_type @@ -1341,7 +1377,6 @@ def __init__( self.tp_group = tp_group self.return_layernorm_output = return_layernorm_output self.params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype - self.attn_mask_type = attn_mask_type self.num_attention_heads = num_attention_heads self.return_bias = return_bias @@ -1467,7 +1502,6 @@ def __init__( attention_dropout=attention_dropout, tp_size=tp_size, get_rng_state_tracker=get_rng_state_tracker, - attn_mask_type=attn_mask_type, sequence_parallel=sequence_parallel, tp_group=tp_group, layer_number=self.layer_number, @@ -1508,6 +1542,7 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, encoder_output: Optional[torch.Tensor] = None, + attn_mask_type: str = "causal", is_first_microbatch: Optional[bool] = None, checkpoint_core_attention: bool = False, inference_params: Optional[Any] = None, @@ -1521,7 +1556,7 @@ def forward( .. note:: - Argument :attr:`attention_mask` will be ignored when :attr:`self_attn_mask_type` + Argument :attr:`attention_mask` will be ignored when :attr:`attn_mask_type` is set to `"causal"`. Parameters @@ -1530,6 +1565,8 @@ def forward( Input tensor. attention_mask : Optional[torch.Tensor], default = `None` Boolean tensor used to mask out self-attention softmax input. + attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal` + type of attention mask passed into softmax operation. encoder_output : Optional[torch.Tensor], default = `None` Output of the encoder block to be fed into the decoder block if using `layer_type="decoder"`. @@ -1563,7 +1600,16 @@ def forward( """ # hidden_states: [sq, b, h] - if self.attn_mask_type == "padding" and attention_mask is not None: + if self.attn_mask_type is not None: + warnings.warn( + "Argument :attr:`attn_mask_type` has been moved to the `forward` method and" + "is deprecated. It will be fully removed in future releases.", + category=DeprecationWarning, + ) + # Keep previous functionality for current users. + attn_mask_type = self.attn_mask_type + + if attn_mask_type == "padding" and attention_mask is not None: assert ( attention_mask.dtype == torch.bool ), "Attention mask must be a boolean tensor" @@ -1768,7 +1814,8 @@ def forward( query_layer, key_layer, value_layer, - attention_mask, + attention_mask=attention_mask, + attn_mask_type=attn_mask_type, checkpoint_core_attention=checkpoint_core_attention, core_attention_bias_type=core_attention_bias_type, core_attention_bias=core_attention_bias, diff --git a/transformer_engine/pytorch/softmax.py b/transformer_engine/pytorch/softmax.py index b4166309d7..036ea98369 100644 --- a/transformer_engine/pytorch/softmax.py +++ b/transformer_engine/pytorch/softmax.py @@ -215,19 +215,16 @@ class FusedScaleMaskSoftmax(nn.Module): fused operation: scaling + mask + softmax Arguments: - attn_mask_type: attention mask type (pad or causal) mask_func: mask function to be applied. softmax_in_fp32: if true, softmax in performed at fp32 precision. """ def __init__( self, - attn_mask_type: str, mask_func: Callable, softmax_in_fp32: bool = True, ) -> None: super().__init__() - self.attn_mask_type = attn_mask_type self.scaled_masked_softmax_fusion = bool( int(os.getenv("NVTE_MASKED_SOFTMAX_FUSION", "1")) ) @@ -249,6 +246,7 @@ def forward( self, inp: torch.Tensor, mask: torch.Tensor, + attn_mask_type: str, scale: Optional[float] = None, ) -> torch.Tensor: """FusedScaleMaskSoftmax fprop""" @@ -257,6 +255,7 @@ def forward( self.input_in_fp16 = inp.dtype == torch.float16 self.input_in_bf16 = inp.dtype == torch.bfloat16 self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 + self.attn_mask_type = attn_mask_type assert ( scale is None or self.softmax_in_fp32 diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index de93cd652f..6b45a10fb3 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -73,10 +73,10 @@ class TransformerLayer(torch.nn.Module): Arguments :attr:`attention_softmax_in_fp32` and :attr:`apply_query_key_layer_scaling` are deprecated and will be fully removed in future releases. - .. note:: + .. warning:: - Argument :attr:`attention_mask` will be ignored in the `forward` call when - :attr:`self_attn_mask_type` is set to `"causal"`. + Argument :attr:`self_attn_mask_type` has been moved to the `forward` method and + is deprecated. It will be fully removed in future releases. Parameters ---------- @@ -127,8 +127,6 @@ class TransformerLayer(torch.nn.Module): kv_channels: int, default = `None` number of key-value channels. defaults to :attr:`hidden_size` / :attr:`num_attention_heads` if `None`. - self_attn_mask_type: {'causal', 'padding'}, default = `causal` - type of attention mask passed into softmax operation. zero_centered_gamma : bool, default = 'False' if set to 'True', gamma parameter in LayerNorm is initialized to 0 and the LayerNorm formula changes to @@ -214,7 +212,7 @@ def __init__( output_layer_init_method: Optional[Callable] = None, layer_number: Optional[int] = None, kv_channels: Optional[int] = None, - self_attn_mask_type: str = "causal", + self_attn_mask_type: Optional[str] = None, tp_group: Optional[dist_group_type] = None, tp_size: int = 1, params_dtype: Optional[torch.dtype] = None, @@ -241,6 +239,13 @@ def __init__( ) -> None: super().__init__() + if self_attn_mask_type is not None: + warnings.warn( + "Argument :attr:`self_attn_mask_type` has been moved to the `forward` method and" + "is deprecated. It will be fully removed in future releases.", + category=DeprecationWarning, + ) + warnings.warn( "Arguments `attention_softmax_in_fp32` and `apply_query_key_layer_scaling`" "are deprecated and will be fully removed in future releases.", @@ -252,6 +257,7 @@ def __init__( tex.userbuf_comm_available() ), "Userbuffer communication backend not available." + self.self_attn_mask_type = self_attn_mask_type params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype ub_tp_comm_overlap = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_OVERLAP", "1"))) ub_bulk_wgrad = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_BULK_WGRAD", "1"))) @@ -265,10 +271,7 @@ def __init__( self.apply_residual_connection_post_layernorm = ( apply_residual_connection_post_layernorm ) - self.self_attn_mask_type = self_attn_mask_type - assert ( - self_attn_mask_type in AttnMaskTypes - ), f"self_attn_mask_type {self_attn_mask_type} not supported" + assert layer_type in LayerTypes, f"layer_type {layer_type} not supported" if not fuse_qkv_params: @@ -326,7 +329,6 @@ def __init__( self.self_attention = MultiheadAttention( *attention_args, **common_attention_kwargs, - attn_mask_type=self_attn_mask_type, input_layernorm=not output_layernorm, attention_type="self", bias=bias, @@ -429,6 +431,7 @@ def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, + self_attn_mask_type: str = "causal", encoder_output: Optional[torch.Tensor] = None, enc_dec_attn_mask: Optional[torch.Tensor] = None, is_first_microbatch: Optional[bool] = None, @@ -453,6 +456,8 @@ def forward( Input tensor. attention_mask : Optional[torch.Tensor], default = `None` Boolean tensor used to mask out self-attention softmax input. + self_attn_mask_type: {'causal', 'padding'}, default = `causal` + type of attention mask passed into softmax operation. encoder_output : Optional[torch.Tensor], default = `None` Output of the encoder block to be fed into the decoder block if using `layer_type="decoder"`. @@ -488,6 +493,19 @@ def forward( Whether to set output tensors to 0 or not before use. """ + if self.self_attn_mask_type is not None: + warnings.warn( + "Argument :attr:`self_attn_mask_type` has been moved to the `forward` method and" + "is deprecated. It will be fully removed in future releases.", + category=DeprecationWarning, + ) + # Keep previous functionality for current users. + self_attn_mask_type = self.self_attn_mask_type + + assert ( + self_attn_mask_type in AttnMaskTypes + ), f"self_attn_mask_type {self_attn_mask_type} not supported" + hidden_states = hidden_states.contiguous() if self.sequence_parallel and self.seq_length is not None: @@ -495,7 +513,7 @@ def forward( hidden_states.shape[0] == self.seq_length // self.tp_size ), "Sequence dimension must be split across TP group when using sequence parallel." - if self.self_attn_mask_type != "causal" and attention_mask is not None: + if self_attn_mask_type != "causal" and attention_mask is not None: assert ( attention_mask.dtype == torch.bool ), "Attention mask must be a boolean tensor" @@ -509,7 +527,8 @@ def forward( # Self attention. self_attention_outputs = self.self_attention( hidden_states, - attention_mask, + attention_mask=attention_mask, + attn_mask_type=self_attn_mask_type, inference_params=inference_params, is_first_microbatch=is_first_microbatch, checkpoint_core_attention=checkpoint_core_attention, @@ -556,7 +575,8 @@ def forward( if self.layer_type == "decoder": inter_attention_outputs = self.inter_attention( bda_output, - enc_dec_attn_mask, + attention_mask=enc_dec_attn_mask, + attn_mask_type=self_attn_mask_type, encoder_output=encoder_output, is_first_microbatch=is_first_microbatch, checkpoint_core_attention=checkpoint_core_attention, From 0170797ce9fc2a6114f4e72383ad58e1fa321dfd Mon Sep 17 00:00:00 2001 From: Tian Zheng Date: Sun, 27 Aug 2023 02:08:10 +0800 Subject: [PATCH 51/68] [Paddle] Add parallel support (#357) * [Paddle] Add TP, DP, PP, FSDP Signed-off-by: Tian Zheng (Engrg-Hardware 1) * Minor fix Signed-off-by: Tian Zheng (Engrg-Hardware 1) * Fix CI failure Signed-off-by: Tian Zheng (Engrg-Hardware 1) * Remove set_nccl_overlap_warning_if_tp Signed-off-by: Tian Zheng (Engrg-Hardware 1) * Improve variable naming Signed-off-by: Tian Zheng (Engrg-Hardware 1) * Refactor FP8 Buffer Signed-off-by: Tian Zheng (Engrg-Hardware 1) * Stylic changes Signed-off-by: Tian Zheng (Engrg-Hardware 1) * Fix FP32 parallel training Signed-off-by: Tian Zheng (Engrg-Hardware 1) * Fix numel performance issue Signed-off-by: Tian Zheng (Engrg-Hardware 1) * Squashed commit of the following: commit 79e2e5fd774e67dcdda9aae01a9f31a6479c5d70 Author: Tian Zheng (Engrg-Hardware 1) Date: Sun Aug 20 14:39:16 2023 +0000 Add TP test Signed-off-by: Tian Zheng (Engrg-Hardware 1) commit 1d40ad60540490f97ed82ba877cc6eda8902cbf6 Author: Tian Zheng (Engrg-Hardware 1) Date: Sun Aug 20 14:22:25 2023 +0000 Fix tp_size when disabled Signed-off-by: Tian Zheng (Engrg-Hardware 1) commit 6632f735a0c8251862355fc74622af59fae3a509 Author: Tian Zheng (Engrg-Hardware 1) Date: Sun Aug 20 05:52:18 2023 +0000 Add TP for attention and transformer layer Signed-off-by: Tian Zheng (Engrg-Hardware 1) Signed-off-by: Tian Zheng (Engrg-Hardware 1) * Add shape check Signed-off-by: Tian Zheng (Engrg-Hardware 1) * Add FSDP check for stage 1,2,3 Signed-off-by: Tian Zheng (Engrg-Hardware 1) * Review changes Signed-off-by: Tian Zheng (Engrg-Hardware 1) * Fix group_sharding test Signed-off-by: Tian Zheng (Engrg-Hardware 1) * Support NVTE_FUSE_ATTN Signed-off-by: Tian Zheng (Engrg-Hardware 1) * Fix CI errors Signed-off-by: Tian Zheng (Engrg-Hardware 1) --------- Signed-off-by: Tian Zheng (Engrg-Hardware 1) Co-authored-by: Kirthi Shankar Sivamani --- .../paddle/mnist/test_single_gpu_mnist.py | 8 +- tests/paddle/dist_launcher.py | 140 ++++++++++ tests/paddle/parallel_tests/amax_reduction.py | 87 ++++++ tests/paddle/parallel_tests/group_sharding.py | 187 +++++++++++++ .../parallel_tests/layernorm_linear_tp.py | 119 ++++++++ .../paddle/parallel_tests/layernorm_mlp_tp.py | 125 +++++++++ tests/paddle/parallel_tests/linear_pp.py | 192 +++++++++++++ tests/paddle/parallel_tests/linear_tp.py | 180 ++++++++++++ tests/paddle/parallel_tests/transformer_tp.py | 151 ++++++++++ tests/paddle/test_layers.py | 10 +- tests/paddle/test_operators.py | 8 +- tests/paddle/test_parallel.py | 89 ++++++ tests/paddle/utils.py | 18 ++ transformer_engine/paddle/constants.py | 4 + transformer_engine/paddle/distributed.py | 100 +++++++ transformer_engine/paddle/fp8.py | 92 +++++-- transformer_engine/paddle/fp8_buffer.py | 257 ++++++++++++++++++ transformer_engine/paddle/layer/attention.py | 106 +++++--- transformer_engine/paddle/layer/base.py | 78 +++++- transformer_engine/paddle/layer/layernorm.py | 2 +- .../paddle/layer/layernorm_linear.py | 109 ++++++-- .../paddle/layer/layernorm_mlp.py | 153 +++++++++-- transformer_engine/paddle/layer/linear.py | 145 ++++++++-- .../paddle/layer/transformer.py | 28 +- 24 files changed, 2248 insertions(+), 140 deletions(-) create mode 100644 tests/paddle/dist_launcher.py create mode 100644 tests/paddle/parallel_tests/amax_reduction.py create mode 100644 tests/paddle/parallel_tests/group_sharding.py create mode 100644 tests/paddle/parallel_tests/layernorm_linear_tp.py create mode 100644 tests/paddle/parallel_tests/layernorm_mlp_tp.py create mode 100644 tests/paddle/parallel_tests/linear_pp.py create mode 100644 tests/paddle/parallel_tests/linear_tp.py create mode 100644 tests/paddle/parallel_tests/transformer_tp.py create mode 100644 tests/paddle/test_parallel.py create mode 100644 transformer_engine/paddle/distributed.py create mode 100644 transformer_engine/paddle/fp8_buffer.py diff --git a/examples/paddle/mnist/test_single_gpu_mnist.py b/examples/paddle/mnist/test_single_gpu_mnist.py index dabeb55656..cffd036d95 100644 --- a/examples/paddle/mnist/test_single_gpu_mnist.py +++ b/examples/paddle/mnist/test_single_gpu_mnist.py @@ -57,11 +57,13 @@ def forward(self, x): def train(args, model, train_loader, optimizer, epoch, use_fp8): """Training function.""" model.train() + losses = [] for batch_id, (data, labels) in enumerate(train_loader): with paddle.amp.auto_cast(dtype='bfloat16', level='O2'): # pylint: disable=not-context-manager with te.fp8_autocast(enabled=use_fp8): outputs = model(data) loss = F.cross_entropy(outputs, labels) + losses.append(loss.item()) loss.backward() optimizer.step() @@ -74,7 +76,9 @@ def train(args, model, train_loader, optimizer, epoch, use_fp8): f"Loss: {loss.item():.6f}") if args.dry_run: return loss.item() - return loss.item() + avg_loss = sum(losses) / len(losses) + print(f"Train Epoch: {epoch}, Average Loss: {avg_loss}") + return avg_loss def evaluate(model, test_loader, epoch, use_fp8): @@ -226,7 +230,7 @@ def setUpClass(cls): @staticmethod def verify(actual): """Check If loss and accuracy match target""" - desired_traing_loss = 0.5 + desired_traing_loss = 0.1 desired_test_accuracy = 0.98 assert actual[0] < desired_traing_loss assert actual[1] > desired_test_accuracy diff --git a/tests/paddle/dist_launcher.py b/tests/paddle/dist_launcher.py new file mode 100644 index 0000000000..e59b686435 --- /dev/null +++ b/tests/paddle/dist_launcher.py @@ -0,0 +1,140 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Helper functions to launch distributed tests""" + +import copy +import os +from pathlib import Path +import subprocess +import time +import unittest + +from paddle import fluid +from paddle.distributed.utils.launch_utils import ( + TrainerProc, + find_free_ports, + get_cluster, + watch_local_trainers, +) + +__all__ = ['TestDistributed'] + + +def get_cluster_from_args(selected_gpus): + """Get node information from selected GPUs""" + cluster_node_ips = '127.0.0.1' + node_ip = '127.0.0.1' + + node_ips = [x.strip() for x in cluster_node_ips.split(',')] + + node_ips.index(node_ip) + + free_ports = None + + free_ports = find_free_ports(len(selected_gpus)) + if free_ports is not None: + free_ports = list(free_ports) + + trainer_endpoints = [] + for ip in node_ips: + trainer_endpoints.append([f"{ip}:{port}" for port in free_ports]) + return get_cluster(node_ips, node_ip, trainer_endpoints, selected_gpus) + + +def get_gpus(selected_gpus): + """Get selected GPU string""" + selected_gpus = [x.strip() for x in selected_gpus.split(',')] + return selected_gpus + + +def start_local_trainers( + cluster, + pod, + training_script, + training_script_args, + allocator_strategy="auto_growth", +): + """Launch trainers""" + current_env = copy.copy(os.environ.copy()) + # paddle broadcast ncclUniqueId use socket, and + # proxy maybe make trainers unreachable, so delete them. + # if we set them to "", grpc will log error message "bad uri" + # so just delete them. + current_env.pop("http_proxy", None) + current_env.pop("https_proxy", None) + + procs = [] + for t in pod.trainers: + proc_env = { + "FLAGS_selected_gpus": ",".join([str(g) for g in t.gpus]), + "PADDLE_TRAINER_ID": f"{t.rank}", + "PADDLE_CURRENT_ENDPOINT": f"{t.endpoint}", + "PADDLE_TRAINERS_NUM": f"{cluster.trainers_nranks()}", + "PADDLE_TRAINER_ENDPOINTS": ",".join(cluster.trainers_endpoints()), + "PYTHONPATH": str(Path(__file__).resolve().parent), + } + + proc_env["FLAGS_allocator_strategy"] = allocator_strategy + if allocator_strategy == "auto_growth": + proc_env["FLAGS_fraction_of_gpu_memory_to_use"] = "0.1" + + current_env.update(proc_env) + + print(f"trainer proc env:{current_env}") + + if os.getenv('WITH_COVERAGE', 'OFF') == 'ON': + cmd = "python -m coverage run --branch -p " + training_script + else: + cmd = "python -u " + training_script + + print(f"start trainer proc:{cmd} env:{proc_env}") + + fn = None + + proc = subprocess.Popen(cmd.split(" ") + training_script_args, env=current_env) # pylint: disable=consider-using-with + + tp = TrainerProc() + tp.proc = proc + tp.rank = t.rank + tp.log_fn = fn + tp.cmd = cmd + + procs.append(tp) + + return procs + + +class TestDistributed(unittest.TestCase): + """Base class for distributed test""" + + @staticmethod + def run_2gpu( + target_file_name, + allocator_strategy="auto_growth", + ): + """Run target file in subprocesses""" + if (not fluid.core.is_compiled_with_cuda() or fluid.core.get_cuda_device_count() == 0): + return + + selected_gpus = get_gpus('0,1') + cluster = None + pod = None + + cluster, pod = get_cluster_from_args(selected_gpus) + + procs = start_local_trainers( + cluster, + pod, + allocator_strategy=allocator_strategy, + training_script=target_file_name, + training_script_args=[], + ) + + while True: + alive = watch_local_trainers(procs, cluster.trainers_endpoints()) + + if not alive: + print(f"Local procs complete, POD info:{pod}") + break + time.sleep(3) diff --git a/tests/paddle/parallel_tests/amax_reduction.py b/tests/paddle/parallel_tests/amax_reduction.py new file mode 100644 index 0000000000..931af07657 --- /dev/null +++ b/tests/paddle/parallel_tests/amax_reduction.py @@ -0,0 +1,87 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Unittest for Linear layer in tensor parallel""" + +import unittest + +import paddle +from paddle.distributed import fleet + +from utils import assert_allclose, set_random_seed +import transformer_engine.paddle as te + + +def assert_allclose_across_ranks(tensor, group=None): + """Assert tensor is identical in all ranks""" + gathered_list = [] + paddle.distributed.all_gather(gathered_list, tensor, group=group) + assert len(gathered_list) > 1 + for gathered_tensor in gathered_list: + assert_allclose(tensor, gathered_tensor) + + +class TestAmaxReduction(unittest.TestCase): + """Tests Amax reduction""" + + def setUp(self): + self.data_parallel_size = 2 + self.init_dist_env() + self.global_dtype = 'bfloat16' + paddle.set_default_dtype(self.global_dtype) + + def init_dist_env(self): + """Init Paddle Fleet environment""" + strategy = fleet.DistributedStrategy() + strategy.hybrid_configs = { + "dp_degree": self.data_parallel_size, + "mp_degree": 1, + "pp_degree": 1, + } + fleet.init(is_collective=True, strategy=strategy) + + def test_amax_reduction(self): + """Tests column parallel linear""" + set_random_seed(1024) + layer1 = te.Linear(16, 16) + layer2 = te.Linear(16, 16) + model = paddle.nn.Sequential(layer1, layer2) + model = fleet.distributed_model(model) + + rank_id = paddle.distributed.get_rank() + set_random_seed(rank_id) + + optimizer = paddle.optimizer.SGD(learning_rate=10.0, parameters=model.parameters()) + optimizer = fleet.distributed_optimizer(optimizer) + + def train_one_step(layer, inp, optimizer): + inp = paddle.to_tensor(inp) + inp.stop_gradient = False + out = layer(inp) + loss = out.mean() + loss.backward() + optimizer.step() + optimizer.clear_grad() + return loss + + for _ in range(5): + inp = paddle.uniform([16, 16], self.global_dtype) + with te.fp8_autocast(enabled=True): + train_one_step(model, inp, optimizer) + + assert_allclose_across_ranks(layer1.fp8_meta["scaling_fwd"].amax_history[-1]) + assert_allclose_across_ranks(layer1.fp8_meta["scaling_fwd"].scale) + assert_allclose_across_ranks(layer1.fp8_meta["scaling_fwd"].scale_inv) + assert_allclose_across_ranks(layer2.fp8_meta["scaling_fwd"].amax_history[-1]) + assert_allclose_across_ranks(layer2.fp8_meta["scaling_fwd"].scale) + assert_allclose_across_ranks(layer2.fp8_meta["scaling_fwd"].scale_inv) + assert_allclose_across_ranks(layer1.fp8_meta["scaling_bwd"].amax_history[-1]) + assert_allclose_across_ranks(layer1.fp8_meta["scaling_bwd"].scale) + assert_allclose_across_ranks(layer1.fp8_meta["scaling_bwd"].scale_inv) + assert_allclose_across_ranks(layer2.fp8_meta["scaling_bwd"].amax_history[-1]) + assert_allclose_across_ranks(layer2.fp8_meta["scaling_bwd"].scale) + assert_allclose_across_ranks(layer2.fp8_meta["scaling_bwd"].scale_inv) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/paddle/parallel_tests/group_sharding.py b/tests/paddle/parallel_tests/group_sharding.py new file mode 100644 index 0000000000..b8e4fd885d --- /dev/null +++ b/tests/paddle/parallel_tests/group_sharding.py @@ -0,0 +1,187 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Unittest for group sharding""" + +import unittest + +import paddle +from paddle.distributed import fleet +from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer import ( + DygraphShardingOptimizer,) + +from utils import assert_allclose, set_random_seed +import transformer_engine.paddle as te + + +class TestGroupSharding(unittest.TestCase): + """Tests group sharding""" + + def setUp(self): + self.set_attr() + self.init_dist_env() + paddle.set_default_dtype(self.global_dtype) + + def set_attr(self): + """Set test configs""" + self.sharding_degree = 2 + self.global_dtype = 'float32' + self.rtol = 1e-5 + self.atol = 1e-5 + self.batch_size = 16 + self.in_channels = 16 + self.out_channels = 32 + self.fp8 = False + + def init_dist_env(self): + """Init Paddle Fleet environment""" + strategy = fleet.DistributedStrategy() + strategy.hybrid_configs = { + "dp_degree": 1, + "mp_degree": 1, + "pp_degree": 1, + "sharding_degree": self.sharding_degree, + } + self.strategy = strategy + fleet.init(is_collective=True, strategy=strategy) + + def _get_model_and_optimizer(self, model, stage): + if stage == 1: + optimizer = DygraphShardingOptimizer( + hcg=fleet.get_hybrid_communicate_group(), + user_defined_strategy=self.strategy, + params=model.parameters(), + inner_optimizer_class=paddle.optimizer.AdamW, + learning_rate=0.01, + ) + model = fleet.distributed_model(model) + optimizer = fleet.distributed_optimizer(optimizer) + elif stage in [2, 3]: + optimizer = paddle.optimizer.AdamW(learning_rate=0.01, parameters=model.parameters()) + group = fleet.get_hybrid_communicate_group().get_sharding_parallel_group() + + class ShardingLevel: # pylint: disable=too-few-public-methods, + """Paddle sharding options""" + kStage1 = 'os' + kStage2 = 'os_g' + kStage3 = 'p_g_os' + + level = ShardingLevel.kStage3 if stage == 3 else ShardingLevel.kStage2 + model, optimizer, _ = paddle.distributed.sharding.group_sharded_parallel( + model=model, + optimizer=optimizer, + level=level, + group=group, + segment_size=256, + ) + else: + raise ValueError(f"Stage {stage} not supported") + return model, optimizer + + def test_group_sharding_stage1(self): + """Tests group sharding training""" + set_random_seed(1024) + model_te = te.Linear(self.in_channels, self.out_channels) + model_pd = paddle.nn.Linear(self.in_channels, self.out_channels) + model_pd.weight.copy_(model_te.weight.T, True) + model_pd.bias.copy_(model_te.bias, True) + + model_te, optimizer_te = self._get_model_and_optimizer(model_te, stage=1) + model_pd, optimizer_pd = self._get_model_and_optimizer(model_pd, stage=1) + + rank_id = paddle.distributed.get_rank() + paddle.seed(rank_id) + + def train_one_step(model, inp, optimizer): + out = model(inp) + loss = out.mean() + loss.backward() + optimizer.step() + optimizer.clear_grad() + return loss + + for _ in range(5): + inp = paddle.uniform([self.batch_size, self.in_channels], self.global_dtype) + with te.fp8_autocast(enabled=False): + loss_te = train_one_step(model_te, inp, optimizer_te) + loss_pd = train_one_step(model_pd, inp, optimizer_pd) + assert_allclose(loss_te, loss_pd, rtol=self.rtol, atol=self.atol) + + assert len(optimizer_te.state_dict()) == 4, \ + "Expect each rank to hold 4 optimizer state entries." + + def test_group_sharding_stage2(self): + """Tests group sharding training""" + set_random_seed(1024) + model_te = te.Linear(self.in_channels, self.out_channels) + model_pd = paddle.nn.Linear(self.in_channels, self.out_channels) + model_pd.weight.copy_(model_te.weight.T, True) + model_pd.bias.copy_(model_te.bias, True) + + model_te, optimizer_te = self._get_model_and_optimizer(model_te, stage=2) + model_pd, optimizer_pd = self._get_model_and_optimizer(model_pd, stage=2) + + rank_id = paddle.distributed.get_rank() + paddle.seed(rank_id) + + def train_one_step(model, inp, optimizer): + out = model(inp) + loss = out.mean() + loss.backward() + # Check gradients are split to different trainers + if rank_id == 0: + assert model.bias.grad is None and model.weight.grad is not None + elif rank_id == 1: + assert model.weight.grad is None and model.bias.grad is not None + optimizer.step() + optimizer.clear_grad() + return loss + + for _ in range(5): + inp = paddle.uniform([self.batch_size, self.in_channels], self.global_dtype) + with te.fp8_autocast(enabled=False): + loss_te = train_one_step(model_te, inp, optimizer_te) + loss_pd = train_one_step(model_pd, inp, optimizer_pd) + assert_allclose(loss_te, loss_pd, rtol=self.rtol, atol=self.atol) + + assert len(optimizer_te.state_dict()) == 4, \ + "Expect each rank to hold 4 optimizer state entries." + + def test_group_sharding_stage3(self): + """Tests group sharding training""" + set_random_seed(1024) + model_te = te.Linear(self.in_channels, self.out_channels) + model_pd = paddle.nn.Linear(self.in_channels, self.out_channels) + model_pd.weight.copy_(model_te.weight.T, True) + model_pd.bias.copy_(model_te.bias, True) + + model_te, optimizer_te = self._get_model_and_optimizer(model_te, stage=3) + model_pd, optimizer_pd = self._get_model_and_optimizer(model_pd, stage=3) + + rank_id = paddle.distributed.get_rank() + paddle.seed(rank_id) + + def train_one_step(model, inp, optimizer): + out = model(inp) + loss = out.mean() + loss.backward() + optimizer.step() + optimizer.clear_grad() + return loss + + for _ in range(5): + inp = paddle.uniform([self.batch_size, self.in_channels], self.global_dtype) + with te.fp8_autocast(enabled=False): + loss_te = train_one_step(model_te, inp, optimizer_te) + loss_pd = train_one_step(model_pd, inp, optimizer_pd) + assert_allclose(loss_te, loss_pd, rtol=self.rtol, atol=self.atol) + + for name, value in optimizer_te.state_dict().items(): + if name.endswith('w_0_moment1_0'): + assert value.numel() == \ + self.in_channels * self.out_channels // self.sharding_degree, \ + "Expect optimizer state to be sharded across trainers." + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/paddle/parallel_tests/layernorm_linear_tp.py b/tests/paddle/parallel_tests/layernorm_linear_tp.py new file mode 100644 index 0000000000..1034fb26fc --- /dev/null +++ b/tests/paddle/parallel_tests/layernorm_linear_tp.py @@ -0,0 +1,119 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Unittest for LayerNormLinear layer in tensor parallel""" + +import unittest + +import paddle +from paddle.distributed import fleet +from paddle.distributed.fleet.layers.mpu import mp_ops + +from utils import assert_allclose, assert_shape, set_random_seed +import transformer_engine.paddle as te + + +class TestLayerNormLinearTp(unittest.TestCase): + """Tests LayerNormLinear layer with column/row parallelism in BF16""" + + def setUp(self): + self.set_attr() + self.init_dist_env() + paddle.set_default_dtype(self.global_dtype) + + def init_dist_env(self): + """Init Paddle Fleet environment""" + strategy = fleet.DistributedStrategy() + self.model_parallel_size = 2 + strategy.hybrid_configs = { + "dp_degree": 1, + "mp_degree": self.model_parallel_size, + "pp_degree": 1, + } + fleet.init(is_collective=True, strategy=strategy) + self.hcg = fleet.get_hybrid_communicate_group() + self.tp_group = self.hcg.get_model_parallel_group() + + def set_attr(self): + """Set test configs""" + self.batch_size = 16 + self.in_features = 32 + self.out_features = 64 + self.global_dtype = 'bfloat16' + self.rtol = 1e-3 + self.atol = 1e-3 + self.eps = 1e-3 + self.fp8 = False + + def test_column_parallel_layer(self): + """Tests column parallel LayerNormLinear""" + set_random_seed(1024) + layer_te = te.LayerNormLinear( + self.in_features, + self.out_features, + eps=self.eps, + parallel_mode='column', + ) + layer_pd = te.LayerNormLinear( + self.in_features, + self.out_features, + eps=self.eps, + backend='paddle', + ) + # Get total weight + total_weight = [] + partial_weight = layer_te.weight.clone().detach() + paddle.distributed.all_gather(total_weight, partial_weight, group=self.tp_group) + total_weight = paddle.concat(total_weight, axis=0) + layer_pd.weight.copy_(total_weight.T, True) + + assert_shape(layer_te.weight, + [self.out_features // self.model_parallel_size, self.in_features]) + assert_shape(layer_te.bias, [self.out_features // self.model_parallel_size]) + + optimizer_te = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_te.parameters()) + optimizer_pd = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_pd.parameters()) + + layer_te = fleet.distributed_model(layer_te) + optimizer_te = fleet.distributed_optimizer(optimizer_te) + + def train_one_step(layer, inp, optimizer, gather=False): + inp = paddle.to_tensor(inp) + inp.stop_gradient = False + out = layer(inp) + if gather: + total_out = mp_ops._c_concat(out, group=self.tp_group) + else: + total_out = out + loss = total_out.mean() + loss.backward() + optimizer.step() + optimizer.clear_grad() + return loss, inp.grad + + for _ in range(5): + inp = paddle.uniform([self.batch_size, self.in_features], self.global_dtype) + with te.fp8_autocast(enabled=self.fp8): + loss_tp, grad_input = train_one_step(layer_te, inp, optimizer_te, gather=True) + loss_ref, grad_input_ref = train_one_step(layer_pd, inp, optimizer_pd) + assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol) + assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol) + + +class TestLayerNormLinearTpFp8(TestLayerNormLinearTp): + """Tests LayernormLinear layer with column/row parallelism in FP8""" + + def set_attr(self): + """Set test configs""" + self.batch_size = 16 + self.in_features = 32 + self.out_features = 64 + self.global_dtype = 'bfloat16' + self.rtol = 1e-2 + self.atol = 1e-2 + self.eps = 1e-3 + self.fp8 = True + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/paddle/parallel_tests/layernorm_mlp_tp.py b/tests/paddle/parallel_tests/layernorm_mlp_tp.py new file mode 100644 index 0000000000..f579f5f371 --- /dev/null +++ b/tests/paddle/parallel_tests/layernorm_mlp_tp.py @@ -0,0 +1,125 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Unittest for LayerNormMLP layer in tensor parallel""" + +import unittest + +import paddle +from paddle.distributed import fleet + +from utils import assert_allclose, assert_shape, set_random_seed +import transformer_engine.paddle as te + + +class TestLayerNormMLPTp(unittest.TestCase): + """Tests LayerNormMLP layer with model parallel in BF16""" + + def setUp(self): + self.set_attr() + self.init_dist_env() + paddle.set_default_dtype(self.global_dtype) + + def init_dist_env(self): + """Init Paddle Fleet environment""" + strategy = fleet.DistributedStrategy() + self.model_parallel_size = 2 + strategy.hybrid_configs = { + "dp_degree": 1, + "mp_degree": self.model_parallel_size, + "pp_degree": 1, + } + fleet.init(is_collective=True, strategy=strategy) + self.hcg = fleet.get_hybrid_communicate_group() + self.tp_group = self.hcg.get_model_parallel_group() + + def set_attr(self): + """Set test configs""" + self.batch_size = 16 + self.hidden_size = 32 + self.ffn_hidden_size = 64 + self.global_dtype = 'bfloat16' + self.rtol = 1e-3 + self.atol = 1e-3 + self.eps = 1e-3 + self.fp8 = False + + def test_parallel_layer(self): + """Tests parallel LayerNormMLP""" + set_random_seed(1024) + layer_te = te.LayerNormMLP( + hidden_size=self.hidden_size, + ffn_hidden_size=self.ffn_hidden_size, + eps=self.eps, + set_parallel_mode=True, + ) + layer_pd = te.LayerNormMLP( + hidden_size=self.hidden_size, + ffn_hidden_size=self.ffn_hidden_size, + eps=self.eps, + set_parallel_mode=False, + backend='paddle', + ) + + def _get_total_weight(local_weight, tp_group, axis): + total_weight = [] + partial_weight = local_weight.clone().detach() + paddle.distributed.all_gather(total_weight, partial_weight, group=tp_group) + total_weight = paddle.concat(total_weight, axis=axis) + return total_weight + + # Get total weight + total_fc1_weight = _get_total_weight(layer_te.fc1_weight, tp_group=self.tp_group, axis=0) + total_fc2_weight = _get_total_weight(layer_te.fc2_weight, tp_group=self.tp_group, axis=1) + layer_pd.fc1_weight.copy_(total_fc1_weight.T, True) + layer_pd.fc2_weight.copy_(total_fc2_weight.T, True) + + assert_shape(layer_te.fc1_weight, + [self.ffn_hidden_size // self.model_parallel_size, self.hidden_size]) + assert_shape(layer_te.fc1_bias, [self.ffn_hidden_size // self.model_parallel_size]) + assert_shape(layer_te.fc2_weight, + [self.hidden_size, self.ffn_hidden_size // self.model_parallel_size]) + assert_shape(layer_te.fc2_bias, [self.hidden_size]) + + optimizer_te = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_te.parameters()) + optimizer_pd = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_pd.parameters()) + + layer_te = fleet.distributed_model(layer_te) + optimizer_te = fleet.distributed_optimizer(optimizer_te) + + def train_one_step(layer, inp, optimizer): + inp = paddle.to_tensor(inp) + inp.stop_gradient = False + out = layer(inp) + loss = out.mean() + loss.backward() + optimizer.step() + optimizer.clear_grad() + return loss, inp.grad + + for _ in range(5): + inp = paddle.uniform([self.batch_size, self.hidden_size], self.global_dtype) + with te.fp8_autocast(enabled=self.fp8): + loss_tp, grad_input = train_one_step(layer_te, inp, optimizer_te) + loss_ref, grad_input_ref = train_one_step(layer_pd, inp, optimizer_pd) + assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol) + assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol) + + +class TestLayerNormMLPTpFp8(TestLayerNormMLPTp): + """Tests LayerNormMLP layer with tensor parallelism in FP8""" + + def set_attr(self): + """Set test configs""" + self.batch_size = 16 + self.hidden_size = 32 + self.ffn_hidden_size = 64 + self.global_dtype = 'bfloat16' + self.rtol = 1e-2 + self.atol = 1e-2 + self.eps = 1e-3 + self.fp8 = True + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/paddle/parallel_tests/linear_pp.py b/tests/paddle/parallel_tests/linear_pp.py new file mode 100644 index 0000000000..994e15ba7d --- /dev/null +++ b/tests/paddle/parallel_tests/linear_pp.py @@ -0,0 +1,192 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Unittest for Linear layer in pipeline parallel""" + +import unittest + +import numpy as np + +import paddle +from paddle.distributed import fleet + +from paddle.distributed.fleet.meta_parallel import ( + LayerDesc, + PipelineLayer, +) + +from utils import assert_allclose, set_random_seed +import transformer_engine.paddle as te + + +class TEPipelineModel(PipelineLayer): + """Model for pipeline parallel test""" + + def __init__(self, + in_features, + hidden_features, + weight_attrs, + use_te=True, + use_fp8=False, + **kwargs): + self.in_features = in_features + self.hidden_features = hidden_features + self.fp8 = use_fp8 + hcg = fleet.get_hybrid_communicate_group() + self.dp_group = hcg.get_data_parallel_group() + + Linear = te.Linear if use_te else paddle.nn.Linear + model_desc = [ + LayerDesc(Linear, self.in_features, self.hidden_features, weight_attr=weight_attrs[0]), + LayerDesc(Linear, self.hidden_features, self.in_features, weight_attr=weight_attrs[1]), + ] + super().__init__(layers=model_desc, loss_fn=paddle.nn.CrossEntropyLoss(), **kwargs) + + def forward(self, *args, **kwargs): + with te.fp8_autocast(enabled=self.fp8, fp8_group=self.dp_group): + return super().forward(*args, **kwargs) + + +class StandaloneModel(paddle.nn.Layer): + """Model for pipeline parallel test""" + + def __init__(self, in_features, hidden_features, weight_attrs): + super().__init__() + self.in_features = in_features + self.hidden_features = hidden_features + Linear = paddle.nn.Linear + self.layer = paddle.nn.Sequential( + Linear(self.in_features, self.hidden_features, weight_attr=weight_attrs[0]), + Linear(self.hidden_features, self.in_features, weight_attr=weight_attrs[1]), + ) + self.loss = paddle.nn.CrossEntropyLoss() + + def forward(self, inp): + out = self.layer(inp[0]) + loss = self.loss(out, inp[1]) + return loss + + +class TestLinearPipelineParallel(unittest.TestCase): + """Tests Linear layer with pipeline parallel""" + + def setUp(self): + self.set_attr() + self.init_dist_env() + paddle.set_default_dtype(self.global_dtype) + + def init_dist_env(self): + """Init Paddle Fleet environment""" + strategy = fleet.DistributedStrategy() + self.pipeline_parallel_size = 2 + strategy.hybrid_configs = { + "dp_degree": 1, + "mp_degree": 1, + "pp_degree": self.pipeline_parallel_size, + } + strategy.pipeline_configs = { + "accumulate_steps": self.batch_size // self.micro_batch_size, + "micro_batch_size": self.micro_batch_size, + } + fleet.init(is_collective=True, strategy=strategy) + self.rank = fleet.worker_index() + self.hcg = fleet.get_hybrid_communicate_group() + + def set_attr(self): + """Set test configs""" + self.batch_size = 32 + self.micro_batch_size = 16 + self.in_features = 32 + self.hidden_features = 64 + self.global_dtype = 'float32' + self.rtol = 1e-5 + self.atol = 1e-5 + self.iter = 10 + self.fp8 = False + + def test_pipeline_train(self): + """Test pipeline parallel training""" + set_random_seed(1024) + + weight1_np = np.random.normal(size=[self.in_features, self.hidden_features]) + weight2_np = np.random.normal(size=[self.hidden_features, self.in_features]) + weight_attrs = [ + paddle.ParamAttr(initializer=paddle.nn.initializer.Assign(weight1_np)), + paddle.ParamAttr(initializer=paddle.nn.initializer.Assign(weight2_np)), + ] + weight_attrs_transposed = [ + paddle.ParamAttr(initializer=paddle.nn.initializer.Assign(weight1_np.T)), + paddle.ParamAttr(initializer=paddle.nn.initializer.Assign(weight2_np.T)), + ] + + pipe_model = TEPipelineModel( + self.in_features, + self.hidden_features, + weight_attrs_transposed, + use_te=True, + use_fp8=self.fp8, + seg_method="layer:Linear", + num_stages=self.pipeline_parallel_size, + ) + + # Check if model is split across ranks as expected + for name, sublayer in pipe_model.named_sublayers(): + if name in ('_loss_fn', 'shared_layers'): + continue + if self.rank == 0: + assert tuple(sublayer.weight.shape) == weight1_np.T.shape, \ + f"Shape does not match, expect: {weight1_np.T.shape}, " \ + f"actual: {tuple(sublayer.weight.shape)}" + elif self.rank == 1: + assert tuple(sublayer.weight.shape) == weight2_np.T.shape, \ + f"Shape does not match, expect: {weight2_np.T.shape}, " \ + f"actual: {tuple(sublayer.weight.shape)}" + + standalone_model = StandaloneModel( + self.in_features, + self.hidden_features, + weight_attrs, + ) + + optimizer_te = paddle.optimizer.SGD(learning_rate=0.1, parameters=pipe_model.parameters()) + optimizer_pd = paddle.optimizer.SGD(learning_rate=0.1, + parameters=standalone_model.parameters()) + + pipe_model = fleet.distributed_model(pipe_model) + optimizer_te = fleet.distributed_optimizer(optimizer_te) + + def train_one_step(layer, inp, optimizer): + loss = layer(inp) + loss.backward() + optimizer.step() + optimizer.clear_grad() + return loss + + for i in range(self.iter): + inp = paddle.to_tensor(np.random.normal(size=[self.batch_size, self.in_features]), + dtype=self.global_dtype) + label = paddle.to_tensor(np.random.randint(self.in_features, size=[self.batch_size, 1])) + loss_te = pipe_model.train_batch([inp, label], optimizer_te) + loss_pd = train_one_step(standalone_model, [inp, label], optimizer_pd) + print(f"Iter: {i}, loss_te: {loss_te.item()}, loss_pd: {loss_pd.item()}") + assert_allclose(loss_te, loss_pd, rtol=self.rtol, atol=self.atol) + + +class TestLinearPipelineParallelFP8(TestLinearPipelineParallel): + """Tests Linear layer with column/row parallelism in FP8""" + + def set_attr(self): + """Set test configs""" + self.batch_size = 32 + self.micro_batch_size = 16 + self.in_features = 32 + self.hidden_features = 64 + self.global_dtype = 'float32' + self.rtol = 5e-2 + self.atol = 5e-2 + self.iter = 10 + self.fp8 = True + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/paddle/parallel_tests/linear_tp.py b/tests/paddle/parallel_tests/linear_tp.py new file mode 100644 index 0000000000..fe0aeddccd --- /dev/null +++ b/tests/paddle/parallel_tests/linear_tp.py @@ -0,0 +1,180 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Unittest for Linear layer in tensor parallel""" + +import unittest + +import paddle +from paddle.distributed import fleet +from paddle.distributed.fleet.layers.mpu import mp_ops + +from utils import assert_allclose, assert_shape, set_random_seed +import transformer_engine.paddle as te + + +class TestLinearTp(unittest.TestCase): + """Tests Linear layer with column/row parallelism in BF16""" + + def setUp(self): + self.set_attr() + self.init_dist_env() + paddle.set_default_dtype(self.global_dtype) + + def init_dist_env(self): + """Init Paddle Fleet environment""" + strategy = fleet.DistributedStrategy() + self.model_parallel_size = 2 + strategy.hybrid_configs = { + "dp_degree": 1, + "mp_degree": self.model_parallel_size, + "pp_degree": 1, + } + fleet.init(is_collective=True, strategy=strategy) + self.rank = fleet.worker_index() + self.hcg = fleet.get_hybrid_communicate_group() + self.tp_group = self.hcg.get_model_parallel_group() + self.world_size = self.hcg.get_model_parallel_world_size() + + def set_attr(self): + """Set test configs""" + self.batch_size = 16 + self.in_features = 32 + self.out_features = 64 + self.global_dtype = 'bfloat16' + self.rtol = 1e-3 + self.atol = 1e-3 + self.fp8 = False + + def test_column_parallel_layer(self): + """Tests column parallel linear""" + set_random_seed(1024) + layer_te = te.Linear( + self.in_features, + self.out_features, + parallel_mode='column', + ) + layer_pd = te.Linear( + self.in_features, + self.out_features, + backend='paddle', + ) + # Get total weight + total_weight = [] + partial_weight = layer_te.weight.clone().detach() + paddle.distributed.all_gather(total_weight, partial_weight, group=self.tp_group) + total_weight = paddle.concat(total_weight, axis=0) + layer_pd.weight.copy_(total_weight.T, True) + + assert_shape(layer_te.weight, + [self.out_features // self.model_parallel_size, self.in_features]) + assert_shape(layer_te.bias, [self.out_features // self.model_parallel_size]) + + optimizer_te = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_te.parameters()) + optimizer_pd = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_pd.parameters()) + + layer_te = fleet.distributed_model(layer_te) + optimizer_te = fleet.distributed_optimizer(optimizer_te) + + def train_one_step(layer, inp, optimizer, gather=False): + inp = paddle.to_tensor(inp) + inp.stop_gradient = False + out = layer(inp) + if gather: + total_out = mp_ops._c_concat(out, group=self.tp_group) + else: + total_out = out + loss = total_out.mean() + loss.backward() + optimizer.step() + optimizer.clear_grad() + return loss, inp.grad + + for _ in range(5): + inp = paddle.uniform([self.batch_size, self.in_features], self.global_dtype) + with te.fp8_autocast(enabled=self.fp8): + loss_tp, grad_input = train_one_step(layer_te, inp, optimizer_te, gather=True) + loss_ref, grad_input_ref = train_one_step(layer_pd, inp, optimizer_pd) + assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol) + assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol) + + def test_row_parallel_layer(self): + """Tests row parallel linear""" + set_random_seed(1024) + layer_te = te.Linear( + self.in_features, + self.out_features, + parallel_mode='row', + ) + layer_pd = te.Linear( + self.in_features, + self.out_features, + backend='paddle', + ) + # Get total weight + total_weight = [] + partial_weight = layer_te.weight.clone().detach() + paddle.distributed.all_gather(total_weight, partial_weight, group=self.tp_group) + total_weight = paddle.concat(total_weight, axis=1) + layer_pd.weight.copy_(total_weight.T, True) + + assert_shape(layer_te.weight, + [self.out_features, self.in_features // self.model_parallel_size]) + assert_shape(layer_te.bias, [self.out_features]) + + optimizer_te = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_te.parameters()) + optimizer_pd = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_pd.parameters()) + + # Note(tizheng): For this test, we cannot wrap model with fleet.distributed_model, + # because it will broadcast inputs across mp group. However, RPL expects splitted + # inputs, which is different on each rank. + + def train_one_step(layer, inp, optimizer, split=False): + inp = paddle.to_tensor(inp, stop_gradient=True) + if split: + # TODO(tizheng): Why not working? + # issue: https://github.com/PaddlePaddle/Paddle/issues/55565 + # input_parallel = mp_ops._c_split(inp, group=layer.tp_group) + split_size = inp.shape[1] // self.world_size + input_parallel = inp[:, split_size * self.rank:split_size * (self.rank + 1)] + else: + input_parallel = inp + input_parallel.stop_gradient = False + out = layer(input_parallel) + loss = out.mean() + loss.backward() + optimizer.step() + optimizer.clear_grad() + if split: + grad_input = [] + paddle.distributed.all_gather(grad_input, input_parallel.grad, group=self.tp_group) + grad_input = paddle.concat(grad_input, axis=1) + else: + grad_input = input_parallel.grad + return loss, grad_input + + for _ in range(5): + inp = paddle.uniform([self.batch_size, self.in_features], self.global_dtype) + with te.fp8_autocast(enabled=self.fp8): + loss_tp, grad_input = train_one_step(layer_te, inp, optimizer_te, split=True) + loss_ref, grad_input_ref = train_one_step(layer_pd, inp, optimizer_pd) + assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol) + assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol) + + +class TestLinearTpFP8(TestLinearTp): + """Tests Linear layer with column/row parallelism in FP8""" + + def set_attr(self): + """Set test configs""" + self.batch_size = 16 + self.in_features = 32 + self.out_features = 64 + self.global_dtype = 'bfloat16' + self.rtol = 1e-2 + self.atol = 1e-2 + self.fp8 = True + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/paddle/parallel_tests/transformer_tp.py b/tests/paddle/parallel_tests/transformer_tp.py new file mode 100644 index 0000000000..69fef08d56 --- /dev/null +++ b/tests/paddle/parallel_tests/transformer_tp.py @@ -0,0 +1,151 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Unittest for Transformer layer in tensor parallel""" + +import unittest + +import paddle +from paddle.distributed import fleet + +from utils import assert_allclose, set_random_seed +import transformer_engine.paddle as te + + +class TestTransformerTp(unittest.TestCase): + """Tests Transformer layer with model parallel in BF16""" + + def setUp(self): + self.set_attr() + self.init_dist_env() + paddle.set_default_dtype(self.global_dtype) + + def init_dist_env(self): + """Init Paddle Fleet environment""" + strategy = fleet.DistributedStrategy() + self.model_parallel_size = 2 + strategy.hybrid_configs = { + "dp_degree": 1, + "mp_degree": self.model_parallel_size, + "pp_degree": 1, + } + fleet.init(is_collective=True, strategy=strategy) + self.hcg = fleet.get_hybrid_communicate_group() + self.tp_group = self.hcg.get_model_parallel_group() + + def set_attr(self): + """Set test configs""" + self.batch_size = 16 + self.hidden_size = 1024 + self.num_heads = 16 + self.ffn_hidden_size = 4096 + self.q_seqlen = 128 + self.kv_seqlen = 128 + self.mask_type = 'padding' + self.layer_type = 'encoder' + self.global_dtype = 'bfloat16' + self.rtol = 5e-2 + self.atol = 5e-2 + self.eps = 1e-3 + self.fp8 = False + + def test_parallel_layer(self): + """Tests parallel Transformer""" + set_random_seed(1024) + common_args = [ + self.hidden_size, + self.ffn_hidden_size, + self.num_heads, + ] + common_kwargs = { + 'layernorm_epsilon': self.eps, + 'hidden_dropout': 0.0, + 'attention_dropout': 0.0, + 'self_attn_mask_type': self.mask_type, + 'layer_type': self.layer_type, + } + layer_tp = te.TransformerLayer(*common_args, **common_kwargs, set_parallel_mode=True) + layer_single = te.TransformerLayer(*common_args, **common_kwargs, set_parallel_mode=False) + + def _get_total_weight(local_weight, tp_group, axis): + total_weight = [] + partial_weight = local_weight.clone().detach() + paddle.distributed.all_gather(total_weight, partial_weight, group=tp_group) + total_weight = paddle.concat(total_weight, axis=axis) + return total_weight + + def _get_weight(obj, weight_names): + for name in weight_names: + obj = getattr(obj, name) + return obj + + def copy_weight(layer_src, layer_dst, partition_mode, weight_names): + weight_src = _get_weight(layer_src, weight_names) + weight_dst = _get_weight(layer_dst, weight_names) + if partition_mode is None: + total_weight = weight_src + elif partition_mode == 'column': + total_weight = _get_total_weight(weight_src, tp_group=self.tp_group, axis=0) + elif partition_mode == 'row': + total_weight = _get_total_weight(weight_src, tp_group=self.tp_group, axis=1) + else: + raise ValueError(f"Partition Mode {partition_mode} is not supported.") + assert weight_dst.shape == total_weight.shape, \ + f"Shapes of src:{total_weight.shape} and dst:{weight_dst.shape} do not match." + weight_dst.copy_(total_weight, True) + + copy_weight(layer_tp, layer_single, None, ['self_attention', 'layernorm_qkv', 'ln_weight']) + copy_weight(layer_tp, layer_single, 'column', ['self_attention', 'layernorm_qkv', 'weight']) + copy_weight(layer_tp, layer_single, 'row', ['self_attention', 'proj', 'weight']) + copy_weight(layer_tp, layer_single, None, ['layernorm_mlp', 'ln_weight']) + copy_weight(layer_tp, layer_single, 'column', ['layernorm_mlp', 'fc1_weight']) + copy_weight(layer_tp, layer_single, 'row', ['layernorm_mlp', 'fc2_weight']) + + optimizer_tp = paddle.optimizer.SGD(learning_rate=0.1, parameters=layer_tp.parameters()) + optimizer_single = paddle.optimizer.SGD(learning_rate=0.1, + parameters=layer_single.parameters()) + + layer_tp = fleet.distributed_model(layer_tp) + optimizer_tp = fleet.distributed_optimizer(optimizer_tp) + + def train_one_step(layer, inp_list, optimizer, fp8_enabled): + with te.fp8_autocast(enabled=fp8_enabled): + out = layer(*inp_list) + loss = out.mean() + loss.backward() + optimizer.step() + optimizer.clear_grad() + return loss + + for _ in range(5): + inp = paddle.uniform([self.batch_size, self.q_seqlen, self.hidden_size], + self.global_dtype) + mask = paddle.zeros(shape=(self.batch_size, 1, self.q_seqlen, self.kv_seqlen), + dtype='bool') + loss_tp = train_one_step(layer_tp, [inp, mask], optimizer_tp, self.fp8) + loss_single = train_one_step(layer_single, [inp, mask], optimizer_single, self.fp8) + assert_allclose(loss_tp, loss_single, rtol=self.rtol, atol=self.atol) + + +class TestTransformerTpFp8(TestTransformerTp): + """Tests Transformer layer with tensor parallelism in FP8""" + + def set_attr(self): + """Set test configs""" + self.batch_size = 16 + self.hidden_size = 1024 + self.num_heads = 16 + self.ffn_hidden_size = 4096 + self.q_seqlen = 128 + self.kv_seqlen = 128 + self.mask_type = 'padding' + self.layer_type = 'encoder' + self.global_dtype = 'bfloat16' + self.rtol = 5e-2 + self.atol = 5e-2 + self.eps = 1e-3 + self.fp8 = True + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/paddle/test_layers.py b/tests/paddle/test_layers.py index 171b9233e7..bb93458230 100644 --- a/tests/paddle/test_layers.py +++ b/tests/paddle/test_layers.py @@ -98,8 +98,8 @@ def test_linear_bf16(bs, in_features, out_features, has_bias, no_dbias, no_dgrad """ Test BF16 Linear """ - rtol = 1e-2 - atol = 1e-2 + rtol = 5e-2 + atol = 5e-2 input_tensor = paddle.uniform(shape=(bs, in_features), dtype=activation_dtype) input_tensor.stop_gradient = no_dgrad @@ -258,8 +258,8 @@ def test_layernorm_linear_bf16(bs, in_features, out_features, has_bias, no_dbias Test BF16 LayerNormLinear Layer """ paddle.set_default_dtype(activation_dtype) - rtol = 1e-2 - atol = 1e-2 + rtol = 5e-2 + atol = 5e-2 input_tensor = paddle.uniform(shape=(bs, in_features), dtype=activation_dtype) input_tensor.stop_gradient = no_dgrad @@ -905,7 +905,7 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size, """ paddle.set_default_dtype(math_dtype) rtol = 5e-2 - atol = 5e-2 + atol = 6e-2 eps = 1e-3 encoder_input = paddle.uniform(shape=(bs, q_seqlen, hidden_size), dtype=math_dtype) diff --git a/tests/paddle/test_operators.py b/tests/paddle/test_operators.py index 662978086a..241f96214b 100644 --- a/tests/paddle/test_operators.py +++ b/tests/paddle/test_operators.py @@ -728,8 +728,8 @@ def _get_fused_attention_out(self): return out, q_grad, k_grad, v_grad - @pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0), - reason="cuDNN fMHA requires Ampere+ GPU") + @pytest.mark.skipif(paddle.device.cuda.get_device_capability() not in ((8, 0), (9, 0)), + reason="cuDNN fMHA requires Ampere and Hopper GPU") @pytest.mark.parametrize('b, s, h, d', SELF_ATTN_CASES) @pytest.mark.parametrize('dtype', ['float16', 'bfloat16']) @pytest.mark.parametrize('is_causal_masking', [True, False]) @@ -745,8 +745,8 @@ def test_self_attn_forward_backward(self, b, s, h, d, dtype, is_causal_masking): assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2) assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2) - @pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0), - reason="cuDNN fMHA requires Ampere+ GPU") + @pytest.mark.skipif(paddle.device.cuda.get_device_capability() not in ((8, 0), (9, 0)), + reason="cuDNN fMHA requires Ampere and Hopper GPU") @pytest.mark.parametrize('b, s_q, s_kv, h, d', CROSS_ATTN_CASES) @pytest.mark.parametrize('dtype', ['float16', 'bfloat16']) def test_cross_attn_forward_backward(self, b, s_q, s_kv, h, d, dtype): diff --git a/tests/paddle/test_parallel.py b/tests/paddle/test_parallel.py new file mode 100644 index 0000000000..d6e02747d1 --- /dev/null +++ b/tests/paddle/test_parallel.py @@ -0,0 +1,89 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Test TE Paddle Parallel""" + +from pathlib import Path +import unittest + +from dist_launcher import TestDistributed +from utils import is_devices_enough + +from transformer_engine.paddle.fp8 import is_fp8_available + +test_root = Path(__file__).resolve().parent +gpu_has_fp8, reason = is_fp8_available() + + +class TestParallelLinear(TestDistributed): + """Test Linear in Parallel mode""" + + @unittest.skipIf(not is_devices_enough(2), "TestParallelLinear needs 2 GPUs") + @unittest.skipIf(not gpu_has_fp8, reason) + def test_linear_tp(self): + """Tests linear with tensor parallel in BF16""" + self.run_2gpu(str(test_root / 'parallel_tests' / 'linear_tp.py')) + + +class TestParallelLayerNormLinear(TestDistributed): + """Test LayerNormLinear in Parallel mode""" + + @unittest.skipIf(not is_devices_enough(2), "TestParallelLayerNormLinear needs 2 GPUs") + @unittest.skipIf(not gpu_has_fp8, reason) + def test_layernorm_linear_tp(self): + """Tests layernorm_linear with tensor parallel in BF16""" + self.run_2gpu(str(test_root / 'parallel_tests' / 'layernorm_linear_tp.py')) + + +class TestParallelLayerNormMLP(TestDistributed): + """Test LayerNormMLP in Parallel mode""" + + @unittest.skipIf(not is_devices_enough(2), "TestParallelLayerNormMLP needs 2 GPUs") + @unittest.skipIf(not gpu_has_fp8, reason) + def test_layernorm_mlp_tp(self): + """Tests layernorm_mlp with tensor parallel in BF16""" + self.run_2gpu(str(test_root / 'parallel_tests' / 'layernorm_mlp_tp.py')) + + +class TestAmaxReduction(TestDistributed): + """Test amax reduction in dp mode""" + + @unittest.skipIf(not is_devices_enough(2), "TestAmaxReduction needs 2 GPUs") + @unittest.skipIf(not gpu_has_fp8, reason) + def test_amax_reduction(self): + """Tests amax reduction""" + self.run_2gpu(str(test_root / 'parallel_tests' / 'amax_reduction.py')) + + +class TestPipelineParallel(TestDistributed): + """Test pipeline parallel""" + + @unittest.skipIf(not is_devices_enough(2), "TestPipelineParallel needs 2 GPUs") + @unittest.skipIf(not gpu_has_fp8, reason) + def test_pipeline_parallel(self): + """Tests pipeline parallel""" + self.run_2gpu(str(test_root / 'parallel_tests' / 'linear_pp.py')) + + +class TestGroupSharding(TestDistributed): + """Test group sharding""" + + @unittest.skipIf(not is_devices_enough(2), "TestGroupSharding needs 2 GPUs") + @unittest.skipIf(not gpu_has_fp8, reason) + def test_group_sharding(self): + """Tests group sharding""" + self.run_2gpu(str(test_root / 'parallel_tests' / 'group_sharding.py')) + + +class TestParallelTransformerLayer(TestDistributed): + """Test Transformer Layer in Parallel mode""" + + @unittest.skipIf(not is_devices_enough(2), "TestParallelTransformerLayer needs 2 GPUs") + @unittest.skipIf(not gpu_has_fp8, reason) + def test_transformer_tp(self): + """Tests Transformer Layer with tensor parallel in BF16""" + self.run_2gpu(str(test_root / 'parallel_tests' / 'transformer_tp.py')) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/paddle/utils.py b/tests/paddle/utils.py index 432b39c2e0..5960cccd3d 100644 --- a/tests/paddle/utils.py +++ b/tests/paddle/utils.py @@ -34,3 +34,21 @@ def assert_allclose(actual, if isinstance(desired, paddle.Tensor): desired = paddle.cast(desired, 'float32').numpy() np.testing.assert_allclose(actual, desired, rtol, atol, equal_nan, err_msg, verbose) + + +def assert_shape(inp, expected_shape): + """Assert the shape of input tensor equals to expected shape""" + assert inp.shape == expected_shape, f"Expected tensor shape: {expected_shape} != " \ + f"actual tensor shape: {inp.shape}" + + +def is_devices_enough(required): + """If the number of device is enough""" + return paddle.device.cuda.device_count() >= required + + +def set_random_seed(seed): + """Set random seed for reproducability.""" + np.random.seed(seed) + paddle.seed(seed) + paddle.distributed.fleet.meta_parallel.model_parallel_random_seed(seed) diff --git a/transformer_engine/paddle/constants.py b/transformer_engine/paddle/constants.py index eac161ec60..cfecd39564 100644 --- a/transformer_engine/paddle/constants.py +++ b/transformer_engine/paddle/constants.py @@ -46,3 +46,7 @@ class FP8BwdTensors(Enum): AttnTypes = ("self", "cross") LayerTypes = ("encoder", "decoder") + +GemmParallelModes = ("row", "column", None) + +dist_group_type = paddle.distributed.collective.Group diff --git a/transformer_engine/paddle/distributed.py b/transformer_engine/paddle/distributed.py new file mode 100644 index 0000000000..5bf51c9274 --- /dev/null +++ b/transformer_engine/paddle/distributed.py @@ -0,0 +1,100 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Methods needed for distributed training.""" + +from contextlib import contextmanager +from typing import Optional, Union, Tuple + +import paddle + +import paddle.distributed.fleet.base.topology as tp +from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker +from paddle.distributed.fleet.layers.mpu import mp_ops + +from .constants import dist_group_type + +_weight_split_axis = { + 'transformer_engine': { + 'row': 1, + 'column': 0 + }, + 'paddle': { + 'row': 0, + 'column': 1 + } +} + + +def get_tp_group_and_world_size(tp_group: Union[dist_group_type, None], + enable_tp: bool = True) -> Tuple[Union[dist_group_type, None], int]: + """Get TP group and world size using Fleet API""" + if not (paddle.distributed.is_initialized() and enable_tp): + return None, 1 + model_parallel_group = (tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group() + if tp_group is None else tp_group) + world_size = (tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size() + if tp_group is None else tp_group.nranks) + return model_parallel_group, world_size + + +@contextmanager +def track_rng_state(enable: bool) -> None: + """ + Applies get_rng_state_tracker().rng_state() to the context. + If not enabled, it does nothing. + """ + if enable: + with get_rng_state_tracker().rng_state(): + yield + else: + yield + + +def set_tensor_dist_attr(tensor: paddle.Tensor, is_parallel: bool, axis: int) -> None: + """Set distributed attributes for the input tensor""" + tensor.is_distributed = is_parallel + if is_parallel: + tensor.split_axis = axis + + +def set_weight_tensor_dist_attr(tensor: paddle.Tensor, is_parallel: bool, + parallel_mode: Optional[str], backend: str) -> None: + """Set distributed attributes for the weight tensor""" + if not is_parallel or parallel_mode is None: + return + set_tensor_dist_attr(tensor, is_parallel, axis=_weight_split_axis[backend][parallel_mode]) + + +def allreduce( + input_: paddle.Tensor, + tp_group: Optional[dist_group_type] = None, +) -> paddle.Tensor: + """All-reduce the input tensor across model parallel group.""" + + # Bypass the function if we are using only 1 GPU. + if tp_group is None or tp_group.nranks == 1: + return input_ + + # All-reduce. + output = mp_ops._mp_allreduce( + input_, + group=tp_group, + use_calc_stream=True, + use_model_parallel=True, + ) + + return output + + +def identity( + input_: paddle.Tensor, + tp_group: Optional[dist_group_type] = None, +) -> paddle.Tensor: + """ + Identity when forward. + Allreduce across model parallel group when backward. + """ + output = mp_ops._c_identity(input_, group=tp_group) + + return output diff --git a/transformer_engine/paddle/fp8.py b/transformer_engine/paddle/fp8.py index bcd7ae2b22..576b8d859c 100644 --- a/transformer_engine/paddle/fp8.py +++ b/transformer_engine/paddle/fp8.py @@ -3,9 +3,8 @@ # See LICENSE for license information. """FP8 utilities for TransformerEngine""" -import copy from contextlib import contextmanager -from typing import Tuple, Optional, Dict, Any +from typing import Tuple, Optional, Dict, Any, Union import numpy as np @@ -13,6 +12,9 @@ import transformer_engine_paddle as tex from transformer_engine.common.recipe import DelayedScaling, Format +from .constants import dist_group_type +from .fp8_buffer import FP8MetaFwdBuffer, FP8MetaBwdBuffer + # FP8 support _is_fp8_available = None _reason_for_no_fp8 = "" @@ -50,21 +52,27 @@ class FP8State: """Stores FP8 state""" def __init__(self): - self.fp8_enabled = False - self.fp8_calibration = False - self.fp8_recipe = None + self._fp8_enabled = False + self._fp8_calibration = False + self._fp8_recipe = None + self._fp8_distributed_group = None + self._is_first_fp8_module = False + self._fp8_autocast_counter = 0 + self._fp8_autocast_depth = 0 + self._fp8_fwd_buffer = FP8MetaFwdBuffer() + self._fp8_bwd_buffer = FP8MetaBwdBuffer() def is_fp8_enabled(self) -> bool: """Is FP8 enabled""" - return self.fp8_enabled + return self._fp8_enabled def is_fp8_calibration(self) -> bool: """Is FP8 calibration""" - return self.fp8_calibration + return self._fp8_calibration def get_fp8_recipe(self) -> DelayedScaling: """Return the fp8 recipe""" - return self.fp8_recipe + return self._fp8_recipe @staticmethod def get_default_fp8_recipe() -> DelayedScaling: @@ -73,6 +81,63 @@ def get_default_fp8_recipe() -> DelayedScaling: """ return DelayedScaling() + def get_autocast_id(self) -> int: + """Returns the number of times of entering the `fp8_autocast` context. + as a unique ID for different training steps.""" + return self._fp8_autocast_counter + + def is_first_fp8_module(self): + """Returns `True` only the first time when called multiple + times from within the same `fp8_autocast` context. + """ + tmp = self._is_first_fp8_module + self._is_first_fp8_module = False + return tmp + + def get_fp8_group(self) -> Union[dist_group_type, None]: + """Return the fp8 group for scale/amax comm""" + return self._fp8_distributed_group + + def get_fp8_fwd_buffer(self) -> FP8MetaFwdBuffer: + """Returns global fp8 forward buffer.""" + return self._fp8_fwd_buffer + + def get_fp8_bwd_buffer(self) -> FP8MetaBwdBuffer: + """Returns global fp8 backward buffer.""" + return self._fp8_bwd_buffer + + def enter( + self, + enabled: bool, + calibrating: bool, + fp8_recipe: Optional[DelayedScaling], + fp8_group: Optional[dist_group_type], + ) -> None: + """Called when entering 'fp8_autocast'""" + self.saved_states = (self._fp8_enabled, self._fp8_calibration, self._fp8_recipe, + self._fp8_distributed_group, self._is_first_fp8_module) + + self._fp8_enabled = enabled + self._fp8_calibration = calibrating + self._fp8_recipe = self.get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe + self._fp8_distributed_group = fp8_group + + if self._fp8_autocast_depth == 0: + self._is_first_fp8_module = True + self._fp8_autocast_counter += 1 + self._fp8_autocast_depth += 1 + + def exit(self): + """Called when exiting 'fp8_autocast'""" + # Restore saved states + (self._fp8_enabled, self._fp8_calibration, self._fp8_recipe, self._fp8_distributed_group, + self._is_first_fp8_module) = self.saved_states + + self._fp8_autocast_depth -= 1 + + if self._fp8_autocast_depth == 0: + self._fp8_fwd_buffer.finalize() + _global_fp8_state = FP8State() @@ -87,25 +152,20 @@ def fp8_autocast( enabled: bool = False, calibrating: bool = False, fp8_recipe: Optional[DelayedScaling] = None, + fp8_group: Optional[dist_group_type] = None, ) -> None: """ Context manager for FP8 usage. """ - - global _global_fp8_state - saved_fp8_state = copy.deepcopy(_global_fp8_state) try: - _global_fp8_state.fp8_enabled = enabled - _global_fp8_state.fp8_calibration = calibrating - _global_fp8_state.fp8_recipe = FP8State.get_default_fp8_recipe( - ) if fp8_recipe is None else fp8_recipe + _global_fp8_state.enter(enabled, calibrating, fp8_recipe, fp8_group) if enabled: fp8_available, reason_for_no_fp8 = is_fp8_available() assert fp8_available, reason_for_no_fp8 yield finally: - _global_fp8_state = saved_fp8_state + _global_fp8_state.exit() def get_fp8_te_dtype(fp8_recipe: DelayedScaling, fprop_tensor: bool = True) -> tex.DType: diff --git a/transformer_engine/paddle/fp8_buffer.py b/transformer_engine/paddle/fp8_buffer.py new file mode 100644 index 0000000000..76b0c9db59 --- /dev/null +++ b/transformer_engine/paddle/fp8_buffer.py @@ -0,0 +1,257 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""FP8 meta buffer for FP8 amax reduction""" + +from abc import ABC, abstractmethod +from functools import partial +import os +from typing import Dict, Any, List, Union + +import numpy as np +import paddle + +from .constants import dist_group_type + + +class FP8MetaBufferBase(ABC): + """ + A global buffer that holds FP8 meta for reduction across trainers. + """ + + def __init__(self): + self._data = {} + self._buffer_delete_key = None + self._amax_reduce_wait_func = None + self._dp_amax_reduce_interval = None + self._dp_amax_reduce_idx = 0 + + @staticmethod + @abstractmethod + def _get_meta_tensor_key(): + """Returns scaling key in `fp8_meta`.""" + + @staticmethod + @abstractmethod + def _get_buffer_position_key(): + """Returns module position key in `fp8_meta`.""" + + @staticmethod + @abstractmethod + def _get_autocast_key(): + """Returns autocast id key in `fp8_meta`.""" + + def _get_amax_buffer_key(self, fp8_meta: Dict[str, Any]) -> str: + """Return a key in `_data` for the AMAX storage.""" + return f"AMAX_{fp8_meta[self._get_autocast_key()]}" + + def _execute_deletion(self) -> None: + """Delete the key from global amax buffer.""" + if (self._buffer_delete_key is not None and self._buffer_delete_key in self._data): + del self._data[self._buffer_delete_key] + + def _wait_handle_and_split( + self, + contiguous_amax: paddle.Tensor, + chunk_sizes: List[int], + amax_buffer_key: str, + wait_handle: Union[bool, None], + ) -> None: + """Wait for amax reduction to finish and then copy reduced amax to buffer""" + if wait_handle is not None: + wait_handle.wait() + self._data[amax_buffer_key] = list(contiguous_amax.split(chunk_sizes)) + + def _global_amax_reduction( + self, + fp8_meta: Dict[str, Any], + tp_group: dist_group_type, + tp_size: int, + ) -> None: + """Concatenate, reduce, and split amaxes in the global buffer.""" + + def _reduce_tensor_across_group_op_max(tensor, group, sync_op): + if paddle.distributed.is_initialized(): + wait_handle = paddle.distributed.all_reduce( + tensor, + op=paddle.distributed.ReduceOp.MAX, + group=group, + sync_op=sync_op, + ) + return wait_handle + return None + + amax_buffer_key = self._get_amax_buffer_key(fp8_meta) + # Key already deleted. + if amax_buffer_key not in self._data: + return None + + # Reduce AMAX in DP-domain at an interval. + if self._dp_amax_reduce_interval is None: + self._dp_amax_reduce_interval = int(os.getenv("NVTE_DP_AMAX_REDUCE_INTERVAL", "1")) + + tp_amax_reduce = False + if self._dp_amax_reduce_idx == 0: + reduce_group = fp8_meta["fp8_group"] + else: + tp_amax_reduce = True + self._dp_amax_reduce_idx = (self._dp_amax_reduce_idx + 1) % self._dp_amax_reduce_interval + + if tp_amax_reduce: + if tp_size > 1: + reduce_group = tp_group + else: + return None + + chunk_sizes = [x.shape[0] for x in self._data[amax_buffer_key]] + contiguous_amax = paddle.concat(self._data[amax_buffer_key]) + + wait_handle = _reduce_tensor_across_group_op_max( + contiguous_amax, + reduce_group, + not fp8_meta["async_amax_reduction"], + ) + + return partial( + self._wait_handle_and_split, + contiguous_amax, + chunk_sizes, + amax_buffer_key, + wait_handle, + ) + + def add_amax(self, fp8_meta: Dict[str, Any]) -> None: + """Append `amax_history` to global buffer.""" + buffer_key = self._get_amax_buffer_key(fp8_meta) + fp8_meta_tensor_key = self._get_meta_tensor_key() + buffer_position_key = self._get_buffer_position_key() + + if buffer_key not in self._data: + self._data[buffer_key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]] + else: + self._data[buffer_key].append(fp8_meta[fp8_meta_tensor_key].amax_history[0]) + + if buffer_position_key not in fp8_meta: + fp8_meta[buffer_position_key] = len(self._data[buffer_key]) - 1 + + # Catch incorrect fp8_autocast usage. + assert fp8_meta[buffer_position_key] == len(self._data[buffer_key]) - 1, \ + "Same module is being invoked more than once inside an `fp8_autocast` " \ + "region when using FP8 with amax reduction. This behavior is currently " \ + "unsupported. For more details and correct usage, please see " \ + "https://github.com/NVIDIA/TransformerEngine/pull/93." + + def copy_amax_from_buffer(self, fp8_meta: Dict[str, Any]) -> None: + """Populate current amax with the correct location from buffer.""" + fp8_meta_tensor_key = self._get_meta_tensor_key() + buffer_position_key = self._get_buffer_position_key() + if buffer_position_key not in fp8_meta: + return + + amax_buffer_key = self._get_amax_buffer_key(fp8_meta) + assert amax_buffer_key in self._data, "TE internal error." + + fp8_meta[fp8_meta_tensor_key].amax_history[0] = self._data[amax_buffer_key][ + fp8_meta[buffer_position_key]] + + def set_for_deletion(self, fp8_meta: Dict[str, Any]) -> None: + """Delete this amax key from global buffer during autocast end.""" + if self._get_autocast_key() not in fp8_meta: + return + self._buffer_delete_key = self._get_amax_buffer_key(fp8_meta) + + def get_amax_reduce_handle(self) -> Union[bool, None]: + """Return AMAX reduction wait handle.""" + return self._amax_reduce_handle + + def wait(self) -> None: + """Wait for reduced amax to be available in buffer.""" + if self._amax_reduce_wait_func is not None: + self._amax_reduce_wait_func() # pylint: disable=not-callable + self._amax_reduce_wait_func = None + + def to_numpy(self) -> Dict[str, List[np.array]]: + """Convert to numpy arrays""" + out = {} + for k, v in self._data.items(): + out[k] = [tensor.numpy() for tensor in v] + return out + + def from_numpy(self, buffer: Dict[str, np.array]) -> None: + """Set buffer values from numpy arrays""" + for k, v in buffer.items(): + self._data[k] = [paddle.to_tensor(arr) for arr in v] + + +class FP8MetaFwdBuffer(FP8MetaBufferBase): + """FP8Meta Buffer for forward""" + + @staticmethod + def _get_meta_tensor_key() -> str: + """Returns scaling key in `fp8_meta`.""" + return "scaling_fwd" + + @staticmethod + def _get_buffer_position_key() -> str: + """Returns module position key in `fp8_meta`.""" + return "global_fp8_buffer_pos_fwd" + + @staticmethod + def _get_autocast_key() -> str: + """Returns module position key in `fp8_meta`.""" + return "autocast_id_fwd" + + def set_for_amax_reduction( + self, + fp8_meta: Dict[str, Any], + tp_group: dist_group_type, + tp_size: int, + ) -> None: + """Sets up the function to call during autocast exit.""" + self._amax_global_reduce_func = partial( + self._global_amax_reduction, + fp8_meta, + tp_group, + tp_size, + ) + + def finalize(self) -> None: + """ + Called at FP8 autocast end. + Performs AMAX reduction and delete unused buffer entries. + """ + if hasattr(self, '_amax_global_reduce_func') and callable(self._amax_global_reduce_func): + self._amax_reduce_wait_func = self._amax_global_reduce_func() + self._execute_deletion() + + +class FP8MetaBwdBuffer(FP8MetaBufferBase): + """FP8Meta Buffer for backward""" + + @staticmethod + def _get_meta_tensor_key() -> str: + """Returns scaling key in `fp8_meta`.""" + return "scaling_bwd" + + @staticmethod + def _get_buffer_position_key() -> str: + """Returns module position key in `fp8_meta`.""" + return "global_fp8_buffer_pos_bwd" + + @staticmethod + def _get_autocast_key() -> str: + """Returns module position key in `fp8_meta`.""" + return "autocast_id_bwd" + + def finalize( + self, + fp8_meta: Dict[str, Any], + tp_group: dist_group_type, + tp_size: int, + ) -> None: + """ + Called at FP8 autocast end in backward. + Performs AMAX reduction and delete unused buffer entries. + """ + self._amax_reduce_wait_func = self._global_amax_reduction(fp8_meta, tp_group, tp_size) + self._execute_deletion() diff --git a/transformer_engine/paddle/layer/attention.py b/transformer_engine/paddle/layer/attention.py index a5aac3566f..565321baad 100644 --- a/transformer_engine/paddle/layer/attention.py +++ b/transformer_engine/paddle/layer/attention.py @@ -4,27 +4,25 @@ """Attntion API""" import math +import os import warnings from typing import Optional, Tuple, Union import paddle import paddle.nn.functional as F -from transformer_engine.paddle.constants import ( - AttnTypes, - TE_DType, -) -from transformer_engine.paddle.cpp_extensions import ( +from .layernorm_linear import LayerNormLinear +from .linear import Linear +from .softmax import FusedScaleMaskSoftmax +from ..constants import AttnTypes, TE_DType, dist_group_type +from ..cpp_extensions import ( fused_attn_fwd_qkvpacked, fused_attn_bwd_qkvpacked, fused_attn_fwd_kvpacked, fused_attn_bwd_kvpacked, ) -from transformer_engine.paddle.utils import (attention_mask_func, mask_to_cu_seqlens) -from .base import TransformerEngineBaseLayer -from .layernorm_linear import LayerNormLinear -from .linear import Linear -from .softmax import FusedScaleMaskSoftmax +from ..distributed import get_tp_group_and_world_size, track_rng_state +from ..utils import attention_mask_func, divide, mask_to_cu_seqlens class FusedAttnFuncPackedQKV(paddle.autograd.PyLayer): @@ -161,9 +159,20 @@ def __init__(self, self.attn_mask_type = attn_mask_type self.attention_dropout = attention_dropout self.attention_type = attention_type - self.backend = backend self.rng_state = paddle.zeros((2,), dtype='int64') self.rng_state.persistable = True + + self.backend = backend + + arch = paddle.device.cuda.get_device_capability() + self.is_fused_attn_supported = arch in ((8, 0), (9, 0)) + self.enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", + "0")) and self.is_fused_attn_supported + + if not self.enable_fused_attn and backend == 'transformer_engine': + # FMHA is not enabled, falling back to Paddle backend + self.backend = 'paddle' + if self.backend != 'transformer_engine': self.scale_mask_softmax = FusedScaleMaskSoftmax(attn_mask_type, attention_mask_func, @@ -343,7 +352,7 @@ def _pd_forward( return out -class MultiHeadAttention(TransformerEngineBaseLayer): +class MultiHeadAttention(paddle.nn.Layer): """Attention w/ QKV and Proj Gemms Parameters @@ -390,6 +399,8 @@ def __init__( input_layernorm: bool = False, attention_type: str = "self", zero_centered_gamma: bool = False, + set_parallel_mode: bool = False, + tp_group: Optional[dist_group_type] = None, backend: str = 'transformer_engine', ) -> None: super().__init__() @@ -403,11 +414,19 @@ def __init__( assert attention_type in AttnTypes, f"attention_type {attention_type} not supported" + self.tp_group, self.tp_size = get_tp_group_and_world_size(tp_group, + enable_tp=set_parallel_mode) + self.tensor_parallel = self.tp_size > 1 + self.hidden_size_per_attention_head = hidden_size // num_attention_heads self.num_attention_heads = num_attention_heads norm_factor = math.sqrt(self.hidden_size_per_attention_head) + self.set_parallel_mode = set_parallel_mode self.backend = backend + self.num_attention_heads_per_partition = divide(self.num_attention_heads, self.tp_size) + qkv_parallel_mode = "column" if set_parallel_mode else None + if self.attention_type == "self": if self.input_layernorm: self.layernorm_qkv = LayerNormLinear( @@ -418,6 +437,8 @@ def __init__( bias_attr=self.bias_attr, return_layernorm_output=return_layernorm_output, zero_centered_gamma=zero_centered_gamma, + parallel_mode=qkv_parallel_mode, + tp_group=self.tp_group, backend=self.backend, ) else: @@ -426,6 +447,8 @@ def __init__( 3 * hidden_size, self.weight_attr, self.bias_attr, + parallel_mode=qkv_parallel_mode, + tp_group=self.tp_group, backend=self.backend, ) @@ -439,6 +462,8 @@ def __init__( bias_attr=self.bias_attr, return_layernorm_output=return_layernorm_output, zero_centered_gamma=zero_centered_gamma, + parallel_mode=qkv_parallel_mode, + tp_group=self.tp_group, backend=self.backend, ) else: @@ -447,6 +472,8 @@ def __init__( hidden_size, self.weight_attr, self.bias_attr, + parallel_mode=qkv_parallel_mode, + tp_group=self.tp_group, backend=self.backend, ) self.key_value = Linear( @@ -454,6 +481,8 @@ def __init__( 2 * hidden_size, self.weight_attr, self.bias_attr, + parallel_mode=qkv_parallel_mode, + tp_group=self.tp_group, backend=self.backend, ) @@ -472,6 +501,8 @@ def __init__( hidden_size, self.weight_attr, self.bias_attr, + parallel_mode="row" if set_parallel_mode else None, + tp_group=self.tp_group, backend=self.backend, ) @@ -520,23 +551,26 @@ def forward( mixed_qkv_layer = self.qkv(hidden_states) # [b, s_q, 3 * hidden_size] --> [b, s_q, 3, num_heads, head_size] - mixed_qkv_layer = mixed_qkv_layer.reshape( - shape=[0, 0, 3, self.num_attention_heads, self.hidden_size_per_attention_head]) - - context_layer = self.core_attention( - query_layer=mixed_qkv_layer, - key_value_layer=None, - attention_mask=attention_mask, - core_attention_bias_type=core_attention_bias_type, - core_attention_bias=core_attention_bias, - set_zero=set_zero, - ) + mixed_qkv_layer = mixed_qkv_layer.reshape(shape=[ + 0, 0, 3, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head + ]) + + with track_rng_state(enable=self.tensor_parallel): + context_layer = self.core_attention( + query_layer=mixed_qkv_layer, + key_value_layer=None, + attention_mask=attention_mask, + core_attention_bias_type=core_attention_bias_type, + core_attention_bias=core_attention_bias, + set_zero=set_zero, + ) else: # cross attention mixed_kv_layer = self.key_value(encoder_output) # [b, s_kv, 2 * hidden_size] --> [b, s_kv, 2, num_heads, head_size] - mixed_kv_layer = mixed_kv_layer.reshape( - shape=[0, 0, 2, self.num_attention_heads, self.hidden_size_per_attention_head]) + mixed_kv_layer = mixed_kv_layer.reshape(shape=[ + 0, 0, 2, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head + ]) if self.input_layernorm: layernorm_query_outputs = self.layernorm_query(hidden_states) @@ -547,16 +581,18 @@ def forward( else: query_layer = self.query_layer(hidden_states) - query_layer = query_layer.reshape( - shape=[0, 0, self.num_attention_heads, self.hidden_size_per_attention_head]) - context_layer = self.core_attention( - query_layer=query_layer, - key_value_layer=mixed_kv_layer, - attention_mask=attention_mask, - core_attention_bias_type=core_attention_bias_type, - core_attention_bias=core_attention_bias, - set_zero=set_zero, - ) + query_layer = query_layer.reshape(shape=[ + 0, 0, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head + ]) + with track_rng_state(enable=self.tensor_parallel): + context_layer = self.core_attention( + query_layer=query_layer, + key_value_layer=mixed_kv_layer, + attention_mask=attention_mask, + core_attention_bias_type=core_attention_bias_type, + core_attention_bias=core_attention_bias, + set_zero=set_zero, + ) context_layer = paddle.reshape(context_layer, [0, 0, context_layer.shape[2] * context_layer.shape[3]]) diff --git a/transformer_engine/paddle/layer/base.py b/transformer_engine/paddle/layer/base.py index 5e16fda098..0f5a1af65c 100644 --- a/transformer_engine/paddle/layer/base.py +++ b/transformer_engine/paddle/layer/base.py @@ -5,6 +5,7 @@ from abc import ABC, abstractmethod from contextlib import contextmanager +import os import pickle from typing import Generator, Dict, Tuple, Union, Any @@ -14,7 +15,7 @@ from paddle.fluid import core from paddle.fluid.framework import _dygraph_tracer -from ..constants import FP8BwdTensors +from ..constants import FP8BwdTensors, dist_group_type from ..cpp_extensions import cast_transpose, cast_transpose_bgrad, cast_to_fp8 from ..fp8 import ( FP8State, @@ -24,7 +25,6 @@ get_fp8_te_dtype, ) from ..profile import nvtx_range -from ..utils import get_bias_dtype, cast_if_needed _2X_ACC_FPROP = False _2X_ACC_DGRAD = True @@ -61,9 +61,15 @@ def __init__(self) -> None: self.fp8_calibration = False self.fp8_meta = {} self.fp8_meta["fp8_checkpoint"] = False + self.fp8_meta["fp8_group"] = None self.fp8_meta["recipe"] = FP8State.get_default_fp8_recipe() self.fp8_meta["scaling_fwd"] = FP8TensorMeta(is_forward=True) self.fp8_meta["scaling_bwd"] = FP8TensorMeta(is_forward=False) + self.tp_group = None + self.tp_size = 1 + self.fp8_meta["autocast_id_fwd_stack"] = [] + self.fp8_meta["async_amax_reduction"] = bool( + int(os.getenv("NVTE_ASYNC_AMAX_REDUCTION", "0"))) def set_activation_dtype(self, inp: paddle.Tensor) -> None: """Get activation data type for AMP.""" @@ -102,18 +108,20 @@ def set_activation_dtype(self, inp: paddle.Tensor) -> None: # assume FP8 execution. def fp8_init(self, num_gemms: int = 1) -> None: """Initialize fp8 related metadata and tensors during fprop.""" - state = get_global_fp8_state() - self.fp8_enabled = state.is_fp8_enabled() - self.fp8_calibration = state.is_fp8_calibration() + global_fp8_state = get_global_fp8_state() + self.fp8_enabled = global_fp8_state.is_fp8_enabled() + self.fp8_calibration = global_fp8_state.is_fp8_calibration() self.fp8_meta["fp8_checkpoint"] = self.fp8_enabled or self.fp8_calibration if self.fp8_enabled or self.fp8_calibration: # FP8 init has already been run and recipe is the same, don't do anything. - if self.fp8_initialized and state.get_fp8_recipe() == self.fp8_meta["recipe"]: + if self.fp8_initialized and global_fp8_state.get_fp8_recipe( + ) == self.fp8_meta["recipe"]: return # Set FP8, recipe, and other FP8 metadata - self.fp8_meta["recipe"] = state.get_fp8_recipe() + self.fp8_meta["recipe"] = global_fp8_state.get_fp8_recipe() + self.fp8_meta["fp8_group"] = global_fp8_state.get_fp8_group() # Set FP8_MAX per tensor according to recipe self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd @@ -136,6 +144,8 @@ def _get_fp8_state(self) -> paddle.Tensor: state = {} state["scaling_fwd"] = self.fp8_meta["scaling_fwd"].to_numpy() state["scaling_bwd"] = self.fp8_meta["scaling_bwd"].to_numpy() + state["global_fp8_fwd_buffer"] = get_global_fp8_state().get_fp8_fwd_buffer().to_numpy() + state["global_fp8_bwd_buffer"] = get_global_fp8_state().get_fp8_bwd_buffer().to_numpy() # Store other pickelable values. extra = {} for k, v in self.fp8_meta.items(): @@ -179,6 +189,12 @@ def _set_fp8_state(self, state: paddle.Tensor) -> None: self.fp8_meta["scaling_fwd"].from_numpy(state["scaling_fwd"]) self.fp8_meta["scaling_bwd"].from_numpy(state["scaling_bwd"]) + # Restore global FP8 buffer states. + global_fp8_fwd_buffer = get_global_fp8_state().get_fp8_fwd_buffer() + global_fp8_bwd_buffer = get_global_fp8_state().get_fp8_bwd_buffer() + global_fp8_fwd_buffer.from_numpy(state["global_fp8_fwd_buffer"]) + global_fp8_bwd_buffer.from_numpy(state["global_fp8_bwd_buffer"]) + # Load extra items. self.fp8_meta.update(state["extra_fp8_variables"]) self.fp8_meta["recipe"].amax_history_len = self.fp8_meta["scaling_fwd"].amax_history.shape[ @@ -210,9 +226,22 @@ def prepare_forward( # Previous iteration was grad_enabled if self.fp8_meta.get("update_amax_and_scale_fwd", False): - amax_and_scale_update(self.fp8_meta, True) + global_fp8_fwd_buffer = get_global_fp8_state().get_fp8_fwd_buffer() + global_fp8_fwd_buffer.wait() + if self.fp8_meta["recipe"].reduce_amax: + global_fp8_fwd_buffer.copy_amax_from_buffer(self.fp8_meta) + amax_and_scale_update(self.fp8_meta, True) + global_fp8_fwd_buffer.set_for_deletion(self.fp8_meta) + else: + amax_and_scale_update(self.fp8_meta, True) if self.fp8_enabled and self.training: + # Setup for amax reduction + if self.fp8_meta["recipe"].reduce_amax: + global_fp8_state = get_global_fp8_state() + self.fp8_meta["first_module"] = global_fp8_state.is_first_fp8_module() + self.fp8_meta["autocast_id_fwd"] = global_fp8_state.get_autocast_id() + self.fp8_meta["autocast_id_fwd_stack"].append(self.fp8_meta["autocast_id_fwd"]) self.fp8_meta["update_amax_and_scale_fwd"] = True else: self.fp8_meta["update_amax_and_scale_fwd"] = False @@ -220,18 +249,47 @@ def prepare_forward( with nvtx_range(self.__class__.__name__ + " forward"): yield inp + if self.fp8_enabled and self.training and self.fp8_meta["recipe"].reduce_amax: + global_fp8_state = get_global_fp8_state() + global_fp8_fwd_buffer = global_fp8_state.get_fp8_fwd_buffer() + global_fp8_fwd_buffer.add_amax(self.fp8_meta) + global_fp8_fwd_buffer.set_for_amax_reduction( + self.fp8_meta, + self.tp_group, + self.tp_size, + ) + @staticmethod @contextmanager def prepare_backward(fp8_enabled: bool, fp8_meta: Dict[str, Any], + tp_group: dist_group_type, + tp_size: int, name: str = "") -> Generator[None, None, None]: """Checks and prep for BWD.""" if fp8_enabled: - amax_and_scale_update(fp8_meta, False) + global_fp8_state = get_global_fp8_state() + global_fp8_bwd_buffer = global_fp8_state.get_fp8_bwd_buffer() + global_fp8_bwd_buffer.wait() + + if fp8_meta["recipe"].reduce_amax: + global_fp8_bwd_buffer.copy_amax_from_buffer(fp8_meta) + amax_and_scale_update(fp8_meta, False) + global_fp8_bwd_buffer.set_for_deletion(fp8_meta) + + # Get new backward key. + fp8_meta["autocast_id_bwd"] = fp8_meta["autocast_id_fwd_stack"].pop(0) + else: + amax_and_scale_update(fp8_meta, False) with nvtx_range(name + " backward"): yield + if fp8_enabled and fp8_meta["recipe"].reduce_amax: + global_fp8_bwd_buffer.add_amax(fp8_meta) + if fp8_meta["first_module"]: + global_fp8_bwd_buffer.finalize(fp8_meta, tp_group, tp_size) + @staticmethod def grad_output_preprocess( ctx, grad_output: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None], ...]: @@ -258,8 +316,6 @@ def grad_output_preprocess( FP8BwdTensors.GRAD_OUTPUT1, fp8_dtype_backward, ) - bias_dtype = get_bias_dtype(ctx.activation_dtype) - bgrad = cast_if_needed(bgrad, bias_dtype) else: if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: grad_output_c, grad_output_t = cast_transpose( diff --git a/transformer_engine/paddle/layer/layernorm.py b/transformer_engine/paddle/layer/layernorm.py index 3f0b8c4a50..89c03ee25c 100644 --- a/transformer_engine/paddle/layer/layernorm.py +++ b/transformer_engine/paddle/layer/layernorm.py @@ -31,7 +31,7 @@ def forward( zero_centered_gamma: bool, ) -> paddle.Tensor: # Make sure input dimensions are compatible - in_features = ln_weight.numel() + in_features = ln_weight.shape[0] assert inp.shape[-1] == in_features, "LayerNorm not possible" inputmat = inp.reshape((-1, in_features)) diff --git a/transformer_engine/paddle/layer/layernorm_linear.py b/transformer_engine/paddle/layer/layernorm_linear.py index 608f02a6ff..285cf4609a 100644 --- a/transformer_engine/paddle/layer/layernorm_linear.py +++ b/transformer_engine/paddle/layer/layernorm_linear.py @@ -4,7 +4,7 @@ """LayerNormLinear API""" import os -from typing import Union, Tuple, Dict, Any +from typing import Union, Tuple, Dict, Any, Optional import paddle import paddle.nn.functional as F @@ -21,9 +21,22 @@ from .base import TransformerEngineBaseLayer from .linear import _linear_fwd, _linear_bwd -from ..constants import TE_DType, FP8FwdTensors, FP8BwdTensors +from ..constants import TE_DType, FP8FwdTensors, FP8BwdTensors, GemmParallelModes, dist_group_type +from ..distributed import ( + allreduce, + get_tp_group_and_world_size, + identity, + track_rng_state, + set_tensor_dist_attr, + set_weight_tensor_dist_attr, +) from ..fp8 import get_fp8_te_dtype -from ..utils import cast_if_needed, cast_if_needed_inplace, assert_dim_for_fp8_forward_exec +from ..utils import ( + assert_dim_for_fp8_forward_exec, + cast_if_needed, + cast_if_needed_inplace, + divide, +) __all__ = ["LayerNormLinear", "_layernorm_fwd_fp8_cast", "_layernorm_bwd"] @@ -128,9 +141,13 @@ def forward( fwd_ln_sm_margin: int, bwd_ln_sm_margin: int, zero_centered_gamma: bool, + parallel_mode: Union[str, None], + tensor_parallel: bool, + tp_group: Union[dist_group_type, None], + tp_size: int, ) -> Union[Tuple[paddle.Tensor, ...], paddle.Tensor]: # Make sure input dimensions are compatible - in_features = ln_weight.numel() + in_features = ln_weight.shape[0] assert inp.shape[-1] == in_features, "GEMM not possible" inputmat = inp.reshape((-1, in_features)) if fp8_enabled: @@ -169,6 +186,9 @@ def forward( fp8_calibration, fp8_meta, activation_dtype, + parallel_mode, + tensor_parallel, + tp_group, is_grad_enabled, ) @@ -192,6 +212,10 @@ def forward( ctx.return_layernorm_output = return_layernorm_output ctx.bwd_ln_sm_margin = bwd_ln_sm_margin ctx.zero_centered_gamma = zero_centered_gamma + ctx.parallel_mode = parallel_mode + ctx.tensor_parallel = tensor_parallel + ctx.tp_group = tp_group + ctx.tp_size = tp_size ctx.requires_dgrad = not inp.stop_gradient ctx.requires_bgrad = use_bias and not bias.stop_gradient ctx.requires_ln_bgrad = not ln_bias.stop_gradient @@ -208,6 +232,8 @@ def backward( ...]) -> Tuple[Union[paddle.Tensor, None], ...]: with TransformerEngineBaseLayer.prepare_backward(ctx.fp8_enabled, ctx.fp8_meta, + ctx.tp_group, + ctx.tp_size, name="_LayerNormLinear"): ( inputmat, @@ -262,6 +288,9 @@ def backward( ctx.fp8_meta, True, # Always compute dgrad to feed into LayerNorm bwd ctx.activation_dtype, + ctx.parallel_mode, + ctx.tensor_parallel, + ctx.tp_group, ) if not ctx.fp8_enabled: @@ -307,6 +336,8 @@ def __init__( bias_attr: Union[paddle.ParamAttr, None, bool] = None, return_layernorm_output: bool = False, zero_centered_gamma: bool = False, + parallel_mode: Optional[str] = None, + tp_group: Union[dist_group_type, None] = None, backend: str = 'transformer_engine', ) -> None: super().__init__() @@ -322,9 +353,23 @@ def __init__( self._bias_attr = bias_attr self._dtype = self._helper.get_default_dtype() + # Set parallel configs + self.tp_group, self.tp_size = get_tp_group_and_world_size(tp_group, + enable_tp=parallel_mode + is not None) + self.tensor_parallel = self.tp_size > 1 + self.parallel_mode = parallel_mode + assert (self.parallel_mode + in GemmParallelModes), f"parallel_mode {parallel_mode} not supported" + + if self.parallel_mode == "column": + self.out_features = divide(self.out_features, self.tp_size) + elif self.parallel_mode == "row": + self.in_features = divide(self.in_features, self.tp_size) + # LayerNorm weights self.ln_weight = self.create_parameter( - shape=[in_features], + shape=[self.in_features], attr=paddle.ParamAttr(initializer=Constant( value=0.0 if self.zero_centered_gamma else 1.0)), dtype=self._dtype, @@ -332,34 +377,48 @@ def __init__( ) self.ln_bias = self.create_parameter( - shape=[in_features], + shape=[self.in_features], attr=paddle.ParamAttr(initializer=Constant(value=0.0)), dtype=self._dtype, is_bias=True, ) - # Linear weights - self.weight = self.create_parameter( - shape=[out_features, in_features] - if self.backend == 'transformer_engine' else [in_features, out_features], - attr=self._weight_attr, - dtype=self._dtype, - is_bias=False, - ) + # Initialize Linear weight parameter + with track_rng_state(enable=self.tensor_parallel): + # TE linear weight is in column major + self.weight = self.create_parameter( + shape=[self.out_features, self.in_features] + if self.backend == 'transformer_engine' else [self.in_features, self.out_features], + attr=self._weight_attr, + dtype=self._dtype, + is_bias=False, + ) + set_weight_tensor_dist_attr(self.weight, self.tensor_parallel, self.parallel_mode, + self.backend) + # Initialize Linear bias parameter self.has_bias = self._bias_attr is not False use_default_bias = self._bias_attr is None or self._bias_attr is True if self.has_bias: self.bias = self.create_parameter( - shape=[out_features], + shape=[self.out_features], attr=self._bias_attr if not use_default_bias else paddle.ParamAttr( initializer=Constant(value=0.0)), dtype=self._dtype, is_bias=True, ) + if parallel_mode == "column": + set_tensor_dist_attr(self.bias, self.tensor_parallel, axis=0) else: self.bias = None + # For RPL, bias has to be added after TP collectives + # So it cannot be fused with the GEMM + if self.parallel_mode == "row" and self.tensor_parallel and self.has_bias: + self.gemm_bias_fused_add = False + else: + self.gemm_bias_fused_add = True + # These many SMs are subtracted from the total SM count when calling forward # and backward LayerNorm C APIs. These envvars can be used to prevent the LN # kernels from using all SMs in the device. This is useful for cases such as @@ -385,8 +444,8 @@ def _te_forward( self.ln_weight, self.ln_bias, self.weight, - self.bias, - self.has_bias, + self.bias if self.gemm_bias_fused_add else None, + self.has_bias and self.gemm_bias_fused_add, self.eps, self.fp8_enabled, self.fp8_calibration, @@ -397,10 +456,19 @@ def _te_forward( self.fwd_ln_sm_margin, self.bwd_ln_sm_margin, self.zero_centered_gamma, + self.parallel_mode, + self.tensor_parallel, + self.tp_group, + self.tp_size, ) if self.return_layernorm_output: out, ln_out = out + + if not self.gemm_bias_fused_add: + out = out + cast_if_needed_inplace(self.bias, self.activation_dtype) + + if self.return_layernorm_output: return out, ln_out return out @@ -418,7 +486,12 @@ def _pd_forward( weight=self.ln_weight, bias=self.ln_bias, epsilon=self.eps) - out = F.linear(ln_out, self.weight, self.bias) + if self.parallel_mode == 'column' and self.tensor_parallel: + ln_out = identity(ln_out, self.tp_group) + out = F.linear(ln_out, self.weight, self.bias if self.gemm_bias_fused_add else None) + if self.parallel_mode == 'row' and self.tensor_parallel: + out = allreduce(out, self.tp_group) + out = out + self.bias if self.bias is not None else out if self.return_layernorm_output: return out, ln_out return out diff --git a/transformer_engine/paddle/layer/layernorm_mlp.py b/transformer_engine/paddle/layer/layernorm_mlp.py index 6d725114b0..9b89d05d47 100644 --- a/transformer_engine/paddle/layer/layernorm_mlp.py +++ b/transformer_engine/paddle/layer/layernorm_mlp.py @@ -4,25 +4,38 @@ """LayerNormMLP API""" import os -from typing import Union, Tuple, Dict, Any +from typing import Union, Tuple, Dict, Any, Optional import paddle import paddle.nn.functional as F from paddle.nn.initializer import Constant +from .base import TransformerEngineBaseLayer +from .layernorm_linear import _layernorm_fwd_fp8_cast, _layernorm_bwd +from .linear import _linear_fwd_fp8, _linear_fwd_non_fp8, _linear_bwd_fp8, _linear_bwd_non_fp8 +from ..constants import TE_DType, FP8FwdTensors, FP8BwdTensors, dist_group_type from ..cpp_extensions import ( cast_from_fp8, dgelu_cast_transpose_bgrad_fp8, gelu_fp8, transpose, ) - -from .base import TransformerEngineBaseLayer -from .layernorm_linear import _layernorm_fwd_fp8_cast, _layernorm_bwd -from .linear import _linear_fwd_fp8, _linear_fwd_non_fp8, _linear_bwd_fp8, _linear_bwd_non_fp8 -from ..constants import TE_DType, FP8FwdTensors, FP8BwdTensors +from ..distributed import ( + allreduce, + get_tp_group_and_world_size, + identity, + track_rng_state, + set_tensor_dist_attr, + set_weight_tensor_dist_attr, +) from ..fp8 import get_fp8_te_dtype -from ..utils import cast_if_needed, assert_dim_for_fp8_forward_exec, get_paddle_act_func +from ..utils import ( + assert_dim_for_fp8_forward_exec, + cast_if_needed, + cast_if_needed_inplace, + divide, + get_paddle_act_func, +) __all__ = ["LayerNormMLP"] @@ -43,7 +56,11 @@ def _mlp_forward( fp8_calibration: bool, fp8_meta: Dict[str, Any], activation_dtype: paddle.dtype, + activation: str, is_grad_enabled: bool, + set_parallel_mode: bool, + tensor_parallel: bool, + tp_group: Union[dist_group_type, None], ): if fp8_enabled: fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) @@ -56,6 +73,9 @@ def _mlp_forward( use_fc1_bias, fp8_meta, activation_dtype, + 'column' if set_parallel_mode else None, + tensor_parallel, + tp_group, is_grad_enabled, ) @@ -75,6 +95,9 @@ def _mlp_forward( use_fc2_bias, fp8_meta, activation_dtype, + 'row' if set_parallel_mode else None, + tensor_parallel, + tp_group, is_grad_enabled, ) else: @@ -88,7 +111,10 @@ def _mlp_forward( fp8_calibration, fp8_meta, activation_dtype, - activation='gelu', + 'column' if set_parallel_mode else None, + tensor_parallel, + tp_group, + activation=activation, ) fc2_out = _linear_fwd_non_fp8( @@ -101,6 +127,9 @@ def _mlp_forward( fp8_calibration, fp8_meta, activation_dtype, + 'row' if set_parallel_mode else None, + tensor_parallel, + tp_group, ) return ( fc1_out, @@ -136,6 +165,9 @@ def _mlp_backward( requires_dgrad: bool, activation_dtype: paddle.dtype, activation: str, + set_parallel_mode: bool, + tensor_parallel: bool, + tp_group: Union[dist_group_type, None], ): ( fc1_dgrad, @@ -179,6 +211,9 @@ def _mlp_backward( True, requires_fc2_wgrad, activation_dtype, + 'row' if set_parallel_mode else None, + tensor_parallel, + tp_group, ) # GELU Bwd @@ -193,7 +228,7 @@ def _mlp_backward( if requires_fc1_bgrad: fc1_bgrad = fc1_bgrad_ - # FC2 Bwd + # FC1 Bwd requires_fc1_wgrad = not fc1_weight.stop_gradient dgelu_no_fp8, fc1_input_no_fp8, fc1_input_t = None, None, None if requires_fc1_wgrad: @@ -231,6 +266,9 @@ def _mlp_backward( requires_dgrad, requires_fc1_wgrad, activation_dtype, + 'column' if set_parallel_mode else None, + tensor_parallel, + tp_group, ) else: dgelu, fc2_wgrad, fc2_bgrad = _linear_bwd_non_fp8( @@ -240,6 +278,9 @@ def _mlp_backward( requires_fc2_bgrad, True, activation_dtype, + 'row' if set_parallel_mode else None, + tensor_parallel, + tp_group, gelu_input=fc1_out, activation=activation, ) @@ -250,6 +291,9 @@ def _mlp_backward( requires_fc1_bgrad, requires_dgrad, activation_dtype, + 'column' if set_parallel_mode else None, + tensor_parallel, + tp_group, ) return ( fc1_dgrad, @@ -286,9 +330,13 @@ def forward( bwd_ln_sm_margin: int, zero_centered_gamma: bool, activation: str, + set_parallel_mode: bool, + tensor_parallel: bool, + tp_group: Union[dist_group_type, None], + tp_size: int, ) -> Union[Tuple[paddle.Tensor, ...], paddle.Tensor]: # Make sure input dimensions are compatible - in_features = ln_weight.numel() + in_features = ln_weight.shape[0] assert inp.shape[-1] == in_features, "GEMM not possible" inputmat = inp.reshape((-1, in_features)) if fp8_enabled: @@ -341,7 +389,11 @@ def forward( fp8_calibration, fp8_meta, activation_dtype, + activation, is_grad_enabled, + set_parallel_mode, + tensor_parallel, + tp_group, ) if is_grad_enabled: @@ -369,6 +421,10 @@ def forward( ctx.return_layernorm_output = return_layernorm_output ctx.bwd_ln_sm_margin = bwd_ln_sm_margin ctx.zero_centered_gamma = zero_centered_gamma + ctx.set_parallel_mode = set_parallel_mode + ctx.tensor_parallel = tensor_parallel + ctx.tp_group = tp_group + ctx.tp_size = tp_size ctx.requires_dgrad = not inp.stop_gradient ctx.requires_fc1_bgrad = use_fc1_bias and not fc1_bias.stop_gradient ctx.requires_fc2_bgrad = use_fc2_bias and not fc2_bias.stop_gradient @@ -387,6 +443,8 @@ def backward( ...]) -> Tuple[Union[paddle.Tensor, None], ...]: with TransformerEngineBaseLayer.prepare_backward(ctx.fp8_enabled, ctx.fp8_meta, + ctx.tp_group, + ctx.tp_size, name="_LayerNormMLP"): ( inputmat, @@ -442,6 +500,9 @@ def backward( True, ctx.activation_dtype, ctx.activation, + ctx.set_parallel_mode, + ctx.tensor_parallel, + ctx.tp_group, ) if not ctx.fp8_enabled: # fc2_bias is fused with gemm for non-FP8 path @@ -491,6 +552,8 @@ def __init__( activation: str = "gelu", return_layernorm_output: bool = False, zero_centered_gamma: bool = False, + set_parallel_mode: bool = False, + tp_group: Optional[dist_group_type] = None, backend: str = 'transformer_engine', ) -> None: super().__init__() @@ -507,6 +570,17 @@ def __init__( self._bias_attr = bias_attr self._dtype = self._helper.get_default_dtype() + # Set parallel configs + self.tp_group, self.tp_size = get_tp_group_and_world_size(tp_group, + enable_tp=set_parallel_mode) + self.tensor_parallel = self.tp_size > 1 + self.set_parallel_mode = set_parallel_mode + + if self.set_parallel_mode: + self.size_per_partition = divide(self.ffn_hidden_size, self.tp_size) + else: + self.size_per_partition = self.ffn_hidden_size + # LayerNorm weights self.ln_weight = self.create_parameter( shape=[self.hidden_size], @@ -524,36 +598,47 @@ def __init__( ) # FC1 weights - self.fc1_weight = self.create_parameter( - shape=[self.ffn_hidden_size, self.hidden_size] - if self.backend == 'transformer_engine' else [self.hidden_size, self.ffn_hidden_size], - attr=self._weight_attr, - dtype=self._dtype, - is_bias=False, - ) + with track_rng_state(enable=self.tensor_parallel): + self.fc1_weight = self.create_parameter( + shape=[self.size_per_partition, self.hidden_size] if self.backend + == 'transformer_engine' else [self.hidden_size, self.size_per_partition], + attr=self._weight_attr, + dtype=self._dtype, + is_bias=False, + ) + set_weight_tensor_dist_attr(self.fc1_weight, + self.tensor_parallel, + parallel_mode='column', + backend=self.backend) self.has_bias = self._bias_attr is not False - if self._bias_attr is None or self._bias_attr is True: + use_default_bias = self._bias_attr is None or self._bias_attr is True + if use_default_bias: self._bias_attr = paddle.ParamAttr(initializer=Constant(value=0.0)) if self.has_bias: self.fc1_bias = self.create_parameter( - shape=[self.ffn_hidden_size], + shape=[self.size_per_partition], attr=self._bias_attr, dtype=self._dtype, is_bias=True, ) + set_tensor_dist_attr(self.fc1_bias, self.tensor_parallel, axis=0) else: self.fc1_bias = None # FC2 weights self.fc2_weight = self.create_parameter( - shape=[self.hidden_size, self.ffn_hidden_size] - if self.backend == 'transformer_engine' else [self.ffn_hidden_size, self.hidden_size], + shape=[self.hidden_size, self.size_per_partition] if self.backend + == 'transformer_engine' else [self.size_per_partition, self.hidden_size], attr=self._weight_attr, dtype=self._dtype, is_bias=False, ) + set_weight_tensor_dist_attr(self.fc2_weight, + self.tensor_parallel, + parallel_mode='row', + backend=self.backend) if self.has_bias: self.fc2_bias = self.create_parameter( @@ -565,6 +650,13 @@ def __init__( else: self.fc2_bias = None + # For RPL, bias has to be added after TP collectives + # So it cannot be fused with the GEMM + if self.set_parallel_mode and self.tensor_parallel and self.has_bias: + self.gemm_bias_fused_add = False + else: + self.gemm_bias_fused_add = True + # These many SMs are subtracted from the total SM count when calling forward # and backward LayerNorm C APIs. These envvars can be used to prevent the LN # kernels from using all SMs in the device. This is useful for cases such as @@ -606,12 +698,20 @@ def _te_forward( self.bwd_ln_sm_margin, self.zero_centered_gamma, self.activation, + self.set_parallel_mode, + self.tensor_parallel, + self.tp_group, + self.tp_size, ) if self.return_layernorm_output: out, ln_out = out - return out, ln_out + if not self.gemm_bias_fused_add: + out = out + cast_if_needed_inplace(self.fc2_bias, self.activation_dtype) + + if self.return_layernorm_output: + return out, ln_out return out def _pd_forward( @@ -628,11 +728,16 @@ def _pd_forward( weight=self.ln_weight, bias=self.ln_bias, epsilon=self.eps) + if self.set_parallel_mode and self.tensor_parallel: + ln_out = identity(ln_out, self.tp_group) fc1_out = F.linear(ln_out, self.fc1_weight, self.fc1_bias) act_func = get_paddle_act_func(self.activation) act_out = act_func(fc1_out) - out = F.linear(act_out, self.fc2_weight, self.fc2_bias) - + out = F.linear(act_out, self.fc2_weight, + self.fc2_bias if self.gemm_bias_fused_add else None) + if self.set_parallel_mode and self.tensor_parallel: + out = allreduce(out, self.tp_group) + out = out + self.fc2_bias if self.fc2_bias is not None else out if self.return_layernorm_output: return out, ln_out return out diff --git a/transformer_engine/paddle/layer/linear.py b/transformer_engine/paddle/layer/linear.py index dc9863e062..ff164067a7 100644 --- a/transformer_engine/paddle/layer/linear.py +++ b/transformer_engine/paddle/layer/linear.py @@ -3,7 +3,7 @@ # See LICENSE for license information. """Linear API""" -from typing import Union, Tuple, Dict, Any +from typing import Union, Tuple, Dict, Any, Optional import paddle import paddle.nn.functional as F @@ -17,13 +17,22 @@ _2X_ACC_WGRAD, ) -from ..fp8 import get_fp8_te_dtype -from ..constants import FP8FwdTensors, FP8BwdTensors +from ..constants import FP8FwdTensors, FP8BwdTensors, GemmParallelModes, dist_group_type from ..cpp_extensions import gemm, fp8_gemm, cast_to_fp8, cast_transpose +from ..distributed import ( + allreduce, + get_tp_group_and_world_size, + identity, + track_rng_state, + set_tensor_dist_attr, + set_weight_tensor_dist_attr, +) +from ..fp8 import get_fp8_te_dtype from ..utils import ( + assert_dim_for_fp8_forward_exec, cast_if_needed, cast_if_needed_inplace, - assert_dim_for_fp8_forward_exec, + divide, get_bias_dtype, ) @@ -39,12 +48,15 @@ def _linear_fwd_fp8( use_bias: bool, fp8_meta: Dict[str, Any], activation_dtype: paddle.dtype, + parallel_mode: Union[str, None], + tensor_parallel: bool, + tp_group: Union[dist_group_type, None], is_grad_enabled: bool, ): """FP8 path of Linear Fwd""" fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) bias_dtype = get_bias_dtype(activation_dtype) - bias = cast_if_needed_inplace(bias, bias_dtype) + bias = cast_if_needed(bias, bias_dtype) if is_grad_enabled: weight_fp8, weight_t_fp8 = cast_transpose( @@ -78,6 +90,10 @@ def _linear_fwd_fp8( use_split_accumulator=_2X_ACC_FPROP, ) + # Row Parallel Linear + if parallel_mode == "row" and tensor_parallel: + out = allreduce(out, tp_group) + return out, weight_t_fp8 @@ -91,6 +107,9 @@ def _linear_fwd_non_fp8( fp8_calibration: bool, fp8_meta: Dict[str, Any], activation_dtype: paddle.dtype, + parallel_mode: Union[str, None], + tensor_parallel: bool, + tp_group: Union[dist_group_type, None], activation: str = "", ): """Non-FP8 path of Linear Fwd""" @@ -123,6 +142,9 @@ def _linear_fwd_non_fp8( return out, gelu_out out, _, _ = outputs + # Row Parallel Linear + if parallel_mode == "row" and tensor_parallel: + out = allreduce(out, tp_group) return out @@ -137,6 +159,9 @@ def _linear_fwd( fp8_calibration: bool, fp8_meta: Dict[str, Any], activation_dtype: paddle.dtype, + parallel_mode: Union[str, None], + tensor_parallel: bool, + tp_group: Union[dist_group_type, None], is_grad_enabled: bool, ): if fp8_enabled: @@ -149,6 +174,9 @@ def _linear_fwd( use_bias, fp8_meta, activation_dtype, + parallel_mode, + tensor_parallel, + tp_group, is_grad_enabled, ) else: @@ -162,6 +190,9 @@ def _linear_fwd( fp8_calibration, fp8_meta, activation_dtype, + parallel_mode, + tensor_parallel, + tp_group, ) return ( out, @@ -184,6 +215,9 @@ def _linear_bwd_fp8( requires_dgrad: bool, requires_wgrad: bool, activation_dtype: paddle.dtype, + parallel_mode: Union[str, None], + tensor_parallel: bool, + tp_group: Union[dist_group_type, None], ): dgrad, wgrad = None, None fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) @@ -202,6 +236,9 @@ def _linear_bwd_fp8( get_workspace(), use_split_accumulator=_2X_ACC_DGRAD, ) + if parallel_mode == "column" and tensor_parallel: + dgrad = allreduce(dgrad, tp_group) + if requires_wgrad: if not fp8_meta["recipe"].override_linear_precision.wgrad: wgrad = fp8_gemm( @@ -236,6 +273,9 @@ def _linear_bwd_non_fp8( requires_bgrad: bool, requires_dgrad: bool, activation_dtype: paddle.dtype, + parallel_mode: Union[str, None], + tensor_parallel: bool, + tp_group: Union[dist_group_type, None], gelu_input: Union[paddle.Tensor, None] = None, activation: str = "", ): @@ -255,6 +295,9 @@ def _linear_bwd_non_fp8( gelu_input=gelu_input, grad=True, ) + if parallel_mode == "column" and tensor_parallel: + dgrad = allreduce(dgrad, tp_group) + if requires_wgrad: wgrad, bgrad, _ = gemm( inputmat, @@ -288,6 +331,9 @@ def _linear_bwd( fp8_meta: Dict[str, Any], requires_dgrad: bool, activation_dtype: paddle.dtype, + parallel_mode: Union[str, None], + tensor_parallel: bool, + tp_group: Union[dist_group_type, None], ): dgrad, wgrad, bgrad = None, None, None requires_wgrad = not weight.stop_gradient @@ -307,6 +353,9 @@ def _linear_bwd( requires_dgrad, requires_wgrad, activation_dtype, + parallel_mode, + tensor_parallel, + tp_group, ) else: dgrad, wgrad, bgrad = _linear_bwd_non_fp8( @@ -316,6 +365,9 @@ def _linear_bwd( requires_bgrad, requires_dgrad, activation_dtype, + parallel_mode, + tensor_parallel, + tp_group, ) return dgrad, wgrad, bgrad @@ -335,6 +387,10 @@ def forward( fp8_meta: Dict[str, Any], activation_dtype: paddle.dtype, is_grad_enabled: bool, + parallel_mode: Union[str, None], + tensor_parallel: bool, + tp_group: Union[dist_group_type, None], + tp_size: int, ) -> paddle.Tensor: # Make sure input dimensions are compatible in_features = weight.shape[-1] @@ -385,6 +441,9 @@ def forward( fp8_calibration, fp8_meta, activation_dtype, + parallel_mode, + tensor_parallel, + tp_group, is_grad_enabled, ) @@ -402,6 +461,10 @@ def forward( ctx.fp8_meta = fp8_meta ctx.use_bias = use_bias ctx.inp_shape = inp.shape + ctx.parallel_mode = parallel_mode + ctx.tensor_parallel = tensor_parallel + ctx.tp_group = tp_group + ctx.tp_size = tp_size ctx.requires_dgrad = not inp.stop_gradient ctx.requires_bgrad = use_bias and not bias.stop_gradient @@ -411,6 +474,8 @@ def forward( def backward(ctx, grad_output: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None], ...]: with TransformerEngineBaseLayer.prepare_backward(ctx.fp8_enabled, ctx.fp8_meta, + ctx.tp_group, + ctx.tp_size, name="_Linear"): ( inputmat, @@ -444,6 +509,9 @@ def backward(ctx, grad_output: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None ctx.fp8_meta, ctx.requires_dgrad, ctx.activation_dtype, + ctx.parallel_mode, + ctx.tensor_parallel, + ctx.tp_group, ) if not ctx.fp8_enabled: @@ -474,6 +542,8 @@ def __init__( out_features: int, weight_attr: Union[paddle.ParamAttr, None] = None, bias_attr: Union[paddle.ParamAttr, None, bool] = None, + parallel_mode: Optional[str] = None, + tp_group: Union[dist_group_type, None] = None, backend: str = 'transformer_engine', ) -> None: super().__init__() @@ -484,28 +554,56 @@ def __init__( self._bias_attr = bias_attr self._dtype = self._helper.get_default_dtype() - # TE linear weight is in column major - self.weight = self.create_parameter( - shape=[out_features, in_features] - if self.backend == 'transformer_engine' else [in_features, out_features], - attr=self._weight_attr, - dtype=self._dtype, - is_bias=False, - ) + # Set parallel configs + self.tp_group, self.tp_size = get_tp_group_and_world_size(tp_group, + enable_tp=parallel_mode + is not None) + self.tensor_parallel = self.tp_size > 1 + self.parallel_mode = parallel_mode + assert (self.parallel_mode + in GemmParallelModes), f"parallel_mode {parallel_mode} not supported" + + if self.parallel_mode == "column": + self.out_features = divide(self.out_features, self.tp_size) + elif self.parallel_mode == "row": + self.in_features = divide(self.in_features, self.tp_size) + + # Initialize weight parameter + with track_rng_state(enable=self.tensor_parallel): + # TE linear weight is in column major + self.weight = self.create_parameter( + shape=[self.out_features, self.in_features] + if self.backend == 'transformer_engine' else [self.in_features, self.out_features], + attr=self._weight_attr, + dtype=self._dtype, + is_bias=False, + ) + set_weight_tensor_dist_attr(self.weight, self.tensor_parallel, self.parallel_mode, + self.backend) + # Initialize bias parameter self.has_bias = self._bias_attr is not False use_default_bias = self._bias_attr is None or self._bias_attr is True if self.has_bias: self.bias = self.create_parameter( - shape=[out_features], + shape=[self.out_features], attr=self._bias_attr if not use_default_bias else paddle.ParamAttr( initializer=Constant(value=0.0)), dtype=self._dtype, is_bias=True, ) + if parallel_mode == "column": + set_tensor_dist_attr(self.bias, self.tensor_parallel, axis=0) else: self.bias = None + # For RPL, bias has to be added after TP collectives + # So it cannot be fused with the GEMM + if self.parallel_mode == "row" and self.tensor_parallel and self.has_bias: + self.gemm_bias_fused_add = False + else: + self.gemm_bias_fused_add = True + def _te_forward( self, inp: paddle.Tensor, @@ -521,15 +619,22 @@ def _te_forward( out = _Linear.apply( self.weight, inp, - self.bias, - self.has_bias, + self.bias if self.gemm_bias_fused_add else None, + self.has_bias and self.gemm_bias_fused_add, self.fp8_enabled, self.fp8_calibration, self.fp8_meta, self.activation_dtype, paddle.is_grad_enabled(), + self.parallel_mode, + self.tensor_parallel, + self.tp_group, + self.tp_size, ) + if not self.gemm_bias_fused_add: + out = out + cast_if_needed_inplace(self.bias, self.activation_dtype) + return out def _pd_forward( @@ -537,7 +642,13 @@ def _pd_forward( inp: paddle.Tensor, ) -> paddle.Tensor: """Calls Paddle OP""" - return F.linear(inp, self.weight, self.bias) + if self.parallel_mode == 'column' and self.tensor_parallel: + inp = identity(inp, self.tp_group) + out = F.linear(inp, self.weight, self.bias if self.gemm_bias_fused_add else None) + if self.parallel_mode == 'row' and self.tensor_parallel: + out = allreduce(out, self.tp_group) + out = out + self.bias if self.bias is not None else out + return out def forward(self, *args, **kwargs): """forward""" diff --git a/transformer_engine/paddle/layer/transformer.py b/transformer_engine/paddle/layer/transformer.py index 6e6afd4ca2..a95b9fcfe1 100644 --- a/transformer_engine/paddle/layer/transformer.py +++ b/transformer_engine/paddle/layer/transformer.py @@ -7,15 +7,11 @@ import paddle -from transformer_engine.paddle.constants import ( - AttnMaskTypes, - LayerTypes, -) -from transformer_engine.paddle.layer import (LayerNormMLP, LayerNorm, MultiHeadAttention) -from .base import TransformerEngineBaseLayer +from . import LayerNormMLP, LayerNorm, MultiHeadAttention +from ..constants import AttnMaskTypes, LayerTypes, dist_group_type -class TransformerLayer(TransformerEngineBaseLayer): +class TransformerLayer(paddle.nn.Layer): r""" TransformerLayer is made up of an attention block and a feedforward network (MLP). This standard layer is based on the paper "Attention Is All You Need". @@ -64,6 +60,16 @@ class TransformerLayer(TransformerEngineBaseLayer): it controls the type used to allocate the initial parameters. Useful when the model is trained with lower precision and the original FP32 parameters would not fit in GPU memory. + + Parallelism parameters + ---------------------- + set_parallel_mode : bool, default = `False` + if set to `True`, QKV and FC1 layers are used as Column Parallel + whereas PROJ and FC2 is used as Row Parallel as described + `here `_. + tp_group : ProcessGroup, default = `None` + tensor parallel process group. + """ def __init__(self, @@ -82,6 +88,8 @@ def __init__(self, layer_type: str = "encoder", zero_centered_gamma: bool = False, activation: str = 'gelu', + set_parallel_mode: bool = False, + tp_group: Optional[dist_group_type] = None, backend: str = 'transformer_engine') -> None: super().__init__() @@ -90,6 +98,8 @@ def __init__(self, self.layer_type = layer_type self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm self.self_attn_mask_type = self_attn_mask_type + self.set_parallel_mode = set_parallel_mode + self.tp_group = tp_group assert (self_attn_mask_type in AttnMaskTypes), f"self_attn_mask_type {self_attn_mask_type} not supported" @@ -107,6 +117,8 @@ def __init__(self, "params_dtype": params_dtype, "return_layernorm_output": apply_residual_connection_post_layernorm, "zero_centered_gamma": zero_centered_gamma, + "set_parallel_mode": set_parallel_mode, + "tp_group": tp_group, "backend": backend, } @@ -136,6 +148,8 @@ def __init__(self, activation=activation, return_layernorm_output=apply_residual_connection_post_layernorm, zero_centered_gamma=zero_centered_gamma, + set_parallel_mode=set_parallel_mode, + tp_group=tp_group, backend=backend, ) From 112f67f6bbb93d2d3e42fb75c16801815f187e95 Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Thu, 21 Sep 2023 00:55:08 +0200 Subject: [PATCH 52/68] [pyTorch] Enable the model to change precision between iterations (#414) * Enable the model to be change precision between iterations Signed-off-by: Przemek Tredak * Add test Signed-off-by: Przemek Tredak * Fix for the test Signed-off-by: Przemek Tredak --------- Signed-off-by: Przemek Tredak Co-authored-by: Kirthi Shankar Sivamani --- tests/pytorch/test_sanity.py | 13 +++++++++++++ transformer_engine/pytorch/module/base.py | 3 +-- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 21497b417f..65af2f9713 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -788,3 +788,16 @@ def test_gpt_cuda_graph(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_ ) _test_sanity_e2e_cuda_graph(block, bs, dtype, config, fp8_recipe, skip_wgrad) + +def test_model_multiple_cast(): + a = torch.zeros((16,16)).cuda() + m = Linear(16,32) + + y = m(a) + assert y.dtype == torch.float32 + + m.half() + a = a.half() + + y2 = m(a) + assert y2.dtype == torch.float16 diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 0352a7ba2b..82d39eeaf0 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -445,8 +445,7 @@ def set_activation_dtype(self, inp: torch.Tensor) -> None: return # All checks after this have already been performed once, thus skip - # We assume that user doesn't change input types across iterations - if hasattr(self, "activation_dtype"): + if hasattr(self, "activation_dtype") and self.activation_dtype == inp.dtype: return dtype = inp.dtype From 291cb4fcbe97d8711c3bd4b78afb02d8cb440a34 Mon Sep 17 00:00:00 2001 From: zlsh80826 Date: Fri, 22 Sep 2023 10:05:29 +0800 Subject: [PATCH 53/68] [Paddle] Eliminate amax update bubbles by using custom_ops (#436) * Eliminate amax_and_scale_update bubbles Signed-off-by: rewang * Add CUDA check Signed-off-by: rewang --------- Signed-off-by: rewang --- tests/paddle/test_operators.py | 38 ++++++++- transformer_engine/paddle/csrc/custom_ops.cu | 81 +++++++++++++++----- transformer_engine/paddle/fp8.py | 32 ++------ transformer_engine/paddle/fp8_buffer.py | 7 +- 4 files changed, 108 insertions(+), 50 deletions(-) diff --git a/tests/paddle/test_operators.py b/tests/paddle/test_operators.py index c4211a7218..7a2472e4bc 100644 --- a/tests/paddle/test_operators.py +++ b/tests/paddle/test_operators.py @@ -865,13 +865,17 @@ def test_scaled_upper_triang_masked_softmax_fwd_bwd(dtype): assert_allclose(dx_ref, dx, rtol=1e-4, atol=5e-3) -def test_update_scale(): +def test_amax_and_scale_update(): """Test update_scale""" num_gemm = 6 + history_len = 1024 recipe = DelayedScaling() fp8_max = recipe.fp8_format.value.max_fwd - amax_tensor = paddle.rand(shape=[num_gemm], dtype='float32') * fp8_max + amax_history_tensor = paddle.rand(shape=[history_len, num_gemm], dtype='float32') + rolled_history_ref = paddle.roll(amax_history_tensor, -1, axis=0) + rolled_history_ref[0] = 0.0 + amax_tensor = paddle.max(amax_history_tensor, axis=0) scale_tensor = paddle.ones(shape=[num_gemm], dtype='float32') def calc_ref(amax, scale, fp8_max, margin=0): @@ -884,6 +888,32 @@ def calc_ref(amax, scale, fp8_max, margin=0): return sf scale_ref = calc_ref(amax_tensor, scale_tensor, fp8_max, 0.) - scale_actual = tex.update_scale(amax_tensor, scale_tensor, fp8_max, 0.) + scale_inv_ref = 1. / scale_ref - assert_allclose(scale_ref, scale_actual, rtol=1e-5, atol=1e-5) + # Placeholder + scale_actual = paddle.zeros_like(scale_tensor) + scale_inv_actual = paddle.zeros_like(scale_tensor) + + tex.amax_and_scale_update_inplace(_amax_history=amax_history_tensor, + _scale=scale_actual, + _scale_inv=scale_inv_actual, + fp8_max=fp8_max, + margin=0., + amax_compute="max") + + assert_allclose(scale_actual, scale_ref, rtol=1e-7, atol=1e-7) + assert_allclose(scale_inv_actual, scale_inv_ref, rtol=1e-7, atol=1e-7) + assert_allclose(amax_history_tensor, rolled_history_ref, rtol=1e-7, atol=1e-7) + + +def test_update_latest_history(): + """Test update_latest_history""" + num_gemm = 6 + history_len = 1024 + + amax_history_tensor = paddle.rand(shape=[history_len, num_gemm], dtype='float32') + amax = paddle.rand(shape=[num_gemm], dtype='float32') + + tex.update_latest_amax_history_inplace(_history=amax_history_tensor, amax=amax) + + assert_allclose(amax_history_tensor[0], amax, rtol=1e-7, atol=1e-7) diff --git a/transformer_engine/paddle/csrc/custom_ops.cu b/transformer_engine/paddle/csrc/custom_ops.cu index 76f8987306..44e0202e53 100644 --- a/transformer_engine/paddle/csrc/custom_ops.cu +++ b/transformer_engine/paddle/csrc/custom_ops.cu @@ -1019,28 +1019,62 @@ void te_scaled_upper_triang_masked_softmax_backward(paddle::Tensor &output_grads softmax_results.stream()); } -__global__ void UpdateScalesKernel(const float *amax, const float *scale, float margin, - float fp8_max, size_t size, float *scale_out) { +__global__ void UpdateFP8MetaKernel(const float *amax, const float *rolled_amax_history, + float *amax_history, float *scale, float *scale_inv, + float margin, float fp8_max, size_t history_numel, + size_t amax_numel) { int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < size) { + if (idx >= history_numel) { + return; + } + + amax_history[idx] = rolled_amax_history[idx]; + + if (idx < amax_numel) { float exp = floor(log2(fp8_max / amax[idx])) - margin; float sf = round(powf(2.0f, abs(exp))); - sf = ((amax[idx] > 0.0f) && isfinite(amax[idx])) ? sf : scale[idx]; - scale_out[idx] = exp < 0.0f ? 1 / sf : sf; + float scale_reg = scale[idx]; + sf = ((amax[idx] > 0.0f) && isfinite(amax[idx])) ? sf : scale_reg; + scale_reg = exp < 0.0f ? 1 / sf : sf; + scale[idx] = scale_reg; + scale_inv[idx] = 1.0f / scale_reg; + amax_history[idx] = 0.0f; } } -std::vector update_scale(const paddle::Tensor &amax, const paddle::Tensor &scale, - float fp8_max, float margin) { - const size_t block_size = 512; - size_t size = static_cast(amax.numel()); - size_t num_blocks = (size + block_size - 1) / block_size; - auto scale_out = paddle::empty_like(scale, scale.dtype(), scale.place()); - UpdateScalesKernel<<>>( - amax.data(), scale.data(), margin, fp8_max, size, scale_out.data()); +void amax_and_scale_update_inplace(paddle::Tensor &amax_history, // NOLINT + paddle::Tensor &scale, // NOLINT + paddle::Tensor &scale_inv, // NOLINT + float fp8_max, float margin, const std::string &amax_compute) { + NVTE_CHECK(amax_compute == "max" || amax_compute == "most_recent"); + + paddle::Tensor amax; + + if (amax_compute == "max") { + amax = amax_history.max({0}); + } else { + amax = amax_history.slice(0, 1); + } + + const auto rolled_amax_history = amax_history.roll({-1}, {0}); + + auto size = amax_history.numel(); + constexpr int BLOCK_SIZE = 256; + size_t num_blocks = (size + BLOCK_SIZE - 1) / BLOCK_SIZE; + UpdateFP8MetaKernel<<>>( + amax.data(), rolled_amax_history.data(), amax_history.data(), + scale.data(), scale_inv.data(), margin, fp8_max, amax_history.numel(), + amax.numel()); + NVTE_CHECK_CUDA(cudaGetLastError()); +} - return {scale_out}; +void update_latest_amax_history_inplace(paddle::Tensor &history, // NOLINT + const paddle::Tensor &amax) { + // Copy amax to history[0] + NVTE_CHECK_CUDA(cudaMemcpyAsync(history.data(), amax.data(), + amax.numel() * SizeOf(amax.dtype()), cudaMemcpyDeviceToDevice, + amax.stream())); } } // namespace paddle_ext @@ -1242,8 +1276,17 @@ PD_BUILD_OP(te_scaled_upper_triang_masked_softmax_backward) .SetKernelFn( PD_KERNEL(transformer_engine::paddle_ext::te_scaled_upper_triang_masked_softmax_backward)); -PD_BUILD_OP(update_scale) - .Inputs({"Amax", "Scale"}) - .Outputs({"ScaleOut"}) - .Attrs({"fp8_max: float", "margin: float"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::update_scale)); +PD_BUILD_OP(amax_and_scale_update_inplace) + .Inputs({"_amax_history", "_scale", "_scale_inv"}) + .Outputs({"amax_history", "scale", "scale_inv"}) + .SetInplaceMap({{"_amax_history", "amax_history"}, + {"_scale", "scale"}, + {"_scale_inv", "scale_inv"}}) + .Attrs({"fp8_max: float", "margin: float", "amax_compute: std::string"}) + .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::amax_and_scale_update_inplace)); + +PD_BUILD_OP(update_latest_amax_history_inplace) + .Inputs({"_history", "amax"}) + .Outputs({"history"}) + .SetInplaceMap({{"_history", "history"}}) + .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::update_latest_amax_history_inplace)); diff --git a/transformer_engine/paddle/fp8.py b/transformer_engine/paddle/fp8.py index e56f1de767..abf347042a 100644 --- a/transformer_engine/paddle/fp8.py +++ b/transformer_engine/paddle/fp8.py @@ -197,30 +197,12 @@ def amax_and_scale_update( fp8_max_key = "fp8_max_fwd" if fwd_update else "fp8_max_bwd" if not callable(amax_compute) and sf_compute is None: - # Obtain amax from history - amax_history = fp8_meta[fp8_meta_tensor_key].amax_history - if amax_compute == "max": - amax = paddle.max(amax_history, axis=0) - else: # amax_compute_algo == "most_recent" - amax = amax_history[0] - - # Update amax history and set next amax to zero - if amax_history.shape[0] > 1: - amax_history = paddle.roll(amax_history, -1, 0) - amax_history[0] = 0.0 - fp8_meta[fp8_meta_tensor_key].amax_history = amax_history - - # Update scaling factor - fp8_meta[fp8_meta_tensor_key].scale = tex.update_scale( - amax=amax, - scale=fp8_meta[fp8_meta_tensor_key].scale, - fp8_max=fp8_meta[fp8_max_key], - margin=float(fp8_meta["recipe"].margin)) - - # Update scale_inv - fp8_meta[fp8_meta_tensor_key].scale_inv = \ - 1.0 / fp8_meta[fp8_meta_tensor_key].scale - + tex.amax_and_scale_update_inplace(_amax_history=fp8_meta[fp8_meta_tensor_key].amax_history, + _scale=fp8_meta[fp8_meta_tensor_key].scale, + _scale_inv=fp8_meta[fp8_meta_tensor_key].scale_inv, + fp8_max=fp8_meta[fp8_max_key], + margin=float(fp8_meta["recipe"].margin), + amax_compute=amax_compute) else: raise ValueError("We only support the fp8 recipe with 'max' or 'most_recent' " "amax_compute_algo and default scaling_factor_compute_algo at this " @@ -247,7 +229,7 @@ def prepare(self, num_gemms: bool, amax_history_len: int) -> None: curr_len = self.amax_history.shape[0] num_fp8_tensors = self.amax_history.shape[1] if amax_history_len < curr_len: - self.amax_history = (self.amax_history[:amax_history_len]) + self.amax_history = self.amax_history[:amax_history_len] elif amax_history_len > curr_len: extra_rows = amax_history_len - curr_len self.amax_history = paddle.concat([ diff --git a/transformer_engine/paddle/fp8_buffer.py b/transformer_engine/paddle/fp8_buffer.py index b6f082d69d..93090195a1 100644 --- a/transformer_engine/paddle/fp8_buffer.py +++ b/transformer_engine/paddle/fp8_buffer.py @@ -11,6 +11,7 @@ import numpy as np import paddle +import transformer_engine_paddle as tex from .constants import dist_group_type, RecomputeFunctionNames @@ -152,8 +153,10 @@ def copy_amax_from_buffer(self, fp8_meta: Dict[str, Any]) -> None: amax_buffer_key = self._get_amax_buffer_key(fp8_meta) assert amax_buffer_key in self._data, "TE internal error." - fp8_meta[fp8_meta_tensor_key].amax_history[0] = self._data[amax_buffer_key][ - fp8_meta[buffer_position_key]] + # Copy amax to amax_history[0] + tex.update_latest_amax_history_inplace( + _history=fp8_meta[fp8_meta_tensor_key].amax_history, + amax=self._data[amax_buffer_key][fp8_meta[buffer_position_key]]) def set_for_deletion(self, fp8_meta: Dict[str, Any]) -> None: """Delete this amax key from global buffer during autocast end.""" From a6e1b10f05718c0853792532e9fa556c60a411f3 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Fri, 22 Sep 2023 23:42:31 -0700 Subject: [PATCH 54/68] Change scaling factor from E8M0 to E8M23 (#427) * Change scaling factor from E8M0 to E8M23 Signed-off-by: Kirthi Shankar Sivamani * fix formula Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani --- tests/paddle/test_operators.py | 4 +--- transformer_engine/common/recipe.py | 3 +-- transformer_engine/jax/fp8.py | 10 +++------- transformer_engine/paddle/csrc/custom_ops.cu | 7 ++----- transformer_engine/pytorch/fp8.py | 5 +---- transformer_engine/tensorflow/fp8.py | 4 +--- 6 files changed, 9 insertions(+), 24 deletions(-) diff --git a/tests/paddle/test_operators.py b/tests/paddle/test_operators.py index 7a2472e4bc..fbdd95de95 100644 --- a/tests/paddle/test_operators.py +++ b/tests/paddle/test_operators.py @@ -880,11 +880,9 @@ def test_amax_and_scale_update(): def calc_ref(amax, scale, fp8_max, margin=0): """Calculate reference scale""" - exp = paddle.floor(paddle.log2(fp8_max / amax)) - margin - sf = paddle.round(2**paddle.abs(exp)) + sf = (fp8_max / amax) / (2 ** margin) sf = paddle.where(amax > 0.0, sf, scale) sf = paddle.where(paddle.isfinite(amax), sf, scale) - sf = paddle.where(exp < 0, 1 / sf, sf) return sf scale_ref = calc_ref(amax_tensor, scale_tensor, fp8_max, 0.) diff --git a/transformer_engine/common/recipe.py b/transformer_engine/common/recipe.py index 3bb5320475..c5d2ee4972 100644 --- a/transformer_engine/common/recipe.py +++ b/transformer_engine/common/recipe.py @@ -115,8 +115,7 @@ def scaling_factor_compute(amax: Tensor, .. code-block:: python FP8_MAX = maximum_representable_value(fp8_format) - exp = get_exponent(FP8_MAX / amax) - margin - new_scaling_factor = 2.0 ^ exp + new_scaling_factor = (FP8_MAX / amax) / (2 ^ margin) * The scaling factor should always be a power of 2 to not introduce numerical error during the conversion from FP8 to higher precision format. diff --git a/transformer_engine/jax/fp8.py b/transformer_engine/jax/fp8.py index f5015a315f..83aad88c07 100644 --- a/transformer_engine/jax/fp8.py +++ b/transformer_engine/jax/fp8.py @@ -310,11 +310,9 @@ def _update_fp8_metas_impl(fp8_metas: Collection) -> Collection: amax = fp8_meta_arrays[fp8_amax_idx][..., 0:1] scale = fp8_meta_arrays[fp8_scale_idx] - exp = jnp.floor(jnp.log2(fp8_max / amax)) - FP8Helper.MARGIN - sf = jnp.round(jnp.power(2, jnp.abs(exp))) + sf = (fp8_max / amax) / (2 ** FP8Helper.MARGIN) sf = jnp.where(amax > 0.0, sf, scale) sf = jnp.where(jnp.isfinite(amax), sf, scale) - scale = jnp.where(exp < 0, 1 / sf, sf) fp8_meta_arrays[fp8_scale_idx] = scale fp8_meta_arrays[fp8_scale_inv_idx] = 1 / scale @@ -426,11 +424,9 @@ def update_fp8_metas(state: Collection) -> Collection: .. code-block:: python - exp = floor(log2(fp8_max / amax)) - margin - sf = round(power(2, abs(exp))) + sf = (fp8_max / amax) / (2 ^ margin) sf = sf if amax > 0.0, else original_scale - sf = sf if isfinite(amax), else original_scale) - updated_scale = 1/sf if exp < 0, else sf + updated_scale = sf if isfinite(amax), else original_scale) updated_scale_inv = 1/updated_scale Collection = [dict, flax.core.frozen_dict.FrozenDict] diff --git a/transformer_engine/paddle/csrc/custom_ops.cu b/transformer_engine/paddle/csrc/custom_ops.cu index 44e0202e53..d08080b168 100644 --- a/transformer_engine/paddle/csrc/custom_ops.cu +++ b/transformer_engine/paddle/csrc/custom_ops.cu @@ -1032,11 +1032,8 @@ __global__ void UpdateFP8MetaKernel(const float *amax, const float *rolled_amax_ amax_history[idx] = rolled_amax_history[idx]; if (idx < amax_numel) { - float exp = floor(log2(fp8_max / amax[idx])) - margin; - float sf = round(powf(2.0f, abs(exp))); - float scale_reg = scale[idx]; - sf = ((amax[idx] > 0.0f) && isfinite(amax[idx])) ? sf : scale_reg; - scale_reg = exp < 0.0f ? 1 / sf : sf; + float sf = (fp8_max / amax[idx]) / powf(2.0f, margin); + float scale_reg = ((amax[idx] > 0.0f) && isfinite(amax[idx])) ? sf : scale[idx]; scale[idx] = scale_reg; scale_inv[idx] = 1.0f / scale_reg; amax_history[idx] = 0.0f; diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 5e9f6634f9..51cd565f5b 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -538,12 +538,9 @@ def _default_sf_compute( margin: int, ) -> torch.Tensor: """Default function to convert amax to scaling factor.""" - exp = torch.floor(torch.log2(fp8_max / amax)) - margin - sf = torch.round(torch.pow(2, torch.abs(exp))) + sf = (fp8_max / amax) / (2 ** margin) sf = torch.where(amax > 0.0, sf, scale) sf = torch.where(torch.isfinite(amax), sf, scale) - sf = torch.where(exp < 0, 1 / sf, sf) - return sf diff --git a/transformer_engine/tensorflow/fp8.py b/transformer_engine/tensorflow/fp8.py index d04471ff12..b6dfb69308 100644 --- a/transformer_engine/tensorflow/fp8.py +++ b/transformer_engine/tensorflow/fp8.py @@ -157,11 +157,9 @@ def get_fp8_recipe(): def _default_sf_compute(amax, scale, fp8_max, margin): """Default function to convert amax to scaling factor.""" - exp = tf.math.floor(tf.experimental.numpy.log2(fp8_max / amax)) - margin - sf = tf.math.round(tf.math.pow(2.0, tf.math.abs(exp))) + sf = (fp8_max / amax) / (2 ** margin) sf = tf.where(amax > 0.0, sf, scale) sf = tf.where(tf.math.is_finite(amax), sf, scale) - sf = tf.where(exp < 0, 1.0 / sf, sf) return sf From a7b22b754cd49ccf556240d725a9bdb2ae68caff Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Fri, 22 Sep 2023 23:42:44 -0700 Subject: [PATCH 55/68] [PyTorch] Fix ONNX exports (#437) * Fix ONNX exports Signed-off-by: Kirthi Shankar Sivamani * docs Signed-off-by: Kirthi Shankar Sivamani * review Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani --- tests/pytorch/test_onnx_export.py | 173 ++-------------------- transformer_engine/pytorch/attention.py | 65 +++----- transformer_engine/pytorch/transformer.py | 34 ++--- 3 files changed, 48 insertions(+), 224 deletions(-) diff --git a/tests/pytorch/test_onnx_export.py b/tests/pytorch/test_onnx_export.py index 14640febde..533e0cff6a 100644 --- a/tests/pytorch/test_onnx_export.py +++ b/tests/pytorch/test_onnx_export.py @@ -763,156 +763,6 @@ def forward(self, inp): validate_result( fname, inp, model, atol=atol, is_fp8=use_fp8, allow_cnt_errors=3, te_outputs=te_outputs) -@skip_FP8 -@pytest.mark.parametrize("softmax_fn", [ - softmax_defs.ScaledUpperTriangMaskedSoftmax, - softmax_defs.ScaledMaskedSoftmax, - softmax_defs.ScaledSoftmax, - te.softmax.FusedScaleMaskSoftmax, -]) -# Softmax kernel only supports FP16 or BF16! -@pytest.mark.parametrize("precision", [torch.float16, torch.bfloat16, "fake-torch.bfloat16"]) -def test_export_softmax(seed_default_rng, set_max_seq_len, softmax_fn, precision): - class Test_Softmax(nn.Module): - def __init__(self, softmax_fn, fake_bf16_io, mask_inp=False): - super().__init__() - self.softmax_fn = softmax_fn - self.scale = 8 # arbitrary value - self.mask_inp = mask_inp - self.fused_scaled_softmax = None - self.fake_bf16_io = fake_bf16_io - if self.softmax_fn == te.softmax.FusedScaleMaskSoftmax: - self.fused_scaled_softmax = te.softmax.FusedScaleMaskSoftmax( - mask_func=te.utils.attention_mask_func, - softmax_in_fp32=True, - ) - - def forward(self, inp, mask): - if self.fake_bf16_io: - inp = inp.type(torch.bfloat16) - - if self.fused_scaled_softmax: - ret = self.fused_scaled_softmax(inp, mask, "causal", self.scale) - else: - if self.mask_inp: - ret = self.softmax_fn.apply(inp, mask, self.scale) - else: - ret = self.softmax_fn.apply(inp, self.scale) - if self.fake_bf16_io: - ret = ret.type(torch.float32) - return ret - - fake_bf16_io = precision == "fake-torch.bfloat16" - precision = torch.bfloat16 if fake_bf16_io else precision - - # Set dimensions (these are arbitrary). - batch_size, n_heads, seq_len_q, seq_len_k = 64, 96, 32, 32 - mask = None - input_names = ["input", "mask"] - inp_shape = [batch_size, n_heads, seq_len_q, seq_len_k] - if softmax_fn == softmax_defs.ScaledUpperTriangMaskedSoftmax: - inp_shape = [batch_size, seq_len_q, seq_len_k] - kernel_str = "ScaledUpperTriangMaskedSoftmax" - model = Test_Softmax(softmax_fn, fake_bf16_io) - elif softmax_fn == softmax_defs.ScaledMaskedSoftmax: - # Generate a random mask with 50% probability for 0 or 1. - probs = 0.5 * torch.ones(1, 1, seq_len_q, seq_len_k, device="cuda", dtype=precision) - mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool) - kernel_str = "ScaledMaskedSoftmax" - model = Test_Softmax(softmax_fn, fake_bf16_io, mask_inp=True) - elif softmax_fn == softmax_defs.ScaledSoftmax: - kernel_str = "ScaledSoftmax" - model = Test_Softmax(softmax_fn, fake_bf16_io) - elif softmax_fn == te.softmax.FusedScaleMaskSoftmax: - kernel_str = "TorchSoftmax" - model = Test_Softmax(softmax_fn, fake_bf16_io) - - input_tensor = torch.randn(*inp_shape, device="cuda", dtype=torch.float32 if fake_bf16_io else precision) - high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io) - fname = f"{kernel_str}{high_prec_str}.onnx" - inp = (input_tensor, mask) - dynamic_axes = {} - if mask is not None: - dynamic_axes = {"mask": {2:"seq_len_q", 3:"seq_len_k"}} - do_export(model, inp, fname, input_names=input_names, dynamic_axes=dynamic_axes) - te_outputs = te_infer(model, inp, is_fp8=False) - serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names) - if fake_bf16_io or precision != torch.bfloat16: - atol = 5e-2 if fake_bf16_io else 1e-3 - validate_result(fname, inp, model, atol=atol, input_names=input_names, te_outputs=te_outputs) - - -# Test dynamically generated softmax mask. -# Softmax kernel only supports FP16 or BF16! -@skip_FP8 -@pytest.mark.parametrize("precision", [torch.float16, torch.bfloat16, "fake-torch.bfloat16"]) -def test_softmax_mask_fn(seed_default_rng, precision): - fake_bf16_io = precision == "fake-torch.bfloat16" - # reset precision to torch.bfloat16 after capturing fake BF16 mode - precision = torch.bfloat16 if fake_bf16_io else precision - - class Test_Softmax(nn.Module): - def __init__(self, use_default_te_mask_fn: bool, fake_bf16_io: bool): - super().__init__() - self.scale = 1 # arbitrary value - self.fake_bf16_io = fake_bf16_io - - if use_default_te_mask_fn: - os.environ["NVTE_ONNX_KVCACHE_MAX_SEQ_LEN"] = "0" - else: - os.environ["NVTE_ONNX_KVCACHE_MAX_SEQ_LEN"] = f"{seq_len_q}" - - # Use NVTE_MASKED_SOFTMAX_FUSION to force TE to use forward_torch_softmax - # even when is_in_onnx_export_mode()==False. - os.environ["NVTE_MASKED_SOFTMAX_FUSION"] = "0" - self.fused_scaled_softmax = te.softmax.FusedScaleMaskSoftmax( - mask_func=te.utils.attention_mask_func, - softmax_in_fp32=True, - ) - - def forward(self, inp, mask): - if self.fake_bf16_io: - inp = inp.type(torch.bfloat16) - ret = self.fused_scaled_softmax(inp, mask, "causal", scale=self.scale) - if self.fake_bf16_io: - ret = ret.type(torch.float) - return ret - - # Set dimensions (these are arbitrary). - mask = None - batch_size, n_heads, seq_len_q, seq_len_k = 64, 96, 32, 32 - assert seq_len_q == seq_len_k # This is a causal (TRILU) mask - inp_shape = [batch_size, n_heads, seq_len_q, seq_len_k] - input_tensor = torch.randn( - *inp_shape, device="cuda", dtype=torch.float if fake_bf16_io else precision) - inp = (input_tensor, mask) - high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io) - - # Compare the outputs of TE when using the default softmax mask - # to the TE outputs produced when using the ONNX-compatible causal mask. - # This verifies that _get_onnx_export_causal_mask generates a correct mask. - model = Test_Softmax(use_default_te_mask_fn=True, fake_bf16_io=fake_bf16_io) - te_outputs_default_mask = te_infer(model, inp, is_fp8=True) - with te.onnx_export(True): - # ONNX export mode forces use of the ONNX-compatible causal mask. - model_onnx_mask = Test_Softmax(use_default_te_mask_fn=False, fake_bf16_io=fake_bf16_io) - te_outputs_onnx_mask = te_infer(model_onnx_mask, inp, is_fp8=True) - compare_outputs(te_outputs_default_mask, te_outputs_onnx_mask, - atol=0, rtol=0, max_errors_printed=10, allow_cnt_errors=0, fname="softmax masking") - - # Compare the outputs of TE when using the default softmax mask - # to the ORT ONNX outputs produced when using the ONNX-compatible causal mask. - input_names = ["input", "mask"] - kernel_str = "FusedScaleMaskSoftmax" - fname = f"{kernel_str}{high_prec_str}.onnx" - do_export(model, inp, fname, input_names=input_names) - serialize_inputs_outputs(fname, inp, te_outputs=te_outputs_default_mask, input_names=input_names) - if fake_bf16_io or precision != torch.bfloat16: - atol = 1e-2 if fake_bf16_io else 1e-3 - validate_result( - fname, inp, model_onnx_mask, atol=atol, - input_names=input_names, te_outputs=te_outputs_default_mask) - @pytest.mark.parametrize("scale_factor", [1]) @pytest.mark.parametrize("use_fp8", [False, True]) @@ -1159,13 +1009,13 @@ def test_export_core_attention( query_layer = torch.randn(qkv_size, dtype=precision, device="cuda") key_layer = torch.randn(qkv_size, dtype=precision, device="cuda") value_layer = torch.randn(qkv_size, dtype=precision, device="cuda") - input_names = ["query", "key", "value", "attention_mask", "attn_mask_type"] + input_names = ["query", "key", "value", "attention_mask"] attention_mask = None if use_mask: # Generate a random mask with 50% probability for 0 or 1. probs = 0.5 * torch.ones(qkv_size[1], qkv_size[2], qkv_size[0], qkv_size[0], device="cuda", dtype=precision) attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool) - inp = (query_layer, key_layer, value_layer, attention_mask, attn_mask_type) + inp = (query_layer, key_layer, value_layer, attention_mask) mask_str = get_attn_mask_str(use_mask, attn_mask_type) high_prec_str = dtype2str(precision) @@ -1175,6 +1025,7 @@ def test_export_core_attention( num_attention_heads=num_attention_heads, kv_channels=kv_channels, attention_dropout=0.5, + attn_mask_type=attn_mask_type, ).to(device='cuda') do_export(model, inp, @@ -1190,8 +1041,9 @@ def test_export_core_attention( test_configs_multihead_attention = [ #"use_mask, attn_mask_type" - (False, "no_mask"), # calls ScaledUpperTriangMaskedSoftmax + (False, "causal"), # calls ScaledUpperTriangMaskedSoftmax (True, "padding"), # calls ScaledMaskedSoftmax + (False, "padding"), # calls ScaledSoftmax ] test_configs_attention_type = [ #"input_layernorm, attention_type, fuse_qkv_params" @@ -1265,6 +1117,7 @@ def test_export_multihead_attention( model = te.MultiheadAttention( *attention_args, + attn_mask_type=attn_mask_type, params_dtype=precision, return_layernorm_output=return_layernorm_output, input_layernorm=input_layernorm, @@ -1273,8 +1126,8 @@ def test_export_multihead_attention( return_bias=True, ).to(device='cuda') - inp_context = (hidden_states_context, attention_mask, encoder_output, attn_mask_type) - input_names = ["hidden_states", "attention_mask", "encoder_output", "attn_mask_type"] + inp_context = (hidden_states_context, attention_mask, encoder_output) + input_names = ["hidden_states", "attention_mask", "encoder_output"] output_names=["attention_output", "attention_bias"] do_export(model, inp_context, fname, use_fp8, input_names=input_names, output_names=output_names, dynamic_axes={"hidden_states": {0: "seq", 1:"bs"}, @@ -1342,13 +1195,13 @@ def test_export_transformer_layer( num_attention_heads = 4 input_tensor = torch.rand(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda") - input_names = ["input", "attention_mask", "self_attn_mask_type"] + input_names = ["input", "attention_mask"] attention_mask = None if use_mask and attn_mask_type != "causal": # Generate a random mask with 50% probability for 0 or 1. probs = 0.5 * torch.ones(batch_size, 1, sequence_length, sequence_length, device="cuda", dtype=precision) attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool) - inp = (input_tensor, attention_mask, attn_mask_type) + inp = (input_tensor, attention_mask) fp8_str = "_fp8" if use_fp8 else "" fuse_qkv_params_str = "_fused-qkv" if fuse_qkv_params else "" @@ -1360,6 +1213,7 @@ def test_export_transformer_layer( hidden_size, ffn_hidden_size, num_attention_heads, + self_attn_mask_type=attn_mask_type, output_layernorm=output_layernorm, params_dtype=precision, fuse_qkv_params=fuse_qkv_params, @@ -1541,16 +1395,17 @@ def test_export_gpt_generation( hidden_size, ffn_hidden_size, num_attention_heads, + self_attn_mask_type=attn_mask_type, output_layernorm=output_layernorm, params_dtype=precision, fuse_qkv_params=fuse_qkv_params, zero_centered_gamma=zero_centered_gamma).to(device='cuda') # "Context phase": use full input sequence length - input_names = ["input", "attention_mask", "self_attn_mask_type"] + input_names = ["input"] output_names = ["output"] input_tensor = torch.rand(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda") - inp = (input_tensor, None, attn_mask_type) + inp = (input_tensor,) do_export(model, inp, fname, use_fp8, input_names=input_names, output_names=output_names, dynamic_axes={"input": {0: "seq", 1:"bs"}, diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index b8f9befb1f..f9aa63ce8a 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -186,6 +186,7 @@ def backward(ctx, tensors = split_tensor_along_dim(grad_outputs[0], ctx.dim, 2) return tensors[0], tensors[1], None + class UnfusedDotProductAttention(torch.nn.Module): """Parallel attention w/o QKV and Proj Gemms BMM1 -> softmax + dropout -> BMM2 @@ -883,11 +884,6 @@ class DotProductAttention(torch.nn.Module): and set the environment variable :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0`. In order to disable`flash-attn` entirely, set :attr:`NVTE_FLASH_ATTN=0`. - .. warning:: - - Argument :attr:`attn_mask_type` has been moved to the `forward` method and - is deprecated. It will be fully removed in future releases. - Parameters ---------- num_attention_heads : int @@ -907,6 +903,12 @@ class DotProductAttention(torch.nn.Module): layer_number: int, default = `None` layer number of the current `DotProductAttention` when multiple such modules are concatenated, for instance in consecutive transformer blocks. + attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal` + type of attention mask passed into softmax operation. Overridden by + :attr:`attn_mask_type` in the `forward` method. The forward + arg is useful for dynamically changing mask types, e.g. a different + mask for training and inference. The init arg is useful for cases + involving compilation/tracing, e.g. ONNX export. Parallelism parameters ---------------------- @@ -924,7 +926,7 @@ def __init__( kv_channels: int, num_gqa_groups: Optional[int] = None, attention_dropout: float = 0.0, - attn_mask_type: Optional[str] = None, + attn_mask_type: str = "causal", sequence_parallel: bool = False, tp_size: int = 1, get_rng_state_tracker: Optional[Callable] = None, @@ -934,13 +936,6 @@ def __init__( ) -> None: super().__init__() - if attn_mask_type is not None: - warnings.warn( - "Argument :attr:`attn_mask_type` has been moved to the `forward` method and" - "is deprecated. It will be fully removed in future releases.", - category=DeprecationWarning, - ) - self.attn_mask_type = attn_mask_type self.tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group) self.tp_group = tp_group @@ -1031,7 +1026,7 @@ def forward( key_layer: torch.Tensor, value_layer: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - attn_mask_type: str = "causal", + attn_mask_type: Optional[str] = None, checkpoint_core_attention: bool = False, core_attention_bias_type: str = "no_bias", core_attention_bias: Optional[torch.Tensor] = None, @@ -1087,7 +1082,7 @@ def forward( Value tensor. attention_mask : Optional[torch.Tensor], default = `None` Boolean tensor used to mask out softmax input when not using flash-attn. - attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal` + attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `None` type of attention mask passed into softmax operation. checkpoint_core_attention : bool, default = `False` If true, forward activations for attention are recomputed @@ -1102,13 +1097,7 @@ def forward( Whether to use the fast path to set output tensors to 0 or not. """ - if self.attn_mask_type is not None: - warnings.warn( - "Argument :attr:`attn_mask_type` has been moved to the `forward` method and" - "is deprecated. It will be fully removed in future releases.", - category=DeprecationWarning, - ) - # Keep previous functionality for current users. + if attn_mask_type is None: attn_mask_type = self.attn_mask_type assert (key_layer.shape[-2] == self.num_gqa_groups_per_partition @@ -1229,11 +1218,6 @@ class MultiheadAttention(torch.nn.Module): Argument :attr:`attention_mask` will be ignored in the `forward` call when :attr:`attn_mask_type` is set to `"causal"`. - .. warning:: - - Argument :attr:`attn_mask_type` has been moved to the `forward` method and - is deprecated. It will be fully removed in future releases. - Parameters ---------- hidden_size : int @@ -1259,6 +1243,12 @@ class MultiheadAttention(torch.nn.Module): layer_number: int, default = `None` layer number of the current `TransformerLayer` when multiple such modules are concatenated to form a transformer block. + attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal` + type of attention mask passed into softmax operation. Overridden by + :attr:`attn_mask_type` in the `forward` method. The forward + arg is useful for dynamically changing mask types, e.g. a different + mask for training and inference. The init arg is useful for cases + involving compilation/tracing, e.g. ONNX export. num_gqa_groups : int, default = `None` number of GQA groups in the transformer layer. Grouped Query Attention is described in @@ -1349,7 +1339,7 @@ def __init__( init_method: Optional[Callable] = None, output_layer_init_method: Optional[Callable] = None, layer_number: Optional[int] = None, - attn_mask_type: Optional[str] = None, + attn_mask_type: str = "causal", tp_group: Optional[dist_group_type] = None, tp_size: int = 1, num_gqa_groups: Optional[int] = None, @@ -1375,13 +1365,6 @@ def __init__( ) -> None: super().__init__() - if attn_mask_type is not None: - warnings.warn( - "Argument :attr:`attn_mask_type` has been moved to the `forward` method and" - "is deprecated. It will be fully removed in future releases.", - category=DeprecationWarning, - ) - self.attn_mask_type = attn_mask_type self.layer_number = layer_number self.input_layernorm = input_layernorm @@ -1555,7 +1538,7 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, encoder_output: Optional[torch.Tensor] = None, - attn_mask_type: str = "causal", + attn_mask_type: Optional[str] = None, is_first_microbatch: Optional[bool] = None, checkpoint_core_attention: bool = False, inference_params: Optional[Any] = None, @@ -1578,7 +1561,7 @@ def forward( Input tensor. attention_mask : Optional[torch.Tensor], default = `None` Boolean tensor used to mask out self-attention softmax input. - attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal` + attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `None` type of attention mask passed into softmax operation. encoder_output : Optional[torch.Tensor], default = `None` Output of the encoder block to be fed into the decoder block if using @@ -1613,13 +1596,7 @@ def forward( """ # hidden_states: [sq, b, h] - if self.attn_mask_type is not None: - warnings.warn( - "Argument :attr:`attn_mask_type` has been moved to the `forward` method and" - "is deprecated. It will be fully removed in future releases.", - category=DeprecationWarning, - ) - # Keep previous functionality for current users. + if attn_mask_type is None: attn_mask_type = self.attn_mask_type if attn_mask_type == "padding" and attention_mask is not None: diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 6b45a10fb3..d4046ec7da 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -73,10 +73,9 @@ class TransformerLayer(torch.nn.Module): Arguments :attr:`attention_softmax_in_fp32` and :attr:`apply_query_key_layer_scaling` are deprecated and will be fully removed in future releases. - .. warning:: - - Argument :attr:`self_attn_mask_type` has been moved to the `forward` method and - is deprecated. It will be fully removed in future releases. + .. note:: + Argument :attr:`attention_mask` will be ignored in the `forward` call when + :attr:`self_attn_mask_type` is set to `"causal"`. Parameters ---------- @@ -127,6 +126,12 @@ class TransformerLayer(torch.nn.Module): kv_channels: int, default = `None` number of key-value channels. defaults to :attr:`hidden_size` / :attr:`num_attention_heads` if `None`. + self_attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal` + type of attention mask passed into softmax operation. Overridden by + :attr:`self_attn_mask_type` in the `forward` method. The forward + arg is useful for dynamically changing mask types, e.g. a different + mask for training and inference. The init arg is useful for cases + involving compilation/tracing, e.g. ONNX export. zero_centered_gamma : bool, default = 'False' if set to 'True', gamma parameter in LayerNorm is initialized to 0 and the LayerNorm formula changes to @@ -212,7 +217,7 @@ def __init__( output_layer_init_method: Optional[Callable] = None, layer_number: Optional[int] = None, kv_channels: Optional[int] = None, - self_attn_mask_type: Optional[str] = None, + self_attn_mask_type: str = "causal", tp_group: Optional[dist_group_type] = None, tp_size: int = 1, params_dtype: Optional[torch.dtype] = None, @@ -239,13 +244,6 @@ def __init__( ) -> None: super().__init__() - if self_attn_mask_type is not None: - warnings.warn( - "Argument :attr:`self_attn_mask_type` has been moved to the `forward` method and" - "is deprecated. It will be fully removed in future releases.", - category=DeprecationWarning, - ) - warnings.warn( "Arguments `attention_softmax_in_fp32` and `apply_query_key_layer_scaling`" "are deprecated and will be fully removed in future releases.", @@ -431,7 +429,7 @@ def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - self_attn_mask_type: str = "causal", + self_attn_mask_type: Optional[str] = None, encoder_output: Optional[torch.Tensor] = None, enc_dec_attn_mask: Optional[torch.Tensor] = None, is_first_microbatch: Optional[bool] = None, @@ -456,7 +454,7 @@ def forward( Input tensor. attention_mask : Optional[torch.Tensor], default = `None` Boolean tensor used to mask out self-attention softmax input. - self_attn_mask_type: {'causal', 'padding'}, default = `causal` + self_attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal` type of attention mask passed into softmax operation. encoder_output : Optional[torch.Tensor], default = `None` Output of the encoder block to be fed into the decoder block if using @@ -493,13 +491,7 @@ def forward( Whether to set output tensors to 0 or not before use. """ - if self.self_attn_mask_type is not None: - warnings.warn( - "Argument :attr:`self_attn_mask_type` has been moved to the `forward` method and" - "is deprecated. It will be fully removed in future releases.", - category=DeprecationWarning, - ) - # Keep previous functionality for current users. + if self_attn_mask_type is None: self_attn_mask_type = self.self_attn_mask_type assert ( From a402c4d2cb11d5860385f0bb8edc7597b442d3e6 Mon Sep 17 00:00:00 2001 From: cyanguwa <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 22 Sep 2023 23:44:03 -0700 Subject: [PATCH 56/68] Fix layernorm in GQA (#434) * [PyTorch] Implement GQA based on fused q, k, v projection. Additionally fixes #392 Signed-off-by: Markus Schnoes * [PyTorch] Extend parameters_split option in Linear and LayerNormLinear to support splitting with different sizes as required by unfused GQA. Signed-off-by: Markus Schnoes * fix parameters split Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix noop cat to bypass torch.cat and support uneven split Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix unit tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix torch.split args Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix cuda graph due to noop_cat Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix lint Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove the use of enumerate when possible Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix strides in SplitAlongDim Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by: Markus Schnoes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: Markus Schnoes --- tests/pytorch/test_fused_attn.py | 13 +- transformer_engine/pytorch/attention.py | 146 +++++++++++------- transformer_engine/pytorch/module/base.py | 37 +++-- .../pytorch/module/layernorm_linear.py | 55 ++++--- transformer_engine/pytorch/module/linear.py | 55 ++++--- 5 files changed, 194 insertions(+), 112 deletions(-) diff --git a/tests/pytorch/test_fused_attn.py b/tests/pytorch/test_fused_attn.py index 32442e40fb..1a1515d843 100644 --- a/tests/pytorch/test_fused_attn.py +++ b/tests/pytorch/test_fused_attn.py @@ -141,7 +141,8 @@ def _run_dot_product_attention(dtype, bs, config, backend, ckpt_attn, bias_type) @pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("ckpt_attn", [False]) @pytest.mark.parametrize("bias_type", ["no_bias", "post_scale_bias"]) -def test_transformer_layer(dtype, bs, model, ckpt_attn, bias_type): +@pytest.mark.parametrize("fused_qkv_params", [True, False]) +def test_transformer_layer(dtype, bs, model, ckpt_attn, bias_type, fused_qkv_params): """Test TransformerLayer module when its DotProductAttention is enabled with FlashAttention, FusedAttention, or UnfusedDotProductAttention backend""" @@ -149,11 +150,11 @@ def test_transformer_layer(dtype, bs, model, ckpt_attn, bias_type): if bias_type == "no_bias": flash_attn_fwd, flash_attn_bwd = _run_transformer_layer( - dtype, bs, config, "FlashAttention", ckpt_attn, bias_type) + dtype, bs, config, "FlashAttention", ckpt_attn, bias_type, fused_qkv_params) fused_attn_fwd, fused_attn_bwd = _run_transformer_layer( - dtype, bs, config, "FusedAttention", ckpt_attn, bias_type) + dtype, bs, config, "FusedAttention", ckpt_attn, bias_type, fused_qkv_params) unfused_attn_fwd, unfused_attn_bwd = _run_transformer_layer( - dtype, bs, config, "UnfusedDotProductAttention", ckpt_attn, bias_type) + dtype, bs, config, "UnfusedDotProductAttention", ckpt_attn, bias_type, fused_qkv_params) atol, rtol = (5e-1, 5e-2) if bias_type == "no_bias": @@ -162,7 +163,7 @@ def test_transformer_layer(dtype, bs, model, ckpt_attn, bias_type): assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol) assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol) -def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type): +def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type, fused_qkv_params): reset_rng_states() os.environ["NVTE_FLASH_ATTN"] = "0" @@ -220,7 +221,7 @@ def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type): layer_type="encoder", drop_path_rate=drop_path_rates[layer_number - 1], set_parallel_mode=True, - fuse_qkv_params=True, + fuse_qkv_params=fused_qkv_params, zero_centered_gamma=False, qkv_weight_interleaved=False, ub_tp_comm_overlap=False, diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index f9aa63ce8a..bcf5584f3d 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -8,8 +8,9 @@ import math from importlib.metadata import version from contextlib import nullcontext -from typing import Any, Callable, Optional, Tuple, Union, Dict +from typing import Any, Callable, Optional, Tuple, Union, Dict, List from pkg_resources import packaging +import numpy as np import torch @@ -84,48 +85,61 @@ def apply_rotary_pos_emb(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: return torch.cat((t, t_pass), dim=-1) -class _SplitLastDim(torch.autograd.Function): +class _SplitAlongDim(torch.autograd.Function): """""" @staticmethod def forward(ctx, mixed_x_layer: torch.Tensor, - num_parts: int + split_dim: int, + split_size_or_sections: Union[int, List[int], Tuple[int]], ) -> Tuple[torch.Tensor, ...]: - return split_tensor_along_dim(mixed_x_layer, -1, num_parts) + ctx.split_dim = split_dim + ctx.split_size_or_sections = split_size_or_sections + return torch.split(mixed_x_layer, split_size_or_sections, dim = split_dim) @staticmethod def backward(ctx, *grad_outputs): assert len(grad_outputs) > 0, "No gradients received for backprop!" + if isinstance(ctx.split_size_or_sections, (list, tuple)): + split_sizes = ctx.split_size_or_sections + assert (len(grad_outputs) == len(split_sizes) + ), "Unequal number of gradients vs split sections for backprop!" + if isinstance(ctx.split_size_or_sections, int): + split_sizes = [ctx.split_size_or_sections] * len(grad_outputs) + dims = len(grad_outputs[0].shape) + split_dim = (ctx.split_dim + dims) % dims + noop_ok = True strides = grad_outputs[0].stride() data_ptr = grad_outputs[0].storage().data_ptr() - shape = grad_outputs[0].shape - last_dim_size = grad_outputs[0].shape[-1] + shape = list(grad_outputs[0].shape) for i, tensor in enumerate(grad_outputs): + shape_i = shape + shape_i[split_dim] = split_sizes[i] + offset_size = sum(split_sizes[:i]) * np.prod(shape[split_dim+1:]) if (tensor.stride() != strides or - tensor.shape != shape or + list(tensor.shape) != shape_i or tensor.storage().data_ptr() != data_ptr or - tensor.storage_offset() != i * last_dim_size): + tensor.storage_offset() != offset_size): noop_ok = False break if noop_ok: - ret = torch.Tensor().to(grad_outputs[0].dtype) ret = torch.Tensor().to(device=grad_outputs[0].device, dtype=grad_outputs[0].dtype) new_shape = list(shape) - new_shape[-1] = new_shape[-1] * len(grad_outputs) - ret.set_(grad_outputs[0].storage(), + new_shape[split_dim] = sum(split_sizes) + ret.set_(grad_outputs[0].untyped_storage(), grad_outputs[0].storage_offset(), new_shape, - grad_outputs[0].stride() + strides ) - return ret, None + return ret, None, None - return torch.cat(grad_outputs, dim = -1), None + return torch.cat(grad_outputs, dim = split_dim), None, None class _CombineQKV(torch.autograd.Function): """""" @@ -1401,8 +1415,8 @@ def __init__( num_attention_heads if num_gqa_groups is None else num_gqa_groups ) assert (num_attention_heads % self.num_gqa_groups == 0 - ), "The number of GQA groups must be divisible by the number of attention heads!" - assert (num_attention_heads % tp_size == 0 + ), "The number of attention heads must be divisible by the number of GQA groups!" + assert (self.num_gqa_groups % tp_size == 0 ), "The number of GQA groups must be divisible by tensor parallel size!" self.num_gqa_groups_per_partition = int(self.num_gqa_groups // tp_size) self.hidden_size_kv = int(hidden_size * self.num_gqa_groups // num_attention_heads) @@ -1419,18 +1433,21 @@ def __init__( qkv_parallel_mode = "column" if set_parallel_mode else None - if self.attention_type == "self" and self.num_gqa_groups == self.num_attention_heads: + if self.attention_type == "self": + parameters_split = {"query_": hidden_size, + "key_": self.hidden_size_kv, + "value_": self.hidden_size_kv} if not fuse_qkv_params else None if self.input_layernorm: self.layernorm_qkv = LayerNormLinear( hidden_size, - 3 * hidden_size, + hidden_size + 2 * self.hidden_size_kv, eps=layernorm_epsilon, init_method=init_method, bias=bias, return_bias=False, parallel_mode=qkv_parallel_mode, return_layernorm_output=return_layernorm_output, - parameters_split=("query_", "key_", "value_") if not fuse_qkv_params else None, + parameters_split=parameters_split, zero_centered_gamma=zero_centered_gamma, ub_bulk_wgrad=ub_bulk_wgrad, ub_bulk_dgrad=ub_bulk_dgrad, @@ -1441,17 +1458,15 @@ def __init__( else: self.qkv = Linear( hidden_size, - 3 * hidden_size, + hidden_size + 2 * self.hidden_size_kv, init_method=init_method, bias=bias, return_bias=False, parallel_mode=qkv_parallel_mode, - parameters_split=("query_", "key_", "value_") if not fuse_qkv_params else None, + parameters_split=parameters_split, **common_gemm_kwargs, ) - elif ((self.attention_type == "cross") - or (self.attention_type == "self" - and self.num_gqa_groups != self.num_attention_heads)): + elif self.attention_type == "cross": if self.input_layernorm: self.layernorm_query = LayerNormLinear( hidden_size, @@ -1461,6 +1476,7 @@ def __init__( bias=bias, return_bias=False, parallel_mode=qkv_parallel_mode, + parameters_split=("query_",) if not fuse_qkv_params else None, return_layernorm_output=return_layernorm_output, zero_centered_gamma=zero_centered_gamma, ub_bulk_wgrad=ub_bulk_wgrad, @@ -1636,8 +1652,8 @@ def forward( # Query, Key, and Value # ===================== - if self.attention_type == "self" and self.num_gqa_groups == self.num_attention_heads: - # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] + if self.attention_type == "self": + # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn] if self.input_layernorm: layernorm_qkv_outputs = self.layernorm_qkv( hidden_states, @@ -1653,49 +1669,59 @@ def forward( is_first_microbatch=is_first_microbatch, ) + num_queries_per_key_value = (self.num_attention_heads_per_partition // + self.num_gqa_groups_per_partition) if self.qkv_weight_interleaved: - # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] - new_tensor_shape = mixed_x_layer.size()[:-1] + ( - self.num_attention_heads_per_partition, - 3 * self.hidden_size_per_attention_head, - ) - # split along last dimension - split_dim = -1 - else: - # [sq, b, (np * 3 * hn)] --> [sq, b, 3 * np, hn] + # [sq, b, ng * (np/ng + 2) * hn] --> [sq, b, ng, (np/ng + 2), hn] new_tensor_shape = mixed_x_layer.size()[:-1] + ( - 3 * self.num_attention_heads_per_partition, + self.num_gqa_groups_per_partition, + (num_queries_per_key_value + 2), self.hidden_size_per_attention_head, ) # split along second last dimension split_dim = -2 + else: + # [sq, b, ng * (np/ng + 2) * hn] --> [sq, b, (np/ng + 2), ng, hn] + new_tensor_shape = mixed_x_layer.size()[:-1] + ( + (num_queries_per_key_value + 2), + self.num_gqa_groups_per_partition, + self.hidden_size_per_attention_head + ) + # split along third last dimension + split_dim = -3 mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) - # mixed_x_layer --> 3 [sq, b, np, hn] - if split_dim == -1 and not is_in_onnx_export_mode(): - query_layer, key_layer, value_layer = _SplitLastDim.apply(mixed_x_layer, 3) - else: - query_layer, key_layer, value_layer = split_tensor_along_dim( - mixed_x_layer, split_dim, 3 + # qkv_weight_interleaved: + # [sq, b, ng, (np/ng + 2), hn] + # --> [sq, b, ng, np/ng, hn], [sq, b, ng, 1, hn], [sq, b, ng, 1, hn] + # not qkv_weight_interleaved: + # [sq, b, (np/ng + 2), ng, hn] + # --> [sq, b, np/ng, np, hn], [sq, b, 1, ng, hn], [sq, b, 1, ng, hn] + if not is_in_onnx_export_mode(): + query_layer, key_layer, value_layer = _SplitAlongDim.apply( + mixed_x_layer, split_dim, (num_queries_per_key_value, 1, 1) ) - elif ((self.attention_type == "cross") - or (self.attention_type == "self" - and self.num_gqa_groups != self.num_attention_heads)): - - if self.attention_type == "cross": - input_tensor = encoder_output else: - input_tensor = hidden_states - - # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] + query_layer, key_layer, value_layer = torch.split( + mixed_x_layer, (num_queries_per_key_value, 1, 1), dim = split_dim, + ) + + # query: -> [sq, b, np, hn] + # key, value: -> [sq, b, ng, hn] + query_layer, key_layer, value_layer = (x.reshape(x.size(0), x.size(1), -1, + self.hidden_size_per_attention_head) + for x in (query_layer, key_layer, value_layer)) + + elif self.attention_type == "cross": + # Attention heads [sk, b, h] --> [sk, b, (ng * 2 * hn)] mixed_kv_layer = self.key_value( - input_tensor, + encoder_output, is_first_microbatch=is_first_microbatch, ) if self.qkv_weight_interleaved: - # [sq, b, (np * 2 * hn)] --> [sq, b, np, 2 * hn] + # [sq, b, (ng * 2 * hn)] --> [sq, b, ng, 2 * hn] new_tensor_shape = mixed_kv_layer.size()[:-1] + ( self.num_gqa_groups_per_partition, 2 * self.hidden_size_per_attention_head, @@ -1703,7 +1729,7 @@ def forward( # split along last dimension split_dim = -1 else: - # [sq, b, (np * 2 * hn)] --> [sq, b, 2 * np, hn] + # [sq, b, (ng * 2 * hn)] --> [sq, b, 2 * ng, hn] new_tensor_shape = mixed_kv_layer.size()[:-1] + ( 2 * self.num_gqa_groups_per_partition, self.hidden_size_per_attention_head, @@ -1713,11 +1739,15 @@ def forward( mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape) - # mixed_kv_layer --> 2 [sk, b, np, hn] - if split_dim == -1 and not is_in_onnx_export_mode(): - key_layer, value_layer = _SplitLastDim.apply(mixed_kv_layer, 2) + # mixed_kv_layer --> 2 [sk, b, ng, hn] + if not is_in_onnx_export_mode(): + key_layer, value_layer = _SplitAlongDim.apply( + mixed_kv_layer, split_dim, mixed_kv_layer.shape[split_dim] // 2, + ) else: - key_layer, value_layer = split_tensor_along_dim(mixed_kv_layer, split_dim, 2) + key_layer, value_layer = torch.split( + mixed_kv_layer, mixed_kv_layer.shape[split_dim] // 2, dim = split_dim, + ) # Attention head [sq, b, h] --> [sq, b, hp] if self.input_layernorm: diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 82d39eeaf0..50d7b9f2fb 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -212,8 +212,9 @@ def forward(ctx, *params_split: Tuple[torch.Tensor, ...], ) -> torch.Tensor: assert not full_param_buffer.requires_grad, "Buffers should not require gradient" + sum_params_shape = sum(p.shape[0] for p in params_split) assert ( - full_param_buffer.shape[0] % len(params_split) == 0 + full_param_buffer.shape[0] == sum_params_shape ), "Dimensions not compatible for concatenation" param_temp = full_param_buffer.new() @@ -223,18 +224,19 @@ def forward(ctx, full_param_buffer.stride()) param_temp.requires_grad = True - ctx.save_for_backward(full_param_buffer, *params_split) + ctx.save_for_backward(*params_split) return param_temp @staticmethod def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: - full_param_buffer, *params_split = ctx.saved_tensors - - split_size = full_param_buffer.shape[0] // len(params_split) + params_split = ctx.saved_tensors grads = [] - + slice_begin = 0 for i, _ in enumerate(params_split): - grads.append(grad_output[i * split_size : (i+1) * split_size]) + slice_size = params_split[i].shape[0] + slice_end = slice_begin + slice_size + grads.append(grad_output[slice_begin:slice_end]) + slice_begin = slice_end return None, *grads @@ -753,7 +755,11 @@ def grad_output_preprocess( return grad_output_mat, grad_output_c, grad_output_t, grad_bias - def noop_cat(self, buffer_name: str, pnames: List[str]) -> torch.Tensor: + def noop_cat(self, + buffer_name: str, + pnames: List[str], + parameters_split: Dict[str, int] + ) -> torch.Tensor: """No-op replacement of `torch.cat`. The buffer and split parameters must occupy the same memory region. If this is not the case, then the split parameters are concatenated and the buffer is overwritten. The parameters' memory is then @@ -762,17 +768,24 @@ def noop_cat(self, buffer_name: str, pnames: List[str]) -> torch.Tensor: assert hasattr(self, buffer_name), f"No buffer named {buffer_name}" full_param_buffer = getattr(self, buffer_name) - split_size = full_param_buffer.shape[0] // len(pnames) params = [getattr(self, name) for name in pnames] + slice_begin = 0 for i, p in enumerate(params): - if p.data.data_ptr() != full_param_buffer[i*split_size : (i+1)*split_size].data_ptr(): + slice_size = parameters_split[pnames[i].split('_')[0]+'_'] + slice_end = slice_begin + slice_size + if p.data.data_ptr() != full_param_buffer[slice_begin:slice_end].data_ptr(): with torch.no_grad(): setattr(self, buffer_name, torch.cat(params)) - for j, pname in enumerate(pnames): + slice_begin_j = 0 + for pname in pnames: + slice_size_j = parameters_split[pname.split('_')[0]+'_'] + slice_end_j = slice_begin_j + slice_size_j full_param_buffer = getattr(self, buffer_name) setattr(self, pname, - Parameter(full_param_buffer[j*split_size : (j+1)*split_size])) + Parameter(full_param_buffer[slice_begin_j:slice_end_j])) + slice_begin_j = slice_end_j break + slice_begin = slice_end return _NoopCat.apply(getattr(self, buffer_name), *[getattr(self, name) for name in pnames]) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 9115971524..761b0abf6b 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -536,11 +536,14 @@ class LayerNormLinear(TransformerEngineBaseModule): together with the output of the linear transformation. Example use case: residual connection for transformer module is taken post layernorm. - parameters_split : Tuple[str, ...], default = None - if a tuple of strings is provided, the weight and bias parameters of the - module are exposed as `N` separate `torch.nn.parameter.Parameter`s each, - split along the first dimension, where `N` is the length of the argument - and the strings contained are the names of the split parameters. + parameters_split : Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None + if a tuple of strings or a dict of strings to integers is provided, + the weight and bias parameters of the module are exposed as `N` separate + `torch.nn.parameter.Parameter`s each, split along the first dimension, + where `N` is the length of the argument and the strings contained are the + names of the split parameters. In the case of a tuple, each parameter + has the same shape. In the case of a dict, the values give the + `out_features` for each projection. zero_centered_gamma : bool, default = 'False' if set to 'True', gamma parameter in LayerNorm is initialized to 0 and the LayerNorm formula changes to @@ -607,7 +610,7 @@ def __init__( parallel_mode: Optional[str] = None, return_layernorm_output: bool = False, skip_weight_param_allocation: bool = False, - parameters_split: Optional[Tuple[str, ...]] = None, + parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None, zero_centered_gamma: bool = False, ub_bulk_wgrad: bool = False, ub_bulk_dgrad: bool = False, @@ -707,23 +710,35 @@ def __init__( self.bias_tensor.zero_() if parameters_split is None: - parameters_split = ("",) - - assert ( - self.out_features % len(parameters_split) == 0 - ), f"Weight and bias params cannot be split into {len(parameters_split)} parts" - - split_size = self.out_features // len(parameters_split) + parameters_split = {"": self.out_features} + elif isinstance(parameters_split, tuple): + assert ( + self.out_features % len(parameters_split) == 0 + ), f"Weight and bias params cannot be split into {len(parameters_split)} parts" + split_size = self.out_features // len(parameters_split) + parameters_split = {key: split_size for key in parameters_split} + elif isinstance(parameters_split, dict): + overall_split_size = sum(parameters_split.values()) + assert( + self.out_features == overall_split_size + ), f"Overall sum of parameters_split (={overall_split_size}) does not match "\ + f"to out features (={self.out_features})" + else: + assert False, "Type of 'parameters_split' is not None, tuple or dict" + self.updated_parameters_split = parameters_split self.weight_names = [] self.bias_names = [] - for i, pname in enumerate(parameters_split): + slice_begin = 0 + for pname, slice_size in parameters_split.items(): wname = pname + "weight" bname = pname + "bias" + slice_end = slice_begin + slice_size + self.register_parameter( - wname, Parameter(self.weight_tensor[i * split_size : (i+1) * split_size]) + wname, Parameter(self.weight_tensor[slice_begin:slice_end]) ) set_tensor_model_parallel_attributes( @@ -735,7 +750,7 @@ def __init__( if self.use_bias: self.register_parameter( - bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size]) + bname, Parameter(self.bias_tensor[slice_begin:slice_end]) ) else: setattr(self, bname, torch.Tensor().to(dtype=params_dtype, device=device)) @@ -746,6 +761,8 @@ def __init__( self.weight_names.append(wname) self.bias_names.append(bname) + slice_begin = slice_end + self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features))) @@ -841,12 +858,14 @@ def forward( bias_tensor = ( self.bias if self.parameters_split is None else self.bias_tensor if not torch.is_grad_enabled() - else self.noop_cat("bias_tensor", self.bias_names) + else self.noop_cat("bias_tensor", self.bias_names, + self.updated_parameters_split) ) weight_tensor = ( self.weight if self.parameters_split is None else self.weight_tensor if not torch.is_grad_enabled() - else self.noop_cat("weight_tensor", self.weight_names) + else self.noop_cat("weight_tensor", self.weight_names, + self.updated_parameters_split) ) # Fetch the fp8 weights placeholders (for linear/gemm) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index c54a7aed73..45a163966b 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -461,11 +461,14 @@ class Linear(TransformerEngineBaseModule): init_method : Callable, default = `None` used for initializing weights in the following way: `init_method(weight)`. When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`. - parameters_split : Tuple[str, ...], default = None - if a tuple of strings is provided, the weight and bias parameters of the - module are exposed as `N` separate `torch.nn.parameter.Parameter`s each, - split along the first dimension, where `N` is the length of the argument - and the strings contained are the names of the split parameters. + parameters_split : Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None + if a tuple of strings or a dict of strings to integers is provided, + the weight and bias parameters of the module are exposed as `N` separate + `torch.nn.parameter.Parameter`s each, split along the first dimension, + where `N` is the length of the argument and the strings contained are the + names of the split parameters. In the case of a tuple, each parameter + has the same shape. In the case of a dict, the values give the + `out_features` for each projection. device : Union[torch.device, str], default = "cuda" The device on which the parameters of the model will allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the @@ -522,7 +525,7 @@ def __init__( params_dtype: Optional[torch.dtype] = None, parallel_mode: Optional[str] = None, skip_weight_param_allocation: bool = False, - parameters_split: Optional[Tuple[str, ...]] = None, + parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None, ub_split_rs: bool = False, ub_split_ag: bool = False, device: Union[torch.device, str] = "cuda", @@ -598,23 +601,35 @@ def __init__( self.bias_tensor.zero_() if parameters_split is None: - parameters_split = ("",) - - assert ( - self.out_features % len(parameters_split) == 0 - ), f"Weight and bias params cannot be split into {len(parameters_split)} parts" - - split_size = self.out_features // len(parameters_split) + parameters_split = {"": self.out_features} + elif isinstance(parameters_split, tuple): + assert ( + self.out_features % len(parameters_split) == 0 + ), f"Weight and bias params cannot be split into {len(parameters_split)} parts" + split_size = self.out_features // len(parameters_split) + parameters_split = {key: split_size for key in parameters_split} + elif isinstance(parameters_split, dict): + overall_split_size = sum(parameters_split.values()) + assert( + self.out_features == overall_split_size + ), f"Overall sum of parameters_split (={overall_split_size}) does not match "\ + f"to out features (={self.out_features})" + else: + assert False, "Type of 'parameters_split' is not None, tuple or dict" + self.updated_parameters_split = parameters_split self.weight_names = [] self.bias_names = [] - for i, pname in enumerate(parameters_split): + slice_begin = 0 + for pname, slice_size in parameters_split.items(): wname = pname + "weight" bname = pname + "bias" + slice_end = slice_begin + slice_size + self.register_parameter( - wname, Parameter(self.weight_tensor[i * split_size : (i+1) * split_size]) + wname, Parameter(self.weight_tensor[slice_begin:slice_end]) ) set_tensor_model_parallel_attributes( @@ -626,7 +641,7 @@ def __init__( if self.use_bias: self.register_parameter( - bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size]) + bname, Parameter(self.bias_tensor[slice_begin:slice_end]) ) else: setattr(self, bname, torch.Tensor().to(dtype=params_dtype, device=device)) @@ -637,6 +652,8 @@ def __init__( self.weight_names.append(wname) self.bias_names.append(bname) + slice_begin = slice_end + self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features))) # For RPL, bias has to be added after TP collectives @@ -715,12 +732,14 @@ def forward( bias_tensor = ( self.bias if self.parameters_split is None else self.bias_tensor if not torch.is_grad_enabled() - else self.noop_cat("bias_tensor", self.bias_names) + else self.noop_cat("bias_tensor", self.bias_names, + self.updated_parameters_split) ) weight_tensor = ( self.weight if self.parameters_split is None else self.weight_tensor if not torch.is_grad_enabled() - else self.noop_cat("weight_tensor", self.weight_names) + else self.noop_cat("weight_tensor", self.weight_names, + self.updated_parameters_split) ) # Fetch the fp8 weights placeholders (for linear/gemm) From 2f57bffa6321b385a6e4a679b8973c3c7676183e Mon Sep 17 00:00:00 2001 From: cyanguwa <8636796+cyanguwa@users.noreply.github.com> Date: Sun, 24 Sep 2023 23:00:37 -0700 Subject: [PATCH 57/68] [C/Pytorch] Expand layout support for fused attention (#403) * add flexible layout support Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add support for flexible qkv layout Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add more changes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fixes for compiling Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove redudant file Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix options device error Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix typos Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * more changes; WIP Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * more changes; WIP Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fixes and tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fixes and wrong results Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * sb3hd/bs3hd working on top of 3xsbhd/bshd/thd Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix dQ, dK, dV Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add nvtx Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove qkvso_strides on torch side; cover it in generateQKVStrides Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * all 15 layouts pass Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add workspace optimization Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor fixes and test Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * removed most debug info/clean up Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add note to deprecate some qkv layouts Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix code for unit tests in test_fused_attn.py Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * further remove debug info Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove a couple more comments Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix numerics tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fixes for lint Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix fp8 tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix onnx for core attn; not fixed Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove nvtx and add env var for workspace opt Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove testing for env var Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * replace zeros/zeros_like with empty/empty_like Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix nvtx marker name for _q_k_v API Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove sm80 when compiling for h100 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add mapping from qkv layout to layout group and qkv format Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * clean up enums mapping and remove trailing spaces Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * simplify workspace opt control logic; only need env var Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix fp8 test, and minor modifications for other tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * avoid overwriting model configs in unit test Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * random fixes/improvements: get_qkv_format/etc, default values, docstrings, comments Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix minor issues: invalid syntax Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * change workspace opt logic back to FORCE_WORKSPACE_OPT Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix FP8 tests and generateStrides function Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix get_backend logic for max512/arbitrary Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix unit tests; need cleanup Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * clean up unit tests for layouts, and fix minor lint issue Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor tweaks for CI testing: onnx string issue and test fused attn first Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove one unsupported layout from max512 and add a check to qkvpacked API Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix te layer test; reduce test time Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * revert compiler option changes; add back sm80 for even h100 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove some unit tests or make them optional to reduce CI time Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove more unit tests temporarily Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove _q_k_v in naming and add NVTE_ERROR for FP8 Aux_CTX_Tensors size checks Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add more deprecation notes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove temp tests from last commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * replace with te::getenv Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove prints from last commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove redundant contiguous() Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove thd->bs3hd user warning to avoid GPU sync Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * adjust fused attn bs in tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * temporary fix for onnx issue; more fixes in PR 437 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove unused variables Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Signed-off-by: Charlene Yang Signed-off-by: cyanguwa <8636796+cyanguwa@users.noreply.github.com> Signed-off-by: Kirthi Shankar Sivamani Co-authored-by: Kirthi Shankar Sivamani --- qa/L0_unittest/test.sh | 2 +- tests/pytorch/test_fused_attn.py | 202 ++++- tests/pytorch/test_numerics.py | 2 +- tests/pytorch/test_onnx_export.py | 2 + .../common/fused_attn/fused_attn.cpp | 273 ++++++- .../fused_attn_f16_arbitrary_seqlen.cu | 178 ++++- .../fused_attn_f16_arbitrary_seqlen.h | 24 + .../fused_attn_f16_max512_seqlen.cu | 139 +++- .../fused_attn/fused_attn_f16_max512_seqlen.h | 23 + .../common/fused_attn/fused_attn_fp8.cu | 220 ++++- .../common/fused_attn/fused_attn_fp8.h | 39 + transformer_engine/common/fused_attn/utils.cu | 262 +++++- .../include/transformer_engine/fused_attn.h | 215 ++++- transformer_engine/pytorch/attention.py | 753 +++++++++++------- transformer_engine/pytorch/constants.py | 5 + .../pytorch/cpp_extensions/fused_attn.py | 419 +++++++++- transformer_engine/pytorch/csrc/extensions.h | 46 ++ .../pytorch/csrc/extensions/attention.cu | 438 ++++++++++ .../pytorch/csrc/extensions/pybind.cpp | 21 +- transformer_engine/pytorch/transformer.py | 3 +- 20 files changed, 2832 insertions(+), 434 deletions(-) diff --git a/qa/L0_unittest/test.sh b/qa/L0_unittest/test.sh index f02ea1c6e8..268a534a82 100644 --- a/qa/L0_unittest/test.sh +++ b/qa/L0_unittest/test.sh @@ -9,6 +9,6 @@ set -e pip install pytest==6.2.5 onnxruntime==1.13.1 pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py PYTORCH_JIT=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py -NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py pytest -v -s $TE_PATH/tests/pytorch/test_jit.py pytest -v -s $TE_PATH/tests/pytorch/test_fused_attn.py +NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py diff --git a/tests/pytorch/test_fused_attn.py b/tests/pytorch/test_fused_attn.py index 1a1515d843..1b43fa36eb 100644 --- a/tests/pytorch/test_fused_attn.py +++ b/tests/pytorch/test_fused_attn.py @@ -39,20 +39,23 @@ def __init__( model_configs = { "test1": ModelConfig(1, 1024, 16, 64, 128, 0.0, "causal"), - "test2": ModelConfig(1, 1024, 16, 64, 512, 0.0, "causal"), - "test3": ModelConfig(1, 1024, 16, 64, 2048, 0.0, "causal"), - "test4": ModelConfig(1, 2048, 16, 128, 128, 0.0, "causal"), - "test5": ModelConfig(1, 2048, 16, 128, 512, 0.0, "causal"), - "test6": ModelConfig(1, 2048, 16, 128, 2048, 0.0, "causal"), - "test7": ModelConfig(1, 1024, 16, 64, 128, 0.0, "no_mask"), - "test8": ModelConfig(1, 1024, 16, 64, 512, 0.0, "no_mask"), + "test2": ModelConfig(1, 1024, 16, 64, 2048, 0.0, "causal"), + "test3": ModelConfig(1, 2048, 16, 128, 128, 0.0, "causal"), + "test4": ModelConfig(1, 3072, 24, 128, 2048, 0.0, "causal"), + "test5": ModelConfig(1, 1024, 16, 64, 128, 0.0, "no_mask"), } +if os.getenv('NVTE_ADDITIONAL_TESTS', '0') == '1': + model_configs["test6"] = ModelConfig(1, 1024, 16, 64, 512, 0.0, "causal") + model_configs["test7"] = ModelConfig(1, 2048, 16, 128, 512, 0.0, "causal") + model_configs["test8"] = ModelConfig(1, 2048, 16, 128, 2048, 0.0, "causal") + model_configs["test9"] = ModelConfig(1, 1024, 16, 64, 512, 0.0, "no_mask") + param_types = [torch.float16] if torch.cuda.is_bf16_supported(): param_types.append(torch.bfloat16) -batch_sizes = [1, 2, 32] +batch_sizes = [1, 2] # add more if needed, e.g. 32 @pytest.mark.skipif( get_device_compute_capability() < 8.0, reason="Compute capability 8.0+ is required.") @@ -77,10 +80,10 @@ def test_dot_product_attention(dtype, bs, model, ckpt_attn, bias_type): atol, rtol = (2.5e-2, 2.5e-2) if dtype == torch.bfloat16 else (5e-3, 5e-3) if bias_type == "no_bias": - assert torch.allclose(fused_attn_fwd, flash_attn_fwd, atol=atol, rtol=rtol) - assert torch.allclose(fused_attn_bwd, flash_attn_bwd, atol=atol, rtol=rtol) - assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol) - assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol) + torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, atol=atol, rtol=rtol) + torch.testing.assert_close(fused_attn_bwd, flash_attn_bwd, atol=atol, rtol=rtol) + torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol) + torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol) def _run_dot_product_attention(dtype, bs, config, backend, ckpt_attn, bias_type): @@ -126,7 +129,11 @@ def _run_dot_product_attention(dtype, bs, config, backend, ckpt_attn, bias_type) q = inp[:, :,0,:,:] k = inp[:, :,1,:,:] v = inp[:, :,2,:,:] - op = block(q, k, v, attn_mask_type=config.attn_mask_type, + op = block(q, k, v, + qkv_format='sbhd', + cu_seqlens_q = cu_seqlens, + cu_seqlens_kv = cu_seqlens, + attn_mask_type=config.attn_mask_type, checkpoint_core_attention=ckpt_attn, core_attention_bias_type=bias_type, core_attention_bias=bias) @@ -134,6 +141,130 @@ def _run_dot_product_attention(dtype, bs, config, backend, ckpt_attn, bias_type) return op, inp.grad +qkv_layouts = [ + 'sb3hd', 'sbh3d', 'sbhd_sb2hd', 'sbhd_sbh2d', 'sbhd_sbhd_sbhd', + 'bs3hd', 'bsh3d', 'bshd_bs2hd', 'bshd_bsh2d', 'bshd_bshd_bshd', + # will add tests for thd layouts later when the support is available in fused attention + #'t3hd', 'th3d', 'thd_t2hd', 'thd_th2d', 'thd_thd_thd', + ] + +@pytest.mark.skipif( + get_device_compute_capability() < 8.0, reason="Compute capability 8.0+ is required.") +@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("bs", batch_sizes) +@pytest.mark.parametrize("model", model_configs.keys()) +@pytest.mark.parametrize("workspace_opt", [True, False]) +@pytest.mark.parametrize("qkv_layout", qkv_layouts) +def test_dpa_qkv_layout(dtype, bs, model, workspace_opt, qkv_layout): + """Test DotProductAttention module with different QKV layouts""" + + config = model_configs[model] + + flash_attn_fwd, flash_attn_bwd = _run_dpa_qkv_layout( + dtype, bs, config, "FlashAttention", qkv_layout, workspace_opt) + fused_attn_fwd, fused_attn_bwd = _run_dpa_qkv_layout( + dtype, bs, config, "FusedAttention", qkv_layout, workspace_opt) + unfused_attn_fwd, unfused_attn_bwd = _run_dpa_qkv_layout( + dtype, bs, config, "UnfusedDotProductAttention", qkv_layout, workspace_opt) + + atol, rtol = (5e-2, 5e-2) if dtype == torch.bfloat16 else (2.5e-3, 2.5e-3) + torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol) + torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol) + torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, atol = atol, rtol = rtol) + for i in range(len(flash_attn_bwd)): + torch.testing.assert_close(flash_attn_bwd[i], unfused_attn_bwd[i], atol = atol, rtol = rtol) + torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], atol = atol, rtol = rtol) + torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], atol = atol, rtol = rtol) + +def _run_dpa_qkv_layout(dtype, bs, config, backend, qkv_layout, workspace_opt): + + torch.manual_seed(1234) + torch.cuda.manual_seed(1234) + os.environ["NVTE_FLASH_ATTN"] = "0" + os.environ["NVTE_FUSED_ATTN"] = "0" + if backend == "FlashAttention": + os.environ["NVTE_FLASH_ATTN"] = "1" + if backend == "FusedAttention": + os.environ["NVTE_FUSED_ATTN"] = "1" + os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" if workspace_opt else "0" + + + dim_to_num = {'b': bs, + 's': config.seq_len, + 'h': config.num_attention_heads, + 'd': config.head_dim, + 't': bs * config.seq_len, + '3': 3, + '2': 2} + + inp = [] + for i,layout in enumerate(qkv_layout.split('_')): + tensor_shape = [dim_to_num[j] for j in layout] + tensor = 0.1 * torch.randn(tensor_shape, dtype = dtype).cuda() + tensor_count = 1 + split_dim = 0 + for dim,l in enumerate(layout): + if l.isdigit(): + tensor_count = int(l) + split_dim = dim + break + tensors = torch.split(tensor, 1, dim = split_dim) if split_dim != 0 else [tensor] + for j in range(tensor_count): + if split_dim != 0: + inp.append(tensors[j].squeeze(split_dim)) + else: + inp.append(tensors[j]) + for i in range(3): + inp[i].requires_grad=True + + seqlens = torch.empty(bs, dtype = torch.int32).cuda() + seqlens.fill_(config.seq_len) + cu_seqlens = torch.zeros(bs + 1, device = inp[0].device, dtype = torch.int32) + cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0) + qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()]) + qkv_format_no_thd = qkv_format if qkv_format != 'thd' else 'bshd' + op_grad_shape = [dim_to_num[i] for i in qkv_format_no_thd] + op_grad_shape_new = [*op_grad_shape[:-2], op_grad_shape[-2] * op_grad_shape[-1]] + op_grad = 0.001 * torch.randint(0, 200, op_grad_shape_new, dtype = dtype).cuda() + + block = ( + DotProductAttention( + config.num_attention_heads, + config.head_dim, + attention_dropout = config.dropout_p, + attn_mask_type = config.attn_mask_type, + sequence_parallel = False, + tp_size = 1, + get_rng_state_tracker = None, + tp_group = None, + layer_number = 1, + attention_type = "self" + ).to(dtype = dtype).cuda() + ) + + if qkv_format != 'thd': + op = block(inp[0], inp[1], inp[2], qkv_format=qkv_format) + else: + cu_seqlens_q = torch.arange( + 0, + (bs + 1) * config.seq_len, + step=config.seq_len, + dtype=torch.int32, + device=inp[0].device) + cu_seqlens_kv = torch.arange( + 0, + (bs + 1) * config.seq_len, + step=config.seq_len, + dtype=torch.int32, + device=inp[1].device) + op = block(inp[0], inp[1], inp[2], + qkv_format=qkv_format, + cu_seqlens_q = cu_seqlens_q, + cu_seqlens_kv = cu_seqlens_kv) + op.backward(op_grad) + + return op, (inp[0].grad, inp[1].grad, inp[2].grad) + @pytest.mark.skipif( get_device_compute_capability() < 8.0, reason="Compute capability 8.0+ is required.") @pytest.mark.parametrize("dtype", param_types) @@ -158,10 +289,10 @@ def test_transformer_layer(dtype, bs, model, ckpt_attn, bias_type, fused_qkv_par atol, rtol = (5e-1, 5e-2) if bias_type == "no_bias": - assert torch.allclose(fused_attn_fwd, flash_attn_fwd, atol=atol, rtol=rtol) - assert torch.allclose(fused_attn_bwd, flash_attn_bwd, atol=atol, rtol=rtol) - assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol) - assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol) + torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, atol=atol, rtol=rtol) + torch.testing.assert_close(fused_attn_bwd, flash_attn_bwd, atol=atol, rtol=rtol) + torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol) + torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol) def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type, fused_qkv_params): @@ -231,7 +362,7 @@ def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type, fus .cuda() ) - num_iters = 10 + num_iters = 5 for i in range(num_iters): op = block(inp, self_attn_mask_type=config.attn_mask_type, checkpoint_core_attention=ckpt_attn, @@ -269,8 +400,8 @@ def find_factors(x): dtype, bs, config, "UnfusedDotProductAttention", num_q_per_gqa_group) atol, rtol = 5e-1, 5e-2 - assert torch.allclose(flash_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol) - assert torch.allclose(flash_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol) + torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol) + torch.testing.assert_close(flash_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol) def _run_transformer_layer_gqa(dtype, bs, config, backend, num_querys_per_gqa_group): @@ -363,8 +494,8 @@ def test_dpa_fp8(dtype, bs, model): dtype, bs, config, "UnfusedDotProductAttention") atol, rtol = (2.5e-2, 2.5e-2) - assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol) - assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol) + torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol) + torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol) def _run_dpa_fp8(dtype, bs, config, backend): @@ -427,7 +558,7 @@ def _run_dpa_fp8_ref(dtype, bs, config, backend): attention_dropout=config.dropout_p, sequence_parallel=False, tp_size=1, - get_rng_state_tracker=None, + get_rng_state_tracker=get_dummy_cuda_rng_tracker, tp_group=None, layer_number=1, attention_type="self" @@ -439,8 +570,6 @@ def _run_dpa_fp8_ref(dtype, bs, config, backend): v = inp[:, :,2,:,:] op = block(q, k, v, attn_mask_type=config.attn_mask_type) op.backward(op_grad) - torch.save(op,'ctx_ref.pt') - torch.save(inp.grad,'dqkv_ref.pt') return op, inp.grad @@ -455,6 +584,8 @@ def _run_dpa_fp8_ref(dtype, bs, config, backend): from transformer_engine.pytorch.cpp_extensions.fused_attn import ( fused_attn_fwd_qkvpacked, fused_attn_bwd_qkvpacked, + fused_attn_fwd, + fused_attn_bwd, FusedAttnBackend) _CUBLASLT_WORKSPACE_SIZE_BYTES = 33_554_432 # 32MiB @@ -542,11 +673,15 @@ def forward( torch.save(qkv_out_fp16, 'qkv.pt') # FMHA - context_, aux_ctx_tensors, *rest = fused_attn_fwd_qkvpacked( + context_, aux_ctx_tensors, *rest = fused_attn_fwd( is_training, max_s, + max_s, cu_seqlens, - qkv_out, + cu_seqlens, + qkv_out[:,0,:,:], + qkv_out[:,1,:,:], + qkv_out[:,2,:,:], fp8_dtype_forward, FusedAttnBackend["FP8"], None, @@ -558,7 +693,7 @@ def forward( attn_scale=None, dropout=p_dropout, fast_zero_fill=fast_zero_fill, - qkv_layout="qkv_interleaved", + qkv_layout="t3hd", attn_bias_type="no_bias", attn_mask_type="padding", rng_gen=None, @@ -617,10 +752,14 @@ def backward( grad_output, ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward ) - dqkv, *rest = fused_attn_bwd_qkvpacked( + dq, dk, dv, *rest = fused_attn_bwd( ctx.max_s, + ctx.max_s, + ctx.cu_seqlens, ctx.cu_seqlens, - qkv_out, + qkv_out[:,0,:,:], + qkv_out[:,1,:,:], + qkv_out[:,2,:,:], context, proj_dgrad.view_as(context), fp8_dtype_forward, @@ -638,10 +777,11 @@ def backward( None, ctx.p_dropout, ctx.fast_zero_fill, - "qkv_interleaved", + "t3hd", "no_bias", "padding", ) + dqkv = torch.cat([dq.unsqueeze(1), dk.unsqueeze(1), dv.unsqueeze(1)], dim=1) dqkv_grad_output_c = dqkv.view(-1, 3*ctx.hidden_size) dqkv_grad_output_c_fp16 = ext.cast_from_fp8(dqkv_grad_output_c, diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index bf9f7502fd..eeb14ba444 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -871,7 +871,7 @@ def _test_dpa_accuracy(block, bs, dtype, config): key.retain_grad() value.retain_grad() - out = block(query, key, value, mask) + out = block(query, key, value, attention_mask=mask) loss = out.sum() loss.backward() diff --git a/tests/pytorch/test_onnx_export.py b/tests/pytorch/test_onnx_export.py index 533e0cff6a..727ccce3dd 100644 --- a/tests/pytorch/test_onnx_export.py +++ b/tests/pytorch/test_onnx_export.py @@ -1005,6 +1005,7 @@ def test_export_core_attention( # Set dimensions (these are arbitrary). seq_len, batch_size, num_attention_heads, kv_channels = (64, 4, 1, 64) qkv_size = (seq_len, batch_size, num_attention_heads, kv_channels) + qkv_format = "sbhd" query_layer = torch.randn(qkv_size, dtype=precision, device="cuda") key_layer = torch.randn(qkv_size, dtype=precision, device="cuda") @@ -1025,6 +1026,7 @@ def test_export_core_attention( num_attention_heads=num_attention_heads, kv_channels=kv_channels, attention_dropout=0.5, + qkv_format=qkv_format, attn_mask_type=attn_mask_type, ).to(device='cuda') do_export(model, diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index a651ea005f..f724d1d051 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -12,6 +12,66 @@ #include "fused_attn_fp8.h" #include "../util/cuda_runtime.h" +// map NVTE_QKV_Layout to NVTE_QKV_Layout_Group +NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout) { + switch (qkv_layout) { + case NVTE_QKV_Layout::NVTE_SB3HD: + case NVTE_QKV_Layout::NVTE_BS3HD: + case NVTE_QKV_Layout::NVTE_T3HD: + case NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED: + return NVTE_QKV_Layout_Group::NVTE_3HD; + case NVTE_QKV_Layout::NVTE_SBH3D: + case NVTE_QKV_Layout::NVTE_BSH3D: + case NVTE_QKV_Layout::NVTE_TH3D: + return NVTE_QKV_Layout_Group::NVTE_H3D; + case NVTE_QKV_Layout::NVTE_SBHD_SB2HD: + case NVTE_QKV_Layout::NVTE_BSHD_BS2HD: + case NVTE_QKV_Layout::NVTE_THD_T2HD: + case NVTE_QKV_Layout::NVTE_KV_INTERLEAVED: + return NVTE_QKV_Layout_Group::NVTE_HD_2HD; + case NVTE_QKV_Layout::NVTE_SBHD_SBH2D: + case NVTE_QKV_Layout::NVTE_BSHD_BSH2D: + case NVTE_QKV_Layout::NVTE_THD_TH2D: + return NVTE_QKV_Layout_Group::NVTE_HD_H2D; + case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_THD_THD_THD: + case NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED: + return NVTE_QKV_Layout_Group::NVTE_HD_HD_HD; + default: + NVTE_ERROR("qkv_layout not supported!"); + } +} + +// map NVTE_QKV_Layout to NVTE_QKV_Format +NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout) { + switch (qkv_layout) { + case NVTE_QKV_Layout::NVTE_SB3HD: + case NVTE_QKV_Layout::NVTE_SBH3D: + case NVTE_QKV_Layout::NVTE_SBHD_SB2HD: + case NVTE_QKV_Layout::NVTE_SBHD_SBH2D: + case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD: + return NVTE_QKV_Format::NVTE_SBHD; + case NVTE_QKV_Layout::NVTE_BS3HD: + case NVTE_QKV_Layout::NVTE_BSH3D: + case NVTE_QKV_Layout::NVTE_BSHD_BS2HD: + case NVTE_QKV_Layout::NVTE_BSHD_BSH2D: + case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD: + return NVTE_QKV_Format::NVTE_BSHD; + case NVTE_QKV_Layout::NVTE_T3HD: + case NVTE_QKV_Layout::NVTE_TH3D: + case NVTE_QKV_Layout::NVTE_THD_T2HD: + case NVTE_QKV_Layout::NVTE_THD_TH2D: + case NVTE_QKV_Layout::NVTE_THD_THD_THD: + case NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED: + case NVTE_QKV_Layout::NVTE_KV_INTERLEAVED: + case NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED: + return NVTE_QKV_Format::NVTE_THD; + default: + NVTE_ERROR("qkv_layout not supported!"); + } +} + // select a backend for fused attention NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( NVTEDType q_dtype, @@ -26,6 +86,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( const int device_id = cuda::current_device(); const int sm_arch_ = cuda::sm_arch(device_id); NVTE_CHECK(q_dtype == kv_dtype, "Q and KV must have the same data type."); + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); if ((q_dtype == NVTEDType::kNVTEFloat8E4M3) || (q_dtype == NVTEDType::kNVTEFloat8E5M2) && (sm_arch_ >= 90) && (max_seqlen_q == max_seqlen_kv) @@ -33,7 +94,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( && (head_dim == 64) && (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) && (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) - && (qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED)) { + && ((qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) + || (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD))) { #if (CUDNN_VERSION >= 8900) backend = NVTE_Fused_Attn_Backend::NVTE_FP8; #else @@ -52,7 +114,12 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( || (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) && ((qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) - || (qkv_layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED))) { + || (qkv_layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED) + || (qkv_layout == NVTE_QKV_Layout::NVTE_SB3HD) + || (qkv_layout == NVTE_QKV_Layout::NVTE_SBHD_SB2HD) + || (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) + || (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) + || (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD))) { flag_m512 = true; } if ( @@ -65,7 +132,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( && ((head_dim == 64) || (head_dim == 128)) && (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) && (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) - && (qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED)) { + && ((qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) + || (qkv_format == NVTE_QKV_Format::NVTE_SBHD) + || (qkv_format == NVTE_QKV_Format::NVTE_BSHD))) { flag_arb = true; } if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) @@ -438,3 +507,201 @@ void nvte_fused_attn_bwd_kvpacked( NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); } } +// NVTE fused attention FWD with separate Q, K and V +void nvte_fused_attn_fwd( + const NVTETensor Q, + const NVTETensor K, + const NVTETensor V, + const NVTETensor Bias, + NVTETensor S, + NVTETensor O, + NVTETensorPack* Aux_CTX_Tensors, + const NVTETensor cu_seqlens_q, + const NVTETensor cu_seqlens_kv, + const NVTETensor rng_state, + size_t max_seqlen_q, size_t max_seqlen_kv, + bool is_training, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, + NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_flash_attn_fwd); + using namespace transformer_engine; + const Tensor *input_cu_seqlens_q = reinterpret_cast(cu_seqlens_q); + const Tensor *input_cu_seqlens_kv = reinterpret_cast(cu_seqlens_kv); + const Tensor *input_rng_state = reinterpret_cast(rng_state); + const Tensor *input_Q = reinterpret_cast(Q); + const Tensor *input_K = reinterpret_cast(K); + const Tensor *input_V = reinterpret_cast(V); + const Tensor *input_Bias = reinterpret_cast(Bias); + Tensor *input_output_S = reinterpret_cast(S); + Tensor *output_O = reinterpret_cast(O); + Tensor *wkspace = reinterpret_cast(workspace); + + auto ndim = input_Q->data.shape.size(); + size_t b = input_cu_seqlens_q->data.shape[0] - 1; + size_t h = input_Q->data.shape[ndim - 2]; + size_t d = input_Q->data.shape[ndim - 1]; + + auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); + const NVTEDType Q_type = static_cast(input_Q->data.dtype); + const NVTEDType KV_type = static_cast(input_K->data.dtype); + + NVTE_Fused_Attn_Backend fused_attention_backend = + nvte_get_fused_attn_backend( + Q_type, KV_type, + qkv_layout, bias_type, attn_mask_type, + dropout, max_seqlen_q, max_seqlen_kv, d); + + if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { +#if (CUDNN_VERSION >= 8901) + fused_attn_max_512_fwd( + b, max_seqlen_q, max_seqlen_kv, h, d, + is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, + input_Q, input_K, input_V, input_Bias, output_O, + Aux_CTX_Tensors, + input_cu_seqlens_q, input_cu_seqlens_kv, + input_rng_state, + wkspace, stream, handle); +#else + NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); +#endif + } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { +#if (CUDNN_VERSION >= 8900) + fused_attn_arbitrary_seqlen_fwd( + b, max_seqlen_q, max_seqlen_kv, h, d, + is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, + input_Q, input_K, input_V, input_Bias, output_O, + Aux_CTX_Tensors, + input_cu_seqlens_q, input_cu_seqlens_kv, + input_rng_state, + wkspace, stream, handle); +#else + NVTE_ERROR( + "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); +#endif + } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { +#if (CUDNN_VERSION >= 8900) + fused_attn_fp8_fwd( + b, max_seqlen_q, max_seqlen_kv, h, d, + is_training, attn_scale, dropout, qkv_layout, + input_Q, input_K, input_V, input_output_S, output_O, + Aux_CTX_Tensors, + input_cu_seqlens_q, input_cu_seqlens_kv, + input_rng_state, + wkspace, stream, handle); +#else + NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); +#endif + } else { + NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); + } +} +// NVTE fused attention BWD with separate Q, K and V +void nvte_fused_attn_bwd( + const NVTETensor Q, + const NVTETensor K, + const NVTETensor V, + const NVTETensor O, + const NVTETensor dO, + const NVTETensor S, + NVTETensor dP, + const NVTETensorPack* Aux_CTX_Tensors, + NVTETensor dQ, + NVTETensor dK, + NVTETensor dV, + NVTETensor dBias, + const NVTETensor cu_seqlens_q, + const NVTETensor cu_seqlens_kv, + size_t max_seqlen_q, size_t max_seqlen_kv, + float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, + NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_flash_attn_bwd); + using namespace transformer_engine; + const Tensor *input_cu_seqlens_q = reinterpret_cast(cu_seqlens_q); + const Tensor *input_cu_seqlens_kv = reinterpret_cast(cu_seqlens_kv); + const Tensor *input_Q = reinterpret_cast(Q); + const Tensor *input_K = reinterpret_cast(K); + const Tensor *input_V = reinterpret_cast(V); + const Tensor *input_O = reinterpret_cast(O); + const Tensor *input_dO = reinterpret_cast(dO); + const Tensor *input_S = reinterpret_cast(S); + Tensor *input_output_dP = reinterpret_cast(dP); + Tensor *output_dQ = reinterpret_cast(dQ); + Tensor *output_dK = reinterpret_cast(dK); + Tensor *output_dV = reinterpret_cast(dV); + Tensor *output_dBias = reinterpret_cast(dBias); + Tensor *wkspace = reinterpret_cast(workspace); + + auto ndim = input_Q->data.shape.size(); + size_t b = input_cu_seqlens_q->data.shape[0] - 1; + size_t h = input_Q->data.shape[ndim - 2]; + size_t d = input_Q->data.shape[ndim - 1]; + + auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); + const NVTEDType Q_type = static_cast(input_Q->data.dtype); + const NVTEDType KV_type = static_cast(input_K->data.dtype); + + NVTE_Fused_Attn_Backend fused_attention_backend = + nvte_get_fused_attn_backend( + Q_type, KV_type, + qkv_layout, bias_type, attn_mask_type, + dropout, max_seqlen_q, max_seqlen_kv, d); + + if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { +#if (CUDNN_VERSION >= 8901) + Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + fused_attn_max_512_bwd( + b, max_seqlen_q, max_seqlen_kv, h, d, + attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, + input_Q, input_K, input_V, input_dO, + output_S, + output_dQ, output_dK, output_dV, output_dBias, + input_cu_seqlens_q, input_cu_seqlens_kv, + wkspace, stream, handle); +#else + NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); +#endif + } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { +#if (CUDNN_VERSION >= 8900) + Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + const Tensor *input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + fused_attn_arbitrary_seqlen_bwd( + b, max_seqlen_q, max_seqlen_kv, h, d, + attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, + input_Q, input_K, input_V, input_O, input_dO, + output_S, + output_dQ, output_dK, output_dV, output_dBias, + input_cu_seqlens_q, input_cu_seqlens_kv, + input_rng_state, wkspace, stream, handle); +#else + const char *err_msg = + "cuDNN 8.9.0 is required for BF16/FP16 fused attention " + "with arbitrary sequence length. \n"; + NVTE_ERROR(err_msg); +#endif + } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { +#if (CUDNN_VERSION >= 8900) + const Tensor *input_M = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + const Tensor *input_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + const Tensor *input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + fused_attn_fp8_bwd( + b, max_seqlen_q, max_seqlen_kv, h, d, + attn_scale, dropout, qkv_layout, + input_Q, input_K, input_V, input_O, input_dO, + input_M, input_ZInv, + input_S, input_output_dP, + output_dQ, output_dK, output_dV, + input_cu_seqlens_q, input_cu_seqlens_kv, + input_rng_state, + wkspace, stream, handle); +#else + NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); +#endif + } else { + NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); + } +} diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 8bed01732e..e2da13729b 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -15,6 +15,7 @@ #include "../common.h" #include "utils.h" #include "../util/cuda_runtime.h" +#include "../util/system.h" #if (CUDNN_VERSION >= 8900) #define Q_ID 1 @@ -1059,6 +1060,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( auto matmul_3_Desc = cudnn_frontend::MatMulDescBuilder() .setComputeType(CUDNN_DATA_FLOAT) .build(); + if (!use_workspace_opt) { auto matmul_op3 = cudnn_frontend::OperationBuilder( CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) @@ -1221,9 +1223,6 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; - NVTE_CHECK(qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED, - "qkv_layout must be NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED."); - // QKV shape is [b, s, 3, h, d] void *devPtrQKV = input_QKV->data.dptr; const auto stride = 2 * num_head * head_dim; @@ -1295,9 +1294,6 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t max_seqlen, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; - NVTE_CHECK(qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED, - "qkv_layout must be NVTE_QKV_INTERLEAVED."); - // QKV shape is [b, s, 3, h, d] void *devPtrQKV = input_QKV->data.dptr; @@ -1337,21 +1333,16 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t max_seqlen, (batch * num_head * max_seqlen_div_up_q * max_seqlen_div_up_kv * 2 + 1048576 - 1) / 1048576; // default upper limit for dp workspace 256MB size_t max_allowed_dp_workspace = 256; - const char* env_workspace_limit_char = std::getenv("NVTE_FUSED_ATTN_DP_WORKSPACE_LIMIT"); - if (env_workspace_limit_char != nullptr) { - try { - std::string env_dp_workspace_limit(env_workspace_limit_char); - int dp_workspace_limit = std::stoi(env_dp_workspace_limit); - if (dp_workspace_limit > max_allowed_dp_workspace) { - max_allowed_dp_workspace = dp_workspace_limit; - } - } catch (...) { - NVTE_ERROR( - "Invalid argument for NVTE_FUSED_ATTN_DP_WORKSPACE_LIMIT (integer; in MBytes)! \n"); - } - } if (required_dp_workspace <= max_allowed_dp_workspace) { - use_workspace_opt = true; + use_workspace_opt = true; + } + use_workspace_opt = transformer_engine::getenv( + "NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT", use_workspace_opt); + // will not be needed in cuDNN 8.9.6 + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + if ((layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) + || (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D)) { + use_workspace_opt = false; } } #endif @@ -1378,5 +1369,152 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t max_seqlen, NVTE_ERROR("Unexpected workspace_size."); } } + +void fused_attn_arbitrary_seqlen_fwd( + size_t batch, size_t max_seqlen_q, size_t max_seqlen_kv, + size_t num_head, size_t head_dim, bool is_training, + float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, const Tensor *input_Q, const Tensor *input_K, + const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O, + NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, + const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { + using namespace transformer_engine; + + const DType QKV_type = input_Q->data.dtype; + void *devPtrQ = input_Q->data.dptr; + void *devPtrK = input_K->data.dptr; + void *devPtrV = input_V->data.dptr; + void *devPtrO = output_O->data.dptr; + void *devPtrS = nullptr; + + if (Aux_CTX_Tensors->size == 0) { + Aux_CTX_Tensors->size = 2; + Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + output_S->data.dptr = nullptr; + output_S->data.shape = {batch, num_head, max_seqlen_q, 1}; + output_S->data.dtype = DType::kFloat32; + Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + output_rng_state->data.dptr = nullptr; + output_rng_state->data.shape = {2}; + output_rng_state->data.dtype = DType::kInt64; + } else if (Aux_CTX_Tensors->size == 2) { + Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + devPtrS = output_S->data.dptr; + Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + output_rng_state->data.dptr = rng_state->data.dptr; + } else { + NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); + } + + void* devPtrDropoutSeed = rng_state->data.dptr; + void* devPtrDropoutOffset = reinterpret_cast( + reinterpret_cast(rng_state->data.dptr) + 1); + + size_t workspace_size = 0; + + fused_attn_arbitrary_seqlen_fwd_impl(batch, num_head, max_seqlen_q, max_seqlen_kv, head_dim, + is_training, attn_scale, p_dropout, qkv_layout, + devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO, + devPtrDropoutSeed, devPtrDropoutOffset, + get_cudnn_dtype(QKV_type), + workspace->data.dptr, &workspace_size, stream, handle); + + if (workspace_size > 0) { + if (workspace->data.dptr == nullptr) { + workspace->data.shape = {workspace_size}; + workspace->data.dtype = DType::kByte; + return; + } + } else if (workspace_size == 0) { + workspace->data.shape = {1}; + workspace->data.dtype = DType::kByte; + return; + } else { + NVTE_ERROR("Unexpected workspace_size."); + } +} + +void fused_attn_arbitrary_seqlen_bwd(size_t batch, size_t max_seqlen_q, size_t max_seqlen_kv, + size_t num_head, size_t head_dim, float attn_scale, + float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + const Tensor *input_Q, const Tensor *input_K, + const Tensor *input_V, const Tensor *input_O, + const Tensor *input_dO, Tensor *output_S, + Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, + Tensor *output_dBias, const Tensor *cu_seqlens_q, + const Tensor *cu_seqlens_kv, + const Tensor *rng_state, Tensor *workspace, + cudaStream_t stream, cudnnHandle_t handle) { + using namespace transformer_engine; + + const auto QKV_type = input_Q->data.dtype; + void *devPtrQ = input_Q->data.dptr; + void *devPtrK = input_K->data.dptr; + void *devPtrV = input_V->data.dptr; + void* devPtrO = input_O->data.dptr; + void *devPtrdO = input_dO->data.dptr; + + void *devPtrdQ = output_dQ->data.dptr; + void *devPtrdK = output_dK->data.dptr; + void *devPtrdV = output_dV->data.dptr; + void *devPtrSoftmaxStats = nullptr; + devPtrSoftmaxStats = output_S->data.dptr; + + void* devPtrDropoutSeed = rng_state->data.dptr; + void* devPtrDropoutOffset = reinterpret_cast( + reinterpret_cast(rng_state->data.dptr) + 1); + + size_t workspace_size = 0; + + bool use_workspace_opt = false; +#if (CUDNN_VERSION >= 8905) + const int device_id = cuda::current_device(); + const int sm_arch_ = cuda::sm_arch(device_id); + if (sm_arch_ >= 90) { + // quick estimate of dp workspace size + size_t max_seqlen_div_up_q = ((max_seqlen_q + 64 - 1) / 64) * 64; + size_t max_seqlen_div_up_kv = ((max_seqlen_kv + 64 - 1) / 64) * 64; + size_t required_dp_workspace = + (batch * num_head * max_seqlen_div_up_q * max_seqlen_div_up_kv * 2 + 1048576 - 1) / 1048576; + // default upper limit for dp workspace 256MB + size_t max_allowed_dp_workspace = 256; + if (required_dp_workspace <= max_allowed_dp_workspace) { + use_workspace_opt = true; + } + use_workspace_opt = transformer_engine::getenv( + "NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT", use_workspace_opt); + // will not be needed in cuDNN 8.9.6 + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + if ((layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) + || (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D)) { + use_workspace_opt = false; + } + } +#endif + + fused_attn_arbitrary_seqlen_bwd_impl(batch, num_head, max_seqlen_q, max_seqlen_kv, head_dim, + attn_scale, p_dropout, qkv_layout, + devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, + devPtrdQ, devPtrdK, devPtrdV, devPtrdO, + devPtrDropoutSeed, devPtrDropoutOffset, + get_cudnn_dtype(QKV_type), workspace->data.dptr, + &workspace_size, stream, handle, use_workspace_opt); + + if (workspace_size > 0) { + if (workspace->data.dptr == nullptr) { + workspace->data.shape = {workspace_size}; + workspace->data.dtype = DType::kByte; + return; + } + } else if (workspace_size == 0) { + workspace->data.shape = {1}; + workspace->data.dtype = DType::kByte; + return; + } else { + NVTE_ERROR("Unexpected workspace_size."); + } +} } // namespace transformer_engine #endif // CUDNN_VERSION >= 8900 diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index 68ebe0c7c0..202e06987d 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -38,6 +38,30 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t max_seqlen, const Tensor *cu_seqlens, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); +void fused_attn_arbitrary_seqlen_fwd(size_t batch, size_t max_seqlen_q, size_t max_seqlen_kv, + size_t num_head, size_t head_size, bool is_training, + float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + const Tensor *input_Q, const Tensor *input_K, + const Tensor *input_V, const Tensor *input_Bias, + Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, + const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, + const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + +void fused_attn_arbitrary_seqlen_bwd(size_t batch, size_t max_seqlen_q, size_t max_seqlen_kv, + size_t num_head, size_t head_dim, float attn_scale, + float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + const Tensor *input_Q, const Tensor *input_K, + const Tensor *input_V, const Tensor *input_O, + const Tensor *input_dO, Tensor *output_S, + Tensor *output_dQ, Tensor *output_dK, + Tensor *output_dV, Tensor *output_dBias, + const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, + const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + #endif // CUDNN_VERSION >= 8900 } // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu index 00fb3e66c2..663ff37187 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu @@ -1250,9 +1250,6 @@ void fused_attn_max_512_fwd_qkvpacked( Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; - NVTE_CHECK(qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED, - "qkv_layout must be NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED."); - // QKV shape is [b, s, 3, h, d] void *devPtrQKV = input_QKV->data.dptr; const auto stride = 2 * num_head * head_dim; @@ -1323,8 +1320,6 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; - NVTE_CHECK(qkv_layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED, - "qkv_layout must be NVTE_QKV_Layout::NVTE_KV_INTERLEAVED."); NVTE_CHECK(bias_type == NVTE_Bias_Type::NVTE_NO_BIAS || bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS, "NVTE_PRE_SCALE_BIAS is not implemented in fused_attn_max_512."); @@ -1391,6 +1386,76 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k NVTE_ERROR("Unexpected workspace_size."); } } +void fused_attn_max_512_fwd(size_t batch, size_t q_max_seqlen, size_t kv_max_seqlen, + size_t num_head, size_t head_dim, bool is_training, + float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + const Tensor *input_Q, const Tensor *input_K, + const Tensor *input_V, + const Tensor *input_Bias, Tensor *output_O, + NVTETensorPack *Aux_CTX_Tensors, const Tensor *q_cu_seqlens, + const Tensor *kv_cu_seqlens, const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { + using namespace transformer_engine; + + void *devPtrQ = input_Q->data.dptr; + void *devPtrK = input_K->data.dptr; + void *devPtrV = input_V->data.dptr; + + void *devPtrBias = input_Bias->data.dptr; + + void *devPtrO = output_O->data.dptr; + + void *devPtrS = nullptr; + + const DType q_type = input_Q->data.dtype; + const DType kv_type = input_K->data.dtype; + NVTE_CHECK(q_type == kv_type, "data type of Q must be equal to data type of KV."); + + if (Aux_CTX_Tensors->size == 0) { + Aux_CTX_Tensors->size = 1; + Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + output_S->data.dptr = nullptr; + output_S->data.shape = {batch, num_head, q_max_seqlen, kv_max_seqlen}; + output_S->data.dtype = q_type; + } else if (Aux_CTX_Tensors->size == 1) { + Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + devPtrS = output_S->data.dptr; + } else { + NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); + } + + void *devQCuSeqlen = q_cu_seqlens->data.dptr; + void *devKVCuSeqlen = kv_cu_seqlens->data.dptr; + + const DType rng_state_type = rng_state->data.dtype; + NVTE_CHECK(rng_state_type == DType::kInt64); + void *devPtrDropoutSeed = rng_state->data.dptr; + void *devPtrDropoutOffset = + static_cast(static_cast(rng_state->data.dptr) + 1); + + size_t workspace_size = 0; + + fused_attn_max_512_fwd_impl( + batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, is_training, attn_scale, p_dropout, + qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO, devPtrBias, + devQCuSeqlen, devKVCuSeqlen, devPtrDropoutSeed, devPtrDropoutOffset, workspace->data.dptr, + &workspace_size, get_cudnn_dtype(q_type), stream, handle); + + if (workspace_size > 0) { + if (workspace->data.dptr == nullptr) { + workspace->data.shape = {workspace_size}; + workspace->data.dtype = DType::kByte; + return; + } + } else if (workspace_size == 0) { + workspace->data.shape = {1}; + workspace->data.dtype = DType::kByte; + return; + } else { + NVTE_ERROR("Unexpected workspace_size."); + } +} void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t num_head, size_t head_dim, float attn_scale, float p_dropout, @@ -1402,9 +1467,6 @@ void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t nu cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; - NVTE_CHECK(qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED, - "qkv_layout must be NVTE_QKV_INTERLEAVED."); - // QKV shape is [b, s, 3, h, d] void *devPtrQKV = input_QKV->data.dptr; @@ -1465,9 +1527,6 @@ void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; - NVTE_CHECK(qkv_layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED, - "qkv_layout must be NVTE_KV_INTERLEAVED."); - // Q shape is [b, s, h, d] // KV shape is [b, s, 2, h, d] auto stride = 2 * num_head * head_dim; @@ -1518,5 +1577,63 @@ void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k NVTE_ERROR("Unexpected workspace_size."); } } +void fused_attn_max_512_bwd(size_t batch, size_t q_max_seqlen, size_t kv_max_seqlen, + size_t num_head, size_t head_dim, float attn_scale, + float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + const Tensor *input_Q, const Tensor *input_K, + const Tensor *input_V, + const Tensor *input_dO, Tensor *output_S, + Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, + Tensor *output_dBias, + const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { + using namespace transformer_engine; + + void *devPtrQ = input_Q->data.dptr; + void *devPtrK = input_K->data.dptr; + void *devPtrV = input_V->data.dptr; + + void *devPtrdO = input_dO->data.dptr; + + void *devPtrdQ = output_dQ->data.dptr; + void *devPtrdK = output_dK->data.dptr; + void *devPtrdV = output_dV->data.dptr; + + void *devPtrdBias = output_dBias->data.dptr; + + void *devPtrS = output_S->data.dptr; + + // devPtrdS reuses the memory of devPtrS + void *devPtrdS = devPtrS; + + void *devPtrQCuSeqlens = q_cu_seqlens->data.dptr; + void *devPtrKVCuSeqlens = kv_cu_seqlens->data.dptr; + + const auto q_type = input_Q->data.dtype; + const auto kv_type = input_K->data.dtype; + NVTE_CHECK(q_type == kv_type, "data type of Q must be equal to data type of KV."); + size_t workspace_size = 0; + + fused_attn_max_512_bwd_impl( + batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, attn_scale, p_dropout, qkv_layout, + mask_type, bias_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrdQ, devPtrdK, devPtrdV, + devPtrdO, devPtrdS, devPtrdBias, devPtrQCuSeqlens, devPtrKVCuSeqlens, workspace->data.dptr, + &workspace_size, get_cudnn_dtype(q_type), stream, handle); + + if (workspace_size > 0) { + if (workspace->data.dptr == nullptr) { + workspace->data.shape = {workspace_size}; + workspace->data.dtype = DType::kByte; + return; + } + } else if (workspace_size == 0) { + workspace->data.shape = {1}; + workspace->data.dtype = DType::kByte; + return; + } else { + NVTE_ERROR("Unexpected workspace_size."); + } +} } // namespace transformer_engine #endif // CUDNN_VERSION >= 8901 diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.h index 75545d0b40..e2106347ff 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.h @@ -38,6 +38,17 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k const Tensor *kv_cu_seqlens, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); +void fused_attn_max_512_fwd(size_t batch, size_t q_max_seqlen, size_t kv_max_seqlen, + size_t num_head, size_t head_dim, bool is_training, + float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + const Tensor *input_Q, const Tensor *input_K, + const Tensor *input_V, + const Tensor *input_Bias, Tensor *output_O, + NVTETensorPack *Aux_CTX_Tensors, const Tensor *q_cu_seqlens, + const Tensor *kv_cu_seqlens, const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t num_head, size_t head_dim, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, @@ -56,6 +67,18 @@ void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dBias, const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + +void fused_attn_max_512_bwd(size_t batch, size_t q_max_seqlen, size_t kv_max_seqlen, + size_t num_head, size_t head_dim, float attn_scale, + float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + const Tensor *input_Q, const Tensor *input_K, + const Tensor *input_V, + const Tensor *input_dO, Tensor *output_S, + Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, + Tensor *output_dBias, + const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); #endif // CUDNN_VERSION >= 8901 } // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index c4bdecac8f..120406202e 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -173,6 +173,7 @@ static cudnn_frontend::Tensor createScale( static cudnn_frontend::Tensor createScaleWithOffset( const cudnn_frontend::Tensor& prevBlockOutputTensor, const std::string& scale_tensor_name, + NVTE_QKV_Layout layout, cudnnDataType_t tensorType, bool isOutputVirtual, bool isScaleByValue, @@ -192,7 +193,7 @@ static cudnn_frontend::Tensor createScaleWithOffset( generateMatrixStrides(output_dim[0], output_dim[1], output_dim[2], 0 /*s_kv = 0 for placeholder*/, output_dim[3], output_stride, - NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED, NVTE_QKV_Matrix::NVTE_Q_Matrix); + layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); } else { // Otherwise output dim and stride should be the same as prev block dim and stride for (int i = 0; i < 4; i++) { @@ -1163,6 +1164,7 @@ void fused_attn_fp8_fwd_impl(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, in auto OTensor = createScaleWithOffset( OTensor_before_quan_O_tensor, // input tensor "scaleO", // scale tensor + layout, // qkv layout tensorType, // output tensor type false, // output not virtual false, // scale is by value @@ -1515,6 +1517,7 @@ void fused_attn_fp8_bwd_impl(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, in auto dVTensor = createScaleWithOffset( dVTensor_before_quan_dV, // input tensor "scaledV", // scale tensor + layout, // qkv layout CUDNN_DATA_FP8_E5M2, // output tensor type false, // output not virtual false, // scale is by value @@ -1631,7 +1634,7 @@ void fused_attn_fp8_bwd_impl(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, in // (dS * K) * descale dS auto After_dS_K_before_dequan_K = createScale( - After_dS_K, // input tensor + After_dS_K, // input tensor descaledSTensor, // scale tensor CUDNN_DATA_FLOAT, // output tensor type true, // output is virtual @@ -1641,7 +1644,7 @@ void fused_attn_fp8_bwd_impl(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, in // (dS * K) * descale dS * descale K auto After_dS_K_before_quan_dQ = createScale( - After_dS_K_before_dequan_K, // input tensor + After_dS_K_before_dequan_K, // input tensor descaleKTensor, // scale tensor CUDNN_DATA_FLOAT, // output tensor type true, // output is virtual @@ -1651,8 +1654,9 @@ void fused_attn_fp8_bwd_impl(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, in // (dS * K) * descale dS * descale K * scale dQ auto dQ = createScaleWithOffset( - After_dS_K_before_quan_dQ, // input tensor + After_dS_K_before_quan_dQ, // input tensor "scaledQ", // scale tensor + layout, // qkv layout CUDNN_DATA_FP8_E5M2, // output tensor type false, // output not virtual false, // scale is by value @@ -1671,7 +1675,7 @@ void fused_attn_fp8_bwd_impl(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, in // (dS.T * Q) * descale dS auto After_dSTranspose_Q_before_dequan_Q = createScale( - After_dSTranspose_Q, // input tensor + After_dSTranspose_Q, // input tensor descaledSTensor, // scale tensor CUDNN_DATA_FLOAT, // output tensor type true, // output is virtual @@ -1681,7 +1685,7 @@ void fused_attn_fp8_bwd_impl(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, in // (dS.T * Q) * descale dS * descale Q auto After_dSTranspose_Q_before_quan_dK = createScale( - After_dSTranspose_Q_before_dequan_Q, // input tensor + After_dSTranspose_Q_before_dequan_Q, // input tensor descaleQTensor, // scale tensor CUDNN_DATA_FLOAT, // output tensor type true, // output is virtual @@ -1691,8 +1695,9 @@ void fused_attn_fp8_bwd_impl(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, in // (dS.T * Q) * descale dS * descale Q * scale dK auto dK = createScaleWithOffset( - After_dSTranspose_Q_before_quan_dK, // input tensor + After_dSTranspose_Q_before_quan_dK, // input tensor "scaledK", // scale tensor + layout, // qkv layout CUDNN_DATA_FP8_E5M2, // output tensor type false, // output not virtual false, // scale is by value @@ -1911,6 +1916,8 @@ void fused_attn_fp8_fwd_qkvpacked( devPtrM = output_M->data.dptr; devPtrZInv = output_ZInv->data.dptr; output_rng_state->data.dptr = rng_state->data.dptr; + } else { + NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); } void* devPtrAmaxS = input_output_S->amax.dptr; @@ -2048,5 +2055,204 @@ void fused_attn_fp8_bwd_qkvpacked( return; } } +// fused attention FWD FP8 with separate Q, K, V +void fused_attn_fp8_fwd( + size_t b, size_t max_seqlen_q, size_t max_seqlen_kv, + size_t h, size_t d, + bool is_training, float attn_scale, + float p_dropout, NVTE_QKV_Layout qkv_layout, + const Tensor *input_Q, + const Tensor *input_K, + const Tensor *input_V, + Tensor *input_output_S, + Tensor *output_O, + NVTETensorPack* Aux_CTX_Tensors, + const Tensor *cu_seqlens_q, + const Tensor *cu_seqlens_kv, + const Tensor *rng_state, + Tensor *workspace, + cudaStream_t stream, + cudnnHandle_t handle) { + using namespace transformer_engine; + void* devPtrQ = input_Q->data.dptr; + void* devPtrK = input_K->data.dptr; + void* devPtrV = input_V->data.dptr; + void* devPtrDescaleQ = input_Q->scale_inv.dptr; + void* devPtrDescaleK = input_Q->scale_inv.dptr; + void* devPtrDescaleV = input_Q->scale_inv.dptr; + + void* devPtrO = output_O->data.dptr; + void* devPtrAmaxO = output_O->amax.dptr; + void* devPtrScaleO = output_O->scale.dptr; + + void* devPtrM = nullptr; + void* devPtrZInv = nullptr; + if (Aux_CTX_Tensors->size == 0) { + if (is_training) { + Aux_CTX_Tensors->size = 3; + Tensor *output_M = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *output_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + output_M->data.dptr = nullptr; + output_M->data.shape = {b, h, max_seqlen_q, 1}; + output_M->data.dtype = DType::kFloat32; + output_ZInv->data.dptr = nullptr; + output_ZInv->data.shape = {b, h, max_seqlen_q, 1}; + output_ZInv->data.dtype = DType::kFloat32; + output_rng_state->data.dptr = nullptr; + output_rng_state->data.shape = {2}; + output_rng_state->data.dtype = DType::kInt64; + } + } else if (Aux_CTX_Tensors->size == 3) { + Tensor *output_M = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *output_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + devPtrM = output_M->data.dptr; + devPtrZInv = output_ZInv->data.dptr; + output_rng_state->data.dptr = rng_state->data.dptr; + } else { + NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); + } + + void* devPtrAmaxS = input_output_S->amax.dptr; + void* devPtrScaleS = input_output_S->scale.dptr; + void* devPtrDescaleS = input_output_S->scale_inv.dptr; + + void* devPtrcuSeqlensQ = reinterpret_cast( + reinterpret_cast(cu_seqlens_q->data.dptr)); + void* devPtrcuSeqlensKV = reinterpret_cast( + reinterpret_cast(cu_seqlens_kv->data.dptr)); + void* devPtrDropoutSeed = reinterpret_cast( + reinterpret_cast(rng_state->data.dptr)); + void* devPtrDropoutOffset = reinterpret_cast( + reinterpret_cast(rng_state->data.dptr) + 1); + + const DType QKV_type = input_Q->data.dtype; + size_t workspace_size = 0; + + fused_attn::fused_attn_fp8_fwd_impl( + b, max_seqlen_q, max_seqlen_kv, h, d, + is_training, attn_scale, p_dropout, qkv_layout, + devPtrQ, devPtrK, devPtrV, + devPtrM, devPtrZInv, + devPtrO, + devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, + devPtrDescaleS, devPtrScaleS, devPtrScaleO, + devPtrAmaxO, devPtrAmaxS, + devPtrcuSeqlensQ, devPtrcuSeqlensKV, + devPtrDropoutSeed, devPtrDropoutOffset, + get_cudnn_dtype(QKV_type), + workspace->data.dptr, &workspace_size, stream, handle); + + if (workspace_size > 0) { + if (workspace->data.dptr == nullptr) { + workspace->data.shape = { workspace_size }; + workspace->data.dtype = DType::kByte; + return; + } + } else if (workspace_size == 0) { + workspace->data.shape = { 1 }; + workspace->data.dtype = DType::kByte; + return; + } +} +// fused attention BWD FP8 with separate Q, K, V +void fused_attn_fp8_bwd( + size_t b, size_t max_seqlen_q, size_t max_seqlen_kv, + size_t h, size_t d, + float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + const Tensor *input_Q, + const Tensor *input_K, + const Tensor *input_V, + const Tensor *input_O, + const Tensor *input_dO, + const Tensor *input_M, + const Tensor *input_ZInv, + const Tensor *input_S, + Tensor *input_output_dP, + const Tensor *output_dQ, + const Tensor *output_dK, + const Tensor *output_dV, + const Tensor *cu_seqlens_q, + const Tensor *cu_seqlens_kv, + const Tensor *rng_state, + Tensor *workspace, + cudaStream_t stream, + cudnnHandle_t handle) { + using namespace transformer_engine; + void* devPtrQ = input_Q->data.dptr; + void* devPtrK = input_K->data.dptr; + void* devPtrV = input_V->data.dptr; + void* devPtrDescaleQ = input_Q->scale_inv.dptr; + void* devPtrDescaleK = input_Q->scale_inv.dptr; + void* devPtrDescaleV = input_Q->scale_inv.dptr; + + void* devPtrO = input_O->data.dptr; + void* devPtrDescaleO = input_O->scale_inv.dptr; + void* devPtrdO = input_dO->data.dptr; + void* devPtrDescaledO = input_dO->scale_inv.dptr; + + void* devPtrM = input_M->data.dptr; + void* devPtrZInv = input_ZInv->data.dptr; + + void* devPtrScaleS = input_S->scale.dptr; + void* devPtrDescaleS = input_S->scale_inv.dptr; + void* devPtrAmaxdS = input_output_dP->amax.dptr; + void* devPtrScaledS = input_output_dP->scale.dptr; + void* devPtrDescaledS = input_output_dP->scale_inv.dptr; + + void* devPtrdQ = output_dQ->data.dptr; + void* devPtrdK = output_dK->data.dptr; + void* devPtrdV = output_dV->data.dptr; + void* devPtrAmaxdQ = output_dQ->amax.dptr; + void* devPtrAmaxdK = output_dQ->amax.dptr; + void* devPtrAmaxdV = output_dQ->amax.dptr; + void* devPtrScaledQ = output_dQ->scale.dptr; + void* devPtrScaledK = output_dQ->scale.dptr; + void* devPtrScaledV = output_dQ->scale.dptr; + + void* devPtrcuSeqlensQ = reinterpret_cast( + reinterpret_cast(cu_seqlens_q->data.dptr)); + void* devPtrcuSeqlensKV = reinterpret_cast( + reinterpret_cast(cu_seqlens_kv->data.dptr)); + void* devPtrDropoutSeed = reinterpret_cast( + reinterpret_cast(rng_state->data.dptr)); + void* devPtrDropoutOffset = reinterpret_cast( + reinterpret_cast(rng_state->data.dptr) + 1); + + const DType QKV_type = input_Q->data.dtype; + size_t workspace_size = 0; + + fused_attn::fused_attn_fp8_bwd_impl( + b, max_seqlen_q, max_seqlen_kv, h, d, + attn_scale, p_dropout, qkv_layout, + devPtrQ, devPtrK, devPtrV, + devPtrM, devPtrZInv, + devPtrO, devPtrdO, + devPtrdQ, devPtrdK, devPtrdV, + devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, + devPtrDescaleO, devPtrDescaledO, + devPtrDescaleS, devPtrDescaledS, + devPtrScaleS, devPtrScaledS, + devPtrScaledQ, devPtrScaledK, devPtrScaledV, + devPtrAmaxdS, + devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, + devPtrcuSeqlensQ, devPtrcuSeqlensKV, + devPtrDropoutSeed, devPtrDropoutOffset, + get_cudnn_dtype(QKV_type), + workspace->data.dptr, &workspace_size, stream, handle); + + if (workspace_size > 0) { + if (workspace->data.dptr == nullptr) { + workspace->data.shape = { workspace_size }; + workspace->data.dtype = DType::kByte; + return; + } + } else if (workspace_size == 0) { + workspace->data.shape = { 1 }; + workspace->data.dtype = DType::kByte; + return; + } +} #endif // end of CUDNN>=8900 } // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.h b/transformer_engine/common/fused_attn/fused_attn_fp8.h index 111dfddd10..d78f0f97ca 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.h +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.h @@ -46,5 +46,44 @@ void fused_attn_fp8_bwd_qkvpacked( Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + +// fused attention FWD FP8 with separate Q, K, V +void fused_attn_fp8_fwd( + size_t b, size_t max_seqlen_q, size_t max_seqlen_kv, + size_t h, size_t d, + bool is_training, float attn_scale, + float p_dropout, NVTE_QKV_Layout qkv_layout, + const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, + Tensor *input_output_S, + Tensor *output_O, + NVTETensorPack* Aux_CTX_Tensors, + const Tensor *cu_seqlens_q, + const Tensor *cu_seqlens_kv, + const Tensor *rng_state, + Tensor *workspace, + cudaStream_t stream, + cudnnHandle_t handle); + +// fused attention BWD FP8 with separate Q, K, V +void fused_attn_fp8_bwd( + size_t b, size_t max_seqlen_q, size_t max_seqlen_kv, + size_t h, size_t d, + float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, + const Tensor *input_O, + const Tensor *input_dO, + const Tensor *input_M, + const Tensor *input_ZInv, + const Tensor *input_S, + Tensor *input_output_dP, + const Tensor *output_dQ, + const Tensor *output_dK, + const Tensor *output_dV, + const Tensor *cu_seqlens_q, + const Tensor *cu_seqlens_kv, + const Tensor *rng_state, + Tensor *workspace, + cudaStream_t stream, + cudnnHandle_t handle); #endif // end of CUDNN>=8900 } // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn/utils.cu b/transformer_engine/common/fused_attn/utils.cu index ebba6efa21..fc4be20cf6 100644 --- a/transformer_engine/common/fused_attn/utils.cu +++ b/transformer_engine/common/fused_attn/utils.cu @@ -30,6 +30,7 @@ void generateMatrixStrides( constexpr int seqlen_q_dim_idx = 2; constexpr int seqlen_kv_dim_idx = 3; + // to be deprecated in the future switch (matrix) { case NVTE_QKV_Matrix::NVTE_Q_Matrix: if (layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) { @@ -37,7 +38,8 @@ void generateMatrixStrides( strideA[seqlen_dim_idx] = 3 * h * d; strideA[head_dim_idx] = d; strideA[batch_dim_idx] = s_q * 3 * h * d; - } else { + } else if ((layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED) + || (layout == NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED)) { strideA[hidden_dim_idx] = 1; strideA[seqlen_dim_idx] = h * d; strideA[head_dim_idx] = d; @@ -55,7 +57,7 @@ void generateMatrixStrides( strideA[hidden_dim_idx] = 1; strideA[head_dim_idx] = d; strideA[batch_dim_idx] = s_kv * 2 * h * d; - } else { + } else if (layout == NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED) { strideA[seqlen_dim_idx] = h * d; strideA[hidden_dim_idx] = 1; strideA[head_dim_idx] = d; @@ -73,7 +75,7 @@ void generateMatrixStrides( strideA[hidden_transpose_dim_idx] = 1; strideA[head_dim_idx] = d; strideA[batch_dim_idx] = s_kv * 2 * h * d; - } else { + } else if (layout == NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED) { strideA[seqlen_transpose_dim_idx] = h * d; strideA[hidden_transpose_dim_idx] = 1; strideA[head_dim_idx] = d; @@ -91,7 +93,7 @@ void generateMatrixStrides( strideA[seqlen_dim_idx] = 2* h * d; strideA[head_dim_idx] = d; strideA[batch_dim_idx] = s_kv * 2 * h * d; - } else { + } else if (layout == NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED) { strideA[hidden_dim_idx] = 1; strideA[seqlen_dim_idx] = h * d; strideA[head_dim_idx] = d; @@ -100,21 +102,21 @@ void generateMatrixStrides( break; case NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose: if (layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) { - strideA[hidden_transpose_dim_idx] = 1; - strideA[seqlen_transpose_dim_idx] = 3 * h * d; - strideA[head_dim_idx] = d; - strideA[batch_dim_idx] = s_kv * 3 * h * d; - } else if (layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED) { - strideA[hidden_transpose_dim_idx] = 1; - strideA[seqlen_transpose_dim_idx] = 2* h * d; - strideA[head_dim_idx] = d; - strideA[batch_dim_idx] = s_kv * 2 * h * d; - } else { - strideA[hidden_transpose_dim_idx] = 1; - strideA[seqlen_transpose_dim_idx] = h * d; - strideA[head_dim_idx] = d; - strideA[batch_dim_idx] = s_kv * h * d; - } + strideA[hidden_transpose_dim_idx] = 1; + strideA[seqlen_transpose_dim_idx] = 3 * h * d; + strideA[head_dim_idx] = d; + strideA[batch_dim_idx] = s_kv * 3 * h * d; + } else if (layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED) { + strideA[hidden_transpose_dim_idx] = 1; + strideA[seqlen_transpose_dim_idx] = 2* h * d; + strideA[head_dim_idx] = d; + strideA[batch_dim_idx] = s_kv * 2 * h * d; + } else if (layout == NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED) { + strideA[hidden_transpose_dim_idx] = 1; + strideA[seqlen_transpose_dim_idx] = h * d; + strideA[head_dim_idx] = d; + strideA[batch_dim_idx] = s_kv * h * d; + } break; case NVTE_QKV_Matrix::NVTE_S_Matrix: strideA[seqlen_kv_dim_idx] = 1; @@ -129,6 +131,228 @@ void generateMatrixStrides( strideA[batch_dim_idx] = s_q * h * d; break; } + + // new way of getting strides + switch (layout) { + case NVTE_QKV_Layout::NVTE_SB3HD: + if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) + || (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) + || (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strideA[batch_dim_idx] = 3 * h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_dim_idx] = b * 3 * h * d; + strideA[hidden_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) + || (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) { + strideA[batch_dim_idx] = 3 * h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_transpose_dim_idx] = b * 3 * h * d; + strideA[hidden_transpose_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix) { + strideA[batch_dim_idx] = h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_dim_idx] = b * h * d; + strideA[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_SBH3D: + if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) + || (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) + || (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strideA[batch_dim_idx] = 3 * h * d; + strideA[head_dim_idx] = 3 * d; + strideA[seqlen_dim_idx] = b * 3 * h * d; + strideA[hidden_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) + || (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) { + strideA[batch_dim_idx] = 3 * h * d; + strideA[head_dim_idx] = 3 * d; + strideA[seqlen_transpose_dim_idx] = b * 3 * h * d; + strideA[hidden_transpose_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix) { + strideA[batch_dim_idx] = h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_dim_idx] = b * h * d; + strideA[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_SBHD_SB2HD: + if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) + || (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strideA[batch_dim_idx] = 2 * h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_dim_idx] = b * 2 * h * d; + strideA[hidden_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) + || (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) { + strideA[batch_dim_idx] = 2 * h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_transpose_dim_idx] = b * 2 * h * d; + strideA[hidden_transpose_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) + || (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) { + strideA[batch_dim_idx] = h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_dim_idx] = b * h * d; + strideA[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_SBHD_SBH2D: + if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) + || (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strideA[batch_dim_idx] = 2 * h * d; + strideA[head_dim_idx] = 2 * d; + strideA[seqlen_dim_idx] = b * 2 * h * d; + strideA[hidden_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) + || (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) { + strideA[batch_dim_idx] = 2 * h * d; + strideA[head_dim_idx] = 2 * d; + strideA[seqlen_transpose_dim_idx] = b * 2 * h * d; + strideA[hidden_transpose_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) + || (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) { + strideA[batch_dim_idx] = h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_dim_idx] = b * h * d; + strideA[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD: + if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) + || (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) + || (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) + || (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) { + strideA[batch_dim_idx] = h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_dim_idx] = b * h * d; + strideA[hidden_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) + || (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) { + strideA[batch_dim_idx] = h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_transpose_dim_idx] = b * h * d; + strideA[hidden_transpose_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_BS3HD: + case NVTE_QKV_Layout::NVTE_T3HD: + if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) + || (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) + || (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strideA[batch_dim_idx] = s_q * 3 * h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_dim_idx] = 3 * h * d; + strideA[hidden_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) + || (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) { + strideA[batch_dim_idx] = s_q * 3 * h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_transpose_dim_idx] = 3 * h * d; + strideA[hidden_transpose_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix) { + strideA[batch_dim_idx] = s_q * h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_dim_idx] = h * d; + strideA[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_BSH3D: + case NVTE_QKV_Layout::NVTE_TH3D: + if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) + || (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) + || (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strideA[batch_dim_idx] = s_q * 3 * h * d; + strideA[head_dim_idx] = 3 * d; + strideA[seqlen_dim_idx] = 3 * h * d; + strideA[hidden_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) + || (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) { + strideA[batch_dim_idx] = s_q * 3 * h * d; + strideA[head_dim_idx] = 3 * d; + strideA[seqlen_transpose_dim_idx] = 3 * h * d; + strideA[hidden_transpose_dim_idx] = 1; + } else if (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix) { + strideA[batch_dim_idx] = s_q * h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_dim_idx] = h * d; + strideA[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_BSHD_BS2HD: + case NVTE_QKV_Layout::NVTE_THD_T2HD: + if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) + || (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strideA[batch_dim_idx] = s_kv * 2 * h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_dim_idx] = 2 * h * d; + strideA[hidden_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) + || (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) { + strideA[batch_dim_idx] = s_kv * 2 * h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_transpose_dim_idx] = 2 * h * d; + strideA[hidden_transpose_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) + || (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) { + strideA[batch_dim_idx] = s_q * h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_dim_idx] = h * d; + strideA[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_BSHD_BSH2D: + case NVTE_QKV_Layout::NVTE_THD_TH2D: + if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) + || (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strideA[batch_dim_idx] = s_kv * 2 * h * d; + strideA[head_dim_idx] = 2 * d; + strideA[seqlen_dim_idx] = 2 * h * d; + strideA[hidden_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) + || (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) { + strideA[batch_dim_idx] = s_kv * 2 * h * d; + strideA[head_dim_idx] = 2 * d; + strideA[seqlen_transpose_dim_idx] = 2 * h * d; + strideA[hidden_transpose_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) + || (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) { + strideA[batch_dim_idx] = s_q * h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_dim_idx] = h * d; + strideA[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_THD_THD_THD: + if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) + || (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) { + strideA[batch_dim_idx] = s_q * h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_dim_idx] = h * d; + strideA[hidden_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) + || (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strideA[batch_dim_idx] = s_kv * h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_dim_idx] = h * d; + strideA[hidden_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) + || (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) { + strideA[batch_dim_idx] = s_kv * h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_transpose_dim_idx] = h * d; + strideA[hidden_transpose_dim_idx] = 1; + } + break; + } + + if (matrix == NVTE_QKV_Matrix::NVTE_S_Matrix) { + strideA[seqlen_kv_dim_idx] = 1; + strideA[seqlen_q_dim_idx] = s_kv; + strideA[head_dim_idx] = s_q * s_kv; + strideA[batch_dim_idx] = h * s_q * s_kv; + } } bool allowAllConfig(cudnnBackendDescriptor_t engine_config) { diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index b71573ec1b..6de3c63512 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -18,7 +18,17 @@ extern "C" { #endif /*! \enum NVTE_QKV_Layout - * \brief QKV matrix layouts + * \brief Memory layouts of QKV tensors + * `S`, `B`, `H`, `D`, and `T` stand for sequence length, batch size, the number of heads, + head size, and the total number of sequences in a batch, i.e. `t = sum(s_i) for i = 0...b-1`. + `SBHD` and `BSHD`-based layouts are used when sequences in a batch are of equal length + or padded to the same length, and `THD`-based layouts are used when sequences have + different lengths in a batch. + * \note {`NVTE_QKV_INTERLEAVED`, `NVTE_KV_INTERLEAVED` and `NVTE_NOT_INTERLEAVED` + will be deprecated in the next release. Please use their equivalent enums instead, i.e. `NVTE_T3HD`, + `NVTE_THD_T2HD` and `NVTE_THD_THD_THD` when sequences are of variable lengths, and `NVTE_BS3HD`, + `NVTE_BSHD_BS2HD` and `NVTE_BSHD_BSHD_BSHD` when sequences are of equal length or padded + to equal length.} */ enum NVTE_QKV_Layout { /*! Separate Q, K, V tensors. @@ -67,7 +77,51 @@ enum NVTE_QKV_Layout { | num_heads * head_dim \endverbatim */ - NVTE_KV_INTERLEAVED = 2 + NVTE_KV_INTERLEAVED = 2, + + NVTE_SB3HD = 3, + NVTE_SBH3D = 4, + NVTE_SBHD_SB2HD = 5, + NVTE_SBHD_SBH2D = 6, + NVTE_SBHD_SBHD_SBHD = 7, + NVTE_BS3HD = 8, + NVTE_BSH3D = 9, + NVTE_BSHD_BS2HD = 10, + NVTE_BSHD_BSH2D = 11, + NVTE_BSHD_BSHD_BSHD = 12, + NVTE_T3HD = 13, + NVTE_TH3D = 14, + NVTE_THD_T2HD = 15, + NVTE_THD_TH2D = 16, + NVTE_THD_THD_THD = 17, +}; + +/*! \enum NVTE_QKV_Layout_Group + * \brief Grouping of QKV layouts + */ +enum NVTE_QKV_Layout_Group { + /*! 3HD QKV layouts, e.g. BS3HD */ + NVTE_3HD = 0, + /*! H3D QKV layouts, e.g. BSH3D */ + NVTE_H3D = 1, + /*! HD_2HD QKV layouts, e.g. BSHD_BS2HD */ + NVTE_HD_2HD = 2, + /*! HD_H2D QKV layouts, e.g. BSHD_BSH2D */ + NVTE_HD_H2D = 3, + /*! HD_HD_HD QKV layouts, e.g. BSHD_BSHD_BSHD */ + NVTE_HD_HD_HD = 4, +}; + +/*! \enum NVTE_QKV_Format + * \brief Dimension formats for QKV tensors + */ +enum NVTE_QKV_Format { + /*! SBHD QKV format */ + NVTE_SBHD = 0, + /*! BSHD QKV format */ + NVTE_BSHD = 1, + /*! THD QKV format */ + NVTE_THD = 2, }; /*! \enum NVTE_Bias_Type @@ -94,6 +148,9 @@ enum NVTE_Mask_Type { NVTE_CAUSAL_MASK = 2, }; +/*! \enum NVTE_Fused_Attn_Backend + * \brief Fused attention backends + */ enum NVTE_Fused_Attn_Backend { /*! No supported backend */ NVTE_No_Backend = -1, @@ -105,8 +162,24 @@ enum NVTE_Fused_Attn_Backend { NVTE_FP8 = 2, }; +/*! \brief Get layout group for a given QKV layout + * + * \param[in] qkv_layout QKV layout, e.g. sbh3d. + * + * \return qkv layout group, e.g. h3d. + */ +NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout); + +/*! \brief Get QKV format for a given QKV layout + * + * \param[in] qkv_layout QKV layout, e.g. sbh3d. + * + * \return qkv format, e.g. sbhd. + */ +NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout); + /*! \brief Get fused attention backend based on input parameters. - * + * * \param[in] q_dtype The data type of Tensor Q. * \param[in] kv_dtype The data type of Tensors K, V. * \param[in] qkv_layout The layout of Tensors Q, K, V. @@ -152,7 +225,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( * \param[in] cu_seqlens Accumulative sequence lengths, [batch_size + 1]. * \param[in] rng_state Seed and offset of CUDA random number generator. * \param[in] max_seqlen Max sequence length used for computing, - * it may be >= max(cu_seqlens). + * it may be >= max(seqlen_i) for i=0,...batch_size-1. * \param[in] is_training Whether this is in training mode or inference. * \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] dropout Dropout probability. @@ -199,7 +272,7 @@ void nvte_fused_attn_fwd_qkvpacked( * \param[out] dBias The gradient of the Bias tensor. * \param[in] cu_seqlens Accumulative sequence lengths, [batch_size + 1]. * \param[in] max_seqlen Max sequence length used for computing, - * it may be >= max(cu_seqlens). + * it may be >= max(seqlen_i) for i=0,...batch_size-1. * \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] dropout Dropout probability. * \param[in] qkv_layout QKV tensor's layout. @@ -249,10 +322,10 @@ void nvte_fused_attn_bwd_qkvpacked( * \param[in] cu_seqlens_q Accumulative sequence lengths for Q, [batch_size + 1]. * \param[in] cu_seqlens_kv Accumulative sequence lengths for KV, [batch_size + 1]. * \param[in] rng_state Seed and offset of CUDA random number generator. - * \param[in] max_seqlen_q Max sequence length used for computing for Q. - * it may be >= max(cu_seqlens_q). - * \param[in] max_seqlen_kv Max sequence length used for computing for KV. - * it may be >= max(cu_seqlens_kv). + * \param[in] max_seqlen_q Max sequence length used for computing for Q. + * it may be >= max(seqlen_q_i) for i=0,...batch_size-1. + * \param[in] max_seqlen_kv Max sequence length used for computing for KV. + * it may be >= max(seqlen_kv_i) for i=0,...batch_size-1. * \param[in] is_training Whether this is in training mode or inference. * \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] dropout Dropout probability. @@ -300,10 +373,10 @@ void nvte_fused_attn_fwd_kvpacked( * \param[out] dBias The gradient of the Bias tensor. * \param[in] cu_seqlens_q Accumulative sequence lengths for Q, [batch_size + 1]. * \param[in] cu_seqlens_kv Accumulative sequence lengths for KV, [batch_size + 1]. - * \param[in] max_seqlen_q Max sequence length used for computing for Q. - * it may be >= max(cu_seqlens_q). - * \param[in] max_seqlen_kv Max sequence length used for computing for KV. - * it may be >= max(cu_seqlens_kv). + * \param[in] max_seqlen_q Max sequence length used for computing for Q. + * it may be >= max(seqlen_q_i) for i=0,...batch_size-1. + * \param[in] max_seqlen_kv Max sequence length used for computing for KV. + * it may be >= max(seqlen_kv_i) for i=0,...batch_size-1. * \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] dropout Dropout probability. * \param[in] qkv_layout QKV tensor's layout. @@ -332,6 +405,122 @@ void nvte_fused_attn_bwd_kvpacked( NVTETensor workspace, cudaStream_t stream); +/*! \brief Compute dot product attention with separate Q, K and V. + * + * Computes: + * - P = Q * Transpose(K) + Bias + * - S = ScaleMaskSoftmax(P) + * - D = Dropout(S) + * - O = D * Transpose(V) + * + * Support Matrix: + \verbatim + | backend | precision | qkv format | bias | mask | dropout | sequence length | head_dim | + | 0 | FP16/BF16 | SBHD, BSHD | NO/POST_SCALE_BIAS | PADDING/CAUSAL_MASK | Yes | <= 512 | 64 | + | 1 | FP16/BF16 | SBHD, BSHD | NO/POST_SCALE_BIAS | CAUSAL_MASK | Yes | > 512 | 64, 128 | + | 2 | FP8 | THD | NO_BIAS | PADDING_MASK | Yes | <= 512 | 64 | + \endverbatim + * + * \param[in] Q The Q tensor. + * \param[in] K The K tensor. + * \param[in] V The V tensor. + * \param[in] Bias The Bias tensor. + * \param[in,out] S The S tensor. + * \param[out] O The output O tensor. + * \param[out] Aux_CTX_Tensors Auxiliary output tensors when training, + * e.g. M, ZInv, rng_state. + * \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1]. + * \param[in] cu_seqlens_kv Cumulative sequence lengths for K and V, [batch_size + 1]. + * \param[in] rng_state Seed and offset of CUDA random number generator. + * \param[in] max_seqlen_q Max sequence length used for computing for Q. + * it may be >= max(seqlen_q_i) for i=0,...batch_size-1. + * \param[in] max_seqlen_kv Max sequence length used for computing for K and V. + * it may be >= max(seqlen_kv_i) for i=0,...batch_size-1. + * \param[in] is_training Whether this is in training mode or inference. + * \param[in] attn_scale Scaling factor for Q * K.T. + * \param[in] dropout Dropout probability. + * \param[in] qkv_layout QKV tensors' layout. + * \param[in] bias_type Bias type. + * \param[in] attn_mask_type Attention mask type. + * \param[in] workspace Workspace tensor. + * \param[in] stream CUDA stream used for this operation. + */ +void nvte_fused_attn_fwd( + const NVTETensor Q, + const NVTETensor K, + const NVTETensor V, + const NVTETensor Bias, + NVTETensor S, + NVTETensor O, + NVTETensorPack* Aux_CTX_Tensors, + const NVTETensor cu_seqlens_q, + const NVTETensor cu_seqlens_kv, + const NVTETensor rng_state, + size_t max_seqlen_q, size_t max_seqlen_kv, + bool is_training, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, + NVTETensor workspace, + cudaStream_t stream); + +/*! \brief Compute the backward of the dot product attention with separate Q, K and V. + * + * Support Matrix: + \verbatim + | backend | precision | qkv format | bias | mask | dropout | sequence length | head_dim | + | 0 | FP16/BF16 | SBHD, BSHD | NO/POST_SCALE_BIAS | PADDING/CAUSAL_MASK | Yes | <= 512 | 64 | + | 1 | FP16/BF16 | SBHD, BSHD | NO/POST_SCALE_BIAS | CAUSAL_MASK | Yes | > 512 | 64, 128 | + | 2 | FP8 | THD | NO_BIAS | PADDING_MASK | Yes | <= 512 | 64 | + \endverbatim + * + * \param[in] Q The Q tensor. + * \param[in] K The K tensor. + * \param[in] V The V tensor. + * \param[in] O The O tensor from forward. + * \param[in] dO The gradient of the O tensor. + * \param[in] S The S tensor. + * \param[in,out] dP The gradient of the P tensor. + * \param[in] Aux_CTX_Tensors Auxiliary tensors from context when in training mode, + * e.g. M, ZInv, rng_state. + * \param[out] dQ The gradient of the Q tensor. + * \param[out] dK The gradient of the K tensor. + * \param[out] dV The gradient of the V tensor. + * \param[out] dBias The gradient of the Bias tensor. + * \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1]. + * \param[in] cu_seqlens_kv Cumulative sequence lengths for K and V, [batch_size + 1]. + * \param[in] max_seqlen_q Max sequence length used for computing for Q. + * it may be >= max(seqlen_q_i) for i=0,...batch_size-1. + * \param[in] max_seqlen_kv Max sequence length used for computing for K and V. + * it may be >= max(seqlen_kv_i) for i=0,...batch_size-1. + * \param[in] attn_scale Scaling factor for Q * K.T. + * \param[in] dropout Dropout probability. + * \param[in] qkv_layout QKV tensors' layout. + * \param[in] bias_type Bias type. + * \param[in] attn_mask_type Attention mask type. + * \param[in] workspace Workspace tensor. + * \param[in] stream CUDA stream used for this operation. + */ +void nvte_fused_attn_bwd( + const NVTETensor Q, + const NVTETensor K, + const NVTETensor V, + const NVTETensor O, + const NVTETensor dO, + const NVTETensor S, + NVTETensor dP, + const NVTETensorPack* Aux_CTX_Tensors, + NVTETensor dQ, + NVTETensor dK, + NVTETensor dV, + NVTETensor dBias, + const NVTETensor cu_seqlens_q, + const NVTETensor cu_seqlens_kv, + size_t max_seqlen_q, size_t max_seqlen_kv, + float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, + NVTETensor workspace, + cudaStream_t stream); #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index bcf5584f3d..625cd8644e 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -20,6 +20,8 @@ fused_attn_bwd_qkvpacked, fused_attn_fwd_kvpacked, fused_attn_bwd_kvpacked, + fused_attn_fwd, + fused_attn_bwd, QKVLayout, AttnBiasType, AttnMaskType, @@ -37,6 +39,7 @@ AttnMaskTypes, AttnTypes, AttnBiasTypes, + QKVLayouts, dist_group_type, TE_DType, ) @@ -141,64 +144,6 @@ def backward(ctx, return torch.cat(grad_outputs, dim = split_dim), None, None -class _CombineQKV(torch.autograd.Function): - """""" - - @staticmethod - def forward(ctx, - query_layer: torch.Tensor, - key_layer: torch.Tensor, # pylint: disable=unused-argument - value_layer: torch.Tensor, # pylint: disable=unused-argument - dim: int, - ) -> torch.Tensor: - - mixed_layer = torch.Tensor().to(device=query_layer.device, - dtype=query_layer.dtype) - new_shape = list(query_layer.shape) - new_shape[dim] = new_shape[dim] * 3 - mixed_layer.set_(query_layer.untyped_storage(), - query_layer.storage_offset(), - new_shape, - query_layer.stride()) - ctx.dim = dim - return mixed_layer - - @staticmethod - def backward(ctx, - *grad_outputs, - ) -> Tuple[torch.Tensor, ...]: - assert len(grad_outputs) > 0, "No gradients received for backprop!" - tensors = split_tensor_along_dim(grad_outputs[0], ctx.dim, 3) - return tensors[0], tensors[1], tensors[2], None - -class _CombineKV(torch.autograd.Function): - """""" - - @staticmethod - def forward(ctx, - key_layer: torch.Tensor, - value_layer: torch.Tensor, # pylint: disable=unused-argument - dim: int, - ) -> torch.Tensor: - - mixed_layer = torch.Tensor().to(device=key_layer.device, - dtype=key_layer.dtype) - new_shape = list(key_layer.shape) - new_shape[dim] = new_shape[dim] * 2 - mixed_layer.set_(key_layer.untyped_storage(), - key_layer.storage_offset(), - new_shape, - key_layer.stride()) - ctx.dim = dim - return mixed_layer - - @staticmethod - def backward(ctx, - *grad_outputs, - ) -> Tuple[torch.Tensor, ...]: - assert len(grad_outputs) > 0, "No gradients received for backprop!" - tensors = split_tensor_along_dim(grad_outputs[0], ctx.dim, 2) - return tensors[0], tensors[1], None class UnfusedDotProductAttention(torch.nn.Module): @@ -235,6 +180,9 @@ def forward( query_layer: torch.Tensor, key_layer: torch.Tensor, value_layer: torch.Tensor, + qkv_layout: str = "sbh3d", + cu_seqlens_q: Optional[torch.Tensor] = None, # pylint: disable=unused-argument + cu_seqlens_kv: Optional[torch.Tensor] = None, # pylint: disable=unused-argument attn_mask_type: str = "causal", attention_mask: Optional[torch.Tensor] = None, core_attention_bias_type: str = "no_bias", @@ -242,6 +190,15 @@ def forward( ) -> torch.Tensor: """core attention fprop""" + assert (qkv_layout in QKVLayouts + ), f"UnfusedDotProductAttention does not support qkv_layout = {qkv_layout}!" + qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()]) + assert (qkv_format != 'thd' + ), """UnfusedDotProductAttention does not support variable sequence lengths!""" + if qkv_format == 'bshd': + # convert to sbhd and use sbhd implementation for now + query_layer, key_layer, value_layer = [x.transpose(0, 1) + for x in [query_layer, key_layer, value_layer]] assert ( attn_mask_type in AttnMaskTypes ), f"attn_mask_type {attn_mask_type} not supported" @@ -257,7 +214,6 @@ def forward( key_layer.size(0), ) - assert key_layer.shape == value_layer.shape, "Keys and values must have the same shape!" if key_layer.shape[2] != query_layer.shape[2]: assert (query_layer.shape[2]%key_layer.shape[2]==0 ),"The number of attention heads must be divisible by the number of GQA groups!" @@ -367,11 +323,19 @@ def forward( # change view [b, np, sq, hn] context_layer = context_layer.view(*output_size) - # [b, np, sq, hn] --> [sq, b, np, hn] - context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + if qkv_format == 'sbhd': + # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + + # [sq, b, np, hn] --> [sq, b, hp] + context_layer = context_layer.view(seqlen, batch_size, -1) - # [sq, b, np, hn] --> [sq, b, hp] - context_layer = context_layer.view(seqlen, batch_size, -1) + if qkv_format == 'bshd': + # [b, np, sq, hn] --> [b, sq, np, hn] + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + + # [b, sq, np, hn] --> [b, sq, hp] + context_layer = context_layer.view(batch_size, seqlen, -1) return context_layer @@ -406,66 +370,100 @@ def backward(ctx, dq, dk, dv = split_tensor_along_dim(dqkv, -1, 3) return dq, dk, dv +def _get_qkv_layout( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + qkv_format: str = 'sbhd', + ) -> str: + """Get qkv layout. -def _check_qkv_layout(q, k, v): - data_ptr = q.untyped_storage().data_ptr() - check_ptrs = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k, v]) - if not check_ptrs: - return False - - stride = q.stride() - check_strides = all(stride == x.stride() for x in [q, k, v]) - if not check_strides: - return False - - shape = q.shape - check_shapes = all(shape == x.shape for x in [q, k, v]) - if not check_shapes: - return False - - last_dim_size = shape[-1] - check_offsets = all(i * last_dim_size == x.storage_offset() - for i, x in enumerate([q, k, v])) - if check_offsets: - return "sbh3d" - - last_dims_size = shape[-1] * shape[-2] - check_offsets = all(i * last_dims_size == x.storage_offset() - for i, x in enumerate([q, k, v])) - if check_offsets: - return "sb3hd" + Parameters + ---------- + q: torch.Tensor + Query tensor. + k: torch.Tensor + Key tensor. + v: torch.Tensor + Value tensor. + qkv_format: str, default = `sbhd` + Dimension format for `q`, `k` and `v`, {`sbhd`, `bshd`, `thd`}. `s` stands for + the sequence length dimension, `b` batch size, `h` the number of attention heads, + `d` head size, and `t` the total number of sequences in a batch, i.e. + `t = sum(s_i) for i = 0...b-1`. + + Returns + ---------- + qkv_layout: str + Memory layout of `q`, `k` and `v`. Each `qkv_format` can be mapped to one of five + memory layouts. For example, `sb3hd` means `q`, `k`, `v` are created as one chunk + of memory and that they are interleaved in the `2`nd dimension. `sbhd_sbh2d` means + `q` and `kv` are created in two chunks and that `q` itself is contiguous and `k`, `v` + are interleaved with each other in the `3`rd dimension, `k = kv[:,:,:,0,:]` and + `v = kv[:,:,:,1,:]`. + Mapping: + `sbhd`: {`sb3hd`, `sbh3d`, `sbhd_sb2hd`, `sbhd_sbh2d`, `sbhd_sbhd_sbhd`} + `bshd`: {`bs3hd`, `bsh3d`, `bshd_bs2hd`, `bshd_bsh2d`, `bshd_bshd_bshd`} + `thd` : {`t3hd`, `th3d`, `thd_t2hd`, `thd_th2d`, `thd_thd_thd`} + """ - return "other" + check_last_dim_contiguous = all(x.stride(-1) == 1 for x in [q, k, v]) + assert check_last_dim_contiguous, "q, k and v must have stride 1 in their last dimension!" -def _check_kv_layout(k, v): + data_ptr = q.untyped_storage().data_ptr() + check_ptrs_qkv = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k, v]) data_ptr = k.untyped_storage().data_ptr() - check_ptrs = all(x.untyped_storage().data_ptr() == data_ptr for x in [k, v]) - if not check_ptrs: - return False + check_ptrs_kv = all(x.untyped_storage().data_ptr() == data_ptr for x in [k, v]) + stride = q.stride() + check_strides_qkv = all(stride == x.stride() for x in [q, k, v]) stride = k.stride() - check_strides = all(stride == x.stride() for x in [k, v]) - if not check_strides: - return False + check_strides_kv = all(stride == x.stride() for x in [k, v]) + shape = q.shape + check_shapes_qkv = all(shape == x.shape for x in [q, k, v]) shape = k.shape - check_shapes = all(shape == x.shape for x in [k, v]) - if not check_shapes: - return False + check_shapes_kv = all(shape == x.shape for x in [k, v]) - last_dim_size = shape[-1] - check_offsets = all(i * last_dim_size == x.storage_offset() + last_dim_size = q.shape[-1] + check_last_dim_offsets_qkv = all(i * last_dim_size == x.storage_offset() + for i, x in enumerate([q, k, v])) + last_dim_size = k.shape[-1] + check_last_dim_offsets_kv = all(i * last_dim_size == x.storage_offset() for i, x in enumerate([k, v])) - if check_offsets: - return "sbh2d" - last_dims_size = shape[-1] * shape[-2] - check_offsets = all(i * last_dims_size == x.storage_offset() + last_two_dims_size = q.shape[-1] * q.shape[-2] + check_last_two_dims_offsets_qkv = all(i * last_two_dims_size == x.storage_offset() + for i, x in enumerate([q, k, v])) + last_two_dims_size = k.shape[-1] * k.shape[-2] + check_last_two_dims_offsets_kv = all(i * last_two_dims_size == x.storage_offset() for i, x in enumerate([k, v])) - if check_offsets: - return "sb2hd" - return "other" + qkv_layout = None + if (check_ptrs_qkv and check_strides_qkv and check_shapes_qkv + and check_last_two_dims_offsets_qkv + and not check_last_dim_offsets_qkv): + # sb3hd, bs3hd, t3hd + qkv_layout = qkv_format[:-2] + '3' + qkv_format[-2:] + elif check_ptrs_qkv and check_strides_qkv and check_shapes_qkv and check_last_dim_offsets_qkv: + # sbh3d, bsh3d, th3d + qkv_layout = qkv_format[:-1] + '3' + qkv_format[-1:] + elif (check_ptrs_kv and check_strides_kv and check_shapes_kv + and check_last_two_dims_offsets_kv + and not check_last_dim_offsets_kv): + # sbhd_sb2hd, bshd_bs2hd, thd_t2hd + qkv_layout = qkv_format + '_' + qkv_format[:-2] + '2' + qkv_format[-2:] + elif (check_ptrs_kv and check_strides_kv and check_shapes_kv + and check_last_dim_offsets_kv): + # sbhd_sbh2d, bshd_bsh2d, thd_th2d + qkv_layout = qkv_format + '_' + qkv_format[:-1] + '2' + qkv_format[-1:] + elif check_strides_kv and check_shapes_kv: + # sbhd_sbhd_sbhd, bshd_bshd_bshd, thd_thd_thd + qkv_layout = '_'.join(list([qkv_format])*3) + else: + raise Exception("The provided qkv memory layout is not supported!") + + return qkv_layout class FlashAttention(torch.nn.Module): @@ -496,6 +494,9 @@ def forward( query_layer: torch.Tensor, key_layer: torch.Tensor, value_layer: torch.Tensor, + qkv_layout: str = "sbh3d", + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_kv: Optional[torch.Tensor] = None, attn_mask_type: str = "causal", ) -> torch.Tensor: """flash-attn fprop""" @@ -504,52 +505,87 @@ def forward( query_layer.dtype in [torch.float16, torch.bfloat16] and key_layer.dtype in [torch.float16, torch.bfloat16] and value_layer.dtype in [torch.float16, torch.bfloat16] - ), 'FlashAttention currently only supports FP16 and BF16.' + ), "FlashAttention currently only supports FP16 and BF16." assert ( query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda - ), 'FlashAttention currently only supports CUDA tensors.' - - # For now just 128, will make it more general in the future - - if (query_layer.shape[-1] == 128 and - query_layer.shape[0] * query_layer.shape[1] >= 512 and - _check_qkv_layout(query_layer, key_layer, value_layer) == "sbh3d"): - query_layer, key_layer, value_layer = _PrepareQKVForFA.apply(query_layer, - key_layer, - value_layer) - else: - query_layer, key_layer, value_layer = [x.transpose(0,1).contiguous() - for x in (query_layer, key_layer, value_layer)] - - batch_size, seqlen = query_layer.shape[0], query_layer.shape[1] - - # [b, sq, np, hn] + ), "FlashAttention currently only supports CUDA tensors." + assert ( + qkv_layout in QKVLayouts + ), f"FlashAttention does not support qkv_layout = {qkv_layout}!" + + qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()]) + + if qkv_format == 'sbhd': + # For now just 128, will make it more general in the future + if (query_layer.shape[-1] == 128 and + query_layer.shape[0] * query_layer.shape[1] >= 512 and + qkv_layout == "sbh3d"): + query_layer, key_layer, value_layer = _PrepareQKVForFA.apply(query_layer, + key_layer, + value_layer) + else: + query_layer, key_layer, value_layer = [x.transpose(0,1).contiguous() + for x in (query_layer, key_layer, value_layer)] + + if qkv_format == 'bshd': + query_layer, key_layer, value_layer = [x.contiguous() + for x in (query_layer, key_layer, value_layer)] + + if qkv_format in ['sbhd', 'bshd']: + batch_size, max_seqlen_q, max_seqlen_kv = ( + query_layer.shape[0], query_layer.shape[1], key_layer.shape[1]) + if cu_seqlens_q is None: + cu_seqlens_q = torch.arange( + 0, + (batch_size + 1) * max_seqlen_q, + step=max_seqlen_q, + dtype=torch.int32, + device=query_layer.device) + if cu_seqlens_kv is None: + cu_seqlens_kv = torch.arange( + 0, + (batch_size + 1) * max_seqlen_kv, + step=max_seqlen_kv, + dtype=torch.int32, + device=key_layer.device) + + if qkv_format == 'thd': + assert (_flash_attn_2_available + ), "flash-attn v2 is required for variable sequence length support!" + assert (cu_seqlens_q is not None and cu_seqlens_kv is not None + ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!" + seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] + seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] + max_seqlen_q = seqlens_q.max().item() + max_seqlen_kv = seqlens_kv.max().item() + + # [b * s, h, d] query_layer, key_layer, value_layer = [ x.view(x.shape[0] * x.shape[1], *x.shape[2:]) for x in [query_layer, key_layer, value_layer] ] - max_seqlen = seqlen - cu_seqlens = torch.arange( - 0, - (batch_size + 1) * seqlen, - step=seqlen, - dtype=torch.int32, - device=query_layer.device) - with self.attention_dropout_ctx(): fa_optional_forward_kwargs = {} if not _flash_attn_2_available: fa_optional_forward_kwargs["deterministic"] = self.deterministic output = flash_attn_forward_func( - query_layer, key_layer, value_layer, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, + query_layer, key_layer, value_layer, + cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, self.attention_dropout if self.training else 0.0, - softmax_scale=1.0/self.norm_factor, causal=attn_mask_type=="causal", + softmax_scale=1.0/self.norm_factor, + causal=attn_mask_type=="causal", **fa_optional_forward_kwargs ) - # [(b sq), np, hn] -> [sq, b, (np hn)] - return output.view(batch_size, seqlen, -1).transpose(0, 1).contiguous() + if qkv_format == 'sbhd': + # (bs)hd -> bs(hd) -> sb(hd) + output = output.view(batch_size, max_seqlen_q, -1).transpose(0, 1).contiguous() + if qkv_format == 'bshd': + # (bs)hd -> bs(hd) + output = output.view(batch_size, max_seqlen_q, -1).contiguous() + + return output class FusedAttnFunc_qkvpacked(torch.autograd.Function): @@ -685,6 +721,77 @@ def backward(ctx, d_out): None, None, None, None, None, None, None, None, None, None, None, None) +class FusedAttnFunc(torch.autograd.Function): + """Function for FusedAttention with separate Q, K, V tensors""" + + @staticmethod + def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, + q, k, v, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill, + qkv_layout, attn_bias_type, attn_mask_type, + rng_gen, fused_attention_backend, use_FAv2_bwd): + out, aux_ctx_tensors = fused_attn_fwd( + is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, + q, k, v, qkv_dtype, fused_attention_backend, attn_bias, + None, None, None, None, None, + attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, + rng_gen) + + ctx.save_for_backward(q, k, v, out, cu_seqlens_q, cu_seqlens_kv) + ctx.aux_ctx_tensors = aux_ctx_tensors + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_kv = max_seqlen_kv + ctx.qkv_dtype = qkv_dtype + ctx.attn_scale = attn_scale + ctx.dropout_p = dropout_p + ctx.fast_zero_fill = fast_zero_fill + ctx.qkv_layout = qkv_layout + ctx.attn_bias_type = attn_bias_type + ctx.attn_mask_type = attn_mask_type + ctx.fused_attention_backend = fused_attention_backend + ctx.use_FAv2_bwd = use_FAv2_bwd + + return out + + @staticmethod + def backward(ctx, d_out): + q, k, v, out, cu_seqlens_q, cu_seqlens_kv = ctx.saved_tensors + if ctx.use_FAv2_bwd: + softmax_lse, rng_state = ctx.aux_ctx_tensors + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x + d_out, q, k, v, out = [maybe_contiguous(x) + for x in (d_out, q, k, v, out)] + flash_attn_cuda_bwd( + d_out, q, k, v, out, softmax_lse, dq, dk, dv, + cu_seqlens_q, cu_seqlens_kv, ctx.max_seqlen_q, ctx.max_seqlen_kv, + ctx.dropout_p, ctx.attn_scale, False, + ctx.attn_mask_type == "causal", None, rng_state + ) + dq = dq[..., :d_out.shape[-1]] + dk = dk[..., :d_out.shape[-1]] + dv = dv[..., :d_out.shape[-1]] + else: + dq, dk, dv, *rest = fused_attn_bwd( + ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, + q, k, v, out, d_out, + ctx.qkv_dtype, ctx.aux_ctx_tensors, + ctx.fused_attention_backend, + None, None, None, None, None, None, None, None, None, + ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, + ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type) + + # if no_bias, return dqkv + if ctx.attn_bias_type == "no_bias": + return (None, None, None, None, None, dq, dk, dv, None, None, None, + None, None, None, None, None, None, + None, None, None, None, None, None) + # else, return (dqkv, dbias) + return (None, None, None, None, None, dq, dk, dv, None, rest[0], None, + None, None, None, None, None, None, + None, None, None, None, None, None) + class FusedAttention(torch.nn.Module): """Dot product attention, with multiple backends: @@ -695,20 +802,23 @@ class FusedAttention(torch.nn.Module): Support matrix: - | backend | 1 | 2 | - | flash based | no | yes | - | cuDNN based | yes | yes | - | qkv dtype | fp16/bf16 | fp16/bf16 | - | attn_type | self/cross | self | - | qkv_layout | | | - | - qkv | qkv_interleaved | qkv_interleaved | - | - (q,kv) | kv_interleaved | | - | mask_type | causal/no_mask | causal | - | bias_type | no_bias/post_scale_bias | no_bias | - | dropout | yes | yes | - | max_seqlen | <=512 | any | - | head_dim | 64 | 64,128 | - | output dtype | fp16/bf16 | fp16/bf16 | + | backend | 1 | 2 | + | flash based | no | yes | + | cuDNN based | yes | yes | + | qkv dtype | fp16/bf16 | fp16/bf16 | + | attn_type | self/cross | self | + | qkv_layout | | | + | - qkv | qkv_interleaved | qkv_interleaved | + | - (q,kv) | kv_interleaved | | + | - (q,k,v) | sb3hd, bs3hd | sb3hd, bs3hd | + | | sbhd_sb2hd, bshd_bs2hd | sbhd_sb2hd, bshd_bs2hd | + | | bshd_bshd_bshd | sbhd_sbhd_sbhd, bshd_bshd_bshd | + | mask_type | causal/no_mask | causal | + | bias_type | no_bias/post_scale_bias | no_bias | + | dropout | yes | yes | + | max_seqlen | <=512 | any | + | head_dim | 64 | 64,128 | + | output dtype | fp16/bf16 | fp16/bf16 | """ def __init__( @@ -733,6 +843,9 @@ def forward( query_layer: torch.Tensor, key_layer: torch.Tensor, value_layer: torch.Tensor, + qkv_layout: str = "sbh3d", + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_kv: Optional[torch.Tensor] = None, attn_mask_type: str = "causal", fused_attention_backend: tex.NVTE_Fused_Attn_Backend = tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend, @@ -743,8 +856,8 @@ def forward( """fused attention fprop""" assert (fused_attention_backend - != tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend - ), 'No fused attention backend supports this input combination!' + != tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend + ), 'No fused attention backend supports this input combination!' assert ( (query_layer.dtype in [torch.float16, torch.bfloat16]) and (key_layer.dtype in [torch.float16, torch.bfloat16]) @@ -753,132 +866,66 @@ def forward( assert ( query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda ), 'FusedAttention only supports CUDA tensors.' + assert ( + qkv_layout in QKVLayouts + ), f"FusedAttention does not support qkv_layout = {qkv_layout}!" + + qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()]) + if qkv_format in ['sbhd', 'bshd']: + if qkv_format == 'sbhd': + batch_size, max_seqlen_q, max_seqlen_kv = ( + query_layer.shape[1], query_layer.shape[0], key_layer.shape[0]) + if qkv_format == 'bshd': + batch_size, max_seqlen_q, max_seqlen_kv = ( + query_layer.shape[0], query_layer.shape[1], key_layer.shape[1]) + if cu_seqlens_q is None: + cu_seqlens_q = torch.arange( + 0, + (batch_size + 1) * max_seqlen_q, + step=max_seqlen_q, + dtype=torch.int32, + device=query_layer.device) + if cu_seqlens_kv is None: + cu_seqlens_kv = torch.arange( + 0, + (batch_size + 1) * max_seqlen_kv, + step=max_seqlen_kv, + dtype=torch.int32, + device=key_layer.device) + if qkv_format == 'thd': + assert (cu_seqlens_q is not None and cu_seqlens_kv is not None + ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!" + seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] + seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] + max_seqlen_q = seqlens_q.max().item() + max_seqlen_kv = seqlens_kv.max().item() qkv_dtype = TE_DType[query_layer.dtype] - seqlen_q, batch_size = query_layer.shape[0], query_layer.shape[1] - seqlen_kv = key_layer.shape[0] - max_seqlen_q = seqlen_q - max_seqlen_kv = seqlen_kv - if self.attention_type == "self": - qkv_layout = _check_qkv_layout(query_layer, key_layer, value_layer) - if qkv_layout == "sbh3d": - mixed_layer = _CombineQKV.apply(query_layer, key_layer, value_layer, 3) - # [s, b, h, 3, d] - mixed_layer = mixed_layer.view( - *mixed_layer.shape[0:3], 3, query_layer.shape[-1]) - # [b, s, 3, h, d] - mixed_layer = mixed_layer.transpose(2, 3).transpose(0, 1).contiguous() - elif qkv_layout == "sb3hd": - mixed_layer = _CombineQKV.apply(query_layer, key_layer, value_layer, 2) - # [s, b, 3, h, d] - mixed_layer = mixed_layer.view( - *mixed_layer.shape[0:2], 3, *query_layer.shape[2:]) - # [b, s, 3, h, d] - mixed_layer = mixed_layer.transpose(0, 1).contiguous() - else: - raise Exception("FusedAttention only supports qkv layout sbh3d or sb3hd!") - - # [total_seqs, 3, h, d] - mixed_layer = mixed_layer.view( - mixed_layer.shape[0] * mixed_layer.shape[1], *mixed_layer.shape[2:]) - - qkv_layout = "qkv_interleaved" - max_seqlen = seqlen_q - cu_seqlens = torch.arange( - 0, - (batch_size + 1) * seqlen_q, - step=seqlen_q, - dtype=torch.int32, - device=query_layer.device) - use_FAv2_bwd = (self.use_FAv2_bwd - and (fused_attention_backend - == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen) - and core_attention_bias_type == "no_bias") - - with self.attention_dropout_ctx(): - output = FusedAttnFunc_qkvpacked.apply( - self.training, - max_seqlen, - cu_seqlens, - mixed_layer, - qkv_dtype, - core_attention_bias, - 1.0/self.norm_factor, - self.attention_dropout if self.training else 0.0, - fast_zero_fill, - qkv_layout, - core_attention_bias_type, - attn_mask_type, - None, # rng_gen - fused_attention_backend, - use_FAv2_bwd - ) - output = output.view(batch_size, seqlen_q, -1).transpose(0, 1).contiguous() - - if self.attention_type == "cross": - kv_layout = _check_kv_layout(key_layer, value_layer) - if kv_layout == "sbh2d": - key_value = _CombineKV.apply(key_layer, value_layer, 3) - # [s, b, h, 2, d] - key_value = key_value.view( - *key_value.shape[0:3], 2, key_layer.shape[-1]) - # [b, s, 2, h, d] - key_value = key_value.transpose(2, 3).transpose(0, 1).contiguous() - elif qkv_layout == "sb2hd": - key_value = _CombineKV.apply(key_layer, value_layer, 2) - # [s, b, 2, h, d] - key_value = key_value.view( - *key_value.shape[0:2], 2, *key_layer.shape[2:]) - # [b, s, 2, h, d] - key_value = key_value.transpose(0, 1).contiguous() - else: - raise Exception("FusedAttention only supports kv layout sbh2d or sb2hd!") - - # [total_seqs, h, d] - query_layer = query_layer.transpose(0, 1).contiguous() - query_layer = query_layer.view( - query_layer.shape[0] * query_layer.shape[1], *query_layer.shape[2:]) - # [total_seqs, 2, h, d] - key_value = key_value.view([key_value.shape[0] * key_value.shape[1]] - + key_value.shape[2:]) - - qkv_layout = "kv_interleaved" - cu_seqlens_q = torch.arange( - 0, - (batch_size + 1) * seqlen_q, - step=seqlen_q, - dtype=torch.int32, - device=query_layer.device) - cu_seqlens_kv = torch.arange( - 0, - (batch_size + 1) * seqlen_kv, - step=seqlen_kv, - dtype=torch.int32, - device=key_layer.device) - - with self.attention_dropout_ctx(): - outputs = FusedAttnFunc_kvpacked.apply( - self.training, - max_seqlen_q, max_seqlen_kv, - cu_seqlens_q, cu_seqlens_kv, - query_layer, key_value, - qkv_dtype, - core_attention_bias, - 1.0/self.norm_factor, - self.attention_dropout if self.training else 0.0, - fast_zero_fill, - qkv_layout, - core_attention_bias_type, - attn_mask_type, - None, # rng_gen - fused_attention_backend, - use_FAv2_bwd - ) + use_FAv2_bwd = (self.use_FAv2_bwd + and (fused_attention_backend + == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen)) + with self.attention_dropout_ctx(): + output = FusedAttnFunc.apply( + self.training, + max_seqlen_q, max_seqlen_kv, + cu_seqlens_q, cu_seqlens_kv, + query_layer, key_layer, value_layer, + qkv_dtype, + core_attention_bias, + 1.0/self.norm_factor, + self.attention_dropout if self.training else 0.0, + fast_zero_fill, + qkv_layout, + core_attention_bias_type, + attn_mask_type, + None, # rng_gen + fused_attention_backend, + use_FAv2_bwd, + ) - output = (outputs[0].view(batch_size, seqlen_q, -1).transpose(0, 1).contiguous(), - outputs[1].view(batch_size, seqlen_q, -1).transpose(0, 1).contiguous()) - return output + # ...hd -> ...(hd) + return output.view(*output.shape[:-2], -1) class DotProductAttention(torch.nn.Module): @@ -917,6 +964,16 @@ class DotProductAttention(torch.nn.Module): layer_number: int, default = `None` layer number of the current `DotProductAttention` when multiple such modules are concatenated, for instance in consecutive transformer blocks. + qkv_format: str, default = `sbhd` + dimension format for `query_layer`, `key_layer` and `value_layer`, + {`sbhd`, `bshd`, `thd`}. `s` stands for the sequence length, `b` batch size, + `h` the number of heads, `d` head size, and `t` the total number of sequences + in a batch, with `t = sum(s_i), for i = 0...b-1`. `sbhd` and `bshd` formats + are used for when sequences in a batch are of equal length or padded to + equal length, and the `thd` format is used for when sequences in a batch + have different lengths. Please note that these formats do not reflect how + tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory. + For that, please use `_get_qkv_layout` to gain the layout information. attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal` type of attention mask passed into softmax operation. Overridden by :attr:`attn_mask_type` in the `forward` method. The forward @@ -940,6 +997,7 @@ def __init__( kv_channels: int, num_gqa_groups: Optional[int] = None, attention_dropout: float = 0.0, + qkv_format: str = "sbhd", attn_mask_type: str = "causal", sequence_parallel: bool = False, tp_size: int = 1, @@ -950,6 +1008,7 @@ def __init__( ) -> None: super().__init__() + self.qkv_format = qkv_format self.attn_mask_type = attn_mask_type self.tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group) self.tp_group = tp_group @@ -1040,6 +1099,9 @@ def forward( key_layer: torch.Tensor, value_layer: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, + qkv_format: Optional[str] = None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_kv: Optional[torch.Tensor] = None, attn_mask_type: Optional[str] = None, checkpoint_core_attention: bool = False, core_attention_bias_type: str = "no_bias", @@ -1082,9 +1144,11 @@ def forward( If FusedAttention is being used, users can also choose to switch to flash-attn's implementation for backward by setting :attr:`NVTE_FUSED_ATTN_USE_FAv2_BWD=1` (default: 0), because of the performance differences between various versions of - flash-attn and FusedAttention. Further, :attr:`NVTE_FUSED_ATTN_DP_WORKSPACE_LIMIT` - can be used to enable the workspace related optimizations in FusedAttention - (default: 256MB; raise the limit to enable these performance optimizations). + flash-attn and FusedAttention. Further, :attr:`NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT` + can be used to enable (:attr:`1`) or disable (:attr:`0`) the workspace related + optimizations in FusedAttention. When unset, TransformerEngine determines the code path + based on its internal logic. These optimizations trade memory for performance + and should be used with care. Parameters ---------- @@ -1094,6 +1158,14 @@ def forward( Key tensor. value_layer : torch.Tensor Value tensor. + qkv_format: str, default = `None` + If provided, overrides :attr:`qkv_format` from initialization. + cu_seqlens_q: Optional[torch.Tensor], default = `None` + Cumulative sum of sequence lengths in a batch for `query_layer`, + with shape [batch_size + 1] and dtype torch.int32. + cu_seqlens_kv: Optional[torch.Tensor], default = `None` + Cumulative sum of sequence lengths in a batch for `key_layer` and `value_layer`, + with shape [batch_size + 1] and dtype torch.int32. attention_mask : Optional[torch.Tensor], default = `None` Boolean tensor used to mask out softmax input when not using flash-attn. attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `None` @@ -1111,12 +1183,57 @@ def forward( Whether to use the fast path to set output tensors to 0 or not. """ + assert (key_layer.shape == value_layer.shape + ), "Keys and values must have the same shape!" + if attn_mask_type is None: attn_mask_type = self.attn_mask_type + if qkv_format is None: + qkv_format = self.qkv_format assert (key_layer.shape[-2] == self.num_gqa_groups_per_partition - and value_layer.shape[-2] == self.num_gqa_groups_per_partition - ), f"Keys and values must have {self.num_gqa_groups} heads!" + and value_layer.shape[-2] == self.num_gqa_groups_per_partition + ), f"Keys and values must have num_gqa_group = {self.num_gqa_groups} heads!" + assert (qkv_format in ['sbhd', 'bshd', 'thd'] + ), "DotProductAttention only supports qkv_format = {'sbhd', 'bshd', 'thd'}!" + + if qkv_format == 'thd': + assert (all(len(x.shape) == 3 for x in (query_layer, key_layer, value_layer)) + ), "Queries, keys and values must be 3D tensors when qkv_format = thd!" + assert (cu_seqlens_q is not None and cu_seqlens_kv is not None + ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!" + assert (cu_seqlens_q.shape == cu_seqlens_kv.shape + and len(cu_seqlens_q.shape) == 1 + and len(cu_seqlens_kv.shape) == 1 + ), "cu_seqlens_q and cu_seqlens_q must both have shape [batch_size + 1]!" + assert (cu_seqlens_q.dtype == torch.int32 + and cu_seqlens_kv.dtype == torch.int32 + ), "cu_seqlens_q and cu_seqlens_q must both be in dtype torch.int32!" + seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] + seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] + max_seqlen_q = seqlens_q.max().item() + max_seqlen_kv = seqlens_kv.max().item() + + if qkv_format in ['sbhd', 'bshd']: + assert (all(len(x.shape) == 4 for x in (query_layer, key_layer, value_layer)) + ), f"Queries, keys and values must be 4D tensors when qkv_format = {qkv_format}!" + if qkv_format == 'sbhd': + max_seqlen_q, max_seqlen_kv = (query_layer.shape[0], key_layer.shape[0]) + if qkv_format == 'bshd': + max_seqlen_q, max_seqlen_kv = (query_layer.shape[1], key_layer.shape[1]) + if cu_seqlens_q is not None: + seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] + assert (all(seqlens_q <= max_seqlen_q) + ), """Sequence lengths indicated by cu_seqlens_q must be no greater than + the sequence dimention in 'query_layer'!""" + if cu_seqlens_kv is not None: + seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] + assert (all(seqlens_kv <= max_seqlen_kv) + ), """Sequence lengths indicated by cu_seqlens_kv must be no greater than + the sequence dimention in 'key_layer' and 'value_layer'!""" + + qkv_layout = _get_qkv_layout(query_layer, key_layer, value_layer, + qkv_format = qkv_format) use_flash_attention = self.use_flash_attention use_fused_attention = self.use_fused_attention @@ -1147,8 +1264,6 @@ def forward( use_flash_attention = False use_fused_attention = False - qkv_layout = "qkv_interleaved" if self.attention_type == "self" else "kv_interleaved" - if use_fused_attention: fused_attention_backend = tex.get_fused_attn_backend( TE_DType[query_layer.dtype], @@ -1157,7 +1272,7 @@ def forward( AttnBiasType[core_attention_bias_type], AttnMaskType[attn_mask_type], self.attention_dropout, - query_layer.shape[0], key_layer.shape[0], + max_seqlen_q, max_seqlen_kv, query_layer.shape[-1]) # DPA does not support FP8; for FP8, use cpp_extensions modules directly is_backend_avail = (fused_attention_backend in @@ -1179,9 +1294,16 @@ def forward( query_layer, key_layer, value_layer, - attn_mask_type=attn_mask_type) - return self.flash_attention( - query_layer, key_layer, value_layer, attn_mask_type=attn_mask_type) + qkv_layout = qkv_layout, + cu_seqlens_q = cu_seqlens_q, + cu_seqlens_kv = cu_seqlens_kv, + attn_mask_type = attn_mask_type) + return self.flash_attention(query_layer, key_layer, value_layer, + qkv_layout = qkv_layout, + cu_seqlens_q = cu_seqlens_q, + cu_seqlens_kv = cu_seqlens_kv, + attn_mask_type = attn_mask_type) + if use_fused_attention: if checkpoint_core_attention: @@ -1189,17 +1311,23 @@ def forward( query_layer, key_layer, value_layer, - attn_mask_type=attn_mask_type, - fused_attention_backend=fused_attention_backend, - core_attention_bias_type=core_attention_bias_type, - core_attention_bias=core_attention_bias, - fast_zero_fill=fast_zero_fill) + qkv_layout = qkv_layout, + cu_seqlens_q = cu_seqlens_q, + cu_seqlens_kv = cu_seqlens_kv, + attn_mask_type = attn_mask_type, + fused_attention_backend = fused_attention_backend, + core_attention_bias_type = core_attention_bias_type, + core_attention_bias = core_attention_bias, + fast_zero_fill = fast_zero_fill) return self.fused_attention(query_layer, key_layer, value_layer, - attn_mask_type=attn_mask_type, - fused_attention_backend=fused_attention_backend, - core_attention_bias_type=core_attention_bias_type, - core_attention_bias=core_attention_bias, - fast_zero_fill=fast_zero_fill) + qkv_layout = qkv_layout, + cu_seqlens_q = cu_seqlens_q, + cu_seqlens_kv = cu_seqlens_kv, + attn_mask_type = attn_mask_type, + fused_attention_backend = fused_attention_backend, + core_attention_bias_type = core_attention_bias_type, + core_attention_bias = core_attention_bias, + fast_zero_fill = fast_zero_fill) if checkpoint_core_attention: return self._checkpointed_attention_forward( @@ -1207,19 +1335,23 @@ def forward( query_layer, key_layer, value_layer, - attn_mask_type=attn_mask_type, - attention_mask=attention_mask, - core_attention_bias_type=core_attention_bias_type, - core_attention_bias=core_attention_bias, - ) + qkv_layout = qkv_layout, + cu_seqlens_q = cu_seqlens_q, + cu_seqlens_kv = cu_seqlens_kv, + attn_mask_type = attn_mask_type, + attention_mask = attention_mask, + core_attention_bias_type = core_attention_bias_type, + core_attention_bias = core_attention_bias) return self.unfused_attention(query_layer, key_layer, value_layer, - attn_mask_type=attn_mask_type, - attention_mask=attention_mask, - core_attention_bias_type=core_attention_bias_type, - core_attention_bias=core_attention_bias, - ) + qkv_layout = qkv_layout, + cu_seqlens_q = cu_seqlens_q, + cu_seqlens_kv = cu_seqlens_kv, + attn_mask_type = attn_mask_type, + attention_mask = attention_mask, + core_attention_bias_type = core_attention_bias_type, + core_attention_bias = core_attention_bias) class MultiheadAttention(torch.nn.Module): @@ -1834,6 +1966,9 @@ def forward( query_layer, key_layer, value_layer, + qkv_format='sbhd', + cu_seqlens_q=None, + cu_seqlens_kv=None, attention_mask=attention_mask, attn_mask_type=attn_mask_type, checkpoint_core_attention=checkpoint_core_attention, diff --git a/transformer_engine/pytorch/constants.py b/transformer_engine/pytorch/constants.py index ee43fa10d9..0504cde47c 100644 --- a/transformer_engine/pytorch/constants.py +++ b/transformer_engine/pytorch/constants.py @@ -28,6 +28,11 @@ AttnBiasTypes = ("pre_scale_bias", "post_scale_bias", "no_bias") +QKVLayouts = ( + "sb3hd", "sbh3d", "sbhd_sb2hd", "sbhd_sbh2d", "sbhd_sbhd_sbhd", + "bs3hd", "bsh3d", "bshd_bs2hd", "bshd_bsh2d", "bshd_bshd_bshd", + "t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd") + LayerTypes = ("encoder", "decoder") GemmParallelModes = ("row", "column", None) diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index dd6fb3e2f8..77b5302d6c 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -18,7 +18,9 @@ __all__ = ['fused_attn_fwd_qkvpacked', 'fused_attn_bwd_qkvpacked', 'fused_attn_fwd_kvpacked', - 'fused_attn_bwd_kvpacked'] + 'fused_attn_bwd_kvpacked', + 'fused_attn_fwd', + 'fused_attn_bwd'] TORCH_DType = { @@ -34,6 +36,21 @@ "not_interleaved": NVTE_QKV_Layout.NVTE_NOT_INTERLEAVED, "qkv_interleaved": NVTE_QKV_Layout.NVTE_QKV_INTERLEAVED, "kv_interleaved": NVTE_QKV_Layout.NVTE_KV_INTERLEAVED, + "sb3hd": NVTE_QKV_Layout.NVTE_SB3HD, + "sbh3d": NVTE_QKV_Layout.NVTE_SBH3D, + "sbhd_sb2hd": NVTE_QKV_Layout.NVTE_SBHD_SB2HD, + "sbhd_sbh2d": NVTE_QKV_Layout.NVTE_SBHD_SBH2D, + "sbhd_sbhd_sbhd": NVTE_QKV_Layout.NVTE_SBHD_SBHD_SBHD, + "bs3hd": NVTE_QKV_Layout.NVTE_BS3HD, + "bsh3d": NVTE_QKV_Layout.NVTE_BSH3D, + "bshd_bs2hd": NVTE_QKV_Layout.NVTE_BSHD_BS2HD, + "bshd_bsh2d": NVTE_QKV_Layout.NVTE_BSHD_BSH2D, + "bshd_bshd_bshd": NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD, + "t3hd": NVTE_QKV_Layout.NVTE_T3HD, + "th3d": NVTE_QKV_Layout.NVTE_TH3D, + "thd_t2hd": NVTE_QKV_Layout.NVTE_THD_T2HD, + "thd_th2d": NVTE_QKV_Layout.NVTE_THD_TH2D, + "thd_thd_thd": NVTE_QKV_Layout.NVTE_THD_THD_THD, } AttnBiasType = { @@ -166,9 +183,10 @@ def fused_attn_fwd_qkvpacked( if True, runs training and produces auxiliary tensors aux_ctx_tensors for the backward; if False, runs inference and doesn't produce aux_ctx_tensors max_seqlen: int - max sequence length for QKV, used for padding; may be larger than max(cu_seqlens) + max sequence length for QKV, used for padding; may be larger than max(seqlens), + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] cu_seqlens: torch.Tensor - accumulative sequence lengths for QKV; shape [batch_size + 1] + cumulative sequence lengths for QKV; shape [batch_size + 1] qkv: torch.Tensor input tensor QKV; shape [total_seqs, 3, num_heads, head_dim], where total_seqs = cu_seqlens[-1] @@ -336,9 +354,10 @@ def fused_attn_bwd_qkvpacked( Parameters ---------- max_seqlen: int - max sequence length for QKV, used for padding; may be larger than max(cu_seqlens_q) + max sequence length for QKV, used for padding; may be larger than max(seqlens) + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] cu_seqlens: torch.Tensor - accumulative sequence lengths for QKV; shape [batch_size + 1] + cumulative sequence lengths for QKV; shape [batch_size + 1] qkv: torch.Tensor input tensor QKV; shape [total_seqs, 3, num_heads, head_dim], where total_seqs = cu_seqlens[-1] @@ -482,7 +501,7 @@ def fused_attn_fwd_kvpacked( attn_scale: float = None, dropout: float = 0.0, fast_zero_fill: bool = True, - qkv_layout: str = "qkv_interleaved", + qkv_layout: str = "kv_interleaved", attn_bias_type: str = "no_bias", attn_mask_type: str = "padding", rng_gen: torch.Generator = None, @@ -495,13 +514,15 @@ def fused_attn_fwd_kvpacked( if True, runs training and produces auxiliary tensors aux_ctx_tensors for the backward; if False, runs inference and doesn't produce aux_ctx_tensors max_seqlen_q: int - max sequence length for Q, used for padding; may be larger than max(cu_seqlens_q) + max sequence length for Q, used for padding; may be larger than max(seqlens_q), + seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] max_seqlen_kv: int - max sequence length for KV, used for padding; may be larger than max(cu_seqlens_kv) + max sequence length for KV, used for padding; may be larger than max(seqlens_kv), + seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] cu_seqlens_q: torch.Tensor - accumulative sequence lengths for Q; shape [batch_size + 1] + cumulative sequence lengths for Q; shape [batch_size + 1] cu_seqlens_kv: torch.Tensor - accumulative sequence lengths for KV; shape [batch_size + 1] + cumulative sequence lengths for KV; shape [batch_size + 1] q: torch.Tensor input tensor Q; shape [total_seqs_q, num_heads, head_dim], where total_seqs_q = cu_seqlens_q[-1] @@ -535,7 +556,7 @@ def fused_attn_fwd_kvpacked( fast_zero_fill: bool, default = True if True, initializes the output tensor O to zero using the fast filling method; if False, uses PyTorch's .fill_() method - qkv_layout: str, default = "qkv_interleaved" + qkv_layout: str, default = "kv_interleaved" layout of QKV; {"qkv_interleaved", "kv_interleaved", "not_interleaved"} attn_bias_type: str, default = "no_bias" type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"} @@ -659,7 +680,7 @@ def fused_attn_bwd_kvpacked( attn_scale: float = None, dropout: float = 0.0, fast_zero_fill: bool = True, - qkv_layout: str = "qkv_interleaved", + qkv_layout: str = "kv_interleaved", attn_bias_type: str = "no_bias", attn_mask_type: str = "padding", ) -> Tuple[Union[torch.Tensor, None], ...]: @@ -668,13 +689,15 @@ def fused_attn_bwd_kvpacked( Parameters ---------- max_seqlen_q: int - max sequence length for Q, used for padding; may be larger than max(cu_seqlens_q) + max sequence length for Q, used for padding; may be larger than max(seqlens_q), + seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] max_seqlen_kv: int - max sequence length for KV, used for padding; may be larger than max(cu_seqlens_kv) + max sequence length for KV, used for padding; may be larger than max(seqlens_kv), + seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] cu_seqlens_q: torch.Tensor - accumulative sequence lengths for Q; shape [batch_size + 1] + cumulative sequence lengths for Q; shape [batch_size + 1] cu_seqlens_kv: torch.Tensor - accumulative sequence lengths for KV; shape [batch_size + 1] + cumulative sequence lengths for KV; shape [batch_size + 1] q: torch.Tensor input tensor Q; shape [total_seqs_q, num_heads, head_dim], where total_seqs_q = cu_seqlens_q[-1] @@ -723,7 +746,7 @@ def fused_attn_bwd_kvpacked( fast_zero_fill: bool, default = True if True, initializes the output tensor O to zero using the fast filling method; if False, uses PyTorch's .fill_() method - qkv_layout: str, default = "qkv_interleaved" + qkv_layout: str, default = "kv_interleaved" layout of QKV; {"qkv_interleaved", "kv_interleaved", "not_interleaved"} attn_bias_type: str, default = "no_bias" type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"} @@ -812,3 +835,365 @@ def fused_attn_bwd_kvpacked( return output_tensors # otherwise return (d_q, d_kv), d_bias return output_tensors[:2], output_tensors[2] + +def fused_attn_fwd( + is_training: bool, + max_seqlen_q: int, + max_seqlen_kv: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + qkv_dtype: tex.DType, + fused_attention_backend: tex.NVTE_Fused_Attn_Backend, + attn_bias: torch.Tensor = None, + d_scale_qkv: torch.Tensor = None, + q_scale_s: torch.Tensor = None, + q_scale_o: torch.Tensor = None, + amax_s: torch.Tensor = None, + amax_o: torch.Tensor = None, + attn_scale: float = None, + dropout: float = 0.0, + fast_zero_fill: bool = True, + qkv_layout: str = "sbh3d", + attn_bias_type: str = "no_bias", + attn_mask_type: str = "padding", + rng_gen: torch.Generator = None, +) -> Tuple[Union[torch.Tensor, None], ...]: + """Fused Attention FWD for separate QKV input. + + Parameters + ---------- + is_training: bool + if True, runs training and produces auxiliary tensors aux_ctx_tensors + for the backward; if False, runs inference and doesn't produce aux_ctx_tensors + max_seqlen_q: int + max sequence length for Q, used for padding; + may be larger than max(seqlens_q), + seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] + max_seqlen_kv: int + max sequence length for K and V, used for padding; + may be larger than max(seqlens_kv), + seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] + cu_seqlens_q: torch.Tensor + cumulative sequence lengths for Q; shape [batch_size + 1] + cu_seqlens_kv: torch.Tensor + cumulative sequence lengths for K and V; shape [batch_size + 1] + q: torch.Tensor + input tensor Q; + shape [total_seqs_q, num_heads, head_dim], + where total_seqs_q = cu_seqlens_q[-1], + or [batch_size, seqlen_q, num_heads, head_dim], + or [seqlen_q, batch_size, num_heads, head_dim] + k: torch.Tensor + input tensor K; + shape [total_seqs_kv, num_heads, head_dim], + where total_seqs_kv = cu_seqlens_kv[-1], + or [batch_size, seqlen_kv, num_heads, head_dim], + or [seqlen_kv, batch_size, num_heads, head_dim] + v: torch.Tensor + input tensor V; + shape [total_seqs_kv, num_heads, head_dim], + where total_seqs_kv = cu_seqlens_kv[-1], + or [batch_size, seqlen_kv, num_heads, head_dim], + or [seqlen_kv, batch_size, num_heads, head_dim] + qkv_dtype: tex.DType + data type of Q, K and V; in tex.DType, not torch.dtype + fused_attention_backend: tex.NVTE_Fused_Attn_Backend + please see FusedAttention module for details on supported backends. + attn_bias: torch.Tensor, default = None + input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias"; + shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q, k and v + d_scale_qkv: torch.Tensor, default = None + input tensor for the dequantization of Q, K and V in FP8 computations + q_scale_s: torch.Tensor, default = None + input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T) + q_scale_o: torch.Tensor, default = None + input tensor for the quantization of O in FP8 computations + amax_s: torch.Tensor, default = None + output tensor, amax of S, used by the next iteration in FP8 computations + amax_o: torch.Tensor, default = None + output tensor, amax of O, used by the next iteration in FP8 computations + attn_scale: float, default = None + if not None, use attn_scale as the attention scale for Q*K.T BMM; + if None, use 1.0/sqrt(head_dim) as the default + dropout: float, default = 0.0 + dropout probability, 0.0 means no dropout, 1.0 means no output; + dropout must be 0.0 if is_training is False + fast_zero_fill: bool, default = True + if True, initializes the output tensor O to zero using the fast filling method; + if False, uses PyTorch's .fill_() method + qkv_layout: str, default = "sbh3d" + layout of Q, K and V; + {"sb3hd", "sbh3d", "sbhd_sb2hd", "sbhd_sbh2d", "sbhd_sbhd_sbhd", + "bs3hd", "bsh3d", "bshd_bs2hd", "bshd_bsh2d", "bshd_bshd_bshd", + "t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"} + attn_bias_type: str, default = "no_bias" + type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"} + attn_mask_type: str, default = "padding" + type of the attention mask; {"padding", "causal", "no_mask"} + rng_gen: torch.Generator, default = None + random number generator; + if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen + + Returns + ---------- + o: torch.Tensor + output tensor O, of the attention calculation; same data type as Q, K and V; + same shape as Q + aux_ctx_tensors: List[torch.Tensor] + auxiliary output tensors used for the backward; + if is_training is True, aux_ctx_tensors = [softmax-related tensors, rng_state] + if is_training is False, aux_ctx_tensors = None + + softmax-related tensors: + 1. if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"] + softmax: torch.Tensor + Softmax(Q*K.T) + shape [batch_size, num_heads, max_seqlen_q, max_seqlen_kv], dtype float32 + 2. if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"] + softmaxStats: torch.Tensor + log(sum(e^(x - max(x)))), where x=Q*K.T + shape [batch_size, num_heads, max_seqlen_q, 1], dtype float32 + 3. if fused_attention_backend == FusedAttnBackend["FP8"] + M: torch.Tensor + max(Q*K.T) + shape [batch_size, num_heads, max_seqlen_q, 1], dtype float32 + ZInv: torch.Tensor + 1/sum(e^(x - max(x))), where x=Q*K.T + shape [batch_size, num_heads, max_seqlen_q, 1], dtype float32 + rng_state: torch.Tensor, optional, if backend is not F16_max512_seqlen + state of the random number generator; + [seed, offset], dtype uint64 + """ + + check_cu_seqlens(cu_seqlens_q) + check_cu_seqlens(cu_seqlens_kv) + assert (cu_seqlens_q.numel() == cu_seqlens_kv.numel() + ), "cu_seqlens_q and cu_seqlens_kv must have the same length." + h = q.shape[-2] + d = q.shape[-1] + + if attn_scale is None: + attn_scale = 1.0 / math.sqrt(d) + + if attn_bias_type != "no_bias": + assert (attn_bias is not None + ), "attn_bias tensor cannot be None when attn_bias_type is not no_bias." + assert (attn_bias.shape == torch.Size([1, h, max_seqlen_q, max_seqlen_kv]) + ), "attn_bias tensor must be in [1, h, max_seqlen_q, max_seqlen_kv] shape." + assert (attn_bias.dtype == q.dtype + ), "attn_bias tensor must be in the same dtype as q and kv." + + assert (fused_attention_backend != FusedAttnBackend["No_Backend"] + ), "Fused attention does not support this input combination." + + # BF16/FP16 fused attention API from fmha_v1 apex + if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: + rng_elts_per_thread = (max_seqlen_q * max_seqlen_kv + + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1)//BACKEND_F16m512_FP8_THREADS_PER_CTA + + # BF16/FP16 fused attention API from fmha_v2 + if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: + rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS + + # FP8 fused attention API from fmha_v2 + if fused_attention_backend == FusedAttnBackend["FP8"]: + rng_elts_per_thread = (max_seqlen_q * max_seqlen_q + + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1)//BACKEND_F16m512_FP8_THREADS_PER_CTA + + # execute kernel + output_tensors = tex.fused_attn_fwd( + max_seqlen_q, max_seqlen_kv, is_training, attn_scale, dropout, fast_zero_fill, + QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], + cu_seqlens_q, cu_seqlens_kv, q, k, v, qkv_dtype, + d_scale_qkv, q_scale_s, q_scale_o, amax_s, amax_o, + attn_bias, rng_gen, rng_elts_per_thread, + ) + + # out, aux_ctx_tensors + return output_tensors[0], output_tensors[1:] + + +def fused_attn_bwd( + max_seqlen_q: int, + max_seqlen_kv: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + d_o: torch.Tensor, + qkv_dtype: tex.DType, + aux_ctx_tensors: List[torch.Tensor], + fused_attention_backend: tex.NVTE_Fused_Attn_Backend, + d_scale_qkv: torch.Tensor = None, + d_scale_s: torch.Tensor = None, + d_scale_o: torch.Tensor = None, + d_scale_do: torch.Tensor = None, + q_scale_s: torch.Tensor = None, + q_scale_dp: torch.Tensor = None, + q_scale_dqkv: torch.Tensor = None, + amax_dp: torch.Tensor = None, + amax_dqkv: torch.Tensor = None, + attn_scale: float = None, + dropout: float = 0.0, + fast_zero_fill: bool = True, + qkv_layout: str = "sbh3d", + attn_bias_type: str = "no_bias", + attn_mask_type: str = "padding", +) -> Tuple[Union[torch.Tensor, None], ...]: + """Fused Attention BWD for packed KV input. + + Parameters + ---------- + max_seqlen_q: int + max sequence length for Q, used for padding; may be larger than max(seqlens_q), + seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] + max_seqlen_kv: int + max sequence length for K and V, used for padding; + may be larger than max(seqlens_kv), + seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] + cu_seqlens_q: torch.Tensor + cumulative sequence lengths for Q; shape [batch_size + 1] + cu_seqlens_kv: torch.Tensor + cumulative sequence lengths for K and V; shape [batch_size + 1] + q: torch.Tensor + input tensor Q; + shape [total_seqs_q, num_heads, head_dim], + where total_seqs_q = cu_seqlens_q[-1], + or [batch_size, seqlen_q, num_heads, head_dim], + or [seqlen_q, batch_size, num_heads, head_dim] + k: torch.Tensor + input tensor K; + shape [total_seqs_kv, num_heads, head_dim], + where total_seqs_kv = cu_seqlens_kv[-1], + or [batch_size, seqlen_kv, num_heads, head_dim], + or [seqlen_kv, batch_size, num_heads, head_dim] + v: torch.Tensor + input tensor V; + shape [total_seqs_kv, num_heads, head_dim], + where total_seqs_kv = cu_seqlens_kv[-1], + or [batch_size, seqlen_kv, num_heads, head_dim], + or [seqlen_kv, batch_size, num_heads, head_dim] + o: torch.Tensor + input tensor O (output of forward); same data type as Q, K and V; + same shape as Q + d_o: torch.Tensor + input tensor dO (gradient of O); same data type as Q, K and V; + same shape as Q + qkv_dtype: tex.DType + data type of Q, K and V; in tex.DType, not torch.dtype + aux_ctx_tensors: List[torch.Tensor] + auxiliary output tensors of the forward pass when its is_training is True, + e.g. aux_ctx_tensors = [M, ZInv, rng_state] + fused_attention_backend: tex.NVTE_Fused_Attn_Backend + please see FusedAttention module for details on supported backends. + d_scale_qkv: torch.Tensor, default = None + input tensor for the dequantization of Q, K and V in FP8 computations + d_scale_s: torch.Tensor, default = None + input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T) + d_scale_o: torch.Tensor, default = None + input tensor for the dequantization of O in FP8 computations + d_scale_do: torch.Tensor, default = None + input tensor for the dequantization of dO in FP8 computations + q_scale_s: torch.Tensor, default = None + input tensor for the quantization of S in FP8 computations + q_scale_dp: torch.Tensor, default = None + input tensor for the quantization of dP in FP8 computations, P = Q * K.T + q_scale_dqkv: torch.Tensor, default = None + input tensor for the quantization of dQ, dK and dV in FP8 computations + amax_dp: torch.Tensor, default = None + output tensor, amax of dP, used by the next iteration in FP8 computations, + P = Q * K.T + amax_dqkv: torch.Tensor, default = None + output tensor, amax of dQ, dK and dV, used by the next iteration in FP8 computations + attn_scale: float, default = None + if not None, use attn_scale as the attention scale for Q*K.T BMM; + if None, use 1.0/sqrt(head_dim) as the default + dropout: float, default = 0.0 + dropout probability, 0.0 means no dropout, 1.0 means no output; + dropout must be 0.0 if is_training is False + fast_zero_fill: bool, default = True + if True, initializes the output tensor O to zero using the fast filling method; + if False, uses PyTorch's .fill_() method + qkv_layout: str, default = "sbh3d" + layout of Q, K and V; + {"sb3hd", "sbh3d", "sbhd_sb2hd", "sbhd_sbh2d", "sbhd_sbhd_sbhd", + "bs3hd", "bsh3d", "bshd_bs2hd", "bshd_bsh2d", "bshd_bshd_bshd", + "t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"} + attn_bias_type: str, default = "no_bias" + type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"} + attn_mask_type: str, default = "padding" + type of the attention mask; {"padding", "causal", "no_mask"} + + Returns + ---------- + d_q: torch.Tensor + gradient tensor of Q; same data type and shape as Q + d_k: torch.Tensor + gradient tensor of K; same data type and shape as K + d_v: torch.Tensor + gradient tensor of V; same data type and shape as V + d_bias: torch.Tensor, optional + gradient tensor of Bias when attn_bias_type is "pre_scale_bias" + or "post_scale_bias"; same data type and shape as Bias + """ + + check_cu_seqlens(cu_seqlens_q) + check_cu_seqlens(cu_seqlens_kv) + assert (cu_seqlens_q.numel() == cu_seqlens_kv.numel() + ), "cu_seqlens_q and cu_seqlens_kv must have the same length." + b = cu_seqlens_q.numel() - 1 + h = q.shape[-2] + d = q.shape[-1] + + if attn_scale is None: + attn_scale = 1.0 / math.sqrt(d) + + assert (fused_attention_backend != FusedAttnBackend["No_Backend"] + ), "Fused attention does not support this input combination." + + if fused_attention_backend != FusedAttnBackend["F16_max512_seqlen"]: + assert (len(aux_ctx_tensors) >= 1 + ), "aux_ctx_tensors must contain rng_state as its last element." + rng_state = aux_ctx_tensors[-1] + check_rng_state(rng_state) + + if fused_attention_backend == FusedAttnBackend["FP8"]: + assert (d_scale_qkv is not None), "d_scale_qkv is required for FP8 fused attention." + assert (d_scale_s is not None), "d_scale_s is required for FP8 fused attention." + assert (d_scale_o is not None), "d_scale_o is required for FP8 fused attention." + assert (d_scale_do is not None), "d_scale_do is required for FP8 fused attention." + assert (q_scale_s is not None), "q_scale_s is required for FP8 fused attention." + assert (q_scale_dp is not None), "q_scale_dp is required for FP8 fused attention." + assert (q_scale_dqkv is not None), "q_scale_dqkv is required for FP8 fused attention." + assert (amax_dp is not None), "amax_dp is required for FP8 fused attention." + assert (amax_dqkv is not None), "amax_dqkv is required for FP8 fused attention." + assert (len(aux_ctx_tensors) == 3 + ), "aux_ctx_tensors is required to be [M, ZInv, rng_state] for FP8 fused attention." + check_scalar(d_scale_qkv) + check_scalar(d_scale_s) + check_scalar(d_scale_o) + check_scalar(d_scale_do) + check_scalar(q_scale_s) + check_scalar(q_scale_dp) + check_scalar(q_scale_dqkv) + check_scalar(amax_dp) + check_scalar(amax_dqkv) + m, z_inv = aux_ctx_tensors[:2] + check_stats(m, b, h, max_seqlen_q) + check_stats(z_inv, b, h, max_seqlen_q) + + # execute kernel + output_tensors = tex.fused_attn_bwd( + max_seqlen_q, max_seqlen_kv, attn_scale, dropout, fast_zero_fill, + QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], + cu_seqlens_q, cu_seqlens_kv, q, k, v, o, d_o, qkv_dtype, aux_ctx_tensors, + d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, + q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv, + ) + + return tuple(output_tensors) diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index d06906b5a2..274a523ec0 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -106,6 +106,52 @@ std::vector fused_attn_bwd_kvpacked( c10::optional amax_dP, c10::optional amax_dQKV); +std::vector fused_attn_fwd( + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, + float attn_scale, float p_dropout, bool set_zero, + NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, + const at::Tensor cu_seqlens_q, + const at::Tensor cu_seqlens_kv, + const at::Tensor Q, + const at::Tensor K, + const at::Tensor V, + const transformer_engine::DType qkv_type, + const c10::optional descale_QKV, + const c10::optional scale_S, + const c10::optional scale_O, + c10::optional amax_S, + c10::optional amax_O, + const c10::optional Bias, + const c10::optional rng_gen, + size_t rng_elts_per_thread); + +std::vector fused_attn_bwd( + size_t max_seqlen_q, size_t max_seqlen_kv, + float attn_scale, float p_dropout, bool set_zero, + NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, + const at::Tensor cu_seqlens_q, + const at::Tensor cu_seqlens_kv, + const at::Tensor Q, + const at::Tensor K, + const at::Tensor V, + const at::Tensor O, + const at::Tensor dO, + const transformer_engine::DType qkv_type, + const std::vector Aux_CTX_Tensors, + const c10::optional descale_QKV, + const c10::optional descale_S, + const c10::optional descale_O, + const c10::optional descale_dO, + const c10::optional scale_S, + const c10::optional scale_dP, + const c10::optional scale_dQKV, + c10::optional amax_dP, + c10::optional amax_dQKV); + at::Tensor fa_prepare_fwd(at::Tensor qkvi); at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v); diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index d2b91cc194..4f2d958f13 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -717,6 +717,444 @@ std::vector fused_attn_bwd_kvpacked( return {dQ, dKV, dBias}; } +// fused attention FWD with separate Q, K and V tensors +std::vector fused_attn_fwd( + size_t max_seqlen_q, size_t max_seqlen_kv, + bool is_training, float attn_scale, float p_dropout, bool set_zero, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + const at::Tensor cu_seqlens_q, + const at::Tensor cu_seqlens_kv, + const at::Tensor Q, + const at::Tensor K, + const at::Tensor V, + const transformer_engine::DType qkv_type, + const c10::optional descale_QKV, + const c10::optional scale_S, + const c10::optional scale_O, + c10::optional amax_S, + c10::optional amax_O, + const c10::optional Bias, + const c10::optional rng_gen, + size_t rng_elts_per_thread) { + using namespace transformer_engine; + + auto q_sizes = Q.sizes().vec(); + std::vector q_shape{q_sizes.begin(), q_sizes.end()}; + auto k_sizes = K.sizes().vec(); + std::vector k_shape{k_sizes.begin(), k_sizes.end()}; + auto v_sizes = V.sizes().vec(); + std::vector v_shape{v_sizes.begin(), v_sizes.end()}; + + // create output tensor O + auto O = torch::empty_like(Q); + + // construct NVTE tensors + TensorWrapper te_Q, te_K, te_V, te_S, te_O, te_Bias; + TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv; + if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { + // FP8 + auto h = Q.size(-2); + auto d = Q.size(-1); + if (set_zero && ((h * d) % block_size == 0)) { + mha_fill(O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); + } else { + O.fill_(0); + } + if ((!descale_QKV.has_value()) || (!scale_S.has_value()) || (!scale_O.has_value()) + || (!amax_S.has_value()) || (!amax_O.has_value())) { + std::string err_tensors = "descale_QKV, scale_S, scale_O, amax_S and amax_O"; + NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); + } + te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, + qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); + te_K = makeTransformerEngineTensor(K.data_ptr(), k_shape, + qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); + te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape, + qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); + at::Tensor descale_S = torch::empty_like(scale_S.value()); + te_S = makeTransformerEngineTensor(nullptr, {0}, + DType::kFloat32, amax_S.value().data_ptr(), + scale_S.value().data_ptr(), descale_S.data_ptr()); + te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, + qkv_type, amax_O.value().data_ptr(), scale_O.value().data_ptr(), nullptr); + } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { + // BF16 or FP16 + te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, + qkv_type, nullptr, nullptr, nullptr); + te_K = makeTransformerEngineTensor(K.data_ptr(), k_shape, + qkv_type, nullptr, nullptr, nullptr); + te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape, + qkv_type, nullptr, nullptr, nullptr); + te_S = makeTransformerEngineTensor(nullptr, {0}, + DType::kFloat32, nullptr, nullptr, nullptr); + te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, + qkv_type, nullptr, nullptr, nullptr); + } else { + NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); + } + if ((bias_type != NVTE_NO_BIAS) && (Bias.has_value())) { + auto bias_sizes = Bias.value().sizes().vec(); + std::vector bias_shape{bias_sizes.begin(), bias_sizes.end()}; + te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), bias_shape, + DType::kFloat32, nullptr, nullptr, nullptr); + } + auto cu_seqlens_q_sizes = cu_seqlens_q.sizes().vec(); + std::vector cu_seqlens_q_shape{cu_seqlens_q_sizes.begin(), cu_seqlens_q_sizes.end()}; + auto cu_seqlens_kv_sizes = cu_seqlens_kv.sizes().vec(); + std::vector cu_seqlens_kv_shape{cu_seqlens_kv_sizes.begin(), cu_seqlens_kv_sizes.end()}; + te_cu_seqlens_q = makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), cu_seqlens_q_shape, + DType::kInt32, nullptr, nullptr, nullptr); + te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape, + DType::kInt32, nullptr, nullptr, nullptr); + + // extract rng seed and offset + auto gen = at::get_generator_or_default( + rng_gen, at::cuda::detail::getDefaultCUDAGenerator()); + at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread); + auto options = torch::TensorOptions().dtype(torch::kInt64).device(torch::kCUDA); + auto rng_state = torch::empty({2}, options); + unpack<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>( + philox_args, static_cast(rng_state.data_ptr())); + auto te_rng_state = makeTransformerEngineTensor(rng_state); + + // create auxiliary output tensors + NVTETensorPack nvte_aux_tensor_pack; + nvte_tensor_pack_create(&nvte_aux_tensor_pack); + + // create workspace + TensorWrapper workspace; + + // populate tensors with appropriate shapes and dtypes + nvte_fused_attn_fwd( + te_Q.data(), + te_K.data(), + te_V.data(), + te_Bias.data(), + te_S.data(), + te_O.data(), + &nvte_aux_tensor_pack, + te_cu_seqlens_q.data(), + te_cu_seqlens_kv.data(), + te_rng_state.data(), + max_seqlen_q, max_seqlen_kv, + is_training, attn_scale, p_dropout, + qkv_layout, bias_type, attn_mask_type, + workspace.data(), + at::cuda::getCurrentCUDAStream()); + + // allocate memory for workspace and auxiliary output tensors + auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); + workspace = makeTransformerEngineTensor( + workspace_data.data_ptr(), + workspace.shape(), workspace.dtype()); + + // output_tensors = [O, nvte_aux_tensor_pack.tensors] + std::vector output_tensors; + output_tensors.push_back(O); + for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { + auto tensor = reinterpret_cast(nvte_aux_tensor_pack.tensors[i]); + // allocate memory for nvte_aux_tensor_pack.tensors + at::Tensor output_tensor; + if (nvte_aux_tensor_pack.size >= 2) { + output_tensor = (i < nvte_aux_tensor_pack.size-1) + ? allocateSpace(tensor->data.shape, tensor->data.dtype, false) : rng_state; + } else { + output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false); + } + output_tensors.push_back(output_tensor); + tensor->data.dptr = output_tensor.data_ptr(); + } + + // execute the kernel + nvte_fused_attn_fwd( + te_Q.data(), + te_K.data(), + te_V.data(), + te_Bias.data(), + te_S.data(), + te_O.data(), + &nvte_aux_tensor_pack, + te_cu_seqlens_q.data(), + te_cu_seqlens_kv.data(), + te_rng_state.data(), + max_seqlen_q, max_seqlen_kv, + is_training, attn_scale, p_dropout, + qkv_layout, bias_type, attn_mask_type, + workspace.data(), + at::cuda::getCurrentCUDAStream()); + + // destroy tensor wrappers, but not allocated memory + nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); + + // if training, [O, softmax-related tensors, rng_state]; if inference, [O] + return output_tensors; +} + +// fused attention BWD with separate Q, K and V +std::vector fused_attn_bwd( + size_t max_seqlen_q, size_t max_seqlen_kv, + float attn_scale, float p_dropout, bool set_zero, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + const at::Tensor cu_seqlens_q, + const at::Tensor cu_seqlens_kv, + const at::Tensor Q, + const at::Tensor K, + const at::Tensor V, + const at::Tensor O, + const at::Tensor dO, + const transformer_engine::DType qkv_type, + const std::vector Aux_CTX_Tensors, + const c10::optional descale_QKV, + const c10::optional descale_S, + const c10::optional descale_O, + const c10::optional descale_dO, + const c10::optional scale_S, + const c10::optional scale_dP, + const c10::optional scale_dQKV, + c10::optional amax_dP, + c10::optional amax_dQKV) { + using namespace transformer_engine; + + auto q_sizes = Q.sizes().vec(); + std::vector q_shape{q_sizes.begin(), q_sizes.end()}; + auto k_sizes = K.sizes().vec(); + std::vector k_shape{k_sizes.begin(), k_sizes.end()}; + auto v_sizes = V.sizes().vec(); + std::vector v_shape{v_sizes.begin(), v_sizes.end()}; + auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA); + + at::Tensor dQ; + at::Tensor dK; + at::Tensor dV; + at::Tensor dQKV, dKV; + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + std::vector tmp_shape; + switch (layout_group) { + case NVTE_QKV_Layout_Group::NVTE_3HD: + tmp_shape = std::vector{q_sizes.begin(), q_sizes.end()}; + tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 2, int64_t(3)); + dQKV = torch::empty(c10::IntArrayRef(tmp_shape), options); + dQ = dQKV.index({"...", torch::indexing::Slice(0, 1, 1), + torch::indexing::Slice(0, torch::indexing::None, 1), + torch::indexing::Slice(0, torch::indexing::None, 1)}).squeeze(tmp_shape.size() - 3); + dK = dQKV.index({"...", torch::indexing::Slice(1, 2, 1), + torch::indexing::Slice(0, torch::indexing::None, 1), + torch::indexing::Slice(0, torch::indexing::None, 1)}).squeeze(tmp_shape.size() - 3); + dV = dQKV.index({"...", torch::indexing::Slice(2, torch::indexing::None, 1), + torch::indexing::Slice(0, torch::indexing::None, 1), + torch::indexing::Slice(0, torch::indexing::None, 1)}).squeeze(tmp_shape.size() - 3); + break; + case NVTE_QKV_Layout_Group::NVTE_H3D: + tmp_shape = std::vector{q_sizes.begin(), q_sizes.end()}; + tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 1, int64_t(3)); + dQKV = torch::empty(c10::IntArrayRef(tmp_shape), options); + dQ = dQKV.index({"...", torch::indexing::Slice(0, 1, 1), + torch::indexing::Slice(0, torch::indexing::None, 1)}).squeeze(tmp_shape.size() - 2); + dK = dQKV.index({"...", torch::indexing::Slice(1, 2, 1), + torch::indexing::Slice(0, torch::indexing::None, 1)}).squeeze(tmp_shape.size() - 2); + dV = dQKV.index({"...", torch::indexing::Slice(2, torch::indexing::None, 1), + torch::indexing::Slice(0, torch::indexing::None, 1)}).squeeze(tmp_shape.size() - 2); + break; + case NVTE_QKV_Layout_Group::NVTE_HD_2HD: + dQ = torch::empty_like(Q); + tmp_shape = std::vector{k_sizes.begin(), k_sizes.end()}; + tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 2, int64_t(2)); + dKV = torch::empty(c10::IntArrayRef(tmp_shape), options); + dK = dKV.index({"...", torch::indexing::Slice(0, 1, 1), + torch::indexing::Slice(0, torch::indexing::None, 1), + torch::indexing::Slice(0, torch::indexing::None, 1)}).squeeze(tmp_shape.size() - 3); + dV = dKV.index({"...", torch::indexing::Slice(1, torch::indexing::None, 1), + torch::indexing::Slice(0, torch::indexing::None, 1), + torch::indexing::Slice(0, torch::indexing::None, 1)}).squeeze(tmp_shape.size() - 3); + break; + case NVTE_QKV_Layout_Group::NVTE_HD_H2D: + dQ = torch::empty_like(Q); + tmp_shape = std::vector{k_sizes.begin(), k_sizes.end()}; + tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 1, int64_t(2)); + dKV = torch::empty(c10::IntArrayRef(tmp_shape), options); + dK = dKV.index({"...", torch::indexing::Slice(0, 1, 1), + torch::indexing::Slice(0, torch::indexing::None, 1)}).squeeze(tmp_shape.size() - 2); + dV = dKV.index({"...", torch::indexing::Slice(1, torch::indexing::None, 1), + torch::indexing::Slice(0, torch::indexing::None, 1)}).squeeze(tmp_shape.size() - 2); + break; + case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD: + dQ = torch::empty_like(Q); + dK = torch::empty_like(K); + dV = torch::empty_like(V); + break; + default: + NVTE_ERROR("QKV layout not supported!"); + } + + at::Tensor dBias; + TensorWrapper te_dBias; + if (bias_type != NVTE_NO_BIAS) { + dBias = torch::empty({1, static_cast(Q.size(-2)), + static_cast(max_seqlen_q), + static_cast(max_seqlen_kv)}, options); + te_dBias = makeTransformerEngineTensor(dBias); + } + + // construct NVTE tensors + TensorWrapper te_Q, te_K, te_V, te_O, te_dO, te_S, te_dP, te_dQ, te_dK, te_dV; + if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { + // FP8 + auto h_q = Q.size(-2); + auto h_kv = K.size(-2); + auto d = Q.size(-1); + if (set_zero + && ((h_q * d) % block_size == 0) + && ((h_kv * d) % block_size == 0) + && dQ.is_contiguous() + && dK.is_contiguous() + && dV.is_contiguous()) { + mha_fill(dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); + mha_fill(dK, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); + mha_fill(dV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); + } else { + dQ.fill_(0); + dK.fill_(0); + dV.fill_(0); + } + if ((!descale_QKV.has_value()) || (!descale_S.has_value()) + || (!descale_O.has_value()) || (!descale_dO.has_value()) + || (!scale_S.has_value()) || (!scale_dP.has_value()) + || (!scale_dQKV.has_value()) + || (!amax_dP.has_value()) || (!amax_dQKV.has_value())) { + std::string err_tensors = "descale_QKV, descale_S, descale_O, scale_S, scale_dP, "; + err_tensors = err_tensors + std::string("scale_dQKV, amax_dP and amax_dQKV"); + NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); + } + te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, + qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); + te_K = makeTransformerEngineTensor(K.data_ptr(), k_shape, + qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); + te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape, + qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); + te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, + qkv_type, nullptr, nullptr, descale_O.value().data_ptr()); + te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape, + qkv_type, nullptr, nullptr, descale_dO.value().data_ptr()); + te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, + scale_S.value().data_ptr(), descale_S.value().data_ptr()); + at::Tensor descale_dP = torch::empty_like(scale_dP.value()); + te_dP = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, + amax_dP.value().data_ptr(), scale_dP.value().data_ptr(), + descale_dP.data_ptr()); + te_dQ = makeTransformerEngineTensor(dQ.data_ptr(), q_shape, qkv_type, + amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr); + te_dK = makeTransformerEngineTensor(dK.data_ptr(), k_shape, qkv_type, + amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr); + te_dV = makeTransformerEngineTensor(dV.data_ptr(), v_shape, qkv_type, + amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr); + } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { + // BF16 or FP16 + te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, + qkv_type, nullptr, nullptr, nullptr); + te_K = makeTransformerEngineTensor(K.data_ptr(), k_shape, + qkv_type, nullptr, nullptr, nullptr); + te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape, + qkv_type, nullptr, nullptr, nullptr); + te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, + qkv_type, nullptr, nullptr, nullptr); + te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape, + qkv_type, nullptr, nullptr, nullptr); + te_S = makeTransformerEngineTensor(nullptr, {0}, + DType::kFloat32, nullptr, nullptr, nullptr); + te_dP = makeTransformerEngineTensor(nullptr, {0}, + DType::kFloat32, nullptr, nullptr, nullptr); + te_dQ = makeTransformerEngineTensor(dQ.data_ptr(), q_shape, + qkv_type, nullptr, nullptr, nullptr); + te_dK = makeTransformerEngineTensor(dK.data_ptr(), k_shape, + qkv_type, nullptr, nullptr, nullptr); + te_dV = makeTransformerEngineTensor(dV.data_ptr(), v_shape, + qkv_type, nullptr, nullptr, nullptr); + } else { + NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); + } + + // create cu_seqlens tensorwrappers + auto cu_seqlens_q_sizes = cu_seqlens_q.sizes().vec(); + std::vector cu_seqlens_q_shape{cu_seqlens_q_sizes.begin(), cu_seqlens_q_sizes.end()}; + auto cu_seqlens_kv_sizes = cu_seqlens_kv.sizes().vec(); + std::vector cu_seqlens_kv_shape{cu_seqlens_kv_sizes.begin(), cu_seqlens_kv_sizes.end()}; + TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv, te_qkvso_strides; + te_cu_seqlens_q = makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), cu_seqlens_q_shape, + DType::kInt32, nullptr, nullptr, nullptr); + te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape, + DType::kInt32, nullptr, nullptr, nullptr); + + // convert auxiliary tensors from forward to NVTETensors + NVTETensorPack nvte_aux_tensor_pack; + nvte_tensor_pack_create(&nvte_aux_tensor_pack); + nvte_aux_tensor_pack.size = Aux_CTX_Tensors.size(); + for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { + auto tensor = reinterpret_cast(nvte_aux_tensor_pack.tensors[i]); + tensor->data.dptr = Aux_CTX_Tensors[i].data_ptr(); + std::vector tmp(Aux_CTX_Tensors[i].sizes().vec()); + tensor->data.shape = std::vector(tmp.begin(), tmp.end()); + tensor->data.dtype = GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type()); + } + + // create workspace + TensorWrapper workspace; + + // populate tensors with appropriate shapes and dtypes + nvte_fused_attn_bwd( + te_Q.data(), + te_K.data(), + te_V.data(), + te_O.data(), + te_dO.data(), + te_S.data(), + te_dP.data(), + &nvte_aux_tensor_pack, + te_dQ.data(), + te_dK.data(), + te_dV.data(), + te_dBias.data(), + te_cu_seqlens_q.data(), + te_cu_seqlens_kv.data(), + max_seqlen_q, max_seqlen_kv, + attn_scale, p_dropout, + qkv_layout, bias_type, attn_mask_type, + workspace.data(), + at::cuda::getCurrentCUDAStream()); + + // allocate memory for workspace + auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); + workspace = makeTransformerEngineTensor( + workspace_data.data_ptr(), + workspace.shape(), workspace.dtype()); + + // execute kernel + nvte_fused_attn_bwd( + te_Q.data(), + te_K.data(), + te_V.data(), + te_O.data(), + te_dO.data(), + te_S.data(), + te_dP.data(), + &nvte_aux_tensor_pack, + te_dQ.data(), + te_dK.data(), + te_dV.data(), + te_dBias.data(), + te_cu_seqlens_q.data(), + te_cu_seqlens_kv.data(), + max_seqlen_q, max_seqlen_kv, + attn_scale, p_dropout, + qkv_layout, bias_type, attn_mask_type, + workspace.data(), + at::cuda::getCurrentCUDAStream()); + + // destroy tensor wrappers + nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); + + return {dQ, dK, dV, dBias}; +} + namespace flash_attention { constexpr int warp_size = 32; diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 93196962e0..abc15022b0 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -56,6 +56,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Fused Attention FP8/BF16/FP16 FWD with packed KV"); m.def("fused_attn_bwd_kvpacked", &fused_attn_bwd_kvpacked, "Fused Attention FP8/BF16/FP16 BWD with packed KV"); + m.def("fused_attn_fwd", &fused_attn_fwd, + "Fused Attention FP8/BF16/FP16 FWD with separate Q, K and V"); + m.def("fused_attn_bwd", &fused_attn_bwd, + "Fused Attention FP8/BF16/FP16 BWD with separate Q, K and V"); m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O"); m.def("gelu", &gelu, "GeLU with FP8 output"); m.def("relu", &relu, "ReLU with FP8 output"); @@ -148,7 +152,22 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::enum_(m, "NVTE_QKV_Layout") .value("NVTE_NOT_INTERLEAVED", NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED) .value("NVTE_QKV_INTERLEAVED", NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) - .value("NVTE_KV_INTERLEAVED", NVTE_QKV_Layout::NVTE_KV_INTERLEAVED); + .value("NVTE_KV_INTERLEAVED", NVTE_QKV_Layout::NVTE_KV_INTERLEAVED) + .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) + .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) + .value("NVTE_SBHD_SB2HD", NVTE_QKV_Layout::NVTE_SBHD_SB2HD) + .value("NVTE_SBHD_SBH2D", NVTE_QKV_Layout::NVTE_SBHD_SBH2D) + .value("NVTE_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD) + .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD) + .value("NVTE_BSH3D", NVTE_QKV_Layout::NVTE_BSH3D) + .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD) + .value("NVTE_BSHD_BSH2D", NVTE_QKV_Layout::NVTE_BSHD_BSH2D) + .value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) + .value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD) + .value("NVTE_TH3D", NVTE_QKV_Layout::NVTE_TH3D) + .value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) + .value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D) + .value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD); py::enum_(m, "NVTE_Fused_Attn_Backend") .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index d4046ec7da..8ac14758e7 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -74,6 +74,7 @@ class TransformerLayer(torch.nn.Module): are deprecated and will be fully removed in future releases. .. note:: + Argument :attr:`attention_mask` will be ignored in the `forward` call when :attr:`self_attn_mask_type` is set to `"causal"`. @@ -624,5 +625,5 @@ def forward( if self.output_layernorm: output = self.layernorm(output) - # output: [b, s, h] + # output: [s, b, h] return output From f575ff935c54307fffcdf6b051f8eba105fb02e2 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 26 Sep 2023 22:47:55 -0700 Subject: [PATCH 58/68] Add release to deprecation warnings (#447) Change deprecation warnings Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/jax/__init__.py | 30 ++++++++++++------- transformer_engine/pytorch/module/base.py | 4 +-- .../pytorch/module/layernorm_linear.py | 8 ++--- transformer_engine/pytorch/module/linear.py | 8 ++--- transformer_engine/pytorch/transformer.py | 4 +-- 5 files changed, 32 insertions(+), 22 deletions(-) diff --git a/transformer_engine/jax/__init__.py b/transformer_engine/jax/__init__.py index 0459402172..793e6c3f8b 100644 --- a/transformer_engine/jax/__init__.py +++ b/transformer_engine/jax/__init__.py @@ -10,29 +10,39 @@ extend_logical_axis_rules = deprecate_wrapper( flax.extend_logical_axis_rules, - "extend_logical_axis_rules is moving to transformer_engine.jax.flax module") + "extend_logical_axis_rules is moving to transformer_engine.jax.flax module" + " and will be fully removed in the next release (v1.0.0).") DenseGeneral = deprecate_wrapper(flax.DenseGeneral, - "DenseGeneral is moving to transformer_engine.jax.flax module") + "DenseGeneral is moving to transformer_engine.jax.flax module" + " and will be fully removed in the next release (v1.0.0).") LayerNorm = deprecate_wrapper(flax.LayerNorm, - "LayerNorm is moving to transformer_engine.jax.flax module") + "LayerNorm is moving to transformer_engine.jax.flax module" + " and will be fully removed in the next release (v1.0.0).") LayerNormDenseGeneral = deprecate_wrapper( flax.LayerNormDenseGeneral, - "LayerNormDenseGeneral is moving to transformer_engine.jax.flax module") + "LayerNormDenseGeneral is moving to transformer_engine.jax.flax module" + " and will be fully removed in the next release (v1.0.0).") LayerNormMLP = deprecate_wrapper(flax.LayerNormMLP, - "LayerNormMLP is moving to transformer_engine.jax.flax module") + "LayerNormMLP is moving to transformer_engine.jax.flax module" + " and will be fully removed in the next release (v1.0.0).") TransformerEngineBase = deprecate_wrapper( flax.TransformerEngineBase, - "TransformerEngineBase is moving to transformer_engine.jax.flax module") + "TransformerEngineBase is moving to transformer_engine.jax.flax module" + " and will be fully removed in the next release (v1.0.0).") MultiHeadAttention = deprecate_wrapper( - flax.MultiHeadAttention, "MultiHeadAttention is moving to transformer_engine.jax.flax module") + flax.MultiHeadAttention, "MultiHeadAttention is moving to transformer_engine.jax.flax module" + " and will be fully removed in the next release (v1.0.0).") RelativePositionBiases = deprecate_wrapper( flax.RelativePositionBiases, - "RelativePositionBiases is moving to transformer_engine.jax.flax module") + "RelativePositionBiases is moving to transformer_engine.jax.flax module" + " and will be fully removed in the next release (v1.0.0).") TransformerLayer = deprecate_wrapper( - flax.TransformerLayer, "TransformerLayer is moving to transformer_engine.jax.flax module") + flax.TransformerLayer, "TransformerLayer is moving to transformer_engine.jax.flax module" + " and will be fully removed in the next release (v1.0.0).") TransformerLayerType = deprecate_wrapper( flax.TransformerLayerType, - "TransformerLayerType is moving to transformer_engine.jax.flax module") + "TransformerLayerType is moving to transformer_engine.jax.flax module" + " and will be fully removed in the next release (v1.0.0).") __all__ = [ 'fp8_autocast', 'update_collections', 'update_fp8_metas', 'get_delayed_scaling', diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 50d7b9f2fb..8bb9d55f38 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -366,7 +366,7 @@ def set_extra_state(self, state: torch.Tensor) -> None: if isinstance(state, list): warnings.warn( "This checkpoint format is deprecated and will be" - "removed in a future release of Transformer Engine" + "removed in the next release (v1.0.0)." ) # Retrieve checkpointed items. @@ -412,7 +412,7 @@ def set_extra_state(self, state: torch.Tensor) -> None: else: warnings.warn( "This checkpoint format is deprecated and will be" - "removed in a future release of Transformer Engine" + "removed in the next release (v1.0.0)." ) # Load extra items. self.fp8_meta.update(state["extra_fp8_variables"]) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 761b0abf6b..b7372f81fe 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -514,7 +514,7 @@ class LayerNormLinear(TransformerEngineBaseModule): .. warning:: Argument :attr:`skip_weight_param_allocation` is deprecated and will - be fully removed in future releases. + be fully removed in the next release (v1.0.0). Parameters ---------- @@ -622,7 +622,7 @@ def __init__( if skip_weight_param_allocation: warnings.warn( "Argument `skip_weight_param_allocation` is deprecated and" - "will be fully removed in future releases. It is ignored" + "will be fully removed in the next release (v1.0.0). It is ignored" "starting from v0.11.", category=DeprecationWarning, ) @@ -827,7 +827,7 @@ def forward( .. warning:: Arguments :attr:`weight` and :attr:`bias` are deprecated and will - be fully removed in future releases. + be fully removed in the next release (v1.0.0). Parameters ---------- @@ -851,7 +851,7 @@ def forward( if weight is not None or bias is not None: raise RuntimeError( "Arguments `weight` and `bias` are deprecated and " - "will be fully removed in future releases." + "will be fully removed in the next release (v1.0.0)." ) with self.prepare_forward(inp, is_first_microbatch) as inp: diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 45a163966b..98ca2015ed 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -448,7 +448,7 @@ class Linear(TransformerEngineBaseModule): .. warning:: Argument :attr:`skip_weight_param_allocation` is deprecated and will - be fully removed in future releases. + be fully removed in the next release (v1.0.0). Parameters ---------- @@ -535,7 +535,7 @@ def __init__( if skip_weight_param_allocation: warnings.warn( "Argument `skip_weight_param_allocation` is deprecated and" - "will be fully removed in future releases. It has ignored" + "will be fully removed in the next release (v1.0.0). It has ignored" "starting from v0.11.", category=DeprecationWarning, ) @@ -701,7 +701,7 @@ def forward( .. warning:: Arguments :attr:`weight` and :attr:`bias` are deprecated and will - be fully removed in future releases. + be fully removed in the next release (v1.0.0). Parameters ---------- @@ -725,7 +725,7 @@ def forward( if weight is not None or bias is not None: raise RuntimeError( "Arguments `weight` and `bias` are deprecated and " - "will be fully removed in future releases." + "will be fully removed in the next release (v1.0.0)." ) with self.prepare_forward(inp, is_first_microbatch) as inp: diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 8ac14758e7..d8a1aa1ad2 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -71,7 +71,7 @@ class TransformerLayer(torch.nn.Module): .. warning:: Arguments :attr:`attention_softmax_in_fp32` and :attr:`apply_query_key_layer_scaling` - are deprecated and will be fully removed in future releases. + are deprecated and will be fully removed in the next release (v1.0.0). .. note:: @@ -247,7 +247,7 @@ def __init__( warnings.warn( "Arguments `attention_softmax_in_fp32` and `apply_query_key_layer_scaling`" - "are deprecated and will be fully removed in future releases.", + "are deprecated and will be fully removed in the next release (v1.0.0).", category=DeprecationWarning, ) From dfd29c48fe61e9fe419bb02710b53f064c39d1a3 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 26 Sep 2023 22:48:09 -0700 Subject: [PATCH 59/68] Keep previous FA version (#450) Signed-off-by: Kirthi Shankar Sivamani --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index bcccd8208f..5959c2b941 100644 --- a/setup.py +++ b/setup.py @@ -290,7 +290,7 @@ def add_unique(l: List[str], vals: Union[str, List[str]]) -> None: # Framework-specific requirements if "pytorch" in frameworks(): - add_unique(install_reqs, ["torch", "flash-attn>=1.0.6, <=2.2.1"]) + add_unique(install_reqs, ["torch", "flash-attn>=1.0.6, <=2.0.4"]) add_unique(test_reqs, ["numpy", "onnxruntime", "torchvision"]) if "jax" in frameworks(): if not found_pybind11(): From 02a4ccceb02309ae1544562edd689b2edcc89696 Mon Sep 17 00:00:00 2001 From: vasunvidia <108759426+vasunvidia@users.noreply.github.com> Date: Thu, 5 Oct 2023 13:20:16 -0700 Subject: [PATCH 60/68] Atomic gemm and FP8 Reduce Scatter (#449) * Initial commit Signed-off-by: Vasudevan Rengasamy * Repro for RS output mismatch with Single GEMM + Split pipelined RS Signed-off-by: Vasudevan Rengasamy * minor changes for AG->GEMM pipelined overlap Signed-off-by: Vasudevan Rengasamy * Add Atomic Gemm cublasApi attributes and initial implementation of AG->Atomic GEMM Signed-off-by: Vasudevan Rengasamy * AtomicGemm+RS functional with workaround Signed-off-by: Vasudevan Rengasamy * add amax update to layernorm_linear for FP8 unit test accuracy Signed-off-by: Vasudevan Rengasamy * Enable reducescatter2_userbuff_strided variants Signed-off-by: Vasudevan Rengasamy * Bug fix Signed-off-by: Vasudevan Rengasamy * AG+AtomicGemm overlap functional but gemm doesnt overlap with comm Signed-off-by: Vasudevan Rengasamy * Add userbuffers_sendrecv kernel variants Signed-off-by: Vasudevan Rengasamy * TransformerLayer API changes to enable AtomicGemm+RS overlap Signed-off-by: Vasudevan Rengasamy * Code cleanup Signed-off-by: Vasudevan Rengasamy * Code cleanup2 Signed-off-by: Vasudevan Rengasamy * [UB] AllGather Atomic GEMM overlap using userbuffer_sendrecv kernels Signed-off-by: Vasudevan Rengasamy * Code cleanup + bug fix for multiatomic sendrecv kernel Signed-off-by: Vasudevan Rengasamy * cleanup Signed-off-by: Vasudevan Rengasamy * Bug fixes Signed-off-by: Vasudevan Rengasamy * [UB] Add shuffling for better AG AtomicGEMM overlap Signed-off-by: Vasudevan Rengasamy * Bug fix for AG AtomicGemm overlap Signed-off-by: Vasudevan Rengasamy * Bug fix for multiAtomicAG and singleAtomicAG Signed-off-by: Vasudevan Rengasamy * Use chunk_i+1 as recv_chunk for multiatomic_AG with shuffling Signed-off-by: Vasudevan Rengasamy * Launch AtomicGEMM after first-chunk AG Signed-off-by: Vasudevan Rengasamy * Rebase to main Signed-off-by: Vasudevan Rengasamy * Add FP8 ReduceScatter kernels, AtomicGEMM+FP8 RS not functional Signed-off-by: Vasudevan Rengasamy * Revert "Add FP8 ReduceScatter kernels, AtomicGEMM+FP8 RS not functional" This reverts commit 80a47a76355440cd5fb4314c96fe9fda632d87f9. Signed-off-by: Vasudevan Rengasamy * Add support for NVLS-MC and FP8 Reduce Scatter Signed-off-by: Vasudevan Rengasamy * Bug fix Signed-off-by: Vasudevan Rengasamy * Atomic and Multiatomic FP8 RS functional Signed-off-by: Vasudevan Rengasamy * Remove debug print Signed-off-by: Vasudevan Rengasamy * UB comm initialization hang fix Signed-off-by: Vasudevan Rengasamy * Code cleanup Signed-off-by: Vasudevan Rengasamy * Create new GEMM API for Atomic GEMM Signed-off-by: Vasudevan Rengasamy * CI ready Signed-off-by: Kirthi Shankar Sivamani * more fixes Signed-off-by: Kirthi Shankar Sivamani * license Signed-off-by: Kirthi Shankar Sivamani * Bug fix Signed-off-by: Vasudevan Rengasamy * Revert NVLS-MC Signed-off-by: Vasudevan Rengasamy * Check cu* versions for running atomic gemms Signed-off-by: Kirthi Shankar Sivamani * lint Signed-off-by: Kirthi Shankar Sivamani * fixes Signed-off-by: Kirthi Shankar Sivamani * Cleanup Signed-off-by: Vasudevan Rengasamy * Add experimental warning Signed-off-by: Kirthi Shankar Sivamani * Better wording Signed-off-by: Kirthi Shankar Sivamani * Add warning to c api Signed-off-by: Kirthi Shankar Sivamani * Fix wording Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Vasudevan Rengasamy Signed-off-by: Kirthi Shankar Sivamani Co-authored-by: Kirthi Shankar Sivamani --- tests/pytorch/test_onnx_export.py | 4 +- .../common/gemm/cublaslt_gemm.cu | 108 +- .../common/include/transformer_engine/gemm.h | 46 + transformer_engine/pytorch/attention.py | 6 + .../pytorch/cpp_extensions/gemm.py | 26 +- .../pytorch/csrc/comm_gemm_overlap.h | 502 ++- transformer_engine/pytorch/csrc/extensions.h | 26 + .../pytorch/csrc/extensions/gemm.cu | 80 + .../pytorch/csrc/extensions/pybind.cpp | 12 +- .../csrc/userbuffers/userbuffers-host.cpp | 186 +- .../pytorch/csrc/userbuffers/userbuffers.cu | 2949 ++++++++++++++--- .../pytorch/csrc/userbuffers/userbuffers.h | 83 + transformer_engine/pytorch/module/base.py | 19 +- .../pytorch/module/layernorm_linear.py | 69 +- .../pytorch/module/layernorm_mlp.py | 124 +- transformer_engine/pytorch/module/linear.py | 62 +- transformer_engine/pytorch/transformer.py | 20 + 17 files changed, 3619 insertions(+), 703 deletions(-) diff --git a/tests/pytorch/test_onnx_export.py b/tests/pytorch/test_onnx_export.py index 727ccce3dd..171b2f23c4 100644 --- a/tests/pytorch/test_onnx_export.py +++ b/tests/pytorch/test_onnx_export.py @@ -506,7 +506,7 @@ def forward(self, inp, weight): self.fp8_tensor_weight, self.weights_type) - ret = fp8_gemm( + ret, _ = fp8_gemm( weight_fp8, self.meta_weight.scale_inv, self.fp8_tensor_weight, @@ -1324,7 +1324,7 @@ def forward(self, inp, weight): self.fp8_tensor_weight, self.weights_type) - ret = fp8_gemm( + ret, _ = fp8_gemm( weight_fp8, self.meta_weight.scale_inv, self.fp8_tensor_weight, diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 7f8b0b723d..95ef55bba4 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include "../common.h" @@ -50,6 +51,10 @@ void cublas_gemm(const Tensor *inputA, bool accumulate, bool use_split_accumulator, int math_sm_count, + int m_split, + int n_split, + bool gemm_producer, + const Tensor *inputCounter, cudaStream_t stream ) { void *A = inputA->data.dptr; @@ -63,6 +68,10 @@ void cublas_gemm(const Tensor *inputA, void *bias_ptr = inputBias->data.dptr; const bool bias = bias_ptr != nullptr; void *pre_gelu_out = outputPreGelu->data.dptr; + void *counter = nullptr; + if (inputCounter != nullptr) { + counter = inputCounter->data.dptr; + } const bool gelu = pre_gelu_out != nullptr; const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) || is_fp8_dtype(inputB->data.dtype); @@ -223,6 +232,27 @@ void cublas_gemm(const Tensor *inputA, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue))); +#if CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 120205 + if (counter != nullptr) { + if (m_split == 0) m_split=1; + if (n_split == 0) n_split=1; + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_NUM_CHUNKS_D_ROWS, + &m_split, sizeof(m_split))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_NUM_CHUNKS_D_COLS, + &n_split, sizeof(n_split))); + if (gemm_producer) { + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_OUT_COUNTERS_POINTER, + &counter, sizeof(counter))); + } else { + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_IN_COUNTERS_POINTER, + &counter, sizeof(counter))); + } + } +#endif NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceCreate(&preference)); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( @@ -254,7 +284,6 @@ void cublas_gemm(const Tensor *inputA, workspaceSize, stream)); /* stream */ - NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceDestroy(preference)); NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Ddesc)); NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Cdesc)); @@ -320,5 +349,82 @@ void nvte_cublas_gemm(const NVTETensor A, wspace->data.shape[0], accumulate, use_split_accumulator, math_sm_count, + 0, + 0, + false, + nullptr, + stream); +} + +void nvte_cublas_atomic_gemm(const NVTETensor A, + const NVTETensor B, + NVTETensor D, + const NVTETensor bias, + NVTETensor pre_gelu_out, + bool transa, + bool transb, + bool grad, + NVTETensor workspace, + bool accumulate, + bool use_split_accumulator, + int math_sm_count, + int m_split, + int n_split, + bool gemm_producer, + const NVTETensor counter, + cudaStream_t stream) { + NVTE_API_CALL(nvte_cublas_atomic_gemm); + + int cudart_version; + NVTE_CHECK_CUDA(cudaRuntimeGetVersion(&cudart_version)); + NVTE_CHECK(cudart_version >= 12020, "Cuda version 12.2 is required for atomic gemm."); + NVTE_CHECK(cublasLtGetVersion() >= 120205, "Cublas version 12.2.5 is required for atomic gemm."); + + using namespace transformer_engine; + const Tensor *inputA = reinterpret_cast(A); + const Tensor *inputB = reinterpret_cast(B); + Tensor *outputD = reinterpret_cast(D); + const Tensor *biasTensor = reinterpret_cast(bias); + Tensor *outputGelu = reinterpret_cast(pre_gelu_out); + const Tensor *inputCounter = reinterpret_cast(counter); + Tensor *wspace = reinterpret_cast(workspace); + + const int m = transa ? inputA->data.shape[0] : inputA->data.shape[1]; + const int k = transa ? inputA->data.shape[1] : inputA->data.shape[0]; + const int n = transb ? inputB->data.shape[1] : inputB->data.shape[0]; + int lda, ldb, ldd; + if (transa && !transb) { // TN + lda = k; + ldb = k; + ldd = m; + } else if (!transa && !transb) { // NN + lda = m; + ldb = k; + ldd = m; + } else if (!transa && transb) { // NT + lda = m; + ldb = n; + ldd = m; + } else { // TT + NVTE_ERROR("TT layout not allowed."); + } + + cublas_gemm(inputA, + inputB, + outputD, + biasTensor, + outputGelu, + m, n, k, + lda, ldb, ldd, + (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, + (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, + grad, wspace->data.dptr, + wspace->data.shape[0], + accumulate, use_split_accumulator, + math_sm_count, + m_split, + n_split, + gemm_producer, + inputCounter, stream); } diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index 8cd549b658..5faff43afa 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -54,6 +54,52 @@ void nvte_cublas_gemm(const NVTETensor A, cudaStream_t stream ); +/*! \brief Compute matrix multiplication of 2 matrices with chunking and atomic counters. + * + * \warning Cublas atomic gemm uses a beta API and is not tested for all use cases. + * + * Computes: + * - `D = AB` if both `bias` and `pre_gelu_out` are empty tensors + * - `D = AB + bias` if `pre_gelu_out` is empty and `bias` is not empty + * - `D = GELU(AB + bias)` if both `bias` and `pre_gelu_out` are not empty tensors + * + * \param[in] A The A matrix. + * \param[in] B The B matrix. + * \param[in,out] D Output matrix. + * \param[in] bias Bias tensor. + * \param[in,out] pre_gelu_out Output matrix before GELU activation. + * \param[in] transa Whether A matrix is transposed. + * \param[in] transb Whether B matrix is transposed. + * \param[in] grad Whether this operation is part of the + * gradient computation. + * \param[out] workspace Workspace tensor. + * \param[in] accumulate Whether to accumulate the result into the D matrix. + * \param[in] use_split_accumulator Whether to use split accumulator in the FP8 GEMM. + * \param[in] math_sm_count Number of GPU SMs to use (default=0: use cuBLAS heuristics) + * \param[in] m_split Number of chunks/splits along m-dimension for Atomic GEMM. + * \param[in] n_split Number of chunks/splits along n-dimension for Atomic GEMM. + * \param[in] gemm_producer Whether Atomic GEMM is the producer or consumer. + * \param[in,out] counter counter[chunk_i]=0 indicates chunk_i has been produced. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_cublas_atomic_gemm(const NVTETensor A, + const NVTETensor B, + NVTETensor D, + const NVTETensor bias, + NVTETensor pre_gelu_out, + bool transa, + bool transb, + bool grad, + NVTETensor workspace, + bool accumulate, + bool use_split_accumulator, + int math_sm_count, + int m_split, + int n_split, + bool gemm_producer, + const NVTETensor counter, + cudaStream_t stream +); #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 625cd8644e..3fb67b990a 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1505,6 +1505,8 @@ def __init__( ub_bulk_dgrad: bool = False, ub_split_rs: bool = False, ub_split_ag: bool = False, + ub_atomic_gemm_rs: bool = False, + ub_atomic_gemm_ag: bool = False, bias: bool = True, normalization: str = "LayerNorm", device: Union[torch.device, str] = "cuda", @@ -1585,6 +1587,7 @@ def __init__( ub_bulk_dgrad=ub_bulk_dgrad, ub_split_ag=ub_split_ag, normalization=normalization, + ub_atomic_gemm_ag=ub_atomic_gemm_ag, **common_gemm_kwargs, ) else: @@ -1615,6 +1618,7 @@ def __init__( ub_bulk_dgrad=ub_bulk_dgrad, ub_split_ag=ub_split_ag, normalization=normalization, + ub_atomic_gemm_ag=ub_atomic_gemm_ag, **common_gemm_kwargs, ) else: @@ -1661,6 +1665,8 @@ def __init__( parallel_mode="row" if set_parallel_mode else None, ub_split_rs=ub_split_rs, ub_split_ag=ub_split_ag, + ub_atomic_gemm_rs=ub_atomic_gemm_rs, + ub_atomic_gemm_ag=ub_atomic_gemm_ag, **common_gemm_kwargs, ) diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index c84dd1cb39..2d271c950c 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -92,22 +92,40 @@ def fp8_gemm( assert ub is not None, 'ub object is None!' if ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_AG: fn = ub.bulk_overlap - args = tuple(args + (1,)) + extra_output_tensor = ( + empty_tensor if extra_output_tensor is None else extra_output_tensor + ) + args = tuple(args + (1, extra_output_tensor,)) elif ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_RS: fn = ub.bulk_overlap - args = tuple(args + (0,)) + extra_output_tensor = ( + empty_tensor if extra_output_tensor is None else extra_output_tensor + ) + args = tuple(args + (0, extra_output_tensor,)) elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG: fn = ub.split_overlap_ag extra_output_tensor = ( empty_tensor if extra_output_tensor is None else extra_output_tensor ) args = tuple(args + (extra_output_tensor,)) + elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_AG: + fn = ub.atomic_gemm_overlap_ag + extra_output_tensor = ( + empty_tensor if extra_output_tensor is None else extra_output_tensor + ) + args = tuple(args + (extra_output_tensor,)) elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS: fn = ub.split_overlap_rs assert ( extra_output_tensor is not None ), 'SPLIT_PIPELINED_RS requires extra output tensor' args = tuple(args + (True, extra_output_tensor,)) + elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_RS: + fn = ub.atomic_gemm_overlap_rs + assert ( + extra_output_tensor is not None + ), 'ATOMIC_GEMM_RS requires extra output tensor' + args = tuple(args + (True, extra_output_tensor,)) _ = fn(*args) if return_output: @@ -204,10 +222,10 @@ def gemm( assert ub is not None, 'ub object is None!' if ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_AG: fn = ub.bulk_overlap - args = tuple(args + (1,)) + args = tuple(args + (1, empty_tensor)) elif ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_RS: fn = ub.bulk_overlap - args = tuple(args + (0,)) + args = tuple(args + (0, empty_tensor)) elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG: fn = ub.split_overlap_ag extra_output_tensor = ( diff --git a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h index 5dd71e4758..edac58a9dd 100644 --- a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h +++ b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h @@ -4,30 +4,32 @@ * See LICENSE for license information. ************************************************************************/ +#include "userbuffers/userbuffers.h" #include #include #include #include +#include #include #include #include #include #include #include -#include "userbuffers/userbuffers.h" #define HALF_BYTES 2 #define UB_MAX_SM 32 -#define CHECK_CUDA(call) \ - do { \ - cudaError_t status_ = call; \ - if (status_ != cudaSuccess) { \ - fprintf(stderr, "CUDA Error at line %d: %s\n", __LINE__, cudaGetErrorString(status_)); \ - exit(1); \ - } \ +#define CHECK_CUDA(call) \ + do { \ + cudaError_t status_ = call; \ + if (status_ != cudaSuccess) { \ + fprintf(stderr, "CUDA Error at line %d: %s\n", __LINE__, cudaGetErrorString(status_)); \ + exit(1); \ + } \ } while (0) +using namespace torch::indexing; namespace ubuf { enum class COMM_TYPE { RS = 0, AG = 1 }; @@ -36,11 +38,16 @@ enum class UBOverlapAlgo { BULK_OVERLAP_AG = 0, BULK_OVERLAP_RS = 1, SPLIT_PIPELINED_AG = 2, - SPLIT_PIPELINED_RS = 3 + SPLIT_PIPELINED_RS = 3, + ATOMIC_GEMM_RS = 4, + ATOMIC_GEMM_AG = 5 }; -struct UbufCommOverlap : torch::CustomClassHolder { - communicator *_ub_comm; +struct UbufBase { + static inline communicator *_ub_comm{nullptr}; + static inline bool comm_created{false}; +}; +struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { int _tp_id; int _tp_size; int _num_splits; @@ -49,24 +56,53 @@ struct UbufCommOverlap : torch::CustomClassHolder { void *_ubuf_ptr; torch::Tensor _ubuf; torch::Tensor output_tensor; + torch::Tensor _ubuf_scale_inv; + bool _ubuf_scale_inv_initialized; + torch::Tensor counter; + torch::Tensor _empty_tensor; at::cuda::CUDAStream _stream_comm = at::cuda::getStreamFromPool(true); std::vector _stream_compute; cudaEvent_t _start_compute, _stop_compute, _start_d2dcopy, _start_comm, _stop_comm; + int comm_sms; + int cga_size; + int use_ce; UbufCommOverlap(torch::Tensor sample, int rank, int tp_size, int num_comm_sm, int comm_cga_size, - int num_splits, bool set_sm_margin, int num_max_streams) { + int num_splits, bool set_sm_margin, int num_max_streams, + torch::Tensor empty_tensor) { // Initialize userbuf communicator - create_communicator_grouped2(&_ub_comm, 1, 1, tp_size, 1); - _ub_comm->use_ce = 0; - _ub_comm->sms = num_comm_sm; - _ub_comm->cga_size = comm_cga_size; + if (!comm_created) { + if (rank == 0) { + printf("!!! [UB] Create UbufCommOverlap Communicator\n"); + } + create_communicator_grouped2(&_ub_comm, 1, 1, tp_size, 1); + comm_created = true; + } + use_ce = 0; + comm_sms = num_comm_sm; + cga_size = comm_cga_size; + _empty_tensor = empty_tensor; // Allocate and register extra userbuffers int ubuf_bytes = sample.numel() * sample.element_size(); _ub_reg = register_user_buffer_collective(reinterpret_cast(&_ubuf_ptr), ubuf_bytes, _ub_comm, true); + if (rank == 0) { + printf("!!! [UB] Register UBuf %d\n", _ub_reg); + } _ubuf = torch::from_blob(_ubuf_ptr, {sample.size(0), sample.size(1)}, sample.options()); + const char *env_p = std::getenv("NVTE_RS_STRIDED_ATOMIC"); + const char *env_q = std::getenv("NVTE_UB_ATOMIC_GEMM_RS"); + if (rank == 0 && env_p != nullptr && env_q != nullptr && env_q[0] == '1') { + if (env_p[0] == '1') + printf("!! Using reducescatter2_userbuff_strided_atomic\n"); + else if (env_p[0] == '2') + printf("!! Using reducescatter2_userbuff_strided_multiatomic\n"); + else + printf("!! Using reducescatter2_userbuff_strided\n"); + } + at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); for (int i = 0; i < std::min(num_max_streams, num_splits); i++) { cudaStream_t stream; @@ -78,6 +114,7 @@ struct UbufCommOverlap : torch::CustomClassHolder { _num_splits = num_splits; _tp_size = tp_size; _tp_id = (rank % tp_size); + _ubuf_scale_inv_initialized = false; // Set the number of SMs for GEMM with margin cudaDeviceProp prop; @@ -85,6 +122,9 @@ struct UbufCommOverlap : torch::CustomClassHolder { _math_sms = (set_sm_margin) ? prop.multiProcessorCount - num_comm_sm : prop.multiProcessorCount; output_tensor = torch::Tensor(); + auto counter_options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA); + counter = torch::zeros({num_splits * 2}, counter_options); + counter.index_put_({Slice(None, num_splits)}, 1); // CUDA event creation cudaEventCreateWithFlags(&_start_compute, 0); cudaEventCreateWithFlags(&_stop_compute, 0); @@ -97,13 +137,17 @@ struct UbufCommOverlap : torch::CustomClassHolder { ** Bulk GEMM + COMM ** This function assumes the communication input is pre-copied to _ubuf */ - std::vector bulk_overlap( - at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse, - int64_t B_fp8_tensor, transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, int comm_type) { + std::vector + bulk_overlap(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, + transformer_engine::DType A_type, bool transa, at::Tensor B, + at::Tensor B_scale_inverse, int64_t B_fp8_tensor, transformer_engine::DType B_type, + bool transb, at::Tensor D, at::Tensor D_scale, transformer_engine::DType D_type, + at::Tensor D_amax, at::Tensor bias, transformer_engine::DType bias_type, + at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, size_t workspaceSize, + bool accumulate, bool use_split_accumulator, int comm_type, at::Tensor rs_output) { + _ub_comm->use_ce = use_ce; + _ub_comm->sms = comm_sms; + _ub_comm->cga_size = cga_size; // Get the current userbuf offset char *ubuf_wt_ptr = reinterpret_cast(_ubuf.data_ptr()); int comm_elements = (_ubuf.numel() / 2) * _ubuf.element_size(); // UBUF uses 2Byte element size @@ -121,15 +165,30 @@ struct UbufCommOverlap : torch::CustomClassHolder { if (_comm_type == COMM_TYPE::AG) { allgather2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, (cudaStream_t)_stream_comm); } else if (_comm_type == COMM_TYPE::RS) { - reducescatter2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, - (cudaStream_t)_stream_comm); + if (_ubuf.element_size() == 1) { + assert(_ubuf_scale_inv_initialized); + comm_elements *= 2; + float *scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); + assert(rs_output.numel() == _ubuf.numel() / _tp_size); + assert(rs_output.size(0) == _ubuf.size(0) / _tp_size); + assert(rs_output.element_size() == 2); + char *rs_output_ptr = reinterpret_cast(rs_output.data_ptr()); + reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(rs_output_ptr, scale_inv_ptr, _ub_reg, 0, + comm_elements, _ub_comm, + (cudaStream_t)_stream_comm); + } else { + reducescatter2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, + (cudaStream_t)_stream_comm); + } } else { NVTE_ERROR("Not supported communication type."); } - if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; + if (A_scale_inverse.numel()) + A_scale_inverse = A_scale_inverse[A_fp8_tensor]; - if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor]; + if (B_scale_inverse.numel()) + B_scale_inverse = B_scale_inverse[B_fp8_tensor]; assert(pre_gelu_out.numel() == 0); te_gemm(A, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb, D, D_scale, @@ -147,6 +206,117 @@ struct UbufCommOverlap : torch::CustomClassHolder { return {D, output_tensor}; } // bulk_overlap + /* + ** Split FPROP GEMM + ReduceScatter + */ + void atomic_gemm_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, + transformer_engine::DType A_type, bool transa, at::Tensor B, + at::Tensor B_scale_inverse, int64_t B_fp8_tensor, + transformer_engine::DType B_type, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, + at::Tensor D_amax, at::Tensor bias, + transformer_engine::DType bias_type, at::Tensor pre_gelu_out, + bool grad, at::Tensor workspace, size_t workspaceSize, + bool accumulate, bool use_split_accumulator, bool gemm_overlap, + at::Tensor rs_output) { + _ub_comm->use_ce = use_ce; + _ub_comm->sms = comm_sms; + _ub_comm->cga_size = cga_size; + // Get GEMM dimensions + int m = A.size(0); + int k = A.size(1); + int n = B.size(0); + int m_chunk = m / _num_splits; + int workspace_size_chunk = workspaceSize / _stream_compute.size(); + + // Get input, output, and workspace data pointers + char *input_a_chunk_ptr = reinterpret_cast(A.data_ptr()); + char *output_buf_chunk_ptr = reinterpret_cast(_ubuf.data_ptr()); + char *workspace_ptr = reinterpret_cast(workspace.data_ptr()); + int *counter_ptr = reinterpret_cast(counter.data_ptr()); + char *rs_output_ptr = reinterpret_cast(rs_output.data_ptr()); + int ori_sms = _ub_comm->sms; + + // Catch up the default torch stream + at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); + CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); + CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm)); + for (int i = 0; i < _stream_compute.size(); i++) { + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0)); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _stop_comm, 0)); + } + + if (A_scale_inverse.numel()) + A_scale_inverse = A_scale_inverse[A_fp8_tensor]; + + if (B_scale_inverse.numel()) + B_scale_inverse = B_scale_inverse[B_fp8_tensor]; + + assert(pre_gelu_out.numel() == 0); + + torch::Tensor input_a = torch::from_blob(input_a_chunk_ptr, {m, k}, A.options()); + torch::Tensor output_d = torch::from_blob(output_buf_chunk_ptr, {n, m}, _ubuf.options()); + // torch::zeros({n, m}, _ubuf.options()); + torch::Tensor workspace_chunk = + torch::from_blob(workspace_ptr, {workspace_size_chunk}, workspace.options()); + at::cuda::setCurrentCUDAStream(_stream_compute[0]); + te_atomic_gemm(input_a, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb, + output_d, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, + workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, + _math_sms, _num_splits /*m_split*/, 0 /*n_split*/, true /*gemm_producer*/, + counter); + for (int i = 0; i < _num_splits; i++) { + const char *env_p = std::getenv("NVTE_RS_STRIDED_ATOMIC"); + if (env_p != nullptr && env_p[0] == '1') { + if (i == _num_splits - 1) { + _ub_comm->sms = UB_MAX_SM; + } + if (_ubuf.element_size() == 1) { + assert(_ubuf_scale_inv_initialized); + float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); + reducescatter2_userbuff_strided_atomic_fp8<__nv_fp8_e4m3>( + rs_output_ptr, d_scale_inv_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, m, _num_splits, + &counter_ptr[i], _ub_comm, (cudaStream_t)_stream_comm); + } else { + reducescatter2_userbuff_strided_atomic(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, + _num_splits, &counter_ptr[i], _ub_comm, + (cudaStream_t)_stream_comm); + } + } else if (env_p != nullptr && env_p[0] == '2') { + if (_ubuf.element_size() == 1) { + assert(_ubuf_scale_inv_initialized); + float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); + reducescatter2_userbuff_strided_multiatomic_fp8<__nv_fp8_e4m3>( + rs_output_ptr, d_scale_inv_ptr, _ub_reg, m_chunk, m_chunk, n, m, m, _num_splits, + counter_ptr, _ub_comm, (cudaStream_t)_stream_comm); + } else { + reducescatter2_userbuff_strided_multiatomic(rs_output_ptr, _ub_reg, m_chunk, m_chunk, n, + m, _num_splits, counter_ptr, _ub_comm, + (cudaStream_t)_stream_comm); + } + break; + } else { + consumer(counter_ptr, i, (cudaStream_t)_stream_comm); + // if (i == _num_splits-1) { + // _ub_comm->sms = UB_MAX_SM; + // } + reducescatter2_userbuff_strided(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, + _ub_comm, (cudaStream_t)_stream_comm); + } + + rs_output_ptr += m_chunk * rs_output.element_size(); + } + + _ub_comm->sms = ori_sms; + CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[0])); + CHECK_CUDA(cudaEventRecord(_stop_comm, (cudaStream_t)_stream_comm)); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_comm, 0)); + at::cuda::setCurrentCUDAStream(stream_main); + + return; + } // split_overlap_rs + /* ** Split FPROP GEMM + ReduceScatter */ @@ -160,6 +330,9 @@ struct UbufCommOverlap : torch::CustomClassHolder { size_t workspaceSize, bool accumulate, bool use_split_accumulator, bool gemm_overlap, at::Tensor rs_output) { // Get GEMM dimensions + _ub_comm->use_ce = use_ce; + _ub_comm->sms = comm_sms; + _ub_comm->cga_size = cga_size; int m = A.size(0); int k = A.size(1); int n = B.size(0); @@ -174,7 +347,6 @@ struct UbufCommOverlap : torch::CustomClassHolder { char *workspace_ptr = reinterpret_cast(workspace.data_ptr()); char *rs_output_ptr = reinterpret_cast(rs_output.data_ptr()); - int ubuf_offset = 0; int ori_sms = _ub_comm->sms; // Catch up the default torch stream @@ -184,9 +356,11 @@ struct UbufCommOverlap : torch::CustomClassHolder { CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0)); } - if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; + if (A_scale_inverse.numel()) + A_scale_inverse = A_scale_inverse[A_fp8_tensor]; - if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor]; + if (B_scale_inverse.numel()) + B_scale_inverse = B_scale_inverse[B_fp8_tensor]; assert(pre_gelu_out.numel() == 0); @@ -223,10 +397,19 @@ struct UbufCommOverlap : torch::CustomClassHolder { CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0)); // Communication chunk - reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, (i - 1) * output_chunk_size, - m_chunk, n, m, _ub_comm, (cudaStream_t)_stream_comm); + if (_ubuf.element_size() == 1) { + assert(_ubuf_scale_inv_initialized); + float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); + reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e4m3>( + rs_output_ptr, d_scale_inv_ptr, _ub_reg, (i - 1) * output_chunk_size, m_chunk, n, m, + _ub_comm, (cudaStream_t)_stream_comm); + } else { + reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, (i - 1) * output_chunk_size, + m_chunk, n, m, _ub_comm, + (cudaStream_t)_stream_comm); + } - rs_output_ptr += m_chunk * _ubuf.element_size(); + rs_output_ptr += m_chunk * rs_output.element_size(); } int last_compute_stream_id = (_num_splits + _stream_compute.size() - 1) % _stream_compute.size(); @@ -236,9 +419,17 @@ struct UbufCommOverlap : torch::CustomClassHolder { // Last communication chunk with max SM _ub_comm->sms = UB_MAX_SM; - reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, - (_num_splits - 1) * output_chunk_size, m_chunk, n, m, - _ub_comm, (cudaStream_t)_stream_comm); + if (_ubuf.element_size() == 1) { + assert(_ubuf_scale_inv_initialized); + float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); + reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e4m3>( + rs_output_ptr, d_scale_inv_ptr, _ub_reg, (_num_splits - 1) * output_chunk_size, m_chunk, + n, m, _ub_comm, (cudaStream_t)_stream_comm); + } else { + reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, + (_num_splits - 1) * output_chunk_size, m_chunk, n, m, + _ub_comm, (cudaStream_t)_stream_comm); + } } else { for (int i = 0; i < _num_splits; i++) { torch::Tensor input_a_chunk = @@ -259,13 +450,21 @@ struct UbufCommOverlap : torch::CustomClassHolder { CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0)); // Communication chunk. Uses MAX_SM at the last chunk - if (i == _num_splits-1) { + if (i == _num_splits - 1) { _ub_comm->sms = UB_MAX_SM; } - reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, i * output_chunk_size, - m_chunk, n, m, _ub_comm, (cudaStream_t)_stream_comm); - - rs_output_ptr += m_chunk * _ubuf.element_size(); + if (_ubuf.element_size() == 1) { + assert(_ubuf_scale_inv_initialized); + float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); + reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e4m3>( + rs_output_ptr, d_scale_inv_ptr, _ub_reg, i * output_chunk_size, m_chunk, n, m, + _ub_comm, (cudaStream_t)_stream_comm); + } else { + reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, i * output_chunk_size, + m_chunk, n, m, _ub_comm, + (cudaStream_t)_stream_comm); + } + rs_output_ptr += m_chunk * rs_output.element_size(); input_a_chunk_ptr += input_a_chunk_size * B.element_size(); output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size(); } @@ -283,6 +482,12 @@ struct UbufCommOverlap : torch::CustomClassHolder { return; } // split_overlap_rs + void set_ubuf_scale_inv(const torch::Tensor &scale_inv) { + _ubuf_scale_inv = scale_inv; + _ubuf_scale_inv_initialized = true; + } + + bool is_fp8_ubuf() { return (_ubuf.element_size() == 1); } /* ** Helper function to copy input to _ubuf */ @@ -311,7 +516,8 @@ struct UbufCommOverlap : torch::CustomClassHolder { torch::Tensor &get_ubuf_output(int comm_type) { char *ubuf_wt_ptr = reinterpret_cast(_ubuf.data_ptr()); COMM_TYPE _comm_type = static_cast(comm_type); - if (_comm_type != COMM_TYPE::AG && _comm_type != COMM_TYPE::RS) NVTE_ERROR("Invalid comm_type"); + if (_comm_type != COMM_TYPE::AG && _comm_type != COMM_TYPE::RS) + NVTE_ERROR("Invalid comm_type"); if (_comm_type == COMM_TYPE::RS) ubuf_wt_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size(); int output_c_dim0 = (_comm_type == COMM_TYPE::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size; @@ -321,35 +527,51 @@ struct UbufCommOverlap : torch::CustomClassHolder { } }; // UbufCommOverlap -struct UbufP2PCommOverlap : torch::CustomClassHolder { - communicator *_ub_comm; +struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { int _tp_id; int _tp_size; int _ub_reg; int _next_rank, _prev_rank, _rank, _rank_round_tp; int _aggregate2; int _math_sms; + int _self_chunk_id; void *_ubuf_ptr; torch::Tensor _ubuf; + torch::Tensor counter; + torch::Tensor _empty_tensor; std::vector _ubufs; at::cuda::CUDAStream _stream_send = at::cuda::getStreamFromPool(true); at::cuda::CUDAStream _stream_recv = at::cuda::getStreamFromPool(true); std::vector _stream_compute; cudaEvent_t _start_compute, _stop_compute, _stop_send, _stop_recv; + int use_ce; + int sms; + int cga_size; - UbufP2PCommOverlap(torch::Tensor sample, int rank, int tp_size, bool aggregate2, - int num_max_streams) { + UbufP2PCommOverlap(torch::Tensor sample, int rank, int tp_size, int num_comm_sm, + int comm_cga_size, bool set_sm_margin, bool aggregate2, int num_max_streams, + torch::Tensor empty_tensor) { // Initialize userbuf communicator - create_communicator_grouped2(&_ub_comm, 1, 1, tp_size, 1); - _ub_comm->use_ce = 1; - _ub_comm->sms = 1; - _ub_comm->cga_size = 1; + if (!comm_created) { + if (rank == 0) { + printf("!!! [UB] Create UbufP2PCommOverlap Communicator\n"); + } + create_communicator_grouped2(&_ub_comm, 1, 1, tp_size, 1); + comm_created = true; + } + use_ce = 1; + sms = 1; + cga_size = 1; + _empty_tensor = empty_tensor; // Create workspace tensor with userbuffer int ubuf_bytes = sample.numel() * sample.element_size(); int ubuf_chunk_bytes = ubuf_bytes / tp_size; _ub_reg = register_user_buffer_collective(reinterpret_cast(&_ubuf_ptr), ubuf_bytes, _ub_comm, true); + if (rank == 0) { + printf("!!! [UBP2P] Register UBuf %d\n", _ub_reg); + } _ubuf = torch::from_blob(_ubuf_ptr, {sample.size(0), sample.size(1)}, sample.options()); // Create tensor chunks for easy management @@ -372,7 +594,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder { // Set the number of SMs for GEMM with margin cudaDeviceProp prop; cudaGetDeviceProperties(&prop, 0); - _math_sms = prop.multiProcessorCount; + _math_sms = (set_sm_margin) ? prop.multiProcessorCount - num_comm_sm : prop.multiProcessorCount; _tp_size = tp_size; _aggregate2 = aggregate2; @@ -383,6 +605,26 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder { _next_rank = (tp_size + rank + 1) % tp_size + _rank_round_tp; _prev_rank = (tp_size + rank + -1) % tp_size + _rank_round_tp; + auto counter_options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA); + counter = torch::zeros({tp_size * 2}, counter_options); + counter.index_put_({Slice(None, tp_size)}, 1); + _self_chunk_id = _tp_id; + + const char *env_p = std::getenv("NVTE_AG_P2P_ATOMIC"); + if (rank == 0 && env_p != nullptr) { + if (env_p[0] == '1') { + printf("!!userbuffers_sendrecv_atomic\n"); + } else if (env_p[0] == '2') { + printf("!!userbuffers_sendrecv_multiatomic\n"); + } else if (env_p[0] == '3') { + printf("!!userbuffers_sendrecv_multiatomic_shuffle\n"); + _self_chunk_id = 0; + } else { + printf("!!userbuffers_sendrecv\n"); + } + } + counter.index_put_({_self_chunk_id}, 0); + // CUDA event creation cudaEventCreateWithFlags(&_start_compute, 0); cudaEventCreateWithFlags(&_stop_compute, 0); @@ -390,11 +632,144 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder { cudaEventCreateWithFlags(&_stop_recv, 0); } + /* + ** Split AllGather + AtomicGEMM using P2P communication + ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is + *needed to have AG outputs + ** in each rank to be in the contiguous memory space after all ring exchange + *phases. + */ + torch::Tensor atomic_gemm_overlap_ag( + at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, + transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse, + int64_t B_fp8_tensor, transformer_engine::DType B_type, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, + transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor B_copy) { + _ub_comm->use_ce = use_ce; + _ub_comm->sms = sms; + _ub_comm->cga_size = cga_size; + // Get GEMM dimensions between TN and NN input layouts + const int m = (transa) ? A.size(0) : A.size(1); + const int k = (transa) ? A.size(1) : A.size(0); + const int n_chunk = _ubufs[0].size(0); + + // Get communication and GEMM output chunk sizes + const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); + + // Get output and workspace data pointers + char *output_ptr = reinterpret_cast(D.data_ptr()); + char *workspace_ptr = reinterpret_cast(workspace.data_ptr()); + int *counter_ptr = reinterpret_cast(counter.data_ptr()); + int workspace_size_chunk = workspaceSize / _stream_compute.size(); + + if (A_scale_inverse.numel()) + A_scale_inverse = A_scale_inverse[A_fp8_tensor]; + + if (B_scale_inverse.numel()) + B_scale_inverse = B_scale_inverse[B_fp8_tensor]; + + at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); + CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main)); + + assert(pre_gelu_out.numel() == 0); + // Catch up the default torch stream + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0)); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0)); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _start_compute, 0)); + + torch::Tensor output_chunk = torch::from_blob(output_ptr, {_ubuf.size(0), m}, D.options()); + torch::Tensor workspace_chunk = + torch::from_blob(workspace_ptr, {workspace_size_chunk}, workspace.options()); + for (int i = 0; i < _tp_size; i++) { + // Set the userbuffer id. Buffer under send is the input for the current + // GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to + // have the AG output in all ranks to be contiguous after the ring + // exchanges + int send_chunk_id = (_tp_size + _tp_id - i) % _tp_size; + int recv_chunk_id = (_tp_size + _tp_id - i - 1) % _tp_size; + int send_offset = comm_bytes * send_chunk_id; + int recv_offset = comm_bytes * recv_chunk_id; + + if (i < _tp_size - 1) { + const char *env_p = std::getenv("NVTE_AG_P2P_ATOMIC"); + if (env_p != nullptr && env_p[0] == '1') { + userbuffers_sendrecv_atomic(_ub_reg, _ub_reg, send_offset, recv_offset, comm_bytes, + _ub_comm, _next_rank, _prev_rank, &counter_ptr[recv_chunk_id], + (cudaStream_t)_stream_recv); + } else if (env_p != nullptr && env_p[0] == '2') { + if (i == 0) { + userbuffers_sendrecv_multiatomic(_ub_reg, _ub_reg, comm_bytes, comm_bytes, comm_bytes, + _ub_comm, _next_rank, _prev_rank, _tp_size, + counter_ptr, false, (cudaStream_t)_stream_recv); + } + } else if (env_p != nullptr && env_p[0] == '3') { + if (i == 0) { + userbuffers_sendrecv_multiatomic(_ub_reg, _ub_reg, comm_bytes, comm_bytes, comm_bytes, + _ub_comm, _next_rank, _prev_rank, _tp_size, + counter_ptr, true, (cudaStream_t)_stream_recv); + } + } else { + // P2P communication + // userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, + // comm_bytes, _ub_comm, + // _next_rank, (cudaStream_t)_stream_send); + // userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, + // comm_bytes, _ub_comm, + // _prev_rank, (cudaStream_t)_stream_recv); + // CHECK_CUDA(cudaEventRecord(_stop_recv, + // (cudaStream_t)_stream_recv)); + // CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, + // _stop_recv, 0)); + userbuffers_sendrecv(_ub_reg, _ub_reg, send_offset, recv_offset, comm_bytes, _ub_comm, + _next_rank, _prev_rank, (cudaStream_t)_stream_recv); + producer(counter_ptr, recv_chunk_id, (cudaStream_t)_stream_recv); + } + if (i == 0) { + at::cuda::setCurrentCUDAStream(_stream_compute[0]); + te_atomic_gemm(A, A_scale_inverse, A_type, transa, _ubuf, B_scale_inverse, B_type, transb, + output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, + workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, + _math_sms, 0, _tp_size, false, counter); + } + } else { + // GEMM + // userbuffers_send_multiatomic(_ub_reg, 0, _ub_reg, 0, comm_bytes, + // _ub_comm, + // _next_rank, _tp_size, comm_bytes, comm_bytes, + // (cudaStream_t)_stream_send); + // userbuffers_recv_multiatomic(_ub_reg, 0, _ub_reg, 0, comm_bytes, + // _ub_comm, + // _prev_rank, _tp_size, counter_ptr, + // (cudaStream_t)_stream_recv); + if (B_copy.numel() > 0) { + assert(B_copy.numel() == _ubufs[_tp_id].numel()); + assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); + CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_tp_id].data_ptr(), + _ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(), + cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send)); + CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send)); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0)); + } + } + } + for (int i = 0; i < _tp_size; i++) { + if (i != _self_chunk_id) { + consumer(counter_ptr, i, (cudaStream_t)_stream_compute[0]); + } + } + at::cuda::setCurrentCUDAStream(stream_main); + CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[0])); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); + + return D; + } // split_overlap_ag /* ** Split AllGather + GEMM using P2P communication - ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG - *outputs - ** in each rank to be in the contiguous memory space after all ring exchange phases. + ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is + *needed to have AG outputs + ** in each rank to be in the contiguous memory space after all ring exchange + *phases. */ torch::Tensor split_overlap_ag(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, transformer_engine::DType A_type, bool transa, at::Tensor B, @@ -405,6 +780,9 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder { transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor B_copy) { + _ub_comm->use_ce = use_ce; + _ub_comm->sms = sms; + _ub_comm->cga_size = cga_size; // Get GEMM dimensions between TN and NN input layouts const int m = (transa) ? A.size(0) : A.size(1); const int k = (transa) ? A.size(1) : A.size(0); @@ -419,9 +797,11 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder { char *workspace_ptr = reinterpret_cast(workspace.data_ptr()); int workspace_size_chunk = workspaceSize / _stream_compute.size(); - if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; + if (A_scale_inverse.numel()) + A_scale_inverse = A_scale_inverse[A_fp8_tensor]; - if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor]; + if (B_scale_inverse.numel()) + B_scale_inverse = B_scale_inverse[B_fp8_tensor]; at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main)); @@ -506,9 +886,10 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder { CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _start_compute, 0)); for (int i = 0; i < _tp_size; i++) { - // Set the userbuffer id. Buffer under send is the input for the current GEMM chunk - // The initial input chunk is stored _ubuf[rank]. This is to have the AG output in all ranks - // to be contiguous after the ring exchanges + // Set the userbuffer id. Buffer under send is the input for the current + // GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to + // have the AG output in all ranks to be contiguous after the ring + // exchanges int send_chunk_id = (_tp_size + _tp_id - i) % _tp_size; int recv_chunk_id = (_tp_size + _tp_id - i - 1) % _tp_size; int send_offset = comm_bytes * send_chunk_id; @@ -581,7 +962,8 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder { torch::Tensor get_ubuf_output(int comm_type) { char *ubuf_wt_ptr = reinterpret_cast(_ubuf.data_ptr()); COMM_TYPE _comm_type = static_cast(comm_type); - if (_comm_type != COMM_TYPE::AG && _comm_type != COMM_TYPE::RS) NVTE_ERROR("Invalid comm_type"); + if (_comm_type != COMM_TYPE::AG && _comm_type != COMM_TYPE::RS) + NVTE_ERROR("Invalid comm_type"); if (_comm_type == COMM_TYPE::RS) ubuf_wt_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size(); int output_c_dim0 = (_comm_type == COMM_TYPE::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size; diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 274a523ec0..4eaca7c896 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -179,6 +179,32 @@ void te_gemm(at::Tensor A, int math_sm_count ); +void te_atomic_gemm(at::Tensor A, + at::Tensor A_scale_inverse, + transformer_engine::DType A_type, + bool transa, + at::Tensor B, + at::Tensor B_scale_inverse, + transformer_engine::DType B_type, + bool transb, + at::Tensor D, + at::Tensor D_scale, + transformer_engine::DType D_type, + at::Tensor D_amax, + at::Tensor bias, + transformer_engine::DType bias_type, + at::Tensor pre_gelu_out, + bool grad, + at::Tensor workspace, + size_t workspaceSize, + bool accumulate, + bool use_split_accumulator, + int math_sm_count, + int m_split, + int n_split, + bool gemm_producer, + at::Tensor counter +); void fused_cast_transpose(at::Tensor input, at::Tensor scale, diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cu b/transformer_engine/pytorch/csrc/extensions/gemm.cu index 1a7630edce..480b8716b2 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cu +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cu @@ -6,6 +6,7 @@ #include "extensions.h" + void te_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType A_type, @@ -73,3 +74,82 @@ void te_gemm(at::Tensor A, math_sm_count, at::cuda::getCurrentCUDAStream()); } + +void te_atomic_gemm(at::Tensor A, + at::Tensor A_scale_inverse, + transformer_engine::DType A_type, + bool transa, + at::Tensor B, + at::Tensor B_scale_inverse, + transformer_engine::DType B_type, + bool transb, + at::Tensor D, + at::Tensor D_scale, + transformer_engine::DType D_type, + at::Tensor D_amax, + at::Tensor bias, + transformer_engine::DType bias_type, + at::Tensor pre_gelu_out, + bool grad, + at::Tensor workspace, + size_t workspaceSize, + bool accumulate, + bool use_split_accumulator, + int math_sm_count, + int m_split, + int n_split, + bool gemm_producer, + at::Tensor counter +) { + using namespace transformer_engine; + auto te_A = makeTransformerEngineTensor(A.data_ptr(), + {static_cast(A.size(0)), + static_cast(A.size(1))}, + A_type, nullptr, nullptr, + A_scale_inverse.data_ptr()); + auto te_B = makeTransformerEngineTensor(B.data_ptr(), + {static_cast(B.size(0)), + static_cast(B.size(1))}, + B_type, nullptr, nullptr, + B_scale_inverse.data_ptr()); + auto te_D = makeTransformerEngineTensor(D.data_ptr(), + {static_cast(D.size(0)), + static_cast(D.size(1))}, + D_type, D_amax.data_ptr(), + D_scale.data_ptr(), nullptr); + auto te_bias = makeTransformerEngineTensor(bias.data_ptr(), {static_cast(bias.size(0))}, + bias_type); + auto te_counter = makeTransformerEngineTensor(counter.data_ptr(), + {static_cast(counter.size(0))}, + DType::kInt32); + + const auto gelu_shape = pre_gelu_out.data_ptr() == nullptr + ? std::vector{static_cast(pre_gelu_out.size(0))} + : std::vector{static_cast(pre_gelu_out.size(0)), + static_cast(pre_gelu_out.size(1))}; + auto te_pre_gelu_out = makeTransformerEngineTensor(pre_gelu_out.data_ptr(), + gelu_shape, + GetTransformerEngineDType( + pre_gelu_out.scalar_type())); + auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(), + {workspaceSize}, + DType::kByte); + + nvte_cublas_atomic_gemm(te_A.data(), + te_B.data(), + te_D.data(), + te_bias.data(), + te_pre_gelu_out.data(), + transa, + transb, + grad, + te_workspace.data(), + accumulate, + use_split_accumulator, + math_sm_count, + m_split, + n_split, + gemm_producer, + te_counter.data(), + at::cuda::getCurrentCUDAStream()); +} diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index abc15022b0..7e80299d15 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -91,18 +91,24 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .value("BULK_OVERLAP_AG", ubuf::UBOverlapAlgo::BULK_OVERLAP_AG) .value("BULK_OVERLAP_RS", ubuf::UBOverlapAlgo::BULK_OVERLAP_RS) .value("SPLIT_PIPELINED_RS", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_RS) - .value("SPLIT_PIPELINED_AG", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_AG); + .value("SPLIT_PIPELINED_AG", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_AG) + .value("ATOMIC_GEMM_RS", ubuf::UBOverlapAlgo::ATOMIC_GEMM_RS) + .value("ATOMIC_GEMM_AG", ubuf::UBOverlapAlgo::ATOMIC_GEMM_AG); py::class_(m, "UbufCommOverlap") - .def(py::init()) + .def(py::init()) .def("bulk_overlap", &ubuf::UbufCommOverlap::bulk_overlap) .def("split_overlap_rs", &ubuf::UbufCommOverlap::split_overlap_rs) + .def("set_ubuf_scale_inv", &ubuf::UbufCommOverlap::set_ubuf_scale_inv) + .def("atomic_gemm_overlap_rs", &ubuf::UbufCommOverlap::atomic_gemm_overlap_rs) + .def("is_fp8_ubuf", &ubuf::UbufCommOverlap::is_fp8_ubuf) .def("copy_input_to_ubuf", &ubuf::UbufCommOverlap::copy_input_to_ubuf) .def("get_ubuf_output", &ubuf::UbufCommOverlap::get_ubuf_output); py::class_(m, "UbufP2PCommOverlap") - .def(py::init()) + .def(py::init()) .def("split_overlap_ag", &ubuf::UbufP2PCommOverlap::split_overlap_ag) + .def("atomic_gemm_overlap_ag", &ubuf::UbufP2PCommOverlap::atomic_gemm_overlap_ag) .def("copy_input_to_ubuf", &ubuf::UbufP2PCommOverlap::copy_input_to_ubuf) .def("get_ubuf_output", &ubuf::UbufP2PCommOverlap::get_ubuf_output); #else // NVTE_WITH_USERBUFFERS diff --git a/transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp b/transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp index 59afc4b452..7c08070728 100644 --- a/transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp +++ b/transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp @@ -4,10 +4,13 @@ * See LICENSE for license information. ************************************************************************/ +#include "userbuffers.h" #include +#include #include #include #include +#include #include #include #include @@ -15,9 +18,6 @@ #include #include #include -#include -#include -#include "userbuffers.h" static int oob_bcast(void *comm_context, void *buf, int size, int root) { MPI_Bcast(buf, size, MPI_BYTE, root, @@ -38,20 +38,31 @@ static int oob_gather(void *comm_context, int root, void *sbuf, void *rbuf, int int stringCmp(const void *a, const void *b) { return strcmp((const char *)a, (const char *)b); } -#define CUDACHECK(cmd) \ - do { \ - cudaError_t e = cmd; \ - if (e != cudaSuccess) { \ - printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \ - exit(EXIT_FAILURE); \ - } \ +#define CUDACHECK(cmd) \ + do { \ + cudaError_t e = cmd; \ + if (e != cudaSuccess) { \ + printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \ + exit(EXIT_FAILURE); \ + } \ } while (0) -#define NVTE_UB_ERROR(x) \ - do { \ - throw std::runtime_error(std::string(__FILE__ ":") + std::to_string(__LINE__) + \ - " in function " + __func__ + ": " + x); \ - } while (false) +#define CUCHECK(cmd) \ + do { \ + CUresult retval = cmd; \ + if (retval != CUDA_SUCCESS) { \ + const char *error_string; \ + cuGetErrorString(retval, &error_string); \ + printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, error_string); \ + exit(EXIT_FAILURE); \ + } \ + } while (0); + +#define NVTE_UB_ERROR(x) \ + do { \ + throw std::runtime_error(std::string(__FILE__ ":") + std::to_string(__LINE__) + \ + " in function " + __func__ + ": " + x); \ + } while (false) int pipe_rank(communicator *comm, int step) { int mynode = comm->myrank / comm->nvsize; @@ -89,12 +100,14 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode (*comm)->push = 1; (*comm)->use_ce = 0; (*comm)->cga_size = 2; - for (int i = 0; i < userbuffers_op_types; i++) (*comm)->basecounter[i] = 0; + for (int i = 0; i < userbuffers_op_types; i++) + (*comm)->basecounter[i] = 0; (*comm)->head = 0; (*comm)->tail = 0; (*comm)->activeproxy = 1; (*comm)->active_nreqs = 0; - for (int i = 0; i < userbuffers_op_types; i++) (*comm)->active_req[i].active = -1; + for (int i = 0; i < userbuffers_op_types; i++) + (*comm)->active_req[i].active = -1; int ret = 0; // split communicator @@ -112,8 +125,10 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode color = 0; for (int n = 0; n < size; n++) { - if (n > 0 && strcmp(host_names[n - 1], host_names[n])) color++; - if (strcmp(host_name, host_names[n]) == 0) break; + if (n > 0 && strcmp(host_names[n - 1], host_names[n])) + color++; + if (strcmp(host_name, host_names[n]) == 0) + break; } free(host_names); @@ -128,14 +143,22 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode cpu_set_t cpuset; CPU_ZERO(&cpuset); int core; - if (mylocal == 0) core = 50; - if (mylocal == 1) core = 58; - if (mylocal == 2) core = 18; - if (mylocal == 3) core = 26; - if (mylocal == 4) core = 114; - if (mylocal == 5) core = 122; - if (mylocal == 6) core = 82; - if (mylocal == 7) core = 90; + if (mylocal == 0) + core = 50; + if (mylocal == 1) + core = 58; + if (mylocal == 2) + core = 18; + if (mylocal == 3) + core = 26; + if (mylocal == 4) + core = 114; + if (mylocal == 5) + core = 122; + if (mylocal == 6) + core = 82; + if (mylocal == 7) + core = 90; CPU_SET(core, &cpuset); if (!getenv("NVTE_NODOUBLE")) { @@ -144,7 +167,8 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode else CPU_SET(core + 128, &cpuset); } - if (getenv("NVTE_DOPIN")) pthread_setaffinity_np(pthread_self(), sizeof(cpu_set_t), &cpuset); + if (getenv("NVTE_DOPIN")) + pthread_setaffinity_np(pthread_self(), sizeof(cpu_set_t), &cpuset); if (ndev == numlocal) { // all visible devices if (cur_dev != mylocal) @@ -175,7 +199,8 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode int datanodegroup_id = myrank / numlocal / datanodes; // data reduction group node belongs, equals 0 for all if both // pipenodes=1 and tensornodes=1 - // mpi communicator only needed for SHARP which is always allreduce1/data-parallel + // mpi communicator only needed for SHARP which is always + // allreduce1/data-parallel MPI_Comm_split(MPI_COMM_WORLD, mylocal + numlocal * datanodegroup_id, rank, &(*comm)->comm_inter); // different rails from same group are in different subcommunicators @@ -192,19 +217,37 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode char *ib_dev_list; int ZIONROCE = getenv("NVTE_ZIONROCE") ? atoi(getenv("NVTE_ZIONROCE")) : 0; int ROCE = getenv("NVTE_ROCE") ? atoi(getenv("NVTE_ROCE")) : 0; - if (ZIONROCE) ROCE = 1; + if (ZIONROCE) + ROCE = 1; int DGX_H100 = device_prop.major == 9; switch (mylocal) { - case 0:ib_dev_list = "mlx5_0:1"; break; // NOLINT(*) - case 1:ib_dev_list = (char*)(DGX_H100?"mlx5_3:1":"mlx5_1:1"); break; // NOLINT(*) - case 2:ib_dev_list = (char*)(ZIONROCE?"mlx5_4:1":DGX_H100?"mlx5_4:1":"mlx5_2:1"); break; // NOLINT(*) - case 3:ib_dev_list = (char*)(DGX_H100?"mlx5_5:1":"mlx5_3:1"); break; // NOLINT(*) - case 4:ib_dev_list = (char*)(DGX_H100?"mlx5_6:1":"mlx5_6:1"); break; // NOLINT(*) - case 5:ib_dev_list = (char*)(DGX_H100?"mlx5_9:1":"mlx5_7:1"); break; // NOLINT(*) - case 6:ib_dev_list = (char*)(ZIONROCE?"mlx5_10:1":DGX_H100?"mlx5_10:1":"mlx5_8:1"); break; // NOLINT(*) - case 7:ib_dev_list = (char*)(DGX_H100?"mlx5_11:1":"mlx5_9:1"); break; // NOLINT(*) - default: break; + case 0: + ib_dev_list = "mlx5_0:1"; + break; // NOLINT(*) + case 1: + ib_dev_list = (char *)(DGX_H100 ? "mlx5_3:1" : "mlx5_1:1"); // NOLINT(*) + break; // NOLINT(*) + case 2: + ib_dev_list = (char *)(ZIONROCE ? "mlx5_4:1" : DGX_H100 ? "mlx5_4:1" : "mlx5_2:1"); // NOLINT(*) + break; // NOLINT(*) + case 3: + ib_dev_list = (char *)(DGX_H100 ? "mlx5_5:1" : "mlx5_3:1"); // NOLINT(*) + break; // NOLINT(*) + case 4: + ib_dev_list = (char *)(DGX_H100 ? "mlx5_6:1" : "mlx5_6:1"); // NOLINT(*) + break; // NOLINT(*) + case 5: + ib_dev_list = (char *)(DGX_H100 ? "mlx5_9:1" : "mlx5_7:1"); // NOLINT(*) + break; // NOLINT(*) + case 6: + ib_dev_list = (char *)(ZIONROCE ? "mlx5_10:1" : DGX_H100 ? "mlx5_10:1" : "mlx5_8:1"); // NOLINT(*) + break; // NOLINT(*) + case 7: + ib_dev_list = (char *)(DGX_H100 ? "mlx5_11:1" : "mlx5_9:1"); // NOLINT(*) + break; // NOLINT(*) + default: + break; } (*comm)->fifo = reinterpret_cast(malloc(sizeof(ub_request) * NVTE_MAX_REQUESTS)); @@ -215,7 +258,8 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode CUDACHECK(cudaMallocHost((void **)&(*comm)->hostflags, // NOLINT(*) (NVTE_MAX_SMS + 100) * sizeof(int))); - for (int i = 0; i < 100 + NVTE_MAX_SMS; i++) (*comm)->hostflags[i] = 0; + for (int i = 0; i < 100 + NVTE_MAX_SMS; i++) + (*comm)->hostflags[i] = 0; _mm_mfence(); sleep(1); @@ -223,13 +267,16 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode (*comm)->ibnvsize = (*comm)->nvsize; #define NBUF 2 + #define LOCALSIZE 4 * (NVTE_REG0_OFFSET(*comm) + NVTE_REG0_FLAGS + NVTE_REG0_COMMBUFFER * NBUF) // peer pointers + op flags + comm buffer - CUDACHECK(cudaMalloc(&(*comm)->gpu_ptrs, LOCALSIZE)); // flags and pointers, no block data yet + CUDACHECK(cudaMalloc(&(*comm)->gpu_ptrs, + LOCALSIZE)); // flags and pointers, no block data yet CUDACHECK(cudaMemset((*comm)->gpu_ptrs, 0, LOCALSIZE)); CUDACHECK(cudaDeviceSynchronize()); - register_user_buffer_collective(&((*comm)->gpu_ptrs), LOCALSIZE, *comm); // will use handler 0 + register_user_buffer_collective(&((*comm)->gpu_ptrs), LOCALSIZE, + *comm); // will use handler 0 CUDACHECK(cudaMalloc(&(*comm)->send_id, (*comm)->nranks * sizeof(int))); CUDACHECK(cudaMalloc(&(*comm)->recv_id, NVTE_MAX_REGIONS * (*comm)->nranks * sizeof(int))); CUDACHECK(cudaMemset((*comm)->send_id, 0, (*comm)->nranks * sizeof(int))); @@ -243,7 +290,6 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode #define GPU_PAGE_MASK (~GPU_PAGE_OFFSET) CUDACHECK(cudaMalloc(&(*comm)->flags, 2 * GPU_PAGE_SIZE)); unsigned int flag = 1; - // cuPointerSetAttribute(&flag, CU_POINTER_ATTRIBUTE_SYNC_MEMOPS, (CUdeviceptr)(*comm)->flags); CUDACHECK(cudaMemset((*comm)->flags, 0, 2 * GPU_PAGE_SIZE)); (*comm)->flags = reinterpret_cast(((CUdeviceptr)(*comm)->flags + GPU_PAGE_SIZE - 1) & GPU_PAGE_MASK); @@ -275,7 +321,8 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode pthread_attr_setschedparam(&attr, ¶m); if (getenv("NVTE_UBDEBUG")) - printf("%d/%d:(%d x %d): DP %d x %d TP %d x %d, DPGROUP %dx%d TPGROUP %dx%d PIPE_ID %d/%d\n", + printf("%d/%d:(%d x %d): DP %d x %d TP %d x %d, DPGROUP %dx%d TPGROUP " + "%dx%d PIPE_ID %d/%d\n", myrank, nranks, myrank / numlocal, myrank % numlocal, (*comm)->my_node, (*comm)->ar_nvrank, (*comm)->my2_node, (*comm)->ar2_nvrank, (*comm)->num_nodes, (*comm)->ar_nvsize, (*comm)->num2_nodes, (*comm)->ar2_nvsize, (*comm)->pipe_id, @@ -300,9 +347,9 @@ void destroy_communicator(communicator *comm) { } int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *comm, bool alloc) { - if (comm->free_region > NVTE_MAX_REGIONS) return -1; + if (comm->free_region > NVTE_MAX_REGIONS) + return -1; int hndl = comm->free_region; - // printf("%d register %d size %lld\n",comm->myrank,hndl,bytes);fflush(NULL); comm->peer_ptr[hndl] = reinterpret_cast(malloc(sizeof(void *) * (comm->nvsize))); if (alloc) { @@ -313,25 +360,22 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * reinterpret_cast(malloc(sizeof(cudaIpcMemHandle_t) * (comm->nvsize))); CUDACHECK(cudaIpcGetMemHandle(&memhndl[comm->nvrank], *gpubuff)); - MPI_Allgather(&memhndl[comm->nvrank], sizeof(cudaIpcMemHandle_t), MPI_BYTE, memhndl, sizeof(cudaIpcMemHandle_t), MPI_BYTE, comm->comm_intra); - for (int i = 0; i < comm->nvsize; i++) if (i != comm->nvrank) CUDACHECK(cudaIpcOpenMemHandle((void **)&(comm->peer_ptr[hndl][i]), // NOLINT(*) memhndl[i], cudaIpcMemLazyEnablePeerAccess)); comm->peer_ptr[hndl][comm->nvrank] = *gpubuff; CUDACHECK(cudaDeviceSynchronize()); - CUDACHECK( cudaMemcpy(reinterpret_cast(comm->gpu_ptrs) + (hndl * comm->nvsize * sizeof(void *)), comm->peer_ptr[hndl], comm->nvsize * sizeof(void *), cudaMemcpyHostToDevice)); - CUDACHECK(cudaDeviceSynchronize()); free(memhndl); comm->mem_ptr[hndl] = *gpubuff; + return comm->free_region++; } @@ -352,8 +396,10 @@ int allgather2_userbuff_inplace_gpu(const int maxcredit, const int handler, cons void allreduce_nonsharp_inplace(const int handler, const int offset, const int elements, communicator *comm, cudaStream_t stream, int op) { - if (elements < 64) NVTE_UB_ERROR("Userbuffer comm for given config not implemented."); - // if(comm->myrank==0) fprintf(stderr,"AR2(%d) user call launch_mode=%d\n",op,comm->launch_mode); + if (elements < 64) + NVTE_UB_ERROR("Userbuffer comm for given config not implemented."); + // if(comm->myrank==0) fprintf(stderr,"AR2(%d) user call + // launch_mode=%d\n",op,comm->launch_mode); const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; int blocksize = elements * 2; int maxcredit = 0; @@ -361,19 +407,19 @@ void allreduce_nonsharp_inplace(const int handler, const int offset, const int e blocksize = (comm->nblocks - 1 + (comm->alignblock - 1 + elements * 2) / comm->alignblock) / comm->nblocks; // FIXME TUNING blocksize *= comm->alignblock; - if (blocksize < comm->minblock) blocksize = comm->minblock; + if (blocksize < comm->minblock) + blocksize = comm->minblock; maxcredit = (elements * 2 + blocksize - 1) / blocksize; - // if(maxcredit>4) maxcredit=4; - // if(maxcredit>4 && ar_nvsize==1) maxcredit=4; size_t peerblock = sizeof(int) * NVTE_REG0_COMMBUFFER / maxcredit; // max size we can fit - if (blocksize > peerblock * ar_nvsize) blocksize = peerblock * ar_nvsize; - // blocksize=elements*2; + if (blocksize > peerblock * ar_nvsize) + blocksize = peerblock * ar_nvsize; int sms = allreduce2_userbuff_inplace_gpu(maxcredit, handler, offset, elements, blocksize, comm, stream, op); if (num_nodes > 1 && comm->launch_mode & NVTE_LAUNCH_CPU) { - if (!sms) return; + if (!sms) + return; comm->fifo[comm->head].optype = op; comm->fifo[comm->head].basecounter = comm->basecounter[op]; comm->fifo[comm->head].blocksize = blocksize; @@ -399,7 +445,8 @@ void allreduce2_userbuff_inplace(const int handler, const int offset, const int void allreduce_userbuff_inplace(const int handler, const int offset, const int elements, communicator *comm, cudaStream_t stream) { - if (elements < 64) NVTE_UB_ERROR("Userbuffer comm for given config not implemented."); + if (elements < 64) + NVTE_UB_ERROR("Userbuffer comm for given config not implemented."); allreduce_nonsharp_inplace(handler, offset, elements, comm, stream, userbuffers_allreduceop_nonsharp); return; @@ -407,7 +454,8 @@ void allreduce_userbuff_inplace(const int handler, const int offset, const int e void reducescatter_userbuff_inplace(const int handler, const int offset, const int elements, communicator *comm, cudaStream_t stream) { - if (elements < 64) NVTE_UB_ERROR("Userbuffer comm for given config not implemented."); + if (elements < 64) + NVTE_UB_ERROR("Userbuffer comm for given config not implemented."); int op = userbuffers_allreduceop_nonsharp; const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; @@ -418,17 +466,20 @@ void reducescatter_userbuff_inplace(const int handler, const int offset, const i blocksize = (comm->nblocks - 1 + (comm->alignblock - 1 + elements * 2) / comm->alignblock) / comm->nblocks; // FIXME TUNING blocksize *= comm->alignblock; - if (blocksize < comm->minblock) blocksize = comm->minblock; + if (blocksize < comm->minblock) + blocksize = comm->minblock; maxcredit = (elements * 2 + blocksize - 1) / blocksize; size_t peerblock = sizeof(int) * NVTE_REG0_COMMBUFFER / maxcredit; // max size we can fit - if (blocksize > peerblock * ar_nvsize) blocksize = peerblock * ar_nvsize; + if (blocksize > peerblock * ar_nvsize) + blocksize = peerblock * ar_nvsize; int sms = reducescatter2_userbuff_inplace_gpu(maxcredit, handler, offset, elements, blocksize, comm, stream, op); if (num_nodes > 1 && comm->launch_mode & NVTE_LAUNCH_CPU) { - if (!sms) return; + if (!sms) + return; comm->fifo[comm->head].optype = op; comm->fifo[comm->head].basecounter = comm->basecounter[op]; comm->fifo[comm->head].blocksize = blocksize; @@ -448,7 +499,8 @@ void reducescatter_userbuff_inplace(const int handler, const int offset, const i void allgather_userbuff_inplace(const int handler, const int offset, const int elements, communicator *comm, cudaStream_t stream) { - if (elements < 64) NVTE_UB_ERROR("Userbuffer comm for given config not implemented."); + if (elements < 64) + NVTE_UB_ERROR("Userbuffer comm for given config not implemented."); int op = userbuffers_allreduceop_nonsharp; const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; int blocksize = elements * 2; @@ -458,11 +510,13 @@ void allgather_userbuff_inplace(const int handler, const int offset, const int e blocksize = (comm->nblocks - 1 + (comm->alignblock - 1 + elements * 2) / comm->alignblock) / comm->nblocks; // FIXME TUNING blocksize *= comm->alignblock; - if (blocksize < comm->minblock) blocksize = comm->minblock; + if (blocksize < comm->minblock) + blocksize = comm->minblock; maxcredit = (elements * 2 + blocksize - 1) / blocksize; size_t peerblock = sizeof(int) * NVTE_REG0_COMMBUFFER / maxcredit; // max size we can fit - if (blocksize > peerblock * ar_nvsize) blocksize = peerblock * ar_nvsize; + if (blocksize > peerblock * ar_nvsize) + blocksize = peerblock * ar_nvsize; int sms = allgather2_userbuff_inplace_gpu(maxcredit, handler, offset, elements, blocksize, comm, stream, op); diff --git a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu b/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu index 2c8e9dc61d..ecd17a45d7 100644 --- a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu +++ b/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu @@ -12,22 +12,42 @@ #else #include #endif +#include "userbuffers.h" #include +#include #include -#include "userbuffers.h" #define MAX_THREADS 1024 #define TIMEOUT 200000000000ull -#define CUDACHECK(cmd) \ - do { \ - cudaError_t e = cmd; \ - if (e != cudaSuccess) { \ - printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \ - exit(EXIT_FAILURE); \ - } \ +#define CUDACHECK(cmd) \ + do { \ + cudaError_t e = cmd; \ + if (e != cudaSuccess) { \ + printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \ + exit(EXIT_FAILURE); \ + } \ } while (0) +#define ATOMIC_CONSUMER(chunk) \ + if (counters) { \ + if (threadIdx.x == 0 && blockIdx.x == 0) { \ + int old_val; \ + while (0 != (old_val = atomicCAS(((unsigned int *)counters) + chunk, 0, 0))) { \ + } \ + ((unsigned int *)counters)[chunk] = 1; \ + asm volatile("fence.sc.gpu;\n"); \ + } \ + if (blockIdx.x == 0) \ + __syncthreads(); \ + } + +#define ATOMIC_PRODUCER(chunk) \ + if (counters) { \ + ((unsigned int *)counters)[chunk] = 0; \ + asm volatile("fence.sc.gpu;\n"); \ + } + template __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_rw(const int op, const int flagoffset, const int firstrank, @@ -36,8 +56,7 @@ __global__ void __launch_bounds__(MAX_THREADS) __shared__ int4 *userptr[RANKS]; int *flagptr, physgpu, targetgpu, *myptr; int *reduceidptr, reduce_id; - // if(blockIdx.x==0 && threadIdx.x==0) printf("%d/%d(phys %d gpustep %d firstrank %d):RRkernel(d) - // start, size %lld\n",myrank,RANKS,gpustep*myrank+firstrank,gpustep,firstrank,numlines*16ull); + if (threadIdx.x < RANKS) { physgpu = myrank * gpustep + firstrank; targetgpu = threadIdx.x * gpustep + firstrank; @@ -66,7 +85,8 @@ __global__ void __launch_bounds__(MAX_THREADS) int warp = blockIdx.x + (threadIdx.x >> 5); int dest[RANKS]; #pragma unroll - for (int i = 0; i < RANKS; i++) dest[i] = (i + myrank + warp) & (RANKS - 1); + for (int i = 0; i < RANKS; i++) + dest[i] = (i + myrank + warp) & (RANKS - 1); __syncthreads(); for (int line = threadIdx.x + blockDim.x * (myrank + RANKS * blockIdx.x); line < numlines; @@ -86,7 +106,8 @@ __global__ void __launch_bounds__(MAX_THREADS) for (int i = 1; i < RANKS; i++) { half *x = reinterpret_cast(&val[i]); #pragma unroll - for (int j = 0; j < 8; j++) s[j] += x[j]; + for (int j = 0; j < 8; j++) + s[j] += x[j]; } #pragma unroll for (int i = 0; i < RANKS; i++) { @@ -96,7 +117,8 @@ __global__ void __launch_bounds__(MAX_THREADS) } __syncthreads(); - if (threadIdx.x == 0) __threadfence_system(); + if (threadIdx.x == 0) + __threadfence_system(); __syncthreads(); if (threadIdx.x < RANKS) { @@ -111,7 +133,8 @@ __global__ void __launch_bounds__(MAX_THREADS) } } } - if (threadIdx.x == 0 && blockIdx.x == 0) *reduceidptr = reduce_id; + if (threadIdx.x == 0 && blockIdx.x == 0) + *reduceidptr = reduce_id; } // fp16 inplace reduce kernel (Volta,Hopper) template @@ -150,7 +173,8 @@ __global__ void __launch_bounds__(MAX_THREADS) int warp = blockIdx.x + (threadIdx.x >> 5); int dest[RANKS]; #pragma unroll - for (int i = 0; i < RANKS; i++) dest[i] = (i + myrank + warp) & (RANKS - 1); + for (int i = 0; i < RANKS; i++) + dest[i] = (i + myrank + warp) & (RANKS - 1); __syncthreads(); for (int line = threadIdx.x + blockDim.x * (myrank + RANKS * blockIdx.x); line < numlines; @@ -169,13 +193,15 @@ __global__ void __launch_bounds__(MAX_THREADS) for (int i = 1; i < RANKS; i++) { half *x = reinterpret_cast(&val[i]); #pragma unroll - for (int j = 0; j < 8; j++) s[j] += x[j]; + for (int j = 0; j < 8; j++) + s[j] += x[j]; } userptr[myrank][lineoffset + line] = sum; } __syncthreads(); - if (threadIdx.x == 0) __threadfence(); + if (threadIdx.x == 0) + __threadfence(); __syncthreads(); if (threadIdx.x < RANKS) { @@ -217,7 +243,8 @@ __global__ void __launch_bounds__(MAX_THREADS) userptr[myrank][lineoffset + line + blockDim.x * dest[i]] = val[i]; } } - if (threadIdx.x == 0 && blockIdx.x == 0) *reduceidptr = reduce_id; + if (threadIdx.x == 0 && blockIdx.x == 0) + *reduceidptr = reduce_id; } // fp16 inplace reduce kernel (Ampere) template @@ -227,19 +254,19 @@ __global__ void __launch_bounds__(MAX_THREADS) const int mylineoffset, const int totallines, void **commbuff, const int handleridx) { __shared__ int4 *userptr[RANKS]; - int *flagptr, physgpu, targetgpu, *myptr; + volatile int *flagptr; + int physgpu, targetgpu, *myptr; int *reduceidptr, reduce_id; + int lastSM = 0; if (threadIdx.x < RANKS) { physgpu = myrank * gpustep + firstrank; targetgpu = threadIdx.x * gpustep + firstrank; - const int blockflagoffset = NVTE_MAX_NVLINK * 2 * blockIdx.x; myptr = (reinterpret_cast(commbuff[physgpu])) + flagoffset; reduceidptr = myptr - NVTE_MAX_OPS; // +op; reduce_id = (*reduceidptr) + 1; - flagptr = (reinterpret_cast(commbuff[targetgpu])) + flagoffset + blockflagoffset; - myptr += blockflagoffset; - - flagptr[physgpu] = reduce_id; + flagptr = (reinterpret_cast(commbuff[targetgpu])) + flagoffset; + if (blockIdx.x == 0) + flagptr[physgpu] = reduce_id; volatile int *flag = (volatile int *)&(myptr[targetgpu]); userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu + handleridx]); clock_t s = clock64(); @@ -252,11 +279,18 @@ __global__ void __launch_bounds__(MAX_THREADS) } } __syncthreads(); + if (threadIdx.x == 0) { + const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1; + int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder); + if (old_val + adder == NVTE_MAX_SMS * reduce_id) + lastSM = 1; + } int warp = blockIdx.x + (threadIdx.x >> 5); int dest[RANKS]; #pragma unroll - for (int i = 0; i < RANKS; i++) dest[i] = (i + myrank + warp) & (RANKS - 1); + for (int i = 0; i < RANKS; i++) + dest[i] = (i + myrank + warp) & (RANKS - 1); __syncthreads(); for (int line = threadIdx.x + blockDim.x * blockIdx.x; line < totallines; @@ -275,13 +309,15 @@ __global__ void __launch_bounds__(MAX_THREADS) for (int i = 1; i < RANKS; i++) { half *x = reinterpret_cast(&val[i]); #pragma unroll - for (int j = 0; j < 8; j++) s[j] += x[j]; + for (int j = 0; j < 8; j++) + s[j] += x[j]; } userptr[myrank][mylineoffset + line] = sum; } - if (threadIdx.x == 0 && blockIdx.x == 0) *reduceidptr = reduce_id; + if (threadIdx.x == 0 && lastSM) + *reduceidptr = reduce_id; } // fp16 inplace reduce-scatter kernel template @@ -293,19 +329,19 @@ __global__ void __launch_bounds__(MAX_THREADS) const int skiplines, void **commbuff, const int handleridx, void *outbuf) { __shared__ int4 *userptr[RANKS]; - int *flagptr, physgpu, targetgpu, *myptr; + volatile int *flagptr; + int physgpu, targetgpu, *myptr; int *reduceidptr, reduce_id; + int lastSM = 0; if (threadIdx.x < RANKS) { physgpu = myrank * gpustep + firstrank; targetgpu = threadIdx.x * gpustep + firstrank; - const int blockflagoffset = NVTE_MAX_NVLINK * 2 * blockIdx.x; myptr = (reinterpret_cast(commbuff[physgpu])) + flagoffset; reduceidptr = myptr - NVTE_MAX_OPS; // +op; reduce_id = (*reduceidptr) + 1; - flagptr = (reinterpret_cast(commbuff[targetgpu])) + flagoffset + blockflagoffset; - myptr += blockflagoffset; - - flagptr[physgpu] = reduce_id; + flagptr = (reinterpret_cast(commbuff[targetgpu])) + flagoffset; + if (blockIdx.x == 0) + flagptr[physgpu] = reduce_id; volatile int *flag = (volatile int *)&(myptr[targetgpu]); userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu + handleridx]); clock_t s = clock64(); @@ -318,11 +354,18 @@ __global__ void __launch_bounds__(MAX_THREADS) } } __syncthreads(); + if (threadIdx.x == 0) { + const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1; + int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder); + if (old_val + adder == NVTE_MAX_SMS * reduce_id) + lastSM = 1; + } int warp = blockIdx.x + (threadIdx.x >> 5); int dest[RANKS]; #pragma unroll - for (int i = 0; i < RANKS; i++) dest[i] = (i + myrank + warp) & (RANKS - 1); + for (int i = 0; i < RANKS; i++) + dest[i] = (i + myrank + warp) & (RANKS - 1); __syncthreads(); for (int line = threadIdx.x + blockDim.x * blockIdx.x; line < totallines; @@ -341,24 +384,28 @@ __global__ void __launch_bounds__(MAX_THREADS) for (int i = 1; i < RANKS; i++) { half *x = reinterpret_cast(&val[i]); #pragma unroll - for (int j = 0; j < 8; j++) s[j] += x[j]; + for (int j = 0; j < 8; j++) + s[j] += x[j]; } (reinterpret_cast(outbuf))[(line / rowlines) * skiplines + (line % rowlines)] = sum; } - if (threadIdx.x == 0 && blockIdx.x == 0) *reduceidptr = reduce_id; + if (threadIdx.x == 0 && lastSM) + *reduceidptr = reduce_id; } // fp16 reduce-scatter kernel (out of place) +#if 0 +// All MC kernels here template __global__ void __launch_bounds__(MAX_THREADS) - userbuffers_fp16_sum_inplace_gpu_rr_ag(const int op, const int flagoffset, const int firstrank, - const int myrank, const int gpustep, - const int mylineoffset, const int totallines, - void **commbuff, const int handleridx) { - __shared__ int4 *userptr[RANKS]; + userbuffers_fp16_sum_inplace_gpu_mc(const int op, const int flagoffset, const int firstrank, + const int myrank, const int gpustep, const int lineoffset, + const int numlines, void **commbuff, const int handleridx, + float4 *mc_ptr) { int *flagptr, physgpu, targetgpu, *myptr; int *reduceidptr, reduce_id; + if (threadIdx.x < RANKS) { physgpu = myrank * gpustep + firstrank; targetgpu = threadIdx.x * gpustep + firstrank; @@ -371,114 +418,322 @@ __global__ void __launch_bounds__(MAX_THREADS) flagptr[physgpu] = reduce_id; volatile int *flag = (volatile int *)&(myptr[targetgpu]); - userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu + handleridx]); clock_t s = clock64(); - } - - int warp = blockIdx.x + (threadIdx.x >> 5); - int dest[RANKS]; - - int skipmy = 0; -#pragma unroll - for (int i = 0; i < RANKS; i++) { - int dst = (i + warp + myrank) & (RANKS - 1); - if (dst == myrank) { - skipmy++; - continue; + while (*flag < reduce_id) { + if (clock64() - s > TIMEOUT) { + printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, + *flag); + break; + } } - dest[i - skipmy] = dst; + reduce_id++; } __syncthreads(); +#define UNROLL_MC 8 + const int loop_step0 = blockDim.x * gridDim.x * RANKS; + const int loop_step = loop_step0 * UNROLL_MC; + const int start_elem = threadIdx.x + blockDim.x * (myrank + RANKS * blockIdx.x); + const int end_elem = max(start_elem, numlines); + const int aligned_elem = ((end_elem - start_elem) / loop_step) * loop_step; + const int end_aligned = start_elem + aligned_elem; - for (int line = threadIdx.x + blockDim.x * blockIdx.x; line < totallines; - line += blockDim.x * gridDim.x) { - int4 val[RANKS - 1]; - + for (int line = start_elem; line < end_aligned; line += loop_step) { + uint4 val[UNROLL_MC]; #pragma unroll - for (int i = 0; i < RANKS - 1; i++) { - val[i] = userptr[dest[i]][mylineoffset + line + totallines * dest[i]]; - } - + for (int i = 0; i < UNROLL_MC; i++) +#if defined(NVTE_UB_FP16) + asm("multimem.ld_reduce.global.add.v4.f16x2 {%0,%1,%2,%3}, [%4];" + : "=r"(val[i].x), "=r"(val[i].y), "=r"(val[i].z), "=r"(val[i].w) + : "l"(mc_ptr + (lineoffset + line + i * loop_step0)) + : "memory"); +#else + asm("multimem.ld_reduce.global.add.v4.bf16x2 {%0,%1,%2,%3}, [%4];" + : "=r"(val[i].x), "=r"(val[i].y), "=r"(val[i].z), "=r"(val[i].w) + : "l"(mc_ptr + (lineoffset + line + i * loop_step0)) + : "memory"); +#endif #pragma unroll - for (int i = 0; i < RANKS - 1; i++) { - userptr[myrank][mylineoffset + line + totallines * dest[i]] = val[i]; + for (int i = 0; i < UNROLL_MC; i++) + asm volatile("multimem.st.global.v4.f32 [%0], {%1,%2,%3,%4};" ::"l"( + mc_ptr + (lineoffset + line + i * loop_step0)), + "r"(val[i].x), "r"(val[i].y), "r"(val[i].z), "r"(val[i].w) + : "memory"); + } + for (int line = end_aligned; line < end_elem; line += loop_step0) { + uint4 val; +#if defined(NVTE_UB_FP16) + asm("multimem.ld_reduce.global.add.v4.f16x2 {%0,%1,%2,%3}, [%4];" + : "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w) + : "l"(mc_ptr + (lineoffset + line)) + : "memory"); +#else + asm("multimem.ld_reduce.global.add.v4.bf16x2 {%0,%1,%2,%3}, [%4];" + : "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w) + : "l"(mc_ptr + (lineoffset + line)) + : "memory"); +#endif + asm volatile( + "multimem.st.global.v4.f32 [%0], {%1,%2,%3,%4};" ::"l"(mc_ptr + (lineoffset + line)), + "r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w) + : "memory"); + } + + __syncthreads(); + if (threadIdx.x == 0) + __threadfence_system(); + __syncthreads(); + + if (threadIdx.x < RANKS) { + flagptr[physgpu] = reduce_id; + volatile int *flag = (volatile int *)&myptr[targetgpu]; + clock_t s = clock64(); + while (*flag < reduce_id) { + if (clock64() - s > 2ull * TIMEOUT) { + printf("NVONLY AGBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, + *flag); + break; + } } } - if (threadIdx.x == 0 && blockIdx.x == 0) *reduceidptr = reduce_id; -} // fp16 inplace reduce kernel (Ampere) + if (threadIdx.x == 0 && blockIdx.x == 0) + *reduceidptr = reduce_id; +} // fp16 inplace reduce kernel (Hopper) MC template __global__ void __launch_bounds__(MAX_THREADS) - userbuffers_fp16_sum_inplace_gpu_rw_ag(const int op, const int flagoffset, const int firstrank, + userbuffers_fp16_sum_inplace_gpu_mc_rs(const int op, const int flagoffset, const int firstrank, const int myrank, const int gpustep, const int mylineoffset, const int totallines, - void **commbuff, const int handleridx) { - __shared__ int4 *userptr[RANKS]; - int *flagptr, physgpu, targetgpu, *myptr; + void **commbuff, const int handleridx, float4 *mc_ptr) { + volatile int *flagptr; + int physgpu, targetgpu, *myptr; int *reduceidptr, reduce_id; - int4 *localptr; + uint4 *localptr = reinterpret_cast(commbuff[myrank * gpustep + firstrank + handleridx]); + int lastSM = 0; + if (threadIdx.x < RANKS) { physgpu = myrank * gpustep + firstrank; targetgpu = threadIdx.x * gpustep + firstrank; - const int blockflagoffset = NVTE_MAX_NVLINK * 2 * blockIdx.x; myptr = (reinterpret_cast(commbuff[physgpu])) + flagoffset; reduceidptr = myptr - NVTE_MAX_OPS; // +op; reduce_id = (*reduceidptr) + 1; - flagptr = (reinterpret_cast(commbuff[targetgpu])) + flagoffset + blockflagoffset; - myptr += blockflagoffset; - userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu + handleridx]); - reduce_id++; + flagptr = (reinterpret_cast(commbuff[targetgpu])) + flagoffset; + if (blockIdx.x == 0) + flagptr[physgpu] = reduce_id; + volatile int *flag = (volatile int *)&(myptr[targetgpu]); + clock_t s = clock64(); + while (*flag < reduce_id) { + if (clock64() - s > TIMEOUT) { + printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, + *flag); + break; + } + } } __syncthreads(); - localptr = userptr[myrank]; + if (threadIdx.x == 0) { + const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1; + int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder); + if (old_val + adder == NVTE_MAX_SMS * reduce_id) + lastSM = 1; + } + const int loop_step0 = blockDim.x * gridDim.x; + const int loop_step = loop_step0 * UNROLL_MC; + const int start_elem = threadIdx.x + blockDim.x * blockIdx.x; + const int end_elem = max(start_elem, totallines); + const int aligned_elem = ((end_elem - start_elem) / loop_step) * loop_step; + const int end_aligned = start_elem + aligned_elem; - int warp = blockIdx.x + (threadIdx.x >> 5); - int dest[RANKS - 1]; - int skipmy = 0; + for (int line = start_elem; line < end_aligned; line += loop_step) { + uint4 val[UNROLL_MC]; #pragma unroll - for (int i = 0; i < RANKS; i++) { - int dst = (i + warp + myrank) & (RANKS - 1); - if (dst == myrank) { - skipmy++; - continue; + for (int i = 0; i < UNROLL_MC; i++) +#if defined(NVTE_UB_FP16) + asm("multimem.ld_reduce.global.add.v4.f16x2 {%0,%1,%2,%3}, [%4];" + : "=r"(val[i].x), "=r"(val[i].y), "=r"(val[i].z), "=r"(val[i].w) + : "l"(mc_ptr + (mylineoffset + line + i * loop_step0)) + : "memory"); +#else + asm("multimem.ld_reduce.global.add.v4.bf16x2 {%0,%1,%2,%3}, [%4];" + : "=r"(val[i].x), "=r"(val[i].y), "=r"(val[i].z), "=r"(val[i].w) + : "l"(mc_ptr + (mylineoffset + line + i * loop_step0)) + : "memory"); +#endif +#pragma unroll + for (int i = 0; i < UNROLL_MC; i++) + localptr[mylineoffset + line + i * loop_step0] = val[i]; + } + for (int line = end_aligned; line < end_elem; line += loop_step0) { + uint4 val; +#if defined(NVTE_UB_FP16) + asm("multimem.ld_reduce.global.add.v4.f16x2 {%0,%1,%2,%3}, [%4];" + : "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w) + : "l"(mc_ptr + (mylineoffset + line)) + : "memory"); +#else + asm("multimem.ld_reduce.global.add.v4.bf16x2 {%0,%1,%2,%3}, [%4];" + : "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w) + : "l"(mc_ptr + (mylineoffset + line)) + : "memory"); +#endif + localptr[mylineoffset + line] = val; + } + + if (threadIdx.x == 0 && lastSM) + *reduceidptr = reduce_id; +} // fp16 inplace reduce-scatter kernel MC + +template +__global__ void __launch_bounds__(MAX_THREADS) + userbuffers_fp16_sum_inplace_gpu_mc_rs_oop(const int op, const int flagoffset, + const int firstrank, const int myrank, + const int gpustep, const int mylineoffset, + const int totallines, const int rowlines, + const int skiplines, void **commbuff, + const int handleridx, void *outbuf, float4 *mc_ptr) { + volatile int *flagptr; + int physgpu, targetgpu, *myptr; + int *reduceidptr, reduce_id; + int lastSM = 0; + + if (threadIdx.x < RANKS) { + physgpu = myrank * gpustep + firstrank; + targetgpu = threadIdx.x * gpustep + firstrank; + myptr = (reinterpret_cast(commbuff[physgpu])) + flagoffset; + reduceidptr = myptr - NVTE_MAX_OPS; // +op; + reduce_id = (*reduceidptr) + 1; + flagptr = (reinterpret_cast(commbuff[targetgpu])) + flagoffset; + if (blockIdx.x == 0) + flagptr[physgpu] = reduce_id; + volatile int *flag = (volatile int *)&(myptr[targetgpu]); + clock_t s = clock64(); + while (*flag < reduce_id) { + if (clock64() - s > TIMEOUT) { + printf("[%d] NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", myrank, blockIdx.x, + threadIdx.x, reduce_id, *flag); + break; + } } - dest[i - skipmy] = dst; } -#define UNROLLAG 4 __syncthreads(); + if (threadIdx.x == 0) { + const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1; + int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder); + if (old_val + adder == NVTE_MAX_SMS * reduce_id) + lastSM = 1; + } + const int loop_step0 = blockDim.x * gridDim.x; - const int loop_step = loop_step0 * UNROLLAG; + const int loop_step = loop_step0 * UNROLL_MC; const int start_elem = threadIdx.x + blockDim.x * blockIdx.x; const int end_elem = max(start_elem, totallines); const int aligned_elem = ((end_elem - start_elem) / loop_step) * loop_step; const int end_aligned = start_elem + aligned_elem; - for (int line = start_elem; line < end_aligned; line += loop_step) { - int4 val[UNROLLAG]; + uint4 val[UNROLL_MC]; +#pragma unroll + for (int i = 0; i < UNROLL_MC; i++) +#if defined(NVTE_UB_FP16) + asm("multimem.ld_reduce.global.add.v4.f16x2 {%0,%1,%2,%3}, [%4];" + : "=r"(val[i].x), "=r"(val[i].y), "=r"(val[i].z), "=r"(val[i].w) + : "l"(mc_ptr + (mylineoffset + line + i * loop_step0)) + : "memory"); +#else + asm("multimem.ld_reduce.global.add.v4.bf16x2 {%0,%1,%2,%3}, [%4];" + : "=r"(val[i].x), "=r"(val[i].y), "=r"(val[i].z), "=r"(val[i].w) + : "l"(mc_ptr + (mylineoffset + line + i * loop_step0)) + : "memory"); +#endif #pragma unroll - for (int j = 0; j < UNROLLAG; j++) val[j] = localptr[mylineoffset + line + loop_step0 * j]; + for (int i = 0; i < UNROLL_MC; i++) + (reinterpret_cast(outbuf))[((line + i * loop_step0) / rowlines) * skiplines + + ((line + i * loop_step0) % rowlines)] = val[i]; + } + for (int line = end_aligned; line < end_elem; line += loop_step0) { + uint4 val; +#if defined(NVTE_UB_FP16) + asm("multimem.ld_reduce.global.add.v4.f16x2 {%0,%1,%2,%3}, [%4];" + : "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w) + : "l"(mc_ptr + (mylineoffset + line)) + : "memory"); +#else + asm("multimem.ld_reduce.global.add.v4.bf16x2 {%0,%1,%2,%3}, [%4];" + : "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w) + : "l"(mc_ptr + (mylineoffset + line)) + : "memory"); +#endif + reinterpret_cast (outbuf)[(line / rowlines) * skiplines + (line % rowlines)] = val; + } + + if (threadIdx.x == 0 && lastSM) + *reduceidptr = reduce_id; +} // fp16 reduce-scatter kernel (out of place) fp16 MC + +template +__global__ void __launch_bounds__(MAX_THREADS) + userbuffers_fp16_sum_inplace_gpu_mc_ag(const int op, const int flagoffset, const int firstrank, + const int myrank, const int gpustep, + const int mylineoffset, const int totallines, + void **commbuff, const int handleridx, uint4 *mc_ptr) { + volatile int *flagptr; + int physgpu, targetgpu, *myptr; + int *reduceidptr, reduce_id; + uint4 *localptr = reinterpret_cast(commbuff[myrank * gpustep + firstrank + handleridx]); + + if (threadIdx.x < RANKS) { + physgpu = myrank * gpustep + firstrank; + targetgpu = threadIdx.x * gpustep + firstrank; + myptr = (reinterpret_cast(commbuff[physgpu])) + flagoffset; + reduceidptr = myptr - NVTE_MAX_OPS; // +op; + reduce_id = (*reduceidptr) + 1; + flagptr = (reinterpret_cast(commbuff[targetgpu])) + flagoffset; + } + __syncthreads(); + const int loop_step0 = blockDim.x * gridDim.x; + const int loop_step = loop_step0 * UNROLL_MC; + const int start_elem = threadIdx.x + blockDim.x * blockIdx.x; + const int end_elem = max(start_elem, totallines); + const int aligned_elem = ((end_elem - start_elem) / loop_step) * loop_step; + const int end_aligned = start_elem + aligned_elem; + for (int line = start_elem; line < end_aligned; line += loop_step) { + uint4 val[UNROLL_MC]; #pragma unroll - for (int j = 0; j < UNROLLAG; j++) + for (int i = 0; i < UNROLL_MC; i++) + val[i] = localptr[mylineoffset + line + i * loop_step0]; #pragma unroll - for (int i = 0; i < RANKS - 1; i++) { - userptr[dest[i]][mylineoffset + line + j * loop_step0] = val[j]; - } + for (int i = 0; i < UNROLL_MC; i++) + asm volatile("multimem.st.global.v4.f32 [%0], {%1,%2,%3,%4};" ::"l"( + mc_ptr + (mylineoffset + line + i * loop_step0)), + "r"(val[i].x), "r"(val[i].y), "r"(val[i].z), "r"(val[i].w) + : "memory"); } - for (int line = end_aligned; line < end_elem; line += loop_step0) { - int4 sum = localptr[mylineoffset + line]; -#pragma unroll - for (int i = 0; i < RANKS - 1; i++) { - userptr[dest[i]][mylineoffset + line] = sum; - } + uint4 val = localptr[mylineoffset + line]; + asm volatile( + "multimem.st.global.v4.f32 [%0], {%1,%2,%3,%4};" ::"l"(mc_ptr + (mylineoffset + line)), + "r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w) + : "memory"); } __syncthreads(); - if (threadIdx.x == 0) __threadfence_system(); + if (threadIdx.x == 0) + __threadfence_system(); __syncthreads(); - if (threadIdx.x < RANKS) { + __shared__ int lastSM; + if (threadIdx.x == 0) { + const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1; + int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder); + if (old_val + adder == NVTE_MAX_SMS * reduce_id) + lastSM = 1; + else + lastSM = 0; + } + __syncthreads(); + if (lastSM && threadIdx.x < RANKS) { + if (threadIdx.x == 0) + *reduceidptr = reduce_id; flagptr[physgpu] = reduce_id; volatile int *flag = (volatile int *)&myptr[targetgpu]; clock_t s = clock64(); @@ -490,229 +745,983 @@ __global__ void __launch_bounds__(MAX_THREADS) } } } - if (threadIdx.x == 0 && blockIdx.x == 0) *reduceidptr = reduce_id; -} // fp16 inplace allgather kernel (Volta,Hopper) +} // fp16 inplace allgather kernel (Hopper) MC +#else template __global__ void __launch_bounds__(MAX_THREADS) - userbuffers_fp16_sum_inplace_gpu_rr_blocked(const int op, const int flagoffset, - const int firstrank, const int myrank, - const int lineoffset, const int numlines, - void **commbuff, const int handleridx, - const int peerblocklines, int *hostflags, - int *gpuflag, const int numblocks) { - const int basecounter = gpuflag[NVTE_GF_STATE + op]; + userbuffers_fp16_sum_inplace_gpu_mc(const int op, const int flagoffset, const int firstrank, + const int myrank, const int gpustep, const int lineoffset, + const int numlines, void **commbuff, const int handleridx, + float4 *mc_ptr) {} +template +__global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_mc_rs_oop( + const int op, const int flagoffset, const int firstrank, const int myrank, const int gpustep, + const int mylineoffset, const int totallines, const int rowlines, const int skiplines, + void **commbuff, const int handleridx, void *outbuf, float4 *mc_ptr) {} +template +__global__ void __launch_bounds__(MAX_THREADS) + userbuffers_fp16_sum_inplace_gpu_mc_ag(const int op, const int flagoffset, const int firstrank, + const int myrank, const int gpustep, + const int mylineoffset, const int totallines, + void **commbuff, const int handleridx, uint4 *mc_ptr) {} +template +__global__ void __launch_bounds__(MAX_THREADS) + userbuffers_fp16_sum_inplace_gpu_mc_rs(const int op, const int flagoffset, const int firstrank, + const int myrank, const int gpustep, + const int mylineoffset, const int totallines, + void **commbuff, const int handleridx, float4 *mc_ptr) {} +#endif -#define REDUCETHREADS (blockDim.x - 32) +template +__global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_fp8( + const int op, const int flagoffset, const int firstrank, const int myrank, const int gpustep, + const int mylineoffset, const int totallines, const int rowlines, const int skiplines, + void **commbuff, const int handleridx, void *outbuf, float *scale) { + __shared__ int4 *userptr[RANKS]; + volatile int *flagptr; + int physgpu, targetgpu, *myptr; + int *reduceidptr, reduce_id; + int lastSM = 0; + half hscale = (half)*scale; - if (threadIdx.x < 32) { - int *flagptr; - if (threadIdx.x < RANKS) { - if (!blockIdx.x) { - flagptr = reinterpret_cast(commbuff[threadIdx.x + firstrank]); - flagptr[flagoffset + myrank + firstrank] = basecounter; - } - volatile int *flag = (volatile int *)&((reinterpret_cast( - commbuff[myrank + firstrank]))[flagoffset + threadIdx.x + firstrank]); - while (*flag < basecounter) { + if (threadIdx.x < RANKS) { + physgpu = myrank * gpustep + firstrank; + targetgpu = threadIdx.x * gpustep + firstrank; + myptr = (reinterpret_cast(commbuff[physgpu])) + flagoffset; + reduceidptr = myptr - NVTE_MAX_OPS; // +op; + reduce_id = (*reduceidptr) + 1; + flagptr = (reinterpret_cast(commbuff[targetgpu])) + flagoffset; + if (blockIdx.x == 0) + flagptr[physgpu] = reduce_id; + volatile int *flag = (volatile int *)&(myptr[targetgpu]); + userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu + handleridx]); + clock_t s = clock64(); + while (*flag < reduce_id) { + if (clock64() - s > TIMEOUT) { + printf("[%d] NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", myrank, blockIdx.x, + threadIdx.x, reduce_id, *flag); + break; } } - __syncthreads(); - - int startblock = 0, endblock = numblocks; + } + __syncthreads(); + if (threadIdx.x == 0) { + const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1; + int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder); + if (old_val + adder == NVTE_MAX_SMS * reduce_id) + lastSM = 1; + } + int warp = blockIdx.x + (threadIdx.x >> 5); + int dest[RANKS]; +#pragma unroll + for (int i = 0; i < RANKS; i++) + dest[i] = (i + myrank + warp) & (RANKS - 1); - for (int nblock = 0; nblock < endblock; nblock++) { - asm volatile("bar.sync 13, %0;" ::"r"(REDUCETHREADS + 32)); + __syncthreads(); + for (int line = threadIdx.x + blockDim.x * blockIdx.x; line < totallines; + line += blockDim.x * gridDim.x) { + int4 val[RANKS]; - if (threadIdx.x == 0) { - __threadfence(); - if (blockIdx.x) gpuflag[op * NVTE_MAX_SMS * 2 + blockIdx.x] = nblock + basecounter + 1; - } else if (blockIdx.x == 0) { - int expecting = (basecounter + nblock + 1); - if (threadIdx.x < gridDim.x) - while (((volatile int *)gpuflag)[op * NVTE_MAX_SMS * 2 + threadIdx.x] < expecting) { - } - } - if (!blockIdx.x) { - asm volatile("bar.sync 15, %0;" ::"r"(32)); - if (!threadIdx.x) hostflags[0] = nblock + basecounter + 1; - } +#pragma unroll + for (int i = 0; i < RANKS; i++) { + val[i] = userptr[dest[i]][mylineoffset + line]; } - int cachedflag = basecounter; + int4 sum[2] = {{0, 0, 0, 0}, {0, 0, 0, 0}}; + half *s = reinterpret_cast(&sum); -#define ALLGATHERFLAG NVTE_GF_IBSHARPDONE +#pragma unroll + for (int i = 0; i < RANKS; i++) { + fp8type *x = reinterpret_cast(&val[i]); +#pragma unroll + for (int j = 0; j < sizeof(int4) / sizeof(fp8type); j++) + s[j] += hscale * (half)(x[j]); + } + int hline = 2 * line; + (reinterpret_cast(outbuf))[(hline / rowlines) * skiplines + (hline % rowlines)] = + sum[0]; + hline++; + (reinterpret_cast(outbuf))[(hline / rowlines) * skiplines + (hline % rowlines)] = + sum[1]; + } - if (blockIdx.x == 0 && threadIdx.x < RANKS) { - while (cachedflag < basecounter + numblocks) { - int newflag = ((volatile int *)gpuflag)[ALLGATHERFLAG]; - if (newflag == cachedflag) continue; - cachedflag = newflag; - flagptr[flagoffset + myrank + 32 + firstrank] = cachedflag; + if (threadIdx.x == 0 && lastSM) + *reduceidptr = reduce_id; +} // fp16 reduce-scatter kernel (out of place) (fp8->fp16) + +template +__global__ void __launch_bounds__(MAX_THREADS) + userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_atomic_fp8( + const int op, const int flagoffset, const int firstrank, const int myrank, + const int gpustep, const int mylineoffset, const int totallines, const int rowlines, + const int skiplines_out, const int skiplines_in, void **commbuff, const int handleridx, + void *outbuf, float *scale, void *counters, const int numchunks, const int atomicindex) { + __shared__ int4 *userptr[RANKS]; + volatile int *flagptr; + int physgpu, targetgpu, *myptr; + int *reduceidptr, reduce_id; + int lastSM = 0; + half hscale = (half)*scale; + + if (threadIdx.x < RANKS) { + physgpu = myrank * gpustep + firstrank; + targetgpu = threadIdx.x * gpustep + firstrank; + // const int blockflagoffset = MAX_NVLINK * 2 * blockIdx.x; + myptr = (reinterpret_cast(commbuff[physgpu])) + flagoffset; + reduceidptr = myptr - NVTE_MAX_OPS; // +op; + reduce_id = (*reduceidptr); + flagptr = (reinterpret_cast(commbuff[targetgpu])) + flagoffset; // + blockflagoffset; + } + + for (int chunk_i = 0; chunk_i < numchunks; chunk_i++) { + ATOMIC_CONSUMER(chunk_i); + + lastSM = 0; + if (threadIdx.x < RANKS) { + reduce_id++; + if (blockIdx.x == 0) + flagptr[physgpu] = reduce_id; + volatile int *flag = (volatile int *)&(myptr[targetgpu]); + userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu + handleridx]); + clock_t s = clock64(); + while (*flag < reduce_id) { + if (clock64() - s > TIMEOUT) { + printf("[%d] NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", myrank, blockIdx.x, + threadIdx.x, reduce_id, *flag); + break; + } } } + __syncthreads(); + if (threadIdx.x == 0) { + const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1; + int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), /*numchunks * */ adder); + if (old_val + adder == NVTE_MAX_SMS * (reduce_id /* + numchunks*/)) + lastSM = 1; + } - if (blockIdx.x == 0 && threadIdx.x == 0) gpuflag[NVTE_GF_STATE + op] = basecounter + numblocks; - } else { - const int warp = blockIdx.x + (threadIdx.x >> 5); - int4 *userptr[RANKS]; - int4 *userptrmyrank; + int warp = blockIdx.x + (threadIdx.x >> 5); + int dest[RANKS]; #pragma unroll for (int i = 0; i < RANKS; i++) - userptr[i] = reinterpret_cast( - commbuff[((i + myrank + warp) & (RANKS - 1)) + handleridx + firstrank]); - userptrmyrank = reinterpret_cast(commbuff[myrank + handleridx + firstrank]); + dest[i] = (i + myrank + warp) & (RANKS - 1); + __syncthreads(); + for (int line = threadIdx.x + blockDim.x * blockIdx.x; line < totallines; + line += blockDim.x * gridDim.x) { + int4 val[RANKS]; + const int rowlines_in = rowlines / 2; + const int index_in = skiplines_in == 0 + ? mylineoffset + myrank * totallines + line + : (numchunks <= 1 ? 1 : chunk_i) * mylineoffset + + myrank * (totallines * skiplines_in / rowlines_in) + + (line / rowlines_in) * skiplines_in + (line % rowlines_in); + const int index1_out = chunk_i * mylineoffset * 2 + ((2 * line) / rowlines) * skiplines_out + + ((2 * line) % rowlines); + const int index2_out = chunk_i * mylineoffset * 2 + + ((2 * line + 1) / rowlines) * skiplines_out + + ((2 * line + 1) % rowlines); - int blocklineoffset = 0; +#pragma unroll + for (int i = 0; i < RANKS; i++) { + val[i] = userptr[dest[i]][index_in]; + } - while (blocklineoffset < numlines) { - const int remainder = min(numlines - blocklineoffset, peerblocklines * RANKS); - const int blocklines = remainder / RANKS; - const int blockstart = lineoffset + blocklineoffset + blocklines * myrank; + int4 sum[2] = {{0, 0, 0, 0}, {0, 0, 0, 0}}; + half *s = reinterpret_cast(&sum); - for (int line = threadIdx.x - 32 + REDUCETHREADS * blockIdx.x; line < blocklines; - line += REDUCETHREADS * gridDim.x) { - int4 val[RANKS]; +#pragma unroll + for (int i = 0; i < RANKS; i++) { + fp8type *x = reinterpret_cast(&val[i]); +#pragma unroll + for (int j = 0; j < sizeof(int4) / sizeof(fp8type); j++) + s[j] += hscale * (half)(x[j]); + } + (reinterpret_cast(outbuf))[index1_out] = sum[0]; + (reinterpret_cast(outbuf))[index2_out] = sum[1]; + } + } + if (threadIdx.x == 0 && lastSM) + *reduceidptr = reduce_id; +} // fp16 reduce-scatter kernel (out of place) (fp8->fp16) + +template +__global__ void __launch_bounds__(MAX_THREADS) + userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride(const int op, const int flagoffset, + const int firstrank, const int myrank, + const int gpustep, const int mylineoffset, + const int totallines, const int rowlines, + const int skiplines, void **commbuff, + const int handleridx, void *outbuf) { + __shared__ int4 *userptr[RANKS]; + volatile int *flagptr; + int physgpu, targetgpu, *myptr; + int *reduceidptr, reduce_id; + int lastSM = 0; + + if (threadIdx.x < RANKS) { + physgpu = myrank * gpustep + firstrank; + targetgpu = threadIdx.x * gpustep + firstrank; + myptr = (reinterpret_cast(commbuff[physgpu])) + flagoffset; + reduceidptr = myptr - NVTE_MAX_OPS; // +op; + reduce_id = (*reduceidptr) + 1; + flagptr = (reinterpret_cast(commbuff[targetgpu])) + flagoffset; + if (blockIdx.x == 0) + flagptr[physgpu] = reduce_id; + volatile int *flag = (volatile int *)&(myptr[targetgpu]); + userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu + handleridx]); + clock_t s = clock64(); + while (*flag < reduce_id) { + if (clock64() - s > TIMEOUT) { + printf("[%d] NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", myrank, blockIdx.x, + threadIdx.x, reduce_id, *flag); + break; + } + } + } + __syncthreads(); + if (threadIdx.x == 0) { + const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1; + int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder); + if (old_val + adder == NVTE_MAX_SMS * reduce_id) + lastSM = 1; + } + int warp = blockIdx.x + (threadIdx.x >> 5); + int dest[RANKS]; #pragma unroll - for (int i = 0; i < RANKS; i++) { - val[i] = userptr[i][blockstart + line]; - } + for (int i = 0; i < RANKS; i++) + dest[i] = (i + myrank + warp) & (RANKS - 1); - int4 sum = val[0]; - half *s = reinterpret_cast(&sum); + for (int line = threadIdx.x + blockDim.x * blockIdx.x; line < totallines; + line += blockDim.x * gridDim.x) { + int4 val[RANKS]; + int index_in = mylineoffset + myrank * (totallines * skiplines / rowlines) + + (line / rowlines) * skiplines + (line % rowlines); #pragma unroll - for (int i = 1; i < RANKS; i++) { - half *x = reinterpret_cast(&val[i]); + for (int i = 0; i < RANKS; i++) { + val[i] = userptr[dest[i]][index_in]; + } + + int4 sum = val[0]; + half *s = reinterpret_cast(&sum); + #pragma unroll - for (int j = 0; j < sizeof(int4) / sizeof(half); j++) s[j] += x[j]; - } + for (int i = 1; i < RANKS; i++) { + half *x = reinterpret_cast(&val[i]); +#pragma unroll + for (int j = 0; j < 8; j++) + s[j] += x[j]; + } - userptrmyrank[blockstart + line] = sum; - } // single block loop + int index_out = (line / rowlines) * skiplines + (line % rowlines); + (reinterpret_cast(outbuf))[index_out] = sum; + } - asm volatile("bar.sync 13, %0;" ::"r"(REDUCETHREADS + 32)); + if (threadIdx.x == 0 && lastSM) + *reduceidptr = reduce_id; +} // fp16 reduce-scatter kernel (out of place) fp16 + +#if 0 +template +__global__ void +__launch_bounds__(MAX_THREADS) +userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride_atomic_fp8( + const int op, const int flagoffset, const int firstrank, const int myrank, const int gpustep, + const int mylineoffset, const int totallines, const int rowlines, const int skiplines, + const int numchunks, void **commbuff, const int handleridx, void* outbuf, void *counters, + float* scale) { + if (counters) { + if ( threadIdx.x == 0 ) { + // spin-lock on counter from producer + int old_val; + while (0 != (old_val = atomicCAS(((unsigned int*)counters), 0, 0) )) {} + + // make sure all threadblocks have read/waited on counters. + int old_val2; + atomicInc(((unsigned int *)counters)+numchunks, gridDim.x-1); + while (0 != (old_val2 = atomicCAS(((unsigned int*)counters)+numchunks, 0, 0) )) {} + + // reset counter for next producer. + ((unsigned int*)counters)[0] = 1; + asm volatile ("fence.sc.gpu;\n"); + } + } + __syncthreads(); - blocklineoffset += peerblocklines * RANKS; - } // block loop NVLINK-REDUCESCATTER - const int nwarps = (REDUCETHREADS >> 5) / (RANKS - 1); - const int myblockDim = nwarps << 5; - const int mywarp = ((threadIdx.x - 32) >> 5) / (RANKS - 1); - const int maxthreadIdx = myblockDim * (RANKS - 1) + 32; - const int mydest = (myrank + 1 + ((threadIdx.x - 32) >> 5) % (RANKS - 1)) & (RANKS - 1); - const int mythreadIdx = (mywarp << 5) + (threadIdx.x & 31); - volatile int *flag = (volatile int *)&((reinterpret_cast( - commbuff[myrank + firstrank]))[flagoffset + mydest + 32 + firstrank]); + __shared__ int4* userptr[RANKS]; + volatile int *flagptr; + int physgpu, targetgpu, *myptr; + int *reduceidptr, reduce_id; + int lastSM = 0; + half hscale = (half) *scale; - int4 *userptrmydest = userptr[((RANKS << 10) + mydest - myrank - warp) & (RANKS - 1)]; + if (threadIdx.x < RANKS) { + physgpu = myrank*gpustep+firstrank; + targetgpu = threadIdx.x*gpustep+firstrank; + myptr = (reinterpret_cast(commbuff[physgpu])) + flagoffset; + reduceidptr = myptr-NVTE_MAX_OPS; // +op; + reduce_id =(*reduceidptr)+1; + flagptr = (reinterpret_cast(commbuff[targetgpu])) + flagoffset; + if (blockIdx.x == 0) flagptr[physgpu] = reduce_id; + volatile int* flag = (volatile int*)&(myptr[targetgpu]); + userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu+handleridx]); + clock_t s = clock64(); + while (*flag < reduce_id) { + if (clock64()-s > TIMEOUT) { + printf("[%d] NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", + myrank, blockIdx.x, threadIdx.x, reduce_id, *flag); + break; + } + } + } + __syncthreads(); + if (threadIdx.x == 0) { + const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS-gridDim.x+1 : 1; + int old_val = atomicAdd(myptr+(NVTE_MAX_NVLINK*2), adder); + if (old_val+adder == NVTE_MAX_SMS*reduce_id) lastSM = 1; + } - blocklineoffset = 0; - int gathercounter = basecounter + 1; - while (blocklineoffset < numlines) { - const int remainder = min(numlines - blocklineoffset, peerblocklines * RANKS); - const int blocklines = remainder / RANKS; - const int blockstart = lineoffset + blocklineoffset; -#define UNROLL 6 - int4 *myptr = &userptrmyrank[blockstart + blocklines * mydest]; - int4 *peerptr = &userptrmydest[blockstart + blocklines * mydest]; + int warp = blockIdx.x+(threadIdx.x>>5); + int dest[RANKS]; +#pragma unroll + for (int i = 0; i < RANKS; i++) + dest[i] = (i+myrank+warp)&(RANKS-1); - if (threadIdx.x < maxthreadIdx) { - const int start_elem = mythreadIdx + myblockDim * blockIdx.x; - const int end_elem = max(start_elem, blocklines); - const int aligned_elem = ((end_elem - start_elem) / (myblockDim * gridDim.x * UNROLL)) * - (myblockDim * gridDim.x * UNROLL); - const int end_aligned = start_elem + aligned_elem; + for (int line = threadIdx.x+blockDim.x*blockIdx.x; + line < totallines; line+=blockDim.x*gridDim.x) { + int4 val[RANKS]; + int index_in = mylineoffset + myrank*(totallines*skiplines/rowlines/2) + + (line/rowlines)*skiplines/2+(line%rowlines); - if (mythreadIdx == 0) { - while (*flag < gathercounter) { - } - gathercounter++; +#pragma unroll + for (int i = 0; i < RANKS; i++) { + val[i] = userptr[dest[i]][index_in]; } - asm volatile("bar.sync %0, %1;" ::"r"(1 + mydest), "r"(myblockDim)); + int4 sum[2] = {{0, 0, 0, 0}, {0, 0, 0, 0}}; + half *s = reinterpret_cast(&sum); - for (int line = start_elem; line < end_aligned; line += myblockDim * gridDim.x * UNROLL) { - int4 val[UNROLL]; #pragma unroll - for (int i = 0; i < UNROLL; i++) val[i] = peerptr[line + i * myblockDim * gridDim.x]; + for (int i = 0; i < RANKS; i++) { + fp8type *x = reinterpret_cast(&val[i]); #pragma unroll - for (int i = 0; i < UNROLL; i++) myptr[line + i * myblockDim * gridDim.x] = val[i]; + for (int j=0; j < sizeof(int4)/sizeof(fp8type); j++) s[j] += hscale * (half)(x[j]); } - for (int line = end_aligned; line < end_elem; line += myblockDim * gridDim.x) - myptr[line] = peerptr[line]; + int hline = 2*line; + int index_out1 = (hline/rowlines)*skiplines+(hline%rowlines); + (reinterpret_cast(outbuf))[index_out1] = sum[0]; + hline++; + int index_out2 = (hline/rowlines)*skiplines+(hline%rowlines); + (reinterpret_cast(outbuf))[index_out2] = sum[1]; } - blocklineoffset += peerblocklines * RANKS; - } // block loop for NVLINK-ALLGATHER - } // worker warps else block -} // fp16 inplace reduce kernel with SHARP / in blocks -// threadfence and SMs sync to SM0 -#define SMBAR(offset, block) \ - asm volatile("bar.sync 13, %0;" ::"r"(blockDim.x)); \ - if (threadIdx.x == 0) { \ - __threadfence_system(); \ - if (blockIdx.x) gpuflag[offset + blockIdx.x] = block + basecounter + 1; \ - } else if (blockIdx.x == 0) { \ - int expecting = (basecounter + block + 1); \ - if (threadIdx.x < gridDim.x) \ - while (((volatile int *)gpuflag)[offset + threadIdx.x] < expecting) { \ - } \ - } \ - if (blockIdx.x == 0) asm volatile("bar.sync 15, %0;" ::"r"(32)); + if (threadIdx.x == 0 && lastSM) *reduceidptr = reduce_id; +} // fp16 reduce-scatter kernel (out of place) fp16 +#endif template -__global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_rr_blocked2( - const int op, const int maxcredit, const int headstart, const int myibrank, const int ibranks, - const int commbufoffset, const int flagoffset, const int firstrank, const int myrank, - const int gpustep, const int lineoffset, const int numlines, void **commbuff, - const int handleridx, const int peerblocklines, int *hostflags, int *gpuflag, - const int numblocks) { - const int basecounter = gpuflag[NVTE_GF_STATE + op]; - if (threadIdx.x < 32) { - int *flagptr; - volatile int *localflag = (volatile int *)&( - ((int *)commbuff[gpustep * myrank + firstrank])[flagoffset]); // NOLINT(*) - // initial intranode barrier - once - if (threadIdx.x < RANKS) { - if (!blockIdx.x) { - flagptr = reinterpret_cast(commbuff[gpustep * threadIdx.x + firstrank]); - flagptr[flagoffset + gpustep * myrank + firstrank] = basecounter; +__global__ void __launch_bounds__(MAX_THREADS) + userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride_atomic( + const int op, const int flagoffset, const int firstrank, const int myrank, + const int gpustep, const int mylineoffset, const int totallines, const int rowlines, + const int skiplines, const int numchunks, void **commbuff, const int handleridx, + void *outbuf, void *counters) { + if (counters) { + if (threadIdx.x == 0) { + // spin-lock on counter from producer + int old_val; + while (0 != (old_val = atomicCAS(((unsigned int *)counters), 0, 0))) { } - volatile int *flag = &localflag[gpustep * threadIdx.x + firstrank]; - while (*flag < basecounter) { + + // make sure all threadblocks have read/waited on counters. + int old_val2; + atomicInc(((unsigned int *)counters) + numchunks, gridDim.x - 1); + while (0 != (old_val2 = atomicCAS(((unsigned int *)counters) + numchunks, 0, 0))) { } + + // reset counter for next producer. + ((unsigned int *)counters)[0] = 1; + asm volatile("fence.sc.gpu;\n"); } - __syncthreads(); + } + __syncthreads(); - for (int nblock = 0; nblock < numblocks + headstart; nblock++) { - if (nblock < numblocks) { - // RS happens here - SMBAR(op * 2 * NVTE_MAX_SMS, nblock); - if (!blockIdx.x && !threadIdx.x) - hostflags[NVTE_HF_NVRSDONE + (op & 1)] = nblock + basecounter + 1; - } + __shared__ int4 *userptr[RANKS]; + volatile int *flagptr; + int physgpu, targetgpu, *myptr; + int *reduceidptr, reduce_id; + int lastSM = 0; - if (nblock >= headstart) { - for (int ibflag = threadIdx.x; ibflag < ibranks; ibflag += 32) - if (ibflag != myibrank) - while (localflag[NVTE_REG0_IBRS + ibflag] < basecounter + nblock - headstart + 1) { - } - asm volatile("bar.sync 13, %0;" ::"r"(blockDim.x)); - // REDUCE happens here - SMBAR(op * 2 * NVTE_MAX_SMS + NVTE_MAX_SMS, nblock - headstart); - if (!blockIdx.x && !threadIdx.x) - hostflags[NVTE_HF_NVREDUCEDONE + (op & 1)] = nblock + basecounter + 1 - headstart; + if (threadIdx.x < RANKS) { + physgpu = myrank * gpustep + firstrank; + targetgpu = threadIdx.x * gpustep + firstrank; + myptr = (reinterpret_cast(commbuff[physgpu])) + flagoffset; + reduceidptr = myptr - NVTE_MAX_OPS; // +op; + reduce_id = (*reduceidptr) + 1; + flagptr = (reinterpret_cast(commbuff[targetgpu])) + flagoffset; + if (blockIdx.x == 0) + flagptr[physgpu] = reduce_id; + volatile int *flag = (volatile int *)&(myptr[targetgpu]); + userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu + handleridx]); + clock_t s = clock64(); + while (*flag < reduce_id) { + if (clock64() - s > TIMEOUT) { + printf("[%d] NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", myrank, blockIdx.x, + threadIdx.x, reduce_id, *flag); + break; } } - // final part doing NVAG based on responses from NIC-RMW:IBAG + } + __syncthreads(); + if (threadIdx.x == 0) { + const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1; + int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder); + if (old_val + adder == NVTE_MAX_SMS * reduce_id) + lastSM = 1; + } - if (blockIdx.x == 0) { - for (int nblock = 0; nblock < numblocks; nblock++) { - const int expected = basecounter + nblock + 1; - for (int ibflag = threadIdx.x; ibflag < ibranks; ibflag += 32) + int warp = blockIdx.x + (threadIdx.x >> 5); + int dest[RANKS]; +#pragma unroll + for (int i = 0; i < RANKS; i++) + dest[i] = (i + myrank + warp) & (RANKS - 1); + + for (int line = threadIdx.x + blockDim.x * blockIdx.x; line < totallines; + line += blockDim.x * gridDim.x) { + int4 val[RANKS]; + int index_in = mylineoffset + myrank * (totallines * skiplines / rowlines) + + (line / rowlines) * skiplines + (line % rowlines); + +#pragma unroll + for (int i = 0; i < RANKS; i++) { + val[i] = userptr[dest[i]][index_in]; + } + + int4 sum = val[0]; + half *s = reinterpret_cast(&sum); + +#pragma unroll + for (int i = 1; i < RANKS; i++) { + half *x = reinterpret_cast(&val[i]); +#pragma unroll + for (int j = 0; j < 8; j++) + s[j] += x[j]; + } + + int index_out = (line / rowlines) * skiplines + (line % rowlines); + (reinterpret_cast(outbuf))[index_out] = sum; + } + + if (threadIdx.x == 0 && lastSM) + *reduceidptr = reduce_id; +} // fp16 reduce-scatter kernel (out of place) fp16 + +template +__global__ void __launch_bounds__(MAX_THREADS) + userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride_multiatomic( + const int op, const int flagoffset, const int firstrank, const int myrank, + const int gpustep, const int mylineoffset, const int totallines, const int rowlines, + const int skiplines, const int numchunks, void **commbuff, const int handleridx, + void *outbuf, void *counters) { + for (int chunk_i = 0; chunk_i < numchunks; chunk_i++) { + if (counters) { + if (threadIdx.x == 0) { + // spin-lock on counter from producer + int old_val; + while (0 != (old_val = atomicCAS(((unsigned int *)counters) + chunk_i, 0, 0))) { + } + + // make sure all threadblocks have read/waited on counters. + int old_val2; + atomicInc(((unsigned int *)counters) + numchunks + chunk_i, gridDim.x - 1); + while (0 != + (old_val2 = atomicCAS(((unsigned int *)counters) + numchunks + chunk_i, 0, 0))) { + } + + // reset counter for next producer. + ((unsigned int *)counters)[chunk_i] = 1; + asm volatile("fence.sc.gpu;\n"); + } + } + __syncthreads(); + + __shared__ int4 *userptr[RANKS]; + volatile int *flagptr; + int physgpu, targetgpu, *myptr; + int *reduceidptr, reduce_id; + int lastSM = 0; + + if (threadIdx.x < RANKS) { + physgpu = myrank * gpustep + firstrank; + targetgpu = threadIdx.x * gpustep + firstrank; + myptr = (reinterpret_cast(commbuff[physgpu])) + flagoffset; + reduceidptr = myptr - NVTE_MAX_OPS; // +op; + reduce_id = (*reduceidptr) + 1; + flagptr = (reinterpret_cast(commbuff[targetgpu])) + flagoffset; + if (blockIdx.x == 0) + flagptr[physgpu] = reduce_id; + volatile int *flag = (volatile int *)&(myptr[targetgpu]); + userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu + handleridx]); + clock_t s = clock64(); + while (*flag < reduce_id) { + if (clock64() - s > TIMEOUT) { + printf("[%d] NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", myrank, blockIdx.x, + threadIdx.x, reduce_id, *flag); + break; + } + } + } + __syncthreads(); + if (threadIdx.x == 0) { + const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1; + int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder); + if (old_val + adder == NVTE_MAX_SMS * reduce_id) + lastSM = 1; + } + + int warp = blockIdx.x + (threadIdx.x >> 5); + int dest[RANKS]; +#pragma unroll + for (int i = 0; i < RANKS; i++) + dest[i] = (i + myrank + warp) & (RANKS - 1); + + for (int line = threadIdx.x + blockDim.x * blockIdx.x; line < totallines; + line += blockDim.x * gridDim.x) { + int4 val[RANKS]; + int index_in = chunk_i * mylineoffset + myrank * (totallines * skiplines / rowlines) + + (line / rowlines) * skiplines + (line % rowlines); + +#pragma unroll + for (int i = 0; i < RANKS; i++) { + val[i] = userptr[dest[i]][index_in]; + } + + int4 sum = val[0]; + half *s = reinterpret_cast(&sum); + +#pragma unroll + for (int i = 1; i < RANKS; i++) { + half *x = reinterpret_cast(&val[i]); +#pragma unroll + for (int j = 0; j < 8; j++) + s[j] += x[j]; + } + + int index_out = chunk_i * mylineoffset + (line / rowlines) * skiplines + (line % rowlines); + (reinterpret_cast(outbuf))[index_out] = sum; + } + if (threadIdx.x == 0 && lastSM) + *reduceidptr = reduce_id; + } +} // fp16 reduce-scatter kernel (out of place) fp16 + +template +__global__ void __launch_bounds__(MAX_THREADS) + userbuffers_fp16_sum_inplace_gpu_rr_ag(const int op, const int flagoffset, const int firstrank, + const int myrank, const int gpustep, + const int mylineoffset, const int totallines, + void **commbuff, const int handleridx) { + __shared__ int4 *userptr[RANKS]; + volatile int *flagptr; + int physgpu, targetgpu, *myptr; + int *reduceidptr, reduce_id; + if (threadIdx.x < RANKS) { + physgpu = myrank * gpustep + firstrank; + targetgpu = threadIdx.x * gpustep + firstrank; + myptr = (reinterpret_cast(commbuff[physgpu])) + flagoffset; + reduceidptr = myptr - NVTE_MAX_OPS; // +op; + reduce_id = (*reduceidptr) + 1; + flagptr = (reinterpret_cast(commbuff[targetgpu])) + flagoffset; + volatile int *flag = (volatile int *)&(myptr[targetgpu]); + userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu + handleridx]); + clock_t s = clock64(); + } + + int warp = blockIdx.x + (threadIdx.x >> 5); + int dest[RANKS]; + + int skipmy = 0; +#pragma unroll + for (int i = 0; i < RANKS; i++) { + int dst = (i + warp + myrank) & (RANKS - 1); + if (dst == myrank) { + skipmy++; + continue; + } + dest[i - skipmy] = dst; + } + __syncthreads(); + + for (int line = threadIdx.x + blockDim.x * blockIdx.x; line < totallines; + line += blockDim.x * gridDim.x) { + int4 val[RANKS - 1]; + +#pragma unroll + for (int i = 0; i < RANKS - 1; i++) { + val[i] = userptr[dest[i]][mylineoffset + line + totallines * dest[i]]; + } + +#pragma unroll + for (int i = 0; i < RANKS - 1; i++) { + userptr[myrank][mylineoffset + line + totallines * dest[i]] = val[i]; + } + } + __shared__ int lastSM; + if (threadIdx.x == 0) { + const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1; + int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder); + if (old_val + adder == NVTE_MAX_SMS * reduce_id) + lastSM = 1; + else + lastSM = 0; + } + __syncthreads(); + if (lastSM && threadIdx.x < RANKS) { + if (threadIdx.x == 0) + *reduceidptr = reduce_id; + flagptr[physgpu] = reduce_id; + volatile int *flag = (volatile int *)&myptr[targetgpu]; + clock_t s = clock64(); + while (*flag < reduce_id) { + if (clock64() - s > 2ull * TIMEOUT) { + printf("NVONLY AGBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, + *flag); + break; + } + } + } +} // fp16 inplace reduce kernel (Ampere) + +template +__global__ void __launch_bounds__(MAX_THREADS) + userbuffers_fp16_sum_inplace_gpu_rw_ag(const int op, const int flagoffset, const int firstrank, + const int myrank, const int gpustep, + const int mylineoffset, const int totallines, + void **commbuff, const int handleridx) { + __shared__ int4 *userptr[RANKS]; + volatile int *flagptr; + int physgpu, targetgpu, *myptr; + int *reduceidptr, reduce_id; + int4 *localptr; + if (threadIdx.x < RANKS) { + physgpu = myrank * gpustep + firstrank; + targetgpu = threadIdx.x * gpustep + firstrank; + myptr = (reinterpret_cast(commbuff[physgpu])) + flagoffset; + reduceidptr = myptr - NVTE_MAX_OPS; // +op; + reduce_id = (*reduceidptr) + 1; + flagptr = (reinterpret_cast(commbuff[targetgpu])) + flagoffset; + userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu + handleridx]); + } + __syncthreads(); + localptr = userptr[myrank]; + + int warp = blockIdx.x + (threadIdx.x >> 5); + int dest[RANKS - 1]; + int skipmy = 0; +#pragma unroll + for (int i = 0; i < RANKS; i++) { + int dst = (i + warp + myrank) & (RANKS - 1); + if (dst == myrank) { + skipmy++; + continue; + } + dest[i - skipmy] = dst; + } +#define UNROLLAG 4 + __syncthreads(); + const int loop_step0 = blockDim.x * gridDim.x; + const int loop_step = loop_step0 * UNROLLAG; + const int start_elem = threadIdx.x + blockDim.x * blockIdx.x; + const int end_elem = max(start_elem, totallines); + const int aligned_elem = ((end_elem - start_elem) / loop_step) * loop_step; + const int end_aligned = start_elem + aligned_elem; + + for (int line = start_elem; line < end_aligned; line += loop_step) { + int4 val[UNROLLAG]; +#pragma unroll + for (int j = 0; j < UNROLLAG; j++) + val[j] = localptr[mylineoffset + line + loop_step0 * j]; + +#pragma unroll + for (int j = 0; j < UNROLLAG; j++) +#pragma unroll + for (int i = 0; i < RANKS - 1; i++) { + userptr[dest[i]][mylineoffset + line + j * loop_step0] = val[j]; + } + } + + for (int line = end_aligned; line < end_elem; line += loop_step0) { + int4 sum = localptr[mylineoffset + line]; +#pragma unroll + for (int i = 0; i < RANKS - 1; i++) { + userptr[dest[i]][mylineoffset + line] = sum; + } + } + + __syncthreads(); + if (threadIdx.x == 0) + __threadfence_system(); + __syncthreads(); + + __shared__ int lastSM; + if (threadIdx.x == 0) { + const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1; + int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder); + if (old_val + adder == NVTE_MAX_SMS * reduce_id) + lastSM = 1; + else + lastSM = 0; + } + __syncthreads(); + if (lastSM && threadIdx.x < RANKS) { + if (threadIdx.x == 0) + *reduceidptr = reduce_id; + flagptr[physgpu] = reduce_id; + volatile int *flag = (volatile int *)&myptr[targetgpu]; + clock_t s = clock64(); + while (*flag < reduce_id) { + if (clock64() - s > 2ull * TIMEOUT) { + printf("NVONLY AGBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, + *flag); + break; + } + } + } +} // fp16 inplace allgather kernel (Volta,Hopper) + +template +__global__ void __launch_bounds__(MAX_THREADS) + userbuffers_fp16_sum_inplace_gpu_rr_blocked(const int op, const int flagoffset, + const int firstrank, const int myrank, + const int lineoffset, const int numlines, + void **commbuff, const int handleridx, + const int peerblocklines, int *hostflags, + int *gpuflag, const int numblocks) { + const int basecounter = gpuflag[NVTE_GF_STATE + op]; + +#define REDUCETHREADS (blockDim.x - 32) + + if (threadIdx.x < 32) { + int *flagptr; + if (threadIdx.x < RANKS) { + if (!blockIdx.x) { + flagptr = reinterpret_cast(commbuff[threadIdx.x + firstrank]); + flagptr[flagoffset + myrank + firstrank] = basecounter; + } + volatile int *flag = (volatile int *)&((reinterpret_cast( + commbuff[myrank + firstrank]))[flagoffset + threadIdx.x + firstrank]); + while (*flag < basecounter) { + } + } + __syncthreads(); + + int startblock = 0, endblock = numblocks; + + for (int nblock = 0; nblock < endblock; nblock++) { + asm volatile("bar.sync 13, %0;" ::"r"(REDUCETHREADS + 32)); + + if (threadIdx.x == 0) { + __threadfence(); + if (blockIdx.x) + gpuflag[op * NVTE_MAX_SMS * 2 + blockIdx.x] = nblock + basecounter + 1; + } else if (blockIdx.x == 0) { + int expecting = (basecounter + nblock + 1); + if (threadIdx.x < gridDim.x) + while (((volatile int *)gpuflag)[op * NVTE_MAX_SMS * 2 + threadIdx.x] < expecting) { + } + } + if (!blockIdx.x) { + asm volatile("bar.sync 15, %0;" ::"r"(32)); + if (!threadIdx.x) + hostflags[0] = nblock + basecounter + 1; + } + } + + int cachedflag = basecounter; + +#define ALLGATHERFLAG NVTE_GF_IBSHARPDONE + + if (blockIdx.x == 0 && threadIdx.x < RANKS) { + while (cachedflag < basecounter + numblocks) { + int newflag = ((volatile int *)gpuflag)[ALLGATHERFLAG]; + if (newflag == cachedflag) + continue; + cachedflag = newflag; + flagptr[flagoffset + myrank + 32 + firstrank] = cachedflag; + } + } + + if (blockIdx.x == 0 && threadIdx.x == 0) + gpuflag[NVTE_GF_STATE + op] = basecounter + numblocks; + } else { + const int warp = blockIdx.x + (threadIdx.x >> 5); + int4 *userptr[RANKS]; + int4 *userptrmyrank; +#pragma unroll + for (int i = 0; i < RANKS; i++) + userptr[i] = reinterpret_cast( + commbuff[((i + myrank + warp) & (RANKS - 1)) + handleridx + firstrank]); + userptrmyrank = reinterpret_cast(commbuff[myrank + handleridx + firstrank]); + __syncthreads(); + + int blocklineoffset = 0; + + while (blocklineoffset < numlines) { + const int remainder = min(numlines - blocklineoffset, peerblocklines * RANKS); + const int blocklines = remainder / RANKS; + const int blockstart = lineoffset + blocklineoffset + blocklines * myrank; + + for (int line = threadIdx.x - 32 + REDUCETHREADS * blockIdx.x; line < blocklines; + line += REDUCETHREADS * gridDim.x) { + int4 val[RANKS]; + +#pragma unroll + for (int i = 0; i < RANKS; i++) { + val[i] = userptr[i][blockstart + line]; + } + + int4 sum = val[0]; + half *s = reinterpret_cast(&sum); + +#pragma unroll + for (int i = 1; i < RANKS; i++) { + half *x = reinterpret_cast(&val[i]); +#pragma unroll + for (int j = 0; j < sizeof(int4) / sizeof(half); j++) + s[j] += x[j]; + } + + userptrmyrank[blockstart + line] = sum; + } // single block loop + + asm volatile("bar.sync 13, %0;" ::"r"(REDUCETHREADS + 32)); + + blocklineoffset += peerblocklines * RANKS; + } // block loop NVLINK-REDUCESCATTER + const int nwarps = (REDUCETHREADS >> 5) / (RANKS - 1); + const int myblockDim = nwarps << 5; + const int mywarp = ((threadIdx.x - 32) >> 5) / (RANKS - 1); + const int maxthreadIdx = myblockDim * (RANKS - 1) + 32; + const int mydest = (myrank + 1 + ((threadIdx.x - 32) >> 5) % (RANKS - 1)) & (RANKS - 1); + const int mythreadIdx = (mywarp << 5) + (threadIdx.x & 31); + volatile int *flag = (volatile int *)&((reinterpret_cast( + commbuff[myrank + firstrank]))[flagoffset + mydest + 32 + firstrank]); + + int4 *userptrmydest = userptr[((RANKS << 10) + mydest - myrank - warp) & (RANKS - 1)]; + + blocklineoffset = 0; + int gathercounter = basecounter + 1; + while (blocklineoffset < numlines) { + const int remainder = min(numlines - blocklineoffset, peerblocklines * RANKS); + const int blocklines = remainder / RANKS; + const int blockstart = lineoffset + blocklineoffset; + +#define UNROLL 6 + int4 *myptr = &userptrmyrank[blockstart + blocklines * mydest]; + int4 *peerptr = &userptrmydest[blockstart + blocklines * mydest]; + + if (threadIdx.x < maxthreadIdx) { + const int start_elem = mythreadIdx + myblockDim * blockIdx.x; + const int end_elem = max(start_elem, blocklines); + const int aligned_elem = ((end_elem - start_elem) / (myblockDim * gridDim.x * UNROLL)) * + (myblockDim * gridDim.x * UNROLL); + const int end_aligned = start_elem + aligned_elem; + + if (mythreadIdx == 0) { + while (*flag < gathercounter) { + } + gathercounter++; + } + + asm volatile("bar.sync %0, %1;" ::"r"(1 + mydest), "r"(myblockDim)); + + for (int line = start_elem; line < end_aligned; line += myblockDim * gridDim.x * UNROLL) { + int4 val[UNROLL]; +#pragma unroll + for (int i = 0; i < UNROLL; i++) + val[i] = peerptr[line + i * myblockDim * gridDim.x]; +#pragma unroll + for (int i = 0; i < UNROLL; i++) + myptr[line + i * myblockDim * gridDim.x] = val[i]; + } + for (int line = end_aligned; line < end_elem; line += myblockDim * gridDim.x) + myptr[line] = peerptr[line]; + } + blocklineoffset += peerblocklines * RANKS; + } // block loop for NVLINK-ALLGATHER + } // worker warps else block +} // fp16 inplace reduce kernel with SHARP / in blocks + +// threadfence and SMs sync to SM0 +#define SMBAR(offset, block) \ + asm volatile("bar.sync 13, %0;" ::"r"(blockDim.x)); \ + if (threadIdx.x == 0) { \ + __threadfence_system(); \ + if (blockIdx.x) \ + gpuflag[offset + blockIdx.x] = block + basecounter + 1; \ + } else if (blockIdx.x == 0) { \ + int expecting = (basecounter + block + 1); \ + if (threadIdx.x < gridDim.x) \ + while (((volatile int *)gpuflag)[offset + threadIdx.x] < expecting) { \ + } \ + } \ + if (blockIdx.x == 0) \ + asm volatile("bar.sync 15, %0;" ::"r"(32)); + +template +__global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_rr_blocked2( + const int op, const int maxcredit, const int headstart, const int myibrank, const int ibranks, + const int commbufoffset, const int flagoffset, const int firstrank, const int myrank, + const int gpustep, const int lineoffset, const int numlines, void **commbuff, + const int handleridx, const int peerblocklines, int *hostflags, int *gpuflag, + const int numblocks) { + const int basecounter = gpuflag[NVTE_GF_STATE + op]; + if (threadIdx.x < 32) { + int *flagptr; + volatile int *localflag = (volatile int *)&( + ((int *)commbuff[gpustep * myrank + firstrank])[flagoffset]); // NOLINT(*) + // initial intranode barrier - once + if (threadIdx.x < RANKS) { + if (!blockIdx.x) { + flagptr = reinterpret_cast(commbuff[gpustep * threadIdx.x + firstrank]); + flagptr[flagoffset + gpustep * myrank + firstrank] = basecounter; + } + volatile int *flag = &localflag[gpustep * threadIdx.x + firstrank]; + while (*flag < basecounter) { + } + } + __syncthreads(); + + for (int nblock = 0; nblock < numblocks + headstart; nblock++) { + if (nblock < numblocks) { + // RS happens here + SMBAR(op * 2 * NVTE_MAX_SMS, nblock); + if (!blockIdx.x && !threadIdx.x) + hostflags[NVTE_HF_NVRSDONE + (op & 1)] = nblock + basecounter + 1; + } + + if (nblock >= headstart) { + for (int ibflag = threadIdx.x; ibflag < ibranks; ibflag += 32) + if (ibflag != myibrank) + while (localflag[NVTE_REG0_IBRS + ibflag] < basecounter + nblock - headstart + 1) { + } + asm volatile("bar.sync 13, %0;" ::"r"(blockDim.x)); + // REDUCE happens here + SMBAR(op * 2 * NVTE_MAX_SMS + NVTE_MAX_SMS, nblock - headstart); + if (!blockIdx.x && !threadIdx.x) + hostflags[NVTE_HF_NVREDUCEDONE + (op & 1)] = nblock + basecounter + 1 - headstart; + } + } + // final part doing NVAG based on responses from NIC-RMW:IBAG + + if (blockIdx.x == 0) { + for (int nblock = 0; nblock < numblocks; nblock++) { + const int expected = basecounter + nblock + 1; + for (int ibflag = threadIdx.x; ibflag < ibranks; ibflag += 32) if (ibflag != myibrank) while (localflag[NVTE_REG0_IBAG + ibflag] < expected) { } @@ -722,7 +1731,8 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ } } - if (blockIdx.x == 0 && threadIdx.x == 0) gpuflag[NVTE_GF_STATE + op] = basecounter + numblocks; + if (blockIdx.x == 0 && threadIdx.x == 0) + gpuflag[NVTE_GF_STATE + op] = basecounter + numblocks; } else { // sync warp // reducethreads const int warp = blockIdx.x + (threadIdx.x >> 5); @@ -762,7 +1772,8 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ for (int i = 1; i < RANKS; i++) { half *x = reinterpret_cast(&val[i]); #pragma unroll - for (int j = 0; j < sizeof(int4) / sizeof(half); j++) s[j] += x[j]; + for (int j = 0; j < sizeof(int4) / sizeof(half); j++) + s[j] += x[j]; } userptrmyrank[blockstart + line] = sum; @@ -801,13 +1812,15 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ : tempbufptr[i * ibblocklines + line]; half *x = reinterpret_cast(&val[(i + 1) % UNROLLRS]); #pragma unroll - for (int j = 0; j < 16; j++) s[j] += x[j]; + for (int j = 0; j < 16; j++) + s[j] += x[j]; } #pragma unroll for (int i = 1; i < UNROLLRS; i++) { half *x = reinterpret_cast(&val[i]); #pragma unroll - for (int j = 0; j < 16; j++) s[j] += x[j]; + for (int j = 0; j < 16; j++) + s[j] += x[j]; } userptrmyrank[tempstart + line] = sum; } @@ -858,9 +1871,11 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ for (int line = start_elem; line < end_aligned; line += myblockDim * gridDim.x * UNROLL) { int4 val[UNROLL]; #pragma unroll - for (int i = 0; i < UNROLL; i++) val[i] = peerptr[line + i * myblockDim * gridDim.x]; + for (int i = 0; i < UNROLL; i++) + val[i] = peerptr[line + i * myblockDim * gridDim.x]; #pragma unroll - for (int i = 0; i < UNROLL; i++) myptr[line + i * myblockDim * gridDim.x] = val[i]; + for (int i = 0; i < UNROLL; i++) + myptr[line + i * myblockDim * gridDim.x] = val[i]; } for (int line = end_aligned; line < end_elem; line += myblockDim * gridDim.x) myptr[line] = peerptr[line]; @@ -952,7 +1967,8 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ for (int i = 1; i < RANKS; i++) { half *x = reinterpret_cast(&val[i]); #pragma unroll - for (int j = 0; j < sizeof(int4) / sizeof(half); j++) s[j] += x[j]; + for (int j = 0; j < sizeof(int4) / sizeof(half); j++) + s[j] += x[j]; } userptrmyrank[blockstart + line] = sum; @@ -971,9 +1987,6 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ int4 *tempbufptr = &internalbuf[((nblock - headstart) % maxcredit) * peerblocklines]; const int tempstart = lineoffset + (nblock - headstart) * peerblocklines * RANKS + myrank * blocklines + ibblocklines * myibrank; - // if(threadIdx.x==32) printf("[%d] block%d thread %d offset %d line %d ibblocklines %d ptr - // %lx commbufoffset - // %d\n",myrank,blockIdx.x,threadIdx.x,tempstart,0,ibblocklines,(void*)&tempbufptr[(1-myibrank)*ibblocklines],(1-myibrank)*ibblocklines*16); asm volatile("bar.sync 13, %0;" ::"r"(REDUCETHREADS + 32)); @@ -994,13 +2007,15 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ : tempbufptr[i * ibblocklines + line]; half *x = reinterpret_cast(&val[(i + 1) % UNROLLRS]); #pragma unroll - for (int j = 0; j < 16; j++) s[j] += x[j]; + for (int j = 0; j < 16; j++) + s[j] += x[j]; } #pragma unroll for (int i = 1; i < UNROLLRS; i++) { half *x = reinterpret_cast(&val[i]); #pragma unroll - for (int j = 0; j < 16; j++) s[j] += x[j]; + for (int j = 0; j < 16; j++) + s[j] += x[j]; } userptrmyrank[tempstart + line] = sum; } @@ -1048,7 +2063,8 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ } } - if (blockIdx.x == 0 && threadIdx.x == 0) gpuflag[NVTE_GF_STATE + op] = basecounter + numblocks; + if (blockIdx.x == 0 && threadIdx.x == 0) + gpuflag[NVTE_GF_STATE + op] = basecounter + numblocks; } else { // sync warp // reducethreads const int warp = blockIdx.x + (threadIdx.x >> 5); @@ -1105,9 +2121,11 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ for (int line = start_elem; line < end_aligned; line += myblockDim * gridDim.x * UNROLL) { int4 val[UNROLL]; #pragma unroll - for (int i = 0; i < UNROLL; i++) val[i] = peerptr[line + i * myblockDim * gridDim.x]; + for (int i = 0; i < UNROLL; i++) + val[i] = peerptr[line + i * myblockDim * gridDim.x]; #pragma unroll - for (int i = 0; i < UNROLL; i++) myptr[line + i * myblockDim * gridDim.x] = val[i]; + for (int i = 0; i < UNROLL; i++) + myptr[line + i * myblockDim * gridDim.x] = val[i]; } for (int line = end_aligned; line < end_elem; line += myblockDim * gridDim.x) myptr[line] = peerptr[line]; @@ -1125,102 +2143,134 @@ __global__ void userbuffers_fp16_sum_inplace_gpu_null(const int op, int *hostfla gpuflag[NVTE_GF_STATE + op] = basecounter; while (((volatile int *)gpuflag)[NVTE_GF_IBSHARPDONE] < basecounter) { } -} +} + +#define callranks_block(x) \ + if (comm->ar_nvsize == x) \ + userbuffers_fp16_sum_inplace_gpu_rr_blocked<<>>( \ + userbuffers_allreduceop_sharp, NVTE_REG0_OFFSET(comm), comm->ar_firstgpu, comm->ar_nvrank, \ + offset / 8, elements / 8, reinterpret_cast(comm->gpu_ptrs), \ + handler * comm->nvsize, blocksize / sizeof(int4) / comm->ar_nvsize, \ + reinterpret_cast(comm->hostflags), comm->flags, \ + (elements * 2 + blocksize - 1) / blocksize); + +#define callranks2_block(x) \ + if (ar_nvsize == x) { \ + int numblocks = (elements * 2 + blocksize - 1) / blocksize; \ + int headstart = numblocks - 1; /*<3?numblocks-1:3;*/ \ + if (headstart > maxcredit) \ + headstart = maxcredit; \ + if (x == 1) \ + headstart = maxcredit; \ + if (headstart > numblocks) \ + headstart = numblocks; \ + if (headstart == 0) \ + headstart = 1; \ + userbuffers_fp16_sum_inplace_gpu_rr_blocked2<<>>( \ + op, maxcredit, headstart, my_node, num_nodes, \ + NVTE_REG0_OFFSET(comm) + NVTE_REG0_FLAGS + \ + (op == userbuffers_allreduceop_nonsharp ? NVTE_REG0_COMMBUFFER : 0), \ + NVTE_REG0_OFFSET(comm) + NVTE_REG0_OPFLAGS * op, ar_firstgpu, ar_nvrank, ar_step, \ + offset / 8, elements / 8, reinterpret_cast(comm->gpu_ptrs), \ + handler * comm->nvsize, blocksize / sizeof(int4) / ar_nvsize, \ + reinterpret_cast(comm->hostflags), comm->flags, numblocks); \ + } + +#define callranks2_block_rs(x) \ + if (ar_nvsize == x) { \ + int numblocks = (elements * 2 + blocksize - 1) / blocksize; \ + int headstart = numblocks - 1; /*<3?numblocks-1:3;*/ \ + if (headstart > maxcredit) \ + headstart = maxcredit; \ + if (x == 1) \ + headstart = maxcredit; \ + if (headstart > numblocks) \ + headstart = numblocks; \ + if (headstart == 0) \ + headstart = 1; \ + userbuffers_fp16_sum_inplace_gpu_rr_blocked2_rs<<>>( \ + op, maxcredit, headstart, my_node, num_nodes, \ + NVTE_REG0_OFFSET(comm) + NVTE_REG0_FLAGS + \ + (op == userbuffers_allreduceop_nonsharp ? NVTE_REG0_COMMBUFFER : 0), \ + NVTE_REG0_OFFSET(comm) + NVTE_REG0_OPFLAGS * op, ar_firstgpu, ar_nvrank, ar_step, \ + offset / 8, elements / 8, reinterpret_cast(comm->gpu_ptrs), \ + handler * comm->nvsize, blocksize / sizeof(int4) / ar_nvsize, \ + reinterpret_cast(comm->hostflags), comm->flags, numblocks); \ + } + +#define callranks2_block_ag(x) \ + if (ar_nvsize == x) { \ + int numblocks = (elements * 2 + blocksize - 1) / blocksize; \ + int headstart = numblocks - 1; /*<3?numblocks-1:3;*/ \ + if (headstart > maxcredit) \ + headstart = maxcredit; \ + if (x == 1) \ + headstart = maxcredit; \ + if (headstart > numblocks) \ + headstart = numblocks; \ + if (headstart == 0) \ + headstart = 1; \ + userbuffers_fp16_sum_inplace_gpu_rr_blocked2_ag<<>>( \ + op, maxcredit, headstart, my_node, num_nodes, \ + NVTE_REG0_OFFSET(comm) + NVTE_REG0_FLAGS + \ + (op == userbuffers_allreduceop_nonsharp ? NVTE_REG0_COMMBUFFER : 0), \ + NVTE_REG0_OFFSET(comm) + NVTE_REG0_OPFLAGS * op, ar_firstgpu, ar_nvrank, ar_step, \ + offset / 8, elements / 8, reinterpret_cast(comm->gpu_ptrs), \ + handler * comm->nvsize, blocksize / sizeof(int4) / ar_nvsize, \ + reinterpret_cast(comm->hostflags), comm->flags, numblocks); \ + } + +#define callranks(x) \ + if (ar_nvsize == x) { \ + int arg1 = op - NVTE_MAX_OPS, \ + arg2 = NVTE_REG0_OFFSET(comm) - \ + (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ + NVTE_MAX_OPS, \ + arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg6 = offset / 8, \ + arg7 = elements / 8; \ + void **arg8 = reinterpret_cast(comm->gpu_ptrs); \ + int arg9 = handler * comm->nvsize; \ + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ + reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ + reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ + reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ + reinterpret_cast(&arg9)}; \ + CUDACHECK(cudaLaunchKernelExC( \ + &cfg, \ + reinterpret_cast(comm->use_rr_kernel ? userbuffers_fp16_sum_inplace_gpu_rr \ + : userbuffers_fp16_sum_inplace_gpu_rw), \ + kernelArgs)); \ + } -#define callranks_block(x) \ - if (comm->ar_nvsize == x) \ - userbuffers_fp16_sum_inplace_gpu_rr_blocked<<>>( \ - userbuffers_allreduceop_sharp, NVTE_REG0_OFFSET(comm), comm->ar_firstgpu, comm->ar_nvrank, \ - offset / 8, elements / 8, reinterpret_cast(comm->gpu_ptrs), \ - handler * comm->nvsize, blocksize / sizeof(int4) / comm->ar_nvsize, \ - reinterpret_cast(comm->hostflags), comm->flags, \ - (elements * 2 + blocksize - 1) / blocksize); +#define callranksMC(x) \ + if (ar_nvsize == x) { \ + int arg1 = op - NVTE_MAX_OPS, \ + arg2 = NVTE_REG0_OFFSET(comm) - \ + (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ + NVTE_MAX_OPS, \ + arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg6 = offset / 8, \ + arg7 = elements / 8; \ + void **arg8 = reinterpret_cast(comm->gpu_ptrs); \ + int arg9 = handler * comm->nvsize; \ + void *arg10 = comm->mc_ptr[handler]; \ + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ + reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ + reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ + reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ + reinterpret_cast(&arg9), reinterpret_cast(&arg10)}; \ + CUDACHECK(cudaLaunchKernelExC( \ + &cfg, reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_mc), kernelArgs)); \ + } -#define callranks2_block(x) \ - if (ar_nvsize == x) { \ - int numblocks = (elements * 2 + blocksize - 1) / blocksize; \ - int headstart = numblocks - 1; /*<3?numblocks-1:3;*/ \ - if (headstart > maxcredit) headstart = maxcredit; \ - if (x == 1) headstart = maxcredit; \ - if (headstart > numblocks) headstart = numblocks; \ - if (headstart == 0) headstart = 1; \ - userbuffers_fp16_sum_inplace_gpu_rr_blocked2<<>>( \ - op, maxcredit, headstart, my_node, num_nodes, \ - NVTE_REG0_OFFSET(comm) + NVTE_REG0_FLAGS + \ - (op == userbuffers_allreduceop_nonsharp ? NVTE_REG0_COMMBUFFER : 0), \ - NVTE_REG0_OFFSET(comm) + NVTE_REG0_OPFLAGS * op, ar_firstgpu, ar_nvrank, ar_step, \ - offset / 8, elements / 8, reinterpret_cast(comm->gpu_ptrs), \ - handler * comm->nvsize, blocksize / sizeof(int4) / ar_nvsize, \ - reinterpret_cast(comm->hostflags), comm->flags, numblocks); \ - } - -#define callranks2_block_rs(x) \ - if (ar_nvsize == x) { \ - int numblocks = (elements * 2 + blocksize - 1) / blocksize; \ - int headstart = numblocks - 1; /*<3?numblocks-1:3;*/ \ - if (headstart > maxcredit) headstart = maxcredit; \ - if (x == 1) headstart = maxcredit; \ - if (headstart > numblocks) headstart = numblocks; \ - if (headstart == 0) headstart = 1; \ - userbuffers_fp16_sum_inplace_gpu_rr_blocked2_rs<<>>( \ - op, maxcredit, headstart, my_node, num_nodes, \ - NVTE_REG0_OFFSET(comm) + NVTE_REG0_FLAGS + \ - (op == userbuffers_allreduceop_nonsharp ? NVTE_REG0_COMMBUFFER : 0), \ - NVTE_REG0_OFFSET(comm) + NVTE_REG0_OPFLAGS * op, ar_firstgpu, ar_nvrank, ar_step, \ - offset / 8, elements / 8, reinterpret_cast(comm->gpu_ptrs), \ - handler * comm->nvsize, blocksize / sizeof(int4) / ar_nvsize, \ - reinterpret_cast(comm->hostflags), comm->flags, numblocks); \ - } - -#define callranks2_block_ag(x) \ - if (ar_nvsize == x) { \ - int numblocks = (elements * 2 + blocksize - 1) / blocksize; \ - int headstart = numblocks - 1; /*<3?numblocks-1:3;*/ \ - if (headstart > maxcredit) headstart = maxcredit; \ - if (x == 1) headstart = maxcredit; \ - if (headstart > numblocks) headstart = numblocks; \ - if (headstart == 0) headstart = 1; \ - userbuffers_fp16_sum_inplace_gpu_rr_blocked2_ag<<>>( \ - op, maxcredit, headstart, my_node, num_nodes, \ - NVTE_REG0_OFFSET(comm) + NVTE_REG0_FLAGS + \ - (op == userbuffers_allreduceop_nonsharp ? NVTE_REG0_COMMBUFFER : 0), \ - NVTE_REG0_OFFSET(comm) + NVTE_REG0_OPFLAGS * op, ar_firstgpu, ar_nvrank, ar_step, \ - offset / 8, elements / 8, reinterpret_cast(comm->gpu_ptrs), \ - handler * comm->nvsize, blocksize / sizeof(int4) / ar_nvsize, \ - reinterpret_cast(comm->hostflags), comm->flags, numblocks); \ - } - -#define callranks(x) \ - if (ar_nvsize == x) { \ - int arg1 = op - NVTE_MAX_OPS, \ - arg2 = NVTE_REG0_OFFSET(comm) - \ - (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ - NVTE_MAX_OPS, \ - arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg6 = offset / 8, \ - arg7 = elements / 8; \ - void **arg8 = reinterpret_cast(comm->gpu_ptrs); \ - int arg9 = handler * comm->nvsize; \ - void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ - reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ - reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ - reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ - reinterpret_cast(&arg9)}; \ - CUDACHECK(cudaLaunchKernelExC( \ - &cfg, \ - reinterpret_cast(comm->use_rr_kernel ? userbuffers_fp16_sum_inplace_gpu_rr \ - : userbuffers_fp16_sum_inplace_gpu_rw), \ - kernelArgs)); \ - } - -#define SETUP_LAUNCH_CONFIG(sms, threads, stream) \ - cudaLaunchConfig_t cfg = {sms, threads, 0, stream, NULL, 0}; \ - cudaLaunchAttribute attribute_ub[2]; \ - attribute_ub[1].id = cudaLaunchAttributeClusterDimension; \ - attribute_ub[1].val.clusterDim.x = sms % comm->cga_size == 0 ? comm->cga_size : 1; \ - attribute_ub[1].val.clusterDim.y = 1; \ - attribute_ub[1].val.clusterDim.z = 1; \ - attribute_ub[0].id = cudaLaunchAttributeCooperative; \ - cfg.attrs = attribute_ub; \ +#define SETUP_LAUNCH_CONFIG(sms, threads, stream) \ + cudaLaunchConfig_t cfg = {sms, threads, 0, stream, NULL, 0}; \ + cudaLaunchAttribute attribute_ub[2]; \ + attribute_ub[1].id = cudaLaunchAttributeClusterDimension; \ + attribute_ub[1].val.clusterDim.x = sms % comm->cga_size == 0 ? comm->cga_size : 1; \ + attribute_ub[1].val.clusterDim.y = 1; \ + attribute_ub[1].val.clusterDim.z = 1; \ + attribute_ub[0].id = cudaLaunchAttributeCooperative; \ + cfg.attrs = attribute_ub; \ cfg.numAttrs = comm->sm_arch >= 9 ? 2 : 1; int allreduce_userbuff_inplace_gpu(const int handler, const int offset, const int elements, @@ -1232,10 +2282,12 @@ int allreduce_userbuff_inplace_gpu(const int handler, const int offset, const in const int ar_nvsize = comm->nvsize; const int ar_firstgpu = comm->ar_firstgpu; const int ar_nvrank = comm->ar_nvrank; - if (elements < 8) return 0; + if (elements < 8) + return 0; int sms = sms = comm->sms; int warps = comm->threads / 32; - if (warps < comm->ar_nvsize) warps = comm->ar_nvsize; + if (warps < comm->ar_nvsize) + warps = comm->ar_nvsize; if (comm->launch_mode & NVTE_LAUNCH_GPU) { if (comm->ar_nvsize == 1) @@ -1259,109 +2311,502 @@ int allreduce2_userbuff_inplace_gpu(const int maxcredit, const int handler, cons const int ar_step = op == userbuffers_allreduceop_nonsharp2 ? 1 : comm->ar2_nvsize; const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank; - - if (elements < 8) return 0; + + if (elements < 8) + return 0; + int sms = ar_nvsize == 1 ? 2 : comm->sms; + int warps = comm->threads / 32; + if (warps < ar_nvsize) + warps = ar_nvsize; + if (num_nodes > 1) { + callranks2_block(1) callranks2_block(2) callranks2_block(4) callranks2_block(8) + } else { + SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); + callranks(2) callranks(4) callranks(8) + } + return sms; +} + +#define callranks_ag(x) \ + if (ar_nvsize == x) { \ + int arg1 = op - NVTE_MAX_OPS, \ + arg2 = NVTE_REG0_OFFSET(comm) - \ + (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ + NVTE_MAX_OPS, \ + arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 8 / x, \ + arg6 = offset / 8 + (comm->use_rr_kernel ? 0 : arg4 * arg7); \ + void **arg8 = reinterpret_cast(comm->gpu_ptrs); \ + int arg9 = handler * comm->nvsize; \ + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ + reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ + reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ + reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ + reinterpret_cast(&arg9)}; \ + CUDACHECK(cudaLaunchKernelExC( \ + &cfg, \ + reinterpret_cast(comm->use_rr_kernel ? userbuffers_fp16_sum_inplace_gpu_rr_ag \ + : userbuffers_fp16_sum_inplace_gpu_rw_ag), \ + kernelArgs)); \ + } + +#define callranks_agMC(x) \ + if (ar_nvsize == x) { \ + int arg1 = op - NVTE_MAX_OPS, \ + arg2 = NVTE_REG0_OFFSET(comm) - \ + (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ + NVTE_MAX_OPS, \ + arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 8 / x, \ + arg6 = offset / 8 + arg4 * arg7; \ + void **arg8 = reinterpret_cast(comm->gpu_ptrs); \ + int arg9 = handler * comm->nvsize; \ + uint4 *arg10 = reinterpret_cast(comm->mc_ptr[handler]); \ + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ + reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ + reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ + reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ + reinterpret_cast(&arg9), reinterpret_cast(&arg10)}; \ + CUDACHECK(cudaLaunchKernelExC( \ + &cfg, reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_mc_ag), kernelArgs)); \ + } + +#define callranks_rs(x) \ + if (ar_nvsize == x) { \ + int arg1 = op - NVTE_MAX_OPS, \ + arg2 = NVTE_REG0_OFFSET(comm) - \ + (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ + NVTE_MAX_OPS, \ + arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 8 / x, \ + arg6 = offset / 8 + arg4 * arg7; \ + void **arg8 = reinterpret_cast(comm->gpu_ptrs); \ + int arg9 = handler * comm->nvsize; \ + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ + reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ + reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ + reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ + reinterpret_cast(&arg9)}; \ + CUDACHECK(cudaLaunchKernelExC( \ + &cfg, reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_rr_rs), kernelArgs)); \ + } + +#define callranks_rsMC(x) \ + if (ar_nvsize == x) { \ + int arg1 = op - NVTE_MAX_OPS, \ + arg2 = NVTE_REG0_OFFSET(comm) - \ + (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ + NVTE_MAX_OPS, \ + arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 8 / x, \ + arg6 = offset / 8 + arg4 * arg7; \ + void **arg8 = reinterpret_cast(comm->gpu_ptrs); \ + int arg9 = handler * comm->nvsize; \ + void *arg10 = comm->mc_ptr[handler]; \ + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ + reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ + reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ + reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ + reinterpret_cast(&arg9), reinterpret_cast(&arg10)}; \ + CUDACHECK(cudaLaunchKernelExC( \ + &cfg, reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_mc_rs), kernelArgs)); \ + } + +#define callranks_rs_oop(x) \ + if (ar_nvsize == x) { \ + int arg1 = op - NVTE_MAX_OPS, \ + arg2 = NVTE_REG0_OFFSET(comm) - \ + (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ + NVTE_MAX_OPS, \ + arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 8 / x, \ + arg6 = offset / 8 + arg4 * arg7, arg8 = rowelements / 8, arg9 = strideelements / 8; \ + void **arg10 = reinterpret_cast(comm->gpu_ptrs); \ + int arg11 = handler * comm->nvsize; \ + void *arg12 = output; \ + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ + reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ + reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ + reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ + reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ + reinterpret_cast(&arg11), reinterpret_cast(&arg12)}; \ + CUDACHECK(cudaLaunchKernelExC( \ + &cfg, reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_rr_rs_oop), \ + kernelArgs)); \ + } + +#define callranks_rs_oop_fp8(x) \ + if (ar_nvsize == x) { \ + int arg1 = op - NVTE_MAX_OPS, \ + arg2 = NVTE_REG0_OFFSET(comm) - \ + (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ + NVTE_MAX_OPS, \ + arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 16 / x, \ + arg6 = offset / 16 + arg4 * arg7, arg8 = rowelements / 8, arg9 = strideelements / 8; \ + void **arg10 = reinterpret_cast(comm->gpu_ptrs); \ + int arg11 = handler * comm->nvsize; \ + void *arg12 = output; \ + float *arg13 = scale; \ + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ + reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ + reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ + reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ + reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ + reinterpret_cast(&arg11), reinterpret_cast(&arg12), \ + reinterpret_cast(&arg13)}; \ + CUDACHECK(cudaLaunchKernelExC( \ + &cfg, \ + reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_fp8), \ + kernelArgs)); \ + } + +#define callranks_rs_oopMC(x) \ + if (ar_nvsize == x) { \ + int arg1 = op - NVTE_MAX_OPS, \ + arg2 = NVTE_REG0_OFFSET(comm) - \ + (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ + NVTE_MAX_OPS, \ + arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 8 / x, \ + arg6 = offset / 8 + arg4 * arg7, arg8 = rowelements / 8, arg9 = strideelements / 8; \ + void **arg10 = reinterpret_cast(comm->gpu_ptrs); \ + int arg11 = handler * comm->nvsize; \ + void *arg12 = output; \ + void *arg13 = comm->mc_ptr[handler]; \ + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ + reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ + reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ + reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ + reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ + reinterpret_cast(&arg11), reinterpret_cast(&arg12), \ + reinterpret_cast(&arg13)}; \ + CUDACHECK(cudaLaunchKernelExC( \ + &cfg, reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_mc_rs_oop), \ + kernelArgs)); \ + } + +#define callranks_rs_oop_atomic_fp8(x) \ + if (ar_nvsize == x) { \ + int arg1 = op - NVTE_MAX_OPS, \ + arg2 = NVTE_REG0_OFFSET(comm) - \ + (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ + NVTE_MAX_OPS, \ + arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 16 / x, \ + arg6 = offset / 16, arg8 = rowelements / 8, arg9 = strideelements_out / 8, \ + arg10 = strideelements_in / 16; \ + void **arg11 = reinterpret_cast(comm->gpu_ptrs); \ + int arg12 = handler * comm->nvsize; \ + void *arg13 = output; \ + float *arg14 = scale; \ + void *arg15 = counters; \ + int arg16 = numchunks, arg17 = atomicindex; \ + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ + reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ + reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ + reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ + reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ + reinterpret_cast(&arg11), reinterpret_cast(&arg12), \ + reinterpret_cast(&arg13), reinterpret_cast(&arg14), \ + reinterpret_cast(&arg15), reinterpret_cast(&arg16), \ + reinterpret_cast(&arg17)}; \ + CUDACHECK(cudaLaunchKernelExC( \ + &cfg, \ + reinterpret_cast( \ + userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_atomic_fp8), \ + kernelArgs)); \ + } + +#define callranks_rs_oop_stride(x) \ + if (ar_nvsize == x) { \ + int arg1 = op - NVTE_MAX_OPS, \ + arg2 = NVTE_REG0_OFFSET(comm) - \ + (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ + NVTE_MAX_OPS, \ + arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 8 / x, \ + arg6 = offset / 8, arg8 = rowelements / 8, arg9 = strideelements / 8; \ + void **arg10 = reinterpret_cast(comm->gpu_ptrs); \ + int arg11 = handler * comm->nvsize; \ + void *arg12 = output; \ + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ + reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ + reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ + reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ + reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ + reinterpret_cast(&arg11), reinterpret_cast(&arg12)}; \ + CUDACHECK(cudaLaunchKernelExC( \ + &cfg, reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride), \ + kernelArgs)); \ + } + +#if 0 +#define callranks_rs_oop_stride_atomic_fp8(x) \ + if (ar_nvsize == x) { \ + int arg1 = op - NVTE_MAX_OPS, \ + arg2 = NVTE_REG0_OFFSET(comm) - \ + (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ + NVTE_MAX_OPS, \ + arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 16 / x, \ + arg6 = offset / 16, arg8 = rowelements / 8, arg9 = strideelements / 8, arg10 = numchunks; \ + void **arg11 = reinterpret_cast(comm->gpu_ptrs); \ + int arg12 = handler * comm->nvsize; \ + void *arg13 = output; \ + void *arg14 = counters; \ + float *arg15 = scale; \ + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ + reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ + reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ + reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ + reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ + reinterpret_cast(&arg11), reinterpret_cast(&arg12), \ + reinterpret_cast(&arg13), reinterpret_cast(&arg14), \ + reinterpret_cast(&arg15)}; \ + CUDACHECK(cudaLaunchKernelExC( \ + &cfg, \ + reinterpret_cast( \ + userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride_atomic_fp8), \ + kernelArgs)); \ + } +#endif + +#define callranks_rs_oop_stride_atomic(x) \ + if (ar_nvsize == x) { \ + int arg1 = op - NVTE_MAX_OPS, \ + arg2 = NVTE_REG0_OFFSET(comm) - \ + (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ + NVTE_MAX_OPS, \ + arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 8 / x, \ + arg6 = offset / 8, arg8 = rowelements / 8, arg9 = strideelements / 8, arg10 = numchunks; \ + void **arg11 = reinterpret_cast(comm->gpu_ptrs); \ + int arg12 = handler * comm->nvsize; \ + void *arg13 = output; \ + void *arg14 = counters; \ + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ + reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ + reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ + reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ + reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ + reinterpret_cast(&arg11), reinterpret_cast(&arg12), \ + reinterpret_cast(&arg13), reinterpret_cast(&arg14)}; \ + CUDACHECK(cudaLaunchKernelExC( \ + &cfg, \ + reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride_atomic), \ + kernelArgs)); \ + } + +#define callranks_rs_oop_stride_multiatomic(x) \ + if (ar_nvsize == x) { \ + int arg1 = op - NVTE_MAX_OPS, \ + arg2 = NVTE_REG0_OFFSET(comm) - \ + (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ + NVTE_MAX_OPS, \ + arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 8 / x, \ + arg6 = offset / 8, arg8 = rowelements / 8, arg9 = strideelements / 8, arg10 = numchunks; \ + void **arg11 = reinterpret_cast(comm->gpu_ptrs); \ + int arg12 = handler * comm->nvsize; \ + void *arg13 = output; \ + void *arg14 = counters; \ + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ + reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ + reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ + reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ + reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ + reinterpret_cast(&arg11), reinterpret_cast(&arg12), \ + reinterpret_cast(&arg13), reinterpret_cast(&arg14)}; \ + CUDACHECK( \ + cudaLaunchKernelExC(&cfg, \ + reinterpret_cast( \ + userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride_multiatomic), \ + kernelArgs)); \ + } + +int reducescatter2_userbuff_inplace_gpu(const int maxcredit, const int handler, const int offset, + const int elements, const int blocksize, communicator *comm, + cudaStream_t stream, int op) { + // schedule GPU kernel only + // CPU/SHARP part is responsibility of caller + + const int num_nodes = op == userbuffers_allreduceop_nonsharp ? comm->num_nodes : comm->num2_nodes; + const int my_node = op == userbuffers_allreduceop_nonsharp ? comm->my_node : comm->my2_node; + const int ar_firstgpu = + op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu; + const int ar_step = op == userbuffers_allreduceop_nonsharp2 ? 1 : comm->ar2_nvsize; + const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; + const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank; + + if (elements < 8) + return 0; + int sms = ar_nvsize == 1 ? 2 : comm->sms; + int warps = comm->threads / 32; + if (warps < ar_nvsize) + warps = ar_nvsize; + + if (num_nodes > 1) { + callranks2_block_rs(1) callranks2_block_rs(2) callranks2_block_rs(4) callranks2_block_rs(8) + } else { + SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); + callranks_rs(2) callranks_rs(4) callranks_rs(8) + } + return sms; +} + +void reducescatter2_userbuff_strided(void *output, const int handler, const int offset, + const int rowelements, const int colelements, + const int strideelements, communicator *comm, + cudaStream_t stream) { + const int elements = rowelements * colelements; + const int op = userbuffers_allreduceop_nonsharp2; + const int blocksize = elements * 2; + const int ar_firstgpu = + op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu; + const int ar_step = op == userbuffers_allreduceop_nonsharp2 ? 1 : comm->ar2_nvsize; + const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; + const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank; + + if (elements < 64) + return; + int sms = ar_nvsize == 1 ? 2 : comm->sms; + int warps = comm->threads / 32; + if (warps < ar_nvsize) + warps = ar_nvsize; + + SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); + callranks_rs_oop_stride(2) callranks_rs_oop_stride(4) callranks_rs_oop_stride(8) +} +void reducescatter2_userbuff_strided_atomic(void *output, const int handler, const int offset, + const int rowelements, const int colelements, + const int strideelements, const int numchunks, + void *counters, communicator *comm, + cudaStream_t stream) { + const int elements = rowelements * colelements; + const int op = userbuffers_allreduceop_nonsharp2; + const int blocksize = elements * 2; + const int ar_firstgpu = + op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu; + const int ar_step = op == userbuffers_allreduceop_nonsharp2 ? 1 : comm->ar2_nvsize; + const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; + const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank; + + if (elements < 64) + return; + int sms = ar_nvsize == 1 ? 2 : comm->sms; + int warps = comm->threads / 32; + if (warps < ar_nvsize) + warps = ar_nvsize; + + SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); + callranks_rs_oop_stride_atomic(2) callranks_rs_oop_stride_atomic(4) + callranks_rs_oop_stride_atomic(8) +} + +#if 0 + template + void reducescatter2_userbuff_strided_atomic_fp8( + void* output, float *scale, const int handler, const int offset, const int rowelements, + const int colelements, const int strideelements, const int numchunks, void *counters, + communicator* comm, cudaStream_t stream) { + const int elements = rowelements*colelements; + const int op = userbuffers_allreduceop_nonsharp2; + const int blocksize = elements; + const int ar_firstgpu = op == userbuffers_allreduceop_nonsharp ? + comm->ar_firstgpu : comm->ar2_firstgpu; + const int ar_step = op == userbuffers_allreduceop_nonsharp2 ? + 1 : comm->ar2_nvsize; + const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? + comm->ar_nvsize : comm->ar2_nvsize; + const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? + comm->ar_nvrank : comm->ar2_nvrank; + + assert(comm->sm_arch >= 9); + if (elements < 128) return; + int sms = ar_nvsize == 1 ? 2 : comm->sms; + int warps = comm->threads/32; + if (warps < ar_nvsize) warps = ar_nvsize; + + SETUP_LAUNCH_CONFIG(sms, warps*32, stream); + callranks_rs_oop_stride_atomic_fp8(2) + callranks_rs_oop_stride_atomic_fp8(4) + callranks_rs_oop_stride_atomic_fp8(8) + } +#endif +template +void reducescatter2_userbuff_strided_universal_fp8(void *output, float *scale, const int handler, + const int offset, const int rowelements, + const int colelements, + const int strideelements_out, + const int strideelements_in, const int numchunks, + const int atomicindex, void *counters, + communicator *comm, cudaStream_t stream) { + const int elements = rowelements * colelements; + const int op = userbuffers_allreduceop_nonsharp2; + const int blocksize = elements; + const int ar_firstgpu = + op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu; + const int ar_step = op == userbuffers_allreduceop_nonsharp2 ? 1 : comm->ar2_nvsize; + const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; + const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank; + assert(comm->sm_arch >= 9); + if (elements < 128) + return; int sms = ar_nvsize == 1 ? 2 : comm->sms; int warps = comm->threads / 32; - if (warps < ar_nvsize) warps = ar_nvsize; - if (num_nodes > 1) { - callranks2_block(1) callranks2_block(2) callranks2_block(4) callranks2_block(8) - } else { - SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); - callranks(2) callranks(4) callranks(8) - } - return sms; -} - -#define callranks_ag(x) \ - if (ar_nvsize == x) { \ - int arg1 = op - NVTE_MAX_OPS, \ - arg2 = NVTE_REG0_OFFSET(comm) - \ - (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ - NVTE_MAX_OPS, \ - arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 8 / x, \ - arg6 = offset / 8 + (comm->use_rr_kernel ? 0 : arg4 * arg7); \ - void **arg8 = reinterpret_cast(comm->gpu_ptrs); \ - int arg9 = handler * comm->nvsize; \ - void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ - reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ - reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ - reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ - reinterpret_cast(&arg9)}; \ - CUDACHECK(cudaLaunchKernelExC( \ - &cfg, \ - reinterpret_cast(comm->use_rr_kernel ? userbuffers_fp16_sum_inplace_gpu_rr_ag \ - : userbuffers_fp16_sum_inplace_gpu_rw_ag), \ - kernelArgs)); \ - } + if (warps < ar_nvsize) + warps = ar_nvsize; -#define callranks_rs(x) \ - if (ar_nvsize == x) { \ - int arg1 = op - NVTE_MAX_OPS, \ - arg2 = NVTE_REG0_OFFSET(comm) - \ - (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ - NVTE_MAX_OPS, \ - arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 8 / x, \ - arg6 = offset / 8 + arg4 * arg7; \ - void **arg8 = reinterpret_cast(comm->gpu_ptrs); \ - int arg9 = handler * comm->nvsize; \ - void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ - reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ - reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ - reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ - reinterpret_cast(&arg9)}; \ - CUDACHECK(cudaLaunchKernelExC( \ - &cfg, reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_rr_rs), kernelArgs)); \ - } - -#define callranks_rs_oop(x) \ - if (ar_nvsize == x) { \ - int arg1 = op - NVTE_MAX_OPS, \ - arg2 = NVTE_REG0_OFFSET(comm) - \ - (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ - NVTE_MAX_OPS, \ - arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 8 / x, \ - arg6 = offset / 8 + arg4 * arg7, arg8 = rowelements / 8, arg9 = strideelements / 8; \ - void **arg10 = reinterpret_cast(comm->gpu_ptrs); \ - int arg11 = handler * comm->nvsize; \ - void *arg12 = output; \ - void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ - reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ - reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ - reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ - reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ - reinterpret_cast(&arg11), reinterpret_cast(&arg12)}; \ - CUDACHECK(cudaLaunchKernelExC( \ - &cfg, reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_rr_rs_oop), \ - kernelArgs)); \ - } + SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); + callranks_rs_oop_atomic_fp8(2) callranks_rs_oop_atomic_fp8(4) callranks_rs_oop_atomic_fp8(8) +} -int reducescatter2_userbuff_inplace_gpu(const int maxcredit, const int handler, const int offset, - const int elements, const int blocksize, communicator *comm, - cudaStream_t stream, int op) { - // schedule GPU kernel only - // CPU/SHARP part is responsibility of caller +template +void reducescatter2_userbuff_strided_atomic_fp8(void *output, float *scale, const int handler, + const int offset, const int rowelements, + const int colelements, const int strideelements_out, + const int strideelements_in, const int numchunks, + void *counters, communicator *comm, + cudaStream_t stream) { + reducescatter2_userbuff_strided_universal_fp8( + output, scale, handler, offset, rowelements, colelements, strideelements_out, + strideelements_in, 1, numchunks, counters /*nullptr*/, comm, stream); +} +template +void reducescatter2_userbuff_strided_multiatomic_fp8( + void *output, float *scale, const int handler, const int offset, const int rowelements, + const int colelements, const int strideelements_out, const int strideelements_in, + const int numchunks, void *counters, communicator *comm, cudaStream_t stream) { + reducescatter2_userbuff_strided_universal_fp8( + output, scale, handler, offset, rowelements, colelements, strideelements_out, + strideelements_in, numchunks, 0, counters /*nullptr*/, comm, stream); +} - const int num_nodes = op == userbuffers_allreduceop_nonsharp ? comm->num_nodes : comm->num2_nodes; - const int my_node = op == userbuffers_allreduceop_nonsharp ? comm->my_node : comm->my2_node; +void reducescatter2_userbuff_strided_multiatomic(void *output, const int handler, const int offset, + const int rowelements, const int colelements, + const int strideelements, const int numchunks, + void *counters, communicator *comm, + cudaStream_t stream) { + const int elements = rowelements * colelements; + const int op = userbuffers_allreduceop_nonsharp2; + const int blocksize = elements * 2; const int ar_firstgpu = op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu; const int ar_step = op == userbuffers_allreduceop_nonsharp2 ? 1 : comm->ar2_nvsize; const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank; - if (elements < 8) return 0; + if (elements < 64) + return; int sms = ar_nvsize == 1 ? 2 : comm->sms; int warps = comm->threads / 32; - if (warps < ar_nvsize) warps = ar_nvsize; + if (warps < ar_nvsize) + warps = ar_nvsize; - if (num_nodes > 1) { - callranks2_block_rs(1) callranks2_block_rs(2) callranks2_block_rs(4) callranks2_block_rs(8) - } else { - SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); - callranks_rs(2) callranks_rs(4) callranks_rs(8) - } - return sms; + SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); + // if(comm->use_mc && (comm->memflags[handler] & NVTE_UB_MEM_MC_CREATED)) { + // //callranks_rs_oopMC(2) + // //callranks_rs_oopMC(4) + // //callranks_rs_oopMC(8) + // } else { + // if(comm->memflags[handler] & NVTE_UB_MEM_UC_CONTIG) { + // //callranks_rs_oopUCPTR(2) + // //callranks_rs_oopUCPTR(4) + // //callranks_rs_oopUCPTR(8) + // } else { + callranks_rs_oop_stride_multiatomic(2) callranks_rs_oop_stride_multiatomic(4) + callranks_rs_oop_stride_multiatomic(8) + // } + //} } int allgather2_userbuff_inplace_gpu(const int maxcredit, const int handler, const int offset, @@ -1378,10 +2823,12 @@ int allgather2_userbuff_inplace_gpu(const int maxcredit, const int handler, cons const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank; - if (elements < 8) return 0; + if (elements < 8) + return 0; int sms = ar_nvsize == 1 ? 2 : comm->sms; int warps = comm->threads / 32; - if (warps < ar_nvsize) warps = ar_nvsize; + if (warps < ar_nvsize) + warps = ar_nvsize; if (num_nodes > 1) { callranks2_block_ag(1) callranks2_block_ag(2) callranks2_block_ag(4) callranks2_block_ag(8) @@ -1402,13 +2849,15 @@ void allgather2_userbuff_inplace(const int handler, const int offset, const int const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank; - if (elements < 64) return; + if (elements < 64) + return; int sms = ar_nvsize == 1 ? 2 : comm->sms; int warps = comm->threads / 32; - if (warps < ar_nvsize) warps = ar_nvsize; + if (warps < ar_nvsize) + warps = ar_nvsize; SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); - callranks_ag(2) callranks_ag(4) callranks_ag(8) + callranks_ag(2) callranks_ag(4) callranks_ag(8) } void allgather2_userbuff_inplace_sliced(const int handler, const int offset, const int elements, @@ -1436,13 +2885,15 @@ void reducescatter2_userbuff_inplace(const int handler, const int offset, const const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank; - if (elements < 64) return; + if (elements < 64) + return; int sms = ar_nvsize == 1 ? 2 : comm->sms; int warps = comm->threads / 32; - if (warps < ar_nvsize) warps = ar_nvsize; + if (warps < ar_nvsize) + warps = ar_nvsize; SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); - callranks_rs(2) callranks_rs(4) callranks_rs(8) + callranks_rs(2) callranks_rs(4) callranks_rs(8) } void reducescatter2_userbuff_stridedoutput(void *output, const int handler, const int offset, const int rowelements, const int colelements, @@ -1457,21 +2908,124 @@ void reducescatter2_userbuff_stridedoutput(void *output, const int handler, cons const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank; - if (elements < 64) return; + if (elements < 64) + return; int sms = ar_nvsize == 1 ? 2 : comm->sms; int warps = comm->threads / 32; - if (warps < ar_nvsize) warps = ar_nvsize; + if (warps < ar_nvsize) + warps = ar_nvsize; SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); - callranks_rs_oop(2) callranks_rs_oop(4) callranks_rs_oop(8) + callranks_rs_oop(2) callranks_rs_oop(4) callranks_rs_oop(8) } void reducescatter2_userbuff(void *output, const int handler, const int offset, const int elements, communicator *comm, cudaStream_t stream) { reducescatter2_userbuff_stridedoutput(output, handler, offset, elements, 1, 0, comm, stream); } +template +void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const int handler, + const int offset, const int rowelements, + const int colelements, const int strideelements, + communicator *comm, cudaStream_t stream) { + const int elements = rowelements * colelements; + const int op = userbuffers_allreduceop_nonsharp2; + const int blocksize = elements; + const int ar_firstgpu = + op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu; + const int ar_step = op == userbuffers_allreduceop_nonsharp2 ? 1 : comm->ar2_nvsize; + const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; + const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank; + assert(comm->sm_arch >= 9); + if (elements < 128) + return; + int sms = ar_nvsize == 1 ? 2 : comm->sms; + int warps = comm->threads / 32; + if (warps < ar_nvsize) + warps = ar_nvsize; + + SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); + callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8) +} + +template +void reducescatter2_userbuff_fp8(void *output, float *scale, const int handler, const int offset, + const int elements, communicator *comm, cudaStream_t stream) { + reducescatter2_userbuff_stridedoutput_fp8(output, scale, handler, offset, elements, 1, 0, + comm, stream); +} + +template void reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(void *output, float *scale, + const int handler, const int offset, + const int elements, communicator *comm, + cudaStream_t stream = 0); +template void reducescatter2_userbuff_fp8<__nv_fp8_e4m3>(void *output, float *scale, + const int handler, const int offset, + const int elements, communicator *comm, + cudaStream_t stream = 0); +#if 0 +template void reducescatter2_userbuff_strided_atomic_fp8<__nv_fp8_e4m3>( + void* output, float *scale, const int handler, const int offset, + const int rowelements, const int colelements, const int strideelements, + const int numchunks, void *counters, communicator* comm, cudaStream_t stream = 0); +#endif +template void reducescatter2_userbuff_strided_atomic_fp8<__nv_fp8_e4m3>( + void *output, float *scale, const int handler, const int offset, const int rowelements, + const int colelements, const int strideelements_out, const int strideelements_in, + const int numchunks, void *counters, communicator *comm, cudaStream_t stream = 0); +template void reducescatter2_userbuff_strided_multiatomic_fp8<__nv_fp8_e4m3>( + void *output, float *scale, const int handler, const int offset, const int rowelements, + const int colelements, const int strideelements_out, const int strideelements_in, + const int numchunks, void *counters, communicator *comm, cudaStream_t stream = 0); +__global__ void __launch_bounds__(MAX_THREADS) + kuserbuffers_pullsendrecv(int myrank, int peer, int *recv_id, int *send_flagptr, + int *recv_flagptr, int4 *srcptr, int4 *dstptr, const int lines) { + if (blockIdx.x == 0 && threadIdx.x == 0) { + atomicAdd_system(send_flagptr, 1); + } + +#define UNROLLCOPY 8 + const int start_elem = threadIdx.x + blockDim.x * blockIdx.x; + const int end_elem = lines; + const int aligned_elem = (end_elem - start_elem) & (~(blockDim.x * gridDim.x * UNROLLCOPY - 1)); + const int end_aligned = start_elem + aligned_elem; + + if (threadIdx.x == 0) { + const int signal_id = (*recv_id) + 1; + volatile int *flag = (volatile int *)recv_flagptr; + clock_t s = clock64(); + while (*flag < signal_id) { + if (clock64() - s > TIMEOUT) { + printf("[%d from %d] pullrecv: expected %d, stuck with %d\n", myrank, peer, signal_id, + *flag); + break; + } + } + if (lines == 0) { + *recv_id = signal_id; + return; + } // otherwise need an extra kernel + } + __syncthreads(); + + if (end_elem <= start_elem) + return; + + for (int line = start_elem; line < end_aligned; line += blockDim.x * gridDim.x * UNROLLCOPY) { + int4 val[UNROLLCOPY]; +#pragma unroll + for (int i = 0; i < UNROLLCOPY; i++) + val[i] = srcptr[line + i * blockDim.x * gridDim.x]; +#pragma unroll + for (int i = 0; i < UNROLLCOPY; i++) + dstptr[line + i * blockDim.x * gridDim.x] = val[i]; + } + for (int line = end_aligned; line < end_elem; line += blockDim.x * gridDim.x) + dstptr[line] = srcptr[line]; +} + __global__ void kuserbuffers_pullsend(int myrank, int peer, int *send_id, int *flagptr) { - atomicAdd(flagptr, 1); + atomicAdd_system(flagptr, 1); } __global__ void kuserbuffers_inc(int *id) { @@ -1514,14 +3068,17 @@ __global__ void __launch_bounds__(MAX_THREADS) } __syncthreads(); - if (end_elem <= start_elem) return; + if (end_elem <= start_elem) + return; for (int line = start_elem; line < end_aligned; line += blockDim.x * gridDim.x * UNROLLCOPY) { int4 val[UNROLLCOPY]; #pragma unroll - for (int i = 0; i < UNROLLCOPY; i++) val[i] = srcptr[line + i * blockDim.x * gridDim.x]; + for (int i = 0; i < UNROLLCOPY; i++) + val[i] = srcptr[line + i * blockDim.x * gridDim.x]; #pragma unroll - for (int i = 0; i < UNROLLCOPY; i++) dstptr[line + i * blockDim.x * gridDim.x] = val[i]; + for (int i = 0; i < UNROLLCOPY; i++) + dstptr[line + i * blockDim.x * gridDim.x] = val[i]; } for (int line = end_aligned; line < end_elem; line += blockDim.x * gridDim.x) dstptr[line] = srcptr[line]; @@ -1539,18 +3096,22 @@ __global__ void __launch_bounds__(MAX_THREADS) for (int line = start_elem; line < end_aligned; line += blockDim.x * gridDim.x * UNROLLCOPY) { int4 val[UNROLLCOPY]; #pragma unroll - for (int i = 0; i < UNROLLCOPY; i++) val[i] = srcptr[line + i * blockDim.x * gridDim.x]; + for (int i = 0; i < UNROLLCOPY; i++) + val[i] = srcptr[line + i * blockDim.x * gridDim.x]; #pragma unroll - for (int i = 0; i < UNROLLCOPY; i++) dstptr[line + i * blockDim.x * gridDim.x] = val[i]; + for (int i = 0; i < UNROLLCOPY; i++) + dstptr[line + i * blockDim.x * gridDim.x] = val[i]; } for (int line = end_aligned; line < end_elem; line += blockDim.x * gridDim.x) dstptr[line] = srcptr[line]; } __syncthreads(); - if (threadIdx.x) return; + if (threadIdx.x) + return; __threadfence_system(); - atomicAdd(flagptr, 1); // otherwise need local SM sync before sending flag - } else { // 0 bytes and 1 SM only + atomicAdd_system(flagptr, + 1); // otherwise need local SM sync before sending flag + } else { // 0 bytes and 1 SM only atomicAdd_system(flagptr, 1); } } @@ -1559,7 +3120,8 @@ __global__ void kuserbuffers_pushrecv(int myrank, int peer, int *recv_id, int *f const int signal_id = (*recv_id) + adder; *recv_id = signal_id; volatile int *flag = (volatile int *)flagptr; - if (*flag >= signal_id) return; + if (*flag >= signal_id) + return; clock_t s = clock64(); while (atomicAdd_system(flagptr, 0) < signal_id) { if (clock64() - s > TIMEOUT) { @@ -1569,13 +3131,203 @@ __global__ void kuserbuffers_pushrecv(int myrank, int peer, int *recv_id, int *f } } -#define CUDACHECK(cmd) \ - do { \ - cudaError_t e = cmd; \ - if (e != cudaSuccess) { \ - printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \ - exit(EXIT_FAILURE); \ - } \ +__global__ void __launch_bounds__(MAX_THREADS) + kuserbuffers_pushsendrecv(int *send_id, int *send_flagptr, int4 *srcptr, int4 *dstptr, + const int lines, int myrank, int peer, int *recv_id, + int *recv_flagptr, int adder) { + if (lines) { + const int start_elem = threadIdx.x + blockDim.x * blockIdx.x; + const int end_elem = lines; + const int aligned_elem = + ((end_elem - start_elem) & (~(blockDim.x * gridDim.x * UNROLLCOPY - 1))); + const int end_aligned = start_elem + aligned_elem; + if (end_elem > start_elem) { + for (int line = start_elem; line < end_aligned; line += blockDim.x * gridDim.x * UNROLLCOPY) { + int4 val[UNROLLCOPY]; +#pragma unroll + for (int i = 0; i < UNROLLCOPY; i++) { + val[i] = srcptr[line + i * blockDim.x * gridDim.x]; + } +#pragma unroll + for (int i = 0; i < UNROLLCOPY; i++) { + dstptr[line + i * blockDim.x * gridDim.x] = val[i]; + } + } + for (int line = end_aligned; line < end_elem; line += blockDim.x * gridDim.x) { + dstptr[line] = srcptr[line]; + } + } + __syncthreads(); + if (threadIdx.x) + return; + __threadfence_system(); + atomicAdd_system(send_flagptr, + 1); // otherwise need local SM sync before sending flag + } else { // 0 bytes and 1 SM only + atomicAdd_system(send_flagptr, 1); + } + + if (blockIdx.x == 0 && threadIdx.x == 0) { + const int signal_id = (*recv_id) + adder; + *recv_id = signal_id; + volatile int *flag = (volatile int *)recv_flagptr; + if (*flag >= signal_id) + return; + clock_t s = clock64(); + while (*flag < signal_id) { + if (clock64() - s > TIMEOUT) { + printf("%d from %d] pushrecv: expected %d, stuck with %d\n", myrank, peer, signal_id, + *flag); + return; + } + } + } +} + +__global__ void __launch_bounds__(MAX_THREADS) + kuserbuffers_pushsendrecv_atomic(int *send_id, int *send_flagptr, int4 *srcptr, int4 *dstptr, + const int lines, int myrank, int peer, int *recv_id, + int *recv_flagptr, int adder, void *counters) { + if (lines) { + const int start_elem = threadIdx.x + blockDim.x * blockIdx.x; + const int end_elem = lines; + const int aligned_elem = + ((end_elem - start_elem) & (~(blockDim.x * gridDim.x * UNROLLCOPY - 1))); + const int end_aligned = start_elem + aligned_elem; + if (end_elem > start_elem) { + for (int line = start_elem; line < end_aligned; line += blockDim.x * gridDim.x * UNROLLCOPY) { + int4 val[UNROLLCOPY]; +#pragma unroll + for (int i = 0; i < UNROLLCOPY; i++) { + val[i] = srcptr[line + i * blockDim.x * gridDim.x]; + } +#pragma unroll + for (int i = 0; i < UNROLLCOPY; i++) { + dstptr[line + i * blockDim.x * gridDim.x] = val[i]; + } + } + for (int line = end_aligned; line < end_elem; line += blockDim.x * gridDim.x) { + dstptr[line] = srcptr[line]; + } + } + __syncthreads(); + if (threadIdx.x) + return; + __threadfence_system(); + atomicAdd_system(send_flagptr, + 1); // otherwise need local SM sync before sending flag + } else { // 0 bytes and 1 SM only + atomicAdd_system(send_flagptr, 1); + } + + if (blockIdx.x == 0 && threadIdx.x == 0) { + const int signal_id = (*recv_id) + adder; + *recv_id = signal_id; + volatile int *flag = (volatile int *)recv_flagptr; + // if(*flag>=signal_id) return; + clock_t s = clock64(); + while (*flag < signal_id) { + if (clock64() - s > TIMEOUT) { + printf("%d from %d] pushrecv: expected %d, stuck with %d\n", myrank, peer, signal_id, + *flag); /*return;*/ + } + } + + // Decrement atomic val to signal current output tile finish + if (counters) { + ((unsigned int *)counters)[0] = 0; + asm volatile("fence.sc.gpu;\n"); + } + } +} + +__global__ void __launch_bounds__(MAX_THREADS) + kuserbuffers_pushsendrecv_multiatomic(int *send_id, int *send_flagptr, int4 *srcptr, + int4 *dstptr, const int lines, int myrank, int peer, + int *recv_id, int *recv_flagptr, int adder, + void *counters, int nchunks, int send_stride, + int recv_stride, bool shuffle) { + for (int chunk_i = 0; chunk_i < nchunks - 1; chunk_i++) { + int send_chunk_id = shuffle ? chunk_i : (nchunks + myrank - chunk_i) % nchunks; + int recv_chunk_id = shuffle ? chunk_i + 1 : (nchunks + myrank - chunk_i - 1) % nchunks; + int send_offset = (send_chunk_id * send_stride) / 16; + int recv_offset = ((shuffle ? recv_chunk_id : send_chunk_id) * recv_stride) / 16; + + if (lines) { + const int start_elem = threadIdx.x + blockDim.x * blockIdx.x; + const int end_elem = lines; + const int aligned_elem = + ((end_elem - start_elem) & (~(blockDim.x * gridDim.x * UNROLLCOPY - 1))); + const int end_aligned = start_elem + aligned_elem; + if (end_elem > start_elem) { + for (int line = start_elem; line < end_aligned; + line += blockDim.x * gridDim.x * UNROLLCOPY) { + int4 val[UNROLLCOPY]; +#pragma unroll + for (int i = 0; i < UNROLLCOPY; i++) { + val[i] = srcptr[send_offset + line + i * blockDim.x * gridDim.x]; + } +#pragma unroll + for (int i = 0; i < UNROLLCOPY; i++) { + dstptr[recv_offset + line + i * blockDim.x * gridDim.x] = val[i]; + } + } + for (int line = end_aligned; line < end_elem; line += blockDim.x * gridDim.x) { + dstptr[recv_offset + line] = srcptr[send_offset + line]; + } + } + __syncthreads(); + if (!threadIdx.x) { + __threadfence_system(); + atomicAdd_system(send_flagptr, + 1); // otherwise need local SM sync before sending flag + } + } else { // 0 bytes and 1 SM only + atomicAdd_system(send_flagptr, 1); + } + + // wait for message to arrive. + if (blockIdx.x == 0 && threadIdx.x == 0) { + const int signal_id = (*recv_id) + adder; + *recv_id = signal_id; + volatile int *flag = (volatile int *)recv_flagptr; + // if(*flag>=signal_id) return; + clock_t s = clock64(); + while (*flag < signal_id) { + if (clock64() - s > TIMEOUT) { + printf("%d from %d] pushrecv: expected %d, stuck with %d\n", myrank, peer, signal_id, + *flag); /*return;*/ + } + } + } + + // Producer must update counters. + if (blockIdx.x == 0 && threadIdx.x == 0) { + // Decrement atomic val to signal current output tile finish + if (counters) { + ((unsigned int *)counters)[recv_chunk_id /*chunk_i+1*/] = 0; + asm volatile("fence.sc.gpu;\n"); + } + } + + // sync all CTAs before moving to next chunk. + if (threadIdx.x == 0) { + int old_val2; + atomicInc(((unsigned int *)counters) + nchunks + chunk_i, gridDim.x - 1); + while (0 != (old_val2 = atomicCAS(((unsigned int *)counters) + nchunks + chunk_i, 0, 0))) { + } + } + __syncthreads(); + } +} + +#define CUDACHECK(cmd) \ + do { \ + cudaError_t e = cmd; \ + if (e != cudaSuccess) { \ + printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \ + exit(EXIT_FAILURE); \ + } \ } while (0) #define INTRANODE(peer) ((peer / comm->nvsize) == (comm->myrank / comm->nvsize)) @@ -1611,7 +3363,8 @@ void userbuffers_send(const int srchandler, const size_t srcoffset, const int ds comm->hostflags + userbuffers_sendop); return; } - if (!(comm->launch_mode & NVTE_LAUNCH_GPU)) return; + if (!(comm->launch_mode & NVTE_LAUNCH_GPU)) + return; if (comm->push == 0) { kuserbuffers_pullsend<<<1, 1, 0, stream>>>(comm->myrank, peer, &(comm->send_id[peer]), reinterpret_cast(flagptr)); @@ -1633,10 +3386,145 @@ void userbuffers_send(const int srchandler, const size_t srcoffset, const int ds } } +void userbuffers_sendrecv(const int srchandler, const int dsthandler, const size_t send_offset, + const size_t recv_offset, const size_t bytes, communicator *comm, + const int send_peer, const int recv_peer, cudaStream_t stream) { + bool signalonly = (bytes / 16 == 0) || (comm->use_ce != 0); + int send_peerlocal = send_peer % comm->nvsize; + int recv_peerlocal = recv_peer % comm->nvsize; + void *flagptr_send = + (comm->peer_ptr[0][send_peerlocal]) + + ((NVTE_REG0_OFFSET(comm) + NVTE_REG0_RECV + comm->myrank * NVTE_MAX_REGIONS + dsthandler) * + sizeof(int)); + void *flagptr_recv = + (comm->mem_ptr[0]) + + ((NVTE_REG0_OFFSET(comm) + NVTE_REG0_RECV + recv_peer * NVTE_MAX_REGIONS + dsthandler) * + sizeof(int)); + + void *send_srcptr = (comm->mem_ptr[srchandler]) + send_offset; + void *send_dstptr = (comm->peer_ptr[dsthandler][send_peerlocal]) + send_offset; + if (comm->use_ce) + CUDACHECK(cudaMemcpyAsync(send_dstptr, send_srcptr, bytes, cudaMemcpyDeviceToDevice, stream)); + SETUP_LAUNCH_CONFIG(signalonly ? 1 : comm->sms, signalonly ? 1 : 1024, stream); + + int *arg1 = &comm->send_id[send_peer]; + int *arg2 = reinterpret_cast(flagptr_send); + int4 *arg3 = reinterpret_cast(send_srcptr); + int4 *arg4 = reinterpret_cast(send_dstptr); + int arg5 = signalonly ? 0 : bytes / 16; + int arg6 = comm->myrank; + int arg7 = recv_peer; + int *arg8 = &comm->recv_id[recv_peer * NVTE_MAX_REGIONS + dsthandler]; + int *arg9 = reinterpret_cast(flagptr_recv); + int arg10 = signalonly ? 1 : comm->sms; + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), + reinterpret_cast(&arg3), reinterpret_cast(&arg4), + reinterpret_cast(&arg5), reinterpret_cast(&arg6), + reinterpret_cast(&arg7), reinterpret_cast(&arg8), + reinterpret_cast(&arg9), reinterpret_cast(&arg10)}; + CUDACHECK( + cudaLaunchKernelExC(&cfg, reinterpret_cast(kuserbuffers_pushsendrecv), kernelArgs)); + //} +} + +void userbuffers_sendrecv_atomic(const int srchandler, const int dsthandler, + const size_t send_offset, const size_t recv_offset, + const size_t bytes, communicator *comm, const int send_peer, + const int recv_peer, void *counters, cudaStream_t stream) { + assert(comm->push && comm->use_ce == 0); + bool signalonly = (bytes / 16 == 0) || (comm->use_ce != 0); + + int send_peerlocal = send_peer % comm->nvsize; + int recv_peerlocal = recv_peer % comm->nvsize; + void *flagptr_send = + (comm->peer_ptr[0][send_peerlocal]) + + ((NVTE_REG0_OFFSET(comm) + NVTE_REG0_RECV + comm->myrank * NVTE_MAX_REGIONS + dsthandler) * + sizeof(int)); + void *flagptr_recv = + (comm->mem_ptr[0]) + + ((NVTE_REG0_OFFSET(comm) + NVTE_REG0_RECV + recv_peer * NVTE_MAX_REGIONS + dsthandler) * + sizeof(int)); + + void *send_srcptr = (comm->mem_ptr[srchandler]) + send_offset; + void *send_dstptr = (comm->peer_ptr[dsthandler][send_peerlocal]) + send_offset; + if (comm->use_ce) { + CUDACHECK(cudaMemcpyAsync(send_dstptr, send_srcptr, bytes, cudaMemcpyDeviceToDevice, stream)); + } + SETUP_LAUNCH_CONFIG(signalonly ? 1 : comm->sms, signalonly ? 1 : 1024, stream); + + int *arg1 = &comm->send_id[send_peer]; + int *arg2 = reinterpret_cast(flagptr_send); + int4 *arg3 = reinterpret_cast(send_srcptr); + int4 *arg4 = reinterpret_cast(send_dstptr); + int arg5 = signalonly ? 0 : bytes / 16; + int arg6 = comm->myrank; + int arg7 = recv_peer; + int *arg8 = &comm->recv_id[recv_peer * NVTE_MAX_REGIONS + dsthandler]; + int *arg9 = reinterpret_cast(flagptr_recv); + int arg10 = signalonly ? 1 : comm->sms; + void *arg11 = counters; + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), + reinterpret_cast(&arg3), reinterpret_cast(&arg4), + reinterpret_cast(&arg5), reinterpret_cast(&arg6), + reinterpret_cast(&arg7), reinterpret_cast(&arg8), + reinterpret_cast(&arg9), reinterpret_cast(&arg10), + reinterpret_cast(&arg11)}; + CUDACHECK(cudaLaunchKernelExC(&cfg, reinterpret_cast(kuserbuffers_pushsendrecv_atomic), + kernelArgs)); +} + +void userbuffers_sendrecv_multiatomic(const int srchandler, const int dsthandler, + const size_t send_stride, const size_t recv_stride, + const size_t bytes, communicator *comm, const int send_peer, + const int recv_peer, const int nchunks, void *counters, + bool shuffle, cudaStream_t stream) { + assert(comm->push && comm->use_ce == 0); + + int send_peerlocal = send_peer % comm->nvsize; + int recv_peerlocal = recv_peer % comm->nvsize; + void *flagptr_send = + (comm->peer_ptr[0][send_peerlocal]) + + ((NVTE_REG0_OFFSET(comm) + NVTE_REG0_RECV + comm->myrank * NVTE_MAX_REGIONS + dsthandler) * + sizeof(int)); + void *flagptr_recv = + (comm->mem_ptr[0]) + + ((NVTE_REG0_OFFSET(comm) + NVTE_REG0_RECV + recv_peer * NVTE_MAX_REGIONS + dsthandler) * + sizeof(int)); + + SETUP_LAUNCH_CONFIG(comm->sms, 1024, stream); + + int *arg1 = &comm->send_id[send_peer]; + int *arg2 = reinterpret_cast(flagptr_send); + int4 *arg3 = reinterpret_cast((comm->mem_ptr[srchandler])); + int4 *arg4 = reinterpret_cast((comm->peer_ptr[dsthandler][send_peerlocal])); + int arg5 = bytes / 16; + int arg6 = comm->myrank; + int arg7 = recv_peer; + int *arg8 = &comm->recv_id[recv_peer * NVTE_MAX_REGIONS + dsthandler]; + int *arg9 = reinterpret_cast(flagptr_recv); + int arg10 = comm->sms; + void *arg11 = counters; + int arg12 = nchunks; + int arg13 = send_stride; + int arg14 = recv_stride; + bool arg15 = shuffle; + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), + reinterpret_cast(&arg3), reinterpret_cast(&arg4), + reinterpret_cast(&arg5), reinterpret_cast(&arg6), + reinterpret_cast(&arg7), reinterpret_cast(&arg8), + reinterpret_cast(&arg9), reinterpret_cast(&arg10), + reinterpret_cast(&arg11), reinterpret_cast(&arg12), + reinterpret_cast(&arg13), reinterpret_cast(&arg14), + reinterpret_cast(&arg15)}; + CUDACHECK(cudaLaunchKernelExC( + &cfg, reinterpret_cast(kuserbuffers_pushsendrecv_multiatomic), kernelArgs)); +} + __global__ void __launch_bounds__(MAX_THREADS) kuserbuffers_alltoall(void **baseflagptrs, int flagoffset, int4 *basesrcptr, void **dstptrs, size_t dstoffset, const int lines, const int myrank) { - if (blockIdx.x == myrank) return; + if (blockIdx.x == myrank) + return; int4 *dstptr = reinterpret_cast(dstptrs[blockIdx.x] + dstoffset); int *flagptr = reinterpret_cast(baseflagptrs[blockIdx.x] + flagoffset); const size_t myblockoffset = blockIdx.x * lines; @@ -1652,14 +3540,18 @@ __global__ void __launch_bounds__(MAX_THREADS) for (int line = start_elem; line < end_aligned; line += blockDim.x * UNROLLCOPY) { int4 val[UNROLLCOPY]; #pragma unroll - for (int i = 0; i < UNROLLCOPY; i++) val[i] = srcptr[line + i * blockDim.x]; + for (int i = 0; i < UNROLLCOPY; i++) + val[i] = srcptr[line + i * blockDim.x]; #pragma unroll - for (int i = 0; i < UNROLLCOPY; i++) dstptr[line + i * blockDim.x] = val[i]; + for (int i = 0; i < UNROLLCOPY; i++) + dstptr[line + i * blockDim.x] = val[i]; } - for (int line = end_aligned; line < end_elem; line += blockDim.x) dstptr[line] = srcptr[line]; + for (int line = end_aligned; line < end_elem; line += blockDim.x) + dstptr[line] = srcptr[line]; } __syncthreads(); - if (threadIdx.x) return; + if (threadIdx.x) + return; __threadfence_system(); atomicAdd(flagptr, 1); @@ -1702,7 +3594,8 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds sizeof(int)); bool signalonly = (bytes / 16 == 0) || (comm->use_ce != 0); bool intranode = INTRANODE(peer); - if (!(comm->launch_mode & NVTE_LAUNCH_GPU)) return; + if (!(comm->launch_mode & NVTE_LAUNCH_GPU)) + return; if (comm->push == 0 && intranode) { void *dstptr = (comm->mem_ptr[dsthandler]) + dstoffset; void *srcptr = (comm->peer_ptr[srchandler][peerlocal]) + srcoffset; @@ -1728,7 +3621,45 @@ void userbuffers_alltoall_recv(communicator *comm, cudaStream_t stream) { (comm->mem_ptr[0]) + ((NVTE_REG0_OFFSET(comm) + NVTE_REG0_OPFLAGS * userbuffers_alltoall) * sizeof(int)); - if (!(comm->launch_mode & NVTE_LAUNCH_GPU)) return; + if (!(comm->launch_mode & NVTE_LAUNCH_GPU)) + return; kuserbuffers_pushrecv<<<1, 1, 0, stream>>>(comm->myrank, -1, reinterpret_cast(flagptr + 4), reinterpret_cast(flagptr), comm->nranks - 1); } + +// producer +static __global__ void producer_kernel(void *atomic_ptr, int chunk_i) { + // Decrement atomic val to signal current output tile finish + if (blockIdx.x == 0 && threadIdx.x == 0) { + ((unsigned int *)atomic_ptr)[chunk_i] = 0; + } + + // COMM kernel need to explicitely flash gmem. + // GEMM kernel already executed, and can not see gmem + // change without COMM kernel explicitely make change + asm volatile("fence.sc.gpu;\n"); +} + +// consumer +static __global__ void consumer_kernel(void *atomic_ptr, int chunk_i) { + // Wait for producer to change the val to 0, which signal producer ready + if (blockIdx.x == 0 && threadIdx.x == 0) { + int old_val; + while (0 != (old_val = atomicCAS((unsigned int *)atomic_ptr + chunk_i, 0, 0))) { + } + ((unsigned int *)atomic_ptr)[chunk_i] = 1; + asm volatile("fence.sc.gpu;\n"); + } +} + +void producer(void *atomic_ptr, int chunk_i, cudaStream_t stream) { + dim3 block(1); + dim3 grid(1); + producer_kernel<<>>(atomic_ptr, chunk_i); +} + +void consumer(void *atomic_ptr, int chunk_i, cudaStream_t stream) { + dim3 block(1); + dim3 grid(1); + consumer_kernel<<>>(atomic_ptr, chunk_i); +} diff --git a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.h b/transformer_engine/pytorch/csrc/userbuffers/userbuffers.h index d6ec23c40d..7f635771c9 100644 --- a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.h +++ b/transformer_engine/pytorch/csrc/userbuffers/userbuffers.h @@ -24,6 +24,18 @@ #define NVTE_LAUNCH_CPU 2 #define NVTE_MAX_NVLINK 8 +#define UB_MEM_UC_CONTIG 1 +#define UB_MEM_MC_CREATED 2 +#define UB_MEM_ALLOCATED 4 + +#define NVTE_UB_MEM_UC_CONTIG 1 +#define NVTE_UB_MEM_MC_CREATED 2 +#define NVTE_UB_MEM_ALLOCATED 4 + +#ifdef UCP +#include +#endif + // region 0 flag offsets #define NVTE_REG0_OPFLAGS 1024 #define NVTE_REG0_RECV (NVTE_REG0_OPFLAGS * userbuffers_op_types) @@ -35,6 +47,10 @@ #define NVTE_REG0_IBRS 32 #define NVTE_REG0_IBAG 512 +#if defined(UCP) || !defined(NOSHARP) +#undef REG0_COMMBUFFER +#define REG0_COMMBUFFER (1024*1024*16) +#endif // gpuflags map offsets #define NVTE_GF_STATE 16000 #define NVTE_GF_IBSHARPDONE 0 @@ -81,6 +97,19 @@ struct communicator { void *mem_ptr[NVTE_MAX_REGIONS]; void **peer_ptr[NVTE_MAX_REGIONS]; + + int memflags[NVTE_MAX_REGIONS]; // UC,MC, user/lib allocated + + CUmemGenericAllocationHandle *uchandles[NVTE_MAX_REGIONS]; + void* ucbase_ptr[NVTE_MAX_REGIONS]; // only for cuMem allocated memory + size_t mem_size[NVTE_MAX_REGIONS]; + + void* mc_ptr[NVTE_MAX_REGIONS]; + void* mc_baseptr; + CUmemGenericAllocationHandle mc_handle; + size_t mc_offset, mc_maxsize; + int use_mc; // 1: use MC if available, 0: override not to use MC + int ar_nvsize, ar_firstgpu, ar_nvrank; // number of gpus(and first gpu in a group) of gpus per node in reduction subgroup // (_splitar init used) would be equal to (nvsize,0) for regular comm_create @@ -120,6 +149,8 @@ struct communicator { }; typedef struct communicator communicator; +void producer(void *atomic_ptr, int chunk_i, cudaStream_t stream); +void consumer(void *atomic_ptr, int chunk_i, cudaStream_t stream); int create_communicator(communicator **comm); /* creates communicator, allocates all internal buffers if necessary */ @@ -191,6 +222,45 @@ void reducescatter2_userbuff_stridedoutput(void *output, const int handler, cons const int rowelements, const int colelements, const int strideelements, communicator *comm, cudaStream_t stream = 0); +template +void reducescatter2_userbuff_stridedoutput_fp8(void* output, float* scale, const int handler, + const int offset, const int rowelements, + const int colelements, const int strideelements, + communicator* comm, cudaStream_t stream = 0); +template +void reducescatter2_userbuff_fp8(void* output, float* scale, const int handler, const int offset, + const int elements, communicator* comm, cudaStream_t stream = 0); +#if 0 +template +void reducescatter2_userbuff_strided_atomic_fp8(void* output, float *scale, const int handler, + const int offset, const int rowelements, + const int colelements, const int strideelements, + const int numchunks, void *counters, + communicator* comm, cudaStream_t stream = 0); +#endif +template +void reducescatter2_userbuff_strided_atomic_fp8(void* output, float *scale, const int handler, + const int offset, const int rowelements, + const int colelements, const int strideelements_out, + const int strideelements_in, const int numchunks, + void *counters, communicator* comm, + cudaStream_t stream = 0); +template +void reducescatter2_userbuff_strided_multiatomic_fp8( + void* output, float *scale, const int handler, const int offset, const int rowelements, + const int colelements, const int strideelements_out, const int strideelements_in, + const int numchunks, void *counters, communicator* comm, cudaStream_t stream = 0); +void reducescatter2_userbuff_strided( + void* output, const int handler, const int offset, const int rowelements, const int colelements, + const int strideelements, communicator* comm, cudaStream_t stream = 0); +void reducescatter2_userbuff_strided_atomic( + void* output, const int handler , const int offset, const int rowelements, const int colelements, + const int strideelements, const int numchunks, void *counters, communicator* comm, + cudaStream_t stream = 0); +void reducescatter2_userbuff_strided_multiatomic( + void* output, const int handler, const int offset, const int rowelements, const int colelements, + const int strideelements, const int numchunks, void *counters, communicator* comm, + cudaStream_t stream = 0); /* everything should be 16byte aligned = 8 elts aligned output is strided: row starts separated by stride elements*/ @@ -208,6 +278,19 @@ void userbuffers_send(const int srchandler, const size_t srcoffset, const int ds void userbuffers_recv(const int srchandler, const size_t srcoffset, const int dsthandler, const size_t dstoffset, const size_t bytes, communicator *comm, const int peer, cudaStream_t stream = 0); +void userbuffers_sendrecv( + const int srchandler, const int dsthandler, const size_t send_offset, const size_t recv_offset, + const size_t bytes, communicator* comm, const int send_peer, const int recv_peer, + cudaStream_t stream = 0); +void userbuffers_sendrecv_atomic( + const int srchandler, const int dsthandler, const size_t send_offset, const size_t recv_offset, + const size_t bytes, communicator* comm, const int send_peer, const int recv_peer, void *counters, + cudaStream_t stream = 0); +void userbuffers_sendrecv_multiatomic( + const int srchandler, const int dsthandler, const size_t send_offset, const size_t recv_offset, + const size_t bytes, communicator* comm, const int send_peer, const int recv_peer, + const int nchunks, void *counters, bool shuffle, cudaStream_t stream = 0); + // alltoall split send and recv to allow for overlap // send kicks in sending data to the destination - invoke on same stream as data generation diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 8bb9d55f38..7076e59600 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -124,6 +124,8 @@ def initialize_ub( fp8_buf = [ "qkv_fprop", "qkv_dgrad", "proj_dgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad" ] + if bool(int(os.getenv("NVTE_UB_FP8_RS", "0"))): + fp8_buf.append ("proj_fprop") # Default overlap methods for layers methods = { "ring_exchange":["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad"], @@ -153,8 +155,12 @@ def add_ub( sample_buffer, # Sample userbuffer rank_id, # Rank id tp_size, # TP size + num_sm, # Number of communication SMs + cga_size, # CGA cluster size + set_sm_margin, # Set SM margin aggregate, # Aggregate 2X GEMM chunks _NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams + torch.Tensor(), # empty tensor to pass to counters ) else: ub_obj = tex.UbufCommOverlap( @@ -166,6 +172,7 @@ def add_ub( num_splits, # Number of communication splits set_sm_margin, # Set SM margin _NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams + torch.Tensor(), # empty tensor to pass to counters ) _ub_communicators[name] = ub_obj @@ -676,10 +683,12 @@ def grad_output_preprocess( grad_output_mat = grad_output.view((-1, grad_output.shape[-1])) gather_grad_output = row_parallel_mode and ctx.sequence_parallel + if gather_grad_output: + ub_overlap_ag = ctx.ub_split_ag or ctx.ub_atomic_gemm_ag # No-FP8 case: bgrad is fused with wgrad for this case. if not ctx.fp8: if gather_grad_output: - if not ctx.ub_split_ag: + if not ub_overlap_ag: grad_output_mat, _ = gather_along_first_dim( grad_output_mat, ctx.tp_group ) @@ -698,8 +707,8 @@ def grad_output_preprocess( and ctx.fp8_meta["recipe"].override_linear_precision.wgrad ): assert ( - not ctx.ub_split_ag - ), "override_linear_precision.wgrad not supported with ub_split_ag" + not ub_overlap_ag + ), "override_linear_precision.wgrad not supported with UB AG overlap" grad_output_mat, _ = gather_along_first_dim(grad_output_mat, ctx.tp_group) # FP8 case with gather: unfused bgrad, cast, transpose for efficient gather elif gather_grad_output: @@ -707,7 +716,7 @@ def grad_output_preprocess( grad_bias = grad_output_mat.sum(dim=0) else: grad_bias = None - if ctx.ub_split_ag: + if ub_overlap_ag: grad_output_c = ctx.ub_obj_gradout.get_ubuf_output(0) else: grad_output_c = torch.empty_like(grad_output_mat, dtype=torch.uint8) @@ -718,7 +727,7 @@ def grad_output_preprocess( fp8_dtype_backward, out=grad_output_c, ) - if not ctx.ub_split_ag: + if not ub_overlap_ag: grad_output_c, _ = gather_along_first_dim(grad_output_c, ctx.tp_group) grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward) else: diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index b7372f81fe..71af058415 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -83,6 +83,7 @@ def forward( ub_bulk_dgrad: bool, ub_split_ag: bool, normalization: str, + ub_atomic_gemm_ag: bool, ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # Make sure input dimensions are compatible in_features = ln_weight.numel() @@ -100,11 +101,12 @@ def forward( if ln_bias is not None: ln_bias = cast_if_needed(ln_bias, activation_dtype) - if ub_split_ag: + if ub_split_ag or ub_atomic_gemm_ag: tp_world_size = get_distributed_world_size(tp_group) if tp_world_size == 1 or (not is_grad_enabled) or return_layernorm_output: ub_split_ag = False - if ub_split_ag: + ub_atomic_gemm_ag = False + if ub_split_ag or ub_atomic_gemm_ag: dim_size = list(inputmat.size()) dim_size[0] = dim_size[0] * tp_world_size ub_obj_lnout = get_ub("qkv_fprop") @@ -112,6 +114,8 @@ def forward( else: ln_out_dtype = torch.uint8 if fp8 else inputmat.dtype ln_out = torch.empty_like(inputmat, dtype=ln_out_dtype) + if ub_atomic_gemm_ag: + assert fp8, "AtomicGemm overlap supported only for FP8 GEMM." fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) @@ -139,7 +143,7 @@ def forward( fp8_dtype_forward, ) # Column Parallel Linear - if ub_split_ag: + if ub_split_ag or ub_atomic_gemm_ag: ln_out_total = ub_obj_lnout.get_ubuf_output(1) ln_out = torch.empty_like(ln_out) elif parallel_mode == "column" and sequence_parallel: @@ -173,6 +177,8 @@ def forward( tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward) + ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None + ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ub_atomic_gemm_ag else ub_algo out = tex.fp8_gemm( weight_fp8, fp8_meta["scaling_fwd"].scale_inv, @@ -187,9 +193,9 @@ def forward( bias=bias, use_bias=use_bias, use_split_accumulator=_2X_ACC_FPROP, - ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None, - ub=ub_obj_lnout if ub_split_ag else None, - extra_output_tensor=ln_out if ub_split_ag else None, + ub_algo=ub_algo, + ub=ub_obj_lnout if (ub_split_ag or ub_atomic_gemm_ag) else None, + extra_output_tensor=ln_out if (ub_split_ag or ub_atomic_gemm_ag) else None, ) else: # Cast for native AMP @@ -339,6 +345,14 @@ def backward( fp8_dtype_backward = get_fp8_te_dtype( ctx.fp8_meta["recipe"], fprop_tensor=False ) + out_index, meta_tensor, out_te_type, out_type = ( + None, None, None, ctx.activation_dtype) + if ctx.ub_bulk_wgrad and ub_obj_dgrad.is_fp8_ubuf(): + out_index = tex.FP8BwdTensors.GRAD_INPUT1 + meta_tensor = ctx.fp8_meta["scaling_bwd"] + out_te_type = fp8_dtype_backward + out_type = torch.uint8 + ub_obj_dgrad.set_ubuf_scale_inv(meta_tensor.scale_inv[out_index]) # DGRAD: Evaluated unconditionally to feed into Linear backward _ = tex.fp8_gemm( @@ -350,12 +364,15 @@ def backward( ctx.fp8_meta["scaling_bwd"].scale_inv, tex.FP8BwdTensors.GRAD_OUTPUT1, fp8_dtype_backward, - ctx.activation_dtype, + out_type, get_workspace(), out=dgrad, use_split_accumulator=_2X_ACC_DGRAD, ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_AG if ctx.ub_bulk_dgrad else None, - ub=ub_obj_lnout if ctx.ub_bulk_dgrad else None + ub=ub_obj_lnout if ctx.ub_bulk_dgrad else None, + out_index=out_index, + fp8_meta_tensor = meta_tensor, + D_dtype = out_te_type, ) else: # DGRAD: Evaluated unconditionally to feed into Linear backward @@ -387,6 +404,15 @@ def backward( if weight.requires_grad: if ctx.fp8: # WGRAD + extra_output_tensor = None + if ctx.ub_bulk_wgrad: + if ub_obj_dgrad.is_fp8_ubuf(): + dim_size = list(ub_obj_dgrad.get_ubuf_output(0).size()) # RS output + extra_output_tensor = torch.empty( + dim_size, dtype=ctx.activation_dtype, device=dgrad.device) + dgrad = extra_output_tensor + else: + dgrad = ub_obj_dgrad.get_ubuf_output(0) if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward) wgrad = tex.fp8_gemm( @@ -405,7 +431,8 @@ def backward( use_split_accumulator=_2X_ACC_WGRAD, ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None, - ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None + ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, + extra_output_tensor=extra_output_tensor ) else: ln_out_total_c = tex.cast_from_fp8( @@ -426,7 +453,8 @@ def backward( out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None, - ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None + ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, + extra_output_tensor=extra_output_tensor ) else: # WGRAD @@ -443,12 +471,14 @@ def backward( ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None, ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None ) + if ctx.ub_bulk_wgrad: + dgrad = ub_obj_dgrad.get_ubuf_output(0) # Reduce-scatter output - - if ctx.ub_bulk_wgrad: - dgrad = ub_obj_dgrad.get_ubuf_output(0) # Reduce-scatter output # Column Parallel Linear - elif ctx.parallel_mode == "column" and ctx.tensor_parallel and handle is not None: + if ((not ctx.ub_bulk_wgrad) + and ctx.parallel_mode == "column" + and ctx.tensor_parallel + and handle is not None): handle.wait() # LayerNorm gradient @@ -504,6 +534,7 @@ def backward( None, None, None, + None, ) @@ -616,6 +647,7 @@ def __init__( ub_bulk_dgrad: bool = False, ub_split_ag: bool = False, device: Union[torch.device, str] = "cuda", + ub_atomic_gemm_ag: bool = False, ) -> None: super().__init__() @@ -642,12 +674,18 @@ def __init__( self.ub_bulk_wgrad = ub_bulk_wgrad self.ub_bulk_dgrad = ub_bulk_dgrad self.ub_split_ag = ub_split_ag + self.ub_atomic_gemm_ag = ub_atomic_gemm_ag - if ub_bulk_wgrad or ub_bulk_dgrad or ub_split_ag: + if ub_bulk_wgrad or ub_bulk_dgrad or ub_split_ag or ub_atomic_gemm_ag: assert ( tex.userbuf_comm_available() ), "Userbuffer communication backend not available." + if ub_atomic_gemm_ag: + warnings.warn( + "Atomic gemm uses a beta API from cublas and is not tested for all use cases." + ) + if tp_group is None: self.tp_size = tp_size if tp_size == 1: @@ -909,6 +947,7 @@ def forward( self.ub_bulk_dgrad, self.ub_split_ag, self.normalization, + self.ub_atomic_gemm_ag, ) out = fwd_fn(*args) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index ea9f7b5b2b..2daf73f11c 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -4,6 +4,7 @@ """LayerNormMLP API""" import os +import warnings from typing import Union, Optional, Callable, Tuple, List, Dict, Any import torch @@ -107,7 +108,9 @@ def forward( ub_bulk_wgrad: bool, ub_bulk_dgrad: bool, ub_split_rs: bool, + ub_atomic_gemm_rs: bool, ub_split_ag: bool, + ub_atomic_gemm_ag: bool, activation: str, normalization: str, ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: @@ -130,20 +133,25 @@ def forward( if ln_bias is not None: ln_bias = cast_if_needed(ln_bias, activation_dtype) - if ub_split_ag: + if ub_split_ag or ub_atomic_gemm_ag: tp_world_size = get_distributed_world_size(tp_group) if tp_world_size == 1 or (not is_grad_enabled) or return_layernorm_output: ub_split_ag = False - if ub_split_ag: + ub_atomic_gemm_ag = False + ub_overlap_ag = ub_split_ag or ub_atomic_gemm_ag + if ub_overlap_ag: ub_obj_lnout = get_ub("fc1_fprop") ln_out = ub_obj_lnout.get_ubuf_output(0) else: ln_out_dtype = torch.uint8 if (fp8 and not return_layernorm_output) else inputmat.dtype ln_out = torch.empty_like(inputmat, dtype=ln_out_dtype) - if ub_split_rs: + if ub_split_rs or ub_atomic_gemm_rs: tp_world_size = get_distributed_world_size(tp_group) if tp_world_size == 1: ub_split_rs = False + ub_atomic_gemm_rs = False + if ub_atomic_gemm_rs or ub_atomic_gemm_ag: + assert fp8, "AtomicGemm overlap supported only for FP8 GEMM." fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) @@ -171,7 +179,7 @@ def forward( fp8_dtype_forward, ) # Column Parallel Linear - if ub_split_ag: + if ub_overlap_ag: ln_out_total = ub_obj_lnout.get_ubuf_output(1) ln_out = torch.empty_like(ln_out) elif set_parallel_mode and sequence_parallel: @@ -223,6 +231,8 @@ def forward( fp8_dtype_forward, ) + ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None + ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ub_atomic_gemm_ag else ub_algo fc1_out = tex.fp8_gemm( fc1_weight_fp8, fp8_meta["scaling_fwd"].scale_inv, @@ -237,9 +247,9 @@ def forward( bias=fc1_bias, use_bias=use_fc1_bias, use_split_accumulator=_2X_ACC_FPROP, - ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None, - ub=ub_obj_lnout if ub_split_ag else None, - extra_output_tensor=ln_out if ub_split_ag else None, + ub_algo=ub_algo, + ub=ub_obj_lnout if ub_overlap_ag else None, + extra_output_tensor=ln_out if ub_overlap_ag else None, ) gelu_out = activation_func( @@ -249,18 +259,29 @@ def forward( fp8_dtype_forward, ) - if ub_split_rs: + fc2_out_index, fc2_meta_tensor, fc2_te_type, out_type = ( + None, None, None, activation_dtype) + if ub_split_rs or ub_atomic_gemm_rs: ub_obj_fc2out = get_ub("fc2_fprop") fc2_out = ub_obj_fc2out.get_ubuf_output(1) dim_size = list(gelu_out.size()) dim_size[0] = dim_size[0] // tp_world_size dim_size[1] = fc2_weight.size(0) rs_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device) + + if ub_obj_fc2out.is_fp8_ubuf(): + fc2_out_index = tex.FP8FwdTensors.GEMM2_OUTPUT + fc2_meta_tensor = fp8_meta["scaling_fwd"] + fc2_te_type = fp8_dtype_forward + out_type = torch.uint8 + ub_obj_fc2out.set_ubuf_scale_inv(fc2_meta_tensor.scale_inv[fc2_out_index]) else: dim_size = list(gelu_out.size()) dim_size[1] = fc2_weight.size(0) fc2_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device) + ub_algo=tex.UbufOverlapAlgo.ATOMIC_GEMM_RS if ub_atomic_gemm_rs else None + ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else ub_algo _ = tex.fp8_gemm( fc2_weight_fp8, fp8_meta["scaling_fwd"].scale_inv, @@ -270,15 +291,18 @@ def forward( fp8_meta["scaling_fwd"].scale_inv, tex.FP8FwdTensors.GEMM2_INPUT, fp8_dtype_forward, - activation_dtype, + out_type, get_workspace(), bias=fc2_bias, use_bias=use_fc2_bias, use_split_accumulator=_2X_ACC_FPROP, out=fc2_out, - ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else None, - ub=ub_obj_fc2out if ub_split_rs else None, - extra_output_tensor=rs_out if ub_split_rs else None, + ub_algo=ub_algo, + ub=ub_obj_fc2out if ub_split_rs or ub_atomic_gemm_rs else None, + extra_output_tensor=rs_out if ub_split_rs or ub_atomic_gemm_rs else None, + out_index=fc2_out_index, + fp8_meta_tensor = fc2_meta_tensor, + D_dtype = fc2_te_type, ) else: # Cast for native AMP @@ -394,11 +418,12 @@ def forward( ctx.ub_bulk_wgrad = ub_bulk_wgrad ctx.ub_bulk_dgrad = ub_bulk_dgrad ctx.ub_split_ag = ub_split_ag + ctx.ub_atomic_gemm_ag = ub_atomic_gemm_ag ctx.requires_dgrad = inp.requires_grad ctx.normalization = normalization # Row Parallel Linear - if ub_split_rs: + if ub_split_rs or ub_atomic_gemm_rs: fc2_out = rs_out elif set_parallel_mode and sequence_parallel: fc2_out, _ = reduce_scatter_along_first_dim(fc2_out, tp_group) @@ -447,11 +472,15 @@ def backward( dim_size[0] = dim_size[0] * tp_world_size ub_obj_lnout = get_ub("fc1_dgrad") ub_obj_lnout.copy_input_to_ubuf(ln_out, 1) - if ctx.ub_split_ag: + ub_overlap_ag = ctx.ub_split_ag or ctx.ub_atomic_gemm_ag + if ub_overlap_ag: tp_world_size = get_distributed_world_size(ctx.tp_group) if tp_world_size == 1: ctx.ub_split_ag = False - if ctx.ub_split_ag: + ctx.ub_overlap_ag = False + ub_overlap_ag = ctx.ub_split_ag or ctx.ub_atomic_gemm_ag + + if ub_overlap_ag: dim_size = list(grad_outputs[0].size()) dim_size[0] = dim_size[0] * tp_world_size ctx.ub_obj_gradout = get_ub("fc2_dgrad") @@ -497,6 +526,8 @@ def backward( ctx.fp8_meta["recipe"], fprop_tensor=False ) + ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None + ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ctx.ub_atomic_gemm_ag else ub_algo # FC2 DGRAD; Unconditional fc2_dgrad = tex.fp8_gemm( fc2_weight_t_fp8, @@ -510,10 +541,10 @@ def backward( ctx.activation_dtype, get_workspace(), use_split_accumulator=_2X_ACC_DGRAD, - ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None, - ub=ctx.ub_obj_gradout if ctx.ub_split_ag else None, + ub_algo=ub_algo, + ub=ctx.ub_obj_gradout if ub_overlap_ag else None, ) - if ctx.ub_split_ag: + if ub_overlap_ag: grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward) # FC2 WGRAD if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: @@ -595,11 +626,19 @@ def backward( ) dgelu_t = None + out_index, meta_tensor, out_te_type, out_type = ( + None, None, None, ctx.activation_dtype) fc1_dgrad_size = list(dgelu.size()) fc1_dgrad_size[1] = fc1_weight.size(1) if ctx.ub_bulk_wgrad: # allocate dgrad output ub_obj_dgrad = get_ub("fc1_wgrad") fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output + if ub_obj_dgrad.is_fp8_ubuf(): + out_index = tex.FP8BwdTensors.GRAD_INPUT2 + meta_tensor = ctx.fp8_meta["scaling_bwd"] + out_te_type = fp8_dtype_backward + out_type = torch.uint8 + ub_obj_dgrad.set_ubuf_scale_inv(meta_tensor.scale_inv[out_index]) else: fc1_dgrad = torch.empty( fc1_dgrad_size, dtype=ctx.activation_dtype, device=fc1_weight.device @@ -614,12 +653,15 @@ def backward( ctx.fp8_meta["scaling_bwd"].scale_inv, tex.FP8BwdTensors.GRAD_OUTPUT2, fp8_dtype_backward, - ctx.activation_dtype, + out_type, get_workspace(), out=fc1_dgrad, use_split_accumulator=_2X_ACC_DGRAD, ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_AG if ctx.ub_bulk_dgrad else None, - ub=ub_obj_lnout if ctx.ub_bulk_dgrad else None + ub=ub_obj_lnout if ctx.ub_bulk_dgrad else None, + out_index=out_index, + fp8_meta_tensor = meta_tensor, + D_dtype = out_te_type, ) else: # FC2 DGRAD; Unconditional @@ -703,6 +745,15 @@ def backward( if fc1_weight.requires_grad: if ctx.fp8: # FC1 WGRAD + extra_output_tensor = None + if ctx.ub_bulk_wgrad: + if ub_obj_dgrad.is_fp8_ubuf(): + dim_size = list(ub_obj_dgrad.get_ubuf_output(0).size()) # RS output + extra_output_tensor = torch.empty( + dim_size, dtype=ctx.activation_dtype, device=fc1_dgrad.device) + fc1_dgrad = extra_output_tensor + else: + fc1_dgrad = ub_obj_dgrad.get_ubuf_output(0) if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward) fc1_wgrad = tex.fp8_gemm( @@ -724,6 +775,7 @@ def backward( ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None, ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, + extra_output_tensor=extra_output_tensor, ) else: ln_out_total_c = tex.cast_from_fp8( @@ -747,6 +799,7 @@ def backward( ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None, ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, + extra_output_tensor=extra_output_tensor, ) else: # FC1 WGRAD @@ -768,11 +821,14 @@ def backward( fc1_wgrad, _, _ = fc1_wgrad_outputs else: fc1_wgrad, fc1_bias_grad, _ = fc1_wgrad_outputs + if ctx.ub_bulk_wgrad: + fc1_dgrad = ub_obj_dgrad.get_ubuf_output(0) # Reduce-scatter output # Column Parallel Linear - if ctx.ub_bulk_wgrad: - fc1_dgrad = ub_obj_dgrad.get_ubuf_output(0) # Reduce-scatter output - elif ctx.set_parallel_mode and ctx.tensor_parallel and handle is not None: + if ((not ctx.ub_bulk_wgrad) + and ctx.set_parallel_mode + and ctx.tensor_parallel + and handle is not None): handle.wait() # LayerNorm gradient @@ -832,6 +888,8 @@ def backward( None, None, None, + None, + None, ) @@ -947,8 +1005,10 @@ def __init__( ub_bulk_wgrad: bool = False, ub_bulk_dgrad: bool = False, ub_split_rs: bool = False, + ub_atomic_gemm_rs: bool = False, ub_split_ag: bool = False, device: Union[torch.device, str] = "cuda", + ub_atomic_gemm_ag: bool = False, ) -> None: super().__init__() @@ -969,12 +1029,24 @@ def __init__( self.ub_bulk_dgrad = ub_bulk_dgrad self.ub_split_rs = ub_split_rs self.ub_split_ag = ub_split_ag - - if ub_bulk_wgrad or ub_bulk_dgrad or ub_split_rs or ub_split_ag: + self.ub_atomic_gemm_rs = ub_atomic_gemm_rs + self.ub_atomic_gemm_ag = ub_atomic_gemm_ag + + if (ub_bulk_wgrad # pylint: disable=too-many-boolean-expressions + or ub_bulk_dgrad + or ub_split_rs + or ub_split_ag + or ub_atomic_gemm_rs + or ub_atomic_gemm_ag): assert ( tex.userbuf_comm_available() ), "Userbuffer communication backend not available." + if ub_atomic_gemm_rs or ub_atomic_gemm_ag: + warnings.warn( + "Atomic gemm uses a beta API from cublas and is not tested for all use cases." + ) + if tp_group is None: self.tp_size = tp_size if tp_size == 1: @@ -1189,7 +1261,9 @@ def forward( self.ub_bulk_wgrad, self.ub_bulk_dgrad, self.ub_split_rs, + self.ub_atomic_gemm_rs, self.ub_split_ag, + self.ub_atomic_gemm_ag, self.activation, self.normalization, ) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 98ca2015ed..2d9dbac057 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -77,6 +77,8 @@ def forward( is_grad_enabled: bool, ub_split_rs: bool, ub_split_ag: bool, + ub_atomic_gemm_rs: bool, + ub_atomic_gemm_ag: bool, ) -> torch.Tensor: # Make sure input dimensions are compatible in_features = weight.shape[-1] @@ -88,10 +90,13 @@ def forward( update_fp8_weights = is_first_microbatch is None or is_first_microbatch - if ub_split_rs: + if ub_split_rs or ub_atomic_gemm_rs: tp_world_size = get_distributed_world_size(tp_group) if tp_world_size == 1: ub_split_rs = False + ub_atomic_gemm_rs = False + if ub_atomic_gemm_rs or ub_atomic_gemm_ag: + assert fp8, "AtomicGemm overlap supported only for FP8 GEMM." # Cast for native AMP inputmat = cast_if_needed(inputmat, activation_dtype) inputmat_no_fp8 = inputmat @@ -155,18 +160,29 @@ def forward( fp8_dtype_forward, ) - if ub_split_rs: + proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = ( + None, None, None, activation_dtype) + if ub_split_rs or ub_atomic_gemm_rs: ub_obj_projout = get_ub("proj_fprop") out = ub_obj_projout.get_ubuf_output(1) dim_size = list(inputmat_total.size()) dim_size[0] = dim_size[0] // tp_world_size dim_size[1] = weight.size(0) rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) + + if ub_obj_projout.is_fp8_ubuf(): + proj_out_index = tex.FP8FwdTensors.GEMM1_OUTPUT + meta_tensor = fp8_meta["scaling_fwd"] + proj_out_tetype = fp8_dtype_forward + proj_out_pttype = torch.uint8 + ub_obj_projout.set_ubuf_scale_inv(meta_tensor.scale_inv[proj_out_index]) else: dim_size = list(inputmat_total.size()) dim_size[1] = weight.size(0) out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) + ub_algo=tex.UbufOverlapAlgo.ATOMIC_GEMM_RS if ub_atomic_gemm_rs else None + ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else ub_algo _ = fp8_gemm( weight_fp8, fp8_meta["scaling_fwd"].scale_inv, @@ -176,15 +192,18 @@ def forward( fp8_meta["scaling_fwd"].scale_inv, tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype_forward, - activation_dtype, + proj_out_pttype, get_workspace(), bias=bias, use_bias=use_bias, use_split_accumulator=_2X_ACC_FPROP, out=out, - ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else None, - ub=ub_obj_projout if ub_split_rs else None, - extra_output_tensor=rs_out if ub_split_rs else None, + ub_algo=ub_algo, + ub=ub_obj_projout if (ub_split_rs or ub_atomic_gemm_rs) else None, + extra_output_tensor=rs_out if (ub_split_rs or ub_atomic_gemm_rs) else None, + out_index=proj_out_index, + fp8_meta_tensor = meta_tensor, + D_dtype = proj_out_tetype, ) else: # Cast for native AMP @@ -245,11 +264,12 @@ def forward( ctx.parallel_mode = parallel_mode ctx.tp_group = tp_group ctx.ub_split_ag = ub_split_ag + ctx.ub_atomic_gemm_ag = ub_atomic_gemm_ag ctx.tp_size = tp_size ctx.requires_dgrad = inp.requires_grad # Row Parallel Linear - if ub_split_rs: + if ub_split_rs or ub_atomic_gemm_rs: out = rs_out elif parallel_mode == "row" and sequence_parallel: out, _ = reduce_scatter_along_first_dim(out, tp_group) @@ -275,11 +295,12 @@ def backward( fwd_scale_inverses, ) = ctx.saved_tensors - if ctx.ub_split_ag: + if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag: tp_world_size = get_distributed_world_size(ctx.tp_group) if tp_world_size == 1: ctx.ub_split_ag = False - if ctx.ub_split_ag: + ctx.ub_atomic_gemm_ag = False + if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag: dim_size = list(grad_output.size()) dim_size[0] = dim_size[0] * tp_world_size ctx.ub_obj_gradout = get_ub("proj_dgrad") @@ -323,6 +344,8 @@ def backward( ctx.fp8_meta["recipe"], fprop_tensor=False ) + ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None + ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ctx.ub_atomic_gemm_ag else ub_algo if ctx.requires_dgrad: if ctx.fp8: dgrad = fp8_gemm( @@ -337,8 +360,8 @@ def backward( ctx.activation_dtype, get_workspace(), use_split_accumulator=_2X_ACC_DGRAD, - ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None, - ub=ctx.ub_obj_gradout if ctx.ub_split_ag else None, + ub_algo=ub_algo, + ub=ctx.ub_obj_gradout if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag else None, ) else: dgrad, _, _ = gemm( @@ -366,7 +389,7 @@ def backward( if ctx.fp8: # WGRAD if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: - if ctx.ub_split_ag: + if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag: grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward) wgrad = fp8_gemm( inputmat_t_total, @@ -436,6 +459,8 @@ def backward( None, None, None, + None, + None, ) @@ -529,6 +554,8 @@ def __init__( ub_split_rs: bool = False, ub_split_ag: bool = False, device: Union[torch.device, str] = "cuda", + ub_atomic_gemm_rs: bool = False, + ub_atomic_gemm_ag: bool = False, ) -> None: super().__init__() @@ -550,12 +577,19 @@ def __init__( self.parameters_split = parameters_split self.ub_split_rs = ub_split_rs self.ub_split_ag = ub_split_ag + self.ub_atomic_gemm_rs = ub_atomic_gemm_rs + self.ub_atomic_gemm_ag = ub_atomic_gemm_ag - if ub_split_rs or ub_split_ag: + if ub_split_rs or ub_split_ag or ub_atomic_gemm_rs: assert ( tex.userbuf_comm_available() ), "Userbuffer communication backend not available." + if ub_atomic_gemm_rs or ub_atomic_gemm_ag: + warnings.warn( + "Atomic gemm uses a beta API from cublas and is not tested for all use cases." + ) + if tp_group is None: self.tp_size = tp_size if tp_size == 1: @@ -774,6 +808,8 @@ def forward( torch.is_grad_enabled(), self.ub_split_rs, self.ub_split_ag, + self.ub_atomic_gemm_rs, + self.ub_atomic_gemm_ag, ) out = linear_fn(*args) diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index d8a1aa1ad2..cded3bf53f 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -263,6 +263,22 @@ def __init__( ub_bulk_dgrad = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_BULK_DGRAD", "1"))) ub_split_ag = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_SPLIT_AG", "1"))) ub_split_rs = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_SPLIT_RS", "1"))) + ub_atomic_gemm_rs = (ub_tp_comm_overlap + and bool(int(os.getenv("NVTE_UB_ATOMIC_GEMM_RS", "0")))) + assert ( + not (ub_split_rs and ub_atomic_gemm_rs) + ), "Only one type of RS overlap NVTE_UB_SPLIT_RS/NVTE_UB_ATOMIC_GEMM_RS should be enabled." + ub_atomic_gemm_ag = (ub_tp_comm_overlap + and bool(int(os.getenv("NVTE_UB_ATOMIC_GEMM_AG", "0")))) + assert ( + not (ub_split_ag and ub_atomic_gemm_ag) + ), "Only one type of AG overlap NVTE_UB_SPLIT_AG/NVTE_UB_ATOMIC_GEMM_AG should be enabled." + + if ub_atomic_gemm_rs or ub_atomic_gemm_ag: + warnings.warn( + "Atomic gemm uses a beta API from cublas and is not tested for all use cases." + ) + bias_dropout_fusion = bool(int(os.getenv("NVTE_BIAS_DROPOUT_FUSION", "1"))) self.layer_number = layer_number self.output_layernorm = output_layernorm @@ -323,6 +339,8 @@ def __init__( "ub_bulk_dgrad" : ub_bulk_dgrad, "ub_split_ag" : ub_split_ag, "ub_split_rs" : ub_split_rs, + "ub_atomic_gemm_rs" : ub_atomic_gemm_rs, + "ub_atomic_gemm_ag" : ub_atomic_gemm_ag, } self.self_attention = MultiheadAttention( @@ -377,6 +395,8 @@ def __init__( ub_bulk_dgrad=ub_bulk_dgrad, ub_split_rs=ub_split_rs, ub_split_ag=ub_split_ag, + ub_atomic_gemm_rs=ub_atomic_gemm_rs, + ub_atomic_gemm_ag=ub_atomic_gemm_ag, activation=activation, normalization=normalization, device=device, From 8eae4ce2b8fdfbbe525fc8bfecb0df5498cc9687 Mon Sep 17 00:00:00 2001 From: Ming-Xu Huang Date: Sat, 7 Oct 2023 01:20:16 +0800 Subject: [PATCH 61/68] [JAX] Enhance Dropout in TransformerLayer. (#444) * [JAX] Enhance Dropout in TransformerLayer. 1. Fixed missing setup of dropout RNG key in TransformerLayer and LayerNormMLP. 2. Allowing seperated dropout rate for FC1's output and other hiddens. Signed-off-by: Ming Huang * Fix wrong fp8 scale in _update_fp8_metas_impl Signed-off-by: Ming Huang * Fix typo Signed-off-by: Ming Huang --------- Signed-off-by: Ming Huang Co-authored-by: Kirthi Shankar Sivamani --- tests/jax/test_helper.py | 9 ++++---- tests/jax/test_layer.py | 8 +++++++ tests/jax/test_praxis_layers.py | 3 +++ transformer_engine/jax/flax/module.py | 6 +++++- transformer_engine/jax/flax/transformer.py | 22 +++++++++++++++----- transformer_engine/jax/fp8.py | 6 +++--- transformer_engine/jax/praxis/transformer.py | 4 ++++ 7 files changed, 44 insertions(+), 14 deletions(-) diff --git a/tests/jax/test_helper.py b/tests/jax/test_helper.py index 91ca06a90e..815aab6099 100644 --- a/tests/jax/test_helper.py +++ b/tests/jax/test_helper.py @@ -72,11 +72,10 @@ def get_fp8_scale(fp8_max, amax, scale): amax = np.array(amax) scale = np.array(scale) - exp = np.floor(np.log2(fp8_max / amax)) - FP8Helper.MARGIN - sf = np.round(np.power(2, np.abs(exp))) - sf = np.where(amax > 0.0, sf, scale) - sf = np.where(np.isfinite(amax), sf, scale) - return np.where(exp < 0, 1 / sf, sf) + sf = (fp8_max / amax) / (2**FP8Helper.MARGIN) + sf = jnp.where(amax > 0.0, sf, scale) + sf = jnp.where(jnp.isfinite(amax), sf, scale) + return sf amax_meta_shape = (num_of_meta, FP8Helper.AMAX_HISTORY_LEN) scale_meta_shape = (num_of_meta, 1) diff --git a/tests/jax/test_layer.py b/tests/jax/test_layer.py index a635c687b7..4f9e224663 100644 --- a/tests/jax/test_layer.py +++ b/tests/jax/test_layer.py @@ -167,6 +167,7 @@ def forward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): if k == 'dropout_rate': te_layer_attrs['attention_dropout'] = v te_layer_attrs['hidden_dropout'] = v + te_layer_attrs['intermediate_dropout'] = v elif k == 'fuse_mlp_wi': continue else: @@ -174,6 +175,7 @@ def forward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): ref_layer_cls = partial(RefEncoderLayer, dtype=dtype, **attrs) layer_cls = partial(TransformerLayer, hidden_dropout_dims=(sequence_dim,), + intermediate_dropout_dims=(sequence_dim,), layer_type=TransformerLayerType.ENCODER, self_attn_mask_type='padding', dtype=dtype, @@ -212,6 +214,7 @@ def forward_backward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e- if k == 'dropout_rate': te_layer_attrs['attention_dropout'] = v te_layer_attrs['hidden_dropout'] = v + te_layer_attrs['intermediate_dropout'] = v elif k == 'fuse_mlp_wi': continue else: @@ -219,6 +222,7 @@ def forward_backward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e- ref_layer_cls = partial(RefEncoderLayer, dtype=dtype, **attrs) layer_cls = partial(TransformerLayer, hidden_dropout_dims=(sequence_dim,), + intermediate_dropout_dims=(sequence_dim,), layer_type=TransformerLayerType.ENCODER, self_attn_mask_type='padding', dtype=dtype, @@ -381,6 +385,7 @@ def forward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): if k == 'dropout_rate': te_layer_attrs['attention_dropout'] = v te_layer_attrs['hidden_dropout'] = v + te_layer_attrs['intermediate_dropout'] = v elif k == 'fuse_mlp_wi': continue else: @@ -388,6 +393,7 @@ def forward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): ref_layer_cls = partial(RefDecoderLayer, dtype=dtype, **attrs) layer_cls = partial(TransformerLayer, hidden_dropout_dims=(sequence_dim,), + intermediate_dropout_dims=(sequence_dim,), layer_type=TransformerLayerType.DECODER, dtype=dtype, **te_layer_attrs) @@ -426,6 +432,7 @@ def forward_backward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e- if k == 'dropout_rate': te_layer_attrs['attention_dropout'] = v te_layer_attrs['hidden_dropout'] = v + te_layer_attrs['intermediate_dropout'] = v elif k == 'fuse_mlp_wi': continue else: @@ -433,6 +440,7 @@ def forward_backward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e- ref_layer_cls = partial(RefDecoderLayer, dtype=dtype, **attrs) layer_cls = partial(TransformerLayer, hidden_dropout_dims=(sequence_dim,), + intermediate_dropout_dims=(sequence_dim,), layer_type=TransformerLayerType.DECODER, dtype=dtype, **te_layer_attrs) diff --git a/tests/jax/test_praxis_layers.py b/tests/jax/test_praxis_layers.py index 12ad919077..5a1bf41fb2 100644 --- a/tests/jax/test_praxis_layers.py +++ b/tests/jax/test_praxis_layers.py @@ -957,6 +957,7 @@ def generate_praxis_p_and_flax_cls(self, dtype, attrs): layernorm_type = attrs[TransformerLayerAttr.LN_TYPE] hidden_dropout = 0.0 attention_dropout = 0.0 + intermediate_dropout = 0.0 mlp_activations = attrs[TransformerLayerAttr.ACTIVATION] kernel_init = WeightInit.Gaussian(1.0) use_bias = attrs[TransformerLayerAttr.USE_BIAS] @@ -991,6 +992,7 @@ def generate_praxis_p_and_flax_cls(self, dtype, attrs): layernorm_type=layernorm_type, hidden_dropout=hidden_dropout, attention_dropout=attention_dropout, + intermediate_dropout=intermediate_dropout, mlp_activations=mlp_activations, use_bias=use_bias, bias_init=bias_init, @@ -1007,6 +1009,7 @@ def generate_praxis_p_and_flax_cls(self, dtype, attrs): layernorm_type=layernorm_type, hidden_dropout=hidden_dropout, attention_dropout=attention_dropout, + intermediate_dropout=intermediate_dropout, mlp_activations=mlp_activations, mha_kernel_init=TransformerEngineBaseLayer.generate_params_init( "mha_kernel", kernel_init), diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index d95bece5ad..89da212367 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -739,6 +739,8 @@ class LayerNormMLP(TransformerEngineBase): activations: Sequence[Union[str, Callable]], default = ('relu',) The sequence of activation functions to apply after the first linear transformation. Each activation has its own transformation layer. + intermediate_dropout_rng_name: str, default = 'dropout' + The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks. intermediate_dropout_rate: float, default = 0.1 Dropout probability for the dropout op after the :attr:`activations`. intermediate_hidden_dropout_dims: Sequence[int], default = () @@ -779,6 +781,7 @@ class LayerNormMLP(TransformerEngineBase): bias_axes_2: Tuple[str, ...] = ('embed',) return_layernorm_output: bool = True activations: Sequence[Union[str, Callable]] = ('relu',) + intermediate_dropout_rng_name: str = 'dropout' intermediate_dropout_rate: float = 0.1 intermediate_hidden_dropout_dims: Sequence[int] = () axis: Union[Iterable[int], int] = -1 @@ -985,7 +988,8 @@ def fp8_meta_generator(): z = jnp.reshape(z, (*z.shape[:-2], -1)) z = nn.Dropout(rate=self.intermediate_dropout_rate, - broadcast_dims=self.intermediate_hidden_dropout_dims)( + broadcast_dims=self.intermediate_hidden_dropout_dims, + rng_collection=self.intermediate_dropout_rng_name)( z, deterministic=deterministic) # DenseGeneral 2 diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index 2a3d5979fd..451d7731b1 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -883,6 +883,10 @@ class TransformerLayer(nn.Module): Dimensions that will share the same dropout mask for hidden attention_dropout: float, default = 0.1 Dropout probability for the dropout op during multi-head attention. + intermediate_dropout: float, default = 0.1 + Dropout probability for the dropout op after FC1 layer. + intermediate_dropout_dims: Sequence[int], default = () + Dimensions that will share the same dropout mask for hidden after FC1 layer. dropout_rng_name: str, default = 'dropout' The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks in the Multi-Head Attention. @@ -963,6 +967,8 @@ class TransformerLayer(nn.Module): hidden_dropout: float = 0.1 hidden_dropout_dims: Sequence[int] = () attention_dropout: float = 0.1 + intermediate_dropout: float = 0.1 + intermediate_dropout_dims: Sequence[int] = () dropout_rng_name: str = 'dropout' mha_kernel_init: Initializer = None mlp_kernel_init: Initializer = None @@ -1078,6 +1084,8 @@ def __call__(self, else: mha_name = 'self_attention' + inputs = _with_sharding_constraint(inputs, (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)) + # [batch, length, emb_dim] -> [batch, length, emb_dim] x, residual = MultiHeadAttention( num_heads=self.num_attention_heads, @@ -1113,14 +1121,15 @@ def hidden_dropout(x, deterministic): assert -x_shape_len <= dims < x_shape_len return nn.Dropout(rate=self.hidden_dropout, - broadcast_dims=self.hidden_dropout_dims)(x, - deterministic=deterministic) + broadcast_dims=self.hidden_dropout_dims, + rng_collection=self.dropout_rng_name)(x, deterministic=deterministic) x = hidden_dropout(x, deterministic) if self.drop_path > 0.0: drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim) x = nn.Dropout(rate=self.drop_path, - broadcast_dims=drop_path_shape)(x, deterministic=deterministic) + broadcast_dims=drop_path_shape, + rng_collection=self.dropout_rng_name)(x, deterministic=deterministic) x = x + residual mlp_input = x @@ -1156,6 +1165,8 @@ def hidden_dropout(x, deterministic): y = hidden_dropout(y, deterministic) mlp_input = y + residual + mlp_input = _with_sharding_constraint(mlp_input, (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)) + # MlpBlock residual = mlp_input z, ln_out = LayerNormMLP( @@ -1167,8 +1178,9 @@ def hidden_dropout(x, deterministic): return_layernorm_output=self.apply_residual_connection_post_layernorm, intermediate_dim=self.mlp_hidden_size, activations=self.mlp_activations, - intermediate_dropout_rate=self.hidden_dropout, - intermediate_hidden_dropout_dims=self.hidden_dropout_dims, + intermediate_dropout_rng_name=self.dropout_rng_name, + intermediate_dropout_rate=self.intermediate_dropout, + intermediate_hidden_dropout_dims=self.intermediate_dropout_dims, dtype=self.dtype, scale_axes=(W_NO_SHARD_AXES,), ln_bias_axes=(W_NO_SHARD_AXES,), diff --git a/transformer_engine/jax/fp8.py b/transformer_engine/jax/fp8.py index 83aad88c07..c64bcbd6d0 100644 --- a/transformer_engine/jax/fp8.py +++ b/transformer_engine/jax/fp8.py @@ -310,11 +310,11 @@ def _update_fp8_metas_impl(fp8_metas: Collection) -> Collection: amax = fp8_meta_arrays[fp8_amax_idx][..., 0:1] scale = fp8_meta_arrays[fp8_scale_idx] - sf = (fp8_max / amax) / (2 ** FP8Helper.MARGIN) + sf = (fp8_max / amax) / (2**FP8Helper.MARGIN) sf = jnp.where(amax > 0.0, sf, scale) sf = jnp.where(jnp.isfinite(amax), sf, scale) - fp8_meta_arrays[fp8_scale_idx] = scale - fp8_meta_arrays[fp8_scale_inv_idx] = 1 / scale + fp8_meta_arrays[fp8_scale_idx] = sf + fp8_meta_arrays[fp8_scale_inv_idx] = 1 / sf return jax.tree_util.tree_unflatten(treedef, fp8_meta_arrays) diff --git a/transformer_engine/jax/praxis/transformer.py b/transformer_engine/jax/praxis/transformer.py index 9bf9628490..b16c4e731e 100644 --- a/transformer_engine/jax/praxis/transformer.py +++ b/transformer_engine/jax/praxis/transformer.py @@ -137,6 +137,8 @@ class TransformerLayer(TransformerEngineBaseLayer): hidden_dropout: float = 0.1 hidden_dropout_dims: Sequence[int] = () attention_dropout: float = 0.1 + intermediate_dropout: float = 0.1 + intermediate_dropout_dims: Sequence[int] = () dropout_rng_name: str = 'dropout' mlp_activations: Sequence[str] = ('relu',) use_bias: bool = False @@ -190,6 +192,8 @@ def setup(self) -> None: hidden_dropout=self.hidden_dropout, hidden_dropout_dims=self.hidden_dropout_dims, attention_dropout=self.attention_dropout, + intermediate_dropout=self.intermediate_dropout, + intermediate_dropout_dims=self.intermediate_dropout_dims, dropout_rng_name=self.dropout_rng_name, mha_kernel_init=TransformerEngineBaseLayer.generate_params_init( "mha_kernel", self.params_init), From 61a6a188914bf56cd3aa05cc77d1e88412c9bb0c Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 19 Oct 2023 14:44:31 -0700 Subject: [PATCH 62/68] [PyTorch] rm unused docs (#484) RM unused docs Signed-off-by: Kirthi Shankar Sivamani --- docs/api/pytorch.rst | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/api/pytorch.rst b/docs/api/pytorch.rst index e31f44fef5..aea66b257f 100644 --- a/docs/api/pytorch.rst +++ b/docs/api/pytorch.rst @@ -29,7 +29,6 @@ pyTorch :members: forward, set_context_parallel_group, set_tensor_parallel_group .. autoapiclass:: transformer_engine.pytorch.InferenceParams(max_batch_size, max_sequence_length) - :members: swap_key_value_dict .. autoapiclass:: transformer_engine.pytorch.CudaRNGStatesTracker() :members: reset, get_states, set_states, add, fork From 719f422f802086d995446431388849b2749c4d94 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Fri, 20 Oct 2023 01:14:51 -0700 Subject: [PATCH 63/68] Fix incorrect dtype in LayerNormLinear (#483) Signed-off-by: Tim Moon Co-authored-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/module/layernorm_linear.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index a910946218..a8e83631bc 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -112,7 +112,7 @@ def forward( ub_obj_lnout = get_ub("qkv_fprop") ln_out = ub_obj_lnout.get_ubuf_output(0) else: - ln_out_dtype = torch.uint8 if fp8 else inputmat.dtype + ln_out_dtype = torch.uint8 if (fp8 and not return_layernorm_output) else inputmat.dtype ln_out = torch.empty_like(inputmat, dtype=ln_out_dtype) if ub_atomic_gemm_ag: assert fp8, "AtomicGemm overlap supported only for FP8 GEMM." From 1214da0e47662a1d1aa9fad1b622ca59a707a651 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Fri, 20 Oct 2023 13:11:04 -0500 Subject: [PATCH 64/68] Incorrect use of extend_fsdp_sharding_meta() in cross_fused_attn() (#482) fixed incorrect of extend_fsdp_sharding_meta() in cross_fused_attn() Signed-off-by: Alp Dener --- transformer_engine/jax/fused_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/jax/fused_attn.py b/transformer_engine/jax/fused_attn.py index aaca58b2d5..3951d87274 100644 --- a/transformer_engine/jax/fused_attn.py +++ b/transformer_engine/jax/fused_attn.py @@ -206,7 +206,7 @@ def cross_fused_attn(q: jnp.ndarray, tp_dims=([2, 3, None, None], [2]), dp_axis_name=dp_axis_name, tp_axis_name=tp_axis_name) - sharding_meta = extend_fsdp_sharding_meta(sharding_meta, {0: 0, 2: 0}) + sharding_meta, _ = extend_fsdp_sharding_meta(sharding_meta, {0: 0, 2: 0}) inputs_ = tuple( jnp.reshape(x, new_shape) if x is not None else None From ebfeaad52204ce687f908e4fdbcf8caff704f1b8 Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Fri, 20 Oct 2023 07:37:22 +0200 Subject: [PATCH 65/68] Better way of checking cuDNN version (#485) * Ability to check cuDNN version from Python Signed-off-by: Przemek Tredak * Modify the fused attention test to not use the CUDNN_VERSION env variable which is specific to NGC containers Signed-off-by: Przemek Tredak --------- Signed-off-by: Przemek Tredak --- tests/pytorch/test_fused_attn.py | 10 +++++++++- transformer_engine/pytorch/csrc/common.h | 1 + transformer_engine/pytorch/csrc/extensions.h | 2 ++ transformer_engine/pytorch/csrc/extensions/misc.cu | 4 ++++ transformer_engine/pytorch/csrc/extensions/pybind.cpp | 1 + 5 files changed, 17 insertions(+), 1 deletion(-) diff --git a/tests/pytorch/test_fused_attn.py b/tests/pytorch/test_fused_attn.py index a3a2656d0b..ac868b83d9 100644 --- a/tests/pytorch/test_fused_attn.py +++ b/tests/pytorch/test_fused_attn.py @@ -44,7 +44,15 @@ fp8_available, reason_for_no_fp8 = fp8.FP8GlobalStateManager.is_fp8_available() _flash_attn_version = packaging.version.Version(version("flash-attn")) _flash_attn_2_available = _flash_attn_version >= packaging.version.Version("2") -_cudnn_version = [int(i) for i in os.environ['CUDNN_VERSION'].split('.')] + +def _get_cudnn_version(): + cudnn_version_encoded = ext.get_cudnn_version() + cudnn_major = cudnn_version_encoded // 1000 + cudnn_minor = (cudnn_version_encoded - cudnn_major * 1000) // 100 + cudnn_patch = cudnn_version_encoded - 1000 * cudnn_major - 100 * cudnn_minor + return [cudnn_major, cudnn_minor, cudnn_patch] + +_cudnn_version = _get_cudnn_version() class ModelConfig: diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 7c17f1f34c..d40f3db45b 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -31,6 +31,7 @@ #include #include #include +#include #include #include #include diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 4eaca7c896..d1789cedb2 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -524,6 +524,8 @@ at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_, size_t get_cublasLt_version(); +size_t get_cudnn_version(); + bool userbuf_comm_available(); void placeholder(); diff --git a/transformer_engine/pytorch/csrc/extensions/misc.cu b/transformer_engine/pytorch/csrc/extensions/misc.cu index e6275d1159..48aa98bbf1 100644 --- a/transformer_engine/pytorch/csrc/extensions/misc.cu +++ b/transformer_engine/pytorch/csrc/extensions/misc.cu @@ -13,6 +13,10 @@ size_t get_cublasLt_version() { return cublasLtGetVersion(); } +size_t get_cudnn_version() { + return cudnnGetVersion(); +} + bool userbuf_comm_available() { // TODO(ksivamani) check on python side #ifdef NVTE_WITH_USERBUFFERS diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 7e80299d15..fd117782ab 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -77,6 +77,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // Misc m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version"); + m.def("get_cudnn_version", &get_cudnn_version, "Get cuDNN version"); m.def("userbuf_comm_available", &userbuf_comm_available, "If userbuf backend is available"); // Data structures From 7eca973ae8dcf6b62d755db18096a41f47b40337 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Mon, 23 Oct 2023 14:23:05 -0700 Subject: [PATCH 66/68] [PyTorch] Fixes and tests for FP8 + activation recompute (#487) * initial test fix Signed-off-by: Kirthi Shankar Sivamani * Drop eval for selective checkpointing tests Signed-off-by: Kirthi Shankar Sivamani * Remove redundant recompute for FA Signed-off-by: Kirthi Shankar Sivamani * CI fix; Decouple fused attention and numerics tests Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani --- tests/pytorch/test_fused_attn.py | 36 ++++- tests/pytorch/test_numerics.py | 152 ++++++++++++---------- transformer_engine/pytorch/attention.py | 13 -- transformer_engine/pytorch/fp8.py | 23 ++++ transformer_engine/pytorch/module/base.py | 20 +-- 5 files changed, 154 insertions(+), 90 deletions(-) diff --git a/tests/pytorch/test_fused_attn.py b/tests/pytorch/test_fused_attn.py index ac868b83d9..fd37bd371c 100644 --- a/tests/pytorch/test_fused_attn.py +++ b/tests/pytorch/test_fused_attn.py @@ -25,8 +25,6 @@ QKVLayout, fused_attn_bwd, fused_attn_fwd, - fused_attn_bwd_qkvpacked, - fused_attn_fwd_qkvpacked, ) import transformer_engine.pytorch.fp8 as fp8 from transformer_engine.pytorch.module.base import ( @@ -38,13 +36,24 @@ init_method_normal, scaled_init_method_normal, ) +from transformer_engine.pytorch.distributed import _set_cuda_rng_state, CudaRNGStatesTracker import transformer_engine_extensions as tex -from test_numerics import get_dummy_cuda_rng_tracker, reset_rng_states + +# Only run FP8 tests on H100. fp8_available, reason_for_no_fp8 = fp8.FP8GlobalStateManager.is_fp8_available() _flash_attn_version = packaging.version.Version(version("flash-attn")) _flash_attn_2_available = _flash_attn_version >= packaging.version.Version("2") + +seed = 1234 +torch.manual_seed(seed) +torch.cuda.manual_seed(seed) +# Record initial RNG state from script run. +_cpu_rng_state = torch.get_rng_state() +_cuda_rng_state = torch.cuda.get_rng_state() + + def _get_cudnn_version(): cudnn_version_encoded = ext.get_cudnn_version() cudnn_major = cudnn_version_encoded // 1000 @@ -52,6 +61,13 @@ def _get_cudnn_version(): cudnn_patch = cudnn_version_encoded - 1000 * cudnn_major - 100 * cudnn_minor return [cudnn_major, cudnn_minor, cudnn_patch] + +def reset_rng_states() -> None: + """revert back to initial RNG state.""" + torch.set_rng_state(_cpu_rng_state) + _set_cuda_rng_state(_cuda_rng_state) + + _cudnn_version = _get_cudnn_version() @@ -210,6 +226,13 @@ def _run_dot_product_attention(dtype, bs, config, backend, ckpt_attn, bias_type) else: bias = None + _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() + _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) + + def get_dummy_cuda_rng_tracker(): + """Get cuda rng tracker.""" + return _DUMMY_CUDA_RNG_STATE_TRACKER + block = ( DotProductAttention( config.num_attention_heads, @@ -733,6 +756,13 @@ def _run_dpa_fp8_ref(dtype, bs, config, backend): cu_seqlens[1:] = torch.cumsum(seqlens, dim=0) op_grad = torch.load('op_grad.pt').cuda().view(bs, config.seq_len, -1).transpose(0,1) + _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() + _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) + + def get_dummy_cuda_rng_tracker(): + """Get cuda rng tracker.""" + return _DUMMY_CUDA_RNG_STATE_TRACKER + block = ( DotProductAttention( config.num_attention_heads, diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 21ee0968d9..02fb63e71f 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -12,6 +12,7 @@ import torch.nn as nn from torch.nn import Parameter +from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager from transformer_engine.pytorch.utils import ( init_method_normal, scaled_init_method_normal, @@ -25,6 +26,10 @@ from transformer_engine.pytorch.distributed import _set_cuda_rng_state, CudaRNGStatesTracker +# Only run FP8 tests on H100. +fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() + + seed = 1234 torch.manual_seed(seed) torch.cuda.manual_seed(seed) @@ -90,20 +95,11 @@ def assert_allclose(l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float) def reset_rng_states() -> None: - # revert back to initial RNG state. + """revert back to initial RNG state.""" torch.set_rng_state(_cpu_rng_state) _set_cuda_rng_state(_cuda_rng_state) -_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() -_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) - - -def get_dummy_cuda_rng_tracker(): - """Get cuda rng tracker.""" - return _DUMMY_CUDA_RNG_STATE_TRACKER - - class TorchScaledMaskedSoftmax(nn.Module): def __init__(self) -> None: super().__init__() @@ -343,41 +339,21 @@ def forward( return x -def _test_e2e_selective_recompute(block, bs, dtype, config, recompute=False): +def _test_e2e_selective_recompute(bs, dtype, config, fp8, recompute=False): reset_rng_states() - - te_inp_hidden_states = torch.randn( - config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True - ).cuda() - te_inp_hidden_states.retain_grad() - te_inp_attn_mask = get_causal_attn_mask(config.seq_len) - - te_out = block( - te_inp_hidden_states, - attention_mask=te_inp_attn_mask, - checkpoint_core_attention=recompute, - ) - loss = te_out.sum() - loss.backward() - torch.cuda.synchronize() - - outputs = [te_out, te_inp_hidden_states.grad] - for p in block.parameters(): - if p.requires_grad: - outputs.append(p.grad) - return outputs - - -@pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("bs", batch_sizes) -@pytest.mark.parametrize("model", model_configs.keys()) -def test_gpt_selective_activation_recompute(dtype, bs, model): - config = model_configs[model] + FP8GlobalStateManager.reset() sigma = 0.023 init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) + _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() + _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) + + def get_dummy_cuda_rng_tracker(): + """Get cuda rng tracker.""" + return _DUMMY_CUDA_RNG_STATE_TRACKER + block = ( TransformerLayer( config.hidden_size, @@ -395,38 +371,19 @@ def test_gpt_selective_activation_recompute(dtype, bs, model): params_dtype=dtype, ) .cuda() - .eval() ) - outputs = _test_e2e_selective_recompute(block, bs, dtype, config, recompute=False) - outputs_recompute = _test_e2e_selective_recompute(block, bs, dtype, config, recompute=True) - assert_all_equal(outputs, outputs_recompute) - - -def _test_e2e_full_recompute(block, bs, dtype, config, recompute=False): - reset_rng_states() - te_inp_hidden_states = torch.randn( config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True ).cuda() te_inp_hidden_states.retain_grad() te_inp_attn_mask = get_causal_attn_mask(config.seq_len) - if recompute: - te_out = te_checkpoint( - block, - False, # distribute_saved_activations - get_dummy_cuda_rng_tracker, - None, # tp_group - te_inp_hidden_states, - attention_mask=te_inp_attn_mask, - checkpoint_core_attention=False, - ) - else: + with fp8_autocast(enabled=fp8): te_out = block( te_inp_hidden_states, attention_mask=te_inp_attn_mask, - checkpoint_core_attention=False, + checkpoint_core_attention=recompute, ) loss = te_out.sum() loss.backward() @@ -442,13 +399,33 @@ def _test_e2e_full_recompute(block, bs, dtype, config, recompute=False): @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", model_configs.keys()) -def test_gpt_full_activation_recompute(dtype, bs, model): +@pytest.mark.parametrize("fp8", all_boolean) +def test_gpt_selective_activation_recompute(dtype, bs, model, fp8): + if fp8 and not fp8_available: + pytest.skip(reason_for_no_fp8) + config = model_configs[model] + outputs = _test_e2e_selective_recompute(bs, dtype, config, fp8, recompute=False) + outputs_recompute = _test_e2e_selective_recompute(bs, dtype, config, fp8, recompute=True) + assert_all_equal(outputs, outputs_recompute) + + +def _test_e2e_full_recompute(bs, dtype, config, fp8, recompute=False): + reset_rng_states() + FP8GlobalStateManager.reset() + sigma = 0.023 init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) + _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() + _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) + + def get_dummy_cuda_rng_tracker(): + """Get cuda rng tracker.""" + return _DUMMY_CUDA_RNG_STATE_TRACKER + block = ( TransformerLayer( config.hidden_size, @@ -466,11 +443,54 @@ def test_gpt_full_activation_recompute(dtype, bs, model): params_dtype=dtype, ) .cuda() - .eval() ) - outputs = _test_e2e_full_recompute(block, bs, dtype, config, recompute=False) - outputs_recompute = _test_e2e_full_recompute(block, bs, dtype, config, recompute=True) + te_inp_hidden_states = torch.randn( + config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True + ).cuda() + te_inp_hidden_states.retain_grad() + te_inp_attn_mask = get_causal_attn_mask(config.seq_len) + + with fp8_autocast(enabled=fp8): + if recompute: + te_out = te_checkpoint( + block, + False, # distribute_saved_activations + get_dummy_cuda_rng_tracker, + None, # tp_group + te_inp_hidden_states, + attention_mask=te_inp_attn_mask, + checkpoint_core_attention=False, + ) + else: + te_out = block( + te_inp_hidden_states, + attention_mask=te_inp_attn_mask, + checkpoint_core_attention=False, + ) + loss = te_out.sum() + loss.backward() + torch.cuda.synchronize() + + outputs = [te_out, te_inp_hidden_states.grad] + for p in block.parameters(): + if p.requires_grad: + outputs.append(p.grad) + return outputs + + +@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("bs", batch_sizes) +@pytest.mark.parametrize("model", model_configs.keys()) +@pytest.mark.parametrize("fp8", all_boolean) +def test_gpt_full_activation_recompute(dtype, bs, model, fp8): + if fp8 and not fp8_available: + pytest.skip(reason_for_no_fp8) + + config = model_configs[model] + + outputs = _test_e2e_full_recompute(bs, dtype, config, fp8, recompute=False) + outputs_recompute = _test_e2e_full_recompute(bs, dtype, config, fp8, recompute=True) assert_all_equal(outputs, outputs_recompute) @@ -565,8 +585,8 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path= def test_gpt_checkpointing(dtype, bs, model): config = model_configs[model] outputs = _test_e2e_checkpointing(bs, dtype, config, checkpoint=False) - outputs_recompute = _test_e2e_checkpointing(bs, dtype, config, checkpoint=True) - assert_all_equal(outputs, outputs_recompute) + outputs_checkpoint = _test_e2e_checkpointing(bs, dtype, config, checkpoint=True) + assert_all_equal(outputs, outputs_checkpoint) def _test_e2e_gpt_accuracy(block, bs, dtype, config): diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 0d2dbe0bc8..6f1aafe3f0 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -2164,19 +2164,6 @@ def forward( ) if use_flash_attention: - if checkpoint_core_attention: - return self._checkpointed_attention_forward(self.flash_attention, - query_layer, - key_layer, - value_layer, - attention_mask=attention_mask, - qkv_layout=qkv_layout, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_kv=cu_seqlens_kv, - attn_mask_type=attn_mask_type, - cp_group=self.cp_group, - cp_global_ranks=self.cp_global_ranks, - cp_stream=self.cp_stream) return self.flash_attention(query_layer, key_layer, value_layer, diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 24c97be6e9..c89ff10968 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -75,6 +75,29 @@ class FP8GlobalStateManager: dp_amax_reduce_forward_idx = 0 dp_amax_reduce_backward_idx = 0 + @classmethod + def reset(cls) -> None: + """Reset the global state""" + cls.FP8_ENABLED = False + cls.FP8_CALIBRATION = False + cls.FP8_RECIPE = None + cls.FP8_DISTRIBUTED_GROUP = None + cls.IS_FIRST_FP8_MODULE = False + cls.FP8_AUTOCAST_COUNTER = 0 + cls.FP8_CURRENT_CONTEXT_ID = 0 + cls.FP8_AUTOCAST_DEPTH = 0 + cls.global_fp8_buffer = {} + cls.fp8_tensors_recompute_buffer = [] + cls.amax_forward_global_reduce_func = None + cls.buffer_delete_key_fwd = None + cls.buffer_delete_key_bwd = None + cls.amax_reduce_handle_fwd = None + cls.fp8_available = None + cls.reason_for_no_fp8 = "" + cls.dp_amax_reduce_interval = None + cls.dp_amax_reduce_forward_idx = 0 + cls.dp_amax_reduce_backward_idx = 0 + @classmethod def is_fp8_available(cls) -> Tuple[bool, str]: """Return if fp8 support is available""" diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 73b0bcdb76..5803cfa2f9 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -28,6 +28,7 @@ gather_along_first_dim, is_fp8_activation_recompute_enabled, in_fp8_activation_recompute_phase, + get_distributed_world_size, ) from ..cpp_extensions import ( fp8_cast_transpose_fused, @@ -77,9 +78,7 @@ def _prepare_backward( _amax_reduce_handle_bwd = None # Update amax and scale; Skip all setup for global amax reduction - if not fp8_meta["recipe"].reduce_amax: - amax_and_scale_update(fp8_meta, False) - else: + if fp8_meta["recipe"].reduce_amax and get_distributed_world_size(fp8_meta["fp8_group"]) > 1: # From previous iteration FP8GlobalStateManager.copy_amax_from_global_buffer(fp8_meta, forward=False) amax_and_scale_update(fp8_meta, False) @@ -89,11 +88,14 @@ def _prepare_backward( fp8_meta["autocast_id_bwd"] = fp8_meta["autocast_id_fwd_stack"].pop(0) FP8GlobalStateManager.add_amax_to_global_buffer(fp8_meta, forward=False) + else: + amax_and_scale_update(fp8_meta, False) with torch.cuda.nvtx.range(name + " backward"): yield - if fp8 and fp8_meta["recipe"].reduce_amax: + if (fp8 and fp8_meta["recipe"].reduce_amax + and get_distributed_world_size(fp8_meta["fp8_group"]) > 1): if fp8_meta["first_module"]: _amax_reduce_handle_bwd = FP8GlobalStateManager.global_amax_reduction( fp8_meta, @@ -549,7 +551,8 @@ def prepare_forward( # Previous iteration was grad_enabled if self.fp8_meta.get("update_amax_and_scale_fwd", False): - if self.fp8_meta["recipe"].reduce_amax: + if (self.fp8_meta["recipe"].reduce_amax + and get_distributed_world_size(self.fp8_meta["fp8_group"]) > 1): FP8GlobalStateManager.copy_amax_from_global_buffer(self.fp8_meta, forward=True) amax_and_scale_update( self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv @@ -562,7 +565,8 @@ def prepare_forward( if self.fp8 and self.training: # Setup for amax reduction - if self.fp8_meta["recipe"].reduce_amax: + if (self.fp8_meta["recipe"].reduce_amax + and get_distributed_world_size(self.fp8_meta["fp8_group"]) > 1): self.fp8_meta["first_module"] = FP8GlobalStateManager.is_first_fp8_module() if self.fp8_meta["first_module"]: # Wait for the prior AMAX reduction to finish @@ -588,7 +592,6 @@ def prepare_forward( self.fp8 and self.training and is_fp8_activation_recompute_enabled() - and not in_fp8_activation_recompute_phase() ): FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta) @@ -599,7 +602,8 @@ def prepare_forward( FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta) return - if self.fp8 and self.training and self.fp8_meta["recipe"].reduce_amax: + if (self.fp8 and self.training and self.fp8_meta["recipe"].reduce_amax + and get_distributed_world_size(self.fp8_meta["fp8_group"]) > 1): FP8GlobalStateManager.set_fp8_context_id(self.fp8_meta["autocast_id_fwd"]) reduce_func = partial( FP8GlobalStateManager.global_amax_reduction, From d58c08c72d289cb80f9c4fb729a2bda80b78b6ca Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Tue, 31 Oct 2023 11:08:34 -0700 Subject: [PATCH 67/68] [PyTorch] Experimental FP8 tensor class (#452) * Experimental FP8 tensor Co-authored-by: Tim Moon Co-authored-by: Sudhakar Singh Co-authored-by: Przemyslaw Tredak Signed-off-by: Kirthi Shankar Sivamani * Add fp8 tensor to ci test Signed-off-by: Kirthi Shankar Sivamani * review comments and tests Signed-off-by: Kirthi Shankar Sivamani * Minor changes Signed-off-by: Kirthi Shankar Sivamani * Default to FP8 usage Signed-off-by: Kirthi Shankar Sivamani * Fix docs Signed-off-by: Kirthi Shankar Sivamani * Naming changes Signed-off-by: Kirthi Shankar Sivamani * minor fix Signed-off-by: Kirthi Shankar Sivamani * Fix transpose caching Signed-off-by: Kirthi Shankar Sivamani * Debug transpose caching Handle case where transpose cache is updated externally. Signed-off-by: Tim Moon * Rename FP8GlobalStateManager.with_fp8_parameters Signed-off-by: Tim Moon * remove Float8Tensor from import API Signed-off-by: Kirthi Shankar Sivamani * Avoid caching FP8 transposes if not required Signed-off-by: Tim Moon * Fix import error in FP8 tensor tests Signed-off-by: Tim Moon * Fix tranpose caching and checkpointing bug Signed-off-by: Kirthi Shankar Sivamani * Improve caching and fix distopt case Signed-off-by: Kirthi Shankar Sivamani * Update transformer_engine/pytorch/float8_tensor.py Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> * Remove recursive logic Signed-off-by: Kirthi Shankar Sivamani * Fix cache reset bug Signed-off-by: Kirthi Shankar Sivamani * Store FP8 attributes in dict Easier for multiple tensors to share, e.g. detached tensors. Signed-off-by: Tim Moon * Make sure scale_inv is 1D tensor Signed-off-by: Tim Moon * Make sure scale_inv is 1D tensor Signed-off-by: Tim Moon * Fixes and detach recipe Signed-off-by: Kirthi Shankar Sivamani * Set default fp8 data type Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani Signed-off-by: Tim Moon Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: Kirthi Shankar Sivamani Co-authored-by: Sudhakar Singh Co-authored-by: Przemyslaw Tredak --- docs/api/pytorch.rst | 2 + qa/L0_pytorch_unittest/test.sh | 1 + tests/pytorch/test_float8tensor.py | 318 ++++++++ tests/pytorch/test_numerics.py | 133 +++- tests/pytorch/test_onnx_export.py | 2 +- tests/pytorch/test_torch_save_load.py | 4 +- transformer_engine/pytorch/__init__.py | 1 + transformer_engine/pytorch/distributed.py | 10 +- transformer_engine/pytorch/float8_tensor.py | 689 ++++++++++++++++++ transformer_engine/pytorch/fp8.py | 63 +- transformer_engine/pytorch/module/base.py | 81 +- .../pytorch/module/layernorm_linear.py | 79 +- .../pytorch/module/layernorm_mlp.py | 119 ++- transformer_engine/pytorch/module/linear.py | 87 ++- 14 files changed, 1448 insertions(+), 141 deletions(-) create mode 100644 tests/pytorch/test_float8tensor.py create mode 100644 transformer_engine/pytorch/float8_tensor.py diff --git a/docs/api/pytorch.rst b/docs/api/pytorch.rst index aea66b257f..f179569251 100644 --- a/docs/api/pytorch.rst +++ b/docs/api/pytorch.rst @@ -35,6 +35,8 @@ pyTorch .. autoapifunction:: transformer_engine.pytorch.fp8_autocast +.. autoapifunction:: transformer_engine.pytorch.fp8_model_init + .. autoapifunction:: transformer_engine.pytorch.checkpoint .. autoapifunction:: transformer_engine.pytorch.onnx_export diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 268a534a82..54ba2a09c0 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -12,3 +12,4 @@ PYTORCH_JIT=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pyt pytest -v -s $TE_PATH/tests/pytorch/test_jit.py pytest -v -s $TE_PATH/tests/pytorch/test_fused_attn.py NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py +pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py diff --git a/tests/pytorch/test_float8tensor.py b/tests/pytorch/test_float8tensor.py new file mode 100644 index 0000000000..dc48c886cf --- /dev/null +++ b/tests/pytorch/test_float8tensor.py @@ -0,0 +1,318 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from collections.abc import Iterable +from typing import Any, Dict, List, Tuple, Union + +import pytest +import torch + +import transformer_engine.common.recipe +import transformer_engine.pytorch as te +from transformer_engine.pytorch.float8_tensor import Float8Tensor +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +import transformer_engine_extensions as tex + +# PyTorch tensor dtypes +_dtypes: List[torch.dtype] = [torch.float32, torch.float16, torch.bfloat16] +# TE FP8 dtypes +_fp8_dtypes: List[tex.DType] = [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2] + +# Numerical tolerances with FP8 types +_tols: Dict[tex.DType, Dict[str, float]] = { + tex.DType.kFloat8E4M3: dict(rtol=0.125, atol=0.0675), # epsilon = 0.0625 + tex.DType.kFloat8E5M2: dict(rtol=0.25, atol=0.125), # epsilon = 0.125 +} + +def _to_list(x: Union[Iterable, Any]) -> List: + """Convert to list if iterable, otherwise put in singleton list""" + if isinstance(x, Iterable): + return list(x) + else: + return [x] + +# Types that can be interpreted as tensor dims +DimsType = Union[Iterable[int], int] + +# Check if FP8 is supported +fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() + +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +class TestFloat8Tensor: + + @staticmethod + def setup_class(cls) -> None: + # Configure RNG + seed = 1234 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + def test_constructor( + self, + dims: DimsType = 1, + fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, + scale_inv: float = 0.375, + dtype: torch.dtype = torch.float32, + ) -> None: + """Call constructor and perform sanity checks""" + dims = _to_list(dims) + tensor = Float8Tensor( + data=torch.zeros(dims, device="cuda", dtype=torch.uint8), + fp8_dtype=fp8_dtype, + fp8_scale_inv=torch.full([1], scale_inv), + dtype=dtype, + ) + assert list(tensor.size()) == dims, "Incorrect dims" + assert tensor.dtype == dtype, "Incorrect nominal dtype" + assert tensor.is_cuda, "Incorrect device" + + def _test_quantize_dequantize( + self, + fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, + scale: float = 3.5, + dtype: torch.dtype = torch.float32, + dims: DimsType = 23, + ) -> None: + """Check numerical error when casting to FP8 and back""" + + # Initialize random data + x_ref = 2 * torch.rand(_to_list(dims), dtype=dtype, device="cpu") - 1 + + # Cast to FP8 and back + x_fp8 = Float8Tensor.to_float8( + x_ref, + fp8_dtype=fp8_dtype, + scale=torch.full([1], scale), + ) + x_fp8 = x_fp8.from_float8().cpu() + + # Check results + torch.testing.assert_close(x_fp8, x_ref, **_tols[fp8_dtype]) + + # Make sure we are not trivially passing the test + with pytest.raises(AssertionError): + torch.testing.assert_close(x_fp8, -x_ref, **_tols[fp8_dtype]) + + @pytest.mark.parametrize("fp8_dtype", _fp8_dtypes) + @pytest.mark.parametrize("dtype", _dtypes) + def test_quantize_dequantize_dtypes( + self, + fp8_dtype: tex.DType, + dtype: torch.dtype, + ) -> None: + self._test_quantize_dequantize(fp8_dtype=fp8_dtype, dtype=dtype) + + @pytest.mark.parametrize("scale", [0.375, 1, 3.5]) + def test_quantize_dequantize_scales(self, scale: float) -> None: + self._test_quantize_dequantize(scale=scale) + + @pytest.mark.parametrize("dims", [[], 1, 311, [7,11], [7,5,3], [2,3,5,3]]) + def test_quantize_dequantize_dims(self, dims: DimsType) -> None: + self._test_quantize_dequantize(dims=dims) + + def test_fp8_meta( + self, + dtype: torch.dtype = torch.float32, + dims: DimsType = 23, + ) -> None: + """Construct Float8Tensor using FP8 metadata and perform basic checks""" + + # Get FP8 metadata from linear module + fp8_dtype = tex.DType.kFloat8E4M3 + recipe = transformer_engine.common.recipe.DelayedScaling( + fp8_format=transformer_engine.common.recipe.Format.E4M3, + ) + with te.fp8_autocast(enabled=True, fp8_recipe=recipe): + module = te.Linear(32, 32) + _ = module(torch.zeros([8, 32], device="cuda")) + fp8_meta = module.fp8_meta + fp8_meta_index = tex.FP8FwdTensors.GEMM1_WEIGHT + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True) + + # Initialize random data + dims = _to_list(dims) + x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 + + # Make Float8Tensor + x_fp8 = Float8Tensor.to_float8( + x_ref, + fp8_meta=fp8_meta, + fp8_meta_index=fp8_meta_index, + ) + x_ref = x_fp8.from_float8() + assert list(x_fp8.size()) == dims, "Incorrect dims" + assert x_fp8.dtype == dtype, "Incorrect nominal dtype" + assert x_fp8.is_cuda, "Incorrect device" + assert x_fp8._fp8_dtype == fp8_dtype, "Incorrect FP8 dtype" + + # Change FP8 metadata scale + fp8_meta[fp8_meta_key].scale[fp8_meta_index] = 2 + fp8_meta[fp8_meta_key].scale_inv.fill_(123) + + # Check results + torch.testing.assert_close(x_fp8, x_ref, **_tols[fp8_dtype]) + with pytest.raises(AssertionError): + # Make sure we are not trivially passing the test + torch.testing.assert_close(x_fp8, -x_ref, **_tols[fp8_dtype]) + + # Check if scaling factor is updated after in-place ops + x_fp8 += 0 + fp8_meta[fp8_meta_key].scale[fp8_meta_index] = 4 + fp8_meta[fp8_meta_key].scale_inv.fill_(321) + assert x_fp8._scale_inv.item() == 0.5, "Incorrect FP8 scale_inv" + torch.testing.assert_close(x_fp8, x_ref, **_tols[fp8_dtype]) + y = x_fp8.detach() + y += 0 + assert x_fp8._scale_inv.item() == 0.25, "Incorrect FP8 scale_inv" + torch.testing.assert_close(x_fp8, x_ref, **_tols[fp8_dtype]) + + def test_basic_ops( + self, + dims: DimsType = 23, + fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, + scale: float = 3.5, + dtype: torch.dtype = torch.float32, + ) -> None: + """Test basic out-of-place ops""" + + # Initialize random data + dims = _to_list(dims) + x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 + y_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 + x_fp8 = Float8Tensor.to_float8( + x_ref, + fp8_dtype=fp8_dtype, + scale=torch.full([1], scale), + ) + y_fp8 = Float8Tensor.to_float8( + y_ref, + fp8_dtype=fp8_dtype, + scale=torch.full([1], scale), + ) + x_ref = x_fp8.from_float8() + y_ref = y_fp8.from_float8() + + # Exact operations + torch.testing.assert_close(-x_fp8, -x_ref, rtol=0, atol=0) + torch.testing.assert_close(x_fp8.abs(), x_ref.abs(), rtol=0, atol=0) + + # Operations with numerical error + tols = _tols[fp8_dtype] + torch.testing.assert_close(x_fp8 + y_fp8, x_ref + y_ref, **tols) + torch.testing.assert_close(x_fp8 - y_fp8, x_ref - y_ref, **tols) + torch.testing.assert_close(x_fp8 * y_fp8, x_ref * y_ref, **tols) + torch.testing.assert_close(x_fp8 + y_ref, x_ref + y_ref, **tols) + torch.testing.assert_close(x_ref + y_fp8, x_ref + y_ref, **tols) + torch.testing.assert_close(torch.sin(x_fp8), torch.sin(x_ref), **tols) + + # Make sure we are not trivially passing tests + with pytest.raises(AssertionError): + torch.testing.assert_close(x_fp8 + y_fp8, x_ref - y_fp8, **tols) + + def test_inplace_ops( + self, + dims: DimsType = 23, + fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, + scale: float = 3.5, + dtype: torch.dtype = torch.float32, + ) -> None: + """Test in-place ops""" + + # Initialize random data + dims = _to_list(dims) + x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 + y_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 + x_fp8 = Float8Tensor.to_float8( + x_ref, + fp8_dtype=fp8_dtype, + scale=torch.full([1], scale), + ) + y_fp8 = Float8Tensor.to_float8( + y_ref, + fp8_dtype=fp8_dtype, + scale=torch.full([1], scale), + ) + x_ref = x_fp8.from_float8() + y_ref = y_fp8.from_float8() + + # In-place operations + tols = _tols[fp8_dtype] + x_fp8 += y_ref + x_ref += y_ref + torch.testing.assert_close(x_fp8, x_ref, **tols) + x_ref = x_fp8.from_float8() + x_fp8 -= y_fp8 + x_ref -= y_fp8 + torch.testing.assert_close(x_fp8, x_ref, **tols) + x_ref = x_fp8.from_float8() + x_fp8 *= 2 + x_ref *= 2 + torch.testing.assert_close(x_fp8, x_ref, **tols) + x_ref = x_fp8.from_float8() + + # Make sure we are not trivially passing tests + x_ref += 123 + with pytest.raises(AssertionError): + torch.testing.assert_close(x_fp8, x_ref, **tols) + + @pytest.mark.parametrize("dims", [[33, 41], [5, 7, 11]]) + @pytest.mark.parametrize("transpose_dims", [(0, 1), (-2, -1), (0, 0)]) + def test_transpose( + self, + dims: DimsType, + transpose_dims: Tuple[int, int], + fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, + scale: float = 1, + dtype: torch.dtype = torch.float32, + ) -> None: + """Test transpose""" + + # Initialize random data + dims = _to_list(dims) + x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 + x_fp8 = Float8Tensor.to_float8( + x_ref, + fp8_dtype=fp8_dtype, + scale=torch.full([1], scale), + ) + x_ref = x_fp8.from_float8() + + # Perform transpose + y_fp8 = x_fp8.transpose(*transpose_dims) + y_ref = x_ref.transpose(*transpose_dims) + + # Check results + tols = dict(rtol=0, atol=0) + torch.testing.assert_close(y_fp8, y_ref, **tols) + + # Make sure we are not trivially passing the test + if transpose_dims[0] != transpose_dims[1]: + with pytest.raises(AssertionError): + torch.testing.assert_close( + y_fp8, + x_ref, + **tols, + ) + + # Check transpose caching + if x_fp8.dim() == 2 and transpose_dims[0] != transpose_dims[1]: + x_fp8 += 0.5 + x_ref = x_fp8.from_float8() + torch.testing.assert_close( + x_fp8.transpose(*transpose_dims, update_cache=True), + x_ref.transpose(*transpose_dims), + **tols, + ) + torch.testing.assert_close( + x_fp8.transpose(*transpose_dims, update_cache=True), + x_ref.transpose(*transpose_dims), + **tols, + ) + x_fp8 += 0.5 + x_ref = x_fp8.from_float8() + torch.testing.assert_close( + x_fp8.transpose(*transpose_dims, update_cache=True), + x_ref.transpose(*transpose_dims), + **tols, + ) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 02fb63e71f..474f0a95b9 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -12,7 +12,7 @@ import torch.nn as nn from torch.nn import Parameter -from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager +from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager, fp8_model_init from transformer_engine.pytorch.utils import ( init_method_normal, scaled_init_method_normal, @@ -339,7 +339,7 @@ def forward( return x -def _test_e2e_selective_recompute(bs, dtype, config, fp8, recompute=False): +def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False, recompute=False): reset_rng_states() FP8GlobalStateManager.reset() @@ -354,24 +354,26 @@ def get_dummy_cuda_rng_tracker(): """Get cuda rng tracker.""" return _DUMMY_CUDA_RNG_STATE_TRACKER - block = ( - TransformerLayer( - config.hidden_size, - 4 * config.hidden_size, - config.num_attention_heads, - layernorm_epsilon=config.eps, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - hidden_dropout=0.1, - attention_dropout=0.1, - kv_channels=config.embed, - apply_residual_connection_post_layernorm=False, - output_layernorm=False, - get_rng_state_tracker=get_dummy_cuda_rng_tracker, - params_dtype=dtype, + with fp8_model_init(enabled=fp8 and fp8_model_params): + block = ( + TransformerLayer( + config.hidden_size, + 4 * config.hidden_size, + config.num_attention_heads, + layernorm_epsilon=config.eps, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + hidden_dropout=0.1, + attention_dropout=0.1, + kv_channels=config.embed, + apply_residual_connection_post_layernorm=False, + output_layernorm=False, + get_rng_state_tracker=get_dummy_cuda_rng_tracker, + params_dtype=dtype, + fuse_qkv_params=True, + ) + .cuda() ) - .cuda() - ) te_inp_hidden_states = torch.randn( config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True @@ -400,18 +402,19 @@ def get_dummy_cuda_rng_tracker(): @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("fp8", all_boolean) -def test_gpt_selective_activation_recompute(dtype, bs, model, fp8): +@pytest.mark.parametrize("fp8_model_params", all_boolean) +def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, fp8_model_params): if fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) config = model_configs[model] - outputs = _test_e2e_selective_recompute(bs, dtype, config, fp8, recompute=False) - outputs_recompute = _test_e2e_selective_recompute(bs, dtype, config, fp8, recompute=True) + outputs = _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=False) + outputs_recompute = _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=True) assert_all_equal(outputs, outputs_recompute) -def _test_e2e_full_recompute(bs, dtype, config, fp8, recompute=False): +def _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params=False, recompute=False): reset_rng_states() FP8GlobalStateManager.reset() @@ -426,7 +429,8 @@ def get_dummy_cuda_rng_tracker(): """Get cuda rng tracker.""" return _DUMMY_CUDA_RNG_STATE_TRACKER - block = ( + with fp8_model_init(enabled=fp8 and fp8_model_params): + block = ( TransformerLayer( config.hidden_size, 4 * config.hidden_size, @@ -441,9 +445,10 @@ def get_dummy_cuda_rng_tracker(): output_layernorm=False, get_rng_state_tracker=get_dummy_cuda_rng_tracker, params_dtype=dtype, + fuse_qkv_params=True, ) .cuda() - ) + ) te_inp_hidden_states = torch.randn( config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True @@ -483,14 +488,15 @@ def get_dummy_cuda_rng_tracker(): @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("fp8", all_boolean) -def test_gpt_full_activation_recompute(dtype, bs, model, fp8): +@pytest.mark.parametrize("fp8_model_params", all_boolean) +def test_gpt_full_activation_recompute(dtype, bs, model, fp8, fp8_model_params): if fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) config = model_configs[model] - outputs = _test_e2e_full_recompute(bs, dtype, config, fp8, recompute=False) - outputs_recompute = _test_e2e_full_recompute(bs, dtype, config, fp8, recompute=True) + outputs = _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=False) + outputs_recompute = _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=True) assert_all_equal(outputs, outputs_recompute) @@ -871,6 +877,7 @@ def test_linear_accuracy(dtype, bs, model): else: assert_allclose(te_outputs[0], torch_outputs[0], 5e-2) + @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", model_configs.keys()) @@ -911,6 +918,7 @@ def test_rmsnorm_accuracy(dtype, bs, model, eps): else: assert_allclose(te_outputs[0], torch_outputs[0], 2e-2) + @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", model_configs.keys()) @@ -1110,3 +1118,72 @@ def test_gpt_cuda_graph(dtype, bs, model): assert_allclose(out, graphed_out, 1e-3) assert_allclose(params, graphed_params, 1e-3) assert_allclose(grads, graphed_grads, 1e-3) + + +def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params): + reset_rng_states() + FP8GlobalStateManager.reset() + + sigma = 0.023 + init_method = init_method_normal(sigma) + output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) + + _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() + _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) + + def get_dummy_cuda_rng_tracker(): + """Get cuda rng tracker.""" + return _DUMMY_CUDA_RNG_STATE_TRACKER + + with fp8_model_init(enabled=fp8_model_params): + block = ( + TransformerLayer( + config.hidden_size, + 4 * config.hidden_size, + config.num_attention_heads, + layernorm_epsilon=config.eps, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + hidden_dropout=0.1, + attention_dropout=0.1, + kv_channels=config.embed, + apply_residual_connection_post_layernorm=False, + output_layernorm=False, + get_rng_state_tracker=get_dummy_cuda_rng_tracker, + params_dtype=dtype, + fuse_qkv_params=True, + ) + .cuda() + ) + + te_inp_hidden_states = torch.randn( + config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True + ).cuda() + te_inp_hidden_states.retain_grad() + te_inp_attn_mask = get_causal_attn_mask(config.seq_len) + + with fp8_autocast(enabled=True): + te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask) + loss = te_out.sum() + loss.backward() + torch.cuda.synchronize() + + outputs = [te_out, te_inp_hidden_states.grad] + for p in block.parameters(): + if p.requires_grad: + outputs.append(p.grad) + return outputs + + +@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("bs", batch_sizes) +@pytest.mark.parametrize("model", model_configs.keys()) +def test_gpt_fp8_parameters(dtype, bs, model): + if not fp8_available: + pytest.skip(reason_for_no_fp8) + + config = model_configs[model] + + outputs = _test_gpt_fp8_parameters(bs, dtype, config, False) + outputs_fp8_params = _test_gpt_fp8_parameters(bs, dtype, config, True) + assert_all_equal(outputs, outputs_fp8_params) diff --git a/tests/pytorch/test_onnx_export.py b/tests/pytorch/test_onnx_export.py index 4774cd39ab..dd50f15e43 100644 --- a/tests/pytorch/test_onnx_export.py +++ b/tests/pytorch/test_onnx_export.py @@ -147,7 +147,7 @@ def set_layer_scale(module: torch.nn.Module, scale: float, num_gemms: int): """Initialize the FP8 quantization scales in module""" NB_SCALES_PER_GEMM = 3 # One scale per: input, weights, and output GEMM tensors. nb_total_scales = num_gemms * NB_SCALES_PER_GEMM - module.fp8_init(num_gemms) + module.init_fp8_metadata(num_gemms) module.fp8_meta["scaling_fwd"].scale = torch.ones( nb_total_scales, dtype=torch.float32, device="cuda") / scale module.fp8_meta["scaling_fwd"].scale_inv = torch.ones( diff --git a/tests/pytorch/test_torch_save_load.py b/tests/pytorch/test_torch_save_load.py index f35b60ede2..2732db6ad9 100644 --- a/tests/pytorch/test_torch_save_load.py +++ b/tests/pytorch/test_torch_save_load.py @@ -16,7 +16,7 @@ import torch import transformer_engine.pytorch as te import transformer_engine_extensions as tex -from transformer_engine.pytorch.cpp_extensions import fp8_gemm, cast_to_fp8, cast_from_fp8 +from transformer_engine.pytorch.cpp_extensions import fp8_gemm, cast_to_fp8 from transformer_engine.pytorch.module.base import get_workspace from transformer_engine.pytorch.module.base import TransformerEngineBaseModule @@ -93,7 +93,7 @@ def forward(self, inp, weight): model_in = Test_TE_Export(precision, True) with te.fp8_autocast(enabled=True): - model_in.fp8_init() + model_in.init_fp8_metadata() # scaling fwd model_in.fp8_meta["scaling_fwd"].scale = torch.ones(3, dtype=torch.float32, device="cuda") * scale_fwd model_in.fp8_meta["scaling_fwd"].scale_inv = torch.ones(3, dtype=torch.float32, device="cuda") / scale_fwd diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 8ff601f6f1..b29853a3a7 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -13,6 +13,7 @@ from .attention import MultiheadAttention from .transformer import TransformerLayer from .fp8 import fp8_autocast +from .fp8 import fp8_model_init from .export import onnx_export from .distributed import checkpoint from .distributed import CudaRNGStatesTracker diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index abc3936e25..1d93d03f3f 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -83,14 +83,16 @@ def initialize_affine_weight_gpu( weight: torch.Tensor, init_method: Callable, get_rng_state_tracker: Callable, - partition_dim: int, + partition_dim: int = 0, stride: int = 1, + set_tp_attributes: bool = True, ) -> None: """Initialize affine weight for model parallel on GPU.""" - set_tensor_model_parallel_attributes( - tensor=weight, is_parallel=True, dim=partition_dim, stride=stride - ) + if set_tp_attributes: + set_tensor_model_parallel_attributes( + tensor=weight, is_parallel=True, dim=partition_dim, stride=stride + ) if get_rng_state_tracker is None: init_method(weight) diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py new file mode 100644 index 0000000000..1868bb4ed2 --- /dev/null +++ b/transformer_engine/pytorch/float8_tensor.py @@ -0,0 +1,689 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Tensor class with FP8 data""" +from __future__ import annotations +from typing import Any, Dict, Optional + +import torch +from torch.utils._pytree import tree_map +import transformer_engine_extensions as tex + +from .constants import TE_DType +from .fp8 import FP8GlobalStateManager + + +aten = torch.ops.aten +c10d = torch.ops.c10d + + +def _make_fp8_attr_property_funcs(name: str) -> Any: + """Make accessors for an FP8 attribute + + We store FP8 attributes in a dictionary so we can share them + between tensors with the same data, e.g. detached tensors. For + convenience, we also expose them as property attributes. This + function creates the accessors for property attributes. + + Parameters + ---------- + name: str + Key in dictionary of FP8 attributes + + """ + def get_func(self) -> Any: + return self._fp8_attrs[name] + def set_func(self, value: Any) -> None: + self._fp8_attrs[name] = value + def del_func(self) -> None: + del self._fp8_attrs[name] + return dict(fget=get_func, fset=set_func, fdel=del_func) + + +class _FromFloat8Func(torch.autograd.Function): + """Cast from FP8 to other dtype""" + @staticmethod + def forward( + ctx, + tensor: Float8Tensor, + dtype: Optional[torch.dtype] = None, + ) -> torch.Tensor: + if dtype is None: + dtype = tensor.dtype + data = tensor._data.contiguous().view(1,-1).detach() + out = tex.cast_from_fp8( + data, + tensor._scale_inv, + tensor._fp8_dtype, + TE_DType[dtype], + ) + out = out.view(tensor.size()) + return out + + @staticmethod + def backward(ctx, grad): + # Assume that we want gradients in full precision + return grad, None + + +class _ToFloat8Func(torch.autograd.Function): + """Cast to FP8 from other dtype""" + @staticmethod + def forward( + ctx, + tensor: torch.Tensor, + fp8_meta: Optional[Dict[str, Any]] = None, + fp8_meta_forward: bool = True, + fp8_meta_index: Optional[int] = None, + fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, + ): + + # Manually compute scale-inverse if needed + if scale is not None and scale_inv is None: + if isinstance(scale, torch.Tensor): + scale_inv = scale.reciprocal() + else: + scale_inv = 1 / scale + + # Extract data from FP8 meta tensors if provided + if fp8_meta is not None: + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=fp8_meta_forward, + ) + if fp8_meta_index is None: + raise ValueError( + "To initialize Float8Tensor with FP8 meta tensors, " + "the FP8 meta tensor index must also be provided" + ) + if scale is None: + scale = fp8_meta[fp8_meta_key].scale[fp8_meta_index] + if amax is None: + amax = fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index] + if scale_inv is None: + scale_inv = fp8_meta[fp8_meta_key].scale_inv[fp8_meta_index] + scale_inv = scale_inv.detach().view(1).clone() + + # Check input tensor + tensor = tensor.contiguous().cuda().detach() + if tensor.dtype not in (torch.float32, torch.bfloat16, torch.float16): + tensor = tensor.float() + + # Check scale + if not isinstance(scale, torch.Tensor): + if scale is None: + scale = 1 + scale = torch.full( + [1], + scale, + dtype=torch.float32, + device=tensor.device, + ) + if scale.numel() != 1: + raise ValueError( + "Attempted to initialize Float8Tensor with invalid scale tensor" + ) + scale = scale.to(device=tensor.device, dtype=torch.float32) + + # Check scale-inverse + if scale_inv is None: + scale_inv = scale.reciprocal() + scale_inv = scale_inv.to(device=tensor.device, dtype=torch.float32) + + # Check amax + if amax is None: + amax = torch.empty_like(scale) + if not (amax.numel() == 1 and amax.is_cuda and amax.dtype == torch.float32): + raise ValueError( + "Attempted to initialize Float8Tensor with invalid amax tensor" + ) + + # Cast data to FP8 + data = tex.cast_to_fp8( + tensor.view(1,-1), + scale, + amax, + scale_inv, + fp8_dtype, + ) + data = data.view(tensor.size()) + + # Construct FP8 tensor + return Float8Tensor( + data=data, + fp8_meta=fp8_meta, + fp8_meta_forward=fp8_meta_forward, + fp8_meta_index=fp8_meta_index, + fp8_dtype=fp8_dtype, + fp8_scale_inv=scale_inv, + dtype=tensor.dtype, + ) + + @staticmethod + def backward(ctx, grad): + # Assume that we want gradients in full precision + return grad, None, None, None, None, None, None, None + +class _IdentityFunc(torch.autograd.Function): + """Identity function + + If constructor keyword-arguments are provided, then construct a + new Float8Tensor using the provided tensor's attributes. + + """ + + @staticmethod + def forward( + ctx, + tensor: Float8Tensor, + init_kwargs: Optional[Dict[str, Any]] = None, + ) -> torch.Tensor: + + # Return input tensor if constructor kwargs are not provided + ctx.input_dtype = tensor.dtype + if init_kwargs is None: + return tensor + + # Construct new tensor if constructor kwargs are provided + default_kwargs = dict( + data=tensor._data, + fp8_meta=tensor._fp8_meta, + fp8_meta_forward=tensor._fp8_meta_forward, + fp8_meta_index=tensor._fp8_meta_index, + fp8_dtype=tensor._fp8_dtype, + fp8_scale_inv=tensor._scale_inv, + dtype=tensor.dtype, + ) + for key, val in default_kwargs.items(): + if key not in init_kwargs: + init_kwargs[key] = val + return Float8Tensor(**init_kwargs) + + @staticmethod + def backward(ctx, grad): + return grad.to(ctx.input_dtype), None + + +class Float8Tensor(torch.Tensor): + """Experimental tensor class with FP8 data + + The tensor presents as having a standard, higher-precision dtype, + but the data itself is (scaled) FP8. For most tensor operations, + the data will be cast to the nominal dtype before performing the + operation. + + Parameters + ---------- + data: torch.Tensor + Raw FP8 data in a uint8 tensor + fp8_attrs: dict, optional + FP8 metadata, primarily managed by Float8Tensor. If + provided, all other FP8 configuration is ignored. + fp8_meta: dict, optional + FP8 metadata object, primarily managed by TE modules. + fp8_meta_forward: bool, default = `True` + Whether to access the FP8 metadata for the + forward pass. Ignored if fp8_meta is not + provided. + fp8_meta_index: int, optional + Index to access in FP8 meta tensors. Required if + fp8_meta is provided and otherwise ignored. + fp8_dtype: transformer_engine_extensions.DType, tex.DType.kFloat8E4M3 + FP8 format. + fp8_scale_inv: torch.Tensor + Reciprocal of the scaling factor applied when + casting to FP8, i.e. the scaling factor that must + be applied when casting from FP8 to higher + precision. Can be inferred from fp8_meta if + provided. + dtype: torch.dtype, default = torch.float32 + Nominal tensor datatype. + + """ + + def __new__( + cls, + *, + data: torch.Tensor, + fp8_attrs: Optional[Dict[str, Any]] = None, + fp8_meta: Optional[Dict[str, Any]] = None, + fp8_meta_forward: bool = True, + fp8_meta_index: Optional[int] = None, + fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, + fp8_scale_inv: Optional[torch.Tensor] = None, + dtype: torch.dtype = torch.float32, + ): + + # Check that data buffer is valid + if data.element_size() != 1: + raise ValueError( + "Float8Tensor requires data buffer with 8-bit dtype " + f"(got dtype={data.dtype})" + ) + if data.requires_grad: + raise ValueError( + "Float8Tensor requires non-differentiable data buffer" + ) + data = data.cuda() + + # Initialize tensor object + self = torch.Tensor._make_wrapper_subclass( + cls, + data.size(), + strides=data.stride(), + storage_offset=data.storage_offset(), + dtype=dtype, + layout=data.layout, + requires_grad=data.requires_grad, + device=data.device, + ) + self._data: torch.Tensor = data + + # Initialize dict of class attributes + # Note: We store FP8 attributes in a dictionary so we can + # share them between tensors with the same data, e.g. detached + # tensors. + self._fp8_attrs: dict = {} + if fp8_attrs is not None: + self._fp8_attrs = fp8_attrs + return self + + # FP8 meta tensors + if fp8_meta is not None and fp8_meta_index is None: + raise ValueError( + "To initialize Float8Tensor with FP8 meta tensors, " + "the FP8 meta tensor index must also be provided" + ) + self._fp8_meta: Optional[Dict[str, Any]] = fp8_meta + self._fp8_meta_forward: bool = fp8_meta_forward + self._fp8_meta_index: Optional[int] = fp8_meta_index + + # FP8 dtype + assert ( + fp8_dtype in (tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2) + ), f"Unsupported fp8_dtype {fp8_dtype}." + self._fp8_dtype: tex.DType = fp8_dtype + + # Cached transpose + self._transpose: Optional[Float8Tensor] = None + + # FP8 scale-inverse + self._scale_inv: Optional[torch.Tensor] = fp8_scale_inv + if self._scale_inv is None and self._fp8_meta is not None: + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=self._fp8_meta_forward, + ) + scale_inv = self._fp8_meta[fp8_meta_key].scale_inv[self._fp8_meta_index] + self._scale_inv = scale_inv.detach().view(1).clone() + if self._scale_inv is None: + raise ValueError( + "Attempted to initialize Float8Tensor without specifying scale-inverse" + ) + if not isinstance(self._scale_inv, torch.Tensor): + self._scale_inv = torch.full( + [1], + self._scale_inv, + dtype=torch.float32, + device=self._data.device, + ) + if self._scale_inv.numel() != 1: + raise ValueError( + "Attempted to initialize Float8Tensor with invalid scale-inverse tensor" + ) + self._scale_inv = self._scale_inv.to( + device=self._data.device, + dtype=torch.float32, + ) + + return self + + @classmethod + def make_like( + cls, + tensor: Float8Tensor, + *, + data: torch.Tensor, + fp8_attrs: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> Float8Tensor: + """Use attributes of a Float8Tensor to create another Float8Tensor + + See constructor for list of keyword arguments. + + """ + default_kwargs = dict( + fp8_meta=tensor._fp8_meta, + fp8_meta_forward=tensor._fp8_meta_forward, + fp8_meta_index=tensor._fp8_meta_index, + fp8_dtype=tensor._fp8_dtype, + fp8_scale_inv=tensor._scale_inv, + dtype=tensor.dtype, + ) + for key, val in default_kwargs.items(): + if key not in kwargs: + kwargs[key] = val + return Float8Tensor(data=data, fp8_attrs=fp8_attrs, **kwargs) + + def __repr__(self): + return ( + "Float8Tensor(" + f"fp8_dtype={self._fp8_dtype}, " + f"scale_inv={self._scale_inv.item()}, " + f"data={self.from_float8(dtype=self.dtype)}" + ")" + ) + + def from_float8(self, dtype: Optional[torch.dtype] = None) -> torch.Tensor: + """ + Construct plain PyTorch tensor from Float8Tensor + + By default the resulting tensor's dtype is the + Float8Tensor's nominal dtype. + """ + return _FromFloat8Func.apply(self, dtype) + + @classmethod + def to_float8( + cls, + tensor: torch.Tensor, + *, + fp8_meta: Optional[Dict[str, Any]] = None, + fp8_meta_forward: bool = True, + fp8_meta_index: Optional[int] = None, + fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, + ): + """Construct Float8Tensor from plain PyTorch tensor""" + return _ToFloat8Func.apply( + tensor, + fp8_meta, + fp8_meta_forward, + fp8_meta_index, + fp8_dtype, + scale, + amax, + scale_inv, + ) + + def float(self) -> torch.Tensor: + return self.from_float8(dtype=torch.float32) + + def bfloat16(self) -> torch.Tensor: + return self.from_float8(dtype=torch.bfloat16) + + def half(self) -> torch.Tensor: + return self.from_float8(dtype=torch.float16) + + def cpu(self) -> torch.Tensor: + return self.from_float8().cpu() + + def clone(self) -> Float8Tensor: + return _IdentityFunc.apply(self, {"data": self._data.detach().clone()}) + + def expand_as(self, other: torch.Tensor): + if other is self: + # Note: expand_as is hackily used to create dummy autograd nodes + # and access the backward graph (see + # https://github.com/pytorch/pytorch/blob/238fb660851268f44ff88127887041fea352fe48/torch/nn/parallel/distributed.py#L1026). + # We equally hackily add a dummy function to handle this + # case. + return _IdentityFunc.apply(self) + return super().expand_as(other) + + def _transpose_no_cache(self) -> torch.Tensor: + """ + Swap tensor dimensions + + For basic 2D matrix transposes, an optimized transpose kernel + is applied and a Float8Tensor is returned. + """ + + # Use optimized kernel for basic 2D transpose + # TODO Support differentiation # pylint: disable=fixme + return Float8Tensor.make_like( + self, + data=tex.fp8_transpose( + self._data.contiguous().detach(), + self._fp8_dtype, + ), + ) + + def transpose( + self, + dim0: int = 0, + dim1: int = 1, + *, + update_cache: Optional[bool] = None, + ) -> torch.Tensor: + """ + Swap tensor dimensions + + For basic 2D matrix transposes, an optimized transpose kernel + is applied and a Float8Tensor is returned. + + Parameters + ---------- + dim0: int, default = 0 + The first dimension to be transposed + dim1: int, default = 1 + The second dimension to be transposed + update_cache: Optional[bool], default = None + If set to `True`, the result is computed and stored in a cache. + If set to `False`, the result is computed only if the cache is + empty, otherwise the cache is returned. If set to `None`, the + result is not cached. Caching is only supported for basic 2D + transposes and the cache is reset after any in-place operations. + """ + + # Handle non-2D transposes + if -self.dim() <= dim0 < 0: + dim0 += self.dim() + if -self.dim() <= dim1 < 0: + dim1 += self.dim() + if self.dim() != 2 or dim0 == dim1: + if update_cache is not None: + raise ValueError( + "Transpose caching is only supported for basic 2D transposes " + f"(ndims={self.dim()}, dim0={dim0}, dim1={dim1})" + ) + return super().transpose(dim0, dim1) + + # No caching. + if update_cache is None: + return self._transpose_no_cache() + + # Update cache. + if update_cache or self._transpose is None: + self._transpose = self._transpose_no_cache() + + return self._transpose + + @torch.no_grad() + def reset_fp8_meta_scale_inv(self) -> None: + """Replace FP8 meta tensor scale-inverse with cached value + + The FP8 meta tensor scale_inv entry corresponding to this + tensor is replaced with the scale_inv value used to construct + the tensor. + + """ + if self._fp8_meta is None: + return + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=self._fp8_meta_forward, + ) + scale_inv = self._fp8_meta[fp8_meta_key].scale_inv[self._fp8_meta_index] + scale_inv.view(1).copy_(self._scale_inv.view(1)) + + def to_dtype(self, dtype: torch.dtype) -> Float8Tensor: + """Create `Float8Tensor` with given nominal dtype + + The new tensor has the same underlying FP8 data. + + """ + return Float8Tensor.make_like( + self, + data=self._data, + fp8_attrs=self._fp8_attrs, + dtype=dtype, + ) + + def _reset_caches(self) -> None: + """Reset cached values + + Should be called after any in-place operation. + + """ + self._transpose = None + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + + # In-place copy op + if func == aten.copy_.default: + + # Check tensors + dst = args[0] + src = args[1] + if not isinstance(dst, Float8Tensor): + raise RuntimeError("Expected to copy into Float8Tensor") + if not isinstance(src, torch.Tensor): + raise RuntimeError("Expected to copy from tensor") + if not dst._data.is_contiguous(): + raise RuntimeError("Transformer Engine cast kernels require contiguous data") + + # Make sure input is in expected format + if isinstance(src, Float8Tensor): + src = src.from_float8() + src = src.expand(dst.size()) + src = src.to( + device=dst.device, + memory_format=torch.contiguous_format, + ) + + # Update scaling factor if FP8 meta tensors are available + if dst._fp8_meta is None: + scale = dst._scale_inv.reciprocal() + else: + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=dst._fp8_meta_forward, + ) + scale = dst._fp8_meta[fp8_meta_key].scale[dst._fp8_meta_index] + dst._scale_inv = scale.detach().view(1).reciprocal() + + # Cast to FP8 + tex.cast_to_fp8_noalloc( + src.view(1,-1), + scale, + dst._data.view(1,-1), + torch.empty_like(dst._scale_inv), # amax + dst._scale_inv, + dst._fp8_dtype, + ) + + # Nothing to return for in-place ops + dst._reset_caches() + return None + + # Slice op + # TODO Consider additional bookkeeping so we invalidate caches # pylint: disable=fixme + # if these slices are modified in-place + if func == aten.slice.Tensor: + tensor = args[0] + data = tensor._data + data_slice = data.__torch_dispatch__( + func, + types, + [data] + list(args[1:]), + kwargs, + ) + return Float8Tensor.make_like(tensor, data=data_slice) + + # Detach op + if func == aten.detach.default: + # Simply return a new Float8Tensor with the same attrs + return Float8Tensor.make_like( + args[0], + data=args[0]._data, + fp8_attrs=args[0]._fp8_attrs, + ) + + def maybe_unwrap(t): + if isinstance(t, Float8Tensor): + return t.from_float8() + return t + + def maybe_update_inplace(arg, new_arg, schema_arg): + """Update values of FP8 tensors + + Keep the same FP8 scaling factors. + + """ + if( + isinstance(arg, Float8Tensor) and + isinstance(new_arg, torch.Tensor) and + hasattr(schema_arg, 'alias_info') and + hasattr(schema_arg.alias_info, 'is_write') and + schema_arg.alias_info.is_write + ): + arg.copy_(new_arg) + arg._reset_caches() + + # In-place op + if func._schema.is_mutable: + # Cast to higher precision, perform op, and cast values + # back to original FP8 buffers + new_args = tree_map(maybe_unwrap, args) + new_kwargs = tree_map(maybe_unwrap, kwargs) + schema_args = func._schema.arguments + args_len = len(args) + out = super().__torch_dispatch__(func, types, new_args, new_kwargs) + for arg, new_arg, schema_arg in zip(args, new_args, schema_args): + maybe_update_inplace(arg, new_arg, schema_arg) + for kwarg, new_kwarg, schema_arg in zip(kwargs, new_kwargs, schema_args[args_len:]): + assert kwarg == new_kwarg == schema_arg.name, "name of the kw argument should match" + maybe_update_inplace(kwargs[kwarg], new_kwargs[new_kwarg], schema_arg) + return None + + # Default op + # Note: cast to higher precision and perform op + args = tree_map(maybe_unwrap, args) + if kwargs is not None: + kwargs = tree_map(maybe_unwrap, kwargs) + out = super().__torch_dispatch__(func, types, args, kwargs) + return out + + def _get_data(self) -> Float8Tensor: + """Get tensor data property""" + return super().data + + def _set_data(self, tensor: torch.Tensor) -> None: + """Set tensor data property + + Cast tensor to FP8 and store in FP8 buffer. + + """ + with torch.no_grad(): + self.copy_(tensor) + + # Cast to FP8 when setting Float8Tensor.data + data = property(_get_data, _set_data) + + # Accessors for objects in self._fp8_attrs + # Note: We store FP8 attributes in a dictionary so we can share + # them between tensors with the same data, e.g. detached tensors. + # For convenience, we also expose them as property attributes. + _fp8_meta = property(**_make_fp8_attr_property_funcs("fp8_meta")) + _fp8_meta_forward = property(**_make_fp8_attr_property_funcs("fp8_meta_forward")) + _fp8_meta_index = property(**_make_fp8_attr_property_funcs("fp8_meta_index")) + _fp8_dtype = property(**_make_fp8_attr_property_funcs("dtype")) + _transpose = property(**_make_fp8_attr_property_funcs("transpose")) + _scale_inv = property(**_make_fp8_attr_property_funcs("scale_inv")) + + # Do not force the Float8Tensor type on the returned tensor + __torch_function__ = torch._C._disabled_torch_function_impl diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index c89ff10968..c7d4524113 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -17,7 +17,7 @@ from .jit import jit_fuser -__all__ = ["fp8_autocast"] +__all__ = ["fp8_autocast", "fp8_model_init"] def check_fp8_support() -> Tuple[bool, str]: @@ -59,6 +59,7 @@ class FP8GlobalStateManager: FP8_CALIBRATION = False FP8_RECIPE = None FP8_DISTRIBUTED_GROUP = None + FP8_PARAMETERS = False IS_FIRST_FP8_MODULE = False FP8_AUTOCAST_COUNTER = 0 FP8_CURRENT_CONTEXT_ID = 0 @@ -277,6 +278,11 @@ def is_fp8_calibration(cls) -> bool: """Is FP8 calibration""" return cls.FP8_CALIBRATION + @classmethod + def with_fp8_parameters(cls) -> bool: + """Should the parameters be stored as FP8""" + return cls.FP8_PARAMETERS + @classmethod def is_first_fp8_module(cls): """Returns `True` only the first time when called multiple @@ -400,6 +406,11 @@ def fp8_autocast_enter( fp8_group: Optional[dist_group_type] = None, ) -> None: """Set state and tracking variables for entry into FP8 region.""" + if cls.FP8_AUTOCAST_DEPTH == 0: + if callable(cls.amax_forward_global_reduce_func): + cls.amax_reduce_handle_fwd = cls.amax_forward_global_reduce_func() # pylint: disable=not-callable + cls.delete_key_from_amax_buffer(forward=True) + cls.FP8_ENABLED = enabled cls.FP8_CALIBRATION = calibrating cls.FP8_RECIPE = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe @@ -419,11 +430,6 @@ def fp8_autocast_exit(cls): """Set state and tracking variables for exit from FP8 region.""" cls.FP8_AUTOCAST_DEPTH -= 1 - if cls.FP8_AUTOCAST_DEPTH == 0: - if callable(cls.amax_forward_global_reduce_func): - cls.amax_reduce_handle_fwd = cls.amax_forward_global_reduce_func() # pylint: disable=not-callable - cls.delete_key_from_amax_buffer(forward=True) - @classmethod def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None: """Copy the scaling factors and amaxes for recompute forward phase @@ -477,9 +483,45 @@ def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None: fp8_meta["scaling_fwd"].scale_inv = fp8_meta["updated_scale_inv_fwd"] +@contextmanager +def fp8_model_init(enabled: bool = True) -> None: + """ + Context manager for FP8 initialization of parameters. + + Example usage: + + .. code-block:: python + + with fp8_model_init(enabled=True): + model = transformer_engine.pytorch.Linear(768, 768) + + Parameters + ---------- + enabled: bool, default = `True` + when enabled, Transformer Engine modules created inside this `fp8_model_init` + region will hold only FP8 copies of its parameters, as opposed to the default + behavior where both higher precision and FP8 copies are present. Setting this + option to `True` may result in lower memory consumption and is especially + useful for scenarios like: + + * full model training using optimizer with master weights, where the high + precision copies of weights are already present in the optimizer. + * inference, where only the FP8 copies of the parameters are used. + * LoRA-like fine-tuning, where the main parameters of the model do not change. + + This functionality is *EXPERIMENTAL*. + """ + try: + _fp8_parameters = FP8GlobalStateManager.FP8_PARAMETERS + FP8GlobalStateManager.FP8_PARAMETERS = enabled + yield + finally: + FP8GlobalStateManager.FP8_PARAMETERS = _fp8_parameters # pylint: disable=used-before-assignment + + @contextmanager def fp8_autocast( - enabled: bool = False, + enabled: bool = True, calibrating: bool = False, fp8_recipe: Optional[DelayedScaling] = None, fp8_group: Optional[dist_group_type] = None, @@ -508,7 +550,7 @@ def fp8_autocast( Parameters ---------- - enabled: bool, default = `False` + enabled: bool, default = `True` whether or not to enable fp8 calibrating: bool, default = `False` calibration mode allows collecting statistics such as amax and scale @@ -523,7 +565,10 @@ def fp8_autocast( """ try: fp8_state = FP8GlobalStateManager.get_fp8_autocast_state() - FP8GlobalStateManager.fp8_autocast_enter(enabled, calibrating, fp8_recipe, fp8_group) + FP8GlobalStateManager.fp8_autocast_enter(enabled=enabled, + calibrating=calibrating, + fp8_recipe=fp8_recipe, + fp8_group=fp8_group) yield finally: FP8GlobalStateManager.set_fp8_autocast_state(fp8_state) # pylint: disable=used-before-assignment diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 5803cfa2f9..1dbc40dc70 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -36,6 +36,7 @@ cast_to_fp8, ) from ..constants import dist_group_type +from ..float8_tensor import Float8Tensor _2X_ACC_FPROP = False _2X_ACC_DGRAD = True @@ -451,21 +452,29 @@ def set_fp8_weights(self) -> None: setattr( self, weight_cast_attr, - torch.empty( - shape, - device=torch.cuda.current_device(), - dtype=torch.uint8, - ), + Float8Tensor( + data=torch.empty( + shape, + device=torch.cuda.current_device(), + dtype=torch.uint8, + ), + fp8_dtype=tex.DType.kFloat8E4M3, + fp8_scale_inv=1, + ) ) setattr( self, weight_transpose_attr, - torch.empty( - shape[1], - shape[0], - device=torch.cuda.current_device(), - dtype=torch.uint8, - ), + Float8Tensor( + data=torch.empty( + shape[1], + shape[0], + device=torch.cuda.current_device(), + dtype=torch.uint8, + ), + fp8_dtype=tex.DType.kFloat8E4M3, + fp8_scale_inv=1, + ) ) def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None: @@ -483,12 +492,17 @@ def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> N # This routine is shared across FP8 and FP8_calibration paths so should not actually # assume FP8 execution. - def fp8_init(self, num_gemms: int = 1) -> None: + def init_fp8_metadata(self, num_gemms: int = 1) -> None: """Initialize fp8 related metadata and tensors during fprop.""" + self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters() self.fp8 = FP8GlobalStateManager.is_fp8_enabled() self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration() self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration + if self.fp8_parameters and not self.fp8_initialized: + self.fp8_meta["num_gemms"] = num_gemms + self.init_fp8_meta_tensors() + if self.fp8 or self.fp8_calibration: # FP8 init has already been run and recipe is the same, don't do anything. if (self.fp8_initialized @@ -536,7 +550,7 @@ def prepare_forward( assert self.tp_group_initialized, "TP group not initialized." self.set_activation_dtype(inp) - self.fp8_init(num_gemms=num_gemms) + self.init_fp8_metadata(num_gemms=num_gemms) # Create persistent tensors for fp8 weights and their transposes # only when fp8 weight caching is used. @@ -765,7 +779,7 @@ def noop_cat(self, def get_fp8_weights_empty_tensors( self, is_first_microbatch: Union[bool, None], - ) -> List[torch.Tensor]: + ) -> List[Float8Tensor]: """ Returns empty tensors to be later used to store fp8 version of weights and their transposes (for the bwd pass) for this batch (or microbatch). @@ -781,23 +795,42 @@ def get_fp8_weights_empty_tensors( fp8_weight_tensors = [] for shape in self.fp8_weight_shapes: fp8_weight_tensors.append( - torch.empty( - shape, - device=torch.cuda.current_device(), - dtype=torch.uint8, + Float8Tensor( + data=torch.empty( + shape, + device=torch.cuda.current_device(), + dtype=torch.uint8, + ), + fp8_dtype=tex.DType.kFloat8E4M3, + fp8_scale_inv=1, ) ) - fp8_weight_tensors.append( - torch.empty( - shape[1], - shape[0], - device=torch.cuda.current_device(), - dtype=torch.uint8, + Float8Tensor( + data=torch.empty( + shape[1], + shape[0], + device=torch.cuda.current_device(), + dtype=torch.uint8, + ), + fp8_dtype=tex.DType.kFloat8E4M3, + fp8_scale_inv=1, ) ) return fp8_weight_tensors + def state_dict(self, *args, **kwargs) -> Dict: + """Get dictionary containing module state""" + state = super().state_dict(*args, **kwargs) + + # Convert Float8Tensors to plain tensors + # Note: Float8Tensors don't serialize well, especially if they + # contain references to FP8 metadata. + for key, val in state.items(): + if isinstance(val, Float8Tensor): + state[key] = val.from_float8() + + return state @abstractmethod def forward(self): diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index a8e83631bc..d4746ba3a0 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -23,7 +23,7 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) -from ..fp8 import get_fp8_te_dtype +from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager from ..utils import ( divide, get_default_init_method, @@ -43,6 +43,7 @@ from ._common import _apply_normalization +from ..float8_tensor import Float8Tensor __all__ = ["LayerNormLinear"] @@ -79,10 +80,11 @@ def forward( fwd_ln_sm_margin: int, bwd_ln_sm_margin: int, zero_centered_gamma: bool, + normalization: str, + primary_weights_in_fp8: bool, ub_bulk_wgrad: bool, ub_bulk_dgrad: bool, ub_split_ag: bool, - normalization: str, ub_atomic_gemm_ag: bool, ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # Make sure input dimensions are compatible @@ -159,28 +161,43 @@ def forward( ) bias = cast_if_needed(bias, bias_dtype) if use_bias else bias - if update_fp8_weights: + if primary_weights_in_fp8: + # Weight is already in FP8 + weight.reset_fp8_meta_scale_inv() + weight_fp8 = weight + weight_t_fp8 = None + if is_grad_enabled: + weight_t_fp8 = weight_fp8.transpose(update_cache=is_first_microbatch) + + elif update_fp8_weights: + # Need to cast weights to FP8 + weight_fp8 = Float8Tensor( + data=weight_fp8._data, + fp8_meta=fp8_meta, + fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, + ) if is_grad_enabled: tex.fp8_cast_transpose_fused( weight, fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, - cast_out=weight_fp8, - transpose_out=weight_t_fp8, + cast_out=weight_fp8._data, + transpose_out=weight_t_fp8._data, ) else: - weight_t_fp8 = None - weight_fp8 = tex.cast_to_fp8( + weight_fp8._data = tex.cast_to_fp8( weight, fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_WEIGHT, - fp8_dtype_forward) + fp8_dtype_forward, + ) + weight_t_fp8 = None ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ub_atomic_gemm_ag else ub_algo out, _ = tex.fp8_gemm( - weight_fp8, + weight_fp8._data, fp8_meta["scaling_fwd"].scale_inv, tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, @@ -356,7 +373,7 @@ def backward( # DGRAD: Evaluated unconditionally to feed into Linear backward _ = tex.fp8_gemm( - weight_t_fp8, + weight_t_fp8._data, fwd_scale_inverses, tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, @@ -544,6 +561,7 @@ def backward( None, None, None, + None, ) @@ -646,10 +664,10 @@ def __init__( return_layernorm_output: bool = False, parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None, zero_centered_gamma: bool = False, + device: Union[torch.device, str] = "cuda", ub_bulk_wgrad: bool = False, ub_bulk_dgrad: bool = False, ub_split_ag: bool = False, - device: Union[torch.device, str] = "cuda", ub_atomic_gemm_ag: bool = False, ) -> None: super().__init__() @@ -666,6 +684,7 @@ def __init__( self.return_layernorm_output = return_layernorm_output self.parameters_split = parameters_split self.zero_centered_gamma = zero_centered_gamma + self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters() self.ub_bulk_wgrad = ub_bulk_wgrad self.ub_bulk_dgrad = ub_bulk_dgrad self.ub_split_ag = ub_split_ag @@ -719,18 +738,30 @@ def __init__( self.layer_norm_bias = None self.reset_layer_norm_parameters() - self.weight_tensor = torch.empty( + temp_weight = torch.empty( self.out_features, self.in_features, device=device, dtype=params_dtype) initialize_affine_weight_gpu( - self.weight_tensor, + temp_weight, init_method, get_rng_state_tracker, partition_dim=1 if self.parallel_mode == "row" else 0, stride=1, ) + if self.primary_weights_in_fp8: + self.init_fp8_metadata() + self.fp8_meta["update_amax_and_scale_fwd"] = True + + self.weight_tensor = Float8Tensor.to_float8( + temp_weight, + fp8_meta=self.fp8_meta, + fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, + ) + else: + self.weight_tensor = temp_weight + if self.use_bias: self.bias_tensor = torch.empty( self.out_features, @@ -769,10 +800,17 @@ def __init__( bname = pname + "bias" slice_end = slice_begin + slice_size - - self.register_parameter( - wname, Parameter(self.weight_tensor[slice_begin:slice_end]) - ) + # NOTE(future): Figure out a way to support slicing when weights + # are of `Float8Tensor` class + if self.primary_weights_in_fp8: + assert len(parameters_split) == 1, ("Slicing operation is not " + "supported in Float8Tensor " + "class!") + self.register_parameter(wname, Parameter(self.weight_tensor)) + else: + self.register_parameter( + wname, Parameter(self.weight_tensor[slice_begin:slice_end]) + ) set_tensor_model_parallel_attributes( tensor=getattr(self, wname), @@ -833,7 +871,7 @@ def get_fp8_weights_scratchpad( `is_first_microbatch` is not `None`) or return empty fp8 weight tensors (if `is_first_microbatch is None`) """ - if not self.fp8: + if not self.fp8 or self.primary_weights_in_fp8: return [None, None] if is_first_microbatch is None: @@ -877,6 +915,8 @@ def forward( """ with self.prepare_forward(inp, is_first_microbatch) as inp: + assert self.fp8 or not self.primary_weights_in_fp8, \ + "Need to run inside fp8_autocast region when weights are stored in FP8." bias_tensor = ( self.bias if self.parameters_split is None else self.bias_tensor if not torch.is_grad_enabled() @@ -927,10 +967,11 @@ def forward( self.fwd_ln_sm_margin, self.bwd_ln_sm_margin, self.zero_centered_gamma, + self.normalization, + self.primary_weights_in_fp8, self.ub_bulk_wgrad, self.ub_bulk_dgrad, self.ub_split_ag, - self.normalization, self.ub_atomic_gemm_ag, ) out = fwd_fn(*args) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index d41c8d39df..40256dba6a 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -20,7 +20,7 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) -from ..fp8 import get_fp8_te_dtype +from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager from ..jit import ( bias_gelu_fused, bgrad_dgelu_fused, @@ -47,6 +47,7 @@ from ..constants import dist_group_type, TE_DType from ..jit import no_torch_dynamo +from ..float8_tensor import Float8Tensor from ._common import _apply_normalization @@ -105,14 +106,15 @@ def forward( fwd_ln_sm_margin: int, bwd_ln_sm_margin: int, zero_centered_gamma: bool, + activation: str, + normalization: str, + primary_weights_in_fp8: bool, ub_bulk_wgrad: bool, ub_bulk_dgrad: bool, ub_split_rs: bool, ub_atomic_gemm_rs: bool, ub_split_ag: bool, ub_atomic_gemm_ag: bool, - activation: str, - normalization: str, ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # Make sure input dimensions are compatible in_features = ln_weight.numel() @@ -196,45 +198,68 @@ def forward( fc1_bias = cast_if_needed(fc1_bias, bias_dtype) if use_fc1_bias else fc1_bias fc2_bias = cast_if_needed(fc2_bias, bias_dtype) if use_fc2_bias else fc2_bias - if update_fp8_weights: + if primary_weights_in_fp8: + # Weights are already in FP8 + fc1_weight.reset_fp8_meta_scale_inv() + fc2_weight.reset_fp8_meta_scale_inv() + fc1_weight_fp8 = fc1_weight + fc2_weight_fp8 = fc2_weight + fc1_weight_t_fp8 = None + fc2_weight_t_fp8 = None if is_grad_enabled: + fc1_weight_t_fp8 = fc1_weight_fp8.transpose(update_cache=is_first_microbatch) + fc2_weight_t_fp8 = fc2_weight_fp8.transpose(update_cache=is_first_microbatch) + + elif update_fp8_weights: + # Need to cast weights to FP8 + fc1_weight_fp8 = Float8Tensor( + data=fc1_weight_fp8._data, + fp8_meta=fp8_meta, + fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, + ) + fc2_weight_fp8 = Float8Tensor( + data=fc2_weight_fp8._data, + fp8_meta=fp8_meta, + fp8_meta_index=tex.FP8FwdTensors.GEMM2_WEIGHT, + ) + if is_grad_enabled: + # Fused cast-transpose kernels tex.fp8_cast_transpose_fused( fc1_weight, fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, - cast_out=fc1_weight_fp8, - transpose_out=fc1_weight_t_fp8, + cast_out=fc1_weight_fp8._data, + transpose_out=fc1_weight_t_fp8._data, ) - tex.fp8_cast_transpose_fused( fc2_weight, fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM2_WEIGHT, fp8_dtype_forward, - cast_out=fc2_weight_fp8, - transpose_out=fc2_weight_t_fp8, + cast_out=fc2_weight_fp8._data, + transpose_out=fc2_weight_t_fp8._data, ) else: - fc1_weight_t_fp8 = None - fc1_weight_fp8 = tex.cast_to_fp8( + fc1_weight_fp8._data = tex.cast_to_fp8( fc1_weight, fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, ) - fc2_weight_t_fp8 = None - fc2_weight_fp8 = tex.cast_to_fp8( + fc1_weight_t_fp8 = None + fc2_weight_fp8._data = tex.cast_to_fp8( fc2_weight, fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM2_WEIGHT, fp8_dtype_forward, ) + fc2_weight_t_fp8 = None ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ub_atomic_gemm_ag else ub_algo fc1_out, _ = tex.fp8_gemm( - fc1_weight_fp8, + fc1_weight_fp8._data, fp8_meta["scaling_fwd"].scale_inv, tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, @@ -283,7 +308,7 @@ def forward( ub_algo=tex.UbufOverlapAlgo.ATOMIC_GEMM_RS if ub_atomic_gemm_rs else None ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else ub_algo _ = tex.fp8_gemm( - fc2_weight_fp8, + fc2_weight_fp8._data, fp8_meta["scaling_fwd"].scale_inv, tex.FP8FwdTensors.GEMM2_WEIGHT, fp8_dtype_forward, @@ -530,7 +555,7 @@ def backward( ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ctx.ub_atomic_gemm_ag else ub_algo # FC2 DGRAD; Unconditional fc2_dgrad, _ = tex.fp8_gemm( - fc2_weight_t_fp8, + fc2_weight_t_fp8._data, fwd_scale_inverses, tex.FP8FwdTensors.GEMM2_WEIGHT, fp8_dtype_forward, @@ -645,7 +670,7 @@ def backward( ) # FC1 DGRAD: Unconditional _ = tex.fp8_gemm( - fc1_weight_t_fp8, + fc1_weight_t_fp8._data, fwd_scale_inverses, tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, @@ -908,6 +933,7 @@ def backward( None, None, None, + None, ) @@ -1020,12 +1046,12 @@ def __init__( micro_batch_size: Optional[int] = None, set_parallel_mode: bool = False, zero_centered_gamma: bool = False, + device: Union[torch.device, str] = "cuda", ub_bulk_wgrad: bool = False, ub_bulk_dgrad: bool = False, ub_split_rs: bool = False, ub_atomic_gemm_rs: bool = False, ub_split_ag: bool = False, - device: Union[torch.device, str] = "cuda", ub_atomic_gemm_ag: bool = False, ) -> None: super().__init__() @@ -1043,6 +1069,7 @@ def __init__( self.activation == 'gelu') self.set_parallel_mode = set_parallel_mode self.zero_centered_gamma = zero_centered_gamma + self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters() self.ub_bulk_wgrad = ub_bulk_wgrad self.ub_bulk_dgrad = ub_bulk_dgrad self.ub_split_rs = ub_split_rs @@ -1102,19 +1129,30 @@ def __init__( else: fc1_output_features = self.size_per_partition # FC1 init - self.fc1_weight = Parameter( - torch.empty(fc1_output_features, hidden_size, device=device, dtype=params_dtype) - ) - self.fp8_weight_shapes.append(self.fc1_weight.shape) + fc1_temp_weight = torch.empty( + fc1_output_features, hidden_size, device=device, dtype=params_dtype) initialize_affine_weight_gpu( - self.fc1_weight, + fc1_temp_weight, init_method, get_rng_state_tracker, - partition_dim=0, - stride=1, + set_tp_attributes=False, ) + if self.primary_weights_in_fp8: + self.init_fp8_metadata(num_gemms=2) + self.fp8_meta["update_amax_and_scale_fwd"] = True + + fc1_temp_weight = Float8Tensor.to_float8( + fc1_temp_weight, + fp8_meta=self.fp8_meta, + fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, + ) + + self.fc1_weight = Parameter(fc1_temp_weight) + set_tensor_model_parallel_attributes(self.fc1_weight, True, 0, 1) + self.fp8_weight_shapes.append(self.fc1_weight.shape) + if self.use_bias: self.fc1_bias = Parameter( torch.empty(fc1_output_features, device=device, dtype=params_dtype) @@ -1127,19 +1165,27 @@ def __init__( self.fc1_bias.zero_() # FC2 init - self.fc2_weight = Parameter( - torch.empty(hidden_size, self.size_per_partition, device=device, dtype=params_dtype) - ) - self.fp8_weight_shapes.append(self.fc2_weight.shape) + fc2_temp_weight = torch.empty( + hidden_size, self.size_per_partition, device=device, dtype=params_dtype) initialize_affine_weight_gpu( - self.fc2_weight, + fc2_temp_weight, output_layer_init_method, get_rng_state_tracker, - partition_dim=1, - stride=1, + set_tp_attributes=False, ) + if self.primary_weights_in_fp8: + fc2_temp_weight = Float8Tensor.to_float8( + fc2_temp_weight, + fp8_meta=self.fp8_meta, + fp8_meta_index=tex.FP8FwdTensors.GEMM2_WEIGHT, + ) + + self.fc2_weight = Parameter(fc2_temp_weight) + set_tensor_model_parallel_attributes(self.fc2_weight, True, 1, 1) + self.fp8_weight_shapes.append(self.fc2_weight.shape) + if self.use_bias: self.fc2_bias = Parameter( torch.empty(hidden_size, device=device, dtype=params_dtype) @@ -1192,7 +1238,7 @@ def get_fp8_weights_scratchpad( `is_first_microbatch` is not `None`) or return empty fp8 weight tensors (if `is_first_microbatch is None`) """ - if not self.fp8: + if not self.fp8 or self.primary_weights_in_fp8: return [None, None, None, None] if is_first_microbatch is None: @@ -1235,6 +1281,8 @@ def forward( """ with self.prepare_forward(inp, is_first_microbatch, num_gemms=2) as inp: + assert self.fp8 or not self.primary_weights_in_fp8, \ + "Need to run inside fp8_autocast region when weights are stored in FP8." # Fetch the fp8 weights placeholders (for linear/gemm) weight1_fp8, weight1_t_fp8, weight2_fp8, weight2_t_fp8 = \ self.get_fp8_weights_scratchpad( @@ -1279,14 +1327,15 @@ def forward( self.fwd_ln_sm_margin, self.bwd_ln_sm_margin, self.zero_centered_gamma, + self.activation, + self.normalization, + self.primary_weights_in_fp8, self.ub_bulk_wgrad, self.ub_bulk_dgrad, self.ub_split_rs, self.ub_atomic_gemm_rs, self.ub_split_ag, self.ub_atomic_gemm_ag, - self.activation, - self.normalization, ) out = fwd_fn(*args) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 5e2cab22fe..b14877e74b 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -20,7 +20,7 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) -from ..fp8 import get_fp8_te_dtype +from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager from ..utils import ( divide, get_default_init_method, @@ -45,6 +45,8 @@ from ..constants import GemmParallelModes, dist_group_type from ..jit import no_torch_dynamo +from ..float8_tensor import Float8Tensor + __all__ = ["Linear"] @@ -57,9 +59,9 @@ class _Linear(torch.autograd.Function): @staticmethod def forward( ctx, - weight: torch.Tensor, - weight_fp8: Union[torch.Tensor, None], - weight_t_fp8: Union[torch.Tensor, None], + weight: Union[Float8Tensor, torch.Tensor], + weight_fp8: Union[Float8Tensor, None], + weight_t_fp8: Union[Float8Tensor, None], inp: torch.Tensor, bias: torch.Tensor, use_bias: bool, @@ -75,6 +77,7 @@ def forward( activation_dtype: torch.dtype, parallel_mode: Union[str, None], is_grad_enabled: bool, + primary_weights_in_fp8: bool, ub_split_rs: bool, ub_split_ag: bool, ub_atomic_gemm_rs: bool, @@ -141,24 +144,38 @@ def forward( ) bias = cast_if_needed(bias, bias_dtype) if use_bias else bias - if update_fp8_weights: + if primary_weights_in_fp8: + # Weight is already in FP8 + weight.reset_fp8_meta_scale_inv() + weight_fp8 = weight + weight_t_fp8 = None + if is_grad_enabled: + weight_t_fp8 = weight_fp8.transpose(update_cache=is_first_microbatch) + + elif update_fp8_weights: + # Need to cast weights to FP8 + weight_fp8 = Float8Tensor( + data=weight_fp8._data, + fp8_meta=fp8_meta, + fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, + ) if is_grad_enabled: fp8_cast_transpose_fused( weight, fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, - cast_out=weight_fp8, - transpose_out=weight_t_fp8, + cast_out=weight_fp8._data, + transpose_out=weight_t_fp8._data, ) else: - weight_t_fp8 = None - weight_fp8 = cast_to_fp8( + weight_fp8._data = cast_to_fp8( weight, fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, ) + weight_t_fp8 = None proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = ( None, None, None, activation_dtype) @@ -184,7 +201,7 @@ def forward( ub_algo=tex.UbufOverlapAlgo.ATOMIC_GEMM_RS if ub_atomic_gemm_rs else None ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else ub_algo _ = fp8_gemm( - weight_fp8, + weight_fp8._data, fp8_meta["scaling_fwd"].scale_inv, tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, @@ -245,6 +262,9 @@ def forward( if is_grad_enabled: fp8_wgrad = fp8 and not fp8_meta["recipe"].override_linear_precision.wgrad + if fp8: + assert hasattr(weight_t_fp8, "_data"), \ + "_data attr doesn't exist (before save for bwd)" ctx.save_for_backward( inputmat_no_fp8 if weight.requires_grad and not fp8_wgrad else None, inputmat_t if weight.requires_grad and fp8_wgrad else None, @@ -294,6 +314,9 @@ def backward( weight_t_fp8, fwd_scale_inverses, ) = ctx.saved_tensors + if weight_t_fp8 is not None: + assert hasattr(weight_t_fp8, "_data"), \ + "_data attr doesn't exist (after restore in bwd)" if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag: tp_world_size = get_distributed_world_size(ctx.tp_group) @@ -349,7 +372,7 @@ def backward( if ctx.requires_dgrad: if ctx.fp8: dgrad, _ = fp8_gemm( - weight_t_fp8, + weight_t_fp8._data, fwd_scale_inverses, tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, @@ -470,6 +493,7 @@ def backward( None, None, None, + None, ) @@ -554,9 +578,9 @@ def __init__( params_dtype: Optional[torch.dtype] = None, parallel_mode: Optional[str] = None, parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None, + device: Union[torch.device, str] = "cuda", ub_split_rs: bool = False, ub_split_ag: bool = False, - device: Union[torch.device, str] = "cuda", ub_atomic_gemm_rs: bool = False, ub_atomic_gemm_ag: bool = False, ) -> None: @@ -570,6 +594,7 @@ def __init__( self.return_bias = return_bias self.apply_bias = bias and not return_bias self.parameters_split = parameters_split + self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters() self.ub_split_rs = ub_split_rs self.ub_split_ag = ub_split_ag self.ub_atomic_gemm_rs = ub_atomic_gemm_rs @@ -609,18 +634,31 @@ def __init__( self.sequence_parallel = (self.tp_size > 1) and sequence_parallel - self.weight_tensor = torch.empty( + temp_weight = torch.empty( self.out_features, self.in_features, device=device, dtype=params_dtype) + # TODO(ksivaman): This functionality works with FP8 outside TE. initialize_affine_weight_gpu( - self.weight_tensor, + temp_weight, init_method, get_rng_state_tracker, partition_dim=1 if self.parallel_mode == "row" else 0, stride=1, ) + if self.primary_weights_in_fp8: + self.init_fp8_metadata() + self.fp8_meta["update_amax_and_scale_fwd"] = True + + self.weight_tensor = Float8Tensor.to_float8( + temp_weight, + fp8_meta=self.fp8_meta, + fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, + ) + else: + self.weight_tensor = temp_weight + if self.use_bias: self.bias_tensor = torch.empty(self.out_features, device=device, dtype=params_dtype) else: @@ -657,9 +695,17 @@ def __init__( slice_end = slice_begin + slice_size - self.register_parameter( - wname, Parameter(self.weight_tensor[slice_begin:slice_end]) - ) + # TODO(ksivaman): Add indexing op to torch dispatcher for float8 + if self.primary_weights_in_fp8: + assert len(parameters_split) == 1, ("Slicing operation is not " + "supported in Float8Tensor " + "class!") + self.register_parameter(wname, Parameter(self.weight_tensor)) + else: + + self.register_parameter( + wname, Parameter(self.weight_tensor[slice_begin:slice_end]) + ) set_tensor_model_parallel_attributes( tensor=getattr(self, wname), @@ -697,13 +743,13 @@ def __init__( def get_fp8_weights_scratchpad( self, is_first_microbatch: Union[bool, None], - ) -> List[torch.Tensor]: + ) -> List[Float8Tensor]: """ Fetch the fp8 weight tensor placeholders if they exist (when `is_first_microbatch` is not `None`) or return empty fp8 weight tensors (if `is_first_microbatch is None`) """ - if not self.fp8: + if not self.fp8 or self.primary_weights_in_fp8: return [None, None] if is_first_microbatch is None: @@ -747,6 +793,8 @@ def forward( """ with self.prepare_forward(inp, is_first_microbatch) as inp: + assert self.fp8 or not self.primary_weights_in_fp8, \ + "Need to run inside fp8_autocast region when weights are stored in FP8." bias_tensor = ( self.bias if self.parameters_split is None else self.bias_tensor if not torch.is_grad_enabled() @@ -790,6 +838,7 @@ def forward( self.activation_dtype, self.parallel_mode, torch.is_grad_enabled(), + self.primary_weights_in_fp8, self.ub_split_rs, self.ub_split_ag, self.ub_atomic_gemm_rs, From 66d91d5219f295ec1e2e714a4926ddb67a2b8f80 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 24 Oct 2023 12:11:53 -0700 Subject: [PATCH 68/68] [paddle] add documentation (#489) * paddle documentation Signed-off-by: Kirthi Shankar Sivamani * minor fix Signed-off-by: Kirthi Shankar Sivamani * review comments Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani --- docs/api/framework.rst | 1 + docs/api/paddle.rst | 34 ++++++++++++ transformer_engine/paddle/fp8.py | 38 ++++++++++++++ transformer_engine/paddle/layer/attention.py | 32 +++++++----- transformer_engine/paddle/layer/layernorm.py | 28 +++++++++- .../paddle/layer/layernorm_linear.py | 47 ++++++++++++++++- .../paddle/layer/layernorm_mlp.py | 52 ++++++++++++++++++- transformer_engine/paddle/layer/linear.py | 34 +++++++++++- transformer_engine/paddle/layer/softmax.py | 27 +++++++--- .../paddle/layer/transformer.py | 12 +++-- transformer_engine/paddle/recompute.py | 14 ++++- 11 files changed, 288 insertions(+), 31 deletions(-) create mode 100644 docs/api/paddle.rst diff --git a/docs/api/framework.rst b/docs/api/framework.rst index 81d980e089..e298535ed0 100644 --- a/docs/api/framework.rst +++ b/docs/api/framework.rst @@ -10,3 +10,4 @@ Framework-specific API pytorch jax + paddle diff --git a/docs/api/paddle.rst b/docs/api/paddle.rst new file mode 100644 index 0000000000..0ce6ce2284 --- /dev/null +++ b/docs/api/paddle.rst @@ -0,0 +1,34 @@ +.. + Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +paddle +====== + +.. autoapiclass:: transformer_engine.paddle.Linear(in_features, out_features, **kwargs) + :members: forward + +.. autoapiclass:: transformer_engine.paddle.LayerNorm(hidden_size, eps=1e-5, **kwargs) + +.. autoapiclass:: transformer_engine.paddle.LayerNormLinear(in_features, out_features, eps=1e-5, **kwargs) + :members: forward + +.. autoapiclass:: transformer_engine.paddle.LayerNormMLP(hidden_size, ffn_hidden_size, eps=1e-5, **kwargs) + :members: forward + +.. autoapiclass:: transformer_engine.paddle.FusedScaleMaskSoftmax(attn_mask_type, mask_func, **kwargs) + :members: forward + +.. autoapiclass:: transformer_engine.paddle.DotProductAttention(num_attention_heads, kv_channels, **kwargs) + :members: forward + +.. autoapiclass:: transformer_engine.paddle.MultiHeadAttention(hidden_size, num_attention_heads, **kwargs) + :members: forward + +.. autoapiclass:: transformer_engine.paddle.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads, **kwargs) + :members: forward + +.. autoapifunction:: transformer_engine.paddle.fp8_autocast + +.. autoapifunction:: transformer_engine.paddle.recompute diff --git a/transformer_engine/paddle/fp8.py b/transformer_engine/paddle/fp8.py index abf347042a..9ec3037236 100644 --- a/transformer_engine/paddle/fp8.py +++ b/transformer_engine/paddle/fp8.py @@ -15,6 +15,10 @@ from .constants import dist_group_type from .fp8_buffer import FP8MetaFwdBuffer, FP8MetaBwdBuffer, FP8RecomputeBuffer + +__all__ = ['fp8_autocast'] + + # FP8 support _is_fp8_available = None _reason_for_no_fp8 = "" @@ -166,6 +170,40 @@ def fp8_autocast( ) -> None: """ Context manager for FP8 usage. + + .. code-block:: python + + with fp8_autocast(enabled=True): + out = model(inp) + + .. note:: + + Support for FP8 in the Linear layer of Transformer Engine is currently limited to tensors + with shapes where both dimensions are divisible by 16. In terms of the input to the full + Transformer network, this typically requires padding sequence length to be multiple of 16. + + .. note:: + + When :attr:`fp8_recipe.reduce_amax==True`, any module must not be invoked more than once + inside a single `fp8_autocast` region. This is unsupported behavior because the amax + reduction is handled during the exit of the `fp8_autocast` context. Calling the same + module more than once inside an `fp8_autocast` region overrides the amax tensors + before reduction can occur. + + Parameters + ---------- + enabled: bool, default = `False` + whether or not to enable fp8 + calibrating: bool, default = `False` + calibration mode allows collecting statistics such as amax and scale + data of fp8 tensors even when executing without fp8 enabled. This is + useful for saving an inference ready fp8 checkpoint while training + using a higher precision. + fp8_recipe: recipe.DelayedScaling, default = `None` + recipe used for FP8 training. + fp8_group: paddle.distributed.collective.Group, default = `None` + distributed group over which amaxes for the fp8 tensors + are reduced at the end of each training step. """ try: _global_fp8_state.enter(enabled, calibrating, fp8_recipe, fp8_group) diff --git a/transformer_engine/paddle/layer/attention.py b/transformer_engine/paddle/layer/attention.py index 8c9be22748..02aa53b042 100644 --- a/transformer_engine/paddle/layer/attention.py +++ b/transformer_engine/paddle/layer/attention.py @@ -29,6 +29,9 @@ from ..recompute import recompute +__all__ = ["DotProductAttention", "MultiHeadAttention"] + + class FusedAttnFuncPackedQKV(paddle.autograd.PyLayer): """Function for FusedAttention with packed QKV input""" @@ -129,7 +132,7 @@ def backward(ctx, d_out): class DotProductAttention(paddle.nn.Layer): - """Dot Product Attention Layer + """ Allows the model to jointly attend to information from different representation subspaces as described in the paper: `Attention Is All You Need `_. @@ -150,8 +153,7 @@ class DotProductAttention(paddle.nn.Layer): attention_type: {'self', 'cross'}, default = `self` type of attention operation. backend: {'transformer_engine', 'paddle'}, default = `transformer_engine` - backend to use for attention operation. - + backend to use for attention operation. """ def __init__(self, @@ -215,17 +217,17 @@ def forward( Parameters ---------- query_layer : paddle.Tensor - Query tensor. + Query tensor. key_value_layer : paddle.Tensor - Key tensor. + Key tensor. attention_mask : Optional[paddle.Tensor], default = `None` - Boolean tensor used to mask out softmax input when not using attention. + Boolean tensor used to mask out softmax input when not using attention. core_attention_bias_type: str, default = `no_bias` - only support no_bias type currently, {`no_bias`} + only support no_bias type currently, {`no_bias`} core_attention_bias: Optional[paddle.Tensor], default = `None` - Bias tensor for Q * K.T - set_zero: bool, defautl = `True` - Whether to use the fast path to set output tensors to 0 or not. + Bias tensor for Q * K.T + set_zero: bool, default = `True` + Whether to use the fast path to set output tensors to 0 or not. """ backend = self.backend @@ -358,7 +360,9 @@ def _pd_forward( class MultiHeadAttention(paddle.nn.Layer): - """Attention w/ QKV and Proj Gemms + """ + Multi-head Attention (MHA), including Query, + Key, Value and Output projection. Parameters ---------- @@ -387,7 +391,8 @@ class MultiHeadAttention(paddle.nn.Layer): zero_centered_gamma: bool, default = `False` whether to zero initialize the gamma of the layernorm operation. backend: {'transformer_engine', 'paddle'}, default = `transformer_engine` - backend to use for attention operation. + backend to use for attention operation. If set to 'paddle', a framework + only no-FP8 path is executed with limited optimization. Parallelism parameters ---------------------- @@ -542,7 +547,6 @@ def forward( """ MultiHeadAttention Layer. - Parameters ---------- hidden_states : paddle.Tensor @@ -555,7 +559,7 @@ def forward( only support no_bias type currently, {`no_bias`} core_attention_bias: Optional[paddle.Tensor], default = `None` Bias tensor for Q * K.T - set_zero: bool, defautl = `True` + set_zero: bool, default = `True` Whether to use the fast path to set output tensors to 0 or not. recompute_core_attention: bool, default = `False` If true, forward activations for core attention are recomputed diff --git a/transformer_engine/paddle/layer/layernorm.py b/transformer_engine/paddle/layer/layernorm.py index 89c03ee25c..77c164e48a 100644 --- a/transformer_engine/paddle/layer/layernorm.py +++ b/transformer_engine/paddle/layer/layernorm.py @@ -63,7 +63,33 @@ def backward(ctx, grad_output: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None class LayerNorm(paddle.nn.Layer): r""" Applies Layer Normalization over a mini-batch of inputs as described in - the paper `Layer Normalization ` + the paper `Layer Normalization `__ + + .. math:: + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma + \beta + + :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of + size :attr:`hidden_size` + + Parameters + ---------- + hidden_size : int + size of each input sample. + eps : float, default = 1e-5 + a value added to the denominator of layer normalization for numerical stability. + weight_attr: Union[paddle.ParamAttr, None], default = None + optional `paddle.ParamAttr` for weight. + bias_attr: Union[paddle.ParamAttr, None, bool], default = None + optional `paddle.ParamAttr` for bias. + zero_centered_gamma : bool, default = 'False' + if set to 'True', gamma parameter in LayerNorm is initialized to 0 and + the LayerNorm formula changes to + + .. math:: + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * + (1 + \gamma) + \beta + backend: {'transformer_engine', 'paddle'}, default = `transformer_engine` + backend to use for softmax operation. """ def __init__( diff --git a/transformer_engine/paddle/layer/layernorm_linear.py b/transformer_engine/paddle/layer/layernorm_linear.py index 1d13ee093f..e1b46aaa18 100644 --- a/transformer_engine/paddle/layer/layernorm_linear.py +++ b/transformer_engine/paddle/layer/layernorm_linear.py @@ -40,7 +40,7 @@ saved_tensor_allow_none, ) -__all__ = ["LayerNormLinear", "_layernorm_fwd_fp8_cast", "_layernorm_bwd"] +__all__ = ["LayerNormLinear"] def _layernorm_fwd_fp8_cast( @@ -331,6 +331,42 @@ def backward( class LayerNormLinear(TransformerEngineBaseLayer): r""" Applies layer normalization followed by linear transformation to the incoming data. + + Parameters + ---------- + in_features : int + size of each input sample. + out_features : int + size of each output sample. + eps : float, default = 1e-5 + a value added to the denominator of layer normalization for numerical stability. + weight_attr: Union[paddle.ParamAttr, None], default = None + optional `paddle.ParamAttr` for weight. + bias_attr: Union[paddle.ParamAttr, None, bool], default = None + optional `paddle.ParamAttr` for bias. + return_layernorm_output : bool, default = `False` + if set to `True`, output of layernorm is returned from the forward + together with the output of the linear transformation. + Example use case: residual connection for transformer module is + taken post layernorm. + zero_centered_gamma : bool, default = 'False' + if set to 'True', gamma parameter in LayerNorm is initialized to 0 and + the LayerNorm formula changes to + + .. math:: + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * + (1 + \gamma) + \beta + backend: {'transformer_engine', 'paddle'}, default = 'transformer_engine' + if set to 'paddle', a framework only no-FP8 path is executed with limited optimization. + + Parallelism parameters + ---------------------- + tp_group : ProcessGroup, default = `None` + tensor parallel process group. + parallel_mode : {None, 'Column', 'Row'}, default = `None` + used to decide whether this Linear layer is Column Parallel Linear or Row + Parallel Linear as described `here `_. + When set to `None`, no communication is performed. """ def __init__( @@ -503,7 +539,14 @@ def _pd_forward( return out def forward(self, *args, **kwargs): - """forward""" + """ + Apply layer normalization to the input followed by a linear transformation. + + Parameters + ---------- + inp : torch.Tensor + Input tensor. + """ if self.backend == 'transformer_engine': return self._te_forward(*args, **kwargs) if self.backend == 'paddle': diff --git a/transformer_engine/paddle/layer/layernorm_mlp.py b/transformer_engine/paddle/layer/layernorm_mlp.py index 85364552cc..c4752f6406 100644 --- a/transformer_engine/paddle/layer/layernorm_mlp.py +++ b/transformer_engine/paddle/layer/layernorm_mlp.py @@ -39,6 +39,7 @@ saved_tensor_allow_none, ) + __all__ = ["LayerNormMLP"] @@ -549,7 +550,47 @@ def backward( class LayerNormMLP(TransformerEngineBaseLayer): r""" - Applies layer normalization followed by linear transformation to the incoming data. + Applies layer normalization on the input followed by the MLP module, consisting of + 2 successive linear transformations, separated by the GeLU activation. + + Parameters + ---------- + hidden_size : int + size of each input sample. + ffn_hidden_size : int + intermediate size to which input samples are projected. + eps : float, default = 1e-5 + a value added to the denominator of layer normalization for numerical stability. + weight_attr: Union[paddle.ParamAttr, None], default = None + optional `paddle.ParamAttr` for weight. + bias_attr: Union[paddle.ParamAttr, None, bool], default = None + optional `paddle.ParamAttr` for bias. + activation : str, default = 'gelu' + activation function used. + Options: 'gelu', 'geglu', 'relu', 'reglu', 'squared_relu', 'swiglu'. + return_layernorm_output : bool, default = `False` + if set to `True`, output of layernorm is returned from the forward + together with the output of the linear transformation. + Example use case: residual connection for transformer module + is taken post layernorm. + zero_centered_gamma : bool, default = 'False' + if set to 'True', gamma parameter in LayerNorm is initialized to 0 and + the LayerNorm formula changes to + + .. math:: + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * + (1 + \gamma) + \beta + backend: {'transformer_engine', 'paddle'}, default = 'transformer_engine' + if set to 'paddle', a framework only no-FP8 path is executed with limited optimization. + + Parallelism parameters + ---------------------- + set_parallel_mode : bool, default = `False` + if set to `True`, FC1 is used as Column Parallel and FC2 is used as Row + Parallel as described `here `_. + tp_group : paddle.distributed.collective.Group, default = `None` + tensor parallel process group. + """ def __init__( @@ -753,7 +794,14 @@ def _pd_forward( return out def forward(self, *args, **kwargs): - """forward""" + """ + Apply layer normalization to the input followed by a feedforward network (MLP Block). + + Parameters + ---------- + inp : torch.Tensor + Input tensor. + """ if self.backend == 'transformer_engine': return self._te_forward(*args, **kwargs) if self.backend == 'paddle': diff --git a/transformer_engine/paddle/layer/linear.py b/transformer_engine/paddle/layer/linear.py index 9644f9c4e7..1c4ba3ef9b 100644 --- a/transformer_engine/paddle/layer/linear.py +++ b/transformer_engine/paddle/layer/linear.py @@ -38,7 +38,7 @@ saved_tensor_allow_none, ) -__all__ = ["Linear", "_linear_fwd", "_linear_fwd_fp8", "_linear_bwd", "_linear_fwd_non_fp8"] +__all__ = ["Linear"] def _linear_fwd_fp8( @@ -541,6 +541,29 @@ def backward(ctx, grad_output: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None class Linear(TransformerEngineBaseLayer): """ Applies a linear transformation to the incoming data :math:`y = xA^T + b` + + Parameters + ---------- + in_features : int + size of each input sample. + out_features : int + size of each output sample. + weight_attr: Union[paddle.ParamAttr, None], default = None + optional `paddle.ParamAttr` for weight. + bias_attr: Union[paddle.ParamAttr, None, bool], default = None + optional `paddle.ParamAttr` for bias. + backend: {'transformer_engine', 'paddle'}, default = 'transformer_engine' + if set to 'paddle', a framework only no-FP8 path is executed with limited optimization. + + Parallelism parameters + ---------------------- + tp_group : ProcessGroup, default = `None` + tensor parallel process group. + parallel_mode : {None, 'Column', 'Row'}, default = `None` + used to decide whether this Linear layer is Column Parallel Linear or Row + Parallel Linear as described `here `_. + When set to `None`, no communication is performed. + """ def __init__( @@ -658,7 +681,14 @@ def _pd_forward( return out def forward(self, *args, **kwargs): - """forward""" + """ + Apply the linear transformation to the input. + + Parameters + ---------- + inp : torch.Tensor + Input tensor. + """ if self.backend == 'transformer_engine': return self._te_forward(*args, **kwargs) if self.backend == 'paddle': diff --git a/transformer_engine/paddle/layer/softmax.py b/transformer_engine/paddle/layer/softmax.py index 33b0293e0a..b48dd26259 100644 --- a/transformer_engine/paddle/layer/softmax.py +++ b/transformer_engine/paddle/layer/softmax.py @@ -18,9 +18,14 @@ scaled_softmax_backward, ) + +__all__ = ["FusedScaleMaskSoftmax"] + + THREADS_PER_WARP = 32 THREADS_PER_BLOCK = 128 + _default_causal_mask = {} @@ -112,12 +117,22 @@ def backward(ctx, output_grads: paddle.Tensor) -> Tuple[Union[paddle.Tensor, Non class FusedScaleMaskSoftmax(paddle.nn.Layer): """ - fused operation: scaling + mask + softmax - - Arguments: - attn_mask_type: attention mask type (pad or causal) - mask_func: mask function to be applied. - softmax_in_fp32: if true, softmax in performed at fp32 precision. + Scaled and masked softmax module for paddle with fused optimizations. + + Parameters + ---------- + attn_mask_type : str, default = `causal` + type of attention mask, can be 'causal', 'padding', or 'no_mask'. + mask_func : callable + custom callable for applying the mask to the softmax input. + `masked_input=mask_func(inp, mask)`. + softmax_in_fp32 : bool, default = True + perform softmax computation in fp32. + layernorm_epsilon : float, default = 1e-5 + a value added to the denominator of layer normalization + for numerical stability. + backend: {'transformer_engine', 'paddle'}, default = `transformer_engine` + backend to use for operation. """ def __init__( diff --git a/transformer_engine/paddle/layer/transformer.py b/transformer_engine/paddle/layer/transformer.py index ada4107648..95c592c672 100644 --- a/transformer_engine/paddle/layer/transformer.py +++ b/transformer_engine/paddle/layer/transformer.py @@ -8,9 +8,9 @@ import paddle from paddle.incubate.nn.layer.fused_dropout_add import FusedDropoutAdd -from . import LayerNormMLP, LayerNorm, MultiHeadAttention -from ..constants import AttnMaskTypes, LayerTypes, dist_group_type -from ..distributed import get_tp_group_and_world_size, track_rng_state +from transformer_engine.paddle.layer import LayerNormMLP, LayerNorm, MultiHeadAttention +from transformer_engine.paddle.constants import AttnMaskTypes, LayerTypes, dist_group_type +from transformer_engine.paddle.distributed import get_tp_group_and_world_size, track_rng_state class TransformerLayer(paddle.nn.Layer): @@ -33,6 +33,10 @@ class TransformerLayer(paddle.nn.Layer): dropout probability for the dropout op after FC2 layer. attention_dropout: float, default = 0.1 dropout probability for the dropout op during multi-head attention. + weight_attr: Union[paddle.ParamAttr, None], default = None + optional `paddle.ParamAttr` for weight. + bias_attr: Union[paddle.ParamAttr, None, bool], default = None + optional `paddle.ParamAttr` for bias. self_attn_mask_type: {'causal', 'padding'}, default = `causal` type of attention mask passed into softmax operation. apply_residual_connection_post_layernorm : bool, default = `False` @@ -62,6 +66,8 @@ class TransformerLayer(paddle.nn.Layer): it controls the type used to allocate the initial parameters. Useful when the model is trained with lower precision and the original FP32 parameters would not fit in GPU memory. + backend: {'transformer_engine', 'paddle'}, default = 'transformer_engine' + if set to 'paddle', a framework only no-FP8 path is executed with limited optimization. Parallelism parameters ---------------------- diff --git a/transformer_engine/paddle/recompute.py b/transformer_engine/paddle/recompute.py index cf42505bc8..b4d22f5240 100644 --- a/transformer_engine/paddle/recompute.py +++ b/transformer_engine/paddle/recompute.py @@ -11,7 +11,9 @@ from .constants import RecomputeFunctionNames from .fp8 import get_global_fp8_state -__all__ = ['recompute', 'is_in_recompute_phase'] + +__all__ = ['recompute'] + _DISABLE_RECOMPUTE = int(os.getenv("NVTE_DISABLE_RECOMPUTE", "0")) @@ -35,6 +37,16 @@ def recompute(function, *args, **kwargs): """ This is a wrapper of paddle.distributed.fleet.utils.recompute. It provides necessary state information for fp8 layers. + + Parameters + ---------- + function: Callable + paddle module used to run the forward and backward passes using + the specified :attr:`args` and :attr:`kwargs`. + args : tuple + tuple of torch tensors for inputs to :attr:`function`. + kwargs : dict + dictionary of string keys for keyword arguments to :attr:`function`. """ assert not _DISABLE_RECOMPUTE, "Recompute is disabled. " \ f"Got NVTE_DISABLE_RECOMPUTE={_DISABLE_RECOMPUTE}."