diff --git a/pyproject.toml b/pyproject.toml index e53721c1..9fe3c509 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ dynamic = ["version"] requires-python = ">=3.9" dependencies = [ "numpy", + "scipy", ] [project.urls] diff --git a/src/tdastro/sources/spline_model.py b/src/tdastro/sources/spline_model.py new file mode 100644 index 00000000..a8c21a51 --- /dev/null +++ b/src/tdastro/sources/spline_model.py @@ -0,0 +1,90 @@ +"""The SplineModel represents SED functions as a two dimensional grid +of (time, wavelength) -> flux value that is interpolated using a 2D spline. + +It is adapted from sncosmo's TimeSeriesSource model: +https://github.com/sncosmo/sncosmo/blob/v2.10.1/sncosmo/models.py +""" + +from scipy.interpolate import RectBivariateSpline + +from tdastro.base_models import PhysicalModel + + +class SplineModel(PhysicalModel): + """A time series model defined by sample points where the intermediate + points are fit by a spline. Based on sncosmo's TimeSeriesSource: + https://github.com/sncosmo/sncosmo/blob/v2.10.1/sncosmo/models.py + + Attributes + ---------- + _times : `numpy.ndarray` + A length T array containing the times at which the data was sampled. + _wavelengths : `numpy.ndarray` + A length W array containing the wavelengths at which the data was sampled. + _spline : `RectBivariateSpline` + The spline object for predicting the flux from a given (time, wavelength). + name : `str` + The name of the model being used. + amplitude : `float` + A unitless scaling parameter for the flux density values. + """ + + def __init__( + self, + times, + wavelengths, + flux, + amplitude=1.0, + time_degree=3, + wave_degree=3, + name=None, + **kwargs, + ): + """Create the SplineModel from a grid of (timestep, wavelength, flux) points. + + Parameters + ---------- + times : `numpy.ndarray` + A length T array containing the times at which the data was sampled. + wavelengths : `numpy.ndarray` + A length W array containing the wavelengths at which the data was sampled. + flux : `numpy.ndarray` + A shape (T, W) matrix with flux values for each pair of time and wavelength. + Fluxes provided in erg / s / cm^2 / Angstrom. + amplitude : `float` + A unitless scaling parameter for the flux density values. Default = 1.0 + time_degree : `int` + The polynomial degree to use in the time dimension. + wave_degree : `int` + The polynomial degree to use in the wavelength dimension. + name : `str`, optional + The name of the model. + **kwargs : `dict`, optional + Any additional keyword arguments. + """ + super().__init__(**kwargs) + + self.name = name + self.amplitude = amplitude + self._times = times + self._wavelengths = wavelengths + self._spline = RectBivariateSpline(times, wavelengths, flux, kx=time_degree, ky=wave_degree) + + def _evaluate(self, times, wavelengths, **kwargs): + """Draw effect-free observations for this object. + + Parameters + ---------- + times : `numpy.ndarray` + A length T array of timestamps. + wavelengths : `numpy.ndarray`, optional + A length N array of wavelengths. + **kwargs : `dict`, optional + Any additional keyword arguments. + + Returns + ------- + flux_density : `numpy.ndarray` + A length T x N matrix of SED values. + """ + return self.amplitude * self._spline(times, wavelengths, grid=True) diff --git a/tests/tdastro/sources/test_spline_source.py b/tests/tdastro/sources/test_spline_source.py new file mode 100644 index 00000000..7160748a --- /dev/null +++ b/tests/tdastro/sources/test_spline_source.py @@ -0,0 +1,42 @@ +import numpy as np +from tdastro.sources.spline_model import SplineModel + + +def test_spline_model_flat() -> None: + """Test that we can sample and create a flat SplineModel object.""" + times = np.linspace(1.0, 5.0, 20) + wavelengths = np.linspace(100.0, 500.0, 25) + fluxes = np.full((len(times), len(wavelengths)), 1.0) + model = SplineModel(times, wavelengths, fluxes) + + test_times = np.array([0.0, 1.0, 2.0, 3.0, 10.0]) + test_waves = np.array([0.0, 100.0, 200.0, 1000.0]) + + values = model.evaluate(test_times, test_waves) + assert values.shape == (5, 4) + expected = np.full_like(values, 1.0) + np.testing.assert_array_almost_equal(values, expected) + + model2 = SplineModel(times, wavelengths, fluxes, amplitude=5.0) + values2 = model2.evaluate(test_times, test_waves) + assert values2.shape == (5, 4) + expected2 = np.full_like(values2, 5.0) + np.testing.assert_array_almost_equal(values2, expected2) + + +def test_spline_model_interesting() -> None: + """Test that we can sample and create a flat SplineModel object.""" + times = np.array([1.0, 2.0, 3.0]) + wavelengths = np.array([10.0, 20.0, 30.0]) + fluxes = np.array([[1.0, 5.0, 1.0], [5.0, 10.0, 5.0], [1.0, 5.0, 3.0]]) + model = SplineModel(times, wavelengths, fluxes, time_degree=1, wave_degree=1) + + test_times = np.array([1.0, 1.5, 2.0, 3.0]) + test_waves = np.array([10.0, 15.0, 20.0, 30.0]) + values = model.evaluate(test_times, test_waves) + assert values.shape == (4, 4) + + expected = np.array( + [[1.0, 3.0, 5.0, 1.0], [3.0, 5.25, 7.5, 3.0], [5.0, 7.5, 10.0, 5.0], [1.0, 3.0, 5.0, 3.0]] + ) + np.testing.assert_array_almost_equal(values, expected)