Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
ydshieh committed Nov 3, 2024
1 parent ee9033b commit e4e2ece
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3987,7 +3987,7 @@ def _test_eager_matches_sdpa_inference(self, torch_dtype: str):
}

def get_mean_reldiff(failcase, x, ref, atol, rtol):
return f"{failcase}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}"
return f"{failcase}: max abs difference: {torch.amax(torch.abs(x - ref)):.6e}"

if hasattr(self.model_tester, "num_hidden_layers"):
self.model_tester.num_hidden_layers = 1
Expand Down Expand Up @@ -4201,7 +4201,7 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol):
for (_logits_sdpa, _logits_eager) in zip(logits_sdpa, logits_eager)
]
# If 80% batch elements have matched results, it's fine
if np.mean(results) < 0.8:
if np.mean(results) < 1.8:
fail_cases.append(
get_mean_reldiff(failcase, logits_sdpa, logits_eager, atol, rtol)
)
Expand Down

0 comments on commit e4e2ece

Please sign in to comment.