From 370447a67a8060efba65a6788622ff912ef5824f Mon Sep 17 00:00:00 2001 From: Konstantin Malanchev Date: Wed, 10 Jul 2024 14:02:07 -0400 Subject: [PATCH 01/31] Rename epoch to t0 --- src/tdastro/sources/periodic_source.py | 11 ++++++----- src/tdastro/sources/periodic_variable_star.py | 11 ++++++----- tests/tdastro/sources/test_periodic_source.py | 6 +++--- tests/tdastro/sources/test_periodic_variable_star.py | 2 +- 4 files changed, 16 insertions(+), 14 deletions(-) diff --git a/src/tdastro/sources/periodic_source.py b/src/tdastro/sources/periodic_source.py index 82c5d475..b68c02da 100644 --- a/src/tdastro/sources/periodic_source.py +++ b/src/tdastro/sources/periodic_source.py @@ -10,14 +10,15 @@ class PeriodicSource(PhysicalModel, ABC): ---------- period : `float` The period of the source, in days. - epoch : `float` - The epoch of the zero phase, date. + t0 : `float` + The t0 of the zero phase, date. Could be date of the minimum or maximum light + or any other reference time point. """ - def __init__(self, period, epoch, **kwargs): + def __init__(self, period, t0, **kwargs): super().__init__(**kwargs) self.add_parameter("period", period, required=True, **kwargs) - self.add_parameter("epoch", epoch, required=True, **kwargs) + self.add_parameter("t0", t0, required=True, **kwargs) @abstractmethod def _evaluate_phases(self, phases, wavelengths, **kwargs): @@ -56,7 +57,7 @@ def _evaluate(self, times, wavelengths, **kwargs): flux_density : `numpy.ndarray` A length T x N matrix of SED values. """ - phases = (times - self.epoch) % self.period / self.period + phases = (times - self.t0) % self.period / self.period flux_density = self._evaluate_phases(phases, wavelengths, **kwargs) return flux_density diff --git a/src/tdastro/sources/periodic_variable_star.py b/src/tdastro/sources/periodic_variable_star.py index 7caceb86..e75fd6f5 100644 --- a/src/tdastro/sources/periodic_variable_star.py +++ b/src/tdastro/sources/periodic_variable_star.py @@ -14,15 +14,16 @@ class PeriodicVariableStar(PeriodicSource, ABC): ---------- period : `float` The period of the source, in days. - epoch : `float` - The epoch of the zero phase, date. + t0 : `float` + The t0 of the zero phase, date. Could be date of the minimum or maximum light + or any other reference time point. distance : `float` The distance to the source, in pc. """ - def __init__(self, period, epoch, **kwargs): + def __init__(self, period, t0, **kwargs): distance = kwargs.pop("distance", None) - super().__init__(period, epoch, **kwargs) + super().__init__(period, t0, **kwargs) self.add_parameter("distance", value=distance, required=True, **kwargs) def _evaluate_phases(self, phases, wavelengths, **kwargs): @@ -77,7 +78,7 @@ class EclipsingBinaryStar(PeriodicVariableStar): """A toy model for a detached eclipsing binary star. It is assumed that the stars are spherical, SED is black-body, - and the orbits are circular. Epoch is the time of the primary eclipse. + and the orbits are circular. t0 is the epoch of the primary minimum. No limb darkening, reflection, or other effects are included. Attributes diff --git a/tests/tdastro/sources/test_periodic_source.py b/tests/tdastro/sources/test_periodic_source.py index 25cbc7b5..335fb976 100644 --- a/tests/tdastro/sources/test_periodic_source.py +++ b/tests/tdastro/sources/test_periodic_source.py @@ -13,13 +13,13 @@ def _evaluate_phases(self, phases, wavelengths, **kwargs): return amplitude * wavelengths[None, :] ** -2 -@pytest.mark.parametrize("period, epoch", [(1.0, 0.0), (2.0, 0.75), (4.0, 100 / 3)]) -def test_periodicity(period, epoch): +@pytest.mark.parametrize("period, t0", [(1.0, 0.0), (2.0, 0.75), (4.0, 100 / 3)]) +def test_periodicity(period, t0): """Test that the source is periodic.""" max_time = 16 n_periods = int(max_time / period) - source = SineSource(period=period, epoch=epoch) + source = SineSource(period=period, t0=t0) times = np.linspace(0, max_time, max_time * 100 + 1) wavelengths = np.linspace(100, 200, 3) fluxes = source.evaluate(times, wavelengths) diff --git a/tests/tdastro/sources/test_periodic_variable_star.py b/tests/tdastro/sources/test_periodic_variable_star.py index 723e3c5a..91c4099d 100644 --- a/tests/tdastro/sources/test_periodic_variable_star.py +++ b/tests/tdastro/sources/test_periodic_variable_star.py @@ -145,7 +145,7 @@ def test_eclipsing_binary_star(): source = EclipsingBinaryStar( distance=distance_pc, period=period.to_value(u.day), - epoch=0.0, + t0=0.0, major_semiaxis=major_semiaxis.cgs.value, inclination=89.0, primary_radius=primary_radius.cgs.value, From 00de859781bf8e72672bdddc8888bb27dfdcfcf1 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Wed, 10 Jul 2024 17:46:25 -0400 Subject: [PATCH 02/31] Create io_utils.py --- src/tdastro/io_utils.py | 49 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 src/tdastro/io_utils.py diff --git a/src/tdastro/io_utils.py b/src/tdastro/io_utils.py new file mode 100644 index 00000000..9067ca9b --- /dev/null +++ b/src/tdastro/io_utils.py @@ -0,0 +1,49 @@ +import numpy as np +from astropy.table import Table + + +def read_grid_data(input_file, format="ascii"): + """Read 2-d grid data from a text, csv, ecsv, or fits file. + + Each line is of the form 'x0 x1 value' where x0 and x1 are the grid + coordinates and value is the grid value. The rows should be sorted by + increasing x0 and, within an x0 value, increasing x1. + + Parameters + ---------- + input_file : `str` or file-like object + The input data file. + format : `str` + The file format. Should be one of ascii, csv, ecsv, + or fits. + + Returns + ------- + x0 : `numpy.ndarray` + A 1-d array with the values along the x-axis of the grid. + x1 : `numpy.ndarray` + A 1-d array with the values along the y-axis of the grid. + values : `numpy.ndarray` + A 2-d array with the values at each point in the grid with + shape (len(x0), len(x1)). + """ + data = Table.read(input_file, format=format) + if len(data.colnames) != 3: + raise ValueError( + f"Incorrect format for grid data in {input_file} with format {format}. " + f"Expected 3 columns but found {len(data.colnames)}." + ) + + # Get the values along the x0 and x1 dimensions. + x0 = np.sort(np.unique(data[data.colnames[0]].data)) + x1 = np.sort(np.unique(data[data.colnames[1]].data)) + + # Get the array of values. + if len(data) != len(x0) * len(x1): + raise ValueError( + f"Incomplete data for {input_file} with format {format}. Expected " + f"{len(x0) * len(x1)} entries but found {len(data)}." + ) + values = data[data.colnames[2]].data.reshape((len(x0), len(x1))) + + return x0, x1, values From 569dde2901dee81eb7a56fc1405f6eafad7698d9 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Wed, 10 Jul 2024 18:30:21 -0400 Subject: [PATCH 03/31] Add testing --- src/tdastro/io_utils.py | 34 +++++++++++++++++++++---- tests/tdastro/conftest.py | 24 +++++++++++++++++ tests/tdastro/data/grid_input_bad.txt | 6 +++++ tests/tdastro/data/grid_input_good.ecsv | 8 ++++++ tests/tdastro/test_io_utils.py | 25 ++++++++++++++++++ 5 files changed, 92 insertions(+), 5 deletions(-) create mode 100644 tests/tdastro/data/grid_input_bad.txt create mode 100644 tests/tdastro/data/grid_input_good.ecsv create mode 100644 tests/tdastro/test_io_utils.py diff --git a/src/tdastro/io_utils.py b/src/tdastro/io_utils.py index 9067ca9b..3730e3fb 100644 --- a/src/tdastro/io_utils.py +++ b/src/tdastro/io_utils.py @@ -2,7 +2,7 @@ from astropy.table import Table -def read_grid_data(input_file, format="ascii"): +def read_grid_data(input_file, format="ascii", validate=False): """Read 2-d grid data from a text, csv, ecsv, or fits file. Each line is of the form 'x0 x1 value' where x0 and x1 are the grid @@ -16,6 +16,10 @@ def read_grid_data(input_file, format="ascii"): format : `str` The file format. Should be one of ascii, csv, ecsv, or fits. + Default = 'ascii' + validate : `bool` + Perform additional validation on the input data. + Default = False Returns ------- @@ -26,17 +30,24 @@ def read_grid_data(input_file, format="ascii"): values : `numpy.ndarray` A 2-d array with the values at each point in the grid with shape (len(x0), len(x1)). + + Raises + ------ + ``ValueError`` if any data validation fails. """ - data = Table.read(input_file, format=format) + data = Table.read(input_file, format=format, comment=r"\s*#") if len(data.colnames) != 3: raise ValueError( f"Incorrect format for grid data in {input_file} with format {format}. " f"Expected 3 columns but found {len(data.colnames)}." ) + x0_col = data.colnames[0] + x1_col = data.colnames[1] + v_col = data.colnames[2] # Get the values along the x0 and x1 dimensions. - x0 = np.sort(np.unique(data[data.colnames[0]].data)) - x1 = np.sort(np.unique(data[data.colnames[1]].data)) + x0 = np.sort(np.unique(data[x0_col].data)) + x1 = np.sort(np.unique(data[x1_col].data)) # Get the array of values. if len(data) != len(x0) * len(x1): @@ -44,6 +55,19 @@ def read_grid_data(input_file, format="ascii"): f"Incomplete data for {input_file} with format {format}. Expected " f"{len(x0) * len(x1)} entries but found {len(data)}." ) - values = data[data.colnames[2]].data.reshape((len(x0), len(x1))) + + # If we are validating, loop through the entire table and check that + # the x0 and x1 values are in the expected order. + if validate: + counter = 0 + for i in range(len(x0)): + for j in range(len(x1)): + if data[x0_col][counter] != x0[i]: + raise ValueError(f"Incorrect x0 ordering in {input_file} at row={counter}.") + if data[x1_col][counter] != x1[j]: + raise ValueError(f"Incorrect x0 ordering in {input_file} at row={counter}.") + + # Build the values matrix. + values = data[v_col].data.reshape((len(x0), len(x1))) return x0, x1, values diff --git a/tests/tdastro/conftest.py b/tests/tdastro/conftest.py index e69de29b..2af78c2d 100644 --- a/tests/tdastro/conftest.py +++ b/tests/tdastro/conftest.py @@ -0,0 +1,24 @@ +import os.path + +import pytest + +DATA_DIR_NAME = "data" +TEST_DIR = os.path.dirname(__file__) + + +@pytest.fixture +def test_data_dir(): + """Return the base test data directory.""" + return os.path.join(TEST_DIR, DATA_DIR_NAME) + + +@pytest.fixture +def grid_data_good_file(test_data_dir): + """Return the file path for the good grid input file.""" + return os.path.join(test_data_dir, "grid_input_good.ecsv") + + +@pytest.fixture +def grid_data_bad_file(test_data_dir): + """Return the file path for the bad grid input file.""" + return os.path.join(test_data_dir, "grid_input_bad.txt") diff --git a/tests/tdastro/data/grid_input_bad.txt b/tests/tdastro/data/grid_input_bad.txt new file mode 100644 index 00000000..12793185 --- /dev/null +++ b/tests/tdastro/data/grid_input_bad.txt @@ -0,0 +1,6 @@ +0.0 1.0 0.0 +0.0 1.5 1.0 +1.0 1.0 2.0 +1.0 1.5 3.0 +2.0 1.5 5.0 +2.0 1.0 4.0 \ No newline at end of file diff --git a/tests/tdastro/data/grid_input_good.ecsv b/tests/tdastro/data/grid_input_good.ecsv new file mode 100644 index 00000000..694d93e0 --- /dev/null +++ b/tests/tdastro/data/grid_input_good.ecsv @@ -0,0 +1,8 @@ +# Comment up here. +x0, x1, values +0.0, 1.0, 0.0 +0.0, 1.5, 1.0 +1.0, 1.0, 2.0 +1.0, 1.5, 3.0 +2.0, 1.0, 4.0 +2.0, 1.5, 5.0 \ No newline at end of file diff --git a/tests/tdastro/test_io_utils.py b/tests/tdastro/test_io_utils.py new file mode 100644 index 00000000..0a17f2fe --- /dev/null +++ b/tests/tdastro/test_io_utils.py @@ -0,0 +1,25 @@ +import numpy as np +import pytest +from tdastro.io_utils import read_grid_data + + +def test_read_grid_data_good(grid_data_good_file): + """Test that we can read a well formatted grid data file.""" + x0, x1, values = read_grid_data(grid_data_good_file, format="ascii.csv") + x0_expected = np.array([0.0, 1.0, 2.0]) + x1_expected = np.array([1.0, 1.5]) + values_expected = np.array([[0.0, 1.0], [2.0, 3.0], [4.0, 5.0]]) + + np.testing.assert_allclose(x0, x0_expected, atol=1e-5) + np.testing.assert_allclose(x1, x1_expected, atol=1e-5) + np.testing.assert_allclose(values, values_expected, atol=1e-5) + + +def test_read_grid_data_bad(grid_data_bad_file): + """Test that we correctly handle a badly formatted grid data file.""" + # We load without a problem is validation is off. + x0, x1, values = read_grid_data(grid_data_bad_file, format="ascii") + assert values.shape == (3, 2) + + with pytest.raises(ValueError): + _, _, _ = read_grid_data(grid_data_bad_file, format="ascii", validate=True) From 2e596e3c6dc4aab5d4e3a50926cfeabcadb5804a Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Wed, 10 Jul 2024 18:35:41 -0400 Subject: [PATCH 04/31] Fix comment --- src/tdastro/io_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tdastro/io_utils.py b/src/tdastro/io_utils.py index 3730e3fb..926c2c15 100644 --- a/src/tdastro/io_utils.py +++ b/src/tdastro/io_utils.py @@ -14,8 +14,8 @@ def read_grid_data(input_file, format="ascii", validate=False): input_file : `str` or file-like object The input data file. format : `str` - The file format. Should be one of ascii, csv, ecsv, - or fits. + The file format. Should be one of the formats supported by + astropy Tables such as 'ascii', 'ascii.ecsv', or 'fits'. Default = 'ascii' validate : `bool` Perform additional validation on the input data. From 9f99eed39c952b6a7f737ce828e8c33af4cfefa4 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Thu, 11 Jul 2024 11:20:21 -0400 Subject: [PATCH 05/31] Wrap sncosmo models --- pyproject.toml | 1 + src/tdastro/base_models.py | 77 ++++++++++---- src/tdastro/sources/sncomso_models.py | 102 +++++++++++++++++++ tests/tdastro/sources/test_sncosmo_models.py | 19 ++++ tests/tdastro/sources/test_static_source.py | 6 ++ tests/tdastro/test_base_models.py | 22 +++- 6 files changed, 208 insertions(+), 19 deletions(-) create mode 100644 src/tdastro/sources/sncomso_models.py create mode 100644 tests/tdastro/sources/test_sncosmo_models.py diff --git a/pyproject.toml b/pyproject.toml index 06ee50c7..7fc95b11 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ dependencies = [ "astropy", "numpy", "scipy", + "sncosmo", ] [project.urls] diff --git a/src/tdastro/base_models.py b/src/tdastro/base_models.py index 27a7ffcc..77f40995 100644 --- a/src/tdastro/base_models.py +++ b/src/tdastro/base_models.py @@ -23,8 +23,8 @@ class ParameterizedModel: ---------- setters : `list` of `tuple` A dictionary to information about the setters for the parameters in the form: - (name, ParameterSource, setter information). The attributes are stored in the - order in which they need to be set. + (name, ParameterSource, setter information, required). The attributes are + stored in the order in which they need to be set. sample_iteration : `int` A counter used to syncronize sampling runs. Tracks how many times this model's parameters have been resampled. @@ -38,12 +38,13 @@ def __str__(self): """Return the string representation of the model.""" return "ParameterizedModel" - def add_parameter(self, name, value=None, required=False, **kwargs): - """Add a single parameter to the ParameterizedModel. Checks multiple sources - in the following order: 1. Manually specified ``value``, 2. An entry in ``kwargs``, - or 3. ``None``. - Sets an initial value for the attribute based on the given information. - The attributes are stored in the order in which they are added. + def set_parameter(self, name, value=None, **kwargs): + """Set a single *existing* parameter to the ParameterizedModel. + + Notes + ----- + * Sets an initial value for the attribute based on the given information. + * The attributes are stored in the order in which they are added. Parameters ---------- @@ -52,8 +53,6 @@ def add_parameter(self, name, value=None, required=False, **kwargs): value : any, optional The information to use to set the parameter. Can be a constant, function, ParameterizedModel, or self. - required : `bool` - Fail if the parameter is set to ``None``. **kwargs : `dict`, optional All other keyword arguments, possibly including the parameter setters. @@ -63,39 +62,81 @@ def add_parameter(self, name, value=None, required=False, **kwargs): cannot be found. Raise a ``ValueError`` if the parameter is required, but set to None. """ - if hasattr(self, name) and getattr(self, name) is not None: - raise KeyError(f"Duplicate parameter set: {name}") + # Check for parameter has been added and if so, find the index. + try: + ind = next(ind for ind, entry in enumerate(self.setters) if entry[0] == name) + except StopIteration: + raise KeyError(f"Tried to set parameter {name} that has not been added.") from None + required = self.setters[ind][3] if value is None and name in kwargs: # The value wasn't set, but the name is in kwargs. value = kwargs[name] + if value is not None: if isinstance(value, types.FunctionType): # Case 1: If we are getting from a static function, sample it. - self.setters.append((name, ParameterSource.FUNCTION, value)) + self.setters[ind] = (name, ParameterSource.FUNCTION, value, required) setattr(self, name, value(**kwargs)) elif isinstance(value, types.MethodType) and isinstance(value.__self__, ParameterizedModel): # Case 2: We are trying to use the method from a ParameterizedModel. # Note that this will (correctly) fail if we are adding a model method from the current # object that requires an unset attribute. - self.setters.append((name, ParameterSource.MODEL_METHOD, value)) + self.setters[ind] = (name, ParameterSource.MODEL_METHOD, value, required) setattr(self, name, value(**kwargs)) elif isinstance(value, ParameterizedModel): # Case 3: We are trying to access an attribute from a parameterized model. if not hasattr(value, name): raise ValueError(f"Attribute {name} missing from parent.") - self.setters.append((name, ParameterSource.MODEL_ATTRIBUTE, value)) + self.setters[ind] = (name, ParameterSource.MODEL_ATTRIBUTE, value, required) setattr(self, name, getattr(value, name)) else: # Case 4: The value is constant. - self.setters.append((name, ParameterSource.CONSTANT, value)) + self.setters[ind] = (name, ParameterSource.CONSTANT, value, required) setattr(self, name, value) elif not required: - self.setters.append((name, ParameterSource.CONSTANT, None)) + self.setters[ind] = (name, ParameterSource.CONSTANT, None, required) setattr(self, name, None) else: raise ValueError(f"Missing required parameter {name}") + def add_parameter(self, name, value=None, required=False, **kwargs): + """Add a single *new* parameter to the ParameterizedModel. + + Notes + ----- + * Checks multiple sources in the following order: Manually specified ``value``, + an entry in ``kwargs``, or ``None``. + * Sets an initial value for the attribute based on the given information. + * The attributes are stored in the order in which they are added. + + Parameters + ---------- + name : `str` + The parameter name to add. + value : any, optional + The information to use to set the parameter. Can be a constant, + function, ParameterizedModel, or self. + required : `bool` + Fail if the parameter is set to ``None``. + **kwargs : `dict`, optional + All other keyword arguments, possibly including the parameter setters. + + Raises + ------ + Raise a ``KeyError`` if there is a parameter collision or the parameter + cannot be found. + Raise a ``ValueError`` if the parameter is required, but set to None. + """ + # Check for parameter collision. + if hasattr(self, name) and getattr(self, name) is not None: + raise KeyError(f"Duplicate parameter set: {name}") + + # Add an entry for the setter function and fill in the remaining + # information using set_parameter(). + self.setters.append((name, None, None, required)) + self.set_parameter(name, value, **kwargs) + def sample_parameters(self, max_depth=50, **kwargs): """Sample the model's underlying parameters if they are provided by a function or ParameterizedModel. @@ -118,7 +159,7 @@ def sample_parameters(self, max_depth=50, **kwargs): raise ValueError(f"Maximum sampling depth exceeded at {self}. Potential infinite loop.") # Run through each parameter and sample it based on the given recipe. - for name, source_type, setter in self.setters: + for name, source_type, setter, _ in self.setters: sampled_value = None if source_type == ParameterSource.CONSTANT: sampled_value = setter diff --git a/src/tdastro/sources/sncomso_models.py b/src/tdastro/sources/sncomso_models.py new file mode 100644 index 00000000..852ebf1e --- /dev/null +++ b/src/tdastro/sources/sncomso_models.py @@ -0,0 +1,102 @@ +"""Wrappers for the models defined in sncosmo. + +https://github.com/sncosmo/sncosmo/blob/v2.10.1/sncosmo/models.py +https://sncosmo.readthedocs.io/en/stable/models.html +""" + +import sncosmo + +from tdastro.base_models import PhysicalModel + + +class SncosmoModel(PhysicalModel): + """A wrapper for sncosmo models. + + Attributes + ---------- + model : `sncosmo.Model` + The underlying model. + model_name : `str` + The name used to set the model. + + Parameters + ---------- + model_name : `str` + The name used to set the model. + **kwargs : `dict`, optional + Any additional keyword arguments. + """ + + def __init__(self, model_name, **kwargs): + super().__init__(**kwargs) + self.model_name = model_name + self.model = sncosmo.Model(source=model_name) + + def __str__(self): + """Return the string representation of the model.""" + return f"SncosmoModel({self.model_name})" + + @property + def param_names(self): + """Return a list of the model's parameter names.""" + return self.model.param_names + + @property + def parameters(self): + """Return a list of the model's parameter values.""" + return self.model.parameters + + @property + def source(self): + """Return the model's sncosmo source instance.""" + return self.model.source + + def get(self, name): + """Get the value of a specific parameter. + + Parameters + ---------- + name : `str` + The name of the parameter. + + Returns + ------- + The parameter value. + """ + return self.model.get(name) + + def set(self, **kwargs): + """Set the parameters of the model. + + These must all be constants to be compatible with sncosmo. + + Parameters + ---------- + **kwargs : `dict` + The parameters to set and their values. + """ + self.model.set(**kwargs) + for key, value in kwargs.items(): + if hasattr(self, key): + self.set_parameter(key, value) + else: + self.add_parameter(key, value) + + def _evaluate(self, times, wavelengths, **kwargs): + """Draw effect-free observations for this object. + + Parameters + ---------- + times : `numpy.ndarray` + A length T array of timestamps. + wavelengths : `numpy.ndarray`, optional + A length N array of wavelengths. + **kwargs : `dict`, optional + Any additional keyword arguments. + + Returns + ------- + flux_density : `numpy.ndarray` + A length T x N matrix of SED values. + """ + return self.model.flux(times, wavelengths) diff --git a/tests/tdastro/sources/test_sncosmo_models.py b/tests/tdastro/sources/test_sncosmo_models.py new file mode 100644 index 00000000..b417fa16 --- /dev/null +++ b/tests/tdastro/sources/test_sncosmo_models.py @@ -0,0 +1,19 @@ +import numpy as np +from tdastro.sources.sncomso_models import SncosmoModel + + +def test_sncomso_models_hsiao() -> None: + """Test that we can create and evalue a 'hsiao' model.""" + model = SncosmoModel("hsiao") + model.set(z=0.5, t0=55000.0, amplitude=1.0e-10) + assert model.z == 0.5 + assert model.t0 == 55000.0 + assert model.amplitude == 1.0e-10 + assert str(model) == "SncosmoModel(hsiao)" + + assert np.array_equal(model.param_names, ["z", "t0", "amplitude"]) + assert np.array_equal(model.parameters, [0.5, 55000.0, 1.0e-10]) + + # Test with the example from: https://sncosmo.readthedocs.io/en/stable/models.html + fluxes = model.evaluate([54990.0], [4000.0, 4100.0, 4200.0]) + assert np.allclose(fluxes, [4.31210900e-20, 7.46619962e-20, 1.42182787e-19]) diff --git a/tests/tdastro/sources/test_static_source.py b/tests/tdastro/sources/test_static_source.py index 834c2899..420ff749 100644 --- a/tests/tdastro/sources/test_static_source.py +++ b/tests/tdastro/sources/test_static_source.py @@ -33,6 +33,12 @@ def test_static_source() -> None: assert values.shape == (6, 3) assert np.all(values == 10.0) + # We can set a value we have already added. + model.set_parameter("brightness", 5.0) + values = model.evaluate(times, wavelengths) + assert values.shape == (6, 3) + assert np.all(values == 5.0) + def test_static_source_host() -> None: """Test that we can sample and create a StaticSource object with properties diff --git a/tests/tdastro/test_base_models.py b/tests/tdastro/test_base_models.py index 86b1291d..b5ff1daa 100644 --- a/tests/tdastro/test_base_models.py +++ b/tests/tdastro/test_base_models.py @@ -102,7 +102,7 @@ def test_parameterized_model() -> None: # Nothing changes in model1 or model2 assert model1.value1 == 0.5 - assert model1.value1 == 0.5 + assert model1.value2 == 0.5 assert model1.result() == 1.0 assert model1.value_sum == 1.0 assert model2.value1 == 0.5 @@ -121,3 +121,23 @@ def test_parameterized_model() -> None: assert model1.sample_iteration == model2.sample_iteration assert model1.sample_iteration == model3.sample_iteration assert model1.sample_iteration == model4.sample_iteration + + +def test_parameterized_model_modify() -> None: + """Test that we can modify the parameters in a model.""" + model = PairModel(value1=0.5, value2=0.5) + assert model.value1 == 0.5 + assert model.value2 == 0.5 + + # We cannot add a parameter a second time. + with pytest.raises(KeyError): + model.add_parameter("value1", 5.0) + + # We can set the parameter. + model.set_parameter("value1", 5.0) + assert model.value1 == 5.0 + assert model.value2 == 0.5 + + # We cannot set a value that hasn't been added. + with pytest.raises(KeyError): + model.set_parameter("brightness", 5.0) From e3f2fa005e4daa1b595daeafae088449e1041c4e Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Thu, 11 Jul 2024 11:30:01 -0400 Subject: [PATCH 06/31] Allow users to change parameter settings --- src/tdastro/base_models.py | 77 ++++++++++++++++----- tests/tdastro/sources/test_static_source.py | 6 ++ tests/tdastro/test_base_models.py | 22 +++++- 3 files changed, 86 insertions(+), 19 deletions(-) diff --git a/src/tdastro/base_models.py b/src/tdastro/base_models.py index 27a7ffcc..77f40995 100644 --- a/src/tdastro/base_models.py +++ b/src/tdastro/base_models.py @@ -23,8 +23,8 @@ class ParameterizedModel: ---------- setters : `list` of `tuple` A dictionary to information about the setters for the parameters in the form: - (name, ParameterSource, setter information). The attributes are stored in the - order in which they need to be set. + (name, ParameterSource, setter information, required). The attributes are + stored in the order in which they need to be set. sample_iteration : `int` A counter used to syncronize sampling runs. Tracks how many times this model's parameters have been resampled. @@ -38,12 +38,13 @@ def __str__(self): """Return the string representation of the model.""" return "ParameterizedModel" - def add_parameter(self, name, value=None, required=False, **kwargs): - """Add a single parameter to the ParameterizedModel. Checks multiple sources - in the following order: 1. Manually specified ``value``, 2. An entry in ``kwargs``, - or 3. ``None``. - Sets an initial value for the attribute based on the given information. - The attributes are stored in the order in which they are added. + def set_parameter(self, name, value=None, **kwargs): + """Set a single *existing* parameter to the ParameterizedModel. + + Notes + ----- + * Sets an initial value for the attribute based on the given information. + * The attributes are stored in the order in which they are added. Parameters ---------- @@ -52,8 +53,6 @@ def add_parameter(self, name, value=None, required=False, **kwargs): value : any, optional The information to use to set the parameter. Can be a constant, function, ParameterizedModel, or self. - required : `bool` - Fail if the parameter is set to ``None``. **kwargs : `dict`, optional All other keyword arguments, possibly including the parameter setters. @@ -63,39 +62,81 @@ def add_parameter(self, name, value=None, required=False, **kwargs): cannot be found. Raise a ``ValueError`` if the parameter is required, but set to None. """ - if hasattr(self, name) and getattr(self, name) is not None: - raise KeyError(f"Duplicate parameter set: {name}") + # Check for parameter has been added and if so, find the index. + try: + ind = next(ind for ind, entry in enumerate(self.setters) if entry[0] == name) + except StopIteration: + raise KeyError(f"Tried to set parameter {name} that has not been added.") from None + required = self.setters[ind][3] if value is None and name in kwargs: # The value wasn't set, but the name is in kwargs. value = kwargs[name] + if value is not None: if isinstance(value, types.FunctionType): # Case 1: If we are getting from a static function, sample it. - self.setters.append((name, ParameterSource.FUNCTION, value)) + self.setters[ind] = (name, ParameterSource.FUNCTION, value, required) setattr(self, name, value(**kwargs)) elif isinstance(value, types.MethodType) and isinstance(value.__self__, ParameterizedModel): # Case 2: We are trying to use the method from a ParameterizedModel. # Note that this will (correctly) fail if we are adding a model method from the current # object that requires an unset attribute. - self.setters.append((name, ParameterSource.MODEL_METHOD, value)) + self.setters[ind] = (name, ParameterSource.MODEL_METHOD, value, required) setattr(self, name, value(**kwargs)) elif isinstance(value, ParameterizedModel): # Case 3: We are trying to access an attribute from a parameterized model. if not hasattr(value, name): raise ValueError(f"Attribute {name} missing from parent.") - self.setters.append((name, ParameterSource.MODEL_ATTRIBUTE, value)) + self.setters[ind] = (name, ParameterSource.MODEL_ATTRIBUTE, value, required) setattr(self, name, getattr(value, name)) else: # Case 4: The value is constant. - self.setters.append((name, ParameterSource.CONSTANT, value)) + self.setters[ind] = (name, ParameterSource.CONSTANT, value, required) setattr(self, name, value) elif not required: - self.setters.append((name, ParameterSource.CONSTANT, None)) + self.setters[ind] = (name, ParameterSource.CONSTANT, None, required) setattr(self, name, None) else: raise ValueError(f"Missing required parameter {name}") + def add_parameter(self, name, value=None, required=False, **kwargs): + """Add a single *new* parameter to the ParameterizedModel. + + Notes + ----- + * Checks multiple sources in the following order: Manually specified ``value``, + an entry in ``kwargs``, or ``None``. + * Sets an initial value for the attribute based on the given information. + * The attributes are stored in the order in which they are added. + + Parameters + ---------- + name : `str` + The parameter name to add. + value : any, optional + The information to use to set the parameter. Can be a constant, + function, ParameterizedModel, or self. + required : `bool` + Fail if the parameter is set to ``None``. + **kwargs : `dict`, optional + All other keyword arguments, possibly including the parameter setters. + + Raises + ------ + Raise a ``KeyError`` if there is a parameter collision or the parameter + cannot be found. + Raise a ``ValueError`` if the parameter is required, but set to None. + """ + # Check for parameter collision. + if hasattr(self, name) and getattr(self, name) is not None: + raise KeyError(f"Duplicate parameter set: {name}") + + # Add an entry for the setter function and fill in the remaining + # information using set_parameter(). + self.setters.append((name, None, None, required)) + self.set_parameter(name, value, **kwargs) + def sample_parameters(self, max_depth=50, **kwargs): """Sample the model's underlying parameters if they are provided by a function or ParameterizedModel. @@ -118,7 +159,7 @@ def sample_parameters(self, max_depth=50, **kwargs): raise ValueError(f"Maximum sampling depth exceeded at {self}. Potential infinite loop.") # Run through each parameter and sample it based on the given recipe. - for name, source_type, setter in self.setters: + for name, source_type, setter, _ in self.setters: sampled_value = None if source_type == ParameterSource.CONSTANT: sampled_value = setter diff --git a/tests/tdastro/sources/test_static_source.py b/tests/tdastro/sources/test_static_source.py index 834c2899..420ff749 100644 --- a/tests/tdastro/sources/test_static_source.py +++ b/tests/tdastro/sources/test_static_source.py @@ -33,6 +33,12 @@ def test_static_source() -> None: assert values.shape == (6, 3) assert np.all(values == 10.0) + # We can set a value we have already added. + model.set_parameter("brightness", 5.0) + values = model.evaluate(times, wavelengths) + assert values.shape == (6, 3) + assert np.all(values == 5.0) + def test_static_source_host() -> None: """Test that we can sample and create a StaticSource object with properties diff --git a/tests/tdastro/test_base_models.py b/tests/tdastro/test_base_models.py index 86b1291d..b5ff1daa 100644 --- a/tests/tdastro/test_base_models.py +++ b/tests/tdastro/test_base_models.py @@ -102,7 +102,7 @@ def test_parameterized_model() -> None: # Nothing changes in model1 or model2 assert model1.value1 == 0.5 - assert model1.value1 == 0.5 + assert model1.value2 == 0.5 assert model1.result() == 1.0 assert model1.value_sum == 1.0 assert model2.value1 == 0.5 @@ -121,3 +121,23 @@ def test_parameterized_model() -> None: assert model1.sample_iteration == model2.sample_iteration assert model1.sample_iteration == model3.sample_iteration assert model1.sample_iteration == model4.sample_iteration + + +def test_parameterized_model_modify() -> None: + """Test that we can modify the parameters in a model.""" + model = PairModel(value1=0.5, value2=0.5) + assert model.value1 == 0.5 + assert model.value2 == 0.5 + + # We cannot add a parameter a second time. + with pytest.raises(KeyError): + model.add_parameter("value1", 5.0) + + # We can set the parameter. + model.set_parameter("value1", 5.0) + assert model.value1 == 5.0 + assert model.value2 == 0.5 + + # We cannot set a value that hasn't been added. + with pytest.raises(KeyError): + model.set_parameter("brightness", 5.0) From 03f1c82d78b0c1ff13c51a1c398f30a594f93805 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Thu, 11 Jul 2024 12:05:44 -0400 Subject: [PATCH 07/31] Scope down sncosmo dependency --- src/tdastro/sources/sncomso_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tdastro/sources/sncomso_models.py b/src/tdastro/sources/sncomso_models.py index 852ebf1e..14db6715 100644 --- a/src/tdastro/sources/sncomso_models.py +++ b/src/tdastro/sources/sncomso_models.py @@ -4,7 +4,7 @@ https://sncosmo.readthedocs.io/en/stable/models.html """ -import sncosmo +from sncosmo.models import Model from tdastro.base_models import PhysicalModel @@ -30,7 +30,7 @@ class SncosmoModel(PhysicalModel): def __init__(self, model_name, **kwargs): super().__init__(**kwargs) self.model_name = model_name - self.model = sncosmo.Model(source=model_name) + self.model = Model(source=model_name) def __str__(self): """Return the string representation of the model.""" From 0ec7d7e93b21c5d5687f3e7212a9f97c6a42b90e Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Fri, 12 Jul 2024 09:27:54 -0400 Subject: [PATCH 08/31] Change setters list to dict --- src/tdastro/base_models.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/src/tdastro/base_models.py b/src/tdastro/base_models.py index 77f40995..e2fb3e13 100644 --- a/src/tdastro/base_models.py +++ b/src/tdastro/base_models.py @@ -21,9 +21,9 @@ class ParameterizedModel: Attributes ---------- - setters : `list` of `tuple` + setters : `dict` of `tuple` A dictionary to information about the setters for the parameters in the form: - (name, ParameterSource, setter information, required). The attributes are + (ParameterSource, setter information, required). The attributes are stored in the order in which they need to be set. sample_iteration : `int` A counter used to syncronize sampling runs. Tracks how many times this @@ -31,7 +31,7 @@ class ParameterizedModel: """ def __init__(self, **kwargs): - self.setters = [] + self.setters = {} self.sample_iteration = 0 def __str__(self): @@ -63,11 +63,9 @@ def set_parameter(self, name, value=None, **kwargs): Raise a ``ValueError`` if the parameter is required, but set to None. """ # Check for parameter has been added and if so, find the index. - try: - ind = next(ind for ind, entry in enumerate(self.setters) if entry[0] == name) - except StopIteration: + if name not in self.setters: raise KeyError(f"Tried to set parameter {name} that has not been added.") from None - required = self.setters[ind][3] + required = self.setters[name][2] if value is None and name in kwargs: # The value wasn't set, but the name is in kwargs. @@ -76,26 +74,26 @@ def set_parameter(self, name, value=None, **kwargs): if value is not None: if isinstance(value, types.FunctionType): # Case 1: If we are getting from a static function, sample it. - self.setters[ind] = (name, ParameterSource.FUNCTION, value, required) + self.setters[name] = (ParameterSource.FUNCTION, value, required) setattr(self, name, value(**kwargs)) elif isinstance(value, types.MethodType) and isinstance(value.__self__, ParameterizedModel): # Case 2: We are trying to use the method from a ParameterizedModel. # Note that this will (correctly) fail if we are adding a model method from the current # object that requires an unset attribute. - self.setters[ind] = (name, ParameterSource.MODEL_METHOD, value, required) + self.setters[name] = (ParameterSource.MODEL_METHOD, value, required) setattr(self, name, value(**kwargs)) elif isinstance(value, ParameterizedModel): # Case 3: We are trying to access an attribute from a parameterized model. if not hasattr(value, name): raise ValueError(f"Attribute {name} missing from parent.") - self.setters[ind] = (name, ParameterSource.MODEL_ATTRIBUTE, value, required) + self.setters[name] = (ParameterSource.MODEL_ATTRIBUTE, value, required) setattr(self, name, getattr(value, name)) else: # Case 4: The value is constant. - self.setters[ind] = (name, ParameterSource.CONSTANT, value, required) + self.setters[name] = (ParameterSource.CONSTANT, value, required) setattr(self, name, value) elif not required: - self.setters[ind] = (name, ParameterSource.CONSTANT, None, required) + self.setters[name] = (ParameterSource.CONSTANT, None, required) setattr(self, name, None) else: raise ValueError(f"Missing required parameter {name}") @@ -134,7 +132,7 @@ def add_parameter(self, name, value=None, required=False, **kwargs): # Add an entry for the setter function and fill in the remaining # information using set_parameter(). - self.setters.append((name, None, None, required)) + self.setters[name] = (None, None, required) self.set_parameter(name, value, **kwargs) def sample_parameters(self, max_depth=50, **kwargs): @@ -159,7 +157,9 @@ def sample_parameters(self, max_depth=50, **kwargs): raise ValueError(f"Maximum sampling depth exceeded at {self}. Potential infinite loop.") # Run through each parameter and sample it based on the given recipe. - for name, source_type, setter, _ in self.setters: + # As of Python 3.7 dictionaries are guaranteed to preserve insertion ordering, + # so this will iterate through attributes in the order they were inserted. + for name, (source_type, setter, _) in self.setters.items(): sampled_value = None if source_type == ParameterSource.CONSTANT: sampled_value = setter From 5f5ec93bff1a67e69acb8f8a797e7b233b67ec09 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Fri, 12 Jul 2024 11:13:16 -0400 Subject: [PATCH 09/31] Create a population level model --- src/tdastro/base_models.py | 105 ++++++++++++++++++ src/tdastro/populations/__init__.py | 0 src/tdastro/populations/fixed_population.py | 85 ++++++++++++++ .../populations/test_fixed_population.py | 70 ++++++++++++ 4 files changed, 260 insertions(+) create mode 100644 src/tdastro/populations/__init__.py create mode 100644 src/tdastro/populations/fixed_population.py create mode 100644 tests/tdastro/populations/test_fixed_population.py diff --git a/src/tdastro/base_models.py b/src/tdastro/base_models.py index e2fb3e13..30bffcb2 100644 --- a/src/tdastro/base_models.py +++ b/src/tdastro/base_models.py @@ -1,5 +1,10 @@ +"""The base models used throughout TDAstro including physical objects, effects, and populations.""" + import types from enum import Enum +from os import urandom + +import numpy as np class ParameterSource(Enum): @@ -347,3 +352,103 @@ def apply(self, flux_density, wavelengths=None, physical_model=None, **kwargs): A length T x N matrix of flux densities after the effect is applied. """ raise NotImplementedError() + + +class PopulationModel(ParameterizedModel): + """A model of a population of PhysicalModels. + + Attributes + ---------- + num_sources : `int` + The number of different sources in the population. + sources : `list` + A list of sources from which to draw. + _rng : `numpy.random._generator.Generator` + A random number generator to use. + + Parameters + ---------- + rng : `numpy.random._generator.Generator`, optional + A random number generator to use for sampling. If not provided, + will create a new one with a randomized seed. + """ + + def __init__(self, rng=None, **kwargs): + super().__init__(**kwargs) + self.num_sources = 0 + self.sources = [] + if rng is None: + seed = int.from_bytes(urandom(4), "big") + self._rng = np.random.default_rng(seed) + else: + self._rng = rng + + def __str__(self): + """Return the string representation of the model.""" + return f"PopulationModel with {self.num_sources} sources." + + def add_source(self, new_source, **kwargs): + """Add a new source to the population. + + Parameters + ---------- + new_source : `PhysicalModel` + A source from the population. + **kwargs : `dict`, optional + Any additional keyword arguments. + """ + if not isinstance(new_source, PhysicalModel): + raise ValueError("All sources must be PhysicalModels") + self.sources.append(new_source) + self.num_sources += 1 + + def draw_source(self): + """Sample a single source from the population. + + Returns + ------- + source : `PhysicalModel` + A source from the population. + """ + raise NotImplementedError() + + def add_effect(self, effect, **kwargs): + """Add a transformational effect to all PhysicalModels in this population. + Effects are applied in the order in which they are added. + + Parameters + ---------- + effect : `EffectModel` + The effect to apply. + **kwargs : `dict`, optional + Any additional keyword arguments. + + Raises + ------ + Raises a ``AttributeError`` if the PhysicalModel does not have all of the + required attributes. + """ + for source in self.sources: + source.add_effect(effect, **kwargs) + + def evaluate(self, times, wavelengths, resample_parameters=False, **kwargs): + """Draw observations from a single (randomly sampled) source. + + Parameters + ---------- + times : `numpy.ndarray` + A length T array of timestamps. + wavelengths : `numpy.ndarray`, optional + A length N array of wavelengths. + resample_parameters : `bool` + Treat this evaluation as a completely new object, resampling the + parameters from the original provided functions. + **kwargs : `dict`, optional + Any additional keyword arguments. + + Returns + ------- + flux_density : `numpy.ndarray` + A length T x N matrix of SED values. + """ + return self.draw_source().evalute(times, wavelengths, resample_parameters, **kwargs) diff --git a/src/tdastro/populations/__init__.py b/src/tdastro/populations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/tdastro/populations/fixed_population.py b/src/tdastro/populations/fixed_population.py new file mode 100644 index 00000000..06838062 --- /dev/null +++ b/src/tdastro/populations/fixed_population.py @@ -0,0 +1,85 @@ +import numpy as np + +from tdastro.base_models import PopulationModel + + +class FixedPopulation(PopulationModel): + """A population with a predefined, fixed probability of sampling each source. + + Attributes + ---------- + probs : `numpy.ndarray` + The probability of drawing each type of source. + _raw_rates : `numpy.ndarray` + An array of floats that provides the base sampling rate for each + type. This is normalized into a probability distributions so + [100, 200, 200] -> [0.2, 0.4, 0.4]. + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.probs = np.array([]) + self._raw_rates = np.array([]) + + def __str__(self): + """Return the string representation of the model.""" + return f"FixedPopulation({self.probability})" + + def _update_probabilities(self): + """Update the probability array.""" + self.probs = self._raw_rates / np.sum(self._raw_rates) + + def add_source(self, new_source, rate, **kwargs): + """Add a new source to the population. + + Parameters + ---------- + new_source : `PhysicalModel` + A source from the population. + rate : `float` + A numerical rate for drawing the object. + **kwargs : `dict`, optional + Any additional keyword arguments. + + Raises + ------ + ``ValueError`` if the rate <= 0.0. + """ + if rate <= 0.0: + raise ValueError(f"Expected positive rate. Found {rate}.") + super().add_source(new_source, **kwargs) + + self._raw_rates = np.append(self._raw_rates, rate) + self._update_probabilities() + + def change_rate(self, source_index, rate, **kwargs): + """Add a new source to the population. + + Parameters + ---------- + source_index : `int` + The index of the source whose rate is changing. + rate : `float` + A numerical rate for drawing the object. + **kwargs : `dict`, optional + Any additional keyword arguments. + + Raises + ------ + ``ValueError`` if the rate <= 0.0. + """ + if rate <= 0.0: + raise ValueError(f"Expected positive rate. Found {rate}.") + self._raw_rates[source_index] = rate + self._update_probabilities() + + def draw_source(self): + """Sample a single source from the population. + + Returns + ------- + source : `PhysicalModel` + A source from the population. + """ + index = self._rng.choice(np.arange(0, self.num_sources), p=self.probs) + return self.sources[index] diff --git a/tests/tdastro/populations/test_fixed_population.py b/tests/tdastro/populations/test_fixed_population.py new file mode 100644 index 00000000..8989c921 --- /dev/null +++ b/tests/tdastro/populations/test_fixed_population.py @@ -0,0 +1,70 @@ +import numpy as np +import pytest +from tdastro.effects.white_noise import WhiteNoise +from tdastro.populations.fixed_population import FixedPopulation +from tdastro.sources.static_source import StaticSource + + +def test_fixed_population_basic_add(): + """Test that we can add effects to a population of PhysicalModels.""" + population = FixedPopulation() + population.add_source(StaticSource(brightness=10.0), 0.5) + assert population.num_sources == 1 + assert np.allclose(population.probs, [1.0]) + + # Test that we fail with a bad rate. + with pytest.raises(ValueError): + population.add_source(StaticSource(brightness=10.0), -0.5) + + +def test_fixed_population_add_effect(): + """Test that we can add effects to a population of PhysicalModels.""" + model1 = StaticSource(brightness=10.0) + model2 = StaticSource(brightness=20.0) + + population = FixedPopulation() + population.add_source(model1, 0.5) + population.add_source(model2, 0.5) + assert population.num_sources == 2 + assert len(model1.effects) == 0 + assert len(model2.effects) == 0 + + # Add a white noise effect to all models. + population.add_effect(WhiteNoise(scale=0.01)) + assert len(model1.effects) == 1 + assert len(model2.effects) == 1 + + +def test_fixed_population_sample(): + """Test that we can sample and create a StaticSource object.""" + test_rng = np.random.default_rng(100) + population = FixedPopulation(rng=test_rng) + + population.add_source(StaticSource(brightness=0.0), 10.0) + assert np.allclose(population.probs, [1.0]) + assert population.num_sources == 1 + + population.add_source(StaticSource(brightness=1.0), 10.0) + assert np.allclose(population.probs, [0.5, 0.5]) + assert population.num_sources == 2 + + population.add_source(StaticSource(brightness=2.0), 20.0) + assert np.allclose(population.probs, [0.25, 0.25, 0.5]) + assert population.num_sources == 3 + + itr = 10_000 + counts = [0.0, 0.0, 0.0] + for _ in range(itr): + model = population.draw_source() + counts[int(model.brightness)] += 1.0 + assert np.allclose(counts, [0.25 * itr, 0.25 * itr, 0.5 * itr], rtol=0.05) + + # Check the we can change a rate. + population.change_rate(0, 20.0) + assert np.allclose(population.probs, [0.4, 0.2, 0.4]) + + counts = [0.0, 0.0, 0.0] + for _ in range(itr): + model = population.draw_source() + counts[int(model.brightness)] += 1.0 + assert np.allclose(counts, [0.4 * itr, 0.2 * itr, 0.4 * itr], rtol=0.05) From c5ed45f9ac1d9f7354ea83b818e669576496869c Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Fri, 12 Jul 2024 11:23:47 -0400 Subject: [PATCH 10/31] Use sncosmo.Source instead of sncosmo.Model --- src/tdastro/sources/sncomso_models.py | 39 +++++++++----------- tests/tdastro/sources/test_sncosmo_models.py | 14 +++---- 2 files changed, 23 insertions(+), 30 deletions(-) diff --git a/src/tdastro/sources/sncomso_models.py b/src/tdastro/sources/sncomso_models.py index 14db6715..379de41e 100644 --- a/src/tdastro/sources/sncomso_models.py +++ b/src/tdastro/sources/sncomso_models.py @@ -4,52 +4,47 @@ https://sncosmo.readthedocs.io/en/stable/models.html """ -from sncosmo.models import Model +from sncosmo.models import get_source from tdastro.base_models import PhysicalModel -class SncosmoModel(PhysicalModel): +class SncosmoWrapperModel(PhysicalModel): """A wrapper for sncosmo models. Attributes ---------- - model : `sncosmo.Model` - The underlying model. - model_name : `str` - The name used to set the model. + source : `sncosmo.Source` + The underlying source model. + source_name : `str` + The name used to set the source. Parameters ---------- - model_name : `str` - The name used to set the model. + source_name : `str` + The name used to set the source. **kwargs : `dict`, optional Any additional keyword arguments. """ - def __init__(self, model_name, **kwargs): + def __init__(self, source_name, **kwargs): super().__init__(**kwargs) - self.model_name = model_name - self.model = Model(source=model_name) + self.source_name = source_name + self.source = get_source(source_name) def __str__(self): """Return the string representation of the model.""" - return f"SncosmoModel({self.model_name})" + return f"SncosmoWrapperModel({self.source_name})" @property def param_names(self): """Return a list of the model's parameter names.""" - return self.model.param_names + return self.source.param_names @property def parameters(self): """Return a list of the model's parameter values.""" - return self.model.parameters - - @property - def source(self): - """Return the model's sncosmo source instance.""" - return self.model.source + return self.source.parameters def get(self, name): """Get the value of a specific parameter. @@ -63,7 +58,7 @@ def get(self, name): ------- The parameter value. """ - return self.model.get(name) + return self.source.get(name) def set(self, **kwargs): """Set the parameters of the model. @@ -75,7 +70,7 @@ def set(self, **kwargs): **kwargs : `dict` The parameters to set and their values. """ - self.model.set(**kwargs) + self.source.set(**kwargs) for key, value in kwargs.items(): if hasattr(self, key): self.set_parameter(key, value) @@ -99,4 +94,4 @@ def _evaluate(self, times, wavelengths, **kwargs): flux_density : `numpy.ndarray` A length T x N matrix of SED values. """ - return self.model.flux(times, wavelengths) + return self.source.flux(times, wavelengths) diff --git a/tests/tdastro/sources/test_sncosmo_models.py b/tests/tdastro/sources/test_sncosmo_models.py index b417fa16..58590903 100644 --- a/tests/tdastro/sources/test_sncosmo_models.py +++ b/tests/tdastro/sources/test_sncosmo_models.py @@ -1,18 +1,16 @@ import numpy as np -from tdastro.sources.sncomso_models import SncosmoModel +from tdastro.sources.sncomso_models import SncosmoWrapperModel def test_sncomso_models_hsiao() -> None: """Test that we can create and evalue a 'hsiao' model.""" - model = SncosmoModel("hsiao") - model.set(z=0.5, t0=55000.0, amplitude=1.0e-10) - assert model.z == 0.5 - assert model.t0 == 55000.0 + model = SncosmoWrapperModel("hsiao") + model.set(amplitude=1.0e-10) assert model.amplitude == 1.0e-10 - assert str(model) == "SncosmoModel(hsiao)" + assert str(model) == "SncosmoWrapperModel(hsiao)" - assert np.array_equal(model.param_names, ["z", "t0", "amplitude"]) - assert np.array_equal(model.parameters, [0.5, 55000.0, 1.0e-10]) + assert np.array_equal(model.param_names, ["amplitude"]) + assert np.array_equal(model.parameters, [1.0e-10]) # Test with the example from: https://sncosmo.readthedocs.io/en/stable/models.html fluxes = model.evaluate([54990.0], [4000.0, 4100.0, 4200.0]) From 8961ed01fb9a38c53f55e6cecb6205770e74b4f3 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Fri, 12 Jul 2024 13:05:09 -0400 Subject: [PATCH 11/31] Create the ability to sample from a background model --- src/tdastro/base_models.py | 18 ++++- src/tdastro/sources/galaxy_models.py | 89 +++++++++++++++++++++ tests/tdastro/sources/test_galaxy_models.py | 68 ++++++++++++++++ 3 files changed, 174 insertions(+), 1 deletion(-) create mode 100644 src/tdastro/sources/galaxy_models.py create mode 100644 tests/tdastro/sources/test_galaxy_models.py diff --git a/src/tdastro/base_models.py b/src/tdastro/base_models.py index e2fb3e13..5c81be8b 100644 --- a/src/tdastro/base_models.py +++ b/src/tdastro/base_models.py @@ -199,11 +199,13 @@ class PhysicalModel(ParameterizedModel): The object's declination (in degrees) distance : `float` The object's distance (in pc) + background : `PhysicalModel` + A source of background flux such as a host galaxy. effects : `list` A list of effects to apply to an observations. """ - def __init__(self, ra=None, dec=None, distance=None, **kwargs): + def __init__(self, ra=None, dec=None, distance=None, background=None, **kwargs): super().__init__(**kwargs) self.effects = [] @@ -212,6 +214,9 @@ def __init__(self, ra=None, dec=None, distance=None, **kwargs): self.add_parameter("dec", dec) self.add_parameter("distance", distance) + # Background is an object not a sampled parameter + self.background = background + def __str__(self): """Return the string representation of the model.""" return "PhysicalModel" @@ -282,7 +287,15 @@ def evaluate(self, times, wavelengths, resample_parameters=False, **kwargs): if resample_parameters: self.sample_parameters(kwargs) + # Compute the flux density for both the current object and add in anything + # behind it, such as a host galaxy. flux_density = self._evaluate(times, wavelengths, **kwargs) + if self.background is not None: + print("background is not None") + flux_density += self.background._evaluate(times, wavelengths, ra=self.ra, dec=self.dec, **kwargs) + else: + print("background is None") + for effect in self.effects: flux_density = effect.apply(flux_density, wavelengths, self, **kwargs) return flux_density @@ -299,7 +312,10 @@ def sample_parameters(self, include_effects=True, **kwargs): All the keyword arguments, including the values needed to sample parameters. """ + if self.background is not None: + self.background.sample_parameters(include_effects, **kwargs) super().sample_parameters(**kwargs) + if include_effects: for effect in self.effects: effect.sample_parameters(**kwargs) diff --git a/src/tdastro/sources/galaxy_models.py b/src/tdastro/sources/galaxy_models.py new file mode 100644 index 00000000..b4e6a4be --- /dev/null +++ b/src/tdastro/sources/galaxy_models.py @@ -0,0 +1,89 @@ +import numpy as np +from astropy.coordinates import angular_separation + +from tdastro.base_models import PhysicalModel + + +class GaussianGalaxy(PhysicalModel): + """A static source. + + Attributes + ---------- + radius_std : `float` + The standard deviation of the brightness as we move away + from the galaxy's center (in degrees). + brightness : `float` + The inherent brightness at the center of the galaxy. + """ + + def __init__(self, brightness, radius, **kwargs): + super().__init__(**kwargs) + self.add_parameter("galaxy_radius_std", radius, required=True, **kwargs) + self.add_parameter("brightness", brightness, required=True, **kwargs) + + def __str__(self): + """Return the string representation of the model.""" + return f"GuassianGalaxy({self.brightness}, {self.galaxy_radius_std})" + + def sample_ra(self): + """Sample an right ascension coordinate based on the center and radius of the galaxy. + + Returns + ------- + ra : `float` + The sampled right ascension. + """ + return np.random.normal(loc=self.ra, scale=self.galaxy_radius_std) + + def sample_dec(self): + """Sample a declination coordinate based on the center and radius of the galaxy. + + Returns + ------- + dec : `float` + The sampled declination. + """ + return np.random.normal(loc=self.dec, scale=self.galaxy_radius_std) + + def _evaluate(self, times, wavelengths, ra=None, dec=None, **kwargs): + """Draw effect-free observations for this object. + + Parameters + ---------- + times : `numpy.ndarray` + A length T array of timestamps. + wavelengths : `numpy.ndarray`, optional + A length N array of wavelengths. + ra : `float`, optional + The right ascension of the observations. + dec : `float`, optional + The declination of the observations. + **kwargs : `dict`, optional + Any additional keyword arguments. + + Returns + ------- + flux_density : `numpy.ndarray` + A length T x N matrix of SED values. + """ + if ra is None: + ra = self.ra + if dec is None: + dec = self.dec + + print(f"Host: {self.ra}, {self.dec}") + print(f"Query: {ra}, {dec}") + + # Scale the brightness as a Guassian function centered on the object's RA and Dec. + dist = angular_separation( + self.ra * np.pi / 180.0, + self.dec * np.pi / 180.0, + ra * np.pi / 180.0, + dec * np.pi / 180.0, + ) + print(f"Dist = {dist}") + + scale = np.exp(-(dist * dist) / (2.0 * self.galaxy_radius_std * self.galaxy_radius_std)) + print(f"Scale = {scale}") + + return np.full((len(times), len(wavelengths)), self.brightness * scale) diff --git a/tests/tdastro/sources/test_galaxy_models.py b/tests/tdastro/sources/test_galaxy_models.py new file mode 100644 index 00000000..e127494d --- /dev/null +++ b/tests/tdastro/sources/test_galaxy_models.py @@ -0,0 +1,68 @@ +import random + +import numpy as np +from tdastro.sources.galaxy_models import GaussianGalaxy +from tdastro.sources.static_source import StaticSource + + +def _sample_ra(**kwargs): + """Return a random value between 0 and 360. + + Parameters + ---------- + **kwargs : `dict`, optional + Absorbs additional parameters + """ + return 360.0 * random.random() + + +def _sample_dec(**kwargs): + """Return a random value between -90 and 90. + + Parameters + ---------- + **kwargs : `dict`, optional + Absorbs additional parameters + """ + return 180.0 * random.random() - 90.0 + + +def test_gaussian_galaxy() -> None: + """Test that we can sample and create a StaticSource object.""" + random.seed(1001) + + host = GaussianGalaxy(ra=_sample_ra, dec=_sample_dec, brightness=10.0, radius=1.0 / 3600.0) + host_ra = host.ra + host_dec = host.dec + + source = StaticSource(ra=host.sample_ra, dec=host.sample_dec, background=host, brightness=100.0) + + # Both RA and dec should be "close" to (but not exactly at) the center of the galaxy. + source_ra_offset = source.ra - host_ra + assert 0.0 < np.abs(source_ra_offset) < 100.0 / 3600.0 + + source_dec_offset = source.dec - host_dec + assert 0.0 < np.abs(source_dec_offset) < 100.0 / 3600.0 + + times = np.array([1, 2, 3, 4, 5, 10]) + wavelengths = np.array([100.0, 200.0, 300.0]) + + # All the measured fluxes should have some contribution from the background object. + values = source.evaluate(times, wavelengths) + assert values.shape == (6, 3) + assert np.all(values > 100.0) + assert np.all(values <= 110.0) + + # Check that if we resample the source it will propagate and correctly resample the host. + # the host's (RA, dec) should change and the source's should still be close. + source.sample_parameters() + assert host_ra != host.ra + assert host_dec != host.dec + + source_ra_offset2 = source.ra - host.ra + assert source_ra_offset != source_ra_offset2 + assert 0.0 < np.abs(source_ra_offset2) < 100.0 / 3600.0 + + source_dec_offset2 = source.dec - host.dec + assert source_dec_offset != source_dec_offset2 + assert 0.0 < np.abs(source_ra_offset2) < 100.0 / 3600.0 From e7133d62456fceec1b4cc65e03c1af5a9fdeca24 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Fri, 12 Jul 2024 13:21:25 -0400 Subject: [PATCH 12/31] Update base_models.py --- src/tdastro/base_models.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/tdastro/base_models.py b/src/tdastro/base_models.py index 5c81be8b..6dda7d6e 100644 --- a/src/tdastro/base_models.py +++ b/src/tdastro/base_models.py @@ -291,10 +291,7 @@ def evaluate(self, times, wavelengths, resample_parameters=False, **kwargs): # behind it, such as a host galaxy. flux_density = self._evaluate(times, wavelengths, **kwargs) if self.background is not None: - print("background is not None") flux_density += self.background._evaluate(times, wavelengths, ra=self.ra, dec=self.dec, **kwargs) - else: - print("background is None") for effect in self.effects: flux_density = effect.apply(flux_density, wavelengths, self, **kwargs) From 9d2f822e2e85ae63e87c1e5e7e55c76d07cbaa3a Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Fri, 12 Jul 2024 14:36:55 -0400 Subject: [PATCH 13/31] Address PR comments - Added an end to end test to show usage. - Moved from numpy.random to the built in random - Added duplicate testing for effects --- src/tdastro/base_models.py | 51 ++++++++------ src/tdastro/populations/fixed_population.py | 25 ++----- .../populations/test_fixed_population.py | 68 ++++++++++++++++--- 3 files changed, 96 insertions(+), 48 deletions(-) diff --git a/src/tdastro/base_models.py b/src/tdastro/base_models.py index 30bffcb2..a43319d2 100644 --- a/src/tdastro/base_models.py +++ b/src/tdastro/base_models.py @@ -2,7 +2,6 @@ import types from enum import Enum -from os import urandom import numpy as np @@ -221,7 +220,7 @@ def __str__(self): """Return the string representation of the model.""" return "PhysicalModel" - def add_effect(self, effect, **kwargs): + def add_effect(self, effect, allow_dups=True, **kwargs): """Add a transformational effect to the PhysicalModel. Effects are applied in the order in which they are added. @@ -229,6 +228,9 @@ def add_effect(self, effect, **kwargs): ---------- effect : `EffectModel` The effect to apply. + allow_dups : `bool` + Allow multiple effects of the same type. + Default = ``True`` **kwargs : `dict`, optional Any additional keyword arguments. @@ -237,6 +239,13 @@ def add_effect(self, effect, **kwargs): Raises a ``AttributeError`` if the PhysicalModel does not have all of the required attributes. """ + # Check that we have not added this effect before. + if not allow_dups: + effect_type = type(effect) + for prev in self.effects: + if effect_type == type(prev): + raise ValueError("Added the effect type to a model {effect_type} more than once.") + required: list = effect.required_parameters() for parameter in required: # Raise an AttributeError if the parameter is missing or set to None. @@ -363,25 +372,12 @@ class PopulationModel(ParameterizedModel): The number of different sources in the population. sources : `list` A list of sources from which to draw. - _rng : `numpy.random._generator.Generator` - A random number generator to use. - - Parameters - ---------- - rng : `numpy.random._generator.Generator`, optional - A random number generator to use for sampling. If not provided, - will create a new one with a randomized seed. """ def __init__(self, rng=None, **kwargs): super().__init__(**kwargs) self.num_sources = 0 self.sources = [] - if rng is None: - seed = int.from_bytes(urandom(4), "big") - self._rng = np.random.default_rng(seed) - else: - self._rng = rng def __str__(self): """Return the string representation of the model.""" @@ -412,7 +408,7 @@ def draw_source(self): """ raise NotImplementedError() - def add_effect(self, effect, **kwargs): + def add_effect(self, effect, allow_dups=False, **kwargs): """Add a transformational effect to all PhysicalModels in this population. Effects are applied in the order in which they are added. @@ -420,6 +416,9 @@ def add_effect(self, effect, **kwargs): ---------- effect : `EffectModel` The effect to apply. + allow_dups : `bool` + Allow multiple effects of the same type. + Default = ``True`` **kwargs : `dict`, optional Any additional keyword arguments. @@ -429,13 +428,15 @@ def add_effect(self, effect, **kwargs): required attributes. """ for source in self.sources: - source.add_effect(effect, **kwargs) + source.add_effect(effect, allow_dups=allow_dups, **kwargs) - def evaluate(self, times, wavelengths, resample_parameters=False, **kwargs): + def evaluate(self, samples, times, wavelengths, resample_parameters=False, **kwargs): """Draw observations from a single (randomly sampled) source. Parameters ---------- + samples : `int` + The number of sources to samples. times : `numpy.ndarray` A length T array of timestamps. wavelengths : `numpy.ndarray`, optional @@ -448,7 +449,15 @@ def evaluate(self, times, wavelengths, resample_parameters=False, **kwargs): Returns ------- - flux_density : `numpy.ndarray` - A length T x N matrix of SED values. + results : `numpy.ndarray` + A shape (samples, T, N) matrix of SED values. """ - return self.draw_source().evalute(times, wavelengths, resample_parameters, **kwargs) + if samples <= 0: + raise ValueError("The number of samples must be > 0.") + + results = [] + for _ in range(samples): + source = self.draw_source() + object_fluxes = source.evaluate(times, wavelengths, resample_parameters, **kwargs) + results.append(object_fluxes) + return np.array(results) diff --git a/src/tdastro/populations/fixed_population.py b/src/tdastro/populations/fixed_population.py index 06838062..bc644b4a 100644 --- a/src/tdastro/populations/fixed_population.py +++ b/src/tdastro/populations/fixed_population.py @@ -1,4 +1,4 @@ -import numpy as np +import random from tdastro.base_models import PopulationModel @@ -8,9 +8,7 @@ class FixedPopulation(PopulationModel): Attributes ---------- - probs : `numpy.ndarray` - The probability of drawing each type of source. - _raw_rates : `numpy.ndarray` + weights : `numpy.ndarray` An array of floats that provides the base sampling rate for each type. This is normalized into a probability distributions so [100, 200, 200] -> [0.2, 0.4, 0.4]. @@ -18,17 +16,12 @@ class FixedPopulation(PopulationModel): def __init__(self, **kwargs): super().__init__(**kwargs) - self.probs = np.array([]) - self._raw_rates = np.array([]) + self.weights = [] def __str__(self): """Return the string representation of the model.""" return f"FixedPopulation({self.probability})" - def _update_probabilities(self): - """Update the probability array.""" - self.probs = self._raw_rates / np.sum(self._raw_rates) - def add_source(self, new_source, rate, **kwargs): """Add a new source to the population. @@ -48,12 +41,10 @@ def add_source(self, new_source, rate, **kwargs): if rate <= 0.0: raise ValueError(f"Expected positive rate. Found {rate}.") super().add_source(new_source, **kwargs) - - self._raw_rates = np.append(self._raw_rates, rate) - self._update_probabilities() + self.weights.append(rate) def change_rate(self, source_index, rate, **kwargs): - """Add a new source to the population. + """Change rate of a source. Parameters ---------- @@ -70,8 +61,7 @@ def change_rate(self, source_index, rate, **kwargs): """ if rate <= 0.0: raise ValueError(f"Expected positive rate. Found {rate}.") - self._raw_rates[source_index] = rate - self._update_probabilities() + self.weights[source_index] = rate def draw_source(self): """Sample a single source from the population. @@ -81,5 +71,4 @@ def draw_source(self): source : `PhysicalModel` A source from the population. """ - index = self._rng.choice(np.arange(0, self.num_sources), p=self.probs) - return self.sources[index] + return random.choices(self.sources, weights=self.weights)[0] diff --git a/tests/tdastro/populations/test_fixed_population.py b/tests/tdastro/populations/test_fixed_population.py index 8989c921..134ed459 100644 --- a/tests/tdastro/populations/test_fixed_population.py +++ b/tests/tdastro/populations/test_fixed_population.py @@ -1,3 +1,5 @@ +import random + import numpy as np import pytest from tdastro.effects.white_noise import WhiteNoise @@ -10,7 +12,7 @@ def test_fixed_population_basic_add(): population = FixedPopulation() population.add_source(StaticSource(brightness=10.0), 0.5) assert population.num_sources == 1 - assert np.allclose(population.probs, [1.0]) + assert np.allclose(population.weights, [0.5]) # Test that we fail with a bad rate. with pytest.raises(ValueError): @@ -35,21 +37,34 @@ def test_fixed_population_add_effect(): assert len(model2.effects) == 1 -def test_fixed_population_sample(): - """Test that we can sample and create a StaticSource object.""" - test_rng = np.random.default_rng(100) - population = FixedPopulation(rng=test_rng) +def test_fixed_population_add_effect_fail(): + """Test a case where we try to add an existing effect to models.""" + model1 = StaticSource(brightness=10.0) + model1.add_effect(WhiteNoise(scale=0.01)) + + population = FixedPopulation() + population.add_source(model1, 0.5) + + # Fail when we try to re-add the WhiteNoise effect + with pytest.raises(ValueError): + population.add_effect(WhiteNoise(scale=0.01)) + + +def test_fixed_population_sample_sources(): + """Test that we can create a population of sources and sample its sources.""" + random.seed(1000) + population = FixedPopulation() population.add_source(StaticSource(brightness=0.0), 10.0) - assert np.allclose(population.probs, [1.0]) + assert np.allclose(population.weights, [10.0]) assert population.num_sources == 1 population.add_source(StaticSource(brightness=1.0), 10.0) - assert np.allclose(population.probs, [0.5, 0.5]) + assert np.allclose(population.weights, [10.0, 10.0]) assert population.num_sources == 2 population.add_source(StaticSource(brightness=2.0), 20.0) - assert np.allclose(population.probs, [0.25, 0.25, 0.5]) + assert np.allclose(population.weights, [10.0, 10.0, 20.0]) assert population.num_sources == 3 itr = 10_000 @@ -61,10 +76,45 @@ def test_fixed_population_sample(): # Check the we can change a rate. population.change_rate(0, 20.0) - assert np.allclose(population.probs, [0.4, 0.2, 0.4]) counts = [0.0, 0.0, 0.0] for _ in range(itr): model = population.draw_source() counts[int(model.brightness)] += 1.0 assert np.allclose(counts, [0.4 * itr, 0.2 * itr, 0.4 * itr], rtol=0.05) + + +def _random_brightness(): + """Returns a random value [0.0, 100.0]""" + return 100.0 * random.random() + + +def test_fixed_population_sample_fluxes(): + """Test that we can create a population of sources and sample its sources' flux.""" + random.seed(1001) + population = FixedPopulation() + population.add_source(StaticSource(brightness=100.0), 10.0) + population.add_source(StaticSource(brightness=_random_brightness), 10.0) + population.add_source(StaticSource(brightness=200.0), 20.0) + population.add_source(StaticSource(brightness=150.0), 10.0) + + # Sample the actual observations, resampling the corresponding + # model's parameters each time. + num_samples = 10_000 + times = np.array([1, 2, 3, 4, 5]) + wavelengths = np.array([100.0, 200.0, 300.0]) + fluxes = population.evaluate(num_samples, times, wavelengths, resample_parameters=True) + + # Check that the fluxes are constant within a sample. Also check that we have + # More than 4 values (since we are resampling a model with a random parameter). + seen_values = [] + for i in range(num_samples): + value = fluxes[i, 0, 0] + seen_values.append(value) + assert np.allclose(fluxes[i], value) + assert len(np.unique(seen_values)) > 4 + + # Check that our average is near the expected value of flux. + ave_val = np.mean(fluxes.flatten()) + expected = 0.2 * 100.0 + 0.2 * 50.0 + 0.4 * 200.0 + 0.2 * 150.0 + assert abs(ave_val - expected) < 10.0 From a4fdd19af724a6bc9021be7863cfc8df8a70139b Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Fri, 12 Jul 2024 15:15:16 -0400 Subject: [PATCH 14/31] Address PR comments --- src/tdastro/base_models.py | 10 +++++++--- src/tdastro/sources/galaxy_models.py | 14 ++++---------- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/src/tdastro/base_models.py b/src/tdastro/base_models.py index 6dda7d6e..205d0755 100644 --- a/src/tdastro/base_models.py +++ b/src/tdastro/base_models.py @@ -187,9 +187,13 @@ def sample_parameters(self, max_depth=50, **kwargs): class PhysicalModel(ParameterizedModel): - """A physical model of a source of flux. Physical models can have fixed attributes - (where you need to create a new model to change them) and settable attributes that - can be passed functions or constants. + """A physical model of a source of flux. + + Physical models can have fixed attributes (where you need to create a new model + to change them) and settable attributes that can be passed functions or constants. + They can also have special background pointers that link to another PhysicalModel + producing flux. We can chain these to have a supernova in front of a star in front + of a static background. Attributes ---------- diff --git a/src/tdastro/sources/galaxy_models.py b/src/tdastro/sources/galaxy_models.py index b4e6a4be..9aea8bfc 100644 --- a/src/tdastro/sources/galaxy_models.py +++ b/src/tdastro/sources/galaxy_models.py @@ -31,7 +31,7 @@ def sample_ra(self): Returns ------- ra : `float` - The sampled right ascension. + The sampled right ascension in degrees. """ return np.random.normal(loc=self.ra, scale=self.galaxy_radius_std) @@ -41,7 +41,7 @@ def sample_dec(self): Returns ------- dec : `float` - The sampled declination. + The sampled declination in degrees. """ return np.random.normal(loc=self.dec, scale=self.galaxy_radius_std) @@ -55,9 +55,9 @@ def _evaluate(self, times, wavelengths, ra=None, dec=None, **kwargs): wavelengths : `numpy.ndarray`, optional A length N array of wavelengths. ra : `float`, optional - The right ascension of the observations. + The right ascension of the observations in degrees. dec : `float`, optional - The declination of the observations. + The declination of the observations in degrees. **kwargs : `dict`, optional Any additional keyword arguments. @@ -71,9 +71,6 @@ def _evaluate(self, times, wavelengths, ra=None, dec=None, **kwargs): if dec is None: dec = self.dec - print(f"Host: {self.ra}, {self.dec}") - print(f"Query: {ra}, {dec}") - # Scale the brightness as a Guassian function centered on the object's RA and Dec. dist = angular_separation( self.ra * np.pi / 180.0, @@ -81,9 +78,6 @@ def _evaluate(self, times, wavelengths, ra=None, dec=None, **kwargs): ra * np.pi / 180.0, dec * np.pi / 180.0, ) - print(f"Dist = {dist}") - scale = np.exp(-(dist * dist) / (2.0 * self.galaxy_radius_std * self.galaxy_radius_std)) - print(f"Scale = {scale}") return np.full((len(times), len(wavelengths)), self.brightness * scale) From 4982de5196636485c24edb166aaf49badd68062a Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Sat, 13 Jul 2024 15:48:46 -0400 Subject: [PATCH 15/31] Add basic wrapped function --- src/tdastro/base_models.py | 58 ++++++++++--------- src/tdastro/function_wrappers.py | 69 +++++++++++++++++++++++ tests/tdastro/test_function_wrappers.py | 74 +++++++++++++++++++++++++ 3 files changed, 175 insertions(+), 26 deletions(-) create mode 100644 src/tdastro/function_wrappers.py create mode 100644 tests/tdastro/test_function_wrappers.py diff --git a/src/tdastro/base_models.py b/src/tdastro/base_models.py index a43319d2..d5b38b51 100644 --- a/src/tdastro/base_models.py +++ b/src/tdastro/base_models.py @@ -5,6 +5,8 @@ import numpy as np +from tdastro.function_wrappers import TDFunc + class ParameterSource(Enum): """ParameterSource specifies where a PhysicalModel should get the value @@ -14,8 +16,9 @@ class ParameterSource(Enum): CONSTANT = 1 FUNCTION = 2 - MODEL_ATTRIBUTE = 3 - MODEL_METHOD = 4 + TDFUNC_OBJ = 3 + MODEL_ATTRIBUTE = 4 + MODEL_METHOD = 5 class ParameterizedModel: @@ -75,31 +78,32 @@ def set_parameter(self, name, value=None, **kwargs): # The value wasn't set, but the name is in kwargs. value = kwargs[name] - if value is not None: - if isinstance(value, types.FunctionType): - # Case 1: If we are getting from a static function, sample it. - self.setters[name] = (ParameterSource.FUNCTION, value, required) - setattr(self, name, value(**kwargs)) - elif isinstance(value, types.MethodType) and isinstance(value.__self__, ParameterizedModel): - # Case 2: We are trying to use the method from a ParameterizedModel. - # Note that this will (correctly) fail if we are adding a model method from the current - # object that requires an unset attribute. - self.setters[name] = (ParameterSource.MODEL_METHOD, value, required) - setattr(self, name, value(**kwargs)) - elif isinstance(value, ParameterizedModel): - # Case 3: We are trying to access an attribute from a parameterized model. - if not hasattr(value, name): - raise ValueError(f"Attribute {name} missing from parent.") - self.setters[name] = (ParameterSource.MODEL_ATTRIBUTE, value, required) - setattr(self, name, getattr(value, name)) - else: - # Case 4: The value is constant. - self.setters[name] = (ParameterSource.CONSTANT, value, required) - setattr(self, name, value) - elif not required: - self.setters[name] = (ParameterSource.CONSTANT, None, required) - setattr(self, name, None) + if isinstance(value, types.FunctionType): + # Case 1a: If we are getting from a static function, sample it. + self.setters[name] = (ParameterSource.FUNCTION, value, required) + setattr(self, name, value(**kwargs)) + elif isinstance(value, TDFunc): + # Case 1b: We are using a TDFunc wrapped function (with default parameters). + self.setters[name] = (ParameterSource.TDFUNC_OBJ, value, required) + setattr(self, name, value(self, **kwargs)) + elif isinstance(value, types.MethodType) and isinstance(value.__self__, ParameterizedModel): + # Case 2: We are trying to use the method from a ParameterizedModel. + # Note that this will (correctly) fail if we are adding a model method from the current + # object that requires an unset attribute. + self.setters[name] = (ParameterSource.MODEL_METHOD, value, required) + setattr(self, name, value(**kwargs)) + elif isinstance(value, ParameterizedModel): + # Case 3: We are trying to access an attribute from a parameterized model. + if not hasattr(value, name): + raise ValueError(f"Attribute {name} missing from parent.") + self.setters[name] = (ParameterSource.MODEL_ATTRIBUTE, value, required) + setattr(self, name, getattr(value, name)) else: + # Case 4: The value is constant (including None). + self.setters[name] = (ParameterSource.CONSTANT, value, required) + setattr(self, name, value) + + if required and getattr(self, name) is None: raise ValueError(f"Missing required parameter {name}") def add_parameter(self, name, value=None, required=False, **kwargs): @@ -169,6 +173,8 @@ def sample_parameters(self, max_depth=50, **kwargs): sampled_value = setter elif source_type == ParameterSource.FUNCTION: sampled_value = setter(**kwargs) + elif source_type == ParameterSource.TDFUNC_OBJ: + sampled_value = setter(self, **kwargs) elif source_type == ParameterSource.MODEL_ATTRIBUTE: # Check if we need to resample the parent (needs to be done before # we read its attribute). diff --git a/src/tdastro/function_wrappers.py b/src/tdastro/function_wrappers.py new file mode 100644 index 00000000..03f4ef0d --- /dev/null +++ b/src/tdastro/function_wrappers.py @@ -0,0 +1,69 @@ +"""Classes to wrap functions to allow users to pass around functions +with partially specified arguments. +""" + +import copy + + +class TDFunc: + """A class to wrap functions to pass around functions with default + argument settings, arguments from kwargs, and (optionally) + arguments that are from the fcalling object. + + The object stores default arguments for the function, which does not + need to include all the function's parameters. Any parameters not + included in the default list must be specified as part of the kwargs. + + Attributes + ---------- + func : `function` or `method` + The function to call during an evaluation. + default_args : `dict` + A dictionary of default arguments to pass to the function. + This does not need to include all the arguments. + object_args : `list`, optional + Arguments that are provided by attributes of the calling object. + """ + + def __init__(self, func, default_args=None, object_args=None, **kwargs): + self.func = func + self.object_args = object_args + self.default_args = {} + + # The default arguments are the union of the default_args parameter + # and the remaining kwargs. + if default_args is not None: + self.default_args = default_args + if kwargs: + for key, value in kwargs.items(): + self.default_args[key] = value + + def __str__(self): + """Return the string representation of the function.""" + return f"TDFunc({self.func.name})" + + def __call__(self, calling_object=None, **kwargs): + """Execute the wrapped function. + + Parameters + ---------- + calling_object : any, optional + The object that called the function. + **kwargs : `dict`, optional + Additional function arguments. + """ + # Start with the default arguments. We make a copy so we can modify the dictionary. + args = copy.copy(self.default_args) + + # If there are arguments to get from the calling object, set those. + if self.object_args is not None and len(self.object_args) > 0: + if calling_object is None: + raise ValueError(f"Calling object needed for parameters: {self.object_args}") + for arg_name in self.object_args: + args[arg_name] = getattr(calling_object, arg_name) + + # Set any last arguments from the kwargs (overwriting previous settings). + args.update(kwargs) + + # Call the function with all the parameters. + return self.func(**args) diff --git a/tests/tdastro/test_function_wrappers.py b/tests/tdastro/test_function_wrappers.py new file mode 100644 index 00000000..40ea9e28 --- /dev/null +++ b/tests/tdastro/test_function_wrappers.py @@ -0,0 +1,74 @@ +import pytest +from tdastro.function_wrappers import TDFunc + + +def _test_func(a, b): + """Return the sum of the two parameters. + + Parameters + ---------- + a : `float` + The first parameter. + b : `float` + The second parameter. + """ + return a + b + + +class _StaticModel: + """A test model that has given parameters. + + Attributes + ---------- + a : `float` + The first parameter. + b : `float` + The second parameter. + """ + + def __init__(self, a, b): + self.a = a + self.b = b + + def eval_func(self, func, **kwargs): + """Evaluate a TDFunc. + + Parameters + ---------- + func : `TDFunc` + The function to evaluate. + **kwargs : `dict`, optional + Any additional keyword arguments. + """ + return func(self, **kwargs) + + +def test_tdfunc_basic(): + """Test that we can create and query a TDFunc.""" + tdf1 = TDFunc(_test_func, a=1.0) + + # Fail without enough arguments (only a is specified). + with pytest.raises(TypeError): + _ = tdf1() + + # We succeed with a manually specified parameter. + assert tdf1(b=2.0) == 3.0 + + # We can overwrite parameters. + assert tdf1(a=3.0, b=2.0) == 5.0 + + +def test_tdfunc_obj(): + """Test that we can create and query a TDFunc that depends on an object.""" + model = _StaticModel(a=10.0, b=11.0) + + # Check a function without defaults. + tdf1 = TDFunc(_test_func, object_args=["a", "b"]) + assert model.eval_func(tdf1) == 21.0 + + # Defaults set are overwritten by the object. + tdf2 = TDFunc(_test_func, object_args=["a", "b"], a=1.0, b=0.0) + assert model.eval_func(tdf2) == 21.0 + + # But we can overwrite everything with kwargs. + assert model.eval_func(tdf2, b=7.5) == 17.5 From a97242c7c6964174caf9c241444cae68fb3ba1b2 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Sun, 14 Jul 2024 09:30:06 -0400 Subject: [PATCH 16/31] Additional tests --- tests/tdastro/sources/test_step_source.py | 19 ++++++++++++------- tests/tdastro/test_function_wrappers.py | 4 ++++ 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/tests/tdastro/sources/test_step_source.py b/tests/tdastro/sources/test_step_source.py index dc069241..325e73f2 100644 --- a/tests/tdastro/sources/test_step_source.py +++ b/tests/tdastro/sources/test_step_source.py @@ -1,6 +1,7 @@ import random import numpy as np +from tdastro.function_wrappers import TDFunc from tdastro.sources.static_source import StaticSource from tdastro.sources.step_source import StepSource @@ -18,7 +19,7 @@ def _sample_brightness(magnitude, **kwargs): return magnitude * random.random() -def _sample_end(duration, **kwargs): +def _sample_end(duration): """Return a random value between 1 and 1 + duration Parameters @@ -54,20 +55,20 @@ def test_step_source() -> None: def test_step_source_resample() -> None: """Check that we can call resample on the model parameters.""" + random.seed(1111) + model = StepSource( - brightness=_sample_brightness, + brightness=TDFunc(_sample_brightness, magnitude=100.0), t0=0.0, - t1=_sample_end, - magnitude=100.0, - duration=5.0, + t1=TDFunc(_sample_end, duration=5.0), ) - num_samples = 100 + num_samples = 1000 brightness_vals = np.zeros((num_samples, 1)) t_end_vals = np.zeros((num_samples, 1)) t_start_vals = np.zeros((num_samples, 1)) for i in range(num_samples): - model.sample_parameters(magnitude=100.0, duration=5.0) + model.sample_parameters() brightness_vals[i] = model.brightness t_end_vals[i] = model.t1 t_start_vals[i] = model.t0 @@ -79,6 +80,10 @@ def test_step_source_resample() -> None: assert np.all(t_end_vals >= 1.0) assert np.all(t_end_vals <= 6.0) + # Check that the expected values are close. + assert abs(np.mean(brightness_vals) - 50.0) < 5.0 + assert abs(np.mean(t_end_vals) - 3.5) < 0.5 + # Check that the brightness or end values are not all the same. assert not np.all(brightness_vals == brightness_vals[0]) assert not np.all(t_end_vals == t_end_vals[0]) diff --git a/tests/tdastro/test_function_wrappers.py b/tests/tdastro/test_function_wrappers.py index 40ea9e28..13ddaa3e 100644 --- a/tests/tdastro/test_function_wrappers.py +++ b/tests/tdastro/test_function_wrappers.py @@ -57,6 +57,10 @@ def test_tdfunc_basic(): # We can overwrite parameters. assert tdf1(a=3.0, b=2.0) == 5.0 + # That we can use a different ordering for parameters. + tdf2 = TDFunc(_test_func, b=2.0, a=1.0) + assert tdf2() == 3.0 + def test_tdfunc_obj(): """Test that we can create and query a TDFunc that depends on an object.""" From 897e9d625aaf7f706b74f273b59bd5018dcae557 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Sun, 14 Jul 2024 10:24:30 -0400 Subject: [PATCH 17/31] Add basic cosmology --- src/tdastro/astro_utils/cosmology.py | 59 +++++++++++++++++++++ src/tdastro/base_models.py | 22 ++++++-- tests/tdastro/astro_utils/test_cosmology.py | 10 ++++ tests/tdastro/test_base_models.py | 31 ++++++++++- 4 files changed, 117 insertions(+), 5 deletions(-) create mode 100644 src/tdastro/astro_utils/cosmology.py create mode 100644 tests/tdastro/astro_utils/test_cosmology.py diff --git a/src/tdastro/astro_utils/cosmology.py b/src/tdastro/astro_utils/cosmology.py new file mode 100644 index 00000000..f7fbb6ba --- /dev/null +++ b/src/tdastro/astro_utils/cosmology.py @@ -0,0 +1,59 @@ +import astropy.cosmology.units as cu +from astropy import units as u + +from tdastro.function_wrappers import TDFunc + + +def redshift_to_distance(redshift, cosmology, kind="comoving"): + """Compute a source's distance given its redshift and a + specified cosmology using astropy's redshift_distance(). + + Parameters + ---------- + redshift : `float` + The redshift value. + cosmology : `astropy.cosmology` + The cosmology specification. + kind : `str` + The distance type for the Equivalency as defined by + astropy.cosmology.units.redshift_distance. + + Returns + ------- + distance : `float` + The distance (in pc) + """ + z = redshift * cu.redshift + distance = z.to(u.pc, cu.redshift_distance(cosmology, kind=kind)) + return distance.value + + +class RedshiftDistFunc(TDFunc): + """A wrapper class for the redshift_to_distance() function. + + Attributes + ---------- + cosmology : `astropy.cosmology` + The cosmology specification. + kind : `str` + The distance type for the Equivalency as defined by + astropy.cosmology.units.redshift_distance. + """ + + def __init__(self, cosmology, kind="comoving", **kwargs): + self.cosmology = cosmology + self.kind = kind + default_args = {"cosmology": cosmology, "kind": kind} + + # Call the super class's constructor with the needed information. + # Do not pass kwargs because we are limiting the arguments to match + # the signature of redshift_to_distance(). + super().__init__( + func=redshift_to_distance, + default_args=default_args, + object_args=["redshift"], + ) + + def __str__(self): + """Return the string representation of the function.""" + return f"RedshiftDistFunc({self.cosmology.name}, {self.kind})" diff --git a/src/tdastro/base_models.py b/src/tdastro/base_models.py index d5b38b51..86b24493 100644 --- a/src/tdastro/base_models.py +++ b/src/tdastro/base_models.py @@ -5,6 +5,7 @@ import numpy as np +from tdastro.astro_utils.cosmology import RedshiftDistFunc from tdastro.function_wrappers import TDFunc @@ -207,20 +208,33 @@ class PhysicalModel(ParameterizedModel): The object's right ascension (in degrees) dec : `float` The object's declination (in degrees) + redshift : `float` + The object's redshift. distance : `float` - The object's distance (in pc) + The object's distance (in pc). If the distance is not provided and + a ``cosmology`` parameter is given, the model will try to derive from + the redshift and the cosmology. effects : `list` A list of effects to apply to an observations. """ - def __init__(self, ra=None, dec=None, distance=None, **kwargs): + def __init__(self, ra=None, dec=None, redshift=None, distance=None, **kwargs): super().__init__(**kwargs) self.effects = [] - # Set RA, dec, and distance from the parameters. + # Set RA, dec, and redshift from the parameters. self.add_parameter("ra", ra) self.add_parameter("dec", dec) - self.add_parameter("distance", distance) + self.add_parameter("redshift", redshift) + + # If the distance is provided, use that. Otherwise try the redshift value + # using the cosmology (if given). Finally, default to None. + if distance is not None: + self.add_parameter("distance", distance) + elif redshift is not None and kwargs.get("cosmology", None) is not None: + self.add_parameter("distance", RedshiftDistFunc(**kwargs)) + else: + self.add_parameter("distance", None) def __str__(self): """Return the string representation of the model.""" diff --git a/tests/tdastro/astro_utils/test_cosmology.py b/tests/tdastro/astro_utils/test_cosmology.py new file mode 100644 index 00000000..14a500e2 --- /dev/null +++ b/tests/tdastro/astro_utils/test_cosmology.py @@ -0,0 +1,10 @@ +import pytest +from astropy.cosmology import WMAP9 +from tdastro.astro_utils.cosmology import redshift_to_distance + + +def test_redshift_to_distance(): + """Test that we can convert the redshift to a distance using a given cosmology.""" + # Use the example from: + # https://docs.astropy.org/en/stable/api/astropy.cosmology.units.redshift_distance.html + assert redshift_to_distance(1100, cosmology=WMAP9) == pytest.approx(14004.03157418 * 1e6) diff --git a/tests/tdastro/test_base_models.py b/tests/tdastro/test_base_models.py index b5ff1daa..6847cffa 100644 --- a/tests/tdastro/test_base_models.py +++ b/tests/tdastro/test_base_models.py @@ -1,7 +1,8 @@ import random import pytest -from tdastro.base_models import ParameterizedModel +from astropy.cosmology import WMAP9 +from tdastro.base_models import ParameterizedModel, PhysicalModel def _sampler_fun(**kwargs): @@ -141,3 +142,31 @@ def test_parameterized_model_modify() -> None: # We cannot set a value that hasn't been added. with pytest.raises(KeyError): model.set_parameter("brightness", 5.0) + + +def test_physical_model(): + """Test that we can create a physical model.""" + # Everything is specified. + model1 = PhysicalModel(ra=1.0, dec=2.0, distance=3.0, redshift=0.0) + assert model1.ra == 1.0 + assert model1.dec == 2.0 + assert model1.distance == 3.0 + assert model1.redshift == 0.0 + + # Derive the distance from the redshift using the example from: + # https://docs.astropy.org/en/stable/api/astropy.cosmology.units.redshift_distance.html + model2 = PhysicalModel(ra=1.0, dec=2.0, redshift=1100.0, cosmology=WMAP9) + assert model2.ra == 1.0 + assert model2.dec == 2.0 + assert model2.redshift == 1100.0 + assert model2.distance == pytest.approx(14004.03157418 * 1e6) + + # Neither distance nor redshift are specified. + model3 = PhysicalModel(ra=1.0, dec=2.0) + assert model3.redshift is None + assert model3.distance is None + + # Redshift is specified but cosmology is not. + model4 = PhysicalModel(ra=1.0, dec=2.0, redshift=1100.0) + assert model4.redshift == 1100.0 + assert model4.distance is None From 748447cd6a28f0f3545f0d96e1d0ceaf5ed9073f Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Sun, 14 Jul 2024 11:11:52 -0400 Subject: [PATCH 18/31] Add more comments and tests --- src/tdastro/function_wrappers.py | 41 ++++++++++++++++++------- tests/tdastro/test_function_wrappers.py | 38 ++++++++++++++++++----- 2 files changed, 61 insertions(+), 18 deletions(-) diff --git a/src/tdastro/function_wrappers.py b/src/tdastro/function_wrappers.py index 03f4ef0d..ed24aa3f 100644 --- a/src/tdastro/function_wrappers.py +++ b/src/tdastro/function_wrappers.py @@ -8,21 +8,34 @@ class TDFunc: """A class to wrap functions to pass around functions with default argument settings, arguments from kwargs, and (optionally) - arguments that are from the fcalling object. - - The object stores default arguments for the function, which does not - need to include all the function's parameters. Any parameters not - included in the default list must be specified as part of the kwargs. + arguments that are from the calling object. Attributes ---------- func : `function` or `method` The function to call during an evaluation. default_args : `dict` - A dictionary of default arguments to pass to the function. - This does not need to include all the arguments. - object_args : `list`, optional + A dictionary of default arguments to pass to the function. Assembled + from the ```default_args`` parameter and additional ``kwargs``. + object_args : `list` Arguments that are provided by attributes of the calling object. + + Example + ------- + my_func = TDFunc(random.randint, a=1, b=10) + value1 = my_func() # Sample from default range + value2 = my_func(b=20) # Sample from extended range + + Note + ---- + All the function's parameters that will be used need to be specified + in either the default_args dict, object_args list, or as a kwarg in the + constructor. Arguments cannot be first given during function call. + For example the following will fail (because b is not defined in the + constructor): + + my_func = TDFunc(random.randint, a=1) + value1 = my_func(b=10.0) """ def __init__(self, func, default_args=None, object_args=None, **kwargs): @@ -30,13 +43,16 @@ def __init__(self, func, default_args=None, object_args=None, **kwargs): self.object_args = object_args self.default_args = {} - # The default arguments are the union of the default_args parameter - # and the remaining kwargs. + # The default arguments are the union of the default_args parameter, + # the object_args (set to None), and the remaining kwargs. if default_args is not None: self.default_args = default_args if kwargs: for key, value in kwargs.items(): self.default_args[key] = value + if object_args: + for key in object_args: + self.default_args[key] = None def __str__(self): """Return the string representation of the function.""" @@ -63,7 +79,10 @@ def __call__(self, calling_object=None, **kwargs): args[arg_name] = getattr(calling_object, arg_name) # Set any last arguments from the kwargs (overwriting previous settings). - args.update(kwargs) + # Only use known kwargs. + for key, value in kwargs.items(): + if key in self.default_args: + args[key] = value # Call the function with all the parameters. return self.func(**args) diff --git a/tests/tdastro/test_function_wrappers.py b/tests/tdastro/test_function_wrappers.py index 13ddaa3e..ed59170f 100644 --- a/tests/tdastro/test_function_wrappers.py +++ b/tests/tdastro/test_function_wrappers.py @@ -1,3 +1,4 @@ +import numpy as np import pytest from tdastro.function_wrappers import TDFunc @@ -45,21 +46,44 @@ def eval_func(self, func, **kwargs): def test_tdfunc_basic(): """Test that we can create and query a TDFunc.""" - tdf1 = TDFunc(_test_func, a=1.0) - # Fail without enough arguments (only a is specified). + tdf1 = TDFunc(_test_func, a=1.0) with pytest.raises(TypeError): _ = tdf1() - # We succeed with a manually specified parameter. - assert tdf1(b=2.0) == 3.0 + # We succeed with a manually specified parameter (but first fail with + # the default). + tdf2 = TDFunc(_test_func, a=1.0, b=None) + with pytest.raises(TypeError): + _ = tdf2() + assert tdf2(b=2.0) == 3.0 # We can overwrite parameters. - assert tdf1(a=3.0, b=2.0) == 5.0 + assert tdf2(a=3.0, b=2.0) == 5.0 + + # Test that we ignore extra kwargs. + assert tdf2(b=2.0, c=10.0, d=11.0) == 3.0 # That we can use a different ordering for parameters. - tdf2 = TDFunc(_test_func, b=2.0, a=1.0) - assert tdf2() == 3.0 + tdf3 = TDFunc(_test_func, b=2.0, a=1.0) + assert tdf3() == 3.0 + + +def test_np_sampler_method(): + """Test that we can wrap numpy random functions.""" + rng = np.random.default_rng(1001) + tdf = TDFunc(rng.normal, loc=10.0, scale=1.0) + + # Sample 1000 times with the default values. Check that we are near + # the expected mean and not everything is equal. + vals = np.array([tdf() for _ in range(1000)]) + assert abs(np.mean(vals) - 10.0) < 1.0 + assert not np.all(vals == vals[0]) + + # Override the mean and resample. + vals = np.array([tdf(loc=25.0) for _ in range(1000)]) + assert abs(np.mean(vals) - 25.0) < 1.0 + assert not np.all(vals == vals[0]) def test_tdfunc_obj(): From bfc57227b6cc4b7cea81abc4677cbff47c2aff90 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Sun, 14 Jul 2024 11:27:42 -0400 Subject: [PATCH 19/31] Update population test to use TDFunc --- tests/tdastro/populations/test_fixed_population.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/tdastro/populations/test_fixed_population.py b/tests/tdastro/populations/test_fixed_population.py index 134ed459..95b0e0ae 100644 --- a/tests/tdastro/populations/test_fixed_population.py +++ b/tests/tdastro/populations/test_fixed_population.py @@ -3,6 +3,7 @@ import numpy as np import pytest from tdastro.effects.white_noise import WhiteNoise +from tdastro.function_wrappers import TDFunc from tdastro.populations.fixed_population import FixedPopulation from tdastro.sources.static_source import StaticSource @@ -84,17 +85,12 @@ def test_fixed_population_sample_sources(): assert np.allclose(counts, [0.4 * itr, 0.2 * itr, 0.4 * itr], rtol=0.05) -def _random_brightness(): - """Returns a random value [0.0, 100.0]""" - return 100.0 * random.random() - - def test_fixed_population_sample_fluxes(): """Test that we can create a population of sources and sample its sources' flux.""" random.seed(1001) population = FixedPopulation() population.add_source(StaticSource(brightness=100.0), 10.0) - population.add_source(StaticSource(brightness=_random_brightness), 10.0) + population.add_source(StaticSource(brightness=TDFunc(random.uniform, a=0.0, b=100.0)), 10.0) population.add_source(StaticSource(brightness=200.0), 20.0) population.add_source(StaticSource(brightness=150.0), 10.0) From b9bc87cd78da53e86ae3760e29aa3cc99ab56af9 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Sun, 14 Jul 2024 16:10:48 -0400 Subject: [PATCH 20/31] All function chaining --- src/tdastro/astro_utils/cosmology.py | 18 +++++++-- src/tdastro/base_models.py | 10 +++-- src/tdastro/function_wrappers.py | 52 +++++++++++-------------- tests/tdastro/test_function_wrappers.py | 41 +++++++++++-------- 4 files changed, 67 insertions(+), 54 deletions(-) diff --git a/src/tdastro/astro_utils/cosmology.py b/src/tdastro/astro_utils/cosmology.py index f7fbb6ba..ecbd16ab 100644 --- a/src/tdastro/astro_utils/cosmology.py +++ b/src/tdastro/astro_utils/cosmology.py @@ -38,20 +38,30 @@ class RedshiftDistFunc(TDFunc): kind : `str` The distance type for the Equivalency as defined by astropy.cosmology.units.redshift_distance. + + Parameters + ---------- + redshift : function or constant + The function or constant providing the redshift value. + cosmology : `astropy.cosmology` + The cosmology specification. + kind : `str` + The distance type for the Equivalency as defined by + astropy.cosmology.units.redshift_distance. """ - def __init__(self, cosmology, kind="comoving", **kwargs): + def __init__(self, redshift, cosmology, kind="comoving"): self.cosmology = cosmology self.kind = kind - default_args = {"cosmology": cosmology, "kind": kind} # Call the super class's constructor with the needed information. # Do not pass kwargs because we are limiting the arguments to match # the signature of redshift_to_distance(). super().__init__( func=redshift_to_distance, - default_args=default_args, - object_args=["redshift"], + redshift=redshift, + cosmology=cosmology, + kind=kind, ) def __str__(self): diff --git a/src/tdastro/base_models.py b/src/tdastro/base_models.py index d612d5fd..5e5746f3 100644 --- a/src/tdastro/base_models.py +++ b/src/tdastro/base_models.py @@ -86,7 +86,7 @@ def set_parameter(self, name, value=None, **kwargs): elif isinstance(value, TDFunc): # Case 1b: We are using a TDFunc wrapped function (with default parameters). self.setters[name] = (ParameterSource.TDFUNC_OBJ, value, required) - setattr(self, name, value(self, **kwargs)) + setattr(self, name, value(**kwargs)) elif isinstance(value, types.MethodType) and isinstance(value.__self__, ParameterizedModel): # Case 2: We are trying to use the method from a ParameterizedModel. # Note that this will (correctly) fail if we are adding a model method from the current @@ -175,7 +175,7 @@ def sample_parameters(self, max_depth=50, **kwargs): elif source_type == ParameterSource.FUNCTION: sampled_value = setter(**kwargs) elif source_type == ParameterSource.TDFUNC_OBJ: - sampled_value = setter(self, **kwargs) + sampled_value = setter(**kwargs) elif source_type == ParameterSource.MODEL_ATTRIBUTE: # Check if we need to resample the parent (needs to be done before # we read its attribute). @@ -238,7 +238,7 @@ def __init__(self, ra=None, dec=None, redshift=None, distance=None, background=N if distance is not None: self.add_parameter("distance", distance) elif redshift is not None and kwargs.get("cosmology", None) is not None: - self.add_parameter("distance", RedshiftDistFunc(**kwargs)) + self.add_parameter("distance", RedshiftDistFunc(redshift=self.get_redshift, **kwargs)) else: self.add_parameter("distance", None) @@ -249,6 +249,10 @@ def __str__(self): """Return the string representation of the model.""" return "PhysicalModel" + def get_redshift(self): + """Return the redshift for the model.""" + return self.redshift + def add_effect(self, effect, allow_dups=True, **kwargs): """Add a transformational effect to the PhysicalModel. Effects are applied in the order in which they are added. diff --git a/src/tdastro/function_wrappers.py b/src/tdastro/function_wrappers.py index ed24aa3f..5a7aad6e 100644 --- a/src/tdastro/function_wrappers.py +++ b/src/tdastro/function_wrappers.py @@ -1,14 +1,10 @@ -"""Classes to wrap functions to allow users to pass around functions -with partially specified arguments. -""" +"""Utilities to wrap functions for inclusion in an evaluation graph.""" import copy class TDFunc: - """A class to wrap functions to pass around functions with default - argument settings, arguments from kwargs, and (optionally) - arguments that are from the calling object. + """A class to wrap functions and their argument settings. Attributes ---------- @@ -17,11 +13,13 @@ class TDFunc: default_args : `dict` A dictionary of default arguments to pass to the function. Assembled from the ```default_args`` parameter and additional ``kwargs``. - object_args : `list` - Arguments that are provided by attributes of the calling object. + setter_functions : `dict` + A dictionary mapping arguments names to functions, methods, or + TDFunc objects used to set that argument. These are evaluated dynamically + each time. - Example - ------- + Examples + -------- my_func = TDFunc(random.randint, a=1, b=10) value1 = my_func() # Sample from default range value2 = my_func(b=20) # Sample from extended range @@ -38,45 +36,39 @@ class TDFunc: value1 = my_func(b=10.0) """ - def __init__(self, func, default_args=None, object_args=None, **kwargs): + def __init__(self, func, **kwargs): self.func = func - self.object_args = object_args self.default_args = {} + self.setter_functions = {} - # The default arguments are the union of the default_args parameter, - # the object_args (set to None), and the remaining kwargs. - if default_args is not None: - self.default_args = default_args - if kwargs: - for key, value in kwargs.items(): - self.default_args[key] = value - if object_args: - for key in object_args: + for key, value in kwargs.items(): + if callable(value): self.default_args[key] = None + self.setter_functions[key] = value + else: + self.default_args[key] = value def __str__(self): """Return the string representation of the function.""" return f"TDFunc({self.func.name})" - def __call__(self, calling_object=None, **kwargs): + def __call__(self, **kwargs): """Execute the wrapped function. Parameters ---------- - calling_object : any, optional - The object that called the function. **kwargs : `dict`, optional Additional function arguments. """ # Start with the default arguments. We make a copy so we can modify the dictionary. args = copy.copy(self.default_args) - # If there are arguments to get from the calling object, set those. - if self.object_args is not None and len(self.object_args) > 0: - if calling_object is None: - raise ValueError(f"Calling object needed for parameters: {self.object_args}") - for arg_name in self.object_args: - args[arg_name] = getattr(calling_object, arg_name) + # If there are arguments to get from the calling functions, set those. + for key, value in self.setter_functions.items(): + if isinstance(value, TDFunc): + args[key] = value(**kwargs) + else: + args[key] = value() # Set any last arguments from the kwargs (overwriting previous settings). # Only use known kwargs. diff --git a/tests/tdastro/test_function_wrappers.py b/tests/tdastro/test_function_wrappers.py index ed59170f..e7048391 100644 --- a/tests/tdastro/test_function_wrappers.py +++ b/tests/tdastro/test_function_wrappers.py @@ -31,17 +31,13 @@ def __init__(self, a, b): self.a = a self.b = b - def eval_func(self, func, **kwargs): - """Evaluate a TDFunc. + def get_a(self): + """Get the a attribute.""" + return self.a - Parameters - ---------- - func : `TDFunc` - The function to evaluate. - **kwargs : `dict`, optional - Any additional keyword arguments. - """ - return func(self, **kwargs) + def get_b(self): + """Get the b attribute.""" + return self.b def test_tdfunc_basic(): @@ -69,6 +65,16 @@ def test_tdfunc_basic(): assert tdf3() == 3.0 +def test_tdfunc_chain(): + """Test that we can create and query a chained TDFunc.""" + tdf1 = TDFunc(_test_func, a=1.0, b=1.0) + tdf2 = TDFunc(_test_func, a=tdf1, b=3.0) + assert tdf2() == 5.0 + + # This will overwrite all the b parameters. + assert tdf2(b=10.0) == 21.0 + + def test_np_sampler_method(): """Test that we can wrap numpy random functions.""" rng = np.random.default_rng(1001) @@ -91,12 +97,13 @@ def test_tdfunc_obj(): model = _StaticModel(a=10.0, b=11.0) # Check a function without defaults. - tdf1 = TDFunc(_test_func, object_args=["a", "b"]) - assert model.eval_func(tdf1) == 21.0 + tdf1 = TDFunc(_test_func, a=model.get_a, b=model.get_b) + assert tdf1() == 21.0 - # Defaults set are overwritten by the object. - tdf2 = TDFunc(_test_func, object_args=["a", "b"], a=1.0, b=0.0) - assert model.eval_func(tdf2) == 21.0 + # We can pull from multiple models. + model2 = _StaticModel(a=1.0, b=0.0) + tdf2 = TDFunc(_test_func, a=model.get_a, b=model2.get_b) + assert tdf2() == 10.0 - # But we can overwrite everything with kwargs. - assert model.eval_func(tdf2, b=7.5) == 17.5 + # W can overwrite everything with kwargs. + assert tdf2(b=7.5) == 17.5 From 96f52bca158c7c344987ef8e6ab80294dc03e590 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Sun, 14 Jul 2024 16:37:18 -0400 Subject: [PATCH 21/31] refactor --- src/tdastro/base_models.py | 334 +----------------- src/tdastro/effects/effect_model.py | 47 +++ src/tdastro/effects/white_noise.py | 2 +- src/tdastro/populations/fixed_population.py | 2 +- src/tdastro/populations/population_model.py | 106 ++++++ src/tdastro/sources/galaxy_models.py | 2 +- src/tdastro/sources/periodic_source.py | 2 +- src/tdastro/sources/physical_model.py | 167 +++++++++ src/tdastro/sources/sncomso_models.py | 2 +- src/tdastro/sources/spline_model.py | 2 +- src/tdastro/sources/static_source.py | 2 +- tests/tdastro/sources/test_physical_models.py | 31 ++ tests/tdastro/test_base_models.py | 37 +- 13 files changed, 374 insertions(+), 362 deletions(-) create mode 100644 src/tdastro/effects/effect_model.py create mode 100644 src/tdastro/populations/population_model.py create mode 100644 src/tdastro/sources/physical_model.py create mode 100644 tests/tdastro/sources/test_physical_models.py diff --git a/src/tdastro/base_models.py b/src/tdastro/base_models.py index 5e5746f3..d94ece2d 100644 --- a/src/tdastro/base_models.py +++ b/src/tdastro/base_models.py @@ -3,9 +3,6 @@ import types from enum import Enum -import numpy as np - -from tdastro.astro_utils.cosmology import RedshiftDistFunc from tdastro.function_wrappers import TDFunc @@ -22,9 +19,9 @@ class ParameterSource(Enum): MODEL_METHOD = 5 -class ParameterizedModel: +class ParameterizedNode: """Any model that uses parameters that can be set by constants, - functions, or other parameterized models. ParameterizedModels can + functions, or other parameterized models. ParameterizedNode can include physical objects or statistical distributions. Attributes @@ -34,7 +31,7 @@ class ParameterizedModel: (ParameterSource, setter information, required). The attributes are stored in the order in which they need to be set. sample_iteration : `int` - A counter used to syncronize sampling runs. Tracks how many times this + A counter used to syncronize sampling runs. Tracks how many times this model's parameters have been resampled. """ @@ -44,10 +41,10 @@ def __init__(self, **kwargs): def __str__(self): """Return the string representation of the model.""" - return "ParameterizedModel" + return "ParameterizedNode" def set_parameter(self, name, value=None, **kwargs): - """Set a single *existing* parameter to the ParameterizedModel. + """Set a single *existing* parameter to the ParameterizedNode. Notes ----- @@ -60,7 +57,7 @@ def set_parameter(self, name, value=None, **kwargs): The parameter name to add. value : any, optional The information to use to set the parameter. Can be a constant, - function, ParameterizedModel, or self. + function, ParameterizedNode, or self. **kwargs : `dict`, optional All other keyword arguments, possibly including the parameter setters. @@ -87,13 +84,13 @@ def set_parameter(self, name, value=None, **kwargs): # Case 1b: We are using a TDFunc wrapped function (with default parameters). self.setters[name] = (ParameterSource.TDFUNC_OBJ, value, required) setattr(self, name, value(**kwargs)) - elif isinstance(value, types.MethodType) and isinstance(value.__self__, ParameterizedModel): - # Case 2: We are trying to use the method from a ParameterizedModel. + elif isinstance(value, types.MethodType) and isinstance(value.__self__, ParameterizedNode): + # Case 2: We are trying to use the method from a ParameterizedNode. # Note that this will (correctly) fail if we are adding a model method from the current # object that requires an unset attribute. self.setters[name] = (ParameterSource.MODEL_METHOD, value, required) setattr(self, name, value(**kwargs)) - elif isinstance(value, ParameterizedModel): + elif isinstance(value, ParameterizedNode): # Case 3: We are trying to access an attribute from a parameterized model. if not hasattr(value, name): raise ValueError(f"Attribute {name} missing from parent.") @@ -108,7 +105,7 @@ def set_parameter(self, name, value=None, **kwargs): raise ValueError(f"Missing required parameter {name}") def add_parameter(self, name, value=None, required=False, **kwargs): - """Add a single *new* parameter to the ParameterizedModel. + """Add a single *new* parameter to the ParameterizedNode. Notes ----- @@ -123,7 +120,7 @@ def add_parameter(self, name, value=None, required=False, **kwargs): The parameter name to add. value : any, optional The information to use to set the parameter. Can be a constant, - function, ParameterizedModel, or self. + function, ParameterizedNode, or self. required : `bool` Fail if the parameter is set to ``None``. **kwargs : `dict`, optional @@ -146,7 +143,7 @@ def add_parameter(self, name, value=None, required=False, **kwargs): def sample_parameters(self, max_depth=50, **kwargs): """Sample the model's underlying parameters if they are provided by a function - or ParameterizedModel. + or ParameterizedNode. Parameters ---------- @@ -195,310 +192,3 @@ def sample_parameters(self, max_depth=50, **kwargs): # Increase the sampling iteration. self.sample_iteration += 1 - - -class PhysicalModel(ParameterizedModel): - """A physical model of a source of flux. - - Physical models can have fixed attributes (where you need to create a new model - to change them) and settable attributes that can be passed functions or constants. - They can also have special background pointers that link to another PhysicalModel - producing flux. We can chain these to have a supernova in front of a star in front - of a static background. - - Attributes - ---------- - ra : `float` - The object's right ascension (in degrees) - dec : `float` - The object's declination (in degrees) - redshift : `float` - The object's redshift. - distance : `float` - The object's distance (in pc). If the distance is not provided and - a ``cosmology`` parameter is given, the model will try to derive from - the redshift and the cosmology. - background : `PhysicalModel` - A source of background flux such as a host galaxy. - effects : `list` - A list of effects to apply to an observations. - """ - - def __init__(self, ra=None, dec=None, redshift=None, distance=None, background=None, **kwargs): - super().__init__(**kwargs) - self.effects = [] - - # Set RA, dec, and redshift from the parameters. - self.add_parameter("ra", ra) - self.add_parameter("dec", dec) - self.add_parameter("redshift", redshift) - - # If the distance is provided, use that. Otherwise try the redshift value - # using the cosmology (if given). Finally, default to None. - if distance is not None: - self.add_parameter("distance", distance) - elif redshift is not None and kwargs.get("cosmology", None) is not None: - self.add_parameter("distance", RedshiftDistFunc(redshift=self.get_redshift, **kwargs)) - else: - self.add_parameter("distance", None) - - # Background is an object not a sampled parameter - self.background = background - - def __str__(self): - """Return the string representation of the model.""" - return "PhysicalModel" - - def get_redshift(self): - """Return the redshift for the model.""" - return self.redshift - - def add_effect(self, effect, allow_dups=True, **kwargs): - """Add a transformational effect to the PhysicalModel. - Effects are applied in the order in which they are added. - - Parameters - ---------- - effect : `EffectModel` - The effect to apply. - allow_dups : `bool` - Allow multiple effects of the same type. - Default = ``True`` - **kwargs : `dict`, optional - Any additional keyword arguments. - - Raises - ------ - Raises a ``AttributeError`` if the PhysicalModel does not have all of the - required attributes. - """ - # Check that we have not added this effect before. - if not allow_dups: - effect_type = type(effect) - for prev in self.effects: - if effect_type == type(prev): - raise ValueError("Added the effect type to a model {effect_type} more than once.") - - required: list = effect.required_parameters() - for parameter in required: - # Raise an AttributeError if the parameter is missing or set to None. - if getattr(self, parameter) is None: - raise AttributeError(f"Parameter {parameter} unset for model {type(self).__name__}") - - self.effects.append(effect) - - def _evaluate(self, times, wavelengths, **kwargs): - """Draw effect-free observations for this object. - - Parameters - ---------- - times : `numpy.ndarray` - A length T array of timestamps. - wavelengths : `numpy.ndarray`, optional - A length N array of wavelengths. - **kwargs : `dict`, optional - Any additional keyword arguments. - - Returns - ------- - flux_density : `numpy.ndarray` - A length T x N matrix of SED values. - """ - raise NotImplementedError() - - def evaluate(self, times, wavelengths, resample_parameters=False, **kwargs): - """Draw observations for this object and apply the noise. - - Parameters - ---------- - times : `numpy.ndarray` - A length T array of timestamps. - wavelengths : `numpy.ndarray`, optional - A length N array of wavelengths. - resample_parameters : `bool` - Treat this evaluation as a completely new object, resampling the - parameters from the original provided functions. - **kwargs : `dict`, optional - Any additional keyword arguments. - - Returns - ------- - flux_density : `numpy.ndarray` - A length T x N matrix of SED values. - """ - if resample_parameters: - self.sample_parameters(kwargs) - - # Compute the flux density for both the current object and add in anything - # behind it, such as a host galaxy. - flux_density = self._evaluate(times, wavelengths, **kwargs) - if self.background is not None: - flux_density += self.background._evaluate(times, wavelengths, ra=self.ra, dec=self.dec, **kwargs) - - for effect in self.effects: - flux_density = effect.apply(flux_density, wavelengths, self, **kwargs) - return flux_density - - def sample_parameters(self, include_effects=True, **kwargs): - """Sample the model's underlying parameters if they are provided by a function - or ParameterizedModel. - - Parameters - ---------- - include_effects : `bool` - Resample the parameters for the effects models. - **kwargs : `dict`, optional - All the keyword arguments, including the values needed to sample - parameters. - """ - if self.background is not None: - self.background.sample_parameters(include_effects, **kwargs) - super().sample_parameters(**kwargs) - - if include_effects: - for effect in self.effects: - effect.sample_parameters(**kwargs) - - -class EffectModel(ParameterizedModel): - """A physical or systematic effect to apply to an observation.""" - - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def __str__(self): - """Return the string representation of the model.""" - return "EffectModel" - - def required_parameters(self): - """Returns a list of the parameters of a PhysicalModel - that this effect needs to access. - - Returns - ------- - parameters : `list` of `str` - A list of every required parameter the effect needs. - """ - return [] - - def apply(self, flux_density, wavelengths=None, physical_model=None, **kwargs): - """Apply the effect to observations (flux_density values) - - Parameters - ---------- - flux_density : `numpy.ndarray` - A length T X N matrix of flux density values. - wavelengths : `numpy.ndarray`, optional - A length N array of wavelengths. - physical_model : `PhysicalModel` - A PhysicalModel from which the effect may query parameters - such as redshift, position, or distance. - **kwargs : `dict`, optional - Any additional keyword arguments. - - Returns - ------- - flux_density : `numpy.ndarray` - A length T x N matrix of flux densities after the effect is applied. - """ - raise NotImplementedError() - - -class PopulationModel(ParameterizedModel): - """A model of a population of PhysicalModels. - - Attributes - ---------- - num_sources : `int` - The number of different sources in the population. - sources : `list` - A list of sources from which to draw. - """ - - def __init__(self, rng=None, **kwargs): - super().__init__(**kwargs) - self.num_sources = 0 - self.sources = [] - - def __str__(self): - """Return the string representation of the model.""" - return f"PopulationModel with {self.num_sources} sources." - - def add_source(self, new_source, **kwargs): - """Add a new source to the population. - - Parameters - ---------- - new_source : `PhysicalModel` - A source from the population. - **kwargs : `dict`, optional - Any additional keyword arguments. - """ - if not isinstance(new_source, PhysicalModel): - raise ValueError("All sources must be PhysicalModels") - self.sources.append(new_source) - self.num_sources += 1 - - def draw_source(self): - """Sample a single source from the population. - - Returns - ------- - source : `PhysicalModel` - A source from the population. - """ - raise NotImplementedError() - - def add_effect(self, effect, allow_dups=False, **kwargs): - """Add a transformational effect to all PhysicalModels in this population. - Effects are applied in the order in which they are added. - - Parameters - ---------- - effect : `EffectModel` - The effect to apply. - allow_dups : `bool` - Allow multiple effects of the same type. - Default = ``True`` - **kwargs : `dict`, optional - Any additional keyword arguments. - - Raises - ------ - Raises a ``AttributeError`` if the PhysicalModel does not have all of the - required attributes. - """ - for source in self.sources: - source.add_effect(effect, allow_dups=allow_dups, **kwargs) - - def evaluate(self, samples, times, wavelengths, resample_parameters=False, **kwargs): - """Draw observations from a single (randomly sampled) source. - - Parameters - ---------- - samples : `int` - The number of sources to samples. - times : `numpy.ndarray` - A length T array of timestamps. - wavelengths : `numpy.ndarray`, optional - A length N array of wavelengths. - resample_parameters : `bool` - Treat this evaluation as a completely new object, resampling the - parameters from the original provided functions. - **kwargs : `dict`, optional - Any additional keyword arguments. - - Returns - ------- - results : `numpy.ndarray` - A shape (samples, T, N) matrix of SED values. - """ - if samples <= 0: - raise ValueError("The number of samples must be > 0.") - - results = [] - for _ in range(samples): - source = self.draw_source() - object_fluxes = source.evaluate(times, wavelengths, resample_parameters, **kwargs) - results.append(object_fluxes) - return np.array(results) diff --git a/src/tdastro/effects/effect_model.py b/src/tdastro/effects/effect_model.py new file mode 100644 index 00000000..d5b80bf8 --- /dev/null +++ b/src/tdastro/effects/effect_model.py @@ -0,0 +1,47 @@ +"""The base EffectModel class used for all effects.""" + +from tdastro.base_models import ParameterizedNode + + +class EffectModel(ParameterizedNode): + """A physical or systematic effect to apply to an observation.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def __str__(self): + """Return the string representation of the model.""" + return "EffectModel" + + def required_parameters(self): + """Returns a list of the parameters of a PhysicalModel + that this effect needs to access. + + Returns + ------- + parameters : `list` of `str` + A list of every required parameter the effect needs. + """ + return [] + + def apply(self, flux_density, wavelengths=None, physical_model=None, **kwargs): + """Apply the effect to observations (flux_density values) + + Parameters + ---------- + flux_density : `numpy.ndarray` + A length T X N matrix of flux density values. + wavelengths : `numpy.ndarray`, optional + A length N array of wavelengths. + physical_model : `PhysicalModel` + A PhysicalModel from which the effect may query parameters + such as redshift, position, or distance. + **kwargs : `dict`, optional + Any additional keyword arguments. + + Returns + ------- + flux_density : `numpy.ndarray` + A length T x N matrix of flux densities after the effect is applied. + """ + raise NotImplementedError() diff --git a/src/tdastro/effects/white_noise.py b/src/tdastro/effects/white_noise.py index 49f2e621..b7d3a0a7 100644 --- a/src/tdastro/effects/white_noise.py +++ b/src/tdastro/effects/white_noise.py @@ -1,6 +1,6 @@ import numpy as np -from tdastro.base_models import EffectModel +from tdastro.effects.effect_model import EffectModel class WhiteNoise(EffectModel): diff --git a/src/tdastro/populations/fixed_population.py b/src/tdastro/populations/fixed_population.py index bc644b4a..18f3f0e9 100644 --- a/src/tdastro/populations/fixed_population.py +++ b/src/tdastro/populations/fixed_population.py @@ -1,6 +1,6 @@ import random -from tdastro.base_models import PopulationModel +from tdastro.populations.population_model import PopulationModel class FixedPopulation(PopulationModel): diff --git a/src/tdastro/populations/population_model.py b/src/tdastro/populations/population_model.py new file mode 100644 index 00000000..c527038a --- /dev/null +++ b/src/tdastro/populations/population_model.py @@ -0,0 +1,106 @@ +"""The base population models.""" + +import numpy as np + +from tdastro.base_models import ParameterizedNode +from tdastro.sources.physical_model import PhysicalModel + + +class PopulationModel(ParameterizedNode): + """A model of a population of PhysicalModels. + + Attributes + ---------- + num_sources : `int` + The number of different sources in the population. + sources : `list` + A list of sources from which to draw. + """ + + def __init__(self, rng=None, **kwargs): + super().__init__(**kwargs) + self.num_sources = 0 + self.sources = [] + + def __str__(self): + """Return the string representation of the model.""" + return f"PopulationModel with {self.num_sources} sources." + + def add_source(self, new_source, **kwargs): + """Add a new source to the population. + + Parameters + ---------- + new_source : `PhysicalModel` + A source from the population. + **kwargs : `dict`, optional + Any additional keyword arguments. + """ + if not isinstance(new_source, PhysicalModel): + raise ValueError("All sources must be PhysicalModels") + self.sources.append(new_source) + self.num_sources += 1 + + def draw_source(self): + """Sample a single source from the population. + + Returns + ------- + source : `PhysicalModel` + A source from the population. + """ + raise NotImplementedError() + + def add_effect(self, effect, allow_dups=False, **kwargs): + """Add a transformational effect to all PhysicalModels in this population. + Effects are applied in the order in which they are added. + + Parameters + ---------- + effect : `EffectModel` + The effect to apply. + allow_dups : `bool` + Allow multiple effects of the same type. + Default = ``True`` + **kwargs : `dict`, optional + Any additional keyword arguments. + + Raises + ------ + Raises a ``AttributeError`` if the PhysicalModel does not have all of the + required attributes. + """ + for source in self.sources: + source.add_effect(effect, allow_dups=allow_dups, **kwargs) + + def evaluate(self, samples, times, wavelengths, resample_parameters=False, **kwargs): + """Draw observations from a single (randomly sampled) source. + + Parameters + ---------- + samples : `int` + The number of sources to samples. + times : `numpy.ndarray` + A length T array of timestamps. + wavelengths : `numpy.ndarray`, optional + A length N array of wavelengths. + resample_parameters : `bool` + Treat this evaluation as a completely new object, resampling the + parameters from the original provided functions. + **kwargs : `dict`, optional + Any additional keyword arguments. + + Returns + ------- + results : `numpy.ndarray` + A shape (samples, T, N) matrix of SED values. + """ + if samples <= 0: + raise ValueError("The number of samples must be > 0.") + + results = [] + for _ in range(samples): + source = self.draw_source() + object_fluxes = source.evaluate(times, wavelengths, resample_parameters, **kwargs) + results.append(object_fluxes) + return np.array(results) diff --git a/src/tdastro/sources/galaxy_models.py b/src/tdastro/sources/galaxy_models.py index 9aea8bfc..eadb853a 100644 --- a/src/tdastro/sources/galaxy_models.py +++ b/src/tdastro/sources/galaxy_models.py @@ -1,7 +1,7 @@ import numpy as np from astropy.coordinates import angular_separation -from tdastro.base_models import PhysicalModel +from tdastro.sources.physical_model import PhysicalModel class GaussianGalaxy(PhysicalModel): diff --git a/src/tdastro/sources/periodic_source.py b/src/tdastro/sources/periodic_source.py index b68c02da..9b3cd81f 100644 --- a/src/tdastro/sources/periodic_source.py +++ b/src/tdastro/sources/periodic_source.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod -from tdastro.base_models import PhysicalModel +from tdastro.sources.physical_model import PhysicalModel class PeriodicSource(PhysicalModel, ABC): diff --git a/src/tdastro/sources/physical_model.py b/src/tdastro/sources/physical_model.py new file mode 100644 index 00000000..8cf4946e --- /dev/null +++ b/src/tdastro/sources/physical_model.py @@ -0,0 +1,167 @@ +"""The base PhysicalModel used for all sources.""" + +from tdastro.astro_utils.cosmology import RedshiftDistFunc +from tdastro.base_models import ParameterizedNode + + +class PhysicalModel(ParameterizedNode): + """A physical model of a source of flux. + + Physical models can have fixed attributes (where you need to create a new model + to change them) and settable attributes that can be passed functions or constants. + They can also have special background pointers that link to another PhysicalModel + producing flux. We can chain these to have a supernova in front of a star in front + of a static background. + + Attributes + ---------- + ra : `float` + The object's right ascension (in degrees) + dec : `float` + The object's declination (in degrees) + redshift : `float` + The object's redshift. + distance : `float` + The object's distance (in pc). If the distance is not provided and + a ``cosmology`` parameter is given, the model will try to derive from + the redshift and the cosmology. + background : `PhysicalModel` + A source of background flux such as a host galaxy. + effects : `list` + A list of effects to apply to an observations. + """ + + def __init__(self, ra=None, dec=None, redshift=None, distance=None, background=None, **kwargs): + super().__init__(**kwargs) + self.effects = [] + + # Set RA, dec, and redshift from the parameters. + self.add_parameter("ra", ra) + self.add_parameter("dec", dec) + self.add_parameter("redshift", redshift) + + # If the distance is provided, use that. Otherwise try the redshift value + # using the cosmology (if given). Finally, default to None. + if distance is not None: + self.add_parameter("distance", distance) + elif redshift is not None and kwargs.get("cosmology", None) is not None: + self.add_parameter("distance", RedshiftDistFunc(redshift=self.get_redshift, **kwargs)) + else: + self.add_parameter("distance", None) + + # Background is an object not a sampled parameter + self.background = background + + def __str__(self): + """Return the string representation of the model.""" + return "PhysicalModel" + + def get_redshift(self): + """Return the redshift for the model.""" + return self.redshift + + def add_effect(self, effect, allow_dups=True, **kwargs): + """Add a transformational effect to the PhysicalModel. + Effects are applied in the order in which they are added. + + Parameters + ---------- + effect : `EffectModel` + The effect to apply. + allow_dups : `bool` + Allow multiple effects of the same type. + Default = ``True`` + **kwargs : `dict`, optional + Any additional keyword arguments. + + Raises + ------ + Raises a ``AttributeError`` if the PhysicalModel does not have all of the + required attributes. + """ + # Check that we have not added this effect before. + if not allow_dups: + effect_type = type(effect) + for prev in self.effects: + if effect_type == type(prev): + raise ValueError("Added the effect type to a model {effect_type} more than once.") + + required: list = effect.required_parameters() + for parameter in required: + # Raise an AttributeError if the parameter is missing or set to None. + if getattr(self, parameter) is None: + raise AttributeError(f"Parameter {parameter} unset for model {type(self).__name__}") + + self.effects.append(effect) + + def _evaluate(self, times, wavelengths, **kwargs): + """Draw effect-free observations for this object. + + Parameters + ---------- + times : `numpy.ndarray` + A length T array of timestamps. + wavelengths : `numpy.ndarray`, optional + A length N array of wavelengths. + **kwargs : `dict`, optional + Any additional keyword arguments. + + Returns + ------- + flux_density : `numpy.ndarray` + A length T x N matrix of SED values. + """ + raise NotImplementedError() + + def evaluate(self, times, wavelengths, resample_parameters=False, **kwargs): + """Draw observations for this object and apply the noise. + + Parameters + ---------- + times : `numpy.ndarray` + A length T array of timestamps. + wavelengths : `numpy.ndarray`, optional + A length N array of wavelengths. + resample_parameters : `bool` + Treat this evaluation as a completely new object, resampling the + parameters from the original provided functions. + **kwargs : `dict`, optional + Any additional keyword arguments. + + Returns + ------- + flux_density : `numpy.ndarray` + A length T x N matrix of SED values. + """ + if resample_parameters: + self.sample_parameters(kwargs) + + # Compute the flux density for both the current object and add in anything + # behind it, such as a host galaxy. + flux_density = self._evaluate(times, wavelengths, **kwargs) + if self.background is not None: + flux_density += self.background._evaluate(times, wavelengths, ra=self.ra, dec=self.dec, **kwargs) + + for effect in self.effects: + flux_density = effect.apply(flux_density, wavelengths, self, **kwargs) + return flux_density + + def sample_parameters(self, include_effects=True, **kwargs): + """Sample the model's underlying parameters if they are provided by a function + or ParameterizedModel. + + Parameters + ---------- + include_effects : `bool` + Resample the parameters for the effects models. + **kwargs : `dict`, optional + All the keyword arguments, including the values needed to sample + parameters. + """ + if self.background is not None: + self.background.sample_parameters(include_effects, **kwargs) + super().sample_parameters(**kwargs) + + if include_effects: + for effect in self.effects: + effect.sample_parameters(**kwargs) diff --git a/src/tdastro/sources/sncomso_models.py b/src/tdastro/sources/sncomso_models.py index 379de41e..b176233e 100644 --- a/src/tdastro/sources/sncomso_models.py +++ b/src/tdastro/sources/sncomso_models.py @@ -6,7 +6,7 @@ from sncosmo.models import get_source -from tdastro.base_models import PhysicalModel +from tdastro.sources.physical_model import PhysicalModel class SncosmoWrapperModel(PhysicalModel): diff --git a/src/tdastro/sources/spline_model.py b/src/tdastro/sources/spline_model.py index eff26dff..67e79202 100644 --- a/src/tdastro/sources/spline_model.py +++ b/src/tdastro/sources/spline_model.py @@ -7,7 +7,7 @@ from scipy.interpolate import RectBivariateSpline -from tdastro.base_models import PhysicalModel +from tdastro.sources.physical_model import PhysicalModel class SplineModel(PhysicalModel): diff --git a/src/tdastro/sources/static_source.py b/src/tdastro/sources/static_source.py index 44799c2e..f69eea89 100644 --- a/src/tdastro/sources/static_source.py +++ b/src/tdastro/sources/static_source.py @@ -1,6 +1,6 @@ import numpy as np -from tdastro.base_models import PhysicalModel +from tdastro.sources.physical_model import PhysicalModel class StaticSource(PhysicalModel): diff --git a/tests/tdastro/sources/test_physical_models.py b/tests/tdastro/sources/test_physical_models.py new file mode 100644 index 00000000..cf18bf1f --- /dev/null +++ b/tests/tdastro/sources/test_physical_models.py @@ -0,0 +1,31 @@ +import pytest +from astropy.cosmology import WMAP9 +from tdastro.sources.physical_model import PhysicalModel + + +def test_physical_model(): + """Test that we can create a physical model.""" + # Everything is specified. + model1 = PhysicalModel(ra=1.0, dec=2.0, distance=3.0, redshift=0.0) + assert model1.ra == 1.0 + assert model1.dec == 2.0 + assert model1.distance == 3.0 + assert model1.redshift == 0.0 + + # Derive the distance from the redshift using the example from: + # https://docs.astropy.org/en/stable/api/astropy.cosmology.units.redshift_distance.html + model2 = PhysicalModel(ra=1.0, dec=2.0, redshift=1100.0, cosmology=WMAP9) + assert model2.ra == 1.0 + assert model2.dec == 2.0 + assert model2.redshift == 1100.0 + assert model2.distance == pytest.approx(14004.03157418 * 1e6) + + # Neither distance nor redshift are specified. + model3 = PhysicalModel(ra=1.0, dec=2.0) + assert model3.redshift is None + assert model3.distance is None + + # Redshift is specified but cosmology is not. + model4 = PhysicalModel(ra=1.0, dec=2.0, redshift=1100.0) + assert model4.redshift == 1100.0 + assert model4.distance is None diff --git a/tests/tdastro/test_base_models.py b/tests/tdastro/test_base_models.py index 6847cffa..591b88b6 100644 --- a/tests/tdastro/test_base_models.py +++ b/tests/tdastro/test_base_models.py @@ -1,8 +1,7 @@ import random import pytest -from astropy.cosmology import WMAP9 -from tdastro.base_models import ParameterizedModel, PhysicalModel +from tdastro.base_models import ParameterizedNode def _sampler_fun(**kwargs): @@ -16,7 +15,7 @@ def _sampler_fun(**kwargs): return random.random() -class PairModel(ParameterizedModel): +class PairModel(ParameterizedNode): """A test class for the parameterized model. Attributes @@ -34,9 +33,9 @@ def __init__(self, value1, value2, **kwargs): Parameters ---------- - value1 : `float`, `function`, `ParameterizedModel`, or `None` + value1 : `float`, `function`, `ParameterizedNode`, or `None` The first value. - value2 : `float`, `function`, `ParameterizedModel`, or `None` + value2 : `float`, `function`, `ParameterizedNode`, or `None` The second value. **kwargs : `dict`, optional Any additional keyword arguments. @@ -142,31 +141,3 @@ def test_parameterized_model_modify() -> None: # We cannot set a value that hasn't been added. with pytest.raises(KeyError): model.set_parameter("brightness", 5.0) - - -def test_physical_model(): - """Test that we can create a physical model.""" - # Everything is specified. - model1 = PhysicalModel(ra=1.0, dec=2.0, distance=3.0, redshift=0.0) - assert model1.ra == 1.0 - assert model1.dec == 2.0 - assert model1.distance == 3.0 - assert model1.redshift == 0.0 - - # Derive the distance from the redshift using the example from: - # https://docs.astropy.org/en/stable/api/astropy.cosmology.units.redshift_distance.html - model2 = PhysicalModel(ra=1.0, dec=2.0, redshift=1100.0, cosmology=WMAP9) - assert model2.ra == 1.0 - assert model2.dec == 2.0 - assert model2.redshift == 1100.0 - assert model2.distance == pytest.approx(14004.03157418 * 1e6) - - # Neither distance nor redshift are specified. - model3 = PhysicalModel(ra=1.0, dec=2.0) - assert model3.redshift is None - assert model3.distance is None - - # Redshift is specified but cosmology is not. - model4 = PhysicalModel(ra=1.0, dec=2.0, redshift=1100.0) - assert model4.redshift == 1100.0 - assert model4.distance is None From 8b263f0070a7c1c88ca3c018313b6b6e3ee302e8 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Sun, 14 Jul 2024 19:54:04 -0400 Subject: [PATCH 22/31] Put everything in a single node interface --- src/tdastro/astro_utils/cosmology.py | 9 +- src/tdastro/base_models.py | 157 +++++++++++++----- src/tdastro/function_wrappers.py | 80 --------- src/tdastro/sources/physical_model.py | 12 +- .../populations/test_fixed_population.py | 5 +- tests/tdastro/sources/test_galaxy_models.py | 4 + tests/tdastro/sources/test_step_source.py | 6 +- tests/tdastro/test_base_models.py | 77 ++++++++- tests/tdastro/test_function_wrappers.py | 109 ------------ 9 files changed, 211 insertions(+), 248 deletions(-) delete mode 100644 src/tdastro/function_wrappers.py delete mode 100644 tests/tdastro/test_function_wrappers.py diff --git a/src/tdastro/astro_utils/cosmology.py b/src/tdastro/astro_utils/cosmology.py index ecbd16ab..1beecf70 100644 --- a/src/tdastro/astro_utils/cosmology.py +++ b/src/tdastro/astro_utils/cosmology.py @@ -1,7 +1,7 @@ import astropy.cosmology.units as cu from astropy import units as u -from tdastro.function_wrappers import TDFunc +from tdastro.base_models import FunctionNode def redshift_to_distance(redshift, cosmology, kind="comoving"): @@ -28,7 +28,7 @@ def redshift_to_distance(redshift, cosmology, kind="comoving"): return distance.value -class RedshiftDistFunc(TDFunc): +class RedshiftDistFunc(FunctionNode): """A wrapper class for the redshift_to_distance() function. Attributes @@ -51,12 +51,7 @@ class RedshiftDistFunc(TDFunc): """ def __init__(self, redshift, cosmology, kind="comoving"): - self.cosmology = cosmology - self.kind = kind - # Call the super class's constructor with the needed information. - # Do not pass kwargs because we are limiting the arguments to match - # the signature of redshift_to_distance(). super().__init__( func=redshift_to_distance, redshift=redshift, diff --git a/src/tdastro/base_models.py b/src/tdastro/base_models.py index d94ece2d..262913da 100644 --- a/src/tdastro/base_models.py +++ b/src/tdastro/base_models.py @@ -1,28 +1,23 @@ -"""The base models used throughout TDAstro including physical objects, effects, and populations.""" +"""The base models used to specify the TDAstro computation graph.""" import types from enum import Enum -from tdastro.function_wrappers import TDFunc - class ParameterSource(Enum): - """ParameterSource specifies where a PhysicalModel should get the value - for a given parameter: a constant value, a function, or from another - parameterized model. + """ParameterSource specifies where a ParameterizedNode should get the value + for a given parameter: a constant value or from another ParameterizedNode. """ CONSTANT = 1 - FUNCTION = 2 - TDFUNC_OBJ = 3 - MODEL_ATTRIBUTE = 4 - MODEL_METHOD = 5 + MODEL_ATTRIBUTE = 2 + MODEL_METHOD = 3 + FUNCTION = 4 class ParameterizedNode: """Any model that uses parameters that can be set by constants, - functions, or other parameterized models. ParameterizedNode can - include physical objects or statistical distributions. + functions, or other parameterized nodes. Attributes ---------- @@ -43,6 +38,36 @@ def __str__(self): """Return the string representation of the model.""" return "ParameterizedNode" + def check_resample(self, child): + """Check if we need to resample the current node based + on the state of a child trying to access its attributes. + + Parameters + ---------- + child : `ParameterizedNode` + The node that is accessing the attribute or method + of the current node. + + Returns + ------- + bool + Indicates whether to resample or not. + + Raises + ------ + ``ValueError`` if the graph has gotten out of sync. + """ + if child == self: + return False + if child.sample_iteration == self.sample_iteration: + return False + if child.sample_iteration != self.sample_iteration + 1: + raise ValueError( + f"Node {str(child)} at iteration {child.sample_iteration} accessing" + f" parent {str(self)} at iteration {self.sample_iteration}." + ) + return True + def set_parameter(self, name, value=None, **kwargs): """Set a single *existing* parameter to the ParameterizedNode. @@ -76,28 +101,26 @@ def set_parameter(self, name, value=None, **kwargs): # The value wasn't set, but the name is in kwargs. value = kwargs[name] - if isinstance(value, types.FunctionType): - # Case 1a: If we are getting from a static function, sample it. - self.setters[name] = (ParameterSource.FUNCTION, value, required) - setattr(self, name, value(**kwargs)) - elif isinstance(value, TDFunc): - # Case 1b: We are using a TDFunc wrapped function (with default parameters). - self.setters[name] = (ParameterSource.TDFUNC_OBJ, value, required) - setattr(self, name, value(**kwargs)) - elif isinstance(value, types.MethodType) and isinstance(value.__self__, ParameterizedNode): - # Case 2: We are trying to use the method from a ParameterizedNode. - # Note that this will (correctly) fail if we are adding a model method from the current - # object that requires an unset attribute. - self.setters[name] = (ParameterSource.MODEL_METHOD, value, required) + if callable(value): + if isinstance(value, types.FunctionType): + # Case 1a: This is a static function (not attached to an object). + self.setters[name] = (ParameterSource.FUNCTION, value, required) + elif isinstance(value.__self__, ParameterizedNode): + # Case 1b: This is a method attached to another ParameterizedNode. + self.setters[name] = (ParameterSource.MODEL_METHOD, value, required) + else: + # Case 1c: This is a general callable method from another object. + # We treat it as static (we don't resample the other object). + self.setters[name] = (ParameterSource.FUNCTION, value, required) setattr(self, name, value(**kwargs)) elif isinstance(value, ParameterizedNode): - # Case 3: We are trying to access an attribute from a parameterized model. + # Case 2: We are trying to access a parameter of another ParameterizedNode. if not hasattr(value, name): raise ValueError(f"Attribute {name} missing from parent.") self.setters[name] = (ParameterSource.MODEL_ATTRIBUTE, value, required) setattr(self, name, getattr(value, name)) else: - # Case 4: The value is constant (including None). + # Case 3: The value is constant (including None). self.setters[name] = (ParameterSource.CONSTANT, value, required) setattr(self, name, value) @@ -162,6 +185,9 @@ def sample_parameters(self, max_depth=50, **kwargs): if max_depth == 0: raise ValueError(f"Maximum sampling depth exceeded at {self}. Potential infinite loop.") + # Increase the sampling iteration. + self.sample_iteration += 1 + # Run through each parameter and sample it based on the given recipe. # As of Python 3.7 dictionaries are guaranteed to preserve insertion ordering, # so this will iterate through attributes in the order they were inserted. @@ -171,24 +197,77 @@ def sample_parameters(self, max_depth=50, **kwargs): sampled_value = setter elif source_type == ParameterSource.FUNCTION: sampled_value = setter(**kwargs) - elif source_type == ParameterSource.TDFUNC_OBJ: - sampled_value = setter(**kwargs) elif source_type == ParameterSource.MODEL_ATTRIBUTE: - # Check if we need to resample the parent (needs to be done before - # we read its attribute). - if setter.sample_iteration == self.sample_iteration: + # Check if we need to resample the parent (before accessing the attribute). + if setter.check_resample(self): setter.sample_parameters(max_depth - 1, **kwargs) sampled_value = getattr(setter, name) elif source_type == ParameterSource.MODEL_METHOD: - # Check if we need to resample the parent (needs to be done before - # we evaluate its method). Do not resample the current object. - parent = setter.__self__ - if parent is not self and parent.sample_iteration == self.sample_iteration: - parent.sample_parameters(max_depth - 1, **kwargs) + # Check if we need to resample the parent (before calling the method). + parent_node = setter.__self__ + if parent_node.check_resample(self): + parent_node.sample_parameters(max_depth - 1, **kwargs) sampled_value = setter(**kwargs) else: raise ValueError(f"Unknown ParameterSource type {source_type} for {name}") setattr(self, name, sampled_value) - # Increase the sampling iteration. - self.sample_iteration += 1 + +class FunctionNode(ParameterizedNode): + """A class to wrap functions and their argument settings. + + Attributes + ---------- + func : `function` or `method` + The function to call during an evaluation. + args_names : `list` + A list of argument names to pass to the function. + + Examples + -------- + my_func = TDFunc(random.randint, a=1, b=10) + value1 = my_func() # Sample from default range + value2 = my_func(b=20) # Sample from extended range + + Note + ---- + All the function's parameters that will be used need to be specified + in either the default_args dict, object_args list, or as a kwarg in the + constructor. Arguments cannot be first given during function call. + For example the following will fail (because b is not defined in the + constructor): + + my_func = TDFunc(random.randint, a=1) + value1 = my_func(b=10.0) + """ + + def __init__(self, func, **kwargs): + super().__init__(**kwargs) + self.func = func + self.arg_names = [] + + # Add all of the parameters from default_args or the kwargs. + for key, value in kwargs.items(): + self.arg_names.append(key) + self.add_parameter(key, value) + + def __str__(self): + """Return the string representation of the function.""" + return f"FunctionNode({self.func.name})" + + def compute(self, **kwargs): + """Execute the wrapped function. + + Parameters + ---------- + **kwargs : `dict`, optional + Additional function arguments. + """ + args = {} + for key in self.arg_names: + # Override with the kwarg if the parameter is there. + if key in kwargs: + args[key] = kwargs[key] + else: + args[key] = getattr(self, key) + return self.func(**args) diff --git a/src/tdastro/function_wrappers.py b/src/tdastro/function_wrappers.py deleted file mode 100644 index 5a7aad6e..00000000 --- a/src/tdastro/function_wrappers.py +++ /dev/null @@ -1,80 +0,0 @@ -"""Utilities to wrap functions for inclusion in an evaluation graph.""" - -import copy - - -class TDFunc: - """A class to wrap functions and their argument settings. - - Attributes - ---------- - func : `function` or `method` - The function to call during an evaluation. - default_args : `dict` - A dictionary of default arguments to pass to the function. Assembled - from the ```default_args`` parameter and additional ``kwargs``. - setter_functions : `dict` - A dictionary mapping arguments names to functions, methods, or - TDFunc objects used to set that argument. These are evaluated dynamically - each time. - - Examples - -------- - my_func = TDFunc(random.randint, a=1, b=10) - value1 = my_func() # Sample from default range - value2 = my_func(b=20) # Sample from extended range - - Note - ---- - All the function's parameters that will be used need to be specified - in either the default_args dict, object_args list, or as a kwarg in the - constructor. Arguments cannot be first given during function call. - For example the following will fail (because b is not defined in the - constructor): - - my_func = TDFunc(random.randint, a=1) - value1 = my_func(b=10.0) - """ - - def __init__(self, func, **kwargs): - self.func = func - self.default_args = {} - self.setter_functions = {} - - for key, value in kwargs.items(): - if callable(value): - self.default_args[key] = None - self.setter_functions[key] = value - else: - self.default_args[key] = value - - def __str__(self): - """Return the string representation of the function.""" - return f"TDFunc({self.func.name})" - - def __call__(self, **kwargs): - """Execute the wrapped function. - - Parameters - ---------- - **kwargs : `dict`, optional - Additional function arguments. - """ - # Start with the default arguments. We make a copy so we can modify the dictionary. - args = copy.copy(self.default_args) - - # If there are arguments to get from the calling functions, set those. - for key, value in self.setter_functions.items(): - if isinstance(value, TDFunc): - args[key] = value(**kwargs) - else: - args[key] = value() - - # Set any last arguments from the kwargs (overwriting previous settings). - # Only use known kwargs. - for key, value in kwargs.items(): - if key in self.default_args: - args[key] = value - - # Call the function with all the parameters. - return self.func(**args) diff --git a/src/tdastro/sources/physical_model.py b/src/tdastro/sources/physical_model.py index 8cf4946e..c463bb70 100644 --- a/src/tdastro/sources/physical_model.py +++ b/src/tdastro/sources/physical_model.py @@ -45,7 +45,8 @@ def __init__(self, ra=None, dec=None, redshift=None, distance=None, background=N if distance is not None: self.add_parameter("distance", distance) elif redshift is not None and kwargs.get("cosmology", None) is not None: - self.add_parameter("distance", RedshiftDistFunc(redshift=self.get_redshift, **kwargs)) + self._redshift_func = RedshiftDistFunc(redshift=self, **kwargs) + self.add_parameter("distance", self._redshift_func.compute) else: self.add_parameter("distance", None) @@ -56,10 +57,6 @@ def __str__(self): """Return the string representation of the model.""" return "PhysicalModel" - def get_redshift(self): - """Return the redshift for the model.""" - return self.redshift - def add_effect(self, effect, allow_dups=True, **kwargs): """Add a transformational effect to the PhysicalModel. Effects are applied in the order in which they are added. @@ -158,10 +155,11 @@ def sample_parameters(self, include_effects=True, **kwargs): All the keyword arguments, including the values needed to sample parameters. """ - if self.background is not None: + if self.background is not None and self.background.check_resample(self): self.background.sample_parameters(include_effects, **kwargs) super().sample_parameters(**kwargs) if include_effects: for effect in self.effects: - effect.sample_parameters(**kwargs) + if effect.check_resample(self): + effect.sample_parameters(**kwargs) diff --git a/tests/tdastro/populations/test_fixed_population.py b/tests/tdastro/populations/test_fixed_population.py index 95b0e0ae..ff56b20a 100644 --- a/tests/tdastro/populations/test_fixed_population.py +++ b/tests/tdastro/populations/test_fixed_population.py @@ -2,8 +2,8 @@ import numpy as np import pytest +from tdastro.base_models import FunctionNode from tdastro.effects.white_noise import WhiteNoise -from tdastro.function_wrappers import TDFunc from tdastro.populations.fixed_population import FixedPopulation from tdastro.sources.static_source import StaticSource @@ -88,9 +88,10 @@ def test_fixed_population_sample_sources(): def test_fixed_population_sample_fluxes(): """Test that we can create a population of sources and sample its sources' flux.""" random.seed(1001) + brightness_func = FunctionNode(random.uniform, a=0.0, b=100.0) population = FixedPopulation() population.add_source(StaticSource(brightness=100.0), 10.0) - population.add_source(StaticSource(brightness=TDFunc(random.uniform, a=0.0, b=100.0)), 10.0) + population.add_source(StaticSource(brightness=brightness_func.compute), 10.0) population.add_source(StaticSource(brightness=200.0), 20.0) population.add_source(StaticSource(brightness=150.0), 10.0) diff --git a/tests/tdastro/sources/test_galaxy_models.py b/tests/tdastro/sources/test_galaxy_models.py index e127494d..312777f8 100644 --- a/tests/tdastro/sources/test_galaxy_models.py +++ b/tests/tdastro/sources/test_galaxy_models.py @@ -55,7 +55,11 @@ def test_gaussian_galaxy() -> None: # Check that if we resample the source it will propagate and correctly resample the host. # the host's (RA, dec) should change and the source's should still be close. + print(f"Host sample = {host.sample_iteration}") + print(f"Source sample = {source.sample_iteration}") source.sample_parameters() + print(f"Host sample = {host.sample_iteration}") + print(f"Source sample = {source.sample_iteration}") assert host_ra != host.ra assert host_dec != host.dec diff --git a/tests/tdastro/sources/test_step_source.py b/tests/tdastro/sources/test_step_source.py index 325e73f2..899a716f 100644 --- a/tests/tdastro/sources/test_step_source.py +++ b/tests/tdastro/sources/test_step_source.py @@ -1,7 +1,7 @@ import random import numpy as np -from tdastro.function_wrappers import TDFunc +from tdastro.base_models import FunctionNode from tdastro.sources.static_source import StaticSource from tdastro.sources.step_source import StepSource @@ -58,9 +58,9 @@ def test_step_source_resample() -> None: random.seed(1111) model = StepSource( - brightness=TDFunc(_sample_brightness, magnitude=100.0), + brightness=FunctionNode(_sample_brightness, magnitude=100.0).compute, t0=0.0, - t1=TDFunc(_sample_end, duration=5.0), + t1=FunctionNode(_sample_end, duration=5.0).compute, ) num_samples = 1000 diff --git a/tests/tdastro/test_base_models.py b/tests/tdastro/test_base_models.py index 591b88b6..b0fb51e5 100644 --- a/tests/tdastro/test_base_models.py +++ b/tests/tdastro/test_base_models.py @@ -1,7 +1,8 @@ import random +import numpy as np import pytest -from tdastro.base_models import ParameterizedNode +from tdastro.base_models import FunctionNode, ParameterizedNode def _sampler_fun(**kwargs): @@ -15,6 +16,19 @@ def _sampler_fun(**kwargs): return random.random() +def _test_func(value1, value2): + """Return the sum of the two parameters. + + Parameters + ---------- + value1 : `float` + The first parameter. + value2 : `float` + The second parameter. + """ + return value1 + value2 + + class PairModel(ParameterizedNode): """A test class for the parameterized model. @@ -45,6 +59,10 @@ def __init__(self, value1, value2, **kwargs): self.add_parameter("value2", value2, required=True, **kwargs) self.add_parameter("value_sum", self.result, required=True, **kwargs) + def get_value1(self): + """Get the value of value1.""" + return self.value1 + def result(self, **kwargs): """Add the pair of values together @@ -141,3 +159,60 @@ def test_parameterized_model_modify() -> None: # We cannot set a value that hasn't been added. with pytest.raises(KeyError): model.set_parameter("brightness", 5.0) + + +def test_function_node_basic(): + """Test that we can create and query a FunctionNode.""" + my_func = FunctionNode(_test_func, value1=1.0, value2=2.0) + assert my_func.compute() == 3.0 + assert my_func.compute(value2=3.0) == 4.0 + assert my_func.compute(value2=3.0, unused_param=5.0) == 4.0 + assert my_func.compute(value2=3.0, value1=1.0) == 4.0 + + +def test_function_node_chain(): + """Test that we can create and query a chained FunctionNode.""" + func1 = FunctionNode(_test_func, value1=1.0, value2=1.0) + func2 = FunctionNode(_test_func, value1=func1.compute, value2=3.0) + assert func2.compute() == 5.0 + + +def test_np_sampler_method(): + """Test that we can wrap numpy random functions.""" + rng = np.random.default_rng(1001) + my_func = FunctionNode(rng.normal, loc=10.0, scale=1.0) + + # Sample 1000 times with the default values. Check that we are near + # the expected mean and not everything is equal. + vals = np.array([my_func.compute() for _ in range(1000)]) + assert abs(np.mean(vals) - 10.0) < 1.0 + assert not np.all(vals == vals[0]) + + # Override the mean and resample. + vals = np.array([my_func.compute(loc=25.0) for _ in range(1000)]) + assert abs(np.mean(vals) - 25.0) < 1.0 + assert not np.all(vals == vals[0]) + + +def test_function_node_obj(): + """Test that we can create and query a FunctionNode that depends on + another ParameterizedNode. + """ + # The model depends on the function. + func = FunctionNode(_test_func, value1=5.0, value2=6.0) + model = PairModel(value1=10.0, value2=func.compute) + assert model.result() == 21.0 + + # Function depends on the model's value2 attribute. + model = PairModel(value1=-10.0, value2=17.5) + func = FunctionNode(_test_func, value1=5.0, value2=model) + assert model.result() == 7.5 + assert func.compute() == 22.5 + + # Function depends on the model's get_value1() method. + func = FunctionNode(_test_func, value1=model.get_value1, value2=5.0) + assert model.result() == 7.5 + assert func.compute() == -5.0 + + # We can always override the attributes with kwargs. + assert func.compute(value1=1.0, value2=4.0) == 5.0 diff --git a/tests/tdastro/test_function_wrappers.py b/tests/tdastro/test_function_wrappers.py deleted file mode 100644 index e7048391..00000000 --- a/tests/tdastro/test_function_wrappers.py +++ /dev/null @@ -1,109 +0,0 @@ -import numpy as np -import pytest -from tdastro.function_wrappers import TDFunc - - -def _test_func(a, b): - """Return the sum of the two parameters. - - Parameters - ---------- - a : `float` - The first parameter. - b : `float` - The second parameter. - """ - return a + b - - -class _StaticModel: - """A test model that has given parameters. - - Attributes - ---------- - a : `float` - The first parameter. - b : `float` - The second parameter. - """ - - def __init__(self, a, b): - self.a = a - self.b = b - - def get_a(self): - """Get the a attribute.""" - return self.a - - def get_b(self): - """Get the b attribute.""" - return self.b - - -def test_tdfunc_basic(): - """Test that we can create and query a TDFunc.""" - # Fail without enough arguments (only a is specified). - tdf1 = TDFunc(_test_func, a=1.0) - with pytest.raises(TypeError): - _ = tdf1() - - # We succeed with a manually specified parameter (but first fail with - # the default). - tdf2 = TDFunc(_test_func, a=1.0, b=None) - with pytest.raises(TypeError): - _ = tdf2() - assert tdf2(b=2.0) == 3.0 - - # We can overwrite parameters. - assert tdf2(a=3.0, b=2.0) == 5.0 - - # Test that we ignore extra kwargs. - assert tdf2(b=2.0, c=10.0, d=11.0) == 3.0 - - # That we can use a different ordering for parameters. - tdf3 = TDFunc(_test_func, b=2.0, a=1.0) - assert tdf3() == 3.0 - - -def test_tdfunc_chain(): - """Test that we can create and query a chained TDFunc.""" - tdf1 = TDFunc(_test_func, a=1.0, b=1.0) - tdf2 = TDFunc(_test_func, a=tdf1, b=3.0) - assert tdf2() == 5.0 - - # This will overwrite all the b parameters. - assert tdf2(b=10.0) == 21.0 - - -def test_np_sampler_method(): - """Test that we can wrap numpy random functions.""" - rng = np.random.default_rng(1001) - tdf = TDFunc(rng.normal, loc=10.0, scale=1.0) - - # Sample 1000 times with the default values. Check that we are near - # the expected mean and not everything is equal. - vals = np.array([tdf() for _ in range(1000)]) - assert abs(np.mean(vals) - 10.0) < 1.0 - assert not np.all(vals == vals[0]) - - # Override the mean and resample. - vals = np.array([tdf(loc=25.0) for _ in range(1000)]) - assert abs(np.mean(vals) - 25.0) < 1.0 - assert not np.all(vals == vals[0]) - - -def test_tdfunc_obj(): - """Test that we can create and query a TDFunc that depends on an object.""" - model = _StaticModel(a=10.0, b=11.0) - - # Check a function without defaults. - tdf1 = TDFunc(_test_func, a=model.get_a, b=model.get_b) - assert tdf1() == 21.0 - - # We can pull from multiple models. - model2 = _StaticModel(a=1.0, b=0.0) - tdf2 = TDFunc(_test_func, a=model.get_a, b=model2.get_b) - assert tdf2() == 10.0 - - # W can overwrite everything with kwargs. - assert tdf2(b=7.5) == 17.5 From ae8349d77b0a01d4af2e8cedb7c47a861a590261 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Mon, 15 Jul 2024 08:38:05 -0400 Subject: [PATCH 23/31] Reorganize the base classes --- src/tdastro/base_models.py | 321 +----------------- src/tdastro/effects/effect_model.py | 47 +++ src/tdastro/effects/white_noise.py | 2 +- src/tdastro/populations/fixed_population.py | 2 +- src/tdastro/populations/population_model.py | 106 ++++++ src/tdastro/sources/galaxy_models.py | 2 +- src/tdastro/sources/periodic_source.py | 2 +- src/tdastro/sources/physical_model.py | 149 ++++++++ src/tdastro/sources/sncomso_models.py | 2 +- src/tdastro/sources/spline_model.py | 2 +- src/tdastro/sources/static_source.py | 2 +- tests/tdastro/sources/test_physical_models.py | 10 + tests/tdastro/test_base_models.py | 16 +- 13 files changed, 341 insertions(+), 322 deletions(-) create mode 100644 src/tdastro/effects/effect_model.py create mode 100644 src/tdastro/populations/population_model.py create mode 100644 src/tdastro/sources/physical_model.py create mode 100644 tests/tdastro/sources/test_physical_models.py diff --git a/src/tdastro/base_models.py b/src/tdastro/base_models.py index 0bd53b07..1cf1fbfc 100644 --- a/src/tdastro/base_models.py +++ b/src/tdastro/base_models.py @@ -1,10 +1,8 @@ -"""The base models used throughout TDAstro including physical objects, effects, and populations.""" +"""The base models used to specify the TDAstro computation graph.""" import types from enum import Enum -import numpy as np - class ParameterSource(Enum): """ParameterSource specifies where a PhysicalModel should get the value @@ -18,10 +16,9 @@ class ParameterSource(Enum): MODEL_METHOD = 4 -class ParameterizedModel: +class ParameterizedNode: """Any model that uses parameters that can be set by constants, - functions, or other parameterized models. ParameterizedModels can - include physical objects or statistical distributions. + functions, or other parameterized nodes. Attributes ---------- @@ -39,11 +36,11 @@ def __init__(self, **kwargs): self.sample_iteration = 0 def __str__(self): - """Return the string representation of the model.""" - return "ParameterizedModel" + """Return the string representation of the node.""" + return "ParameterizedNode" def set_parameter(self, name, value=None, **kwargs): - """Set a single *existing* parameter to the ParameterizedModel. + """Set a single *existing* parameter to the ParameterizedNode. Notes ----- @@ -56,7 +53,7 @@ def set_parameter(self, name, value=None, **kwargs): The parameter name to add. value : any, optional The information to use to set the parameter. Can be a constant, - function, ParameterizedModel, or self. + function, ParameterizedNode, or self. **kwargs : `dict`, optional All other keyword arguments, possibly including the parameter setters. @@ -80,14 +77,14 @@ def set_parameter(self, name, value=None, **kwargs): # Case 1: If we are getting from a static function, sample it. self.setters[name] = (ParameterSource.FUNCTION, value, required) setattr(self, name, value(**kwargs)) - elif isinstance(value, types.MethodType) and isinstance(value.__self__, ParameterizedModel): - # Case 2: We are trying to use the method from a ParameterizedModel. + elif isinstance(value, types.MethodType) and isinstance(value.__self__, ParameterizedNode): + # Case 2: We are trying to use the method from a ParameterizedNode. # Note that this will (correctly) fail if we are adding a model method from the current # object that requires an unset attribute. self.setters[name] = (ParameterSource.MODEL_METHOD, value, required) setattr(self, name, value(**kwargs)) - elif isinstance(value, ParameterizedModel): - # Case 3: We are trying to access an attribute from a parameterized model. + elif isinstance(value, ParameterizedNode): + # Case 3: We are trying to access an attribute from a ParameterizedNode. if not hasattr(value, name): raise ValueError(f"Attribute {name} missing from parent.") self.setters[name] = (ParameterSource.MODEL_ATTRIBUTE, value, required) @@ -103,7 +100,7 @@ def set_parameter(self, name, value=None, **kwargs): raise ValueError(f"Missing required parameter {name}") def add_parameter(self, name, value=None, required=False, **kwargs): - """Add a single *new* parameter to the ParameterizedModel. + """Add a single *new* parameter to the ParameterizedNode. Notes ----- @@ -118,7 +115,7 @@ def add_parameter(self, name, value=None, required=False, **kwargs): The parameter name to add. value : any, optional The information to use to set the parameter. Can be a constant, - function, ParameterizedModel, or self. + function, ParameterizedNode, or self. required : `bool` Fail if the parameter is set to ``None``. **kwargs : `dict`, optional @@ -141,7 +138,7 @@ def add_parameter(self, name, value=None, required=False, **kwargs): def sample_parameters(self, max_depth=50, **kwargs): """Sample the model's underlying parameters if they are provided by a function - or ParameterizedModel. + or ParameterizedNode. Parameters ---------- @@ -188,293 +185,3 @@ def sample_parameters(self, max_depth=50, **kwargs): # Increase the sampling iteration. self.sample_iteration += 1 - - -class PhysicalModel(ParameterizedModel): - """A physical model of a source of flux. - - Physical models can have fixed attributes (where you need to create a new model - to change them) and settable attributes that can be passed functions or constants. - They can also have special background pointers that link to another PhysicalModel - producing flux. We can chain these to have a supernova in front of a star in front - of a static background. - - Attributes - ---------- - ra : `float` - The object's right ascension (in degrees) - dec : `float` - The object's declination (in degrees) - distance : `float` - The object's distance (in pc) - background : `PhysicalModel` - A source of background flux such as a host galaxy. - effects : `list` - A list of effects to apply to an observations. - """ - - def __init__(self, ra=None, dec=None, distance=None, background=None, **kwargs): - super().__init__(**kwargs) - self.effects = [] - - # Set RA, dec, and distance from the parameters. - self.add_parameter("ra", ra) - self.add_parameter("dec", dec) - self.add_parameter("distance", distance) - - # Background is an object not a sampled parameter - self.background = background - - def __str__(self): - """Return the string representation of the model.""" - return "PhysicalModel" - - def add_effect(self, effect, allow_dups=True, **kwargs): - """Add a transformational effect to the PhysicalModel. - Effects are applied in the order in which they are added. - - Parameters - ---------- - effect : `EffectModel` - The effect to apply. - allow_dups : `bool` - Allow multiple effects of the same type. - Default = ``True`` - **kwargs : `dict`, optional - Any additional keyword arguments. - - Raises - ------ - Raises a ``AttributeError`` if the PhysicalModel does not have all of the - required attributes. - """ - # Check that we have not added this effect before. - if not allow_dups: - effect_type = type(effect) - for prev in self.effects: - if effect_type == type(prev): - raise ValueError("Added the effect type to a model {effect_type} more than once.") - - required: list = effect.required_parameters() - for parameter in required: - # Raise an AttributeError if the parameter is missing or set to None. - if getattr(self, parameter) is None: - raise AttributeError(f"Parameter {parameter} unset for model {type(self).__name__}") - - self.effects.append(effect) - - def _evaluate(self, times, wavelengths, **kwargs): - """Draw effect-free observations for this object. - - Parameters - ---------- - times : `numpy.ndarray` - A length T array of timestamps. - wavelengths : `numpy.ndarray`, optional - A length N array of wavelengths. - **kwargs : `dict`, optional - Any additional keyword arguments. - - Returns - ------- - flux_density : `numpy.ndarray` - A length T x N matrix of SED values. - """ - raise NotImplementedError() - - def evaluate(self, times, wavelengths, resample_parameters=False, **kwargs): - """Draw observations for this object and apply the noise. - - Parameters - ---------- - times : `numpy.ndarray` - A length T array of timestamps. - wavelengths : `numpy.ndarray`, optional - A length N array of wavelengths. - resample_parameters : `bool` - Treat this evaluation as a completely new object, resampling the - parameters from the original provided functions. - **kwargs : `dict`, optional - Any additional keyword arguments. - - Returns - ------- - flux_density : `numpy.ndarray` - A length T x N matrix of SED values. - """ - if resample_parameters: - self.sample_parameters(kwargs) - - # Compute the flux density for both the current object and add in anything - # behind it, such as a host galaxy. - flux_density = self._evaluate(times, wavelengths, **kwargs) - if self.background is not None: - flux_density += self.background._evaluate(times, wavelengths, ra=self.ra, dec=self.dec, **kwargs) - - for effect in self.effects: - flux_density = effect.apply(flux_density, wavelengths, self, **kwargs) - return flux_density - - def sample_parameters(self, include_effects=True, **kwargs): - """Sample the model's underlying parameters if they are provided by a function - or ParameterizedModel. - - Parameters - ---------- - include_effects : `bool` - Resample the parameters for the effects models. - **kwargs : `dict`, optional - All the keyword arguments, including the values needed to sample - parameters. - """ - if self.background is not None: - self.background.sample_parameters(include_effects, **kwargs) - super().sample_parameters(**kwargs) - - if include_effects: - for effect in self.effects: - effect.sample_parameters(**kwargs) - - -class EffectModel(ParameterizedModel): - """A physical or systematic effect to apply to an observation.""" - - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def __str__(self): - """Return the string representation of the model.""" - return "EffectModel" - - def required_parameters(self): - """Returns a list of the parameters of a PhysicalModel - that this effect needs to access. - - Returns - ------- - parameters : `list` of `str` - A list of every required parameter the effect needs. - """ - return [] - - def apply(self, flux_density, wavelengths=None, physical_model=None, **kwargs): - """Apply the effect to observations (flux_density values) - - Parameters - ---------- - flux_density : `numpy.ndarray` - A length T X N matrix of flux density values. - wavelengths : `numpy.ndarray`, optional - A length N array of wavelengths. - physical_model : `PhysicalModel` - A PhysicalModel from which the effect may query parameters - such as redshift, position, or distance. - **kwargs : `dict`, optional - Any additional keyword arguments. - - Returns - ------- - flux_density : `numpy.ndarray` - A length T x N matrix of flux densities after the effect is applied. - """ - raise NotImplementedError() - - -class PopulationModel(ParameterizedModel): - """A model of a population of PhysicalModels. - - Attributes - ---------- - num_sources : `int` - The number of different sources in the population. - sources : `list` - A list of sources from which to draw. - """ - - def __init__(self, rng=None, **kwargs): - super().__init__(**kwargs) - self.num_sources = 0 - self.sources = [] - - def __str__(self): - """Return the string representation of the model.""" - return f"PopulationModel with {self.num_sources} sources." - - def add_source(self, new_source, **kwargs): - """Add a new source to the population. - - Parameters - ---------- - new_source : `PhysicalModel` - A source from the population. - **kwargs : `dict`, optional - Any additional keyword arguments. - """ - if not isinstance(new_source, PhysicalModel): - raise ValueError("All sources must be PhysicalModels") - self.sources.append(new_source) - self.num_sources += 1 - - def draw_source(self): - """Sample a single source from the population. - - Returns - ------- - source : `PhysicalModel` - A source from the population. - """ - raise NotImplementedError() - - def add_effect(self, effect, allow_dups=False, **kwargs): - """Add a transformational effect to all PhysicalModels in this population. - Effects are applied in the order in which they are added. - - Parameters - ---------- - effect : `EffectModel` - The effect to apply. - allow_dups : `bool` - Allow multiple effects of the same type. - Default = ``True`` - **kwargs : `dict`, optional - Any additional keyword arguments. - - Raises - ------ - Raises a ``AttributeError`` if the PhysicalModel does not have all of the - required attributes. - """ - for source in self.sources: - source.add_effect(effect, allow_dups=allow_dups, **kwargs) - - def evaluate(self, samples, times, wavelengths, resample_parameters=False, **kwargs): - """Draw observations from a single (randomly sampled) source. - - Parameters - ---------- - samples : `int` - The number of sources to samples. - times : `numpy.ndarray` - A length T array of timestamps. - wavelengths : `numpy.ndarray`, optional - A length N array of wavelengths. - resample_parameters : `bool` - Treat this evaluation as a completely new object, resampling the - parameters from the original provided functions. - **kwargs : `dict`, optional - Any additional keyword arguments. - - Returns - ------- - results : `numpy.ndarray` - A shape (samples, T, N) matrix of SED values. - """ - if samples <= 0: - raise ValueError("The number of samples must be > 0.") - - results = [] - for _ in range(samples): - source = self.draw_source() - object_fluxes = source.evaluate(times, wavelengths, resample_parameters, **kwargs) - results.append(object_fluxes) - return np.array(results) diff --git a/src/tdastro/effects/effect_model.py b/src/tdastro/effects/effect_model.py new file mode 100644 index 00000000..d5b80bf8 --- /dev/null +++ b/src/tdastro/effects/effect_model.py @@ -0,0 +1,47 @@ +"""The base EffectModel class used for all effects.""" + +from tdastro.base_models import ParameterizedNode + + +class EffectModel(ParameterizedNode): + """A physical or systematic effect to apply to an observation.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def __str__(self): + """Return the string representation of the model.""" + return "EffectModel" + + def required_parameters(self): + """Returns a list of the parameters of a PhysicalModel + that this effect needs to access. + + Returns + ------- + parameters : `list` of `str` + A list of every required parameter the effect needs. + """ + return [] + + def apply(self, flux_density, wavelengths=None, physical_model=None, **kwargs): + """Apply the effect to observations (flux_density values) + + Parameters + ---------- + flux_density : `numpy.ndarray` + A length T X N matrix of flux density values. + wavelengths : `numpy.ndarray`, optional + A length N array of wavelengths. + physical_model : `PhysicalModel` + A PhysicalModel from which the effect may query parameters + such as redshift, position, or distance. + **kwargs : `dict`, optional + Any additional keyword arguments. + + Returns + ------- + flux_density : `numpy.ndarray` + A length T x N matrix of flux densities after the effect is applied. + """ + raise NotImplementedError() diff --git a/src/tdastro/effects/white_noise.py b/src/tdastro/effects/white_noise.py index 49f2e621..b7d3a0a7 100644 --- a/src/tdastro/effects/white_noise.py +++ b/src/tdastro/effects/white_noise.py @@ -1,6 +1,6 @@ import numpy as np -from tdastro.base_models import EffectModel +from tdastro.effects.effect_model import EffectModel class WhiteNoise(EffectModel): diff --git a/src/tdastro/populations/fixed_population.py b/src/tdastro/populations/fixed_population.py index bc644b4a..18f3f0e9 100644 --- a/src/tdastro/populations/fixed_population.py +++ b/src/tdastro/populations/fixed_population.py @@ -1,6 +1,6 @@ import random -from tdastro.base_models import PopulationModel +from tdastro.populations.population_model import PopulationModel class FixedPopulation(PopulationModel): diff --git a/src/tdastro/populations/population_model.py b/src/tdastro/populations/population_model.py new file mode 100644 index 00000000..c527038a --- /dev/null +++ b/src/tdastro/populations/population_model.py @@ -0,0 +1,106 @@ +"""The base population models.""" + +import numpy as np + +from tdastro.base_models import ParameterizedNode +from tdastro.sources.physical_model import PhysicalModel + + +class PopulationModel(ParameterizedNode): + """A model of a population of PhysicalModels. + + Attributes + ---------- + num_sources : `int` + The number of different sources in the population. + sources : `list` + A list of sources from which to draw. + """ + + def __init__(self, rng=None, **kwargs): + super().__init__(**kwargs) + self.num_sources = 0 + self.sources = [] + + def __str__(self): + """Return the string representation of the model.""" + return f"PopulationModel with {self.num_sources} sources." + + def add_source(self, new_source, **kwargs): + """Add a new source to the population. + + Parameters + ---------- + new_source : `PhysicalModel` + A source from the population. + **kwargs : `dict`, optional + Any additional keyword arguments. + """ + if not isinstance(new_source, PhysicalModel): + raise ValueError("All sources must be PhysicalModels") + self.sources.append(new_source) + self.num_sources += 1 + + def draw_source(self): + """Sample a single source from the population. + + Returns + ------- + source : `PhysicalModel` + A source from the population. + """ + raise NotImplementedError() + + def add_effect(self, effect, allow_dups=False, **kwargs): + """Add a transformational effect to all PhysicalModels in this population. + Effects are applied in the order in which they are added. + + Parameters + ---------- + effect : `EffectModel` + The effect to apply. + allow_dups : `bool` + Allow multiple effects of the same type. + Default = ``True`` + **kwargs : `dict`, optional + Any additional keyword arguments. + + Raises + ------ + Raises a ``AttributeError`` if the PhysicalModel does not have all of the + required attributes. + """ + for source in self.sources: + source.add_effect(effect, allow_dups=allow_dups, **kwargs) + + def evaluate(self, samples, times, wavelengths, resample_parameters=False, **kwargs): + """Draw observations from a single (randomly sampled) source. + + Parameters + ---------- + samples : `int` + The number of sources to samples. + times : `numpy.ndarray` + A length T array of timestamps. + wavelengths : `numpy.ndarray`, optional + A length N array of wavelengths. + resample_parameters : `bool` + Treat this evaluation as a completely new object, resampling the + parameters from the original provided functions. + **kwargs : `dict`, optional + Any additional keyword arguments. + + Returns + ------- + results : `numpy.ndarray` + A shape (samples, T, N) matrix of SED values. + """ + if samples <= 0: + raise ValueError("The number of samples must be > 0.") + + results = [] + for _ in range(samples): + source = self.draw_source() + object_fluxes = source.evaluate(times, wavelengths, resample_parameters, **kwargs) + results.append(object_fluxes) + return np.array(results) diff --git a/src/tdastro/sources/galaxy_models.py b/src/tdastro/sources/galaxy_models.py index 9aea8bfc..eadb853a 100644 --- a/src/tdastro/sources/galaxy_models.py +++ b/src/tdastro/sources/galaxy_models.py @@ -1,7 +1,7 @@ import numpy as np from astropy.coordinates import angular_separation -from tdastro.base_models import PhysicalModel +from tdastro.sources.physical_model import PhysicalModel class GaussianGalaxy(PhysicalModel): diff --git a/src/tdastro/sources/periodic_source.py b/src/tdastro/sources/periodic_source.py index b68c02da..9b3cd81f 100644 --- a/src/tdastro/sources/periodic_source.py +++ b/src/tdastro/sources/periodic_source.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod -from tdastro.base_models import PhysicalModel +from tdastro.sources.physical_model import PhysicalModel class PeriodicSource(PhysicalModel, ABC): diff --git a/src/tdastro/sources/physical_model.py b/src/tdastro/sources/physical_model.py new file mode 100644 index 00000000..0b8af424 --- /dev/null +++ b/src/tdastro/sources/physical_model.py @@ -0,0 +1,149 @@ +"""The base PhysicalModel used for all sources.""" + +from tdastro.base_models import ParameterizedNode + + +class PhysicalModel(ParameterizedNode): + """A physical model of a source of flux. + + Physical models can have fixed attributes (where you need to create a new model + to change them) and settable attributes that can be passed functions or constants. + They can also have special background pointers that link to another PhysicalModel + producing flux. We can chain these to have a supernova in front of a star in front + of a static background. + + Attributes + ---------- + ra : `float` + The object's right ascension (in degrees) + dec : `float` + The object's declination (in degrees) + distance : `float` + The object's distance (in pc). + background : `PhysicalModel` + A source of background flux such as a host galaxy. + effects : `list` + A list of effects to apply to an observations. + """ + + def __init__(self, ra=None, dec=None, distance=None, background=None, **kwargs): + super().__init__(**kwargs) + self.effects = [] + + # Set RA, dec, and redshift from the parameters. + self.add_parameter("ra", ra) + self.add_parameter("dec", dec) + self.add_parameter("distance", distance) + + # Background is an object not a sampled parameter + self.background = background + + def __str__(self): + """Return the string representation of the model.""" + return "PhysicalModel" + + def add_effect(self, effect, allow_dups=True, **kwargs): + """Add a transformational effect to the PhysicalModel. + Effects are applied in the order in which they are added. + + Parameters + ---------- + effect : `EffectModel` + The effect to apply. + allow_dups : `bool` + Allow multiple effects of the same type. + Default = ``True`` + **kwargs : `dict`, optional + Any additional keyword arguments. + + Raises + ------ + Raises a ``AttributeError`` if the PhysicalModel does not have all of the + required attributes. + """ + # Check that we have not added this effect before. + if not allow_dups: + effect_type = type(effect) + for prev in self.effects: + if effect_type == type(prev): + raise ValueError("Added the effect type to a model {effect_type} more than once.") + + required: list = effect.required_parameters() + for parameter in required: + # Raise an AttributeError if the parameter is missing or set to None. + if getattr(self, parameter) is None: + raise AttributeError(f"Parameter {parameter} unset for model {type(self).__name__}") + + self.effects.append(effect) + + def _evaluate(self, times, wavelengths, **kwargs): + """Draw effect-free observations for this object. + + Parameters + ---------- + times : `numpy.ndarray` + A length T array of timestamps. + wavelengths : `numpy.ndarray`, optional + A length N array of wavelengths. + **kwargs : `dict`, optional + Any additional keyword arguments. + + Returns + ------- + flux_density : `numpy.ndarray` + A length T x N matrix of SED values. + """ + raise NotImplementedError() + + def evaluate(self, times, wavelengths, resample_parameters=False, **kwargs): + """Draw observations for this object and apply the noise. + + Parameters + ---------- + times : `numpy.ndarray` + A length T array of timestamps. + wavelengths : `numpy.ndarray`, optional + A length N array of wavelengths. + resample_parameters : `bool` + Treat this evaluation as a completely new object, resampling the + parameters from the original provided functions. + **kwargs : `dict`, optional + Any additional keyword arguments. + + Returns + ------- + flux_density : `numpy.ndarray` + A length T x N matrix of SED values. + """ + if resample_parameters: + self.sample_parameters(kwargs) + + # Compute the flux density for both the current object and add in anything + # behind it, such as a host galaxy. + flux_density = self._evaluate(times, wavelengths, **kwargs) + if self.background is not None: + flux_density += self.background._evaluate(times, wavelengths, ra=self.ra, dec=self.dec, **kwargs) + + for effect in self.effects: + flux_density = effect.apply(flux_density, wavelengths, self, **kwargs) + return flux_density + + def sample_parameters(self, include_effects=True, **kwargs): + """Sample the model's underlying parameters if they are provided by a function + or ParameterizedModel. + + Parameters + ---------- + include_effects : `bool` + Resample the parameters for the effects models. + **kwargs : `dict`, optional + All the keyword arguments, including the values needed to sample + parameters. + """ + if self.background is not None: + self.background.sample_parameters(include_effects, **kwargs) + super().sample_parameters(**kwargs) + + if include_effects: + for effect in self.effects: + effect.sample_parameters(**kwargs) diff --git a/src/tdastro/sources/sncomso_models.py b/src/tdastro/sources/sncomso_models.py index 379de41e..b176233e 100644 --- a/src/tdastro/sources/sncomso_models.py +++ b/src/tdastro/sources/sncomso_models.py @@ -6,7 +6,7 @@ from sncosmo.models import get_source -from tdastro.base_models import PhysicalModel +from tdastro.sources.physical_model import PhysicalModel class SncosmoWrapperModel(PhysicalModel): diff --git a/src/tdastro/sources/spline_model.py b/src/tdastro/sources/spline_model.py index eff26dff..67e79202 100644 --- a/src/tdastro/sources/spline_model.py +++ b/src/tdastro/sources/spline_model.py @@ -7,7 +7,7 @@ from scipy.interpolate import RectBivariateSpline -from tdastro.base_models import PhysicalModel +from tdastro.sources.physical_model import PhysicalModel class SplineModel(PhysicalModel): diff --git a/src/tdastro/sources/static_source.py b/src/tdastro/sources/static_source.py index 44799c2e..f69eea89 100644 --- a/src/tdastro/sources/static_source.py +++ b/src/tdastro/sources/static_source.py @@ -1,6 +1,6 @@ import numpy as np -from tdastro.base_models import PhysicalModel +from tdastro.sources.physical_model import PhysicalModel class StaticSource(PhysicalModel): diff --git a/tests/tdastro/sources/test_physical_models.py b/tests/tdastro/sources/test_physical_models.py new file mode 100644 index 00000000..72950ded --- /dev/null +++ b/tests/tdastro/sources/test_physical_models.py @@ -0,0 +1,10 @@ +from tdastro.sources.physical_model import PhysicalModel + + +def test_physical_model(): + """Test that we can create a PhysicalModel.""" + # Everything is specified. + model1 = PhysicalModel(ra=1.0, dec=2.0, distance=3.0) + assert model1.ra == 1.0 + assert model1.dec == 2.0 + assert model1.distance == 3.0 diff --git a/tests/tdastro/test_base_models.py b/tests/tdastro/test_base_models.py index b5ff1daa..c33322ac 100644 --- a/tests/tdastro/test_base_models.py +++ b/tests/tdastro/test_base_models.py @@ -1,7 +1,7 @@ import random import pytest -from tdastro.base_models import ParameterizedModel +from tdastro.base_models import ParameterizedNode def _sampler_fun(**kwargs): @@ -15,8 +15,8 @@ def _sampler_fun(**kwargs): return random.random() -class PairModel(ParameterizedModel): - """A test class for the parameterized model. +class PairModel(ParameterizedNode): + """A test class for the ParameterizedNode. Attributes ---------- @@ -33,9 +33,9 @@ def __init__(self, value1, value2, **kwargs): Parameters ---------- - value1 : `float`, `function`, `ParameterizedModel`, or `None` + value1 : `float`, `function`, `ParameterizedNode`, or `None` The first value. - value2 : `float`, `function`, `ParameterizedModel`, or `None` + value2 : `float`, `function`, `ParameterizedNode`, or `None` The second value. **kwargs : `dict`, optional Any additional keyword arguments. @@ -61,7 +61,7 @@ def result(self, **kwargs): return self.value1 + self.value2 -def test_parameterized_model() -> None: +def test_parameterized_node() -> None: """Test that we can sample and create a PairModel object.""" # Simple addition model1 = PairModel(value1=0.5, value2=0.5) @@ -123,8 +123,8 @@ def test_parameterized_model() -> None: assert model1.sample_iteration == model4.sample_iteration -def test_parameterized_model_modify() -> None: - """Test that we can modify the parameters in a model.""" +def test_parameterized_node_modify() -> None: + """Test that we can modify the parameters in a node.""" model = PairModel(value1=0.5, value2=0.5) assert model.value1 == 0.5 assert model.value2 == 0.5 From de036aa2b1ce94e028ff284922bcf5527c3c9285 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Mon, 15 Jul 2024 09:22:31 -0400 Subject: [PATCH 24/31] Fix merge --- tests/tdastro/sources/test_galaxy_models.py | 4 ---- tests/tdastro/sources/test_physical_models.py | 2 +- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/tdastro/sources/test_galaxy_models.py b/tests/tdastro/sources/test_galaxy_models.py index 312777f8..e127494d 100644 --- a/tests/tdastro/sources/test_galaxy_models.py +++ b/tests/tdastro/sources/test_galaxy_models.py @@ -55,11 +55,7 @@ def test_gaussian_galaxy() -> None: # Check that if we resample the source it will propagate and correctly resample the host. # the host's (RA, dec) should change and the source's should still be close. - print(f"Host sample = {host.sample_iteration}") - print(f"Source sample = {source.sample_iteration}") source.sample_parameters() - print(f"Host sample = {host.sample_iteration}") - print(f"Source sample = {source.sample_iteration}") assert host_ra != host.ra assert host_dec != host.dec diff --git a/tests/tdastro/sources/test_physical_models.py b/tests/tdastro/sources/test_physical_models.py index cf18bf1f..cd91e3e3 100644 --- a/tests/tdastro/sources/test_physical_models.py +++ b/tests/tdastro/sources/test_physical_models.py @@ -4,7 +4,7 @@ def test_physical_model(): - """Test that we can create a physical model.""" + """Test that we can create a PhysicalModel.""" # Everything is specified. model1 = PhysicalModel(ra=1.0, dec=2.0, distance=3.0, redshift=0.0) assert model1.ra == 1.0 From 2b2230a41f5ff508be3629a37774f2051f70dd46 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Mon, 15 Jul 2024 09:32:06 -0400 Subject: [PATCH 25/31] Small fixes --- src/tdastro/base_models.py | 23 ++++++++++++----------- tests/tdastro/test_base_models.py | 2 +- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/src/tdastro/base_models.py b/src/tdastro/base_models.py index 262913da..cc67c8bb 100644 --- a/src/tdastro/base_models.py +++ b/src/tdastro/base_models.py @@ -10,9 +10,9 @@ class ParameterSource(Enum): """ CONSTANT = 1 - MODEL_ATTRIBUTE = 2 - MODEL_METHOD = 3 - FUNCTION = 4 + FUNCTION = 2 + MODEL_ATTRIBUTE = 3 + MODEL_METHOD = 4 class ParameterizedNode: @@ -35,16 +35,17 @@ def __init__(self, **kwargs): self.sample_iteration = 0 def __str__(self): - """Return the string representation of the model.""" + """Return the string representation of the node.""" return "ParameterizedNode" - def check_resample(self, child): + def check_resample(self, other): """Check if we need to resample the current node based - on the state of a child trying to access its attributes. + on the state of another node trying to access its attributes + or methods. Parameters ---------- - child : `ParameterizedNode` + other : `ParameterizedNode` The node that is accessing the attribute or method of the current node. @@ -57,13 +58,13 @@ def check_resample(self, child): ------ ``ValueError`` if the graph has gotten out of sync. """ - if child == self: + if other == self: return False - if child.sample_iteration == self.sample_iteration: + if other.sample_iteration == self.sample_iteration: return False - if child.sample_iteration != self.sample_iteration + 1: + if other.sample_iteration != self.sample_iteration + 1: raise ValueError( - f"Node {str(child)} at iteration {child.sample_iteration} accessing" + f"Node {str(other)} at iteration {other.sample_iteration} accessing" f" parent {str(self)} at iteration {self.sample_iteration}." ) return True diff --git a/tests/tdastro/test_base_models.py b/tests/tdastro/test_base_models.py index 8c29647e..1fa953e5 100644 --- a/tests/tdastro/test_base_models.py +++ b/tests/tdastro/test_base_models.py @@ -30,7 +30,7 @@ def _test_func(value1, value2): class PairModel(ParameterizedNode): - """A test class for the parameterized model. + """A test class for the ParameterizedNode. Attributes ---------- From e021f6159b68697231930e2f129682a994e10f21 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Mon, 15 Jul 2024 10:34:50 -0400 Subject: [PATCH 26/31] Add a function for getting the current parameter settings --- src/tdastro/base_models.py | 50 +++++++++++++++++++- src/tdastro/effects/effect_model.py | 4 -- src/tdastro/effects/white_noise.py | 4 -- src/tdastro/populations/fixed_population.py | 4 -- src/tdastro/populations/population_model.py | 4 -- src/tdastro/sources/galaxy_models.py | 4 -- src/tdastro/sources/physical_model.py | 4 -- src/tdastro/sources/sncomso_models.py | 4 -- src/tdastro/sources/spline_model.py | 5 -- src/tdastro/sources/static_source.py | 4 -- src/tdastro/sources/step_source.py | 4 -- tests/tdastro/sources/test_sncosmo_models.py | 2 +- tests/tdastro/sources/test_spline_source.py | 6 +-- tests/tdastro/sources/test_static_source.py | 5 +- tests/tdastro/sources/test_step_source.py | 1 - tests/tdastro/test_base_models.py | 33 +++++++++++++ 16 files changed, 88 insertions(+), 50 deletions(-) diff --git a/src/tdastro/base_models.py b/src/tdastro/base_models.py index 1cf1fbfc..c99e40d2 100644 --- a/src/tdastro/base_models.py +++ b/src/tdastro/base_models.py @@ -22,6 +22,8 @@ class ParameterizedNode: Attributes ---------- + node_identifier : `str` + An identifier (or name) for the current node. setters : `dict` of `tuple` A dictionary to information about the setters for the parameters in the form: (ParameterSource, setter information, required). The attributes are @@ -31,13 +33,17 @@ class ParameterizedNode: model's parameters have been resampled. """ - def __init__(self, **kwargs): + def __init__(self, node_identifier=None, **kwargs): self.setters = {} self.sample_iteration = 0 + self.node_identifier = node_identifier def __str__(self): """Return the string representation of the node.""" - return "ParameterizedNode" + if self.node_identifier: + return f"{self.node_identifier}={self.__class__.__name__}" + else: + return self.__class__.__name__ def set_parameter(self, name, value=None, **kwargs): """Set a single *existing* parameter to the ParameterizedNode. @@ -185,3 +191,43 @@ def sample_parameters(self, max_depth=50, **kwargs): # Increase the sampling iteration. self.sample_iteration += 1 + + def get_all_parameter_values(self, recursive=True, seen=None): + """Get the values of the current parameters and (optionally) those of + all their dependencies. + + Effectively snapshots the state of the execution graph. + + Parameters + ---------- + seen : `set` + A set of objects that have already been processed. + recursive : `bool` + Recursively extract the attribute setting of this object's dependencies. + + Returns + ------- + values : `dict` + The dictionary mapping the combination of the object identifier and + attribute name to its value. + """ + # Make sure that we do not process the same nodes multiple times. + if seen is None: + seen = set() + if self in seen: + return {} + seen.add(self) + + values = {} + for name, (source_type, setter, _) in self.setters.items(): + if recursive: + if source_type == ParameterSource.MODEL_ATTRIBUTE: + values.update(setter.get_all_parameter_values(True, seen)) + elif source_type == ParameterSource.MODEL_METHOD: + values.update(setter.__self__.get_all_parameter_values(True, seen)) + + full_name = f"{str(self)}.{name}" + else: + full_name = name + values[full_name] = getattr(self, name) + return values diff --git a/src/tdastro/effects/effect_model.py b/src/tdastro/effects/effect_model.py index d5b80bf8..83885bda 100644 --- a/src/tdastro/effects/effect_model.py +++ b/src/tdastro/effects/effect_model.py @@ -9,10 +9,6 @@ class EffectModel(ParameterizedNode): def __init__(self, **kwargs): super().__init__(**kwargs) - def __str__(self): - """Return the string representation of the model.""" - return "EffectModel" - def required_parameters(self): """Returns a list of the parameters of a PhysicalModel that this effect needs to access. diff --git a/src/tdastro/effects/white_noise.py b/src/tdastro/effects/white_noise.py index b7d3a0a7..f9191a8f 100644 --- a/src/tdastro/effects/white_noise.py +++ b/src/tdastro/effects/white_noise.py @@ -16,10 +16,6 @@ def __init__(self, scale, **kwargs): super().__init__(**kwargs) self.add_parameter("scale", scale, required=True, **kwargs) - def __str__(self): - """Return the string representation of the model.""" - return f"WhiteNoise({self.scale})" - def apply(self, flux_density, wavelengths=None, physical_model=None, **kwargs): """Apply the effect to observations (flux_density values) diff --git a/src/tdastro/populations/fixed_population.py b/src/tdastro/populations/fixed_population.py index 18f3f0e9..cb8e808f 100644 --- a/src/tdastro/populations/fixed_population.py +++ b/src/tdastro/populations/fixed_population.py @@ -18,10 +18,6 @@ def __init__(self, **kwargs): super().__init__(**kwargs) self.weights = [] - def __str__(self): - """Return the string representation of the model.""" - return f"FixedPopulation({self.probability})" - def add_source(self, new_source, rate, **kwargs): """Add a new source to the population. diff --git a/src/tdastro/populations/population_model.py b/src/tdastro/populations/population_model.py index c527038a..e4a966e9 100644 --- a/src/tdastro/populations/population_model.py +++ b/src/tdastro/populations/population_model.py @@ -22,10 +22,6 @@ def __init__(self, rng=None, **kwargs): self.num_sources = 0 self.sources = [] - def __str__(self): - """Return the string representation of the model.""" - return f"PopulationModel with {self.num_sources} sources." - def add_source(self, new_source, **kwargs): """Add a new source to the population. diff --git a/src/tdastro/sources/galaxy_models.py b/src/tdastro/sources/galaxy_models.py index eadb853a..3471bdd9 100644 --- a/src/tdastro/sources/galaxy_models.py +++ b/src/tdastro/sources/galaxy_models.py @@ -21,10 +21,6 @@ def __init__(self, brightness, radius, **kwargs): self.add_parameter("galaxy_radius_std", radius, required=True, **kwargs) self.add_parameter("brightness", brightness, required=True, **kwargs) - def __str__(self): - """Return the string representation of the model.""" - return f"GuassianGalaxy({self.brightness}, {self.galaxy_radius_std})" - def sample_ra(self): """Sample an right ascension coordinate based on the center and radius of the galaxy. diff --git a/src/tdastro/sources/physical_model.py b/src/tdastro/sources/physical_model.py index 0b8af424..de300643 100644 --- a/src/tdastro/sources/physical_model.py +++ b/src/tdastro/sources/physical_model.py @@ -38,10 +38,6 @@ def __init__(self, ra=None, dec=None, distance=None, background=None, **kwargs): # Background is an object not a sampled parameter self.background = background - def __str__(self): - """Return the string representation of the model.""" - return "PhysicalModel" - def add_effect(self, effect, allow_dups=True, **kwargs): """Add a transformational effect to the PhysicalModel. Effects are applied in the order in which they are added. diff --git a/src/tdastro/sources/sncomso_models.py b/src/tdastro/sources/sncomso_models.py index b176233e..037434a9 100644 --- a/src/tdastro/sources/sncomso_models.py +++ b/src/tdastro/sources/sncomso_models.py @@ -32,10 +32,6 @@ def __init__(self, source_name, **kwargs): self.source_name = source_name self.source = get_source(source_name) - def __str__(self): - """Return the string representation of the model.""" - return f"SncosmoWrapperModel({self.source_name})" - @property def param_names(self): """Return a list of the model's parameter names.""" diff --git a/src/tdastro/sources/spline_model.py b/src/tdastro/sources/spline_model.py index 67e79202..6336b121 100644 --- a/src/tdastro/sources/spline_model.py +++ b/src/tdastro/sources/spline_model.py @@ -67,15 +67,10 @@ def __init__( # These parameters are directly set, because they cannot be changed once # the object is created. - self.name = name self._times = times self._wavelengths = wavelengths self._spline = RectBivariateSpline(times, wavelengths, flux, kx=time_degree, ky=wave_degree) - def __str__(self): - """Return the string representation of the model.""" - return f"SplineModel({self.name})" - def _evaluate(self, times, wavelengths, **kwargs): """Draw effect-free observations for this object. diff --git a/src/tdastro/sources/static_source.py b/src/tdastro/sources/static_source.py index f69eea89..25584a15 100644 --- a/src/tdastro/sources/static_source.py +++ b/src/tdastro/sources/static_source.py @@ -16,10 +16,6 @@ def __init__(self, brightness, **kwargs): super().__init__(**kwargs) self.add_parameter("brightness", brightness, required=True, **kwargs) - def __str__(self): - """Return the string representation of the model.""" - return f"StaticSource({self.brightness})" - def _evaluate(self, times, wavelengths, **kwargs): """Draw effect-free observations for this object. diff --git a/src/tdastro/sources/step_source.py b/src/tdastro/sources/step_source.py index 156bd60a..7488f32d 100644 --- a/src/tdastro/sources/step_source.py +++ b/src/tdastro/sources/step_source.py @@ -21,10 +21,6 @@ def __init__(self, brightness, t0, t1, **kwargs): self.add_parameter("t0", t0, required=True, **kwargs) self.add_parameter("t1", t1, required=True, **kwargs) - def __str__(self): - """Return the string representation of the model.""" - return f"StepSource({self.brightness})_{self.t0}_to_{self.t1}" - def _evaluate(self, times, wavelengths, **kwargs): """Draw effect-free observations for this object. diff --git a/tests/tdastro/sources/test_sncosmo_models.py b/tests/tdastro/sources/test_sncosmo_models.py index 58590903..e3c67e80 100644 --- a/tests/tdastro/sources/test_sncosmo_models.py +++ b/tests/tdastro/sources/test_sncosmo_models.py @@ -7,7 +7,7 @@ def test_sncomso_models_hsiao() -> None: model = SncosmoWrapperModel("hsiao") model.set(amplitude=1.0e-10) assert model.amplitude == 1.0e-10 - assert str(model) == "SncosmoWrapperModel(hsiao)" + assert str(model) == "SncosmoWrapperModel" assert np.array_equal(model.param_names, ["amplitude"]) assert np.array_equal(model.parameters, [1.0e-10]) diff --git a/tests/tdastro/sources/test_spline_source.py b/tests/tdastro/sources/test_spline_source.py index c91d7503..5b30bc87 100644 --- a/tests/tdastro/sources/test_spline_source.py +++ b/tests/tdastro/sources/test_spline_source.py @@ -8,7 +8,7 @@ def test_spline_model_flat() -> None: wavelengths = np.linspace(100.0, 500.0, 25) fluxes = np.full((len(times), len(wavelengths)), 1.0) model = SplineModel(times, wavelengths, fluxes) - assert str(model) == "SplineModel(None)" + assert str(model) == "SplineModel" test_times = np.array([0.0, 1.0, 2.0, 3.0, 10.0]) test_waves = np.array([0.0, 100.0, 200.0, 1000.0]) @@ -18,8 +18,8 @@ def test_spline_model_flat() -> None: expected = np.full_like(values, 1.0) np.testing.assert_array_almost_equal(values, expected) - model2 = SplineModel(times, wavelengths, fluxes, amplitude=5.0, name="test") - assert str(model2) == "SplineModel(test)" + model2 = SplineModel(times, wavelengths, fluxes, amplitude=5.0, node_identifier="test") + assert str(model2) == "test=SplineModel" values2 = model2.evaluate(test_times, test_waves) assert values2.shape == (5, 4) diff --git a/tests/tdastro/sources/test_static_source.py b/tests/tdastro/sources/test_static_source.py index 420ff749..2dacd3b0 100644 --- a/tests/tdastro/sources/test_static_source.py +++ b/tests/tdastro/sources/test_static_source.py @@ -19,12 +19,12 @@ def _sampler_fun(magnitude, **kwargs): def test_static_source() -> None: """Test that we can sample and create a StaticSource object.""" - model = StaticSource(brightness=10.0) + model = StaticSource(brightness=10.0, node_identifier="my_static_source") assert model.brightness == 10.0 assert model.ra is None assert model.dec is None assert model.distance is None - assert str(model) == "StaticSource(10.0)" + assert str(model) == "my_static_source=StaticSource" times = np.array([1, 2, 3, 4, 5, 10]) wavelengths = np.array([100.0, 200.0, 300.0]) @@ -49,6 +49,7 @@ def test_static_source_host() -> None: assert model.ra == 1.0 assert model.dec == 2.0 assert model.distance == 3.0 + assert str(model) == "StaticSource" def test_static_source_resample() -> None: diff --git a/tests/tdastro/sources/test_step_source.py b/tests/tdastro/sources/test_step_source.py index dc069241..d0f530f4 100644 --- a/tests/tdastro/sources/test_step_source.py +++ b/tests/tdastro/sources/test_step_source.py @@ -41,7 +41,6 @@ def test_step_source() -> None: assert model.ra == 1.0 assert model.dec == 2.0 assert model.distance == 3.0 - assert str(model) == "StepSource(15.0)_1.0_to_2.0" times = np.array([0.0, 1.0, 2.0, 3.0]) wavelengths = np.array([100.0, 200.0]) diff --git a/tests/tdastro/test_base_models.py b/tests/tdastro/test_base_models.py index c33322ac..15d4aa5f 100644 --- a/tests/tdastro/test_base_models.py +++ b/tests/tdastro/test_base_models.py @@ -123,6 +123,39 @@ def test_parameterized_node() -> None: assert model1.sample_iteration == model4.sample_iteration +def test_parameterized_node_attributes() -> None: + """Test that we can extract the attributes of a graph of ParameterizedNode.""" + model1 = PairModel(value1=0.5, value2=1.5, node_identifier="1") + settings = model1.get_all_parameter_values(False) + assert len(settings) == 3 + assert settings["value1"] == 0.5 + assert settings["value2"] == 1.5 + assert settings["value_sum"] == 2.0 + + settings = model1.get_all_parameter_values(True) + assert len(settings) == 3 + assert settings["1=PairModel.value1"] == 0.5 + assert settings["1=PairModel.value2"] == 1.5 + assert settings["1=PairModel.value_sum"] == 2.0 + + # Use value1=model.value and value2=1.0 + model2 = PairModel(value1=model1, value2=3.0, node_identifier="2") + settings = model2.get_all_parameter_values(False) + assert len(settings) == 3 + assert settings["value1"] == 0.5 + assert settings["value2"] == 3.0 + assert settings["value_sum"] == 3.5 + + settings = model2.get_all_parameter_values(True) + assert len(settings) == 6 + assert settings["1=PairModel.value1"] == 0.5 + assert settings["1=PairModel.value2"] == 1.5 + assert settings["1=PairModel.value_sum"] == 2.0 + assert settings["2=PairModel.value1"] == 0.5 + assert settings["2=PairModel.value2"] == 3.0 + assert settings["2=PairModel.value_sum"] == 3.5 + + def test_parameterized_node_modify() -> None: """Test that we can modify the parameters in a node.""" model = PairModel(value1=0.5, value2=0.5) From bc4848a54696bd4d243e4486b464106fb6f7d069 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Mon, 15 Jul 2024 10:43:08 -0400 Subject: [PATCH 27/31] Address PR comments. --- src/tdastro/astro_utils/cosmology.py | 4 ++-- src/tdastro/sources/physical_model.py | 6 +++--- tests/tdastro/astro_utils/test_cosmology.py | 5 ++++- tests/tdastro/sources/test_physical_models.py | 9 ++++----- 4 files changed, 13 insertions(+), 11 deletions(-) diff --git a/src/tdastro/astro_utils/cosmology.py b/src/tdastro/astro_utils/cosmology.py index 1beecf70..b36022b1 100644 --- a/src/tdastro/astro_utils/cosmology.py +++ b/src/tdastro/astro_utils/cosmology.py @@ -5,7 +5,7 @@ def redshift_to_distance(redshift, cosmology, kind="comoving"): - """Compute a source's distance given its redshift and a + """Compute a source's luminosity distance given its redshift and a specified cosmology using astropy's redshift_distance(). Parameters @@ -21,7 +21,7 @@ def redshift_to_distance(redshift, cosmology, kind="comoving"): Returns ------- distance : `float` - The distance (in pc) + The luminosity distance (in pc) """ z = redshift * cu.redshift distance = z.to(u.pc, cu.redshift_distance(cosmology, kind=kind)) diff --git a/src/tdastro/sources/physical_model.py b/src/tdastro/sources/physical_model.py index c463bb70..fb4cb3de 100644 --- a/src/tdastro/sources/physical_model.py +++ b/src/tdastro/sources/physical_model.py @@ -22,7 +22,7 @@ class PhysicalModel(ParameterizedNode): redshift : `float` The object's redshift. distance : `float` - The object's distance (in pc). If the distance is not provided and + The object's luminosity distance (in pc). If no value is provided and a ``cosmology`` parameter is given, the model will try to derive from the redshift and the cosmology. background : `PhysicalModel` @@ -40,8 +40,8 @@ def __init__(self, ra=None, dec=None, redshift=None, distance=None, background=N self.add_parameter("dec", dec) self.add_parameter("redshift", redshift) - # If the distance is provided, use that. Otherwise try the redshift value - # using the cosmology (if given). Finally, default to None. + # If the luminosity distance is provided, use that. Otherwise try the + # redshift value using the cosmology (if given). Finally, default to None. if distance is not None: self.add_parameter("distance", distance) elif redshift is not None and kwargs.get("cosmology", None) is not None: diff --git a/tests/tdastro/astro_utils/test_cosmology.py b/tests/tdastro/astro_utils/test_cosmology.py index 14a500e2..fff88c47 100644 --- a/tests/tdastro/astro_utils/test_cosmology.py +++ b/tests/tdastro/astro_utils/test_cosmology.py @@ -1,5 +1,5 @@ import pytest -from astropy.cosmology import WMAP9 +from astropy.cosmology import WMAP9, Planck18 from tdastro.astro_utils.cosmology import redshift_to_distance @@ -8,3 +8,6 @@ def test_redshift_to_distance(): # Use the example from: # https://docs.astropy.org/en/stable/api/astropy.cosmology.units.redshift_distance.html assert redshift_to_distance(1100, cosmology=WMAP9) == pytest.approx(14004.03157418 * 1e6) + + # Try the Planck18 cosmology. + assert redshift_to_distance(1100, cosmology=Planck18) == pytest.approx(13886.327957 * 1e6) diff --git a/tests/tdastro/sources/test_physical_models.py b/tests/tdastro/sources/test_physical_models.py index cd91e3e3..d05560e3 100644 --- a/tests/tdastro/sources/test_physical_models.py +++ b/tests/tdastro/sources/test_physical_models.py @@ -1,5 +1,5 @@ import pytest -from astropy.cosmology import WMAP9 +from astropy.cosmology import Planck18 from tdastro.sources.physical_model import PhysicalModel @@ -12,13 +12,12 @@ def test_physical_model(): assert model1.distance == 3.0 assert model1.redshift == 0.0 - # Derive the distance from the redshift using the example from: - # https://docs.astropy.org/en/stable/api/astropy.cosmology.units.redshift_distance.html - model2 = PhysicalModel(ra=1.0, dec=2.0, redshift=1100.0, cosmology=WMAP9) + # Derive the distance from the redshift: + model2 = PhysicalModel(ra=1.0, dec=2.0, redshift=1100.0, cosmology=Planck18) assert model2.ra == 1.0 assert model2.dec == 2.0 assert model2.redshift == 1100.0 - assert model2.distance == pytest.approx(14004.03157418 * 1e6) + assert model2.distance == pytest.approx(13886.327957 * 1e6) # Neither distance nor redshift are specified. model3 = PhysicalModel(ra=1.0, dec=2.0) From bd09398225e9278d3a72a66d001da2b3a27dd293 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Mon, 15 Jul 2024 10:52:05 -0400 Subject: [PATCH 28/31] Address more PR comments Fix the distance kind as "luminosity" --- src/tdastro/astro_utils/cosmology.py | 13 +++---------- tests/tdastro/astro_utils/test_cosmology.py | 11 +++++------ tests/tdastro/sources/test_physical_models.py | 3 +-- 3 files changed, 9 insertions(+), 18 deletions(-) diff --git a/src/tdastro/astro_utils/cosmology.py b/src/tdastro/astro_utils/cosmology.py index b36022b1..57771220 100644 --- a/src/tdastro/astro_utils/cosmology.py +++ b/src/tdastro/astro_utils/cosmology.py @@ -4,7 +4,7 @@ from tdastro.base_models import FunctionNode -def redshift_to_distance(redshift, cosmology, kind="comoving"): +def redshift_to_distance(redshift, cosmology): """Compute a source's luminosity distance given its redshift and a specified cosmology using astropy's redshift_distance(). @@ -14,9 +14,6 @@ def redshift_to_distance(redshift, cosmology, kind="comoving"): The redshift value. cosmology : `astropy.cosmology` The cosmology specification. - kind : `str` - The distance type for the Equivalency as defined by - astropy.cosmology.units.redshift_distance. Returns ------- @@ -24,7 +21,7 @@ def redshift_to_distance(redshift, cosmology, kind="comoving"): The luminosity distance (in pc) """ z = redshift * cu.redshift - distance = z.to(u.pc, cu.redshift_distance(cosmology, kind=kind)) + distance = z.to(u.pc, cu.redshift_distance(cosmology, kind="luminosity")) return distance.value @@ -45,18 +42,14 @@ class RedshiftDistFunc(FunctionNode): The function or constant providing the redshift value. cosmology : `astropy.cosmology` The cosmology specification. - kind : `str` - The distance type for the Equivalency as defined by - astropy.cosmology.units.redshift_distance. """ - def __init__(self, redshift, cosmology, kind="comoving"): + def __init__(self, redshift, cosmology): # Call the super class's constructor with the needed information. super().__init__( func=redshift_to_distance, redshift=redshift, cosmology=cosmology, - kind=kind, ) def __str__(self): diff --git a/tests/tdastro/astro_utils/test_cosmology.py b/tests/tdastro/astro_utils/test_cosmology.py index fff88c47..9ed9559e 100644 --- a/tests/tdastro/astro_utils/test_cosmology.py +++ b/tests/tdastro/astro_utils/test_cosmology.py @@ -1,13 +1,12 @@ -import pytest from astropy.cosmology import WMAP9, Planck18 from tdastro.astro_utils.cosmology import redshift_to_distance def test_redshift_to_distance(): """Test that we can convert the redshift to a distance using a given cosmology.""" - # Use the example from: - # https://docs.astropy.org/en/stable/api/astropy.cosmology.units.redshift_distance.html - assert redshift_to_distance(1100, cosmology=WMAP9) == pytest.approx(14004.03157418 * 1e6) + wmap9_val = redshift_to_distance(1100, cosmology=WMAP9) + planck18_val = redshift_to_distance(1100, cosmology=Planck18) - # Try the Planck18 cosmology. - assert redshift_to_distance(1100, cosmology=Planck18) == pytest.approx(13886.327957 * 1e6) + assert abs(planck18_val - wmap9_val) > 1000.0 + assert 13.0 * 1e12 < wmap9_val < 16.0 * 1e12 + assert 13.0 * 1e12 < planck18_val < 16.0 * 1e12 diff --git a/tests/tdastro/sources/test_physical_models.py b/tests/tdastro/sources/test_physical_models.py index d05560e3..8098cf1b 100644 --- a/tests/tdastro/sources/test_physical_models.py +++ b/tests/tdastro/sources/test_physical_models.py @@ -1,4 +1,3 @@ -import pytest from astropy.cosmology import Planck18 from tdastro.sources.physical_model import PhysicalModel @@ -17,7 +16,7 @@ def test_physical_model(): assert model2.ra == 1.0 assert model2.dec == 2.0 assert model2.redshift == 1100.0 - assert model2.distance == pytest.approx(13886.327957 * 1e6) + assert 13.0 * 1e12 < model2.distance < 16.0 * 1e12 # Neither distance nor redshift are specified. model3 = PhysicalModel(ra=1.0, dec=2.0) From d0bebae999475721b458a65c7c1b6b6a2347d00d Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Mon, 15 Jul 2024 16:03:05 -0400 Subject: [PATCH 29/31] Update tests/tdastro/test_base_models.py Co-authored-by: Olivia R. Lynn --- tests/tdastro/test_base_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tdastro/test_base_models.py b/tests/tdastro/test_base_models.py index baeb8da3..1c2f855e 100644 --- a/tests/tdastro/test_base_models.py +++ b/tests/tdastro/test_base_models.py @@ -156,7 +156,7 @@ def test_parameterized_node_attributes() -> None: assert settings["1=PairModel.value2"] == 1.5 assert settings["1=PairModel.value_sum"] == 2.0 - # Use value1=model.value and value2=1.0 + # Use value1=model.value and value2=3.0 model2 = PairModel(value1=model1, value2=3.0, node_identifier="2") settings = model2.get_all_parameter_values(False) assert len(settings) == 3 From e44409127681d01d680f624d69d36da9e5c1dcb1 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Tue, 16 Jul 2024 11:36:46 -0400 Subject: [PATCH 30/31] Add ability to overwrite existing attributes --- src/tdastro/base_models.py | 9 +++++++-- tests/tdastro/test_base_models.py | 18 ++++++++++++++++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/src/tdastro/base_models.py b/src/tdastro/base_models.py index 5dcaeeec..4cc15c84 100644 --- a/src/tdastro/base_models.py +++ b/src/tdastro/base_models.py @@ -134,7 +134,7 @@ def set_parameter(self, name, value=None, **kwargs): if required and getattr(self, name) is None: raise ValueError(f"Missing required parameter {name}") - def add_parameter(self, name, value=None, required=False, **kwargs): + def add_parameter(self, name, value=None, required=False, allow_overwrite=False, **kwargs): """Add a single *new* parameter to the ParameterizedNode. Notes @@ -153,6 +153,11 @@ def add_parameter(self, name, value=None, required=False, **kwargs): function, ParameterizedNode, or self. required : `bool` Fail if the parameter is set to ``None``. + Default = ``False`` + allow_overwrite : `bool` + Allow a subclass to overwrite the definition of the attribute + used in the superclass. + Default = ``False`` **kwargs : `dict`, optional All other keyword arguments, possibly including the parameter setters. @@ -163,7 +168,7 @@ def add_parameter(self, name, value=None, required=False, **kwargs): Raise a ``ValueError`` if the parameter is required, but set to None. """ # Check for parameter collision. - if hasattr(self, name) and getattr(self, name) is not None: + if hasattr(self, name) and getattr(self, name) is not None and not allow_overwrite: raise KeyError(f"Duplicate parameter set: {name}") # Add an entry for the setter function and fill in the remaining diff --git a/tests/tdastro/test_base_models.py b/tests/tdastro/test_base_models.py index 1c2f855e..f474aea1 100644 --- a/tests/tdastro/test_base_models.py +++ b/tests/tdastro/test_base_models.py @@ -141,6 +141,24 @@ def test_parameterized_node() -> None: assert model1.sample_iteration == model4.sample_iteration +def test_parameterized_node_overwrite() -> None: + """Test that we can overwrite attributes in a PairModel.""" + model1 = PairModel(value1=0.5, value2=0.5) + assert model1.value1 == 0.5 + assert model1.value1 == 0.5 + assert model1.result() == 1.0 + assert model1.value_sum == 1.0 + assert model1.sample_iteration == 0 + + # By default the overwrite fails. + with pytest.raises(KeyError): + model1.add_parameter("value1", value=1.0) + + # We can force it with allow_overwrite=True. + model1.add_parameter("value1", value=1.0, allow_overwrite=True) + assert model1.value1 == 1.0 + + def test_parameterized_node_attributes() -> None: """Test that we can extract the attributes of a graph of ParameterizedNode.""" model1 = PairModel(value1=0.5, value2=1.5, node_identifier="1") From 61bfb8e9e846dd747255bd5a7ff626f1bff59b26 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Tue, 16 Jul 2024 11:59:57 -0400 Subject: [PATCH 31/31] Update periodic_variable_star.py --- src/tdastro/sources/periodic_variable_star.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/tdastro/sources/periodic_variable_star.py b/src/tdastro/sources/periodic_variable_star.py index e75fd6f5..b4612123 100644 --- a/src/tdastro/sources/periodic_variable_star.py +++ b/src/tdastro/sources/periodic_variable_star.py @@ -21,10 +21,9 @@ class PeriodicVariableStar(PeriodicSource, ABC): The distance to the source, in pc. """ - def __init__(self, period, t0, **kwargs): - distance = kwargs.pop("distance", None) + def __init__(self, period, t0, distance, **kwargs): super().__init__(period, t0, **kwargs) - self.add_parameter("distance", value=distance, required=True, **kwargs) + self.add_parameter("distance", value=distance, required=True, allow_overwrite=True, **kwargs) def _evaluate_phases(self, phases, wavelengths, **kwargs): """Draw effect-free observations for this object, as a function of phase.