Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#0: Add WriteShard and ReadShard MeshBuffer APIs and resolve MeshBuffer dealloc issues #16960

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

tt-asaigal
Copy link
Contributor

@tt-asaigal tt-asaigal commented Jan 22, 2025

Ticket

No ticket.

Problem description

  • MeshBuffer deallocation on destruction is currently a nop on main (issue deallocate to address 0 when destroying any MeshBuffer)
  • Basic IO functionality for MeshBuffer needs to be added

What's changed

  • Resolve deallocation issue: Allocate the single device backing buffer when its created, and store this in the MeshBuffer object. The backing buffer gets deleted and automatically deallocated at the correct address when its destroyed
  • Add WriteShard and ReadShard APIs. Multi-device sharding and replication builds on top of this (to be added in a separate PR).
  • Add tests.

Checklist

  • Post commit CI passes
  • Blackhole Post commit (if applicable)
  • Model regression CI testing passes (if applicable)
  • Device performance regression CI testing passes (if applicable)
  • (For models and ops writers) Full new models tests passes
  • New/Existing tests provide coverage for changes

tt_metal/distributed/mesh_command_queue.cpp Show resolved Hide resolved
tt_metal/distributed/mesh_buffer.hpp Outdated Show resolved Hide resolved
tt_metal/distributed/mesh_buffer.cpp Outdated Show resolved Hide resolved

