From dc6f8e1b04089204dc4ebaf682b2286b17386f5a Mon Sep 17 00:00:00 2001 From: James Fulton Date: Wed, 22 Nov 2023 17:36:25 +0000 Subject: [PATCH] remove torchdata --- pvnet_app/app.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/pvnet_app/app.py b/pvnet_app/app.py index 6b6a397..3257b3a 100644 --- a/pvnet_app/app.py +++ b/pvnet_app/app.py @@ -41,8 +41,8 @@ from ocf_datapipes.utils.utils import stack_np_examples_into_batch from pvnet_summation.models.base_model import BaseModel as SummationBaseModel from sqlalchemy.orm import Session -from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService -from torchdata.datapipes.iter import IterableWrapper +from torch.utils.data import DataLoader +from torch.utils.data.datapipes.iter import IterableWrapper import pvnet from pvnet.data.datamodule import batch_to_tensor, copy_batch_to_device @@ -76,7 +76,7 @@ # Huggingfacehub model repo and commit for PVNet summation (GSP sum to national model) # If summation_model_name is set to None, a simple sum is computed instead default_summation_model_name = "openclimatefix/pvnet_v2_summation" -default_summation_model_version = "01393d6e4a036103f9c7111cba6f03d5c19beb54" +default_summation_model_version = "6c5361101b461ae991662bdff05f7a0b77b4040b" model_name_ocf_db = "pvnet_v2" use_adjuster = os.getenv("USE_ADJUSTER", "True").lower() == "true" @@ -396,12 +396,22 @@ def app( ) # Set up dataloader for parallel loading - rs = MultiProcessingReadingService( + dataloader_kwargs = dict( + shuffle=False, + batch_size=None, # batched in datapipe step + sampler=None, + batch_sampler=None, num_workers=num_workers, - multiprocessing_context="spawn", - worker_prefetch_cnt=0 if num_workers == 0 else 2, + collate_fn=None, + pin_memory=False, + drop_last=False, + timeout=0, + worker_init_fn=None, + prefetch_factor=None if num_workers == 0 else 2, + persistent_workers=False, ) - dataloader = DataLoader2(batch_datapipe, reading_service=rs) + + dataloader = DataLoader(batch_datapipe, **dataloader_kwargs) # --------------------------------------------------------------------------- # 3. set up model