Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial Implementation of a JAX Salt2 model #191

Merged
merged 2 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 134 additions & 0 deletions src/tdastro/astro_utils/salt2_color_law.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""The ColorLaw used by SALT models as defined by in (Guy J., 2007)

It is adapted from sncosmo's SALT2ColorLaw class (but implemented in JAX):
https://github.com/sncosmo/sncosmo/blob/v2.10.1/sncosmo/salt2utils.pyx
"""

import jax.numpy as jnp
import numpy as np

# Constants used in SALT2ColorLaw computations (from:
# https://github.com/sncosmo/sncosmo/blob/v2.10.1/sncosmo/salt2utils.pyx)
_SALT2CL_B = 4302.57 # B-band-ish wavelength
_SALT2CL_V = 5428.55 # V-band-ish wavelength
_WAVESCALE = 1.0 / (_SALT2CL_V - _SALT2CL_B)


class SALT2ColorLaw:
"""An object that applies the color law to the given wavelengths.

Parameters
----------
wave_min : float
The minimum wavelength (in angstroms)
wave_max : float
The maximum wavelength (in angstroms)
coeffs : list, numpy array, or jax array
The <= 6 coefficients of the polynomial to use.

Attributes
----------
coeffs : JAX array
The final 7 coefficients to use (based on the <= 6 provided in the parameters).
scaled_wave_min : float
The minimum wavelength shifted and scaled.
scaled_wave_max : float
The maximum wavelength shifted and scaled.
value_at_min : float
The value of the polynomial at wave_min.
value_at_max : float
The value of the polynomial at wave_max.
_exponents : JAX array
A precomputed array of the exponents to use.
"""

def __init__(self, wave_min, wave_max, coeffs):
# Create the internal coefficient array. The new first entry is 1.0 minus the
# sum of the given entries. The first six given entries are then listed.
coeffs = np.array(coeffs)
num_coeffs = min(len(coeffs), 6)

padded_coeffs = np.zeros(7)
padded_coeffs[1 : (num_coeffs + 1)] = np.array(coeffs)[0:num_coeffs]
padded_coeffs[0] = 1.0 - np.sum(padded_coeffs)
self.coeffs = jnp.asarray(padded_coeffs)

# Compute the bounds for the wavelengths and the value of the polynomial at both bounds.
self.scaled_wave_min = (wave_min - _SALT2CL_B) * _WAVESCALE
self.scaled_wave_max = (wave_max - _SALT2CL_B) * _WAVESCALE

self.exponents = jnp.arange(1, 8, 1)
self.value_at_min = jnp.sum(self.coeffs * jnp.power(self.scaled_wave_min, self.exponents))
self.value_at_max = jnp.sum(self.coeffs * jnp.power(self.scaled_wave_max, self.exponents))

# Precompute the polynomials derivative at the min and max wavelength.
dcoeffs = jnp.arange(2, 8, 1) * self.coeffs[1:7]
dexponents = jnp.arange(1, 7, 1)
self.deriv_at_min = jnp.sum(dcoeffs * jnp.power(self.scaled_wave_min, dexponents)) + self.coeffs[0]
self.deriv_at_max = jnp.sum(dcoeffs * jnp.power(self.scaled_wave_max, dexponents)) + self.coeffs[0]

@classmethod
def from_file(cls, filename):
"""Create the SALT2ColorLaw object from data in a file.

Parameters
----------
filename : str
The name of the file to load.
"""
with open(filename, mode="r") as f:
data = f.read().split()

# The first line holds the number of coefficients N and the next N lines
# each hold a single coefficient.
num_coeff = int(data[0])
coeffs = np.array(data[1 : (1 + num_coeff)], dtype=float)

# The rest of the lines (if any) are meta-data with a label and value on each line.
wave_min = 3000.0
wave_max = 7000.0
for i in range(1 + num_coeff, len(data), 2):
if "min_lambda" in data[i]:
wave_min = float(data[i + 1])
elif "max_lambda" in data[i]:
wave_max = float(data[i + 1])
elif "version" in data[i]:
version = int(data[i + 1])
if version != 1:
raise RuntimeError(f"Unsupported version {version}.")

return SALT2ColorLaw(wave_min, wave_max, coeffs)

def apply(self, wavelengths):
"""Apply the color law to the given wavelengths.

