diff --git a/tests/tt_metal/tt_metal/api/test_global_semaphores.cpp b/tests/tt_metal/tt_metal/api/test_global_semaphores.cpp index 5f05b2f85d1..7417bdd13df 100644 --- a/tests/tt_metal/tt_metal/api/test_global_semaphores.cpp +++ b/tests/tt_metal/tt_metal/api/test_global_semaphores.cpp @@ -53,7 +53,7 @@ TEST_F(DispatchFixture, CreateMultipleGlobalSemaphoresOnSameCore) { } for (auto device : devices_) { { - std::vector> global_semaphores; + std::vector> global_semaphores; global_semaphores.reserve(cores.size()); std::vector addresses; addresses.reserve(cores.size()); diff --git a/tests/tt_metal/tt_metal/dispatch/sub_device_test_utils.hpp b/tests/tt_metal/tt_metal/dispatch/sub_device_test_utils.hpp index cc3f9db4aef..54b77acedc1 100644 --- a/tests/tt_metal/tt_metal/dispatch/sub_device_test_utils.hpp +++ b/tests/tt_metal/tt_metal/dispatch/sub_device_test_utils.hpp @@ -9,7 +9,7 @@ // TODO: ARCH_NAME specific, must remove #include "eth_l1_address_map.h" -inline std::tuple> create_single_sync_program( +inline std::tuple> create_single_sync_program( Device* device, SubDevice sub_device) { auto syncer_coord = sub_device.cores(HalProgrammableCoreType::TENSIX).ranges().at(0).start_coord; auto syncer_core = CoreRangeSet(CoreRange(syncer_coord, syncer_coord)); @@ -26,7 +26,7 @@ inline std::tuple> create_s return {std::move(syncer_program), std::move(syncer_coord), std::move(global_sem)}; } -inline std::tuple> create_basic_sync_program( +inline std::tuple> create_basic_sync_program( Device* device, const SubDevice& sub_device_1, const SubDevice& sub_device_2) { auto waiter_coord = sub_device_2.cores(HalProgrammableCoreType::TENSIX).ranges().at(0).start_coord; auto waiter_core = CoreRangeSet(CoreRange(waiter_coord, waiter_coord)); @@ -70,7 +70,7 @@ inline std::tuple> c std::move(waiter_program), std::move(syncer_program), std::move(incrementer_program), std::move(global_sem)}; } -inline std::tuple> create_basic_eth_sync_program( +inline std::tuple> create_basic_eth_sync_program( Device* device, const SubDevice& sub_device_1, const SubDevice& sub_device_2) { auto waiter_coord = sub_device_2.cores(HalProgrammableCoreType::ACTIVE_ETH).ranges().at(0).start_coord; auto waiter_core = CoreRangeSet(CoreRange(waiter_coord, waiter_coord)); diff --git a/tests/ttnn/unit_tests/test_global_circular_buffer.py b/tests/ttnn/unit_tests/test_global_circular_buffer.py new file mode 100644 index 00000000000..5314046fe2c --- /dev/null +++ b/tests/ttnn/unit_tests/test_global_circular_buffer.py @@ -0,0 +1,42 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +import ttnn + + +def run_global_circular_buffer(device): + sender_cores = [ttnn.CoreCoord(1, 1), ttnn.CoreCoord(2, 2)] + receiver_cores = [ + ttnn.CoreRangeSet( + { + ttnn.CoreRange( + ttnn.CoreCoord(4, 4), + ttnn.CoreCoord(4, 4), + ), + } + ), + ttnn.CoreRangeSet( + { + ttnn.CoreRange( + ttnn.CoreCoord(2, 3), + ttnn.CoreCoord(2, 4), + ), + } + ), + ] + sender_receiver_mapping = dict(zip(sender_cores, receiver_cores)) + + global_circular_buffer = ttnn.create_global_circular_buffer(device, sender_receiver_mapping, 3200) + + +@pytest.mark.parametrize("enable_async_mode", (False, True), indirect=True) +def test_global_circular_buffer(device, enable_async_mode): + run_global_circular_buffer(device) + + +@pytest.mark.parametrize("enable_async_mode", (False, True), indirect=True) +def test_global_circular_buffer_mesh(mesh_device, enable_async_mode): + run_global_circular_buffer(mesh_device) diff --git a/tests/ttnn/unit_tests/test_global_semaphore.py b/tests/ttnn/unit_tests/test_global_semaphore.py new file mode 100644 index 00000000000..24c6fa107de --- /dev/null +++ b/tests/ttnn/unit_tests/test_global_semaphore.py @@ -0,0 +1,42 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +import ttnn + + +def run_global_semaphore(device): + tensix_cores0 = ttnn.CoreRangeSet( + { + ttnn.CoreRange( + ttnn.CoreCoord(0, 0), + ttnn.CoreCoord(3, 3), + ), + } + ) + tensix_cores1 = ttnn.CoreRangeSet( + { + ttnn.CoreRange( + ttnn.CoreCoord(4, 4), + ttnn.CoreCoord(4, 4), + ), + } + ) + global_sem0 = ttnn.create_global_semaphore(device, tensix_cores0, 1) + global_sem1 = ttnn.create_global_semaphore(device, tensix_cores1, 2) + + assert ttnn.get_global_semaphore_address(global_sem0) != ttnn.get_global_semaphore_address(global_sem1) + + ttnn.reset_global_semaphore_value(global_sem0) + + +@pytest.mark.parametrize("enable_async_mode", (False, True), indirect=True) +def test_global_semaphore(device, enable_async_mode): + run_global_semaphore(device) + + +@pytest.mark.parametrize("enable_async_mode", (False, True), indirect=True) +def test_global_semaphore_mesh(mesh_device, enable_async_mode): + run_global_semaphore(mesh_device) diff --git a/tests/ttnn/unit_tests/test_sub_device.py b/tests/ttnn/unit_tests/test_sub_device.py new file mode 100644 index 00000000000..7d3f93797a7 --- /dev/null +++ b/tests/ttnn/unit_tests/test_sub_device.py @@ -0,0 +1,112 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +import ttnn + + +def run_sub_devices(device): + tensix_cores0 = ttnn.CoreRangeSet( + { + ttnn.CoreRange( + ttnn.CoreCoord(0, 0), + ttnn.CoreCoord(3, 3), + ), + } + ) + tensix_cores1 = ttnn.CoreRangeSet( + { + ttnn.CoreRange( + ttnn.CoreCoord(4, 4), + ttnn.CoreCoord(4, 4), + ), + } + ) + sub_device_1 = ttnn.SubDevice([tensix_cores0]) + sub_device_2 = ttnn.SubDevice([tensix_cores1]) + sub_device_manager1 = device.create_sub_device_manager([sub_device_1, sub_device_2], 3200) + sub_device_manager2 = device.create_sub_device_manager([sub_device_2], 3200) + device.load_sub_device_manager(sub_device_manager1) + device.load_sub_device_manager(sub_device_manager2) + device.clear_loaded_sub_device_manager() + device.remove_sub_device_manager(sub_device_manager1) + device.remove_sub_device_manager(sub_device_manager2) + + +def run_sub_devices_program(device): + is_mesh_device = isinstance(device, ttnn.MeshDevice) + if is_mesh_device: + inputs_mesh_mapper = ttnn.ShardTensorToMesh(device, dim=0) + output_mesh_composer = ttnn.ConcatMeshToTensor(device, dim=0) + num_devices = device.get_num_devices() + else: + inputs_mesh_mapper = None + output_mesh_composer = None + num_devices = 1 + tensix_cores0 = ttnn.CoreRangeSet( + { + ttnn.CoreRange( + ttnn.CoreCoord(0, 0), + ttnn.CoreCoord(3, 3), + ), + } + ) + tensix_cores1 = ttnn.CoreRangeSet( + { + ttnn.CoreRange( + ttnn.CoreCoord(4, 4), + ttnn.CoreCoord(4, 4), + ), + } + ) + sub_device_1 = ttnn.SubDevice([tensix_cores0]) + sub_device_2 = ttnn.SubDevice([tensix_cores1]) + sub_device_manager = device.create_sub_device_manager([sub_device_1, sub_device_2], 3200) + device.load_sub_device_manager(sub_device_manager) + + x = torch.randn(num_devices, 1, 64, 64, dtype=torch.bfloat16) + xt = ttnn.from_torch( + x, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + mesh_mapper=inputs_mesh_mapper, + ) + + grid_size = device.compute_with_storage_grid_size() + shard_size = [32, 64] + shard_scheme = ttnn.TensorMemoryLayout.HEIGHT_SHARDED + shard_orientation = ttnn.ShardOrientation.ROW_MAJOR + yt = ttnn.interleaved_to_sharded( + xt, grid_size, shard_size, shard_scheme, shard_orientation, output_dtype=ttnn.bfloat16 + ) + y = ttnn.to_torch(yt, device=device, mesh_composer=output_mesh_composer) + + eq = torch.equal(x, y) + assert eq + + device.clear_loaded_sub_device_manager() + device.remove_sub_device_manager(sub_device_manager) + + +@pytest.mark.parametrize("enable_async_mode", (False, True), indirect=True) +def test_sub_devices(device, enable_async_mode): + run_sub_devices(device) + + +@pytest.mark.parametrize("enable_async_mode", (False, True), indirect=True) +def test_sub_devices_mesh(mesh_device, enable_async_mode): + run_sub_devices(mesh_device) + + +@pytest.mark.parametrize("enable_async_mode", (False, True), indirect=True) +def test_sub_device_program(device, enable_async_mode): + run_sub_devices_program(device) + + +@pytest.mark.parametrize("enable_async_mode", (False, True), indirect=True) +def test_sub_device_program_mesh(mesh_device, enable_async_mode): + run_sub_devices_program(mesh_device) diff --git a/tt_metal/distributed/mesh_device.cpp b/tt_metal/distributed/mesh_device.cpp index db820323ae5..3641c4ca7ed 100644 --- a/tt_metal/distributed/mesh_device.cpp +++ b/tt_metal/distributed/mesh_device.cpp @@ -441,4 +441,48 @@ size_t MeshDevice::num_program_cache_entries() const { return total_entries; } +MeshSubDeviceManagerId MeshDevice::create_sub_device_manager(tt::stl::Span sub_devices, DeviceAddr local_l1_size) { + MeshSubDeviceManagerId mesh_sub_device_manager_id(*this); + for (uint32_t i = 0; i < this->num_devices(); i++) { + auto* device = this->devices[i]; + auto& sub_device_manager_id = mesh_sub_device_manager_id.sub_device_manager_ids[i]; + device->push_work([device, sub_devices, local_l1_size, &sub_device_manager_id]() { + sub_device_manager_id = device->create_sub_device_manager(sub_devices, local_l1_size); + }); + } + for (auto* device : this->devices) { + device->synchronize(); + } + return mesh_sub_device_manager_id; +} +void MeshDevice::load_sub_device_manager(MeshSubDeviceManagerId mesh_sub_device_manager_id) { + for (uint32_t i = 0; i < this->num_devices(); i++) { + auto* device = this->devices[i]; + auto sub_device_manager_id = mesh_sub_device_manager_id.sub_device_manager_ids[i]; + device->push_work([device, sub_device_manager_id]() { + device->load_sub_device_manager(sub_device_manager_id); + }); + } +} +void MeshDevice::clear_loaded_sub_device_manager() { + for (auto* device : this->devices) { + device->push_work([device]() { + device->clear_loaded_sub_device_manager(); + }); + } +} +void MeshDevice::remove_sub_device_manager(MeshSubDeviceManagerId mesh_sub_device_manager_id) { + for (uint32_t i = 0; i < this->num_devices(); i++) { + auto* device = this->devices[i]; + auto sub_device_manager_id = mesh_sub_device_manager_id.sub_device_manager_ids[i]; + device->push_work([device, sub_device_manager_id]() { + device->remove_sub_device_manager(sub_device_manager_id); + }); + } +} + +MeshSubDeviceManagerId::MeshSubDeviceManagerId(const MeshDevice& mesh_device) { + this->sub_device_manager_ids.resize(mesh_device.num_devices()); +} + } // namespace tt::tt_metal::distributed diff --git a/tt_metal/distributed/mesh_device.hpp b/tt_metal/distributed/mesh_device.hpp index 2c0ee62f872..3c838398aa0 100644 --- a/tt_metal/distributed/mesh_device.hpp +++ b/tt_metal/distributed/mesh_device.hpp @@ -9,8 +9,10 @@ #include #include -#include "tt_metal/impl/device/device.hpp" #include "tt_metal/distributed/mesh_device_view.hpp" +#include "tt_metal/impl/device/device.hpp" +#include "tt_metal/impl/sub_device/sub_device_types.hpp" +#include "tt_metal/tt_stl/span.hpp" namespace tt::tt_metal::distributed { @@ -19,6 +21,8 @@ using MeshDeviceID = size_t; using MeshOffset = std::pair; class MeshDeviceView; +struct MeshSubDeviceManagerId; + struct MeshDeviceConfig { MeshShape mesh_shape; MeshOffset offset; @@ -171,6 +175,12 @@ class MeshDevice : public std::enable_shared_from_this { size_t num_program_cache_entries() const; + MeshSubDeviceManagerId create_sub_device_manager( + tt::stl::Span sub_devices, DeviceAddr local_l1_size); + void load_sub_device_manager(MeshSubDeviceManagerId mesh_sub_device_manager_id); + void clear_loaded_sub_device_manager(); + void remove_sub_device_manager(MeshSubDeviceManagerId mesh_sub_device_manager_id); + static std::shared_ptr fetch_mesh_device(const std::vector& devices); static std::shared_ptr create( const MeshDeviceConfig& config, @@ -182,4 +192,12 @@ class MeshDevice : public std::enable_shared_from_this { std::ostream& operator<<(std::ostream& os, const MeshDevice& mesh_device); +// TODO: This will be removed once we have DistributedDevice +// Currently required since each device manages its own sub-device manager ids +struct MeshSubDeviceManagerId { + MeshSubDeviceManagerId(const MeshDevice& mesh_device); + + std::vector sub_device_manager_ids; +}; + } // namespace tt::tt_metal::distributed diff --git a/tt_metal/host_api.hpp b/tt_metal/host_api.hpp index 57863280c0a..be3b0e5fcad 100644 --- a/tt_metal/host_api.hpp +++ b/tt_metal/host_api.hpp @@ -297,7 +297,7 @@ uint32_t CreateSemaphore( * Initializes a global semaphore on all cores within the specified CoreRangeSet. * This only supports tensix cores, and can only use L1 buffer types like BufferType::L1 and BufferType::L1_SMALL. * - * Return value: std::unique_ptr + * Return value: std::shared_ptr * * | Argument | Description | Type | Valid Range | Required | * |---------------|------------------------------------------------------|-----------------------------------------------------------|--------------|----------| @@ -307,7 +307,7 @@ uint32_t CreateSemaphore( * | buffer_type | Buffer type to store the semaphore | BufferType | L1 types | No | */ // clang-format on -std::unique_ptr CreateGlobalSemaphore( +std::shared_ptr CreateGlobalSemaphore( Device* device, const CoreRangeSet& cores, uint32_t initial_value, BufferType buffer_type = BufferType::L1); // clang-format off @@ -315,7 +315,7 @@ std::unique_ptr CreateGlobalSemaphore( * Initializes a global semaphore on all cores within the specified CoreRangeSet. * This only supports tensix cores, and can only use L1 buffer types like BufferType::L1 and BufferType::L1_SMALL. * - * Return value: std::unique_ptr + * Return value: std::shared_ptr * * | Argument | Description | Type | Valid Range | Required | * |---------------|------------------------------------------------------|-----------------------------------------------------------|--------------|----------| @@ -325,7 +325,7 @@ std::unique_ptr CreateGlobalSemaphore( * | buffer_type | Buffer type to store the semaphore | BufferType | L1 types | No | */ // clang-format on -std::unique_ptr CreateGlobalSemaphore( +std::shared_ptr CreateGlobalSemaphore( Device* device, CoreRangeSet&& cores, uint32_t initial_value, BufferType buffer_type = BufferType::L1); // clang-format off diff --git a/tt_metal/impl/buffers/global_semaphore.cpp b/tt_metal/impl/buffers/global_semaphore.cpp index b38fbf99f6a..807e74a8e10 100644 --- a/tt_metal/impl/buffers/global_semaphore.cpp +++ b/tt_metal/impl/buffers/global_semaphore.cpp @@ -53,15 +53,17 @@ void GlobalSemaphore::setup_buffer(BufferType buffer_type) { this->reset_semaphore_value(); } -std::unique_ptr GlobalSemaphore::create( +std::shared_ptr GlobalSemaphore::create( Device* device, const CoreRangeSet& cores, uint32_t initial_value, BufferType buffer_type) { return std::make_unique(device, cores, initial_value, buffer_type); } -std::unique_ptr GlobalSemaphore::create( +std::shared_ptr GlobalSemaphore::create( Device* device, CoreRangeSet&& cores, uint32_t initial_value, BufferType buffer_type) { return std::make_unique(device, std::move(cores), initial_value, buffer_type); } +Device* GlobalSemaphore::device() const { return device_; } + DeviceAddr GlobalSemaphore::address() const { return buffer_->address(); } void GlobalSemaphore::reset_semaphore_value() { diff --git a/tt_metal/impl/buffers/global_semaphore.hpp b/tt_metal/impl/buffers/global_semaphore.hpp index 5eb5489fe9a..6c2f8d17947 100644 --- a/tt_metal/impl/buffers/global_semaphore.hpp +++ b/tt_metal/impl/buffers/global_semaphore.hpp @@ -32,12 +32,14 @@ class GlobalSemaphore { GlobalSemaphore(GlobalSemaphore&&) noexcept = default; GlobalSemaphore& operator=(GlobalSemaphore&&) noexcept = default; - static std::unique_ptr create( + static std::shared_ptr create( Device* device, const CoreRangeSet& cores, uint32_t initial_value, BufferType buffer_type = BufferType::L1); - static std::unique_ptr create( + static std::shared_ptr create( Device* device, CoreRangeSet&& cores, uint32_t initial_value, BufferType buffer_type = BufferType::L1); + Device* device() const; + DeviceAddr address() const; void reset_semaphore_value(); diff --git a/tt_metal/impl/device/device.hpp b/tt_metal/impl/device/device.hpp index d548184bc79..045a1097aac 100644 --- a/tt_metal/impl/device/device.hpp +++ b/tt_metal/impl/device/device.hpp @@ -377,7 +377,7 @@ class Device { SubDeviceManagerId get_active_sub_device_manager_id() const; SubDeviceManagerId get_default_sub_device_manager_id() const; - SubDeviceManagerId create_sub_device_manager(tt::stl::Span sub_devices, DeviceAddr mesh_l1_size); + SubDeviceManagerId create_sub_device_manager(tt::stl::Span sub_devices, DeviceAddr local_l1_size); void load_sub_device_manager(SubDeviceManagerId sub_device_manager_id); void clear_loaded_sub_device_manager(); void remove_sub_device_manager(SubDeviceManagerId sub_device_manager_id); diff --git a/tt_metal/tt_metal.cpp b/tt_metal/tt_metal.cpp index 0729c27c390..ff1987983e9 100644 --- a/tt_metal/tt_metal.cpp +++ b/tt_metal/tt_metal.cpp @@ -1161,12 +1161,12 @@ uint32_t CreateSemaphore( core_spec); } -std::unique_ptr CreateGlobalSemaphore( +std::shared_ptr CreateGlobalSemaphore( Device* device, const CoreRangeSet& cores, uint32_t initial_value, BufferType buffer_type) { return GlobalSemaphore::create(device, cores, initial_value, buffer_type); } -std::unique_ptr CreateGlobalSemaphore( +std::shared_ptr CreateGlobalSemaphore( Device* device, CoreRangeSet&& cores, uint32_t initial_value, BufferType buffer_type) { return GlobalSemaphore::create(device, std::move(cores), initial_value, buffer_type); } diff --git a/ttnn/CMakeLists.txt b/ttnn/CMakeLists.txt index 8f2baff00f1..2f432ee54e0 100644 --- a/ttnn/CMakeLists.txt +++ b/ttnn/CMakeLists.txt @@ -4,6 +4,8 @@ set(ALL_TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/core.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/events.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/global_circular_buffer.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/global_semaphore.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/run_operation.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/distributed/api.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/distributed/distributed_pybind.cpp @@ -586,6 +588,8 @@ set(TTNN_SRC) set(PYBIND_SRC ${CMAKE_CURRENT_SOURCE_DIR}/cpp/pybind11/events.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/pybind11/global_circular_buffer.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/pybind11/global_semaphore.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/pybind11/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/pybind11/profiler.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/pybind11/tensor.cpp diff --git a/ttnn/cpp/pybind11/__init__.cpp b/ttnn/cpp/pybind11/__init__.cpp index 639498cd807..2f0a322614c 100644 --- a/ttnn/cpp/pybind11/__init__.cpp +++ b/ttnn/cpp/pybind11/__init__.cpp @@ -12,6 +12,8 @@ #include "device.hpp" #include "profiler.hpp" #include "events.hpp" +#include "global_circular_buffer.hpp" +#include "global_semaphore.hpp" #include "tensor.hpp" #include "reports.hpp" #include "ttnn/distributed/distributed_pybind.hpp" @@ -43,6 +45,8 @@ PYBIND11_MODULE(_ttnn, module) { auto m_device = module.def_submodule("device", "ttnn devices"); auto m_multi_device = module.def_submodule("multi_device", "ttnn multi_device"); auto m_events = module.def_submodule("events", "ttnn events"); + auto m_global_circular_buffer = module.def_submodule("global_circular_buffer", "ttnn global circular buffer"); + auto m_global_semaphore = module.def_submodule("global_semaphore", "ttnn global semaphore"); auto m_profiler = module.def_submodule("profiler", "Submodule defining the profiler"); auto m_reports = module.def_submodule("reports", "ttnn reports"); auto m_operations = module.def_submodule("operations", "ttnn Operations"); @@ -58,6 +62,8 @@ PYBIND11_MODULE(_ttnn, module) { ttnn::device::py_device_module_types(m_device); ttnn::distributed::py_module_types(m_multi_device); ttnn::events::py_module_types(m_events); + ttnn::global_circular_buffer::py_module_types(m_global_circular_buffer); + ttnn::global_semaphore::py_module_types(m_global_semaphore); ttnn::reports::py_module_types(m_reports); // FUNCTIONS / OPERATIONS @@ -79,6 +85,8 @@ PYBIND11_MODULE(_ttnn, module) { ttnn::device::py_device_module(m_device); ttnn::distributed::py_module(m_multi_device); ttnn::events::py_module(m_events); + ttnn::global_circular_buffer::py_module(m_global_circular_buffer); + ttnn::global_semaphore::py_module(m_global_semaphore); ttnn::profiler::py_module(m_profiler); ttnn::reports::py_module(m_reports); diff --git a/ttnn/cpp/pybind11/device.cpp b/ttnn/cpp/pybind11/device.cpp index 198326fd7d7..b60a36ed7ad 100644 --- a/ttnn/cpp/pybind11/device.cpp +++ b/ttnn/cpp/pybind11/device.cpp @@ -97,9 +97,23 @@ void py_device_module_types(py::module& m_device) { py::class_>( m_device, "Device", "Class describing a Tenstorrent accelerator device."); + + py::class_(m_device, "SubDevice", "Class describing a sub-device of a Tenstorrent accelerator device."); + + py::class_(m_device, "SubDeviceManagerId", "ID of a sub-device manager."); } void device_module(py::module& m_device) { + auto pySubDevice = static_cast>(m_device.attr("SubDevice")); + pySubDevice.def( + py::init<>([](std::vector cores) { return SubDevice(cores); }), + py::arg("cores"), + R"doc( + Creates a SubDevice object from a list of CoreRangeSet objects, where each CoreRangeSet object + represents the cores from a specific CoreType. + The order of cores is Tensix, then Ethernet. + )doc"); + auto pyDevice = static_cast>>(m_device.attr("Device")); pyDevice .def( @@ -133,7 +147,66 @@ void device_module(py::module& m_device) { "num_program_cache_entries", &Device::num_program_cache_entries, "Number of entries in the program cache for this device") - .def("enable_async", &Device::enable_async); + .def("enable_async", &Device::enable_async) + .def( + "create_sub_device_manager", + [](Device* device, + const std::vector& sub_devices, + DeviceAddr local_l1_size) -> SubDeviceManagerId { + SubDeviceManagerId sub_device_manager_id; + device->push_work( + [device, sub_devices, local_l1_size, &sub_device_manager_id] { + sub_device_manager_id = device->create_sub_device_manager(sub_devices, local_l1_size); + }, + true); + return sub_device_manager_id; + }, + py::arg("sub_devices"), + py::arg("local_l1_size"), + R"doc( + Creates a sub-device manager for the given device. + + Args: + sub_devices (List[ttnn.SubDevice]): The sub-devices to include in the sub-device manager. + local_l1_size (int): The size of the local allocators of each sub-device. The global allocator will be shrunk by this amount. + + Returns: + SubDeviceManagerId: The ID of the created sub-device manager. + )doc") + .def( + "load_sub_device_manager", + [](Device* device, SubDeviceManagerId sub_device_manager_id) { + device->push_work([device, sub_device_manager_id] { + device->push_work( + [device, sub_device_manager_id] { device->load_sub_device_manager(sub_device_manager_id); }); + }); + }, + py::arg("sub_device_manager_id"), + R"doc( + Loads the sub-device manager with the given ID. + + Args: + sub_device_manager_id (SubDeviceManagerId): The ID of the sub-device manager to load. + )doc") + .def( + "clear_loaded_sub_device_manager", + [](Device* device) { device->push_work([device] { device->clear_loaded_sub_device_manager(); }); }, + R"doc( + Clears the loaded sub-device manager for the given device. + )doc") + .def( + "remove_sub_device_manager", + [](Device* device, SubDeviceManagerId sub_device_manager_id) { + device->push_work( + [device, sub_device_manager_id] { device->remove_sub_device_manager(sub_device_manager_id); }); + }, + py::arg("sub_device_manager_id"), + R"doc( + Removes the sub-device manager with the given ID. + + Args: + sub_device_manager_id (SubDeviceManagerId): The ID of the sub-device manager to remove. + )doc"); // *** eps constant *** m_device.attr("EPS_GS") = EPS_GS; m_device.attr("EPS_WHB0") = EPS_WHB0; diff --git a/ttnn/cpp/pybind11/global_circular_buffer.cpp b/ttnn/cpp/pybind11/global_circular_buffer.cpp new file mode 100644 index 00000000000..f736ee99781 --- /dev/null +++ b/ttnn/cpp/pybind11/global_circular_buffer.cpp @@ -0,0 +1,58 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "global_circular_buffer.hpp" + +#include "tt_metal/impl/buffers/global_circular_buffer.hpp" +#include "ttnn/cpp/ttnn/global_circular_buffer.hpp" +#include "pybind11/pybind11.h" + +namespace ttnn::global_circular_buffer { + +void py_module_types(py::module& module) { + py::class_>(module, "global_circular_buffer"); + py::class_(module, "multi_device_global_circular_buffer"); +} + +void py_module(py::module& module) { + // Single Device APIs + module.def( + "create_global_circular_buffer", + py::overload_cast&, uint32_t, BufferType>( + &create_global_circular_buffer), + py::arg("device"), + py::arg("sender_receiver_core_mapping"), + py::arg("size"), + py::arg("buffer_type") = tt::tt_metal::BufferType::L1, + R"doc( + Create a GlobalCircularBuffer Object on a single device. + + Args: + device (Device): The device on which to create the global circular buffer. + sender_receiver_core_mapping (dict): The mapping of remote sender to remote receiver cores for the circular buffer. + size (int): Size of the global circular buffer per core in bytes. + buffer_type (BufferType): The type of buffer to use for the global circular buffer. + )doc"); + + // Multi Device APIs + module.def( + "create_global_circular_buffer", + py::overload_cast&, uint32_t, BufferType>( + &create_global_circular_buffer), + py::arg("mesh_device"), + py::arg("sender_receiver_core_mapping"), + py::arg("size"), + py::arg("buffer_type") = tt::tt_metal::BufferType::L1, + R"doc( + Create a GlobalCircularBuffer Object on a single device. + + Args: + mesh_device (MeshDevice): The mesh device on which to create the global circular buffer. + sender_receiver_core_mapping (dict): The mapping of remote sender to remote receiver cores for the circular buffer. + size (int): Size of the global circular buffer per core in bytes. + buffer_type (BufferType): The type of buffer to use for the global circular buffer. + )doc"); +} + +} // namespace ttnn::global_circular_buffer diff --git a/ttnn/cpp/pybind11/global_circular_buffer.hpp b/ttnn/cpp/pybind11/global_circular_buffer.hpp new file mode 100644 index 00000000000..82405656853 --- /dev/null +++ b/ttnn/cpp/pybind11/global_circular_buffer.hpp @@ -0,0 +1,17 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "pybind11/pybind_fwd.hpp" +#include "ttnn/global_circular_buffer.hpp" + +namespace py = pybind11; + +namespace ttnn::global_circular_buffer { + +void py_module_types(py::module& module); +void py_module(py::module& module); + +} // namespace ttnn::global_circular_buffer diff --git a/ttnn/cpp/pybind11/global_semaphore.cpp b/ttnn/cpp/pybind11/global_semaphore.cpp new file mode 100644 index 00000000000..79a97de58df --- /dev/null +++ b/ttnn/cpp/pybind11/global_semaphore.cpp @@ -0,0 +1,100 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "global_semaphore.hpp" + +#include "tt_metal/impl/buffers/global_semaphore.hpp" +#include "ttnn/cpp/ttnn/global_semaphore.hpp" +#include "pybind11/pybind11.h" + +namespace ttnn::global_semaphore { + +void py_module_types(py::module& module) { + py::class_>(module, "global_sempahore"); + py::class_(module, "multi_device_global_semaphore"); +} + +void py_module(py::module& module) { + // Single Device APIs + module.def( + "create_global_semaphore", + py::overload_cast(&create_global_semaphore), + py::arg("device"), + py::arg("cores"), + py::arg("initial_value"), + py::arg("buffer_type") = tt::tt_metal::BufferType::L1, + R"doc( + Create a GlobalSemaphore Object on a single device. + + Args: + device (Device): The device on which to create the global semaphore. + cores (CoreRangeSet): The cores on which the global semaphore will be used for synchronization. + initial_value (int): The initial value of the global semaphore. + buffer_type (BufferType): The type of buffer to use for the global semaphore. + )doc"); + + module.def( + "get_global_semaphore_address", + py::overload_cast&>(&get_global_semaphore_address), + py::arg("global_semaphore"), + R"doc( + Get the address of the global semaphore. + + Args: + global_semaphore (GlobalSemaphore): The global semaphore object. + )doc"); + + module.def( + "reset_global_semaphore_value", + py::overload_cast&>(&reset_global_semaphore_value), + py::arg("global_semaphore"), + R"doc( + Reset the value of the global semaphore. + + Args: + global_semaphore (GlobalSemaphore): The global semaphore object. + )doc"); + + // Multi Device APIs + module.def( + "create_global_semaphore", + py::overload_cast(&create_global_semaphore), + py::arg("mesh_device"), + py::arg("cores"), + py::arg("initial_value"), + py::arg("buffer_type") = tt::tt_metal::BufferType::L1, + R"doc( + Create a GlobalSemaphore Object on a single device. + + Args: + mesh_device (MeshDevice): The mesh device on which to create the global semaphore. + cores (CoreRangeSet): The cores on which the global semaphore will be used for synchronization. + initial_value (int): The initial value of the global semaphore. + buffer_type (BufferType): The type of buffer to use for the global semaphore. + )doc"); + + module.def( + "get_global_semaphore_address", + py::overload_cast(&get_global_semaphore_address), + py::arg("global_semaphore"), + R"doc( + Get the address of the global semaphore. + + Args: + global_semaphore (GlobalSemaphore): The global semaphore object. + )doc"); + + module.def( + "reset_global_semaphore_value", + py::overload_cast(&reset_global_semaphore_value), + py::arg("global_semaphore"), + R"doc( + Reset the value of the global semaphore. + + Args: + global_semaphore (GlobalSemaphore): The global semaphore object. + )doc"); +} + +} // namespace ttnn::global_semaphore diff --git a/ttnn/cpp/pybind11/global_semaphore.hpp b/ttnn/cpp/pybind11/global_semaphore.hpp new file mode 100644 index 00000000000..33dedd67ec7 --- /dev/null +++ b/ttnn/cpp/pybind11/global_semaphore.hpp @@ -0,0 +1,17 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "pybind11/pybind_fwd.hpp" +#include "ttnn/global_semaphore.hpp" + +namespace py = pybind11; + +namespace ttnn::global_semaphore { + +void py_module_types(py::module& module); +void py_module(py::module& module); + +} // namespace ttnn::global_semaphore diff --git a/ttnn/cpp/ttnn/device.hpp b/ttnn/cpp/ttnn/device.hpp index b0f19e5e746..d38fd1eb363 100644 --- a/ttnn/cpp/ttnn/device.hpp +++ b/ttnn/cpp/ttnn/device.hpp @@ -5,6 +5,7 @@ #pragma once #include "ttnn/types.hpp" + namespace ttnn { namespace device { diff --git a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp index 8bb6ca1d9c3..ec7e8e4691f 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp @@ -17,7 +17,10 @@ namespace ttnn::distributed { namespace py = pybind11; -void py_module_types(py::module& module) { py::class_>(module, "MeshDevice"); } +void py_module_types(py::module& module) { + py::class_>(module, "MeshDevice"); + py::class_(module, "MeshSubDeviceManagerId"); +} void py_module(py::module& module) { py::enum_(module, "MeshType") @@ -137,7 +140,50 @@ void py_module(py::module& module) { Returns: Tuple[int, int]: The shape of the device mesh as (num_rows, num_cols). )doc") - .def("__repr__", &MeshDevice::to_string); + .def("__repr__", &MeshDevice::to_string) + .def( + "create_sub_device_manager", + [](MeshDevice& self, const std::vector& sub_devices, DeviceAddr local_l1_size) { + return self.create_sub_device_manager(sub_devices, local_l1_size); + }, + py::arg("sub_devices"), + py::arg("local_l1_size"), + R"doc( + Creates a sub-device manager for the given mesh device. + + Args: + sub_devices (List[ttnn.SubDevice]): The sub-devices to include in the sub-device manager. + local_l1_size (int): The size of the local allocators of each sub-device. The global allocator will be shrunk by this amount. + + Returns: + MeshSubDeviceManagerId: The ID of the created sub-device manager. + )doc") + .def( + "load_sub_device_manager", + &MeshDevice::load_sub_device_manager, + py::arg("mesh_sub_device_manager_id"), + R"doc( + Loads the sub-device manager with the given ID. + + Args: + mesh_sub_device_manager_id (MeshSubDeviceManagerId): The ID of the sub-device manager to load. + )doc") + .def( + "clear_loaded_sub_device_manager", + &MeshDevice::clear_loaded_sub_device_manager, + R"doc( + Clears the loaded sub-device manager for the given mesh device. + )doc") + .def( + "remove_sub_device_manager", + &MeshDevice::remove_sub_device_manager, + py::arg("mesh_sub_device_manager_id"), + R"doc( + Removes the sub-device manager with the given ID. + + Args: + mesh_sub_device_manager_id (MeshSubDeviceManagerId): The ID of the sub-device manager to remove. + )doc"); module.def( "open_mesh_device", diff --git a/ttnn/cpp/ttnn/distributed/types.hpp b/ttnn/cpp/ttnn/distributed/types.hpp index 421c022fcef..557d10c90ec 100644 --- a/ttnn/cpp/ttnn/distributed/types.hpp +++ b/ttnn/cpp/ttnn/distributed/types.hpp @@ -18,6 +18,7 @@ using MeshDevice = tt::tt_metal::distributed::MeshDevice; using MeshDeviceView = tt::tt_metal::distributed::MeshDeviceView; using MeshType = tt::tt_metal::distributed::MeshType; using MeshDeviceConfig = tt::tt_metal::distributed::MeshDeviceConfig; +using MeshSubDeviceManagerId = tt::tt_metal::distributed::MeshSubDeviceManagerId; } // namespace ttnn::distributed @@ -29,6 +30,7 @@ using ttnn::distributed::MeshDevice; using ttnn::distributed::MeshDeviceConfig; using ttnn::distributed::MeshDeviceView; using ttnn::distributed::MeshShape; +using ttnn::distributed::MeshSubDeviceManagerId; using ttnn::distributed::MeshType; } // namespace ttnn diff --git a/ttnn/cpp/ttnn/global_circular_buffer.cpp b/ttnn/cpp/ttnn/global_circular_buffer.cpp new file mode 100644 index 00000000000..7c5967fa3c2 --- /dev/null +++ b/ttnn/cpp/ttnn/global_circular_buffer.cpp @@ -0,0 +1,56 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "global_circular_buffer.hpp" + +#include +#include "tt_metal/impl/buffers/global_circular_buffer.hpp" +#include "tt_metal/include/tt_metal/global_circular_buffer.hpp" + +namespace ttnn::global_circular_buffer { + +MultiDeviceGlobalCircularBuffer::MultiDeviceGlobalCircularBuffer(MeshDevice* mesh_device) { + TT_ASSERT( + mesh_device != nullptr, + "Must provide a valid mesh_device when initializing a global circular buffer on multiple devices."); + this->global_circular_buffers = std::vector>(mesh_device->num_devices()); +} + +std::shared_ptr create_global_circular_buffer( + Device* device, + const std::unordered_map& sender_receiver_core_mapping, + uint32_t size, + BufferType buffer_type) { + std::shared_ptr global_cb; + device->push_work( + [device, &sender_receiver_core_mapping, size, buffer_type, &global_cb]() { + global_cb = tt::tt_metal::v1::experimental::CreateGlobalCircularBuffer( + device, sender_receiver_core_mapping, size, buffer_type); + }, + /*blocking=*/true); + return global_cb; +} + +MultiDeviceGlobalCircularBuffer create_global_circular_buffer( + MeshDevice* mesh_device, + const std::unordered_map& sender_receiver_core_mapping, + uint32_t size, + BufferType buffer_type) { + MultiDeviceGlobalCircularBuffer multi_device_global_cb(mesh_device); + const auto& devices = mesh_device->get_devices(); + for (uint32_t i = 0; i < devices.size(); ++i) { + auto* device = devices[i]; + auto& global_cb = multi_device_global_cb.global_circular_buffers[i]; + device->push_work([device, &sender_receiver_core_mapping, size, buffer_type, &global_cb]() { + global_cb = tt::tt_metal::v1::experimental::CreateGlobalCircularBuffer( + device, sender_receiver_core_mapping, size, buffer_type); + }); + } + for (auto* device : devices) { + device->synchronize(); + } + return multi_device_global_cb; +} + +} // namespace ttnn::global_circular_buffer diff --git a/ttnn/cpp/ttnn/global_circular_buffer.hpp b/ttnn/cpp/ttnn/global_circular_buffer.hpp new file mode 100644 index 00000000000..bb84ce3a7ab --- /dev/null +++ b/ttnn/cpp/ttnn/global_circular_buffer.hpp @@ -0,0 +1,32 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include "tt_metal/include/tt_metal/global_circular_buffer.hpp" +#include "ttnn/types.hpp" + +namespace ttnn::global_circular_buffer { + +struct MultiDeviceGlobalCircularBuffer { + MultiDeviceGlobalCircularBuffer(MeshDevice* mesh_device); + std::vector> global_circular_buffers; +}; + +// Single Device APIs +std::shared_ptr create_global_circular_buffer( + Device* device, + const std::unordered_map& sender_receiver_core_mapping, + uint32_t size, + BufferType buffer_type = BufferType::L1); + +// Multi Device APIs +MultiDeviceGlobalCircularBuffer create_global_circular_buffer( + MeshDevice* mesh_device, + const std::unordered_map& sender_receiver_core_mapping, + uint32_t size, + BufferType buffer_type = BufferType::L1); + +} // namespace ttnn::global_circular_buffer diff --git a/ttnn/cpp/ttnn/global_semaphore.cpp b/ttnn/cpp/ttnn/global_semaphore.cpp new file mode 100644 index 00000000000..da1ebf8f0f0 --- /dev/null +++ b/ttnn/cpp/ttnn/global_semaphore.cpp @@ -0,0 +1,80 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "global_semaphore.hpp" + +#include +#include "tt_metal/impl/buffers/global_semaphore.hpp" +#include "tt_metal/host_api.hpp" + +namespace ttnn::global_semaphore { + +MultiDeviceGlobalSemaphore::MultiDeviceGlobalSemaphore(MeshDevice* mesh_device) { + TT_ASSERT( + mesh_device != nullptr, + "Must provide a valid mesh_device when initializing a global semaphore on multiple devices."); + this->global_semaphores = std::vector>(mesh_device->num_devices()); +} + +std::shared_ptr create_global_semaphore( + Device* device, const CoreRangeSet& cores, uint32_t initial_value, BufferType buffer_type) { + std::shared_ptr global_semaphore = nullptr; + device->push_work( + [device, &cores, initial_value, buffer_type, &global_semaphore] { + global_semaphore = GlobalSemaphore::create(device, cores, initial_value, buffer_type); + }, + /*blocking=*/true); + return global_semaphore; +} + +DeviceAddr get_global_semaphore_address(const std::shared_ptr& global_semaphore) { + auto* device = global_semaphore->device(); + DeviceAddr address = 0; + device->push_work([&global_semaphore, &address] { address = global_semaphore->address(); }, /*blocking=*/true); + return address; +} + +void reset_global_semaphore_value(const std::shared_ptr& global_semaphore) { + auto* device = global_semaphore->device(); + device->push_work([global_semaphore] { global_semaphore->reset_semaphore_value(); }); +} + +MultiDeviceGlobalSemaphore create_global_semaphore( + MeshDevice* mesh_device, const CoreRangeSet& cores, uint32_t initial_value, BufferType buffer_type) { + MultiDeviceGlobalSemaphore multi_device_global_semaphore(mesh_device); + const auto& devices = mesh_device->get_devices(); + for (uint32_t i = 0; i < devices.size(); ++i) { + auto* device = devices[i]; + auto& global_semaphore = multi_device_global_semaphore.global_semaphores[i]; + device->push_work([device, &cores, initial_value, buffer_type, &global_semaphore] { + global_semaphore = GlobalSemaphore::create(device, cores, initial_value, buffer_type); + }); + } + for (auto device : devices) { + device->synchronize(); + } + return multi_device_global_semaphore; +} +std::vector get_global_semaphore_address(const MultiDeviceGlobalSemaphore& global_semaphore) { + std::vector addresses(global_semaphore.global_semaphores.size()); + const auto& global_semaphores = global_semaphore.global_semaphores; + for (uint32_t i = 0; i < global_semaphores.size(); ++i) { + const auto& global_semaphore = global_semaphores[i]; + auto& address = addresses[i]; + auto* device = global_semaphore->device(); + device->push_work([&global_semaphore, &address] { address = global_semaphore->address(); }); + } + for (const auto& global_semaphore : global_semaphores) { + global_semaphore->device()->synchronize(); + } + return addresses; +} + +void reset_global_semaphore_value(const MultiDeviceGlobalSemaphore& global_semaphore) { + for (const auto& global_semaphore : global_semaphore.global_semaphores) { + reset_global_semaphore_value(global_semaphore); + } +} + +} // namespace ttnn::global_semaphore diff --git a/ttnn/cpp/ttnn/global_semaphore.hpp b/ttnn/cpp/ttnn/global_semaphore.hpp new file mode 100644 index 00000000000..70f56fb4b4c --- /dev/null +++ b/ttnn/cpp/ttnn/global_semaphore.hpp @@ -0,0 +1,33 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include "tt_metal/host_api.hpp" +#include "ttnn/types.hpp" + +namespace ttnn::global_semaphore { + +struct MultiDeviceGlobalSemaphore { + MultiDeviceGlobalSemaphore(MeshDevice* mesh_device); + std::vector> global_semaphores; +}; + +// Single Device APIs +std::shared_ptr create_global_semaphore( + Device* device, const CoreRangeSet& cores, uint32_t initial_value, BufferType buffer_type = BufferType::L1); +DeviceAddr get_global_semaphore_address(const std::shared_ptr& global_semaphore); +void reset_global_semaphore_value(const std::shared_ptr& global_semaphore); + +// Multi Device APIs +MultiDeviceGlobalSemaphore create_global_semaphore( + MeshDevice* mesh_device, + const CoreRangeSet& cores, + uint32_t initial_value, + BufferType buffer_type = BufferType::L1); +std::vector get_global_semaphore_address(const MultiDeviceGlobalSemaphore& global_semaphore); +void reset_global_semaphore_value(const MultiDeviceGlobalSemaphore& global_semaphore); + +} // namespace ttnn::global_semaphore diff --git a/ttnn/cpp/ttnn/types.hpp b/ttnn/cpp/ttnn/types.hpp index d35328d6946..0d67280b5e5 100644 --- a/ttnn/cpp/ttnn/types.hpp +++ b/ttnn/cpp/ttnn/types.hpp @@ -6,6 +6,9 @@ #include "tt_metal/detail/tt_metal.hpp" #include "tt_metal/impl/allocator/allocator.hpp" +#include "tt_metal/impl/buffers/global_circular_buffer.hpp" +#include "tt_metal/impl/buffers/global_semaphore.hpp" +#include "tt_metal/impl/sub_device/sub_device.hpp" #include "ttnn/distributed/types.hpp" #include "ttnn/tensor/tensor.hpp" #include "ttnn/tensor/types.hpp" @@ -55,6 +58,11 @@ static std::ostream& operator<<(std::ostream& os, const CoreGrid& core_grid) { return os; } +using tt::tt_metal::GlobalSemaphore; +using tt::tt_metal::SubDevice; +using tt::tt_metal::SubDeviceManagerId; +using tt::tt_metal::v1::experimental::GlobalCircularBuffer; + } // namespace types using namespace types; diff --git a/ttnn/ttnn/__init__.py b/ttnn/ttnn/__init__.py index 2a987ca2dcd..71f4f748660 100644 --- a/ttnn/ttnn/__init__.py +++ b/ttnn/ttnn/__init__.py @@ -102,6 +102,16 @@ def manage_config(name, value): from ttnn._ttnn.events import create_event, record_event, wait_for_event +from ttnn._ttnn.global_circular_buffer import ( + create_global_circular_buffer, +) + +from ttnn._ttnn.global_semaphore import ( + create_global_semaphore, + get_global_semaphore_address, + reset_global_semaphore_value, +) + from ttnn.types import ( TILE_SIZE, DataType, @@ -175,6 +185,8 @@ def manage_config(name, value): format_input_tensor, format_output_tensor, pad_to_tile_shape, + SubDevice, + SubDeviceManagerId, ) from ttnn.profiler import start_tracy_zone, stop_tracy_zone, tracy_message, tracy_frame diff --git a/ttnn/ttnn/device.py b/ttnn/ttnn/device.py index 894fa137f9d..e620c800a6c 100644 --- a/ttnn/ttnn/device.py +++ b/ttnn/ttnn/device.py @@ -146,5 +146,7 @@ def is_grayskull(device): format_output_tensor = ttnn._ttnn.device.format_output_tensor pad_to_tile_shape = ttnn._ttnn.device.pad_to_tile_shape +SubDevice = ttnn._ttnn.device.SubDevice +SubDeviceManagerId = ttnn._ttnn.device.SubDeviceManagerId __all__ = []