diff --git a/pvnet_summation/models/base_model.py b/pvnet_summation/models/base_model.py index c58e4f6..6e7d541 100644 --- a/pvnet_summation/models/base_model.py +++ b/pvnet_summation/models/base_model.py @@ -75,6 +75,22 @@ def __init__( self._accumulated_y_sum = PredAccumulator() self._accumulated_times = PredAccumulator() + self.use_quantile_regression = self.output_quantiles is not None + + if self.use_quantile_regression: + self.num_output_features = self.forecast_len_30 * len(self.output_quantiles) + else: + self.num_output_features = self.forecast_len_30 + + if self.pvnet_model.use_quantile_regression: + self.pvnet_output_shape = ( + 317, + self.pvnet_model.forecast_len, + len(self.pvnet_model.output_quantiles), + ) + else: + self.pvnet_output_shape = (317, self.pvnet_model.forecast_len) + def predict_pvnet_batch(self, batch): """Use PVNet model to create predictions for batch""" gsp_batches = [] @@ -92,14 +108,6 @@ def sum_of_gsps(self, x): return (y_hat * x["effective_capacity"]).sum(dim=1) - @property - def pvnet_output_shape(self): - """Return the expected shape of the PVNet outputs""" - if self.pvnet_model.use_quantile_regression: - return (317, self.pvnet_model.forecast_len_30, len(self.pvnet_model.output_quantiles)) - else: - return (317, self.pvnet_model.forecast_len_30) - def _training_accumulate_log(self, batch_idx, losses, y_hat, y, y_sum, times): """Internal function to accumulate training batches and log results. diff --git a/requirements.txt b/requirements.txt index 3b4b122..f394115 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,5 @@ ocf_datapipes==3.2.* -pvnet==2.6.* -ocf_ml_metrics==0.0.* +pvnet==3.0.* numpy pandas matplotlib @@ -9,20 +8,15 @@ ipykernel h5netcdf torch>=2.0.0 lightning>=2.0.1 -torchdata pytest pytest-cov typer sqlalchemy -jedi fsspec[s3] -tables -tilemapbase -testcontainers wandb tensorboard tqdm -rich omegaconf hydra-core python-dotenv +huggingface-hub==0.20.* diff --git a/tests/conftest.py b/tests/conftest.py index e756706..edabfc8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -138,7 +138,7 @@ def model_kwargs(): # These kwargs define the pvnet model which the summation model uses kwargs = dict( model_name="openclimatefix/pvnet_v2", - model_version="22e577100d55787eb2547d701275b9bb48f7bfa0", + model_version="4203e12e719efd93da641c43d2e38527648f4915", ) return kwargs