Skip to content

Commit

Permalink
Use numpy.random as implementation inside of random.Random
Browse files Browse the repository at this point in the history
Augments f278e2a to use the standard library's `random.Random` as a perfectly good
interface to random values (plus access to all the derivatives that come with it), but
provide a (mostly) conforming implementation that uses `numpy.random` for its
implementation (if available).
  • Loading branch information
posita committed Sep 22, 2021
1 parent 72469e6 commit 3c9b258
Show file tree
Hide file tree
Showing 3 changed files with 204 additions and 43 deletions.
53 changes: 10 additions & 43 deletions dyce/h.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@

from .bt import beartype
from .lifecycle import experimental
from .rng import RNG
from .symmetries import comb, gcd
from .types import (
CachingProtocolMeta,
Expand Down Expand Up @@ -94,7 +95,6 @@
_BinaryOperatorT = Callable[[_T_co, _T_co], _T_co]
_ExpandT = Callable[["H", OutcomeT], Union[OutcomeT, "H"]]
_CoalesceT = Callable[["H", OutcomeT], "H"]
_ChoiceT = Callable[["H"], OutcomeT]


# ---- Data ----------------------------------------------------------------------------
Expand All @@ -120,47 +120,6 @@ def coalesce_replace(h: H, outcome: OutcomeT) -> H:
return h


_choice_impl: _ChoiceT


try:
from numpy.random import Generator

try:
from numpy.random import PCG64DXSM as _BitGenImpl
except ImportError:
from numpy.random import PCG64 as _BitGenImpl

_NUMPY_RNG = Generator(_BitGenImpl())

def _numpy_choice_impl(h: H) -> OutcomeT:
return (
_NUMPY_RNG.choice(
tuple(h.outcomes()),
p=tuple(Fraction(count, h.total) for count in h.counts()),
)
if h
else 0
)

_choice_impl = _numpy_choice_impl
except ImportError:
from random import choices

def _default_choice_impl(h: H) -> OutcomeT:
return (
choices(
population=tuple(h.outcomes()),
weights=tuple(h.counts()),
k=1,
)[0]
if h
else 0
)

_choice_impl = _default_choice_impl


# ---- Classes -------------------------------------------------------------------------


Expand Down Expand Up @@ -1760,7 +1719,15 @@ def roll(self) -> OutcomeT:
r"""
Returns a (weighted) random outcome, sorted.
"""
return _choice_impl(self)
return (
RNG.choices(
population=tuple(self.outcomes()),
weights=tuple(self.counts()),
k=1,
)[0]
if self
else 0
)

def _lowest_terms(self) -> Iterable[Tuple[OutcomeT, int]]:
counts_gcd = gcd(*self.counts())
Expand Down
83 changes: 83 additions & 0 deletions dyce/rng.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# ======================================================================================
# Copyright and other protections apply. Please see the accompanying LICENSE file for
# rights and restrictions governing use of this software. All rights not expressly
# waived or licensed are reserved. If that file is missing or appears to be modified
# from its original, then please contact the author before viewing or using this
# software in any capacity.
# ======================================================================================

from __future__ import annotations

from random import Random
from typing import Any, NewType, Optional, Union

__all__ = ("RNG",)


# ---- Types ---------------------------------------------------------------------------


_RandSeed = Union[int, float, str, bytes, bytearray]
_RandState = NewType("_RandState", Any)


# ---- Data ----------------------------------------------------------------------------


RNG: Random = Random()


try:
import numpy.random
from numpy.random import BitGenerator, Generator

class NumpyRandom(Random):
r"""
Defines a [``!#python
random.Random``](https://docs.python.org/3/library/random.html#random.Random)
implementation that accepts and uses a [``!#python
numpy.random.BitGenerator``](https://numpy.org/doc/stable/reference/random/bit_generators/index.html)
under the covers. Motivated by
[avrae/d20#7](https://github.com/avrae/d20/issues/7).
"""

def __init__(self, bit_generator: BitGenerator):
self._g = Generator(bit_generator)

def random(self) -> float:
return self._g.random()

def seed(self, a: Optional[_RandSeed], version: int = 2) -> None:
if a is not None and not isinstance(a, (int, float, str, bytes, bytearray)):
raise ValueError(f"unrecognized seed type ({type(a)})")

bg_type = type(self._g.bit_generator)

if a is None:
self._g = Generator(bg_type())
else:
# This is somewhat fragile and may not be the best approach. It uses
# `random.Random` to generate its own state from the seed in order to
# maintain compatibility with accepted seed types. (NumPy only accepts
# ints whereas the standard library accepts ints, floats, bytes, etc.).
# That state consists of a 3-tuple: (version: int, internal_state:
# tuple[int], gauss_next: float) at least for for versions through 3 (as
# of this writing). We feed internal_state as the seed for the NumPy
# BitGenerator.
version, internal_state, _ = Random(a).getstate()
self._g = Generator(bg_type(internal_state))

def getstate(self) -> _RandState:
return _RandState(self._g.bit_generator.state)

def setstate(self, state: _RandState) -> None:
self._g.bit_generator.state = state

if hasattr(numpy.random, "PCG64DXSM"):
RNG = NumpyRandom(numpy.random.PCG64DXSM())
elif hasattr(numpy.random, "PCG64"):
RNG = NumpyRandom(numpy.random.PCG64())
elif hasattr(numpy.random, "default_rng"):
RNG = NumpyRandom(numpy.random.default_rng().bit_generator)
except ImportError:
pass
111 changes: 111 additions & 0 deletions tests/test_rng.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# ======================================================================================
# Copyright and other protections apply. Please see the accompanying LICENSE file for
# rights and restrictions governing use of this software. All rights not expressly
# waived or licensed are reserved. If that file is missing or appears to be modified
# from its original, then please contact the author before viewing or using this
# software in any capacity.
# ======================================================================================

from __future__ import annotations

from decimal import Decimal
from random import Random
from typing import Optional

import pytest

import dyce.rng
from dyce.rng import _RandSeed

__all__ = ()


# ---- Data ----------------------------------------------------------------------------


SEED_INT_128 = 0x6265656663616665
SEED_FLOAT = float(
Decimal(
"9856940084378475016744131457734599215371411366662962480265638551381775059468656635085733393811201634227995293393551923733235754825282073085472925752147516616452603904"
),
)
SEED_BYTES_128 = b"beefcafe"[::-1]
SEED_INT_192 = 0x646561646265656663616665
SEED_BYTES_192 = b"deadbeefcafe"[::-1]


# ---- Tests ---------------------------------------------------------------------------


def test_numpy_rng() -> None:
pytest.importorskip("numpy.random", reason="requires numpy")
assert hasattr(dyce.rng, "NumpyRandom")
assert isinstance(dyce.rng.RNG, dyce.rng.NumpyRandom)


def test_numpy_rng_pcg64dxsm() -> None:
numpy_random = pytest.importorskip("numpy.random", reason="requires numpy")

if not hasattr(numpy_random, "PCG64DXSM"):
pytest.skip("requires numpy.random.PCG64DXSM")

rng = dyce.rng.NumpyRandom(numpy_random.PCG64DXSM())
_test_w_seed_helper(rng, SEED_INT_128, 0.7903327469601987)
_test_w_seed_helper(rng, SEED_FLOAT, 0.6018795857570297)
_test_w_seed_helper(rng, SEED_BYTES_128, 0.5339952033746491)
_test_w_seed_helper(rng, SEED_INT_192, 0.9912715409588355)
_test_w_seed_helper(rng, SEED_BYTES_192, 0.13818265573158406)

with pytest.raises(ValueError):
_test_w_seed_helper(rng, object()) # type: ignore


def test_numpy_rng_pcg64() -> None:
numpy_random = pytest.importorskip("numpy.random", reason="requires numpy")

if not hasattr(numpy_random, "PCG64"):
pytest.skip("requires numpy.random.PCG64")

rng = dyce.rng.NumpyRandom(numpy_random.PCG64())
_test_w_seed_helper(rng, SEED_INT_128, 0.9794491381144006)
_test_w_seed_helper(rng, SEED_FLOAT, 0.8347478482621317)
_test_w_seed_helper(rng, SEED_BYTES_128, 0.7800090883745199)
_test_w_seed_helper(rng, SEED_INT_192, 0.28018439479392754)
_test_w_seed_helper(rng, SEED_BYTES_192, 0.4814859325412144)

with pytest.raises(ValueError):
_test_w_seed_helper(rng, object()) # type: ignore


def test_numpy_rng_default() -> None:
numpy_random = pytest.importorskip("numpy.random", reason="requires numpy")

if not hasattr(numpy_random, "default_rng"):
pytest.skip("requires numpy.random.default_rng")

rng = dyce.rng.NumpyRandom(numpy_random.default_rng().bit_generator)
_test_w_seed_helper(rng, SEED_INT_128)
_test_w_seed_helper(rng, SEED_FLOAT)
_test_w_seed_helper(rng, SEED_BYTES_128)
_test_w_seed_helper(rng, SEED_INT_192)
_test_w_seed_helper(rng, SEED_BYTES_192)

with pytest.raises(ValueError):
_test_w_seed_helper(rng, object()) # type: ignore


def _test_w_seed_helper(
rng: Random,
seed: _RandSeed,
expected: Optional[float] = None,
) -> None:
rng.seed(seed)
state = rng.getstate()
val = rng.random()
assert val >= 0.0 and val < 1.0

if expected is not None:
assert expected == val

rng.setstate(state)
assert rng.random() == val

0 comments on commit 3c9b258

Please sign in to comment.