-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Use
numpy.random
as implementation inside of random.Random
Augments f278e2a to use the standard library's `random.Random` as a perfectly good interface to random values (plus access to all the derivatives that come with it), but provide a (mostly) conforming implementation that uses `numpy.random` for its implementation (if available).
- Loading branch information
Showing
3 changed files
with
204 additions
and
43 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
# ====================================================================================== | ||
# Copyright and other protections apply. Please see the accompanying LICENSE file for | ||
# rights and restrictions governing use of this software. All rights not expressly | ||
# waived or licensed are reserved. If that file is missing or appears to be modified | ||
# from its original, then please contact the author before viewing or using this | ||
# software in any capacity. | ||
# ====================================================================================== | ||
|
||
from __future__ import annotations | ||
|
||
from random import Random | ||
from typing import Any, NewType, Optional, Union | ||
|
||
__all__ = ("RNG",) | ||
|
||
|
||
# ---- Types --------------------------------------------------------------------------- | ||
|
||
|
||
_RandSeed = Union[int, float, str, bytes, bytearray] | ||
_RandState = NewType("_RandState", Any) | ||
|
||
|
||
# ---- Data ---------------------------------------------------------------------------- | ||
|
||
|
||
RNG: Random = Random() | ||
|
||
|
||
try: | ||
import numpy.random | ||
from numpy.random import BitGenerator, Generator | ||
|
||
class NumpyRandom(Random): | ||
r""" | ||
Defines a [``!#python | ||
random.Random``](https://docs.python.org/3/library/random.html#random.Random) | ||
implementation that accepts and uses a [``!#python | ||
numpy.random.BitGenerator``](https://numpy.org/doc/stable/reference/random/bit_generators/index.html) | ||
under the covers. Motivated by | ||
[avrae/d20#7](https://github.com/avrae/d20/issues/7). | ||
""" | ||
|
||
def __init__(self, bit_generator: BitGenerator): | ||
self._g = Generator(bit_generator) | ||
|
||
def random(self) -> float: | ||
return self._g.random() | ||
|
||
def seed(self, a: Optional[_RandSeed], version: int = 2) -> None: | ||
if a is not None and not isinstance(a, (int, float, str, bytes, bytearray)): | ||
raise ValueError(f"unrecognized seed type ({type(a)})") | ||
|
||
bg_type = type(self._g.bit_generator) | ||
|
||
if a is None: | ||
self._g = Generator(bg_type()) | ||
else: | ||
# This is somewhat fragile and may not be the best approach. It uses | ||
# `random.Random` to generate its own state from the seed in order to | ||
# maintain compatibility with accepted seed types. (NumPy only accepts | ||
# ints whereas the standard library accepts ints, floats, bytes, etc.). | ||
# That state consists of a 3-tuple: (version: int, internal_state: | ||
# tuple[int], gauss_next: float) at least for for versions through 3 (as | ||
# of this writing). We feed internal_state as the seed for the NumPy | ||
# BitGenerator. | ||
version, internal_state, _ = Random(a).getstate() | ||
self._g = Generator(bg_type(internal_state)) | ||
|
||
def getstate(self) -> _RandState: | ||
return _RandState(self._g.bit_generator.state) | ||
|
||
def setstate(self, state: _RandState) -> None: | ||
self._g.bit_generator.state = state | ||
|
||
if hasattr(numpy.random, "PCG64DXSM"): | ||
RNG = NumpyRandom(numpy.random.PCG64DXSM()) | ||
elif hasattr(numpy.random, "PCG64"): | ||
RNG = NumpyRandom(numpy.random.PCG64()) | ||
elif hasattr(numpy.random, "default_rng"): | ||
RNG = NumpyRandom(numpy.random.default_rng().bit_generator) | ||
except ImportError: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
# ====================================================================================== | ||
# Copyright and other protections apply. Please see the accompanying LICENSE file for | ||
# rights and restrictions governing use of this software. All rights not expressly | ||
# waived or licensed are reserved. If that file is missing or appears to be modified | ||
# from its original, then please contact the author before viewing or using this | ||
# software in any capacity. | ||
# ====================================================================================== | ||
|
||
from __future__ import annotations | ||
|
||
from decimal import Decimal | ||
from random import Random | ||
from typing import Optional | ||
|
||
import pytest | ||
|
||
import dyce.rng | ||
from dyce.rng import _RandSeed | ||
|
||
__all__ = () | ||
|
||
|
||
# ---- Data ---------------------------------------------------------------------------- | ||
|
||
|
||
SEED_INT_128 = 0x6265656663616665 | ||
SEED_FLOAT = float( | ||
Decimal( | ||
"9856940084378475016744131457734599215371411366662962480265638551381775059468656635085733393811201634227995293393551923733235754825282073085472925752147516616452603904" | ||
), | ||
) | ||
SEED_BYTES_128 = b"beefcafe"[::-1] | ||
SEED_INT_192 = 0x646561646265656663616665 | ||
SEED_BYTES_192 = b"deadbeefcafe"[::-1] | ||
|
||
|
||
# ---- Tests --------------------------------------------------------------------------- | ||
|
||
|
||
def test_numpy_rng() -> None: | ||
pytest.importorskip("numpy.random", reason="requires numpy") | ||
assert hasattr(dyce.rng, "NumpyRandom") | ||
assert isinstance(dyce.rng.RNG, dyce.rng.NumpyRandom) | ||
|
||
|
||
def test_numpy_rng_pcg64dxsm() -> None: | ||
numpy_random = pytest.importorskip("numpy.random", reason="requires numpy") | ||
|
||
if not hasattr(numpy_random, "PCG64DXSM"): | ||
pytest.skip("requires numpy.random.PCG64DXSM") | ||
|
||
rng = dyce.rng.NumpyRandom(numpy_random.PCG64DXSM()) | ||
_test_w_seed_helper(rng, SEED_INT_128, 0.7903327469601987) | ||
_test_w_seed_helper(rng, SEED_FLOAT, 0.6018795857570297) | ||
_test_w_seed_helper(rng, SEED_BYTES_128, 0.5339952033746491) | ||
_test_w_seed_helper(rng, SEED_INT_192, 0.9912715409588355) | ||
_test_w_seed_helper(rng, SEED_BYTES_192, 0.13818265573158406) | ||
|
||
with pytest.raises(ValueError): | ||
_test_w_seed_helper(rng, object()) # type: ignore | ||
|
||
|
||
def test_numpy_rng_pcg64() -> None: | ||
numpy_random = pytest.importorskip("numpy.random", reason="requires numpy") | ||
|
||
if not hasattr(numpy_random, "PCG64"): | ||
pytest.skip("requires numpy.random.PCG64") | ||
|
||
rng = dyce.rng.NumpyRandom(numpy_random.PCG64()) | ||
_test_w_seed_helper(rng, SEED_INT_128, 0.9794491381144006) | ||
_test_w_seed_helper(rng, SEED_FLOAT, 0.8347478482621317) | ||
_test_w_seed_helper(rng, SEED_BYTES_128, 0.7800090883745199) | ||
_test_w_seed_helper(rng, SEED_INT_192, 0.28018439479392754) | ||
_test_w_seed_helper(rng, SEED_BYTES_192, 0.4814859325412144) | ||
|
||
with pytest.raises(ValueError): | ||
_test_w_seed_helper(rng, object()) # type: ignore | ||
|
||
|
||
def test_numpy_rng_default() -> None: | ||
numpy_random = pytest.importorskip("numpy.random", reason="requires numpy") | ||
|
||
if not hasattr(numpy_random, "default_rng"): | ||
pytest.skip("requires numpy.random.default_rng") | ||
|
||
rng = dyce.rng.NumpyRandom(numpy_random.default_rng().bit_generator) | ||
_test_w_seed_helper(rng, SEED_INT_128) | ||
_test_w_seed_helper(rng, SEED_FLOAT) | ||
_test_w_seed_helper(rng, SEED_BYTES_128) | ||
_test_w_seed_helper(rng, SEED_INT_192) | ||
_test_w_seed_helper(rng, SEED_BYTES_192) | ||
|
||
with pytest.raises(ValueError): | ||
_test_w_seed_helper(rng, object()) # type: ignore | ||
|
||
|
||
def _test_w_seed_helper( | ||
rng: Random, | ||
seed: _RandSeed, | ||
expected: Optional[float] = None, | ||
) -> None: | ||
rng.seed(seed) | ||
state = rng.getstate() | ||
val = rng.random() | ||
assert val >= 0.0 and val < 1.0 | ||
|
||
if expected is not None: | ||
assert expected == val | ||
|
||
rng.setstate(state) | ||
assert rng.random() == val |