diff --git a/src/tdastro/base_models.py b/src/tdastro/base_models.py index 36b924fe..27a7ffcc 100644 --- a/src/tdastro/base_models.py +++ b/src/tdastro/base_models.py @@ -21,16 +21,17 @@ class ParameterizedModel: Attributes ---------- - setters : `dict` or `tuple` + setters : `list` of `tuple` A dictionary to information about the setters for the parameters in the form: - (ParameterSource, value). + (name, ParameterSource, setter information). The attributes are stored in the + order in which they need to be set. sample_iteration : `int` A counter used to syncronize sampling runs. Tracks how many times this model's parameters have been resampled. """ def __init__(self, **kwargs): - self.setters = {} + self.setters = [] self.sample_iteration = 0 def __str__(self): @@ -39,11 +40,10 @@ def __str__(self): def add_parameter(self, name, value=None, required=False, **kwargs): """Add a single parameter to the ParameterizedModel. Checks multiple sources - in the following order: - 1. Manually specified ``value`` - 2. An entry in ``kwargs`` - 3. ``None`` + in the following order: 1. Manually specified ``value``, 2. An entry in ``kwargs``, + or 3. ``None``. Sets an initial value for the attribute based on the given information. + The attributes are stored in the order in which they are added. Parameters ---------- @@ -72,24 +72,26 @@ def add_parameter(self, name, value=None, required=False, **kwargs): if value is not None: if isinstance(value, types.FunctionType): # Case 1: If we are getting from a static function, sample it. - self.setters[name] = (ParameterSource.FUNCTION, value) + self.setters.append((name, ParameterSource.FUNCTION, value)) setattr(self, name, value(**kwargs)) elif isinstance(value, types.MethodType) and isinstance(value.__self__, ParameterizedModel): # Case 2: We are trying to use the method from a ParameterizedModel. - self.setters[name] = (ParameterSource.MODEL_METHOD, value) + # Note that this will (correctly) fail if we are adding a model method from the current + # object that requires an unset attribute. + self.setters.append((name, ParameterSource.MODEL_METHOD, value)) setattr(self, name, value(**kwargs)) elif isinstance(value, ParameterizedModel): # Case 3: We are trying to access an attribute from a parameterized model. if not hasattr(value, name): raise ValueError(f"Attribute {name} missing from parent.") - self.setters[name] = (ParameterSource.MODEL_ATTRIBUTE, value) + self.setters.append((name, ParameterSource.MODEL_ATTRIBUTE, value)) setattr(self, name, getattr(value, name)) else: # Case 4: The value is constant. - self.setters[name] = (ParameterSource.CONSTANT, value) + self.setters.append((name, ParameterSource.CONSTANT, value)) setattr(self, name, value) elif not required: - self.setters[name] = (ParameterSource.CONSTANT, None) + self.setters.append((name, ParameterSource.CONSTANT, None)) setattr(self, name, None) else: raise ValueError(f"Missing required parameter {name}") @@ -102,7 +104,7 @@ def sample_parameters(self, max_depth=50, **kwargs): ---------- max_depth : `int` The maximum recursive depth. Used to prevent infinite loops. - Most users should not need to set this manually. + Users should not need to set this manually. **kwargs : `dict`, optional All the keyword arguments, including the values needed to sample parameters. @@ -116,28 +118,28 @@ def sample_parameters(self, max_depth=50, **kwargs): raise ValueError(f"Maximum sampling depth exceeded at {self}. Potential infinite loop.") # Run through each parameter and sample it based on the given recipe. - for param, value in self.setters.items(): + for name, source_type, setter in self.setters: sampled_value = None - if value[0] == ParameterSource.CONSTANT: - sampled_value = value[1] - elif value[0] == ParameterSource.FUNCTION: - sampled_value = value[1](**kwargs) - elif value[0] == ParameterSource.MODEL_ATTRIBUTE: + if source_type == ParameterSource.CONSTANT: + sampled_value = setter + elif source_type == ParameterSource.FUNCTION: + sampled_value = setter(**kwargs) + elif source_type == ParameterSource.MODEL_ATTRIBUTE: # Check if we need to resample the parent (needs to be done before # we read its attribute). - if value[1].sample_iteration == self.sample_iteration: - value[1].sample_parameters(max_depth - 1, **kwargs) - sampled_value = getattr(value[1], param) - elif value[0] == ParameterSource.MODEL_METHOD: + if setter.sample_iteration == self.sample_iteration: + setter.sample_parameters(max_depth - 1, **kwargs) + sampled_value = getattr(setter, name) + elif source_type == ParameterSource.MODEL_METHOD: # Check if we need to resample the parent (needs to be done before - # we evaluate its method). - parent = value[1].__self__ - if parent.sample_iteration == self.sample_iteration: + # we evaluate its method). Do not resample the current object. + parent = setter.__self__ + if parent is not self and parent.sample_iteration == self.sample_iteration: parent.sample_parameters(max_depth - 1, **kwargs) - sampled_value = value[1](**kwargs) + sampled_value = setter(**kwargs) else: - raise ValueError(f"Unknown ParameterSource type {value[0]}") - setattr(self, param, sampled_value) + raise ValueError(f"Unknown ParameterSource type {source_type} for {name}") + setattr(self, name, sampled_value) # Increase the sampling iteration. self.sample_iteration += 1 @@ -173,7 +175,7 @@ def __str__(self): """Return the string representation of the model.""" return "PhysicalModel" - def add_effect(self, effect): + def add_effect(self, effect, **kwargs): """Add a transformational effect to the PhysicalModel. Effects are applied in the order in which they are added. @@ -181,6 +183,8 @@ def add_effect(self, effect): ---------- effect : `EffectModel` The effect to apply. + **kwargs : `dict`, optional + Any additional keyword arguments. Raises ------ @@ -242,12 +246,33 @@ def evaluate(self, times, wavelengths, resample_parameters=False, **kwargs): flux_density = effect.apply(flux_density, wavelengths, self, **kwargs) return flux_density + def sample_parameters(self, include_effects=True, **kwargs): + """Sample the model's underlying parameters if they are provided by a function + or ParameterizedModel. + + Parameters + ---------- + include_effects : `bool` + Resample the parameters for the effects models. + **kwargs : `dict`, optional + All the keyword arguments, including the values needed to sample + parameters. + """ + super().sample_parameters(**kwargs) + if include_effects: + for effect in self.effects: + effect.sample_parameters(**kwargs) + -class EffectModel: +class EffectModel(ParameterizedModel): """A physical or systematic effect to apply to an observation.""" def __init__(self, **kwargs): - pass + super().__init__(**kwargs) + + def __str__(self): + """Return the string representation of the model.""" + return "EffectModel" def required_parameters(self): """Returns a list of the parameters of a PhysicalModel diff --git a/src/tdastro/effects/white_noise.py b/src/tdastro/effects/white_noise.py index 32f66153..49f2e621 100644 --- a/src/tdastro/effects/white_noise.py +++ b/src/tdastro/effects/white_noise.py @@ -14,17 +14,21 @@ class WhiteNoise(EffectModel): def __init__(self, scale, **kwargs): super().__init__(**kwargs) - self.scale = scale + self.add_parameter("scale", scale, required=True, **kwargs) - def apply(self, flux_density, bands=None, physical_model=None, **kwargs): + def __str__(self): + """Return the string representation of the model.""" + return f"WhiteNoise({self.scale})" + + def apply(self, flux_density, wavelengths=None, physical_model=None, **kwargs): """Apply the effect to observations (flux_density values) Parameters ---------- flux_density : `numpy.ndarray` An array of flux density values. - bands : `numpy.ndarray`, optional - An array of bands. + wavelengths : `numpy.ndarray`, optional + An array of wavelengths. physical_model : `PhysicalModel` A PhysicalModel from which the effect may query parameters such as redshift, position, or distance. diff --git a/tests/tdastro/effects/test_white_noise.py b/tests/tdastro/effects/test_white_noise.py index 22548524..ed46edd3 100644 --- a/tests/tdastro/effects/test_white_noise.py +++ b/tests/tdastro/effects/test_white_noise.py @@ -3,14 +3,14 @@ from tdastro.sources.static_source import StaticSource -def brightness_generator(): +def rand_generator(): """A test generator function.""" return 10.0 + 0.5 * np.random.rand(1) def test_white_noise() -> None: """Test that we can sample and create a WhiteNoise object.""" - model = StaticSource(brightness=brightness_generator) + model = StaticSource(brightness=rand_generator) model.add_effect(WhiteNoise(scale=0.01)) times = np.array([1, 2, 3, 5, 10]) @@ -20,3 +20,18 @@ def test_white_noise() -> None: assert values.shape == (5, 3) assert not np.all(values == 10.0) assert np.all(np.abs(values - 10.0) < 1.0) + + +def test_white_noise_random() -> None: + """Test that we can resample effects to change their parameters.""" + wn_effect = WhiteNoise(scale=rand_generator) + scale = wn_effect.scale + wn_effect.sample_parameters() + assert scale != wn_effect.scale + + # We can resample when it is added to a PhysicalObject. + model = StaticSource(brightness=10.0) + model.add_effect(WhiteNoise(scale=rand_generator)) + scale = model.effects[0].scale + model.sample_parameters() + assert model.effects[0].scale != scale diff --git a/tests/tdastro/test_base_models.py b/tests/tdastro/test_base_models.py index 909b654c..86b1291d 100644 --- a/tests/tdastro/test_base_models.py +++ b/tests/tdastro/test_base_models.py @@ -20,10 +20,12 @@ class PairModel(ParameterizedModel): Attributes ---------- - value1 : `float`, `function`, `ParameterizedModel`, or `None` + value1 : `float` The first value. - value2 : `float`, `function`, `ParameterizedModel`, or `None` + value2 : `float` The second value. + value_sum : `float` + The sum of the two values. """ def __init__(self, value1, value2, **kwargs): @@ -41,6 +43,7 @@ def __init__(self, value1, value2, **kwargs): super().__init__(**kwargs) self.add_parameter("value1", value1, required=True, **kwargs) self.add_parameter("value2", value2, required=True, **kwargs) + self.add_parameter("value_sum", self.result, required=True, **kwargs) def result(self, **kwargs): """Add the pair of values together @@ -48,7 +51,12 @@ def result(self, **kwargs): Parameters ---------- **kwargs : `dict`, optional - Any additional keyword arguments. + Any additional keyword arguments. + + Returns + ------- + result : `float` + The result of the addition. """ return self.value1 + self.value2 @@ -60,6 +68,7 @@ def test_parameterized_model() -> None: assert model1.value1 == 0.5 assert model1.value1 == 0.5 assert model1.result() == 1.0 + assert model1.value_sum == 1.0 assert model1.sample_iteration == 0 # Use value1=model.value and value2=1.0 @@ -67,18 +76,21 @@ def test_parameterized_model() -> None: assert model2.value1 == 0.5 assert model2.value2 == 1.0 assert model2.result() == 1.5 + assert model2.value_sum == 1.5 assert model2.sample_iteration == 0 # Compute value1 from model2's result and value2 from the sampler function. model3 = PairModel(value1=model2.result, value2=_sampler_fun) rand_val = model3.value2 assert model3.result() == pytest.approx(1.5 + rand_val) + assert model3.value_sum == pytest.approx(1.5 + rand_val) assert model3.sample_iteration == 0 # Compute value1 from model3's result (which is itself the result for model2 + # a random value) and value2 = -1.0. model4 = PairModel(value1=model3.result, value2=-1.0) assert model4.result() == pytest.approx(0.5 + rand_val) + assert model4.value_sum == pytest.approx(0.5 + rand_val) assert model4.sample_iteration == 0 final_res = model4.result() @@ -92,13 +104,17 @@ def test_parameterized_model() -> None: assert model1.value1 == 0.5 assert model1.value1 == 0.5 assert model1.result() == 1.0 + assert model1.value_sum == 1.0 assert model2.value1 == 0.5 assert model2.value2 == 1.0 assert model2.result() == 1.5 + assert model2.value_sum == 1.5 # Models 3 and 4 use the data from the new random value. assert model3.result() == pytest.approx(1.5 + rand_val) assert model4.result() == pytest.approx(0.5 + rand_val) + assert model3.value_sum == pytest.approx(1.5 + rand_val) + assert model4.value_sum == pytest.approx(0.5 + rand_val) assert final_res != model4.result() # All models should have the same sample iteration.