Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
Enable restricted split + cat in order to enable SP (#253)
Browse files Browse the repository at this point in the history
Summary:
This comes from needing to support sequence parallelism in torchtitan

Pull Request resolved: #253

Reviewed By: wanchaol

Differential Revision: D57134004

Pulled By: drisspg

fbshipit-source-id: e6c67ba7b2b96045867ece467400b0e4a3305e1d
  • Loading branch information
drisspg authored and facebook-github-bot committed May 9, 2024
1 parent 14b00aa commit cb55df2
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 3 deletions.
48 changes: 47 additions & 1 deletion float8_experimental/float8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Dict
from typing import Any, Dict, Tuple

import torch

Expand Down Expand Up @@ -50,6 +50,52 @@ def float8_desugar_op(aten_op, args, kwargs=None):
)


@implements([aten.split.Tensor])
def float8_split(aten_op, args, kwargs=None):
new_data_tensors = aten_op(args[0]._data, *args[1:], **kwargs)

def make_float8(data):
return Float8Tensor(
data, args[0]._scale, args[0]._orig_dtype, args[0]._mm_config
)

out = map(make_float8, new_data_tensors)
return list(out)


# Errors cant `cat_cuda float8 e4m3fn`
@implements([aten.cat.default])
def float8_cat(aten_op, args, kwargs=None):
chunked_tensors: Tuple[Float8Tensor] = args[0]

orig_dtype = chunked_tensors[0]._orig_dtype
scale = chunked_tensors[0]._scale
mm_config = chunked_tensors[0]._mm_config
fp8_dtype = chunked_tensors[0]._data.dtype
chunk_data = []
for chunk in chunked_tensors:
assert isinstance(
chunk, Float8Tensor
), "Expecting all chunks to be of type Float8Tensor"
assert (
chunk._orig_dtype == orig_dtype
), "Expecting all chunks to be of the same dtype"
assert (
chunk._scale is scale
), "Expecting all chunks to have thee same scale as a result of a split"
assert (
chunk._mm_config is mm_config
), "Expecting all chunks to have thee same mm config as a result of a split"
assert (
chunk._data.dtype == fp8_dtype
), "Expecting all chunks to be of the same dtype as a result of a split"
chunk_data.append(chunk._data.view(torch.uint8))

new_data = aten_op(chunk_data, *args[1:], **kwargs)
new_data = new_data.view(fp8_dtype)
return Float8Tensor(new_data, scale, orig_dtype, mm_config)


@implements([aten.sum.dim_IntList])
def float8_cast_up_op(aten_op, args, kwargs=None):
"""Be careful with this function, this is a "fallback" op that
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ dev = [
"black==23.3.0",
"usort==1.0.6",
"ufmt==2.1.0",
"libcst==1.0.1",
"libcst==1.1.0",
"pytest==7.4.0",
"bumpver",
"pip-tools",
Expand Down
17 changes: 16 additions & 1 deletion test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@
is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)


def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool:
assert torch.all(a._data == b._data).item(), "scales are not identical"
assert torch.all(a._data == b._data).item(), "data is not identical"
return True


class TestFloat8Tensor(unittest.TestCase):
def test_preserves_dtype(self) -> None:
# hp means high precision, lp means low precision
Expand All @@ -68,6 +74,15 @@ def test_differentiable_casts(self) -> None:
# the gradient should be unchanged through both casts
torch.testing.assert_close(grad, x.grad, rtol=0, atol=0)

def test_split_cat(self):
a = torch.rand(16, 16, dtype=torch.bfloat16)
scale = tensor_to_scale(a, torch.float8_e4m3fn)
fp8_a = Float8Tensor.to_float8(a, scale, torch.float8_e4m3fn)

splits = torch.split(fp8_a, 16)
catted = torch.cat(splits, dim=0)
assert bitwise_identical(fp8_a, catted)


class TestFloat8Linear:
def _test_linear_impl(
Expand Down Expand Up @@ -357,7 +372,7 @@ def test_different_configs_error(self):
):
a @ b

def test_merge_configs(sel):
def test_merge_configs(self):
a = ScaledMMConfig(False, True, True)
b = ScaledMMConfig(True, False, False)
with pytest.raises(
Expand Down

0 comments on commit cb55df2

Please sign in to comment.