diff --git a/india_forecast_app/models/pvnet/model.py b/india_forecast_app/models/pvnet/model.py index 3fdf51f..a529cb2 100644 --- a/india_forecast_app/models/pvnet/model.py +++ b/india_forecast_app/models/pvnet/model.py @@ -102,7 +102,7 @@ def predict(self, site_id: str, timestamp: dt.datetime): if self.name == "windnet_ad_sites_generation_delay": # this is a bit of hack, but it's important to do what was done in training - batch[BatchKey.wind][int(batch[:, BatchKey.wind_t0_idx])] = -1 + batch[BatchKey.wind][:, int(batch[BatchKey.wind_t0_idx])] = -1 # save batch save_batch(batch=batch, i=i, model_name=self.name, site_uuid=self.site_uuid)