diff --git a/tests/test_lion8b.py b/tests/test_lion8b.py index 312480f61a..018ce5f7cd 100644 --- a/tests/test_lion8b.py +++ b/tests/test_lion8b.py @@ -331,13 +331,15 @@ def test_lion8b_fused_unfused_unquantized_same(w_init: str, grad_strategy: str, # at all; latter is "ground truth" assert cossim(diffs_true, diffs_fq, dim=-1) > min_cossim assert _nmse(diffs_true, diffs_fq) < max_nmse + + # error correction not supported on torch 2.1 + if version.parse(torch.__version__) < version.parse('2.1.0'): + # fused impl with errors should also be close to "true" updates; + assert cossim(diffs_true, diffs_fqe, dim=-1) > min_cossim + assert _nmse(diffs_true, diffs_fqe) < max_nmse - # fused impl with errors should also be close to "true" updates; - assert cossim(diffs_true, diffs_fqe, dim=-1) > min_cossim - assert _nmse(diffs_true, diffs_fqe) < max_nmse - - # error correction should reduce error, or at least do no worse - assert _nmse(diffs_true, diffs_fqe) <= _nmse(diffs_true, diffs_fq) + # error correction should reduce error, or at least do no worse + assert _nmse(diffs_true, diffs_fqe) <= _nmse(diffs_true, diffs_fq) # if sgd weights aren't different than LION weights, we haven't # changed them enough to meaningfully test the LION logic @@ -601,12 +603,10 @@ def _time_kernels(N: int, D: int, min_elems_traversed: int): try: assert times[True] < times[False] + atol assert times[True] < times['NA'] + atol - assert times['ecc'] < times['NA'] + atol - print('') - print('time fused (ms): ', times[True] * 1e3) - print('time fused+ecc (ms): ', times['ecc'] * 1e3) - print('time unfused (ms): ', times[False] * 1e3) - print('time unquantized (ms): ', times['NA'] * 1e3) + + # error correction not supported on torch 2.1 + if version.parse(torch.__version__) < version.parse('2.1.0'): + assert times['ecc'] < times['NA'] + atol break except AssertionError as e: if it >= 2: # allow 3 retries to avoid flakiness