Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jul 21, 2023
1 parent cc4707d commit 2143caa
Show file tree
Hide file tree
Showing 14 changed files with 159 additions and 213 deletions.
4 changes: 2 additions & 2 deletions configs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ defaults:
- hydra: default.yaml

# Whether to loop through the PVNet outputs and save them out before training
presave_pvnet_outputs: True

presave_pvnet_outputs:
True

# enable color logging
# - override hydra/hydra_logging: colorlog
Expand Down
2 changes: 1 addition & 1 deletion configs/datamodule/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ batch_dir: "/mnt/disks/bigbatches/concurrent_batches_v3.6_-60mins"
gsp_zarr_path: "/mnt/disks/nwp/pv_gsp.zarr"
batch_size: 8
num_workers: 20
prefetch_factor: 2
prefetch_factor: 2
1 change: 0 additions & 1 deletion configs/model/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ output_network_kwargs:
res_block_layers: 2
dropout_frac: 0.0


# Foreast and time settings
forecast_minutes: 480

Expand Down
2 changes: 1 addition & 1 deletion pvnet_summation/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
"""PVNet_summation"""
"""PVNet_summation"""
157 changes: 68 additions & 89 deletions pvnet_summation/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,30 @@

import torch
from lightning.pytorch import LightningDataModule
from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService
from torchdata.datapipes.iter import FileLister, IterDataPipe
from ocf_datapipes.utils.consts import BatchKey
from ocf_datapipes.load import OpenGSP
from ocf_datapipes.training.pvnet import normalize_gsp
from torchdata.datapipes.iter import Zipper
from ocf_datapipes.utils.consts import BatchKey
from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService
from torchdata.datapipes.iter import FileLister, IterDataPipe, Zipper

from pvnet.data.datamodule import (
copy_batch_to_device,
batch_to_tensor,
split_batches,
)
# https://github.com/pytorch/pytorch/issues/973
torch.multiprocessing.set_sharing_strategy('file_system')
torch.multiprocessing.set_sharing_strategy("file_system")


class GetNationalPVLive(IterDataPipe):
"""Select national output targets for given times"""

def __init__(self, gsp_data, times_datapipe):
"""Select national output targets for given times
Args:
gsp_data: xarray Dataarray of the national outputs
times_datapipe: IterDataPipe yeilding arrays of target times.
"""
self.gsp_data = gsp_data
self.times_datapipe = times_datapipe

def __iter__(self):

