Skip to content

Commit

Permalink
Merge pull request #2 from openclimatefix/presave_pvnet_outputs
Browse files Browse the repository at this point in the history
Presave pvnet outputs
  • Loading branch information
dfulu authored Jul 24, 2023
2 parents 9c62ff8 + 2143caa commit 8287321
Show file tree
Hide file tree
Showing 14 changed files with 355 additions and 165 deletions.
4 changes: 4 additions & 0 deletions configs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ defaults:
- logger: wandb.yaml # set logger here or use command line (e.g. `python run.py logger=wandb`)
- hydra: default.yaml

# Whether to loop through the PVNet outputs and save them out before training
presave_pvnet_outputs:
True

# enable color logging
# - override hydra/hydra_logging: colorlog
# - override hydra/job_logging: colorlog
Expand Down
2 changes: 1 addition & 1 deletion configs/datamodule/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ batch_dir: "/mnt/disks/bigbatches/concurrent_batches_v3.6_-60mins"
gsp_zarr_path: "/mnt/disks/nwp/pv_gsp.zarr"
batch_size: 8
num_workers: 20
prefetch_factor: 2
prefetch_factor: 2
1 change: 0 additions & 1 deletion configs/model/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ output_network_kwargs:
res_block_layers: 2
dropout_frac: 0.0


# Foreast and time settings
forecast_minutes: 480

Expand Down
2 changes: 1 addition & 1 deletion pvnet_summation/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
"""PVNet_summation"""
"""PVNet_summation"""
279 changes: 217 additions & 62 deletions pvnet_summation/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,61 +2,112 @@

import torch
from lightning.pytorch import LightningDataModule
from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService
from torchdata.datapipes.iter import FileLister, IterDataPipe
from ocf_datapipes.utils.consts import BatchKey
from ocf_datapipes.load import OpenGSP
from ocf_datapipes.training.pvnet import normalize_gsp
from ocf_datapipes.utils.consts import BatchKey
from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService
from torchdata.datapipes.iter import FileLister, IterDataPipe, Zipper

from pvnet.data.datamodule import (
copy_batch_to_device,
batch_to_tensor,
split_batches,
)
# https://github.com/pytorch/pytorch/issues/973
torch.multiprocessing.set_sharing_strategy('file_system')
torch.multiprocessing.set_sharing_strategy("file_system")


class GetNationalPVLive(IterDataPipe):
def __init__(self, gsp_data, sample_datapipe, return_times=False):
"""Select national output targets for given times"""

def __init__(self, gsp_data, times_datapipe):
"""Select national output targets for given times
Args:
gsp_data: xarray Dataarray of the national outputs
times_datapipe: IterDataPipe yeilding arrays of target times.
"""
self.gsp_data = gsp_data
self.sample_datapipe = sample_datapipe
self.return_times = return_times

self.times_datapipe = times_datapipe

def __iter__(self):
gsp_data = self.gsp_data
for sample in self.sample_datapipe:
# Times for each GSP in the sample batch should be the same - take first
id0 = sample[BatchKey.gsp_t0_idx]
times = sample[BatchKey.gsp_time_utc][0, id0+1:]
for times in self.times_datapipe:
national_outputs = torch.as_tensor(
gsp_data.sel(time_utc=times.cpu().numpy().astype("datetime64[s]")).values
)

if self.return_times:
yield national_outputs, times
else:
yield national_outputs


class ReorganiseBatch(IterDataPipe):
"""Reoragnise batches for pvnet_summation"""
yield national_outputs


class GetBatchTime(IterDataPipe):
"""Extract the valid times from the concurrent sample batch"""

def __init__(self, sample_datapipe):
"""Extract the valid times from the concurrent sample batch
Args:
sample_datapipe: IterDataPipe yeilding concurrent sample batches
"""
self.sample_datapipe = sample_datapipe

def __iter__(self):
for sample in self.sample_datapipe:
# Times for each GSP in the sample batch should be the same - take first
id0 = sample[BatchKey.gsp_t0_idx]
times = sample[BatchKey.gsp_time_utc][0, id0 + 1 :]
yield times


