diff --git a/tests/ttnn/unit_tests/operations/test_reduction.py b/tests/ttnn/unit_tests/operations/test_reduction.py index 6769fea9414..ae2aec562e0 100644 --- a/tests/ttnn/unit_tests/operations/test_reduction.py +++ b/tests/ttnn/unit_tests/operations/test_reduction.py @@ -48,3 +48,25 @@ def test_var(device, batch_size, h, w, dim): output_tensor = ttnn.to_torch(output_tensor) assert_with_pcc(torch_output_tensor, output_tensor, pcc=0.99) + + +@pytest.mark.parametrize("batch_size", [1, 16]) +@pytest.mark.parametrize("c", [1, 4, 8, 16]) +@pytest.mark.parametrize("h", [32, 64, 41, 37]) +@pytest.mark.parametrize("w", [32, 64, 31, 63]) +@pytest.mark.parametrize("dim", [None, [0, 1, 2, 3]]) +@pytest.mark.parametrize("keepdim", [True]) +def test_sum_4d_tensors(device, batch_size, c, h, w, dim, keepdim): + torch.manual_seed(0) + + torch_input_tensor = torch.randn((batch_size, c, h, w), dtype=torch.bfloat16) + torch_output_tensor = torch.sum(torch_input_tensor, dim=dim, keepdim=keepdim) + + input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device) + + output_tensor = ttnn.sum(input_tensor, dim=dim, keepdim=keepdim) + output_tensor = ttnn.to_layout(output_tensor, ttnn.TILE_LAYOUT) + output_tensor = ttnn.from_device(output_tensor) + + output_tensor = ttnn.to_torch(output_tensor) + assert_with_pcc(torch_output_tensor, output_tensor, pcc=0.99) diff --git a/ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions.cpp b/ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions.cpp index 1f971e25b52..18cb871dec8 100644 --- a/ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions.cpp +++ b/ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions.cpp @@ -93,7 +93,7 @@ static Tensor reduce_impl( auto input_tensor = ttnn::unsqueeze_to_4D(input_tensor_arg); Tensor output_tensor; - if (!dim_arg.has_value()) { + if (!dim_arg.has_value() || dim.size() == rank) { if constexpr ( reduce_type == ReduceType::Sum || reduce_type == ReduceType::Max || reduce_type == ReduceType::Min) { output_tensor = input_tensor;