Skip to content

Commit

Permalink
#0: Add WriteShard and ReadShard MeshBuffer APIs and resolve MeshBuff…
Browse files Browse the repository at this point in the history
…er dealloc issues

  - Add tests for reading and writing shards with Interleaved and Sharded configs
  - Add test for deallocation, verying addresses
  • Loading branch information
tt-asaigal committed Jan 22, 2025
1 parent b95d0a3 commit f9508de
Show file tree
Hide file tree
Showing 6 changed files with 328 additions and 49 deletions.
149 changes: 142 additions & 7 deletions tests/tt_metal/distributed/test_mesh_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,56 @@
#include <tt-metalium/mesh_device_view.hpp>

#include "tests/tt_metal/tt_metal/common/multi_device_fixture.hpp"
#include "tt_metal/distributed/mesh_buffer.hpp"
#include "tt_metal/distributed/distributed.hpp"

namespace tt::tt_metal::distributed::test {
namespace {

using MeshBufferTest = T3000MultiDeviceFixture;

class DeviceLocalShardedBufferTestConfig {
public:
std::array<uint32_t, 2> num_pages_per_core;
std::array<uint32_t, 2> num_cores;
std::array<uint32_t, 2> page_shape;
uint32_t element_size = 1;
TensorMemoryLayout mem_config = TensorMemoryLayout::HEIGHT_SHARDED;
ShardOrientation shard_orientation = ShardOrientation::ROW_MAJOR;

DeviceLocalShardedBufferTestConfig(
const std::array<uint32_t, 2>& num_pages_per_core_,
const std::array<uint32_t, 2>& num_cores_,
const std::array<uint32_t, 2> page_shape_,
const TensorMemoryLayout& shard_strategy_) {
this->num_pages_per_core = num_pages_per_core_;
this->num_cores = num_cores_;
this->page_shape = page_shape_;
this->mem_config = shard_strategy_;
}

std::array<uint32_t, 2> tensor2d_shape() {
return {num_pages_per_core[0] * num_cores[0], num_pages_per_core[1] * num_cores[1]};
}

uint32_t num_pages() { return tensor2d_shape()[0] * tensor2d_shape()[1]; }

std::array<uint32_t, 2> shard_shape() {
return {num_pages_per_core[0] * page_shape[0], num_pages_per_core[1] * page_shape[1]};
}

CoreRangeSet shard_grid() {
return CoreRangeSet(std::set<CoreRange>(
{CoreRange(CoreCoord(0, 0), CoreCoord(this->num_cores[0] - 1, this->num_cores[1] - 1))}));
}

uint32_t page_size() { return page_shape[0] * page_shape[1] * element_size; }

ShardSpecBuffer shard_parameters() {
return ShardSpecBuffer(
this->shard_grid(), this->shard_shape(), this->shard_orientation, this->page_shape, this->tensor2d_shape());
}
};

TEST_F(MeshBufferTest, ConfigValidation) {
const DeviceLocalBufferConfig device_local_config{
.page_size = 1024,
Expand Down Expand Up @@ -78,22 +121,24 @@ TEST_F(MeshBufferTest, ReplicatedBufferInitialization) {
}

TEST_F(MeshBufferTest, Deallocation) {
// Verify that a buffer is deallocated on the MeshDevice when it goes
// out of scope on host. Create a buffer with a certain config in limited
// scope. Record its address. Create another buffer with the same config
// outside the scope. Verify that addresses match.
const DeviceLocalBufferConfig device_local_config{
.page_size = 1024,
.buffer_type = BufferType::DRAM,
.buffer_layout = TensorMemoryLayout::INTERLEAVED,
.bottom_up = false};

const ReplicatedBufferConfig buffer_config{.size = 16 << 10};
std::shared_ptr<Buffer> buffer;
Allocator* allocator = nullptr;
uint32_t expected_address = 0;
{
auto replicated_buffer = MeshBuffer::create(buffer_config, device_local_config, mesh_device_.get());
buffer = replicated_buffer->get_device_buffer(Coordinate{0, 0});
allocator = buffer->allocator();
EXPECT_TRUE(allocator->allocated_buffers.contains(buffer.get()));
expected_address = replicated_buffer->address();
}
EXPECT_FALSE(allocator->allocated_buffers.contains(buffer.get()));
auto replicated_buffer = MeshBuffer::create(buffer_config, device_local_config, mesh_device_.get());
EXPECT_EQ(replicated_buffer->address(), expected_address);
}

TEST_F(MeshBufferTest, GetDeviceBuffer) {
Expand All @@ -112,5 +157,95 @@ TEST_F(MeshBufferTest, GetDeviceBuffer) {
EXPECT_NO_THROW(replicated_buffer->get_device_buffer(Coordinate{1, 3}));
}

TEST_F(MeshBufferTest, TestInterleavedShardsReadWrite) {
constexpr uint32_t NUM_ITERS = 100;
uint32_t seed = tt::parse_env("TT_METAL_SEED", 0);
uint32_t single_tile_size = ::tt::tt_metal::detail::TileSize(DataFormat::Float16_b);

for (auto buffer_type : {BufferType::L1, BufferType::DRAM}) {
DeviceLocalBufferConfig per_device_buffer_config{
.page_size = single_tile_size,
.buffer_type = BufferType::L1,
.buffer_layout = TensorMemoryLayout::INTERLEAVED,
.bottom_up = false};

std::uniform_int_distribution<int> gen_num_tiles(1, 1024);
std::mt19937 rng(seed);
for (int i = 0; i < NUM_ITERS; i++) {
uint32_t num_random_tiles = gen_num_tiles(rng);
ReplicatedBufferConfig global_buffer_config = {
.size = num_random_tiles * single_tile_size,
};

std::shared_ptr<MeshBuffer> buf =
MeshBuffer::create(global_buffer_config, per_device_buffer_config, mesh_device_.get());

std::vector<uint32_t> src_vec = create_constant_vector_of_bfloat16(num_random_tiles * single_tile_size, i);
for (std::size_t logical_x = 0; logical_x < buf->device()->num_cols(); logical_x++) {
for (std::size_t logical_y = 0; logical_y < buf->device()->num_rows(); logical_y++) {
WriteShard(mesh_device_->mesh_command_queue(), buf, src_vec, Coordinate(logical_y, logical_x));
}
}

for (std::size_t logical_x = 0; logical_x < buf->device()->num_cols(); logical_x++) {
for (std::size_t logical_y = 0; logical_y < buf->device()->num_rows(); logical_y++) {
std::vector<bfloat16> dst_vec = {};
ReadShard(mesh_device_->mesh_command_queue(), dst_vec, buf, Coordinate(logical_y, logical_x));
for (int j = 0; j < dst_vec.size(); j++) {
EXPECT_EQ(dst_vec[j].to_float(), i);
}
}
}
}
}
}

TEST_F(MeshBufferTest, TestDeviceLocalMeshBufferSharding) {
CoreCoord core_grid_size = mesh_device_->compute_with_storage_grid_size();
std::vector<std::array<uint32_t, 2>> num_pages_per_core_vec = {{1, 1}, {3, 137}, {67, 4}, {7, 11}, {2, 2}};
std::vector<std::array<uint32_t, 2>> page_shapes = {{1, 1024}, {1, 2048}, {1, 4}, {32, 32}, {1, 120}};
std::vector<TensorMemoryLayout> shard_strategies = {
TensorMemoryLayout::HEIGHT_SHARDED, TensorMemoryLayout::WIDTH_SHARDED, TensorMemoryLayout::BLOCK_SHARDED};

for (const auto shard_strategy : shard_strategies) {
for (const auto& num_pages_per_core : num_pages_per_core_vec) {
for (const auto& page_shape : page_shapes) {
DeviceLocalShardedBufferTestConfig test_config(
num_pages_per_core, {core_grid_size.x, core_grid_size.y}, page_shape, shard_strategy);
DeviceLocalBufferConfig per_device_buffer_config{
.page_size = test_config.page_size(),
.buffer_type = BufferType::L1,
.buffer_layout = test_config.mem_config,
.shard_parameters = test_config.shard_parameters(),
.bottom_up = false};

uint32_t buf_size = test_config.num_pages() * test_config.page_size();
ReplicatedBufferConfig global_buffer_config{
.size = buf_size,
};
auto buf = MeshBuffer::create(global_buffer_config, per_device_buffer_config, mesh_device_.get());
std::vector<uint32_t> src_vec(buf_size / sizeof(uint32_t), 0);
std::iota(src_vec.begin(), src_vec.end(), 0);

for (std::size_t logical_x = 0; logical_x < buf->device()->num_cols(); logical_x++) {
for (std::size_t logical_y = 0; logical_y < buf->device()->num_rows(); logical_y++) {
WriteShard(mesh_device_->mesh_command_queue(), buf, src_vec, Coordinate(logical_y, logical_x));
}
}

for (std::size_t logical_x = 0; logical_x < buf->device()->num_cols(); logical_x++) {
for (std::size_t logical_y = 0; logical_y < buf->device()->num_rows(); logical_y++) {
std::vector<uint32_t> dst_vec = {};
ReadShard(mesh_device_->mesh_command_queue(), dst_vec, buf, Coordinate(logical_y, logical_x));
for (int j = 0; j < dst_vec.size(); j++) {
EXPECT_EQ(dst_vec[j], j);
}
}
}
}
}
}
}

} // namespace
} // namespace tt::tt_metal::distributed::test
22 changes: 22 additions & 0 deletions tt_metal/distributed/distributed.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,28 @@ void AddProgramToMeshWorkload(MeshWorkload& mesh_workload, Program& program, con

void EnqueueMeshWorkload(MeshCommandQueue& mesh_cq, MeshWorkload& mesh_workload, bool blocking);

template <typename DType>
void WriteShard(
MeshCommandQueue& mesh_cq,
std::shared_ptr<MeshBuffer>& mesh_buffer,
std::vector<DType>& src,
const Coordinate& coord,
bool blocking = false) {
mesh_cq.enqueue_write_shard(mesh_buffer, src.data(), coord, blocking);
}

template <typename DType>
void ReadShard(
MeshCommandQueue& mesh_cq,
std::vector<DType>& dst,
std::shared_ptr<MeshBuffer>& mesh_buffer,
const Coordinate& coord,
bool blocking = true) {
auto shard = mesh_buffer->get_device_buffer(coord);
dst.resize(shard->page_size() * shard->num_pages() / sizeof(DType));
mesh_cq.enqueue_read_shard(dst.data(), mesh_buffer, coord, blocking);
}

void Finish(MeshCommandQueue& mesh_cq);

} // namespace distributed
Expand Down
79 changes: 46 additions & 33 deletions tt_metal/distributed/mesh_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,38 +59,25 @@ std::shared_ptr<MeshBuffer> MeshBuffer::create(
}},
mesh_buffer_config);

// Rely on the single device allocator to provide the address for the entire mesh buffer.
// TODO: use mesh allocator, when available.
std::shared_ptr<Buffer> backing_buffer = Buffer::create(
mesh_device->get_device(0, 0),
/*address=*/address.value_or(0),
device_local_size,
device_local_config.page_size,
device_local_config.buffer_type,
device_local_config.buffer_layout,
device_local_config.shard_parameters,
device_local_config.bottom_up);
std::shared_ptr<MeshBuffer> mesh_buffer;
if (!address.has_value()) {
*address = tt::tt_metal::detail::AllocateBuffer(backing_buffer.get());
auto* backing_buffer_ptr = backing_buffer.get();
// Rely on the single device allocator to provide the address for the entire mesh buffer.
// The address provided to the backing buffer is used as the address for the MeshBuffer object.
// TODO: use mesh allocator, when available.
std::shared_ptr<Buffer> backing_buffer = Buffer::create(
mesh_device->get_device(0, 0),
device_local_size,
device_local_config.page_size,
device_local_config.buffer_type,
device_local_config.buffer_layout,
device_local_config.shard_parameters,
device_local_config.bottom_up);

mesh_buffer = std::shared_ptr<MeshBuffer>(
new MeshBuffer(
mesh_buffer_config,
device_local_config,
*address,
device_local_size,
mesh_device,
std::move(backing_buffer)),
[backing_buffer_ptr](MeshBuffer*) { tt::tt_metal::detail::DeallocateBuffer(backing_buffer_ptr); });
new MeshBuffer(mesh_buffer_config, device_local_config, device_local_size, mesh_device, backing_buffer));
} else {
mesh_buffer = std::shared_ptr<MeshBuffer>(new MeshBuffer(
mesh_buffer_config,
device_local_config,
*address,
device_local_size,
mesh_device,
std::move(backing_buffer)));
mesh_buffer = std::shared_ptr<MeshBuffer>(
new MeshBuffer(mesh_buffer_config, device_local_config, address.value(), device_local_size, mesh_device));
}

