diff --git a/benchmarks/benchmarks.py b/benchmarks/benchmarks.py index f671a66..95f53d1 100644 --- a/benchmarks/benchmarks.py +++ b/benchmarks/benchmarks.py @@ -12,7 +12,7 @@ from tdastro.astro_utils.snia_utils import DistModFromRedshift, HostmassX1Func, X0FromDistMod from tdastro.astro_utils.unit_utils import fnu_to_flam from tdastro.base_models import FunctionNode -from tdastro.rand_nodes.np_random import NumpyRandomFunc +from tdastro.math_nodes.np_random import NumpyRandomFunc from tdastro.sources.sncomso_models import SncosmoWrapperModel diff --git a/src/tdastro/base_models.py b/src/tdastro/base_models.py index d8024c5..24e51b9 100644 --- a/src/tdastro/base_models.py +++ b/src/tdastro/base_models.py @@ -623,28 +623,6 @@ def build_pytree(self, graph_state, partial=None): return partial -class SingleVariableNode(ParameterizedNode): - """A ParameterizedNode holding a single pre-defined variable. - - Notes - ----- - Often used for testing, but can be used to make graph dependencies clearer. - - Parameters - ---------- - name : `str` - The parameter name. - value : any - The parameter value. - **kwargs : `dict`, optional - Any additional keyword arguments. - """ - - def __init__(self, name, value, **kwargs): - super().__init__(**kwargs) - self.add_parameter(name, value, **kwargs) - - class FunctionNode(ParameterizedNode): """A class to wrap functions and their argument settings. diff --git a/src/tdastro/math_nodes/single_value_node.py b/src/tdastro/math_nodes/single_value_node.py new file mode 100644 index 0000000..2e6881d --- /dev/null +++ b/src/tdastro/math_nodes/single_value_node.py @@ -0,0 +1,25 @@ +"""A simple node used for testing.""" + +from tdastro.base_models import ParameterizedNode + + +class SingleVariableNode(ParameterizedNode): + """A ParameterizedNode holding a single pre-defined variable. + + Notes + ----- + Often used for testing, but can be used to make graph dependencies clearer. + + Parameters + ---------- + name : `str` + The parameter name. + value : any + The parameter value. + **kwargs : `dict`, optional + Any additional keyword arguments. + """ + + def __init__(self, name, value, **kwargs): + super().__init__(**kwargs) + self.add_parameter(name, value, **kwargs) diff --git a/tests/tdastro/astro_utils/test_snia_utils.py b/tests/tdastro/astro_utils/test_snia_utils.py index f4ae7f6..8156d2e 100644 --- a/tests/tdastro/astro_utils/test_snia_utils.py +++ b/tests/tdastro/astro_utils/test_snia_utils.py @@ -1,8 +1,8 @@ import numpy as np import pytest from scipy.stats import norm -from tdastro.math_nodess.snia_utils import DistModFromRedshift, HostmassX1Distr, HostmassX1Func -from tdastro.rand_nodes.np_random import NumpyRandomFunc +from tdastro.astro_utils.snia_utils import DistModFromRedshift, HostmassX1Distr, HostmassX1Func +from tdastro.math_nodes.np_random import NumpyRandomFunc def test_hostmass_x1_distr(): diff --git a/tests/tdastro/rand_nodes/test_given_sampler.py b/tests/tdastro/math_nodes/test_given_sampler.py similarity index 100% rename from tests/tdastro/rand_nodes/test_given_sampler.py rename to tests/tdastro/math_nodes/test_given_sampler.py diff --git a/tests/tdastro/rand_nodes/test_np_random.py b/tests/tdastro/math_nodes/test_np_random.py similarity index 100% rename from tests/tdastro/rand_nodes/test_np_random.py rename to tests/tdastro/math_nodes/test_np_random.py diff --git a/tests/tdastro/rand_nodes/test_scipy_random.py b/tests/tdastro/math_nodes/test_scipy_random.py similarity index 100% rename from tests/tdastro/rand_nodes/test_scipy_random.py rename to tests/tdastro/math_nodes/test_scipy_random.py diff --git a/tests/tdastro/math_nodes/test_single_value_node.py b/tests/tdastro/math_nodes/test_single_value_node.py new file mode 100644 index 0000000..7548d9f --- /dev/null +++ b/tests/tdastro/math_nodes/test_single_value_node.py @@ -0,0 +1,10 @@ +from tdastro.math_nodes.single_value_node import SingleVariableNode + + +def test_single_variable_node(): + """Test that we can create and query a SingleVariableNode.""" + node = SingleVariableNode("A", 10.0) + assert str(node) == "SingleVariableNode" + + state = node.sample_parameters() + assert node.get_param(state, "A") == 10 diff --git a/tests/tdastro/test_base_models.py b/tests/tdastro/test_base_models.py index 76f88ab..cbb25c5 100644 --- a/tests/tdastro/test_base_models.py +++ b/tests/tdastro/test_base_models.py @@ -3,7 +3,8 @@ import jax import numpy as np import pytest -from tdastro.base_models import FunctionNode, ParameterizedNode, ParameterSource, SingleVariableNode +from tdastro.base_models import FunctionNode, ParameterizedNode, ParameterSource +from tdastro.math_nodes.single_value_node import SingleVariableNode def _sampler_fun(**kwargs): @@ -215,15 +216,6 @@ def test_parameterized_node_build_pytree(): assert pytree["B"]["value2"] == 3.0 -def test_single_variable_node(): - """Test that we can create and query a SingleVariableNode.""" - node = SingleVariableNode("A", 10.0) - assert str(node) == "SingleVariableNode" - - state = node.sample_parameters() - assert node.get_param(state, "A") == 10 - - def test_function_node_basic(): """Test that we can create and query a FunctionNode.""" my_func = FunctionNode(_test_func, value1=1.0, value2=2.0)