Skip to content

Commit

Permalink
Address renaming comments
Browse files Browse the repository at this point in the history
  • Loading branch information
omilyutin-tt committed Nov 27, 2024
1 parent 96de4a9 commit 3f7d6c0
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 83 deletions.
2 changes: 1 addition & 1 deletion ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,7 @@ set(TTNN_PUBLIC_LINK_DIRS "")
set(TTNN_PRECOMPILED_HEADERS
${PROJECT_SOURCE_DIR}/tt_metal/tt_stl/reflection.hpp
${PROJECT_SOURCE_DIR}/ttnn/cpp/ttnn/operation.hpp
${PROJECT_SOURCE_DIR}/ttnn/cpp/ttnn/simple_device.hpp
${PROJECT_SOURCE_DIR}/ttnn/cpp/ttnn/any_device.hpp
${PROJECT_SOURCE_DIR}/tt_metal/third_party/tracy/public/tracy/Tracy.hpp
${PROJECT_SOURCE_DIR}/tt_metal/third_party/umd/device/device_api_metal.h
${PROJECT_SOURCE_DIR}/tt_metal/third_party/umd/device/cluster.h
Expand Down
24 changes: 11 additions & 13 deletions ttnn/cpp/ttnn/simple_device.hpp → ttnn/cpp/ttnn/any_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,26 @@

#pragma once

#include <functional>

#include "tt_metal/distributed/mesh_device.hpp"
#include "tt_metal/impl/device/device.hpp"

namespace ttnn {

// SimpleDevice is a wrapper around Device / MeshDevice to use in interfaces that can accept either.
// AnyDevice is a wrapper around Device / MeshDevice to use in interfaces that can accept either.
// This class is cheaply copyable, use value semantics to pass it around.
//
// TODO: the eventual goal is to lower this primitive into tt_metal. In the long term, we also want to extend the
// functionality with the "distributed device" semantics.
class SimpleDevice {
public:
class AnyDevice {
public:
// Allow implicit conversion for transparent migration.
// Expect the pointers to be non-null, and remain valid for the lifetime of SimpleDevice.
SimpleDevice(tt::tt_metal::Device* device) : metal_device_{device} {}
SimpleDevice(tt::tt_metal::distributed::MeshDevice* mesh_device) : metal_device_{mesh_device} {}
SimpleDevice(const SimpleDevice&) = default;
SimpleDevice& operator=(const SimpleDevice&) = default;
SimpleDevice(SimpleDevice&&) = delete;
SimpleDevice& operator=(SimpleDevice&&) = delete;
// Expect the pointers to be non-null, and remain valid for the lifetime of AnyDevice.
AnyDevice(tt::tt_metal::Device* device) : metal_device_{device} {}
AnyDevice(tt::tt_metal::distributed::MeshDevice* mesh_device) : metal_device_{mesh_device} {}
AnyDevice(const AnyDevice&) = default;
AnyDevice& operator=(const AnyDevice&) = default;
AnyDevice(AnyDevice&&) = delete;
AnyDevice& operator=(AnyDevice&&) = delete;

std::vector<tt::tt_metal::Device*> get_devices() {
if (auto* device = std::get_if<tt::tt_metal::Device*>(&metal_device_); device != nullptr) {
Expand All @@ -35,7 +33,7 @@ class SimpleDevice {
}
}

private:
private:
std::variant<tt::tt_metal::Device*, tt::tt_metal::distributed::MeshDevice*> metal_device_;
};

Expand Down
5 changes: 3 additions & 2 deletions ttnn/cpp/ttnn/operations/core/core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "ttnn/distributed/types.hpp"
#include "ttnn/operations/data_movement/sharded/sharded_to_interleaved/sharded_to_interleaved.hpp"
#include "ttnn/operations/data_movement/sharded/interleaved_to_sharded/interleaved_to_sharded.hpp"
#include "ttnn/tensor/tensor.hpp"

namespace ttnn::operations::core {

Expand Down Expand Up @@ -87,7 +88,7 @@ ttnn::Tensor allocate_tensor_on_device(
Layout layout,
Device* device,
const std::optional<MemoryConfig>& memory_config) {
return tt::tt_metal::allocate_tensor_on_workers(
return tt::tt_metal::allocate_tensor_on_devices(
shape, data_type, layout, {device}, memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG));
}

Expand All @@ -97,7 +98,7 @@ ttnn::Tensor allocate_tensor_on_device(
Layout layout,
MeshDevice* mesh_device,
const std::optional<MemoryConfig>& memory_config) {
return tt::tt_metal::allocate_tensor_on_workers(
return tt::tt_metal::allocate_tensor_on_devices(
shape, data_type, layout, mesh_device->get_devices(), memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG));
}

Expand Down
32 changes: 15 additions & 17 deletions ttnn/cpp/ttnn/operations/creation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,26 @@

#include <optional>

#include "ttnn/simple_device.hpp"
#include "ttnn/any_device.hpp"

namespace ttnn::operations::creation::detail {

OptionalSimpleDevice::OptionalSimpleDevice(std::nullopt_t) {}
OptionalSimpleDevice::OptionalSimpleDevice(ttnn::SimpleDevice device) :
device_(std::make_optional<ttnn::SimpleDevice>(device)) {}
OptionalAnyDevice::OptionalAnyDevice(std::nullopt_t) {}
OptionalAnyDevice::OptionalAnyDevice(ttnn::AnyDevice device) : device_(std::make_optional<ttnn::AnyDevice>(device)) {}

// TODO: some of these won't be needed, as we unify the APIs.
OptionalSimpleDevice::OptionalSimpleDevice(const std::optional<std::reference_wrapper<tt::tt_metal::Device>>& device) :
device_(device.has_value() ? std::make_optional<SimpleDevice>(&device->get()) : std::nullopt) {}
OptionalSimpleDevice::OptionalSimpleDevice(
OptionalAnyDevice::OptionalAnyDevice(const std::optional<std::reference_wrapper<tt::tt_metal::Device>>& device) :
device_(device.has_value() ? std::make_optional<AnyDevice>(&device->get()) : std::nullopt) {}
OptionalAnyDevice::OptionalAnyDevice(
const std::optional<std::reference_wrapper<tt::tt_metal::distributed::MeshDevice>>& mesh_device) :
device_(mesh_device.has_value() ? std::make_optional<SimpleDevice>(&mesh_device->get()) : std::nullopt) {}
OptionalSimpleDevice::OptionalSimpleDevice(std::reference_wrapper<tt::tt_metal::Device> device) :
device_(std::make_optional<SimpleDevice>(&device.get())) {}
OptionalSimpleDevice::OptionalSimpleDevice(std::reference_wrapper<tt::tt_metal::distributed::MeshDevice> mesh_device) :
device_(std::make_optional<SimpleDevice>(&mesh_device.get())) {}

OptionalSimpleDevice::OptionalSimpleDevice(tt::tt_metal::Device& device) :
device_(std::make_optional<SimpleDevice>(&device)) {}
OptionalSimpleDevice::OptionalSimpleDevice(tt::tt_metal::distributed::MeshDevice& mesh_device) :
device_(std::make_optional<SimpleDevice>(&mesh_device)) {}
device_(mesh_device.has_value() ? std::make_optional<AnyDevice>(&mesh_device->get()) : std::nullopt) {}
OptionalAnyDevice::OptionalAnyDevice(std::reference_wrapper<tt::tt_metal::Device> device) :
device_(std::make_optional<AnyDevice>(&device.get())) {}
OptionalAnyDevice::OptionalAnyDevice(std::reference_wrapper<tt::tt_metal::distributed::MeshDevice> mesh_device) :
device_(std::make_optional<AnyDevice>(&mesh_device.get())) {}

OptionalAnyDevice::OptionalAnyDevice(tt::tt_metal::Device& device) : device_(std::make_optional<AnyDevice>(&device)) {}
OptionalAnyDevice::OptionalAnyDevice(tt::tt_metal::distributed::MeshDevice& mesh_device) :
device_(std::make_optional<AnyDevice>(&mesh_device)) {}

} // namespace ttnn::operations::creation::detail
87 changes: 43 additions & 44 deletions ttnn/cpp/ttnn/operations/creation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#include "ttnn/distributed/types.hpp"
#include "ttnn/operations/eltwise/unary/unary.hpp"
#include "ttnn/operations/numpy/functions.hpp"
#include "ttnn/simple_device.hpp"
#include "ttnn/any_device.hpp"
#include "ttnn/tensor/tensor.hpp"
#include "ttnn/tensor/tensor_utils.hpp"
#include "ttnn/tensor/types.hpp"
Expand All @@ -37,36 +37,35 @@ struct boxed {
};

// Helper class to transparently bind instances of Device / MeshDevice along with their reference wrappers to
// SimpleDevice
class OptionalSimpleDevice {
public:
OptionalSimpleDevice() = default;
OptionalSimpleDevice(std::nullopt_t);
OptionalSimpleDevice(ttnn::SimpleDevice device);
OptionalSimpleDevice(const std::optional<std::reference_wrapper<tt::tt_metal::Device>>& device);
OptionalSimpleDevice(
const std::optional<std::reference_wrapper<tt::tt_metal::distributed::MeshDevice>>& mesh_device);
OptionalSimpleDevice(std::reference_wrapper<tt::tt_metal::Device> device);
OptionalSimpleDevice(std::reference_wrapper<tt::tt_metal::distributed::MeshDevice> mesh_device);
OptionalSimpleDevice(tt::tt_metal::Device& device);
OptionalSimpleDevice(tt::tt_metal::distributed::MeshDevice& mesh_device);

OptionalSimpleDevice(const OptionalSimpleDevice&) = default;
OptionalSimpleDevice& operator=(const OptionalSimpleDevice&) = default;
OptionalSimpleDevice(OptionalSimpleDevice&&) = delete;
OptionalSimpleDevice& operator=(OptionalSimpleDevice&&) = delete;
// AnyDevice
class OptionalAnyDevice {
public:
OptionalAnyDevice() = default;
OptionalAnyDevice(std::nullopt_t);
OptionalAnyDevice(ttnn::AnyDevice device);
OptionalAnyDevice(const std::optional<std::reference_wrapper<tt::tt_metal::Device>>& device);
OptionalAnyDevice(const std::optional<std::reference_wrapper<tt::tt_metal::distributed::MeshDevice>>& mesh_device);
OptionalAnyDevice(std::reference_wrapper<tt::tt_metal::Device> device);
OptionalAnyDevice(std::reference_wrapper<tt::tt_metal::distributed::MeshDevice> mesh_device);
OptionalAnyDevice(tt::tt_metal::Device& device);
OptionalAnyDevice(tt::tt_metal::distributed::MeshDevice& mesh_device);

OptionalAnyDevice(const OptionalAnyDevice&) = default;
OptionalAnyDevice& operator=(const OptionalAnyDevice&) = default;
OptionalAnyDevice(OptionalAnyDevice&&) = delete;
OptionalAnyDevice& operator=(OptionalAnyDevice&&) = delete;

bool has_value() { return device_.has_value(); }
ttnn::SimpleDevice* operator->() { return &(*device_); }
ttnn::SimpleDevice operator*() { return *device_; }
ttnn::AnyDevice* operator->() { return &(*device_); }
ttnn::AnyDevice operator*() { return *device_; }

private:
std::optional<ttnn::SimpleDevice> device_;
private:
std::optional<ttnn::AnyDevice> device_;
};

// Converts an instance of SimpleDevice to a vector of the underlying Devices.
// Converts an instance of AnyDevice to a vector of the underlying Devices.
// TODO: Consider moving the helper into a dedicated header with the related utils.
inline std::vector<Device*> get_workers_from_device(OptionalSimpleDevice device) {
inline std::vector<Device*> get_workers_from_device(OptionalAnyDevice device) {
return device.has_value() ? device->get_devices() : std::vector<Device*>{};
}

Expand Down Expand Up @@ -137,7 +136,7 @@ inline ttnn::Tensor full(
const T fill_value,
const std::optional<DataType>& dtype = std::nullopt,
const std::optional<Layout>& layout = std::nullopt,
detail::OptionalSimpleDevice device = std::nullopt,
detail::OptionalAnyDevice device = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<ttnn::Tensor> optional_output_tensor = std::nullopt,
uint8_t queue_id = ttnn::DefaultQueueId) {
Expand All @@ -160,7 +159,7 @@ struct FullWith {
const ttnn::Shape& shape,
const std::optional<DataType>& dtype = std::nullopt,
const std::optional<Layout>& layout = std::nullopt,
detail::OptionalSimpleDevice device = std::nullopt,
detail::OptionalAnyDevice device = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt) {
return full(shape, fill_value, dtype, layout, device, memory_config);
}
Expand All @@ -179,7 +178,7 @@ inline ttnn::Tensor full_like_impl(
const T fill_value,
const std::optional<DataType>& dtype = std::nullopt,
const std::optional<Layout>& layout = std::nullopt,
detail::OptionalSimpleDevice device = std::nullopt,
detail::OptionalAnyDevice device = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<ttnn::Tensor> optional_output_tensor = std::nullopt) {
Layout layout_value = optional_output_tensor.has_value() ? optional_output_tensor.value().get_layout() : layout.value_or(tensor.get_layout());
Expand Down Expand Up @@ -223,7 +222,7 @@ inline ttnn::Tensor full_like(
const T fill_value,
const std::optional<DataType>& dtype = std::nullopt,
const std::optional<Layout>& layout = std::nullopt,
detail::OptionalSimpleDevice device = std::nullopt,
detail::OptionalAnyDevice device = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt) {
return full_like_impl(ttnn::DefaultQueueId, tensor, fill_value, dtype, layout, device, memory_config, std::nullopt);
}
Expand All @@ -237,7 +236,7 @@ struct FullLikeWith {
const ttnn::Tensor& tensor,
const std::optional<DataType>& dtype = std::nullopt,
const std::optional<Layout>& layout = std::nullopt,
detail::OptionalSimpleDevice device = std::nullopt,
detail::OptionalAnyDevice device = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<ttnn::Tensor> optional_output_tensor = std::nullopt) {
return full_like_impl(queue_id, tensor, fill_value, dtype, layout, device, memory_config, optional_output_tensor);
Expand All @@ -247,7 +246,7 @@ struct FullLikeWith {
const ttnn::Tensor& tensor,
const std::optional<DataType>& dtype = std::nullopt,
const std::optional<Layout>& layout = std::nullopt,
detail::OptionalSimpleDevice device = std::nullopt,
detail::OptionalAnyDevice device = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<ttnn::Tensor> optional_output_tensor = std::nullopt) {
return full_like_impl(ttnn::DefaultQueueId, tensor, fill_value, dtype, layout, device, memory_config, optional_output_tensor);
Expand All @@ -265,9 +264,9 @@ struct Empty {
const ttnn::Shape& shape,
const DataType& dtype,
const Layout& layout,
ttnn::SimpleDevice device,
ttnn::AnyDevice device,
const MemoryConfig& memory_config) {
return allocate_tensor_on_workers(shape, dtype, layout, device.get_devices(), memory_config);
return allocate_tensor_on_devices(shape, dtype, layout, device.get_devices(), memory_config);
}
};

Expand All @@ -276,14 +275,14 @@ struct EmptyLike {
const ttnn::Tensor& tensor,
const std::optional<DataType>& dtype = std::nullopt,
const std::optional<Layout>& layout = std::nullopt,
detail::OptionalSimpleDevice device_arg = std::nullopt,
detail::OptionalAnyDevice device_arg = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt) {
const std::vector<Device*>& devices =
device_arg.has_value() ? device_arg->get_devices() : tensor.get_workers(/*blocking=*/true);
Layout layout_value = layout.value_or(tensor.get_layout());
DataType dtype_value = dtype.value_or(tensor.get_dtype());
MemoryConfig mem_cfg = memory_config.value_or(tensor.memory_config());
return allocate_tensor_on_workers(tensor.get_shape(), dtype_value, layout_value, devices, mem_cfg);
return allocate_tensor_on_devices(tensor.get_shape(), dtype_value, layout_value, devices, mem_cfg);
}
};

Expand All @@ -294,7 +293,7 @@ struct Full {
const float fill_value,
const std::optional<DataType>& dtype = std::nullopt,
const std::optional<Layout>& layout = std::nullopt,
detail::OptionalSimpleDevice device = std::nullopt,
detail::OptionalAnyDevice device = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<ttnn::Tensor> optional_output_tensor = std::nullopt) {
return full_impl(
Expand All @@ -314,7 +313,7 @@ struct Full {
const int fill_value,
const std::optional<DataType>& dtype = std::nullopt,
const std::optional<Layout>& layout = std::nullopt,
detail::OptionalSimpleDevice device = std::nullopt,
detail::OptionalAnyDevice device = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<ttnn::Tensor> optional_output_tensor = std::nullopt) {
return full_impl(
Expand All @@ -333,7 +332,7 @@ struct Full {
const float fill_value,
const std::optional<DataType>& dtype = std::nullopt,
const std::optional<Layout>& layout = std::nullopt,
detail::OptionalSimpleDevice device = std::nullopt,
detail::OptionalAnyDevice device = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<ttnn::Tensor> optional_output_tensor = std::nullopt) {
return full_impl(
Expand All @@ -352,7 +351,7 @@ struct Full {
const int fill_value,
const std::optional<DataType>& dtype = std::nullopt,
const std::optional<Layout>& layout = std::nullopt,
detail::OptionalSimpleDevice device = std::nullopt,
detail::OptionalAnyDevice device = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<ttnn::Tensor> optional_output_tensor = std::nullopt) {
return full_impl(
Expand All @@ -374,7 +373,7 @@ struct FullLike {
const float fill_value,
const std::optional<DataType>& dtype = std::nullopt,
const std::optional<Layout>& layout = std::nullopt,
detail::OptionalSimpleDevice device = std::nullopt,
detail::OptionalAnyDevice device = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<ttnn::Tensor> optional_output_tensor = std::nullopt) {
return full_like_impl(queue_id, tensor, fill_value, dtype, layout, device, memory_config, optional_output_tensor);
Expand All @@ -386,7 +385,7 @@ struct FullLike {
const int fill_value,
const std::optional<DataType>& dtype = std::nullopt,
const std::optional<Layout>& layout = std::nullopt,
detail::OptionalSimpleDevice device = std::nullopt,
detail::OptionalAnyDevice device = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<ttnn::Tensor> optional_output_tensor = std::nullopt) {
return full_like_impl(queue_id, tensor, fill_value, dtype, layout, device, memory_config, optional_output_tensor);
Expand All @@ -397,7 +396,7 @@ struct FullLike {
const float fill_value,
const std::optional<DataType>& dtype = std::nullopt,
const std::optional<Layout>& layout = std::nullopt,
detail::OptionalSimpleDevice device = std::nullopt,
detail::OptionalAnyDevice device = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<ttnn::Tensor> optional_output_tensor = std::nullopt) {
return full_like_impl(ttnn::DefaultQueueId, tensor, fill_value, dtype, layout, device, memory_config, optional_output_tensor);
Expand All @@ -408,14 +407,14 @@ struct FullLike {
const int fill_value,
const std::optional<DataType>& dtype = std::nullopt,
const std::optional<Layout>& layout = std::nullopt,
detail::OptionalSimpleDevice device = std::nullopt,
detail::OptionalAnyDevice device = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<ttnn::Tensor> optional_output_tensor = std::nullopt) {
return full_like_impl(ttnn::DefaultQueueId, tensor, fill_value, dtype, layout, device, memory_config, optional_output_tensor);
}
};

// TODO: #14974 - Onboard this API onto SimpleDevice.
// TODO: #14974 - Onboard this API onto AnyDevice.
struct Arange {
static ttnn::Tensor invoke(
const int64_t stop,
Expand Down
8 changes: 4 additions & 4 deletions ttnn/cpp/ttnn/tensor/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -834,15 +834,15 @@ void memcpy(Tensor& dst, const Tensor& src, const std::optional<std::size_t> tra
}
}

Tensor allocate_tensor_on_workers(
Tensor allocate_tensor_on_devices(
const ttnn::Shape& shape,
DataType data_type,
Layout layout,
const std::vector<Device*>& workers,
const std::vector<Device*>& devices,
const MemoryConfig& memory_config,
const std::optional<Tile>& tile) {
// Top level wrapper to asynchronously create a device tensor (single- or multi-device).
Tensor device_tensor = Tensor(workers);
Tensor device_tensor = Tensor(devices);
TensorSpec tensor_spec(
shape.logical_shape(),
TensorLayout::fromLegacyPaddedShape(data_type, PageConfig(layout, tile), memory_config, shape));
Expand All @@ -855,7 +855,7 @@ Tensor allocate_tensor_on_workers(
uint32_t num_workers = workers_in_use.size();

for (int worker_index = 0; worker_index < num_workers; ++worker_index) {
auto& worker = workers[worker_index];
auto& worker = devices[worker_index];
worker->push_work([worker, device_tensor, tensor_spec, worker_index]() mutable {
auto local_tensor = create_device_tensor(tensor_spec, worker);
insert_buffer_and_shape_for_device(worker, local_tensor, device_tensor, worker_index);
Expand Down
Loading

0 comments on commit 3f7d6c0

Please sign in to comment.