From bc74905e8f1537b8debb94fef8874353f5b01e8c Mon Sep 17 00:00:00 2001 From: Matt Bogosian Date: Fri, 24 Sep 2021 07:50:12 -0500 Subject: [PATCH] Restructure NumPy Random implementation Follow-up to 3c9b258 to restructure NumPy random implementation according to a more rigid class structure in line with [this advice](https://github.com/avrae/d20/issues/7#issuecomment-926022431). --- docs/notes.md | 2 +- dyce/h.py | 4 +- dyce/rng.py | 138 ++++++++++++++++++++++++------------ tests/test_h.py | 4 +- tests/test_rng.py | 173 +++++++++++++++++++++++++++++----------------- 5 files changed, 206 insertions(+), 115 deletions(-) diff --git a/docs/notes.md b/docs/notes.md index e01a480c..63faa6ac 100644 --- a/docs/notes.md +++ b/docs/notes.md @@ -20,7 +20,7 @@ * Introduces experimental generic [``walk``][dyce.r.walk] function and supporting visitor data structures. * Uses ``pygraphviz`` to automate class diagram generation. (See the note on special considerations for regenerating class diagrams in the [hacking quick start](contrib.md#hacking-quick-start).) -* Uses ``numpy`` for RNG, if present. +* Introduces experimental use of ``numpy`` for RNG, if present. * Migrates to using ``pyproject.toml`` and ``setup.cfg``. ## [0.4.0](https://github.com/posita/dyce/releases/tag/v0.4.0) diff --git a/dyce/h.py b/dyce/h.py index d0e81c3d..ca4656e7 100644 --- a/dyce/h.py +++ b/dyce/h.py @@ -57,9 +57,9 @@ overload, ) +from . import rng from .bt import beartype from .lifecycle import experimental -from .rng import RNG from .symmetries import comb, gcd from .types import ( CachingProtocolMeta, @@ -1720,7 +1720,7 @@ def roll(self) -> OutcomeT: Returns a (weighted) random outcome, sorted. """ return ( - RNG.choices( + rng.RNG.choices( population=tuple(self.outcomes()), weights=tuple(self.counts()), k=1, diff --git a/dyce/rng.py b/dyce/rng.py index 5116fda0..24e88a5b 100644 --- a/dyce/rng.py +++ b/dyce/rng.py @@ -8,8 +8,12 @@ from __future__ import annotations +from abc import ABC from random import Random -from typing import Any, NewType, Optional, Union +from sys import version_info +from typing import NewType, Sequence, Type, Union + +from .bt import beartype __all__ = ("RNG",) @@ -17,67 +21,109 @@ # ---- Types --------------------------------------------------------------------------- -_RandSeed = Union[int, float, str, bytes, bytearray] -_RandState = NewType("_RandState", Any) +_RandState = NewType("_RandState", object) +_RandSeed = Union[None, int, Sequence[int]] # ---- Data ---------------------------------------------------------------------------- -RNG: Random = Random() +RNG: Random + + +# ---- Classes ------------------------------------------------------------------------- try: - import numpy.random - from numpy.random import BitGenerator, Generator + from numpy.random import PCG64DXSM, BitGenerator, Generator, default_rng - class NumpyRandom(Random): + _BitGeneratorT = Type[BitGenerator] + + class NumPyRandomBase(Random, ABC): r""" - Defines a [``!#python + Base class for a [``#!python random.Random``](https://docs.python.org/3/library/random.html#random.Random) - implementation that accepts and uses a [``!#python + implementation that uses a [``#!python numpy.random.BitGenerator``](https://numpy.org/doc/stable/reference/random/bit_generators/index.html) under the covers. Motivated by [avrae/d20#7](https://github.com/avrae/d20/issues/7). + + The [initializer][rng.NumPyRandomBase.__init__] takes an optional *seed*, which is + passed to + [``NumPyRandomBase.bit_generator``][dyce.rng.NumPyRandomBase.bit_generator] via + [``NumPyRandomBase.seed``][dyce.rng.NumPyRandomBase.seed] during construction. """ - def __init__(self, bit_generator: BitGenerator): - self._g = Generator(bit_generator) + bit_generator: _BitGeneratorT + _generator: Generator + + if version_info < (3, 11): + + @beartype + def __new__(cls, seed: _RandSeed = None): + r""" + Because ``#!python random.Random`` is broken in versions <3.11, ``#!python + random.Random``’s vanilla implementation cannot accept non-hashable + values as the first argument. For example, it will reject lists of + ``#!python int``s as *seed*. This implementation of ``#!python __new__`` + fixes that. + """ + return super(NumPyRandomBase, cls).__new__(cls) + + @beartype + def __init__(self, seed: _RandSeed = None): + # Parent calls self.seed(seed) + super().__init__(seed) + + # ---- Overrides --------------------------------------------------------------- + @beartype + def getrandbits(self, k: int) -> int: + # Adapted from the implementation for random.SystemRandom.getrandbits + if k < 0: + raise ValueError("number of bits must be non-negative") + + numbytes = (k + 7) // 8 # bits / 8 and rounded up + x = int.from_bytes(self.randbytes(numbytes), "big") + + return x >> (numbytes * 8 - k) # trim excess bits + + @beartype + # TODO(posita): See + def getstate(self) -> _RandState: # type: ignore + return _RandState(self._generator.bit_generator.state) + + @beartype + def randbytes(self, n: int) -> bytes: + return self._generator.bytes(n) + + @beartype def random(self) -> float: - return self._g.random() - - def seed(self, a: Optional[_RandSeed], version: int = 2) -> None: - if a is not None and not isinstance(a, (int, float, str, bytes, bytearray)): - raise ValueError(f"unrecognized seed type ({type(a)})") - - bg_type = type(self._g.bit_generator) - - if a is None: - self._g = Generator(bg_type()) - else: - # This is somewhat fragile and may not be the best approach. It uses - # `random.Random` to generate its own state from the seed in order to - # maintain compatibility with accepted seed types. (NumPy only accepts - # ints whereas the standard library accepts ints, floats, bytes, etc.). - # That state consists of a 3-tuple: (version: int, internal_state: - # tuple[int], gauss_next: float) at least for for versions through 3 (as - # of this writing). We feed internal_state as the seed for the NumPy - # BitGenerator. - version, internal_state, _ = Random(a).getstate() - self._g = Generator(bg_type(internal_state)) - - def getstate(self) -> _RandState: - return _RandState(self._g.bit_generator.state) - - def setstate(self, state: _RandState) -> None: - self._g.bit_generator.state = state - - if hasattr(numpy.random, "PCG64DXSM"): - RNG = NumpyRandom(numpy.random.PCG64DXSM()) - elif hasattr(numpy.random, "PCG64"): - RNG = NumpyRandom(numpy.random.PCG64()) - elif hasattr(numpy.random, "default_rng"): - RNG = NumpyRandom(numpy.random.default_rng().bit_generator) + return self._generator.random() + + @beartype + def seed( # type: ignore + self, + a: _RandSeed, + version: int = 2, + ) -> None: + self._generator = default_rng(self.bit_generator(a)) + + @beartype + def setstate( # type: ignore + self, + # TODO(posita): See + state: _RandState, + ) -> None: + self._generator.bit_generator.state = state + + class PCG64DXSMRandom(NumPyRandomBase): + r""" + A [``NumPyRandomBase``][dyce.rng.NumPyRandomBase] based on + [``numpy.random.PCG64DXSM``](https://numpy.org/doc/stable/reference/random/bit_generators/pcg64dxsm.html#numpy.random.PCG64DXSM). + """ + bit_generator = PCG64DXSM + + RNG = PCG64DXSMRandom() except ImportError: - pass + RNG = Random() diff --git a/tests/test_h.py b/tests/test_h.py index 297f5212..920e4034 100644 --- a/tests/test_h.py +++ b/tests/test_h.py @@ -171,13 +171,13 @@ def test_op_sub_h(self) -> None: assert d2 - d3 == { o_type(-2): 1, o_type(-1): 2, - # See + # TODO(posita): See o_type(0) + o_type(0): 2, o_type(1): 1, }, f"o_type: {o_type}; c_type: {c_type}" assert d3 - d2 == { o_type(-1): 1, - # See + # TODO(posita): See o_type(0) + o_type(0): 2, o_type(1): 2, o_type(2): 1, diff --git a/tests/test_rng.py b/tests/test_rng.py index 86a42c9c..75095e9b 100644 --- a/tests/test_rng.py +++ b/tests/test_rng.py @@ -8,14 +8,12 @@ from __future__ import annotations -from decimal import Decimal from random import Random from typing import Optional import pytest -import dyce.rng -from dyce.rng import _RandSeed +from dyce.rng import RNG, _RandSeed __all__ = () @@ -23,78 +21,120 @@ # ---- Data ---------------------------------------------------------------------------- -SEED_INT_128 = 0x6265656663616665 -SEED_FLOAT = float( - Decimal( - "9856940084378475016744131457734599215371411366662962480265638551381775059468656635085733393811201634227995293393551923733235754825282073085472925752147516616452603904" - ), -) -SEED_BYTES_128 = b"beefcafe"[::-1] -SEED_INT_192 = 0x646561646265656663616665 -SEED_BYTES_192 = b"deadbeefcafe"[::-1] +SEED_INT_64: int = 0x64656164 +SEED_INT_128: int = 0x6465616462656566 +SEED_INT_192: int = 0x646561646265656663616665 +SEED_INTS: _RandSeed = (SEED_INT_64, SEED_INT_128, SEED_INT_192) # ---- Tests --------------------------------------------------------------------------- -def test_numpy_rng() -> None: - pytest.importorskip("numpy.random", reason="requires numpy") - assert hasattr(dyce.rng, "NumpyRandom") - assert isinstance(dyce.rng.RNG, dyce.rng.NumpyRandom) - - -def test_numpy_rng_pcg64dxsm() -> None: - numpy_random = pytest.importorskip("numpy.random", reason="requires numpy") - - if not hasattr(numpy_random, "PCG64DXSM"): - pytest.skip("requires numpy.random.PCG64DXSM") - - rng = dyce.rng.NumpyRandom(numpy_random.PCG64DXSM()) - _test_w_seed_helper(rng, SEED_INT_128, 0.7903327469601987) - _test_w_seed_helper(rng, SEED_FLOAT, 0.6018795857570297) - _test_w_seed_helper(rng, SEED_BYTES_128, 0.5339952033746491) - _test_w_seed_helper(rng, SEED_INT_192, 0.9912715409588355) - _test_w_seed_helper(rng, SEED_BYTES_192, 0.13818265573158406) - - with pytest.raises(ValueError): - _test_w_seed_helper(rng, object()) # type: ignore - - -def test_numpy_rng_pcg64() -> None: - numpy_random = pytest.importorskip("numpy.random", reason="requires numpy") +def test_numpy_rng_installed() -> None: + try: + from dyce.rng import PCG64DXSMRandom + except ImportError: + pytest.skip("requires numpy") - if not hasattr(numpy_random, "PCG64"): - pytest.skip("requires numpy.random.PCG64") + assert isinstance(RNG, PCG64DXSMRandom) - rng = dyce.rng.NumpyRandom(numpy_random.PCG64()) - _test_w_seed_helper(rng, SEED_INT_128, 0.9794491381144006) - _test_w_seed_helper(rng, SEED_FLOAT, 0.8347478482621317) - _test_w_seed_helper(rng, SEED_BYTES_128, 0.7800090883745199) - _test_w_seed_helper(rng, SEED_INT_192, 0.28018439479392754) - _test_w_seed_helper(rng, SEED_BYTES_192, 0.4814859325412144) - with pytest.raises(ValueError): - _test_w_seed_helper(rng, object()) # type: ignore - - -def test_numpy_rng_default() -> None: - numpy_random = pytest.importorskip("numpy.random", reason="requires numpy") - - if not hasattr(numpy_random, "default_rng"): - pytest.skip("requires numpy.random.default_rng") +def test_numpy_rng() -> None: + try: + from dyce.rng import PCG64DXSMRandom + except ImportError: + pytest.skip("requires numpy") + + rng = PCG64DXSMRandom() + seed: _RandSeed + random: float + getrandbits: int + randbytes: bytes + + for seed, random, getrandbits, randbytes in ( + ( + SEED_INT_64, + 0.5066807340643421, + 0x6CCCD2511ED4B58, + bytes.fromhex("6cccd2511ed4b581"), + ), + ( + SEED_INT_128, + 0.16159916444553268, + 0x32CDBF5A16905E2, + bytes.fromhex("32cdbf5a16905e29"), + ), + ( + SEED_INT_192, + 0.09272816060986888, + 0xE0D0D43C6108BD1, + bytes.fromhex("e0d0d43c6108bd17"), + ), + ( + SEED_INTS, + 0.32331170065667836, + 0x6F230DBC3C8EC45, + bytes.fromhex("6f230dbc3c8ec452"), + ), + ): + _test_random_w_seed_helper(rng, seed, random) + _test_getrandbits_w_seed_helper(rng, seed, 60, getrandbits) + _test_randbytes_w_seed_helper(rng, seed, randbytes) + + +def test_standard_rng_installed() -> None: + try: + from dyce.rng import PCG64DXSMRandom # noqa: F401 + + pytest.skip("requires numpy not be installed") + except ImportError: + pass + + assert isinstance(RNG, Random) + + +def test_standard_rng() -> None: + rng = Random() + + for seed in ( + SEED_INT_64, + SEED_INT_128, + SEED_INT_192, + SEED_INTS, + ): + _test_random_w_seed_helper(rng, seed) + + +def _test_getrandbits_w_seed_helper( + rng: Random, + seed: _RandSeed, + bits: int, + expected: int, +) -> None: + rng.seed(seed) + state = rng.getstate() + val = rng.getrandbits(bits) + assert val == expected + rng.setstate(state) + assert rng.getrandbits(bits) == val - rng = dyce.rng.NumpyRandom(numpy_random.default_rng().bit_generator) - _test_w_seed_helper(rng, SEED_INT_128) - _test_w_seed_helper(rng, SEED_FLOAT) - _test_w_seed_helper(rng, SEED_BYTES_128) - _test_w_seed_helper(rng, SEED_INT_192) - _test_w_seed_helper(rng, SEED_BYTES_192) - with pytest.raises(ValueError): - _test_w_seed_helper(rng, object()) # type: ignore +def _test_randbytes_w_seed_helper( + rng: Random, + seed: _RandSeed, + expected: bytes, +) -> None: + rng.seed(seed) + state = rng.getstate() + val = rng.randbytes(len(expected)) + assert val == expected + rng.setstate(state) + assert rng.randbytes(len(expected)) == val + rng.setstate(state) + assert rng.randbytes(len(expected)) == val -def _test_w_seed_helper( +def _test_random_w_seed_helper( rng: Random, seed: _RandSeed, expected: Optional[float] = None, @@ -105,7 +145,12 @@ def _test_w_seed_helper( assert val >= 0.0 and val < 1.0 if expected is not None: - assert expected == val + assert val == expected + + assert type(rng)(seed).random() == val rng.setstate(state) assert rng.random() == val + + rng.seed(seed) + assert rng.random() == val