diff --git a/models/demos/t3000/llama2_70b/tt/llama_mlp_optimized.py b/models/demos/t3000/llama2_70b/tt/llama_mlp_optimized.py index 1ac4f56a41a8..4d30605e63c7 100644 --- a/models/demos/t3000/llama2_70b/tt/llama_mlp_optimized.py +++ b/models/demos/t3000/llama2_70b/tt/llama_mlp_optimized.py @@ -72,12 +72,11 @@ def load_weights(self): padded_w3[:, :, :, :H4] = self.state_dict[w3_str].transpose(-2, -1) # w1: 8k x 4k. width-sharded on 12 banks, 4224 over 12 banks. - device = self.mesh_device.get_device(0) weight_grid = ttnn.CoreRangeSet( { ttnn.CoreRange( ttnn.CoreCoord(0, 0), - ttnn.CoreCoord(device.dram_grid_size().x - 1, device.dram_grid_size().y - 1), + ttnn.CoreCoord(self.mesh_device.dram_grid_size().x - 1, self.mesh_device.dram_grid_size().y - 1), ) } ) diff --git a/tests/scripts/tg/run_tg_model_perf_tests.sh b/tests/scripts/tg/run_tg_model_perf_tests.sh index 7cd43da8c897..76d85050dcc5 100755 --- a/tests/scripts/tg/run_tg_model_perf_tests.sh +++ b/tests/scripts/tg/run_tg_model_perf_tests.sh @@ -1,6 +1,10 @@ #!/bin/bash -run_tg_llm_tests() { +run_t3k_tests_on_tg_tests() { + + echo "LOG_METAL: Running T3000 tests on TG" + env pytest -n auto models/demos/t3000/llama2_70b/tests/test_llama_perf_decode.py -m "model_perf_t3000" --timeout=600 ; fail+=$? + # Merge all the generated reports env python models/perf/merge_perf_results.py; fail+=$? diff --git a/tests/ttnn/multichip_unit_tests/test_multidevice_TG.py b/tests/ttnn/multichip_unit_tests/test_multidevice_TG.py index 3d927636f91a..4ca13900f54d 100644 --- a/tests/ttnn/multichip_unit_tests/test_multidevice_TG.py +++ b/tests/ttnn/multichip_unit_tests/test_multidevice_TG.py @@ -1573,3 +1573,19 @@ def test_sharded_distributed_layernorm(mesh_device, input_width, input_height, c is_pass, output_pcc = comp_pcc(torch_output_tensor, tt_output_tensor, pcc=0.999) assert is_pass, f"PCC value: {output_pcc}" + + +def test_ttnn_multi_device_all_gather_all_devices(t3k_mesh_device): + """Example test for running a 2x4-Ring All-Gather on galaxy""" + full_tensor = torch.ones((1, 1, 32, 32 * t3k_mesh_device.get_num_devices()), dtype=torch.bfloat16) + for i in range(t3k_mesh_device.get_num_devices()): + full_tensor[..., i * 32 : (i + 1) * 32] = i + + ttnn_tensor = ttnn.from_torch(full_tensor, mesh_mapper=ShardTensorToMesh(t3k_mesh_device, dim=3)) + ttnn_tensor = ttnn.to_device(ttnn_tensor, t3k_mesh_device) + ttnn_tensor = ttnn.all_gather(ttnn_tensor, dim=3, num_links=1) + + device_tensors: typing.List[ttnn.Tensor] = ttnn.get_device_tensors(ttnn_tensor) + for device_tensor in device_tensors: + device_tensor_torch = ttnn.to_torch(device_tensor) + assert torch.all(device_tensor_torch == full_tensor) diff --git a/tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp b/tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp index 6d7ed90ee8b9..2d2e99504677 100644 --- a/tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp +++ b/tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp @@ -130,8 +130,9 @@ TEST(GalaxyTests, TestAllGatherDeadlock) { } // Iterate over each row and run line all-gather multiple times. // For each row, send adversarial traffic to the first chip, that can hang the network if the CCL is not tagged. + auto view = MeshDeviceView(*mesh); for (uint32_t row = 0; row < 8; row++) { - auto devs = mesh->get_devices_on_row(row); + auto devs = view.get_devices_on_row(row); std::vector device_ids = {}; for (auto dev : devs) { device_ids.push_back(dev->id()); @@ -189,13 +190,14 @@ TEST(GalaxyTests, TestReduceScatterDeadlock) { std::shared_ptr mesh = ttnn::multi_device::open_mesh_device(mesh_shape, 0, 0, 1, DispatchCoreType::WORKER); // Create the outer ring on which Reduce Scatter will be run. This allows us to verify that there are no deadlocks when we send CCLs to the // first tunnel (forward path). - std::vector ring_devices = mesh->get_devices_on_row(0); // Tunnel 0 - std::vector ring_devices_1 = mesh->get_devices_on_column(mesh_shape.second - 1); // Orthogonal to tunnel .. no deadlocks + auto view = MeshDeviceView(*mesh); + std::vector ring_devices = view.get_devices_on_row(0); // Tunnel 0 + std::vector ring_devices_1 = view.get_devices_on_column(mesh_shape.second - 1); // Orthogonal to tunnel .. no deadlocks ring_devices_1 = std::vector(ring_devices_1.begin() + 1, ring_devices_1.end()); - std::vector ring_devices_2 = mesh->get_devices_on_row(7); // Tunnel 7 .. potential deadlocks with lack of buffering + std::vector ring_devices_2 = view.get_devices_on_row(7); // Tunnel 7 .. potential deadlocks with lack of buffering std::reverse(ring_devices_2.begin(), ring_devices_2.end()); ring_devices_2 = std::vector(ring_devices_2.begin() + 1, ring_devices_2.end()); - std::vector ring_devices_3 = mesh->get_devices_on_column(0); // Orthogonal to tunnel .. no deadlocks + std::vector ring_devices_3 = view.get_devices_on_column(0); // Orthogonal to tunnel .. no deadlocks std::reverse(ring_devices_3.begin(), ring_devices_3.end()); ring_devices_3 = std::vector(ring_devices_3.begin() + 1, ring_devices_3.end() - 1); diff --git a/tests/ttnn/unit_tests/test_multi_device.py b/tests/ttnn/unit_tests/test_multi_device.py index f1c2728857f0..982e539689df 100644 --- a/tests/ttnn/unit_tests/test_multi_device.py +++ b/tests/ttnn/unit_tests/test_multi_device.py @@ -587,3 +587,7 @@ def test_validate_as_tensor(tmp_path, mesh_device, height, width): for device in mesh_device.get_devices(): device_tensor = ttnn.get_device_tensor(tensor, device) assert torch.allclose(ttnn.to_torch(device_tensor), torch_input_tensor) + + +def test_ttnn_visualize_mesh_device(t3k_mesh_device): + ttnn.visualize_mesh_device(t3k_mesh_device) diff --git a/tt_metal/impl/device/mesh_configurations/T3000.json b/tt_metal/impl/device/mesh_configurations/T3000.json index 2c62209d01fc..acfe3edac004 100644 --- a/tt_metal/impl/device/mesh_configurations/T3000.json +++ b/tt_metal/impl/device/mesh_configurations/T3000.json @@ -1,6 +1,6 @@ { "logical_to_physical_coordinates": [ [[0, 0], [0, 0, 0, 0]], [[0, 1], [0, 1, 0, 0]], [[0, 2], [0, 2, 0, 0]], [[0, 3], [0, 3, 0, 0]], - [[1, 0], [1, 3, 0, 0]], [[1, 1], [1, 2, 0, 0]], [[1, 2], [1, 1, 0, 0]], [[1, 3], [1, 0, 0, 0]] + [[1, 0], [1, 0, 0, 0]], [[1, 1], [1, 1, 0, 0]], [[1, 2], [1, 2, 0, 0]], [[1, 3], [1, 3, 0, 0]] ] } diff --git a/tt_metal/impl/device/mesh_device.cpp b/tt_metal/impl/device/mesh_device.cpp index e90d4a8925e2..7c3f1a461078 100644 --- a/tt_metal/impl/device/mesh_device.cpp +++ b/tt_metal/impl/device/mesh_device.cpp @@ -127,6 +127,14 @@ std::vector SystemMesh::get_mapped_physical_device_ids(const MeshDevi } return physical_device_ids; } +void SystemMesh::register_mesh_device(const std::shared_ptr &mesh_device, const std::vector& devices) { + std::vector physical_device_ids; + for (auto device : devices) { + physical_device_ids.push_back(device->id()); + } + this->assigned_mesh_device_devices.insert({mesh_device->get_mesh_id(), mesh_device}); + this->assigned_devices.insert({mesh_device->get_mesh_id(), physical_device_ids}); +} std::vector SystemMesh::map_mesh_device( std::shared_ptr mesh_device, @@ -145,7 +153,6 @@ std::vector SystemMesh::map_mesh_device( TT_FATAL(requested_num_rows <= max_num_rows, "Requested too many rows: {} > {}", requested_num_rows, max_num_rows); TT_FATAL(requested_num_rows*requested_num_cols <= max_num_rows*max_num_cols, "Requested submesh is too big: {}x{}", requested_num_rows, requested_num_cols); - this->assigned_mesh_device_devices.insert({mesh_device->get_mesh_id(), mesh_device}); auto physical_device_ids = user_provided_physical_device_ids.empty() ? this->get_mapped_physical_device_ids(MeshDeviceConfig{mesh_device->shape(), offset}) : @@ -158,27 +165,34 @@ std::vector SystemMesh::map_mesh_device( for (auto physical_device_id : physical_device_ids) { auto mapped_device = this->opened_devices[mesh_device->get_mesh_id()].at(physical_device_id); mapped_devices.push_back(mapped_device); - this->assigned_devices[mesh_device->get_mesh_id()].push_back(physical_device_id); this->assigned_physical_id_to_device.insert({physical_device_id, mapped_device}); } + + this->register_mesh_device(mesh_device, mapped_devices); // TODO: change this return mapped_devices; } void SystemMesh::unmap_mesh_device(const std::shared_ptr& mesh_device) { auto mesh_id = mesh_device->get_mesh_id(); - - // Clean up all state related to this virtual mesh this->assigned_mesh_device_devices.erase(mesh_id); - // Remove the devices from assigned_physical_id_to_device - for (auto physical_id : this->assigned_devices.at(mesh_id)) { - this->assigned_physical_id_to_device.erase(physical_id); + // Close the devices + if (mesh_device->is_parent_mesh()) { + for (auto physical_id : this->assigned_devices.at(mesh_id)) { + this->assigned_physical_id_to_device.erase(physical_id); + } + tt::tt_metal::detail::CloseDevices(this->opened_devices.at(mesh_id)); + this->opened_devices.erase(mesh_id); } this->assigned_devices.erase(mesh_id); +} - // Close the devices - tt::tt_metal::detail::CloseDevices(this->opened_devices.at(mesh_id)); - this->opened_devices.erase(mesh_id); +Device* SystemMesh::get_device(const chip_id_t physical_device_id) { + auto it = this->assigned_physical_id_to_device.find(physical_device_id); + if (it == this->assigned_physical_id_to_device.end()) { + TT_THROW("Physical Device ID: {} not found in assigned devices", physical_device_id); + } + return it->second; } static MeshDeviceID generate_unique_mesh_id() { @@ -186,7 +200,8 @@ static MeshDeviceID generate_unique_mesh_id() { return next_id++; } -MeshDevice::MeshDevice(const MeshShape& mesh_device_shape) : mesh_device_shape(mesh_device_shape), mesh_id(generate_unique_mesh_id()) {} +MeshDevice::MeshDevice(const MeshShape& mesh_device_shape, std::shared_ptr parent_mesh) + : mesh_device_shape(mesh_device_shape), mesh_id(generate_unique_mesh_id()), parent_mesh(parent_mesh) {} std::shared_ptr MeshDevice::create( const MeshShape& mesh_device_shape, @@ -203,6 +218,36 @@ std::shared_ptr MeshDevice::create( return mesh_device; } +std::shared_ptr MeshDevice::create_submesh(const MeshShape &submesh_shape, const MeshOffset &offset) { + if (submesh_shape.first <= 0 || submesh_shape.second <= 0) { + TT_THROW("Invalid submesh shape: ({}, {}). Both dimensions must be positive.", submesh_shape.first, submesh_shape.second); + } + + if (offset.first < 0 || offset.second < 0) { + TT_THROW("Invalid offset: ({}, {}). Offset must be non-negative.", offset.first, offset.second); + } + + if (offset.first + submesh_shape.first > this->mesh_device_shape.first || + offset.second + submesh_shape.second > this->mesh_device_shape.second) { + TT_THROW("Submesh ({}x{}) with offset ({}, {}) does not fit within parent mesh ({}x{}).", + submesh_shape.first, submesh_shape.second, + offset.first, offset.second, + this->mesh_device_shape.first, this->mesh_device_shape.second); + } + + auto submesh = std::make_shared(submesh_shape, shared_from_this()); + auto start_coordinate = Coordinate{offset.first, offset.second}; + auto end_coordinate = Coordinate{offset.first + submesh_shape.first - 1, offset.second + submesh_shape.second - 1}; + submesh->primary_view = std::make_unique(*this, start_coordinate, end_coordinate); + submesh->devices = submesh->primary_view->get_devices(); + SystemMesh::instance().register_mesh_device(submesh, submesh->devices); + this->submeshes.push_back(submesh); + log_trace(LogMetal, "Instantiating submesh {}: {}x{} with offset: {} {}", submesh->get_mesh_id(), submesh_shape.first, submesh_shape.second, offset.first, offset.second); + log_trace(LogMetal, "Submesh {} instantiated with {} devices", submesh->get_mesh_id(), submesh->devices); + + return submesh; +} + void MeshDevice::initialize( size_t l1_small_size, size_t trace_region_size, @@ -223,16 +268,18 @@ void MeshDevice::initialize( this->devices = instance.map_mesh_device( shared_from_this(), num_command_queues, l1_small_size, trace_region_size, dispatch_core_type, offset, physical_device_ids); this->primary_view = std::make_unique(*this); - - for (int device_index = 0; device_index < this->devices.size(); device_index++) { - this->physical_id_to_device_index.insert({this->devices[device_index]->id(), device_index}); - } } MeshDevice::~MeshDevice() { if (not this->devices.empty()) { this->close_devices(); } + for (auto submesh : this->submeshes) { + submesh->close_devices(); + } + this->primary_view.reset(); + this->devices.clear(); + this->parent_mesh.reset(); } Device* MeshDevice::get_device_index(int logical_device_id) const { @@ -241,7 +288,7 @@ Device* MeshDevice::get_device_index(int logical_device_id) const { } Device* MeshDevice::get_device(int physical_device_id) const { - return this->devices.at(this->physical_id_to_device_index.at(physical_device_id)); + return SystemMesh::instance().get_device(physical_device_id); } std::vector MeshDevice::get_devices() const { return this->devices; } @@ -250,14 +297,6 @@ Device* MeshDevice::get_device(int row_idx, int col_idx) const { return this->get_device_index(row_idx * num_cols() + col_idx); } -std::vector MeshDevice::get_devices_on_row(int row_idx) const { - return this->primary_view->get_devices_on_row(row_idx); -} - -std::vector MeshDevice::get_devices_on_column(int col_idx) const { - return this->primary_view->get_devices_on_column(col_idx); -} - const DeviceIds MeshDevice::get_device_ids() const { DeviceIds device_ids; for (auto device : this->get_devices()) { @@ -283,7 +322,6 @@ MeshShape MeshDevice::shape() const { return this->mesh_device_shape; } void MeshDevice::close_devices() { SystemMesh::instance().unmap_mesh_device(shared_from_this()); this->devices.clear(); - this->physical_id_to_device_index.clear(); this->primary_view.reset(); } @@ -295,8 +333,60 @@ std::shared_ptr MeshDevice::get_view() const { return this std::shared_ptr MeshDevice::get_view() { return this->primary_view; } +std::vector> MeshDevice::get_submesh_views() { + std::vector> submesh_views; + if (this->submeshes.empty()) { + submesh_views.push_back(this->get_view()); + } + else { + for (auto submesh : this->submeshes) { + submesh_views.push_back(submesh->get_view()); + } + } + return submesh_views; +} + MeshDeviceID MeshDevice::get_mesh_id() const { return this->mesh_id; } +bool MeshDevice::is_parent_mesh() const { return this->parent_mesh == nullptr; } + +std::shared_ptr SystemMesh::get_mesh_device(const std::vector& physical_device_ids) { + log_trace(LogMetal, "Getting mesh device for {} physical devices: {}", physical_device_ids.size(), physical_device_ids); + std::unordered_set input_set(physical_device_ids.begin(), physical_device_ids.end()); + + for (const auto& [mesh_id, mesh_device] : this->assigned_mesh_device_devices) { + const auto& assigned_devices = this->assigned_devices.at(mesh_id); + std::unordered_set assigned_set(assigned_devices.begin(), assigned_devices.end()); + log_trace(LogMetal, "Assigned devices: {}", assigned_devices); + + if (input_set == assigned_set) { + return mesh_device; + } + } + TT_THROW("No mesh device found for the provided devices"); +} + +std::shared_ptr MeshDevice::fetch_mesh_device(const std::vector& devices) { + TT_FATAL(devices.size() > 0, "No devices provided"); + auto& instance = SystemMesh::instance(); + std::vector physical_device_ids; + for (auto device : devices) { + physical_device_ids.push_back(device->id()); + } + return instance.get_mesh_device(physical_device_ids); +} + +std::vector> MeshDevice::get_submeshes() const { return this->submeshes; } + +std::shared_ptr MeshDevice::get_view(const Device* device) { + for (auto submesh_view : this->get_submesh_views()) { + if (submesh_view->contains_device(device->id())) { + return submesh_view; + } + } + TT_THROW("Device {} not found in any submesh view", device->id()); +} + std::ostream& operator<<(std::ostream& os, const MeshDevice& mesh_device) { return os << mesh_device.to_string(); } bool validate_worker_modes(const std::vector& workers) { diff --git a/tt_metal/impl/device/mesh_device.hpp b/tt_metal/impl/device/mesh_device.hpp index 940110973cce..8a5cccc041aa 100644 --- a/tt_metal/impl/device/mesh_device.hpp +++ b/tt_metal/impl/device/mesh_device.hpp @@ -72,6 +72,7 @@ class SystemMesh { // Get the physical device IDs mapped to a MeshDevice std::vector get_mapped_physical_device_ids(const MeshDeviceConfig &config) const; + void register_mesh_device(const std::shared_ptr &mesh_device, const std::vector& devices); // Map MeshDevice to physical devices std::vector map_mesh_device( @@ -85,14 +86,18 @@ class SystemMesh { // Unmap MeshDevice, releasing the associated physical devices. void unmap_mesh_device(const std::shared_ptr &mesh_device); + std::shared_ptr get_mesh_device(const std::vector& physical_device_ids); + Device* get_device(const chip_id_t physical_device_id); }; class MeshDevice : public std::enable_shared_from_this { + private: MeshDeviceID mesh_id; MeshShape mesh_device_shape; std::shared_ptr primary_view; std::vector devices; - std::unordered_map physical_id_to_device_index; + std::shared_ptr parent_mesh; + std::vector> submeshes; void initialize( size_t l1_small_size, @@ -103,7 +108,7 @@ class MeshDevice : public std::enable_shared_from_this { const std::vector &physical_device_ids); public: - MeshDevice(const MeshShape &mesh_device_shape); + MeshDevice(const MeshShape &mesh_device_shape, std::shared_ptr parent_mesh = nullptr); ~MeshDevice(); MeshDevice(const MeshDevice &) = delete; @@ -116,8 +121,6 @@ class MeshDevice : public std::enable_shared_from_this { Device *get_device_index(int logical_device_id) const; Device *get_device(int physical_device_id) const; Device *get_device(int row_idx, int col_idx) const; - std::vector get_devices_on_row(int row_idx) const; - std::vector get_devices_on_column(int col_idx) const; const DeviceIds get_device_ids() const; @@ -138,6 +141,7 @@ class MeshDevice : public std::enable_shared_from_this { std::string to_string() const; MeshDeviceID get_mesh_id() const; + bool is_parent_mesh() const; static std::shared_ptr create( const MeshShape &mesh_device_shape, @@ -147,6 +151,13 @@ class MeshDevice : public std::enable_shared_from_this { DispatchCoreType dispatch_core_type, const std::pair &offset = {0, 0}, const std::vector &physical_device_ids = {}); + + std::vector> get_submeshes() const; + std::vector> get_submesh_views(); + std::shared_ptr get_view(const Device* device); + + std::shared_ptr create_submesh(const MeshShape &submesh_shape, const MeshOffset &offset = {0, 0}); + static std::shared_ptr fetch_mesh_device(const std::vector& devices); }; std::ostream &operator<<(std::ostream &os, const MeshDevice &mesh_device); diff --git a/tt_metal/impl/device/mesh_device_view.cpp b/tt_metal/impl/device/mesh_device_view.cpp index cc4a227780f6..4a9ea95286ef 100644 --- a/tt_metal/impl/device/mesh_device_view.cpp +++ b/tt_metal/impl/device/mesh_device_view.cpp @@ -24,12 +24,12 @@ MeshDeviceView::MeshDeviceView(const MeshDevice& mesh) } MeshDeviceView::MeshDeviceView(const MeshDevice& mesh, Coordinate top_left, Coordinate bottom_right) - : top_left_(top_left), bottom_right_(bottom_right) { + : top_left_(0, 0), bottom_right_(Coordinate{bottom_right.row - top_left.row, bottom_right.col - top_left.col}) { for (size_t row = top_left.row; row <= bottom_right.row; ++row) { for (size_t col = top_left.col; col <= bottom_right.col; ++col) { if (auto device = mesh.get_device(row, col)) { devices_.push_back(device); - device_coordinates_[(device)->id()] = {row, col}; + device_coordinates_[(device)->id()] = {row - top_left.row, col - top_left.col}; } } } @@ -158,6 +158,11 @@ bool MeshDeviceView::operator==(const MeshDeviceView& other) const { bottom_right_ == other.bottom_right_; } + +bool MeshDeviceView::contains_device(chip_id_t device_id) const { + return device_coordinates_.find(device_id) != device_coordinates_.end(); +} + Coordinate MeshDeviceView::find_device(chip_id_t device_id) const { auto it = device_coordinates_.find(device_id); if (it != device_coordinates_.end()) { @@ -199,5 +204,57 @@ void MeshDeviceView::validate_coordinates() const { throw std::invalid_argument("Invalid coordinates: top_left must be less than or equal to bottom_right"); } } +// Get the boundary coordinates of the subgrid defined by offset and shape +std::vector MeshDeviceView::get_ring_coordinates(const MeshShape& shape, const Coordinate& offset) { + std::vector boundary_coords; + + size_t start_row = offset.row; + size_t start_col = offset.col; + size_t end_row = offset.row + shape.first - 1; + size_t end_col = offset.col + shape.second - 1; + + // Validate the specified subgrid + if (start_row >= num_rows() || start_col >= num_cols() || + end_row >= num_rows() || end_col >= num_cols()) { + throw std::invalid_argument("Subgrid is out of mesh bounds."); + } + + // Traverse the top row from left to right + for (size_t col = start_col; col <= end_col; ++col) { + boundary_coords.emplace_back(Coordinate{start_row, col}); + } + + // Traverse the rightmost column from top+1 to bottom + for (size_t row = start_row + 1; row <= end_row; ++row) { + boundary_coords.emplace_back(Coordinate{row, end_col}); + } + + // Traverse the bottom row from right to left, if there is more than one row + if (end_row > start_row and end_col > start_col) { + for (size_t col = end_col - 1; col + 1 > start_col; --col) { + boundary_coords.emplace_back(Coordinate{end_row, col}); + } + for (size_t row = end_row - 1; row > start_row; --row) { + boundary_coords.emplace_back(Coordinate{row, start_col}); + } + } + + return boundary_coords; +} + +std::vector MeshDeviceView::get_ring_devices() { + return get_ring_devices(shape(), this->top_left_); +} + +std::vector MeshDeviceView::get_ring_devices(const MeshShape& shape, const Coordinate& offset) { + auto boundary_coords = get_ring_coordinates(shape, offset); + std::vector ring_devices; + for (const auto& coord : boundary_coords) { + if (auto device = this->get_device(coord.row, coord.col)) { + ring_devices.push_back(device); + } + } + return ring_devices; +} } // namespace tt::tt_metal diff --git a/tt_metal/impl/device/mesh_device_view.hpp b/tt_metal/impl/device/mesh_device_view.hpp index 73c9e2b61c20..4510c98d1edd 100644 --- a/tt_metal/impl/device/mesh_device_view.hpp +++ b/tt_metal/impl/device/mesh_device_view.hpp @@ -86,7 +86,7 @@ class MeshDeviceView { [[nodiscard]] bool empty() const noexcept; [[nodiscard]] size_t size() const noexcept; - [[nodiscard]] std::pair shape() const noexcept; + [[nodiscard]] MeshShape shape() const noexcept; [[nodiscard]] bool contains(const Coordinate& coord) const noexcept; [[nodiscard]] const_device_pointer at(const Coordinate& coord) const noexcept; @@ -99,10 +99,15 @@ class MeshDeviceView { [[nodiscard]] std::size_t num_cols() const { return bottom_right_.col - top_left_.col + 1; } [[nodiscard]] std::size_t num_devices() const { return devices_.size(); } + [[nodiscard]] bool contains_device(chip_id_t device_id) const; [[nodiscard]] Coordinate find_device(chip_id_t device_id) const; [[nodiscard]] chip_id_t find_device_id(const Coordinate& coord) const; + [[nodiscard]] std::vector get_ring_devices(); + [[nodiscard]] std::vector get_ring_devices(const MeshShape& shape, const Coordinate& offset); private: + std::vector get_ring_coordinates(const MeshShape& shape, const Coordinate& offset); + std::vector devices_; std::unordered_map device_coordinates_; Coordinate top_left_; diff --git a/ttnn/cpp/pybind11/multi_device.hpp b/ttnn/cpp/pybind11/multi_device.hpp index 70d9755d0400..f468217ebe7b 100644 --- a/ttnn/cpp/pybind11/multi_device.hpp +++ b/ttnn/cpp/pybind11/multi_device.hpp @@ -47,6 +47,7 @@ void py_module(py::module& module) { py::arg("offset"), py::arg("physical_device_ids")) .def("get_num_devices", &MeshDevice::num_devices) + .def("get_mesh_id", &MeshDevice::get_mesh_id) .def("get_device_ids", &MeshDevice::get_device_ids) .def( "get_device", @@ -62,26 +63,7 @@ void py_module(py::module& module) { Returns: List[Device]: The devices in the device mesh. )doc") - .def( - "get_devices_on_row", - &MeshDevice::get_devices_on_row, - py::return_value_policy::reference, - R"doc( - Get the devices in a row of the device mesh. - - Returns: - List[Device]: The devices on a row in the device mesh. - )doc") - .def( - "get_devices_on_column", - &MeshDevice::get_devices_on_column, - py::return_value_policy::reference, - R"doc( - Get the devices in a row of the device mesh. - - Returns: - List[Device]: The devices on a row in the device mesh. - )doc") + .def("create_submesh", &MeshDevice::create_submesh, py::arg("submesh_shape"), py::arg("offset") = std::pair{0, 0}, py::return_value_policy::reference_internal, py::keep_alive<0, 1>()) .def( "compute_with_storage_grid_size", &MeshDevice::compute_with_storage_grid_size, diff --git a/ttnn/cpp/pybind11/operations/core.hpp b/ttnn/cpp/pybind11/operations/core.hpp index 3825201f266b..71c300b12fb1 100644 --- a/ttnn/cpp/pybind11/operations/core.hpp +++ b/ttnn/cpp/pybind11/operations/core.hpp @@ -284,7 +284,8 @@ void py_module(py::module& module) { py::arg("trace_id"), py::kw_only(), py::arg("cq_id") = ttnn::DefaultQueueId, - py::arg("blocking") = true); + py::arg("blocking") = true, + py::call_guard()); module.def( "release_trace", diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp index 66f70cb1ef1c..dda3a09ea01c 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp @@ -6,6 +6,7 @@ #include "ttnn/deprecated/tt_dnn/op_library/math.hpp" #include "tt_metal/host_api.hpp" +#include "tt_metal/impl/device/mesh_device.hpp" #include "ttnn/tensor/tensor_utils.hpp" @@ -191,17 +192,18 @@ Tensor all_gather( if (num_devices == 2){ ccl_topology = ttnn::ccl::Topology::Linear; } + auto mesh_device = MeshDevice::fetch_mesh_device(devices); std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor}))}; operation::launch_op( - [dim, num_links, memory_config, user_defined_num_workers, user_defined_num_buffers_per_channel, devices, ccl_topology]( + [=]( const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors) mutable -> std::vector { const auto& input_tensor = input_tensors.at(0); - + auto submesh_view = mesh_device->get_view(input_tensor.device()); return operation::run( - create_all_gather_struct(input_tensor, dim, num_links, memory_config, user_defined_num_workers, user_defined_num_buffers_per_channel, devices, ccl_topology), + create_all_gather_struct(input_tensor, dim, num_links, memory_config, user_defined_num_workers, user_defined_num_buffers_per_channel, submesh_view->get_ring_devices(), ccl_topology), {input_tensor}); }, {input_tensor}, diff --git a/ttnn/cpp/ttnn/tensor/tensor.cpp b/ttnn/cpp/ttnn/tensor/tensor.cpp index d72b0bf50f8c..896fdb7d3a0f 100644 --- a/ttnn/cpp/ttnn/tensor/tensor.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor.cpp @@ -791,14 +791,16 @@ std::vector distribute_tensor_to_mesh(const Tensor& tensor, MeshDevice& return workers; }; - if (mesh_device.get_view() != nullptr and std::holds_alternative(tensor.get_storage())) { + auto mesh_view = mesh_device.get_view(); + if (mesh_view != nullptr and std::holds_alternative(tensor.get_storage())) { const auto& host_storage = std::get(tensor.get_storage()); return std::visit([&](const auto& strategy) { using StrategyType = std::decay_t; if constexpr (std::is_same_v) { - auto mesh_view = mesh_device.get_view(); return mesh_view->get_devices(strategy.shard_mesh); + } else if constexpr (std::is_same_v) { + return mesh_view->get_ring_devices(); } else { return get_multi_device_workers(mesh_device.get_devices()); }