class PivotDictList(IterDataPipe):
"""Convert list of dicts to dict of lists"""

def __init__(self, source_datapipe):
"""Reoragnise batches for pvnet_summation
"""Convert list of dicts to dict of lists
Args:
source_datapipe:
"""
self.source_datapipe = source_datapipe

def __iter__(self):
for list_of_dicts in self.source_datapipe:
keys = list_of_dicts[0].keys()
batch_dict = {k: [d[k] for d in list_of_dicts] for k in keys}
yield batch_dict


class DictApply(IterDataPipe):
"""Apply functions to elements of a dictionary and return processed dictionary."""

def __init__(self, source_datapipe, **transforms):
"""Apply functions to elements of a dictionary and return processed dictionary.
Args:
source_datapipe: Zipped datapipe of list[tuple(NumpyBatch, national_outputs)]
source_datapipe: Datapipe which yields dicts
**transforms: key-function pairs
"""
self.source_datapipe = source_datapipe

self.transforms = transforms

def __iter__(self):
for batch in self.source_datapipe:
yield dict(
pvnet_inputs = [sample[0] for sample in batch],
national_targets = torch.stack([sample[1] for sample in batch]),
times = torch.stack([sample[2] for sample in batch]),
)

for d in self.source_datapipe:
for key, function in self.transforms.items():
d[key] = function(d[key])
yield d


class ZipperDict(IterDataPipe):
"""Yield samples from multiple datapipes as a dict"""

def __init__(self, **datapipes):
"""Yield samples from multiple datapipes as a dict.
Args:
**datapipes: Named datapipes
"""
self.keys = list(datapipes.keys())
self.source_datapipes = Zipper(*[datapipes[key] for key in self.keys])

def __iter__(self):
for outputs in self.source_datapipes:
yield {key: value for key, value in zip(self.keys, outputs)}


class DataModule(LightningDataModule):
"""Datamodule for training pvnet_summation."""

Expand Down Expand Up @@ -87,46 +138,150 @@ def __init__(
multiprocessing_context="spawn",
worker_prefetch_cnt=prefetch_factor,
)

def _get_premade_batches_datapipe(self, subdir, shuffle=False):
data_pipeline = FileLister(f"{self.batch_dir}/{subdir}", masks="*.pt", recursive=False)

def _get_premade_batches_datapipe(self, subdir, shuffle=False, add_filename=False):
# Load presaved concurrent sample batches
file_pipeline = FileLister(f"{self.batch_dir}/{subdir}", masks="*.pt", recursive=False)

if shuffle:
data_pipeline = data_pipeline.shuffle(buffer_size=1000)

data_pipeline = data_pipeline.sharding_filter().map(torch.load)

# Add the national target
data_pipeline, dp = data_pipeline.fork(2, buffer_size=5)

gsp_datapipe = OpenGSP(gsp_pv_power_zarr_path=self.gsp_zarr_path).map(normalize_gsp)
gsp_data = next(iter(gsp_datapipe)).sel(gsp_id=0).compute()

national_targets_datapipe, times_datapipe = (
GetNationalPVLive(gsp_data, dp, return_times=True).unzip(sequence_length=2)
file_pipeline = file_pipeline.shuffle(buffer_size=1000)

file_pipeline = file_pipeline.sharding_filter()

if add_filename:
file_pipeline, file_pipeline_copy = file_pipeline.fork(2, buffer_size=5)

sample_pipeline = file_pipeline.map(torch.load)

# Find national outout simultaneous to concurrent samples
gsp_data = (
next(iter(OpenGSP(gsp_pv_power_zarr_path=self.gsp_zarr_path).map(normalize_gsp)))
.sel(gsp_id=0)
.compute()
)

sample_pipeline, sample_pipeline_copy = sample_pipeline.fork(2, buffer_size=5)

times_datapipe, times_datapipe_copy = GetBatchTime(sample_pipeline_copy).fork(
2, buffer_size=5
)
data_pipeline = data_pipeline.zip(national_targets_datapipe, times_datapipe)

data_pipeline = ReorganiseBatch(data_pipeline.batch(self.batch_size))


national_targets_datapipe = GetNationalPVLive(gsp_data, times_datapipe_copy)

# Compile the samples
if add_filename:
data_pipeline = ZipperDict(
pvnet_inputs=sample_pipeline,
national_targets=national_targets_datapipe,
times=times_datapipe,
filepath=file_pipeline_copy,
)
else:
data_pipeline = ZipperDict(
pvnet_inputs=sample_pipeline,
national_targets=national_targets_datapipe,
times=times_datapipe,
)

if self.batch_size is not None:
data_pipeline = PivotDictList(data_pipeline.batch(self.batch_size))
data_pipeline = DictApply(
data_pipeline,
national_targets=torch.stack,
times=torch.stack,
)

return data_pipeline

def train_dataloader(self):
def train_dataloader(self, shuffle=True, add_filename=False):
"""Construct train dataloader"""
datapipe = self._get_premade_batches_datapipe("train", shuffle=True)
datapipe = self._get_premade_batches_datapipe(
"train", shuffle=shuffle, add_filename=add_filename
)

rs = MultiProcessingReadingService(**self.readingservice_config)
return DataLoader2(datapipe, reading_service=rs)

def val_dataloader(self):
def val_dataloader(self, shuffle=False, add_filename=False):
"""Construct val dataloader"""
datapipe = self._get_premade_batches_datapipe("val")

datapipe = self._get_premade_batches_datapipe(
"val", shuffle=shuffle, add_filename=add_filename
)
rs = MultiProcessingReadingService(**self.readingservice_config)
return DataLoader2(datapipe, reading_service=rs)

def test_dataloader(self):
"""Construct test dataloader"""
datapipe = self._get_premade_batches_datapipe("test")
raise NotImplementedError


class PVNetPresavedDataModule(LightningDataModule):
"""Datamodule for loading pre-saved PVNet predictions to train pvnet_summation."""

def __init__(
self,
batch_dir: str,
batch_size=16,
num_workers=0,
prefetch_factor=2,
):
"""Datamodule for loading pre-saved PVNet predictions to train pvnet_summation.
Args:
batch_dir: Path to the directory of pre-saved batches.
batch_size: Batch size.
num_workers: Number of workers to use in multiprocess batch loading.
prefetch_factor: Number of data will be prefetched at the end of each worker process.
"""
super().__init__()
self.batch_size = batch_size
self.batch_dir = batch_dir

self.readingservice_config = dict(
num_workers=num_workers,
multiprocessing_context="spawn",
worker_prefetch_cnt=prefetch_factor,
)

def _get_premade_batches_datapipe(self, subdir, shuffle=False):
# Load presaved concurrent sample batches
file_pipeline = FileLister(f"{self.batch_dir}/{subdir}", masks="*.pt", recursive=False)

if shuffle:
file_pipeline = file_pipeline.shuffle(buffer_size=1000)

sample_pipeline = file_pipeline.sharding_filter().map(torch.load)

if self.batch_size is not None:
batch_pipeline = PivotDictList(sample_pipeline.batch(self.batch_size))
batch_pipeline = DictApply(
batch_pipeline,
pvnet_outputs=torch.stack,
national_targets=torch.stack,
times=torch.stack,
)

return batch_pipeline

def train_dataloader(self, shuffle=True):
"""Construct train dataloader"""
datapipe = self._get_premade_batches_datapipe(
"train",
shuffle=shuffle,
)

rs = MultiProcessingReadingService(**self.readingservice_config)
return DataLoader2(datapipe, reading_service=rs)

def val_dataloader(self, shuffle=False):
"""Construct val dataloader"""
datapipe = self._get_premade_batches_datapipe(
"val",
shuffle=shuffle,
)
rs = MultiProcessingReadingService(**self.readingservice_config)
return DataLoader2(datapipe, reading_service=rs)

def test_dataloader(self):
"""Construct test dataloader"""
raise NotImplementedError
Loading

0 comments on commit 8287321

Please sign in to comment.