-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
7d14709
commit 0d8515f
Showing
3 changed files
with
125 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ dynamic = ["version"] | |
requires-python = ">=3.9" | ||
dependencies = [ | ||
"numpy", | ||
"scipy", | ||
] | ||
|
||
[project.urls] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import types | ||
|
||
import numpy as np | ||
|
||
from scipy.interpolate import RectBivariateSpline | ||
|
||
from tdastro.base_models import PhysicalModel | ||
|
||
|
||
class SplineModel(PhysicalModel): | ||
"""A time series model define by sample points where the intermediate | ||
points are fit by a spline. Based on sncosmo's TimeSeriesSource: | ||
https://github.com/sncosmo/sncosmo/blob/v2.10.1/sncosmo/models.py | ||
Attributes | ||
---------- | ||
_times : `numpy.ndarray` | ||
A length T array containing the times at which the data was sampled. | ||
_wavelengths : `numpy.ndarray` | ||
A length W array containing the wavelengths at which the data was sampled. | ||
_spline : `RectBivariateSpline` | ||
The spline object for predicting the flux from a given (time, wavelength). | ||
name : `str` | ||
The name of the model being used. | ||
amplitude : `float` | ||
A unitless scaling parameter for the flux density values. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
times, | ||
wavelengths, | ||
flux, | ||
amplitude=1.0, | ||
time_degree=3, | ||
wave_degree=3, | ||
name=None, | ||
**kwargs, | ||
): | ||
"""Create the SplineModel from a grid of (timestep, wavelength, flux) points. | ||
Parameters | ||
---------- | ||
times : `numpy.ndarray` | ||
A length T array containing the times at which the data was sampled. | ||
wavelengths : `numpy.ndarray` | ||
A length W array containing the wavelengths at which the data was sampled. | ||
flux : `numpy.ndarray` | ||
A shape (T, W) matrix with flux values for each pair of time and wavelength. | ||
Fluxes provided in erg / s / cm^2 / Angstrom. | ||
amplitude : `float` | ||
A unitless scaling parameter for the flux density values. Default = 1.0 | ||
time_degree : `int` | ||
The polynomial degree to use in the time dimension. | ||
wave_degree : `int` | ||
The polynomial degree to use in the wavelength dimension. | ||
name : `str`, optional | ||
The name of the model. | ||
**kwargs : `dict`, optional | ||
Any additional keyword arguments. | ||
""" | ||
super().__init__(**kwargs) | ||
|
||
self.name = name | ||
self.amplitude = amplitude | ||
self._times = times | ||
self._wavelengths = wavelengths | ||
self._spline = RectBivariateSpline(times, wavelengths, flux, kx=time_degree, ky=wave_degree) | ||
|
||
def _evaluate(self, times, wavelengths): | ||
"""Draw effect-free observations for this object. | ||
Parameters | ||
---------- | ||
times : `numpy.ndarray` | ||
A length N array of timestamps. | ||
wavelengths : `numpy.ndarray`, optional | ||
A length N array of wavelengths. | ||
**kwargs : `dict`, optional | ||
Any additional keyword arguments. | ||
Returns | ||
------- | ||
flux_density : `numpy.ndarray` | ||
A length N-array of flux densities. | ||
""" | ||
return self.amplitude * self._spline(times, wavelengths, grid=False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import numpy as np | ||
from tdastro.sources.spline_model import SplineModel | ||
|
||
|
||
def test_spline_model_flat() -> None: | ||
"""Test that we can sample and create a flat SplineModel object.""" | ||
times = np.linspace(1.0, 5.0, 20) | ||
wavelengths = np.linspace(100.0, 500.0, 25) | ||
fluxes = np.full((len(times), len(wavelengths)), 1.0) | ||
model = SplineModel(times, wavelengths, fluxes) | ||
|
||
test_times = np.array([0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 10.0]) | ||
test_waves = np.array([100.0, 200.0, 200.0, 0.0, 200.0, 1000.0, 200.0, 200.0]) | ||
|
||
values = model.evaluate(test_times, test_waves) | ||
expected = np.array([1.0] * len(test_times)) | ||
np.testing.assert_array_almost_equal(values, expected) | ||
|
||
model2 = SplineModel(times, wavelengths, fluxes, amplitude=5.0) | ||
values2 = model2.evaluate(test_times, test_waves) | ||
expected2 = np.array([5.0] * len(test_times)) | ||
np.testing.assert_array_almost_equal(values2, expected2) | ||
|
||
|
||
def test_spline_model_interesting() -> None: | ||
"""Test that we can sample and create a flat SplineModel object.""" | ||
times = np.array([1.0, 2.0, 3.0]) | ||
wavelengths = np.array([10.0, 20.0, 30.0]) | ||
fluxes = np.array([[1.0, 5.0, 1.0], [5.0, 10.0, 5.0], [1.0, 5.0, 3.0]]) | ||
model = SplineModel(times, wavelengths, fluxes, time_degree=1, wave_degree=1) | ||
|
||
test_times = np.array([1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.5, 2.0]) | ||
test_waves = np.array([10.0, 10.0, 10.0, 20.0, 20.0, 20.0, 30.0, 30.0, 30.0, 10.0, 15.0]) | ||
values = model.evaluate(test_times, test_waves) | ||
|
||
expected = np.array([1.0, 5.0, 1.0, 5.0, 10.0, 5.0, 1.0, 5.0, 3.0, 3.0, 7.5]) | ||
np.testing.assert_array_almost_equal(values, expected) |