From 66d91d5219f295ec1e2e714a4926ddb67a2b8f80 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 24 Oct 2023 12:11:53 -0700 Subject: [PATCH] [paddle] add documentation (#489) * paddle documentation Signed-off-by: Kirthi Shankar Sivamani * minor fix Signed-off-by: Kirthi Shankar Sivamani * review comments Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani --- docs/api/framework.rst | 1 + docs/api/paddle.rst | 34 ++++++++++++ transformer_engine/paddle/fp8.py | 38 ++++++++++++++ transformer_engine/paddle/layer/attention.py | 32 +++++++----- transformer_engine/paddle/layer/layernorm.py | 28 +++++++++- .../paddle/layer/layernorm_linear.py | 47 ++++++++++++++++- .../paddle/layer/layernorm_mlp.py | 52 ++++++++++++++++++- transformer_engine/paddle/layer/linear.py | 34 +++++++++++- transformer_engine/paddle/layer/softmax.py | 27 +++++++--- .../paddle/layer/transformer.py | 12 +++-- transformer_engine/paddle/recompute.py | 14 ++++- 11 files changed, 288 insertions(+), 31 deletions(-) create mode 100644 docs/api/paddle.rst diff --git a/docs/api/framework.rst b/docs/api/framework.rst index 81d980e089..e298535ed0 100644 --- a/docs/api/framework.rst +++ b/docs/api/framework.rst @@ -10,3 +10,4 @@ Framework-specific API pytorch jax + paddle diff --git a/docs/api/paddle.rst b/docs/api/paddle.rst new file mode 100644 index 0000000000..0ce6ce2284 --- /dev/null +++ b/docs/api/paddle.rst @@ -0,0 +1,34 @@ +.. + Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +paddle +====== + +.. autoapiclass:: transformer_engine.paddle.Linear(in_features, out_features, **kwargs) + :members: forward + +.. autoapiclass:: transformer_engine.paddle.LayerNorm(hidden_size, eps=1e-5, **kwargs) + +.. autoapiclass:: transformer_engine.paddle.LayerNormLinear(in_features, out_features, eps=1e-5, **kwargs) + :members: forward + +.. autoapiclass:: transformer_engine.paddle.LayerNormMLP(hidden_size, ffn_hidden_size, eps=1e-5, **kwargs) + :members: forward + +.. autoapiclass:: transformer_engine.paddle.FusedScaleMaskSoftmax(attn_mask_type, mask_func, **kwargs) + :members: forward + +.. autoapiclass:: transformer_engine.paddle.DotProductAttention(num_attention_heads, kv_channels, **kwargs) + :members: forward + +.. autoapiclass:: transformer_engine.paddle.MultiHeadAttention(hidden_size, num_attention_heads, **kwargs) + :members: forward + +.. autoapiclass:: transformer_engine.paddle.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads, **kwargs) + :members: forward + +.. autoapifunction:: transformer_engine.paddle.fp8_autocast + +.. autoapifunction:: transformer_engine.paddle.recompute diff --git a/transformer_engine/paddle/fp8.py b/transformer_engine/paddle/fp8.py index abf347042a..9ec3037236 100644 --- a/transformer_engine/paddle/fp8.py +++ b/transformer_engine/paddle/fp8.py @@ -15,6 +15,10 @@ from .constants import dist_group_type from .fp8_buffer import FP8MetaFwdBuffer, FP8MetaBwdBuffer, FP8RecomputeBuffer + +__all__ = ['fp8_autocast'] + + # FP8 support _is_fp8_available = None _reason_for_no_fp8 = "" @@ -166,6 +170,40 @@ def fp8_autocast( ) -> None: """ Context manager for FP8 usage. + + .. code-block:: python + + with fp8_autocast(enabled=True): + out = model(inp) + + .. note:: + + Support for FP8 in the Linear layer of Transformer Engine is currently limited to tensors + with shapes where both dimensions are divisible by 16. In terms of the input to the full + Transformer network, this typically requires padding sequence length to be multiple of 16. + + .. note:: + + When :attr:`fp8_recipe.reduce_amax==True`, any module must not be invoked more than once + inside a single `fp8_autocast` region. This is unsupported behavior because the amax + reduction is handled during the exit of the `fp8_autocast` context. Calling the same + module more than once inside an `fp8_autocast` region overrides the amax tensors + before reduction can occur. + + Parameters + ---------- + enabled: bool, default = `False` + whether or not to enable fp8 + calibrating: bool, default = `False` + calibration mode allows collecting statistics such as amax and scale + data of fp8 tensors even when executing without fp8 enabled. This is + useful for saving an inference ready fp8 checkpoint while training + using a higher precision. + fp8_recipe: recipe.DelayedScaling, default = `None` + recipe used for FP8 training. + fp8_group: paddle.distributed.collective.Group, default = `None` + distributed group over which amaxes for the fp8 tensors + are reduced at the end of each training step. """ try: _global_fp8_state.enter(enabled, calibrating, fp8_recipe, fp8_group) diff --git a/transformer_engine/paddle/layer/attention.py b/transformer_engine/paddle/layer/attention.py index 8c9be22748..02aa53b042 100644 --- a/transformer_engine/paddle/layer/attention.py +++ b/transformer_engine/paddle/layer/attention.py @@ -29,6 +29,9 @@ from ..recompute import recompute +__all__ = ["DotProductAttention", "MultiHeadAttention"] + + class FusedAttnFuncPackedQKV(paddle.autograd.PyLayer): """Function for FusedAttention with packed QKV input""" @@ -129,7 +132,7 @@ def backward(ctx, d_out): class DotProductAttention(paddle.nn.Layer): - """Dot Product Attention Layer + """ Allows the model to jointly attend to information from different representation subspaces as described in the paper: `Attention Is All You Need `_. @@ -150,8 +153,7 @@ class DotProductAttention(paddle.nn.Layer): attention_type: {'self', 'cross'}, default = `self` type of attention operation. backend: {'transformer_engine', 'paddle'}, default = `transformer_engine` - backend to use for attention operation. - + backend to use for attention operation. """ def __init__(self, @@ -215,17 +217,17 @@ def forward( Parameters ---------- query_layer : paddle.Tensor - Query tensor. + Query tensor. key_value_layer : paddle.Tensor - Key tensor. + Key tensor. attention_mask : Optional[paddle.Tensor], default = `None` - Boolean tensor used to mask out softmax input when not using attention. + Boolean tensor used to mask out softmax input when not using attention. core_attention_bias_type: str, default = `no_bias` - only support no_bias type currently, {`no_bias`} + only support no_bias type currently, {`no_bias`} core_attention_bias: Optional[paddle.Tensor], default = `None` - Bias tensor for Q * K.T - set_zero: bool, defautl = `True` - Whether to use the fast path to set output tensors to 0 or not. + Bias tensor for Q * K.T + set_zero: bool, default = `True` + Whether to use the fast path to set output tensors to 0 or not. """ backend = self.backend @@ -358,7 +360,9 @@ def _pd_forward( class MultiHeadAttention(paddle.nn.Layer): - """Attention w/ QKV and Proj Gemms + """ + Multi-head Attention (MHA), including Query, + Key, Value and Output projection. Parameters ---------- @@ -387,7 +391,8 @@ class MultiHeadAttention(paddle.nn.Layer): zero_centered_gamma: bool, default = `False` whether to zero initialize the gamma of the layernorm operation. backend: {'transformer_engine', 'paddle'}, default = `transformer_engine` - backend to use for attention operation. + backend to use for attention operation. If set to 'paddle', a framework + only no-FP8 path is executed with limited optimization. Parallelism parameters ---------------------- @@ -542,7 +547,6 @@ def forward( """ MultiHeadAttention Layer. - Parameters ---------- hidden_states : paddle.Tensor @@ -555,7 +559,7 @@ def forward( only support no_bias type currently, {`no_bias`} core_attention_bias: Optional[paddle.Tensor], default = `None` Bias tensor for Q * K.T - set_zero: bool, defautl = `True` + set_zero: bool, default = `True` Whether to use the fast path to set output tensors to 0 or not. recompute_core_attention: bool, default = `False` If true, forward activations for core attention are recomputed diff --git a/transformer_engine/paddle/layer/layernorm.py b/transformer_engine/paddle/layer/layernorm.py index 89c03ee25c..77c164e48a 100644 --- a/transformer_engine/paddle/layer/layernorm.py +++ b/transformer_engine/paddle/layer/layernorm.py @@ -63,7 +63,33 @@ def backward(ctx, grad_output: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None class LayerNorm(paddle.nn.Layer): r""" Applies Layer Normalization over a mini-batch of inputs as described in - the paper `Layer Normalization ` + the paper `Layer Normalization `__ + + .. math:: + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma + \beta + + :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of + size :attr:`hidden_size` + + Parameters + ---------- + hidden_size : int + size of each input sample. + eps : float, default = 1e-5 + a value added to the denominator of layer normalization for numerical stability. + weight_attr: Union[paddle.ParamAttr, None], default = None + optional `paddle.ParamAttr` for weight. + bias_attr: Union[paddle.ParamAttr, None, bool], default = None + optional `paddle.ParamAttr` for bias. + zero_centered_gamma : bool, default = 'False' + if set to 'True', gamma parameter in LayerNorm is initialized to 0 and + the LayerNorm formula changes to + + .. math:: + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * + (1 + \gamma) + \beta + backend: {'transformer_engine', 'paddle'}, default = `transformer_engine` + backend to use for softmax operation. """ def __init__( diff --git a/transformer_engine/paddle/layer/layernorm_linear.py b/transformer_engine/paddle/layer/layernorm_linear.py index 1d13ee093f..e1b46aaa18 100644 --- a/transformer_engine/paddle/layer/layernorm_linear.py +++ b/transformer_engine/paddle/layer/layernorm_linear.py @@ -40,7 +40,7 @@ saved_tensor_allow_none, ) -__all__ = ["LayerNormLinear", "_layernorm_fwd_fp8_cast", "_layernorm_bwd"] +__all__ = ["LayerNormLinear"] def _layernorm_fwd_fp8_cast( @@ -331,6 +331,42 @@ def backward( class LayerNormLinear(TransformerEngineBaseLayer): r""" Applies layer normalization followed by linear transformation to the incoming data. + + Parameters + ---------- + in_features : int + size of each input sample. + out_features : int + size of each output sample. + eps : float, default = 1e-5 + a value added to the denominator of layer normalization for numerical stability. + weight_attr: Union[paddle.ParamAttr, None], default = None + optional `paddle.ParamAttr` for weight. + bias_attr: Union[paddle.ParamAttr, None, bool], default = None + optional `paddle.ParamAttr` for bias. + return_layernorm_output : bool, default = `False` + if set to `True`, output of layernorm is returned from the forward + together with the output of the linear transformation. + Example use case: residual connection for transformer module is + taken post layernorm. + zero_centered_gamma : bool, default = 'False' + if set to 'True', gamma parameter in LayerNorm is initialized to 0 and + the LayerNorm formula changes to + + .. math:: + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * + (1 + \gamma) + \beta + backend: {'transformer_engine', 'paddle'}, default = 'transformer_engine' + if set to 'paddle', a framework only no-FP8 path is executed with limited optimization. + + Parallelism parameters + ---------------------- + tp_group : ProcessGroup, default = `None` + tensor parallel process group. + parallel_mode : {None, 'Column', 'Row'}, default = `None` + used to decide whether this Linear layer is Column Parallel Linear or Row + Parallel Linear as described `here `_. + When set to `None`, no communication is performed. """ def __init__( @@ -503,7 +539,14 @@ def _pd_forward( return out def forward(self, *args, **kwargs): - """forward""" + """ + Apply layer normalization to the input followed by a linear transformation. + + Parameters + ---------- + inp : torch.Tensor + Input tensor. + """ if self.backend == 'transformer_engine': return self._te_forward(*args, **kwargs) if self.backend == 'paddle': diff --git a/transformer_engine/paddle/layer/layernorm_mlp.py b/transformer_engine/paddle/layer/layernorm_mlp.py index 85364552cc..c4752f6406 100644 --- a/transformer_engine/paddle/layer/layernorm_mlp.py +++ b/transformer_engine/paddle/layer/layernorm_mlp.py @@ -39,6 +39,7 @@ saved_tensor_allow_none, ) + __all__ = ["LayerNormMLP"] @@ -549,7 +550,47 @@ def backward( class LayerNormMLP(TransformerEngineBaseLayer): r""" - Applies layer normalization followed by linear transformation to the incoming data. + Applies layer normalization on the input followed by the MLP module, consisting of + 2 successive linear transformations, separated by the GeLU activation. + + Parameters + ---------- + hidden_size : int + size of each input sample. + ffn_hidden_size : int + intermediate size to which input samples are projected. + eps : float, default = 1e-5 + a value added to the denominator of layer normalization for numerical stability. + weight_attr: Union[paddle.ParamAttr, None], default = None + optional `paddle.ParamAttr` for weight. + bias_attr: Union[paddle.ParamAttr, None, bool], default = None + optional `paddle.ParamAttr` for bias. + activation : str, default = 'gelu' + activation function used. + Options: 'gelu', 'geglu', 'relu', 'reglu', 'squared_relu', 'swiglu'. + return_layernorm_output : bool, default = `False` + if set to `True`, output of layernorm is returned from the forward + together with the output of the linear transformation. + Example use case: residual connection for transformer module + is taken post layernorm. + zero_centered_gamma : bool, default = 'False' + if set to 'True', gamma parameter in LayerNorm is initialized to 0 and + the LayerNorm formula changes to + + .. math:: + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * + (1 + \gamma) + \beta + backend: {'transformer_engine', 'paddle'}, default = 'transformer_engine' + if set to 'paddle', a framework only no-FP8 path is executed with limited optimization. + + Parallelism parameters + ---------------------- + set_parallel_mode : bool, default = `False` + if set to `True`, FC1 is used as Column Parallel and FC2 is used as Row + Parallel as described `here `_. + tp_group : paddle.distributed.collective.Group, default = `None` + tensor parallel process group. + """ def __init__( @@ -753,7 +794,14 @@ def _pd_forward( return out def forward(self, *args, **kwargs): - """forward""" + """ + Apply layer normalization to the input followed by a feedforward network (MLP Block). + + Parameters + ---------- + inp : torch.Tensor + Input tensor. + """ if self.backend == 'transformer_engine': return self._te_forward(*args, **kwargs) if self.backend == 'paddle': diff --git a/transformer_engine/paddle/layer/linear.py b/transformer_engine/paddle/layer/linear.py index 9644f9c4e7..1c4ba3ef9b 100644 --- a/transformer_engine/paddle/layer/linear.py +++ b/transformer_engine/paddle/layer/linear.py @@ -38,7 +38,7 @@ saved_tensor_allow_none, ) -__all__ = ["Linear", "_linear_fwd", "_linear_fwd_fp8", "_linear_bwd", "_linear_fwd_non_fp8"] +__all__ = ["Linear"] def _linear_fwd_fp8( @@ -541,6 +541,29 @@ def backward(ctx, grad_output: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None class Linear(TransformerEngineBaseLayer): """ Applies a linear transformation to the incoming data :math:`y = xA^T + b` + + Parameters + ---------- + in_features : int + size of each input sample. + out_features : int + size of each output sample. + weight_attr: Union[paddle.ParamAttr, None], default = None + optional `paddle.ParamAttr` for weight. + bias_attr: Union[paddle.ParamAttr, None, bool], default = None + optional `paddle.ParamAttr` for bias. + backend: {'transformer_engine', 'paddle'}, default = 'transformer_engine' + if set to 'paddle', a framework only no-FP8 path is executed with limited optimization. + + Parallelism parameters + ---------------------- + tp_group : ProcessGroup, default = `None` + tensor parallel process group. + parallel_mode : {None, 'Column', 'Row'}, default = `None` + used to decide whether this Linear layer is Column Parallel Linear or Row + Parallel Linear as described `here `_. + When set to `None`, no communication is performed. + """ def __init__( @@ -658,7 +681,14 @@ def _pd_forward( return out def forward(self, *args, **kwargs): - """forward""" + """ + Apply the linear transformation to the input. + + Parameters + ---------- + inp : torch.Tensor + Input tensor. + """ if self.backend == 'transformer_engine': return self._te_forward(*args, **kwargs) if self.backend == 'paddle': diff --git a/transformer_engine/paddle/layer/softmax.py b/transformer_engine/paddle/layer/softmax.py index 33b0293e0a..b48dd26259 100644 --- a/transformer_engine/paddle/layer/softmax.py +++ b/transformer_engine/paddle/layer/softmax.py @@ -18,9 +18,14 @@ scaled_softmax_backward, ) + +__all__ = ["FusedScaleMaskSoftmax"] + + THREADS_PER_WARP = 32 THREADS_PER_BLOCK = 128 + _default_causal_mask = {} @@ -112,12 +117,22 @@ def backward(ctx, output_grads: paddle.Tensor) -> Tuple[Union[paddle.Tensor, Non class FusedScaleMaskSoftmax(paddle.nn.Layer): """ - fused operation: scaling + mask + softmax - - Arguments: - attn_mask_type: attention mask type (pad or causal) - mask_func: mask function to be applied. - softmax_in_fp32: if true, softmax in performed at fp32 precision. + Scaled and masked softmax module for paddle with fused optimizations. + + Parameters + ---------- + attn_mask_type : str, default = `causal` + type of attention mask, can be 'causal', 'padding', or 'no_mask'. + mask_func : callable + custom callable for applying the mask to the softmax input. + `masked_input=mask_func(inp, mask)`. + softmax_in_fp32 : bool, default = True + perform softmax computation in fp32. + layernorm_epsilon : float, default = 1e-5 + a value added to the denominator of layer normalization + for numerical stability. + backend: {'transformer_engine', 'paddle'}, default = `transformer_engine` + backend to use for operation. """ def __init__( diff --git a/transformer_engine/paddle/layer/transformer.py b/transformer_engine/paddle/layer/transformer.py index ada4107648..95c592c672 100644 --- a/transformer_engine/paddle/layer/transformer.py +++ b/transformer_engine/paddle/layer/transformer.py @@ -8,9 +8,9 @@ import paddle from paddle.incubate.nn.layer.fused_dropout_add import FusedDropoutAdd -from . import LayerNormMLP, LayerNorm, MultiHeadAttention -from ..constants import AttnMaskTypes, LayerTypes, dist_group_type -from ..distributed import get_tp_group_and_world_size, track_rng_state +from transformer_engine.paddle.layer import LayerNormMLP, LayerNorm, MultiHeadAttention +from transformer_engine.paddle.constants import AttnMaskTypes, LayerTypes, dist_group_type +from transformer_engine.paddle.distributed import get_tp_group_and_world_size, track_rng_state class TransformerLayer(paddle.nn.Layer): @@ -33,6 +33,10 @@ class TransformerLayer(paddle.nn.Layer): dropout probability for the dropout op after FC2 layer. attention_dropout: float, default = 0.1 dropout probability for the dropout op during multi-head attention. + weight_attr: Union[paddle.ParamAttr, None], default = None + optional `paddle.ParamAttr` for weight. + bias_attr: Union[paddle.ParamAttr, None, bool], default = None + optional `paddle.ParamAttr` for bias. self_attn_mask_type: {'causal', 'padding'}, default = `causal` type of attention mask passed into softmax operation. apply_residual_connection_post_layernorm : bool, default = `False` @@ -62,6 +66,8 @@ class TransformerLayer(paddle.nn.Layer): it controls the type used to allocate the initial parameters. Useful when the model is trained with lower precision and the original FP32 parameters would not fit in GPU memory. + backend: {'transformer_engine', 'paddle'}, default = 'transformer_engine' + if set to 'paddle', a framework only no-FP8 path is executed with limited optimization. Parallelism parameters ---------------------- diff --git a/transformer_engine/paddle/recompute.py b/transformer_engine/paddle/recompute.py index cf42505bc8..b4d22f5240 100644 --- a/transformer_engine/paddle/recompute.py +++ b/transformer_engine/paddle/recompute.py @@ -11,7 +11,9 @@ from .constants import RecomputeFunctionNames from .fp8 import get_global_fp8_state -__all__ = ['recompute', 'is_in_recompute_phase'] + +__all__ = ['recompute'] + _DISABLE_RECOMPUTE = int(os.getenv("NVTE_DISABLE_RECOMPUTE", "0")) @@ -35,6 +37,16 @@ def recompute(function, *args, **kwargs): """ This is a wrapper of paddle.distributed.fleet.utils.recompute. It provides necessary state information for fp8 layers. + + Parameters + ---------- + function: Callable + paddle module used to run the forward and backward passes using + the specified :attr:`args` and :attr:`kwargs`. + args : tuple + tuple of torch tensors for inputs to :attr:`function`. + kwargs : dict + dictionary of string keys for keyword arguments to :attr:`function`. """ assert not _DISABLE_RECOMPUTE, "Recompute is disabled. " \ f"Got NVTE_DISABLE_RECOMPUTE={_DISABLE_RECOMPUTE}."