Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
dfulu committed Jan 8, 2024
1 parent 3b74b1d commit bd0887b
Show file tree
Hide file tree
Showing 6 changed files with 10 additions and 440 deletions.
2 changes: 1 addition & 1 deletion configs.example/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ defaults:
renewable:
"pv"

# enable color logging
# enable color logging
# - override hydra/hydra_logging: colorlog
# - override hydra/job_logging: colorlog

Expand Down
15 changes: 5 additions & 10 deletions pvnet/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from pvnet.data.utils import batch_to_tensor




class DataModule(LightningDataModule):
"""Datamodule for training pvnet and using pvnet pipeline in `ocf_datapipes`."""

Expand All @@ -23,7 +25,6 @@ def __init__(
train_period=[None, None],
val_period=[None, None],
test_period=[None, None],
block_nwp_and_sat=False,
batch_dir=None,
):
"""Datamodule for training pvnet and using pvnet pipeline in `ocf_datapipes`.
Expand All @@ -39,16 +40,13 @@ def __init__(
train_period: Date range filter for train dataloader.
val_period: Date range filter for val dataloader.
test_period: Date range filter for test dataloader.
block_nwp_and_sat: If True, the dataloader does not load the requested NWP and sat data.
It instead returns an zero-array of the required shape. Useful for pretraining.
batch_dir: Path to the directory of pre-saved batches. Cannot be used together with
`configuration` or 'train/val/test_period'.
"""
super().__init__()
self.configuration = configuration
self.batch_size = batch_size
self.block_nwp_and_sat = block_nwp_and_sat
self.batch_dir = batch_dir

if not ((batch_dir is not None) ^ (configuration is not None)):
Expand All @@ -69,7 +67,6 @@ def __init__(
]

self._common_dataloader_kwargs = dict(
shuffle=False, # shuffled in datapipe step
batch_size=None, # batched in datapipe step
sampler=None,
batch_sampler=None,
Expand All @@ -88,8 +85,6 @@ def _get_datapipe(self, start_time, end_time):
self.configuration,
start_time=start_time,
end_time=end_time,
block_sat=self.block_nwp_and_sat,
block_nwp=self.block_nwp_and_sat,
)

data_pipeline = (
Expand Down Expand Up @@ -131,20 +126,20 @@ def train_dataloader(self):
datapipe = self._get_premade_batches_datapipe("train", shuffle=True)
else:
datapipe = self._get_datapipe(*self.train_period)
return DataLoader(datapipe, **self._common_dataloader_kwargs)
return DataLoader(datapipe, shuffle=True, **self._common_dataloader_kwargs)

def val_dataloader(self):
"""Construct val dataloader"""
if self.batch_dir is not None:
datapipe = self._get_premade_batches_datapipe("val")
else:
datapipe = self._get_datapipe(*self.val_period)
return DataLoader(datapipe, **self._common_dataloader_kwargs)
return DataLoader(datapipe, shuffle=False, **self._common_dataloader_kwargs)

def test_dataloader(self):
"""Construct test dataloader"""
if self.batch_dir is not None:
datapipe = self._get_premade_batches_datapipe("test")
else:
datapipe = self._get_datapipe(*self.test_period)
return DataLoader(datapipe, **self._common_dataloader_kwargs)
return DataLoader(datapipe, shuffle=False, **self._common_dataloader_kwargs)
4 changes: 2 additions & 2 deletions pvnet/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __bool__(self):

def append(self, y_hat: torch.Tensor):
"""Append a sub-batch of predictions"""
self._y_hats += [y_hat]
self._y_hats.append(y_hat)

def flush(self) -> torch.Tensor:
"""Return all appended predictions as single tensor and remove from accumulated store."""
Expand All @@ -42,7 +42,7 @@ class DictListAccumulator:
@staticmethod
def _dict_list_append(d1, d2):
for k, v in d2.items():
d1[k] += [v]
d1[k].append(v)

@staticmethod
def _dict_init_list(d):
Expand Down
1 change: 0 additions & 1 deletion pvnet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,6 @@ def _get_numpy(key):
times_utc = batch[BatchKey.gsp_time_utc].cpu().numpy().squeeze().astype("datetime64[s]")
times_utc = [pd.to_datetime(t) for t in times_utc]

len(times_utc[0]) - t0_idx - 1
batch_size = y.shape[0]

fig, axes = plt.subplots(4, 4, figsize=(8, 8))
Expand Down
4 changes: 2 additions & 2 deletions scripts/checkpoint_to_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ def push_to_huggingface(
# Only one epoch (best) saved per model
files = glob.glob(f"{checkpoint_dir_path}/epoch*.ckpt")
assert len(files) == 1
checkpoint = torch.load(files[0])
checkpoint = torch.load(files[0], map_location="cpu")
else:
checkpoint = torch.load(f"{checkpoint_dir_path}/last.ckpt")
checkpoint = torch.load(f"{checkpoint_dir_path}/last.ckpt", map_location="cpu")

model.load_state_dict(state_dict=checkpoint["state_dict"])

Expand Down
Loading

0 comments on commit bd0887b

Please sign in to comment.