-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Initial Implementation of a JAX Salt2 model
1 parent
291c03c
commit c0744af
Showing
21 changed files
with
55,027 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
"""The ColorLaw used by SALT models as defined by in (Guy J., 2007) | ||
It is adapted from sncosmo's SALT2ColorLaw class (but implemented in JAX): | ||
https://github.com/sncosmo/sncosmo/blob/v2.10.1/sncosmo/salt2utils.pyx | ||
""" | ||
|
||
import jax.numpy as jnp | ||
import numpy as np | ||
|
||
# Constants used in SALT2ColorLaw computations (from: | ||
# https://github.com/sncosmo/sncosmo/blob/v2.10.1/sncosmo/salt2utils.pyx) | ||
_SALT2CL_B = 4302.57 # B-band-ish wavelength | ||
_SALT2CL_V = 5428.55 # V-band-ish wavelength | ||
_WAVESCALE = 1.0 / (_SALT2CL_V - _SALT2CL_B) | ||
|
||
|
||
class SALT2ColorLaw: | ||
"""An object that applies the color law to the given wavelengths. | ||
Parameters | ||
---------- | ||
wave_min : float | ||
The minimum wavelength (in angstroms) | ||
wave_max : float | ||
The maximum wavelength (in angstroms) | ||
coeffs : list, numpy array, or jax array | ||
The <= 6 coefficients of the polynomial to use. | ||
Attributes | ||
---------- | ||
coeffs : JAX array | ||
The final 7 coefficients to use (based on the <= 6 provided in the parameters). | ||
scaled_wave_min : float | ||
The minimum wavelength shifted and scaled. | ||
scaled_wave_max : float | ||
The maximum wavelength shifted and scaled. | ||
value_at_min : float | ||
The value of the polynomial at wave_min. | ||
value_at_max : float | ||
The value of the polynomial at wave_max. | ||
_exponents : JAX array | ||
A precomputed array of the exponents to use. | ||
""" | ||
|
||
def __init__(self, wave_min, wave_max, coeffs): | ||
# Create the internal coefficient array. The new first entry is 1.0 minus the | ||
# sum of the given entries. The first six given entries are then listed. | ||
coeffs = np.array(coeffs) | ||
num_coeffs = min(len(coeffs), 6) | ||
|
||
padded_coeffs = np.zeros(7) | ||
padded_coeffs[1 : (num_coeffs + 1)] = np.array(coeffs)[0:num_coeffs] | ||
padded_coeffs[0] = 1.0 - np.sum(padded_coeffs) | ||
self.coeffs = jnp.asarray(padded_coeffs) | ||
|
||
# Compute the bounds for the wavelengths and the value of the polynomial at both bounds. | ||
self.scaled_wave_min = (wave_min - _SALT2CL_B) * _WAVESCALE | ||
self.scaled_wave_max = (wave_max - _SALT2CL_B) * _WAVESCALE | ||
|
||
self.exponents = jnp.arange(1, 8, 1) | ||
self.value_at_min = jnp.sum(self.coeffs * jnp.power(self.scaled_wave_min, self.exponents)) | ||
self.value_at_max = jnp.sum(self.coeffs * jnp.power(self.scaled_wave_max, self.exponents)) | ||
|
||
# Precompute the polynomials derivative at the min and max wavelength. | ||
dcoeffs = jnp.arange(2, 8, 1) * self.coeffs[1:7] | ||
dexponents = jnp.arange(1, 7, 1) | ||
self.deriv_at_min = jnp.sum(dcoeffs * jnp.power(self.scaled_wave_min, dexponents)) + self.coeffs[0] | ||
self.deriv_at_max = jnp.sum(dcoeffs * jnp.power(self.scaled_wave_max, dexponents)) + self.coeffs[0] | ||
|
||
@classmethod | ||
def from_file(cls, filename): | ||
"""Create the SALT2ColorLaw object from data in a file. | ||
Parameters | ||
---------- | ||
filename : str | ||
The name of the file to load. | ||
""" | ||
with open(filename, mode="r") as f: | ||
data = f.read().split() | ||
|
||
# The first line holds the number of coefficients N and the next N lines | ||
# each hold a single coefficient. | ||
num_coeff = int(data[0]) | ||
coeffs = np.array(data[1 : (1 + num_coeff)], dtype=float) | ||
|
||
# The rest of the lines (if any) are meta-data with a label and value on each line. | ||
wave_min = 3000.0 | ||
wave_max = 7000.0 | ||
for i in range(1 + num_coeff, len(data), 2): | ||
if "min_lambda" in data[i]: | ||
wave_min = float(data[i + 1]) | ||
elif "max_lambda" in data[i]: | ||
wave_max = float(data[i + 1]) | ||
elif "version" in data[i]: | ||
version = int(data[i + 1]) | ||
if version != 1: | ||
raise RuntimeError(f"Unsupported version {version}.") | ||
|
||
return SALT2ColorLaw(wave_min, wave_max, coeffs) | ||
|
||
def apply(self, wavelengths): | ||
"""Apply the color law to the given wavelengths. | ||
Parameters | ||
---------- | ||
wavelengths : array | ||
The wavelengths in angstroms. | ||
""" | ||
num_waves = len(wavelengths) | ||
shifted_wave = (jnp.asarray(wavelengths) - _SALT2CL_B) * _WAVESCALE | ||
|
||
# Compute the three cases of interest. | ||
# 1) If the shifted wave is past the lower bound, extrapolate a value based on | ||
# the value and derivative at that bound. | ||
below = self.value_at_min + self.deriv_at_min * (shifted_wave - self.scaled_wave_min) | ||
# 2) If the shifted wave is past the upper bound, extrapolate a value based on | ||
# the value and derivative at that bound. | ||
above = self.value_at_max + self.deriv_at_max * (shifted_wave - self.scaled_wave_max) | ||
# 3) If the shifted value is in the middle, use the polynomial. | ||
wave_T = jnp.reshape(shifted_wave, (num_waves, 1)) | ||
coeffs_all = jnp.tile(self.coeffs, (num_waves, 1)) | ||
middle = jnp.sum((coeffs_all * jnp.power(wave_T, self.exponents)).T, axis=0) | ||
|
||
result = -jnp.where( | ||
shifted_wave < self.scaled_wave_min, | ||
below, | ||
jnp.where( | ||
shifted_wave > self.scaled_wave_max, | ||
above, | ||
middle, | ||
), | ||
) | ||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
from pathlib import Path | ||
|
||
from tdastro.astro_utils.salt2_color_law import SALT2ColorLaw | ||
from tdastro.sources.physical_model import PhysicalModel | ||
from tdastro.utils.bicubic_interp import BicubicInterpolator | ||
|
||
|
||
class SALT2JaxModel(PhysicalModel): | ||
"""A SALT2 model implemented with JAX for it can use auto-differentiation. | ||
The model is defined in (Guy J., 2007) as: | ||
flux(time, wave) = x0 * [M0(time, wave) + x1 * M1(time, wave)] * exp(c * CL(wave)) | ||
where x0, x1, and c are given parameters, M0 is the average spectral sequence, | ||
M1 is the first compoment to describe variability, and CL is the average color | ||
correction law. | ||
We use the formulation in sncosmo where CL is defined such that: | ||
flux(time, wave) = x0 * [M0(time, wave) + x1 * M1(time, wave)] * 10 ** (-0.4 * c * CL(wave)) | ||
This class is based on the sncosmo implementation at: | ||
https://github.com/sncosmo/sncosmo/blob/v2.10.1/sncosmo/models.py | ||
The wrapped sncosmo version in sncosmo_models.py is faster and should be used | ||
when auto-differentiation is not needed. | ||
Attributes | ||
---------- | ||
_m0_model : BicubicInterpolator | ||
The interpolator for the m0 parameter. | ||
_m1_model : BicubicInterpolator | ||
The interpolator for the m1 parameter. | ||
_colorlaw : SALT2ColorLaw | ||
The data to apply the color law. | ||
Parameters | ||
---------- | ||
x0 : parameter | ||
The SALT2 x0 parameter. | ||
x1 : parameter | ||
The SALT2 x1 parameter. | ||
c : parameter | ||
The SALT2 c parameter. | ||
model_dir : `str` | ||
The path for the model file directory. | ||
Default: "" | ||
m0_filename : `str` | ||
The file name for the m0 model component. | ||
Default: "salt2_template_0.dat" | ||
m1_filename : `str` | ||
The file name for the m1 model component. | ||
Default: "salt2_template_1.dat" | ||
cl_filename : `str` | ||
The file name of the color law correction coefficients. | ||
Default: "salt2_color_correction.dat", | ||
**kwargs : `dict`, optional | ||
Any additional keyword arguments. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
x0=None, | ||
x1=None, | ||
c=None, | ||
model_dir="", | ||
m0_filename="salt2_template_0.dat", | ||
m1_filename="salt2_template_1.dat", | ||
cl_filename="salt2_color_correction.dat", | ||
**kwargs, | ||
): | ||
super().__init__(**kwargs) | ||
|
||
# Add the model specific parameters. | ||
self.add_parameter("x0", x0, **kwargs) | ||
self.add_parameter("x1", x1, **kwargs) | ||
self.add_parameter("c", c, **kwargs) | ||
|
||
# Load the data files. | ||
model_path = Path(model_dir) | ||
self._m0_model = BicubicInterpolator.from_grid_file( | ||
model_path / m0_filename, | ||
scale_factor=1e-12, | ||
) | ||
self._m1_model = BicubicInterpolator.from_grid_file( | ||
model_path / m1_filename, | ||
scale_factor=1e-12, | ||
) | ||
|
||
# Use the default color correction values. | ||
self._colorlaw = SALT2ColorLaw.from_file(model_path / cl_filename) | ||
|
||
def compute_flux(self, phase, wavelengths, graph_state, **kwargs): | ||
"""Draw effect-free observations for this object. | ||
Parameters | ||
---------- | ||
phase : `numpy.ndarray` | ||
A length T array of rest frame timestamps. | ||
wavelengths : `numpy.ndarray`, optional | ||
A length N array of wavelengths (in angstroms). | ||
graph_state : `GraphState` | ||
An object mapping graph parameters to their values. | ||
**kwargs : `dict`, optional | ||
Any additional keyword arguments. | ||
Returns | ||
------- | ||
flux_density : `numpy.ndarray` | ||
A length T x N matrix of SED values (in nJy). | ||
""" | ||
m0_vals = self._m0_model(phase, wavelengths) | ||
m1_vals = self._m1_model(phase, wavelengths) | ||
params = self.get_local_params(graph_state) | ||
|
||
flux_density = ( | ||
params["x0"] | ||
* (m0_vals + params["x1"] * m1_vals) | ||
* 10.0 ** (-0.4 * self._colorlaw.apply(wavelengths) * params["c"]) | ||
) | ||
return flux_density |
Empty file.
Oops, something went wrong.