Skip to content

Commit

Permalink
Merge pull request #150 from lincc-frameworks/sncosmo_units
Browse files Browse the repository at this point in the history
Change sncosmo output flux unit to fnu
  • Loading branch information
mi-dai authored Oct 7, 2024
2 parents 7efe1be + 6960fbf commit 3033e07
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 32 deletions.
13 changes: 8 additions & 5 deletions src/tdastro/astro_utils/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from tdastro.base_models import FunctionNode


def obs_frame_to_rest_frame(observer_frame_times, observer_frame_wavelengths, redshift, t0):
def obs_to_rest_times_waves(observer_frame_times, observer_frame_wavelengths, redshift, t0):
"""Calculate the rest frame times and wavelengths needed to give user the observer frame times
and wavelengths (given the redshift).
Expand Down Expand Up @@ -33,8 +33,11 @@ def obs_frame_to_rest_frame(observer_frame_times, observer_frame_wavelengths, re
return (rest_frame_times, rest_frame_wavelengths)


def apply_redshift(flux_density, redshift):
"""Apply the redshift effect to rest frame flux density values.
def rest_to_obs_flux(flux_density, redshift):
"""Convert rest-frame flux to obs-frame flux.
The (1+redshift) factor is applied to preserve bolometric flux.
The rest-frame flux is defined as F_nu = L_nu / 4*pi*D_L**2,
where D_L is the luminosity distance.
Parameters
----------
Expand All @@ -46,9 +49,9 @@ def apply_redshift(flux_density, redshift):
Returns
-------
flux_density : `numpy.ndarray`
The redshifted results (in nJy).
The observer frame flux (in nJy).
"""
return flux_density / (1 + redshift)
return flux_density * (1 + redshift)


def redshift_to_distance(redshift, cosmology):
Expand Down
6 changes: 3 additions & 3 deletions src/tdastro/astro_utils/snia_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ def pdf(self, x1):

def _x0_from_distmod(distmod, x1, c, alpha, beta, m_abs):
"""Calculate the SALT3 x0 parameter given distance modulus based on Tripp relation.
distmod = -2.5*log10(x0) + alpha * x1 - beta * c - m_abs
x0 = 10 ^ (-0.4* (distmod - alpha * x1 + beta * c + m_abs))
distmod = -2.5*log10(x0) + alpha * x1 - beta * c - m_abs + 10.635
x0 = 10 ^ (-0.4* (distmod - alpha * x1 + beta * c + m_abs - 10.635))
Parameters
----------
Expand All @@ -171,7 +171,7 @@ def _x0_from_distmod(distmod, x1, c, alpha, beta, m_abs):
x0 : `float`
The x0 parameter
"""
x0 = np.power(10.0, -0.4 * (distmod - alpha * x1 + beta * c + m_abs))
x0 = np.power(10.0, -0.4 * (distmod - alpha * x1 + beta * c + m_abs - 10.635))

return x0

Expand Down
14 changes: 8 additions & 6 deletions src/tdastro/sources/physical_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np

from tdastro.astro_utils.passbands import Passband
from tdastro.astro_utils.redshift import RedshiftDistFunc, apply_redshift, obs_frame_to_rest_frame
from tdastro.astro_utils.redshift import RedshiftDistFunc, obs_to_rest_times_waves, rest_to_obs_flux
from tdastro.base_models import ParameterizedNode
from tdastro.graph_state import GraphState
from tdastro.rand_nodes.np_random import build_rngs_from_hashes
Expand Down Expand Up @@ -100,21 +100,23 @@ def set_apply_redshift(self, apply_redshift):
self.apply_redshift = apply_redshift

