Skip to content

Commit

Permalink
#0: Update global semaphores and global circular buffers metal apis t…
Browse files Browse the repository at this point in the history
…o be thread-safe instead of depending on ttnn apis
  • Loading branch information
tt-aho committed Dec 11, 2024
1 parent 6b2d29d commit eb76f95
Show file tree
Hide file tree
Showing 9 changed files with 143 additions and 95 deletions.
5 changes: 3 additions & 2 deletions tests/tt_metal/tt_metal/api/test_global_semaphores.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ TEST_F(DispatchFixture, ResetGlobalSemaphores) {
for (auto device : devices_) {
{
uint32_t initial_value = 1;
uint32_t reset_value = 2;
std::vector<uint32_t> overwrite_value = {2};
auto global_semaphore = tt::tt_metal::CreateGlobalSemaphore(device, cores, initial_value);
auto address = global_semaphore->address();
Expand All @@ -104,14 +105,14 @@ TEST_F(DispatchFixture, ResetGlobalSemaphores) {

EXPECT_EQ(sem_vals[0], overwrite_value[0]);
}
global_semaphore->reset_semaphore_value();
global_semaphore->reset_semaphore_value(reset_value);
Synchronize(device);
for (const auto& core : cores_vec) {
auto sem_vals = tt::llrt::read_hex_vec_from_core(
device->id(), device->worker_core_from_logical_core(core), address, sizeof(uint32_t));
tt::llrt::write_hex_vec_to_core(
device->id(), device->worker_core_from_logical_core(core), overwrite_value, address);
EXPECT_EQ(sem_vals[0], initial_value);
EXPECT_EQ(sem_vals[0], reset_value);
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion tests/ttnn/unit_tests/test_global_semaphore.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def run_global_semaphore(device):

assert ttnn.get_global_semaphore_address(global_sem0) != ttnn.get_global_semaphore_address(global_sem1)

ttnn.reset_global_semaphore_value(global_sem0)
ttnn.reset_global_semaphore_value(global_sem0, 3)


@pytest.mark.parametrize("enable_async_mode", (False, True), indirect=True)
Expand Down
112 changes: 60 additions & 52 deletions tt_metal/impl/buffers/global_circular_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ GlobalCircularBuffer::GlobalCircularBuffer(
const std::unordered_map<CoreCoord, CoreRangeSet>& sender_receiver_core_mapping,
uint32_t size,
BufferType buffer_type,
tt::stl::Span<const SubDeviceId> sub_device_ids) :
tt::stl::Span<const SubDeviceId> sub_device_ids,
Private) :
device_(device), sender_receiver_core_mapping_(sender_receiver_core_mapping), size_(size) {
TT_FATAL(this->device_ != nullptr, "Device cannot be null");
uint32_t num_sender_cores = sender_receiver_core_mapping.size();
Expand Down Expand Up @@ -86,58 +87,65 @@ void GlobalCircularBuffer::setup_cb_buffers(
shard_parameters,
std::nullopt);

const auto& core_to_core_id = this->cb_config_buffer_->get_buffer_page_mapping()->core_to_core_id_;

std::vector<uint32_t> cb_config_host_buffer(cb_config_size / sizeof(uint32_t), 0);
uint32_t buffer_address = this->cb_buffer_->address();
uint32_t noc_xy_address = this->cb_config_buffer_->address() + num_config_elements * sizeof(uint32_t);
uint32_t pages_sent_address = align(noc_xy_address + num_noc_xy_words * sizeof(uint32_t), l1_alignment);

for (const auto& [sender_core, receiver_cores] : this->sender_receiver_core_mapping_) {
const auto& receiver_cores_vec = corerange_to_cores(receiver_cores);
uint32_t sender_idx = core_to_core_id.at(sender_core) * cb_config_page_size / sizeof(uint32_t);
uint32_t num_receivers = receiver_cores.num_cores();
uint32_t pages_acked_address = pages_sent_address + num_receivers * l1_alignment;
cb_config_host_buffer[sender_idx++] = 1;
cb_config_host_buffer[sender_idx++] = receiver_cores.num_cores();
cb_config_host_buffer[sender_idx++] = buffer_address;
cb_config_host_buffer[sender_idx++] = this->size_;
cb_config_host_buffer[sender_idx++] = buffer_address;
cb_config_host_buffer[sender_idx++] = noc_xy_address;
cb_config_host_buffer[sender_idx++] = pages_sent_address;

auto sender_physical_coord = this->device_->worker_core_from_logical_core(sender_core);
for (uint32_t i = 0; i < receiver_cores_vec.size(); i++) {
auto receiver_physical_coord = this->device_->worker_core_from_logical_core(receiver_cores_vec[i]);
cb_config_host_buffer[sender_idx++] = receiver_physical_coord.x;
cb_config_host_buffer[sender_idx++] = receiver_physical_coord.y;

uint32_t receiver_idx = core_to_core_id.at(receiver_cores_vec[i]) * cb_config_page_size / sizeof(uint32_t);
cb_config_host_buffer[receiver_idx++] = 0;
cb_config_host_buffer[receiver_idx++] = num_receivers;
cb_config_host_buffer[receiver_idx++] = buffer_address;
cb_config_host_buffer[receiver_idx++] = this->size_;
cb_config_host_buffer[receiver_idx++] = buffer_address;
cb_config_host_buffer[receiver_idx++] = noc_xy_address;
cb_config_host_buffer[receiver_idx++] = pages_sent_address + 2 * i * l1_alignment;
cb_config_host_buffer[receiver_idx++] = sender_physical_coord.x;
cb_config_host_buffer[receiver_idx++] = sender_physical_coord.y;
}
}

// Write the config buffer to the device
// Only block for the slow dispatch case
if (this->device_->using_slow_dispatch()) {
detail::WriteToBuffer(*this->cb_config_buffer_, cb_config_host_buffer);
tt::Cluster::instance().l1_barrier(this->device_->id());
} else {
EnqueueWriteBuffer(
this->device_->command_queue(),
this->cb_config_buffer_,
cb_config_host_buffer.data(),
false,
sub_device_ids);
}
auto device = this->device_;
device->push_work([device,
cb_config_size,
cb_config_page_size,
num_noc_xy_words,
l1_alignment,
buffer_address = this->cb_buffer_->address(),
cb_config_buffer = this->cb_config_buffer_,
size = this->size_,
sender_receiver_core_mapping = this->sender_receiver_core_mapping_,
sub_device_ids = std::vector<SubDeviceId>(sub_device_ids.begin(), sub_device_ids.end())] {
auto config_buffer_address = cb_config_buffer->address();
const auto& core_to_core_id = cb_config_buffer->get_buffer_page_mapping()->core_to_core_id_;
std::vector<uint32_t> cb_config_host_buffer(cb_config_size / sizeof(uint32_t), 0);
uint32_t noc_xy_address = config_buffer_address + num_config_elements * sizeof(uint32_t);
uint32_t pages_sent_address = align(noc_xy_address + num_noc_xy_words * sizeof(uint32_t), l1_alignment);

for (const auto& [sender_core, receiver_cores] : sender_receiver_core_mapping) {
const auto& receiver_cores_vec = corerange_to_cores(receiver_cores);
uint32_t sender_idx = core_to_core_id.at(sender_core) * cb_config_page_size / sizeof(uint32_t);
uint32_t num_receivers = receiver_cores.num_cores();
uint32_t pages_acked_address = pages_sent_address + num_receivers * l1_alignment;
cb_config_host_buffer[sender_idx++] = 1;
cb_config_host_buffer[sender_idx++] = receiver_cores.num_cores();
cb_config_host_buffer[sender_idx++] = buffer_address;
cb_config_host_buffer[sender_idx++] = size;
cb_config_host_buffer[sender_idx++] = buffer_address;
cb_config_host_buffer[sender_idx++] = noc_xy_address;
cb_config_host_buffer[sender_idx++] = pages_sent_address;

auto sender_physical_coord = device->worker_core_from_logical_core(sender_core);
for (uint32_t i = 0; i < receiver_cores_vec.size(); i++) {
auto receiver_physical_coord = device->worker_core_from_logical_core(receiver_cores_vec[i]);
cb_config_host_buffer[sender_idx++] = receiver_physical_coord.x;
cb_config_host_buffer[sender_idx++] = receiver_physical_coord.y;

uint32_t receiver_idx =
core_to_core_id.at(receiver_cores_vec[i]) * cb_config_page_size / sizeof(uint32_t);
cb_config_host_buffer[receiver_idx++] = 0;
cb_config_host_buffer[receiver_idx++] = num_receivers;
cb_config_host_buffer[receiver_idx++] = buffer_address;
cb_config_host_buffer[receiver_idx++] = size;
cb_config_host_buffer[receiver_idx++] = buffer_address;
cb_config_host_buffer[receiver_idx++] = noc_xy_address;
cb_config_host_buffer[receiver_idx++] = pages_sent_address + 2 * i * l1_alignment;
cb_config_host_buffer[receiver_idx++] = sender_physical_coord.x;
cb_config_host_buffer[receiver_idx++] = sender_physical_coord.y;
}
}
if (device->using_slow_dispatch()) {
detail::WriteToBuffer(*cb_config_buffer, cb_config_host_buffer);
tt::Cluster::instance().l1_barrier(device->id());
} else {
EnqueueWriteBuffer(
device->command_queue(), cb_config_buffer, cb_config_host_buffer.data(), false, sub_device_ids);
}
});
}

std::shared_ptr<GlobalCircularBuffer> GlobalCircularBuffer::create(
Expand All @@ -147,7 +155,7 @@ std::shared_ptr<GlobalCircularBuffer> GlobalCircularBuffer::create(
BufferType buffer_type,
tt::stl::Span<const SubDeviceId> sub_device_ids) {
return std::make_shared<GlobalCircularBuffer>(
device, sender_receiver_core_mapping, size, buffer_type, sub_device_ids);
device, sender_receiver_core_mapping, size, buffer_type, sub_device_ids, Private());
}

const Buffer& GlobalCircularBuffer::cb_buffer() const { return *this->cb_buffer_; }
Expand Down
7 changes: 6 additions & 1 deletion tt_metal/impl/buffers/global_circular_buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,18 @@ namespace v1 {
namespace experimental {

class GlobalCircularBuffer {
struct Private {
explicit Private() = default;
};

public:
GlobalCircularBuffer(
Device* device,
const std::unordered_map<CoreCoord, CoreRangeSet>& sender_receiver_core_mapping,
uint32_t size,
BufferType buffer_type,
tt::stl::Span<const SubDeviceId> sub_device_ids);
tt::stl::Span<const SubDeviceId> sub_device_ids,
Private);

GlobalCircularBuffer(const GlobalCircularBuffer&) = default;
GlobalCircularBuffer& operator=(const GlobalCircularBuffer&) = default;
Expand Down
48 changes: 29 additions & 19 deletions tt_metal/impl/buffers/global_semaphore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,25 @@ GlobalSemaphore::GlobalSemaphore(
const CoreRangeSet& cores,
uint32_t initial_value,
BufferType buffer_type,
tt::stl::Span<const SubDeviceId> sub_device_ids) :
device_(device), cores_(cores), initial_value_(initial_value) {
this->setup_buffer(buffer_type, sub_device_ids);
tt::stl::Span<const SubDeviceId> sub_device_ids,
Private) :
device_(device), cores_(cores) {
this->setup_buffer(initial_value, buffer_type, sub_device_ids);
}

GlobalSemaphore::GlobalSemaphore(
Device* device,
CoreRangeSet&& cores,
uint32_t initial_value,
BufferType buffer_type,
tt::stl::Span<const SubDeviceId> sub_device_ids) :
device_(device), cores_(std::move(cores)), initial_value_(initial_value) {
this->setup_buffer(buffer_type, sub_device_ids);
tt::stl::Span<const SubDeviceId> sub_device_ids,
Private) :
device_(device), cores_(std::move(cores)) {
this->setup_buffer(initial_value, buffer_type, sub_device_ids);
}

void GlobalSemaphore::setup_buffer(BufferType buffer_type, tt::stl::Span<const SubDeviceId> sub_device_ids) {
void GlobalSemaphore::setup_buffer(
uint32_t initial_value, BufferType buffer_type, tt::stl::Span<const SubDeviceId> sub_device_ids) {
TT_FATAL(
buffer_type == BufferType::L1 or buffer_type == BufferType::L1_SMALL,
"Global semaphore can only be created for L1 buffer types");
Expand All @@ -58,8 +61,7 @@ void GlobalSemaphore::setup_buffer(BufferType buffer_type, tt::stl::Span<const S
shard_parameters,
std::nullopt);

this->host_buffer_ = std::vector<uint32_t>(num_cores, this->initial_value_);
this->reset_semaphore_value(sub_device_ids);
this->reset_semaphore_value(initial_value, sub_device_ids);
}

std::shared_ptr<GlobalSemaphore> GlobalSemaphore::create(
Expand All @@ -68,31 +70,39 @@ std::shared_ptr<GlobalSemaphore> GlobalSemaphore::create(
uint32_t initial_value,
BufferType buffer_type,
tt::stl::Span<const SubDeviceId> sub_device_ids) {
return std::make_shared<GlobalSemaphore>(device, cores, initial_value, buffer_type, sub_device_ids);
return std::make_shared<GlobalSemaphore>(device, cores, initial_value, buffer_type, sub_device_ids, Private());
}
std::shared_ptr<GlobalSemaphore> GlobalSemaphore::create(
Device* device,
CoreRangeSet&& cores,
uint32_t initial_value,
BufferType buffer_type,
tt::stl::Span<const SubDeviceId> sub_device_ids) {
return std::make_shared<GlobalSemaphore>(device, std::move(cores), initial_value, buffer_type, sub_device_ids);
return std::make_shared<GlobalSemaphore>(
device, std::move(cores), initial_value, buffer_type, sub_device_ids, Private());
}

Device* GlobalSemaphore::device() const { return device_; }

DeviceAddr GlobalSemaphore::address() const { return buffer_->address(); }

void GlobalSemaphore::reset_semaphore_value(tt::stl::Span<const SubDeviceId> sub_device_ids) {
void GlobalSemaphore::reset_semaphore_value(uint32_t reset_value, tt::stl::Span<const SubDeviceId> sub_device_ids) {
// Write the initial value to the semaphore to the device
// Only block for the slow dispatch case
if (this->device_->using_slow_dispatch()) {
detail::WriteToBuffer(*this->buffer_, this->host_buffer_);
tt::Cluster::instance().l1_barrier(this->device_->id());
} else {
EnqueueWriteBuffer(
this->device_->command_queue(), this->buffer_, this->host_buffer_.data(), false, sub_device_ids);
}
auto* device = this->device_;
device->push_work([device,
reset_value,
sub_device_ids = std::vector<SubDeviceId>(sub_device_ids.begin(), sub_device_ids.end()),
num_cores = this->cores_.num_cores(),
buffer = this->buffer_] {
std::vector<uint32_t> host_buffer(num_cores, reset_value);
if (device->using_slow_dispatch()) {
detail::WriteToBuffer(*buffer, host_buffer);
tt::Cluster::instance().l1_barrier(device->id());
} else {
EnqueueWriteBuffer(device->command_queue(), buffer, host_buffer, false, sub_device_ids);
}
});
}

} // namespace tt::tt_metal
Expand Down
24 changes: 14 additions & 10 deletions tt_metal/impl/buffers/global_semaphore.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,26 @@ class Buffer;
class Device;

class GlobalSemaphore {
struct Private {
explicit Private() = default;
};

public:
GlobalSemaphore(
Device* device,
const CoreRangeSet& cores,
uint32_t initial_value,
BufferType buffer_type = BufferType::L1,
tt::stl::Span<const SubDeviceId> sub_device_ids = {});
BufferType buffer_type,
tt::stl::Span<const SubDeviceId> sub_device_ids,
Private);

GlobalSemaphore(
Device* device,
CoreRangeSet&& cores,
uint32_t initial_value,
BufferType buffer_type = BufferType::L1,
tt::stl::Span<const SubDeviceId> sub_device_ids = {});
BufferType buffer_type,
tt::stl::Span<const SubDeviceId> sub_device_ids,
Private);

GlobalSemaphore(const GlobalSemaphore&) = default;
GlobalSemaphore& operator=(const GlobalSemaphore&) = default;
Expand All @@ -59,21 +65,19 @@ class GlobalSemaphore {

DeviceAddr address() const;

void reset_semaphore_value(tt::stl::Span<const SubDeviceId> sub_device_ids = {});
void reset_semaphore_value(uint32_t reset_value, tt::stl::Span<const SubDeviceId> sub_device_ids = {});

static constexpr auto attribute_names = std::forward_as_tuple("cores", "initial_value");
const auto attribute_values() const { return std::make_tuple(this->cores_, this->initial_value_); }
static constexpr auto attribute_names = std::forward_as_tuple("cores");
const auto attribute_values() const { return std::make_tuple(this->cores_); }

private:
void setup_buffer(BufferType buffer_type, tt::stl::Span<const SubDeviceId> sub_device_ids);
void setup_buffer(uint32_t initial_value, BufferType buffer_type, tt::stl::Span<const SubDeviceId> sub_device_ids);

// GlobalSemaphore is implemented as a wrapper around a sharded buffer
// This can be updated in the future to be its own container with optimized dispatch functions
std::shared_ptr<Buffer> buffer_;
std::vector<uint32_t> host_buffer_;
Device* device_;
CoreRangeSet cores_;
uint32_t initial_value_ = 0;
};

} // namespace v0
Expand Down
18 changes: 14 additions & 4 deletions ttnn/cpp/pybind11/global_semaphore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,20 @@ void py_module(py::module& module) {

module.def(
"reset_global_semaphore_value",
py::overload_cast<const std::shared_ptr<GlobalSemaphore>&, const std::vector<SubDeviceId>&>(
&reset_global_semaphore_value),
[](const std::shared_ptr<GlobalSemaphore>& global_semaphore,
uint32_t reset_value,
const std::vector<SubDeviceId>& sub_device_ids) {
ttnn::global_semaphore::reset_global_semaphore_value(global_semaphore, reset_value, sub_device_ids);
},
py::arg("global_semaphore"),
py::arg("reset_value"),
py::arg("sub_device_ids") = std::vector<SubDeviceId>(),
R"doc(
Reset the value of the global semaphore.
Args:
global_semaphore (GlobalSemaphore): The global semaphore object.
reset_value (int): The value to reset the global semaphore to.
sub_device_ids (List[ttnn.SubDeviceIds]): Sub-device IDs to wait on before writing the global semaphore value to device.
Defaults to waiting on all sub-devices.
)doc");
Expand Down Expand Up @@ -111,15 +116,20 @@ void py_module(py::module& module) {

module.def(
"reset_global_semaphore_value",
py::overload_cast<const MultiDeviceGlobalSemaphore&, const std::vector<SubDeviceId>&>(
&reset_global_semaphore_value),
[](const MultiDeviceGlobalSemaphore& global_semaphore,
uint32_t reset_value,
const std::vector<SubDeviceId>& sub_device_ids) {
ttnn::global_semaphore::reset_global_semaphore_value(global_semaphore, reset_value, sub_device_ids);
},
py::arg("global_semaphore"),
py::arg("reset_value"),
py::arg("sub_device_ids") = std::vector<SubDeviceId>(),
R"doc(
Reset the value of the global semaphore.
Args:
global_semaphore (GlobalSemaphore): The global semaphore object.
reset_value (int): The value to reset the global semaphore to.
sub_device_ids (List[ttnn.SubDeviceIds]): Sub-device IDs to wait on before writing the global semaphore value to device.
)doc");
}
Expand Down
Loading

0 comments on commit eb76f95

Please sign in to comment.