Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a helper node for testing #163

Merged
merged 4 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from tdastro.astro_utils.snia_utils import DistModFromRedshift, HostmassX1Func, X0FromDistMod
from tdastro.astro_utils.unit_utils import fnu_to_flam
from tdastro.base_models import FunctionNode
from tdastro.rand_nodes.np_random import NumpyRandomFunc
from tdastro.math_nodes.np_random import NumpyRandomFunc
from tdastro.sources.sncomso_models import SncosmoWrapperModel


Expand Down
2 changes: 1 addition & 1 deletion src/tdastro/astro_utils/snia_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from scipy.stats.sampling import NumericalInversePolynomial

from tdastro.base_models import FunctionNode
from tdastro.rand_nodes.scipy_random import NumericalInversePolynomialFunc
from tdastro.math_nodes.scipy_random import NumericalInversePolynomialFunc


def snia_volumetric_rates(redshift):
Expand Down
22 changes: 0 additions & 22 deletions src/tdastro/base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,28 +623,6 @@ def build_pytree(self, graph_state, partial=None):
return partial


class SingleVariableNode(ParameterizedNode):
"""A ParameterizedNode holding a single pre-defined variable.

Notes
-----
Often used for testing, but can be used to make graph dependencies clearer.

Parameters
----------
name : `str`
The parameter name.
value : any
The parameter value.
**kwargs : `dict`, optional
Any additional keyword arguments.
"""

def __init__(self, name, value, **kwargs):
super().__init__(**kwargs)
self.add_parameter(name, value, **kwargs)


class FunctionNode(ParameterizedNode):
"""A class to wrap functions and their argument settings.

Expand Down
2 changes: 1 addition & 1 deletion src/tdastro/example_runs/simulate_snia.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from tdastro.astro_utils.passbands import PassbandGroup
from tdastro.astro_utils.snia_utils import DistModFromRedshift, HostmassX1Func, X0FromDistMod
from tdastro.astro_utils.unit_utils import flam_to_fnu, fnu_to_flam
from tdastro.rand_nodes.np_random import NumpyRandomFunc
from tdastro.math_nodes.np_random import NumpyRandomFunc
from tdastro.sources.sncomso_models import SncosmoWrapperModel
from tdastro.sources.snia_host import SNIaHost

Expand Down
File renamed without changes.
68 changes: 68 additions & 0 deletions src/tdastro/math_nodes/given_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""A sampler used for testing that produces known (given) results."""

import numpy as np

from tdastro.base_models import FunctionNode


class GivenSampler(FunctionNode):
"""A FunctionNode that returns given results.

Attributes
----------
values : `float`, `list, or `numpy.ndarray`
The values to return.
next_ind : int
The index of the next value.
"""

def __init__(self, values, **kwargs):
self.values = np.array(values)
self.next_ind = 0
super().__init__(self._non_func, **kwargs)

def _non_func(self):
"""This function does nothing. Everything happens in the overloaded compute()."""
pass

def reset(self):
"""Reset the next index to use."""
self.next_ind = 0

def compute(self, graph_state, rng_info=None, **kwargs):
"""Return the given values.

Parameters
----------
graph_state : `GraphState`
An object mapping graph parameters to their values. This object is modified
in place as it is sampled.
rng_info : numpy.random._generator.Generator, optional
Unused in this function, but included to provide consistency with other
compute functions.
**kwargs : `dict`, optional
Additional function arguments.

Returns
-------
results : any
The result of the computation. This return value is provided so that testing
functions can easily access the results.
"""
if graph_state.num_samples == 1:
if self.next_ind >= len(self.values):
raise IndexError()

results = self.values[self.next_ind]
self.next_ind += 1
else:
end_ind = self.next_ind + graph_state.num_samples
if end_ind > len(self.values):
raise IndexError()

results = self.values[self.next_ind : end_ind]
self.next_ind = end_ind

# Save and return the results.
self._save_results(results, graph_state)
return results
25 changes: 25 additions & 0 deletions src/tdastro/math_nodes/single_value_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""A simple node used for testing."""

from tdastro.base_models import ParameterizedNode


class SingleVariableNode(ParameterizedNode):
"""A ParameterizedNode holding a single pre-defined variable.

Notes
-----
Often used for testing, but can be used to make graph dependencies clearer.

Parameters
----------
name : `str`
The parameter name.
value : any
The parameter value.
**kwargs : `dict`, optional
Any additional keyword arguments.
"""

def __init__(self, name, value, **kwargs):
super().__init__(**kwargs)
self.add_parameter(name, value, **kwargs)
2 changes: 1 addition & 1 deletion tests/tdastro/astro_utils/test_snia_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest
from scipy.stats import norm
from tdastro.astro_utils.snia_utils import DistModFromRedshift, HostmassX1Distr, HostmassX1Func
from tdastro.rand_nodes.np_random import NumpyRandomFunc
from tdastro.math_nodes.np_random import NumpyRandomFunc


def test_hostmass_x1_distr():
Expand Down
74 changes: 74 additions & 0 deletions tests/tdastro/math_nodes/test_given_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import numpy as np
import pytest
from tdastro.base_models import FunctionNode
from tdastro.graph_state import GraphState
from tdastro.math_nodes.given_sampler import GivenSampler


def _test_func(value1, value2):
"""Return the sum of the two parameters.

