-
Notifications
You must be signed in to change notification settings - Fork 333
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[PyTorch] Experimental FP8 tensor class (#452)
* Experimental FP8 tensor Co-authored-by: Tim Moon <[email protected]> Co-authored-by: Sudhakar Singh <[email protected]> Co-authored-by: Przemyslaw Tredak <[email protected]> Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Add fp8 tensor to ci test Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * review comments and tests Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Minor changes Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Default to FP8 usage Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fix docs Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Naming changes Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * minor fix Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fix transpose caching Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Debug transpose caching Handle case where transpose cache is updated externally. Signed-off-by: Tim Moon <[email protected]> * Rename FP8GlobalStateManager.with_fp8_parameters Signed-off-by: Tim Moon <[email protected]> * remove Float8Tensor from import API Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Avoid caching FP8 transposes if not required Signed-off-by: Tim Moon <[email protected]> * Fix import error in FP8 tensor tests Signed-off-by: Tim Moon <[email protected]> * Fix tranpose caching and checkpointing bug Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Improve caching and fix distopt case Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Update transformer_engine/pytorch/float8_tensor.py Signed-off-by: Tim Moon <[email protected]> * Remove recursive logic Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fix cache reset bug Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Store FP8 attributes in dict Easier for multiple tensors to share, e.g. detached tensors. Signed-off-by: Tim Moon <[email protected]> * Make sure scale_inv is 1D tensor Signed-off-by: Tim Moon <[email protected]> * Make sure scale_inv is 1D tensor Signed-off-by: Tim Moon <[email protected]> * Fixes and detach recipe Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Set default fp8 data type Signed-off-by: Kirthi Shankar Sivamani <[email protected]> --------- Signed-off-by: Kirthi Shankar Sivamani <[email protected]> Signed-off-by: Tim Moon <[email protected]> Signed-off-by: Tim Moon <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]> Co-authored-by: Sudhakar Singh <[email protected]> Co-authored-by: Przemyslaw Tredak <[email protected]>
- Loading branch information
1 parent
7eca973
commit d58c08c
Showing
14 changed files
with
1,448 additions
and
141 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,318 @@ | ||
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# | ||
# See LICENSE for license information. | ||
|
||
from collections.abc import Iterable | ||
from typing import Any, Dict, List, Tuple, Union | ||
|
||
import pytest | ||
import torch | ||
|
||
import transformer_engine.common.recipe | ||
import transformer_engine.pytorch as te | ||
from transformer_engine.pytorch.float8_tensor import Float8Tensor | ||
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager | ||
import transformer_engine_extensions as tex | ||
|
||
# PyTorch tensor dtypes | ||
_dtypes: List[torch.dtype] = [torch.float32, torch.float16, torch.bfloat16] | ||
# TE FP8 dtypes | ||
_fp8_dtypes: List[tex.DType] = [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2] | ||
|
||
# Numerical tolerances with FP8 types | ||
_tols: Dict[tex.DType, Dict[str, float]] = { | ||
tex.DType.kFloat8E4M3: dict(rtol=0.125, atol=0.0675), # epsilon = 0.0625 | ||
tex.DType.kFloat8E5M2: dict(rtol=0.25, atol=0.125), # epsilon = 0.125 | ||
} | ||
|
||
def _to_list(x: Union[Iterable, Any]) -> List: | ||
"""Convert to list if iterable, otherwise put in singleton list""" | ||
if isinstance(x, Iterable): | ||
return list(x) | ||
else: | ||
return [x] | ||
|
||
# Types that can be interpreted as tensor dims | ||
DimsType = Union[Iterable[int], int] | ||
|
||
# Check if FP8 is supported | ||
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() | ||
|
||
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) | ||
class TestFloat8Tensor: | ||
|
||
@staticmethod | ||
def setup_class(cls) -> None: | ||
# Configure RNG | ||
seed = 1234 | ||
torch.manual_seed(seed) | ||
torch.cuda.manual_seed(seed) | ||
|
||
def test_constructor( | ||
self, | ||
dims: DimsType = 1, | ||
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, | ||
scale_inv: float = 0.375, | ||
dtype: torch.dtype = torch.float32, | ||
) -> None: | ||
"""Call constructor and perform sanity checks""" | ||
dims = _to_list(dims) | ||
tensor = Float8Tensor( | ||
data=torch.zeros(dims, device="cuda", dtype=torch.uint8), | ||
fp8_dtype=fp8_dtype, | ||
fp8_scale_inv=torch.full([1], scale_inv), | ||
dtype=dtype, | ||
) | ||
assert list(tensor.size()) == dims, "Incorrect dims" | ||
assert tensor.dtype == dtype, "Incorrect nominal dtype" | ||
assert tensor.is_cuda, "Incorrect device" | ||
|
||
def _test_quantize_dequantize( | ||
self, | ||
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, | ||
scale: float = 3.5, | ||
dtype: torch.dtype = torch.float32, | ||
dims: DimsType = 23, | ||
) -> None: | ||
"""Check numerical error when casting to FP8 and back""" | ||
|
||
# Initialize random data | ||
x_ref = 2 * torch.rand(_to_list(dims), dtype=dtype, device="cpu") - 1 | ||
|
||
# Cast to FP8 and back | ||
x_fp8 = Float8Tensor.to_float8( | ||
x_ref, | ||
fp8_dtype=fp8_dtype, | ||
scale=torch.full([1], scale), | ||
) | ||
x_fp8 = x_fp8.from_float8().cpu() | ||
|
||
# Check results | ||
torch.testing.assert_close(x_fp8, x_ref, **_tols[fp8_dtype]) | ||
|
||
# Make sure we are not trivially passing the test | ||
with pytest.raises(AssertionError): | ||
torch.testing.assert_close(x_fp8, -x_ref, **_tols[fp8_dtype]) | ||
|
||
@pytest.mark.parametrize("fp8_dtype", _fp8_dtypes) | ||
@pytest.mark.parametrize("dtype", _dtypes) | ||
def test_quantize_dequantize_dtypes( | ||
self, | ||
fp8_dtype: tex.DType, | ||
dtype: torch.dtype, | ||
) -> None: | ||
self._test_quantize_dequantize(fp8_dtype=fp8_dtype, dtype=dtype) | ||
|
||
@pytest.mark.parametrize("scale", [0.375, 1, 3.5]) | ||
def test_quantize_dequantize_scales(self, scale: float) -> None: | ||
self._test_quantize_dequantize(scale=scale) | ||
|
||
@pytest.mark.parametrize("dims", [[], 1, 311, [7,11], [7,5,3], [2,3,5,3]]) | ||
def test_quantize_dequantize_dims(self, dims: DimsType) -> None: | ||
self._test_quantize_dequantize(dims=dims) | ||
|
||
def test_fp8_meta( | ||
self, | ||
dtype: torch.dtype = torch.float32, | ||
dims: DimsType = 23, | ||
) -> None: | ||
"""Construct Float8Tensor using FP8 metadata and perform basic checks""" | ||
|
||
# Get FP8 metadata from linear module | ||
fp8_dtype = tex.DType.kFloat8E4M3 | ||
recipe = transformer_engine.common.recipe.DelayedScaling( | ||
fp8_format=transformer_engine.common.recipe.Format.E4M3, | ||
) | ||
with te.fp8_autocast(enabled=True, fp8_recipe=recipe): | ||
module = te.Linear(32, 32) | ||
_ = module(torch.zeros([8, 32], device="cuda")) | ||
fp8_meta = module.fp8_meta | ||
fp8_meta_index = tex.FP8FwdTensors.GEMM1_WEIGHT | ||
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True) | ||
|
||
# Initialize random data | ||
dims = _to_list(dims) | ||
x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 | ||
|
||
# Make Float8Tensor | ||
x_fp8 = Float8Tensor.to_float8( | ||
x_ref, | ||
fp8_meta=fp8_meta, | ||
fp8_meta_index=fp8_meta_index, | ||
) | ||
x_ref = x_fp8.from_float8() | ||
assert list(x_fp8.size()) == dims, "Incorrect dims" | ||
assert x_fp8.dtype == dtype, "Incorrect nominal dtype" | ||
assert x_fp8.is_cuda, "Incorrect device" | ||
assert x_fp8._fp8_dtype == fp8_dtype, "Incorrect FP8 dtype" | ||
|
||
# Change FP8 metadata scale | ||
fp8_meta[fp8_meta_key].scale[fp8_meta_index] = 2 | ||
fp8_meta[fp8_meta_key].scale_inv.fill_(123) | ||
|
||
# Check results | ||
torch.testing.assert_close(x_fp8, x_ref, **_tols[fp8_dtype]) | ||
with pytest.raises(AssertionError): | ||
# Make sure we are not trivially passing the test | ||
torch.testing.assert_close(x_fp8, -x_ref, **_tols[fp8_dtype]) | ||
|
||
# Check if scaling factor is updated after in-place ops | ||
x_fp8 += 0 | ||
fp8_meta[fp8_meta_key].scale[fp8_meta_index] = 4 | ||
fp8_meta[fp8_meta_key].scale_inv.fill_(321) | ||
assert x_fp8._scale_inv.item() == 0.5, "Incorrect FP8 scale_inv" | ||
torch.testing.assert_close(x_fp8, x_ref, **_tols[fp8_dtype]) | ||
y = x_fp8.detach() | ||
y += 0 | ||
assert x_fp8._scale_inv.item() == 0.25, "Incorrect FP8 scale_inv" | ||
torch.testing.assert_close(x_fp8, x_ref, **_tols[fp8_dtype]) | ||
|
||
def test_basic_ops( | ||
self, | ||
dims: DimsType = 23, | ||
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, | ||
scale: float = 3.5, | ||
dtype: torch.dtype = torch.float32, | ||
) -> None: | ||
"""Test basic out-of-place ops""" | ||
|
||
# Initialize random data | ||
dims = _to_list(dims) | ||
x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 | ||
y_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 | ||
x_fp8 = Float8Tensor.to_float8( | ||
x_ref, | ||
fp8_dtype=fp8_dtype, | ||
scale=torch.full([1], scale), | ||
) | ||
y_fp8 = Float8Tensor.to_float8( | ||
y_ref, | ||
fp8_dtype=fp8_dtype, | ||
scale=torch.full([1], scale), | ||
) | ||
x_ref = x_fp8.from_float8() | ||
y_ref = y_fp8.from_float8() | ||
|
||
# Exact operations | ||
torch.testing.assert_close(-x_fp8, -x_ref, rtol=0, atol=0) | ||
torch.testing.assert_close(x_fp8.abs(), x_ref.abs(), rtol=0, atol=0) | ||
|
||
# Operations with numerical error | ||
tols = _tols[fp8_dtype] | ||
torch.testing.assert_close(x_fp8 + y_fp8, x_ref + y_ref, **tols) | ||
torch.testing.assert_close(x_fp8 - y_fp8, x_ref - y_ref, **tols) | ||
torch.testing.assert_close(x_fp8 * y_fp8, x_ref * y_ref, **tols) | ||
torch.testing.assert_close(x_fp8 + y_ref, x_ref + y_ref, **tols) | ||
torch.testing.assert_close(x_ref + y_fp8, x_ref + y_ref, **tols) | ||
torch.testing.assert_close(torch.sin(x_fp8), torch.sin(x_ref), **tols) | ||
|
||
# Make sure we are not trivially passing tests | ||
with pytest.raises(AssertionError): | ||
torch.testing.assert_close(x_fp8 + y_fp8, x_ref - y_fp8, **tols) | ||
|
||
def test_inplace_ops( | ||
self, | ||
dims: DimsType = 23, | ||
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, | ||
scale: float = 3.5, | ||
dtype: torch.dtype = torch.float32, | ||
) -> None: | ||
"""Test in-place ops""" | ||
|
||
# Initialize random data | ||
dims = _to_list(dims) | ||
x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 | ||
y_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 | ||
x_fp8 = Float8Tensor.to_float8( | ||
x_ref, | ||
fp8_dtype=fp8_dtype, | ||
scale=torch.full([1], scale), | ||
) | ||
y_fp8 = Float8Tensor.to_float8( | ||
y_ref, | ||
fp8_dtype=fp8_dtype, | ||
scale=torch.full([1], scale), | ||
) | ||
x_ref = x_fp8.from_float8() | ||
y_ref = y_fp8.from_float8() | ||
|
||
# In-place operations | ||
tols = _tols[fp8_dtype] | ||
x_fp8 += y_ref | ||
x_ref += y_ref | ||
torch.testing.assert_close(x_fp8, x_ref, **tols) | ||
x_ref = x_fp8.from_float8() | ||
x_fp8 -= y_fp8 | ||
x_ref -= y_fp8 | ||
torch.testing.assert_close(x_fp8, x_ref, **tols) | ||
x_ref = x_fp8.from_float8() | ||
x_fp8 *= 2 | ||
x_ref *= 2 | ||
torch.testing.assert_close(x_fp8, x_ref, **tols) | ||
x_ref = x_fp8.from_float8() | ||
|
||
# Make sure we are not trivially passing tests | ||
x_ref += 123 | ||
with pytest.raises(AssertionError): | ||
torch.testing.assert_close(x_fp8, x_ref, **tols) | ||
|
||
@pytest.mark.parametrize("dims", [[33, 41], [5, 7, 11]]) | ||
@pytest.mark.parametrize("transpose_dims", [(0, 1), (-2, -1), (0, 0)]) | ||
def test_transpose( | ||
self, | ||
dims: DimsType, | ||
transpose_dims: Tuple[int, int], | ||
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, | ||
scale: float = 1, | ||
dtype: torch.dtype = torch.float32, | ||
) -> None: | ||
"""Test transpose""" | ||
|
||
# Initialize random data | ||
dims = _to_list(dims) | ||
x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 | ||
x_fp8 = Float8Tensor.to_float8( | ||
x_ref, | ||
fp8_dtype=fp8_dtype, | ||
scale=torch.full([1], scale), | ||
) | ||
x_ref = x_fp8.from_float8() | ||
|
||
# Perform transpose | ||
y_fp8 = x_fp8.transpose(*transpose_dims) | ||
y_ref = x_ref.transpose(*transpose_dims) | ||
|
||
# Check results | ||
tols = dict(rtol=0, atol=0) | ||
torch.testing.assert_close(y_fp8, y_ref, **tols) | ||
|
||
# Make sure we are not trivially passing the test | ||
if transpose_dims[0] != transpose_dims[1]: | ||
with pytest.raises(AssertionError): | ||
torch.testing.assert_close( | ||
y_fp8, | ||
x_ref, | ||
**tols, | ||
) | ||
|
||
# Check transpose caching | ||
if x_fp8.dim() == 2 and transpose_dims[0] != transpose_dims[1]: | ||
x_fp8 += 0.5 | ||
x_ref = x_fp8.from_float8() | ||
torch.testing.assert_close( | ||
x_fp8.transpose(*transpose_dims, update_cache=True), | ||
x_ref.transpose(*transpose_dims), | ||
**tols, | ||
) | ||
torch.testing.assert_close( | ||
x_fp8.transpose(*transpose_dims, update_cache=True), | ||
x_ref.transpose(*transpose_dims), | ||
**tols, | ||
) | ||
x_fp8 += 0.5 | ||
x_ref = x_fp8.from_float8() | ||
torch.testing.assert_close( | ||
x_fp8.transpose(*transpose_dims, update_cache=True), | ||
x_ref.transpose(*transpose_dims), | ||
**tols, | ||
) |
Oops, something went wrong.