Skip to content

Commit

Permalink
Add the option to specify a seed for the Random sampling method
Browse files Browse the repository at this point in the history
  • Loading branch information
Enrico Stragiotti committed Dec 13, 2024
1 parent 5722846 commit fc7324a
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 1 deletion.
1 change: 1 addition & 0 deletions AUTHORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ SMT has been developed thanks to contributions from:
* Alexandre Thouvenot
* Andres Lopez Lopera
* Antoine Averland
* Enrico Stragiotti
* Emile Roux
* Ewout ter Hoeven
* Florent Vergnes
Expand Down
5 changes: 5 additions & 0 deletions doc/_src_docs/sampling_methods/random.rst

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 22 additions & 1 deletion smt/sampling_methods/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,25 @@


class Random(ScaledSamplingMethod):
def _initialize(self, **kwargs):
self.options.declare(
"random_state",
types=(type(None), int, np.random.RandomState),
desc="Numpy RandomState object or seed number which controls random draws",
)

# Update options values passed by the user here to get 'random_state' option
self.options.update(kwargs)

# RandomState is and has to be initialized once at constructor time,
# not in _compute to avoid yielding the same dataset again and again
if isinstance(self.options["random_state"], np.random.RandomState):
self.random_state = self.options["random_state"]
elif isinstance(self.options["random_state"], int):
self.random_state = np.random.RandomState(self.options["random_state"])
else:
self.random_state = np.random.RandomState()

def _compute(self, nt):
"""
Implemented by sampling methods to compute the requested number of sampling points.
Expand All @@ -30,4 +49,6 @@ def _compute(self, nt):
"""
xlimits = self.options["xlimits"]
nx = xlimits.shape[0]
return np.random.rand(nt, nx)
# Create a Generator object with a specified seed (np.random.rand(nt, nx) is being deprecated)
rng = np.random.default_rng(self.random_state)
return rng.random((nt, nx))

0 comments on commit fc7324a

Please sign in to comment.