Skip to content

Commit

Permalink
#13745:move tensor.reshape_unsafe to ttnn.experimental
Browse files Browse the repository at this point in the history
  • Loading branch information
nardoTT committed Dec 4, 2024
1 parent 1556034 commit 47a3bcd
Show file tree
Hide file tree
Showing 8 changed files with 256 additions and 3 deletions.
3 changes: 2 additions & 1 deletion tests/ttnn/unit_tests/test_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def test_reshape_sharded_rm(device, n, c, h, w):
torch_input_tensor, layout=ttnn.ROW_MAJOR_LAYOUT, device=device, memory_config=sharded_mem_config
)

tt_output_tensor = tt_input_tensor.reshape_unsafe(n, c, h * 2, w // 2)
# tt_output_tensor = tt_input_tensor.reshape_unsafe(n, c, h * 2, w // 2)
tt_output_tensor = ttnn.experimental.reshape(tt_input_tensor, n, c, h * 2, w // 2)

sharded_mem_config = ttnn.create_sharded_memory_config(
tt_output_tensor.shape,
Expand Down
2 changes: 2 additions & 0 deletions ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,8 @@ set(ALL_TTNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/expand/expand_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/expand/device/expand_rm_program_factory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/expand/device/expand_device_operation.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/reshape/reshape.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/reshape/reshape_pybind.cpp
)

#Split src and python bindings
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ void ReshapeDeviceOperation::validate(const std::vector<Tensor>& input_tensors)

TT_FATAL(
input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED,
"Reshape does not currently support sharding");
"Use view_unsafe for reshaping sharded inputs");
TT_FATAL(
this->output_mem_config.memory_layout == TensorMemoryLayout::INTERLEAVED,
"Reshape does not currently support sharding");
"Use view_unsafe for reshaping sharded inputs");

if (input_tensor_a.get_layout() == Layout::TILE) {
TT_FATAL(input_tensor_a.volume() % TILE_HW == 0, "Error");
Expand Down
6 changes: 6 additions & 0 deletions ttnn/cpp/ttnn/operations/experimental/experimental_pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@
#include "ttnn/operations/experimental/ccl/all_gather_matmul/all_gather_matmul_pybind.hpp"
#include "ttnn/operations/experimental/ccl/all_reduce/all_reduce_pybind.hpp"
#include "ttnn/operations/experimental/plusone/plusone_pybind.hpp"

#include "ttnn/operations/experimental/reshape/reshape_pybind.hpp"

namespace ttnn::operations::experimental {

void py_module(py::module& module) {
Expand Down Expand Up @@ -76,10 +79,13 @@ void py_module(py::module& module) {

plusone::detail::bind_experimental_plusone_operation(module);

reshape::detail::py_bind_reshape(module);

// CCL ops
auto m_experimental_ccl = module.def_submodule("ccl", "experiemental collective communication operations");
ccl::py_bind_all_gather_matmul(m_experimental_ccl);
ccl::py_bind_all_reduce(m_experimental_ccl);

}

} // namespace ttnn::operations::experimental
133 changes: 133 additions & 0 deletions ttnn/cpp/ttnn/operations/experimental/reshape/reshape.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@

// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include "ttnn/common/constants.hpp"
#include "ttnn/run_operation.hpp"
#include "reshape.hpp"
#include "tt_metal/common/constants.hpp"
#include <functional>
#include <ttnn/operations/numpy/functions.hpp>
#include "ttnn/operations/experimental/auto_format/auto_format.hpp"
#include "ttnn/tensor/tensor_utils.hpp"
#include "ttnn/operations/data_movement/data_transfer/data_transfer.hpp"
#include "ttnn/operations/data_movement/slice/slice.hpp"
#include "ttnn/operations/core/core.hpp"


#include "ttnn/tensor/tensor.hpp"

#include <cstdint>
#include <memory>

#include "common/bfloat16.hpp"
#include "ttnn/tensor/tensor_impl.hpp"
#include "ttnn/tensor/tensor_impl_wrapper.hpp"
#include "ttnn/tensor/tensor_utils.hpp"
#include "ttnn/tensor/types.hpp"
#include "tt_metal/common/constants.hpp"
#include "tt_metal/common/math.hpp"
#include "tt_metal/third_party/tracy/public/tracy/Tracy.hpp"
#include "tt_metal/graph/graph_tracking.hpp"
#include "ttnn/distributed/api.hpp"
#include "ttnn/distributed/types.hpp"
#include "ttnn/core.hpp"


namespace ttnn{

namespace operations::experimental::reshape {
ttnn::Tensor tensor_reshape(const ttnn::Tensor& input_tensor, const ttnn::Shape& new_shape) {
ZoneScoped;
GraphTracker::instance().track_function_start("ttnn::experimental::reshape", input_tensor, new_shape);
const auto& new_padded_shape = new_shape.padded_shape();
const auto tile = input_tensor.get_tensor_spec().tile();
TT_ASSERT(
input_tensor.volume() == new_padded_shape.volume(),
"{} != {}",
input_tensor.volume(),
new_padded_shape.volume());
if (input_tensor.get_layout() == Layout::TILE) {
TT_ASSERT(
new_padded_shape[-2] % tile.get_tile_shape()[0] == 0 &&
new_padded_shape[-1] % tile.get_tile_shape()[1] == 0 &&
"Expected a multiple of 32 for H, W (or -1 evaluating to such) in ttnn::experimental::reshape()!");
}
auto output = std::visit(
[&input_tensor, &new_shape, &tile](auto&& storage) -> Tensor {
using T = std::decay_t<decltype(storage)>;
const auto& tensor = input_tensor;
if constexpr (std::is_same_v<T, MultiDeviceHostStorage>) {
auto updated_storage = std::get<T>(tensor.get_storage());
for (int i = 0; i < updated_storage.shapes.size(); i++) {
updated_storage.shapes[i] = new_shape;
}
return Tensor(updated_storage, new_shape, tensor.get_dtype(), tensor.get_layout(), tile);
}
if constexpr (std::is_same_v<T, MultiDeviceStorage>) {
MultiDeviceStorage updated_storage = std::get<T>(tensor.get_storage());
std::unordered_map<int, ttnn::Shape> new_shapes;

for (auto device_id : updated_storage.ordered_device_ids) {
new_shapes.insert({device_id, new_shape});
}
updated_storage.shapes = new_shapes;
return Tensor(updated_storage, new_shape, tensor.get_dtype(), tensor.get_layout(), tile);
}
if constexpr (std::is_same_v<T, DeviceStorage>) {
if (input_tensor.get_layout() == Layout::ROW_MAJOR) {
if (tensor.memory_config().memory_layout != TensorMemoryLayout::HEIGHT_SHARDED) {
DeviceStorage device_storage = std::get<T>(tensor.get_storage());
DeviceBuffer device_buffer = device_storage.get_buffer();
device_buffer->set_page_size(new_shape[-1] * tensor.element_size());
device_storage.insert_buffer(device_buffer);
return Tensor(device_storage, new_shape, tensor.get_dtype(), tensor.get_layout(), tile);
} else {
DeviceStorage device_storage = std::get<T>(tensor.get_storage());
DeviceBuffer device_buffer = device_storage.get_buffer();
ShardSpecBuffer shard_spec_buffer = device_buffer->shard_spec();

auto shard_spec = shard_spec_buffer.tensor_shard_spec;
auto shard_shape = shard_spec.shape;

uint32_t mul_div = new_shape[-1] > shard_shape[1] ? (new_shape[-1] / shard_shape[1])
: (shard_shape[1] / new_shape[-1]);
shard_spec.shape[0] =
new_shape[-1] > shard_shape[1] ? shard_shape[0] / mul_div : shard_shape[0] * mul_div;
shard_spec.shape[1] = new_shape[-1];

shard_spec_buffer.page_shape = {1, new_shape[-1]};
shard_spec_buffer.tensor2d_shape = {shard_spec.shape[0], 1};
shard_spec_buffer.set_shard_spec(shard_spec);

device_buffer->set_shard_spec(shard_spec_buffer);
device_storage.insert_buffer(device_buffer);

return Tensor(device_storage, new_shape, tensor.get_dtype(), tensor.get_layout(), tile);
}
} else {
return Tensor(tensor.get_storage(), new_shape, tensor.get_dtype(), tensor.get_layout(), tile);
}
} else {
return Tensor(tensor.get_storage(), new_shape, tensor.get_dtype(), tensor.get_layout(), tile);
}
},
input_tensor.get_storage());
output = tt::tt_metal::set_tensor_id(output);
GraphTracker::instance().track_function_end(output);
return output;
}



ttnn::Tensor ReshapeOperation::invoke(const ttnn::Tensor& tensor, const ttnn::SimpleShape& shape) {
return tensor_reshape(tensor, shape);
}

ttnn::Tensor ReshapeOperation::invoke(const ttnn::Tensor& tensor, const ttnn::Shape& shape) {
return tensor_reshape(tensor, shape);
}

} // namespace operations::experimental::reshape
} //namespace ttnn
26 changes: 26 additions & 0 deletions ttnn/cpp/ttnn/operations/experimental/reshape/reshape.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "ttnn/run_operation.hpp"
#include "ttnn/decorators.hpp"
#include <optional>

namespace ttnn {
namespace operations::experimental::reshape {


struct ReshapeOperation {
static ttnn::Tensor invoke(const ttnn::Tensor& input_tensor, const ttnn::Shape& shape);
static ttnn::Tensor invoke(const ttnn::Tensor& input_tensor, const ttnn::SimpleShape& shape);
};

} // namespace operations::experimental::reshape

namespace experimental {
constexpr auto reshape =
ttnn::register_operation_with_auto_launch_op<"ttnn::experimental::reshape", ttnn::operations::experimental::reshape::ReshapeOperation>();
} // namespace experimental
} // namespace ttnn
72 changes: 72 additions & 0 deletions ttnn/cpp/ttnn/operations/experimental/reshape/reshape_pybind.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include "reshape_pybind.hpp"
#include "reshape.hpp"

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "ttnn/cpp/pybind11/decorators.hpp"

#include "ttnn/types.hpp"

#include "ttnn/tensor/tensor.hpp"
#include "ttnn/tensor/tensor_impl.hpp"


namespace ttnn::operations::experimental::reshape::detail {
namespace py = pybind11;

void py_bind_reshape(py::module& module) {
auto doc = R"doc(
Note: for a 0 cost view, the following conditions must be met:
* the last dimension must not change
* In Tiled the second last two dimensions must not change OR there is no padding on the second last dimension
Args:
* input_tensor: Input Tensor.
* new_shape: New shape of tensor.
Returns:
ttnn.Tensor: the output tensor with the new shape.
Example:
>>> tensor = ttnn.from_torch(torch.tensor((1, 4), dtype=torch.bfloat16), device=device)
>>> output = ttnn.experimental.reshape(tensor, (1, 1, 2, 2))
)doc";
bind_registered_operation(
module,
ttnn::experimental::reshape,
doc,
ttnn::pybind_overload_t{
[](const decltype(ttnn::experimental::reshape)& self, ttnn::Tensor& input_tensor, int N, int C, int H, int W) {
return self(input_tensor, infer_dims_for_reshape(input_tensor, ttnn::SmallVector<int>{N, C, H, W}));
},
py::arg("input_tensor"),
py::arg("N"),
py::arg("C"),
py::arg("H"),
py::arg("W"),
},

ttnn::pybind_overload_t{
[](const decltype(ttnn::experimental::reshape)& self, ttnn::Tensor& input_tensor, const ttnn::Shape& shape) {
return self(input_tensor, shape); },
py::arg("input_tensor"),
py::arg("shape"),
},
ttnn::pybind_overload_t{
[](const decltype(ttnn::experimental::reshape)& self, ttnn::Tensor& input_tensor, const ttnn::SmallVector<int32_t>& shape) {
return self(input_tensor, infer_dims_for_reshape(input_tensor, shape));
},
py::arg("input_tensor"),
py::arg("shape"),
});
}

} // namespace ttnn::operations::experimental::reshape::detail
13 changes: 13 additions & 0 deletions ttnn/cpp/ttnn/operations/experimental/reshape/reshape_pybind.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "pybind11/pybind_fwd.hpp"

namespace ttnn::operations::experimental::reshape::detail {

void py_bind_reshape(pybind11::module& module);

} // namespace ttnn::operations::experimental

0 comments on commit 47a3bcd

Please sign in to comment.