diff --git a/src/tdastro/astro_utils/passbands.py b/src/tdastro/astro_utils/passbands.py index cc72ae2..d55f1bb 100644 --- a/src/tdastro/astro_utils/passbands.py +++ b/src/tdastro/astro_utils/passbands.py @@ -591,16 +591,31 @@ def fluxes_to_bandflux( Parameters ---------- flux_density_matrix : np.ndarray - A 2D array of flux densities where rows are times and columns are wavelengths. + A 2D or 3D array of flux densities. If the array is 2D it contains a single sample where + the rows are the T times and columns are M wavelengths. If the array is 3D it contains S + samples and the values are indexed as (sample_num, time, wavelength). Returns ------- - np.ndarray - An array of bandfluxes with length flux_density_matrix, where each element is the bandflux - at the corresponding time. + bandfluxes : np.ndarray + A 1D or 2D array. If the flux_density_matrix contains a single sample (2D input) then + the function returns a 1D length T array where each element is the bandflux + at the corresponding time. Otherwise the function returns a size S x T array where + each entry corresponds to the value for a given sample at a given time. """ - if flux_density_matrix.size == 0 or len(self.waves) != len(flux_density_matrix[0]): - flux_density_matrix_num_cols = 0 if flux_density_matrix.size == 0 else len(flux_density_matrix[0]) + if flux_density_matrix.size == 0: + raise ValueError("Empty flux density matrix used.") + if len(flux_density_matrix.shape) == 2: + w_axis = 1 + flux_density_matrix_num_cols = flux_density_matrix.shape[1] + elif len(flux_density_matrix.shape) == 3: + w_axis = 2 + flux_density_matrix_num_cols = flux_density_matrix.shape[2] + else: + raise ValueError("Invalid flux density matrix. Must be 2 or 3-dimensional.") + + # Check the number of wavelengths match. + if len(self.waves) != flux_density_matrix_num_cols: raise ValueError( f"Passband mismatched grids: Flux density matrix has {flux_density_matrix_num_cols} " f"columns, which does not match the {len(self.waves)} rows in band {self.full_name}'s " @@ -612,6 +627,5 @@ def fluxes_to_bandflux( # Calculate the bandflux as ∫ f(λ)φ_b(λ) dλ, # where f(λ) is the flux density and φ_b(λ) is the normalized system response integrand = flux_density_matrix * self.processed_transmission_table[:, 1] - bandfluxes = scipy.integrate.trapezoid(integrand, x=self.waves) - + bandfluxes = scipy.integrate.trapezoid(integrand, x=self.waves, axis=w_axis) return bandfluxes diff --git a/src/tdastro/sources/physical_model.py b/src/tdastro/sources/physical_model.py index 3cd43a8..8b1a5c9 100644 --- a/src/tdastro/sources/physical_model.py +++ b/src/tdastro/sources/physical_model.py @@ -141,7 +141,9 @@ def evaluate(self, times, wavelengths, graph_state=None, given_args=None, rng_in Returns ------- flux_density : `numpy.ndarray` - A length T x N matrix of SED values (in nJy). + A length S x T x N matrix of SED values (in nJy), where S is the number of samples, + T is the number of time steps, and N is the number of wavelengths. + If S=1 then the function returns a T x N matrix. """ # Make sure times and wavelengths are numpy arrays. times = np.asarray(times) @@ -152,36 +154,46 @@ def evaluate(self, times, wavelengths, graph_state=None, given_args=None, rng_in graph_state = self.sample_parameters( given_args=given_args, num_samples=1, rng_info=rng_info, **kwargs ) - params = self.get_local_params(graph_state) - - # Pre-effects are adjustments done to times and/or wavelengths, before flux density - # computation. We skip if redshift is 0.0 since there is nothing to do. - if self.apply_redshift and params["redshift"] != 0.0: - if params.get("redshift", None) is None: - raise ValueError("The 'redshift' parameter is required for redshifted models.") - if params.get("t0", None) is None: - raise ValueError("The 't0' parameter is required for redshifted models.") - times, wavelengths = obs_to_rest_times_waves(times, wavelengths, params["redshift"], params["t0"]) - - # Compute the flux density for both the current object and add in anything - # behind it, such as a host galaxy. - flux_density = self.compute_flux(times, wavelengths, graph_state, **kwargs) - if self.background is not None: - flux_density += self.background.compute_flux( - times, - wavelengths, - graph_state, - ra=params["ra"], - dec=params["dec"], - **kwargs, - ) - # Post-effects are adjustments done to the flux density after computation. - if self.apply_redshift and params["redshift"] != 0.0: - # We have alread checked that redshift is not None. - flux_density = rest_to_obs_flux(flux_density, params["redshift"]) + results = np.empty((graph_state.num_samples, len(times), len(wavelengths))) + for sample_num, state in enumerate(graph_state): + params = self.get_local_params(state) + + # Pre-effects are adjustments done to times and/or wavelengths, before flux density + # computation. We skip if redshift is 0.0 since there is nothing to do. + if self.apply_redshift and params["redshift"] != 0.0: + if params.get("redshift", None) is None: + raise ValueError("The 'redshift' parameter is required for redshifted models.") + if params.get("t0", None) is None: + raise ValueError("The 't0' parameter is required for redshifted models.") + times, wavelengths = obs_to_rest_times_waves( + times, wavelengths, params["redshift"], params["t0"] + ) + + # Compute the flux density for both the current object and add in anything + # behind it, such as a host galaxy. + flux_density = self.compute_flux(times, wavelengths, state, **kwargs) + if self.background is not None: + flux_density += self.background.compute_flux( + times, + wavelengths, + state, + ra=params["ra"], + dec=params["dec"], + **kwargs, + ) + + # Post-effects are adjustments done to the flux density after computation. + if self.apply_redshift and params["redshift"] != 0.0: + # We have alread checked that redshift is not None. + flux_density = rest_to_obs_flux(flux_density, params["redshift"]) - return flux_density + # Save the result. + results[sample_num, :, :] = flux_density + + if graph_state.num_samples == 1: + return results[0, :, :] + return results def sample_parameters(self, given_args=None, num_samples=1, rng_info=None, **kwargs): """Sample the model's underlying parameters if they are provided by a function @@ -242,17 +254,17 @@ def get_band_fluxes(self, passband_or_group, times, filters, state) -> np.ndarra Returns ------- - band_fluxes : `numpy.ndarray` or `dict - A length T array of band fluxes, or a dictionary of band names mapped to fluxes (if a passband - group is used). + band_fluxes : `numpy.ndarray` + A matrix of the band fluxes. If only one sample is provided in the GraphState, + then returns a length T array. Otherwise returns a size S x T array where S is the + number of samples in the graph state. """ if isinstance(passband_or_group, Passband): - if filters is not None and not np.array_equal( - filters, np.repeat(passband_or_group.filter_name, len(times)) - ): + if filters is not None and not np.all(filters == passband_or_group.filter_name): raise ValueError( - "If passband_or_group is a Passband, " - "filters must be None or a list of the same filter repeated." + "If passband_or_group is a Passband, filters must either be None " + "or a list where every entry matches the given filter's name: " + f"{passband_or_group.filter_name}." ) spectral_fluxes = self.evaluate(times, passband_or_group.waves, state) return passband_or_group.fluxes_to_bandflux(spectral_fluxes) @@ -260,10 +272,13 @@ def get_band_fluxes(self, passband_or_group, times, filters, state) -> np.ndarra if filters is None: raise ValueError("If passband_or_group is a PassbandGroup, filters must be provided.") - band_fluxes = np.empty_like(times) + band_fluxes = np.empty((state.num_samples, len(times))) for filter_name in np.unique(filters): passband = passband_or_group.passbands[filter_name] filter_mask = filters == filter_name spectral_fluxes = self.evaluate(times[filter_mask], passband.waves, state) - band_fluxes[filter_mask] = passband.fluxes_to_bandflux(spectral_fluxes) + band_fluxes[:, filter_mask] = passband.fluxes_to_bandflux(spectral_fluxes) + + if state.num_samples == 1: + return band_fluxes[0, :] return band_fluxes diff --git a/tests/tdastro/astro_utils/test_passbands.py b/tests/tdastro/astro_utils/test_passbands.py index 5277ff6..fa3f3e7 100644 --- a/tests/tdastro/astro_utils/test_passbands.py +++ b/tests/tdastro/astro_utils/test_passbands.py @@ -395,6 +395,36 @@ def test_passband_fluxes_to_bandflux(passbands_dir, tmp_path): assert len(in_band_flux) == 5 +def test_passband_fluxes_to_bandflux_mult_samples(passbands_dir, tmp_path): + """Test the fluxes_to_bandflux method of the Passband class with multiple samples.""" + transmission_table = "100 0.5\n200 0.75\n300 0.25\n" + a_band = create_toy_passband(tmp_path, transmission_table, delta_wave=100, trim_quantile=None) + + # Define some mock flux values and calculate our expected bandflux + flux = np.array( + [ + [ + [1.0, 1.0, 1.0], + [2.0, 1.0, 1.0], + [3.0, 1.0, 1.0], + [2.0, 1.0, 1.0], + [1.0, 1.0, 1.0], + ], + [ + [2.0, 2.0, 2.0], + [4.0, 2.0, 2.0], + [6.0, 2.0, 2.0], + [4.0, 2.0, 2.0], + [2.0, 2.0, 2.0], + ], + ] + ) + expected = np.array([[1.0, 1.375, 1.75, 1.375, 1.0], [2.0, 2.75, 3.5, 2.75, 2.0]]) + + result = a_band.fluxes_to_bandflux(flux) + np.testing.assert_allclose(result, expected) + + def test_passband_wrapped_from_physical_source(passbands_dir, tmp_path): """Test get_band_fluxes, PhysicalModel's wrapped version of Passband's fluxes_to_bandflux..""" # Set up physical model diff --git a/tests/tdastro/sources/test_physical_models.py b/tests/tdastro/sources/test_physical_models.py index 209e290..8aa2fcd 100644 --- a/tests/tdastro/sources/test_physical_models.py +++ b/tests/tdastro/sources/test_physical_models.py @@ -2,6 +2,7 @@ import pytest from astropy.cosmology import Planck18 from tdastro.astro_utils.passbands import PassbandGroup +from tdastro.math_nodes.given_sampler import GivenSampler from tdastro.sources.physical_model import PhysicalModel from tdastro.sources.static_source import StaticSource @@ -53,6 +54,34 @@ def test_physical_model(): assert model4.get_param(state, "distance") is None +def test_physical_model_evaluate(): + """Test that we can evaluate a PhysicalModel.""" + times = np.array([0.0, 1.0, 2.0, 3.0, 4.0]) + waves = np.array([4000.0, 5000.0]) + brightness = GivenSampler([10.0, 20.0, 30.0]) + static_source = StaticSource(brightness=brightness) + + # Providing no state should give a single sample. + flux = static_source.evaluate(times, waves) + assert flux.shape == (5, 2) + assert np.all(flux == 10.0) + + # Doing a single sample should give a single sample. + state = static_source.sample_parameters(num_samples=1) + flux = static_source.evaluate(times, waves, graph_state=state) + assert flux.shape == (5, 2) + assert np.all(flux == 20.0) + + # We can do multiple samples. + brightness.reset() + state = static_source.sample_parameters(num_samples=3) + flux = static_source.evaluate(times, waves, graph_state=state) + assert flux.shape == (3, 5, 2) + assert np.all(flux[0, :, :] == 10.0) + assert np.all(flux[1, :, :] == 20.0) + assert np.all(flux[2, :, :] == 30.0) + + def test_physical_model_get_band_fluxes(passbands_dir): """Test that band fluxes are computed correctly.""" # It should work fine for any positive Fnu. @@ -60,8 +89,9 @@ def test_physical_model_get_band_fluxes(passbands_dir): static_source = StaticSource(brightness=f_nu) state = static_source.sample_parameters() passbands = PassbandGroup(preset="LSST") + n_passbands = len(passbands) - times = np.arange(len(passbands), dtype=float) + times = np.arange(n_passbands, dtype=float) filters = np.array(sorted(passbands.passbands.keys())) # It should fail if no filters are provided. @@ -72,4 +102,15 @@ def test_physical_model_get_band_fluxes(passbands_dir): _band_fluxes = static_source.get_band_fluxes(passbands.passbands["LSST_r"], times, filters, state) band_fluxes = static_source.get_band_fluxes(passbands, times, filters, state) + assert band_fluxes.shape == (n_passbands,) np.testing.assert_allclose(band_fluxes, f_nu, rtol=1e-10) + + # If we use multiple samples, we should get a correctly sized array. + n_samples = 21 + brightness_list = [1.5 * i for i in range(n_samples)] + static_source2 = StaticSource(brightness=GivenSampler(brightness_list)) + state2 = static_source2.sample_parameters(num_samples=n_samples) + band_fluxes2 = static_source2.get_band_fluxes(passbands, times, filters, state2) + assert band_fluxes2.shape == (n_samples, n_passbands) + for idx, brightness in enumerate(brightness_list): + np.testing.assert_allclose(band_fluxes2[idx, :], brightness, rtol=1e-10)