diff --git a/tests/ttnn/distributed/test_distributed.cpp b/tests/ttnn/distributed/test_distributed.cpp index 817860f62e20..a5be1300e9d5 100644 --- a/tests/ttnn/distributed/test_distributed.cpp +++ b/tests/ttnn/distributed/test_distributed.cpp @@ -16,7 +16,7 @@ class DistributedTest : public ::testing::Test { void TearDown() override {} }; -TEST_F(DistributedTest, TestSystemMeshTearDownWithoutClose) { +TEST_F(DistributedTest, TestSystemMeshTearDownWithoutClose_LocalMesh) { auto& sys = tt::tt_metal::distributed::SystemMesh::instance(); auto mesh = ttnn::distributed::open_mesh_device( {2, 4}, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, 1, tt::tt_metal::DispatchCoreType::WORKER); @@ -26,4 +26,15 @@ TEST_F(DistributedTest, TestSystemMeshTearDownWithoutClose) { EXPECT_GT(cols, 0); } +TEST_F(DistributedTest, TestSystemMeshTearDownWithoutClose_StaticMesh) { + static std::shared_ptr mesh; + auto& sys = tt::tt_metal::distributed::SystemMesh::instance(); + mesh = ttnn::distributed::open_mesh_device( + {2, 4}, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, 1, tt::tt_metal::DispatchCoreType::WORKER); + + auto [rows, cols] = sys.get_shape(); + EXPECT_GT(rows, 0); + EXPECT_GT(cols, 0); +} + } // namespace ttnn::distributed::test diff --git a/tt_metal/distributed/mesh_device.cpp b/tt_metal/distributed/mesh_device.cpp index a81476ef69e9..f5865e4a89ca 100644 --- a/tt_metal/distributed/mesh_device.cpp +++ b/tt_metal/distributed/mesh_device.cpp @@ -53,7 +53,47 @@ static std::unordered_map load_translatio return result; } -MeshShape SystemMesh::get_system_mesh_shape(size_t system_num_devices) { +class SystemMesh::Impl { + private: + using LogicalCoordinate = Coordinate; + using PhysicalCoordinate = eth_coord_t; + + std::unordered_map> opened_devices; + std::unordered_map> assigned_devices; + std::unordered_map> assigned_mesh_device_devices; + + MeshShape logical_mesh_shape; + std::unordered_map logical_to_physical_coordinates; + std::unordered_map physical_coordinate_to_device_id; + std::unordered_map physical_device_id_to_coordinate; + + + public: + Impl() = default; + ~Impl() = default; + + bool is_system_mesh_initialized() const; + void initialize(); + const MeshShape& get_shape() const; + size_t get_num_devices() const; + std::vector map_mesh_device( + std::shared_ptr mesh_device, + size_t num_command_queues, + size_t l1_small_size, + size_t trace_region_size, + DispatchCoreType dispatch_core_type, + const MeshDeviceConfig &config); + std::vector get_mapped_physical_device_ids(const MeshDeviceConfig& config) const; + void remove_expired_mesh_devices(); + Device* get_device(const chip_id_t physical_device_id) const; + void register_mesh_device(const std::shared_ptr &mesh_device, const std::vector& devices); + + static MeshShape get_system_mesh_shape(size_t system_num_devices); + static std::unordered_map get_system_mesh_translation_map(size_t system_num_devices); +}; + +// Implementation of private static methods +MeshShape SystemMesh::Impl::get_system_mesh_shape(size_t system_num_devices) { const std::unordered_map system_mesh_to_shape = { {1, MeshShape{1, 1}}, // single-device {2, MeshShape{1, 2}}, // N300 @@ -67,7 +107,7 @@ MeshShape SystemMesh::get_system_mesh_shape(size_t system_num_devices) { return shape; } -std::unordered_map SystemMesh::get_system_mesh_translation_map(size_t system_num_devices) { +std::unordered_map SystemMesh::Impl::get_system_mesh_translation_map(size_t system_num_devices) { const std::unordered_map system_mesh_translation_map = { {1, "device.json"}, {2, "N300.json"}, @@ -80,36 +120,30 @@ std::unordered_map SystemMesh::get_system return load_translation_map(translation_config_file, "logical_to_physical_coordinates"); } -bool SystemMesh::is_system_mesh_initialized() const { +// Implementation of public methods +bool SystemMesh::Impl::is_system_mesh_initialized() const { return this->physical_coordinate_to_device_id.size() > 0; } -SystemMesh& SystemMesh::instance() { - static SystemMesh instance; - if (!instance.is_system_mesh_initialized()) { - instance.initialize(); - } - return instance; -} -void SystemMesh::initialize() { +void SystemMesh::Impl::initialize() { this->physical_device_id_to_coordinate = tt::Cluster::instance().get_user_chip_ethernet_coordinates(); for (const auto& [chip_id, physical_coordinate] : this->physical_device_id_to_coordinate) { this->physical_coordinate_to_device_id.emplace(physical_coordinate, chip_id); } - // Initialize the system mesh shape and translation map auto num_devices = physical_coordinate_to_device_id.size(); - this->logical_mesh_shape = SystemMesh::get_system_mesh_shape(num_devices); - this->logical_to_physical_coordinates = SystemMesh::get_system_mesh_translation_map(num_devices); + this->logical_mesh_shape = get_system_mesh_shape(num_devices); + this->logical_to_physical_coordinates = get_system_mesh_translation_map(num_devices); } -const MeshShape& SystemMesh::get_shape() const { return this->logical_mesh_shape; } -size_t SystemMesh::get_num_devices() const { +const MeshShape& SystemMesh::Impl::get_shape() const { return this->logical_mesh_shape; } +size_t SystemMesh::Impl::get_num_devices() const { auto [num_rows, num_cols] = this->get_shape(); return num_rows * num_cols; } -std::vector SystemMesh::get_mapped_physical_device_ids(const MeshDeviceConfig& config) const { + +std::vector SystemMesh::Impl::get_mapped_physical_device_ids(const MeshDeviceConfig& config) const { std::vector physical_device_ids; auto [system_mesh_rows, system_mesh_cols] = this->get_shape(); auto [requested_rows, requested_cols] = config.mesh_shape; @@ -142,7 +176,7 @@ std::vector SystemMesh::get_mapped_physical_device_ids(const MeshDevi } return physical_device_ids; } -void SystemMesh::register_mesh_device(const std::shared_ptr &mesh_device, const std::vector& devices) { +void SystemMesh::Impl::register_mesh_device(const std::shared_ptr &mesh_device, const std::vector& devices) { std::vector physical_device_ids; for (auto device : devices) { physical_device_ids.push_back(device->id()); @@ -151,13 +185,14 @@ void SystemMesh::register_mesh_device(const std::shared_ptr &mesh_de this->assigned_devices.insert({mesh_device->get_mesh_id(), physical_device_ids}); } -std::vector SystemMesh::map_mesh_device( +std::vector SystemMesh::Impl::map_mesh_device( std::shared_ptr mesh_device, size_t num_command_queues, size_t l1_small_size, size_t trace_region_size, DispatchCoreType dispatch_core_type, const MeshDeviceConfig& config) { + this->remove_expired_mesh_devices(); auto [requested_num_rows, requested_num_cols] = mesh_device->shape(); auto [max_num_rows, max_num_cols] = this->logical_mesh_shape; @@ -179,34 +214,65 @@ std::vector SystemMesh::map_mesh_device( for (auto physical_device_id : physical_device_ids) { auto mapped_device = this->opened_devices[mesh_device->get_mesh_id()].at(physical_device_id); mapped_devices.push_back(mapped_device); - this->assigned_physical_id_to_device.insert({physical_device_id, mapped_device}); } this->register_mesh_device(mesh_device, mapped_devices); // here return mapped_devices; } -void SystemMesh::unmap_mesh_device(const MeshDevice* mesh_device) { - auto mesh_id = mesh_device->get_mesh_id(); - this->assigned_mesh_device_devices.erase(mesh_id); +void SystemMesh::Impl::remove_expired_mesh_devices() { + std::vector stale_ids; + for (const auto& [mesh_id, weak_mesh_device] : assigned_mesh_device_devices) { + if (weak_mesh_device.expired()) { + stale_ids.push_back(mesh_id); + } + } + for (auto mesh_id : stale_ids) { + this->assigned_mesh_device_devices.erase(mesh_id); - // Close the devices - if (mesh_device->is_parent_mesh()) { - for (auto physical_id : this->assigned_devices.at(mesh_id)) { - this->assigned_physical_id_to_device.erase(physical_id); + if (assigned_devices.count(mesh_id)) { + assigned_devices.erase(mesh_id); + } + + if (opened_devices.count(mesh_id)) { + tt::tt_metal::detail::CloseDevices(this->opened_devices.at(mesh_id)); + this->opened_devices.erase(mesh_id); } - tt::tt_metal::detail::CloseDevices(this->opened_devices.at(mesh_id)); - this->opened_devices.erase(mesh_id); } - this->assigned_devices.erase(mesh_id); } -Device* SystemMesh::get_device(const chip_id_t physical_device_id) const { - auto it = this->assigned_physical_id_to_device.find(physical_device_id); - if (it == this->assigned_physical_id_to_device.end()) { - TT_THROW("Physical Device ID: {} not found in assigned devices", physical_device_id); +SystemMesh::SystemMesh() : pimpl(std::make_unique()) {} +SystemMesh::~SystemMesh() = default; + +SystemMesh& SystemMesh::instance() { + static SystemMesh instance; + if (!instance.pimpl->is_system_mesh_initialized()) { + instance.pimpl->initialize(); } - return it->second; + return instance; +} + +const MeshShape& SystemMesh::get_shape() const { return pimpl->get_shape(); } + +size_t SystemMesh::get_num_devices() const { return pimpl->get_num_devices(); } + +void SystemMesh::register_mesh_device(const std::shared_ptr &mesh_device, const std::vector& devices) { + pimpl->register_mesh_device(mesh_device, devices); +} + +std::vector SystemMesh::map_mesh_device( + std::shared_ptr mesh_device, + size_t num_command_queues, + size_t l1_small_size, + size_t trace_region_size, + DispatchCoreType dispatch_core_type, + const MeshDeviceConfig& config) { + return pimpl->map_mesh_device(mesh_device, num_command_queues, l1_small_size, trace_region_size, dispatch_core_type, config); +} + + +std::vector SystemMesh::get_mapped_physical_device_ids(const MeshDeviceConfig& config) const { + return pimpl->get_mapped_physical_device_ids(config); } static MeshDeviceID generate_unique_mesh_id() { @@ -309,7 +375,12 @@ Device* MeshDevice::get_device_index(size_t logical_device_id) const { } Device* MeshDevice::get_device(chip_id_t physical_device_id) const { - return SystemMesh::instance().get_device(physical_device_id); + for (auto device : this->devices) { + if (device->id() == physical_device_id) { + return device; + } + } + TT_THROW("Physical Device ID: {} not found in assigned devices", physical_device_id); } std::vector MeshDevice::get_devices() const { return this->primary_view->get_devices(this->type); } @@ -344,9 +415,7 @@ void MeshDevice::close_devices() { for (const auto& submesh : this->submeshes) { submesh->close_devices(); } - if (not this->devices.empty()) { - SystemMesh::instance().unmap_mesh_device(this); - } + this->submeshes.clear(); this->parent_mesh.reset(); this->devices.clear(); this->primary_view.reset(); @@ -364,34 +433,6 @@ MeshDeviceID MeshDevice::get_mesh_id() const { return this->mesh_id; } bool MeshDevice::is_parent_mesh() const { return this->parent_mesh.expired(); } -std::shared_ptr SystemMesh::get_mesh_device(const std::vector& physical_device_ids) { - log_trace(LogMetal, "Getting mesh device for {} physical devices: {}", physical_device_ids.size(), physical_device_ids); - std::unordered_set input_set(physical_device_ids.begin(), physical_device_ids.end()); - - for (const auto& [mesh_id, weak_mesh_device] : this->assigned_mesh_device_devices) { - if (auto mesh_device = weak_mesh_device.lock()) { - const auto& assigned_devices = this->assigned_devices.at(mesh_id); - std::unordered_set assigned_set(assigned_devices.begin(), assigned_devices.end()); - log_trace(LogMetal, "Assigned devices: {}", assigned_devices); - - if (input_set == assigned_set) { - return mesh_device; - } - } - } - TT_THROW("No mesh device found for the provided devices"); -} - -std::shared_ptr MeshDevice::fetch_mesh_device(const std::vector& devices) { - TT_FATAL(devices.size() > 0, "No devices provided"); - auto& instance = SystemMesh::instance(); - std::vector physical_device_ids; - for (auto device : devices) { - physical_device_ids.push_back(device->id()); - } - return instance.get_mesh_device(physical_device_ids); -} - std::vector> MeshDevice::get_submeshes() const { return this->submeshes; } std::ostream& operator<<(std::ostream& os, const MeshDevice& mesh_device) { return os << mesh_device.to_string(); } diff --git a/tt_metal/distributed/mesh_device.hpp b/tt_metal/distributed/mesh_device.hpp index b23346ed7239..5c46eba9ad10 100644 --- a/tt_metal/distributed/mesh_device.hpp +++ b/tt_metal/distributed/mesh_device.hpp @@ -50,62 +50,34 @@ struct MeshDeviceConfig { // device resources. class SystemMesh { private: - using LogicalCoordinate = Coordinate; - using PhysicalCoordinate = eth_coord_t; + friend class MeshDevice; + class Impl; // Forward declaration only + std::unique_ptr pimpl; + SystemMesh(); + ~SystemMesh(); - // Keep track of the devices that were opened so we can close them later. We shouldn't - // to keep track of this but DevicePool seems to open all devices associated with an MMIO device id - std::unordered_map> opened_devices; - std::unordered_map> assigned_devices; - std::unordered_map> assigned_mesh_device_devices; - std::unordered_map assigned_physical_id_to_device; - - // Logical mesh shape and coordinates - MeshShape logical_mesh_shape; - std::unordered_map logical_to_physical_coordinates; - - // Handling of physical coordinates - std::unordered_map physical_coordinate_to_device_id; - std::unordered_map physical_device_id_to_coordinate; + void register_mesh_device(const std::shared_ptr &mesh_device, const std::vector& devices); + std::vector map_mesh_device( + std::shared_ptr mesh_device, + size_t num_command_queues, + size_t l1_small_size, + size_t trace_region_size, + DispatchCoreType dispatch_core_type, + const MeshDeviceConfig &config); - SystemMesh() = default; + public: + static SystemMesh &instance(); SystemMesh(const SystemMesh &) = delete; SystemMesh &operator=(const SystemMesh &) = delete; SystemMesh(SystemMesh &&) = delete; SystemMesh &operator=(SystemMesh &&) = delete; - static MeshShape get_system_mesh_shape(size_t system_num_devices); - static std::unordered_map get_system_mesh_translation_map( - size_t system_num_devices); - - bool is_system_mesh_initialized() const; - - public: - static SystemMesh &instance(); - - void initialize(); - // Return the shape of the logical mesh const MeshShape &get_shape() const; size_t get_num_devices() const; // Get the physical device IDs mapped to a MeshDevice std::vector get_mapped_physical_device_ids(const MeshDeviceConfig &config) const; - void register_mesh_device(const std::shared_ptr &mesh_device, const std::vector& devices); - - // Map MeshDevice to physical devices - std::vector map_mesh_device( - std::shared_ptr mesh_device, - size_t num_command_queues, - size_t l1_small_size, - size_t trace_region_size, - DispatchCoreType dispatch_core_type, - const MeshDeviceConfig &config); - - // Unmap MeshDevice, releasing the associated physical devices. - void unmap_mesh_device(const MeshDevice* mesh_device); - std::shared_ptr get_mesh_device(const std::vector& physical_device_ids); - Device* get_device(const chip_id_t physical_device_id) const; }; class MeshDevice : public std::enable_shared_from_this { @@ -177,7 +149,6 @@ class MeshDevice : public std::enable_shared_from_this { size_t num_program_cache_entries() const; - static std::shared_ptr fetch_mesh_device(const std::vector& devices); static std::shared_ptr create( const MeshDeviceConfig &config, size_t l1_small_size = DEFAULT_L1_SMALL_SIZE,