From c6a61fc44815967109be4dd079cbaebaba1de22d Mon Sep 17 00:00:00 2001 From: Sean Nijjar Date: Fri, 25 Oct 2024 20:59:41 +0000 Subject: [PATCH] #0: resolve multi-link line reduce scatter PCC issues - Also added some reduce-scatter tests that do in-flight reshards. --- .../ccl/test_reduce_scatter_TG_nightly.py | 6 +- .../ccl/test_reduce_scatter_post_commit.py | 179 +++++++++++++++++- .../host/reduce_scatter_full_worker_grid.cpp | 4 +- .../host/reduce_scatter_common.cpp | 22 ++- 4 files changed, 189 insertions(+), 22 deletions(-) diff --git a/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_TG_nightly.py b/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_TG_nightly.py index 2cbe8f5aa29..9e9fbf479f5 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_TG_nightly.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_TG_nightly.py @@ -208,7 +208,7 @@ def run_line_reduce_scatter_on_TG_with_mesh_tensor_along_rows( @pytest.mark.parametrize( "num_devices, num_links, per_chip_output_shape, dim, layout", [ - (4, 1, [1, 4, 32, 2304], 1, ttnn.TILE_LAYOUT), + (4, 2, [1, 4, 32, 2304], 1, ttnn.TILE_LAYOUT), ], ) @pytest.mark.parametrize( @@ -270,8 +270,8 @@ def test_line_reduce_scatter_on_TG_rows_post_commit( @pytest.mark.parametrize( "num_devices, num_links, per_chip_output_shape, dim, layout", [ - (8, 1, [1, 8, 32, 1280], 1, ttnn.TILE_LAYOUT), - (8, 1, [8, 1, 32, 1280], 0, ttnn.TILE_LAYOUT), + (8, 2, [1, 8, 32, 1280], 1, ttnn.TILE_LAYOUT), + (8, 2, [8, 1, 32, 1280], 0, ttnn.TILE_LAYOUT), ], ) @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_post_commit.py b/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_post_commit.py index 3d052a1a7a1..9fbc710ed7c 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_post_commit.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_post_commit.py @@ -11,9 +11,6 @@ def is_unsupported_case(input_shape, scatter_dim, math_op, mem_config, num_devices, num_links, input_dtype, layout): - if scatter_dim != 3: - return True, "Only support for scatter_dim=3 is tested so far" - elem_size = 2 if input_dtype == ttnn.bfloat16 else 1 tensor_size_bytes = elem_size for i in input_shape: @@ -322,6 +319,69 @@ def test_line_reduce_scatter_post_commit( ) +# ~2:45 extra time in the current state +@skip_for_grayskull("Requires eth connected devices to run") +@pytest.mark.timeout(120) +@pytest.mark.parametrize( + "num_devices, num_links", + [ + (4, 2), + ], +) +@pytest.mark.parametrize( + "per_chip_output_shape, scatter_dim, layout", + [ + ([1, 1, 32, 1280], 1, ttnn.TILE_LAYOUT), + ([1, 1, 32, 1024], 1, ttnn.TILE_LAYOUT), + ], +) +@pytest.mark.parametrize( + "input_dtype", + [ + ttnn.bfloat16, + ], +) +@pytest.mark.parametrize( + "mem_config", + [ + ttnn.MemoryConfig(buffer_type=ttnn.BufferType.DRAM), + ], +) +@pytest.mark.parametrize("math_op", [ttnn.ReduceType.Sum]) +@pytest.mark.parametrize("enable_async", [True]) +def test_line_reduce_scatter_post_commit_4chip( + pcie_mesh_device, + num_devices, + per_chip_output_shape, + scatter_dim, + num_links, + math_op, + input_dtype, + layout, + mem_config, + use_program_cache, + function_level_defaults, + enable_async, + num_iters=1, +): + run_reduce_scatter_test( + pcie_mesh_device, + num_devices, + per_chip_output_shape, + scatter_dim, + num_links, + math_op, + input_dtype, + layout, + mem_config, + use_program_cache, + function_level_defaults, + num_iters=num_iters, + enable_async=enable_async, + topology=ttnn.Topology.Linear, + ) + + def run_reduce_scatter_sharded_test( t3k_mesh_device, num_devices, @@ -337,6 +397,9 @@ def run_reduce_scatter_sharded_test( tensor_mem_layout, use_program_cache, function_level_defaults, + in_shard_override=None, + in_shard_grid_override=None, + topology=ttnn.Topology.Ring, enable_async=True, num_iters=1, n_worker=None, @@ -355,15 +418,23 @@ def run_reduce_scatter_sharded_test( t3k_mesh_device.enable_async(enable_async) # Generate input tensors - input_shard_shape = list(output_shard_shape) - if scatter_dim == 3: - input_shard_shape[1] *= num_devices + if in_shard_grid_override is None: + assert in_shard_override is None + in_shard_grid = shard_grid + input_shard_shape = list(output_shard_shape) + if scatter_dim == 3: + input_shard_shape[1] *= num_devices + else: + input_shard_shape[0] *= num_devices else: - input_shard_shape[0] *= num_devices + assert in_shard_override is not None + input_shard_shape = list(in_shard_override) + in_shard_grid = in_shard_grid_override + tt_input_tensors = [] input_shard_spec = ttnn.ShardSpec( - shard_grid, + in_shard_grid, tuple(input_shard_shape), orientation, False, @@ -421,6 +492,7 @@ def run_reduce_scatter_sharded_test( math_op=math_op, num_links=num_links, memory_config=output_mem_config, + topology=topology, ) for device_id in t3k_mesh_device.get_device_ids(): @@ -544,6 +616,97 @@ def test_width_sharded_reduce_scatter_post_commit( ) +@skip_for_grayskull("Requires eth connected devices to run") +@pytest.mark.timeout(120) +@pytest.mark.parametrize( + "num_devices, num_links", + [ + (4, 2), + ], +) +@pytest.mark.parametrize("dim", [3]) +@pytest.mark.parametrize( + "tensor_mem_layout", + [ + ttnn.TensorMemoryLayout.WIDTH_SHARDED, + ], +) +@pytest.mark.parametrize("tensor_layout", [ttnn.TILE_LAYOUT]) +@pytest.mark.parametrize("orientation", [ttnn.ShardOrientation.ROW_MAJOR]) +@pytest.mark.parametrize( + "input_dtype", + [ + ttnn.bfloat16, + ttnn.bfloat8_b, + ], +) +@pytest.mark.parametrize( + "per_chip_output_shape,output_shard_shape,shard_grid,in_shard_override,in_shard_grid_override", + ( + # LLama + ( + (1, 1, 32, 1280), + (32, 128), + ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(4, 1))}), + (32, 160), + ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 4))}), + ), + ( + (1, 1, 32, 1280), + (32, 128), + ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(4, 1))}), + None, + None, + ), + ), +) +@pytest.mark.parametrize("topology", [ttnn.Topology.Ring, ttnn.Topology.Linear]) +@pytest.mark.parametrize("math_op", [ttnn.ReduceType.Sum]) +@pytest.mark.parametrize("enable_async", [True]) +def test_width_sharded_reduce_scatter_post_commit_4chip( + pcie_mesh_device, + num_devices, + per_chip_output_shape, + output_shard_shape, + dim, + num_links, + math_op, + topology, + shard_grid, + orientation, + input_dtype, + tensor_layout, + tensor_mem_layout, + use_program_cache, + function_level_defaults, + in_shard_override, + in_shard_grid_override, + enable_async, + num_iters=1, +): + run_reduce_scatter_sharded_test( + pcie_mesh_device, + num_devices, + per_chip_output_shape, + output_shard_shape, + dim, + num_links, + math_op, + shard_grid, + orientation, + input_dtype, + tensor_layout, + tensor_mem_layout, + use_program_cache=use_program_cache, + function_level_defaults=function_level_defaults, + in_shard_override=in_shard_override, + in_shard_grid_override=in_shard_grid_override, + topology=topology, + enable_async=enable_async, + num_iters=num_iters, + ) + + @skip_for_grayskull("Requires eth connected devices to run") @pytest.mark.skip("Hangs") @pytest.mark.timeout(120) diff --git a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp index 5db23aa52b8..108c21414e3 100644 --- a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp @@ -339,7 +339,7 @@ static std::pair> select_worker_cores_ auto const& lower_half_of_cores = CoreRangeSet(CoreRange(CoreCoord(0, 0), CoreCoord(workers_per_direction - 1, num_links - 1))); auto const& upper_half_of_cores = CoreRangeSet( - CoreRange(CoreCoord(workers_per_direction, 0), CoreCoord(num_edm_channels - 1, num_links - 1))); + CoreRange(CoreCoord(0, num_links), CoreCoord(workers_per_direction - 1, (2 * num_links) - 1))); if (topology_config.ring_index == 0) { log_trace(tt::LogOp, "Start of line, putting CCL send cores in lower half"); return {upper_half_of_cores, lower_half_of_cores}; @@ -650,7 +650,7 @@ operation::ProgramWithCallbacks reduce_scatter_with_workers( std::function is_worker_in_clockwise_direction_fn = [is_linear, enable_bidirectional, num_edm_channels_per_link](std::size_t x) { static constexpr std::size_t bidirectional_directions = 2; - return is_linear ? (x < (num_edm_channels_per_link / bidirectional_directions)): + return is_linear ? ((x % num_edm_channels_per_link) < (num_edm_channels_per_link / bidirectional_directions)): enable_bidirectional ? (x % bidirectional_directions == 0) : true; }; diff --git a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/host/reduce_scatter_common.cpp b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/host/reduce_scatter_common.cpp index 1fedbae5584..160dffb2d66 100644 --- a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/host/reduce_scatter_common.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/host/reduce_scatter_common.cpp @@ -94,18 +94,20 @@ std::vector build_worker_attributes( worker_receiver_semaphore_id : worker_receiver_semaphore_id_second_core_range; + std::array worker_slice_index = {0, 0}; + for (std::size_t l = 0; l < num_links; l++) { for (std::size_t i = 0; i < workers_per_slice; i++) { auto worker_id = get_global_worker_id(l, i, num_channels_per_link); TT_ASSERT(worker_cores_idx < worker_cores_list.size()); - + auto direction = is_buffer_in_clockwise_direction_fn(worker_id) ? Direction::CLOCKWISE : Direction::COUNTER_CLOCKWISE; worker_attributes.push_back( { l, i, - i, - is_buffer_in_clockwise_direction_fn(worker_id) ? Direction::CLOCKWISE : Direction::COUNTER_CLOCKWISE, - first_workers_list[worker_cores_idx], + worker_slice_index[static_cast(direction)]++, + direction, + first_workers_list.at(worker_cores_idx), first_send_to_edm_sem_id, first_read_from_edm_sem_id } @@ -119,14 +121,15 @@ std::vector build_worker_attributes( TT_ASSERT(second_vec_index < second_workers_list.value().size()); std::size_t my_logical_index = workers_per_slice + i; std::size_t my_idx = worker_attributes.size(); + auto direction = is_buffer_in_clockwise_direction_fn(my_logical_index) ? + Direction::CLOCKWISE : Direction::COUNTER_CLOCKWISE; worker_attributes.push_back( { l, my_logical_index, - i, - is_buffer_in_clockwise_direction_fn(my_logical_index) ? - Direction::CLOCKWISE : Direction::COUNTER_CLOCKWISE, - second_workers_list.value()[second_vec_index], + worker_slice_index[static_cast(direction)]++, + direction, + second_workers_list.value().at(second_vec_index), second_send_to_edm_sem_id, second_read_from_edm_sem_id } @@ -150,9 +153,10 @@ std::vector build_worker_attributes( // Log worker attributes log_trace(tt::LogOp, "Worker Attributes:"); for (const auto &wa : worker_attributes) { - log_trace(tt::LogOp, "\tAttributes: link={}, index={}, core_logical=(x={},y={}), direction={}, associated_core=(x={},y={}), associated_index={}", + log_trace(tt::LogOp, "\tAttributes: link={}, chan_index={}, slice_index: {}, core_logical=(x={},y={}), direction={}, associated_core=(x={},y={}), associated_index={}", wa.link, wa.channel, + wa.index_in_slice, wa.location_logical.x, wa.location_logical.y, wa.direction == Direction::CLOCKWISE ? "CLOCKWISE": "COUNTER-CLOCKWISE",