Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobbieker committed Nov 27, 2023
1 parent 1953090 commit d39d2ad
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 107 deletions.
1 change: 1 addition & 0 deletions pvnet/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
"""Data parts"""
from .utils import BatchSplitter
54 changes: 1 addition & 53 deletions pvnet/data/datamodule.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,13 @@
""" Data module for pytorch lightning """
from datetime import datetime

import numpy as np
import torch
from lightning.pytorch import LightningDataModule
from ocf_datapipes.training.pvnet import pvnet_datapipe
from ocf_datapipes.utils.consts import BatchKey
from ocf_datapipes.utils.utils import stack_np_examples_into_batch
from torch.utils.data import DataLoader
from torch.utils.data.datapipes._decorator import functional_datapipe
from torch.utils.data.datapipes.datapipe import IterDataPipe
from torch.utils.data.datapipes.iter import FileLister


def copy_batch_to_device(batch, device):
"""Moves a dict-batch of tensors to new device."""
batch_copy = {}
for k in list(batch.keys()):
if isinstance(batch[k], torch.Tensor):
batch_copy[k] = batch[k].to(device)
else:
batch_copy[k] = batch[k]
return batch_copy


def batch_to_tensor(batch):
"""Moves numpy batch to a tensor"""
for k in list(batch.keys()):
if isinstance(batch[k], np.ndarray) and np.issubdtype(batch[k].dtype, np.number):
batch[k] = torch.as_tensor(batch[k])
return batch


def split_batches(batch):
"""Splits a single batch of data."""
n_samples = batch[BatchKey.gsp].shape[0]
keys = list(batch.keys())
examples = [{} for _ in range(n_samples)]
for i in range(n_samples):
b = examples[i]
for k in keys:
if ("idx" in k.name) or ("channel_names" in k.name):
b[k] = batch[k]
else:
b[k] = batch[k][i]
return examples


@functional_datapipe("split_batches")
class BatchSplitter(IterDataPipe):
"""Pipeline step to split batches of data and yield single examples"""

def __init__(self, source_datapipe: IterDataPipe):
"""Pipeline step to split batches of data and yield single examples"""
self.source_datapipe = source_datapipe

def __iter__(self):
"""Opens the NWP data"""
for batch in self.source_datapipe:
for example in split_batches(batch):
yield example
from pvnet.data.utils import batch_to_tensor


class DataModule(LightningDataModule):
Expand Down
53 changes: 53 additions & 0 deletions pvnet/data/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import numpy as np
import torch
from ocf_datapipes.utils.consts import BatchKey
from torch.utils.data import functional_datapipe, IterDataPipe


def copy_batch_to_device(batch, device):
"""Moves a dict-batch of tensors to new device."""
batch_copy = {}
for k in list(batch.keys()):
if isinstance(batch[k], torch.Tensor):
batch_copy[k] = batch[k].to(device)
else:
batch_copy[k] = batch[k]
return batch_copy


def batch_to_tensor(batch):
"""Moves numpy batch to a tensor"""
for k in list(batch.keys()):
if isinstance(batch[k], np.ndarray) and np.issubdtype(batch[k].dtype, np.number):
batch[k] = torch.as_tensor(batch[k])
return batch


def split_batches(batch):
"""Splits a single batch of data."""
n_samples = batch[BatchKey.sensor].shape[0]
keys = list(batch.keys())
examples = [{} for _ in range(n_samples)]
for i in range(n_samples):
b = examples[i]
for k in keys:
if ("idx" in k.name) or ("channel_names" in k.name):
b[k] = batch[k]
else:
b[k] = batch[k][i]
return examples


@functional_datapipe("split_batches")
class BatchSplitter(IterDataPipe):
"""Pipeline step to split batches of data and yield single examples"""

def __init__(self, source_datapipe: IterDataPipe):
"""Pipeline step to split batches of data and yield single examples"""
self.source_datapipe = source_datapipe

def __iter__(self):
"""Opens the NWP data"""
for batch in self.source_datapipe:
for example in split_batches(batch):
yield example
54 changes: 1 addition & 53 deletions pvnet/data/wind_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,13 @@
""" Data module for pytorch lightning """
from datetime import datetime

import numpy as np
import torch
from lightning.pytorch import LightningDataModule
from ocf_datapipes.training.windnet import windnet_netcdf_datapipe
from ocf_datapipes.utils.consts import BatchKey
from ocf_datapipes.utils.utils import stack_np_examples_into_batch
from torch.utils.data import DataLoader
from torch.utils.data.datapipes._decorator import functional_datapipe
from torch.utils.data.datapipes.datapipe import IterDataPipe
import glob


def copy_batch_to_device(batch, device):
"""Moves a dict-batch of tensors to new device."""
batch_copy = {}
for k in list(batch.keys()):
if isinstance(batch[k], torch.Tensor):
batch_copy[k] = batch[k].to(device)
else:
batch_copy[k] = batch[k]
return batch_copy


def batch_to_tensor(batch):
"""Moves numpy batch to a tensor"""
for k in list(batch.keys()):
if isinstance(batch[k], np.ndarray) and np.issubdtype(batch[k].dtype, np.number):
batch[k] = torch.as_tensor(batch[k])
return batch


def split_batches(batch):
"""Splits a single batch of data."""
n_samples = batch[BatchKey.sensor].shape[0]
keys = list(batch.keys())
examples = [{} for _ in range(n_samples)]
for i in range(n_samples):
b = examples[i]
for k in keys:
if ("idx" in k.name) or ("channel_names" in k.name):
b[k] = batch[k]
else:
b[k] = batch[k][i]
return examples


@functional_datapipe("split_batches")
class BatchSplitter(IterDataPipe):
"""Pipeline step to split batches of data and yield single examples"""

def __init__(self, source_datapipe: IterDataPipe):
"""Pipeline step to split batches of data and yield single examples"""
self.source_datapipe = source_datapipe

def __iter__(self):
"""Opens the NWP data"""
for batch in self.source_datapipe:
for example in split_batches(batch):
yield example
from pvnet.data.utils import batch_to_tensor


class WindDataModule(LightningDataModule):
Expand Down
16 changes: 15 additions & 1 deletion tests/data/test_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pvnet.data.datamodule import DataModule

from pvnet.data.wind_datamodule import WindDataModule
import os

def test_init():
dm = DataModule(
Expand All @@ -13,3 +14,16 @@ def test_init():
block_nwp_and_sat=False,
batch_dir="tests/data/sample_batches",
)


def test_wind_init():
dm = WindDataModule(
configuration=f"{os.path.dirname(os.path.abspath(__file__))}/data/sample_batches/data_configuration.yaml",
batch_size=2,
num_workers=0,
prefetch_factor=None,
train_period=[None, None],
val_period=[None, None],
test_period=[None, None],
batch_dir="tests/data/sample_batches",
)

0 comments on commit d39d2ad

Please sign in to comment.