namespace tt::tt_metal::distributed::test {
namespace {

using MeshBufferTest = T3000MultiDeviceFixture;

class DeviceLocalShardedBufferTestConfig {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the plan with DeviceLocalBufferConfig?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a struct used for testing only. It generates device local sharding parameters based on user provided config, making it easier for us to test single device sharding functionality.
DeviceLocalBufferConfig is a core user facing struct exposed in the mesh_buffer.hpp header

Comment on lines +27 to +48
template <typename DType>
void WriteShard(
MeshCommandQueue& mesh_cq,
std::shared_ptr<MeshBuffer>& mesh_buffer,
std::vector<DType>& src,
const Coordinate& coord,
bool blocking = false) {
mesh_cq.enqueue_write_shard(mesh_buffer, src.data(), coord, blocking);
}

template <typename DType>
void ReadShard(
MeshCommandQueue& mesh_cq,
std::vector<DType>& dst,
std::shared_ptr<MeshBuffer>& mesh_buffer,
const Coordinate& coord,
bool blocking = true) {
auto shard = mesh_buffer->get_device_buffer(coord);
dst.resize(shard->page_size() * shard->num_pages() / sizeof(DType));
mesh_cq.enqueue_read_shard(dst.data(), mesh_buffer, coord, blocking);
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

void ReadShard(Buffer& buffer, std::vector<DType>& host_buffer, const uint32_t& core_id);

We should converge on the interface for eventual unification with single-device

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, my preference is to modify the tt-metal variants, since:

  1. They currently dont work with Fast Dispatch, so they'll need a command queue in the argument list
  2. They pass in a 1D core_id which doesn't make sense for a 2D grid.

I think this should be a separate effort, since it requires dedicated dispatch changes to properly support.

@tt-asaigal tt-asaigal force-pushed the asaigal/mesh_buffer_io branch 2 times, most recently from f9508de to 42f0a14 Compare January 22, 2025 03:49
@tt-asaigal
Copy link
Contributor Author

Copy link
Contributor

@omilyutin-tt omilyutin-tt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The APIs look nice - thank you for following up!

std::vector<std::vector<std::shared_ptr<Buffer>>> buffers_;
// Buffer owned by the MeshBuffer. Responsible for interfacing with the
// single device allocator.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add "not set if the MeshBuffer is externally owned"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


void MeshCommandQueue::enqueue_read_shard(
void* host_data, std::shared_ptr<MeshBuffer>& mesh_buffer, const Coordinate& coord, bool blocking) {
TT_FATAL(blocking, "Only blocking reads are currently supported from MeshBuffer shards.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

blocking here doesn't do anything - is this intentional?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes this is intentional for now. Once we support non-blocking reads, this parameter will actually do something. For now, we have the API in place and just assert if the user tries to do a non-blocking read

}
EXPECT_FALSE(allocator->allocated_buffers.contains(buffer.get()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the old test pass as it is? Just wondering. Btw I tried testing with Buffer::is_allocated(), but tt::tt_metal::detail::DeallocateBuffer(backing_buffer_ptr); didn't reset the variable that tracks the allocation state. I think you have a better luck with the new revision. Can you try:

std::shared_ptr<Buffer> buffer;
{
  auto replicated_buffer = MeshBuffer::create(buffer_config, device_local_config, mesh_device_.get());
  buffer = replicated_buffer->get_device_buffer(Coordinate{0, 0});
  EXPECT_TRUE(buffer->is_allocated());
}
EXPECT_FALSE(buffer->is_allocated());

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original test does not pass as is. The reason is that the previous impl for MeshBuffer stored the backing buffer as the shard for device (0, 0). So doing the following gives you an accurate representation for the state of the backing buffer

        buffer = replicated_buffer->get_device_buffer(Coordinate{0, 0});
        allocator = buffer->allocator();
        EXPECT_TRUE(allocator->allocated_buffers.contains(buffer.get()));

This is not the case anymore, since the backing buffer is not accessible to the user in any form.

The test you've described will not pass either, since the individual shards returned by get_device_buffer and stored in a temporary variable do not get deallocated when a MeshBuffer goes out of scope.

@@ -112,5 +157,95 @@ TEST_F(MeshBufferTest, GetDeviceBuffer) {
EXPECT_NO_THROW(replicated_buffer->get_device_buffer(Coordinate{1, 3}));
}

TEST_F(MeshBufferTest, TestInterleavedShardsReadWrite) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MeshBufferTest already has "Test" in it, drop Test prefix here and below

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

std::shared_ptr<MeshBuffer> buf =
MeshBuffer::create(global_buffer_config, per_device_buffer_config, mesh_device_.get());

std::vector<uint32_t> src_vec = create_constant_vector_of_bfloat16(num_random_tiles * single_tile_size, i);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you create a std::vector<bfloat16> and write it using WriteShard API? Or rely on a simpler dtype like float or int? The same vector can then be used to compared with the output:

EXPECT_THAT(dst_vect, Pointwise(Eq(), src_vec))

... without a loop.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, no for loop comparisons being done in tests

uint32_t seed = tt::parse_env("TT_METAL_SEED", 0);
uint32_t single_tile_size = ::tt::tt_metal::detail::TileSize(DataFormat::Float16_b);

for (auto buffer_type : {BufferType::L1, BufferType::DRAM}) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want to do this, this should be a parameterized TEST_P.

std::vector<uint32_t> dst_vec = {};
ReadShard(mesh_device_->mesh_command_queue(), dst_vec, buf, Coordinate(logical_y, logical_x));
for (int j = 0; j < dst_vec.size(); j++) {
EXPECT_EQ(dst_vec[j], j);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here for comparing vectors without a loop

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines 204 to 213
CoreCoord core_grid_size = mesh_device_->compute_with_storage_grid_size();
std::vector<std::array<uint32_t, 2>> num_pages_per_core_vec = {{1, 1}, {3, 137}, {67, 4}, {7, 11}, {2, 2}};
std::vector<std::array<uint32_t, 2>> page_shapes = {{1, 1024}, {1, 2048}, {1, 4}, {32, 32}, {1, 120}};
std::vector<TensorMemoryLayout> shard_strategies = {
TensorMemoryLayout::HEIGHT_SHARDED, TensorMemoryLayout::WIDTH_SHARDED, TensorMemoryLayout::BLOCK_SHARDED};

for (const auto shard_strategy : shard_strategies) {
for (const auto& num_pages_per_core : num_pages_per_core_vec) {
for (const auto& page_shape : page_shapes) {
DeviceLocalShardedBufferTestConfig test_config(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here for using a parameterized test suite.


namespace tt::tt_metal::distributed::test {
namespace {

using MeshBufferTest = T3000MultiDeviceFixture;

class DeviceLocalShardedBufferTestConfig {
public:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

class with just public data members and methods is a struct, let's do that and remove constructor? You can use aggregate initialization with this syntax:

DeviceLocalShardedBufferTestConfig config{
  .num_pages_per_core = ...,
  .num_cores = ...
  // etc
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


namespace tt::tt_metal::distributed::test {
namespace {

using MeshBufferTest = T3000MultiDeviceFixture;

class DeviceLocalShardedBufferTestConfig {
public:
std::array<uint32_t, 2> num_pages_per_core;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shape2D?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@tt-asaigal tt-asaigal force-pushed the asaigal/mesh_buffer_io branch from 42f0a14 to 17ceb33 Compare January 22, 2025 18:02
…er dealloc issues

  - Add tests for reading and writing shards with Interleaved and Sharded configs
  - Add test for deallocation, verying addresses
@tt-asaigal tt-asaigal force-pushed the asaigal/mesh_buffer_io branch from 17ceb33 to befeaca Compare January 22, 2025 18:44
Copy link
Contributor

@abhullar-tt abhullar-tt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overall lgtm once other feedback is addressed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants