From 9a5ab87d8045c4268f26f6d6adab410b7049710f Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 26 Aug 2024 14:38:29 -0500 Subject: [PATCH] test: clean up random tests --- tests/test_random.py | 400 +++++++++++++++++++++++-------------------- 1 file changed, 210 insertions(+), 190 deletions(-) diff --git a/tests/test_random.py b/tests/test_random.py index d0e78fe319..5fbd90b43f 100644 --- a/tests/test_random.py +++ b/tests/test_random.py @@ -47,7 +47,7 @@ testseed = 1000 # seed used for UniformDeviate for all tests # Warning! If you change testseed, then all of the *Result variables below must change as well. -if os.environ.get("JAX_GALSIM_TESTING", "0") == "1": +if is_jax_galsim(): # the right answer for the first three uniform deviates produced from testseed uResult = (0.0160653916, 0.228817832, 0.1609966951) @@ -252,20 +252,20 @@ def test_uniform(): # Test generate u.seed(testseed) test_array = np.empty(3) - if hasattr(galsim, "_galsim"): - u.generate(test_array) - else: + if is_jax_galsim(): test_array = u.generate(test_array) + else: + u.generate(test_array) np.testing.assert_array_almost_equal( test_array, np.array(uResult), precision, err_msg='Wrong uniform random number sequence from generate.') # Test add_generate u.seed(testseed) - if hasattr(galsim, "_galsim"): - u.add_generate(test_array) - else: + if is_jax_galsim(): test_array = u.add_generate(test_array) + else: + u.add_generate(test_array) np.testing.assert_array_almost_equal( test_array, 2.*np.array(uResult), precision, err_msg='Wrong uniform random number sequence from generate.') @@ -273,20 +273,20 @@ def test_uniform(): # Test generate with a float32 array u.seed(testseed) test_array = np.empty(3, dtype=np.float32) - if hasattr(galsim, "_galsim"): - u.generate(test_array) - else: + if is_jax_galsim(): test_array = u.generate(test_array) + else: + u.generate(test_array) np.testing.assert_array_almost_equal( test_array, np.array(uResult), precisionF, err_msg='Wrong uniform random number sequence from generate.') # Test add_generate u.seed(testseed) - if hasattr(galsim, "_galsim"): - u.add_generate(test_array) - else: + if is_jax_galsim(): test_array = u.add_generate(test_array) + else: + u.add_generate(test_array) np.testing.assert_array_almost_equal( test_array, 2.*np.array(uResult), precisionF, err_msg='Wrong uniform random number sequence from generate.') @@ -297,26 +297,26 @@ def test_uniform(): v1 = np.empty(555) v2 = np.empty(555) with single_threaded(): - if hasattr(galsim, "_galsim"): - u1.generate(v1) - else: + if is_jax_galsim(): v1 = u1.generate(v1) - with single_threaded(num_threads=10): - if hasattr(galsim, "_galsim"): - u2.generate(v2) else: + u1.generate(v1) + with single_threaded(num_threads=10): + if is_jax_galsim(): v2 = u2.generate(v2) + else: + u2.generate(v2) np.testing.assert_array_equal(v1, v2) with single_threaded(): - if hasattr(galsim, "_galsim"): - u1.add_generate(v1) - else: + if is_jax_galsim(): v1 = u1.add_generate(v1) - with single_threaded(num_threads=10): - if hasattr(galsim, "_galsim"): - u2.add_generate(v2) else: + u1.add_generate(v1) + with single_threaded(num_threads=10): + if is_jax_galsim(): v2 = u2.add_generate(v2) + else: + u2.add_generate(v2) np.testing.assert_array_equal(v1, v2) # Check picklability @@ -337,8 +337,10 @@ def test_uniform(): assert u1 != u2, "Consecutive UniformDeviate(None) compared equal!" # We shouldn't be able to construct a UniformDeviate from anything but a BaseDeviate, int, str, # or None. - if hasattr(galsim, "_galsim"): + if is_jax_galsim(): # jax galsim doesn't test this + pass + else: assert_raises(TypeError, galsim.UniformDeviate, dict()) assert_raises(TypeError, galsim.UniformDeviate, list()) assert_raises(TypeError, galsim.UniformDeviate, set()) @@ -387,7 +389,11 @@ def test_gaussian(): v1,v2 = g(),g2() print('after %d vals, next one is %s, %s'%(nvals,v1,v2)) assert v1 == v2 - if hasattr(galsim, "_galsim"): + if is_jax_galsim(): + # jax doesn't have this issue + assert g.has_reliable_discard + assert not g.generates_in_pairs + else: # Note: For Gaussian, this only works if nvals is even. g2 = galsim.GaussianDeviate(testseed, mean=gMean, sigma=gSigma) g2.discard(nvals+1, suppress_warnings=True) @@ -396,14 +402,12 @@ def test_gaussian(): assert v1 != v2 assert g.has_reliable_discard assert g.generates_in_pairs - else: - # jax doesn't have this issue - assert g.has_reliable_discard - assert not g.generates_in_pairs # If don't explicitly suppress the warning, then a warning is emitted when n is odd. g2 = galsim.GaussianDeviate(testseed, mean=gMean, sigma=gSigma) - if hasattr(galsim, "_galsim"): + if is_jax_galsim(): + pass + else: # jax doesn't do this with assert_warns(galsim.GalSimWarning): g2.discard(nvals+1) @@ -476,10 +480,10 @@ def test_gaussian(): # Test generate g.seed(testseed) test_array = np.empty(3) - if hasattr(galsim, "_galsim"): - g.generate(test_array) - else: + if is_jax_galsim(): test_array = g.generate(test_array) + else: + g.generate(test_array) np.testing.assert_array_almost_equal( test_array, np.array(gResult), precision, err_msg='Wrong Gaussian random number sequence from generate.') @@ -489,29 +493,29 @@ def test_gaussian(): g3 = galsim.GaussianDeviate(testseed, mean=5, sigma=0.3) test_array = np.empty(3) test_array.fill(gSigma**2) - if hasattr(galsim, "_galsim"): - g2.generate_from_variance(test_array) - else: + if is_jax_galsim(): test_array = g2.generate_from_variance(test_array) + else: + g2.generate_from_variance(test_array) np.testing.assert_array_almost_equal( test_array, np.array(gResult)-gMean, precision, err_msg='Wrong Gaussian random number sequence from generate_from_variance.') # After running generate_from_variance, it should be back to using the specified mean, sigma. # Note: need to round up to even number for discard, since gd generates 2 at a time. - if hasattr(galsim, "_galsim"): - g3.discard((len(test_array)+1)//2 * 2) - else: + if is_jax_galsim(): g3.discard(len(test_array)) + else: + g3.discard((len(test_array)+1)//2 * 2) print('g2,g3 = ',g2(),g3()) assert g2() == g3() # Test generate with a float32 array. g.seed(testseed) test_array = np.empty(3, dtype=np.float32) - if hasattr(galsim, "_galsim"): - g.generate(test_array) - else: + if is_jax_galsim(): test_array = g.generate(test_array) + else: + g.generate(test_array) np.testing.assert_array_almost_equal( test_array, np.array(gResult), precisionF, err_msg='Wrong Gaussian random number sequence from generate.') @@ -520,10 +524,10 @@ def test_gaussian(): g2.seed(testseed) test_array = np.empty(3, dtype=np.float32) test_array.fill(gSigma**2) - if hasattr(galsim, "_galsim"): - g2.generate_from_variance(test_array) - else: + if is_jax_galsim(): test_array = g2.generate_from_variance(test_array) + else: + g2.generate_from_variance(test_array) np.testing.assert_array_almost_equal( test_array, np.array(gResult)-gMean, precisionF, err_msg='Wrong Gaussian random number sequence from generate_from_variance.') @@ -534,45 +538,45 @@ def test_gaussian(): v1 = np.empty(555) v2 = np.empty(555) with single_threaded(): - if hasattr(galsim, "_galsim"): - g1.generate(v1) - else: + if is_jax_galsim(): v1 = g1.generate(v1) - with single_threaded(num_threads=10): - if hasattr(galsim, "_galsim"): - g2.generate(v2) else: + g1.generate(v1) + with single_threaded(num_threads=10): + if is_jax_galsim(): v2 = g2.generate(v2) + else: + g2.generate(v2) np.testing.assert_array_equal(v1, v2) with single_threaded(): - if hasattr(galsim, "_galsim"): - g1.add_generate(v1) - else: + if is_jax_galsim(): v1 = g1.add_generate(v1) - with single_threaded(num_threads=10): - if hasattr(galsim, "_galsim"): - g2.add_generate(v2) else: + g1.add_generate(v1) + with single_threaded(num_threads=10): + if is_jax_galsim(): v2 = g2.add_generate(v2) + else: + g2.add_generate(v2) np.testing.assert_array_equal(v1, v2) ud = galsim.UniformDeviate(testseed + 3) ud.generate(v1) v1 += 6.7 - if hasattr(galsim, "_galsim"): - v2[:] = v1 - else: + if is_jax_galsim(): # jax galsim makes a copy v2 = v1.copy() + else: + v2[:] = v1 with single_threaded(): - if hasattr(galsim, "_galsim"): - g1.generate_from_variance(v1) - else: + if is_jax_galsim(): v1 = g1.generate_from_variance(v1) - with single_threaded(num_threads=10): - if hasattr(galsim, "_galsim"): - g2.generate_from_variance(v2) else: + g1.generate_from_variance(v1) + with single_threaded(num_threads=10): + if is_jax_galsim(): v2 = g2.generate_from_variance(v2) + else: + g2.generate_from_variance(v2) np.testing.assert_array_equal(v1, v2) # Check picklability @@ -590,7 +594,9 @@ def test_gaussian(): assert g1 != g2, "Consecutive GaussianDeviate(None) compared equal!" # We shouldn't be able to construct a GaussianDeviate from anything but a BaseDeviate, int, str, # or None. - if hasattr(galsim, "_galsim"): + if is_jax_galsim(): + pass + else: # jax-galsim doesn't test for these things assert_raises(TypeError, galsim.GaussianDeviate, dict()) assert_raises(TypeError, galsim.GaussianDeviate, list()) @@ -709,10 +715,10 @@ def test_binomial(): # Test generate b.seed(testseed) test_array = np.empty(3) - if hasattr(galsim, "_galsim"): - b.generate(test_array) - else: + if is_jax_galsim(): test_array = b.generate(test_array) + else: + b.generate(test_array) np.testing.assert_array_almost_equal( test_array, np.array(bResult), precision, err_msg='Wrong binomial random number sequence from generate.') @@ -720,10 +726,10 @@ def test_binomial(): # Test generate with an int array b.seed(testseed) test_array = np.empty(3, dtype=int) - if hasattr(galsim, "_galsim"): - b.generate(test_array) - else: + if is_jax_galsim(): test_array = b.generate(test_array) + else: + b.generate(test_array) np.testing.assert_array_almost_equal( test_array, np.array(bResult), precisionI, err_msg='Wrong binomial random number sequence from generate.') @@ -734,26 +740,26 @@ def test_binomial(): v1 = np.empty(555) v2 = np.empty(555) with single_threaded(): - if hasattr(galsim, "_galsim"): - b1.generate(v1) - else: + if is_jax_galsim(): v1 = b1.generate(v1) - with single_threaded(num_threads=10): - if hasattr(galsim, "_galsim"): - b2.generate(v2) else: + b1.generate(v1) + with single_threaded(num_threads=10): + if is_jax_galsim(): v2 = b2.generate(v2) + else: + b2.generate(v2) np.testing.assert_array_equal(v1, v2) with single_threaded(): - if hasattr(galsim, "_galsim"): - b1.add_generate(v1) - else: + if is_jax_galsim(): v1 = b1.add_generate(v1) - with single_threaded(num_threads=10): - if hasattr(galsim, "_galsim"): - b2.add_generate(v2) else: + b1.add_generate(v1) + with single_threaded(num_threads=10): + if is_jax_galsim(): v2 = b2.add_generate(v2) + else: + b2.add_generate(v2) np.testing.assert_array_equal(v1, v2) # Check picklability @@ -771,7 +777,9 @@ def test_binomial(): assert b1 != b2, "Consecutive BinomialDeviate(None) compared equal!" # We shouldn't be able to construct a BinomialDeviate from anything but a BaseDeviate, int, str, # or None. - if hasattr(galsim, "_galsim"): + if is_jax_galsim(): + pass + else: # jax does not raise for this assert_raises(TypeError, galsim.BinomialDeviate, dict()) assert_raises(TypeError, galsim.BinomialDeviate, list()) @@ -829,23 +837,23 @@ def test_poisson(): p2.discard(nvals, suppress_warnings=True) v1,v2 = p(),p2() print('With mean = %d, after %d vals, next one is %s, %s'%(high_mean,nvals,v1,v2)) - if hasattr(galsim, "_galsim"): - assert v1 != v2 - assert not p.has_reliable_discard - else: + if is_jax_galsim(): # jax always discards reliably assert v1 == v2 assert p.has_reliable_discard + else: + assert v1 != v2 + assert not p.has_reliable_discard assert not p.generates_in_pairs # Discard normally emits a warning for Poisson p2 = galsim.PoissonDeviate(testseed, mean=pMean) - if hasattr(galsim, "_galsim"): - with assert_warns(galsim.GalSimWarning): - p2.discard(nvals) - else: + if is_jax_galsim(): # jax always discards reliably p2.discard(nvals) + else: + with assert_warns(galsim.GalSimWarning): + p2.discard(nvals) # Check seed, reset p = galsim.PoissonDeviate(testseed, mean=pMean) @@ -915,10 +923,10 @@ def test_poisson(): # Test generate p.seed(testseed) test_array = np.empty(3) - if hasattr(galsim, "_galsim"): - p.generate(test_array) - else: + if is_jax_galsim(): test_array = p.generate(test_array) + else: + p.generate(test_array) np.testing.assert_array_almost_equal( test_array, np.array(pResult), precision, err_msg='Wrong poisson random number sequence from generate.') @@ -926,10 +934,10 @@ def test_poisson(): # Test generate with an int array p.seed(testseed) test_array = np.empty(3, dtype=int) - if hasattr(galsim, "_galsim"): - p.generate(test_array) - else: + if is_jax_galsim(): test_array = p.generate(test_array) + else: + p.generate(test_array) np.testing.assert_array_almost_equal( test_array, np.array(pResult), precisionI, err_msg='Wrong poisson random number sequence from generate.') @@ -937,10 +945,10 @@ def test_poisson(): # Test generate_from_expectation p2 = galsim.PoissonDeviate(testseed, mean=77) test_array = np.array([pMean]*3, dtype=int) - if hasattr(galsim, "_galsim"): - p2.generate_from_expectation(test_array) - else: + if is_jax_galsim(): test_array = p2.generate_from_expectation(test_array) + else: + p2.generate_from_expectation(test_array) np.testing.assert_array_almost_equal( test_array, np.array(pResult), precisionI, err_msg='Wrong poisson random number sequence from generate_from_expectation.') @@ -957,26 +965,26 @@ def test_poisson(): v1 = np.empty(555) v2 = np.empty(555) with single_threaded(): - if hasattr(galsim, "_galsim"): - p1.generate(v1) - else: + if is_jax_galsim(): v1 = p1.generate(v1) - with single_threaded(num_threads=10): - if hasattr(galsim, "_galsim"): - p2.generate(v2) else: + p1.generate(v1) + with single_threaded(num_threads=10): + if is_jax_galsim(): v2 = p2.generate(v2) + else: + p2.generate(v2) np.testing.assert_array_equal(v1, v2) with single_threaded(): - if hasattr(galsim, "_galsim"): - p1.add_generate(v1) - else: + if is_jax_galsim(): v1 = p1.add_generate(v1) - with single_threaded(num_threads=10): - if hasattr(galsim, "_galsim"): - p2.add_generate(v2) else: + p1.add_generate(v1) + with single_threaded(num_threads=10): + if is_jax_galsim(): v2 = p2.add_generate(v2) + else: + p2.add_generate(v2) np.testing.assert_array_equal(v1, v2) # Check picklability @@ -994,7 +1002,9 @@ def test_poisson(): assert p1 != p2, "Consecutive PoissonDeviate(None) compared equal!" # We shouldn't be able to construct a PoissonDeviate from anything but a BaseDeviate, int, str, # or None. - if hasattr(galsim, "_galsim"): + if is_jax_galsim(): + pass + else: assert_raises(TypeError, galsim.PoissonDeviate, dict()) assert_raises(TypeError, galsim.PoissonDeviate, list()) assert_raises(TypeError, galsim.PoissonDeviate, set()) @@ -1129,20 +1139,20 @@ def test_poisson_zeromean(): # Test generate test_array = np.empty(3, dtype=int) - if hasattr(galsim, "_galsim"): - p.generate(test_array) - else: + if is_jax_galsim(): test_array = p.generate(test_array) - np.testing.assert_array_equal(test_array, 0) - if hasattr(galsim, "_galsim"): - p2.generate(test_array) else: - test_array = p2.generate(test_array) + p.generate(test_array) np.testing.assert_array_equal(test_array, 0) - if hasattr(galsim, "_galsim"): - p3.generate(test_array) + if is_jax_galsim(): + test_array = p2.generate(test_array) else: + p2.generate(test_array) + np.testing.assert_array_equal(test_array, 0) + if is_jax_galsim(): test_array = p3.generate(test_array) + else: + p3.generate(test_array) np.testing.assert_array_equal(test_array, 0) # Test generate_from_expectation @@ -1155,7 +1165,9 @@ def test_poisson_zeromean(): # Error raised if mean<0 # jax doesn't raise here - if hasattr(galsim, "_galsim"): + if is_jax_galsim(): + pass + else: with assert_raises(ValueError): p = galsim.PoissonDeviate(testseed, mean=-0.1) with assert_raises(ValueError): @@ -1277,10 +1289,10 @@ def test_weibull(): # Test generate w.seed(testseed) test_array = np.empty(3) - if hasattr(galsim, "_galsim"): - w.generate(test_array) - else: + if is_jax_galsim(): test_array = w.generate(test_array) + else: + w.generate(test_array) np.testing.assert_array_almost_equal( test_array, np.array(wResult), precision, err_msg='Wrong weibull random number sequence from generate.') @@ -1288,10 +1300,10 @@ def test_weibull(): # Test generate with a float32 array w.seed(testseed) test_array = np.empty(3, dtype=np.float32) - if hasattr(galsim, "_galsim"): - w.generate(test_array) - else: + if is_jax_galsim(): test_array = w.generate(test_array) + else: + w.generate(test_array) np.testing.assert_array_almost_equal( test_array, np.array(wResult), precisionF, err_msg='Wrong weibull random number sequence from generate.') @@ -1302,26 +1314,26 @@ def test_weibull(): v1 = np.empty(555) v2 = np.empty(555) with single_threaded(): - if hasattr(galsim, "_galsim"): - w1.generate(v1) - else: + if is_jax_galsim(): v1 = w1.generate(v1) - with single_threaded(num_threads=10): - if hasattr(galsim, "_galsim"): - w2.generate(v2) else: + w1.generate(v1) + with single_threaded(num_threads=10): + if is_jax_galsim(): v2 = w2.generate(v2) + else: + w2.generate(v2) np.testing.assert_array_equal(v1, v2) with single_threaded(): - if hasattr(galsim, "_galsim"): - w1.add_generate(v1) - else: + if is_jax_galsim(): v1 = w1.add_generate(v1) - with single_threaded(num_threads=10): - if hasattr(galsim, "_galsim"): - w2.add_generate(v2) else: + w1.add_generate(v1) + with single_threaded(num_threads=10): + if is_jax_galsim(): v2 = w2.add_generate(v2) + else: + w2.add_generate(v2) np.testing.assert_array_equal(v1, v2) # Check picklability @@ -1339,7 +1351,9 @@ def test_weibull(): assert w1 != w2, "Consecutive WeibullDeviate(None) compared equal!" # We shouldn't be able to construct a WeibullDeviate from anything but a BaseDeviate, int, str, # or None. - if hasattr(galsim, "_galsim"): + if is_jax_galsim(): + pass + else: assert_raises(TypeError, galsim.WeibullDeviate, dict()) assert_raises(TypeError, galsim.WeibullDeviate, list()) assert_raises(TypeError, galsim.WeibullDeviate, set()) @@ -1385,22 +1399,22 @@ def test_gamma(): v1,v2 = g(),g2() print('after %d vals, next one is %s, %s'%(nvals,v1,v2)) # Gamma uses at least 2 rngs per value, but can use arbitrarily more than this. - if hasattr(galsim, "_galsim"): - assert v1 != v2 - assert not g.has_reliable_discard - else: + if is_jax_galsim(): assert v1 == v2 assert g.has_reliable_discard + else: + assert v1 != v2 + assert not g.has_reliable_discard assert not g.generates_in_pairs # Discard normally emits a warning for Gamma g2 = galsim.GammaDeviate(testseed, k=gammaK, theta=gammaTheta) - if hasattr(galsim, "_galsim"): - with assert_warns(galsim.GalSimWarning): - g2.discard(nvals) - else: + if is_jax_galsim(): # jax always discards reliably g2.discard(nvals) + else: + with assert_warns(galsim.GalSimWarning): + g2.discard(nvals) # Check seed, reset g.seed(testseed) @@ -1467,10 +1481,10 @@ def test_gamma(): # Test generate g.seed(testseed) test_array = np.empty(3) - if hasattr(galsim, "_galsim"): - g.generate(test_array) - else: + if is_jax_galsim(): test_array = g.generate(test_array) + else: + g.generate(test_array) np.testing.assert_array_almost_equal( test_array, np.array(gammaResult), precision, err_msg='Wrong gamma random number sequence from generate.') @@ -1478,10 +1492,10 @@ def test_gamma(): # Test generate with a float32 array g.seed(testseed) test_array = np.empty(3, dtype=np.float32) - if hasattr(galsim, "_galsim"): - g.generate(test_array) - else: + if is_jax_galsim(): test_array = g.generate(test_array) + else: + g.generate(test_array) np.testing.assert_array_almost_equal( test_array, np.array(gammaResult), precisionF, err_msg='Wrong gamma random number sequence from generate.') @@ -1501,7 +1515,9 @@ def test_gamma(): assert g1 != g2, "Consecutive GammaDeviate(None) compared equal!" # We shouldn't be able to construct a GammaDeviate from anything but a BaseDeviate, int, str, # or None. - if hasattr(galsim, "_galsim"): + if is_jax_galsim(): + pass + else: assert_raises(TypeError, galsim.GammaDeviate, dict()) assert_raises(TypeError, galsim.GammaDeviate, list()) assert_raises(TypeError, galsim.GammaDeviate, set()) @@ -1547,22 +1563,22 @@ def test_chi2(): v1,v2 = c(),c2() print('after %d vals, next one is %s, %s'%(nvals,v1,v2)) # Chi2 uses at least 2 rngs per value, but can use arbitrarily more than this. - if hasattr(galsim, "_galsim"): - assert v1 != v2 - assert not c.has_reliable_discard - else: + if is_jax_galsim(): assert v1 == v2 assert c.has_reliable_discard + else: + assert v1 != v2 + assert not c.has_reliable_discard assert not c.generates_in_pairs # Discard normally emits a warning for Chi2 c2 = galsim.Chi2Deviate(testseed, n=chi2N) - if hasattr(galsim, "_galsim"): - with assert_warns(galsim.GalSimWarning): - c2.discard(nvals) - else: + if is_jax_galsim(): # jax always discards reliably c2.discard(nvals) + else: + with assert_warns(galsim.GalSimWarning): + c2.discard(nvals) # Check seed, reset c.seed(testseed) @@ -1629,10 +1645,10 @@ def test_chi2(): # Test generate c.seed(testseed) test_array = np.empty(3) - if hasattr(galsim, "_galsim"): - c.generate(test_array) - else: + if is_jax_galsim(): test_array = c.generate(test_array) + else: + c.generate(test_array) np.testing.assert_array_almost_equal( test_array, np.array(chi2Result), precision, err_msg='Wrong Chi^2 random number sequence from generate.') @@ -1640,10 +1656,10 @@ def test_chi2(): # Test generate with a float32 array c.seed(testseed) test_array = np.empty(3, dtype=np.float32) - if hasattr(galsim, "_galsim"): - c.generate(test_array) - else: + if is_jax_galsim(): test_array = c.generate(test_array) + else: + c.generate(test_array) np.testing.assert_array_almost_equal( test_array, np.array(chi2Result), precisionF, err_msg='Wrong Chi^2 random number sequence from generate.') @@ -1663,7 +1679,9 @@ def test_chi2(): assert c1 != c2, "Consecutive Chi2Deviate(None) compared equal!" # We shouldn't be able to construct a Chi2Deviate from anything but a BaseDeviate, int, str, # or None. - if hasattr(galsim, "_galsim"): + if is_jax_galsim(): + pass + else: assert_raises(TypeError, galsim.Chi2Deviate, dict()) assert_raises(TypeError, galsim.Chi2Deviate, list()) assert_raises(TypeError, galsim.Chi2Deviate, set()) @@ -2150,11 +2168,11 @@ def test_permute(): ind_list = list(range(n_list)) # Permute both at the same time. - if hasattr(galsim, "_galsim"): - galsim.random.permute(312, my_list, ind_list) - else: + if is_jax_galsim(): # jax requires arrays galsim.random.permute(312, np.array(my_list), np.array(ind_list)) + else: + galsim.random.permute(312, my_list, ind_list) # Make sure that everything is sensible for ind in range(n_list): @@ -2162,16 +2180,18 @@ def test_permute(): # Repeat with same seed, should do same permutation. my_list = copy.deepcopy(my_list_copy) - if hasattr(galsim, "_galsim"): - galsim.random.permute(312, my_list) - else: + if is_jax_galsim(): galsim.random.permute(312, np.array(my_list)) + else: + galsim.random.permute(312, my_list) for ind in range(n_list): assert my_list_copy[ind_list[ind]] == my_list[ind] # permute with no lists should raise TypeError # jax galsim does not raise - if hasattr(galsim, "_galsim"): + if is_jax_galsim(): + pass + else: with assert_raises(TypeError): galsim.random.permute(312) @@ -2181,16 +2201,16 @@ def test_ne(): """ Check that inequality works as expected for corner cases where the reprs of two unequal BaseDeviates may be the same due to truncation. """ - if hasattr(galsim, "_galsim"): - a = galsim.BaseDeviate(seed='1 2 3 4 5 6 7 8 9 10') - b = galsim.BaseDeviate(seed='1 2 3 7 6 5 4 8 9 10') - assert repr(a) == repr(b) - assert a != b - else: + if is_jax_galsim(): a = galsim.BaseDeviate(seed="(0, 10)") b = galsim.BaseDeviate(seed="(0, 11)") assert repr(a) != repr(b) assert a != b + else: + a = galsim.BaseDeviate(seed='1 2 3 4 5 6 7 8 9 10') + b = galsim.BaseDeviate(seed='1 2 3 7 6 5 4 8 9 10') + assert repr(a) == repr(b) + assert a != b # Check DistDeviate separately, since it overrides __repr__ and __eq__ d1 = galsim.DistDeviate(seed=a, function=galsim.LookupTable([1, 2, 3], [4, 5, 6]))