Skip to content

Commit

Permalink
Dev icenet-ai#252: small runtime fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
JimCircadian committed May 23, 2024
1 parent 92c5c30 commit 347fbc6
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
3 changes: 2 additions & 1 deletion icenet/model/networks/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def train(self,
verbose=self._verbose,
callbacks=self.callbacks,
validation_data=validation_dataset,
# TODO: pretty sure this is redundant for non-keras.utils.Sequence, legacy inclusion!
max_queue_size=self._data_queue_size,
)

Expand Down Expand Up @@ -174,7 +175,7 @@ def train(self,
self._pre_load_path))
network.load_weights(self._pre_load_path)

if model_creator_kwargs["horovod"].rank() == 0:
if hvd.local_rank() == 0:
network.summary()

logging.debug("Calling training loop")
Expand Down
3 changes: 1 addition & 2 deletions icenet/model/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import tensorflow as tf
import horovod.tensorflow.keras as hvd
hvd.init()

from icenet.data.dataset import IceNetDataSet, MergedIceNetDataSet
from icenet.model.cli import TrainingArgParser
Expand Down Expand Up @@ -98,6 +97,7 @@ def get_datasets(args):

def horovod_main():
args = TrainingArgParser().add_unet().add_horovod().add_wandb().parse_args()
hvd.init()

if args.device_type in ("XPU", "GPU"):
logging.debug("Setting up {} devices".format(args.device_type))
Expand All @@ -115,7 +115,6 @@ def horovod_main():
args.run_name,
checkpoint_mode=args.checkpoint_mode,
checkpoint_monitor=args.checkpoint_monitor,
device_type=args.device_type,
early_stopping_patience=args.early_stopping,
data_queue_size=args.max_queue_size,
lr_decay=(
Expand Down

0 comments on commit 347fbc6

Please sign in to comment.