From 796b656198894bd368263da85cfd64263c702208 Mon Sep 17 00:00:00 2001 From: Oleg Milyutin Date: Mon, 25 Nov 2024 21:24:48 +0000 Subject: [PATCH] Moved optional simple device into creation header, other nit fixes --- ttnn/CMakeLists.txt | 1 + ttnn/cpp/ttnn/distributed/api.cpp | 2 +- ttnn/cpp/ttnn/operations/creation.cpp | 33 +++++++++++++ ttnn/cpp/ttnn/operations/creation.hpp | 69 +++++++++++++++++++-------- ttnn/cpp/ttnn/simple_device.hpp | 33 ++----------- 5 files changed, 88 insertions(+), 50 deletions(-) create mode 100644 ttnn/cpp/ttnn/operations/creation.cpp diff --git a/ttnn/CMakeLists.txt b/ttnn/CMakeLists.txt index 4ef870a0578d..63c56e92d738 100644 --- a/ttnn/CMakeLists.txt +++ b/ttnn/CMakeLists.txt @@ -391,6 +391,7 @@ set(ALL_TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/clone/clone_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/clone/device/clone_program_factory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/clone/device/clone_device_operation.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/creation.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/sharding_utilities.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/uniform/uniform.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/uniform/uniform_pybind.cpp diff --git a/ttnn/cpp/ttnn/distributed/api.cpp b/ttnn/cpp/ttnn/distributed/api.cpp index 74f8725bf970..f8738bf60d7a 100644 --- a/ttnn/cpp/ttnn/distributed/api.cpp +++ b/ttnn/cpp/ttnn/distributed/api.cpp @@ -161,7 +161,7 @@ Tensor get_device_tensor(const Tensor& multi_device_tensor, const int device_id) tensor_storage != nullptr && tensor_storage->has_buffer_for_device_id(device_id)) { return Tensor{ DeviceStorage{tensor_storage->get_buffer_for_device_id(device_id)}, - multi_device_tensor.get_legacy_shape(), + multi_device_tensor.get_shape(), multi_device_tensor.get_dtype(), multi_device_tensor.get_layout()}; } else if (std::holds_alternative(multi_device_tensor.get_storage())) { diff --git a/ttnn/cpp/ttnn/operations/creation.cpp b/ttnn/cpp/ttnn/operations/creation.cpp new file mode 100644 index 000000000000..2cf5a0ec7dbd --- /dev/null +++ b/ttnn/cpp/ttnn/operations/creation.cpp @@ -0,0 +1,33 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ttnn/operations/creation.hpp" + +#include + +#include "ttnn/simple_device.hpp" + +namespace ttnn::operations::creation::detail { + +OptionalSimpleDevice::OptionalSimpleDevice(std::nullopt_t) {} +OptionalSimpleDevice::OptionalSimpleDevice(ttnn::SimpleDevice device) : + device_(std::make_optional(device)) {} + +// TODO: some of these won't be needed, as we unify the APIs. +OptionalSimpleDevice::OptionalSimpleDevice(const std::optional>& device) : + device_(device.has_value() ? std::make_optional(&device->get()) : std::nullopt) {} +OptionalSimpleDevice::OptionalSimpleDevice( + const std::optional>& mesh_device) : + device_(mesh_device.has_value() ? std::make_optional(&mesh_device->get()) : std::nullopt) {} +OptionalSimpleDevice::OptionalSimpleDevice(std::reference_wrapper device) : + device_(std::make_optional(&device.get())) {} +OptionalSimpleDevice::OptionalSimpleDevice(std::reference_wrapper mesh_device) : + device_(std::make_optional(&mesh_device.get())) {} + +OptionalSimpleDevice::OptionalSimpleDevice(tt::tt_metal::Device& device) : + device_(std::make_optional(&device)) {} +OptionalSimpleDevice::OptionalSimpleDevice(tt::tt_metal::distributed::MeshDevice& mesh_device) : + device_(std::make_optional(&mesh_device)) {} + +} // namespace ttnn::operations::creation::detail diff --git a/ttnn/cpp/ttnn/operations/creation.hpp b/ttnn/cpp/ttnn/operations/creation.hpp index dc3a298cbfa8..69184adc08d0 100644 --- a/ttnn/cpp/ttnn/operations/creation.hpp +++ b/ttnn/cpp/ttnn/operations/creation.hpp @@ -36,9 +36,37 @@ struct boxed { consteval auto invoke() const noexcept -> T { return value; } }; +// Helper class to transparently bind instances of Device / MeshDevice along with their reference wrappers to +// SimpleDevice +class OptionalSimpleDevice { + public: + OptionalSimpleDevice() = default; + OptionalSimpleDevice(std::nullopt_t); + OptionalSimpleDevice(ttnn::SimpleDevice device); + OptionalSimpleDevice(const std::optional>& device); + OptionalSimpleDevice( + const std::optional>& mesh_device); + OptionalSimpleDevice(std::reference_wrapper device); + OptionalSimpleDevice(std::reference_wrapper mesh_device); + OptionalSimpleDevice(tt::tt_metal::Device& device); + OptionalSimpleDevice(tt::tt_metal::distributed::MeshDevice& mesh_device); + + OptionalSimpleDevice(const OptionalSimpleDevice&) = default; + OptionalSimpleDevice& operator=(const OptionalSimpleDevice&) = default; + OptionalSimpleDevice(OptionalSimpleDevice&&) = delete; + OptionalSimpleDevice& operator=(OptionalSimpleDevice&&) = delete; + + bool has_value() { return device_.has_value(); } + ttnn::SimpleDevice* operator->() { return &(*device_); } + ttnn::SimpleDevice operator*() { return *device_; } + + private: + std::optional device_; +}; + // Converts an instance of SimpleDevice to a vector of the underlying Devices. // TODO: Consider moving the helper into a dedicated header with the related utils. -inline std::vector get_workers_from_device(ttnn::OptionalSimpleDevice device) { +inline std::vector get_workers_from_device(OptionalSimpleDevice device) { return device.has_value() ? device->get_devices() : std::vector{}; } @@ -88,10 +116,10 @@ inline ttnn::Tensor full_impl( const T fill_value, const std::optional& dtype = std::nullopt, const std::optional& layout = std::nullopt, - const std::vector workers = {}, + const std::vector& workers = {}, const std::optional& memory_config = std::nullopt, std::optional optional_output_tensor = std::nullopt) { - const std::vector workers_to_use = + const std::vector& workers_to_use = optional_output_tensor.has_value() ? optional_output_tensor->get_workers(/*blocking=*/true) : workers; Layout layout_value = optional_output_tensor.has_value() ? optional_output_tensor.value().get_layout() : layout.value_or(ttnn::ROW_MAJOR_LAYOUT); @@ -109,7 +137,7 @@ inline ttnn::Tensor full( const T fill_value, const std::optional& dtype = std::nullopt, const std::optional& layout = std::nullopt, - ttnn::OptionalSimpleDevice device = std::nullopt, + detail::OptionalSimpleDevice device = std::nullopt, const std::optional& memory_config = std::nullopt, std::optional optional_output_tensor = std::nullopt, uint8_t queue_id = ttnn::DefaultQueueId) { @@ -132,7 +160,7 @@ struct FullWith { const ttnn::Shape& shape, const std::optional& dtype = std::nullopt, const std::optional& layout = std::nullopt, - ttnn::OptionalSimpleDevice device = std::nullopt, + detail::OptionalSimpleDevice device = std::nullopt, const std::optional& memory_config = std::nullopt) { return full(shape, fill_value, dtype, layout, device, memory_config); } @@ -151,7 +179,7 @@ inline ttnn::Tensor full_like_impl( const T fill_value, const std::optional& dtype = std::nullopt, const std::optional& layout = std::nullopt, - ttnn::OptionalSimpleDevice device = std::nullopt, + detail::OptionalSimpleDevice device = std::nullopt, const std::optional& memory_config = std::nullopt, std::optional optional_output_tensor = std::nullopt) { Layout layout_value = optional_output_tensor.has_value() ? optional_output_tensor.value().get_layout() : layout.value_or(tensor.get_layout()); @@ -195,7 +223,7 @@ inline ttnn::Tensor full_like( const T fill_value, const std::optional& dtype = std::nullopt, const std::optional& layout = std::nullopt, - ttnn::OptionalSimpleDevice device = std::nullopt, + detail::OptionalSimpleDevice device = std::nullopt, const std::optional& memory_config = std::nullopt) { return full_like_impl(ttnn::DefaultQueueId, tensor, fill_value, dtype, layout, device, memory_config, std::nullopt); } @@ -209,7 +237,7 @@ struct FullLikeWith { const ttnn::Tensor& tensor, const std::optional& dtype = std::nullopt, const std::optional& layout = std::nullopt, - ttnn::OptionalSimpleDevice device = std::nullopt, + detail::OptionalSimpleDevice device = std::nullopt, const std::optional& memory_config = std::nullopt, std::optional optional_output_tensor = std::nullopt) { return full_like_impl(queue_id, tensor, fill_value, dtype, layout, device, memory_config, optional_output_tensor); @@ -219,7 +247,7 @@ struct FullLikeWith { const ttnn::Tensor& tensor, const std::optional& dtype = std::nullopt, const std::optional& layout = std::nullopt, - ttnn::OptionalSimpleDevice device = std::nullopt, + detail::OptionalSimpleDevice device = std::nullopt, const std::optional& memory_config = std::nullopt, std::optional optional_output_tensor = std::nullopt) { return full_like_impl(ttnn::DefaultQueueId, tensor, fill_value, dtype, layout, device, memory_config, optional_output_tensor); @@ -248,9 +276,9 @@ struct EmptyLike { const ttnn::Tensor& tensor, const std::optional& dtype = std::nullopt, const std::optional& layout = std::nullopt, - ttnn::OptionalSimpleDevice device_arg = std::nullopt, + detail::OptionalSimpleDevice device_arg = std::nullopt, const std::optional& memory_config = std::nullopt) { - const std::vector devices = + const std::vector& devices = device_arg.has_value() ? device_arg->get_devices() : tensor.get_workers(/*blocking=*/true); Layout layout_value = layout.value_or(tensor.get_layout()); DataType dtype_value = dtype.value_or(tensor.get_dtype()); @@ -266,7 +294,7 @@ struct Full { const float fill_value, const std::optional& dtype = std::nullopt, const std::optional& layout = std::nullopt, - ttnn::OptionalSimpleDevice device = std::nullopt, + detail::OptionalSimpleDevice device = std::nullopt, const std::optional& memory_config = std::nullopt, std::optional optional_output_tensor = std::nullopt) { return full_impl( @@ -286,7 +314,7 @@ struct Full { const int fill_value, const std::optional& dtype = std::nullopt, const std::optional& layout = std::nullopt, - ttnn::OptionalSimpleDevice device = std::nullopt, + detail::OptionalSimpleDevice device = std::nullopt, const std::optional& memory_config = std::nullopt, std::optional optional_output_tensor = std::nullopt) { return full_impl( @@ -305,7 +333,7 @@ struct Full { const float fill_value, const std::optional& dtype = std::nullopt, const std::optional& layout = std::nullopt, - ttnn::OptionalSimpleDevice device = std::nullopt, + detail::OptionalSimpleDevice device = std::nullopt, const std::optional& memory_config = std::nullopt, std::optional optional_output_tensor = std::nullopt) { return full_impl( @@ -324,7 +352,7 @@ struct Full { const int fill_value, const std::optional& dtype = std::nullopt, const std::optional& layout = std::nullopt, - ttnn::OptionalSimpleDevice device = std::nullopt, + detail::OptionalSimpleDevice device = std::nullopt, const std::optional& memory_config = std::nullopt, std::optional optional_output_tensor = std::nullopt) { return full_impl( @@ -346,7 +374,7 @@ struct FullLike { const float fill_value, const std::optional& dtype = std::nullopt, const std::optional& layout = std::nullopt, - ttnn::OptionalSimpleDevice device = std::nullopt, + detail::OptionalSimpleDevice device = std::nullopt, const std::optional& memory_config = std::nullopt, std::optional optional_output_tensor = std::nullopt) { return full_like_impl(queue_id, tensor, fill_value, dtype, layout, device, memory_config, optional_output_tensor); @@ -358,7 +386,7 @@ struct FullLike { const int fill_value, const std::optional& dtype = std::nullopt, const std::optional& layout = std::nullopt, - ttnn::OptionalSimpleDevice device = std::nullopt, + detail::OptionalSimpleDevice device = std::nullopt, const std::optional& memory_config = std::nullopt, std::optional optional_output_tensor = std::nullopt) { return full_like_impl(queue_id, tensor, fill_value, dtype, layout, device, memory_config, optional_output_tensor); @@ -369,7 +397,7 @@ struct FullLike { const float fill_value, const std::optional& dtype = std::nullopt, const std::optional& layout = std::nullopt, - ttnn::OptionalSimpleDevice device = std::nullopt, + detail::OptionalSimpleDevice device = std::nullopt, const std::optional& memory_config = std::nullopt, std::optional optional_output_tensor = std::nullopt) { return full_like_impl(ttnn::DefaultQueueId, tensor, fill_value, dtype, layout, device, memory_config, optional_output_tensor); @@ -380,7 +408,7 @@ struct FullLike { const int fill_value, const std::optional& dtype = std::nullopt, const std::optional& layout = std::nullopt, - ttnn::OptionalSimpleDevice device = std::nullopt, + detail::OptionalSimpleDevice device = std::nullopt, const std::optional& memory_config = std::nullopt, std::optional optional_output_tensor = std::nullopt) { return full_like_impl(ttnn::DefaultQueueId, tensor, fill_value, dtype, layout, device, memory_config, optional_output_tensor); @@ -438,6 +466,7 @@ constexpr auto ones_like = constexpr auto empty_like = ttnn::decorators::register_operation<"ttnn::empty_like", ttnn::operations::creation::EmptyLike>(); -constexpr auto arange = ttnn::decorators::register_operation<"ttnn::arange", ttnn::operations::creation::Arange>(); +constexpr auto arange = + ttnn::decorators::register_operation_with_auto_launch_op<"ttnn::arange", ttnn::operations::creation::Arange>(); } // namespace ttnn diff --git a/ttnn/cpp/ttnn/simple_device.hpp b/ttnn/cpp/ttnn/simple_device.hpp index 96d6622139a3..af1f3d0bda9a 100644 --- a/ttnn/cpp/ttnn/simple_device.hpp +++ b/ttnn/cpp/ttnn/simple_device.hpp @@ -4,6 +4,8 @@ #pragma once +#include + #include "tt_metal/distributed/mesh_device.hpp" #include "tt_metal/impl/device/device.hpp" @@ -22,6 +24,8 @@ class SimpleDevice { SimpleDevice(tt::tt_metal::distributed::MeshDevice* mesh_device) : metal_device_{mesh_device} {} SimpleDevice(const SimpleDevice&) = default; SimpleDevice& operator=(const SimpleDevice&) = default; + SimpleDevice(SimpleDevice&&) = delete; + SimpleDevice& operator=(SimpleDevice&&) = delete; std::vector get_devices() { if (auto* device = std::get_if(&metal_device_); device != nullptr) { @@ -35,33 +39,4 @@ class SimpleDevice { std::variant metal_device_; }; -class OptionalSimpleDevice { - public: - // Allow implicit conversions for transparent migration. - OptionalSimpleDevice(std::nullopt_t) {} - OptionalSimpleDevice(SimpleDevice device) : device_(std::make_optional(device)) {} - - // TODO: some of these won't be needed, as we unify the APIs. - OptionalSimpleDevice(const std::optional>& device) : - device_(device.has_value() ? std::make_optional(&device->get()) : std::nullopt) {} - OptionalSimpleDevice( - const std::optional>& mesh_device) : - device_(mesh_device.has_value() ? std::make_optional(&mesh_device->get()) : std::nullopt) {} - OptionalSimpleDevice(std::reference_wrapper device) : - device_(std::make_optional(&device.get())) {} - OptionalSimpleDevice(std::reference_wrapper mesh_device) : - device_(std::make_optional(&mesh_device.get())) {} - - OptionalSimpleDevice(tt::tt_metal::Device& device) : device_(std::make_optional(&device)) {} - OptionalSimpleDevice(tt::tt_metal::distributed::MeshDevice& mesh_device) : - device_(std::make_optional(&mesh_device)) {} - - bool has_value() { return device_.has_value(); } - SimpleDevice* operator->() { return &(*device_); } - SimpleDevice operator*() { return *device_; } - - private: - std::optional device_; -}; - } // namespace ttnn