Skip to content

Commit

Permalink
Merge pull request #1 from openclimatefix/issue/mvp
Browse files Browse the repository at this point in the history
Create MVP summation model
  • Loading branch information
dfulu authored Jul 20, 2023
2 parents bcded22 + b44fecb commit 9c62ff8
Show file tree
Hide file tree
Showing 27 changed files with 1,327 additions and 2 deletions.
42 changes: 42 additions & 0 deletions configs/callbacks/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#pretrain_early_stopping:
# _target_: pvnet.callbacks.PretrainEarlyStopping
# monitor: "MAE/val" # name of the logged metric which determines when model is improving
# mode: "min" # can be "max" or "min"
# patience: 10 # how many epochs (or val check periods) of not improving until training stops
# min_delta: 0.001 # minimum change in the monitored metric needed to qualify as an improvement

#pretrain_encoder_freezing:
# _target_: pvnet.callbacks.PretrainFreeze

early_stopping:
_target_: pvnet.callbacks.MainEarlyStopping
# name of the logged metric which determines when model is improving
monitor: "${resolve_monitor_loss:${model.output_quantiles}}"
mode: "min" # can be "max" or "min"
patience: 10 # how many epochs (or val check periods) of not improving until training stops
min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement

learning_rate_monitor:
_target_: lightning.pytorch.callbacks.LearningRateMonitor
logging_interval: "epoch"

model_summary:
_target_: lightning.pytorch.callbacks.ModelSummary
max_depth: 3

model_checkpoint:
_target_: lightning.pytorch.callbacks.ModelCheckpoint
# name of the logged metric which determines when model is improving
monitor: "${resolve_monitor_loss:${model.output_quantiles}}"
mode: "min" # can be "max" or "min"
save_top_k: 1 # save k best models (determined by above metric)
save_last: True # additionaly always save model from last epoch
every_n_epochs: 1
verbose: False
filename: "epoch={epoch}-step={step}"
dirpath: "checkpoints/pvnet_summation/${model_name}" #${..model_name}
auto_insert_metric_name: False
save_on_train_epoch_end: False
#device_stats_monitor:
# _target_: lightning.pytorch.callbacks.DeviceStatsMonitor
# cpu_stats: True
Empty file added configs/callbacks/none.yaml
Empty file.
40 changes: 40 additions & 0 deletions configs/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# @package _global_

# specify here default training configuration
defaults:
- trainer: default.yaml
- model: default.yaml
- datamodule: default.yaml
- callbacks: default.yaml # set this to null if you don't want to use callbacks
- logger: wandb.yaml # set logger here or use command line (e.g. `python run.py logger=wandb`)
- hydra: default.yaml

# enable color logging
# - override hydra/hydra_logging: colorlog
# - override hydra/job_logging: colorlog

# path to original working directory
# hydra hijacks working directory by changing it to the current log directory,
# so it's useful to have this path as a special variable
# learn more here: https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory
work_dir: ${hydra:runtime.cwd}

model_name: "default"

# use `python run.py debug=true` for easy debugging!
# this will run 1 train, val and test loop with only 1 batch
# equivalent to running `python run.py trainer.fast_dev_run=true`
# (this is placed here just for easier access from command line)
debug: False

# pretty print config at the start of the run using Rich library
print_config: True

# disable python warnings if they annoy you
ignore_warnings: True

# check performance on test set, using the best model achieved during training
# lightning chooses best model based on metric specified in checkpoint callback
test_after_training: False

seed: 2727831
6 changes: 6 additions & 0 deletions configs/datamodule/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
_target_: pvnet_summation.data.datamodule.DataModule
batch_dir: "/mnt/disks/bigbatches/concurrent_batches_v3.6_-60mins"
gsp_zarr_path: "/mnt/disks/nwp/pv_gsp.zarr"
batch_size: 8
num_workers: 20
prefetch_factor: 2
12 changes: 12 additions & 0 deletions configs/hydra/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# output paths for hydra logs
run:
dir: logs/runs/${now:%Y-%m-%d}/${now:%H-%M-%S}
sweep:
dir: logs/multiruns/${now:%Y-%m-%d_%H-%M-%S}
subdir: ${hydra.job.num}

