Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a helper function for looking at variables #170

Merged
merged 6 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 73 additions & 14 deletions src/tdastro/graph_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ class GraphState:
are fixed in this GraphState instance.
"""

_NAME_SEPARATOR = "."

def __init__(self, num_samples=1):
if num_samples < 1:
raise ValueError(f"Invalid number of samples {num_samples}")
Expand All @@ -63,8 +61,8 @@ def __len__(self):
def __contains__(self, key):
if key in self.states:
return True
elif self._NAME_SEPARATOR in key:
tokens = key.split(self._NAME_SEPARATOR)
elif "." in key:
tokens = key.split(".")
if len(tokens) != 2:
raise KeyError(f"Invalid GraphState key: {key}")
return tokens[0] in self.states and tokens[1] in self.states[tokens[0]]
Expand Down Expand Up @@ -112,8 +110,8 @@ def __getitem__(self, key):
access by both the pair of keys and the extended name."""
if key in self.states:
return self.states[key]
elif self._NAME_SEPARATOR in key:
tokens = key.split(self._NAME_SEPARATOR)
elif "." in key:
tokens = key.split(".")
if len(tokens) != 2:
raise KeyError(f"Invalid GraphState key: {key}")
return self.states[tokens[0]][tokens[1]]
Expand All @@ -134,14 +132,14 @@ def extended_param_name(node_name, param_name):
Returns
-------
extended : str
A name of the form {node_name}{_NAME_SEPARATOR}{param_name}
A name of the form {node_name}.{param_name}
"""
return f"{node_name}{GraphState._NAME_SEPARATOR}{param_name}"
return f"{node_name}.{param_name}"

@classmethod
def from_table(cls, input_table):
"""Create the GraphState from an AstroPy Table with columns for each parameter
and column names of the form '{node_name}{_NAME_SEPARATOR}{param_name}'.
and column names of the form '{node_name}.{param_name}'.

Parameters
----------
Expand All @@ -151,11 +149,10 @@ def from_table(cls, input_table):
num_samples = len(input_table)
result = GraphState(num_samples=num_samples)
for col in input_table.colnames:
components = col.split(cls._NAME_SEPARATOR)
components = col.split(".")
if len(components) != 2:
raise ValueError(
f"Invalid name for entry {col}. Entries should be of the form "
f"'node_name{cls._NAME_SEPARATOR}param_name'."
f"Invalid name for entry {col}. Entries should be of the form " f"'node_name.param_name'."
)

# If we only have a single value then store that value instead of the np array.
Expand Down Expand Up @@ -226,8 +223,8 @@ def set(self, node_name, var_name, value, force_copy=False, fixed=False):
Default: ``False``
"""
# Check that the names do not use the separator value.
if self._NAME_SEPARATOR in node_name or self._NAME_SEPARATOR in var_name:
raise ValueError(f"Names cannot contain the substring '{self._NAME_SEPARATOR}'.")
if "." in node_name or "." in var_name:
raise ValueError("Names cannot contain the character '.'")

# Update the meta data.
if node_name not in self.states:
Expand Down Expand Up @@ -324,6 +321,68 @@ def extract_single_sample(self, sample_num):
new_state.states[node_name][var_name] = value[sample_num]
return new_state

def extract_parameters(self, params):
"""Extract the parameter value(s) by a given name. This is often used for
recording the important parameters from an entire model (set of nodes).

Parameters
----------
params : str or list-like, optional
The parameter names to extract. These can be full names ("node.param") or
use the parameter names.

