Skip to content

Commit

Permalink
#11403: Support 2x4-submeshes across 8x4 mesh
Browse files Browse the repository at this point in the history
1. Submeshing to support creating submesh on galaxy mesh
2. Key change to start enabling more T3000 Tests onto galaxy:
- Currently ttnn.all_gather(..) in a ring relies on MeshDevice being
initialized in ring-order. Now we decouple this so we don't require that
MeshDevice is initialized with devices in a ring-order. Instead, we now
explicitly request for a ring-order in the operation that requires it.
  • Loading branch information
cfjchu committed Oct 1, 2024
1 parent d8706ff commit 240ee29
Show file tree
Hide file tree
Showing 24 changed files with 436 additions and 135 deletions.
6 changes: 2 additions & 4 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,10 +232,7 @@ def pcie_mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, devic
request.node.pci_ids = device_ids[:num_pcie_devices_requested]

mesh_device = ttnn.open_mesh_device(
ttnn.MeshShape(1, num_pcie_devices_requested),
dispatch_core_type=get_dispatch_core_type(),
**device_params,
physical_device_ids=device_ids[:num_pcie_devices_requested],
ttnn.MeshShape(2, 2), dispatch_core_type=get_dispatch_core_type(), **device_params, offset=(0, 1)
)

logger.debug(f"multidevice with {mesh_device.get_num_devices()} devices is created")
Expand All @@ -255,6 +252,7 @@ def t3k_mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, device
if ttnn.get_num_devices() < 8:
pytest.skip()

