From 89e4e75cfe469a06e0ff0f81bbd98ade69c9f232 Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Thu, 12 Sep 2024 12:29:18 +0100 Subject: [PATCH] update for quantile loss --- pvnet/models/base_model.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index 29b909b2..5dfb553d 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -635,7 +635,11 @@ def _log_validation_results(self, batch, y_hat, accum_batch_num): for i in range(batch_size): y_i = y[i] - y_hat_i = y_hat[i] + if self.use_quantile_regression: + idx = self.output_quantiles.index(0.5) + y_hat_i = y_hat[i,idx] + else: + y_hat_i = y_hat[i] time_utc_i = time_utc[i] target_id_i = target_id[i]