# you can set here environment variables that are universal for all users
# for system specific variables (like data paths) it's better to use .env file!
job:
env_set:
EXAMPLE_VAR: "example_value"
15 changes: 15 additions & 0 deletions configs/logger/wandb.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# https://wandb.ai

wandb:
_target_: lightning.pytorch.loggers.wandb.WandbLogger
project: "pvnet_summation"
name: "${model_name}"
save_dir: "/mnt/disks/batches/"
offline: False # set True to store all logs only locally
id: null # pass correct id to resume experiment!
# entity: "" # set to name of your wandb team or just remove it
log_model: False
prefix: ""
job_type: "train"
group: ""
tags: []
34 changes: 34 additions & 0 deletions configs/model/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
_target_: pvnet_summation.models.model.Model

output_quantiles: null

model_name: "openclimatefix/pvnet_v2"
model_version: "898630f3f8cd4e8506525d813dd61c6d8de86144"

#--------------------------------------------
# Tabular network settings
#--------------------------------------------

output_network:
_target_: pvnet.models.multimodal.linear_networks.networks.ResFCNet2
_partial_: True
output_network_kwargs:
fc_hidden_features: 128
n_res_blocks: 6
res_block_layers: 2
dropout_frac: 0.0


# Foreast and time settings
forecast_minutes: 480

# ----------------------------------------------

optimizer:
_target_: pvnet.optimizers.AdamWReduceLROnPlateau
lr: 0.0001
weight_decay: 0.25
amsgrad: True
patience: 5
factor: 0.1
threshold: 0.002
7 changes: 7 additions & 0 deletions configs/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
The following folders how the configuration files

This idea is copied from
https://github.com/ashleve/lightning-hydra-template/blob/main/configs/experiment/example_simple.yaml

run experiments by:
`python run.py experiment=example_simple `
48 changes: 48 additions & 0 deletions configs/trainer/all_params.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
_target_: pytorch_lightning.Trainer

# default values for all trainer parameters
checkpoint_callback: True
default_root_dir: null
gradient_clip_val: 0.0
process_position: 0
num_nodes: 1
num_processes: 1
gpus: null
auto_select_gpus: False
tpu_cores: null
log_gpu_memory: null
overfit_batches: 0.0
track_grad_norm: -1
check_val_every_n_epoch: 1
fast_dev_run: False
accumulate_grad_batches: 1
max_epochs: 1
min_epochs: 1
max_steps: null
min_steps: null
limit_train_batches: 1.0
limit_val_batches: 1.0
limit_test_batches: 1.0
val_check_interval: 1.0
flush_logs_every_n_steps: 100
log_every_n_steps: 50
accelerator: null
sync_batchnorm: False
precision: 32
weights_save_path: null
num_sanity_val_steps: 2
truncated_bptt_steps: null
resume_from_checkpoint: null
profiler: null
benchmark: False
deterministic: False
reload_dataloaders_every_epoch: False
auto_lr_find: False
replace_sampler_ddp: True
terminate_on_nan: False
auto_scale_batch_size: False
prepare_data_per_node: True
plugins: null
amp_backend: "native"
amp_level: "O2"
move_metrics_to_cpu: False
17 changes: 17 additions & 0 deletions configs/trainer/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
_target_: lightning.pytorch.trainer.trainer.Trainer

# set `1` to train on GPU, `0` to train on CPU only
accelerator: gpu
devices: auto

min_epochs: null
max_epochs: null
reload_dataloaders_every_n_epochs: 0
num_sanity_val_steps: 8
fast_dev_run: false
#profiler: 'simple'

accumulate_grad_batches: 4
#val_check_interval: 800
#limit_val_batches: 800
log_every_n_steps: 50
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: ocf_template
name: pvnet_summation
channels:
- pytorch
- conda-forge
Expand Down
132 changes: 132 additions & 0 deletions pvnet_summation/data/datamodule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
""" Data module for pytorch lightning """

import torch
from lightning.pytorch import LightningDataModule
from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService
from torchdata.datapipes.iter import FileLister, IterDataPipe
from ocf_datapipes.utils.consts import BatchKey
from ocf_datapipes.load import OpenGSP
from ocf_datapipes.training.pvnet import normalize_gsp

