From 47a3bcdebb06712da41a2a481f5bd60c36f520eb Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 29 Nov 2024 19:07:55 +0000 Subject: [PATCH] #13745:move tensor.reshape_unsafe to ttnn.experimental --- tests/ttnn/unit_tests/test_reshape.py | 3 +- ttnn/CMakeLists.txt | 2 + .../reshape_on_device/device/reshape_op.cpp | 4 +- .../experimental/experimental_pybind.cpp | 6 + .../experimental/reshape/reshape.cpp | 133 ++++++++++++++++++ .../experimental/reshape/reshape.hpp | 26 ++++ .../experimental/reshape/reshape_pybind.cpp | 72 ++++++++++ .../experimental/reshape/reshape_pybind.hpp | 13 ++ 8 files changed, 256 insertions(+), 3 deletions(-) create mode 100644 ttnn/cpp/ttnn/operations/experimental/reshape/reshape.cpp create mode 100644 ttnn/cpp/ttnn/operations/experimental/reshape/reshape.hpp create mode 100644 ttnn/cpp/ttnn/operations/experimental/reshape/reshape_pybind.cpp create mode 100644 ttnn/cpp/ttnn/operations/experimental/reshape/reshape_pybind.hpp diff --git a/tests/ttnn/unit_tests/test_reshape.py b/tests/ttnn/unit_tests/test_reshape.py index 0900d428aa98..9a8bb32d29e8 100644 --- a/tests/ttnn/unit_tests/test_reshape.py +++ b/tests/ttnn/unit_tests/test_reshape.py @@ -35,7 +35,8 @@ def test_reshape_sharded_rm(device, n, c, h, w): torch_input_tensor, layout=ttnn.ROW_MAJOR_LAYOUT, device=device, memory_config=sharded_mem_config ) - tt_output_tensor = tt_input_tensor.reshape_unsafe(n, c, h * 2, w // 2) + # tt_output_tensor = tt_input_tensor.reshape_unsafe(n, c, h * 2, w // 2) + tt_output_tensor = ttnn.experimental.reshape(tt_input_tensor, n, c, h * 2, w // 2) sharded_mem_config = ttnn.create_sharded_memory_config( tt_output_tensor.shape, diff --git a/ttnn/CMakeLists.txt b/ttnn/CMakeLists.txt index 2f432ee54e0e..683af4bfba7f 100644 --- a/ttnn/CMakeLists.txt +++ b/ttnn/CMakeLists.txt @@ -580,6 +580,8 @@ set(ALL_TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/expand/expand_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/expand/device/expand_rm_program_factory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/expand/device/expand_device_operation.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/reshape/reshape.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/reshape/reshape_pybind.cpp ) #Split src and python bindings diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/device/reshape_op.cpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/device/reshape_op.cpp index 1da82c867772..291bcd154272 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/device/reshape_op.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/device/reshape_op.cpp @@ -26,10 +26,10 @@ void ReshapeDeviceOperation::validate(const std::vector& input_tensors) TT_FATAL( input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED, - "Reshape does not currently support sharding"); + "Use view_unsafe for reshaping sharded inputs"); TT_FATAL( this->output_mem_config.memory_layout == TensorMemoryLayout::INTERLEAVED, - "Reshape does not currently support sharding"); + "Use view_unsafe for reshaping sharded inputs"); if (input_tensor_a.get_layout() == Layout::TILE) { TT_FATAL(input_tensor_a.volume() % TILE_HW == 0, "Error"); diff --git a/ttnn/cpp/ttnn/operations/experimental/experimental_pybind.cpp b/ttnn/cpp/ttnn/operations/experimental/experimental_pybind.cpp index d6f9431947f3..290545fbc030 100644 --- a/ttnn/cpp/ttnn/operations/experimental/experimental_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/experimental_pybind.cpp @@ -36,6 +36,9 @@ #include "ttnn/operations/experimental/ccl/all_gather_matmul/all_gather_matmul_pybind.hpp" #include "ttnn/operations/experimental/ccl/all_reduce/all_reduce_pybind.hpp" #include "ttnn/operations/experimental/plusone/plusone_pybind.hpp" + +#include "ttnn/operations/experimental/reshape/reshape_pybind.hpp" + namespace ttnn::operations::experimental { void py_module(py::module& module) { @@ -76,10 +79,13 @@ void py_module(py::module& module) { plusone::detail::bind_experimental_plusone_operation(module); + reshape::detail::py_bind_reshape(module); + // CCL ops auto m_experimental_ccl = module.def_submodule("ccl", "experiemental collective communication operations"); ccl::py_bind_all_gather_matmul(m_experimental_ccl); ccl::py_bind_all_reduce(m_experimental_ccl); + } } // namespace ttnn::operations::experimental diff --git a/ttnn/cpp/ttnn/operations/experimental/reshape/reshape.cpp b/ttnn/cpp/ttnn/operations/experimental/reshape/reshape.cpp new file mode 100644 index 000000000000..397d468df6ff --- /dev/null +++ b/ttnn/cpp/ttnn/operations/experimental/reshape/reshape.cpp @@ -0,0 +1,133 @@ + +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ttnn/common/constants.hpp" +#include "ttnn/run_operation.hpp" +#include "reshape.hpp" +#include "tt_metal/common/constants.hpp" +#include +#include +#include "ttnn/operations/experimental/auto_format/auto_format.hpp" +#include "ttnn/tensor/tensor_utils.hpp" +#include "ttnn/operations/data_movement/data_transfer/data_transfer.hpp" +#include "ttnn/operations/data_movement/slice/slice.hpp" +#include "ttnn/operations/core/core.hpp" + + +#include "ttnn/tensor/tensor.hpp" + +#include +#include + +#include "common/bfloat16.hpp" +#include "ttnn/tensor/tensor_impl.hpp" +#include "ttnn/tensor/tensor_impl_wrapper.hpp" +#include "ttnn/tensor/tensor_utils.hpp" +#include "ttnn/tensor/types.hpp" +#include "tt_metal/common/constants.hpp" +#include "tt_metal/common/math.hpp" +#include "tt_metal/third_party/tracy/public/tracy/Tracy.hpp" +#include "tt_metal/graph/graph_tracking.hpp" +#include "ttnn/distributed/api.hpp" +#include "ttnn/distributed/types.hpp" +#include "ttnn/core.hpp" + + +namespace ttnn{ + +namespace operations::experimental::reshape { +ttnn::Tensor tensor_reshape(const ttnn::Tensor& input_tensor, const ttnn::Shape& new_shape) { + ZoneScoped; + GraphTracker::instance().track_function_start("ttnn::experimental::reshape", input_tensor, new_shape); + const auto& new_padded_shape = new_shape.padded_shape(); + const auto tile = input_tensor.get_tensor_spec().tile(); + TT_ASSERT( + input_tensor.volume() == new_padded_shape.volume(), + "{} != {}", + input_tensor.volume(), + new_padded_shape.volume()); + if (input_tensor.get_layout() == Layout::TILE) { + TT_ASSERT( + new_padded_shape[-2] % tile.get_tile_shape()[0] == 0 && + new_padded_shape[-1] % tile.get_tile_shape()[1] == 0 && + "Expected a multiple of 32 for H, W (or -1 evaluating to such) in ttnn::experimental::reshape()!"); + } + auto output = std::visit( + [&input_tensor, &new_shape, &tile](auto&& storage) -> Tensor { + using T = std::decay_t; + const auto& tensor = input_tensor; + if constexpr (std::is_same_v) { + auto updated_storage = std::get(tensor.get_storage()); + for (int i = 0; i < updated_storage.shapes.size(); i++) { + updated_storage.shapes[i] = new_shape; + } + return Tensor(updated_storage, new_shape, tensor.get_dtype(), tensor.get_layout(), tile); + } + if constexpr (std::is_same_v) { + MultiDeviceStorage updated_storage = std::get(tensor.get_storage()); + std::unordered_map new_shapes; + + for (auto device_id : updated_storage.ordered_device_ids) { + new_shapes.insert({device_id, new_shape}); + } + updated_storage.shapes = new_shapes; + return Tensor(updated_storage, new_shape, tensor.get_dtype(), tensor.get_layout(), tile); + } + if constexpr (std::is_same_v) { + if (input_tensor.get_layout() == Layout::ROW_MAJOR) { + if (tensor.memory_config().memory_layout != TensorMemoryLayout::HEIGHT_SHARDED) { + DeviceStorage device_storage = std::get(tensor.get_storage()); + DeviceBuffer device_buffer = device_storage.get_buffer(); + device_buffer->set_page_size(new_shape[-1] * tensor.element_size()); + device_storage.insert_buffer(device_buffer); + return Tensor(device_storage, new_shape, tensor.get_dtype(), tensor.get_layout(), tile); + } else { + DeviceStorage device_storage = std::get(tensor.get_storage()); + DeviceBuffer device_buffer = device_storage.get_buffer(); + ShardSpecBuffer shard_spec_buffer = device_buffer->shard_spec(); + + auto shard_spec = shard_spec_buffer.tensor_shard_spec; + auto shard_shape = shard_spec.shape; + + uint32_t mul_div = new_shape[-1] > shard_shape[1] ? (new_shape[-1] / shard_shape[1]) + : (shard_shape[1] / new_shape[-1]); + shard_spec.shape[0] = + new_shape[-1] > shard_shape[1] ? shard_shape[0] / mul_div : shard_shape[0] * mul_div; + shard_spec.shape[1] = new_shape[-1]; + + shard_spec_buffer.page_shape = {1, new_shape[-1]}; + shard_spec_buffer.tensor2d_shape = {shard_spec.shape[0], 1}; + shard_spec_buffer.set_shard_spec(shard_spec); + + device_buffer->set_shard_spec(shard_spec_buffer); + device_storage.insert_buffer(device_buffer); + + return Tensor(device_storage, new_shape, tensor.get_dtype(), tensor.get_layout(), tile); + } + } else { + return Tensor(tensor.get_storage(), new_shape, tensor.get_dtype(), tensor.get_layout(), tile); + } + } else { + return Tensor(tensor.get_storage(), new_shape, tensor.get_dtype(), tensor.get_layout(), tile); + } + }, + input_tensor.get_storage()); + output = tt::tt_metal::set_tensor_id(output); + GraphTracker::instance().track_function_end(output); + return output; +} + + + +ttnn::Tensor ReshapeOperation::invoke(const ttnn::Tensor& tensor, const ttnn::SimpleShape& shape) { + return tensor_reshape(tensor, shape); +} + +ttnn::Tensor ReshapeOperation::invoke(const ttnn::Tensor& tensor, const ttnn::Shape& shape) { + return tensor_reshape(tensor, shape); +} + +} // namespace operations::experimental::reshape +} //namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/experimental/reshape/reshape.hpp b/ttnn/cpp/ttnn/operations/experimental/reshape/reshape.hpp new file mode 100644 index 000000000000..aa96e96c9a43 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/experimental/reshape/reshape.hpp @@ -0,0 +1,26 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ttnn/run_operation.hpp" +#include "ttnn/decorators.hpp" +#include + +namespace ttnn { +namespace operations::experimental::reshape { + + +struct ReshapeOperation { + static ttnn::Tensor invoke(const ttnn::Tensor& input_tensor, const ttnn::Shape& shape); + static ttnn::Tensor invoke(const ttnn::Tensor& input_tensor, const ttnn::SimpleShape& shape); +}; + +} // namespace operations::experimental::reshape + +namespace experimental { +constexpr auto reshape = + ttnn::register_operation_with_auto_launch_op<"ttnn::experimental::reshape", ttnn::operations::experimental::reshape::ReshapeOperation>(); +} // namespace experimental +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/experimental/reshape/reshape_pybind.cpp b/ttnn/cpp/ttnn/operations/experimental/reshape/reshape_pybind.cpp new file mode 100644 index 000000000000..a1677e8d0ea6 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/experimental/reshape/reshape_pybind.cpp @@ -0,0 +1,72 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "reshape_pybind.hpp" +#include "reshape.hpp" + +#include +#include + +#include "ttnn/cpp/pybind11/decorators.hpp" + +#include "ttnn/types.hpp" + +#include "ttnn/tensor/tensor.hpp" +#include "ttnn/tensor/tensor_impl.hpp" + + +namespace ttnn::operations::experimental::reshape::detail { +namespace py = pybind11; + +void py_bind_reshape(py::module& module) { + auto doc = R"doc( + + Note: for a 0 cost view, the following conditions must be met: + * the last dimension must not change + * In Tiled the second last two dimensions must not change OR there is no padding on the second last dimension + + Args: + * input_tensor: Input Tensor. + * new_shape: New shape of tensor. + + Returns: + ttnn.Tensor: the output tensor with the new shape. + + Example: + + >>> tensor = ttnn.from_torch(torch.tensor((1, 4), dtype=torch.bfloat16), device=device) + >>> output = ttnn.experimental.reshape(tensor, (1, 1, 2, 2)) + + )doc"; + bind_registered_operation( + module, + ttnn::experimental::reshape, + doc, + ttnn::pybind_overload_t{ + [](const decltype(ttnn::experimental::reshape)& self, ttnn::Tensor& input_tensor, int N, int C, int H, int W) { + return self(input_tensor, infer_dims_for_reshape(input_tensor, ttnn::SmallVector{N, C, H, W})); + }, + py::arg("input_tensor"), + py::arg("N"), + py::arg("C"), + py::arg("H"), + py::arg("W"), + }, + + ttnn::pybind_overload_t{ + [](const decltype(ttnn::experimental::reshape)& self, ttnn::Tensor& input_tensor, const ttnn::Shape& shape) { + return self(input_tensor, shape); }, + py::arg("input_tensor"), + py::arg("shape"), + }, + ttnn::pybind_overload_t{ + [](const decltype(ttnn::experimental::reshape)& self, ttnn::Tensor& input_tensor, const ttnn::SmallVector& shape) { + return self(input_tensor, infer_dims_for_reshape(input_tensor, shape)); + }, + py::arg("input_tensor"), + py::arg("shape"), + }); +} + +} // namespace ttnn::operations::experimental::reshape::detail diff --git a/ttnn/cpp/ttnn/operations/experimental/reshape/reshape_pybind.hpp b/ttnn/cpp/ttnn/operations/experimental/reshape/reshape_pybind.hpp new file mode 100644 index 000000000000..f528e86d48b3 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/experimental/reshape/reshape_pybind.hpp @@ -0,0 +1,13 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "pybind11/pybind_fwd.hpp" + +namespace ttnn::operations::experimental::reshape::detail { + +void py_bind_reshape(pybind11::module& module); + +} // namespace ttnn::operations::experimental