diff --git a/tests/ttnn/unit_tests/operations/test_all_reduce_t3000_frequent.py b/tests/ttnn/unit_tests/operations/test_all_reduce_t3000_frequent.py index b8e414c5677..3e75367d05e 100644 --- a/tests/ttnn/unit_tests/operations/test_all_reduce_t3000_frequent.py +++ b/tests/ttnn/unit_tests/operations/test_all_reduce_t3000_frequent.py @@ -111,6 +111,7 @@ def run_all_reduce_test( .to(layout) .to(mesh_device.get_device(mesh_device.get_device_ids()[i]), mem_config) ) + input_tensor = input_tensor.view(1, -1, input_tensor.shape[2], input_tensor.shape[3]) input_tensors.append(input_tensor) unchunked_input_tensor = torch.cat(input_tensors) @@ -133,8 +134,8 @@ def run_all_reduce_test( tt_out_tensors = ttnn.get_device_tensors(output_tensor_mesh) logger.info(f"Compare") - golden_canonical_out_tensor = torch.sum(unchunked_input_tensor, 0) - + golden_canonical_out_tensor = torch.sum(unchunked_input_tensor, 0, keepdim=True) + golden_canonical_out_tensor = golden_canonical_out_tensor.view(per_chip_output_shape) # Compare mismatch = False for i, t in enumerate(tt_out_tensors): @@ -176,6 +177,10 @@ def run_all_reduce_test( ([1, 1, 8192, 32]), ([1, 1, 1024, 32]), ([1, 1, 2048, 32]), + ([4, 1, 32, 4096]), + ([8, 1, 32, 1024]), + ([1, 4, 1024, 32]), + ([2, 4, 2048, 32]), ], ) @pytest.mark.parametrize(