Skip to content

Commit

Permalink
Merge branch 'main' into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
cssprad1 committed Aug 9, 2023
2 parents 4861fb1 + 3ad8b5a commit 120988a
Show file tree
Hide file tree
Showing 33 changed files with 3,088 additions and 1 deletion.
28 changes: 28 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,34 @@ singularity build --sandbox pytorch-caney docker://nasanccs/pytorch-caney:latest

Please see our [guide for contributing to pytorch-caney](CONTRIBUTING.md).

## SatVision

| name | pretrain | resolution | #params |
| :---: | :---: | :---: | :---: |
| SatVision-B | MODIS-1.9-M | 192x192 | 84.5M |

## Pre-training with Masked Image Modeling
To pre-train the swinv2 base model with masked image modeling pre-training, run:
```bash
torchrun --nproc_per_node <NGPUS> --cfg <config-file> --dataset <dataset-name> --data-paths <path-to-data-subfolder-1> --batch-size <batch-size> --output <output-dir> --enable-amp
```

For example to run on a compute node with 4 GPUs and a batch size of 128 on the MODIS SatVision pre-training dataset with a base swinv2 model, run:

```bash
singularity shell --nv -B <mounts> /path/to/container/pytorch-caney
Singularity> torchrun --nproc_per_node 4 --cfg examples/satvision/mim_pretrain_swinv2_satvision_base_192_window12_800ep.yaml --dataset MODIS --data-paths /explore/nobackup/projects/ilab/data/satvision/pretraining/training_* --batch-size 128 --output . --enable-amp
```

This example script runs the exact configuration used to make the SatVision-base model pre-training with MiM and the MODIS pre-training dataset.
```bash
singularity shell --nv -B <mounts> /path/to/container/pytorch-caney
Singularity> cd examples/satvision
Singularity> ./run_satvision_pretrain.sh
```

## References

