Skip to content

Commit

Permalink
Merge pull request #182 from lincc-frameworks/small_opt
Browse files Browse the repository at this point in the history
Improve error checking for node labels
  • Loading branch information
jeremykubica authored Oct 31, 2024
2 parents 18536db + 43d66e3 commit 3aec1af
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 6 deletions.
9 changes: 6 additions & 3 deletions src/tdastro/base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ def _sample_helper(self, graph_state, seen_nodes, rng_info=None):
An object mapping graph parameters to their values. This object is modified
in place as it is sampled.
seen_nodes : `dict`
A dictionary mapping nodes seen during this sampling run to their ID.
A dictionary mapping nodes strings seen during this sampling run to their object.
Used to avoid sampling nodes multiple times and to validity check the graph.
rng_info : numpy.random._generator.Generator, optional
A given numpy random number generator to use for this computation. If not
Expand All @@ -482,9 +482,12 @@ def _sample_helper(self, graph_state, seen_nodes, rng_info=None):
------
Raise a ``KeyError`` if the sampling encounters an error with the order of dependencies.
"""
if self in seen_nodes:
node_str = str(self)
if node_str in seen_nodes:
if seen_nodes[node_str] != self:
raise ValueError(f"Duplicate node label {node_str}.")
return # Nothing to do
seen_nodes[self] = self.node_pos
seen_nodes[node_str] = self

# Run through each parameter and sample it based on the given recipe.
# As of Python 3.7 dictionaries are guaranteed to preserve insertion ordering,
Expand Down
4 changes: 2 additions & 2 deletions src/tdastro/sources/physical_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,8 @@ def evaluate(self, times, wavelengths, graph_state=None, given_args=None, rng_in
A length T x N matrix of SED values (in nJy).
"""
# Make sure times and wavelengths are numpy arrays.
times = np.array(times)
wavelengths = np.array(wavelengths)
times = np.asarray(times)
wavelengths = np.asarray(wavelengths)

# Check if we need to sample the graph.
if graph_state is None:
Expand Down
5 changes: 4 additions & 1 deletion src/tdastro/sources/sncomso_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ class SncosmoWrapperModel(PhysicalModel):
Any additional keyword arguments.
"""

# A class variable for the units so we are not computing them each time.
_FLAM_UNIT = u.erg / u.second / u.cm**2 / u.AA

def __init__(self, source_name, t0=0.0, node_label=None, **kwargs):
super().__init__(node_label=node_label, **kwargs)
self.source_name = source_name
Expand Down Expand Up @@ -160,7 +163,7 @@ def compute_flux(self, times, wavelengths, graph_state=None, **kwargs):
flux_flam,
wavelengths,
wave_unit=u.AA,
flam_unit=u.erg / u.second / u.cm**2 / u.AA,
flam_unit=self._FLAM_UNIT,
fnu_unit=u.nJy,
)

Expand Down
20 changes: 20 additions & 0 deletions tests/tdastro/test_base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,26 @@ def test_parameterized_node():
assert model4.get_param(state, "value_sum") != model4.get_param(new_state, "value_sum")


def test_parameterized_node_label_collision():
"""Test that throw an error when two nodes use the same label."""
node_a = SingleVariableNode("A", 10.0, node_label="A")
node_b = SingleVariableNode("B", 20.0, node_label="B")
pair1 = PairModel(value1=node_a.A, value2=node_b.B, node_label="pair1")
pair2 = PairModel(value1=pair1.value_sum, value2=node_b.B, node_label="pair2")

# No collision even though node_b is referenced twice.
state = pair2.sample_parameters()
assert state["pair1"]["value_sum"] == 30.0
assert state["pair2"]["value_sum"] == 50.0

# We run into a problem if we reuse a label. Label "A" collides multiple
# levels above.
node_c = SingleVariableNode("C", 5.0, node_label="C")
pair3 = PairModel(value1=pair2.value_sum, value2=node_c.C, node_label="A")
with pytest.raises(ValueError):
_ = pair3.sample_parameters()


def test_parameterized_node_modify():
"""Test that we can modify the parameters in a node."""
model = PairModel(value1=0.5, value2=0.5)
Expand Down

0 comments on commit 3aec1af

Please sign in to comment.