Skip to content

Commit

Permalink
Merge pull request #3 from openclimatefix/gsp_sum_deviation
Browse files Browse the repository at this point in the history
GSP sum deviation
  • Loading branch information
dfulu authored Jul 24, 2023
2 parents 8287321 + 86ae916 commit 95112e0
Show file tree
Hide file tree
Showing 9 changed files with 143 additions and 48 deletions.
2 changes: 1 addition & 1 deletion configs/datamodule/default.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
_target_: pvnet_summation.data.datamodule.DataModule
batch_dir: "/mnt/disks/bigbatches/concurrent_batches_v3.6_-60mins"
gsp_zarr_path: "/mnt/disks/nwp/pv_gsp.zarr"
batch_size: 8
batch_size: 32
num_workers: 20
prefetch_factor: 2
6 changes: 2 additions & 4 deletions configs/model/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,10 @@ output_network:
_partial_: True
output_network_kwargs:
fc_hidden_features: 128
n_res_blocks: 6
n_res_blocks: 2
res_block_layers: 2
dropout_frac: 0.0

# Foreast and time settings
forecast_minutes: 480
predict_difference_from_sum: False

# ----------------------------------------------

Expand Down
2 changes: 1 addition & 1 deletion configs/trainer/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ num_sanity_val_steps: 8
fast_dev_run: false
#profiler: 'simple'

accumulate_grad_batches: 4
#accumulate_grad_batches: 4
#val_check_interval: 800
#limit_val_batches: 800
log_every_n_steps: 50
31 changes: 26 additions & 5 deletions pvnet_summation/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(self, source_datapipe):
"""Convert list of dicts to dict of lists
Args:
source_datapipe:
source_datapipe: Datapipe yielding lists of dicts
"""
self.source_datapipe = source_datapipe

Expand Down Expand Up @@ -105,7 +105,17 @@ def __init__(self, **datapipes):

def __iter__(self):
for outputs in self.source_datapipes:
yield {key: value for key, value in zip(self.keys, outputs)}
yield {key: value for key, value in zip(self.keys, outputs)} # noqa: B905


def get_capacity(batch):
"""Extract the capacity from the numpy batch"""
return batch[BatchKey.gsp_effective_capacity_mwp]


def divide(args):
"""Divide first argument by second"""
return args[0] / args[1]


class DataModule(LightningDataModule):
Expand Down Expand Up @@ -161,24 +171,33 @@ def _get_premade_batches_datapipe(self, subdir, shuffle=False, add_filename=Fals
)

sample_pipeline, sample_pipeline_copy = sample_pipeline.fork(2, buffer_size=5)
times_datapipe = GetBatchTime(sample_pipeline_copy)

