Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a helper function for looking at variables #170

Merged
merged 6 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 51 additions & 27 deletions src/tdastro/graph_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,47 +321,67 @@ def extract_single_sample(self, sample_num):
new_state.states[node_name][var_name] = value[sample_num]
return new_state

def extract_parameters(self, params=None, nodes=None):
def extract_parameters(self, params):
"""Extract the parameter value(s) by a given name. This is often used for
recording the important parameters from an entire model (set of nodes).

Parameters
----------
params : str or list-like, optional
The parameter names to extract. If None (not provided), extract data
for all parameters.
Default: None
nodes: str or list-like, optional
The node names to extract. If None (not provided), extract data
from all nodes.
Default: None
The parameter names to extract. These can be full names ("node.param") or
use the parameter names.

Returns
-------
values : dict
The resulting dictionary.
"""
if nodes is None:
use_nodes = None
elif isinstance(nodes, str):
use_nodes = set([nodes])
else:
use_nodes = set(nodes)
# If we are looking up a single parameter, but it into a list.
if isinstance(params, str):
params = [params]

# Go through all the parameters. If a parameters fill name is provided,
jeremykubica marked this conversation as resolved.
Show resolved Hide resolved
# look it up now and save the result. Otherwise put it into a list to check
# for in each node.
single_params = set()
results = {}
for current in params:
if "." in current:
node_name, param_name = current.split(".")
if node_name in self.states and param_name in self.states[node_name]:
results[current] = self.states[node_name][param_name]
else:
single_params.add(current)

if params is None:
use_params = None
elif isinstance(params, str):
use_params = set([params])
else:
use_params = set(params)
if len(single_params) == 0:
# Nothing else to do.
return results

values = {}
# Traverse the trnested dictionaries looking for cases where the parameter names match.
jeremykubica marked this conversation as resolved.
Show resolved Hide resolved
first_seen_node = {}
for node_name, node_params in self.states.items():
if use_nodes is None or node_name in use_nodes:
for param_name, param_value in node_params.items():
if use_params is None or param_name in use_params:
values[self.extended_param_name(node_name, param_name)] = param_value
return values
for param_name, param_value in node_params.items():
if param_name in single_params:
if param_name in first_seen_node:
# We've already seen this parameter in another node. Time to use the
# expanded names.

# Start by expanding the result we have already seen if needed.
if param_name in results:
full_name_existing = f"{first_seen_node[param_name]}.{param_name}"
results[full_name_existing] = results[param_name]
del results[param_name]

# Add the result from the current node.
full_name_current = f"{node_name}.{param_name}"
results[full_name_current] = param_value
else:
# This is the first time we have seen the node. Save it with
# just the parameter name. Also save the node where we saw it.
results[param_name] = param_value
first_seen_node[param_name] = node_name

return results
jeremykubica marked this conversation as resolved.
Show resolved Hide resolved

def to_table(self):
"""Flatten the graph state to an AstroPy Table with columns for each parameter.
Expand Down Expand Up @@ -389,7 +409,11 @@ def to_dict(self):
values : dict
The resulting dictionary.
"""
return self.extract_parameters(nodes=None, params=None)
values = {}
for node_name, node_params in self.states.items():
for param_name, param_value in node_params.items():
values[self.extended_param_name(node_name, param_name)] = list(param_value)
return values

