diff --git a/india_forecast_app/models/pvnet/model.py b/india_forecast_app/models/pvnet/model.py index 6a1e38a..fff901e 100644 --- a/india_forecast_app/models/pvnet/model.py +++ b/india_forecast_app/models/pvnet/model.py @@ -11,7 +11,7 @@ import numpy as np import pandas as pd import torch -from ocf_datapipes.batch import batch_to_tensor, copy_batch_to_device, stack_np_examples_into_batch +from ocf_datapipes.batch import batch_to_tensor, copy_batch_to_device, stack_np_examples_into_batch, BatchKey from ocf_datapipes.training.pvnet_site import construct_sliced_data_pipeline as pv_base_pipeline from ocf_datapipes.training.windnet import DictDatasetIterDataPipe, split_dataset_dict_dp from ocf_datapipes.training.windnet import construct_sliced_data_pipeline as wind_base_pipeline @@ -95,6 +95,10 @@ def predict(self, site_id: str, timestamp: dt.datetime): for i, batch in enumerate(self.dataloader): log.info(f"Predicting for batch: {i}") + if self.name == 'windnet_ad_sites_generation_delay': + # this is a bit of hack, but it's important to do what was done in training + batch[BatchKey.wind][int(batch[:, BatchKey.wind_t0_idx])] = -1 + # save batch save_batch(batch=batch, i=i, model_name=self.name, site_uuid=self.site_uuid)