From c230ae596c69e5b7f91a0d287637faf058aed517 Mon Sep 17 00:00:00 2001 From: Austin Ho Date: Tue, 10 Dec 2024 19:02:19 +0000 Subject: [PATCH] #0: Update global cb, sem apis to take in sub_device_ids to know what to stall on when writing to device --- .../tt_metal/api/test_global_semaphores.cpp | 8 ++-- tt_metal/host_api.hpp | 38 ++++++++++------ .../impl/buffers/global_circular_buffer.cpp | 24 +++++++--- .../impl/buffers/global_circular_buffer.hpp | 10 +++-- tt_metal/impl/buffers/global_semaphore.cpp | 45 +++++++++++++------ tt_metal/impl/buffers/global_semaphore.hpp | 29 +++++++++--- .../tt_metal/global_circular_buffer.hpp | 5 ++- tt_metal/tt_metal.cpp | 21 ++++++--- ttnn/cpp/pybind11/global_circular_buffer.cpp | 26 +++++++++-- ttnn/cpp/pybind11/global_semaphore.cpp | 35 +++++++++++++-- ttnn/cpp/ttnn/global_circular_buffer.cpp | 14 +++--- ttnn/cpp/ttnn/global_circular_buffer.hpp | 6 ++- ttnn/cpp/ttnn/global_semaphore.cpp | 33 +++++++++----- ttnn/cpp/ttnn/global_semaphore.hpp | 15 +++++-- 14 files changed, 225 insertions(+), 84 deletions(-) diff --git a/tests/tt_metal/tt_metal/api/test_global_semaphores.cpp b/tests/tt_metal/tt_metal/api/test_global_semaphores.cpp index 7417bdd13df..f592258864f 100644 --- a/tests/tt_metal/tt_metal/api/test_global_semaphores.cpp +++ b/tests/tt_metal/tt_metal/api/test_global_semaphores.cpp @@ -21,7 +21,7 @@ TEST_F(DispatchFixture, InitializeGlobalSemaphores) { uint32_t initial_value = 1; auto global_semaphore = tt::tt_metal::CreateGlobalSemaphore(device, cores, initial_value); auto address = global_semaphore->address(); - + Synchronize(device); for (const auto& core : cores_vec) { auto sem_vals = tt::llrt::read_hex_vec_from_core( device->id(), device->worker_core_from_logical_core(core), address, sizeof(uint32_t)); @@ -33,7 +33,7 @@ TEST_F(DispatchFixture, InitializeGlobalSemaphores) { uint32_t initial_value = 2; auto global_semaphore = tt::tt_metal::CreateGlobalSemaphore(device, cores, initial_value); auto address = global_semaphore->address(); - + Synchronize(device); for (const auto& core : cores_vec) { auto sem_vals = tt::llrt::read_hex_vec_from_core( device->id(), device->worker_core_from_logical_core(core), address, sizeof(uint32_t)); @@ -61,6 +61,7 @@ TEST_F(DispatchFixture, CreateMultipleGlobalSemaphoresOnSameCore) { global_semaphores.push_back(tt::tt_metal::CreateGlobalSemaphore(device, cores[i], initial_values[i])); addresses.push_back(global_semaphores[i]->address()); } + Synchronize(device); for (size_t i = 0; i < cores.size(); i++) { const auto& address = addresses[i]; const auto& initial_value = initial_values[i]; @@ -85,7 +86,7 @@ TEST_F(DispatchFixture, ResetGlobalSemaphores) { std::vector overwrite_value = {2}; auto global_semaphore = tt::tt_metal::CreateGlobalSemaphore(device, cores, initial_value); auto address = global_semaphore->address(); - + Synchronize(device); for (const auto& core : cores_vec) { auto sem_vals = tt::llrt::read_hex_vec_from_core( device->id(), device->worker_core_from_logical_core(core), address, sizeof(uint32_t)); @@ -101,6 +102,7 @@ TEST_F(DispatchFixture, ResetGlobalSemaphores) { EXPECT_EQ(sem_vals[0], overwrite_value[0]); } global_semaphore->reset_semaphore_value(); + Synchronize(device); for (const auto& core : cores_vec) { auto sem_vals = tt::llrt::read_hex_vec_from_core( device->id(), device->worker_core_from_logical_core(core), address, sizeof(uint32_t)); diff --git a/tt_metal/host_api.hpp b/tt_metal/host_api.hpp index be3b0e5fcad..8a9dfcb2003 100644 --- a/tt_metal/host_api.hpp +++ b/tt_metal/host_api.hpp @@ -299,16 +299,21 @@ uint32_t CreateSemaphore( * * Return value: std::shared_ptr * - * | Argument | Description | Type | Valid Range | Required | - * |---------------|------------------------------------------------------|-----------------------------------------------------------|--------------|----------| - * | device | The device to create the semaphore on | Device * | | Yes | - * | cores | Range of the Tensix co-ordinates using the semaphore | const CoreRangeSet & | | Yes | - * | initial_value | Initial value of the semaphore | uint32_t | | Yes | - * | buffer_type | Buffer type to store the semaphore | BufferType | L1 types | No | + * | Argument | Description | Type | Valid Range | Required | + * |----------------|--------------------------------------------------------|-----------------------------------------------------------|--------------|----------| + * | device | The device to create the semaphore on | Device * | | Yes | + * | cores | Range of the Tensix co-ordinates using the semaphore | const CoreRangeSet & | | Yes | + * | initial_value | Initial value of the semaphore | uint32_t | | Yes | + * | buffer_type | Buffer type to store the semaphore | BufferType | L1 types | No | + * | sub_device_ids | Sub-device ids to wait on before writing the semaphore | tt::stl::Span | | No | */ // clang-format on std::shared_ptr CreateGlobalSemaphore( - Device* device, const CoreRangeSet& cores, uint32_t initial_value, BufferType buffer_type = BufferType::L1); + Device* device, + const CoreRangeSet& cores, + uint32_t initial_value, + BufferType buffer_type = BufferType::L1, + tt::stl::Span sub_device_ids = {}); // clang-format off /** @@ -317,16 +322,21 @@ std::shared_ptr CreateGlobalSemaphore( * * Return value: std::shared_ptr * - * | Argument | Description | Type | Valid Range | Required | - * |---------------|------------------------------------------------------|-----------------------------------------------------------|--------------|----------| - * | device | The device to create the semaphore on | Device * | | Yes | - * | cores | Range of the Tensix co-ordinates using the semaphore | CoreRangeSet && | | Yes | - * | initial_value | Initial value of the semaphore | uint32_t | | Yes | - * | buffer_type | Buffer type to store the semaphore | BufferType | L1 types | No | + * | Argument | Description | Type | Valid Range | Required | + * |----------------|--------------------------------------------------------|-----------------------------------------------------------|--------------|----------| + * | device | The device to create the semaphore on | Device * | | Yes | + * | cores | Range of the Tensix co-ordinates using the semaphore | CoreRangeSet && | | Yes | + * | initial_value | Initial value of the semaphore | uint32_t | | Yes | + * | buffer_type | Buffer type to store the semaphore | BufferType | L1 types | No | + * | sub_device_ids | Sub-device ids to wait on before writing the semaphore | tt::stl::Span | | No | */ // clang-format on std::shared_ptr CreateGlobalSemaphore( - Device* device, CoreRangeSet&& cores, uint32_t initial_value, BufferType buffer_type = BufferType::L1); + Device* device, + CoreRangeSet&& cores, + uint32_t initial_value, + BufferType buffer_type = BufferType::L1, + tt::stl::Span sub_device_ids = {}); // clang-format off /** diff --git a/tt_metal/impl/buffers/global_circular_buffer.cpp b/tt_metal/impl/buffers/global_circular_buffer.cpp index 2d8760f1af5..3cb765abf60 100644 --- a/tt_metal/impl/buffers/global_circular_buffer.cpp +++ b/tt_metal/impl/buffers/global_circular_buffer.cpp @@ -27,7 +27,8 @@ GlobalCircularBuffer::GlobalCircularBuffer( Device* device, const std::unordered_map& sender_receiver_core_mapping, uint32_t size, - BufferType buffer_type) : + BufferType buffer_type, + tt::stl::Span sub_device_ids) : device_(device), sender_receiver_core_mapping_(sender_receiver_core_mapping), size_(size) { TT_FATAL(this->device_ != nullptr, "Device cannot be null"); uint32_t num_sender_cores = sender_receiver_core_mapping.size(); @@ -46,10 +47,11 @@ GlobalCircularBuffer::GlobalCircularBuffer( TT_FATAL(num_receiver_cores == this->receiver_cores_.num_cores(), "Duplicate receiver cores found"); this->all_cores_ = this->sender_cores_.merge(this->receiver_cores_); TT_FATAL(this->all_cores_.num_cores() == num_sender_cores + num_receiver_cores, "Duplicate cores found"); - this->setup_cb_buffers(buffer_type, max_num_receivers_per_sender); + this->setup_cb_buffers(buffer_type, max_num_receivers_per_sender, sub_device_ids); } -void GlobalCircularBuffer::setup_cb_buffers(BufferType buffer_type, uint32_t max_num_receivers_per_sender) { +void GlobalCircularBuffer::setup_cb_buffers( + BufferType buffer_type, uint32_t max_num_receivers_per_sender, tt::stl::Span sub_device_ids) { TT_FATAL( buffer_type == BufferType::L1 or buffer_type == BufferType::L1_SMALL, "Global circular buffer can only be created for L1 buffer types"); @@ -123,12 +125,18 @@ void GlobalCircularBuffer::setup_cb_buffers(BufferType buffer_type, uint32_t max } } - // Blocking write of cb config to buffer + // Write the config buffer to the device + // Only block for the slow dispatch case if (this->device_->using_slow_dispatch()) { detail::WriteToBuffer(*this->cb_config_buffer_, cb_config_host_buffer); tt::Cluster::instance().l1_barrier(this->device_->id()); } else { - EnqueueWriteBuffer(this->device_->command_queue(), this->cb_config_buffer_, cb_config_host_buffer.data(), true); + EnqueueWriteBuffer( + this->device_->command_queue(), + this->cb_config_buffer_, + cb_config_host_buffer.data(), + false, + sub_device_ids); } } @@ -136,8 +144,10 @@ std::shared_ptr GlobalCircularBuffer::create( Device* device, const std::unordered_map& sender_receiver_core_mapping, uint32_t size, - BufferType buffer_type) { - return std::make_unique(device, sender_receiver_core_mapping, size, buffer_type); + BufferType buffer_type, + tt::stl::Span sub_device_ids) { + return std::make_shared( + device, sender_receiver_core_mapping, size, buffer_type, sub_device_ids); } const Buffer& GlobalCircularBuffer::cb_buffer() const { return *this->cb_buffer_; } diff --git a/tt_metal/impl/buffers/global_circular_buffer.hpp b/tt_metal/impl/buffers/global_circular_buffer.hpp index d18ed91e0c4..ca0c56da71f 100644 --- a/tt_metal/impl/buffers/global_circular_buffer.hpp +++ b/tt_metal/impl/buffers/global_circular_buffer.hpp @@ -9,6 +9,7 @@ #include "tt_metal/common/core_coord.hpp" #include "tt_metal/impl/buffers/buffer_constants.hpp" +#include "tt_metal/impl/sub_device/sub_device_types.hpp" #include "tt_metal/llrt/hal.hpp" namespace tt::tt_metal { @@ -30,7 +31,8 @@ class GlobalCircularBuffer { Device* device, const std::unordered_map& sender_receiver_core_mapping, uint32_t size, - BufferType buffer_type); + BufferType buffer_type, + tt::stl::Span sub_device_ids); GlobalCircularBuffer(const GlobalCircularBuffer&) = default; GlobalCircularBuffer& operator=(const GlobalCircularBuffer&) = default; @@ -42,7 +44,8 @@ class GlobalCircularBuffer { Device* device, const std::unordered_map& sender_receiver_core_mapping, uint32_t size, - BufferType buffer_type = BufferType::L1); + BufferType buffer_type = BufferType::L1, + tt::stl::Span sub_device_ids = {}); const Buffer& cb_buffer() const; @@ -57,7 +60,8 @@ class GlobalCircularBuffer { const auto attribute_values() const { return std::make_tuple(this->sender_receiver_core_mapping_, this->size_); } private: - void setup_cb_buffers(BufferType buffer_type, uint32_t max_num_receivers_per_sender); + void setup_cb_buffers( + BufferType buffer_type, uint32_t max_num_receivers_per_sender, tt::stl::Span sub_device_ids); // GlobalCircularBuffer is implemented as a wrapper around a sharded buffer // This can be updated in the future to be its own container with optimized dispatch functions diff --git a/tt_metal/impl/buffers/global_semaphore.cpp b/tt_metal/impl/buffers/global_semaphore.cpp index 64d16beb377..7d523f64d12 100644 --- a/tt_metal/impl/buffers/global_semaphore.cpp +++ b/tt_metal/impl/buffers/global_semaphore.cpp @@ -20,17 +20,26 @@ namespace tt::tt_metal { GlobalSemaphore::GlobalSemaphore( - Device* device, const CoreRangeSet& cores, uint32_t initial_value, BufferType buffer_type) : + Device* device, + const CoreRangeSet& cores, + uint32_t initial_value, + BufferType buffer_type, + tt::stl::Span sub_device_ids) : device_(device), cores_(cores), initial_value_(initial_value) { - this->setup_buffer(buffer_type); + this->setup_buffer(buffer_type, sub_device_ids); } -GlobalSemaphore::GlobalSemaphore(Device* device, CoreRangeSet&& cores, uint32_t initial_value, BufferType buffer_type) : +GlobalSemaphore::GlobalSemaphore( + Device* device, + CoreRangeSet&& cores, + uint32_t initial_value, + BufferType buffer_type, + tt::stl::Span sub_device_ids) : device_(device), cores_(std::move(cores)), initial_value_(initial_value) { - this->setup_buffer(buffer_type); + this->setup_buffer(buffer_type, sub_device_ids); } -void GlobalSemaphore::setup_buffer(BufferType buffer_type) { +void GlobalSemaphore::setup_buffer(BufferType buffer_type, tt::stl::Span sub_device_ids) { TT_FATAL( buffer_type == BufferType::L1 or buffer_type == BufferType::L1_SMALL, "Global semaphore can only be created for L1 buffer types"); @@ -50,29 +59,39 @@ void GlobalSemaphore::setup_buffer(BufferType buffer_type) { std::nullopt); this->host_buffer_ = std::vector(num_cores, this->initial_value_); - this->reset_semaphore_value(); + this->reset_semaphore_value(sub_device_ids); } std::shared_ptr GlobalSemaphore::create( - Device* device, const CoreRangeSet& cores, uint32_t initial_value, BufferType buffer_type) { - return std::make_unique(device, cores, initial_value, buffer_type); + Device* device, + const CoreRangeSet& cores, + uint32_t initial_value, + BufferType buffer_type, + tt::stl::Span sub_device_ids) { + return std::make_shared(device, cores, initial_value, buffer_type, sub_device_ids); } std::shared_ptr GlobalSemaphore::create( - Device* device, CoreRangeSet&& cores, uint32_t initial_value, BufferType buffer_type) { - return std::make_unique(device, std::move(cores), initial_value, buffer_type); + Device* device, + CoreRangeSet&& cores, + uint32_t initial_value, + BufferType buffer_type, + tt::stl::Span sub_device_ids) { + return std::make_shared(device, std::move(cores), initial_value, buffer_type, sub_device_ids); } Device* GlobalSemaphore::device() const { return device_; } DeviceAddr GlobalSemaphore::address() const { return buffer_->address(); } -void GlobalSemaphore::reset_semaphore_value() { - // Blocking write of semaphore value to buffer +void GlobalSemaphore::reset_semaphore_value(tt::stl::Span sub_device_ids) { + // Write the initial value to the semaphore to the device + // Only block for the slow dispatch case if (this->device_->using_slow_dispatch()) { detail::WriteToBuffer(*this->buffer_, this->host_buffer_); tt::Cluster::instance().l1_barrier(this->device_->id()); } else { - EnqueueWriteBuffer(this->device_->command_queue(), this->buffer_, this->host_buffer_.data(), true); + EnqueueWriteBuffer( + this->device_->command_queue(), this->buffer_, this->host_buffer_.data(), false, sub_device_ids); } } diff --git a/tt_metal/impl/buffers/global_semaphore.hpp b/tt_metal/impl/buffers/global_semaphore.hpp index f6d657998f8..0d912b2f9ac 100644 --- a/tt_metal/impl/buffers/global_semaphore.hpp +++ b/tt_metal/impl/buffers/global_semaphore.hpp @@ -9,6 +9,7 @@ #include "tt_metal/common/core_coord.hpp" #include "tt_metal/impl/buffers/buffer_constants.hpp" +#include "tt_metal/impl/sub_device/sub_device_types.hpp" #include "tt_metal/llrt/hal.hpp" namespace tt::tt_metal { @@ -21,10 +22,18 @@ class Device; class GlobalSemaphore { public: GlobalSemaphore( - Device* device, const CoreRangeSet& cores, uint32_t initial_value, BufferType buffer_type = BufferType::L1); + Device* device, + const CoreRangeSet& cores, + uint32_t initial_value, + BufferType buffer_type = BufferType::L1, + tt::stl::Span sub_device_ids = {}); GlobalSemaphore( - Device* device, CoreRangeSet&& cores, uint32_t initial_value, BufferType buffer_type = BufferType::L1); + Device* device, + CoreRangeSet&& cores, + uint32_t initial_value, + BufferType buffer_type = BufferType::L1, + tt::stl::Span sub_device_ids = {}); GlobalSemaphore(const GlobalSemaphore&) = default; GlobalSemaphore& operator=(const GlobalSemaphore&) = default; @@ -33,22 +42,30 @@ class GlobalSemaphore { GlobalSemaphore& operator=(GlobalSemaphore&&) noexcept = default; static std::shared_ptr create( - Device* device, const CoreRangeSet& cores, uint32_t initial_value, BufferType buffer_type = BufferType::L1); + Device* device, + const CoreRangeSet& cores, + uint32_t initial_value, + BufferType buffer_type = BufferType::L1, + tt::stl::Span sub_device_ids = {}); static std::shared_ptr create( - Device* device, CoreRangeSet&& cores, uint32_t initial_value, BufferType buffer_type = BufferType::L1); + Device* device, + CoreRangeSet&& cores, + uint32_t initial_value, + BufferType buffer_type = BufferType::L1, + tt::stl::Span sub_device_ids = {}); Device* device() const; DeviceAddr address() const; - void reset_semaphore_value(); + void reset_semaphore_value(tt::stl::Span sub_device_ids = {}); static constexpr auto attribute_names = std::forward_as_tuple("cores", "initial_value"); const auto attribute_values() const { return std::make_tuple(this->cores_, this->initial_value_); } private: - void setup_buffer(BufferType buffer_type); + void setup_buffer(BufferType buffer_type, tt::stl::Span sub_device_ids); // GlobalSemaphore is implemented as a wrapper around a sharded buffer // This can be updated in the future to be its own container with optimized dispatch functions diff --git a/tt_metal/include/tt_metal/global_circular_buffer.hpp b/tt_metal/include/tt_metal/global_circular_buffer.hpp index 776296a589a..3c19ee7a07b 100644 --- a/tt_metal/include/tt_metal/global_circular_buffer.hpp +++ b/tt_metal/include/tt_metal/global_circular_buffer.hpp @@ -22,13 +22,16 @@ namespace experimental { * @param sender_receiver_core_mapping The mapping of remote sender to remote receiver cores for the circular buffer. * @param size Size of the global circular buffer per core in bytes. * @param buffer_type Buffer type to store the global circular buffer. Can only be an L1 buffer type. + * @param sub_device_ids Sub-device IDs to wait on before writing the global circular buffer config to device. Defaults + * to waiting on all sub-devices. * @return Handle to the allocated global circular buffer. */ std::shared_ptr CreateGlobalCircularBuffer( Device* device, const std::unordered_map& sender_receiver_core_mapping, uint32_t size, - BufferType buffer_type = BufferType::L1); + BufferType buffer_type = BufferType::L1, + tt::stl::Span sub_device_ids = {}); } // namespace experimental diff --git a/tt_metal/tt_metal.cpp b/tt_metal/tt_metal.cpp index e59f14430cd..c6bb2cfb93b 100644 --- a/tt_metal/tt_metal.cpp +++ b/tt_metal/tt_metal.cpp @@ -1153,13 +1153,21 @@ uint32_t CreateSemaphore( } std::shared_ptr CreateGlobalSemaphore( - Device* device, const CoreRangeSet& cores, uint32_t initial_value, BufferType buffer_type) { - return GlobalSemaphore::create(device, cores, initial_value, buffer_type); + Device* device, + const CoreRangeSet& cores, + uint32_t initial_value, + BufferType buffer_type, + tt::stl::Span sub_device_ids) { + return GlobalSemaphore::create(device, cores, initial_value, buffer_type, sub_device_ids); } std::shared_ptr CreateGlobalSemaphore( - Device* device, CoreRangeSet&& cores, uint32_t initial_value, BufferType buffer_type) { - return GlobalSemaphore::create(device, std::move(cores), initial_value, buffer_type); + Device* device, + CoreRangeSet&& cores, + uint32_t initial_value, + BufferType buffer_type, + tt::stl::Span sub_device_ids) { + return GlobalSemaphore::create(device, std::move(cores), initial_value, buffer_type, sub_device_ids); } std::shared_ptr CreateBuffer(const InterleavedBufferConfig& config) { @@ -1362,8 +1370,9 @@ std::shared_ptr CreateGlobalCircularBuffer( Device* device, const std::unordered_map& sender_receiver_core_mapping, uint32_t size, - BufferType buffer_type) { - return GlobalCircularBuffer::create(device, sender_receiver_core_mapping, size, buffer_type); + BufferType buffer_type, + tt::stl::Span sub_device_ids) { + return GlobalCircularBuffer::create(device, sender_receiver_core_mapping, size, buffer_type, sub_device_ids); } CBHandle CreateCircularBuffer( diff --git a/ttnn/cpp/pybind11/global_circular_buffer.cpp b/ttnn/cpp/pybind11/global_circular_buffer.cpp index f736ee99781..4c21941b73c 100644 --- a/ttnn/cpp/pybind11/global_circular_buffer.cpp +++ b/ttnn/cpp/pybind11/global_circular_buffer.cpp @@ -19,12 +19,19 @@ void py_module(py::module& module) { // Single Device APIs module.def( "create_global_circular_buffer", - py::overload_cast&, uint32_t, BufferType>( - &create_global_circular_buffer), + [](Device* device, + const std::unordered_map& sender_receiver_core_mapping, + uint32_t size, + BufferType buffer_type, + const std::vector& sub_device_ids) { + return ttnn::global_circular_buffer::create_global_circular_buffer( + device, sender_receiver_core_mapping, size, buffer_type, sub_device_ids); + }, py::arg("device"), py::arg("sender_receiver_core_mapping"), py::arg("size"), py::arg("buffer_type") = tt::tt_metal::BufferType::L1, + py::arg("sub_device_ids") = std::vector(), R"doc( Create a GlobalCircularBuffer Object on a single device. @@ -33,17 +40,26 @@ void py_module(py::module& module) { sender_receiver_core_mapping (dict): The mapping of remote sender to remote receiver cores for the circular buffer. size (int): Size of the global circular buffer per core in bytes. buffer_type (BufferType): The type of buffer to use for the global circular buffer. + sub_device_ids (List[ttnn.SubDeviceIds]): Sub-device IDs to wait on before writing the global circular buffer config to device. + Defaults to waiting on all sub-devices. )doc"); // Multi Device APIs module.def( "create_global_circular_buffer", - py::overload_cast&, uint32_t, BufferType>( - &create_global_circular_buffer), + [](MeshDevice* mesh_device, + const std::unordered_map& sender_receiver_core_mapping, + uint32_t size, + BufferType buffer_type, + const std::vector& sub_device_ids) { + return ttnn::global_circular_buffer::create_global_circular_buffer( + mesh_device, sender_receiver_core_mapping, size, buffer_type, sub_device_ids); + }, py::arg("mesh_device"), py::arg("sender_receiver_core_mapping"), py::arg("size"), py::arg("buffer_type") = tt::tt_metal::BufferType::L1, + py::arg("sub_device_ids") = std::vector(), R"doc( Create a GlobalCircularBuffer Object on a single device. @@ -52,6 +68,8 @@ void py_module(py::module& module) { sender_receiver_core_mapping (dict): The mapping of remote sender to remote receiver cores for the circular buffer. size (int): Size of the global circular buffer per core in bytes. buffer_type (BufferType): The type of buffer to use for the global circular buffer. + sub_device_ids (List[ttnn.SubDeviceIds]): Sub-device IDs to wait on before writing the global circular buffer config to device. + Defaults to waiting on all sub-devices. )doc"); } diff --git a/ttnn/cpp/pybind11/global_semaphore.cpp b/ttnn/cpp/pybind11/global_semaphore.cpp index 79a97de58df..f6e44cb3419 100644 --- a/ttnn/cpp/pybind11/global_semaphore.cpp +++ b/ttnn/cpp/pybind11/global_semaphore.cpp @@ -19,11 +19,19 @@ void py_module(py::module& module) { // Single Device APIs module.def( "create_global_semaphore", - py::overload_cast(&create_global_semaphore), + [](Device* device, + const CoreRangeSet& cores, + uint32_t initial_value, + BufferType buffer_type, + const std::vector& sub_device_ids) { + return ttnn::global_semaphore::create_global_semaphore( + device, cores, initial_value, buffer_type, sub_device_ids); + }, py::arg("device"), py::arg("cores"), py::arg("initial_value"), py::arg("buffer_type") = tt::tt_metal::BufferType::L1, + py::arg("sub_device_ids") = std::vector(), R"doc( Create a GlobalSemaphore Object on a single device. @@ -32,6 +40,8 @@ void py_module(py::module& module) { cores (CoreRangeSet): The cores on which the global semaphore will be used for synchronization. initial_value (int): The initial value of the global semaphore. buffer_type (BufferType): The type of buffer to use for the global semaphore. + sub_device_ids (List[ttnn.SubDeviceIds]): Sub-device IDs to wait on before writing the global semaphore value to device. + Defaults to waiting on all sub-devices. )doc"); module.def( @@ -47,23 +57,35 @@ void py_module(py::module& module) { module.def( "reset_global_semaphore_value", - py::overload_cast&>(&reset_global_semaphore_value), + py::overload_cast&, const std::vector&>( + &reset_global_semaphore_value), py::arg("global_semaphore"), + py::arg("sub_device_ids") = std::vector(), R"doc( Reset the value of the global semaphore. Args: global_semaphore (GlobalSemaphore): The global semaphore object. + sub_device_ids (List[ttnn.SubDeviceIds]): Sub-device IDs to wait on before writing the global semaphore value to device. + Defaults to waiting on all sub-devices. )doc"); // Multi Device APIs module.def( "create_global_semaphore", - py::overload_cast(&create_global_semaphore), + [](MeshDevice* mesh_device, + const CoreRangeSet& cores, + uint32_t initial_value, + BufferType buffer_type, + const std::vector& sub_device_ids) { + return ttnn::global_semaphore::create_global_semaphore( + mesh_device, cores, initial_value, buffer_type, sub_device_ids); + }, py::arg("mesh_device"), py::arg("cores"), py::arg("initial_value"), py::arg("buffer_type") = tt::tt_metal::BufferType::L1, + py::arg("sub_device_ids") = std::vector(), R"doc( Create a GlobalSemaphore Object on a single device. @@ -72,6 +94,8 @@ void py_module(py::module& module) { cores (CoreRangeSet): The cores on which the global semaphore will be used for synchronization. initial_value (int): The initial value of the global semaphore. buffer_type (BufferType): The type of buffer to use for the global semaphore. + sub_device_ids (List[ttnn.SubDeviceIds]): Sub-device IDs to wait on before writing the global semaphore value to device. + Defaults to waiting on all sub-devices. )doc"); module.def( @@ -87,13 +111,16 @@ void py_module(py::module& module) { module.def( "reset_global_semaphore_value", - py::overload_cast(&reset_global_semaphore_value), + py::overload_cast&>( + &reset_global_semaphore_value), py::arg("global_semaphore"), + py::arg("sub_device_ids") = std::vector(), R"doc( Reset the value of the global semaphore. Args: global_semaphore (GlobalSemaphore): The global semaphore object. + sub_device_ids (List[ttnn.SubDeviceIds]): Sub-device IDs to wait on before writing the global semaphore value to device. )doc"); } diff --git a/ttnn/cpp/ttnn/global_circular_buffer.cpp b/ttnn/cpp/ttnn/global_circular_buffer.cpp index 7c5967fa3c2..76cc9df3e9f 100644 --- a/ttnn/cpp/ttnn/global_circular_buffer.cpp +++ b/ttnn/cpp/ttnn/global_circular_buffer.cpp @@ -21,12 +21,13 @@ std::shared_ptr create_global_circular_buffer( Device* device, const std::unordered_map& sender_receiver_core_mapping, uint32_t size, - BufferType buffer_type) { + BufferType buffer_type, + tt::stl::Span sub_device_ids) { std::shared_ptr global_cb; device->push_work( - [device, &sender_receiver_core_mapping, size, buffer_type, &global_cb]() { + [device, &sender_receiver_core_mapping, size, buffer_type, sub_device_ids, &global_cb]() { global_cb = tt::tt_metal::v1::experimental::CreateGlobalCircularBuffer( - device, sender_receiver_core_mapping, size, buffer_type); + device, sender_receiver_core_mapping, size, buffer_type, sub_device_ids); }, /*blocking=*/true); return global_cb; @@ -36,15 +37,16 @@ MultiDeviceGlobalCircularBuffer create_global_circular_buffer( MeshDevice* mesh_device, const std::unordered_map& sender_receiver_core_mapping, uint32_t size, - BufferType buffer_type) { + BufferType buffer_type, + tt::stl::Span sub_device_ids) { MultiDeviceGlobalCircularBuffer multi_device_global_cb(mesh_device); const auto& devices = mesh_device->get_devices(); for (uint32_t i = 0; i < devices.size(); ++i) { auto* device = devices[i]; auto& global_cb = multi_device_global_cb.global_circular_buffers[i]; - device->push_work([device, &sender_receiver_core_mapping, size, buffer_type, &global_cb]() { + device->push_work([device, &sender_receiver_core_mapping, size, buffer_type, sub_device_ids, &global_cb]() { global_cb = tt::tt_metal::v1::experimental::CreateGlobalCircularBuffer( - device, sender_receiver_core_mapping, size, buffer_type); + device, sender_receiver_core_mapping, size, buffer_type, sub_device_ids); }); } for (auto* device : devices) { diff --git a/ttnn/cpp/ttnn/global_circular_buffer.hpp b/ttnn/cpp/ttnn/global_circular_buffer.hpp index bb84ce3a7ab..39f18a0a63d 100644 --- a/ttnn/cpp/ttnn/global_circular_buffer.hpp +++ b/ttnn/cpp/ttnn/global_circular_buffer.hpp @@ -20,13 +20,15 @@ std::shared_ptr create_global_circular_buffer( Device* device, const std::unordered_map& sender_receiver_core_mapping, uint32_t size, - BufferType buffer_type = BufferType::L1); + BufferType buffer_type = BufferType::L1, + tt::stl::Span sub_device_ids = {}); // Multi Device APIs MultiDeviceGlobalCircularBuffer create_global_circular_buffer( MeshDevice* mesh_device, const std::unordered_map& sender_receiver_core_mapping, uint32_t size, - BufferType buffer_type = BufferType::L1); + BufferType buffer_type = BufferType::L1, + tt::stl::Span sub_device_ids = {}); } // namespace ttnn::global_circular_buffer diff --git a/ttnn/cpp/ttnn/global_semaphore.cpp b/ttnn/cpp/ttnn/global_semaphore.cpp index da1ebf8f0f0..a74a4b350cc 100644 --- a/ttnn/cpp/ttnn/global_semaphore.cpp +++ b/ttnn/cpp/ttnn/global_semaphore.cpp @@ -5,8 +5,9 @@ #include "global_semaphore.hpp" #include -#include "tt_metal/impl/buffers/global_semaphore.hpp" #include "tt_metal/host_api.hpp" +#include "tt_metal/impl/buffers/global_semaphore.hpp" +#include "tt_metal/tt_stl/span.hpp" namespace ttnn::global_semaphore { @@ -18,11 +19,15 @@ MultiDeviceGlobalSemaphore::MultiDeviceGlobalSemaphore(MeshDevice* mesh_device) } std::shared_ptr create_global_semaphore( - Device* device, const CoreRangeSet& cores, uint32_t initial_value, BufferType buffer_type) { + Device* device, + const CoreRangeSet& cores, + uint32_t initial_value, + BufferType buffer_type, + tt::stl::Span sub_device_ids) { std::shared_ptr global_semaphore = nullptr; device->push_work( - [device, &cores, initial_value, buffer_type, &global_semaphore] { - global_semaphore = GlobalSemaphore::create(device, cores, initial_value, buffer_type); + [device, &cores, initial_value, buffer_type, sub_device_ids, &global_semaphore] { + global_semaphore = GlobalSemaphore::create(device, cores, initial_value, buffer_type, sub_device_ids); }, /*blocking=*/true); return global_semaphore; @@ -35,20 +40,25 @@ DeviceAddr get_global_semaphore_address(const std::shared_ptr& return address; } -void reset_global_semaphore_value(const std::shared_ptr& global_semaphore) { +void reset_global_semaphore_value( + const std::shared_ptr& global_semaphore, const std::vector& sub_device_ids) { auto* device = global_semaphore->device(); - device->push_work([global_semaphore] { global_semaphore->reset_semaphore_value(); }); + device->push_work([global_semaphore, sub_device_ids] { global_semaphore->reset_semaphore_value(sub_device_ids); }); } MultiDeviceGlobalSemaphore create_global_semaphore( - MeshDevice* mesh_device, const CoreRangeSet& cores, uint32_t initial_value, BufferType buffer_type) { + MeshDevice* mesh_device, + const CoreRangeSet& cores, + uint32_t initial_value, + BufferType buffer_type, + tt::stl::Span sub_device_ids) { MultiDeviceGlobalSemaphore multi_device_global_semaphore(mesh_device); const auto& devices = mesh_device->get_devices(); for (uint32_t i = 0; i < devices.size(); ++i) { auto* device = devices[i]; auto& global_semaphore = multi_device_global_semaphore.global_semaphores[i]; - device->push_work([device, &cores, initial_value, buffer_type, &global_semaphore] { - global_semaphore = GlobalSemaphore::create(device, cores, initial_value, buffer_type); + device->push_work([device, &cores, initial_value, buffer_type, sub_device_ids, &global_semaphore] { + global_semaphore = GlobalSemaphore::create(device, cores, initial_value, buffer_type, sub_device_ids); }); } for (auto device : devices) { @@ -71,9 +81,10 @@ std::vector get_global_semaphore_address(const MultiDeviceGlobalSema return addresses; } -void reset_global_semaphore_value(const MultiDeviceGlobalSemaphore& global_semaphore) { +void reset_global_semaphore_value( + const MultiDeviceGlobalSemaphore& global_semaphore, const std::vector& sub_device_ids) { for (const auto& global_semaphore : global_semaphore.global_semaphores) { - reset_global_semaphore_value(global_semaphore); + reset_global_semaphore_value(global_semaphore, sub_device_ids); } } diff --git a/ttnn/cpp/ttnn/global_semaphore.hpp b/ttnn/cpp/ttnn/global_semaphore.hpp index 70f56fb4b4c..b04cda2dd27 100644 --- a/ttnn/cpp/ttnn/global_semaphore.hpp +++ b/ttnn/cpp/ttnn/global_semaphore.hpp @@ -17,17 +17,24 @@ struct MultiDeviceGlobalSemaphore { // Single Device APIs std::shared_ptr create_global_semaphore( - Device* device, const CoreRangeSet& cores, uint32_t initial_value, BufferType buffer_type = BufferType::L1); + Device* device, + const CoreRangeSet& cores, + uint32_t initial_value, + BufferType buffer_type = BufferType::L1, + tt::stl::Span sub_device_ids = {}); DeviceAddr get_global_semaphore_address(const std::shared_ptr& global_semaphore); -void reset_global_semaphore_value(const std::shared_ptr& global_semaphore); +void reset_global_semaphore_value( + const std::shared_ptr& global_semaphore, const std::vector& sub_device_ids = {}); // Multi Device APIs MultiDeviceGlobalSemaphore create_global_semaphore( MeshDevice* mesh_device, const CoreRangeSet& cores, uint32_t initial_value, - BufferType buffer_type = BufferType::L1); + BufferType buffer_type = BufferType::L1, + tt::stl::Span sub_device_ids = {}); std::vector get_global_semaphore_address(const MultiDeviceGlobalSemaphore& global_semaphore); -void reset_global_semaphore_value(const MultiDeviceGlobalSemaphore& global_semaphore); +void reset_global_semaphore_value( + const MultiDeviceGlobalSemaphore& global_semaphore, const std::vector& sub_device_ids = {}); } // namespace ttnn::global_semaphore