diff --git a/tests/ttnn/distributed/CMakeLists.txt b/tests/ttnn/distributed/CMakeLists.txt index 5294a6d73a56..f41d726988af 100644 --- a/tests/ttnn/distributed/CMakeLists.txt +++ b/tests/ttnn/distributed/CMakeLists.txt @@ -1,4 +1,8 @@ -add_executable(test_distributed test_distributed.cpp) +add_executable( + test_distributed + test_distributed.cpp + test_distributed_atexit.cpp +) # Set up properties for the target setup_ttnn_test_target(test_distributed) diff --git a/tests/ttnn/distributed/test_distributed_atexit.cpp b/tests/ttnn/distributed/test_distributed_atexit.cpp new file mode 100644 index 000000000000..d3d1ea75f291 --- /dev/null +++ b/tests/ttnn/distributed/test_distributed_atexit.cpp @@ -0,0 +1,32 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include +#include +#include + +namespace ttnn::distributed::test { + +class DistributedTest : public ::testing::Test { +protected: + void SetUp() override {} + void TearDown() override {} +}; + +// Test that the mesh device is properly cleaned up when the program exits +static std::shared_ptr mesh; + +TEST_F(DistributedTest, TestSystemMeshTearDownWithoutClose_StaticMesh) { + auto& sys = tt::tt_metal::distributed::SystemMesh::instance(); + mesh = ttnn::distributed::open_mesh_device( + {2, 4}, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, 1, tt::tt_metal::DispatchCoreType::WORKER); + + auto [rows, cols] = sys.get_shape(); + EXPECT_GT(rows, 0); + EXPECT_GT(cols, 0); +} + +} // namespace ttnn::distributed::test diff --git a/tt_metal/distributed/mesh_device.cpp b/tt_metal/distributed/mesh_device.cpp index db820323ae51..eacca0ffd8d4 100644 --- a/tt_metal/distributed/mesh_device.cpp +++ b/tt_metal/distributed/mesh_device.cpp @@ -48,19 +48,61 @@ static std::unordered_map load_translatio if (mapping.size() != 2 || mapping[0].size() != 2 || mapping[1].size() != 5) { throw std::runtime_error("Invalid coordinate format in JSON file: " + filename); } - result.emplace(LogicalCoordinate{mapping[0][0], mapping[0][1]}, PhysicalCoordinate{ - mapping[1][0], // cluster_id - mapping[1][2], // x - mapping[1][1], // y - mapping[1][3], // rack - mapping[1][4] // shelf - }); + result.emplace( + LogicalCoordinate{mapping[0][0], mapping[0][1]}, + PhysicalCoordinate{ + mapping[1][0], // cluster_id + mapping[1][2], // x + mapping[1][1], // y + mapping[1][3], // rack + mapping[1][4] // shelf + }); } return result; } -MeshShape SystemMesh::get_system_mesh_shape(size_t system_num_devices) { +class SystemMesh::Impl { +private: + using LogicalCoordinate = Coordinate; + using PhysicalCoordinate = eth_coord_t; + + std::unordered_map> opened_devices; + std::unordered_map> assigned_devices; + std::unordered_map> assigned_mesh_device_devices; + + MeshShape logical_mesh_shape; + std::unordered_map logical_to_physical_coordinates; + std::unordered_map physical_coordinate_to_device_id; + std::unordered_map physical_device_id_to_coordinate; + +public: + Impl() = default; + ~Impl() = default; + + bool is_system_mesh_initialized() const; + void initialize(); + const MeshShape& get_shape() const; + size_t get_num_devices() const; + std::vector map_mesh_device( + const std::shared_ptr& mesh_device, + size_t num_command_queues, + size_t l1_small_size, + size_t trace_region_size, + const DispatchCoreConfig& dispatch_core_config, + const MeshDeviceConfig& config); + std::vector get_mapped_physical_device_ids(const MeshDeviceConfig& config) const; + void remove_expired_mesh_devices(); + Device* get_device(const chip_id_t physical_device_id) const; + void register_mesh_device(const std::shared_ptr& mesh_device, const std::vector& devices); + + static MeshShape get_system_mesh_shape(size_t system_num_devices); + static std::unordered_map get_system_mesh_translation_map( + size_t system_num_devices); +}; + +// Implementation of private static methods +MeshShape SystemMesh::Impl::get_system_mesh_shape(size_t system_num_devices) { const std::unordered_map system_mesh_to_shape = { {1, MeshShape{1, 1}}, // single-device {2, MeshShape{1, 2}}, // N300 @@ -68,13 +110,15 @@ MeshShape SystemMesh::get_system_mesh_shape(size_t system_num_devices) { {32, MeshShape{8, 4}}, // TG {64, MeshShape{8, 8}}, // TGG }; - TT_FATAL(system_mesh_to_shape.contains(system_num_devices), "Unsupported number of devices: {}", system_num_devices); + TT_FATAL( + system_mesh_to_shape.contains(system_num_devices), "Unsupported number of devices: {}", system_num_devices); auto shape = system_mesh_to_shape.at(system_num_devices); log_debug(LogMetal, "Logical SystemMesh Shape: {}x{}", shape.first, shape.second); return shape; } -std::unordered_map SystemMesh::get_system_mesh_translation_map(size_t system_num_devices) { +std::unordered_map SystemMesh::Impl::get_system_mesh_translation_map( + size_t system_num_devices) { const std::unordered_map system_mesh_translation_map = { {1, "device.json"}, {2, "N300.json"}, @@ -82,28 +126,25 @@ std::unordered_map SystemMesh::get_system {32, "TG.json"}, {64, "TGG.json"}, }; - TT_FATAL(system_mesh_translation_map.contains(system_num_devices), "Unsupported number of devices: {}", system_num_devices); + TT_FATAL( + system_mesh_translation_map.contains(system_num_devices), + "Unsupported number of devices: {}", + system_num_devices); auto translation_config_file = get_config_path(system_mesh_translation_map.at(system_num_devices)); return load_translation_map(translation_config_file, "logical_to_physical_coordinates"); } -bool SystemMesh::is_system_mesh_initialized() const { - return this->physical_coordinate_to_device_id.size() > 0; -} +// Implementation of public methods +bool SystemMesh::Impl::is_system_mesh_initialized() const { return this->physical_coordinate_to_device_id.size() > 0; } -SystemMesh& SystemMesh::instance() { - static SystemMesh instance; - if (!instance.is_system_mesh_initialized()) { - instance.initialize(); - } - return instance; -} -void SystemMesh::initialize() { +void SystemMesh::Impl::initialize() { this->physical_device_id_to_coordinate = tt::Cluster::instance().get_user_chip_ethernet_coordinates(); if (this->physical_device_id_to_coordinate.empty()) { // Only WH has ethernet coordinates. Fabric will assign chip ids for BH auto arch = tt::Cluster::instance().arch(); - TT_FATAL(arch == ARCH::GRAYSKULL or arch == ARCH::BLACKHOLE, "Expected Wormhole chips to have ethernet coordinates assigned by cluster descriptor"); + TT_FATAL( + arch == ARCH::GRAYSKULL or arch == ARCH::BLACKHOLE, + "Expected Wormhole chips to have ethernet coordinates assigned by cluster descriptor"); const int num_detected_devices = tt::Cluster::instance().number_of_devices(); for (auto chip_id = 0; chip_id < num_detected_devices; chip_id++) { PhysicalCoordinate coord{0, chip_id, 0, 0, 0}; @@ -116,19 +157,18 @@ void SystemMesh::initialize() { } } - // Initialize the system mesh shape and translation map auto num_devices = physical_coordinate_to_device_id.size(); - this->logical_mesh_shape = SystemMesh::get_system_mesh_shape(num_devices); - this->logical_to_physical_coordinates = SystemMesh::get_system_mesh_translation_map(num_devices); + this->logical_mesh_shape = get_system_mesh_shape(num_devices); + this->logical_to_physical_coordinates = get_system_mesh_translation_map(num_devices); } -const MeshShape& SystemMesh::get_shape() const { return this->logical_mesh_shape; } -size_t SystemMesh::get_num_devices() const { +const MeshShape& SystemMesh::Impl::get_shape() const { return this->logical_mesh_shape; } +size_t SystemMesh::Impl::get_num_devices() const { auto [num_rows, num_cols] = this->get_shape(); return num_rows * num_cols; } -std::vector SystemMesh::get_mapped_physical_device_ids(const MeshDeviceConfig& config) const { +std::vector SystemMesh::Impl::get_mapped_physical_device_ids(const MeshDeviceConfig& config) const { std::vector physical_device_ids; auto [system_mesh_rows, system_mesh_cols] = this->get_shape(); auto [requested_rows, requested_cols] = config.mesh_shape; @@ -136,32 +176,44 @@ std::vector SystemMesh::get_mapped_physical_device_ids(const MeshDevi if (requested_rows == 1) { TT_FATAL(row_offset == 0 and col_offset == 0, "Row and column offsets unsupported for single row mesh"); - auto line_coords = MeshDeviceView::get_line_coordinates(requested_cols, Coordinate{row_offset, col_offset}, system_mesh_rows, system_mesh_cols); + auto line_coords = MeshDeviceView::get_line_coordinates( + requested_cols, Coordinate{row_offset, col_offset}, system_mesh_rows, system_mesh_cols); for (const auto& logical_coordinate : line_coords) { auto physical_coordinate = this->logical_to_physical_coordinates.at(logical_coordinate); auto physical_device_id = this->physical_coordinate_to_device_id.at(physical_coordinate); physical_device_ids.push_back(physical_device_id); - log_debug(LogMetal, "Logical coordinate: {}, Physical coordinate: {}, Physical device ID: {}", - logical_coordinate, physical_coordinate, physical_device_id); + log_debug( + LogMetal, + "Logical coordinate: {}, Physical coordinate: {}, Physical device ID: {}", + logical_coordinate, + physical_coordinate, + physical_device_id); } } else { for (int row = 0; row < requested_rows; row++) { for (int col = 0; col < requested_cols; col++) { auto logical_device_id = (row + row_offset) * system_mesh_cols + (col + col_offset); - auto logical_coordinate = Coordinate{logical_device_id / system_mesh_cols, logical_device_id % system_mesh_cols}; + auto logical_coordinate = + Coordinate{logical_device_id / system_mesh_cols, logical_device_id % system_mesh_cols}; auto physical_coordinate = this->logical_to_physical_coordinates.at(logical_coordinate); auto physical_device_id = this->physical_coordinate_to_device_id.at(physical_coordinate); physical_device_ids.push_back(physical_device_id); - log_debug(LogMetal, "Logical device ID: {}, Logical coordinate: {}, Physical coordinate: {}, Physical device ID: {}", - logical_device_id, logical_coordinate, physical_coordinate, physical_device_id); + log_debug( + LogMetal, + "Logical device ID: {}, Logical coordinate: {}, Physical coordinate: {}, Physical device ID: {}", + logical_device_id, + logical_coordinate, + physical_coordinate, + physical_device_id); } } } return physical_device_ids; } -void SystemMesh::register_mesh_device(const std::shared_ptr &mesh_device, const std::vector& devices) { +void SystemMesh::Impl::register_mesh_device( + const std::shared_ptr& mesh_device, const std::vector& devices) { std::vector physical_device_ids; for (auto device : devices) { physical_device_ids.push_back(device->id()); @@ -170,26 +222,35 @@ void SystemMesh::register_mesh_device(const std::shared_ptr &mesh_de this->assigned_devices.insert({mesh_device->get_mesh_id(), physical_device_ids}); } -std::vector SystemMesh::map_mesh_device( +std::vector SystemMesh::Impl::map_mesh_device( const std::shared_ptr& mesh_device, size_t num_command_queues, size_t l1_small_size, size_t trace_region_size, - const DispatchCoreConfig &dispatch_core_config, + const DispatchCoreConfig& dispatch_core_config, const MeshDeviceConfig& config) { + this->remove_expired_mesh_devices(); auto [requested_num_rows, requested_num_cols] = mesh_device->shape(); auto [max_num_rows, max_num_cols] = this->logical_mesh_shape; auto [row_offset, col_offset] = config.offset; - log_debug(LogMetal, "Mapping MeshDevice ({}x{}) with offset: {}, {}", requested_num_rows, requested_num_cols, row_offset, col_offset); + log_debug( + LogMetal, + "Mapping MeshDevice ({}x{}) with offset: {}, {}", + requested_num_rows, + requested_num_cols, + row_offset, + col_offset); TT_FATAL(requested_num_rows <= max_num_rows, "Requested too many rows: {} > {}", requested_num_rows, max_num_rows); - TT_FATAL(requested_num_rows*requested_num_cols <= max_num_rows*max_num_cols, "Requested submesh is too big: {}x{}", requested_num_rows, requested_num_cols); - + TT_FATAL( + requested_num_rows * requested_num_cols <= max_num_rows * max_num_cols, + "Requested submesh is too big: {}x{}", + requested_num_rows, + requested_num_cols); - auto physical_device_ids = config.physical_device_ids.empty() ? - this->get_mapped_physical_device_ids(config) : - config.physical_device_ids; + auto physical_device_ids = + config.physical_device_ids.empty() ? this->get_mapped_physical_device_ids(config) : config.physical_device_ids; this->opened_devices[mesh_device->get_mesh_id()] = tt::tt_metal::detail::CreateDevices( physical_device_ids, num_command_queues, l1_small_size, trace_region_size, dispatch_core_config); @@ -198,34 +259,66 @@ std::vector SystemMesh::map_mesh_device( for (auto physical_device_id : physical_device_ids) { auto mapped_device = this->opened_devices[mesh_device->get_mesh_id()].at(physical_device_id); mapped_devices.push_back(mapped_device); - this->assigned_physical_id_to_device.insert({physical_device_id, mapped_device}); } - this->register_mesh_device(mesh_device, mapped_devices); // here + this->register_mesh_device(mesh_device, mapped_devices); // here return mapped_devices; } -void SystemMesh::unmap_mesh_device(const MeshDevice* mesh_device) { - auto mesh_id = mesh_device->get_mesh_id(); - this->assigned_mesh_device_devices.erase(mesh_id); +void SystemMesh::Impl::remove_expired_mesh_devices() { + std::vector stale_ids; + for (const auto& [mesh_id, weak_mesh_device] : assigned_mesh_device_devices) { + if (weak_mesh_device.expired()) { + stale_ids.push_back(mesh_id); + } + } + for (auto mesh_id : stale_ids) { + this->assigned_mesh_device_devices.erase(mesh_id); + + if (assigned_devices.count(mesh_id)) { + assigned_devices.erase(mesh_id); + } - // Close the devices - if (mesh_device->is_parent_mesh()) { - for (auto physical_id : this->assigned_devices.at(mesh_id)) { - this->assigned_physical_id_to_device.erase(physical_id); + if (opened_devices.count(mesh_id)) { + tt::tt_metal::detail::CloseDevices(this->opened_devices.at(mesh_id)); + this->opened_devices.erase(mesh_id); } - tt::tt_metal::detail::CloseDevices(this->opened_devices.at(mesh_id)); - this->opened_devices.erase(mesh_id); } - this->assigned_devices.erase(mesh_id); } -Device* SystemMesh::get_device(const chip_id_t physical_device_id) const { - auto it = this->assigned_physical_id_to_device.find(physical_device_id); - if (it == this->assigned_physical_id_to_device.end()) { - TT_THROW("Physical Device ID: {} not found in assigned devices", physical_device_id); +SystemMesh::SystemMesh() : pimpl(std::make_unique()) {} +SystemMesh::~SystemMesh() = default; + +SystemMesh& SystemMesh::instance() { + static SystemMesh instance; + if (!instance.pimpl->is_system_mesh_initialized()) { + instance.pimpl->initialize(); } - return it->second; + return instance; +} + +const MeshShape& SystemMesh::get_shape() const { return pimpl->get_shape(); } + +size_t SystemMesh::get_num_devices() const { return pimpl->get_num_devices(); } + +void SystemMesh::register_mesh_device( + const std::shared_ptr& mesh_device, const std::vector& devices) { + pimpl->register_mesh_device(mesh_device, devices); +} + +std::vector SystemMesh::map_mesh_device( + const std::shared_ptr& mesh_device, + size_t num_command_queues, + size_t l1_small_size, + size_t trace_region_size, + const DispatchCoreConfig& dispatch_core_config, + const MeshDeviceConfig& config) { + return pimpl->map_mesh_device( + mesh_device, num_command_queues, l1_small_size, trace_region_size, dispatch_core_config, config); +} + +std::vector SystemMesh::get_mapped_physical_device_ids(const MeshDeviceConfig& config) const { + return pimpl->get_mapped_physical_device_ids(config); } static MeshDeviceID generate_unique_mesh_id() { @@ -233,16 +326,18 @@ static MeshDeviceID generate_unique_mesh_id() { return next_id++; } -MeshDevice::MeshDevice(const MeshShape& mesh_device_shape, MeshType type, std::weak_ptr parent_mesh) - : mesh_device_shape(mesh_device_shape), type(type), mesh_id(generate_unique_mesh_id()), parent_mesh(std::move(parent_mesh)) {} +MeshDevice::MeshDevice(const MeshShape& mesh_device_shape, MeshType type, std::weak_ptr parent_mesh) : + mesh_device_shape(mesh_device_shape), + type(type), + mesh_id(generate_unique_mesh_id()), + parent_mesh(std::move(parent_mesh)) {} std::shared_ptr MeshDevice::create( const MeshDeviceConfig& config, size_t l1_small_size, size_t trace_region_size, size_t num_command_queues, - const DispatchCoreConfig &dispatch_core_config) -{ + const DispatchCoreConfig& dispatch_core_config) { auto mesh_device = std::make_shared(config.mesh_shape, config.mesh_type); mesh_device->initialize(l1_small_size, trace_region_size, num_command_queues, dispatch_core_config, config); @@ -250,12 +345,12 @@ std::shared_ptr MeshDevice::create( } std::shared_ptr MeshDevice::create_submesh( - const MeshShape &submesh_shape, - const MeshOffset &offset, - MeshType type) -{ + const MeshShape& submesh_shape, const MeshOffset& offset, MeshType type) { if (submesh_shape.first <= 0 || submesh_shape.second <= 0) { - TT_THROW("Invalid submesh shape: ({}, {}). Both dimensions must be positive.", submesh_shape.first, submesh_shape.second); + TT_THROW( + "Invalid submesh shape: ({}, {}). Both dimensions must be positive.", + submesh_shape.first, + submesh_shape.second); } if (offset.first < 0 || offset.second < 0) { @@ -264,10 +359,14 @@ std::shared_ptr MeshDevice::create_submesh( if (offset.first + submesh_shape.first > this->mesh_device_shape.first || offset.second + submesh_shape.second > this->mesh_device_shape.second) { - TT_THROW("Submesh ({}x{}) with offset ({}, {}) does not fit within parent mesh ({}x{}).", - submesh_shape.first, submesh_shape.second, - offset.first, offset.second, - this->mesh_device_shape.first, this->mesh_device_shape.second); + TT_THROW( + "Submesh ({}x{}) with offset ({}, {}) does not fit within parent mesh ({}x{}).", + submesh_shape.first, + submesh_shape.second, + offset.first, + offset.second, + this->mesh_device_shape.first, + this->mesh_device_shape.second); } auto submesh = std::make_shared(submesh_shape, type, shared_from_this()); @@ -277,16 +376,20 @@ std::shared_ptr MeshDevice::create_submesh( submesh->devices = submesh->primary_view->get_devices(); SystemMesh::instance().register_mesh_device(submesh, submesh->devices); this->submeshes.push_back(submesh); - log_trace(LogMetal, "Instantiating submesh {}: {}x{} with offset: {} {}", submesh->get_mesh_id(), submesh_shape.first, submesh_shape.second, offset.first, offset.second); + log_trace( + LogMetal, + "Instantiating submesh {}: {}x{} with offset: {} {}", + submesh->get_mesh_id(), + submesh_shape.first, + submesh_shape.second, + offset.first, + offset.second); log_trace(LogMetal, "Submesh {} instantiated with {} devices", submesh->get_mesh_id(), submesh->devices); return submesh; } -std::vector> MeshDevice::create_submeshes( - const MeshShape &submesh_shape, - MeshType type) -{ +std::vector> MeshDevice::create_submeshes(const MeshShape& submesh_shape, MeshType type) { std::vector> submeshes; for (int row = 0; row < this->num_rows(); row += submesh_shape.first) { for (int col = 0; col < this->num_cols(); col += submesh_shape.second) { @@ -301,16 +404,16 @@ void MeshDevice::initialize( size_t l1_small_size, size_t trace_region_size, size_t num_command_queues, - const DispatchCoreConfig &dispatch_core_config, - const MeshDeviceConfig& config) -{ + const DispatchCoreConfig& dispatch_core_config, + const MeshDeviceConfig& config) { auto [num_rows, num_cols] = this->shape(); auto num_requested_devices = num_rows * num_cols; auto num_available_devices = tt::tt_metal::GetNumAvailableDevices(); TT_FATAL( num_requested_devices <= num_available_devices, "User has requested more devices than available: {} requested, {} available", - num_requested_devices, num_available_devices); + num_requested_devices, + num_available_devices); auto& instance = SystemMesh::instance(); this->devices = instance.map_mesh_device( @@ -318,9 +421,7 @@ void MeshDevice::initialize( this->primary_view = std::make_shared(*this); } -MeshDevice::~MeshDevice() { - close_devices(); -} +MeshDevice::~MeshDevice() { close_devices(); } Device* MeshDevice::get_device_index(size_t logical_device_id) const { TT_FATAL(logical_device_id >= 0 and logical_device_id < num_devices(), "Invalid device index"); @@ -328,7 +429,12 @@ Device* MeshDevice::get_device_index(size_t logical_device_id) const { } Device* MeshDevice::get_device(chip_id_t physical_device_id) const { - return SystemMesh::instance().get_device(physical_device_id); + for (auto device : this->devices) { + if (device->id() == physical_device_id) { + return device; + } + } + TT_THROW("Physical Device ID: {} not found in assigned devices", physical_device_id); } std::vector MeshDevice::get_devices() const { return this->primary_view->get_devices(this->type); } @@ -347,7 +453,9 @@ const DeviceIds MeshDevice::get_device_ids() const { size_t MeshDevice::num_devices() const { return this->devices.size(); } -CoreCoord MeshDevice::compute_with_storage_grid_size() const { return get_device_index(0)->compute_with_storage_grid_size(); } +CoreCoord MeshDevice::compute_with_storage_grid_size() const { + return get_device_index(0)->compute_with_storage_grid_size(); +} CoreCoord MeshDevice::dram_grid_size() const { return get_device_index(0)->dram_grid_size(); } @@ -363,9 +471,7 @@ void MeshDevice::close_devices() { for (const auto& submesh : this->submeshes) { submesh->close_devices(); } - if (not this->devices.empty()) { - SystemMesh::instance().unmap_mesh_device(this); - } + this->submeshes.clear(); this->parent_mesh.reset(); this->devices.clear(); this->primary_view.reset(); @@ -383,34 +489,6 @@ MeshDeviceID MeshDevice::get_mesh_id() const { return this->mesh_id; } bool MeshDevice::is_parent_mesh() const { return this->parent_mesh.expired(); } -std::shared_ptr SystemMesh::get_mesh_device(const std::vector& physical_device_ids) { - log_trace(LogMetal, "Getting mesh device for {} physical devices: {}", physical_device_ids.size(), physical_device_ids); - std::unordered_set input_set(physical_device_ids.begin(), physical_device_ids.end()); - - for (const auto& [mesh_id, weak_mesh_device] : this->assigned_mesh_device_devices) { - if (auto mesh_device = weak_mesh_device.lock()) { - const auto& assigned_devices = this->assigned_devices.at(mesh_id); - std::unordered_set assigned_set(assigned_devices.begin(), assigned_devices.end()); - log_trace(LogMetal, "Assigned devices: {}", assigned_devices); - - if (input_set == assigned_set) { - return mesh_device; - } - } - } - TT_THROW("No mesh device found for the provided devices"); -} - -std::shared_ptr MeshDevice::fetch_mesh_device(const std::vector& devices) { - TT_FATAL(devices.size() > 0, "No devices provided"); - auto& instance = SystemMesh::instance(); - std::vector physical_device_ids; - for (auto device : devices) { - physical_device_ids.push_back(device->id()); - } - return instance.get_mesh_device(physical_device_ids); -} - std::vector> MeshDevice::get_submeshes() const { return this->submeshes; } std::ostream& operator<<(std::ostream& os, const MeshDevice& mesh_device) { return os << mesh_device.to_string(); } diff --git a/tt_metal/distributed/mesh_device.hpp b/tt_metal/distributed/mesh_device.hpp index 2c0ee62f872d..83e170e32bd0 100644 --- a/tt_metal/distributed/mesh_device.hpp +++ b/tt_metal/distributed/mesh_device.hpp @@ -44,63 +44,35 @@ struct MeshDeviceConfig { // It is responsible for the assignment of devices in a MeshDevice to physical devices, and the creation and deletion of // device resources. class SystemMesh { -private: - using LogicalCoordinate = Coordinate; - using PhysicalCoordinate = eth_coord_t; - - // Keep track of the devices that were opened so we can close them later. We shouldn't - // to keep track of this but DevicePool seems to open all devices associated with an MMIO device id - std::unordered_map> opened_devices; - std::unordered_map> assigned_devices; - std::unordered_map> assigned_mesh_device_devices; - std::unordered_map assigned_physical_id_to_device; - - // Logical mesh shape and coordinates - MeshShape logical_mesh_shape; - std::unordered_map logical_to_physical_coordinates; - - // Handling of physical coordinates - std::unordered_map physical_coordinate_to_device_id; - std::unordered_map physical_device_id_to_coordinate; - - SystemMesh() = default; - SystemMesh(const SystemMesh&) = delete; - SystemMesh& operator=(const SystemMesh&) = delete; - SystemMesh(SystemMesh&&) = delete; - SystemMesh& operator=(SystemMesh&&) = delete; - - static MeshShape get_system_mesh_shape(size_t system_num_devices); - static std::unordered_map get_system_mesh_translation_map( - size_t system_num_devices); - - bool is_system_mesh_initialized() const; + private: + friend class MeshDevice; + class Impl; // Forward declaration only + std::unique_ptr pimpl; + SystemMesh(); + ~SystemMesh(); + + void register_mesh_device(const std::shared_ptr &mesh_device, const std::vector& devices); + std::vector map_mesh_device( + const std::shared_ptr& mesh_device, + size_t num_command_queues, + size_t l1_small_size, + size_t trace_region_size, + const DispatchCoreConfig& dispatch_core_config, + const MeshDeviceConfig &config); public: static SystemMesh& instance(); - - void initialize(); + SystemMesh(const SystemMesh &) = delete; + SystemMesh &operator=(const SystemMesh &) = delete; + SystemMesh(SystemMesh &&) = delete; + SystemMesh &operator=(SystemMesh &&) = delete; // Return the shape of the logical mesh const MeshShape& get_shape() const; size_t get_num_devices() const; // Get the physical device IDs mapped to a MeshDevice - std::vector get_mapped_physical_device_ids(const MeshDeviceConfig& config) const; - void register_mesh_device(const std::shared_ptr& mesh_device, const std::vector& devices); - - // Map MeshDevice to physical devices - std::vector map_mesh_device( - const std::shared_ptr& mesh_device, - size_t num_command_queues, - size_t l1_small_size, - size_t trace_region_size, - const DispatchCoreConfig& dispatch_core_config, - const MeshDeviceConfig& config); - - // Unmap MeshDevice, releasing the associated physical devices. - void unmap_mesh_device(const MeshDevice* mesh_device); - std::shared_ptr get_mesh_device(const std::vector& physical_device_ids); - Device* get_device(const chip_id_t physical_device_id) const; + std::vector get_mapped_physical_device_ids(const MeshDeviceConfig &config) const; }; class MeshDevice : public std::enable_shared_from_this { @@ -171,7 +143,6 @@ class MeshDevice : public std::enable_shared_from_this { size_t num_program_cache_entries() const; - static std::shared_ptr fetch_mesh_device(const std::vector& devices); static std::shared_ptr create( const MeshDeviceConfig& config, size_t l1_small_size = DEFAULT_L1_SMALL_SIZE,