From d4b7fe4ab18cce2f2700e0038c6e5bee67ff4739 Mon Sep 17 00:00:00 2001 From: James Fulton Date: Tue, 21 Nov 2023 16:24:33 +0000 Subject: [PATCH] fix prefetch factor --- pvnet/data/datamodule.py | 12 +++++++----- tests/conftest.py | 2 +- tests/data/test_datamodule.py | 2 +- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/pvnet/data/datamodule.py b/pvnet/data/datamodule.py index b60ecd98..681a9ac1 100644 --- a/pvnet/data/datamodule.py +++ b/pvnet/data/datamodule.py @@ -3,14 +3,16 @@ import numpy as np import torch +from torch.utils.data import DataLoader +from torch.utils.data.datapipes.datapipe import IterDataPipe +from torch.utils.data.datapipes.iter import FileLister +from torch.utils.data.datapipes._decorator import functional_datapipe + from lightning.pytorch import LightningDataModule from ocf_datapipes.training.pvnet import pvnet_datapipe from ocf_datapipes.utils.consts import BatchKey from ocf_datapipes.utils.utils import stack_np_examples_into_batch -from torch.utils.data import DataLoader -from torch.utils.data.datapipes._decorator import functional_datapipe -from torch.utils.data.datapipes.datapipe import IterDataPipe -from torch.utils.data.datapipes.iter import FileLister + def copy_batch_to_device(batch, device): @@ -70,7 +72,7 @@ def __init__( configuration=None, batch_size=16, num_workers=0, - prefetch_factor=2, + prefetch_factor=None, train_period=[None, None], val_period=[None, None], test_period=[None, None], diff --git a/tests/conftest.py b/tests/conftest.py index 7f7aa142..7e7eae52 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -96,7 +96,7 @@ def sample_datamodule(): configuration=None, batch_size=2, num_workers=0, - prefetch_factor=2, + prefetch_factor=None, train_period=[None, None], val_period=[None, None], test_period=[None, None], diff --git a/tests/data/test_datamodule.py b/tests/data/test_datamodule.py index 1c283e2f..a2a6eab6 100644 --- a/tests/data/test_datamodule.py +++ b/tests/data/test_datamodule.py @@ -6,7 +6,7 @@ def test_init(): configuration=None, batch_size=2, num_workers=0, - prefetch_factor=2, + prefetch_factor=Mone, train_period=[None, None], val_period=[None, None], test_period=[None, None],