From 5f56a1f5f276287d6f93e1649a11661b66711eb1 Mon Sep 17 00:00:00 2001 From: James Byrne Date: Thu, 23 May 2024 08:05:36 +0100 Subject: [PATCH] Fixes #263: implemented a basic extension check to allow fully qualified dataset filenames and Dev #252: refactoring of existing training functionality to allow extension to use horovod for fully distributed training as a child implementation of the original tensorflow --- icenet/cli.py | 1 - icenet/model/cli.py | 183 ++++++------ icenet/model/models.py | 252 +---------------- icenet/model/networks/__init__.py | 0 icenet/model/networks/base.py | 92 ++++++ icenet/model/networks/tensorflow.py | 417 ++++++++++++++++++++++++++++ icenet/model/train.py | 320 ++++++--------------- icenet/model/utils.py | 49 +++- setup.py | 4 +- 9 files changed, 731 insertions(+), 587 deletions(-) create mode 100644 icenet/model/networks/__init__.py create mode 100644 icenet/model/networks/base.py create mode 100644 icenet/model/networks/tensorflow.py diff --git a/icenet/cli.py b/icenet/cli.py index 41986b31..6704cdf6 100644 --- a/icenet/cli.py +++ b/icenet/cli.py @@ -18,7 +18,6 @@ def wrapper(*args, **kwargs): logging.basicConfig( level=level, format=log_format, - datefmt="%d-%m-%y %T", ) # TODO: better way of handling these on a case by case basis diff --git a/icenet/model/cli.py b/icenet/model/cli.py index cfb53d24..70491526 100644 --- a/icenet/model/cli.py +++ b/icenet/model/cli.py @@ -1,90 +1,117 @@ import argparse +import logging import os from icenet.cli import setup_logging -@setup_logging -def train_args(): - """ +class TrainingArgParser(argparse.ArgumentParser): + """An ArgumentParser specialised to support model training - :return: + The 'allow_*' methods return self to permit method chaining. """ - ap = argparse.ArgumentParser() - ap.add_argument("dataset", type=str) - ap.add_argument("run_name", type=str) - ap.add_argument("seed", type=int) - - ap.add_argument("-b", "--batch-size", type=int, default=4) - ap.add_argument("-ca", - "--checkpoint-mode", - default="min", - type=str) - ap.add_argument("-cm", - "--checkpoint-monitor", - default="val_rmse", - type=str) - ap.add_argument("-ds", - "--additional-dataset", - dest="additional", - nargs="*", - default=[]) - ap.add_argument("-e", "--epochs", type=int, default=4) - ap.add_argument("-f", "--filter-size", type=int, default=3) - ap.add_argument("--early-stopping", type=int, default=50) - ap.add_argument("-m", - "--multiprocessing", - action="store_true", - default=False) - ap.add_argument("-n", "--n-filters-factor", type=float, default=1.) - ap.add_argument("-p", "--preload", type=str) - ap.add_argument("-pw", - "--pickup-weights", - action="store_true", - default=False) - ap.add_argument("-qs", "--max-queue-size", default=10, type=int) - ap.add_argument("-r", "--ratio", default=1.0, type=float) - - ap.add_argument("-s", - "--strategy", - default=None, - choices=("default", "mirrored", "central")) - - ap.add_argument("--shuffle-train", - default=False, - action="store_true", - help="Shuffle the training set") - ap.add_argument("-v", "--verbose", action="store_true", default=False) - ap.add_argument("-w", "--workers", type=int, default=4) - - # WandB additional arguments - ap.add_argument("-nw", "--no-wandb", default=False, action="store_true") - ap.add_argument("-wo", - "--wandb-offline", - default=False, - action="store_true") - ap.add_argument("-wp", - "--wandb-project", - default=os.environ.get("ICENET_ENVIRONMENT"), - type=str) - ap.add_argument("-wu", - "--wandb-user", - default=os.environ.get("USER"), - type=str) - - # Learning rate arguments - ap.add_argument("--lr", default=1e-4, type=float) - ap.add_argument("--lr_10e_decay_fac", - default=1.0, - type=float, - help="Factor by which LR is multiplied by every 10 epochs " - "using exponential decay. E.g. 1 -> no decay (default)" - ", 0.5 -> halve every 10 epochs.") - ap.add_argument('--lr_decay_start', default=10, type=int) - ap.add_argument('--lr_decay_end', default=30, type=int) - - return ap.parse_args() + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.add_argument("dataset", type=str) + self.add_argument("run_name", type=str) + self.add_argument("seed", type=int) + + self.add_argument("-o", "--output-path", type=str, default=None) + self.add_argument("-v", + "--verbose", + action="store_true", + default=False) + + self.add_argument("-b", "--batch-size", type=int, default=4) + self.add_argument("-ca", + "--checkpoint-mode", + default="min", + type=str) + self.add_argument("-cm", + "--checkpoint-monitor", + default="val_rmse", + type=str) + self.add_argument("-ds", + "--additional-dataset", + dest="additional", + nargs="*", + default=[]) + self.add_argument("-e", "--epochs", type=int, default=4) + self.add_argument("--early-stopping", type=int, default=50) + self.add_argument("-p", "--preload", type=str) + self.add_argument("-qs", "--max-queue-size", default=10, type=int) + self.add_argument("-r", "--ratio", default=1.0, type=float) + self.add_argument("--shuffle-train", + default=False, + action="store_true", + help="Shuffle the training set") + self.add_argument("--lr", default=1e-4, type=float) + self.add_argument("--lr_10e_decay_fac", + default=1.0, + type=float, + help="Factor by which LR is multiplied by every 10 epochs " + "using exponential decay. E.g. 1 -> no decay (default)" + ", 0.5 -> halve every 10 epochs.") + self.add_argument('--lr_decay_start', default=10, type=int) + self.add_argument('--lr_decay_end', default=30, type=int) + + def add_unet(self): + self.add_argument("-f", "--filter-size", type=int, default=3) + self.add_argument("-n", "--n-filters-factor", type=float, default=1.) + return self + + def add_tensorflow(self): + # TODO: derive from available tf.distribute implementations + self.add_argument("-s", + "--strategy", + default=None, + choices=("default", "mirrored", "central")) + return self + + def add_horovod(self): + self.add_argument("-hv", + "--horovod", + default=False, + action="store_true") + self.add_argument("--device-type", + default=None, + help="Choose a device type for distribution, if using") + return self + + def add_wandb(self): + self.add_argument("-nw", "--no-wandb", default=False, action="store_true") + self.add_argument("-wo", + "--wandb-offline", + default=False, + action="store_true") + self.add_argument("-wp", + "--wandb-project", + default=os.environ.get("ICENET_ENVIRONMENT"), + type=str) + self.add_argument("-wu", + "--wandb-user", + default=os.environ.get("USER"), + type=str) + return self + + def parse_args(self, *args, + log_format="[%(asctime)-17s :%(levelname)-8s] - %(message)s", + **kwargs): + args = super().parse_args(*args, **kwargs) + + logging.basicConfig( + datefmt="%d-%m-%y %T", + format=log_format, + level=logging.DEBUG if args.verbose else logging.INFO) + logging.getLogger("cdsapi").setLevel(logging.WARNING) + logging.getLogger("matplotlib").setLevel(logging.WARNING) + logging.getLogger("matplotlib.pyplot").setLevel(logging.WARNING) + logging.getLogger("requests").setLevel(logging.WARNING) + logging.getLogger("tensorflow").setLevel(logging.WARNING) + logging.getLogger("urllib3").setLevel(logging.WARNING) + return args @setup_logging diff --git a/icenet/model/models.py b/icenet/model/models.py index d95d9c90..1222e303 100644 --- a/icenet/model/models.py +++ b/icenet/model/models.py @@ -1,258 +1,8 @@ import numpy as np -import tensorflow as tf -from tensorflow.keras.models import Model -from tensorflow.keras.layers import Conv2D, BatchNormalization, UpSampling2D, \ - concatenate, MaxPooling2D, Input -from tensorflow.keras.optimizers import Adam -""" -Defines the Python-based sea ice forecasting models, such as the IceNet architecture -and the linear trend extrapolation model. -""" - - -@tf.keras.utils.register_keras_serializable() -class TemperatureScale(tf.keras.layers.Layer): - """Temperature scaling layer - - Implements the temperature scaling layer for probability calibration, - as introduced in Guo 2017 (http://proceedings.mlr.press/v70/guo17a.html). - """ - - def __init__(self, **kwargs): - super(TemperatureScale, self).__init__(**kwargs) - self.temp = tf.Variable(initial_value=1.0, - trainable=False, - dtype=tf.float32, - name='temp') - - def call(self, inputs: object, **kwargs): - """ Divide the input logits by the T value. - - :param **kwargs: - :param inputs: - :return: - """ - return tf.divide(inputs, self.temp) - - def get_config(self): - """ For saving and loading networks with this custom layer. - - :return: - """ - return {'temp': self.temp.numpy()} - - -### Network architectures: -# -------------------------------------------------------------------- - - -def unet_batchnorm(input_shape: object, - loss: object, - metrics: object, - learning_rate: float = 1e-4, - custom_optimizer: object = None, - experimental_run_tf_function: bool = True, - filter_size: float = 3, - n_filters_factor: float = 1, - n_forecast_days: int = 1, - legacy_rounding: bool = False) -> object: - """ - - :param input_shape: - :param loss: - :param metrics: - :param learning_rate: - :param custom_optimizer: - :param experimental_run_tf_function: - :param filter_size: - :param n_filters_factor: - :param n_forecast_days: - :param legacy_rounding: Ensures filter number calculations are int()'d at the end of calculations - :return: - """ - inputs = Input(shape=input_shape) - - start_out_channels = 64 - reduced_channels = start_out_channels * n_filters_factor - - if not legacy_rounding: - # We're assuming to just strip off any partial channels, rather than round - reduced_channels = int(reduced_channels) - - channels = { - start_out_channels * 2 ** pow: - reduced_channels * 2 ** pow if not legacy_rounding else int(reduced_channels * 2 ** pow) - for pow in range(4) - } - - conv1 = Conv2D(channels[64], - filter_size, - activation='relu', - padding='same', - kernel_initializer='he_normal')(inputs) - conv1 = Conv2D(channels[64], - filter_size, - activation='relu', - padding='same', - kernel_initializer='he_normal')(conv1) - bn1 = BatchNormalization(axis=-1)(conv1) - pool1 = MaxPooling2D(pool_size=(2, 2))(bn1) - - conv2 = Conv2D(channels[128], - filter_size, - activation='relu', - padding='same', - kernel_initializer='he_normal')(pool1) - conv2 = Conv2D(channels[128], - filter_size, - activation='relu', - padding='same', - kernel_initializer='he_normal')(conv2) - bn2 = BatchNormalization(axis=-1)(conv2) - pool2 = MaxPooling2D(pool_size=(2, 2))(bn2) - - conv3 = Conv2D(channels[256], - filter_size, - activation='relu', - padding='same', - kernel_initializer='he_normal')(pool2) - conv3 = Conv2D(channels[256], - filter_size, - activation='relu', - padding='same', - kernel_initializer='he_normal')(conv3) - bn3 = BatchNormalization(axis=-1)(conv3) - pool3 = MaxPooling2D(pool_size=(2, 2))(bn3) - - conv4 = Conv2D(channels[256], - filter_size, - activation='relu', - padding='same', - kernel_initializer='he_normal')(pool3) - conv4 = Conv2D(channels[256], - filter_size, - activation='relu', - padding='same', - kernel_initializer='he_normal')(conv4) - bn4 = BatchNormalization(axis=-1)(conv4) - pool4 = MaxPooling2D(pool_size=(2, 2))(bn4) - - conv5 = Conv2D(channels[512], - filter_size, - activation='relu', - padding='same', - kernel_initializer='he_normal')(pool4) - conv5 = Conv2D(channels[512], - filter_size, - activation='relu', - padding='same', - kernel_initializer='he_normal')(conv5) - bn5 = BatchNormalization(axis=-1)(conv5) - - up6 = Conv2D(channels[256], - 2, - activation='relu', - padding='same', - kernel_initializer='he_normal')(UpSampling2D( - size=(2, 2), interpolation='nearest')(bn5)) - - merge6 = concatenate([bn4, up6], axis=3) - conv6 = Conv2D(channels[256], - filter_size, - activation='relu', - padding='same', - kernel_initializer='he_normal')(merge6) - conv6 = Conv2D(channels[256], - filter_size, - activation='relu', - padding='same', - kernel_initializer='he_normal')(conv6) - bn6 = BatchNormalization(axis=-1)(conv6) - - up7 = Conv2D(channels[256], - 2, - activation='relu', - padding='same', - kernel_initializer='he_normal')(UpSampling2D( - size=(2, 2), interpolation='nearest')(bn6)) - merge7 = concatenate([bn3, up7], axis=3) - conv7 = Conv2D(channels[256], - filter_size, - activation='relu', - padding='same', - kernel_initializer='he_normal')(merge7) - conv7 = Conv2D(channels[256], - filter_size, - activation='relu', - padding='same', - kernel_initializer='he_normal')(conv7) - bn7 = BatchNormalization(axis=-1)(conv7) - - up8 = Conv2D(channels[128], - 2, - activation='relu', - padding='same', - kernel_initializer='he_normal')(UpSampling2D( - size=(2, 2), interpolation='nearest')(bn7)) - merge8 = concatenate([bn2, up8], axis=3) - conv8 = Conv2D(channels[128], - filter_size, - activation='relu', - padding='same', - kernel_initializer='he_normal')(merge8) - conv8 = Conv2D(channels[128], - filter_size, - activation='relu', - padding='same', - kernel_initializer='he_normal')(conv8) - bn8 = BatchNormalization(axis=-1)(conv8) - - up9 = Conv2D(channels[64], - 2, - activation='relu', - padding='same', - kernel_initializer='he_normal')(UpSampling2D( - size=(2, 2), interpolation='nearest')(bn8)) - - merge9 = concatenate([conv1, up9], axis=3) - - conv9 = Conv2D(channels[64], - filter_size, - activation='relu', - padding='same', - kernel_initializer='he_normal')(merge9) - conv9 = Conv2D(channels[64], - filter_size, - activation='relu', - padding='same', - kernel_initializer='he_normal')(conv9) - conv9 = Conv2D(channels[64], - filter_size, - activation='relu', - padding='same', - kernel_initializer='he_normal')(conv9) - - final_layer = Conv2D(n_forecast_days, kernel_size=1, - activation='sigmoid')(conv9) - - # Keras graph mode needs y_pred and y_true to have the same shape, so we - # we must pad an extra dimension onto the model output to train with - # an extra sample weight dimension in y_true. - # final_layer = tf.expand_dims(final_layer, axis=-1) - - model = Model(inputs, final_layer) - - model.compile(optimizer=Adam(learning_rate=learning_rate) - if custom_optimizer is None else custom_optimizer, - loss=loss, - weighted_metrics=metrics, - experimental_run_tf_function=experimental_run_tf_function) - - return model def linear_trend_forecast( - usable_selector: object, + usable_selector: callable, forecast_date: object, da: object, mask: object, diff --git a/icenet/model/networks/__init__.py b/icenet/model/networks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/icenet/model/networks/base.py b/icenet/model/networks/base.py new file mode 100644 index 00000000..4c342b3d --- /dev/null +++ b/icenet/model/networks/base.py @@ -0,0 +1,92 @@ +import logging +import os +import random + +import numpy as np + +from abc import abstractmethod + + +class BaseNetwork: + def __init__(self, + dataset: object, + run_name: object, + callbacks_additional: list = None, + callbacks_default: list = None, + network_folder: object = None, + seed: int = 42): + + if not network_folder: + self._network_folder = os.path.join(".", "results", "networks", run_name) + + if not os.path.exists(self._network_folder): + logging.info("Creating network folder: {}".format(network_folder)) + os.makedirs(self._network_folder, exist_ok=True) + + self._model_path = os.path.join( + self._network_folder, "{}.model_{}.{}".format(run_name, + dataset.identifier, + seed)) + + self._callbacks = list() if callbacks_default is None else self.get_callbacks() + self._callbacks += callbacks_additional if callbacks_additional is not None else [] + self._dataset = dataset + self._run_name = run_name + self._seed = seed + + self._attempt_seed_setup() + + def _attempt_seed_setup(self): + logging.warning( + "Setting seed for best attempt at determinism, value {}".format(self._seed)) + # determinism is not guaranteed across different versions of TensorFlow. + # determinism is not guaranteed across different hardware. + os.environ['PYTHONHASHSEED'] = str(self._seed) + # numpy.random.default_rng ignores this, WARNING! + np.random.seed(self._seed) + random.seed(self._seed) + + def add_callback(self, callback): + self._callbacks.append(callback) + + def get_callbacks(self): + return list() + + @abstractmethod + def train(self, + dataset: object, + epochs: int, + model_creator: callable, + train_dataset: object, + model_creator_kwargs: dict = None, + save: bool = True): + raise NotImplementedError("Implementation not found") + + @abstractmethod + def predict(self): + raise NotImplementedError("Implementation not found") + + @property + def callbacks(self): + return self._callbacks + + @property + def dataset(self): + return self._dataset + + @property + def model_path(self): + return self._model_path + + @property + def network_folder(self): + return self._network_folder + + @property + def run_name(self): + return self._run_name + + @property + def seed(self): + return self._seed + diff --git a/icenet/model/networks/tensorflow.py b/icenet/model/networks/tensorflow.py new file mode 100644 index 00000000..55746a85 --- /dev/null +++ b/icenet/model/networks/tensorflow.py @@ -0,0 +1,417 @@ +import datetime as dt +import logging +import os + +from icenet.model.networks.base import BaseNetwork +from icenet.model.utils import make_exp_decay_lr_schedule + +import numpy as np +import pandas as pd +import tensorflow as tf + +from tensorflow.keras.callbacks import \ + EarlyStopping, ModelCheckpoint, LearningRateScheduler +from tensorflow.keras.layers import Conv2D, BatchNormalization, UpSampling2D, \ + concatenate, MaxPooling2D, Input +from tensorflow.keras.models import save_model, Model +from tensorflow.keras.optimizers import Adam + + +class TensorflowNetwork(BaseNetwork): + def __init__(self, + *args, + checkpoint_mode: str = "min", + checkpoint_monitor: str = None, + early_stopping_patience: int = 0, + data_queue_size: int = 10, + lr_decay: tuple = (0, 0, 0), + pre_load_path: str = None, + strategy: str = None, + tensorboard_logdir: str = None, + verbose: bool = False, + **kwargs): + super().__init__(*args, **kwargs) + + self._checkpoint_mode = checkpoint_mode + self._checkpoint_monitor = checkpoint_monitor + self._data_queue_size = data_queue_size + self._early_stopping_patience = early_stopping_patience + self._lr_decay = lr_decay + self._tensorboard_logdir = tensorboard_logdir + self._strategy = strategy + self._verbose = verbose + + if pre_load_path is not None and not os.path.exists(pre_load_path): + raise RuntimeError("{} is not available, so you cannot preload the " + "network with it!".format(pre_load_path)) + self._pre_load_path = pre_load_path + + self._weights_path = os.path.join( + self.network_folder, "{}.network_{}.{}.h5".format( + self.run_name, self.dataset.identifier, self.seed)) + + def _attempt_seed_setup(self): + super()._attempt_seed_setup() + tf.random.set_seed(self._seed) + tf.keras.utils.set_random_seed(self._seed) + # See #8: tf.config.experimental.enable_op_determinism() + + def train(self, + epochs: int, + model_creator: callable, + train_dataset: object, + model_creator_kwargs: dict = None, + save: bool = True, + validation_dataset: object = None): + + strategy = tf.distribute.MirroredStrategy() \ + if self._strategy == "mirrored" \ + else tf.distribute.experimental.CentralStorageStrategy() \ + if self._strategy == "central" \ + else tf.distribute.get_strategy() + + history_path = os.path.join(self.network_folder, + "{}_{}_history.json".format( + self.run_name, self.seed)) + + with strategy.scope(): + network = model_creator(**model_creator_kwargs) + + if self._pre_load_path and os.path.exists(self._pre_load_path): + logging.warning("Automagically loading network weights from {}".format( + self._pre_load_path)) + network.load_weights(self._pre_load_path) + + network.summary() + + model_history = network.fit( + train_dataset, + epochs=epochs, + verbose=self._verbose, + callbacks=self.callbacks, + validation_data=validation_dataset, + max_queue_size=self._data_queue_size, + ) + + if save: + logging.info("Saving network to: {}".format(self._weights_path)) + network.save_weights(self._weights_path) + save_model(network, self.model_path) + + with open(history_path, 'w') as fh: + pd.DataFrame(model_history.history).to_json(fh) + + def get_callbacks(self): + callbacks_list = list() + + if self._checkpoint_monitor is not None: + callbacks_list.append( + ModelCheckpoint(filepath=self._weights_path, + monitor=self._checkpoint_monitor, + verbose=1, + mode=self._checkpoint_mode, + save_best_only=True)) + + if self._early_stopping_patience > 0: + callbacks_list.append( + EarlyStopping(monitor=self._checkpoint_monitor, + mode=self._checkpoint_mode, + verbose=1, + patience=self._early_stopping_patience, + baseline=None)) + + if self._lr_decay[0] > 0: + lr_decay = -0.1 * np.log(self._lr_decay[0]) + + callbacks_list.append( + LearningRateScheduler( + make_exp_decay_lr_schedule( + rate=lr_decay, + start_epoch=self._lr_decay[1], + end_epoch=self._lr_decay[2], + ))) + + if self._tensorboard_logdir is not None: + logging.info("Adding tensorboard callback") + log_dir = os.path.join( + self._tensorboard_logdir, + dt.datetime.now().strftime("%d-%m-%y-%H%M%S")) + callbacks_list.append( + tf.keras.callbacks.TensorBoard(log_dir=log_dir, + histogram_freq=1)) + return callbacks_list + + +class HorovodNetwork(TensorflowNetwork): + def __init__(self, + *args, + device_type="XPU", + **kwargs): + super().__init__(*args, **kwargs) + import horovod.tensorflow.keras as hvd + hvd.init() + + gpus = tf.config.list_physical_devices(device_type) + logging.info("{} count is {}".format(device_type, len(gpus))) + + for gpu in gpus: + tf.config.experimental.set_memory_growth(gpu, True) + if gpus: + tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], 'XPU') + + self.add_callback( + hvd.callbacks.BroadcastGlobalVariablesCallback(0) + ) + self._horovod = hvd + + def train(self, + epochs: int, + learning_rate: float, + loss: object, + metrics: object, + model_creator: callable, + train_dataset: object, + model_creator_args: dict = None, + save: bool = True, + validation_dataset: object = None): + + history_path = os.path.join(self.network_folder, + "{}_{}_history.json".format( + self.run_name, self.seed)) + + # TODO: this is totally assuming the structure of model_creator :( + network = model_creator(**model_creator_args, + custom_optimizer=self._horovod.DistributedOptimizer(Adam(learning_rate)), + experimental_run_tf_function=False) + + if self._pre_load_path and os.path.exists(self._pre_load_path): + logging.warning("Automagically loading network weights from {}".format( + self._pre_load_path)) + network.load_weights(self._pre_load_path) + + network.summary() + + model_history = network.fit( + train_dataset, + epochs=epochs, + verbose=1 if self._horovod.rank() == 0 and self._verbose else 0, + callbacks=self.callbacks, + validation_data=validation_dataset, + max_queue_size=self._data_queue_size, + steps_per_epoch=self.dataset.counts["train"] // (self.dataset.batch_size * self._horovod.size()), + ) + + if save: + logging.info("Saving network to: {}".format(self._weights_path)) + network.save_weights(self._weights_path) + save_model(network, self.model_path) + + with open(history_path, 'w') as fh: + pd.DataFrame(model_history.history).to_json(fh) + + +### Network architectures: +# -------------------------------------------------------------------- +def unet_batchnorm(input_shape: object, + loss: object, + metrics: object, + learning_rate: float = 1e-4, + custom_optimizer: object = None, + experimental_run_tf_function: bool = True, + filter_size: float = 3, + n_filters_factor: float = 1, + n_forecast_days: int = 1, + legacy_rounding: bool = False) -> object: + """ + + :param input_shape: + :param loss: + :param metrics: + :param learning_rate: + :param custom_optimizer: + :param experimental_run_tf_function: + :param filter_size: + :param n_filters_factor: + :param n_forecast_days: + :param legacy_rounding: Ensures filter number calculations are int()'d at the end of calculations + :return: + """ + inputs = Input(shape=input_shape) + + start_out_channels = 64 + reduced_channels = start_out_channels * n_filters_factor + + if not legacy_rounding: + # We're assuming to just strip off any partial channels, rather than round + reduced_channels = int(reduced_channels) + + channels = { + start_out_channels * 2 ** pow: + reduced_channels * 2 ** pow if not legacy_rounding else int(reduced_channels * 2 ** pow) + for pow in range(4) + } + + conv1 = Conv2D(channels[64], + filter_size, + activation='relu', + padding='same', + kernel_initializer='he_normal')(inputs) + conv1 = Conv2D(channels[64], + filter_size, + activation='relu', + padding='same', + kernel_initializer='he_normal')(conv1) + bn1 = BatchNormalization(axis=-1)(conv1) + pool1 = MaxPooling2D(pool_size=(2, 2))(bn1) + + conv2 = Conv2D(channels[128], + filter_size, + activation='relu', + padding='same', + kernel_initializer='he_normal')(pool1) + conv2 = Conv2D(channels[128], + filter_size, + activation='relu', + padding='same', + kernel_initializer='he_normal')(conv2) + bn2 = BatchNormalization(axis=-1)(conv2) + pool2 = MaxPooling2D(pool_size=(2, 2))(bn2) + + conv3 = Conv2D(channels[256], + filter_size, + activation='relu', + padding='same', + kernel_initializer='he_normal')(pool2) + conv3 = Conv2D(channels[256], + filter_size, + activation='relu', + padding='same', + kernel_initializer='he_normal')(conv3) + bn3 = BatchNormalization(axis=-1)(conv3) + pool3 = MaxPooling2D(pool_size=(2, 2))(bn3) + + conv4 = Conv2D(channels[256], + filter_size, + activation='relu', + padding='same', + kernel_initializer='he_normal')(pool3) + conv4 = Conv2D(channels[256], + filter_size, + activation='relu', + padding='same', + kernel_initializer='he_normal')(conv4) + bn4 = BatchNormalization(axis=-1)(conv4) + pool4 = MaxPooling2D(pool_size=(2, 2))(bn4) + + conv5 = Conv2D(channels[512], + filter_size, + activation='relu', + padding='same', + kernel_initializer='he_normal')(pool4) + conv5 = Conv2D(channels[512], + filter_size, + activation='relu', + padding='same', + kernel_initializer='he_normal')(conv5) + bn5 = BatchNormalization(axis=-1)(conv5) + + up6 = Conv2D(channels[256], + 2, + activation='relu', + padding='same', + kernel_initializer='he_normal')(UpSampling2D( + size=(2, 2), interpolation='nearest')(bn5)) + + merge6 = concatenate([bn4, up6], axis=3) + conv6 = Conv2D(channels[256], + filter_size, + activation='relu', + padding='same', + kernel_initializer='he_normal')(merge6) + conv6 = Conv2D(channels[256], + filter_size, + activation='relu', + padding='same', + kernel_initializer='he_normal')(conv6) + bn6 = BatchNormalization(axis=-1)(conv6) + + up7 = Conv2D(channels[256], + 2, + activation='relu', + padding='same', + kernel_initializer='he_normal')(UpSampling2D( + size=(2, 2), interpolation='nearest')(bn6)) + merge7 = concatenate([bn3, up7], axis=3) + conv7 = Conv2D(channels[256], + filter_size, + activation='relu', + padding='same', + kernel_initializer='he_normal')(merge7) + conv7 = Conv2D(channels[256], + filter_size, + activation='relu', + padding='same', + kernel_initializer='he_normal')(conv7) + bn7 = BatchNormalization(axis=-1)(conv7) + + up8 = Conv2D(channels[128], + 2, + activation='relu', + padding='same', + kernel_initializer='he_normal')(UpSampling2D( + size=(2, 2), interpolation='nearest')(bn7)) + merge8 = concatenate([bn2, up8], axis=3) + conv8 = Conv2D(channels[128], + filter_size, + activation='relu', + padding='same', + kernel_initializer='he_normal')(merge8) + conv8 = Conv2D(channels[128], + filter_size, + activation='relu', + padding='same', + kernel_initializer='he_normal')(conv8) + bn8 = BatchNormalization(axis=-1)(conv8) + + up9 = Conv2D(channels[64], + 2, + activation='relu', + padding='same', + kernel_initializer='he_normal')(UpSampling2D( + size=(2, 2), interpolation='nearest')(bn8)) + + merge9 = concatenate([conv1, up9], axis=3) + + conv9 = Conv2D(channels[64], + filter_size, + activation='relu', + padding='same', + kernel_initializer='he_normal')(merge9) + conv9 = Conv2D(channels[64], + filter_size, + activation='relu', + padding='same', + kernel_initializer='he_normal')(conv9) + conv9 = Conv2D(channels[64], + filter_size, + activation='relu', + padding='same', + kernel_initializer='he_normal')(conv9) + + final_layer = Conv2D(n_forecast_days, kernel_size=1, + activation='sigmoid')(conv9) + + # Keras graph mode needs y_pred and y_true to have the same shape, so we + # we must pad an extra dimension onto the model output to train with + # an extra sample weight dimension in y_true. + # final_layer = tf.expand_dims(final_layer, axis=-1) + + model = Model(inputs, final_layer) + + model.compile(optimizer=Adam(learning_rate=learning_rate) + if custom_optimizer is None else custom_optimizer, + loss=loss, + weighted_metrics=metrics, + experimental_run_tf_function=experimental_run_tf_function) + + return model diff --git a/icenet/model/train.py b/icenet/model/train.py index 51191a7b..f0cd2c4b 100644 --- a/icenet/model/train.py +++ b/icenet/model/train.py @@ -1,215 +1,29 @@ -import datetime as dt import json import logging -import os import time -import numpy as np -import pandas as pd import tensorflow as tf -from tensorflow.keras.callbacks import \ - EarlyStopping, ModelCheckpoint, LearningRateScheduler -from tensorflow.keras.models import load_model, save_model - from icenet.data.dataset import IceNetDataSet, MergedIceNetDataSet -from icenet.model.cli import train_args -import icenet.model.losses as losses -import icenet.model.metrics as metrics -from icenet.model.utils import attempt_seed_setup, make_exp_decay_lr_schedule -import icenet.model.models as models - - -def train_model(run_name: object, - dataset: object, - callback_objects: list = [], - checkpoint_monitor: str = 'val_rmse', - checkpoint_mode: str = 'min', - dataset_ratio: float = 1.0, - early_stopping_patience: int = 30, - epochs: int = 2, - filter_size: float = 3, - learning_rate: float = 1e-4, - lr_10e_decay_fac: float = 1.0, - lr_decay_start: float = 10, - lr_decay_end: float = 30, - max_queue_size: int = 3, - model_func: object = models.unet_batchnorm, - n_filters_factor: float = 2, - network_folder: object = None, - network_save: bool = True, - pickup_weights: bool = False, - pre_load_network: bool = False, - pre_load_path: object = None, - seed: int = 42, - strategy: object = tf.distribute.get_strategy(), - training_verbosity: int = 1, - workers: int = 5, - use_multiprocessing: bool = True, - use_tensorboard: bool = True) -> object: - """ - - :param run_name: - :param dataset: - :param callback_objects: - :param checkpoint_monitor: - :param checkpoint_mode: - :param dataset_ratio: - :param early_stopping_patience: - :param epochs: - :param filter_size: - :param learning_rate: - :param lr_10e_decay_fac: - :param lr_decay_start: - :param lr_decay_end: - :param max_queue_size: - :param model_func: - :param n_filters_factor: - :param network_folder: - :param network_save: - :param pickup_weights: - :param pre_load_network: - :param pre_load_path: - :param seed: - :param strategy: - :param training_verbosity: - :param workers: - :param use_multiprocessing: - :param use_tensorboard: - :return: - """ +from icenet.model.cli import TrainingArgParser +from icenet.model.networks.tensorflow import HorovodNetwork, TensorflowNetwork, unet_batchnorm - lr_decay = -0.1 * np.log(lr_10e_decay_fac) +from tensorflow.keras.models import load_model - input_shape = (*dataset.shape, dataset.num_channels) - - if pre_load_network and not os.path.exists(pre_load_path): - raise RuntimeError("{} is not available, so you cannot preload the " - "network with it!".format(pre_load_path)) - - if not network_folder: - network_folder = os.path.join(".", "results", "networks", run_name) - - if not os.path.exists(network_folder): - logging.info("Creating network folder: {}".format(network_folder)) - os.makedirs(network_folder, exist_ok=True) - - weights_path = os.path.join( - network_folder, "{}.network_{}.{}.h5".format(run_name, - dataset.identifier, seed)) - model_path = os.path.join( - network_folder, "{}.model_{}.{}".format(run_name, dataset.identifier, - seed)) - - history_path = os.path.join(network_folder, - "{}_{}_history.json".format(run_name, seed)) - - prev_best = None - callbacks_list = list() - - # Checkpoint the model weights when a validation metric is improved - callbacks_list.append( - ModelCheckpoint(filepath=weights_path, - monitor=checkpoint_monitor, - verbose=1, - mode=checkpoint_mode, - save_best_only=True)) - - # Abort training when validation performance stops improving - callbacks_list.append( - EarlyStopping(monitor=checkpoint_monitor, - mode=checkpoint_mode, - verbose=1, - patience=early_stopping_patience, - baseline=prev_best)) - - callbacks_list.append( - LearningRateScheduler( - make_exp_decay_lr_schedule( - rate=lr_decay, - start_epoch=lr_decay_start, - end_epoch=lr_decay_end, - ))) - - if use_tensorboard: - logging.info("Adding tensorboard callback") - log_dir = "logs/" + dt.datetime.now().strftime("%d-%m-%y-%H%M%S") - callbacks_list.append( - tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)) - - ############################################################################ - # TRAINING MODEL - ############################################################################ - - with strategy.scope(): - loss = losses.WeightedMSE() - metrics_list = [ - # metrics.weighted_MAE, - metrics.WeightedBinaryAccuracy(), - metrics.WeightedMAE(), - metrics.WeightedRMSE(), - losses.WeightedMSE() - ] - - network = model_func( - input_shape=input_shape, - loss=loss, - metrics=metrics_list, - learning_rate=learning_rate, - filter_size=filter_size, - n_filters_factor=n_filters_factor, - n_forecast_days=dataset.n_forecast_days, - ) - - if pre_load_network: - logging.info("Loading network weights from {}".format(pre_load_path)) - network.load_weights(pre_load_path) - elif pickup_weights and os.path.exists(weights_path): - logging.warning("Automagically loading network weights from {}".format( - weights_path)) - network.load_weights(weights_path) - - network.summary() - - ratio = dataset_ratio if dataset_ratio else 1.0 - train_ds, val_ds, test_ds = dataset.get_split_datasets(ratio=ratio) - - model_history = network.fit( - train_ds, - epochs=epochs, - verbose=training_verbosity, - callbacks=callbacks_list + callback_objects, - validation_data=val_ds, - max_queue_size=max_queue_size, - # not useful for tf.data usage according to docs, but useful in dev - workers=workers, - use_multiprocessing=use_multiprocessing) - - if network_save: - logging.info("Saving network to: {}".format(weights_path)) - network.save_weights(weights_path) - save_model(network, model_path) - - with open(history_path, 'w') as fh: - pd.DataFrame(model_history.history).to_json(fh) - - return weights_path, model_path +import icenet.model.losses as losses +import icenet.model.metrics as metrics def evaluate_model(model_path: object, dataset: object, dataset_ratio: float = 1.0, - max_queue_size: int = 3, - workers: int = 5, - use_multiprocessing: bool = True): + max_queue_size: int = 3): """ :param model_path: :param dataset: :param dataset_ratio: :param max_queue_size: - :param workers: - :param use_multiprocessing: """ logging.info("Running evaluation against test set") network = load_model(model_path, compile=False) @@ -227,6 +41,7 @@ def evaluate_model(model_path: object, lead_times = list(range(1, dataset.n_forecast_days + 1)) logging.info("Metric creation for lead time of {} days".format( len(lead_times))) + # TODO: common across train_model and evaluate_model - list of instantiations metric_names = ["binacc", "mae", "rmse"] metrics_classes = [ metrics.WeightedBinaryAccuracy, @@ -246,9 +61,7 @@ def evaluate_model(model_path: object, eval_data, return_dict=True, verbose=0, - max_queue_size=max_queue_size, - workers=workers, - use_multiprocessing=use_multiprocessing, + max_queue_size=max_queue_size ) results_path = "{}.results.json".format(model_path) @@ -261,77 +74,100 @@ def evaluate_model(model_path: object, return results, metric_names, lead_times -def main(): - args = train_args() - attempt_seed_setup(args.seed) - +def get_datasets(args): # TODO: this should come from a factory in the future - not the only place # that merged datasets are going to be available + + dataset_filenames = [ + el if str(el).split(".")[-1] == "json" else "dataset_config.{}.json".format(el) + for el in [args.dataset, *args.additional] + ] + if len(args.additional) == 0: - dataset = IceNetDataSet("dataset_config.{}.json".format(args.dataset), + dataset = IceNetDataSet(dataset_filenames[0], batch_size=args.batch_size, shuffling=args.shuffle_train) else: - dataset = MergedIceNetDataSet([ - "dataset_config.{}.json".format(el) - for el in [args.dataset, *args.additional] - ], + dataset = MergedIceNetDataSet(dataset_filenames, batch_size=args.batch_size, shuffling=args.shuffle_train) - - strategy = tf.distribute.MirroredStrategy() \ - if args.strategy == "mirrored" \ - else tf.distribute.experimental.CentralStorageStrategy() \ - if args.strategy == "central" \ - else tf.distribute.get_strategy() - + return dataset + + +def horovod_main(): + args = TrainingArgParser().add_unet().add_horovod().add_wandb().parse_args() + dataset = get_datasets() + network = HorovodNetwork() + execute_tf_training(args, dataset, network) + + +def tensorflow_main(): + args = TrainingArgParser().add_unet().add_tensorflow().add_wandb().parse_args() + dataset = get_datasets(args) + network = TensorflowNetwork(dataset, + args.run_name, + checkpoint_mode=args.checkpoint_mode, + checkpoint_monitor=args.checkpoint_monitor, + early_stopping_patience=args.early_stopping, + data_queue_size=args.max_queue_size, + lr_decay=( + args.lr_10e_decay_fac, + args.lr_decay_start, + args.lr_decay_end, + ), + pre_load_path=args.preload, + seed=args.seed, + strategy=args.strategy, + verbose=args.verbose) + execute_tf_training(args, dataset, network) + + +def execute_tf_training(args, dataset, network): # There is a better way of doing this by passing off to a dynamic factory # for other integrations, but for the moment I have no shame - callback_objects = list() using_wandb = False run = None - # TODO: this can and probably should be a decorator + # TODO: move to overridden implementation - decorator? if not args.no_wandb: from icenet.model.handlers.wandb import init_wandb, finalise_wandb run, callback = init_wandb(args) if callback is not None: - callback_objects.append(callback) + network.add_callback(callback) using_wandb = True - weights_path, model_path = \ - train_model(args.run_name, - dataset, - callback_objects=callback_objects, - checkpoint_mode=args.checkpoint_mode, - checkpoint_monitor=args.checkpoint_monitor, - dataset_ratio=args.ratio, - early_stopping_patience=args.early_stopping, - epochs=args.epochs, - filter_size=args.filter_size, - learning_rate=args.lr, - lr_10e_decay_fac=args.lr_10e_decay_fac, - lr_decay_start=args.lr_decay_start, - lr_decay_end=args.lr_decay_end, - pickup_weights=args.pickup_weights, - pre_load_network=args.preload is not None, - pre_load_path=args.preload, - max_queue_size=args.max_queue_size, - n_filters_factor=args.n_filters_factor, - seed=args.seed, - strategy=strategy, - training_verbosity=1 if args.verbose else 2, - use_multiprocessing=args.multiprocessing, - workers=args.workers) + input_shape = (*dataset.shape, dataset.num_channels) + ratio = args.ratio if args.ratio else 1.0 + train_ds, val_ds, _ = dataset.get_split_datasets(ratio=ratio) + + network.train( + args.epochs, + unet_batchnorm, + train_ds, + model_creator_kwargs=dict( + input_shape=input_shape, + loss=losses.WeightedMSE(), + metrics=[ + metrics.WeightedBinaryAccuracy(), + metrics.WeightedMAE(), + metrics.WeightedRMSE(), + losses.WeightedMSE() + ], + learning_rate=args.lr, + filter_size=args.filter_size, + n_filters_factor=args.n_filters_factor, + n_forecast_days=dataset.n_forecast_days, + ), + save=True, + validation_dataset=val_ds + ) results, metric_names, leads = \ - evaluate_model(model_path, + evaluate_model(network.model_path, dataset, dataset_ratio=args.ratio, - max_queue_size=args.max_queue_size, - use_multiprocessing=args.multiprocessing, - workers=args.workers) + max_queue_size=args.max_queue_size) if using_wandb: finalise_wandb(run, results, metric_names, leads) diff --git a/icenet/model/utils.py b/icenet/model/utils.py index 3367fdb0..05f44ae5 100644 --- a/icenet/model/utils.py +++ b/icenet/model/utils.py @@ -7,20 +7,6 @@ import tensorflow as tf -def attempt_seed_setup(seed): - logging.warning( - "Setting seed for best attempt at determinism, value {}".format(seed)) - # determinism is not guaranteed across different versions of TensorFlow. - # determinism is not guaranteed across different hardware. - os.environ['PYTHONHASHSEED'] = str(seed) - # numpy.random.default_rng ignores this, WARNING! - np.random.seed(seed) - random.seed(seed) - tf.random.set_seed(seed) - tf.keras.utils.set_random_seed(seed) - # See #8: tf.config.experimental.enable_op_determinism() - - ################################################################################ # LEARNING RATE ################################################################################ @@ -145,3 +131,38 @@ def arr_to_ice_edge_rgba_arr(arr: object, thresh: object, land_mask: object, ice_edge_rgba_arr[:, :, :3] = rgb return ice_edge_rgba_arr + + +### Potentially redundant implementations + + +@tf.keras.utils.register_keras_serializable() +class TemperatureScale(tf.keras.layers.Layer): + """Temperature scaling layer + + Implements the temperature scaling layer for probability calibration, + as introduced in Guo 2017 (http://proceedings.mlr.press/v70/guo17a.html). + """ + + def __init__(self, **kwargs): + super(TemperatureScale, self).__init__(**kwargs) + self.temp = tf.Variable(initial_value=1.0, + trainable=False, + dtype=tf.float32, + name='temp') + + def call(self, inputs: object, **kwargs): + """ Divide the input logits by the T value. + + :param **kwargs: + :param inputs: + :return: + """ + return tf.divide(inputs, self.temp) + + def get_config(self): + """ For saving and loading networks with this custom layer. + + :return: + """ + return {'temp': self.temp.numpy()} diff --git a/setup.py b/setup.py index e12aaaf7..cdcdd570 100644 --- a/setup.py +++ b/setup.py @@ -66,7 +66,9 @@ def get_content(filename): "icenet_dataset_check = icenet.data.dataset:check_dataset", "icenet_dataset_create = icenet.data.loader:create", - "icenet_train = icenet.model.train:main", + "icenet_train_horovod = icenet.model.train:horovod_main", + "icenet_train_tensorflow = icenet.model.train:tensorflow_main", + "icenet_predict = icenet.model.predict:main", "icenet_upload_azure = icenet.process.azure:upload", "icenet_upload_local = icenet.process.local:upload",