Skip to content

Commit

Permalink
Merge pull request #37 from lincc-frameworks/random_seed
Browse files Browse the repository at this point in the history
Add the ability to set a graph-wide random seed.
  • Loading branch information
jeremykubica authored Jul 17, 2024
2 parents a94c158 + a450194 commit df3e2a5
Show file tree
Hide file tree
Showing 7 changed files with 181 additions and 45 deletions.
9 changes: 4 additions & 5 deletions src/tdastro/astro_utils/cosmology.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,15 @@ class RedshiftDistFunc(FunctionNode):
The function or constant providing the redshift value.
cosmology : `astropy.cosmology`
The cosmology specification.
**kwargs : `dict`, optional
Any additional keyword arguments.
"""

def __init__(self, redshift, cosmology):
def __init__(self, redshift, cosmology, **kwargs):
# Call the super class's constructor with the needed information.
super().__init__(
func=redshift_to_distance,
redshift=redshift,
cosmology=cosmology,
**kwargs,
)

def __str__(self):
"""Return the string representation of the function."""
return f"RedshiftDistFunc({self.cosmology.name}, {self.kind})"
100 changes: 93 additions & 7 deletions src/tdastro/base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import types
from enum import Enum
from hashlib import md5
from os import urandom


class ParameterSource(Enum):
Expand Down Expand Up @@ -30,19 +32,59 @@ class ParameterizedNode:
sample_iteration : `int`
A counter used to syncronize sampling runs. Tracks how many times this
model's parameters have been resampled.
_object_seed : `int`
A object-specific seed to control random number generation.
_graph_base_seed, `int`
A base random seed to use for this specific evaluation graph. Used
for validity checking.
Parameters
----------
node_identifier : `str`, optional
An identifier (or name) for the current node.
graph_base_seed : `int`, optional
A base random seed to use for this specific evaluation graph.
WARNING: This seed should almost never be set manually. Using the same
seed for multiple graph instances will produce biased samples.
If set to ``None`` will use urandom() to produce a fully random seed.
**kwargs : `dict`, optional
Any additional keyword arguments.
"""

def __init__(self, node_identifier=None, **kwargs):
def __init__(self, node_identifier=None, graph_base_seed=None, **kwargs):
self.setters = {}
self.sample_iteration = 0
self.node_identifier = node_identifier
self.set_graph_base_seed(graph_base_seed)

def __str__(self):
"""Return the string representation of the node."""
name = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
if self.node_identifier:
return f"{self.node_identifier}={self.__class__.__name__}"
return f"{self.node_identifier}={name}"
else:
return self.__class__.__name__
return name

def set_graph_base_seed(self, graph_base_seed):
"""Set a new graph base seed.
Notes
-----
WARNING: This seed should almost never be set manually. Using the same
seed for multiple graph instances will produce biased samples.
Parameters
----------
graph_base_seed : `int`, optional
A base random seed to use for this specific evaluation graph.
"""
if graph_base_seed is None:
graph_base_seed = int.from_bytes(urandom(4), "big")
self._graph_base_seed = graph_base_seed

hashed_object_name = md5(str(self).encode())
seed_offset = int(hashed_object_name.hexdigest(), base=16)
self._object_seed = (graph_base_seed + seed_offset) % (2**31)

