Skip to content

Commit

Permalink
#0: Expose Asynchronous Runtime APIs through TTNN
Browse files Browse the repository at this point in the history
  - Add multithreaded data movement, memory allocation, workload
    dispatch and synchronization APIs to ttnn::async_runtime
  - Expose a ttnn::DeviceBuffer object (inherting from
    tt_metal::DeviceBuffer and using the async engine for memory
    management). This object can be directly used to create
    DeviceStorage and device tensors
  - Add support for optional_output_tensors in launch_op
  - Add TTNN CPP test suite regressing on these APIs
  • Loading branch information
tt-asaigal committed May 9, 2024
1 parent ffc0f10 commit f3cdb4a
Show file tree
Hide file tree
Showing 45 changed files with 543 additions and 128 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ttnn-post-commit.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ jobs:
test-group: [
{name: ttnn group 1, cmd: pytest $TT_METAL_HOME/tests/ttnn/unit_tests -v --splits 2 --group 1},
{name: ttnn group 2, cmd: pytest $TT_METAL_HOME/tests/ttnn/unit_tests -v --splits 2 --group 2},
{name: ttnn cpp tests, cmd: ./build/test/ttnn/unit_tests},

]
name: ${{ matrix.test-group.name }} ${{ matrix.runner-info.arch }} ${{ matrix.runner-info.name }}
Expand Down
144 changes: 144 additions & 0 deletions tests/ttnn/unit_tests/test_async_runtime.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include "tensor/tensor.hpp"
#include "ttnn_multi_command_queue_fixture.hpp"
#include "tt_dnn/op_library/eltwise_binary/eltwise_binary_op.hpp"
#include "tt_dnn/op_library/moreh_sum/moreh_sum_op.hpp"
#include "common/bfloat16.hpp"
#include "ttnn/cpp/ttnn/async_runtime.hpp"
#include "tt_numpy/functions.hpp"
#include <cmath>

using namespace tt;
using namespace tt_metal;
using MultiCommandQueueSingleDeviceFixture = ttnn::MultiCommandQueueSingleDeviceFixture;
using namespace constants;

