Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#0: Add WriteShard and ReadShard MeshBuffer APIs and resolve MeshBuffer dealloc issues #16960

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 130 additions & 7 deletions tests/tt_metal/distributed/test_mesh_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,44 @@
#include <tt-metalium/mesh_device_view.hpp>

#include "tests/tt_metal/tt_metal/common/multi_device_fixture.hpp"
#include "tt_metal/distributed/mesh_buffer.hpp"
#include "tt_metal/distributed/distributed.hpp"

namespace tt::tt_metal::distributed::test {
namespace {

using MeshBufferTest = T3000MultiDeviceFixture;

struct DeviceLocalShardedBufferTestConfig {
Shape2D num_pages_per_core;
Shape2D num_cores;
Shape2D page_shape;
uint32_t element_size = 1;
TensorMemoryLayout mem_config = TensorMemoryLayout::HEIGHT_SHARDED;
ShardOrientation shard_orientation = ShardOrientation::ROW_MAJOR;

Shape2D tensor2d_shape() {
return {num_pages_per_core.height() * num_cores.height(), num_pages_per_core.width() * num_cores.width()};
}

uint32_t num_pages() { return tensor2d_shape().height() * tensor2d_shape().width(); }

std::array<uint32_t, 2> shard_shape() {
return {num_pages_per_core.height() * page_shape.height(), num_pages_per_core.width() * page_shape.width()};
}

CoreRangeSet shard_grid() {
return CoreRangeSet(std::set<CoreRange>(
{CoreRange(CoreCoord(0, 0), CoreCoord(this->num_cores.height() - 1, this->num_cores.width() - 1))}));
}

uint32_t page_size() { return page_shape.height() * page_shape.width() * element_size; }

ShardSpecBuffer shard_parameters() {
return ShardSpecBuffer(
this->shard_grid(), this->shard_shape(), this->shard_orientation, this->page_shape, this->tensor2d_shape());
}
};

TEST_F(MeshBufferTest, ConfigValidation) {
const DeviceLocalBufferConfig device_local_config{
.page_size = 1024,
Expand Down Expand Up @@ -78,22 +109,24 @@ TEST_F(MeshBufferTest, ReplicatedBufferInitialization) {
}

TEST_F(MeshBufferTest, Deallocation) {
// Verify that a buffer is deallocated on the MeshDevice when it goes
// out of scope on host. Create a buffer with a certain config in limited
// scope. Record its address. Create another buffer with the same config
// outside the scope. Verify that addresses match.
const DeviceLocalBufferConfig device_local_config{
.page_size = 1024,
.buffer_type = BufferType::DRAM,
.buffer_layout = TensorMemoryLayout::INTERLEAVED,
.bottom_up = false};

const ReplicatedBufferConfig buffer_config{.size = 16 << 10};
std::shared_ptr<Buffer> buffer;
Allocator* allocator = nullptr;
uint32_t expected_address = 0;
{
auto replicated_buffer = MeshBuffer::create(buffer_config, device_local_config, mesh_device_.get());
buffer = replicated_buffer->get_device_buffer(Coordinate{0, 0});
allocator = buffer->allocator();
EXPECT_TRUE(allocator->allocated_buffers.contains(buffer.get()));
expected_address = replicated_buffer->address();
}
EXPECT_FALSE(allocator->allocated_buffers.contains(buffer.get()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the old test pass as it is? Just wondering. Btw I tried testing with Buffer::is_allocated(), but tt::tt_metal::detail::DeallocateBuffer(backing_buffer_ptr); didn't reset the variable that tracks the allocation state. I think you have a better luck with the new revision. Can you try:

std::shared_ptr<Buffer> buffer;
{
  auto replicated_buffer = MeshBuffer::create(buffer_config, device_local_config, mesh_device_.get());
  buffer = replicated_buffer->get_device_buffer(Coordinate{0, 0});
  EXPECT_TRUE(buffer->is_allocated());
}
EXPECT_FALSE(buffer->is_allocated());

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original test does not pass as is. The reason is that the previous impl for MeshBuffer stored the backing buffer as the shard for device (0, 0). So doing the following gives you an accurate representation for the state of the backing buffer

        buffer = replicated_buffer->get_device_buffer(Coordinate{0, 0});
        allocator = buffer->allocator();
        EXPECT_TRUE(allocator->allocated_buffers.contains(buffer.get()));

This is not the case anymore, since the backing buffer is not accessible to the user in any form.

The test you've described will not pass either, since the individual shards returned by get_device_buffer and stored in a temporary variable do not get deallocated when a MeshBuffer goes out of scope.

auto replicated_buffer = MeshBuffer::create(buffer_config, device_local_config, mesh_device_.get());
EXPECT_EQ(replicated_buffer->address(), expected_address);
}

TEST_F(MeshBufferTest, GetDeviceBuffer) {
Expand All @@ -112,5 +145,95 @@ TEST_F(MeshBufferTest, GetDeviceBuffer) {
EXPECT_NO_THROW(replicated_buffer->get_device_buffer(Coordinate{1, 3}));
}

TEST_F(MeshBufferTest, InterleavedShardsReadWrite) {
constexpr uint32_t NUM_ITERS = 100;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why bother with 100 iterations? Can we do 2 to confirm the re-read re-write path?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

each iteration is randomized, I think its worth adding randomized testing for all core data-structures and APIs, especially if its cheap (on the order of a few seconds)

uint32_t seed = tt::parse_env("TT_METAL_SEED", 0);
uint32_t single_tile_size = ::tt::tt_metal::detail::TileSize(DataFormat::UInt32);

for (auto buffer_type : {BufferType::L1, BufferType::DRAM}) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want to do this, this should be a parameterized TEST_P.

DeviceLocalBufferConfig per_device_buffer_config{
.page_size = single_tile_size,
.buffer_type = BufferType::L1,
.buffer_layout = TensorMemoryLayout::INTERLEAVED,
.bottom_up = false};

std::uniform_int_distribution<int> gen_num_tiles(1, 1024);
std::mt19937 rng(seed);
for (int i = 0; i < NUM_ITERS; i++) {
uint32_t num_random_tiles = gen_num_tiles(rng);
ReplicatedBufferConfig global_buffer_config = {
.size = num_random_tiles * single_tile_size,
};

std::shared_ptr<MeshBuffer> buf =
MeshBuffer::create(global_buffer_config, per_device_buffer_config, mesh_device_.get());

std::vector<uint32_t> src_vec(num_random_tiles * single_tile_size / sizeof(uint32_t), 0);
std::iota(src_vec.begin(), src_vec.end(), i);
for (std::size_t logical_x = 0; logical_x < buf->device()->num_cols(); logical_x++) {
for (std::size_t logical_y = 0; logical_y < buf->device()->num_rows(); logical_y++) {
WriteShard(mesh_device_->mesh_command_queue(), buf, src_vec, Coordinate(logical_y, logical_x));
}
}

for (std::size_t logical_x = 0; logical_x < buf->device()->num_cols(); logical_x++) {
for (std::size_t logical_y = 0; logical_y < buf->device()->num_rows(); logical_y++) {
std::vector<uint32_t> dst_vec = {};
ReadShard(mesh_device_->mesh_command_queue(), dst_vec, buf, Coordinate(logical_y, logical_x));
EXPECT_EQ(dst_vec, src_vec);
}
}
}
}
}

TEST_F(MeshBufferTest, DeviceLocalMeshBufferSharding) {
CoreCoord core_grid_size = mesh_device_->compute_with_storage_grid_size();
std::vector<std::array<uint32_t, 2>> num_pages_per_core_vec = {{1, 1}, {3, 137}, {67, 4}, {7, 11}, {2, 2}};
std::vector<std::array<uint32_t, 2>> page_shapes = {{1, 1024}, {1, 2048}, {1, 4}, {32, 32}, {1, 120}};
std::vector<TensorMemoryLayout> shard_strategies = {
TensorMemoryLayout::HEIGHT_SHARDED, TensorMemoryLayout::WIDTH_SHARDED, TensorMemoryLayout::BLOCK_SHARDED};

for (const auto shard_strategy : shard_strategies) {
for (const auto& num_pages_per_core : num_pages_per_core_vec) {
for (const auto& page_shape : page_shapes) {
DeviceLocalShardedBufferTestConfig test_config{
.num_pages_per_core = num_pages_per_core,
.num_cores = {core_grid_size.x, core_grid_size.y},
.page_shape = page_shape,
.mem_config = shard_strategy};
DeviceLocalBufferConfig per_device_buffer_config{
.page_size = test_config.page_size(),
.buffer_type = BufferType::L1,
.buffer_layout = test_config.mem_config,
.shard_parameters = test_config.shard_parameters(),
.bottom_up = false};

uint32_t buf_size = test_config.num_pages() * test_config.page_size();
ReplicatedBufferConfig global_buffer_config{
.size = buf_size,
};
auto buf = MeshBuffer::create(global_buffer_config, per_device_buffer_config, mesh_device_.get());
std::vector<uint32_t> src_vec(buf_size / sizeof(uint32_t), 0);
std::iota(src_vec.begin(), src_vec.end(), 0);

for (std::size_t logical_x = 0; logical_x < buf->device()->num_cols(); logical_x++) {
for (std::size_t logical_y = 0; logical_y < buf->device()->num_rows(); logical_y++) {
WriteShard(mesh_device_->mesh_command_queue(), buf, src_vec, Coordinate(logical_y, logical_x));
}
}

for (std::size_t logical_x = 0; logical_x < buf->device()->num_cols(); logical_x++) {
for (std::size_t logical_y = 0; logical_y < buf->device()->num_rows(); logical_y++) {
std::vector<uint32_t> dst_vec = {};
ReadShard(mesh_device_->mesh_command_queue(), dst_vec, buf, Coordinate(logical_y, logical_x));
EXPECT_EQ(dst_vec, src_vec);
}
}
}
}
}
}

} // namespace
} // namespace tt::tt_metal::distributed::test
22 changes: 22 additions & 0 deletions tt_metal/distributed/distributed.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,28 @@ void AddProgramToMeshWorkload(MeshWorkload& mesh_workload, Program& program, con

void EnqueueMeshWorkload(MeshCommandQueue& mesh_cq, MeshWorkload& mesh_workload, bool blocking);

template <typename DType>
void WriteShard(
MeshCommandQueue& mesh_cq,
std::shared_ptr<MeshBuffer>& mesh_buffer,
std::vector<DType>& src,
const Coordinate& coord,
bool blocking = false) {
mesh_cq.enqueue_write_shard(mesh_buffer, src.data(), coord, blocking);
}

template <typename DType>
void ReadShard(
MeshCommandQueue& mesh_cq,
std::vector<DType>& dst,
std::shared_ptr<MeshBuffer>& mesh_buffer,
const Coordinate& coord,
bool blocking = true) {
auto shard = mesh_buffer->get_device_buffer(coord);
dst.resize(shard->page_size() * shard->num_pages() / sizeof(DType));
mesh_cq.enqueue_read_shard(dst.data(), mesh_buffer, coord, blocking);
}

Comment on lines +27 to +48
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

void ReadShard(Buffer& buffer, std::vector<DType>& host_buffer, const uint32_t& core_id);

We should converge on the interface for eventual unification with single-device

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, my preference is to modify the tt-metal variants, since:

  1. They currently dont work with Fast Dispatch, so they'll need a command queue in the argument list
  2. They pass in a 1D core_id which doesn't make sense for a 2D grid.

I think this should be a separate effort, since it requires dedicated dispatch changes to properly support.

void Finish(MeshCommandQueue& mesh_cq);

} // namespace distributed
Expand Down
79 changes: 46 additions & 33 deletions tt_metal/distributed/mesh_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,38 +59,25 @@ std::shared_ptr<MeshBuffer> MeshBuffer::create(
}},
mesh_buffer_config);

// Rely on the single device allocator to provide the address for the entire mesh buffer.
// TODO: use mesh allocator, when available.
std::shared_ptr<Buffer> backing_buffer = Buffer::create(
mesh_device->get_device(0, 0),
/*address=*/address.value_or(0),
device_local_size,
device_local_config.page_size,
device_local_config.buffer_type,
device_local_config.buffer_layout,
device_local_config.shard_parameters,
device_local_config.bottom_up);
std::shared_ptr<MeshBuffer> mesh_buffer;
if (!address.has_value()) {
*address = tt::tt_metal::detail::AllocateBuffer(backing_buffer.get());
auto* backing_buffer_ptr = backing_buffer.get();
// Rely on the single device allocator to provide the address for the entire mesh buffer.
// The address provided to the backing buffer is used as the address for the MeshBuffer object.
// TODO: use mesh allocator, when available.
std::shared_ptr<Buffer> backing_buffer = Buffer::create(
mesh_device->get_device(0, 0),
device_local_size,
device_local_config.page_size,
device_local_config.buffer_type,
device_local_config.buffer_layout,
device_local_config.shard_parameters,
device_local_config.bottom_up);

mesh_buffer = std::shared_ptr<MeshBuffer>(
new MeshBuffer(
mesh_buffer_config,
device_local_config,
*address,
device_local_size,
mesh_device,
std::move(backing_buffer)),
[backing_buffer_ptr](MeshBuffer*) { tt::tt_metal::detail::DeallocateBuffer(backing_buffer_ptr); });
new MeshBuffer(mesh_buffer_config, device_local_config, device_local_size, mesh_device, backing_buffer));
} else {
mesh_buffer = std::shared_ptr<MeshBuffer>(new MeshBuffer(
mesh_buffer_config,
device_local_config,
*address,
device_local_size,
mesh_device,
std::move(backing_buffer)));
mesh_buffer = std::shared_ptr<MeshBuffer>(
new MeshBuffer(mesh_buffer_config, device_local_config, address.value(), device_local_size, mesh_device));
}

mesh_buffer->allocate();
Expand All @@ -99,6 +86,13 @@ std::shared_ptr<MeshBuffer> MeshBuffer::create(
}

void MeshBuffer::allocate() {
if (backing_buffer_) {
TT_FATAL(
!address_, "The address for a MeshBuffer should not explicitly be initialized when it is being allocated");
address_ = backing_buffer_->address();
} else {
TT_FATAL(address_, "A MeshBuffer should be provided a valid address if its not being allocated");
}
buffers_ = std::vector<std::vector<std::shared_ptr<Buffer>>>(
mesh_device_->num_rows(), std::vector<std::shared_ptr<Buffer>>(mesh_device_->num_cols()));

Expand All @@ -117,11 +111,7 @@ void MeshBuffer::allocate() {

for (int row = 0; row < mesh_device_->num_rows(); row++) {
for (int col = 0; col < mesh_device_->num_cols(); col++) {
if (row == 0 and col == 0) {
buffers_[row][col] = backing_buffer_;
} else {
buffers_[row][col] = allocate_device_buffer_at_address(Coordinate{row, col});
}
buffers_[row][col] = allocate_device_buffer_at_address(Coordinate{row, col});
}
}
}
Expand Down Expand Up @@ -156,4 +146,27 @@ const ShardedBufferConfig& MeshBuffer::global_shard_spec() const {
return std::get<ShardedBufferConfig>(config_);
}

uint32_t MeshBuffer::datum_size_bytes() const {
// Limitation for now.
TT_FATAL(
this->global_layout() == MeshBufferLayout::SHARDED,
"Can only query datum size for buffers sharded across the Mesh");
return this->global_shard_spec().compute_datum_size_bytes();
}

Shape2D MeshBuffer::physical_shard_shape() const {
TT_FATAL(
this->global_layout() == MeshBufferLayout::SHARDED,
"Can only query physical shard shape for buffers sharded across the Mesh");
auto sharded_config = std::get<ShardedBufferConfig>(config_);
Shape2D physical_shard_shape = sharded_config.shard_shape;
if (physical_shard_shape.height() == 0) {
physical_shard_shape = {sharded_config.global_buffer_shape.height(), physical_shard_shape.width()};
}
if (physical_shard_shape.width() == 0) {
physical_shard_shape = {physical_shard_shape.height(), sharded_config.global_buffer_shape.width()};
}
return physical_shard_shape;
}

} // namespace tt::tt_metal::distributed
31 changes: 22 additions & 9 deletions tt_metal/distributed/mesh_buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <buffer_constants.hpp>
#include <mesh_device.hpp>
#include <mesh_device_view.hpp>
#include <tt-metalium/shape2d.hpp>

namespace tt::tt_metal::distributed {

Expand Down Expand Up @@ -40,19 +41,17 @@ struct ShardedBufferConfig {
DeviceAddr global_size = 0;

// Global shape of the buffer; at metal-level, we expect the shape to be aligned with the mesh shape.
// TODO: Consider a 2D shape class.
std::pair<size_t, size_t> global_buffer_shape = {0, 0};
Shape2D global_buffer_shape = {0, 0};

// Shard shape, sent to each device.
// TODO: Consider a 2D shape class.
std::pair<size_t, size_t> shard_shape = {0, 0};
Shape2D shard_shape = {0, 0};

// Orientation of the shards in a mesh.
ShardOrientation shard_orientation = ShardOrientation::ROW_MAJOR;

// Computes the number of bytes per datum in the sharded buffer.
uint32_t compute_datum_size_bytes() const {
return global_size / (global_buffer_shape.first * global_buffer_shape.second);
return global_size / (global_buffer_shape.height() * global_buffer_shape.width());
}
};

Expand Down Expand Up @@ -80,32 +79,46 @@ class MeshBuffer {
const DeviceLocalBufferConfig& device_local_config() const { return device_local_config_; }

std::shared_ptr<Buffer> get_device_buffer(const Coordinate& device_coord);
uint32_t datum_size_bytes() const;
Shape2D physical_shard_shape() const;

private:
MeshBuffer(
const MeshBufferConfig& config,
const DeviceLocalBufferConfig& device_local_config,
DeviceAddr address,
DeviceAddr device_local_size,
MeshDevice* mesh_device,
std::shared_ptr<Buffer> backing_buffer) :
config_(config),
device_local_config_(device_local_config),
mesh_device_(mesh_device),
address_(address),
device_local_size_(device_local_size),
backing_buffer_(std::move(backing_buffer)) {}

void allocate();
MeshBuffer(
const MeshBufferConfig& config,
const DeviceLocalBufferConfig& device_local_config,
DeviceAddr address,
DeviceAddr device_local_size,
MeshDevice* mesh_device) :
config_(config),
device_local_config_(device_local_config),
mesh_device_(mesh_device),
address_(address),
device_local_size_(device_local_size) {}

void allocate();
MeshBufferConfig config_;
DeviceLocalBufferConfig device_local_config_;
MeshDevice* mesh_device_ = nullptr;
DeviceAddr address_ = 0;
DeviceAddr device_local_size_ = 0;

// TODO: Conisder optimizing with SmallVector.
// TODO: Consider optimizing with SmallVector.
std::vector<std::vector<std::shared_ptr<Buffer>>> buffers_;
// Buffer owned by the MeshBuffer. Responsible for interfacing with the
// single device allocator. This data-structure is not populated if memory
// for the MeshBuffer is externally owned.
std::shared_ptr<Buffer> backing_buffer_;
};

Expand Down
Loading
Loading