Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#15863: Implementing the view operation #15865

Merged
merged 11 commits into from
Dec 10, 2024
71 changes: 71 additions & 0 deletions tests/ttnn/unit_tests/test_view.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import pytest

import torch

import ttnn

from tests.ttnn.utils_for_testing import assert_with_pcc


# Reshape in Tile layout with shapes that are not divisible by 32
@pytest.mark.parametrize(
"input_shape, output_shape, layout",
[
((1, 15), (15,), ttnn.ROW_MAJOR_LAYOUT), # RM_last dimension matches, 1D output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could add test which demonstrates that it will throw if I pass wrong dimensions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is done in test_invalid_cases

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just saw it! thanks.

((2, 1, 1, 1, 15), (2, 15), ttnn.ROW_MAJOR_LAYOUT), # RM_last dimension matches
((16, 1, 1, 247, 13), (1, 16, 247, 13), ttnn.TILE_LAYOUT), # last two dimensions match
(
(16, 1, 1, 256, 16),
(8, 16, 32, 16),
ttnn.TILE_LAYOUT,
), # last dimension match but second last multiple of 32 but does not match
((32, 32, 32, 15), (32768, 15), ttnn.TILE_LAYOUT), # Very large tensor
],
)
def test_view(input_shape, output_shape, layout, device):
torch_input_tensor = torch.randn(input_shape, dtype=torch.bfloat16)
torch_result = torch_input_tensor.reshape(output_shape)

input_tensor = ttnn.from_torch(torch_input_tensor, layout=layout, dtype=ttnn.bfloat16, device=device)
ttnn_output = ttnn.view(input_tensor, output_shape)
assert layout == ttnn_output.layout
output = ttnn.to_torch(ttnn_output)
assert_with_pcc(torch_result, output, 0.9999)


