diff --git a/configs.example/config.yaml b/configs.example/config.yaml index 9ad3457b..a76af5b2 100644 --- a/configs.example/config.yaml +++ b/configs.example/config.yaml @@ -12,10 +12,9 @@ defaults: - hparams_search: null - hydra: default.yaml -renewable: - "pv" +renewable: "pv" - # enable color logging +# enable color logging # - override hydra/hydra_logging: colorlog # - override hydra/job_logging: colorlog diff --git a/pvnet/data/datamodule.py b/pvnet/data/datamodule.py index e5905815..3d279331 100644 --- a/pvnet/data/datamodule.py +++ b/pvnet/data/datamodule.py @@ -23,7 +23,6 @@ def __init__( train_period=[None, None], val_period=[None, None], test_period=[None, None], - block_nwp_and_sat=False, batch_dir=None, ): """Datamodule for training pvnet and using pvnet pipeline in `ocf_datapipes`. @@ -39,8 +38,6 @@ def __init__( train_period: Date range filter for train dataloader. val_period: Date range filter for val dataloader. test_period: Date range filter for test dataloader. - block_nwp_and_sat: If True, the dataloader does not load the requested NWP and sat data. - It instead returns an zero-array of the required shape. Useful for pretraining. batch_dir: Path to the directory of pre-saved batches. Cannot be used together with `configuration` or 'train/val/test_period'. @@ -48,7 +45,6 @@ def __init__( super().__init__() self.configuration = configuration self.batch_size = batch_size - self.block_nwp_and_sat = block_nwp_and_sat self.batch_dir = batch_dir if not ((batch_dir is not None) ^ (configuration is not None)): @@ -69,7 +65,6 @@ def __init__( ] self._common_dataloader_kwargs = dict( - shuffle=False, # shuffled in datapipe step batch_size=None, # batched in datapipe step sampler=None, batch_sampler=None, @@ -88,8 +83,6 @@ def _get_datapipe(self, start_time, end_time): self.configuration, start_time=start_time, end_time=end_time, - block_sat=self.block_nwp_and_sat, - block_nwp=self.block_nwp_and_sat, ) data_pipeline = ( @@ -131,7 +124,7 @@ def train_dataloader(self): datapipe = self._get_premade_batches_datapipe("train", shuffle=True) else: datapipe = self._get_datapipe(*self.train_period) - return DataLoader(datapipe, **self._common_dataloader_kwargs) + return DataLoader(datapipe, shuffle=True, **self._common_dataloader_kwargs) def val_dataloader(self): """Construct val dataloader""" @@ -139,7 +132,7 @@ def val_dataloader(self): datapipe = self._get_premade_batches_datapipe("val") else: datapipe = self._get_datapipe(*self.val_period) - return DataLoader(datapipe, **self._common_dataloader_kwargs) + return DataLoader(datapipe, shuffle=False, **self._common_dataloader_kwargs) def test_dataloader(self): """Construct test dataloader""" @@ -147,4 +140,4 @@ def test_dataloader(self): datapipe = self._get_premade_batches_datapipe("test") else: datapipe = self._get_datapipe(*self.test_period) - return DataLoader(datapipe, **self._common_dataloader_kwargs) + return DataLoader(datapipe, shuffle=False, **self._common_dataloader_kwargs) diff --git a/pvnet/data/wind_datamodule.py b/pvnet/data/wind_datamodule.py index 21a93804..dad9e265 100644 --- a/pvnet/data/wind_datamodule.py +++ b/pvnet/data/wind_datamodule.py @@ -61,7 +61,6 @@ def __init__( ] self._common_dataloader_kwargs = dict( - shuffle=False, # shuffled in datapipe step batch_size=None, # batched in datapipe step sampler=None, batch_sampler=None, @@ -123,7 +122,7 @@ def train_dataloader(self): datapipe = self._get_premade_batches_datapipe("train", shuffle=True) else: datapipe = self._get_datapipe(*self.train_period) - return DataLoader(datapipe, **self._common_dataloader_kwargs) + return DataLoader(datapipe, shuffle=True, **self._common_dataloader_kwargs) def val_dataloader(self): """Construct val dataloader""" @@ -131,7 +130,7 @@ def val_dataloader(self): datapipe = self._get_premade_batches_datapipe("val") else: datapipe = self._get_datapipe(*self.val_period) - return DataLoader(datapipe, **self._common_dataloader_kwargs) + return DataLoader(datapipe, shuffle=False, **self._common_dataloader_kwargs) def test_dataloader(self): """Construct test dataloader""" @@ -139,4 +138,4 @@ def test_dataloader(self): datapipe = self._get_premade_batches_datapipe("test") else: datapipe = self._get_datapipe(*self.test_period) - return DataLoader(datapipe, **self._common_dataloader_kwargs) + return DataLoader(datapipe, shuffle=False, **self._common_dataloader_kwargs) diff --git a/pvnet/models/utils.py b/pvnet/models/utils.py index d8f645db..2c9748aa 100644 --- a/pvnet/models/utils.py +++ b/pvnet/models/utils.py @@ -27,7 +27,7 @@ def __bool__(self): def append(self, y_hat: torch.Tensor): """Append a sub-batch of predictions""" - self._y_hats += [y_hat] + self._y_hats.append(y_hat) def flush(self) -> torch.Tensor: """Return all appended predictions as single tensor and remove from accumulated store.""" @@ -42,7 +42,7 @@ class DictListAccumulator: @staticmethod def _dict_list_append(d1, d2): for k, v in d2.items(): - d1[k] += [v] + d1[k].append(v) @staticmethod def _dict_init_list(d): diff --git a/pvnet/training.py b/pvnet/training.py index 5f28215e..cc5f3c43 100644 --- a/pvnet/training.py +++ b/pvnet/training.py @@ -139,12 +139,16 @@ def train(config: DictConfig) -> Optional[float]: if should_pretrain: # Pre-train the model - datamodule.block_nwp_and_sat = True + raise NotImplementedError("Pre-training is not yet supported") + # The parameter `block_nwp_and_sat` has been removed from datapipes + # If pretraining is re-supported in the future it is likely any pre-training logic should + # go here or perhaps in the callbacks + # datamodule.block_nwp_and_sat = True + trainer.fit(model=model, datamodule=datamodule) _callbacks_to_phase(callbacks, "main") - datamodule.block_nwp_and_sat = False trainer.should_stop = False # Train the model completely diff --git a/pvnet/utils.py b/pvnet/utils.py index 33e26372..33ffe95b 100644 --- a/pvnet/utils.py +++ b/pvnet/utils.py @@ -252,12 +252,11 @@ def _get_numpy(key): y_hat = y_hat.cpu().numpy() gsp_ids = batch[BatchKey.gsp_id].cpu().numpy().squeeze() - t0_idx = batch[BatchKey.gsp_t0_idx] + batch[BatchKey.gsp_t0_idx] times_utc = batch[BatchKey.gsp_time_utc].cpu().numpy().squeeze().astype("datetime64[s]") times_utc = [pd.to_datetime(t) for t in times_utc] - len(times_utc[0]) - t0_idx - 1 batch_size = y.shape[0] fig, axes = plt.subplots(4, 4, figsize=(8, 8)) diff --git a/scripts/checkpoint_to_huggingface.py b/scripts/checkpoint_to_huggingface.py index 6691fb2c..f5777114 100644 --- a/scripts/checkpoint_to_huggingface.py +++ b/scripts/checkpoint_to_huggingface.py @@ -53,9 +53,9 @@ def push_to_huggingface( # Only one epoch (best) saved per model files = glob.glob(f"{checkpoint_dir_path}/epoch*.ckpt") assert len(files) == 1 - checkpoint = torch.load(files[0]) + checkpoint = torch.load(files[0], map_location="cpu") else: - checkpoint = torch.load(f"{checkpoint_dir_path}/last.ckpt") + checkpoint = torch.load(f"{checkpoint_dir_path}/last.ckpt", map_location="cpu") model.load_state_dict(state_dict=checkpoint["state_dict"]) diff --git a/tests/conftest.py b/tests/conftest.py index 31f61039..2a393ee2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -123,7 +123,6 @@ def sample_train_val_datamodule(): train_period=[None, None], val_period=[None, None], test_period=[None, None], - block_nwp_and_sat=False, batch_dir=f"{tmpdirname}", ) yield dm @@ -139,7 +138,6 @@ def sample_datamodule(): train_period=[None, None], val_period=[None, None], test_period=[None, None], - block_nwp_and_sat=False, batch_dir="tests/test_data/sample_batches", ) return dm diff --git a/tests/data/test_datamodule.py b/tests/data/test_datamodule.py index 8a428db4..0c497bbb 100644 --- a/tests/data/test_datamodule.py +++ b/tests/data/test_datamodule.py @@ -12,7 +12,6 @@ def test_init(): train_period=[None, None], val_period=[None, None], test_period=[None, None], - block_nwp_and_sat=False, batch_dir="tests/test_data/sample_batches", ) @@ -39,7 +38,6 @@ def test_iter(): train_period=[None, None], val_period=[None, None], test_period=[None, None], - block_nwp_and_sat=False, batch_dir="tests/test_data/sample_batches", ) @@ -55,7 +53,6 @@ def test_iter_multiprocessing(): train_period=[None, None], val_period=[None, None], test_period=[None, None], - block_nwp_and_sat=False, batch_dir="tests/test_data/sample_batches", )