diff --git a/tests/tt_eager/integration_tests/test_bert.cpp b/tests/tt_eager/integration_tests/test_bert.cpp index ccd6d2a2fa39..17b8dac197ed 100644 --- a/tests/tt_eager/integration_tests/test_bert.cpp +++ b/tests/tt_eager/integration_tests/test_bert.cpp @@ -2,22 +2,19 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "tt_metal/common/constants.hpp" -#include "tensor/tensor.hpp" -#include "tensor/owned_buffer.hpp" -#include "tt_metal/host_api.hpp" - -#include "tt_numpy/functions.hpp" - -#include "tt_dnn/op_library/operation.hpp" +#include -#include "tt_dnn/op_library/bmm/bmm_op.hpp" +#include "tensor/host_buffer/types.hpp" +#include "tensor/tensor.hpp" #include "tt_dnn/op_library/bcast/bcast_op.hpp" -#include "tt_dnn/op_library/transformer_tms/transformer_tms.hpp" +#include "tt_dnn/op_library/bmm/bmm_op.hpp" #include "tt_dnn/op_library/layernorm/layernorm_op.hpp" +#include "tt_dnn/op_library/operation.hpp" #include "tt_dnn/op_library/softmax/softmax_op.hpp" - -#include +#include "tt_dnn/op_library/transformer_tms/transformer_tms.hpp" +#include "tt_metal/common/constants.hpp" +#include "tt_metal/host_api.hpp" +#include "tt_numpy/functions.hpp" using Parameters = std::map; diff --git a/tests/tt_eager/ops/test_eltwise_binary_op.cpp b/tests/tt_eager/ops/test_eltwise_binary_op.cpp index 9fbc09771eb5..bce5675294a2 100644 --- a/tests/tt_eager/ops/test_eltwise_binary_op.cpp +++ b/tests/tt_eager/ops/test_eltwise_binary_op.cpp @@ -3,8 +3,8 @@ // SPDX-License-Identifier: Apache-2.0 #include "common/constants.hpp" -#include "tensor/owned_buffer.hpp" -#include "tensor/owned_buffer_functions.hpp" +#include "tensor/host_buffer/functions.hpp" +#include "tensor/host_buffer/types.hpp" #include "tensor/tensor.hpp" #include "tt_dnn/op_library/eltwise_binary/eltwise_binary_op.hpp" #include "tt_numpy/functions.hpp" diff --git a/tests/tt_eager/ops/test_eltwise_unary_op.cpp b/tests/tt_eager/ops/test_eltwise_unary_op.cpp index 88cf0bf92553..2cadefcf1307 100644 --- a/tests/tt_eager/ops/test_eltwise_unary_op.cpp +++ b/tests/tt_eager/ops/test_eltwise_unary_op.cpp @@ -5,8 +5,8 @@ #include #include "common/constants.hpp" -#include "tensor/owned_buffer.hpp" -#include "tensor/owned_buffer_functions.hpp" +#include "tensor/host_buffer/functions.hpp" +#include "tensor/host_buffer/types.hpp" #include "tensor/tensor.hpp" #include "tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp" #include "tt_dnn/op_library/operation.hpp" diff --git a/tests/tt_eager/ops/test_multi_queue_api.cpp b/tests/tt_eager/ops/test_multi_queue_api.cpp index b322c1dbc044..7b6a5c186810 100644 --- a/tests/tt_eager/ops/test_multi_queue_api.cpp +++ b/tests/tt_eager/ops/test_multi_queue_api.cpp @@ -6,8 +6,8 @@ #include "common/constants.hpp" #include "queue/queue.hpp" -#include "tensor/owned_buffer.hpp" -#include "tensor/owned_buffer_functions.hpp" +#include "tensor/host_buffer/functions.hpp" +#include "tensor/host_buffer/types.hpp" #include "tensor/tensor.hpp" #include "tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp" #include "tt_dnn/op_library/operation.hpp" diff --git a/tests/tt_eager/ops/test_pad_op.cpp b/tests/tt_eager/ops/test_pad_op.cpp index afa349fa819e..d41a4b388ccf 100644 --- a/tests/tt_eager/ops/test_pad_op.cpp +++ b/tests/tt_eager/ops/test_pad_op.cpp @@ -5,15 +5,13 @@ #include #include "common/constants.hpp" +#include "tensor/host_buffer/types.hpp" #include "tensor/tensor.hpp" -#include "tensor/owned_buffer.hpp" +#include "tt_dnn/op_library/operation.hpp" +#include "tt_dnn/op_library/pad/pad_op.hpp" #include "tt_metal/host_api.hpp" - #include "tt_numpy/functions.hpp" -#include "tt_dnn/op_library/pad/pad_op.hpp" -#include "tt_dnn/op_library/operation.hpp" - using tt::tt_metal::DataType; using tt::tt_metal::Device; diff --git a/tests/tt_eager/ops/test_tilize_op.cpp b/tests/tt_eager/ops/test_tilize_op.cpp index f4a9108beb92..a06fe5780cbc 100644 --- a/tests/tt_eager/ops/test_tilize_op.cpp +++ b/tests/tt_eager/ops/test_tilize_op.cpp @@ -2,17 +2,17 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "tt_metal/host_api.hpp" -#include "tensor/tensor.hpp" -#include "tensor/owned_buffer.hpp" -#include "tensor/owned_buffer_functions.hpp" -#include "tt_dnn/op_library/tilize/tilize_op.hpp" -#include "common/constants.hpp" -#include - #include #include #include +#include + +#include "common/constants.hpp" +#include "tensor/host_buffer/functions.hpp" +#include "tensor/host_buffer/types.hpp" +#include "tensor/tensor.hpp" +#include "tt_dnn/op_library/tilize/tilize_op.hpp" +#include "tt_metal/host_api.hpp" using namespace tt; using namespace tt_metal; diff --git a/tests/tt_eager/ops/test_tilize_op_channels_last.cpp b/tests/tt_eager/ops/test_tilize_op_channels_last.cpp index 1234a1e2b01f..b3983aff6e31 100644 --- a/tests/tt_eager/ops/test_tilize_op_channels_last.cpp +++ b/tests/tt_eager/ops/test_tilize_op_channels_last.cpp @@ -2,19 +2,18 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "tt_metal/host_api.hpp" -#include "tensor/tensor.hpp" -#include "tensor/owned_buffer.hpp" -#include "tensor/owned_buffer_functions.hpp" -#include "tensor/owned_buffer_functions.hpp" -#include "tt_dnn/op_library/tilize/tilize_op.hpp" -#include "common/constants.hpp" -#include "tt_numpy/functions.hpp" - #include #include #include +#include "common/constants.hpp" +#include "tensor/host_buffer/functions.hpp" +#include "tensor/host_buffer/types.hpp" +#include "tensor/tensor.hpp" +#include "tt_dnn/op_library/tilize/tilize_op.hpp" +#include "tt_metal/host_api.hpp" +#include "tt_numpy/functions.hpp" + using namespace tt; using namespace tt_metal; using namespace constants; diff --git a/tests/tt_eager/ops/test_tilize_zero_padding.cpp b/tests/tt_eager/ops/test_tilize_zero_padding.cpp index bcf60908d56c..6e8c73a1609c 100644 --- a/tests/tt_eager/ops/test_tilize_zero_padding.cpp +++ b/tests/tt_eager/ops/test_tilize_zero_padding.cpp @@ -2,18 +2,18 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "tt_metal/host_api.hpp" -#include "tensor/tensor.hpp" -#include "tensor/owned_buffer.hpp" -#include "tensor/owned_buffer_functions.hpp" -#include "tt_dnn/op_library/tilize/tilize_op.hpp" -#include "common/constants.hpp" -#include "tt_numpy/functions.hpp" - #include #include #include +#include "common/constants.hpp" +#include "tensor/host_buffer/functions.hpp" +#include "tensor/host_buffer/types.hpp" +#include "tensor/tensor.hpp" +#include "tt_dnn/op_library/tilize/tilize_op.hpp" +#include "tt_metal/host_api.hpp" +#include "tt_numpy/functions.hpp" + using namespace tt; using namespace tt_metal; using namespace constants; diff --git a/tests/tt_eager/ops/test_tilize_zero_padding_channels_last.cpp b/tests/tt_eager/ops/test_tilize_zero_padding_channels_last.cpp index 3742e8743e78..4862c3ec50f9 100644 --- a/tests/tt_eager/ops/test_tilize_zero_padding_channels_last.cpp +++ b/tests/tt_eager/ops/test_tilize_zero_padding_channels_last.cpp @@ -2,18 +2,18 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "tt_metal/host_api.hpp" -#include "tensor/tensor.hpp" -#include "tensor/owned_buffer.hpp" -#include "tensor/owned_buffer_functions.hpp" -#include "tt_dnn/op_library/tilize/tilize_op.hpp" -#include "common/constants.hpp" -#include "tt_numpy/functions.hpp" - #include #include #include +#include "common/constants.hpp" +#include "tensor/host_buffer/functions.hpp" +#include "tensor/host_buffer/types.hpp" +#include "tensor/tensor.hpp" +#include "tt_dnn/op_library/tilize/tilize_op.hpp" +#include "tt_metal/host_api.hpp" +#include "tt_numpy/functions.hpp" + using namespace tt; using namespace tt_metal; using namespace constants; diff --git a/tests/tt_eager/ops/test_transpose_wh_multi_core.cpp b/tests/tt_eager/ops/test_transpose_wh_multi_core.cpp index 8e5a972d6c6b..1131f21891a7 100644 --- a/tests/tt_eager/ops/test_transpose_wh_multi_core.cpp +++ b/tests/tt_eager/ops/test_transpose_wh_multi_core.cpp @@ -2,16 +2,16 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "tt_metal/host_api.hpp" -#include "tensor/tensor.hpp" -#include "tensor/owned_buffer.hpp" -#include "tensor/owned_buffer_functions.hpp" -#include "tt_dnn/op_library/transpose/transpose_op.hpp" -#include - #include #include #include +#include + +#include "tensor/host_buffer/functions.hpp" +#include "tensor/host_buffer/types.hpp" +#include "tensor/tensor.hpp" +#include "tt_dnn/op_library/transpose/transpose_op.hpp" +#include "tt_metal/host_api.hpp" using namespace tt; using namespace tt_metal; diff --git a/tests/tt_eager/ops/test_transpose_wh_single_core.cpp b/tests/tt_eager/ops/test_transpose_wh_single_core.cpp index 8e5a972d6c6b..1131f21891a7 100644 --- a/tests/tt_eager/ops/test_transpose_wh_single_core.cpp +++ b/tests/tt_eager/ops/test_transpose_wh_single_core.cpp @@ -2,16 +2,16 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "tt_metal/host_api.hpp" -#include "tensor/tensor.hpp" -#include "tensor/owned_buffer.hpp" -#include "tensor/owned_buffer_functions.hpp" -#include "tt_dnn/op_library/transpose/transpose_op.hpp" -#include - #include #include #include +#include + +#include "tensor/host_buffer/functions.hpp" +#include "tensor/host_buffer/types.hpp" +#include "tensor/tensor.hpp" +#include "tt_dnn/op_library/transpose/transpose_op.hpp" +#include "tt_metal/host_api.hpp" using namespace tt; using namespace tt_metal; diff --git a/tests/tt_eager/tensors/test_async_tensor_apis.cpp b/tests/tt_eager/tensors/test_async_tensor_apis.cpp index 127830f0b3b7..14f6f8a526ad 100644 --- a/tests/tt_eager/tensors/test_async_tensor_apis.cpp +++ b/tests/tt_eager/tensors/test_async_tensor_apis.cpp @@ -2,24 +2,23 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "tensor/types.hpp" -#include "tt_metal/host_api.hpp" +#include +#include +#include +#include + +#include "common/bfloat16.hpp" +#include "common/constants.hpp" +#include "tensor/host_buffer/functions.hpp" +#include "tensor/host_buffer/types.hpp" #include "tensor/tensor.hpp" #include "tensor/tensor_impl.hpp" -#include "tensor/owned_buffer.hpp" -#include "tensor/owned_buffer_functions.hpp" +#include "tensor/types.hpp" #include "tests/tt_metal/tt_metal/unit_tests_common/common/common_fixture.hpp" #include "tt_dnn/op_library/eltwise_binary/eltwise_binary_op.hpp" -#include "common/bfloat16.hpp" -#include "common/constants.hpp" - +#include "tt_metal/host_api.hpp" #include "tt_numpy/functions.hpp" -#include -#include -#include -#include - using namespace tt; using namespace tt_metal; using namespace constants; diff --git a/tests/tt_eager/tensors/test_copy_and_move.cpp b/tests/tt_eager/tensors/test_copy_and_move.cpp index a0f65de48327..347609dec87b 100644 --- a/tests/tt_eager/tensors/test_copy_and_move.cpp +++ b/tests/tt_eager/tensors/test_copy_and_move.cpp @@ -2,20 +2,19 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "tt_metal/host_api.hpp" -#include "tensor/tensor.hpp" -#include "tensor/tensor_impl.hpp" -#include "tensor/owned_buffer.hpp" -#include "tensor/owned_buffer_functions.hpp" -#include "common/bfloat16.hpp" -#include "common/constants.hpp" - -#include "tt_numpy/functions.hpp" - #include #include #include +#include "common/bfloat16.hpp" +#include "common/constants.hpp" +#include "tensor/host_buffer/functions.hpp" +#include "tensor/host_buffer/types.hpp" +#include "tensor/tensor.hpp" +#include "tensor/tensor_impl.hpp" +#include "tt_metal/host_api.hpp" +#include "tt_numpy/functions.hpp" + using namespace tt; using namespace tt_metal; using namespace constants; diff --git a/tests/tt_eager/tensors/test_host_device_loopback.cpp b/tests/tt_eager/tensors/test_host_device_loopback.cpp index 8c21ecb7d9cc..a490d7e3a328 100644 --- a/tests/tt_eager/tensors/test_host_device_loopback.cpp +++ b/tests/tt_eager/tensors/test_host_device_loopback.cpp @@ -2,18 +2,18 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "tt_metal/host_api.hpp" -#include "tensor/tensor.hpp" -#include "tensor/owned_buffer.hpp" -#include "tensor/owned_buffer_functions.hpp" -#include "tt_dnn/op_library/eltwise_binary/eltwise_binary_op.hpp" -#include "common/constants.hpp" -#include "tt_numpy/functions.hpp" - #include #include #include +#include "common/constants.hpp" +#include "tensor/host_buffer/functions.hpp" +#include "tensor/host_buffer/types.hpp" +#include "tensor/tensor.hpp" +#include "tt_dnn/op_library/eltwise_binary/eltwise_binary_op.hpp" +#include "tt_metal/host_api.hpp" +#include "tt_numpy/functions.hpp" + using namespace tt; using namespace tt_metal; using namespace constants; diff --git a/tests/tt_eager/tensors/test_ranks.cpp b/tests/tt_eager/tensors/test_ranks.cpp index 78cf668c3e59..ccd355b96eb0 100644 --- a/tests/tt_eager/tensors/test_ranks.cpp +++ b/tests/tt_eager/tensors/test_ranks.cpp @@ -2,20 +2,19 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "tt_metal/host_api.hpp" -#include "tensor/tensor.hpp" -#include "tensor/tensor_impl.hpp" -#include "tensor/owned_buffer.hpp" -#include "tensor/owned_buffer_functions.hpp" -#include "common/bfloat16.hpp" -#include "common/constants.hpp" - -#include "tt_numpy/functions.hpp" - #include #include #include +#include "common/bfloat16.hpp" +#include "common/constants.hpp" +#include "tensor/host_buffer/functions.hpp" +#include "tensor/host_buffer/types.hpp" +#include "tensor/tensor.hpp" +#include "tensor/tensor_impl.hpp" +#include "tt_metal/host_api.hpp" +#include "tt_numpy/functions.hpp" + using namespace tt; using namespace tt_metal; using namespace constants; diff --git a/tests/tt_eager/tensors/test_raw_host_memory_pointer.cpp b/tests/tt_eager/tensors/test_raw_host_memory_pointer.cpp index 787c404e1aae..ddce6e4e81a3 100644 --- a/tests/tt_eager/tensors/test_raw_host_memory_pointer.cpp +++ b/tests/tt_eager/tensors/test_raw_host_memory_pointer.cpp @@ -8,8 +8,8 @@ #include "common/bfloat16.hpp" #include "common/constants.hpp" -#include "tensor/owned_buffer.hpp" -#include "tensor/owned_buffer_functions.hpp" +#include "tensor/host_buffer/functions.hpp" +#include "tensor/host_buffer/types.hpp" #include "tensor/tensor.hpp" #include "tensor/tensor_impl.hpp" #include "tt_dnn/op_library/eltwise_binary/eltwise_binary_op.hpp" diff --git a/tt_eager/tensor/borrowed_buffer.hpp b/tt_eager/tensor/borrowed_buffer.hpp deleted file mode 100644 index 03f1c87ad123..000000000000 --- a/tt_eager/tensor/borrowed_buffer.hpp +++ /dev/null @@ -1,64 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#pragma once - -#include - -namespace tt { - -namespace tt_metal { - -namespace borrowed_buffer { - -template -struct Buffer { - explicit Buffer() = default; - explicit Buffer(T* data_ptr, std::size_t size) : - data_ptr_(data_ptr), - size_(size) {} - - const std::size_t size() const { return this->size_; } - - inline T& operator[](std::size_t index) noexcept { return this->data_ptr_[index]; } - inline const T& operator[](std::size_t index) const noexcept { return this->data_ptr_[index]; } - - inline T* begin() noexcept { return this->data_ptr_; } - inline T* end() noexcept { return this->data_ptr_ + this->size(); } - - inline const T* begin() const noexcept { return this->data_ptr_; } - inline const T* end() const noexcept { return this->data_ptr_ + this->size(); } - - inline void* data() noexcept { return static_cast(this->data_ptr_); } - inline const void* data() const noexcept { return static_cast(this->data_ptr_); } - private: - T* data_ptr_; - std::size_t size_; -}; - - -template -bool operator==(const Buffer& buffer_a, const Buffer& buffer_b) noexcept { - if (buffer_a.size() != buffer_b.size()) { - return false; - } - for (auto index = 0; index < buffer_a.size(); index++) { - if (buffer_a[index] != buffer_b[index]) { - return false; - } - } - return true; -} - - -template -bool operator!=(const Buffer& buffer_a, const Buffer& buffer_b) noexcept { - return not (buffer_a == buffer_b); -} - -} // namespace borrowed_buffer - -} // namespace tt_metal - -} // namespace tt diff --git a/tt_eager/tensor/borrowed_buffer_functions.hpp b/tt_eager/tensor/borrowed_buffer_functions.hpp deleted file mode 100644 index 61fc436425a1..000000000000 --- a/tt_eager/tensor/borrowed_buffer_functions.hpp +++ /dev/null @@ -1,79 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#pragma once - -#include "tensor/tensor.hpp" -#include "tensor/borrowed_buffer.hpp" - -#include - -namespace tt { - -namespace tt_metal { - -namespace borrowed_buffer { - -template -void validate_datatype(const Tensor& tensor) { - if constexpr (std::is_same_v) { - TT_FATAL(tensor.get_dtype() == DataType::UINT32); - } else if constexpr (std::is_same_v) { - TT_FATAL(tensor.get_dtype() == DataType::INT32); - } else if constexpr (std::is_same_v) { - TT_FATAL(tensor.get_dtype() == DataType::FLOAT32); - } else if constexpr (std::is_same_v) { - TT_FATAL(tensor.get_dtype() == DataType::BFLOAT16); - } else if constexpr (std::is_same_v) { - TT_FATAL(tensor.get_dtype() == DataType::UINT16); - } else { - static_assert(tt::stl::concepts::always_false_v, "Unsupported DataType"); - } -} - -template -Buffer get_as(BorrowedBuffer& buffer) { - return std::get>(buffer); -} - -template -const Buffer get_as(const BorrowedBuffer& buffer) { - return std::get>(buffer); -} - -template -Buffer get_as(Tensor& tensor) { - validate_datatype(tensor); - return std::visit( - [](auto&& storage) -> Buffer { - using StorageType = std::decay_t; - if constexpr (std::is_same_v) { - return get_as(storage.buffer); - } else { - TT_THROW("Tensor must have BorrowedStorage"); - } - }, - tensor.get_storage()); -} - -template -const Buffer get_as(const Tensor& tensor) { - validate_datatype(tensor); - return std::visit( - [](auto&& storage) -> Buffer { - using StorageType = std::decay_t; - if constexpr (std::is_same_v) { - return get_as(storage.buffer); - } else { - TT_THROW("Tensor must have BorrowedStorage"); - } - }, - tensor.get_storage()); -} - -} // namespace borrowed_buffer - -} // namespace tt_metal - -} // namespace tt diff --git a/tt_eager/tensor/host_buffer/functions.hpp b/tt_eager/tensor/host_buffer/functions.hpp new file mode 100644 index 000000000000..47626ab2134c --- /dev/null +++ b/tt_eager/tensor/host_buffer/functions.hpp @@ -0,0 +1,215 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "tensor/host_buffer/types.hpp" +#include "tensor/tensor.hpp" + +namespace tt { + +namespace tt_metal { + +namespace borrowed_buffer { + +template +void validate_datatype(const Tensor& tensor) { + if constexpr (std::is_same_v) { + TT_FATAL(tensor.get_dtype() == DataType::UINT32); + } else if constexpr (std::is_same_v) { + TT_FATAL(tensor.get_dtype() == DataType::INT32); + } else if constexpr (std::is_same_v) { + TT_FATAL(tensor.get_dtype() == DataType::FLOAT32); + } else if constexpr (std::is_same_v) { + TT_FATAL(tensor.get_dtype() == DataType::BFLOAT16); + } else if constexpr (std::is_same_v) { + TT_FATAL(tensor.get_dtype() == DataType::UINT16); + } else { + static_assert(tt::stl::concepts::always_false_v, "Unsupported DataType"); + } +} + +template +Buffer get_as(BorrowedBuffer& buffer) { + return std::get>(buffer); +} + +template +const Buffer get_as(const BorrowedBuffer& buffer) { + return std::get>(buffer); +} + +template +Buffer get_as(Tensor& tensor) { + validate_datatype(tensor); + return std::visit( + [](auto&& storage) -> Buffer { + using StorageType = std::decay_t; + if constexpr (std::is_same_v) { + return get_as(storage.buffer); + } else { + TT_THROW("Tensor must have BorrowedStorage"); + } + }, + tensor.get_storage()); +} + +template +const Buffer get_as(const Tensor& tensor) { + validate_datatype(tensor); + return std::visit( + [](auto&& storage) -> const Buffer { + using StorageType = std::decay_t; + if constexpr (std::is_same_v) { + return get_as(storage.buffer); + } else { + TT_THROW("Tensor must have BorrowedStorage"); + } + }, + tensor.get_storage()); +} + +} // namespace borrowed_buffer + +namespace owned_buffer { + +template +Buffer create(std::vector&& storage) { + return Buffer{std::make_shared>(std::forward>(storage))}; +} + +template +Buffer create(std::size_t size) { + return create(std::vector(size, 0)); +} + +template +void validate_datatype(const Tensor& tensor) { + if constexpr (std::is_same_v) { + TT_FATAL( + tensor.get_dtype() == DataType::UINT32 or tensor.get_dtype() == DataType::BFLOAT8_B or + tensor.get_dtype() == DataType::BFLOAT4_B); + } else if constexpr (std::is_same_v) { + TT_FATAL(tensor.get_dtype() == DataType::INT32); + } else if constexpr (std::is_same_v) { + TT_FATAL(tensor.get_dtype() == DataType::FLOAT32); + } else if constexpr (std::is_same_v) { + TT_FATAL(tensor.get_dtype() == DataType::BFLOAT16); + } else if constexpr (std::is_same_v) { + TT_FATAL(tensor.get_dtype() == DataType::UINT16); + } else { + static_assert(tt::stl::concepts::always_false_v, "Unsupported DataType"); + } +} + +template +Buffer get_as(OwnedBuffer& buffer) { + return std::get>(buffer); +} + +template +const Buffer get_as(const OwnedBuffer& buffer) { + return std::get>(buffer); +} + +template +Buffer get_as(Tensor& tensor) { + validate_datatype(tensor); + return std::visit( + [](auto&& storage) -> Buffer { + using StorageType = std::decay_t; + if constexpr (std::is_same_v) { + return get_as(storage.buffer); + } else { + TT_THROW("Tensor must have OwnedStorage"); + } + }, + tensor.get_storage()); +} + +template +const Buffer get_as(const Tensor& tensor) { + validate_datatype(tensor); + return std::visit( + [](auto&& storage) -> const Buffer { + using StorageType = std::decay_t; + if constexpr (std::is_same_v) { + return get_as(storage.buffer); + } else { + TT_THROW("Tensor must have OwnedStorage"); + } + }, + tensor.get_storage()); +} + +} // namespace owned_buffer + +namespace host_buffer { + +template +borrowed_buffer::Buffer get_as(OwnedBuffer& buffer) { + auto& owned_buffer = std::get>(buffer); + return borrowed_buffer::Buffer(owned_buffer.begin(), owned_buffer.size()); +} + +template +const borrowed_buffer::Buffer get_as(const OwnedBuffer& buffer) { + auto owned_buffer = std::get>(buffer); + return borrowed_buffer::Buffer(owned_buffer.begin(), owned_buffer.size()); +} + +template +borrowed_buffer::Buffer get_as(OwnedBuffer&& buffer) = delete; +template +borrowed_buffer::Buffer get_as(const OwnedBuffer&& buffer) = delete; + +template +borrowed_buffer::Buffer get_as(BorrowedBuffer& buffer) { + return borrowed_buffer::get_as(buffer); +} + +template +const borrowed_buffer::Buffer get_as(const BorrowedBuffer& buffer) { + return borrowed_buffer::get_as(buffer); +} + +template +borrowed_buffer::Buffer get_as(Tensor& tensor) { + return std::visit( + [](auto&& storage) -> borrowed_buffer::Buffer { + using StorageType = std::decay_t; + if constexpr (std::is_same_v) { + return get_as(storage.buffer); + } else if constexpr (std::is_same_v) { + return get_as(storage.buffer); + } else { + TT_THROW("Tensor must have OwnedStorage or BorrowedStorage"); + } + }, + tensor.get_storage()); +} + +template +const borrowed_buffer::Buffer get_as(const Tensor& tensor) { + return std::visit( + [](auto&& storage) -> const borrowed_buffer::Buffer { + using StorageType = std::decay_t; + if constexpr (std::is_same_v) { + return get_as(storage.buffer); + } else if constexpr (std::is_same_v) { + return get_as(storage.buffer); + } else { + TT_THROW("Tensor must have OwnedStorage or BorrowedStorage"); + } + }, + tensor.get_storage()); +} + +} // namespace host_buffer + +} // namespace tt_metal + +} // namespace tt diff --git a/tt_eager/tensor/owned_buffer.hpp b/tt_eager/tensor/host_buffer/types.hpp similarity index 59% rename from tt_eager/tensor/owned_buffer.hpp rename to tt_eager/tensor/host_buffer/types.hpp index 076e16f4d2e2..8dc74daab0fd 100644 --- a/tt_eager/tensor/owned_buffer.hpp +++ b/tt_eager/tensor/host_buffer/types.hpp @@ -11,9 +11,55 @@ namespace tt { namespace tt_metal { +namespace borrowed_buffer { + +template +struct Buffer { + explicit Buffer() = default; + explicit Buffer(T* data_ptr, std::size_t size) : data_ptr_(data_ptr), size_(size) {} + + const std::size_t size() const { return this->size_; } + + inline T& operator[](std::size_t index) noexcept { return this->data_ptr_[index]; } + inline const T& operator[](std::size_t index) const noexcept { return this->data_ptr_[index]; } + + inline T* begin() noexcept { return this->data_ptr_; } + inline T* end() noexcept { return this->data_ptr_ + this->size(); } + + inline const T* begin() const noexcept { return this->data_ptr_; } + inline const T* end() const noexcept { return this->data_ptr_ + this->size(); } + + inline void* data() noexcept { return static_cast(this->data_ptr_); } + inline const void* data() const noexcept { return static_cast(this->data_ptr_); } + + private: + T* data_ptr_; + std::size_t size_; +}; + +template +bool operator==(const Buffer& buffer_a, const Buffer& buffer_b) noexcept { + if (buffer_a.size() != buffer_b.size()) { + return false; + } + for (auto index = 0; index < buffer_a.size(); index++) { + if (buffer_a[index] != buffer_b[index]) { + return false; + } + } + return true; +} + +template +bool operator!=(const Buffer& buffer_a, const Buffer& buffer_b) noexcept { + return not(buffer_a == buffer_b); +} + +} // namespace borrowed_buffer + namespace owned_buffer { -template +template struct Buffer { explicit Buffer() = default; explicit Buffer(std::shared_ptr>&& shared_vector) : @@ -36,7 +82,7 @@ struct Buffer { inline const T* begin() const noexcept { return this->pointer_for_faster_access_; } inline const T* end() const noexcept { return this->pointer_for_faster_access_ + this->size(); } - inline bool is_allocated() const{ return bool(this->shared_vector_); } + inline bool is_allocated() const { return bool(this->shared_vector_); } inline const std::vector& get() const { return *this->shared_vector_; } inline const std::shared_ptr> get_ptr() const noexcept { return this->shared_vector_; } inline void reset() { this->shared_vector_.reset(); } @@ -44,14 +90,14 @@ struct Buffer { inline void* data() noexcept { return static_cast(this->pointer_for_faster_access_); } inline const void* data() const noexcept { return static_cast(this->pointer_for_faster_access_); } inline uint32_t use_count() const noexcept { return this->shared_vector_.use_count(); } + private: std::shared_ptr> shared_vector_; T* pointer_for_faster_access_; std::size_t size_; }; - -template +template bool operator==(const Buffer& buffer_a, const Buffer& buffer_b) noexcept { if (buffer_a.size() != buffer_b.size()) { return false; @@ -64,10 +110,9 @@ bool operator==(const Buffer& buffer_a, const Buffer& buffer_b) noexcept { return true; } - -template +template bool operator!=(const Buffer& buffer_a, const Buffer& buffer_b) noexcept { - return not (buffer_a == buffer_b); + return not(buffer_a == buffer_b); } } // namespace owned_buffer diff --git a/tt_eager/tensor/owned_buffer_functions.hpp b/tt_eager/tensor/owned_buffer_functions.hpp deleted file mode 100644 index af39652acbdd..000000000000 --- a/tt_eager/tensor/owned_buffer_functions.hpp +++ /dev/null @@ -1,89 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#pragma once - -#include "tensor/tensor.hpp" -#include "tensor/owned_buffer.hpp" - -#include - -namespace tt { - -namespace tt_metal { - -namespace owned_buffer { - -template -Buffer create(std::vector&& storage) { - return Buffer{std::make_shared>(std::forward>(storage))}; -} - -template -Buffer create(std::size_t size) { - return create(std::vector(size, 0)); -} - -template -void validate_datatype(const Tensor& tensor) { - if constexpr (std::is_same_v) { - TT_FATAL(tensor.get_dtype() == DataType::UINT32 or tensor.get_dtype() == DataType::BFLOAT8_B or tensor.get_dtype() == DataType::BFLOAT4_B); - } else if constexpr (std::is_same_v) { - TT_FATAL(tensor.get_dtype() == DataType::INT32); - } else if constexpr (std::is_same_v) { - TT_FATAL(tensor.get_dtype() == DataType::FLOAT32); - } else if constexpr (std::is_same_v) { - TT_FATAL(tensor.get_dtype() == DataType::BFLOAT16); - } else if constexpr (std::is_same_v) { - TT_FATAL(tensor.get_dtype() == DataType::UINT16); - } else { - static_assert(tt::stl::concepts::always_false_v, "Unsupported DataType"); - } -} - -template -Buffer get_as(OwnedBuffer& buffer) { - return std::get>(buffer); -} - -template -const Buffer get_as(const OwnedBuffer& buffer) { - return std::get>(buffer); -} - -template -Buffer get_as(Tensor& tensor) { - validate_datatype(tensor); - return std::visit( - [](auto&& storage) -> Buffer { - using StorageType = std::decay_t; - if constexpr (std::is_same_v) { - return get_as(storage.buffer); - } else { - TT_THROW("Tensor must have OwnedStorage"); - } - }, - tensor.get_storage()); -} - -template -const Buffer get_as(const Tensor& tensor) { - validate_datatype(tensor); - return std::visit( - [](auto&& storage) -> Buffer { - using StorageType = std::decay_t; - if constexpr (std::is_same_v) { - return get_as(storage.buffer); - } else { - TT_THROW("Tensor must have OwnedStorage"); - } - }, - tensor.get_storage()); -} - -} // namespace owned_buffer - -} // namespace tt_metal - -} // namespace tt diff --git a/tt_eager/tensor/serialization.cpp b/tt_eager/tensor/serialization.cpp index d0e21ceb9d03..16c6a049be07 100644 --- a/tt_eager/tensor/serialization.cpp +++ b/tt_eager/tensor/serialization.cpp @@ -10,8 +10,7 @@ #include #include -#include "tensor/borrowed_buffer_functions.hpp" -#include "tensor/owned_buffer_functions.hpp" +#include "tensor/host_buffer/functions.hpp" #include "tensor/tensor_utils.hpp" namespace tt { diff --git a/tt_eager/tensor/tensor_impl.hpp b/tt_eager/tensor/tensor_impl.hpp index 25c3a235b7c5..025ad977b3f9 100644 --- a/tt_eager/tensor/tensor_impl.hpp +++ b/tt_eager/tensor/tensor_impl.hpp @@ -6,8 +6,9 @@ #include #include -#include "tensor/borrowed_buffer_functions.hpp" -#include "tensor/owned_buffer_functions.hpp" +#include "common/bfloat4.hpp" +#include "common/bfloat8.hpp" +#include "tensor/host_buffer/functions.hpp" #include "tensor/tensor.hpp" #include "tensor/tensor_impl_wrapper.hpp" #include "tensor/tensor_utils.hpp" @@ -18,9 +19,6 @@ #include "tt_metal/third_party/tracy/public/tracy/Tracy.hpp" #include "tt_stl/concepts.hpp" -#include "common/bfloat4.hpp" -#include "common/bfloat8.hpp" - namespace tt { namespace tt_metal { @@ -322,8 +320,8 @@ inline DeviceBuffer to_device_buffer( if (memory_config.is_sharded()) { TT_ASSERT(shard_spec.has_value(), "If sharded must provide shard_spec"); } - if constexpr (std::is_same_v) { - auto data_to_write = owned_buffer::get_as(storage.buffer); + if constexpr (std::is_same_v or std::is_same_v) { + auto data_to_write = host_buffer::get_as(storage.buffer); TT_ASSERT( compute_buffer_size(shape, data_type) == data_to_write.size(), fmt::format( @@ -339,28 +337,6 @@ inline DeviceBuffer to_device_buffer( data_to_write, device, shape, data_type, layout, memory_config, shard_spec); } else if constexpr (std::is_same_v) { TT_THROW("Device storage doesn't support to_device_buffer"); - } else if constexpr (std::is_same_v) { - if constexpr ( - std::is_same_v or std::is_same_v or std::is_same_v or - std::is_same_v or std::is_same_v) { - auto data_to_write = borrowed_buffer::get_as(storage.buffer); - TT_ASSERT( - compute_buffer_size(shape, data_type) == data_to_write.size(), - fmt::format( - "Tensor buffer size and number of data elements does not match: {} != {}", - compute_buffer_size(shape, data_type), - data_to_write.size())); - if (layout == Layout::TILE) { - TT_ASSERT( - (shape[-2] % tt::constants::TILE_HEIGHT == 0 && shape[-1] % tt::constants::TILE_WIDTH == 0), - "Tensor shape incompatible for specified layout"); - } - return initialize_data_on_device( - data_to_write, device, shape, data_type, layout, memory_config, shard_spec); - - } else { - TT_THROW("Borrowed storage doesn't support this data type"); - } } else if constexpr (std::is_same_v) { TT_THROW("MultiHostStorage storage doesn't support to_device_buffer"); } else if constexpr (std::is_same_v) { diff --git a/tt_eager/tensor/tensor_utils.cpp b/tt_eager/tensor/tensor_utils.cpp index cd913e0675c3..2cdc16150863 100644 --- a/tt_eager/tensor/tensor_utils.cpp +++ b/tt_eager/tensor/tensor_utils.cpp @@ -3,8 +3,9 @@ // SPDX-License-Identifier: Apache-2.0 #include "tensor/tensor_utils.hpp" -#include "tensor/owned_buffer.hpp" -#include "tensor/owned_buffer_functions.hpp" + +#include "tensor/host_buffer/functions.hpp" +#include "tensor/host_buffer/types.hpp" namespace tt { diff --git a/tt_eager/tensor/types.hpp b/tt_eager/tensor/types.hpp index 0332081201ad..23e5b3c952ab 100644 --- a/tt_eager/tensor/types.hpp +++ b/tt_eager/tensor/types.hpp @@ -11,8 +11,7 @@ #include #include "common/bfloat16.hpp" -#include "tensor/borrowed_buffer.hpp" -#include "tensor/owned_buffer.hpp" +#include "tensor/host_buffer/types.hpp" #include "tt_metal/impl/buffers/buffer.hpp" #include "tt_metal/impl/device/device.hpp" #include "tt_metal/tt_stl/concepts.hpp" diff --git a/tt_eager/tt_dnn/op_library/backward/backward_ops.hpp b/tt_eager/tt_dnn/op_library/backward/backward_ops.hpp index 1621f8988772..f5a4ddce68b6 100644 --- a/tt_eager/tt_dnn/op_library/backward/backward_ops.hpp +++ b/tt_eager/tt_dnn/op_library/backward/backward_ops.hpp @@ -4,14 +4,12 @@ #pragma once #include -#include "tt_metal/common/constants.hpp" -#include "tt_dnn/op_library/bcast/bcast_op.hpp" +#include "tensor/host_buffer/functions.hpp" #include "tensor/tensor.hpp" #include "tensor/tensor_utils.hpp" -#include "tensor/owned_buffer_functions.hpp" - - +#include "tt_dnn/op_library/bcast/bcast_op.hpp" +#include "tt_metal/common/constants.hpp" namespace tt { diff --git a/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp b/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp index b71219b61b69..795205290283 100644 --- a/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp +++ b/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp @@ -5,13 +5,13 @@ #pragma once #include -#include "tt_metal/common/constants.hpp" -#include "tensor/owned_buffer_functions.hpp" +#include "tensor/host_buffer/functions.hpp" #include "tensor/tensor.hpp" #include "tensor/tensor_utils.hpp" #include "tt_dnn/op_library/bcast/bcast_op.hpp" #include "tt_dnn/op_library/eltwise_binary/eltwise_binary_op.hpp" #include "tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp" +#include "tt_metal/common/constants.hpp" namespace tt { diff --git a/tt_eager/tt_dnn/op_library/pad/pad_op.cpp b/tt_eager/tt_dnn/op_library/pad/pad_op.cpp index cfd9b3cb07e5..c449ceb60d5d 100644 --- a/tt_eager/tt_dnn/op_library/pad/pad_op.cpp +++ b/tt_eager/tt_dnn/op_library/pad/pad_op.cpp @@ -3,13 +3,13 @@ // SPDX-License-Identifier: Apache-2.0 #include "tt_dnn/op_library/pad/pad_op.hpp" + +#include "tensor/host_buffer/functions.hpp" #include "tt_dnn/op_library/copy/copy_op.hpp" #include "tt_dnn/op_library/math.hpp" - -#include "tt_metal/host_api.hpp" #include "tt_metal/common/constants.hpp" #include "tt_metal/detail/util.hpp" -#include "tensor/owned_buffer_functions.hpp" +#include "tt_metal/host_api.hpp" using namespace tt::constants; diff --git a/tt_eager/tt_dnn/op_library/pad/pad_op_multi_core.cpp b/tt_eager/tt_dnn/op_library/pad/pad_op_multi_core.cpp index dc9cd6b1b0af..42133fb83bdd 100644 --- a/tt_eager/tt_dnn/op_library/pad/pad_op_multi_core.cpp +++ b/tt_eager/tt_dnn/op_library/pad/pad_op_multi_core.cpp @@ -2,13 +2,12 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "tt_dnn/op_library/pad/pad_op.hpp" +#include "tensor/host_buffer/functions.hpp" #include "tt_dnn/op_library/math.hpp" - -#include "tt_metal/host_api.hpp" +#include "tt_dnn/op_library/pad/pad_op.hpp" #include "tt_metal/common/constants.hpp" #include "tt_metal/detail/util.hpp" -#include "tensor/owned_buffer_functions.hpp" +#include "tt_metal/host_api.hpp" using namespace tt::constants; diff --git a/tt_eager/tt_dnn/op_library/pool/max_pool.cpp b/tt_eager/tt_dnn/op_library/pool/max_pool.cpp index 68d6ee025f12..594c23504d3b 100644 --- a/tt_eager/tt_dnn/op_library/pool/max_pool.cpp +++ b/tt_eager/tt_dnn/op_library/pool/max_pool.cpp @@ -2,16 +2,17 @@ // // SPDX-License-Identifier: Apache-2.0 +#include "tt_dnn/op_library/pool/max_pool.hpp" + #include #include -#include "tt_dnn/op_library/pool/max_pool.hpp" -#include "tt_dnn/op_library/reduce/reduce_op.hpp" // for reduce_op_utils +#include "detail/util.hpp" +#include "tensor/host_buffer/functions.hpp" +#include "tensor/tensor_utils.hpp" +#include "tt_dnn/op_library/reduce/reduce_op.hpp" // for reduce_op_utils #include "tt_dnn/op_library/work_split.hpp" #include "tt_metal/host_api.hpp" -#include "tensor/tensor_utils.hpp" -#include "tensor/owned_buffer_functions.hpp" -#include "detail/util.hpp" namespace tt { namespace tt_metal { diff --git a/tt_eager/tt_dnn/op_library/pool/max_pool_multi_core.cpp b/tt_eager/tt_dnn/op_library/pool/max_pool_multi_core.cpp index dc34b2c4f6f5..0b264eef7f15 100644 --- a/tt_eager/tt_dnn/op_library/pool/max_pool_multi_core.cpp +++ b/tt_eager/tt_dnn/op_library/pool/max_pool_multi_core.cpp @@ -5,15 +5,15 @@ #include #include +#include "detail/util.hpp" +#include "tensor/host_buffer/functions.hpp" +#include "tensor/tensor_utils.hpp" #include "tt_dnn/op_library/pool/max_pool.hpp" -#include "tt_dnn/op_library/reduce/reduce_op.hpp" // for reduce_op_utils -#include "tt_dnn/op_library/work_split.hpp" +#include "tt_dnn/op_library/reduce/reduce_op.hpp" // for reduce_op_utils #include "tt_dnn/op_library/sharding_utilities.hpp" #include "tt_dnn/op_library/sliding_window_op_infra/utils.hpp" +#include "tt_dnn/op_library/work_split.hpp" #include "tt_metal/host_api.hpp" -#include "tensor/tensor_utils.hpp" -#include "tensor/owned_buffer_functions.hpp" -#include "detail/util.hpp" namespace tt { namespace tt_metal { diff --git a/tt_eager/tt_dnn/op_library/pool/max_pool_single_core.cpp b/tt_eager/tt_dnn/op_library/pool/max_pool_single_core.cpp index 0d0779041122..f8bc1933a676 100644 --- a/tt_eager/tt_dnn/op_library/pool/max_pool_single_core.cpp +++ b/tt_eager/tt_dnn/op_library/pool/max_pool_single_core.cpp @@ -5,13 +5,13 @@ #include #include +#include "detail/util.hpp" +#include "tensor/host_buffer/functions.hpp" +#include "tensor/tensor_utils.hpp" #include "tt_dnn/op_library/pool/max_pool.hpp" -#include "tt_dnn/op_library/reduce/reduce_op.hpp" // for reduce_op_utils +#include "tt_dnn/op_library/reduce/reduce_op.hpp" // for reduce_op_utils #include "tt_dnn/op_library/work_split.hpp" #include "tt_metal/host_api.hpp" -#include "tensor/tensor_utils.hpp" -#include "tensor/owned_buffer_functions.hpp" -#include "detail/util.hpp" namespace tt { namespace tt_metal { diff --git a/tt_eager/tt_dnn/op_library/untilize/untilize_with_halo_op.cpp b/tt_eager/tt_dnn/op_library/untilize/untilize_with_halo_op.cpp index 183be80c2d4b..4c9ec98b7f29 100644 --- a/tt_eager/tt_dnn/op_library/untilize/untilize_with_halo_op.cpp +++ b/tt_eager/tt_dnn/op_library/untilize/untilize_with_halo_op.cpp @@ -4,15 +4,14 @@ #include - +#include "tensor/host_buffer/functions.hpp" +#include "tt_dnn/op_library/math.hpp" +#include "tt_dnn/op_library/sharding_utilities.hpp" #include "tt_dnn/op_library/untilize/untilize_op.hpp" #include "tt_dnn/op_library/work_split.hpp" -#include "tt_dnn/op_library/sharding_utilities.hpp" -#include "tt_dnn/op_library/math.hpp" -#include "tt_metal/host_api.hpp" #include "tt_metal/common/constants.hpp" #include "tt_metal/detail/util.hpp" -#include "tensor/owned_buffer_functions.hpp" +#include "tt_metal/host_api.hpp" using namespace tt::constants; diff --git a/tt_eager/tt_dnn/op_library/untilize/untilize_with_halo_op_v2.cpp b/tt_eager/tt_dnn/op_library/untilize/untilize_with_halo_op_v2.cpp index 8186457bac2c..6b1506951813 100644 --- a/tt_eager/tt_dnn/op_library/untilize/untilize_with_halo_op_v2.cpp +++ b/tt_eager/tt_dnn/op_library/untilize/untilize_with_halo_op_v2.cpp @@ -6,7 +6,7 @@ #include -#include "tensor/owned_buffer_functions.hpp" +#include "tensor/host_buffer/functions.hpp" #include "tt_dnn/op_library/math.hpp" #include "tt_dnn/op_library/sharding_utilities.hpp" #include "tt_dnn/op_library/sliding_window_op_infra/utils.hpp" diff --git a/tt_eager/tt_dnn/op_library/upsample/upsample_op.cpp b/tt_eager/tt_dnn/op_library/upsample/upsample_op.cpp index 734a1e277d52..67bb4124c410 100644 --- a/tt_eager/tt_dnn/op_library/upsample/upsample_op.cpp +++ b/tt_eager/tt_dnn/op_library/upsample/upsample_op.cpp @@ -2,17 +2,18 @@ // // SPDX-License-Identifier: Apache-2.0 +#include "tt_dnn/op_library/upsample/upsample_op.hpp" + #include #include -#include "tt_dnn/op_library/upsample/upsample_op.hpp" +#include "detail/util.hpp" +#include "tensor/host_buffer/functions.hpp" +#include "tensor/tensor_utils.hpp" #include "tt_dnn/op_library/pool/max_pool.hpp" -#include "tt_dnn/op_library/reduce/reduce_op.hpp" // for reduce_op_utils +#include "tt_dnn/op_library/reduce/reduce_op.hpp" // for reduce_op_utils #include "tt_dnn/op_library/work_split.hpp" #include "tt_metal/host_api.hpp" -#include "tensor/tensor_utils.hpp" -#include "tensor/owned_buffer_functions.hpp" -#include "detail/util.hpp" namespace tt { namespace tt_metal { diff --git a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor.cpp b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor.cpp index a7b2ec5b2636..c28addfb81ed 100644 --- a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor.cpp +++ b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor.cpp @@ -2,38 +2,38 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "tt_lib_bindings.hpp" -#include "tt_dnn/op_library/downsample/downsample_op.hpp" -#include "tt_dnn/op_library/eltwise_binary/eltwise_binary_op.hpp" +#include "tt_lib_bindings_tensor.hpp" + +#include "tensor/host_buffer/types.hpp" +#include "tensor/serialization.hpp" +#include "tensor/tensor_impl.hpp" +#include "tensor/tensor_utils.hpp" +#include "tt_dnn/op_library/auto_format.hpp" +#include "tt_dnn/op_library/compute_kernel_config.hpp" #include "tt_dnn/op_library/conv/conv_op.hpp" #include "tt_dnn/op_library/conv/optimized_conv_op.hpp" -#include "tt_dnn/op_library/softmax/softmax_op.hpp" -#include "tt_dnn/op_library/upsample/upsample_op.hpp" +#include "tt_dnn/op_library/downsample/downsample_op.hpp" +#include "tt_dnn/op_library/eltwise_binary/eltwise_binary_op.hpp" +#include "tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp" +#include "tt_dnn/op_library/embeddings/embeddings_op.hpp" +#include "tt_dnn/op_library/fully_connected/fully_connected_op.hpp" #include "tt_dnn/op_library/groupnorm/groupnorm_op.hpp" +#include "tt_dnn/op_library/layernorm/layernorm_op.hpp" #include "tt_dnn/op_library/pool/average_pool.hpp" #include "tt_dnn/op_library/pool/max_pool.hpp" -#include "tt_dnn/op_library/fully_connected/fully_connected_op.hpp" -#include "tt_dnn/op_library/layernorm/layernorm_op.hpp" -#include "tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp" -#include "tt_dnn/op_library/auto_format.hpp" -#include "tt_dnn/op_library/split/split_last_dim_two_chunks_tiled.hpp" -#include "tt_dnn/op_library/scan/scan_op.hpp" -#include "tt_dnn/op_library/rotate_half/rotate_half_op.hpp" +#include "tt_dnn/op_library/reduce/reduce_op.hpp" #include "tt_dnn/op_library/rotary_embedding/rotary_embedding_op.hpp" -#include "tt_eager/tt_dnn/op_library/loss/loss_op.hpp" -#include "tt_dnn/op_library/embeddings/embeddings_op.hpp" +#include "tt_dnn/op_library/rotate_half/rotate_half_op.hpp" +#include "tt_dnn/op_library/scan/scan_op.hpp" +#include "tt_dnn/op_library/softmax/softmax_op.hpp" +#include "tt_dnn/op_library/split/split_last_dim_two_chunks_tiled.hpp" #include "tt_dnn/op_library/update_cache/update_cache_op.hpp" -#include "tt_dnn/op_library/reduce/reduce_op.hpp" +#include "tt_dnn/op_library/upsample/upsample_op.hpp" #include "tt_dnn/op_library/work_split.hpp" -#include "tensor/owned_buffer.hpp" -#include "tensor/borrowed_buffer.hpp" -#include "tensor/tensor_impl.hpp" -#include "tensor/tensor_utils.hpp" -#include "tensor/serialization.hpp" -#include "type_caster.hpp" +#include "tt_eager/tt_dnn/op_library/loss/loss_op.hpp" +#include "tt_lib_bindings.hpp" #include "tt_lib_bindings_tensor_impl.hpp" -#include "tt_lib_bindings_tensor.hpp" -#include "tt_dnn/op_library/compute_kernel_config.hpp" +#include "type_caster.hpp" namespace tt::tt_metal{ diff --git a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_pytensor.cpp b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_pytensor.cpp index 7bace7d4c9d6..acc7d8581e86 100644 --- a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_pytensor.cpp +++ b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_pytensor.cpp @@ -5,8 +5,7 @@ #include #include -#include "tensor/borrowed_buffer.hpp" -#include "tensor/owned_buffer.hpp" +#include "tensor/host_buffer/types.hpp" #include "tensor/tensor_impl.hpp" #include "tt_dnn/op_library/run_operation.hpp" #include "tt_lib_bindings_tensor.hpp" diff --git a/tt_eager/tt_numpy/functions.hpp b/tt_eager/tt_numpy/functions.hpp index 94bf502a9e2d..cbc35b4ad102 100644 --- a/tt_eager/tt_numpy/functions.hpp +++ b/tt_eager/tt_numpy/functions.hpp @@ -7,8 +7,8 @@ #include #include #include -#include -#include +#include +#include #include #include #include diff --git a/tt_metal/detail/util.hpp b/tt_metal/detail/util.hpp index 7e0723653bb4..994b52af5991 100644 --- a/tt_metal/detail/util.hpp +++ b/tt_metal/detail/util.hpp @@ -3,8 +3,9 @@ // SPDX-License-Identifier: Apache-2.0 #pragma once -#include "tt_metal/common/tt_backend_api_types.hpp" #include "tt_metal/common/math.hpp" +#include "tt_metal/common/tt_backend_api_types.hpp" +#include "tt_metal/hostdevcommon/common_values.hpp" #include "tt_metal/impl/kernels/data_types.hpp" namespace tt::tt_metal::detail{ diff --git a/ttnn/cpp/ttnn/op_library/binary/binary_op.hpp b/ttnn/cpp/ttnn/op_library/binary/binary_op.hpp index fe0135582af4..48e1e4194271 100644 --- a/ttnn/cpp/ttnn/op_library/binary/binary_op.hpp +++ b/ttnn/cpp/ttnn/op_library/binary/binary_op.hpp @@ -9,7 +9,7 @@ #include "tensor/tensor.hpp" #include "third_party/magic_enum/magic_enum.hpp" -#include "tt_eager/tensor/owned_buffer_functions.hpp" +#include "tt_eager/tensor/host_buffer/functions.hpp" #include "tt_eager/tensor/tensor_utils.hpp" #include "tt_eager/tt_dnn/op_library/compute_kernel_config.hpp" #include "tt_eager/tt_dnn/op_library/eltwise_binary/eltwise_binary_op.hpp" diff --git a/ttnn/cpp/ttnn/op_library/to_dtype/to_dtype_op.hpp b/ttnn/cpp/ttnn/op_library/to_dtype/to_dtype_op.hpp new file mode 100644 index 000000000000..53a675ce68c7 --- /dev/null +++ b/ttnn/cpp/ttnn/op_library/to_dtype/to_dtype_op.hpp @@ -0,0 +1,120 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +#include "tensor/tensor.hpp" +#include "ttnn/operations/core.hpp" +#include "ttnn/types.hpp" + +namespace ttnn { + +namespace operations { + +namespace core { + +namespace detail { + +inline Tensor convert_to_cpp_dtypes(const Tensor& input_tensor) { + auto input_dtype = input_tensor.get_dtype(); + + auto buffer = std::visit( + [](auto&& storage) -> std::variant { + using T = std::decay_t; + if constexpr (std::is_same_v) { + return storage.buffer; + } else if constexpr (std::is_same_v) { + TT_THROW("Device input_tensor cannot be converted to torch"); + } else if constexpr (std::is_same_v) { + return storage.buffer; + } else if constexpr (std::is_same_v) { + TT_THROW("Tensor with MultiDeviceStorage cannot be converted to torch"); + } else if constexpr (std::is_same_v) { + TT_THROW( + "Tensor MultiDeviceHostStorage cannot be converted to torch directly. Use composer(..) " + "functionality."); + } else { + raise_unsupported_storage(); + } + }, + input_tensor.get_storage()); + + if (input_dtype == DataType::BFLOAT8_B) { + auto uint32_data = std::get>(std::get(buffer)).get(); + auto float_unpacked_data = + unpack_bfp8_tiles_into_float_vec(uint32_data, /*row_major_output=*/false, /*is_exp_a=*/false); + buffer = owned_buffer::create(std::move(float_unpacked_data)); + input_dtype = DataType::FLOAT32; + } else if (input_dtype == DataType::BFLOAT4_B) { + auto uint32_data = std::get>(std::get(buffer)).get(); + auto float_unpacked_data = + unpack_bfp4_tiles_into_float_vec(uint32_data, /*row_major_output=*/false, /*is_exp_a=*/false); + buffer = owned_buffer::create(std::move(float_unpacked_data)); + input_dtype = DataType::FLOAT32; + } else { + TT_THROW("Unsupported input data type"); + } + + return std::visit( + [&](auto&& buffer) -> Tensor { + using T = std::decay_t; + if constexpr (std::is_same_v) { + return Tensor{OwnedStorage{buffer}, input_tensor.get_shape(), input_dtype, input_tensor.get_layout()}; + } else if constexpr (std::is_same_v) { + return Tensor{ + BorrowedStorage{buffer, []() {}, []() {}}, + input_tensor.get_shape(), + input_dtype, + input_tensor.get_layout()}; + } else { + TT_THROW("Unsupported buffer type"); + } + }, + buffer); +} + +} // namespace detail + +struct ToDtype { + static inline const std::array input_tensor_schemas() { + return {ttnn::TensorSchema{ + 1, + 8, + {ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b, ttnn::float32, ttnn::uint16, ttnn::uint32, ttnn::int32}, + {ttnn::ROW_MAJOR_LAYOUT, ttnn::TILE_LAYOUT}, + true, + true, + false, + false}}; + } + + template + static auto input_tensors_to_validate(const Tensor& tensor_arg, Args&&... args) { + return std::make_tuple(tensor_arg); + }; + + // TODO: Move to cpp once we merge with tt_eager + static Tensor execute(const ttnn::Tensor& input_tensor, const ttnn::DataType& dtype) { + auto input_layout = input_tensor.get_layout(); + auto input_dtype = input_tensor.get_dtype(); + + if (input_dtype == dtype) { + return input_tensor; + } + + if (input_layout != ttnn::ROW_MAJOR_LAYOUT) { + TT_THROW("Only ROW_MAJOR_LAYOUT is supported"); + } + + auto output_tensor = input_tensor; + return output_tensor; + }; +}; + +} // namespace core +} // namespace operations +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/op_library/to_layout/to_layout_op.hpp b/ttnn/cpp/ttnn/op_library/to_layout/to_layout_op.hpp index 801b1f1a8a0b..d61a8468500a 100644 --- a/ttnn/cpp/ttnn/op_library/to_layout/to_layout_op.hpp +++ b/ttnn/cpp/ttnn/op_library/to_layout/to_layout_op.hpp @@ -9,7 +9,7 @@ #include "tensor/tensor.hpp" #include "third_party/magic_enum/magic_enum.hpp" -#include "tt_eager/tensor/owned_buffer_functions.hpp" +#include "tt_eager/tensor/host_buffer/functions.hpp" #include "tt_eager/tensor/tensor_utils.hpp" #include "tt_eager/tt_dnn/op_library/compute_kernel_config.hpp" #include "tt_eager/tt_dnn/op_library/run_operation.hpp" diff --git a/ttnn/cpp/ttnn/op_library/to_memory_config/to_memory_config_op.hpp b/ttnn/cpp/ttnn/op_library/to_memory_config/to_memory_config_op.hpp index 7f283806e330..24edc6f0855c 100644 --- a/ttnn/cpp/ttnn/op_library/to_memory_config/to_memory_config_op.hpp +++ b/ttnn/cpp/ttnn/op_library/to_memory_config/to_memory_config_op.hpp @@ -9,7 +9,7 @@ #include "tensor/tensor.hpp" // #include "third_party/magic_enum/magic_enum.hpp" -// #include "tt_eager/tensor/owned_buffer_functions.hpp" +// #include "tt_eager/tensor/host_buffer/functions.hpp" // #include "tt_eager/tensor/tensor_utils.hpp" #include "tt_eager/tt_dnn/op_library/compute_kernel_config.hpp" #include "tt_eager/tt_dnn/op_library/copy/copy_op.hpp" @@ -31,7 +31,7 @@ struct ToMemoryConfig { static inline const std::array input_tensor_schemas() { return {ttnn::TensorSchema{ 1, - 4, + 8, {ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b, ttnn::float32, ttnn::uint16, ttnn::uint32, ttnn::int32}, {ttnn::ROW_MAJOR_LAYOUT, ttnn::TILE_LAYOUT}, true, @@ -47,9 +47,7 @@ struct ToMemoryConfig { // TODO: Move to cpp once we merge with tt_eager static Tensor execute( - const ttnn::Tensor tensor, - const ttnn::MemoryConfig& memory_config, - std::optional dtype) { + const ttnn::Tensor& tensor, const ttnn::MemoryConfig& memory_config, std::optional dtype) { // Temporary until we see why buffer data not being populated const auto original_shape = tensor.get_shape(); diff --git a/ttnn/cpp/ttnn/operations/core.hpp b/ttnn/cpp/ttnn/operations/core.hpp index 911dd40b3b04..85d4e03d06fa 100644 --- a/ttnn/cpp/ttnn/operations/core.hpp +++ b/ttnn/cpp/ttnn/operations/core.hpp @@ -17,6 +17,7 @@ #include "ttnn/core.hpp" #include "ttnn/decorators.hpp" #include "ttnn/op_library/to_layout/to_layout_op.hpp" +#include "ttnn/op_library/to_dtype/to_dtype_op.hpp" #include "ttnn/op_library/to_memory_config/to_memory_config_op.hpp" #include "ttnn/types.hpp" #include "ttnn/validation.hpp" @@ -188,6 +189,8 @@ using operations::core::squeeze_from_4D; using operations::core::to_device; using operations::core::unsqueeze_to_4D; +constexpr auto to_dtype = ttnn::register_operation("ttnn::to_dtype"); constexpr auto to_memory_config = ttnn::register_operation("ttnn::to_memory_config"); constexpr auto to_layout = ttnn::register_operation("ttnn::to_layout"); + } // namespace ttnn