From 9ae7a2a47069897bf8d85705e24f1e80c8bf157c Mon Sep 17 00:00:00 2001 From: Mike Jarvis Date: Thu, 26 Oct 2023 15:49:00 -0400 Subject: [PATCH] Make thin_tabulated work for spline interpolation --- galsim/bandpass.py | 5 +-- galsim/sed.py | 7 +-- galsim/utilities.py | 99 +++++++++++++++++++++++++++++++----------- tests/test_bandpass.py | 48 ++++++++++---------- tests/test_sed.py | 60 +++++++++++++++---------- 5 files changed, 142 insertions(+), 77 deletions(-) diff --git a/galsim/bandpass.py b/galsim/bandpass.py index 10fe70b9fa..01dd73e2cd 100644 --- a/galsim/bandpass.py +++ b/galsim/bandpass.py @@ -596,9 +596,8 @@ def thin(self, rel_err=1.e-4, trim_zeros=True, preserve_range=True, fast_search= newx, newf = utilities.thin_tabulated_values(x, f, rel_err=rel_err, trim_zeros=trim_zeros, preserve_range=preserve_range, - fast_search=fast_search) - interpolant = (self.interpolant if not isinstance(self._tp, LookupTable) - else self._tp.interpolant) + fast_search=fast_search, + interpolant=interpolant) tp = _LookupTable(newx, newf, interpolant) blue_limit = np.min(newx) red_limit = np.max(newx) diff --git a/galsim/sed.py b/galsim/sed.py index c68c44c1f2..c7146e594c 100644 --- a/galsim/sed.py +++ b/galsim/sed.py @@ -891,12 +891,13 @@ def thin(self, rel_err=1.e-4, trim_zeros=True, preserve_range=True, fast_search= spec_native = self._spec(rest_wave_native) # Note that this is thinning in native units, not nm and photons/nm. + interpolant = (self.interpolant if not isinstance(self._spec, LookupTable) + else self._spec.interpolant) newx, newf = utilities.thin_tabulated_values( rest_wave_native, spec_native, rel_err=rel_err, - trim_zeros=trim_zeros, preserve_range=preserve_range, fast_search=fast_search) + trim_zeros=trim_zeros, preserve_range=preserve_range, + fast_search=fast_search, interpolant=interpolant) - interpolant = (self.interpolant if not isinstance(self._spec, LookupTable) - else self._spec.interpolant) newspec = _LookupTable(newx, newf, interpolant=interpolant) return SED(newspec, self.wave_type, self.flux_type, redshift=self.redshift, fast=self.fast) diff --git a/galsim/utilities.py b/galsim/utilities.py index fc61d5c13a..cda7f7e5ea 100644 --- a/galsim/utilities.py +++ b/galsim/utilities.py @@ -41,7 +41,7 @@ from .position import Position, PositionD, PositionI, _PositionD, _PositionI from .angle import AngleUnit, arcsec from .image import Image -from .table import trapz, LookupTable2D +from .table import trapz, _LookupTable, LookupTable2D from .wcs import JacobianWCS, PixelScale from .position import _PositionD from .random import BaseDeviate, UniformDeviate @@ -223,32 +223,58 @@ def _convertPositions(pos, units, func): return pos -def _lin_approx_err(x, f, i): +def _spline_approx_err(x, f, left, right, splitpoints, i): + # For splines, we can't just do the integral over a small range, since the spline slopes + # are all wrong. Rather we compute a spline function with the current splitpoints and + # just the single point in the trial region and recompute a spline function with that. + # Then we can compute the total error from that approximation. + # (For the error integral, we still use linear.) + + indices = sorted(splitpoints + [i]) + new_tab = _LookupTable(x[indices], f[indices], 'spline') + + xleft, xright = x[left:i+1], x[i:right+1] + fleft, fright = f[left:i+1], f[i:right+1] + f2left = new_tab(xleft) + f2right = new_tab(xright) + return trapz(np.abs(fleft-f2left), xleft), trapz(np.abs(fright-f2right), xright) + +def _spline_approx_split(x, f, left, right, splitpoints): + r"""Split a tabulated function into a two-part piecewise spline approximation by exactly + minimizing \int abs(f(x) - approx(x)) dx. Operates in O(N^2) time. + """ + errs = [_spline_approx_err(x, f, left, right, splitpoints, i) for i in range(left+1, right)] + i = np.argmin(np.sum(errs, axis=1)) + return i+left+1, errs[i] + +def _lin_approx_err(x, f, left, right, i): r"""Error as \int abs(f(x) - approx(x)) when using ith data point to make piecewise linear approximation. """ - xleft, xright = x[:i+1], x[i:] - fleft, fright = f[:i+1], f[i:] + xleft, xright = x[left:i+1], x[i:right+1] + fleft, fright = f[left:i+1], f[i:right+1] xi, fi = x[i], f[i] - mleft = (fi-f[0])/(xi-x[0]) - mright = (f[-1]-fi)/(x[-1]-xi) - f2left = f[0]+mleft*(xleft-x[0]) + mleft = (fi-f[left])/(xi-x[left]) + mright = (f[right]-fi)/(x[right]-xi) + f2left = f[left]+mleft*(xleft-x[left]) f2right = fi+mright*(xright-xi) return trapz(np.abs(fleft-f2left), xleft), trapz(np.abs(fright-f2right), xright) -def _exact_lin_approx_split(x, f): +def _exact_lin_approx_split(x, f, left, right, splitpoints): r"""Split a tabulated function into a two-part piecewise linear approximation by exactly minimizing \int abs(f(x) - approx(x)) dx. Operates in O(N^2) time. """ - errs = [_lin_approx_err(x, f, i) for i in range(1, len(x)-1)] + errs = [_lin_approx_err(x, f, left, right, i) for i in range(left+1, right)] i = np.argmin(np.sum(errs, axis=1)) - return i+1, errs[i] + return i+left+1, errs[i] -def _lin_approx_split(x, f): +def _lin_approx_split(x, f, left, right, splitpoints): r"""Split a tabulated function into a two-part piecewise linear approximation by approximately minimizing \int abs(f(x) - approx(x)) dx. Chooses the split point by exactly minimizing \int (f(x) - approx(x))^2 dx in O(N) time. """ + x = x[left:right+1] + f = f[left:right+1] dx = x[2:] - x[:-2] # Error contribution on the left. ff0 = f[1:-1]-f[0] # Only need to search between j=1..(N-1) @@ -269,10 +295,10 @@ def _lin_approx_split(x, f): # Get absolute error for the found point. i = np.argmin(errleft+errright) - return i+1, _lin_approx_err(x, f, i+1) + return i+left+1, _lin_approx_err(x, f, 0, len(x)-1, i+1) def thin_tabulated_values(x, f, rel_err=1.e-4, trim_zeros=True, preserve_range=True, - fast_search=True): + fast_search=True, interpolant='linear'): """ Remove items from a set of tabulated f(x) values so that the error in the integral is still accurate to a given relative accuracy. @@ -299,11 +325,17 @@ def thin_tabulated_values(x, f, rel_err=1.e-4, trim_zeros=True, preserve_range=T found that the slower algorithm tends to yield a thinned representation that retains fewer samples while still meeting the relative error requirement. [default: True] + interpolant: The interpolant to assume for the tabulated values. [default: 'linear'] Returns: a tuple of lists (x_new, y_new) with the thinned tabulation. """ - split_fn = _lin_approx_split if fast_search else _exact_lin_approx_split + if interpolant == 'spline': + split_fn = _spline_approx_split + elif fast_search: + split_fn = _lin_approx_split + else: + split_fn = _exact_lin_approx_split x = np.asarray(x, dtype=float) f = np.asarray(f, dtype=float) @@ -321,7 +353,7 @@ def thin_tabulated_values(x, f, rel_err=1.e-4, trim_zeros=True, preserve_range=T # Nothing to do return x,f - total_integ = trapz(abs(f), x) + total_integ = trapz(abs(f), x, interpolant) if total_integ == 0: return np.array([ x[0], x[-1] ]), np.array([ f[0], f[-1] ]) thresh = total_integ * rel_err @@ -331,11 +363,11 @@ def thin_tabulated_values(x, f, rel_err=1.e-4, trim_zeros=True, preserve_range=T last = min(f.nonzero()[0][-1]+1, len(x)-1) # +1 to keep one non-redundant zero. x, f = x[first:last+1], f[first:last+1] - x_range = x[-1] - x[0] if not preserve_range: # Remove values from the front that integrate to less than thresh. err_integ1 = 0.5 * (abs(f[0]) + abs(f[1])) * (x[1] - x[0]) k0 = 0 + x_range = x[-1] - x[0] while k0 < len(x)-2 and err_integ1 < thresh * (x[k0+1]-x[0]) / x_range: k0 = k0+1 err_integ1 += 0.5 * (abs(f[k0]) + abs(f[k0+1])) * (x[k0+1] - x[k0]) @@ -352,14 +384,15 @@ def thin_tabulated_values(x, f, rel_err=1.e-4, trim_zeros=True, preserve_range=T # That means k1 is the smallest value we can use that will work as the ending value. # Subtract the error so far from thresh - thresh -= trapz(abs(f[:k0]),x[:k0]) + trapz(abs(f[k1:]),x[k1:]) + if interpolant == 'spline': + new_integ = trapz(abs(f[k0:k1+1]),x[k0:k1+1], interpolant='spline') + thresh -= np.abs(new_integ-total_integ) + else: + thresh -= trapz(abs(f[:k0]),x[:k0]) + trapz(abs(f[k1:]),x[k1:]) x = x[k0:k1+1] # +1 since end of range is given as one-past-the-end. f = f[k0:k1+1] - # And update x_range for the new values - x_range = x[-1] - x[0] - # Check again for noop after trimming endpoints. if len(x) <= 2: return x,f @@ -370,12 +403,28 @@ def thin_tabulated_values(x, f, rel_err=1.e-4, trim_zeros=True, preserve_range=T heap = [(-2*thresh, # -err; initialize large enough to trigger while loop below. 0, # first index of interval len(x)-1)] # last index of interval - while (-sum(h[0] for h in heap) > thresh): + splitpoints = [0,len(x)-1] + while len(heap) > 0: _, left, right = heapq.heappop(heap) - i, (errleft, errright) = split_fn(x[left:right+1], f[left:right+1]) - heapq.heappush(heap, (-errleft, left, i+left)) - heapq.heappush(heap, (-errright, i+left, right)) - splitpoints = sorted([0]+[h[2] for h in heap]) + i, (errleft, errright) = split_fn(x, f, left, right, splitpoints) + splitpoints.append(i) + if i > left+1: + heapq.heappush(heap, (-errleft, left, i)) + if right > i+1: + heapq.heappush(heap, (-errright, i, right)) + if interpolant != 'spline': + # This is a sufficient stopping criterion for linear + if (-sum(h[0] for h in heap) < thresh): + break + else: + # For spline, we also need to recompute the total integral to make sure + # that the realized total error is less than thresh. + if (-sum(h[0] for h in heap) < thresh): + splitpoints = sorted(splitpoints) + current_integ = trapz(f[splitpoints], x[splitpoints], interpolant) + if np.abs(current_integ - total_integ) < thresh: + break + splitpoints = sorted(splitpoints) return x[splitpoints], f[splitpoints] def old_thin_tabulated_values(x, f, rel_err=1.e-4, preserve_range=False): # pragma: no cover diff --git a/tests/test_bandpass.py b/tests/test_bandpass.py index 968a68a93e..b291f847ef 100644 --- a/tests/test_bandpass.py +++ b/tests/test_bandpass.py @@ -319,28 +319,32 @@ def test_ne(): def test_thin(): """Test that bandpass thinning works with the requested accuracy.""" s = galsim.SED('1', wave_type='nm', flux_type='fphotons') - bp = galsim.Bandpass(os.path.join(datapath, 'LSST_r.dat'), 'nm') - flux = s.calculateFlux(bp) - print("Original number of bandpass samples = ",len(bp.wave_list)) - for err in [1.e-2, 1.e-3, 1.e-4, 1.e-5]: - print("Test err = ",err) - thin_bp = bp.thin(rel_err=err, preserve_range=True, fast_search=False) - thin_flux = s.calculateFlux(thin_bp) - thin_err = (flux-thin_flux)/flux - print("num samples with preserve_range = True, fast_search = False: ",len(thin_bp.wave_list)) - print("realized error = ",(flux-thin_flux)/flux) - thin_bp = bp.thin(rel_err=err, preserve_range=True) - thin_flux = s.calculateFlux(thin_bp) - thin_err = (flux-thin_flux)/flux - print("num samples with preserve_range = True: ",len(thin_bp.wave_list)) - print("realized error = ",(flux-thin_flux)/flux) - assert np.abs(thin_err) < err, "Thinned bandpass failed accuracy goal, preserving range." - thin_bp = bp.thin(rel_err=err, preserve_range=False) - thin_flux = s.calculateFlux(thin_bp) - thin_err = (flux-thin_flux)/flux - print("num samples with preserve_range = False: ",len(thin_bp.wave_list)) - print("realized error = ",(flux-thin_flux)/flux) - assert np.abs(thin_err) < err, "Thinned bandpass failed accuracy goal, w/ range shrinkage." + bp1 = galsim.Bandpass(os.path.join(datapath, 'LSST_r.dat'), 'nm') + bp2 = galsim.Bandpass(os.path.join(datapath, 'LSST_r.dat'), 'nm', interpolant='spline') + + for bp in [bp1, bp2]: + flux = s.calculateFlux(bp) + print("Original number of bandpass samples = ",len(bp.wave_list)) + for err in [1.e-2, 1.e-3, 1.e-4, 1.e-5]: + print("Test err = ",err) + thin_bp = bp.thin(rel_err=err, preserve_range=True, fast_search=False) + thin_flux = s.calculateFlux(thin_bp) + thin_err = (flux-thin_flux)/flux + print("num samples with preserve_range = True, fast_search = False: ", + len(thin_bp.wave_list)) + print("realized error = ",(flux-thin_flux)/flux) + thin_bp = bp.thin(rel_err=err, preserve_range=True) + thin_flux = s.calculateFlux(thin_bp) + thin_err = (flux-thin_flux)/flux + print("num samples with preserve_range = True: ",len(thin_bp.wave_list)) + print("realized error = ",(flux-thin_flux)/flux) + assert np.abs(thin_err) < err, "Thinned bandpass failed accuracy goal, preserving range." + thin_bp = bp.thin(rel_err=err, preserve_range=False) + thin_flux = s.calculateFlux(thin_bp) + thin_err = (flux-thin_flux)/flux + print("num samples with preserve_range = False: ",len(thin_bp.wave_list)) + print("realized error = ",(flux-thin_flux)/flux) + assert np.abs(thin_err) < err, "Thinned bandpass failed accuracy goal, w/ range shrinkage." @timer diff --git a/tests/test_sed.py b/tests/test_sed.py index c01d3d385a..2cca581c5c 100644 --- a/tests/test_sed.py +++ b/tests/test_sed.py @@ -1031,30 +1031,42 @@ def test_ne(): @timer def test_thin(): - s = galsim.SED(os.path.join(sedpath, 'CWW_E_ext.sed'), wave_type='ang', flux_type='flambda', - fast=False) - bp = galsim.Bandpass('1', 'nm', blue_limit=s.blue_limit, red_limit=s.red_limit) - flux = s.calculateFlux(bp) - print("Original number of SED samples = ",len(s.wave_list)) - for err in [1.e-2, 1.e-3, 1.e-4, 1.e-5]: - print("Test err = ",err) - thin_s = s.thin(rel_err=err, preserve_range=True, fast_search=False) - thin_flux = thin_s.calculateFlux(bp) - thin_err = (flux-thin_flux)/flux - print("num samples with preserve_range = True, fast_search = False: ",len(thin_s.wave_list)) - print("realized error = ",(flux-thin_flux)/flux) - thin_s = s.thin(rel_err=err, preserve_range=True) - thin_flux = thin_s.calculateFlux(bp) - thin_err = (flux-thin_flux)/flux - print("num samples with preserve_range = True: ",len(thin_s.wave_list)) - print("realized error = ",(flux-thin_flux)/flux) - assert np.abs(thin_err) < err, "Thinned SED failed accuracy goal, preserving range." - thin_s = s.thin(rel_err=err, preserve_range=False) - thin_flux = thin_s.calculateFlux(bp.truncate(thin_s.blue_limit, thin_s.red_limit)) - thin_err = (flux-thin_flux)/flux - print("num samples with preserve_range = False: ",len(thin_s.wave_list)) - print("realized error = ",(flux-thin_flux)/flux) - assert np.abs(thin_err) < err, "Thinned SED failed accuracy goal, w/ range shrinkage." + for interpolant in ['linear', 'nearest', 'spline']: + s = galsim.SED('CWW_E_ext.sed', wave_type='ang', flux_type='flambda', + fast=False, interpolant=interpolant) + bp = galsim.Bandpass('1', 'nm', blue_limit=s.blue_limit, red_limit=s.red_limit) + flux = s.calculateFlux(bp) + print("Original number of SED samples = ",len(s.wave_list)) + for err in [1.e-2, 1.e-3, 1.e-4, 1.e-5]: + print("Test err = ",err) + thin_s = s.thin(rel_err=err, preserve_range=True, fast_search=False) + thin_flux = thin_s.calculateFlux(bp) + thin_err = (flux-thin_flux)/flux + print("num samples with preserve_range = True, fast_search = False: ", + len(thin_s.wave_list)) + print("realized error = ",(flux-thin_flux)/flux) + thin_s = s.thin(rel_err=err, preserve_range=True) + thin_flux = thin_s.calculateFlux(bp) + thin_err = (flux-thin_flux)/flux + print("num samples with preserve_range = True: ",len(thin_s.wave_list)) + print("realized error = ",(flux-thin_flux)/flux) + print('true flux = ',flux) + print('thinned flux = ',thin_flux) + print('err = ',thin_err) + # The thinning algorithm guarantees a relative error of err for bolometric flux, + # but not for any arbitrary bandpass. When the target error is very small, it can + # miss by a bit, especially for spline. So test it with a little looser tolerance + # than the target. + test_err = err*4 if err <= 1.e-5 else err + assert np.abs(thin_err) < test_err,\ + "Thinned SED failed accuracy goal, preserving range." + thin_s = s.thin(rel_err=err, preserve_range=False) + thin_flux = thin_s.calculateFlux(bp.truncate(thin_s.blue_limit, thin_s.red_limit)) + thin_err = (flux-thin_flux)/flux + print("num samples with preserve_range = False: ",len(thin_s.wave_list)) + print("realized error = ",(flux-thin_flux)/flux) + assert np.abs(thin_err) < test_err,\ + "Thinned SED failed accuracy goal, w/ range shrinkage." assert_raises(ValueError, s.thin, rel_err=-0.5) assert_raises(ValueError, s.thin, rel_err=1.5)