From 6c6a54cc8d85aece73de60fc742230d256116cac Mon Sep 17 00:00:00 2001 From: James Fulton <41546094+dfulu@users.noreply.github.com> Date: Wed, 22 Nov 2023 10:54:30 +0000 Subject: [PATCH] Remove torchdata (#95) * remove torchdata * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * import fixes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * import fixes * fix prefetch factor * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * test bug --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- pvnet/data/datamodule.py | 33 +++++++---- requirements.txt | 3 +- scripts/load_batches.py | 2 +- scripts/save_batches.py | 95 +++++++++++++++++------------- scripts/save_concurrent_batches.py | 34 +++++++---- tests/conftest.py | 2 +- tests/data/test_datamodule.py | 2 +- 7 files changed, 100 insertions(+), 71 deletions(-) diff --git a/pvnet/data/datamodule.py b/pvnet/data/datamodule.py index 78783975..39abae5a 100644 --- a/pvnet/data/datamodule.py +++ b/pvnet/data/datamodule.py @@ -7,9 +7,10 @@ from ocf_datapipes.training.pvnet import pvnet_datapipe from ocf_datapipes.utils.consts import BatchKey from ocf_datapipes.utils.utils import stack_np_examples_into_batch -from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService -from torchdata.datapipes import functional_datapipe -from torchdata.datapipes.iter import FileLister, IterDataPipe +from torch.utils.data import DataLoader +from torch.utils.data.datapipes._decorator import functional_datapipe +from torch.utils.data.datapipes.datapipe import IterDataPipe +from torch.utils.data.datapipes.iter import FileLister def copy_batch_to_device(batch, device): @@ -69,7 +70,7 @@ def __init__( configuration=None, batch_size=16, num_workers=0, - prefetch_factor=2, + prefetch_factor=None, train_period=[None, None], val_period=[None, None], test_period=[None, None], @@ -118,10 +119,19 @@ def __init__( None if d is None else datetime.strptime(d, "%Y-%m-%d") for d in test_period ] - self.readingservice_config = dict( + self._common_dataloader_kwargs = dict( + shuffle=False, # shuffled in datapipe step + batch_size=None, # batched in datapipe step + sampler=None, + batch_sampler=None, num_workers=num_workers, - multiprocessing_context="spawn", - worker_prefetch_cnt=prefetch_factor, + collate_fn=None, + pin_memory=False, + drop_last=False, + timeout=0, + worker_init_fn=None, + prefetch_factor=prefetch_factor, + persistent_workers=False, ) def _get_datapipe(self, start_time, end_time): @@ -172,8 +182,7 @@ def train_dataloader(self): datapipe = self._get_premade_batches_datapipe("train", shuffle=True) else: datapipe = self._get_datapipe(*self.train_period) - rs = MultiProcessingReadingService(**self.readingservice_config) - return DataLoader2(datapipe, reading_service=rs) + return DataLoader(datapipe, **self._common_dataloader_kwargs) def val_dataloader(self): """Construct val dataloader""" @@ -181,8 +190,7 @@ def val_dataloader(self): datapipe = self._get_premade_batches_datapipe("val") else: datapipe = self._get_datapipe(*self.val_period) - rs = MultiProcessingReadingService(**self.readingservice_config) - return DataLoader2(datapipe, reading_service=rs) + return DataLoader(datapipe, **self._common_dataloader_kwargs) def test_dataloader(self): """Construct test dataloader""" @@ -190,5 +198,4 @@ def test_dataloader(self): datapipe = self._get_premade_batches_datapipe("test") else: datapipe = self._get_datapipe(*self.test_period) - rs = MultiProcessingReadingService(**self.readingservice_config) - return DataLoader2(datapipe, reading_service=rs) + return DataLoader(datapipe, **self._common_dataloader_kwargs) diff --git a/requirements.txt b/requirements.txt index 61148f13..e3bcea24 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -ocf_datapipes>=2.0.6 +ocf_datapipes>=2.2.2 nowcasting_utils ocf_ml_metrics numpy @@ -9,7 +9,6 @@ ipykernel h5netcdf torch>=2.0 lightning>=2.0.1 -torchdata torchvision pytest pytest-cov diff --git a/scripts/load_batches.py b/scripts/load_batches.py index 261371fe..718f6c8f 100644 --- a/scripts/load_batches.py +++ b/scripts/load_batches.py @@ -2,7 +2,7 @@ """ import torch -from torchdata.datapipes.iter import FileLister +from torch.utils.data.datapipes.iter import FileLister from pvnet.data.datamodule import BatchSplitter diff --git a/scripts/save_batches.py b/scripts/save_batches.py index 0a145502..950ad23e 100644 --- a/scripts/save_batches.py +++ b/scripts/save_batches.py @@ -7,9 +7,11 @@ use: ``` python save_batches.py \ - +batch_output_dir="/mnt/disks/batches/batches_v0" \ - +num_train_batches=10_000 \ - +num_val_batches=2_000 + +batch_output_dir="/mnt/disks/bigbatches/batches_v0" \ + datamodule.batch_size=2 \ + datamodule.num_workers=2 \ + +num_train_batches=0 \ + +num_val_batches=2 ``` """ @@ -27,11 +29,12 @@ from ocf_datapipes.utils.utils import stack_np_examples_into_batch from omegaconf import DictConfig, OmegaConf from sqlalchemy import exc as sa_exc -from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService -from torchdata.datapipes.iter import IterableWrapper +from torch.utils.data import DataLoader +from torch.utils.data.datapipes.iter import IterableWrapper from tqdm import tqdm from pvnet.data.datamodule import batch_to_tensor +from pvnet.utils import print_config warnings.filterwarnings("ignore", category=sa_exc.SAWarning) @@ -62,13 +65,12 @@ def _get_datapipe(config_path, start_time, end_time, batch_size): return data_pipeline -def _save_batches_with_dataloader(batch_pipe, batch_dir, num_batches, rs_config): +def _save_batches_with_dataloader(batch_pipe, batch_dir, num_batches, dataloader_kwargs): save_func = _save_batch_func_factory(batch_dir) filenumber_pipe = IterableWrapper(range(num_batches)).sharding_filter() save_pipe = filenumber_pipe.zip(batch_pipe).map(save_func) - rs = MultiProcessingReadingService(**rs_config) - dataloader = DataLoader2(save_pipe, reading_service=rs) + dataloader = DataLoader(save_pipe, **dataloader_kwargs) pbar = tqdm(total=num_batches) for i, batch in zip(range(num_batches), dataloader): @@ -82,6 +84,8 @@ def main(config: DictConfig): """Constructs and saves validation and training batches.""" config_dm = config.datamodule + print_config(config, resolve=False) + # Set up directory os.makedirs(config.batch_output_dir, exist_ok=False) @@ -93,41 +97,52 @@ def main(config: DictConfig): os.mkdir(f"{config.batch_output_dir}/train") os.mkdir(f"{config.batch_output_dir}/val") - readingservice_config = dict( + dataloader_kwargs = dict( + shuffle=False, + batch_size=None, # batched in datapipe step + sampler=None, + batch_sampler=None, num_workers=config_dm.num_workers, - multiprocessing_context="spawn", - worker_prefetch_cnt=config_dm.prefetch_factor, - ) - - print("----- Saving val batches -----") - - val_batch_pipe = _get_datapipe( - config_dm.configuration, - *config_dm.val_period, - config_dm.batch_size, - ) - - _save_batches_with_dataloader( - batch_pipe=val_batch_pipe, - batch_dir=f"{config.batch_output_dir}/val", - num_batches=config.num_val_batches, - rs_config=readingservice_config, + collate_fn=None, + pin_memory=False, + drop_last=False, + timeout=0, + worker_init_fn=None, + prefetch_factor=config_dm.prefetch_factor, + persistent_workers=False, ) - print("----- Saving train batches -----") - - train_batch_pipe = _get_datapipe( - config_dm.configuration, - *config_dm.train_period, - config_dm.batch_size, - ) - - _save_batches_with_dataloader( - batch_pipe=train_batch_pipe, - batch_dir=f"{config.batch_output_dir}/train", - num_batches=config.num_train_batches, - rs_config=readingservice_config, - ) + if config.num_val_batches > 0: + print("----- Saving val batches -----") + + val_batch_pipe = _get_datapipe( + config_dm.configuration, + *config_dm.val_period, + config_dm.batch_size, + ) + + _save_batches_with_dataloader( + batch_pipe=val_batch_pipe, + batch_dir=f"{config.batch_output_dir}/val", + num_batches=config.num_val_batches, + dataloader_kwargs=dataloader_kwargs, + ) + + if config.num_train_batches > 0: + print("----- Saving train batches -----") + + train_batch_pipe = _get_datapipe( + config_dm.configuration, + *config_dm.train_period, + config_dm.batch_size, + ) + + _save_batches_with_dataloader( + batch_pipe=train_batch_pipe, + batch_dir=f"{config.batch_output_dir}/train", + num_batches=config.num_train_batches, + dataloader_kwargs=dataloader_kwargs, + ) print("done") diff --git a/scripts/save_concurrent_batches.py b/scripts/save_concurrent_batches.py index 8af2784e..cd1d0937 100644 --- a/scripts/save_concurrent_batches.py +++ b/scripts/save_concurrent_batches.py @@ -7,9 +7,9 @@ use: ``` python save_concurrent_batches.py \ - +batch_output_dir="/mnt/disks/batches/concurrent_batches_v0" \ - +num_train_batches=1_000 \ - +num_val_batches=200 + +batch_output_dir="/mnt/disks/nwp_rechunk/concurrent_batches_v3.9" \ + +num_train_batches=20_000 \ + +num_val_batches=4_000 ``` """ @@ -31,8 +31,8 @@ from ocf_datapipes.utils.utils import stack_np_examples_into_batch from omegaconf import DictConfig, OmegaConf from sqlalchemy import exc as sa_exc -from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService -from torchdata.datapipes.iter import IterableWrapper +from torch.utils.data import DataLoader +from torch.utils.data.datapipes.iter import IterableWrapper from tqdm import tqdm from pvnet.data.datamodule import batch_to_tensor @@ -123,13 +123,12 @@ def _get_datapipe(config_path, start_time, end_time, n_batches): return data_pipeline -def _save_batches_with_dataloader(batch_pipe, batch_dir, num_batches, rs_config): +def _save_batches_with_dataloader(batch_pipe, batch_dir, num_batches, dataloader_kwargs): save_func = _save_batch_func_factory(batch_dir) filenumber_pipe = IterableWrapper(np.arange(num_batches)).sharding_filter() save_pipe = filenumber_pipe.zip(batch_pipe).map(save_func) - rs = MultiProcessingReadingService(**rs_config) - dataloader = DataLoader2(save_pipe, reading_service=rs) + dataloader = DataLoader(save_pipe, **dataloader_kwargs) pbar = tqdm(total=num_batches) for i, batch in zip(range(num_batches), dataloader): @@ -163,10 +162,19 @@ def main(config: DictConfig): os.mkdir(f"{config.batch_output_dir}/train") os.mkdir(f"{config.batch_output_dir}/val") - readingservice_config = dict( + dataloader_kwargs = dict( + shuffle=False, + batch_size=None, # batched in datapipe step + sampler=None, + batch_sampler=None, num_workers=config_dm.num_workers, - multiprocessing_context="spawn", - worker_prefetch_cnt=config_dm.prefetch_factor, + collate_fn=None, + pin_memory=False, + drop_last=False, + timeout=0, + worker_init_fn=None, + prefetch_factor=config_dm.prefetch_factor, + persistent_workers=False, ) print("----- Saving val batches -----") @@ -181,7 +189,7 @@ def main(config: DictConfig): batch_pipe=val_batch_pipe, batch_dir=f"{config.batch_output_dir}/val", num_batches=config.num_val_batches, - rs_config=readingservice_config, + dataloader_kwargs=dataloader_kwargs, ) print("----- Saving train batches -----") @@ -196,7 +204,7 @@ def main(config: DictConfig): batch_pipe=train_batch_pipe, batch_dir=f"{config.batch_output_dir}/train", num_batches=config.num_train_batches, - rs_config=readingservice_config, + dataloader_kwargs=dataloader_kwargs, ) print("done") diff --git a/tests/conftest.py b/tests/conftest.py index 7f7aa142..7e7eae52 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -96,7 +96,7 @@ def sample_datamodule(): configuration=None, batch_size=2, num_workers=0, - prefetch_factor=2, + prefetch_factor=None, train_period=[None, None], val_period=[None, None], test_period=[None, None], diff --git a/tests/data/test_datamodule.py b/tests/data/test_datamodule.py index 1c283e2f..66ebeef9 100644 --- a/tests/data/test_datamodule.py +++ b/tests/data/test_datamodule.py @@ -6,7 +6,7 @@ def test_init(): configuration=None, batch_size=2, num_workers=0, - prefetch_factor=2, + prefetch_factor=None, train_period=[None, None], val_period=[None, None], test_period=[None, None],