from pvnet.data.datamodule import (
copy_batch_to_device,
batch_to_tensor,
split_batches,
)
# https://github.com/pytorch/pytorch/issues/973
torch.multiprocessing.set_sharing_strategy('file_system')


class GetNationalPVLive(IterDataPipe):
def __init__(self, gsp_data, sample_datapipe, return_times=False):
self.gsp_data = gsp_data
self.sample_datapipe = sample_datapipe
self.return_times = return_times

def __iter__(self):
gsp_data = self.gsp_data
for sample in self.sample_datapipe:
# Times for each GSP in the sample batch should be the same - take first
id0 = sample[BatchKey.gsp_t0_idx]
times = sample[BatchKey.gsp_time_utc][0, id0+1:]
national_outputs = torch.as_tensor(
gsp_data.sel(time_utc=times.cpu().numpy().astype("datetime64[s]")).values
)

if self.return_times:
yield national_outputs, times
else:
yield national_outputs


class ReorganiseBatch(IterDataPipe):
"""Reoragnise batches for pvnet_summation"""
def __init__(self, source_datapipe):
"""Reoragnise batches for pvnet_summation
Args:
source_datapipe: Zipped datapipe of list[tuple(NumpyBatch, national_outputs)]
"""
self.source_datapipe = source_datapipe

def __iter__(self):
for batch in self.source_datapipe:
yield dict(
pvnet_inputs = [sample[0] for sample in batch],
national_targets = torch.stack([sample[1] for sample in batch]),
times = torch.stack([sample[2] for sample in batch]),
)

class DataModule(LightningDataModule):
"""Datamodule for training pvnet_summation."""

def __init__(
self,
batch_dir: str,
gsp_zarr_path: str,
batch_size=16,
num_workers=0,
prefetch_factor=2,
):
"""Datamodule for training pvnet_summation.
Args:
batch_dir: Path to the directory of pre-saved batches.
gsp_zarr_path: Path to zarr file containing GSP ID 0 outputs
batch_size: Batch size.
num_workers: Number of workers to use in multiprocess batch loading.
prefetch_factor: Number of data will be prefetched at the end of each worker process.
"""
super().__init__()
self.gsp_zarr_path = gsp_zarr_path
self.batch_size = batch_size
self.batch_dir = batch_dir

self.readingservice_config = dict(
num_workers=num_workers,
multiprocessing_context="spawn",
worker_prefetch_cnt=prefetch_factor,
)

def _get_premade_batches_datapipe(self, subdir, shuffle=False):
data_pipeline = FileLister(f"{self.batch_dir}/{subdir}", masks="*.pt", recursive=False)
if shuffle:
data_pipeline = data_pipeline.shuffle(buffer_size=1000)

data_pipeline = data_pipeline.sharding_filter().map(torch.load)

# Add the national target
data_pipeline, dp = data_pipeline.fork(2, buffer_size=5)

gsp_datapipe = OpenGSP(gsp_pv_power_zarr_path=self.gsp_zarr_path).map(normalize_gsp)
gsp_data = next(iter(gsp_datapipe)).sel(gsp_id=0).compute()

national_targets_datapipe, times_datapipe = (
GetNationalPVLive(gsp_data, dp, return_times=True).unzip(sequence_length=2)
)
data_pipeline = data_pipeline.zip(national_targets_datapipe, times_datapipe)

data_pipeline = ReorganiseBatch(data_pipeline.batch(self.batch_size))

return data_pipeline

def train_dataloader(self):
"""Construct train dataloader"""
datapipe = self._get_premade_batches_datapipe("train", shuffle=True)

rs = MultiProcessingReadingService(**self.readingservice_config)
return DataLoader2(datapipe, reading_service=rs)

def val_dataloader(self):
"""Construct val dataloader"""
datapipe = self._get_premade_batches_datapipe("val")

rs = MultiProcessingReadingService(**self.readingservice_config)
return DataLoader2(datapipe, reading_service=rs)

def test_dataloader(self):
"""Construct test dataloader"""
datapipe = self._get_premade_batches_datapipe("test")

rs = MultiProcessingReadingService(**self.readingservice_config)
return DataLoader2(datapipe, reading_service=rs)
Loading

0 comments on commit 9c62ff8

Please sign in to comment.