Skip to content

Commit

Permalink
Merging in changes from icenet 0.2.x
Browse files Browse the repository at this point in the history
  • Loading branch information
JimCircadian committed Jul 26, 2024
1 parent d70ece9 commit 4c4b765
Show file tree
Hide file tree
Showing 13 changed files with 788 additions and 38 deletions.
12 changes: 12 additions & 0 deletions icenet/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import datetime as dt
import re


def date_arg(string: str) -> object:
"""
:param string:
:return:
"""
date_match = re.search(r"(\d{4})-(\d{1,2})-(\d{1,2})", string)
return dt.date(*[int(s) for s in date_match.groups()])
15 changes: 9 additions & 6 deletions icenet/data/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def check_dataset(self, split: str = "train") -> None:
dtype=self.dtype.__name__)

for df in getattr(self, "{}_fns".format(split)):
logging.debug("Getting records from {}".format(df))
logging.info("Getting records from {}".format(df))
try:
raw_dataset = tf.data.TFRecordDataset([df])
raw_dataset = raw_dataset.map(decoder)
Expand All @@ -217,7 +217,8 @@ def check_dataset(self, split: str = "train") -> None:
df, i, x.shape, y.shape, sw.shape))

input_nans = np.isnan(x).sum()
output_nans = np.isnan(y[sw > 0.]).sum()
output_nans = np.isnan(y[(sw > 0.)]).sum()
sw_nans = np.isnan(sw).sum()
input_min = np.min(x)
input_max = np.max(x)
output_min = np.min(x)
Expand All @@ -231,13 +232,15 @@ def check_dataset(self, split: str = "train") -> None:
sw_min, sw_max))

if input_nans > 0:
logging.warning("Input NaNs detected in {}:{}".format(
df, i))
logging.warning("Input NaNs detected in {}:{}".format(df, i))

if output_nans > 0:
logging.warning(
"Output NaNs detected in {}:{}, not "
"accounted for by sample weighting".format(df, i))
"Output NaNs detected in {}:{}, not accounted for by sample weighting".format(df, i))

if sw_nans > 0:
logging.warning(
"SW NaNs detected in {}:{}".format(df, i))
except tf.errors.DataLossError as e:
logging.warning("{}: data loss error {}".format(df, e.message))
except tf.errors.OpError as e:
Expand Down
142 changes: 142 additions & 0 deletions icenet/model/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import argparse
import logging
import os

from icenet.utils import setup_logging


