Skip to content

Commit

Permalink
Make thin_tabulated work for spline interpolation
Browse files Browse the repository at this point in the history
  • Loading branch information
rmjarvis committed Oct 27, 2023
1 parent af33266 commit 9ae7a2a
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 77 deletions.
5 changes: 2 additions & 3 deletions galsim/bandpass.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,9 +596,8 @@ def thin(self, rel_err=1.e-4, trim_zeros=True, preserve_range=True, fast_search=
newx, newf = utilities.thin_tabulated_values(x, f, rel_err=rel_err,
trim_zeros=trim_zeros,
preserve_range=preserve_range,
fast_search=fast_search)
interpolant = (self.interpolant if not isinstance(self._tp, LookupTable)
else self._tp.interpolant)
fast_search=fast_search,
interpolant=interpolant)
tp = _LookupTable(newx, newf, interpolant)
blue_limit = np.min(newx)
red_limit = np.max(newx)
Expand Down
7 changes: 4 additions & 3 deletions galsim/sed.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,12 +891,13 @@ def thin(self, rel_err=1.e-4, trim_zeros=True, preserve_range=True, fast_search=
spec_native = self._spec(rest_wave_native)

# Note that this is thinning in native units, not nm and photons/nm.
interpolant = (self.interpolant if not isinstance(self._spec, LookupTable)
else self._spec.interpolant)
newx, newf = utilities.thin_tabulated_values(
rest_wave_native, spec_native, rel_err=rel_err,
trim_zeros=trim_zeros, preserve_range=preserve_range, fast_search=fast_search)
trim_zeros=trim_zeros, preserve_range=preserve_range,
fast_search=fast_search, interpolant=interpolant)

interpolant = (self.interpolant if not isinstance(self._spec, LookupTable)
else self._spec.interpolant)
newspec = _LookupTable(newx, newf, interpolant=interpolant)
return SED(newspec, self.wave_type, self.flux_type, redshift=self.redshift,
fast=self.fast)
Expand Down
99 changes: 74 additions & 25 deletions galsim/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from .position import Position, PositionD, PositionI, _PositionD, _PositionI
from .angle import AngleUnit, arcsec
from .image import Image
from .table import trapz, LookupTable2D
from .table import trapz, _LookupTable, LookupTable2D
from .wcs import JacobianWCS, PixelScale
from .position import _PositionD
from .random import BaseDeviate, UniformDeviate
Expand Down Expand Up @@ -223,32 +223,58 @@ def _convertPositions(pos, units, func):

return pos

def _lin_approx_err(x, f, i):
def _spline_approx_err(x, f, left, right, splitpoints, i):
# For splines, we can't just do the integral over a small range, since the spline slopes
# are all wrong. Rather we compute a spline function with the current splitpoints and
# just the single point in the trial region and recompute a spline function with that.
# Then we can compute the total error from that approximation.
# (For the error integral, we still use linear.)

indices = sorted(splitpoints + [i])
new_tab = _LookupTable(x[indices], f[indices], 'spline')

xleft, xright = x[left:i+1], x[i:right+1]
fleft, fright = f[left:i+1], f[i:right+1]
f2left = new_tab(xleft)
f2right = new_tab(xright)
return trapz(np.abs(fleft-f2left), xleft), trapz(np.abs(fright-f2right), xright)

def _spline_approx_split(x, f, left, right, splitpoints):
r"""Split a tabulated function into a two-part piecewise spline approximation by exactly
minimizing \int abs(f(x) - approx(x)) dx. Operates in O(N^2) time.
"""
errs = [_spline_approx_err(x, f, left, right, splitpoints, i) for i in range(left+1, right)]
i = np.argmin(np.sum(errs, axis=1))
return i+left+1, errs[i]

def _lin_approx_err(x, f, left, right, i):
r"""Error as \int abs(f(x) - approx(x)) when using ith data point to make piecewise linear
approximation.
"""
xleft, xright = x[:i+1], x[i:]
fleft, fright = f[:i+1], f[i:]
xleft, xright = x[left:i+1], x[i:right+1]
fleft, fright = f[left:i+1], f[i:right+1]
xi, fi = x[i], f[i]
mleft = (fi-f[0])/(xi-x[0])
mright = (f[-1]-fi)/(x[-1]-xi)
f2left = f[0]+mleft*(xleft-x[0])
mleft = (fi-f[left])/(xi-x[left])
mright = (f[right]-fi)/(x[right]-xi)
f2left = f[left]+mleft*(xleft-x[left])
f2right = fi+mright*(xright-xi)
return trapz(np.abs(fleft-f2left), xleft), trapz(np.abs(fright-f2right), xright)

