Skip to content

Commit

Permalink
undo revert of #16247 (#16430)
Browse files Browse the repository at this point in the history
### Ticket
[Link to Github
Issue](#16233)

### Problem description
reduction doesn't support dims of all possible lengths

### What's changed
Added code to support all permutations of dims

### Checklist
- [x] Post commit CI passes:
https://github.com/tenstorrent/tt-metal/actions/runs/12635471794
- [x] Blackhole Post commit (if applicable):
https://github.com/tenstorrent/tt-metal/actions/runs/12653898951
- [x] Model regression CI testing passes (if applicable):
https://github.com/tenstorrent/tt-metal/actions/runs/12643280742
- [x] Device performance regression CI testing passes (if applicable):
https://github.com/tenstorrent/tt-metal/actions/runs/12643286124
- [ ] **(For models and ops writers)** Full [new
models](https://github.com/tenstorrent/tt-metal/actions/workflows/full-new-models-suite.yaml)
tests passes
- [x] New/Existing tests provide coverage for changes

---------

Signed-off-by: asandhupatlaTT <[email protected]>
  • Loading branch information
asandhupatlaTT authored Jan 7, 2025
1 parent 66ae1a9 commit 44287f6
Show file tree
Hide file tree
Showing 3 changed files with 196 additions and 70 deletions.
144 changes: 124 additions & 20 deletions tests/ttnn/unit_tests/operations/test_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,38 @@ def test_var(device, batch_size, h, w, dim):
assert_with_pcc(torch_output_tensor, output_tensor, pcc=0.99)


@pytest.mark.parametrize("batch_size", [1, 16])
@pytest.mark.parametrize("c", [1, 4, 8, 16])
@pytest.mark.parametrize("h", [32, 64, 41, 37])
@pytest.mark.parametrize("w", [32, 64, 31, 63])
@pytest.mark.parametrize("dim", [None, [0, 1, 2, 3]])
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize("c", [11])
@pytest.mark.parametrize("h", [67])
@pytest.mark.parametrize("w", [77])
@pytest.mark.parametrize("dim", [0, 1, 2, 3])
@pytest.mark.parametrize("keepdim", [True])
def test_prod(device, batch_size, c, h, w, dim, keepdim):
torch.manual_seed(0)

torch_input_tensor = torch.randn((batch_size, c, h, w), dtype=torch.bfloat16)
torch_output_tensor = torch.prod(torch_input_tensor, dim=dim, keepdim=keepdim)

input_tensor = ttnn.from_torch(
torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.L1_MEMORY_CONFIG
)

output_tensor = ttnn.prod(input_tensor, dim=dim, memory_config=ttnn.L1_MEMORY_CONFIG)
output_tensor = ttnn.from_device(output_tensor)

output_tensor = ttnn.to_torch(output_tensor)
assert len(output_tensor.shape) == len(torch_output_tensor.shape)
assert output_tensor.shape == torch_output_tensor.shape
# assert_with_pcc(torch_output_tensor, output_tensor, pcc=0.99)


@pytest.mark.parametrize("batch_size", [3])
@pytest.mark.parametrize("c", [5])
@pytest.mark.parametrize("h", [37])
@pytest.mark.parametrize("w", [63])
@pytest.mark.parametrize("dim", [None, [], 0, 2, [0, 1], [1, 3], [0, 1, 2], [1, 2, 3], [0, 1, 2, 3]])
@pytest.mark.parametrize("keepdim", [True])
def test_sum_4d_tensors(device, batch_size, c, h, w, dim, keepdim):
def test_sum_4d_tensor_dims(device, batch_size, c, h, w, dim, keepdim):
torch.manual_seed(0)

torch_input_tensor = torch.randn((batch_size, c, h, w), dtype=torch.bfloat16)
Expand All @@ -72,26 +97,105 @@ def test_sum_4d_tensors(device, batch_size, c, h, w, dim, keepdim):
assert_with_pcc(torch_output_tensor, output_tensor, pcc=0.99)


@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize("c", [11])
@pytest.mark.parametrize("h", [67])
@pytest.mark.parametrize("w", [77])
@pytest.mark.parametrize("dim", [0, 1, 2, 3])
@pytest.mark.parametrize("c", [3])
@pytest.mark.parametrize("h", [31])
@pytest.mark.parametrize("w", [32])
@pytest.mark.parametrize("dim", [[0, 2], [0, 1, 2]])
@pytest.mark.parametrize("keepdim", [True])
def test_prod(device, batch_size, c, h, w, dim, keepdim):
def test_sum_3d_tensor_dims(device, c, h, w, dim, keepdim):
torch.manual_seed(0)

torch_input_tensor = torch.randn((c, h, w), dtype=torch.bfloat16)
torch_output_tensor = torch.sum(torch_input_tensor, dim=dim, keepdim=keepdim)

input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)

output_tensor = ttnn.sum(input_tensor, dim=dim, keepdim=keepdim)
output_tensor = ttnn.to_layout(output_tensor, ttnn.TILE_LAYOUT)
output_tensor = ttnn.from_device(output_tensor)

output_tensor = ttnn.to_torch(output_tensor)
assert_with_pcc(torch_output_tensor, output_tensor, pcc=0.99)


@pytest.mark.parametrize("h", [41])
@pytest.mark.parametrize("w", [31])
@pytest.mark.parametrize("dim", [0, 1, [0, 1]])
@pytest.mark.parametrize("keepdim", [True])
def test_sum_2d_tensor_dims(device, h, w, dim, keepdim):
torch.manual_seed(0)

torch_input_tensor = torch.randn((h, w), dtype=torch.bfloat16)
torch_output_tensor = torch.sum(torch_input_tensor, dim=dim, keepdim=keepdim)

input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)

