Skip to content

Commit

Permalink
Merge pull request #22 from nasa-nccs-hpda/finetuning
Browse files Browse the repository at this point in the history
Added initial finetuning
  • Loading branch information
cssprad1 authored Aug 16, 2023
2 parents 120988a + e70acdd commit ffd776a
Show file tree
Hide file tree
Showing 23 changed files with 1,034 additions and 81 deletions.
11 changes: 9 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ Please see our [guide for contributing to pytorch-caney](CONTRIBUTING.md).
| :---: | :---: | :---: | :---: |
| SatVision-B | MODIS-1.9-M | 192x192 | 84.5M |

## SatVision Datasets

| name | bands | resolution | #chips |
| :---: | :---: | :---: | :---: |
| MODIS-Small | 7 | 128x128 | 1,994,131 |

## Pre-training with Masked Image Modeling
To pre-train the swinv2 base model with masked image modeling pre-training, run:
```bash
Expand All @@ -61,13 +67,14 @@ For example to run on a compute node with 4 GPUs and a batch size of 128 on the

```bash
singularity shell --nv -B <mounts> /path/to/container/pytorch-caney
Singularity> torchrun --nproc_per_node 4 --cfg examples/satvision/mim_pretrain_swinv2_satvision_base_192_window12_800ep.yaml --dataset MODIS --data-paths /explore/nobackup/projects/ilab/data/satvision/pretraining/training_* --batch-size 128 --output . --enable-amp
Singularity> export PYTHONPATH=$PWD:$PWD/pytorch-caney
Singularity> torchrun --nproc_per_node 4 pytorch-caney/pytorch_caney/pipelines/pretraining/mim.py --cfg pytorch-caney/examples/satvision/mim_pretrain_swinv2_satvision_base_192_window12_800ep.yaml --dataset MODIS --data-paths /explore/nobackup/projects/ilab/data/satvision/pretraining/training_* --batch-size 128 --output . --enable-amp
```

This example script runs the exact configuration used to make the SatVision-base model pre-training with MiM and the MODIS pre-training dataset.
```bash
singularity shell --nv -B <mounts> /path/to/container/pytorch-caney
Singularity> cd examples/satvision
Singularity> cd pytorch-caney/examples/satvision
Singularity> ./run_satvision_pretrain.sh
```

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
MODEL:
TYPE: swinv2
DECODER: unet
NAME: satvision_finetune_lc5class
DROP_PATH_RATE: 0.1
NUM_CLASSES: 5
SWINV2:
IN_CHANS: 7
EMBED_DIM: 128
DEPTHS: [ 2, 2, 18, 2 ]
NUM_HEADS: [ 4, 8, 16, 32 ]
WINDOW_SIZE: 14
PRETRAINED_WINDOW_SIZES: [ 12, 12, 12, 6 ]
DATA:
IMG_SIZE: 224
DATASET: MODISLC5
MASK_PATCH_SIZE: 32
MASK_RATIO: 0.6
LOSS:
NAME: 'tversky'
MODE: 'multiclass'
ALPHA: 0.4
BETA: 0.6
TRAIN:
EPOCHS: 100
WARMUP_EPOCHS: 10
BASE_LR: 1e-4
WARMUP_LR: 5e-7
WEIGHT_DECAY: 0.01
LAYER_DECAY: 0.8
PRINT_FREQ: 100
SAVE_FREQ: 5
TAG: satvision_finetune_land_cover_5class_swinv2_satvision_192_window12__800ep
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
MODEL:
TYPE: swinv2
DECODER: unet
NAME: satvision_finetune_lc9class
DROP_PATH_RATE: 0.1
NUM_CLASSES: 9
SWINV2:
IN_CHANS: 7
EMBED_DIM: 128
DEPTHS: [ 2, 2, 18, 2 ]
NUM_HEADS: [ 4, 8, 16, 32 ]
WINDOW_SIZE: 14
PRETRAINED_WINDOW_SIZES: [ 12, 12, 12, 6 ]
DATA:
IMG_SIZE: 224
DATASET: MODISLC5
MASK_PATCH_SIZE: 32
MASK_RATIO: 0.6
LOSS:
NAME: 'tversky'
MODE: 'multiclass'
ALPHA: 0.4
BETA: 0.6
TRAIN:
EPOCHS: 100
WARMUP_EPOCHS: 10
BASE_LR: 1e-4
WARMUP_LR: 5e-7
WEIGHT_DECAY: 0.01
LAYER_DECAY: 0.8
PRINT_FREQ: 100
SAVE_FREQ: 5
TAG: satvision_finetune_land_cover_9class_swinv2_satvision_192_window12__800ep
20 changes: 20 additions & 0 deletions examples/satvision/run_satvision_finetune_lc_fiveclass.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#!/bin/bash

#SBATCH -J finetune_satvision_lc5
#SBATCH -t 3-00:00:00
#SBATCH -G 4
#SBATCH -N 1


export PYTHONPATH=$PWD:../../../:../../../pytorch-caney
export NGPUS=8

torchrun --nproc_per_node $NGPUS \
../../../pytorch-caney/pytorch_caney/pipelines/finetuning/finetune.py \
--cfg finetune_satvision_base_landcover5class_192_window12_100ep.yaml \
--pretrained /explore/nobackup/people/cssprad1/projects/satnet/code/development/masked_image_modeling/development/models/simmim_satnet_pretrain_pretrain/simmim_pretrain__satnet_swinv2_base__img192_window12__800ep_v3_no_norm/ckpt_epoch_800.pth \
--dataset MODISLC9 \
--data-paths /explore/nobackup/projects/ilab/data/satvision/finetuning/h18v04/labels_9classes_224 \
--batch-size 4 \
--output /explore/nobackup/people/cssprad1/projects/satnet/code/development/cleanup/finetune/models \
--enable-amp
20 changes: 20 additions & 0 deletions examples/satvision/run_satvision_finetune_lc_nineclass.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#!/bin/bash

#SBATCH -J finetune_satvision_lc9
#SBATCH -t 3-00:00:00
#SBATCH -G 4
#SBATCH -N 1


export PYTHONPATH=$PWD:../../../:../../../pytorch-caney
export NGPUS=8

torchrun --nproc_per_node $NGPUS \
../../../pytorch-caney/pytorch_caney/pipelines/finetuning/finetune.py \
--cfg finetune_satvision_base_landcover5class_192_window12_100ep.yaml \
--pretrained /explore/nobackup/people/cssprad1/projects/satnet/code/development/masked_image_modeling/development/models/simmim_satnet_pretrain_pretrain/simmim_pretrain__satnet_swinv2_base__img192_window12__800ep_v3_no_norm/ckpt_epoch_800.pth \
--dataset MODISLC9 \
--data-paths /explore/nobackup/projects/ilab/data/satvision/finetuning/h18v04/labels_5classes_224 \
--batch-size 4 \
--output /explore/nobackup/people/cssprad1/projects/satnet/code/development/cleanup/finetune/models \
--enable-amp
2 changes: 1 addition & 1 deletion examples/satvision/run_satvision_pretrain.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ export PYTHONPATH=$PWD:../../../:../../../pytorch-caney
export NGPUS=4

torchrun --nproc_per_node $NGPUS \
../../../pytorch-caney/pytorch_caney/pipelines/pretraining/simmim.py \
../../../pytorch-caney/pytorch_caney/pipelines/pretraining/mim.py \
--cfg mim_pretrain_swinv2_satvision_base_192_window12_800ep.yaml \
--dataset MODIS \
--data-paths /explore/nobackup/projects/ilab/data/satvision/pretraining/training_* \
Expand Down
79 changes: 20 additions & 59 deletions pytorch_caney/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
_C.MODEL = CN()
# Model type
_C.MODEL.TYPE = 'swinv2'
# Decoder type
_C.MODEL.DECODER = None
# Model name
_C.MODEL.NAME = 'swinv2_base_patch4_window7_224'
# Pretrained weight from checkpoint, could be from previous pre-training
Expand All @@ -50,20 +52,6 @@
# Drop path rate
_C.MODEL.DROP_PATH_RATE = 0.1

# Swin Transformer parameters
_C.MODEL.SWIN = CN()
_C.MODEL.SWIN.PATCH_SIZE = 4
_C.MODEL.SWIN.IN_CHANS = 4
_C.MODEL.SWIN.EMBED_DIM = 96
_C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]
_C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]
_C.MODEL.SWIN.WINDOW_SIZE = 7
_C.MODEL.SWIN.MLP_RATIO = 4.
_C.MODEL.SWIN.QKV_BIAS = True
_C.MODEL.SWIN.QK_SCALE = None
_C.MODEL.SWIN.APE = False
_C.MODEL.SWIN.PATCH_NORM = True

# Swin Transformer V2 parameters
_C.MODEL.SWINV2 = CN()
_C.MODEL.SWINV2.PATCH_SIZE = 4
Expand All @@ -78,20 +66,21 @@
_C.MODEL.SWINV2.PATCH_NORM = True
_C.MODEL.SWINV2.PRETRAINED_WINDOW_SIZES = [0, 0, 0, 0]

# Vision Transformer parameters
_C.MODEL.VIT = CN()
_C.MODEL.VIT.PATCH_SIZE = 16
_C.MODEL.VIT.IN_CHANS = 3
_C.MODEL.VIT.EMBED_DIM = 768
_C.MODEL.VIT.DEPTH = 12
_C.MODEL.VIT.NUM_HEADS = 12
_C.MODEL.VIT.MLP_RATIO = 4
_C.MODEL.VIT.QKV_BIAS = True
_C.MODEL.VIT.INIT_VALUES = 0.1
_C.MODEL.VIT.USE_APE = False
_C.MODEL.VIT.USE_RPB = False
_C.MODEL.VIT.USE_SHARED_RPB = True
_C.MODEL.VIT.USE_MEAN_POOLING = False
# -----------------------------------------------------------------------------
# Training settings
# -----------------------------------------------------------------------------
_C.LOSS = CN()
_C.LOSS.NAME = 'tversky'
_C.LOSS.MODE = 'multiclass'
_C.LOSS.CLASSES = None
_C.LOSS.LOG = False
_C.LOSS.LOGITS = True
_C.LOSS.SMOOTH = 0.0
_C.LOSS.IGNORE_INDEX = None
_C.LOSS.EPS = 1e-7
_C.LOSS.ALPHA = 0.5
_C.LOSS.BETA = 0.5
_C.LOSS.GAMMA = 1.0

# -----------------------------------------------------------------------------
# Training settings
Expand Down Expand Up @@ -139,32 +128,6 @@
# [SimMIM] Layer decay for fine-tuning
_C.TRAIN.LAYER_DECAY = 1.0

# -----------------------------------------------------------------------------
# Augmentation settings
# -----------------------------------------------------------------------------
_C.AUG = CN()
# Color jitter factor
_C.AUG.COLOR_JITTER = 0.4
# Use AutoAugment policy. "v0" or "original"
_C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1'
# Random erase prob
_C.AUG.REPROB = 0.25
# Random erase mode
_C.AUG.REMODE = 'pixel'
# Random erase count
_C.AUG.RECOUNT = 1
# Mixup alpha, mixup enabled if > 0
_C.AUG.MIXUP = 0.8
# Cutmix alpha, cutmix enabled if > 0
_C.AUG.CUTMIX = 1.0
# Cutmix min/max ratio, overrides alpha and enables cutmix if set
_C.AUG.CUTMIX_MINMAX = None
# Probability of performing mixup or cutmix when either/both is enabled
_C.AUG.MIXUP_PROB = 1.0
# Probability of switching to cutmix when both mixup and cutmix enabled
_C.AUG.MIXUP_SWITCH_PROB = 0.5
# How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
_C.AUG.MIXUP_MODE = 'batch'

# -----------------------------------------------------------------------------
# Testing settings
Expand All @@ -176,24 +139,22 @@
# -----------------------------------------------------------------------------
# Misc
# -----------------------------------------------------------------------------
# [SimMIM] Whether to enable pytorch amp, overwritten by command line argument
# Whether to enable pytorch amp, overwritten by command line argument
_C.ENABLE_AMP = False
# Enable Pytorch automatic mixed precision (amp).
_C.AMP_ENABLE = True
# Path to output folder, overwritten by command line argument
_C.OUTPUT = ''
# Tag of experiment, overwritten by command line argument
_C.TAG = 'default'
_C.TAG = 'pt-caney-default-tag'
# Frequency to save checkpoint
_C.SAVE_FREQ = 1
# Frequency to logging info
_C.PRINT_FREQ = 10
# Fixed random seed
_C.SEED = 0
_C.SEED = 42
# Perform evaluation only, overwritten by command line argument
_C.EVAL_MODE = False
# Test throughput only, overwritten by command line argument
_C.THROUGHPUT_MODE = False


def _update_config_from_file(config, cfg_file):
Expand Down
93 changes: 93 additions & 0 deletions pytorch_caney/data/datamodules/finetune_datamodule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from ..datasets.modis_dataset import MODISDataset
from ..datasets.modis_lc_five_dataset import MODISLCFiveDataset
from ..datasets.modis_lc_nine_dataset import MODISLCNineDataset

from ..transforms import TensorResizeTransform

import torch.distributed as dist
from torch.utils.data import DataLoader, DistributedSampler


DATASETS = {
'modis': MODISDataset,
'modislc9': MODISLCNineDataset,
'modislc5': MODISLCFiveDataset,
# 'modis tree': MODISTree,
}


def get_dataset_from_dict(dataset_name: str):

dataset_name = dataset_name.lower()

try:

dataset_to_use = DATASETS[dataset_name]

except KeyError:

error_msg = f"{dataset_name} is not an existing dataset"

error_msg = f"{error_msg}. Available datasets: {DATASETS.keys()}"

raise KeyError(error_msg)

return dataset_to_use


def build_finetune_dataloaders(config, logger):

transform = TensorResizeTransform(config)

logger.info(f'Finetuning data transform:\n{transform}')

dataset_name = config.DATA.DATASET

logger.info(f'Dataset: {dataset_name}')
logger.info(f'Data Paths: {config.DATA.DATA_PATHS}')

dataset_to_use = get_dataset_from_dict(dataset_name)

logger.info(f'Dataset obj: {dataset_to_use}')

dataset_train = dataset_to_use(data_paths=config.DATA.DATA_PATHS,
split="train",
img_size=config.DATA.IMG_SIZE,
transform=transform)

dataset_val = dataset_to_use(data_paths=config.DATA.DATA_PATHS,
split="val",
img_size=config.DATA.IMG_SIZE,
transform=transform)

logger.info(f'Build dataset: train images = {len(dataset_train)}')

logger.info(f'Build dataset: val images = {len(dataset_val)}')

sampler_train = DistributedSampler(
dataset_train,
num_replicas=dist.get_world_size(),
rank=dist.get_rank(),
shuffle=True)

sampler_val = DistributedSampler(
dataset_val,
num_replicas=dist.get_world_size(),
rank=dist.get_rank(),
shuffle=False)

dataloader_train = DataLoader(dataset_train,
config.DATA.BATCH_SIZE,
sampler=sampler_train,
num_workers=config.DATA.NUM_WORKERS,
pin_memory=True,
drop_last=True)

dataloader_val = DataLoader(dataset_val,
config.DATA.BATCH_SIZE,
sampler=sampler_val,
num_workers=config.DATA.NUM_WORKERS,
pin_memory=True,
drop_last=False)

return dataloader_train, dataloader_val
2 changes: 1 addition & 1 deletion pytorch_caney/data/datasets/modis_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,6 @@ def get_filenames(self, path):
Returns a list of absolute paths to images inside given `path`
"""
files_list = []
for filename in os.listdir(path):
for filename in sorted(os.listdir(path)):
files_list.append(os.path.join(path, filename))
return files_list
Loading

0 comments on commit ffd776a

Please sign in to comment.