Skip to content

Commit

Permalink
Merge pull request #40 from lincc-frameworks/test_helpers
Browse files Browse the repository at this point in the history
Add a simple storage node
  • Loading branch information
jeremykubica authored Jul 18, 2024
2 parents e55b4d6 + e43146f commit 942387d
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 16 deletions.
22 changes: 22 additions & 0 deletions src/tdastro/base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,28 @@ def get_dependencies(self, nodes=None):
return nodes


class SingleVariableNode(ParameterizedNode):
"""A ParameterizedNode holding a single pre-defined variable.
Notes
-----
Often used for testing, but can be used to make graph dependencies clearer.
Parameters
----------
name : `str`
The parameter name.
value : any
The parameter value.
**kwargs : `dict`, optional
Any additional keyword arguments.
"""

def __init__(self, name, value, **kwargs):
super().__init__(**kwargs)
self.add_parameter(name, value, required=True, **kwargs)


class FunctionNode(ParameterizedNode):
"""A class to wrap functions and their argument settings.
Expand Down
24 changes: 8 additions & 16 deletions tests/tdastro/test_base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np
import pytest
from tdastro.base_models import FunctionNode, ParameterizedNode
from tdastro.base_models import FunctionNode, ParameterizedNode, SingleVariableNode


def _sampler_fun(**kwargs):
Expand All @@ -29,20 +29,6 @@ def _test_func(value1, value2):
return value1 + value2


class SingleModel(ParameterizedNode):
"""A test class for the ParameterizedNode.
Attributes
----------
value1 : `float`
The first value.
"""

def __init__(self, value1, **kwargs):
super().__init__(**kwargs)
self.add_parameter("value1", value1, required=True, **kwargs)


class PairModel(ParameterizedNode):
"""A test class for the ParameterizedNode.
Expand Down Expand Up @@ -249,7 +235,7 @@ def test_parameterized_node_seed():
model_a = PairModel(value1=0.5, value2=0.5, graph_base_seed=10, node_identifier="A")
model_b = PairModel(value1=0.5, value2=0.5, graph_base_seed=10, node_identifier="B")
model_c = PairModel(value1=0.5, value2=0.5, graph_base_seed=10, node_identifier="A")
model_d = SingleModel(value1=0.5, node_identifier="A")
model_d = SingleVariableNode("value1", 0.5, node_identifier="A")
assert model_a._object_seed != model_b._object_seed
assert model_a._object_seed == model_c._object_seed
assert model_a._object_seed != model_d._object_seed
Expand All @@ -267,6 +253,12 @@ def test_parameterized_node_seed():
assert model_d._object_seed != model_c._object_seed


def test_single_variable_node():
"""Test that we can create and query a SingleVariableNode."""
node = SingleVariableNode("A", 10.0)
assert node.A == 10


def test_function_node_basic():
"""Test that we can create and query a FunctionNode."""
my_func = FunctionNode(_test_func, value1=1.0, value2=2.0)
Expand Down

0 comments on commit 942387d

Please sign in to comment.