Skip to content

Commit

Permalink
#0: resolve multi-link line reduce scatter PCC issues
Browse files Browse the repository at this point in the history
- Also added some reduce-scatter tests that do in-flight reshards.
  • Loading branch information
SeanNijjar committed Oct 26, 2024
1 parent 561e2fd commit c6a61fc
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def run_line_reduce_scatter_on_TG_with_mesh_tensor_along_rows(
@pytest.mark.parametrize(
"num_devices, num_links, per_chip_output_shape, dim, layout",
[
(4, 1, [1, 4, 32, 2304], 1, ttnn.TILE_LAYOUT),
(4, 2, [1, 4, 32, 2304], 1, ttnn.TILE_LAYOUT),
],
)
@pytest.mark.parametrize(
Expand Down Expand Up @@ -270,8 +270,8 @@ def test_line_reduce_scatter_on_TG_rows_post_commit(
@pytest.mark.parametrize(
"num_devices, num_links, per_chip_output_shape, dim, layout",
[
(8, 1, [1, 8, 32, 1280], 1, ttnn.TILE_LAYOUT),
(8, 1, [8, 1, 32, 1280], 0, ttnn.TILE_LAYOUT),
(8, 2, [1, 8, 32, 1280], 1, ttnn.TILE_LAYOUT),
(8, 2, [8, 1, 32, 1280], 0, ttnn.TILE_LAYOUT),
],
)
@pytest.mark.parametrize(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@


def is_unsupported_case(input_shape, scatter_dim, math_op, mem_config, num_devices, num_links, input_dtype, layout):
if scatter_dim != 3:
return True, "Only support for scatter_dim=3 is tested so far"

elem_size = 2 if input_dtype == ttnn.bfloat16 else 1
tensor_size_bytes = elem_size
for i in input_shape:
Expand Down Expand Up @@ -322,6 +319,69 @@ def test_line_reduce_scatter_post_commit(
)


# ~2:45 extra time in the current state
@skip_for_grayskull("Requires eth connected devices to run")
@pytest.mark.timeout(120)
@pytest.mark.parametrize(
"num_devices, num_links",
[
(4, 2),
],
)
@pytest.mark.parametrize(
"per_chip_output_shape, scatter_dim, layout",
[
([1, 1, 32, 1280], 1, ttnn.TILE_LAYOUT),
([1, 1, 32, 1024], 1, ttnn.TILE_LAYOUT),
],
)
@pytest.mark.parametrize(
"input_dtype",
[
ttnn.bfloat16,
],
)
@pytest.mark.parametrize(
"mem_config",
[
ttnn.MemoryConfig(buffer_type=ttnn.BufferType.DRAM),
],
)
@pytest.mark.parametrize("math_op", [ttnn.ReduceType.Sum])
@pytest.mark.parametrize("enable_async", [True])
def test_line_reduce_scatter_post_commit_4chip(
pcie_mesh_device,
num_devices,
per_chip_output_shape,
scatter_dim,
num_links,
math_op,
input_dtype,
layout,
mem_config,
use_program_cache,
function_level_defaults,
enable_async,
num_iters=1,
):
run_reduce_scatter_test(
pcie_mesh_device,
num_devices,
per_chip_output_shape,
scatter_dim,
num_links,
math_op,
input_dtype,
layout,
mem_config,
use_program_cache,
function_level_defaults,
num_iters=num_iters,
enable_async=enable_async,
topology=ttnn.Topology.Linear,
)


def run_reduce_scatter_sharded_test(
t3k_mesh_device,
num_devices,
Expand All @@ -337,6 +397,9 @@ def run_reduce_scatter_sharded_test(
tensor_mem_layout,
use_program_cache,
function_level_defaults,
in_shard_override=None,
in_shard_grid_override=None,
topology=ttnn.Topology.Ring,
enable_async=True,
num_iters=1,
n_worker=None,
Expand All @@ -355,15 +418,23 @@ def run_reduce_scatter_sharded_test(
t3k_mesh_device.enable_async(enable_async)

# Generate input tensors
input_shard_shape = list(output_shard_shape)
if scatter_dim == 3:
input_shard_shape[1] *= num_devices
if in_shard_grid_override is None:
assert in_shard_override is None
in_shard_grid = shard_grid
input_shard_shape = list(output_shard_shape)
if scatter_dim == 3:
input_shard_shape[1] *= num_devices
else:
input_shard_shape[0] *= num_devices
else:
input_shard_shape[0] *= num_devices
assert in_shard_override is not None
input_shard_shape = list(in_shard_override)
in_shard_grid = in_shard_grid_override

tt_input_tensors = []

input_shard_spec = ttnn.ShardSpec(
shard_grid,
in_shard_grid,
tuple(input_shard_shape),
orientation,
False,
Expand Down Expand Up @@ -421,6 +492,7 @@ def run_reduce_scatter_sharded_test(
math_op=math_op,
num_links=num_links,
memory_config=output_mem_config,
topology=topology,
)

for device_id in t3k_mesh_device.get_device_ids():
Expand Down Expand Up @@ -544,6 +616,97 @@ def test_width_sharded_reduce_scatter_post_commit(
)


@skip_for_grayskull("Requires eth connected devices to run")
@pytest.mark.timeout(120)
@pytest.mark.parametrize(
"num_devices, num_links",
[
(4, 2),
],
)
@pytest.mark.parametrize("dim", [3])
@pytest.mark.parametrize(
"tensor_mem_layout",
[
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
],
)
@pytest.mark.parametrize("tensor_layout", [ttnn.TILE_LAYOUT])
@pytest.mark.parametrize("orientation", [ttnn.ShardOrientation.ROW_MAJOR])
@pytest.mark.parametrize(
"input_dtype",
[
ttnn.bfloat16,
ttnn.bfloat8_b,
],
)
@pytest.mark.parametrize(
"per_chip_output_shape,output_shard_shape,shard_grid,in_shard_override,in_shard_grid_override",
(
# LLama
(
(1, 1, 32, 1280),
(32, 128),
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(4, 1))}),
(32, 160),
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 4))}),
),
(
(1, 1, 32, 1280),
(32, 128),
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(4, 1))}),
None,
None,
),
),
)
@pytest.mark.parametrize("topology", [ttnn.Topology.Ring, ttnn.Topology.Linear])
@pytest.mark.parametrize("math_op", [ttnn.ReduceType.Sum])
@pytest.mark.parametrize("enable_async", [True])
def test_width_sharded_reduce_scatter_post_commit_4chip(
pcie_mesh_device,
num_devices,
per_chip_output_shape,
output_shard_shape,
dim,
num_links,
math_op,
topology,
shard_grid,
orientation,
input_dtype,
tensor_layout,
tensor_mem_layout,
use_program_cache,
function_level_defaults,
in_shard_override,
in_shard_grid_override,
enable_async,
num_iters=1,
):
run_reduce_scatter_sharded_test(
pcie_mesh_device,
num_devices,
per_chip_output_shape,
output_shard_shape,
dim,
num_links,
math_op,
shard_grid,
orientation,
input_dtype,
tensor_layout,
tensor_mem_layout,
use_program_cache=use_program_cache,
function_level_defaults=function_level_defaults,
in_shard_override=in_shard_override,
in_shard_grid_override=in_shard_grid_override,
topology=topology,
enable_async=enable_async,
num_iters=num_iters,
)


