diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_hop_latencies_no_edm.cpp b/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_hop_latencies_no_edm.cpp index f9adf4f13fbe..dc922d70ecb2 100644 --- a/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_hop_latencies_no_edm.cpp +++ b/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_hop_latencies_no_edm.cpp @@ -449,48 +449,48 @@ int main(int argc, char** argv) { T3000TestDevice test_fixture; auto view = test_fixture.mesh_device_->get_view(); - auto get_device_list = [](const std::shared_ptr& view, std::size_t n_hops) { + auto get_device_list = [](const MeshDeviceView& view, std::size_t n_hops) { switch (n_hops) { case 2: return std::vector{ - view->get_device(0, 0), - view->get_device(0, 1), + view.get_device(0, 0), + view.get_device(0, 1), }; case 4: return std::vector{ - view->get_device(1, 1), - view->get_device(0, 1), - view->get_device(0, 2), - view->get_device(1, 2), + view.get_device(1, 1), + view.get_device(0, 1), + view.get_device(0, 2), + view.get_device(1, 2), }; case 8: return std::vector{ - view->get_device(1, 1), - view->get_device(1, 0), - view->get_device(0, 0), - view->get_device(0, 1), - view->get_device(0, 2), - view->get_device(0, 3), - view->get_device(1, 3), - view->get_device(1, 2), + view.get_device(1, 1), + view.get_device(1, 0), + view.get_device(0, 0), + view.get_device(0, 1), + view.get_device(0, 2), + view.get_device(0, 3), + view.get_device(1, 3), + view.get_device(1, 2), }; case 12: // Does an extra loop through the inner ring return std::vector{ - view->get_device(1, 1), - view->get_device(1, 0), - view->get_device(0, 0), - view->get_device(0, 1), - view->get_device(0, 2), - view->get_device(1, 2), - view->get_device(1, 1), - view->get_device(0, 1), - view->get_device(0, 2), - view->get_device(0, 3), - view->get_device(1, 3), - view->get_device(1, 2), + view.get_device(1, 1), + view.get_device(1, 0), + view.get_device(0, 0), + view.get_device(0, 1), + view.get_device(0, 2), + view.get_device(1, 2), + view.get_device(1, 1), + view.get_device(0, 1), + view.get_device(0, 2), + view.get_device(0, 3), + view.get_device(1, 3), + view.get_device(1, 2), }; default: TT_THROW("Unsupported hop_count"); return std::vector{}; diff --git a/tests/ttnn/distributed/CMakeLists.txt b/tests/ttnn/distributed/CMakeLists.txt index f41d726988af..5823925eec31 100644 --- a/tests/ttnn/distributed/CMakeLists.txt +++ b/tests/ttnn/distributed/CMakeLists.txt @@ -1,11 +1,13 @@ add_executable( test_distributed test_distributed.cpp - test_distributed_atexit.cpp + test_distributed_reshape.cpp ) +add_executable(test_distributed_atexit test_distributed_atexit.cpp) # Set up properties for the target setup_ttnn_test_target(test_distributed) - +setup_ttnn_test_target(test_distributed_atexit) # Add test to CTest add_test(NAME test_distributed COMMAND test_distributed) +add_test(NAME test_distributed_atexit COMMAND test_distributed_atexit) diff --git a/tests/ttnn/distributed/test_distributed_reshape.cpp b/tests/ttnn/distributed/test_distributed_reshape.cpp new file mode 100644 index 000000000000..0c5ddb11f0b8 --- /dev/null +++ b/tests/ttnn/distributed/test_distributed_reshape.cpp @@ -0,0 +1,280 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include +#include +#include +#include +#include "tests/tt_metal/test_utils/env_vars.hpp" + +namespace ttnn::distributed::test { + +// Helper function to check test environment +void check_test_environment() { + auto slow_dispatch = getenv("TT_METAL_SLOW_DISPATCH_MODE"); + const auto arch = tt::get_arch_from_string(tt::test_utils::get_umd_arch_name()); + const size_t num_devices = tt::tt_metal::GetNumAvailableDevices(); + if (slow_dispatch) { + GTEST_SKIP() << "Skipping Multi-Device test suite, since it can only be run in Fast Dispatch Mode."; + } + if (num_devices < 8 or arch != tt::ARCH::WORMHOLE_B0) { + GTEST_SKIP() << "Skipping T3K Multi-Device test suite on non T3K machine."; + } +} + +std::vector get_physical_device_ids(const MeshDevice& mesh) { + std::vector device_ids; + for (auto* device : mesh.get_devices(ttnn::distributed::MeshType::RowMajor)) { + device_ids.push_back(device->id()); + } + return device_ids; +} + +static constexpr std::array kMeshShapes{ + {{1, 1}, {1, 2}, {1, 3}, {1, 4}, {1, 5}, {1, 6}, {1, 7}, {1, 8}, {2, 1}, {2, 2}, {2, 3}, {2, 4}, + {3, 1}, {3, 2}, {4, 1}, {4, 2}, {8, 1}, {7, 1}, {6, 1}, {5, 1}, {4, 1}, {3, 1}, {2, 1}, {1, 1}}}; + +class MeshConfigurationTest : public ::testing::TestWithParam { +protected: + void SetUp() override { check_test_environment(); } +}; + +TEST_P(MeshConfigurationTest, TestMeshConfigurations) { + const auto& shape = GetParam(); + auto mesh = ttnn::distributed::open_mesh_device( + {shape.num_rows, shape.num_cols}, + DEFAULT_L1_SMALL_SIZE, + DEFAULT_TRACE_REGION_SIZE, + 1, + tt::tt_metal::DispatchCoreType::WORKER); + EXPECT_EQ(mesh->num_rows(), shape.num_rows); + EXPECT_EQ(mesh->num_cols(), shape.num_cols); + ttnn::distributed::close_mesh_device(mesh); +} + +// Test all possible mesh configurations on T3000 +INSTANTIATE_TEST_SUITE_P(MeshShapes, MeshConfigurationTest, ::testing::ValuesIn(kMeshShapes)); + +class MeshReshapeTest : public ::testing::TestWithParam> { +protected: + void SetUp() override { check_test_environment(); } +}; + +TEST_P(MeshReshapeTest, TestReshapeBetweenConfigurations) { + const auto& [old_shape, new_shape] = GetParam(); + + if ((old_shape.num_rows * old_shape.num_cols) != (new_shape.num_rows * new_shape.num_cols)) { + GTEST_SKIP() << "Device counts don't match; we test this in InvalidReshapeDimensions"; + } + if (old_shape.num_rows == 1 or old_shape.num_cols == 1) { + GTEST_SKIP() << "Old shape is 1xN or Nx1; we test this in From1x4To2x2Invalid"; + } + + auto mesh = ttnn::distributed::open_mesh_device( + {old_shape.num_rows, old_shape.num_cols}, + DEFAULT_L1_SMALL_SIZE, + DEFAULT_TRACE_REGION_SIZE, + 1, + tt::tt_metal::DispatchCoreType::WORKER); + + EXPECT_EQ(mesh->num_rows(), old_shape.num_rows); + EXPECT_EQ(mesh->num_cols(), old_shape.num_cols); + + auto original_order = get_physical_device_ids(*mesh); + + // Attempt reshape + mesh->reshape({new_shape.num_rows, new_shape.num_cols}); + + // Verify new shape + EXPECT_EQ(mesh->num_rows(), new_shape.num_rows); + EXPECT_EQ(mesh->num_cols(), new_shape.num_cols); + + // Verify device ordering is preserved + EXPECT_EQ(get_physical_device_ids(*mesh), original_order); +} + +// Generate all possible combinations of shapes from kMeshShapes +INSTANTIATE_TEST_SUITE_P( + ReshapeConfigurations, + MeshReshapeTest, + ::testing::Combine(::testing::ValuesIn(kMeshShapes), ::testing::ValuesIn(kMeshShapes))); + +// Base class for non-parameterized tests +class T3000ReshapeTest : public ::testing::Test { +protected: + void SetUp() override { check_test_environment(); } +}; + +TEST_F(T3000ReshapeTest, InvalidReshapeDimensions) { + auto mesh = ttnn::distributed::open_mesh_device( + {1, 8}, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, 1, tt::tt_metal::DispatchCoreType::WORKER); + + // Test reshaping to dimensions that don't match total device count + EXPECT_THROW(mesh->reshape({3, 3}), std::runtime_error); // 9 devices != 8 + EXPECT_THROW(mesh->reshape({1, 9}), std::runtime_error); // 9 devices != 8 + + // Verify original shape is preserved after failed reshapes + EXPECT_EQ(mesh->num_rows(), 1); + EXPECT_EQ(mesh->num_cols(), 8); +} + +TEST_F(T3000ReshapeTest, From1x8To2x4) { + auto mesh = ttnn::distributed::open_mesh_device( + {1, 8}, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, 1, tt::tt_metal::DispatchCoreType::WORKER); + + EXPECT_EQ(mesh->num_rows(), 1); + EXPECT_EQ(mesh->num_cols(), 8); + auto original_order = get_physical_device_ids(*mesh); + + mesh->reshape({2, 4}); + EXPECT_EQ(mesh->num_rows(), 2); + EXPECT_EQ(mesh->num_cols(), 4); + auto new_order = get_physical_device_ids(*mesh); + EXPECT_EQ(original_order, new_order); +} + +TEST_F(T3000ReshapeTest, OnRingTopology) { + auto mesh = ttnn::distributed::open_mesh_device( + {1, 8}, + DEFAULT_L1_SMALL_SIZE, + DEFAULT_TRACE_REGION_SIZE, + 1, + tt::tt_metal::DispatchCoreType::WORKER, + ttnn::distributed::MeshType::Ring); + + EXPECT_EQ(mesh->num_rows(), 1); + EXPECT_EQ(mesh->num_cols(), 8); + auto original_order = get_physical_device_ids(*mesh); + + mesh->reshape({2, 4}); + + EXPECT_EQ(mesh->num_rows(), 2); + EXPECT_EQ(mesh->num_cols(), 4); + auto new_order = get_physical_device_ids(*mesh); + EXPECT_EQ(original_order, new_order); +} + +TEST_F(T3000ReshapeTest, InvalidTotalDeviceCount) { + auto mesh = ttnn::distributed::open_mesh_device( + {1, 8}, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, 1, tt::tt_metal::DispatchCoreType::WORKER); + + // Test reshaping to dimensions that don't match total device count + EXPECT_THROW(mesh->reshape({3, 3}), std::runtime_error); // 9 devices != 8 + EXPECT_THROW(mesh->reshape({1, 9}), std::runtime_error); // 9 devices != 8 + + // Verify original shape is preserved after failed reshapes + EXPECT_EQ(mesh->num_rows(), 1); + EXPECT_EQ(mesh->num_cols(), 8); +} + +TEST_F(T3000ReshapeTest, MultipleReshapes) { + auto mesh = ttnn::distributed::open_mesh_device( + {1, 8}, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, 1, tt::tt_metal::DispatchCoreType::WORKER); + + auto original_order = get_physical_device_ids(*mesh); + + // Test multiple reshapes + mesh->reshape({2, 4}); // 1x8 -> 2x4 + auto order1 = get_physical_device_ids(*mesh); + EXPECT_EQ(order1, original_order); + + mesh->reshape({4, 2}); // 2x4 -> 4x2 + auto order2 = get_physical_device_ids(*mesh); + EXPECT_EQ(order2, original_order); + + mesh->reshape({1, 8}); // 4x2 -> 1x8 (back to original) + auto final_order = get_physical_device_ids(*mesh); + EXPECT_EQ(final_order, original_order); +} + +TEST_F(T3000ReshapeTest, RingPreservation) { + auto mesh = ttnn::distributed::open_mesh_device( + {1, 8}, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, 1, tt::tt_metal::DispatchCoreType::WORKER); + + // Store original device positions + std::vector original_layout; + for (size_t i = 0; i < mesh->num_rows(); ++i) { + for (size_t j = 0; j < mesh->num_cols(); ++j) { + original_layout.push_back(mesh->get_device(i, j)->id()); + } + } + + mesh->reshape({2, 4}); + + // Verify devices are still connected in a Ring topology + std::vector new_layout; + for (size_t i = 0; i < mesh->num_rows(); ++i) { + for (size_t j = 0; j < mesh->num_cols(); ++j) { + new_layout.push_back(mesh->get_device(i, j)->id()); + } + } + EXPECT_EQ(new_layout, original_layout); +} + +TEST_F(T3000ReshapeTest, From1x4To2x2Invalid) { + auto mesh = ttnn::distributed::open_mesh_device( + {1, 4}, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, 1, tt::tt_metal::DispatchCoreType::WORKER); + + // This is an invalid reshape because the 1x4 mesh does not fully cover the 2x2 mesh + EXPECT_THROW(mesh->reshape({2, 2}), std::runtime_error); +} + +TEST_F(T3000ReshapeTest, From1x4To2x2Valid) { + auto& system_mesh = tt::tt_metal::distributed::SystemMesh::instance(); + + // Fetch the device ids for a physically connected 2x2 mesh. + auto physical_device_ids = system_mesh.get_mapped_physical_device_ids( + MeshDeviceConfig(MeshShape{2, 2}, ttnn::distributed::MeshType::Line)); + + // Supply the physical device ids to the mesh constructor that we know we know is 2x2 physically connected. + // We will create a 1x4 mesh and then reshape it to 2x2. + auto mesh = ttnn::distributed::open_mesh_device( + {1, 4}, + DEFAULT_L1_SMALL_SIZE, + DEFAULT_TRACE_REGION_SIZE, + 1, + tt::tt_metal::DispatchCoreType::WORKER, + ttnn::distributed::MeshType::Line, + MeshOffset{0, 0}, + physical_device_ids); + + mesh->reshape({2, 2}); + EXPECT_EQ(mesh->num_rows(), 2); + EXPECT_EQ(mesh->num_cols(), 2); + auto new_layout = get_physical_device_ids(*mesh); + for (auto physical_device_id : physical_device_ids) { + EXPECT_TRUE(std::find(new_layout.begin(), new_layout.end(), physical_device_id) != new_layout.end()); + } +} + +TEST_F(T3000ReshapeTest, From2x2To1x4) { + auto mesh = ttnn::distributed::open_mesh_device( + {2, 2}, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, 1, tt::tt_metal::DispatchCoreType::WORKER); + + std::vector original_layout; + for (size_t i = 0; i < mesh->num_rows(); ++i) { + for (size_t j = 0; j < mesh->num_cols(); ++j) { + auto id = mesh->get_device(i, j)->id(); + original_layout.push_back(id); + } + } + + mesh->reshape({1, 4}); + EXPECT_EQ(mesh->num_rows(), 1); + EXPECT_EQ(mesh->num_cols(), 4); + + std::vector new_layout; + for (size_t i = 0; i < mesh->num_rows(); ++i) { + for (size_t j = 0; j < mesh->num_cols(); ++j) { + auto id = mesh->get_device(i, j)->id(); + new_layout.push_back(id); + } + } + + EXPECT_EQ(new_layout, original_layout); +} + +} // namespace ttnn::distributed::test diff --git a/tt_metal/distributed/mesh_device.cpp b/tt_metal/distributed/mesh_device.cpp index fb709c39901e..073c92b18314 100644 --- a/tt_metal/distributed/mesh_device.cpp +++ b/tt_metal/distributed/mesh_device.cpp @@ -26,7 +26,8 @@ static std::string get_config_path(const std::string& filename) { return root_path + "/tt_metal/distributed/mesh_configurations/" + filename; } -static std::unordered_map load_translation_map(const std::string& filename, const std::string& key) { +static std::unordered_map load_translation_map( + const std::string& filename, const std::string& key) { std::ifstream file(filename); if (!file.is_open()) { throw std::runtime_error("Unable to open file: " + filename); @@ -72,6 +73,7 @@ class SystemMesh::Impl { MeshShape logical_mesh_shape_; std::unordered_map logical_to_physical_coordinates_; + std::unordered_map logical_to_device_id_; std::unordered_map physical_coordinate_to_device_id_; std::unordered_map physical_device_id_to_coordinate_; @@ -91,6 +93,8 @@ class SystemMesh::Impl { static MeshShape get_system_mesh_shape(size_t system_num_devices); static std::unordered_map get_system_mesh_translation_map( size_t system_num_devices); + + chip_id_t get_physical_device_id(size_t logical_row_idx, size_t logical_col_idx) const; }; // Implementation of private static methods @@ -152,6 +156,9 @@ void SystemMesh::Impl::initialize() { auto num_devices = physical_coordinate_to_device_id_.size(); logical_mesh_shape_ = get_system_mesh_shape(num_devices); logical_to_physical_coordinates_ = get_system_mesh_translation_map(num_devices); + for (const auto& [logical_coordinate, physical_coordinate] : logical_to_physical_coordinates_) { + logical_to_device_id_.emplace(logical_coordinate, physical_coordinate_to_device_id_.at(physical_coordinate)); + } } const MeshShape& SystemMesh::Impl::get_shape() const { return logical_mesh_shape_; } @@ -160,50 +167,125 @@ size_t SystemMesh::Impl::get_num_devices() const { return num_rows * num_cols; } +chip_id_t SystemMesh::Impl::get_physical_device_id(size_t logical_row_idx, size_t logical_col_idx) const { + TT_FATAL( + logical_row_idx < logical_mesh_shape_.num_rows, + "Row index out of bounds: {} >= {}", + logical_row_idx, + logical_mesh_shape_.num_rows); + TT_FATAL( + logical_col_idx < logical_mesh_shape_.num_cols, + "Column index out of bounds: {} >= {}", + logical_col_idx, + logical_mesh_shape_.num_cols); + auto logical_coordinate = Coordinate{logical_row_idx, logical_col_idx}; + return logical_to_device_id_.at(logical_coordinate); +} + std::vector SystemMesh::Impl::get_mapped_physical_device_ids(const MeshDeviceConfig& config) const { std::vector physical_device_ids; auto [system_mesh_rows, system_mesh_cols] = this->get_shape(); - auto [requested_rows, requested_cols] = config.mesh_shape; + auto [requested_num_rows, requested_num_cols] = config.mesh_shape; auto [row_offset, col_offset] = config.offset; - if (requested_rows == 1) { + // First check if total size fits + TT_FATAL( + requested_num_rows * requested_num_cols <= system_mesh_rows * system_mesh_cols, + "Requested submesh is too big: {}x{}", + requested_num_rows, + requested_num_cols); + + bool is_single_row_or_column = requested_num_rows == 1 or requested_num_cols == 1; + if (is_single_row_or_column) { TT_FATAL(row_offset == 0 and col_offset == 0, "Row and column offsets unsupported for single row mesh"); + auto line_length = requested_num_rows * requested_num_cols; auto line_coords = MeshDeviceView::get_line_coordinates( - requested_cols, Coordinate{row_offset, col_offset}, system_mesh_rows, system_mesh_cols); + line_length, Coordinate{row_offset, col_offset}, system_mesh_rows, system_mesh_cols); for (const auto& logical_coordinate : line_coords) { - auto physical_coordinate = logical_to_physical_coordinates_.at(logical_coordinate); - auto physical_device_id = physical_coordinate_to_device_id_.at(physical_coordinate); + auto physical_device_id = logical_to_device_id_.at(logical_coordinate); physical_device_ids.push_back(physical_device_id); log_debug( LogMetal, - "Logical coordinate: {}, Physical coordinate: {}, Physical device ID: {}", + "Logical coordinate: {}, Physical device ID: {}", logical_coordinate, - physical_coordinate, physical_device_id); } + return physical_device_ids; + } + bool requires_rotation = requested_num_rows > system_mesh_rows || requested_num_cols > system_mesh_cols; + + + if (requires_rotation) { + bool can_rotate = requested_num_rows <= system_mesh_cols && requested_num_cols <= system_mesh_rows; + if (can_rotate) { + // Rotate requested shape; row_offset and col_offset refer to original orientation + std::swap(requested_num_rows, requested_num_cols); + } else { + TT_THROW("User has requested a submesh that is too big and is not rotatable: {}x{} and SystemMesh is {}x{}.", + requested_num_rows, requested_num_cols, + system_mesh_rows, system_mesh_cols); + } } else { - for (int row = 0; row < requested_rows; row++) { - for (int col = 0; col < requested_cols; col++) { - auto logical_device_id = (row + row_offset) * system_mesh_cols + (col + col_offset); - auto logical_coordinate = - Coordinate{logical_device_id / system_mesh_cols, logical_device_id % system_mesh_cols}; - auto physical_coordinate = logical_to_physical_coordinates_.at(logical_coordinate); - auto physical_device_id = physical_coordinate_to_device_id_.at(physical_coordinate); - physical_device_ids.push_back(physical_device_id); - - log_debug( - LogMetal, - "Logical device ID: {}, Logical coordinate: {}, Physical coordinate: {}, Physical device ID: {}", - logical_device_id, - logical_coordinate, - physical_coordinate, - physical_device_id); + // If no rotation, check dimensions directly + TT_FATAL( + requested_num_rows <= system_mesh_rows && requested_num_cols <= system_mesh_cols, + "Requested submesh is too big: {}x{} and SystemMesh is {}x{}", + requested_num_rows, requested_num_cols, + system_mesh_rows, system_mesh_cols); + } + + size_t original_rows = system_mesh_rows; + size_t original_cols = system_mesh_cols; + + // Check that offsets fit in the original mesh + TT_FATAL( + row_offset + requested_num_rows <= original_rows, + "Row offset + requested rows exceeds mesh size: {} + {} > {}", + row_offset, requested_num_rows, original_rows); + TT_FATAL( + col_offset + requested_num_cols <= original_cols, + "Column offset + requested columns exceeds mesh size: {} + {} > {}", + col_offset, requested_num_cols, original_cols); + + // Map each submesh coordinate to the original logical coordinates + for (size_t row = 0; row < requested_num_rows; row++) { + for (size_t col = 0; col < requested_num_cols; col++) { + Coordinate logical_coordinate; + if (requires_rotation) { + // After swapping requested_num_rows and requested_num_cols, + // (row, col) now iterate over the rotated shape. + size_t old_row = row_offset + row; // top row + size_t old_col = col_offset + col; // increasing columns horizontally + logical_coordinate = Coordinate{old_row, old_col}; + } else { + logical_coordinate = Coordinate{row + row_offset, col + col_offset}; } + + TT_FATAL( + logical_coordinate.row < logical_mesh_shape_.num_rows, + "Row coordinate out of bounds: {} >= {}", + logical_coordinate.row, + logical_mesh_shape_.num_rows); + TT_FATAL( + logical_coordinate.col < logical_mesh_shape_.num_cols, + "Column coordinate out of bounds: {} >= {}", + logical_coordinate.col, + logical_mesh_shape_.num_cols); + + auto physical_device_id = logical_to_device_id_.at(logical_coordinate); + physical_device_ids.push_back(physical_device_id); + + log_debug( + LogMetal, + "Logical coordinate: {}, Physical device ID: {}", + logical_coordinate, + physical_device_id); } } return physical_device_ids; } + void SystemMesh::Impl::register_mesh_device( const std::shared_ptr& mesh_device, const std::vector& devices) { std::vector physical_device_ids; @@ -226,14 +308,9 @@ std::vector SystemMesh::Impl::request_available_devices(const MeshDev requested_num_cols, row_offset, col_offset); - TT_FATAL(requested_num_rows <= max_num_rows, "Requested too many rows: {} > {}", requested_num_rows, max_num_rows); - TT_FATAL( - requested_num_rows * requested_num_cols <= max_num_rows * max_num_cols, - "Requested submesh is too big: {}x{}", - requested_num_rows, - requested_num_cols); - return config.physical_device_ids.empty() ? this->get_mapped_physical_device_ids(config) : config.physical_device_ids; + return config.physical_device_ids.empty() ? this->get_mapped_physical_device_ids(config) + : config.physical_device_ids; } SystemMesh::SystemMesh() : pimpl_(std::make_unique()) {} @@ -247,6 +324,10 @@ SystemMesh& SystemMesh::instance() { return instance; } +chip_id_t SystemMesh::get_physical_device_id(size_t logical_row_idx, size_t logical_col_idx) const { + return pimpl_->get_physical_device_id(logical_row_idx, logical_col_idx); +} + const MeshShape& SystemMesh::get_shape() const { return pimpl_->get_shape(); } size_t SystemMesh::get_num_devices() const { return pimpl_->get_num_devices(); } @@ -269,9 +350,7 @@ static MeshDeviceID generate_unique_mesh_id() { return next_id++; } -Device* MeshDevice::reference_device() const { - return this->devices.at(0); -} +Device* MeshDevice::reference_device() const { return this->devices.at(0); } MeshDevice::MeshDevice(const MeshShape& mesh_device_shape, MeshType type, std::weak_ptr parent_mesh) : mesh_device_shape(mesh_device_shape), @@ -319,8 +398,8 @@ std::shared_ptr MeshDevice::create_submesh( auto submesh = std::make_shared(submesh_shape, type, shared_from_this()); auto start_coordinate = Coordinate{offset.row, offset.col}; auto end_coordinate = Coordinate{offset.row + submesh_shape.num_rows - 1, offset.col + submesh_shape.num_cols - 1}; - submesh->primary_view = std::make_shared(*this, start_coordinate, end_coordinate); - submesh->devices = submesh->primary_view->get_devices(); + submesh->view = std::make_unique(*this, start_coordinate, end_coordinate); + submesh->devices = submesh->view->get_devices(); SystemMesh::instance().register_mesh_device(submesh, submesh->devices); this->submeshes.push_back(submesh); log_trace( @@ -353,15 +432,6 @@ void MeshDevice::initialize( size_t num_command_queues, const DispatchCoreConfig& dispatch_core_config, const MeshDeviceConfig& config) { - auto [num_rows, num_cols] = this->shape(); - auto num_requested_devices = num_rows * num_cols; - auto num_available_devices = tt::tt_metal::GetNumAvailableDevices(); - TT_FATAL( - num_requested_devices <= num_available_devices, - "User has requested more devices than available: {} requested, {} available", - num_requested_devices, - num_available_devices); - auto& system_mesh = SystemMesh::instance(); auto physical_device_ids = system_mesh.request_available_devices(config); @@ -371,7 +441,7 @@ void MeshDevice::initialize( for (auto physical_device_id : physical_device_ids) { this->devices.push_back(this->opened_devices.at(physical_device_id)); } - this->primary_view = std::make_shared(*this); + this->view = std::make_unique(*this); system_mesh.register_mesh_device(shared_from_this(), this->devices); } @@ -391,8 +461,11 @@ Device* MeshDevice::get_device(chip_id_t physical_device_id) const { TT_THROW("Physical Device ID: {} not found in assigned devices", physical_device_id); } -std::vector MeshDevice::get_devices() const { return this->primary_view->get_devices(this->type); } +std::vector MeshDevice::get_devices(const std::optional& requested_type) const { + return this->view->get_devices(requested_type.value_or(this->type)); +} +// TODO: Remove this function once we have a proper view interface Device* MeshDevice::get_device(size_t row_idx, size_t col_idx) const { return this->get_device_index(row_idx * num_cols() + col_idx); } @@ -407,7 +480,9 @@ const DeviceIds MeshDevice::get_device_ids() const { size_t MeshDevice::num_devices() const { return this->devices.size(); } -CoreCoord MeshDevice::compute_with_storage_grid_size() const { return this->reference_device()->compute_with_storage_grid_size(); } +CoreCoord MeshDevice::compute_with_storage_grid_size() const { + return this->reference_device()->compute_with_storage_grid_size(); +} CoreCoord MeshDevice::dram_grid_size() const { return this->reference_device()->dram_grid_size(); } @@ -419,6 +494,41 @@ size_t MeshDevice::num_cols() const { return this->mesh_device_shape.num_cols; } MeshShape MeshDevice::shape() const { return this->mesh_device_shape; } +void MeshDevice::reshape(const MeshShape& new_shape) { + TT_FATAL( + new_shape.num_rows * new_shape.num_cols == this->num_devices(), + "New shape must have the same number of devices as current shape"); + + std::unordered_map physical_device_id_to_linearized_index; + for (size_t i = 0; i < this->num_devices(); i++) { + physical_device_id_to_linearized_index[this->devices[i]->id()] = i; + } + + // From an MxN mesh, we can always reduce rank to a 1xM*N Line mesh. + // However, going from a Line mesh to an MxN mesh is not always possible. + if (new_shape.num_rows != 1 and new_shape.num_cols != 1) { + auto new_physical_device_ids = + SystemMesh::instance().request_available_devices(MeshDeviceConfig{new_shape}); + + for (size_t i = 0; i < new_physical_device_ids.size(); i++) { + if (physical_device_id_to_linearized_index.find(new_physical_device_ids[i]) == physical_device_id_to_linearized_index.end()) { + TT_THROW( + "User has requested a reshape of the MeshDevice to shape: {}x{}, but it is not possible to form a " + "physically connected mesh of {}x{} grid with the opened devices from the original shape: {}x{}.", + new_shape.num_rows, + new_shape.num_cols, + new_shape.num_rows, + new_shape.num_cols, + this->num_rows(), + this->num_cols()); + } + } + } + + this->mesh_device_shape = new_shape; + this->view = std::make_unique(*this); +} + void MeshDevice::close_devices() { for (const auto& submesh : this->submeshes) { submesh->close_devices(); @@ -430,16 +540,17 @@ void MeshDevice::close_devices() { this->submeshes.clear(); this->parent_mesh.reset(); this->devices.clear(); - this->primary_view.reset(); + this->view.reset(); } std::string MeshDevice::to_string() const { return fmt::format("MeshDevice({}x{} grid, {} devices)", this->num_rows(), this->num_cols(), this->num_devices()); } -std::shared_ptr MeshDevice::get_view() const { return this->primary_view; } - -std::shared_ptr MeshDevice::get_view() { return this->primary_view; } +const MeshDeviceView& MeshDevice::get_view() const { + TT_FATAL(view, "MeshDeviceView is not initialized"); + return *view; +} MeshDeviceID MeshDevice::get_mesh_id() const { return this->mesh_id; } @@ -475,7 +586,8 @@ size_t MeshDevice::num_program_cache_entries() const { return total_entries; } -MeshSubDeviceManagerId MeshDevice::create_sub_device_manager(tt::stl::Span sub_devices, DeviceAddr local_l1_size) { +MeshSubDeviceManagerId MeshDevice::create_sub_device_manager( + tt::stl::Span sub_devices, DeviceAddr local_l1_size) { MeshSubDeviceManagerId mesh_sub_device_manager_id(*this); for (uint32_t i = 0; i < this->num_devices(); i++) { auto* device = this->devices[i]; @@ -511,25 +623,21 @@ void MeshDevice::load_sub_device_manager(MeshSubDeviceManagerId mesh_sub_device_ for (uint32_t i = 0; i < this->num_devices(); i++) { auto* device = this->devices[i]; auto sub_device_manager_id = mesh_sub_device_manager_id.sub_device_manager_ids[i]; - device->push_work([device, sub_device_manager_id]() { - device->load_sub_device_manager(sub_device_manager_id); - }); + device->push_work( + [device, sub_device_manager_id]() { device->load_sub_device_manager(sub_device_manager_id); }); } } void MeshDevice::clear_loaded_sub_device_manager() { for (auto* device : this->devices) { - device->push_work([device]() { - device->clear_loaded_sub_device_manager(); - }); + device->push_work([device]() { device->clear_loaded_sub_device_manager(); }); } } void MeshDevice::remove_sub_device_manager(MeshSubDeviceManagerId mesh_sub_device_manager_id) { for (uint32_t i = 0; i < this->num_devices(); i++) { auto* device = this->devices[i]; auto sub_device_manager_id = mesh_sub_device_manager_id.sub_device_manager_ids[i]; - device->push_work([device, sub_device_manager_id]() { - device->remove_sub_device_manager(sub_device_manager_id); - }); + device->push_work( + [device, sub_device_manager_id]() { device->remove_sub_device_manager(sub_device_manager_id); }); } } @@ -541,7 +649,8 @@ int MeshDevice::num_dram_channels() const { return this->reference_device()->num_dram_channels() * this->num_devices(); } -allocator::Statistics MeshDevice::get_memory_allocation_statistics(const BufferType &buffer_type, SubDeviceId sub_device_id) const { +allocator::Statistics MeshDevice::get_memory_allocation_statistics( + const BufferType& buffer_type, SubDeviceId sub_device_id) const { // With current implementation, we assume that all devices have the same memory allocation statistics. // This will be made more explicit in the future to have lock-step allocation across devices. // Right now, we just return the statistics of the first device. diff --git a/tt_metal/distributed/mesh_device.hpp b/tt_metal/distributed/mesh_device.hpp index 01a63d2e2865..b0e2a004d5ad 100644 --- a/tt_metal/distributed/mesh_device.hpp +++ b/tt_metal/distributed/mesh_device.hpp @@ -48,6 +48,7 @@ struct MeshDeviceConfig { // SystemMesh creates a virtualization over the physical devices in the system. // It creates a logical 2D-mesh of devices and manages the mapping between logical and physical device coordinates. +// It serves as a query interface between the logical 2D coordinates to physical device IDs. class SystemMesh { private: friend class MeshDevice; @@ -70,6 +71,9 @@ class SystemMesh { const MeshShape& get_shape() const; size_t get_num_devices() const; + // Gets the physical device ID for a given logical row and column index + chip_id_t get_physical_device_id(size_t logical_row_idx, size_t logical_col_idx) const; + // Get the physical device IDs mapped to a MeshDevice std::vector get_mapped_physical_device_ids(const MeshDeviceConfig &config) const; }; @@ -79,7 +83,7 @@ class MeshDevice : public std::enable_shared_from_this { MeshDeviceID mesh_id; MeshShape mesh_device_shape; MeshType type; - std::shared_ptr primary_view; + std::unique_ptr view; std::map opened_devices; std::vector devices; std::vector> submeshes; // Parent owns submeshes and responsible fortheir destruction @@ -105,7 +109,10 @@ class MeshDevice : public std::enable_shared_from_this { MeshDevice(MeshDevice&&) = delete; MeshDevice& operator=(MeshDevice&&) = delete; - std::vector get_devices() const; + // A MeshDevice is a collection of devices arranged in a 2D grid. + // The type parameter allows the caller to specify how to linearize the devices in the mesh. + // If type is not provided, the default behavior is to return the devices based on the MeshType of the MeshDevice. + std::vector get_devices(const std::optional& type = std::nullopt) const; Device* get_device_index(size_t logical_device_id) const; Device* get_device(chip_id_t physical_device_id) const; Device* get_device(size_t row_idx, size_t col_idx) const; @@ -117,9 +124,21 @@ class MeshDevice : public std::enable_shared_from_this { size_t num_cols() const; MeshShape shape() const; + // Reshapes the logical mesh and re-maps the physical devices to the new logical coordinates. + // Reshaping Rules: + // 1. The old_shape volume must equal the new_shape volume (i.e. number of devices must remain constant) + // 2. Line-to-Line Reshaping (when either dimension is 1): + // - Always possible between 1xN and Nx1 shapes (e.g.: 1x8 <-> 8x1 + // 3. Grid-to-Grid Reshaping: + // - Only possible if the devices can form a connected physical mesh in the new shape + // - Must maintain physical connectivity between adjacent devices + // 4. Line-to-Grid Reshaping: + // - Only possible if the physical devices can form a connected physical mesh in the new shape + // - Example: 1x8 -> 2x4 is possible only if physical mesh permits a 2x4 configuration + void reshape(const MeshShape& new_shape); + void close_devices(); - std::shared_ptr get_view() const; - std::shared_ptr get_view(); + const MeshDeviceView& get_view() const; std::string to_string() const; MeshDeviceID get_mesh_id() const; diff --git a/tt_metal/distributed/mesh_device_view.cpp b/tt_metal/distributed/mesh_device_view.cpp index f9e115f0437f..1c71c877823a 100644 --- a/tt_metal/distributed/mesh_device_view.cpp +++ b/tt_metal/distributed/mesh_device_view.cpp @@ -12,7 +12,7 @@ namespace tt::tt_metal::distributed { static std::vector get_devices_from_coordinates( - MeshDeviceView& mesh, const std::vector& coords) { + const MeshDeviceView& mesh, const std::vector& coords) { std::vector devices; for (const auto& coord : coords) { if (auto device = mesh.get_device(coord.row, coord.col)) { @@ -52,11 +52,7 @@ MeshDeviceView::MeshDeviceView(std::vector devices, const Coordi initialize_from_devices(devices_, std::move(mapper)); } -MeshDeviceView::device_pointer MeshDeviceView::get_device(size_t row, size_t col) { - return const_cast(std::as_const(*this).get_device(row, col)); -} - -MeshDeviceView::const_device_pointer MeshDeviceView::get_device(size_t row, size_t col) const { +MeshDeviceView::device_pointer MeshDeviceView::get_device(size_t row, size_t col) const { for (const auto& device : devices_) { auto it = device_coordinates_.find(device->id()); if (it != device_coordinates_.end() && it->second.row == row && it->second.col == col) { @@ -66,7 +62,7 @@ MeshDeviceView::const_device_pointer MeshDeviceView::get_device(size_t row, size return nullptr; } -MeshDeviceView::DeviceView MeshDeviceView::get_devices(const Coordinate& start, const Coordinate& end) { +MeshDeviceView::DeviceView MeshDeviceView::get_devices(const Coordinate& start, const Coordinate& end) const { if (start.row > end.row || start.col > end.col) { log_fatal("Invalid coordinates: start {} must be less than or equal to end {}", start, end); } @@ -82,8 +78,8 @@ MeshDeviceView::DeviceView MeshDeviceView::get_devices(const Coordinate& start, return devices_in_region; } -MeshDeviceView::DeviceView MeshDeviceView::get_devices(const MeshShape& shape) { - return get_devices({0, 0}, {shape.num_rows - 1, shape.num_cols - 1}); +MeshDeviceView::DeviceView MeshDeviceView::get_devices(const MeshShape& submesh_shape) const { + return get_devices({0, 0}, {submesh_shape.num_rows - 1, submesh_shape.num_cols - 1}); } std::vector MeshDeviceView::get_devices_on_row(size_t row) const { @@ -214,7 +210,7 @@ std::vector MeshDeviceView::get_line_coordinates( } std::vector MeshDeviceView::get_ring_coordinates( - const MeshShape& ring_shape, const Coordinate& offset, size_t num_rows, size_t num_cols) { + const MeshShape& ring_shape, const Coordinate& offset, size_t num_rows, size_t num_cols) const { auto [start_row, start_col] = offset; auto [ring_rows, ring_cols] = ring_shape; auto end_row = start_row + ring_rows - 1; @@ -258,18 +254,18 @@ void MeshDeviceView::validate_coordinates() const { } } -std::vector MeshDeviceView::get_line_devices() { +std::vector MeshDeviceView::get_line_devices() const { auto boundary_coords = get_line_coordinates(this->num_rows() * this->num_cols(), this->top_left_, this->num_rows(), this->num_cols()); return get_devices_from_coordinates(*this, boundary_coords); } -std::vector MeshDeviceView::get_ring_devices() { +std::vector MeshDeviceView::get_ring_devices() const { auto boundary_coords = get_ring_coordinates(shape(), this->top_left_, this->num_rows(), this->num_cols()); return get_devices_from_coordinates(*this, boundary_coords); } -MeshDeviceView::DeviceView MeshDeviceView::get_devices(MeshType type) { +MeshDeviceView::DeviceView MeshDeviceView::get_devices(MeshType type) const { switch (type) { case MeshType::RowMajor: return this->devices_; case MeshType::Ring: return this->get_ring_devices(); diff --git a/tt_metal/distributed/mesh_device_view.hpp b/tt_metal/distributed/mesh_device_view.hpp index 31af7aba3764..0524814b7971 100644 --- a/tt_metal/distributed/mesh_device_view.hpp +++ b/tt_metal/distributed/mesh_device_view.hpp @@ -75,14 +75,13 @@ class MeshDeviceView { MeshDeviceView(const MeshDevice& mesh, Coordinate top_left, Coordinate bottom_right); MeshDeviceView(std::vector devices, const CoordinateMapper& mapper); - [[nodiscard]] device_pointer get_device(size_t row, size_t col); - [[nodiscard]] const_device_pointer get_device(size_t row, size_t col) const; + [[nodiscard]] device_pointer get_device(size_t row, size_t col) const; // Get devices spanning the rectangular region defined by the top-left and bottom-right coordinates // devices are returned in row-major order with start/end coordinates inclusive - [[nodiscard]] DeviceView get_devices(const Coordinate& start, const Coordinate& end); - [[nodiscard]] DeviceView get_devices(const MeshShape& shape); - [[nodiscard]] DeviceView get_devices(MeshType type = MeshType::RowMajor); + [[nodiscard]] DeviceView get_devices(const Coordinate& start, const Coordinate& end) const; + [[nodiscard]] DeviceView get_devices(const MeshShape& submesh_shape) const; + [[nodiscard]] DeviceView get_devices(MeshType type = MeshType::RowMajor) const; [[nodiscard]] DeviceView get_devices_on_row(size_t row) const; [[nodiscard]] DeviceView get_devices_on_column(size_t col) const; @@ -114,9 +113,9 @@ class MeshDeviceView { [[nodiscard]] static std::vector get_line_coordinates( size_t length, const Coordinate& offset, size_t num_rows, size_t num_cols); [[nodiscard]] std::vector get_ring_coordinates( - const MeshShape& ring_shape, const Coordinate& offset, size_t num_rows, size_t num_cols); - [[nodiscard]] std::vector get_ring_devices(); - [[nodiscard]] std::vector get_line_devices(); + const MeshShape& ring_shape, const Coordinate& offset, size_t num_rows, size_t num_cols) const; + [[nodiscard]] std::vector get_ring_devices() const; + [[nodiscard]] std::vector get_line_devices() const; private: std::vector devices_; diff --git a/ttnn/cpp/ttnn/distributed/api.cpp b/ttnn/cpp/ttnn/distributed/api.cpp index fee7fa1566c5..986f127f2b67 100644 --- a/ttnn/cpp/ttnn/distributed/api.cpp +++ b/ttnn/cpp/ttnn/distributed/api.cpp @@ -156,13 +156,13 @@ std::vector get_mapped_devices(const Tensor& tensor, MeshDevice& mesh_d } return workers; }; - if (mesh_device.get_view() != nullptr and std::holds_alternative(tensor.get_storage())) { + if (std::holds_alternative(tensor.get_storage())) { const auto& host_storage = std::get(tensor.get_storage()); return std::visit( tt::stl::overloaded{ [&](const ShardTensor2D& s) { - return mesh_device.get_view()->get_devices(MeshShape{s.shard_mesh.y, s.shard_mesh.x}); + return mesh_device.get_view().get_devices(MeshShape{s.shard_mesh.y, s.shard_mesh.x}); }, [&](const auto&) { return get_workers_for_tensor(); }}, host_storage.strategy); diff --git a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp index 1e1ac97f507d..d41713d53f1d 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp @@ -100,7 +100,12 @@ void py_module(py::module& module) { "get_device", py::overload_cast(&MeshDevice::get_device, py::const_), py::return_value_policy::reference) - .def("get_devices", &MeshDevice::get_devices, py::return_value_policy::reference, R"doc( + .def( + "get_devices", + &MeshDevice::get_devices, + py::return_value_policy::reference, + py::arg("type") = py::none(), + R"doc( Get the devices in the device mesh. Returns: diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp index e32b1232ae85..119e2d840198 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp @@ -300,7 +300,7 @@ Tensor all_gather( topology == ttnn::ccl::Topology::Linear, "This all_gather API with cluster_axis is currently supported only for the Linear topology"); const auto mesh_view = mesh_device.get_view(); - std::size_t num_devices = (cluster_axis == 0) ? mesh_view->num_rows() : mesh_view->num_cols(); + std::size_t num_devices = (cluster_axis == 0) ? mesh_view.num_rows() : mesh_view.num_cols(); int32_t rank = input_tensor.get_logical_shape().rank(); @@ -330,7 +330,7 @@ Tensor all_gather( const std::vector>& optional_output_tensors) mutable -> std::vector { const auto& input_device_tensor = input_tensors.at(0); - const auto coordinate = mesh_view->find_device(input_device_tensor.device()->id()); + const auto coordinate = mesh_view.find_device(input_device_tensor.device()->id()); const auto view_index = (cluster_axis == 0) ? coordinate.col : coordinate.row; const auto device_index = (cluster_axis == 0) ? coordinate.row : coordinate.col; @@ -341,7 +341,7 @@ Tensor all_gather( } else { new_coord.col = line_index % num_devices; } - return mesh_view->find_device_id(new_coord); + return mesh_view.find_device_id(new_coord); }; bool is_last_chip_in_clockwise_direction = device_index == (num_devices - 1); diff --git a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp index 4f633e5d8bc7..a89e0407c3e5 100644 --- a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp @@ -196,7 +196,7 @@ Tensor reduce_scatter( topology == ttnn::ccl::Topology::Linear, "This all_gather API with cluster_axis is currently supported only for the Linear topology"); const auto mesh_view = mesh_device.get_view(); - std::size_t num_devices = (cluster_axis == 0) ? mesh_view->num_rows() : mesh_view->num_cols(); + std::size_t num_devices = (cluster_axis == 0) ? mesh_view.num_rows() : mesh_view.num_cols(); int16_t rank = input_tensor.get_logical_shape().rank(); @@ -227,7 +227,7 @@ Tensor reduce_scatter( const std::vector>& optional_output_tensors) mutable -> std::vector { const auto& input_device_tensor = input_tensors.at(0); - const auto coordinate = mesh_view->find_device(input_device_tensor.device()->id()); + const auto coordinate = mesh_view.find_device(input_device_tensor.device()->id()); const auto view_index = (cluster_axis == 0) ? coordinate.col : coordinate.row; const auto device_index = (cluster_axis == 0) ? coordinate.row : coordinate.col; @@ -238,7 +238,7 @@ Tensor reduce_scatter( } else { new_coord.col = line_index % num_devices; } - return mesh_view->find_device_id(new_coord); + return mesh_view.find_device_id(new_coord); }; bool is_last_chip_in_clockwise_direction = device_index == (num_devices - 1);