From 0a4d73cd966cdfacc0d31242ce928bdde96a09a1 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Thu, 27 Jun 2024 09:21:16 -0400 Subject: [PATCH] Upgrade the spline model to produce a 2D array --- src/tdastro/sources/spline_model.py | 15 +++++++++------ tests/tdastro/sources/test_spline_source.py | 19 ++++++++++++------- 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/src/tdastro/sources/spline_model.py b/src/tdastro/sources/spline_model.py index d3ad0086..7c73cea0 100644 --- a/src/tdastro/sources/spline_model.py +++ b/src/tdastro/sources/spline_model.py @@ -1,6 +1,9 @@ -import types +"""The SplineModel represents SED functions as a two dimensional grid +of (time, wavelength) -> flux value that is interpolated using a 2D spline. -import numpy as np +It is adapted from sncosmo's TimeSeriesSource model: +https://github.com/sncosmo/sncosmo/blob/v2.10.1/sncosmo/models.py +""" from scipy.interpolate import RectBivariateSpline @@ -67,13 +70,13 @@ def __init__( self._wavelengths = wavelengths self._spline = RectBivariateSpline(times, wavelengths, flux, kx=time_degree, ky=wave_degree) - def _evaluate(self, times, wavelengths): + def _evaluate(self, times, wavelengths, **kwargs): """Draw effect-free observations for this object. Parameters ---------- times : `numpy.ndarray` - A length N array of timestamps. + A length T array of timestamps. wavelengths : `numpy.ndarray`, optional A length N array of wavelengths. **kwargs : `dict`, optional @@ -82,6 +85,6 @@ def _evaluate(self, times, wavelengths): Returns ------- flux_density : `numpy.ndarray` - A length N-array of flux densities. + A length T x N matrix of SED values. """ - return self.amplitude * self._spline(times, wavelengths, grid=False) + return self.amplitude * self._spline(times, wavelengths, grid=True) diff --git a/tests/tdastro/sources/test_spline_source.py b/tests/tdastro/sources/test_spline_source.py index 72ce1ac1..7160748a 100644 --- a/tests/tdastro/sources/test_spline_source.py +++ b/tests/tdastro/sources/test_spline_source.py @@ -9,16 +9,18 @@ def test_spline_model_flat() -> None: fluxes = np.full((len(times), len(wavelengths)), 1.0) model = SplineModel(times, wavelengths, fluxes) - test_times = np.array([0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 10.0]) - test_waves = np.array([100.0, 200.0, 200.0, 0.0, 200.0, 1000.0, 200.0, 200.0]) + test_times = np.array([0.0, 1.0, 2.0, 3.0, 10.0]) + test_waves = np.array([0.0, 100.0, 200.0, 1000.0]) values = model.evaluate(test_times, test_waves) - expected = np.array([1.0] * len(test_times)) + assert values.shape == (5, 4) + expected = np.full_like(values, 1.0) np.testing.assert_array_almost_equal(values, expected) model2 = SplineModel(times, wavelengths, fluxes, amplitude=5.0) values2 = model2.evaluate(test_times, test_waves) - expected2 = np.array([5.0] * len(test_times)) + assert values2.shape == (5, 4) + expected2 = np.full_like(values2, 5.0) np.testing.assert_array_almost_equal(values2, expected2) @@ -29,9 +31,12 @@ def test_spline_model_interesting() -> None: fluxes = np.array([[1.0, 5.0, 1.0], [5.0, 10.0, 5.0], [1.0, 5.0, 3.0]]) model = SplineModel(times, wavelengths, fluxes, time_degree=1, wave_degree=1) - test_times = np.array([1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.5, 2.0]) - test_waves = np.array([10.0, 10.0, 10.0, 20.0, 20.0, 20.0, 30.0, 30.0, 30.0, 10.0, 15.0]) + test_times = np.array([1.0, 1.5, 2.0, 3.0]) + test_waves = np.array([10.0, 15.0, 20.0, 30.0]) values = model.evaluate(test_times, test_waves) + assert values.shape == (4, 4) - expected = np.array([1.0, 5.0, 1.0, 5.0, 10.0, 5.0, 1.0, 5.0, 3.0, 3.0, 7.5]) + expected = np.array( + [[1.0, 3.0, 5.0, 1.0], [3.0, 5.25, 7.5, 3.0], [5.0, 7.5, 10.0, 5.0], [1.0, 3.0, 5.0, 3.0]] + ) np.testing.assert_array_almost_equal(values, expected)