diff --git a/src/tdastro/astro_utils/cosmology.py b/src/tdastro/astro_utils/cosmology.py index 57771220..a2e3f5dd 100644 --- a/src/tdastro/astro_utils/cosmology.py +++ b/src/tdastro/astro_utils/cosmology.py @@ -42,16 +42,15 @@ class RedshiftDistFunc(FunctionNode): The function or constant providing the redshift value. cosmology : `astropy.cosmology` The cosmology specification. + **kwargs : `dict`, optional + Any additional keyword arguments. """ - def __init__(self, redshift, cosmology): + def __init__(self, redshift, cosmology, **kwargs): # Call the super class's constructor with the needed information. super().__init__( func=redshift_to_distance, redshift=redshift, cosmology=cosmology, + **kwargs, ) - - def __str__(self): - """Return the string representation of the function.""" - return f"RedshiftDistFunc({self.cosmology.name}, {self.kind})" diff --git a/src/tdastro/base_models.py b/src/tdastro/base_models.py index 4cc15c84..4f2c5b87 100644 --- a/src/tdastro/base_models.py +++ b/src/tdastro/base_models.py @@ -2,6 +2,8 @@ import types from enum import Enum +from hashlib import md5 +from os import urandom class ParameterSource(Enum): @@ -30,19 +32,59 @@ class ParameterizedNode: sample_iteration : `int` A counter used to syncronize sampling runs. Tracks how many times this model's parameters have been resampled. + _object_seed : `int` + A object-specific seed to control random number generation. + _graph_base_seed, `int` + A base random seed to use for this specific evaluation graph. Used + for validity checking. + + Parameters + ---------- + node_identifier : `str`, optional + An identifier (or name) for the current node. + graph_base_seed : `int`, optional + A base random seed to use for this specific evaluation graph. + WARNING: This seed should almost never be set manually. Using the same + seed for multiple graph instances will produce biased samples. + If set to ``None`` will use urandom() to produce a fully random seed. + **kwargs : `dict`, optional + Any additional keyword arguments. """ - def __init__(self, node_identifier=None, **kwargs): + def __init__(self, node_identifier=None, graph_base_seed=None, **kwargs): self.setters = {} self.sample_iteration = 0 self.node_identifier = node_identifier + self.set_graph_base_seed(graph_base_seed) def __str__(self): """Return the string representation of the node.""" + name = f"{self.__class__.__module__}.{self.__class__.__qualname__}" if self.node_identifier: - return f"{self.node_identifier}={self.__class__.__name__}" + return f"{self.node_identifier}={name}" else: - return self.__class__.__name__ + return name + + def set_graph_base_seed(self, graph_base_seed): + """Set a new graph base seed. + + Notes + ----- + WARNING: This seed should almost never be set manually. Using the same + seed for multiple graph instances will produce biased samples. + + Parameters + ---------- + graph_base_seed : `int`, optional + A base random seed to use for this specific evaluation graph. + """ + if graph_base_seed is None: + graph_base_seed = int.from_bytes(urandom(4), "big") + self._graph_base_seed = graph_base_seed + + hashed_object_name = md5(str(self).encode()) + seed_offset = int(hashed_object_name.hexdigest(), base=16) + self._object_seed = (graph_base_seed + seed_offset) % (2**31) def check_resample(self, other): """Check if we need to resample the current node based @@ -264,6 +306,34 @@ def get_all_parameter_values(self, recursive=True, seen=None): values[full_name] = getattr(self, name) return values + def get_dependencies(self, nodes=None): + """Get all nodes on which this current node depends. + + Parameters + ---------- + nodes : `set` + A set of all nodes at or above this node in the graph. + This is modified in place. + + Returns + ------- + nodes : `set` + A set of all nodes at or above this node in the graph. + """ + # Make sure that we do not process the same nodes multiple times. + if nodes is None: + nodes = set() + if self in nodes: + return nodes + nodes.add(self) + + for source_type, setter, _ in self.setters.values(): + if source_type == ParameterSource.MODEL_ATTRIBUTE: + nodes = setter.get_dependencies(nodes) + elif source_type == ParameterSource.MODEL_METHOD: + nodes = setter.__self__.get_dependencies(nodes) + return nodes + class FunctionNode(ParameterizedNode): """A class to wrap functions and their argument settings. @@ -275,6 +345,20 @@ class FunctionNode(ParameterizedNode): args_names : `list` A list of argument names to pass to the function. + Parameters + ---------- + func : `function` or `method` + The function to call during an evaluation. + node_identifier : `str`, optional + An identifier (or name) for the current node. + graph_base_seed : `int`, optional + A base random seed to use for this specific evaluation graph. + WARNING: This seed should almost never be set manually. Using the same + seed for multiple graph instances will produce biased samples. + If set to ``None`` will use urandom() to produce a fully random seed. + **kwargs : `dict`, optional + Any additional keyword arguments. + Examples -------- my_func = TDFunc(random.randint, a=1, b=10) @@ -293,19 +377,21 @@ class FunctionNode(ParameterizedNode): value1 = my_func(b=10.0) """ - def __init__(self, func, **kwargs): - super().__init__(**kwargs) + def __init__(self, func, node_identifier=None, graph_base_seed=None, **kwargs): + # We set the function before calling the parent class so we can use + # the function's name (if needed). self.func = func - self.arg_names = [] + super().__init__(node_identifier=node_identifier, graph_base_seed=graph_base_seed, **kwargs) # Add all of the parameters from default_args or the kwargs. + self.arg_names = [] for key, value in kwargs.items(): self.arg_names.append(key) self.add_parameter(key, value) def __str__(self): """Return the string representation of the function.""" - return f"FunctionNode({self.func.name})" + return f"FunctionNode({self.func.__name__})" def compute(self, **kwargs): """Execute the wrapped function. diff --git a/src/tdastro/effects/redshift.py b/src/tdastro/effects/redshift.py index de594f57..e098d9cd 100644 --- a/src/tdastro/effects/redshift.py +++ b/src/tdastro/effects/redshift.py @@ -33,10 +33,6 @@ def __init__(self, redshift=None, t0=None, **kwargs): self.add_parameter("redshift", redshift, required=True, **kwargs) self.add_parameter("t0", t0, required=True, **kwargs) - def __str__(self) -> str: - """Return a string representation of the Redshift effect model.""" - return f"RedshiftEffect(redshift={self.redshift})" - def pre_effect(self, observer_frame_times, observer_frame_wavelengths, **kwargs): """Calculate the rest-frame times and wavelengths needed to give us the observer-frame times and wavelengths (given the redshift). diff --git a/tests/tdastro/sources/test_sncosmo_models.py b/tests/tdastro/sources/test_sncosmo_models.py index e3c67e80..42e591b7 100644 --- a/tests/tdastro/sources/test_sncosmo_models.py +++ b/tests/tdastro/sources/test_sncosmo_models.py @@ -7,7 +7,7 @@ def test_sncomso_models_hsiao() -> None: model = SncosmoWrapperModel("hsiao") model.set(amplitude=1.0e-10) assert model.amplitude == 1.0e-10 - assert str(model) == "SncosmoWrapperModel" + assert str(model) == "tdastro.sources.sncomso_models.SncosmoWrapperModel" assert np.array_equal(model.param_names, ["amplitude"]) assert np.array_equal(model.parameters, [1.0e-10]) diff --git a/tests/tdastro/sources/test_spline_source.py b/tests/tdastro/sources/test_spline_source.py index 5b30bc87..4b18e70a 100644 --- a/tests/tdastro/sources/test_spline_source.py +++ b/tests/tdastro/sources/test_spline_source.py @@ -8,7 +8,7 @@ def test_spline_model_flat() -> None: wavelengths = np.linspace(100.0, 500.0, 25) fluxes = np.full((len(times), len(wavelengths)), 1.0) model = SplineModel(times, wavelengths, fluxes) - assert str(model) == "SplineModel" + assert str(model) == "tdastro.sources.spline_model.SplineModel" test_times = np.array([0.0, 1.0, 2.0, 3.0, 10.0]) test_waves = np.array([0.0, 100.0, 200.0, 1000.0]) @@ -19,7 +19,7 @@ def test_spline_model_flat() -> None: np.testing.assert_array_almost_equal(values, expected) model2 = SplineModel(times, wavelengths, fluxes, amplitude=5.0, node_identifier="test") - assert str(model2) == "test=SplineModel" + assert str(model2) == "test=tdastro.sources.spline_model.SplineModel" values2 = model2.evaluate(test_times, test_waves) assert values2.shape == (5, 4) diff --git a/tests/tdastro/sources/test_static_source.py b/tests/tdastro/sources/test_static_source.py index 2dacd3b0..2e8f5a7a 100644 --- a/tests/tdastro/sources/test_static_source.py +++ b/tests/tdastro/sources/test_static_source.py @@ -24,7 +24,7 @@ def test_static_source() -> None: assert model.ra is None assert model.dec is None assert model.distance is None - assert str(model) == "my_static_source=StaticSource" + assert str(model) == "my_static_source=tdastro.sources.static_source.StaticSource" times = np.array([1, 2, 3, 4, 5, 10]) wavelengths = np.array([100.0, 200.0, 300.0]) @@ -49,7 +49,7 @@ def test_static_source_host() -> None: assert model.ra == 1.0 assert model.dec == 2.0 assert model.distance == 3.0 - assert str(model) == "StaticSource" + assert str(model) == "tdastro.sources.static_source.StaticSource" def test_static_source_resample() -> None: diff --git a/tests/tdastro/test_base_models.py b/tests/tdastro/test_base_models.py index f474aea1..37005b40 100644 --- a/tests/tdastro/test_base_models.py +++ b/tests/tdastro/test_base_models.py @@ -29,6 +29,20 @@ def _test_func(value1, value2): return value1 + value2 +class SingleModel(ParameterizedNode): + """A test class for the ParameterizedNode. + + Attributes + ---------- + value1 : `float` + The first value. + """ + + def __init__(self, value1, **kwargs): + super().__init__(**kwargs) + self.add_parameter("value1", value1, required=True, **kwargs) + + class PairModel(ParameterizedNode): """A test class for the ParameterizedNode. @@ -43,17 +57,6 @@ class PairModel(ParameterizedNode): """ def __init__(self, value1, value2, **kwargs): - """Create a ConstModel object. - - Parameters - ---------- - value1 : `float`, `function`, `ParameterizedNode`, or `None` - The first value. - value2 : `float`, `function`, `ParameterizedNode`, or `None` - The second value. - **kwargs : `dict`, optional - Any additional keyword arguments. - """ super().__init__(**kwargs) self.add_parameter("value1", value1, required=True, **kwargs) self.add_parameter("value2", value2, required=True, **kwargs) @@ -79,7 +82,7 @@ def result(self, **kwargs): return self.value1 + self.value2 -def test_parameterized_node() -> None: +def test_parameterized_node(): """Test that we can sample and create a PairModel object.""" # Simple addition model1 = PairModel(value1=0.5, value2=0.5) @@ -141,7 +144,7 @@ def test_parameterized_node() -> None: assert model1.sample_iteration == model4.sample_iteration -def test_parameterized_node_overwrite() -> None: +def test_parameterized_node_overwrite(): """Test that we can overwrite attributes in a PairModel.""" model1 = PairModel(value1=0.5, value2=0.5) assert model1.value1 == 0.5 @@ -159,7 +162,7 @@ def test_parameterized_node_overwrite() -> None: assert model1.value1 == 1.0 -def test_parameterized_node_attributes() -> None: +def test_parameterized_node_attributes(): """Test that we can extract the attributes of a graph of ParameterizedNode.""" model1 = PairModel(value1=0.5, value2=1.5, node_identifier="1") settings = model1.get_all_parameter_values(False) @@ -170,9 +173,9 @@ def test_parameterized_node_attributes() -> None: settings = model1.get_all_parameter_values(True) assert len(settings) == 3 - assert settings["1=PairModel.value1"] == 0.5 - assert settings["1=PairModel.value2"] == 1.5 - assert settings["1=PairModel.value_sum"] == 2.0 + assert settings["1=test_base_models.PairModel.value1"] == 0.5 + assert settings["1=test_base_models.PairModel.value2"] == 1.5 + assert settings["1=test_base_models.PairModel.value_sum"] == 2.0 # Use value1=model.value and value2=3.0 model2 = PairModel(value1=model1, value2=3.0, node_identifier="2") @@ -184,15 +187,37 @@ def test_parameterized_node_attributes() -> None: settings = model2.get_all_parameter_values(True) assert len(settings) == 6 - assert settings["1=PairModel.value1"] == 0.5 - assert settings["1=PairModel.value2"] == 1.5 - assert settings["1=PairModel.value_sum"] == 2.0 - assert settings["2=PairModel.value1"] == 0.5 - assert settings["2=PairModel.value2"] == 3.0 - assert settings["2=PairModel.value_sum"] == 3.5 + assert settings["1=test_base_models.PairModel.value1"] == 0.5 + assert settings["1=test_base_models.PairModel.value2"] == 1.5 + assert settings["1=test_base_models.PairModel.value_sum"] == 2.0 + assert settings["2=test_base_models.PairModel.value1"] == 0.5 + assert settings["2=test_base_models.PairModel.value2"] == 3.0 + assert settings["2=test_base_models.PairModel.value_sum"] == 3.5 + +def test_parameterized_node_get_dependencies(): + """Test that we can extract the attributes of a graph of ParameterizedNode.""" + model1 = PairModel(value1=0.5, value2=1.5, node_identifier="1") + model2 = PairModel(value1=model1, value2=3.0, node_identifier="2") + model3 = PairModel(value1=model1, value2=model2.result, node_identifier="3") -def test_parameterized_node_modify() -> None: + dep1 = model1.get_dependencies() + assert len(dep1) == 1 + assert model1 in dep1 + + dep2 = model2.get_dependencies() + assert len(dep2) == 2 + assert model1 in dep2 + assert model2 in dep2 + + dep3 = model3.get_dependencies() + assert len(dep3) == 3 + assert model1 in dep3 + assert model2 in dep3 + assert model3 in dep3 + + +def test_parameterized_node_modify(): """Test that we can modify the parameters in a node.""" model = PairModel(value1=0.5, value2=0.5) assert model.value1 == 0.5 @@ -212,6 +237,36 @@ def test_parameterized_node_modify() -> None: model.set_parameter("brightness", 5.0) +def test_parameterized_node_seed(): + """Test that we can set a random seed for the entire graph.""" + # Left unspecified we use full random seeds. + model_a = PairModel(value1=0.5, value2=0.5) + model_b = PairModel(value1=0.5, value2=0.5) + assert model_a._object_seed != model_b._object_seed + + # If we specify a seed, the results are the same objects with + # the same name (class + node identifier) and different otherwise. + model_a = PairModel(value1=0.5, value2=0.5, graph_base_seed=10, node_identifier="A") + model_b = PairModel(value1=0.5, value2=0.5, graph_base_seed=10, node_identifier="B") + model_c = PairModel(value1=0.5, value2=0.5, graph_base_seed=10, node_identifier="A") + model_d = SingleModel(value1=0.5, node_identifier="A") + assert model_a._object_seed != model_b._object_seed + assert model_a._object_seed == model_c._object_seed + assert model_a._object_seed != model_d._object_seed + + assert model_b._object_seed != model_a._object_seed + assert model_b._object_seed != model_c._object_seed + assert model_b._object_seed != model_d._object_seed + + assert model_c._object_seed == model_a._object_seed + assert model_c._object_seed != model_b._object_seed + assert model_c._object_seed != model_d._object_seed + + assert model_d._object_seed != model_a._object_seed + assert model_d._object_seed != model_b._object_seed + assert model_d._object_seed != model_c._object_seed + + def test_function_node_basic(): """Test that we can create and query a FunctionNode.""" my_func = FunctionNode(_test_func, value1=1.0, value2=2.0)