diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_hop_latencies_no_edm.cpp b/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_hop_latencies_no_edm.cpp index a03799f896e4..aad3f3ca8947 100644 --- a/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_hop_latencies_no_edm.cpp +++ b/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_hop_latencies_no_edm.cpp @@ -8,6 +8,7 @@ #include #include +#include "impl/device/mesh_device_view.hpp" #include "tt_metal/common/logger.hpp" #include "device/tt_arch_types.h" #include "impl/device/device.hpp" @@ -26,6 +27,7 @@ #include "tt_metal/test_utils/stimulus.hpp" #include "tt_metal/detail/persistent_kernel_cache.hpp" +#include "tt_metal/impl/device/mesh_device.hpp" using tt::tt_metal::Device; @@ -41,8 +43,7 @@ class T3000TestDevice { num_devices_ = tt::tt_metal::GetNumAvailableDevices(); if (arch_ == tt::ARCH::WORMHOLE_B0 and tt::tt_metal::GetNumAvailableDevices() == 8 and tt::tt_metal::GetNumPCIeDevices() == 4) { - devices_ = tt::tt_metal::detail::CreateDevices({0,1,2,3,4,5,6,7}); - tt::Cluster::instance().set_internal_routing_info_for_ethernet_cores(true); + mesh_device_ = tt::tt_metal::MeshDevice::create(tt::tt_metal::MeshDeviceConfig(tt::tt_metal::MeshShape{2, 4})); } else { TT_THROW("This suite can only be run on T3000 Wormhole devices"); @@ -57,15 +58,12 @@ class T3000TestDevice { void TearDown() { device_open = false; - tt::Cluster::instance().set_internal_routing_info_for_ethernet_cores(false); - for (auto [device_id, device_ptr] : devices_) { - tt::tt_metal::CloseDevice(device_ptr); - } + mesh_device_->close_devices(); } - std::map devices_; tt::ARCH arch_; size_t num_devices_; + std::shared_ptr mesh_device_; private: bool device_open; @@ -420,23 +418,51 @@ int main (int argc, char** argv) { TT_ASSERT(std::all_of(max_concurrent_samples.begin(), max_concurrent_samples.end(), [](std::size_t n) { return n > 0; })); T3000TestDevice test_fixture; + auto view = test_fixture.mesh_device_->get_view(); - // Device setup - std::vector device_ids = std::vector{0, 1, 2, 3, 4, 5, 6, 7}; - - auto get_device_list = [](std::map &all_devices, std::size_t n_hops) { + auto get_device_list = [](const std::shared_ptr& view, std::size_t n_hops) { switch (n_hops) { case 2: - return std::vector{all_devices[0], all_devices[1]}; + return std::vector{ + view->get_device(0, 0), + view->get_device(0, 1), + }; case 4: - return std::vector{all_devices[0], all_devices[1], all_devices[2], all_devices[3]}; + return std::vector{ + view->get_device(1, 1), + view->get_device(0, 1), + view->get_device(0, 2), + view->get_device(1, 2), + }; case 8: - return std::vector{all_devices[0], all_devices[4], all_devices[5], all_devices[1], all_devices[2], all_devices[6], all_devices[7], all_devices[3]}; + return std::vector{ + view->get_device(1, 1), + view->get_device(1, 0), + view->get_device(0, 0), + view->get_device(0, 1), + view->get_device(0, 2), + view->get_device(0, 3), + view->get_device(1, 3), + view->get_device(1, 2), + }; case 12: // Does an extra loop through the inner ring - return std::vector{all_devices[0], all_devices[4], all_devices[5], all_devices[1], all_devices[2], all_devices[3], all_devices[0], all_devices[1], all_devices[2], all_devices[6], all_devices[7], all_devices[3]}; + return std::vector{ + view->get_device(1, 1), + view->get_device(1, 0), + view->get_device(0, 0), + view->get_device(0, 1), + view->get_device(0, 2), + view->get_device(1, 2), + view->get_device(1, 1), + view->get_device(0, 1), + view->get_device(0, 2), + view->get_device(0, 3), + view->get_device(1, 3), + view->get_device(1, 2), + }; default: TT_THROW("Unsupported hop_count"); @@ -448,7 +474,7 @@ int main (int argc, char** argv) { constexpr std::size_t placeholder_arg_value = 1; for (auto n_hops : hop_counts) { - auto devices = get_device_list(test_fixture.devices_, n_hops); + auto devices = get_device_list(view, n_hops); std::vector hop_eth_sockets = build_eth_sockets_list(devices); for (auto max_concurrent_samples : max_concurrent_samples) { diff --git a/tests/ttnn/unit_tests/gtests/ttnn_test_fixtures.hpp b/tests/ttnn/unit_tests/gtests/ttnn_test_fixtures.hpp index 2b0a8fc04a1e..c8506b79cd95 100644 --- a/tests/ttnn/unit_tests/gtests/ttnn_test_fixtures.hpp +++ b/tests/ttnn/unit_tests/gtests/ttnn_test_fixtures.hpp @@ -67,12 +67,7 @@ class T3kMultiDeviceFixture : public ::testing::Test { if (num_devices < 8 or arch != tt::ARCH::WORMHOLE_B0) { GTEST_SKIP() << "Skipping T3K Multi-Device test suite on non T3K machine."; } - constexpr auto DEFAULT_NUM_COMMAND_QUEUES = 1; mesh_device_ = MeshDevice::create( - DEFAULT_L1_SMALL_SIZE, - DEFAULT_TRACE_REGION_SIZE, - DEFAULT_NUM_COMMAND_QUEUES, - DispatchCoreType::WORKER, MeshDeviceConfig(MeshShape{2, 4}, MeshType::Ring)); } diff --git a/tt_metal/impl/device/mesh_device.cpp b/tt_metal/impl/device/mesh_device.cpp index 0c8d169056c3..e1cc8228c0b0 100644 --- a/tt_metal/impl/device/mesh_device.cpp +++ b/tt_metal/impl/device/mesh_device.cpp @@ -216,11 +216,11 @@ MeshDevice::MeshDevice(const MeshShape& mesh_device_shape, MeshType type, std::w : mesh_device_shape(mesh_device_shape), type(type), mesh_id(generate_unique_mesh_id()), parent_mesh(parent_mesh) {} std::shared_ptr MeshDevice::create( + const MeshDeviceConfig& config, size_t l1_small_size, size_t trace_region_size, size_t num_command_queues, - DispatchCoreType dispatch_core_type, - const MeshDeviceConfig& config) + DispatchCoreType dispatch_core_type) { auto mesh_device = std::make_shared(config.mesh_shape, config.mesh_type); mesh_device->initialize(l1_small_size, trace_region_size, num_command_queues, dispatch_core_type, config); diff --git a/tt_metal/impl/device/mesh_device.hpp b/tt_metal/impl/device/mesh_device.hpp index 1f3cf43592fb..91f1d12f9cfc 100644 --- a/tt_metal/impl/device/mesh_device.hpp +++ b/tt_metal/impl/device/mesh_device.hpp @@ -27,7 +27,7 @@ struct MeshDeviceConfig { MeshDeviceConfig( const MeshShape &mesh_shape, - MeshType mesh_type = MeshType::RowMajor) : + MeshType mesh_type) : mesh_shape(mesh_shape), offset(MeshOffset{0, 0}), physical_device_ids(std::vector()), @@ -174,11 +174,11 @@ class MeshDevice : public std::enable_shared_from_this { static std::shared_ptr fetch_mesh_device(const std::vector& devices); static std::shared_ptr create( - size_t l1_small_size, - size_t trace_region_size, - size_t num_command_queues, - DispatchCoreType dispatch_core_type, - const MeshDeviceConfig &config); + const MeshDeviceConfig &config, + size_t l1_small_size = DEFAULT_L1_SMALL_SIZE, + size_t trace_region_size = DEFAULT_TRACE_REGION_SIZE, + size_t num_command_queues = 1, + DispatchCoreType dispatch_core_type = DispatchCoreType::WORKER); }; std::ostream &operator<<(std::ostream &os, const MeshDevice &mesh_device); diff --git a/ttnn/cpp/pybind11/multi_device.hpp b/ttnn/cpp/pybind11/multi_device.hpp index c9c661e04af0..d339fec8ac86 100644 --- a/ttnn/cpp/pybind11/multi_device.hpp +++ b/ttnn/cpp/pybind11/multi_device.hpp @@ -37,13 +37,12 @@ void py_module(py::module& module) { const std::pair& offset, const std::vector& physical_device_ids, MeshType mesh_type) { - auto config = MeshDeviceConfig(mesh_device_shape, offset, physical_device_ids, mesh_type); return MeshDevice::create( + MeshDeviceConfig(mesh_device_shape, offset, physical_device_ids, mesh_type), l1_small_size, trace_region_size, num_command_queues, - dispatch_core_type, - config); + dispatch_core_type); }), py::kw_only(), py::arg("mesh_shape"), diff --git a/ttnn/cpp/ttnn/multi_device.cpp b/ttnn/cpp/ttnn/multi_device.cpp index b8a9e91a900b..cb1e48fc7bf2 100644 --- a/ttnn/cpp/ttnn/multi_device.cpp +++ b/ttnn/cpp/ttnn/multi_device.cpp @@ -14,7 +14,7 @@ namespace ttnn::multi_device { std::shared_ptr open_mesh_device(const MeshShape& mesh_shape, size_t l1_small_size, size_t trace_region_size, size_t num_command_queues, DispatchCoreType dispatch_core_type, MeshType mesh_type, const std::pair& offset, const std::vector& physical_device_ids) { auto config = MeshDeviceConfig(mesh_shape, offset, physical_device_ids, mesh_type); - return MeshDevice::create(l1_small_size, trace_region_size, num_command_queues, dispatch_core_type, config); + return MeshDevice::create(config, l1_small_size, trace_region_size, num_command_queues, dispatch_core_type); } void close_mesh_device(const std::shared_ptr& mesh_device) {