generated from openclimatefix/ocf-template
-
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from openclimatefix/issue/mvp
Create MVP summation model
- Loading branch information
Showing
27 changed files
with
1,327 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
#pretrain_early_stopping: | ||
# _target_: pvnet.callbacks.PretrainEarlyStopping | ||
# monitor: "MAE/val" # name of the logged metric which determines when model is improving | ||
# mode: "min" # can be "max" or "min" | ||
# patience: 10 # how many epochs (or val check periods) of not improving until training stops | ||
# min_delta: 0.001 # minimum change in the monitored metric needed to qualify as an improvement | ||
|
||
#pretrain_encoder_freezing: | ||
# _target_: pvnet.callbacks.PretrainFreeze | ||
|
||
early_stopping: | ||
_target_: pvnet.callbacks.MainEarlyStopping | ||
# name of the logged metric which determines when model is improving | ||
monitor: "${resolve_monitor_loss:${model.output_quantiles}}" | ||
mode: "min" # can be "max" or "min" | ||
patience: 10 # how many epochs (or val check periods) of not improving until training stops | ||
min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement | ||
|
||
learning_rate_monitor: | ||
_target_: lightning.pytorch.callbacks.LearningRateMonitor | ||
logging_interval: "epoch" | ||
|
||
model_summary: | ||
_target_: lightning.pytorch.callbacks.ModelSummary | ||
max_depth: 3 | ||
|
||
model_checkpoint: | ||
_target_: lightning.pytorch.callbacks.ModelCheckpoint | ||
# name of the logged metric which determines when model is improving | ||
monitor: "${resolve_monitor_loss:${model.output_quantiles}}" | ||
mode: "min" # can be "max" or "min" | ||
save_top_k: 1 # save k best models (determined by above metric) | ||
save_last: True # additionaly always save model from last epoch | ||
every_n_epochs: 1 | ||
verbose: False | ||
filename: "epoch={epoch}-step={step}" | ||
dirpath: "checkpoints/pvnet_summation/${model_name}" #${..model_name} | ||
auto_insert_metric_name: False | ||
save_on_train_epoch_end: False | ||
#device_stats_monitor: | ||
# _target_: lightning.pytorch.callbacks.DeviceStatsMonitor | ||
# cpu_stats: True |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# @package _global_ | ||
|
||
# specify here default training configuration | ||
defaults: | ||
- trainer: default.yaml | ||
- model: default.yaml | ||
- datamodule: default.yaml | ||
- callbacks: default.yaml # set this to null if you don't want to use callbacks | ||
- logger: wandb.yaml # set logger here or use command line (e.g. `python run.py logger=wandb`) | ||
- hydra: default.yaml | ||
|
||
# enable color logging | ||
# - override hydra/hydra_logging: colorlog | ||
# - override hydra/job_logging: colorlog | ||
|
||
# path to original working directory | ||
# hydra hijacks working directory by changing it to the current log directory, | ||
# so it's useful to have this path as a special variable | ||
# learn more here: https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory | ||
work_dir: ${hydra:runtime.cwd} | ||
|
||
model_name: "default" | ||
|
||
# use `python run.py debug=true` for easy debugging! | ||
# this will run 1 train, val and test loop with only 1 batch | ||
# equivalent to running `python run.py trainer.fast_dev_run=true` | ||
# (this is placed here just for easier access from command line) | ||
debug: False | ||
|
||
# pretty print config at the start of the run using Rich library | ||
print_config: True | ||
|
||
# disable python warnings if they annoy you | ||
ignore_warnings: True | ||
|
||
# check performance on test set, using the best model achieved during training | ||
# lightning chooses best model based on metric specified in checkpoint callback | ||
test_after_training: False | ||
|
||
seed: 2727831 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
_target_: pvnet_summation.data.datamodule.DataModule | ||
batch_dir: "/mnt/disks/bigbatches/concurrent_batches_v3.6_-60mins" | ||
gsp_zarr_path: "/mnt/disks/nwp/pv_gsp.zarr" | ||
batch_size: 8 | ||
num_workers: 20 | ||
prefetch_factor: 2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# output paths for hydra logs | ||
run: | ||
dir: logs/runs/${now:%Y-%m-%d}/${now:%H-%M-%S} | ||
sweep: | ||
dir: logs/multiruns/${now:%Y-%m-%d_%H-%M-%S} | ||
subdir: ${hydra.job.num} | ||
|
||
# you can set here environment variables that are universal for all users | ||
# for system specific variables (like data paths) it's better to use .env file! | ||
job: | ||
env_set: | ||
EXAMPLE_VAR: "example_value" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# https://wandb.ai | ||
|
||
wandb: | ||
_target_: lightning.pytorch.loggers.wandb.WandbLogger | ||
project: "pvnet_summation" | ||
name: "${model_name}" | ||
save_dir: "/mnt/disks/batches/" | ||
offline: False # set True to store all logs only locally | ||
id: null # pass correct id to resume experiment! | ||
# entity: "" # set to name of your wandb team or just remove it | ||
log_model: False | ||
prefix: "" | ||
job_type: "train" | ||
group: "" | ||
tags: [] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
_target_: pvnet_summation.models.model.Model | ||
|
||
output_quantiles: null | ||
|
||
model_name: "openclimatefix/pvnet_v2" | ||
model_version: "898630f3f8cd4e8506525d813dd61c6d8de86144" | ||
|
||
#-------------------------------------------- | ||
# Tabular network settings | ||
#-------------------------------------------- | ||
|
||
output_network: | ||
_target_: pvnet.models.multimodal.linear_networks.networks.ResFCNet2 | ||
_partial_: True | ||
output_network_kwargs: | ||
fc_hidden_features: 128 | ||
n_res_blocks: 6 | ||
res_block_layers: 2 | ||
dropout_frac: 0.0 | ||
|
||
|
||
# Foreast and time settings | ||
forecast_minutes: 480 | ||
|
||
# ---------------------------------------------- | ||
|
||
optimizer: | ||
_target_: pvnet.optimizers.AdamWReduceLROnPlateau | ||
lr: 0.0001 | ||
weight_decay: 0.25 | ||
amsgrad: True | ||
patience: 5 | ||
factor: 0.1 | ||
threshold: 0.002 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
The following folders how the configuration files | ||
|
||
This idea is copied from | ||
https://github.com/ashleve/lightning-hydra-template/blob/main/configs/experiment/example_simple.yaml | ||
|
||
run experiments by: | ||
`python run.py experiment=example_simple ` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
_target_: pytorch_lightning.Trainer | ||
|
||
# default values for all trainer parameters | ||
checkpoint_callback: True | ||
default_root_dir: null | ||
gradient_clip_val: 0.0 | ||
process_position: 0 | ||
num_nodes: 1 | ||
num_processes: 1 | ||
gpus: null | ||
auto_select_gpus: False | ||
tpu_cores: null | ||
log_gpu_memory: null | ||
overfit_batches: 0.0 | ||
track_grad_norm: -1 | ||
check_val_every_n_epoch: 1 | ||
fast_dev_run: False | ||
accumulate_grad_batches: 1 | ||
max_epochs: 1 | ||
min_epochs: 1 | ||
max_steps: null | ||
min_steps: null | ||
limit_train_batches: 1.0 | ||
limit_val_batches: 1.0 | ||
limit_test_batches: 1.0 | ||
val_check_interval: 1.0 | ||
flush_logs_every_n_steps: 100 | ||
log_every_n_steps: 50 | ||
accelerator: null | ||
sync_batchnorm: False | ||
precision: 32 | ||
weights_save_path: null | ||
num_sanity_val_steps: 2 | ||
truncated_bptt_steps: null | ||
resume_from_checkpoint: null | ||
profiler: null | ||
benchmark: False | ||
deterministic: False | ||
reload_dataloaders_every_epoch: False | ||
auto_lr_find: False | ||
replace_sampler_ddp: True | ||
terminate_on_nan: False | ||
auto_scale_batch_size: False | ||
prepare_data_per_node: True | ||
plugins: null | ||
amp_backend: "native" | ||
amp_level: "O2" | ||
move_metrics_to_cpu: False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
_target_: lightning.pytorch.trainer.trainer.Trainer | ||
|
||
# set `1` to train on GPU, `0` to train on CPU only | ||
accelerator: gpu | ||
devices: auto | ||
|
||
min_epochs: null | ||
max_epochs: null | ||
reload_dataloaders_every_n_epochs: 0 | ||
num_sanity_val_steps: 8 | ||
fast_dev_run: false | ||
#profiler: 'simple' | ||
|
||
accumulate_grad_batches: 4 | ||
#val_check_interval: 800 | ||
#limit_val_batches: 800 | ||
log_every_n_steps: 50 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
name: ocf_template | ||
name: pvnet_summation | ||
channels: | ||
- pytorch | ||
- conda-forge | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
""" Data module for pytorch lightning """ | ||
|
||
import torch | ||
from lightning.pytorch import LightningDataModule | ||
from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService | ||
from torchdata.datapipes.iter import FileLister, IterDataPipe | ||
from ocf_datapipes.utils.consts import BatchKey | ||
from ocf_datapipes.load import OpenGSP | ||
from ocf_datapipes.training.pvnet import normalize_gsp | ||
|
||
from pvnet.data.datamodule import ( | ||
copy_batch_to_device, | ||
batch_to_tensor, | ||
split_batches, | ||
) | ||
# https://github.com/pytorch/pytorch/issues/973 | ||
torch.multiprocessing.set_sharing_strategy('file_system') | ||
|
||
|
||
class GetNationalPVLive(IterDataPipe): | ||
def __init__(self, gsp_data, sample_datapipe, return_times=False): | ||
self.gsp_data = gsp_data | ||
self.sample_datapipe = sample_datapipe | ||
self.return_times = return_times | ||
|
||
def __iter__(self): | ||
gsp_data = self.gsp_data | ||
for sample in self.sample_datapipe: | ||
# Times for each GSP in the sample batch should be the same - take first | ||
id0 = sample[BatchKey.gsp_t0_idx] | ||
times = sample[BatchKey.gsp_time_utc][0, id0+1:] | ||
national_outputs = torch.as_tensor( | ||
gsp_data.sel(time_utc=times.cpu().numpy().astype("datetime64[s]")).values | ||
) | ||
|
||
if self.return_times: | ||
yield national_outputs, times | ||
else: | ||
yield national_outputs | ||
|
||
|
||
class ReorganiseBatch(IterDataPipe): | ||
"""Reoragnise batches for pvnet_summation""" | ||
def __init__(self, source_datapipe): | ||
"""Reoragnise batches for pvnet_summation | ||
Args: | ||
source_datapipe: Zipped datapipe of list[tuple(NumpyBatch, national_outputs)] | ||
""" | ||
self.source_datapipe = source_datapipe | ||
|
||
def __iter__(self): | ||
for batch in self.source_datapipe: | ||
yield dict( | ||
pvnet_inputs = [sample[0] for sample in batch], | ||
national_targets = torch.stack([sample[1] for sample in batch]), | ||
times = torch.stack([sample[2] for sample in batch]), | ||
) | ||
|
||
class DataModule(LightningDataModule): | ||
"""Datamodule for training pvnet_summation.""" | ||
|
||
def __init__( | ||
self, | ||
batch_dir: str, | ||
gsp_zarr_path: str, | ||
batch_size=16, | ||
num_workers=0, | ||
prefetch_factor=2, | ||
): | ||
"""Datamodule for training pvnet_summation. | ||
Args: | ||
batch_dir: Path to the directory of pre-saved batches. | ||
gsp_zarr_path: Path to zarr file containing GSP ID 0 outputs | ||
batch_size: Batch size. | ||
num_workers: Number of workers to use in multiprocess batch loading. | ||
prefetch_factor: Number of data will be prefetched at the end of each worker process. | ||
""" | ||
super().__init__() | ||
self.gsp_zarr_path = gsp_zarr_path | ||
self.batch_size = batch_size | ||
self.batch_dir = batch_dir | ||
|
||
self.readingservice_config = dict( | ||
num_workers=num_workers, | ||
multiprocessing_context="spawn", | ||
worker_prefetch_cnt=prefetch_factor, | ||
) | ||
|
||
def _get_premade_batches_datapipe(self, subdir, shuffle=False): | ||
data_pipeline = FileLister(f"{self.batch_dir}/{subdir}", masks="*.pt", recursive=False) | ||
if shuffle: | ||
data_pipeline = data_pipeline.shuffle(buffer_size=1000) | ||
|
||
data_pipeline = data_pipeline.sharding_filter().map(torch.load) | ||
|
||
# Add the national target | ||
data_pipeline, dp = data_pipeline.fork(2, buffer_size=5) | ||
|
||
gsp_datapipe = OpenGSP(gsp_pv_power_zarr_path=self.gsp_zarr_path).map(normalize_gsp) | ||
gsp_data = next(iter(gsp_datapipe)).sel(gsp_id=0).compute() | ||
|
||
national_targets_datapipe, times_datapipe = ( | ||
GetNationalPVLive(gsp_data, dp, return_times=True).unzip(sequence_length=2) | ||
) | ||
data_pipeline = data_pipeline.zip(national_targets_datapipe, times_datapipe) | ||
|
||
data_pipeline = ReorganiseBatch(data_pipeline.batch(self.batch_size)) | ||
|
||
return data_pipeline | ||
|
||
def train_dataloader(self): | ||
"""Construct train dataloader""" | ||
datapipe = self._get_premade_batches_datapipe("train", shuffle=True) | ||
|
||
rs = MultiProcessingReadingService(**self.readingservice_config) | ||
return DataLoader2(datapipe, reading_service=rs) | ||
|
||
def val_dataloader(self): | ||
"""Construct val dataloader""" | ||
datapipe = self._get_premade_batches_datapipe("val") | ||
|
||
rs = MultiProcessingReadingService(**self.readingservice_config) | ||
return DataLoader2(datapipe, reading_service=rs) | ||
|
||
def test_dataloader(self): | ||
"""Construct test dataloader""" | ||
datapipe = self._get_premade_batches_datapipe("test") | ||
|
||
rs = MultiProcessingReadingService(**self.readingservice_config) | ||
return DataLoader2(datapipe, reading_service=rs) |
Oops, something went wrong.