- [Pytorch Lightning](https://github.com/Lightning-AI/lightning)
- [Swin Transformer](https://github.com/microsoft/Swin-Transformer)
- [SimMIM](https://github.com/microsoft/SimMIM)
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
MODEL:
TYPE: swinv2
NAME: mim_satvision_pretrain
DROP_PATH_RATE: 0.1
SWINV2:
IN_CHANS: 7
EMBED_DIM: 128
DEPTHS: [ 2, 2, 18, 2 ]
NUM_HEADS: [ 4, 8, 16, 32 ]
WINDOW_SIZE: 12
DATA:
IMG_SIZE: 192
MASK_PATCH_SIZE: 32
MASK_RATIO: 0.6
TRAIN:
EPOCHS: 800
WARMUP_EPOCHS: 10
BASE_LR: 1e-4
WARMUP_LR: 5e-7
WEIGHT_DECAY: 0.05
LR_SCHEDULER:
NAME: 'multistep'
GAMMA: 0.1
MULTISTEPS: [700,]
PRINT_FREQ: 100
SAVE_FREQ: 5
TAG: mim_pretrain_swinv2_satvision_192_window12__800ep
19 changes: 19 additions & 0 deletions examples/satvision/run_satvision_pretrain.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/bin/bash

#SBATCH -J pretrain_satvision_swinv2
#SBATCH -t 3-00:00:00
#SBATCH -G 4
#SBATCH -N 1


export PYTHONPATH=$PWD:../../../:../../../pytorch-caney
export NGPUS=4

torchrun --nproc_per_node $NGPUS \
../../../pytorch-caney/pytorch_caney/pipelines/pretraining/simmim.py \
--cfg mim_pretrain_swinv2_satvision_base_192_window12_800ep.yaml \
--dataset MODIS \
--data-paths /explore/nobackup/projects/ilab/data/satvision/pretraining/training_* \
--batch-size 128 \
--output /explore/nobackup/people/cssprad1/projects/satnet/code/development/cleanup/trf/transformer/models \
--enable-amp
265 changes: 265 additions & 0 deletions pytorch_caney/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,265 @@
import os
import yaml
from yacs.config import CfgNode as CN

_C = CN()

# Base config files
_C.BASE = ['']

# -----------------------------------------------------------------------------
# Data settings
# -----------------------------------------------------------------------------
_C.DATA = CN()
# Batch size for a single GPU, could be overwritten by command line argument
_C.DATA.BATCH_SIZE = 128
# Path(s) to dataset, could be overwritten by command line argument
_C.DATA.DATA_PATHS = ['']
# Dataset name
_C.DATA.DATASET = 'MODIS'
# Input image size
_C.DATA.IMG_SIZE = 224
# Interpolation to resize image (random, bilinear, bicubic)
_C.DATA.INTERPOLATION = 'bicubic'
# Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.
_C.DATA.PIN_MEMORY = True
# Number of data loading threads
_C.DATA.NUM_WORKERS = 8
# [SimMIM] Mask patch size for MaskGenerator
_C.DATA.MASK_PATCH_SIZE = 32
# [SimMIM] Mask ratio for MaskGenerator
_C.DATA.MASK_RATIO = 0.6

# -----------------------------------------------------------------------------
# Model settings
# -----------------------------------------------------------------------------
_C.MODEL = CN()
# Model type
_C.MODEL.TYPE = 'swinv2'
# Model name
_C.MODEL.NAME = 'swinv2_base_patch4_window7_224'
# Pretrained weight from checkpoint, could be from previous pre-training
# could be overwritten by command line argument
_C.MODEL.PRETRAINED = ''
# Checkpoint to resume, could be overwritten by command line argument
_C.MODEL.RESUME = ''
# Number of classes, overwritten in data preparation
_C.MODEL.NUM_CLASSES = 17
# Dropout rate
_C.MODEL.DROP_RATE = 0.0
# Drop path rate
_C.MODEL.DROP_PATH_RATE = 0.1

# Swin Transformer parameters
_C.MODEL.SWIN = CN()
_C.MODEL.SWIN.PATCH_SIZE = 4
_C.MODEL.SWIN.IN_CHANS = 4
_C.MODEL.SWIN.EMBED_DIM = 96
_C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]
_C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]
_C.MODEL.SWIN.WINDOW_SIZE = 7
_C.MODEL.SWIN.MLP_RATIO = 4.
_C.MODEL.SWIN.QKV_BIAS = True
_C.MODEL.SWIN.QK_SCALE = None
_C.MODEL.SWIN.APE = False
_C.MODEL.SWIN.PATCH_NORM = True

# Swin Transformer V2 parameters
_C.MODEL.SWINV2 = CN()
_C.MODEL.SWINV2.PATCH_SIZE = 4
_C.MODEL.SWINV2.IN_CHANS = 3
_C.MODEL.SWINV2.EMBED_DIM = 96
_C.MODEL.SWINV2.DEPTHS = [2, 2, 6, 2]
_C.MODEL.SWINV2.NUM_HEADS = [3, 6, 12, 24]
_C.MODEL.SWINV2.WINDOW_SIZE = 7
_C.MODEL.SWINV2.MLP_RATIO = 4.
_C.MODEL.SWINV2.QKV_BIAS = True
_C.MODEL.SWINV2.APE = False
_C.MODEL.SWINV2.PATCH_NORM = True
_C.MODEL.SWINV2.PRETRAINED_WINDOW_SIZES = [0, 0, 0, 0]

# Vision Transformer parameters
_C.MODEL.VIT = CN()
_C.MODEL.VIT.PATCH_SIZE = 16
_C.MODEL.VIT.IN_CHANS = 3
_C.MODEL.VIT.EMBED_DIM = 768
_C.MODEL.VIT.DEPTH = 12
_C.MODEL.VIT.NUM_HEADS = 12
_C.MODEL.VIT.MLP_RATIO = 4
_C.MODEL.VIT.QKV_BIAS = True
_C.MODEL.VIT.INIT_VALUES = 0.1
_C.MODEL.VIT.USE_APE = False
_C.MODEL.VIT.USE_RPB = False
_C.MODEL.VIT.USE_SHARED_RPB = True
_C.MODEL.VIT.USE_MEAN_POOLING = False

# -----------------------------------------------------------------------------
# Training settings
# -----------------------------------------------------------------------------
_C.TRAIN = CN()
_C.TRAIN.START_EPOCH = 0
_C.TRAIN.EPOCHS = 300
_C.TRAIN.WARMUP_EPOCHS = 20
_C.TRAIN.WEIGHT_DECAY = 0.05
_C.TRAIN.BASE_LR = 5e-4
_C.TRAIN.WARMUP_LR = 5e-7
_C.TRAIN.MIN_LR = 5e-6
# Clip gradient norm
_C.TRAIN.CLIP_GRAD = 5.0
# Auto resume from latest checkpoint
_C.TRAIN.AUTO_RESUME = True
# Gradient accumulation steps
# could be overwritten by command line argument
_C.TRAIN.ACCUMULATION_STEPS = 0
# Whether to use gradient checkpointing to save memory
# could be overwritten by command line argument
_C.TRAIN.USE_CHECKPOINT = False

# LR scheduler
_C.TRAIN.LR_SCHEDULER = CN()
_C.TRAIN.LR_SCHEDULER.NAME = 'cosine'
# Epoch interval to decay LR, used in StepLRScheduler
_C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30
# LR decay rate, used in StepLRScheduler
_C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1
# Gamma / Multi steps value, used in MultiStepLRScheduler
_C.TRAIN.LR_SCHEDULER.GAMMA = 0.1
_C.TRAIN.LR_SCHEDULER.MULTISTEPS = []

# Optimizer
_C.TRAIN.OPTIMIZER = CN()
_C.TRAIN.OPTIMIZER.NAME = 'adamw'
# Optimizer Epsilon
_C.TRAIN.OPTIMIZER.EPS = 1e-8
# Optimizer Betas
_C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999)
# SGD momentum
_C.TRAIN.OPTIMIZER.MOMENTUM = 0.9

