diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index e6f4bb652c3..3b95476d413 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -154,7 +154,6 @@ "nn.functional.pixel_unshuffle", "nn.functional.poisson_nll_loss", "nn.functional.rrelu", - "nn.functional.scaled_dot_product_attention", "nn.functional.softmin", "nn.functional.unfold", "nn.functional.upsample_nearest", @@ -362,6 +361,8 @@ def test_reference_eager(self, device, dtype, op): # To avoid errors during testing, replace values below 1 with 1. sample_input.input = self.replace_values_below_threshold( sample_input.input, 1) + if op.name == "nn.functional.scaled_dot_product_attention": + check_output = sample_input.kwargs.get('dropout_p') == 0.0 ignore_index = op.name in should_ignore_indexes