diff --git a/src/tdastro/sources/salt2_jax.py b/src/tdastro/sources/salt2_jax.py index 985c320..f9d7430 100644 --- a/src/tdastro/sources/salt2_jax.py +++ b/src/tdastro/sources/salt2_jax.py @@ -1,6 +1,9 @@ from pathlib import Path +from astropy import units as u + from tdastro.astro_utils.salt2_color_law import SALT2ColorLaw +from tdastro.astro_utils.unit_utils import flam_to_fnu from tdastro.sources.physical_model import PhysicalModel from tdastro.utils.bicubic_interp import BicubicInterpolator @@ -58,6 +61,9 @@ class SALT2JaxModel(PhysicalModel): Any additional keyword arguments. """ + # A class variable for the units so we are not computing them each time. + _FLAM_UNIT = u.erg / u.second / u.cm**2 / u.AA + def __init__( self, x0=None, @@ -101,6 +107,7 @@ def compute_flux(self, phase, wavelengths, graph_state, **kwargs): A length N array of wavelengths (in angstroms). graph_state : `GraphState` An object mapping graph parameters to their values. + **kwargs : `dict`, optional Any additional keyword arguments. @@ -118,4 +125,13 @@ def compute_flux(self, phase, wavelengths, graph_state, **kwargs): * (m0_vals + params["x1"] * m1_vals) * 10.0 ** (-0.4 * self._colorlaw.apply(wavelengths) * params["c"]) ) + + # Convert to the correct units. + flux_density = flam_to_fnu( + flux_density, + wavelengths, + wave_unit=u.AA, + flam_unit=self._FLAM_UNIT, + fnu_unit=u.nJy, + ) return flux_density diff --git a/tests/tdastro/sources/test_salt2.py b/tests/tdastro/sources/test_salt2.py index ea786fa..ada9651 100644 --- a/tests/tdastro/sources/test_salt2.py +++ b/tests/tdastro/sources/test_salt2.py @@ -2,33 +2,7 @@ import pytest from sncosmo.models import SALT2Source from tdastro.sources.salt2_jax import SALT2JaxModel - - -def test_salt2_model(test_data_dir): - """Test loading a SALT2 object from a file and querying it.""" - dir_name = test_data_dir / "truncated-salt2-h17" - model = SALT2JaxModel(x0=0.5, x1=0.2, c=1.0, model_dir=dir_name) - - assert model._colorlaw is not None - assert model._m0_model is not None - assert model._m1_model is not None - - # Test compared to values computed via sncosmo's implementation that - # fall within the range of the truncated grid. We multiple by 1e12 - # for comparison precision purposes. - times = np.array([1.0, 2.1, 3.9, 4.0]) - waves = np.array([4000.0, 4102.0, 4200.0]) - expected_times_1e12 = np.array( - [ - [0.12842110, 0.17791164, 0.17462753], - [0.12287933, 0.17060205, 0.17152248], - [0.11121435, 0.15392100, 0.16234423], - [0.11051545, 0.15288580, 0.16170497], - ] - ) - - flux = model.evaluate(times, waves) - assert np.allclose(flux * 1e12, expected_times_1e12) +from tdastro.sources.sncomso_models import SncosmoWrapperModel def test_salt2_model_parity(test_data_dir): @@ -37,8 +11,11 @@ def test_salt2_model_parity(test_data_dir): """ dir_name = test_data_dir / "truncated-salt2-h17" td_model = SALT2JaxModel(x0=0.4, x1=0.3, c=1.1, model_dir=dir_name) - sn_model = SALT2Source(modeldir=dir_name) - sn_model.set(x0=0.4, x1=0.3, c=1.1) + + # We need to overwrite the source parameter to correspond to + # the truncated directory data. + sn_model = SncosmoWrapperModel("SALT2", x0=0.4, x1=0.3, c=1.1) + sn_model.source = SALT2Source(modeldir=dir_name) # Test compared to values computed via sncosmo's implementation that # fall within the range of the truncated grid. We multiple by 1e12 @@ -46,9 +23,10 @@ def test_salt2_model_parity(test_data_dir): times = np.arange(-1.0, 15.0, 0.01) waves = np.arange(3800.0, 4200.0, 0.5) + # Allow TDAstro to return both sets of results in f_nu. flux_td = td_model.evaluate(times, waves) - flux_sn = sn_model._flux(times, waves) - assert np.allclose(flux_td * 1e12, flux_sn * 1e12) + flux_sn = sn_model.evaluate(times, waves) + assert np.allclose(flux_td, flux_sn) def test_salt2_no_model(test_data_dir):