Skip to content

Commit

Permalink
Reorganize the base classes
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremykubica committed Jul 15, 2024
1 parent ac94f79 commit ae8349d
Show file tree
Hide file tree
Showing 13 changed files with 341 additions and 322 deletions.
321 changes: 14 additions & 307 deletions src/tdastro/base_models.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
"""The base models used throughout TDAstro including physical objects, effects, and populations."""
"""The base models used to specify the TDAstro computation graph."""

import types
from enum import Enum

import numpy as np


class ParameterSource(Enum):
"""ParameterSource specifies where a PhysicalModel should get the value
Expand All @@ -18,10 +16,9 @@ class ParameterSource(Enum):
MODEL_METHOD = 4


class ParameterizedModel:
class ParameterizedNode:
"""Any model that uses parameters that can be set by constants,
functions, or other parameterized models. ParameterizedModels can
include physical objects or statistical distributions.
functions, or other parameterized nodes.
Attributes
----------
Expand All @@ -39,11 +36,11 @@ def __init__(self, **kwargs):
self.sample_iteration = 0

def __str__(self):
"""Return the string representation of the model."""
return "ParameterizedModel"
"""Return the string representation of the node."""
return "ParameterizedNode"

def set_parameter(self, name, value=None, **kwargs):
"""Set a single *existing* parameter to the ParameterizedModel.
"""Set a single *existing* parameter to the ParameterizedNode.
Notes
-----
Expand All @@ -56,7 +53,7 @@ def set_parameter(self, name, value=None, **kwargs):
The parameter name to add.
value : any, optional
The information to use to set the parameter. Can be a constant,
function, ParameterizedModel, or self.
function, ParameterizedNode, or self.
**kwargs : `dict`, optional
All other keyword arguments, possibly including the parameter setters.
Expand All @@ -80,14 +77,14 @@ def set_parameter(self, name, value=None, **kwargs):
# Case 1: If we are getting from a static function, sample it.
self.setters[name] = (ParameterSource.FUNCTION, value, required)
setattr(self, name, value(**kwargs))
elif isinstance(value, types.MethodType) and isinstance(value.__self__, ParameterizedModel):
# Case 2: We are trying to use the method from a ParameterizedModel.
elif isinstance(value, types.MethodType) and isinstance(value.__self__, ParameterizedNode):
# Case 2: We are trying to use the method from a ParameterizedNode.
# Note that this will (correctly) fail if we are adding a model method from the current
# object that requires an unset attribute.
self.setters[name] = (ParameterSource.MODEL_METHOD, value, required)
setattr(self, name, value(**kwargs))
elif isinstance(value, ParameterizedModel):
# Case 3: We are trying to access an attribute from a parameterized model.
elif isinstance(value, ParameterizedNode):
# Case 3: We are trying to access an attribute from a ParameterizedNode.
if not hasattr(value, name):
raise ValueError(f"Attribute {name} missing from parent.")
self.setters[name] = (ParameterSource.MODEL_ATTRIBUTE, value, required)
Expand All @@ -103,7 +100,7 @@ def set_parameter(self, name, value=None, **kwargs):
raise ValueError(f"Missing required parameter {name}")

def add_parameter(self, name, value=None, required=False, **kwargs):
"""Add a single *new* parameter to the ParameterizedModel.
"""Add a single *new* parameter to the ParameterizedNode.
Notes
-----
Expand All @@ -118,7 +115,7 @@ def add_parameter(self, name, value=None, required=False, **kwargs):
The parameter name to add.
value : any, optional
The information to use to set the parameter. Can be a constant,
function, ParameterizedModel, or self.
function, ParameterizedNode, or self.
required : `bool`
Fail if the parameter is set to ``None``.
**kwargs : `dict`, optional
Expand All @@ -141,7 +138,7 @@ def add_parameter(self, name, value=None, required=False, **kwargs):

