diff --git a/src/tdastro/astro_utils/passbands.py b/src/tdastro/astro_utils/passbands.py index e0b60d8..cc72ae2 100644 --- a/src/tdastro/astro_utils/passbands.py +++ b/src/tdastro/astro_utils/passbands.py @@ -23,6 +23,9 @@ class PassbandGroup: The path to the directory containing the passband tables. waves : np.ndarray The union of all wavelengths in the passbands. + _in_band_wave_indices : dict + A dictionary mapping the passband name (eg, "LSST_u") to the indices of that specific + passband's wavelengths in the full waves list. """ def __init__( @@ -64,6 +67,7 @@ def __init__( Additional keyword arguments to pass to the Passband constructor. """ self.passbands = {} + self._in_band_wave_indices = {} if preset is None and passband_parameters is None and given_passbands is None: raise ValueError( @@ -102,7 +106,6 @@ def __init__( # Compute the unique points and bounds for the group. self._update_waves() - self._calculate_in_band_wave_indices() def __str__(self) -> str: """Return a string representation of the PassbandGroup.""" @@ -145,7 +148,14 @@ def _load_preset(self, preset: str, table_dir: Optional[str], **kwargs) -> None: def _update_waves(self, threshold=1e-5) -> None: """Update the group's wave attribute to be the union of all wavelengths in - the passbands. + the passbands and update the group's _in_band_wave_indices attribute, which is + the indices of the group's wave grid that are in the passband's wave grid. + + Eg, if a group's waves are [11, 12, 13, 14, 15] and a single band's are [13, 14], + we get [2, 3]. + + The indices are stored in the passband's _in_band_wave_indices attribute as either + a tuple of two ints (lower, upper) or a 1D np.ndarray of ints. Parameters ---------- @@ -166,15 +176,10 @@ def _update_waves(self, threshold=1e-5) -> None: gap_sizes = np.insert(sorted_waves[1:] - sorted_waves[:-1], 0, 1e8) self.waves = sorted_waves[gap_sizes >= threshold] - def _calculate_in_band_wave_indices(self) -> None: - """Calculate the indices of the group's wave grid that are in the passband's wave grid. - - Eg, if a group's waves are [11, 12, 13, 14, 15] and a single band's are [13, 14], we get [2, 3]. - - The indices are stored in the passband's _in_band_wave_indices attribute as either a tuple of two ints - (lower, upper) or a 1D np.ndarray of ints. - """ - for passband in self.passbands.values(): + # Update the mapping of each passband's wavelengths to the corresponding indices in the + # unioned list of all wavelengths. + self._in_band_wave_indices = {} + for name, passband in self.passbands.items(): # We only want the fluxes that are in the passband's wavelength range # So, find the indices in the group's wave grid that are in the passband's wave grid lower, upper = passband.waves[0], passband.waves[-1] @@ -187,7 +192,7 @@ def _calculate_in_band_wave_indices(self) -> None: indices = slice(lower_index, upper_index + 1) else: indices = np.searchsorted(self.waves, passband.waves) - passband._in_band_wave_indices = indices + self._in_band_wave_indices[name] = indices def process_transmission_tables( self, delta_wave: Optional[float] = 5.0, trim_quantile: Optional[float] = 1e-3 @@ -206,7 +211,6 @@ def process_transmission_tables( passband.process_transmission_table(delta_wave, trim_quantile) self._update_waves() - self._calculate_in_band_wave_indices() def fluxes_to_bandfluxes(self, flux_density_matrix: np.ndarray) -> np.ndarray: """Calculate bandfluxes for all passbands in the group. @@ -232,12 +236,12 @@ def fluxes_to_bandfluxes(self, flux_density_matrix: np.ndarray) -> np.ndarray: bandfluxes = {} for full_name, passband in self.passbands.items(): - indices = passband._in_band_wave_indices + indices = self._in_band_wave_indices[full_name] if indices is None: raise ValueError( f"Passband {full_name} does not have _in_band_wave_indices set. " - "This should have been calculated in PassbandGroup._calculate_in_band_wave_indices." + "This should have been calculated in PassbandGroup._update_waves." ) in_band_fluxes = flux_density_matrix[:, indices] @@ -270,9 +274,6 @@ class Passband: processed_transmission_table : np.ndarray A 2D array where the first col is wavelengths (Angstrom) and the second col is transmission values. This table is both interpolated to the _wave_grid and normalized to calculate phi_b(λ). - _in_band_wave_indices : np.ndarray, optional - The indices of the full wave grid used in PassbandGroup that correspond to this Passband's wave grid. - This is only set when the Passband is part of a PassbandGroup. """ def __init__( @@ -326,7 +327,6 @@ def __init__( self.table_path = Path(table_path) if table_path is not None else None self.table_url = table_url self.units = units - self._in_band_wave_indices = None if table_values is not None: if table_values.shape[1] != 2: diff --git a/tests/tdastro/astro_utils/test_passband_groups.py b/tests/tdastro/astro_utils/test_passband_groups.py index 1cf52e1..b00b312 100644 --- a/tests/tdastro/astro_utils/test_passband_groups.py +++ b/tests/tdastro/astro_utils/test_passband_groups.py @@ -289,16 +289,14 @@ def test_passband_group_calculate_in_band_wave_indices(passbands_dir, tmp_path): # Note that passband_A and passband_B have overlapping wavelength ranges # Where passband_A covers 100-300 and passband_B covers 250-350 (and passband_C covers 400-600) - np.testing.assert_allclose( - passband_A._in_band_wave_indices, np.array([0, 1, 2, 3, 4, 5, 6, 7, 9, 11, 13]) - ) - np.testing.assert_allclose(passband_A.waves, toy_passband_group.waves[passband_A._in_band_wave_indices]) + toy_a_inds = toy_passband_group._in_band_wave_indices["TOY_a"] + np.testing.assert_allclose(toy_a_inds, np.array([0, 1, 2, 3, 4, 5, 6, 7, 9, 11, 13])) + np.testing.assert_allclose(passband_A.waves, toy_passband_group.waves[toy_a_inds]) - np.testing.assert_allclose(passband_B._in_band_wave_indices, np.array([8, 10, 12, 14, 15, 16])) - np.testing.assert_allclose(passband_B.waves, toy_passband_group.waves[passband_B._in_band_wave_indices]) + toy_b_inds = toy_passband_group._in_band_wave_indices["TOY_b"] + np.testing.assert_allclose(toy_b_inds, np.array([8, 10, 12, 14, 15, 16])) + np.testing.assert_allclose(passband_B.waves, toy_passband_group.waves[toy_b_inds]) - assert passband_C._in_band_wave_indices == slice(17, 28) - np.testing.assert_allclose( - passband_C.waves, - toy_passband_group.waves[passband_C._in_band_wave_indices], - ) + toy_c_inds = toy_passband_group._in_band_wave_indices["TOY_c"] + assert toy_c_inds == slice(17, 28) + np.testing.assert_allclose(passband_C.waves, toy_passband_group.waves[toy_c_inds])