diff --git a/src/tdastro/base_models.py b/src/tdastro/base_models.py index fc038f1d..c2c9e1bf 100644 --- a/src/tdastro/base_models.py +++ b/src/tdastro/base_models.py @@ -335,6 +335,28 @@ def get_dependencies(self, nodes=None): return nodes +class SingleVariableNode(ParameterizedNode): + """A ParameterizedNode holding a single pre-defined variable. + + Notes + ----- + Often used for testing, but can be used to make graph dependencies clearer. + + Parameters + ---------- + name : `str` + The parameter name. + value : any + The parameter value. + **kwargs : `dict`, optional + Any additional keyword arguments. + """ + + def __init__(self, name, value, **kwargs): + super().__init__(**kwargs) + self.add_parameter(name, value, required=True, **kwargs) + + class FunctionNode(ParameterizedNode): """A class to wrap functions and their argument settings. diff --git a/tests/tdastro/test_base_models.py b/tests/tdastro/test_base_models.py index c2e25abb..46ed5e74 100644 --- a/tests/tdastro/test_base_models.py +++ b/tests/tdastro/test_base_models.py @@ -2,7 +2,7 @@ import numpy as np import pytest -from tdastro.base_models import FunctionNode, ParameterizedNode +from tdastro.base_models import FunctionNode, ParameterizedNode, SingleVariableNode def _sampler_fun(**kwargs): @@ -29,20 +29,6 @@ def _test_func(value1, value2): return value1 + value2 -class SingleModel(ParameterizedNode): - """A test class for the ParameterizedNode. - - Attributes - ---------- - value1 : `float` - The first value. - """ - - def __init__(self, value1, **kwargs): - super().__init__(**kwargs) - self.add_parameter("value1", value1, required=True, **kwargs) - - class PairModel(ParameterizedNode): """A test class for the ParameterizedNode. @@ -249,7 +235,7 @@ def test_parameterized_node_seed(): model_a = PairModel(value1=0.5, value2=0.5, graph_base_seed=10, node_identifier="A") model_b = PairModel(value1=0.5, value2=0.5, graph_base_seed=10, node_identifier="B") model_c = PairModel(value1=0.5, value2=0.5, graph_base_seed=10, node_identifier="A") - model_d = SingleModel(value1=0.5, node_identifier="A") + model_d = SingleVariableNode("value1", 0.5, node_identifier="A") assert model_a._object_seed != model_b._object_seed assert model_a._object_seed == model_c._object_seed assert model_a._object_seed != model_d._object_seed @@ -267,6 +253,12 @@ def test_parameterized_node_seed(): assert model_d._object_seed != model_c._object_seed +def test_single_variable_node(): + """Test that we can create and query a SingleVariableNode.""" + node = SingleVariableNode("A", 10.0) + assert node.A == 10 + + def test_function_node_basic(): """Test that we can create and query a FunctionNode.""" my_func = FunctionNode(_test_func, value1=1.0, value2=2.0)