Skip to content

Commit

Permalink
Make a basic spline model
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremykubica committed Jun 21, 2024
1 parent 7d14709 commit 0d8515f
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 0 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ dynamic = ["version"]
requires-python = ">=3.9"
dependencies = [
"numpy",
"scipy",
]

[project.urls]
Expand Down
87 changes: 87 additions & 0 deletions src/tdastro/sources/spline_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import types

import numpy as np

from scipy.interpolate import RectBivariateSpline

from tdastro.base_models import PhysicalModel


class SplineModel(PhysicalModel):
"""A time series model define by sample points where the intermediate
points are fit by a spline. Based on sncosmo's TimeSeriesSource:
https://github.com/sncosmo/sncosmo/blob/v2.10.1/sncosmo/models.py
Attributes
----------
_times : `numpy.ndarray`
A length T array containing the times at which the data was sampled.
_wavelengths : `numpy.ndarray`
A length W array containing the wavelengths at which the data was sampled.
_spline : `RectBivariateSpline`
The spline object for predicting the flux from a given (time, wavelength).
name : `str`
The name of the model being used.
amplitude : `float`
A unitless scaling parameter for the flux density values.
"""

def __init__(
self,
times,
wavelengths,
flux,
amplitude=1.0,
time_degree=3,
wave_degree=3,
name=None,
**kwargs,
):
"""Create the SplineModel from a grid of (timestep, wavelength, flux) points.
Parameters
----------
times : `numpy.ndarray`
A length T array containing the times at which the data was sampled.
wavelengths : `numpy.ndarray`
A length W array containing the wavelengths at which the data was sampled.
flux : `numpy.ndarray`
A shape (T, W) matrix with flux values for each pair of time and wavelength.
Fluxes provided in erg / s / cm^2 / Angstrom.
amplitude : `float`
A unitless scaling parameter for the flux density values. Default = 1.0
time_degree : `int`
The polynomial degree to use in the time dimension.
wave_degree : `int`
The polynomial degree to use in the wavelength dimension.
name : `str`, optional
The name of the model.
**kwargs : `dict`, optional
Any additional keyword arguments.
"""
super().__init__(**kwargs)

self.name = name
self.amplitude = amplitude
self._times = times
self._wavelengths = wavelengths
self._spline = RectBivariateSpline(times, wavelengths, flux, kx=time_degree, ky=wave_degree)

def _evaluate(self, times, wavelengths):
"""Draw effect-free observations for this object.
Parameters
----------
times : `numpy.ndarray`
A length N array of timestamps.
wavelengths : `numpy.ndarray`, optional
A length N array of wavelengths.
**kwargs : `dict`, optional
Any additional keyword arguments.
Returns
-------
flux_density : `numpy.ndarray`
A length N-array of flux densities.
"""
return self.amplitude * self._spline(times, wavelengths, grid=False)
37 changes: 37 additions & 0 deletions tests/tdastro/sources/test_spline_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import numpy as np
from tdastro.sources.spline_model import SplineModel


def test_spline_model_flat() -> None:
"""Test that we can sample and create a flat SplineModel object."""
times = np.linspace(1.0, 5.0, 20)
wavelengths = np.linspace(100.0, 500.0, 25)
fluxes = np.full((len(times), len(wavelengths)), 1.0)
model = SplineModel(times, wavelengths, fluxes)

test_times = np.array([0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 10.0])
test_waves = np.array([100.0, 200.0, 200.0, 0.0, 200.0, 1000.0, 200.0, 200.0])

values = model.evaluate(test_times, test_waves)
expected = np.array([1.0] * len(test_times))
np.testing.assert_array_almost_equal(values, expected)

model2 = SplineModel(times, wavelengths, fluxes, amplitude=5.0)
values2 = model2.evaluate(test_times, test_waves)
expected2 = np.array([5.0] * len(test_times))
np.testing.assert_array_almost_equal(values2, expected2)


def test_spline_model_interesting() -> None:
"""Test that we can sample and create a flat SplineModel object."""
times = np.array([1.0, 2.0, 3.0])
wavelengths = np.array([10.0, 20.0, 30.0])
fluxes = np.array([[1.0, 5.0, 1.0], [5.0, 10.0, 5.0], [1.0, 5.0, 3.0]])
model = SplineModel(times, wavelengths, fluxes, time_degree=1, wave_degree=1)

test_times = np.array([1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.5, 2.0])
test_waves = np.array([10.0, 10.0, 10.0, 20.0, 20.0, 20.0, 30.0, 30.0, 30.0, 10.0, 15.0])
values = model.evaluate(test_times, test_waves)

expected = np.array([1.0, 5.0, 1.0, 5.0, 10.0, 5.0, 1.0, 5.0, 3.0, 3.0, 7.5])
np.testing.assert_array_almost_equal(values, expected)

0 comments on commit 0d8515f

Please sign in to comment.