diff --git a/tests/test_draw.py b/tests/test_draw.py index 55dc7ab55a..07cbc8aad3 100644 --- a/tests/test_draw.py +++ b/tests/test_draw.py @@ -380,7 +380,9 @@ def test_drawImage(): os.path.join(os.path.dirname(__file__), 'fits_files/tpv.fits'))) assert_raises(ValueError, obj.drawImage, bounds=galsim.BoundsI()) - if hasattr(galsim, "_galsim"): + if is_jax_galsim(): + pass + else: assert_raises(ValueError, obj.drawImage, image=im10, gain=0.) assert_raises(ValueError, obj.drawImage, image=im10, gain=-1.) assert_raises(ValueError, obj.drawImage, image=im10, area=0.) @@ -392,7 +394,9 @@ def test_drawImage(): # These options are invalid unless metho=phot assert_raises(TypeError, obj.drawImage, image=im10, n_photons=3) assert_raises(TypeError, obj.drawImage, rng=galsim.BaseDeviate(234)) - if hasattr(galsim, "_galsim"): + if is_jax_galsim(): + pass + else: assert_raises(TypeError, obj.drawImage, max_extra_noise=23) assert_raises(TypeError, obj.drawImage, poisson_flux=True) assert_raises(TypeError, obj.drawImage, maxN=10000) @@ -523,14 +527,14 @@ def test_drawKImage(): """Test the various optional parameters to the drawKImage function. In particular test the parameters image, and scale in various combinations. """ - if hasattr(galsim, "_galsim"): - maxk_threshold = 1.e-4 - N = 1174 - Ns = 37 - else: + if is_jax_galsim(): maxk_threshold = 1.e-3 N = 880 Ns = 28 + else: + maxk_threshold = 1.e-4 + N = 1174 + Ns = 37 # We use a Moffat profile with beta = 1.5, since its real-space profile is # flux / (2 pi rD^2) * (1 + (r/rD)^2)^3/2 @@ -1109,31 +1113,31 @@ def test_shoot(): # in exact arithmetic. We had an assert there which blew up in a not very nice way. obj = galsim.Gaussian(sigma=0.2398318) + 0.1*galsim.Gaussian(sigma=0.47966352) obj = obj.withFlux(100001) - if hasattr(galsim, "_galsim"): - image1 = galsim.ImageF(32,32, init_value=100) - else: + if is_jax_galsim(): # jax galsim needs double images here image1 = galsim.ImageD(32,32, init_value=100) + else: + image1 = galsim.ImageF(32,32, init_value=100) rng = galsim.BaseDeviate(1234) obj.drawImage(image1, method='phot', poisson_flux=False, add_to_image=True, rng=rng, maxN=100000) # The test here is really just that it doesn't crash. # But let's do something to check correctness. - if hasattr(galsim, "_galsim"): - image2 = galsim.ImageF(32,32) - else: + if is_jax_galsim(): # jax galsim needs double images here image2 = galsim.ImageD(32,32) + else: + image2 = galsim.ImageF(32,32) rng = galsim.BaseDeviate(1234) obj.drawImage(image2, method='phot', poisson_flux=False, add_to_image=False, rng=rng, maxN=100000) image2 += 100 - if hasattr(galsim, "_galsim"): - np.testing.assert_array_almost_equal(image2.array, image1.array, decimal=12) - else: + if is_jax_galsim(): # jax galsim works not as well np.testing.assert_array_almost_equal(image2.array, image1.array, decimal=10) + else: + np.testing.assert_array_almost_equal(image2.array, image1.array, decimal=12) # Also check that you get the same answer with a smaller maxN. image3 = galsim.ImageF(32,32, init_value=100) @@ -1148,7 +1152,9 @@ def test_shoot(): # Warns if flux is 1 and n_photons not given. psf = galsim.Gaussian(sigma=3) - if hasattr(galsim, "_galsim"): + if is_jax_galsim(): + pass + else: with assert_warns(galsim.GalSimWarning): psf.drawImage(method='phot') with assert_warns(galsim.GalSimWarning): @@ -1213,20 +1219,26 @@ def test_drawImage_area_exptime(): # Shooting with flux=1 raises a warning. obj1 = obj.withFlux(1) - if hasattr(galsim, "_galsim"): + if is_jax_galsim(): + pass + else: with assert_warns(galsim.GalSimWarning): obj1.drawImage(method='phot') # But not if we explicitly tell it to shoot 1 photon with assert_raises(AssertionError): assert_warns(galsim.GalSimWarning, obj1.drawImage, method='phot', n_photons=1) # Likewise for makePhot - if hasattr(galsim, "_galsim"): + if is_jax_galsim(): + pass + else: with assert_warns(galsim.GalSimWarning): obj1.makePhot() with assert_raises(AssertionError): assert_warns(galsim.GalSimWarning, obj1.makePhot, n_photons=1) # And drawPhot - if hasattr(galsim, "_galsim"): + if is_jax_galsim(): + pass + else: with assert_warns(galsim.GalSimWarning): obj1.drawPhot(im1) with assert_raises(AssertionError): @@ -1247,7 +1259,11 @@ def test_fft(): [4,6,8,4], [2,4,6,6] ], xmin=-2, ymin=-2, dtype=dt, scale=0.1) - if hasattr(galsim, "_galsim") or dt not in [np.complex128, complex]: + if is_gjax_galsim() and dt not in [np.complex128, complex]: + kim = xim.calculate_fft() + xim2 = kim.calculate_inverse_fft() + np.testing.assert_array_almost_equal(xim.array, xim2.array) + else: kim = xim.calculate_fft() xim2 = kim.calculate_inverse_fft() np.testing.assert_array_almost_equal(xim.array, xim2.array) @@ -1282,7 +1298,11 @@ def test_fft(): xim2 = galsim.Image([ [2,4,6], [4,6,8] ], xmin=-2, ymin=-1, dtype=dt, scale=0.1) - if hasattr(galsim, "_galsim") or dt not in [np.complex128, complex]: + if is_gjax_galsim() and dt not in [np.complex128, complex]: + kim = xim.calculate_fft() + kim2 = xim2.calculate_fft() + np.testing.assert_array_almost_equal(kim.array, kim2.array) + else: kim = xim.calculate_fft() kim2 = xim2.calculate_fft() np.testing.assert_array_almost_equal(kim.array, kim2.array) @@ -1305,14 +1325,14 @@ def test_fft(): # Now use drawKImage (as above in test_drawKImage) to get a more realistic k-space image # NB. It is useful to have this come out not a multiple of 4, since some of the # calculation needs to be different when N/2 is odd. - if hasattr(galsim, "_galsim"): - maxk_threshold = 1.e-4 - N = 1174 - Nfft = 1536 - else: + if is_jax_galsim(): maxk_threshold = 0.78e-3 N = 912 Nfft = 1024 + else: + maxk_threshold = 1.e-4 + N = 1174 + Nfft = 1536 obj = galsim.Moffat(flux=test_flux, beta=1.5, scale_radius=0.5) obj = obj.withGSParams(maxk_threshold=maxk_threshold) im1 = obj.drawKImage() @@ -1562,9 +1582,12 @@ def test_np_fft(): def round_cast(array, dt): # array.astype(dt) doesn't round to the nearest for integer types. # This rounds first if dt is integer and then casts. - # NOTE JAX doesn't round to the nearest int when drawing - if hasattr(galsim, "_galsim") and dt(0.5) != 0.5: - array = np.around(array) + if is_jax_galsim(): + # NOTE JAX doesn't round to the nearest int when drawing + pass + else: + if dt(0.5) != 0.5: + array = np.around(array) return array.astype(dt) @timer