# [SimMIM] Layer decay for fine-tuning
_C.TRAIN.LAYER_DECAY = 1.0

# -----------------------------------------------------------------------------
# Augmentation settings
# -----------------------------------------------------------------------------
_C.AUG = CN()
# Color jitter factor
_C.AUG.COLOR_JITTER = 0.4
# Use AutoAugment policy. "v0" or "original"
_C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1'
# Random erase prob
_C.AUG.REPROB = 0.25
# Random erase mode
_C.AUG.REMODE = 'pixel'
# Random erase count
_C.AUG.RECOUNT = 1
# Mixup alpha, mixup enabled if > 0
_C.AUG.MIXUP = 0.8
# Cutmix alpha, cutmix enabled if > 0
_C.AUG.CUTMIX = 1.0
# Cutmix min/max ratio, overrides alpha and enables cutmix if set
_C.AUG.CUTMIX_MINMAX = None
# Probability of performing mixup or cutmix when either/both is enabled
_C.AUG.MIXUP_PROB = 1.0
# Probability of switching to cutmix when both mixup and cutmix enabled
_C.AUG.MIXUP_SWITCH_PROB = 0.5
# How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
_C.AUG.MIXUP_MODE = 'batch'

# -----------------------------------------------------------------------------
# Testing settings
# -----------------------------------------------------------------------------
_C.TEST = CN()
# Whether to use center crop when testing
_C.TEST.CROP = True

# -----------------------------------------------------------------------------
# Misc
# -----------------------------------------------------------------------------
# [SimMIM] Whether to enable pytorch amp, overwritten by command line argument
_C.ENABLE_AMP = False
# Enable Pytorch automatic mixed precision (amp).
_C.AMP_ENABLE = True
# Path to output folder, overwritten by command line argument
_C.OUTPUT = ''
# Tag of experiment, overwritten by command line argument
_C.TAG = 'default'
# Frequency to save checkpoint
_C.SAVE_FREQ = 1
# Frequency to logging info
_C.PRINT_FREQ = 10
# Fixed random seed
_C.SEED = 0
# Perform evaluation only, overwritten by command line argument
_C.EVAL_MODE = False
# Test throughput only, overwritten by command line argument
_C.THROUGHPUT_MODE = False


def _update_config_from_file(config, cfg_file):
config.defrost()
with open(cfg_file, 'r') as f:
yaml_cfg = yaml.load(f, Loader=yaml.FullLoader)

for cfg in yaml_cfg.setdefault('BASE', ['']):
if cfg:
_update_config_from_file(
config, os.path.join(os.path.dirname(cfg_file), cfg)
)
print('=> merge config from {}'.format(cfg_file))
config.merge_from_file(cfg_file)
config.freeze()


def update_config(config, args):
_update_config_from_file(config, args.cfg)

config.defrost()

def _check_args(name):
if hasattr(args, name) and eval(f'args.{name}'):
return True
return False

# merge from specific arguments
if _check_args('batch_size'):
config.DATA.BATCH_SIZE = args.batch_size
if _check_args('data_paths'):
config.DATA.DATA_PATHS = args.data_paths
if _check_args('dataset'):
config.DATA.DATASET = args.dataset
if _check_args('resume'):
config.MODEL.RESUME = args.resume
if _check_args('pretrained'):
config.MODEL.PRETRAINED = args.pretrained
if _check_args('resume'):
config.MODEL.RESUME = args.resume
if _check_args('accumulation_steps'):
config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps
if _check_args('use_checkpoint'):
config.TRAIN.USE_CHECKPOINT = True
if _check_args('disable_amp'):
config.AMP_ENABLE = False
if _check_args('output'):
config.OUTPUT = args.output
if _check_args('tag'):
config.TAG = args.tag
if _check_args('eval'):
config.EVAL_MODE = True
if _check_args('enable_amp'):
config.ENABLE_AMP = args.enable_amp

# output folder
config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG)

config.freeze()


def get_config(args):
"""Get a yacs CfgNode object with default values."""
# Return a clone so that the defaults will not be altered
# This is for the "local variable" use pattern
config = _C.clone()
update_config(config, args)

return config
1 change: 0 additions & 1 deletion pytorch_caney/console/cli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# import pytorch_lightning
from pytorch_lightning.utilities.cli import LightningCLI

import torch
Expand Down
File renamed without changes.
Loading

0 comments on commit 120988a

Please sign in to comment.