def save_to_file(self, filename, overwrite=False):
"""Save the GraphState to a file.
Expand Down
66 changes: 29 additions & 37 deletions tests/tdastro/test_graph_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,61 +405,53 @@ def test_graph_to_from_file():
def test_graph_state_extract_parameters():
"""Test that we can extract named parameters from a GraphState."""
state = GraphState()
state.set("a", "v0", 0.0)
state.set("a", "v1", 1.0)
state.set("a", "v2", 2.0)
state.set("a", "v3", 3.0)
state.set("b", "v1", 4.0)
state.set("c", "v2", 5.0)
state.set("c", "v3", 3.0)
state.set("d", "v4", 6.0)
state.set("e", "v3", 3.0)
state.set("e", "v5", 7.0)

# We can always access a parameter by its full name.
assert state["a.v1"] == 1.0
assert state["a.v2"] == 2.0

# With no filtering, we extract all the parameters.
results = state.extract_parameters()
assert len(results) == 9
state.set("c", "v3", 6.0)
state.set("d", "v4", 7.0)
state.set("e", "v3", 8.0)
state.set("e", "v5", 9.0)
state.set("e", "v6", 10.0)
state.set("f", "v6", 11.0)

# We can extract a mixture of unique parameters based on full and short names.
results = state.extract_parameters(["a.v1", "c.v2", "v5"])
assert len(results) == 3
assert results["a.v1"] == 1.0
assert results["c.v2"] == 5.0
assert results["v5"] == 9.0

# We can extract only certain parameters.
# If we extract a parameter that appears in multiple nodes with its short
# name, we expand the name for each instance.
results = state.extract_parameters(["v2", "v3", "v4"])
assert len(results) == 6
assert results["a.v2"] == 2.0
assert results["a.v3"] == 3.0
assert results["c.v2"] == 5.0
assert results["c.v3"] == 3.0
assert results["e.v3"] == 3.0
assert results["d.v4"] == 6.0
assert results["c.v3"] == 6.0
assert results["e.v3"] == 8.0
assert results["v4"] == 7.0 # We do not expand the name.

# We can also provide a single parameter name.
# We can also provide a single parameter name as a string.
results = state.extract_parameters("v2")
assert len(results) == 2
assert results["a.v2"] == 2.0
assert results["c.v2"] == 5.0

# We can extract from only certain nodes.
results = state.extract_parameters(nodes=["a", "c"])
assert len(results) == 5
assert results["a.v1"] == 1.0
assert results["a.v2"] == 2.0
assert results["a.v3"] == 3.0
assert results["c.v2"] == 5.0
assert results["c.v3"] == 3.0

# We can also provide a single node name.
results = state.extract_parameters(nodes="c")
assert len(results) == 2
assert results["c.v2"] == 5.0
assert results["c.v3"] == 3.0

# We can extract a set of parameters from a set of nodes.
results = state.extract_parameters(nodes=["a", "c", "e"], params=["v1", "v2"])
assert len(results) == 3
assert results["a.v1"] == 1.0
assert results["a.v2"] == 2.0
# Test a complicated list.
results = state.extract_parameters(["v0", "c.v2", "e.v3", "v5", "v6", "a.v3"])
assert len(results) == 7
assert results["v0"] == 0.0
assert results["c.v2"] == 5.0
assert results["e.v3"] == 8.0
assert results["v5"] == 9.0
assert results["e.v6"] == 10.0
assert results["f.v6"] == 11.0
assert results["a.v3"] == 3.0

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking something like extract_parameters(['a.v1','c.v2','v3']) which will return {'a.v1: 1, 'c.v2': 5, 'a.v3': 3, 'c.v3': 3, 'e.v3': 3}. Or extract_parameters(['a.v1','v5']) which will return {'a.v1': 1, 'v5': 7} or {'a.v1': 1, 'e.v5': 7} (either is fine but maybe the former is slightly preferred).
For example, I wanted to get a quick plot of the simulated Hubble Diagram using host redshift, so I'd want to do extract_parameters(['host.redshift','x0','x1','c']), and I don't care (or don't know) what node x0,x1 or c is from.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That behavior was the point of the collapse logic. It allowed returning a single value for a parameter using just the parameter name.

In your case of extract_parameters(['host.redshift','x0','x1','c']) how do you want to handle duplicates? If we find instances of x0 in two different nodes we could:

  1. Throw an error [initial behavior]
  2. Collapse them if they are identical
    2a. Throw an error if they are not identical
    2b. Return both in expanded form if they are not identical [second round behavior]
  3. Return all values [current behavior]
  4. Return an arbitrary one of the x0 values. [Not recommended as the behavior becomes undefined].

More concretely if someone where to call extract_parameters(['ra', 'dec', 't0']) where the host and source have two different values of ra, how would you like the code to behave?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we want to output duplicate parameter names with their node names (say a.x0, b.x0), regardless of whether they are identical in values.
If someone calls extract_parameters(['ra','dec','t0']), we'll return {'host.ra','source.ra','anything_else.ra', ...}, even some of them may be identical in their values.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to add that I think extract_parameters(node=[],parameters=[]) is still useful. Maybe parameters can be expended to include extract_parameters(parameters=['host.redshift','x0','x1','c'])?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I've got it, but let me check.

In the case of a state:

a.v1 -> 1.0
a.v2 -> 2.0
a.v3 -> 3.0
b.v1 -> 4.0
b.v2 -> 5.0
b.v4 -> 6.0
c.v1 -> 7.0

and the query extract_parameters(parameters=['a.v1', 'v2', 'v3', 'b.v4']) you would want the function to return the dictionary:

{
    "a.v1": 1.0,
    "a.v2": 2.0,
    "b.v2": 5.0,
    "v3": 3.0,
    "b.v4": 6.0,
}

Specifically:

  1. The full name are always preserved (even if the parameter does not exist in other nodes).
  2. The shortened name is preserved if the parameter is unique.
  3. The shortened name is transformed into multiple long names if the parameter is not unique.
    Is that correct?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I implemented those rules. I've dropped the nodes parameter for now, because it is not clear how to incorporate that as well. We can always extract all parameters with to_dict.


def test_transpose_dict_of_list():
Expand Down
Loading