Skip to content

Commit

Permalink
Merge pull request #118 from openclimatefix/james/tweeks
Browse files Browse the repository at this point in the history
Fix shuffling and minor tweaks
  • Loading branch information
dfulu authored Jan 9, 2024
2 parents c23f811 + eaa94f7 commit 49d14fa
Show file tree
Hide file tree
Showing 9 changed files with 19 additions and 30 deletions.
5 changes: 2 additions & 3 deletions configs.example/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@ defaults:
- hparams_search: null
- hydra: default.yaml

renewable:
"pv"
renewable: "pv"

# enable color logging
# enable color logging
# - override hydra/hydra_logging: colorlog
# - override hydra/job_logging: colorlog

Expand Down
13 changes: 3 additions & 10 deletions pvnet/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def __init__(
train_period=[None, None],
val_period=[None, None],
test_period=[None, None],
block_nwp_and_sat=False,
batch_dir=None,
):
"""Datamodule for training pvnet and using pvnet pipeline in `ocf_datapipes`.
Expand All @@ -39,16 +38,13 @@ def __init__(
train_period: Date range filter for train dataloader.
val_period: Date range filter for val dataloader.
test_period: Date range filter for test dataloader.
block_nwp_and_sat: If True, the dataloader does not load the requested NWP and sat data.
It instead returns an zero-array of the required shape. Useful for pretraining.
batch_dir: Path to the directory of pre-saved batches. Cannot be used together with
`configuration` or 'train/val/test_period'.
"""
super().__init__()
self.configuration = configuration
self.batch_size = batch_size
self.block_nwp_and_sat = block_nwp_and_sat
self.batch_dir = batch_dir

if not ((batch_dir is not None) ^ (configuration is not None)):
Expand All @@ -69,7 +65,6 @@ def __init__(
]

self._common_dataloader_kwargs = dict(
shuffle=False, # shuffled in datapipe step
batch_size=None, # batched in datapipe step
sampler=None,
batch_sampler=None,
Expand All @@ -88,8 +83,6 @@ def _get_datapipe(self, start_time, end_time):
self.configuration,
start_time=start_time,
end_time=end_time,
block_sat=self.block_nwp_and_sat,
block_nwp=self.block_nwp_and_sat,
)

