diff --git a/pyproject.toml b/pyproject.toml index c008e78..09ca67e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ [project.optional-dependencies] examples = [ "invrs-opt", + "invrs-utils", "scikit-image", ] tests = [ diff --git a/scripts/experiment.py b/scripts/experiment.py index 6a29ded..ed8c11b 100644 --- a/scripts/experiment.py +++ b/scripts/experiment.py @@ -10,15 +10,14 @@ """ import argparse -import dataclasses import functools -import glob -import itertools import json import multiprocessing as mp import os import random -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple +from typing import Any, Dict, List, Tuple + +from invrs_utils.experiment import sweep Sweep = List[Dict[str, Any]] @@ -33,15 +32,15 @@ def run_experiment( """Runs an experiment.""" # Define the experiment. - challenge_sweeps = sweep("challenge_name", ["metagrating"]) - hparam_sweeps = sweep_product( - sweep("density_relative_mean", [0.5]), - sweep("density_relative_noise_amplitude", [0.1]), - sweep("beta", [2.0]), - sweep("seed", range(3)), - sweep("steps", [steps]), + challenge_sweeps = sweep.sweep("challenge_name", ["metagrating"]) + hparam_sweeps = sweep.product( + sweep.sweep("density_relative_mean", [0.5]), + sweep.sweep("density_relative_noise_amplitude", [0.1]), + sweep.sweep("beta", [2.0]), + sweep.sweep("seed", range(3)), + sweep.sweep("steps", [steps]), ) - sweeps = sweep_product(challenge_sweeps, hparam_sweeps) + sweeps = sweep.product(challenge_sweeps, hparam_sweeps) # Set up checkpointing directory. wid_paths = [experiment_path + f"/wid_{i:04}" for i in range(len(sweeps))] @@ -74,29 +73,6 @@ def _run_work_unit(path_and_kwargs: Tuple[str, Dict[str, Any]]) -> None: return run_work_unit(wid_path=wid_path, **kwargs) -def sweep(name: str, values: Sequence[Any]) -> Sweep: - """Generate a list of dictionaries defining a sweep.""" - return [{name: v} for v in values] - - -def sweep_zip(*sweeps: Sweep) -> Sweep: - """Zip sweeps of different variables.""" - return [_merge(*kw) for kw in zip(*sweeps, strict=True)] - - -def sweep_product(*sweeps: Sweep) -> Sweep: - """Return the Cartesian product of multiple sweeps.""" - return [_merge(*kw) for kw in itertools.product(*sweeps)] - - -def _merge(*vars: Dict[str, Any]) -> Dict[str, Any]: - """Merge dictionaries defining sweeps of multiple variables.""" - d = {} - for v in vars: - d.update(v) - return d - - # ----------------------------------------------------------------------------- # Functions related to individual work units within the experiment. # ----------------------------------------------------------------------------- @@ -130,6 +106,7 @@ def run_work_unit( import time import invrs_opt + from invrs_utils.experiment import checkpoint import jax from jax import numpy as jnp from totypes import json_utils @@ -138,7 +115,7 @@ def run_work_unit( from invrs_gym.utils import initializers # Create a basic checkpoint manager that can serialize custom types. - mngr = CheckpointManager( + mngr = checkpoint.CheckpointManager( path=wid_path, save_interval_steps=10, max_to_keep=1, @@ -179,9 +156,9 @@ def loss_fn( opt = invrs_opt.density_lbfgsb(beta=beta) if mngr.latest_step() is not None: latest_step: int = mngr.latest_step() # type: ignore[assignment] - checkpoint = mngr.restore(latest_step) - state = checkpoint["state"] - scalars = checkpoint["scalars"] + latest_checkpoint = mngr.restore(latest_step) + state = latest_checkpoint["state"] + scalars = latest_checkpoint["scalars"] else: latest_step = -1 # Next step is `0`. params = challenge.component.init(jax.random.PRNGKey(seed)) @@ -219,54 +196,6 @@ def _log_scalar(name: str, value: float) -> None: os.utime(wid_path, None) -# ----------------------------------------------------------------------------- -# Functions related to checkpointing. -# ----------------------------------------------------------------------------- - - -@dataclasses.dataclass -class CheckpointManager: - """A simple checkpoint manager with an orbax-like API.""" - - path: str - save_interval_steps: int - max_to_keep: int - serialize_fn: Callable[[Any], str] = json.dumps - deserialize_fn: Callable[[str], Any] = json.loads - - def latest_step(self) -> Optional[int]: - """Return the latest checkpointed step, or `None` if no checkpoints exist.""" - steps = self._checkpoint_steps() - steps.sort() - return None if len(steps) == 0 else steps[-1] - - def save(self, step: int, pytree: Any, force_save: bool = False) -> None: - """Save a pytree checkpoint.""" - if (step + 1) % self.save_interval_steps != 0 and not force_save: - return - with open(self._checkpoint_fname(step), "w") as f: - f.write(self.serialize_fn(pytree)) - steps = self._checkpoint_steps() - steps.sort() - steps_to_delete = steps[: -self.max_to_keep] - for step in steps_to_delete: - os.remove(self._checkpoint_fname(step)) - - def restore(self, step: int) -> Any: - """Restore a pytree checkpoint.""" - with open(self._checkpoint_fname(step)) as f: - return self.deserialize_fn(f.read()) - - def _checkpoint_steps(self) -> List[int]: - """Return the steps for which checkpoint files exist.""" - fnames = glob.glob(self.path + "/checkpoint_*.json") - return [int(f.split("_")[-1][:-5]) for f in fnames] - - def _checkpoint_fname(self, step: int) -> str: - """Return the chackpoint filename for the given step.""" - return self.path + f"/checkpoint_{step:04}.json" - - def _is_scalar(x: Any) -> bool: """Returns `True` if `x` is a scalar, i.e. it can be cast as a float.""" try: