Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a spline-based model #12

Merged
merged 5 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ dynamic = ["version"]
requires-python = ">=3.9"
dependencies = [
"numpy",
"scipy",
]

[project.urls]
Expand Down
90 changes: 90 additions & 0 deletions src/tdastro/sources/spline_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""The SplineModel represents SED functions as a two dimensional grid
of (time, wavelength) -> flux value that is interpolated using a 2D spline.

It is adapted from sncosmo's TimeSeriesSource model:
https://github.com/sncosmo/sncosmo/blob/v2.10.1/sncosmo/models.py
"""

from scipy.interpolate import RectBivariateSpline

from tdastro.base_models import PhysicalModel


class SplineModel(PhysicalModel):
"""A time series model defined by sample points where the intermediate
points are fit by a spline. Based on sncosmo's TimeSeriesSource:
https://github.com/sncosmo/sncosmo/blob/v2.10.1/sncosmo/models.py

Attributes
----------
_times : `numpy.ndarray`
A length T array containing the times at which the data was sampled.
_wavelengths : `numpy.ndarray`
A length W array containing the wavelengths at which the data was sampled.
_spline : `RectBivariateSpline`
The spline object for predicting the flux from a given (time, wavelength).
name : `str`
The name of the model being used.
amplitude : `float`
A unitless scaling parameter for the flux density values.
"""

def __init__(
self,
times,
wavelengths,
flux,
amplitude=1.0,
time_degree=3,
wave_degree=3,
name=None,
**kwargs,
):
"""Create the SplineModel from a grid of (timestep, wavelength, flux) points.

Parameters
----------
times : `numpy.ndarray`
A length T array containing the times at which the data was sampled.
wavelengths : `numpy.ndarray`
A length W array containing the wavelengths at which the data was sampled.
flux : `numpy.ndarray`
A shape (T, W) matrix with flux values for each pair of time and wavelength.
Fluxes provided in erg / s / cm^2 / Angstrom.
amplitude : `float`
A unitless scaling parameter for the flux density values. Default = 1.0
time_degree : `int`
The polynomial degree to use in the time dimension.
wave_degree : `int`
The polynomial degree to use in the wavelength dimension.
name : `str`, optional
The name of the model.
**kwargs : `dict`, optional
Any additional keyword arguments.
"""
super().__init__(**kwargs)

self.name = name
self.amplitude = amplitude
self._times = times
self._wavelengths = wavelengths
self._spline = RectBivariateSpline(times, wavelengths, flux, kx=time_degree, ky=wave_degree)

def _evaluate(self, times, wavelengths, **kwargs):
"""Draw effect-free observations for this object.

Parameters
----------
times : `numpy.ndarray`
A length T array of timestamps.
wavelengths : `numpy.ndarray`, optional
A length N array of wavelengths.
**kwargs : `dict`, optional
Any additional keyword arguments.

Returns
-------
flux_density : `numpy.ndarray`
A length T x N matrix of SED values.
"""
return self.amplitude * self._spline(times, wavelengths, grid=True)
42 changes: 42 additions & 0 deletions tests/tdastro/sources/test_spline_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import numpy as np
from tdastro.sources.spline_model import SplineModel


def test_spline_model_flat() -> None:
"""Test that we can sample and create a flat SplineModel object."""
times = np.linspace(1.0, 5.0, 20)
wavelengths = np.linspace(100.0, 500.0, 25)
fluxes = np.full((len(times), len(wavelengths)), 1.0)
model = SplineModel(times, wavelengths, fluxes)

test_times = np.array([0.0, 1.0, 2.0, 3.0, 10.0])
test_waves = np.array([0.0, 100.0, 200.0, 1000.0])

values = model.evaluate(test_times, test_waves)
assert values.shape == (5, 4)
expected = np.full_like(values, 1.0)
np.testing.assert_array_almost_equal(values, expected)

model2 = SplineModel(times, wavelengths, fluxes, amplitude=5.0)
values2 = model2.evaluate(test_times, test_waves)
assert values2.shape == (5, 4)
expected2 = np.full_like(values2, 5.0)
np.testing.assert_array_almost_equal(values2, expected2)


def test_spline_model_interesting() -> None:
"""Test that we can sample and create a flat SplineModel object."""
times = np.array([1.0, 2.0, 3.0])
wavelengths = np.array([10.0, 20.0, 30.0])
fluxes = np.array([[1.0, 5.0, 1.0], [5.0, 10.0, 5.0], [1.0, 5.0, 3.0]])
model = SplineModel(times, wavelengths, fluxes, time_degree=1, wave_degree=1)

test_times = np.array([1.0, 1.5, 2.0, 3.0])
test_waves = np.array([10.0, 15.0, 20.0, 30.0])
values = model.evaluate(test_times, test_waves)
assert values.shape == (4, 4)

expected = np.array(
[[1.0, 3.0, 5.0, 1.0], [3.0, 5.25, 7.5, 3.0], [5.0, 7.5, 10.0, 5.0], [1.0, 3.0, 5.0, 3.0]]
)
np.testing.assert_array_almost_equal(values, expected)
Loading