Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a spline-based model #12

Merged
merged 5 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Make a basic spline model
  • Loading branch information
jeremykubica committed Jun 21, 2024
commit 0d8515f7d802e96d5122f2c4df74dc92352f3bd4
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ dynamic = ["version"]
requires-python = ">=3.9"
dependencies = [
"numpy",
"scipy",
]

[project.urls]
Expand Down
87 changes: 87 additions & 0 deletions src/tdastro/sources/spline_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import types

import numpy as np

from scipy.interpolate import RectBivariateSpline

from tdastro.base_models import PhysicalModel


class SplineModel(PhysicalModel):
"""A time series model define by sample points where the intermediate
jeremykubica marked this conversation as resolved.
Show resolved Hide resolved
points are fit by a spline. Based on sncosmo's TimeSeriesSource:
https://github.com/sncosmo/sncosmo/blob/v2.10.1/sncosmo/models.py

Attributes
----------
_times : `numpy.ndarray`
A length T array containing the times at which the data was sampled.
_wavelengths : `numpy.ndarray`
A length W array containing the wavelengths at which the data was sampled.
_spline : `RectBivariateSpline`
The spline object for predicting the flux from a given (time, wavelength).
name : `str`
The name of the model being used.
amplitude : `float`
A unitless scaling parameter for the flux density values.
"""

def __init__(
self,
times,
wavelengths,
flux,
amplitude=1.0,
time_degree=3,
wave_degree=3,
name=None,
**kwargs,
):
"""Create the SplineModel from a grid of (timestep, wavelength, flux) points.

Parameters
----------
times : `numpy.ndarray`
A length T array containing the times at which the data was sampled.
wavelengths : `numpy.ndarray`
A length W array containing the wavelengths at which the data was sampled.
flux : `numpy.ndarray`
A shape (T, W) matrix with flux values for each pair of time and wavelength.
Fluxes provided in erg / s / cm^2 / Angstrom.
amplitude : `float`
A unitless scaling parameter for the flux density values. Default = 1.0
time_degree : `int`
The polynomial degree to use in the time dimension.
wave_degree : `int`
The polynomial degree to use in the wavelength dimension.
name : `str`, optional
The name of the model.
**kwargs : `dict`, optional
Any additional keyword arguments.
"""
super().__init__(**kwargs)

self.name = name
self.amplitude = amplitude
self._times = times
self._wavelengths = wavelengths
self._spline = RectBivariateSpline(times, wavelengths, flux, kx=time_degree, ky=wave_degree)

def _evaluate(self, times, wavelengths):
"""Draw effect-free observations for this object.

Parameters
----------
times : `numpy.ndarray`
A length N array of timestamps.
wavelengths : `numpy.ndarray`, optional
A length N array of wavelengths.
**kwargs : `dict`, optional
Any additional keyword arguments.

Returns
-------
flux_density : `numpy.ndarray`
A length N-array of flux densities.
"""
return self.amplitude * self._spline(times, wavelengths, grid=False)
37 changes: 37 additions & 0 deletions tests/tdastro/sources/test_spline_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import numpy as np
from tdastro.sources.spline_model import SplineModel


def test_spline_model_flat() -> None:
"""Test that we can sample and create a flat SplineModel object."""
times = np.linspace(1.0, 5.0, 20)
wavelengths = np.linspace(100.0, 500.0, 25)
fluxes = np.full((len(times), len(wavelengths)), 1.0)
model = SplineModel(times, wavelengths, fluxes)

test_times = np.array([0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 10.0])
test_waves = np.array([100.0, 200.0, 200.0, 0.0, 200.0, 1000.0, 200.0, 200.0])

values = model.evaluate(test_times, test_waves)
expected = np.array([1.0] * len(test_times))
np.testing.assert_array_almost_equal(values, expected)

model2 = SplineModel(times, wavelengths, fluxes, amplitude=5.0)
values2 = model2.evaluate(test_times, test_waves)
expected2 = np.array([5.0] * len(test_times))
np.testing.assert_array_almost_equal(values2, expected2)


def test_spline_model_interesting() -> None:
"""Test that we can sample and create a flat SplineModel object."""
times = np.array([1.0, 2.0, 3.0])
wavelengths = np.array([10.0, 20.0, 30.0])
fluxes = np.array([[1.0, 5.0, 1.0], [5.0, 10.0, 5.0], [1.0, 5.0, 3.0]])
model = SplineModel(times, wavelengths, fluxes, time_degree=1, wave_degree=1)

test_times = np.array([1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.5, 2.0])
test_waves = np.array([10.0, 10.0, 10.0, 20.0, 20.0, 20.0, 30.0, 30.0, 30.0, 10.0, 15.0])
values = model.evaluate(test_times, test_waves)

expected = np.array([1.0, 5.0, 1.0, 5.0, 10.0, 5.0, 1.0, 5.0, 3.0, 3.0, 7.5])
np.testing.assert_array_almost_equal(values, expected)