Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Nov 27, 2023
1 parent c7565a0 commit 4b9f837
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 7 deletions.
1 change: 1 addition & 0 deletions pvnet/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ocf_datapipes.utils.utils import stack_np_examples_into_batch
from torch.utils.data import DataLoader
from torch.utils.data.datapipes.iter import FileLister

from pvnet.data.utils import batch_to_tensor


Expand Down
2 changes: 1 addition & 1 deletion pvnet/data/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import torch
from ocf_datapipes.utils.consts import BatchKey
from torch.utils.data import functional_datapipe, IterDataPipe
from torch.utils.data import IterDataPipe, functional_datapipe


def copy_batch_to_device(batch, device):
Expand Down
13 changes: 7 additions & 6 deletions pvnet/data/wind_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
""" Data module for pytorch lightning """
import glob
from datetime import datetime

from lightning.pytorch import LightningDataModule
from ocf_datapipes.training.windnet import windnet_netcdf_datapipe
from ocf_datapipes.utils.utils import stack_np_examples_into_batch
from torch.utils.data import DataLoader
import glob

from pvnet.data.utils import batch_to_tensor

Expand Down Expand Up @@ -46,7 +46,6 @@ def __init__(
self.batch_size = batch_size
self.batch_dir = batch_dir


if batch_dir is not None:
if any([period != [None, None] for period in [train_period, val_period, test_period]]):
raise ValueError("Cannot set `(train/val/test)_period` with presaved batches")
Expand Down Expand Up @@ -79,7 +78,7 @@ def __init__(
def _get_datapipe(self, start_time, end_time):
data_pipeline = windnet_netcdf_datapipe(
self.configuration,
keys=["sensor","nwp"],
keys=["sensor", "nwp"],
)

data_pipeline = (
Expand All @@ -90,9 +89,11 @@ def _get_datapipe(self, start_time, end_time):
return data_pipeline

def _get_premade_batches_datapipe(self, subdir, shuffle=False):
data_pipeline = windnet_netcdf_datapipe(config_filename=self.configuration,
keys=["sensor","nwp"],
filenames=list(glob.glob(f"{self.batch_dir}/{subdir}/*.nc")))
data_pipeline = windnet_netcdf_datapipe(
config_filename=self.configuration,
keys=["sensor", "nwp"],
filenames=list(glob.glob(f"{self.batch_dir}/{subdir}/*.nc")),
)
if shuffle:
data_pipeline = (
data_pipeline.shuffle(buffer_size=100)
Expand Down
1 change: 1 addition & 0 deletions tests/data/test_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pvnet.data.wind_datamodule import WindDataModule
import os


def test_init():
dm = DataModule(
configuration=None,
Expand Down

0 comments on commit 4b9f837

Please sign in to comment.