Skip to content

Commit

Permalink
fix failing torch test for torchmetrics when multi output (#2573)
Browse files Browse the repository at this point in the history
  • Loading branch information
dennisbader authored Oct 27, 2024
1 parent 264e6b3 commit 0b9efd0
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
11 changes: 9 additions & 2 deletions darts/models/forecasting/pl_forecasting_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,11 +400,18 @@ def _update_metrics(self, output, target, metrics):
return

if self.likelihood:
metrics.update(self.likelihood.sample(output), target)
pred = self.likelihood.sample(output)
else:
# If there's no likelihood, nr_params=1, and we need to squeeze out the
# last dimension of model output, for properly computing the metric.
metrics.update(output.squeeze(dim=-1), target)
pred = output.squeeze(dim=-1)

# torch metrics require 2D targets of shape (batch size * ocl, num targets)
if self.n_targets > 1:
target = target.reshape(-1, self.n_targets)
pred = pred.reshape(-1, self.n_targets)

metrics.update(pred, target)

def _compute_metrics(self, metrics):
if not len(metrics):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1310,7 +1310,7 @@ def test_metrics(self):
10,
10,
n_epochs=1,
torch_metrics=metric,
torch_metrics=metric_collection,
pl_trainer_kwargs=model_kwargs,
)
model.fit(self.multivariate_series)
Expand Down

0 comments on commit 0b9efd0

Please sign in to comment.