From dbba7ca0f92af09cf79cbdf4ec0134e35390afa3 Mon Sep 17 00:00:00 2001 From: Evan Smal Date: Wed, 11 Dec 2024 16:03:27 +0000 Subject: [PATCH] Handle padded shards in `ttnn.convert_to_chw` --- .../operations/test_convert_to_chw.py | 55 ++++++++++++++++--- .../device/convert_to_chw_op.cpp | 8 ++- .../device/convert_to_chw_program_factory.cpp | 39 +++++++++---- 3 files changed, 80 insertions(+), 22 deletions(-) diff --git a/tests/ttnn/unit_tests/operations/test_convert_to_chw.py b/tests/ttnn/unit_tests/operations/test_convert_to_chw.py index 4fca6b024e7d..d3880ab5dd5f 100644 --- a/tests/ttnn/unit_tests/operations/test_convert_to_chw.py +++ b/tests/ttnn/unit_tests/operations/test_convert_to_chw.py @@ -39,24 +39,61 @@ def test_convert_to_chw(device, C, HW, core_grid): [1, 1, 32, HW], core_grid, ttnn.ShardStrategy.WIDTH, ttnn.ShardOrientation.ROW_MAJOR ) actual = ttnn.experimental.convert_to_chw(input_tensor, memory_config=output_memory_config) - actual = ttnn.to_torch(actual) - assert_with_pcc(expected, actual, 0.9999999) + assert_with_pcc(expected, ttnn.to_torch(actual), 1.0) + + return actual + + +@skip_for_grayskull() +@skip_for_blackhole() +@pytest.mark.parametrize("C", [1, 2, 4]) +@pytest.mark.parametrize( + "HW, core_grid, padded_sharded_dim", + ( + (96, ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(1, 0))}), 64), + (1056 * 160, ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 7))}), 2656), + ), +) +def test_convert_to_chw_padded(device, C, HW, core_grid, padded_sharded_dim): + input_tensor = torch.randn([1, 1, HW, C], dtype=torch.bfloat16) + expected = input_tensor.transpose(2, 3) + + input_shard_shape = (padded_sharded_dim, 32) + input_shard_spec = ttnn.ShardSpec(core_grid, input_shard_shape, ttnn.ShardOrientation.ROW_MAJOR, False) + input_mem_config = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.L1, input_shard_spec) + + output_shard_shape = (C, padded_sharded_dim) + output_shard_spec = ttnn.ShardSpec(core_grid, output_shard_shape, ttnn.ShardOrientation.ROW_MAJOR, False) + output_mem_config = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.WIDTH_SHARDED, ttnn.BufferType.L1, output_shard_spec) + + input_tensor = ttnn.from_torch(input_tensor, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT) + input_tensor = ttnn.to_device(input_tensor, device, memory_config=input_mem_config) + + actual = ttnn.experimental.convert_to_chw(input_tensor, memory_config=output_mem_config) + + assert_with_pcc(expected, ttnn.to_torch(actual), 1.0) + + return actual @skip_for_grayskull() @skip_for_blackhole() def test_convert_to_chw_with_program_cache(device, use_program_cache): - C, HW = 8, 128 - core_grid = ttnn.CoreGrid(x=2, y=1) + C, HW, core_grid = 2, 256, ttnn.CoreGrid(x=2, y=1) + + C_padded, HW_padded, padded_sharded_dim = 4, 96, 64 + core_grid_padded = ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(1, 0))}) - for _ in range(2): - test_convert_to_chw(device, C, HW, core_grid) - test_convert_to_chw(device, C, HW, core_grid) - dummy_shape = [1, 1, 128, 128] + a, b, c = None, None, None + for _ in range(8): + a = test_convert_to_chw_padded(device, C_padded, HW_padded, core_grid_padded, padded_sharded_dim) + b = test_convert_to_chw(device, C, HW, core_grid) + c = test_convert_to_chw_padded(device, C_padded, HW_padded, core_grid_padded, padded_sharded_dim) + dummy_shape = [1, 1, 256, 128] py_dummy_tensor = torch.randn(dummy_shape) tt_dummy_tensor = ( ttnn.Tensor(py_dummy_tensor, ttnn.bfloat16).to(ttnn.TILE_LAYOUT).to(device, ttnn.L1_MEMORY_CONFIG) ) - assert device.num_program_cache_entries() == 1 + assert device.num_program_cache_entries() == 2 diff --git a/ttnn/cpp/ttnn/operations/experimental/cnn/convert_to_chw/device/convert_to_chw_op.cpp b/ttnn/cpp/ttnn/operations/experimental/cnn/convert_to_chw/device/convert_to_chw_op.cpp index fc4322cb9652..05bba0aeeb9b 100644 --- a/ttnn/cpp/ttnn/operations/experimental/cnn/convert_to_chw/device/convert_to_chw_op.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/cnn/convert_to_chw/device/convert_to_chw_op.cpp @@ -23,10 +23,16 @@ void ConvertToCHW::validate(const std::vector& input_tensors) const { TT_FATAL(C <= TILE_HEIGHT, "C must be less than or equal to 32 (was {})", C); TT_FATAL(HW % TILE_HEIGHT == 0, "HW must be divisible by tile size"); + TT_FATAL(input.is_sharded(), "Input tensor must be sharded"); + + const auto& input_shard_spec = input.memory_config().shard_spec.value(); + TT_FATAL( + input_shard_spec.shape[0] % TILE_HEIGHT == 0, + "Shard height must be divisible by tile size"); // input shards can be padded so HW may not match shard height + TT_FATAL( this->memory_config.is_sharded() && this->memory_config.memory_layout == TensorMemoryLayout::WIDTH_SHARDED, "Output tensor must be width sharded"); - // TODO: Check that grids match } std::vector ConvertToCHW::compute_output_shapes( diff --git a/ttnn/cpp/ttnn/operations/experimental/cnn/convert_to_chw/device/convert_to_chw_program_factory.cpp b/ttnn/cpp/ttnn/operations/experimental/cnn/convert_to_chw/device/convert_to_chw_program_factory.cpp index fdd27e1e0779..8485f7fe7f1e 100644 --- a/ttnn/cpp/ttnn/operations/experimental/cnn/convert_to_chw/device/convert_to_chw_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/cnn/convert_to_chw/device/convert_to_chw_program_factory.cpp @@ -20,20 +20,29 @@ operation::ProgramWithCallbacks multi_core_convert_to_chw( const auto HW = input_shape[2]; const auto C = input_shape[3]; + tt::log_debug(tt::LogType::LogOp, "Running op with HW={}, C={}, shard_shape={}", HW, C, a.shard_spec()->shape); + TT_ASSERT(C < TILE_HEIGHT, "C must be 32 or smaller"); - const uint32_t c_tiles = 1; // assume C <= 32 - const uint32_t hw_tiles = HW / TILE_HEIGHT; - const uint32_t total_tiles = hw_tiles * c_tiles; - const uint32_t total_tiles_per_core = hw_tiles * c_tiles / input_cores.size(); + const uint32_t total_tiles = HW / TILE_HEIGHT; // assume C < 32 + const uint32_t total_tiles_per_core = tt::div_up(total_tiles, input_cores.size()); + + tt::log_debug( + tt::LogType::LogOp, "Processing {} tiles per core ({} total tiles)", total_tiles_per_core, total_tiles); const auto create_circular_buffer = [&program, &input_core_grid]( uint32_t index, - uint32_t num_tiles, - uint32_t tile_size, + uint32_t total_size, + uint32_t page_size, const tt::DataFormat& format, Buffer* buffer = nullptr) -> tt::tt_metal::CBHandle { - auto config = CircularBufferConfig(num_tiles * tile_size, {{index, format}}).set_page_size(index, tile_size); + tt::log_debug( + tt::LogType::LogOp, + "Creating CB at index {} with total size {} B and page size {} B", + index, + total_size, + page_size); + auto config = CircularBufferConfig(total_size, {{index, format}}).set_page_size(index, page_size); if (buffer != nullptr) { config = config.set_globally_allocated_address(*buffer); } @@ -47,16 +56,22 @@ operation::ProgramWithCallbacks multi_core_convert_to_chw( const uint32_t intermediary_tile_size = tt::tt_metal::detail::TileSize(intermediary_format); const uint32_t cb_in_id = tt::CB::c_in0; - const auto cb_in = - create_circular_buffer(cb_in_id, total_tiles_per_core, input_tile_size, input_format, a.buffer()); + const uint32_t cb_in_total_size = total_tiles_per_core * input_tile_size; + const uint32_t cb_in_page_size = input_tile_size; + const auto cb_in = create_circular_buffer(cb_in_id, cb_in_total_size, cb_in_page_size, input_format, a.buffer()); const uint32_t cb_out_id = tt::CB::c_out0; + const uint32_t element_size = tt::datum_size(input_format); + const uint32_t cb_out_total_size = C * HW * element_size / input_cores.size(); + const uint32_t cb_out_page_size = HW / input_cores.size(); const auto cb_out = - create_circular_buffer(cb_out_id, total_tiles_per_core, input_tile_size, input_format, output.buffer()); + create_circular_buffer(cb_out_id, cb_out_total_size, cb_out_page_size, input_format, output.buffer()); const uint32_t cb_in_transpose_id = tt::CB::c_intermed0; - const auto cb_in_transpose = - create_circular_buffer(cb_in_transpose_id, 1, intermediary_tile_size, intermediary_format); + const uint32_t cb_in_transpose_total_size = intermediary_tile_size; + const uint32_t cb_in_transpose_page_size = intermediary_tile_size; + const auto cb_in_transpose = create_circular_buffer( + cb_in_transpose_id, cb_in_transpose_total_size, cb_in_transpose_page_size, intermediary_format); std::vector reader_compile_time_args = {cb_in_id}; std::vector writer_compile_time_args = {cb_in_transpose_id, cb_out_id, C};