Skip to content

Commit

Permalink
test: clean up tests for draw
Browse files Browse the repository at this point in the history
  • Loading branch information
beckermr committed Aug 23, 2024
1 parent 8454028 commit 101a36e
Showing 1 changed file with 53 additions and 30 deletions.
83 changes: 53 additions & 30 deletions tests/test_draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,9 @@ def test_drawImage():
os.path.join(os.path.dirname(__file__), 'fits_files/tpv.fits')))

assert_raises(ValueError, obj.drawImage, bounds=galsim.BoundsI())
if hasattr(galsim, "_galsim"):
if is_jax_galsim():
pass
else:
assert_raises(ValueError, obj.drawImage, image=im10, gain=0.)
assert_raises(ValueError, obj.drawImage, image=im10, gain=-1.)
assert_raises(ValueError, obj.drawImage, image=im10, area=0.)
Expand All @@ -392,7 +394,9 @@ def test_drawImage():
# These options are invalid unless metho=phot
assert_raises(TypeError, obj.drawImage, image=im10, n_photons=3)
assert_raises(TypeError, obj.drawImage, rng=galsim.BaseDeviate(234))
if hasattr(galsim, "_galsim"):
if is_jax_galsim():
pass
else:
assert_raises(TypeError, obj.drawImage, max_extra_noise=23)
assert_raises(TypeError, obj.drawImage, poisson_flux=True)
assert_raises(TypeError, obj.drawImage, maxN=10000)
Expand Down Expand Up @@ -523,14 +527,14 @@ def test_drawKImage():
"""Test the various optional parameters to the drawKImage function.
In particular test the parameters image, and scale in various combinations.
"""
if hasattr(galsim, "_galsim"):
maxk_threshold = 1.e-4
N = 1174
Ns = 37
else:
if is_jax_galsim():
maxk_threshold = 1.e-3
N = 880
Ns = 28
else:
maxk_threshold = 1.e-4
N = 1174
Ns = 37

# We use a Moffat profile with beta = 1.5, since its real-space profile is
# flux / (2 pi rD^2) * (1 + (r/rD)^2)^3/2
Expand Down Expand Up @@ -1109,31 +1113,31 @@ def test_shoot():
# in exact arithmetic. We had an assert there which blew up in a not very nice way.
obj = galsim.Gaussian(sigma=0.2398318) + 0.1*galsim.Gaussian(sigma=0.47966352)
obj = obj.withFlux(100001)
if hasattr(galsim, "_galsim"):
image1 = galsim.ImageF(32,32, init_value=100)
else:
if is_jax_galsim():
# jax galsim needs double images here
image1 = galsim.ImageD(32,32, init_value=100)
else:
image1 = galsim.ImageF(32,32, init_value=100)
rng = galsim.BaseDeviate(1234)
obj.drawImage(image1, method='phot', poisson_flux=False, add_to_image=True, rng=rng,
maxN=100000)

# The test here is really just that it doesn't crash.
# But let's do something to check correctness.
if hasattr(galsim, "_galsim"):
image2 = galsim.ImageF(32,32)
else:
if is_jax_galsim():
# jax galsim needs double images here
image2 = galsim.ImageD(32,32)
else:
image2 = galsim.ImageF(32,32)
rng = galsim.BaseDeviate(1234)
obj.drawImage(image2, method='phot', poisson_flux=False, add_to_image=False, rng=rng,
maxN=100000)
image2 += 100
if hasattr(galsim, "_galsim"):
np.testing.assert_array_almost_equal(image2.array, image1.array, decimal=12)
else:
if is_jax_galsim():
# jax galsim works not as well
np.testing.assert_array_almost_equal(image2.array, image1.array, decimal=10)
else:
np.testing.assert_array_almost_equal(image2.array, image1.array, decimal=12)

# Also check that you get the same answer with a smaller maxN.
image3 = galsim.ImageF(32,32, init_value=100)
Expand All @@ -1148,7 +1152,9 @@ def test_shoot():

