diff --git a/pvnet/data/datamodule.py b/pvnet/data/datamodule.py index 78783975..83f10487 100644 --- a/pvnet/data/datamodule.py +++ b/pvnet/data/datamodule.py @@ -3,13 +3,16 @@ import numpy as np import torch +from torch.utils.data import DataLoader +from torch.utils.data.datapipes.datapipe import IterDataPipe +from torch.utils.data.datapipes._decorator import functional_datapipe + from lightning.pytorch import LightningDataModule + from ocf_datapipes.training.pvnet import pvnet_datapipe from ocf_datapipes.utils.consts import BatchKey from ocf_datapipes.utils.utils import stack_np_examples_into_batch -from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService -from torchdata.datapipes import functional_datapipe -from torchdata.datapipes.iter import FileLister, IterDataPipe + def copy_batch_to_device(batch, device): @@ -117,11 +120,20 @@ def __init__( self.test_period = [ None if d is None else datetime.strptime(d, "%Y-%m-%d") for d in test_period ] - - self.readingservice_config = dict( - num_workers=num_workers, - multiprocessing_context="spawn", - worker_prefetch_cnt=prefetch_factor, + + self._common_dataloader_kwargs = dict( + shuffle=False, # shuffled in datapipe step + batch_size=None, # batched in datapipe step + sampler=None, + batch_sampler=None, + num_workers=num_workers, + collate_fn=None, + pin_memory=False, + drop_last=False, + timeout=0, + worker_init_fn=None, + prefetch_factor=prefetch_factor, + persistent_workers=False ) def _get_datapipe(self, start_time, end_time): @@ -173,7 +185,7 @@ def train_dataloader(self): else: datapipe = self._get_datapipe(*self.train_period) rs = MultiProcessingReadingService(**self.readingservice_config) - return DataLoader2(datapipe, reading_service=rs) + return DataLoader(datapipe, **self._common_dataloader_kwargs) def val_dataloader(self): """Construct val dataloader""" @@ -181,8 +193,8 @@ def val_dataloader(self): datapipe = self._get_premade_batches_datapipe("val") else: datapipe = self._get_datapipe(*self.val_period) - rs = MultiProcessingReadingService(**self.readingservice_config) - return DataLoader2(datapipe, reading_service=rs) + return DataLoader(datapipe, **self._common_dataloader_kwargs) + def test_dataloader(self): """Construct test dataloader""" @@ -190,5 +202,5 @@ def test_dataloader(self): datapipe = self._get_premade_batches_datapipe("test") else: datapipe = self._get_datapipe(*self.test_period) - rs = MultiProcessingReadingService(**self.readingservice_config) - return DataLoader2(datapipe, reading_service=rs) + return DataLoader(datapipe, **self._common_dataloader_kwargs) + diff --git a/requirements.txt b/requirements.txt index 61148f13..e3bcea24 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -ocf_datapipes>=2.0.6 +ocf_datapipes>=2.2.2 nowcasting_utils ocf_ml_metrics numpy @@ -9,7 +9,6 @@ ipykernel h5netcdf torch>=2.0 lightning>=2.0.1 -torchdata torchvision pytest pytest-cov diff --git a/scripts/load_batches.py b/scripts/load_batches.py index 261371fe..718f6c8f 100644 --- a/scripts/load_batches.py +++ b/scripts/load_batches.py @@ -2,7 +2,7 @@ """ import torch -from torchdata.datapipes.iter import FileLister +from torch.utils.data.datapipes.iter import FileLister from pvnet.data.datamodule import BatchSplitter diff --git a/scripts/save_batches.py b/scripts/save_batches.py index 0a145502..a3526a76 100644 --- a/scripts/save_batches.py +++ b/scripts/save_batches.py @@ -7,9 +7,11 @@ use: ``` python save_batches.py \ - +batch_output_dir="/mnt/disks/batches/batches_v0" \ - +num_train_batches=10_000 \ - +num_val_batches=2_000 + +batch_output_dir="/mnt/disks/bigbatches/batches_v0" \ + datamodule.batch_size=2 \ + datamodule.num_workers=2 \ + +num_train_batches=0 \ + +num_val_batches=2 ``` """ @@ -27,11 +29,11 @@ from ocf_datapipes.utils.utils import stack_np_examples_into_batch from omegaconf import DictConfig, OmegaConf from sqlalchemy import exc as sa_exc -from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService -from torchdata.datapipes.iter import IterableWrapper +from torch.utils.data import DataLoader from tqdm import tqdm from pvnet.data.datamodule import batch_to_tensor +from pvnet.utils import print_config warnings.filterwarnings("ignore", category=sa_exc.SAWarning) @@ -62,13 +64,12 @@ def _get_datapipe(config_path, start_time, end_time, batch_size): return data_pipeline -def _save_batches_with_dataloader(batch_pipe, batch_dir, num_batches, rs_config): +def _save_batches_with_dataloader(batch_pipe, batch_dir, num_batches, dataloader_kwargs): save_func = _save_batch_func_factory(batch_dir) filenumber_pipe = IterableWrapper(range(num_batches)).sharding_filter() save_pipe = filenumber_pipe.zip(batch_pipe).map(save_func) - rs = MultiProcessingReadingService(**rs_config) - dataloader = DataLoader2(save_pipe, reading_service=rs) + dataloader = DataLoader(save_pipe, **dataloader_kwargs) pbar = tqdm(total=num_batches) for i, batch in zip(range(num_batches), dataloader): @@ -79,8 +80,10 @@ def _save_batches_with_dataloader(batch_pipe, batch_dir, num_batches, rs_config) @hydra.main(config_path="../configs/", config_name="config.yaml", version_base="1.2") def main(config: DictConfig): - """Constructs and saves validation and training batches.""" + "Constructs and saves validation and training batches." config_dm = config.datamodule + + print_config(config, resolve=False) # Set up directory os.makedirs(config.batch_output_dir, exist_ok=False) @@ -93,44 +96,56 @@ def main(config: DictConfig): os.mkdir(f"{config.batch_output_dir}/train") os.mkdir(f"{config.batch_output_dir}/val") - readingservice_config = dict( - num_workers=config_dm.num_workers, - multiprocessing_context="spawn", - worker_prefetch_cnt=config_dm.prefetch_factor, - ) - - print("----- Saving val batches -----") - - val_batch_pipe = _get_datapipe( - config_dm.configuration, - *config_dm.val_period, - config_dm.batch_size, - ) - - _save_batches_with_dataloader( - batch_pipe=val_batch_pipe, - batch_dir=f"{config.batch_output_dir}/val", - num_batches=config.num_val_batches, - rs_config=readingservice_config, - ) - - print("----- Saving train batches -----") - - train_batch_pipe = _get_datapipe( - config_dm.configuration, - *config_dm.train_period, - config_dm.batch_size, - ) - - _save_batches_with_dataloader( - batch_pipe=train_batch_pipe, - batch_dir=f"{config.batch_output_dir}/train", - num_batches=config.num_train_batches, - rs_config=readingservice_config, + dataloader_kwargs = dict( + shuffle=False, + batch_size=None, # batched in datapipe step + sampler=None, + batch_sampler=None, + num_workers=config_dm.num_workers, + collate_fn=None, + pin_memory=False, + drop_last=False, + timeout=0, + worker_init_fn=None, + prefetch_factor=config_dm.prefetch_factor, + persistent_workers=False ) + + if config.num_val_batches>0: + print("----- Saving val batches -----") + + val_batch_pipe = _get_datapipe( + config_dm.configuration, + *config_dm.val_period, + config_dm.batch_size, + ) + + _save_batches_with_dataloader( + batch_pipe=val_batch_pipe, + batch_dir=f"{config.batch_output_dir}/val", + num_batches=config.num_val_batches, + dataloader_kwargs=dataloader_kwargs, + ) + + if config.num_train_batches>0: + + print("----- Saving train batches -----") + + train_batch_pipe = _get_datapipe( + config_dm.configuration, + *config_dm.train_period, + config_dm.batch_size, + ) + + _save_batches_with_dataloader( + batch_pipe=train_batch_pipe, + batch_dir=f"{config.batch_output_dir}/train", + num_batches=config.num_train_batches, + dataloader_kwargs=dataloader_kwargs, + ) print("done") if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/scripts/save_concurrent_batches.py b/scripts/save_concurrent_batches.py index 8af2784e..7fe22a38 100644 --- a/scripts/save_concurrent_batches.py +++ b/scripts/save_concurrent_batches.py @@ -7,9 +7,9 @@ use: ``` python save_concurrent_batches.py \ - +batch_output_dir="/mnt/disks/batches/concurrent_batches_v0" \ - +num_train_batches=1_000 \ - +num_val_batches=200 + +batch_output_dir="/mnt/disks/nwp_rechunk/concurrent_batches_v3.9" \ + +num_train_batches=20_000 \ + +num_val_batches=4_000 ``` """ @@ -31,8 +31,7 @@ from ocf_datapipes.utils.utils import stack_np_examples_into_batch from omegaconf import DictConfig, OmegaConf from sqlalchemy import exc as sa_exc -from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService -from torchdata.datapipes.iter import IterableWrapper +from torch.utils.data import DataLoader from tqdm import tqdm from pvnet.data.datamodule import batch_to_tensor @@ -123,13 +122,12 @@ def _get_datapipe(config_path, start_time, end_time, n_batches): return data_pipeline -def _save_batches_with_dataloader(batch_pipe, batch_dir, num_batches, rs_config): +def _save_batches_with_dataloader(batch_pipe, batch_dir, num_batches, dataloader_kwargs): save_func = _save_batch_func_factory(batch_dir) filenumber_pipe = IterableWrapper(np.arange(num_batches)).sharding_filter() save_pipe = filenumber_pipe.zip(batch_pipe).map(save_func) - rs = MultiProcessingReadingService(**rs_config) - dataloader = DataLoader2(save_pipe, reading_service=rs) + dataloader = DataLoader(save_pipe, **dataloader_kwargs) pbar = tqdm(total=num_batches) for i, batch in zip(range(num_batches), dataloader): @@ -163,10 +161,20 @@ def main(config: DictConfig): os.mkdir(f"{config.batch_output_dir}/train") os.mkdir(f"{config.batch_output_dir}/val") - readingservice_config = dict( - num_workers=config_dm.num_workers, - multiprocessing_context="spawn", - worker_prefetch_cnt=config_dm.prefetch_factor, + + dataloader_kwargs = dict( + shuffle=False, + batch_size=None, # batched in datapipe step + sampler=None, + batch_sampler=None, + num_workers=config_dm.num_workers, + collate_fn=None, + pin_memory=False, + drop_last=False, + timeout=0, + worker_init_fn=None, + prefetch_factor=config_dm.prefetch_factor, + persistent_workers=False ) print("----- Saving val batches -----") @@ -181,7 +189,7 @@ def main(config: DictConfig): batch_pipe=val_batch_pipe, batch_dir=f"{config.batch_output_dir}/val", num_batches=config.num_val_batches, - rs_config=readingservice_config, + dataloader_kwargs=dataloader_kwargs, ) print("----- Saving train batches -----") @@ -196,7 +204,7 @@ def main(config: DictConfig): batch_pipe=train_batch_pipe, batch_dir=f"{config.batch_output_dir}/train", num_batches=config.num_train_batches, - rs_config=readingservice_config, + dataloader_kwargs=dataloader_kwargs, ) print("done")