diff --git a/README.md b/README.md index e71e1f1..7993cec 100644 --- a/README.md +++ b/README.md @@ -51,6 +51,12 @@ Please see our [guide for contributing to pytorch-caney](CONTRIBUTING.md). | :---: | :---: | :---: | :---: | | SatVision-B | MODIS-1.9-M | 192x192 | 84.5M | +## SatVision Datasets + +| name | bands | resolution | #chips | +| :---: | :---: | :---: | :---: | +| MODIS-Small | 7 | 128x128 | 1,994,131 | + ## Pre-training with Masked Image Modeling To pre-train the swinv2 base model with masked image modeling pre-training, run: ```bash @@ -61,13 +67,14 @@ For example to run on a compute node with 4 GPUs and a batch size of 128 on the ```bash singularity shell --nv -B /path/to/container/pytorch-caney -Singularity> torchrun --nproc_per_node 4 --cfg examples/satvision/mim_pretrain_swinv2_satvision_base_192_window12_800ep.yaml --dataset MODIS --data-paths /explore/nobackup/projects/ilab/data/satvision/pretraining/training_* --batch-size 128 --output . --enable-amp +Singularity> export PYTHONPATH=$PWD:$PWD/pytorch-caney +Singularity> torchrun --nproc_per_node 4 pytorch-caney/pytorch_caney/pipelines/pretraining/mim.py --cfg pytorch-caney/examples/satvision/mim_pretrain_swinv2_satvision_base_192_window12_800ep.yaml --dataset MODIS --data-paths /explore/nobackup/projects/ilab/data/satvision/pretraining/training_* --batch-size 128 --output . --enable-amp ``` This example script runs the exact configuration used to make the SatVision-base model pre-training with MiM and the MODIS pre-training dataset. ```bash singularity shell --nv -B /path/to/container/pytorch-caney -Singularity> cd examples/satvision +Singularity> cd pytorch-caney/examples/satvision Singularity> ./run_satvision_pretrain.sh ``` diff --git a/examples/satvision/finetune_satvision_base_landcover5class_192_window12_100ep.yaml b/examples/satvision/finetune_satvision_base_landcover5class_192_window12_100ep.yaml new file mode 100644 index 0000000..5f41c64 --- /dev/null +++ b/examples/satvision/finetune_satvision_base_landcover5class_192_window12_100ep.yaml @@ -0,0 +1,33 @@ +MODEL: + TYPE: swinv2 + DECODER: unet + NAME: satvision_finetune_lc5class + DROP_PATH_RATE: 0.1 + NUM_CLASSES: 5 + SWINV2: + IN_CHANS: 7 + EMBED_DIM: 128 + DEPTHS: [ 2, 2, 18, 2 ] + NUM_HEADS: [ 4, 8, 16, 32 ] + WINDOW_SIZE: 14 + PRETRAINED_WINDOW_SIZES: [ 12, 12, 12, 6 ] +DATA: + IMG_SIZE: 224 + DATASET: MODISLC5 + MASK_PATCH_SIZE: 32 + MASK_RATIO: 0.6 +LOSS: + NAME: 'tversky' + MODE: 'multiclass' + ALPHA: 0.4 + BETA: 0.6 +TRAIN: + EPOCHS: 100 + WARMUP_EPOCHS: 10 + BASE_LR: 1e-4 + WARMUP_LR: 5e-7 + WEIGHT_DECAY: 0.01 + LAYER_DECAY: 0.8 +PRINT_FREQ: 100 +SAVE_FREQ: 5 +TAG: satvision_finetune_land_cover_5class_swinv2_satvision_192_window12__800ep \ No newline at end of file diff --git a/examples/satvision/finetune_satvision_base_landcover9class_192_window12_100ep.yaml b/examples/satvision/finetune_satvision_base_landcover9class_192_window12_100ep.yaml new file mode 100644 index 0000000..2e96121 --- /dev/null +++ b/examples/satvision/finetune_satvision_base_landcover9class_192_window12_100ep.yaml @@ -0,0 +1,33 @@ +MODEL: + TYPE: swinv2 + DECODER: unet + NAME: satvision_finetune_lc9class + DROP_PATH_RATE: 0.1 + NUM_CLASSES: 9 + SWINV2: + IN_CHANS: 7 + EMBED_DIM: 128 + DEPTHS: [ 2, 2, 18, 2 ] + NUM_HEADS: [ 4, 8, 16, 32 ] + WINDOW_SIZE: 14 + PRETRAINED_WINDOW_SIZES: [ 12, 12, 12, 6 ] +DATA: + IMG_SIZE: 224 + DATASET: MODISLC5 + MASK_PATCH_SIZE: 32 + MASK_RATIO: 0.6 +LOSS: + NAME: 'tversky' + MODE: 'multiclass' + ALPHA: 0.4 + BETA: 0.6 +TRAIN: + EPOCHS: 100 + WARMUP_EPOCHS: 10 + BASE_LR: 1e-4 + WARMUP_LR: 5e-7 + WEIGHT_DECAY: 0.01 + LAYER_DECAY: 0.8 +PRINT_FREQ: 100 +SAVE_FREQ: 5 +TAG: satvision_finetune_land_cover_9class_swinv2_satvision_192_window12__800ep \ No newline at end of file diff --git a/examples/satvision/run_satvision_finetune_lc_fiveclass.sh b/examples/satvision/run_satvision_finetune_lc_fiveclass.sh new file mode 100755 index 0000000..155abf6 --- /dev/null +++ b/examples/satvision/run_satvision_finetune_lc_fiveclass.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +#SBATCH -J finetune_satvision_lc5 +#SBATCH -t 3-00:00:00 +#SBATCH -G 4 +#SBATCH -N 1 + + +export PYTHONPATH=$PWD:../../../:../../../pytorch-caney +export NGPUS=8 + +torchrun --nproc_per_node $NGPUS \ + ../../../pytorch-caney/pytorch_caney/pipelines/finetuning/finetune.py \ + --cfg finetune_satvision_base_landcover5class_192_window12_100ep.yaml \ + --pretrained /explore/nobackup/people/cssprad1/projects/satnet/code/development/masked_image_modeling/development/models/simmim_satnet_pretrain_pretrain/simmim_pretrain__satnet_swinv2_base__img192_window12__800ep_v3_no_norm/ckpt_epoch_800.pth \ + --dataset MODISLC9 \ + --data-paths /explore/nobackup/projects/ilab/data/satvision/finetuning/h18v04/labels_9classes_224 \ + --batch-size 4 \ + --output /explore/nobackup/people/cssprad1/projects/satnet/code/development/cleanup/finetune/models \ + --enable-amp \ No newline at end of file diff --git a/examples/satvision/run_satvision_finetune_lc_nineclass.sh b/examples/satvision/run_satvision_finetune_lc_nineclass.sh new file mode 100755 index 0000000..7008967 --- /dev/null +++ b/examples/satvision/run_satvision_finetune_lc_nineclass.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +#SBATCH -J finetune_satvision_lc9 +#SBATCH -t 3-00:00:00 +#SBATCH -G 4 +#SBATCH -N 1 + + +export PYTHONPATH=$PWD:../../../:../../../pytorch-caney +export NGPUS=8 + +torchrun --nproc_per_node $NGPUS \ + ../../../pytorch-caney/pytorch_caney/pipelines/finetuning/finetune.py \ + --cfg finetune_satvision_base_landcover5class_192_window12_100ep.yaml \ + --pretrained /explore/nobackup/people/cssprad1/projects/satnet/code/development/masked_image_modeling/development/models/simmim_satnet_pretrain_pretrain/simmim_pretrain__satnet_swinv2_base__img192_window12__800ep_v3_no_norm/ckpt_epoch_800.pth \ + --dataset MODISLC9 \ + --data-paths /explore/nobackup/projects/ilab/data/satvision/finetuning/h18v04/labels_5classes_224 \ + --batch-size 4 \ + --output /explore/nobackup/people/cssprad1/projects/satnet/code/development/cleanup/finetune/models \ + --enable-amp \ No newline at end of file diff --git a/examples/satvision/run_satvision_pretrain.sh b/examples/satvision/run_satvision_pretrain.sh index 4a6b09b..0ff9598 100755 --- a/examples/satvision/run_satvision_pretrain.sh +++ b/examples/satvision/run_satvision_pretrain.sh @@ -10,7 +10,7 @@ export PYTHONPATH=$PWD:../../../:../../../pytorch-caney export NGPUS=4 torchrun --nproc_per_node $NGPUS \ - ../../../pytorch-caney/pytorch_caney/pipelines/pretraining/simmim.py \ + ../../../pytorch-caney/pytorch_caney/pipelines/pretraining/mim.py \ --cfg mim_pretrain_swinv2_satvision_base_192_window12_800ep.yaml \ --dataset MODIS \ --data-paths /explore/nobackup/projects/ilab/data/satvision/pretraining/training_* \ diff --git a/pytorch_caney/config.py b/pytorch_caney/config.py index 1e11805..d35ac2a 100644 --- a/pytorch_caney/config.py +++ b/pytorch_caney/config.py @@ -36,6 +36,8 @@ _C.MODEL = CN() # Model type _C.MODEL.TYPE = 'swinv2' +# Decoder type +_C.MODEL.DECODER = None # Model name _C.MODEL.NAME = 'swinv2_base_patch4_window7_224' # Pretrained weight from checkpoint, could be from previous pre-training @@ -50,20 +52,6 @@ # Drop path rate _C.MODEL.DROP_PATH_RATE = 0.1 -# Swin Transformer parameters -_C.MODEL.SWIN = CN() -_C.MODEL.SWIN.PATCH_SIZE = 4 -_C.MODEL.SWIN.IN_CHANS = 4 -_C.MODEL.SWIN.EMBED_DIM = 96 -_C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2] -_C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24] -_C.MODEL.SWIN.WINDOW_SIZE = 7 -_C.MODEL.SWIN.MLP_RATIO = 4. -_C.MODEL.SWIN.QKV_BIAS = True -_C.MODEL.SWIN.QK_SCALE = None -_C.MODEL.SWIN.APE = False -_C.MODEL.SWIN.PATCH_NORM = True - # Swin Transformer V2 parameters _C.MODEL.SWINV2 = CN() _C.MODEL.SWINV2.PATCH_SIZE = 4 @@ -78,20 +66,21 @@ _C.MODEL.SWINV2.PATCH_NORM = True _C.MODEL.SWINV2.PRETRAINED_WINDOW_SIZES = [0, 0, 0, 0] -# Vision Transformer parameters -_C.MODEL.VIT = CN() -_C.MODEL.VIT.PATCH_SIZE = 16 -_C.MODEL.VIT.IN_CHANS = 3 -_C.MODEL.VIT.EMBED_DIM = 768 -_C.MODEL.VIT.DEPTH = 12 -_C.MODEL.VIT.NUM_HEADS = 12 -_C.MODEL.VIT.MLP_RATIO = 4 -_C.MODEL.VIT.QKV_BIAS = True -_C.MODEL.VIT.INIT_VALUES = 0.1 -_C.MODEL.VIT.USE_APE = False -_C.MODEL.VIT.USE_RPB = False -_C.MODEL.VIT.USE_SHARED_RPB = True -_C.MODEL.VIT.USE_MEAN_POOLING = False +# ----------------------------------------------------------------------------- +# Training settings +# ----------------------------------------------------------------------------- +_C.LOSS = CN() +_C.LOSS.NAME = 'tversky' +_C.LOSS.MODE = 'multiclass' +_C.LOSS.CLASSES = None +_C.LOSS.LOG = False +_C.LOSS.LOGITS = True +_C.LOSS.SMOOTH = 0.0 +_C.LOSS.IGNORE_INDEX = None +_C.LOSS.EPS = 1e-7 +_C.LOSS.ALPHA = 0.5 +_C.LOSS.BETA = 0.5 +_C.LOSS.GAMMA = 1.0 # ----------------------------------------------------------------------------- # Training settings @@ -139,32 +128,6 @@ # [SimMIM] Layer decay for fine-tuning _C.TRAIN.LAYER_DECAY = 1.0 -# ----------------------------------------------------------------------------- -# Augmentation settings -# ----------------------------------------------------------------------------- -_C.AUG = CN() -# Color jitter factor -_C.AUG.COLOR_JITTER = 0.4 -# Use AutoAugment policy. "v0" or "original" -_C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1' -# Random erase prob -_C.AUG.REPROB = 0.25 -# Random erase mode -_C.AUG.REMODE = 'pixel' -# Random erase count -_C.AUG.RECOUNT = 1 -# Mixup alpha, mixup enabled if > 0 -_C.AUG.MIXUP = 0.8 -# Cutmix alpha, cutmix enabled if > 0 -_C.AUG.CUTMIX = 1.0 -# Cutmix min/max ratio, overrides alpha and enables cutmix if set -_C.AUG.CUTMIX_MINMAX = None -# Probability of performing mixup or cutmix when either/both is enabled -_C.AUG.MIXUP_PROB = 1.0 -# Probability of switching to cutmix when both mixup and cutmix enabled -_C.AUG.MIXUP_SWITCH_PROB = 0.5 -# How to apply mixup/cutmix params. Per "batch", "pair", or "elem" -_C.AUG.MIXUP_MODE = 'batch' # ----------------------------------------------------------------------------- # Testing settings @@ -176,24 +139,22 @@ # ----------------------------------------------------------------------------- # Misc # ----------------------------------------------------------------------------- -# [SimMIM] Whether to enable pytorch amp, overwritten by command line argument +# Whether to enable pytorch amp, overwritten by command line argument _C.ENABLE_AMP = False # Enable Pytorch automatic mixed precision (amp). _C.AMP_ENABLE = True # Path to output folder, overwritten by command line argument _C.OUTPUT = '' # Tag of experiment, overwritten by command line argument -_C.TAG = 'default' +_C.TAG = 'pt-caney-default-tag' # Frequency to save checkpoint _C.SAVE_FREQ = 1 # Frequency to logging info _C.PRINT_FREQ = 10 # Fixed random seed -_C.SEED = 0 +_C.SEED = 42 # Perform evaluation only, overwritten by command line argument _C.EVAL_MODE = False -# Test throughput only, overwritten by command line argument -_C.THROUGHPUT_MODE = False def _update_config_from_file(config, cfg_file): diff --git a/pytorch_caney/data/datamodules/finetune_datamodule.py b/pytorch_caney/data/datamodules/finetune_datamodule.py new file mode 100644 index 0000000..0333ca1 --- /dev/null +++ b/pytorch_caney/data/datamodules/finetune_datamodule.py @@ -0,0 +1,93 @@ +from ..datasets.modis_dataset import MODISDataset +from ..datasets.modis_lc_five_dataset import MODISLCFiveDataset +from ..datasets.modis_lc_nine_dataset import MODISLCNineDataset + +from ..transforms import TensorResizeTransform + +import torch.distributed as dist +from torch.utils.data import DataLoader, DistributedSampler + + +DATASETS = { + 'modis': MODISDataset, + 'modislc9': MODISLCNineDataset, + 'modislc5': MODISLCFiveDataset, + # 'modis tree': MODISTree, +} + + +def get_dataset_from_dict(dataset_name: str): + + dataset_name = dataset_name.lower() + + try: + + dataset_to_use = DATASETS[dataset_name] + + except KeyError: + + error_msg = f"{dataset_name} is not an existing dataset" + + error_msg = f"{error_msg}. Available datasets: {DATASETS.keys()}" + + raise KeyError(error_msg) + + return dataset_to_use + + +def build_finetune_dataloaders(config, logger): + + transform = TensorResizeTransform(config) + + logger.info(f'Finetuning data transform:\n{transform}') + + dataset_name = config.DATA.DATASET + + logger.info(f'Dataset: {dataset_name}') + logger.info(f'Data Paths: {config.DATA.DATA_PATHS}') + + dataset_to_use = get_dataset_from_dict(dataset_name) + + logger.info(f'Dataset obj: {dataset_to_use}') + + dataset_train = dataset_to_use(data_paths=config.DATA.DATA_PATHS, + split="train", + img_size=config.DATA.IMG_SIZE, + transform=transform) + + dataset_val = dataset_to_use(data_paths=config.DATA.DATA_PATHS, + split="val", + img_size=config.DATA.IMG_SIZE, + transform=transform) + + logger.info(f'Build dataset: train images = {len(dataset_train)}') + + logger.info(f'Build dataset: val images = {len(dataset_val)}') + + sampler_train = DistributedSampler( + dataset_train, + num_replicas=dist.get_world_size(), + rank=dist.get_rank(), + shuffle=True) + + sampler_val = DistributedSampler( + dataset_val, + num_replicas=dist.get_world_size(), + rank=dist.get_rank(), + shuffle=False) + + dataloader_train = DataLoader(dataset_train, + config.DATA.BATCH_SIZE, + sampler=sampler_train, + num_workers=config.DATA.NUM_WORKERS, + pin_memory=True, + drop_last=True) + + dataloader_val = DataLoader(dataset_val, + config.DATA.BATCH_SIZE, + sampler=sampler_val, + num_workers=config.DATA.NUM_WORKERS, + pin_memory=True, + drop_last=False) + + return dataloader_train, dataloader_val diff --git a/pytorch_caney/data/datasets/modis_dataset.py b/pytorch_caney/data/datasets/modis_dataset.py index e7455f3..1d978df 100644 --- a/pytorch_caney/data/datasets/modis_dataset.py +++ b/pytorch_caney/data/datasets/modis_dataset.py @@ -74,6 +74,6 @@ def get_filenames(self, path): Returns a list of absolute paths to images inside given `path` """ files_list = [] - for filename in os.listdir(path): + for filename in sorted(os.listdir(path)): files_list.append(os.path.join(path, filename)) return files_list diff --git a/pytorch_caney/data/datasets/modis_lc_five_dataset.py b/pytorch_caney/data/datasets/modis_lc_five_dataset.py new file mode 100644 index 0000000..07b8f03 --- /dev/null +++ b/pytorch_caney/data/datasets/modis_lc_five_dataset.py @@ -0,0 +1,76 @@ +import os +from torch.utils.data import Dataset + +import numpy as np +import random + + +class MODISLCFiveDataset(Dataset): + + IMAGE_PATH = os.path.join("images") + MASK_PATH = os.path.join("labels") + + def __init__( + self, + data_paths: list, + split: str, + img_size: tuple = (224, 224), + transform=None, + ): + self.img_size = img_size + self.transform = transform + self.split = split + self.data_paths = data_paths + self.img_list = [] + self.mask_list = [] + for data_path in data_paths: + img_path = os.path.join(data_path, self.IMAGE_PATH) + mask_path = os.path.join(data_path, self.MASK_PATH) + self.img_list.extend(self.get_filenames(img_path)) + self.mask_list.extend(self.get_filenames(mask_path)) + # Split between train and valid set (80/20) + + random_inst = random.Random(12345) # for repeatability + n_items = len(self.img_list) + print(f'Found {n_items} possible patches to use') + range_n_items = range(n_items) + range_n_items = random_inst.sample(range_n_items, int(n_items*0.5)) + idxs = set(random_inst.sample(range_n_items, len(range_n_items) // 5)) + total_idxs = set(range_n_items) + if split == 'train': + idxs = total_idxs - idxs + print(f'> Using {len(idxs)} patches for this dataset ({split})') + self.img_list = [self.img_list[i] for i in idxs] + self.mask_list = [self.mask_list[i] for i in idxs] + print(f'>> {split}: {len(self.img_list)}') + + def __len__(self): + return len(self.img_list) + + def __getitem__(self, idx, transpose=True): + + # load image + img = np.load(self.img_list[idx]) + + img = np.clip(img, 0, 1.0) + + # load mask + mask = np.load(self.mask_list[idx]) + + mask = np.argmax(mask, axis=-1) + + mask = mask-1 + + # perform transformations + img = self.transform(img) + + return img, mask + + def get_filenames(self, path): + """ + Returns a list of absolute paths to images inside given `path` + """ + files_list = [] + for filename in sorted(os.listdir(path)): + files_list.append(os.path.join(path, filename)) + return files_list diff --git a/pytorch_caney/data/datasets/modis_lc_nine_dataset.py b/pytorch_caney/data/datasets/modis_lc_nine_dataset.py new file mode 100644 index 0000000..12bde11 --- /dev/null +++ b/pytorch_caney/data/datasets/modis_lc_nine_dataset.py @@ -0,0 +1,77 @@ +import os +import random + +import numpy as np + +from torch.utils.data import Dataset + + +class MODISLCNineDataset(Dataset): + + IMAGE_PATH = os.path.join("images") + MASK_PATH = os.path.join("labels") + + def __init__( + self, + data_paths: list, + split: str, + img_size: tuple = (224, 224), + transform=None, + ): + self.img_size = img_size + self.transform = transform + self.split = split + self.data_paths = data_paths + self.img_list = [] + self.mask_list = [] + for data_path in data_paths: + img_path = os.path.join(data_path, self.IMAGE_PATH) + mask_path = os.path.join(data_path, self.MASK_PATH) + self.img_list.extend(self.get_filenames(img_path)) + self.mask_list.extend(self.get_filenames(mask_path)) + # Split between train and valid set (80/20) + + random_inst = random.Random(12345) # for repeatability + n_items = len(self.img_list) + print(f'Found {n_items} possible patches to use') + range_n_items = range(n_items) + range_n_items = random_inst.sample(range_n_items, int(n_items*0.5)) + idxs = set(random_inst.sample(range_n_items, len(range_n_items) // 5)) + total_idxs = set(range_n_items) + if split == 'train': + idxs = total_idxs - idxs + print(f'> Using {len(idxs)} patches for this dataset ({split})') + self.img_list = [self.img_list[i] for i in idxs] + self.mask_list = [self.mask_list[i] for i in idxs] + print(f'>> {split}: {len(self.img_list)}') + + def __len__(self): + return len(self.img_list) + + def __getitem__(self, idx, transpose=True): + + # load image + img = np.load(self.img_list[idx]) + + img = np.clip(img, 0, 1.0) + + # load mask + mask = np.load(self.mask_list[idx]) + + mask = np.argmax(mask, axis=-1) + + mask = mask-1 + + # perform transformations + img = self.transform(img) + + return img, mask + + def get_filenames(self, path): + """ + Returns a list of absolute paths to images inside given `path` + """ + files_list = [] + for filename in sorted(os.listdir(path)): + files_list.append(os.path.join(path, filename)) + return files_list diff --git a/pytorch_caney/data/datasets/simmim_modis_dataset.py b/pytorch_caney/data/datasets/simmim_modis_dataset.py index cf0c091..c110f79 100644 --- a/pytorch_caney/data/datasets/simmim_modis_dataset.py +++ b/pytorch_caney/data/datasets/simmim_modis_dataset.py @@ -43,11 +43,7 @@ def __init__( if config.MODEL.TYPE in ['swin', 'swinv2']: - model_patch_size = config.MODEL.SWIN.PATCH_SIZE - - elif config.MODEL.TYPE == 'vit': - - model_patch_size = config.MODEL.VIT.PATCH_SIZE + model_patch_size = config.MODEL.SWINV2.PATCH_SIZE else: diff --git a/pytorch_caney/data/transforms.py b/pytorch_caney/data/transforms.py index e1207e4..4fb88e7 100644 --- a/pytorch_caney/data/transforms.py +++ b/pytorch_caney/data/transforms.py @@ -22,11 +22,7 @@ def __init__(self, config): if config.MODEL.TYPE in ['swin', 'swinv2']: - model_patch_size = config.MODEL.SWIN.PATCH_SIZE - - elif config.MODEL.TYPE == 'vit': - - model_patch_size = config.MODEL.VIT.PATCH_SIZE + model_patch_size = config.MODEL.SWINV2.PATCH_SIZE else: @@ -46,3 +42,24 @@ def __call__(self, img): mask = self.mask_generator() return img, mask + + +class TensorResizeTransform: + """ + torchvision transform which transforms the input imagery into + addition to generating a MiM mask + """ + + def __init__(self, config): + + self.transform_img = \ + T.Compose([ + T.ToTensor(), + T.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE)), + ]) + + def __call__(self, img): + + img = self.transform_img(img) + + return img diff --git a/pytorch_caney/loss/build.py b/pytorch_caney/loss/build.py new file mode 100644 index 0000000..abcb9c6 --- /dev/null +++ b/pytorch_caney/loss/build.py @@ -0,0 +1,43 @@ +from segmentation_models_pytorch.losses import TverskyLoss + + +LOSSES = { + 'tversky': TverskyLoss, +} + + +def get_loss_from_dict(loss_name, config): + + try: + + loss_to_use = LOSSES[loss_name] + + except KeyError: + + error_msg = f"{loss_name} is not an implemented loss" + + error_msg = f"{error_msg}. Available loss functions: {LOSSES.keys()}" + + raise KeyError(error_msg) + + if loss_name == 'tversky': + loss = loss_to_use(mode=config.LOSS.MODE, + classes=config.LOSS.CLASSES, + log_loss=config.LOSS.LOG, + from_logits=config.LOSS.LOGITS, + smooth=config.LOSS.SMOOTH, + ignore_index=config.LOSS.IGNORE_INDEX, + eps=config.LOSS.EPS, + alpha=config.LOSS.ALPHA, + beta=config.LOSS.BETA, + gamma=config.LOSS.GAMMA) + return loss + + +def build_loss(config): + + loss_name = config.LOSS.NAME + + loss_to_use = get_loss_from_dict(loss_name, config) + + return loss_to_use diff --git a/pytorch_caney/loss/utils.py b/pytorch_caney/loss/utils.py new file mode 100755 index 0000000..4319803 --- /dev/null +++ b/pytorch_caney/loss/utils.py @@ -0,0 +1,26 @@ +import numpy as np + +import torch + + +# --- +# Adapted from +# https://github.com/qubvel/segmentation_models.pytorch \ +# /tree/master/segmentation_models_pytorch/losses +# --- +def to_tensor(x, dtype=None) -> torch.Tensor: + if isinstance(x, torch.Tensor): + if dtype is not None: + x = x.type(dtype) + return x + if isinstance(x, np.ndarray): + x = torch.from_numpy(x) + if dtype is not None: + x = x.type(dtype) + return x + if isinstance(x, (list, tuple)): + x = np.array(x) + x = torch.from_numpy(x) + if dtype is not None: + x = x.type(dtype) + return x diff --git a/pytorch_caney/losses.py b/pytorch_caney/losses.py deleted file mode 100755 index a361a2c..0000000 --- a/pytorch_caney/losses.py +++ /dev/null @@ -1,5 +0,0 @@ - - -__author__ = "Jordan A Caraballo-Vega, Science Data Processing Branch" -__email__ = "jordan.a.caraballo-vega@nasa.gov" -__status__ = "Production" diff --git a/pytorch_caney/models/build.py b/pytorch_caney/models/build.py new file mode 100644 index 0000000..2094df7 --- /dev/null +++ b/pytorch_caney/models/build.py @@ -0,0 +1,66 @@ +from .swinv2_model import SwinTransformerV2 +from .unet_swin_model import unet_swin + +from .simmim.simmim import build_mim_model + +from ..training.simmim_utils import load_pretrained + +import logging + +# from pytorch_caney.models.simmim.simmim \ +# import build_mim_model + +# from pytorch_caney.training.simmim_utils \ +# import build_optimizer, save_checkpoint + + +def build_model(config, + pretrain: bool = False, + pretrain_method: str = 'mim', + logger: logging.Logger = None): + + if pretrain: + + if pretrain_method == 'mim': + model = build_mim_model(config) + return model + + encoder_architecture = config.MODEL.TYPE + decoder_architecture = config.MODEL.DECODER + + if encoder_architecture == 'swinv2': + logger.info('Hit encoder only build') + window_sizes = config.MODEL.SWINV2.PRETRAINED_WINDOW_SIZES + + model = SwinTransformerV2( + img_size=config.DATA.IMG_SIZE, + patch_size=config.MODEL.SWINV2.PATCH_SIZE, + in_chans=config.MODEL.SWINV2.IN_CHANS, + num_classes=config.MODEL.NUM_CLASSES, + embed_dim=config.MODEL.SWINV2.EMBED_DIM, + depths=config.MODEL.SWINV2.DEPTHS, + num_heads=config.MODEL.SWINV2.NUM_HEADS, + window_size=config.MODEL.SWINV2.WINDOW_SIZE, + mlp_ratio=config.MODEL.SWINV2.MLP_RATIO, + qkv_bias=config.MODEL.SWINV2.QKV_BIAS, + drop_rate=config.MODEL.DROP_RATE, + drop_path_rate=config.MODEL.DROP_PATH_RATE, + ape=config.MODEL.SWINV2.APE, + patch_norm=config.MODEL.SWINV2.PATCH_NORM, + use_checkpoint=config.TRAIN.USE_CHECKPOINT, + pretrained_window_sizes=window_sizes) + + if decoder_architecture == 'unet': + + num_classes = config.MODEL.NUM_CLASSES + + if config.MODEL.PRETRAINED and (not config.MODEL.RESUME): + load_pretrained(config, model, logger) + + model = unet_swin(encoder=model, num_classes=num_classes) + + else: + error_msg = f'Unknown decoder architecture: {decoder_architecture}' + raise NotImplementedError(error_msg) + + return model diff --git a/pytorch_caney/models/unet_swin_model.py b/pytorch_caney/models/unet_swin_model.py index 8e29ede..c578156 100644 --- a/pytorch_caney/models/unet_swin_model.py +++ b/pytorch_caney/models/unet_swin_model.py @@ -31,8 +31,8 @@ def __init__(self, encoder, num_classes=9): kernel_size=self.KERNEL_SIZE, upsampling=self.UPSAMPLING) - def forward(self, input): - encoder_featrue = self.encoder.get_unet_feature(input) + def forward(self, x): + encoder_featrue = self.encoder.get_unet_feature(x) decoder_output = self.decoder(*encoder_featrue) masks = self.segmentation_head(decoder_output) diff --git a/pytorch_caney/pipelines/finetuning/finetune.py b/pytorch_caney/pipelines/finetuning/finetune.py new file mode 100644 index 0000000..e9ca259 --- /dev/null +++ b/pytorch_caney/pipelines/finetuning/finetune.py @@ -0,0 +1,454 @@ +from pytorch_caney.models.build import build_model + +from pytorch_caney.data.datamodules.finetune_datamodule \ + import build_finetune_dataloaders + +from pytorch_caney.training.simmim_utils \ + import build_optimizer, save_checkpoint, reduce_tensor + +from pytorch_caney.config import get_config +from pytorch_caney.loss.build import build_loss +from pytorch_caney.lr_scheduler import build_scheduler, setup_scaled_lr +from pytorch_caney.logging import create_logger +from pytorch_caney.training.simmim_utils import get_grad_norm + +import argparse +import datetime +import joblib +import numpy as np +import os +import time + +import torch +import torch.cuda.amp as amp +import torch.backends.cudnn as cudnn +import torch.distributed as dist + +from timm.utils import AverageMeter + + +def parse_args(): + """ + Parse command-line arguments + """ + + parser = argparse.ArgumentParser( + 'pytorch-caney finetuning', + add_help=False) + + parser.add_argument( + '--cfg', + type=str, + required=True, + metavar="FILE", + help='path to config file') + + parser.add_argument( + "--data-paths", + nargs='+', + required=True, + help="paths where dataset is stored") + + parser.add_argument( + '--dataset', + type=str, + required=True, + help='Dataset to use') + + parser.add_argument( + '--pretrained', + type=str, + help='path to pre-trained model') + + parser.add_argument( + '--batch-size', + type=int, + help="batch size for single GPU") + + parser.add_argument( + '--resume', + help='resume from checkpoint') + + parser.add_argument( + '--accumulation-steps', + type=int, + help="gradient accumulation steps") + + parser.add_argument( + '--use-checkpoint', + action='store_true', + help="whether to use gradient checkpointing to save memory") + + parser.add_argument( + '--enable-amp', + action='store_true') + + parser.add_argument( + '--disable-amp', + action='store_false', + dest='enable_amp') + + parser.set_defaults(enable_amp=True) + + parser.add_argument( + '--output', + default='output', + type=str, + metavar='PATH', + help='root of output folder, the full path is ' + + '// (default: output)') + + parser.add_argument( + '--tag', + help='tag of experiment') + + args = parser.parse_args() + + config = get_config(args) + + return args, config + + +def train(config, + dataloader_train, + dataloader_val, + model, + model_wo_ddp, + optimizer, + lr_scheduler, + scaler, + criterion): + """ + Start fine-tuning a specific model and dataset. + + Args: + config: config object + dataloader_train: training pytorch dataloader + dataloader_val: validation pytorch dataloader + model: model to pre-train + model_wo_ddp: model to pre-train that is not the DDP version + optimizer: pytorch optimizer + lr_scheduler: learning-rate scheduler + scaler: loss scaler + criterion: loss function to use for fine-tuning + """ + + logger.info("Start fine-tuning") + + start_time = time.time() + + for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS): + + dataloader_train.sampler.set_epoch(epoch) + + execute_one_epoch(config, model, dataloader_train, + optimizer, criterion, epoch, lr_scheduler, scaler) + + loss = validate(config, model, dataloader_val, criterion) + + logger.info(f'Model validation loss: {loss:.3f}%') + + if dist.get_rank() == 0 and \ + (epoch % config.SAVE_FREQ == 0 or + epoch == (config.TRAIN.EPOCHS - 1)): + + save_checkpoint(config, epoch, model_wo_ddp, 0., + optimizer, lr_scheduler, scaler, logger) + + total_time = time.time() - start_time + + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + + logger.info('Training time {}'.format(total_time_str)) + + +def execute_one_epoch(config, + model, + dataloader, + optimizer, + criterion, + epoch, + lr_scheduler, + scaler): + """ + Execute training iterations on a single epoch. + + Args: + config: config object + model: model to pre-train + dataloader: dataloader to use + optimizer: pytorch optimizer + epoch: int epoch number + lr_scheduler: learning-rate scheduler + scaler: loss scaler + """ + model.train() + + optimizer.zero_grad() + + num_steps = len(dataloader) + + # Set up logging meters + batch_time = AverageMeter() + data_time = AverageMeter() + loss_meter = AverageMeter() + norm_meter = AverageMeter() + loss_scale_meter = AverageMeter() + + start = time.time() + end = time.time() + for idx, (samples, targets) in enumerate(dataloader): + + data_time.update(time.time() - start) + + samples = samples.cuda(non_blocking=True) + targets = targets.cuda(non_blocking=True) + + with amp.autocast(enabled=config.ENABLE_AMP): + logits = model(samples) + + if config.TRAIN.ACCUMULATION_STEPS > 1: + loss = criterion(logits, targets) + loss = loss / config.TRAIN.ACCUMULATION_STEPS + scaler.scale(loss).backward() + if config.TRAIN.CLIP_GRAD: + scaler.unscale_(optimizer) + grad_norm = torch.nn.utils.clip_grad_norm_( + model.parameters(), + config.TRAIN.CLIP_GRAD) + else: + grad_norm = get_grad_norm(model.parameters()) + if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0: + scaler.step(optimizer) + optimizer.zero_grad() + scaler.update() + lr_scheduler.step_update(epoch * num_steps + idx) + else: + loss = criterion(logits, targets) + optimizer.zero_grad() + scaler.scale(loss).backward() + if config.TRAIN.CLIP_GRAD: + scaler.unscale_(optimizer) + grad_norm = torch.nn.utils.clip_grad_norm_( + model.parameters(), + config.TRAIN.CLIP_GRAD) + else: + grad_norm = get_grad_norm(model.parameters()) + scaler.step(optimizer) + scaler.update() + lr_scheduler.step_update(epoch * num_steps + idx) + + torch.cuda.synchronize() + + loss_meter.update(loss.item(), targets.size(0)) + norm_meter.update(grad_norm) + loss_scale_meter.update(scaler.get_scale()) + batch_time.update(time.time() - end) + end = time.time() + + if idx % config.PRINT_FREQ == 0: + lr = optimizer.param_groups[0]['lr'] + memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) + etas = batch_time.avg * (num_steps - idx) + logger.info( + f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t' + f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t' + f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t' + f'data_time {data_time.val:.4f} ({data_time.avg:.4f})\t' + f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' + f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t' + f'loss_scale {loss_scale_meter.val:.4f}' + + f' ({loss_scale_meter.avg:.4f})\t' + f'mem {memory_used:.0f}MB') + + epoch_time = time.time() - start + logger.info( + f"EPOCH {epoch} training takes " + + f"{datetime.timedelta(seconds=int(epoch_time))}") + + +@torch.no_grad() +def validate(config, model, dataloader, criterion): + """Validation function which given a model and validation loader + performs a validation run and returns the average loss according + to the criterion. + + Args: + config: config object + model: pytorch model to validate + dataloader: pytorch validation loader + criterion: pytorch-friendly loss function + + Returns: + loss_meter.avg: average of the loss throught the validation + iterations + """ + + model.eval() + + batch_time = AverageMeter() + + loss_meter = AverageMeter() + + end = time.time() + + for idx, (images, target) in enumerate(dataloader): + + images = images.cuda(non_blocking=True) + + target = target.cuda(non_blocking=True) + + # compute output + output = model(images) + + # measure accuracy and record loss + loss = criterion(output, target.long()) + + loss = reduce_tensor(loss) + + loss_meter.update(loss.item(), target.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + + end = time.time() + + if idx % config.PRINT_FREQ == 0: + + memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) + + logger.info( + f'Test: [{idx}/{len(dataloader)}]\t' + f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' + f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' + f'Mem {memory_used:.0f}MB') + + return loss_meter.avg + + +def main(config): + """ + Performs the main function of building model, loader, etc. and starts + training. + """ + + dataloader_train, dataloader_val = build_finetune_dataloaders( + config, logger) + + model = build_finetune_model(config, logger) + + optimizer = build_optimizer(config, + model, + is_pretrain=False, + logger=logger) + + model, model_wo_ddp = make_ddp(model) + + n_iter_per_epoch = len(dataloader_train) + + lr_scheduler = build_scheduler(config, optimizer, n_iter_per_epoch) + + scaler = amp.GradScaler() + + criterion = build_loss(config) + + train(config, + dataloader_train, + dataloader_val, + model, + model_wo_ddp, + optimizer, + lr_scheduler, + scaler, + criterion) + + +def build_finetune_model(config, logger): + + logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}") + + model = build_model(config, + pretrain=False, + pretrain_method='mim', + logger=logger) + + model.cuda() + + logger.info(str(model)) + + return model + + +def make_ddp(model): + + model = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[int(os.environ["RANK"])], + broadcast_buffers=False, + find_unused_parameters=True) + + model_without_ddp = model.module + + return model, model_without_ddp + + +def setup_rank_worldsize(): + if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + rank = int(os.environ["RANK"]) + world_size = int(os.environ['WORLD_SIZE']) + print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}") + else: + rank = -1 + world_size = -1 + return rank, world_size + + +def setup_distributed_processing(rank, world_size): + torch.cuda.set_device(int(os.environ["RANK"])) + torch.distributed.init_process_group( + backend='nccl', init_method='env://', world_size=world_size, rank=rank) + torch.distributed.barrier() + + +def setup_seeding(config): + seed = config.SEED + dist.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + + +if __name__ == '__main__': + _, config = parse_args() + + rank, world_size = setup_rank_worldsize() + + setup_distributed_processing(rank, world_size) + + setup_seeding(config) + + cudnn.benchmark = True + + linear_scaled_lr, linear_scaled_min_lr, linear_scaled_warmup_lr = \ + setup_scaled_lr(config) + + config.defrost() + config.TRAIN.BASE_LR = linear_scaled_lr + config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr + config.TRAIN.MIN_LR = linear_scaled_min_lr + config.freeze() + + os.makedirs(config.OUTPUT, exist_ok=True) + logger = create_logger(output_dir=config.OUTPUT, + dist_rank=dist.get_rank(), + name=f"{config.MODEL.NAME}") + + if dist.get_rank() == 0: + path = os.path.join(config.OUTPUT, "config.json") + with open(path, "w") as f: + f.write(config.dump()) + logger.info(f"Full config saved to {path}") + logger.info(config.dump()) + config_file_name = f'{config.TAG}.config.sav' + config_file_path = os.path.join(config.OUTPUT, config_file_name) + joblib.dump(config, config_file_path) + + main(config) diff --git a/pytorch_caney/pipelines/pretraining/simmim.py b/pytorch_caney/pipelines/pretraining/mim.py similarity index 91% rename from pytorch_caney/pipelines/pretraining/simmim.py rename to pytorch_caney/pipelines/pretraining/mim.py index c8f0203..0f40089 100644 --- a/pytorch_caney/pipelines/pretraining/simmim.py +++ b/pytorch_caney/pipelines/pretraining/mim.py @@ -28,8 +28,11 @@ def parse_args(): + """ + Parse command-line arguments + """ parser = argparse.ArgumentParser( - 'pytorch-caney impletmentation of SimMiM pre-training script', + 'pytorch-caney implementation of MiM pre-training script', add_help=False) parser.add_argument( @@ -107,6 +110,18 @@ def train(config, optimizer, lr_scheduler, scaler): + """ + Start pre-training a specific model and dataset. + + Args: + config: config object + dataloader: dataloader to use + model: model to pre-train + model_wo_ddp: model to pre-train that is not the DDP version + optimizer: pytorch optimizer + lr_scheduler: learning-rate scheduler + scaler: loss scaler + """ logger.info("Start training") @@ -140,6 +155,18 @@ def execute_one_epoch(config, epoch, lr_scheduler, scaler): + """ + Execute training iterations on a single epoch. + + Args: + config: config object + model: model to pre-train + dataloader: dataloader to use + optimizer: pytorch optimizer + epoch: int epoch number + lr_scheduler: learning-rate scheduler + scaler: loss scaler + """ model.train() @@ -226,6 +253,12 @@ def execute_one_epoch(config, def main(config): + """ + Starts training process after building the proper model, optimizer, etc. + + Args: + config: config object + """ pretrain_data_loader = build_mim_dataloader(config, logger) diff --git a/requirements/Dockerfile b/requirements/Dockerfile index 4def6c7..d68964b 100644 --- a/requirements/Dockerfile +++ b/requirements/Dockerfile @@ -108,6 +108,7 @@ RUN pip --no-cache-dir install omegaconf \ sphinx_rtd_theme \ yacs \ termcolor \ + segmentation-models-pytorch \ GDAL==`ogrinfo --version | grep -Eo '[0-9]\.[0-9]\.[0-9]+'` HEALTHCHECK NONE diff --git a/requirements/Dockerfile.dev b/requirements/Dockerfile.dev index 4def6c7..da66593 100644 --- a/requirements/Dockerfile.dev +++ b/requirements/Dockerfile.dev @@ -108,6 +108,7 @@ RUN pip --no-cache-dir install omegaconf \ sphinx_rtd_theme \ yacs \ termcolor \ + segmentation-models-pytorch \ GDAL==`ogrinfo --version | grep -Eo '[0-9]\.[0-9]\.[0-9]+'` HEALTHCHECK NONE diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 57ca9c4..08c3785 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -22,4 +22,5 @@ sphinx sphinx_rtd_theme yacs termcolor +segmentation-models-pytorch GDAL==`ogrinfo --version | grep -Eo '[0-9]\.[0-9]\.[0-9]+'` \ No newline at end of file