Skip to content

Commit

Permalink
remove torchdata
Browse files Browse the repository at this point in the history
  • Loading branch information
dfulu committed Nov 21, 2023
1 parent 2dd1cc9 commit 035e65d
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 74 deletions.
38 changes: 25 additions & 13 deletions pvnet/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@

import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.data.datapipes.datapipe import IterDataPipe
from torch.utils.data.datapipes._decorator import functional_datapipe

from lightning.pytorch import LightningDataModule

from ocf_datapipes.training.pvnet import pvnet_datapipe
from ocf_datapipes.utils.consts import BatchKey
from ocf_datapipes.utils.utils import stack_np_examples_into_batch
from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService
from torchdata.datapipes import functional_datapipe
from torchdata.datapipes.iter import FileLister, IterDataPipe



def copy_batch_to_device(batch, device):
Expand Down Expand Up @@ -117,11 +120,20 @@ def __init__(
self.test_period = [
None if d is None else datetime.strptime(d, "%Y-%m-%d") for d in test_period
]

self.readingservice_config = dict(
num_workers=num_workers,
multiprocessing_context="spawn",
worker_prefetch_cnt=prefetch_factor,

self._common_dataloader_kwargs = dict(
shuffle=False, # shuffled in datapipe step
batch_size=None, # batched in datapipe step
sampler=None,
batch_sampler=None,
num_workers=num_workers,
collate_fn=None,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
prefetch_factor=prefetch_factor,
persistent_workers=False
)

def _get_datapipe(self, start_time, end_time):
Expand Down Expand Up @@ -173,22 +185,22 @@ def train_dataloader(self):
else:
datapipe = self._get_datapipe(*self.train_period)
rs = MultiProcessingReadingService(**self.readingservice_config)
return DataLoader2(datapipe, reading_service=rs)
return DataLoader(datapipe, **self._common_dataloader_kwargs)

def val_dataloader(self):
"""Construct val dataloader"""
if self.batch_dir is not None:
datapipe = self._get_premade_batches_datapipe("val")
else:
datapipe = self._get_datapipe(*self.val_period)
rs = MultiProcessingReadingService(**self.readingservice_config)
return DataLoader2(datapipe, reading_service=rs)
return DataLoader(datapipe, **self._common_dataloader_kwargs)


def test_dataloader(self):
"""Construct test dataloader"""
if self.batch_dir is not None:
datapipe = self._get_premade_batches_datapipe("test")
else:
datapipe = self._get_datapipe(*self.test_period)
rs = MultiProcessingReadingService(**self.readingservice_config)
return DataLoader2(datapipe, reading_service=rs)
return DataLoader(datapipe, **self._common_dataloader_kwargs)

3 changes: 1 addition & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
ocf_datapipes>=2.0.6
ocf_datapipes>=2.2.2
nowcasting_utils
ocf_ml_metrics
numpy
Expand All @@ -9,7 +9,6 @@ ipykernel
h5netcdf
torch>=2.0
lightning>=2.0.1
torchdata
torchvision
pytest
pytest-cov
Expand Down
2 changes: 1 addition & 1 deletion scripts/load_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"""

import torch
from torchdata.datapipes.iter import FileLister
from torch.utils.data.datapipes.iter import FileLister

from pvnet.data.datamodule import BatchSplitter

Expand Down
103 changes: 59 additions & 44 deletions scripts/save_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
use:
```
python save_batches.py \
+batch_output_dir="/mnt/disks/batches/batches_v0" \
+num_train_batches=10_000 \
+num_val_batches=2_000
+batch_output_dir="/mnt/disks/bigbatches/batches_v0" \
datamodule.batch_size=2 \
datamodule.num_workers=2 \
+num_train_batches=0 \
+num_val_batches=2
```
"""
Expand All @@ -27,11 +29,11 @@
from ocf_datapipes.utils.utils import stack_np_examples_into_batch
from omegaconf import DictConfig, OmegaConf
from sqlalchemy import exc as sa_exc
from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService
from torchdata.datapipes.iter import IterableWrapper
from torch.utils.data import DataLoader
from tqdm import tqdm

from pvnet.data.datamodule import batch_to_tensor
from pvnet.utils import print_config

warnings.filterwarnings("ignore", category=sa_exc.SAWarning)

Expand Down Expand Up @@ -62,13 +64,12 @@ def _get_datapipe(config_path, start_time, end_time, batch_size):
return data_pipeline


def _save_batches_with_dataloader(batch_pipe, batch_dir, num_batches, rs_config):
def _save_batches_with_dataloader(batch_pipe, batch_dir, num_batches, dataloader_kwargs):
save_func = _save_batch_func_factory(batch_dir)
filenumber_pipe = IterableWrapper(range(num_batches)).sharding_filter()
save_pipe = filenumber_pipe.zip(batch_pipe).map(save_func)

rs = MultiProcessingReadingService(**rs_config)
dataloader = DataLoader2(save_pipe, reading_service=rs)
dataloader = DataLoader(save_pipe, **dataloader_kwargs)

pbar = tqdm(total=num_batches)
for i, batch in zip(range(num_batches), dataloader):
Expand All @@ -79,8 +80,10 @@ def _save_batches_with_dataloader(batch_pipe, batch_dir, num_batches, rs_config)

@hydra.main(config_path="../configs/", config_name="config.yaml", version_base="1.2")
def main(config: DictConfig):
"""Constructs and saves validation and training batches."""
"Constructs and saves validation and training batches."
config_dm = config.datamodule

print_config(config, resolve=False)

# Set up directory
os.makedirs(config.batch_output_dir, exist_ok=False)
Expand All @@ -93,44 +96,56 @@ def main(config: DictConfig):
os.mkdir(f"{config.batch_output_dir}/train")
os.mkdir(f"{config.batch_output_dir}/val")

readingservice_config = dict(
num_workers=config_dm.num_workers,
multiprocessing_context="spawn",
worker_prefetch_cnt=config_dm.prefetch_factor,
)

print("----- Saving val batches -----")

val_batch_pipe = _get_datapipe(
config_dm.configuration,
*config_dm.val_period,
config_dm.batch_size,
)

_save_batches_with_dataloader(
batch_pipe=val_batch_pipe,
batch_dir=f"{config.batch_output_dir}/val",
num_batches=config.num_val_batches,
rs_config=readingservice_config,
)

print("----- Saving train batches -----")

train_batch_pipe = _get_datapipe(
config_dm.configuration,
*config_dm.train_period,
config_dm.batch_size,
)

_save_batches_with_dataloader(
batch_pipe=train_batch_pipe,
batch_dir=f"{config.batch_output_dir}/train",
num_batches=config.num_train_batches,
rs_config=readingservice_config,
dataloader_kwargs = dict(
shuffle=False,
batch_size=None, # batched in datapipe step
sampler=None,
batch_sampler=None,
num_workers=config_dm.num_workers,
collate_fn=None,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
prefetch_factor=config_dm.prefetch_factor,
persistent_workers=False
)

if config.num_val_batches>0:
print("----- Saving val batches -----")

val_batch_pipe = _get_datapipe(
config_dm.configuration,
*config_dm.val_period,
config_dm.batch_size,
)

_save_batches_with_dataloader(
batch_pipe=val_batch_pipe,
batch_dir=f"{config.batch_output_dir}/val",
num_batches=config.num_val_batches,
dataloader_kwargs=dataloader_kwargs,
)

if config.num_train_batches>0:

print("----- Saving train batches -----")

train_batch_pipe = _get_datapipe(
config_dm.configuration,
*config_dm.train_period,
config_dm.batch_size,
)

_save_batches_with_dataloader(
batch_pipe=train_batch_pipe,
batch_dir=f"{config.batch_output_dir}/train",
num_batches=config.num_train_batches,
dataloader_kwargs=dataloader_kwargs,
)

print("done")


if __name__ == "__main__":
main()
main()
36 changes: 22 additions & 14 deletions scripts/save_concurrent_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
use:
```
python save_concurrent_batches.py \
+batch_output_dir="/mnt/disks/batches/concurrent_batches_v0" \
+num_train_batches=1_000 \
+num_val_batches=200
+batch_output_dir="/mnt/disks/nwp_rechunk/concurrent_batches_v3.9" \
+num_train_batches=20_000 \
+num_val_batches=4_000
```
"""
Expand All @@ -31,8 +31,7 @@
from ocf_datapipes.utils.utils import stack_np_examples_into_batch
from omegaconf import DictConfig, OmegaConf
from sqlalchemy import exc as sa_exc
from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService
from torchdata.datapipes.iter import IterableWrapper
from torch.utils.data import DataLoader
from tqdm import tqdm

from pvnet.data.datamodule import batch_to_tensor
Expand Down Expand Up @@ -123,13 +122,12 @@ def _get_datapipe(config_path, start_time, end_time, n_batches):
return data_pipeline


def _save_batches_with_dataloader(batch_pipe, batch_dir, num_batches, rs_config):
def _save_batches_with_dataloader(batch_pipe, batch_dir, num_batches, dataloader_kwargs):
save_func = _save_batch_func_factory(batch_dir)
filenumber_pipe = IterableWrapper(np.arange(num_batches)).sharding_filter()
save_pipe = filenumber_pipe.zip(batch_pipe).map(save_func)

rs = MultiProcessingReadingService(**rs_config)
dataloader = DataLoader2(save_pipe, reading_service=rs)
dataloader = DataLoader(save_pipe, **dataloader_kwargs)

pbar = tqdm(total=num_batches)
for i, batch in zip(range(num_batches), dataloader):
Expand Down Expand Up @@ -163,10 +161,20 @@ def main(config: DictConfig):
os.mkdir(f"{config.batch_output_dir}/train")
os.mkdir(f"{config.batch_output_dir}/val")

readingservice_config = dict(
num_workers=config_dm.num_workers,
multiprocessing_context="spawn",
worker_prefetch_cnt=config_dm.prefetch_factor,

dataloader_kwargs = dict(
shuffle=False,
batch_size=None, # batched in datapipe step
sampler=None,
batch_sampler=None,
num_workers=config_dm.num_workers,
collate_fn=None,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
prefetch_factor=config_dm.prefetch_factor,
persistent_workers=False
)

print("----- Saving val batches -----")
Expand All @@ -181,7 +189,7 @@ def main(config: DictConfig):
batch_pipe=val_batch_pipe,
batch_dir=f"{config.batch_output_dir}/val",
num_batches=config.num_val_batches,
rs_config=readingservice_config,
dataloader_kwargs=dataloader_kwargs,
)

print("----- Saving train batches -----")
Expand All @@ -196,7 +204,7 @@ def main(config: DictConfig):
batch_pipe=train_batch_pipe,
batch_dir=f"{config.batch_output_dir}/train",
num_batches=config.num_train_batches,
rs_config=readingservice_config,
dataloader_kwargs=dataloader_kwargs,
)

print("done")
Expand Down

0 comments on commit 035e65d

Please sign in to comment.