From 01c2dceca249160f107d483365c15a5a0a95409c Mon Sep 17 00:00:00 2001 From: Saad Jameel Date: Tue, 24 Dec 2024 20:50:26 +0000 Subject: [PATCH] #0: update squeezebert perf and use combined kernel for fold --- .../squeezebert/tests/test_perf_device_squeezebert.py | 2 +- ttnn/cpp/ttnn/operations/data_movement/fold/fold.cpp | 10 +++------- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/models/demos/squeezebert/tests/test_perf_device_squeezebert.py b/models/demos/squeezebert/tests/test_perf_device_squeezebert.py index 7f8acbca401..381bc9f0c9d 100644 --- a/models/demos/squeezebert/tests/test_perf_device_squeezebert.py +++ b/models/demos/squeezebert/tests/test_perf_device_squeezebert.py @@ -18,7 +18,7 @@ def test_perf_device_bare_metal(batch_size, test): subdir = "ttnn_squeezebert" num_iterations = 1 margin = 0.03 - expected_perf = 114.8 if is_grayskull() else 284.5 + expected_perf = 102.7 if is_grayskull() else 298.7 command = f"pytest tests/ttnn/integration_tests/squeezebert/test_ttnn_squeezebert.py::test_squeezebert_for_question_answering" cols = ["DEVICE FW", "DEVICE KERNEL", "DEVICE BRISC KERNEL"] diff --git a/ttnn/cpp/ttnn/operations/data_movement/fold/fold.cpp b/ttnn/cpp/ttnn/operations/data_movement/fold/fold.cpp index 754ea6e24b1..21cdc7696db 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/fold/fold.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/fold/fold.cpp @@ -8,6 +8,7 @@ #include "ttnn/operations/math.hpp" #include "ttnn/operations/data_movement/transpose/transpose.hpp" +#include "ttnn/operations/data_movement/permute/device/permute_device_operation.hpp" #include "ttnn/cpp/ttnn/operations/data_movement/slice/slice.hpp" #include "ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/reshape.hpp" #include "ttnn/cpp/ttnn/operations/data_movement/pad/pad.hpp" @@ -61,13 +62,8 @@ std::vector fold_with_transpose_( tt::log_debug("pad_output: {}", pad_output.shape()); - // transpose - auto transpose_hw_output = ttnn::transpose(pad_output, 2, 3, L1_mem_config); - - tt::log_debug("transpose_hw_output: {}", transpose_hw_output.shape()); - - // transpose - auto transpose_hc_output = ttnn::transpose(transpose_hw_output, 1, 2, L1_mem_config); + auto transpose_hc_output = ttnn::prim::permute( + pad_output, ttnn::SmallVector({0, 3, 1, 2}), std::make_optional(L1_mem_config), std::nullopt); tt::log_debug("transpose_hc_output: {}", transpose_hc_output.shape());