-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
33 changed files
with
3,088 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
27 changes: 27 additions & 0 deletions
27
examples/satvision/mim_pretrain_swinv2_satvision_base_192_window12_800ep.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
MODEL: | ||
TYPE: swinv2 | ||
NAME: mim_satvision_pretrain | ||
DROP_PATH_RATE: 0.1 | ||
SWINV2: | ||
IN_CHANS: 7 | ||
EMBED_DIM: 128 | ||
DEPTHS: [ 2, 2, 18, 2 ] | ||
NUM_HEADS: [ 4, 8, 16, 32 ] | ||
WINDOW_SIZE: 12 | ||
DATA: | ||
IMG_SIZE: 192 | ||
MASK_PATCH_SIZE: 32 | ||
MASK_RATIO: 0.6 | ||
TRAIN: | ||
EPOCHS: 800 | ||
WARMUP_EPOCHS: 10 | ||
BASE_LR: 1e-4 | ||
WARMUP_LR: 5e-7 | ||
WEIGHT_DECAY: 0.05 | ||
LR_SCHEDULER: | ||
NAME: 'multistep' | ||
GAMMA: 0.1 | ||
MULTISTEPS: [700,] | ||
PRINT_FREQ: 100 | ||
SAVE_FREQ: 5 | ||
TAG: mim_pretrain_swinv2_satvision_192_window12__800ep |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
#!/bin/bash | ||
|
||
#SBATCH -J pretrain_satvision_swinv2 | ||
#SBATCH -t 3-00:00:00 | ||
#SBATCH -G 4 | ||
#SBATCH -N 1 | ||
|
||
|
||
export PYTHONPATH=$PWD:../../../:../../../pytorch-caney | ||
export NGPUS=4 | ||
|
||
torchrun --nproc_per_node $NGPUS \ | ||
../../../pytorch-caney/pytorch_caney/pipelines/pretraining/simmim.py \ | ||
--cfg mim_pretrain_swinv2_satvision_base_192_window12_800ep.yaml \ | ||
--dataset MODIS \ | ||
--data-paths /explore/nobackup/projects/ilab/data/satvision/pretraining/training_* \ | ||
--batch-size 128 \ | ||
--output /explore/nobackup/people/cssprad1/projects/satnet/code/development/cleanup/trf/transformer/models \ | ||
--enable-amp |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,265 @@ | ||
import os | ||
import yaml | ||
from yacs.config import CfgNode as CN | ||
|
||
_C = CN() | ||
|
||
# Base config files | ||
_C.BASE = [''] | ||
|
||
# ----------------------------------------------------------------------------- | ||
# Data settings | ||
# ----------------------------------------------------------------------------- | ||
_C.DATA = CN() | ||
# Batch size for a single GPU, could be overwritten by command line argument | ||
_C.DATA.BATCH_SIZE = 128 | ||
# Path(s) to dataset, could be overwritten by command line argument | ||
_C.DATA.DATA_PATHS = [''] | ||
# Dataset name | ||
_C.DATA.DATASET = 'MODIS' | ||
# Input image size | ||
_C.DATA.IMG_SIZE = 224 | ||
# Interpolation to resize image (random, bilinear, bicubic) | ||
_C.DATA.INTERPOLATION = 'bicubic' | ||
# Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU. | ||
_C.DATA.PIN_MEMORY = True | ||
# Number of data loading threads | ||
_C.DATA.NUM_WORKERS = 8 | ||
# [SimMIM] Mask patch size for MaskGenerator | ||
_C.DATA.MASK_PATCH_SIZE = 32 | ||
# [SimMIM] Mask ratio for MaskGenerator | ||
_C.DATA.MASK_RATIO = 0.6 | ||
|
||
# ----------------------------------------------------------------------------- | ||
# Model settings | ||
# ----------------------------------------------------------------------------- | ||
_C.MODEL = CN() | ||
# Model type | ||
_C.MODEL.TYPE = 'swinv2' | ||
# Model name | ||
_C.MODEL.NAME = 'swinv2_base_patch4_window7_224' | ||
# Pretrained weight from checkpoint, could be from previous pre-training | ||
# could be overwritten by command line argument | ||
_C.MODEL.PRETRAINED = '' | ||
# Checkpoint to resume, could be overwritten by command line argument | ||
_C.MODEL.RESUME = '' | ||
# Number of classes, overwritten in data preparation | ||
_C.MODEL.NUM_CLASSES = 17 | ||
# Dropout rate | ||
_C.MODEL.DROP_RATE = 0.0 | ||
# Drop path rate | ||
_C.MODEL.DROP_PATH_RATE = 0.1 | ||
|
||
# Swin Transformer parameters | ||
_C.MODEL.SWIN = CN() | ||
_C.MODEL.SWIN.PATCH_SIZE = 4 | ||
_C.MODEL.SWIN.IN_CHANS = 4 | ||
_C.MODEL.SWIN.EMBED_DIM = 96 | ||
_C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2] | ||
_C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24] | ||
_C.MODEL.SWIN.WINDOW_SIZE = 7 | ||
_C.MODEL.SWIN.MLP_RATIO = 4. | ||
_C.MODEL.SWIN.QKV_BIAS = True | ||
_C.MODEL.SWIN.QK_SCALE = None | ||
_C.MODEL.SWIN.APE = False | ||
_C.MODEL.SWIN.PATCH_NORM = True | ||
|
||
# Swin Transformer V2 parameters | ||
_C.MODEL.SWINV2 = CN() | ||
_C.MODEL.SWINV2.PATCH_SIZE = 4 | ||
_C.MODEL.SWINV2.IN_CHANS = 3 | ||
_C.MODEL.SWINV2.EMBED_DIM = 96 | ||
_C.MODEL.SWINV2.DEPTHS = [2, 2, 6, 2] | ||
_C.MODEL.SWINV2.NUM_HEADS = [3, 6, 12, 24] | ||
_C.MODEL.SWINV2.WINDOW_SIZE = 7 | ||
_C.MODEL.SWINV2.MLP_RATIO = 4. | ||
_C.MODEL.SWINV2.QKV_BIAS = True | ||
_C.MODEL.SWINV2.APE = False | ||
_C.MODEL.SWINV2.PATCH_NORM = True | ||
_C.MODEL.SWINV2.PRETRAINED_WINDOW_SIZES = [0, 0, 0, 0] | ||
|
||
# Vision Transformer parameters | ||
_C.MODEL.VIT = CN() | ||
_C.MODEL.VIT.PATCH_SIZE = 16 | ||
_C.MODEL.VIT.IN_CHANS = 3 | ||
_C.MODEL.VIT.EMBED_DIM = 768 | ||
_C.MODEL.VIT.DEPTH = 12 | ||
_C.MODEL.VIT.NUM_HEADS = 12 | ||
_C.MODEL.VIT.MLP_RATIO = 4 | ||
_C.MODEL.VIT.QKV_BIAS = True | ||
_C.MODEL.VIT.INIT_VALUES = 0.1 | ||
_C.MODEL.VIT.USE_APE = False | ||
_C.MODEL.VIT.USE_RPB = False | ||
_C.MODEL.VIT.USE_SHARED_RPB = True | ||
_C.MODEL.VIT.USE_MEAN_POOLING = False | ||
|
||
# ----------------------------------------------------------------------------- | ||
# Training settings | ||
# ----------------------------------------------------------------------------- | ||
_C.TRAIN = CN() | ||
_C.TRAIN.START_EPOCH = 0 | ||
_C.TRAIN.EPOCHS = 300 | ||
_C.TRAIN.WARMUP_EPOCHS = 20 | ||
_C.TRAIN.WEIGHT_DECAY = 0.05 | ||
_C.TRAIN.BASE_LR = 5e-4 | ||
_C.TRAIN.WARMUP_LR = 5e-7 | ||
_C.TRAIN.MIN_LR = 5e-6 | ||
# Clip gradient norm | ||
_C.TRAIN.CLIP_GRAD = 5.0 | ||
# Auto resume from latest checkpoint | ||
_C.TRAIN.AUTO_RESUME = True | ||
# Gradient accumulation steps | ||
# could be overwritten by command line argument | ||
_C.TRAIN.ACCUMULATION_STEPS = 0 | ||
# Whether to use gradient checkpointing to save memory | ||
# could be overwritten by command line argument | ||
_C.TRAIN.USE_CHECKPOINT = False | ||
|
||
# LR scheduler | ||
_C.TRAIN.LR_SCHEDULER = CN() | ||
_C.TRAIN.LR_SCHEDULER.NAME = 'cosine' | ||
# Epoch interval to decay LR, used in StepLRScheduler | ||
_C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30 | ||
# LR decay rate, used in StepLRScheduler | ||
_C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1 | ||
# Gamma / Multi steps value, used in MultiStepLRScheduler | ||
_C.TRAIN.LR_SCHEDULER.GAMMA = 0.1 | ||
_C.TRAIN.LR_SCHEDULER.MULTISTEPS = [] | ||
|
||
# Optimizer | ||
_C.TRAIN.OPTIMIZER = CN() | ||
_C.TRAIN.OPTIMIZER.NAME = 'adamw' | ||
# Optimizer Epsilon | ||
_C.TRAIN.OPTIMIZER.EPS = 1e-8 | ||
# Optimizer Betas | ||
_C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999) | ||
# SGD momentum | ||
_C.TRAIN.OPTIMIZER.MOMENTUM = 0.9 | ||
|
||
# [SimMIM] Layer decay for fine-tuning | ||
_C.TRAIN.LAYER_DECAY = 1.0 | ||
|
||
# ----------------------------------------------------------------------------- | ||
# Augmentation settings | ||
# ----------------------------------------------------------------------------- | ||
_C.AUG = CN() | ||
# Color jitter factor | ||
_C.AUG.COLOR_JITTER = 0.4 | ||
# Use AutoAugment policy. "v0" or "original" | ||
_C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1' | ||
# Random erase prob | ||
_C.AUG.REPROB = 0.25 | ||
# Random erase mode | ||
_C.AUG.REMODE = 'pixel' | ||
# Random erase count | ||
_C.AUG.RECOUNT = 1 | ||
# Mixup alpha, mixup enabled if > 0 | ||
_C.AUG.MIXUP = 0.8 | ||
# Cutmix alpha, cutmix enabled if > 0 | ||
_C.AUG.CUTMIX = 1.0 | ||
# Cutmix min/max ratio, overrides alpha and enables cutmix if set | ||
_C.AUG.CUTMIX_MINMAX = None | ||
# Probability of performing mixup or cutmix when either/both is enabled | ||
_C.AUG.MIXUP_PROB = 1.0 | ||
# Probability of switching to cutmix when both mixup and cutmix enabled | ||
_C.AUG.MIXUP_SWITCH_PROB = 0.5 | ||
# How to apply mixup/cutmix params. Per "batch", "pair", or "elem" | ||
_C.AUG.MIXUP_MODE = 'batch' | ||
|
||
# ----------------------------------------------------------------------------- | ||
# Testing settings | ||
# ----------------------------------------------------------------------------- | ||
_C.TEST = CN() | ||
# Whether to use center crop when testing | ||
_C.TEST.CROP = True | ||
|
||
# ----------------------------------------------------------------------------- | ||
# Misc | ||
# ----------------------------------------------------------------------------- | ||
# [SimMIM] Whether to enable pytorch amp, overwritten by command line argument | ||
_C.ENABLE_AMP = False | ||
# Enable Pytorch automatic mixed precision (amp). | ||
_C.AMP_ENABLE = True | ||
# Path to output folder, overwritten by command line argument | ||
_C.OUTPUT = '' | ||
# Tag of experiment, overwritten by command line argument | ||
_C.TAG = 'default' | ||
# Frequency to save checkpoint | ||
_C.SAVE_FREQ = 1 | ||
# Frequency to logging info | ||
_C.PRINT_FREQ = 10 | ||
# Fixed random seed | ||
_C.SEED = 0 | ||
# Perform evaluation only, overwritten by command line argument | ||
_C.EVAL_MODE = False | ||
# Test throughput only, overwritten by command line argument | ||
_C.THROUGHPUT_MODE = False | ||
|
||
|
||
def _update_config_from_file(config, cfg_file): | ||
config.defrost() | ||
with open(cfg_file, 'r') as f: | ||
yaml_cfg = yaml.load(f, Loader=yaml.FullLoader) | ||
|
||
for cfg in yaml_cfg.setdefault('BASE', ['']): | ||
if cfg: | ||
_update_config_from_file( | ||
config, os.path.join(os.path.dirname(cfg_file), cfg) | ||
) | ||
print('=> merge config from {}'.format(cfg_file)) | ||
config.merge_from_file(cfg_file) | ||
config.freeze() | ||
|
||
|
||
def update_config(config, args): | ||
_update_config_from_file(config, args.cfg) | ||
|
||
config.defrost() | ||
|
||
def _check_args(name): | ||
if hasattr(args, name) and eval(f'args.{name}'): | ||
return True | ||
return False | ||
|
||
# merge from specific arguments | ||
if _check_args('batch_size'): | ||
config.DATA.BATCH_SIZE = args.batch_size | ||
if _check_args('data_paths'): | ||
config.DATA.DATA_PATHS = args.data_paths | ||
if _check_args('dataset'): | ||
config.DATA.DATASET = args.dataset | ||
if _check_args('resume'): | ||
config.MODEL.RESUME = args.resume | ||
if _check_args('pretrained'): | ||
config.MODEL.PRETRAINED = args.pretrained | ||
if _check_args('resume'): | ||
config.MODEL.RESUME = args.resume | ||
if _check_args('accumulation_steps'): | ||
config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps | ||
if _check_args('use_checkpoint'): | ||
config.TRAIN.USE_CHECKPOINT = True | ||
if _check_args('disable_amp'): | ||
config.AMP_ENABLE = False | ||
if _check_args('output'): | ||
config.OUTPUT = args.output | ||
if _check_args('tag'): | ||
config.TAG = args.tag | ||
if _check_args('eval'): | ||
config.EVAL_MODE = True | ||
if _check_args('enable_amp'): | ||
config.ENABLE_AMP = args.enable_amp | ||
|
||
# output folder | ||
config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG) | ||
|
||
config.freeze() | ||
|
||
|
||
def get_config(args): | ||
"""Get a yacs CfgNode object with default values.""" | ||
# Return a clone so that the defaults will not be altered | ||
# This is for the "local variable" use pattern | ||
config = _C.clone() | ||
update_config(config, args) | ||
|
||
return config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,3 @@ | ||
# import pytorch_lightning | ||
from pytorch_lightning.utilities.cli import LightningCLI | ||
|
||
import torch | ||
|
File renamed without changes.
File renamed without changes.
Oops, something went wrong.