diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index 29b909b2..5dfb553d 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -635,7 +635,11 @@ def _log_validation_results(self, batch, y_hat, accum_batch_num): for i in range(batch_size): y_i = y[i] - y_hat_i = y_hat[i] + if self.use_quantile_regression: + idx = self.output_quantiles.index(0.5) + y_hat_i = y_hat[i,idx] + else: + y_hat_i = y_hat[i] time_utc_i = time_utc[i] target_id_i = target_id[i]