Skip to content

Commit

Permalink
#13127: Add full support for creating tensors with logical sharding f…
Browse files Browse the repository at this point in the history
…rom python

#13127: Add support for ROW_MAJOR block/width sharding with partial shards along width
- Changes to pytensor:
  * Integrate encode_tensor_data converter into create_owned_tensor
  * Integrate decode_tensor_data data converter into create_row_major_owned_buffer
    ** Update converters to also handle tilize and untilize if needed
    ** Update converters to also handle interleaved tensors with equivalent sharding specs
  * Switch to query logical shape when converting tt tensor to pytorch and numpy tensors
    ** This adds readback support for logically sharded tensors
    ** For pytorch, fork new functionality into tensor.to_torch_with_logical_shape()
    ** Maintain old tensor.to_torch() behaviour which returns data based on padded shape
    *** TODO: This is incorrect and need to fix tests that use tensor.to_torch()
    *** TODO: Then, we can deprecate old tensor.to_torch() and use new one
  * Update enable_borrow logic to account for padding needed for logical sharding
  * Add check for shard spec to have value for sharded memory configs
- Other infra changes/fixes:
  * Integrate tensor APIs into ttnn.from_torch() and ttnn.to_torch() to support logical sharding
  * Update create_default_alignment for ROW_MAJOR logically sharded tensors
  * Add TT_FATAL check to make sure raw data bytes matches expected device buffer size before sharding
  * Add pybind for creating logical shard spec with logical and physical shard shapes
  * Fix bug with reshape view hardcoding tile sizes
- Add test cases for logical sharding to tests/ttnn/unit_tests/tensor/test_tensor_creation.py
  * Tests tensor creation with ttnn.Tensor(...) and ttnn.from_torch(...)
  * Tests tensor readback with tensor.to_torch_with_logical_shape(...) and ttnn.to_torch(...)
  * Tests API parity between tensor and functional ttnn APIs
- Clean up tests/ttnn/unit_tests/gtests/tensor/test_sharding_with_alignment.cpp
  * Move important converters and helper functions into tensor_impl and tensor_utils
  * Convert expected/output data to TILE/ROW_MAJOR layout for ease of testing
    ** Needed since converters now tilize/untilize to match required Layout
  * Change CoreRangeSet to proper 2D grid for BLOCK sharded cases
  * Change EXPECT_EQ to ASSERT_EQ for ShardWithAlignmentTests
  * Fix formatting for test cases
  • Loading branch information
TT-BrianLiu committed Dec 20, 2024
1 parent 5058b8f commit 212be31
Show file tree
Hide file tree
Showing 11 changed files with 705 additions and 350 deletions.
475 changes: 187 additions & 288 deletions tests/ttnn/unit_tests/gtests/tensor/test_sharding_with_alignment.cpp

Large diffs are not rendered by default.

140 changes: 140 additions & 0 deletions tests/ttnn/unit_tests/tensor/test_tensor_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,143 @@ def test_tensor_creation_api_parity(shape, tt_dtype, layout, device):
passing = torch.allclose(py_tensor, py_tensor_after_round_trip_3, **allclose_kwargs)
passing = torch.allclose(py_tensor, py_tensor_after_round_trip_4, **allclose_kwargs)
assert passing


grid_size = [8, 7]


