Skip to content

Commit

Permalink
#0: Update buffer/allocator to better handle sharded tensors with par…
Browse files Browse the repository at this point in the history
…tial shards (ie. sharded tensors with padding greater than page alignment)

- Use aligned size and aligned page size in Buffer::aligned_size_per_bank
- Remove size arg in allocate::allocate_buffer and get from buffer instead
  * Remove branching for sharded buffers in AllocateBuffer
  • Loading branch information
TT-BrianLiu committed Dec 7, 2024
1 parent d5c4061 commit 8ee1f82
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 12 deletions.
5 changes: 3 additions & 2 deletions tt_metal/impl/allocator/allocator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -437,9 +437,10 @@ void reset_allocator_size(Allocator& allocator, const BufferType& buffer_type) {
}
}

DeviceAddr allocate_buffer(Allocator& allocator, DeviceAddr size, Buffer* buffer) {
DeviceAddr allocate_buffer(Allocator& allocator, Buffer* buffer) {
DeviceAddr address = 0;
auto page_size = buffer->page_size();
auto size = buffer->aligned_size();
auto page_size = buffer->aligned_page_size();
auto buffer_type = buffer->buffer_type();
auto bottom_up = buffer->bottom_up();
auto num_shards = buffer->num_cores();
Expand Down
2 changes: 1 addition & 1 deletion tt_metal/impl/allocator/allocator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ void shrink_allocator_size(
Allocator& allocator, const BufferType& buffer_type, DeviceAddr shrink_size, bool bottom_up = true);
void reset_allocator_size(Allocator& allocator, const BufferType& buffer_type);

DeviceAddr allocate_buffer(Allocator& allocator, DeviceAddr size, Buffer* buffer);
DeviceAddr allocate_buffer(Allocator& allocator, Buffer* buffer);

void mark_allocations_unsafe(Allocator& allocator);

Expand Down
2 changes: 1 addition & 1 deletion tt_metal/impl/buffers/buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ DeviceAddr Buffer::aligned_size() const {

DeviceAddr Buffer::aligned_size_per_bank() const {
uint32_t num_banks = is_sharded(this->buffer_layout_) ? this->num_cores().value() : this->device_->num_banks(this->buffer_type());
return tt::tt_metal::detail::SizeBytesPerBank(this->size_, this->page_size_, num_banks, this->alignment());
return tt::tt_metal::detail::SizeBytesPerBank(this->aligned_size(), this->aligned_page_size(), num_banks, this->alignment());
}

DeviceAddr Buffer::sharded_page_address(uint32_t bank_id, uint32_t page_index) const {
Expand Down
10 changes: 2 additions & 8 deletions tt_metal/tt_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -856,15 +856,9 @@ DeviceAddr AllocateBuffer(Buffer* buffer) {
*buffer->sub_device_manager_id(),
buffer->device()->get_active_sub_device_manager_id());
}
auto allocator = buffer->allocator();
DeviceAddr allocated_addr;

if (is_sharded(buffer->buffer_layout())) {
allocated_addr = allocator::allocate_buffer(
*allocator, buffer->shard_spec().size() * buffer->num_cores().value() * buffer->page_size(), buffer);
} else {
allocated_addr = allocator::allocate_buffer(*allocator, buffer->size(), buffer);
}
DeviceAddr allocated_addr = allocator::allocate_buffer(*buffer->allocator(), buffer);

// Assertion here because buffer class returns a u32 when address is queried
// Requires updating all use cases of buffer address to accept a u64 to remove
TT_ASSERT(allocated_addr <= std::numeric_limits<uint32_t>::max());
Expand Down

0 comments on commit 8ee1f82

Please sign in to comment.