def sample_parameters(self, max_depth=50, **kwargs):
"""Sample the model's underlying parameters if they are provided by a function
or ParameterizedModel.
or ParameterizedNode.
Parameters
----------
Expand Down Expand Up @@ -188,293 +185,3 @@ def sample_parameters(self, max_depth=50, **kwargs):

# Increase the sampling iteration.
self.sample_iteration += 1


class PhysicalModel(ParameterizedModel):
"""A physical model of a source of flux.
Physical models can have fixed attributes (where you need to create a new model
to change them) and settable attributes that can be passed functions or constants.
They can also have special background pointers that link to another PhysicalModel
producing flux. We can chain these to have a supernova in front of a star in front
of a static background.
Attributes
----------
ra : `float`
The object's right ascension (in degrees)
dec : `float`
The object's declination (in degrees)
distance : `float`
The object's distance (in pc)
background : `PhysicalModel`
A source of background flux such as a host galaxy.
effects : `list`
A list of effects to apply to an observations.
"""

def __init__(self, ra=None, dec=None, distance=None, background=None, **kwargs):
super().__init__(**kwargs)
self.effects = []

# Set RA, dec, and distance from the parameters.
self.add_parameter("ra", ra)
self.add_parameter("dec", dec)
self.add_parameter("distance", distance)

# Background is an object not a sampled parameter
self.background = background

def __str__(self):
"""Return the string representation of the model."""
return "PhysicalModel"

def add_effect(self, effect, allow_dups=True, **kwargs):
"""Add a transformational effect to the PhysicalModel.
Effects are applied in the order in which they are added.
Parameters
----------
effect : `EffectModel`
The effect to apply.
allow_dups : `bool`
Allow multiple effects of the same type.
Default = ``True``
**kwargs : `dict`, optional
Any additional keyword arguments.
Raises
------
Raises a ``AttributeError`` if the PhysicalModel does not have all of the
required attributes.
"""
# Check that we have not added this effect before.
if not allow_dups:
effect_type = type(effect)
for prev in self.effects:
if effect_type == type(prev):
raise ValueError("Added the effect type to a model {effect_type} more than once.")

required: list = effect.required_parameters()
for parameter in required:
# Raise an AttributeError if the parameter is missing or set to None.
if getattr(self, parameter) is None:
raise AttributeError(f"Parameter {parameter} unset for model {type(self).__name__}")

self.effects.append(effect)

def _evaluate(self, times, wavelengths, **kwargs):
"""Draw effect-free observations for this object.
Parameters
----------
times : `numpy.ndarray`
A length T array of timestamps.
wavelengths : `numpy.ndarray`, optional
A length N array of wavelengths.
**kwargs : `dict`, optional
Any additional keyword arguments.
Returns
-------
flux_density : `numpy.ndarray`
A length T x N matrix of SED values.
"""
raise NotImplementedError()

def evaluate(self, times, wavelengths, resample_parameters=False, **kwargs):
"""Draw observations for this object and apply the noise.
Parameters
----------
times : `numpy.ndarray`
A length T array of timestamps.
wavelengths : `numpy.ndarray`, optional
A length N array of wavelengths.
resample_parameters : `bool`
Treat this evaluation as a completely new object, resampling the
parameters from the original provided functions.
**kwargs : `dict`, optional
Any additional keyword arguments.
Returns
-------
flux_density : `numpy.ndarray`
A length T x N matrix of SED values.
"""
if resample_parameters:
self.sample_parameters(kwargs)

# Compute the flux density for both the current object and add in anything
# behind it, such as a host galaxy.
flux_density = self._evaluate(times, wavelengths, **kwargs)
if self.background is not None:
flux_density += self.background._evaluate(times, wavelengths, ra=self.ra, dec=self.dec, **kwargs)

for effect in self.effects:
flux_density = effect.apply(flux_density, wavelengths, self, **kwargs)
return flux_density

def sample_parameters(self, include_effects=True, **kwargs):
"""Sample the model's underlying parameters if they are provided by a function
or ParameterizedModel.
Parameters
----------
include_effects : `bool`
Resample the parameters for the effects models.
**kwargs : `dict`, optional
All the keyword arguments, including the values needed to sample
parameters.
"""
if self.background is not None:
self.background.sample_parameters(include_effects, **kwargs)
super().sample_parameters(**kwargs)

if include_effects:
for effect in self.effects:
effect.sample_parameters(**kwargs)


class EffectModel(ParameterizedModel):
"""A physical or systematic effect to apply to an observation."""

def __init__(self, **kwargs):
super().__init__(**kwargs)

def __str__(self):
"""Return the string representation of the model."""
return "EffectModel"

def required_parameters(self):
"""Returns a list of the parameters of a PhysicalModel
that this effect needs to access.
Returns
-------
parameters : `list` of `str`
A list of every required parameter the effect needs.
"""
return []

def apply(self, flux_density, wavelengths=None, physical_model=None, **kwargs):
"""Apply the effect to observations (flux_density values)
Parameters
----------
flux_density : `numpy.ndarray`
A length T X N matrix of flux density values.
wavelengths : `numpy.ndarray`, optional
A length N array of wavelengths.
physical_model : `PhysicalModel`
A PhysicalModel from which the effect may query parameters
such as redshift, position, or distance.
**kwargs : `dict`, optional
Any additional keyword arguments.
Returns
-------
flux_density : `numpy.ndarray`
A length T x N matrix of flux densities after the effect is applied.
"""
raise NotImplementedError()


class PopulationModel(ParameterizedModel):
"""A model of a population of PhysicalModels.
Attributes
----------
num_sources : `int`
The number of different sources in the population.
sources : `list`
A list of sources from which to draw.
"""

def __init__(self, rng=None, **kwargs):
super().__init__(**kwargs)
self.num_sources = 0
self.sources = []

def __str__(self):
"""Return the string representation of the model."""
return f"PopulationModel with {self.num_sources} sources."

def add_source(self, new_source, **kwargs):
"""Add a new source to the population.
Parameters
----------
new_source : `PhysicalModel`
A source from the population.
**kwargs : `dict`, optional
Any additional keyword arguments.
"""
if not isinstance(new_source, PhysicalModel):
raise ValueError("All sources must be PhysicalModels")
self.sources.append(new_source)
self.num_sources += 1

def draw_source(self):
"""Sample a single source from the population.
Returns
-------
source : `PhysicalModel`
A source from the population.
"""
raise NotImplementedError()

def add_effect(self, effect, allow_dups=False, **kwargs):
"""Add a transformational effect to all PhysicalModels in this population.
Effects are applied in the order in which they are added.
Parameters
----------
effect : `EffectModel`
The effect to apply.
allow_dups : `bool`
Allow multiple effects of the same type.
Default = ``True``
**kwargs : `dict`, optional
Any additional keyword arguments.
Raises
------
Raises a ``AttributeError`` if the PhysicalModel does not have all of the
required attributes.
"""
for source in self.sources:
source.add_effect(effect, allow_dups=allow_dups, **kwargs)

def evaluate(self, samples, times, wavelengths, resample_parameters=False, **kwargs):
"""Draw observations from a single (randomly sampled) source.
Parameters
----------
samples : `int`
The number of sources to samples.
times : `numpy.ndarray`
A length T array of timestamps.
wavelengths : `numpy.ndarray`, optional
A length N array of wavelengths.
resample_parameters : `bool`
Treat this evaluation as a completely new object, resampling the
parameters from the original provided functions.
**kwargs : `dict`, optional
Any additional keyword arguments.
Returns
-------
results : `numpy.ndarray`
A shape (samples, T, N) matrix of SED values.
"""
if samples <= 0:
raise ValueError("The number of samples must be > 0.")

results = []
for _ in range(samples):
source = self.draw_source()
object_fluxes = source.evaluate(times, wavelengths, resample_parameters, **kwargs)
results.append(object_fluxes)
return np.array(results)
Loading

0 comments on commit ae8349d

Please sign in to comment.