forked from icenet-ai/icenet
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merging in changes from icenet 0.2.x
- Loading branch information
1 parent
d70ece9
commit 4c4b765
Showing
13 changed files
with
788 additions
and
38 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
import datetime as dt | ||
import re | ||
|
||
|
||
def date_arg(string: str) -> object: | ||
""" | ||
:param string: | ||
:return: | ||
""" | ||
date_match = re.search(r"(\d{4})-(\d{1,2})-(\d{1,2})", string) | ||
return dt.date(*[int(s) for s in date_match.groups()]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
import argparse | ||
import logging | ||
import os | ||
|
||
from icenet.utils import setup_logging | ||
|
||
|
||
class TrainingArgParser(argparse.ArgumentParser): | ||
"""An ArgumentParser specialised to support model training | ||
The 'allow_*' methods return self to permit method chaining. | ||
""" | ||
|
||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
|
||
self.add_argument("dataset", type=str) | ||
self.add_argument("run_name", type=str) | ||
self.add_argument("seed", type=int) | ||
|
||
self.add_argument("-o", "--output-path", type=str, default=None) | ||
self.add_argument("-v", | ||
"--verbose", | ||
action="store_true", | ||
default=False) | ||
|
||
self.add_argument("-b", "--batch-size", type=int, default=4) | ||
self.add_argument("-ca", | ||
"--checkpoint-mode", | ||
default="min", | ||
type=str) | ||
self.add_argument("-cm", | ||
"--checkpoint-monitor", | ||
default="val_rmse", | ||
type=str) | ||
self.add_argument("-ds", | ||
"--additional-dataset", | ||
dest="additional", | ||
nargs="*", | ||
default=[]) | ||
self.add_argument("-e", "--epochs", type=int, default=4) | ||
self.add_argument("--early-stopping", type=int, default=50) | ||
self.add_argument("-p", "--preload", type=str) | ||
self.add_argument("-r", "--ratio", default=1.0, type=float) | ||
self.add_argument("--shuffle-train", | ||
default=False, | ||
action="store_true", | ||
help="Shuffle the training set") | ||
self.add_argument("--lr", default=1e-4, type=float) | ||
self.add_argument("--lr_10e_decay_fac", | ||
default=1.0, | ||
type=float, | ||
help="Factor by which LR is multiplied by every 10 epochs " | ||
"using exponential decay. E.g. 1 -> no decay (default)" | ||
", 0.5 -> halve every 10 epochs.") | ||
self.add_argument('--lr_decay_start', default=10, type=int) | ||
self.add_argument('--lr_decay_end', default=30, type=int) | ||
|
||
def add_unet(self): | ||
self.add_argument("-f", "--filter-size", type=int, default=3) | ||
self.add_argument("-n", "--n-filters-factor", type=float, default=1.) | ||
return self | ||
|
||
def add_tensorflow(self): | ||
# TODO: derive from available tf.distribute implementations | ||
self.add_argument("-s", | ||
"--strategy", | ||
default=None, | ||
choices=("default", "mirrored", "central")) | ||
return self | ||
|
||
def add_horovod(self): | ||
self.add_argument("--no-horovod", | ||
dest="horovod", | ||
default=True, | ||
action="store_false") | ||
self.add_argument("--device-type", | ||
default=None, | ||
help="Choose a device type to detect, if using") | ||
return self | ||
|
||
def add_wandb(self): | ||
self.add_argument("-nw", "--no-wandb", default=False, action="store_true") | ||
self.add_argument("-wo", | ||
"--wandb-offline", | ||
default=False, | ||
action="store_true") | ||
self.add_argument("-wp", | ||
"--wandb-project", | ||
default=os.environ.get("ICENET_ENVIRONMENT"), | ||
type=str) | ||
self.add_argument("-wu", | ||
"--wandb-user", | ||
default=os.environ.get("USER"), | ||
type=str) | ||
return self | ||
|
||
def parse_args(self, *args, | ||
log_format="[%(asctime)-17s :%(levelname)-8s] - %(message)s", | ||
**kwargs): | ||
args = super().parse_args(*args, **kwargs) | ||
|
||
logging.basicConfig( | ||
datefmt="%d-%m-%y %T", | ||
format=log_format, | ||
level=logging.DEBUG if args.verbose else logging.INFO) | ||
logging.getLogger("cdsapi").setLevel(logging.WARNING) | ||
logging.getLogger("matplotlib").setLevel(logging.WARNING) | ||
logging.getLogger("matplotlib.pyplot").setLevel(logging.WARNING) | ||
logging.getLogger("requests").setLevel(logging.WARNING) | ||
logging.getLogger("tensorflow").setLevel(logging.WARNING) | ||
logging.getLogger("urllib3").setLevel(logging.WARNING) | ||
return args | ||
|
||
|
||
@setup_logging | ||
def predict_args(): | ||
""" | ||
:return: | ||
""" | ||
ap = argparse.ArgumentParser() | ||
ap.add_argument("dataset") | ||
ap.add_argument("network_name") | ||
ap.add_argument("output_name") | ||
ap.add_argument("seed", type=int, default=42) | ||
ap.add_argument("datefile", type=argparse.FileType("r")) | ||
|
||
ap.add_argument("-i", | ||
"--train-identifier", | ||
dest="ident", | ||
help="Train dataset identifier", | ||
type=str, | ||
default=None) | ||
ap.add_argument("-n", "--n-filters-factor", type=float, default=1.) | ||
ap.add_argument("-l", "--legacy-rounding", action="store_true", | ||
default=False, help="Ensure filter number rounding occurs last in channel number calculations") | ||
ap.add_argument("-t", "--testset", action="store_true", default=False) | ||
ap.add_argument("-v", "--verbose", action="store_true", default=False) | ||
ap.add_argument("-s", "--save_args", action="store_true", default=False) | ||
|
||
return ap.parse_args() |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
import datetime as dt | ||
import logging | ||
|
||
wandb_available = False | ||
try: | ||
import wandb | ||
wandb_available = True | ||
except ModuleNotFoundError: | ||
pass | ||
|
||
|
||
def init_wandb(cli_args): | ||
if wandb_available: | ||
if cli_args.horovod: | ||
try: | ||
import horovod.tensorflow.keras as hvd | ||
except ModuleNotFoundError: | ||
raise RuntimeError("We're running horovod jobs without the module, eh?") | ||
|
||
if hvd.rank() > 0: | ||
logging.info("Not initialising wandb for rank {}".format(hvd.rank())) | ||
return None, None | ||
|
||
logging.warning("Initialising WANDB for this run at user request") | ||
|
||
run = wandb.init( | ||
project=cli_args.wandb_project, | ||
name="{}.{}".format(cli_args.run_name, cli_args.seed), | ||
notes="{}: run at {}{}".format( | ||
cli_args.run_name, | ||
dt.datetime.now().strftime("%D %T"), "" if cli_args.preload is None | ||
else " preload {}".format(cli_args.preload)), | ||
entity=cli_args.wandb_user, | ||
config=dict( | ||
seed=cli_args.seed, | ||
learning_rate=cli_args.lr, | ||
filter_size=cli_args.filter_size, | ||
n_filters_factor=cli_args.n_filters_factor, | ||
lr_10e_decay_fac=cli_args.lr_10e_decay_fac, | ||
lr_decay_start=cli_args.lr_decay_start, | ||
lr_decay_end=cli_args.lr_decay_end, | ||
batch_size=cli_args.batch_size, | ||
), | ||
settings=wandb.Settings( | ||
# start_method="fork", | ||
# _disable_stats=True, | ||
), | ||
allow_val_change=True, | ||
mode='offline' if cli_args.wandb_offline else 'online', | ||
group=cli_args.run_name, | ||
) | ||
|
||
# Log training metrics to wandb each epoch | ||
return run, wandb.keras.WandbCallback( | ||
monitor=cli_args.checkpoint_monitor, | ||
mode=cli_args.checkpoint_mode, | ||
save_model=False, | ||
save_graph=False, | ||
) | ||
|
||
logging.warning("WandB is not available, we will never use it") | ||
return None, None | ||
|
||
|
||
def finalise_wandb(run, results, metric_names, leads): | ||
logging.info("Updating wandb run with evaluation metrics") | ||
metric_vals = [[results[f'{name}{lt}'] for lt in leads] | ||
for name in metric_names] | ||
table_data = list(zip(leads, *metric_vals)) | ||
table = wandb.Table(data=table_data, | ||
columns=['leadtime', *metric_names]) | ||
|
||
# Log each metric vs. leadtime as a plot to wandb | ||
for name in metric_names: | ||
logging.debug("WandB logging {}".format(name)) | ||
run.log( | ||
{f'{name}_plot': wandb.plot.line(table, x='leadtime', y=name)}) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
import logging | ||
import os | ||
import random | ||
|
||
import numpy as np | ||
|
||
from abc import abstractmethod | ||
|
||
|
||
class BaseNetwork: | ||
def __init__(self, | ||
dataset: object, | ||
run_name: object, | ||
callbacks_additional: list = None, | ||
callbacks_default: list = None, | ||
network_folder: object = None, | ||
seed: int = 42): | ||
|
||
if not network_folder: | ||
self._network_folder = os.path.join(".", "results", "networks", run_name) | ||
|
||
if not os.path.exists(self._network_folder): | ||
logging.info("Creating network folder: {}".format(network_folder)) | ||
os.makedirs(self._network_folder, exist_ok=True) | ||
|
||
self._model_path = os.path.join( | ||
self._network_folder, "{}.model_{}.{}".format(run_name, | ||
dataset.identifier, | ||
seed)) | ||
|
||
self._dataset = dataset | ||
self._run_name = run_name | ||
self._seed = seed | ||
|
||
self._callbacks = self.get_default_callbacks() if callbacks_default is None else callbacks_default | ||
self._callbacks += callbacks_additional if callbacks_additional is not None else [] | ||
|
||
self._attempt_seed_setup() | ||
|
||
def _attempt_seed_setup(self): | ||
logging.warning( | ||
"Setting seed for best attempt at determinism, value {}".format(self._seed)) | ||
# determinism is not guaranteed across different versions of TensorFlow. | ||
# determinism is not guaranteed across different hardware. | ||
os.environ['PYTHONHASHSEED'] = str(self._seed) | ||
# numpy.random.default_rng ignores this, WARNING! | ||
np.random.seed(self._seed) | ||
random.seed(self._seed) | ||
|
||
def add_callback(self, callback): | ||
logging.debug("Adding callback {}".format(callback)) | ||
self._callbacks.append(callback) | ||
|
||
def get_default_callbacks(self): | ||
return list() | ||
|
||
@abstractmethod | ||
def train(self, | ||
epochs: int, | ||
model_creator: callable, | ||
train_dataset: object, | ||
model_creator_kwargs: dict = None, | ||
save: bool = True): | ||
raise NotImplementedError("Implementation not found") | ||
|
||
@abstractmethod | ||
def predict(self): | ||
raise NotImplementedError("Implementation not found") | ||
|
||
@property | ||
def callbacks(self): | ||
return self._callbacks | ||
|
||
@property | ||
def dataset(self): | ||
return self._dataset | ||
|
||
@property | ||
def model_path(self): | ||
return self._model_path | ||
|
||
@property | ||
def network_folder(self): | ||
return self._network_folder | ||
|
||
@property | ||
def run_name(self): | ||
return self._run_name | ||
|
||
@property | ||
def seed(self): | ||
return self._seed | ||
|
Oops, something went wrong.