def check_resample(self, other):
"""Check if we need to resample the current node based
Expand Down Expand Up @@ -264,6 +306,34 @@ def get_all_parameter_values(self, recursive=True, seen=None):
values[full_name] = getattr(self, name)
return values

def get_dependencies(self, nodes=None):
"""Get all nodes on which this current node depends.
Parameters
----------
nodes : `set`
A set of all nodes at or above this node in the graph.
This is modified in place.
Returns
-------
nodes : `set`
A set of all nodes at or above this node in the graph.
"""
# Make sure that we do not process the same nodes multiple times.
if nodes is None:
nodes = set()
if self in nodes:
return nodes
nodes.add(self)

for source_type, setter, _ in self.setters.values():
if source_type == ParameterSource.MODEL_ATTRIBUTE:
nodes = setter.get_dependencies(nodes)
elif source_type == ParameterSource.MODEL_METHOD:
nodes = setter.__self__.get_dependencies(nodes)
return nodes


class FunctionNode(ParameterizedNode):
"""A class to wrap functions and their argument settings.
Expand All @@ -275,6 +345,20 @@ class FunctionNode(ParameterizedNode):
args_names : `list`
A list of argument names to pass to the function.
Parameters
----------
func : `function` or `method`
The function to call during an evaluation.
node_identifier : `str`, optional
An identifier (or name) for the current node.
graph_base_seed : `int`, optional
A base random seed to use for this specific evaluation graph.
WARNING: This seed should almost never be set manually. Using the same
seed for multiple graph instances will produce biased samples.
If set to ``None`` will use urandom() to produce a fully random seed.
**kwargs : `dict`, optional
Any additional keyword arguments.
Examples
--------
my_func = TDFunc(random.randint, a=1, b=10)
Expand All @@ -293,19 +377,21 @@ class FunctionNode(ParameterizedNode):
value1 = my_func(b=10.0)
"""

def __init__(self, func, **kwargs):
super().__init__(**kwargs)
def __init__(self, func, node_identifier=None, graph_base_seed=None, **kwargs):
# We set the function before calling the parent class so we can use
# the function's name (if needed).
self.func = func
self.arg_names = []
super().__init__(node_identifier=node_identifier, graph_base_seed=graph_base_seed, **kwargs)

# Add all of the parameters from default_args or the kwargs.
self.arg_names = []
for key, value in kwargs.items():
self.arg_names.append(key)
self.add_parameter(key, value)

def __str__(self):
"""Return the string representation of the function."""
return f"FunctionNode({self.func.name})"
return f"FunctionNode({self.func.__name__})"

def compute(self, **kwargs):
"""Execute the wrapped function.
Expand Down
4 changes: 0 additions & 4 deletions src/tdastro/effects/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,6 @@ def __init__(self, redshift=None, t0=None, **kwargs):
self.add_parameter("redshift", redshift, required=True, **kwargs)
self.add_parameter("t0", t0, required=True, **kwargs)

def __str__(self) -> str:
"""Return a string representation of the Redshift effect model."""
return f"RedshiftEffect(redshift={self.redshift})"

