From f9508de078b250006f9179968e663bb70f5b64c5 Mon Sep 17 00:00:00 2001 From: asaigal Date: Mon, 20 Jan 2025 21:39:14 +0000 Subject: [PATCH] #0: Add WriteShard and ReadShard MeshBuffer APIs and resolve MeshBuffer dealloc issues - Add tests for reading and writing shards with Interleaved and Sharded configs - Add test for deallocation, verying addresses --- .../tt_metal/distributed/test_mesh_buffer.cpp | 149 +++++++++++++++++- tt_metal/distributed/distributed.hpp | 22 +++ tt_metal/distributed/mesh_buffer.cpp | 79 ++++++---- tt_metal/distributed/mesh_buffer.hpp | 30 ++-- tt_metal/distributed/mesh_command_queue.cpp | 81 ++++++++++ tt_metal/distributed/mesh_command_queue.hpp | 16 ++ 6 files changed, 328 insertions(+), 49 deletions(-) diff --git a/tests/tt_metal/distributed/test_mesh_buffer.cpp b/tests/tt_metal/distributed/test_mesh_buffer.cpp index 5d451de33bd..b24cf72acce 100644 --- a/tests/tt_metal/distributed/test_mesh_buffer.cpp +++ b/tests/tt_metal/distributed/test_mesh_buffer.cpp @@ -9,13 +9,56 @@ #include #include "tests/tt_metal/tt_metal/common/multi_device_fixture.hpp" -#include "tt_metal/distributed/mesh_buffer.hpp" +#include "tt_metal/distributed/distributed.hpp" namespace tt::tt_metal::distributed::test { namespace { using MeshBufferTest = T3000MultiDeviceFixture; +class DeviceLocalShardedBufferTestConfig { +public: + std::array num_pages_per_core; + std::array num_cores; + std::array page_shape; + uint32_t element_size = 1; + TensorMemoryLayout mem_config = TensorMemoryLayout::HEIGHT_SHARDED; + ShardOrientation shard_orientation = ShardOrientation::ROW_MAJOR; + + DeviceLocalShardedBufferTestConfig( + const std::array& num_pages_per_core_, + const std::array& num_cores_, + const std::array page_shape_, + const TensorMemoryLayout& shard_strategy_) { + this->num_pages_per_core = num_pages_per_core_; + this->num_cores = num_cores_; + this->page_shape = page_shape_; + this->mem_config = shard_strategy_; + } + + std::array tensor2d_shape() { + return {num_pages_per_core[0] * num_cores[0], num_pages_per_core[1] * num_cores[1]}; + } + + uint32_t num_pages() { return tensor2d_shape()[0] * tensor2d_shape()[1]; } + + std::array shard_shape() { + return {num_pages_per_core[0] * page_shape[0], num_pages_per_core[1] * page_shape[1]}; + } + + CoreRangeSet shard_grid() { + return CoreRangeSet(std::set( + {CoreRange(CoreCoord(0, 0), CoreCoord(this->num_cores[0] - 1, this->num_cores[1] - 1))})); + } + + uint32_t page_size() { return page_shape[0] * page_shape[1] * element_size; } + + ShardSpecBuffer shard_parameters() { + return ShardSpecBuffer( + this->shard_grid(), this->shard_shape(), this->shard_orientation, this->page_shape, this->tensor2d_shape()); + } +}; + TEST_F(MeshBufferTest, ConfigValidation) { const DeviceLocalBufferConfig device_local_config{ .page_size = 1024, @@ -78,6 +121,10 @@ TEST_F(MeshBufferTest, ReplicatedBufferInitialization) { } TEST_F(MeshBufferTest, Deallocation) { + // Verify that a buffer is deallocated on the MeshDevice when it goes + // out of scope on host. Create a buffer with a certain config in limited + // scope. Record its address. Create another buffer with the same config + // outside the scope. Verify that addresses match. const DeviceLocalBufferConfig device_local_config{ .page_size = 1024, .buffer_type = BufferType::DRAM, @@ -85,15 +132,13 @@ TEST_F(MeshBufferTest, Deallocation) { .bottom_up = false}; const ReplicatedBufferConfig buffer_config{.size = 16 << 10}; - std::shared_ptr buffer; - Allocator* allocator = nullptr; + uint32_t expected_address = 0; { auto replicated_buffer = MeshBuffer::create(buffer_config, device_local_config, mesh_device_.get()); - buffer = replicated_buffer->get_device_buffer(Coordinate{0, 0}); - allocator = buffer->allocator(); - EXPECT_TRUE(allocator->allocated_buffers.contains(buffer.get())); + expected_address = replicated_buffer->address(); } - EXPECT_FALSE(allocator->allocated_buffers.contains(buffer.get())); + auto replicated_buffer = MeshBuffer::create(buffer_config, device_local_config, mesh_device_.get()); + EXPECT_EQ(replicated_buffer->address(), expected_address); } TEST_F(MeshBufferTest, GetDeviceBuffer) { @@ -112,5 +157,95 @@ TEST_F(MeshBufferTest, GetDeviceBuffer) { EXPECT_NO_THROW(replicated_buffer->get_device_buffer(Coordinate{1, 3})); } +TEST_F(MeshBufferTest, TestInterleavedShardsReadWrite) { + constexpr uint32_t NUM_ITERS = 100; + uint32_t seed = tt::parse_env("TT_METAL_SEED", 0); + uint32_t single_tile_size = ::tt::tt_metal::detail::TileSize(DataFormat::Float16_b); + + for (auto buffer_type : {BufferType::L1, BufferType::DRAM}) { + DeviceLocalBufferConfig per_device_buffer_config{ + .page_size = single_tile_size, + .buffer_type = BufferType::L1, + .buffer_layout = TensorMemoryLayout::INTERLEAVED, + .bottom_up = false}; + + std::uniform_int_distribution gen_num_tiles(1, 1024); + std::mt19937 rng(seed); + for (int i = 0; i < NUM_ITERS; i++) { + uint32_t num_random_tiles = gen_num_tiles(rng); + ReplicatedBufferConfig global_buffer_config = { + .size = num_random_tiles * single_tile_size, + }; + + std::shared_ptr buf = + MeshBuffer::create(global_buffer_config, per_device_buffer_config, mesh_device_.get()); + + std::vector src_vec = create_constant_vector_of_bfloat16(num_random_tiles * single_tile_size, i); + for (std::size_t logical_x = 0; logical_x < buf->device()->num_cols(); logical_x++) { + for (std::size_t logical_y = 0; logical_y < buf->device()->num_rows(); logical_y++) { + WriteShard(mesh_device_->mesh_command_queue(), buf, src_vec, Coordinate(logical_y, logical_x)); + } + } + + for (std::size_t logical_x = 0; logical_x < buf->device()->num_cols(); logical_x++) { + for (std::size_t logical_y = 0; logical_y < buf->device()->num_rows(); logical_y++) { + std::vector dst_vec = {}; + ReadShard(mesh_device_->mesh_command_queue(), dst_vec, buf, Coordinate(logical_y, logical_x)); + for (int j = 0; j < dst_vec.size(); j++) { + EXPECT_EQ(dst_vec[j].to_float(), i); + } + } + } + } + } +} + +TEST_F(MeshBufferTest, TestDeviceLocalMeshBufferSharding) { + CoreCoord core_grid_size = mesh_device_->compute_with_storage_grid_size(); + std::vector> num_pages_per_core_vec = {{1, 1}, {3, 137}, {67, 4}, {7, 11}, {2, 2}}; + std::vector> page_shapes = {{1, 1024}, {1, 2048}, {1, 4}, {32, 32}, {1, 120}}; + std::vector shard_strategies = { + TensorMemoryLayout::HEIGHT_SHARDED, TensorMemoryLayout::WIDTH_SHARDED, TensorMemoryLayout::BLOCK_SHARDED}; + + for (const auto shard_strategy : shard_strategies) { + for (const auto& num_pages_per_core : num_pages_per_core_vec) { + for (const auto& page_shape : page_shapes) { + DeviceLocalShardedBufferTestConfig test_config( + num_pages_per_core, {core_grid_size.x, core_grid_size.y}, page_shape, shard_strategy); + DeviceLocalBufferConfig per_device_buffer_config{ + .page_size = test_config.page_size(), + .buffer_type = BufferType::L1, + .buffer_layout = test_config.mem_config, + .shard_parameters = test_config.shard_parameters(), + .bottom_up = false}; + + uint32_t buf_size = test_config.num_pages() * test_config.page_size(); + ReplicatedBufferConfig global_buffer_config{ + .size = buf_size, + }; + auto buf = MeshBuffer::create(global_buffer_config, per_device_buffer_config, mesh_device_.get()); + std::vector src_vec(buf_size / sizeof(uint32_t), 0); + std::iota(src_vec.begin(), src_vec.end(), 0); + + for (std::size_t logical_x = 0; logical_x < buf->device()->num_cols(); logical_x++) { + for (std::size_t logical_y = 0; logical_y < buf->device()->num_rows(); logical_y++) { + WriteShard(mesh_device_->mesh_command_queue(), buf, src_vec, Coordinate(logical_y, logical_x)); + } + } + + for (std::size_t logical_x = 0; logical_x < buf->device()->num_cols(); logical_x++) { + for (std::size_t logical_y = 0; logical_y < buf->device()->num_rows(); logical_y++) { + std::vector dst_vec = {}; + ReadShard(mesh_device_->mesh_command_queue(), dst_vec, buf, Coordinate(logical_y, logical_x)); + for (int j = 0; j < dst_vec.size(); j++) { + EXPECT_EQ(dst_vec[j], j); + } + } + } + } + } + } +} + } // namespace } // namespace tt::tt_metal::distributed::test diff --git a/tt_metal/distributed/distributed.hpp b/tt_metal/distributed/distributed.hpp index 75d7839e19c..6e10b142f56 100644 --- a/tt_metal/distributed/distributed.hpp +++ b/tt_metal/distributed/distributed.hpp @@ -24,6 +24,28 @@ void AddProgramToMeshWorkload(MeshWorkload& mesh_workload, Program& program, con void EnqueueMeshWorkload(MeshCommandQueue& mesh_cq, MeshWorkload& mesh_workload, bool blocking); +template +void WriteShard( + MeshCommandQueue& mesh_cq, + std::shared_ptr& mesh_buffer, + std::vector& src, + const Coordinate& coord, + bool blocking = false) { + mesh_cq.enqueue_write_shard(mesh_buffer, src.data(), coord, blocking); +} + +template +void ReadShard( + MeshCommandQueue& mesh_cq, + std::vector& dst, + std::shared_ptr& mesh_buffer, + const Coordinate& coord, + bool blocking = true) { + auto shard = mesh_buffer->get_device_buffer(coord); + dst.resize(shard->page_size() * shard->num_pages() / sizeof(DType)); + mesh_cq.enqueue_read_shard(dst.data(), mesh_buffer, coord, blocking); +} + void Finish(MeshCommandQueue& mesh_cq); } // namespace distributed diff --git a/tt_metal/distributed/mesh_buffer.cpp b/tt_metal/distributed/mesh_buffer.cpp index 153e40ed3fd..991f2edc266 100644 --- a/tt_metal/distributed/mesh_buffer.cpp +++ b/tt_metal/distributed/mesh_buffer.cpp @@ -59,38 +59,25 @@ std::shared_ptr MeshBuffer::create( }}, mesh_buffer_config); - // Rely on the single device allocator to provide the address for the entire mesh buffer. - // TODO: use mesh allocator, when available. - std::shared_ptr backing_buffer = Buffer::create( - mesh_device->get_device(0, 0), - /*address=*/address.value_or(0), - device_local_size, - device_local_config.page_size, - device_local_config.buffer_type, - device_local_config.buffer_layout, - device_local_config.shard_parameters, - device_local_config.bottom_up); std::shared_ptr mesh_buffer; if (!address.has_value()) { - *address = tt::tt_metal::detail::AllocateBuffer(backing_buffer.get()); - auto* backing_buffer_ptr = backing_buffer.get(); + // Rely on the single device allocator to provide the address for the entire mesh buffer. + // The address provided to the backing buffer is used as the address for the MeshBuffer object. + // TODO: use mesh allocator, when available. + std::shared_ptr backing_buffer = Buffer::create( + mesh_device->get_device(0, 0), + device_local_size, + device_local_config.page_size, + device_local_config.buffer_type, + device_local_config.buffer_layout, + device_local_config.shard_parameters, + device_local_config.bottom_up); + mesh_buffer = std::shared_ptr( - new MeshBuffer( - mesh_buffer_config, - device_local_config, - *address, - device_local_size, - mesh_device, - std::move(backing_buffer)), - [backing_buffer_ptr](MeshBuffer*) { tt::tt_metal::detail::DeallocateBuffer(backing_buffer_ptr); }); + new MeshBuffer(mesh_buffer_config, device_local_config, device_local_size, mesh_device, backing_buffer)); } else { - mesh_buffer = std::shared_ptr(new MeshBuffer( - mesh_buffer_config, - device_local_config, - *address, - device_local_size, - mesh_device, - std::move(backing_buffer))); + mesh_buffer = std::shared_ptr( + new MeshBuffer(mesh_buffer_config, device_local_config, address.value(), device_local_size, mesh_device)); } mesh_buffer->allocate(); @@ -99,6 +86,13 @@ std::shared_ptr MeshBuffer::create( } void MeshBuffer::allocate() { + if (backing_buffer_) { + TT_FATAL( + !address_, "The address for a MeshBuffer should not explicitly be initialized when it is being allocated"); + address_ = backing_buffer_->address(); + } else { + TT_FATAL(address_, "A MeshBuffer should be provided a valid address if its not being allocated"); + } buffers_ = std::vector>>( mesh_device_->num_rows(), std::vector>(mesh_device_->num_cols())); @@ -117,11 +111,7 @@ void MeshBuffer::allocate() { for (int row = 0; row < mesh_device_->num_rows(); row++) { for (int col = 0; col < mesh_device_->num_cols(); col++) { - if (row == 0 and col == 0) { - buffers_[row][col] = backing_buffer_; - } else { - buffers_[row][col] = allocate_device_buffer_at_address(Coordinate{row, col}); - } + buffers_[row][col] = allocate_device_buffer_at_address(Coordinate{row, col}); } } } @@ -156,4 +146,27 @@ const ShardedBufferConfig& MeshBuffer::global_shard_spec() const { return std::get(config_); } +uint32_t MeshBuffer::datum_size_bytes() const { + // Limitation for now. + TT_FATAL( + this->global_layout() == MeshBufferLayout::SHARDED, + "Can only query datum size for buffers sharded across the Mesh"); + return this->global_shard_spec().compute_datum_size_bytes(); +} + +Shape2D MeshBuffer::physical_shard_shape() const { + TT_FATAL( + this->global_layout() == MeshBufferLayout::SHARDED, + "Can only query physical shard shape for buffers sharded across the Mesh"); + auto sharded_config = std::get(config_); + Shape2D physical_shard_shape = sharded_config.shard_shape; + if (physical_shard_shape.height() == 0) { + physical_shard_shape = {sharded_config.global_buffer_shape.height(), physical_shard_shape.width()}; + } + if (physical_shard_shape.width() == 0) { + physical_shard_shape = {physical_shard_shape.height(), sharded_config.global_buffer_shape.width()}; + } + return physical_shard_shape; +} + } // namespace tt::tt_metal::distributed diff --git a/tt_metal/distributed/mesh_buffer.hpp b/tt_metal/distributed/mesh_buffer.hpp index 4558ec10fc0..4de80d06070 100644 --- a/tt_metal/distributed/mesh_buffer.hpp +++ b/tt_metal/distributed/mesh_buffer.hpp @@ -8,6 +8,7 @@ #include #include #include +#include namespace tt::tt_metal::distributed { @@ -40,19 +41,17 @@ struct ShardedBufferConfig { DeviceAddr global_size = 0; // Global shape of the buffer; at metal-level, we expect the shape to be aligned with the mesh shape. - // TODO: Consider a 2D shape class. - std::pair global_buffer_shape = {0, 0}; + Shape2D global_buffer_shape = {0, 0}; // Shard shape, sent to each device. - // TODO: Consider a 2D shape class. - std::pair shard_shape = {0, 0}; + Shape2D shard_shape = {0, 0}; // Orientation of the shards in a mesh. ShardOrientation shard_orientation = ShardOrientation::ROW_MAJOR; // Computes the number of bytes per datum in the sharded buffer. uint32_t compute_datum_size_bytes() const { - return global_size / (global_buffer_shape.first * global_buffer_shape.second); + return global_size / (global_buffer_shape.height() * global_buffer_shape.width()); } }; @@ -80,32 +79,45 @@ class MeshBuffer { const DeviceLocalBufferConfig& device_local_config() const { return device_local_config_; } std::shared_ptr get_device_buffer(const Coordinate& device_coord); + uint32_t datum_size_bytes() const; + Shape2D physical_shard_shape() const; private: MeshBuffer( const MeshBufferConfig& config, const DeviceLocalBufferConfig& device_local_config, - DeviceAddr address, DeviceAddr device_local_size, MeshDevice* mesh_device, std::shared_ptr backing_buffer) : config_(config), device_local_config_(device_local_config), mesh_device_(mesh_device), - address_(address), device_local_size_(device_local_size), backing_buffer_(std::move(backing_buffer)) {} - void allocate(); + MeshBuffer( + const MeshBufferConfig& config, + const DeviceLocalBufferConfig& device_local_config, + DeviceAddr address, + DeviceAddr device_local_size, + MeshDevice* mesh_device) : + config_(config), + device_local_config_(device_local_config), + mesh_device_(mesh_device), + address_(address), + device_local_size_(device_local_size) {} + void allocate(); MeshBufferConfig config_; DeviceLocalBufferConfig device_local_config_; MeshDevice* mesh_device_ = nullptr; DeviceAddr address_ = 0; DeviceAddr device_local_size_ = 0; - // TODO: Conisder optimizing with SmallVector. + // TODO: Consider optimizing with SmallVector. std::vector>> buffers_; + // Buffer owned by the MeshBuffer. Responsible for interfacing with the + // single device allocator. std::shared_ptr backing_buffer_; }; diff --git a/tt_metal/distributed/mesh_command_queue.cpp b/tt_metal/distributed/mesh_command_queue.cpp index c88428f6297..1c7c33879c5 100644 --- a/tt_metal/distributed/mesh_command_queue.cpp +++ b/tt_metal/distributed/mesh_command_queue.cpp @@ -4,6 +4,7 @@ #include "tt_metal/distributed/mesh_command_queue.hpp" #include "tt_metal/distributed/mesh_workload_utils.hpp" +#include "tt_metal/impl/buffers/dispatch.hpp" namespace tt::tt_metal::distributed { @@ -173,4 +174,84 @@ void MeshCommandQueue::finish() { } } +void MeshCommandQueue::write_shard_to_device( + std::shared_ptr& shard_view, + const void* src, + std::array& expected_num_workers_completed, + tt::stl::Span sub_device_ids) { + auto device = shard_view->device(); + BufferRegion region(0, shard_view->size()); + buffer_dispatch::write_to_device_buffer( + src, *shard_view, region, id_, expected_num_workers_completed, this->dispatch_core_type(), sub_device_ids); +} + +void MeshCommandQueue::read_shard_from_device( + std::shared_ptr& shard_view, + void* dst, + std::array& expected_num_workers_completed, + tt::stl::Span sub_device_ids) { + auto device = shard_view->device(); + chip_id_t mmio_device_id = tt::Cluster::instance().get_associated_mmio_device(device->id()); + uint16_t channel = tt::Cluster::instance().get_assigned_channel_for_device(device->id()); + + bool exit_condition = false; + + BufferRegion region(0, shard_view->size()); + + if (is_sharded(shard_view->buffer_layout())) { + auto dispatch_params = buffer_dispatch::initialize_sharded_buf_read_dispatch_params( + *shard_view, id_, expected_num_workers_completed); + auto cores = buffer_dispatch::get_cores_for_sharded_buffer( + dispatch_params.width_split, dispatch_params.buffer_page_mapping, *shard_view); + for (uint32_t core_id = 0; core_id < shard_view->num_cores(); ++core_id) { + buffer_dispatch::copy_sharded_buffer_from_core_to_completion_queue( + core_id, *shard_view, dispatch_params, sub_device_ids, cores[core_id], this->dispatch_core_type()); + if (dispatch_params.pages_per_txn > 0) { + auto read_descriptor = std::get( + *buffer_dispatch::generate_sharded_buffer_read_descriptor(dst, dispatch_params, *shard_view)); + buffer_dispatch::copy_completion_queue_data_into_user_space( + read_descriptor, mmio_device_id, channel, id_, device->sysmem_manager(), exit_condition); + } + } + } else { + auto dispatch_params = buffer_dispatch::initialize_interleaved_buf_read_dispatch_params( + *shard_view, id_, expected_num_workers_completed, region); + buffer_dispatch::copy_interleaved_buffer_to_completion_queue( + dispatch_params, *shard_view, sub_device_ids, this->dispatch_core_type()); + if (dispatch_params.pages_per_txn > 0) { + auto read_descriptor = std::get( + *buffer_dispatch::generate_interleaved_buffer_read_descriptor(dst, dispatch_params, *shard_view)); + buffer_dispatch::copy_completion_queue_data_into_user_space( + read_descriptor, mmio_device_id, channel, id_, device->sysmem_manager(), exit_condition); + } + } +} + +void MeshCommandQueue::enqueue_write_shard( + std::shared_ptr& mesh_buffer, void* host_data, const Coordinate& coord, bool blocking) { + // TODO: Add proper support for SubDevices once SubDeviceManager and allocator are moved up to MeshDevice + // We should not be querying SubDevices from device 0. + auto sub_device_ids = tt::stl::Span(mesh_device_->get_device(0)->get_sub_device_ids()); + std::array expected_num_workers_completed; + expected_num_workers_completed[0] = expected_num_workers_completed_; + auto shard = mesh_buffer->get_device_buffer(coord); + this->write_shard_to_device(shard, host_data, expected_num_workers_completed, sub_device_ids); + + if (blocking) { + this->finish(); + } +} + +void MeshCommandQueue::enqueue_read_shard( + void* host_data, std::shared_ptr& mesh_buffer, const Coordinate& coord, bool blocking) { + TT_FATAL(blocking, "Only blocking reads are currently supported from MeshBuffer shards."); + // TODO: Add proper support for SubDevices once SubDeviceManager and allocator are moved up to MeshDevice + // We should not be querying SubDevices from device 0. + auto sub_device_ids = tt::stl::Span(mesh_device_->get_device(0)->get_sub_device_ids()); + std::array expected_num_workers_completed; + expected_num_workers_completed[0] = expected_num_workers_completed_; + auto shard = mesh_buffer->get_device_buffer(coord); + this->read_shard_from_device(shard, host_data, expected_num_workers_completed, sub_device_ids); +} + } // namespace tt::tt_metal::distributed diff --git a/tt_metal/distributed/mesh_command_queue.hpp b/tt_metal/distributed/mesh_command_queue.hpp index fc628779846..034eec7682e 100644 --- a/tt_metal/distributed/mesh_command_queue.hpp +++ b/tt_metal/distributed/mesh_command_queue.hpp @@ -5,7 +5,9 @@ #pragma once #include +#include +#include "tt_metal/distributed/mesh_buffer.hpp" #include "tt_metal/distributed/mesh_workload.hpp" namespace tt::tt_metal::distributed { @@ -21,6 +23,16 @@ class MeshCommandQueue { void populate_dispatch_core_type(); CoreCoord virtual_program_dispatch_core() const; CoreType dispatch_core_type() const; + void write_shard_to_device( + std::shared_ptr& shard_view, + const void* src, + std::array& expected_num_workers_completed, + tt::stl::Span sub_device_ids); + void read_shard_from_device( + std::shared_ptr& shard_view, + void* dst, + std::array& expected_num_workers_completed, + tt::stl::Span sub_device_ids); tt::tt_metal::WorkerConfigBufferMgr config_buffer_mgr_; LaunchMessageRingBufferState worker_launch_message_buffer_state_; uint32_t expected_num_workers_completed_ = 0; @@ -35,6 +47,10 @@ class MeshCommandQueue { uint32_t id() const { return id_; } WorkerConfigBufferMgr& get_config_buffer_mgr(uint32_t index) { return config_buffer_mgr_; }; void enqueue_mesh_workload(MeshWorkload& mesh_workload, bool blocking); + void enqueue_write_shard( + std::shared_ptr& mesh_buffer, void* host_data, const Coordinate& coord, bool blocking); + void enqueue_read_shard( + void* host_data, std::shared_ptr& mesh_buffer, const Coordinate& coord, bool blocking); void finish(); };