Skip to content

Commit

Permalink
[PyTorch] Experimental FP8 tensor class (#452)
Browse files Browse the repository at this point in the history
* Experimental FP8 tensor

Co-authored-by: Tim Moon <[email protected]>
Co-authored-by: Sudhakar Singh <[email protected]>
Co-authored-by: Przemyslaw Tredak <[email protected]>
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Add fp8 tensor to ci test

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* review comments and tests

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Minor changes

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Default to FP8 usage

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fix docs

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Naming changes

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* minor fix

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fix transpose caching

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Debug transpose caching

Handle case where transpose cache is updated externally.

Signed-off-by: Tim Moon <[email protected]>

* Rename FP8GlobalStateManager.with_fp8_parameters

Signed-off-by: Tim Moon <[email protected]>

* remove Float8Tensor from import API

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Avoid caching FP8 transposes if not required

Signed-off-by: Tim Moon <[email protected]>

* Fix import error in FP8 tensor tests

Signed-off-by: Tim Moon <[email protected]>

* Fix tranpose caching and checkpointing bug

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Improve caching and fix distopt case

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Update transformer_engine/pytorch/float8_tensor.py

Signed-off-by: Tim Moon <[email protected]>

* Remove recursive logic

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fix cache reset bug

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Store FP8 attributes in dict

Easier for multiple tensors to share, e.g. detached tensors.

Signed-off-by: Tim Moon <[email protected]>

* Make sure scale_inv is 1D tensor

Signed-off-by: Tim Moon <[email protected]>

* Make sure scale_inv is 1D tensor

Signed-off-by: Tim Moon <[email protected]>

* Fixes and detach recipe

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Set default fp8 data type

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

---------

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
Co-authored-by: Sudhakar Singh <[email protected]>
Co-authored-by: Przemyslaw Tredak <[email protected]>
  • Loading branch information
4 people committed Oct 31, 2023
1 parent 7eca973 commit d58c08c
Show file tree
Hide file tree
Showing 14 changed files with 1,448 additions and 141 deletions.
2 changes: 2 additions & 0 deletions docs/api/pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ pyTorch

.. autoapifunction:: transformer_engine.pytorch.fp8_autocast

.. autoapifunction:: transformer_engine.pytorch.fp8_model_init

.. autoapifunction:: transformer_engine.pytorch.checkpoint

.. autoapifunction:: transformer_engine.pytorch.onnx_export
1 change: 1 addition & 0 deletions qa/L0_pytorch_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ PYTORCH_JIT=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pyt
pytest -v -s $TE_PATH/tests/pytorch/test_jit.py
pytest -v -s $TE_PATH/tests/pytorch/test_fused_attn.py
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py
pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py
318 changes: 318 additions & 0 deletions tests/pytorch/test_float8tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,318 @@
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

from collections.abc import Iterable
from typing import Any, Dict, List, Tuple, Union

import pytest
import torch

import transformer_engine.common.recipe
import transformer_engine.pytorch as te
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine_extensions as tex

# PyTorch tensor dtypes
_dtypes: List[torch.dtype] = [torch.float32, torch.float16, torch.bfloat16]
# TE FP8 dtypes
_fp8_dtypes: List[tex.DType] = [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]

# Numerical tolerances with FP8 types
_tols: Dict[tex.DType, Dict[str, float]] = {
tex.DType.kFloat8E4M3: dict(rtol=0.125, atol=0.0675), # epsilon = 0.0625
tex.DType.kFloat8E5M2: dict(rtol=0.25, atol=0.125), # epsilon = 0.125
}

def _to_list(x: Union[Iterable, Any]) -> List:
"""Convert to list if iterable, otherwise put in singleton list"""
if isinstance(x, Iterable):
return list(x)
else:
return [x]

# Types that can be interpreted as tensor dims
DimsType = Union[Iterable[int], int]

# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()

@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
class TestFloat8Tensor:

@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

def test_constructor(
self,
dims: DimsType = 1,
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
scale_inv: float = 0.375,
dtype: torch.dtype = torch.float32,
) -> None:
"""Call constructor and perform sanity checks"""
dims = _to_list(dims)
tensor = Float8Tensor(
data=torch.zeros(dims, device="cuda", dtype=torch.uint8),
fp8_dtype=fp8_dtype,
fp8_scale_inv=torch.full([1], scale_inv),
dtype=dtype,
)
assert list(tensor.size()) == dims, "Incorrect dims"
assert tensor.dtype == dtype, "Incorrect nominal dtype"
assert tensor.is_cuda, "Incorrect device"

def _test_quantize_dequantize(
self,
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
scale: float = 3.5,
dtype: torch.dtype = torch.float32,
dims: DimsType = 23,
) -> None:
"""Check numerical error when casting to FP8 and back"""

# Initialize random data
x_ref = 2 * torch.rand(_to_list(dims), dtype=dtype, device="cpu") - 1

# Cast to FP8 and back
x_fp8 = Float8Tensor.to_float8(
x_ref,
fp8_dtype=fp8_dtype,
scale=torch.full([1], scale),
)
x_fp8 = x_fp8.from_float8().cpu()

# Check results
torch.testing.assert_close(x_fp8, x_ref, **_tols[fp8_dtype])

# Make sure we are not trivially passing the test
with pytest.raises(AssertionError):
torch.testing.assert_close(x_fp8, -x_ref, **_tols[fp8_dtype])

@pytest.mark.parametrize("fp8_dtype", _fp8_dtypes)
@pytest.mark.parametrize("dtype", _dtypes)
def test_quantize_dequantize_dtypes(
self,
fp8_dtype: tex.DType,
dtype: torch.dtype,
) -> None:
self._test_quantize_dequantize(fp8_dtype=fp8_dtype, dtype=dtype)

@pytest.mark.parametrize("scale", [0.375, 1, 3.5])
def test_quantize_dequantize_scales(self, scale: float) -> None:
self._test_quantize_dequantize(scale=scale)

@pytest.mark.parametrize("dims", [[], 1, 311, [7,11], [7,5,3], [2,3,5,3]])
def test_quantize_dequantize_dims(self, dims: DimsType) -> None:
self._test_quantize_dequantize(dims=dims)

def test_fp8_meta(
self,
dtype: torch.dtype = torch.float32,
dims: DimsType = 23,
) -> None:
"""Construct Float8Tensor using FP8 metadata and perform basic checks"""

# Get FP8 metadata from linear module
fp8_dtype = tex.DType.kFloat8E4M3
recipe = transformer_engine.common.recipe.DelayedScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
module = te.Linear(32, 32)
_ = module(torch.zeros([8, 32], device="cuda"))
fp8_meta = module.fp8_meta
fp8_meta_index = tex.FP8FwdTensors.GEMM1_WEIGHT
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True)

