Skip to content

Commit

Permalink
Fix op test for `nn.functional.scaled_dot_product_attention"nn.functi… (
Browse files Browse the repository at this point in the history
  • Loading branch information
guyao authored Sep 16, 2024
1 parent 0e47739 commit 3b0097d
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@
"nn.functional.pixel_unshuffle",
"nn.functional.poisson_nll_loss",
"nn.functional.rrelu",
"nn.functional.scaled_dot_product_attention",
"nn.functional.softmin",
"nn.functional.unfold",
"nn.functional.upsample_nearest",
Expand Down Expand Up @@ -362,6 +361,8 @@ def test_reference_eager(self, device, dtype, op):
# To avoid errors during testing, replace values below 1 with 1.
sample_input.input = self.replace_values_below_threshold(
sample_input.input, 1)
if op.name == "nn.functional.scaled_dot_product_attention":
check_output = sample_input.kwargs.get('dropout_p') == 0.0

ignore_index = op.name in should_ignore_indexes

Expand Down

0 comments on commit 3b0097d

Please sign in to comment.