Parameters
----------
value1 : `float`
The first parameter.
value2 : `float`
The second parameter.
"""
return value1 + value2


def test_given_sampler():
"""Test that we can retrieve numbers from a GivenSampler."""
given_node = GivenSampler([1.0, 1.5, 2.0, 2.5, 3.0, -1.0, 3.5])

# Check that we generate the correct result and save it in the GraphState.
state1 = GraphState(num_samples=2)
results = given_node.compute(state1)
assert np.array_equal(results, [1.0, 1.5])
assert np.array_equal(given_node.get_param(state1, "function_node_result"), [1.0, 1.5])

state2 = GraphState(num_samples=1)
results = given_node.compute(state2)
assert results == 2.0
assert given_node.get_param(state2, "function_node_result") == 2.0

state3 = GraphState(num_samples=2)
results = given_node.compute(state3)
assert np.array_equal(results, [2.5, 3.0])
assert np.array_equal(given_node.get_param(state3, "function_node_result"), [2.5, 3.0])

# Check that GivenSampler raises an error when it has run out of samples.
state4 = GraphState(num_samples=4)
with pytest.raises(IndexError):
_ = given_node.compute(state4)

# Resetting the GivenSampler starts back at the beginning.
given_node.reset()
state5 = GraphState(num_samples=6)
results = given_node.compute(state5)
assert np.array_equal(results, [1.0, 1.5, 2.0, 2.5, 3.0, -1.0])
assert np.array_equal(
given_node.get_param(state5, "function_node_result"),
[1.0, 1.5, 2.0, 2.5, 3.0, -1.0],
)


def test_test_given_sampler_compound():
"""Test that we can use the GivenSampler as input into another node."""
values = [1.0, 1.5, 2.0, 2.5, 3.0, -1.0, 3.5, 4.0, 10.0, -2.0]
given_node = GivenSampler(values)

# Create a function node that takes the next value and adds 2.0.
compound_node = FunctionNode(_test_func, value1=given_node, value2=2.0)
for val in values:
state = compound_node.sample_parameters()
assert compound_node.get_param(state, "function_node_result") == val + 2.0

# Reset the given node and try generating all the samples.
given_node.reset()
state2 = compound_node.sample_parameters(num_samples=8)
assert np.array_equal(
compound_node.get_param(state2, "function_node_result"),
[3.0, 3.5, 4.0, 4.5, 5.0, 1.0, 5.5, 6.0],
)
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
import pytest
from tdastro.rand_nodes.np_random import NumpyRandomFunc
from tdastro.math_nodes.np_random import NumpyRandomFunc


def test_numpy_random_uniform():
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
from tdastro.rand_nodes.np_random import NumpyRandomFunc
from tdastro.rand_nodes.scipy_random import (
from tdastro.math_nodes.np_random import NumpyRandomFunc
from tdastro.math_nodes.scipy_random import (
NumericalInversePolynomialFunc,
SampleLogPDF,
SamplePDF,
Expand Down
10 changes: 10 additions & 0 deletions tests/tdastro/math_nodes/test_single_value_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from tdastro.math_nodes.single_value_node import SingleVariableNode


def test_single_variable_node():
"""Test that we can create and query a SingleVariableNode."""
node = SingleVariableNode("A", 10.0)
assert str(node) == "SingleVariableNode"

state = node.sample_parameters()
assert node.get_param(state, "A") == 10
2 changes: 1 addition & 1 deletion tests/tdastro/sources/test_galaxy_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from tdastro.rand_nodes.np_random import NumpyRandomFunc
from tdastro.math_nodes.np_random import NumpyRandomFunc
from tdastro.sources.galaxy_models import GaussianGalaxy
from tdastro.sources.static_source import StaticSource

Expand Down
2 changes: 1 addition & 1 deletion tests/tdastro/sources/test_sncosmo_models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
from astropy import units as u
from tdastro.astro_utils.unit_utils import fnu_to_flam
from tdastro.rand_nodes.np_random import NumpyRandomFunc
from tdastro.math_nodes.np_random import NumpyRandomFunc
from tdastro.sources.sncomso_models import SncosmoWrapperModel


Expand Down
12 changes: 2 additions & 10 deletions tests/tdastro/test_base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import jax
import numpy as np
import pytest
from tdastro.base_models import FunctionNode, ParameterizedNode, ParameterSource, SingleVariableNode
from tdastro.base_models import FunctionNode, ParameterizedNode, ParameterSource
from tdastro.math_nodes.single_value_node import SingleVariableNode


def _sampler_fun(**kwargs):
Expand Down Expand Up @@ -215,15 +216,6 @@ def test_parameterized_node_build_pytree():
assert pytree["B"]["value2"] == 3.0


def test_single_variable_node():
"""Test that we can create and query a SingleVariableNode."""
node = SingleVariableNode("A", 10.0)
assert str(node) == "SingleVariableNode"

state = node.sample_parameters()
assert node.get_param(state, "A") == 10


def test_function_node_basic():
"""Test that we can create and query a FunctionNode."""
my_func = FunctionNode(_test_func, value1=1.0, value2=2.0)
Expand Down
Loading