diff --git a/conftest.py b/conftest.py index 8ce28b82872c..88ff11111fc9 100644 --- a/conftest.py +++ b/conftest.py @@ -280,7 +280,7 @@ def t3k_mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, device request.node.pci_ids = ttnn.get_pcie_device_ids() mesh_device = ttnn.open_mesh_device( - ttnn.MeshShape(2, 4), + ttnn.MeshShape(1, 8), dispatch_core_type=get_dispatch_core_type(), **device_params, ) diff --git a/tests/ttnn/unit_tests/test_multi_device.py b/tests/ttnn/unit_tests/test_multi_device.py index f1c2728857f0..32eb74e3c999 100644 --- a/tests/ttnn/unit_tests/test_multi_device.py +++ b/tests/ttnn/unit_tests/test_multi_device.py @@ -587,3 +587,7 @@ def test_validate_as_tensor(tmp_path, mesh_device, height, width): for device in mesh_device.get_devices(): device_tensor = ttnn.get_device_tensor(tensor, device) assert torch.allclose(ttnn.to_torch(device_tensor), torch_input_tensor) + + +def test_visualize_mesh_device(t3k_mesh_device): + ttnn.visualize_mesh_device(t3k_mesh_device) diff --git a/tt_metal/impl/device/mesh_configurations/T3000.json b/tt_metal/impl/device/mesh_configurations/T3000.json index 2c62209d01fc..acfe3edac004 100644 --- a/tt_metal/impl/device/mesh_configurations/T3000.json +++ b/tt_metal/impl/device/mesh_configurations/T3000.json @@ -1,6 +1,6 @@ { "logical_to_physical_coordinates": [ [[0, 0], [0, 0, 0, 0]], [[0, 1], [0, 1, 0, 0]], [[0, 2], [0, 2, 0, 0]], [[0, 3], [0, 3, 0, 0]], - [[1, 0], [1, 3, 0, 0]], [[1, 1], [1, 2, 0, 0]], [[1, 2], [1, 1, 0, 0]], [[1, 3], [1, 0, 0, 0]] + [[1, 0], [1, 0, 0, 0]], [[1, 1], [1, 1, 0, 0]], [[1, 2], [1, 2, 0, 0]], [[1, 3], [1, 3, 0, 0]] ] } diff --git a/tt_metal/impl/device/mesh_device.cpp b/tt_metal/impl/device/mesh_device.cpp index e90d4a8925e2..1798df086c17 100644 --- a/tt_metal/impl/device/mesh_device.cpp +++ b/tt_metal/impl/device/mesh_device.cpp @@ -113,16 +113,29 @@ std::vector SystemMesh::get_mapped_physical_device_ids(const MeshDevi auto [requested_rows, requested_cols] = config.mesh_shape; auto [row_offset, col_offset] = config.offset; - for (int row = 0; row < requested_rows; row++) { - for (int col = 0; col < requested_cols; col++) { - auto logical_device_id = (row + row_offset) * system_mesh_cols + (col + col_offset); - auto logical_coordinate = Coordinate{logical_device_id / system_mesh_cols, logical_device_id % system_mesh_cols}; + if (requested_rows == 1) { + TT_FATAL(row_offset == 0 and col_offset == 0, "Row and column offsets unsupported for single row mesh"); + auto line_coords = MeshDeviceView::get_line_coordinates(requested_cols, Coordinate{row_offset, col_offset}, system_mesh_rows, system_mesh_cols); + for (const auto& logical_coordinate : line_coords) { auto physical_coordinate = this->logical_to_physical_coordinates.at(logical_coordinate); auto physical_device_id = this->physical_coordinate_to_device_id.at(physical_coordinate); physical_device_ids.push_back(physical_device_id); - log_debug(LogMetal, "Logical device ID: {}, Logical coordinate: {}, Physical coordinate: {}, Physical device ID: {}", - logical_device_id, logical_coordinate, physical_coordinate, physical_device_id); + log_debug(LogMetal, "Logical coordinate: {}, Physical coordinate: {}, Physical device ID: {}", + logical_coordinate, physical_coordinate, physical_device_id); + } + } else { + for (int row = 0; row < requested_rows; row++) { + for (int col = 0; col < requested_cols; col++) { + auto logical_device_id = (row + row_offset) * system_mesh_cols + (col + col_offset); + auto logical_coordinate = Coordinate{logical_device_id / system_mesh_cols, logical_device_id % system_mesh_cols}; + auto physical_coordinate = this->logical_to_physical_coordinates.at(logical_coordinate); + auto physical_device_id = this->physical_coordinate_to_device_id.at(physical_coordinate); + physical_device_ids.push_back(physical_device_id); + + log_debug(LogMetal, "Logical device ID: {}, Logical coordinate: {}, Physical coordinate: {}, Physical device ID: {}", + logical_device_id, logical_coordinate, physical_coordinate, physical_device_id); + } } } return physical_device_ids; @@ -313,7 +326,8 @@ std::vector get_t3k_physical_device_ids_ring() { auto num_devices = instance.get_num_devices(); TT_FATAL(num_devices == 8, "T3000 ring topology only works with 8 devices"); - auto physical_device_ids = instance.get_mapped_physical_device_ids(MeshDeviceConfig{instance.get_shape(), MeshOffset{0, 0}}); + auto physical_device_ids = instance.get_mapped_physical_device_ids( + MeshDeviceConfig{MeshShape{1, 8}, MeshOffset{0, 0}}); return physical_device_ids; } diff --git a/tt_metal/impl/device/mesh_device.hpp b/tt_metal/impl/device/mesh_device.hpp index 940110973cce..6ae7c58ba792 100644 --- a/tt_metal/impl/device/mesh_device.hpp +++ b/tt_metal/impl/device/mesh_device.hpp @@ -9,7 +9,6 @@ #include #include -#include "mesh_device_view.hpp" #include "tt_metal/impl/device/device.hpp" #include "tt_metal/impl/device/mesh_device_view.hpp" diff --git a/tt_metal/impl/device/mesh_device_view.cpp b/tt_metal/impl/device/mesh_device_view.cpp index cc4a227780f6..89963b10764f 100644 --- a/tt_metal/impl/device/mesh_device_view.cpp +++ b/tt_metal/impl/device/mesh_device_view.cpp @@ -3,10 +3,12 @@ // SPDX-License-Identifier: Apache-2.0 #include "tt_metal/impl/device/mesh_device_view.hpp" -#include "tt_metal/impl/device/mesh_device.hpp" + #include #include +#include "tt_metal/impl/device/mesh_device.hpp" + namespace tt::tt_metal { using MeshDevice = tt::tt_metal::MeshDevice; @@ -55,9 +57,7 @@ MeshDeviceView::const_device_pointer MeshDeviceView::get_device(size_t row, size return nullptr; } -const std::vector& MeshDeviceView::get_devices() const { - return devices_; -} +const std::vector& MeshDeviceView::get_devices() const { return devices_; } MeshDeviceView::DeviceView MeshDeviceView::get_devices(const Coordinate& start, const Coordinate& end) { if (start.row > end.row || start.col > end.col) { @@ -194,10 +194,69 @@ void MeshDeviceView::initialize_from_devices(const std::vector& bottom_right_ = {max_row, max_col}; } +std::vector MeshDeviceView::get_line_coordinates( + size_t length, const Coordinate& offset, size_t num_rows, size_t num_cols) { + std::vector line_coords; + auto [row_index, col_index] = offset; + bool left_to_right = true; + + for (size_t i = 0; i < length && row_index < num_rows && col_index < num_cols; ++i) { + line_coords.emplace_back(Coordinate{row_index, col_index}); + + if (left_to_right && col_index < num_cols - 1) { + col_index++; + } else if (!left_to_right && col_index > 0) { + col_index--; + } else { + row_index++; + left_to_right = !left_to_right; + } + } + + TT_FATAL(line_coords.size() == length, "Failed to get line coordinates"); + return line_coords; +} + +std::vector MeshDeviceView::get_ring_coordinates(const MeshShape& ring_shape, const Coordinate& offset, size_t num_rows, size_t num_cols) { + auto [start_row, start_col] = offset; + auto [ring_rows, ring_cols] = ring_shape; + auto end_row = start_row + ring_rows - 1; + auto end_col = start_col + ring_cols - 1; + + // Validate the specified subgrid + std::vector boundary_coords; + if (start_row + ring_rows > num_rows || start_col + ring_cols > num_cols) { + throw std::invalid_argument("Subgrid is out of mesh bounds."); + } + + // Traverse the top row from left to right + for (size_t col = start_col; col <= end_col; ++col) { + boundary_coords.emplace_back(Coordinate{start_row, col}); + } + + // Traverse the rightmost column from top+1 to bottom + for (size_t row = start_row + 1; row <= end_row; ++row) { + boundary_coords.emplace_back(Coordinate{row, end_col}); + } + + // Traverse the bottom row from right to left, if there is more than one row + if (ring_rows > 1 and ring_cols > 1) { + for (size_t col = end_col - 1; col >= start_col; --col) { + boundary_coords.emplace_back(Coordinate{end_row, col}); + } + for (size_t row = end_row - 1; row >= start_row; --row) { + boundary_coords.emplace_back(Coordinate{row, start_col}); + } + } + + return boundary_coords; +} + + void MeshDeviceView::validate_coordinates() const { if (top_left_.row > bottom_right_.row || top_left_.col > bottom_right_.col) { throw std::invalid_argument("Invalid coordinates: top_left must be less than or equal to bottom_right"); } } -} // namespace tt::tt_metal +} // namespace tt::tt_metal diff --git a/tt_metal/impl/device/mesh_device_view.hpp b/tt_metal/impl/device/mesh_device_view.hpp index 73c9e2b61c20..be95a5b162c1 100644 --- a/tt_metal/impl/device/mesh_device_view.hpp +++ b/tt_metal/impl/device/mesh_device_view.hpp @@ -102,6 +102,11 @@ class MeshDeviceView { [[nodiscard]] Coordinate find_device(chip_id_t device_id) const; [[nodiscard]] chip_id_t find_device_id(const Coordinate& coord) const; + // Given a starting coordinate, get the coordinates of a line of devices where device[i-1] is connected to device[i] + // The current support only provides left-to-right and right-to-left snaking of the line. + [[nodiscard]] static std::vector get_line_coordinates(size_t length, const Coordinate& offset, size_t num_rows, size_t num_cols); + [[nodiscard]] std::vector get_ring_coordinates(const MeshShape& ring_shape, const Coordinate& offset, size_t num_rows, size_t num_cols) const; + private: std::vector devices_; std::unordered_map device_coordinates_;