diff --git a/pyproject.toml b/pyproject.toml index 7fc95b11..11a11f06 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ dynamic = ["version"] requires-python = ">=3.9" dependencies = [ "astropy", + "jax", "numpy", "scipy", "sncosmo", diff --git a/src/tdastro/base_models.py b/src/tdastro/base_models.py index e131366c..fc038f1d 100644 --- a/src/tdastro/base_models.py +++ b/src/tdastro/base_models.py @@ -338,10 +338,15 @@ def get_dependencies(self, nodes=None): class FunctionNode(ParameterizedNode): """A class to wrap functions and their argument settings. + The node can compute the result using a given function (the ``func`` + parameter) or through the ``compute()`` method. If no ``func=None`` + then the user must override ``compute()``. + Attributes ---------- func : `function` or `method` - The function to call during an evaluation. + The function to call during an evaluation. If this is ``None`` + you must override the ``compute()`` method directly. args_names : `list` A list of argument names to pass to the function. @@ -394,8 +399,21 @@ def __str__(self): # Extend the FunctionNode's string to include the name of the # function it calls so we can wrap a variety of raw functions. super_name = super().__str__() + if self.func is None: + return super_name return f"{super_name}:{self.func.__name__}" + def _build_args_dict(self, **kwargs): + """Build a dictionary of arguments for the function.""" + args = {} + for key in self.arg_names: + # Override with the kwarg if the parameter is there. + if key in kwargs: + args[key] = kwargs[key] + else: + args[key] = getattr(self, key) + return args + def compute(self, **kwargs): """Execute the wrapped function. @@ -403,12 +421,16 @@ def compute(self, **kwargs): ---------- **kwargs : `dict`, optional Additional function arguments. + + Raises + ------ + ``ValueError`` is ``func`` attribute is ``None``. """ - args = {} - for key in self.arg_names: - # Override with the kwarg if the parameter is there. - if key in kwargs: - args[key] = kwargs[key] - else: - args[key] = getattr(self, key) + if self.func is None: + raise ValueError( + "func parameter is None for a FunctionNode. You need to either " + "set func or override compute()." + ) + + args = self._build_args_dict(**kwargs) return self.func(**args) diff --git a/src/tdastro/util_nodes/__init__.py b/src/tdastro/util_nodes/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/tdastro/util_nodes/jax_random.py b/src/tdastro/util_nodes/jax_random.py new file mode 100644 index 00000000..336408ac --- /dev/null +++ b/src/tdastro/util_nodes/jax_random.py @@ -0,0 +1,97 @@ +"""Wrapper classes for calling JAX random number generators.""" + +import jax.random + +from tdastro.base_models import FunctionNode + + +class JaxRandomFunc(FunctionNode): + """The base class for JAX random number generators. + + Attributes + ---------- + _key : `jax._src.prng.PRNGKeyArray` + + Note + ---- + Automatically splits keys each time ``compute()`` is called, so + each call produces a new pseudorandom number. + """ + + def __init__(self, func, **kwargs): + super().__init__(func, **kwargs) + self._key = jax.random.key(self._object_seed) + + def set_graph_base_seed(self, graph_base_seed): + """Set a new graph base seed. + + Notes + ----- + WARNING: This seed should almost never be set manually. Using the same + seed for multiple graph instances will produce biased samples. + + Parameters + ---------- + graph_base_seed : `int`, optional + A base random seed to use for this specific evaluation graph. + """ + super().set_graph_base_seed(graph_base_seed) + + # We recompute the JAX key with the new object seed. + self._key = jax.random.key(self._object_seed) + + def compute(self, **kwargs): + """Execute the wrapped JAX sampling function. + + Parameters + ---------- + **kwargs : `dict`, optional + Additional function arguments. + + Raises + ------ + ``ValueError`` is ``func`` attribute is ``None``. + """ + if self.func is None: + raise ValueError( + "func parameter is None for a JAXRandom. You need to either " + "set func or override compute()." + ) + + args = self._build_args_dict(**kwargs) + self._key, subkey = jax.random.split(self._key) + return float(self.func(subkey, **args)) + + +class JaxRandomNormal(JaxRandomFunc): + """A wrapper for the JAX normal function that takes + a mean and std. + + Attributes + ---------- + loc : `float` + The mean of the distribution. + scale : `float` + The std of the distribution. + """ + + def __init__(self, loc, scale, **kwargs): + super().__init__(jax.random.normal, **kwargs) + + # The mean and std as attributes, but not arguments. + self.add_parameter("loc", loc) + self.add_parameter("scale", scale) + + def compute(self, **kwargs): + """Generate a random number from a normal distribution + with the given mean and std. + + Parameters + ---------- + **kwargs : `dict`, optional + Additional function arguments. + """ + initial_value = super().compute(**kwargs) + local_mean = kwargs.get("loc", self.loc) + local_std = kwargs.get("scale", self.scale) + return local_std * initial_value + local_mean diff --git a/src/tdastro/util_nodes/np_random.py b/src/tdastro/util_nodes/np_random.py new file mode 100644 index 00000000..71478c30 --- /dev/null +++ b/src/tdastro/util_nodes/np_random.py @@ -0,0 +1,70 @@ +"""Wrapper classes for calling numpy random number generators.""" + +import numpy as np + +from tdastro.base_models import FunctionNode + + +class NumpyRandomFunc(FunctionNode): + """The base class for numpy random number generators. + + Attributes + ---------- + func_name : `str` + The name of the random function to use. + _rng : `numpy.random._generator.Generator` + This object's random number generator. + + Notes + ----- + Since we need to create a new random number generator for this object + and use that generator's functions, we cannot pass in the function directly. + Instead we need to pass in the function's name. + + Examples + -------- + # Create a uniform random number generator between 100.0 and 150.0 + func_node = NumpyRandomFunc("uniform", low=100.0, high=150.0) + + # Create a normal random number generator with mean=5.0 and std=1.0 + func_node = NumpyRandomFunc("normal", loc=5.0, scale=1.0) + """ + + def __init__(self, func_name, **kwargs): + self.func_name = func_name + self._rng = np.random.default_rng() + if not hasattr(self._rng, func_name): + raise ValueError(f"Random function {func_name} does not exist.") + func = getattr(self._rng, func_name) + super().__init__(func, **kwargs) + + def set_graph_base_seed(self, graph_base_seed): + """Set a new graph base seed. + + Notes + ----- + WARNING: This seed should almost never be set manually. Using the same + seed for multiple graph instances will produce biased samples. + + Parameters + ---------- + graph_base_seed : `int`, optional + A base random seed to use for this specific evaluation graph. + """ + super().set_graph_base_seed(graph_base_seed) + + # We create a new random number generator with the new object seed and + # link to that object's function. + self._rng = np.random.default_rng(self._object_seed) + self.func = getattr(self._rng, self.func_name) + + def compute(self, **kwargs): + """Execute the wrapped numpy random number generator method. + + Parameters + ---------- + **kwargs : `dict`, optional + Additional function arguments. + """ + args = self._build_args_dict(**kwargs) + return self.func(**args) diff --git a/tests/tdastro/util_nodes/test_jax_random.py b/tests/tdastro/util_nodes/test_jax_random.py new file mode 100644 index 00000000..57abc818 --- /dev/null +++ b/tests/tdastro/util_nodes/test_jax_random.py @@ -0,0 +1,48 @@ +import jax.random +import numpy as np +from tdastro.util_nodes.jax_random import JaxRandomFunc, JaxRandomNormal + + +def test_jax_random_uniform(): + """Test that we can generate numbers from a uniform distribution.""" + jax_node = JaxRandomFunc(jax.random.uniform, graph_base_seed=100) + + values = np.array([jax_node.compute() for _ in range(10_000)]) + assert len(np.unique(values)) > 10 + assert np.all(values <= 1.0) + assert np.all(values >= 0.0) + assert np.abs(np.mean(values) - 0.5) < 0.01 + + # If we reuse the seed, we get the same number. + jax_node2 = JaxRandomFunc(jax.random.uniform, graph_base_seed=100) + values2 = np.array([jax_node2.compute() for _ in range(10_000)]) + assert np.allclose(values, values2) + + # We can change the range. + jax_node3 = JaxRandomFunc(jax.random.uniform, graph_base_seed=101, minval=10.0, maxval=20.0) + values = np.array([jax_node3.compute() for _ in range(10_000)]) + assert len(np.unique(values)) > 10 + assert np.all(values <= 20.0) + assert np.all(values >= 10.0) + assert np.abs(np.mean(values) - 15.0) < 0.05 + + # We can override the range dynamically. + values = np.array([jax_node3.compute(minval=2.0) for _ in range(10_000)]) + assert len(np.unique(values)) > 10 + assert np.all(values <= 20.0) + assert np.all(values >= 2.0) + assert np.abs(np.mean(values) - 11.0) < 0.05 + + +def test_jax_random_normal(): + """Test that we can generate numbers from a normal distribution.""" + jax_node = JaxRandomNormal(loc=100.0, scale=10.0, graph_base_seed=100) + + values = np.array([jax_node.compute() for _ in range(10_000)]) + assert np.abs(np.mean(values) - 100.0) < 0.5 + assert np.abs(np.std(values) - 10.0) < 0.5 + + # If we reuse the seed, we get the same number. + jax_node2 = JaxRandomNormal(loc=100.0, scale=10.0, graph_base_seed=100) + values2 = np.array([jax_node2.compute() for _ in range(10_000)]) + assert np.allclose(values, values2) diff --git a/tests/tdastro/util_nodes/test_np_random.py b/tests/tdastro/util_nodes/test_np_random.py new file mode 100644 index 00000000..75da1a2d --- /dev/null +++ b/tests/tdastro/util_nodes/test_np_random.py @@ -0,0 +1,47 @@ +import numpy as np +from tdastro.util_nodes.np_random import NumpyRandomFunc + + +def test_numpy_random_uniform(): + """Test that we can generate numbers from a uniform distribution.""" + np_node = NumpyRandomFunc("uniform", graph_base_seed=100) + + values = np.array([np_node.compute() for _ in range(10_000)]) + assert len(np.unique(values)) > 10 + assert np.all(values <= 1.0) + assert np.all(values >= 0.0) + assert np.abs(np.mean(values) - 0.5) < 0.01 + + # If we reuse the seed, we get the same number. + np_node2 = NumpyRandomFunc("uniform", graph_base_seed=100) + values2 = np.array([np_node2.compute() for _ in range(10_000)]) + assert np.allclose(values, values2) + + # We can change the range. + np_node3 = NumpyRandomFunc("uniform", graph_base_seed=101, low=10.0, high=20.0) + values = np.array([np_node3.compute() for _ in range(10_000)]) + assert len(np.unique(values)) > 10 + assert np.all(values <= 20.0) + assert np.all(values >= 10.0) + assert np.abs(np.mean(values) - 15.0) < 0.05 + + # We can override the range dynamically. + values = np.array([np_node3.compute(low=2.0) for _ in range(10_000)]) + assert len(np.unique(values)) > 10 + assert np.all(values <= 20.0) + assert np.all(values >= 2.0) + assert np.abs(np.mean(values) - 11.0) < 0.05 + + +def test_numpy_random_normal(): + """Test that we can generate numbers from a normal distribution.""" + np_node = NumpyRandomFunc("normal", loc=100.0, scale=10.0, graph_base_seed=100) + + values = np.array([np_node.compute() for _ in range(10_000)]) + assert np.abs(np.mean(values) - 100.0) < 0.5 + assert np.abs(np.std(values) - 10.0) < 0.5 + + # If we reuse the seed, we get the same number. + jax_node2 = NumpyRandomFunc("normal", loc=100.0, scale=10.0, graph_base_seed=100) + values2 = np.array([jax_node2.compute() for _ in range(10_000)]) + assert np.allclose(values, values2)