-
Notifications
You must be signed in to change notification settings - Fork 91
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
#15863: Implementing the view operation #15865
Merged
Merged
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
954b28d
#15863: Implementing the view operation
jvegaTT fc85881
#0: updating copyright
jvegaTT b25af68
#0: small oops
jvegaTT f82dc8f
#0: comments cleanup
jvegaTT fe48804
#0: fix doc message for view pybind
jvegaTT 18750a2
#15863: changing from Shape to SimpleShape
jvegaTT c02254e
Merge branch 'jvega/view_op_implementation' of github.com:tenstorrent…
jvegaTT 41f44d1
Merge branch 'main' into jvega/view_op_implementation
jvegaTT e4c0936
#15863: removing Shape from hpp
jvegaTT 2f7db75
Merge branch 'jvega/view_op_implementation' of github.com:tenstorrent…
jvegaTT 12acb7a
Merge branch 'main' into jvega/view_op_implementation
jvegaTT File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import pytest | ||
|
||
import torch | ||
|
||
import ttnn | ||
|
||
from tests.ttnn.utils_for_testing import assert_with_pcc | ||
|
||
|
||
# Reshape in Tile layout with shapes that are not divisible by 32 | ||
@pytest.mark.parametrize( | ||
"input_shape, output_shape, layout", | ||
[ | ||
((1, 15), (15,), ttnn.ROW_MAJOR_LAYOUT), # RM_last dimension matches, 1D output | ||
((2, 1, 1, 1, 15), (2, 15), ttnn.ROW_MAJOR_LAYOUT), # RM_last dimension matches | ||
((16, 1, 1, 247, 13), (1, 16, 247, 13), ttnn.TILE_LAYOUT), # last two dimensions match | ||
( | ||
(16, 1, 1, 256, 16), | ||
(8, 16, 32, 16), | ||
ttnn.TILE_LAYOUT, | ||
), # last dimension match but second last multiple of 32 but does not match | ||
((32, 32, 32, 15), (32768, 15), ttnn.TILE_LAYOUT), # Very large tensor | ||
], | ||
) | ||
def test_view(input_shape, output_shape, layout, device): | ||
torch_input_tensor = torch.randn(input_shape, dtype=torch.bfloat16) | ||
torch_result = torch_input_tensor.reshape(output_shape) | ||
|
||
input_tensor = ttnn.from_torch(torch_input_tensor, layout=layout, dtype=ttnn.bfloat16, device=device) | ||
ttnn_output = ttnn.view(input_tensor, output_shape) | ||
assert layout == ttnn_output.layout | ||
output = ttnn.to_torch(ttnn_output) | ||
assert_with_pcc(torch_result, output, 0.9999) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"input_shape, output_shape, layout", | ||
[ | ||
((2, 1, 1, 1, 15), (1, 30), ttnn.ROW_MAJOR_LAYOUT), # RM last dimension doesn't match | ||
( | ||
(16, 1, 256, 1, 16), | ||
(8, 16, 32, 16), | ||
ttnn.TILE_LAYOUT, | ||
), # TILE last dimension match but second last does not match, shape mult of 32 only | ||
( | ||
(16, 1, 1, 256, 16), | ||
(8, 16, 32, 1, 16), | ||
ttnn.TILE_LAYOUT, | ||
), # TILE last dimension match but second last does not match, tensor mult of 32 only | ||
( | ||
(256, 1, 1, 16, 16), | ||
(8, 16, 32, 1, 16), | ||
ttnn.TILE_LAYOUT, | ||
), # TILE last dimension match but second last does not match, none mult of 32 | ||
( | ||
(16, 8, 1, 32, 16), | ||
(8, 16, 31, 16), | ||
ttnn.TILE_LAYOUT, | ||
), # Volume doesn't match but padded volume does | ||
], | ||
) | ||
def test_invalid_cases(input_shape, output_shape, layout, device): | ||
# Verifies invalid cases do cause an assertion | ||
torch_input_tensor = torch.rand(input_shape, dtype=torch.bfloat16) | ||
input_tensor = ttnn.from_torch(torch_input_tensor, dtype=ttnn.bfloat16, layout=layout, device=device) | ||
with pytest.raises(RuntimeError): | ||
ttnn.view(input_tensor, output_shape) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include "view.hpp" | ||
#include "ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.hpp" | ||
|
||
namespace ttnn::operations::data_movement { | ||
|
||
ttnn::Tensor ViewOperation::invoke(const ttnn::Tensor& tensor, const ttnn::SimpleShape& shape) { | ||
auto layout = tensor.get_layout(); | ||
jvegaTT marked this conversation as resolved.
Show resolved
Hide resolved
|
||
auto tensor_shape = tensor.get_shape(); | ||
// First Case, No reshape Required | ||
if (tensor_shape == shape) { | ||
return tensor; | ||
} | ||
|
||
const uint32_t tile_first_dim = tensor.get_tensor_spec().tile().get_width(); | ||
const uint32_t tile_second_dim = tensor.get_tensor_spec().tile().get_height(); | ||
const uint32_t shape_second_last_dim = shape.rank() >= 2 ? shape[-2] : 1; | ||
const uint32_t tensor_shape_second_last_dim = tensor_shape.rank() >= 2 ? tensor_shape[-2] : 1; | ||
// Validate the operation | ||
TT_FATAL( | ||
shape.volume() == tensor.get_logical_volume(), | ||
"Invalid view, logical volumes are changing from {} to {}", | ||
tensor.get_logical_volume(), | ||
shape.volume()); | ||
TT_FATAL( | ||
ttnn::has_storage_type_of(tensor, ttnn::StorageType::DEVICE), | ||
"View requires the tensor be stored on device, use reshape instead"); | ||
TT_FATAL( | ||
(tensor_shape[-1] == shape[-1]), | ||
"The last dimension can not change in view, attempting to change last dimension from {} to {}, use reshape " | ||
"instead", | ||
tensor_shape[-1], | ||
shape[-1]); | ||
TT_FATAL( | ||
(tensor.get_layout() == ttnn::ROW_MAJOR_LAYOUT) || // Its row major | ||
(tensor_shape_second_last_dim == shape_second_last_dim) || // Second last dimension is the same | ||
((shape_second_last_dim % tile_second_dim == 0) && (tensor_shape_second_last_dim % tile_second_dim == 0)), | ||
"Invalid second last dims for TILED reshape, from {} to {}, use reshape instead\n", | ||
tensor_shape_second_last_dim, | ||
shape_second_last_dim); | ||
// Perform the View | ||
return PerformView(tensor, shape, tile_first_dim, tile_second_dim); | ||
} | ||
|
||
ttnn::Tensor ViewOperation::invoke(const ttnn::Tensor& tensor, tt::stl::Span<const int32_t> shape_vector) { | ||
return invoke(tensor, tt::tt_metal::infer_dims_for_reshape(tensor, shape_vector)); | ||
} | ||
|
||
} // namespace ttnn::operations::data_movement |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#pragma once | ||
|
||
#include "ttnn/decorators.hpp" | ||
|
||
namespace ttnn { | ||
namespace operations::data_movement { | ||
|
||
struct ViewOperation { | ||
static ttnn::Tensor invoke(const ttnn::Tensor& input_tensor, const ttnn::SimpleShape& logical_shape); | ||
bbradelTT marked this conversation as resolved.
Show resolved
Hide resolved
|
||
static ttnn::Tensor invoke(const ttnn::Tensor& input_tensor, tt::stl::Span<const int32_t> shape_vector); | ||
}; | ||
|
||
} // namespace operations::data_movement | ||
constexpr auto view = ttnn::register_operation<"ttnn::view", ttnn::operations::data_movement::ViewOperation>(); | ||
} // namespace ttnn |
69 changes: 69 additions & 0 deletions
69
ttnn/cpp/ttnn/operations/data_movement/view/view_pybind.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include "view_pybind.hpp" | ||
|
||
#include <pybind11/pybind11.h> | ||
#include <pybind11/stl.h> | ||
|
||
#include "ttnn/cpp/pybind11/decorators.hpp" | ||
#include "ttnn/operations/data_movement/view/view.hpp" | ||
#include "ttnn/types.hpp" | ||
|
||
namespace ttnn::operations::data_movement { | ||
|
||
namespace detail { | ||
|
||
template <typename data_movement_operation_t> | ||
void bind_view(pybind11::module& module, const data_movement_operation_t& operation, const char* doc) { | ||
bind_registered_operation( | ||
module, | ||
operation, | ||
doc, | ||
ttnn::pybind_overload_t{ | ||
[](const data_movement_operation_t& self, const ttnn::Tensor& input_tensor, const ttnn::SimpleShape& shape) | ||
-> ttnn::Tensor { return self(input_tensor, shape); }, | ||
py::arg("input_tensor"), | ||
py::arg("shape"), | ||
}, | ||
ttnn::pybind_overload_t{ | ||
[](const data_movement_operation_t& self, | ||
const ttnn::Tensor& input_tensor, | ||
const ttnn::SmallVector<int32_t> shape) -> ttnn::Tensor { return self(input_tensor, shape); }, | ||
py::arg("input_tensor"), | ||
py::arg("shape"), | ||
}); | ||
} | ||
|
||
} // namespace detail | ||
|
||
void py_bind_view(pybind11::module& module) { | ||
detail::bind_view( | ||
module, | ||
ttnn::view, | ||
R"doc( | ||
|
||
This is a 0 cost view operation that returns the same tensor that was passed to it but with a new shape | ||
|
||
Note: The following conditions must be met: | ||
* the memory is stored on the device | ||
* the last dimension must not change | ||
* In Tiled the second last two dimensions must not change OR there is no padding on the second last dimension | ||
|
||
Args: | ||
* input_tensor: Input Tensor. | ||
* new_shape: New shape of tensor. | ||
|
||
Returns: | ||
ttnn.Tensor: a reference to the input tensor but with the new shape. | ||
|
||
Example: | ||
|
||
>>> tensor = ttnn.from_torch(torch.tensor((2, 1, 1, 1, 4), dtype=torch.bfloat16), device=device) | ||
>>> output = ttnn.view(tensor, (2, 1, 4)) | ||
|
||
)doc"); | ||
|
||
} // namespace ttnn::operations::data_movement | ||
} // namespace ttnn::operations::data_movement |
13 changes: 13 additions & 0 deletions
13
ttnn/cpp/ttnn/operations/data_movement/view/view_pybind.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#pragma once | ||
|
||
#include "pybind11/pybind_fwd.hpp" | ||
|
||
namespace ttnn::operations::data_movement { | ||
|
||
void py_bind_view(pybind11::module& module); | ||
|
||
} // namespace ttnn::operations::data_movement |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could add test which demonstrates that it will throw if I pass wrong dimensions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is done in test_invalid_cases
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just saw it! thanks.