diff --git a/darts/models/forecasting/pl_forecasting_module.py b/darts/models/forecasting/pl_forecasting_module.py index 6318afb4cd..b8c96623a9 100644 --- a/darts/models/forecasting/pl_forecasting_module.py +++ b/darts/models/forecasting/pl_forecasting_module.py @@ -400,11 +400,18 @@ def _update_metrics(self, output, target, metrics): return if self.likelihood: - metrics.update(self.likelihood.sample(output), target) + pred = self.likelihood.sample(output) else: # If there's no likelihood, nr_params=1, and we need to squeeze out the # last dimension of model output, for properly computing the metric. - metrics.update(output.squeeze(dim=-1), target) + pred = output.squeeze(dim=-1) + + # torch metrics require 2D targets of shape (batch size * ocl, num targets) + if self.n_targets > 1: + target = target.reshape(-1, self.n_targets) + pred = pred.reshape(-1, self.n_targets) + + metrics.update(pred, target) def _compute_metrics(self, metrics): if not len(metrics): diff --git a/darts/tests/models/forecasting/test_torch_forecasting_model.py b/darts/tests/models/forecasting/test_torch_forecasting_model.py index 2668d8d767..9ae63e80f7 100644 --- a/darts/tests/models/forecasting/test_torch_forecasting_model.py +++ b/darts/tests/models/forecasting/test_torch_forecasting_model.py @@ -1310,7 +1310,7 @@ def test_metrics(self): 10, 10, n_epochs=1, - torch_metrics=metric, + torch_metrics=metric_collection, pl_trainer_kwargs=model_kwargs, ) model.fit(self.multivariate_series)