Skip to content

Commit

Permalink
optim(rng): Use binarized formats for de/serialization of rng (#92)
Browse files Browse the repository at this point in the history
  • Loading branch information
eddiebergman authored May 4, 2024
2 parents b3e76a8 + f19804e commit 5bdeced
Show file tree
Hide file tree
Showing 8 changed files with 230 additions and 43 deletions.
9 changes: 4 additions & 5 deletions neps/optimizers/base_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from neps.types import ConfigResult
from neps.utils.files import serialize, deserialize
from ..search_spaces.search_space import SearchSpace
from ..utils.common import get_rnd_state, set_rnd_state
from neps.utils._rng import SeedState
from neps.utils.data_loading import _get_cost, _get_learning_curve, _get_loss


Expand Down Expand Up @@ -63,7 +63,7 @@ def get_config_and_ids(self) -> tuple[SearchSpace, str, str | None]:
raise NotImplementedError

def get_state(self) -> Any:
_state = {"rnd_seeds": get_rnd_state(), "used_budget": self.used_budget}
_state = {"used_budget": self.used_budget}
if self.budget is not None:
# TODO(eddiebergman): Seems like this isn't used anywhere,
# A fuzzy find search for `remaining_budget` shows this as the
Expand All @@ -73,7 +73,6 @@ def get_state(self) -> Any:
return _state

def load_state(self, state: Any) -> None:
set_rnd_state(state["rnd_seeds"])
self.used_budget = state["used_budget"]

def load_config(self, config_dict: Mapping[str, Any]) -> SearchSpace:
Expand Down Expand Up @@ -114,8 +113,8 @@ def whoami(self) -> str:
@contextmanager
def using_state(self, state_file: Path) -> Iterator[Self]:
if state_file.exists():
state = deserialize(state_file)
self.load_state(state)
optimizer_state = deserialize(state_file)
self.load_state(optimizer_state)

yield self

Expand Down
10 changes: 8 additions & 2 deletions neps/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,14 @@

from neps.types import ERROR, POST_EVAL_HOOK_SIGNATURE, ConfigLike, ConfigResult
from neps.utils._locker import Locker
from neps.utils._rng import SeedState
from neps.utils.files import deserialize, empty_file, serialize

if TYPE_CHECKING:
from .optimizers.base_optimizer import BaseOptimizer

# Wait time between each successive poll to see if state can be grabbed
DEFAULT_STATE_POLL: float = 1.0
DEFAULT_STATE_POLL: float = 0.1
ENVIRON_STATE_POLL_KEY = "NEPS_STATE_POLL"

# Timeout before giving up on trying to grab the state, raising an error
Expand Down Expand Up @@ -299,6 +300,7 @@ class SharedState:
optimizer_state_file: The path to the optimizers state.
optimizer_info_file: The path to the file containing information about the
optimizer's setup.
seed_state_dir: Directory where the seed state is stored.
results_dir: Directory where results for configurations are stored.
"""

Expand All @@ -308,6 +310,7 @@ class SharedState:
lock: Locker = field(init=False)
optimizer_state_file: Path = field(init=False)
optimizer_info_file: Path = field(init=False)
seed_state_dir: Path = field(init=False)
results_dir: Path = field(init=False)

def __post_init__(self) -> None:
Expand All @@ -322,6 +325,7 @@ def __post_init__(self) -> None:
self.lock = Locker(self.base_dir / ".decision_lock")
self.optimizer_state_file = self.base_dir / ".optimizer_state.yaml"
self.optimizer_info_file = self.base_dir / ".optimizer_info.yaml"
self.seed_state_dir = self.base_dir / ".seed_state"

def trial_refs(self) -> dict[Trial.State, list[Trial.Disk]]:
"""Get the disk reference of every trial, grouped by their state."""
Expand Down Expand Up @@ -565,7 +569,9 @@ def launch_runtime( # noqa: PLR0913, C901, PLR0915

# While we have the decision lock, we will now sample with the optimizer in
# this process
with sampler.using_state(shared_state.optimizer_state_file):
with SeedState.use(shared_state.seed_state_dir), sampler.using_state(
shared_state.optimizer_state_file
):
if sampler.budget is not None and sampler.used_budget >= sampler.budget:
logger.info("Maximum budget reached, shutting down")
break
Expand Down
176 changes: 176 additions & 0 deletions neps/utils/_rng.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
from __future__ import annotations

import json
import random
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import Iterator, List, Tuple, Union
from typing_extensions import TypeAlias

import numpy as np
import torch

NP_RNG_STATE: TypeAlias = Tuple[str, np.ndarray, int, int, float]
PY_RNG_STATE: TypeAlias = Tuple[int, Tuple[int, ...], Union[int, None]]
TORCH_RNG_STATE: TypeAlias = torch.Tensor
TORCH_CUDA_RNG_STATE: TypeAlias = List[torch.Tensor]


@dataclass
class SeedState:
"""State of the global rng.
Primarly enables storing of the rng state to disk using a binary format
native to each library, allowing for potential version mistmatches between
processes loading the state, as long as they can read the binary format.
"""

# It seems like they're all uint32 but I can't be sure.
PY_RNG_STATE_DTYPE = np.int64

np_rng: NP_RNG_STATE
py_rng: PY_RNG_STATE
torch_rng: TORCH_RNG_STATE
torch_cuda_rng: TORCH_CUDA_RNG_STATE | None

@classmethod
def get(cls) -> SeedState:
"""Current state of the global rng.
Takes a snapshot, including cloning or copying any arrays, tensors, etc.
"""
# https://numpy.org/doc/stable/reference/random/generated/numpy.random.get_state.html
np_keys = np.random.get_state(legacy=True) # noqa: NPY002
assert np_keys[0] == "MT19937" # type: ignore
np_keys = (np_keys[0], np_keys[1].copy(), *np_keys[2:]) # type: ignore

py_rng = random.getstate()
torch_rng = torch.random.get_rng_state().clone()
torch_cuda_keys: list[torch.Tensor] | None = None
if torch.cuda.is_available():
torch_cuda_keys = [c.clone() for c in torch.cuda.get_rng_state_all()]

return cls(
np_rng=np_keys, # type: ignore
py_rng=py_rng,
torch_rng=torch_rng,
torch_cuda_rng=torch_cuda_keys,
)

def set_as_global_state(self) -> None:
"""Set the global rng to the given state."""
np.random.set_state(self.np_rng) # noqa: NPY002
random.setstate(self.py_rng)
torch.random.set_rng_state(self.torch_rng)
if self.torch_cuda_rng and torch.cuda.is_available():
torch.cuda.set_rng_state_all(self.torch_cuda_rng)

def dump(self, path: Path) -> None:
"""Save the state to a directory."""
if path.exists():
assert path.is_dir()
else:
path.mkdir(parents=True)

py_rng_version, py_rng_state, py_guass_next = self.py_rng
np_rng_kind, np_rng_state, np_pos, np_has_gauss, np_cached_gauss = self.np_rng

seed_info = {
"np_rng_kind": np_rng_kind,
"np_pos": np_pos,
"np_has_gauss": np_has_gauss,
"np_cached_gauss": np_cached_gauss,
"py_rng_version": py_rng_version,
"py_guass_next": py_guass_next,
}

# NOTE(eddiebergman): Chose JSON since it's fast and non-injectable
with (path / "seed_info.json").open("w") as f:
json.dump(seed_info, f)

py_rng_state_arr = np.array(py_rng_state, dtype=self.PY_RNG_STATE_DTYPE)
with (path / "py_rng.npy").open("wb") as f:
py_rng_state_arr.tofile(f)

with (path / "np_rng_state.npy").open("wb") as f:
np_rng_state.tofile(f)

torch.save(self.torch_rng, path / "torch_rng_state.pt")

if self.torch_cuda_rng:
torch.save(self.torch_cuda_rng, path / "torch_cuda_rng_state.pt")

@classmethod
def load(cls, path: Path) -> SeedState:
assert path.is_dir()

with (path / "seed_info.json").open("r") as f:
seed_info = json.load(f)

# Load and set pythons rng
py_rng_state = tuple(
int(x) for x in np.fromfile(path / "py_rng.npy", dtype=cls.PY_RNG_STATE_DTYPE)
)
np_rng_state = np.fromfile(path / "np_rng_state.npy", dtype=np.uint32)

# By specifying `weights_only=True`, it disables arbitrary object loading
torch_rng_state = torch.load(path / "torch_rng_state.pt", weights_only=True)

torch_cuda_rng = None
torch_cuda_rng_path = path / "torch_cuda_rng_state.pt"
if torch_cuda_rng_path.exists():
# By specifying `weights_only=True`, it disables arbitrary object loading
torch_cuda_rng = torch.load(
path / "torch_cuda_rng_state.pt",
weights_only=True,
)

return cls(
np_rng=(
seed_info["np_rng_kind"],
np_rng_state,
seed_info["np_pos"],
seed_info["np_has_gauss"],
seed_info["np_cached_gauss"],
),
py_rng=(
seed_info["py_rng_version"],
py_rng_state,
seed_info["py_guass_next"],
),
torch_rng=torch_rng_state,
torch_cuda_rng=torch_cuda_rng,
)

@classmethod
@contextmanager
def use(
cls,
path: Path,
*,
update_on_exit: bool = False,
) -> Iterator[SeedState]:
"""Context manager to use a seed state.
If the path exists, load the seed state from the path and set it as the
global state. Otherwise, use the current global state.
Args:
path: Path to the seed state.
update_on_exit: If True, get the seed state after the context manager returns
and save it to the path.
Yields:
SeedState: The seed state in use.
"""
if path.exists():
seed_state = cls.load(path)
seed_state.set_as_global_state()
else:
seed_state = cls.get()

yield seed_state

if update_on_exit:
cls.get().dump(path)
34 changes: 0 additions & 34 deletions neps/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@
from __future__ import annotations

import inspect
import random
from functools import partial
from pathlib import Path
from typing import Any, Iterable, Mapping, Sequence

import numpy as np
import torch
import yaml

Expand Down Expand Up @@ -277,38 +275,6 @@ def filter_instances(itr: Iterable[Any], *types: type) -> list[Any]:
return [el for el in itr if isinstance(el, types)]


def get_rnd_state() -> dict:
"""Current state of the global random number generators in a devoctorized format."""
np_state = list(np.random.get_state()) # noqa: NPY002
np_state[1] = np_state[1].tolist() # type: ignore
state = {
"random_state": random.getstate(),
"np_seed_state": np_state,
"torch_seed_state": torch.random.get_rng_state().tolist(),
}
if torch.cuda.is_available():
state["torch_cuda_seed_state"] = [
dev.tolist() for dev in torch.cuda.get_rng_state_all()
]
return state


def set_rnd_state(state: dict) -> None:
"""Set the global random number generators to the given state."""
random.setstate(
tuple(
tuple(rnd_s) if isinstance(rnd_s, list) else rnd_s
for rnd_s in state["random_state"]
)
)
np.random.set_state(tuple(state["np_seed_state"])) # noqa: NPY002
torch.random.set_rng_state(torch.ByteTensor(state["torch_seed_state"]))
if torch.cuda.is_available() and "torch_cuda_seed_state" in state:
torch.cuda.set_rng_state_all(
[torch.ByteTensor(dev) for dev in state["torch_cuda_seed_state"]]
)


class MissingDependencyError(ImportError):
"""Raise when a dependency is missing for an optional feature."""

Expand Down
1 change: 0 additions & 1 deletion neps_examples/basic_usage/hpo_usage_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def run_pipeline(
if categorical_name2 == "a":
loss += 1

time.sleep(2) # For demonstration purposes only
return loss


Expand Down
1 change: 0 additions & 1 deletion neps_examples/basic_usage/hyperparameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

def run_pipeline(float1, float2, categorical, integer1, integer2):
loss = -float(np.sum([float1, float2, int(categorical), integer1, integer2]))
time.sleep(2) # For demonstration purposes only
return loss


Expand Down
42 changes: 42 additions & 0 deletions tests/test_rng.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from __future__ import annotations

from pathlib import Path
import random
from typing import Callable
import numpy as np
import torch
import pytest

from neps.utils._rng import SeedState

@pytest.mark.parametrize(
"make_ints", (
lambda: [random.randint(0, 100) for _ in range(10)],
lambda: list(np.random.randint(0, 100, (10,))),
lambda: list(torch.randint(0, 100, (10,))),
)
)
def test_randomstate_consistent(tmp_path: Path, make_ints: Callable[[], list[int]]) -> None:
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

seed_dir = tmp_path / "seed_dir"

seed_state = SeedState.get()
integers_1 = make_ints()

seed_state.set_as_global_state()
integers_2 = make_ints()

assert integers_1 == integers_2

SeedState.get().dump(seed_dir)
integers_3 = make_ints()

assert integers_3 != integers_2, "Ensure we have actually changed random state"

SeedState.load(seed_dir).set_as_global_state()
integers_4 = make_ints()

assert integers_3 == integers_4
File renamed without changes.

0 comments on commit 5bdeced

Please sign in to comment.