Skip to content

Commit

Permalink
#0: Don't return shared ptrs of global sems/cbs, and directly return …
Browse files Browse the repository at this point in the history
…the object instead

global sems/cbs are natively thread safe now, so user can decide whether to use shared ptrs or not
  • Loading branch information
tt-aho committed Jan 2, 2025
1 parent 83f816b commit 775b799
Show file tree
Hide file tree
Showing 23 changed files with 179 additions and 287 deletions.
14 changes: 7 additions & 7 deletions tests/tt_metal/tt_metal/api/test_global_circular_buffers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ TEST_F(DispatchFixture, TensixCreateGlobalCircularBuffers) {
sender_receiver_core_mapping[CoreCoord(0, 0)] = cores;
auto global_cb = tt::tt_metal::v1::experimental::CreateGlobalCircularBuffer(
device, sender_receiver_core_mapping, 3200, tt::tt_metal::BufferType::L1);
auto buffer_address = global_cb->buffer_address();
auto config_address = global_cb->config_address();
auto buffer_address = global_cb.buffer_address();
auto config_address = global_cb.config_address();
}
{
std::unordered_map<CoreCoord, CoreRangeSet> sender_receiver_core_mapping;
Expand Down Expand Up @@ -84,14 +84,14 @@ TEST_F(DispatchFixture, TensixProgramGlobalCircularBuffers) {
EXPECT_THROW(global_cb_config.remote_index(2), std::exception);
EXPECT_THROW(
tt::tt_metal::v1::experimental::CreateCircularBuffer(
program, CoreRangeSet(CoreRange({3, 3})), global_cb_config, *global_cb),
program, CoreRangeSet(CoreRange({3, 3})), global_cb_config, global_cb),
std::exception);
auto remote_cb =
tt::tt_metal::v1::experimental::CreateCircularBuffer(program, receiver_cores, global_cb_config, *global_cb);
tt::tt_metal::v1::experimental::CreateCircularBuffer(program, receiver_cores, global_cb_config, global_cb);
tt::tt_metal::detail::CompileProgram(device, program);
program.finalize(device);
tt::tt_metal::v1::experimental::UpdateDynamicCircularBufferAddress(program, remote_cb, *global_cb);
EXPECT_THROW(UpdateDynamicCircularBufferAddress(program, remote_cb, *dummy_global_cb), std::exception);
tt::tt_metal::v1::experimental::UpdateDynamicCircularBufferAddress(program, remote_cb, global_cb);
EXPECT_THROW(UpdateDynamicCircularBufferAddress(program, remote_cb, dummy_global_cb), std::exception);
}
{
tt::tt_metal::Program program = CreateProgram();
Expand All @@ -107,7 +107,7 @@ TEST_F(DispatchFixture, TensixProgramGlobalCircularBuffers) {
global_cb_config.remote_index(remote_cb_index).set_page_size(cb_page_size).set_data_format(tile_format);
global_cb_config.index(local_cb_index).set_page_size(cb_page_size).set_data_format(tile_format);
auto remote_cb =
tt::tt_metal::v1::experimental::CreateCircularBuffer(program, receiver_cores, global_cb_config, *global_cb);
tt::tt_metal::v1::experimental::CreateCircularBuffer(program, receiver_cores, global_cb_config, global_cb);
tt::tt_metal::detail::CompileProgram(device, program);
EXPECT_THROW(program.finalize(device), std::exception);
}
Expand Down
12 changes: 6 additions & 6 deletions tests/tt_metal/tt_metal/api/test_global_semaphores.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ TEST_F(DispatchFixture, InitializeGlobalSemaphores) {
{
uint32_t initial_value = 1;
auto global_semaphore = tt::tt_metal::CreateGlobalSemaphore(device, cores, initial_value);
auto address = global_semaphore->address();
auto address = global_semaphore.address();
Synchronize(device);
for (const auto& core : cores_vec) {
auto sem_vals = tt::llrt::read_hex_vec_from_core(
Expand All @@ -32,7 +32,7 @@ TEST_F(DispatchFixture, InitializeGlobalSemaphores) {
{
uint32_t initial_value = 2;
auto global_semaphore = tt::tt_metal::CreateGlobalSemaphore(device, cores, initial_value);
auto address = global_semaphore->address();
auto address = global_semaphore.address();
Synchronize(device);
for (const auto& core : cores_vec) {
auto sem_vals = tt::llrt::read_hex_vec_from_core(
Expand All @@ -53,13 +53,13 @@ TEST_F(DispatchFixture, CreateMultipleGlobalSemaphoresOnSameCore) {
}
for (auto device : devices_) {
{
std::vector<std::shared_ptr<tt::tt_metal::GlobalSemaphore>> global_semaphores;
std::vector<tt::tt_metal::GlobalSemaphore> global_semaphores;
global_semaphores.reserve(cores.size());
std::vector<DeviceAddr> addresses;
addresses.reserve(cores.size());
for (size_t i = 0; i < cores.size(); i++) {
global_semaphores.push_back(tt::tt_metal::CreateGlobalSemaphore(device, cores[i], initial_values[i]));
addresses.push_back(global_semaphores[i]->address());
addresses.push_back(global_semaphores[i].address());
}
Synchronize(device);
for (size_t i = 0; i < cores.size(); i++) {
Expand Down Expand Up @@ -89,7 +89,7 @@ TEST_F(DispatchFixture, ResetGlobalSemaphores) {
uint32_t reset_value = 2;
std::vector<uint32_t> overwrite_value = {2};
auto global_semaphore = tt::tt_metal::CreateGlobalSemaphore(device, cores, initial_value);
auto address = global_semaphore->address();
auto address = global_semaphore.address();
Synchronize(device);
for (const auto& core : cores_vec) {
auto sem_vals = tt::llrt::read_hex_vec_from_core(
Expand All @@ -105,7 +105,7 @@ TEST_F(DispatchFixture, ResetGlobalSemaphores) {

EXPECT_EQ(sem_vals[0], overwrite_value[0]);
}
global_semaphore->reset_semaphore_value(reset_value);
global_semaphore.reset_semaphore_value(reset_value);
Synchronize(device);
for (const auto& core : cores_vec) {
auto sem_vals = tt::llrt::read_hex_vec_from_core(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ TEST_F(CommandQueueSingleCardFixture, TensixTestSubDeviceSynchronization) {
EXPECT_TRUE(std::equal(input_1_it, input_1_it + page_size_1 / sizeof(uint32_t), readback.begin()));
input_1_it += page_size_1 / sizeof(uint32_t);
}
auto sem_addr = global_semaphore->address();
auto sem_addr = global_semaphore.address();
auto physical_syncer_core = device->worker_core_from_logical_core(syncer_core);
tt::llrt::write_hex_vec_to_core(device->id(), physical_syncer_core, std::vector<uint32_t>{1}, sem_addr);

Expand Down
20 changes: 10 additions & 10 deletions tests/tt_metal/tt_metal/dispatch/sub_device_test_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
// TODO: ARCH_NAME specific, must remove
#include "eth_l1_address_map.h"

inline std::tuple<Program, CoreCoord, std::shared_ptr<GlobalSemaphore>> create_single_sync_program(
inline std::tuple<Program, CoreCoord, GlobalSemaphore> create_single_sync_program(
Device* device, SubDevice sub_device) {
auto syncer_coord = sub_device.cores(HalProgrammableCoreType::TENSIX).ranges().at(0).start_coord;
auto syncer_core = CoreRangeSet(CoreRange(syncer_coord, syncer_coord));
Expand All @@ -21,12 +21,12 @@ inline std::tuple<Program, CoreCoord, std::shared_ptr<GlobalSemaphore>> create_s
"tests/tt_metal/tt_metal/test_kernels/misc/sub_device/syncer.cpp",
syncer_core,
DataMovementConfig{.processor = DataMovementProcessor::RISCV_0, .noc = NOC::RISCV_0_default});
std::array<uint32_t, 1> syncer_rt_args = {global_sem->address()};
std::array<uint32_t, 1> syncer_rt_args = {global_sem.address()};
SetRuntimeArgs(syncer_program, syncer_kernel, syncer_core, syncer_rt_args);
return {std::move(syncer_program), std::move(syncer_coord), std::move(global_sem)};
}

inline std::tuple<Program, Program, Program, std::shared_ptr<GlobalSemaphore>> create_basic_sync_program(
inline std::tuple<Program, Program, Program, GlobalSemaphore> create_basic_sync_program(
Device* device, const SubDevice& sub_device_1, const SubDevice& sub_device_2) {
auto waiter_coord = sub_device_2.cores(HalProgrammableCoreType::TENSIX).ranges().at(0).start_coord;
auto waiter_core = CoreRangeSet(CoreRange(waiter_coord, waiter_coord));
Expand All @@ -45,7 +45,7 @@ inline std::tuple<Program, Program, Program, std::shared_ptr<GlobalSemaphore>> c
waiter_core,
DataMovementConfig{.processor = DataMovementProcessor::RISCV_0, .noc = NOC::RISCV_0_default});
std::array<uint32_t, 4> waiter_rt_args = {
global_sem->address(), incrementer_cores.num_cores(), syncer_core_physical.x, syncer_core_physical.y};
global_sem.address(), incrementer_cores.num_cores(), syncer_core_physical.x, syncer_core_physical.y};
SetRuntimeArgs(waiter_program, waiter_kernel, waiter_core, waiter_rt_args);

Program syncer_program = CreateProgram();
Expand All @@ -54,7 +54,7 @@ inline std::tuple<Program, Program, Program, std::shared_ptr<GlobalSemaphore>> c
"tests/tt_metal/tt_metal/test_kernels/misc/sub_device/syncer.cpp",
syncer_core,
DataMovementConfig{.processor = DataMovementProcessor::RISCV_0, .noc = NOC::RISCV_0_default});
std::array<uint32_t, 1> syncer_rt_args = {global_sem->address()};
std::array<uint32_t, 1> syncer_rt_args = {global_sem.address()};
SetRuntimeArgs(syncer_program, syncer_kernel, syncer_core, syncer_rt_args);

Program incrementer_program = CreateProgram();
Expand All @@ -64,13 +64,13 @@ inline std::tuple<Program, Program, Program, std::shared_ptr<GlobalSemaphore>> c
incrementer_cores,
DataMovementConfig{.processor = DataMovementProcessor::RISCV_1, .noc = NOC::RISCV_1_default});
std::array<uint32_t, 3> incrementer_rt_args = {
global_sem->address(), waiter_core_physical.x, waiter_core_physical.y};
global_sem.address(), waiter_core_physical.x, waiter_core_physical.y};
SetRuntimeArgs(incrementer_program, incrementer_kernel, incrementer_cores, incrementer_rt_args);
return {
std::move(waiter_program), std::move(syncer_program), std::move(incrementer_program), std::move(global_sem)};
}

inline std::tuple<Program, Program, Program, std::shared_ptr<GlobalSemaphore>> create_basic_eth_sync_program(
inline std::tuple<Program, Program, Program, GlobalSemaphore> create_basic_eth_sync_program(
Device* device, const SubDevice& sub_device_1, const SubDevice& sub_device_2) {
auto waiter_coord = sub_device_2.cores(HalProgrammableCoreType::ACTIVE_ETH).ranges().at(0).start_coord;
auto waiter_core = CoreRangeSet(CoreRange(waiter_coord, waiter_coord));
Expand All @@ -92,7 +92,7 @@ inline std::tuple<Program, Program, Program, std::shared_ptr<GlobalSemaphore>> c
waiter_core,
EthernetConfig{.noc = NOC::RISCV_0_default, .processor = DataMovementProcessor::RISCV_0});
std::array<uint32_t, 7> waiter_rt_args = {
global_sem->address(),
global_sem.address(),
incrementer_cores.num_cores(),
syncer_core_physical.x,
syncer_core_physical.y,
Expand All @@ -107,7 +107,7 @@ inline std::tuple<Program, Program, Program, std::shared_ptr<GlobalSemaphore>> c
"tests/tt_metal/tt_metal/test_kernels/misc/sub_device/syncer.cpp",
syncer_core,
DataMovementConfig{.processor = DataMovementProcessor::RISCV_0, .noc = NOC::RISCV_0_default});
std::array<uint32_t, 1> syncer_rt_args = {global_sem->address()};
std::array<uint32_t, 1> syncer_rt_args = {global_sem.address()};
SetRuntimeArgs(syncer_program, syncer_kernel, syncer_core, syncer_rt_args);

Program incrementer_program = CreateProgram();
Expand All @@ -117,7 +117,7 @@ inline std::tuple<Program, Program, Program, std::shared_ptr<GlobalSemaphore>> c
incrementer_cores,
DataMovementConfig{.processor = DataMovementProcessor::RISCV_1, .noc = NOC::RISCV_1_default});
std::array<uint32_t, 3> incrementer_rt_args = {
global_sem->address(), tensix_waiter_core_physical.x, tensix_waiter_core_physical.y};
global_sem.address(), tensix_waiter_core_physical.x, tensix_waiter_core_physical.y};
SetRuntimeArgs(incrementer_program, incrementer_kernel, incrementer_cores, incrementer_rt_args);
return {
std::move(waiter_program), std::move(syncer_program), std::move(incrementer_program), std::move(global_sem)};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ void get_max_page_size_and_num_pages(
num_pages = total_size / page_size;
}

std::tuple<std::vector<tt_metal::Program>,std::shared_ptr<tt_metal::v1::experimental::GlobalCircularBuffer>>
std::tuple<std::vector<tt_metal::Program>, tt_metal::v1::experimental::GlobalCircularBuffer>
create_programs(
tt_metal::Device* device,
const CoreRangeSet& dram_reader_core,
Expand Down Expand Up @@ -146,7 +146,7 @@ create_programs(
tt_metal::CircularBufferConfig writer_cb_config = tt_metal::CircularBufferConfig(receiver_cb_size);
writer_cb_config.remote_index(writer_cb_index).set_page_size(single_tile_size).set_data_format(tile_format);
auto writer_cb =
tt_metal::v1::experimental::CreateCircularBuffer(sender_program, dram_reader_core, writer_cb_config, *global_cb);
tt_metal::v1::experimental::CreateCircularBuffer(sender_program, dram_reader_core, writer_cb_config, global_cb);

// mixed cb dataformat
uint32_t next_layer_num_blocks = num_blocks * 2;
Expand Down Expand Up @@ -178,7 +178,7 @@ create_programs(
tt_metal::CircularBufferConfig receiver_cb_config = tt_metal::CircularBufferConfig(receiver_cb_size);
receiver_cb_config.remote_index(receiver_cb_index).set_page_size(single_tile_size).set_data_format(tile_format);
auto receiver_cb = tt_metal::v1::experimental::CreateCircularBuffer(
receiver_program, l1_receiver_cores, receiver_cb_config, *global_cb);
receiver_program, l1_receiver_cores, receiver_cb_config, global_cb);

log_info("reader_cb_size: {}", reader_cb_size);
log_info("receiver_cb_size: {}", receiver_cb_size);
Expand Down Expand Up @@ -846,7 +846,7 @@ int main(int argc, char** argv) {
tt::DataFormat::Bfp8_b,
l1_receiver_core,
num_receivers,
global_cb->buffer_address());
global_cb.buffer_address());

} else {
// output
Expand All @@ -860,7 +860,7 @@ int main(int argc, char** argv) {
tt::DataFormat::Float16_b,
l1_receiver_core,
num_receivers,
global_cb->buffer_address());
global_cb.buffer_address());
}

////////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ std::tuple<uint32_t, uint32_t> get_out_subblock_params(
return {1, 1};
}

std::tuple<std::vector<tt_metal::Program>, std::shared_ptr<tt::tt_metal::v1::experimental::GlobalCircularBuffer>>
std::tuple<std::vector<tt_metal::Program>, ::tt_metal::v1::experimental::GlobalCircularBuffer>
create_programs(
tt_metal::Device* device,
const CoreRangeSet& dram_reader_core,
Expand Down Expand Up @@ -169,7 +169,7 @@ create_programs(
tt_metal::CircularBufferConfig in1_writer_cb_config = tt_metal::CircularBufferConfig(in1_receiver_cb_size);
in1_writer_cb_config.remote_index(in1_writer_cb_index).set_page_size(single_tile_size).set_data_format(tile_format);
auto writer_cb = tt_metal::v1::experimental::CreateCircularBuffer(
sender_program, dram_reader_core, in1_writer_cb_config, *global_cb);
sender_program, dram_reader_core, in1_writer_cb_config, global_cb);

// in0 reader CB
uint32_t in0_reader_cb_index = 0;
Expand All @@ -190,7 +190,7 @@ create_programs(
.set_data_format(tile_format);
in1_receiver_cb_config.index(in1_pusher_cb_index).set_page_size(single_tile_size).set_data_format(tile_format);
auto in1_receiver_cb = tt_metal::v1::experimental::CreateCircularBuffer(
receiver_program, l1_receiver_cores, in1_receiver_cb_config, *global_cb);
receiver_program, l1_receiver_cores, in1_receiver_cb_config, global_cb);

// output CB
uint32_t output_cb_index = 16;
Expand Down
8 changes: 4 additions & 4 deletions tt_metal/host_api.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ uint32_t CreateSemaphore(
* Initializes a global semaphore on all cores within the specified CoreRangeSet.
* This only supports tensix cores, and can only use L1 buffer types like BufferType::L1 and BufferType::L1_SMALL.
*
* Return value: std::shared_ptr<GlobalSemaphore>
* Return value: GlobalSemaphore
*
* | Argument | Description | Type | Valid Range | Required |
* |----------------|--------------------------------------------------------|-----------------------------------------------------------|--------------|----------|
Expand All @@ -308,7 +308,7 @@ uint32_t CreateSemaphore(
* | sub_device_ids | Sub-device ids to wait on before writing the semaphore | tt::stl::Span<const SubDeviceId> | | No |
*/
// clang-format on
std::shared_ptr<GlobalSemaphore> CreateGlobalSemaphore(
GlobalSemaphore CreateGlobalSemaphore(
Device* device,
const CoreRangeSet& cores,
uint32_t initial_value,
Expand All @@ -320,7 +320,7 @@ std::shared_ptr<GlobalSemaphore> CreateGlobalSemaphore(
* Initializes a global semaphore on all cores within the specified CoreRangeSet.
* This only supports tensix cores, and can only use L1 buffer types like BufferType::L1 and BufferType::L1_SMALL.
*
* Return value: std::shared_ptr<GlobalSemaphore>
* Return value: GlobalSemaphore
*
* | Argument | Description | Type | Valid Range | Required |
* |----------------|--------------------------------------------------------|-----------------------------------------------------------|--------------|----------|
Expand All @@ -331,7 +331,7 @@ std::shared_ptr<GlobalSemaphore> CreateGlobalSemaphore(
* | sub_device_ids | Sub-device ids to wait on before writing the semaphore | tt::stl::Span<const SubDeviceId> | | No |
*/
// clang-format on
std::shared_ptr<GlobalSemaphore> CreateGlobalSemaphore(
GlobalSemaphore CreateGlobalSemaphore(
Device* device,
CoreRangeSet&& cores,
uint32_t initial_value,
Expand Down
13 changes: 1 addition & 12 deletions tt_metal/impl/buffers/global_circular_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ GlobalCircularBuffer::GlobalCircularBuffer(
const std::unordered_map<CoreCoord, CoreRangeSet>& sender_receiver_core_mapping,
uint32_t size,
BufferType buffer_type,
tt::stl::Span<const SubDeviceId> sub_device_ids,
Private) :
tt::stl::Span<const SubDeviceId> sub_device_ids) :
device_(device), sender_receiver_core_mapping_(sender_receiver_core_mapping), size_(size) {
TT_FATAL(this->device_ != nullptr, "Device cannot be null");
uint32_t num_sender_cores = sender_receiver_core_mapping.size();
Expand Down Expand Up @@ -148,16 +147,6 @@ void GlobalCircularBuffer::setup_cb_buffers(
});
}

std::shared_ptr<GlobalCircularBuffer> GlobalCircularBuffer::create(
Device* device,
const std::unordered_map<CoreCoord, CoreRangeSet>& sender_receiver_core_mapping,
uint32_t size,
BufferType buffer_type,
tt::stl::Span<const SubDeviceId> sub_device_ids) {
return std::make_shared<GlobalCircularBuffer>(
device, sender_receiver_core_mapping, size, buffer_type, sub_device_ids, Private());
}

const Buffer& GlobalCircularBuffer::cb_buffer() const { return *this->cb_buffer_; }

const CoreRangeSet& GlobalCircularBuffer::sender_cores() const { return this->sender_cores_; }
Expand Down
Loading

0 comments on commit 775b799

Please sign in to comment.