From 4f565bd1e9076579ebe406b3303a1f88c4bf0ec0 Mon Sep 17 00:00:00 2001 From: Naif Tarafdar <135640067+ntarafdar@users.noreply.github.com> Date: Thu, 14 Nov 2024 16:05:56 -0800 Subject: [PATCH] #12184: Alignment fix for BH on I2S and S2I (fix after revert) (#15055) ### Ticket [Link to Github Issue](https://github.com/tenstorrent/tt-metal/issues/12184#event-15046053642) ### Problem description Alignment issues for BH when going from DRAM to L1 in blackhole as a lot of alignment issues were hardcoded for WH case ### What's changed Added extra logic to handle alignment on i2s and s2i side. ### Checklist - [x] Post commit CI passes (https://github.com/tenstorrent/tt-metal/actions/runs/11846373612) - [ ] Blackhole Post commit (if applicable) - [ ] Model regression CI testing passes (if applicable) - [ ] Device performance regression CI testing passes (if applicable) - [ ] New/Existing tests provide coverage for changes --- .../unit_testing/misc/test_sharded.py | 14 ++- tests/ttnn/unit_tests/operations/test_core.py | 94 +++++++++++++++++++ ttnn/cpp/ttnn/operations/core/core.cpp | 39 +++++++- ...interleaved_to_sharded_program_factory.cpp | 48 +++++++--- .../device/sharded_to_interleaved_op.cpp | 3 +- ...sharded_to_interleaved_program_factory.cpp | 13 ++- 6 files changed, 185 insertions(+), 26 deletions(-) diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_sharded.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_sharded.py index b3e41058c67..5df2d752340 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_sharded.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_sharded.py @@ -101,6 +101,7 @@ def test_sharded_tile( # TODO (7735): Switch to new interleaved_to_sharded with sharded_mem_config input and re-enable BLOCK sharded tests +@skip_for_blackhole("WIP") @pytest.mark.parametrize( "input_shape, shard_scheme, shard_size, num_cores", [ @@ -180,7 +181,7 @@ def test_sharded_rm( assert passing -@skip_for_blackhole("Mismatching on BH, see #12349") +@skip_for_blackhole("BH LLK issue with untilize, #14594") @pytest.mark.parametrize("H, num_cores", [[100352, 98], [25088, 98]]) @pytest.mark.parametrize("in_sharded", [True, False]) @pytest.mark.parametrize("out_sharded", [True, False]) @@ -256,7 +257,7 @@ def test_sharded_untilize(H, num_cores, in_sharded, out_sharded, dtype, device, assert passing -@skip_for_blackhole("Mismatching on BH, see #12349") +@skip_for_blackhole("Mismatching on BH, see #14609") @pytest.mark.parametrize("H, num_cores", [[25088, 98]]) @pytest.mark.parametrize("output_dtype", [ttnn.bfloat16, ttnn.bfloat8_b]) def test_sharded_tilize(H, num_cores, output_dtype, device, function_level_defaults): @@ -895,6 +896,7 @@ def test_partial_sharded_op_binary( assert passing +@pytest.mark.skipif(is_blackhole(), reason="BH ND hang, see issue #14745") @pytest.mark.parametrize("in0_sharded", [True, False], ids=["in0_sharded", "in0_unsharded"]) @pytest.mark.parametrize("in1_sharded", [True, False], ids=["in1_sharded", "in1_unsharded"]) @pytest.mark.parametrize("out_sharded", [True, False], ids=["out_sharded", "out_unsharded"]) @@ -1335,6 +1337,7 @@ def test_sharded_matmul_2d_transposed( assert passing +@pytest.mark.skipif(is_blackhole(), reason="BH ND hang, see issue #14745") def test_resharded_binary_to_matmul(device, function_level_defaults): grid_size_binary = device.compute_with_storage_grid_size() num_cores_binary = 98 @@ -1426,6 +1429,7 @@ def test_resharded_binary_to_matmul(device, function_level_defaults): assert passing +@pytest.mark.skipif(is_blackhole(), reason="BH ND hang, see issue #14745") @pytest.mark.parametrize("in_sharded", [True, False], ids=["in0_sharded", "in0_unsharded"]) @pytest.mark.parametrize("out_sharded", [False], ids=["out_unsharded"]) @pytest.mark.parametrize("dtype", [ttnn.bfloat16, ttnn.bfloat8_b]) @@ -1501,6 +1505,7 @@ def test_sharded_untilize_padded_shard(in_sharded, out_sharded, dtype, device, f assert passing +@pytest.mark.skipif(is_blackhole(), reason="BH ND hang, see issue #14745") @pytest.mark.parametrize("in_sharded", [True, False], ids=["in0_sharded", "in0_unsharded"]) @pytest.mark.parametrize("out_sharded", [False], ids=["out_unsharded"]) @pytest.mark.parametrize("activations_dtype", [ttnn.bfloat16, ttnn.bfloat8_b]) @@ -1691,6 +1696,7 @@ def test_block_sharded_untilize_with_unpadding(in_sharded, out_sharded, dtype, d "unbatched_16_shape_out_interleaved", ], ) +@skip_for_blackhole("BH Issue with untilize LLK, see #14594") @pytest.mark.parametrize("dtype", [ttnn.bfloat16, ttnn.bfloat8_b]) def test_width_sharded_untilize_with_unpadding( shape, output_H, in_sharded, out_sharded, dtype, device, function_level_defaults @@ -1761,7 +1767,7 @@ def test_width_sharded_untilize_with_unpadding( assert passing -@skip_for_blackhole("Mismatching on BH, see #12349") +@skip_for_blackhole("BH LLK Issue with tilize, #14609") @pytest.mark.parametrize("input_shape", [[8, 1, 49, 2048], [1, 1, 8, 2048], [16, 1, 49, 2048], [1, 1, 16, 2048]]) @pytest.mark.parametrize("sharding_config", [(True, True), (False, False)], ids=["both_sharded", "both_interleaved"]) @pytest.mark.parametrize("output_dtype", [ttnn.bfloat16, ttnn.bfloat8_b]) @@ -1833,7 +1839,6 @@ def test_sharded_tilize_with_val_padding(input_shape, sharding_config, output_dt assert passing -@skip_for_blackhole("Mismatching on BH, see #12349") @pytest.mark.parametrize("N", [8, 16]) @pytest.mark.parametrize("in_sharded", [True], ids=["in0_sharded"]) @pytest.mark.parametrize("out_sharded", [True], ids=["out_sharded"]) @@ -2064,6 +2069,7 @@ def test_sharded_matmul_1d_in1_wormhole(device, function_level_defaults): assert passing +@pytest.mark.skipif(is_blackhole(), reason="BH ND hang, see issue #14745") @pytest.mark.parametrize("in0_sharded", [True, False], ids=["in0_sharded", "in0_unsharded"]) @pytest.mark.parametrize("in1_sharded", [True, False], ids=["in1_sharded", "in1_unsharded"]) @pytest.mark.parametrize("out_sharded", [True, False], ids=["out_sharded", "out_unsharded"]) diff --git a/tests/ttnn/unit_tests/operations/test_core.py b/tests/ttnn/unit_tests/operations/test_core.py index 23b9d1f8459..c39154379df 100644 --- a/tests/ttnn/unit_tests/operations/test_core.py +++ b/tests/ttnn/unit_tests/operations/test_core.py @@ -439,3 +439,97 @@ def test_create_sharded_memory_config(device, shape, strategy, orientation, core passing = torch.equal(input_data, output_data) assert passing + + +@pytest.mark.parametrize( + "shape, shard_shape, strategy, orientation, core_grid", + [ + ([1, 1, 2, 16], None, ttnn.ShardStrategy.WIDTH, ttnn.ShardOrientation.ROW_MAJOR, ttnn.CoreGrid(y=1, x=1)), + ([1, 1, 2, 16], None, ttnn.ShardStrategy.WIDTH, ttnn.ShardOrientation.ROW_MAJOR, ttnn.CoreGrid(y=2, x=1)), + ([1, 1, 32, 16], None, ttnn.ShardStrategy.HEIGHT, ttnn.ShardOrientation.ROW_MAJOR, ttnn.CoreGrid(y=2, x=1)), + ([1, 1, 64, 16], None, ttnn.ShardStrategy.HEIGHT, ttnn.ShardOrientation.ROW_MAJOR, ttnn.CoreGrid(y=2, x=1)), + ( + [1, 1, 2, 16], + [2, 16], + ttnn.ShardStrategy.HEIGHT, + ttnn.ShardOrientation.ROW_MAJOR, + ttnn.CoreRangeSet( + { + ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(0, 0)), + } + ), + ), + ( + [1, 1, 5280, 16], + [5280, 16], + ttnn.ShardStrategy.HEIGHT, + ttnn.ShardOrientation.ROW_MAJOR, + ttnn.CoreRangeSet( + { + ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(0, 0)), + } + ), + ), + # TODO: Add this test back by checking for core grid size and skipping if we can't do it + # ( + # [1, 1, 675840, 16], + # [5280, 16], + # ttnn.ShardStrategy.HEIGHT, + # ttnn.ShardOrientation.ROW_MAJOR, + # ttnn.CoreRangeSet( + # { + # ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(11, 9)), # 120 + # ttnn.CoreRange(ttnn.CoreCoord(12, 0), ttnn.CoreCoord(12, 7)), # 8 + # } + # ), + # ), + ], +) +@pytest.mark.parametrize( + "input_buffer_type", + [ + ttnn.L1_MEMORY_CONFIG, + ttnn.DRAM_MEMORY_CONFIG, + ], +) +@pytest.mark.parametrize( + "output_buffer_type", + [ + ttnn.L1_MEMORY_CONFIG, + ttnn.DRAM_MEMORY_CONFIG, + ], +) +def test_bh_alignment_i2s( + device, shape, shard_shape, strategy, orientation, core_grid, input_buffer_type, output_buffer_type +): + torch.manual_seed(0) + input_data = torch.randn(shape, dtype=torch.bfloat16) + if shard_shape == None: + shard_config = ttnn.create_sharded_memory_config( + shape=shape, + core_grid=core_grid, + strategy=strategy, + orientation=orientation, + use_height_and_width_as_shard_shape=False, + ) + else: + shard_config = ttnn.create_sharded_memory_config( + shape=shard_shape, + core_grid=core_grid, + strategy=strategy, + orientation=orientation, + use_height_and_width_as_shard_shape=True, + ) + x_t = ttnn.from_torch( + input_data, + device=device, + layout=ttnn.ROW_MAJOR_LAYOUT, + memory_config=input_buffer_type, + dtype=ttnn.bfloat16, + ) + x_t_sharded = ttnn.to_memory_config(x_t, shard_config) + x_t = ttnn.to_memory_config(x_t_sharded, output_buffer_type) + output_data = ttnn.from_device(x_t) + output_data = ttnn.to_torch(output_data) + passing = torch.equal(input_data, output_data) + assert passing diff --git a/ttnn/cpp/ttnn/operations/core/core.cpp b/ttnn/cpp/ttnn/operations/core/core.cpp index dba2edf328b..978f3413b48 100644 --- a/ttnn/cpp/ttnn/operations/core/core.cpp +++ b/ttnn/cpp/ttnn/operations/core/core.cpp @@ -11,6 +11,8 @@ #include "ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.hpp" #include "ttnn/operations/data_movement/data_transfer/data_transfer.hpp" #include "ttnn/distributed/types.hpp" +#include "ttnn/operations/data_movement/sharded/sharded_to_interleaved/sharded_to_interleaved.hpp" +#include "ttnn/operations/data_movement/sharded/interleaved_to_sharded/interleaved_to_sharded.hpp" namespace ttnn::operations::core { @@ -54,12 +56,30 @@ ttnn::Tensor squeeze_from_4D(const ttnn::Tensor& tensor, const int rank) { } ttnn::Tensor to_device(const ttnn::Tensor& tensor, Device* device, const std::optional& memory_config) { - return tensor.to(device, memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG)); + auto mem_config = memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG); + if(mem_config.is_sharded () and (device->arch() == tt::ARCH::BLACKHOLE)) { + auto interleaved_tensor = tensor.to(device, ttnn::DRAM_MEMORY_CONFIG); + return ttnn::interleaved_to_sharded(ttnn::DefaultQueueId, interleaved_tensor, mem_config, std::nullopt); + } + else { + return tensor.to(device, memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG)); + } + } ttnn::Tensor to_device( const ttnn::Tensor& tensor, MeshDevice* mesh_device, const std::optional& memory_config) { - return tensor.to(mesh_device, memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG)); + auto mem_config = memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG); + // Currently no direct sharded write support in BLACKHOLE due to alignment issue + if(mem_config.is_sharded () and (mesh_device->arch() == tt::ARCH::BLACKHOLE)) { + auto interleaved_tensor = tensor.to(mesh_device, ttnn::DRAM_MEMORY_CONFIG); + return ttnn::interleaved_to_sharded(ttnn::DefaultQueueId, interleaved_tensor, mem_config, std::nullopt); + } + else { + return tensor.to(mesh_device, mem_config); + } + + } ttnn::Tensor allocate_tensor_on_device( @@ -86,7 +106,20 @@ void copy_host_to_device_tensor(ttnn::Tensor host_tensor, ttnn::Tensor device_te tt::tt_metal::write_tensor(host_tensor, device_tensor, cq_id); } -ttnn::Tensor from_device(const ttnn::Tensor& tensor, bool blocking, uint8_t cq_id) { return tensor.cpu(blocking, cq_id); } + +ttnn::Tensor from_device(const ttnn::Tensor& tensor, bool blocking, uint8_t cq_id) { + + // Currently no direct sharded read support in BLACKHOLE due to alignment issue + if(tensor.is_sharded () and (tensor.device()->arch() == tt::ARCH::BLACKHOLE)) { + auto interleaved_tensor = ttnn::sharded_to_interleaved(cq_id, tensor, ttnn::DRAM_MEMORY_CONFIG, std::nullopt); + return interleaved_tensor.cpu(blocking, cq_id); + } + else { + return tensor.cpu(blocking, cq_id); + + } + +} void deallocate(Tensor& tensor, bool force) { tensor.deallocate(force); } diff --git a/ttnn/cpp/ttnn/operations/data_movement/sharded/interleaved_to_sharded/device/interleaved_to_sharded_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/sharded/interleaved_to_sharded/device/interleaved_to_sharded_program_factory.cpp index d41cadcf1d1..b07f464e4ca 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/sharded/interleaved_to_sharded/device/interleaved_to_sharded_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/sharded/interleaved_to_sharded/device/interleaved_to_sharded_program_factory.cpp @@ -32,6 +32,13 @@ operation::ProgramWithCallbacks interleaved_to_sharded_multi_core( bool rm_orientation = shard_spec.orientation == ShardOrientation::ROW_MAJOR; CoreCoord end_core = (*shard_spec.grid.ranges().rbegin()).end_coord; + + bool convert_df = input_cb_data_format != output_cb_data_format; + auto src_buffer = input.buffer(); + auto dst_buffer = output.buffer(); + bool src_is_dram = src_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; + bool is_blackhole = (input.device()->arch() == tt::ARCH::BLACKHOLE); + if (input.get_layout() == Layout::TILE) { num_units = input.volume() / TILE_HW; input_unit_size = tt::tt_metal::detail::TileSize(input_cb_data_format); @@ -66,13 +73,6 @@ operation::ProgramWithCallbacks interleaved_to_sharded_multi_core( padded_offset_bytes = align(input_unit_size, input.buffer()->alignment()); } - bool convert_df = input_cb_data_format != output_cb_data_format; - - auto src_buffer = input.buffer(); - - auto dst_buffer = output.buffer(); - - bool src_is_dram = src_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; auto all_cores = shard_spec.grid; uint32_t input_cb_index = tt::CB::c_in0; @@ -94,10 +94,17 @@ operation::ProgramWithCallbacks interleaved_to_sharded_multi_core( .set_globally_allocated_address(*output.buffer()); auto cb_output = tt::tt_metal::CreateCircularBuffer(program, all_cores, output_cb_out_config); uint32_t dram_alignment = hal.get_alignment(HalMemType::DRAM); - if (src_is_dram && input_unit_size % dram_alignment != 0) { - uint32_t scratch_cb_page_size = align(input_unit_size, dram_alignment); + if (src_is_dram && input_unit_size % dram_alignment != 0 or is_blackhole) { + uint32_t scratch_cb_page_size; + //scratchpad going to be used to align DRAM (64B) to L1 (16B) + if (is_blackhole) { + scratch_cb_page_size = align(input_unit_size, hal.get_alignment(HalMemType::L1)); + } + else { + scratch_cb_page_size = align(input_unit_size, dram_alignment); + } tt::tt_metal::CircularBufferConfig scratch_cb_out_config = - tt::tt_metal::CircularBufferConfig(1 * scratch_cb_page_size, {{scratch_cb_index, input_cb_data_format}}) + tt::tt_metal::CircularBufferConfig(4 * scratch_cb_page_size, {{scratch_cb_index, input_cb_data_format}}) .set_page_size(scratch_cb_index, scratch_cb_page_size); auto cb_scratch = tt::tt_metal::CreateCircularBuffer(program, all_cores, scratch_cb_out_config); } @@ -236,10 +243,23 @@ operation::ProgramWithCallbacks interleaved_to_sharded_multi_core( } uint32_t dram_alignment = hal.get_alignment(HalMemType::DRAM); - bool aligned = src_is_dram ? curr_idx_w % dram_alignment == 0 : true; + uint32_t l1_alignment = hal.get_alignment(HalMemType::L1); + bool aligned = (src_is_dram ? curr_idx_w % dram_alignment == 0 : true); + aligned = aligned and !(is_blackhole); uint32_t aligned_width_offset, aligned_shard_width, aligned_offset; if (!aligned) { - aligned_width_offset = tt::round_down(curr_idx_w, dram_alignment); + //TODO: is this right, leaving non BH case the same for now, should investigate + if(!is_blackhole) { + aligned_width_offset = tt::round_down(curr_idx_w, dram_alignment); + } + else { + if(src_is_dram) { + aligned_width_offset = tt::round_down(curr_idx_w, dram_alignment); + } + else { + aligned_width_offset = tt::round_down(curr_idx_w, l1_alignment); + } + } aligned_offset = curr_idx_w - aligned_width_offset; aligned_shard_width = aligned_offset + shard_width; } else { @@ -256,7 +276,7 @@ operation::ProgramWithCallbacks interleaved_to_sharded_multi_core( num_units_per_row, shard_height, shard_width, - padded_offset_bytes, + (is_blackhole) ? shard_width : padded_offset_bytes, static_cast(aligned), aligned_width_offset, aligned_shard_width, @@ -305,6 +325,4 @@ operation::ProgramWithCallbacks interleaved_to_sharded_multi_core( return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_arguments_callback}; } - - } diff --git a/ttnn/cpp/ttnn/operations/data_movement/sharded/sharded_to_interleaved/device/sharded_to_interleaved_op.cpp b/ttnn/cpp/ttnn/operations/data_movement/sharded/sharded_to_interleaved/device/sharded_to_interleaved_op.cpp index 55b32e3c00a..f736258f7d6 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/sharded/sharded_to_interleaved/device/sharded_to_interleaved_op.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/sharded/sharded_to_interleaved/device/sharded_to_interleaved_op.cpp @@ -20,9 +20,8 @@ void ShardedToInterleavedDeviceOperation::validate(const std::vector& in TT_FATAL(input_tensor.memory_config().buffer_type == BufferType::L1, "Input tensor must be in L1"); TT_FATAL(this->output_mem_config.memory_layout == TensorMemoryLayout::INTERLEAVED, "Output memory config must be Interleaved"); if (input_tensor.get_layout() == Layout::ROW_MAJOR) { - uint32_t dram_alignment = hal.get_alignment(HalMemType::DRAM); uint32_t l1_alignment = hal.get_alignment(HalMemType::L1); - TT_FATAL((*input_tensor.memory_config().shard_spec).shape[1] * input_tensor.element_size() % (this->output_mem_config.buffer_type == BufferType::DRAM ? dram_alignment : l1_alignment) == 0, "Shard page size must be aligned to {}B for L1 Tensor, or {}B for DRAM tensor", l1_alignment, dram_alignment); + TT_FATAL((*input_tensor.memory_config().shard_spec).shape[1] * input_tensor.element_size() % (l1_alignment) == 0, "Shard page size must be aligned to {}B for L1 Tensor", l1_alignment); } if (input_tensor.get_dtype() != this->output_dtype) { TT_FATAL(input_tensor.get_layout() == Layout::TILE, "If diff output type, tensor must be TILED"); diff --git a/ttnn/cpp/ttnn/operations/data_movement/sharded/sharded_to_interleaved/device/sharded_to_interleaved_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/sharded/sharded_to_interleaved/device/sharded_to_interleaved_program_factory.cpp index 6d585e65a13..2cb58883bf1 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/sharded/sharded_to_interleaved/device/sharded_to_interleaved_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/sharded/sharded_to_interleaved/device/sharded_to_interleaved_program_factory.cpp @@ -98,6 +98,7 @@ operation::ProgramWithCallbacks sharded_to_interleaved_multi_core( tt_metal::ReaderDataMovementConfig(reader_compile_time_args)); bool dst_is_dram = dst_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; + bool is_blackhole = (input.device()->arch() == tt::ARCH::BLACKHOLE); tt_metal::KernelHandle unary_writer_kernel_id; if (input.get_layout() == Layout::TILE) { @@ -141,7 +142,8 @@ operation::ProgramWithCallbacks sharded_to_interleaved_multi_core( uint32_t curr_idx_w = 0; const auto cores = corerange_to_cores(all_cores, std::nullopt, rm_orientation); - uint32_t padded_shard_width = align(output_unit_size, dst_buffer->alignment()); + uint32_t padded_offset_bytes; + for (const auto& core : cores) { if (input.get_layout() == Layout::TILE) { uint32_t shard_height = num_units_per_shard_height; @@ -217,6 +219,13 @@ operation::ProgramWithCallbacks sharded_to_interleaved_multi_core( } } } + uint32_t dram_alignment = hal.get_alignment(HalMemType::DRAM); + uint32_t l1_alignment = hal.get_alignment(HalMemType::L1); + uint32_t padded_shard_width = align(output_unit_size, dst_buffer->alignment()); + if(is_blackhole) { + if(!dst_is_dram) + padded_shard_width = align(output_unit_size, l1_alignment); + } tt_metal::SetRuntimeArgs( program, unary_writer_kernel_id, @@ -225,7 +234,7 @@ operation::ProgramWithCallbacks sharded_to_interleaved_multi_core( num_units_per_row, shard_height, shard_width, - padded_shard_width, + (is_blackhole) ? shard_width : padded_shard_width, curr_idx_w, curr_idx_h}); curr_idx_w += output_unit_size;