Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FunctionNode and basic cosmology #33

Merged
merged 14 commits into from
Jul 15, 2024
57 changes: 57 additions & 0 deletions src/tdastro/astro_utils/cosmology.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import astropy.cosmology.units as cu
from astropy import units as u

from tdastro.base_models import FunctionNode


def redshift_to_distance(redshift, cosmology):
"""Compute a source's luminosity distance given its redshift and a
specified cosmology using astropy's redshift_distance().

Parameters
----------
redshift : `float`
The redshift value.
cosmology : `astropy.cosmology`
The cosmology specification.

Returns
-------
distance : `float`
The luminosity distance (in pc)
"""
z = redshift * cu.redshift
distance = z.to(u.pc, cu.redshift_distance(cosmology, kind="luminosity"))
return distance.value


class RedshiftDistFunc(FunctionNode):
"""A wrapper class for the redshift_to_distance() function.

Attributes
----------
cosmology : `astropy.cosmology`
The cosmology specification.
kind : `str`
The distance type for the Equivalency as defined by
astropy.cosmology.units.redshift_distance.

Parameters
----------
redshift : function or constant
The function or constant providing the redshift value.
cosmology : `astropy.cosmology`
The cosmology specification.
"""

def __init__(self, redshift, cosmology):
# Call the super class's constructor with the needed information.
super().__init__(
func=redshift_to_distance,
redshift=redshift,
cosmology=cosmology,
)

def __str__(self):
"""Return the string representation of the function."""
return f"RedshiftDistFunc({self.cosmology.name}, {self.kind})"
155 changes: 121 additions & 34 deletions src/tdastro/base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@


class ParameterSource(Enum):
"""ParameterSource specifies where a PhysicalModel should get the value
for a given parameter: a constant value, a function, or from another
parameterized model.
"""ParameterSource specifies where a ParameterizedNode should get the value
for a given parameter: a constant value or from another ParameterizedNode.
"""

CONSTANT = 1
Expand All @@ -27,7 +26,7 @@ class ParameterizedNode:
(ParameterSource, setter information, required). The attributes are
stored in the order in which they need to be set.
sample_iteration : `int`
A counter used to syncronize sampling runs. Tracks how many times this
A counter used to syncronize sampling runs. Tracks how many times this
model's parameters have been resampled.
"""

Expand All @@ -39,6 +38,37 @@ def __str__(self):
"""Return the string representation of the node."""
return "ParameterizedNode"

def check_resample(self, other):
"""Check if we need to resample the current node based
on the state of another node trying to access its attributes
or methods.

Parameters
----------
other : `ParameterizedNode`
The node that is accessing the attribute or method
of the current node.

Returns
-------
bool
Indicates whether to resample or not.

Raises
------
``ValueError`` if the graph has gotten out of sync.
"""
if other == self:
return False
if other.sample_iteration == self.sample_iteration:
return False
if other.sample_iteration != self.sample_iteration + 1:
raise ValueError(
f"Node {str(other)} at iteration {other.sample_iteration} accessing"
f" parent {str(self)} at iteration {self.sample_iteration}."
)
return True

