Skip to content
This repository has been archived by the owner on Nov 29, 2023. It is now read-only.

Commit

Permalink
Refactor SatFlow for new models in separate repos (#86)
Browse files Browse the repository at this point in the history
* Add learnable query

Needs more testing, currently does not work

Also made slight change as PerceiverIO implementation changed as well in upstream repo

* Update Perceiver IO with newer options

Add in the decoder feedforward for the multi-pass version

Continue working on the SinglePassPerceiver, output shape is wrong still, there is a 2 somewhere that I'm missing

* Add decoder_ff as config option

* Add total variation loss and Dynamic SSIM loss

From https://arxiv.org/pdf/2004.05214.pdf mentioning that SSIM tends to regress to predicting the background, and MSE predicting blurry images, total variation loss is one way to combat that, and another is to only use SSIM on the parts of the image that changes

* Add TODO

* Add losses to get_loss

* Add to assert

* Add Gradient Difference Loss, relates to #5

* Remove MetNet and Nowcasting GAN files

They are now in their own repos and importable with pip

* Remove unused imports

* Update to remove shadowing

* Fix test

* Skip two tests because of changes with Perceiver model

* Update test

* Remove some of MetNet code and Perceiver Encoders

MetNet code is in the new metnet repo, same with the perceiver code in the perceiver-pytorch repo

* Update to use OCF perceiver pytorch

* Add encoder/decoders

* Simplify MultiPerceiverSat using OCF perceiver

* Update MultiPerceiverSat to use new perceiver

* Only test on 3.8+

* Remove SinglePassPerceiver

It can be accomplished with the perceiverio implementation fairly easily now, so once that's merged there can use it.

* Remove MetNetPreprocessor

Taken care of in the model repos

* Switch to pip install

* Switch pip name
  • Loading branch information
jacobbieker authored Sep 8, 2021
1 parent 93cbfcf commit bb1b5c6
Show file tree
Hide file tree
Showing 28 changed files with 299 additions and 2,796 deletions.
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ hydra-optuna-sweeper>=1.1.0
lightning-bolts>=0.3.4
neptune-client>=0.10.2
neptune-pytorch-lightning>=0.9.7
perceiver-pytorch>=0.5.1
pytest>=6.2.4
python-dotenv>=0.19.0
pytorch-msssim==0.2.1
Expand All @@ -17,3 +16,6 @@ git+https://github.com/webdataset/webdataset
torch_optimizer
huggingface_hub>=0.0.16
einops==0.3.2
metnet>=0.0.3
skillful_nowcasting>=0.0.2
perceiver-model>=0.7.0
4 changes: 2 additions & 2 deletions satflow/configs/datamodule/metnet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pin_memory: True
prefetch_factor: 4
config:
visualize: False
num_timesteps: 5
num_timesteps: 6
skip_timesteps: 1
forecast_times: 24
output_shape: 256
Expand All @@ -24,7 +24,7 @@ config:
use_time: False
time_aux: False
use_mask: False
use_image: True
use_image: False
add_pixel_coords: False
time_as_channels: False
metnet_normalization: True
Expand Down
2 changes: 1 addition & 1 deletion satflow/configs/datamodule/nowcasting_gan_hrv_aws.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ config:
visualize: False
num_timesteps: 3
skip_timesteps: 1
forecast_times: 12
forecast_times: 24
output_shape: 128
target_type: "cloudmask"
num_crops: 10
Expand Down
2 changes: 1 addition & 1 deletion satflow/configs/datamodule/perceiver_metnet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ config:
visualize: False
num_timesteps: 3
skip_timesteps: 1
forecast_times: 12
forecast_times: 24
output_shape: 128
output_target: 32
target_type: "cloudmask"
Expand Down
14 changes: 7 additions & 7 deletions satflow/configs/hparams_search/metnet_optuna.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@ hydra:
# learn more here: https://hydra.cc/docs/next/plugins/optuna_sweeper
sweeper:
_target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper
storage: "sqlite:///metnet.db"
storage: null #"sqlite:///metnet.db"
study_name: metnet
n_jobs: 1

# 'minimize' or 'maximize' the objective
direction: minimize

# number of experiments that will be executed
n_trials: 50
n_trials: 20

# choose Optuna hyperparameter sampler
# learn more here: https://optuna.readthedocs.io/en/stable/reference/samplers.html
Expand All @@ -48,20 +48,20 @@ hydra:
choices: [2]
datamodule.config.num_timesteps:
type: categorical
choices: [1, 3, 6]
choices: [3, 6]
datamodule.config.skip_timesteps:
type: categorical
choices: [1, 2, 3]
choices: [1]
model.lr:
type: float
low: 0.0001
high: 0.2
model.hidden_dim:
type: categorical
choices: [8, 16, 32, 64]
choices: [32, 64, 128]
model.num_layers:
type: categorical
choices: [1, 2]
choices: [1, 2, 3]
model.num_att_layers:
type: categorical
choices: [1, 2]
choices: [1, 2, 3]
2 changes: 1 addition & 1 deletion satflow/configs/model/metnet.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
_target_: satflow.models.metnet.MetNet
input_channels: 16
output_channels: 12
output_channels: 1
sat_channels: 12
input_size: 64
hidden_dim: 32
Expand Down
2 changes: 1 addition & 1 deletion satflow/configs/model/nowcasting_gan.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
_target_: satflow.models.nowcasting_gan.NowcastingGAN
forecast_steps: 12
forecast_steps: 24
input_channels: 1
output_shape: 128
gen_lr: 0.00005
Expand Down
2 changes: 1 addition & 1 deletion satflow/configs/model/perceiver.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ latent_heads: 8
cross_dim_heads: 8
latent_dim: 256
weight_tie_layers: False
self_per_cross_attention: 2
decoder_ff: True
dim: 32
logits_dim: null
queries_dim: 32
Expand Down
2 changes: 1 addition & 1 deletion satflow/configs/model/perceiver_encoder.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ latent_heads: 8
cross_dim_heads: 8
latent_dim: 256
weight_tie_layers: False
self_per_cross_attention: 2
decoder_ff: True
dim: 32
logits_dim: null
queries_dim: 32
Expand Down
4 changes: 2 additions & 2 deletions satflow/configs/model/perceiver_metnet.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
_target_: satflow.models.perceiver.Perceiver
input_channels: 16
sat_channels: 12
forecast_steps: 12
forecast_steps: 24
lr: 0.005
input_size: 32
max_frequency: 16.0
Expand All @@ -12,7 +12,7 @@ latent_heads: 8
cross_dim_heads: 8
latent_dim: 256
weight_tie_layers: False
self_per_cross_attention: 2
decoder_ff: True
dim: 32
logits_dim: null
queries_dim: 32
Expand Down
24 changes: 24 additions & 0 deletions satflow/configs/model/perceiver_single.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
_target_: satflow.models.perceiver.SinglePassPerceiver
input_channels: 16
sat_channels: 12
forecast_steps: 3
lr: 0.005
input_size: 32
max_frequency: 16.0
depth: 6
num_latents: 256
cross_heads: 1
latent_heads: 8
cross_dim_heads: 8
latent_dim: 256
weight_tie_layers: False
self_per_cross_attention: 2
dim: 32
logits_dim: null
queries_dim: 128
latent_dim_heads: 64
visualize: False
preprocessor_type: "metnet"
use_input_as_query: True
use_learnable_query: False
output_shape: [3, 32, 32]
2 changes: 1 addition & 1 deletion satflow/configs/trainer/minimal.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ terminate_on_nan: False
auto_lr_find: False
auto_scale_batch_size: False
accumulate_grad_batches: 1
precision: 16
precision: 32
# stochastic_weight_avg: True
fast_dev_run: False

Expand Down
7 changes: 3 additions & 4 deletions satflow/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from .base import get_model, create_model
from .conv_lstm import EncoderDecoderConvLSTM, ConvLSTM
from .metnet import MetNet
from .pix2pix import Pix2Pix
from .pl_metnet import LitMetNet

from .runet import R2U_Net, RUnet
from .attention_unet import R2AttU_Net, AttU_Net
from .cloudgan import CloudGAN
from .nowcasting_gan import NowcastingGAN

from .perceiver import Perceiver
Loading

0 comments on commit bb1b5c6

Please sign in to comment.