Parameters
----------
wavelengths : array
The wavelengths in angstroms.
"""
num_waves = len(wavelengths)
shifted_wave = (jnp.asarray(wavelengths) - _SALT2CL_B) * _WAVESCALE

# Compute the three cases of interest.
# 1) If the shifted wave is past the lower bound, extrapolate a value based on
# the value and derivative at that bound.
below = self.value_at_min + self.deriv_at_min * (shifted_wave - self.scaled_wave_min)
# 2) If the shifted wave is past the upper bound, extrapolate a value based on
# the value and derivative at that bound.
above = self.value_at_max + self.deriv_at_max * (shifted_wave - self.scaled_wave_max)
# 3) If the shifted value is in the middle, use the polynomial.
wave_T = jnp.reshape(shifted_wave, (num_waves, 1))
coeffs_all = jnp.tile(self.coeffs, (num_waves, 1))
middle = jnp.sum((coeffs_all * jnp.power(wave_T, self.exponents)).T, axis=0)

result = -jnp.where(
shifted_wave < self.scaled_wave_min,
below,
jnp.where(
shifted_wave > self.scaled_wave_max,
above,
middle,
),
)
return result
121 changes: 121 additions & 0 deletions src/tdastro/sources/salt2_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from pathlib import Path

from tdastro.astro_utils.salt2_color_law import SALT2ColorLaw
from tdastro.sources.physical_model import PhysicalModel
from tdastro.utils.bicubic_interp import BicubicInterpolator


class SALT2JaxModel(PhysicalModel):
"""A SALT2 model implemented with JAX for it can use auto-differentiation.

The model is defined in (Guy J., 2007) as:

flux(time, wave) = x0 * [M0(time, wave) + x1 * M1(time, wave)] * exp(c * CL(wave))

where x0, x1, and c are given parameters, M0 is the average spectral sequence,
M1 is the first compoment to describe variability, and CL is the average color
correction law.

We use the formulation in sncosmo where CL is defined such that:

flux(time, wave) = x0 * [M0(time, wave) + x1 * M1(time, wave)] * 10 ** (-0.4 * c * CL(wave))

This class is based on the sncosmo implementation at:
https://github.com/sncosmo/sncosmo/blob/v2.10.1/sncosmo/models.py
The wrapped sncosmo version in sncosmo_models.py is faster and should be used
when auto-differentiation is not needed.

Attributes
----------
_m0_model : BicubicInterpolator
The interpolator for the m0 parameter.
_m1_model : BicubicInterpolator
The interpolator for the m1 parameter.
_colorlaw : SALT2ColorLaw
The data to apply the color law.

Parameters
----------
x0 : parameter
The SALT2 x0 parameter.
x1 : parameter
The SALT2 x1 parameter.
c : parameter
The SALT2 c parameter.
model_dir : `str`
The path for the model file directory.
Default: ""
m0_filename : `str`
The file name for the m0 model component.
Default: "salt2_template_0.dat"
m1_filename : `str`
The file name for the m1 model component.
Default: "salt2_template_1.dat"
cl_filename : `str`
The file name of the color law correction coefficients.
Default: "salt2_color_correction.dat",
**kwargs : `dict`, optional
Any additional keyword arguments.
"""

def __init__(
self,
x0=None,
x1=None,
c=None,
model_dir="",
m0_filename="salt2_template_0.dat",
m1_filename="salt2_template_1.dat",
cl_filename="salt2_color_correction.dat",
**kwargs,
):
super().__init__(**kwargs)

# Add the model specific parameters.
self.add_parameter("x0", x0, **kwargs)
self.add_parameter("x1", x1, **kwargs)
self.add_parameter("c", c, **kwargs)

# Load the data files.
model_path = Path(model_dir)
self._m0_model = BicubicInterpolator.from_grid_file(
model_path / m0_filename,
scale_factor=1e-12,
)
self._m1_model = BicubicInterpolator.from_grid_file(
model_path / m1_filename,
scale_factor=1e-12,
)

# Use the default color correction values.
self._colorlaw = SALT2ColorLaw.from_file(model_path / cl_filename)

def compute_flux(self, phase, wavelengths, graph_state, **kwargs):
"""Draw effect-free observations for this object.

Parameters
----------
phase : `numpy.ndarray`
A length T array of rest frame timestamps.
wavelengths : `numpy.ndarray`, optional
A length N array of wavelengths (in angstroms).
graph_state : `GraphState`
An object mapping graph parameters to their values.
**kwargs : `dict`, optional
Any additional keyword arguments.

Returns
-------
flux_density : `numpy.ndarray`
A length T x N matrix of SED values (in nJy).
"""
m0_vals = self._m0_model(phase, wavelengths)
m1_vals = self._m1_model(phase, wavelengths)
params = self.get_local_params(graph_state)

flux_density = (
params["x0"]
* (m0_vals + params["x1"] * m1_vals)
* 10.0 ** (-0.4 * self._colorlaw.apply(wavelengths) * params["c"])
)
return flux_density
Empty file added src/tdastro/utils/__init__.py
Empty file.
Loading