def pre_effect(self, observer_frame_times, observer_frame_wavelengths, **kwargs):
"""Calculate the rest-frame times and wavelengths needed to give us the observer-frame times
and wavelengths (given the redshift).
Expand Down
2 changes: 1 addition & 1 deletion tests/tdastro/sources/test_sncosmo_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def test_sncomso_models_hsiao() -> None:
model = SncosmoWrapperModel("hsiao")
model.set(amplitude=1.0e-10)
assert model.amplitude == 1.0e-10
assert str(model) == "SncosmoWrapperModel"
assert str(model) == "tdastro.sources.sncomso_models.SncosmoWrapperModel"

assert np.array_equal(model.param_names, ["amplitude"])
assert np.array_equal(model.parameters, [1.0e-10])
Expand Down
4 changes: 2 additions & 2 deletions tests/tdastro/sources/test_spline_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def test_spline_model_flat() -> None:
wavelengths = np.linspace(100.0, 500.0, 25)
fluxes = np.full((len(times), len(wavelengths)), 1.0)
model = SplineModel(times, wavelengths, fluxes)
assert str(model) == "SplineModel"
assert str(model) == "tdastro.sources.spline_model.SplineModel"

test_times = np.array([0.0, 1.0, 2.0, 3.0, 10.0])
test_waves = np.array([0.0, 100.0, 200.0, 1000.0])
Expand All @@ -19,7 +19,7 @@ def test_spline_model_flat() -> None:
np.testing.assert_array_almost_equal(values, expected)

model2 = SplineModel(times, wavelengths, fluxes, amplitude=5.0, node_identifier="test")
assert str(model2) == "test=SplineModel"
assert str(model2) == "test=tdastro.sources.spline_model.SplineModel"

values2 = model2.evaluate(test_times, test_waves)
assert values2.shape == (5, 4)
Expand Down
4 changes: 2 additions & 2 deletions tests/tdastro/sources/test_static_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_static_source() -> None:
assert model.ra is None
assert model.dec is None
assert model.distance is None
assert str(model) == "my_static_source=StaticSource"
assert str(model) == "my_static_source=tdastro.sources.static_source.StaticSource"

times = np.array([1, 2, 3, 4, 5, 10])
wavelengths = np.array([100.0, 200.0, 300.0])
Expand All @@ -49,7 +49,7 @@ def test_static_source_host() -> None:
assert model.ra == 1.0
assert model.dec == 2.0
assert model.distance == 3.0
assert str(model) == "StaticSource"
assert str(model) == "tdastro.sources.static_source.StaticSource"


def test_static_source_resample() -> None:
Expand Down
103 changes: 79 additions & 24 deletions tests/tdastro/test_base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,20 @@ def _test_func(value1, value2):
return value1 + value2


class SingleModel(ParameterizedNode):
"""A test class for the ParameterizedNode.
Attributes
----------
value1 : `float`
The first value.
"""

def __init__(self, value1, **kwargs):
super().__init__(**kwargs)
self.add_parameter("value1", value1, required=True, **kwargs)


class PairModel(ParameterizedNode):
"""A test class for the ParameterizedNode.
Expand All @@ -43,17 +57,6 @@ class PairModel(ParameterizedNode):
"""

def __init__(self, value1, value2, **kwargs):
"""Create a ConstModel object.
Parameters
----------
value1 : `float`, `function`, `ParameterizedNode`, or `None`
The first value.
value2 : `float`, `function`, `ParameterizedNode`, or `None`
The second value.
**kwargs : `dict`, optional
Any additional keyword arguments.
"""
super().__init__(**kwargs)
self.add_parameter("value1", value1, required=True, **kwargs)
self.add_parameter("value2", value2, required=True, **kwargs)
Expand All @@ -79,7 +82,7 @@ def result(self, **kwargs):
return self.value1 + self.value2


def test_parameterized_node() -> None:
def test_parameterized_node():
"""Test that we can sample and create a PairModel object."""
# Simple addition
model1 = PairModel(value1=0.5, value2=0.5)
Expand Down Expand Up @@ -141,7 +144,7 @@ def test_parameterized_node() -> None:
assert model1.sample_iteration == model4.sample_iteration


def test_parameterized_node_overwrite() -> None:
def test_parameterized_node_overwrite():
"""Test that we can overwrite attributes in a PairModel."""
model1 = PairModel(value1=0.5, value2=0.5)
assert model1.value1 == 0.5
Expand All @@ -159,7 +162,7 @@ def test_parameterized_node_overwrite() -> None:
assert model1.value1 == 1.0


def test_parameterized_node_attributes() -> None:
def test_parameterized_node_attributes():
"""Test that we can extract the attributes of a graph of ParameterizedNode."""
model1 = PairModel(value1=0.5, value2=1.5, node_identifier="1")
settings = model1.get_all_parameter_values(False)
Expand All @@ -170,9 +173,9 @@ def test_parameterized_node_attributes() -> None:

settings = model1.get_all_parameter_values(True)
assert len(settings) == 3
assert settings["1=PairModel.value1"] == 0.5
assert settings["1=PairModel.value2"] == 1.5
assert settings["1=PairModel.value_sum"] == 2.0
assert settings["1=test_base_models.PairModel.value1"] == 0.5
assert settings["1=test_base_models.PairModel.value2"] == 1.5
assert settings["1=test_base_models.PairModel.value_sum"] == 2.0

# Use value1=model.value and value2=3.0
model2 = PairModel(value1=model1, value2=3.0, node_identifier="2")
Expand All @@ -184,15 +187,37 @@ def test_parameterized_node_attributes() -> None:

settings = model2.get_all_parameter_values(True)
assert len(settings) == 6
assert settings["1=PairModel.value1"] == 0.5
assert settings["1=PairModel.value2"] == 1.5
assert settings["1=PairModel.value_sum"] == 2.0
assert settings["2=PairModel.value1"] == 0.5
assert settings["2=PairModel.value2"] == 3.0
assert settings["2=PairModel.value_sum"] == 3.5
assert settings["1=test_base_models.PairModel.value1"] == 0.5
assert settings["1=test_base_models.PairModel.value2"] == 1.5
assert settings["1=test_base_models.PairModel.value_sum"] == 2.0
assert settings["2=test_base_models.PairModel.value1"] == 0.5
assert settings["2=test_base_models.PairModel.value2"] == 3.0
assert settings["2=test_base_models.PairModel.value_sum"] == 3.5


def test_parameterized_node_get_dependencies():
"""Test that we can extract the attributes of a graph of ParameterizedNode."""
model1 = PairModel(value1=0.5, value2=1.5, node_identifier="1")
model2 = PairModel(value1=model1, value2=3.0, node_identifier="2")
model3 = PairModel(value1=model1, value2=model2.result, node_identifier="3")

def test_parameterized_node_modify() -> None:
dep1 = model1.get_dependencies()
assert len(dep1) == 1
assert model1 in dep1

dep2 = model2.get_dependencies()
assert len(dep2) == 2
assert model1 in dep2
assert model2 in dep2

dep3 = model3.get_dependencies()
assert len(dep3) == 3
assert model1 in dep3
assert model2 in dep3
assert model3 in dep3


def test_parameterized_node_modify():
"""Test that we can modify the parameters in a node."""
model = PairModel(value1=0.5, value2=0.5)
assert model.value1 == 0.5
Expand All @@ -212,6 +237,36 @@ def test_parameterized_node_modify() -> None:
model.set_parameter("brightness", 5.0)


def test_parameterized_node_seed():
"""Test that we can set a random seed for the entire graph."""
# Left unspecified we use full random seeds.
model_a = PairModel(value1=0.5, value2=0.5)
model_b = PairModel(value1=0.5, value2=0.5)
assert model_a._object_seed != model_b._object_seed

# If we specify a seed, the results are the same objects with
# the same name (class + node identifier) and different otherwise.
model_a = PairModel(value1=0.5, value2=0.5, graph_base_seed=10, node_identifier="A")
model_b = PairModel(value1=0.5, value2=0.5, graph_base_seed=10, node_identifier="B")
model_c = PairModel(value1=0.5, value2=0.5, graph_base_seed=10, node_identifier="A")
model_d = SingleModel(value1=0.5, node_identifier="A")
assert model_a._object_seed != model_b._object_seed
assert model_a._object_seed == model_c._object_seed
assert model_a._object_seed != model_d._object_seed

assert model_b._object_seed != model_a._object_seed
assert model_b._object_seed != model_c._object_seed
assert model_b._object_seed != model_d._object_seed

assert model_c._object_seed == model_a._object_seed
assert model_c._object_seed != model_b._object_seed
assert model_c._object_seed != model_d._object_seed

assert model_d._object_seed != model_a._object_seed
assert model_d._object_seed != model_b._object_seed
assert model_d._object_seed != model_c._object_seed


def test_function_node_basic():
"""Test that we can create and query a FunctionNode."""
my_func = FunctionNode(_test_func, value1=1.0, value2=2.0)
Expand Down

0 comments on commit df3e2a5

Please sign in to comment.