From 864df86a5e0d68923b0e567e79538b0856798383 Mon Sep 17 00:00:00 2001 From: Oleg Milyutin Date: Tue, 3 Dec 2024 18:45:22 +0000 Subject: [PATCH] Move creation functions out of numpy/functions --- ttnn/cpp/ttnn/operations/creation.hpp | 59 +++++++++- ttnn/cpp/ttnn/operations/numpy/functions.hpp | 111 +------------------ ttnn/cpp/ttnn/tensor/types.hpp | 19 ++++ 3 files changed, 78 insertions(+), 111 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/creation.hpp b/ttnn/cpp/ttnn/operations/creation.hpp index acd2914c98f..b17e3f53bd0 100644 --- a/ttnn/cpp/ttnn/operations/creation.hpp +++ b/ttnn/cpp/ttnn/operations/creation.hpp @@ -100,6 +100,50 @@ Tensor create_scalar(T scalar, DataType data_type, Layout layout, Device* device } } +template +static Tensor full( + uint8_t queue_id, + const tt::tt_metal::LegacyShape& shape, + T value, + const Layout layout, + const std::vector& devices, + const MemoryConfig& output_mem_config, + std::optional optional_output_tensor) { + constexpr DataType data_type = tt::tt_metal::convert_to_data_type(); + TensorSpec tensor_spec( + shape.logical_shape(), + TensorLayout::fromLegacyPaddedShape(data_type, PageConfig(layout), MemoryConfig{}, shape)); + auto owned_buffer = tt::tt_metal::owned_buffer::create(tensor_spec.padded_shape().volume()); + // TODO: 15061 - Generalize the header to support generic vector / view types. + std::fill(std::begin(owned_buffer), std::end(owned_buffer), value); + + if (!optional_output_tensor.has_value()) { + auto output = Tensor(OwnedStorage{owned_buffer}, shape, data_type, layout); + if (!devices.empty()) { + output = output.to(devices, output_mem_config); + } + return output; + } else { + const auto buffers = optional_output_tensor->buffers(); + const bool using_fast_dispatch = (std::getenv("TT_METAL_SLOW_DISPATCH_MODE") == nullptr); + + for (auto* buffer : buffers) { + if (using_fast_dispatch) { + auto& cmd_queue = buffer->device()->command_queue(queue_id); + if (CommandQueue::default_mode() == CommandQueue::CommandQueueMode::ASYNC) { + tt::tt_metal::EnqueueWriteBuffer(cmd_queue, *buffer, owned_buffer.get_ptr(), /*blocking=*/false); + } else { + tt::tt_metal::EnqueueWriteBuffer(cmd_queue, *buffer, owned_buffer.data(), /*blocking=*/false); + } + } else { + tt::tt_metal::detail::WriteToBuffer(*buffer, owned_buffer.get()); + } + } + + return *optional_output_tensor; + } +} + template inline ttnn::Tensor full_impl( uint8_t queue_id, @@ -122,8 +166,19 @@ inline ttnn::Tensor full_impl( MemoryConfig mem_cfg = optional_output_tensor.has_value() ? optional_output_tensor.value().memory_config() : memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG); - return numpy::full_impl( - queue_id, shape_value, fill_value, dtype_value, layout_value, workers, mem_cfg, optional_output_tensor); + auto concrete_full = [&](FillValueType concrete_fill_value) { + return full( + queue_id, shape_value, concrete_fill_value, layout_value, workers, mem_cfg, optional_output_tensor); + }; + + switch (dtype_value) { + case DataType::UINT8: return concrete_full(static_cast(fill_value)); + case DataType::UINT16: return concrete_full(static_cast(fill_value)); + case DataType::UINT32: return concrete_full(static_cast(fill_value)); + case DataType::FLOAT32: return concrete_full(static_cast(fill_value)); + case DataType::BFLOAT16: return concrete_full(static_cast<::bfloat16>(static_cast(fill_value))); + default: TT_THROW("Unsupported DataType!"); + } } template diff --git a/ttnn/cpp/ttnn/operations/numpy/functions.hpp b/ttnn/cpp/ttnn/operations/numpy/functions.hpp index 8fa40f0c4e6..a03ed0918ec 100644 --- a/ttnn/cpp/ttnn/operations/numpy/functions.hpp +++ b/ttnn/cpp/ttnn/operations/numpy/functions.hpp @@ -26,113 +26,6 @@ using tt::tt_metal::MemoryConfig; using tt::tt_metal::OwnedStorage; using tt::tt_metal::StorageType; using tt::tt_metal::Tensor; -namespace detail { - -template -constexpr static DataType get_data_type() { - if constexpr (std::is_same_v) { - return DataType::UINT8; - } else if constexpr (std::is_same_v) { - return DataType::UINT16; - } else if constexpr (std::is_same_v) { - return DataType::INT32; - } else if constexpr (std::is_same_v) { - return DataType::UINT32; - } else if constexpr (std::is_same_v) { - return DataType::FLOAT32; - } else if constexpr (std::is_same_v) { - return DataType::BFLOAT16; - } else { - TT_THROW("Unsupported DataType!"); - } -} - -template -static Tensor full( - uint8_t queue_id, - const tt::tt_metal::LegacyShape& shape, - T value, - const Layout layout, - const std::vector& devices, - const MemoryConfig& output_mem_config, - std::optional optional_output_tensor) { - constexpr DataType data_type = detail::get_data_type(); - TensorSpec tensor_spec( - shape.logical_shape(), - TensorLayout::fromLegacyPaddedShape(data_type, PageConfig(layout), MemoryConfig{}, shape)); - auto owned_buffer = tt::tt_metal::owned_buffer::create(tensor_spec.padded_shape().volume()); - // TODO: 15061 - Generalize the header to support generic vector / view types. - std::fill(std::begin(owned_buffer), std::end(owned_buffer), value); - - if (!optional_output_tensor.has_value()) { - auto output = Tensor(OwnedStorage{owned_buffer}, shape, data_type, layout); - if (!devices.empty()) { - output = output.to(devices, output_mem_config); - } - return output; - } else { - const auto buffers = optional_output_tensor->buffers(); - const bool using_fast_dispatch = (std::getenv("TT_METAL_SLOW_DISPATCH_MODE") == nullptr); - - for (auto* buffer : buffers) { - if (using_fast_dispatch) { - auto& cmd_queue = buffer->device()->command_queue(queue_id); - if (CommandQueue::default_mode() == CommandQueue::CommandQueueMode::ASYNC) { - tt::tt_metal::EnqueueWriteBuffer(cmd_queue, *buffer, owned_buffer.get_ptr(), /*blocking=*/false); - } else { - tt::tt_metal::EnqueueWriteBuffer(cmd_queue, *buffer, owned_buffer.data(), /*blocking=*/false); - } - } else { - tt::tt_metal::detail::WriteToBuffer(*buffer, owned_buffer.get()); - } - } - - return *optional_output_tensor; - } -} - -} // namespace detail - -template -static Tensor full_impl( - uint8_t queue_id, - const tt::tt_metal::LegacyShape& shape, - const T value, - const DataType data_type, - const Layout layout, - const std::vector& devices, - const MemoryConfig& output_mem_config, - std::optional optional_output_tensor) { - switch (data_type) { - case DataType::UINT8: { - return detail::full( - queue_id, shape, uint8_t(value), layout, devices, output_mem_config, optional_output_tensor); - } - case DataType::UINT16: { - return detail::full( - queue_id, shape, uint16_t(value), layout, devices, output_mem_config, optional_output_tensor); - } - case DataType::UINT32: { - return detail::full( - queue_id, shape, uint32_t(value), layout, devices, output_mem_config, optional_output_tensor); - } - case DataType::FLOAT32: { - return detail::full( - queue_id, shape, float(value), layout, devices, output_mem_config, optional_output_tensor); - } - case DataType::BFLOAT16: { - return detail::full<::bfloat16>( - queue_id, - shape, - ::bfloat16(static_cast(value)), - layout, - devices, - output_mem_config, - optional_output_tensor); - } - default: TT_THROW("Unsupported DataType!"); - } -} template static Tensor arange( @@ -143,7 +36,7 @@ static Tensor arange( Device* device = nullptr, const MemoryConfig& output_mem_config = MemoryConfig{ .memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}) { - constexpr DataType data_type = detail::get_data_type(); + constexpr DataType data_type = tt::tt_metal::convert_to_data_type(); // Current implementation restrictions TT_ASSERT(step > 0, "Step must be greater than 0"); TT_ASSERT(start < stop, "Start must be less than step"); @@ -628,7 +521,7 @@ static void seed(std::size_t seed) { RANDOM_GENERATOR = std::mt19937(seed); } template static Tensor uniform(T low, T high, const tt::tt_metal::LegacyShape& shape, const Layout layout = Layout::ROW_MAJOR) { - constexpr DataType data_type = detail::get_data_type(); + constexpr DataType data_type = tt::tt_metal::convert_to_data_type(); auto owned_buffer = tt::tt_metal::owned_buffer::create(tt::tt_metal::compute_volume(shape)); diff --git a/ttnn/cpp/ttnn/tensor/types.hpp b/ttnn/cpp/ttnn/tensor/types.hpp index 3666c710113..40d5a224b57 100644 --- a/ttnn/cpp/ttnn/tensor/types.hpp +++ b/ttnn/cpp/ttnn/tensor/types.hpp @@ -41,6 +41,25 @@ enum class DataType { INVALID = 8, }; +template +consteval inline DataType convert_to_data_type() { + if constexpr (std::is_same_v) { + return DataType::UINT8; + } else if constexpr (std::is_same_v) { + return DataType::UINT16; + } else if constexpr (std::is_same_v) { + return DataType::INT32; + } else if constexpr (std::is_same_v) { + return DataType::UINT32; + } else if constexpr (std::is_same_v) { + return DataType::FLOAT32; + } else if constexpr (std::is_same_v) { + return DataType::BFLOAT16; + } else { + static_assert(false, "Unsupported DataType!"); + } +} + inline bool is_floating_point(DataType dtype) { switch (dtype) { case DataType::BFLOAT16: