Skip to content

Commit

Permalink
Merge pull request #195 from lincc-frameworks/jax_salt2_fixes
Browse files Browse the repository at this point in the history
Fix the units for the JAX SALT2 model
  • Loading branch information
jeremykubica authored Dec 2, 2024
2 parents 97f783f + 40ba853 commit 9a932e9
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 31 deletions.
16 changes: 16 additions & 0 deletions src/tdastro/sources/salt2_jax.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from pathlib import Path

from astropy import units as u

from tdastro.astro_utils.salt2_color_law import SALT2ColorLaw
from tdastro.astro_utils.unit_utils import flam_to_fnu
from tdastro.sources.physical_model import PhysicalModel
from tdastro.utils.bicubic_interp import BicubicInterpolator

Expand Down Expand Up @@ -58,6 +61,9 @@ class SALT2JaxModel(PhysicalModel):
Any additional keyword arguments.
"""

# A class variable for the units so we are not computing them each time.
_FLAM_UNIT = u.erg / u.second / u.cm**2 / u.AA

def __init__(
self,
x0=None,
Expand Down Expand Up @@ -101,6 +107,7 @@ def compute_flux(self, phase, wavelengths, graph_state, **kwargs):
A length N array of wavelengths (in angstroms).
graph_state : `GraphState`
An object mapping graph parameters to their values.
**kwargs : `dict`, optional
Any additional keyword arguments.
Expand All @@ -118,4 +125,13 @@ def compute_flux(self, phase, wavelengths, graph_state, **kwargs):
* (m0_vals + params["x1"] * m1_vals)
* 10.0 ** (-0.4 * self._colorlaw.apply(wavelengths) * params["c"])
)

# Convert to the correct units.
flux_density = flam_to_fnu(
flux_density,
wavelengths,
wave_unit=u.AA,
flam_unit=self._FLAM_UNIT,
fnu_unit=u.nJy,
)
return flux_density
40 changes: 9 additions & 31 deletions tests/tdastro/sources/test_salt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,7 @@
import pytest
from sncosmo.models import SALT2Source
from tdastro.sources.salt2_jax import SALT2JaxModel


def test_salt2_model(test_data_dir):
"""Test loading a SALT2 object from a file and querying it."""
dir_name = test_data_dir / "truncated-salt2-h17"
model = SALT2JaxModel(x0=0.5, x1=0.2, c=1.0, model_dir=dir_name)

assert model._colorlaw is not None
assert model._m0_model is not None
assert model._m1_model is not None

# Test compared to values computed via sncosmo's implementation that
# fall within the range of the truncated grid. We multiple by 1e12
# for comparison precision purposes.
times = np.array([1.0, 2.1, 3.9, 4.0])
waves = np.array([4000.0, 4102.0, 4200.0])
expected_times_1e12 = np.array(
[
[0.12842110, 0.17791164, 0.17462753],
[0.12287933, 0.17060205, 0.17152248],
[0.11121435, 0.15392100, 0.16234423],
[0.11051545, 0.15288580, 0.16170497],
]
)

flux = model.evaluate(times, waves)
assert np.allclose(flux * 1e12, expected_times_1e12)
from tdastro.sources.sncomso_models import SncosmoWrapperModel


def test_salt2_model_parity(test_data_dir):
Expand All @@ -37,18 +11,22 @@ def test_salt2_model_parity(test_data_dir):
"""
dir_name = test_data_dir / "truncated-salt2-h17"
td_model = SALT2JaxModel(x0=0.4, x1=0.3, c=1.1, model_dir=dir_name)
sn_model = SALT2Source(modeldir=dir_name)
sn_model.set(x0=0.4, x1=0.3, c=1.1)

# We need to overwrite the source parameter to correspond to
# the truncated directory data.
sn_model = SncosmoWrapperModel("SALT2", x0=0.4, x1=0.3, c=1.1)
sn_model.source = SALT2Source(modeldir=dir_name)

# Test compared to values computed via sncosmo's implementation that
# fall within the range of the truncated grid. We multiple by 1e12
# for comparison precision purposes.
times = np.arange(-1.0, 15.0, 0.01)
waves = np.arange(3800.0, 4200.0, 0.5)

# Allow TDAstro to return both sets of results in f_nu.
flux_td = td_model.evaluate(times, waves)
flux_sn = sn_model._flux(times, waves)
assert np.allclose(flux_td * 1e12, flux_sn * 1e12)
flux_sn = sn_model.evaluate(times, waves)
assert np.allclose(flux_td, flux_sn)


def test_salt2_no_model(test_data_dir):
Expand Down

0 comments on commit 9a932e9

Please sign in to comment.