From f701b89b673cfa363a5febea1023f265a31f669f Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Thu, 7 Nov 2024 15:37:14 -0500 Subject: [PATCH] Add iteration to graph_state --- src/tdastro/graph_state.py | 22 ++++++++++++++++++ tests/tdastro/test_graph_state.py | 37 +++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+) diff --git a/src/tdastro/graph_state.py b/src/tdastro/graph_state.py index baf94c8..4adde39 100644 --- a/src/tdastro/graph_state.py +++ b/src/tdastro/graph_state.py @@ -58,6 +58,17 @@ def __init__(self, num_samples=1): def __len__(self): return self.num_parameters + def __next__(self): + return next(self._iterate()) + + def __iter__(self): + return self._iterate() + + def _iterate(self): + """Returns a single sliced state.""" + for idx in range(self.num_samples): + yield self.extract_single_sample(idx) + def __contains__(self, key): if key in self.states: return True @@ -66,6 +77,11 @@ def __contains__(self, key): if len(tokens) != 2: raise KeyError(f"Invalid GraphState key: {key}") return tokens[0] in self.states and tokens[1] in self.states[tokens[0]] + elif len(self.states) == 1: + # Special case when we have only a single node stored in the graph state. + node_state = list(self.states.values())[0] + if key in node_state: + return True else: return False @@ -115,6 +131,12 @@ def __getitem__(self, key): if len(tokens) != 2: raise KeyError(f"Invalid GraphState key: {key}") return self.states[tokens[0]][tokens[1]] + elif len(self.states) == 1: + # Special case when we have only a single node stored + # in the graph state. + node_state = list(self.states.values())[0] + if key in node_state: + return node_state[key] else: raise KeyError(f"Unknown GraphState key: {key}") diff --git a/tests/tdastro/test_graph_state.py b/tests/tdastro/test_graph_state.py index 1dbc022..17e9e85 100644 --- a/tests/tdastro/test_graph_state.py +++ b/tests/tdastro/test_graph_state.py @@ -40,6 +40,18 @@ def test_create_single_sample_graph_state(): assert a_vals["v1"] == 1.0 assert a_vals["v2"] == 2.0 + # If the state only has a single node, we can access that node's variables + # directly. But we get an error if we try to do this with multiple nodes. + state2 = GraphState() + state2.set("a", "v1", 1.0) + state2.set("a", "v2", 2.0) + assert state2["v1"] == 1.0 + assert state2["v2"] == 2.0 + + state2.set("b", "v3", 3.0) + with pytest.raises(KeyError): + _ = state["v1"] + # Can't access an out-of-bounds sample_num. with pytest.raises(ValueError): _ = state.get_node_state("a", 2) @@ -91,6 +103,14 @@ def test_graph_state_contains(): with pytest.raises(KeyError): assert "b.v1.v2" not in state + # If the state only has a single node, we can access that node's variables directly. + state2 = GraphState() + state2.set("a", "v1", 1.0) + state2.set("a", "v2", 2.0) + assert "a" in state2 + assert "v1" in state2 + assert "v2" in state2 + def test_create_multi_sample_graph_state(): """Test that we can create and access a multi-sample GraphState.""" @@ -158,6 +178,23 @@ def test_create_multi_sample_graph_state_reference(): assert np.allclose(state2["b"]["v1"], [2.0, 2.5, 3.0, 3.5, 4.0]) +def test_graph_state_iterate(): + """Test that we can use an iterator to transform a GraphState with + multiple samples into a list of GraphStates each with a single sample. + """ + state = GraphState(10) + state.set("a", "v1", 1.0) + state.set("a", "v2", np.arange(10)) + + list_version = [x for x in state] + assert len(list_version) == 10 + for i in range(10): + sample_state = list_version[i] + assert sample_state.num_samples == 1 + assert sample_state["a.v1"] == 1.0 + assert sample_state["a.v2"] == i + + def test_graph_state_equal(): """Test that we use == on GraphStates.""" state1 = GraphState(num_samples=2)