From 2f589d171badad95231eab01391ebba08e1c131a Mon Sep 17 00:00:00 2001 From: "Bryn N. Ubald" <55503826+bnubald@users.noreply.github.com> Date: Wed, 29 May 2024 11:52:28 +0100 Subject: [PATCH] Fixes #235: Set defaults for generate_workers --- icenet/data/dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/icenet/data/dataset.py b/icenet/data/dataset.py index d47cc10..d7d54d7 100644 --- a/icenet/data/dataset.py +++ b/icenet/data/dataset.py @@ -88,7 +88,7 @@ def __init__(self, self._counts = self._config["counts"] self._dtype = getattr(np, self._config["dtype"]) self._loader_config = self._config["loader_config"] - self._generate_workers = self._config["generate_workers"] + self._generate_workers = self._config.get("generate_workers", 4) self._n_forecast_days = self._config["n_forecast_days"] self._num_channels = self._config["num_channels"] self._shape = tuple(self._config["shape"]) @@ -148,7 +148,7 @@ def get_data_loader(self, if n_forecast_days is None: n_forecast_days = self._config["n_forecast_days"] if generate_workers is None: - generate_workers = self._config["generate_workers"] + generate_workers = self._config.get("generate_workers", 4) loader = IceNetDataLoaderFactory().create_data_loader( "dask", # This will load the `DaskMultiWorkerLoader` class. self.loader_config,