diff --git a/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_repeat_interleave.py b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_repeat_interleave.py index 1a297d86aa1f..cc3b85dfb8eb 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_repeat_interleave.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_repeat_interleave.py @@ -15,7 +15,7 @@ from tests.tt_eager.python_api_testing.sweep_tests.run_pytorch_ci_tests import ( run_single_pytorch_test, ) -from models.utility_functions import is_wormhole_b0, skip_for_blackhole +from models.utility_functions import is_wormhole_b0, is_grayskull, skip_for_blackhole shapes = ( [[1, 1, 32, 32]], # Single core @@ -30,6 +30,8 @@ @pytest.mark.parametrize("dim", [0, 2, -4, -2, 1, 3]) @pytest.mark.parametrize("repeat", [2, 3, 4]) def test_run_repeat_interleave_test(input_shapes, dim, repeat, device): + if is_grayskull and dim == 3: + pytest.skip("Grayskull does not support dim=3 because we cannot tranpose WH reliably") datagen_func = [ generation_funcs.gen_func_with_cast( partial(generation_funcs.gen_rand_along_dim, low=-100, high=100), diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_transpose.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_transpose.py index 0581ad0fb8d2..3973e8fa5726 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_transpose.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_transpose.py @@ -660,22 +660,25 @@ def test_transpose_hc(dtype, shape, device): ) @pytest.mark.parametrize( "shape", - [(1, 32), (1, 12), (1, 35), (16, 32), (34, 8)], + [(9216, 128), (1, 32), (1, 12), (1, 35), (16, 32), (34, 8)], ) @pytest.mark.parametrize( "layout", - [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT], + [ttnn.TILE_LAYOUT], ) -@pytest.mark.parametrize( - "dims", - [(1, 0), (-1, -2)], -) -def test_transpose_2D(dtype, shape, layout, dims, device): +def test_transpose_2D(dtype, shape, layout, device): if is_grayskull() and dtype == ttnn.float32: pytest.skip("Skipping float32 tests on Grayskull") - if layout == ttnn.ROW_MAJOR_LAYOUT and dtype == ttnn.bfloat16 and shape[-1] % 2: + if layout == ttnn.ROW_MAJOR_LAYOUT and dtype == ttnn.bfloat16 and (shape[-1] % 2 or shape[-2] % 2): pytest.skip("Skipping RM odd inner dim test cases") - transpose(shape, device, dim0=0, dim1=1, input_dtype=dtype) + + torch_input = torch.randn(shape, dtype=torch.bfloat16) + torch_output = torch_input.transpose(0, 1) + + tt_input = ttnn.from_torch(torch_input, dtype=ttnn.DataType.BFLOAT16, layout=layout, device=device) + tt_output = ttnn.transpose(tt_input, 0, 1) + tt_output = ttnn.to_torch(tt_output) + assert_with_pcc(torch_output, tt_output, 0.9999) @pytest.mark.parametrize( @@ -689,7 +692,7 @@ def test_transpose_2D(dtype, shape, layout, dims, device): ) @pytest.mark.parametrize( "layout", - [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT], + [ttnn.TILE_LAYOUT], ) @pytest.mark.parametrize( "dims", @@ -698,10 +701,41 @@ def test_transpose_2D(dtype, shape, layout, dims, device): def test_transpose_3D(dtype, shape, layout, dims, device): if is_grayskull() and dtype == ttnn.float32: pytest.skip("Skipping float32 tests on Grayskull") - - new_shape = shape - new_shape[dims[0]], new_shape[dims[1]] = shape[dims[1]], shape[dims[0]] - if layout == ttnn.ROW_MAJOR_LAYOUT and dtype == ttnn.bfloat16 and (shape[-1] % 2 or new_shape[-1] % 2): + if layout == ttnn.ROW_MAJOR_LAYOUT and dtype == ttnn.bfloat16 and (shape[-1] % 2 or shape[dims[-1]] % 2): pytest.skip("Skipping RM odd inner dim test cases") - transpose(shape, device, dim0=dims[0], dim1=dims[1], input_dtype=dtype) + torch_input = torch.randn(shape, dtype=torch.bfloat16) + torch_output = torch_input.transpose(dims[0], dims[1]) + + tt_input = ttnn.from_torch(torch_input, dtype=ttnn.DataType.BFLOAT16, layout=layout, device=device) + tt_output = ttnn.transpose(tt_input, dims[0], dims[1]) + tt_output = ttnn.to_torch(tt_output) + assert_with_pcc(torch_output, tt_output, 0.9999) + + +@pytest.mark.parametrize( + "shape", + [[4, 3, 1280, 40], [1, 4096, 4096]], +) +def test_transpose_4d_wh_rm(shape, device): + torch_input = torch.randn(shape, dtype=torch.bfloat16) + torch_output = torch_input.transpose(-1, -2) + + tt_input = ttnn.from_torch(torch_input, dtype=ttnn.DataType.BFLOAT16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device) + tt_output = ttnn.transpose(tt_input, -1, -2) + tt_output = ttnn.to_torch(tt_output) + assert_with_pcc(torch_output, tt_output, 0.9999) + + +@pytest.mark.parametrize( + "shape", + [[4, 3, 1280, 40], [1, 1200, 1280]], +) +def test_transpose_4d_wh_tile(shape, device): + torch_input = torch.randn(shape, dtype=torch.bfloat16) + torch_output = torch_input.transpose(-1, -2) + + tt_input = ttnn.from_torch(torch_input, dtype=ttnn.DataType.BFLOAT16, layout=ttnn.TILE_LAYOUT, device=device) + tt_output = ttnn.transpose(tt_input, -1, -2) + tt_output = ttnn.to_torch(tt_output) + assert_with_pcc(torch_output, tt_output, 0.9999) diff --git a/tests/ttnn/unit_tests/gtests/CMakeLists.txt b/tests/ttnn/unit_tests/gtests/CMakeLists.txt index 3937aa52f132..bb3672127434 100644 --- a/tests/ttnn/unit_tests/gtests/CMakeLists.txt +++ b/tests/ttnn/unit_tests/gtests/CMakeLists.txt @@ -2,7 +2,6 @@ set(TTNN_UNIT_TESTS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/test_add.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test_graph_add.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/test_repeat_interleave.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test_async_runtime.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test_multiprod_queue.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test_multi_cq_multi_dev.cpp diff --git a/tests/ttnn/unit_tests/gtests/test_repeat_interleave.cpp b/tests/ttnn/unit_tests/gtests/test_repeat_interleave.cpp deleted file mode 100644 index 1dee81c29e06..000000000000 --- a/tests/ttnn/unit_tests/gtests/test_repeat_interleave.cpp +++ /dev/null @@ -1,105 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include "gtest/gtest.h" - -#include "tt_metal/common/bfloat16.hpp" -#include "ttnn/device.hpp" -#include "ttnn/operations/core/core.hpp" -#include "ttnn/async_runtime.hpp" -#include "ttnn/operations/data_movement/repeat_interleave/repeat_interleave.hpp" -#include "ttnn/operations/numpy/functions.hpp" -#include "tt_metal/common/logger.hpp" - -#include "ttnn_test_fixtures.hpp" - -#include - -namespace ttnn { -namespace operations { -namespace data_movement { -namespace test { - -void run_repeat_interleave_test(tt::tt_metal::Device* device, const uint32_t repeats, const uint32_t dim) { - MemoryConfig mem_cfg; - mem_cfg.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED; - mem_cfg.buffer_type = BufferType::DRAM; - - const uint32_t io_cq = 0; - const uint32_t input_buf_size_datums = 32 * 32; - const uint32_t output_buf_size_datums = input_buf_size_datums * repeats; - const uint32_t datum_size_bytes = 2; - ttnn::SimpleShape input_shape{1, 1, 32, 32}; - auto host_data = std::shared_ptr(new uint16_t[input_buf_size_datums]); - auto readback_data = std::shared_ptr(new uint16_t[output_buf_size_datums]); - - for (uint16_t i = 0; i < 32; i++) { - for (uint16_t j = 0; j < 32; j++) { - host_data[i * 32 + j] = i; - } - } - - auto input_buffer = ttnn::allocate_buffer_on_device(input_buf_size_datums * datum_size_bytes, device, input_shape, DataType::UINT16, Layout::TILE, mem_cfg); - auto input_storage = tt::tt_metal::DeviceStorage{input_buffer}; - Tensor input_tensor = Tensor(input_storage, input_shape, DataType::UINT16, Layout::TILE); - ttnn::write_buffer(io_cq, input_tensor, {host_data}); - - ttnn::Tensor output_tensor = ttnn::repeat_interleave(input_tensor, repeats, dim); - - ttnn::read_buffer(io_cq, output_tensor, {readback_data}); - - tt::log_debug("input_data: \n {}", input_tensor.write_to_string()); - tt::log_debug("readback_data: \n {}", output_tensor.write_to_string()); - - for (int i = 0; i < input_buf_size_datums; i++) { - auto input_value = host_data[i]; - for(int r = 0; r < repeats; r++) { - auto value = readback_data[i + r * input_buf_size_datums]; - ASSERT_EQ(input_value, value); - } - } - - input_tensor.deallocate(); - output_tensor.deallocate(); -} - -struct RepeatInterleaveParams { - int repeats = 0; - int dim = 0; -}; - -class RepeatInterleaveTest : public ttnn::TTNNFixtureWithDevice, public ::testing::WithParamInterface {}; - -TEST_P(RepeatInterleaveTest, RunsCorrectly) { - RepeatInterleaveParams params = GetParam(); - run_repeat_interleave_test(device_, params.repeats, params.dim); -} - -INSTANTIATE_TEST_SUITE_P( - RepeatInterleaveWithDim0, - RepeatInterleaveTest, - ::testing::Values( - RepeatInterleaveParams{1, 0}, - RepeatInterleaveParams{2, 0}, - RepeatInterleaveParams{3, 0} - ) -); - -// tests/ttnn/unit_tests/operations/test_repeat_interleave.py proves that it should work over dim 1 too -// likely need to fix the comparison in the test -INSTANTIATE_TEST_SUITE_P( - DISABLED_RepeatInterleaveWithDim1, - RepeatInterleaveTest, - ::testing::Values( - RepeatInterleaveParams{1, 1}, - RepeatInterleaveParams{2, 1}, - RepeatInterleaveParams{3, 1} - ) -); - - -} // namespace test -} // namespace binary -} // namespace operations -} // namespace ttnn diff --git a/tests/ttnn/unit_tests/operations/test_concat.py b/tests/ttnn/unit_tests/operations/test_concat.py index 62af874f4175..a758ef0edbdb 100644 --- a/tests/ttnn/unit_tests/operations/test_concat.py +++ b/tests/ttnn/unit_tests/operations/test_concat.py @@ -112,3 +112,14 @@ def test_sharded_concat( assert_with_pcc(torch_output_tensor, output, 0.9999) assert_with_pcc(torch_output_tensor, output) + + +@pytest.mark.parametrize("dim", [0, 1, 2, 3]) +def test_concat_5d(device, dim): + torch_input_tensor = torch.rand(1, 1, 1, 1, 2, dtype=torch.bfloat16) + torch_result = torch.cat([torch_input_tensor, torch_input_tensor], dim=dim) + + ttnn_input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.ROW_MAJOR_LAYOUT, device=device) + ttnn_result = ttnn.concat([ttnn_input_tensor, ttnn_input_tensor], dim=dim) + ttnn_result = ttnn.to_torch(ttnn_result) + assert_with_pcc(torch_result, ttnn_result, 0.9999) diff --git a/tests/ttnn/unit_tests/operations/test_repeat_interleave.py b/tests/ttnn/unit_tests/operations/test_repeat_interleave.py index aefd70b99c2b..ab21f1462ffa 100644 --- a/tests/ttnn/unit_tests/operations/test_repeat_interleave.py +++ b/tests/ttnn/unit_tests/operations/test_repeat_interleave.py @@ -10,20 +10,29 @@ from loguru import logger from tests.ttnn.utils_for_testing import assert_with_pcc +from models.utility_functions import ( + is_grayskull, + is_wormhole_b0, +) @pytest.mark.parametrize("repeats", [1, 2, 3]) @pytest.mark.parametrize("dim", [0, 1, 2, 3]) -@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("dtype", [ttnn.bfloat16, ttnn.uint16]) def test_repeat_interleave(device, repeats, dim, dtype): - torch_input_tensor = torch.rand(1, 1, 32, 32, dtype=dtype) - torch_result = torch.repeat_interleave(torch_input_tensor, repeats, dim=dim) - - input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device) + if dtype == ttnn.uint16: + if is_grayskull: + pytest.skip("Grayskull does not support uint16") + torch_dtype = torch.int16 + torch_input_tensor = torch.randint(0, 100, (1, 1, 32, 32), dtype=torch_dtype) + else: + torch_dtype = torch.bfloat16 + torch_input_tensor = torch.rand(1, 1, 32, 32, dtype=torch_dtype) + torch_result = torch.repeat_interleave(torch_input_tensor, repeats, dim=dim) + input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, dtype=dtype, device=device) output = ttnn.repeat_interleave(input_tensor, repeats, dim=dim) output = ttnn.to_torch(output) - assert_with_pcc(torch_result, output, 0.9999) diff --git a/ttnn/cpp/ttnn/operations/core/to_memory_config/to_memory_config_op.hpp b/ttnn/cpp/ttnn/operations/core/to_memory_config/to_memory_config_op.hpp index 2251e9314555..a8d0ecf4e957 100644 --- a/ttnn/cpp/ttnn/operations/core/to_memory_config/to_memory_config_op.hpp +++ b/ttnn/cpp/ttnn/operations/core/to_memory_config/to_memory_config_op.hpp @@ -24,7 +24,7 @@ struct ToMemoryConfig { // TODO: Move to cpp once we merge with tt_eager static Tensor invoke( - const ttnn::Tensor& tensor, const ttnn::MemoryConfig& memory_config, std::optional dtype) { + const ttnn::Tensor& tensor, const ttnn::MemoryConfig& memory_config, std::optional dtype = std::nullopt) { // Temporary until we see why buffer data not being populated const auto original_shape = tensor.get_shape(); diff --git a/ttnn/cpp/ttnn/operations/data_movement/concat/concat.cpp b/ttnn/cpp/ttnn/operations/data_movement/concat/concat.cpp index 758b8676b0a0..322e467360c0 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/concat/concat.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/concat/concat.cpp @@ -81,7 +81,7 @@ namespace data_movement { return output; }); // Convert dim after unsqueeze - dim = dim + 4 - rank; + dim = rank < 4 ? dim + 4 - rank : dim; auto output_tensor = concat_impl(itensor, dim, mem_config); while (output_tensor.get_shape().rank() > rank) { const auto shape = output_tensor.get_shape(); diff --git a/ttnn/cpp/ttnn/operations/data_movement/concat/device/concat_device_operation.cpp b/ttnn/cpp/ttnn/operations/data_movement/concat/device/concat_device_operation.cpp index 1e3d958c6613..731cf2b2fb67 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/concat/device/concat_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/concat/device/concat_device_operation.cpp @@ -123,7 +123,9 @@ Tensor concat_impl(std::vector &input_tensors, const std::int64_t dim, c "Current concat implementation requires aligned last dim when concatting on last dim"); } } - Layout target_layout = Layout::TILE; + // row major should default to row major and tilized to tilized implementations, but the below loop turned RM to tilized when possible + Layout target_layout = input_tensors[0].get_layout(); + // this should be dead code when instantiating layout to match the input for (const auto &input_tensor : input_tensors) { if (input_tensor.get_layout() == Layout::ROW_MAJOR) { const auto &input_shape = input_tensor.get_legacy_shape(); diff --git a/ttnn/cpp/ttnn/operations/data_movement/repeat_interleave/repeat_interleave.cpp b/ttnn/cpp/ttnn/operations/data_movement/repeat_interleave/repeat_interleave.cpp index 7a7bb6f80425..515e8d64df53 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/repeat_interleave/repeat_interleave.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/repeat_interleave/repeat_interleave.cpp @@ -6,51 +6,55 @@ #include "repeat_interleave.hpp" #include "ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/reshape.hpp" +#include "ttnn/cpp/ttnn/operations/data_movement/unsqueeze/unsqueeze.hpp" +#include "ttnn/cpp/ttnn/operations/data_movement/transpose/transpose.hpp" +#include "ttnn/cpp/ttnn/operations/copy.hpp" namespace ttnn { namespace operations { namespace data_movement { - // repeat interleave supports repeats as 1 to inf, dim between 0 to 2 ttnn::Tensor ExecuteRepeatInterleave::invoke(const ttnn::Tensor& input_a, uint32_t repeat, int32_t dim, std::optional output_mem_config) { std::vector combined_tensors; combined_tensors.reserve(repeat); - auto shape_wh = input_a.get_legacy_shape(); MemoryConfig mem_config = output_mem_config.value_or(input_a.memory_config()); - // normalizing the negative dim - uint32_t normalized_dim = input_a.get_legacy_shape().get_normalized_index(dim); - // check if dim is 1 or 3 - if (normalized_dim & 1) { - constexpr uint32_t tmp_dim = 2; - std::vector dims = {0, 1, 2, 3}; - std::swap(dims[dim], dims[tmp_dim]); - Tensor transpose_input = ttnn::permute(input_a, dims); - Tensor ril_result = ExecuteRepeatInterleave::invoke(transpose_input, repeat, tmp_dim, mem_config); - return ttnn::permute(ril_result, dims); + if (repeat == 1) { + return ttnn::to_memory_config(input_a, mem_config); + } + uint32_t input_rank = input_a.get_shape().rank(); + uint32_t normalized_dim = input_a.get_shape().get_normalized_index(dim); + if (normalized_dim == input_rank - 1) { + auto transposed_input = ttnn::transpose(input_a, -1, -2, mem_config); + auto repeated_input = ExecuteRepeatInterleave::invoke(transposed_input, repeat, -2, mem_config); + return ttnn::transpose(repeated_input, -1, -2, mem_config); + } + + ttnn::Tensor rm_input = input_a; + bool typecast = input_a.get_dtype() != DataType::BFLOAT16; + if (typecast) { + rm_input = ttnn::typecast(rm_input, DataType::BFLOAT16, mem_config); + } + + rm_input = ttnn::to_layout(rm_input, Layout::ROW_MAJOR, std::nullopt, std::nullopt, (Device*)nullptr); + std::vector final_shape; + final_shape.reserve(input_rank); + for (uint32_t i = 0; i < rm_input.get_shape().rank(); i++) { + final_shape.push_back(rm_input.get_shape()[i]); } - if (normalized_dim <= 1) { - for (int i = 0; i < repeat; i++) { - combined_tensors.push_back(input_a); - } - // TODO: For dim = 1 facing issue with concat_op - if (normalized_dim) { - Tensor concat_out = ttnn::concat(combined_tensors, 2); - return ttnn::reshape_on_device(concat_out, shape_wh[0], shape_wh[1] * repeat, shape_wh[2], shape_wh[3]); - } else { - Tensor concat_out = ttnn::concat(combined_tensors, 1); - return ttnn::reshape_on_device(concat_out, shape_wh[0] * repeat, shape_wh[1], shape_wh[2], shape_wh[3]); - } - } else { - Tensor reshape_out = ttnn::reshape_on_device(input_a, 1, 1, shape_wh[0] * shape_wh[1] * shape_wh[2], shape_wh[3]); - for (int i = 0; i < repeat; i++) { - combined_tensors.push_back(reshape_out); - } - Tensor concat_out = ttnn::concat(combined_tensors, 1); - std::vector permute_dims = {0, 2, 1, 3}; - Tensor permute_out = ttnn::permute(concat_out, permute_dims); - return ttnn::reshape_on_device(permute_out, shape_wh[0], shape_wh[1], shape_wh[2] * repeat, shape_wh[3]); + + final_shape[normalized_dim] *= repeat; + + auto unsqueezed_tensor = ttnn::unsqueeze(rm_input, normalized_dim + 1); + for (uint32_t i = 0; i < repeat; i++) { + combined_tensors.push_back(unsqueezed_tensor); } + + auto concatenated_tensor = ttnn::concat(combined_tensors, normalized_dim + 1); + auto reshaped_tensor = ttnn::reshape(concatenated_tensor, ttnn::Shape(final_shape)); + auto original_layout = ttnn::to_layout(reshaped_tensor, input_a.get_layout(), std::nullopt, std::nullopt, (Device*)nullptr); + return typecast ? ttnn::typecast(original_layout, input_a.get_dtype(), mem_config) : original_layout; + } } // namespace data_movement diff --git a/ttnn/cpp/ttnn/operations/data_movement/transpose/transpose.cpp b/ttnn/cpp/ttnn/operations/data_movement/transpose/transpose.cpp index ef8936a78c40..50e3576e8e58 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/transpose/transpose.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/transpose/transpose.cpp @@ -35,6 +35,7 @@ inline Tensor transpose_(const Tensor &a, TransposeOpDim transpose_dim, const Me case TransposeOpDim::HC: pad_c = a.get_layout() == Layout::TILE && a.get_shape().with_tile_padding()[1] % 32 != 0; break; + // bubble dim around to make it possible as these implementations don't have a kernel case TransposeOpDim::NH: return ttnn::permute((const ttnn::Tensor)a, std::vector({2, 1, 0, 3}), output_mem_config); case TransposeOpDim::NW: @@ -42,7 +43,15 @@ inline Tensor transpose_(const Tensor &a, TransposeOpDim transpose_dim, const Me case TransposeOpDim::CW: return ttnn::permute((const ttnn::Tensor)a, std::vector({0, 3, 2, 1}), output_mem_config); case TransposeOpDim::CN: - tiled_only = true; + tiled_only = true; // CN only has a tiled implementation at the moment + break; + case TransposeOpDim::WH: // THIS NEEDS TO BE FIXED + if (a.device()->arch() == tt::ARCH::GRAYSKULL) { + tiled_only = a.shape()[-2] > 256; // horrible hack because PCC on transpose HW row major is terrible on GS in this code path - kernel spits out garbage and has some demuxing for greater than this size that doesn't work + } + else if (a.device()->arch() == tt::ARCH::WORMHOLE_B0) { + tiled_only = !a.is_sharded() && (a.shape()[-2]*a.shape()[-1] >= 400000); // CB blows up on large sizes, hack until transpose_wh_rm is optimized + } default: break; } @@ -98,17 +107,18 @@ ttnn::Tensor ExecuteTranspose::invoke( uint32_t normalized_dim1 = input_tensor.get_legacy_shape().get_normalized_index(dim1); uint32_t normalized_dim2 = input_tensor.get_legacy_shape().get_normalized_index(dim2); - uint32_t rank_diff = 4 - input_tensor.get_shape().rank(); - Tensor input_4d = input_tensor; - if (rank_diff > 0) { - input_4d = ttnn::unsqueeze_to_4D(input_tensor); + Tensor input_unsqueezed = input_tensor; + uint32_t initial_rank = input_tensor.get_shape().rank(); + if (initial_rank < 4) { + input_unsqueezed = ttnn::unsqueeze_to_4D(input_tensor); + uint32_t rank_diff = 4 - initial_rank; normalized_dim1 += rank_diff; normalized_dim2 += rank_diff; } bool wh = (normalized_dim2 == 2 && normalized_dim1 == 0) || (normalized_dim2 == 0 && normalized_dim1 == 2); - bool typecast = input_4d.get_dtype() == DataType::BFLOAT8_B and input_4d.get_layout() == Layout::TILE and !wh and !input_4d.is_sharded(); - Tensor input_typecasted = typecast ? ttnn::typecast(input_4d, DataType::BFLOAT16) : input_4d; + bool typecast = input_unsqueezed.get_dtype() == DataType::BFLOAT8_B and input_unsqueezed.get_layout() == Layout::TILE and !wh and !input_unsqueezed.is_sharded(); + Tensor input_typecasted = typecast ? ttnn::typecast(input_unsqueezed, DataType::BFLOAT16) : input_unsqueezed; auto input_shape = input_typecasted.get_shape(); @@ -135,8 +145,8 @@ ttnn::Tensor ExecuteTranspose::invoke( auto& a = input_tensors.at(0); auto memory_config = memory_config_arg.value_or(a.memory_config()); - TT_FATAL(normalized_dim1 <= 3, "dimension have to be 0-3 only corresponding to N,C,H,W"); - TT_FATAL(normalized_dim2 <= 3, "dimension have to be 0-3 only corresponding to N,C,H,W"); + TT_FATAL(normalized_dim1 <= 3, "dimension has to be 0-3 only corresponding to N,C,H,W"); + TT_FATAL(normalized_dim2 <= 3, "dimension has to be 0-3 only corresponding to N,C,H,W"); if ( (normalized_dim1 == normalized_dim2) || @@ -170,7 +180,7 @@ ttnn::Tensor ExecuteTranspose::invoke( }, {input_typecasted}, output_tensors); auto output = ttnn::reshape(output_tensors.at(0), ttnn::Shape(output_shape, padded_output_shape)); - output = ttnn::squeeze_from_4D(output, input_tensor.get_shape().rank()); + output = initial_rank < 4 ? ttnn::squeeze_from_4D(output, initial_rank) : output; return typecast ? ttnn::typecast(output, DataType::BFLOAT8_B) : output; } diff --git a/ttnn/cpp/ttnn/operations/reduction/prod/prod.cpp b/ttnn/cpp/ttnn/operations/reduction/prod/prod.cpp index da09bc5ff2ef..e4f56968aa9e 100644 --- a/ttnn/cpp/ttnn/operations/reduction/prod/prod.cpp +++ b/ttnn/cpp/ttnn/operations/reduction/prod/prod.cpp @@ -121,11 +121,6 @@ Tensor ProdOperation::invoke(const Tensor& input_a, bool all_dimensions, int64_t // permute back after_permute_dims = {0, 1, 3, 2}; Tensor res_host = ttnn::permute(new_unpad_tensor, after_permute_dims, output_mem_config); - if(res_host.storage_type() != StorageType::DEVICE or res_host.storage_type() != StorageType::MULTI_DEVICE) { - res_host = res_host.pad_to_tile(0.0f); - res_host = res_host.to(Layout::TILE); - res_host = res_host.to(input_a.device()); - } return res_host; } }