diff --git a/tests/test_functional.py b/tests/test_functional.py index 2d4e959ad..d4f65755f 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -26,12 +26,12 @@ def assert_all_approx_close(a, b, rtol=1e-3, atol=1e-3, count=0, throw=True): - idx = torch.isclose(a, b, rtol, atol) + idx = torch.isclose(a, b, rtol=rtol, atol=atol) sumval = (idx == 0).sum().item() if sumval > count: if throw: print(f"Too many values not close: assert {sumval} < {count}") - torch.testing.assert_close(a, b, rtol, atol) + torch.testing.assert_close(a, b, rtol=rtol, atol=atol) return sumval diff --git a/tests/test_modules.py b/tests/test_modules.py index f809aa791..674620e29 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -42,11 +42,11 @@ def get_args(): def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10): - idx = torch.isclose(a, b, rtol, atol) + idx = torch.isclose(a, b, rtol=rtol, atol=atol) sumval = (idx == 0).sum().item() if sumval > count: print(f"Too many values not close: assert {sumval} < {count}") - torch.testing.assert_close(a, b, rtol, atol) + torch.testing.assert_close(a, b, rtol=rtol, atol=atol) class LinearFunction(torch.autograd.Function): diff --git a/tests/test_optim.py b/tests/test_optim.py index e379c424a..9395b8820 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -145,7 +145,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): # since Lion can have pretty noisy updates where things lie at the boundary # allow up to 10 errors for Lion - assert_most_approx_close(p1, p2.float(), atol, rtol, max_error_count=10) + assert_most_approx_close(p1, p2.float(), atol=atol, rtol=rtol, max_error_count=10) if i % (k // 5) == 0 and i > 0: path = get_temp_dir() @@ -157,7 +157,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): rm_path(path) # since Lion can have pretty noisy updates where things lie at the boundary # allow up to 10 errors for Lion - assert_most_approx_close(p1, p2.float(), atol, rtol, max_error_count=10) + assert_most_approx_close(p1, p2.float(), atol=atol, rtol=rtol, max_error_count=10) for name1, name2 in str2statenames[optim_name]: # since Lion can have pretty noisy updates where things lie at the boundary # allow up to 10 errors for Lion