Skip to content

Commit

Permalink
Try to avoid lambda in mul_sed the same way we do in mul_bandpass
Browse files Browse the repository at this point in the history
  • Loading branch information
rmjarvis committed Oct 27, 2023
1 parent 9ae7a2a commit bb6ffb2
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 9 deletions.
24 changes: 21 additions & 3 deletions galsim/sed.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,9 +501,27 @@ def _mul_sed(self, other):

wave_list, blue_limit, red_limit = utilities.combine_wave_list(self, other)
if fast:
zfactor1 = (1.+redshift) / (1.+self.redshift)
zfactor2 = (1.+redshift) / (1.+other.redshift)
spec = lambda w: self._fast_spec(w * zfactor1) * other._fast_spec(w * zfactor2)
if (isinstance(self._fast_spec, LookupTable)
and not self._fast_spec.x_log
and not self._fast_spec.f_log):
x = wave_list / (1.0 + self.redshift)
# Add in 500 uniformly spaced values to help improve accuracy.
x = utilities.merge_sorted([x, np.linspace(x[0], x[-1], 500)])
zfactor2 = (1.+redshift) / (1.+other.redshift)
f = self._fast_spec(x) * other._fast_spec(x*zfactor2)
spec = _LookupTable(x, f, self._fast_spec.interpolant)
elif (isinstance(other._fast_spec, LookupTable)
and not other._fast_spec.x_log
and not other._fast_spec.f_log):
x = wave_list / (1.0 + other.redshift)
x = utilities.merge_sorted([x, np.linspace(x[0], x[-1], 500)])
zfactor1 = (1.+redshift) / (1.+other.redshift)
f = self._fast_spec(x*zfactor1) * other._fast_spec(x)
spec = _LookupTable(x, f, other._fast_spec.interpolant)
else:
zfactor1 = (1.+redshift) / (1.+self.redshift)
zfactor2 = (1.+redshift) / (1.+other.redshift)
spec = lambda w: self._fast_spec(w * zfactor1) * other._fast_spec(w * zfactor2)
else:
spec = lambda w: self(w * (1.+redshift)) * other(w * (1.+redshift))
spectral = self.spectral or other.spectral
Expand Down
4 changes: 2 additions & 2 deletions galsim/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,10 +392,10 @@ def _check_range(self, x):
slop = (self.x_max - self.x_min) * 1.e-6
if np.min(x,initial=self.x_min) < self.x_min - slop:
raise GalSimRangeError("x value(s) below the range of the LookupTable.",
x, self.x_min, self.x_max)
x[x<self.x_min], self.x_min, self.x_max) from None
if np.max(x,initial=self.x_max) > self.x_max + slop: # pragma: no branch
raise GalSimRangeError("x value(s) above the range of the LookupTable.",
x, self.x_min, self.x_max)
x[x>self.x_max], self.x_min, self.x_max) from None

def getArgs(self):
return self.x
Expand Down
4 changes: 2 additions & 2 deletions tests/test_chromatic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2617,7 +2617,7 @@ def test_chromatic_invariant():
assert isinstance(chrom, galsim.ChromaticTransformation)
assert not isinstance(chrom, galsim.SimpleChromaticTransformation)
check_chromatic_invariant(chrom)
np.testing.assert_allclose(chrom.sed(waves), flux * bulge_SED(waves) * waves**1.03)
np.testing.assert_allclose(chrom.sed(waves), flux * bulge_SED(waves) * waves**1.03, rtol=1.e-4)
# Not picklable, but run str, repr
str(chrom)
repr(chrom)
Expand Down Expand Up @@ -3199,7 +3199,7 @@ def test_shoot_transformation():
img = obj.drawImage(bandpass, nx=25, ny=25, scale=0.2, method='phot', rng=rng,
poisson_flux=False)
print(img.added_flux)
np.testing.assert_allclose(img.added_flux, flux)
np.testing.assert_allclose(img.added_flux, flux, rtol=1.e-6)
img = obj.drawImage(bandpass, nx=25, ny=25, scale=0.2, method='phot', rng=rng)
print(img.added_flux)
assert abs(img.added_flux - flux) > 0.1
Expand Down
10 changes: 8 additions & 2 deletions tests/test_sed.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,9 +287,12 @@ def test_SED_mul():
f = a*e
np.testing.assert_almost_equal(f(x), a(x) * e(x), 10,
err_msg="Found wrong value in SED.__mul__")
f = e*a
np.testing.assert_almost_equal(f(x), e(x) * a(x), 10,
f2 = e*a
np.testing.assert_almost_equal(f2(x), e(x) * a(x), 10,
err_msg="Found wrong value in SED.__mul__")
if sed is sed0:
check_pickle(f)
check_pickle(f2)

# SED multiplied by dimensionless, non-constant SED
g = galsim.SED('wave', 'nm', '1')
Expand All @@ -299,6 +302,9 @@ def test_SED_mul():
h2 = g*a
np.testing.assert_almost_equal(h2(x), g(x) * a(x), 10,
err_msg="Found wrong value in SED.__mul__")
if sed is sed0:
check_pickle(h)
check_pickle(h2)

assert_raises(TypeError, a.__mul__, 'invalid')

Expand Down

0 comments on commit bb6ffb2

Please sign in to comment.