diff --git a/tests/ttnn/unit_tests/test_reshape.py b/tests/ttnn/unit_tests/test_reshape.py index f3ae5e8112fb..0900d428aa98 100644 --- a/tests/ttnn/unit_tests/test_reshape.py +++ b/tests/ttnn/unit_tests/test_reshape.py @@ -35,7 +35,7 @@ def test_reshape_sharded_rm(device, n, c, h, w): torch_input_tensor, layout=ttnn.ROW_MAJOR_LAYOUT, device=device, memory_config=sharded_mem_config ) - tt_output_tensor = tt_input_tensor.reshape(n, c, h * 2, w // 2) + tt_output_tensor = tt_input_tensor.reshape_unsafe(n, c, h * 2, w // 2) sharded_mem_config = ttnn.create_sharded_memory_config( tt_output_tensor.shape,