Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add wrappers to random number generators #39

Merged
merged 3 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ dynamic = ["version"]
requires-python = ">=3.9"
dependencies = [
"astropy",
"jax",
"numpy",
"scipy",
"sncosmo",
Expand Down
38 changes: 30 additions & 8 deletions src/tdastro/base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,10 +338,15 @@ def get_dependencies(self, nodes=None):
class FunctionNode(ParameterizedNode):
"""A class to wrap functions and their argument settings.

The node can compute the result using a given function (the ``func``
parameter) or through the ``compute()`` method. If no ``func=None``
then the user must override ``compute()``.

Attributes
----------
func : `function` or `method`
The function to call during an evaluation.
The function to call during an evaluation. If this is ``None``
you must override the ``compute()`` method directly.
args_names : `list`
A list of argument names to pass to the function.

Expand Down Expand Up @@ -394,21 +399,38 @@ def __str__(self):
# Extend the FunctionNode's string to include the name of the
# function it calls so we can wrap a variety of raw functions.
super_name = super().__str__()
if self.func is None:
return super_name
return f"{super_name}:{self.func.__name__}"

def _build_args_dict(self, **kwargs):
"""Build a dictionary of arguments for the function."""
args = {}
for key in self.arg_names:
# Override with the kwarg if the parameter is there.
if key in kwargs:
args[key] = kwargs[key]
else:
args[key] = getattr(self, key)
return args

def compute(self, **kwargs):
"""Execute the wrapped function.

Parameters
----------
**kwargs : `dict`, optional
Additional function arguments.

Raises
------
``ValueError`` is ``func`` attribute is ``None``.
"""
args = {}
for key in self.arg_names:
# Override with the kwarg if the parameter is there.
if key in kwargs:
args[key] = kwargs[key]
else:
args[key] = getattr(self, key)
if self.func is None:
raise ValueError(
"func parameter is None for a FunctionNode. You need to either "
"set func or override compute()."
)

args = self._build_args_dict(**kwargs)
return self.func(**args)
Empty file.
97 changes: 97 additions & 0 deletions src/tdastro/util_nodes/jax_random.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""Wrapper classes for calling JAX random number generators."""

import jax.random

from tdastro.base_models import FunctionNode


class JaxRandomFunc(FunctionNode):
"""The base class for JAX random number generators.

Attributes
----------
_key : `jax._src.prng.PRNGKeyArray`

Note
----
Automatically splits keys each time ``compute()`` is called, so
each call produces a new pseudorandom number.
"""

def __init__(self, func, **kwargs):
super().__init__(func, **kwargs)
self._key = jax.random.key(self._object_seed)

def set_graph_base_seed(self, graph_base_seed):
"""Set a new graph base seed.

Notes
-----
WARNING: This seed should almost never be set manually. Using the same
seed for multiple graph instances will produce biased samples.

Parameters
----------
graph_base_seed : `int`, optional
A base random seed to use for this specific evaluation graph.
"""
super().set_graph_base_seed(graph_base_seed)

# We recompute the JAX key with the new object seed.
self._key = jax.random.key(self._object_seed)

def compute(self, **kwargs):
"""Execute the wrapped JAX sampling function.

Parameters
----------
**kwargs : `dict`, optional
Additional function arguments.

Raises
------
``ValueError`` is ``func`` attribute is ``None``.
"""
if self.func is None:
raise ValueError(
"func parameter is None for a JAXRandom. You need to either "
"set func or override compute()."
)

args = self._build_args_dict(**kwargs)
self._key, subkey = jax.random.split(self._key)
return float(self.func(subkey, **args))


class JaxRandomNormal(JaxRandomFunc):
"""A wrapper for the JAX normal function that takes
a mean and std.

Attributes
----------
loc : `float`
The mean of the distribution.
scale : `float`
The std of the distribution.
"""

def __init__(self, loc, scale, **kwargs):
super().__init__(jax.random.normal, **kwargs)

# The mean and std as attributes, but not arguments.
self.add_parameter("loc", loc)
self.add_parameter("scale", scale)

def compute(self, **kwargs):
"""Generate a random number from a normal distribution
with the given mean and std.

Parameters
----------
**kwargs : `dict`, optional
Additional function arguments.
"""
initial_value = super().compute(**kwargs)
local_mean = kwargs.get("loc", self.loc)
local_std = kwargs.get("scale", self.scale)
return local_std * initial_value + local_mean
70 changes: 70 additions & 0 deletions src/tdastro/util_nodes/np_random.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""Wrapper classes for calling numpy random number generators."""

import numpy as np

from tdastro.base_models import FunctionNode


class NumpyRandomFunc(FunctionNode):
"""The base class for numpy random number generators.

Attributes
----------
func_name : `str`
The name of the random function to use.
_rng : `numpy.random._generator.Generator`
This object's random number generator.

Notes
-----
Since we need to create a new random number generator for this object
and use that generator's functions, we cannot pass in the function directly.
Instead we need to pass in the function's name.

Examples
--------
# Create a uniform random number generator between 100.0 and 150.0
func_node = NumpyRandomFunc("uniform", low=100.0, high=150.0)

# Create a normal random number generator with mean=5.0 and std=1.0
func_node = NumpyRandomFunc("normal", loc=5.0, scale=1.0)
"""

def __init__(self, func_name, **kwargs):
self.func_name = func_name
self._rng = np.random.default_rng()
if not hasattr(self._rng, func_name):
raise ValueError(f"Random function {func_name} does not exist.")
func = getattr(self._rng, func_name)
super().__init__(func, **kwargs)

def set_graph_base_seed(self, graph_base_seed):
"""Set a new graph base seed.

Notes
-----
WARNING: This seed should almost never be set manually. Using the same
seed for multiple graph instances will produce biased samples.

Parameters
----------
graph_base_seed : `int`, optional
A base random seed to use for this specific evaluation graph.
"""
super().set_graph_base_seed(graph_base_seed)

# We create a new random number generator with the new object seed and
# link to that object's function.
self._rng = np.random.default_rng(self._object_seed)
self.func = getattr(self._rng, self.func_name)

def compute(self, **kwargs):
"""Execute the wrapped numpy random number generator method.

Parameters
----------
**kwargs : `dict`, optional
Additional function arguments.
"""
args = self._build_args_dict(**kwargs)
return self.func(**args)
48 changes: 48 additions & 0 deletions tests/tdastro/util_nodes/test_jax_random.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import jax.random
import numpy as np
from tdastro.util_nodes.jax_random import JaxRandomFunc, JaxRandomNormal


def test_jax_random_uniform():
"""Test that we can generate numbers from a uniform distribution."""
jax_node = JaxRandomFunc(jax.random.uniform, graph_base_seed=100)

values = np.array([jax_node.compute() for _ in range(10_000)])
assert len(np.unique(values)) > 10
assert np.all(values <= 1.0)
assert np.all(values >= 0.0)
assert np.abs(np.mean(values) - 0.5) < 0.01

# If we reuse the seed, we get the same number.
jax_node2 = JaxRandomFunc(jax.random.uniform, graph_base_seed=100)
values2 = np.array([jax_node2.compute() for _ in range(10_000)])
assert np.allclose(values, values2)

# We can change the range.
jax_node3 = JaxRandomFunc(jax.random.uniform, graph_base_seed=101, minval=10.0, maxval=20.0)
values = np.array([jax_node3.compute() for _ in range(10_000)])
assert len(np.unique(values)) > 10
assert np.all(values <= 20.0)
assert np.all(values >= 10.0)
assert np.abs(np.mean(values) - 15.0) < 0.05

# We can override the range dynamically.
values = np.array([jax_node3.compute(minval=2.0) for _ in range(10_000)])
assert len(np.unique(values)) > 10
assert np.all(values <= 20.0)
assert np.all(values >= 2.0)
assert np.abs(np.mean(values) - 11.0) < 0.05


def test_jax_random_normal():
"""Test that we can generate numbers from a normal distribution."""
jax_node = JaxRandomNormal(loc=100.0, scale=10.0, graph_base_seed=100)

values = np.array([jax_node.compute() for _ in range(10_000)])
assert np.abs(np.mean(values) - 100.0) < 0.5
assert np.abs(np.std(values) - 10.0) < 0.5

# If we reuse the seed, we get the same number.
jax_node2 = JaxRandomNormal(loc=100.0, scale=10.0, graph_base_seed=100)
values2 = np.array([jax_node2.compute() for _ in range(10_000)])
assert np.allclose(values, values2)
47 changes: 47 additions & 0 deletions tests/tdastro/util_nodes/test_np_random.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import numpy as np
from tdastro.util_nodes.np_random import NumpyRandomFunc


def test_numpy_random_uniform():
"""Test that we can generate numbers from a uniform distribution."""
np_node = NumpyRandomFunc("uniform", graph_base_seed=100)

values = np.array([np_node.compute() for _ in range(10_000)])
assert len(np.unique(values)) > 10
assert np.all(values <= 1.0)
assert np.all(values >= 0.0)
assert np.abs(np.mean(values) - 0.5) < 0.01

# If we reuse the seed, we get the same number.
np_node2 = NumpyRandomFunc("uniform", graph_base_seed=100)
values2 = np.array([np_node2.compute() for _ in range(10_000)])
assert np.allclose(values, values2)

# We can change the range.
np_node3 = NumpyRandomFunc("uniform", graph_base_seed=101, low=10.0, high=20.0)
values = np.array([np_node3.compute() for _ in range(10_000)])
assert len(np.unique(values)) > 10
assert np.all(values <= 20.0)
assert np.all(values >= 10.0)
assert np.abs(np.mean(values) - 15.0) < 0.05

# We can override the range dynamically.
values = np.array([np_node3.compute(low=2.0) for _ in range(10_000)])
assert len(np.unique(values)) > 10
assert np.all(values <= 20.0)
assert np.all(values >= 2.0)
assert np.abs(np.mean(values) - 11.0) < 0.05


def test_numpy_random_normal():
"""Test that we can generate numbers from a normal distribution."""
np_node = NumpyRandomFunc("normal", loc=100.0, scale=10.0, graph_base_seed=100)

values = np.array([np_node.compute() for _ in range(10_000)])
assert np.abs(np.mean(values) - 100.0) < 0.5
assert np.abs(np.std(values) - 10.0) < 0.5

# If we reuse the seed, we get the same number.
jax_node2 = NumpyRandomFunc("normal", loc=100.0, scale=10.0, graph_base_seed=100)
values2 = np.array([jax_node2.compute() for _ in range(10_000)])
assert np.allclose(values, values2)
Loading