From 0326577b41bfdf59274ad1449fbffa4d5d76deac Mon Sep 17 00:00:00 2001 From: Che-Yu Wu Date: Tue, 22 Oct 2024 08:06:25 +0800 Subject: [PATCH] #10548: Support tile layout for width/height-sharded concat (#13744) Co-authored-by: Artem Yerofieiev <169092593+ayerofieiev-tt@users.noreply.github.com> --- .../ttnn/unit_tests/operations/test_concat.py | 40 ++++++++++++- .../concat/device/concat_device_operation.cpp | 5 +- .../concat/device/concat_program_factory.cpp | 56 ++++++++++--------- ...oncat.cpp => reader_s2s_tensor_concat.cpp} | 0 4 files changed, 72 insertions(+), 29 deletions(-) rename ttnn/cpp/ttnn/operations/data_movement/concat/device/kernels/dataflow/{reader_s2s_rm_tensor_concat.cpp => reader_s2s_tensor_concat.cpp} (100%) diff --git a/tests/ttnn/unit_tests/operations/test_concat.py b/tests/ttnn/unit_tests/operations/test_concat.py index ca04c0512d5..fed420e300e 100644 --- a/tests/ttnn/unit_tests/operations/test_concat.py +++ b/tests/ttnn/unit_tests/operations/test_concat.py @@ -34,13 +34,14 @@ def test_concat(device, height, width, dim, async_mode): @pytest.mark.parametrize( - "inputs, output_shard_shape, shard_grid, strategy, cache_mode", + "inputs, output_shard_shape, shard_grid, strategy, layout, cache_mode", ( ( [((1, 1, 160, 32), (80, 32)), ((1, 1, 160, 32), (80, 32))], (80, 64), ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(0, 1))}), ttnn.ShardStrategy.HEIGHT, + ttnn.ROW_MAJOR_LAYOUT, False, ), ( @@ -48,6 +49,7 @@ def test_concat(device, height, width, dim, async_mode): (80, 48), ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(0, 1))}), ttnn.ShardStrategy.HEIGHT, + ttnn.ROW_MAJOR_LAYOUT, False, ), ( @@ -60,6 +62,7 @@ def test_concat(device, height, width, dim, async_mode): } ), ttnn.ShardStrategy.HEIGHT, + ttnn.ROW_MAJOR_LAYOUT, False, ), pytest.param( @@ -72,6 +75,7 @@ def test_concat(device, height, width, dim, async_mode): } ), ttnn.ShardStrategy.HEIGHT, + ttnn.ROW_MAJOR_LAYOUT, True, marks=pytest.mark.xfail(reason="two tensors concat kernel doesn't work with program cache (#13466)"), ), @@ -80,6 +84,7 @@ def test_concat(device, height, width, dim, async_mode): (8, 48), ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(0, 1))}), ttnn.ShardStrategy.HEIGHT, + ttnn.ROW_MAJOR_LAYOUT, False, ), ( @@ -87,6 +92,7 @@ def test_concat(device, height, width, dim, async_mode): (8, 48), ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(0, 1))}), ttnn.ShardStrategy.HEIGHT, + ttnn.ROW_MAJOR_LAYOUT, True, ), ( @@ -94,6 +100,7 @@ def test_concat(device, height, width, dim, async_mode): (38, 16), ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(0, 3))}), ttnn.ShardStrategy.WIDTH, + ttnn.ROW_MAJOR_LAYOUT, False, ), ( @@ -101,12 +108,39 @@ def test_concat(device, height, width, dim, async_mode): (38, 16), ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(0, 3))}), ttnn.ShardStrategy.WIDTH, + ttnn.ROW_MAJOR_LAYOUT, True, ), + ( + [((1, 1, 256, 96), (64, 96)), ((1, 1, 256, 64), (64, 64)), ((1, 1, 256, 32), (64, 32))], + (64, 192), + ttnn.CoreRangeSet( + { + ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(0, 1)), + ttnn.CoreRange(ttnn.CoreCoord(1, 0), ttnn.CoreCoord(2, 0)), + } + ), + ttnn.ShardStrategy.HEIGHT, + ttnn.TILE_LAYOUT, + False, + ), + ( + [((1, 1, 32, 512), (32, 64)), ((1, 1, 64, 512), (64, 64)), ((1, 1, 96, 512), (96, 64))], + (192, 64), + ttnn.CoreRangeSet( + { + ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(0, 3)), + ttnn.CoreRange(ttnn.CoreCoord(1, 0), ttnn.CoreCoord(2, 1)), + } + ), + ttnn.ShardStrategy.WIDTH, + ttnn.TILE_LAYOUT, + False, + ), ), ) @pytest.mark.parametrize("async_mode", [True, False], ids=["async_on", "async_off"]) -def test_sharded_concat(device, inputs, output_shard_shape, shard_grid, strategy, cache_mode, async_mode): +def test_sharded_concat(device, inputs, output_shard_shape, shard_grid, strategy, layout, cache_mode, async_mode): device.enable_async(async_mode) if cache_mode: device.enable_program_cache() @@ -124,7 +158,7 @@ def _gen_inputs(input_specs): use_height_and_width_as_shard_shape=True, ) torch_input_tensor = torch.rand(shape, dtype=torch.bfloat16) - input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.ROW_MAJOR_LAYOUT, device=device) + input_tensor = ttnn.from_torch(torch_input_tensor, layout=layout, device=device) input_tensor = ttnn.to_memory_config(input_tensor, input_sharded_memory_config) input_tensors.append((torch_input_tensor, input_tensor)) return input_tensors diff --git a/ttnn/cpp/ttnn/operations/data_movement/concat/device/concat_device_operation.cpp b/ttnn/cpp/ttnn/operations/data_movement/concat/device/concat_device_operation.cpp index d703db7714a..53785605aa0 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/concat/device/concat_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/concat/device/concat_device_operation.cpp @@ -53,7 +53,10 @@ void ConcatDeviceOperation::validate(const std::vector &input_tensors) c } TT_FATAL(in_ref.is_sharded() == shard_first, "All tensors must be sharded or all must be interleaved"); if (shard_first) { - TT_FATAL(in_ref.get_layout() == Layout::ROW_MAJOR, "Only row major supported for sharded concat."); + // TODO(jerrysky3): Remove this when we replace the two tensors concat kernel with the general one. + TT_FATAL( + input_tensors.size() > 2 || in_ref.get_layout() == Layout::ROW_MAJOR, + "Only row major supported for sharded two tensors concat."); TT_FATAL(in_ref.shard_spec().has_value(), "Sharded tensors must have a shard spec."); TT_FATAL( in_ref.shard_spec().value().grid == first_input.shard_spec().value().grid, diff --git a/ttnn/cpp/ttnn/operations/data_movement/concat/device/concat_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/concat/device/concat_program_factory.cpp index bd9bfec09b3..af79fddf7d9 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/concat/device/concat_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/concat/device/concat_program_factory.cpp @@ -154,7 +154,7 @@ operation::ProgramWithCallbacks s2s_rm_concat_two_tensors_multi_core( return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_arguments_callback}; } -// Concat sharded tensors into sharded output in row-major layout. Currently it only supports height-sharded width +// Concat sharded tensors into sharded output in row-major/tile layout. Currently it only supports height-sharded width // concat or width-sharded height concat. // // It is done by copying each row of each input sharded tensor to the right offset in the sharded output tensor based on @@ -164,9 +164,9 @@ operation::ProgramWithCallbacks s2s_rm_concat_two_tensors_multi_core( // output. The memory address gap between neighbor input rows is exactly the output width. In height concat, all input // rows are placed at column 0 but sequential rows in the output. The address gap between neighbor input rows is still // the output width (which is equal to the input width). -operation::ProgramWithCallbacks s2s_rm_concat_multi_core( +operation::ProgramWithCallbacks s2s_concat_multi_core( const std::vector &input_tensors, uint32_t dim, Tensor &output) { - TT_FATAL(dim == 2 || dim == 3, "Sharded concat RM only supports dim=2 or 3"); + TT_FATAL(dim == 2 || dim == 3, "Sharded concat only supports dim=2 or 3"); const bool is_height_concat = dim == 2; tt_metal::Program program = tt_metal::CreateProgram(); @@ -175,26 +175,37 @@ operation::ProgramWithCallbacks s2s_rm_concat_multi_core( const uint32_t num_input_tensors = input_tensors.size(); const uint32_t cb_dst_id = 16; TT_FATAL(num_input_tensors <= cb_dst_id, "Not enough circular buffer for {} inputs.", num_input_tensors); + const tt::DataFormat cb_data_format = tt_metal::datatype_to_dataformat_converter(output.get_dtype()); + const bool rm_layout = output.get_layout() == Layout::ROW_MAJOR; // Assume inputs and output have the same element size and alignment. const uint32_t element_size = input_tensors[0].element_size(); const uint32_t alignment = input_tensors[0].buffer()->alignment(); - std::vector all_stick_sizes; - all_stick_sizes.push_back(output.shard_spec().value().shape[1]); - std::transform( - input_tensors.begin(), input_tensors.end(), std::back_inserter(all_stick_sizes), [](const Tensor &tensor) { - return tensor.element_size() * tensor.shard_spec().value().shape[1]; - }); - const uint32_t page_size = find_greatest_common_page_size(all_stick_sizes, alignment); - const uint32_t elements_per_page = page_size / element_size; + uint32_t page_size; + uint32_t elements_per_page_width; + uint32_t elements_per_page_height; + if (rm_layout) { + std::vector all_stick_sizes; + all_stick_sizes.push_back(output.shard_spec().value().shape[1]); + std::transform( + input_tensors.begin(), input_tensors.end(), std::back_inserter(all_stick_sizes), [](const Tensor &tensor) { + return tensor.element_size() * tensor.shard_spec().value().shape[1]; + }); + page_size = find_greatest_common_page_size(all_stick_sizes, alignment); + elements_per_page_width = page_size / element_size; + elements_per_page_height = 1; + } else { + page_size = tt_metal::detail::TileSize(cb_data_format); + elements_per_page_width = TILE_WIDTH; + elements_per_page_height = TILE_HEIGHT; + } vector cb_inputs(num_input_tensors); vector input_num_pages_per_stick(num_input_tensors); vector input_num_sticks(num_input_tensors); vector input_write_offsets(num_input_tensors); - const tt::DataFormat cb_data_format = tt_metal::datatype_to_dataformat_converter(output.get_dtype()); // Assume inputs and output have the same sharding grid. const auto all_cores = input_tensors[0].shard_spec().value().grid; @@ -202,10 +213,8 @@ operation::ProgramWithCallbacks s2s_rm_concat_multi_core( uint32_t curr_input_write_offset = 0; for (uint32_t input_id = 0; input_id < num_input_tensors; input_id++) { const auto shard_spec = input_tensors[input_id].shard_spec().value(); - const uint32_t shard_height = shard_spec.shape[0]; - const uint32_t shard_width = shard_spec.shape[1]; - input_num_pages_per_stick[input_id] = div_up(shard_width, elements_per_page); - input_num_sticks[input_id] = shard_height; + input_num_pages_per_stick[input_id] = div_up(shard_spec.shape[1], elements_per_page_width); + input_num_sticks[input_id] = div_up(shard_spec.shape[0], elements_per_page_height); input_write_offsets[input_id] = curr_input_write_offset; const uint32_t input_num_pages = input_num_pages_per_stick[input_id] * input_num_sticks[input_id]; @@ -221,12 +230,11 @@ operation::ProgramWithCallbacks s2s_rm_concat_multi_core( // Output CB const auto output_shard_spec = output.shard_spec().value(); - const uint32_t output_shard_height = output_shard_spec.shape[0]; - const uint32_t output_shard_width = output_shard_spec.shape[1]; - const uint32_t output_num_pages_per_stick = div_up(output_shard_width, elements_per_page); + const uint32_t output_num_pages_per_stick = div_up(output_shard_spec.shape[1], elements_per_page_width); + const uint32_t output_num_sticks = div_up(output_shard_spec.shape[0], elements_per_page_height); const tt_metal::CircularBufferConfig output_cb_config = tt_metal::CircularBufferConfig( - page_size * output_shard_height * output_num_pages_per_stick, {{cb_dst_id, cb_data_format}}) + page_size * output_num_sticks * output_num_pages_per_stick, {{cb_dst_id, cb_data_format}}) .set_page_size(cb_dst_id, page_size) .set_globally_allocated_address(*output.buffer()); auto cb_output = tt_metal::CreateCircularBuffer(program, all_cores, output_cb_config); @@ -250,15 +258,13 @@ operation::ProgramWithCallbacks s2s_rm_concat_multi_core( tt_metal::KernelHandle unary_reader_kernel_id = tt_metal::CreateKernel( program, - "ttnn/cpp/ttnn/operations/data_movement/concat/device/kernels/dataflow/" - "reader_s2s_rm_tensor_concat.cpp", + "ttnn/cpp/ttnn/operations/data_movement/concat/device/kernels/dataflow/reader_s2s_tensor_concat.cpp", all_cores, tt_metal::ReaderDataMovementConfig(compile_time_args)); tt_metal::KernelHandle unary_writer_kernel_id = tt_metal::CreateKernel( program, - "ttnn/cpp/ttnn/operations/data_movement/concat/device/kernels/dataflow/" - "reader_s2s_rm_tensor_concat.cpp", + "ttnn/cpp/ttnn/operations/data_movement/concat/device/kernels/dataflow/reader_s2s_tensor_concat.cpp", all_cores, tt_metal::WriterDataMovementConfig(compile_time_args)); @@ -422,7 +428,7 @@ operation::ProgramWithCallbacks sharded_concat_multi_core( // case. return s2s_rm_concat_two_tensors_multi_core(input_tensors, dim, output); } else { - return s2s_rm_concat_multi_core(input_tensors, dim, output); + return s2s_concat_multi_core(input_tensors, dim, output); } } else { return s2i_rm_concat_multi_core(input_tensors, dim, output); diff --git a/ttnn/cpp/ttnn/operations/data_movement/concat/device/kernels/dataflow/reader_s2s_rm_tensor_concat.cpp b/ttnn/cpp/ttnn/operations/data_movement/concat/device/kernels/dataflow/reader_s2s_tensor_concat.cpp similarity index 100% rename from ttnn/cpp/ttnn/operations/data_movement/concat/device/kernels/dataflow/reader_s2s_rm_tensor_concat.cpp rename to ttnn/cpp/ttnn/operations/data_movement/concat/device/kernels/dataflow/reader_s2s_tensor_concat.cpp