diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index 5f59400f..e9ebee20 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -299,17 +299,16 @@ def __init__( self._accumulated_metrics = MetricAccumulator() self._accumulated_batches = BatchAccumulator(key_to_keep=self._target_key_name) self._accumulated_y_hat = PredAccumulator() - + # Store whether the model should use quantile regression or simply predict the mean self.use_quantile_regression = self.output_quantiles is not None - + # Store the number of ouput features that the model should predict for if self.use_quantile_regression: self.num_output_features = self.forecast_len * len(self.output_quantiles) else: self.num_output_features = self.forecast_len - def _quantiles_to_prediction(self, y_quantiles): """ Convert network prediction into a point prediction.