From cb55df259cfb22a856ca92107a778343edea5fc7 Mon Sep 17 00:00:00 2001 From: drisspg Date: Wed, 8 May 2024 17:58:52 -0700 Subject: [PATCH] Enable restricted split + cat in order to enable SP (#253) Summary: This comes from needing to support sequence parallelism in torchtitan Pull Request resolved: https://github.com/pytorch-labs/float8_experimental/pull/253 Reviewed By: wanchaol Differential Revision: D57134004 Pulled By: drisspg fbshipit-source-id: e6c67ba7b2b96045867ece467400b0e4a3305e1d --- float8_experimental/float8_ops.py | 48 ++++++++++++++++++++++++++++++- pyproject.toml | 2 +- test/test_base.py | 17 ++++++++++- 3 files changed, 64 insertions(+), 3 deletions(-) diff --git a/float8_experimental/float8_ops.py b/float8_experimental/float8_ops.py index 7eec3b6c..e22ccf3b 100644 --- a/float8_experimental/float8_ops.py +++ b/float8_experimental/float8_ops.py @@ -3,7 +3,7 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Dict +from typing import Any, Dict, Tuple import torch @@ -50,6 +50,52 @@ def float8_desugar_op(aten_op, args, kwargs=None): ) +@implements([aten.split.Tensor]) +def float8_split(aten_op, args, kwargs=None): + new_data_tensors = aten_op(args[0]._data, *args[1:], **kwargs) + + def make_float8(data): + return Float8Tensor( + data, args[0]._scale, args[0]._orig_dtype, args[0]._mm_config + ) + + out = map(make_float8, new_data_tensors) + return list(out) + + +# Errors cant `cat_cuda float8 e4m3fn` +@implements([aten.cat.default]) +def float8_cat(aten_op, args, kwargs=None): + chunked_tensors: Tuple[Float8Tensor] = args[0] + + orig_dtype = chunked_tensors[0]._orig_dtype + scale = chunked_tensors[0]._scale + mm_config = chunked_tensors[0]._mm_config + fp8_dtype = chunked_tensors[0]._data.dtype + chunk_data = [] + for chunk in chunked_tensors: + assert isinstance( + chunk, Float8Tensor + ), "Expecting all chunks to be of type Float8Tensor" + assert ( + chunk._orig_dtype == orig_dtype + ), "Expecting all chunks to be of the same dtype" + assert ( + chunk._scale is scale + ), "Expecting all chunks to have thee same scale as a result of a split" + assert ( + chunk._mm_config is mm_config + ), "Expecting all chunks to have thee same mm config as a result of a split" + assert ( + chunk._data.dtype == fp8_dtype + ), "Expecting all chunks to be of the same dtype as a result of a split" + chunk_data.append(chunk._data.view(torch.uint8)) + + new_data = aten_op(chunk_data, *args[1:], **kwargs) + new_data = new_data.view(fp8_dtype) + return Float8Tensor(new_data, scale, orig_dtype, mm_config) + + @implements([aten.sum.dim_IntList]) def float8_cast_up_op(aten_op, args, kwargs=None): """Be careful with this function, this is a "fallback" op that diff --git a/pyproject.toml b/pyproject.toml index 116e8e99..858e53b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ dev = [ "black==23.3.0", "usort==1.0.6", "ufmt==2.1.0", - "libcst==1.0.1", + "libcst==1.1.0", "pytest==7.4.0", "bumpver", "pip-tools", diff --git a/test/test_base.py b/test/test_base.py index 2c8da53d..d4545ef0 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -44,6 +44,12 @@ is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) +def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool: + assert torch.all(a._data == b._data).item(), "scales are not identical" + assert torch.all(a._data == b._data).item(), "data is not identical" + return True + + class TestFloat8Tensor(unittest.TestCase): def test_preserves_dtype(self) -> None: # hp means high precision, lp means low precision @@ -68,6 +74,15 @@ def test_differentiable_casts(self) -> None: # the gradient should be unchanged through both casts torch.testing.assert_close(grad, x.grad, rtol=0, atol=0) + def test_split_cat(self): + a = torch.rand(16, 16, dtype=torch.bfloat16) + scale = tensor_to_scale(a, torch.float8_e4m3fn) + fp8_a = Float8Tensor.to_float8(a, scale, torch.float8_e4m3fn) + + splits = torch.split(fp8_a, 16) + catted = torch.cat(splits, dim=0) + assert bitwise_identical(fp8_a, catted) + class TestFloat8Linear: def _test_linear_impl( @@ -357,7 +372,7 @@ def test_different_configs_error(self): ): a @ b - def test_merge_configs(sel): + def test_merge_configs(self): a = ScaledMMConfig(False, True, True) b = ScaledMMConfig(True, False, False) with pytest.raises(