From 8ee1f823ccc6680fd4ba5aa9004487692da1c545 Mon Sep 17 00:00:00 2001 From: Brian Liu Date: Fri, 6 Dec 2024 22:44:40 +0000 Subject: [PATCH] #0: Update buffer/allocator to better handle sharded tensors with partial shards (ie. sharded tensors with padding greater than page alignment) - Use aligned size and aligned page size in Buffer::aligned_size_per_bank - Remove size arg in allocate::allocate_buffer and get from buffer instead * Remove branching for sharded buffers in AllocateBuffer --- tt_metal/impl/allocator/allocator.cpp | 5 +++-- tt_metal/impl/allocator/allocator.hpp | 2 +- tt_metal/impl/buffers/buffer.cpp | 2 +- tt_metal/tt_metal.cpp | 10 ++-------- 4 files changed, 7 insertions(+), 12 deletions(-) diff --git a/tt_metal/impl/allocator/allocator.cpp b/tt_metal/impl/allocator/allocator.cpp index aaeec83d553..bdbbe8dd3f9 100644 --- a/tt_metal/impl/allocator/allocator.cpp +++ b/tt_metal/impl/allocator/allocator.cpp @@ -437,9 +437,10 @@ void reset_allocator_size(Allocator& allocator, const BufferType& buffer_type) { } } -DeviceAddr allocate_buffer(Allocator& allocator, DeviceAddr size, Buffer* buffer) { +DeviceAddr allocate_buffer(Allocator& allocator, Buffer* buffer) { DeviceAddr address = 0; - auto page_size = buffer->page_size(); + auto size = buffer->aligned_size(); + auto page_size = buffer->aligned_page_size(); auto buffer_type = buffer->buffer_type(); auto bottom_up = buffer->bottom_up(); auto num_shards = buffer->num_cores(); diff --git a/tt_metal/impl/allocator/allocator.hpp b/tt_metal/impl/allocator/allocator.hpp index 0d62f8cd5de..1852e959766 100644 --- a/tt_metal/impl/allocator/allocator.hpp +++ b/tt_metal/impl/allocator/allocator.hpp @@ -136,7 +136,7 @@ void shrink_allocator_size( Allocator& allocator, const BufferType& buffer_type, DeviceAddr shrink_size, bool bottom_up = true); void reset_allocator_size(Allocator& allocator, const BufferType& buffer_type); -DeviceAddr allocate_buffer(Allocator& allocator, DeviceAddr size, Buffer* buffer); +DeviceAddr allocate_buffer(Allocator& allocator, Buffer* buffer); void mark_allocations_unsafe(Allocator& allocator); diff --git a/tt_metal/impl/buffers/buffer.cpp b/tt_metal/impl/buffers/buffer.cpp index 34975b0442b..b1e5ec3e337 100644 --- a/tt_metal/impl/buffers/buffer.cpp +++ b/tt_metal/impl/buffers/buffer.cpp @@ -467,7 +467,7 @@ DeviceAddr Buffer::aligned_size() const { DeviceAddr Buffer::aligned_size_per_bank() const { uint32_t num_banks = is_sharded(this->buffer_layout_) ? this->num_cores().value() : this->device_->num_banks(this->buffer_type()); - return tt::tt_metal::detail::SizeBytesPerBank(this->size_, this->page_size_, num_banks, this->alignment()); + return tt::tt_metal::detail::SizeBytesPerBank(this->aligned_size(), this->aligned_page_size(), num_banks, this->alignment()); } DeviceAddr Buffer::sharded_page_address(uint32_t bank_id, uint32_t page_index) const { diff --git a/tt_metal/tt_metal.cpp b/tt_metal/tt_metal.cpp index ff1987983e9..38f061df987 100644 --- a/tt_metal/tt_metal.cpp +++ b/tt_metal/tt_metal.cpp @@ -856,15 +856,9 @@ DeviceAddr AllocateBuffer(Buffer* buffer) { *buffer->sub_device_manager_id(), buffer->device()->get_active_sub_device_manager_id()); } - auto allocator = buffer->allocator(); - DeviceAddr allocated_addr; - if (is_sharded(buffer->buffer_layout())) { - allocated_addr = allocator::allocate_buffer( - *allocator, buffer->shard_spec().size() * buffer->num_cores().value() * buffer->page_size(), buffer); - } else { - allocated_addr = allocator::allocate_buffer(*allocator, buffer->size(), buffer); - } + DeviceAddr allocated_addr = allocator::allocate_buffer(*buffer->allocator(), buffer); + // Assertion here because buffer class returns a u32 when address is queried // Requires updating all use cases of buffer address to accept a u64 to remove TT_ASSERT(allocated_addr <= std::numeric_limits::max());