Skip to content

Commit

Permalink
mask wind t0 idx, for windnet_ad_sites_generation_delay model
Browse files Browse the repository at this point in the history
  • Loading branch information
peterdudfield committed Dec 20, 2024
1 parent 6da1b51 commit 18e085c
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion india_forecast_app/models/pvnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import numpy as np
import pandas as pd
import torch
from ocf_datapipes.batch import batch_to_tensor, copy_batch_to_device, stack_np_examples_into_batch
from ocf_datapipes.batch import batch_to_tensor, copy_batch_to_device, stack_np_examples_into_batch, BatchKey

Check failure on line 14 in india_forecast_app/models/pvnet/model.py

View workflow job for this annotation

GitHub Actions / lint_and_test / Lint the code and run the tests

Ruff (E501)

india_forecast_app/models/pvnet/model.py:14:101: E501 Line too long (109 > 100)
from ocf_datapipes.training.pvnet_site import construct_sliced_data_pipeline as pv_base_pipeline
from ocf_datapipes.training.windnet import DictDatasetIterDataPipe, split_dataset_dict_dp
from ocf_datapipes.training.windnet import construct_sliced_data_pipeline as wind_base_pipeline
Expand Down Expand Up @@ -95,6 +95,10 @@ def predict(self, site_id: str, timestamp: dt.datetime):
for i, batch in enumerate(self.dataloader):
log.info(f"Predicting for batch: {i}")

if self.name == 'windnet_ad_sites_generation_delay':
# this is a bit of hack, but it's important to do what was done in training
batch[BatchKey.wind][int(batch[:, BatchKey.wind_t0_idx])] = -1

# save batch
save_batch(batch=batch, i=i, model_name=self.name, site_uuid=self.site_uuid)

Expand Down

0 comments on commit 18e085c

Please sign in to comment.