Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added initial finetuning #22

Merged
merged 6 commits into from
Aug 16, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ Please see our [guide for contributing to pytorch-caney](CONTRIBUTING.md).
| :---: | :---: | :---: | :---: |
| SatVision-B | MODIS-1.9-M | 192x192 | 84.5M |

## SatVision Datasets

| name | bands | resolution | #params |
| :---: | :---: | :---: | :---: |
| MODIS-Small | 7 | 128x128 | 1,994,131 |

## Pre-training with Masked Image Modeling
To pre-train the swinv2 base model with masked image modeling pre-training, run:
```bash
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
MODEL:
TYPE: swinv2
DECODER: unet
NAME: satvision_finetune_lc5class
DROP_PATH_RATE: 0.1
NUM_CLASSES: 5
SWINV2:
IN_CHANS: 7
EMBED_DIM: 128
DEPTHS: [ 2, 2, 18, 2 ]
NUM_HEADS: [ 4, 8, 16, 32 ]
WINDOW_SIZE: 14
PRETRAINED_WINDOW_SIZES: [ 12, 12, 12, 6 ]
DATA:
IMG_SIZE: 224
DATASET: MODISLC5
MASK_PATCH_SIZE: 32
MASK_RATIO: 0.6
LOSS:
NAME: 'tversky'
MODE: 'multiclass'
ALPHA: 0.4
BETA: 0.6
TRAIN:
EPOCHS: 100
WARMUP_EPOCHS: 10
BASE_LR: 1e-4
WARMUP_LR: 5e-7
WEIGHT_DECAY: 0.01
LAYER_DECAY: 0.8
PRINT_FREQ: 100
SAVE_FREQ: 5
TAG: satvision_finetune_land_cover_5class_swinv2_satvision_192_window12__800ep
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
MODEL:
TYPE: swinv2
DECODER: unet
NAME: satvision_finetune_lc9class
DROP_PATH_RATE: 0.1
NUM_CLASSES: 9
SWINV2:
IN_CHANS: 7
EMBED_DIM: 128
DEPTHS: [ 2, 2, 18, 2 ]
NUM_HEADS: [ 4, 8, 16, 32 ]
WINDOW_SIZE: 14
PRETRAINED_WINDOW_SIZES: [ 12, 12, 12, 6 ]
DATA:
IMG_SIZE: 224
DATASET: MODISLC5
MASK_PATCH_SIZE: 32
MASK_RATIO: 0.6
LOSS:
NAME: 'tversky'
MODE: 'multiclass'
ALPHA: 0.4
BETA: 0.6
TRAIN:
EPOCHS: 100
WARMUP_EPOCHS: 10
BASE_LR: 1e-4
WARMUP_LR: 5e-7
WEIGHT_DECAY: 0.01
LAYER_DECAY: 0.8
PRINT_FREQ: 100
SAVE_FREQ: 5
TAG: satvision_finetune_land_cover_9class_swinv2_satvision_192_window12__800ep
20 changes: 20 additions & 0 deletions examples/satvision/run_satvision_finetune_lc_fiveclass.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#!/bin/bash

#SBATCH -J finetune_satvision_lc5
#SBATCH -t 3-00:00:00
#SBATCH -G 4
#SBATCH -N 1


export PYTHONPATH=$PWD:../../../:../../../pytorch-caney
export NGPUS=8

torchrun --nproc_per_node $NGPUS \
../../../pytorch-caney/pytorch_caney/pipelines/finetuning/finetune.py \
--cfg finetune_satvision_base_landcover5class_192_window12_100ep.yaml \
--pretrained /explore/nobackup/people/cssprad1/projects/satnet/code/development/masked_image_modeling/development/models/simmim_satnet_pretrain_pretrain/simmim_pretrain__satnet_swinv2_base__img192_window12__800ep_v3_no_norm/ckpt_epoch_800.pth \
--dataset MODISLC9 \
--data-paths /explore/nobackup/projects/ilab/data/satvision/finetuning/h18v04/labels_9classes_224 \
--batch-size 4 \
--output /explore/nobackup/people/cssprad1/projects/satnet/code/development/cleanup/finetune/models \
--enable-amp
20 changes: 20 additions & 0 deletions examples/satvision/run_satvision_finetune_lc_nineclass.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#!/bin/bash

#SBATCH -J finetune_satvision_lc9
#SBATCH -t 3-00:00:00
#SBATCH -G 4
#SBATCH -N 1


export PYTHONPATH=$PWD:../../../:../../../pytorch-caney
export NGPUS=8

torchrun --nproc_per_node $NGPUS \
../../../pytorch-caney/pytorch_caney/pipelines/finetuning/finetune.py \
--cfg finetune_satvision_base_landcover5class_192_window12_100ep.yaml \
--pretrained /explore/nobackup/people/cssprad1/projects/satnet/code/development/masked_image_modeling/development/models/simmim_satnet_pretrain_pretrain/simmim_pretrain__satnet_swinv2_base__img192_window12__800ep_v3_no_norm/ckpt_epoch_800.pth \
--dataset MODISLC9 \
--data-paths /explore/nobackup/projects/ilab/data/satvision/finetuning/h18v04/labels_5classes_224 \
--batch-size 4 \
--output /explore/nobackup/people/cssprad1/projects/satnet/code/development/cleanup/finetune/models \
--enable-amp
79 changes: 20 additions & 59 deletions pytorch_caney/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
_C.MODEL = CN()
# Model type
_C.MODEL.TYPE = 'swinv2'
# Decoder type
_C.MODEL.DECODER = None
# Model name
_C.MODEL.NAME = 'swinv2_base_patch4_window7_224'
# Pretrained weight from checkpoint, could be from previous pre-training
Expand All @@ -50,20 +52,6 @@
# Drop path rate
_C.MODEL.DROP_PATH_RATE = 0.1

# Swin Transformer parameters
_C.MODEL.SWIN = CN()
_C.MODEL.SWIN.PATCH_SIZE = 4
_C.MODEL.SWIN.IN_CHANS = 4
_C.MODEL.SWIN.EMBED_DIM = 96
_C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]
_C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]
_C.MODEL.SWIN.WINDOW_SIZE = 7
_C.MODEL.SWIN.MLP_RATIO = 4.
_C.MODEL.SWIN.QKV_BIAS = True
_C.MODEL.SWIN.QK_SCALE = None
_C.MODEL.SWIN.APE = False
_C.MODEL.SWIN.PATCH_NORM = True

# Swin Transformer V2 parameters
_C.MODEL.SWINV2 = CN()
_C.MODEL.SWINV2.PATCH_SIZE = 4
Expand All @@ -78,20 +66,21 @@
_C.MODEL.SWINV2.PATCH_NORM = True
_C.MODEL.SWINV2.PRETRAINED_WINDOW_SIZES = [0, 0, 0, 0]

# Vision Transformer parameters
_C.MODEL.VIT = CN()
_C.MODEL.VIT.PATCH_SIZE = 16
_C.MODEL.VIT.IN_CHANS = 3
_C.MODEL.VIT.EMBED_DIM = 768
_C.MODEL.VIT.DEPTH = 12
_C.MODEL.VIT.NUM_HEADS = 12
_C.MODEL.VIT.MLP_RATIO = 4
_C.MODEL.VIT.QKV_BIAS = True
_C.MODEL.VIT.INIT_VALUES = 0.1
_C.MODEL.VIT.USE_APE = False
_C.MODEL.VIT.USE_RPB = False
_C.MODEL.VIT.USE_SHARED_RPB = True
_C.MODEL.VIT.USE_MEAN_POOLING = False
# -----------------------------------------------------------------------------
# Training settings
# -----------------------------------------------------------------------------
_C.LOSS = CN()
_C.LOSS.NAME = 'tversky'
_C.LOSS.MODE = 'multiclass'
_C.LOSS.CLASSES = None
_C.LOSS.LOG = False
_C.LOSS.LOGITS = True
_C.LOSS.SMOOTH = 0.0
_C.LOSS.IGNORE_INDEX = None
_C.LOSS.EPS = 1e-7
_C.LOSS.ALPHA = 0.5
_C.LOSS.BETA = 0.5
_C.LOSS.GAMMA = 1.0

# -----------------------------------------------------------------------------
# Training settings
Expand Down Expand Up @@ -139,32 +128,6 @@
# [SimMIM] Layer decay for fine-tuning
_C.TRAIN.LAYER_DECAY = 1.0

# -----------------------------------------------------------------------------
# Augmentation settings
# -----------------------------------------------------------------------------
_C.AUG = CN()
# Color jitter factor
_C.AUG.COLOR_JITTER = 0.4
# Use AutoAugment policy. "v0" or "original"
_C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1'
# Random erase prob
_C.AUG.REPROB = 0.25
# Random erase mode
_C.AUG.REMODE = 'pixel'
# Random erase count
_C.AUG.RECOUNT = 1
# Mixup alpha, mixup enabled if > 0
_C.AUG.MIXUP = 0.8
# Cutmix alpha, cutmix enabled if > 0
_C.AUG.CUTMIX = 1.0
# Cutmix min/max ratio, overrides alpha and enables cutmix if set
_C.AUG.CUTMIX_MINMAX = None
# Probability of performing mixup or cutmix when either/both is enabled
_C.AUG.MIXUP_PROB = 1.0
# Probability of switching to cutmix when both mixup and cutmix enabled
_C.AUG.MIXUP_SWITCH_PROB = 0.5
# How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
_C.AUG.MIXUP_MODE = 'batch'

# -----------------------------------------------------------------------------
# Testing settings
Expand All @@ -176,24 +139,22 @@
# -----------------------------------------------------------------------------
# Misc
# -----------------------------------------------------------------------------
# [SimMIM] Whether to enable pytorch amp, overwritten by command line argument
# Whether to enable pytorch amp, overwritten by command line argument
_C.ENABLE_AMP = False
# Enable Pytorch automatic mixed precision (amp).
_C.AMP_ENABLE = True
# Path to output folder, overwritten by command line argument
_C.OUTPUT = ''
# Tag of experiment, overwritten by command line argument
_C.TAG = 'default'
_C.TAG = 'pt-caney-default-tag'
# Frequency to save checkpoint
_C.SAVE_FREQ = 1
# Frequency to logging info
_C.PRINT_FREQ = 10
# Fixed random seed
_C.SEED = 0
_C.SEED = 42
# Perform evaluation only, overwritten by command line argument
_C.EVAL_MODE = False
# Test throughput only, overwritten by command line argument
_C.THROUGHPUT_MODE = False


def _update_config_from_file(config, cfg_file):
Expand Down
93 changes: 93 additions & 0 deletions pytorch_caney/data/datamodules/finetune_datamodule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from ..datasets.modis_dataset import MODISDataset
from ..datasets.modis_lc_five_dataset import MODISLCFiveDataset
from ..datasets.modis_lc_nine_dataset import MODISLCNineDataset

from ..transforms import TensorResizeTransform

import torch.distributed as dist
from torch.utils.data import DataLoader, DistributedSampler


DATASETS = {
'modis': MODISDataset,
'modislc9': MODISLCNineDataset,
'modislc5': MODISLCFiveDataset,
# 'modis tree': MODISTree,
}


def get_dataset_from_dict(dataset_name: str):

dataset_name = dataset_name.lower()

try:

dataset_to_use = DATASETS[dataset_name]

except KeyError:

error_msg = f"{dataset_name} is not an existing dataset"

error_msg = f"{error_msg}. Available datasets: {DATASETS.keys()}"

raise KeyError(error_msg)

return dataset_to_use


def build_finetune_dataloaders(config, logger):

transform = TensorResizeTransform(config)

logger.info(f'Finetuning data transform:\n{transform}')

dataset_name = config.DATA.DATASET

logger.info(f'Dataset: {dataset_name}')
logger.info(f'Data Paths: {config.DATA.DATA_PATHS}')

dataset_to_use = get_dataset_from_dict(dataset_name)

logger.info(f'Dataset obj: {dataset_to_use}')

dataset_train = dataset_to_use(data_paths=config.DATA.DATA_PATHS,
split="train",
img_size=config.DATA.IMG_SIZE,
transform=transform)

dataset_val = dataset_to_use(data_paths=config.DATA.DATA_PATHS,
split="val",
img_size=config.DATA.IMG_SIZE,
transform=transform)

logger.info(f'Build dataset: train images = {len(dataset_train)}')

logger.info(f'Build dataset: val images = {len(dataset_val)}')

sampler_train = DistributedSampler(
dataset_train,
num_replicas=dist.get_world_size(),
rank=dist.get_rank(),
shuffle=True)

sampler_val = DistributedSampler(
dataset_val,
num_replicas=dist.get_world_size(),
rank=dist.get_rank(),
shuffle=False)

dataloader_train = DataLoader(dataset_train,
config.DATA.BATCH_SIZE,
sampler=sampler_train,
num_workers=config.DATA.NUM_WORKERS,
pin_memory=True,
drop_last=True)

dataloader_val = DataLoader(dataset_val,
config.DATA.BATCH_SIZE,
sampler=sampler_val,
num_workers=config.DATA.NUM_WORKERS,
pin_memory=True,
drop_last=False)

return dataloader_train, dataloader_val
2 changes: 1 addition & 1 deletion pytorch_caney/data/datasets/modis_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,6 @@ def get_filenames(self, path):
Returns a list of absolute paths to images inside given `path`
"""
files_list = []
for filename in os.listdir(path):
for filename in sorted(os.listdir(path)):
files_list.append(os.path.join(path, filename))
return files_list
Loading