-
-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
33 changed files
with
555 additions
and
559 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,57 +1,116 @@ | ||
""" Data module for pytorch lightning """ | ||
|
||
import resource | ||
from glob import glob | ||
|
||
import torch | ||
from ocf_datapipes.batch import batch_to_tensor, stack_np_examples_into_batch | ||
from ocf_datapipes.training.pvnet import pvnet_datapipe | ||
from torch.utils.data.datapipes.iter import FileLister | ||
from lightning.pytorch import LightningDataModule | ||
from ocf_data_sampler.torch_datasets.pvnet_uk_regional import PVNetUKRegionalDataset | ||
from ocf_datapipes.batch import ( | ||
NumpyBatch, | ||
TensorBatch, | ||
batch_to_tensor, | ||
stack_np_examples_into_batch, | ||
) | ||
from torch.utils.data import DataLoader, Dataset | ||
|
||
|
||
class NumpybatchPremadeSamplesDataset(Dataset): | ||
"""Dataset to load NumpyBatch samples""" | ||
|
||
def __init__(self, sample_dir): | ||
"""Dataset to load NumpyBatch samples | ||
Args: | ||
sample_dir: Path to the directory of pre-saved samples. | ||
""" | ||
self.sample_paths = glob(f"{sample_dir}/*.pt") | ||
|
||
from pvnet.data.base import BaseDataModule | ||
def __len__(self): | ||
return len(self.sample_paths) | ||
|
||
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) | ||
resource.setrlimit(resource.RLIMIT_NOFILE, (2048, rlimit[1])) | ||
def __getitem__(self, idx): | ||
return torch.load(self.sample_paths[idx]) | ||
|
||
|
||
class DataModule(BaseDataModule): | ||
def collate_fn(samples: list[NumpyBatch]) -> TensorBatch: | ||
"""Convert a list of NumpyBatch samples to a tensor batch""" | ||
return batch_to_tensor(stack_np_examples_into_batch(samples)) | ||
|
||
|
||
class DataModule(LightningDataModule): | ||
"""Datamodule for training pvnet and using pvnet pipeline in `ocf_datapipes`.""" | ||
|
||
def _get_datapipe(self, start_time, end_time): | ||
data_pipeline = pvnet_datapipe( | ||
self.configuration, | ||
start_time=start_time, | ||
end_time=end_time, | ||
) | ||
def __init__( | ||
self, | ||
configuration: str | None = None, | ||
sample_dir: str | None = None, | ||
batch_size: int = 16, | ||
num_workers: int = 0, | ||
prefetch_factor: int | None = None, | ||
train_period: list[str | None] = [None, None], | ||
val_period: list[str | None] = [None, None], | ||
): | ||
"""Datamodule for training pvnet architecture. | ||
Can also be used with pre-made batches if `sample_dir` is set. | ||
Args: | ||
configuration: Path to datapipe configuration file. | ||
sample_dir: Path to the directory of pre-saved samples. Cannot be used together with | ||
`configuration` or '[train/val]_period'. | ||
batch_size: Batch size. | ||
num_workers: Number of workers to use in multiprocess batch loading. | ||
prefetch_factor: Number of data will be prefetched at the end of each worker process. | ||
train_period: Date range filter for train dataloader. | ||
val_period: Date range filter for val dataloader. | ||
data_pipeline = ( | ||
data_pipeline.batch(self.batch_size) | ||
.map(stack_np_examples_into_batch) | ||
.map(batch_to_tensor) | ||
""" | ||
super().__init__() | ||
|
||
if not ((sample_dir is not None) ^ (configuration is not None)): | ||
raise ValueError("Exactly one of `sample_dir` or `configuration` must be set.") | ||
|
||
if sample_dir is not None: | ||
if any([period != [None, None] for period in [train_period, val_period]]): | ||
raise ValueError("Cannot set `(train/val)_period` with presaved samples") | ||
|
||
self.configuration = configuration | ||
self.sample_dir = sample_dir | ||
self.train_period = train_period | ||
self.val_period = val_period | ||
|
||
self._common_dataloader_kwargs = dict( | ||
batch_size=batch_size, | ||
sampler=None, | ||
batch_sampler=None, | ||
num_workers=num_workers, | ||
collate_fn=collate_fn, | ||
pin_memory=False, | ||
drop_last=False, | ||
timeout=0, | ||
worker_init_fn=None, | ||
prefetch_factor=prefetch_factor, | ||
persistent_workers=False, | ||
) | ||
return data_pipeline | ||
|
||
def _get_premade_batches_datapipe(self, subdir, shuffle=False): | ||
data_pipeline = FileLister(f"{self.batch_dir}/{subdir}", masks="*.pt", recursive=False) | ||
if shuffle: | ||
data_pipeline = ( | ||
data_pipeline.shuffle(buffer_size=10_000) | ||
.sharding_filter() | ||
.map(torch.load) | ||
# Split the batches and reshuffle them to be combined into new batches | ||
.split_batches() | ||
.shuffle(buffer_size=self.shuffle_factor * self.batch_size) | ||
) | ||
|
||
def _get_streamed_samples_dataset(self, start_time, end_time) -> Dataset: | ||
return PVNetUKRegionalDataset(self.configuration, start_time=start_time, end_time=end_time) | ||
|
||
def _get_premade_samples_dataset(self, subdir) -> Dataset: | ||
split_dir = f"{self.sample_dir}/{subdir}" | ||
return NumpybatchPremadeSamplesDataset(split_dir) | ||
|
||
def train_dataloader(self) -> DataLoader: | ||
"""Construct train dataloader""" | ||
if self.sample_dir is not None: | ||
dataset = self._get_premade_samples_dataset("train") | ||
else: | ||
data_pipeline = ( | ||
data_pipeline.sharding_filter().map(torch.load) | ||
# Split the batches so we can use any batch-size | ||
.split_batches() | ||
) | ||
|
||
data_pipeline = ( | ||
data_pipeline.batch(self.batch_size) | ||
.map(stack_np_examples_into_batch) | ||
.map(batch_to_tensor) | ||
) | ||
dataset = self._get_streamed_samples_dataset(*self.train_period) | ||
return DataLoader(dataset, shuffle=True, **self._common_dataloader_kwargs) | ||
|
||
return data_pipeline | ||
def val_dataloader(self) -> DataLoader: | ||
"""Construct val dataloader""" | ||
if self.sample_dir is not None: | ||
dataset = self._get_premade_samples_dataset("val") | ||
else: | ||
dataset = self._get_streamed_samples_dataset(*self.val_period) | ||
return DataLoader(dataset, shuffle=False, **self._common_dataloader_kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.