# Warns if flux is 1 and n_photons not given.
psf = galsim.Gaussian(sigma=3)
if hasattr(galsim, "_galsim"):
if is_jax_galsim():
pass
else:
with assert_warns(galsim.GalSimWarning):
psf.drawImage(method='phot')
with assert_warns(galsim.GalSimWarning):
Expand Down Expand Up @@ -1213,20 +1219,26 @@ def test_drawImage_area_exptime():

# Shooting with flux=1 raises a warning.
obj1 = obj.withFlux(1)
if hasattr(galsim, "_galsim"):
if is_jax_galsim():
pass
else:
with assert_warns(galsim.GalSimWarning):
obj1.drawImage(method='phot')
# But not if we explicitly tell it to shoot 1 photon
with assert_raises(AssertionError):
assert_warns(galsim.GalSimWarning, obj1.drawImage, method='phot', n_photons=1)
# Likewise for makePhot
if hasattr(galsim, "_galsim"):
if is_jax_galsim():
pass
else:
with assert_warns(galsim.GalSimWarning):
obj1.makePhot()
with assert_raises(AssertionError):
assert_warns(galsim.GalSimWarning, obj1.makePhot, n_photons=1)
# And drawPhot
if hasattr(galsim, "_galsim"):
if is_jax_galsim():
pass
else:
with assert_warns(galsim.GalSimWarning):
obj1.drawPhot(im1)
with assert_raises(AssertionError):
Expand All @@ -1247,7 +1259,11 @@ def test_fft():
[4,6,8,4],
[2,4,6,6] ],
xmin=-2, ymin=-2, dtype=dt, scale=0.1)
if hasattr(galsim, "_galsim") or dt not in [np.complex128, complex]:
if is_gjax_galsim() and dt not in [np.complex128, complex]:
kim = xim.calculate_fft()
xim2 = kim.calculate_inverse_fft()
np.testing.assert_array_almost_equal(xim.array, xim2.array)
else:
kim = xim.calculate_fft()
xim2 = kim.calculate_inverse_fft()
np.testing.assert_array_almost_equal(xim.array, xim2.array)
Expand Down Expand Up @@ -1282,7 +1298,11 @@ def test_fft():
xim2 = galsim.Image([ [2,4,6],
[4,6,8] ],
xmin=-2, ymin=-1, dtype=dt, scale=0.1)
if hasattr(galsim, "_galsim") or dt not in [np.complex128, complex]:
if is_gjax_galsim() and dt not in [np.complex128, complex]:
kim = xim.calculate_fft()
kim2 = xim2.calculate_fft()
np.testing.assert_array_almost_equal(kim.array, kim2.array)
else:
kim = xim.calculate_fft()
kim2 = xim2.calculate_fft()
np.testing.assert_array_almost_equal(kim.array, kim2.array)
Expand All @@ -1305,14 +1325,14 @@ def test_fft():
# Now use drawKImage (as above in test_drawKImage) to get a more realistic k-space image
# NB. It is useful to have this come out not a multiple of 4, since some of the
# calculation needs to be different when N/2 is odd.
if hasattr(galsim, "_galsim"):
maxk_threshold = 1.e-4
N = 1174
Nfft = 1536
else:
if is_jax_galsim():
maxk_threshold = 0.78e-3
N = 912
Nfft = 1024
else:
maxk_threshold = 1.e-4
N = 1174
Nfft = 1536
obj = galsim.Moffat(flux=test_flux, beta=1.5, scale_radius=0.5)
obj = obj.withGSParams(maxk_threshold=maxk_threshold)
im1 = obj.drawKImage()
Expand Down Expand Up @@ -1562,9 +1582,12 @@ def test_np_fft():
def round_cast(array, dt):
# array.astype(dt) doesn't round to the nearest for integer types.
# This rounds first if dt is integer and then casts.
# NOTE JAX doesn't round to the nearest int when drawing
if hasattr(galsim, "_galsim") and dt(0.5) != 0.5:
array = np.around(array)
if is_jax_galsim():
# NOTE JAX doesn't round to the nearest int when drawing
pass
else:
if dt(0.5) != 0.5:
array = np.around(array)
return array.astype(dt)

@timer
Expand Down

0 comments on commit 101a36e

Please sign in to comment.