Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Dec 5, 2023
1 parent 2742824 commit e9f59ca
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
3 changes: 2 additions & 1 deletion configs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ renewable: "wind"
batch_output_dir: "/mnt/storage_ssd_4tb/windnet_batches"
num_val_batches: 3200
num_train_batches: 64000
num_test_batches: 100
num_test_batches:
100
# enable color logging
# - override hydra/hydra_logging: colorlog
# - override hydra/job_logging: colorlog
Expand Down
14 changes: 7 additions & 7 deletions scripts/save_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,23 @@
import os
import shutil
import sys

import dask
dask.config.set(scheduler='single-threaded')

dask.config.set(scheduler="single-threaded")
# Tired of seeing these warnings
import warnings

import hydra
import torch
from ocf_datapipes.training.pvnet import pvnet_datapipe
from ocf_datapipes.training.windnet import windnet_datapipe
from ocf_datapipes.utils.utils import stack_np_examples_into_batch
from omegaconf import DictConfig, OmegaConf
from sqlalchemy import exc as sa_exc
from torch.utils.data import DataLoader
from torch.utils.data.datapipes.iter import IterableWrapper
from tqdm import tqdm

from pvnet.data.datamodule import batch_to_tensor
from pvnet.utils import print_config

warnings.filterwarnings("ignore", category=sa_exc.SAWarning)
Expand Down Expand Up @@ -71,9 +71,9 @@ def _get_datapipe(config_path, start_time, end_time, batch_size, renewable: str
end_time=end_time,
)

#data_pipeline = (
# data_pipeline = (
# data_pipeline.batch(batch_size).map(stack_np_examples_into_batch).map(batch_to_tensor)
#)
# )
return data_pipeline


Expand Down Expand Up @@ -116,13 +116,13 @@ def main(config: DictConfig):
batch_size=None, # batched in datapipe step
sampler=None,
batch_sampler=None,
num_workers=0, #config_dm.num_workers,
num_workers=0, # config_dm.num_workers,
collate_fn=None,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
prefetch_factor=None, #config_dm.prefetch_factor,
prefetch_factor=None, # config_dm.prefetch_factor,
persistent_workers=False,
)

Expand Down

0 comments on commit e9f59ca

Please sign in to comment.