@pytest.mark.parametrize(
"input_shape, output_shape, layout",
[
((2, 1, 1, 1, 15), (1, 30), ttnn.ROW_MAJOR_LAYOUT), # RM last dimension doesn't match
(
(16, 1, 256, 1, 16),
(8, 16, 32, 16),
ttnn.TILE_LAYOUT,
), # TILE last dimension match but second last does not match, shape mult of 32 only
(
(16, 1, 1, 256, 16),
(8, 16, 32, 1, 16),
ttnn.TILE_LAYOUT,
), # TILE last dimension match but second last does not match, tensor mult of 32 only
(
(256, 1, 1, 16, 16),
(8, 16, 32, 1, 16),
ttnn.TILE_LAYOUT,
), # TILE last dimension match but second last does not match, none mult of 32
(
(16, 8, 1, 32, 16),
(8, 16, 31, 16),
ttnn.TILE_LAYOUT,
), # Volume doesn't match but padded volume does
],
)
def test_invalid_cases(input_shape, output_shape, layout, device):
# Verifies invalid cases do cause an assertion
torch_input_tensor = torch.rand(input_shape, dtype=torch.bfloat16)
input_tensor = ttnn.from_torch(torch_input_tensor, dtype=ttnn.bfloat16, layout=layout, device=device)
with pytest.raises(RuntimeError):
ttnn.view(input_tensor, output_shape)
2 changes: 2 additions & 0 deletions ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ set(ALL_TTNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/reshape_on_device/reshape_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/reshape_on_device/device/reshape_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/reshape_on_device/device/reshape_program_factory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/view/view.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/view/view_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/reshape_view/reshape_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/reshape_view/device/reshape_rm_op.cpp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "ttnn/operations/data_movement/repeat_interleave/repeat_interleave_pybind.hpp"
#include "ttnn/operations/data_movement/reshape_on_device/reshape_pybind.hpp"
#include "ttnn/operations/data_movement/reshape_view/reshape_pybind.hpp"
#include "ttnn/operations/data_movement/view/view_pybind.hpp"
#include "ttnn/operations/data_movement/sharded_partial/interleaved_to_sharded_partial/interleaved_to_sharded_partial_pybind.hpp"
#include "ttnn/operations/data_movement/sharded_partial/sharded_to_interleaved_partial/sharded_to_interleaved_partial_pybind.hpp"
#include "ttnn/operations/data_movement/slice/slice_pybind.hpp"
Expand Down Expand Up @@ -75,6 +76,7 @@ void py_module(py::module& module) {
py_bind_repeat(module);
py_bind_reshape(module);
py_bind_reshape_view(module);
py_bind_view(module);
py_bind_reshard(module);
py_bind_sharded_to_interleaved(module);
py_bind_sharded_to_interleaved_partial(module);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ namespace detail {
ttnn::Tensor convert_tile_to_rm(const ttnn::Tensor& tensor, const ttnn::Shape& shape, const uint32_t tile_first_dim, const uint32_t tile_second_dim, const MemoryConfig &memory_config, const uint8_t queue_id, const PadValue &pad_value);
}

ttnn::Shape shape_corrector(const ttnn::Tensor& tensor, const ttnn::Shape& shape);
ttnn::Shape tiling_reshape_corrector(const ttnn::Shape& shape);
ttnn::Tensor PerformView(const ttnn::Tensor& tensor, const ttnn::Shape& shapeconst, uint32_t tile_first_dim, const uint32_t tile_second_dim);
void Validate_transform (const ttnn::Shape& input_shape, const ttnn::Shape& output_shape);
Expand Down
52 changes: 52 additions & 0 deletions ttnn/cpp/ttnn/operations/data_movement/view/view.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include "view.hpp"
#include "ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.hpp"

namespace ttnn::operations::data_movement {

ttnn::Tensor ViewOperation::invoke(const ttnn::Tensor& tensor, const ttnn::SimpleShape& shape) {
auto layout = tensor.get_layout();
jvegaTT marked this conversation as resolved.
Show resolved Hide resolved
auto tensor_shape = tensor.get_shape();
// First Case, No reshape Required
if (tensor_shape == shape) {
return tensor;
}

const uint32_t tile_first_dim = tensor.get_tensor_spec().tile().get_width();
const uint32_t tile_second_dim = tensor.get_tensor_spec().tile().get_height();
const uint32_t shape_second_last_dim = shape.rank() >= 2 ? shape[-2] : 1;
const uint32_t tensor_shape_second_last_dim = tensor_shape.rank() >= 2 ? tensor_shape[-2] : 1;
// Validate the operation
TT_FATAL(
shape.volume() == tensor.get_logical_volume(),
"Invalid view, logical volumes are changing from {} to {}",
tensor.get_logical_volume(),
shape.volume());
TT_FATAL(
ttnn::has_storage_type_of(tensor, ttnn::StorageType::DEVICE),
"View requires the tensor be stored on device, use reshape instead");
TT_FATAL(
(tensor_shape[-1] == shape[-1]),
"The last dimension can not change in view, attempting to change last dimension from {} to {}, use reshape "
"instead",
tensor_shape[-1],
shape[-1]);
TT_FATAL(
(tensor.get_layout() == ttnn::ROW_MAJOR_LAYOUT) || // Its row major
(tensor_shape_second_last_dim == shape_second_last_dim) || // Second last dimension is the same
((shape_second_last_dim % tile_second_dim == 0) && (tensor_shape_second_last_dim % tile_second_dim == 0)),
"Invalid second last dims for TILED reshape, from {} to {}, use reshape instead\n",
tensor_shape_second_last_dim,
shape_second_last_dim);
// Perform the View
return PerformView(tensor, shape, tile_first_dim, tile_second_dim);
}

ttnn::Tensor ViewOperation::invoke(const ttnn::Tensor& tensor, tt::stl::Span<const int32_t> shape_vector) {
return invoke(tensor, tt::tt_metal::infer_dims_for_reshape(tensor, shape_vector));
}

} // namespace ttnn::operations::data_movement
20 changes: 20 additions & 0 deletions ttnn/cpp/ttnn/operations/data_movement/view/view.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "ttnn/decorators.hpp"

namespace ttnn {
namespace operations::data_movement {

struct ViewOperation {
static ttnn::Tensor invoke(const ttnn::Tensor& input_tensor, const ttnn::Shape& shape);
static ttnn::Tensor invoke(const ttnn::Tensor& input_tensor, const ttnn::SimpleShape& logical_shape);
bbradelTT marked this conversation as resolved.
Show resolved Hide resolved
static ttnn::Tensor invoke(const ttnn::Tensor& input_tensor, tt::stl::Span<const int32_t> shape_vector);
};

} // namespace operations::data_movement
constexpr auto view = ttnn::register_operation<"ttnn::view", ttnn::operations::data_movement::ViewOperation>();
} // namespace ttnn
69 changes: 69 additions & 0 deletions ttnn/cpp/ttnn/operations/data_movement/view/view_pybind.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include "view_pybind.hpp"

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "ttnn/cpp/pybind11/decorators.hpp"
#include "ttnn/operations/data_movement/view/view.hpp"
#include "ttnn/types.hpp"

namespace ttnn::operations::data_movement {

namespace detail {

template <typename data_movement_operation_t>
void bind_view(pybind11::module& module, const data_movement_operation_t& operation, const char* doc) {
bind_registered_operation(
module,
operation,
doc,
ttnn::pybind_overload_t{
[](const data_movement_operation_t& self, const ttnn::Tensor& input_tensor, const ttnn::SimpleShape& shape)
-> ttnn::Tensor { return self(input_tensor, shape); },
py::arg("input_tensor"),
py::arg("shape"),
},
ttnn::pybind_overload_t{
[](const data_movement_operation_t& self,
const ttnn::Tensor& input_tensor,
const ttnn::SmallVector<int32_t> shape) -> ttnn::Tensor { return self(input_tensor, shape); },
py::arg("input_tensor"),
py::arg("shape"),
});
}

} // namespace detail

void py_bind_view(pybind11::module& module) {
detail::bind_view(
module,
ttnn::view,
R"doc(

This is a 0 cost view operation that returns the same tensor that was passed to it but with a new shape

Note: The following conditions must be met:
* the memory is stored on the device
* the last dimension must not change
* In Tiled the second last two dimensions must not change OR there is no padding on the second last dimension

Args:
* input_tensor: Input Tensor.
* new_shape: New shape of tensor.

Returns:
ttnn.Tensor: a reference to the input tensor but with the new shape.

Example:

>>> tensor = ttnn.from_torch(torch.tensor((2, 1, 1, 1, 4), dtype=torch.bfloat16), device=device)
>>> output = ttnn.view(tensor, (2, 1, 4))

)doc");

} // namespace ttnn::operations::data_movement
} // namespace ttnn::operations::data_movement
13 changes: 13 additions & 0 deletions ttnn/cpp/ttnn/operations/data_movement/view/view_pybind.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "pybind11/pybind_fwd.hpp"

namespace ttnn::operations::data_movement {

void py_bind_view(pybind11::module& module);

} // namespace ttnn::operations::data_movement
Loading