diff --git a/src/tdastro/astro_utils/cosmology.py b/src/tdastro/astro_utils/cosmology.py new file mode 100644 index 00000000..57771220 --- /dev/null +++ b/src/tdastro/astro_utils/cosmology.py @@ -0,0 +1,57 @@ +import astropy.cosmology.units as cu +from astropy import units as u + +from tdastro.base_models import FunctionNode + + +def redshift_to_distance(redshift, cosmology): + """Compute a source's luminosity distance given its redshift and a + specified cosmology using astropy's redshift_distance(). + + Parameters + ---------- + redshift : `float` + The redshift value. + cosmology : `astropy.cosmology` + The cosmology specification. + + Returns + ------- + distance : `float` + The luminosity distance (in pc) + """ + z = redshift * cu.redshift + distance = z.to(u.pc, cu.redshift_distance(cosmology, kind="luminosity")) + return distance.value + + +class RedshiftDistFunc(FunctionNode): + """A wrapper class for the redshift_to_distance() function. + + Attributes + ---------- + cosmology : `astropy.cosmology` + The cosmology specification. + kind : `str` + The distance type for the Equivalency as defined by + astropy.cosmology.units.redshift_distance. + + Parameters + ---------- + redshift : function or constant + The function or constant providing the redshift value. + cosmology : `astropy.cosmology` + The cosmology specification. + """ + + def __init__(self, redshift, cosmology): + # Call the super class's constructor with the needed information. + super().__init__( + func=redshift_to_distance, + redshift=redshift, + cosmology=cosmology, + ) + + def __str__(self): + """Return the string representation of the function.""" + return f"RedshiftDistFunc({self.cosmology.name}, {self.kind})" diff --git a/src/tdastro/base_models.py b/src/tdastro/base_models.py index 1cf1fbfc..cc67c8bb 100644 --- a/src/tdastro/base_models.py +++ b/src/tdastro/base_models.py @@ -5,9 +5,8 @@ class ParameterSource(Enum): - """ParameterSource specifies where a PhysicalModel should get the value - for a given parameter: a constant value, a function, or from another - parameterized model. + """ParameterSource specifies where a ParameterizedNode should get the value + for a given parameter: a constant value or from another ParameterizedNode. """ CONSTANT = 1 @@ -27,7 +26,7 @@ class ParameterizedNode: (ParameterSource, setter information, required). The attributes are stored in the order in which they need to be set. sample_iteration : `int` - A counter used to syncronize sampling runs. Tracks how many times this + A counter used to syncronize sampling runs. Tracks how many times this model's parameters have been resampled. """ @@ -39,6 +38,37 @@ def __str__(self): """Return the string representation of the node.""" return "ParameterizedNode" + def check_resample(self, other): + """Check if we need to resample the current node based + on the state of another node trying to access its attributes + or methods. + + Parameters + ---------- + other : `ParameterizedNode` + The node that is accessing the attribute or method + of the current node. + + Returns + ------- + bool + Indicates whether to resample or not. + + Raises + ------ + ``ValueError`` if the graph has gotten out of sync. + """ + if other == self: + return False + if other.sample_iteration == self.sample_iteration: + return False + if other.sample_iteration != self.sample_iteration + 1: + raise ValueError( + f"Node {str(other)} at iteration {other.sample_iteration} accessing" + f" parent {str(self)} at iteration {self.sample_iteration}." + ) + return True + def set_parameter(self, name, value=None, **kwargs): """Set a single *existing* parameter to the ParameterizedNode. @@ -72,31 +102,30 @@ def set_parameter(self, name, value=None, **kwargs): # The value wasn't set, but the name is in kwargs. value = kwargs[name] - if value is not None: + if callable(value): if isinstance(value, types.FunctionType): - # Case 1: If we are getting from a static function, sample it. + # Case 1a: This is a static function (not attached to an object). self.setters[name] = (ParameterSource.FUNCTION, value, required) - setattr(self, name, value(**kwargs)) - elif isinstance(value, types.MethodType) and isinstance(value.__self__, ParameterizedNode): - # Case 2: We are trying to use the method from a ParameterizedNode. - # Note that this will (correctly) fail if we are adding a model method from the current - # object that requires an unset attribute. + elif isinstance(value.__self__, ParameterizedNode): + # Case 1b: This is a method attached to another ParameterizedNode. self.setters[name] = (ParameterSource.MODEL_METHOD, value, required) - setattr(self, name, value(**kwargs)) - elif isinstance(value, ParameterizedNode): - # Case 3: We are trying to access an attribute from a ParameterizedNode. - if not hasattr(value, name): - raise ValueError(f"Attribute {name} missing from parent.") - self.setters[name] = (ParameterSource.MODEL_ATTRIBUTE, value, required) - setattr(self, name, getattr(value, name)) else: - # Case 4: The value is constant. - self.setters[name] = (ParameterSource.CONSTANT, value, required) - setattr(self, name, value) - elif not required: - self.setters[name] = (ParameterSource.CONSTANT, None, required) - setattr(self, name, None) + # Case 1c: This is a general callable method from another object. + # We treat it as static (we don't resample the other object). + self.setters[name] = (ParameterSource.FUNCTION, value, required) + setattr(self, name, value(**kwargs)) + elif isinstance(value, ParameterizedNode): + # Case 2: We are trying to access a parameter of another ParameterizedNode. + if not hasattr(value, name): + raise ValueError(f"Attribute {name} missing from parent.") + self.setters[name] = (ParameterSource.MODEL_ATTRIBUTE, value, required) + setattr(self, name, getattr(value, name)) else: + # Case 3: The value is constant (including None). + self.setters[name] = (ParameterSource.CONSTANT, value, required) + setattr(self, name, value) + + if required and getattr(self, name) is None: raise ValueError(f"Missing required parameter {name}") def add_parameter(self, name, value=None, required=False, **kwargs): @@ -157,6 +186,9 @@ def sample_parameters(self, max_depth=50, **kwargs): if max_depth == 0: raise ValueError(f"Maximum sampling depth exceeded at {self}. Potential infinite loop.") + # Increase the sampling iteration. + self.sample_iteration += 1 + # Run through each parameter and sample it based on the given recipe. # As of Python 3.7 dictionaries are guaranteed to preserve insertion ordering, # so this will iterate through attributes in the order they were inserted. @@ -167,21 +199,76 @@ def sample_parameters(self, max_depth=50, **kwargs): elif source_type == ParameterSource.FUNCTION: sampled_value = setter(**kwargs) elif source_type == ParameterSource.MODEL_ATTRIBUTE: - # Check if we need to resample the parent (needs to be done before - # we read its attribute). - if setter.sample_iteration == self.sample_iteration: + # Check if we need to resample the parent (before accessing the attribute). + if setter.check_resample(self): setter.sample_parameters(max_depth - 1, **kwargs) sampled_value = getattr(setter, name) elif source_type == ParameterSource.MODEL_METHOD: - # Check if we need to resample the parent (needs to be done before - # we evaluate its method). Do not resample the current object. - parent = setter.__self__ - if parent is not self and parent.sample_iteration == self.sample_iteration: - parent.sample_parameters(max_depth - 1, **kwargs) + # Check if we need to resample the parent (before calling the method). + parent_node = setter.__self__ + if parent_node.check_resample(self): + parent_node.sample_parameters(max_depth - 1, **kwargs) sampled_value = setter(**kwargs) else: raise ValueError(f"Unknown ParameterSource type {source_type} for {name}") setattr(self, name, sampled_value) - # Increase the sampling iteration. - self.sample_iteration += 1 + +class FunctionNode(ParameterizedNode): + """A class to wrap functions and their argument settings. + + Attributes + ---------- + func : `function` or `method` + The function to call during an evaluation. + args_names : `list` + A list of argument names to pass to the function. + + Examples + -------- + my_func = TDFunc(random.randint, a=1, b=10) + value1 = my_func() # Sample from default range + value2 = my_func(b=20) # Sample from extended range + + Note + ---- + All the function's parameters that will be used need to be specified + in either the default_args dict, object_args list, or as a kwarg in the + constructor. Arguments cannot be first given during function call. + For example the following will fail (because b is not defined in the + constructor): + + my_func = TDFunc(random.randint, a=1) + value1 = my_func(b=10.0) + """ + + def __init__(self, func, **kwargs): + super().__init__(**kwargs) + self.func = func + self.arg_names = [] + + # Add all of the parameters from default_args or the kwargs. + for key, value in kwargs.items(): + self.arg_names.append(key) + self.add_parameter(key, value) + + def __str__(self): + """Return the string representation of the function.""" + return f"FunctionNode({self.func.name})" + + def compute(self, **kwargs): + """Execute the wrapped function. + + Parameters + ---------- + **kwargs : `dict`, optional + Additional function arguments. + """ + args = {} + for key in self.arg_names: + # Override with the kwarg if the parameter is there. + if key in kwargs: + args[key] = kwargs[key] + else: + args[key] = getattr(self, key) + return self.func(**args) diff --git a/src/tdastro/sources/physical_model.py b/src/tdastro/sources/physical_model.py index 0b8af424..fb4cb3de 100644 --- a/src/tdastro/sources/physical_model.py +++ b/src/tdastro/sources/physical_model.py @@ -1,5 +1,6 @@ """The base PhysicalModel used for all sources.""" +from tdastro.astro_utils.cosmology import RedshiftDistFunc from tdastro.base_models import ParameterizedNode @@ -18,22 +19,36 @@ class PhysicalModel(ParameterizedNode): The object's right ascension (in degrees) dec : `float` The object's declination (in degrees) + redshift : `float` + The object's redshift. distance : `float` - The object's distance (in pc). + The object's luminosity distance (in pc). If no value is provided and + a ``cosmology`` parameter is given, the model will try to derive from + the redshift and the cosmology. background : `PhysicalModel` A source of background flux such as a host galaxy. effects : `list` A list of effects to apply to an observations. """ - def __init__(self, ra=None, dec=None, distance=None, background=None, **kwargs): + def __init__(self, ra=None, dec=None, redshift=None, distance=None, background=None, **kwargs): super().__init__(**kwargs) self.effects = [] # Set RA, dec, and redshift from the parameters. self.add_parameter("ra", ra) self.add_parameter("dec", dec) - self.add_parameter("distance", distance) + self.add_parameter("redshift", redshift) + + # If the luminosity distance is provided, use that. Otherwise try the + # redshift value using the cosmology (if given). Finally, default to None. + if distance is not None: + self.add_parameter("distance", distance) + elif redshift is not None and kwargs.get("cosmology", None) is not None: + self._redshift_func = RedshiftDistFunc(redshift=self, **kwargs) + self.add_parameter("distance", self._redshift_func.compute) + else: + self.add_parameter("distance", None) # Background is an object not a sampled parameter self.background = background @@ -140,10 +155,11 @@ def sample_parameters(self, include_effects=True, **kwargs): All the keyword arguments, including the values needed to sample parameters. """ - if self.background is not None: + if self.background is not None and self.background.check_resample(self): self.background.sample_parameters(include_effects, **kwargs) super().sample_parameters(**kwargs) if include_effects: for effect in self.effects: - effect.sample_parameters(**kwargs) + if effect.check_resample(self): + effect.sample_parameters(**kwargs) diff --git a/tests/tdastro/astro_utils/test_cosmology.py b/tests/tdastro/astro_utils/test_cosmology.py new file mode 100644 index 00000000..9ed9559e --- /dev/null +++ b/tests/tdastro/astro_utils/test_cosmology.py @@ -0,0 +1,12 @@ +from astropy.cosmology import WMAP9, Planck18 +from tdastro.astro_utils.cosmology import redshift_to_distance + + +def test_redshift_to_distance(): + """Test that we can convert the redshift to a distance using a given cosmology.""" + wmap9_val = redshift_to_distance(1100, cosmology=WMAP9) + planck18_val = redshift_to_distance(1100, cosmology=Planck18) + + assert abs(planck18_val - wmap9_val) > 1000.0 + assert 13.0 * 1e12 < wmap9_val < 16.0 * 1e12 + assert 13.0 * 1e12 < planck18_val < 16.0 * 1e12 diff --git a/tests/tdastro/populations/test_fixed_population.py b/tests/tdastro/populations/test_fixed_population.py index 134ed459..ff56b20a 100644 --- a/tests/tdastro/populations/test_fixed_population.py +++ b/tests/tdastro/populations/test_fixed_population.py @@ -2,6 +2,7 @@ import numpy as np import pytest +from tdastro.base_models import FunctionNode from tdastro.effects.white_noise import WhiteNoise from tdastro.populations.fixed_population import FixedPopulation from tdastro.sources.static_source import StaticSource @@ -84,17 +85,13 @@ def test_fixed_population_sample_sources(): assert np.allclose(counts, [0.4 * itr, 0.2 * itr, 0.4 * itr], rtol=0.05) -def _random_brightness(): - """Returns a random value [0.0, 100.0]""" - return 100.0 * random.random() - - def test_fixed_population_sample_fluxes(): """Test that we can create a population of sources and sample its sources' flux.""" random.seed(1001) + brightness_func = FunctionNode(random.uniform, a=0.0, b=100.0) population = FixedPopulation() population.add_source(StaticSource(brightness=100.0), 10.0) - population.add_source(StaticSource(brightness=_random_brightness), 10.0) + population.add_source(StaticSource(brightness=brightness_func.compute), 10.0) population.add_source(StaticSource(brightness=200.0), 20.0) population.add_source(StaticSource(brightness=150.0), 10.0) diff --git a/tests/tdastro/sources/test_physical_models.py b/tests/tdastro/sources/test_physical_models.py index 72950ded..8098cf1b 100644 --- a/tests/tdastro/sources/test_physical_models.py +++ b/tests/tdastro/sources/test_physical_models.py @@ -1,10 +1,29 @@ +from astropy.cosmology import Planck18 from tdastro.sources.physical_model import PhysicalModel def test_physical_model(): """Test that we can create a PhysicalModel.""" # Everything is specified. - model1 = PhysicalModel(ra=1.0, dec=2.0, distance=3.0) + model1 = PhysicalModel(ra=1.0, dec=2.0, distance=3.0, redshift=0.0) assert model1.ra == 1.0 assert model1.dec == 2.0 assert model1.distance == 3.0 + assert model1.redshift == 0.0 + + # Derive the distance from the redshift: + model2 = PhysicalModel(ra=1.0, dec=2.0, redshift=1100.0, cosmology=Planck18) + assert model2.ra == 1.0 + assert model2.dec == 2.0 + assert model2.redshift == 1100.0 + assert 13.0 * 1e12 < model2.distance < 16.0 * 1e12 + + # Neither distance nor redshift are specified. + model3 = PhysicalModel(ra=1.0, dec=2.0) + assert model3.redshift is None + assert model3.distance is None + + # Redshift is specified but cosmology is not. + model4 = PhysicalModel(ra=1.0, dec=2.0, redshift=1100.0) + assert model4.redshift == 1100.0 + assert model4.distance is None diff --git a/tests/tdastro/sources/test_step_source.py b/tests/tdastro/sources/test_step_source.py index dc069241..899a716f 100644 --- a/tests/tdastro/sources/test_step_source.py +++ b/tests/tdastro/sources/test_step_source.py @@ -1,6 +1,7 @@ import random import numpy as np +from tdastro.base_models import FunctionNode from tdastro.sources.static_source import StaticSource from tdastro.sources.step_source import StepSource @@ -18,7 +19,7 @@ def _sample_brightness(magnitude, **kwargs): return magnitude * random.random() -def _sample_end(duration, **kwargs): +def _sample_end(duration): """Return a random value between 1 and 1 + duration Parameters @@ -54,20 +55,20 @@ def test_step_source() -> None: def test_step_source_resample() -> None: """Check that we can call resample on the model parameters.""" + random.seed(1111) + model = StepSource( - brightness=_sample_brightness, + brightness=FunctionNode(_sample_brightness, magnitude=100.0).compute, t0=0.0, - t1=_sample_end, - magnitude=100.0, - duration=5.0, + t1=FunctionNode(_sample_end, duration=5.0).compute, ) - num_samples = 100 + num_samples = 1000 brightness_vals = np.zeros((num_samples, 1)) t_end_vals = np.zeros((num_samples, 1)) t_start_vals = np.zeros((num_samples, 1)) for i in range(num_samples): - model.sample_parameters(magnitude=100.0, duration=5.0) + model.sample_parameters() brightness_vals[i] = model.brightness t_end_vals[i] = model.t1 t_start_vals[i] = model.t0 @@ -79,6 +80,10 @@ def test_step_source_resample() -> None: assert np.all(t_end_vals >= 1.0) assert np.all(t_end_vals <= 6.0) + # Check that the expected values are close. + assert abs(np.mean(brightness_vals) - 50.0) < 5.0 + assert abs(np.mean(t_end_vals) - 3.5) < 0.5 + # Check that the brightness or end values are not all the same. assert not np.all(brightness_vals == brightness_vals[0]) assert not np.all(t_end_vals == t_end_vals[0]) diff --git a/tests/tdastro/test_base_models.py b/tests/tdastro/test_base_models.py index c33322ac..1fa953e5 100644 --- a/tests/tdastro/test_base_models.py +++ b/tests/tdastro/test_base_models.py @@ -1,7 +1,8 @@ import random +import numpy as np import pytest -from tdastro.base_models import ParameterizedNode +from tdastro.base_models import FunctionNode, ParameterizedNode def _sampler_fun(**kwargs): @@ -15,6 +16,19 @@ def _sampler_fun(**kwargs): return random.random() +def _test_func(value1, value2): + """Return the sum of the two parameters. + + Parameters + ---------- + value1 : `float` + The first parameter. + value2 : `float` + The second parameter. + """ + return value1 + value2 + + class PairModel(ParameterizedNode): """A test class for the ParameterizedNode. @@ -45,6 +59,10 @@ def __init__(self, value1, value2, **kwargs): self.add_parameter("value2", value2, required=True, **kwargs) self.add_parameter("value_sum", self.result, required=True, **kwargs) + def get_value1(self): + """Get the value of value1.""" + return self.value1 + def result(self, **kwargs): """Add the pair of values together @@ -141,3 +159,60 @@ def test_parameterized_node_modify() -> None: # We cannot set a value that hasn't been added. with pytest.raises(KeyError): model.set_parameter("brightness", 5.0) + + +def test_function_node_basic(): + """Test that we can create and query a FunctionNode.""" + my_func = FunctionNode(_test_func, value1=1.0, value2=2.0) + assert my_func.compute() == 3.0 + assert my_func.compute(value2=3.0) == 4.0 + assert my_func.compute(value2=3.0, unused_param=5.0) == 4.0 + assert my_func.compute(value2=3.0, value1=1.0) == 4.0 + + +def test_function_node_chain(): + """Test that we can create and query a chained FunctionNode.""" + func1 = FunctionNode(_test_func, value1=1.0, value2=1.0) + func2 = FunctionNode(_test_func, value1=func1.compute, value2=3.0) + assert func2.compute() == 5.0 + + +def test_np_sampler_method(): + """Test that we can wrap numpy random functions.""" + rng = np.random.default_rng(1001) + my_func = FunctionNode(rng.normal, loc=10.0, scale=1.0) + + # Sample 1000 times with the default values. Check that we are near + # the expected mean and not everything is equal. + vals = np.array([my_func.compute() for _ in range(1000)]) + assert abs(np.mean(vals) - 10.0) < 1.0 + assert not np.all(vals == vals[0]) + + # Override the mean and resample. + vals = np.array([my_func.compute(loc=25.0) for _ in range(1000)]) + assert abs(np.mean(vals) - 25.0) < 1.0 + assert not np.all(vals == vals[0]) + + +def test_function_node_obj(): + """Test that we can create and query a FunctionNode that depends on + another ParameterizedNode. + """ + # The model depends on the function. + func = FunctionNode(_test_func, value1=5.0, value2=6.0) + model = PairModel(value1=10.0, value2=func.compute) + assert model.result() == 21.0 + + # Function depends on the model's value2 attribute. + model = PairModel(value1=-10.0, value2=17.5) + func = FunctionNode(_test_func, value1=5.0, value2=model) + assert model.result() == 7.5 + assert func.compute() == 22.5 + + # Function depends on the model's get_value1() method. + func = FunctionNode(_test_func, value1=model.get_value1, value2=5.0) + assert model.result() == 7.5 + assert func.compute() == -5.0 + + # We can always override the attributes with kwargs. + assert func.compute(value1=1.0, value2=4.0) == 5.0