From 4982de5196636485c24edb166aaf49badd68062a Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Sat, 13 Jul 2024 15:48:46 -0400 Subject: [PATCH 01/12] Add basic wrapped function --- src/tdastro/base_models.py | 58 ++++++++++--------- src/tdastro/function_wrappers.py | 69 +++++++++++++++++++++++ tests/tdastro/test_function_wrappers.py | 74 +++++++++++++++++++++++++ 3 files changed, 175 insertions(+), 26 deletions(-) create mode 100644 src/tdastro/function_wrappers.py create mode 100644 tests/tdastro/test_function_wrappers.py diff --git a/src/tdastro/base_models.py b/src/tdastro/base_models.py index a43319d2..d5b38b51 100644 --- a/src/tdastro/base_models.py +++ b/src/tdastro/base_models.py @@ -5,6 +5,8 @@ import numpy as np +from tdastro.function_wrappers import TDFunc + class ParameterSource(Enum): """ParameterSource specifies where a PhysicalModel should get the value @@ -14,8 +16,9 @@ class ParameterSource(Enum): CONSTANT = 1 FUNCTION = 2 - MODEL_ATTRIBUTE = 3 - MODEL_METHOD = 4 + TDFUNC_OBJ = 3 + MODEL_ATTRIBUTE = 4 + MODEL_METHOD = 5 class ParameterizedModel: @@ -75,31 +78,32 @@ def set_parameter(self, name, value=None, **kwargs): # The value wasn't set, but the name is in kwargs. value = kwargs[name] - if value is not None: - if isinstance(value, types.FunctionType): - # Case 1: If we are getting from a static function, sample it. - self.setters[name] = (ParameterSource.FUNCTION, value, required) - setattr(self, name, value(**kwargs)) - elif isinstance(value, types.MethodType) and isinstance(value.__self__, ParameterizedModel): - # Case 2: We are trying to use the method from a ParameterizedModel. - # Note that this will (correctly) fail if we are adding a model method from the current - # object that requires an unset attribute. - self.setters[name] = (ParameterSource.MODEL_METHOD, value, required) - setattr(self, name, value(**kwargs)) - elif isinstance(value, ParameterizedModel): - # Case 3: We are trying to access an attribute from a parameterized model. - if not hasattr(value, name): - raise ValueError(f"Attribute {name} missing from parent.") - self.setters[name] = (ParameterSource.MODEL_ATTRIBUTE, value, required) - setattr(self, name, getattr(value, name)) - else: - # Case 4: The value is constant. - self.setters[name] = (ParameterSource.CONSTANT, value, required) - setattr(self, name, value) - elif not required: - self.setters[name] = (ParameterSource.CONSTANT, None, required) - setattr(self, name, None) + if isinstance(value, types.FunctionType): + # Case 1a: If we are getting from a static function, sample it. + self.setters[name] = (ParameterSource.FUNCTION, value, required) + setattr(self, name, value(**kwargs)) + elif isinstance(value, TDFunc): + # Case 1b: We are using a TDFunc wrapped function (with default parameters). + self.setters[name] = (ParameterSource.TDFUNC_OBJ, value, required) + setattr(self, name, value(self, **kwargs)) + elif isinstance(value, types.MethodType) and isinstance(value.__self__, ParameterizedModel): + # Case 2: We are trying to use the method from a ParameterizedModel. + # Note that this will (correctly) fail if we are adding a model method from the current + # object that requires an unset attribute. + self.setters[name] = (ParameterSource.MODEL_METHOD, value, required) + setattr(self, name, value(**kwargs)) + elif isinstance(value, ParameterizedModel): + # Case 3: We are trying to access an attribute from a parameterized model. + if not hasattr(value, name): + raise ValueError(f"Attribute {name} missing from parent.") + self.setters[name] = (ParameterSource.MODEL_ATTRIBUTE, value, required) + setattr(self, name, getattr(value, name)) else: + # Case 4: The value is constant (including None). + self.setters[name] = (ParameterSource.CONSTANT, value, required) + setattr(self, name, value) + + if required and getattr(self, name) is None: raise ValueError(f"Missing required parameter {name}") def add_parameter(self, name, value=None, required=False, **kwargs): @@ -169,6 +173,8 @@ def sample_parameters(self, max_depth=50, **kwargs): sampled_value = setter elif source_type == ParameterSource.FUNCTION: sampled_value = setter(**kwargs) + elif source_type == ParameterSource.TDFUNC_OBJ: + sampled_value = setter(self, **kwargs) elif source_type == ParameterSource.MODEL_ATTRIBUTE: # Check if we need to resample the parent (needs to be done before # we read its attribute). diff --git a/src/tdastro/function_wrappers.py b/src/tdastro/function_wrappers.py new file mode 100644 index 00000000..03f4ef0d --- /dev/null +++ b/src/tdastro/function_wrappers.py @@ -0,0 +1,69 @@ +"""Classes to wrap functions to allow users to pass around functions +with partially specified arguments. +""" + +import copy + + +class TDFunc: + """A class to wrap functions to pass around functions with default + argument settings, arguments from kwargs, and (optionally) + arguments that are from the fcalling object. + + The object stores default arguments for the function, which does not + need to include all the function's parameters. Any parameters not + included in the default list must be specified as part of the kwargs. + + Attributes + ---------- + func : `function` or `method` + The function to call during an evaluation. + default_args : `dict` + A dictionary of default arguments to pass to the function. + This does not need to include all the arguments. + object_args : `list`, optional + Arguments that are provided by attributes of the calling object. + """ + + def __init__(self, func, default_args=None, object_args=None, **kwargs): + self.func = func + self.object_args = object_args + self.default_args = {} + + # The default arguments are the union of the default_args parameter + # and the remaining kwargs. + if default_args is not None: + self.default_args = default_args + if kwargs: + for key, value in kwargs.items(): + self.default_args[key] = value + + def __str__(self): + """Return the string representation of the function.""" + return f"TDFunc({self.func.name})" + + def __call__(self, calling_object=None, **kwargs): + """Execute the wrapped function. + + Parameters + ---------- + calling_object : any, optional + The object that called the function. + **kwargs : `dict`, optional + Additional function arguments. + """ + # Start with the default arguments. We make a copy so we can modify the dictionary. + args = copy.copy(self.default_args) + + # If there are arguments to get from the calling object, set those. + if self.object_args is not None and len(self.object_args) > 0: + if calling_object is None: + raise ValueError(f"Calling object needed for parameters: {self.object_args}") + for arg_name in self.object_args: + args[arg_name] = getattr(calling_object, arg_name) + + # Set any last arguments from the kwargs (overwriting previous settings). + args.update(kwargs) + + # Call the function with all the parameters. + return self.func(**args) diff --git a/tests/tdastro/test_function_wrappers.py b/tests/tdastro/test_function_wrappers.py new file mode 100644 index 00000000..40ea9e28 --- /dev/null +++ b/tests/tdastro/test_function_wrappers.py @@ -0,0 +1,74 @@ +import pytest +from tdastro.function_wrappers import TDFunc + + +def _test_func(a, b): + """Return the sum of the two parameters. + + Parameters + ---------- + a : `float` + The first parameter. + b : `float` + The second parameter. + """ + return a + b + + +class _StaticModel: + """A test model that has given parameters. + + Attributes + ---------- + a : `float` + The first parameter. + b : `float` + The second parameter. + """ + + def __init__(self, a, b): + self.a = a + self.b = b + + def eval_func(self, func, **kwargs): + """Evaluate a TDFunc. + + Parameters + ---------- + func : `TDFunc` + The function to evaluate. + **kwargs : `dict`, optional + Any additional keyword arguments. + """ + return func(self, **kwargs) + + +def test_tdfunc_basic(): + """Test that we can create and query a TDFunc.""" + tdf1 = TDFunc(_test_func, a=1.0) + + # Fail without enough arguments (only a is specified). + with pytest.raises(TypeError): + _ = tdf1() + + # We succeed with a manually specified parameter. + assert tdf1(b=2.0) == 3.0 + + # We can overwrite parameters. + assert tdf1(a=3.0, b=2.0) == 5.0 + + +def test_tdfunc_obj(): + """Test that we can create and query a TDFunc that depends on an object.""" + model = _StaticModel(a=10.0, b=11.0) + + # Check a function without defaults. + tdf1 = TDFunc(_test_func, object_args=["a", "b"]) + assert model.eval_func(tdf1) == 21.0 + + # Defaults set are overwritten by the object. + tdf2 = TDFunc(_test_func, object_args=["a", "b"], a=1.0, b=0.0) + assert model.eval_func(tdf2) == 21.0 + + # But we can overwrite everything with kwargs. + assert model.eval_func(tdf2, b=7.5) == 17.5 From a97242c7c6964174caf9c241444cae68fb3ba1b2 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Sun, 14 Jul 2024 09:30:06 -0400 Subject: [PATCH 02/12] Additional tests --- tests/tdastro/sources/test_step_source.py | 19 ++++++++++++------- tests/tdastro/test_function_wrappers.py | 4 ++++ 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/tests/tdastro/sources/test_step_source.py b/tests/tdastro/sources/test_step_source.py index dc069241..325e73f2 100644 --- a/tests/tdastro/sources/test_step_source.py +++ b/tests/tdastro/sources/test_step_source.py @@ -1,6 +1,7 @@ import random import numpy as np +from tdastro.function_wrappers import TDFunc from tdastro.sources.static_source import StaticSource from tdastro.sources.step_source import StepSource @@ -18,7 +19,7 @@ def _sample_brightness(magnitude, **kwargs): return magnitude * random.random() -def _sample_end(duration, **kwargs): +def _sample_end(duration): """Return a random value between 1 and 1 + duration Parameters @@ -54,20 +55,20 @@ def test_step_source() -> None: def test_step_source_resample() -> None: """Check that we can call resample on the model parameters.""" + random.seed(1111) + model = StepSource( - brightness=_sample_brightness, + brightness=TDFunc(_sample_brightness, magnitude=100.0), t0=0.0, - t1=_sample_end, - magnitude=100.0, - duration=5.0, + t1=TDFunc(_sample_end, duration=5.0), ) - num_samples = 100 + num_samples = 1000 brightness_vals = np.zeros((num_samples, 1)) t_end_vals = np.zeros((num_samples, 1)) t_start_vals = np.zeros((num_samples, 1)) for i in range(num_samples): - model.sample_parameters(magnitude=100.0, duration=5.0) + model.sample_parameters() brightness_vals[i] = model.brightness t_end_vals[i] = model.t1 t_start_vals[i] = model.t0 @@ -79,6 +80,10 @@ def test_step_source_resample() -> None: assert np.all(t_end_vals >= 1.0) assert np.all(t_end_vals <= 6.0) + # Check that the expected values are close. + assert abs(np.mean(brightness_vals) - 50.0) < 5.0 + assert abs(np.mean(t_end_vals) - 3.5) < 0.5 + # Check that the brightness or end values are not all the same. assert not np.all(brightness_vals == brightness_vals[0]) assert not np.all(t_end_vals == t_end_vals[0]) diff --git a/tests/tdastro/test_function_wrappers.py b/tests/tdastro/test_function_wrappers.py index 40ea9e28..13ddaa3e 100644 --- a/tests/tdastro/test_function_wrappers.py +++ b/tests/tdastro/test_function_wrappers.py @@ -57,6 +57,10 @@ def test_tdfunc_basic(): # We can overwrite parameters. assert tdf1(a=3.0, b=2.0) == 5.0 + # That we can use a different ordering for parameters. + tdf2 = TDFunc(_test_func, b=2.0, a=1.0) + assert tdf2() == 3.0 + def test_tdfunc_obj(): """Test that we can create and query a TDFunc that depends on an object.""" From 897e9d625aaf7f706b74f273b59bd5018dcae557 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Sun, 14 Jul 2024 10:24:30 -0400 Subject: [PATCH 03/12] Add basic cosmology --- src/tdastro/astro_utils/cosmology.py | 59 +++++++++++++++++++++ src/tdastro/base_models.py | 22 ++++++-- tests/tdastro/astro_utils/test_cosmology.py | 10 ++++ tests/tdastro/test_base_models.py | 31 ++++++++++- 4 files changed, 117 insertions(+), 5 deletions(-) create mode 100644 src/tdastro/astro_utils/cosmology.py create mode 100644 tests/tdastro/astro_utils/test_cosmology.py diff --git a/src/tdastro/astro_utils/cosmology.py b/src/tdastro/astro_utils/cosmology.py new file mode 100644 index 00000000..f7fbb6ba --- /dev/null +++ b/src/tdastro/astro_utils/cosmology.py @@ -0,0 +1,59 @@ +import astropy.cosmology.units as cu +from astropy import units as u + +from tdastro.function_wrappers import TDFunc + + +def redshift_to_distance(redshift, cosmology, kind="comoving"): + """Compute a source's distance given its redshift and a + specified cosmology using astropy's redshift_distance(). + + Parameters + ---------- + redshift : `float` + The redshift value. + cosmology : `astropy.cosmology` + The cosmology specification. + kind : `str` + The distance type for the Equivalency as defined by + astropy.cosmology.units.redshift_distance. + + Returns + ------- + distance : `float` + The distance (in pc) + """ + z = redshift * cu.redshift + distance = z.to(u.pc, cu.redshift_distance(cosmology, kind=kind)) + return distance.value + + +class RedshiftDistFunc(TDFunc): + """A wrapper class for the redshift_to_distance() function. + + Attributes + ---------- + cosmology : `astropy.cosmology` + The cosmology specification. + kind : `str` + The distance type for the Equivalency as defined by + astropy.cosmology.units.redshift_distance. + """ + + def __init__(self, cosmology, kind="comoving", **kwargs): + self.cosmology = cosmology + self.kind = kind + default_args = {"cosmology": cosmology, "kind": kind} + + # Call the super class's constructor with the needed information. + # Do not pass kwargs because we are limiting the arguments to match + # the signature of redshift_to_distance(). + super().__init__( + func=redshift_to_distance, + default_args=default_args, + object_args=["redshift"], + ) + + def __str__(self): + """Return the string representation of the function.""" + return f"RedshiftDistFunc({self.cosmology.name}, {self.kind})" diff --git a/src/tdastro/base_models.py b/src/tdastro/base_models.py index d5b38b51..86b24493 100644 --- a/src/tdastro/base_models.py +++ b/src/tdastro/base_models.py @@ -5,6 +5,7 @@ import numpy as np +from tdastro.astro_utils.cosmology import RedshiftDistFunc from tdastro.function_wrappers import TDFunc @@ -207,20 +208,33 @@ class PhysicalModel(ParameterizedModel): The object's right ascension (in degrees) dec : `float` The object's declination (in degrees) + redshift : `float` + The object's redshift. distance : `float` - The object's distance (in pc) + The object's distance (in pc). If the distance is not provided and + a ``cosmology`` parameter is given, the model will try to derive from + the redshift and the cosmology. effects : `list` A list of effects to apply to an observations. """ - def __init__(self, ra=None, dec=None, distance=None, **kwargs): + def __init__(self, ra=None, dec=None, redshift=None, distance=None, **kwargs): super().__init__(**kwargs) self.effects = [] - # Set RA, dec, and distance from the parameters. + # Set RA, dec, and redshift from the parameters. self.add_parameter("ra", ra) self.add_parameter("dec", dec) - self.add_parameter("distance", distance) + self.add_parameter("redshift", redshift) + + # If the distance is provided, use that. Otherwise try the redshift value + # using the cosmology (if given). Finally, default to None. + if distance is not None: + self.add_parameter("distance", distance) + elif redshift is not None and kwargs.get("cosmology", None) is not None: + self.add_parameter("distance", RedshiftDistFunc(**kwargs)) + else: + self.add_parameter("distance", None) def __str__(self): """Return the string representation of the model.""" diff --git a/tests/tdastro/astro_utils/test_cosmology.py b/tests/tdastro/astro_utils/test_cosmology.py new file mode 100644 index 00000000..14a500e2 --- /dev/null +++ b/tests/tdastro/astro_utils/test_cosmology.py @@ -0,0 +1,10 @@ +import pytest +from astropy.cosmology import WMAP9 +from tdastro.astro_utils.cosmology import redshift_to_distance + + +def test_redshift_to_distance(): + """Test that we can convert the redshift to a distance using a given cosmology.""" + # Use the example from: + # https://docs.astropy.org/en/stable/api/astropy.cosmology.units.redshift_distance.html + assert redshift_to_distance(1100, cosmology=WMAP9) == pytest.approx(14004.03157418 * 1e6) diff --git a/tests/tdastro/test_base_models.py b/tests/tdastro/test_base_models.py index b5ff1daa..6847cffa 100644 --- a/tests/tdastro/test_base_models.py +++ b/tests/tdastro/test_base_models.py @@ -1,7 +1,8 @@ import random import pytest -from tdastro.base_models import ParameterizedModel +from astropy.cosmology import WMAP9 +from tdastro.base_models import ParameterizedModel, PhysicalModel def _sampler_fun(**kwargs): @@ -141,3 +142,31 @@ def test_parameterized_model_modify() -> None: # We cannot set a value that hasn't been added. with pytest.raises(KeyError): model.set_parameter("brightness", 5.0) + + +def test_physical_model(): + """Test that we can create a physical model.""" + # Everything is specified. + model1 = PhysicalModel(ra=1.0, dec=2.0, distance=3.0, redshift=0.0) + assert model1.ra == 1.0 + assert model1.dec == 2.0 + assert model1.distance == 3.0 + assert model1.redshift == 0.0 + + # Derive the distance from the redshift using the example from: + # https://docs.astropy.org/en/stable/api/astropy.cosmology.units.redshift_distance.html + model2 = PhysicalModel(ra=1.0, dec=2.0, redshift=1100.0, cosmology=WMAP9) + assert model2.ra == 1.0 + assert model2.dec == 2.0 + assert model2.redshift == 1100.0 + assert model2.distance == pytest.approx(14004.03157418 * 1e6) + + # Neither distance nor redshift are specified. + model3 = PhysicalModel(ra=1.0, dec=2.0) + assert model3.redshift is None + assert model3.distance is None + + # Redshift is specified but cosmology is not. + model4 = PhysicalModel(ra=1.0, dec=2.0, redshift=1100.0) + assert model4.redshift == 1100.0 + assert model4.distance is None From 748447cd6a28f0f3545f0d96e1d0ceaf5ed9073f Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Sun, 14 Jul 2024 11:11:52 -0400 Subject: [PATCH 04/12] Add more comments and tests --- src/tdastro/function_wrappers.py | 41 ++++++++++++++++++------- tests/tdastro/test_function_wrappers.py | 38 ++++++++++++++++++----- 2 files changed, 61 insertions(+), 18 deletions(-) diff --git a/src/tdastro/function_wrappers.py b/src/tdastro/function_wrappers.py index 03f4ef0d..ed24aa3f 100644 --- a/src/tdastro/function_wrappers.py +++ b/src/tdastro/function_wrappers.py @@ -8,21 +8,34 @@ class TDFunc: """A class to wrap functions to pass around functions with default argument settings, arguments from kwargs, and (optionally) - arguments that are from the fcalling object. - - The object stores default arguments for the function, which does not - need to include all the function's parameters. Any parameters not - included in the default list must be specified as part of the kwargs. + arguments that are from the calling object. Attributes ---------- func : `function` or `method` The function to call during an evaluation. default_args : `dict` - A dictionary of default arguments to pass to the function. - This does not need to include all the arguments. - object_args : `list`, optional + A dictionary of default arguments to pass to the function. Assembled + from the ```default_args`` parameter and additional ``kwargs``. + object_args : `list` Arguments that are provided by attributes of the calling object. + + Example + ------- + my_func = TDFunc(random.randint, a=1, b=10) + value1 = my_func() # Sample from default range + value2 = my_func(b=20) # Sample from extended range + + Note + ---- + All the function's parameters that will be used need to be specified + in either the default_args dict, object_args list, or as a kwarg in the + constructor. Arguments cannot be first given during function call. + For example the following will fail (because b is not defined in the + constructor): + + my_func = TDFunc(random.randint, a=1) + value1 = my_func(b=10.0) """ def __init__(self, func, default_args=None, object_args=None, **kwargs): @@ -30,13 +43,16 @@ def __init__(self, func, default_args=None, object_args=None, **kwargs): self.object_args = object_args self.default_args = {} - # The default arguments are the union of the default_args parameter - # and the remaining kwargs. + # The default arguments are the union of the default_args parameter, + # the object_args (set to None), and the remaining kwargs. if default_args is not None: self.default_args = default_args if kwargs: for key, value in kwargs.items(): self.default_args[key] = value + if object_args: + for key in object_args: + self.default_args[key] = None def __str__(self): """Return the string representation of the function.""" @@ -63,7 +79,10 @@ def __call__(self, calling_object=None, **kwargs): args[arg_name] = getattr(calling_object, arg_name) # Set any last arguments from the kwargs (overwriting previous settings). - args.update(kwargs) + # Only use known kwargs. + for key, value in kwargs.items(): + if key in self.default_args: + args[key] = value # Call the function with all the parameters. return self.func(**args) diff --git a/tests/tdastro/test_function_wrappers.py b/tests/tdastro/test_function_wrappers.py index 13ddaa3e..ed59170f 100644 --- a/tests/tdastro/test_function_wrappers.py +++ b/tests/tdastro/test_function_wrappers.py @@ -1,3 +1,4 @@ +import numpy as np import pytest from tdastro.function_wrappers import TDFunc @@ -45,21 +46,44 @@ def eval_func(self, func, **kwargs): def test_tdfunc_basic(): """Test that we can create and query a TDFunc.""" - tdf1 = TDFunc(_test_func, a=1.0) - # Fail without enough arguments (only a is specified). + tdf1 = TDFunc(_test_func, a=1.0) with pytest.raises(TypeError): _ = tdf1() - # We succeed with a manually specified parameter. - assert tdf1(b=2.0) == 3.0 + # We succeed with a manually specified parameter (but first fail with + # the default). + tdf2 = TDFunc(_test_func, a=1.0, b=None) + with pytest.raises(TypeError): + _ = tdf2() + assert tdf2(b=2.0) == 3.0 # We can overwrite parameters. - assert tdf1(a=3.0, b=2.0) == 5.0 + assert tdf2(a=3.0, b=2.0) == 5.0 + + # Test that we ignore extra kwargs. + assert tdf2(b=2.0, c=10.0, d=11.0) == 3.0 # That we can use a different ordering for parameters. - tdf2 = TDFunc(_test_func, b=2.0, a=1.0) - assert tdf2() == 3.0 + tdf3 = TDFunc(_test_func, b=2.0, a=1.0) + assert tdf3() == 3.0 + + +def test_np_sampler_method(): + """Test that we can wrap numpy random functions.""" + rng = np.random.default_rng(1001) + tdf = TDFunc(rng.normal, loc=10.0, scale=1.0) + + # Sample 1000 times with the default values. Check that we are near + # the expected mean and not everything is equal. + vals = np.array([tdf() for _ in range(1000)]) + assert abs(np.mean(vals) - 10.0) < 1.0 + assert not np.all(vals == vals[0]) + + # Override the mean and resample. + vals = np.array([tdf(loc=25.0) for _ in range(1000)]) + assert abs(np.mean(vals) - 25.0) < 1.0 + assert not np.all(vals == vals[0]) def test_tdfunc_obj(): From bfc57227b6cc4b7cea81abc4677cbff47c2aff90 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Sun, 14 Jul 2024 11:27:42 -0400 Subject: [PATCH 05/12] Update population test to use TDFunc --- tests/tdastro/populations/test_fixed_population.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/tdastro/populations/test_fixed_population.py b/tests/tdastro/populations/test_fixed_population.py index 134ed459..95b0e0ae 100644 --- a/tests/tdastro/populations/test_fixed_population.py +++ b/tests/tdastro/populations/test_fixed_population.py @@ -3,6 +3,7 @@ import numpy as np import pytest from tdastro.effects.white_noise import WhiteNoise +from tdastro.function_wrappers import TDFunc from tdastro.populations.fixed_population import FixedPopulation from tdastro.sources.static_source import StaticSource @@ -84,17 +85,12 @@ def test_fixed_population_sample_sources(): assert np.allclose(counts, [0.4 * itr, 0.2 * itr, 0.4 * itr], rtol=0.05) -def _random_brightness(): - """Returns a random value [0.0, 100.0]""" - return 100.0 * random.random() - - def test_fixed_population_sample_fluxes(): """Test that we can create a population of sources and sample its sources' flux.""" random.seed(1001) population = FixedPopulation() population.add_source(StaticSource(brightness=100.0), 10.0) - population.add_source(StaticSource(brightness=_random_brightness), 10.0) + population.add_source(StaticSource(brightness=TDFunc(random.uniform, a=0.0, b=100.0)), 10.0) population.add_source(StaticSource(brightness=200.0), 20.0) population.add_source(StaticSource(brightness=150.0), 10.0) From b9bc87cd78da53e86ae3760e29aa3cc99ab56af9 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Sun, 14 Jul 2024 16:10:48 -0400 Subject: [PATCH 06/12] All function chaining --- src/tdastro/astro_utils/cosmology.py | 18 +++++++-- src/tdastro/base_models.py | 10 +++-- src/tdastro/function_wrappers.py | 52 +++++++++++-------------- tests/tdastro/test_function_wrappers.py | 41 +++++++++++-------- 4 files changed, 67 insertions(+), 54 deletions(-) diff --git a/src/tdastro/astro_utils/cosmology.py b/src/tdastro/astro_utils/cosmology.py index f7fbb6ba..ecbd16ab 100644 --- a/src/tdastro/astro_utils/cosmology.py +++ b/src/tdastro/astro_utils/cosmology.py @@ -38,20 +38,30 @@ class RedshiftDistFunc(TDFunc): kind : `str` The distance type for the Equivalency as defined by astropy.cosmology.units.redshift_distance. + + Parameters + ---------- + redshift : function or constant + The function or constant providing the redshift value. + cosmology : `astropy.cosmology` + The cosmology specification. + kind : `str` + The distance type for the Equivalency as defined by + astropy.cosmology.units.redshift_distance. """ - def __init__(self, cosmology, kind="comoving", **kwargs): + def __init__(self, redshift, cosmology, kind="comoving"): self.cosmology = cosmology self.kind = kind - default_args = {"cosmology": cosmology, "kind": kind} # Call the super class's constructor with the needed information. # Do not pass kwargs because we are limiting the arguments to match # the signature of redshift_to_distance(). super().__init__( func=redshift_to_distance, - default_args=default_args, - object_args=["redshift"], + redshift=redshift, + cosmology=cosmology, + kind=kind, ) def __str__(self): diff --git a/src/tdastro/base_models.py b/src/tdastro/base_models.py index d612d5fd..5e5746f3 100644 --- a/src/tdastro/base_models.py +++ b/src/tdastro/base_models.py @@ -86,7 +86,7 @@ def set_parameter(self, name, value=None, **kwargs): elif isinstance(value, TDFunc): # Case 1b: We are using a TDFunc wrapped function (with default parameters). self.setters[name] = (ParameterSource.TDFUNC_OBJ, value, required) - setattr(self, name, value(self, **kwargs)) + setattr(self, name, value(**kwargs)) elif isinstance(value, types.MethodType) and isinstance(value.__self__, ParameterizedModel): # Case 2: We are trying to use the method from a ParameterizedModel. # Note that this will (correctly) fail if we are adding a model method from the current @@ -175,7 +175,7 @@ def sample_parameters(self, max_depth=50, **kwargs): elif source_type == ParameterSource.FUNCTION: sampled_value = setter(**kwargs) elif source_type == ParameterSource.TDFUNC_OBJ: - sampled_value = setter(self, **kwargs) + sampled_value = setter(**kwargs) elif source_type == ParameterSource.MODEL_ATTRIBUTE: # Check if we need to resample the parent (needs to be done before # we read its attribute). @@ -238,7 +238,7 @@ def __init__(self, ra=None, dec=None, redshift=None, distance=None, background=N if distance is not None: self.add_parameter("distance", distance) elif redshift is not None and kwargs.get("cosmology", None) is not None: - self.add_parameter("distance", RedshiftDistFunc(**kwargs)) + self.add_parameter("distance", RedshiftDistFunc(redshift=self.get_redshift, **kwargs)) else: self.add_parameter("distance", None) @@ -249,6 +249,10 @@ def __str__(self): """Return the string representation of the model.""" return "PhysicalModel" + def get_redshift(self): + """Return the redshift for the model.""" + return self.redshift + def add_effect(self, effect, allow_dups=True, **kwargs): """Add a transformational effect to the PhysicalModel. Effects are applied in the order in which they are added. diff --git a/src/tdastro/function_wrappers.py b/src/tdastro/function_wrappers.py index ed24aa3f..5a7aad6e 100644 --- a/src/tdastro/function_wrappers.py +++ b/src/tdastro/function_wrappers.py @@ -1,14 +1,10 @@ -"""Classes to wrap functions to allow users to pass around functions -with partially specified arguments. -""" +"""Utilities to wrap functions for inclusion in an evaluation graph.""" import copy class TDFunc: - """A class to wrap functions to pass around functions with default - argument settings, arguments from kwargs, and (optionally) - arguments that are from the calling object. + """A class to wrap functions and their argument settings. Attributes ---------- @@ -17,11 +13,13 @@ class TDFunc: default_args : `dict` A dictionary of default arguments to pass to the function. Assembled from the ```default_args`` parameter and additional ``kwargs``. - object_args : `list` - Arguments that are provided by attributes of the calling object. + setter_functions : `dict` + A dictionary mapping arguments names to functions, methods, or + TDFunc objects used to set that argument. These are evaluated dynamically + each time. - Example - ------- + Examples + -------- my_func = TDFunc(random.randint, a=1, b=10) value1 = my_func() # Sample from default range value2 = my_func(b=20) # Sample from extended range @@ -38,45 +36,39 @@ class TDFunc: value1 = my_func(b=10.0) """ - def __init__(self, func, default_args=None, object_args=None, **kwargs): + def __init__(self, func, **kwargs): self.func = func - self.object_args = object_args self.default_args = {} + self.setter_functions = {} - # The default arguments are the union of the default_args parameter, - # the object_args (set to None), and the remaining kwargs. - if default_args is not None: - self.default_args = default_args - if kwargs: - for key, value in kwargs.items(): - self.default_args[key] = value - if object_args: - for key in object_args: + for key, value in kwargs.items(): + if callable(value): self.default_args[key] = None + self.setter_functions[key] = value + else: + self.default_args[key] = value def __str__(self): """Return the string representation of the function.""" return f"TDFunc({self.func.name})" - def __call__(self, calling_object=None, **kwargs): + def __call__(self, **kwargs): """Execute the wrapped function. Parameters ---------- - calling_object : any, optional - The object that called the function. **kwargs : `dict`, optional Additional function arguments. """ # Start with the default arguments. We make a copy so we can modify the dictionary. args = copy.copy(self.default_args) - # If there are arguments to get from the calling object, set those. - if self.object_args is not None and len(self.object_args) > 0: - if calling_object is None: - raise ValueError(f"Calling object needed for parameters: {self.object_args}") - for arg_name in self.object_args: - args[arg_name] = getattr(calling_object, arg_name) + # If there are arguments to get from the calling functions, set those. + for key, value in self.setter_functions.items(): + if isinstance(value, TDFunc): + args[key] = value(**kwargs) + else: + args[key] = value() # Set any last arguments from the kwargs (overwriting previous settings). # Only use known kwargs. diff --git a/tests/tdastro/test_function_wrappers.py b/tests/tdastro/test_function_wrappers.py index ed59170f..e7048391 100644 --- a/tests/tdastro/test_function_wrappers.py +++ b/tests/tdastro/test_function_wrappers.py @@ -31,17 +31,13 @@ def __init__(self, a, b): self.a = a self.b = b - def eval_func(self, func, **kwargs): - """Evaluate a TDFunc. + def get_a(self): + """Get the a attribute.""" + return self.a - Parameters - ---------- - func : `TDFunc` - The function to evaluate. - **kwargs : `dict`, optional - Any additional keyword arguments. - """ - return func(self, **kwargs) + def get_b(self): + """Get the b attribute.""" + return self.b def test_tdfunc_basic(): @@ -69,6 +65,16 @@ def test_tdfunc_basic(): assert tdf3() == 3.0 +def test_tdfunc_chain(): + """Test that we can create and query a chained TDFunc.""" + tdf1 = TDFunc(_test_func, a=1.0, b=1.0) + tdf2 = TDFunc(_test_func, a=tdf1, b=3.0) + assert tdf2() == 5.0 + + # This will overwrite all the b parameters. + assert tdf2(b=10.0) == 21.0 + + def test_np_sampler_method(): """Test that we can wrap numpy random functions.""" rng = np.random.default_rng(1001) @@ -91,12 +97,13 @@ def test_tdfunc_obj(): model = _StaticModel(a=10.0, b=11.0) # Check a function without defaults. - tdf1 = TDFunc(_test_func, object_args=["a", "b"]) - assert model.eval_func(tdf1) == 21.0 + tdf1 = TDFunc(_test_func, a=model.get_a, b=model.get_b) + assert tdf1() == 21.0 - # Defaults set are overwritten by the object. - tdf2 = TDFunc(_test_func, object_args=["a", "b"], a=1.0, b=0.0) - assert model.eval_func(tdf2) == 21.0 + # We can pull from multiple models. + model2 = _StaticModel(a=1.0, b=0.0) + tdf2 = TDFunc(_test_func, a=model.get_a, b=model2.get_b) + assert tdf2() == 10.0 - # But we can overwrite everything with kwargs. - assert model.eval_func(tdf2, b=7.5) == 17.5 + # W can overwrite everything with kwargs. + assert tdf2(b=7.5) == 17.5 From 96f52bca158c7c344987ef8e6ab80294dc03e590 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Sun, 14 Jul 2024 16:37:18 -0400 Subject: [PATCH 07/12] refactor --- src/tdastro/base_models.py | 334 +----------------- src/tdastro/effects/effect_model.py | 47 +++ src/tdastro/effects/white_noise.py | 2 +- src/tdastro/populations/fixed_population.py | 2 +- src/tdastro/populations/population_model.py | 106 ++++++ src/tdastro/sources/galaxy_models.py | 2 +- src/tdastro/sources/periodic_source.py | 2 +- src/tdastro/sources/physical_model.py | 167 +++++++++ src/tdastro/sources/sncomso_models.py | 2 +- src/tdastro/sources/spline_model.py | 2 +- src/tdastro/sources/static_source.py | 2 +- tests/tdastro/sources/test_physical_models.py | 31 ++ tests/tdastro/test_base_models.py | 37 +- 13 files changed, 374 insertions(+), 362 deletions(-) create mode 100644 src/tdastro/effects/effect_model.py create mode 100644 src/tdastro/populations/population_model.py create mode 100644 src/tdastro/sources/physical_model.py create mode 100644 tests/tdastro/sources/test_physical_models.py diff --git a/src/tdastro/base_models.py b/src/tdastro/base_models.py index 5e5746f3..d94ece2d 100644 --- a/src/tdastro/base_models.py +++ b/src/tdastro/base_models.py @@ -3,9 +3,6 @@ import types from enum import Enum -import numpy as np - -from tdastro.astro_utils.cosmology import RedshiftDistFunc from tdastro.function_wrappers import TDFunc @@ -22,9 +19,9 @@ class ParameterSource(Enum): MODEL_METHOD = 5 -class ParameterizedModel: +class ParameterizedNode: """Any model that uses parameters that can be set by constants, - functions, or other parameterized models. ParameterizedModels can + functions, or other parameterized models. ParameterizedNode can include physical objects or statistical distributions. Attributes @@ -34,7 +31,7 @@ class ParameterizedModel: (ParameterSource, setter information, required). The attributes are stored in the order in which they need to be set. sample_iteration : `int` - A counter used to syncronize sampling runs. Tracks how many times this + A counter used to syncronize sampling runs. Tracks how many times this model's parameters have been resampled. """ @@ -44,10 +41,10 @@ def __init__(self, **kwargs): def __str__(self): """Return the string representation of the model.""" - return "ParameterizedModel" + return "ParameterizedNode" def set_parameter(self, name, value=None, **kwargs): - """Set a single *existing* parameter to the ParameterizedModel. + """Set a single *existing* parameter to the ParameterizedNode. Notes ----- @@ -60,7 +57,7 @@ def set_parameter(self, name, value=None, **kwargs): The parameter name to add. value : any, optional The information to use to set the parameter. Can be a constant, - function, ParameterizedModel, or self. + function, ParameterizedNode, or self. **kwargs : `dict`, optional All other keyword arguments, possibly including the parameter setters. @@ -87,13 +84,13 @@ def set_parameter(self, name, value=None, **kwargs): # Case 1b: We are using a TDFunc wrapped function (with default parameters). self.setters[name] = (ParameterSource.TDFUNC_OBJ, value, required) setattr(self, name, value(**kwargs)) - elif isinstance(value, types.MethodType) and isinstance(value.__self__, ParameterizedModel): - # Case 2: We are trying to use the method from a ParameterizedModel. + elif isinstance(value, types.MethodType) and isinstance(value.__self__, ParameterizedNode): + # Case 2: We are trying to use the method from a ParameterizedNode. # Note that this will (correctly) fail if we are adding a model method from the current # object that requires an unset attribute. self.setters[name] = (ParameterSource.MODEL_METHOD, value, required) setattr(self, name, value(**kwargs)) - elif isinstance(value, ParameterizedModel): + elif isinstance(value, ParameterizedNode): # Case 3: We are trying to access an attribute from a parameterized model. if not hasattr(value, name): raise ValueError(f"Attribute {name} missing from parent.") @@ -108,7 +105,7 @@ def set_parameter(self, name, value=None, **kwargs): raise ValueError(f"Missing required parameter {name}") def add_parameter(self, name, value=None, required=False, **kwargs): - """Add a single *new* parameter to the ParameterizedModel. + """Add a single *new* parameter to the ParameterizedNode. Notes ----- @@ -123,7 +120,7 @@ def add_parameter(self, name, value=None, required=False, **kwargs): The parameter name to add. value : any, optional The information to use to set the parameter. Can be a constant, - function, ParameterizedModel, or self. + function, ParameterizedNode, or self. required : `bool` Fail if the parameter is set to ``None``. **kwargs : `dict`, optional @@ -146,7 +143,7 @@ def add_parameter(self, name, value=None, required=False, **kwargs): def sample_parameters(self, max_depth=50, **kwargs): """Sample the model's underlying parameters if they are provided by a function - or ParameterizedModel. + or ParameterizedNode. Parameters ---------- @@ -195,310 +192,3 @@ def sample_parameters(self, max_depth=50, **kwargs): # Increase the sampling iteration. self.sample_iteration += 1 - - -class PhysicalModel(ParameterizedModel): - """A physical model of a source of flux. - - Physical models can have fixed attributes (where you need to create a new model - to change them) and settable attributes that can be passed functions or constants. - They can also have special background pointers that link to another PhysicalModel - producing flux. We can chain these to have a supernova in front of a star in front - of a static background. - - Attributes - ---------- - ra : `float` - The object's right ascension (in degrees) - dec : `float` - The object's declination (in degrees) - redshift : `float` - The object's redshift. - distance : `float` - The object's distance (in pc). If the distance is not provided and - a ``cosmology`` parameter is given, the model will try to derive from - the redshift and the cosmology. - background : `PhysicalModel` - A source of background flux such as a host galaxy. - effects : `list` - A list of effects to apply to an observations. - """ - - def __init__(self, ra=None, dec=None, redshift=None, distance=None, background=None, **kwargs): - super().__init__(**kwargs) - self.effects = [] - - # Set RA, dec, and redshift from the parameters. - self.add_parameter("ra", ra) - self.add_parameter("dec", dec) - self.add_parameter("redshift", redshift) - - # If the distance is provided, use that. Otherwise try the redshift value - # using the cosmology (if given). Finally, default to None. - if distance is not None: - self.add_parameter("distance", distance) - elif redshift is not None and kwargs.get("cosmology", None) is not None: - self.add_parameter("distance", RedshiftDistFunc(redshift=self.get_redshift, **kwargs)) - else: - self.add_parameter("distance", None) - - # Background is an object not a sampled parameter - self.background = background - - def __str__(self): - """Return the string representation of the model.""" - return "PhysicalModel" - - def get_redshift(self): - """Return the redshift for the model.""" - return self.redshift - - def add_effect(self, effect, allow_dups=True, **kwargs): - """Add a transformational effect to the PhysicalModel. - Effects are applied in the order in which they are added. - - Parameters - ---------- - effect : `EffectModel` - The effect to apply. - allow_dups : `bool` - Allow multiple effects of the same type. - Default = ``True`` - **kwargs : `dict`, optional - Any additional keyword arguments. - - Raises - ------ - Raises a ``AttributeError`` if the PhysicalModel does not have all of the - required attributes. - """ - # Check that we have not added this effect before. - if not allow_dups: - effect_type = type(effect) - for prev in self.effects: - if effect_type == type(prev): - raise ValueError("Added the effect type to a model {effect_type} more than once.") - - required: list = effect.required_parameters() - for parameter in required: - # Raise an AttributeError if the parameter is missing or set to None. - if getattr(self, parameter) is None: - raise AttributeError(f"Parameter {parameter} unset for model {type(self).__name__}") - - self.effects.append(effect) - - def _evaluate(self, times, wavelengths, **kwargs): - """Draw effect-free observations for this object. - - Parameters - ---------- - times : `numpy.ndarray` - A length T array of timestamps. - wavelengths : `numpy.ndarray`, optional - A length N array of wavelengths. - **kwargs : `dict`, optional - Any additional keyword arguments. - - Returns - ------- - flux_density : `numpy.ndarray` - A length T x N matrix of SED values. - """ - raise NotImplementedError() - - def evaluate(self, times, wavelengths, resample_parameters=False, **kwargs): - """Draw observations for this object and apply the noise. - - Parameters - ---------- - times : `numpy.ndarray` - A length T array of timestamps. - wavelengths : `numpy.ndarray`, optional - A length N array of wavelengths. - resample_parameters : `bool` - Treat this evaluation as a completely new object, resampling the - parameters from the original provided functions. - **kwargs : `dict`, optional - Any additional keyword arguments. - - Returns - ------- - flux_density : `numpy.ndarray` - A length T x N matrix of SED values. - """ - if resample_parameters: - self.sample_parameters(kwargs) - - # Compute the flux density for both the current object and add in anything - # behind it, such as a host galaxy. - flux_density = self._evaluate(times, wavelengths, **kwargs) - if self.background is not None: - flux_density += self.background._evaluate(times, wavelengths, ra=self.ra, dec=self.dec, **kwargs) - - for effect in self.effects: - flux_density = effect.apply(flux_density, wavelengths, self, **kwargs) - return flux_density - - def sample_parameters(self, include_effects=True, **kwargs): - """Sample the model's underlying parameters if they are provided by a function - or ParameterizedModel. - - Parameters - ---------- - include_effects : `bool` - Resample the parameters for the effects models. - **kwargs : `dict`, optional - All the keyword arguments, including the values needed to sample - parameters. - """ - if self.background is not None: - self.background.sample_parameters(include_effects, **kwargs) - super().sample_parameters(**kwargs) - - if include_effects: - for effect in self.effects: - effect.sample_parameters(**kwargs) - - -class EffectModel(ParameterizedModel): - """A physical or systematic effect to apply to an observation.""" - - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def __str__(self): - """Return the string representation of the model.""" - return "EffectModel" - - def required_parameters(self): - """Returns a list of the parameters of a PhysicalModel - that this effect needs to access. - - Returns - ------- - parameters : `list` of `str` - A list of every required parameter the effect needs. - """ - return [] - - def apply(self, flux_density, wavelengths=None, physical_model=None, **kwargs): - """Apply the effect to observations (flux_density values) - - Parameters - ---------- - flux_density : `numpy.ndarray` - A length T X N matrix of flux density values. - wavelengths : `numpy.ndarray`, optional - A length N array of wavelengths. - physical_model : `PhysicalModel` - A PhysicalModel from which the effect may query parameters - such as redshift, position, or distance. - **kwargs : `dict`, optional - Any additional keyword arguments. - - Returns - ------- - flux_density : `numpy.ndarray` - A length T x N matrix of flux densities after the effect is applied. - """ - raise NotImplementedError() - - -class PopulationModel(ParameterizedModel): - """A model of a population of PhysicalModels. - - Attributes - ---------- - num_sources : `int` - The number of different sources in the population. - sources : `list` - A list of sources from which to draw. - """ - - def __init__(self, rng=None, **kwargs): - super().__init__(**kwargs) - self.num_sources = 0 - self.sources = [] - - def __str__(self): - """Return the string representation of the model.""" - return f"PopulationModel with {self.num_sources} sources." - - def add_source(self, new_source, **kwargs): - """Add a new source to the population. - - Parameters - ---------- - new_source : `PhysicalModel` - A source from the population. - **kwargs : `dict`, optional - Any additional keyword arguments. - """ - if not isinstance(new_source, PhysicalModel): - raise ValueError("All sources must be PhysicalModels") - self.sources.append(new_source) - self.num_sources += 1 - - def draw_source(self): - """Sample a single source from the population. - - Returns - ------- - source : `PhysicalModel` - A source from the population. - """ - raise NotImplementedError() - - def add_effect(self, effect, allow_dups=False, **kwargs): - """Add a transformational effect to all PhysicalModels in this population. - Effects are applied in the order in which they are added. - - Parameters - ---------- - effect : `EffectModel` - The effect to apply. - allow_dups : `bool` - Allow multiple effects of the same type. - Default = ``True`` - **kwargs : `dict`, optional - Any additional keyword arguments. - - Raises - ------ - Raises a ``AttributeError`` if the PhysicalModel does not have all of the - required attributes. - """ - for source in self.sources: - source.add_effect(effect, allow_dups=allow_dups, **kwargs) - - def evaluate(self, samples, times, wavelengths, resample_parameters=False, **kwargs): - """Draw observations from a single (randomly sampled) source. - - Parameters - ---------- - samples : `int` - The number of sources to samples. - times : `numpy.ndarray` - A length T array of timestamps. - wavelengths : `numpy.ndarray`, optional - A length N array of wavelengths. - resample_parameters : `bool` - Treat this evaluation as a completely new object, resampling the - parameters from the original provided functions. - **kwargs : `dict`, optional - Any additional keyword arguments. - - Returns - ------- - results : `numpy.ndarray` - A shape (samples, T, N) matrix of SED values. - """ - if samples <= 0: - raise ValueError("The number of samples must be > 0.") - - results = [] - for _ in range(samples): - source = self.draw_source() - object_fluxes = source.evaluate(times, wavelengths, resample_parameters, **kwargs) - results.append(object_fluxes) - return np.array(results) diff --git a/src/tdastro/effects/effect_model.py b/src/tdastro/effects/effect_model.py new file mode 100644 index 00000000..d5b80bf8 --- /dev/null +++ b/src/tdastro/effects/effect_model.py @@ -0,0 +1,47 @@ +"""The base EffectModel class used for all effects.""" + +from tdastro.base_models import ParameterizedNode + + +class EffectModel(ParameterizedNode): + """A physical or systematic effect to apply to an observation.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def __str__(self): + """Return the string representation of the model.""" + return "EffectModel" + + def required_parameters(self): + """Returns a list of the parameters of a PhysicalModel + that this effect needs to access. + + Returns + ------- + parameters : `list` of `str` + A list of every required parameter the effect needs. + """ + return [] + + def apply(self, flux_density, wavelengths=None, physical_model=None, **kwargs): + """Apply the effect to observations (flux_density values) + + Parameters + ---------- + flux_density : `numpy.ndarray` + A length T X N matrix of flux density values. + wavelengths : `numpy.ndarray`, optional + A length N array of wavelengths. + physical_model : `PhysicalModel` + A PhysicalModel from which the effect may query parameters + such as redshift, position, or distance. + **kwargs : `dict`, optional + Any additional keyword arguments. + + Returns + ------- + flux_density : `numpy.ndarray` + A length T x N matrix of flux densities after the effect is applied. + """ + raise NotImplementedError() diff --git a/src/tdastro/effects/white_noise.py b/src/tdastro/effects/white_noise.py index 49f2e621..b7d3a0a7 100644 --- a/src/tdastro/effects/white_noise.py +++ b/src/tdastro/effects/white_noise.py @@ -1,6 +1,6 @@ import numpy as np -from tdastro.base_models import EffectModel +from tdastro.effects.effect_model import EffectModel class WhiteNoise(EffectModel): diff --git a/src/tdastro/populations/fixed_population.py b/src/tdastro/populations/fixed_population.py index bc644b4a..18f3f0e9 100644 --- a/src/tdastro/populations/fixed_population.py +++ b/src/tdastro/populations/fixed_population.py @@ -1,6 +1,6 @@ import random -from tdastro.base_models import PopulationModel +from tdastro.populations.population_model import PopulationModel class FixedPopulation(PopulationModel): diff --git a/src/tdastro/populations/population_model.py b/src/tdastro/populations/population_model.py new file mode 100644 index 00000000..c527038a --- /dev/null +++ b/src/tdastro/populations/population_model.py @@ -0,0 +1,106 @@ +"""The base population models.""" + +import numpy as np + +from tdastro.base_models import ParameterizedNode +from tdastro.sources.physical_model import PhysicalModel + + +class PopulationModel(ParameterizedNode): + """A model of a population of PhysicalModels. + + Attributes + ---------- + num_sources : `int` + The number of different sources in the population. + sources : `list` + A list of sources from which to draw. + """ + + def __init__(self, rng=None, **kwargs): + super().__init__(**kwargs) + self.num_sources = 0 + self.sources = [] + + def __str__(self): + """Return the string representation of the model.""" + return f"PopulationModel with {self.num_sources} sources." + + def add_source(self, new_source, **kwargs): + """Add a new source to the population. + + Parameters + ---------- + new_source : `PhysicalModel` + A source from the population. + **kwargs : `dict`, optional + Any additional keyword arguments. + """ + if not isinstance(new_source, PhysicalModel): + raise ValueError("All sources must be PhysicalModels") + self.sources.append(new_source) + self.num_sources += 1 + + def draw_source(self): + """Sample a single source from the population. + + Returns + ------- + source : `PhysicalModel` + A source from the population. + """ + raise NotImplementedError() + + def add_effect(self, effect, allow_dups=False, **kwargs): + """Add a transformational effect to all PhysicalModels in this population. + Effects are applied in the order in which they are added. + + Parameters + ---------- + effect : `EffectModel` + The effect to apply. + allow_dups : `bool` + Allow multiple effects of the same type. + Default = ``True`` + **kwargs : `dict`, optional + Any additional keyword arguments. + + Raises + ------ + Raises a ``AttributeError`` if the PhysicalModel does not have all of the + required attributes. + """ + for source in self.sources: + source.add_effect(effect, allow_dups=allow_dups, **kwargs) + + def evaluate(self, samples, times, wavelengths, resample_parameters=False, **kwargs): + """Draw observations from a single (randomly sampled) source. + + Parameters + ---------- + samples : `int` + The number of sources to samples. + times : `numpy.ndarray` + A length T array of timestamps. + wavelengths : `numpy.ndarray`, optional + A length N array of wavelengths. + resample_parameters : `bool` + Treat this evaluation as a completely new object, resampling the + parameters from the original provided functions. + **kwargs : `dict`, optional + Any additional keyword arguments. + + Returns + ------- + results : `numpy.ndarray` + A shape (samples, T, N) matrix of SED values. + """ + if samples <= 0: + raise ValueError("The number of samples must be > 0.") + + results = [] + for _ in range(samples): + source = self.draw_source() + object_fluxes = source.evaluate(times, wavelengths, resample_parameters, **kwargs) + results.append(object_fluxes) + return np.array(results) diff --git a/src/tdastro/sources/galaxy_models.py b/src/tdastro/sources/galaxy_models.py index 9aea8bfc..eadb853a 100644 --- a/src/tdastro/sources/galaxy_models.py +++ b/src/tdastro/sources/galaxy_models.py @@ -1,7 +1,7 @@ import numpy as np from astropy.coordinates import angular_separation -from tdastro.base_models import PhysicalModel +from tdastro.sources.physical_model import PhysicalModel class GaussianGalaxy(PhysicalModel): diff --git a/src/tdastro/sources/periodic_source.py b/src/tdastro/sources/periodic_source.py index b68c02da..9b3cd81f 100644 --- a/src/tdastro/sources/periodic_source.py +++ b/src/tdastro/sources/periodic_source.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod -from tdastro.base_models import PhysicalModel +from tdastro.sources.physical_model import PhysicalModel class PeriodicSource(PhysicalModel, ABC): diff --git a/src/tdastro/sources/physical_model.py b/src/tdastro/sources/physical_model.py new file mode 100644 index 00000000..8cf4946e --- /dev/null +++ b/src/tdastro/sources/physical_model.py @@ -0,0 +1,167 @@ +"""The base PhysicalModel used for all sources.""" + +from tdastro.astro_utils.cosmology import RedshiftDistFunc +from tdastro.base_models import ParameterizedNode + + +class PhysicalModel(ParameterizedNode): + """A physical model of a source of flux. + + Physical models can have fixed attributes (where you need to create a new model + to change them) and settable attributes that can be passed functions or constants. + They can also have special background pointers that link to another PhysicalModel + producing flux. We can chain these to have a supernova in front of a star in front + of a static background. + + Attributes + ---------- + ra : `float` + The object's right ascension (in degrees) + dec : `float` + The object's declination (in degrees) + redshift : `float` + The object's redshift. + distance : `float` + The object's distance (in pc). If the distance is not provided and + a ``cosmology`` parameter is given, the model will try to derive from + the redshift and the cosmology. + background : `PhysicalModel` + A source of background flux such as a host galaxy. + effects : `list` + A list of effects to apply to an observations. + """ + + def __init__(self, ra=None, dec=None, redshift=None, distance=None, background=None, **kwargs): + super().__init__(**kwargs) + self.effects = [] + + # Set RA, dec, and redshift from the parameters. + self.add_parameter("ra", ra) + self.add_parameter("dec", dec) + self.add_parameter("redshift", redshift) + + # If the distance is provided, use that. Otherwise try the redshift value + # using the cosmology (if given). Finally, default to None. + if distance is not None: + self.add_parameter("distance", distance) + elif redshift is not None and kwargs.get("cosmology", None) is not None: + self.add_parameter("distance", RedshiftDistFunc(redshift=self.get_redshift, **kwargs)) + else: + self.add_parameter("distance", None) + + # Background is an object not a sampled parameter + self.background = background + + def __str__(self): + """Return the string representation of the model.""" + return "PhysicalModel" + + def get_redshift(self): + """Return the redshift for the model.""" + return self.redshift + + def add_effect(self, effect, allow_dups=True, **kwargs): + """Add a transformational effect to the PhysicalModel. + Effects are applied in the order in which they are added. + + Parameters + ---------- + effect : `EffectModel` + The effect to apply. + allow_dups : `bool` + Allow multiple effects of the same type. + Default = ``True`` + **kwargs : `dict`, optional + Any additional keyword arguments. + + Raises + ------ + Raises a ``AttributeError`` if the PhysicalModel does not have all of the + required attributes. + """ + # Check that we have not added this effect before. + if not allow_dups: + effect_type = type(effect) + for prev in self.effects: + if effect_type == type(prev): + raise ValueError("Added the effect type to a model {effect_type} more than once.") + + required: list = effect.required_parameters() + for parameter in required: + # Raise an AttributeError if the parameter is missing or set to None. + if getattr(self, parameter) is None: + raise AttributeError(f"Parameter {parameter} unset for model {type(self).__name__}") + + self.effects.append(effect) + + def _evaluate(self, times, wavelengths, **kwargs): + """Draw effect-free observations for this object. + + Parameters + ---------- + times : `numpy.ndarray` + A length T array of timestamps. + wavelengths : `numpy.ndarray`, optional + A length N array of wavelengths. + **kwargs : `dict`, optional + Any additional keyword arguments. + + Returns + ------- + flux_density : `numpy.ndarray` + A length T x N matrix of SED values. + """ + raise NotImplementedError() + + def evaluate(self, times, wavelengths, resample_parameters=False, **kwargs): + """Draw observations for this object and apply the noise. + + Parameters + ---------- + times : `numpy.ndarray` + A length T array of timestamps. + wavelengths : `numpy.ndarray`, optional + A length N array of wavelengths. + resample_parameters : `bool` + Treat this evaluation as a completely new object, resampling the + parameters from the original provided functions. + **kwargs : `dict`, optional + Any additional keyword arguments. + + Returns + ------- + flux_density : `numpy.ndarray` + A length T x N matrix of SED values. + """ + if resample_parameters: + self.sample_parameters(kwargs) + + # Compute the flux density for both the current object and add in anything + # behind it, such as a host galaxy. + flux_density = self._evaluate(times, wavelengths, **kwargs) + if self.background is not None: + flux_density += self.background._evaluate(times, wavelengths, ra=self.ra, dec=self.dec, **kwargs) + + for effect in self.effects: + flux_density = effect.apply(flux_density, wavelengths, self, **kwargs) + return flux_density + + def sample_parameters(self, include_effects=True, **kwargs): + """Sample the model's underlying parameters if they are provided by a function + or ParameterizedModel. + + Parameters + ---------- + include_effects : `bool` + Resample the parameters for the effects models. + **kwargs : `dict`, optional + All the keyword arguments, including the values needed to sample + parameters. + """ + if self.background is not None: + self.background.sample_parameters(include_effects, **kwargs) + super().sample_parameters(**kwargs) + + if include_effects: + for effect in self.effects: + effect.sample_parameters(**kwargs) diff --git a/src/tdastro/sources/sncomso_models.py b/src/tdastro/sources/sncomso_models.py index 379de41e..b176233e 100644 --- a/src/tdastro/sources/sncomso_models.py +++ b/src/tdastro/sources/sncomso_models.py @@ -6,7 +6,7 @@ from sncosmo.models import get_source -from tdastro.base_models import PhysicalModel +from tdastro.sources.physical_model import PhysicalModel class SncosmoWrapperModel(PhysicalModel): diff --git a/src/tdastro/sources/spline_model.py b/src/tdastro/sources/spline_model.py index eff26dff..67e79202 100644 --- a/src/tdastro/sources/spline_model.py +++ b/src/tdastro/sources/spline_model.py @@ -7,7 +7,7 @@ from scipy.interpolate import RectBivariateSpline -from tdastro.base_models import PhysicalModel +from tdastro.sources.physical_model import PhysicalModel class SplineModel(PhysicalModel): diff --git a/src/tdastro/sources/static_source.py b/src/tdastro/sources/static_source.py index 44799c2e..f69eea89 100644 --- a/src/tdastro/sources/static_source.py +++ b/src/tdastro/sources/static_source.py @@ -1,6 +1,6 @@ import numpy as np -from tdastro.base_models import PhysicalModel +from tdastro.sources.physical_model import PhysicalModel class StaticSource(PhysicalModel): diff --git a/tests/tdastro/sources/test_physical_models.py b/tests/tdastro/sources/test_physical_models.py new file mode 100644 index 00000000..cf18bf1f --- /dev/null +++ b/tests/tdastro/sources/test_physical_models.py @@ -0,0 +1,31 @@ +import pytest +from astropy.cosmology import WMAP9 +from tdastro.sources.physical_model import PhysicalModel + + +def test_physical_model(): + """Test that we can create a physical model.""" + # Everything is specified. + model1 = PhysicalModel(ra=1.0, dec=2.0, distance=3.0, redshift=0.0) + assert model1.ra == 1.0 + assert model1.dec == 2.0 + assert model1.distance == 3.0 + assert model1.redshift == 0.0 + + # Derive the distance from the redshift using the example from: + # https://docs.astropy.org/en/stable/api/astropy.cosmology.units.redshift_distance.html + model2 = PhysicalModel(ra=1.0, dec=2.0, redshift=1100.0, cosmology=WMAP9) + assert model2.ra == 1.0 + assert model2.dec == 2.0 + assert model2.redshift == 1100.0 + assert model2.distance == pytest.approx(14004.03157418 * 1e6) + + # Neither distance nor redshift are specified. + model3 = PhysicalModel(ra=1.0, dec=2.0) + assert model3.redshift is None + assert model3.distance is None + + # Redshift is specified but cosmology is not. + model4 = PhysicalModel(ra=1.0, dec=2.0, redshift=1100.0) + assert model4.redshift == 1100.0 + assert model4.distance is None diff --git a/tests/tdastro/test_base_models.py b/tests/tdastro/test_base_models.py index 6847cffa..591b88b6 100644 --- a/tests/tdastro/test_base_models.py +++ b/tests/tdastro/test_base_models.py @@ -1,8 +1,7 @@ import random import pytest -from astropy.cosmology import WMAP9 -from tdastro.base_models import ParameterizedModel, PhysicalModel +from tdastro.base_models import ParameterizedNode def _sampler_fun(**kwargs): @@ -16,7 +15,7 @@ def _sampler_fun(**kwargs): return random.random() -class PairModel(ParameterizedModel): +class PairModel(ParameterizedNode): """A test class for the parameterized model. Attributes @@ -34,9 +33,9 @@ def __init__(self, value1, value2, **kwargs): Parameters ---------- - value1 : `float`, `function`, `ParameterizedModel`, or `None` + value1 : `float`, `function`, `ParameterizedNode`, or `None` The first value. - value2 : `float`, `function`, `ParameterizedModel`, or `None` + value2 : `float`, `function`, `ParameterizedNode`, or `None` The second value. **kwargs : `dict`, optional Any additional keyword arguments. @@ -142,31 +141,3 @@ def test_parameterized_model_modify() -> None: # We cannot set a value that hasn't been added. with pytest.raises(KeyError): model.set_parameter("brightness", 5.0) - - -def test_physical_model(): - """Test that we can create a physical model.""" - # Everything is specified. - model1 = PhysicalModel(ra=1.0, dec=2.0, distance=3.0, redshift=0.0) - assert model1.ra == 1.0 - assert model1.dec == 2.0 - assert model1.distance == 3.0 - assert model1.redshift == 0.0 - - # Derive the distance from the redshift using the example from: - # https://docs.astropy.org/en/stable/api/astropy.cosmology.units.redshift_distance.html - model2 = PhysicalModel(ra=1.0, dec=2.0, redshift=1100.0, cosmology=WMAP9) - assert model2.ra == 1.0 - assert model2.dec == 2.0 - assert model2.redshift == 1100.0 - assert model2.distance == pytest.approx(14004.03157418 * 1e6) - - # Neither distance nor redshift are specified. - model3 = PhysicalModel(ra=1.0, dec=2.0) - assert model3.redshift is None - assert model3.distance is None - - # Redshift is specified but cosmology is not. - model4 = PhysicalModel(ra=1.0, dec=2.0, redshift=1100.0) - assert model4.redshift == 1100.0 - assert model4.distance is None From 8b263f0070a7c1c88ca3c018313b6b6e3ee302e8 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Sun, 14 Jul 2024 19:54:04 -0400 Subject: [PATCH 08/12] Put everything in a single node interface --- src/tdastro/astro_utils/cosmology.py | 9 +- src/tdastro/base_models.py | 157 +++++++++++++----- src/tdastro/function_wrappers.py | 80 --------- src/tdastro/sources/physical_model.py | 12 +- .../populations/test_fixed_population.py | 5 +- tests/tdastro/sources/test_galaxy_models.py | 4 + tests/tdastro/sources/test_step_source.py | 6 +- tests/tdastro/test_base_models.py | 77 ++++++++- tests/tdastro/test_function_wrappers.py | 109 ------------ 9 files changed, 211 insertions(+), 248 deletions(-) delete mode 100644 src/tdastro/function_wrappers.py delete mode 100644 tests/tdastro/test_function_wrappers.py diff --git a/src/tdastro/astro_utils/cosmology.py b/src/tdastro/astro_utils/cosmology.py index ecbd16ab..1beecf70 100644 --- a/src/tdastro/astro_utils/cosmology.py +++ b/src/tdastro/astro_utils/cosmology.py @@ -1,7 +1,7 @@ import astropy.cosmology.units as cu from astropy import units as u -from tdastro.function_wrappers import TDFunc +from tdastro.base_models import FunctionNode def redshift_to_distance(redshift, cosmology, kind="comoving"): @@ -28,7 +28,7 @@ def redshift_to_distance(redshift, cosmology, kind="comoving"): return distance.value -class RedshiftDistFunc(TDFunc): +class RedshiftDistFunc(FunctionNode): """A wrapper class for the redshift_to_distance() function. Attributes @@ -51,12 +51,7 @@ class RedshiftDistFunc(TDFunc): """ def __init__(self, redshift, cosmology, kind="comoving"): - self.cosmology = cosmology - self.kind = kind - # Call the super class's constructor with the needed information. - # Do not pass kwargs because we are limiting the arguments to match - # the signature of redshift_to_distance(). super().__init__( func=redshift_to_distance, redshift=redshift, diff --git a/src/tdastro/base_models.py b/src/tdastro/base_models.py index d94ece2d..262913da 100644 --- a/src/tdastro/base_models.py +++ b/src/tdastro/base_models.py @@ -1,28 +1,23 @@ -"""The base models used throughout TDAstro including physical objects, effects, and populations.""" +"""The base models used to specify the TDAstro computation graph.""" import types from enum import Enum -from tdastro.function_wrappers import TDFunc - class ParameterSource(Enum): - """ParameterSource specifies where a PhysicalModel should get the value - for a given parameter: a constant value, a function, or from another - parameterized model. + """ParameterSource specifies where a ParameterizedNode should get the value + for a given parameter: a constant value or from another ParameterizedNode. """ CONSTANT = 1 - FUNCTION = 2 - TDFUNC_OBJ = 3 - MODEL_ATTRIBUTE = 4 - MODEL_METHOD = 5 + MODEL_ATTRIBUTE = 2 + MODEL_METHOD = 3 + FUNCTION = 4 class ParameterizedNode: """Any model that uses parameters that can be set by constants, - functions, or other parameterized models. ParameterizedNode can - include physical objects or statistical distributions. + functions, or other parameterized nodes. Attributes ---------- @@ -43,6 +38,36 @@ def __str__(self): """Return the string representation of the model.""" return "ParameterizedNode" + def check_resample(self, child): + """Check if we need to resample the current node based + on the state of a child trying to access its attributes. + + Parameters + ---------- + child : `ParameterizedNode` + The node that is accessing the attribute or method + of the current node. + + Returns + ------- + bool + Indicates whether to resample or not. + + Raises + ------ + ``ValueError`` if the graph has gotten out of sync. + """ + if child == self: + return False + if child.sample_iteration == self.sample_iteration: + return False + if child.sample_iteration != self.sample_iteration + 1: + raise ValueError( + f"Node {str(child)} at iteration {child.sample_iteration} accessing" + f" parent {str(self)} at iteration {self.sample_iteration}." + ) + return True + def set_parameter(self, name, value=None, **kwargs): """Set a single *existing* parameter to the ParameterizedNode. @@ -76,28 +101,26 @@ def set_parameter(self, name, value=None, **kwargs): # The value wasn't set, but the name is in kwargs. value = kwargs[name] - if isinstance(value, types.FunctionType): - # Case 1a: If we are getting from a static function, sample it. - self.setters[name] = (ParameterSource.FUNCTION, value, required) - setattr(self, name, value(**kwargs)) - elif isinstance(value, TDFunc): - # Case 1b: We are using a TDFunc wrapped function (with default parameters). - self.setters[name] = (ParameterSource.TDFUNC_OBJ, value, required) - setattr(self, name, value(**kwargs)) - elif isinstance(value, types.MethodType) and isinstance(value.__self__, ParameterizedNode): - # Case 2: We are trying to use the method from a ParameterizedNode. - # Note that this will (correctly) fail if we are adding a model method from the current - # object that requires an unset attribute. - self.setters[name] = (ParameterSource.MODEL_METHOD, value, required) + if callable(value): + if isinstance(value, types.FunctionType): + # Case 1a: This is a static function (not attached to an object). + self.setters[name] = (ParameterSource.FUNCTION, value, required) + elif isinstance(value.__self__, ParameterizedNode): + # Case 1b: This is a method attached to another ParameterizedNode. + self.setters[name] = (ParameterSource.MODEL_METHOD, value, required) + else: + # Case 1c: This is a general callable method from another object. + # We treat it as static (we don't resample the other object). + self.setters[name] = (ParameterSource.FUNCTION, value, required) setattr(self, name, value(**kwargs)) elif isinstance(value, ParameterizedNode): - # Case 3: We are trying to access an attribute from a parameterized model. + # Case 2: We are trying to access a parameter of another ParameterizedNode. if not hasattr(value, name): raise ValueError(f"Attribute {name} missing from parent.") self.setters[name] = (ParameterSource.MODEL_ATTRIBUTE, value, required) setattr(self, name, getattr(value, name)) else: - # Case 4: The value is constant (including None). + # Case 3: The value is constant (including None). self.setters[name] = (ParameterSource.CONSTANT, value, required) setattr(self, name, value) @@ -162,6 +185,9 @@ def sample_parameters(self, max_depth=50, **kwargs): if max_depth == 0: raise ValueError(f"Maximum sampling depth exceeded at {self}. Potential infinite loop.") + # Increase the sampling iteration. + self.sample_iteration += 1 + # Run through each parameter and sample it based on the given recipe. # As of Python 3.7 dictionaries are guaranteed to preserve insertion ordering, # so this will iterate through attributes in the order they were inserted. @@ -171,24 +197,77 @@ def sample_parameters(self, max_depth=50, **kwargs): sampled_value = setter elif source_type == ParameterSource.FUNCTION: sampled_value = setter(**kwargs) - elif source_type == ParameterSource.TDFUNC_OBJ: - sampled_value = setter(**kwargs) elif source_type == ParameterSource.MODEL_ATTRIBUTE: - # Check if we need to resample the parent (needs to be done before - # we read its attribute). - if setter.sample_iteration == self.sample_iteration: + # Check if we need to resample the parent (before accessing the attribute). + if setter.check_resample(self): setter.sample_parameters(max_depth - 1, **kwargs) sampled_value = getattr(setter, name) elif source_type == ParameterSource.MODEL_METHOD: - # Check if we need to resample the parent (needs to be done before - # we evaluate its method). Do not resample the current object. - parent = setter.__self__ - if parent is not self and parent.sample_iteration == self.sample_iteration: - parent.sample_parameters(max_depth - 1, **kwargs) + # Check if we need to resample the parent (before calling the method). + parent_node = setter.__self__ + if parent_node.check_resample(self): + parent_node.sample_parameters(max_depth - 1, **kwargs) sampled_value = setter(**kwargs) else: raise ValueError(f"Unknown ParameterSource type {source_type} for {name}") setattr(self, name, sampled_value) - # Increase the sampling iteration. - self.sample_iteration += 1 + +class FunctionNode(ParameterizedNode): + """A class to wrap functions and their argument settings. + + Attributes + ---------- + func : `function` or `method` + The function to call during an evaluation. + args_names : `list` + A list of argument names to pass to the function. + + Examples + -------- + my_func = TDFunc(random.randint, a=1, b=10) + value1 = my_func() # Sample from default range + value2 = my_func(b=20) # Sample from extended range + + Note + ---- + All the function's parameters that will be used need to be specified + in either the default_args dict, object_args list, or as a kwarg in the + constructor. Arguments cannot be first given during function call. + For example the following will fail (because b is not defined in the + constructor): + + my_func = TDFunc(random.randint, a=1) + value1 = my_func(b=10.0) + """ + + def __init__(self, func, **kwargs): + super().__init__(**kwargs) + self.func = func + self.arg_names = [] + + # Add all of the parameters from default_args or the kwargs. + for key, value in kwargs.items(): + self.arg_names.append(key) + self.add_parameter(key, value) + + def __str__(self): + """Return the string representation of the function.""" + return f"FunctionNode({self.func.name})" + + def compute(self, **kwargs): + """Execute the wrapped function. + + Parameters + ---------- + **kwargs : `dict`, optional + Additional function arguments. + """ + args = {} + for key in self.arg_names: + # Override with the kwarg if the parameter is there. + if key in kwargs: + args[key] = kwargs[key] + else: + args[key] = getattr(self, key) + return self.func(**args) diff --git a/src/tdastro/function_wrappers.py b/src/tdastro/function_wrappers.py deleted file mode 100644 index 5a7aad6e..00000000 --- a/src/tdastro/function_wrappers.py +++ /dev/null @@ -1,80 +0,0 @@ -"""Utilities to wrap functions for inclusion in an evaluation graph.""" - -import copy - - -class TDFunc: - """A class to wrap functions and their argument settings. - - Attributes - ---------- - func : `function` or `method` - The function to call during an evaluation. - default_args : `dict` - A dictionary of default arguments to pass to the function. Assembled - from the ```default_args`` parameter and additional ``kwargs``. - setter_functions : `dict` - A dictionary mapping arguments names to functions, methods, or - TDFunc objects used to set that argument. These are evaluated dynamically - each time. - - Examples - -------- - my_func = TDFunc(random.randint, a=1, b=10) - value1 = my_func() # Sample from default range - value2 = my_func(b=20) # Sample from extended range - - Note - ---- - All the function's parameters that will be used need to be specified - in either the default_args dict, object_args list, or as a kwarg in the - constructor. Arguments cannot be first given during function call. - For example the following will fail (because b is not defined in the - constructor): - - my_func = TDFunc(random.randint, a=1) - value1 = my_func(b=10.0) - """ - - def __init__(self, func, **kwargs): - self.func = func - self.default_args = {} - self.setter_functions = {} - - for key, value in kwargs.items(): - if callable(value): - self.default_args[key] = None - self.setter_functions[key] = value - else: - self.default_args[key] = value - - def __str__(self): - """Return the string representation of the function.""" - return f"TDFunc({self.func.name})" - - def __call__(self, **kwargs): - """Execute the wrapped function. - - Parameters - ---------- - **kwargs : `dict`, optional - Additional function arguments. - """ - # Start with the default arguments. We make a copy so we can modify the dictionary. - args = copy.copy(self.default_args) - - # If there are arguments to get from the calling functions, set those. - for key, value in self.setter_functions.items(): - if isinstance(value, TDFunc): - args[key] = value(**kwargs) - else: - args[key] = value() - - # Set any last arguments from the kwargs (overwriting previous settings). - # Only use known kwargs. - for key, value in kwargs.items(): - if key in self.default_args: - args[key] = value - - # Call the function with all the parameters. - return self.func(**args) diff --git a/src/tdastro/sources/physical_model.py b/src/tdastro/sources/physical_model.py index 8cf4946e..c463bb70 100644 --- a/src/tdastro/sources/physical_model.py +++ b/src/tdastro/sources/physical_model.py @@ -45,7 +45,8 @@ def __init__(self, ra=None, dec=None, redshift=None, distance=None, background=N if distance is not None: self.add_parameter("distance", distance) elif redshift is not None and kwargs.get("cosmology", None) is not None: - self.add_parameter("distance", RedshiftDistFunc(redshift=self.get_redshift, **kwargs)) + self._redshift_func = RedshiftDistFunc(redshift=self, **kwargs) + self.add_parameter("distance", self._redshift_func.compute) else: self.add_parameter("distance", None) @@ -56,10 +57,6 @@ def __str__(self): """Return the string representation of the model.""" return "PhysicalModel" - def get_redshift(self): - """Return the redshift for the model.""" - return self.redshift - def add_effect(self, effect, allow_dups=True, **kwargs): """Add a transformational effect to the PhysicalModel. Effects are applied in the order in which they are added. @@ -158,10 +155,11 @@ def sample_parameters(self, include_effects=True, **kwargs): All the keyword arguments, including the values needed to sample parameters. """ - if self.background is not None: + if self.background is not None and self.background.check_resample(self): self.background.sample_parameters(include_effects, **kwargs) super().sample_parameters(**kwargs) if include_effects: for effect in self.effects: - effect.sample_parameters(**kwargs) + if effect.check_resample(self): + effect.sample_parameters(**kwargs) diff --git a/tests/tdastro/populations/test_fixed_population.py b/tests/tdastro/populations/test_fixed_population.py index 95b0e0ae..ff56b20a 100644 --- a/tests/tdastro/populations/test_fixed_population.py +++ b/tests/tdastro/populations/test_fixed_population.py @@ -2,8 +2,8 @@ import numpy as np import pytest +from tdastro.base_models import FunctionNode from tdastro.effects.white_noise import WhiteNoise -from tdastro.function_wrappers import TDFunc from tdastro.populations.fixed_population import FixedPopulation from tdastro.sources.static_source import StaticSource @@ -88,9 +88,10 @@ def test_fixed_population_sample_sources(): def test_fixed_population_sample_fluxes(): """Test that we can create a population of sources and sample its sources' flux.""" random.seed(1001) + brightness_func = FunctionNode(random.uniform, a=0.0, b=100.0) population = FixedPopulation() population.add_source(StaticSource(brightness=100.0), 10.0) - population.add_source(StaticSource(brightness=TDFunc(random.uniform, a=0.0, b=100.0)), 10.0) + population.add_source(StaticSource(brightness=brightness_func.compute), 10.0) population.add_source(StaticSource(brightness=200.0), 20.0) population.add_source(StaticSource(brightness=150.0), 10.0) diff --git a/tests/tdastro/sources/test_galaxy_models.py b/tests/tdastro/sources/test_galaxy_models.py index e127494d..312777f8 100644 --- a/tests/tdastro/sources/test_galaxy_models.py +++ b/tests/tdastro/sources/test_galaxy_models.py @@ -55,7 +55,11 @@ def test_gaussian_galaxy() -> None: # Check that if we resample the source it will propagate and correctly resample the host. # the host's (RA, dec) should change and the source's should still be close. + print(f"Host sample = {host.sample_iteration}") + print(f"Source sample = {source.sample_iteration}") source.sample_parameters() + print(f"Host sample = {host.sample_iteration}") + print(f"Source sample = {source.sample_iteration}") assert host_ra != host.ra assert host_dec != host.dec diff --git a/tests/tdastro/sources/test_step_source.py b/tests/tdastro/sources/test_step_source.py index 325e73f2..899a716f 100644 --- a/tests/tdastro/sources/test_step_source.py +++ b/tests/tdastro/sources/test_step_source.py @@ -1,7 +1,7 @@ import random import numpy as np -from tdastro.function_wrappers import TDFunc +from tdastro.base_models import FunctionNode from tdastro.sources.static_source import StaticSource from tdastro.sources.step_source import StepSource @@ -58,9 +58,9 @@ def test_step_source_resample() -> None: random.seed(1111) model = StepSource( - brightness=TDFunc(_sample_brightness, magnitude=100.0), + brightness=FunctionNode(_sample_brightness, magnitude=100.0).compute, t0=0.0, - t1=TDFunc(_sample_end, duration=5.0), + t1=FunctionNode(_sample_end, duration=5.0).compute, ) num_samples = 1000 diff --git a/tests/tdastro/test_base_models.py b/tests/tdastro/test_base_models.py index 591b88b6..b0fb51e5 100644 --- a/tests/tdastro/test_base_models.py +++ b/tests/tdastro/test_base_models.py @@ -1,7 +1,8 @@ import random +import numpy as np import pytest -from tdastro.base_models import ParameterizedNode +from tdastro.base_models import FunctionNode, ParameterizedNode def _sampler_fun(**kwargs): @@ -15,6 +16,19 @@ def _sampler_fun(**kwargs): return random.random() +def _test_func(value1, value2): + """Return the sum of the two parameters. + + Parameters + ---------- + value1 : `float` + The first parameter. + value2 : `float` + The second parameter. + """ + return value1 + value2 + + class PairModel(ParameterizedNode): """A test class for the parameterized model. @@ -45,6 +59,10 @@ def __init__(self, value1, value2, **kwargs): self.add_parameter("value2", value2, required=True, **kwargs) self.add_parameter("value_sum", self.result, required=True, **kwargs) + def get_value1(self): + """Get the value of value1.""" + return self.value1 + def result(self, **kwargs): """Add the pair of values together @@ -141,3 +159,60 @@ def test_parameterized_model_modify() -> None: # We cannot set a value that hasn't been added. with pytest.raises(KeyError): model.set_parameter("brightness", 5.0) + + +def test_function_node_basic(): + """Test that we can create and query a FunctionNode.""" + my_func = FunctionNode(_test_func, value1=1.0, value2=2.0) + assert my_func.compute() == 3.0 + assert my_func.compute(value2=3.0) == 4.0 + assert my_func.compute(value2=3.0, unused_param=5.0) == 4.0 + assert my_func.compute(value2=3.0, value1=1.0) == 4.0 + + +def test_function_node_chain(): + """Test that we can create and query a chained FunctionNode.""" + func1 = FunctionNode(_test_func, value1=1.0, value2=1.0) + func2 = FunctionNode(_test_func, value1=func1.compute, value2=3.0) + assert func2.compute() == 5.0 + + +def test_np_sampler_method(): + """Test that we can wrap numpy random functions.""" + rng = np.random.default_rng(1001) + my_func = FunctionNode(rng.normal, loc=10.0, scale=1.0) + + # Sample 1000 times with the default values. Check that we are near + # the expected mean and not everything is equal. + vals = np.array([my_func.compute() for _ in range(1000)]) + assert abs(np.mean(vals) - 10.0) < 1.0 + assert not np.all(vals == vals[0]) + + # Override the mean and resample. + vals = np.array([my_func.compute(loc=25.0) for _ in range(1000)]) + assert abs(np.mean(vals) - 25.0) < 1.0 + assert not np.all(vals == vals[0]) + + +def test_function_node_obj(): + """Test that we can create and query a FunctionNode that depends on + another ParameterizedNode. + """ + # The model depends on the function. + func = FunctionNode(_test_func, value1=5.0, value2=6.0) + model = PairModel(value1=10.0, value2=func.compute) + assert model.result() == 21.0 + + # Function depends on the model's value2 attribute. + model = PairModel(value1=-10.0, value2=17.5) + func = FunctionNode(_test_func, value1=5.0, value2=model) + assert model.result() == 7.5 + assert func.compute() == 22.5 + + # Function depends on the model's get_value1() method. + func = FunctionNode(_test_func, value1=model.get_value1, value2=5.0) + assert model.result() == 7.5 + assert func.compute() == -5.0 + + # We can always override the attributes with kwargs. + assert func.compute(value1=1.0, value2=4.0) == 5.0 diff --git a/tests/tdastro/test_function_wrappers.py b/tests/tdastro/test_function_wrappers.py deleted file mode 100644 index e7048391..00000000 --- a/tests/tdastro/test_function_wrappers.py +++ /dev/null @@ -1,109 +0,0 @@ -import numpy as np -import pytest -from tdastro.function_wrappers import TDFunc - - -def _test_func(a, b): - """Return the sum of the two parameters. - - Parameters - ---------- - a : `float` - The first parameter. - b : `float` - The second parameter. - """ - return a + b - - -class _StaticModel: - """A test model that has given parameters. - - Attributes - ---------- - a : `float` - The first parameter. - b : `float` - The second parameter. - """ - - def __init__(self, a, b): - self.a = a - self.b = b - - def get_a(self): - """Get the a attribute.""" - return self.a - - def get_b(self): - """Get the b attribute.""" - return self.b - - -def test_tdfunc_basic(): - """Test that we can create and query a TDFunc.""" - # Fail without enough arguments (only a is specified). - tdf1 = TDFunc(_test_func, a=1.0) - with pytest.raises(TypeError): - _ = tdf1() - - # We succeed with a manually specified parameter (but first fail with - # the default). - tdf2 = TDFunc(_test_func, a=1.0, b=None) - with pytest.raises(TypeError): - _ = tdf2() - assert tdf2(b=2.0) == 3.0 - - # We can overwrite parameters. - assert tdf2(a=3.0, b=2.0) == 5.0 - - # Test that we ignore extra kwargs. - assert tdf2(b=2.0, c=10.0, d=11.0) == 3.0 - - # That we can use a different ordering for parameters. - tdf3 = TDFunc(_test_func, b=2.0, a=1.0) - assert tdf3() == 3.0 - - -def test_tdfunc_chain(): - """Test that we can create and query a chained TDFunc.""" - tdf1 = TDFunc(_test_func, a=1.0, b=1.0) - tdf2 = TDFunc(_test_func, a=tdf1, b=3.0) - assert tdf2() == 5.0 - - # This will overwrite all the b parameters. - assert tdf2(b=10.0) == 21.0 - - -def test_np_sampler_method(): - """Test that we can wrap numpy random functions.""" - rng = np.random.default_rng(1001) - tdf = TDFunc(rng.normal, loc=10.0, scale=1.0) - - # Sample 1000 times with the default values. Check that we are near - # the expected mean and not everything is equal. - vals = np.array([tdf() for _ in range(1000)]) - assert abs(np.mean(vals) - 10.0) < 1.0 - assert not np.all(vals == vals[0]) - - # Override the mean and resample. - vals = np.array([tdf(loc=25.0) for _ in range(1000)]) - assert abs(np.mean(vals) - 25.0) < 1.0 - assert not np.all(vals == vals[0]) - - -def test_tdfunc_obj(): - """Test that we can create and query a TDFunc that depends on an object.""" - model = _StaticModel(a=10.0, b=11.0) - - # Check a function without defaults. - tdf1 = TDFunc(_test_func, a=model.get_a, b=model.get_b) - assert tdf1() == 21.0 - - # We can pull from multiple models. - model2 = _StaticModel(a=1.0, b=0.0) - tdf2 = TDFunc(_test_func, a=model.get_a, b=model2.get_b) - assert tdf2() == 10.0 - - # W can overwrite everything with kwargs. - assert tdf2(b=7.5) == 17.5 From de036aa2b1ce94e028ff284922bcf5527c3c9285 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Mon, 15 Jul 2024 09:22:31 -0400 Subject: [PATCH 09/12] Fix merge --- tests/tdastro/sources/test_galaxy_models.py | 4 ---- tests/tdastro/sources/test_physical_models.py | 2 +- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/tdastro/sources/test_galaxy_models.py b/tests/tdastro/sources/test_galaxy_models.py index 312777f8..e127494d 100644 --- a/tests/tdastro/sources/test_galaxy_models.py +++ b/tests/tdastro/sources/test_galaxy_models.py @@ -55,11 +55,7 @@ def test_gaussian_galaxy() -> None: # Check that if we resample the source it will propagate and correctly resample the host. # the host's (RA, dec) should change and the source's should still be close. - print(f"Host sample = {host.sample_iteration}") - print(f"Source sample = {source.sample_iteration}") source.sample_parameters() - print(f"Host sample = {host.sample_iteration}") - print(f"Source sample = {source.sample_iteration}") assert host_ra != host.ra assert host_dec != host.dec diff --git a/tests/tdastro/sources/test_physical_models.py b/tests/tdastro/sources/test_physical_models.py index cf18bf1f..cd91e3e3 100644 --- a/tests/tdastro/sources/test_physical_models.py +++ b/tests/tdastro/sources/test_physical_models.py @@ -4,7 +4,7 @@ def test_physical_model(): - """Test that we can create a physical model.""" + """Test that we can create a PhysicalModel.""" # Everything is specified. model1 = PhysicalModel(ra=1.0, dec=2.0, distance=3.0, redshift=0.0) assert model1.ra == 1.0 From 2b2230a41f5ff508be3629a37774f2051f70dd46 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Mon, 15 Jul 2024 09:32:06 -0400 Subject: [PATCH 10/12] Small fixes --- src/tdastro/base_models.py | 23 ++++++++++++----------- tests/tdastro/test_base_models.py | 2 +- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/src/tdastro/base_models.py b/src/tdastro/base_models.py index 262913da..cc67c8bb 100644 --- a/src/tdastro/base_models.py +++ b/src/tdastro/base_models.py @@ -10,9 +10,9 @@ class ParameterSource(Enum): """ CONSTANT = 1 - MODEL_ATTRIBUTE = 2 - MODEL_METHOD = 3 - FUNCTION = 4 + FUNCTION = 2 + MODEL_ATTRIBUTE = 3 + MODEL_METHOD = 4 class ParameterizedNode: @@ -35,16 +35,17 @@ def __init__(self, **kwargs): self.sample_iteration = 0 def __str__(self): - """Return the string representation of the model.""" + """Return the string representation of the node.""" return "ParameterizedNode" - def check_resample(self, child): + def check_resample(self, other): """Check if we need to resample the current node based - on the state of a child trying to access its attributes. + on the state of another node trying to access its attributes + or methods. Parameters ---------- - child : `ParameterizedNode` + other : `ParameterizedNode` The node that is accessing the attribute or method of the current node. @@ -57,13 +58,13 @@ def check_resample(self, child): ------ ``ValueError`` if the graph has gotten out of sync. """ - if child == self: + if other == self: return False - if child.sample_iteration == self.sample_iteration: + if other.sample_iteration == self.sample_iteration: return False - if child.sample_iteration != self.sample_iteration + 1: + if other.sample_iteration != self.sample_iteration + 1: raise ValueError( - f"Node {str(child)} at iteration {child.sample_iteration} accessing" + f"Node {str(other)} at iteration {other.sample_iteration} accessing" f" parent {str(self)} at iteration {self.sample_iteration}." ) return True diff --git a/tests/tdastro/test_base_models.py b/tests/tdastro/test_base_models.py index 8c29647e..1fa953e5 100644 --- a/tests/tdastro/test_base_models.py +++ b/tests/tdastro/test_base_models.py @@ -30,7 +30,7 @@ def _test_func(value1, value2): class PairModel(ParameterizedNode): - """A test class for the parameterized model. + """A test class for the ParameterizedNode. Attributes ---------- From bc4848a54696bd4d243e4486b464106fb6f7d069 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Mon, 15 Jul 2024 10:43:08 -0400 Subject: [PATCH 11/12] Address PR comments. --- src/tdastro/astro_utils/cosmology.py | 4 ++-- src/tdastro/sources/physical_model.py | 6 +++--- tests/tdastro/astro_utils/test_cosmology.py | 5 ++++- tests/tdastro/sources/test_physical_models.py | 9 ++++----- 4 files changed, 13 insertions(+), 11 deletions(-) diff --git a/src/tdastro/astro_utils/cosmology.py b/src/tdastro/astro_utils/cosmology.py index 1beecf70..b36022b1 100644 --- a/src/tdastro/astro_utils/cosmology.py +++ b/src/tdastro/astro_utils/cosmology.py @@ -5,7 +5,7 @@ def redshift_to_distance(redshift, cosmology, kind="comoving"): - """Compute a source's distance given its redshift and a + """Compute a source's luminosity distance given its redshift and a specified cosmology using astropy's redshift_distance(). Parameters @@ -21,7 +21,7 @@ def redshift_to_distance(redshift, cosmology, kind="comoving"): Returns ------- distance : `float` - The distance (in pc) + The luminosity distance (in pc) """ z = redshift * cu.redshift distance = z.to(u.pc, cu.redshift_distance(cosmology, kind=kind)) diff --git a/src/tdastro/sources/physical_model.py b/src/tdastro/sources/physical_model.py index c463bb70..fb4cb3de 100644 --- a/src/tdastro/sources/physical_model.py +++ b/src/tdastro/sources/physical_model.py @@ -22,7 +22,7 @@ class PhysicalModel(ParameterizedNode): redshift : `float` The object's redshift. distance : `float` - The object's distance (in pc). If the distance is not provided and + The object's luminosity distance (in pc). If no value is provided and a ``cosmology`` parameter is given, the model will try to derive from the redshift and the cosmology. background : `PhysicalModel` @@ -40,8 +40,8 @@ def __init__(self, ra=None, dec=None, redshift=None, distance=None, background=N self.add_parameter("dec", dec) self.add_parameter("redshift", redshift) - # If the distance is provided, use that. Otherwise try the redshift value - # using the cosmology (if given). Finally, default to None. + # If the luminosity distance is provided, use that. Otherwise try the + # redshift value using the cosmology (if given). Finally, default to None. if distance is not None: self.add_parameter("distance", distance) elif redshift is not None and kwargs.get("cosmology", None) is not None: diff --git a/tests/tdastro/astro_utils/test_cosmology.py b/tests/tdastro/astro_utils/test_cosmology.py index 14a500e2..fff88c47 100644 --- a/tests/tdastro/astro_utils/test_cosmology.py +++ b/tests/tdastro/astro_utils/test_cosmology.py @@ -1,5 +1,5 @@ import pytest -from astropy.cosmology import WMAP9 +from astropy.cosmology import WMAP9, Planck18 from tdastro.astro_utils.cosmology import redshift_to_distance @@ -8,3 +8,6 @@ def test_redshift_to_distance(): # Use the example from: # https://docs.astropy.org/en/stable/api/astropy.cosmology.units.redshift_distance.html assert redshift_to_distance(1100, cosmology=WMAP9) == pytest.approx(14004.03157418 * 1e6) + + # Try the Planck18 cosmology. + assert redshift_to_distance(1100, cosmology=Planck18) == pytest.approx(13886.327957 * 1e6) diff --git a/tests/tdastro/sources/test_physical_models.py b/tests/tdastro/sources/test_physical_models.py index cd91e3e3..d05560e3 100644 --- a/tests/tdastro/sources/test_physical_models.py +++ b/tests/tdastro/sources/test_physical_models.py @@ -1,5 +1,5 @@ import pytest -from astropy.cosmology import WMAP9 +from astropy.cosmology import Planck18 from tdastro.sources.physical_model import PhysicalModel @@ -12,13 +12,12 @@ def test_physical_model(): assert model1.distance == 3.0 assert model1.redshift == 0.0 - # Derive the distance from the redshift using the example from: - # https://docs.astropy.org/en/stable/api/astropy.cosmology.units.redshift_distance.html - model2 = PhysicalModel(ra=1.0, dec=2.0, redshift=1100.0, cosmology=WMAP9) + # Derive the distance from the redshift: + model2 = PhysicalModel(ra=1.0, dec=2.0, redshift=1100.0, cosmology=Planck18) assert model2.ra == 1.0 assert model2.dec == 2.0 assert model2.redshift == 1100.0 - assert model2.distance == pytest.approx(14004.03157418 * 1e6) + assert model2.distance == pytest.approx(13886.327957 * 1e6) # Neither distance nor redshift are specified. model3 = PhysicalModel(ra=1.0, dec=2.0) From bd09398225e9278d3a72a66d001da2b3a27dd293 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Mon, 15 Jul 2024 10:52:05 -0400 Subject: [PATCH 12/12] Address more PR comments Fix the distance kind as "luminosity" --- src/tdastro/astro_utils/cosmology.py | 13 +++---------- tests/tdastro/astro_utils/test_cosmology.py | 11 +++++------ tests/tdastro/sources/test_physical_models.py | 3 +-- 3 files changed, 9 insertions(+), 18 deletions(-) diff --git a/src/tdastro/astro_utils/cosmology.py b/src/tdastro/astro_utils/cosmology.py index b36022b1..57771220 100644 --- a/src/tdastro/astro_utils/cosmology.py +++ b/src/tdastro/astro_utils/cosmology.py @@ -4,7 +4,7 @@ from tdastro.base_models import FunctionNode -def redshift_to_distance(redshift, cosmology, kind="comoving"): +def redshift_to_distance(redshift, cosmology): """Compute a source's luminosity distance given its redshift and a specified cosmology using astropy's redshift_distance(). @@ -14,9 +14,6 @@ def redshift_to_distance(redshift, cosmology, kind="comoving"): The redshift value. cosmology : `astropy.cosmology` The cosmology specification. - kind : `str` - The distance type for the Equivalency as defined by - astropy.cosmology.units.redshift_distance. Returns ------- @@ -24,7 +21,7 @@ def redshift_to_distance(redshift, cosmology, kind="comoving"): The luminosity distance (in pc) """ z = redshift * cu.redshift - distance = z.to(u.pc, cu.redshift_distance(cosmology, kind=kind)) + distance = z.to(u.pc, cu.redshift_distance(cosmology, kind="luminosity")) return distance.value @@ -45,18 +42,14 @@ class RedshiftDistFunc(FunctionNode): The function or constant providing the redshift value. cosmology : `astropy.cosmology` The cosmology specification. - kind : `str` - The distance type for the Equivalency as defined by - astropy.cosmology.units.redshift_distance. """ - def __init__(self, redshift, cosmology, kind="comoving"): + def __init__(self, redshift, cosmology): # Call the super class's constructor with the needed information. super().__init__( func=redshift_to_distance, redshift=redshift, cosmology=cosmology, - kind=kind, ) def __str__(self): diff --git a/tests/tdastro/astro_utils/test_cosmology.py b/tests/tdastro/astro_utils/test_cosmology.py index fff88c47..9ed9559e 100644 --- a/tests/tdastro/astro_utils/test_cosmology.py +++ b/tests/tdastro/astro_utils/test_cosmology.py @@ -1,13 +1,12 @@ -import pytest from astropy.cosmology import WMAP9, Planck18 from tdastro.astro_utils.cosmology import redshift_to_distance def test_redshift_to_distance(): """Test that we can convert the redshift to a distance using a given cosmology.""" - # Use the example from: - # https://docs.astropy.org/en/stable/api/astropy.cosmology.units.redshift_distance.html - assert redshift_to_distance(1100, cosmology=WMAP9) == pytest.approx(14004.03157418 * 1e6) + wmap9_val = redshift_to_distance(1100, cosmology=WMAP9) + planck18_val = redshift_to_distance(1100, cosmology=Planck18) - # Try the Planck18 cosmology. - assert redshift_to_distance(1100, cosmology=Planck18) == pytest.approx(13886.327957 * 1e6) + assert abs(planck18_val - wmap9_val) > 1000.0 + assert 13.0 * 1e12 < wmap9_val < 16.0 * 1e12 + assert 13.0 * 1e12 < planck18_val < 16.0 * 1e12 diff --git a/tests/tdastro/sources/test_physical_models.py b/tests/tdastro/sources/test_physical_models.py index d05560e3..8098cf1b 100644 --- a/tests/tdastro/sources/test_physical_models.py +++ b/tests/tdastro/sources/test_physical_models.py @@ -1,4 +1,3 @@ -import pytest from astropy.cosmology import Planck18 from tdastro.sources.physical_model import PhysicalModel @@ -17,7 +16,7 @@ def test_physical_model(): assert model2.ra == 1.0 assert model2.dec == 2.0 assert model2.redshift == 1100.0 - assert model2.distance == pytest.approx(13886.327957 * 1e6) + assert 13.0 * 1e12 < model2.distance < 16.0 * 1e12 # Neither distance nor redshift are specified. model3 = PhysicalModel(ra=1.0, dec=2.0)