mesh_buffer->allocate();
Expand All @@ -99,6 +86,13 @@ std::shared_ptr<MeshBuffer> MeshBuffer::create(
}

void MeshBuffer::allocate() {
if (backing_buffer_) {
TT_FATAL(
!address_, "The address for a MeshBuffer should not explicitly be initialized when it is being allocated");
address_ = backing_buffer_->address();
} else {
TT_FATAL(address_, "A MeshBuffer should be provided a valid address if its not being allocated");
}
buffers_ = std::vector<std::vector<std::shared_ptr<Buffer>>>(
mesh_device_->num_rows(), std::vector<std::shared_ptr<Buffer>>(mesh_device_->num_cols()));

Expand All @@ -117,11 +111,7 @@ void MeshBuffer::allocate() {

for (int row = 0; row < mesh_device_->num_rows(); row++) {
for (int col = 0; col < mesh_device_->num_cols(); col++) {
if (row == 0 and col == 0) {
buffers_[row][col] = backing_buffer_;
} else {
buffers_[row][col] = allocate_device_buffer_at_address(Coordinate{row, col});
}
buffers_[row][col] = allocate_device_buffer_at_address(Coordinate{row, col});
}
}
}
Expand Down Expand Up @@ -156,4 +146,27 @@ const ShardedBufferConfig& MeshBuffer::global_shard_spec() const {
return std::get<ShardedBufferConfig>(config_);
}

uint32_t MeshBuffer::datum_size_bytes() const {
// Limitation for now.
TT_FATAL(
this->global_layout() == MeshBufferLayout::SHARDED,
"Can only query datum size for buffers sharded across the Mesh");
return this->global_shard_spec().compute_datum_size_bytes();
}

Shape2D MeshBuffer::physical_shard_shape() const {
TT_FATAL(
this->global_layout() == MeshBufferLayout::SHARDED,
"Can only query physical shard shape for buffers sharded across the Mesh");
auto sharded_config = std::get<ShardedBufferConfig>(config_);
Shape2D physical_shard_shape = sharded_config.shard_shape;
if (physical_shard_shape.height() == 0) {
physical_shard_shape = {sharded_config.global_buffer_shape.height(), physical_shard_shape.width()};
}
if (physical_shard_shape.width() == 0) {
physical_shard_shape = {physical_shard_shape.height(), sharded_config.global_buffer_shape.width()};
}
return physical_shard_shape;
}

} // namespace tt::tt_metal::distributed
Loading

0 comments on commit f9508de

Please sign in to comment.