Skip to content

Commit

Permalink
Expose params in train script
Browse files Browse the repository at this point in the history
  • Loading branch information
maweigert committed Oct 16, 2024
1 parent 61909e9 commit 090a27d
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 3 deletions.
35 changes: 33 additions & 2 deletions spotiflow/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

log = logging.getLogger(__name__)
log.setLevel(logging.INFO)
logging.basicConfig(level=logging.INFO)

console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.INFO)
Expand Down Expand Up @@ -58,6 +59,7 @@ def get_args() -> argparse.Namespace:
type=Path,
help="Path to directory containing images and annotations. Please refer to the documentation (https://weigertlab.github.io/spotiflow/train.html#data-format) to see the required format.",
)

required.add_argument(
"-o",
"--outdir",
Expand Down Expand Up @@ -130,6 +132,24 @@ def get_args() -> argparse.Namespace:
title="Training arguments",
description="Arguments to configure the training process.",
)

train_args.add_argument(
"--subfolder",
type=Path,
nargs=2,
required=False,
default=['train', 'val'],
help="Subfolder names for training and validation data. Defaults to ['train', 'val'].",
)

train_args.add_argument(
"--train_samples",
type=int,
required=False,
default=None,
help="Number of training samples per epoch (defaults to None, which means all samples).",
)

train_args.add_argument(
"--crop-size",
type=int,
Expand Down Expand Up @@ -161,6 +181,14 @@ def get_args() -> argparse.Namespace:
choices=["auto", "cpu", "cuda", "mps"],
help="Device to train the model on. Defaults to 'auto', which will infer based on the hardware.",
)
train_args.add_argument(
"--augment",
type=str2bool,
required=False,
default=True,
help="Apply data augmentation during training. Defaults to True.",
)

train_args.add_argument(
"--pos-weight",
type=float,
Expand Down Expand Up @@ -196,20 +224,21 @@ def get_args() -> argparse.Namespace:

def main():
args = get_args()

log.info(f"Spotiflow - version {__version__}")

pl.seed_everything(args.seed, workers=True)

log.info("Loading training data...")
train_images, train_spots = get_data(args.data_dir / "train", is_3d=args.is_3d)
train_images, train_spots = get_data(args.data_dir / args.subfolder[0], is_3d=args.is_3d)
if len(train_images) != len(train_spots):
raise ValueError(f"Number of images and spots in {args.data_dir/'train'} do not match.")
if len(train_images) == 0:
raise ValueError(f"No images were found in the {args.data_dir/'train'}.")
log.info(f"Training data loaded (N={len(train_images)}).")

log.info("Loading validation data...")
val_images, val_spots = get_data(args.data_dir / "val", is_3d=args.is_3d)
val_images, val_spots = get_data(args.data_dir / args.subfolder[1], is_3d=args.is_3d)
if len(val_images) != len(val_spots):
raise ValueError(f"Number of images and spots in {args.data_dir/'val'} do not match.")
if len(val_images) == 0:
Expand Down Expand Up @@ -266,13 +295,15 @@ def main():
save_dir=args.outdir,
device=args.device,
logger=args.logger,
augment_train=args.augment,
train_config={
"batch_size": args.batch_size,
"crop_size": args.crop_size,
"crop_size_z": args.crop_size_z,
"lr": args.lr,
"num_epochs": args.num_epochs,
"pos_weight": args.pos_weight,
"num_train_samples":args.train_samples,
"finetuned_from": args.finetune_from,
},
)
Expand Down
2 changes: 1 addition & 1 deletion spotiflow/model/spotiflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,7 +821,7 @@ def predict(
if verbose:
log.info("Normalizing...")
x = normalizer(x)

pad_shape = tuple(int(d * np.ceil(s / d)) for s, d in zip(x.shape, div_by))
if verbose:
log.info(f"Padding to shape {pad_shape}")
Expand Down

0 comments on commit 090a27d

Please sign in to comment.