diff --git a/tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp b/tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp index 6d7ed90ee8b..2d2e9950467 100644 --- a/tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp +++ b/tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp @@ -130,8 +130,9 @@ TEST(GalaxyTests, TestAllGatherDeadlock) { } // Iterate over each row and run line all-gather multiple times. // For each row, send adversarial traffic to the first chip, that can hang the network if the CCL is not tagged. + auto view = MeshDeviceView(*mesh); for (uint32_t row = 0; row < 8; row++) { - auto devs = mesh->get_devices_on_row(row); + auto devs = view.get_devices_on_row(row); std::vector device_ids = {}; for (auto dev : devs) { device_ids.push_back(dev->id()); @@ -189,13 +190,14 @@ TEST(GalaxyTests, TestReduceScatterDeadlock) { std::shared_ptr mesh = ttnn::multi_device::open_mesh_device(mesh_shape, 0, 0, 1, DispatchCoreType::WORKER); // Create the outer ring on which Reduce Scatter will be run. This allows us to verify that there are no deadlocks when we send CCLs to the // first tunnel (forward path). - std::vector ring_devices = mesh->get_devices_on_row(0); // Tunnel 0 - std::vector ring_devices_1 = mesh->get_devices_on_column(mesh_shape.second - 1); // Orthogonal to tunnel .. no deadlocks + auto view = MeshDeviceView(*mesh); + std::vector ring_devices = view.get_devices_on_row(0); // Tunnel 0 + std::vector ring_devices_1 = view.get_devices_on_column(mesh_shape.second - 1); // Orthogonal to tunnel .. no deadlocks ring_devices_1 = std::vector(ring_devices_1.begin() + 1, ring_devices_1.end()); - std::vector ring_devices_2 = mesh->get_devices_on_row(7); // Tunnel 7 .. potential deadlocks with lack of buffering + std::vector ring_devices_2 = view.get_devices_on_row(7); // Tunnel 7 .. potential deadlocks with lack of buffering std::reverse(ring_devices_2.begin(), ring_devices_2.end()); ring_devices_2 = std::vector(ring_devices_2.begin() + 1, ring_devices_2.end()); - std::vector ring_devices_3 = mesh->get_devices_on_column(0); // Orthogonal to tunnel .. no deadlocks + std::vector ring_devices_3 = view.get_devices_on_column(0); // Orthogonal to tunnel .. no deadlocks std::reverse(ring_devices_3.begin(), ring_devices_3.end()); ring_devices_3 = std::vector(ring_devices_3.begin() + 1, ring_devices_3.end() - 1); diff --git a/tt_metal/impl/device/mesh_device.cpp b/tt_metal/impl/device/mesh_device.cpp index 1798df086c1..e6e8d3d1097 100644 --- a/tt_metal/impl/device/mesh_device.cpp +++ b/tt_metal/impl/device/mesh_device.cpp @@ -23,7 +23,7 @@ static std::string get_config_path(const std::string& filename) { return root_path + "/tt_metal/impl/device/mesh_configurations/" + filename; } -static std::map load_translation_map(const std::string& filename, const std::string& key) { +static std::unordered_map load_translation_map(const std::string& filename, const std::string& key) { std::ifstream file(filename); if (!file.is_open()) { throw std::runtime_error("Unable to open file: " + filename); @@ -40,7 +40,7 @@ static std::map load_translation_map(cons throw std::runtime_error("Key '" + key + "' not found in JSON file: " + filename); } - std::map result; + std::unordered_map result; for (const auto& mapping : j[key]) { if (mapping.size() != 2 || mapping[0].size() != 2 || mapping[1].size() != 4) { throw std::runtime_error("Invalid coordinate format in JSON file: " + filename); @@ -65,7 +65,7 @@ MeshShape SystemMesh::get_system_mesh_shape(std::size_t system_num_devices) { return shape; } -std::map SystemMesh::get_system_mesh_translation_map(std::size_t system_num_devices) { +std::unordered_map SystemMesh::get_system_mesh_translation_map(std::size_t system_num_devices) { const std::unordered_map system_mesh_translation_map = { {1, "device.json"}, {2, "N300.json"}, @@ -140,14 +140,22 @@ std::vector SystemMesh::get_mapped_physical_device_ids(const MeshDevi } return physical_device_ids; } +void SystemMesh::register_mesh_device(const std::shared_ptr &mesh_device, const std::vector& devices) { + std::vector physical_device_ids; + for (auto device : devices) { + physical_device_ids.push_back(device->id()); + } + this->assigned_mesh_device_devices.insert({mesh_device->get_mesh_id(), mesh_device}); + this->assigned_devices.insert({mesh_device->get_mesh_id(), physical_device_ids}); +} std::vector SystemMesh::map_mesh_device( std::shared_ptr mesh_device, - size_t num_command_queues, - size_t l1_small_size, - size_t trace_region_size, + std::size_t num_command_queues, + std::size_t l1_small_size, + std::size_t trace_region_size, DispatchCoreType dispatch_core_type, - const std::pair& offset, + const std::pair& offset, const std::vector& user_provided_physical_device_ids) { auto [requested_num_rows, requested_num_cols] = mesh_device->shape(); @@ -158,7 +166,6 @@ std::vector SystemMesh::map_mesh_device( TT_FATAL(requested_num_rows <= max_num_rows, "Requested too many rows: {} > {}", requested_num_rows, max_num_rows); TT_FATAL(requested_num_rows*requested_num_cols <= max_num_rows*max_num_cols, "Requested submesh is too big: {}x{}", requested_num_rows, requested_num_cols); - this->assigned_mesh_device_devices.insert({mesh_device->get_mesh_id(), mesh_device}); auto physical_device_ids = user_provided_physical_device_ids.empty() ? this->get_mapped_physical_device_ids(MeshDeviceConfig{mesh_device->shape(), offset}) : @@ -171,27 +178,34 @@ std::vector SystemMesh::map_mesh_device( for (auto physical_device_id : physical_device_ids) { auto mapped_device = this->opened_devices[mesh_device->get_mesh_id()].at(physical_device_id); mapped_devices.push_back(mapped_device); - this->assigned_devices[mesh_device->get_mesh_id()].push_back(physical_device_id); this->assigned_physical_id_to_device.insert({physical_device_id, mapped_device}); } + + this->register_mesh_device(mesh_device, mapped_devices); // TODO: change this return mapped_devices; } -void SystemMesh::unmap_mesh_device(const std::shared_ptr& mesh_device) { +void SystemMesh::unmap_mesh_device(const MeshDevice* mesh_device) { auto mesh_id = mesh_device->get_mesh_id(); - - // Clean up all state related to this virtual mesh this->assigned_mesh_device_devices.erase(mesh_id); - // Remove the devices from assigned_physical_id_to_device - for (auto physical_id : this->assigned_devices.at(mesh_id)) { - this->assigned_physical_id_to_device.erase(physical_id); + // Close the devices + if (mesh_device->is_parent_mesh()) { + for (auto physical_id : this->assigned_devices.at(mesh_id)) { + this->assigned_physical_id_to_device.erase(physical_id); + } + tt::tt_metal::detail::CloseDevices(this->opened_devices.at(mesh_id)); + this->opened_devices.erase(mesh_id); } this->assigned_devices.erase(mesh_id); +} - // Close the devices - tt::tt_metal::detail::CloseDevices(this->opened_devices.at(mesh_id)); - this->opened_devices.erase(mesh_id); +Device* SystemMesh::get_device(const chip_id_t physical_device_id) const { + auto it = this->assigned_physical_id_to_device.find(physical_device_id); + if (it == this->assigned_physical_id_to_device.end()) { + TT_THROW("Physical Device ID: {} not found in assigned devices", physical_device_id); + } + return it->second; } static MeshDeviceID generate_unique_mesh_id() { @@ -199,15 +213,16 @@ static MeshDeviceID generate_unique_mesh_id() { return next_id++; } -MeshDevice::MeshDevice(const MeshShape& mesh_device_shape) : mesh_device_shape(mesh_device_shape), mesh_id(generate_unique_mesh_id()) {} +MeshDevice::MeshDevice(const MeshShape& mesh_device_shape, std::shared_ptr parent_mesh) + : mesh_device_shape(mesh_device_shape), mesh_id(generate_unique_mesh_id()), parent_mesh(parent_mesh) {} std::shared_ptr MeshDevice::create( const MeshShape& mesh_device_shape, - size_t l1_small_size, - size_t trace_region_size, - size_t num_command_queues, + std::size_t l1_small_size, + std::size_t trace_region_size, + std::size_t num_command_queues, DispatchCoreType dispatch_core_type, - const std::pair& offset, + const std::pair& offset, const std::vector& user_provided_physical_device_ids) { auto mesh_device = std::make_shared(mesh_device_shape); @@ -216,12 +231,42 @@ std::shared_ptr MeshDevice::create( return mesh_device; } +std::shared_ptr MeshDevice::create_submesh(const MeshShape &submesh_shape, const MeshOffset &offset) { + if (submesh_shape.first <= 0 || submesh_shape.second <= 0) { + TT_THROW("Invalid submesh shape: ({}, {}). Both dimensions must be positive.", submesh_shape.first, submesh_shape.second); + } + + if (offset.first < 0 || offset.second < 0) { + TT_THROW("Invalid offset: ({}, {}). Offset must be non-negative.", offset.first, offset.second); + } + + if (offset.first + submesh_shape.first > this->mesh_device_shape.first || + offset.second + submesh_shape.second > this->mesh_device_shape.second) { + TT_THROW("Submesh ({}x{}) with offset ({}, {}) does not fit within parent mesh ({}x{}).", + submesh_shape.first, submesh_shape.second, + offset.first, offset.second, + this->mesh_device_shape.first, this->mesh_device_shape.second); + } + + auto submesh = std::make_shared(submesh_shape, shared_from_this()); + auto start_coordinate = Coordinate{offset.first, offset.second}; + auto end_coordinate = Coordinate{offset.first + submesh_shape.first - 1, offset.second + submesh_shape.second - 1}; + submesh->primary_view = std::make_shared(*this, start_coordinate, end_coordinate); + submesh->devices = submesh->primary_view->get_devices(); + SystemMesh::instance().register_mesh_device(submesh, submesh->devices); + this->submeshes.push_back(submesh); + log_trace(LogMetal, "Instantiating submesh {}: {}x{} with offset: {} {}", submesh->get_mesh_id(), submesh_shape.first, submesh_shape.second, offset.first, offset.second); + log_trace(LogMetal, "Submesh {} instantiated with {} devices", submesh->get_mesh_id(), submesh->devices); + + return submesh; +} + void MeshDevice::initialize( - size_t l1_small_size, - size_t trace_region_size, - size_t num_command_queues, + std::size_t l1_small_size, + std::size_t trace_region_size, + std::size_t num_command_queues, DispatchCoreType dispatch_core_type, - const std::pair& offset, + const std::pair& offset, const std::vector& physical_device_ids) { auto [num_rows, num_cols] = this->shape(); @@ -235,42 +280,36 @@ void MeshDevice::initialize( auto& instance = SystemMesh::instance(); this->devices = instance.map_mesh_device( shared_from_this(), num_command_queues, l1_small_size, trace_region_size, dispatch_core_type, offset, physical_device_ids); - this->primary_view = std::make_unique(*this); - - for (int device_index = 0; device_index < this->devices.size(); device_index++) { - this->physical_id_to_device_index.insert({this->devices[device_index]->id(), device_index}); - } + this->primary_view = std::make_shared(*this); } MeshDevice::~MeshDevice() { if (not this->devices.empty()) { this->close_devices(); } + for (auto submesh : this->submeshes) { + submesh->close_devices(); + } + this->primary_view.reset(); + this->devices.clear(); + this->parent_mesh.reset(); } -Device* MeshDevice::get_device_index(int logical_device_id) const { +Device* MeshDevice::get_device_index(std::size_t logical_device_id) const { TT_FATAL(logical_device_id >= 0 and logical_device_id < num_devices(), "Invalid device index"); return this->devices.at(logical_device_id); } -Device* MeshDevice::get_device(int physical_device_id) const { - return this->devices.at(this->physical_id_to_device_index.at(physical_device_id)); +Device* MeshDevice::get_device(chip_id_t physical_device_id) const { + return SystemMesh::instance().get_device(physical_device_id); } -std::vector MeshDevice::get_devices() const { return this->devices; } +std::vector MeshDevice::get_devices() const { return this->primary_view->get_devices(IterationOrder::LINE); } -Device* MeshDevice::get_device(int row_idx, int col_idx) const { +Device* MeshDevice::get_device(std::size_t row_idx, std::size_t col_idx) const { return this->get_device_index(row_idx * num_cols() + col_idx); } -std::vector MeshDevice::get_devices_on_row(int row_idx) const { - return this->primary_view->get_devices_on_row(row_idx); -} - -std::vector MeshDevice::get_devices_on_column(int col_idx) const { - return this->primary_view->get_devices_on_column(col_idx); -} - const DeviceIds MeshDevice::get_device_ids() const { DeviceIds device_ids; for (auto device : this->get_devices()) { @@ -279,7 +318,7 @@ const DeviceIds MeshDevice::get_device_ids() const { return device_ids; } -int MeshDevice::num_devices() const { return num_rows() * num_cols(); } +std::size_t MeshDevice::num_devices() const { return this->devices.size(); } CoreCoord MeshDevice::compute_with_storage_grid_size() const { return get_device_index(0)->compute_with_storage_grid_size(); } @@ -287,16 +326,15 @@ CoreCoord MeshDevice::dram_grid_size() const { return get_device_index(0)->dram_ tt::ARCH MeshDevice::arch() const { return get_device_index(0)->arch(); } -int MeshDevice::num_rows() const { return this->mesh_device_shape.first; } +std::size_t MeshDevice::num_rows() const { return this->mesh_device_shape.first; } -int MeshDevice::num_cols() const { return this->mesh_device_shape.second; } +std::size_t MeshDevice::num_cols() const { return this->mesh_device_shape.second; } MeshShape MeshDevice::shape() const { return this->mesh_device_shape; } void MeshDevice::close_devices() { - SystemMesh::instance().unmap_mesh_device(shared_from_this()); + SystemMesh::instance().unmap_mesh_device(this); this->devices.clear(); - this->physical_id_to_device_index.clear(); this->primary_view.reset(); } @@ -308,8 +346,60 @@ std::shared_ptr MeshDevice::get_view() const { return this std::shared_ptr MeshDevice::get_view() { return this->primary_view; } +std::vector> MeshDevice::get_submesh_views() { + std::vector> submesh_views; + if (this->submeshes.empty()) { + submesh_views.push_back(this->get_view()); + } + else { + for (auto submesh : this->submeshes) { + submesh_views.push_back(submesh->get_view()); + } + } + return submesh_views; +} + MeshDeviceID MeshDevice::get_mesh_id() const { return this->mesh_id; } +bool MeshDevice::is_parent_mesh() const { return this->parent_mesh == nullptr; } + +std::shared_ptr SystemMesh::get_mesh_device(const std::vector& physical_device_ids) { + log_trace(LogMetal, "Getting mesh device for {} physical devices: {}", physical_device_ids.size(), physical_device_ids); + std::unordered_set input_set(physical_device_ids.begin(), physical_device_ids.end()); + + for (const auto& [mesh_id, mesh_device] : this->assigned_mesh_device_devices) { + const auto& assigned_devices = this->assigned_devices.at(mesh_id); + std::unordered_set assigned_set(assigned_devices.begin(), assigned_devices.end()); + log_trace(LogMetal, "Assigned devices: {}", assigned_devices); + + if (input_set == assigned_set) { + return mesh_device; + } + } + TT_THROW("No mesh device found for the provided devices"); +} + +std::shared_ptr MeshDevice::fetch_mesh_device(const std::vector& devices) { + TT_FATAL(devices.size() > 0, "No devices provided"); + auto& instance = SystemMesh::instance(); + std::vector physical_device_ids; + for (auto device : devices) { + physical_device_ids.push_back(device->id()); + } + return instance.get_mesh_device(physical_device_ids); +} + +std::vector> MeshDevice::get_submeshes() const { return this->submeshes; } + +std::shared_ptr MeshDevice::get_view(const Device* device) { + for (auto submesh_view : this->get_submesh_views()) { + if (submesh_view->contains_device(device->id())) { + return submesh_view; + } + } + TT_THROW("Device {} not found in any submesh view", device->id()); +} + std::ostream& operator<<(std::ostream& os, const MeshDevice& mesh_device) { return os << mesh_device.to_string(); } bool validate_worker_modes(const std::vector& workers) { diff --git a/tt_metal/impl/device/mesh_device.hpp b/tt_metal/impl/device/mesh_device.hpp index 6ae7c58ba79..d806c7ee9f9 100644 --- a/tt_metal/impl/device/mesh_device.hpp +++ b/tt_metal/impl/device/mesh_device.hpp @@ -16,7 +16,7 @@ namespace tt::tt_metal { using DeviceIds = std::vector; using MeshDeviceID = std::size_t; -using MeshOffset = std::pair; +using MeshOffset = std::pair; class MeshDeviceView; struct MeshDeviceConfig { @@ -42,7 +42,7 @@ class SystemMesh { // Logical mesh shape and coordinates MeshShape logical_mesh_shape; - std::map logical_to_physical_coordinates; + std::unordered_map logical_to_physical_coordinates; // Handling of physical coordinates std::unordered_map physical_coordinate_to_device_id; @@ -55,7 +55,7 @@ class SystemMesh { SystemMesh &operator=(SystemMesh &&) = delete; static MeshShape get_system_mesh_shape(std::size_t system_num_devices); - static std::map get_system_mesh_translation_map( + static std::unordered_map get_system_mesh_translation_map( std::size_t system_num_devices); bool is_system_mesh_initialized() const; @@ -71,38 +71,43 @@ class SystemMesh { // Get the physical device IDs mapped to a MeshDevice std::vector get_mapped_physical_device_ids(const MeshDeviceConfig &config) const; + void register_mesh_device(const std::shared_ptr &mesh_device, const std::vector& devices); // Map MeshDevice to physical devices std::vector map_mesh_device( std::shared_ptr mesh_device, - size_t num_command_queues, - size_t l1_small_size, - size_t trace_region_size, + std::size_t num_command_queues, + std::size_t l1_small_size, + std::size_t trace_region_size, DispatchCoreType dispatch_core_type, - const std::pair &offset = {0, 0}, + const std::pair &offset = {0, 0}, const std::vector &physical_device_ids = {}); // Unmap MeshDevice, releasing the associated physical devices. - void unmap_mesh_device(const std::shared_ptr &mesh_device); + void unmap_mesh_device(const MeshDevice* mesh_device); + std::shared_ptr get_mesh_device(const std::vector& physical_device_ids); + Device* get_device(const chip_id_t physical_device_id) const; }; class MeshDevice : public std::enable_shared_from_this { + private: MeshDeviceID mesh_id; MeshShape mesh_device_shape; std::shared_ptr primary_view; std::vector devices; - std::unordered_map physical_id_to_device_index; + std::shared_ptr parent_mesh; + std::vector> submeshes; void initialize( - size_t l1_small_size, - size_t trace_region_size, - size_t num_command_queues, + std::size_t l1_small_size, + std::size_t trace_region_size, + std::size_t num_command_queues, DispatchCoreType dispatch_core_type, - const std::pair &offset, + const std::pair &offset, const std::vector &physical_device_ids); public: - MeshDevice(const MeshShape &mesh_device_shape); + MeshDevice(const MeshShape &mesh_device_shape, std::shared_ptr parent_mesh = nullptr); ~MeshDevice(); MeshDevice(const MeshDevice &) = delete; @@ -112,17 +117,15 @@ class MeshDevice : public std::enable_shared_from_this { MeshDevice &operator=(MeshDevice &&) = delete; std::vector get_devices() const; - Device *get_device_index(int logical_device_id) const; - Device *get_device(int physical_device_id) const; - Device *get_device(int row_idx, int col_idx) const; - std::vector get_devices_on_row(int row_idx) const; - std::vector get_devices_on_column(int col_idx) const; + Device *get_device_index(std::size_t logical_device_id) const; + Device *get_device(chip_id_t physical_device_id) const; + Device *get_device(std::size_t row_idx, std::size_t col_idx) const; const DeviceIds get_device_ids() const; - int num_devices() const; - int num_rows() const; - int num_cols() const; + std::size_t num_devices() const; + std::size_t num_rows() const; + std::size_t num_cols() const; MeshShape shape() const; CoreCoord compute_with_storage_grid_size() const; @@ -137,14 +140,22 @@ class MeshDevice : public std::enable_shared_from_this { std::string to_string() const; MeshDeviceID get_mesh_id() const; + bool is_parent_mesh() const; + std::vector> get_submeshes() const; + std::vector> get_submesh_views(); + std::shared_ptr get_view(const Device* device); + + std::shared_ptr create_submesh(const MeshShape &submesh_shape, const MeshOffset &offset = {0, 0}); + + static std::shared_ptr fetch_mesh_device(const std::vector& devices); static std::shared_ptr create( const MeshShape &mesh_device_shape, - size_t l1_small_size, - size_t trace_region_size, - size_t num_command_queues, + std::size_t l1_small_size, + std::size_t trace_region_size, + std::size_t num_command_queues, DispatchCoreType dispatch_core_type, - const std::pair &offset = {0, 0}, + const std::pair &offset = {0, 0}, const std::vector &physical_device_ids = {}); }; diff --git a/tt_metal/impl/device/mesh_device_view.cpp b/tt_metal/impl/device/mesh_device_view.cpp index 89963b10764..6d1f7527b30 100644 --- a/tt_metal/impl/device/mesh_device_view.cpp +++ b/tt_metal/impl/device/mesh_device_view.cpp @@ -13,10 +13,20 @@ namespace tt::tt_metal { using MeshDevice = tt::tt_metal::MeshDevice; +static std::vector get_devices_from_coordinates(MeshDeviceView& mesh, const std::vector& coords) { + std::vector devices; + for (const auto& coord : coords) { + if (auto device = mesh.get_device(coord.row, coord.col)) { + devices.push_back(device); + } + } + return devices; +} + MeshDeviceView::MeshDeviceView(const MeshDevice& mesh) : top_left_(0, 0), bottom_right_(mesh.num_rows() - 1, mesh.num_cols() - 1) { - for (size_t row = 0; row < mesh.num_rows(); ++row) { - for (size_t col = 0; col < mesh.num_cols(); ++col) { + for (std::size_t row = 0; row < mesh.num_rows(); ++row) { + for (std::size_t col = 0; col < mesh.num_cols(); ++col) { if (auto device = mesh.get_device(row, col)) { devices_.push_back(device); device_coordinates_[(device)->id()] = {row, col}; @@ -26,12 +36,12 @@ MeshDeviceView::MeshDeviceView(const MeshDevice& mesh) } MeshDeviceView::MeshDeviceView(const MeshDevice& mesh, Coordinate top_left, Coordinate bottom_right) - : top_left_(top_left), bottom_right_(bottom_right) { - for (size_t row = top_left.row; row <= bottom_right.row; ++row) { - for (size_t col = top_left.col; col <= bottom_right.col; ++col) { + : top_left_(0, 0), bottom_right_(Coordinate{bottom_right.row - top_left.row, bottom_right.col - top_left.col}) { + for (std::size_t row = top_left.row; row <= bottom_right.row; ++row) { + for (std::size_t col = top_left.col; col <= bottom_right.col; ++col) { if (auto device = mesh.get_device(row, col)) { devices_.push_back(device); - device_coordinates_[(device)->id()] = {row, col}; + device_coordinates_[(device)->id()] = {row - top_left.row, col - top_left.col}; } } } @@ -43,11 +53,11 @@ MeshDeviceView::MeshDeviceView(std::vector devices, CoordinateMa initialize_from_devices(devices_, std::move(mapper)); } -MeshDeviceView::device_pointer MeshDeviceView::get_device(size_t row, size_t col) { +MeshDeviceView::device_pointer MeshDeviceView::get_device(std::size_t row, std::size_t col) { return const_cast(std::as_const(*this).get_device(row, col)); } -MeshDeviceView::const_device_pointer MeshDeviceView::get_device(size_t row, size_t col) const { +MeshDeviceView::const_device_pointer MeshDeviceView::get_device(std::size_t row, std::size_t col) const { for (const auto& device : devices_) { auto it = device_coordinates_.find(device->id()); if (it != device_coordinates_.end() && it->second.row == row && it->second.col == col) { @@ -57,16 +67,14 @@ MeshDeviceView::const_device_pointer MeshDeviceView::get_device(size_t row, size return nullptr; } -const std::vector& MeshDeviceView::get_devices() const { return devices_; } - MeshDeviceView::DeviceView MeshDeviceView::get_devices(const Coordinate& start, const Coordinate& end) { if (start.row > end.row || start.col > end.col) { log_fatal("Invalid coordinates: start {} must be less than or equal to end {}", start, end); } DeviceView devices_in_region; - for (size_t row = start.row; row <= end.row; ++row) { - for (size_t col = start.col; col <= end.col; ++col) { + for (std::size_t row = start.row; row <= end.row; ++row) { + for (std::size_t col = start.col; col <= end.col; ++col) { if (auto device = get_device(row, col)) { devices_in_region.push_back(device); } @@ -79,7 +87,7 @@ MeshDeviceView::DeviceView MeshDeviceView::get_devices(const MeshShape& shape) { return get_devices({0, 0}, {shape.first - 1, shape.second - 1}); } -std::vector MeshDeviceView::get_devices_on_row(size_t row) const { +std::vector MeshDeviceView::get_devices_on_row(std::size_t row) const { std::vector row_devices; for (const auto& device : devices_) { auto it = device_coordinates_.find(device->id()); @@ -90,7 +98,7 @@ std::vector MeshDeviceView::get_devices_on_row(s return row_devices; } -std::vector MeshDeviceView::get_devices_on_column(size_t col) const { +std::vector MeshDeviceView::get_devices_on_column(std::size_t col) const { std::vector col_devices; for (const auto& device : devices_) { auto it = device_coordinates_.find(device->id()); @@ -103,7 +111,7 @@ std::vector MeshDeviceView::get_devices_on_colum std::vector> MeshDeviceView::get_row_views() const { std::vector> row_views; - for (size_t row = top_left_.row; row <= bottom_right_.row; ++row) { + for (std::size_t row = top_left_.row; row <= bottom_right_.row; ++row) { row_views.push_back(get_devices_on_row(row)); } return row_views; @@ -111,31 +119,21 @@ std::vector> MeshDeviceView::get_row std::vector> MeshDeviceView::get_column_views() const { std::vector> column_views; - for (size_t col = top_left_.col; col <= bottom_right_.col; ++col) { + for (std::size_t col = top_left_.col; col <= bottom_right_.col; ++col) { column_views.push_back(get_devices_on_column(col)); } return column_views; } -template -MeshDeviceView MeshDeviceView::subview(Pred&& predicate) const { - std::vector filtered_devices; - std::copy_if(devices_.begin(), devices_.end(), std::back_inserter(filtered_devices), std::forward(predicate)); - return MeshDeviceView(filtered_devices, [this](int device_id) { - auto it = device_coordinates_.find(device_id); - return it != device_coordinates_.end() ? std::optional(it->second) : std::nullopt; - }); -} - bool MeshDeviceView::empty() const noexcept { return devices_.empty(); } -size_t MeshDeviceView::size() const noexcept { +std::size_t MeshDeviceView::size() const noexcept { return devices_.size(); } -std::pair MeshDeviceView::shape() const noexcept { +std::pair MeshDeviceView::shape() const noexcept { return {num_rows(), num_cols()}; } @@ -158,6 +156,10 @@ bool MeshDeviceView::operator==(const MeshDeviceView& other) const { bottom_right_ == other.bottom_right_; } +bool MeshDeviceView::contains_device(chip_id_t device_id) const { + return device_coordinates_.find(device_id) != device_coordinates_.end(); +} + Coordinate MeshDeviceView::find_device(chip_id_t device_id) const { auto it = device_coordinates_.find(device_id); if (it != device_coordinates_.end()) { @@ -195,12 +197,12 @@ void MeshDeviceView::initialize_from_devices(const std::vector& } std::vector MeshDeviceView::get_line_coordinates( - size_t length, const Coordinate& offset, size_t num_rows, size_t num_cols) { + std::size_t length, const Coordinate& offset, std::size_t num_rows, std::size_t num_cols) { std::vector line_coords; auto [row_index, col_index] = offset; bool left_to_right = true; - for (size_t i = 0; i < length && row_index < num_rows && col_index < num_cols; ++i) { + for (std::size_t i = 0; i < length && row_index < num_rows && col_index < num_cols; ++i) { line_coords.emplace_back(Coordinate{row_index, col_index}); if (left_to_right && col_index < num_cols - 1) { @@ -217,7 +219,7 @@ std::vector MeshDeviceView::get_line_coordinates( return line_coords; } -std::vector MeshDeviceView::get_ring_coordinates(const MeshShape& ring_shape, const Coordinate& offset, size_t num_rows, size_t num_cols) { +std::vector MeshDeviceView::get_ring_coordinates(const MeshShape& ring_shape, const Coordinate& offset, std::size_t num_rows, std::size_t num_cols) { auto [start_row, start_col] = offset; auto [ring_rows, ring_cols] = ring_shape; auto end_row = start_row + ring_rows - 1; @@ -230,21 +232,21 @@ std::vector MeshDeviceView::get_ring_coordinates(const MeshShape& ri } // Traverse the top row from left to right - for (size_t col = start_col; col <= end_col; ++col) { + for (std::size_t col = start_col; col <= end_col; ++col) { boundary_coords.emplace_back(Coordinate{start_row, col}); } // Traverse the rightmost column from top+1 to bottom - for (size_t row = start_row + 1; row <= end_row; ++row) { + for (std::size_t row = start_row + 1; row <= end_row; ++row) { boundary_coords.emplace_back(Coordinate{row, end_col}); } // Traverse the bottom row from right to left, if there is more than one row if (ring_rows > 1 and ring_cols > 1) { - for (size_t col = end_col - 1; col >= start_col; --col) { + for (std::size_t col = end_col - 1; col >= start_col; --col) { boundary_coords.emplace_back(Coordinate{end_row, col}); } - for (size_t row = end_row - 1; row >= start_row; --row) { + for (std::size_t row = end_row - 1; row >= start_row; --row) { boundary_coords.emplace_back(Coordinate{row, start_col}); } } @@ -259,4 +261,27 @@ void MeshDeviceView::validate_coordinates() const { } } +std::vector MeshDeviceView::get_line_devices() { + auto boundary_coords = get_line_coordinates(this->num_rows() * this->num_cols(), this->top_left_, this->num_rows(), this->num_cols()); + return get_devices_from_coordinates(*this, boundary_coords); +} + +std::vector MeshDeviceView::get_ring_devices() { + auto boundary_coords = get_ring_coordinates(shape(), this->top_left_, this->num_rows(), this->num_cols()); + return get_devices_from_coordinates(*this, boundary_coords); +} + +MeshDeviceView::DeviceView MeshDeviceView::get_devices(IterationOrder order) { + switch (order) { + case IterationOrder::ROW_MAJOR: + return this->devices_; + case IterationOrder::RING: + return this->get_ring_devices(); + case IterationOrder::LINE: + return this->get_line_devices(); + default: + TT_THROW("Unsupported iteration order: {}", order); + } +} + } // namespace tt::tt_metal diff --git a/tt_metal/impl/device/mesh_device_view.hpp b/tt_metal/impl/device/mesh_device_view.hpp index 46d5d67e9fa..bb49e63d50e 100644 --- a/tt_metal/impl/device/mesh_device_view.hpp +++ b/tt_metal/impl/device/mesh_device_view.hpp @@ -17,7 +17,7 @@ namespace tt::tt_metal { // Forward declaration of MeshDevice class MeshDevice; -using MeshShape = std::pair; +using MeshShape = std::pair; struct Coordinate { std::size_t row; @@ -53,6 +53,13 @@ struct Coordinate { * specific sub-regions. This is particularly useful for collective communication operations * (CCL-ops), such as line all-gather, which require column or row views of the device mesh. */ + +enum class IterationOrder { + ROW_MAJOR, + RING, + LINE +}; + class MeshDeviceView { public: using device_pointer = Device*; @@ -65,28 +72,24 @@ class MeshDeviceView { MeshDeviceView(const MeshDevice& mesh, Coordinate top_left, Coordinate bottom_right); MeshDeviceView(std::vector devices, CoordinateMapper mapper); - [[nodiscard]] device_pointer get_device(size_t row, size_t col); - [[nodiscard]] const_device_pointer get_device(size_t row, size_t col) const; - - [[nodiscard]] const std::vector& get_devices() const; + [[nodiscard]] device_pointer get_device(std::size_t row, std::size_t col); + [[nodiscard]] const_device_pointer get_device(std::size_t row, std::size_t col) const; // Get devices spanning the rectangular region defined by the top-left and bottom-right coordinates // devices are returned in row-major order with start/end coordinates inclusive [[nodiscard]] DeviceView get_devices(const Coordinate& start, const Coordinate& end); [[nodiscard]] DeviceView get_devices(const MeshShape& shape); + [[nodiscard]] DeviceView get_devices(IterationOrder order = IterationOrder::ROW_MAJOR); - [[nodiscard]] DeviceView get_devices_on_row(size_t row) const; - [[nodiscard]] DeviceView get_devices_on_column(size_t col) const; + [[nodiscard]] DeviceView get_devices_on_row(std::size_t row) const; + [[nodiscard]] DeviceView get_devices_on_column(std::size_t col) const; [[nodiscard]] DeviceViews get_row_views() const; [[nodiscard]] DeviceViews get_column_views() const; - template - [[nodiscard]] MeshDeviceView subview(Pred&& predicate) const; - [[nodiscard]] bool empty() const noexcept; - [[nodiscard]] size_t size() const noexcept; - [[nodiscard]] std::pair shape() const noexcept; + [[nodiscard]] std::size_t size() const noexcept; + [[nodiscard]] MeshShape shape() const noexcept; [[nodiscard]] bool contains(const Coordinate& coord) const noexcept; [[nodiscard]] const_device_pointer at(const Coordinate& coord) const noexcept; @@ -99,15 +102,18 @@ class MeshDeviceView { [[nodiscard]] std::size_t num_cols() const { return bottom_right_.col - top_left_.col + 1; } [[nodiscard]] std::size_t num_devices() const { return devices_.size(); } + [[nodiscard]] bool contains_device(chip_id_t device_id) const; [[nodiscard]] Coordinate find_device(chip_id_t device_id) const; [[nodiscard]] chip_id_t find_device_id(const Coordinate& coord) const; // Given a starting coordinate, get the coordinates of a line of devices where device[i-1] is connected to device[i] // The current support only provides left-to-right and right-to-left snaking of the line. - [[nodiscard]] static std::vector get_line_coordinates(size_t length, const Coordinate& offset, size_t num_rows, size_t num_cols); - [[nodiscard]] std::vector get_ring_coordinates(const MeshShape& ring_shape, const Coordinate& offset, size_t num_rows, size_t num_cols); + [[nodiscard]] static std::vector get_line_coordinates(std::size_t length, const Coordinate& offset, std::size_t num_rows, std::size_t num_cols); + [[nodiscard]] std::vector get_ring_coordinates(const MeshShape& ring_shape, const Coordinate& offset, std::size_t num_rows, std::size_t num_cols); private: + [[nodiscard]] std::vector get_ring_devices(); + [[nodiscard]] std::vector get_line_devices(); std::vector devices_; std::unordered_map device_coordinates_; Coordinate top_left_; @@ -124,10 +130,21 @@ inline MeshDeviceView make_mesh_device_view(std::vector devices, MeshDe } // namespace tt::tt_metal -// Specializations to enable structured bindings namespace std { + // Specializations to enable structured bindings template<> struct tuple_size : std::integral_constant {}; template struct tuple_element { using type = std::size_t; }; + + // Specialization to enable hashing of Coordinate + template <> + struct hash { + std::size_t operator()(const tt::tt_metal::Coordinate& coord) const noexcept { + std::size_t seed = 0; + tt::utils::hash_combine(seed, coord.row); + tt::utils::hash_combine(seed, coord.col); + return seed; + } + }; } // namespace std diff --git a/ttnn/cpp/pybind11/multi_device.hpp b/ttnn/cpp/pybind11/multi_device.hpp index 70d9755d040..fb2ec846a2f 100644 --- a/ttnn/cpp/pybind11/multi_device.hpp +++ b/ttnn/cpp/pybind11/multi_device.hpp @@ -23,12 +23,12 @@ void py_module(py::module& module) { py_mesh_device .def( py::init([](const MeshShape& mesh_device_shape, - size_t l1_small_size, - size_t trace_region_size, - size_t num_command_queues, + std::size_t l1_small_size, + std::size_t trace_region_size, + std::size_t num_command_queues, DispatchCoreType dispatch_core_type, - const std::pair& offset, - const std::vector& physical_device_ids) { + const std::pair& offset, + const std::vector& physical_device_ids) { return MeshDevice::create( mesh_device_shape, l1_small_size, @@ -47,14 +47,15 @@ void py_module(py::module& module) { py::arg("offset"), py::arg("physical_device_ids")) .def("get_num_devices", &MeshDevice::num_devices) + .def("get_mesh_id", &MeshDevice::get_mesh_id) .def("get_device_ids", &MeshDevice::get_device_ids) .def( "get_device", - py::overload_cast(&MeshDevice::get_device, py::const_), + py::overload_cast(&MeshDevice::get_device, py::const_), py::return_value_policy::reference) .def( "get_device", - py::overload_cast(&MeshDevice::get_device, py::const_), + py::overload_cast(&MeshDevice::get_device, py::const_), py::return_value_policy::reference) .def("get_devices", &MeshDevice::get_devices, py::return_value_policy::reference, R"doc( Get the devices in the device mesh. @@ -62,26 +63,7 @@ void py_module(py::module& module) { Returns: List[Device]: The devices in the device mesh. )doc") - .def( - "get_devices_on_row", - &MeshDevice::get_devices_on_row, - py::return_value_policy::reference, - R"doc( - Get the devices in a row of the device mesh. - - Returns: - List[Device]: The devices on a row in the device mesh. - )doc") - .def( - "get_devices_on_column", - &MeshDevice::get_devices_on_column, - py::return_value_policy::reference, - R"doc( - Get the devices in a row of the device mesh. - - Returns: - List[Device]: The devices on a row in the device mesh. - )doc") + .def("create_submesh", &MeshDevice::create_submesh, py::arg("submesh_shape"), py::arg("offset") = std::pair{0, 0}, py::return_value_policy::reference_internal, py::keep_alive<0, 1>()) .def( "compute_with_storage_grid_size", &MeshDevice::compute_with_storage_grid_size,