Skip to content

Commit

Permalink
Merge pull request #198 from lincc-frameworks/passbands
Browse files Browse the repository at this point in the history
Small restructuring of passbands
  • Loading branch information
jeremykubica authored Dec 3, 2024
2 parents e22f83c + cdc6aaa commit 817d06f
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 30 deletions.
38 changes: 19 additions & 19 deletions src/tdastro/astro_utils/passbands.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ class PassbandGroup:
The path to the directory containing the passband tables.
waves : np.ndarray
The union of all wavelengths in the passbands.
_in_band_wave_indices : dict
A dictionary mapping the passband name (eg, "LSST_u") to the indices of that specific
passband's wavelengths in the full waves list.
"""

def __init__(
Expand Down Expand Up @@ -64,6 +67,7 @@ def __init__(
Additional keyword arguments to pass to the Passband constructor.
"""
self.passbands = {}
self._in_band_wave_indices = {}

if preset is None and passband_parameters is None and given_passbands is None:
raise ValueError(
Expand Down Expand Up @@ -102,7 +106,6 @@ def __init__(

# Compute the unique points and bounds for the group.
self._update_waves()
self._calculate_in_band_wave_indices()

def __str__(self) -> str:
"""Return a string representation of the PassbandGroup."""
Expand Down Expand Up @@ -145,7 +148,14 @@ def _load_preset(self, preset: str, table_dir: Optional[str], **kwargs) -> None:

def _update_waves(self, threshold=1e-5) -> None:
"""Update the group's wave attribute to be the union of all wavelengths in
the passbands.
the passbands and update the group's _in_band_wave_indices attribute, which is
the indices of the group's wave grid that are in the passband's wave grid.
Eg, if a group's waves are [11, 12, 13, 14, 15] and a single band's are [13, 14],
we get [2, 3].
The indices are stored in the passband's _in_band_wave_indices attribute as either
a tuple of two ints (lower, upper) or a 1D np.ndarray of ints.
Parameters
----------
Expand All @@ -166,15 +176,10 @@ def _update_waves(self, threshold=1e-5) -> None:
gap_sizes = np.insert(sorted_waves[1:] - sorted_waves[:-1], 0, 1e8)
self.waves = sorted_waves[gap_sizes >= threshold]

def _calculate_in_band_wave_indices(self) -> None:
"""Calculate the indices of the group's wave grid that are in the passband's wave grid.
Eg, if a group's waves are [11, 12, 13, 14, 15] and a single band's are [13, 14], we get [2, 3].
The indices are stored in the passband's _in_band_wave_indices attribute as either a tuple of two ints
(lower, upper) or a 1D np.ndarray of ints.
"""
for passband in self.passbands.values():
# Update the mapping of each passband's wavelengths to the corresponding indices in the
# unioned list of all wavelengths.
self._in_band_wave_indices = {}
for name, passband in self.passbands.items():
# We only want the fluxes that are in the passband's wavelength range
# So, find the indices in the group's wave grid that are in the passband's wave grid
lower, upper = passband.waves[0], passband.waves[-1]
Expand All @@ -187,7 +192,7 @@ def _calculate_in_band_wave_indices(self) -> None:
indices = slice(lower_index, upper_index + 1)
else:
indices = np.searchsorted(self.waves, passband.waves)
passband._in_band_wave_indices = indices
self._in_band_wave_indices[name] = indices

def process_transmission_tables(
self, delta_wave: Optional[float] = 5.0, trim_quantile: Optional[float] = 1e-3
Expand All @@ -206,7 +211,6 @@ def process_transmission_tables(
passband.process_transmission_table(delta_wave, trim_quantile)

self._update_waves()
self._calculate_in_band_wave_indices()

def fluxes_to_bandfluxes(self, flux_density_matrix: np.ndarray) -> np.ndarray:
"""Calculate bandfluxes for all passbands in the group.
Expand All @@ -232,12 +236,12 @@ def fluxes_to_bandfluxes(self, flux_density_matrix: np.ndarray) -> np.ndarray:

bandfluxes = {}
for full_name, passband in self.passbands.items():
indices = passband._in_band_wave_indices
indices = self._in_band_wave_indices[full_name]

if indices is None:
raise ValueError(
f"Passband {full_name} does not have _in_band_wave_indices set. "
"This should have been calculated in PassbandGroup._calculate_in_band_wave_indices."
"This should have been calculated in PassbandGroup._update_waves."
)

in_band_fluxes = flux_density_matrix[:, indices]
Expand Down Expand Up @@ -270,9 +274,6 @@ class Passband:
processed_transmission_table : np.ndarray
A 2D array where the first col is wavelengths (Angstrom) and the second col is transmission values.
This table is both interpolated to the _wave_grid and normalized to calculate phi_b(λ).
_in_band_wave_indices : np.ndarray, optional
The indices of the full wave grid used in PassbandGroup that correspond to this Passband's wave grid.
This is only set when the Passband is part of a PassbandGroup.
"""

def __init__(
Expand Down Expand Up @@ -326,7 +327,6 @@ def __init__(
self.table_path = Path(table_path) if table_path is not None else None
self.table_url = table_url
self.units = units
self._in_band_wave_indices = None

if table_values is not None:
if table_values.shape[1] != 2:
Expand Down
20 changes: 9 additions & 11 deletions tests/tdastro/astro_utils/test_passband_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,16 +289,14 @@ def test_passband_group_calculate_in_band_wave_indices(passbands_dir, tmp_path):

# Note that passband_A and passband_B have overlapping wavelength ranges
# Where passband_A covers 100-300 and passband_B covers 250-350 (and passband_C covers 400-600)
np.testing.assert_allclose(
passband_A._in_band_wave_indices, np.array([0, 1, 2, 3, 4, 5, 6, 7, 9, 11, 13])
)
np.testing.assert_allclose(passband_A.waves, toy_passband_group.waves[passband_A._in_band_wave_indices])
toy_a_inds = toy_passband_group._in_band_wave_indices["TOY_a"]
np.testing.assert_allclose(toy_a_inds, np.array([0, 1, 2, 3, 4, 5, 6, 7, 9, 11, 13]))
np.testing.assert_allclose(passband_A.waves, toy_passband_group.waves[toy_a_inds])

np.testing.assert_allclose(passband_B._in_band_wave_indices, np.array([8, 10, 12, 14, 15, 16]))
np.testing.assert_allclose(passband_B.waves, toy_passband_group.waves[passband_B._in_band_wave_indices])
toy_b_inds = toy_passband_group._in_band_wave_indices["TOY_b"]
np.testing.assert_allclose(toy_b_inds, np.array([8, 10, 12, 14, 15, 16]))
np.testing.assert_allclose(passband_B.waves, toy_passband_group.waves[toy_b_inds])

assert passband_C._in_band_wave_indices == slice(17, 28)
np.testing.assert_allclose(
passband_C.waves,
toy_passband_group.waves[passband_C._in_band_wave_indices],
)
toy_c_inds = toy_passband_group._in_band_wave_indices["TOY_c"]
assert toy_c_inds == slice(17, 28)
np.testing.assert_allclose(passband_C.waves, toy_passband_group.waves[toy_c_inds])

0 comments on commit 817d06f

Please sign in to comment.