From 0f737626d96caff0086966e52b824c934088d21e Mon Sep 17 00:00:00 2001 From: Anil Mahmud Date: Wed, 4 Dec 2024 12:56:06 +0000 Subject: [PATCH] #14063: Fix test by replace wrong default arguments in kernel code --- .../unit_testing/misc/test_rotary_embedding.py | 2 -- .../device/kernels/compute/rotary_embedding.cpp | 4 ++-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_rotary_embedding.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_rotary_embedding.py index c69aee8286b..e8f69c9fffb 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_rotary_embedding.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_rotary_embedding.py @@ -100,7 +100,6 @@ def test_rotary_embedding_prefill(W, Z, Y, X, cache_size, in_sharded, out_sharde assert p -@skip_for_blackhole("Mismatching on Blackhole, see #12349") @pytest.mark.parametrize( "W, Z, Y, X", ([1, 1, 32, 64], [1, 71, 32, 64], [1, 1, 64, 64], [1, 71, 64, 64], [1, 32, 32, 64], [1, 2, 32, 64]), @@ -246,7 +245,6 @@ def test_rotary_embedding_prefill_fp32( assert p -@skip_for_blackhole("Mismatching on Blackhole, see #12349") @pytest.mark.skipif(is_grayskull(), reason="GS does not support fp32") @pytest.mark.parametrize("W, Z, Y, X", [(1, 1, 32, 64)]) @pytest.mark.parametrize("cache_size", [2048]) diff --git a/ttnn/cpp/ttnn/operations/experimental/transformer/rotary_embedding/device/kernels/compute/rotary_embedding.cpp b/ttnn/cpp/ttnn/operations/experimental/transformer/rotary_embedding/device/kernels/compute/rotary_embedding.cpp index 0c309f10dce..76e7730289d 100644 --- a/ttnn/cpp/ttnn/operations/experimental/transformer/rotary_embedding/device/kernels/compute/rotary_embedding.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/transformer/rotary_embedding/device/kernels/compute/rotary_embedding.cpp @@ -51,7 +51,7 @@ ALWI void UNTILIZE_TILES(uint32_t in0_cb, uint32_t out_cb, uint32_t num_tiles) { } ALWI void TILIZE_ROWS(uint32_t in0_cb, uint32_t sync_cb, uint32_t out_cb, uint32_t num_tiles) { - tilize_init_short(in0_cb, num_tiles); + tilize_init_short(in0_cb, num_tiles, out_cb); cb_wait_front(in0_cb, num_tiles); cb_wait_front(sync_cb, num_tiles); cb_reserve_back(out_cb, num_tiles); @@ -61,7 +61,7 @@ ALWI void TILIZE_ROWS(uint32_t in0_cb, uint32_t sync_cb, uint32_t out_cb, uint32 // Pop shared cbs after tilize cb_pop_front(in0_cb, num_tiles); cb_pop_front(sync_cb, num_tiles); - tilize_uninit(in0_cb); + tilize_uninit(in0_cb, out_cb); } namespace NAMESPACE {