def _exact_lin_approx_split(x, f):
def _exact_lin_approx_split(x, f, left, right, splitpoints):
r"""Split a tabulated function into a two-part piecewise linear approximation by exactly
minimizing \int abs(f(x) - approx(x)) dx. Operates in O(N^2) time.
"""
errs = [_lin_approx_err(x, f, i) for i in range(1, len(x)-1)]
errs = [_lin_approx_err(x, f, left, right, i) for i in range(left+1, right)]
i = np.argmin(np.sum(errs, axis=1))
return i+1, errs[i]
return i+left+1, errs[i]

def _lin_approx_split(x, f):
def _lin_approx_split(x, f, left, right, splitpoints):
r"""Split a tabulated function into a two-part piecewise linear approximation by approximately
minimizing \int abs(f(x) - approx(x)) dx. Chooses the split point by exactly minimizing
\int (f(x) - approx(x))^2 dx in O(N) time.
"""
x = x[left:right+1]
f = f[left:right+1]
dx = x[2:] - x[:-2]
# Error contribution on the left.
ff0 = f[1:-1]-f[0] # Only need to search between j=1..(N-1)
Expand All @@ -269,10 +295,10 @@ def _lin_approx_split(x, f):

# Get absolute error for the found point.
i = np.argmin(errleft+errright)
return i+1, _lin_approx_err(x, f, i+1)
return i+left+1, _lin_approx_err(x, f, 0, len(x)-1, i+1)

def thin_tabulated_values(x, f, rel_err=1.e-4, trim_zeros=True, preserve_range=True,
fast_search=True):
fast_search=True, interpolant='linear'):
"""
Remove items from a set of tabulated f(x) values so that the error in the integral is still
accurate to a given relative accuracy.
Expand All @@ -299,11 +325,17 @@ def thin_tabulated_values(x, f, rel_err=1.e-4, trim_zeros=True, preserve_range=T
found that the slower algorithm tends to yield a thinned representation
that retains fewer samples while still meeting the relative error
requirement. [default: True]
interpolant: The interpolant to assume for the tabulated values. [default: 'linear']
Returns:
a tuple of lists (x_new, y_new) with the thinned tabulation.
"""
split_fn = _lin_approx_split if fast_search else _exact_lin_approx_split
if interpolant == 'spline':
split_fn = _spline_approx_split
elif fast_search:
split_fn = _lin_approx_split
else:
split_fn = _exact_lin_approx_split