gsp_data = self.gsp_data
for times in self.times_datapipe:
national_outputs = torch.as_tensor(
Expand All @@ -42,54 +36,54 @@ def __iter__(self):

class GetBatchTime(IterDataPipe):
"""Extract the valid times from the concurrent sample batch"""

def __init__(self, sample_datapipe):
"""Extract the valid times from the concurrent sample batch
Args:
sample_datapipe: IterDataPipe yeilding concurrent sample batches
"""
self.sample_datapipe = sample_datapipe

def __iter__(self):
for sample in self.sample_datapipe:
# Times for each GSP in the sample batch should be the same - take first
# Times for each GSP in the sample batch should be the same - take first
id0 = sample[BatchKey.gsp_t0_idx]
times = sample[BatchKey.gsp_time_utc][0, id0+1:]
times = sample[BatchKey.gsp_time_utc][0, id0 + 1 :]
yield times


class PivotDictList(IterDataPipe):
"""Convert list of dicts to dict of lists"""

def __init__(self, source_datapipe):
"""Convert list of dicts to dict of lists
Args:
source_datapipe:
source_datapipe:
"""
self.source_datapipe = source_datapipe

def __iter__(self):
for list_of_dicts in self.source_datapipe:
keys = list_of_dicts[0].keys()
batch_dict = {k: [d[k] for d in list_of_dicts] for k in keys}
yield batch_dict


class DictApply(IterDataPipe):
"""Apply functions to elements of a dictionary and return processed dictionary."""

def __init__(self, source_datapipe, **transforms):
"""Apply functions to elements of a dictionary and return processed dictionary.
Args:
source_datapipe: Datapipe which yields dicts
**transforms: key-function pairs
"""
self.source_datapipe = source_datapipe
self.transforms = transforms

def __iter__(self):
for d in self.source_datapipe:
for key, function in self.transforms.items():
Expand All @@ -99,21 +93,21 @@ def __iter__(self):

class ZipperDict(IterDataPipe):
"""Yield samples from multiple datapipes as a dict"""

def __init__(self, **datapipes):
"""Yield samples from multiple datapipes as a dict.
Args:
**datapipes: Named datapipes
"""
self.keys = list(datapipes.keys())
self.source_datapipes = Zipper(*[datapipes[key] for key in self.keys])

def __iter__(self):
for outputs in self.source_datapipes:
yield {key: value for key, value in zip(self.keys, outputs)}


class DataModule(LightningDataModule):
"""Datamodule for training pvnet_summation."""

Expand Down Expand Up @@ -144,95 +138,83 @@ def __init__(
multiprocessing_context="spawn",
worker_prefetch_cnt=prefetch_factor,
)

def _get_premade_batches_datapipe(self, subdir, shuffle=False, add_filename=False):

# Load presaved concurrent sample batches
file_pipeline = FileLister(f"{self.batch_dir}/{subdir}", masks="*.pt", recursive=False)

if shuffle:
file_pipeline = file_pipeline.shuffle(buffer_size=1000)

file_pipeline = file_pipeline.sharding_filter()

if add_filename:
file_pipeline, file_pipeline_copy = file_pipeline.fork(2, buffer_size=5)

sample_pipeline = file_pipeline.map(torch.load)

# Find national outout simultaneous to concurrent samples
gsp_data = (
next(iter(
OpenGSP(gsp_pv_power_zarr_path=self.gsp_zarr_path)
.map(normalize_gsp)
))
next(iter(OpenGSP(gsp_pv_power_zarr_path=self.gsp_zarr_path).map(normalize_gsp)))
.sel(gsp_id=0)
.compute()
)

sample_pipeline, sample_pipeline_copy = sample_pipeline.fork(2, buffer_size=5)
times_datapipe, times_datapipe_copy = (
GetBatchTime(sample_pipeline_copy).fork(2, buffer_size=5)

times_datapipe, times_datapipe_copy = GetBatchTime(sample_pipeline_copy).fork(
2, buffer_size=5
)

national_targets_datapipe = GetNationalPVLive(gsp_data, times_datapipe_copy)

# Compile the samples
if add_filename:
data_pipeline = ZipperDict(
pvnet_inputs = sample_pipeline,
national_targets = national_targets_datapipe,
times = times_datapipe,
filepath = file_pipeline_copy,
pvnet_inputs=sample_pipeline,
national_targets=national_targets_datapipe,
times=times_datapipe,
filepath=file_pipeline_copy,
)
else:
data_pipeline = ZipperDict(
pvnet_inputs = sample_pipeline,
national_targets = national_targets_datapipe,
times = times_datapipe,
)
pvnet_inputs=sample_pipeline,
national_targets=national_targets_datapipe,
times=times_datapipe,
)

if self.batch_size is not None:

data_pipeline = PivotDictList(data_pipeline.batch(self.batch_size))
data_pipeline = DictApply(
data_pipeline,
national_targets=torch.stack,
data_pipeline,
national_targets=torch.stack,
times=torch.stack,
)

return data_pipeline


def train_dataloader(self, shuffle=True, add_filename=False):
"""Construct train dataloader"""
datapipe = self._get_premade_batches_datapipe(
"train",
shuffle=shuffle,
add_filename=add_filename
"train", shuffle=shuffle, add_filename=add_filename
)

rs = MultiProcessingReadingService(**self.readingservice_config)
return DataLoader2(datapipe, reading_service=rs)


def val_dataloader(self, shuffle=False, add_filename=False):
"""Construct val dataloader"""
datapipe = self._get_premade_batches_datapipe(
"val",
shuffle=shuffle,
add_filename=add_filename
)
"val", shuffle=shuffle, add_filename=add_filename
)
rs = MultiProcessingReadingService(**self.readingservice_config)
return DataLoader2(datapipe, reading_service=rs)


def test_dataloader(self):
"""Construct test dataloader"""
raise NotImplementedError


class PVNetPresavedDataModule(LightningDataModule):
"""Datamodule for loading pre-saved PVNet predictions to train pvnet_summation."""

Expand Down Expand Up @@ -260,34 +242,32 @@ def __init__(
multiprocessing_context="spawn",
worker_prefetch_cnt=prefetch_factor,
)

def _get_premade_batches_datapipe(self, subdir, shuffle=False):

# Load presaved concurrent sample batches
file_pipeline = FileLister(f"{self.batch_dir}/{subdir}", masks="*.pt", recursive=False)

if shuffle:
file_pipeline = file_pipeline.shuffle(buffer_size=1000)
sample_pipeline = file_pipeline.sharding_filter().map(torch.load)

sample_pipeline = file_pipeline.sharding_filter().map(torch.load)

if self.batch_size is not None:

batch_pipeline = PivotDictList(sample_pipeline.batch(self.batch_size))
batch_pipeline = DictApply(
batch_pipeline,
pvnet_outputs=torch.stack,
national_targets=torch.stack,
national_targets=torch.stack,
times=torch.stack,
)

return batch_pipeline

def train_dataloader(self, shuffle=True):
"""Construct train dataloader"""
datapipe = self._get_premade_batches_datapipe(
"train",
shuffle=shuffle,
"train",
shuffle=shuffle,
)

rs = MultiProcessingReadingService(**self.readingservice_config)
Expand All @@ -296,13 +276,12 @@ def train_dataloader(self, shuffle=True):
def val_dataloader(self, shuffle=False):
"""Construct val dataloader"""
datapipe = self._get_premade_batches_datapipe(
"val",
shuffle=shuffle,
)
"val",
shuffle=shuffle,
)
rs = MultiProcessingReadingService(**self.readingservice_config)
return DataLoader2(datapipe, reading_service=rs)

def test_dataloader(self):
"""Construct test dataloader"""
raise NotImplementedError

Loading

0 comments on commit 2143caa

Please sign in to comment.