request.node.pci_ids = ttnn.get_pcie_device_ids()
mesh_device = ttnn.open_mesh_device(
ttnn.MeshShape(2, 4),
dispatch_core_type=get_dispatch_core_type(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import (
comp_pcc,
)
from models.utility_functions import torch2tt_tensor, tt2torch_tensor, skip_for_grayskull, get_devices_for_t3000
from models.utility_functions import torch2tt_tensor, tt2torch_tensor, skip_for_grayskull
from models.demos.t3000.falcon40b.tt.model_config import (
get_model_config,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import (
comp_pcc,
)
from models.utility_functions import torch2tt_tensor, tt2torch_tensor, skip_for_grayskull, get_devices_for_t3000
from models.utility_functions import torch2tt_tensor, tt2torch_tensor, skip_for_grayskull
from models.demos.t3000.falcon40b.tt.model_config import (
get_model_config,
)
Expand Down
1 change: 0 additions & 1 deletion models/demos/t3000/llama2_70b/tests/test_llama_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
disable_compilation_reports,
nearest_32,
skip_for_grayskull,
get_devices_for_t3000,
)
from models.perf.perf_utils import prep_perf_report
from models.perf.device_perf_utils import run_device_perf, check_device_perf, prep_device_perf_report
Expand Down
3 changes: 1 addition & 2 deletions models/demos/t3000/llama2_70b/tt/llama_mlp_optimized.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,11 @@ def load_weights(self):
padded_w3[:, :, :, :H4] = self.state_dict[w3_str].transpose(-2, -1)

# w1: 8k x 4k. width-sharded on 12 banks, 4224 over 12 banks.
device = self.mesh_device.get_device(0)
weight_grid = ttnn.CoreRangeSet(
{
ttnn.CoreRange(
ttnn.CoreCoord(0, 0),
ttnn.CoreCoord(device.dram_grid_size().x - 1, device.dram_grid_size().y - 1),
ttnn.CoreCoord(self.mesh_device.dram_grid_size().x - 1, self.mesh_device.dram_grid_size().y - 1),
)
}
)
Expand Down
6 changes: 5 additions & 1 deletion tests/scripts/tg/run_tg_model_perf_tests.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
#!/bin/bash

run_tg_llm_tests() {
run_t3k_tests_on_tg_tests() {

echo "LOG_METAL: Running T3000 tests on TG"
env pytest -n auto models/demos/t3000/llama2_70b/tests/test_llama_perf_decode.py -m "model_perf_t3000" --timeout=600 ; fail+=$?

# Merge all the generated reports
env python models/perf/merge_perf_results.py; fail+=$?

Expand Down
16 changes: 16 additions & 0 deletions tests/ttnn/multichip_unit_tests/test_multidevice_TG.py
Original file line number Diff line number Diff line change
Expand Up @@ -1573,3 +1573,19 @@ def test_sharded_distributed_layernorm(mesh_device, input_width, input_height, c
is_pass, output_pcc = comp_pcc(torch_output_tensor, tt_output_tensor, pcc=0.999)

assert is_pass, f"PCC value: {output_pcc}"


def test_ttnn_multi_device_all_gather_all_devices(t3k_mesh_device):
"""Example test for running a 2x4-Ring All-Gather on galaxy"""
full_tensor = torch.ones((1, 1, 32, 32 * t3k_mesh_device.get_num_devices()), dtype=torch.bfloat16)
for i in range(t3k_mesh_device.get_num_devices()):
full_tensor[..., i * 32 : (i + 1) * 32] = i

ttnn_tensor = ttnn.from_torch(full_tensor, mesh_mapper=ShardTensorToMesh(t3k_mesh_device, dim=3))
ttnn_tensor = ttnn.to_device(ttnn_tensor, t3k_mesh_device)
ttnn_tensor = ttnn.all_gather(ttnn_tensor, dim=3, num_links=1)

device_tensors: typing.List[ttnn.Tensor] = ttnn.get_device_tensors(ttnn_tensor)
for device_tensor in device_tensors:
device_tensor_torch = ttnn.to_torch(device_tensor)
assert torch.all(device_tensor_torch == full_tensor)
12 changes: 7 additions & 5 deletions tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,9 @@ TEST(GalaxyTests, TestAllGatherDeadlock) {
}
// Iterate over each row and run line all-gather multiple times.
// For each row, send adversarial traffic to the first chip, that can hang the network if the CCL is not tagged.
auto view = MeshDeviceView(*mesh);
for (uint32_t row = 0; row < 8; row++) {
auto devs = mesh->get_devices_on_row(row);
auto devs = view.get_devices_on_row(row);
std::vector<uint32_t> device_ids = {};
for (auto dev : devs) {
device_ids.push_back(dev->id());
Expand Down Expand Up @@ -189,13 +190,14 @@ TEST(GalaxyTests, TestReduceScatterDeadlock) {
std::shared_ptr<MeshDevice> mesh = ttnn::multi_device::open_mesh_device(mesh_shape, 0, 0, 1, DispatchCoreType::WORKER);
// Create the outer ring on which Reduce Scatter will be run. This allows us to verify that there are no deadlocks when we send CCLs to the
// first tunnel (forward path).
std::vector<Device*> ring_devices = mesh->get_devices_on_row(0); // Tunnel 0
std::vector<Device*> ring_devices_1 = mesh->get_devices_on_column(mesh_shape.second - 1); // Orthogonal to tunnel .. no deadlocks
auto view = MeshDeviceView(*mesh);
std::vector<Device*> ring_devices = view.get_devices_on_row(0); // Tunnel 0
std::vector<Device*> ring_devices_1 = view.get_devices_on_column(mesh_shape.second - 1); // Orthogonal to tunnel .. no deadlocks
ring_devices_1 = std::vector<Device*>(ring_devices_1.begin() + 1, ring_devices_1.end());
std::vector<Device*> ring_devices_2 = mesh->get_devices_on_row(7); // Tunnel 7 .. potential deadlocks with lack of buffering
std::vector<Device*> ring_devices_2 = view.get_devices_on_row(7); // Tunnel 7 .. potential deadlocks with lack of buffering
std::reverse(ring_devices_2.begin(), ring_devices_2.end());
ring_devices_2 = std::vector<Device*>(ring_devices_2.begin() + 1, ring_devices_2.end());
std::vector<Device*> ring_devices_3 = mesh->get_devices_on_column(0); // Orthogonal to tunnel .. no deadlocks
std::vector<Device*> ring_devices_3 = view.get_devices_on_column(0); // Orthogonal to tunnel .. no deadlocks
std::reverse(ring_devices_3.begin(), ring_devices_3.end());
ring_devices_3 = std::vector<Device*>(ring_devices_3.begin() + 1, ring_devices_3.end() - 1);

Expand Down
Loading

0 comments on commit 240ee29

Please sign in to comment.