Skip to content

Commit

Permalink
#0: Remove hardcoded grid width in all_gather and skip test_sharded_m…
Browse files Browse the repository at this point in the history
…atmul test when the device grid size is too small
  • Loading branch information
tt-aho committed Dec 26, 2024
1 parent c0c1bb6 commit 9879b37
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 10 deletions.
11 changes: 8 additions & 3 deletions tests/ttnn/unit_tests/test_multi_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,11 +511,16 @@ def test_sharded_matmul(t3k_mesh_device):
q_heads_1B4D = ttnn.to_device(q_heads_1B4D, t3k_mesh_device)
keys_1BDP = ttnn.to_device(keys_1BDP, t3k_mesh_device)

core_grid = ttnn.CoreGrid(y=4, x=8)
compute_grid_size = t3k_mesh_device.compute_with_storage_grid_size()
if (compute_grid_size.x < core_grid.x) or (compute_grid_size.y < core_grid.y):
pytest.skip("Test requires larger grid size")

q_heads_1B4D = ttnn.to_memory_config(
q_heads_1B4D,
ttnn.create_sharded_memory_config(
shape=(32, 128),
core_grid=ttnn.CoreGrid(y=4, x=8),
core_grid=core_grid,
strategy=ttnn.ShardStrategy.HEIGHT,
orientation=ttnn.ShardOrientation.ROW_MAJOR,
use_height_and_width_as_shard_shape=True,
Expand All @@ -526,7 +531,7 @@ def test_sharded_matmul(t3k_mesh_device):
keys_1BDP,
ttnn.create_sharded_memory_config(
shape=(128, 32),
core_grid=ttnn.CoreGrid(y=4, x=8),
core_grid=core_grid,
strategy=ttnn.ShardStrategy.HEIGHT,
orientation=ttnn.ShardOrientation.ROW_MAJOR,
use_height_and_width_as_shard_shape=True,
Expand All @@ -542,7 +547,7 @@ def test_sharded_matmul(t3k_mesh_device):
q_heads_1B4D,
keys_1BDP,
dtype=ttnn.bfloat16,
core_grid=ttnn.CoreGrid(y=4, x=8),
core_grid=core_grid,
compute_kernel_config=compute_kernel_attn,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,15 @@ using namespace ccl;

static std::tuple<CoreRangeSet, CoreRangeSet, std::map<std::pair<uint32_t, uint32_t>, std::vector<CoreRangeSet>>>
get_all_worker_cores(
AllGatherConfig const& all_gather_config,
const AllGatherConfig& all_gather_config,
uint32_t num_links,
uint32_t num_full_send_directions,
CoreCoord const& core_grid_offset,
const CoreCoord& core_grid_offset,
bool is_linear,
uint32_t ring_size,
uint32_t ring_index) {
constexpr uint32_t worker_grid_width = 8;
uint32_t ring_index,
const CoreCoord& grid_size) {
uint32_t worker_grid_width = grid_size.x;
const bool fit_sender_and_receiver_workers_on_same_row =
(worker_grid_width / 2) >= all_gather_config.get_num_workers_per_link();

Expand All @@ -62,7 +63,7 @@ get_all_worker_cores(
bool receiver_enabled = (!is_linear || !is_first_chip_in_chain);

for (uint32_t link = 0; link < num_links; ++link) {
uint32_t max_cols = 8;
uint32_t max_cols = worker_grid_width;
uint32_t curr_row = link * (((all_gather_config.get_num_workers_per_link() * 2 - 1) / max_cols) + 1) +
(full_send_direction * num_links *
(((all_gather_config.get_num_workers_per_link() * 2 - 1) / max_cols) + 1)) +
Expand Down Expand Up @@ -451,8 +452,15 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper(
}

// KERNEL CREATION
auto const& [all_receiver_workers, all_sender_workers, worker_core_map] = get_all_worker_cores(
all_gather_config, num_links, num_full_send_directions, core_grid_offset, is_linear, ring_size, ring_index);
const auto& [all_receiver_workers, all_sender_workers, worker_core_map] = get_all_worker_cores(
all_gather_config,
num_links,
num_full_send_directions,
core_grid_offset,
is_linear,
ring_size,
ring_index,
input_tensor.device()->compute_with_storage_grid_size());
auto all_sender_worker_cores = corerange_to_cores(all_sender_workers, std::nullopt, true);
auto all_receiver_worker_cores = corerange_to_cores(all_receiver_workers, std::nullopt, true);

Expand Down

0 comments on commit 9879b37

Please sign in to comment.