x = np.asarray(x, dtype=float)
f = np.asarray(f, dtype=float)
Expand All @@ -321,7 +353,7 @@ def thin_tabulated_values(x, f, rel_err=1.e-4, trim_zeros=True, preserve_range=T
# Nothing to do
return x,f

total_integ = trapz(abs(f), x)
total_integ = trapz(abs(f), x, interpolant)
if total_integ == 0:
return np.array([ x[0], x[-1] ]), np.array([ f[0], f[-1] ])
thresh = total_integ * rel_err
Expand All @@ -331,11 +363,11 @@ def thin_tabulated_values(x, f, rel_err=1.e-4, trim_zeros=True, preserve_range=T
last = min(f.nonzero()[0][-1]+1, len(x)-1) # +1 to keep one non-redundant zero.
x, f = x[first:last+1], f[first:last+1]

x_range = x[-1] - x[0]
if not preserve_range:
# Remove values from the front that integrate to less than thresh.
err_integ1 = 0.5 * (abs(f[0]) + abs(f[1])) * (x[1] - x[0])
k0 = 0
x_range = x[-1] - x[0]
while k0 < len(x)-2 and err_integ1 < thresh * (x[k0+1]-x[0]) / x_range:
k0 = k0+1
err_integ1 += 0.5 * (abs(f[k0]) + abs(f[k0+1])) * (x[k0+1] - x[k0])
Expand All @@ -352,14 +384,15 @@ def thin_tabulated_values(x, f, rel_err=1.e-4, trim_zeros=True, preserve_range=T
# That means k1 is the smallest value we can use that will work as the ending value.

# Subtract the error so far from thresh
thresh -= trapz(abs(f[:k0]),x[:k0]) + trapz(abs(f[k1:]),x[k1:])
if interpolant == 'spline':
new_integ = trapz(abs(f[k0:k1+1]),x[k0:k1+1], interpolant='spline')
thresh -= np.abs(new_integ-total_integ)
else:
thresh -= trapz(abs(f[:k0]),x[:k0]) + trapz(abs(f[k1:]),x[k1:])

x = x[k0:k1+1] # +1 since end of range is given as one-past-the-end.
f = f[k0:k1+1]

# And update x_range for the new values
x_range = x[-1] - x[0]

# Check again for noop after trimming endpoints.
if len(x) <= 2:
return x,f
Expand All @@ -370,12 +403,28 @@ def thin_tabulated_values(x, f, rel_err=1.e-4, trim_zeros=True, preserve_range=T
heap = [(-2*thresh, # -err; initialize large enough to trigger while loop below.
0, # first index of interval
len(x)-1)] # last index of interval
while (-sum(h[0] for h in heap) > thresh):
splitpoints = [0,len(x)-1]
while len(heap) > 0:
_, left, right = heapq.heappop(heap)
i, (errleft, errright) = split_fn(x[left:right+1], f[left:right+1])
heapq.heappush(heap, (-errleft, left, i+left))
heapq.heappush(heap, (-errright, i+left, right))
splitpoints = sorted([0]+[h[2] for h in heap])
i, (errleft, errright) = split_fn(x, f, left, right, splitpoints)
splitpoints.append(i)
if i > left+1:
heapq.heappush(heap, (-errleft, left, i))
if right > i+1:
heapq.heappush(heap, (-errright, i, right))
if interpolant != 'spline':
# This is a sufficient stopping criterion for linear
if (-sum(h[0] for h in heap) < thresh):
break
else:
# For spline, we also need to recompute the total integral to make sure
# that the realized total error is less than thresh.
if (-sum(h[0] for h in heap) < thresh):
splitpoints = sorted(splitpoints)
current_integ = trapz(f[splitpoints], x[splitpoints], interpolant)
if np.abs(current_integ - total_integ) < thresh:
break
splitpoints = sorted(splitpoints)
return x[splitpoints], f[splitpoints]

def old_thin_tabulated_values(x, f, rel_err=1.e-4, preserve_range=False): # pragma: no cover
Expand Down
48 changes: 26 additions & 22 deletions tests/test_bandpass.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,28 +319,32 @@ def test_ne():
def test_thin():
"""Test that bandpass thinning works with the requested accuracy."""
s = galsim.SED('1', wave_type='nm', flux_type='fphotons')
bp = galsim.Bandpass(os.path.join(datapath, 'LSST_r.dat'), 'nm')
flux = s.calculateFlux(bp)
print("Original number of bandpass samples = ",len(bp.wave_list))
for err in [1.e-2, 1.e-3, 1.e-4, 1.e-5]:
print("Test err = ",err)
thin_bp = bp.thin(rel_err=err, preserve_range=True, fast_search=False)
thin_flux = s.calculateFlux(thin_bp)
thin_err = (flux-thin_flux)/flux
print("num samples with preserve_range = True, fast_search = False: ",len(thin_bp.wave_list))
print("realized error = ",(flux-thin_flux)/flux)
thin_bp = bp.thin(rel_err=err, preserve_range=True)
thin_flux = s.calculateFlux(thin_bp)
thin_err = (flux-thin_flux)/flux
print("num samples with preserve_range = True: ",len(thin_bp.wave_list))
print("realized error = ",(flux-thin_flux)/flux)
assert np.abs(thin_err) < err, "Thinned bandpass failed accuracy goal, preserving range."
thin_bp = bp.thin(rel_err=err, preserve_range=False)
thin_flux = s.calculateFlux(thin_bp)
thin_err = (flux-thin_flux)/flux
print("num samples with preserve_range = False: ",len(thin_bp.wave_list))
print("realized error = ",(flux-thin_flux)/flux)
assert np.abs(thin_err) < err, "Thinned bandpass failed accuracy goal, w/ range shrinkage."
bp1 = galsim.Bandpass(os.path.join(datapath, 'LSST_r.dat'), 'nm')
bp2 = galsim.Bandpass(os.path.join(datapath, 'LSST_r.dat'), 'nm', interpolant='spline')

for bp in [bp1, bp2]:
flux = s.calculateFlux(bp)
print("Original number of bandpass samples = ",len(bp.wave_list))
for err in [1.e-2, 1.e-3, 1.e-4, 1.e-5]:
print("Test err = ",err)
thin_bp = bp.thin(rel_err=err, preserve_range=True, fast_search=False)
thin_flux = s.calculateFlux(thin_bp)
thin_err = (flux-thin_flux)/flux
print("num samples with preserve_range = True, fast_search = False: ",
len(thin_bp.wave_list))
print("realized error = ",(flux-thin_flux)/flux)
thin_bp = bp.thin(rel_err=err, preserve_range=True)
thin_flux = s.calculateFlux(thin_bp)
thin_err = (flux-thin_flux)/flux
print("num samples with preserve_range = True: ",len(thin_bp.wave_list))
print("realized error = ",(flux-thin_flux)/flux)
assert np.abs(thin_err) < err, "Thinned bandpass failed accuracy goal, preserving range."
thin_bp = bp.thin(rel_err=err, preserve_range=False)
thin_flux = s.calculateFlux(thin_bp)
thin_err = (flux-thin_flux)/flux
print("num samples with preserve_range = False: ",len(thin_bp.wave_list))
print("realized error = ",(flux-thin_flux)/flux)
assert np.abs(thin_err) < err, "Thinned bandpass failed accuracy goal, w/ range shrinkage."


@timer
Expand Down
60 changes: 36 additions & 24 deletions tests/test_sed.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,30 +1031,42 @@ def test_ne():

@timer
def test_thin():
s = galsim.SED(os.path.join(sedpath, 'CWW_E_ext.sed'), wave_type='ang', flux_type='flambda',
fast=False)
bp = galsim.Bandpass('1', 'nm', blue_limit=s.blue_limit, red_limit=s.red_limit)
flux = s.calculateFlux(bp)
print("Original number of SED samples = ",len(s.wave_list))
for err in [1.e-2, 1.e-3, 1.e-4, 1.e-5]:
print("Test err = ",err)
thin_s = s.thin(rel_err=err, preserve_range=True, fast_search=False)
thin_flux = thin_s.calculateFlux(bp)
thin_err = (flux-thin_flux)/flux
print("num samples with preserve_range = True, fast_search = False: ",len(thin_s.wave_list))
print("realized error = ",(flux-thin_flux)/flux)
thin_s = s.thin(rel_err=err, preserve_range=True)
thin_flux = thin_s.calculateFlux(bp)
thin_err = (flux-thin_flux)/flux
print("num samples with preserve_range = True: ",len(thin_s.wave_list))
print("realized error = ",(flux-thin_flux)/flux)
assert np.abs(thin_err) < err, "Thinned SED failed accuracy goal, preserving range."
thin_s = s.thin(rel_err=err, preserve_range=False)
thin_flux = thin_s.calculateFlux(bp.truncate(thin_s.blue_limit, thin_s.red_limit))
thin_err = (flux-thin_flux)/flux
print("num samples with preserve_range = False: ",len(thin_s.wave_list))
print("realized error = ",(flux-thin_flux)/flux)
assert np.abs(thin_err) < err, "Thinned SED failed accuracy goal, w/ range shrinkage."
for interpolant in ['linear', 'nearest', 'spline']:
s = galsim.SED('CWW_E_ext.sed', wave_type='ang', flux_type='flambda',
fast=False, interpolant=interpolant)
bp = galsim.Bandpass('1', 'nm', blue_limit=s.blue_limit, red_limit=s.red_limit)
flux = s.calculateFlux(bp)
print("Original number of SED samples = ",len(s.wave_list))
for err in [1.e-2, 1.e-3, 1.e-4, 1.e-5]:
print("Test err = ",err)
thin_s = s.thin(rel_err=err, preserve_range=True, fast_search=False)
thin_flux = thin_s.calculateFlux(bp)
thin_err = (flux-thin_flux)/flux
print("num samples with preserve_range = True, fast_search = False: ",
len(thin_s.wave_list))
print("realized error = ",(flux-thin_flux)/flux)
thin_s = s.thin(rel_err=err, preserve_range=True)
thin_flux = thin_s.calculateFlux(bp)
thin_err = (flux-thin_flux)/flux
print("num samples with preserve_range = True: ",len(thin_s.wave_list))
print("realized error = ",(flux-thin_flux)/flux)
print('true flux = ',flux)
print('thinned flux = ',thin_flux)
print('err = ',thin_err)
# The thinning algorithm guarantees a relative error of err for bolometric flux,
# but not for any arbitrary bandpass. When the target error is very small, it can
# miss by a bit, especially for spline. So test it with a little looser tolerance
# than the target.
test_err = err*4 if err <= 1.e-5 else err
assert np.abs(thin_err) < test_err,\
"Thinned SED failed accuracy goal, preserving range."
thin_s = s.thin(rel_err=err, preserve_range=False)
thin_flux = thin_s.calculateFlux(bp.truncate(thin_s.blue_limit, thin_s.red_limit))
thin_err = (flux-thin_flux)/flux
print("num samples with preserve_range = False: ",len(thin_s.wave_list))
print("realized error = ",(flux-thin_flux)/flux)
assert np.abs(thin_err) < test_err,\
"Thinned SED failed accuracy goal, w/ range shrinkage."

assert_raises(ValueError, s.thin, rel_err=-0.5)
assert_raises(ValueError, s.thin, rel_err=1.5)
Expand Down

0 comments on commit 9ae7a2a

Please sign in to comment.