Skip to content

Commit

Permalink
Merge pull request #92 from lincc-frameworks/readable_names
Browse files Browse the repository at this point in the history
Cleanup and simplify internal logic
  • Loading branch information
jeremykubica authored Aug 31, 2024
2 parents 9fcdf6d + c4a5f5d commit 6cbbbfa
Show file tree
Hide file tree
Showing 12 changed files with 124 additions and 166 deletions.
170 changes: 69 additions & 101 deletions src/tdastro/base_models.py

Large diffs are not rendered by default.

15 changes: 13 additions & 2 deletions src/tdastro/graph_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@ def __init__(self, num_samples=1):
def __len__(self):
return self.num_parameters

def __str__(self):
str_lines = []
for node_name, node_vars in self.states.items():
str_lines.append(f"{node_name}:")
for var_name, value in node_vars.items():
str_lines.append(f" {var_name}: {value}")
return "\n".join(str_lines)

def __getitem__(self, key):
"""Access the dictionary of parameter values for a node name."""
return self.states[key]
Expand Down Expand Up @@ -139,7 +147,7 @@ def set(self, node_name, var_name, value, force_copy=False, fixed=False):
if fixed:
self.fixed_vars[node_name].add(var_name)

def update(self, inputs, force_copy=False):
def update(self, inputs, force_copy=False, all_fixed=False):
"""Set multiple parameters' value in the GraphState from a GraphState or a
dictionary of the same form.
Expand All @@ -156,6 +164,9 @@ def update(self, inputs, force_copy=False):
Make a copy of data in an array. If set to ``False`` this will link
to the array, saving memory and computation time.
Default: ``False``
all_fixed : `bool`
Treat all the parameters in inputs as fixed.
Default: ``False``
Raises
------
Expand All @@ -173,7 +184,7 @@ def update(self, inputs, force_copy=False):
# number of samples.
for node_name, node_vars in new_states.items():
for var_name, value in node_vars.items():
self.set(node_name, var_name, value, force_copy=force_copy)
self.set(node_name, var_name, value, force_copy=force_copy, fixed=all_fixed)

def extract_single_sample(self, sample_num):
"""Create a new GraphState with a single sample state.
Expand Down
15 changes: 6 additions & 9 deletions src/tdastro/sources/physical_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,22 +222,19 @@ def sample_parameters(self, given_args=None, num_samples=1, rng_info=None, **kwa
if self.node_pos is None:
self.set_graph_positions()

args_to_use = {}
if given_args is not None:
args_to_use.update(given_args)
if kwargs is not None:
args_to_use.update(kwargs)

# We use the same seen_nodes for all sampling calls so each node
# is sampled at most one time regardless of link structure.
graph_state = GraphState(num_samples)
if given_args is not None:
graph_state.update(given_args, all_fixed=True)

seen_nodes = {}
if self.background is not None:
self.background._sample_helper(graph_state, seen_nodes, args_to_use, rng_info, **kwargs)
self._sample_helper(graph_state, seen_nodes, args_to_use, rng_info, **kwargs)
self.background._sample_helper(graph_state, seen_nodes, rng_info, **kwargs)
self._sample_helper(graph_state, seen_nodes, rng_info, **kwargs)

for effect in self.effects:
effect._sample_helper(graph_state, seen_nodes, args_to_use, rng_info, **kwargs)
effect._sample_helper(graph_state, seen_nodes, rng_info, **kwargs)

return graph_state

Expand Down
9 changes: 3 additions & 6 deletions src/tdastro/sources/sncomso_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def parameter_values(self):

def _update_sncosmo_model_parameters(self, graph_state):
"""Update the parameters for the wrapped sncosmo model."""
local_params = graph_state.get_node_state(self.node_hash, 0)
local_params = graph_state.get_node_state(self.node_string, 0)
sn_params = {}
for name in self.source_param_names:
sn_params[name] = local_params[name]
Expand Down Expand Up @@ -97,7 +97,7 @@ def set(self, **kwargs):
self.source_param_names.append(key)
self.source.set(**kwargs)

def _sample_helper(self, graph_state, seen_nodes, given_args=None, num_samples=1, rng_info=None):
def _sample_helper(self, graph_state, seen_nodes, num_samples=1, rng_info=None):
"""Internal recursive function to sample the model's underlying parameters
if they are provided by a function or ParameterizedNode.
Expand All @@ -112,9 +112,6 @@ def _sample_helper(self, graph_state, seen_nodes, given_args=None, num_samples=1
seen_nodes : `dict`
A dictionary mapping nodes seen during this sampling run to their ID.
Used to avoid sampling nodes multiple times and to validity check the graph.
given_args : `dict`, optional
A dictionary representing the given arguments for this sample run.
This can be used as the JAX PyTree for differentiation.
num_samples : `int`
A count of the number of samples to compute.
Default: 1
Expand All @@ -126,7 +123,7 @@ def _sample_helper(self, graph_state, seen_nodes, given_args=None, num_samples=1
------
Raise a ``ValueError`` the sampling encounters a problem with the order of dependencies.
"""
super()._sample_helper(graph_state, seen_nodes, given_args, rng_info)
super()._sample_helper(graph_state, seen_nodes, rng_info)
self._update_sncosmo_model_parameters(graph_state)

def _evaluate(self, times, wavelengths, graph_state=None, **kwargs):
Expand Down
13 changes: 5 additions & 8 deletions src/tdastro/util_nodes/jax_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,17 +89,14 @@ class JaxRandomFunc(FunctionNode):
def __init__(self, func, **kwargs):
super().__init__(func, **kwargs)

def compute(self, graph_state, given_args=None, rng_info=None, **kwargs):
def compute(self, graph_state, rng_info=None, **kwargs):
"""Execute the wrapped JAX sampling function.
Parameters
----------
graph_state : `GraphState`
An object mapping graph parameters to their values. This object is modified
in place as it is sampled.
given_args : `dict`, optional
A dictionary representing the given arguments for this sample run.
This can be used as the JAX PyTree for differentiation.
rng_info : `dict`, optional
A dictionary of random number generator information for each node, such as
the JAX keys or the numpy rngs.
Expand All @@ -125,13 +122,13 @@ def compute(self, graph_state, given_args=None, rng_info=None, **kwargs):
rng_info[self.node_hash] = next_key

# Generate the results.
args = self._build_inputs(graph_state, given_args, **kwargs)
args = self._build_inputs(graph_state, **kwargs)
if graph_state.num_samples == 1:
results = float(self.func(current_key, **args))
else:
use_shape = [graph_state.num_samples]
results = self.func(current_key, shape=use_shape, **args)
graph_state.set(self.node_hash, self.outputs[0], results)
graph_state.set(self.node_string, self.outputs[0], results)
return results

def generate(self, given_args=None, num_samples=1, rng_info=None, **kwargs):
Expand All @@ -152,7 +149,7 @@ def generate(self, given_args=None, num_samples=1, rng_info=None, **kwargs):
Additional function arguments.
"""
state = self.sample_parameters(given_args, num_samples, rng_info)
return self.compute(state, given_args, rng_info, **kwargs)
return self.compute(state, rng_info, **kwargs)


class JaxRandomNormal(FunctionNode):
Expand Down Expand Up @@ -196,4 +193,4 @@ def generate(self, given_args=None, num_samples=1, rng_info=None, **kwargs):
Any additional keyword arguments.
"""
state = self.sample_parameters(given_args, num_samples, rng_info)
return self.compute(state, given_args, rng_info, **kwargs)
return self.compute(state, rng_info, **kwargs)
9 changes: 3 additions & 6 deletions src/tdastro/util_nodes/np_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def set_seed(self, new_seed):
self._rng = np.random.default_rng(seed=new_seed)
self.func = getattr(self._rng, self.func_name)

def compute(self, graph_state, given_args=None, rng_info=None, **kwargs):
def compute(self, graph_state, rng_info=None, **kwargs):
"""Execute the wrapped function.
The input arguments are taken from the current graph_state and the outputs
Expand All @@ -137,9 +137,6 @@ def compute(self, graph_state, given_args=None, rng_info=None, **kwargs):
graph_state : `GraphState`
An object mapping graph parameters to their values. This object is modified
in place as it is sampled.
given_args : `dict`, optional
A dictionary representing the given arguments for this sample run.
This can be used as the JAX PyTree for differentiation.
rng_info : `dict`, optional
A dictionary of random number generator information for each node, such as
the JAX keys or the numpy rngs.
Expand All @@ -156,7 +153,7 @@ def compute(self, graph_state, given_args=None, rng_info=None, **kwargs):
------
``ValueError`` is ``func`` attribute is ``None``.
"""
args = self._build_inputs(graph_state, given_args, **kwargs)
args = self._build_inputs(graph_state, **kwargs)
num_samples = None if graph_state.num_samples == 1 else graph_state.num_samples

# If a random number generator is given use that. Otherwise use the default one.
Expand Down Expand Up @@ -186,4 +183,4 @@ def generate(self, given_args=None, num_samples=1, rng_info=None, **kwargs):
Additional function arguments.
"""
state = self.sample_parameters(given_args, num_samples, rng_info)
return self.compute(state, given_args, rng_info, **kwargs)
return self.compute(state, rng_info, **kwargs)
9 changes: 3 additions & 6 deletions src/tdastro/util_nodes/scipy_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def _create_and_sample(self, args, rng):
sample = NumericalInversePolynomial(dist).rvs(1, rng)[0]
return sample

def compute(self, graph_state, given_args=None, rng_info=None, **kwargs):
def compute(self, graph_state, rng_info=None, **kwargs):
"""Execute the wrapped function.
The input arguments are taken from the current graph_state and the outputs
Expand All @@ -108,9 +108,6 @@ def compute(self, graph_state, given_args=None, rng_info=None, **kwargs):
graph_state : `GraphState`
An object mapping graph parameters to their values. This object is modified
in place as it is sampled.
given_args : `dict`, optional
A dictionary representing the given arguments for this sample run.
This can be used as the JAX PyTree for differentiation.
rng_info : `dict`, optional
A dictionary of random number generator information for each node, such as
the JAX keys or the numpy rngs.
Expand All @@ -132,7 +129,7 @@ def compute(self, graph_state, given_args=None, rng_info=None, **kwargs):
else:
# This is a class so we will need to create a new distribution object
# for each sample (with a single instance of the input parameters).
args = self._build_inputs(graph_state, given_args, **kwargs)
args = self._build_inputs(graph_state, **kwargs)

if graph_state.num_samples == 1:
dist = self._dist(**args)
Expand Down Expand Up @@ -164,4 +161,4 @@ def generate(self, given_args=None, num_samples=1, rng_info=None, **kwargs):
Additional function arguments.
"""
state = self.sample_parameters(given_args, num_samples, rng_info)
return self.compute(state, given_args, rng_info, **kwargs)
return self.compute(state, rng_info, **kwargs)
3 changes: 1 addition & 2 deletions tests/tdastro/sources/test_sncosmo_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def test_sncomso_models_hsiao() -> None:
state = model.sample_parameters()
assert model.get_param(state, "amplitude") == 2.0e10
assert model.get_param(state, "t0") == 0.0
assert str(model) == "0:tdastro.sources.sncomso_models.SncosmoWrapperModel"
assert str(model) == "0:SncosmoWrapperModel"

assert np.array_equal(model.param_names, ["amplitude"])
assert np.array_equal(model.parameter_values, [2.0e10])
Expand All @@ -29,7 +29,6 @@ def test_sncomso_models_hsiao_t0() -> None:
state = model.sample_parameters()
assert model.get_param(state, "amplitude") == 2.0e10
assert model.get_param(state, "t0") == 55000.0
assert str(model) == "0:tdastro.sources.sncomso_models.SncosmoWrapperModel"

assert np.array_equal(model.param_names, ["amplitude"])
assert np.array_equal(model.parameter_values, [2.0e10])
Expand Down
2 changes: 1 addition & 1 deletion tests/tdastro/sources/test_spline_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def test_spline_model_flat() -> None:
wavelengths = np.linspace(100.0, 500.0, 25)
fluxes = np.full((len(times), len(wavelengths)), 1.0)
model = SplineModel(times, wavelengths, fluxes)
assert str(model) == "tdastro.sources.spline_model.SplineModel"
assert str(model) == "SplineModel"

test_times = np.array([0.0, 1.0, 2.0, 3.0, 10.0])
test_waves = np.array([0.0, 100.0, 200.0, 1000.0])
Expand Down
4 changes: 2 additions & 2 deletions tests/tdastro/sources/test_static_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@ def test_static_source_host() -> None:
assert model.get_param(state, "ra") == 1.0
assert model.get_param(state, "dec") == 2.0
assert model.get_param(state, "distance") == 3.0
assert str(model) == "0:tdastro.sources.static_source.StaticSource"
assert str(model) == "0:StaticSource"

# Test that we have given a different name to the host.
assert str(host) == "1:tdastro.sources.static_source.StaticSource"
assert str(host) == "1:StaticSource"


def test_static_source_resample() -> None:
Expand Down
37 changes: 14 additions & 23 deletions tests/tdastro/test_base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_parameter_source():
"""Test the ParameterSource creation and setter functions."""
source = ParameterSource("test")
assert source.parameter_name == "test"
assert source.full_name == "test"
assert source.node_name == ""
assert source.source_type == ParameterSource.UNDEFINED
assert source.dependency is None
assert source.value is None
Expand All @@ -70,17 +70,12 @@ def test_parameter_source():

source.set_as_constant(10.0)
assert source.parameter_name == "test"
assert source.full_name == "test"
assert source.source_type == ParameterSource.CONSTANT
assert source.dependency is None
assert source.value == 10.0
assert not source.fixed
assert not source.required

source.set_name("my_var", "my_node")
assert source.parameter_name == "my_var"
assert source.full_name == "my_node.my_var"

with pytest.raises(ValueError):
source.set_as_constant(_test_func)

Expand All @@ -104,7 +99,7 @@ def test_parameterized_node():
"""Test that we can sample and create a PairModel object."""
# Simple addition
model1 = PairModel(value1=0.5, value2=0.5)
assert str(model1) == "test_base_models.PairModel"
assert str(model1) == "PairModel"

state = model1.sample_parameters()
assert model1.get_param(state, "value1") == 0.5
Expand Down Expand Up @@ -219,26 +214,25 @@ def test_parameterized_node_build_pytree():
model1 = PairModel(value1=0.5, value2=1.5, node_label="A")
model2 = PairModel(value1=model1.value1, value2=3.0, node_label="B")
graph_state = model2.sample_parameters()
pytree = model2.build_pytree(graph_state)

assert len(pytree) == 3
assert pytree["1:A.value1"] == 0.5
assert pytree["1:A.value2"] == 1.5
assert pytree["0:B.value2"] == 3.0
pytree = model2.build_pytree(graph_state)
assert pytree["1:A"]["value1"] == 0.5
assert pytree["1:A"]["value2"] == 1.5
assert pytree["0:B"]["value2"] == 3.0

# Manually set value2 to fixed and check that it no longer appears in the pytree.
model1.setters["value2"].fixed = True

pytree = model2.build_pytree(graph_state)
assert len(pytree) == 2
assert pytree["1:A.value1"] == 0.5
assert pytree["0:B.value2"] == 3.0
assert pytree["1:A"]["value1"] == 0.5
assert pytree["0:B"]["value2"] == 3.0
assert "value2" not in pytree["1:A"]


def test_single_variable_node():
"""Test that we can create and query a SingleVariableNode."""
node = SingleVariableNode("A", 10.0)
assert str(node) == "tdastro.base_models.SingleVariableNode"
assert str(node) == "SingleVariableNode"

state = node.sample_parameters()
assert node.get_param(state, "A") == 10
Expand All @@ -253,7 +247,7 @@ def test_function_node_basic():
assert my_func.compute(state, value2=3.0) == 4.0
assert my_func.compute(state, value2=3.0, unused_param=5.0) == 4.0
assert my_func.compute(state, value2=3.0, value1=1.0) == 4.0
assert str(my_func) == "0:tdastro.base_models.FunctionNode:_test_func"
assert str(my_func) == "0:FunctionNode:_test_func"


def test_function_node_chain():
Expand Down Expand Up @@ -347,14 +341,11 @@ def _test_func2(value1, value2):
graph_state = sum_node.sample_parameters()

pytree = sum_node.build_pytree(graph_state)
assert len(pytree) == 3
print(pytree)

gr_func = jax.value_and_grad(sum_node.resample_and_compute)
values, gradients = gr_func(pytree)
assert len(gradients) == 3
print(gradients)
assert values == 9.0
assert gradients["0:sum:_test_func.value1"] == 1.0
assert gradients["1:div:_test_func2.value1"] == 2.0
assert gradients["1:div:_test_func2.value2"] == -16.0
assert gradients["0:sum:_test_func"]["value1"] == 1.0
assert gradients["1:div:_test_func2"]["value1"] == 2.0
assert gradients["1:div:_test_func2"]["value2"] == -16.0
4 changes: 4 additions & 0 deletions tests/tdastro/test_graph_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ def test_create_single_sample_graph_state():
with pytest.raises(KeyError):
_ = state["c"]["v1"]

# We can create a human readable string representation of the GraphState.
debug_str = str(state)
assert debug_str == "a:\n v1: 1.0\n v2: 2.0\nb:\n v1: 3.0"

# Check that we can get all the values for a specific node.
a_vals = state.get_node_state("a")
assert len(a_vals) == 2
Expand Down

0 comments on commit 6cbbbfa

Please sign in to comment.