From e9f59ca7e8688586955713791de5d48c985308af Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 5 Dec 2023 10:10:58 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- configs/config.yaml | 3 ++- scripts/save_batches.py | 14 +++++++------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/configs/config.yaml b/configs/config.yaml index a5acbe48..e190caee 100644 --- a/configs/config.yaml +++ b/configs/config.yaml @@ -15,7 +15,8 @@ renewable: "wind" batch_output_dir: "/mnt/storage_ssd_4tb/windnet_batches" num_val_batches: 3200 num_train_batches: 64000 -num_test_batches: 100 +num_test_batches: + 100 # enable color logging # - override hydra/hydra_logging: colorlog # - override hydra/job_logging: colorlog diff --git a/scripts/save_batches.py b/scripts/save_batches.py index a172fac9..2bd41b58 100644 --- a/scripts/save_batches.py +++ b/scripts/save_batches.py @@ -19,8 +19,10 @@ import os import shutil import sys + import dask -dask.config.set(scheduler='single-threaded') + +dask.config.set(scheduler="single-threaded") # Tired of seeing these warnings import warnings @@ -28,14 +30,12 @@ import torch from ocf_datapipes.training.pvnet import pvnet_datapipe from ocf_datapipes.training.windnet import windnet_datapipe -from ocf_datapipes.utils.utils import stack_np_examples_into_batch from omegaconf import DictConfig, OmegaConf from sqlalchemy import exc as sa_exc from torch.utils.data import DataLoader from torch.utils.data.datapipes.iter import IterableWrapper from tqdm import tqdm -from pvnet.data.datamodule import batch_to_tensor from pvnet.utils import print_config warnings.filterwarnings("ignore", category=sa_exc.SAWarning) @@ -71,9 +71,9 @@ def _get_datapipe(config_path, start_time, end_time, batch_size, renewable: str end_time=end_time, ) - #data_pipeline = ( + # data_pipeline = ( # data_pipeline.batch(batch_size).map(stack_np_examples_into_batch).map(batch_to_tensor) - #) + # ) return data_pipeline @@ -116,13 +116,13 @@ def main(config: DictConfig): batch_size=None, # batched in datapipe step sampler=None, batch_sampler=None, - num_workers=0, #config_dm.num_workers, + num_workers=0, # config_dm.num_workers, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, - prefetch_factor=None, #config_dm.prefetch_factor, + prefetch_factor=None, # config_dm.prefetch_factor, persistent_workers=False, )