data_pipeline = (
Expand Down Expand Up @@ -131,20 +124,20 @@ def train_dataloader(self):
datapipe = self._get_premade_batches_datapipe("train", shuffle=True)
else:
datapipe = self._get_datapipe(*self.train_period)
return DataLoader(datapipe, **self._common_dataloader_kwargs)
return DataLoader(datapipe, shuffle=True, **self._common_dataloader_kwargs)

def val_dataloader(self):
"""Construct val dataloader"""
if self.batch_dir is not None:
datapipe = self._get_premade_batches_datapipe("val")
else:
datapipe = self._get_datapipe(*self.val_period)
return DataLoader(datapipe, **self._common_dataloader_kwargs)
return DataLoader(datapipe, shuffle=False, **self._common_dataloader_kwargs)

def test_dataloader(self):
"""Construct test dataloader"""
if self.batch_dir is not None:
datapipe = self._get_premade_batches_datapipe("test")
else:
datapipe = self._get_datapipe(*self.test_period)
return DataLoader(datapipe, **self._common_dataloader_kwargs)
return DataLoader(datapipe, shuffle=False, **self._common_dataloader_kwargs)
7 changes: 3 additions & 4 deletions pvnet/data/wind_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def __init__(
]

self._common_dataloader_kwargs = dict(
shuffle=False, # shuffled in datapipe step
batch_size=None, # batched in datapipe step
sampler=None,
batch_sampler=None,
Expand Down Expand Up @@ -123,20 +122,20 @@ def train_dataloader(self):
datapipe = self._get_premade_batches_datapipe("train", shuffle=True)
else:
datapipe = self._get_datapipe(*self.train_period)
return DataLoader(datapipe, **self._common_dataloader_kwargs)
return DataLoader(datapipe, shuffle=True, **self._common_dataloader_kwargs)

def val_dataloader(self):
"""Construct val dataloader"""
if self.batch_dir is not None:
datapipe = self._get_premade_batches_datapipe("val")
else:
datapipe = self._get_datapipe(*self.val_period)
return DataLoader(datapipe, **self._common_dataloader_kwargs)
return DataLoader(datapipe, shuffle=False, **self._common_dataloader_kwargs)

def test_dataloader(self):
"""Construct test dataloader"""
if self.batch_dir is not None:
datapipe = self._get_premade_batches_datapipe("test")
else:
datapipe = self._get_datapipe(*self.test_period)
return DataLoader(datapipe, **self._common_dataloader_kwargs)
return DataLoader(datapipe, shuffle=False, **self._common_dataloader_kwargs)
4 changes: 2 additions & 2 deletions pvnet/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __bool__(self):

def append(self, y_hat: torch.Tensor):
"""Append a sub-batch of predictions"""
self._y_hats += [y_hat]
self._y_hats.append(y_hat)

def flush(self) -> torch.Tensor:
"""Return all appended predictions as single tensor and remove from accumulated store."""
Expand All @@ -42,7 +42,7 @@ class DictListAccumulator:
@staticmethod
def _dict_list_append(d1, d2):
for k, v in d2.items():
d1[k] += [v]
d1[k].append(v)

@staticmethod
def _dict_init_list(d):
Expand Down
8 changes: 6 additions & 2 deletions pvnet/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,16 @@ def train(config: DictConfig) -> Optional[float]:

if should_pretrain:
# Pre-train the model
datamodule.block_nwp_and_sat = True
raise NotImplementedError("Pre-training is not yet supported")
# The parameter `block_nwp_and_sat` has been removed from datapipes
# If pretraining is re-supported in the future it is likely any pre-training logic should
# go here or perhaps in the callbacks
# datamodule.block_nwp_and_sat = True

trainer.fit(model=model, datamodule=datamodule)

_callbacks_to_phase(callbacks, "main")

datamodule.block_nwp_and_sat = False
trainer.should_stop = False

# Train the model completely
Expand Down
3 changes: 1 addition & 2 deletions pvnet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,12 +252,11 @@ def _get_numpy(key):
y_hat = y_hat.cpu().numpy()

gsp_ids = batch[BatchKey.gsp_id].cpu().numpy().squeeze()
t0_idx = batch[BatchKey.gsp_t0_idx]
batch[BatchKey.gsp_t0_idx]

times_utc = batch[BatchKey.gsp_time_utc].cpu().numpy().squeeze().astype("datetime64[s]")
times_utc = [pd.to_datetime(t) for t in times_utc]

len(times_utc[0]) - t0_idx - 1
batch_size = y.shape[0]

fig, axes = plt.subplots(4, 4, figsize=(8, 8))
Expand Down
4 changes: 2 additions & 2 deletions scripts/checkpoint_to_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ def push_to_huggingface(
# Only one epoch (best) saved per model
files = glob.glob(f"{checkpoint_dir_path}/epoch*.ckpt")
assert len(files) == 1
checkpoint = torch.load(files[0])
checkpoint = torch.load(files[0], map_location="cpu")
else:
checkpoint = torch.load(f"{checkpoint_dir_path}/last.ckpt")
checkpoint = torch.load(f"{checkpoint_dir_path}/last.ckpt", map_location="cpu")

model.load_state_dict(state_dict=checkpoint["state_dict"])

Expand Down
2 changes: 0 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ def sample_train_val_datamodule():
train_period=[None, None],
val_period=[None, None],
test_period=[None, None],
block_nwp_and_sat=False,
batch_dir=f"{tmpdirname}",
)
yield dm
Expand All @@ -139,7 +138,6 @@ def sample_datamodule():
train_period=[None, None],
val_period=[None, None],
test_period=[None, None],
block_nwp_and_sat=False,
batch_dir="tests/test_data/sample_batches",
)
return dm
Expand Down
3 changes: 0 additions & 3 deletions tests/data/test_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ def test_init():
train_period=[None, None],
val_period=[None, None],
test_period=[None, None],
block_nwp_and_sat=False,
batch_dir="tests/test_data/sample_batches",
)

Expand All @@ -39,7 +38,6 @@ def test_iter():
train_period=[None, None],
val_period=[None, None],
test_period=[None, None],
block_nwp_and_sat=False,
batch_dir="tests/test_data/sample_batches",
)

Expand All @@ -55,7 +53,6 @@ def test_iter_multiprocessing():
train_period=[None, None],
val_period=[None, None],
test_period=[None, None],
block_nwp_and_sat=False,
batch_dir="tests/test_data/sample_batches",
)

Expand Down

0 comments on commit 49d14fa

Please sign in to comment.