Skip to content

Commit

Permalink
Merge pull request #189 from lincc-frameworks/math_eval
Browse files Browse the repository at this point in the history
Add two features requested in the review of PR185
  • Loading branch information
jeremykubica authored Nov 12, 2024
2 parents 9ea64af + 178a303 commit 610cefb
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 0 deletions.
22 changes: 22 additions & 0 deletions src/tdastro/graph_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,17 @@ def __init__(self, num_samples=1):
def __len__(self):
return self.num_parameters

def __next__(self):
return next(self._iterate())

def __iter__(self):
return self._iterate()

def _iterate(self):
"""Returns a single sliced state."""
for idx in range(self.num_samples):
yield self.extract_single_sample(idx)

def __contains__(self, key):
if key in self.states:
return True
Expand All @@ -66,6 +77,11 @@ def __contains__(self, key):
if len(tokens) != 2:
raise KeyError(f"Invalid GraphState key: {key}")
return tokens[0] in self.states and tokens[1] in self.states[tokens[0]]
elif len(self.states) == 1:
# Special case when we have only a single node stored in the graph state.
node_state = list(self.states.values())[0]
if key in node_state:
return True
else:
return False

Expand Down Expand Up @@ -115,6 +131,12 @@ def __getitem__(self, key):
if len(tokens) != 2:
raise KeyError(f"Invalid GraphState key: {key}")
return self.states[tokens[0]][tokens[1]]
elif len(self.states) == 1:
# Special case when we have only a single node stored
# in the graph state.
node_state = list(self.states.values())[0]
if key in node_state:
return node_state[key]
else:
raise KeyError(f"Unknown GraphState key: {key}")

Expand Down
37 changes: 37 additions & 0 deletions tests/tdastro/test_graph_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,18 @@ def test_create_single_sample_graph_state():
assert a_vals["v1"] == 1.0
assert a_vals["v2"] == 2.0

# If the state only has a single node, we can access that node's variables
# directly. But we get an error if we try to do this with multiple nodes.
state2 = GraphState()
state2.set("a", "v1", 1.0)
state2.set("a", "v2", 2.0)
assert state2["v1"] == 1.0
assert state2["v2"] == 2.0

state2.set("b", "v3", 3.0)
with pytest.raises(KeyError):
_ = state["v1"]

# Can't access an out-of-bounds sample_num.
with pytest.raises(ValueError):
_ = state.get_node_state("a", 2)
Expand Down Expand Up @@ -91,6 +103,14 @@ def test_graph_state_contains():
with pytest.raises(KeyError):
assert "b.v1.v2" not in state

# If the state only has a single node, we can access that node's variables directly.
state2 = GraphState()
state2.set("a", "v1", 1.0)
state2.set("a", "v2", 2.0)
assert "a" in state2
assert "v1" in state2
assert "v2" in state2


def test_create_multi_sample_graph_state():
"""Test that we can create and access a multi-sample GraphState."""
Expand Down Expand Up @@ -158,6 +178,23 @@ def test_create_multi_sample_graph_state_reference():
assert np.allclose(state2["b"]["v1"], [2.0, 2.5, 3.0, 3.5, 4.0])


def test_graph_state_iterate():
"""Test that we can use an iterator to transform a GraphState with
multiple samples into a list of GraphStates each with a single sample.
"""
state = GraphState(10)
state.set("a", "v1", 1.0)
state.set("a", "v2", np.arange(10))

list_version = [x for x in state]
assert len(list_version) == 10
for i in range(10):
sample_state = list_version[i]
assert sample_state.num_samples == 1
assert sample_state["a.v1"] == 1.0
assert sample_state["a.v2"] == i


def test_graph_state_equal():
"""Test that we use == on GraphStates."""
state1 = GraphState(num_samples=2)
Expand Down

0 comments on commit 610cefb

Please sign in to comment.