Skip to content

Commit

Permalink
Moved optional simple device into creation header, other nit fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
omilyutin-tt committed Nov 26, 2024
1 parent 2426a8d commit 796b656
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 50 deletions.
1 change: 1 addition & 0 deletions ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@ set(ALL_TTNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/clone/clone_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/clone/device/clone_program_factory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/clone/device/clone_device_operation.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/creation.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/sharding_utilities.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/uniform/uniform.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/uniform/uniform_pybind.cpp
Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/ttnn/distributed/api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ Tensor get_device_tensor(const Tensor& multi_device_tensor, const int device_id)
tensor_storage != nullptr && tensor_storage->has_buffer_for_device_id(device_id)) {
return Tensor{
DeviceStorage{tensor_storage->get_buffer_for_device_id(device_id)},
multi_device_tensor.get_legacy_shape(),
multi_device_tensor.get_shape(),
multi_device_tensor.get_dtype(),
multi_device_tensor.get_layout()};
} else if (std::holds_alternative<tt::tt_metal::DeviceStorage>(multi_device_tensor.get_storage())) {
Expand Down
33 changes: 33 additions & 0 deletions ttnn/cpp/ttnn/operations/creation.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include "ttnn/operations/creation.hpp"

#include <optional>

#include "ttnn/simple_device.hpp"

namespace ttnn::operations::creation::detail {

OptionalSimpleDevice::OptionalSimpleDevice(std::nullopt_t) {}
OptionalSimpleDevice::OptionalSimpleDevice(ttnn::SimpleDevice device) :
device_(std::make_optional<ttnn::SimpleDevice>(device)) {}

// TODO: some of these won't be needed, as we unify the APIs.
OptionalSimpleDevice::OptionalSimpleDevice(const std::optional<std::reference_wrapper<tt::tt_metal::Device>>& device) :
device_(device.has_value() ? std::make_optional<SimpleDevice>(&device->get()) : std::nullopt) {}
OptionalSimpleDevice::OptionalSimpleDevice(
const std::optional<std::reference_wrapper<tt::tt_metal::distributed::MeshDevice>>& mesh_device) :
device_(mesh_device.has_value() ? std::make_optional<SimpleDevice>(&mesh_device->get()) : std::nullopt) {}
OptionalSimpleDevice::OptionalSimpleDevice(std::reference_wrapper<tt::tt_metal::Device> device) :
device_(std::make_optional<SimpleDevice>(&device.get())) {}
OptionalSimpleDevice::OptionalSimpleDevice(std::reference_wrapper<tt::tt_metal::distributed::MeshDevice> mesh_device) :
device_(std::make_optional<SimpleDevice>(&mesh_device.get())) {}

OptionalSimpleDevice::OptionalSimpleDevice(tt::tt_metal::Device& device) :
device_(std::make_optional<SimpleDevice>(&device)) {}
OptionalSimpleDevice::OptionalSimpleDevice(tt::tt_metal::distributed::MeshDevice& mesh_device) :
device_(std::make_optional<SimpleDevice>(&mesh_device)) {}

} // namespace ttnn::operations::creation::detail
69 changes: 49 additions & 20 deletions ttnn/cpp/ttnn/operations/creation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,37 @@ struct boxed {
consteval auto invoke() const noexcept -> T { return value; }
};

// Helper class to transparently bind instances of Device / MeshDevice along with their reference wrappers to
// SimpleDevice
class OptionalSimpleDevice {
public:
OptionalSimpleDevice() = default;
OptionalSimpleDevice(std::nullopt_t);
OptionalSimpleDevice(ttnn::SimpleDevice device);
OptionalSimpleDevice(const std::optional<std::reference_wrapper<tt::tt_metal::Device>>& device);
OptionalSimpleDevice(
const std::optional<std::reference_wrapper<tt::tt_metal::distributed::MeshDevice>>& mesh_device);
OptionalSimpleDevice(std::reference_wrapper<tt::tt_metal::Device> device);
OptionalSimpleDevice(std::reference_wrapper<tt::tt_metal::distributed::MeshDevice> mesh_device);
OptionalSimpleDevice(tt::tt_metal::Device& device);
OptionalSimpleDevice(tt::tt_metal::distributed::MeshDevice& mesh_device);

OptionalSimpleDevice(const OptionalSimpleDevice&) = default;
OptionalSimpleDevice& operator=(const OptionalSimpleDevice&) = default;
OptionalSimpleDevice(OptionalSimpleDevice&&) = delete;
OptionalSimpleDevice& operator=(OptionalSimpleDevice&&) = delete;

bool has_value() { return device_.has_value(); }
ttnn::SimpleDevice* operator->() { return &(*device_); }
ttnn::SimpleDevice operator*() { return *device_; }

private:
std::optional<ttnn::SimpleDevice> device_;
};

// Converts an instance of SimpleDevice to a vector of the underlying Devices.
// TODO: Consider moving the helper into a dedicated header with the related utils.
inline std::vector<Device*> get_workers_from_device(ttnn::OptionalSimpleDevice device) {
inline std::vector<Device*> get_workers_from_device(OptionalSimpleDevice device) {
return device.has_value() ? device->get_devices() : std::vector<Device*>{};
}

Expand Down Expand Up @@ -88,10 +116,10 @@ inline ttnn::Tensor full_impl(
const T fill_value,
const std::optional<DataType>& dtype = std::nullopt,
const std::optional<Layout>& layout = std::nullopt,
const std::vector<Device*> workers = {},
const std::vector<Device*>& workers = {},
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<ttnn::Tensor> optional_output_tensor = std::nullopt) {
const std::vector<Device*> workers_to_use =
const std::vector<Device*>& workers_to_use =
optional_output_tensor.has_value() ? optional_output_tensor->get_workers(/*blocking=*/true) : workers;

Layout layout_value = optional_output_tensor.has_value() ? optional_output_tensor.value().get_layout() : layout.value_or(ttnn::ROW_MAJOR_LAYOUT);
Expand All @@ -109,7 +137,7 @@ inline ttnn::Tensor full(
const T fill_value,
const std::optional<DataType>& dtype = std::nullopt,
const std::optional<Layout>& layout = std::nullopt,
ttnn::OptionalSimpleDevice device = std::nullopt,
detail::OptionalSimpleDevice device = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<ttnn::Tensor> optional_output_tensor = std::nullopt,
uint8_t queue_id = ttnn::DefaultQueueId) {
Expand All @@ -132,7 +160,7 @@ struct FullWith {
const ttnn::Shape& shape,
const std::optional<DataType>& dtype = std::nullopt,
const std::optional<Layout>& layout = std::nullopt,
ttnn::OptionalSimpleDevice device = std::nullopt,
detail::OptionalSimpleDevice device = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt) {
return full(shape, fill_value, dtype, layout, device, memory_config);
}
Expand All @@ -151,7 +179,7 @@ inline ttnn::Tensor full_like_impl(
const T fill_value,
const std::optional<DataType>& dtype = std::nullopt,
const std::optional<Layout>& layout = std::nullopt,
ttnn::OptionalSimpleDevice device = std::nullopt,
detail::OptionalSimpleDevice device = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<ttnn::Tensor> optional_output_tensor = std::nullopt) {
Layout layout_value = optional_output_tensor.has_value() ? optional_output_tensor.value().get_layout() : layout.value_or(tensor.get_layout());
Expand Down Expand Up @@ -195,7 +223,7 @@ inline ttnn::Tensor full_like(
const T fill_value,
const std::optional<DataType>& dtype = std::nullopt,
const std::optional<Layout>& layout = std::nullopt,
ttnn::OptionalSimpleDevice device = std::nullopt,
detail::OptionalSimpleDevice device = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt) {
return full_like_impl(ttnn::DefaultQueueId, tensor, fill_value, dtype, layout, device, memory_config, std::nullopt);
}
Expand All @@ -209,7 +237,7 @@ struct FullLikeWith {
const ttnn::Tensor& tensor,
const std::optional<DataType>& dtype = std::nullopt,
const std::optional<Layout>& layout = std::nullopt,
ttnn::OptionalSimpleDevice device = std::nullopt,
detail::OptionalSimpleDevice device = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<ttnn::Tensor> optional_output_tensor = std::nullopt) {
return full_like_impl(queue_id, tensor, fill_value, dtype, layout, device, memory_config, optional_output_tensor);
Expand All @@ -219,7 +247,7 @@ struct FullLikeWith {
const ttnn::Tensor& tensor,
const std::optional<DataType>& dtype = std::nullopt,
const std::optional<Layout>& layout = std::nullopt,
ttnn::OptionalSimpleDevice device = std::nullopt,
detail::OptionalSimpleDevice device = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<ttnn::Tensor> optional_output_tensor = std::nullopt) {
return full_like_impl(ttnn::DefaultQueueId, tensor, fill_value, dtype, layout, device, memory_config, optional_output_tensor);
Expand Down Expand Up @@ -248,9 +276,9 @@ struct EmptyLike {
const ttnn::Tensor& tensor,
const std::optional<DataType>& dtype = std::nullopt,
const std::optional<Layout>& layout = std::nullopt,
ttnn::OptionalSimpleDevice device_arg = std::nullopt,
detail::OptionalSimpleDevice device_arg = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt) {
const std::vector<Device*> devices =
const std::vector<Device*>& devices =
device_arg.has_value() ? device_arg->get_devices() : tensor.get_workers(/*blocking=*/true);
Layout layout_value = layout.value_or(tensor.get_layout());
DataType dtype_value = dtype.value_or(tensor.get_dtype());
Expand All @@ -266,7 +294,7 @@ struct Full {
const float fill_value,
const std::optional<DataType>& dtype = std::nullopt,
const std::optional<Layout>& layout = std::nullopt,
ttnn::OptionalSimpleDevice device = std::nullopt,
detail::OptionalSimpleDevice device = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<ttnn::Tensor> optional_output_tensor = std::nullopt) {
return full_impl(
Expand All @@ -286,7 +314,7 @@ struct Full {
const int fill_value,
const std::optional<DataType>& dtype = std::nullopt,
const std::optional<Layout>& layout = std::nullopt,
ttnn::OptionalSimpleDevice device = std::nullopt,
detail::OptionalSimpleDevice device = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<ttnn::Tensor> optional_output_tensor = std::nullopt) {
return full_impl(
Expand All @@ -305,7 +333,7 @@ struct Full {
const float fill_value,
const std::optional<DataType>& dtype = std::nullopt,
const std::optional<Layout>& layout = std::nullopt,
ttnn::OptionalSimpleDevice device = std::nullopt,
detail::OptionalSimpleDevice device = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<ttnn::Tensor> optional_output_tensor = std::nullopt) {
return full_impl(
Expand All @@ -324,7 +352,7 @@ struct Full {
const int fill_value,
const std::optional<DataType>& dtype = std::nullopt,
const std::optional<Layout>& layout = std::nullopt,
ttnn::OptionalSimpleDevice device = std::nullopt,
detail::OptionalSimpleDevice device = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<ttnn::Tensor> optional_output_tensor = std::nullopt) {
return full_impl(
Expand All @@ -346,7 +374,7 @@ struct FullLike {
const float fill_value,
const std::optional<DataType>& dtype = std::nullopt,
const std::optional<Layout>& layout = std::nullopt,
ttnn::OptionalSimpleDevice device = std::nullopt,
detail::OptionalSimpleDevice device = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<ttnn::Tensor> optional_output_tensor = std::nullopt) {
return full_like_impl(queue_id, tensor, fill_value, dtype, layout, device, memory_config, optional_output_tensor);
Expand All @@ -358,7 +386,7 @@ struct FullLike {
const int fill_value,
const std::optional<DataType>& dtype = std::nullopt,
const std::optional<Layout>& layout = std::nullopt,
ttnn::OptionalSimpleDevice device = std::nullopt,
detail::OptionalSimpleDevice device = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<ttnn::Tensor> optional_output_tensor = std::nullopt) {
return full_like_impl(queue_id, tensor, fill_value, dtype, layout, device, memory_config, optional_output_tensor);
Expand All @@ -369,7 +397,7 @@ struct FullLike {
const float fill_value,
const std::optional<DataType>& dtype = std::nullopt,
const std::optional<Layout>& layout = std::nullopt,
ttnn::OptionalSimpleDevice device = std::nullopt,
detail::OptionalSimpleDevice device = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<ttnn::Tensor> optional_output_tensor = std::nullopt) {
return full_like_impl(ttnn::DefaultQueueId, tensor, fill_value, dtype, layout, device, memory_config, optional_output_tensor);
Expand All @@ -380,7 +408,7 @@ struct FullLike {
const int fill_value,
const std::optional<DataType>& dtype = std::nullopt,
const std::optional<Layout>& layout = std::nullopt,
ttnn::OptionalSimpleDevice device = std::nullopt,
detail::OptionalSimpleDevice device = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<ttnn::Tensor> optional_output_tensor = std::nullopt) {
return full_like_impl(ttnn::DefaultQueueId, tensor, fill_value, dtype, layout, device, memory_config, optional_output_tensor);
Expand Down Expand Up @@ -438,6 +466,7 @@ constexpr auto ones_like =
constexpr auto empty_like =
ttnn::decorators::register_operation<"ttnn::empty_like", ttnn::operations::creation::EmptyLike>();

constexpr auto arange = ttnn::decorators::register_operation<"ttnn::arange", ttnn::operations::creation::Arange>();
constexpr auto arange =
ttnn::decorators::register_operation_with_auto_launch_op<"ttnn::arange", ttnn::operations::creation::Arange>();

} // namespace ttnn
33 changes: 4 additions & 29 deletions ttnn/cpp/ttnn/simple_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

#pragma once

#include <functional>

#include "tt_metal/distributed/mesh_device.hpp"
#include "tt_metal/impl/device/device.hpp"

Expand All @@ -22,6 +24,8 @@ class SimpleDevice {
SimpleDevice(tt::tt_metal::distributed::MeshDevice* mesh_device) : metal_device_{mesh_device} {}
SimpleDevice(const SimpleDevice&) = default;
SimpleDevice& operator=(const SimpleDevice&) = default;
SimpleDevice(SimpleDevice&&) = delete;
SimpleDevice& operator=(SimpleDevice&&) = delete;

std::vector<tt::tt_metal::Device*> get_devices() {
if (auto* device = std::get_if<tt::tt_metal::Device*>(&metal_device_); device != nullptr) {
Expand All @@ -35,33 +39,4 @@ class SimpleDevice {
std::variant<tt::tt_metal::Device*, tt::tt_metal::distributed::MeshDevice*> metal_device_;
};

class OptionalSimpleDevice {
public:
// Allow implicit conversions for transparent migration.
OptionalSimpleDevice(std::nullopt_t) {}
OptionalSimpleDevice(SimpleDevice device) : device_(std::make_optional<SimpleDevice>(device)) {}

// TODO: some of these won't be needed, as we unify the APIs.
OptionalSimpleDevice(const std::optional<std::reference_wrapper<tt::tt_metal::Device>>& device) :
device_(device.has_value() ? std::make_optional<SimpleDevice>(&device->get()) : std::nullopt) {}
OptionalSimpleDevice(
const std::optional<std::reference_wrapper<tt::tt_metal::distributed::MeshDevice>>& mesh_device) :
device_(mesh_device.has_value() ? std::make_optional<SimpleDevice>(&mesh_device->get()) : std::nullopt) {}
OptionalSimpleDevice(std::reference_wrapper<tt::tt_metal::Device> device) :
device_(std::make_optional<SimpleDevice>(&device.get())) {}
OptionalSimpleDevice(std::reference_wrapper<tt::tt_metal::distributed::MeshDevice> mesh_device) :
device_(std::make_optional<SimpleDevice>(&mesh_device.get())) {}

OptionalSimpleDevice(tt::tt_metal::Device& device) : device_(std::make_optional<SimpleDevice>(&device)) {}
OptionalSimpleDevice(tt::tt_metal::distributed::MeshDevice& mesh_device) :
device_(std::make_optional<SimpleDevice>(&mesh_device)) {}

bool has_value() { return device_.has_value(); }
SimpleDevice* operator->() { return &(*device_); }
SimpleDevice operator*() { return *device_; }

private:
std::optional<SimpleDevice> device_;
};

} // namespace ttnn

0 comments on commit 796b656

Please sign in to comment.