def _evaluate(self, times, wavelengths, graph_state):
"""Draw effect-free observations for this object.
"""Draw effect-free rest frame flux densities.
The rest-frame flux is defined as F_nu = L_nu / 4*pi*D_L**2,
where D_L is the luminosity distance.
Parameters
----------
times : `numpy.ndarray`
A length T array of rest frame timestamps.
wavelengths : `numpy.ndarray`, optional
A length N array of wavelengths (in angstroms).
A length N array of rest frame wavelengths (in angstroms).
graph_state : `GraphState`
An object mapping graph parameters to their values.
Returns
-------
flux_density : `numpy.ndarray`
A length T x N matrix of SED values (in nJy).
A length T x N matrix of rest frame SED values (in nJy).
"""
raise NotImplementedError()

Expand Down Expand Up @@ -161,7 +163,7 @@ def evaluate(self, times, wavelengths, graph_state=None, given_args=None, rng_in
raise ValueError("The 'redshift' parameter is required for redshifted models.")
if params.get("t0", None) is None:
raise ValueError("The 't0' parameter is required for redshifted models.")
times, wavelengths = obs_frame_to_rest_frame(times, wavelengths, params["redshift"], params["t0"])
times, wavelengths = obs_to_rest_times_waves(times, wavelengths, params["redshift"], params["t0"])

# Compute the flux density for both the current object and add in anything
# behind it, such as a host galaxy.
Expand All @@ -179,7 +181,7 @@ def evaluate(self, times, wavelengths, graph_state=None, given_args=None, rng_in
# Post-effects are adjustments done to the flux density after computation.
if self.apply_redshift and params["redshift"] != 0.0:
# We have alread checked that redshift is not None.
flux_density = apply_redshift(flux_density, params["redshift"])
flux_density = rest_to_obs_flux(flux_density, params["redshift"])

return flux_density

Expand Down
16 changes: 14 additions & 2 deletions src/tdastro/sources/sncomso_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
https://sncosmo.readthedocs.io/en/stable/models.html
"""

from astropy import units as u
from sncosmo.models import get_source

from tdastro.astro_utils.unit_utils import flam_to_fnu
from tdastro.sources.physical_model import PhysicalModel


