From 240ee29fd93ffd77160d081b8f8038d756a3b199 Mon Sep 17 00:00:00 2001 From: Joseph Chu Date: Sat, 21 Sep 2024 00:53:38 +0000 Subject: [PATCH] #11403: Support 2x4-submeshes across 8x4 mesh 1. Submeshing to support creating submesh on galaxy mesh 2. Key change to start enabling more T3000 Tests onto galaxy: - Currently ttnn.all_gather(..) in a ring relies on MeshDevice being initialized in ring-order. Now we decouple this so we don't require that MeshDevice is initialized with devices in a ring-order. Instead, we now explicitly request for a ring-order in the operation that requires it. --- conftest.py | 6 +- .../test_falcon_create_qkv_heads.py | 2 +- .../tests/unit_tests/test_falcon_softmax.py | 2 +- .../t3000/llama2_70b/tests/test_llama_perf.py | 1 - .../llama2_70b/tt/llama_mlp_optimized.py | 3 +- tests/scripts/tg/run_tg_model_perf_tests.sh | 6 +- .../test_multidevice_TG.py | 16 ++ .../unit_tests/gtests/test_ccl_on_galaxy.cpp | 12 +- .../unit_tests/operations/test_all_gather.py | 178 +++++++++++++----- .../test_all_gather_llama_perf_sweep.py | 2 +- .../operations/test_all_gather_nightly.py | 2 +- .../test_distributed_layernorm_sharded.py | 2 +- .../test_reduce_scatter_llama_perf_sweep.py | 2 +- .../operations/test_reduce_scatter_nightly.py | 2 +- .../test_reduce_scatter_post_commit.py | 2 +- tests/ttnn/unit_tests/test_multi_device.py | 4 + .../device/mesh_configurations/T3000.json | 2 +- tt_metal/impl/device/mesh_device.cpp | 140 +++++++++++--- tt_metal/impl/device/mesh_device.hpp | 19 +- tt_metal/impl/device/mesh_device_view.cpp | 111 ++++++++++- tt_metal/impl/device/mesh_device_view.hpp | 19 +- ttnn/cpp/pybind11/multi_device.hpp | 22 +-- .../ccl/all_gather/device/all_gather_op.cpp | 8 +- ttnn/cpp/ttnn/tensor/tensor.cpp | 8 +- 24 files changed, 436 insertions(+), 135 deletions(-) diff --git a/conftest.py b/conftest.py index 36a8d35b6b57..9908a40aa6d6 100644 --- a/conftest.py +++ b/conftest.py @@ -232,10 +232,7 @@ def pcie_mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, devic request.node.pci_ids = device_ids[:num_pcie_devices_requested] mesh_device = ttnn.open_mesh_device( - ttnn.MeshShape(1, num_pcie_devices_requested), - dispatch_core_type=get_dispatch_core_type(), - **device_params, - physical_device_ids=device_ids[:num_pcie_devices_requested], + ttnn.MeshShape(2, 2), dispatch_core_type=get_dispatch_core_type(), **device_params, offset=(0, 1) ) logger.debug(f"multidevice with {mesh_device.get_num_devices()} devices is created") @@ -255,6 +252,7 @@ def t3k_mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, device if ttnn.get_num_devices() < 8: pytest.skip() + request.node.pci_ids = ttnn.get_pcie_device_ids() mesh_device = ttnn.open_mesh_device( ttnn.MeshShape(2, 4), dispatch_core_type=get_dispatch_core_type(), diff --git a/models/demos/t3000/falcon40b/tests/unit_tests/test_falcon_create_qkv_heads.py b/models/demos/t3000/falcon40b/tests/unit_tests/test_falcon_create_qkv_heads.py index 8dc05c8cbc79..c1c758e66046 100644 --- a/models/demos/t3000/falcon40b/tests/unit_tests/test_falcon_create_qkv_heads.py +++ b/models/demos/t3000/falcon40b/tests/unit_tests/test_falcon_create_qkv_heads.py @@ -10,7 +10,7 @@ from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import ( comp_pcc, ) -from models.utility_functions import torch2tt_tensor, tt2torch_tensor, skip_for_grayskull, get_devices_for_t3000 +from models.utility_functions import torch2tt_tensor, tt2torch_tensor, skip_for_grayskull from models.demos.t3000.falcon40b.tt.model_config import ( get_model_config, ) diff --git a/models/demos/t3000/falcon40b/tests/unit_tests/test_falcon_softmax.py b/models/demos/t3000/falcon40b/tests/unit_tests/test_falcon_softmax.py index 6a14cac85b50..ccbed44e4546 100644 --- a/models/demos/t3000/falcon40b/tests/unit_tests/test_falcon_softmax.py +++ b/models/demos/t3000/falcon40b/tests/unit_tests/test_falcon_softmax.py @@ -11,7 +11,7 @@ from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import ( comp_pcc, ) -from models.utility_functions import torch2tt_tensor, tt2torch_tensor, skip_for_grayskull, get_devices_for_t3000 +from models.utility_functions import torch2tt_tensor, tt2torch_tensor, skip_for_grayskull from models.demos.t3000.falcon40b.tt.model_config import ( get_model_config, ) diff --git a/models/demos/t3000/llama2_70b/tests/test_llama_perf.py b/models/demos/t3000/llama2_70b/tests/test_llama_perf.py index b4e81524ec69..8ea224c848a2 100644 --- a/models/demos/t3000/llama2_70b/tests/test_llama_perf.py +++ b/models/demos/t3000/llama2_70b/tests/test_llama_perf.py @@ -25,7 +25,6 @@ disable_compilation_reports, nearest_32, skip_for_grayskull, - get_devices_for_t3000, ) from models.perf.perf_utils import prep_perf_report from models.perf.device_perf_utils import run_device_perf, check_device_perf, prep_device_perf_report diff --git a/models/demos/t3000/llama2_70b/tt/llama_mlp_optimized.py b/models/demos/t3000/llama2_70b/tt/llama_mlp_optimized.py index 1ac4f56a41a8..4d30605e63c7 100644 --- a/models/demos/t3000/llama2_70b/tt/llama_mlp_optimized.py +++ b/models/demos/t3000/llama2_70b/tt/llama_mlp_optimized.py @@ -72,12 +72,11 @@ def load_weights(self): padded_w3[:, :, :, :H4] = self.state_dict[w3_str].transpose(-2, -1) # w1: 8k x 4k. width-sharded on 12 banks, 4224 over 12 banks. - device = self.mesh_device.get_device(0) weight_grid = ttnn.CoreRangeSet( { ttnn.CoreRange( ttnn.CoreCoord(0, 0), - ttnn.CoreCoord(device.dram_grid_size().x - 1, device.dram_grid_size().y - 1), + ttnn.CoreCoord(self.mesh_device.dram_grid_size().x - 1, self.mesh_device.dram_grid_size().y - 1), ) } ) diff --git a/tests/scripts/tg/run_tg_model_perf_tests.sh b/tests/scripts/tg/run_tg_model_perf_tests.sh index 7cd43da8c897..76d85050dcc5 100755 --- a/tests/scripts/tg/run_tg_model_perf_tests.sh +++ b/tests/scripts/tg/run_tg_model_perf_tests.sh @@ -1,6 +1,10 @@ #!/bin/bash -run_tg_llm_tests() { +run_t3k_tests_on_tg_tests() { + + echo "LOG_METAL: Running T3000 tests on TG" + env pytest -n auto models/demos/t3000/llama2_70b/tests/test_llama_perf_decode.py -m "model_perf_t3000" --timeout=600 ; fail+=$? + # Merge all the generated reports env python models/perf/merge_perf_results.py; fail+=$? diff --git a/tests/ttnn/multichip_unit_tests/test_multidevice_TG.py b/tests/ttnn/multichip_unit_tests/test_multidevice_TG.py index 3d927636f91a..4ca13900f54d 100644 --- a/tests/ttnn/multichip_unit_tests/test_multidevice_TG.py +++ b/tests/ttnn/multichip_unit_tests/test_multidevice_TG.py @@ -1573,3 +1573,19 @@ def test_sharded_distributed_layernorm(mesh_device, input_width, input_height, c is_pass, output_pcc = comp_pcc(torch_output_tensor, tt_output_tensor, pcc=0.999) assert is_pass, f"PCC value: {output_pcc}" + + +def test_ttnn_multi_device_all_gather_all_devices(t3k_mesh_device): + """Example test for running a 2x4-Ring All-Gather on galaxy""" + full_tensor = torch.ones((1, 1, 32, 32 * t3k_mesh_device.get_num_devices()), dtype=torch.bfloat16) + for i in range(t3k_mesh_device.get_num_devices()): + full_tensor[..., i * 32 : (i + 1) * 32] = i + + ttnn_tensor = ttnn.from_torch(full_tensor, mesh_mapper=ShardTensorToMesh(t3k_mesh_device, dim=3)) + ttnn_tensor = ttnn.to_device(ttnn_tensor, t3k_mesh_device) + ttnn_tensor = ttnn.all_gather(ttnn_tensor, dim=3, num_links=1) + + device_tensors: typing.List[ttnn.Tensor] = ttnn.get_device_tensors(ttnn_tensor) + for device_tensor in device_tensors: + device_tensor_torch = ttnn.to_torch(device_tensor) + assert torch.all(device_tensor_torch == full_tensor) diff --git a/tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp b/tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp index 6d7ed90ee8b9..2d2e99504677 100644 --- a/tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp +++ b/tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp @@ -130,8 +130,9 @@ TEST(GalaxyTests, TestAllGatherDeadlock) { } // Iterate over each row and run line all-gather multiple times. // For each row, send adversarial traffic to the first chip, that can hang the network if the CCL is not tagged. + auto view = MeshDeviceView(*mesh); for (uint32_t row = 0; row < 8; row++) { - auto devs = mesh->get_devices_on_row(row); + auto devs = view.get_devices_on_row(row); std::vector device_ids = {}; for (auto dev : devs) { device_ids.push_back(dev->id()); @@ -189,13 +190,14 @@ TEST(GalaxyTests, TestReduceScatterDeadlock) { std::shared_ptr mesh = ttnn::multi_device::open_mesh_device(mesh_shape, 0, 0, 1, DispatchCoreType::WORKER); // Create the outer ring on which Reduce Scatter will be run. This allows us to verify that there are no deadlocks when we send CCLs to the // first tunnel (forward path). - std::vector ring_devices = mesh->get_devices_on_row(0); // Tunnel 0 - std::vector ring_devices_1 = mesh->get_devices_on_column(mesh_shape.second - 1); // Orthogonal to tunnel .. no deadlocks + auto view = MeshDeviceView(*mesh); + std::vector ring_devices = view.get_devices_on_row(0); // Tunnel 0 + std::vector ring_devices_1 = view.get_devices_on_column(mesh_shape.second - 1); // Orthogonal to tunnel .. no deadlocks ring_devices_1 = std::vector(ring_devices_1.begin() + 1, ring_devices_1.end()); - std::vector ring_devices_2 = mesh->get_devices_on_row(7); // Tunnel 7 .. potential deadlocks with lack of buffering + std::vector ring_devices_2 = view.get_devices_on_row(7); // Tunnel 7 .. potential deadlocks with lack of buffering std::reverse(ring_devices_2.begin(), ring_devices_2.end()); ring_devices_2 = std::vector(ring_devices_2.begin() + 1, ring_devices_2.end()); - std::vector ring_devices_3 = mesh->get_devices_on_column(0); // Orthogonal to tunnel .. no deadlocks + std::vector ring_devices_3 = view.get_devices_on_column(0); // Orthogonal to tunnel .. no deadlocks std::reverse(ring_devices_3.begin(), ring_devices_3.end()); ring_devices_3 = std::vector(ring_devices_3.begin() + 1, ring_devices_3.end() - 1); diff --git a/tests/ttnn/unit_tests/operations/test_all_gather.py b/tests/ttnn/unit_tests/operations/test_all_gather.py index bc58fef10eb6..501d85be176f 100644 --- a/tests/ttnn/unit_tests/operations/test_all_gather.py +++ b/tests/ttnn/unit_tests/operations/test_all_gather.py @@ -119,7 +119,7 @@ def run_with_trace( def run_all_gather_impl( - all_devices, + t3k_mesh_device, num_devices, input_shape, dim, @@ -130,12 +130,11 @@ def run_all_gather_impl( use_program_cache, function_level_defaults, all_gather_operation, - devices, num_iters=1, enable_async=False, ): # Use Async mode based on test input config - for device in all_devices: + for device in t3k_mesh_device.get_devices(): device.enable_async(enable_async) if enable_async: logger.info(f"Using Async Mode for All Gather Op Dispatch") @@ -149,17 +148,19 @@ def run_all_gather_impl( logger.info(f"dim: {dim}") input_tensor = torch.rand(input_shape).bfloat16() + input_tensor_mesh = ttnn.from_torch( + input_tensor, + dtype=input_dtype, + layout=layout, + device=t3k_mesh_device, + memory_config=mem_config, + mesh_mapper=ShardTensorToMesh(t3k_mesh_device, dim=dim), + ) - input_tensors = torch.chunk(input_tensor, num_devices, dim) - tt_input_tensors = [] - for i, t in enumerate(input_tensors): - tt_input_tensors.append(ttnn.Tensor(t, input_dtype).to(layout).to(devices[i], mem_config)) - - input_tensor_mesh = ttnn.aggregate_as_tensor(tt_input_tensors) for i in range(num_iters): tt_out_tensor = all_gather_operation(input_tensor_mesh, dim, num_links=num_links, memory_config=mem_config) - for d in devices: + for d in t3k_mesh_device.get_devices(): ttnn.synchronize_device(d) logger.info(f"Done iteration {i}") @@ -199,7 +200,7 @@ def run_all_gather_on_n300_impl( pytest.skip(f"Skipping unsupported case {message}.") return run_all_gather_impl( - all_devices, + t3k_mesh_device, num_devices, input_shape, dim, @@ -210,14 +211,13 @@ def run_all_gather_on_n300_impl( use_program_cache, function_level_defaults, all_gather_operation, - all_devices, num_iters, enable_async, ) def run_all_gather_on_t3000_impl( - all_devices, + t3k_mesh_device, num_devices, input_shape, dim, @@ -231,7 +231,7 @@ def run_all_gather_on_t3000_impl( num_iters=1, enable_async=False, ): - if len(all_devices) != 8: + if t3k_mesh_device.get_num_devices() != 8: pytest.skip("Not T3000!") (is_known_failure, message) = is_unsupported_case_t3k( @@ -240,10 +240,8 @@ def run_all_gather_on_t3000_impl( if is_known_failure: pytest.skip(f"Skipping unsupported case {message}.") - devices = get_devices_for_t3000(all_devices, num_devices) - return run_all_gather_impl( - all_devices, + t3k_mesh_device, num_devices, input_shape, dim, @@ -254,14 +252,13 @@ def run_all_gather_on_t3000_impl( use_program_cache, function_level_defaults, all_gather_operation, - devices, num_iters, enable_async, ) def run_all_gather_on_t3000_impl_tight_loop( - all_devices, + t3k_mesh_device, num_devices, input_shape, dim, @@ -276,7 +273,7 @@ def run_all_gather_on_t3000_impl_tight_loop( enable_async=False, ): run_all_gather_on_t3000_impl( - all_devices, + t3k_mesh_device, num_devices, input_shape, dim, @@ -324,7 +321,7 @@ def run_all_gather_on_t3000_impl_tight_loop( @pytest.mark.parametrize("num_iters", [1]) # restore to 500: https://github.com/tenstorrent/tt-metal/issues/9686 @pytest.mark.parametrize("enable_async", [True, False]) def test_all_gather_on_t3000_post_commit_looping( - all_devices, + t3k_mesh_device, num_devices, input_shape, dim, @@ -338,7 +335,7 @@ def test_all_gather_on_t3000_post_commit_looping( enable_async, ): run_all_gather_on_t3000_impl_tight_loop( - all_devices, + t3k_mesh_device, num_devices, input_shape, dim, @@ -386,7 +383,7 @@ def test_all_gather_on_t3000_post_commit_looping( @pytest.mark.parametrize("num_iters", [1000]) # TODO: restore to 500 @pytest.mark.parametrize("enable_async", [True, False]) def test_all_gather_on_t3000_nightly_commit_looping( - all_devices, + t3k_mesh_device, num_devices, input_shape, dim, @@ -400,7 +397,7 @@ def test_all_gather_on_t3000_nightly_commit_looping( enable_async, ): run_all_gather_on_t3000_impl_tight_loop( - all_devices, + t3k_mesh_device, num_devices, input_shape, dim, @@ -422,6 +419,90 @@ def test_all_gather_on_t3000_nightly_commit_looping( "num_devices, num_links, input_shape, dim, layout", [ (4, 2, [4, 1, 33, 256], 0, ttnn.ROW_MAJOR_LAYOUT), # https://github.com/tenstorrent/tt-metal/issues/9686 + # (8, 1, [8, 8, 256, 384], 1, ttnn.ROW_MAJOR_LAYOUT), # https://github.com/tenstorrent/tt-metal/issues/9686 + # (4, 2, [8, 8, 256, 384], 1, ttnn.ROW_MAJOR_LAYOUT), # https://github.com/tenstorrent/tt-metal/issues/9686 + # (4, 2, [8, 8, 256, 384], 1, ttnn.TILE_LAYOUT), # https://github.com/tenstorrent/tt-metal/issues/9686 + # (8, 1, [8, 8, 256, 384], 1, ttnn.TILE_LAYOUT), # https://github.com/tenstorrent/tt-metal/issues/9686 + # (4, 2, [8, 5, 13, 384], 3, ttnn.ROW_MAJOR_LAYOUT), # https://github.com/tenstorrent/tt-metal/issues/9686 + # (8, 1, [8, 5, 13, 512], 3, ttnn.ROW_MAJOR_LAYOUT), # https://github.com/tenstorrent/tt-metal/issues/9686 + # (4, 2, [8, 5, 32, 384], 3, ttnn.TILE_LAYOUT), # https://github.com/tenstorrent/tt-metal/issues/9686 + # (8, 1, [8, 5, 32, 512], 3, ttnn.TILE_LAYOUT), + # Only for BFP8B + # # ([1, 1, 640, 32768], 3, ttnn.TILE_LAYOUT), # https://github.com/tenstorrent/tt-metal/issues/9686 + # # MLP AllGather, Llama 2 decode attn, mlp. Llama2, Falcon 40B decode mlp attn + # (8, 1, [1, 1, 32, 32768], 3, ttnn.TILE_LAYOUT), # https://github.com/tenstorrent/tt-metal/issues/9686 + # (4, 2, [1, 1, 32, 16384], 3, ttnn.TILE_LAYOUT), # https://github.com/tenstorrent/tt-metal/issues/9686 + # # (4, 2, [1, 1, 32, 32768], 3, ttnn.TILE_LAYOUT), # https://github.com/tenstorrent/tt-metal/issues/9686 + # # (8, 1, [1, 1, 32, 32768], 3, ttnn.ROW_MAJOR_LAYOUT), # https://github.com/tenstorrent/tt-metal/issues/9686 + # # Input, Selfout, Final AllGather, Llama2, Falcon 40B decode mlp attn + # (8, 1, [1, 1, 32, 8192], 3, ttnn.TILE_LAYOUT), # https://github.com/tenstorrent/tt-metal/issues/9686 + # (4, 2, [1, 1, 32, 8192], 3, ttnn.TILE_LAYOUT), # https://github.com/tenstorrent/tt-metal/issues/9686 + # (8, 1, [1, 1, 32, 8192], 3, ttnn.ROW_MAJOR_LAYOUT), # https://github.com/tenstorrent/tt-metal/issues/9686 + # # Falcon 40B prefill + # # 8 chips + # (8, 1, [1, 1, 2048, 8192], 3, ttnn.TILE_LAYOUT), # https://github.com/tenstorrent/tt-metal/issues/9686 + # (8, 1, [1, 1, 2048, 8192], 3, ttnn.ROW_MAJOR_LAYOUT), # https://github.com/tenstorrent/tt-metal/issues/9686 + # # Falcon 40B prefill, also mixtral expert reduction (w/ zero filled tensor) + # # 8 chips + # (8, 1, [1, 1, 2048, 32768], 3, ttnn.TILE_LAYOUT), # https://github.com/tenstorrent/tt-metal/issues/9686 + # # Llama/falcon40B galaxy mlp weights stationary -> emulation of row/col reduce + # (8, 1, [1, 1, 256, 1024], 2, ttnn.TILE_LAYOUT), # https://github.com/tenstorrent/tt-metal/issues/9686 + # (8, 1, [1, 1, 246, 4096], 2, ttnn.ROW_MAJOR_LAYOUT), # https://github.com/tenstorrent/tt-metal/issues/9686 + # (8, 1, [1, 1, 246, 4096], 2, ttnn.TILE_LAYOUT), # https://github.com/tenstorrent/tt-metal/issues/9686 + # (8, 1, [1, 1, 8192, 32], 2, ttnn.TILE_LAYOUT), # https://github.com/tenstorrent/tt-metal/issues/9686 + # (8, 1, [1, 1, 1024, 256], 3, ttnn.TILE_LAYOUT), # https://github.com/tenstorrent/tt-metal/issues/9686 + # (8, 1, [1, 1, 256, 2048], 2, ttnn.TILE_LAYOUT), # https://github.com/tenstorrent/tt-metal/issues/9686 + # (8, 1, [1, 1, 256, 8192], 2, ttnn.TILE_LAYOUT), # double on reduction dim for 8 chip # https://github.com/tenstorrent/tt-metal/issues/9686 + # (8, 1, [8, 1, 256, 32], 0, ttnn.TILE_LAYOUT), # https://github.com/tenstorrent/tt-metal/issues/9686 + # (8, 1, [8, 8, 128, 4096], 1, ttnn.TILE_LAYOUT), # https://github.com/tenstorrent/tt-metal/issues/9686 + ], +) +@pytest.mark.parametrize( + "input_dtype", + [ + ttnn.bfloat16, + # ttnn.bfloat8_b, # https://github.com/tenstorrent/tt-metal/issues/9686 + ], +) +@pytest.mark.parametrize( + "mem_config", + [ + ttnn.MemoryConfig(buffer_type=ttnn.BufferType.DRAM), + # ttnn.MemoryConfig(buffer_type=ttnn.BufferType.L1), # https://github.com/tenstorrent/tt-metal/issues/9686 + ], +) +def test_x4_all_gather_on_t3000_post_commit( + pcie_mesh_device, + num_devices, + input_shape, + dim, + num_links, + input_dtype, + layout, + mem_config, + use_program_cache, + function_level_defaults, +): + run_all_gather_on_t3000_impl( + pcie_mesh_device, + num_devices, + input_shape, + dim, + num_links, + input_dtype, + layout, + mem_config, + use_program_cache, + function_level_defaults, + all_gather_operation=ttnn.all_gather, + ) + + +# Enumerate the post-commit cases explicitly +@skip_for_grayskull("Requires eth connected devices to run") +@pytest.mark.parametrize( + "num_devices, num_links, input_shape, dim, layout", + [ (8, 1, [8, 1, 33, 256], 0, ttnn.ROW_MAJOR_LAYOUT), # https://github.com/tenstorrent/tt-metal/issues/9686 # (8, 1, [8, 8, 256, 384], 1, ttnn.ROW_MAJOR_LAYOUT), # https://github.com/tenstorrent/tt-metal/issues/9686 # (4, 2, [8, 8, 256, 384], 1, ttnn.ROW_MAJOR_LAYOUT), # https://github.com/tenstorrent/tt-metal/issues/9686 @@ -476,7 +557,7 @@ def test_all_gather_on_t3000_nightly_commit_looping( ], ) def test_all_gather_on_t3000_post_commit( - all_devices, + t3k_mesh_device, num_devices, input_shape, dim, @@ -488,7 +569,7 @@ def test_all_gather_on_t3000_post_commit( function_level_defaults, ): run_all_gather_on_t3000_impl( - all_devices, + t3k_mesh_device, num_devices, input_shape, dim, @@ -557,7 +638,7 @@ def run_line_all_gather( # This run function is deprecated and is only intended to be used by 2-link tests on t3k def run_line_all_gather_deprecated( - all_devices, + pcie_mesh_device, num_devices, input_shape, dim, @@ -570,10 +651,7 @@ def run_line_all_gather_deprecated( enable_async, num_iters=1, ): - if len(all_devices) != 8: - pytest.skip("Not T3000!") - - for device in all_devices: + for device in pcie_mesh_device.get_devices(): device.enable_async(enable_async) logger.info(f"Input shape: {input_shape}") @@ -585,23 +663,23 @@ def run_line_all_gather_deprecated( if is_known_failure: pytest.skip(f"Skipping unsupported case {message}.") - devices = get_devices_for_t3000(all_devices, num_devices) - logger.info(f"Input shape: {input_shape}") logger.info(f"dim: {dim}") input_tensor = torch.rand(input_shape).bfloat16() - input_tensors = torch.chunk(input_tensor, num_devices, dim) - tt_input_tensors = [] - for i, t in enumerate(input_tensors): - tt_input_tensors.append(ttnn.Tensor(t, input_dtype).to(layout).to(devices[i], mem_config)) - - input_tensor_mesh = ttnn.aggregate_as_tensor(tt_input_tensors) + input_tensor_mesh = ttnn.from_torch( + input_tensor, + mesh_mapper=ShardTensorToMesh(pcie_mesh_device, dim=dim), + memory_config=mem_config, + device=pcie_mesh_device, + layout=layout, + dtype=input_dtype, + ) for i in range(num_iters): tt_out_tensor = ttnn.line_all_gather(input_tensor_mesh, dim, num_links=num_links, memory_config=mem_config) - for d in devices: + for d in pcie_mesh_device.get_devices(): ttnn.synchronize_device(d) logger.info(f"Done iteration {i}") @@ -711,7 +789,7 @@ def test_line_all_gather_on_t3000_post_commit( ) @pytest.mark.parametrize("enable_async", [True, False]) def test_line_all_gather_on_t3000_post_commit_two_link( - all_devices, + pcie_mesh_device, num_devices, input_shape, dim, @@ -725,7 +803,7 @@ def test_line_all_gather_on_t3000_post_commit_two_link( num_iters=1, ): run_line_all_gather_deprecated( - all_devices, + pcie_mesh_device, num_devices, input_shape, dim, @@ -984,7 +1062,7 @@ def test_all_gather_on_t3000_nightly( pytest.xfail(reason="Known failure") run_all_gather_on_t3000_impl( - all_devices, + t3k_mesh_device, num_devices, input_shape, dim, @@ -1094,12 +1172,14 @@ def run_all_gather_sharded( tt_input_tensors_dups = [] tt_input_tensors = [] - for i, t in enumerate(input_tensors): - tt_input_tensors_dups.append(ttnn.Tensor(t, input_dtype).to(tensor_layout).to(devices[i], input_mem_config)) - tt_input_tensors.append(ttnn.Tensor(t, input_dtype).to(tensor_layout).to(devices[i], input_mem_config)) - - input_tensor_mesh = ttnn.aggregate_as_tensor(tt_input_tensors) - + input_tensor_mesh = ttnn.from_torch( + unchunked_input_tensor, + layout=tensor_layout, + dtype=input_dtype, + device=t3k_mesh_device, + memory_config=input_mem_config, + mesh_mapper=ShardTensorToMesh(t3k_mesh_device, dim=dim), + ) if trace_mode: tt_out_tensor = run_with_trace( t3k_mesh_device, diff --git a/tests/ttnn/unit_tests/operations/test_all_gather_llama_perf_sweep.py b/tests/ttnn/unit_tests/operations/test_all_gather_llama_perf_sweep.py index 18f9855f85b6..d04d4a6094c7 100644 --- a/tests/ttnn/unit_tests/operations/test_all_gather_llama_perf_sweep.py +++ b/tests/ttnn/unit_tests/operations/test_all_gather_llama_perf_sweep.py @@ -7,7 +7,7 @@ from loguru import logger import ttnn from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_equal, comp_pcc -from models.utility_functions import skip_for_grayskull, get_devices_for_t3000 +from models.utility_functions import skip_for_grayskull import itertools from ttnn import ShardTensorToMesh from tests.ttnn.unit_tests.operations.test_all_gather import run_all_gather_sharded diff --git a/tests/ttnn/unit_tests/operations/test_all_gather_nightly.py b/tests/ttnn/unit_tests/operations/test_all_gather_nightly.py index 43ce613608e0..7c4005cda1bb 100644 --- a/tests/ttnn/unit_tests/operations/test_all_gather_nightly.py +++ b/tests/ttnn/unit_tests/operations/test_all_gather_nightly.py @@ -7,7 +7,7 @@ from loguru import logger import ttnn from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_equal, comp_pcc -from models.utility_functions import skip_for_grayskull, get_devices_for_t3000 +from models.utility_functions import skip_for_grayskull from tests.ttnn.unit_tests.operations.test_all_gather import ( is_unsupported_case, run_line_all_gather, diff --git a/tests/ttnn/unit_tests/operations/test_distributed_layernorm_sharded.py b/tests/ttnn/unit_tests/operations/test_distributed_layernorm_sharded.py index 9f256b465f51..d07d54d7d120 100644 --- a/tests/ttnn/unit_tests/operations/test_distributed_layernorm_sharded.py +++ b/tests/ttnn/unit_tests/operations/test_distributed_layernorm_sharded.py @@ -13,7 +13,7 @@ comp_allclose, ) -from models.utility_functions import tt2torch_tensor, get_devices_for_t3000, skip_for_grayskull +from models.utility_functions import tt2torch_tensor, skip_for_grayskull def rms_norm(x, gamma, eps): diff --git a/tests/ttnn/unit_tests/operations/test_reduce_scatter_llama_perf_sweep.py b/tests/ttnn/unit_tests/operations/test_reduce_scatter_llama_perf_sweep.py index bfee281eb3ba..1f6aef79bebf 100644 --- a/tests/ttnn/unit_tests/operations/test_reduce_scatter_llama_perf_sweep.py +++ b/tests/ttnn/unit_tests/operations/test_reduce_scatter_llama_perf_sweep.py @@ -7,7 +7,7 @@ from loguru import logger import ttnn from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_equal, comp_pcc -from models.utility_functions import skip_for_grayskull, get_devices_for_t3000 +from models.utility_functions import skip_for_grayskull from tests.ttnn.unit_tests.operations.test_reduce_scatter_post_commit import run_reduce_scatter_sharded_test diff --git a/tests/ttnn/unit_tests/operations/test_reduce_scatter_nightly.py b/tests/ttnn/unit_tests/operations/test_reduce_scatter_nightly.py index 9ab877af87cf..d7e91e884d6b 100644 --- a/tests/ttnn/unit_tests/operations/test_reduce_scatter_nightly.py +++ b/tests/ttnn/unit_tests/operations/test_reduce_scatter_nightly.py @@ -11,7 +11,7 @@ is_unsupported_case, run_reduce_scatter_test, ) -from models.utility_functions import skip_for_grayskull, get_devices_for_t3000 +from models.utility_functions import skip_for_grayskull import itertools diff --git a/tests/ttnn/unit_tests/operations/test_reduce_scatter_post_commit.py b/tests/ttnn/unit_tests/operations/test_reduce_scatter_post_commit.py index 8749f14f5b60..fc3c6ce8bbc1 100644 --- a/tests/ttnn/unit_tests/operations/test_reduce_scatter_post_commit.py +++ b/tests/ttnn/unit_tests/operations/test_reduce_scatter_post_commit.py @@ -7,7 +7,7 @@ from loguru import logger import ttnn from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_equal, comp_pcc -from models.utility_functions import skip_for_grayskull, get_devices_for_t3000 +from models.utility_functions import skip_for_grayskull def is_unsupported_case(input_shape, scatter_dim, math_op, mem_config, num_devices, num_links, input_dtype, layout): diff --git a/tests/ttnn/unit_tests/test_multi_device.py b/tests/ttnn/unit_tests/test_multi_device.py index f1c2728857f0..982e539689df 100644 --- a/tests/ttnn/unit_tests/test_multi_device.py +++ b/tests/ttnn/unit_tests/test_multi_device.py @@ -587,3 +587,7 @@ def test_validate_as_tensor(tmp_path, mesh_device, height, width): for device in mesh_device.get_devices(): device_tensor = ttnn.get_device_tensor(tensor, device) assert torch.allclose(ttnn.to_torch(device_tensor), torch_input_tensor) + + +def test_ttnn_visualize_mesh_device(t3k_mesh_device): + ttnn.visualize_mesh_device(t3k_mesh_device) diff --git a/tt_metal/impl/device/mesh_configurations/T3000.json b/tt_metal/impl/device/mesh_configurations/T3000.json index 2c62209d01fc..acfe3edac004 100644 --- a/tt_metal/impl/device/mesh_configurations/T3000.json +++ b/tt_metal/impl/device/mesh_configurations/T3000.json @@ -1,6 +1,6 @@ { "logical_to_physical_coordinates": [ [[0, 0], [0, 0, 0, 0]], [[0, 1], [0, 1, 0, 0]], [[0, 2], [0, 2, 0, 0]], [[0, 3], [0, 3, 0, 0]], - [[1, 0], [1, 3, 0, 0]], [[1, 1], [1, 2, 0, 0]], [[1, 2], [1, 1, 0, 0]], [[1, 3], [1, 0, 0, 0]] + [[1, 0], [1, 0, 0, 0]], [[1, 1], [1, 1, 0, 0]], [[1, 2], [1, 2, 0, 0]], [[1, 3], [1, 3, 0, 0]] ] } diff --git a/tt_metal/impl/device/mesh_device.cpp b/tt_metal/impl/device/mesh_device.cpp index e90d4a8925e2..7c3f1a461078 100644 --- a/tt_metal/impl/device/mesh_device.cpp +++ b/tt_metal/impl/device/mesh_device.cpp @@ -127,6 +127,14 @@ std::vector SystemMesh::get_mapped_physical_device_ids(const MeshDevi } return physical_device_ids; } +void SystemMesh::register_mesh_device(const std::shared_ptr &mesh_device, const std::vector& devices) { + std::vector physical_device_ids; + for (auto device : devices) { + physical_device_ids.push_back(device->id()); + } + this->assigned_mesh_device_devices.insert({mesh_device->get_mesh_id(), mesh_device}); + this->assigned_devices.insert({mesh_device->get_mesh_id(), physical_device_ids}); +} std::vector SystemMesh::map_mesh_device( std::shared_ptr mesh_device, @@ -145,7 +153,6 @@ std::vector SystemMesh::map_mesh_device( TT_FATAL(requested_num_rows <= max_num_rows, "Requested too many rows: {} > {}", requested_num_rows, max_num_rows); TT_FATAL(requested_num_rows*requested_num_cols <= max_num_rows*max_num_cols, "Requested submesh is too big: {}x{}", requested_num_rows, requested_num_cols); - this->assigned_mesh_device_devices.insert({mesh_device->get_mesh_id(), mesh_device}); auto physical_device_ids = user_provided_physical_device_ids.empty() ? this->get_mapped_physical_device_ids(MeshDeviceConfig{mesh_device->shape(), offset}) : @@ -158,27 +165,34 @@ std::vector SystemMesh::map_mesh_device( for (auto physical_device_id : physical_device_ids) { auto mapped_device = this->opened_devices[mesh_device->get_mesh_id()].at(physical_device_id); mapped_devices.push_back(mapped_device); - this->assigned_devices[mesh_device->get_mesh_id()].push_back(physical_device_id); this->assigned_physical_id_to_device.insert({physical_device_id, mapped_device}); } + + this->register_mesh_device(mesh_device, mapped_devices); // TODO: change this return mapped_devices; } void SystemMesh::unmap_mesh_device(const std::shared_ptr& mesh_device) { auto mesh_id = mesh_device->get_mesh_id(); - - // Clean up all state related to this virtual mesh this->assigned_mesh_device_devices.erase(mesh_id); - // Remove the devices from assigned_physical_id_to_device - for (auto physical_id : this->assigned_devices.at(mesh_id)) { - this->assigned_physical_id_to_device.erase(physical_id); + // Close the devices + if (mesh_device->is_parent_mesh()) { + for (auto physical_id : this->assigned_devices.at(mesh_id)) { + this->assigned_physical_id_to_device.erase(physical_id); + } + tt::tt_metal::detail::CloseDevices(this->opened_devices.at(mesh_id)); + this->opened_devices.erase(mesh_id); } this->assigned_devices.erase(mesh_id); +} - // Close the devices - tt::tt_metal::detail::CloseDevices(this->opened_devices.at(mesh_id)); - this->opened_devices.erase(mesh_id); +Device* SystemMesh::get_device(const chip_id_t physical_device_id) { + auto it = this->assigned_physical_id_to_device.find(physical_device_id); + if (it == this->assigned_physical_id_to_device.end()) { + TT_THROW("Physical Device ID: {} not found in assigned devices", physical_device_id); + } + return it->second; } static MeshDeviceID generate_unique_mesh_id() { @@ -186,7 +200,8 @@ static MeshDeviceID generate_unique_mesh_id() { return next_id++; } -MeshDevice::MeshDevice(const MeshShape& mesh_device_shape) : mesh_device_shape(mesh_device_shape), mesh_id(generate_unique_mesh_id()) {} +MeshDevice::MeshDevice(const MeshShape& mesh_device_shape, std::shared_ptr parent_mesh) + : mesh_device_shape(mesh_device_shape), mesh_id(generate_unique_mesh_id()), parent_mesh(parent_mesh) {} std::shared_ptr MeshDevice::create( const MeshShape& mesh_device_shape, @@ -203,6 +218,36 @@ std::shared_ptr MeshDevice::create( return mesh_device; } +std::shared_ptr MeshDevice::create_submesh(const MeshShape &submesh_shape, const MeshOffset &offset) { + if (submesh_shape.first <= 0 || submesh_shape.second <= 0) { + TT_THROW("Invalid submesh shape: ({}, {}). Both dimensions must be positive.", submesh_shape.first, submesh_shape.second); + } + + if (offset.first < 0 || offset.second < 0) { + TT_THROW("Invalid offset: ({}, {}). Offset must be non-negative.", offset.first, offset.second); + } + + if (offset.first + submesh_shape.first > this->mesh_device_shape.first || + offset.second + submesh_shape.second > this->mesh_device_shape.second) { + TT_THROW("Submesh ({}x{}) with offset ({}, {}) does not fit within parent mesh ({}x{}).", + submesh_shape.first, submesh_shape.second, + offset.first, offset.second, + this->mesh_device_shape.first, this->mesh_device_shape.second); + } + + auto submesh = std::make_shared(submesh_shape, shared_from_this()); + auto start_coordinate = Coordinate{offset.first, offset.second}; + auto end_coordinate = Coordinate{offset.first + submesh_shape.first - 1, offset.second + submesh_shape.second - 1}; + submesh->primary_view = std::make_unique(*this, start_coordinate, end_coordinate); + submesh->devices = submesh->primary_view->get_devices(); + SystemMesh::instance().register_mesh_device(submesh, submesh->devices); + this->submeshes.push_back(submesh); + log_trace(LogMetal, "Instantiating submesh {}: {}x{} with offset: {} {}", submesh->get_mesh_id(), submesh_shape.first, submesh_shape.second, offset.first, offset.second); + log_trace(LogMetal, "Submesh {} instantiated with {} devices", submesh->get_mesh_id(), submesh->devices); + + return submesh; +} + void MeshDevice::initialize( size_t l1_small_size, size_t trace_region_size, @@ -223,16 +268,18 @@ void MeshDevice::initialize( this->devices = instance.map_mesh_device( shared_from_this(), num_command_queues, l1_small_size, trace_region_size, dispatch_core_type, offset, physical_device_ids); this->primary_view = std::make_unique(*this); - - for (int device_index = 0; device_index < this->devices.size(); device_index++) { - this->physical_id_to_device_index.insert({this->devices[device_index]->id(), device_index}); - } } MeshDevice::~MeshDevice() { if (not this->devices.empty()) { this->close_devices(); } + for (auto submesh : this->submeshes) { + submesh->close_devices(); + } + this->primary_view.reset(); + this->devices.clear(); + this->parent_mesh.reset(); } Device* MeshDevice::get_device_index(int logical_device_id) const { @@ -241,7 +288,7 @@ Device* MeshDevice::get_device_index(int logical_device_id) const { } Device* MeshDevice::get_device(int physical_device_id) const { - return this->devices.at(this->physical_id_to_device_index.at(physical_device_id)); + return SystemMesh::instance().get_device(physical_device_id); } std::vector MeshDevice::get_devices() const { return this->devices; } @@ -250,14 +297,6 @@ Device* MeshDevice::get_device(int row_idx, int col_idx) const { return this->get_device_index(row_idx * num_cols() + col_idx); } -std::vector MeshDevice::get_devices_on_row(int row_idx) const { - return this->primary_view->get_devices_on_row(row_idx); -} - -std::vector MeshDevice::get_devices_on_column(int col_idx) const { - return this->primary_view->get_devices_on_column(col_idx); -} - const DeviceIds MeshDevice::get_device_ids() const { DeviceIds device_ids; for (auto device : this->get_devices()) { @@ -283,7 +322,6 @@ MeshShape MeshDevice::shape() const { return this->mesh_device_shape; } void MeshDevice::close_devices() { SystemMesh::instance().unmap_mesh_device(shared_from_this()); this->devices.clear(); - this->physical_id_to_device_index.clear(); this->primary_view.reset(); } @@ -295,8 +333,60 @@ std::shared_ptr MeshDevice::get_view() const { return this std::shared_ptr MeshDevice::get_view() { return this->primary_view; } +std::vector> MeshDevice::get_submesh_views() { + std::vector> submesh_views; + if (this->submeshes.empty()) { + submesh_views.push_back(this->get_view()); + } + else { + for (auto submesh : this->submeshes) { + submesh_views.push_back(submesh->get_view()); + } + } + return submesh_views; +} + MeshDeviceID MeshDevice::get_mesh_id() const { return this->mesh_id; } +bool MeshDevice::is_parent_mesh() const { return this->parent_mesh == nullptr; } + +std::shared_ptr SystemMesh::get_mesh_device(const std::vector& physical_device_ids) { + log_trace(LogMetal, "Getting mesh device for {} physical devices: {}", physical_device_ids.size(), physical_device_ids); + std::unordered_set input_set(physical_device_ids.begin(), physical_device_ids.end()); + + for (const auto& [mesh_id, mesh_device] : this->assigned_mesh_device_devices) { + const auto& assigned_devices = this->assigned_devices.at(mesh_id); + std::unordered_set assigned_set(assigned_devices.begin(), assigned_devices.end()); + log_trace(LogMetal, "Assigned devices: {}", assigned_devices); + + if (input_set == assigned_set) { + return mesh_device; + } + } + TT_THROW("No mesh device found for the provided devices"); +} + +std::shared_ptr MeshDevice::fetch_mesh_device(const std::vector& devices) { + TT_FATAL(devices.size() > 0, "No devices provided"); + auto& instance = SystemMesh::instance(); + std::vector physical_device_ids; + for (auto device : devices) { + physical_device_ids.push_back(device->id()); + } + return instance.get_mesh_device(physical_device_ids); +} + +std::vector> MeshDevice::get_submeshes() const { return this->submeshes; } + +std::shared_ptr MeshDevice::get_view(const Device* device) { + for (auto submesh_view : this->get_submesh_views()) { + if (submesh_view->contains_device(device->id())) { + return submesh_view; + } + } + TT_THROW("Device {} not found in any submesh view", device->id()); +} + std::ostream& operator<<(std::ostream& os, const MeshDevice& mesh_device) { return os << mesh_device.to_string(); } bool validate_worker_modes(const std::vector& workers) { diff --git a/tt_metal/impl/device/mesh_device.hpp b/tt_metal/impl/device/mesh_device.hpp index 940110973cce..8a5cccc041aa 100644 --- a/tt_metal/impl/device/mesh_device.hpp +++ b/tt_metal/impl/device/mesh_device.hpp @@ -72,6 +72,7 @@ class SystemMesh { // Get the physical device IDs mapped to a MeshDevice std::vector get_mapped_physical_device_ids(const MeshDeviceConfig &config) const; + void register_mesh_device(const std::shared_ptr &mesh_device, const std::vector& devices); // Map MeshDevice to physical devices std::vector map_mesh_device( @@ -85,14 +86,18 @@ class SystemMesh { // Unmap MeshDevice, releasing the associated physical devices. void unmap_mesh_device(const std::shared_ptr &mesh_device); + std::shared_ptr get_mesh_device(const std::vector& physical_device_ids); + Device* get_device(const chip_id_t physical_device_id); }; class MeshDevice : public std::enable_shared_from_this { + private: MeshDeviceID mesh_id; MeshShape mesh_device_shape; std::shared_ptr primary_view; std::vector devices; - std::unordered_map physical_id_to_device_index; + std::shared_ptr parent_mesh; + std::vector> submeshes; void initialize( size_t l1_small_size, @@ -103,7 +108,7 @@ class MeshDevice : public std::enable_shared_from_this { const std::vector &physical_device_ids); public: - MeshDevice(const MeshShape &mesh_device_shape); + MeshDevice(const MeshShape &mesh_device_shape, std::shared_ptr parent_mesh = nullptr); ~MeshDevice(); MeshDevice(const MeshDevice &) = delete; @@ -116,8 +121,6 @@ class MeshDevice : public std::enable_shared_from_this { Device *get_device_index(int logical_device_id) const; Device *get_device(int physical_device_id) const; Device *get_device(int row_idx, int col_idx) const; - std::vector get_devices_on_row(int row_idx) const; - std::vector get_devices_on_column(int col_idx) const; const DeviceIds get_device_ids() const; @@ -138,6 +141,7 @@ class MeshDevice : public std::enable_shared_from_this { std::string to_string() const; MeshDeviceID get_mesh_id() const; + bool is_parent_mesh() const; static std::shared_ptr create( const MeshShape &mesh_device_shape, @@ -147,6 +151,13 @@ class MeshDevice : public std::enable_shared_from_this { DispatchCoreType dispatch_core_type, const std::pair &offset = {0, 0}, const std::vector &physical_device_ids = {}); + + std::vector> get_submeshes() const; + std::vector> get_submesh_views(); + std::shared_ptr get_view(const Device* device); + + std::shared_ptr create_submesh(const MeshShape &submesh_shape, const MeshOffset &offset = {0, 0}); + static std::shared_ptr fetch_mesh_device(const std::vector& devices); }; std::ostream &operator<<(std::ostream &os, const MeshDevice &mesh_device); diff --git a/tt_metal/impl/device/mesh_device_view.cpp b/tt_metal/impl/device/mesh_device_view.cpp index cc4a227780f6..8276596b9f11 100644 --- a/tt_metal/impl/device/mesh_device_view.cpp +++ b/tt_metal/impl/device/mesh_device_view.cpp @@ -11,6 +11,16 @@ namespace tt::tt_metal { using MeshDevice = tt::tt_metal::MeshDevice; +static std::vector get_devices_from_coordinates(MeshDeviceView& mesh, const std::vector& coords) { + std::vector devices; + for (const auto& coord : coords) { + if (auto device = mesh.get_device(coord.row, coord.col)) { + devices.push_back(device); + } + } + return devices; +} + MeshDeviceView::MeshDeviceView(const MeshDevice& mesh) : top_left_(0, 0), bottom_right_(mesh.num_rows() - 1, mesh.num_cols() - 1) { for (size_t row = 0; row < mesh.num_rows(); ++row) { @@ -24,12 +34,12 @@ MeshDeviceView::MeshDeviceView(const MeshDevice& mesh) } MeshDeviceView::MeshDeviceView(const MeshDevice& mesh, Coordinate top_left, Coordinate bottom_right) - : top_left_(top_left), bottom_right_(bottom_right) { + : top_left_(0, 0), bottom_right_(Coordinate{bottom_right.row - top_left.row, bottom_right.col - top_left.col}) { for (size_t row = top_left.row; row <= bottom_right.row; ++row) { for (size_t col = top_left.col; col <= bottom_right.col; ++col) { if (auto device = mesh.get_device(row, col)) { devices_.push_back(device); - device_coordinates_[(device)->id()] = {row, col}; + device_coordinates_[(device)->id()] = {row - top_left.row, col - top_left.col}; } } } @@ -55,10 +65,6 @@ MeshDeviceView::const_device_pointer MeshDeviceView::get_device(size_t row, size return nullptr; } -const std::vector& MeshDeviceView::get_devices() const { - return devices_; -} - MeshDeviceView::DeviceView MeshDeviceView::get_devices(const Coordinate& start, const Coordinate& end) { if (start.row > end.row || start.col > end.col) { log_fatal("Invalid coordinates: start {} must be less than or equal to end {}", start, end); @@ -158,6 +164,11 @@ bool MeshDeviceView::operator==(const MeshDeviceView& other) const { bottom_right_ == other.bottom_right_; } + +bool MeshDeviceView::contains_device(chip_id_t device_id) const { + return device_coordinates_.find(device_id) != device_coordinates_.end(); +} + Coordinate MeshDeviceView::find_device(chip_id_t device_id) const { auto it = device_coordinates_.find(device_id); if (it != device_coordinates_.end()) { @@ -199,5 +210,93 @@ void MeshDeviceView::validate_coordinates() const { throw std::invalid_argument("Invalid coordinates: top_left must be less than or equal to bottom_right"); } } +// Get the boundary coordinates of the subgrid defined by offset and shape +std::vector MeshDeviceView::get_ring_coordinates(const MeshShape& shape, const Coordinate& offset) const { + std::vector boundary_coords; + + size_t start_row = offset.row; + size_t start_col = offset.col; + size_t end_row = offset.row + shape.first - 1; + size_t end_col = offset.col + shape.second - 1; + + // Validate the specified subgrid + if (start_row >= num_rows() || start_col >= num_cols() || + end_row >= num_rows() || end_col >= num_cols()) { + throw std::invalid_argument("Subgrid is out of mesh bounds."); + } + + // Traverse the top row from left to right + for (size_t col = start_col; col <= end_col; ++col) { + boundary_coords.emplace_back(Coordinate{start_row, col}); + } + + // Traverse the rightmost column from top+1 to bottom + for (size_t row = start_row + 1; row <= end_row; ++row) { + boundary_coords.emplace_back(Coordinate{row, end_col}); + } + + // Traverse the bottom row from right to left, if there is more than one row + if (end_row > start_row and end_col > start_col) { + for (size_t col = end_col - 1; col + 1 > start_col; --col) { + boundary_coords.emplace_back(Coordinate{end_row, col}); + } + for (size_t row = end_row - 1; row > start_row; --row) { + boundary_coords.emplace_back(Coordinate{row, start_col}); + } + } + + return boundary_coords; +} + +std::vector MeshDeviceView::get_line_coordinates(size_t length, const Coordinate& offset) const { + std::vector line_coords; + auto [row, col] = offset; + bool left_to_right = true; + + for (size_t i = 0; i < length && row < num_rows() && col < num_cols(); ++i) { + line_coords.emplace_back(Coordinate{row, col}); + + if (left_to_right) { + if (col < num_cols() - 1) { + ++col; + } else { + ++row; + left_to_right = false; + } + } else { + if (col > 0) { + --col; + } else { + ++row; + left_to_right = true; + } + } + } + + return line_coords; +} + +std::vector MeshDeviceView::get_line_devices() { + auto boundary_coords = get_line_coordinates(this->num_rows() * this->num_cols(), this->top_left_); + return get_devices_from_coordinates(*this, boundary_coords); +} + +std::vector MeshDeviceView::get_ring_devices() { + auto boundary_coords = get_ring_coordinates(shape(), this->top_left_); + return get_devices_from_coordinates(*this, boundary_coords); +} + +MeshDeviceView::DeviceView MeshDeviceView::get_devices(IterationOrder order) { + switch (order) { + case IterationOrder::ROW_MAJOR: + return this->devices_; + case IterationOrder::RING: + return this->get_ring_devices(); + case IterationOrder::LINE: + return this->get_line_devices(); + default: + TT_THROW("Unsupported iteration order: {}", order); + } +} } // namespace tt::tt_metal diff --git a/tt_metal/impl/device/mesh_device_view.hpp b/tt_metal/impl/device/mesh_device_view.hpp index 73c9e2b61c20..9e6f40eb39cc 100644 --- a/tt_metal/impl/device/mesh_device_view.hpp +++ b/tt_metal/impl/device/mesh_device_view.hpp @@ -53,6 +53,13 @@ struct Coordinate { * specific sub-regions. This is particularly useful for collective communication operations * (CCL-ops), such as line all-gather, which require column or row views of the device mesh. */ + +enum class IterationOrder { + ROW_MAJOR, + RING, + LINE +}; + class MeshDeviceView { public: using device_pointer = Device*; @@ -68,12 +75,11 @@ class MeshDeviceView { [[nodiscard]] device_pointer get_device(size_t row, size_t col); [[nodiscard]] const_device_pointer get_device(size_t row, size_t col) const; - [[nodiscard]] const std::vector& get_devices() const; - // Get devices spanning the rectangular region defined by the top-left and bottom-right coordinates // devices are returned in row-major order with start/end coordinates inclusive [[nodiscard]] DeviceView get_devices(const Coordinate& start, const Coordinate& end); [[nodiscard]] DeviceView get_devices(const MeshShape& shape); + [[nodiscard]] DeviceView get_devices(IterationOrder order = IterationOrder::ROW_MAJOR); [[nodiscard]] DeviceView get_devices_on_row(size_t row) const; [[nodiscard]] DeviceView get_devices_on_column(size_t col) const; @@ -86,7 +92,7 @@ class MeshDeviceView { [[nodiscard]] bool empty() const noexcept; [[nodiscard]] size_t size() const noexcept; - [[nodiscard]] std::pair shape() const noexcept; + [[nodiscard]] MeshShape shape() const noexcept; [[nodiscard]] bool contains(const Coordinate& coord) const noexcept; [[nodiscard]] const_device_pointer at(const Coordinate& coord) const noexcept; @@ -99,10 +105,17 @@ class MeshDeviceView { [[nodiscard]] std::size_t num_cols() const { return bottom_right_.col - top_left_.col + 1; } [[nodiscard]] std::size_t num_devices() const { return devices_.size(); } + [[nodiscard]] bool contains_device(chip_id_t device_id) const; [[nodiscard]] Coordinate find_device(chip_id_t device_id) const; [[nodiscard]] chip_id_t find_device_id(const Coordinate& coord) const; private: + [[nodiscard]] std::vector get_ring_devices(); + [[nodiscard]] std::vector get_line_devices(); + + [[nodiscard]] std::vector get_ring_coordinates(const MeshShape& shape, const Coordinate& offset) const; + [[nodiscard]] std::vector get_line_coordinates(size_t length, const Coordinate& offset) const; + std::vector devices_; std::unordered_map device_coordinates_; Coordinate top_left_; diff --git a/ttnn/cpp/pybind11/multi_device.hpp b/ttnn/cpp/pybind11/multi_device.hpp index 70d9755d0400..f468217ebe7b 100644 --- a/ttnn/cpp/pybind11/multi_device.hpp +++ b/ttnn/cpp/pybind11/multi_device.hpp @@ -47,6 +47,7 @@ void py_module(py::module& module) { py::arg("offset"), py::arg("physical_device_ids")) .def("get_num_devices", &MeshDevice::num_devices) + .def("get_mesh_id", &MeshDevice::get_mesh_id) .def("get_device_ids", &MeshDevice::get_device_ids) .def( "get_device", @@ -62,26 +63,7 @@ void py_module(py::module& module) { Returns: List[Device]: The devices in the device mesh. )doc") - .def( - "get_devices_on_row", - &MeshDevice::get_devices_on_row, - py::return_value_policy::reference, - R"doc( - Get the devices in a row of the device mesh. - - Returns: - List[Device]: The devices on a row in the device mesh. - )doc") - .def( - "get_devices_on_column", - &MeshDevice::get_devices_on_column, - py::return_value_policy::reference, - R"doc( - Get the devices in a row of the device mesh. - - Returns: - List[Device]: The devices on a row in the device mesh. - )doc") + .def("create_submesh", &MeshDevice::create_submesh, py::arg("submesh_shape"), py::arg("offset") = std::pair{0, 0}, py::return_value_policy::reference_internal, py::keep_alive<0, 1>()) .def( "compute_with_storage_grid_size", &MeshDevice::compute_with_storage_grid_size, diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp index 66f70cb1ef1c..80fe9642e647 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp @@ -6,6 +6,7 @@ #include "ttnn/deprecated/tt_dnn/op_library/math.hpp" #include "tt_metal/host_api.hpp" +#include "tt_metal/impl/device/mesh_device.hpp" #include "ttnn/tensor/tensor_utils.hpp" @@ -191,17 +192,18 @@ Tensor all_gather( if (num_devices == 2){ ccl_topology = ttnn::ccl::Topology::Linear; } + auto mesh_device = MeshDevice::fetch_mesh_device(devices); std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor}))}; operation::launch_op( - [dim, num_links, memory_config, user_defined_num_workers, user_defined_num_buffers_per_channel, devices, ccl_topology]( + [=]( const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors) mutable -> std::vector { const auto& input_tensor = input_tensors.at(0); - + auto submesh_view = mesh_device->get_view(input_tensor.device()); return operation::run( - create_all_gather_struct(input_tensor, dim, num_links, memory_config, user_defined_num_workers, user_defined_num_buffers_per_channel, devices, ccl_topology), + create_all_gather_struct(input_tensor, dim, num_links, memory_config, user_defined_num_workers, user_defined_num_buffers_per_channel, submesh_view->get_devices(IterationOrder::RING), ccl_topology), {input_tensor}); }, {input_tensor}, diff --git a/ttnn/cpp/ttnn/tensor/tensor.cpp b/ttnn/cpp/ttnn/tensor/tensor.cpp index d72b0bf50f8c..33e59b484b14 100644 --- a/ttnn/cpp/ttnn/tensor/tensor.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor.cpp @@ -698,7 +698,7 @@ Tensor allocate_tensor_on_device( const std::optional& tile ) { // Top level wrapper to asynchronously create a device tensor (multi-device) - Tensor device_tensor = Tensor(mesh_device->get_devices()); + Tensor device_tensor = Tensor(mesh_device->get_view()->get_devices(IterationOrder::LINE)); uint32_t device_tensor_ref_count = device_tensor.tensor_attributes->record_main_thread_ref_count(); const auto& workers = device_tensor.get_workers(); uint32_t num_workers = workers.size(); @@ -791,14 +791,16 @@ std::vector distribute_tensor_to_mesh(const Tensor& tensor, MeshDevice& return workers; }; - if (mesh_device.get_view() != nullptr and std::holds_alternative(tensor.get_storage())) { + auto mesh_view = mesh_device.get_view(); + if (mesh_view != nullptr and std::holds_alternative(tensor.get_storage())) { const auto& host_storage = std::get(tensor.get_storage()); return std::visit([&](const auto& strategy) { using StrategyType = std::decay_t; if constexpr (std::is_same_v) { - auto mesh_view = mesh_device.get_view(); return mesh_view->get_devices(strategy.shard_mesh); + } else if constexpr (std::is_same_v) { + return mesh_view->get_devices(IterationOrder::LINE); } else { return get_multi_device_workers(mesh_device.get_devices()); }