From 7270452d17331bd1fbce948a1fe5efa15fdb56a1 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Mon, 28 Oct 2024 09:27:24 -0400 Subject: [PATCH 1/3] Update sncomso_models.py --- src/tdastro/sources/sncomso_models.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/tdastro/sources/sncomso_models.py b/src/tdastro/sources/sncomso_models.py index 596a815..54dcfc3 100644 --- a/src/tdastro/sources/sncomso_models.py +++ b/src/tdastro/sources/sncomso_models.py @@ -36,6 +36,9 @@ class SncosmoWrapperModel(PhysicalModel): Any additional keyword arguments. """ + # A class variable for the units so we are not computing them each time. + _FLAM_UNIT = u.erg / u.second / u.cm**2 / u.AA + def __init__(self, source_name, t0=0.0, node_label=None, **kwargs): super().__init__(node_label=node_label, **kwargs) self.source_name = source_name @@ -160,7 +163,7 @@ def compute_flux(self, times, wavelengths, graph_state=None, **kwargs): flux_flam, wavelengths, wave_unit=u.AA, - flam_unit=u.erg / u.second / u.cm**2 / u.AA, + flam_unit=self._FLAM_UNIT, fnu_unit=u.nJy, ) From 717818131aac849859e375fda07277778653ebd4 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Mon, 28 Oct 2024 11:17:12 -0400 Subject: [PATCH 2/3] Update physical_model.py --- src/tdastro/sources/physical_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tdastro/sources/physical_model.py b/src/tdastro/sources/physical_model.py index 8b4ee87..3cd43a8 100644 --- a/src/tdastro/sources/physical_model.py +++ b/src/tdastro/sources/physical_model.py @@ -144,8 +144,8 @@ def evaluate(self, times, wavelengths, graph_state=None, given_args=None, rng_in A length T x N matrix of SED values (in nJy). """ # Make sure times and wavelengths are numpy arrays. - times = np.array(times) - wavelengths = np.array(wavelengths) + times = np.asarray(times) + wavelengths = np.asarray(wavelengths) # Check if we need to sample the graph. if graph_state is None: From 43d66e3c7e6834d6df7f7beddf9a4915e09191f1 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Thu, 31 Oct 2024 09:35:18 -0400 Subject: [PATCH 3/3] Improve label error checking --- src/tdastro/base_models.py | 9 ++++++--- tests/tdastro/test_base_models.py | 20 ++++++++++++++++++++ 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/src/tdastro/base_models.py b/src/tdastro/base_models.py index 886fa9e..959acf6 100644 --- a/src/tdastro/base_models.py +++ b/src/tdastro/base_models.py @@ -472,7 +472,7 @@ def _sample_helper(self, graph_state, seen_nodes, rng_info=None): An object mapping graph parameters to their values. This object is modified in place as it is sampled. seen_nodes : `dict` - A dictionary mapping nodes seen during this sampling run to their ID. + A dictionary mapping nodes strings seen during this sampling run to their object. Used to avoid sampling nodes multiple times and to validity check the graph. rng_info : numpy.random._generator.Generator, optional A given numpy random number generator to use for this computation. If not @@ -482,9 +482,12 @@ def _sample_helper(self, graph_state, seen_nodes, rng_info=None): ------ Raise a ``KeyError`` if the sampling encounters an error with the order of dependencies. """ - if self in seen_nodes: + node_str = str(self) + if node_str in seen_nodes: + if seen_nodes[node_str] != self: + raise ValueError(f"Duplicate node label {node_str}.") return # Nothing to do - seen_nodes[self] = self.node_pos + seen_nodes[node_str] = self # Run through each parameter and sample it based on the given recipe. # As of Python 3.7 dictionaries are guaranteed to preserve insertion ordering, diff --git a/tests/tdastro/test_base_models.py b/tests/tdastro/test_base_models.py index 64c5d82..d0766ef 100644 --- a/tests/tdastro/test_base_models.py +++ b/tests/tdastro/test_base_models.py @@ -153,6 +153,26 @@ def test_parameterized_node(): assert model4.get_param(state, "value_sum") != model4.get_param(new_state, "value_sum") +def test_parameterized_node_label_collision(): + """Test that throw an error when two nodes use the same label.""" + node_a = SingleVariableNode("A", 10.0, node_label="A") + node_b = SingleVariableNode("B", 20.0, node_label="B") + pair1 = PairModel(value1=node_a.A, value2=node_b.B, node_label="pair1") + pair2 = PairModel(value1=pair1.value_sum, value2=node_b.B, node_label="pair2") + + # No collision even though node_b is referenced twice. + state = pair2.sample_parameters() + assert state["pair1"]["value_sum"] == 30.0 + assert state["pair2"]["value_sum"] == 50.0 + + # We run into a problem if we reuse a label. Label "A" collides multiple + # levels above. + node_c = SingleVariableNode("C", 5.0, node_label="C") + pair3 = PairModel(value1=pair2.value_sum, value2=node_c.C, node_label="A") + with pytest.raises(ValueError): + _ = pair3.sample_parameters() + + def test_parameterized_node_modify(): """Test that we can modify the parameters in a node.""" model = PairModel(value1=0.5, value2=0.5)