Skip to content

Commit

Permalink
Add tests for untilize, transpose, and tilize on non-4B aligned row-d…
Browse files Browse the repository at this point in the history
…im Row Major tensors + enable support for these tensors on untilize (#15347)

### Ticket

- #15099
- #12705
- #14227
- #13749

### Problem description

- Several OPs could not support RM inputs that were did not have 4B
aligned rows.
- Transpose, slice, untilize and tilize

### What's changed

- Lift the limitations on transpose, slice, untilize and tilize to
enable support for odd row-dim BFP16 tensors
- Increase transpose coverage on traces to 98.4%
- Increase slice coverage on traces to 94.5%

### Checklist
- [x] Post commit CI passes
https://github.com/tenstorrent/tt-metal/actions/runs/11979119941
- [x] Blackhole Post commit (if applicable)
https://github.com/tenstorrent/tt-metal/actions/runs/11962518460/job/33353111930
same issue as main
- [ ] Model regression CI testing passes (if applicable)
- [ ] Device performance regression CI testing passes (if applicable)
- [ ] New/Existing tests provide coverage for changes
  • Loading branch information
sjameelTT authored Nov 23, 2024
1 parent 42bed58 commit c795acc
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
shape_wh = [
[[1, 1, 32, 32]], # Single core
[[3, 1, 320, 384]], # Multi core
[[1, 1024, 5, 1280]], # Non page-aligned
]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import ttnn

from loguru import logger
from models.utility_functions import is_grayskull, is_blackhole
from models.utility_functions import is_grayskull, is_blackhole, torch_random
from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_pcc, comp_equal
from models.utility_functions import skip_for_grayskull, skip_for_blackhole
from tests.ttnn.utils_for_testing import assert_with_pcc
Expand Down Expand Up @@ -644,7 +644,7 @@ def test_transpose_bfloat8_b(device, shape, swap_dims):
)
@pytest.mark.parametrize(
"shape",
[(1, 32, 12, 100), (1, 12, 32, 100), (1, 35, 7, 7), (1, 1, 1, 1)],
[(1, 32, 12, 100), (1, 12, 32, 100), (1, 35, 7, 7), (1, 1, 1, 1), (1, 12, 32, 100)],
)
def test_transpose_hc(dtype, shape, device):
if is_grayskull() and dtype == ttnn.float32:
Expand Down Expand Up @@ -691,15 +691,22 @@ def test_transpose_2D(dtype, shape, layout, device):
)
@pytest.mark.parametrize(
"shape",
[[32, 1, 32], [32, 1, 12], [1, 1, 35], [1, 16, 32], [2, 34, 8]],
[[32, 1, 32], [32, 1, 12], [1, 1, 35], [1, 16, 32], [2, 34, 8], (32, 12, 100), (6, 33, 34)],
)
@pytest.mark.parametrize(
"layout",
[ttnn.TILE_LAYOUT],
)
@pytest.mark.parametrize(
"dims",
[[0, 1], [0, 2], [2, 1], [-3, -2], [-3, -1], [-2, -1]],
[
[0, 1],
[0, 2],
[2, 1],
[-3, -2],
[-3, -1],
[-2, -1],
],
)
def test_transpose_3D(dtype, shape, layout, dims, device):
torch.manual_seed(2005)
Expand Down Expand Up @@ -750,14 +757,14 @@ def test_transpose_4d_wh_tile(shape, device):
@pytest.mark.parametrize(
"config",
[
[[64, 4, 49, 32], [-2, -1], ttnn.ROW_MAJOR_LAYOUT], # Page size must be divisible by sizeof(uint32_t)
[[1, 1370, 1, 3, 1280], [0, -2], ttnn.TILE_LAYOUT], # untilize doesn't work with 4D
[[12, 3], [0, 1], ttnn.ROW_MAJOR_LAYOUT], # need tensor for this one
[[1, 50, 1, 3, 768], [0, -2], ttnn.TILE_LAYOUT], # untilize doesn't work with 4D
[[21843, 768], [0, 1], ttnn.ROW_MAJOR_LAYOUT], # circular buffer overflow
],
)
@pytest.mark.parametrize("memory_config", [ttnn.L1_MEMORY_CONFIG, ttnn.DRAM_MEMORY_CONFIG])
def test_transpose_failures(config, memory_config, device):
pytest.skip("Failures to fix after #13217 and #13005 are in - 5D, HC PCC issue and unaligned RM tensor")
pytest.skip("Failing pytorch 2.0 trace sweeps")
torch.manual_seed(2005)
torch_input = torch.randn(config[0], dtype=torch.bfloat16)
torch_output = torch_input.transpose(config[1][0], config[1][1])
Expand Down Expand Up @@ -793,6 +800,8 @@ def test_transpose_failures(config, memory_config, device):
[[1, 9, 8, 14], [1, 2], ttnn.ROW_MAJOR_LAYOUT], # unaligned RM that fallsback to tiled
[[1, 9, 8, 2], [1, 2], ttnn.ROW_MAJOR_LAYOUT], # unaligned RM that fallsback to tiled
[[1, 2, 8, 2], [1, 2], ttnn.ROW_MAJOR_LAYOUT], # unaligned RM that fallsback to tiled
[[64, 4, 49, 32], [-2, -1], ttnn.ROW_MAJOR_LAYOUT], # Page size must be divisible by sizeof(uint32_t)
[[12, 3], [0, 1], ttnn.ROW_MAJOR_LAYOUT], # need tensor for this one
[
[1, 8, 4096, 40],
[1, 2],
Expand Down Expand Up @@ -943,3 +952,62 @@ def test_transpose_unpadded(shape, dims, layout, dtype, pad_value, device):
assert ttnn.to_torch(a) == float("-inf")
tt_output = ttnn.to_torch(tt_output)
assert_with_pcc(torch_output, tt_output, 0.9999)


@pytest.mark.parametrize("b", [1])
@pytest.mark.parametrize("h", [18])
@pytest.mark.parametrize("w", [65])
@pytest.mark.parametrize("dim0", [1])
@pytest.mark.parametrize("dim1", [2])
def test_transpose_forge_llama(device, b, h, w, dim0, dim1):
torch.manual_seed(2005)

torch_input_tensor = torch_random((b, h, w), -0.1, 0.1, dtype=torch.bfloat16)
torch_output_tensor = torch_input_tensor.transpose(dim0, dim1)

input_tensor = ttnn.to_device(ttnn.from_torch(torch_input_tensor), device, memory_config=ttnn.DRAM_MEMORY_CONFIG)
input_tensor = ttnn.to_layout(input_tensor, layout=ttnn.TILE_LAYOUT)
output_tensor = ttnn.transpose(input_tensor, dim0, dim1, memory_config=ttnn.DRAM_MEMORY_CONFIG)
output_tensor = ttnn.from_device(output_tensor)
output_tensor = ttnn.to_layout(output_tensor, layout=ttnn.ROW_MAJOR_LAYOUT)
output_tensor = ttnn.to_torch(output_tensor)

assert_with_pcc(torch_output_tensor, output_tensor)


@pytest.mark.parametrize("b", [1])
@pytest.mark.parametrize("h", [2])
@pytest.mark.parametrize("w", [3])
@pytest.mark.parametrize("dim0", [-1])
@pytest.mark.parametrize("dim1", [-2])
def test_transpose_forge_basic(device, b, h, w, dim0, dim1):
torch.manual_seed(2005)
torch_input_tensor = torch_random((1, b, h, w), -0.1, 0.1, dtype=torch.bfloat16)
torch_output_tensor = torch_input_tensor.transpose(dim0, dim1)
input_tensor = ttnn.to_device(ttnn.from_torch(torch_input_tensor), device, memory_config=ttnn.DRAM_MEMORY_CONFIG)
input_tensor = ttnn.to_layout(input_tensor, layout=ttnn.TILE_LAYOUT)
output_tensor = ttnn.transpose(input_tensor, dim0, dim1, memory_config=ttnn.DRAM_MEMORY_CONFIG)
output_tensor = ttnn.from_device(output_tensor)
output_tensor = ttnn.to_layout(output_tensor, layout=ttnn.ROW_MAJOR_LAYOUT)
output_tensor = ttnn.to_torch(output_tensor)

assert_with_pcc(torch_output_tensor, output_tensor)


@pytest.mark.parametrize("b", [6])
@pytest.mark.parametrize("h", [33])
@pytest.mark.parametrize("w", [34])
@pytest.mark.parametrize("dim0", [1])
@pytest.mark.parametrize("dim1", [0])
def test_transpose_forge_hc(device, b, h, w, dim0, dim1):
torch.manual_seed(2005)
torch_input_tensor = torch_random((1, b, h, w), -0.1, 0.1, dtype=torch.bfloat16)
torch_output_tensor = torch_input_tensor.transpose(dim0, dim1)
input_tensor = ttnn.to_device(ttnn.from_torch(torch_input_tensor), device, memory_config=ttnn.DRAM_MEMORY_CONFIG)
input_tensor = ttnn.to_layout(input_tensor, layout=ttnn.TILE_LAYOUT)
output_tensor = ttnn.transpose(input_tensor, dim0, dim1, memory_config=ttnn.DRAM_MEMORY_CONFIG)
output_tensor = ttnn.from_device(output_tensor)
output_tensor = ttnn.to_layout(output_tensor, layout=ttnn.ROW_MAJOR_LAYOUT)
output_tensor = ttnn.to_torch(output_tensor)

assert_with_pcc(torch_output_tensor, output_tensor)
36 changes: 31 additions & 5 deletions tests/ttnn/unit_tests/operations/test_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,8 +752,37 @@ def test_slice_adversarial_fixed(input_shape, dim, start, end, step, layout, dev
@pytest.mark.parametrize(
"input_shape, dim, start, end, step, layout",
(
([8732, 4], 1, 0, -1, 4, ttnn.TILE_LAYOUT), # Need tensor for this or a padding aware tiled kernel
([1, 7], 0, 0, -1, 1, ttnn.ROW_MAJOR_LAYOUT), # page size must equal buffer size
([1, 8, 2, 2], 2, -1, -1, 1, ttnn.TILE_LAYOUT), # Buffer size and page size should be larger than 0 bytes
([3], 0, 0, -1, 1, ttnn.TILE_LAYOUT), # Difference in expected shape as it's a 1D tensor
),
)
def test_slice_adversarial(input_shape, dim, start, end, step, layout, device):
pytest.skip("These tests are known to fail")
torch_input = torch.randn(input_shape, dtype=torch.bfloat16)

slice_obj = slice(start, end, step)

# Prepare indices for slicing in the specified dimension
indices = [slice(None)] * len(input_shape) # By default, select all elements along every dimension
indices[dim] = slice_obj # Apply slicing to the target dimension
indices = tuple(indices)

# Apply slicing to the input_tensor
torch_output_tensor = torch_input[indices]

ttnn_tensor = ttnn.from_torch(torch_input, device=device, layout=layout, dtype=ttnn.bfloat16)
ttnn_output = ttnn_tensor[indices]

ttnn_output_tensor = ttnn.to_torch(ttnn_output)

assert_with_pcc(torch_output_tensor, ttnn_output_tensor, 0.999)


@pytest.mark.parametrize(
"input_shape, dim, start, end, step, layout",
(
([8732, 4], 1, 0, -1, 4, ttnn.TILE_LAYOUT), # Need tensor for this or a padding aware tiled kernel
(
[1, 7, 71, 64],
3,
Expand All @@ -762,12 +791,9 @@ def test_slice_adversarial_fixed(input_shape, dim, start, end, step, layout, dev
1,
ttnn.ROW_MAJOR_LAYOUT,
), # An unpadding slice operations for a RowMajor layout on the output tensor requires the last dimension to be on a 32 bit boundary
([1, 8, 2, 2], 2, -1, -1, 1, ttnn.TILE_LAYOUT), # Buffer size and page size should be larger than 0 bytes
([3], 0, 0, -1, 1, ttnn.TILE_LAYOUT), # Difference in expected shape as it's a 1D tensor
),
)
def test_slice_adversarial(input_shape, dim, start, end, step, layout, device):
pytest.skip("These tests are expected to fail at the moment")
def test_slice_adversarial_fixed(input_shape, dim, start, end, step, layout, device):
torch_input = torch.randn(input_shape, dtype=torch.bfloat16)

slice_obj = slice(start, end, step)
Expand Down
15 changes: 15 additions & 0 deletions tests/ttnn/unit_tests/test_to_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import ttnn

from tests.ttnn.utils_for_testing import assert_with_pcc, check_with_pcc_without_tensor_printout
from models.utility_functions import is_grayskull, is_blackhole, torch_random, skip_for_grayskull


@pytest.mark.parametrize("height", [32, 30])
Expand Down Expand Up @@ -125,3 +126,17 @@ def test_untilize_with_unpadding_W_16(device, in_dtype, use_multicore, use_pack_
passing, pcc_msg = check_with_pcc_without_tensor_printout(torch_input, output_torch)
logger.info(pcc_msg)
assert passing


@pytest.mark.parametrize("h", [1, 18, 65])
@pytest.mark.parametrize("w", [1, 15, 17, 29, 33, 49, 63, 65])
@pytest.mark.parametrize("input_layout", [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT])
@pytest.mark.parametrize("output_layout", [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT])
def test_to_layout_device(device, h, w, input_layout, output_layout):
torch.manual_seed(2005)
torch_input_tensor = torch_random((h, w), -0.1, 0.1, dtype=torch.bfloat16)
input_tensor = ttnn.from_torch(torch_input_tensor, device=device, dtype=ttnn.bfloat16, layout=input_layout)
new_layout_tensor = ttnn.to_layout(input_tensor, layout=output_layout)
torch_brought_back = ttnn.to_torch(new_layout_tensor)

assert_with_pcc(torch_input_tensor, torch_brought_back)
Original file line number Diff line number Diff line change
Expand Up @@ -102,18 +102,11 @@ void SliceDeviceOperation::validate_with_output_tensors(
(output_tensor_shape[-1] % TILE_WIDTH == 0) && (this->slice_start[-1] % TILE_WIDTH == 0),
"Can only unpad tilized tensor with full tiles");
} else if (input_tensor_a.get_layout() == Layout::ROW_MAJOR) {
TT_FATAL(
(output_tensor_shape[-1] * input_tensor_a.element_size() % sizeof(uint32_t) == 0),
"An unpadding slice operations for a RowMajor layout on the output tensor requires the last dimension to be on a 32 bit boundary. For example, the final dimension needs to be divisible by 2 for bfloat16. The resulting tensor shape is {}, which is not 4B aligned as the last dimension is {}",
output_tensor_shape[-1], input_tensor_a.element_size());
if (has_step) {
for (uint32_t i = 0; i < input_tensor_a.get_legacy_shape().rank(); i++) {
TT_FATAL(step[i] > 0, "Step({}) = {} should be positive", i, step[i]);
}
}
else {
TT_FATAL(this->slice_start[-1] * input_tensor_a.element_size() % sizeof(uint32_t) == 0, "Slice needs to start at an aligned position");
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ void UntilizeWithUnpadding::validate(const std::vector<Tensor>& input_tensors) c
TT_FATAL(input_tensor_a.get_layout() == Layout::TILE, "Can only untilize tile major data");

TT_FATAL(input_tensor_a.volume() % tt::constants::TILE_HW == 0, "Error");
TT_FATAL(((this->output_tensor_end[-1] + 1) % 2 == 0), "Can only unpad to row major tensor of even width");

if (input_tensor_a.memory_config().is_sharded()) {
if (input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED) {
Expand Down

0 comments on commit c795acc

Please sign in to comment.