def set_parameter(self, name, value=None, **kwargs):
"""Set a single *existing* parameter to the ParameterizedNode.

Expand Down Expand Up @@ -72,31 +102,30 @@ def set_parameter(self, name, value=None, **kwargs):
# The value wasn't set, but the name is in kwargs.
value = kwargs[name]

if value is not None:
if callable(value):
if isinstance(value, types.FunctionType):
# Case 1: If we are getting from a static function, sample it.
# Case 1a: This is a static function (not attached to an object).
self.setters[name] = (ParameterSource.FUNCTION, value, required)
setattr(self, name, value(**kwargs))
elif isinstance(value, types.MethodType) and isinstance(value.__self__, ParameterizedNode):
# Case 2: We are trying to use the method from a ParameterizedNode.
# Note that this will (correctly) fail if we are adding a model method from the current
# object that requires an unset attribute.
elif isinstance(value.__self__, ParameterizedNode):
# Case 1b: This is a method attached to another ParameterizedNode.
self.setters[name] = (ParameterSource.MODEL_METHOD, value, required)
setattr(self, name, value(**kwargs))
elif isinstance(value, ParameterizedNode):
# Case 3: We are trying to access an attribute from a ParameterizedNode.
if not hasattr(value, name):
raise ValueError(f"Attribute {name} missing from parent.")
self.setters[name] = (ParameterSource.MODEL_ATTRIBUTE, value, required)
setattr(self, name, getattr(value, name))
else:
# Case 4: The value is constant.
self.setters[name] = (ParameterSource.CONSTANT, value, required)
setattr(self, name, value)
elif not required:
self.setters[name] = (ParameterSource.CONSTANT, None, required)
setattr(self, name, None)
# Case 1c: This is a general callable method from another object.
# We treat it as static (we don't resample the other object).
self.setters[name] = (ParameterSource.FUNCTION, value, required)
setattr(self, name, value(**kwargs))
elif isinstance(value, ParameterizedNode):
# Case 2: We are trying to access a parameter of another ParameterizedNode.
if not hasattr(value, name):
raise ValueError(f"Attribute {name} missing from parent.")
self.setters[name] = (ParameterSource.MODEL_ATTRIBUTE, value, required)
setattr(self, name, getattr(value, name))
else:
# Case 3: The value is constant (including None).
self.setters[name] = (ParameterSource.CONSTANT, value, required)
setattr(self, name, value)

if required and getattr(self, name) is None:
raise ValueError(f"Missing required parameter {name}")

def add_parameter(self, name, value=None, required=False, **kwargs):
Expand Down Expand Up @@ -157,6 +186,9 @@ def sample_parameters(self, max_depth=50, **kwargs):
if max_depth == 0:
raise ValueError(f"Maximum sampling depth exceeded at {self}. Potential infinite loop.")

# Increase the sampling iteration.
self.sample_iteration += 1

# Run through each parameter and sample it based on the given recipe.
# As of Python 3.7 dictionaries are guaranteed to preserve insertion ordering,
# so this will iterate through attributes in the order they were inserted.
Expand All @@ -167,21 +199,76 @@ def sample_parameters(self, max_depth=50, **kwargs):
elif source_type == ParameterSource.FUNCTION:
sampled_value = setter(**kwargs)
elif source_type == ParameterSource.MODEL_ATTRIBUTE:
# Check if we need to resample the parent (needs to be done before
# we read its attribute).
if setter.sample_iteration == self.sample_iteration:
# Check if we need to resample the parent (before accessing the attribute).
if setter.check_resample(self):
setter.sample_parameters(max_depth - 1, **kwargs)
sampled_value = getattr(setter, name)
elif source_type == ParameterSource.MODEL_METHOD:
# Check if we need to resample the parent (needs to be done before
# we evaluate its method). Do not resample the current object.
parent = setter.__self__
if parent is not self and parent.sample_iteration == self.sample_iteration:
parent.sample_parameters(max_depth - 1, **kwargs)
# Check if we need to resample the parent (before calling the method).
parent_node = setter.__self__
if parent_node.check_resample(self):
parent_node.sample_parameters(max_depth - 1, **kwargs)
sampled_value = setter(**kwargs)
else:
raise ValueError(f"Unknown ParameterSource type {source_type} for {name}")
setattr(self, name, sampled_value)

# Increase the sampling iteration.
self.sample_iteration += 1

class FunctionNode(ParameterizedNode):
"""A class to wrap functions and their argument settings.

Attributes
----------
func : `function` or `method`
The function to call during an evaluation.
args_names : `list`
A list of argument names to pass to the function.

Examples
--------
my_func = TDFunc(random.randint, a=1, b=10)
value1 = my_func() # Sample from default range
value2 = my_func(b=20) # Sample from extended range

Note
----
All the function's parameters that will be used need to be specified
in either the default_args dict, object_args list, or as a kwarg in the
constructor. Arguments cannot be first given during function call.
For example the following will fail (because b is not defined in the
constructor):

my_func = TDFunc(random.randint, a=1)
value1 = my_func(b=10.0)
"""

def __init__(self, func, **kwargs):
super().__init__(**kwargs)
self.func = func
self.arg_names = []

# Add all of the parameters from default_args or the kwargs.
for key, value in kwargs.items():
self.arg_names.append(key)
self.add_parameter(key, value)

def __str__(self):
"""Return the string representation of the function."""
return f"FunctionNode({self.func.name})"

def compute(self, **kwargs):
"""Execute the wrapped function.

Parameters
----------
**kwargs : `dict`, optional
Additional function arguments.
"""
args = {}
for key in self.arg_names:
# Override with the kwarg if the parameter is there.
if key in kwargs:
args[key] = kwargs[key]
else:
args[key] = getattr(self, key)
return self.func(**args)
26 changes: 21 additions & 5 deletions src/tdastro/sources/physical_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""The base PhysicalModel used for all sources."""

from tdastro.astro_utils.cosmology import RedshiftDistFunc
from tdastro.base_models import ParameterizedNode


