Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Jan 29, 2025
1 parent 93736ac commit e3754a7
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions src/levanter/data/_prp.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import typing

import jax.numpy as jnp
import jax.random as jrandom # Import jax.random
import jax.random as jrandom
import numpy as np


# TODO: do we make this a pytree
class Permutation:
# Pseudo-Random Permutation Code
"""A stateless pseudo-random permutation.
Expand All @@ -21,8 +19,10 @@ class Permutation:
def __init__(self, length, prng_key):
self.length = length
# Convert jax.random.PRNGKey to numpy.random.Generator
self.rng = np.random.Generator(np.random.PCG64(jrandom.randint(prng_key, (), 0, 2**32).item())) # Use jrandom.randint
self.a, self.b = self._generate_permutation_params() # Generate a and b in init
self.rng = np.random.Generator(
np.random.PCG64(jrandom.randint(prng_key, (), 0, 2**32).item())
) # Use jrandom.randint
self.a, self.b = self._generate_permutation_params() # Generate a and b in init

def _generate_permutation_params(self):
length = self.length
Expand All @@ -36,7 +36,6 @@ def _generate_permutation_params(self):
b = rng.integers(0, length) # b can be in [0, length-1]
return a, b


@typing.overload
def __call__(self, indices: int) -> int:
...
Expand All @@ -61,7 +60,7 @@ def __call__(self, indices):
indices = np.array(indices)
was_int = True

out = (a * indices + b) % length # Compute permutation on-the-fly
out = (a * indices + b) % length # Compute permutation on-the-fly

if was_int:
return int(out)
Expand Down

0 comments on commit e3754a7

Please sign in to comment.