# Initialize random data
dims = _to_list(dims)
x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1

# Make Float8Tensor
x_fp8 = Float8Tensor.to_float8(
x_ref,
fp8_meta=fp8_meta,
fp8_meta_index=fp8_meta_index,
)
x_ref = x_fp8.from_float8()
assert list(x_fp8.size()) == dims, "Incorrect dims"
assert x_fp8.dtype == dtype, "Incorrect nominal dtype"
assert x_fp8.is_cuda, "Incorrect device"
assert x_fp8._fp8_dtype == fp8_dtype, "Incorrect FP8 dtype"

# Change FP8 metadata scale
fp8_meta[fp8_meta_key].scale[fp8_meta_index] = 2
fp8_meta[fp8_meta_key].scale_inv.fill_(123)

# Check results
torch.testing.assert_close(x_fp8, x_ref, **_tols[fp8_dtype])
with pytest.raises(AssertionError):
# Make sure we are not trivially passing the test
torch.testing.assert_close(x_fp8, -x_ref, **_tols[fp8_dtype])

# Check if scaling factor is updated after in-place ops
x_fp8 += 0
fp8_meta[fp8_meta_key].scale[fp8_meta_index] = 4
fp8_meta[fp8_meta_key].scale_inv.fill_(321)
assert x_fp8._scale_inv.item() == 0.5, "Incorrect FP8 scale_inv"
torch.testing.assert_close(x_fp8, x_ref, **_tols[fp8_dtype])
y = x_fp8.detach()
y += 0
assert x_fp8._scale_inv.item() == 0.25, "Incorrect FP8 scale_inv"
torch.testing.assert_close(x_fp8, x_ref, **_tols[fp8_dtype])

def test_basic_ops(
self,
dims: DimsType = 23,
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
scale: float = 3.5,
dtype: torch.dtype = torch.float32,
) -> None:
"""Test basic out-of-place ops"""

# Initialize random data
dims = _to_list(dims)
x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
y_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
x_fp8 = Float8Tensor.to_float8(
x_ref,
fp8_dtype=fp8_dtype,
scale=torch.full([1], scale),
)
y_fp8 = Float8Tensor.to_float8(
y_ref,
fp8_dtype=fp8_dtype,
scale=torch.full([1], scale),
)
x_ref = x_fp8.from_float8()
y_ref = y_fp8.from_float8()

