diff --git a/tests/ttnn/distributed/test_distributed.cpp b/tests/ttnn/distributed/test_distributed.cpp index 94ffe01bd3a..fb5f53988c5 100644 --- a/tests/ttnn/distributed/test_distributed.cpp +++ b/tests/ttnn/distributed/test_distributed.cpp @@ -26,4 +26,20 @@ TEST_F(DistributedTest, TestSystemMeshTearDownWithoutClose) { EXPECT_GT(cols, 0); } +TEST_F(DistributedTest, TestMemoryAllocationStatistics) { + auto mesh = ttnn::distributed::open_mesh_device( + {2, 4}, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, 1, tt::tt_metal::DispatchCoreType::WORKER); + auto stats = mesh->get_memory_allocation_statistics(tt::tt_metal::BufferType::DRAM); + for (auto* device : mesh->get_devices()) { + auto device_stats = device->get_memory_allocation_statistics(tt::tt_metal::BufferType::DRAM); + EXPECT_EQ(stats.total_allocatable_size_bytes, device_stats.total_allocatable_size_bytes); + } +} + +TEST_F(DistributedTest, TestNumDramChannels) { + auto mesh = ttnn::distributed::open_mesh_device( + {2, 4}, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, 1, tt::tt_metal::DispatchCoreType::WORKER); + EXPECT_EQ(mesh->num_dram_channels(), 96); // 8 devices * 12 channels +} + } // namespace ttnn::distributed::test diff --git a/tt_metal/distributed/mesh_device.cpp b/tt_metal/distributed/mesh_device.cpp index 62eaa78186b..4de962e0342 100644 --- a/tt_metal/distributed/mesh_device.cpp +++ b/tt_metal/distributed/mesh_device.cpp @@ -269,6 +269,10 @@ static MeshDeviceID generate_unique_mesh_id() { return next_id++; } +Device* MeshDevice::reference_device() const { + return this->devices.at(0); +} + MeshDevice::MeshDevice(const MeshShape& mesh_device_shape, MeshType type, std::weak_ptr parent_mesh) : mesh_device_shape(mesh_device_shape), type(type), @@ -403,13 +407,11 @@ const DeviceIds MeshDevice::get_device_ids() const { size_t MeshDevice::num_devices() const { return this->devices.size(); } -CoreCoord MeshDevice::compute_with_storage_grid_size() const { - return get_device_index(0)->compute_with_storage_grid_size(); -} +CoreCoord MeshDevice::compute_with_storage_grid_size() const { return this->reference_device()->compute_with_storage_grid_size(); } -CoreCoord MeshDevice::dram_grid_size() const { return get_device_index(0)->dram_grid_size(); } +CoreCoord MeshDevice::dram_grid_size() const { return this->reference_device()->dram_grid_size(); } -tt::ARCH MeshDevice::arch() const { return get_device_index(0)->arch(); } +tt::ARCH MeshDevice::arch() const { return this->reference_device()->arch(); } size_t MeshDevice::num_rows() const { return this->mesh_device_shape.first; } @@ -517,4 +519,15 @@ MeshSubDeviceManagerId::MeshSubDeviceManagerId(const MeshDevice& mesh_device) { this->sub_device_manager_ids.resize(mesh_device.num_devices()); } +int MeshDevice::num_dram_channels() const { + return this->reference_device()->num_dram_channels() * this->num_devices(); +} + +allocator::Statistics MeshDevice::get_memory_allocation_statistics(const BufferType &buffer_type, SubDeviceId sub_device_id) const { + // With current implementation, we assume that all devices have the same memory allocation statistics. + // This will be made more explicit in the future to have lock-step allocation across devices. + // Right now, we just return the statistics of the first device. + return this->reference_device()->get_memory_allocation_statistics(buffer_type, sub_device_id); +} + } // namespace tt::tt_metal::distributed diff --git a/tt_metal/distributed/mesh_device.hpp b/tt_metal/distributed/mesh_device.hpp index 81fa6f45be6..f4370cb7c58 100644 --- a/tt_metal/distributed/mesh_device.hpp +++ b/tt_metal/distributed/mesh_device.hpp @@ -89,6 +89,9 @@ class MeshDevice : public std::enable_shared_from_this { const DispatchCoreConfig& dispatch_core_config, const MeshDeviceConfig& config); + // This is a reference device used to query properties that are the same for all devices in the mesh. + Device* reference_device() const; + public: MeshDevice(const MeshShape& mesh_device_shape, MeshType type, std::weak_ptr parent_mesh = {}); ~MeshDevice(); @@ -111,15 +114,6 @@ class MeshDevice : public std::enable_shared_from_this { size_t num_cols() const; MeshShape shape() const; - CoreCoord compute_with_storage_grid_size() const; - - CoreCoord dram_grid_size() const; - - tt::ARCH arch() const; - void enable_async(bool enable); - void enable_program_cache(); - void disable_and_clear_program_cache(); - void close_devices(); std::shared_ptr get_view() const; std::shared_ptr get_view(); @@ -138,8 +132,6 @@ class MeshDevice : public std::enable_shared_from_this { std::vector> create_submeshes( const MeshShape& submesh_shape, MeshType type = MeshType::RowMajor); - size_t num_program_cache_entries() const; - MeshSubDeviceManagerId create_sub_device_manager( tt::stl::Span sub_devices, DeviceAddr local_l1_size); void load_sub_device_manager(MeshSubDeviceManagerId mesh_sub_device_manager_id); @@ -152,6 +144,20 @@ class MeshDevice : public std::enable_shared_from_this { size_t trace_region_size = DEFAULT_TRACE_REGION_SIZE, size_t num_command_queues = 1, const DispatchCoreConfig& dispatch_core_config = DispatchCoreConfig{}); + + // Device API Queries (API contract with Device class to be supported in future) + CoreCoord compute_with_storage_grid_size() const; + CoreCoord dram_grid_size() const; + + tt::ARCH arch() const; + void enable_async(bool enable); + void enable_program_cache(); + void disable_and_clear_program_cache(); + + size_t num_program_cache_entries() const; + + int num_dram_channels() const; + allocator::Statistics get_memory_allocation_statistics(const BufferType &buffer_type, SubDeviceId sub_device_id = SubDeviceId{0}) const; }; std::ostream& operator<<(std::ostream& os, const MeshDevice& mesh_device); @@ -164,4 +170,44 @@ struct MeshSubDeviceManagerId { std::vector sub_device_manager_ids; }; +namespace detail { +template +concept HasMethodsForArchitectureQueries = requires(T& device) { + { device.compute_with_storage_grid_size() } -> std::same_as; + { device.dram_grid_size() } -> std::same_as; + { device.arch() } -> std::same_as; + { device.num_dram_channels() } -> std::same_as; +}; + +template +concept HasMethodsForAllocator = requires(T& device) { + { device.get_memory_allocation_statistics(std::declval(), std::declval()) } -> std::same_as; +}; + +template +concept HasMethodsForProgramCache = requires(T& device) { + { device.num_program_cache_entries() } -> std::same_as; + { device.enable_program_cache() } -> std::same_as; + { device.disable_and_clear_program_cache() } -> std::same_as; +}; + +template +concept HasMethodsForAsync = requires(T& device) { + { device.enable_async(std::declval()) } -> std::same_as; +}; + +template +concept DeviceInterfaceContract = + HasMethodsForArchitectureQueries && + HasMethodsForAllocator && + HasMethodsForProgramCache && + HasMethodsForAsync; + +} // namespace detail + +// For now static_asserts are used to ensure that the concepts are satisfied. +// This is a temporary compile-time check to make sure that Device/MeshDevice don't deviate from the expected interface. +static_assert(detail::DeviceInterfaceContract, "Device must satisfy the DeviceInterfaceContract concept."); +static_assert(detail::DeviceInterfaceContract, "MeshDevice must satisfy the DeviceInterfaceContract concept."); + } // namespace tt::tt_metal::distributed