Skip to content

Commit

Permalink
Merge branch 'main' into random_seed
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremykubica committed Jul 17, 2024
2 parents 39092b4 + a94c158 commit 3f44a01
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 17 deletions.
11 changes: 0 additions & 11 deletions src/tdastro/effects/effect_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,6 @@ class EffectModel(ParameterizedNode):
def __init__(self, **kwargs):
super().__init__(**kwargs)

def required_parameters(self):
"""Returns a list of the parameters of a PhysicalModel
that this effect needs to access.
Returns
-------
parameters : `list` of `str`
A list of every required parameter the effect needs.
"""
return []

def apply(self, flux_density, wavelengths=None, physical_model=None, **kwargs):
"""Apply the effect to observations (flux_density values)
Expand Down
86 changes: 86 additions & 0 deletions src/tdastro/effects/redshift.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from tdastro.effects.effect_model import EffectModel


class Redshift(EffectModel):
"""A redshift effect model.
This contains a "pre-effect" method, which is used to calculate the emitted wavelengths/times
needed to give us the observed wavelengths and times given the redshift. Times are calculated
with respect to the t0 of the given model.
Notes
-----
Conversions used are as follows:
- rest_frame_wavelength = observation_frame_wavelength / (1 + redshift)
- rest_frame_time = (observation_frame_time - t0) / (1 + redshift) + t0
- observation_frame_flux = rest_frame_flux / (1 + redshift)
"""

def __init__(self, redshift=None, t0=None, **kwargs):
"""Create a Redshift effect model.
Parameters
----------
redshift : `float`
The redshift.
t0 : `float`
The reference epoch; e.g. the time of the maximum light of a supernova or the epoch of zero phase
for a periodic variable star.
**kwargs : `dict`, optional
Any additional keyword arguments.
"""
super().__init__(**kwargs)
self.add_parameter("redshift", redshift, required=True, **kwargs)
self.add_parameter("t0", t0, required=True, **kwargs)

def __str__(self) -> str:
"""Return a string representation of the Redshift effect model."""
return f"RedshiftEffect(redshift={self.redshift})"

def pre_effect(self, observer_frame_times, observer_frame_wavelengths, **kwargs):
"""Calculate the rest-frame times and wavelengths needed to give us the observer-frame times
and wavelengths (given the redshift).
Parameters
----------
observer_frame_times : numpy.ndarray
The times at which the observation is made.
observer_frame_wavelengths : numpy.ndarray
The wavelengths at which the observation is made.
**kwargs : `dict`, optional
Any additional keyword arguments.
Returns
-------
tuple of (numpy.ndarray, numpy.ndarray)
The rest-frame times and wavelengths needed to generate the rest-frame flux densities,
which will later be redshifted back to observer-frame flux densities at the observer-frame
times and wavelengths.
"""
observed_times_rel_to_t0 = observer_frame_times - self.t0
rest_frame_times_rel_to_t0 = observed_times_rel_to_t0 / (1 + self.redshift)
rest_frame_times = rest_frame_times_rel_to_t0 + self.t0
rest_frame_wavelengths = observer_frame_wavelengths / (1 + self.redshift)
return (rest_frame_times, rest_frame_wavelengths)

def apply(self, flux_density, wavelengths, physical_model=None, **kwargs):
"""Apply the effect to observations (flux_density values).
Parameters
----------
flux_density : `numpy.ndarray`
A length T X N matrix of flux density values.
wavelengths : `numpy.ndarray`, optional
A length N array of wavelengths.
physical_model : `PhysicalModel`
A PhysicalModel from which the effect may query parameters such as redshift, position, or
distance.
**kwargs : `dict`, optional
Any additional keyword arguments.
Returns
-------
flux_density : `numpy.ndarray`
The redshifted results.
"""
return flux_density / (1 + self.redshift)
11 changes: 5 additions & 6 deletions src/tdastro/sources/physical_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,6 @@ def add_effect(self, effect, allow_dups=True, **kwargs):
if effect_type == type(prev):
raise ValueError("Added the effect type to a model {effect_type} more than once.")

required: list = effect.required_parameters()
for parameter in required:
# Raise an AttributeError if the parameter is missing or set to None.
if getattr(self, parameter) is None:
raise AttributeError(f"Parameter {parameter} unset for model {type(self).__name__}")

self.effects.append(effect)

def _evaluate(self, times, wavelengths, **kwargs):
Expand Down Expand Up @@ -129,6 +123,11 @@ def evaluate(self, times, wavelengths, resample_parameters=False, **kwargs):
if resample_parameters:
self.sample_parameters(kwargs)

# Pre-effects are adjustments done to times and/or wavelengths, before flux density computation.
for effect in self.effects:
if hasattr(effect, "pre_effect"):
times, wavelengths = effect.pre_effect(times, wavelengths, **kwargs)

# Compute the flux density for both the current object and add in anything
# behind it, such as a host galaxy.
flux_density = self._evaluate(times, wavelengths, **kwargs)
Expand Down
71 changes: 71 additions & 0 deletions tests/tdastro/effects/test_redshift.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import numpy as np
from tdastro.effects.redshift import Redshift
from tdastro.sources.step_source import StepSource


def get_no_effect_and_redshifted_values(times, wavelengths, t0, t1, brightness, redshift) -> tuple:
"""Get the values for a source with no effects and a redshifted source."""
model_no_effects = StepSource(brightness=brightness, t0=t0, t1=t1)
model_redshift = StepSource(brightness=brightness, t0=t0, t1=t1, redshift=redshift)
model_redshift.add_effect(Redshift(redshift=model_redshift, t0=model_redshift))

values_no_effects = model_no_effects.evaluate(times, wavelengths)
values_redshift = model_redshift.evaluate(times, wavelengths)

# Check shape of output is as expected
assert values_no_effects.shape == (len(times), len(wavelengths))
assert values_redshift.shape == (len(times), len(wavelengths))

return values_no_effects, values_redshift


def test_redshift() -> None:
"""Test that we can create a Redshift object and it gives us values as expected."""
times = np.array([1, 2, 3, 5, 10])
wavelengths = np.array([100.0, 200.0, 300.0])
t0 = 1.0
t1 = 2.0
brightness = 15.0
redshift = 1.0

# Get the values for a redshifted step source, and a step source with no effects for comparison
(values_no_effects, values_redshift) = get_no_effect_and_redshifted_values(
times, wavelengths, t0, t1, brightness, redshift
)

# Check that the step source activates within the correct time range:
# For values_no_effects, the activated values are in the range [t0, t1]
for i, time in enumerate(times):
if t0 <= time and time <= t1:
assert np.all(values_no_effects[i] == brightness)
else:
assert np.all(values_no_effects[i] == 0.0)

# With redshift = 1.0, the activated values are *observed* in the range [(t0-t0)*(1+redshift)+t0,
# (t1-t0*(1+redshift+t0] (the first term reduces to t0). Also, the values are scaled by a factor
# of (1+redshift).
for i, time in enumerate(times):
if t0 <= time and time <= (t1 - t0) * (1 + redshift) + t0:
assert np.all(values_redshift[i] == brightness / (1 + redshift))
else:
assert np.all(values_redshift[i] == 0.0)


def test_other_redshift_values() -> None:
"""Test that we can create a Redshift object with various other redshift values."""
times = np.linspace(0, 100, 1000)
wavelengths = np.array([100.0, 200.0, 300.0])
t0 = 10.0
t1 = 30.0
brightness = 50.0

for redshift in [0.0, 0.5, 2.0, 3.0, 30.0]:
model_redshift = StepSource(brightness=brightness, t0=t0, t1=t1, redshift=redshift)
model_redshift.add_effect(Redshift(redshift=model_redshift, t0=model_redshift))
values_redshift = model_redshift.evaluate(times, wavelengths)

for i, time in enumerate(times):
if t0 <= time and time <= (t1 - t0) * (1 + redshift) + t0:
assert np.all(values_redshift[i] == brightness / (1 + redshift))
else:
assert np.all(values_redshift[i] == 0.0)

0 comments on commit 3f44a01

Please sign in to comment.