Expand Down Expand Up @@ -143,8 +145,18 @@ def _evaluate(self, times, wavelengths, graph_state=None, **kwargs):
Returns
-------
flux_density : `numpy.ndarray`
A length T x N matrix of SED values (in ergs/s/cm^2/AA).
A length T x N matrix of SED values (in nJy).
"""
params = self.get_local_params(graph_state)
self._update_sncosmo_model_parameters(graph_state)
return self.source.flux(times - params["t0"], wavelengths)

flux_flam = self.source.flux(times - params["t0"], wavelengths)
flux_fnu = flam_to_fnu(
flux_flam,
wavelengths,
wave_unit=u.AA,
flam_unit=u.erg / u.second / u.cm**2 / u.AA,
fnu_unit=u.nJy,
)

return flux_fnu
2 changes: 1 addition & 1 deletion tests/tdastro/astro_utils/test_redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_redshifted_flux_densities() -> None:

for i, time in enumerate(times):
if t0 <= time and time <= (t1 - t0) * (1 + redshift) + t0:
assert np.all(values_redshift[i] == brightness / (1 + redshift))
assert np.all(values_redshift[i] == brightness * (1 + redshift))
else:
assert np.all(values_redshift[i] == 0.0)

Expand Down
24 changes: 20 additions & 4 deletions tests/tdastro/sources/test_sncosmo_models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import numpy as np
from astropy import units as u
from tdastro.astro_utils.unit_utils import fnu_to_flam
from tdastro.rand_nodes.np_random import NumpyRandomFunc
from tdastro.sources.sncomso_models import SncosmoWrapperModel

Expand All @@ -19,8 +21,15 @@ def test_sncomso_models_hsiao() -> None:
# model = sncosmo.Model(source='hsiao')
# model.set(z=0.0, t0=0.0, amplitude=2.0e10)
# model.flux(5., [4000., 4100., 4200.])
fluxes = model.evaluate([5.0], [4000.0, 4100.0, 4200.0])
assert np.allclose(fluxes, [133.98143039, 152.74613574, 134.40916824])
fluxes_fnu = model.evaluate([5.0], [4000.0, 4100.0, 4200.0])
fluxes_flam = fnu_to_flam(
fluxes_fnu,
[4000.0, 4100.0, 4200.0],
wave_unit=u.AA,
flam_unit=u.erg / u.second / u.cm**2 / u.AA,
fnu_unit=u.nJy,
)
assert np.allclose(fluxes_flam, [133.98143039, 152.74613574, 134.40916824])


def test_sncomso_models_hsiao_t0() -> None:
Expand All @@ -37,8 +46,15 @@ def test_sncomso_models_hsiao_t0() -> None:
# model = sncosmo.Model(source='hsiao')
# model.set(z=0.0, t0=55000., amplitude=2.0e10)
# model.flux(54990., [4000., 4100., 4200.])
fluxes = model.evaluate([54990.0], [4000.0, 4100.0, 4200.0])
assert np.allclose(fluxes, [67.83696271, 67.98471119, 47.20395186])
fluxes_fnu = model.evaluate([54990.0], [4000.0, 4100.0, 4200.0])
fluxes_flam = fnu_to_flam(
fluxes_fnu,
[4000.0, 4100.0, 4200.0],
wave_unit=u.AA,
flam_unit=u.erg / u.second / u.cm**2 / u.AA,
fnu_unit=u.nJy,
)
assert np.allclose(fluxes_flam, [67.83696271, 67.98471119, 47.20395186])


def test_sncomso_models_set() -> None:
Expand Down
37 changes: 26 additions & 11 deletions tests/tdastro/sources/test_snia.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from astropy import units as u
from tdastro.astro_utils.passbands import PassbandGroup
from tdastro.astro_utils.snia_utils import DistModFromRedshift, HostmassX1Func, X0FromDistMod
from tdastro.astro_utils.unit_utils import flam_to_fnu
from tdastro.astro_utils.unit_utils import flam_to_fnu, fnu_to_flam
from tdastro.rand_nodes.np_random import NumpyRandomFunc
from tdastro.sources.sncomso_models import SncosmoWrapperModel
from tdastro.sources.snia_host import SNIaHost
Expand Down Expand Up @@ -44,18 +44,25 @@ def draw_single_random_sn(

res["times"] = times

flux_flam = source.evaluate(times, wave_obs, graph_state=state)
res["flux_flam"] = flux_flam
flux_nJy = source.evaluate(times, wave_obs, graph_state=state)
# res["flux_flam"] = flux_flam

# convert ergs/s/cm^2/AA to nJy
# # convert ergs/s/cm^2/AA to nJy

flux_fnu = flam_to_fnu(
flux_flam, wave_obs, wave_unit=u.AA, flam_unit=u.erg / u.second / u.cm**2 / u.AA, fnu_unit=u.nJy
)
# flux_fnu = flam_to_fnu(
# flux_flam, wave_obs, wave_unit=u.AA, flam_unit=u.erg / u.second / u.cm**2 / u.AA, fnu_unit=u.nJy
# )

res["flux_fnu"] = flux_fnu
res["flux_nJy"] = flux_nJy
res["flux_flam"] = fnu_to_flam(
flux_nJy,
wave_obs,
wave_unit=u.AA,
flam_unit=u.erg / u.second / u.cm**2 / u.AA,
fnu_unit=u.nJy,
)

bandfluxes = passbands.fluxes_to_bandfluxes(flux_fnu)
bandfluxes = passbands.fluxes_to_bandfluxes(flux_nJy)
res["bandfluxes"] = bandfluxes

res["state"] = state
Expand Down Expand Up @@ -128,8 +135,8 @@ def run_snia_end2end(oversampled_observations, passbands_dir, nsample=1):
"table_path": f"{passbands_dir}/LSST/r.dat",
},
{
"filter_name": "i",
"table_path": f"{passbands_dir}/LSST/u.dat",
"filter_name": "g",
"table_path": f"{passbands_dir}/LSST/g.dat",
},
],
survey="LSST",
Expand Down Expand Up @@ -169,6 +176,14 @@ def run_snia_end2end(oversampled_observations, passbands_dir, nsample=1):
time = res["times"]

flux_sncosmo = model.flux(time, wave)
fnu_sncosmo = flam_to_fnu(
flux_sncosmo,
wave,
wave_unit=u.AA,
flam_unit=u.erg / u.second / u.cm**2 / u.AA,
fnu_unit=u.nJy,
)
np.testing.assert_allclose(res["flux_nJy"], fnu_sncosmo, atol=1e-6)
np.testing.assert_allclose(res["flux_flam"], flux_sncosmo, atol=1e-30, rtol=1e-5)

for f, passband in passbands.passbands.items():
Expand Down

0 comments on commit 3033e07

Please sign in to comment.