Skip to content

Commit

Permalink
#13745:link tensor.reshape to ttnn.reshape
Browse files Browse the repository at this point in the history
  • Loading branch information
nardoTT committed Dec 5, 2024
1 parent c6dee3b commit 1467253
Showing 1 changed file with 42 additions and 0 deletions.
42 changes: 42 additions & 0 deletions ttnn/cpp/pybind11/pytensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "ttnn/common/constants.hpp"
#include "ttnn/operations/core/core.hpp"


using namespace tt::tt_metal;

namespace py = pybind11;
Expand Down Expand Up @@ -1571,6 +1572,47 @@ void pytensor_module(py::module& m_tensor) {
dtype = tt_tensor.get_dtype()
)doc")


.def(
"reshape",
[](Tensor &self, int N, int C, int H, int W) {
return ttnn::reshape(self, infer_dims_for_reshape(self, ttnn::SmallVector<int>{N, C, H, W})); //self.reshape(infer_dims_for_reshape(self, ttnn::SmallVector<int>{N, C, H, W}));
},
R"doc(
Reshapes TT tensor
.. code-block:: python
reshaped_tensor = tt_tensor.reshape(N, C, H, W)
)doc")
.def(
"reshape",
[](Tensor &self, const ttnn::Shape &shape) -> Tensor { return ttnn::reshape(self, shape);},//self.reshape(shape); },
R"doc(
Reshapes TT tensor
.. code-block:: python
reshaped_tensor = tt_tensor.reshape((4, 3, 32))
)doc")
.def(
"reshape",
[](Tensor &self, const ttnn::SmallVector<int32_t> &shape) -> Tensor {
//return self.reshape(infer_dims_for_reshape(self, shape));
return ttnn::reshape(self, infer_dims_for_reshape(self, shape));
},
R"doc(
Reshapes TT tensor
.. code-block:: python
reshaped_tensor = tt_tensor.reshape((4, -1, 32))
)doc")




.def(
"reshape_unsafe",
[](Tensor& self, int N, int C, int H, int W) {
Expand Down

0 comments on commit 1467253

Please sign in to comment.