diff --git a/tests/ttnn/unit_tests/test_to_dtype.py b/tests/ttnn/unit_tests/test_to_dtype.py new file mode 100644 index 000000000000..112359c4389d --- /dev/null +++ b/tests/ttnn/unit_tests/test_to_dtype.py @@ -0,0 +1,39 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +import torch + +import ttnn + +from tests.ttnn.utils_for_testing import assert_with_pcc + + +@pytest.mark.parametrize("height", [32]) +@pytest.mark.parametrize("width", [32]) +@pytest.mark.parametrize("from_dtype", [ttnn.float32, ttnn.bfloat16]) +@pytest.mark.parametrize("to_dtype", [ttnn.bfloat16, ttnn.float32, ttnn.bfloat8_b]) +def test_to_dtype(height, width, from_dtype, to_dtype): + torch_input_tensor = torch.rand((height, width), dtype=torch.bfloat16) + + input_tensor = ttnn.from_torch(torch_input_tensor) + assert input_tensor.layout == ttnn.ROW_MAJOR_LAYOUT + + input_tensor = ttnn.to_dtype(input_tensor, from_dtype) + assert input_tensor.dtype == from_dtype + assert input_tensor.layout == ttnn.ROW_MAJOR_LAYOUT + assert tuple(input_tensor.shape) == (height, width) + + output_tensor = ttnn.to_dtype(input_tensor, to_dtype) + assert output_tensor.dtype == to_dtype + if to_dtype == ttnn.bfloat8_b: + assert output_tensor.layout == ttnn.TILE_LAYOUT + else: + assert output_tensor.layout == ttnn.ROW_MAJOR_LAYOUT + assert tuple(output_tensor.shape) == (height, width) + + output_tensor = ttnn.to_torch(output_tensor).to(torch_input_tensor.dtype) + + assert_with_pcc(torch_input_tensor, output_tensor) diff --git a/tt_eager/tensor/host_buffer/functions.hpp b/tt_eager/tensor/host_buffer/functions.hpp index 47626ab2134c..17c57b253f01 100644 --- a/tt_eager/tensor/host_buffer/functions.hpp +++ b/tt_eager/tensor/host_buffer/functions.hpp @@ -182,9 +182,9 @@ borrowed_buffer::Buffer get_as(Tensor& tensor) { [](auto&& storage) -> borrowed_buffer::Buffer { using StorageType = std::decay_t; if constexpr (std::is_same_v) { - return get_as(storage.buffer); + return host_buffer::get_as(storage.buffer); } else if constexpr (std::is_same_v) { - return get_as(storage.buffer); + return host_buffer::get_as(storage.buffer); } else { TT_THROW("Tensor must have OwnedStorage or BorrowedStorage"); } @@ -198,9 +198,9 @@ const borrowed_buffer::Buffer get_as(const Tensor& tensor) { [](auto&& storage) -> const borrowed_buffer::Buffer { using StorageType = std::decay_t; if constexpr (std::is_same_v) { - return get_as(storage.buffer); + return host_buffer::get_as(storage.buffer); } else if constexpr (std::is_same_v) { - return get_as(storage.buffer); + return host_buffer::get_as(storage.buffer); } else { TT_THROW("Tensor must have OwnedStorage or BorrowedStorage"); } diff --git a/ttnn/cpp/pybind11/operations/core.hpp b/ttnn/cpp/pybind11/operations/core.hpp index d045816bdc7b..5ca96b60e3fa 100644 --- a/ttnn/cpp/pybind11/operations/core.hpp +++ b/ttnn/cpp/pybind11/operations/core.hpp @@ -87,6 +87,19 @@ void py_module(py::module& module) { module.def("deallocate", &ttnn::operations::core::deallocate, py::arg("tensor"), py::arg("force") = true); + module.def( + "reallocate", + [](ttnn::Tensor& input_tensor, const std::optional& memory_config = std::nullopt) + -> ttnn::Tensor { return reallocate(input_tensor, memory_config); }, + py::arg("tensor"), + py::arg("memory_config") = std::nullopt, + R"doc( + Deallocates device tensor and returns a reallocated tensor + + Args: + * :attr:`input_tensor`: Input Tensor + )doc"); + bind_registered_operation( module, ttnn::to_memory_config, @@ -106,29 +119,25 @@ void py_module(py::module& module) { >>> tensor = ttnn.to_device(ttnn.from_torch(torch.randn((10, 64, 32), dtype=torch.bfloat16)), device) >>> tensor = ttnn.to_memory_config(tensor, memory_config) )doc", - ttnn::pybind_overload_t{ - [](const std::decay_t self, - const ttnn::Tensor& tensor, - const ttnn::MemoryConfig& memory_config, - const std::optional& dtype) -> ttnn::Tensor { - return self(tensor, memory_config, dtype); - }, - py::arg("tensor"), - py::arg("memory_config"), - py::arg("dtype") = std::nullopt}); + ttnn::pybind_arguments_t{py::arg("tensor"), py::arg("memory_config"), py::arg("dtype") = std::nullopt}); - module.def( - "reallocate", - [](ttnn::Tensor& input_tensor, const std::optional& memory_config = std::nullopt) - -> ttnn::Tensor { return reallocate(input_tensor, memory_config); }, - py::arg("tensor"), - py::arg("memory_config") = std::nullopt, - R"doc( -Deallocates device tensor and returns a reallocated tensor + bind_registered_operation( + module, + ttnn::to_dtype, + R"doc(to_dtype(tensor: ttnn.Tensor, dtype: DataType = None) -> ttnn.Tensor -Args: - * :attr:`input_tensor`: Input Tensor - )doc"); + Converts a tensor to the desired dtype + + + Args: + * :attr:`tensor`: the ttnn.Tensor + * :attr:`dtype`: `ttnn` data type. + + Example:: + >>> tensor = ttnn.from_torch(torch.randn((10, 64, 32), dtype=torch.bfloat16)) + >>> tensor = ttnn.to_dtype(tensor, dtype=ttnn.uint16) + )doc", + ttnn::pybind_arguments_t{py::arg("tensor"), py::arg("dtype")}); bind_registered_operation( module, diff --git a/ttnn/cpp/ttnn/decorators.hpp b/ttnn/cpp/ttnn/decorators.hpp index 9a4d196186f7..5d1c37bfa111 100644 --- a/ttnn/cpp/ttnn/decorators.hpp +++ b/ttnn/cpp/ttnn/decorators.hpp @@ -283,7 +283,7 @@ struct operation_without_validation_t { template static auto input_tensors_to_validate(args_t&&... args) { return std::make_tuple(); - }; + } template static auto execute(args_t&&... args) { diff --git a/ttnn/cpp/ttnn/op_library/to_dtype/to_dtype_op.hpp b/ttnn/cpp/ttnn/op_library/to_dtype/to_dtype_op.hpp index 53a675ce68c7..bd1f3a9fcb39 100644 --- a/ttnn/cpp/ttnn/op_library/to_dtype/to_dtype_op.hpp +++ b/ttnn/cpp/ttnn/op_library/to_dtype/to_dtype_op.hpp @@ -19,7 +19,7 @@ namespace core { namespace detail { -inline Tensor convert_to_cpp_dtypes(const Tensor& input_tensor) { +inline Tensor convert_to_cpp_supported_dtype(const Tensor& input_tensor) { auto input_dtype = input_tensor.get_dtype(); auto buffer = std::visit( @@ -55,8 +55,6 @@ inline Tensor convert_to_cpp_dtypes(const Tensor& input_tensor) { unpack_bfp4_tiles_into_float_vec(uint32_data, /*row_major_output=*/false, /*is_exp_a=*/false); buffer = owned_buffer::create(std::move(float_unpacked_data)); input_dtype = DataType::FLOAT32; - } else { - TT_THROW("Unsupported input data type"); } return std::visit( @@ -77,6 +75,107 @@ inline Tensor convert_to_cpp_dtypes(const Tensor& input_tensor) { buffer); } +template +inline std::vector cast(const borrowed_buffer::Buffer& input_buffer) { + std::vector output_vector(input_buffer.size()); + for (auto index = 0; index < input_buffer.size(); ++index) { + auto convert_value = [](auto&& value) { + if constexpr (std::is_same_v) { + return value.to_float(); + } else if constexpr (std::is_same_v) { + return static_cast(value); + } else { + return value; + } + }; + auto value = input_buffer[index]; + output_vector[index] = static_cast(convert_value(value)); + } + return output_vector; +} + +template +Tensor create_owned_tensor(std::vector&& data, const Shape& shape, DataType data_type, Layout layout) { + auto buffer = owned_buffer::create(std::move(data)); + auto storage = OwnedStorage{std::move(buffer)}; + return Tensor(std::move(storage), shape, data_type, layout); +} + +template +inline Tensor create_tensor_from_buffer( + const borrowed_buffer::Buffer& input_buffer, const Shape& shape, const DataType& dtype) { + switch (dtype) { + case DataType::UINT16: { + auto data = cast(input_buffer); + return create_owned_tensor(std::move(data), shape, dtype, Layout::ROW_MAJOR); + } + case DataType::INT32: { + auto data = cast(input_buffer); + return create_owned_tensor(std::move(data), shape, dtype, Layout::ROW_MAJOR); + } + case DataType::UINT32: { + auto data = cast(input_buffer); + return create_owned_tensor(std::move(data), shape, dtype, Layout::ROW_MAJOR); + } + case DataType::FLOAT32: { + auto data = cast(input_buffer); + return create_owned_tensor(std::move(data), shape, dtype, Layout::ROW_MAJOR); + } + case DataType::BFLOAT16: { + auto data = cast<::bfloat16, T>(input_buffer); + return create_owned_tensor(std::move(data), shape, dtype, Layout::ROW_MAJOR); + } + case DataType::BFLOAT8_B: { + auto data = cast(input_buffer); + auto uint32_vector = pack_fp32_vec_as_bfp8_tiles(data, /*row_major_input=*/false, /*is_exp_a=*/false); + auto buffer = owned_buffer::create(std::move(uint32_vector)); + auto storage = OwnedStorage{std::move(buffer)}; + return Tensor(std::move(storage), shape, dtype, Layout::ROW_MAJOR).to(ttnn::TILE_LAYOUT); + } + case DataType::BFLOAT4_B: { + auto data = cast(input_buffer); + auto uint32_vector = pack_fp32_vec_as_bfp4_tiles(data, /*row_major_input=*/false, /*is_exp_a=*/false); + auto buffer = owned_buffer::create(std::move(uint32_vector)); + auto storage = OwnedStorage{std::move(buffer)}; + return Tensor(std::move(storage), shape, dtype, Layout::ROW_MAJOR).to(ttnn::TILE_LAYOUT); + } + default: { + TT_THROW(fmt::format("Unsupported DataType: {}", dtype)); + break; + } + } +} + +inline Tensor convert_to_dtype(const Tensor& input_tensor, const DataType& dtype) { + auto input_dtype = input_tensor.get_dtype(); + + switch (input_dtype) { + case DataType::UINT16: { + auto buffer = host_buffer::get_as(input_tensor); + return create_tensor_from_buffer(buffer, input_tensor.get_shape(), dtype); + } + case DataType::INT32: { + auto buffer = host_buffer::get_as(input_tensor); + return create_tensor_from_buffer(buffer, input_tensor.get_shape(), dtype); + } + case DataType::UINT32: { + auto buffer = host_buffer::get_as(input_tensor); + return create_tensor_from_buffer(buffer, input_tensor.get_shape(), dtype); + } + case DataType::FLOAT32: { + auto buffer = host_buffer::get_as(input_tensor); + return create_tensor_from_buffer(buffer, input_tensor.get_shape(), dtype); + } + case DataType::BFLOAT16: { + auto buffer = host_buffer::get_as<::bfloat16>(input_tensor); + return create_tensor_from_buffer(buffer, input_tensor.get_shape(), dtype); + } + default: TT_THROW(fmt::format("Unsupported DataType: {}", input_dtype)); break; + } + + return input_tensor; +} + } // namespace detail struct ToDtype { @@ -95,7 +194,7 @@ struct ToDtype { template static auto input_tensors_to_validate(const Tensor& tensor_arg, Args&&... args) { return std::make_tuple(tensor_arg); - }; + } // TODO: Move to cpp once we merge with tt_eager static Tensor execute(const ttnn::Tensor& input_tensor, const ttnn::DataType& dtype) { @@ -110,8 +209,8 @@ struct ToDtype { TT_THROW("Only ROW_MAJOR_LAYOUT is supported"); } - auto output_tensor = input_tensor; - return output_tensor; + auto intermediate_tensor = detail::convert_to_cpp_supported_dtype(input_tensor); + return detail::convert_to_dtype(intermediate_tensor, dtype); }; }; diff --git a/ttnn/cpp/ttnn/op_library/to_memory_config/to_memory_config_op.hpp b/ttnn/cpp/ttnn/op_library/to_memory_config/to_memory_config_op.hpp index 24edc6f0855c..e23b6d21924b 100644 --- a/ttnn/cpp/ttnn/op_library/to_memory_config/to_memory_config_op.hpp +++ b/ttnn/cpp/ttnn/op_library/to_memory_config/to_memory_config_op.hpp @@ -43,7 +43,7 @@ struct ToMemoryConfig { template static auto input_tensors_to_validate(const Tensor& tensor_arg, Args&&... args) { return std::make_tuple(tensor_arg); - }; + } // TODO: Move to cpp once we merge with tt_eager static Tensor execute( diff --git a/ttnn/cpp/ttnn/operations/normalization.hpp b/ttnn/cpp/ttnn/operations/normalization.hpp index aa9cce6bf392..1229f1bbcc18 100644 --- a/ttnn/cpp/ttnn/operations/normalization.hpp +++ b/ttnn/cpp/ttnn/operations/normalization.hpp @@ -50,7 +50,7 @@ struct Softmax { } }; -struct LayerNorm : tt::operations::primary::LayerNorm { +struct LayerNorm { static inline const std::array input_tensor_schemas() { return { ttnn::TensorSchema{ @@ -123,7 +123,7 @@ struct LayerNorm : tt::operations::primary::LayerNorm { } }; -struct RMSNorm : tt::operations::primary::LayerNorm { +struct RMSNorm { static inline const std::array input_tensor_schemas() { return { ttnn::TensorSchema{ diff --git a/ttnn/cpp/ttnn/operations/transformer.hpp b/ttnn/cpp/ttnn/operations/transformer.hpp index 1767e8f84d02..da31a9744c3d 100644 --- a/ttnn/cpp/ttnn/operations/transformer.hpp +++ b/ttnn/cpp/ttnn/operations/transformer.hpp @@ -236,7 +236,7 @@ struct ConcatenateHeads : public tt::tt_metal::NlpConcatHeads { }; template -struct AttentionSoftmax : public tt::operations::primary::Softmax { +struct ExecuteAttentionSoftmax { static inline const std::array input_tensor_schemas() { return { ttnn::TensorSchema{4, 4, {ttnn::bfloat16, ttnn::bfloat8_b}, {ttnn::TILE_LAYOUT}, true, false, false, false}, @@ -274,7 +274,7 @@ struct AttentionSoftmax : public tt::operations::primary::Softmax { auto kernel_config_val = init_device_compute_kernel_config( input_tensor.device()->arch(), compute_kernel_config, MathFidelity::HiFi4, true, false, false); auto output_tensor = operation::run( - AttentionSoftmax{ + tt::operations::primary::Softmax{ head_size, in_place, memory_config.value_or(input_tensor.memory_config()), @@ -301,10 +301,12 @@ constexpr auto split_query_key_value_and_split_heads = constexpr auto concatenate_heads = ttnn::register_operation("ttnn::transfomer::concatenate_heads"); -constexpr auto attention_softmax = ttnn::register_operation>( - "ttnn::transfomer::attention_softmax"); -constexpr auto attention_softmax_ = ttnn::register_operation>( - "ttnn::transfomer::attention_softmax_"); +constexpr auto attention_softmax = + ttnn::register_operation>( + "ttnn::transfomer::attention_softmax"); +constexpr auto attention_softmax_ = + ttnn::register_operation>( + "ttnn::transfomer::attention_softmax_"); } // namespace transformer } // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/unary.hpp b/ttnn/cpp/ttnn/operations/unary.hpp index 0561cc096582..6c002fcc40ac 100644 --- a/ttnn/cpp/ttnn/operations/unary.hpp +++ b/ttnn/cpp/ttnn/operations/unary.hpp @@ -36,7 +36,7 @@ inline const std::array input_tensor_schemas() { template inline auto input_tensors_to_validate(const Tensor& input_tensor, Args&&... args) { return std::make_tuple(input_tensor); -}; +} inline Tensor execute( const Tensor& input_tensor, @@ -53,7 +53,7 @@ inline Tensor execute( } // namespace detail template -struct Unary : public EltwiseUnary { +struct Unary { static const std::array input_tensor_schemas() { return detail::input_tensor_schemas(); } template @@ -68,7 +68,7 @@ struct Unary : public EltwiseUnary { }; template -struct UnaryWithFastAndApproximateMode : public EltwiseUnary { +struct UnaryWithFastAndApproximateMode { static const std::array input_tensor_schemas() { return detail::input_tensor_schemas(); } template diff --git a/ttnn/ttnn/__init__.py b/ttnn/ttnn/__init__.py index 47ec5fc67cf4..217508a233b2 100644 --- a/ttnn/ttnn/__init__.py +++ b/ttnn/ttnn/__init__.py @@ -250,6 +250,7 @@ def manage_config(name, value): to_device, from_device, to_layout, + to_dtype, reshape, to_memory_config, deallocate, diff --git a/ttnn/ttnn/operations/core.py b/ttnn/ttnn/operations/core.py index aea2c429fff2..7121a7b1bda6 100644 --- a/ttnn/ttnn/operations/core.py +++ b/ttnn/ttnn/operations/core.py @@ -491,6 +491,13 @@ def _golden_function(tensor, *args, **kwargs): to_layout = ttnn.register_operation(golden_function=_golden_function)(ttnn._ttnn.operations.core.to_layout) +def _golden_function(tensor, *args, **kwargs): + return tensor + + +to_dtype = ttnn.register_operation(golden_function=_golden_function)(ttnn._ttnn.operations.core.to_dtype) + + def _clone_validate_input_tensors(operation_name, input_tensor, *args, **kwargs): ttnn.validate_input_tensor( operation_name,