Skip to content

Commit

Permalink
#15141: Support for additional system queries on MeshDevice
Browse files Browse the repository at this point in the history
This adds methods on the MeshDevice for the following:

MeshDevice::get_memory_allocation_statistics
MeshDevice::get_num_dram_channels
  • Loading branch information
cfjchu committed Dec 6, 2024
1 parent 5ba94f0 commit 1373d64
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 16 deletions.
16 changes: 16 additions & 0 deletions tests/ttnn/distributed/test_distributed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,20 @@ TEST_F(DistributedTest, TestSystemMeshTearDownWithoutClose) {
EXPECT_GT(cols, 0);
}

TEST_F(DistributedTest, TestMemoryAllocationStatistics) {
auto mesh = ttnn::distributed::open_mesh_device(
{2, 4}, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, 1, tt::tt_metal::DispatchCoreType::WORKER);
auto stats = mesh->get_memory_allocation_statistics(tt::tt_metal::BufferType::DRAM);
for (auto* device : mesh->get_devices()) {
auto device_stats = device->get_memory_allocation_statistics(tt::tt_metal::BufferType::DRAM);
EXPECT_EQ(stats.total_allocatable_size_bytes, device_stats.total_allocatable_size_bytes);
}
}

TEST_F(DistributedTest, TestNumDramChannels) {
auto mesh = ttnn::distributed::open_mesh_device(
{2, 4}, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, 1, tt::tt_metal::DispatchCoreType::WORKER);
EXPECT_EQ(mesh->num_dram_channels(), 96); // 8 devices * 12 channels
}

} // namespace ttnn::distributed::test
23 changes: 18 additions & 5 deletions tt_metal/distributed/mesh_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,10 @@ static MeshDeviceID generate_unique_mesh_id() {
return next_id++;
}

Device* MeshDevice::reference_device() const {
return this->devices.at(0);
}

MeshDevice::MeshDevice(const MeshShape& mesh_device_shape, MeshType type, std::weak_ptr<MeshDevice> parent_mesh) :
mesh_device_shape(mesh_device_shape),
type(type),
Expand Down Expand Up @@ -403,13 +407,11 @@ const DeviceIds MeshDevice::get_device_ids() const {

size_t MeshDevice::num_devices() const { return this->devices.size(); }

CoreCoord MeshDevice::compute_with_storage_grid_size() const {
return get_device_index(0)->compute_with_storage_grid_size();
}
CoreCoord MeshDevice::compute_with_storage_grid_size() const { return this->reference_device()->compute_with_storage_grid_size(); }

CoreCoord MeshDevice::dram_grid_size() const { return get_device_index(0)->dram_grid_size(); }
CoreCoord MeshDevice::dram_grid_size() const { return this->reference_device()->dram_grid_size(); }

tt::ARCH MeshDevice::arch() const { return get_device_index(0)->arch(); }
tt::ARCH MeshDevice::arch() const { return this->reference_device()->arch(); }

size_t MeshDevice::num_rows() const { return this->mesh_device_shape.first; }

Expand Down Expand Up @@ -517,4 +519,15 @@ MeshSubDeviceManagerId::MeshSubDeviceManagerId(const MeshDevice& mesh_device) {
this->sub_device_manager_ids.resize(mesh_device.num_devices());
}

int MeshDevice::num_dram_channels() const {
return this->reference_device()->num_dram_channels() * this->num_devices();
}

allocator::Statistics MeshDevice::get_memory_allocation_statistics(const BufferType &buffer_type, SubDeviceId sub_device_id) const {
// With current implementation, we assume that all devices have the same memory allocation statistics.
// This will be made more explicit in the future to have lock-step allocation across devices.
// Right now, we just return the statistics of the first device.
return this->reference_device()->get_memory_allocation_statistics(buffer_type, sub_device_id);
}

} // namespace tt::tt_metal::distributed
68 changes: 57 additions & 11 deletions tt_metal/distributed/mesh_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ class MeshDevice : public std::enable_shared_from_this<MeshDevice> {
const DispatchCoreConfig& dispatch_core_config,
const MeshDeviceConfig& config);

// This is a reference device used to query properties that are the same for all devices in the mesh.
Device* reference_device() const;

