From 501c0fd751e9a7067f9554d94857ec1df0020971 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 13 Sep 2024 16:51:11 -0500 Subject: [PATCH] test: just do less for jax --- tests/test_interpolatedimage.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/test_interpolatedimage.py b/tests/test_interpolatedimage.py index b953012fa7..57503e09a0 100644 --- a/tests/test_interpolatedimage.py +++ b/tests/test_interpolatedimage.py @@ -404,14 +404,18 @@ def test_unit_integrals(): print(str(interp)) # Compute directly with int1d n = interp.ixrange//2 + 1 - print("number of intervas: ",n) + if is_jax_galsim(): + # jax galsim is slow when doing direct integration + _n_do = min(n, 100) + else: + _n_do = n + direct_integrals = np.zeros(n) if isinstance(interp, galsim.Delta): # int1d doesn't handle this well. direct_integrals[0] = 1 else: - for k in range(n): - print(k, n) + for k in range(_n_do): direct_integrals[k] = galsim.integ.int1d(interp.xval, k-0.5, k+0.5) print('direct: ',direct_integrals) @@ -420,7 +424,7 @@ def test_unit_integrals(): print('integrals: ',len(integrals),integrals) assert len(integrals) == n - np.testing.assert_allclose(integrals, direct_integrals, atol=1.e-12) + np.testing.assert_allclose(integrals[_n_do], direct_integrals[_n_do], atol=1.e-12) if n > 10: print('n>10 for ',repr(interp))