From eb76f950567e209dab72232ffa0c001f509e943c Mon Sep 17 00:00:00 2001 From: Austin Ho Date: Wed, 11 Dec 2024 21:57:52 +0000 Subject: [PATCH] #0: Update global semaphores and global circular buffers metal apis to be thread-safe instead of depending on ttnn apis --- .../tt_metal/api/test_global_semaphores.cpp | 5 +- .../ttnn/unit_tests/test_global_semaphore.py | 2 +- .../impl/buffers/global_circular_buffer.cpp | 112 ++++++++++-------- .../impl/buffers/global_circular_buffer.hpp | 7 +- tt_metal/impl/buffers/global_semaphore.cpp | 48 +++++--- tt_metal/impl/buffers/global_semaphore.hpp | 24 ++-- ttnn/cpp/pybind11/global_semaphore.cpp | 18 ++- ttnn/cpp/ttnn/global_semaphore.cpp | 14 ++- ttnn/cpp/ttnn/global_semaphore.hpp | 8 +- 9 files changed, 143 insertions(+), 95 deletions(-) diff --git a/tests/tt_metal/tt_metal/api/test_global_semaphores.cpp b/tests/tt_metal/tt_metal/api/test_global_semaphores.cpp index f88e71efda9b..58a0f9873537 100644 --- a/tests/tt_metal/tt_metal/api/test_global_semaphores.cpp +++ b/tests/tt_metal/tt_metal/api/test_global_semaphores.cpp @@ -86,6 +86,7 @@ TEST_F(DispatchFixture, ResetGlobalSemaphores) { for (auto device : devices_) { { uint32_t initial_value = 1; + uint32_t reset_value = 2; std::vector overwrite_value = {2}; auto global_semaphore = tt::tt_metal::CreateGlobalSemaphore(device, cores, initial_value); auto address = global_semaphore->address(); @@ -104,14 +105,14 @@ TEST_F(DispatchFixture, ResetGlobalSemaphores) { EXPECT_EQ(sem_vals[0], overwrite_value[0]); } - global_semaphore->reset_semaphore_value(); + global_semaphore->reset_semaphore_value(reset_value); Synchronize(device); for (const auto& core : cores_vec) { auto sem_vals = tt::llrt::read_hex_vec_from_core( device->id(), device->worker_core_from_logical_core(core), address, sizeof(uint32_t)); tt::llrt::write_hex_vec_to_core( device->id(), device->worker_core_from_logical_core(core), overwrite_value, address); - EXPECT_EQ(sem_vals[0], initial_value); + EXPECT_EQ(sem_vals[0], reset_value); } } } diff --git a/tests/ttnn/unit_tests/test_global_semaphore.py b/tests/ttnn/unit_tests/test_global_semaphore.py index 24c6fa107deb..32c17742c8bb 100644 --- a/tests/ttnn/unit_tests/test_global_semaphore.py +++ b/tests/ttnn/unit_tests/test_global_semaphore.py @@ -29,7 +29,7 @@ def run_global_semaphore(device): assert ttnn.get_global_semaphore_address(global_sem0) != ttnn.get_global_semaphore_address(global_sem1) - ttnn.reset_global_semaphore_value(global_sem0) + ttnn.reset_global_semaphore_value(global_sem0, 3) @pytest.mark.parametrize("enable_async_mode", (False, True), indirect=True) diff --git a/tt_metal/impl/buffers/global_circular_buffer.cpp b/tt_metal/impl/buffers/global_circular_buffer.cpp index 094670b2a301..3a05110abfd4 100644 --- a/tt_metal/impl/buffers/global_circular_buffer.cpp +++ b/tt_metal/impl/buffers/global_circular_buffer.cpp @@ -28,7 +28,8 @@ GlobalCircularBuffer::GlobalCircularBuffer( const std::unordered_map& sender_receiver_core_mapping, uint32_t size, BufferType buffer_type, - tt::stl::Span sub_device_ids) : + tt::stl::Span sub_device_ids, + Private) : device_(device), sender_receiver_core_mapping_(sender_receiver_core_mapping), size_(size) { TT_FATAL(this->device_ != nullptr, "Device cannot be null"); uint32_t num_sender_cores = sender_receiver_core_mapping.size(); @@ -86,58 +87,65 @@ void GlobalCircularBuffer::setup_cb_buffers( shard_parameters, std::nullopt); - const auto& core_to_core_id = this->cb_config_buffer_->get_buffer_page_mapping()->core_to_core_id_; - - std::vector cb_config_host_buffer(cb_config_size / sizeof(uint32_t), 0); - uint32_t buffer_address = this->cb_buffer_->address(); - uint32_t noc_xy_address = this->cb_config_buffer_->address() + num_config_elements * sizeof(uint32_t); - uint32_t pages_sent_address = align(noc_xy_address + num_noc_xy_words * sizeof(uint32_t), l1_alignment); - - for (const auto& [sender_core, receiver_cores] : this->sender_receiver_core_mapping_) { - const auto& receiver_cores_vec = corerange_to_cores(receiver_cores); - uint32_t sender_idx = core_to_core_id.at(sender_core) * cb_config_page_size / sizeof(uint32_t); - uint32_t num_receivers = receiver_cores.num_cores(); - uint32_t pages_acked_address = pages_sent_address + num_receivers * l1_alignment; - cb_config_host_buffer[sender_idx++] = 1; - cb_config_host_buffer[sender_idx++] = receiver_cores.num_cores(); - cb_config_host_buffer[sender_idx++] = buffer_address; - cb_config_host_buffer[sender_idx++] = this->size_; - cb_config_host_buffer[sender_idx++] = buffer_address; - cb_config_host_buffer[sender_idx++] = noc_xy_address; - cb_config_host_buffer[sender_idx++] = pages_sent_address; - - auto sender_physical_coord = this->device_->worker_core_from_logical_core(sender_core); - for (uint32_t i = 0; i < receiver_cores_vec.size(); i++) { - auto receiver_physical_coord = this->device_->worker_core_from_logical_core(receiver_cores_vec[i]); - cb_config_host_buffer[sender_idx++] = receiver_physical_coord.x; - cb_config_host_buffer[sender_idx++] = receiver_physical_coord.y; - - uint32_t receiver_idx = core_to_core_id.at(receiver_cores_vec[i]) * cb_config_page_size / sizeof(uint32_t); - cb_config_host_buffer[receiver_idx++] = 0; - cb_config_host_buffer[receiver_idx++] = num_receivers; - cb_config_host_buffer[receiver_idx++] = buffer_address; - cb_config_host_buffer[receiver_idx++] = this->size_; - cb_config_host_buffer[receiver_idx++] = buffer_address; - cb_config_host_buffer[receiver_idx++] = noc_xy_address; - cb_config_host_buffer[receiver_idx++] = pages_sent_address + 2 * i * l1_alignment; - cb_config_host_buffer[receiver_idx++] = sender_physical_coord.x; - cb_config_host_buffer[receiver_idx++] = sender_physical_coord.y; - } - } - // Write the config buffer to the device // Only block for the slow dispatch case - if (this->device_->using_slow_dispatch()) { - detail::WriteToBuffer(*this->cb_config_buffer_, cb_config_host_buffer); - tt::Cluster::instance().l1_barrier(this->device_->id()); - } else { - EnqueueWriteBuffer( - this->device_->command_queue(), - this->cb_config_buffer_, - cb_config_host_buffer.data(), - false, - sub_device_ids); - } + auto device = this->device_; + device->push_work([device, + cb_config_size, + cb_config_page_size, + num_noc_xy_words, + l1_alignment, + buffer_address = this->cb_buffer_->address(), + cb_config_buffer = this->cb_config_buffer_, + size = this->size_, + sender_receiver_core_mapping = this->sender_receiver_core_mapping_, + sub_device_ids = std::vector(sub_device_ids.begin(), sub_device_ids.end())] { + auto config_buffer_address = cb_config_buffer->address(); + const auto& core_to_core_id = cb_config_buffer->get_buffer_page_mapping()->core_to_core_id_; + std::vector cb_config_host_buffer(cb_config_size / sizeof(uint32_t), 0); + uint32_t noc_xy_address = config_buffer_address + num_config_elements * sizeof(uint32_t); + uint32_t pages_sent_address = align(noc_xy_address + num_noc_xy_words * sizeof(uint32_t), l1_alignment); + + for (const auto& [sender_core, receiver_cores] : sender_receiver_core_mapping) { + const auto& receiver_cores_vec = corerange_to_cores(receiver_cores); + uint32_t sender_idx = core_to_core_id.at(sender_core) * cb_config_page_size / sizeof(uint32_t); + uint32_t num_receivers = receiver_cores.num_cores(); + uint32_t pages_acked_address = pages_sent_address + num_receivers * l1_alignment; + cb_config_host_buffer[sender_idx++] = 1; + cb_config_host_buffer[sender_idx++] = receiver_cores.num_cores(); + cb_config_host_buffer[sender_idx++] = buffer_address; + cb_config_host_buffer[sender_idx++] = size; + cb_config_host_buffer[sender_idx++] = buffer_address; + cb_config_host_buffer[sender_idx++] = noc_xy_address; + cb_config_host_buffer[sender_idx++] = pages_sent_address; + + auto sender_physical_coord = device->worker_core_from_logical_core(sender_core); + for (uint32_t i = 0; i < receiver_cores_vec.size(); i++) { + auto receiver_physical_coord = device->worker_core_from_logical_core(receiver_cores_vec[i]); + cb_config_host_buffer[sender_idx++] = receiver_physical_coord.x; + cb_config_host_buffer[sender_idx++] = receiver_physical_coord.y; + + uint32_t receiver_idx = + core_to_core_id.at(receiver_cores_vec[i]) * cb_config_page_size / sizeof(uint32_t); + cb_config_host_buffer[receiver_idx++] = 0; + cb_config_host_buffer[receiver_idx++] = num_receivers; + cb_config_host_buffer[receiver_idx++] = buffer_address; + cb_config_host_buffer[receiver_idx++] = size; + cb_config_host_buffer[receiver_idx++] = buffer_address; + cb_config_host_buffer[receiver_idx++] = noc_xy_address; + cb_config_host_buffer[receiver_idx++] = pages_sent_address + 2 * i * l1_alignment; + cb_config_host_buffer[receiver_idx++] = sender_physical_coord.x; + cb_config_host_buffer[receiver_idx++] = sender_physical_coord.y; + } + } + if (device->using_slow_dispatch()) { + detail::WriteToBuffer(*cb_config_buffer, cb_config_host_buffer); + tt::Cluster::instance().l1_barrier(device->id()); + } else { + EnqueueWriteBuffer( + device->command_queue(), cb_config_buffer, cb_config_host_buffer.data(), false, sub_device_ids); + } + }); } std::shared_ptr GlobalCircularBuffer::create( @@ -147,7 +155,7 @@ std::shared_ptr GlobalCircularBuffer::create( BufferType buffer_type, tt::stl::Span sub_device_ids) { return std::make_shared( - device, sender_receiver_core_mapping, size, buffer_type, sub_device_ids); + device, sender_receiver_core_mapping, size, buffer_type, sub_device_ids, Private()); } const Buffer& GlobalCircularBuffer::cb_buffer() const { return *this->cb_buffer_; } diff --git a/tt_metal/impl/buffers/global_circular_buffer.hpp b/tt_metal/impl/buffers/global_circular_buffer.hpp index ca0c56da71f2..0185462e7a45 100644 --- a/tt_metal/impl/buffers/global_circular_buffer.hpp +++ b/tt_metal/impl/buffers/global_circular_buffer.hpp @@ -26,13 +26,18 @@ namespace v1 { namespace experimental { class GlobalCircularBuffer { + struct Private { + explicit Private() = default; + }; + public: GlobalCircularBuffer( Device* device, const std::unordered_map& sender_receiver_core_mapping, uint32_t size, BufferType buffer_type, - tt::stl::Span sub_device_ids); + tt::stl::Span sub_device_ids, + Private); GlobalCircularBuffer(const GlobalCircularBuffer&) = default; GlobalCircularBuffer& operator=(const GlobalCircularBuffer&) = default; diff --git a/tt_metal/impl/buffers/global_semaphore.cpp b/tt_metal/impl/buffers/global_semaphore.cpp index 57ef080d0f74..af976bd1a07f 100644 --- a/tt_metal/impl/buffers/global_semaphore.cpp +++ b/tt_metal/impl/buffers/global_semaphore.cpp @@ -24,9 +24,10 @@ GlobalSemaphore::GlobalSemaphore( const CoreRangeSet& cores, uint32_t initial_value, BufferType buffer_type, - tt::stl::Span sub_device_ids) : - device_(device), cores_(cores), initial_value_(initial_value) { - this->setup_buffer(buffer_type, sub_device_ids); + tt::stl::Span sub_device_ids, + Private) : + device_(device), cores_(cores) { + this->setup_buffer(initial_value, buffer_type, sub_device_ids); } GlobalSemaphore::GlobalSemaphore( @@ -34,12 +35,14 @@ GlobalSemaphore::GlobalSemaphore( CoreRangeSet&& cores, uint32_t initial_value, BufferType buffer_type, - tt::stl::Span sub_device_ids) : - device_(device), cores_(std::move(cores)), initial_value_(initial_value) { - this->setup_buffer(buffer_type, sub_device_ids); + tt::stl::Span sub_device_ids, + Private) : + device_(device), cores_(std::move(cores)) { + this->setup_buffer(initial_value, buffer_type, sub_device_ids); } -void GlobalSemaphore::setup_buffer(BufferType buffer_type, tt::stl::Span sub_device_ids) { +void GlobalSemaphore::setup_buffer( + uint32_t initial_value, BufferType buffer_type, tt::stl::Span sub_device_ids) { TT_FATAL( buffer_type == BufferType::L1 or buffer_type == BufferType::L1_SMALL, "Global semaphore can only be created for L1 buffer types"); @@ -58,8 +61,7 @@ void GlobalSemaphore::setup_buffer(BufferType buffer_type, tt::stl::Spanhost_buffer_ = std::vector(num_cores, this->initial_value_); - this->reset_semaphore_value(sub_device_ids); + this->reset_semaphore_value(initial_value, sub_device_ids); } std::shared_ptr GlobalSemaphore::create( @@ -68,7 +70,7 @@ std::shared_ptr GlobalSemaphore::create( uint32_t initial_value, BufferType buffer_type, tt::stl::Span sub_device_ids) { - return std::make_shared(device, cores, initial_value, buffer_type, sub_device_ids); + return std::make_shared(device, cores, initial_value, buffer_type, sub_device_ids, Private()); } std::shared_ptr GlobalSemaphore::create( Device* device, @@ -76,23 +78,31 @@ std::shared_ptr GlobalSemaphore::create( uint32_t initial_value, BufferType buffer_type, tt::stl::Span sub_device_ids) { - return std::make_shared(device, std::move(cores), initial_value, buffer_type, sub_device_ids); + return std::make_shared( + device, std::move(cores), initial_value, buffer_type, sub_device_ids, Private()); } Device* GlobalSemaphore::device() const { return device_; } DeviceAddr GlobalSemaphore::address() const { return buffer_->address(); } -void GlobalSemaphore::reset_semaphore_value(tt::stl::Span sub_device_ids) { +void GlobalSemaphore::reset_semaphore_value(uint32_t reset_value, tt::stl::Span sub_device_ids) { // Write the initial value to the semaphore to the device // Only block for the slow dispatch case - if (this->device_->using_slow_dispatch()) { - detail::WriteToBuffer(*this->buffer_, this->host_buffer_); - tt::Cluster::instance().l1_barrier(this->device_->id()); - } else { - EnqueueWriteBuffer( - this->device_->command_queue(), this->buffer_, this->host_buffer_.data(), false, sub_device_ids); - } + auto* device = this->device_; + device->push_work([device, + reset_value, + sub_device_ids = std::vector(sub_device_ids.begin(), sub_device_ids.end()), + num_cores = this->cores_.num_cores(), + buffer = this->buffer_] { + std::vector host_buffer(num_cores, reset_value); + if (device->using_slow_dispatch()) { + detail::WriteToBuffer(*buffer, host_buffer); + tt::Cluster::instance().l1_barrier(device->id()); + } else { + EnqueueWriteBuffer(device->command_queue(), buffer, host_buffer, false, sub_device_ids); + } + }); } } // namespace tt::tt_metal diff --git a/tt_metal/impl/buffers/global_semaphore.hpp b/tt_metal/impl/buffers/global_semaphore.hpp index 0d912b2f9ac6..846975262abc 100644 --- a/tt_metal/impl/buffers/global_semaphore.hpp +++ b/tt_metal/impl/buffers/global_semaphore.hpp @@ -20,20 +20,26 @@ class Buffer; class Device; class GlobalSemaphore { + struct Private { + explicit Private() = default; + }; + public: GlobalSemaphore( Device* device, const CoreRangeSet& cores, uint32_t initial_value, - BufferType buffer_type = BufferType::L1, - tt::stl::Span sub_device_ids = {}); + BufferType buffer_type, + tt::stl::Span sub_device_ids, + Private); GlobalSemaphore( Device* device, CoreRangeSet&& cores, uint32_t initial_value, - BufferType buffer_type = BufferType::L1, - tt::stl::Span sub_device_ids = {}); + BufferType buffer_type, + tt::stl::Span sub_device_ids, + Private); GlobalSemaphore(const GlobalSemaphore&) = default; GlobalSemaphore& operator=(const GlobalSemaphore&) = default; @@ -59,21 +65,19 @@ class GlobalSemaphore { DeviceAddr address() const; - void reset_semaphore_value(tt::stl::Span sub_device_ids = {}); + void reset_semaphore_value(uint32_t reset_value, tt::stl::Span sub_device_ids = {}); - static constexpr auto attribute_names = std::forward_as_tuple("cores", "initial_value"); - const auto attribute_values() const { return std::make_tuple(this->cores_, this->initial_value_); } + static constexpr auto attribute_names = std::forward_as_tuple("cores"); + const auto attribute_values() const { return std::make_tuple(this->cores_); } private: - void setup_buffer(BufferType buffer_type, tt::stl::Span sub_device_ids); + void setup_buffer(uint32_t initial_value, BufferType buffer_type, tt::stl::Span sub_device_ids); // GlobalSemaphore is implemented as a wrapper around a sharded buffer // This can be updated in the future to be its own container with optimized dispatch functions std::shared_ptr buffer_; - std::vector host_buffer_; Device* device_; CoreRangeSet cores_; - uint32_t initial_value_ = 0; }; } // namespace v0 diff --git a/ttnn/cpp/pybind11/global_semaphore.cpp b/ttnn/cpp/pybind11/global_semaphore.cpp index f6e44cb34191..726c960e1668 100644 --- a/ttnn/cpp/pybind11/global_semaphore.cpp +++ b/ttnn/cpp/pybind11/global_semaphore.cpp @@ -57,15 +57,20 @@ void py_module(py::module& module) { module.def( "reset_global_semaphore_value", - py::overload_cast&, const std::vector&>( - &reset_global_semaphore_value), + [](const std::shared_ptr& global_semaphore, + uint32_t reset_value, + const std::vector& sub_device_ids) { + ttnn::global_semaphore::reset_global_semaphore_value(global_semaphore, reset_value, sub_device_ids); + }, py::arg("global_semaphore"), + py::arg("reset_value"), py::arg("sub_device_ids") = std::vector(), R"doc( Reset the value of the global semaphore. Args: global_semaphore (GlobalSemaphore): The global semaphore object. + reset_value (int): The value to reset the global semaphore to. sub_device_ids (List[ttnn.SubDeviceIds]): Sub-device IDs to wait on before writing the global semaphore value to device. Defaults to waiting on all sub-devices. )doc"); @@ -111,15 +116,20 @@ void py_module(py::module& module) { module.def( "reset_global_semaphore_value", - py::overload_cast&>( - &reset_global_semaphore_value), + [](const MultiDeviceGlobalSemaphore& global_semaphore, + uint32_t reset_value, + const std::vector& sub_device_ids) { + ttnn::global_semaphore::reset_global_semaphore_value(global_semaphore, reset_value, sub_device_ids); + }, py::arg("global_semaphore"), + py::arg("reset_value"), py::arg("sub_device_ids") = std::vector(), R"doc( Reset the value of the global semaphore. Args: global_semaphore (GlobalSemaphore): The global semaphore object. + reset_value (int): The value to reset the global semaphore to. sub_device_ids (List[ttnn.SubDeviceIds]): Sub-device IDs to wait on before writing the global semaphore value to device. )doc"); } diff --git a/ttnn/cpp/ttnn/global_semaphore.cpp b/ttnn/cpp/ttnn/global_semaphore.cpp index a74a4b350ccb..777fe337b718 100644 --- a/ttnn/cpp/ttnn/global_semaphore.cpp +++ b/ttnn/cpp/ttnn/global_semaphore.cpp @@ -41,9 +41,13 @@ DeviceAddr get_global_semaphore_address(const std::shared_ptr& } void reset_global_semaphore_value( - const std::shared_ptr& global_semaphore, const std::vector& sub_device_ids) { + const std::shared_ptr& global_semaphore, + uint32_t reset_value, + tt::stl::Span sub_device_ids) { auto* device = global_semaphore->device(); - device->push_work([global_semaphore, sub_device_ids] { global_semaphore->reset_semaphore_value(sub_device_ids); }); + device->push_work([global_semaphore, reset_value, sub_device_ids] { + global_semaphore->reset_semaphore_value(reset_value, sub_device_ids); + }); } MultiDeviceGlobalSemaphore create_global_semaphore( @@ -82,9 +86,11 @@ std::vector get_global_semaphore_address(const MultiDeviceGlobalSema } void reset_global_semaphore_value( - const MultiDeviceGlobalSemaphore& global_semaphore, const std::vector& sub_device_ids) { + const MultiDeviceGlobalSemaphore& global_semaphore, + uint32_t reset_value, + tt::stl::Span sub_device_ids) { for (const auto& global_semaphore : global_semaphore.global_semaphores) { - reset_global_semaphore_value(global_semaphore, sub_device_ids); + reset_global_semaphore_value(global_semaphore, reset_value, sub_device_ids); } } diff --git a/ttnn/cpp/ttnn/global_semaphore.hpp b/ttnn/cpp/ttnn/global_semaphore.hpp index b04cda2dd274..121e8c03cdf0 100644 --- a/ttnn/cpp/ttnn/global_semaphore.hpp +++ b/ttnn/cpp/ttnn/global_semaphore.hpp @@ -24,7 +24,9 @@ std::shared_ptr create_global_semaphore( tt::stl::Span sub_device_ids = {}); DeviceAddr get_global_semaphore_address(const std::shared_ptr& global_semaphore); void reset_global_semaphore_value( - const std::shared_ptr& global_semaphore, const std::vector& sub_device_ids = {}); + const std::shared_ptr& global_semaphore, + uint32_t reset_value, + tt::stl::Span sub_device_ids = {}); // Multi Device APIs MultiDeviceGlobalSemaphore create_global_semaphore( @@ -35,6 +37,8 @@ MultiDeviceGlobalSemaphore create_global_semaphore( tt::stl::Span sub_device_ids = {}); std::vector get_global_semaphore_address(const MultiDeviceGlobalSemaphore& global_semaphore); void reset_global_semaphore_value( - const MultiDeviceGlobalSemaphore& global_semaphore, const std::vector& sub_device_ids = {}); + const MultiDeviceGlobalSemaphore& global_semaphore, + uint32_t reset_value, + tt::stl::Span sub_device_ids = {}); } // namespace ttnn::global_semaphore