Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
ydshieh committed Nov 3, 2024
1 parent 4a7ac3d commit f214b55
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
6 changes: 3 additions & 3 deletions tests/models/musicgen/test_modeling_musicgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1479,7 +1479,7 @@ def test_sdpa_can_dispatch_composite_models(self):
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
for _ in range(100):
for _ in range(30):
self._test_eager_matches_sdpa_inference(torch_dtype)

def _test_eager_matches_sdpa_inference(self, torch_dtype: str):
Expand Down Expand Up @@ -1508,10 +1508,10 @@ def _test_eager_matches_sdpa_inference(self, torch_dtype: str):
atols = {
("cpu", False, torch.float32): 1e-6,
("cpu", False, torch.float16): 5e-3,
("cpu", False, torch.bfloat16): 5e-2,
("cpu", False, torch.bfloat16): 1e-2,
("cpu", True, torch.float32): 1e-6,
("cpu", True, torch.float16): 5e-3,
("cpu", True, torch.bfloat16): 5e-2,
("cpu", True, torch.bfloat16): 1e-2,
("cuda", False, torch.float32): 1e-6,
("cuda", False, torch.bfloat16): 1e-2,
("cuda", False, torch.float16): 5e-3,
Expand Down
6 changes: 3 additions & 3 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3931,7 +3931,7 @@ def test_sdpa_can_dispatch_composite_models(self):
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
for _ in range(100):
for _ in range(30):
self._test_eager_matches_sdpa_inference(torch_dtype)

def _test_eager_matches_sdpa_inference(self, torch_dtype: str):
Expand Down Expand Up @@ -3960,10 +3960,10 @@ def _test_eager_matches_sdpa_inference(self, torch_dtype: str):
atols = {
("cpu", False, torch.float32): 1e-6,
("cpu", False, torch.float16): 5e-3,
("cpu", False, torch.bfloat16): 5e-2,
("cpu", False, torch.bfloat16): 1e-2,
("cpu", True, torch.float32): 1e-6,
("cpu", True, torch.float16): 5e-3,
("cpu", True, torch.bfloat16): 5e-2,
("cpu", True, torch.bfloat16): 1e-2,
("cuda", False, torch.float32): 1e-6,
("cuda", False, torch.bfloat16): 1e-2,
("cuda", False, torch.float16): 5e-3,
Expand Down

0 comments on commit f214b55

Please sign in to comment.