public:
MeshDevice(const MeshShape& mesh_device_shape, MeshType type, std::weak_ptr<MeshDevice> parent_mesh = {});
~MeshDevice();
Expand All @@ -111,15 +114,6 @@ class MeshDevice : public std::enable_shared_from_this<MeshDevice> {
size_t num_cols() const;
MeshShape shape() const;

CoreCoord compute_with_storage_grid_size() const;

CoreCoord dram_grid_size() const;

tt::ARCH arch() const;
void enable_async(bool enable);
void enable_program_cache();
void disable_and_clear_program_cache();

void close_devices();
std::shared_ptr<const MeshDeviceView> get_view() const;
std::shared_ptr<MeshDeviceView> get_view();
Expand All @@ -138,8 +132,6 @@ class MeshDevice : public std::enable_shared_from_this<MeshDevice> {
std::vector<std::shared_ptr<MeshDevice>> create_submeshes(
const MeshShape& submesh_shape, MeshType type = MeshType::RowMajor);

size_t num_program_cache_entries() const;

MeshSubDeviceManagerId create_sub_device_manager(
tt::stl::Span<const SubDevice> sub_devices, DeviceAddr local_l1_size);
void load_sub_device_manager(MeshSubDeviceManagerId mesh_sub_device_manager_id);
Expand All @@ -152,6 +144,20 @@ class MeshDevice : public std::enable_shared_from_this<MeshDevice> {
size_t trace_region_size = DEFAULT_TRACE_REGION_SIZE,
size_t num_command_queues = 1,
const DispatchCoreConfig& dispatch_core_config = DispatchCoreConfig{});

// Device API Queries (API contract with Device class to be supported in future)
CoreCoord compute_with_storage_grid_size() const;
CoreCoord dram_grid_size() const;

tt::ARCH arch() const;
void enable_async(bool enable);
void enable_program_cache();
void disable_and_clear_program_cache();

size_t num_program_cache_entries() const;

int num_dram_channels() const;
allocator::Statistics get_memory_allocation_statistics(const BufferType &buffer_type, SubDeviceId sub_device_id = SubDeviceId{0}) const;
};

std::ostream& operator<<(std::ostream& os, const MeshDevice& mesh_device);
Expand All @@ -164,4 +170,44 @@ struct MeshSubDeviceManagerId {
std::vector<SubDeviceManagerId> sub_device_manager_ids;
};

namespace detail {
template <typename T>
concept HasMethodsForArchitectureQueries = requires(T& device) {
{ device.compute_with_storage_grid_size() } -> std::same_as<CoreCoord>;
{ device.dram_grid_size() } -> std::same_as<CoreCoord>;
{ device.arch() } -> std::same_as<tt::ARCH>;
{ device.num_dram_channels() } -> std::same_as<int>;
};

template <typename T>
concept HasMethodsForAllocator = requires(T& device) {
{ device.get_memory_allocation_statistics(std::declval<BufferType>(), std::declval<SubDeviceId>()) } -> std::same_as<allocator::Statistics>;
};

template <typename T>
concept HasMethodsForProgramCache = requires(T& device) {
{ device.num_program_cache_entries() } -> std::same_as<size_t>;
{ device.enable_program_cache() } -> std::same_as<void>;
{ device.disable_and_clear_program_cache() } -> std::same_as<void>;
};

template <typename T>
concept HasMethodsForAsync = requires(T& device) {
{ device.enable_async(std::declval<bool>()) } -> std::same_as<void>;
};

template <typename T>
concept DeviceInterfaceContract =
HasMethodsForArchitectureQueries<T> &&
HasMethodsForAllocator<T> &&
HasMethodsForProgramCache<T> &&
HasMethodsForAsync<T>;

} // namespace detail

// For now static_asserts are used to ensure that the concepts are satisfied.
// This is a temporary compile-time check to make sure that Device/MeshDevice don't deviate from the expected interface.
static_assert(detail::DeviceInterfaceContract<Device>, "Device must satisfy the DeviceInterfaceContract concept.");
static_assert(detail::DeviceInterfaceContract<MeshDevice>, "MeshDevice must satisfy the DeviceInterfaceContract concept.");

} // namespace tt::tt_metal::distributed

0 comments on commit 1373d64

Please sign in to comment.