Skip to content

Commit

Permalink
Merge pull request #18 from openclimatefix/Loosen-requirements
Browse files Browse the repository at this point in the history
Loosen requirements
  • Loading branch information
dfulu authored Apr 17, 2024
2 parents 5167eae + 956a983 commit 9231b81
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 7 deletions.
9 changes: 5 additions & 4 deletions pvnet_summation/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,22 +65,23 @@ def __init__(
self.output_quantiles = output_quantiles

# Number of timestemps for 30 minutely data
self.forecast_len_30 = self.forecast_minutes // 30
self.forecast_len = self.forecast_minutes // 30

self.weighted_losses = WeightedLosses(forecast_length=self.forecast_len_30)
self.weighted_losses = WeightedLosses(forecast_length=self.forecast_len)

self._accumulated_metrics = MetricAccumulator()
self._accumulated_y = PredAccumulator()
self._accumulated_y_hat = PredAccumulator()
self._accumulated_y_sum = PredAccumulator()
self._accumulated_times = PredAccumulator()
self._horizon_maes = MetricAccumulator()

self.use_quantile_regression = self.output_quantiles is not None

if self.use_quantile_regression:
self.num_output_features = self.forecast_len_30 * len(self.output_quantiles)
self.num_output_features = self.forecast_len * len(self.output_quantiles)
else:
self.num_output_features = self.forecast_len_30
self.num_output_features = self.forecast_len

if self.pvnet_model.use_quantile_regression:
self.pvnet_output_shape = (
Expand Down
2 changes: 1 addition & 1 deletion pvnet_summation/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def forward(self, x):

if self.use_quantile_regression:
# Shape: batch_size, seq_length * num_quantiles
out = out.reshape(out.shape[0], self.forecast_len_30, len(self.output_quantiles))
out = out.reshape(out.shape[0], self.forecast_len, len(self.output_quantiles))

if self.predict_difference_from_sum:
gsp_sum = self.sum_of_gsps(x)
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
ocf_datapipes==3.2.*
pvnet==3.0.*
ocf_datapipes>=3.3.19
pvnet>=3.0.25
numpy
pandas
matplotlib
Expand Down

0 comments on commit 9231b81

Please sign in to comment.