times_datapipe, times_datapipe_copy = GetBatchTime(sample_pipeline_copy).fork(
2, buffer_size=5
times_datapipe, times_datapipe_copy = times_datapipe.fork(2, buffer_size=5)
national_targets_datapipe = GetNationalPVLive(gsp_data, times_datapipe_copy)

times_datapipe, times_datapipe_copy = times_datapipe.fork(2, buffer_size=5)
national_capacity_datapipe = GetNationalPVLive(
gsp_data.effective_capacity_mwp, times_datapipe_copy
)
sample_pipeline, sample_pipeline_copy = sample_pipeline.fork(2, buffer_size=5)
gsp_capacity_pipeline = sample_pipeline_copy.map(get_capacity)

national_targets_datapipe = GetNationalPVLive(gsp_data, times_datapipe_copy)
capacity_pipeline = gsp_capacity_pipeline.zip(national_capacity_datapipe).map(divide)

# Compile the samples
if add_filename:
data_pipeline = ZipperDict(
pvnet_inputs=sample_pipeline,
effective_capacity=capacity_pipeline,
national_targets=national_targets_datapipe,
times=times_datapipe,
filepath=file_pipeline_copy,
)
else:
data_pipeline = ZipperDict(
pvnet_inputs=sample_pipeline,
effective_capacity=capacity_pipeline,
national_targets=national_targets_datapipe,
times=times_datapipe,
)
Expand All @@ -187,6 +206,7 @@ def _get_premade_batches_datapipe(self, subdir, shuffle=False, add_filename=Fals
data_pipeline = PivotDictList(data_pipeline.batch(self.batch_size))
data_pipeline = DictApply(
data_pipeline,
effective_capacity=torch.stack,
national_targets=torch.stack,
times=torch.stack,
)
Expand Down Expand Up @@ -256,6 +276,7 @@ def _get_premade_batches_datapipe(self, subdir, shuffle=False):
batch_pipeline = PivotDictList(sample_pipeline.batch(self.batch_size))
batch_pipeline = DictApply(
batch_pipeline,
effective_capacity=torch.stack,
pvnet_outputs=torch.stack,
national_targets=torch.stack,
times=torch.stack,
Expand Down
63 changes: 46 additions & 17 deletions pvnet_summation/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@

from pvnet_summation.utils import plot_forecasts

# from pvnet.models.base_model import BaseModel as PVNetBaseModel


logger = logging.getLogger(__name__)

activities = [torch.profiler.ProfilerActivity.CPU]
Expand All @@ -31,7 +28,6 @@ class BaseModel(PVNetBaseModel):

def __init__(
self,
forecast_minutes: int,
model_name: str,
model_version: Optional[str],
optimizer: AbstractOptimizer,
Expand All @@ -40,7 +36,6 @@ def __init__(
"""Abtstract base class for PVNet summation submodels.
Args:
forecast_minutes (int): Length of the GSP forecast period in minutes
model_name: Model path either locally or on huggingface.
model_version: Model version if using huggingface. Set to None if using local.
optimizer (AbstractOptimizer): Optimizer
Expand All @@ -50,46 +45,58 @@ def __init__(
pl.LightningModule.__init__(self)
PVNetModelHubMixin.__init__(self)

self.pvnet_model = PVNetBaseModel.from_pretrained(
model_name,
revision=model_version,
)
self.pvnet_model.requires_grad_(False)

self._optimizer = optimizer

# Model must have lr to allow tuning
# This setting is only used when lr is tuned with callback
self.lr = None

self.forecast_minutes = forecast_minutes
self.forecast_minutes = self.pvnet_model.forecast_minutes
self.output_quantiles = output_quantiles

# Number of timestemps for 30 minutely data
self.forecast_len_30 = forecast_minutes // 30
self.forecast_len_30 = self.forecast_minutes // 30

self.weighted_losses = WeightedLosses(forecast_length=self.forecast_len_30)

self._accumulated_metrics = MetricAccumulator()
self._accumulated_y = PredAccumulator()
self._accumulated_y_hat = PredAccumulator()
self._accumulated_y_sum = PredAccumulator()
self._accumulated_times = PredAccumulator()

self.pvnet_model = PVNetBaseModel.from_pretrained(
model_name,
revision=model_version,
)
self.pvnet_model.requires_grad_(False)

def predict_pvnet_batch(self, batch):
"""Use PVNet model to create predictions for batch"""
gsp_batches = []
for sample in batch:
preds = self.pvnet_model(sample)
gsp_batches += [preds]
return torch.stack(gsp_batches)

def sum_of_gsps(self, x):
"""Compute the sume of the GSP-level predictions"""
if self.pvnet_model.use_quantile_regression:
y_hat = self.pvnet_model._quantiles_to_prediction(x["pvnet_outputs"])
else:
y_hat = x["pvnet_outputs"]

return (y_hat * x["effective_capacity"]).sum(dim=1)

@property
def pvnet_output_shape(self):
"""Return the expected shape of the PVNet outputs"""
if self.pvnet_model.use_quantile_regression:
return (317, self.pvnet_model.forecast_len_30, len(self.pvnet_model.output_quantiles))
else:
return (317, self.pvnet_model.forecast_len_30)

def _training_accumulate_log(self, batch_idx, losses, y_hat, y, times):
def _training_accumulate_log(self, batch_idx, losses, y_hat, y, y_sum, times):
"""Internal function to accumulate training batches and log results.
This is used when accummulating grad batches. Should make the variability in logged training
Expand All @@ -103,12 +110,14 @@ def _training_accumulate_log(self, batch_idx, losses, y_hat, y, times):
self._accumulated_metrics.append(losses)
self._accumulated_y_hat.append(y_hat)
self._accumulated_y.append(y)
self._accumulated_y_sum.append(y_sum)
self._accumulated_times.append(times)

if not self.trainer.fit_loop._should_accumulate():
losses = self._accumulated_metrics.flush()
y_hat = self._accumulated_y_hat.flush()
y = self._accumulated_y.flush()
y_sum = self._accumulated_y_sum.flush()
times = self._accumulated_times.flush()

self.log_dict(
Expand All @@ -123,7 +132,14 @@ def _training_accumulate_log(self, batch_idx, losses, y_hat, y, times):
# We only create the figure every 8 log steps
# This was reduced as it was creating figures too often
if grad_batch_num % (8 * self.trainer.log_every_n_steps) == 0:
fig = plot_forecasts(y, y_hat, times, batch_idx, quantiles=self.output_quantiles)
fig = plot_forecasts(
y,
y_hat,
times,
batch_idx,
quantiles=self.output_quantiles,
y_sum=y_sum,
)
fig.savefig("latest_logged_train_batch.png")

def training_step(self, batch, batch_idx):
Expand All @@ -132,11 +148,12 @@ def training_step(self, batch, batch_idx):
y_hat = self.forward(batch)
y = batch["national_targets"]
times = batch["times"]
y_sum = self.sum_of_gsps(batch)

losses = self._calculate_common_losses(y, y_hat)
losses = {f"{k}/train": v for k, v in losses.items()}

self._training_accumulate_log(batch_idx, losses, y_hat, y, times)
self._training_accumulate_log(batch_idx, losses, y_hat, y, y_sum, times)

if self.use_quantile_regression:
opt_target = losses["quantile_loss/train"]
Expand All @@ -150,6 +167,7 @@ def validation_step(self, batch: dict, batch_idx):
y_hat = self.forward(batch)
y = batch["national_targets"]
times = batch["times"]
y_sum = self.sum_of_gsps(batch)

losses = self._calculate_common_losses(y, y_hat)
losses.update(self._calculate_val_losses(y, y_hat))
Expand All @@ -169,19 +187,29 @@ def validation_step(self, batch: dict, batch_idx):
if not hasattr(self, "_val_y_hats"):
self._val_y_hats = PredAccumulator()
self._val_y = PredAccumulator()
self._val_y_sum = PredAccumulator()
self._val_times = PredAccumulator()

self._val_y_hats.append(y_hat)
self._val_y.append(y)
self._val_y_sum.append(y_sum)
self._val_times.append(times)

# if batch had accumulated
if (batch_idx + 1) % self.trainer.accumulate_grad_batches == 0:
y_hat = self._val_y_hats.flush()
y = self._val_y.flush()
y_sum = self._val_y_sum.flush()
times = self._val_times.flush()

fig = plot_forecasts(y, y_hat, times, batch_idx, quantiles=self.output_quantiles)
fig = plot_forecasts(
y,
y_hat,
times,
batch_idx,
quantiles=self.output_quantiles,
y_sum=y_sum,
)

self.logger.experiment.log(
{
Expand All @@ -190,6 +218,7 @@ def validation_step(self, batch: dict, batch_idx):
)
del self._val_y_hats
del self._val_y
del self._val_y_sum
del self._val_times

return logged_losses
Expand Down
Loading

0 comments on commit 95112e0

Please sign in to comment.