diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index e12d8651137..6ba30054dce 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4201,7 +4201,7 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol): for (_logits_sdpa, _logits_eager) in zip(logits_sdpa, logits_eager) ] # If 80% batch elements have matched results, it's fine - if np.mean(results) < 1.8: + if np.mean(results) < 0.8: fail_cases.append( get_mean_reldiff(failcase, logits_sdpa, logits_eager, atol, rtol) )