TEST_F(MultiCommandQueueSingleDeviceFixture, TestAsyncPreallocatedOutputs) {
Device* device = this->device_;
MemoryConfig mem_cfg = MemoryConfig{
.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED,
.buffer_type = BufferType::DRAM,
.shard_spec = std::nullopt};

uint32_t input_buf_size_datums = 1024 * 1024;
uint32_t output_buf_size_datums = 1024 * 32;
uint32_t datum_size_bytes = 2;
uint32_t io_cq = 1; // Data reads and writes done through CQ0
uint32_t workload_dispatch_cq = 0; // Workload dispatched through CQ1

ttnn::Shape input_shape = ttnn::Shape(Shape({1, 1, 1024, 1024}));
auto host_data = std::shared_ptr<bfloat16 []>(new bfloat16[input_buf_size_datums]);
auto readback_data = std::shared_ptr<bfloat16 []>(new bfloat16[output_buf_size_datums]);


for (int i = 0; i < input_buf_size_datums; i++) {
host_data[i] = bfloat16(static_cast<float>(1));
}
// Create golden data using tt_eager APIs
Tensor np_tensor = tt::numpy::full<float>(input_shape.value(), static_cast<float>(1), DataType::BFLOAT16).to(Layout::TILE).to(device);
std::vector<int64_t> reduce_dims = {3};
Tensor np_out = tt::operations::primary::moreh_sum(np_tensor, reduce_dims);
Tensor np_out_host = np_out.cpu();
const bfloat16* golden_output = std::get<owned_buffer::Buffer<bfloat16>>(std::get<OwnedStorage>(np_out_host.get_storage()).buffer).begin();
// Enable Asynchronous Execution and test ttnn runtime APIs
device->set_worker_mode(WorkExecutorMode::ASYNCHRONOUS);
// Events for host - device synchronization
auto write_event = std::make_shared<Event>();
auto workload_event = std::make_shared<Event>();
// Running sum-reduce with preallocated output
auto op = tt::operations::primary::MorehSum{.dim = 3};
// Preallocate Input and Output Tensors on Device
auto input_buffer = ttnn::allocate_buffer_on_device(input_buf_size_datums * datum_size_bytes, device, input_shape, DataType::BFLOAT16, Layout::TILE, mem_cfg);
auto output_buffer = ttnn::allocate_buffer_on_device(output_buf_size_datums * datum_size_bytes, device, np_out.get_shape(), DataType::BFLOAT16, Layout::TILE, mem_cfg);
auto input_storage = tt::tt_metal::DeviceStorage{input_buffer};
auto output_storage = tt::tt_metal::DeviceStorage{output_buffer};
Tensor input_tensor = Tensor(input_storage, input_shape, DataType::BFLOAT16, Layout::TILE);
Tensor output_tensor = Tensor(output_storage, np_out.get_shape(), DataType::BFLOAT16, Layout::TILE);
// Populate input_tensor with data
ttnn::write_buffer(io_cq, input_tensor, {host_data});
// Record the completion of the write event
ttnn::record_event(device->command_queue(io_cq), write_event);
// Host stalls until write is completed, before sending workload
ttnn::event_synchronize(write_event);
// Dispatch workload. Preallocated output_tensor is populated by op/
ttnn::run_operation(workload_dispatch_cq, op, {input_tensor}, {}, {output_tensor}).at(0);
// Record completion of workload
ttnn::record_event(device->command_queue(workload_dispatch_cq), workload_event);
ttnn::event_synchronize(workload_event);
// Read output back, once workload is complete
ttnn::read_buffer(io_cq, output_tensor, {readback_data});
// Ensure that reference count book keeping is done correctly
// Tensors only have one reference in the main thread. Ensure this is true.
EXPECT_EQ(input_tensor.tensor_attributes->main_thread_ref_count, 1);
EXPECT_EQ(output_tensor.tensor_attributes->main_thread_ref_count, 1);
// Buffers are currently jointly owned by the original buffer object, the storage object and the tensor (3).
EXPECT_EQ(input_buffer.use_count(), 3);
EXPECT_EQ(output_buffer.use_count(), 3);
// Deallocate tensors (tensor gives up buffer). Done asynchronously, so sync on queue after.
input_tensor.deallocate();
output_tensor.deallocate();
ttnn::queue_synchronize(device->command_queue(io_cq));
// Buffer only has 2 owners in main thread.
EXPECT_EQ(input_buffer.use_count(), 2);
EXPECT_EQ(output_buffer.use_count(), 2);
for (int i = 0; i < output_buf_size_datums; i++) {
EXPECT_EQ(readback_data[i], golden_output[i]);
}
}

