Skip to content

Commit

Permalink
#10332: Make ttnn::event_synchronize block only in the app thread (#1…
Browse files Browse the repository at this point in the history
…1543)

- Add multi-threaded event sync test
  • Loading branch information
tt-asaigal authored Aug 23, 2024
1 parent fc937bf commit 14dabb4
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 9 deletions.
4 changes: 2 additions & 2 deletions tests/ttnn/unit_tests/gtests/test_async_runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,12 @@ TEST_F(MultiCommandQueueSingleDeviceFixture, TestAsyncPreallocatedOutputs) {
// Record the completion of the write event
ttnn::record_event(device->command_queue(io_cq), write_event);
// Host stalls until write is completed, before sending workload
ttnn::event_synchronize(device, write_event);
ttnn::event_synchronize(write_event);
// Dispatch workload. Preallocated output_tensor is populated by op/
ttnn::moreh_sum(workload_dispatch_cq, input_tensor, /*dim*/3, false, output_tensor);
// Record completion of workload
ttnn::record_event(device->command_queue(workload_dispatch_cq), workload_event);
ttnn::event_synchronize(device, workload_event);
ttnn::event_synchronize(workload_event);
// Read output back, once workload is complete
ttnn::read_buffer(io_cq, output_tensor, {readback_data});
// Ensure that reference count book keeping is done correctly
Expand Down
63 changes: 63 additions & 0 deletions tests/ttnn/unit_tests/gtests/test_multiprod_queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "common/bfloat16.hpp"
#include "ttnn/async_runtime.hpp"
#include "tt_numpy/functions.hpp"
#include "tt_metal/impl/event/event.hpp"
#include <cmath>
#include <thread>

Expand Down Expand Up @@ -90,3 +91,65 @@ TEST_F(MultiCommandQueueSingleDeviceFixture, TestMultiProducerLockBasedQueue) {
t1.join();

}

TEST_F(MultiCommandQueueSingleDeviceFixture, TestMultiAppThreadSync) {
// Verify that the event_synchronize API stalls the calling thread until
// the device records the event being polled.
// Thread 0 = writer thread. Thread 1 = reader thread.
// Reader cannot read until writer has correctly updated a memory location.
// Writer cannot update location until reader has picked up data.
// Use write_event to stall reader and read_event to stall writer.
Device* device = this->device_;
// Enable async engine and set queue setting to lock_based
device->set_worker_mode(WorkExecutorMode::ASYNCHRONOUS);
device->set_worker_queue_mode(WorkerQueueMode::LOCKBASED);

MemoryConfig mem_cfg = MemoryConfig{
.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED,
.buffer_type = BufferType::DRAM,
.shard_spec = std::nullopt
};
uint32_t write_cq = 0;
uint32_t read_cq = 0;
uint32_t tensor_buf_size = 1024 * 1024;
uint32_t datum_size_bytes = 2;

std::shared_ptr<Event> write_event = std::make_shared<Event>();
std::shared_ptr<Event> read_event = std::make_shared<Event>();

ttnn::Shape tensor_shape = ttnn::Shape(Shape({1, 1, 1024, 1024}));
auto host_data = std::shared_ptr<bfloat16 []>(new bfloat16[tensor_buf_size]);
auto allocated_buffer = ttnn::allocate_buffer_on_device(tensor_buf_size * datum_size_bytes, device, tensor_shape, DataType::BFLOAT16, Layout::TILE, mem_cfg);
auto allocated_storage = tt::tt_metal::DeviceStorage{allocated_buffer};
auto allocated_tensor = Tensor(allocated_storage, tensor_shape, DataType::BFLOAT16, Layout::TILE);
auto readback_data = std::shared_ptr<bfloat16 []>(new bfloat16[tensor_buf_size]);

std::thread t0([&] () {
for (int j = 0; j < 1000; j++) {
if (j != 0) {
ttnn::event_synchronize(read_event);
}
read_event = std::make_shared<Event>();
for (int i = 0; i < tensor_buf_size; i++) {
host_data[i] = bfloat16(static_cast<float>(2 + j));
}
ttnn::write_buffer(write_cq, allocated_tensor, {host_data});
ttnn::record_event(device->command_queue(write_cq), write_event);
}
});

std::thread t1([&] () {
for (int j = 0; j < 1000; j++) {
ttnn::event_synchronize(write_event);
write_event = std::make_shared<Event>();
ttnn::read_buffer(read_cq, allocated_tensor, {readback_data});
for (int i = 0; i < tensor_buf_size; i++) {
EXPECT_EQ(readback_data[i], host_data[i]);
}
ttnn::record_event(device->command_queue(read_cq), read_event);
}
});

t0.join();
t1.join();
}
7 changes: 1 addition & 6 deletions ttnn/cpp/ttnn/async_runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,7 @@ void queue_synchronize(CommandQueue& cq) {
Finish(cq);
}

void event_synchronize(Device* device, std::shared_ptr<Event> event) {
device->push_work([event] () {
EventSynchronize(event);
});
device->synchronize();
}
void event_synchronize(std::shared_ptr<Event> event) { EventSynchronize(event); }

bool event_query(std::shared_ptr<Event> event) { return EventQuery(event); }

Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/ttnn/async_runtime.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ namespace ttnn {

void queue_synchronize(CommandQueue& cq);

void event_synchronize(Device* device, std::shared_ptr<Event> event);
void event_synchronize(std::shared_ptr<Event> event);

bool event_query(std::shared_ptr<Event> event);

Expand Down

0 comments on commit 14dabb4

Please sign in to comment.