Skip to content

Commit

Permalink
Fixes icenet-ai#263: implemented a basic extension check to allow ful…
Browse files Browse the repository at this point in the history
…ly qualified dataset filenames and Dev icenet-ai#252: refactoring of existing training functionality to allow extension to use horovod for fully distributed training as a child implementation of the original tensorflow
  • Loading branch information
JimCircadian committed May 23, 2024
1 parent 6158908 commit 5f56a1f
Show file tree
Hide file tree
Showing 9 changed files with 731 additions and 587 deletions.
1 change: 0 additions & 1 deletion icenet/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ def wrapper(*args, **kwargs):
logging.basicConfig(
level=level,
format=log_format,
datefmt="%d-%m-%y %T",
)

# TODO: better way of handling these on a case by case basis
Expand Down
183 changes: 105 additions & 78 deletions icenet/model/cli.py
Original file line number Diff line number Diff line change
@@ -1,90 +1,117 @@
import argparse
import logging
import os

from icenet.cli import setup_logging


@setup_logging
def train_args():
"""
class TrainingArgParser(argparse.ArgumentParser):
"""An ArgumentParser specialised to support model training
:return:
The 'allow_*' methods return self to permit method chaining.
"""
ap = argparse.ArgumentParser()
ap.add_argument("dataset", type=str)
ap.add_argument("run_name", type=str)
ap.add_argument("seed", type=int)

ap.add_argument("-b", "--batch-size", type=int, default=4)
ap.add_argument("-ca",
"--checkpoint-mode",
default="min",
type=str)
ap.add_argument("-cm",
"--checkpoint-monitor",
default="val_rmse",
type=str)
ap.add_argument("-ds",
"--additional-dataset",
dest="additional",
nargs="*",
default=[])
ap.add_argument("-e", "--epochs", type=int, default=4)
ap.add_argument("-f", "--filter-size", type=int, default=3)
ap.add_argument("--early-stopping", type=int, default=50)
ap.add_argument("-m",
"--multiprocessing",
action="store_true",
default=False)
ap.add_argument("-n", "--n-filters-factor", type=float, default=1.)
ap.add_argument("-p", "--preload", type=str)
ap.add_argument("-pw",
"--pickup-weights",
action="store_true",
default=False)
ap.add_argument("-qs", "--max-queue-size", default=10, type=int)
ap.add_argument("-r", "--ratio", default=1.0, type=float)

ap.add_argument("-s",
"--strategy",
default=None,
choices=("default", "mirrored", "central"))

ap.add_argument("--shuffle-train",
default=False,
action="store_true",
help="Shuffle the training set")
ap.add_argument("-v", "--verbose", action="store_true", default=False)
ap.add_argument("-w", "--workers", type=int, default=4)

# WandB additional arguments
ap.add_argument("-nw", "--no-wandb", default=False, action="store_true")
ap.add_argument("-wo",
"--wandb-offline",
default=False,
action="store_true")
ap.add_argument("-wp",
"--wandb-project",
default=os.environ.get("ICENET_ENVIRONMENT"),
type=str)
ap.add_argument("-wu",
"--wandb-user",
default=os.environ.get("USER"),
type=str)

# Learning rate arguments
ap.add_argument("--lr", default=1e-4, type=float)
ap.add_argument("--lr_10e_decay_fac",
default=1.0,
type=float,
help="Factor by which LR is multiplied by every 10 epochs "
"using exponential decay. E.g. 1 -> no decay (default)"
", 0.5 -> halve every 10 epochs.")
ap.add_argument('--lr_decay_start', default=10, type=int)
ap.add_argument('--lr_decay_end', default=30, type=int)

return ap.parse_args()

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

self.add_argument("dataset", type=str)
self.add_argument("run_name", type=str)
self.add_argument("seed", type=int)

self.add_argument("-o", "--output-path", type=str, default=None)
self.add_argument("-v",
"--verbose",
action="store_true",
default=False)

self.add_argument("-b", "--batch-size", type=int, default=4)
self.add_argument("-ca",
"--checkpoint-mode",
default="min",
type=str)
self.add_argument("-cm",
"--checkpoint-monitor",
default="val_rmse",
type=str)
self.add_argument("-ds",
"--additional-dataset",
dest="additional",
nargs="*",
default=[])
self.add_argument("-e", "--epochs", type=int, default=4)
self.add_argument("--early-stopping", type=int, default=50)
self.add_argument("-p", "--preload", type=str)
self.add_argument("-qs", "--max-queue-size", default=10, type=int)
self.add_argument("-r", "--ratio", default=1.0, type=float)
self.add_argument("--shuffle-train",
default=False,
action="store_true",
help="Shuffle the training set")
self.add_argument("--lr", default=1e-4, type=float)
self.add_argument("--lr_10e_decay_fac",
default=1.0,
type=float,
help="Factor by which LR is multiplied by every 10 epochs "
"using exponential decay. E.g. 1 -> no decay (default)"
", 0.5 -> halve every 10 epochs.")
self.add_argument('--lr_decay_start', default=10, type=int)
self.add_argument('--lr_decay_end', default=30, type=int)

def add_unet(self):
self.add_argument("-f", "--filter-size", type=int, default=3)
self.add_argument("-n", "--n-filters-factor", type=float, default=1.)
return self

def add_tensorflow(self):
# TODO: derive from available tf.distribute implementations
self.add_argument("-s",
"--strategy",
default=None,
choices=("default", "mirrored", "central"))
return self

def add_horovod(self):
self.add_argument("-hv",
"--horovod",
default=False,
action="store_true")
self.add_argument("--device-type",
default=None,
help="Choose a device type for distribution, if using")
return self

def add_wandb(self):
self.add_argument("-nw", "--no-wandb", default=False, action="store_true")
self.add_argument("-wo",
"--wandb-offline",
default=False,
action="store_true")
self.add_argument("-wp",
"--wandb-project",
default=os.environ.get("ICENET_ENVIRONMENT"),
type=str)
self.add_argument("-wu",
"--wandb-user",
default=os.environ.get("USER"),
type=str)
return self

def parse_args(self, *args,
log_format="[%(asctime)-17s :%(levelname)-8s] - %(message)s",
**kwargs):
args = super().parse_args(*args, **kwargs)

logging.basicConfig(
datefmt="%d-%m-%y %T",
format=log_format,
level=logging.DEBUG if args.verbose else logging.INFO)
logging.getLogger("cdsapi").setLevel(logging.WARNING)
logging.getLogger("matplotlib").setLevel(logging.WARNING)
logging.getLogger("matplotlib.pyplot").setLevel(logging.WARNING)
logging.getLogger("requests").setLevel(logging.WARNING)
logging.getLogger("tensorflow").setLevel(logging.WARNING)
logging.getLogger("urllib3").setLevel(logging.WARNING)
return args


@setup_logging
Expand Down
Loading

0 comments on commit 5f56a1f

Please sign in to comment.