Expand All @@ -18,22 +19,36 @@ class PhysicalModel(ParameterizedNode):
The object's right ascension (in degrees)
dec : `float`
The object's declination (in degrees)
redshift : `float`
The object's redshift.
distance : `float`
The object's distance (in pc).
The object's luminosity distance (in pc). If no value is provided and
a ``cosmology`` parameter is given, the model will try to derive from
the redshift and the cosmology.
background : `PhysicalModel`
A source of background flux such as a host galaxy.
effects : `list`
A list of effects to apply to an observations.
"""

def __init__(self, ra=None, dec=None, distance=None, background=None, **kwargs):
def __init__(self, ra=None, dec=None, redshift=None, distance=None, background=None, **kwargs):
super().__init__(**kwargs)
self.effects = []

# Set RA, dec, and redshift from the parameters.
self.add_parameter("ra", ra)
self.add_parameter("dec", dec)
self.add_parameter("distance", distance)
self.add_parameter("redshift", redshift)

# If the luminosity distance is provided, use that. Otherwise try the
# redshift value using the cosmology (if given). Finally, default to None.
if distance is not None:
self.add_parameter("distance", distance)
elif redshift is not None and kwargs.get("cosmology", None) is not None:
self._redshift_func = RedshiftDistFunc(redshift=self, **kwargs)
self.add_parameter("distance", self._redshift_func.compute)
else:
self.add_parameter("distance", None)

# Background is an object not a sampled parameter
self.background = background
Expand Down Expand Up @@ -140,10 +155,11 @@ def sample_parameters(self, include_effects=True, **kwargs):
All the keyword arguments, including the values needed to sample
parameters.
"""
if self.background is not None:
if self.background is not None and self.background.check_resample(self):
self.background.sample_parameters(include_effects, **kwargs)
super().sample_parameters(**kwargs)

if include_effects:
for effect in self.effects:
effect.sample_parameters(**kwargs)
if effect.check_resample(self):
effect.sample_parameters(**kwargs)
12 changes: 12 additions & 0 deletions tests/tdastro/astro_utils/test_cosmology.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from astropy.cosmology import WMAP9, Planck18
from tdastro.astro_utils.cosmology import redshift_to_distance


def test_redshift_to_distance():
"""Test that we can convert the redshift to a distance using a given cosmology."""
wmap9_val = redshift_to_distance(1100, cosmology=WMAP9)
planck18_val = redshift_to_distance(1100, cosmology=Planck18)

assert abs(planck18_val - wmap9_val) > 1000.0
assert 13.0 * 1e12 < wmap9_val < 16.0 * 1e12
assert 13.0 * 1e12 < planck18_val < 16.0 * 1e12
9 changes: 3 additions & 6 deletions tests/tdastro/populations/test_fixed_population.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
import pytest
from tdastro.base_models import FunctionNode
from tdastro.effects.white_noise import WhiteNoise
from tdastro.populations.fixed_population import FixedPopulation
from tdastro.sources.static_source import StaticSource
Expand Down Expand Up @@ -84,17 +85,13 @@ def test_fixed_population_sample_sources():
assert np.allclose(counts, [0.4 * itr, 0.2 * itr, 0.4 * itr], rtol=0.05)


def _random_brightness():
"""Returns a random value [0.0, 100.0]"""
return 100.0 * random.random()


def test_fixed_population_sample_fluxes():
"""Test that we can create a population of sources and sample its sources' flux."""
random.seed(1001)
brightness_func = FunctionNode(random.uniform, a=0.0, b=100.0)
population = FixedPopulation()
population.add_source(StaticSource(brightness=100.0), 10.0)
population.add_source(StaticSource(brightness=_random_brightness), 10.0)
population.add_source(StaticSource(brightness=brightness_func.compute), 10.0)
population.add_source(StaticSource(brightness=200.0), 20.0)
population.add_source(StaticSource(brightness=150.0), 10.0)

Expand Down
21 changes: 20 additions & 1 deletion tests/tdastro/sources/test_physical_models.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,29 @@
from astropy.cosmology import Planck18
from tdastro.sources.physical_model import PhysicalModel


def test_physical_model():
"""Test that we can create a PhysicalModel."""
# Everything is specified.
model1 = PhysicalModel(ra=1.0, dec=2.0, distance=3.0)
model1 = PhysicalModel(ra=1.0, dec=2.0, distance=3.0, redshift=0.0)
assert model1.ra == 1.0
assert model1.dec == 2.0
assert model1.distance == 3.0
assert model1.redshift == 0.0

# Derive the distance from the redshift:
model2 = PhysicalModel(ra=1.0, dec=2.0, redshift=1100.0, cosmology=Planck18)
assert model2.ra == 1.0
assert model2.dec == 2.0
assert model2.redshift == 1100.0
assert 13.0 * 1e12 < model2.distance < 16.0 * 1e12

# Neither distance nor redshift are specified.
model3 = PhysicalModel(ra=1.0, dec=2.0)
assert model3.redshift is None
assert model3.distance is None

# Redshift is specified but cosmology is not.
model4 = PhysicalModel(ra=1.0, dec=2.0, redshift=1100.0)
assert model4.redshift == 1100.0
assert model4.distance is None
Loading
Loading