Skip to content

Commit

Permalink
support reduction for 3d & 4d dims (#16236)
Browse files Browse the repository at this point in the history
### Ticket
Link to Github Issue
#16118

### Problem description
Provide context for the problem.
Current code doesn't support dim.size()>2

### What's changed
Describe the approach used to solve the problem.
Summarize the changes made and its impact.
Added support for higher dim.
Note: it only supports if input_tensor_rank == dim.size(). Other
combinations are out of scope of this PR

### Checklist
- [x] Post commit CI passes :
https://github.com/tenstorrent/tt-metal/actions/runs/12435618460
- [x] Blackhole Post commit (if applicable) :
https://github.com/tenstorrent/tt-metal/actions/runs/12435616950
- [x] Model regression CI testing passes (if applicable) :
https://github.com/tenstorrent/tt-metal/actions/runs/12435615910
- [x] Device performance regression CI testing passes (if applicable) :
https://github.com/tenstorrent/tt-metal/actions/runs/12435614685
- [ ] **(For models and ops writers)** Full [new
models](https://github.com/tenstorrent/tt-metal/actions/workflows/full-new-models-suite.yaml)
tests passes
- [x] New/Existing tests provide coverage for changes

---------

Signed-off-by: Amruth Sandhupatla <[email protected]>
  • Loading branch information
asandhupatlaTT authored Dec 20, 2024
1 parent 29e0cae commit b4d8b4c
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
22 changes: 22 additions & 0 deletions tests/ttnn/unit_tests/operations/test_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,25 @@ def test_var(device, batch_size, h, w, dim):

output_tensor = ttnn.to_torch(output_tensor)
assert_with_pcc(torch_output_tensor, output_tensor, pcc=0.99)


@pytest.mark.parametrize("batch_size", [1, 16])
@pytest.mark.parametrize("c", [1, 4, 8, 16])
@pytest.mark.parametrize("h", [32, 64, 41, 37])
@pytest.mark.parametrize("w", [32, 64, 31, 63])
@pytest.mark.parametrize("dim", [None, [0, 1, 2, 3]])
@pytest.mark.parametrize("keepdim", [True])
def test_sum_4d_tensors(device, batch_size, c, h, w, dim, keepdim):
torch.manual_seed(0)

torch_input_tensor = torch.randn((batch_size, c, h, w), dtype=torch.bfloat16)
torch_output_tensor = torch.sum(torch_input_tensor, dim=dim, keepdim=keepdim)

input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)

output_tensor = ttnn.sum(input_tensor, dim=dim, keepdim=keepdim)
output_tensor = ttnn.to_layout(output_tensor, ttnn.TILE_LAYOUT)
output_tensor = ttnn.from_device(output_tensor)

output_tensor = ttnn.to_torch(output_tensor)
assert_with_pcc(torch_output_tensor, output_tensor, pcc=0.99)
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ static Tensor reduce_impl(
auto input_tensor = ttnn::unsqueeze_to_4D(input_tensor_arg);

Tensor output_tensor;
if (!dim_arg.has_value()) {
if (!dim_arg.has_value() || dim.size() == rank) {
if constexpr (
reduce_type == ReduceType::Sum || reduce_type == ReduceType::Max || reduce_type == ReduceType::Min) {
output_tensor = input_tensor;
Expand Down

0 comments on commit b4d8b4c

Please sign in to comment.