From 775b79955edf90bac8d39d2e34ed6f4eb738690e Mon Sep 17 00:00:00 2001 From: Austin Ho Date: Thu, 2 Jan 2025 17:34:58 +0000 Subject: [PATCH] #0: Don't return shared ptrs of global sems/cbs, and directly return the object instead global sems/cbs are natively thread safe now, so user can decide whether to use shared ptrs or not --- .../api/test_global_circular_buffers.cpp | 14 ++-- .../tt_metal/api/test_global_semaphores.cpp | 12 ++-- .../dispatch_program/test_sub_device.cpp | 2 +- .../dispatch/sub_device_test_utils.hpp | 20 +++--- .../test_dram_read_remote_cb.cpp | 10 +-- .../test_remote_cb_sync_matmul.cpp | 6 +- tt_metal/host_api.hpp | 8 +-- .../impl/buffers/global_circular_buffer.cpp | 13 +--- .../impl/buffers/global_circular_buffer.hpp | 24 ++----- tt_metal/impl/buffers/global_semaphore.cpp | 27 ++----- tt_metal/impl/buffers/global_semaphore.hpp | 36 ++-------- .../tt_metal/global_circular_buffer.hpp | 4 +- tt_metal/tt_metal.cpp | 12 ++-- ttnn/cpp/pybind11/global_semaphore.cpp | 4 +- ttnn/cpp/ttnn/global_circular_buffer.cpp | 25 ++----- ttnn/cpp/ttnn/global_circular_buffer.hpp | 4 +- ttnn/cpp/ttnn/global_semaphore.cpp | 45 +++--------- ttnn/cpp/ttnn/global_semaphore.hpp | 12 ++-- .../all_gather_async_pybind.cpp | 71 +++++++++---------- .../device/all_gather_async_op.cpp | 59 ++++++++------- .../device/all_gather_async_op.hpp | 12 ++-- .../device/all_gather_async_program.cpp | 24 +++---- .../device/reduce_scatter_async_op.cpp | 22 +++--- 23 files changed, 179 insertions(+), 287 deletions(-) diff --git a/tests/tt_metal/tt_metal/api/test_global_circular_buffers.cpp b/tests/tt_metal/tt_metal/api/test_global_circular_buffers.cpp index f07f505ff0e..f3a865a8598 100644 --- a/tests/tt_metal/tt_metal/api/test_global_circular_buffers.cpp +++ b/tests/tt_metal/tt_metal/api/test_global_circular_buffers.cpp @@ -25,8 +25,8 @@ TEST_F(DispatchFixture, TensixCreateGlobalCircularBuffers) { sender_receiver_core_mapping[CoreCoord(0, 0)] = cores; auto global_cb = tt::tt_metal::v1::experimental::CreateGlobalCircularBuffer( device, sender_receiver_core_mapping, 3200, tt::tt_metal::BufferType::L1); - auto buffer_address = global_cb->buffer_address(); - auto config_address = global_cb->config_address(); + auto buffer_address = global_cb.buffer_address(); + auto config_address = global_cb.config_address(); } { std::unordered_map sender_receiver_core_mapping; @@ -84,14 +84,14 @@ TEST_F(DispatchFixture, TensixProgramGlobalCircularBuffers) { EXPECT_THROW(global_cb_config.remote_index(2), std::exception); EXPECT_THROW( tt::tt_metal::v1::experimental::CreateCircularBuffer( - program, CoreRangeSet(CoreRange({3, 3})), global_cb_config, *global_cb), + program, CoreRangeSet(CoreRange({3, 3})), global_cb_config, global_cb), std::exception); auto remote_cb = - tt::tt_metal::v1::experimental::CreateCircularBuffer(program, receiver_cores, global_cb_config, *global_cb); + tt::tt_metal::v1::experimental::CreateCircularBuffer(program, receiver_cores, global_cb_config, global_cb); tt::tt_metal::detail::CompileProgram(device, program); program.finalize(device); - tt::tt_metal::v1::experimental::UpdateDynamicCircularBufferAddress(program, remote_cb, *global_cb); - EXPECT_THROW(UpdateDynamicCircularBufferAddress(program, remote_cb, *dummy_global_cb), std::exception); + tt::tt_metal::v1::experimental::UpdateDynamicCircularBufferAddress(program, remote_cb, global_cb); + EXPECT_THROW(UpdateDynamicCircularBufferAddress(program, remote_cb, dummy_global_cb), std::exception); } { tt::tt_metal::Program program = CreateProgram(); @@ -107,7 +107,7 @@ TEST_F(DispatchFixture, TensixProgramGlobalCircularBuffers) { global_cb_config.remote_index(remote_cb_index).set_page_size(cb_page_size).set_data_format(tile_format); global_cb_config.index(local_cb_index).set_page_size(cb_page_size).set_data_format(tile_format); auto remote_cb = - tt::tt_metal::v1::experimental::CreateCircularBuffer(program, receiver_cores, global_cb_config, *global_cb); + tt::tt_metal::v1::experimental::CreateCircularBuffer(program, receiver_cores, global_cb_config, global_cb); tt::tt_metal::detail::CompileProgram(device, program); EXPECT_THROW(program.finalize(device), std::exception); } diff --git a/tests/tt_metal/tt_metal/api/test_global_semaphores.cpp b/tests/tt_metal/tt_metal/api/test_global_semaphores.cpp index 58a0f987353..c77427b3d22 100644 --- a/tests/tt_metal/tt_metal/api/test_global_semaphores.cpp +++ b/tests/tt_metal/tt_metal/api/test_global_semaphores.cpp @@ -20,7 +20,7 @@ TEST_F(DispatchFixture, InitializeGlobalSemaphores) { { uint32_t initial_value = 1; auto global_semaphore = tt::tt_metal::CreateGlobalSemaphore(device, cores, initial_value); - auto address = global_semaphore->address(); + auto address = global_semaphore.address(); Synchronize(device); for (const auto& core : cores_vec) { auto sem_vals = tt::llrt::read_hex_vec_from_core( @@ -32,7 +32,7 @@ TEST_F(DispatchFixture, InitializeGlobalSemaphores) { { uint32_t initial_value = 2; auto global_semaphore = tt::tt_metal::CreateGlobalSemaphore(device, cores, initial_value); - auto address = global_semaphore->address(); + auto address = global_semaphore.address(); Synchronize(device); for (const auto& core : cores_vec) { auto sem_vals = tt::llrt::read_hex_vec_from_core( @@ -53,13 +53,13 @@ TEST_F(DispatchFixture, CreateMultipleGlobalSemaphoresOnSameCore) { } for (auto device : devices_) { { - std::vector> global_semaphores; + std::vector global_semaphores; global_semaphores.reserve(cores.size()); std::vector addresses; addresses.reserve(cores.size()); for (size_t i = 0; i < cores.size(); i++) { global_semaphores.push_back(tt::tt_metal::CreateGlobalSemaphore(device, cores[i], initial_values[i])); - addresses.push_back(global_semaphores[i]->address()); + addresses.push_back(global_semaphores[i].address()); } Synchronize(device); for (size_t i = 0; i < cores.size(); i++) { @@ -89,7 +89,7 @@ TEST_F(DispatchFixture, ResetGlobalSemaphores) { uint32_t reset_value = 2; std::vector overwrite_value = {2}; auto global_semaphore = tt::tt_metal::CreateGlobalSemaphore(device, cores, initial_value); - auto address = global_semaphore->address(); + auto address = global_semaphore.address(); Synchronize(device); for (const auto& core : cores_vec) { auto sem_vals = tt::llrt::read_hex_vec_from_core( @@ -105,7 +105,7 @@ TEST_F(DispatchFixture, ResetGlobalSemaphores) { EXPECT_EQ(sem_vals[0], overwrite_value[0]); } - global_semaphore->reset_semaphore_value(reset_value); + global_semaphore.reset_semaphore_value(reset_value); Synchronize(device); for (const auto& core : cores_vec) { auto sem_vals = tt::llrt::read_hex_vec_from_core( diff --git a/tests/tt_metal/tt_metal/dispatch/dispatch_program/test_sub_device.cpp b/tests/tt_metal/tt_metal/dispatch/dispatch_program/test_sub_device.cpp index f140433f3a9..1426a4d0313 100644 --- a/tests/tt_metal/tt_metal/dispatch/dispatch_program/test_sub_device.cpp +++ b/tests/tt_metal/tt_metal/dispatch/dispatch_program/test_sub_device.cpp @@ -81,7 +81,7 @@ TEST_F(CommandQueueSingleCardFixture, TensixTestSubDeviceSynchronization) { EXPECT_TRUE(std::equal(input_1_it, input_1_it + page_size_1 / sizeof(uint32_t), readback.begin())); input_1_it += page_size_1 / sizeof(uint32_t); } - auto sem_addr = global_semaphore->address(); + auto sem_addr = global_semaphore.address(); auto physical_syncer_core = device->worker_core_from_logical_core(syncer_core); tt::llrt::write_hex_vec_to_core(device->id(), physical_syncer_core, std::vector{1}, sem_addr); diff --git a/tests/tt_metal/tt_metal/dispatch/sub_device_test_utils.hpp b/tests/tt_metal/tt_metal/dispatch/sub_device_test_utils.hpp index 54b77acedc1..f0bbd1fa900 100644 --- a/tests/tt_metal/tt_metal/dispatch/sub_device_test_utils.hpp +++ b/tests/tt_metal/tt_metal/dispatch/sub_device_test_utils.hpp @@ -9,7 +9,7 @@ // TODO: ARCH_NAME specific, must remove #include "eth_l1_address_map.h" -inline std::tuple> create_single_sync_program( +inline std::tuple create_single_sync_program( Device* device, SubDevice sub_device) { auto syncer_coord = sub_device.cores(HalProgrammableCoreType::TENSIX).ranges().at(0).start_coord; auto syncer_core = CoreRangeSet(CoreRange(syncer_coord, syncer_coord)); @@ -21,12 +21,12 @@ inline std::tuple> create_s "tests/tt_metal/tt_metal/test_kernels/misc/sub_device/syncer.cpp", syncer_core, DataMovementConfig{.processor = DataMovementProcessor::RISCV_0, .noc = NOC::RISCV_0_default}); - std::array syncer_rt_args = {global_sem->address()}; + std::array syncer_rt_args = {global_sem.address()}; SetRuntimeArgs(syncer_program, syncer_kernel, syncer_core, syncer_rt_args); return {std::move(syncer_program), std::move(syncer_coord), std::move(global_sem)}; } -inline std::tuple> create_basic_sync_program( +inline std::tuple create_basic_sync_program( Device* device, const SubDevice& sub_device_1, const SubDevice& sub_device_2) { auto waiter_coord = sub_device_2.cores(HalProgrammableCoreType::TENSIX).ranges().at(0).start_coord; auto waiter_core = CoreRangeSet(CoreRange(waiter_coord, waiter_coord)); @@ -45,7 +45,7 @@ inline std::tuple> c waiter_core, DataMovementConfig{.processor = DataMovementProcessor::RISCV_0, .noc = NOC::RISCV_0_default}); std::array waiter_rt_args = { - global_sem->address(), incrementer_cores.num_cores(), syncer_core_physical.x, syncer_core_physical.y}; + global_sem.address(), incrementer_cores.num_cores(), syncer_core_physical.x, syncer_core_physical.y}; SetRuntimeArgs(waiter_program, waiter_kernel, waiter_core, waiter_rt_args); Program syncer_program = CreateProgram(); @@ -54,7 +54,7 @@ inline std::tuple> c "tests/tt_metal/tt_metal/test_kernels/misc/sub_device/syncer.cpp", syncer_core, DataMovementConfig{.processor = DataMovementProcessor::RISCV_0, .noc = NOC::RISCV_0_default}); - std::array syncer_rt_args = {global_sem->address()}; + std::array syncer_rt_args = {global_sem.address()}; SetRuntimeArgs(syncer_program, syncer_kernel, syncer_core, syncer_rt_args); Program incrementer_program = CreateProgram(); @@ -64,13 +64,13 @@ inline std::tuple> c incrementer_cores, DataMovementConfig{.processor = DataMovementProcessor::RISCV_1, .noc = NOC::RISCV_1_default}); std::array incrementer_rt_args = { - global_sem->address(), waiter_core_physical.x, waiter_core_physical.y}; + global_sem.address(), waiter_core_physical.x, waiter_core_physical.y}; SetRuntimeArgs(incrementer_program, incrementer_kernel, incrementer_cores, incrementer_rt_args); return { std::move(waiter_program), std::move(syncer_program), std::move(incrementer_program), std::move(global_sem)}; } -inline std::tuple> create_basic_eth_sync_program( +inline std::tuple create_basic_eth_sync_program( Device* device, const SubDevice& sub_device_1, const SubDevice& sub_device_2) { auto waiter_coord = sub_device_2.cores(HalProgrammableCoreType::ACTIVE_ETH).ranges().at(0).start_coord; auto waiter_core = CoreRangeSet(CoreRange(waiter_coord, waiter_coord)); @@ -92,7 +92,7 @@ inline std::tuple> c waiter_core, EthernetConfig{.noc = NOC::RISCV_0_default, .processor = DataMovementProcessor::RISCV_0}); std::array waiter_rt_args = { - global_sem->address(), + global_sem.address(), incrementer_cores.num_cores(), syncer_core_physical.x, syncer_core_physical.y, @@ -107,7 +107,7 @@ inline std::tuple> c "tests/tt_metal/tt_metal/test_kernels/misc/sub_device/syncer.cpp", syncer_core, DataMovementConfig{.processor = DataMovementProcessor::RISCV_0, .noc = NOC::RISCV_0_default}); - std::array syncer_rt_args = {global_sem->address()}; + std::array syncer_rt_args = {global_sem.address()}; SetRuntimeArgs(syncer_program, syncer_kernel, syncer_core, syncer_rt_args); Program incrementer_program = CreateProgram(); @@ -117,7 +117,7 @@ inline std::tuple> c incrementer_cores, DataMovementConfig{.processor = DataMovementProcessor::RISCV_1, .noc = NOC::RISCV_1_default}); std::array incrementer_rt_args = { - global_sem->address(), tensix_waiter_core_physical.x, tensix_waiter_core_physical.y}; + global_sem.address(), tensix_waiter_core_physical.x, tensix_waiter_core_physical.y}; SetRuntimeArgs(incrementer_program, incrementer_kernel, incrementer_cores, incrementer_rt_args); return { std::move(waiter_program), std::move(syncer_program), std::move(incrementer_program), std::move(global_sem)}; diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/10_dram_read_remote_cb_sync/test_dram_read_remote_cb.cpp b/tests/tt_metal/tt_metal/perf_microbenchmark/10_dram_read_remote_cb_sync/test_dram_read_remote_cb.cpp index 5480b3144c0..41147c8f8a0 100644 --- a/tests/tt_metal/tt_metal/perf_microbenchmark/10_dram_read_remote_cb_sync/test_dram_read_remote_cb.cpp +++ b/tests/tt_metal/tt_metal/perf_microbenchmark/10_dram_read_remote_cb_sync/test_dram_read_remote_cb.cpp @@ -80,7 +80,7 @@ void get_max_page_size_and_num_pages( num_pages = total_size / page_size; } -std::tuple,std::shared_ptr> +std::tuple, tt_metal::v1::experimental::GlobalCircularBuffer> create_programs( tt_metal::Device* device, const CoreRangeSet& dram_reader_core, @@ -146,7 +146,7 @@ create_programs( tt_metal::CircularBufferConfig writer_cb_config = tt_metal::CircularBufferConfig(receiver_cb_size); writer_cb_config.remote_index(writer_cb_index).set_page_size(single_tile_size).set_data_format(tile_format); auto writer_cb = - tt_metal::v1::experimental::CreateCircularBuffer(sender_program, dram_reader_core, writer_cb_config, *global_cb); + tt_metal::v1::experimental::CreateCircularBuffer(sender_program, dram_reader_core, writer_cb_config, global_cb); // mixed cb dataformat uint32_t next_layer_num_blocks = num_blocks * 2; @@ -178,7 +178,7 @@ create_programs( tt_metal::CircularBufferConfig receiver_cb_config = tt_metal::CircularBufferConfig(receiver_cb_size); receiver_cb_config.remote_index(receiver_cb_index).set_page_size(single_tile_size).set_data_format(tile_format); auto receiver_cb = tt_metal::v1::experimental::CreateCircularBuffer( - receiver_program, l1_receiver_cores, receiver_cb_config, *global_cb); + receiver_program, l1_receiver_cores, receiver_cb_config, global_cb); log_info("reader_cb_size: {}", reader_cb_size); log_info("receiver_cb_size: {}", receiver_cb_size); @@ -846,7 +846,7 @@ int main(int argc, char** argv) { tt::DataFormat::Bfp8_b, l1_receiver_core, num_receivers, - global_cb->buffer_address()); + global_cb.buffer_address()); } else { // output @@ -860,7 +860,7 @@ int main(int argc, char** argv) { tt::DataFormat::Float16_b, l1_receiver_core, num_receivers, - global_cb->buffer_address()); + global_cb.buffer_address()); } //////////////////////////////////////////////////////////////////////////// diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/11_remote_cb_sync_matmul_single_core/test_remote_cb_sync_matmul.cpp b/tests/tt_metal/tt_metal/perf_microbenchmark/11_remote_cb_sync_matmul_single_core/test_remote_cb_sync_matmul.cpp index 1d467d9d47d..bbb31cc86e2 100644 --- a/tests/tt_metal/tt_metal/perf_microbenchmark/11_remote_cb_sync_matmul_single_core/test_remote_cb_sync_matmul.cpp +++ b/tests/tt_metal/tt_metal/perf_microbenchmark/11_remote_cb_sync_matmul_single_core/test_remote_cb_sync_matmul.cpp @@ -95,7 +95,7 @@ std::tuple get_out_subblock_params( return {1, 1}; } -std::tuple, std::shared_ptr> +std::tuple, ::tt_metal::v1::experimental::GlobalCircularBuffer> create_programs( tt_metal::Device* device, const CoreRangeSet& dram_reader_core, @@ -169,7 +169,7 @@ create_programs( tt_metal::CircularBufferConfig in1_writer_cb_config = tt_metal::CircularBufferConfig(in1_receiver_cb_size); in1_writer_cb_config.remote_index(in1_writer_cb_index).set_page_size(single_tile_size).set_data_format(tile_format); auto writer_cb = tt_metal::v1::experimental::CreateCircularBuffer( - sender_program, dram_reader_core, in1_writer_cb_config, *global_cb); + sender_program, dram_reader_core, in1_writer_cb_config, global_cb); // in0 reader CB uint32_t in0_reader_cb_index = 0; @@ -190,7 +190,7 @@ create_programs( .set_data_format(tile_format); in1_receiver_cb_config.index(in1_pusher_cb_index).set_page_size(single_tile_size).set_data_format(tile_format); auto in1_receiver_cb = tt_metal::v1::experimental::CreateCircularBuffer( - receiver_program, l1_receiver_cores, in1_receiver_cb_config, *global_cb); + receiver_program, l1_receiver_cores, in1_receiver_cb_config, global_cb); // output CB uint32_t output_cb_index = 16; diff --git a/tt_metal/host_api.hpp b/tt_metal/host_api.hpp index 8a9dfcb2003..76029bb9f69 100644 --- a/tt_metal/host_api.hpp +++ b/tt_metal/host_api.hpp @@ -297,7 +297,7 @@ uint32_t CreateSemaphore( * Initializes a global semaphore on all cores within the specified CoreRangeSet. * This only supports tensix cores, and can only use L1 buffer types like BufferType::L1 and BufferType::L1_SMALL. * - * Return value: std::shared_ptr + * Return value: GlobalSemaphore * * | Argument | Description | Type | Valid Range | Required | * |----------------|--------------------------------------------------------|-----------------------------------------------------------|--------------|----------| @@ -308,7 +308,7 @@ uint32_t CreateSemaphore( * | sub_device_ids | Sub-device ids to wait on before writing the semaphore | tt::stl::Span | | No | */ // clang-format on -std::shared_ptr CreateGlobalSemaphore( +GlobalSemaphore CreateGlobalSemaphore( Device* device, const CoreRangeSet& cores, uint32_t initial_value, @@ -320,7 +320,7 @@ std::shared_ptr CreateGlobalSemaphore( * Initializes a global semaphore on all cores within the specified CoreRangeSet. * This only supports tensix cores, and can only use L1 buffer types like BufferType::L1 and BufferType::L1_SMALL. * - * Return value: std::shared_ptr + * Return value: GlobalSemaphore * * | Argument | Description | Type | Valid Range | Required | * |----------------|--------------------------------------------------------|-----------------------------------------------------------|--------------|----------| @@ -331,7 +331,7 @@ std::shared_ptr CreateGlobalSemaphore( * | sub_device_ids | Sub-device ids to wait on before writing the semaphore | tt::stl::Span | | No | */ // clang-format on -std::shared_ptr CreateGlobalSemaphore( +GlobalSemaphore CreateGlobalSemaphore( Device* device, CoreRangeSet&& cores, uint32_t initial_value, diff --git a/tt_metal/impl/buffers/global_circular_buffer.cpp b/tt_metal/impl/buffers/global_circular_buffer.cpp index 02cc48b3a87..d5028b046b7 100644 --- a/tt_metal/impl/buffers/global_circular_buffer.cpp +++ b/tt_metal/impl/buffers/global_circular_buffer.cpp @@ -28,8 +28,7 @@ GlobalCircularBuffer::GlobalCircularBuffer( const std::unordered_map& sender_receiver_core_mapping, uint32_t size, BufferType buffer_type, - tt::stl::Span sub_device_ids, - Private) : + tt::stl::Span sub_device_ids) : device_(device), sender_receiver_core_mapping_(sender_receiver_core_mapping), size_(size) { TT_FATAL(this->device_ != nullptr, "Device cannot be null"); uint32_t num_sender_cores = sender_receiver_core_mapping.size(); @@ -148,16 +147,6 @@ void GlobalCircularBuffer::setup_cb_buffers( }); } -std::shared_ptr GlobalCircularBuffer::create( - Device* device, - const std::unordered_map& sender_receiver_core_mapping, - uint32_t size, - BufferType buffer_type, - tt::stl::Span sub_device_ids) { - return std::make_shared( - device, sender_receiver_core_mapping, size, buffer_type, sub_device_ids, Private()); -} - const Buffer& GlobalCircularBuffer::cb_buffer() const { return *this->cb_buffer_; } const CoreRangeSet& GlobalCircularBuffer::sender_cores() const { return this->sender_cores_; } diff --git a/tt_metal/impl/buffers/global_circular_buffer.hpp b/tt_metal/impl/buffers/global_circular_buffer.hpp index 96f8cfec73c..b7f609aa609 100644 --- a/tt_metal/impl/buffers/global_circular_buffer.hpp +++ b/tt_metal/impl/buffers/global_circular_buffer.hpp @@ -26,23 +26,19 @@ namespace v1 { namespace experimental { class GlobalCircularBuffer { - struct Private { - explicit Private() = default; - }; - public: - static std::shared_ptr create( + GlobalCircularBuffer( Device* device, const std::unordered_map& sender_receiver_core_mapping, uint32_t size, BufferType buffer_type = BufferType::L1, tt::stl::Span sub_device_ids = {}); - GlobalCircularBuffer(const GlobalCircularBuffer&) = delete; - GlobalCircularBuffer& operator=(const GlobalCircularBuffer&) = delete; + GlobalCircularBuffer(const GlobalCircularBuffer&) = default; + GlobalCircularBuffer& operator=(const GlobalCircularBuffer&) = default; - GlobalCircularBuffer(GlobalCircularBuffer&&) noexcept = delete; - GlobalCircularBuffer& operator=(GlobalCircularBuffer&&) noexcept = delete; + GlobalCircularBuffer(GlobalCircularBuffer&&) noexcept = default; + GlobalCircularBuffer& operator=(GlobalCircularBuffer&&) noexcept = default; const Buffer& cb_buffer() const; @@ -56,16 +52,6 @@ class GlobalCircularBuffer { static constexpr auto attribute_names = std::forward_as_tuple("sender_receiver_core_mapping", "size"); const auto attribute_values() const { return std::make_tuple(this->sender_receiver_core_mapping_, this->size_); } - // "Private" constructor to prevent direct instantiation - // Use GlobalCircularBuffer::create instead - GlobalCircularBuffer( - Device* device, - const std::unordered_map& sender_receiver_core_mapping, - uint32_t size, - BufferType buffer_type, - tt::stl::Span sub_device_ids, - Private); - private: void setup_cb_buffers( BufferType buffer_type, uint32_t max_num_receivers_per_sender, tt::stl::Span sub_device_ids); diff --git a/tt_metal/impl/buffers/global_semaphore.cpp b/tt_metal/impl/buffers/global_semaphore.cpp index af976bd1a07..1a26af28b84 100644 --- a/tt_metal/impl/buffers/global_semaphore.cpp +++ b/tt_metal/impl/buffers/global_semaphore.cpp @@ -24,8 +24,7 @@ GlobalSemaphore::GlobalSemaphore( const CoreRangeSet& cores, uint32_t initial_value, BufferType buffer_type, - tt::stl::Span sub_device_ids, - Private) : + tt::stl::Span sub_device_ids) : device_(device), cores_(cores) { this->setup_buffer(initial_value, buffer_type, sub_device_ids); } @@ -35,8 +34,7 @@ GlobalSemaphore::GlobalSemaphore( CoreRangeSet&& cores, uint32_t initial_value, BufferType buffer_type, - tt::stl::Span sub_device_ids, - Private) : + tt::stl::Span sub_device_ids) : device_(device), cores_(std::move(cores)) { this->setup_buffer(initial_value, buffer_type, sub_device_ids); } @@ -64,29 +62,12 @@ void GlobalSemaphore::setup_buffer( this->reset_semaphore_value(initial_value, sub_device_ids); } -std::shared_ptr GlobalSemaphore::create( - Device* device, - const CoreRangeSet& cores, - uint32_t initial_value, - BufferType buffer_type, - tt::stl::Span sub_device_ids) { - return std::make_shared(device, cores, initial_value, buffer_type, sub_device_ids, Private()); -} -std::shared_ptr GlobalSemaphore::create( - Device* device, - CoreRangeSet&& cores, - uint32_t initial_value, - BufferType buffer_type, - tt::stl::Span sub_device_ids) { - return std::make_shared( - device, std::move(cores), initial_value, buffer_type, sub_device_ids, Private()); -} - Device* GlobalSemaphore::device() const { return device_; } DeviceAddr GlobalSemaphore::address() const { return buffer_->address(); } -void GlobalSemaphore::reset_semaphore_value(uint32_t reset_value, tt::stl::Span sub_device_ids) { +void GlobalSemaphore::reset_semaphore_value( + uint32_t reset_value, tt::stl::Span sub_device_ids) const { // Write the initial value to the semaphore to the device // Only block for the slow dispatch case auto* device = this->device_; diff --git a/tt_metal/impl/buffers/global_semaphore.hpp b/tt_metal/impl/buffers/global_semaphore.hpp index 24e404a28e7..7753abd76b5 100644 --- a/tt_metal/impl/buffers/global_semaphore.hpp +++ b/tt_metal/impl/buffers/global_semaphore.hpp @@ -20,58 +20,36 @@ class Buffer; class Device; class GlobalSemaphore { - struct Private { - explicit Private() = default; - }; - public: - static std::shared_ptr create( + GlobalSemaphore( Device* device, const CoreRangeSet& cores, uint32_t initial_value, BufferType buffer_type = BufferType::L1, tt::stl::Span sub_device_ids = {}); - static std::shared_ptr create( + GlobalSemaphore( Device* device, CoreRangeSet&& cores, uint32_t initial_value, BufferType buffer_type = BufferType::L1, tt::stl::Span sub_device_ids = {}); - GlobalSemaphore(const GlobalSemaphore&) = delete; - GlobalSemaphore& operator=(const GlobalSemaphore&) = delete; + GlobalSemaphore(const GlobalSemaphore&) = default; + GlobalSemaphore& operator=(const GlobalSemaphore&) = default; - GlobalSemaphore(GlobalSemaphore&&) noexcept = delete; - GlobalSemaphore& operator=(GlobalSemaphore&&) noexcept = delete; + GlobalSemaphore(GlobalSemaphore&&) noexcept = default; + GlobalSemaphore& operator=(GlobalSemaphore&&) noexcept = default; Device* device() const; DeviceAddr address() const; - void reset_semaphore_value(uint32_t reset_value, tt::stl::Span sub_device_ids = {}); + void reset_semaphore_value(uint32_t reset_value, tt::stl::Span sub_device_ids = {}) const; static constexpr auto attribute_names = std::forward_as_tuple("cores"); const auto attribute_values() const { return std::make_tuple(this->cores_); } - // "Private" constructor to prevent direct instantiation - // Use GlobalSemaphore::create instead - GlobalSemaphore( - Device* device, - const CoreRangeSet& cores, - uint32_t initial_value, - BufferType buffer_type, - tt::stl::Span sub_device_ids, - Private); - - GlobalSemaphore( - Device* device, - CoreRangeSet&& cores, - uint32_t initial_value, - BufferType buffer_type, - tt::stl::Span sub_device_ids, - Private); - private: void setup_buffer(uint32_t initial_value, BufferType buffer_type, tt::stl::Span sub_device_ids); diff --git a/tt_metal/include/tt_metal/global_circular_buffer.hpp b/tt_metal/include/tt_metal/global_circular_buffer.hpp index 3c19ee7a07b..32dc5216e98 100644 --- a/tt_metal/include/tt_metal/global_circular_buffer.hpp +++ b/tt_metal/include/tt_metal/global_circular_buffer.hpp @@ -24,9 +24,9 @@ namespace experimental { * @param buffer_type Buffer type to store the global circular buffer. Can only be an L1 buffer type. * @param sub_device_ids Sub-device IDs to wait on before writing the global circular buffer config to device. Defaults * to waiting on all sub-devices. - * @return Handle to the allocated global circular buffer. + * @return The allocated global circular buffer. */ -std::shared_ptr CreateGlobalCircularBuffer( +GlobalCircularBuffer CreateGlobalCircularBuffer( Device* device, const std::unordered_map& sender_receiver_core_mapping, uint32_t size, diff --git a/tt_metal/tt_metal.cpp b/tt_metal/tt_metal.cpp index 38d17d6631f..78bd56775a8 100644 --- a/tt_metal/tt_metal.cpp +++ b/tt_metal/tt_metal.cpp @@ -1174,22 +1174,22 @@ uint32_t CreateSemaphore( core_spec); } -std::shared_ptr CreateGlobalSemaphore( +GlobalSemaphore CreateGlobalSemaphore( Device* device, const CoreRangeSet& cores, uint32_t initial_value, BufferType buffer_type, tt::stl::Span sub_device_ids) { - return GlobalSemaphore::create(device, cores, initial_value, buffer_type, sub_device_ids); + return GlobalSemaphore(device, cores, initial_value, buffer_type, sub_device_ids); } -std::shared_ptr CreateGlobalSemaphore( +GlobalSemaphore CreateGlobalSemaphore( Device* device, CoreRangeSet&& cores, uint32_t initial_value, BufferType buffer_type, tt::stl::Span sub_device_ids) { - return GlobalSemaphore::create(device, std::move(cores), initial_value, buffer_type, sub_device_ids); + return GlobalSemaphore(device, std::move(cores), initial_value, buffer_type, sub_device_ids); } std::shared_ptr CreateBuffer(const InterleavedBufferConfig& config) { @@ -1388,13 +1388,13 @@ namespace v1 { namespace experimental { -std::shared_ptr CreateGlobalCircularBuffer( +GlobalCircularBuffer CreateGlobalCircularBuffer( Device* device, const std::unordered_map& sender_receiver_core_mapping, uint32_t size, BufferType buffer_type, tt::stl::Span sub_device_ids) { - return GlobalCircularBuffer::create(device, sender_receiver_core_mapping, size, buffer_type, sub_device_ids); + return GlobalCircularBuffer(device, sender_receiver_core_mapping, size, buffer_type, sub_device_ids); } CBHandle CreateCircularBuffer( diff --git a/ttnn/cpp/pybind11/global_semaphore.cpp b/ttnn/cpp/pybind11/global_semaphore.cpp index 726c960e166..c86e03a8865 100644 --- a/ttnn/cpp/pybind11/global_semaphore.cpp +++ b/ttnn/cpp/pybind11/global_semaphore.cpp @@ -46,7 +46,7 @@ void py_module(py::module& module) { module.def( "get_global_semaphore_address", - py::overload_cast&>(&get_global_semaphore_address), + py::overload_cast(&get_global_semaphore_address), py::arg("global_semaphore"), R"doc( Get the address of the global semaphore. @@ -57,7 +57,7 @@ void py_module(py::module& module) { module.def( "reset_global_semaphore_value", - [](const std::shared_ptr& global_semaphore, + [](const GlobalSemaphore& global_semaphore, uint32_t reset_value, const std::vector& sub_device_ids) { ttnn::global_semaphore::reset_global_semaphore_value(global_semaphore, reset_value, sub_device_ids); diff --git a/ttnn/cpp/ttnn/global_circular_buffer.cpp b/ttnn/cpp/ttnn/global_circular_buffer.cpp index 76cc9df3e9f..ac68307e1b7 100644 --- a/ttnn/cpp/ttnn/global_circular_buffer.cpp +++ b/ttnn/cpp/ttnn/global_circular_buffer.cpp @@ -14,23 +14,17 @@ MultiDeviceGlobalCircularBuffer::MultiDeviceGlobalCircularBuffer(MeshDevice* mes TT_ASSERT( mesh_device != nullptr, "Must provide a valid mesh_device when initializing a global circular buffer on multiple devices."); - this->global_circular_buffers = std::vector>(mesh_device->num_devices()); + this->global_circular_buffers.reserve(mesh_device->num_devices()); } -std::shared_ptr create_global_circular_buffer( +GlobalCircularBuffer create_global_circular_buffer( Device* device, const std::unordered_map& sender_receiver_core_mapping, uint32_t size, BufferType buffer_type, tt::stl::Span sub_device_ids) { - std::shared_ptr global_cb; - device->push_work( - [device, &sender_receiver_core_mapping, size, buffer_type, sub_device_ids, &global_cb]() { - global_cb = tt::tt_metal::v1::experimental::CreateGlobalCircularBuffer( - device, sender_receiver_core_mapping, size, buffer_type, sub_device_ids); - }, - /*blocking=*/true); - return global_cb; + return tt::tt_metal::v1::experimental::CreateGlobalCircularBuffer( + device, sender_receiver_core_mapping, size, buffer_type, sub_device_ids); } MultiDeviceGlobalCircularBuffer create_global_circular_buffer( @@ -40,17 +34,12 @@ MultiDeviceGlobalCircularBuffer create_global_circular_buffer( BufferType buffer_type, tt::stl::Span sub_device_ids) { MultiDeviceGlobalCircularBuffer multi_device_global_cb(mesh_device); + auto& global_circular_buffers = multi_device_global_cb.global_circular_buffers; const auto& devices = mesh_device->get_devices(); for (uint32_t i = 0; i < devices.size(); ++i) { auto* device = devices[i]; - auto& global_cb = multi_device_global_cb.global_circular_buffers[i]; - device->push_work([device, &sender_receiver_core_mapping, size, buffer_type, sub_device_ids, &global_cb]() { - global_cb = tt::tt_metal::v1::experimental::CreateGlobalCircularBuffer( - device, sender_receiver_core_mapping, size, buffer_type, sub_device_ids); - }); - } - for (auto* device : devices) { - device->synchronize(); + global_circular_buffers.push_back( + create_global_circular_buffer(device, sender_receiver_core_mapping, size, buffer_type, sub_device_ids)); } return multi_device_global_cb; } diff --git a/ttnn/cpp/ttnn/global_circular_buffer.hpp b/ttnn/cpp/ttnn/global_circular_buffer.hpp index 39f18a0a63d..5f5cb6a7a83 100644 --- a/ttnn/cpp/ttnn/global_circular_buffer.hpp +++ b/ttnn/cpp/ttnn/global_circular_buffer.hpp @@ -12,11 +12,11 @@ namespace ttnn::global_circular_buffer { struct MultiDeviceGlobalCircularBuffer { MultiDeviceGlobalCircularBuffer(MeshDevice* mesh_device); - std::vector> global_circular_buffers; + std::vector global_circular_buffers; }; // Single Device APIs -std::shared_ptr create_global_circular_buffer( +GlobalCircularBuffer create_global_circular_buffer( Device* device, const std::unordered_map& sender_receiver_core_mapping, uint32_t size, diff --git a/ttnn/cpp/ttnn/global_semaphore.cpp b/ttnn/cpp/ttnn/global_semaphore.cpp index cb5c158fa62..3987dbbb875 100644 --- a/ttnn/cpp/ttnn/global_semaphore.cpp +++ b/ttnn/cpp/ttnn/global_semaphore.cpp @@ -15,39 +15,25 @@ MultiDeviceGlobalSemaphore::MultiDeviceGlobalSemaphore(MeshDevice* mesh_device) TT_ASSERT( mesh_device != nullptr, "Must provide a valid mesh_device when initializing a global semaphore on multiple devices."); - this->global_semaphores = std::vector>(mesh_device->num_devices()); + this->global_semaphores.reserve(mesh_device->num_devices()); } -std::shared_ptr create_global_semaphore( +GlobalSemaphore create_global_semaphore( Device* device, const CoreRangeSet& cores, uint32_t initial_value, BufferType buffer_type, tt::stl::Span sub_device_ids) { - std::shared_ptr global_semaphore = nullptr; - device->push_work( - [device, &cores, initial_value, buffer_type, sub_device_ids, &global_semaphore] { - global_semaphore = GlobalSemaphore::create(device, cores, initial_value, buffer_type, sub_device_ids); - }, - /*blocking=*/true); - return global_semaphore; + return CreateGlobalSemaphore(device, cores, initial_value, buffer_type, sub_device_ids); } -tt::tt_metal::DeviceAddr get_global_semaphore_address(const std::shared_ptr& global_semaphore) { - auto* device = global_semaphore->device(); - tt::tt_metal::DeviceAddr address = 0; - device->push_work([&global_semaphore, &address] { address = global_semaphore->address(); }, /*blocking=*/true); - return address; +tt::tt_metal::DeviceAddr get_global_semaphore_address(const GlobalSemaphore& global_semaphore) { + return global_semaphore.address(); } void reset_global_semaphore_value( - const std::shared_ptr& global_semaphore, - uint32_t reset_value, - tt::stl::Span sub_device_ids) { - auto* device = global_semaphore->device(); - device->push_work([global_semaphore, reset_value, sub_device_ids] { - global_semaphore->reset_semaphore_value(reset_value, sub_device_ids); - }); + const GlobalSemaphore& global_semaphore, uint32_t reset_value, tt::stl::Span sub_device_ids) { + global_semaphore.reset_semaphore_value(reset_value, sub_device_ids); } MultiDeviceGlobalSemaphore create_global_semaphore( @@ -57,16 +43,11 @@ MultiDeviceGlobalSemaphore create_global_semaphore( BufferType buffer_type, tt::stl::Span sub_device_ids) { MultiDeviceGlobalSemaphore multi_device_global_semaphore(mesh_device); + auto& global_semaphores = multi_device_global_semaphore.global_semaphores; const auto& devices = mesh_device->get_devices(); for (uint32_t i = 0; i < devices.size(); ++i) { auto* device = devices[i]; - auto& global_semaphore = multi_device_global_semaphore.global_semaphores[i]; - device->push_work([device, &cores, initial_value, buffer_type, sub_device_ids, &global_semaphore] { - global_semaphore = GlobalSemaphore::create(device, cores, initial_value, buffer_type, sub_device_ids); - }); - } - for (auto device : devices) { - device->synchronize(); + global_semaphores.push_back(create_global_semaphore(device, cores, initial_value, buffer_type, sub_device_ids)); } return multi_device_global_semaphore; } @@ -74,13 +55,7 @@ std::vector get_global_semaphore_address(const MultiDe std::vector addresses(global_semaphore.global_semaphores.size()); const auto& global_semaphores = global_semaphore.global_semaphores; for (uint32_t i = 0; i < global_semaphores.size(); ++i) { - const auto& global_semaphore = global_semaphores[i]; - auto& address = addresses[i]; - auto* device = global_semaphore->device(); - device->push_work([&global_semaphore, &address] { address = global_semaphore->address(); }); - } - for (const auto& global_semaphore : global_semaphores) { - global_semaphore->device()->synchronize(); + addresses[i] = get_global_semaphore_address(global_semaphores[i]); } return addresses; } diff --git a/ttnn/cpp/ttnn/global_semaphore.hpp b/ttnn/cpp/ttnn/global_semaphore.hpp index b9aca1bb2b3..bdd645641be 100644 --- a/ttnn/cpp/ttnn/global_semaphore.hpp +++ b/ttnn/cpp/ttnn/global_semaphore.hpp @@ -12,19 +12,21 @@ namespace ttnn::global_semaphore { struct MultiDeviceGlobalSemaphore { MultiDeviceGlobalSemaphore(MeshDevice* mesh_device); - std::vector> global_semaphores; + std::vector global_semaphores; }; // Single Device APIs -std::shared_ptr create_global_semaphore( +GlobalSemaphore create_global_semaphore( Device* device, const CoreRangeSet& cores, uint32_t initial_value, BufferType buffer_type = BufferType::L1, tt::stl::Span sub_device_ids = {}); -tt::tt_metal::DeviceAddr get_global_semaphore_address(const std::shared_ptr& global_semaphore); + +tt::tt_metal::DeviceAddr get_global_semaphore_address(const GlobalSemaphore& global_semaphore); + void reset_global_semaphore_value( - const std::shared_ptr& global_semaphore, + const GlobalSemaphore& global_semaphore, uint32_t reset_value, tt::stl::Span sub_device_ids = {}); @@ -35,7 +37,9 @@ MultiDeviceGlobalSemaphore create_global_semaphore( uint32_t initial_value, BufferType buffer_type = BufferType::L1, tt::stl::Span sub_device_ids = {}); + std::vector get_global_semaphore_address(const MultiDeviceGlobalSemaphore& global_semaphore); + void reset_global_semaphore_value( const MultiDeviceGlobalSemaphore& global_semaphore, uint32_t reset_value, diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/all_gather_async_pybind.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/all_gather_async_pybind.cpp index aae9cccf7f9..7917d3a9de2 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/all_gather_async_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/all_gather_async_pybind.cpp @@ -54,42 +54,41 @@ void bind_all_gather_async(pybind11::module& module, const ccl_operation_t& oper py::arg("enable_persistent_fabric_mode") = false, py::arg("create_semaphore_handles") = true}, - ttnn:: - pybind_overload_t{ - [](const ccl_operation_t& self, - const ttnn::Tensor& input_tensor, - const int32_t dim, - const uint32_t cluster_axis, - const MeshDevice& mesh_device, - const ttnn::ccl::Topology topology, - const std::optional num_preferred_links, - const std::optional& memory_config, - std::optional subdevice_id, - bool enable_persistent_fabric_mode, - bool create_semaphore_handles) -> ttnn::Tensor { - return self( - input_tensor, - dim, - cluster_axis, - mesh_device, - topology, - memory_config,// = std::nullopt, - num_preferred_links,// = std::nullopt, - subdevice_id,// = std::nullopt, - enable_persistent_fabric_mode,// = false, - create_semaphore_handles); - }, - py::arg("input_tensor"), - py::arg("dim"), - py::arg("cluster_axis"), - py::arg("mesh_device"), - py::arg("topology"), - py::kw_only(), - py::arg("num_links") = std::nullopt, - py::arg("memory_config") = std::nullopt, - py::arg("subdevice_id") = std::nullopt, - py::arg("enable_persistent_fabric_mode") = false, - py::arg("create_semaphore_handles") = true}); + ttnn::pybind_overload_t{ + [](const ccl_operation_t& self, + const ttnn::Tensor& input_tensor, + const int32_t dim, + const uint32_t cluster_axis, + const MeshDevice& mesh_device, + const ttnn::ccl::Topology topology, + const std::optional num_preferred_links, + const std::optional& memory_config, + std::optional subdevice_id, + bool enable_persistent_fabric_mode, + bool create_semaphore_handles) -> ttnn::Tensor { + return self( + input_tensor, + dim, + cluster_axis, + mesh_device, + topology, + memory_config, // = std::nullopt, + num_preferred_links, // = std::nullopt, + subdevice_id, // = std::nullopt, + enable_persistent_fabric_mode, // = false, + create_semaphore_handles); + }, + py::arg("input_tensor"), + py::arg("dim"), + py::arg("cluster_axis"), + py::arg("mesh_device"), + py::arg("topology"), + py::kw_only(), + py::arg("num_links") = std::nullopt, + py::arg("memory_config") = std::nullopt, + py::arg("subdevice_id") = std::nullopt, + py::arg("enable_persistent_fabric_mode") = false, + py::arg("create_semaphore_handles") = true}); } } // namespace detail diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.cpp index 44b4033cee1..5815d33c3b1 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.cpp @@ -1,10 +1,10 @@ -/// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC +// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC // // SPDX-License-Identifier: Apache-2.0 #include "all_gather_async_op.hpp" #include "ttnn/operations/math.hpp" -#include "tt_metal/impl/buffers/global_semaphore.hpp" +#include "ttnn/cpp/ttnn/global_semaphore.hpp" #include "tt_metal/host_api.hpp" @@ -23,19 +23,19 @@ AllGatherAsync create_all_gather_async_struct( const std::optional& memory_config, const std::vector& devices, const ttnn::ccl::Topology topology, - const std::optional>>& semaphore_handles, + const std::optional>& semaphores, bool enable_persistent_fabric_mode) { uint32_t num_devices = devices.size(); std::optional forward_device = std::nullopt; std::optional backward_device = std::nullopt; - std::shared_ptr semaphore_handle = nullptr; + std::optional semaphore = std::nullopt; uint32_t device_index = 0; // Initialize device index for (uint32_t i = 0; i < num_devices; ++i) { if (devices.at(i) == input_tensor.device()) { device_index = i; - if (semaphore_handles.has_value()) { - semaphore_handle = semaphore_handles.value().at(i); // Get raw pointer + if (semaphores.has_value()) { + semaphore = semaphores.value().at(i); // Get raw pointer } if (i != 0) { backward_device = devices.at(i - 1); @@ -55,40 +55,40 @@ AllGatherAsync create_all_gather_async_struct( device_index, memory_config.value_or(input_tensor.memory_config()), topology, - semaphore_handle, + semaphore, enable_persistent_fabric_mode}; } -std::optional>> get_global_semaphores( +std::optional> get_global_semaphores( const std::vector& devices, const CoreRange& core_range, std::optional subdevice_id, bool create_semaphore_handles) { - std::optional>> semaphore_handles_opt; + std::optional> semaphores_opt; if (create_semaphore_handles) { - std::vector> semaphore_handles; + std::vector semaphores; for (const auto& device : devices) { auto worker_subdevice_id = subdevice_id.has_value() ? std::vector{subdevice_id.value()} : std::vector{}; - auto handle = GlobalSemaphore::create(device, core_range, 0, BufferType::L1, worker_subdevice_id); - log_trace( - tt::LogOp, "Created semaphore handle at address {} for device {}", handle->address(), device->id()); - semaphore_handles.push_back(handle); + auto sem = + global_semaphore::create_global_semaphore(device, core_range, 0, BufferType::L1, worker_subdevice_id); + log_trace(tt::LogOp, "Created semaphore at address {} for device {}", sem.address(), device->id()); + semaphores.push_back(std::move(sem)); } - // HACK: assert every handle address is the same + // HACK: assert every address is the same TT_FATAL( std::all_of( - semaphore_handles.begin(), - semaphore_handles.end(), - [&](const auto& handle) { return handle->address() == semaphore_handles.front()->address(); }), - "[Hack] All semaphore handles should have the same address"); - semaphore_handles_opt = semaphore_handles; + semaphores.begin(), + semaphores.end(), + [&](const auto& sem) { return sem.address() == semaphores.front().address(); }), + "[Hack] All semaphores should have the same address"); + semaphores_opt = std::move(semaphores); } else { - semaphore_handles_opt = std::nullopt; + semaphores_opt = std::nullopt; } - return semaphore_handles_opt; + return semaphores_opt; } } // namespace all_gather_detail @@ -174,7 +174,7 @@ operation::ProgramWithCallbacks AllGatherAsync::create_program( this->ring_size, this->ring_index, this->topology, - this->semaphore_handle, + this->semaphore, this->enable_persistent_fabric_mode); } @@ -219,7 +219,7 @@ Tensor all_gather_async( CoreCoord grid_size = devices[0]->compute_with_storage_grid_size(); auto core_grid = CoreRange({0, 0}, {grid_size.x - 1, grid_size.y - 1}); - std::optional>> semaphore_handles_opt = + std::optional> semaphores_opt = ttnn::ccl::all_gather_detail::get_global_semaphores(devices, core_grid, subdevice_id, create_semaphore_handles); operation::launch_op( @@ -229,7 +229,7 @@ Tensor all_gather_async( memory_config, devices, ccl_topology, - semaphore_handles_opt, + semaphores_opt, enable_persistent_fabric_mode]( const std::vector& input_tensors, const std::vector>& optional_input_tensors, @@ -244,7 +244,7 @@ Tensor all_gather_async( memory_config, devices, ccl_topology, - semaphore_handles_opt, + semaphores_opt, enable_persistent_fabric_mode), {input_tensor}); }, @@ -285,7 +285,7 @@ Tensor all_gather_async( std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor}))}; CoreCoord grid_size = devices[0]->compute_with_storage_grid_size(); auto core_grid = CoreRange({0, 0}, {grid_size.x - 1, grid_size.y - 1}); - std::optional>> semaphore_handles_opt = + std::optional> semaphores_opt = ttnn::ccl::all_gather_detail::get_global_semaphores(devices, core_grid, subdevice_id, create_semaphore_handles); operation::launch_op( @@ -296,7 +296,7 @@ Tensor all_gather_async( cluster_axis, num_devices, topology, - semaphore_handles_opt, + semaphores_opt, enable_persistent_fabric_mode]( const std::vector& input_tensors, const std::vector>& optional_input_tensors, @@ -317,10 +317,9 @@ Tensor all_gather_async( memory_config, devices, topology, - semaphore_handles_opt, + semaphores_opt, enable_persistent_fabric_mode), {input_tensor}); - }, {input_tensor}, output_tensors); diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.hpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.hpp index b5bc4095f2f..d8b7a9c6648 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.hpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.hpp @@ -34,7 +34,7 @@ struct AllGatherAsync { const uint32_t ring_index; const MemoryConfig output_mem_config; const ccl::Topology topology; - std::optional> semaphore_handle; + const std::optional semaphore; bool enable_persistent_fabric_mode; AllGatherAsync( @@ -46,7 +46,7 @@ struct AllGatherAsync { uint32_t ring_index, MemoryConfig output_mem_config, ccl::Topology topology, - std::optional> semaphore_handle, + std::optional semaphore, bool enable_persistent_fabric_mode) : forward_device(forward_device), backward_device(backward_device), @@ -56,7 +56,7 @@ struct AllGatherAsync { ring_index(ring_index), output_mem_config(output_mem_config), topology(topology), - semaphore_handle(semaphore_handle), + semaphore(semaphore), enable_persistent_fabric_mode(enable_persistent_fabric_mode) {} // Add attributes method for reflection @@ -70,7 +70,7 @@ struct AllGatherAsync { attrs.emplace_back("ring_index", ring_index); attrs.emplace_back("output_mem_config", output_mem_config); attrs.emplace_back("topology", topology); - attrs.emplace_back("semaphore_handle", semaphore_handle.has_value() ? semaphore_handle.value().get() : nullptr); + attrs.emplace_back("semaphore", semaphore); return attrs; } @@ -92,7 +92,7 @@ AllGatherAsync create_all_gather_async_struct( const std::optional& memory_config, const std::vector& devices, const ccl::Topology topology, - const std::optional>& semaphore_handles, + const std::optional>& semaphores, bool enable_persistent_fabric_mode); } // namespace all_gather_async_detail } // namespace ccl @@ -108,7 +108,7 @@ operation::ProgramWithCallbacks all_gather_async_multi_core_with_workers( const uint32_t ring_size, const uint32_t ring_index, ccl::Topology topology, - const std::optional>& semaphore_handle_opt, + const std::optional& semaphore_opt, bool enable_persistent_fabric_mode); namespace operations { diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_program.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_program.cpp index dc83794cdb8..e9420f5e62a 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_program.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_program.cpp @@ -130,14 +130,14 @@ operation::ProgramWithCallbacks all_gather_async_multi_core_with_workers( const uint32_t ring_size, const uint32_t ring_index, ccl::Topology topology, - const std::optional>& semaphore_handle_opt, + const std::optional& semaphore_opt, bool enable_persistent_fabric_mode) { tt::tt_metal::Program program{}; const bool enable_async_output_tensor = false; - TT_FATAL(semaphore_handle_opt.has_value(), "Semaphore handle is required for compile time"); + TT_FATAL(semaphore_opt.has_value(), "Semaphore is required for compile time"); - auto semaphore_handle = semaphore_handle_opt.value(); + const auto& semaphore = semaphore_opt.value(); Device* device = input_tensor.device(); bool is_first_chip = ring_index == 0; @@ -320,21 +320,13 @@ operation::ProgramWithCallbacks all_gather_async_multi_core_with_workers( writer_cmd_stream.push_back(ttnn::ccl::cmd::uops::fabric_write_cb_to_tensor_slice( output_worker_slice_v2, src0_cb_index, mcast_dest_args)); // 2, mcast the semaphore to all dest for teardown - TT_FATAL( - semaphore_handle != nullptr, - "Internal error during all-=gather fatcory. Global semaphore for fabric teardown not properly " - "initialized for non-persistent fabric mode"); writer_cmd_stream.push_back(ttnn::ccl::cmd::uops::fabric_multicast_semaphore_inc( - semaphore_handle.get(), - ttnn::ccl::cmd::CclCommandAtomicInc{1}, - drain_sync_core.x, - drain_sync_core.y, - mcast_dest_args)); + &semaphore, ttnn::ccl::cmd::CclCommandAtomicInc{1}, drain_sync_core.x, drain_sync_core.y, mcast_dest_args)); if (!enable_async_output_tensor) { // 3, wait for n_chip*num_links number of semaphore at teardown semaphore address for first chip, and // n_chip*num_links+1 for other chips writer_cmd_stream.push_back(ttnn::ccl::cmd::uops::local_semaphore_wait( - semaphore_handle.get(), + &semaphore, is_first_chip ? ring_size * num_links : ring_size * num_links + !enable_persistent_fabric_mode)); } @@ -343,7 +335,7 @@ operation::ProgramWithCallbacks all_gather_async_multi_core_with_workers( // 4, send semaphore unicast to forward device except for the last chip if (!is_last_chip) { writer_cmd_stream.push_back(ttnn::ccl::cmd::uops::fabric_unicast_semaphore_inc( - semaphore_handle.get(), + &semaphore, ttnn::ccl::cmd::CclCommandAtomicInc{1}, drain_sync_core.x, drain_sync_core.y, @@ -359,7 +351,7 @@ operation::ProgramWithCallbacks all_gather_async_multi_core_with_workers( info.edm_noc_x, info.edm_noc_y, info.termination_addr, 1)); } // 6. (drain sync core) reset semaphore to 0 - writer_cmd_stream.push_back(ttnn::ccl::cmd::uops::local_core_semaphore_set(semaphore_handle.get(), 0)); + writer_cmd_stream.push_back(ttnn::ccl::cmd::uops::local_core_semaphore_set(&semaphore, 0)); } // set the rt args @@ -382,7 +374,7 @@ operation::ProgramWithCallbacks all_gather_async_multi_core_with_workers( } auto override_runtime_arguments_callback = - [worker_sender_reader_kernel_id, worker_sender_writer_kernel_id, semaphore_handle, sender_worker_cores]( + [worker_sender_reader_kernel_id, worker_sender_writer_kernel_id, semaphore, sender_worker_cores]( const void* operation, Program& program, const std::vector& input_tensors, diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_op.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_op.cpp index 1f9e4dfac3e..8092a445b0c 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_op.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_op.cpp @@ -5,7 +5,7 @@ #include "ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_op.hpp" #include "sub_device/sub_device_types.hpp" #include "tt_metal/host_api.hpp" -#include "tt_metal/impl/buffers/global_semaphore.hpp" +#include "ttnn/cpp/ttnn/global_semaphore.hpp" #include #include @@ -230,7 +230,9 @@ std::vector> create_global_ auto worker_subdevice_id = worker_subdevice_id_opt.has_value() ? std::vector{worker_subdevice_id_opt.value()} : std::vector{}; - auto sem = CreateGlobalSemaphore(d, core_grid, 0, BufferType::L1, worker_subdevice_id); + // TODO: Remove shared_ptr + auto sem = std::make_shared( + global_semaphore::create_global_semaphore(d, core_grid, 0, BufferType::L1, worker_subdevice_id)); semaphores.push_back(sem); } @@ -255,7 +257,9 @@ std::vector> create_global_ auto worker_subdevice_id = worker_subdevice_id_opt.has_value() ? std::vector{worker_subdevice_id_opt.value()} : std::vector{}; - auto sem = CreateGlobalSemaphore(devices[i], core_grid, 0, BufferType::L1, worker_subdevice_id); + // TODO: Remove shared_ptr + auto sem = std::make_shared(global_semaphore::create_global_semaphore( + devices[i], core_grid, 0, BufferType::L1, worker_subdevice_id)); if (sem->address() == highest_addr) { semaphores[i] = sem; } else { @@ -310,10 +314,8 @@ Tensor reduce_scatter( std::optional>> from_remote_inputs_semaphores_opt; std::optional>> to_remote_inputs_semaphores_opt; if (create_semaphore_handles) { - const auto from_remote_inputs_semaphores = create_global_semaphores(devices, worker_subdevice_id_opt); - const auto to_remote_inputs_semaphores = create_global_semaphores(devices, worker_subdevice_id_opt); - from_remote_inputs_semaphores_opt = from_remote_inputs_semaphores; - to_remote_inputs_semaphores_opt = to_remote_inputs_semaphores; + from_remote_inputs_semaphores_opt = create_global_semaphores(devices, worker_subdevice_id_opt); + to_remote_inputs_semaphores_opt = create_global_semaphores(devices, worker_subdevice_id_opt); } else { from_remote_inputs_semaphores_opt = std::nullopt; to_remote_inputs_semaphores_opt = std::nullopt; @@ -389,10 +391,8 @@ Tensor reduce_scatter( std::optional>> from_remote_inputs_semaphores_opt; std::optional>> to_remote_inputs_semaphores_opt; if (create_semaphore_handles) { - const auto from_remote_inputs_semaphores = create_global_semaphores(devices, worker_subdevice_id_opt); - const auto to_remote_inputs_semaphores = create_global_semaphores(devices, worker_subdevice_id_opt); - from_remote_inputs_semaphores_opt = from_remote_inputs_semaphores; - to_remote_inputs_semaphores_opt = to_remote_inputs_semaphores; + from_remote_inputs_semaphores_opt = create_global_semaphores(devices, worker_subdevice_id_opt); + to_remote_inputs_semaphores_opt = create_global_semaphores(devices, worker_subdevice_id_opt); } else { from_remote_inputs_semaphores_opt = std::nullopt; to_remote_inputs_semaphores_opt = std::nullopt;