diff --git a/darts/models/forecasting/rnn_model.py b/darts/models/forecasting/rnn_model.py index e7d55ac9c6..d51b0a3838 100644 --- a/darts/models/forecasting/rnn_model.py +++ b/darts/models/forecasting/rnn_model.py @@ -104,6 +104,12 @@ def forward( pass def _produce_train_output(self, input_batch: tuple) -> torch.Tensor: + # only return the forecast, not the hidden state + return self(self._process_input_batch(input_batch))[0] + + def _process_input_batch( + self, input_batch: tuple + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: ( past_target, historic_future_covariates, @@ -112,7 +118,7 @@ def _produce_train_output(self, input_batch: tuple) -> torch.Tensor: ) = input_batch # For the RNN we concatenate the past_target with the future_covariates # (they have the same length because we enforce a Shift dataset for RNNs) - model_input = ( + return ( ( torch.cat([past_target, future_covariates], dim=2) if future_covariates is not None @@ -120,7 +126,6 @@ def _produce_train_output(self, input_batch: tuple) -> torch.Tensor: ), static_covariates, ) - return self(model_input)[0] def _produce_predict_output( self, x: tuple, last_hidden_state: Optional[torch.Tensor] = None