Skip to content

Commit

Permalink
Merge pull request #97 from invrs-io/experiment
Browse files Browse the repository at this point in the history
Use new experiment utils
  • Loading branch information
mfschubert authored Mar 19, 2024
2 parents 1c9a3f0 + c71aa64 commit ae2db72
Showing 1 changed file with 9 additions and 54 deletions.
63 changes: 9 additions & 54 deletions scripts/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,16 @@
import argparse
import functools
import json
import multiprocessing as mp
import os
import random
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, Sequence

from invrs_utils.experiment import sweep

Sweep = List[Dict[str, Any]]
from invrs_utils.experiment import experiment, sweep

PRINT_INTERVAL = 300


def run_experiment(
experiment_path: str,
workers: int,
dry_run: bool,
randomize: bool,
steps: int,
) -> None:
"""Runs an experiment."""

# Define the experiment.
def experiment_sweeps(steps: int) -> Sequence[Dict[str, Any]]:
"""Defines the hyperparameter sweep of the experiment."""
challenge_sweeps = sweep.sweep("challenge_name", ["metagrating"])
hparam_sweeps = sweep.product(
sweep.sweep("density_relative_mean", [0.5]),
Expand All @@ -42,44 +30,10 @@ def run_experiment(
sweep.sweep("seed", range(3)),
sweep.sweep("steps", [steps]),
)
sweeps = sweep.product(challenge_sweeps, hparam_sweeps)

# Set up checkpointing directory.
wid_paths = [experiment_path + f"/wid_{i:04}" for i in range(len(sweeps))]

# Print some information about the experiment.
print(
f"Experiment:\n"
f" worker count = {max(1, workers)}\n"
f" work unit count = {len(sweeps)}\n"
f" experiment path = {experiment_path}\n"
f"Work units:"
)
for wid_path, kwargs in zip(wid_paths, sweeps):
print(f" {wid_path}: {kwargs}")

path_and_kwargs = list(zip(wid_paths, sweeps))
if randomize:
random.shuffle(path_and_kwargs)

if dry_run:
return

with mp.Pool(processes=workers) as pool:
_ = list(pool.imap_unordered(_run_work_unit, path_and_kwargs))


def _run_work_unit(path_and_kwargs: Tuple[str, Dict[str, Any]]) -> None:
"""Wraps `run_work_unit` so that it can be called by `map`."""
wid_path, kwargs = path_and_kwargs
run_work_unit(wid_path, **kwargs)


# -----------------------------------------------------------------------------
# Functions related to individual work units within the experiment.
# -----------------------------------------------------------------------------
return sweep.product(challenge_sweeps, hparam_sweeps)


@experiment.work_unit_fn
def run_work_unit(
wid_path: str,
challenge_name: str,
Expand Down Expand Up @@ -176,10 +130,11 @@ def run_work_unit(

if __name__ == "__main__":
args = parser.parse_args()
run_experiment(
experiment.run_experiment(
experiment_path=args.path,
sweeps=experiment_sweeps(steps=args.steps),
work_unit_fn=run_work_unit,
workers=args.workers,
dry_run=args.dry_run,
randomize=args.randomize,
steps=args.steps,
)

0 comments on commit ae2db72

Please sign in to comment.