-
Notifications
You must be signed in to change notification settings - Fork 96
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
#0: Add WriteShard and ReadShard MeshBuffer APIs and resolve MeshBuffer dealloc issues #16960
base: main
Are you sure you want to change the base?
Conversation
|
||
namespace tt::tt_metal::distributed::test { | ||
namespace { | ||
|
||
using MeshBufferTest = T3000MultiDeviceFixture; | ||
|
||
class DeviceLocalShardedBufferTestConfig { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the plan with DeviceLocalBufferConfig
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a struct used for testing only. It generates device local sharding parameters based on user provided config, making it easier for us to test single device sharding functionality.
DeviceLocalBufferConfig
is a core user facing struct exposed in the mesh_buffer.hpp
header
template <typename DType> | ||
void WriteShard( | ||
MeshCommandQueue& mesh_cq, | ||
std::shared_ptr<MeshBuffer>& mesh_buffer, | ||
std::vector<DType>& src, | ||
const Coordinate& coord, | ||
bool blocking = false) { | ||
mesh_cq.enqueue_write_shard(mesh_buffer, src.data(), coord, blocking); | ||
} | ||
|
||
template <typename DType> | ||
void ReadShard( | ||
MeshCommandQueue& mesh_cq, | ||
std::vector<DType>& dst, | ||
std::shared_ptr<MeshBuffer>& mesh_buffer, | ||
const Coordinate& coord, | ||
bool blocking = true) { | ||
auto shard = mesh_buffer->get_device_buffer(coord); | ||
dst.resize(shard->page_size() * shard->num_pages() / sizeof(DType)); | ||
mesh_cq.enqueue_read_shard(dst.data(), mesh_buffer, coord, blocking); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
void ReadShard(Buffer& buffer, std::vector<DType>& host_buffer, const uint32_t& core_id);
We should converge on the interface for eventual unification with single-device
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, my preference is to modify the tt-metal variants, since:
- They currently dont work with Fast Dispatch, so they'll need a command queue in the argument list
- They pass in a 1D core_id which doesn't make sense for a 2D grid.
I think this should be a separate effort, since it requires dedicated dispatch changes to properly support.
f9508de
to
42f0a14
Compare
Passing Post Commit: https://github.com/tenstorrent/tt-metal/actions/runs/12900817938 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The APIs look nice - thank you for following up!
tt_metal/distributed/mesh_buffer.hpp
Outdated
std::vector<std::vector<std::shared_ptr<Buffer>>> buffers_; | ||
// Buffer owned by the MeshBuffer. Responsible for interfacing with the | ||
// single device allocator. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add "not set if the MeshBuffer is externally owned"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|
||
void MeshCommandQueue::enqueue_read_shard( | ||
void* host_data, std::shared_ptr<MeshBuffer>& mesh_buffer, const Coordinate& coord, bool blocking) { | ||
TT_FATAL(blocking, "Only blocking reads are currently supported from MeshBuffer shards."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
blocking
here doesn't do anything - is this intentional?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes this is intentional for now. Once we support non-blocking reads, this parameter will actually do something. For now, we have the API in place and just assert if the user tries to do a non-blocking read
} | ||
EXPECT_FALSE(allocator->allocated_buffers.contains(buffer.get())); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does the old test pass as it is? Just wondering. Btw I tried testing with Buffer::is_allocated()
, but tt::tt_metal::detail::DeallocateBuffer(backing_buffer_ptr);
didn't reset the variable that tracks the allocation state. I think you have a better luck with the new revision. Can you try:
std::shared_ptr<Buffer> buffer;
{
auto replicated_buffer = MeshBuffer::create(buffer_config, device_local_config, mesh_device_.get());
buffer = replicated_buffer->get_device_buffer(Coordinate{0, 0});
EXPECT_TRUE(buffer->is_allocated());
}
EXPECT_FALSE(buffer->is_allocated());
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The original test does not pass as is. The reason is that the previous impl for MeshBuffer
stored the backing buffer as the shard for device (0, 0)
. So doing the following gives you an accurate representation for the state of the backing buffer
buffer = replicated_buffer->get_device_buffer(Coordinate{0, 0});
allocator = buffer->allocator();
EXPECT_TRUE(allocator->allocated_buffers.contains(buffer.get()));
This is not the case anymore, since the backing buffer is not accessible to the user in any form.
The test you've described will not pass either, since the individual shards returned by get_device_buffer
and stored in a temporary variable do not get deallocated when a MeshBuffer
goes out of scope.
@@ -112,5 +157,95 @@ TEST_F(MeshBufferTest, GetDeviceBuffer) { | |||
EXPECT_NO_THROW(replicated_buffer->get_device_buffer(Coordinate{1, 3})); | |||
} | |||
|
|||
TEST_F(MeshBufferTest, TestInterleavedShardsReadWrite) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MeshBufferTest
already has "Test" in it, drop Test
prefix here and below
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
std::shared_ptr<MeshBuffer> buf = | ||
MeshBuffer::create(global_buffer_config, per_device_buffer_config, mesh_device_.get()); | ||
|
||
std::vector<uint32_t> src_vec = create_constant_vector_of_bfloat16(num_random_tiles * single_tile_size, i); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you create a std::vector<bfloat16>
and write it using WriteShard
API? Or rely on a simpler dtype like float or int? The same vector can then be used to compared with the output:
EXPECT_THAT(dst_vect, Pointwise(Eq(), src_vec))
... without a loop.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, no for loop comparisons being done in tests
uint32_t seed = tt::parse_env("TT_METAL_SEED", 0); | ||
uint32_t single_tile_size = ::tt::tt_metal::detail::TileSize(DataFormat::Float16_b); | ||
|
||
for (auto buffer_type : {BufferType::L1, BufferType::DRAM}) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we want to do this, this should be a parameterized TEST_P
.
std::vector<uint32_t> dst_vec = {}; | ||
ReadShard(mesh_device_->mesh_command_queue(), dst_vec, buf, Coordinate(logical_y, logical_x)); | ||
for (int j = 0; j < dst_vec.size(); j++) { | ||
EXPECT_EQ(dst_vec[j], j); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here for comparing vectors without a loop
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
CoreCoord core_grid_size = mesh_device_->compute_with_storage_grid_size(); | ||
std::vector<std::array<uint32_t, 2>> num_pages_per_core_vec = {{1, 1}, {3, 137}, {67, 4}, {7, 11}, {2, 2}}; | ||
std::vector<std::array<uint32_t, 2>> page_shapes = {{1, 1024}, {1, 2048}, {1, 4}, {32, 32}, {1, 120}}; | ||
std::vector<TensorMemoryLayout> shard_strategies = { | ||
TensorMemoryLayout::HEIGHT_SHARDED, TensorMemoryLayout::WIDTH_SHARDED, TensorMemoryLayout::BLOCK_SHARDED}; | ||
|
||
for (const auto shard_strategy : shard_strategies) { | ||
for (const auto& num_pages_per_core : num_pages_per_core_vec) { | ||
for (const auto& page_shape : page_shapes) { | ||
DeviceLocalShardedBufferTestConfig test_config( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here for using a parameterized test suite.
|
||
namespace tt::tt_metal::distributed::test { | ||
namespace { | ||
|
||
using MeshBufferTest = T3000MultiDeviceFixture; | ||
|
||
class DeviceLocalShardedBufferTestConfig { | ||
public: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class
with just public data members and methods is a struct
, let's do that and remove constructor? You can use aggregate initialization with this syntax:
DeviceLocalShardedBufferTestConfig config{
.num_pages_per_core = ...,
.num_cores = ...
// etc
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|
||
namespace tt::tt_metal::distributed::test { | ||
namespace { | ||
|
||
using MeshBufferTest = T3000MultiDeviceFixture; | ||
|
||
class DeviceLocalShardedBufferTestConfig { | ||
public: | ||
std::array<uint32_t, 2> num_pages_per_core; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shape2D?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
42f0a14
to
17ceb33
Compare
…er dealloc issues - Add tests for reading and writing shards with Interleaved and Sharded configs - Add test for deallocation, verying addresses
17ceb33
to
befeaca
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
overall lgtm once other feedback is addressed
Ticket
No ticket.
Problem description
MeshBuffer
deallocation on destruction is currently a nop on main (issue deallocate to address 0 when destroying anyMeshBuffer
)MeshBuffer
needs to be addedWhat's changed
MeshBuffer
object. The backing buffer gets deleted and automatically deallocated at the correct address when its destroyedWriteShard
andReadShard
APIs. Multi-device sharding and replication builds on top of this (to be added in a separate PR).Checklist