Skip to content

Commit

Permalink
Merge pull request #16 from openclimatefix/remove_properties
Browse files Browse the repository at this point in the history
Remove properties
  • Loading branch information
dfulu authored Feb 29, 2024
2 parents 714ff80 + 9eab67b commit fd32406
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 17 deletions.
24 changes: 16 additions & 8 deletions pvnet_summation/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,22 @@ def __init__(
self._accumulated_y_sum = PredAccumulator()
self._accumulated_times = PredAccumulator()

self.use_quantile_regression = self.output_quantiles is not None

if self.use_quantile_regression:
self.num_output_features = self.forecast_len_30 * len(self.output_quantiles)
else:
self.num_output_features = self.forecast_len_30

if self.pvnet_model.use_quantile_regression:
self.pvnet_output_shape = (
317,
self.pvnet_model.forecast_len,
len(self.pvnet_model.output_quantiles),
)
else:
self.pvnet_output_shape = (317, self.pvnet_model.forecast_len)

def predict_pvnet_batch(self, batch):
"""Use PVNet model to create predictions for batch"""
gsp_batches = []
Expand All @@ -92,14 +108,6 @@ def sum_of_gsps(self, x):

return (y_hat * x["effective_capacity"]).sum(dim=1)

@property
def pvnet_output_shape(self):
"""Return the expected shape of the PVNet outputs"""
if self.pvnet_model.use_quantile_regression:
return (317, self.pvnet_model.forecast_len_30, len(self.pvnet_model.output_quantiles))
else:
return (317, self.pvnet_model.forecast_len_30)

def _training_accumulate_log(self, batch_idx, losses, y_hat, y, y_sum, times):
"""Internal function to accumulate training batches and log results.
Expand Down
10 changes: 2 additions & 8 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
ocf_datapipes==3.2.*
pvnet==2.6.*
ocf_ml_metrics==0.0.*
pvnet==3.0.*
numpy
pandas
matplotlib
Expand All @@ -9,20 +8,15 @@ ipykernel
h5netcdf
torch>=2.0.0
lightning>=2.0.1
torchdata
pytest
pytest-cov
typer
sqlalchemy
jedi
fsspec[s3]
tables
tilemapbase
testcontainers
wandb
tensorboard
tqdm
rich
omegaconf
hydra-core
python-dotenv
huggingface-hub==0.20.*
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def model_kwargs():
# These kwargs define the pvnet model which the summation model uses
kwargs = dict(
model_name="openclimatefix/pvnet_v2",
model_version="22e577100d55787eb2547d701275b9bb48f7bfa0",
model_version="4203e12e719efd93da641c43d2e38527648f4915",
)
return kwargs

Expand Down

0 comments on commit fd32406

Please sign in to comment.