TEST_F(MultiCommandQueueSingleDeviceFixture, TestAsyncRuntimeAllocatedBuffers) {
Device* device = this->device_;
device->set_worker_mode(WorkExecutorMode::ASYNCHRONOUS);
MemoryConfig mem_cfg = MemoryConfig{
.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED,
.buffer_type = BufferType::DRAM,
.shard_spec = std::nullopt};

uint32_t buf_size_datums = 1024 * 1024;
uint32_t datum_size_bytes = 2;
std::vector<uint32_t> inputs = {4, 9, 16, 25, 36, 64};
uint32_t io_cq = 1;
uint32_t workload_dispatch_cq = 0;
ttnn::Shape shape = ttnn::Shape(Shape({1, 1, 1024, 1024}));

auto host_data = std::shared_ptr<bfloat16 []>(new bfloat16[buf_size_datums]);
auto readback_data = std::shared_ptr<bfloat16 []>(new bfloat16[buf_size_datums]);
for (int loop = 0; loop < 10; loop++) {
log_info(LogTest, "Running outer loop {}", loop);
for (auto input_val : inputs) {
for (int i = 0; i < buf_size_datums; i++) {
host_data[i] = bfloat16(static_cast<float>(input_val));
}

auto write_event = std::make_shared<Event>();
auto workload_event = std::make_shared<Event>();
auto input_buffer = ttnn::allocate_buffer_on_device(buf_size_datums * datum_size_bytes, device, shape, DataType::BFLOAT16, Layout::TILE, mem_cfg);
auto input_storage = tt::tt_metal::DeviceStorage{input_buffer};
Tensor input_tensor = Tensor(input_storage, shape, DataType::BFLOAT16, Layout::TILE);
ttnn::write_buffer(io_cq, input_tensor, {host_data}); // Write using cq 1
ttnn::record_event(device->command_queue(io_cq), write_event); // Record write on cq 1
// Wait until cq 1 write is complete
ttnn::event_synchronize(write_event);
auto op0 = tt::tt_metal::EltwiseUnary{std::vector{tt::tt_metal::UnaryWithParam{tt::tt_metal::UnaryOpType::SQRT}}};
auto op1 = tt::tt_metal::EltwiseUnary{std::vector{tt::tt_metal::UnaryWithParam{tt::tt_metal::UnaryOpType::NEG}}};
// Run operation on cq 0
Tensor output_tensor = ttnn::run_operation(workload_dispatch_cq, op0, {input_tensor}).at(0);
auto dummy_buffer_0 = ttnn::allocate_buffer_on_device(buf_size_datums * datum_size_bytes, device, shape, DataType::BFLOAT16, Layout::TILE, mem_cfg);
output_tensor = ttnn::run_operation(workload_dispatch_cq, op1, {output_tensor}).at(0);
// Allocate this buffer to stress test async allocation across op execution and explicit allocation
auto dummy_buffer_1 = ttnn::allocate_buffer_on_device(buf_size_datums * datum_size_bytes, device, shape, DataType::BFLOAT16, Layout::TILE, mem_cfg);
// Record cq 0 prog execution
ttnn::record_event(device->command_queue(workload_dispatch_cq), workload_event);
// Wait until cq 0 prog execution is done
ttnn::event_synchronize(workload_event);
// Read using cq 1
ttnn::read_buffer(io_cq, output_tensor, {readback_data});
for (int i = 0; i < buf_size_datums; i++) {
EXPECT_EQ(static_cast<int>(floor(bfloat16(readback_data[i]).to_float())), static_cast<int>(-1 * sqrt(input_val)));
}
}
}
}
38 changes: 38 additions & 0 deletions tests/ttnn/unit_tests/ttnn_multi_command_queue_fixture.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include "gtest/gtest.h"
#include "tt_metal/host_api.hpp"
#include "tt_metal/test_utils/env_vars.hpp"
#include "tt_metal/impl/dispatch/command_queue.hpp"
#include "tt_metal/llrt/rtoptions.hpp"

namespace ttnn {

class MultiCommandQueueSingleDeviceFixture : public ::testing::Test {
protected:
void SetUp() override {
auto slow_dispatch = getenv("TT_METAL_SLOW_DISPATCH_MODE");
arch_ = tt::get_arch_from_string(tt::test_utils::get_env_arch_name());
num_devices_ = tt::tt_metal::GetNumAvailableDevices();
if (slow_dispatch) {
GTEST_SKIP() << "Skipping Multi CQ test suite, since it can only be run in Fast Dispatch Mode.";
}

if (arch_ == tt::ARCH::WORMHOLE_B0 and num_devices_ != 1) {
device_ = tt::tt_metal::CreateDevice(0); // Create device here so teardown can gracefully run
GTEST_SKIP() << "Skipping for Multi-Chip Wormhole, since not enough dispatch cores.";
}
device_ = tt::tt_metal::CreateDevice(0, 2);
}

void TearDown() override {
tt::tt_metal::CloseDevice(device_);
}

tt::tt_metal::Device* device_;
tt::ARCH arch_;
size_t num_devices_;
};
}
8 changes: 2 additions & 6 deletions tt_eager/tensor/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -853,9 +853,7 @@ void* get_raw_host_data_ptr(const Tensor& tensor) {
}

