diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_transpose.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_transpose.py index a8f8385c059b..7fac4556c4e2 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_transpose.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_transpose.py @@ -655,18 +655,16 @@ def test_transpose_hc(dtype, shape, device): ) @pytest.mark.parametrize( "shape", - [(9216, 128), (1, 32), (1, 12), (1, 35), (16, 32), (34, 8)], + [(9216, 128), (1, 32), (1, 12), (1, 35), (16, 32), (34, 8), [21843, 768]], ) @pytest.mark.parametrize( "layout", - [ttnn.TILE_LAYOUT], + [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT], ) def test_transpose_2D(dtype, shape, layout, device): torch.manual_seed(2005) if is_grayskull() and dtype == ttnn.float32: pytest.skip("Skipping float32 tests on Grayskull") - if layout == ttnn.ROW_MAJOR_LAYOUT and dtype == ttnn.bfloat16 and (shape[-1] % 2 or shape[-2] % 2): - pytest.skip("Skipping RM odd inner dim test cases") torch_input = torch.randn(shape, dtype=torch.bfloat16) torch_output = torch_input.transpose(0, 1) diff --git a/ttnn/cpp/ttnn/operations/data_movement/transpose/transpose.cpp b/ttnn/cpp/ttnn/operations/data_movement/transpose/transpose.cpp index c7a208489990..04c4ff0d20fa 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/transpose/transpose.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/transpose/transpose.cpp @@ -7,6 +7,7 @@ #include "ttnn/decorators.hpp" #include "device/transpose_op.hpp" #include "ttnn/operations/data_movement/permute/permute.hpp" +#include "ttnn/operations/data_movement/permute/device/permute_device_operation.hpp" #include "ttnn/operations/data_movement/transpose/transpose.hpp" #include "ttnn/cpp/ttnn/operations/copy.hpp" #include "ttnn/cpp/ttnn/operations/data_movement/pad/pad.hpp" @@ -19,38 +20,6 @@ namespace ttnn::operations::data_movement { namespace detail { -inline uint32_t get_estimated_size_of_cbs(const Tensor& input_tensor_a) { - // Circular Buffer sizes: - uint32_t element_size = input_tensor_a.element_size(); - uint32_t Wt = input_tensor_a.get_padded_shape()[-1] / tt::constants::TILE_WIDTH; - uint32_t Ht = input_tensor_a.get_padded_shape()[-2] / tt::constants::TILE_HEIGHT; - uint32_t HtWt = Ht * Wt; - auto data_format = tt::tt_metal::datatype_to_dataformat_converter(input_tensor_a.get_dtype()); - uint32_t tile_size = tt::tt_metal::detail::TileSize(data_format); - - uint32_t cb_src0_size = 2 * Wt * tile_size; - uint32_t cb_output_size = 2 * Ht * tile_size; - uint32_t cb_im_size = Ht * Wt * tile_size; - uint32_t cb_im2_size = Ht * tile_size; - return cb_src0_size + cb_output_size + cb_im_size + cb_im2_size; -} - -inline uint32_t get_max_l1_space(const Tensor& input_tensor_a) { - tt::tt_metal::Device* device = input_tensor_a.device(); - const std::vector& bank_ids = - device->bank_ids_from_logical_core(BufferType::L1, *device->compute_cores_.begin()); - std::optional lowest_address = allocator::lowest_occupied_l1_address(*device->allocator_, bank_ids[0]); - uint32_t max_l1_space = lowest_address.has_value() ? lowest_address.value() : device->l1_size_per_core(); - max_l1_space = max_l1_space - device->get_base_allocator_addr(HalMemType::L1); - return max_l1_space; -} - -inline bool rm_enough_available_space(const Tensor& input_tensor_a) { - uint32_t max_l1_space = get_max_l1_space(input_tensor_a); - uint32_t estimated_size_of_cbs = get_estimated_size_of_cbs(input_tensor_a); - return max_l1_space > estimated_size_of_cbs; -} - inline Tensor transpose_( const Tensor& a, TransposeOpDim transpose_dim, @@ -86,16 +55,14 @@ inline Tensor transpose_( tiled_only = true; // CN only has a tiled implementation at the moment break; case TransposeOpDim::WH: // THIS NEEDS TO BE FIXED - if (((W * a.element_size()) % FACE_WIDTH != 0) || ((H * a.element_size()) % FACE_WIDTH != 0)) { - tiled_only = true; - } else if (a.device()->arch() == tt::ARCH::GRAYSKULL) { + if (a.device()->arch() == tt::ARCH::GRAYSKULL) { tiled_only = a.shape()[-2] > 256; // hangs right now past this dimension, #13660 will turn it from a // hang into a PCC issue for GS and improve perf for WH - } else if ( - !a.is_sharded() && a.layout() == Layout::ROW_MAJOR && - !rm_enough_available_space( - a)) { // rm is L1 intensive, if it overflows we can do tiled which allocates much smaller CBs - tiled_only = true; + } else if (!a.is_sharded() && a.layout() == Layout::ROW_MAJOR) { // rm is L1 intensive, if it overflows we + // can do tiled which allocates much + // smaller CBs + return ttnn::prim::permute( + a, ttnn::SmallVector({0, 1, 3, 2}), output_mem_config, std::nullopt); } break; default: break;