@skip_for_grayskull("Requires eth connected devices to run")
@pytest.mark.skip("Hangs")
@pytest.mark.timeout(120)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ static std::pair<CoreRangeSet, std::optional<CoreRangeSet>> select_worker_cores_
auto const& lower_half_of_cores =
CoreRangeSet(CoreRange(CoreCoord(0, 0), CoreCoord(workers_per_direction - 1, num_links - 1)));
auto const& upper_half_of_cores = CoreRangeSet(
CoreRange(CoreCoord(workers_per_direction, 0), CoreCoord(num_edm_channels - 1, num_links - 1)));
CoreRange(CoreCoord(0, num_links), CoreCoord(workers_per_direction - 1, (2 * num_links) - 1)));
if (topology_config.ring_index == 0) {
log_trace(tt::LogOp, "Start of line, putting CCL send cores in lower half");
return {upper_half_of_cores, lower_half_of_cores};
Expand Down Expand Up @@ -650,7 +650,7 @@ operation::ProgramWithCallbacks reduce_scatter_with_workers(

std::function<bool(uint32_t)> is_worker_in_clockwise_direction_fn = [is_linear, enable_bidirectional, num_edm_channels_per_link](std::size_t x) {
static constexpr std::size_t bidirectional_directions = 2;
return is_linear ? (x < (num_edm_channels_per_link / bidirectional_directions)):
return is_linear ? ((x % num_edm_channels_per_link) < (num_edm_channels_per_link / bidirectional_directions)):
enable_bidirectional ? (x % bidirectional_directions == 0) : true;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,18 +94,20 @@ std::vector<WorkerAttributes> build_worker_attributes(
worker_receiver_semaphore_id :
worker_receiver_semaphore_id_second_core_range;

std::array<size_t, 2> worker_slice_index = {0, 0};

for (std::size_t l = 0; l < num_links; l++) {
for (std::size_t i = 0; i < workers_per_slice; i++) {
auto worker_id = get_global_worker_id(l, i, num_channels_per_link);
TT_ASSERT(worker_cores_idx < worker_cores_list.size());

auto direction = is_buffer_in_clockwise_direction_fn(worker_id) ? Direction::CLOCKWISE : Direction::COUNTER_CLOCKWISE;
worker_attributes.push_back(
{
l,
i,
i,
is_buffer_in_clockwise_direction_fn(worker_id) ? Direction::CLOCKWISE : Direction::COUNTER_CLOCKWISE,
first_workers_list[worker_cores_idx],
worker_slice_index[static_cast<size_t>(direction)]++,
direction,
first_workers_list.at(worker_cores_idx),
first_send_to_edm_sem_id,
first_read_from_edm_sem_id
}
Expand All @@ -119,14 +121,15 @@ std::vector<WorkerAttributes> build_worker_attributes(
TT_ASSERT(second_vec_index < second_workers_list.value().size());
std::size_t my_logical_index = workers_per_slice + i;
std::size_t my_idx = worker_attributes.size();
auto direction = is_buffer_in_clockwise_direction_fn(my_logical_index) ?
Direction::CLOCKWISE : Direction::COUNTER_CLOCKWISE;
worker_attributes.push_back(
{
l,
my_logical_index,
i,
is_buffer_in_clockwise_direction_fn(my_logical_index) ?
Direction::CLOCKWISE : Direction::COUNTER_CLOCKWISE,
second_workers_list.value()[second_vec_index],
worker_slice_index[static_cast<size_t>(direction)]++,
direction,
second_workers_list.value().at(second_vec_index),
second_send_to_edm_sem_id,
second_read_from_edm_sem_id
}
Expand All @@ -150,9 +153,10 @@ std::vector<WorkerAttributes> build_worker_attributes(
// Log worker attributes
log_trace(tt::LogOp, "Worker Attributes:");
for (const auto &wa : worker_attributes) {
log_trace(tt::LogOp, "\tAttributes: link={}, index={}, core_logical=(x={},y={}), direction={}, associated_core=(x={},y={}), associated_index={}",
log_trace(tt::LogOp, "\tAttributes: link={}, chan_index={}, slice_index: {}, core_logical=(x={},y={}), direction={}, associated_core=(x={},y={}), associated_index={}",
wa.link,
wa.channel,
wa.index_in_slice,
wa.location_logical.x,
wa.location_logical.y,
wa.direction == Direction::CLOCKWISE ? "CLOCKWISE": "COUNTER-CLOCKWISE",
Expand Down

0 comments on commit c6a61fc

Please sign in to comment.