class TrainingArgParser(argparse.ArgumentParser):
"""An ArgumentParser specialised to support model training
The 'allow_*' methods return self to permit method chaining.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

self.add_argument("dataset", type=str)
self.add_argument("run_name", type=str)
self.add_argument("seed", type=int)

self.add_argument("-o", "--output-path", type=str, default=None)
self.add_argument("-v",
"--verbose",
action="store_true",
default=False)

self.add_argument("-b", "--batch-size", type=int, default=4)
self.add_argument("-ca",
"--checkpoint-mode",
default="min",
type=str)
self.add_argument("-cm",
"--checkpoint-monitor",
default="val_rmse",
type=str)
self.add_argument("-ds",
"--additional-dataset",
dest="additional",
nargs="*",
default=[])
self.add_argument("-e", "--epochs", type=int, default=4)
self.add_argument("--early-stopping", type=int, default=50)
self.add_argument("-p", "--preload", type=str)
self.add_argument("-r", "--ratio", default=1.0, type=float)
self.add_argument("--shuffle-train",
default=False,
action="store_true",
help="Shuffle the training set")
self.add_argument("--lr", default=1e-4, type=float)
self.add_argument("--lr_10e_decay_fac",
default=1.0,
type=float,
help="Factor by which LR is multiplied by every 10 epochs "
"using exponential decay. E.g. 1 -> no decay (default)"
", 0.5 -> halve every 10 epochs.")
self.add_argument('--lr_decay_start', default=10, type=int)
self.add_argument('--lr_decay_end', default=30, type=int)

def add_unet(self):
self.add_argument("-f", "--filter-size", type=int, default=3)
self.add_argument("-n", "--n-filters-factor", type=float, default=1.)
return self

def add_tensorflow(self):
# TODO: derive from available tf.distribute implementations
self.add_argument("-s",
"--strategy",
default=None,
choices=("default", "mirrored", "central"))
return self

def add_horovod(self):
self.add_argument("--no-horovod",
dest="horovod",
default=True,
action="store_false")
self.add_argument("--device-type",
default=None,
help="Choose a device type to detect, if using")
return self

def add_wandb(self):
self.add_argument("-nw", "--no-wandb", default=False, action="store_true")
self.add_argument("-wo",
"--wandb-offline",
default=False,
action="store_true")
self.add_argument("-wp",
"--wandb-project",
default=os.environ.get("ICENET_ENVIRONMENT"),
type=str)
self.add_argument("-wu",
"--wandb-user",
default=os.environ.get("USER"),
type=str)
return self

def parse_args(self, *args,
log_format="[%(asctime)-17s :%(levelname)-8s] - %(message)s",
**kwargs):
args = super().parse_args(*args, **kwargs)

logging.basicConfig(
datefmt="%d-%m-%y %T",
format=log_format,
level=logging.DEBUG if args.verbose else logging.INFO)
logging.getLogger("cdsapi").setLevel(logging.WARNING)
logging.getLogger("matplotlib").setLevel(logging.WARNING)
logging.getLogger("matplotlib.pyplot").setLevel(logging.WARNING)
logging.getLogger("requests").setLevel(logging.WARNING)
logging.getLogger("tensorflow").setLevel(logging.WARNING)
logging.getLogger("urllib3").setLevel(logging.WARNING)
return args


@setup_logging
def predict_args():
"""
:return:
"""
ap = argparse.ArgumentParser()
ap.add_argument("dataset")
ap.add_argument("network_name")
ap.add_argument("output_name")
ap.add_argument("seed", type=int, default=42)
ap.add_argument("datefile", type=argparse.FileType("r"))

ap.add_argument("-i",
"--train-identifier",
dest="ident",
help="Train dataset identifier",
type=str,
default=None)
ap.add_argument("-n", "--n-filters-factor", type=float, default=1.)
ap.add_argument("-l", "--legacy-rounding", action="store_true",
default=False, help="Ensure filter number rounding occurs last in channel number calculations")
ap.add_argument("-t", "--testset", action="store_true", default=False)
ap.add_argument("-v", "--verbose", action="store_true", default=False)
ap.add_argument("-s", "--save_args", action="store_true", default=False)

return ap.parse_args()
Empty file.
77 changes: 77 additions & 0 deletions icenet/model/handlers/wandb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import datetime as dt
import logging

wandb_available = False
try:
import wandb
wandb_available = True
except ModuleNotFoundError:
pass


def init_wandb(cli_args):
if wandb_available:
if cli_args.horovod:
try:
import horovod.tensorflow.keras as hvd
except ModuleNotFoundError:
raise RuntimeError("We're running horovod jobs without the module, eh?")

if hvd.rank() > 0:
logging.info("Not initialising wandb for rank {}".format(hvd.rank()))
return None, None

logging.warning("Initialising WANDB for this run at user request")

run = wandb.init(
project=cli_args.wandb_project,
name="{}.{}".format(cli_args.run_name, cli_args.seed),
notes="{}: run at {}{}".format(
cli_args.run_name,
dt.datetime.now().strftime("%D %T"), "" if cli_args.preload is None
else " preload {}".format(cli_args.preload)),
entity=cli_args.wandb_user,
config=dict(
seed=cli_args.seed,
learning_rate=cli_args.lr,
filter_size=cli_args.filter_size,
n_filters_factor=cli_args.n_filters_factor,
lr_10e_decay_fac=cli_args.lr_10e_decay_fac,
lr_decay_start=cli_args.lr_decay_start,
lr_decay_end=cli_args.lr_decay_end,
batch_size=cli_args.batch_size,
),
settings=wandb.Settings(
# start_method="fork",
# _disable_stats=True,
),
allow_val_change=True,
mode='offline' if cli_args.wandb_offline else 'online',
group=cli_args.run_name,
)

# Log training metrics to wandb each epoch
return run, wandb.keras.WandbCallback(
monitor=cli_args.checkpoint_monitor,
mode=cli_args.checkpoint_mode,
save_model=False,
save_graph=False,
)

logging.warning("WandB is not available, we will never use it")
return None, None


def finalise_wandb(run, results, metric_names, leads):
logging.info("Updating wandb run with evaluation metrics")
metric_vals = [[results[f'{name}{lt}'] for lt in leads]
for name in metric_names]
table_data = list(zip(leads, *metric_vals))
table = wandb.Table(data=table_data,
columns=['leadtime', *metric_names])

# Log each metric vs. leadtime as a plot to wandb
for name in metric_names:
logging.debug("WandB logging {}".format(name))
run.log(
{f'{name}_plot': wandb.plot.line(table, x='leadtime', y=name)})
Empty file.
93 changes: 93 additions & 0 deletions icenet/model/networks/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import logging
import os
import random

import numpy as np

from abc import abstractmethod


class BaseNetwork:
def __init__(self,
dataset: object,
run_name: object,
callbacks_additional: list = None,
callbacks_default: list = None,
network_folder: object = None,
seed: int = 42):

if not network_folder:
self._network_folder = os.path.join(".", "results", "networks", run_name)

if not os.path.exists(self._network_folder):
logging.info("Creating network folder: {}".format(network_folder))
os.makedirs(self._network_folder, exist_ok=True)

self._model_path = os.path.join(
self._network_folder, "{}.model_{}.{}".format(run_name,
dataset.identifier,
seed))

self._dataset = dataset
self._run_name = run_name
self._seed = seed

self._callbacks = self.get_default_callbacks() if callbacks_default is None else callbacks_default
self._callbacks += callbacks_additional if callbacks_additional is not None else []

self._attempt_seed_setup()

def _attempt_seed_setup(self):
logging.warning(
"Setting seed for best attempt at determinism, value {}".format(self._seed))
# determinism is not guaranteed across different versions of TensorFlow.
# determinism is not guaranteed across different hardware.
os.environ['PYTHONHASHSEED'] = str(self._seed)
# numpy.random.default_rng ignores this, WARNING!
np.random.seed(self._seed)
random.seed(self._seed)

def add_callback(self, callback):
logging.debug("Adding callback {}".format(callback))
self._callbacks.append(callback)

def get_default_callbacks(self):
return list()

@abstractmethod
def train(self,
epochs: int,
model_creator: callable,
train_dataset: object,
model_creator_kwargs: dict = None,
save: bool = True):
raise NotImplementedError("Implementation not found")

@abstractmethod
def predict(self):
raise NotImplementedError("Implementation not found")

@property
def callbacks(self):
return self._callbacks

@property
def dataset(self):
return self._dataset

@property
def model_path(self):
return self._model_path

@property
def network_folder(self):
return self._network_folder

@property
def run_name(self):
return self._run_name

@property
def seed(self):
return self._seed

Loading

0 comments on commit 4c4b765

Please sign in to comment.