Skip to content

Commit

Permalink
Remove torchdata (#95)
Browse files Browse the repository at this point in the history
* remove torchdata

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* import fixes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* import fixes

* fix prefetch factor

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* test bug

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
dfulu and pre-commit-ci[bot] authored Nov 22, 2023
1 parent 2dd1cc9 commit 6c6a54c
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 71 deletions.
33 changes: 20 additions & 13 deletions pvnet/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
from ocf_datapipes.training.pvnet import pvnet_datapipe
from ocf_datapipes.utils.consts import BatchKey
from ocf_datapipes.utils.utils import stack_np_examples_into_batch
from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService
from torchdata.datapipes import functional_datapipe
from torchdata.datapipes.iter import FileLister, IterDataPipe
from torch.utils.data import DataLoader
from torch.utils.data.datapipes._decorator import functional_datapipe
from torch.utils.data.datapipes.datapipe import IterDataPipe
from torch.utils.data.datapipes.iter import FileLister


def copy_batch_to_device(batch, device):
Expand Down Expand Up @@ -69,7 +70,7 @@ def __init__(
configuration=None,
batch_size=16,
num_workers=0,
prefetch_factor=2,
prefetch_factor=None,
train_period=[None, None],
val_period=[None, None],
test_period=[None, None],
Expand Down Expand Up @@ -118,10 +119,19 @@ def __init__(
None if d is None else datetime.strptime(d, "%Y-%m-%d") for d in test_period
]

self.readingservice_config = dict(
self._common_dataloader_kwargs = dict(
shuffle=False, # shuffled in datapipe step
batch_size=None, # batched in datapipe step
sampler=None,
batch_sampler=None,
num_workers=num_workers,
multiprocessing_context="spawn",
worker_prefetch_cnt=prefetch_factor,
collate_fn=None,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
prefetch_factor=prefetch_factor,
persistent_workers=False,
)

def _get_datapipe(self, start_time, end_time):
Expand Down Expand Up @@ -172,23 +182,20 @@ def train_dataloader(self):
datapipe = self._get_premade_batches_datapipe("train", shuffle=True)
else:
datapipe = self._get_datapipe(*self.train_period)
rs = MultiProcessingReadingService(**self.readingservice_config)
return DataLoader2(datapipe, reading_service=rs)
return DataLoader(datapipe, **self._common_dataloader_kwargs)

def val_dataloader(self):
"""Construct val dataloader"""
if self.batch_dir is not None:
datapipe = self._get_premade_batches_datapipe("val")
else:
datapipe = self._get_datapipe(*self.val_period)
rs = MultiProcessingReadingService(**self.readingservice_config)
return DataLoader2(datapipe, reading_service=rs)
return DataLoader(datapipe, **self._common_dataloader_kwargs)

def test_dataloader(self):
"""Construct test dataloader"""
if self.batch_dir is not None:
datapipe = self._get_premade_batches_datapipe("test")
else:
datapipe = self._get_datapipe(*self.test_period)
rs = MultiProcessingReadingService(**self.readingservice_config)
return DataLoader2(datapipe, reading_service=rs)
return DataLoader(datapipe, **self._common_dataloader_kwargs)
3 changes: 1 addition & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
ocf_datapipes>=2.0.6
ocf_datapipes>=2.2.2
nowcasting_utils
ocf_ml_metrics
numpy
Expand All @@ -9,7 +9,6 @@ ipykernel
h5netcdf
torch>=2.0
lightning>=2.0.1
torchdata
torchvision
pytest
pytest-cov
Expand Down
2 changes: 1 addition & 1 deletion scripts/load_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"""

import torch
from torchdata.datapipes.iter import FileLister
from torch.utils.data.datapipes.iter import FileLister

from pvnet.data.datamodule import BatchSplitter

Expand Down
95 changes: 55 additions & 40 deletions scripts/save_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
use:
```
python save_batches.py \
+batch_output_dir="/mnt/disks/batches/batches_v0" \
+num_train_batches=10_000 \
+num_val_batches=2_000
+batch_output_dir="/mnt/disks/bigbatches/batches_v0" \
datamodule.batch_size=2 \
datamodule.num_workers=2 \
+num_train_batches=0 \
+num_val_batches=2
```
"""
Expand All @@ -27,11 +29,12 @@
from ocf_datapipes.utils.utils import stack_np_examples_into_batch
from omegaconf import DictConfig, OmegaConf
from sqlalchemy import exc as sa_exc
from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService
from torchdata.datapipes.iter import IterableWrapper
from torch.utils.data import DataLoader
from torch.utils.data.datapipes.iter import IterableWrapper
from tqdm import tqdm

from pvnet.data.datamodule import batch_to_tensor
from pvnet.utils import print_config

warnings.filterwarnings("ignore", category=sa_exc.SAWarning)

Expand Down Expand Up @@ -62,13 +65,12 @@ def _get_datapipe(config_path, start_time, end_time, batch_size):
return data_pipeline


def _save_batches_with_dataloader(batch_pipe, batch_dir, num_batches, rs_config):
def _save_batches_with_dataloader(batch_pipe, batch_dir, num_batches, dataloader_kwargs):
save_func = _save_batch_func_factory(batch_dir)
filenumber_pipe = IterableWrapper(range(num_batches)).sharding_filter()
save_pipe = filenumber_pipe.zip(batch_pipe).map(save_func)

rs = MultiProcessingReadingService(**rs_config)
dataloader = DataLoader2(save_pipe, reading_service=rs)
dataloader = DataLoader(save_pipe, **dataloader_kwargs)

pbar = tqdm(total=num_batches)
for i, batch in zip(range(num_batches), dataloader):
Expand All @@ -82,6 +84,8 @@ def main(config: DictConfig):
"""Constructs and saves validation and training batches."""
config_dm = config.datamodule

print_config(config, resolve=False)

# Set up directory
os.makedirs(config.batch_output_dir, exist_ok=False)

Expand All @@ -93,41 +97,52 @@ def main(config: DictConfig):
os.mkdir(f"{config.batch_output_dir}/train")
os.mkdir(f"{config.batch_output_dir}/val")

readingservice_config = dict(
dataloader_kwargs = dict(
shuffle=False,
batch_size=None, # batched in datapipe step
sampler=None,
batch_sampler=None,
num_workers=config_dm.num_workers,
multiprocessing_context="spawn",
worker_prefetch_cnt=config_dm.prefetch_factor,
)

print("----- Saving val batches -----")

val_batch_pipe = _get_datapipe(
config_dm.configuration,
*config_dm.val_period,
config_dm.batch_size,
)

_save_batches_with_dataloader(
batch_pipe=val_batch_pipe,
batch_dir=f"{config.batch_output_dir}/val",
num_batches=config.num_val_batches,
rs_config=readingservice_config,
collate_fn=None,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
prefetch_factor=config_dm.prefetch_factor,
persistent_workers=False,
)

print("----- Saving train batches -----")

train_batch_pipe = _get_datapipe(
config_dm.configuration,
*config_dm.train_period,
config_dm.batch_size,
)

_save_batches_with_dataloader(
batch_pipe=train_batch_pipe,
batch_dir=f"{config.batch_output_dir}/train",
num_batches=config.num_train_batches,
rs_config=readingservice_config,
)
if config.num_val_batches > 0:
print("----- Saving val batches -----")

val_batch_pipe = _get_datapipe(
config_dm.configuration,
*config_dm.val_period,
config_dm.batch_size,
)

_save_batches_with_dataloader(
batch_pipe=val_batch_pipe,
batch_dir=f"{config.batch_output_dir}/val",
num_batches=config.num_val_batches,
dataloader_kwargs=dataloader_kwargs,
)

if config.num_train_batches > 0:
print("----- Saving train batches -----")

train_batch_pipe = _get_datapipe(
config_dm.configuration,
*config_dm.train_period,
config_dm.batch_size,
)

_save_batches_with_dataloader(
batch_pipe=train_batch_pipe,
batch_dir=f"{config.batch_output_dir}/train",
num_batches=config.num_train_batches,
dataloader_kwargs=dataloader_kwargs,
)

print("done")

Expand Down
34 changes: 21 additions & 13 deletions scripts/save_concurrent_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
use:
```
python save_concurrent_batches.py \
+batch_output_dir="/mnt/disks/batches/concurrent_batches_v0" \
+num_train_batches=1_000 \
+num_val_batches=200
+batch_output_dir="/mnt/disks/nwp_rechunk/concurrent_batches_v3.9" \
+num_train_batches=20_000 \
+num_val_batches=4_000
```
"""
Expand All @@ -31,8 +31,8 @@
from ocf_datapipes.utils.utils import stack_np_examples_into_batch
from omegaconf import DictConfig, OmegaConf
from sqlalchemy import exc as sa_exc
from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService
from torchdata.datapipes.iter import IterableWrapper
from torch.utils.data import DataLoader
from torch.utils.data.datapipes.iter import IterableWrapper
from tqdm import tqdm

from pvnet.data.datamodule import batch_to_tensor
Expand Down Expand Up @@ -123,13 +123,12 @@ def _get_datapipe(config_path, start_time, end_time, n_batches):
return data_pipeline


def _save_batches_with_dataloader(batch_pipe, batch_dir, num_batches, rs_config):
def _save_batches_with_dataloader(batch_pipe, batch_dir, num_batches, dataloader_kwargs):
save_func = _save_batch_func_factory(batch_dir)
filenumber_pipe = IterableWrapper(np.arange(num_batches)).sharding_filter()
save_pipe = filenumber_pipe.zip(batch_pipe).map(save_func)

rs = MultiProcessingReadingService(**rs_config)
dataloader = DataLoader2(save_pipe, reading_service=rs)
dataloader = DataLoader(save_pipe, **dataloader_kwargs)

pbar = tqdm(total=num_batches)
for i, batch in zip(range(num_batches), dataloader):
Expand Down Expand Up @@ -163,10 +162,19 @@ def main(config: DictConfig):
os.mkdir(f"{config.batch_output_dir}/train")
os.mkdir(f"{config.batch_output_dir}/val")

readingservice_config = dict(
dataloader_kwargs = dict(
shuffle=False,
batch_size=None, # batched in datapipe step
sampler=None,
batch_sampler=None,
num_workers=config_dm.num_workers,
multiprocessing_context="spawn",
worker_prefetch_cnt=config_dm.prefetch_factor,
collate_fn=None,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
prefetch_factor=config_dm.prefetch_factor,
persistent_workers=False,
)

print("----- Saving val batches -----")
Expand All @@ -181,7 +189,7 @@ def main(config: DictConfig):
batch_pipe=val_batch_pipe,
batch_dir=f"{config.batch_output_dir}/val",
num_batches=config.num_val_batches,
rs_config=readingservice_config,
dataloader_kwargs=dataloader_kwargs,
)

print("----- Saving train batches -----")
Expand All @@ -196,7 +204,7 @@ def main(config: DictConfig):
batch_pipe=train_batch_pipe,
batch_dir=f"{config.batch_output_dir}/train",
num_batches=config.num_train_batches,
rs_config=readingservice_config,
dataloader_kwargs=dataloader_kwargs,
)

print("done")
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def sample_datamodule():
configuration=None,
batch_size=2,
num_workers=0,
prefetch_factor=2,
prefetch_factor=None,
train_period=[None, None],
val_period=[None, None],
test_period=[None, None],
Expand Down
2 changes: 1 addition & 1 deletion tests/data/test_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ def test_init():
configuration=None,
batch_size=2,
num_workers=0,
prefetch_factor=2,
prefetch_factor=None,
train_period=[None, None],
val_period=[None, None],
test_period=[None, None],
Expand Down

0 comments on commit 6c6a54c

Please sign in to comment.