Skip to content

Commit

Permalink
Merge branch 'main' into more_complex_model
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremykubica committed Jun 27, 2024
2 parents 0a4d73c + a046276 commit af8852a
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 16 deletions.
16 changes: 8 additions & 8 deletions src/tdastro/base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,13 @@ def add_effect(self, effect):

self.effects.append(effect)

def _evaluate(self, times, wavelengths=None, **kwargs):
def _evaluate(self, times, wavelengths, **kwargs):
"""Draw effect-free observations for this object.
Parameters
----------
times : `numpy.ndarray`
A length N array of timestamps.
A length T array of timestamps.
wavelengths : `numpy.ndarray`, optional
A length N array of wavelengths.
**kwargs : `dict`, optional
Expand All @@ -84,17 +84,17 @@ def _evaluate(self, times, wavelengths=None, **kwargs):
Returns
-------
flux_density : `numpy.ndarray`
A length N-array of flux densities.
A length T x N matrix of SED values.
"""
raise NotImplementedError()

def evaluate(self, times, wavelengths=None, **kwargs):
def evaluate(self, times, wavelengths, **kwargs):
"""Draw observations for this object and apply the noise.
Parameters
----------
times : `numpy.ndarray`
A length N array of timestamps.
A length T array of timestamps.
wavelengths : `numpy.ndarray`, optional
A length N array of wavelengths.
**kwargs : `dict`, optional
Expand All @@ -103,7 +103,7 @@ def evaluate(self, times, wavelengths=None, **kwargs):
Returns
-------
flux_density : `numpy.ndarray`
A length N-array of flux densities.
A length T x N matrix of SED values.
"""
flux_density = self._evaluate(times, wavelengths, **kwargs)
for effect in self.effects:
Expand Down Expand Up @@ -134,7 +134,7 @@ def apply(self, flux_density, wavelengths=None, physical_model=None, **kwargs):
Parameters
----------
flux_density : `numpy.ndarray`
A length N array of flux density values.
A length T X N matrix of flux density values.
wavelengths : `numpy.ndarray`, optional
A length N array of wavelengths.
physical_model : `PhysicalModel`
Expand All @@ -146,6 +146,6 @@ def apply(self, flux_density, wavelengths=None, physical_model=None, **kwargs):
Returns
-------
flux_density : `numpy.ndarray`
A length N-array of flux densities after the effect is applied.
A length T x N matrix of flux densities after the effect is applied.
"""
raise NotImplementedError()
8 changes: 4 additions & 4 deletions src/tdastro/sources/static_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@ def __init__(self, brightness, **kwargs):
# Otherwise assume we were given the parameter itself.
self.brightness = brightness

def _evaluate(self, times, wavelengths=None, **kwargs):
def _evaluate(self, times, wavelengths, **kwargs):
"""Draw effect-free observations for this object.
Parameters
----------
times : `numpy.ndarray`
A length N array of timestamps.
A length T array of timestamps.
wavelengths : `numpy.ndarray`, optional
A length N array of wavelengths.
**kwargs : `dict`, optional
Expand All @@ -51,6 +51,6 @@ def _evaluate(self, times, wavelengths=None, **kwargs):
Returns
-------
flux_density : `numpy.ndarray`
A length N-array of flux densities.
A length T x N matrix of SED values.
"""
return np.full_like(times, self.brightness)
return np.full((len(times), len(wavelengths)), self.brightness)
7 changes: 5 additions & 2 deletions tests/tdastro/effects/test_white_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@ def test_white_noise() -> None:
model = StaticSource(brightness=brightness_generator)
model.add_effect(WhiteNoise(scale=0.01))

values = model.evaluate(np.array([1, 2, 3, 4, 5]))
assert len(values) == 5
times = np.array([1, 2, 3, 5, 10])
wavelengths = np.array([100.0, 200.0, 300.0])

values = model.evaluate(times, wavelengths)
assert values.shape == (5, 3)
assert not np.all(values == 10.0)
assert np.all(np.abs(values - 10.0) < 1.0)
7 changes: 5 additions & 2 deletions tests/tdastro/sources/test_static_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@ def test_static_source() -> None:
assert model.dec is None
assert model.distance is None

values = model.evaluate(np.array([1, 2, 3, 4, 5, 10]))
assert len(values) == 6
times = np.array([1, 2, 3, 4, 5, 10])
wavelengths = np.array([100.0, 200.0, 300.0])

values = model.evaluate(times, wavelengths)
assert values.shape == (6, 3)
assert np.all(values == 10.0)


Expand Down

0 comments on commit af8852a

Please sign in to comment.