From a256ad70d3e8d3a8ff60ff9d6d3190c9798257cb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 24 Nov 2023 17:24:09 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- scripts/save_batches.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/scripts/save_batches.py b/scripts/save_batches.py index f95de847..995f23e1 100644 --- a/scripts/save_batches.py +++ b/scripts/save_batches.py @@ -33,7 +33,6 @@ from torch.utils.data import DataLoader from torch.utils.data.datapipes.iter import IterableWrapper from tqdm import tqdm -import xarray as xr from pvnet.data.datamodule import batch_to_tensor from pvnet.utils import print_config @@ -77,7 +76,9 @@ def _get_datapipe(config_path, start_time, end_time, batch_size, renewable: str return data_pipeline -def _save_batches_with_dataloader(batch_pipe, batch_dir, num_batches, dataloader_kwargs, output_format: str = "torch"): +def _save_batches_with_dataloader( + batch_pipe, batch_dir, num_batches, dataloader_kwargs, output_format: str = "torch" +): save_func = _save_batch_func_factory(batch_dir, output_format=output_format) filenumber_pipe = IterableWrapper(range(num_batches)).sharding_filter() save_pipe = filenumber_pipe.zip(batch_pipe).map(save_func)