Skip to content

Commit

Permalink
more test skipping
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Oct 6, 2023
1 parent 9cd5ec4 commit 82ae824
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions tests/test_lion8b.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,13 +331,15 @@ def test_lion8b_fused_unfused_unquantized_same(w_init: str, grad_strategy: str,
# at all; latter is "ground truth"
assert cossim(diffs_true, diffs_fq, dim=-1) > min_cossim
assert _nmse(diffs_true, diffs_fq) < max_nmse

# error correction not supported on torch 2.1
if version.parse(torch.__version__) < version.parse('2.1.0'):
# fused impl with errors should also be close to "true" updates;
assert cossim(diffs_true, diffs_fqe, dim=-1) > min_cossim
assert _nmse(diffs_true, diffs_fqe) < max_nmse

# fused impl with errors should also be close to "true" updates;
assert cossim(diffs_true, diffs_fqe, dim=-1) > min_cossim
assert _nmse(diffs_true, diffs_fqe) < max_nmse

# error correction should reduce error, or at least do no worse
assert _nmse(diffs_true, diffs_fqe) <= _nmse(diffs_true, diffs_fq)
# error correction should reduce error, or at least do no worse
assert _nmse(diffs_true, diffs_fqe) <= _nmse(diffs_true, diffs_fq)

# if sgd weights aren't different than LION weights, we haven't
# changed them enough to meaningfully test the LION logic
Expand Down Expand Up @@ -601,12 +603,10 @@ def _time_kernels(N: int, D: int, min_elems_traversed: int):
try:
assert times[True] < times[False] + atol
assert times[True] < times['NA'] + atol
assert times['ecc'] < times['NA'] + atol
print('')
print('time fused (ms): ', times[True] * 1e3)
print('time fused+ecc (ms): ', times['ecc'] * 1e3)
print('time unfused (ms): ', times[False] * 1e3)
print('time unquantized (ms): ', times['NA'] * 1e3)

# error correction not supported on torch 2.1
if version.parse(torch.__version__) < version.parse('2.1.0'):
assert times['ecc'] < times['NA'] + atol
break
except AssertionError as e:
if it >= 2: # allow 3 retries to avoid flakiness
Expand Down

0 comments on commit 82ae824

Please sign in to comment.