output_tensor = ttnn.sum(input_tensor, dim=dim, keepdim=keepdim)
output_tensor = ttnn.to_layout(output_tensor, ttnn.TILE_LAYOUT)
output_tensor = ttnn.from_device(output_tensor)

output_tensor = ttnn.to_torch(output_tensor)
assert_with_pcc(torch_output_tensor, output_tensor, pcc=0.99)


@pytest.mark.parametrize("batch_size", [3])
@pytest.mark.parametrize("c", [5])
@pytest.mark.parametrize("h", [37])
@pytest.mark.parametrize("w", [63])
@pytest.mark.parametrize("dim", [None, [], 0, 2, [0, 1], [1, 3], [0, 1, 2], [1, 2, 3], [0, 1, 2, 3]])
@pytest.mark.parametrize("keepdim", [True])
def test_mean_4d_tensor_dims(device, batch_size, c, h, w, dim, keepdim):
torch.manual_seed(0)

torch_input_tensor = torch.randn((batch_size, c, h, w), dtype=torch.bfloat16)
torch_output_tensor = torch.prod(torch_input_tensor, dim=dim, keepdim=keepdim)
torch_output_tensor = torch.mean(torch_input_tensor, dim=dim, keepdim=keepdim)

input_tensor = ttnn.from_torch(
torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.L1_MEMORY_CONFIG
)
input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)

output_tensor = ttnn.prod(input_tensor, dim=dim, memory_config=ttnn.L1_MEMORY_CONFIG)
output_tensor = ttnn.mean(input_tensor, dim=dim, keepdim=keepdim)
output_tensor = ttnn.to_layout(output_tensor, ttnn.TILE_LAYOUT)
output_tensor = ttnn.from_device(output_tensor)

output_tensor = ttnn.to_torch(output_tensor)
assert len(output_tensor.shape) == len(torch_output_tensor.shape)
assert output_tensor.shape == torch_output_tensor.shape
# assert_with_pcc(torch_output_tensor, output_tensor, pcc=0.99)
assert_with_pcc(torch_output_tensor, output_tensor, pcc=0.99)


@pytest.mark.parametrize("c", [3])
@pytest.mark.parametrize("h", [31])
@pytest.mark.parametrize("w", [32])
@pytest.mark.parametrize("dim", [[0, 2], [0, 1, 2]])
@pytest.mark.parametrize("keepdim", [True])
def test_mean_3d_tensor_dims(device, c, h, w, dim, keepdim):
torch.manual_seed(0)

torch_input_tensor = torch.randn((c, h, w), dtype=torch.bfloat16)
torch_output_tensor = torch.mean(torch_input_tensor, dim=dim, keepdim=keepdim)

input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)

output_tensor = ttnn.mean(input_tensor, dim=dim, keepdim=keepdim)
output_tensor = ttnn.to_layout(output_tensor, ttnn.TILE_LAYOUT)
output_tensor = ttnn.from_device(output_tensor)

output_tensor = ttnn.to_torch(output_tensor)
assert_with_pcc(torch_output_tensor, output_tensor, pcc=0.99)


@pytest.mark.parametrize("h", [41])
@pytest.mark.parametrize("w", [31])
@pytest.mark.parametrize("dim", [0, 1, [0, 1]])
@pytest.mark.parametrize("keepdim", [True])
def test_mean_2d_tensor_dims(device, h, w, dim, keepdim):
torch.manual_seed(0)