Returns
-------
values : dict
The resulting dictionary.
"""
# If we are looking up a single parameter, but it into a list.
if isinstance(params, str):
params = [params]

# Go through all the parameters. If a parameters fill name is provided,
jeremykubica marked this conversation as resolved.
Show resolved Hide resolved
# look it up now and save the result. Otherwise put it into a list to check
# for in each node.
single_params = set()
results = {}
for current in params:
if "." in current:
node_name, param_name = current.split(".")
if node_name in self.states and param_name in self.states[node_name]:
results[current] = self.states[node_name][param_name]
else:
single_params.add(current)

if len(single_params) == 0:
# Nothing else to do.
return results

# Traverse the trnested dictionaries looking for cases where the parameter names match.
jeremykubica marked this conversation as resolved.
Show resolved Hide resolved
first_seen_node = {}
for node_name, node_params in self.states.items():
for param_name, param_value in node_params.items():
if param_name in single_params:
if param_name in first_seen_node:
# We've already seen this parameter in another node. Time to use the
# expanded names.

# Start by expanding the result we have already seen if needed.
if param_name in results:
full_name_existing = f"{first_seen_node[param_name]}.{param_name}"
results[full_name_existing] = results[param_name]
del results[param_name]

# Add the result from the current node.
full_name_current = f"{node_name}.{param_name}"
results[full_name_current] = param_value
else:
# This is the first time we have seen the node. Save it with
# just the parameter name. Also save the node where we saw it.
results[param_name] = param_value
first_seen_node[param_name] = node_name

return results
jeremykubica marked this conversation as resolved.
Show resolved Hide resolved

def to_table(self):
"""Flatten the graph state to an AstroPy Table with columns for each parameter.

Expand Down
76 changes: 64 additions & 12 deletions tests/tdastro/test_graph_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ def test_create_single_sample_graph_state():
_ = state["c"]["v1"]

# We can access the entries using the extended key name.
assert state[f"a{state._NAME_SEPARATOR}v1"] == 1.0
assert state[f"a{state._NAME_SEPARATOR}v2"] == 2.0
assert state[f"b{state._NAME_SEPARATOR}v1"] == 3.0
assert state["a.v1"] == 1.0
assert state["a.v2"] == 2.0
assert state["b.v1"] == 3.0
with pytest.raises(KeyError):
_ = state[f"c{state._NAME_SEPARATOR}v1"]
_ = state["c.v1"]

# We can create a human readable string representation of the GraphState.
debug_str = str(state)
Expand Down Expand Up @@ -70,9 +70,9 @@ def test_create_single_sample_graph_state():

# Test we cannot use a name containing the separator as a substring.
with pytest.raises(ValueError):
state.set(f"a{state._NAME_SEPARATOR}b", "v1", 10.0)
state.set("a.b", "v1", 10.0)
with pytest.raises(ValueError):
state.set("b", f"v1{state._NAME_SEPARATOR}v3", 10.0)
state.set("b", "v1.v3", 10.0)


def test_graph_state_contains():
Expand All @@ -86,14 +86,14 @@ def test_graph_state_contains():
assert "b" in state
assert "c" not in state

assert f"a{state._NAME_SEPARATOR}v1" in state
assert f"a{state._NAME_SEPARATOR}v2" in state
assert f"a{state._NAME_SEPARATOR}v3" not in state
assert f"b{state._NAME_SEPARATOR}v1" in state
assert f"c{state._NAME_SEPARATOR}v1" not in state
assert "a.v1" in state
assert "a.v2" in state
assert "a.v3" not in state
assert "b.v1" in state
assert "c.v1" not in state

with pytest.raises(KeyError):
assert f"b{state._NAME_SEPARATOR}v1{state._NAME_SEPARATOR}v2" not in state
assert "b.v1.v2" not in state


def test_create_multi_sample_graph_state():
Expand Down Expand Up @@ -402,6 +402,58 @@ def test_graph_to_from_file():
state.save_to_file(file_path, overwrite=True)


def test_graph_state_extract_parameters():
"""Test that we can extract named parameters from a GraphState."""
state = GraphState()
state.set("a", "v0", 0.0)
state.set("a", "v1", 1.0)
state.set("a", "v2", 2.0)
state.set("a", "v3", 3.0)
state.set("b", "v1", 4.0)
state.set("c", "v2", 5.0)
state.set("c", "v3", 6.0)
state.set("d", "v4", 7.0)
state.set("e", "v3", 8.0)
state.set("e", "v5", 9.0)
state.set("e", "v6", 10.0)
state.set("f", "v6", 11.0)

