Skip to content

Commit

Permalink
Initial Implementation of a JAX Salt2 model
Browse files Browse the repository at this point in the history
jeremykubica committed Nov 21, 2024
1 parent 291c03c commit c0744af
Showing 21 changed files with 55,027 additions and 0 deletions.
134 changes: 134 additions & 0 deletions src/tdastro/astro_utils/salt2_color_law.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""The ColorLaw used by SALT models as defined by in (Guy J., 2007)
It is adapted from sncosmo's SALT2ColorLaw class (but implemented in JAX):
https://github.com/sncosmo/sncosmo/blob/v2.10.1/sncosmo/salt2utils.pyx
"""

import jax.numpy as jnp
import numpy as np

# Constants used in SALT2ColorLaw computations (from:
# https://github.com/sncosmo/sncosmo/blob/v2.10.1/sncosmo/salt2utils.pyx)
_SALT2CL_B = 4302.57 # B-band-ish wavelength
_SALT2CL_V = 5428.55 # V-band-ish wavelength
_WAVESCALE = 1.0 / (_SALT2CL_V - _SALT2CL_B)


class SALT2ColorLaw:
"""An object that applies the color law to the given wavelengths.
Parameters
----------
wave_min : float
The minimum wavelength (in angstroms)
wave_max : float
The maximum wavelength (in angstroms)
coeffs : list, numpy array, or jax array
The <= 6 coefficients of the polynomial to use.
Attributes
----------
coeffs : JAX array
The final 7 coefficients to use (based on the <= 6 provided in the parameters).
scaled_wave_min : float
The minimum wavelength shifted and scaled.
scaled_wave_max : float
The maximum wavelength shifted and scaled.
value_at_min : float
The value of the polynomial at wave_min.
value_at_max : float
The value of the polynomial at wave_max.
_exponents : JAX array
A precomputed array of the exponents to use.
"""

def __init__(self, wave_min, wave_max, coeffs):
# Create the internal coefficient array. The new first entry is 1.0 minus the
# sum of the given entries. The first six given entries are then listed.
coeffs = np.array(coeffs)
num_coeffs = min(len(coeffs), 6)

padded_coeffs = np.zeros(7)
padded_coeffs[1 : (num_coeffs + 1)] = np.array(coeffs)[0:num_coeffs]
padded_coeffs[0] = 1.0 - np.sum(padded_coeffs)
self.coeffs = jnp.asarray(padded_coeffs)

# Compute the bounds for the wavelengths and the value of the polynomial at both bounds.
self.scaled_wave_min = (wave_min - _SALT2CL_B) * _WAVESCALE
self.scaled_wave_max = (wave_max - _SALT2CL_B) * _WAVESCALE

self.exponents = jnp.arange(1, 8, 1)
self.value_at_min = jnp.sum(self.coeffs * jnp.power(self.scaled_wave_min, self.exponents))
self.value_at_max = jnp.sum(self.coeffs * jnp.power(self.scaled_wave_max, self.exponents))

# Precompute the polynomials derivative at the min and max wavelength.
dcoeffs = jnp.arange(2, 8, 1) * self.coeffs[1:7]
dexponents = jnp.arange(1, 7, 1)
self.deriv_at_min = jnp.sum(dcoeffs * jnp.power(self.scaled_wave_min, dexponents)) + self.coeffs[0]
self.deriv_at_max = jnp.sum(dcoeffs * jnp.power(self.scaled_wave_max, dexponents)) + self.coeffs[0]

@classmethod
def from_file(cls, filename):
"""Create the SALT2ColorLaw object from data in a file.
Parameters
----------
filename : str
The name of the file to load.
"""
with open(filename, mode="r") as f:
data = f.read().split()

# The first line holds the number of coefficients N and the next N lines
# each hold a single coefficient.
num_coeff = int(data[0])
coeffs = np.array(data[1 : (1 + num_coeff)], dtype=float)

# The rest of the lines (if any) are meta-data with a label and value on each line.
wave_min = 3000.0
wave_max = 7000.0
for i in range(1 + num_coeff, len(data), 2):
if "min_lambda" in data[i]:
wave_min = float(data[i + 1])
elif "max_lambda" in data[i]:
wave_max = float(data[i + 1])
elif "version" in data[i]:
version = int(data[i + 1])
if version != 1:
raise RuntimeError(f"Unsupported version {version}.")

return SALT2ColorLaw(wave_min, wave_max, coeffs)

def apply(self, wavelengths):
"""Apply the color law to the given wavelengths.
Parameters
----------
wavelengths : array
The wavelengths in angstroms.
"""
num_waves = len(wavelengths)
shifted_wave = (jnp.asarray(wavelengths) - _SALT2CL_B) * _WAVESCALE

# Compute the three cases of interest.
# 1) If the shifted wave is past the lower bound, extrapolate a value based on
# the value and derivative at that bound.
below = self.value_at_min + self.deriv_at_min * (shifted_wave - self.scaled_wave_min)
# 2) If the shifted wave is past the upper bound, extrapolate a value based on
# the value and derivative at that bound.
above = self.value_at_max + self.deriv_at_max * (shifted_wave - self.scaled_wave_max)
# 3) If the shifted value is in the middle, use the polynomial.
wave_T = jnp.reshape(shifted_wave, (num_waves, 1))
coeffs_all = jnp.tile(self.coeffs, (num_waves, 1))
middle = jnp.sum((coeffs_all * jnp.power(wave_T, self.exponents)).T, axis=0)

result = -jnp.where(
shifted_wave < self.scaled_wave_min,
below,
jnp.where(
shifted_wave > self.scaled_wave_max,
above,
middle,
),
)
return result
121 changes: 121 additions & 0 deletions src/tdastro/sources/salt2_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from pathlib import Path

from tdastro.astro_utils.salt2_color_law import SALT2ColorLaw
from tdastro.sources.physical_model import PhysicalModel
from tdastro.utils.bicubic_interp import BicubicInterpolator


class SALT2JaxModel(PhysicalModel):
"""A SALT2 model implemented with JAX for it can use auto-differentiation.
The model is defined in (Guy J., 2007) as:
flux(time, wave) = x0 * [M0(time, wave) + x1 * M1(time, wave)] * exp(c * CL(wave))
where x0, x1, and c are given parameters, M0 is the average spectral sequence,
M1 is the first compoment to describe variability, and CL is the average color
correction law.
We use the formulation in sncosmo where CL is defined such that:
flux(time, wave) = x0 * [M0(time, wave) + x1 * M1(time, wave)] * 10 ** (-0.4 * c * CL(wave))
This class is based on the sncosmo implementation at:
https://github.com/sncosmo/sncosmo/blob/v2.10.1/sncosmo/models.py
The wrapped sncosmo version in sncosmo_models.py is faster and should be used
when auto-differentiation is not needed.
Attributes
----------
_m0_model : BicubicInterpolator
The interpolator for the m0 parameter.
_m1_model : BicubicInterpolator
The interpolator for the m1 parameter.
_colorlaw : SALT2ColorLaw
The data to apply the color law.
Parameters
----------
x0 : parameter
The SALT2 x0 parameter.
x1 : parameter
The SALT2 x1 parameter.
c : parameter
The SALT2 c parameter.
model_dir : `str`
The path for the model file directory.
Default: ""
m0_filename : `str`
The file name for the m0 model component.
Default: "salt2_template_0.dat"
m1_filename : `str`
The file name for the m1 model component.
Default: "salt2_template_1.dat"
cl_filename : `str`
The file name of the color law correction coefficients.
Default: "salt2_color_correction.dat",
**kwargs : `dict`, optional
Any additional keyword arguments.
"""

def __init__(
self,
x0=None,
x1=None,
c=None,
model_dir="",
m0_filename="salt2_template_0.dat",
m1_filename="salt2_template_1.dat",
cl_filename="salt2_color_correction.dat",
**kwargs,
):
super().__init__(**kwargs)

# Add the model specific parameters.
self.add_parameter("x0", x0, **kwargs)
self.add_parameter("x1", x1, **kwargs)
self.add_parameter("c", c, **kwargs)

# Load the data files.
model_path = Path(model_dir)
self._m0_model = BicubicInterpolator.from_grid_file(
model_path / m0_filename,
scale_factor=1e-12,
)
self._m1_model = BicubicInterpolator.from_grid_file(
model_path / m1_filename,
scale_factor=1e-12,
)

# Use the default color correction values.
self._colorlaw = SALT2ColorLaw.from_file(model_path / cl_filename)

def compute_flux(self, phase, wavelengths, graph_state, **kwargs):
"""Draw effect-free observations for this object.
Parameters
----------
phase : `numpy.ndarray`
A length T array of rest frame timestamps.
wavelengths : `numpy.ndarray`, optional
A length N array of wavelengths (in angstroms).
graph_state : `GraphState`
An object mapping graph parameters to their values.
**kwargs : `dict`, optional
Any additional keyword arguments.
Returns
-------
flux_density : `numpy.ndarray`
A length T x N matrix of SED values (in nJy).
"""
m0_vals = self._m0_model(phase, wavelengths)
m1_vals = self._m1_model(phase, wavelengths)
params = self.get_local_params(graph_state)

flux_density = (
params["x0"]
* (m0_vals + params["x1"] * m1_vals)
* 10.0 ** (-0.4 * self._colorlaw.apply(wavelengths) * params["c"])
)
return flux_density
Empty file added src/tdastro/utils/__init__.py
Empty file.
Loading

0 comments on commit c0744af

Please sign in to comment.