# Exact operations
torch.testing.assert_close(-x_fp8, -x_ref, rtol=0, atol=0)
torch.testing.assert_close(x_fp8.abs(), x_ref.abs(), rtol=0, atol=0)

# Operations with numerical error
tols = _tols[fp8_dtype]
torch.testing.assert_close(x_fp8 + y_fp8, x_ref + y_ref, **tols)
torch.testing.assert_close(x_fp8 - y_fp8, x_ref - y_ref, **tols)
torch.testing.assert_close(x_fp8 * y_fp8, x_ref * y_ref, **tols)
torch.testing.assert_close(x_fp8 + y_ref, x_ref + y_ref, **tols)
torch.testing.assert_close(x_ref + y_fp8, x_ref + y_ref, **tols)
torch.testing.assert_close(torch.sin(x_fp8), torch.sin(x_ref), **tols)

# Make sure we are not trivially passing tests
with pytest.raises(AssertionError):
torch.testing.assert_close(x_fp8 + y_fp8, x_ref - y_fp8, **tols)

def test_inplace_ops(
self,
dims: DimsType = 23,
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
scale: float = 3.5,
dtype: torch.dtype = torch.float32,
) -> None:
"""Test in-place ops"""

# Initialize random data
dims = _to_list(dims)
x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
y_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
x_fp8 = Float8Tensor.to_float8(
x_ref,
fp8_dtype=fp8_dtype,
scale=torch.full([1], scale),
)
y_fp8 = Float8Tensor.to_float8(
y_ref,
fp8_dtype=fp8_dtype,
scale=torch.full([1], scale),
)
x_ref = x_fp8.from_float8()
y_ref = y_fp8.from_float8()

# In-place operations
tols = _tols[fp8_dtype]
x_fp8 += y_ref
x_ref += y_ref
torch.testing.assert_close(x_fp8, x_ref, **tols)
x_ref = x_fp8.from_float8()
x_fp8 -= y_fp8
x_ref -= y_fp8
torch.testing.assert_close(x_fp8, x_ref, **tols)
x_ref = x_fp8.from_float8()
x_fp8 *= 2
x_ref *= 2
torch.testing.assert_close(x_fp8, x_ref, **tols)
x_ref = x_fp8.from_float8()

# Make sure we are not trivially passing tests
x_ref += 123
with pytest.raises(AssertionError):
torch.testing.assert_close(x_fp8, x_ref, **tols)

@pytest.mark.parametrize("dims", [[33, 41], [5, 7, 11]])
@pytest.mark.parametrize("transpose_dims", [(0, 1), (-2, -1), (0, 0)])
def test_transpose(
self,
dims: DimsType,
transpose_dims: Tuple[int, int],
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
scale: float = 1,
dtype: torch.dtype = torch.float32,
) -> None:
"""Test transpose"""

# Initialize random data
dims = _to_list(dims)
x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
x_fp8 = Float8Tensor.to_float8(
x_ref,
fp8_dtype=fp8_dtype,
scale=torch.full([1], scale),
)
x_ref = x_fp8.from_float8()

# Perform transpose
y_fp8 = x_fp8.transpose(*transpose_dims)
y_ref = x_ref.transpose(*transpose_dims)

# Check results
tols = dict(rtol=0, atol=0)
torch.testing.assert_close(y_fp8, y_ref, **tols)

# Make sure we are not trivially passing the test
if transpose_dims[0] != transpose_dims[1]:
with pytest.raises(AssertionError):
torch.testing.assert_close(
y_fp8,
x_ref,
**tols,
)

# Check transpose caching
if x_fp8.dim() == 2 and transpose_dims[0] != transpose_dims[1]:
x_fp8 += 0.5
x_ref = x_fp8.from_float8()
torch.testing.assert_close(
x_fp8.transpose(*transpose_dims, update_cache=True),
x_ref.transpose(*transpose_dims),
**tols,
)
torch.testing.assert_close(
x_fp8.transpose(*transpose_dims, update_cache=True),
x_ref.transpose(*transpose_dims),
**tols,
)
x_fp8 += 0.5
x_ref = x_fp8.from_float8()
torch.testing.assert_close(
x_fp8.transpose(*transpose_dims, update_cache=True),
x_ref.transpose(*transpose_dims),
**tols,
)
Loading

0 comments on commit d58c08c

Please sign in to comment.