Skip to content

Commit

Permalink
Merge pull request #21 from openclimatefix/training_fixes #minor
Browse files Browse the repository at this point in the history
Training fixes #minor
  • Loading branch information
dfulu authored Jun 20, 2024
2 parents 6cf05a7 + ea4bc02 commit 55599d1
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 43 deletions.
4 changes: 2 additions & 2 deletions configs.example/model/default.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
_target_: pvnet_summation.models.model.Model
_target_: pvnet_summation.models.flat_model.FlatModel

output_quantiles: null

Expand All @@ -12,11 +12,11 @@ model_version: "898630f3f8cd4e8506525d813dd61c6d8de86144"
output_network:
_target_: pvnet.models.multimodal.linear_networks.networks.ResFCNet2
_partial_: True
output_network_kwargs:
fc_hidden_features: 128
n_res_blocks: 2
res_block_layers: 2
dropout_frac: 0.0

predict_difference_from_sum: False

# ----------------------------------------------
Expand Down
22 changes: 10 additions & 12 deletions pvnet_summation/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ def __init__(
self.batch_dir = batch_dir

self._common_dataloader_kwargs = dict(
shuffle=False, # shuffled in datapipe step
batch_size=None, # batched in datapipe step
sampler=None,
batch_sampler=None,
Expand All @@ -164,7 +163,7 @@ def _get_premade_batches_datapipe(self, subdir, shuffle=False, add_filename=Fals
file_pipeline = FileLister(f"{self.batch_dir}/{subdir}", masks="*.pt", recursive=False)

if shuffle:
file_pipeline = file_pipeline.shuffle(buffer_size=1000)
file_pipeline = file_pipeline.shuffle(buffer_size=10_000)

file_pipeline = file_pipeline.sharding_filter()

Expand Down Expand Up @@ -228,14 +227,14 @@ def train_dataloader(self, shuffle=True, add_filename=False):
datapipe = self._get_premade_batches_datapipe(
"train", shuffle=shuffle, add_filename=add_filename
)
return DataLoader(datapipe, **self._common_dataloader_kwargs)
return DataLoader(datapipe, shuffle=shuffle, **self._common_dataloader_kwargs)

def val_dataloader(self, shuffle=False, add_filename=False):
"""Construct val dataloader"""
datapipe = self._get_premade_batches_datapipe(
"val", shuffle=shuffle, add_filename=add_filename
)
return DataLoader(datapipe, **self._common_dataloader_kwargs)
return DataLoader(datapipe, shuffle=shuffle, **self._common_dataloader_kwargs)

def test_dataloader(self):
"""Construct test dataloader"""
Expand Down Expand Up @@ -265,7 +264,6 @@ def __init__(
self.batch_dir = batch_dir

self._common_dataloader_kwargs = dict(
shuffle=False, # shuffled in datapipe step
batch_size=None, # batched in datapipe step
sampler=None,
batch_sampler=None,
Expand All @@ -284,7 +282,7 @@ def _get_premade_batches_datapipe(self, subdir, shuffle=False):
file_pipeline = FileLister(f"{self.batch_dir}/{subdir}", masks="*.pt", recursive=False)

if shuffle:
file_pipeline = file_pipeline.shuffle(buffer_size=1000)
file_pipeline = file_pipeline.shuffle(buffer_size=10_000)

sample_pipeline = file_pipeline.sharding_filter().map(torch.load)

Expand All @@ -300,21 +298,21 @@ def _get_premade_batches_datapipe(self, subdir, shuffle=False):

return batch_pipeline

def train_dataloader(self, shuffle=True):
def train_dataloader(self):
"""Construct train dataloader"""
datapipe = self._get_premade_batches_datapipe(
"train",
shuffle=shuffle,
shuffle=True,
)
return DataLoader(datapipe, **self._common_dataloader_kwargs)
return DataLoader(datapipe, shuffle=True, **self._common_dataloader_kwargs)

def val_dataloader(self, shuffle=False):
def val_dataloader(self):
"""Construct val dataloader"""
datapipe = self._get_premade_batches_datapipe(
"val",
shuffle=shuffle,
shuffle=False,
)
return DataLoader(datapipe, **self._common_dataloader_kwargs)
return DataLoader(datapipe, shuffle=False, **self._common_dataloader_kwargs)

def test_dataloader(self):
"""Construct test dataloader"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,41 +7,39 @@
import torch
import torch.nn.functional as F
from pvnet.models.multimodal.linear_networks.basic_blocks import AbstractLinearNetwork
from pvnet.models.multimodal.linear_networks.networks import DefaultFCNet
from pvnet.optimizers import AbstractOptimizer
from torch import nn

from pvnet_summation.models.base_model import BaseModel

_default_optimizer = pvnet.optimizers.Adam()

class FlatModel(BaseModel):
"""Neural network which combines GSP predictions from PVNet naively
class Model(BaseModel):
"""Neural network which combines GSP predictions from PVNet"""
This model flattens all the features into a 1D vector before feeding them into the sub network
"""

name = "pvnet_summation_model"
name = "FlatModel"

def __init__(
self,
output_network: AbstractLinearNetwork,
model_name: str,
model_version: Optional[str],
model_version: Optional[str] = None,
output_quantiles: Optional[list[float]] = None,
output_network: AbstractLinearNetwork = DefaultFCNet,
output_network_kwargs: Optional[dict] = None,
relative_scale_pvnet_outputs: bool = False,
predict_difference_from_sum: bool = False,
optimizer: AbstractOptimizer = _default_optimizer,
optimizer: AbstractOptimizer = pvnet.optimizers.Adam(),
):
"""Neural network which combines GSP predictions from PVNet
"""Neural network which combines GSP predictions from PVNet naively
Args:
model_name: Model path either locally or on huggingface.
model_version: Model version if using huggingface. Set to None if using local.
output_quantiles: A list of float (0.0, 1.0) quantiles to predict values for. If set to
None the output is a single value.
output_network: Pytorch Module class used to combine the 1D features to produce the
forecast.
output_network_kwargs: Dictionary of optional kwargs for the `output_network` module.
output_network: A partially instantiated pytorch Module class used to combine the 1D
features to produce the forecast.
relative_scale_pvnet_outputs: If true, the PVNet predictions are scaled by a factor
which is proportional to their capacities.
predict_difference_from_sum: Whether to use the sum of GSPs as an estimate for the
Expand All @@ -55,13 +53,9 @@ def __init__(
self.relative_scale_pvnet_outputs = relative_scale_pvnet_outputs
self.predict_difference_from_sum = predict_difference_from_sum

if output_network_kwargs is None:
output_network_kwargs = dict()

self.model = output_network(
in_features=np.prod(self.pvnet_output_shape),
out_features=self.num_output_features,
**output_network_kwargs,
)

# Add linear layer if predicting difference from sum
Expand All @@ -85,9 +79,10 @@ def forward(self, x):
else:
eff_cap = x["effective_capacity"]

# Multiply by (effective capacity / 100) since the capacities are roughly of magnitude
# of 100 MW. We still want the inputs to the network to be order of magnitude 1.
x_in = x["pvnet_outputs"] * (eff_cap / 100)
# The effective_capacit[ies] are relative fractions of the national capacity. They sum
# to 1 and they are quite small values. For the largest GSP the capacity is around 0.03.
# Therefore we apply this scaling to make the input values a more sensible size
x_in = x["pvnet_outputs"] * eff_cap * 100
else:
x_in = x["pvnet_outputs"]

Expand Down
27 changes: 18 additions & 9 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import math
import glob
import tempfile
from pvnet_summation.models.model import Model

import hydra
from pvnet_summation.models.flat_model import FlatModel

from ocf_datapipes.batch import BatchKey
from datetime import timedelta
Expand Down Expand Up @@ -134,22 +134,31 @@ def sample_batch(sample_datamodule):


@pytest.fixture()
def model_kwargs():
# These kwargs define the pvnet model which the summation model uses
def flat_model_kwargs():
kwargs = dict(
# These kwargs define the pvnet model which the summation model uses
model_name="openclimatefix/pvnet_v2",
model_version="4203e12e719efd93da641c43d2e38527648f4915",
# These kwargs define the structure of the summation model
output_network=dict(
_target_="pvnet.models.multimodal.linear_networks.networks.ResFCNet2",
_partial_=True,
fc_hidden_features=128,
n_res_blocks=2,
res_block_layers=2,
dropout_frac=0.0,
),
)
return kwargs
return hydra.utils.instantiate(kwargs)


@pytest.fixture()
def model(model_kwargs):
model = Model(**model_kwargs)
def model(flat_model_kwargs):
model = FlatModel(**flat_model_kwargs)
return model


@pytest.fixture()
def quantile_model(model_kwargs):
model = Model(output_quantiles=[0.1, 0.5, 0.9], **model_kwargs)
def quantile_model(flat_model_kwargs):
model = FlatModel(output_quantiles=[0.1, 0.5, 0.9], **flat_model_kwargs)
return model
File renamed without changes.

0 comments on commit 55599d1

Please sign in to comment.