torch_input_tensor = torch.randn((h, w), dtype=torch.bfloat16)
torch_output_tensor = torch.mean(torch_input_tensor, dim=dim, keepdim=keepdim)

input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)

output_tensor = ttnn.mean(input_tensor, dim=dim, keepdim=keepdim)
output_tensor = ttnn.to_layout(output_tensor, ttnn.TILE_LAYOUT)
output_tensor = ttnn.from_device(output_tensor)

output_tensor = ttnn.to_torch(output_tensor)
assert_with_pcc(torch_output_tensor, output_tensor, pcc=0.99)
3 changes: 2 additions & 1 deletion tests/ttnn/unit_tests/operations/test_reduction_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def test_mean_without_dim(device, batch_size, h, w):
output_tensor = ttnn.mean(input_tensor, keepdim=True)
output_tensor = ttnn.to_torch(output_tensor)
# PCC does not work for a single value. Assert on allclose.
close_passed, close_message = comp_allclose(torch_output_tensor, output_tensor, rtol=0.001, atol=0.001)
# visit issue: https://github.com/tenstorrent/tt-metal/issues/16454 for why tolerance values are changed
close_passed, close_message = comp_allclose(torch_output_tensor, output_tensor, rtol=0.001, atol=0.00139)
if not close_passed:
print(f"Found mismatch: torch_output_tensor {torch_output_tensor}\n output_tensor {output_tensor}")
assert close_passed, construct_pcc_assert_message(close_message, torch_output_tensor, output_tensor)
119 changes: 70 additions & 49 deletions ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,10 @@
namespace ttnn {
namespace operations::reduction {

template <ReduceType reduce_type>
static Tensor reduce_impl(
const Tensor& input_tensor_arg,
const std::optional<std::variant<int, ttnn::SmallVector<int>>>& dim_arg,
const bool keepdim,
const std::optional<MemoryConfig>& memory_config_arg,
const std::optional<DeviceComputeKernelConfig>& compute_kernel_config,
float scalar,
bool reshape) {
using ttnn::operations::experimental::auto_format::AutoFormat;
ttnn::SmallVector<int> generate_reduce_dim(
const Tensor& input_tensor_arg, const std::optional<std::variant<int, ttnn::SmallVector<int>>>& dim_arg) {
auto input_shape = input_tensor_arg.get_shape();
auto rank = input_shape.size();
auto memory_config = memory_config_arg.value_or(input_tensor_arg.memory_config());

ttnn::SmallVector<int> dim{};
if (dim_arg.has_value()) {
if (not std::holds_alternative<ttnn::SmallVector<int>>(dim_arg.value())) {
Expand All @@ -34,7 +24,8 @@ static Tensor reduce_impl(
} else {
dim = std::get<ttnn::SmallVector<int>>(dim_arg.value());
}
} else {
}
if (dim.empty()) {
dim = ttnn::SmallVector<int>(rank);
for (int i = 0; i < rank; i++) {
dim[i] = i;
Expand All @@ -55,6 +46,22 @@ static Tensor reduce_impl(
}

std::sort(dim.begin(), dim.end());
return dim;
}

template <ReduceType reduce_type>
static Tensor reduce_impl(
const Tensor& input_tensor_arg,
const ttnn::SmallVector<int>& dim,
const bool keepdim,
const std::optional<MemoryConfig>& memory_config_arg,
const std::optional<DeviceComputeKernelConfig>& compute_kernel_config,
float scalar,
bool reshape) {
using ttnn::operations::experimental::auto_format::AutoFormat;
auto input_shape = input_tensor_arg.get_shape();
auto rank = input_shape.size();
auto memory_config = memory_config_arg.value_or(input_tensor_arg.memory_config());

ttnn::SmallVector<uint32_t> output_shape;
for (int axis = 0; axis < input_shape.size(); axis++) {
Expand All @@ -68,45 +75,58 @@ static Tensor reduce_impl(
}
}

if (dim.size() == 1 && (rank == 3 || rank == 4)) {
if (dim[0] == 1 && rank == 4) {
Tensor output = ttnn::transpose(input_tensor_arg, 1, -2, memory_config);
output = reduce_impl<reduce_type>(
output, 2, /*keepdim=*/true, memory_config, compute_kernel_config, scalar, /*reshape=*/true);
output = ttnn::transpose(output, 1, -2, memory_config);
if (reshape) {
output = ttnn::reshape(output, ttnn::Shape{output_shape});
}
return output;
} else if (dim[0] == 0) {
Tensor output = ttnn::transpose(input_tensor_arg, 0, -2, memory_config);
output = reduce_impl<reduce_type>(
output, -2, /*keepdim=*/true, memory_config, compute_kernel_config, scalar, /*reshape=*/true);
output = ttnn::transpose(output, 0, -2, memory_config);
if (reshape) {
output = ttnn::reshape(output, ttnn::Shape{output_shape});
}
return output;
}
}

auto input_tensor = ttnn::unsqueeze_to_4D(input_tensor_arg);

Tensor output_tensor;
if (!dim_arg.has_value() || dim.size() == rank) {
if constexpr (
reduce_type == ReduceType::Sum || reduce_type == ReduceType::Max || reduce_type == ReduceType::Min) {
output_tensor = input_tensor;
for (int rank = input_tensor.get_legacy_shape().rank() - 1; rank >= 0; rank--) {
output_tensor = reduce_impl<reduce_type>(
output_tensor, rank, true, memory_config, compute_kernel_config, scalar, false);
bool single_reduce_op = (dim.size() == 1 && (dim[0] == rank - 1 || dim[0] == rank - 2)) ||
(dim.size() == 2 && dim[1] == rank - 1 && dim[0] == rank - 2);
if (!single_reduce_op) {
auto reduce_4d_loop = [&](const bool use_reduce_type) -> Tensor {
Tensor output_tensor = input_tensor;
int offset = 4 - rank;
for (int i_dim = rank - 1; i_dim >= 0; i_dim--) {
bool found = std::find(dim.begin(), dim.end(), i_dim) != dim.end();
if (found) {
bool transpose = i_dim < rank - 2;
int adjusted_dim = offset + i_dim;
int reduce_dim = adjusted_dim;
if (transpose) {
output_tensor = ttnn::transpose(output_tensor, adjusted_dim, 2, memory_config);
reduce_dim = 2;
}
if (use_reduce_type) {
output_tensor = reduce_impl<reduce_type>(
output_tensor,
{reduce_dim},
/*keepdim=*/true,
memory_config,
compute_kernel_config,
scalar,
/*reshape=*/false);
} else {
output_tensor = reduce_impl<ReduceType::Sum>(
output_tensor,
{reduce_dim},
/*keepdim=*/true,
memory_config,
compute_kernel_config,
scalar,
/*reshape=*/false);
}
if (transpose) {
output_tensor = ttnn::transpose(output_tensor, adjusted_dim, -2, memory_config);
}
}
}
return output_tensor;
};
constexpr bool linear_type =
reduce_type == ReduceType::Sum || reduce_type == ReduceType::Max || reduce_type == ReduceType::Min;
if (dim.size() == 1 || linear_type) {
output_tensor = reduce_4d_loop(/*use_reduce_type=*/true);
} else if constexpr (reduce_type == ReduceType::Mean) {
output_tensor = input_tensor;
for (int rank = input_tensor.get_legacy_shape().rank() - 1; rank >= 0; rank--) {
output_tensor = reduce_impl<ReduceType::Sum>(
output_tensor, rank, true, memory_config, compute_kernel_config, scalar, false);
}
output_tensor = reduce_4d_loop(
/*use_reduce_type=*/false);
float inv_volume = 1.0f / input_tensor.get_logical_volume();
output_tensor = ttnn::mul_sfpu(inv_volume, output_tensor, memory_config);
} else {
Expand Down Expand Up @@ -193,7 +213,7 @@ static Tensor reduce_impl(
}

if (reshape) {
output_tensor = ttnn::reshape(output_tensor, ttnn::Shape{output_shape});
output_tensor = ttnn::reshape(output_tensor, ttnn::SimpleShape{output_shape});
}

return output_tensor;
Expand All @@ -207,8 +227,9 @@ Tensor Reduce<reduce_type>::invoke(
const std::optional<MemoryConfig>& memory_config_arg,
const std::optional<DeviceComputeKernelConfig>& compute_kernel_config,
float scalar) {
ttnn::SmallVector<int> dim = generate_reduce_dim(input_tensor_arg, dim_arg);
return reduce_impl<reduce_type>(
input_tensor_arg, dim_arg, keepdim, memory_config_arg, compute_kernel_config, scalar, true);
input_tensor_arg, dim, keepdim, memory_config_arg, compute_kernel_config, scalar, true);
}

template class Reduce<ReduceType::Sum>;
Expand Down

0 comments on commit 44287f6

Please sign in to comment.