diff --git a/src/tdastro/graph_state.py b/src/tdastro/graph_state.py index 8da9fb4..8e18e17 100644 --- a/src/tdastro/graph_state.py +++ b/src/tdastro/graph_state.py @@ -47,8 +47,6 @@ class GraphState: are fixed in this GraphState instance. """ - _NAME_SEPARATOR = "." - def __init__(self, num_samples=1): if num_samples < 1: raise ValueError(f"Invalid number of samples {num_samples}") @@ -63,8 +61,8 @@ def __len__(self): def __contains__(self, key): if key in self.states: return True - elif self._NAME_SEPARATOR in key: - tokens = key.split(self._NAME_SEPARATOR) + elif "." in key: + tokens = key.split(".") if len(tokens) != 2: raise KeyError(f"Invalid GraphState key: {key}") return tokens[0] in self.states and tokens[1] in self.states[tokens[0]] @@ -112,8 +110,8 @@ def __getitem__(self, key): access by both the pair of keys and the extended name.""" if key in self.states: return self.states[key] - elif self._NAME_SEPARATOR in key: - tokens = key.split(self._NAME_SEPARATOR) + elif "." in key: + tokens = key.split(".") if len(tokens) != 2: raise KeyError(f"Invalid GraphState key: {key}") return self.states[tokens[0]][tokens[1]] @@ -134,14 +132,14 @@ def extended_param_name(node_name, param_name): Returns ------- extended : str - A name of the form {node_name}{_NAME_SEPARATOR}{param_name} + A name of the form {node_name}.{param_name} """ - return f"{node_name}{GraphState._NAME_SEPARATOR}{param_name}" + return f"{node_name}.{param_name}" @classmethod def from_table(cls, input_table): """Create the GraphState from an AstroPy Table with columns for each parameter - and column names of the form '{node_name}{_NAME_SEPARATOR}{param_name}'. + and column names of the form '{node_name}.{param_name}'. Parameters ---------- @@ -151,11 +149,10 @@ def from_table(cls, input_table): num_samples = len(input_table) result = GraphState(num_samples=num_samples) for col in input_table.colnames: - components = col.split(cls._NAME_SEPARATOR) + components = col.split(".") if len(components) != 2: raise ValueError( - f"Invalid name for entry {col}. Entries should be of the form " - f"'node_name{cls._NAME_SEPARATOR}param_name'." + f"Invalid name for entry {col}. Entries should be of the form " f"'node_name.param_name'." ) # If we only have a single value then store that value instead of the np array. @@ -226,8 +223,8 @@ def set(self, node_name, var_name, value, force_copy=False, fixed=False): Default: ``False`` """ # Check that the names do not use the separator value. - if self._NAME_SEPARATOR in node_name or self._NAME_SEPARATOR in var_name: - raise ValueError(f"Names cannot contain the substring '{self._NAME_SEPARATOR}'.") + if "." in node_name or "." in var_name: + raise ValueError("Names cannot contain the character '.'") # Update the meta data. if node_name not in self.states: @@ -324,6 +321,48 @@ def extract_single_sample(self, sample_num): new_state.states[node_name][var_name] = value[sample_num] return new_state + def extract_parameters(self, params=None, nodes=None): + """Extract the parameter value(s) by a given name. This is often used for + recording the important parameters from an entire model (set of nodes). + + Parameters + ---------- + params : str or list-like, optional + The parameter names to extract. If None (not provided), extract data + for all parameters. + Default: None + nodes: str or list-like, optional + The node names to extract. If None (not provided), extract data + from all nodes. + Default: None + + Returns + ------- + values : dict + The resulting dictionary. + """ + if nodes is None: + use_nodes = None + elif isinstance(nodes, str): + use_nodes = set([nodes]) + else: + use_nodes = set(nodes) + + if params is None: + use_params = None + elif isinstance(params, str): + use_params = set([params]) + else: + use_params = set(params) + + values = {} + for node_name, node_params in self.states.items(): + if use_nodes is None or node_name in use_nodes: + for param_name, param_value in node_params.items(): + if use_params is None or param_name in use_params: + values[self.extended_param_name(node_name, param_name)] = param_value + return values + def to_table(self): """Flatten the graph state to an AstroPy Table with columns for each parameter. @@ -340,37 +379,17 @@ def to_table(self): values[self.extended_param_name(node_name, param_name)] = np.array(param_value) return values - def to_dict(self, nodes=None, params=None): + def to_dict(self): """Flatten the graph state to a dictionary with columns for each parameter. The column names are: {node_name}{separator}{param_name} - Parameters - ---------- - nodes : list or set, optional - A list of node names to extract. If None, then extracts data from - all nodes. - Default: None - params : list or set, optional - A list of parameter names to extract. If None, then extracts all the - parameters. - Default: None - Returns ------- values : dict The resulting dictionary. """ - use_nodes = set(nodes) if nodes is not None else None - use_params = set(params) if params is not None else None - - values = {} - for node_name, node_params in self.states.items(): - if use_nodes is None or node_name in use_nodes: - for param_name, param_value in node_params.items(): - if use_params is None or param_name in use_params: - values[self.extended_param_name(node_name, param_name)] = param_value - return values + return self.extract_parameters(nodes=None, params=None) def save_to_file(self, filename, overwrite=False): """Save the GraphState to a file. diff --git a/tests/tdastro/test_graph_state.py b/tests/tdastro/test_graph_state.py index 3bc89f5..49c4bc1 100644 --- a/tests/tdastro/test_graph_state.py +++ b/tests/tdastro/test_graph_state.py @@ -28,11 +28,11 @@ def test_create_single_sample_graph_state(): _ = state["c"]["v1"] # We can access the entries using the extended key name. - assert state[f"a{state._NAME_SEPARATOR}v1"] == 1.0 - assert state[f"a{state._NAME_SEPARATOR}v2"] == 2.0 - assert state[f"b{state._NAME_SEPARATOR}v1"] == 3.0 + assert state["a.v1"] == 1.0 + assert state["a.v2"] == 2.0 + assert state["b.v1"] == 3.0 with pytest.raises(KeyError): - _ = state[f"c{state._NAME_SEPARATOR}v1"] + _ = state["c.v1"] # We can create a human readable string representation of the GraphState. debug_str = str(state) @@ -70,9 +70,9 @@ def test_create_single_sample_graph_state(): # Test we cannot use a name containing the separator as a substring. with pytest.raises(ValueError): - state.set(f"a{state._NAME_SEPARATOR}b", "v1", 10.0) + state.set("a.b", "v1", 10.0) with pytest.raises(ValueError): - state.set("b", f"v1{state._NAME_SEPARATOR}v3", 10.0) + state.set("b", "v1.v3", 10.0) def test_graph_state_contains(): @@ -86,14 +86,14 @@ def test_graph_state_contains(): assert "b" in state assert "c" not in state - assert f"a{state._NAME_SEPARATOR}v1" in state - assert f"a{state._NAME_SEPARATOR}v2" in state - assert f"a{state._NAME_SEPARATOR}v3" not in state - assert f"b{state._NAME_SEPARATOR}v1" in state - assert f"c{state._NAME_SEPARATOR}v1" not in state + assert "a.v1" in state + assert "a.v2" in state + assert "a.v3" not in state + assert "b.v1" in state + assert "c.v1" not in state with pytest.raises(KeyError): - assert f"b{state._NAME_SEPARATOR}v1{state._NAME_SEPARATOR}v2" not in state + assert "b.v1.v2" not in state def test_create_multi_sample_graph_state(): @@ -402,7 +402,7 @@ def test_graph_to_from_file(): state.save_to_file(file_path, overwrite=True) -def test_graph_state_to_dict_extract_params(): +def test_graph_state_extract_parameters(): """Test that we can extract named parameters from a GraphState.""" state = GraphState() state.set("a", "v1", 1.0) @@ -415,12 +415,16 @@ def test_graph_state_to_dict_extract_params(): state.set("e", "v3", 3.0) state.set("e", "v5", 7.0) + # We can always access a parameter by its full name. + assert state["a.v1"] == 1.0 + assert state["a.v2"] == 2.0 + # With no filtering, we extract all the parameters. - results = state.to_dict() + results = state.extract_parameters() assert len(results) == 9 # We can extract only certain parameters. - results = state.to_dict(params=["v2", "v3", "v4"]) + results = state.extract_parameters(["v2", "v3", "v4"]) assert len(results) == 6 assert results["a.v2"] == 2.0 assert results["a.v3"] == 3.0 @@ -430,7 +434,7 @@ def test_graph_state_to_dict_extract_params(): assert results["d.v4"] == 6.0 # We can extract from only certain nodes. - results = state.to_dict(nodes=["a", "c"]) + results = state.extract_parameters(nodes=["a", "c"]) assert len(results) == 5 assert results["a.v1"] == 1.0 assert results["a.v2"] == 2.0 @@ -439,7 +443,7 @@ def test_graph_state_to_dict_extract_params(): assert results["c.v3"] == 3.0 # We can extract a set of parameters from a set of nodes. - results = state.to_dict(nodes=["a", "c", "e"], params=["v1", "v2"]) + results = state.extract_parameters(nodes=["a", "c", "e"], params=["v1", "v2"]) assert len(results) == 3 assert results["a.v1"] == 1.0 assert results["a.v2"] == 2.0