void memcpy(CommandQueue& queue, void* dst, const Tensor& src, const std::optional<std::size_t> transfer_size) {
if (not transfer_size.has_value()) {
TT_ASSERT("transfer_size is not supported for memcpy right now!");
}
TT_ASSERT(not transfer_size.has_value(), "transfer_size is not supported for memcpy right now!");
if (not is_device_tensor(src)) {
TT_THROW("memcpy: src tensor must be on device");
}
Expand All @@ -872,9 +870,7 @@ void memcpy(void* dst, const Tensor& src, const std::optional<std::size_t> trans
}

void memcpy(CommandQueue& queue, Tensor& dst, const void* src, const std::optional<std::size_t> transfer_size) {
if (not transfer_size.has_value()) {
TT_ASSERT("transfer_size is not supported for memcpy right now!");
}
TT_ASSERT(not transfer_size.has_value(), "transfer_size is not supported for memcpy right now!");
if (not is_device_tensor(dst)) {
TT_THROW("memcpy: memcpy to non-device tensor is not supported!");
}
Expand Down
35 changes: 19 additions & 16 deletions tt_eager/tensor/tensor_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,22 +101,7 @@ std::array<uint32_t, 2> get_sharded_page_shape(Layout layout, DataType dtype, s
return page_shape;
}

namespace detail {

DeviceBuffer allocate_interleaved_buffer_on_device(uint32_t buffer_size_bytes, Device *device, const Shape& shape, DataType data_type, Layout layout, const MemoryConfig& memory_config) {
uint32_t page_size = get_page_size(data_type, layout, buffer_size_bytes, shape);
return std::make_shared<Buffer>(device, buffer_size_bytes, page_size, memory_config.buffer_type);
}

DeviceBuffer allocate_contiguous_buffer_on_device(uint32_t buffer_size_bytes, Device *device, const MemoryConfig& memory_config) {
return std::make_shared<Buffer>(device, buffer_size_bytes, buffer_size_bytes, memory_config.buffer_type);
}


DeviceBuffer allocate_sharded_buffer_on_device(uint32_t buffer_size_bytes, Device *device,
const Shape& shape, DataType data_type, Layout layout,
std::optional<ShardSpecBuffer> shard_params,
const MemoryConfig& memory_config) {
void validate_sharded_buffer_allocation(const Shape& shape, Layout layout, std::optional<ShardSpecBuffer> shard_params, const MemoryConfig& memory_config) {
TT_ASSERT(shard_params.has_value(), "Shard params are required for sharded buffer and they were not initialized");

auto shard_spec = memory_config.shard_spec.value();
Expand Down Expand Up @@ -158,7 +143,25 @@ DeviceBuffer allocate_sharded_buffer_on_device(uint32_t buffer_size_bytes, Devic
// Require alignment for now
// TT_ASSERT(shard_shape[1] * tensor_impl::element_size_bytes_wrapper(data_type) % ADDRESS_ALIGNMENT == 0);
}
}

namespace detail {

DeviceBuffer allocate_interleaved_buffer_on_device(uint32_t buffer_size_bytes, Device *device, const Shape& shape, DataType data_type, Layout layout, const MemoryConfig& memory_config) {
uint32_t page_size = get_page_size(data_type, layout, buffer_size_bytes, shape);
return std::make_shared<Buffer>(device, buffer_size_bytes, page_size, memory_config.buffer_type);
}

DeviceBuffer allocate_contiguous_buffer_on_device(uint32_t buffer_size_bytes, Device *device, const MemoryConfig& memory_config) {
return std::make_shared<Buffer>(device, buffer_size_bytes, buffer_size_bytes, memory_config.buffer_type);
}


DeviceBuffer allocate_sharded_buffer_on_device(uint32_t buffer_size_bytes, Device *device,
const Shape& shape, DataType data_type, Layout layout,
std::optional<ShardSpecBuffer> shard_params,
const MemoryConfig& memory_config) {
validate_sharded_buffer_allocation(shape, layout, shard_params, memory_config);
auto page_shape = shard_params.value().page_shape;
uint32_t size_of_element = element_size_bytes_wrapper(data_type);
uint32_t page_size = page_shape[0] * page_shape[1] * size_of_element;
Expand Down
5 changes: 4 additions & 1 deletion tt_eager/tensor/tensor_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ inline std::vector<T> convert_layout_tile_to_row_major(const Shape& shape, const
// Validators
// ======================================================================================
void validate_on_device_dtype_and_layout(Device* device, const Shape& shape, DataType dtype, Layout layout);

void validate_sharded_buffer_allocation(const Shape& shape, Layout layout, std::optional<ShardSpecBuffer> shard_params, const MemoryConfig& memory_config);
// -----------------------------------------------------------------------------------------------------------------------------------------------
// ===============================================================================================================================================
// High Level APIs
Expand All @@ -220,6 +220,9 @@ void validate_on_device_dtype_and_layout(Device* device, const Shape& shape, Dat
// ======================================================================================
// Data reader, writer, and initializers
// ======================================================================================

uint32_t get_page_size(DataType dtype, Layout layout, uint32_t total_size_bytes, const Shape& shape);

DeviceBuffer allocate_buffer_on_device(
uint32_t buffer_size_bytes,
Device* device,
Expand Down
2 changes: 1 addition & 1 deletion tt_eager/tt_dnn/op_library/all_gather/all_gather_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ std::vector<Tensor> all_gather_impl(const std::vector<Tensor>& input_tensors, co
// Package output in vector, to populate it with launch_op
std::vector<Tensor> output_for_curr_device = {output_tensors[i]};
operation::launch_op(
[is_ring, dim, num_links, i, num_inputs, output_mem_config, topology] (const std::vector<Tensor>& input_tensors, const std::vector<std::optional<const Tensor>>& optional_input_tensors) mutable -> std::vector<Tensor> {
[is_ring, dim, num_links, i, num_inputs, output_mem_config, topology] (const std::vector<Tensor>& input_tensors, const std::vector<std::optional<const Tensor>>& optional_input_tensors, const std::vector<std::optional<Tensor>>& optional_output_tensors) mutable -> std::vector<Tensor> {
bool is_last_chip_in_clockwise_direction = is_ring ? false : i == (num_inputs - 1);
bool is_last_chip_in_counter_clockwise_direction = is_ring ? false : i == 0;

Expand Down
2 changes: 1 addition & 1 deletion tt_eager/tt_dnn/op_library/bcast/bcast_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ inline Tensor bcast(

std::vector<Tensor> output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor_a}))};
operation::launch_with_autoformat(
[bcast_op, bcast_dim, output_mem_config] (const std::vector<Tensor>& input_tensors, const std::vector<std::optional<const Tensor>>& optional_input_tensors) mutable -> std::vector<Tensor> {
[bcast_op, bcast_dim, output_mem_config] (const std::vector<Tensor>& input_tensors, const std::vector<std::optional<const Tensor>>& optional_input_tensors, const std::vector<std::optional<Tensor>>& optional_output_tensors) mutable -> std::vector<Tensor> {
using tt::constants::TILE_HEIGHT;
using tt::constants::TILE_WIDTH;
auto& input_tensor_a = input_tensors.at(0);
Expand Down
Loading

0 comments on commit f3cdb4a

Please sign in to comment.