From 347fbc6ef0aef6aa6dc07fce0ddcfa1c49c82ccc Mon Sep 17 00:00:00 2001 From: James Byrne Date: Thu, 23 May 2024 16:04:23 +0100 Subject: [PATCH] Dev #252: small runtime fixes --- icenet/model/networks/tensorflow.py | 3 ++- icenet/model/train.py | 3 +-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/icenet/model/networks/tensorflow.py b/icenet/model/networks/tensorflow.py index 565f58e..a759079 100644 --- a/icenet/model/networks/tensorflow.py +++ b/icenet/model/networks/tensorflow.py @@ -90,6 +90,7 @@ def train(self, verbose=self._verbose, callbacks=self.callbacks, validation_data=validation_dataset, + # TODO: pretty sure this is redundant for non-keras.utils.Sequence, legacy inclusion! max_queue_size=self._data_queue_size, ) @@ -174,7 +175,7 @@ def train(self, self._pre_load_path)) network.load_weights(self._pre_load_path) - if model_creator_kwargs["horovod"].rank() == 0: + if hvd.local_rank() == 0: network.summary() logging.debug("Calling training loop") diff --git a/icenet/model/train.py b/icenet/model/train.py index 1db56c7..f916742 100644 --- a/icenet/model/train.py +++ b/icenet/model/train.py @@ -4,7 +4,6 @@ import tensorflow as tf import horovod.tensorflow.keras as hvd -hvd.init() from icenet.data.dataset import IceNetDataSet, MergedIceNetDataSet from icenet.model.cli import TrainingArgParser @@ -98,6 +97,7 @@ def get_datasets(args): def horovod_main(): args = TrainingArgParser().add_unet().add_horovod().add_wandb().parse_args() + hvd.init() if args.device_type in ("XPU", "GPU"): logging.debug("Setting up {} devices".format(args.device_type)) @@ -115,7 +115,6 @@ def horovod_main(): args.run_name, checkpoint_mode=args.checkpoint_mode, checkpoint_monitor=args.checkpoint_monitor, - device_type=args.device_type, early_stopping_patience=args.early_stopping, data_queue_size=args.max_queue_size, lr_decay=(