From 090a27d1fb83ebfcb5f93e75c1bbb2e32820f971 Mon Sep 17 00:00:00 2001 From: Martin Weigert Date: Wed, 16 Oct 2024 12:37:36 +0200 Subject: [PATCH] Expose params in train script --- spotiflow/cli/train.py | 35 +++++++++++++++++++++++++++++++++-- spotiflow/model/spotiflow.py | 2 +- 2 files changed, 34 insertions(+), 3 deletions(-) diff --git a/spotiflow/cli/train.py b/spotiflow/cli/train.py index 901a003..e30ed18 100644 --- a/spotiflow/cli/train.py +++ b/spotiflow/cli/train.py @@ -16,6 +16,7 @@ log = logging.getLogger(__name__) log.setLevel(logging.INFO) +logging.basicConfig(level=logging.INFO) console_handler = logging.StreamHandler(sys.stdout) console_handler.setLevel(logging.INFO) @@ -58,6 +59,7 @@ def get_args() -> argparse.Namespace: type=Path, help="Path to directory containing images and annotations. Please refer to the documentation (https://weigertlab.github.io/spotiflow/train.html#data-format) to see the required format.", ) + required.add_argument( "-o", "--outdir", @@ -130,6 +132,24 @@ def get_args() -> argparse.Namespace: title="Training arguments", description="Arguments to configure the training process.", ) + + train_args.add_argument( + "--subfolder", + type=Path, + nargs=2, + required=False, + default=['train', 'val'], + help="Subfolder names for training and validation data. Defaults to ['train', 'val'].", + ) + + train_args.add_argument( + "--train_samples", + type=int, + required=False, + default=None, + help="Number of training samples per epoch (defaults to None, which means all samples).", + ) + train_args.add_argument( "--crop-size", type=int, @@ -161,6 +181,14 @@ def get_args() -> argparse.Namespace: choices=["auto", "cpu", "cuda", "mps"], help="Device to train the model on. Defaults to 'auto', which will infer based on the hardware.", ) + train_args.add_argument( + "--augment", + type=str2bool, + required=False, + default=True, + help="Apply data augmentation during training. Defaults to True.", + ) + train_args.add_argument( "--pos-weight", type=float, @@ -196,12 +224,13 @@ def get_args() -> argparse.Namespace: def main(): args = get_args() + log.info(f"Spotiflow - version {__version__}") pl.seed_everything(args.seed, workers=True) log.info("Loading training data...") - train_images, train_spots = get_data(args.data_dir / "train", is_3d=args.is_3d) + train_images, train_spots = get_data(args.data_dir / args.subfolder[0], is_3d=args.is_3d) if len(train_images) != len(train_spots): raise ValueError(f"Number of images and spots in {args.data_dir/'train'} do not match.") if len(train_images) == 0: @@ -209,7 +238,7 @@ def main(): log.info(f"Training data loaded (N={len(train_images)}).") log.info("Loading validation data...") - val_images, val_spots = get_data(args.data_dir / "val", is_3d=args.is_3d) + val_images, val_spots = get_data(args.data_dir / args.subfolder[1], is_3d=args.is_3d) if len(val_images) != len(val_spots): raise ValueError(f"Number of images and spots in {args.data_dir/'val'} do not match.") if len(val_images) == 0: @@ -266,6 +295,7 @@ def main(): save_dir=args.outdir, device=args.device, logger=args.logger, + augment_train=args.augment, train_config={ "batch_size": args.batch_size, "crop_size": args.crop_size, @@ -273,6 +303,7 @@ def main(): "lr": args.lr, "num_epochs": args.num_epochs, "pos_weight": args.pos_weight, + "num_train_samples":args.train_samples, "finetuned_from": args.finetune_from, }, ) diff --git a/spotiflow/model/spotiflow.py b/spotiflow/model/spotiflow.py index ca12e84..470c2fb 100644 --- a/spotiflow/model/spotiflow.py +++ b/spotiflow/model/spotiflow.py @@ -821,7 +821,7 @@ def predict( if verbose: log.info("Normalizing...") x = normalizer(x) - + pad_shape = tuple(int(d * np.ceil(s / d)) for s, d in zip(x.shape, div_by)) if verbose: log.info(f"Padding to shape {pad_shape}")