diff --git a/src/tdastro/graph_state.py b/src/tdastro/graph_state.py index ed99c03c..1664bdcf 100644 --- a/src/tdastro/graph_state.py +++ b/src/tdastro/graph_state.py @@ -47,8 +47,6 @@ class GraphState: are fixed in this GraphState instance. """ - _NAME_SEPARATOR = "." - def __init__(self, num_samples=1): if num_samples < 1: raise ValueError(f"Invalid number of samples {num_samples}") @@ -63,8 +61,8 @@ def __len__(self): def __contains__(self, key): if key in self.states: return True - elif self._NAME_SEPARATOR in key: - tokens = key.split(self._NAME_SEPARATOR) + elif "." in key: + tokens = key.split(".") if len(tokens) != 2: raise KeyError(f"Invalid GraphState key: {key}") return tokens[0] in self.states and tokens[1] in self.states[tokens[0]] @@ -112,8 +110,8 @@ def __getitem__(self, key): access by both the pair of keys and the extended name.""" if key in self.states: return self.states[key] - elif self._NAME_SEPARATOR in key: - tokens = key.split(self._NAME_SEPARATOR) + elif "." in key: + tokens = key.split(".") if len(tokens) != 2: raise KeyError(f"Invalid GraphState key: {key}") return self.states[tokens[0]][tokens[1]] @@ -134,14 +132,14 @@ def extended_param_name(node_name, param_name): Returns ------- extended : str - A name of the form {node_name}{_NAME_SEPARATOR}{param_name} + A name of the form {node_name}.{param_name} """ - return f"{node_name}{GraphState._NAME_SEPARATOR}{param_name}" + return f"{node_name}.{param_name}" @classmethod def from_table(cls, input_table): """Create the GraphState from an AstroPy Table with columns for each parameter - and column names of the form '{node_name}{_NAME_SEPARATOR}{param_name}'. + and column names of the form '{node_name}.{param_name}'. Parameters ---------- @@ -151,11 +149,10 @@ def from_table(cls, input_table): num_samples = len(input_table) result = GraphState(num_samples=num_samples) for col in input_table.colnames: - components = col.split(cls._NAME_SEPARATOR) + components = col.split(".") if len(components) != 2: raise ValueError( - f"Invalid name for entry {col}. Entries should be of the form " - f"'node_name{cls._NAME_SEPARATOR}param_name'." + f"Invalid name for entry {col}. Entries should be of the form " f"'node_name.param_name'." ) # If we only have a single value then store that value instead of the np array. @@ -226,8 +223,8 @@ def set(self, node_name, var_name, value, force_copy=False, fixed=False): Default: ``False`` """ # Check that the names do not use the separator value. - if self._NAME_SEPARATOR in node_name or self._NAME_SEPARATOR in var_name: - raise ValueError(f"Names cannot contain the substring '{self._NAME_SEPARATOR}'.") + if "." in node_name or "." in var_name: + raise ValueError("Names cannot contain the character '.'") # Update the meta data. if node_name not in self.states: @@ -324,6 +321,75 @@ def extract_single_sample(self, sample_num): new_state.states[node_name][var_name] = value[sample_num] return new_state + def extract_parameters(self, params): + """Extract the parameter value(s) by a given name. This is often used for + recording the important parameters from an entire model (set of nodes). + + Parameters + ---------- + params : str or list-like, optional + The parameter names to extract. These can be full names ("node.param") or + use the parameter names. + + Returns + ------- + values : dict + The resulting dictionary. + """ + # If we are looking up a single parameter, but it into a list. + if isinstance(params, str): + params = [params] + + # Go through all the parameters. If a parameters full name is provided, + # look it up now and save the result. Otherwise put it into a list to check + # for in each node. + single_params = set() + results = {} + for current in params: + if "." in current: + node_name, param_name = current.split(".") + if node_name in self.states and param_name in self.states[node_name]: + results[current] = self.states[node_name][param_name] + else: + raise KeyError(f"Parameter {current} not found in GraphState.") + else: + single_params.add(current) + + if len(single_params) == 0: + # Nothing else to do. + return results + + # Traverse the nested dictionaries looking for cases where the parameter names match. + first_seen_node = {} + for node_name, node_params in self.states.items(): + for param_name, param_value in node_params.items(): + if param_name in single_params: + if param_name in first_seen_node: + # We've already seen this parameter in another node. Time to use the + # expanded names. + + # Start by expanding the result we have already seen if needed. + if param_name in results: + full_name_existing = f"{first_seen_node[param_name]}.{param_name}" + results[full_name_existing] = results[param_name] + del results[param_name] + + # Add the result from the current node. + full_name_current = f"{node_name}.{param_name}" + results[full_name_current] = param_value + else: + # This is the first time we have seen the node. Save it with + # just the parameter name. Also save the node where we saw it. + results[param_name] = param_value + first_seen_node[param_name] = node_name + + # Check that we found a match for all the short parameter names. + for param_name in single_params: + if param_name not in first_seen_node: + raise KeyError(f"Parameter {param_name} not found in GraphState.") + + return results + def to_table(self): """Flatten the graph state to an AstroPy Table with columns for each parameter. diff --git a/tests/tdastro/test_graph_state.py b/tests/tdastro/test_graph_state.py index 7df62f3b..cf5f1e23 100644 --- a/tests/tdastro/test_graph_state.py +++ b/tests/tdastro/test_graph_state.py @@ -28,11 +28,11 @@ def test_create_single_sample_graph_state(): _ = state["c"]["v1"] # We can access the entries using the extended key name. - assert state[f"a{state._NAME_SEPARATOR}v1"] == 1.0 - assert state[f"a{state._NAME_SEPARATOR}v2"] == 2.0 - assert state[f"b{state._NAME_SEPARATOR}v1"] == 3.0 + assert state["a.v1"] == 1.0 + assert state["a.v2"] == 2.0 + assert state["b.v1"] == 3.0 with pytest.raises(KeyError): - _ = state[f"c{state._NAME_SEPARATOR}v1"] + _ = state["c.v1"] # We can create a human readable string representation of the GraphState. debug_str = str(state) @@ -70,9 +70,9 @@ def test_create_single_sample_graph_state(): # Test we cannot use a name containing the separator as a substring. with pytest.raises(ValueError): - state.set(f"a{state._NAME_SEPARATOR}b", "v1", 10.0) + state.set("a.b", "v1", 10.0) with pytest.raises(ValueError): - state.set("b", f"v1{state._NAME_SEPARATOR}v3", 10.0) + state.set("b", "v1.v3", 10.0) def test_graph_state_contains(): @@ -86,14 +86,14 @@ def test_graph_state_contains(): assert "b" in state assert "c" not in state - assert f"a{state._NAME_SEPARATOR}v1" in state - assert f"a{state._NAME_SEPARATOR}v2" in state - assert f"a{state._NAME_SEPARATOR}v3" not in state - assert f"b{state._NAME_SEPARATOR}v1" in state - assert f"c{state._NAME_SEPARATOR}v1" not in state + assert "a.v1" in state + assert "a.v2" in state + assert "a.v3" not in state + assert "b.v1" in state + assert "c.v1" not in state with pytest.raises(KeyError): - assert f"b{state._NAME_SEPARATOR}v1{state._NAME_SEPARATOR}v2" not in state + assert "b.v1.v2" not in state def test_create_multi_sample_graph_state(): @@ -402,6 +402,65 @@ def test_graph_to_from_file(): state.save_to_file(file_path, overwrite=True) +def test_graph_state_extract_parameters(): + """Test that we can extract named parameters from a GraphState.""" + state = GraphState() + state.set("a", "v0", 0.0) + state.set("a", "v1", 1.0) + state.set("a", "v2", 2.0) + state.set("a", "v3", 3.0) + state.set("b", "v1", 4.0) + state.set("c", "v2", 5.0) + state.set("c", "v3", 6.0) + state.set("d", "v4", 7.0) + state.set("e", "v3", 8.0) + state.set("e", "v5", 9.0) + state.set("e", "v6", 10.0) + state.set("f", "v6", 11.0) + + # We can extract a mixture of unique parameters based on full and short names. + results = state.extract_parameters(["a.v1", "c.v2", "v5"]) + assert len(results) == 3 + assert results["a.v1"] == 1.0 + assert results["c.v2"] == 5.0 + assert results["v5"] == 9.0 + + # If we extract a parameter that appears in multiple nodes with its short + # name, we expand the name for each instance. + results = state.extract_parameters(["v2", "v3", "v4"]) + assert len(results) == 6 + assert results["a.v2"] == 2.0 + assert results["a.v3"] == 3.0 + assert results["c.v2"] == 5.0 + assert results["c.v3"] == 6.0 + assert results["e.v3"] == 8.0 + assert results["v4"] == 7.0 # We do not expand the name. + + # We can also provide a single parameter name as a string. + results = state.extract_parameters("v2") + assert len(results) == 2 + assert results["a.v2"] == 2.0 + assert results["c.v2"] == 5.0 + + # Test a complicated list. + results = state.extract_parameters(["v0", "c.v2", "e.v3", "v5", "v6", "a.v3"]) + assert len(results) == 7 + assert results["v0"] == 0.0 + assert results["c.v2"] == 5.0 + assert results["e.v3"] == 8.0 + assert results["v5"] == 9.0 + assert results["e.v6"] == 10.0 + assert results["f.v6"] == 11.0 + assert results["a.v3"] == 3.0 + + # We raise a KeyError if we try to lookup a parameter that is not in the GraphState. + with pytest.raises(KeyError): + _ = state.extract_parameters(["v2", "v3", "c.v4"]) + + with pytest.raises(KeyError): + _ = state.extract_parameters(["v2", "v100"]) + + def test_transpose_dict_of_list(): """Test the transpose_dict_of_list helper function""" input_dict = {