@pytest.mark.parametrize(
"layout, tile",
[
(ttnn.ROW_MAJOR_LAYOUT, None),
(ttnn.TILE_LAYOUT, ttnn.Tile([32, 32])),
(ttnn.TILE_LAYOUT, ttnn.Tile([16, 16])),
(ttnn.TILE_LAYOUT, ttnn.Tile([32, 16])),
(ttnn.TILE_LAYOUT, ttnn.Tile([16, 16])),
],
)
@pytest.mark.parametrize(
"tt_dtype",
[
ttnn.uint8,
ttnn.uint16,
ttnn.uint32,
ttnn.int32,
ttnn.float32,
ttnn.bfloat16,
ttnn.bfloat8_b,
ttnn.bfloat4_b,
],
)
@pytest.mark.parametrize(
"shape, memory_config",
[
(
(1, 2, 3, 4),
ttnn.MemoryConfig(
ttnn.TensorMemoryLayout.INTERLEAVED,
ttnn.BufferType.DRAM,
),
),
(
(1, 48, 56, 32),
ttnn.MemoryConfig(
ttnn.TensorMemoryLayout.HEIGHT_SHARDED,
ttnn.BufferType.L1,
ttnn.ShardSpec(
ttnn.num_cores_to_corerangeset(56, grid_size, True),
[48, 32],
ttnn.ShardOrientation.ROW_MAJOR,
False,
ttnn.ShardMode.LOGICAL,
),
),
),
(
(1, 2, 10, 5),
ttnn.MemoryConfig(
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
ttnn.BufferType.L1,
ttnn.ShardSpec(
ttnn.num_cores_to_corerangeset(3, grid_size, True),
[20, 2],
ttnn.ShardOrientation.ROW_MAJOR,
False,
ttnn.ShardMode.LOGICAL,
),
),
),
(
(2, 3, 64, 96),
ttnn.MemoryConfig(
ttnn.TensorMemoryLayout.BLOCK_SHARDED,
ttnn.BufferType.L1,
ttnn.ShardSpec(
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(1, 5))}),
[64, 64],
ttnn.ShardOrientation.ROW_MAJOR,
False,
ttnn.ShardMode.LOGICAL,
),
),
),
(
(1, 8, 36, 32),
ttnn.MemoryConfig(
ttnn.TensorMemoryLayout.BLOCK_SHARDED,
ttnn.BufferType.L1,
ttnn.ShardSpec(
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(3, 5))}),
[48, 10],
[64, 64], # NOTE: This value is compatible with all PageConfigs in this sweep
ttnn.ShardOrientation.ROW_MAJOR,
False,
),
),
),
],
ids=[
"interleaved",
"height_sharded",
"width_sharded",
"block_sharded",
"block_sharded_with_custom_physical_shard_shape",
],
)
def test_tensor_creation_with_memory_config(shape, memory_config, tt_dtype, layout, tile, device):
torch.manual_seed(0)

if tt_dtype in (ttnn.bfloat8_b, ttnn.bfloat4_b) and layout == ttnn.ROW_MAJOR_LAYOUT:
pytest.skip("{} is only valid for ttnn.TILE_LAYOUT!".format(tt_dtype))

dtype = tt_dtype_to_torch_dtype[tt_dtype]

if dtype in {torch.uint8, torch.int16, torch.int32}:
py_tensor = torch.randint(torch.iinfo(dtype).min, torch.iinfo(dtype).max, shape, dtype=dtype)
else:
py_tensor = torch.rand(shape, dtype=dtype)

tt_tensor_1 = ttnn.Tensor(py_tensor, tt_dtype, device, layout, memory_config, tile)
tt_tensor_2 = ttnn.from_torch(
py_tensor, tt_dtype, device=device, layout=layout, memory_config=memory_config, tile=tile
)

tt_tensor_1 = tt_tensor_1.cpu()
tt_tensor_2 = tt_tensor_2.cpu()

py_tensor_after_round_trip_1 = tt_tensor_1.to_torch_with_logical_shape()
py_tensor_after_round_trip_2 = tt_tensor_2.to_torch_with_logical_shape()
py_tensor_after_round_trip_3 = ttnn.to_torch(tt_tensor_1)
py_tensor_after_round_trip_4 = ttnn.to_torch(tt_tensor_2)

allclose_kwargs = {}
if tt_dtype == ttnn.bfloat8_b:
allclose_kwargs = dict(atol=1e-2)
elif tt_dtype == ttnn.bfloat4_b:
allclose_kwargs = dict(atol=0.2)

