Skip to content

Commit

Permalink
Use invrs-utils in scripts (#50)
Browse files Browse the repository at this point in the history
* Use invrs-utils in scripts

* Remove unused imports
  • Loading branch information
mfschubert authored Nov 3, 2023
1 parent 289454e commit 7083071
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 87 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ dependencies = [
[project.optional-dependencies]
examples = [
"invrs-opt",
"invrs-utils",
"scikit-image",
]
tests = [
Expand Down
103 changes: 16 additions & 87 deletions scripts/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,14 @@
"""

import argparse
import dataclasses
import functools
import glob
import itertools
import json
import multiprocessing as mp
import os
import random
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
from typing import Any, Dict, List, Tuple

from invrs_utils.experiment import sweep

Sweep = List[Dict[str, Any]]

Expand All @@ -33,15 +32,15 @@ def run_experiment(
"""Runs an experiment."""

# Define the experiment.
challenge_sweeps = sweep("challenge_name", ["metagrating"])
hparam_sweeps = sweep_product(
sweep("density_relative_mean", [0.5]),
sweep("density_relative_noise_amplitude", [0.1]),
sweep("beta", [2.0]),
sweep("seed", range(3)),
sweep("steps", [steps]),
challenge_sweeps = sweep.sweep("challenge_name", ["metagrating"])
hparam_sweeps = sweep.product(
sweep.sweep("density_relative_mean", [0.5]),
sweep.sweep("density_relative_noise_amplitude", [0.1]),
sweep.sweep("beta", [2.0]),
sweep.sweep("seed", range(3)),
sweep.sweep("steps", [steps]),
)
sweeps = sweep_product(challenge_sweeps, hparam_sweeps)
sweeps = sweep.product(challenge_sweeps, hparam_sweeps)

# Set up checkpointing directory.
wid_paths = [experiment_path + f"/wid_{i:04}" for i in range(len(sweeps))]
Expand Down Expand Up @@ -74,29 +73,6 @@ def _run_work_unit(path_and_kwargs: Tuple[str, Dict[str, Any]]) -> None:
return run_work_unit(wid_path=wid_path, **kwargs)


def sweep(name: str, values: Sequence[Any]) -> Sweep:
"""Generate a list of dictionaries defining a sweep."""
return [{name: v} for v in values]


def sweep_zip(*sweeps: Sweep) -> Sweep:
"""Zip sweeps of different variables."""
return [_merge(*kw) for kw in zip(*sweeps, strict=True)]


def sweep_product(*sweeps: Sweep) -> Sweep:
"""Return the Cartesian product of multiple sweeps."""
return [_merge(*kw) for kw in itertools.product(*sweeps)]


def _merge(*vars: Dict[str, Any]) -> Dict[str, Any]:
"""Merge dictionaries defining sweeps of multiple variables."""
d = {}
for v in vars:
d.update(v)
return d


# -----------------------------------------------------------------------------
# Functions related to individual work units within the experiment.
# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -130,6 +106,7 @@ def run_work_unit(
import time

import invrs_opt
from invrs_utils.experiment import checkpoint
import jax
from jax import numpy as jnp
from totypes import json_utils
Expand All @@ -138,7 +115,7 @@ def run_work_unit(
from invrs_gym.utils import initializers

# Create a basic checkpoint manager that can serialize custom types.
mngr = CheckpointManager(
mngr = checkpoint.CheckpointManager(
path=wid_path,
save_interval_steps=10,
max_to_keep=1,
Expand Down Expand Up @@ -179,9 +156,9 @@ def loss_fn(
opt = invrs_opt.density_lbfgsb(beta=beta)
if mngr.latest_step() is not None:
latest_step: int = mngr.latest_step() # type: ignore[assignment]
checkpoint = mngr.restore(latest_step)
state = checkpoint["state"]
scalars = checkpoint["scalars"]
latest_checkpoint = mngr.restore(latest_step)
state = latest_checkpoint["state"]
scalars = latest_checkpoint["scalars"]
else:
latest_step = -1 # Next step is `0`.
params = challenge.component.init(jax.random.PRNGKey(seed))
Expand Down Expand Up @@ -219,54 +196,6 @@ def _log_scalar(name: str, value: float) -> None:
os.utime(wid_path, None)


# -----------------------------------------------------------------------------
# Functions related to checkpointing.
# -----------------------------------------------------------------------------


@dataclasses.dataclass
class CheckpointManager:
"""A simple checkpoint manager with an orbax-like API."""

path: str
save_interval_steps: int
max_to_keep: int
serialize_fn: Callable[[Any], str] = json.dumps
deserialize_fn: Callable[[str], Any] = json.loads

def latest_step(self) -> Optional[int]:
"""Return the latest checkpointed step, or `None` if no checkpoints exist."""
steps = self._checkpoint_steps()
steps.sort()
return None if len(steps) == 0 else steps[-1]

def save(self, step: int, pytree: Any, force_save: bool = False) -> None:
"""Save a pytree checkpoint."""
if (step + 1) % self.save_interval_steps != 0 and not force_save:
return
with open(self._checkpoint_fname(step), "w") as f:
f.write(self.serialize_fn(pytree))
steps = self._checkpoint_steps()
steps.sort()
steps_to_delete = steps[: -self.max_to_keep]
for step in steps_to_delete:
os.remove(self._checkpoint_fname(step))

def restore(self, step: int) -> Any:
"""Restore a pytree checkpoint."""
with open(self._checkpoint_fname(step)) as f:
return self.deserialize_fn(f.read())

def _checkpoint_steps(self) -> List[int]:
"""Return the steps for which checkpoint files exist."""
fnames = glob.glob(self.path + "/checkpoint_*.json")
return [int(f.split("_")[-1][:-5]) for f in fnames]

def _checkpoint_fname(self, step: int) -> str:
"""Return the chackpoint filename for the given step."""
return self.path + f"/checkpoint_{step:04}.json"


def _is_scalar(x: Any) -> bool:
"""Returns `True` if `x` is a scalar, i.e. it can be cast as a float."""
try:
Expand Down

0 comments on commit 7083071

Please sign in to comment.