diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp index 63f9fdd7c045..e45f0805b824 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp @@ -35,6 +35,7 @@ ttnn::Tensor host_reshape(const ttnn::Tensor& tensor, const ttnn::Shape& shape) ttnn::Tensor slice_input; std::vector begins; std::vector ends; + TT_FATAL(tensor_shape.rank() <= 4, "Only up to 4D tensors"); auto host_tensor_4d = unsqueeze_to_4D(rm_tensor); auto tensor_shape_4d = host_tensor_4d.shape(); begins = std::vector({0, 0, 0, 0}); @@ -98,14 +99,19 @@ ttnn::Tensor row_major_reshape(const ttnn::Tensor& tensor, const ttnn::Shape& sh } -ttnn::Shape get_shape_from_vector(const ttnn::Tensor& tensor, const std::vector & shape) { +ttnn::Shape get_shape_from_vector_with_possible_negative_values(const ttnn::Tensor& tensor, const std::vector & shape) { std::int64_t new_volume = 1; std::int64_t index_of_negative_1 = -1; for (auto index = 0; index < shape.size(); ++index) { if (shape[index] == -1) { if (index_of_negative_1 != -1) { - TT_THROW("Shape cannot have more than 1 elements that is set to -1!"); + std::string error_msg = "Shape cannot have more than 1 elements that is set to -1! Shape used: ("; + for(auto & s: shape) { + error_msg += std::to_string(s) + ","; + } + error_msg += ")"; + TT_THROW(error_msg.c_str()); } index_of_negative_1 = index; } @@ -153,7 +159,7 @@ ttnn::Tensor ReshapeViewOperation::invoke( const std::vector & shape_vector ) { - auto shape = detail::get_shape_from_vector(tensor, shape_vector); + auto shape = detail::get_shape_from_vector_with_possible_negative_values(tensor, shape_vector); return invoke(tensor, shape); }