passing = torch.allclose(py_tensor, py_tensor_after_round_trip_1, **allclose_kwargs)
passing = torch.allclose(py_tensor, py_tensor_after_round_trip_2, **allclose_kwargs)
passing = torch.allclose(py_tensor, py_tensor_after_round_trip_3, **allclose_kwargs)
passing = torch.allclose(py_tensor, py_tensor_after_round_trip_4, **allclose_kwargs)
assert passing
124 changes: 86 additions & 38 deletions ttnn/cpp/pybind11/pytensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,16 @@ void log_external_operation(

template <typename T>
Tensor create_owned_tensor(T* data_ptr, const ttnn::TensorSpec& tensor_spec) {
TT_FATAL(
!tensor_spec.memory_config().is_sharded() or tensor_spec.memory_config().shard_spec.has_value(),
"Sharded tensors must have a shard spec when converting to tt tensors!");
std::size_t num_elements = tensor_spec.logical_shape().volume();
auto data = std::vector<T>(data_ptr, data_ptr + num_elements);
auto buffer = owned_buffer::create(std::move(data));
auto logical_data = std::vector<T>(data_ptr, data_ptr + num_elements);

if (tensor_spec.layout() == Layout::TILE) {
data = tensor_impl::convert_layout_row_major_to_tile(tensor_spec.physical_shape(), tensor_spec.tile(), buffer);
buffer = owned_buffer::create(std::move(data));
}
// See implementation for documentation
auto physical_data = tensor_impl::encode_tensor_data(logical_data, tensor_spec);

auto buffer = owned_buffer::create(std::move(physical_data));
auto storage = OwnedStorage{std::move(buffer)};
return Tensor(std::move(storage), tensor_spec);
}
Expand Down Expand Up @@ -150,12 +152,15 @@ Tensor create_tt_tensor_from_py_data(
std::size_t py_data_ptr,
const TensorSpec& tensor_spec,
Device* device,
bool force_disable_borrow,
const bool force_disable_borrow,
const std::function<void()>& on_creation_callback,
const std::function<void()>& on_destruction_callback) {
auto layout = tensor_spec.layout();

const bool enable_borrow = layout == Layout::ROW_MAJOR and not force_disable_borrow;
const bool requires_padding = tensor_spec.logical_shape().volume() !=
tensor_spec.physical_shape().height() * tensor_spec.physical_shape().width();
const bool requires_tilization = layout != Layout::ROW_MAJOR;
const bool enable_borrow = !requires_padding and !requires_tilization and !force_disable_borrow;

auto data_type = tensor_spec.data_type();
std::size_t num_elements = tensor_spec.logical_shape().volume();
Expand Down Expand Up @@ -253,7 +258,7 @@ Tensor convert_python_tensor_to_tt_tensor(
const std::optional<Tile>& optional_tile,
const MemoryConfig& memory_config,
Device* device,
bool force_disable_borrow = false) {
const bool force_disable_borrow = false) {
GraphTracker::instance().track_function_start(
"tt::tt_metal::detail::convert_python_tensor_to_tt_tensor",
py_tensor,
Expand Down Expand Up @@ -469,47 +474,62 @@ Tensor convert_python_tensors_to_tt_tensors(

template <typename T>
owned_buffer::Buffer<T> create_row_major_owned_buffer(
owned_buffer::Buffer<T> owned_buffer, const ttnn::TensorSpec& tensor_spec) {
if (tensor_spec.layout() == Layout::TILE) {
auto data = tensor_impl::convert_layout_tile_to_row_major(
tensor_spec.physical_shape(), tensor_spec.tile(), owned_buffer);
return owned_buffer::create(std::move(data));
owned_buffer::Buffer<T> owned_buffer, const ttnn::TensorSpec& tensor_spec, const bool legacy_output) {
TT_FATAL(
!tensor_spec.memory_config().is_sharded() or tensor_spec.memory_config().shard_spec.has_value(),
"Sharded tensors must have a shard spec when converting to tt tensors!");

if (legacy_output) {
if (tensor_spec.layout() == Layout::TILE) {
auto data = tensor_impl::convert_layout_tile_to_row_major(
tensor_spec.physical_shape(), tensor_spec.tile(), owned_buffer);
return owned_buffer::create(std::move(data));
}
return owned_buffer;
}
return owned_buffer;

auto physical_data = owned_buffer.get();

// See implementation for documentation
auto logical_data = tensor_impl::decode_tensor_data(physical_data, tensor_spec);

return owned_buffer::create(std::move(logical_data));
}

std::variant<OwnedBuffer, BorrowedBuffer> get_host_buffer_from_tensor(const Tensor& tt_tensor) {
std::variant<OwnedBuffer, BorrowedBuffer> get_host_buffer_from_tensor(
const Tensor& tt_tensor, const bool legacy_output) {
TT_ASSERT(tt_tensor.storage_type() == StorageType::OWNED or tt_tensor.storage_type() == StorageType::BORROWED);

using RetType = std::variant<OwnedBuffer, BorrowedBuffer>;
return std::visit(
tt::stl::overloaded{
[&tt_tensor](const OwnedStorage& storage) -> RetType {
[&tt_tensor, legacy_output](const OwnedStorage& storage) -> RetType {
const auto& tensor_spec = tt_tensor.get_tensor_spec();
const auto tt_dtype = tensor_spec.data_type();
switch (tt_dtype) {
case DataType::UINT8: {
return create_row_major_owned_buffer(
owned_buffer::get_as<uint8_t>(storage.buffer), tensor_spec);
owned_buffer::get_as<uint8_t>(storage.buffer), tensor_spec, legacy_output);
}
case DataType::UINT16: {
return create_row_major_owned_buffer(
owned_buffer::get_as<uint16_t>(storage.buffer), tensor_spec);
owned_buffer::get_as<uint16_t>(storage.buffer), tensor_spec, legacy_output);
}
case DataType::INT32: {
return create_row_major_owned_buffer(
owned_buffer::get_as<int32_t>(storage.buffer), tensor_spec);
owned_buffer::get_as<int32_t>(storage.buffer), tensor_spec, legacy_output);
}
case DataType::UINT32: {
return create_row_major_owned_buffer(
owned_buffer::get_as<uint32_t>(storage.buffer), tensor_spec);
owned_buffer::get_as<uint32_t>(storage.buffer), tensor_spec, legacy_output);
}
case DataType::FLOAT32: {
return create_row_major_owned_buffer(owned_buffer::get_as<float>(storage.buffer), tensor_spec);
return create_row_major_owned_buffer(
owned_buffer::get_as<float>(storage.buffer), tensor_spec, legacy_output);
}
case DataType::BFLOAT16: {
return create_row_major_owned_buffer(
owned_buffer::get_as<::bfloat16>(storage.buffer), tensor_spec);
owned_buffer::get_as<::bfloat16>(storage.buffer), tensor_spec, legacy_output);
}
case DataType::BFLOAT8_B:
case DataType::BFLOAT4_B: {
Expand All @@ -522,7 +542,7 @@ std::variant<OwnedBuffer, BorrowedBuffer> get_host_buffer_from_tensor(const Tens
: unpack_bfp4_tiles_into_float_vec(
uint32_data, /*row_major_output=*/false, /*is_exp_a=*/false, tile);
auto input_float_buffer = owned_buffer::create<float>(std::move(float_unpacked_data));
return create_row_major_owned_buffer(input_float_buffer, tensor_spec);
return create_row_major_owned_buffer(input_float_buffer, tensor_spec, legacy_output);
}
default: {
TT_THROW("Unsupported DataType: {}", tt_dtype);
Expand All @@ -540,10 +560,20 @@ std::variant<OwnedBuffer, BorrowedBuffer> get_host_buffer_from_tensor(const Tens
tt_tensor.get_storage());
}

py::object convert_tt_tensor_to_torch_tensor(const Tensor& tt_tensor) {
GraphTracker::instance().track_function_start("tt::tt_metal::detail::convert_tt_tensor_to_torch_tensor", tt_tensor);

auto buffer = get_host_buffer_from_tensor(tt_tensor);
py::object convert_tt_tensor_to_torch_tensor(const Tensor& tt_tensor, const bool legacy_output = false) {
GraphTracker::instance().track_function_start(
"tt::tt_metal::detail::convert_tt_tensor_to_torch_tensor", tt_tensor, legacy_output);

// TODO: Remove legacy_output flag which supports old behaviour of returning tensors with padded shape.
// These cases need to be fixed:
// ROW_MAJOR tensors with padding (since ROW_MAJOR has no alignment, cannot automatically strip data unless
// padded shape is queried) Physical sharding on padded shape (unlike interleaved tensors, cannot derive an
// equivalent logical shard spec to strip out data)
// One way to clean this up is:
// 1. Update tests to use ttnn.from_torch and ttnn.to_torch
// 2. Fix usage of tensor.to_torch inside ttnn functional APIs
// 3. Deprecate old tensor.to_torch and rename tensor.to_torch_with_logical_shape back to tensor.to_torch
auto buffer = get_host_buffer_from_tensor(tt_tensor, legacy_output);

py::object torch = py::module_::import("torch");
auto frombuffer = torch.attr("frombuffer");
Expand Down Expand Up @@ -576,18 +606,21 @@ py::object convert_tt_tensor_to_torch_tensor(const Tensor& tt_tensor) {
},
buffer);

auto shape = tt_tensor.get_legacy_shape();
auto torch_shape = std::vector<std::uint32_t>(std::begin(shape), std::end(shape));
auto logical_shape = tt_tensor.get_logical_shape();
auto view = logical_shape.view();
std::vector<uint32_t> torch_shape(view.begin(), view.end());
auto tensor = [&]() {
if (tt_tensor.volume() == 0) {
auto pytorch_empty = torch.attr("empty");
auto logical_shape = tt_tensor.get_logical_shape();
auto view = logical_shape.view();
std::vector<uint32_t> shape_vector(view.begin(), view.end());
return pytorch_empty(shape_vector, py::arg("dtype") = torch_dtype);
return pytorch_empty(torch_shape, py::arg("dtype") = torch_dtype);
}
return frombuffer(buffer, py::arg("dtype") = torch_dtype);
}();

if (legacy_output) {
auto shape = tt_tensor.get_legacy_shape();
torch_shape = std::vector<std::uint32_t>(std::begin(shape), std::end(shape));
}
tensor = tensor.attr("reshape")(torch_shape);
tensor = tensor.attr("contiguous")();
if (tt_tensor.storage_type() == StorageType::BORROWED) {
Expand All @@ -600,7 +633,7 @@ py::object convert_tt_tensor_to_torch_tensor(const Tensor& tt_tensor) {
py::object convert_tt_tensor_to_numpy_tensor(const Tensor& tt_tensor) {
GraphTracker::instance().track_function_start("tt::tt_metal::detail::convert_tt_tensor_to_numpy_tensor", tt_tensor);

auto buffer = get_host_buffer_from_tensor(tt_tensor);
auto buffer = get_host_buffer_from_tensor(tt_tensor, false);

py::object np = py::module_::import("numpy");
auto frombuffer = np.attr("frombuffer");
Expand Down Expand Up @@ -633,8 +666,9 @@ py::object convert_tt_tensor_to_numpy_tensor(const Tensor& tt_tensor) {
},
buffer);

auto shape = tt_tensor.get_legacy_shape();
auto np_shape = std::vector<std::uint32_t>(std::begin(shape), std::end(shape));
auto logical_shape = tt_tensor.get_logical_shape();
auto view = logical_shape.view();
std::vector<uint32_t> np_shape(view.begin(), view.end());
auto tensor = frombuffer(buffer, py::arg("dtype") = np_dtype);
tensor = tensor.attr("reshape")(np_shape);
tensor = np.attr("ascontiguousarray")(tensor);
Expand Down Expand Up @@ -1545,6 +1579,20 @@ void pytensor_module(py::module& m_tensor) {
py::return_value_policy::reference)
.def(
"to_torch",
[](const Tensor& self) -> py::object { return detail::convert_tt_tensor_to_torch_tensor(self, true); },
R"doc(
Convert tensor to torch tensor using legacy padded shape.
WARNING: Will be deprecated soon!
The tensor must be on host when calling this function.
.. code-block:: python
data = tt_tensor.cpu().to_torch() # move TT Tensor to host and convert it to torch tensor
)doc")
.def(
"to_torch_with_logical_shape",
[](const Tensor& self) -> py::object { return detail::convert_tt_tensor_to_torch_tensor(self); },
R"doc(
Convert tensor to torch tensor.
Expand All @@ -1553,7 +1601,7 @@ void pytensor_module(py::module& m_tensor) {
.. code-block:: python
data = tt_tensor.cpu().to_torch() # move TT Tensor to host and convert it to torch tensor
data = tt_tensor.cpu().to_torch_with_logical_shape() # move TT Tensor to host and convert it to torch tensor
)doc")
.def(
Expand Down
6 changes: 6 additions & 0 deletions ttnn/cpp/pybind11/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,12 @@ void tensor_mem_config_module(py::module& m_tensor) {
const bool& halo,
const ShardMode& shard_mode) { return ShardSpec(core_sets, shard_shape, shard_orientation, halo, shard_mode); }),
py::arg("grid"), py::arg("shard_shape"), py::arg("shard_orientation"), py::arg("halo"), py::arg("shard_mode") = ShardMode::PHYSICAL)
.def(py::init<>([](const CoreRangeSet& core_sets,
const std::array<uint32_t, 2>& shard_shape,
const std::array<uint32_t, 2>& physical_shard_shape,
const ShardOrientation& shard_orientation,
const bool& halo) { return ShardSpec(core_sets, shard_shape, physical_shard_shape, shard_orientation, halo); }),
py::arg("grid"), py::arg("shard_shape"), py::arg("physical_shard_shape"), py::arg("shard_orientation"), py::arg("halo"))
.def_readwrite("shape", &ShardSpec::shape, "Shape of shard.")
.def_readwrite("grid", &ShardSpec::grid, "Grid to layout shards.")
.def_readwrite("orientation", &ShardSpec::orientation, "Orientation of cores to read shards")
Expand Down
Loading

0 comments on commit 212be31

Please sign in to comment.