diff --git a/src/levanter/data/_prp.py b/src/levanter/data/_prp.py index af6d82817..59170fbd3 100644 --- a/src/levanter/data/_prp.py +++ b/src/levanter/data/_prp.py @@ -1,11 +1,9 @@ import typing -import jax.numpy as jnp -import jax.random as jrandom # Import jax.random +import jax.random as jrandom import numpy as np -# TODO: do we make this a pytree class Permutation: # Pseudo-Random Permutation Code """A stateless pseudo-random permutation. @@ -21,8 +19,10 @@ class Permutation: def __init__(self, length, prng_key): self.length = length # Convert jax.random.PRNGKey to numpy.random.Generator - self.rng = np.random.Generator(np.random.PCG64(jrandom.randint(prng_key, (), 0, 2**32).item())) # Use jrandom.randint - self.a, self.b = self._generate_permutation_params() # Generate a and b in init + self.rng = np.random.Generator( + np.random.PCG64(jrandom.randint(prng_key, (), 0, 2**32).item()) + ) # Use jrandom.randint + self.a, self.b = self._generate_permutation_params() # Generate a and b in init def _generate_permutation_params(self): length = self.length @@ -36,7 +36,6 @@ def _generate_permutation_params(self): b = rng.integers(0, length) # b can be in [0, length-1] return a, b - @typing.overload def __call__(self, indices: int) -> int: ... @@ -61,7 +60,7 @@ def __call__(self, indices): indices = np.array(indices) was_int = True - out = (a * indices + b) % length # Compute permutation on-the-fly + out = (a * indices + b) % length # Compute permutation on-the-fly if was_int: return int(out)