Skip to content

Commit

Permalink
#0: Fix T3000 MeshDevice fixture to reflect usage
Browse files Browse the repository at this point in the history
  • Loading branch information
cfjchu committed Oct 2, 2024
1 parent fbe2897 commit 7de6ae0
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 15 deletions.
2 changes: 1 addition & 1 deletion conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def t3k_mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, device

request.node.pci_ids = ttnn.get_pcie_device_ids()
mesh_device = ttnn.open_mesh_device(
ttnn.MeshShape(2, 4),
ttnn.MeshShape(1, 8),
dispatch_core_type=get_dispatch_core_type(),
**device_params,
)
Expand Down
4 changes: 4 additions & 0 deletions tests/ttnn/unit_tests/test_multi_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,3 +587,7 @@ def test_validate_as_tensor(tmp_path, mesh_device, height, width):
for device in mesh_device.get_devices():
device_tensor = ttnn.get_device_tensor(tensor, device)
assert torch.allclose(ttnn.to_torch(device_tensor), torch_input_tensor)


def test_visualize_mesh_device(t3k_mesh_device):
ttnn.visualize_mesh_device(t3k_mesh_device)
2 changes: 1 addition & 1 deletion tt_metal/impl/device/mesh_configurations/T3000.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"logical_to_physical_coordinates": [
[[0, 0], [0, 0, 0, 0]], [[0, 1], [0, 1, 0, 0]], [[0, 2], [0, 2, 0, 0]], [[0, 3], [0, 3, 0, 0]],
[[1, 0], [1, 3, 0, 0]], [[1, 1], [1, 2, 0, 0]], [[1, 2], [1, 1, 0, 0]], [[1, 3], [1, 0, 0, 0]]
[[1, 0], [1, 0, 0, 0]], [[1, 1], [1, 1, 0, 0]], [[1, 2], [1, 2, 0, 0]], [[1, 3], [1, 3, 0, 0]]
]
}
28 changes: 21 additions & 7 deletions tt_metal/impl/device/mesh_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,16 +113,29 @@ std::vector<chip_id_t> SystemMesh::get_mapped_physical_device_ids(const MeshDevi
auto [requested_rows, requested_cols] = config.mesh_shape;
auto [row_offset, col_offset] = config.offset;

for (int row = 0; row < requested_rows; row++) {
for (int col = 0; col < requested_cols; col++) {
auto logical_device_id = (row + row_offset) * system_mesh_cols + (col + col_offset);
auto logical_coordinate = Coordinate{logical_device_id / system_mesh_cols, logical_device_id % system_mesh_cols};
if (requested_rows == 1) {
TT_FATAL(row_offset == 0 and col_offset == 0, "Row and column offsets unsupported for single row mesh");
auto line_coords = MeshDeviceView::get_line_coordinates(requested_cols, Coordinate{row_offset, col_offset}, system_mesh_rows, system_mesh_cols);
for (const auto& logical_coordinate : line_coords) {
auto physical_coordinate = this->logical_to_physical_coordinates.at(logical_coordinate);
auto physical_device_id = this->physical_coordinate_to_device_id.at(physical_coordinate);
physical_device_ids.push_back(physical_device_id);

log_debug(LogMetal, "Logical device ID: {}, Logical coordinate: {}, Physical coordinate: {}, Physical device ID: {}",
logical_device_id, logical_coordinate, physical_coordinate, physical_device_id);
log_debug(LogMetal, "Logical coordinate: {}, Physical coordinate: {}, Physical device ID: {}",
logical_coordinate, physical_coordinate, physical_device_id);
}
} else {
for (int row = 0; row < requested_rows; row++) {
for (int col = 0; col < requested_cols; col++) {
auto logical_device_id = (row + row_offset) * system_mesh_cols + (col + col_offset);
auto logical_coordinate = Coordinate{logical_device_id / system_mesh_cols, logical_device_id % system_mesh_cols};
auto physical_coordinate = this->logical_to_physical_coordinates.at(logical_coordinate);
auto physical_device_id = this->physical_coordinate_to_device_id.at(physical_coordinate);
physical_device_ids.push_back(physical_device_id);

log_debug(LogMetal, "Logical device ID: {}, Logical coordinate: {}, Physical coordinate: {}, Physical device ID: {}",
logical_device_id, logical_coordinate, physical_coordinate, physical_device_id);
}
}
}
return physical_device_ids;
Expand Down Expand Up @@ -313,7 +326,8 @@ std::vector<int> get_t3k_physical_device_ids_ring() {
auto num_devices = instance.get_num_devices();
TT_FATAL(num_devices == 8, "T3000 ring topology only works with 8 devices");

auto physical_device_ids = instance.get_mapped_physical_device_ids(MeshDeviceConfig{instance.get_shape(), MeshOffset{0, 0}});
auto physical_device_ids = instance.get_mapped_physical_device_ids(
MeshDeviceConfig{MeshShape{1, 8}, MeshOffset{0, 0}});
return physical_device_ids;
}

Expand Down
1 change: 0 additions & 1 deletion tt_metal/impl/device/mesh_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#include <optional>
#include <vector>

#include "mesh_device_view.hpp"
#include "tt_metal/impl/device/device.hpp"
#include "tt_metal/impl/device/mesh_device_view.hpp"

Expand Down
69 changes: 64 additions & 5 deletions tt_metal/impl/device/mesh_device_view.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
// SPDX-License-Identifier: Apache-2.0

#include "tt_metal/impl/device/mesh_device_view.hpp"
#include "tt_metal/impl/device/mesh_device.hpp"

#include <algorithm>
#include <stdexcept>

#include "tt_metal/impl/device/mesh_device.hpp"

namespace tt::tt_metal {

using MeshDevice = tt::tt_metal::MeshDevice;
Expand Down Expand Up @@ -55,9 +57,7 @@ MeshDeviceView::const_device_pointer MeshDeviceView::get_device(size_t row, size
return nullptr;
}

const std::vector<MeshDeviceView::device_pointer>& MeshDeviceView::get_devices() const {
return devices_;
}
const std::vector<MeshDeviceView::device_pointer>& MeshDeviceView::get_devices() const { return devices_; }

MeshDeviceView::DeviceView MeshDeviceView::get_devices(const Coordinate& start, const Coordinate& end) {
if (start.row > end.row || start.col > end.col) {
Expand Down Expand Up @@ -194,10 +194,69 @@ void MeshDeviceView::initialize_from_devices(const std::vector<device_pointer>&
bottom_right_ = {max_row, max_col};
}

std::vector<Coordinate> MeshDeviceView::get_line_coordinates(
size_t length, const Coordinate& offset, size_t num_rows, size_t num_cols) {
std::vector<Coordinate> line_coords;
auto [row_index, col_index] = offset;
bool left_to_right = true;

for (size_t i = 0; i < length && row_index < num_rows && col_index < num_cols; ++i) {
line_coords.emplace_back(Coordinate{row_index, col_index});

if (left_to_right && col_index < num_cols - 1) {
col_index++;
} else if (!left_to_right && col_index > 0) {
col_index--;
} else {
row_index++;
left_to_right = !left_to_right;
}
}

TT_FATAL(line_coords.size() == length, "Failed to get line coordinates");
return line_coords;
}

std::vector<Coordinate> MeshDeviceView::get_ring_coordinates(const MeshShape& ring_shape, const Coordinate& offset, size_t num_rows, size_t num_cols) {
auto [start_row, start_col] = offset;
auto [ring_rows, ring_cols] = ring_shape;
auto end_row = start_row + ring_rows - 1;
auto end_col = start_col + ring_cols - 1;

// Validate the specified subgrid
std::vector<Coordinate> boundary_coords;
if (start_row + ring_rows > num_rows || start_col + ring_cols > num_cols) {
throw std::invalid_argument("Subgrid is out of mesh bounds.");
}

// Traverse the top row from left to right
for (size_t col = start_col; col <= end_col; ++col) {
boundary_coords.emplace_back(Coordinate{start_row, col});
}

// Traverse the rightmost column from top+1 to bottom
for (size_t row = start_row + 1; row <= end_row; ++row) {
boundary_coords.emplace_back(Coordinate{row, end_col});
}

// Traverse the bottom row from right to left, if there is more than one row
if (ring_rows > 1 and ring_cols > 1) {
for (size_t col = end_col - 1; col >= start_col; --col) {
boundary_coords.emplace_back(Coordinate{end_row, col});
}
for (size_t row = end_row - 1; row >= start_row; --row) {
boundary_coords.emplace_back(Coordinate{row, start_col});
}
}

return boundary_coords;
}


void MeshDeviceView::validate_coordinates() const {
if (top_left_.row > bottom_right_.row || top_left_.col > bottom_right_.col) {
throw std::invalid_argument("Invalid coordinates: top_left must be less than or equal to bottom_right");
}
}

} // namespace tt::tt_metal
} // namespace tt::tt_metal
5 changes: 5 additions & 0 deletions tt_metal/impl/device/mesh_device_view.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ class MeshDeviceView {
[[nodiscard]] Coordinate find_device(chip_id_t device_id) const;
[[nodiscard]] chip_id_t find_device_id(const Coordinate& coord) const;

// Given a starting coordinate, get the coordinates of a line of devices where device[i-1] is connected to device[i]
// The current support only provides left-to-right and right-to-left snaking of the line.
[[nodiscard]] static std::vector<Coordinate> get_line_coordinates(size_t length, const Coordinate& offset, size_t num_rows, size_t num_cols);
[[nodiscard]] std::vector<Coordinate> get_ring_coordinates(const MeshShape& ring_shape, const Coordinate& offset, size_t num_rows, size_t num_cols) const;

private:
std::vector<device_pointer> devices_;
std::unordered_map<chip_id_t, Coordinate> device_coordinates_;
Expand Down

0 comments on commit 7de6ae0

Please sign in to comment.