-
Notifications
You must be signed in to change notification settings - Fork 96
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
#0: Add WriteShard and ReadShard MeshBuffer APIs and resolve MeshBuffer dealloc issues #16960
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,13 +9,44 @@ | |
#include <tt-metalium/mesh_device_view.hpp> | ||
|
||
#include "tests/tt_metal/tt_metal/common/multi_device_fixture.hpp" | ||
#include "tt_metal/distributed/mesh_buffer.hpp" | ||
#include "tt_metal/distributed/distributed.hpp" | ||
|
||
namespace tt::tt_metal::distributed::test { | ||
namespace { | ||
|
||
using MeshBufferTest = T3000MultiDeviceFixture; | ||
|
||
struct DeviceLocalShardedBufferTestConfig { | ||
Shape2D num_pages_per_core; | ||
Shape2D num_cores; | ||
Shape2D page_shape; | ||
uint32_t element_size = 1; | ||
TensorMemoryLayout mem_config = TensorMemoryLayout::HEIGHT_SHARDED; | ||
ShardOrientation shard_orientation = ShardOrientation::ROW_MAJOR; | ||
|
||
Shape2D tensor2d_shape() { | ||
return {num_pages_per_core.height() * num_cores.height(), num_pages_per_core.width() * num_cores.width()}; | ||
} | ||
|
||
uint32_t num_pages() { return tensor2d_shape().height() * tensor2d_shape().width(); } | ||
|
||
std::array<uint32_t, 2> shard_shape() { | ||
return {num_pages_per_core.height() * page_shape.height(), num_pages_per_core.width() * page_shape.width()}; | ||
} | ||
|
||
CoreRangeSet shard_grid() { | ||
return CoreRangeSet(std::set<CoreRange>( | ||
{CoreRange(CoreCoord(0, 0), CoreCoord(this->num_cores.height() - 1, this->num_cores.width() - 1))})); | ||
} | ||
|
||
uint32_t page_size() { return page_shape.height() * page_shape.width() * element_size; } | ||
|
||
ShardSpecBuffer shard_parameters() { | ||
return ShardSpecBuffer( | ||
this->shard_grid(), this->shard_shape(), this->shard_orientation, this->page_shape, this->tensor2d_shape()); | ||
} | ||
}; | ||
|
||
TEST_F(MeshBufferTest, ConfigValidation) { | ||
const DeviceLocalBufferConfig device_local_config{ | ||
.page_size = 1024, | ||
|
@@ -78,22 +109,24 @@ TEST_F(MeshBufferTest, ReplicatedBufferInitialization) { | |
} | ||
|
||
TEST_F(MeshBufferTest, Deallocation) { | ||
// Verify that a buffer is deallocated on the MeshDevice when it goes | ||
// out of scope on host. Create a buffer with a certain config in limited | ||
// scope. Record its address. Create another buffer with the same config | ||
// outside the scope. Verify that addresses match. | ||
const DeviceLocalBufferConfig device_local_config{ | ||
.page_size = 1024, | ||
.buffer_type = BufferType::DRAM, | ||
.buffer_layout = TensorMemoryLayout::INTERLEAVED, | ||
.bottom_up = false}; | ||
|
||
const ReplicatedBufferConfig buffer_config{.size = 16 << 10}; | ||
std::shared_ptr<Buffer> buffer; | ||
Allocator* allocator = nullptr; | ||
uint32_t expected_address = 0; | ||
{ | ||
auto replicated_buffer = MeshBuffer::create(buffer_config, device_local_config, mesh_device_.get()); | ||
buffer = replicated_buffer->get_device_buffer(Coordinate{0, 0}); | ||
allocator = buffer->allocator(); | ||
EXPECT_TRUE(allocator->allocated_buffers.contains(buffer.get())); | ||
expected_address = replicated_buffer->address(); | ||
} | ||
EXPECT_FALSE(allocator->allocated_buffers.contains(buffer.get())); | ||
auto replicated_buffer = MeshBuffer::create(buffer_config, device_local_config, mesh_device_.get()); | ||
EXPECT_EQ(replicated_buffer->address(), expected_address); | ||
} | ||
|
||
TEST_F(MeshBufferTest, GetDeviceBuffer) { | ||
|
@@ -112,5 +145,95 @@ TEST_F(MeshBufferTest, GetDeviceBuffer) { | |
EXPECT_NO_THROW(replicated_buffer->get_device_buffer(Coordinate{1, 3})); | ||
} | ||
|
||
TEST_F(MeshBufferTest, InterleavedShardsReadWrite) { | ||
constexpr uint32_t NUM_ITERS = 100; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why bother with 100 iterations? Can we do 2 to confirm the re-read re-write path? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. each iteration is randomized, I think its worth adding randomized testing for all core data-structures and APIs, especially if its cheap (on the order of a few seconds) |
||
uint32_t seed = tt::parse_env("TT_METAL_SEED", 0); | ||
uint32_t single_tile_size = ::tt::tt_metal::detail::TileSize(DataFormat::UInt32); | ||
|
||
for (auto buffer_type : {BufferType::L1, BufferType::DRAM}) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we want to do this, this should be a parameterized |
||
DeviceLocalBufferConfig per_device_buffer_config{ | ||
.page_size = single_tile_size, | ||
.buffer_type = BufferType::L1, | ||
.buffer_layout = TensorMemoryLayout::INTERLEAVED, | ||
.bottom_up = false}; | ||
|
||
std::uniform_int_distribution<int> gen_num_tiles(1, 1024); | ||
std::mt19937 rng(seed); | ||
for (int i = 0; i < NUM_ITERS; i++) { | ||
uint32_t num_random_tiles = gen_num_tiles(rng); | ||
ReplicatedBufferConfig global_buffer_config = { | ||
.size = num_random_tiles * single_tile_size, | ||
}; | ||
|
||
std::shared_ptr<MeshBuffer> buf = | ||
MeshBuffer::create(global_buffer_config, per_device_buffer_config, mesh_device_.get()); | ||
|
||
std::vector<uint32_t> src_vec(num_random_tiles * single_tile_size / sizeof(uint32_t), 0); | ||
std::iota(src_vec.begin(), src_vec.end(), i); | ||
for (std::size_t logical_x = 0; logical_x < buf->device()->num_cols(); logical_x++) { | ||
for (std::size_t logical_y = 0; logical_y < buf->device()->num_rows(); logical_y++) { | ||
WriteShard(mesh_device_->mesh_command_queue(), buf, src_vec, Coordinate(logical_y, logical_x)); | ||
} | ||
} | ||
|
||
for (std::size_t logical_x = 0; logical_x < buf->device()->num_cols(); logical_x++) { | ||
for (std::size_t logical_y = 0; logical_y < buf->device()->num_rows(); logical_y++) { | ||
std::vector<uint32_t> dst_vec = {}; | ||
ReadShard(mesh_device_->mesh_command_queue(), dst_vec, buf, Coordinate(logical_y, logical_x)); | ||
EXPECT_EQ(dst_vec, src_vec); | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
TEST_F(MeshBufferTest, DeviceLocalMeshBufferSharding) { | ||
CoreCoord core_grid_size = mesh_device_->compute_with_storage_grid_size(); | ||
std::vector<std::array<uint32_t, 2>> num_pages_per_core_vec = {{1, 1}, {3, 137}, {67, 4}, {7, 11}, {2, 2}}; | ||
std::vector<std::array<uint32_t, 2>> page_shapes = {{1, 1024}, {1, 2048}, {1, 4}, {32, 32}, {1, 120}}; | ||
std::vector<TensorMemoryLayout> shard_strategies = { | ||
TensorMemoryLayout::HEIGHT_SHARDED, TensorMemoryLayout::WIDTH_SHARDED, TensorMemoryLayout::BLOCK_SHARDED}; | ||
|
||
for (const auto shard_strategy : shard_strategies) { | ||
for (const auto& num_pages_per_core : num_pages_per_core_vec) { | ||
for (const auto& page_shape : page_shapes) { | ||
DeviceLocalShardedBufferTestConfig test_config{ | ||
.num_pages_per_core = num_pages_per_core, | ||
.num_cores = {core_grid_size.x, core_grid_size.y}, | ||
.page_shape = page_shape, | ||
.mem_config = shard_strategy}; | ||
DeviceLocalBufferConfig per_device_buffer_config{ | ||
.page_size = test_config.page_size(), | ||
.buffer_type = BufferType::L1, | ||
.buffer_layout = test_config.mem_config, | ||
.shard_parameters = test_config.shard_parameters(), | ||
.bottom_up = false}; | ||
|
||
uint32_t buf_size = test_config.num_pages() * test_config.page_size(); | ||
ReplicatedBufferConfig global_buffer_config{ | ||
.size = buf_size, | ||
}; | ||
auto buf = MeshBuffer::create(global_buffer_config, per_device_buffer_config, mesh_device_.get()); | ||
std::vector<uint32_t> src_vec(buf_size / sizeof(uint32_t), 0); | ||
std::iota(src_vec.begin(), src_vec.end(), 0); | ||
|
||
for (std::size_t logical_x = 0; logical_x < buf->device()->num_cols(); logical_x++) { | ||
for (std::size_t logical_y = 0; logical_y < buf->device()->num_rows(); logical_y++) { | ||
WriteShard(mesh_device_->mesh_command_queue(), buf, src_vec, Coordinate(logical_y, logical_x)); | ||
} | ||
} | ||
|
||
for (std::size_t logical_x = 0; logical_x < buf->device()->num_cols(); logical_x++) { | ||
for (std::size_t logical_y = 0; logical_y < buf->device()->num_rows(); logical_y++) { | ||
std::vector<uint32_t> dst_vec = {}; | ||
ReadShard(mesh_device_->mesh_command_queue(), dst_vec, buf, Coordinate(logical_y, logical_x)); | ||
EXPECT_EQ(dst_vec, src_vec); | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
} // namespace | ||
} // namespace tt::tt_metal::distributed::test |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,6 +24,28 @@ void AddProgramToMeshWorkload(MeshWorkload& mesh_workload, Program& program, con | |
|
||
void EnqueueMeshWorkload(MeshCommandQueue& mesh_cq, MeshWorkload& mesh_workload, bool blocking); | ||
|
||
template <typename DType> | ||
void WriteShard( | ||
MeshCommandQueue& mesh_cq, | ||
std::shared_ptr<MeshBuffer>& mesh_buffer, | ||
std::vector<DType>& src, | ||
const Coordinate& coord, | ||
bool blocking = false) { | ||
mesh_cq.enqueue_write_shard(mesh_buffer, src.data(), coord, blocking); | ||
} | ||
|
||
template <typename DType> | ||
void ReadShard( | ||
MeshCommandQueue& mesh_cq, | ||
std::vector<DType>& dst, | ||
std::shared_ptr<MeshBuffer>& mesh_buffer, | ||
const Coordinate& coord, | ||
bool blocking = true) { | ||
auto shard = mesh_buffer->get_device_buffer(coord); | ||
dst.resize(shard->page_size() * shard->num_pages() / sizeof(DType)); | ||
mesh_cq.enqueue_read_shard(dst.data(), mesh_buffer, coord, blocking); | ||
} | ||
|
||
Comment on lines
+27
to
+48
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
We should converge on the interface for eventual unification with single-device There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree, my preference is to modify the tt-metal variants, since:
I think this should be a separate effort, since it requires dedicated dispatch changes to properly support. |
||
void Finish(MeshCommandQueue& mesh_cq); | ||
|
||
} // namespace distributed | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does the old test pass as it is? Just wondering. Btw I tried testing with
Buffer::is_allocated()
, buttt::tt_metal::detail::DeallocateBuffer(backing_buffer_ptr);
didn't reset the variable that tracks the allocation state. I think you have a better luck with the new revision. Can you try:There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The original test does not pass as is. The reason is that the previous impl for
MeshBuffer
stored the backing buffer as the shard for device(0, 0)
. So doing the following gives you an accurate representation for the state of the backing bufferThis is not the case anymore, since the backing buffer is not accessible to the user in any form.
The test you've described will not pass either, since the individual shards returned by
get_device_buffer
and stored in a temporary variable do not get deallocated when aMeshBuffer
goes out of scope.