# We can extract a mixture of unique parameters based on full and short names.
results = state.extract_parameters(["a.v1", "c.v2", "v5"])
assert len(results) == 3
assert results["a.v1"] == 1.0
assert results["c.v2"] == 5.0
assert results["v5"] == 9.0

# If we extract a parameter that appears in multiple nodes with its short
# name, we expand the name for each instance.
results = state.extract_parameters(["v2", "v3", "v4"])
assert len(results) == 6
assert results["a.v2"] == 2.0
assert results["a.v3"] == 3.0
assert results["c.v2"] == 5.0
assert results["c.v3"] == 6.0
assert results["e.v3"] == 8.0
assert results["v4"] == 7.0 # We do not expand the name.

# We can also provide a single parameter name as a string.
results = state.extract_parameters("v2")
assert len(results) == 2
assert results["a.v2"] == 2.0
assert results["c.v2"] == 5.0

# Test a complicated list.
results = state.extract_parameters(["v0", "c.v2", "e.v3", "v5", "v6", "a.v3"])
assert len(results) == 7
assert results["v0"] == 0.0
assert results["c.v2"] == 5.0
assert results["e.v3"] == 8.0
assert results["v5"] == 9.0
assert results["e.v6"] == 10.0
assert results["f.v6"] == 11.0
assert results["a.v3"] == 3.0

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking something like extract_parameters(['a.v1','c.v2','v3']) which will return {'a.v1: 1, 'c.v2': 5, 'a.v3': 3, 'c.v3': 3, 'e.v3': 3}. Or extract_parameters(['a.v1','v5']) which will return {'a.v1': 1, 'v5': 7} or {'a.v1': 1, 'e.v5': 7} (either is fine but maybe the former is slightly preferred).
For example, I wanted to get a quick plot of the simulated Hubble Diagram using host redshift, so I'd want to do extract_parameters(['host.redshift','x0','x1','c']), and I don't care (or don't know) what node x0,x1 or c is from.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That behavior was the point of the collapse logic. It allowed returning a single value for a parameter using just the parameter name.

In your case of extract_parameters(['host.redshift','x0','x1','c']) how do you want to handle duplicates? If we find instances of x0 in two different nodes we could:

  1. Throw an error [initial behavior]
  2. Collapse them if they are identical
    2a. Throw an error if they are not identical
    2b. Return both in expanded form if they are not identical [second round behavior]
  3. Return all values [current behavior]
  4. Return an arbitrary one of the x0 values. [Not recommended as the behavior becomes undefined].

More concretely if someone where to call extract_parameters(['ra', 'dec', 't0']) where the host and source have two different values of ra, how would you like the code to behave?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we want to output duplicate parameter names with their node names (say a.x0, b.x0), regardless of whether they are identical in values.
If someone calls extract_parameters(['ra','dec','t0']), we'll return {'host.ra','source.ra','anything_else.ra', ...}, even some of them may be identical in their values.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to add that I think extract_parameters(node=[],parameters=[]) is still useful. Maybe parameters can be expended to include extract_parameters(parameters=['host.redshift','x0','x1','c'])?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I've got it, but let me check.

In the case of a state:

a.v1 -> 1.0
a.v2 -> 2.0
a.v3 -> 3.0
b.v1 -> 4.0
b.v2 -> 5.0
b.v4 -> 6.0
c.v1 -> 7.0

and the query extract_parameters(parameters=['a.v1', 'v2', 'v3', 'b.v4']) you would want the function to return the dictionary:

{
    "a.v1": 1.0,
    "a.v2": 2.0,
    "b.v2": 5.0,
    "v3": 3.0,
    "b.v4": 6.0,
}

Specifically:

  1. The full name are always preserved (even if the parameter does not exist in other nodes).
  2. The shortened name is preserved if the parameter is unique.
  3. The shortened name is transformed into multiple long names if the parameter is not unique.
    Is that correct?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I implemented those rules. I've dropped the nodes parameter for now, because it is not clear how to incorporate that as well. We can always extract all parameters with to_dict.


def test_transpose_dict_of_list():
"""Test the transpose_dict_of_list helper function"""
input_dict = {
Expand Down
Loading