diff --git a/.gitignore b/.gitignore index 65167333710d..44ed13e065fc 100644 --- a/.gitignore +++ b/.gitignore @@ -129,3 +129,4 @@ ttnn/ttnn/runtime ClangBuildAnalyzer.ini ttnn/ttnn/.rpath_checked__ttnn +Testing diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 1474dc932c19..0c7f8d9087ad 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -1,11 +1,12 @@ +include(CTest) enable_testing() include(GoogleTest) add_library(test_common_libs INTERFACE) target_link_libraries(test_common_libs INTERFACE pthread gtest gtest_main magic_enum fmt) if(TT_METAL_BUILD_TESTS) - add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tt_metal/tt_metal) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tt_metal) endif(TT_METAL_BUILD_TESTS) if(TTNN_BUILD_TESTS) diff --git a/tests/tt_metal/CMakeLists.txt b/tests/tt_metal/CMakeLists.txt new file mode 100644 index 000000000000..45cf2b23df10 --- /dev/null +++ b/tests/tt_metal/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(distributed) +add_subdirectory(tt_metal) diff --git a/tests/tt_metal/distributed/CMakeLists.txt b/tests/tt_metal/distributed/CMakeLists.txt new file mode 100644 index 000000000000..61bb8edfce03 --- /dev/null +++ b/tests/tt_metal/distributed/CMakeLists.txt @@ -0,0 +1,15 @@ +set(UNIT_TESTS_DISTRIBUTED_SRC + ${CMAKE_CURRENT_SOURCE_DIR}/test_distributed.cpp +) + +add_executable(distributed_unit_tests ${UNIT_TESTS_DISTRIBUTED_SRC}) +target_link_libraries(distributed_unit_tests PRIVATE tt_metal test_common_libs) + +target_include_directories(distributed_unit_tests PRIVATE + ${PROJECT_SOURCE_DIR}/tt_metal + ${PROJECT_SOURCE_DIR}/tt_metal/distributed +) + +set_target_properties(distributed_unit_tests PROPERTIES RUNTIME_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}/test/tt_metal/distributed) + +gtest_discover_tests(distributed_unit_tests) diff --git a/tests/tt_metal/distributed/test_distributed.cpp b/tests/tt_metal/distributed/test_distributed.cpp new file mode 100644 index 000000000000..75f6626c001a --- /dev/null +++ b/tests/tt_metal/distributed/test_distributed.cpp @@ -0,0 +1,45 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "tt_metal/distributed/mesh_device.hpp" +#include "tt_metal/distributed/mesh_device_view.hpp" +#include "tt_metal/llrt/tt_cluster.hpp" + +namespace tt::tt_metal::distributed::test { + +static inline void skip_test_if_not_t3000() { + auto slow_dispatch = getenv("TT_METAL_SLOW_DISPATCH_MODE"); + const auto arch = tt::Cluster::instance().arch(); + const size_t num_devices = tt::Cluster::instance().number_of_devices(); + + if (slow_dispatch) { + GTEST_SKIP() << "Skipping Multi-Device test suite, since it can only be run in Fast Dispatch Mode."; + } + if (num_devices < 8 or arch != tt::ARCH::WORMHOLE_B0) { + GTEST_SKIP() << "Skipping T3K Multi-Device test suite on non T3K machine."; + } +} +class MeshDevice_T3000 : public ::testing::Test { + protected: + void SetUp() override { + skip_test_if_not_t3000(); + this->mesh_device_ = MeshDevice::create(MeshDeviceConfig(MeshShape(2, 4))); + } + + void TearDown() override { + mesh_device_->close_devices(); + mesh_device_.reset(); + } + std::shared_ptr mesh_device_; +}; + +TEST_F(MeshDevice_T3000, SimpleMeshDeviceTest) { + EXPECT_EQ(mesh_device_->num_devices(), 8); + EXPECT_EQ(mesh_device_->num_rows(), 2); + EXPECT_EQ(mesh_device_->num_cols(), 4); +} + +} // namespace tt::tt_metal::distributed::test diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_hop_latencies_no_edm.cpp b/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_hop_latencies_no_edm.cpp index 1ee2f290314a..be2b4d9c2e58 100644 --- a/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_hop_latencies_no_edm.cpp +++ b/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_hop_latencies_no_edm.cpp @@ -30,6 +30,10 @@ #include "tt_metal/distributed/mesh_device.hpp" using tt::tt_metal::Device; +using tt::tt_metal::distributed::MeshShape; +using tt::tt_metal::distributed::MeshDevice; +using tt::tt_metal::distributed::MeshDeviceView; +using tt::tt_metal::distributed::MeshDeviceConfig; class T3000TestDevice { public: @@ -43,7 +47,7 @@ class T3000TestDevice { num_devices_ = tt::tt_metal::GetNumAvailableDevices(); if (arch_ == tt::ARCH::WORMHOLE_B0 and tt::tt_metal::GetNumAvailableDevices() == 8 and tt::tt_metal::GetNumPCIeDevices() == 4) { - mesh_device_ = tt::tt_metal::MeshDevice::create(tt::tt_metal::MeshDeviceConfig(tt::tt_metal::MeshShape{2, 4})); + mesh_device_ = MeshDevice::create(MeshDeviceConfig(MeshShape{2, 4})); } else { TT_THROW("This suite can only be run on T3000 Wormhole devices"); @@ -63,7 +67,7 @@ class T3000TestDevice { tt::ARCH arch_; size_t num_devices_; - std::shared_ptr mesh_device_; + std::shared_ptr mesh_device_; private: bool device_open; diff --git a/tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp b/tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp index df3476bd5454..e9e0ba7c89f0 100644 --- a/tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp +++ b/tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp @@ -66,9 +66,9 @@ bool is_tgg_system() return is_galaxy_system && (num_mmio_devices == 8) && (num_devices == 64); } -MeshShape get_mesh_shape() +ttnn::MeshShape get_mesh_shape() { - MeshShape shape; + ttnn::MeshShape shape; if (is_tg_system()) { shape = {8, 4}; @@ -116,8 +116,8 @@ TEST(GalaxyTests, TestAllGatherDeadlock) { } validate_num_tunnels_and_tunnel_depth(); - MeshShape mesh_shape = get_mesh_shape(); - std::shared_ptr mesh = ttnn::distributed::open_mesh_device(mesh_shape, 0, 0, 1, DispatchCoreType::WORKER); + ttnn::MeshShape mesh_shape = get_mesh_shape(); + std::shared_ptr mesh = ttnn::distributed::open_mesh_device(mesh_shape, 0, 0, 1, DispatchCoreType::WORKER); // Setup input data and output data containers MemoryConfig mem_cfg = MemoryConfig{ @@ -137,7 +137,7 @@ TEST(GalaxyTests, TestAllGatherDeadlock) { } // Iterate over each row and run line all-gather multiple times. // For each row, send adversarial traffic to the first chip, that can hang the network if the CCL is not tagged. - auto view = MeshDeviceView(*mesh); + auto view = ttnn::MeshDeviceView(*mesh); for (uint32_t row = 0; row < 8; row++) { auto devs = view.get_devices_on_row(row); std::vector device_ids = {}; @@ -193,11 +193,11 @@ TEST(GalaxyTests, TestReduceScatterDeadlock) { } validate_num_tunnels_and_tunnel_depth(); - MeshShape mesh_shape = get_mesh_shape(); - std::shared_ptr mesh = ttnn::distributed::open_mesh_device(mesh_shape, 0, 0, 1, DispatchCoreType::WORKER); + ttnn::MeshShape mesh_shape = get_mesh_shape(); + std::shared_ptr mesh = ttnn::distributed::open_mesh_device(mesh_shape, 0, 0, 1, DispatchCoreType::WORKER); // Create the outer ring on which Reduce Scatter will be run. This allows us to verify that there are no deadlocks when we send CCLs to the // first tunnel (forward path). - auto view = MeshDeviceView(*mesh); + auto view = ttnn::MeshDeviceView(*mesh); std::vector ring_devices = view.get_devices_on_row(0); // Tunnel 0 std::vector ring_devices_1 = view.get_devices_on_column(mesh_shape.second - 1); // Orthogonal to tunnel .. no deadlocks ring_devices_1 = std::vector(ring_devices_1.begin() + 1, ring_devices_1.end()); diff --git a/tt_metal/distributed/mesh_device.cpp b/tt_metal/distributed/mesh_device.cpp index ec521f9ce408..2982a203a689 100644 --- a/tt_metal/distributed/mesh_device.cpp +++ b/tt_metal/distributed/mesh_device.cpp @@ -14,7 +14,7 @@ #include "tt_metal/distributed/mesh_device_view.hpp" #include "tt_metal/distributed/mesh_device.hpp" -namespace tt::tt_metal { +namespace tt::tt_metal::distributed { using LogicalCoordinate = Coordinate; using PhysicalCoordinate = eth_coord_t; @@ -295,7 +295,7 @@ void MeshDevice::initialize( auto& instance = SystemMesh::instance(); this->devices = instance.map_mesh_device( shared_from_this(), num_command_queues, l1_small_size, trace_region_size, dispatch_core_type, config); - this->primary_view = std::make_shared(*this); + this->primary_view = std::make_shared(*this); } MeshDevice::~MeshDevice() { @@ -420,14 +420,4 @@ void MeshDevice::disable_and_clear_program_cache() { } } -std::vector get_t3k_physical_device_ids_ring() { - auto& instance = SystemMesh::instance(); - auto num_devices = instance.get_num_devices(); - TT_FATAL(num_devices == 8, "T3000 ring topology only works with 8 devices"); - - auto physical_device_ids = instance.get_mapped_physical_device_ids( - MeshDeviceConfig(MeshShape{1, 8}, MeshOffset{0, 0})); - return physical_device_ids; -} - -} // namespace tt::tt_metal +} // namespace tt::tt_metal::distributed diff --git a/tt_metal/distributed/mesh_device.hpp b/tt_metal/distributed/mesh_device.hpp index 7237c8c0158c..6e41b74100b0 100644 --- a/tt_metal/distributed/mesh_device.hpp +++ b/tt_metal/distributed/mesh_device.hpp @@ -12,7 +12,7 @@ #include "tt_metal/impl/device/device.hpp" #include "tt_metal/distributed/mesh_device_view.hpp" -namespace tt::tt_metal { +namespace tt::tt_metal::distributed { using DeviceIds = std::vector; using MeshDeviceID = size_t; @@ -186,6 +186,5 @@ class MeshDevice : public std::enable_shared_from_this { std::ostream &operator<<(std::ostream &os, const MeshDevice &mesh_device); bool validate_worker_modes(const std::vector &workers); -std::vector get_t3k_physical_device_ids_ring(); -} // namespace tt::tt_metal +} // namespace tt::tt_metal::distributed diff --git a/tt_metal/distributed/mesh_device_view.cpp b/tt_metal/distributed/mesh_device_view.cpp index d5f7e80855af..5f33ab3ca8bf 100644 --- a/tt_metal/distributed/mesh_device_view.cpp +++ b/tt_metal/distributed/mesh_device_view.cpp @@ -9,9 +9,7 @@ #include "tt_metal/distributed/mesh_device.hpp" -namespace tt::tt_metal { - -using MeshDevice = tt::tt_metal::MeshDevice; +namespace tt::tt_metal::distributed { static std::vector get_devices_from_coordinates(MeshDeviceView& mesh, const std::vector& coords) { std::vector devices; @@ -287,4 +285,4 @@ MeshDeviceView::DeviceView MeshDeviceView::get_devices(MeshType type) { } } -} // namespace tt::tt_metal +} // namespace tt::tt_metal::distributed diff --git a/tt_metal/distributed/mesh_device_view.hpp b/tt_metal/distributed/mesh_device_view.hpp index 2b16a2652b95..517093d97339 100644 --- a/tt_metal/distributed/mesh_device_view.hpp +++ b/tt_metal/distributed/mesh_device_view.hpp @@ -13,7 +13,7 @@ #include "tt_metal/impl/device/device.hpp" -namespace tt::tt_metal { +namespace tt::tt_metal::distributed { // Forward declaration of MeshDevice class MeshDevice; @@ -128,19 +128,19 @@ inline MeshDeviceView make_mesh_device_view(std::vector devices, MeshDe return MeshDeviceView(std::move(devices), std::move(mapper)); } -} // namespace tt::tt_metal +} // namespace tt::tt_metal::distributed namespace std { // Specializations to enable structured bindings - template<> struct tuple_size : std::integral_constant {}; - template struct tuple_element { + template<> struct tuple_size : std::integral_constant {}; + template struct tuple_element { using type = size_t; }; // Specialization to enable hashing of Coordinate template <> - struct hash { - size_t operator()(const tt::tt_metal::Coordinate& coord) const noexcept { + struct hash { + size_t operator()(const tt::tt_metal::distributed::Coordinate& coord) const noexcept { size_t seed = 0; tt::utils::hash_combine(seed, coord.row); tt::utils::hash_combine(seed, coord.col); diff --git a/ttnn/cpp/pybind11/tensor.cpp b/ttnn/cpp/pybind11/tensor.cpp index b3f6d1480408..931ec2b8b2c2 100644 --- a/ttnn/cpp/pybind11/tensor.cpp +++ b/ttnn/cpp/pybind11/tensor.cpp @@ -294,13 +294,13 @@ void tensor_mem_config_module(py::module& m_tensor) { m_tensor.def( "load_tensor", - static_cast(&load_tensor), + py::overload_cast(&load_tensor), py::arg("file_name"), py::arg("device") = nullptr, R"doc(Load tensor to file)doc"); m_tensor.def( "load_tensor", - static_cast(&load_tensor), + py::overload_cast(&load_tensor), py::arg("file_name"), py::arg("device") = nullptr, R"doc(Load tensor to file)doc"); diff --git a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp index 53caf4276ca6..3536ede52f17 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp @@ -180,7 +180,7 @@ void py_module(py::module& module) { )doc"); module.def("get_device_tensors", &get_device_tensors, py::arg("tensor"), py::kw_only()); module.def("aggregate_as_tensor", &aggregate_as_tensor, py::arg("tensors"), py::kw_only()); - module.def("get_t3k_physical_device_ids_ring", &tt::tt_metal::get_t3k_physical_device_ids_ring); + module.def("get_t3k_physical_device_ids_ring", &get_t3k_physical_device_ids_ring); } } // namespace ttnn::distributed diff --git a/ttnn/cpp/ttnn/distributed/mesh_device.cpp b/ttnn/cpp/ttnn/distributed/mesh_device.cpp index 3d351ade00b6..6e2547f68724 100644 --- a/ttnn/cpp/ttnn/distributed/mesh_device.cpp +++ b/ttnn/cpp/ttnn/distributed/mesh_device.cpp @@ -82,4 +82,15 @@ Tensor aggregate_as_tensor(std::vector& tensor_shards) } } +std::vector get_t3k_physical_device_ids_ring() { + using namespace tt::tt_metal::distributed; + auto& instance = SystemMesh::instance(); + auto num_devices = instance.get_num_devices(); + TT_FATAL(num_devices == 8, "T3000 ring topology only works with 8 devices"); + + auto physical_device_ids = instance.get_mapped_physical_device_ids( + MeshDeviceConfig(MeshShape{1, 8}, MeshOffset{0, 0})); + return physical_device_ids; +} + } // namespace ttnn::distributed diff --git a/ttnn/cpp/ttnn/distributed/types.hpp b/ttnn/cpp/ttnn/distributed/types.hpp new file mode 100644 index 000000000000..533fd9251b4a --- /dev/null +++ b/ttnn/cpp/ttnn/distributed/types.hpp @@ -0,0 +1,18 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "tt_metal/distributed/mesh_device.hpp" + +namespace ttnn::distributed::types { + +using MeshShape = tt::tt_metal::distributed::MeshShape; +using DeviceIds = tt::tt_metal::distributed::DeviceIds; +using MeshDevice = tt::tt_metal::distributed::MeshDevice; +using MeshDeviceView = tt::tt_metal::distributed::MeshDeviceView; +using MeshType = tt::tt_metal::distributed::MeshType; +using MeshDeviceConfig = tt::tt_metal::distributed::MeshDeviceConfig; + +} // namespace ttnn::distributed diff --git a/ttnn/cpp/ttnn/events.cpp b/ttnn/cpp/ttnn/events.cpp index 5f207b1cfdf9..08e609e495f2 100644 --- a/ttnn/cpp/ttnn/events.cpp +++ b/ttnn/cpp/ttnn/events.cpp @@ -6,6 +6,7 @@ #include #include "tt_metal/impl/event/event.hpp" +#include "ttnn/distributed/types.hpp" namespace ttnn::events { diff --git a/ttnn/cpp/ttnn/events.hpp b/ttnn/cpp/ttnn/events.hpp index 9fd11ccea39f..9976359b7f1d 100644 --- a/ttnn/cpp/ttnn/events.hpp +++ b/ttnn/cpp/ttnn/events.hpp @@ -6,8 +6,6 @@ #include -#include "tt_metal/distributed/mesh_device.hpp" - namespace ttnn::events { struct MultiDeviceEvent diff --git a/ttnn/cpp/ttnn/run_operation_inl.hpp b/ttnn/cpp/ttnn/run_operation_inl.hpp index 510847bdd352..6e0ce10e9055 100644 --- a/ttnn/cpp/ttnn/run_operation_inl.hpp +++ b/ttnn/cpp/ttnn/run_operation_inl.hpp @@ -84,7 +84,7 @@ void launch_op( return; } check_output(output_tensors, workers); - validate_worker_modes(workers); + distributed::validate_worker_modes(workers); // Record ref counts for all tensors before pushing to worker queue. std::vector input_tensor_ref_count(input_tensors.size()); std::vector optional_input_tensor_ref_count(optional_input_tensors.size()); diff --git a/ttnn/cpp/ttnn/tensor/serialization.cpp b/ttnn/cpp/ttnn/tensor/serialization.cpp index 68cc4eb87171..60ee2ee76308 100644 --- a/ttnn/cpp/ttnn/tensor/serialization.cpp +++ b/ttnn/cpp/ttnn/tensor/serialization.cpp @@ -254,7 +254,7 @@ void dump_tensor(const std::string& file_name, const Tensor& tensor, const std:: } template -Tensor load_tensor(const std::string& file_name, T device) { +Tensor load_tensor_helper(const std::string& file_name, T device) { std::ifstream input_stream(file_name, std::ios::in | std::ios::binary); if (not input_stream) { throw std::runtime_error(fmt::format("Cannot open \"{}\"", file_name)); @@ -320,8 +320,12 @@ Tensor load_tensor(const std::string& file_name, T device) { } // Explicit instantiations -template Tensor load_tensor(const std::string&, Device*); -template Tensor load_tensor(const std::string&, MeshDevice*); +Tensor load_tensor(const std::string& file_name, Device* device) { + return load_tensor_helper(file_name, device); +} +Tensor load_tensor(const std::string& file_name, MeshDevice* device) { + return load_tensor_helper(file_name, device); +} } // namespace tt_metal diff --git a/ttnn/cpp/ttnn/tensor/serialization.hpp b/ttnn/cpp/ttnn/tensor/serialization.hpp index 59ab09d82fd0..9ccca3a8397a 100644 --- a/ttnn/cpp/ttnn/tensor/serialization.hpp +++ b/ttnn/cpp/ttnn/tensor/serialization.hpp @@ -15,8 +15,8 @@ namespace tt_metal { void dump_tensor(const std::string& file_name, const Tensor& tensor, const std::unordered_map& strategy); -template -Tensor load_tensor(const std::string& file_name, T device = nullptr); +Tensor load_tensor(const std::string& file_name, Device* device = nullptr); +Tensor load_tensor(const std::string& file_name, MeshDevice* device = nullptr); } // namespace tt_metalls diff --git a/ttnn/cpp/ttnn/tensor/tensor.hpp b/ttnn/cpp/ttnn/tensor/tensor.hpp index 4f2010730970..d819a507ac1f 100644 --- a/ttnn/cpp/ttnn/tensor/tensor.hpp +++ b/ttnn/cpp/ttnn/tensor/tensor.hpp @@ -28,6 +28,11 @@ namespace tt { namespace tt_metal { + +namespace distributed { + class MeshDevice; +} +using MeshDevice = tt::tt_metal::distributed::MeshDevice; struct Tensor { struct TensorAttributes : public std::enable_shared_from_this { Storage storage; diff --git a/ttnn/cpp/ttnn/tensor/tensor_ops.cpp b/ttnn/cpp/ttnn/tensor/tensor_ops.cpp index 932e28087c22..04f3ca4f32c4 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_ops.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor_ops.cpp @@ -79,7 +79,7 @@ Tensor tensor_to(const Tensor& input_tensor, const std::vector& workers ZoneScoped; GraphTracker::instance().track_function_start("Tensor::to", input_tensor, workers, mem_config); TT_FATAL( - validate_worker_modes(workers), "All device threads/workers must be running in the same mode (ASYNC or SYNC)"); + distributed::validate_worker_modes(workers), "All device threads/workers must be running in the same mode (ASYNC or SYNC)"); Tensor device_tensor = Tensor(workers); uint32_t device_tensor_ref_count = device_tensor.tensor_attributes->record_main_thread_ref_count(); uint32_t original_tensor_ref_count = input_tensor.tensor_attributes->record_main_thread_ref_count(); @@ -122,7 +122,7 @@ Tensor tensor_cpu(const Tensor& input_tensor, bool blocking, uint8_t cq_id) { return output; } TT_FATAL( - validate_worker_modes(workers), "All device threads/workers must be running in the same mode (ASYNC or SYNC)"); + distributed::validate_worker_modes(workers), "All device threads/workers must be running in the same mode (ASYNC or SYNC)"); Tensor host_tensor({}, workers.size()); uint32_t original_tensor_ref_count = input_tensor.tensor_attributes->record_main_thread_ref_count(); for (int worker_index = 0; worker_index < workers.size(); worker_index++) { @@ -201,7 +201,7 @@ Tensor tensor_to(const Tensor& input_tensor, Layout target_layout, MeshDevice* m if (mesh_device) { auto workers = distribute_tensor_to_mesh(input_tensor, *mesh_device); TT_FATAL( - validate_worker_modes(workers), + distributed::validate_worker_modes(workers), "All device threads/workers must be running in the same mode (ASYNC or SYNC)"); std::optional distributed_config = std::nullopt; diff --git a/ttnn/cpp/ttnn/tensor/tensor_ops.hpp b/ttnn/cpp/ttnn/tensor/tensor_ops.hpp index 56113a8db259..4b63973c45bf 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_ops.hpp +++ b/ttnn/cpp/ttnn/tensor/tensor_ops.hpp @@ -8,7 +8,9 @@ namespace tt::tt_metal { struct Tensor; struct MemoryConfig; +namespace distributed { class MeshDevice; +} // namespace distributed inline namespace v0 { class CommandQueue; diff --git a/ttnn/cpp/ttnn/types.hpp b/ttnn/cpp/ttnn/types.hpp index 4fa90e7128cd..7fdf72f496da 100644 --- a/ttnn/cpp/ttnn/types.hpp +++ b/ttnn/cpp/ttnn/types.hpp @@ -7,6 +7,7 @@ #include "tt_metal/detail/tt_metal.hpp" #include "tt_metal/impl/allocator/allocator.hpp" +#include "ttnn/distributed/types.hpp" #include "ttnn/tensor/tensor.hpp" #include "ttnn/tensor/types.hpp" @@ -14,10 +15,7 @@ namespace ttnn { namespace types { using Device = tt::tt_metal::Device; -using MeshShape = tt::tt_metal::MeshShape; -using DeviceIds = tt::tt_metal::DeviceIds; -using MeshDevice = tt::tt_metal::MeshDevice; -using MeshDeviceView = tt::tt_metal::MeshDeviceView; +using namespace ttnn::distributed::types; constexpr auto TILE_SIZE = 32;