Skip to content

Commit

Permalink
Dev icenet-ai#252: attempting to restructure more appropriately
Browse files Browse the repository at this point in the history
  • Loading branch information
JimCircadian committed May 23, 2024
1 parent 72f4b53 commit e1421b9
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 23 deletions.
30 changes: 7 additions & 23 deletions icenet/model/networks/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import numpy as np
import pandas as pd
import tensorflow as tf
import horovod.tensorflow.keras as hvd

from tensorflow.keras.callbacks import \
EarlyStopping, ModelCheckpoint, LearningRateScheduler
Expand Down Expand Up @@ -145,27 +144,6 @@ def get_callbacks(self):


class HorovodNetwork(TensorflowNetwork):
def __init__(self,
*args,
device_type: str = None,
**kwargs):
super().__init__(*args, **kwargs)

if device_type in ("XPU", "GPU"):
logging.debug("Setting up {} devices".format(device_type))
devices = tf.config.list_physical_devices(device_type)
logging.info("{} count is {}".format(device_type, len(devices)))

for dev in devices:
tf.config.experimental.set_memory_growth(dev, True)

if devices:
tf.config.experimental.set_visible_devices(devices[hvd.local_rank()], device_type)

self.add_callback(
hvd.callbacks.BroadcastGlobalVariablesCallback(0)
)

def train(self,
epochs: int,
model_creator: callable,
Expand All @@ -178,7 +156,13 @@ def train(self,
"{}_{}_history.json".format(
self.run_name, self.seed))

# TODO: this is totally assuming the structure of model_creator :(
import horovod.tensorflow.keras as hvd

if hvd.is_initialized():
logging.info("Horovod is initialized when we call train, with {} members".format(hvd.size()))
else:
raise RuntimeError("Horovod is not initialized")

logging.debug("Calling {} to create our model".format(model_creator))
network = model_creator(**model_creator_kwargs,
custom_optimizer=hvd.DistributedOptimizer(Adam(model_creator_kwargs["learning_rate"])),
Expand Down
16 changes: 16 additions & 0 deletions icenet/model/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,18 @@ def get_datasets(args):

def horovod_main():
args = TrainingArgParser().add_unet().add_horovod().add_wandb().parse_args()

if args.device_type in ("XPU", "GPU"):
logging.debug("Setting up {} devices".format(args.device_type))
devices = tf.config.list_physical_devices(args.device_type)
logging.info("{} count is {}".format(args.device_type, len(devices)))

for dev in devices:
tf.config.experimental.set_memory_growth(dev, True)

if devices:
tf.config.experimental.set_visible_devices(devices[hvd.local_rank()], args.device_type)

dataset = get_datasets(args)
network = HorovodNetwork(dataset,
args.run_name,
Expand All @@ -114,6 +126,10 @@ def horovod_main():
pre_load_path=args.preload,
seed=args.seed,
verbose=args.verbose)
network.add_callback(
hvd.callbacks.BroadcastGlobalVariablesCallback(0)
)

execute_tf_training(args, dataset, network)


Expand Down

0 comments on commit e1421b9

Please sign in to comment.