From d11b5068dd74de6694cea0cce350bc86eb2ba5b2 Mon Sep 17 00:00:00 2001 From: Titus <9048635+Titus-von-Koeller@users.noreply.github.com> Date: Wed, 21 Feb 2024 16:27:46 +0100 Subject: [PATCH] tests: fix all_close to respect max 2 positional args (#1074) --- tests/test_functional.py | 4 ++-- tests/test_modules.py | 4 ++-- tests/test_optim.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_functional.py b/tests/test_functional.py index 2d4e959ad..d4f65755f 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -26,12 +26,12 @@ def assert_all_approx_close(a, b, rtol=1e-3, atol=1e-3, count=0, throw=True): - idx = torch.isclose(a, b, rtol, atol) + idx = torch.isclose(a, b, rtol=rtol, atol=atol) sumval = (idx == 0).sum().item() if sumval > count: if throw: print(f"Too many values not close: assert {sumval} < {count}") - torch.testing.assert_close(a, b, rtol, atol) + torch.testing.assert_close(a, b, rtol=rtol, atol=atol) return sumval diff --git a/tests/test_modules.py b/tests/test_modules.py index f809aa791..674620e29 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -42,11 +42,11 @@ def get_args(): def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10): - idx = torch.isclose(a, b, rtol, atol) + idx = torch.isclose(a, b, rtol=rtol, atol=atol) sumval = (idx == 0).sum().item() if sumval > count: print(f"Too many values not close: assert {sumval} < {count}") - torch.testing.assert_close(a, b, rtol, atol) + torch.testing.assert_close(a, b, rtol=rtol, atol=atol) class LinearFunction(torch.autograd.Function): diff --git a/tests/test_optim.py b/tests/test_optim.py index e379c424a..9395b8820 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -145,7 +145,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): # since Lion can have pretty noisy updates where things lie at the boundary # allow up to 10 errors for Lion - assert_most_approx_close(p1, p2.float(), atol, rtol, max_error_count=10) + assert_most_approx_close(p1, p2.float(), atol=atol, rtol=rtol, max_error_count=10) if i % (k // 5) == 0 and i > 0: path = get_temp_dir() @@ -157,7 +157,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): rm_path(path) # since Lion can have pretty noisy updates where things lie at the boundary # allow up to 10 errors for Lion - assert_most_approx_close(p1, p2.float(), atol, rtol, max_error_count=10) + assert_most_approx_close(p1, p2.float(), atol=atol, rtol=rtol, max_error_count=10) for name1, name2 in str2statenames[optim_name]: # since Lion can have pretty noisy updates where things lie at the boundary # allow up to 10 errors for Lion