diff --git a/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_transpose.py b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_transpose.py index 95818476b8c..4f679c70880 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_transpose.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_transpose.py @@ -14,6 +14,7 @@ shape_wh = [ [[1, 1, 32, 32]], # Single core [[3, 1, 320, 384]], # Multi core + [[1, 1024, 5, 1280]], # Non page-aligned ] diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_transpose.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_transpose.py index f980fd1d4e4..489b25ba5e9 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_transpose.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_transpose.py @@ -9,7 +9,7 @@ import ttnn from loguru import logger -from models.utility_functions import is_grayskull, is_blackhole +from models.utility_functions import is_grayskull, is_blackhole, torch_random from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_pcc, comp_equal from models.utility_functions import skip_for_grayskull, skip_for_blackhole from tests.ttnn.utils_for_testing import assert_with_pcc @@ -644,7 +644,7 @@ def test_transpose_bfloat8_b(device, shape, swap_dims): ) @pytest.mark.parametrize( "shape", - [(1, 32, 12, 100), (1, 12, 32, 100), (1, 35, 7, 7), (1, 1, 1, 1)], + [(1, 32, 12, 100), (1, 12, 32, 100), (1, 35, 7, 7), (1, 1, 1, 1), (1, 12, 32, 100)], ) def test_transpose_hc(dtype, shape, device): if is_grayskull() and dtype == ttnn.float32: @@ -691,7 +691,7 @@ def test_transpose_2D(dtype, shape, layout, device): ) @pytest.mark.parametrize( "shape", - [[32, 1, 32], [32, 1, 12], [1, 1, 35], [1, 16, 32], [2, 34, 8]], + [[32, 1, 32], [32, 1, 12], [1, 1, 35], [1, 16, 32], [2, 34, 8], (32, 12, 100), (6, 33, 34)], ) @pytest.mark.parametrize( "layout", @@ -699,7 +699,14 @@ def test_transpose_2D(dtype, shape, layout, device): ) @pytest.mark.parametrize( "dims", - [[0, 1], [0, 2], [2, 1], [-3, -2], [-3, -1], [-2, -1]], + [ + [0, 1], + [0, 2], + [2, 1], + [-3, -2], + [-3, -1], + [-2, -1], + ], ) def test_transpose_3D(dtype, shape, layout, dims, device): torch.manual_seed(2005) @@ -750,14 +757,14 @@ def test_transpose_4d_wh_tile(shape, device): @pytest.mark.parametrize( "config", [ - [[64, 4, 49, 32], [-2, -1], ttnn.ROW_MAJOR_LAYOUT], # Page size must be divisible by sizeof(uint32_t) [[1, 1370, 1, 3, 1280], [0, -2], ttnn.TILE_LAYOUT], # untilize doesn't work with 4D - [[12, 3], [0, 1], ttnn.ROW_MAJOR_LAYOUT], # need tensor for this one + [[1, 50, 1, 3, 768], [0, -2], ttnn.TILE_LAYOUT], # untilize doesn't work with 4D + [[21843, 768], [0, 1], ttnn.ROW_MAJOR_LAYOUT], # circular buffer overflow ], ) @pytest.mark.parametrize("memory_config", [ttnn.L1_MEMORY_CONFIG, ttnn.DRAM_MEMORY_CONFIG]) def test_transpose_failures(config, memory_config, device): - pytest.skip("Failures to fix after #13217 and #13005 are in - 5D, HC PCC issue and unaligned RM tensor") + pytest.skip("Failing pytorch 2.0 trace sweeps") torch.manual_seed(2005) torch_input = torch.randn(config[0], dtype=torch.bfloat16) torch_output = torch_input.transpose(config[1][0], config[1][1]) @@ -793,6 +800,8 @@ def test_transpose_failures(config, memory_config, device): [[1, 9, 8, 14], [1, 2], ttnn.ROW_MAJOR_LAYOUT], # unaligned RM that fallsback to tiled [[1, 9, 8, 2], [1, 2], ttnn.ROW_MAJOR_LAYOUT], # unaligned RM that fallsback to tiled [[1, 2, 8, 2], [1, 2], ttnn.ROW_MAJOR_LAYOUT], # unaligned RM that fallsback to tiled + [[64, 4, 49, 32], [-2, -1], ttnn.ROW_MAJOR_LAYOUT], # Page size must be divisible by sizeof(uint32_t) + [[12, 3], [0, 1], ttnn.ROW_MAJOR_LAYOUT], # need tensor for this one [ [1, 8, 4096, 40], [1, 2], @@ -943,3 +952,62 @@ def test_transpose_unpadded(shape, dims, layout, dtype, pad_value, device): assert ttnn.to_torch(a) == float("-inf") tt_output = ttnn.to_torch(tt_output) assert_with_pcc(torch_output, tt_output, 0.9999) + + +@pytest.mark.parametrize("b", [1]) +@pytest.mark.parametrize("h", [18]) +@pytest.mark.parametrize("w", [65]) +@pytest.mark.parametrize("dim0", [1]) +@pytest.mark.parametrize("dim1", [2]) +def test_transpose_forge_llama(device, b, h, w, dim0, dim1): + torch.manual_seed(2005) + + torch_input_tensor = torch_random((b, h, w), -0.1, 0.1, dtype=torch.bfloat16) + torch_output_tensor = torch_input_tensor.transpose(dim0, dim1) + + input_tensor = ttnn.to_device(ttnn.from_torch(torch_input_tensor), device, memory_config=ttnn.DRAM_MEMORY_CONFIG) + input_tensor = ttnn.to_layout(input_tensor, layout=ttnn.TILE_LAYOUT) + output_tensor = ttnn.transpose(input_tensor, dim0, dim1, memory_config=ttnn.DRAM_MEMORY_CONFIG) + output_tensor = ttnn.from_device(output_tensor) + output_tensor = ttnn.to_layout(output_tensor, layout=ttnn.ROW_MAJOR_LAYOUT) + output_tensor = ttnn.to_torch(output_tensor) + + assert_with_pcc(torch_output_tensor, output_tensor) + + +@pytest.mark.parametrize("b", [1]) +@pytest.mark.parametrize("h", [2]) +@pytest.mark.parametrize("w", [3]) +@pytest.mark.parametrize("dim0", [-1]) +@pytest.mark.parametrize("dim1", [-2]) +def test_transpose_forge_basic(device, b, h, w, dim0, dim1): + torch.manual_seed(2005) + torch_input_tensor = torch_random((1, b, h, w), -0.1, 0.1, dtype=torch.bfloat16) + torch_output_tensor = torch_input_tensor.transpose(dim0, dim1) + input_tensor = ttnn.to_device(ttnn.from_torch(torch_input_tensor), device, memory_config=ttnn.DRAM_MEMORY_CONFIG) + input_tensor = ttnn.to_layout(input_tensor, layout=ttnn.TILE_LAYOUT) + output_tensor = ttnn.transpose(input_tensor, dim0, dim1, memory_config=ttnn.DRAM_MEMORY_CONFIG) + output_tensor = ttnn.from_device(output_tensor) + output_tensor = ttnn.to_layout(output_tensor, layout=ttnn.ROW_MAJOR_LAYOUT) + output_tensor = ttnn.to_torch(output_tensor) + + assert_with_pcc(torch_output_tensor, output_tensor) + + +@pytest.mark.parametrize("b", [6]) +@pytest.mark.parametrize("h", [33]) +@pytest.mark.parametrize("w", [34]) +@pytest.mark.parametrize("dim0", [1]) +@pytest.mark.parametrize("dim1", [0]) +def test_transpose_forge_hc(device, b, h, w, dim0, dim1): + torch.manual_seed(2005) + torch_input_tensor = torch_random((1, b, h, w), -0.1, 0.1, dtype=torch.bfloat16) + torch_output_tensor = torch_input_tensor.transpose(dim0, dim1) + input_tensor = ttnn.to_device(ttnn.from_torch(torch_input_tensor), device, memory_config=ttnn.DRAM_MEMORY_CONFIG) + input_tensor = ttnn.to_layout(input_tensor, layout=ttnn.TILE_LAYOUT) + output_tensor = ttnn.transpose(input_tensor, dim0, dim1, memory_config=ttnn.DRAM_MEMORY_CONFIG) + output_tensor = ttnn.from_device(output_tensor) + output_tensor = ttnn.to_layout(output_tensor, layout=ttnn.ROW_MAJOR_LAYOUT) + output_tensor = ttnn.to_torch(output_tensor) + + assert_with_pcc(torch_output_tensor, output_tensor) diff --git a/tests/ttnn/unit_tests/operations/test_slice.py b/tests/ttnn/unit_tests/operations/test_slice.py index a85b33ada9a..9facb46a90c 100644 --- a/tests/ttnn/unit_tests/operations/test_slice.py +++ b/tests/ttnn/unit_tests/operations/test_slice.py @@ -752,8 +752,37 @@ def test_slice_adversarial_fixed(input_shape, dim, start, end, step, layout, dev @pytest.mark.parametrize( "input_shape, dim, start, end, step, layout", ( - ([8732, 4], 1, 0, -1, 4, ttnn.TILE_LAYOUT), # Need tensor for this or a padding aware tiled kernel ([1, 7], 0, 0, -1, 1, ttnn.ROW_MAJOR_LAYOUT), # page size must equal buffer size + ([1, 8, 2, 2], 2, -1, -1, 1, ttnn.TILE_LAYOUT), # Buffer size and page size should be larger than 0 bytes + ([3], 0, 0, -1, 1, ttnn.TILE_LAYOUT), # Difference in expected shape as it's a 1D tensor + ), +) +def test_slice_adversarial(input_shape, dim, start, end, step, layout, device): + pytest.skip("These tests are known to fail") + torch_input = torch.randn(input_shape, dtype=torch.bfloat16) + + slice_obj = slice(start, end, step) + + # Prepare indices for slicing in the specified dimension + indices = [slice(None)] * len(input_shape) # By default, select all elements along every dimension + indices[dim] = slice_obj # Apply slicing to the target dimension + indices = tuple(indices) + + # Apply slicing to the input_tensor + torch_output_tensor = torch_input[indices] + + ttnn_tensor = ttnn.from_torch(torch_input, device=device, layout=layout, dtype=ttnn.bfloat16) + ttnn_output = ttnn_tensor[indices] + + ttnn_output_tensor = ttnn.to_torch(ttnn_output) + + assert_with_pcc(torch_output_tensor, ttnn_output_tensor, 0.999) + + +@pytest.mark.parametrize( + "input_shape, dim, start, end, step, layout", + ( + ([8732, 4], 1, 0, -1, 4, ttnn.TILE_LAYOUT), # Need tensor for this or a padding aware tiled kernel ( [1, 7, 71, 64], 3, @@ -762,12 +791,9 @@ def test_slice_adversarial_fixed(input_shape, dim, start, end, step, layout, dev 1, ttnn.ROW_MAJOR_LAYOUT, ), # An unpadding slice operations for a RowMajor layout on the output tensor requires the last dimension to be on a 32 bit boundary - ([1, 8, 2, 2], 2, -1, -1, 1, ttnn.TILE_LAYOUT), # Buffer size and page size should be larger than 0 bytes - ([3], 0, 0, -1, 1, ttnn.TILE_LAYOUT), # Difference in expected shape as it's a 1D tensor ), ) -def test_slice_adversarial(input_shape, dim, start, end, step, layout, device): - pytest.skip("These tests are expected to fail at the moment") +def test_slice_adversarial_fixed(input_shape, dim, start, end, step, layout, device): torch_input = torch.randn(input_shape, dtype=torch.bfloat16) slice_obj = slice(start, end, step) diff --git a/tests/ttnn/unit_tests/test_to_layout.py b/tests/ttnn/unit_tests/test_to_layout.py index fafab9674a1..b84a8f4c5fc 100644 --- a/tests/ttnn/unit_tests/test_to_layout.py +++ b/tests/ttnn/unit_tests/test_to_layout.py @@ -10,6 +10,7 @@ import ttnn from tests.ttnn.utils_for_testing import assert_with_pcc, check_with_pcc_without_tensor_printout +from models.utility_functions import is_grayskull, is_blackhole, torch_random, skip_for_grayskull @pytest.mark.parametrize("height", [32, 30]) @@ -125,3 +126,17 @@ def test_untilize_with_unpadding_W_16(device, in_dtype, use_multicore, use_pack_ passing, pcc_msg = check_with_pcc_without_tensor_printout(torch_input, output_torch) logger.info(pcc_msg) assert passing + + +@pytest.mark.parametrize("h", [1, 18, 65]) +@pytest.mark.parametrize("w", [1, 15, 17, 29, 33, 49, 63, 65]) +@pytest.mark.parametrize("input_layout", [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT]) +@pytest.mark.parametrize("output_layout", [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT]) +def test_to_layout_device(device, h, w, input_layout, output_layout): + torch.manual_seed(2005) + torch_input_tensor = torch_random((h, w), -0.1, 0.1, dtype=torch.bfloat16) + input_tensor = ttnn.from_torch(torch_input_tensor, device=device, dtype=ttnn.bfloat16, layout=input_layout) + new_layout_tensor = ttnn.to_layout(input_tensor, layout=output_layout) + torch_brought_back = ttnn.to_torch(new_layout_tensor) + + assert_with_pcc(torch_input_tensor, torch_brought_back) diff --git a/ttnn/cpp/ttnn/operations/data_movement/slice/device/slice_op.cpp b/ttnn/cpp/ttnn/operations/data_movement/slice/device/slice_op.cpp index 6398986a22d..7cde7097244 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/slice/device/slice_op.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/slice/device/slice_op.cpp @@ -102,18 +102,11 @@ void SliceDeviceOperation::validate_with_output_tensors( (output_tensor_shape[-1] % TILE_WIDTH == 0) && (this->slice_start[-1] % TILE_WIDTH == 0), "Can only unpad tilized tensor with full tiles"); } else if (input_tensor_a.get_layout() == Layout::ROW_MAJOR) { - TT_FATAL( - (output_tensor_shape[-1] * input_tensor_a.element_size() % sizeof(uint32_t) == 0), - "An unpadding slice operations for a RowMajor layout on the output tensor requires the last dimension to be on a 32 bit boundary. For example, the final dimension needs to be divisible by 2 for bfloat16. The resulting tensor shape is {}, which is not 4B aligned as the last dimension is {}", - output_tensor_shape[-1], input_tensor_a.element_size()); if (has_step) { for (uint32_t i = 0; i < input_tensor_a.get_legacy_shape().rank(); i++) { TT_FATAL(step[i] > 0, "Step({}) = {} should be positive", i, step[i]); } } - else { - TT_FATAL(this->slice_start[-1] * input_tensor_a.element_size() % sizeof(uint32_t) == 0, "Slice needs to start at an aligned position"); - } } } diff --git a/ttnn/cpp/ttnn/operations/data_movement/untilize_with_unpadding/device/untilize_with_unpadding_op.cpp b/ttnn/cpp/ttnn/operations/data_movement/untilize_with_unpadding/device/untilize_with_unpadding_op.cpp index 4143b8dfa26..13ee19124cf 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/untilize_with_unpadding/device/untilize_with_unpadding_op.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/untilize_with_unpadding/device/untilize_with_unpadding_op.cpp @@ -17,7 +17,6 @@ void UntilizeWithUnpadding::validate(const std::vector& input_tensors) c TT_FATAL(input_tensor_a.get_layout() == Layout::TILE, "Can only untilize tile major data"); TT_FATAL(input_tensor_a.volume() % tt::constants::TILE_HW == 0, "Error"); - TT_FATAL(((this->output_tensor_end[-1] + 1) % 2 == 0), "Can only unpad to row major tensor of even width"); if (input_tensor_a.memory_config().is_sharded()) { if (input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED) {