Skip to content

Commit

Permalink
Rename rand_nodes to math_nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremykubica committed Oct 16, 2024
1 parent 8ef802c commit fefd0f1
Show file tree
Hide file tree
Showing 9 changed files with 40 additions and 35 deletions.
2 changes: 1 addition & 1 deletion benchmarks/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from tdastro.astro_utils.snia_utils import DistModFromRedshift, HostmassX1Func, X0FromDistMod
from tdastro.astro_utils.unit_utils import fnu_to_flam
from tdastro.base_models import FunctionNode
from tdastro.rand_nodes.np_random import NumpyRandomFunc
from tdastro.math_nodes.np_random import NumpyRandomFunc
from tdastro.sources.sncomso_models import SncosmoWrapperModel


Expand Down
22 changes: 0 additions & 22 deletions src/tdastro/base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,28 +623,6 @@ def build_pytree(self, graph_state, partial=None):
return partial


class SingleVariableNode(ParameterizedNode):
"""A ParameterizedNode holding a single pre-defined variable.
Notes
-----
Often used for testing, but can be used to make graph dependencies clearer.
Parameters
----------
name : `str`
The parameter name.
value : any
The parameter value.
**kwargs : `dict`, optional
Any additional keyword arguments.
"""

def __init__(self, name, value, **kwargs):
super().__init__(**kwargs)
self.add_parameter(name, value, **kwargs)


class FunctionNode(ParameterizedNode):
"""A class to wrap functions and their argument settings.
Expand Down
25 changes: 25 additions & 0 deletions src/tdastro/math_nodes/single_value_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""A simple node used for testing."""

from tdastro.base_models import ParameterizedNode


class SingleVariableNode(ParameterizedNode):
"""A ParameterizedNode holding a single pre-defined variable.
Notes
-----
Often used for testing, but can be used to make graph dependencies clearer.
Parameters
----------
name : `str`
The parameter name.
value : any
The parameter value.
**kwargs : `dict`, optional
Any additional keyword arguments.
"""

def __init__(self, name, value, **kwargs):
super().__init__(**kwargs)
self.add_parameter(name, value, **kwargs)
4 changes: 2 additions & 2 deletions tests/tdastro/astro_utils/test_snia_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import numpy as np
import pytest
from scipy.stats import norm
from tdastro.math_nodess.snia_utils import DistModFromRedshift, HostmassX1Distr, HostmassX1Func
from tdastro.rand_nodes.np_random import NumpyRandomFunc
from tdastro.astro_utils.snia_utils import DistModFromRedshift, HostmassX1Distr, HostmassX1Func
from tdastro.math_nodes.np_random import NumpyRandomFunc


def test_hostmass_x1_distr():
Expand Down
File renamed without changes.
10 changes: 10 additions & 0 deletions tests/tdastro/math_nodes/test_single_value_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from tdastro.math_nodes.single_value_node import SingleVariableNode


def test_single_variable_node():
"""Test that we can create and query a SingleVariableNode."""
node = SingleVariableNode("A", 10.0)
assert str(node) == "SingleVariableNode"

state = node.sample_parameters()
assert node.get_param(state, "A") == 10
12 changes: 2 additions & 10 deletions tests/tdastro/test_base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import jax
import numpy as np
import pytest
from tdastro.base_models import FunctionNode, ParameterizedNode, ParameterSource, SingleVariableNode
from tdastro.base_models import FunctionNode, ParameterizedNode, ParameterSource
from tdastro.math_nodes.single_value_node import SingleVariableNode


def _sampler_fun(**kwargs):
Expand Down Expand Up @@ -215,15 +216,6 @@ def test_parameterized_node_build_pytree():
assert pytree["B"]["value2"] == 3.0


def test_single_variable_node():
"""Test that we can create and query a SingleVariableNode."""
node = SingleVariableNode("A", 10.0)
assert str(node) == "SingleVariableNode"

state = node.sample_parameters()
assert node.get_param(state, "A") == 10


def test_function_node_basic():
"""Test that we can create and query a FunctionNode."""
my_func = FunctionNode(_test_func, value1=1.0, value2=2.0)
Expand Down

0 comments on commit fefd0f1

Please sign in to comment.