Skip to content

Commit

Permalink
Merge pull request #279 from allenai/experiments_as_objects
Browse files Browse the repository at this point in the history
Allowing kwargs to be passed to experiments from the command line and changing how testing is run.
  • Loading branch information
jordis-ai2 authored May 4, 2021
2 parents 5275284 + 16cbe81 commit 0ed68f9
Show file tree
Hide file tree
Showing 15 changed files with 343 additions and 210 deletions.
146 changes: 65 additions & 81 deletions allenact/algorithms/onpolicy_sync/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import glob
import itertools
import json
import math
import os
import queue
import random
Expand All @@ -15,6 +16,7 @@
from multiprocessing.process import BaseProcess
from typing import Optional, Dict, Union, Tuple, Sequence, List, Any

import numpy as np
import torch
import torch.multiprocessing as mp
from setproctitle import setproctitle as ptitle
Expand Down Expand Up @@ -45,6 +47,8 @@
# Saves configs, makes folder for trainer models
from allenact.utils.viz_utils import VizSuite

_CONFIG_KWARGS_STR = "__CONFIG_KWARGS__"


class OnPolicyRunner(object):
def __init__(
Expand Down Expand Up @@ -349,10 +353,8 @@ def start_train(

def start_test(
self,
experiment_date: str,
checkpoint_name_fragment: Optional[str] = None,
approx_ckpt_steps_count: Optional[Union[float, int]] = None,
skip_checkpoints: int = 0,
checkpoint_path_dir_or_pattern: str,
approx_ckpt_step_interval: Optional[Union[float, int]] = None,
max_sampler_processes_per_worker: Optional[int] = None,
):
devices = self.worker_devices("test")
Expand Down Expand Up @@ -390,10 +392,8 @@ def start_test(
)

checkpoint_paths = self.get_checkpoint_files(
experiment_date=experiment_date,
checkpoint_name_fragment=checkpoint_name_fragment,
approx_ckpt_steps_count=approx_ckpt_steps_count,
skip_checkpoints=skip_checkpoints,
checkpoint_path_dir_or_pattern=checkpoint_path_dir_or_pattern,
approx_ckpt_step_interval=approx_ckpt_step_interval,
)
steps = [self.step_from_checkpoint(cp) for cp in checkpoint_paths]

Expand All @@ -407,7 +407,7 @@ def start_test(
for _ in range(num_testers):
self.queues["checkpoints"].put(("quit", None))

metrics_dir = self.metric_path(experiment_date)
metrics_dir = self.metric_path(self.local_start_time_str)
os.makedirs(metrics_dir, exist_ok=True)
suffix = "__test_{}".format(self.local_start_time_str)
metrics_file_path = os.path.join(metrics_dir, "metrics" + suffix + ".json")
Expand Down Expand Up @@ -504,6 +504,17 @@ def save_project_state(self):
# Saving configs
if self.loaded_config_src_files is not None:
for src_path in self.loaded_config_src_files:
if src_path == _CONFIG_KWARGS_STR:
# We also save key-word arguments passed to to the experiment
# initializer.
save_path = os.path.join(base_dir, "config_kwargs.json")
assert not os.path.exists(
save_path
), f"{save_path} should not already exist."
with open(save_path, "w") as f:
json.dump(json.loads(self.loaded_config_src_files[src_path]), f)
continue

assert os.path.isfile(src_path), "Config file {} not found".format(
src_path
)
Expand Down Expand Up @@ -887,89 +898,62 @@ def log(

def get_checkpoint_files(
self,
experiment_date: str,
checkpoint_name_fragment: Optional[str] = None,
approx_ckpt_steps_count: Optional[int] = None,
skip_checkpoints: int = 0,
checkpoint_path_dir_or_pattern: str,
approx_ckpt_step_interval: Optional[int] = None,
):
test_checkpoints_dir = self.checkpoint_dir(
experiment_date, create_if_none=False
)
if checkpoint_name_fragment is not None and os.path.exists(
checkpoint_name_fragment
):
if "*" in checkpoint_name_fragment:
# The fragment is a glob
assert "/" not in checkpoint_name_fragment
elif os.path.isfile(checkpoint_name_fragment):
# The fragment is a path to a checkpoint, use this checkpoint
return [checkpoint_name_fragment]
elif os.path.isdir(checkpoint_name_fragment):
# The fragment is a path to a directory, lets use this directory
# as the base dir to search for checkpoints
test_checkpoints_dir = checkpoint_name_fragment
checkpoint_name_fragment = None
else:
raise NotImplementedError

if checkpoint_name_fragment is not None:
assert (
skip_checkpoints == 0
), "`skip_checkpoints` must be 0 (i.e. none skipped)."
if checkpoint_name_fragment.endswith(".pt"):
checkpoint_name_fragment = checkpoint_name_fragment[:-3]
while checkpoint_name_fragment != checkpoint_name_fragment.strip("*"):
checkpoint_name_fragment = checkpoint_name_fragment.strip("*")

for_glob = os.path.join(
test_checkpoints_dir, f"*{checkpoint_name_fragment}*.pt",
if os.path.isdir(checkpoint_path_dir_or_pattern):
# The fragment is a path to a directory, lets use this directory
# as the base dir to search for checkpoints
checkpoint_path_dir_or_pattern = os.path.join(
checkpoint_path_dir_or_pattern, "*.pt"
)
paths = glob.glob(for_glob)
if len(paths) == 0:
raise FileExistsError(
f"No file at path `{checkpoint_name_fragment}` nor any"
f" files matching pattern {for_glob}."
)
elif len(paths) > 1:
raise FileExistsError(
f"Too many files match the pattern {for_glob}. These files include {paths}."
)
return paths
elif approx_ckpt_steps_count is not None:
paths = glob.glob(os.path.join(test_checkpoints_dir, "exp_*.pt"))
if len(paths) == 0:
raise FileExistsError(
f"No checkpoint files in directory {test_checkpoints_dir}."
)
step_diffs = [
abs(approx_ckpt_steps_count - self.step_from_checkpoint(p))
for p in paths
]
_, path = min(*zip(step_diffs, paths))
paths = [path]
else:
paths = glob.glob(os.path.join(test_checkpoints_dir, "exp_*.pt"))
paths = sorted(paths)
return (
paths[:: skip_checkpoints + 1]
+ (
[paths[-1]]
if skip_checkpoints > 0 and len(paths) % (skip_checkpoints + 1) != 1
else []

ckpt_paths = glob.glob(checkpoint_path_dir_or_pattern, recursive=True)

if len(ckpt_paths) == 0:
raise FileNotFoundError(
f"Could not find any checkpoints at {os.path.abspath(checkpoint_path_dir_or_pattern)}, is it possible"
f" the path has been mispecified?"
)
if len(paths) > 0
else paths
)

step_count_ckpt_pairs = [(self.step_from_checkpoint(p), p) for p in ckpt_paths]
step_count_ckpt_pairs.sort()
ckpts_paths = [p for _, p in step_count_ckpt_pairs]
step_counts = np.array([sc for sc, _ in step_count_ckpt_pairs])

if approx_ckpt_step_interval is not None:
assert (
approx_ckpt_step_interval > 0
), "`approx_ckpt_step_interval` must be >0"
inds_to_eval = set()
for i in range(
math.ceil(step_count_ckpt_pairs[-1][0] / approx_ckpt_step_interval) + 1
):
inds_to_eval.add(
int(np.argmin(np.abs(step_counts - i * approx_ckpt_step_interval)))
)

ckpts_paths = [ckpts_paths[ind] for ind in sorted(list(inds_to_eval))]
return ckpts_paths

@staticmethod
def step_from_checkpoint(name: str) -> int:
parts = os.path.basename(name).split("__")
def step_from_checkpoint(ckpt_path: str) -> int:
parts = os.path.basename(ckpt_path).split("__")
for part in parts:
if "steps_" in part:
possible_num = part.split("_")[-1].split(".")[0]
if possible_num.isdigit():
return int(possible_num)
return -1

get_logger().warning(
f"The checkpoint {os.path.basename(ckpt_path)} does not follow the checkpoint naming convention"
f" used by AllenAct. As a fall back we must load the checkpoint into memory to find the"
f" training step count, this may increase startup time if the checkpoints are large or many"
f" must be loaded in sequence."
)
ckpt = torch.load(ckpt_path, map_location="cpu")
return ckpt["total_steps"]

def close(self, verbose=True):
if self._is_closed:
Expand Down
15 changes: 5 additions & 10 deletions allenact/base_abstractions/experiment_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,15 +181,13 @@ class ExperimentConfig(metaclass=FrozenClassVariables):
running the experiment.
"""

@classmethod
@abc.abstractmethod
def tag(cls) -> str:
def tag(self) -> str:
"""A string describing the experiment."""
raise NotImplementedError()

@classmethod
@abc.abstractmethod
def training_pipeline(cls, **kwargs) -> TrainingPipeline:
def training_pipeline(self, **kwargs) -> TrainingPipeline:
"""Creates the training pipeline.
# Parameters
Expand All @@ -202,10 +200,9 @@ def training_pipeline(cls, **kwargs) -> TrainingPipeline:
"""
raise NotImplementedError()

@classmethod
@abc.abstractmethod
def machine_params(
cls, mode="train", **kwargs
self, mode="train", **kwargs
) -> Union[MachineParams, Dict[str, Any]]:
"""Parameters used to specify machine information.
Expand All @@ -225,15 +222,13 @@ def machine_params(
"""
raise NotImplementedError()

@classmethod
@abc.abstractmethod
def create_model(cls, **kwargs) -> nn.Module:
def create_model(self, **kwargs) -> nn.Module:
"""Create the neural model."""
raise NotImplementedError()

@classmethod
@abc.abstractmethod
def make_sampler_fn(cls, **kwargs) -> TaskSampler:
def make_sampler_fn(self, **kwargs) -> TaskSampler:
"""Create the TaskSampler given keyword arguments.
These `kwargs` will be generated by one of
Expand Down
Loading

0 comments on commit 0ed68f9

Please sign in to comment.