Skip to content

Commit

Permalink
#12101: Fixing reduce to account padding while calculating output dim…
Browse files Browse the repository at this point in the history
…ension (#12274)

Problem description
While developing the TT-MLIR compiler, we encountered an issue when using ttnn.mean op. The problem occurs with ttnn.mean operation when input tensor dims aren't tile-dim aligned. The code errors out with the following error message:

Unable to reshape a tensor in TILE_LAYOUT to non-tile height and width! Please convert the tensor to ROW_MAJOR_LAYOUT first.

The issue is in the TTNN implementation of reduction op. The issue is in this line where we calculate the padding for the output shape:
padded_output_shape.push_back(input_shape[axis]);

We iterate through all the dims of the tensor. When we encounter the dim that we are not reducing, we want to capture the non-padded and padded parts of the input shape, but we don't capture the padded part correctly because input_shape[axis] returns the unpadded part. Instead, we should retrieve something like this:
padded_output_shape.push_back(input_shape.value[axis]);

What's changed
The change includes the padding in the calculation of the output tensor. This change impacts all reduction ops, which now work for dimensions not tile-aligned (like tensor(1, 63, 37)).
  • Loading branch information
sdjordjevicTT authored Sep 10, 2024
1 parent 37f1d24 commit 6d0d080
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 22 deletions.
12 changes: 6 additions & 6 deletions tests/ttnn/unit_tests/operations/test_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
from models.utility_functions import torch_random


@pytest.mark.parametrize("batch_size", [1, 16])
@pytest.mark.parametrize("h", [32, 64])
@pytest.mark.parametrize("w", [32, 64])
@pytest.mark.parametrize("batch_size", [1, 16, 1, 16])
@pytest.mark.parametrize("h", [32, 64, 41, 37])
@pytest.mark.parametrize("w", [32, 64, 31, 63])
@pytest.mark.parametrize("dim", [-1, -2])
def test_max(device, batch_size, h, w, dim):
torch.manual_seed(0)
Expand All @@ -32,9 +32,9 @@ def test_max(device, batch_size, h, w, dim):
assert_with_pcc(torch_output_tensor, output_tensor)


@pytest.mark.parametrize("batch_size", [1, 16])
@pytest.mark.parametrize("h", [32, 64])
@pytest.mark.parametrize("w", [32, 64])
@pytest.mark.parametrize("batch_size", [1, 16, 1, 16])
@pytest.mark.parametrize("h", [32, 64, 41, 37])
@pytest.mark.parametrize("w", [32, 64, 31, 63])
def test_max_global(device, batch_size, h, w):
torch.manual_seed(0)

Expand Down
6 changes: 3 additions & 3 deletions tests/ttnn/unit_tests/operations/test_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
from models.utility_functions import torch_random, is_wormhole_b0, is_wormhole_b0


@pytest.mark.parametrize("batch_size", [1, 16])
@pytest.mark.parametrize("h", [32, 64])
@pytest.mark.parametrize("w", [32, 64])
@pytest.mark.parametrize("batch_size", [1, 16, 1, 16])
@pytest.mark.parametrize("h", [32, 64, 41, 37])
@pytest.mark.parametrize("w", [32, 64, 31, 63])
@pytest.mark.parametrize("dim", [-1, -2])
def test_mean(device, batch_size, h, w, dim):
torch.manual_seed(0)
Expand Down
12 changes: 6 additions & 6 deletions tests/ttnn/unit_tests/operations/test_min.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
from models.utility_functions import torch_random, is_wormhole_b0


@pytest.mark.parametrize("batch_size", [1, 16])
@pytest.mark.parametrize("h", [32, 64])
@pytest.mark.parametrize("w", [32, 64])
@pytest.mark.parametrize("batch_size", [1, 16, 1, 16])
@pytest.mark.parametrize("h", [32, 64, 41, 37])
@pytest.mark.parametrize("w", [32, 64, 31, 63])
@pytest.mark.parametrize("dim", [-1, -2])
def test_min(device, batch_size, h, w, dim):
if is_wormhole_b0() and dim == -2:
Expand All @@ -33,9 +33,9 @@ def test_min(device, batch_size, h, w, dim):
assert_with_pcc(torch_output_tensor, output_tensor)


@pytest.mark.parametrize("batch_size", [1, 16])
@pytest.mark.parametrize("h", [32, 64])
@pytest.mark.parametrize("w", [32, 64])
@pytest.mark.parametrize("batch_size", [1, 16, 1, 16])
@pytest.mark.parametrize("h", [32, 64, 41, 37])
@pytest.mark.parametrize("w", [32, 64, 31, 63])
def test_min_global(device, batch_size, h, w):
if is_wormhole_b0():
pytest.skip("Issue #6991: PCC mismatch")
Expand Down
12 changes: 6 additions & 6 deletions tests/ttnn/unit_tests/operations/test_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
from models.utility_functions import torch_random, is_wormhole_b0


@pytest.mark.parametrize("batch_size", [1, 16])
@pytest.mark.parametrize("h", [32, 64])
@pytest.mark.parametrize("w", [32, 64])
@pytest.mark.parametrize("batch_size", [1, 16, 1, 16])
@pytest.mark.parametrize("h", [32, 64, 41, 37])
@pytest.mark.parametrize("w", [32, 64, 31, 63])
@pytest.mark.parametrize("dim", [-1, -2, (2, 1)])
def test_sum(device, batch_size, h, w, dim):
torch.manual_seed(0)
Expand All @@ -33,9 +33,9 @@ def test_sum(device, batch_size, h, w, dim):
assert_with_pcc(torch_output_tensor, output_tensor)


@pytest.mark.parametrize("batch_size", [1, 16])
@pytest.mark.parametrize("h", [32, 64])
@pytest.mark.parametrize("w", [32, 64])
@pytest.mark.parametrize("batch_size", [1, 16, 1, 16])
@pytest.mark.parametrize("h", [32, 64, 41, 37])
@pytest.mark.parametrize("w", [32, 64, 31, 63])
def test_sum_global(device, batch_size, h, w):
torch.manual_seed(0)
if is_wormhole_b0():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,10 @@ static Tensor reduce_impl(
padded_output_shape.push_back(axis >= rank - 2 ? ttnn::TILE_SIZE : 1);
}
} else {
// Get the shape for the output tensor
output_shape.push_back(input_shape[axis]);
padded_output_shape.push_back(input_shape[axis]);
// Get the padded shape for the output tensor
padded_output_shape.push_back(input_shape.value[axis]);
}
}

Expand Down

0 comments on commit 6d0d080

Please sign in to comment.