From cdff09ab0646df7a25304a6b047b112f5ac78552 Mon Sep 17 00:00:00 2001
From: Dennis Bader <dennis.bader@gmx.ch>
Date: Tue, 9 Apr 2024 11:00:12 +0200
Subject: [PATCH] use pytest to skip torch tests (#2307)

* use pytest to skip torch tests

* fix some mistakes in tsmixer notebook
---
 darts/tests/datasets/test_datasets.py         | 2591 ++++++++-------
 .../explainability/test_tft_explainer.py      |  855 +++--
 darts/tests/models/components/glu_variants.py |   31 +-
 .../components/test_layer_norm_variants.py    |   57 +-
 darts/tests/models/forecasting/test_RNN.py    |  286 +-
 darts/tests/models/forecasting/test_TCN.py    |  356 +-
 darts/tests/models/forecasting/test_TFT.py    |  754 +++--
 .../models/forecasting/test_block_RNN.py      |  298 +-
 .../forecasting/test_dlinear_nlinear.py       |  590 ++--
 .../test_global_forecasting_models.py         | 1431 ++++----
 .../models/forecasting/test_nbeats_nhits.py   |  326 +-
 .../models/forecasting/test_ptl_trainer.py    |  464 ++-
 .../models/forecasting/test_tide_model.py     |  432 ++-
 .../test_torch_forecasting_model.py           | 2934 ++++++++---------
 .../forecasting/test_transformer_model.py     |  314 +-
 .../tests/models/forecasting/test_tsmixer.py  |   13 +-
 darts/tests/utils/test_likelihood_models.py   |   37 +-
 darts/tests/utils/test_losses.py              |  103 +-
 darts/tests/utils/test_utils_torch.py         |  212 +-
 docs/source/examples.rst                      |    2 +-
 examples/21-TSMixer-examples.ipynb            |   11 +-
 21 files changed, 5993 insertions(+), 6104 deletions(-)

diff --git a/darts/tests/datasets/test_datasets.py b/darts/tests/datasets/test_datasets.py
index 92b37f105b..d4676d6ae7 100644
--- a/darts/tests/datasets/test_datasets.py
+++ b/darts/tests/datasets/test_datasets.py
@@ -29,1435 +29,1424 @@
         SplitCovariatesSequentialDataset,
         SplitCovariatesShiftedDataset,
     )
-
-    TORCH_AVAILABLE = True
 except ImportError:
-    logger.warning("Torch not installed - dataset tests will be skipped.")
-    TORCH_AVAILABLE = False
-
-if TORCH_AVAILABLE:
-
-    class TestDataset:
-        target1 = gaussian_timeseries(length=100).with_static_covariates(
-            pd.Series([0, 1], index=["st1", "st2"])
-        )
-        target2 = gaussian_timeseries(length=150).with_static_covariates(
-            pd.Series([2, 3], index=["st1", "st2"])
-        )
-        cov_st1 = target1.static_covariates.values
-        cov_st2 = target2.static_covariates.values
-        cov_st2_df = pd.Series([2, 3], index=["st1", "st2"])
-        vals1, vals2 = target1.values(), target2.values()
-        cov1, cov2 = gaussian_timeseries(length=100), gaussian_timeseries(length=150)
-
-        def _assert_eq(self, lefts: tuple, rights: tuple):
-            for left, right in zip(lefts, rights):
-                left = left.values() if isinstance(left, TimeSeries) else left
-                right = right.values() if isinstance(right, TimeSeries) else right
-                assert type(left) is type(right)
-                assert (
-                    isinstance(
-                        left, (TimeSeries, pd.Series, pd.DataFrame, np.ndarray, list)
-                    )
-                    or left is None
-                )
-                if isinstance(left, (pd.Series, pd.DataFrame)):
-                    assert left.equals(right)
-                elif isinstance(left, np.ndarray):
-                    np.testing.assert_array_equal(left, right)
-                elif isinstance(left, (list, TimeSeries)):
-                    assert left == right
-                else:
-                    assert right is None
-
-        def test_past_covariates_inference_dataset(self):
-            # one target series
-            ds = PastCovariatesInferenceDataset(
-                target_series=self.target1, input_chunk_length=len(self.target1)
-            )
-            np.testing.assert_almost_equal(ds[0][0], self.vals1)
-            self._assert_eq(ds[0][1:], (None, None, self.cov_st1, self.target1))
+    pytest.skip(
+        f"Torch not available. {__name__} tests will be skipped.",
+        allow_module_level=True,
+    )
 
-            # two target series
-            ds = PastCovariatesInferenceDataset(
-                target_series=[self.target1, self.target2],
-                input_chunk_length=max(len(self.target1), len(self.target2)),
-            )
-            np.testing.assert_almost_equal(ds[1][0], self.vals2)
-            self._assert_eq(ds[1][1:], (None, None, self.cov_st2, self.target2))
 
-            # fail if covariates do not have same size
-            with pytest.raises(ValueError):
-                ds = PastCovariatesInferenceDataset(
-                    target_series=[self.target1, self.target2], covariates=[self.cov1]
+class TestDataset:
+    target1 = gaussian_timeseries(length=100).with_static_covariates(
+        pd.Series([0, 1], index=["st1", "st2"])
+    )
+    target2 = gaussian_timeseries(length=150).with_static_covariates(
+        pd.Series([2, 3], index=["st1", "st2"])
+    )
+    cov_st1 = target1.static_covariates.values
+    cov_st2 = target2.static_covariates.values
+    cov_st2_df = pd.Series([2, 3], index=["st1", "st2"])
+    vals1, vals2 = target1.values(), target2.values()
+    cov1, cov2 = gaussian_timeseries(length=100), gaussian_timeseries(length=150)
+
+    def _assert_eq(self, lefts: tuple, rights: tuple):
+        for left, right in zip(lefts, rights):
+            left = left.values() if isinstance(left, TimeSeries) else left
+            right = right.values() if isinstance(right, TimeSeries) else right
+            assert type(left) is type(right)
+            assert (
+                isinstance(
+                    left, (TimeSeries, pd.Series, pd.DataFrame, np.ndarray, list)
                 )
+                or left is None
+            )
+            if isinstance(left, (pd.Series, pd.DataFrame)):
+                assert left.equals(right)
+            elif isinstance(left, np.ndarray):
+                np.testing.assert_array_equal(left, right)
+            elif isinstance(left, (list, TimeSeries)):
+                assert left == right
+            else:
+                assert right is None
+
+    def test_past_covariates_inference_dataset(self):
+        # one target series
+        ds = PastCovariatesInferenceDataset(
+            target_series=self.target1, input_chunk_length=len(self.target1)
+        )
+        np.testing.assert_almost_equal(ds[0][0], self.vals1)
+        self._assert_eq(ds[0][1:], (None, None, self.cov_st1, self.target1))
 
-            # with covariates
-            ds = PastCovariatesInferenceDataset(
-                target_series=[self.target1, self.target2],
-                covariates=[self.cov1, self.cov2],
-                input_chunk_length=max(len(self.target1), len(self.target2)),
-            )
-            np.testing.assert_almost_equal(ds[1][0], self.vals2)
-            np.testing.assert_almost_equal(ds[1][1], self.cov2.values())
-            self._assert_eq(
-                ds[1][2:], (None, self.cov_st2, self.target2)
-            )  # no "future past" covariate here
-
-            # more complex case with future past covariates:
-            times1 = pd.date_range(start="20100101", end="20100701", freq="D")
-            times2 = pd.date_range(
-                start="20100101", end="20100820", freq="D"
-            )  # 50 days longer than times1
-
-            target = TimeSeries.from_times_and_values(
-                times1, np.random.randn(len(times1))
-            ).with_static_covariates(self.cov_st2_df)
-            short_cov = TimeSeries.from_times_and_values(
-                times1, np.random.randn(len(times1))
-            )
-            long_cov = TimeSeries.from_times_and_values(
-                times2, np.random.randn(len(times2))
-            )
+        # two target series
+        ds = PastCovariatesInferenceDataset(
+            target_series=[self.target1, self.target2],
+            input_chunk_length=max(len(self.target1), len(self.target2)),
+        )
+        np.testing.assert_almost_equal(ds[1][0], self.vals2)
+        self._assert_eq(ds[1][1:], (None, None, self.cov_st2, self.target2))
 
+        # fail if covariates do not have same size
+        with pytest.raises(ValueError):
             ds = PastCovariatesInferenceDataset(
-                target_series=target,
-                covariates=short_cov,
-                input_chunk_length=10,
-                output_chunk_length=10,
-                n=30,
+                target_series=[self.target1, self.target2], covariates=[self.cov1]
             )
 
-            # should fail if covariates are too short
-            with pytest.raises(ValueError):
-                _ = ds[0]
-
-            # Should return correct values when covariates is long enough
-            ds = PastCovariatesInferenceDataset(
-                target_series=target,
-                covariates=long_cov,
-                input_chunk_length=10,
-                output_chunk_length=10,
-                n=30,
-            )
+        # with covariates
+        ds = PastCovariatesInferenceDataset(
+            target_series=[self.target1, self.target2],
+            covariates=[self.cov1, self.cov2],
+            input_chunk_length=max(len(self.target1), len(self.target2)),
+        )
+        np.testing.assert_almost_equal(ds[1][0], self.vals2)
+        np.testing.assert_almost_equal(ds[1][1], self.cov2.values())
+        self._assert_eq(
+            ds[1][2:], (None, self.cov_st2, self.target2)
+        )  # no "future past" covariate here
+
+        # more complex case with future past covariates:
+        times1 = pd.date_range(start="20100101", end="20100701", freq="D")
+        times2 = pd.date_range(
+            start="20100101", end="20100820", freq="D"
+        )  # 50 days longer than times1
+
+        target = TimeSeries.from_times_and_values(
+            times1, np.random.randn(len(times1))
+        ).with_static_covariates(self.cov_st2_df)
+        short_cov = TimeSeries.from_times_and_values(
+            times1, np.random.randn(len(times1))
+        )
+        long_cov = TimeSeries.from_times_and_values(
+            times2, np.random.randn(len(times2))
+        )
 
-            np.testing.assert_almost_equal(ds[0][0], target.values()[-10:])
-            np.testing.assert_almost_equal(ds[0][1], long_cov.values()[-60:-50])
-            np.testing.assert_almost_equal(ds[0][2], long_cov.values()[-50:-30])
-            np.testing.assert_almost_equal(ds[0][3], self.cov_st2)
-            assert ds[0][4] == target
-
-            # Should also work for integer-indexed series
-            target = TimeSeries.from_times_and_values(
-                pd.RangeIndex(start=10, stop=50, step=1), np.random.randn(40)
-            ).with_static_covariates(self.cov_st2_df)
-            covariate = TimeSeries.from_times_and_values(
-                pd.RangeIndex(start=20, stop=80, step=1), np.random.randn(60)
-            )
+        ds = PastCovariatesInferenceDataset(
+            target_series=target,
+            covariates=short_cov,
+            input_chunk_length=10,
+            output_chunk_length=10,
+            n=30,
+        )
 
-            ds = PastCovariatesInferenceDataset(
-                target_series=target,
-                covariates=covariate,
-                input_chunk_length=10,
-                output_chunk_length=10,
-                n=20,
-            )
+        # should fail if covariates are too short
+        with pytest.raises(ValueError):
+            _ = ds[0]
+
+        # Should return correct values when covariates is long enough
+        ds = PastCovariatesInferenceDataset(
+            target_series=target,
+            covariates=long_cov,
+            input_chunk_length=10,
+            output_chunk_length=10,
+            n=30,
+        )
 
-            np.testing.assert_almost_equal(ds[0][0], target.values()[-10:])
-            np.testing.assert_almost_equal(ds[0][1], covariate.values()[20:30])
-            np.testing.assert_almost_equal(ds[0][2], covariate.values()[30:40])
-            np.testing.assert_almost_equal(ds[0][3], self.cov_st2)
-            assert ds[0][4] == target
+        np.testing.assert_almost_equal(ds[0][0], target.values()[-10:])
+        np.testing.assert_almost_equal(ds[0][1], long_cov.values()[-60:-50])
+        np.testing.assert_almost_equal(ds[0][2], long_cov.values()[-50:-30])
+        np.testing.assert_almost_equal(ds[0][3], self.cov_st2)
+        assert ds[0][4] == target
+
+        # Should also work for integer-indexed series
+        target = TimeSeries.from_times_and_values(
+            pd.RangeIndex(start=10, stop=50, step=1), np.random.randn(40)
+        ).with_static_covariates(self.cov_st2_df)
+        covariate = TimeSeries.from_times_and_values(
+            pd.RangeIndex(start=20, stop=80, step=1), np.random.randn(60)
+        )
 
-        def test_future_covariates_inference_dataset(self):
-            # one target series
-            ds = FutureCovariatesInferenceDataset(
-                target_series=self.target1, input_chunk_length=len(self.target1)
-            )
-            np.testing.assert_almost_equal(ds[0][0], self.vals1)
-            self._assert_eq(ds[0][1:], (None, self.cov_st1, self.target1))
+        ds = PastCovariatesInferenceDataset(
+            target_series=target,
+            covariates=covariate,
+            input_chunk_length=10,
+            output_chunk_length=10,
+            n=20,
+        )
 
-            # two target series
-            ds = FutureCovariatesInferenceDataset(
-                target_series=[self.target1, self.target2],
-                input_chunk_length=max(len(self.target1), len(self.target2)),
-            )
-            np.testing.assert_almost_equal(ds[1][0], self.vals2)
-            self._assert_eq(ds[1][1:], (None, self.cov_st2, self.target2))
+        np.testing.assert_almost_equal(ds[0][0], target.values()[-10:])
+        np.testing.assert_almost_equal(ds[0][1], covariate.values()[20:30])
+        np.testing.assert_almost_equal(ds[0][2], covariate.values()[30:40])
+        np.testing.assert_almost_equal(ds[0][3], self.cov_st2)
+        assert ds[0][4] == target
 
-            # fail if covariates do not have same size
-            with pytest.raises(ValueError):
-                ds = FutureCovariatesInferenceDataset(
-                    target_series=[self.target1, self.target2], covariates=[self.cov1]
-                )
+    def test_future_covariates_inference_dataset(self):
+        # one target series
+        ds = FutureCovariatesInferenceDataset(
+            target_series=self.target1, input_chunk_length=len(self.target1)
+        )
+        np.testing.assert_almost_equal(ds[0][0], self.vals1)
+        self._assert_eq(ds[0][1:], (None, self.cov_st1, self.target1))
 
-            # With future past covariates:
-            times1 = pd.date_range(start="20100101", end="20100701", freq="D")
-            times2 = pd.date_range(
-                start="20100101", end="20100820", freq="D"
-            )  # 50 days longer than times1
-
-            target = TimeSeries.from_times_and_values(
-                times1, np.random.randn(len(times1))
-            ).with_static_covariates(self.cov_st2_df)
-            short_cov = TimeSeries.from_times_and_values(
-                times1, np.random.randn(len(times1))
-            )
-            long_cov = TimeSeries.from_times_and_values(
-                times2, np.random.randn(len(times2))
-            )
+        # two target series
+        ds = FutureCovariatesInferenceDataset(
+            target_series=[self.target1, self.target2],
+            input_chunk_length=max(len(self.target1), len(self.target2)),
+        )
+        np.testing.assert_almost_equal(ds[1][0], self.vals2)
+        self._assert_eq(ds[1][1:], (None, self.cov_st2, self.target2))
 
+        # fail if covariates do not have same size
+        with pytest.raises(ValueError):
             ds = FutureCovariatesInferenceDataset(
-                target_series=target, covariates=short_cov, input_chunk_length=10, n=30
+                target_series=[self.target1, self.target2], covariates=[self.cov1]
             )
 
-            # should fail if covariates are too short
-            with pytest.raises(ValueError):
-                _ = ds[0]
-
-            # Should return correct values when covariates is long enough
-            ds = FutureCovariatesInferenceDataset(
-                target_series=target, covariates=long_cov, input_chunk_length=10, n=30
-            )
+        # With future past covariates:
+        times1 = pd.date_range(start="20100101", end="20100701", freq="D")
+        times2 = pd.date_range(
+            start="20100101", end="20100820", freq="D"
+        )  # 50 days longer than times1
 
-            np.testing.assert_almost_equal(ds[0][0], target.values()[-10:])
-            np.testing.assert_almost_equal(ds[0][1], long_cov.values()[-50:-20])
-            np.testing.assert_almost_equal(ds[0][2], self.cov_st2)
-            assert ds[0][3] == target
-
-            # Should also work for integer-indexed series
-            target = TimeSeries.from_times_and_values(
-                pd.RangeIndex(start=10, stop=50, step=1), np.random.randn(40)
-            ).with_static_covariates(self.cov_st2_df)
-            covariate = TimeSeries.from_times_and_values(
-                pd.RangeIndex(start=20, stop=80, step=1), np.random.randn(60)
-            )
+        target = TimeSeries.from_times_and_values(
+            times1, np.random.randn(len(times1))
+        ).with_static_covariates(self.cov_st2_df)
+        short_cov = TimeSeries.from_times_and_values(
+            times1, np.random.randn(len(times1))
+        )
+        long_cov = TimeSeries.from_times_and_values(
+            times2, np.random.randn(len(times2))
+        )
 
-            ds = FutureCovariatesInferenceDataset(
-                target_series=target, covariates=covariate, input_chunk_length=10, n=20
-            )
+        ds = FutureCovariatesInferenceDataset(
+            target_series=target, covariates=short_cov, input_chunk_length=10, n=30
+        )
 
-            np.testing.assert_almost_equal(ds[0][0], target.values()[-10:])
-            np.testing.assert_almost_equal(ds[0][1], covariate.values()[30:50])
-            np.testing.assert_almost_equal(ds[0][2], self.cov_st2)
-            assert ds[0][3] == target
+        # should fail if covariates are too short
+        with pytest.raises(ValueError):
+            _ = ds[0]
 
-        def test_dual_covariates_inference_dataset(self):
-            # one target series
-            ds = DualCovariatesInferenceDataset(
-                target_series=self.target1, input_chunk_length=len(self.target1)
-            )
-            np.testing.assert_almost_equal(ds[0][0], self.vals1)
-            self._assert_eq(ds[0][1:], (None, None, self.cov_st1, self.target1))
+        # Should return correct values when covariates is long enough
+        ds = FutureCovariatesInferenceDataset(
+            target_series=target, covariates=long_cov, input_chunk_length=10, n=30
+        )
 
-            # two target series
-            ds = DualCovariatesInferenceDataset(
-                target_series=[self.target1, self.target2],
-                input_chunk_length=max(len(self.target1), len(self.target2)),
-            )
-            np.testing.assert_almost_equal(ds[1][0], self.vals2)
-            self._assert_eq(ds[1][1:], (None, None, self.cov_st2, self.target2))
+        np.testing.assert_almost_equal(ds[0][0], target.values()[-10:])
+        np.testing.assert_almost_equal(ds[0][1], long_cov.values()[-50:-20])
+        np.testing.assert_almost_equal(ds[0][2], self.cov_st2)
+        assert ds[0][3] == target
+
+        # Should also work for integer-indexed series
+        target = TimeSeries.from_times_and_values(
+            pd.RangeIndex(start=10, stop=50, step=1), np.random.randn(40)
+        ).with_static_covariates(self.cov_st2_df)
+        covariate = TimeSeries.from_times_and_values(
+            pd.RangeIndex(start=20, stop=80, step=1), np.random.randn(60)
+        )
 
-            # fail if covariates do not have same size
-            with pytest.raises(ValueError):
-                ds = DualCovariatesInferenceDataset(
-                    target_series=[self.target1, self.target2], covariates=[self.cov1]
-                )
+        ds = FutureCovariatesInferenceDataset(
+            target_series=target, covariates=covariate, input_chunk_length=10, n=20
+        )
 
-            # With future past covariates:
-            times1 = pd.date_range(start="20100101", end="20100701", freq="D")
-            times2 = pd.date_range(
-                start="20100101", end="20100820", freq="D"
-            )  # 50 days longer than times1
-
-            target = TimeSeries.from_times_and_values(
-                times1, np.random.randn(len(times1))
-            ).with_static_covariates(self.cov_st2_df)
-            short_cov = TimeSeries.from_times_and_values(
-                times1, np.random.randn(len(times1))
-            )
-            long_cov = TimeSeries.from_times_and_values(
-                times2, np.random.randn(len(times2))
-            )
+        np.testing.assert_almost_equal(ds[0][0], target.values()[-10:])
+        np.testing.assert_almost_equal(ds[0][1], covariate.values()[30:50])
+        np.testing.assert_almost_equal(ds[0][2], self.cov_st2)
+        assert ds[0][3] == target
 
-            ds = DualCovariatesInferenceDataset(
-                target_series=target,
-                covariates=short_cov,
-                input_chunk_length=10,
-                output_chunk_length=10,
-                n=30,
-            )
+    def test_dual_covariates_inference_dataset(self):
+        # one target series
+        ds = DualCovariatesInferenceDataset(
+            target_series=self.target1, input_chunk_length=len(self.target1)
+        )
+        np.testing.assert_almost_equal(ds[0][0], self.vals1)
+        self._assert_eq(ds[0][1:], (None, None, self.cov_st1, self.target1))
 
-            # should fail if covariates are too short
-            with pytest.raises(ValueError):
-                _ = ds[0]
+        # two target series
+        ds = DualCovariatesInferenceDataset(
+            target_series=[self.target1, self.target2],
+            input_chunk_length=max(len(self.target1), len(self.target2)),
+        )
+        np.testing.assert_almost_equal(ds[1][0], self.vals2)
+        self._assert_eq(ds[1][1:], (None, None, self.cov_st2, self.target2))
 
-            # Should return correct values when covariates is long enough
+        # fail if covariates do not have same size
+        with pytest.raises(ValueError):
             ds = DualCovariatesInferenceDataset(
-                target_series=target,
-                covariates=long_cov,
-                input_chunk_length=10,
-                output_chunk_length=10,
-                n=30,
+                target_series=[self.target1, self.target2], covariates=[self.cov1]
             )
 
-            np.testing.assert_almost_equal(ds[0][0], target.values()[-10:])
-            np.testing.assert_almost_equal(ds[0][1], long_cov.values()[-60:-50])
-            np.testing.assert_almost_equal(ds[0][2], long_cov.values()[-50:-20])
-            np.testing.assert_almost_equal(ds[0][3], self.cov_st2)
-            assert ds[0][4] == target
-
-            # Should also work for integer-indexed series
-            target = TimeSeries.from_times_and_values(
-                pd.RangeIndex(start=10, stop=50, step=1), np.random.randn(40)
-            ).with_static_covariates(self.cov_st2_df)
-            covariate = TimeSeries.from_times_and_values(
-                pd.RangeIndex(start=20, stop=80, step=1), np.random.randn(60)
-            )
-
-            ds = DualCovariatesInferenceDataset(
-                target_series=target,
-                covariates=covariate,
-                input_chunk_length=10,
-                output_chunk_length=10,
-                n=20,
-            )
+        # With future past covariates:
+        times1 = pd.date_range(start="20100101", end="20100701", freq="D")
+        times2 = pd.date_range(
+            start="20100101", end="20100820", freq="D"
+        )  # 50 days longer than times1
 
-            np.testing.assert_almost_equal(ds[0][0], target.values()[-10:])
-            np.testing.assert_almost_equal(ds[0][1], covariate.values()[20:30])
-            np.testing.assert_almost_equal(ds[0][2], covariate.values()[30:50])
-            np.testing.assert_almost_equal(ds[0][3], self.cov_st2)
-            assert ds[0][4] == target
-
-        def test_mixed_covariates_inference_dataset(self):
-            # With future past covariates:
-            times1 = pd.date_range(start="20100101", end="20100701", freq="D")
-            times2 = pd.date_range(
-                start="20100201", end="20100820", freq="D"
-            )  # ends 50 days after times1
-
-            target = TimeSeries.from_times_and_values(
-                times1, np.random.randn(len(times1))
-            ).with_static_covariates(self.cov_st2_df)
-            past_cov = TimeSeries.from_times_and_values(
-                times1, np.random.randn(len(times1))
-            )
-            long_past_cov = TimeSeries.from_times_and_values(
-                times2, np.random.randn(len(times2))
-            )
-            future_cov = TimeSeries.from_times_and_values(
-                times2, np.random.randn(len(times2))
-            )
+        target = TimeSeries.from_times_and_values(
+            times1, np.random.randn(len(times1))
+        ).with_static_covariates(self.cov_st2_df)
+        short_cov = TimeSeries.from_times_and_values(
+            times1, np.random.randn(len(times1))
+        )
+        long_cov = TimeSeries.from_times_and_values(
+            times2, np.random.randn(len(times2))
+        )
 
-            ds = MixedCovariatesInferenceDataset(
-                target_series=target,
-                past_covariates=past_cov,
-                future_covariates=past_cov,
-                input_chunk_length=10,
-                output_chunk_length=10,
-                n=30,
-            )
+        ds = DualCovariatesInferenceDataset(
+            target_series=target,
+            covariates=short_cov,
+            input_chunk_length=10,
+            output_chunk_length=10,
+            n=30,
+        )
 
-            # should fail if future covariates are too short
-            with pytest.raises(ValueError):
-                _ = ds[0]
+        # should fail if covariates are too short
+        with pytest.raises(ValueError):
+            _ = ds[0]
+
+        # Should return correct values when covariates is long enough
+        ds = DualCovariatesInferenceDataset(
+            target_series=target,
+            covariates=long_cov,
+            input_chunk_length=10,
+            output_chunk_length=10,
+            n=30,
+        )
 
-            # Should return correct values when covariates is long enough
-            ds = MixedCovariatesInferenceDataset(
-                target_series=target,
-                past_covariates=long_past_cov,
-                future_covariates=future_cov,
-                input_chunk_length=10,
-                output_chunk_length=10,
-                n=30,
-            )
+        np.testing.assert_almost_equal(ds[0][0], target.values()[-10:])
+        np.testing.assert_almost_equal(ds[0][1], long_cov.values()[-60:-50])
+        np.testing.assert_almost_equal(ds[0][2], long_cov.values()[-50:-20])
+        np.testing.assert_almost_equal(ds[0][3], self.cov_st2)
+        assert ds[0][4] == target
+
+        # Should also work for integer-indexed series
+        target = TimeSeries.from_times_and_values(
+            pd.RangeIndex(start=10, stop=50, step=1), np.random.randn(40)
+        ).with_static_covariates(self.cov_st2_df)
+        covariate = TimeSeries.from_times_and_values(
+            pd.RangeIndex(start=20, stop=80, step=1), np.random.randn(60)
+        )
 
-            # It should contain:
-            # past_target, past_covariates, historic_future_covariates, future_covariates, future_past_covariates
-            np.testing.assert_almost_equal(ds[0][0], target.values()[-10:])
-            np.testing.assert_almost_equal(ds[0][1], long_past_cov.values()[-60:-50])
-            np.testing.assert_almost_equal(ds[0][2], future_cov.values()[-60:-50])
-            np.testing.assert_almost_equal(ds[0][3], future_cov.values()[-50:-20])
-            np.testing.assert_almost_equal(ds[0][4], long_past_cov.values()[-50:-30])
-            np.testing.assert_almost_equal(ds[0][5], self.cov_st2)
-            assert ds[0][6] == target
-
-            # Should also work for integer-indexed series
-            target = TimeSeries.from_times_and_values(
-                pd.RangeIndex(start=10, stop=50, step=1), np.random.randn(40)
-            ).with_static_covariates(self.cov_st2_df)
-            past_cov = TimeSeries.from_times_and_values(
-                pd.RangeIndex(start=20, stop=80, step=1), np.random.randn(60)
-            )
-            future_cov = TimeSeries.from_times_and_values(
-                pd.RangeIndex(start=30, stop=100, step=1), np.random.randn(70)
-            )
+        ds = DualCovariatesInferenceDataset(
+            target_series=target,
+            covariates=covariate,
+            input_chunk_length=10,
+            output_chunk_length=10,
+            n=20,
+        )
 
-            ds = MixedCovariatesInferenceDataset(
-                target_series=target,
-                past_covariates=past_cov,
-                future_covariates=future_cov,
-                input_chunk_length=10,
-                output_chunk_length=10,
-                n=20,
-            )
+        np.testing.assert_almost_equal(ds[0][0], target.values()[-10:])
+        np.testing.assert_almost_equal(ds[0][1], covariate.values()[20:30])
+        np.testing.assert_almost_equal(ds[0][2], covariate.values()[30:50])
+        np.testing.assert_almost_equal(ds[0][3], self.cov_st2)
+        assert ds[0][4] == target
+
+    def test_mixed_covariates_inference_dataset(self):
+        # With future past covariates:
+        times1 = pd.date_range(start="20100101", end="20100701", freq="D")
+        times2 = pd.date_range(
+            start="20100201", end="20100820", freq="D"
+        )  # ends 50 days after times1
+
+        target = TimeSeries.from_times_and_values(
+            times1, np.random.randn(len(times1))
+        ).with_static_covariates(self.cov_st2_df)
+        past_cov = TimeSeries.from_times_and_values(
+            times1, np.random.randn(len(times1))
+        )
+        long_past_cov = TimeSeries.from_times_and_values(
+            times2, np.random.randn(len(times2))
+        )
+        future_cov = TimeSeries.from_times_and_values(
+            times2, np.random.randn(len(times2))
+        )
 
-            np.testing.assert_almost_equal(ds[0][0], target.values()[-10:])
-            np.testing.assert_almost_equal(ds[0][1], past_cov.values()[20:30])
-            np.testing.assert_almost_equal(ds[0][2], future_cov.values()[10:20])
-            np.testing.assert_almost_equal(ds[0][3], future_cov.values()[20:40])
-            np.testing.assert_almost_equal(ds[0][4], past_cov.values()[30:40])
-            np.testing.assert_almost_equal(ds[0][5], self.cov_st2)
-            assert ds[0][6] == target
-
-        def test_split_covariates_inference_dataset(self):
-            # With future past covariates:
-            times1 = pd.date_range(start="20100101", end="20100701", freq="D")
-            times2 = pd.date_range(
-                start="20100201", end="20100820", freq="D"
-            )  # ends 50 days after times1
-
-            target = TimeSeries.from_times_and_values(
-                times1, np.random.randn(len(times1))
-            ).with_static_covariates(self.cov_st2_df)
-            past_cov = TimeSeries.from_times_and_values(
-                times1, np.random.randn(len(times1))
-            )
-            long_past_cov = TimeSeries.from_times_and_values(
-                times2, np.random.randn(len(times2))
-            )
-            future_cov = TimeSeries.from_times_and_values(
-                times2, np.random.randn(len(times2))
-            )
+        ds = MixedCovariatesInferenceDataset(
+            target_series=target,
+            past_covariates=past_cov,
+            future_covariates=past_cov,
+            input_chunk_length=10,
+            output_chunk_length=10,
+            n=30,
+        )
 
-            ds = SplitCovariatesInferenceDataset(
-                target_series=target,
-                past_covariates=past_cov,
-                future_covariates=past_cov,
-                input_chunk_length=10,
-                output_chunk_length=10,
-                n=30,
-            )
+        # should fail if future covariates are too short
+        with pytest.raises(ValueError):
+            _ = ds[0]
+
+        # Should return correct values when covariates is long enough
+        ds = MixedCovariatesInferenceDataset(
+            target_series=target,
+            past_covariates=long_past_cov,
+            future_covariates=future_cov,
+            input_chunk_length=10,
+            output_chunk_length=10,
+            n=30,
+        )
 
-            # should fail if future covariates are too short
-            with pytest.raises(ValueError):
-                _ = ds[0]
+        # It should contain:
+        # past_target, past_covariates, historic_future_covariates, future_covariates, future_past_covariates
+        np.testing.assert_almost_equal(ds[0][0], target.values()[-10:])
+        np.testing.assert_almost_equal(ds[0][1], long_past_cov.values()[-60:-50])
+        np.testing.assert_almost_equal(ds[0][2], future_cov.values()[-60:-50])
+        np.testing.assert_almost_equal(ds[0][3], future_cov.values()[-50:-20])
+        np.testing.assert_almost_equal(ds[0][4], long_past_cov.values()[-50:-30])
+        np.testing.assert_almost_equal(ds[0][5], self.cov_st2)
+        assert ds[0][6] == target
+
+        # Should also work for integer-indexed series
+        target = TimeSeries.from_times_and_values(
+            pd.RangeIndex(start=10, stop=50, step=1), np.random.randn(40)
+        ).with_static_covariates(self.cov_st2_df)
+        past_cov = TimeSeries.from_times_and_values(
+            pd.RangeIndex(start=20, stop=80, step=1), np.random.randn(60)
+        )
+        future_cov = TimeSeries.from_times_and_values(
+            pd.RangeIndex(start=30, stop=100, step=1), np.random.randn(70)
+        )
 
-            # Should return correct values when covariates is long enough
-            ds = SplitCovariatesInferenceDataset(
-                target_series=target,
-                past_covariates=long_past_cov,
-                future_covariates=future_cov,
-                input_chunk_length=10,
-                output_chunk_length=10,
-                n=30,
-            )
+        ds = MixedCovariatesInferenceDataset(
+            target_series=target,
+            past_covariates=past_cov,
+            future_covariates=future_cov,
+            input_chunk_length=10,
+            output_chunk_length=10,
+            n=20,
+        )
 
-            # It should contain:
-            # past_target, past_covariates, future_covariates, future_past_covariates
-            np.testing.assert_almost_equal(ds[0][0], target.values()[-10:])
-            np.testing.assert_almost_equal(ds[0][1], long_past_cov.values()[-60:-50])
-            np.testing.assert_almost_equal(ds[0][2], future_cov.values()[-50:-20])
-            np.testing.assert_almost_equal(ds[0][3], long_past_cov.values()[-50:-30])
-            np.testing.assert_almost_equal(ds[0][4], self.cov_st2)
-            assert ds[0][5] == target
-
-            # Should also work for integer-indexed series
-            target = TimeSeries.from_times_and_values(
-                pd.RangeIndex(start=10, stop=50, step=1), np.random.randn(40)
-            ).with_static_covariates(self.cov_st2_df)
-            past_cov = TimeSeries.from_times_and_values(
-                pd.RangeIndex(start=20, stop=80, step=1), np.random.randn(60)
-            )
-            future_cov = TimeSeries.from_times_and_values(
-                pd.RangeIndex(start=30, stop=100, step=1), np.random.randn(70)
-            )
+        np.testing.assert_almost_equal(ds[0][0], target.values()[-10:])
+        np.testing.assert_almost_equal(ds[0][1], past_cov.values()[20:30])
+        np.testing.assert_almost_equal(ds[0][2], future_cov.values()[10:20])
+        np.testing.assert_almost_equal(ds[0][3], future_cov.values()[20:40])
+        np.testing.assert_almost_equal(ds[0][4], past_cov.values()[30:40])
+        np.testing.assert_almost_equal(ds[0][5], self.cov_st2)
+        assert ds[0][6] == target
+
+    def test_split_covariates_inference_dataset(self):
+        # With future past covariates:
+        times1 = pd.date_range(start="20100101", end="20100701", freq="D")
+        times2 = pd.date_range(
+            start="20100201", end="20100820", freq="D"
+        )  # ends 50 days after times1
+
+        target = TimeSeries.from_times_and_values(
+            times1, np.random.randn(len(times1))
+        ).with_static_covariates(self.cov_st2_df)
+        past_cov = TimeSeries.from_times_and_values(
+            times1, np.random.randn(len(times1))
+        )
+        long_past_cov = TimeSeries.from_times_and_values(
+            times2, np.random.randn(len(times2))
+        )
+        future_cov = TimeSeries.from_times_and_values(
+            times2, np.random.randn(len(times2))
+        )
 
-            ds = SplitCovariatesInferenceDataset(
-                target_series=target,
-                past_covariates=past_cov,
-                future_covariates=future_cov,
-                input_chunk_length=10,
-                output_chunk_length=10,
-                n=20,
-            )
+        ds = SplitCovariatesInferenceDataset(
+            target_series=target,
+            past_covariates=past_cov,
+            future_covariates=past_cov,
+            input_chunk_length=10,
+            output_chunk_length=10,
+            n=30,
+        )
 
-            np.testing.assert_almost_equal(ds[0][0], target.values()[-10:])
-            np.testing.assert_almost_equal(ds[0][1], past_cov.values()[20:30])
-            np.testing.assert_almost_equal(ds[0][2], future_cov.values()[20:40])
-            np.testing.assert_almost_equal(ds[0][3], past_cov.values()[30:40])
-            np.testing.assert_almost_equal(ds[0][4], self.cov_st2)
-            assert ds[0][5] == target
+        # should fail if future covariates are too short
+        with pytest.raises(ValueError):
+            _ = ds[0]
+
+        # Should return correct values when covariates is long enough
+        ds = SplitCovariatesInferenceDataset(
+            target_series=target,
+            past_covariates=long_past_cov,
+            future_covariates=future_cov,
+            input_chunk_length=10,
+            output_chunk_length=10,
+            n=30,
+        )
 
-        @pytest.mark.parametrize(
-            "config",
-            [
-                # (dataset class, whether contains future, future batch index)
-                (PastCovariatesInferenceDataset, None),
-                (FutureCovariatesInferenceDataset, 1),
-                (DualCovariatesInferenceDataset, 2),
-                (MixedCovariatesInferenceDataset, 3),
-                (SplitCovariatesInferenceDataset, 2),
-            ],
-        )
-        def test_inference_dataset_output_chunk_shift(self, config):
-            ds_cls, future_idx = config
-            ocl = 1
-            ocs = 2
-            target = self.target1[: -(ocl + ocs)]
-
-            ds_covs = {}
-            ds_init_params = set(inspect.signature(ds_cls.__init__).parameters)
-            for cov_type in ["covariates", "past_covariates", "future_covariates"]:
-                if cov_type in ds_init_params:
-                    ds_covs[cov_type] = self.cov1
-
-            with pytest.raises(ValueError) as err:
-                _ = ds_cls(
-                    target_series=target,
-                    input_chunk_length=1,
-                    output_chunk_length=1,
-                    output_chunk_shift=1,
-                    n=2,
-                    **ds_covs,
-                )
-            assert str(err.value).startswith("Cannot perform auto-regression")
+        # It should contain:
+        # past_target, past_covariates, future_covariates, future_past_covariates
+        np.testing.assert_almost_equal(ds[0][0], target.values()[-10:])
+        np.testing.assert_almost_equal(ds[0][1], long_past_cov.values()[-60:-50])
+        np.testing.assert_almost_equal(ds[0][2], future_cov.values()[-50:-20])
+        np.testing.assert_almost_equal(ds[0][3], long_past_cov.values()[-50:-30])
+        np.testing.assert_almost_equal(ds[0][4], self.cov_st2)
+        assert ds[0][5] == target
+
+        # Should also work for integer-indexed series
+        target = TimeSeries.from_times_and_values(
+            pd.RangeIndex(start=10, stop=50, step=1), np.random.randn(40)
+        ).with_static_covariates(self.cov_st2_df)
+        past_cov = TimeSeries.from_times_and_values(
+            pd.RangeIndex(start=20, stop=80, step=1), np.random.randn(60)
+        )
+        future_cov = TimeSeries.from_times_and_values(
+            pd.RangeIndex(start=30, stop=100, step=1), np.random.randn(70)
+        )
 
-            # regular dataset with output shift=0 and ocl=3: the 3rd future values should be identical to the 1st future
-            # values of a dataset with output shift=2 and ocl=1
-            ds_reg = ds_cls(
-                target_series=target,
-                input_chunk_length=1,
-                output_chunk_length=3,
-                output_chunk_shift=0,
-                n=1,
-                **ds_covs,
-            )
+        ds = SplitCovariatesInferenceDataset(
+            target_series=target,
+            past_covariates=past_cov,
+            future_covariates=future_cov,
+            input_chunk_length=10,
+            output_chunk_length=10,
+            n=20,
+        )
 
-            ds_shift = ds_cls(
+        np.testing.assert_almost_equal(ds[0][0], target.values()[-10:])
+        np.testing.assert_almost_equal(ds[0][1], past_cov.values()[20:30])
+        np.testing.assert_almost_equal(ds[0][2], future_cov.values()[20:40])
+        np.testing.assert_almost_equal(ds[0][3], past_cov.values()[30:40])
+        np.testing.assert_almost_equal(ds[0][4], self.cov_st2)
+        assert ds[0][5] == target
+
+    @pytest.mark.parametrize(
+        "config",
+        [
+            # (dataset class, whether contains future, future batch index)
+            (PastCovariatesInferenceDataset, None),
+            (FutureCovariatesInferenceDataset, 1),
+            (DualCovariatesInferenceDataset, 2),
+            (MixedCovariatesInferenceDataset, 3),
+            (SplitCovariatesInferenceDataset, 2),
+        ],
+    )
+    def test_inference_dataset_output_chunk_shift(self, config):
+        ds_cls, future_idx = config
+        ocl = 1
+        ocs = 2
+        target = self.target1[: -(ocl + ocs)]
+
+        ds_covs = {}
+        ds_init_params = set(inspect.signature(ds_cls.__init__).parameters)
+        for cov_type in ["covariates", "past_covariates", "future_covariates"]:
+            if cov_type in ds_init_params:
+                ds_covs[cov_type] = self.cov1
+
+        with pytest.raises(ValueError) as err:
+            _ = ds_cls(
                 target_series=target,
                 input_chunk_length=1,
                 output_chunk_length=1,
-                output_chunk_shift=ocs,
-                n=1,
+                output_chunk_shift=1,
+                n=2,
                 **ds_covs,
             )
+        assert str(err.value).startswith("Cannot perform auto-regression")
+
+        # regular dataset with output shift=0 and ocl=3: the 3rd future values should be identical to the 1st future
+        # values of a dataset with output shift=2 and ocl=1
+        ds_reg = ds_cls(
+            target_series=target,
+            input_chunk_length=1,
+            output_chunk_length=3,
+            output_chunk_shift=0,
+            n=1,
+            **ds_covs,
+        )
 
-            batch_reg, batch_shift = ds_reg[0], ds_shift[0]
-
-            # shifted prediction starts 2 steps after regular prediction
-            assert batch_reg[-1] == batch_shift[-1] - ocs * target.freq
-
-            if future_idx is not None:
-                # 3rd future values of regular ds must be identical to the 1st future values of shifted dataset
-                np.testing.assert_array_equal(
-                    batch_reg[future_idx][ocs:], batch_shift[future_idx]
-                )
-                batch_reg = batch_reg[:future_idx] + batch_reg[future_idx + 1 :]
-                batch_shift = batch_shift[:future_idx] + batch_shift[future_idx + 1 :]
-
-            # without future part, the input will be identical between regular, and shifted dataset
-            assert all(
-                [
-                    np.all(el_reg == el_shift)
-                    for el_reg, el_shift in zip(batch_reg[:-1], batch_shift[:-1])
-                ]
-            )
+        ds_shift = ds_cls(
+            target_series=target,
+            input_chunk_length=1,
+            output_chunk_length=1,
+            output_chunk_shift=ocs,
+            n=1,
+            **ds_covs,
+        )
 
-        def test_past_covariates_sequential_dataset(self):
-            # one target series
-            ds = PastCovariatesSequentialDataset(
-                target_series=self.target1,
-                input_chunk_length=10,
-                output_chunk_length=10,
-            )
-            assert len(ds) == 81
-            self._assert_eq(
-                ds[5], (self.target1[75:85], None, self.cov_st1, self.target1[85:95])
-            )
+        batch_reg, batch_shift = ds_reg[0], ds_shift[0]
 
-            # two target series
-            ds = PastCovariatesSequentialDataset(
-                target_series=[self.target1, self.target2],
-                input_chunk_length=10,
-                output_chunk_length=10,
-            )
-            assert len(ds) == 262
-            self._assert_eq(
-                ds[5], (self.target1[75:85], None, self.cov_st1, self.target1[85:95])
-            )
-            self._assert_eq(
-                ds[136],
-                (self.target2[125:135], None, self.cov_st2, self.target2[135:145]),
-            )
+        # shifted prediction starts 2 steps after regular prediction
+        assert batch_reg[-1] == batch_shift[-1] - ocs * target.freq
 
-            # two target series with custom max_nr_samples
-            ds = PastCovariatesSequentialDataset(
-                target_series=[self.target1, self.target2],
-                input_chunk_length=10,
-                output_chunk_length=10,
-                max_samples_per_ts=50,
-            )
-            assert len(ds) == 100
-            self._assert_eq(
-                ds[5], (self.target1[75:85], None, self.cov_st1, self.target1[85:95])
-            )
-            self._assert_eq(
-                ds[55],
-                (self.target2[125:135], None, self.cov_st2, self.target2[135:145]),
+        if future_idx is not None:
+            # 3rd future values of regular ds must be identical to the 1st future values of shifted dataset
+            np.testing.assert_array_equal(
+                batch_reg[future_idx][ocs:], batch_shift[future_idx]
             )
+            batch_reg = batch_reg[:future_idx] + batch_reg[future_idx + 1 :]
+            batch_shift = batch_shift[:future_idx] + batch_shift[future_idx + 1 :]
 
-            # two targets and one covariate
-            with pytest.raises(ValueError):
-                ds = PastCovariatesSequentialDataset(
-                    target_series=[self.target1, self.target2], covariates=[self.cov1]
-                )
+        # without future part, the input will be identical between regular, and shifted dataset
+        assert all(
+            [
+                np.all(el_reg == el_shift)
+                for el_reg, el_shift in zip(batch_reg[:-1], batch_shift[:-1])
+            ]
+        )
 
-            # two targets and two covariates
-            ds = PastCovariatesSequentialDataset(
-                target_series=[self.target1, self.target2],
-                covariates=[self.cov1, self.cov2],
-                input_chunk_length=10,
-                output_chunk_length=10,
-            )
-            self._assert_eq(
-                ds[5],
-                (
-                    self.target1[75:85],
-                    self.cov1[75:85],
-                    self.cov_st1,
-                    self.target1[85:95],
-                ),
-            )
-            self._assert_eq(
-                ds[136],
-                (
-                    self.target2[125:135],
-                    self.cov2[125:135],
-                    self.cov_st2,
-                    self.target2[135:145],
-                ),
-            )
+    def test_past_covariates_sequential_dataset(self):
+        # one target series
+        ds = PastCovariatesSequentialDataset(
+            target_series=self.target1,
+            input_chunk_length=10,
+            output_chunk_length=10,
+        )
+        assert len(ds) == 81
+        self._assert_eq(
+            ds[5], (self.target1[75:85], None, self.cov_st1, self.target1[85:95])
+        )
 
-            # should fail if covariates do not have the required time span, even though covariates are longer
-            times1 = pd.date_range(start="20100101", end="20110101", freq="D")
-            times2 = pd.date_range(start="20120101", end="20150101", freq="D")
-            target = TimeSeries.from_times_and_values(
-                times1, np.random.randn(len(times1))
-            ).with_static_covariates(self.cov_st2_df)
-            cov = TimeSeries.from_times_and_values(times2, np.random.randn(len(times2)))
-            ds = PastCovariatesSequentialDataset(
-                target_series=target,
-                covariates=cov,
-                input_chunk_length=10,
-                output_chunk_length=10,
-            )
-            with pytest.raises(ValueError):
-                _ = ds[5]
-
-            # the same should fail when series are integer-indexed
-            times1 = pd.RangeIndex(start=0, stop=100, step=1)
-            times2 = pd.RangeIndex(start=200, stop=400, step=1)
-            target = TimeSeries.from_times_and_values(
-                times1, np.random.randn(len(times1))
-            ).with_static_covariates(self.cov_st2_df)
-            cov = TimeSeries.from_times_and_values(times2, np.random.randn(len(times2)))
-            ds = PastCovariatesSequentialDataset(
-                target_series=target,
-                covariates=cov,
-                input_chunk_length=10,
-                output_chunk_length=10,
-            )
-            with pytest.raises(ValueError):
-                _ = ds[5]
-
-            # we should get the correct covariate slice even when target and covariates are not aligned
-            times1 = pd.date_range(start="20100101", end="20110101", freq="D")
-            times2 = pd.date_range(start="20090101", end="20110106", freq="D")
-            target = TimeSeries.from_times_and_values(
-                times1, np.random.randn(len(times1))
-            ).with_static_covariates(self.cov_st2_df)
-            cov = TimeSeries.from_times_and_values(times2, np.random.randn(len(times2)))
-            ds = PastCovariatesSequentialDataset(
-                target_series=target,
-                covariates=cov,
-                input_chunk_length=10,
-                output_chunk_length=10,
-            )
+        # two target series
+        ds = PastCovariatesSequentialDataset(
+            target_series=[self.target1, self.target2],
+            input_chunk_length=10,
+            output_chunk_length=10,
+        )
+        assert len(ds) == 262
+        self._assert_eq(
+            ds[5], (self.target1[75:85], None, self.cov_st1, self.target1[85:95])
+        )
+        self._assert_eq(
+            ds[136],
+            (self.target2[125:135], None, self.cov_st2, self.target2[135:145]),
+        )
 
-            np.testing.assert_almost_equal(ds[0][1], cov.values()[-25:-15])
-            np.testing.assert_almost_equal(ds[5][1], cov.values()[-30:-20])
+        # two target series with custom max_nr_samples
+        ds = PastCovariatesSequentialDataset(
+            target_series=[self.target1, self.target2],
+            input_chunk_length=10,
+            output_chunk_length=10,
+            max_samples_per_ts=50,
+        )
+        assert len(ds) == 100
+        self._assert_eq(
+            ds[5], (self.target1[75:85], None, self.cov_st1, self.target1[85:95])
+        )
+        self._assert_eq(
+            ds[55],
+            (self.target2[125:135], None, self.cov_st2, self.target2[135:145]),
+        )
 
-            # This should also be the case when series are integer indexed
-            times1 = pd.RangeIndex(start=100, stop=200, step=1)
-            times2 = pd.RangeIndex(start=50, stop=250, step=1)
-            target = TimeSeries.from_times_and_values(
-                times1, np.random.randn(len(times1))
-            ).with_static_covariates(self.cov_st2_df)
-            cov = TimeSeries.from_times_and_values(times2, np.random.randn(len(times2)))
+        # two targets and one covariate
+        with pytest.raises(ValueError):
             ds = PastCovariatesSequentialDataset(
-                target_series=target,
-                covariates=cov,
-                input_chunk_length=10,
-                output_chunk_length=10,
+                target_series=[self.target1, self.target2], covariates=[self.cov1]
             )
 
-            np.testing.assert_almost_equal(ds[0][1], cov.values()[-70:-60])
-            np.testing.assert_almost_equal(ds[5][1], cov.values()[-75:-65])
-
-        def test_future_covariates_sequential_dataset(self):
-            # one target series
-            ds = FutureCovariatesSequentialDataset(
-                target_series=self.target1,
-                input_chunk_length=10,
-                output_chunk_length=10,
-            )
-            assert len(ds) == 81
-            self._assert_eq(
-                ds[5], (self.target1[75:85], None, self.cov_st1, self.target1[85:95])
-            )
+        # two targets and two covariates
+        ds = PastCovariatesSequentialDataset(
+            target_series=[self.target1, self.target2],
+            covariates=[self.cov1, self.cov2],
+            input_chunk_length=10,
+            output_chunk_length=10,
+        )
+        self._assert_eq(
+            ds[5],
+            (
+                self.target1[75:85],
+                self.cov1[75:85],
+                self.cov_st1,
+                self.target1[85:95],
+            ),
+        )
+        self._assert_eq(
+            ds[136],
+            (
+                self.target2[125:135],
+                self.cov2[125:135],
+                self.cov_st2,
+                self.target2[135:145],
+            ),
+        )
 
-            # two target series
-            ds = FutureCovariatesSequentialDataset(
-                target_series=[self.target1, self.target2],
-                input_chunk_length=10,
-                output_chunk_length=10,
-            )
-            assert len(ds) == 262
-            self._assert_eq(
-                ds[5], (self.target1[75:85], None, self.cov_st1, self.target1[85:95])
-            )
-            self._assert_eq(
-                ds[136],
-                (self.target2[125:135], None, self.cov_st2, self.target2[135:145]),
-            )
+        # should fail if covariates do not have the required time span, even though covariates are longer
+        times1 = pd.date_range(start="20100101", end="20110101", freq="D")
+        times2 = pd.date_range(start="20120101", end="20150101", freq="D")
+        target = TimeSeries.from_times_and_values(
+            times1, np.random.randn(len(times1))
+        ).with_static_covariates(self.cov_st2_df)
+        cov = TimeSeries.from_times_and_values(times2, np.random.randn(len(times2)))
+        ds = PastCovariatesSequentialDataset(
+            target_series=target,
+            covariates=cov,
+            input_chunk_length=10,
+            output_chunk_length=10,
+        )
+        with pytest.raises(ValueError):
+            _ = ds[5]
+
+        # the same should fail when series are integer-indexed
+        times1 = pd.RangeIndex(start=0, stop=100, step=1)
+        times2 = pd.RangeIndex(start=200, stop=400, step=1)
+        target = TimeSeries.from_times_and_values(
+            times1, np.random.randn(len(times1))
+        ).with_static_covariates(self.cov_st2_df)
+        cov = TimeSeries.from_times_and_values(times2, np.random.randn(len(times2)))
+        ds = PastCovariatesSequentialDataset(
+            target_series=target,
+            covariates=cov,
+            input_chunk_length=10,
+            output_chunk_length=10,
+        )
+        with pytest.raises(ValueError):
+            _ = ds[5]
+
+        # we should get the correct covariate slice even when target and covariates are not aligned
+        times1 = pd.date_range(start="20100101", end="20110101", freq="D")
+        times2 = pd.date_range(start="20090101", end="20110106", freq="D")
+        target = TimeSeries.from_times_and_values(
+            times1, np.random.randn(len(times1))
+        ).with_static_covariates(self.cov_st2_df)
+        cov = TimeSeries.from_times_and_values(times2, np.random.randn(len(times2)))
+        ds = PastCovariatesSequentialDataset(
+            target_series=target,
+            covariates=cov,
+            input_chunk_length=10,
+            output_chunk_length=10,
+        )
 
-            # two target series with custom max_nr_samples
-            ds = FutureCovariatesSequentialDataset(
-                target_series=[self.target1, self.target2],
-                input_chunk_length=10,
-                output_chunk_length=10,
-                max_samples_per_ts=50,
-            )
-            assert len(ds) == 100
-            self._assert_eq(
-                ds[5], (self.target1[75:85], None, self.cov_st1, self.target1[85:95])
-            )
-            self._assert_eq(
-                ds[55],
-                (self.target2[125:135], None, self.cov_st2, self.target2[135:145]),
-            )
+        np.testing.assert_almost_equal(ds[0][1], cov.values()[-25:-15])
+        np.testing.assert_almost_equal(ds[5][1], cov.values()[-30:-20])
+
+        # This should also be the case when series are integer indexed
+        times1 = pd.RangeIndex(start=100, stop=200, step=1)
+        times2 = pd.RangeIndex(start=50, stop=250, step=1)
+        target = TimeSeries.from_times_and_values(
+            times1, np.random.randn(len(times1))
+        ).with_static_covariates(self.cov_st2_df)
+        cov = TimeSeries.from_times_and_values(times2, np.random.randn(len(times2)))
+        ds = PastCovariatesSequentialDataset(
+            target_series=target,
+            covariates=cov,
+            input_chunk_length=10,
+            output_chunk_length=10,
+        )
 
-            # two targets and one covariate
-            with pytest.raises(ValueError):
-                ds = FutureCovariatesSequentialDataset(
-                    target_series=[self.target1, self.target2], covariates=[self.cov1]
-                )
+        np.testing.assert_almost_equal(ds[0][1], cov.values()[-70:-60])
+        np.testing.assert_almost_equal(ds[5][1], cov.values()[-75:-65])
 
-            # two targets and two covariates; covariates not aligned, must contain correct values
-            target1 = TimeSeries.from_values(
-                np.random.randn(100)
-            ).with_static_covariates(self.cov_st2_df)
-            target2 = TimeSeries.from_values(
-                np.random.randn(50)
-            ).with_static_covariates(self.cov_st2_df)
-            cov1 = TimeSeries.from_values(np.random.randn(120))
-            cov2 = TimeSeries.from_values(np.random.randn(80))
+    def test_future_covariates_sequential_dataset(self):
+        # one target series
+        ds = FutureCovariatesSequentialDataset(
+            target_series=self.target1,
+            input_chunk_length=10,
+            output_chunk_length=10,
+        )
+        assert len(ds) == 81
+        self._assert_eq(
+            ds[5], (self.target1[75:85], None, self.cov_st1, self.target1[85:95])
+        )
 
-            ds = FutureCovariatesSequentialDataset(
-                target_series=[target1, target2],
-                covariates=[cov1, cov2],
-                input_chunk_length=10,
-                output_chunk_length=10,
-            )
+        # two target series
+        ds = FutureCovariatesSequentialDataset(
+            target_series=[self.target1, self.target2],
+            input_chunk_length=10,
+            output_chunk_length=10,
+        )
+        assert len(ds) == 262
+        self._assert_eq(
+            ds[5], (self.target1[75:85], None, self.cov_st1, self.target1[85:95])
+        )
+        self._assert_eq(
+            ds[136],
+            (self.target2[125:135], None, self.cov_st2, self.target2[135:145]),
+        )
 
-            np.testing.assert_almost_equal(ds[0][0], target1.values()[-20:-10])
-            np.testing.assert_almost_equal(ds[0][1], cov1.values()[-30:-20])
-            np.testing.assert_almost_equal(ds[0][2], self.cov_st2)
-            np.testing.assert_almost_equal(ds[0][3], target1.values()[-10:])
-
-            np.testing.assert_almost_equal(ds[101][0], target2.values()[-40:-30])
-            np.testing.assert_almost_equal(ds[101][1], cov2.values()[-60:-50])
-            np.testing.assert_almost_equal(ds[101][2], self.cov_st2)
-            np.testing.assert_almost_equal(ds[101][3], target2.values()[-30:-20])
-
-            # Should also contain correct values when time-indexed with covariates not aligned
-            times1 = pd.date_range(start="20090201", end="20090220", freq="D")
-            times2 = pd.date_range(start="20090201", end="20090222", freq="D")
-            target1 = TimeSeries.from_times_and_values(
-                times1, np.random.randn(len(times1))
-            ).with_static_covariates(self.cov_st2_df)
-            cov1 = TimeSeries.from_times_and_values(
-                times2, np.random.randn(len(times2))
-            )
+        # two target series with custom max_nr_samples
+        ds = FutureCovariatesSequentialDataset(
+            target_series=[self.target1, self.target2],
+            input_chunk_length=10,
+            output_chunk_length=10,
+            max_samples_per_ts=50,
+        )
+        assert len(ds) == 100
+        self._assert_eq(
+            ds[5], (self.target1[75:85], None, self.cov_st1, self.target1[85:95])
+        )
+        self._assert_eq(
+            ds[55],
+            (self.target2[125:135], None, self.cov_st2, self.target2[135:145]),
+        )
 
+        # two targets and one covariate
+        with pytest.raises(ValueError):
             ds = FutureCovariatesSequentialDataset(
-                target_series=[target1],
-                covariates=[cov1],
-                input_chunk_length=2,
-                output_chunk_length=2,
+                target_series=[self.target1, self.target2], covariates=[self.cov1]
             )
 
-            np.testing.assert_almost_equal(ds[0][0], target1.values()[-4:-2])
-            np.testing.assert_almost_equal(ds[0][1], cov1.values()[-4:-2])
-            np.testing.assert_almost_equal(ds[0][2], self.cov_st2)
-            np.testing.assert_almost_equal(ds[0][3], target1.values()[-2:])
-
-            # Should fail if covariates are not long enough
-            target1 = TimeSeries.from_values(np.random.randn(8)).with_static_covariates(
-                self.cov_st2_df
-            )
-            cov1 = TimeSeries.from_values(np.random.randn(7))
-
-            ds = FutureCovariatesSequentialDataset(
-                target_series=[target1],
-                covariates=[cov1],
-                input_chunk_length=2,
-                output_chunk_length=2,
-            )
+        # two targets and two covariates; covariates not aligned, must contain correct values
+        target1 = TimeSeries.from_values(np.random.randn(100)).with_static_covariates(
+            self.cov_st2_df
+        )
+        target2 = TimeSeries.from_values(np.random.randn(50)).with_static_covariates(
+            self.cov_st2_df
+        )
+        cov1 = TimeSeries.from_values(np.random.randn(120))
+        cov2 = TimeSeries.from_values(np.random.randn(80))
+
+        ds = FutureCovariatesSequentialDataset(
+            target_series=[target1, target2],
+            covariates=[cov1, cov2],
+            input_chunk_length=10,
+            output_chunk_length=10,
+        )
 
-            with pytest.raises(ValueError):
-                _ = ds[0]
+        np.testing.assert_almost_equal(ds[0][0], target1.values()[-20:-10])
+        np.testing.assert_almost_equal(ds[0][1], cov1.values()[-30:-20])
+        np.testing.assert_almost_equal(ds[0][2], self.cov_st2)
+        np.testing.assert_almost_equal(ds[0][3], target1.values()[-10:])
+
+        np.testing.assert_almost_equal(ds[101][0], target2.values()[-40:-30])
+        np.testing.assert_almost_equal(ds[101][1], cov2.values()[-60:-50])
+        np.testing.assert_almost_equal(ds[101][2], self.cov_st2)
+        np.testing.assert_almost_equal(ds[101][3], target2.values()[-30:-20])
+
+        # Should also contain correct values when time-indexed with covariates not aligned
+        times1 = pd.date_range(start="20090201", end="20090220", freq="D")
+        times2 = pd.date_range(start="20090201", end="20090222", freq="D")
+        target1 = TimeSeries.from_times_and_values(
+            times1, np.random.randn(len(times1))
+        ).with_static_covariates(self.cov_st2_df)
+        cov1 = TimeSeries.from_times_and_values(times2, np.random.randn(len(times2)))
+
+        ds = FutureCovariatesSequentialDataset(
+            target_series=[target1],
+            covariates=[cov1],
+            input_chunk_length=2,
+            output_chunk_length=2,
+        )
 
-        def test_dual_covariates_sequential_dataset(self):
-            # Must contain (past_target, historic_future_covariates, future_covariates, future_target)
+        np.testing.assert_almost_equal(ds[0][0], target1.values()[-4:-2])
+        np.testing.assert_almost_equal(ds[0][1], cov1.values()[-4:-2])
+        np.testing.assert_almost_equal(ds[0][2], self.cov_st2)
+        np.testing.assert_almost_equal(ds[0][3], target1.values()[-2:])
 
-            # one target series
-            ds = DualCovariatesSequentialDataset(
-                target_series=self.target1,
-                input_chunk_length=10,
-                output_chunk_length=10,
-            )
-            assert len(ds) == 81
-            self._assert_eq(
-                ds[5],
-                (self.target1[75:85], None, None, self.cov_st1, self.target1[85:95]),
-            )
+        # Should fail if covariates are not long enough
+        target1 = TimeSeries.from_values(np.random.randn(8)).with_static_covariates(
+            self.cov_st2_df
+        )
+        cov1 = TimeSeries.from_values(np.random.randn(7))
 
-            # two target series
-            ds = DualCovariatesSequentialDataset(
-                target_series=[self.target1, self.target2],
-                input_chunk_length=10,
-                output_chunk_length=10,
-            )
-            assert len(ds) == 262
-            self._assert_eq(
-                ds[5],
-                (self.target1[75:85], None, None, self.cov_st1, self.target1[85:95]),
-            )
-            self._assert_eq(
-                ds[136],
-                (
-                    self.target2[125:135],
-                    None,
-                    None,
-                    self.cov_st2,
-                    self.target2[135:145],
-                ),
-            )
+        ds = FutureCovariatesSequentialDataset(
+            target_series=[target1],
+            covariates=[cov1],
+            input_chunk_length=2,
+            output_chunk_length=2,
+        )
 
-            # two target series with custom max_nr_samples
-            ds = DualCovariatesSequentialDataset(
-                target_series=[self.target1, self.target2],
-                input_chunk_length=10,
-                output_chunk_length=10,
-                max_samples_per_ts=50,
-            )
-            assert len(ds) == 100
-            self._assert_eq(
-                ds[5],
-                (self.target1[75:85], None, None, self.cov_st1, self.target1[85:95]),
-            )
-            self._assert_eq(
-                ds[55],
-                (
-                    self.target2[125:135],
-                    None,
-                    None,
-                    self.cov_st2,
-                    self.target2[135:145],
-                ),
-            )
+        with pytest.raises(ValueError):
+            _ = ds[0]
 
-            # two targets and one covariate
-            with pytest.raises(ValueError):
-                ds = DualCovariatesSequentialDataset(
-                    target_series=[self.target1, self.target2], covariates=[self.cov1]
-                )
+    def test_dual_covariates_sequential_dataset(self):
+        # Must contain (past_target, historic_future_covariates, future_covariates, future_target)
 
-            # two targets and two covariates; covariates not aligned, must contain correct values
-            target1 = TimeSeries.from_values(
-                np.random.randn(100)
-            ).with_static_covariates(self.cov_st2_df)
-            target2 = TimeSeries.from_values(
-                np.random.randn(50)
-            ).with_static_covariates(self.cov_st2_df)
-            cov1 = TimeSeries.from_values(np.random.randn(120))
-            cov2 = TimeSeries.from_values(np.random.randn(80))
+        # one target series
+        ds = DualCovariatesSequentialDataset(
+            target_series=self.target1,
+            input_chunk_length=10,
+            output_chunk_length=10,
+        )
+        assert len(ds) == 81
+        self._assert_eq(
+            ds[5],
+            (self.target1[75:85], None, None, self.cov_st1, self.target1[85:95]),
+        )
 
-            ds = DualCovariatesSequentialDataset(
-                target_series=[target1, target2],
-                covariates=[cov1, cov2],
-                input_chunk_length=10,
-                output_chunk_length=10,
-            )
+        # two target series
+        ds = DualCovariatesSequentialDataset(
+            target_series=[self.target1, self.target2],
+            input_chunk_length=10,
+            output_chunk_length=10,
+        )
+        assert len(ds) == 262
+        self._assert_eq(
+            ds[5],
+            (self.target1[75:85], None, None, self.cov_st1, self.target1[85:95]),
+        )
+        self._assert_eq(
+            ds[136],
+            (
+                self.target2[125:135],
+                None,
+                None,
+                self.cov_st2,
+                self.target2[135:145],
+            ),
+        )
 
-            np.testing.assert_almost_equal(ds[0][0], target1.values()[-20:-10])
-            np.testing.assert_almost_equal(ds[0][1], cov1.values()[-40:-30])
-            np.testing.assert_almost_equal(ds[0][2], cov1.values()[-30:-20])
-            np.testing.assert_almost_equal(ds[0][3], self.cov_st2)
-            np.testing.assert_almost_equal(ds[0][4], target1.values()[-10:])
-
-            np.testing.assert_almost_equal(ds[101][0], target2.values()[-40:-30])
-            np.testing.assert_almost_equal(ds[101][1], cov2.values()[-70:-60])
-            np.testing.assert_almost_equal(ds[101][2], cov2.values()[-60:-50])
-            np.testing.assert_almost_equal(ds[101][3], self.cov_st2)
-            np.testing.assert_almost_equal(ds[101][4], target2.values()[-30:-20])
-
-            # Should also contain correct values when time-indexed with covariates not aligned
-            times1 = pd.date_range(start="20090201", end="20090220", freq="D")
-            times2 = pd.date_range(start="20090201", end="20090222", freq="D")
-            target1 = TimeSeries.from_times_and_values(
-                times1, np.random.randn(len(times1))
-            ).with_static_covariates(self.cov_st2_df)
-            cov1 = TimeSeries.from_times_and_values(
-                times2, np.random.randn(len(times2))
-            )
+        # two target series with custom max_nr_samples
+        ds = DualCovariatesSequentialDataset(
+            target_series=[self.target1, self.target2],
+            input_chunk_length=10,
+            output_chunk_length=10,
+            max_samples_per_ts=50,
+        )
+        assert len(ds) == 100
+        self._assert_eq(
+            ds[5],
+            (self.target1[75:85], None, None, self.cov_st1, self.target1[85:95]),
+        )
+        self._assert_eq(
+            ds[55],
+            (
+                self.target2[125:135],
+                None,
+                None,
+                self.cov_st2,
+                self.target2[135:145],
+            ),
+        )
 
+        # two targets and one covariate
+        with pytest.raises(ValueError):
             ds = DualCovariatesSequentialDataset(
-                target_series=[target1],
-                covariates=[cov1],
-                input_chunk_length=2,
-                output_chunk_length=2,
+                target_series=[self.target1, self.target2], covariates=[self.cov1]
             )
 
-            np.testing.assert_almost_equal(ds[0][0], target1.values()[-4:-2])
-            np.testing.assert_almost_equal(ds[0][1], cov1.values()[-6:-4])
-            np.testing.assert_almost_equal(ds[0][2], cov1.values()[-4:-2])
-            np.testing.assert_almost_equal(ds[0][3], self.cov_st2)
-            np.testing.assert_almost_equal(ds[0][4], target1.values()[-2:])
+        # two targets and two covariates; covariates not aligned, must contain correct values
+        target1 = TimeSeries.from_values(np.random.randn(100)).with_static_covariates(
+            self.cov_st2_df
+        )
+        target2 = TimeSeries.from_values(np.random.randn(50)).with_static_covariates(
+            self.cov_st2_df
+        )
+        cov1 = TimeSeries.from_values(np.random.randn(120))
+        cov2 = TimeSeries.from_values(np.random.randn(80))
+
+        ds = DualCovariatesSequentialDataset(
+            target_series=[target1, target2],
+            covariates=[cov1, cov2],
+            input_chunk_length=10,
+            output_chunk_length=10,
+        )
 
-            # Should fail if covariates are not long enough
-            target1 = TimeSeries.from_values(np.random.randn(8)).with_static_covariates(
-                self.cov_st2_df
-            )
-            cov1 = TimeSeries.from_values(np.random.randn(7))
+        np.testing.assert_almost_equal(ds[0][0], target1.values()[-20:-10])
+        np.testing.assert_almost_equal(ds[0][1], cov1.values()[-40:-30])
+        np.testing.assert_almost_equal(ds[0][2], cov1.values()[-30:-20])
+        np.testing.assert_almost_equal(ds[0][3], self.cov_st2)
+        np.testing.assert_almost_equal(ds[0][4], target1.values()[-10:])
+
+        np.testing.assert_almost_equal(ds[101][0], target2.values()[-40:-30])
+        np.testing.assert_almost_equal(ds[101][1], cov2.values()[-70:-60])
+        np.testing.assert_almost_equal(ds[101][2], cov2.values()[-60:-50])
+        np.testing.assert_almost_equal(ds[101][3], self.cov_st2)
+        np.testing.assert_almost_equal(ds[101][4], target2.values()[-30:-20])
+
+        # Should also contain correct values when time-indexed with covariates not aligned
+        times1 = pd.date_range(start="20090201", end="20090220", freq="D")
+        times2 = pd.date_range(start="20090201", end="20090222", freq="D")
+        target1 = TimeSeries.from_times_and_values(
+            times1, np.random.randn(len(times1))
+        ).with_static_covariates(self.cov_st2_df)
+        cov1 = TimeSeries.from_times_and_values(times2, np.random.randn(len(times2)))
+
+        ds = DualCovariatesSequentialDataset(
+            target_series=[target1],
+            covariates=[cov1],
+            input_chunk_length=2,
+            output_chunk_length=2,
+        )
 
-            ds = DualCovariatesSequentialDataset(
-                target_series=[target1],
-                covariates=[cov1],
-                input_chunk_length=2,
-                output_chunk_length=2,
-            )
+        np.testing.assert_almost_equal(ds[0][0], target1.values()[-4:-2])
+        np.testing.assert_almost_equal(ds[0][1], cov1.values()[-6:-4])
+        np.testing.assert_almost_equal(ds[0][2], cov1.values()[-4:-2])
+        np.testing.assert_almost_equal(ds[0][3], self.cov_st2)
+        np.testing.assert_almost_equal(ds[0][4], target1.values()[-2:])
 
-            with pytest.raises(ValueError):
-                _ = ds[0]
+        # Should fail if covariates are not long enough
+        target1 = TimeSeries.from_values(np.random.randn(8)).with_static_covariates(
+            self.cov_st2_df
+        )
+        cov1 = TimeSeries.from_values(np.random.randn(7))
 
-        def test_past_covariates_shifted_dataset(self):
-            # one target series
-            ds = PastCovariatesShiftedDataset(
-                target_series=self.target1, length=10, shift=5
-            )
-            assert len(ds) == 86
-            self._assert_eq(
-                ds[5], (self.target1[80:90], None, self.cov_st1, self.target1[85:95])
-            )
+        ds = DualCovariatesSequentialDataset(
+            target_series=[target1],
+            covariates=[cov1],
+            input_chunk_length=2,
+            output_chunk_length=2,
+        )
 
-            # two target series
-            ds = PastCovariatesShiftedDataset(
-                target_series=[self.target1, self.target2], length=10, shift=5
-            )
-            assert len(ds) == 272
-            self._assert_eq(
-                ds[5], (self.target1[80:90], None, self.cov_st1, self.target1[85:95])
-            )
-            self._assert_eq(
-                ds[141],
-                (self.target2[130:140], None, self.cov_st2, self.target2[135:145]),
-            )
+        with pytest.raises(ValueError):
+            _ = ds[0]
 
-            # two target series with custom max_nr_samples
-            ds = PastCovariatesShiftedDataset(
-                target_series=[self.target1, self.target2],
-                length=10,
-                shift=5,
-                max_samples_per_ts=50,
-            )
-            assert len(ds) == 100
-            self._assert_eq(
-                ds[5], (self.target1[80:90], None, self.cov_st1, self.target1[85:95])
-            )
-            self._assert_eq(
-                ds[55],
-                (self.target2[130:140], None, self.cov_st2, self.target2[135:145]),
-            )
+    def test_past_covariates_shifted_dataset(self):
+        # one target series
+        ds = PastCovariatesShiftedDataset(
+            target_series=self.target1, length=10, shift=5
+        )
+        assert len(ds) == 86
+        self._assert_eq(
+            ds[5], (self.target1[80:90], None, self.cov_st1, self.target1[85:95])
+        )
 
-            # two targets and one covariate
-            with pytest.raises(ValueError):
-                ds = PastCovariatesShiftedDataset(
-                    target_series=[self.target1, self.target2], covariates=[self.cov1]
-                )
+        # two target series
+        ds = PastCovariatesShiftedDataset(
+            target_series=[self.target1, self.target2], length=10, shift=5
+        )
+        assert len(ds) == 272
+        self._assert_eq(
+            ds[5], (self.target1[80:90], None, self.cov_st1, self.target1[85:95])
+        )
+        self._assert_eq(
+            ds[141],
+            (self.target2[130:140], None, self.cov_st2, self.target2[135:145]),
+        )
 
-            # two targets and two covariates
-            ds = PastCovariatesShiftedDataset(
-                target_series=[self.target1, self.target2],
-                covariates=[self.cov1, self.cov2],
-                length=10,
-                shift=5,
-            )
-            self._assert_eq(
-                ds[5],
-                (
-                    self.target1[80:90],
-                    self.cov1[80:90],
-                    self.cov_st1,
-                    self.target1[85:95],
-                ),
-            )
-            self._assert_eq(
-                ds[141],
-                (
-                    self.target2[130:140],
-                    self.cov2[130:140],
-                    self.cov_st2,
-                    self.target2[135:145],
-                ),
-            )
+        # two target series with custom max_nr_samples
+        ds = PastCovariatesShiftedDataset(
+            target_series=[self.target1, self.target2],
+            length=10,
+            shift=5,
+            max_samples_per_ts=50,
+        )
+        assert len(ds) == 100
+        self._assert_eq(
+            ds[5], (self.target1[80:90], None, self.cov_st1, self.target1[85:95])
+        )
+        self._assert_eq(
+            ds[55],
+            (self.target2[130:140], None, self.cov_st2, self.target2[135:145]),
+        )
 
-            # Should contain correct values even when covariates are not aligned
-            target1 = TimeSeries.from_values(np.random.randn(8)).with_static_covariates(
-                self.cov_st2_df
-            )
-            cov1 = TimeSeries.from_values(np.random.randn(10))
-            ds = PastCovariatesShiftedDataset(
-                target_series=[target1], covariates=[cov1], length=3, shift=2
-            )
-            np.testing.assert_almost_equal(ds[0][0], target1.values()[-5:-2])
-            np.testing.assert_almost_equal(ds[0][1], cov1.values()[-7:-4])
-            np.testing.assert_almost_equal(ds[0][2], self.cov_st2)
-            np.testing.assert_almost_equal(ds[0][3], target1.values()[-3:])
-
-            # Should also contain correct values when time-indexed with covariates not aligned
-            times1 = pd.date_range(start="20090201", end="20090220", freq="D")
-            times2 = pd.date_range(start="20090201", end="20090222", freq="D")
-            target1 = TimeSeries.from_times_and_values(
-                times1, np.random.randn(len(times1))
-            ).with_static_covariates(self.cov_st2_df)
-            cov1 = TimeSeries.from_times_and_values(
-                times2, np.random.randn(len(times2))
-            )
-            ds = PastCovariatesShiftedDataset(
-                target_series=[target1], covariates=[cov1], length=3, shift=2
-            )
-            np.testing.assert_almost_equal(ds[0][0], target1.values()[-5:-2])
-            np.testing.assert_almost_equal(ds[0][1], cov1.values()[-7:-4])
-            np.testing.assert_almost_equal(ds[0][2], self.cov_st2)
-            np.testing.assert_almost_equal(ds[0][3], target1.values()[-3:])
-
-            # Should fail if covariates are too short
-            target1 = TimeSeries.from_values(np.random.randn(8)).with_static_covariates(
-                self.cov_st2_df
-            )
-            cov1 = TimeSeries.from_values(np.random.randn(5))
+        # two targets and one covariate
+        with pytest.raises(ValueError):
             ds = PastCovariatesShiftedDataset(
-                target_series=[target1], covariates=[cov1], length=3, shift=2
+                target_series=[self.target1, self.target2], covariates=[self.cov1]
             )
-            with pytest.raises(ValueError):
-                _ = ds[0]
 
-        def test_future_covariates_shifted_dataset(self):
-            # one target series
-            ds = FutureCovariatesShiftedDataset(
-                target_series=self.target1, length=10, shift=5
-            )
-            assert len(ds) == 86
-            self._assert_eq(
-                ds[5], (self.target1[80:90], None, self.cov_st1, self.target1[85:95])
-            )
+        # two targets and two covariates
+        ds = PastCovariatesShiftedDataset(
+            target_series=[self.target1, self.target2],
+            covariates=[self.cov1, self.cov2],
+            length=10,
+            shift=5,
+        )
+        self._assert_eq(
+            ds[5],
+            (
+                self.target1[80:90],
+                self.cov1[80:90],
+                self.cov_st1,
+                self.target1[85:95],
+            ),
+        )
+        self._assert_eq(
+            ds[141],
+            (
+                self.target2[130:140],
+                self.cov2[130:140],
+                self.cov_st2,
+                self.target2[135:145],
+            ),
+        )
 
-            # two target series
-            ds = FutureCovariatesShiftedDataset(
-                target_series=[self.target1, self.target2], length=10, shift=5
-            )
-            assert len(ds) == 272
-            self._assert_eq(
-                ds[5], (self.target1[80:90], None, self.cov_st1, self.target1[85:95])
-            )
-            self._assert_eq(
-                ds[141],
-                (self.target2[130:140], None, self.cov_st2, self.target2[135:145]),
-            )
+        # Should contain correct values even when covariates are not aligned
+        target1 = TimeSeries.from_values(np.random.randn(8)).with_static_covariates(
+            self.cov_st2_df
+        )
+        cov1 = TimeSeries.from_values(np.random.randn(10))
+        ds = PastCovariatesShiftedDataset(
+            target_series=[target1], covariates=[cov1], length=3, shift=2
+        )
+        np.testing.assert_almost_equal(ds[0][0], target1.values()[-5:-2])
+        np.testing.assert_almost_equal(ds[0][1], cov1.values()[-7:-4])
+        np.testing.assert_almost_equal(ds[0][2], self.cov_st2)
+        np.testing.assert_almost_equal(ds[0][3], target1.values()[-3:])
+
+        # Should also contain correct values when time-indexed with covariates not aligned
+        times1 = pd.date_range(start="20090201", end="20090220", freq="D")
+        times2 = pd.date_range(start="20090201", end="20090222", freq="D")
+        target1 = TimeSeries.from_times_and_values(
+            times1, np.random.randn(len(times1))
+        ).with_static_covariates(self.cov_st2_df)
+        cov1 = TimeSeries.from_times_and_values(times2, np.random.randn(len(times2)))
+        ds = PastCovariatesShiftedDataset(
+            target_series=[target1], covariates=[cov1], length=3, shift=2
+        )
+        np.testing.assert_almost_equal(ds[0][0], target1.values()[-5:-2])
+        np.testing.assert_almost_equal(ds[0][1], cov1.values()[-7:-4])
+        np.testing.assert_almost_equal(ds[0][2], self.cov_st2)
+        np.testing.assert_almost_equal(ds[0][3], target1.values()[-3:])
+
+        # Should fail if covariates are too short
+        target1 = TimeSeries.from_values(np.random.randn(8)).with_static_covariates(
+            self.cov_st2_df
+        )
+        cov1 = TimeSeries.from_values(np.random.randn(5))
+        ds = PastCovariatesShiftedDataset(
+            target_series=[target1], covariates=[cov1], length=3, shift=2
+        )
+        with pytest.raises(ValueError):
+            _ = ds[0]
 
-            # two target series with custom max_nr_samples
-            ds = FutureCovariatesShiftedDataset(
-                target_series=[self.target1, self.target2],
-                length=10,
-                shift=5,
-                max_samples_per_ts=50,
-            )
-            assert len(ds) == 100
-            self._assert_eq(
-                ds[5], (self.target1[80:90], None, self.cov_st1, self.target1[85:95])
-            )
-            self._assert_eq(
-                ds[55],
-                (self.target2[130:140], None, self.cov_st2, self.target2[135:145]),
-            )
+    def test_future_covariates_shifted_dataset(self):
+        # one target series
+        ds = FutureCovariatesShiftedDataset(
+            target_series=self.target1, length=10, shift=5
+        )
+        assert len(ds) == 86
+        self._assert_eq(
+            ds[5], (self.target1[80:90], None, self.cov_st1, self.target1[85:95])
+        )
 
-            # two targets and one covariate
-            with pytest.raises(ValueError):
-                ds = FutureCovariatesShiftedDataset(
-                    target_series=[self.target1, self.target2], covariates=[self.cov1]
-                )
+        # two target series
+        ds = FutureCovariatesShiftedDataset(
+            target_series=[self.target1, self.target2], length=10, shift=5
+        )
+        assert len(ds) == 272
+        self._assert_eq(
+            ds[5], (self.target1[80:90], None, self.cov_st1, self.target1[85:95])
+        )
+        self._assert_eq(
+            ds[141],
+            (self.target2[130:140], None, self.cov_st2, self.target2[135:145]),
+        )
 
-            # two targets and two covariates
-            ds = FutureCovariatesShiftedDataset(
-                target_series=[self.target1, self.target2],
-                covariates=[self.cov1, self.cov2],
-                length=10,
-                shift=5,
-            )
-            self._assert_eq(
-                ds[5],
-                (
-                    self.target1[80:90],
-                    self.cov1[85:95],
-                    self.cov_st1,
-                    self.target1[85:95],
-                ),
-            )
-            self._assert_eq(
-                ds[141],
-                (
-                    self.target2[130:140],
-                    self.cov2[135:145],
-                    self.cov_st2,
-                    self.target2[135:145],
-                ),
-            )
+        # two target series with custom max_nr_samples
+        ds = FutureCovariatesShiftedDataset(
+            target_series=[self.target1, self.target2],
+            length=10,
+            shift=5,
+            max_samples_per_ts=50,
+        )
+        assert len(ds) == 100
+        self._assert_eq(
+            ds[5], (self.target1[80:90], None, self.cov_st1, self.target1[85:95])
+        )
+        self._assert_eq(
+            ds[55],
+            (self.target2[130:140], None, self.cov_st2, self.target2[135:145]),
+        )
 
-            # Should contain correct values even when covariates are not aligned
-            target1 = TimeSeries.from_values(np.random.randn(8)).with_static_covariates(
-                self.cov_st2_df
-            )
-            cov1 = TimeSeries.from_values(np.random.randn(10))
-            ds = FutureCovariatesShiftedDataset(
-                target_series=[target1], covariates=[cov1], length=3, shift=2
-            )
-            np.testing.assert_almost_equal(ds[0][0], target1.values()[-5:-2])
-            np.testing.assert_almost_equal(ds[0][1], cov1.values()[-5:-2])
-            np.testing.assert_almost_equal(ds[0][2], self.cov_st2)
-            np.testing.assert_almost_equal(ds[0][3], target1.values()[-3:])
-
-            # Should also contain correct values when time-indexed with covariates not aligned
-            times1 = pd.date_range(start="20090201", end="20090220", freq="D")
-            times2 = pd.date_range(start="20090201", end="20090222", freq="D")
-            target1 = TimeSeries.from_times_and_values(
-                times1, np.random.randn(len(times1))
-            ).with_static_covariates(self.cov_st2_df)
-            cov1 = TimeSeries.from_times_and_values(
-                times2, np.random.randn(len(times2))
-            )
-            ds = FutureCovariatesShiftedDataset(
-                target_series=[target1], covariates=[cov1], length=3, shift=2
-            )
-            np.testing.assert_almost_equal(ds[0][0], target1.values()[-5:-2])
-            np.testing.assert_almost_equal(ds[0][1], cov1.values()[-5:-2])
-            np.testing.assert_almost_equal(ds[0][2], self.cov_st2)
-            np.testing.assert_almost_equal(ds[0][3], target1.values()[-3:])
-
-            # Should fail if covariates are too short
-            target1 = TimeSeries.from_values(np.random.randn(8)).with_static_covariates(
-                self.cov_st2_df
-            )
-            cov1 = TimeSeries.from_values(np.random.randn(7))
+        # two targets and one covariate
+        with pytest.raises(ValueError):
             ds = FutureCovariatesShiftedDataset(
-                target_series=[target1], covariates=[cov1], length=3, shift=2
+                target_series=[self.target1, self.target2], covariates=[self.cov1]
             )
-            with pytest.raises(ValueError):
-                _ = ds[0]
 
-        def test_dual_covariates_shifted_dataset(self):
-            # one target series
-            ds = DualCovariatesShiftedDataset(
-                target_series=self.target1, length=10, shift=5
-            )
-            assert len(ds) == 86
-            self._assert_eq(
-                ds[5],
-                (self.target1[80:90], None, None, self.cov_st1, self.target1[85:95]),
-            )
+        # two targets and two covariates
+        ds = FutureCovariatesShiftedDataset(
+            target_series=[self.target1, self.target2],
+            covariates=[self.cov1, self.cov2],
+            length=10,
+            shift=5,
+        )
+        self._assert_eq(
+            ds[5],
+            (
+                self.target1[80:90],
+                self.cov1[85:95],
+                self.cov_st1,
+                self.target1[85:95],
+            ),
+        )
+        self._assert_eq(
+            ds[141],
+            (
+                self.target2[130:140],
+                self.cov2[135:145],
+                self.cov_st2,
+                self.target2[135:145],
+            ),
+        )
 
-            # two target series
-            ds = DualCovariatesShiftedDataset(
-                target_series=[self.target1, self.target2], length=10, shift=5
-            )
-            assert len(ds) == 272
-            self._assert_eq(
-                ds[5],
-                (self.target1[80:90], None, None, self.cov_st1, self.target1[85:95]),
-            )
-            self._assert_eq(
-                ds[141],
-                (
-                    self.target2[130:140],
-                    None,
-                    None,
-                    self.cov_st2,
-                    self.target2[135:145],
-                ),
-            )
+        # Should contain correct values even when covariates are not aligned
+        target1 = TimeSeries.from_values(np.random.randn(8)).with_static_covariates(
+            self.cov_st2_df
+        )
+        cov1 = TimeSeries.from_values(np.random.randn(10))
+        ds = FutureCovariatesShiftedDataset(
+            target_series=[target1], covariates=[cov1], length=3, shift=2
+        )
+        np.testing.assert_almost_equal(ds[0][0], target1.values()[-5:-2])
+        np.testing.assert_almost_equal(ds[0][1], cov1.values()[-5:-2])
+        np.testing.assert_almost_equal(ds[0][2], self.cov_st2)
+        np.testing.assert_almost_equal(ds[0][3], target1.values()[-3:])
+
+        # Should also contain correct values when time-indexed with covariates not aligned
+        times1 = pd.date_range(start="20090201", end="20090220", freq="D")
+        times2 = pd.date_range(start="20090201", end="20090222", freq="D")
+        target1 = TimeSeries.from_times_and_values(
+            times1, np.random.randn(len(times1))
+        ).with_static_covariates(self.cov_st2_df)
+        cov1 = TimeSeries.from_times_and_values(times2, np.random.randn(len(times2)))
+        ds = FutureCovariatesShiftedDataset(
+            target_series=[target1], covariates=[cov1], length=3, shift=2
+        )
+        np.testing.assert_almost_equal(ds[0][0], target1.values()[-5:-2])
+        np.testing.assert_almost_equal(ds[0][1], cov1.values()[-5:-2])
+        np.testing.assert_almost_equal(ds[0][2], self.cov_st2)
+        np.testing.assert_almost_equal(ds[0][3], target1.values()[-3:])
+
+        # Should fail if covariates are too short
+        target1 = TimeSeries.from_values(np.random.randn(8)).with_static_covariates(
+            self.cov_st2_df
+        )
+        cov1 = TimeSeries.from_values(np.random.randn(7))
+        ds = FutureCovariatesShiftedDataset(
+            target_series=[target1], covariates=[cov1], length=3, shift=2
+        )
+        with pytest.raises(ValueError):
+            _ = ds[0]
 
-            # two target series with custom max_nr_samples
-            ds = DualCovariatesShiftedDataset(
-                target_series=[self.target1, self.target2],
-                length=10,
-                shift=5,
-                max_samples_per_ts=50,
-            )
-            assert len(ds) == 100
-            self._assert_eq(
-                ds[5],
-                (self.target1[80:90], None, None, self.cov_st1, self.target1[85:95]),
-            )
-            self._assert_eq(
-                ds[55],
-                (
-                    self.target2[130:140],
-                    None,
-                    None,
-                    self.cov_st2,
-                    self.target2[135:145],
-                ),
-            )
+    def test_dual_covariates_shifted_dataset(self):
+        # one target series
+        ds = DualCovariatesShiftedDataset(
+            target_series=self.target1, length=10, shift=5
+        )
+        assert len(ds) == 86
+        self._assert_eq(
+            ds[5],
+            (self.target1[80:90], None, None, self.cov_st1, self.target1[85:95]),
+        )
 
-            # two targets and one covariate
-            with pytest.raises(ValueError):
-                ds = DualCovariatesShiftedDataset(
-                    target_series=[self.target1, self.target2], covariates=[self.cov1]
-                )
+        # two target series
+        ds = DualCovariatesShiftedDataset(
+            target_series=[self.target1, self.target2], length=10, shift=5
+        )
+        assert len(ds) == 272
+        self._assert_eq(
+            ds[5],
+            (self.target1[80:90], None, None, self.cov_st1, self.target1[85:95]),
+        )
+        self._assert_eq(
+            ds[141],
+            (
+                self.target2[130:140],
+                None,
+                None,
+                self.cov_st2,
+                self.target2[135:145],
+            ),
+        )
 
-            # two targets and two covariates
-            ds = DualCovariatesShiftedDataset(
-                target_series=[self.target1, self.target2],
-                covariates=[self.cov1, self.cov2],
-                length=10,
-                shift=5,
-            )
-            self._assert_eq(
-                ds[5],
-                (
-                    self.target1[80:90],
-                    self.cov1[80:90],
-                    self.cov1[85:95],
-                    self.cov_st1,
-                    self.target1[85:95],
-                ),
-            )
-            self._assert_eq(
-                ds[141],
-                (
-                    self.target2[130:140],
-                    self.cov2[130:140],
-                    self.cov2[135:145],
-                    self.cov_st2,
-                    self.target2[135:145],
-                ),
-            )
+        # two target series with custom max_nr_samples
+        ds = DualCovariatesShiftedDataset(
+            target_series=[self.target1, self.target2],
+            length=10,
+            shift=5,
+            max_samples_per_ts=50,
+        )
+        assert len(ds) == 100
+        self._assert_eq(
+            ds[5],
+            (self.target1[80:90], None, None, self.cov_st1, self.target1[85:95]),
+        )
+        self._assert_eq(
+            ds[55],
+            (
+                self.target2[130:140],
+                None,
+                None,
+                self.cov_st2,
+                self.target2[135:145],
+            ),
+        )
 
-            # Should contain correct values even when covariates are not aligned
-            target1 = TimeSeries.from_values(np.random.randn(8)).with_static_covariates(
-                self.cov_st2_df
-            )
-            cov1 = TimeSeries.from_values(np.random.randn(10))
-            ds = DualCovariatesShiftedDataset(
-                target_series=[target1], covariates=[cov1], length=3, shift=2
-            )
-            np.testing.assert_almost_equal(ds[0][0], target1.values()[-5:-2])
-            np.testing.assert_almost_equal(ds[0][1], cov1.values()[-7:-4])
-            np.testing.assert_almost_equal(ds[0][2], cov1.values()[-5:-2])
-            np.testing.assert_almost_equal(ds[0][3], self.cov_st2)
-            np.testing.assert_almost_equal(ds[0][4], target1.values()[-3:])
-
-            # Should also contain correct values when time-indexed with covariates not aligned
-            times1 = pd.date_range(start="20090201", end="20090220", freq="D")
-            times2 = pd.date_range(start="20090201", end="20090222", freq="D")
-            target1 = TimeSeries.from_times_and_values(
-                times1, np.random.randn(len(times1))
-            ).with_static_covariates(self.cov_st2_df)
-            cov1 = TimeSeries.from_times_and_values(
-                times2, np.random.randn(len(times2))
-            )
+        # two targets and one covariate
+        with pytest.raises(ValueError):
             ds = DualCovariatesShiftedDataset(
-                target_series=[target1], covariates=[cov1], length=3, shift=2
+                target_series=[self.target1, self.target2], covariates=[self.cov1]
             )
-            np.testing.assert_almost_equal(ds[0][0], target1.values()[-5:-2])
-            np.testing.assert_almost_equal(ds[0][1], cov1.values()[-7:-4])
-            np.testing.assert_almost_equal(ds[0][2], cov1.values()[-5:-2])
-            np.testing.assert_almost_equal(ds[0][3], self.cov_st2)
-            np.testing.assert_almost_equal(ds[0][4], target1.values()[-3:])
-
-            # Should fail if covariates are too short
-            target1 = TimeSeries.from_values(np.random.randn(8)).with_static_covariates(
-                self.cov_st2_df
-            )
-            cov1 = TimeSeries.from_values(np.random.randn(7))
-            ds = DualCovariatesShiftedDataset(
-                target_series=[target1], covariates=[cov1], length=3, shift=2
-            )
-            with pytest.raises(ValueError):
-                _ = ds[0]
 
-        def test_horizon_based_dataset(self):
-            # one target series
-            ds = HorizonBasedDataset(
-                target_series=self.target1,
-                output_chunk_length=10,
-                lh=(1, 3),
-                lookback=2,
-            )
-            assert len(ds) == 20
-            self._assert_eq(
-                ds[5], (self.target1[65:85], None, self.cov_st1, self.target1[85:95])
-            )
+        # two targets and two covariates
+        ds = DualCovariatesShiftedDataset(
+            target_series=[self.target1, self.target2],
+            covariates=[self.cov1, self.cov2],
+            length=10,
+            shift=5,
+        )
+        self._assert_eq(
+            ds[5],
+            (
+                self.target1[80:90],
+                self.cov1[80:90],
+                self.cov1[85:95],
+                self.cov_st1,
+                self.target1[85:95],
+            ),
+        )
+        self._assert_eq(
+            ds[141],
+            (
+                self.target2[130:140],
+                self.cov2[130:140],
+                self.cov2[135:145],
+                self.cov_st2,
+                self.target2[135:145],
+            ),
+        )
 
-            # two target series
-            ds = HorizonBasedDataset(
-                target_series=[self.target1, self.target2],
-                output_chunk_length=10,
-                lh=(1, 3),
-                lookback=2,
-            )
-            assert len(ds) == 40
-            self._assert_eq(
-                ds[5], (self.target1[65:85], None, self.cov_st1, self.target1[85:95])
-            )
-            self._assert_eq(
-                ds[25],
-                (self.target2[115:135], None, self.cov_st2, self.target2[135:145]),
-            )
+        # Should contain correct values even when covariates are not aligned
+        target1 = TimeSeries.from_values(np.random.randn(8)).with_static_covariates(
+            self.cov_st2_df
+        )
+        cov1 = TimeSeries.from_values(np.random.randn(10))
+        ds = DualCovariatesShiftedDataset(
+            target_series=[target1], covariates=[cov1], length=3, shift=2
+        )
+        np.testing.assert_almost_equal(ds[0][0], target1.values()[-5:-2])
+        np.testing.assert_almost_equal(ds[0][1], cov1.values()[-7:-4])
+        np.testing.assert_almost_equal(ds[0][2], cov1.values()[-5:-2])
+        np.testing.assert_almost_equal(ds[0][3], self.cov_st2)
+        np.testing.assert_almost_equal(ds[0][4], target1.values()[-3:])
+
+        # Should also contain correct values when time-indexed with covariates not aligned
+        times1 = pd.date_range(start="20090201", end="20090220", freq="D")
+        times2 = pd.date_range(start="20090201", end="20090222", freq="D")
+        target1 = TimeSeries.from_times_and_values(
+            times1, np.random.randn(len(times1))
+        ).with_static_covariates(self.cov_st2_df)
+        cov1 = TimeSeries.from_times_and_values(times2, np.random.randn(len(times2)))
+        ds = DualCovariatesShiftedDataset(
+            target_series=[target1], covariates=[cov1], length=3, shift=2
+        )
+        np.testing.assert_almost_equal(ds[0][0], target1.values()[-5:-2])
+        np.testing.assert_almost_equal(ds[0][1], cov1.values()[-7:-4])
+        np.testing.assert_almost_equal(ds[0][2], cov1.values()[-5:-2])
+        np.testing.assert_almost_equal(ds[0][3], self.cov_st2)
+        np.testing.assert_almost_equal(ds[0][4], target1.values()[-3:])
+
+        # Should fail if covariates are too short
+        target1 = TimeSeries.from_values(np.random.randn(8)).with_static_covariates(
+            self.cov_st2_df
+        )
+        cov1 = TimeSeries.from_values(np.random.randn(7))
+        ds = DualCovariatesShiftedDataset(
+            target_series=[target1], covariates=[cov1], length=3, shift=2
+        )
+        with pytest.raises(ValueError):
+            _ = ds[0]
+
+    def test_horizon_based_dataset(self):
+        # one target series
+        ds = HorizonBasedDataset(
+            target_series=self.target1,
+            output_chunk_length=10,
+            lh=(1, 3),
+            lookback=2,
+        )
+        assert len(ds) == 20
+        self._assert_eq(
+            ds[5], (self.target1[65:85], None, self.cov_st1, self.target1[85:95])
+        )
 
-            # two targets and one covariate
-            with pytest.raises(ValueError):
-                ds = HorizonBasedDataset(
-                    target_series=[self.target1, self.target2], covariates=[self.cov1]
-                )
+        # two target series
+        ds = HorizonBasedDataset(
+            target_series=[self.target1, self.target2],
+            output_chunk_length=10,
+            lh=(1, 3),
+            lookback=2,
+        )
+        assert len(ds) == 40
+        self._assert_eq(
+            ds[5], (self.target1[65:85], None, self.cov_st1, self.target1[85:95])
+        )
+        self._assert_eq(
+            ds[25],
+            (self.target2[115:135], None, self.cov_st2, self.target2[135:145]),
+        )
 
-            # two targets and two covariates
+        # two targets and one covariate
+        with pytest.raises(ValueError):
             ds = HorizonBasedDataset(
-                target_series=[self.target1, self.target2],
-                covariates=[self.cov1, self.cov2],
-                output_chunk_length=10,
-                lh=(1, 3),
-                lookback=2,
-            )
-            self._assert_eq(
-                ds[5],
-                (
-                    self.target1[65:85],
-                    self.cov1[65:85],
-                    self.cov_st1,
-                    self.target1[85:95],
-                ),
-            )
-            self._assert_eq(
-                ds[25],
-                (
-                    self.target2[115:135],
-                    self.cov2[115:135],
-                    self.cov_st2,
-                    self.target2[135:145],
-                ),
+                target_series=[self.target1, self.target2], covariates=[self.cov1]
             )
 
-        @pytest.mark.parametrize(
-            "config",
-            [
-                # (dataset class, whether contains future, future batch index)
-                (PastCovariatesSequentialDataset, None),
-                (FutureCovariatesSequentialDataset, 1),
-                (DualCovariatesSequentialDataset, 2),
-                (MixedCovariatesSequentialDataset, 3),
-                (SplitCovariatesSequentialDataset, 2),
-            ],
-        )
-        def test_sequential_training_dataset_output_chunk_shift(self, config):
-            ds_cls, future_idx = config
-            ocl = 1
-            ocs = 2
-            target = self.target1[: -(ocl + ocs)]
-
-            ds_covs = {}
-            ds_init_params = set(inspect.signature(ds_cls.__init__).parameters)
-            for cov_type in ["covariates", "past_covariates", "future_covariates"]:
-                if cov_type in ds_init_params:
-                    ds_covs[cov_type] = self.cov1
-
-            # regular dataset with output shift=0 and ocl=3: the 3rd future values should be identical to the 1st future
-            # values of a dataset with output shift=2 and ocl=1
-            ds_reg = ds_cls(
-                target_series=target,
-                input_chunk_length=1,
-                output_chunk_length=3,
-                output_chunk_shift=0,
-                **ds_covs,
-            )
+        # two targets and two covariates
+        ds = HorizonBasedDataset(
+            target_series=[self.target1, self.target2],
+            covariates=[self.cov1, self.cov2],
+            output_chunk_length=10,
+            lh=(1, 3),
+            lookback=2,
+        )
+        self._assert_eq(
+            ds[5],
+            (
+                self.target1[65:85],
+                self.cov1[65:85],
+                self.cov_st1,
+                self.target1[85:95],
+            ),
+        )
+        self._assert_eq(
+            ds[25],
+            (
+                self.target2[115:135],
+                self.cov2[115:135],
+                self.cov_st2,
+                self.target2[135:145],
+            ),
+        )
 
-            ds_shift = ds_cls(
-                target_series=target,
-                input_chunk_length=1,
-                output_chunk_length=1,
-                output_chunk_shift=ocs,
-                **ds_covs,
-            )
+    @pytest.mark.parametrize(
+        "config",
+        [
+            # (dataset class, whether contains future, future batch index)
+            (PastCovariatesSequentialDataset, None),
+            (FutureCovariatesSequentialDataset, 1),
+            (DualCovariatesSequentialDataset, 2),
+            (MixedCovariatesSequentialDataset, 3),
+            (SplitCovariatesSequentialDataset, 2),
+        ],
+    )
+    def test_sequential_training_dataset_output_chunk_shift(self, config):
+        ds_cls, future_idx = config
+        ocl = 1
+        ocs = 2
+        target = self.target1[: -(ocl + ocs)]
+
+        ds_covs = {}
+        ds_init_params = set(inspect.signature(ds_cls.__init__).parameters)
+        for cov_type in ["covariates", "past_covariates", "future_covariates"]:
+            if cov_type in ds_init_params:
+                ds_covs[cov_type] = self.cov1
+
+        # regular dataset with output shift=0 and ocl=3: the 3rd future values should be identical to the 1st future
+        # values of a dataset with output shift=2 and ocl=1
+        ds_reg = ds_cls(
+            target_series=target,
+            input_chunk_length=1,
+            output_chunk_length=3,
+            output_chunk_shift=0,
+            **ds_covs,
+        )
 
-            batch_reg, batch_shift = ds_reg[0], ds_shift[0]
+        ds_shift = ds_cls(
+            target_series=target,
+            input_chunk_length=1,
+            output_chunk_length=1,
+            output_chunk_shift=ocs,
+            **ds_covs,
+        )
 
-            if future_idx is not None:
-                # 3rd future values of regular ds must be identical to the 1st future values of shifted dataset
-                np.testing.assert_array_equal(
-                    batch_reg[future_idx][-1:], batch_shift[future_idx]
-                )
-                batch_reg = batch_reg[:future_idx] + batch_reg[future_idx + 1 :]
-                batch_shift = batch_shift[:future_idx] + batch_shift[future_idx + 1 :]
+        batch_reg, batch_shift = ds_reg[0], ds_shift[0]
 
-            # last element is the output chunk of the target series.
+        if future_idx is not None:
             # 3rd future values of regular ds must be identical to the 1st future values of shifted dataset
-            batch_reg = batch_reg[:-1] + (batch_reg[-1][ocs:],)
-
-            # without future part, the input will be identical between regular, and shifted dataset
-            assert all(
-                [
-                    np.all(el_reg == el_shift)
-                    for el_reg, el_shift in zip(batch_reg[:-1], batch_shift[:-1])
-                ]
+            np.testing.assert_array_equal(
+                batch_reg[future_idx][-1:], batch_shift[future_idx]
             )
+            batch_reg = batch_reg[:future_idx] + batch_reg[future_idx + 1 :]
+            batch_shift = batch_shift[:future_idx] + batch_shift[future_idx + 1 :]
+
+        # last element is the output chunk of the target series.
+        # 3rd future values of regular ds must be identical to the 1st future values of shifted dataset
+        batch_reg = batch_reg[:-1] + (batch_reg[-1][ocs:],)
+
+        # without future part, the input will be identical between regular, and shifted dataset
+        assert all(
+            [
+                np.all(el_reg == el_shift)
+                for el_reg, el_shift in zip(batch_reg[:-1], batch_shift[:-1])
+            ]
+        )
 
-        def test_get_matching_index(self):
-            from darts.utils.data.utils import _get_matching_index
-
-            # Check dividable freq
-            times1 = pd.date_range(start="20100101", end="20100330", freq="D")
-            times2 = pd.date_range(start="20100101", end="20100320", freq="D")
-            target = TimeSeries.from_times_and_values(
-                times1, np.random.randn(len(times1))
-            ).with_static_covariates(self.cov_st2_df)
-            cov = TimeSeries.from_times_and_values(times2, np.random.randn(len(times2)))
-            assert _get_matching_index(target, cov, idx=15) == 5
-
-            # check non-dividable freq
-            times1 = pd.date_range(start="20100101", end="20120101", freq="M")
-            times2 = pd.date_range(start="20090101", end="20110601", freq="M")
-            target = TimeSeries.from_times_and_values(
-                times1, np.random.randn(len(times1))
-            ).with_static_covariates(self.cov_st2_df)
-            cov = TimeSeries.from_times_and_values(times2, np.random.randn(len(times2)))
-            assert _get_matching_index(target, cov, idx=15) == 15 - 7
-
-            # check integer-indexed series
-            times2 = pd.RangeIndex(start=10, stop=90)
-            target = TimeSeries.from_values(
-                np.random.randn(100)
-            ).with_static_covariates(self.cov_st2_df)
-            cov = TimeSeries.from_times_and_values(times2, np.random.randn(len(times2)))
-            assert _get_matching_index(target, cov, idx=15) == 5
+    def test_get_matching_index(self):
+        from darts.utils.data.utils import _get_matching_index
+
+        # Check dividable freq
+        times1 = pd.date_range(start="20100101", end="20100330", freq="D")
+        times2 = pd.date_range(start="20100101", end="20100320", freq="D")
+        target = TimeSeries.from_times_and_values(
+            times1, np.random.randn(len(times1))
+        ).with_static_covariates(self.cov_st2_df)
+        cov = TimeSeries.from_times_and_values(times2, np.random.randn(len(times2)))
+        assert _get_matching_index(target, cov, idx=15) == 5
+
+        # check non-dividable freq
+        times1 = pd.date_range(start="20100101", end="20120101", freq="M")
+        times2 = pd.date_range(start="20090101", end="20110601", freq="M")
+        target = TimeSeries.from_times_and_values(
+            times1, np.random.randn(len(times1))
+        ).with_static_covariates(self.cov_st2_df)
+        cov = TimeSeries.from_times_and_values(times2, np.random.randn(len(times2)))
+        assert _get_matching_index(target, cov, idx=15) == 15 - 7
+
+        # check integer-indexed series
+        times2 = pd.RangeIndex(start=10, stop=90)
+        target = TimeSeries.from_values(np.random.randn(100)).with_static_covariates(
+            self.cov_st2_df
+        )
+        cov = TimeSeries.from_times_and_values(times2, np.random.randn(len(times2)))
+        assert _get_matching_index(target, cov, idx=15) == 5
diff --git a/darts/tests/explainability/test_tft_explainer.py b/darts/tests/explainability/test_tft_explainer.py
index 8bea86f6b5..c1cd930977 100644
--- a/darts/tests/explainability/test_tft_explainer.py
+++ b/darts/tests/explainability/test_tft_explainer.py
@@ -16,462 +16,443 @@
 try:
     from darts.explainability import TFTExplainabilityResult, TFTExplainer
     from darts.models import TFTModel
-
-    TORCH_AVAILABLE = True
 except ImportError:
-    logger.warning("Torch not available. RNN tests will be skipped.")
-    TORCH_AVAILABLE = False
-
-
-if TORCH_AVAILABLE:
-
-    def helper_create_test_cases(series_options: list):
-        covariates_options = [
-            {},
-            {"past_covariates"},
-            {"future_covariates"},
-            {"past_covariates", "future_covariates"},
+    pytest.skip(
+        f"Torch not available. {__name__} tests will be skipped.",
+        allow_module_level=True,
+    )
+
+
+def helper_create_test_cases(series_options: list):
+    covariates_options = [
+        {},
+        {"past_covariates"},
+        {"future_covariates"},
+        {"past_covariates", "future_covariates"},
+    ]
+    relative_index_options = [False, True]
+    use_encoders_options = [False, True]
+    return itertools.product(
+        *[
+            series_options,
+            covariates_options,
+            relative_index_options,
+            use_encoders_options,
         ]
-        relative_index_options = [False, True]
-        use_encoders_options = [False, True]
-        return itertools.product(
-            *[
-                series_options,
-                covariates_options,
-                relative_index_options,
-                use_encoders_options,
+    )
+
+
+class TestTFTExplainer:
+    freq = "MS"
+    series_lin_pos = tg.linear_timeseries(length=10, freq=freq).with_static_covariates(
+        pd.Series([0.0, 0.5], index=["cat", "num"])
+    )
+    series_sine = tg.sine_timeseries(length=10, freq=freq)
+    series_mv1 = series_lin_pos.stack(series_sine)
+
+    series_lin_neg = tg.linear_timeseries(
+        start_value=1, end_value=0, length=10, freq=freq
+    ).with_static_covariates(pd.Series([1.0, 0.5], index=["cat", "num"]))
+    series_cos = tg.sine_timeseries(length=10, value_phase=90, freq=freq)
+    series_mv2 = series_lin_neg.stack(series_cos)
+
+    series_multi = [series_mv1, series_mv2]
+    pc = tg.constant_timeseries(length=10, freq=freq)
+    pc_multi = [pc] * 2
+    fc = tg.constant_timeseries(length=13, freq=freq)
+    fc_multi = [fc] * 2
+
+    def helper_get_input(self, series_option: str):
+        if series_option == "univariate":
+            return self.series_lin_pos, self.pc, self.fc
+        elif series_option == "multivariate":
+            return self.series_mv1, self.pc, self.fc
+        else:  # multiple
+            return self.series_multi, self.pc_multi, self.fc_multi
+
+    @pytest.mark.parametrize(
+        "test_case", helper_create_test_cases(["univariate", "multivariate"])
+    )
+    def test_explainer_single_univariate_multivariate_series(self, test_case):
+        """Test TFTExplainer with single univariate and multivariate series and a combination of
+        encoders, covariates, and addition of relative index."""
+        series_option, cov_option, add_relative_idx, use_encoders = test_case
+        series, pc, fc = self.helper_get_input(series_option)
+        cov_test_case = dict()
+        use_pc, use_fc = False, False
+        if "past_covariates" in cov_option:
+            cov_test_case["past_covariates"] = pc
+            use_pc = True
+        if "future_covariates" in cov_option:
+            cov_test_case["future_covariates"] = fc
+            use_fc = True
+
+        # expected number of features for past covs, future covs, and static covs, and encoder/decoder
+        n_target_expected = series.n_components
+        n_pc_expected = 1 if "past_covariates" in cov_test_case else 0
+        n_fc_expected = 1 if "future_covariates" in cov_test_case else 0
+        n_sc_expected = 2
+        # encoder is number of past and future covs plus 4 optional encodings (future and past)
+        # plus 1 univariate target plus 1 optional relative index
+        n_enc_expected = (
+            n_pc_expected
+            + n_fc_expected
+            + n_target_expected
+            + (4 if use_encoders else 0)
+            + (1 if add_relative_idx else 0)
+        )
+        # encoder is number of future covs plus 2 optional encodings (future)
+        # plus 1 optional relative index
+        n_dec_expected = (
+            n_fc_expected + (2 if use_encoders else 0) + (1 if add_relative_idx else 0)
+        )
+        model = self.helper_create_model(
+            use_encoders=use_encoders, add_relative_idx=add_relative_idx
+        )
+        # TFTModel requires future covariates
+        if (
+            not add_relative_idx
+            and "future_covariates" not in cov_test_case
+            and not use_encoders
+        ):
+            with pytest.raises(ValueError):
+                model.fit(series=series, **cov_test_case)
+            return
+
+        model.fit(series=series, **cov_test_case)
+        explainer = TFTExplainer(model)
+        explainer2 = TFTExplainer(
+            model,
+            background_series=series,
+            background_past_covariates=pc if use_pc else None,
+            background_future_covariates=fc if use_fc else None,
+        )
+        assert explainer.background_series == explainer2.background_series
+        assert (
+            explainer.background_past_covariates
+            == explainer2.background_past_covariates
+        )
+        assert (
+            explainer.background_future_covariates
+            == explainer2.background_future_covariates
+        )
+
+        assert hasattr(explainer, "model")
+        assert explainer.background_series[0] == series
+        if use_pc:
+            assert explainer.background_past_covariates[0] == pc
+            assert explainer.background_past_covariates[0].n_components == n_pc_expected
+        else:
+            assert explainer.background_past_covariates is None
+        if use_fc:
+            assert explainer.background_future_covariates[0] == fc
+            assert (
+                explainer.background_future_covariates[0].n_components == n_fc_expected
+            )
+        else:
+            assert explainer.background_future_covariates is None
+        result = explainer.explain()
+        assert isinstance(result, TFTExplainabilityResult)
+
+        enc_imp = result.get_encoder_importance()
+        dec_imp = result.get_decoder_importance()
+        stc_imp = result.get_static_covariates_importance()
+        imps = [enc_imp, dec_imp, stc_imp]
+        assert all([isinstance(imp, pd.DataFrame) for imp in imps])
+        # importances must sum up to 100 percent
+        assert all(
+            [imp.squeeze().sum() == pytest.approx(100.0, rel=0.2) for imp in imps]
+        )
+        # importances must have the expected number of columns
+        assert all(
+            [
+                len(imp.columns) == n
+                for imp, n in zip(imps, [n_enc_expected, n_dec_expected, n_sc_expected])
             ]
         )
 
-    class TestTFTExplainer:
-        freq = "MS"
-        series_lin_pos = tg.linear_timeseries(
-            length=10, freq=freq
-        ).with_static_covariates(pd.Series([0.0, 0.5], index=["cat", "num"]))
-        series_sine = tg.sine_timeseries(length=10, freq=freq)
-        series_mv1 = series_lin_pos.stack(series_sine)
-
-        series_lin_neg = tg.linear_timeseries(
-            start_value=1, end_value=0, length=10, freq=freq
-        ).with_static_covariates(pd.Series([1.0, 0.5], index=["cat", "num"]))
-        series_cos = tg.sine_timeseries(length=10, value_phase=90, freq=freq)
-        series_mv2 = series_lin_neg.stack(series_cos)
-
-        series_multi = [series_mv1, series_mv2]
-        pc = tg.constant_timeseries(length=10, freq=freq)
-        pc_multi = [pc] * 2
-        fc = tg.constant_timeseries(length=13, freq=freq)
-        fc_multi = [fc] * 2
-
-        def helper_get_input(self, series_option: str):
-            if series_option == "univariate":
-                return self.series_lin_pos, self.pc, self.fc
-            elif series_option == "multivariate":
-                return self.series_mv1, self.pc, self.fc
-            else:  # multiple
-                return self.series_multi, self.pc_multi, self.fc_multi
-
-        @pytest.mark.parametrize(
-            "test_case", helper_create_test_cases(["univariate", "multivariate"])
+        attention = result.get_attention()
+        assert isinstance(attention, TimeSeries)
+        # input chunk length + output chunk length = 5 + 2 = 7
+        icl, ocl = 5, 2
+        freq = series.freq
+        assert len(attention) == icl + ocl
+        assert attention.start_time() == series.end_time() - (icl - 1) * freq
+        assert attention.end_time() == series.end_time() + ocl * freq
+        assert attention.n_components == ocl
+
+    @pytest.mark.parametrize("test_case", helper_create_test_cases(["multiple"]))
+    def test_explainer_multiple_multivariate_series(self, test_case):
+        """Test TFTExplainer with multiple multivaraites series and a combination of encoders, covariates,
+        and addition of relative index."""
+        series_option, cov_option, add_relative_idx, use_encoders = test_case
+        series, pc, fc = self.helper_get_input(series_option)
+        cov_test_case = dict()
+        use_pc, use_fc = False, False
+        if "past_covariates" in cov_option:
+            cov_test_case["past_covariates"] = pc
+            use_pc = True
+        if "future_covariates" in cov_option:
+            cov_test_case["future_covariates"] = fc
+            use_fc = True
+
+        # expected number of features for past covs, future covs, and static covs, and encoder/decoder
+        n_target_expected = series[0].n_components
+        n_pc_expected = 1 if "past_covariates" in cov_test_case else 0
+        n_fc_expected = 1 if "future_covariates" in cov_test_case else 0
+        n_sc_expected = 2
+        # encoder is number of past and future covs plus 4 optional encodings (future and past)
+        # plus 1 univariate target plus 1 optional relative index
+        n_enc_expected = (
+            n_pc_expected
+            + n_fc_expected
+            + n_target_expected
+            + (4 if use_encoders else 0)
+            + (1 if add_relative_idx else 0)
         )
-        def test_explainer_single_univariate_multivariate_series(self, test_case):
-            """Test TFTExplainer with single univariate and multivariate series and a combination of
-            encoders, covariates, and addition of relative index."""
-            series_option, cov_option, add_relative_idx, use_encoders = test_case
-            series, pc, fc = self.helper_get_input(series_option)
-            cov_test_case = dict()
-            use_pc, use_fc = False, False
-            if "past_covariates" in cov_option:
-                cov_test_case["past_covariates"] = pc
-                use_pc = True
-            if "future_covariates" in cov_option:
-                cov_test_case["future_covariates"] = fc
-                use_fc = True
-
-            # expected number of features for past covs, future covs, and static covs, and encoder/decoder
-            n_target_expected = series.n_components
-            n_pc_expected = 1 if "past_covariates" in cov_test_case else 0
-            n_fc_expected = 1 if "future_covariates" in cov_test_case else 0
-            n_sc_expected = 2
-            # encoder is number of past and future covs plus 4 optional encodings (future and past)
-            # plus 1 univariate target plus 1 optional relative index
-            n_enc_expected = (
-                n_pc_expected
-                + n_fc_expected
-                + n_target_expected
-                + (4 if use_encoders else 0)
-                + (1 if add_relative_idx else 0)
-            )
-            # encoder is number of future covs plus 2 optional encodings (future)
-            # plus 1 optional relative index
-            n_dec_expected = (
-                n_fc_expected
-                + (2 if use_encoders else 0)
-                + (1 if add_relative_idx else 0)
-            )
-            model = self.helper_create_model(
-                use_encoders=use_encoders, add_relative_idx=add_relative_idx
-            )
-            # TFTModel requires future covariates
-            if (
-                not add_relative_idx
-                and "future_covariates" not in cov_test_case
-                and not use_encoders
-            ):
-                with pytest.raises(ValueError):
-                    model.fit(series=series, **cov_test_case)
-                return
-
-            model.fit(series=series, **cov_test_case)
+        # encoder is number of future covs plus 2 optional encodings (future)
+        # plus 1 optional relative index
+        n_dec_expected = (
+            n_fc_expected + (2 if use_encoders else 0) + (1 if add_relative_idx else 0)
+        )
+        model = self.helper_create_model(
+            use_encoders=use_encoders, add_relative_idx=add_relative_idx
+        )
+        # TFTModel requires future covariates
+        if (
+            not add_relative_idx
+            and "future_covariates" not in cov_test_case
+            and not use_encoders
+        ):
+            with pytest.raises(ValueError):
+                model.fit(series=series, **cov_test_case)
+            return
+
+        model.fit(series=series, **cov_test_case)
+        # explainer requires background if model trained on multiple time series
+        with pytest.raises(ValueError):
             explainer = TFTExplainer(model)
-            explainer2 = TFTExplainer(
-                model,
-                background_series=series,
-                background_past_covariates=pc if use_pc else None,
-                background_future_covariates=fc if use_fc else None,
-            )
-            assert explainer.background_series == explainer2.background_series
-            assert (
-                explainer.background_past_covariates
-                == explainer2.background_past_covariates
-            )
+        explainer = TFTExplainer(
+            model,
+            background_series=series,
+            background_past_covariates=pc if use_pc else None,
+            background_future_covariates=fc if use_fc else None,
+        )
+        assert hasattr(explainer, "model")
+        assert explainer.background_series, series
+        if use_pc:
+            assert explainer.background_past_covariates == pc
+            assert explainer.background_past_covariates[0].n_components == n_pc_expected
+        else:
+            assert explainer.background_past_covariates is None
+        if use_fc:
+            assert explainer.background_future_covariates == fc
             assert (
-                explainer.background_future_covariates
-                == explainer2.background_future_covariates
+                explainer.background_future_covariates[0].n_components == n_fc_expected
             )
+        else:
+            assert explainer.background_future_covariates is None
+        result = explainer.explain()
+        assert isinstance(result, TFTExplainabilityResult)
+
+        enc_imp = result.get_encoder_importance()
+        dec_imp = result.get_decoder_importance()
+        stc_imp = result.get_static_covariates_importance()
+        imps = [enc_imp, dec_imp, stc_imp]
+        assert all([isinstance(imp, list) for imp in imps])
+        assert all([len(imp) == len(series) for imp in imps])
+        assert all([isinstance(imp_, pd.DataFrame) for imp in imps for imp_ in imp])
+        # importances must sum up to 100 percent
+        assert all(
+            [
+                imp_.squeeze().sum() == pytest.approx(100.0, abs=0.21)
+                for imp in imps
+                for imp_ in imp
+            ]
+        )
+        # importances must have the expected number of columns
+        assert all(
+            [
+                len(imp_.columns) == n
+                for imp, n in zip(imps, [n_enc_expected, n_dec_expected, n_sc_expected])
+                for imp_ in imp
+            ]
+        )
 
-            assert hasattr(explainer, "model")
-            assert explainer.background_series[0] == series
-            if use_pc:
-                assert explainer.background_past_covariates[0] == pc
-                assert (
-                    explainer.background_past_covariates[0].n_components
-                    == n_pc_expected
-                )
-            else:
-                assert explainer.background_past_covariates is None
-            if use_fc:
-                assert explainer.background_future_covariates[0] == fc
-                assert (
-                    explainer.background_future_covariates[0].n_components
-                    == n_fc_expected
-                )
-            else:
-                assert explainer.background_future_covariates is None
-            result = explainer.explain()
-            assert isinstance(result, TFTExplainabilityResult)
-
-            enc_imp = result.get_encoder_importance()
-            dec_imp = result.get_decoder_importance()
-            stc_imp = result.get_static_covariates_importance()
-            imps = [enc_imp, dec_imp, stc_imp]
-            assert all([isinstance(imp, pd.DataFrame) for imp in imps])
-            # importances must sum up to 100 percent
-            assert all(
-                [imp.squeeze().sum() == pytest.approx(100.0, rel=0.2) for imp in imps]
-            )
-            # importances must have the expected number of columns
-            assert all(
-                [
-                    len(imp.columns) == n
-                    for imp, n in zip(
-                        imps, [n_enc_expected, n_dec_expected, n_sc_expected]
-                    )
-                ]
-            )
+        attention = result.get_attention()
+        assert isinstance(attention, list)
+        assert len(attention) == len(series)
+        assert all([isinstance(att, TimeSeries) for att in attention])
+        # input chunk length + output chunk length = 5 + 2 = 7
+        icl, ocl = 5, 2
+        freq = series[0].freq
+        assert all([len(att) == icl + ocl for att in attention])
+        assert all(
+            [
+                att.start_time() == series_.end_time() - (icl - 1) * freq
+                for att, series_ in zip(attention, series)
+            ]
+        )
+        assert all(
+            [
+                att.end_time() == series_.end_time() + ocl * freq
+                for att, series_ in zip(attention, series)
+            ]
+        )
+        assert all([att.n_components == ocl for att in attention])
+
+    def test_variable_selection_explanation(self):
+        """Test variable selection (feature importance) explanation results and plotting."""
+        model = self.helper_create_model(use_encoders=True, add_relative_idx=True)
+        series, pc, fc = self.helper_get_input(series_option="multivariate")
+        model.fit(series, past_covariates=pc, future_covariates=fc)
+        explainer = TFTExplainer(model)
+        results = explainer.explain()
+
+        imps = results.get_feature_importances()
+        enc_imp = results.get_encoder_importance()
+        dec_imp = results.get_decoder_importance()
+        stc_imp = results.get_static_covariates_importance()
+        imps_direct = [enc_imp, dec_imp, stc_imp]
+
+        imp_names = [
+            "encoder_importance",
+            "decoder_importance",
+            "static_covariates_importance",
+        ]
+        assert list(imps.keys()) == imp_names
+        for imp, imp_name in zip(imps_direct, imp_names):
+            assert imps[imp_name].equals(imp)
+
+        enc_expected = pd.DataFrame(
+            {
+                "linear_target": 1.7,
+                "sine_target": 3.1,
+                "add_relative_index_futcov": 3.6,
+                "constant_pastcov": 3.9,
+                "darts_enc_fc_cyc_month_sin_futcov": 5.0,
+                "darts_enc_pc_cyc_month_sin_pastcov": 10.1,
+                "darts_enc_pc_cyc_month_cos_pastcov": 19.9,
+                "constant_futcov": 21.8,
+                "darts_enc_fc_cyc_month_cos_futcov": 31.0,
+            },
+            index=[0],
+        )
+        # relaxed comparison because M1 chip gives slightly different results than intel chip
+        assert ((enc_imp.round(decimals=1) - enc_expected).abs() <= 3).all().all()
+
+        dec_expected = pd.DataFrame(
+            {
+                "darts_enc_fc_cyc_month_sin_futcov": 5.3,
+                "darts_enc_fc_cyc_month_cos_futcov": 7.4,
+                "constant_futcov": 24.5,
+                "add_relative_index_futcov": 62.9,
+            },
+            index=[0],
+        )
+        # relaxed comparison because M1 chip gives slightly different results than intel chip
+        assert ((dec_imp.round(decimals=1) - dec_expected).abs() <= 0.6).all().all()
 
-            attention = result.get_attention()
-            assert isinstance(attention, TimeSeries)
-            # input chunk length + output chunk length = 5 + 2 = 7
-            icl, ocl = 5, 2
-            freq = series.freq
-            assert len(attention) == icl + ocl
-            assert attention.start_time() == series.end_time() - (icl - 1) * freq
-            assert attention.end_time() == series.end_time() + ocl * freq
-            assert attention.n_components == ocl
-
-        @pytest.mark.parametrize("test_case", helper_create_test_cases(["multiple"]))
-        def test_explainer_multiple_multivariate_series(self, test_case):
-            """Test TFTExplainer with multiple multivaraites series and a combination of encoders, covariates,
-            and addition of relative index."""
-            series_option, cov_option, add_relative_idx, use_encoders = test_case
-            series, pc, fc = self.helper_get_input(series_option)
-            cov_test_case = dict()
-            use_pc, use_fc = False, False
-            if "past_covariates" in cov_option:
-                cov_test_case["past_covariates"] = pc
-                use_pc = True
-            if "future_covariates" in cov_option:
-                cov_test_case["future_covariates"] = fc
-                use_fc = True
-
-            # expected number of features for past covs, future covs, and static covs, and encoder/decoder
-            n_target_expected = series[0].n_components
-            n_pc_expected = 1 if "past_covariates" in cov_test_case else 0
-            n_fc_expected = 1 if "future_covariates" in cov_test_case else 0
-            n_sc_expected = 2
-            # encoder is number of past and future covs plus 4 optional encodings (future and past)
-            # plus 1 univariate target plus 1 optional relative index
-            n_enc_expected = (
-                n_pc_expected
-                + n_fc_expected
-                + n_target_expected
-                + (4 if use_encoders else 0)
-                + (1 if add_relative_idx else 0)
-            )
-            # encoder is number of future covs plus 2 optional encodings (future)
-            # plus 1 optional relative index
-            n_dec_expected = (
-                n_fc_expected
-                + (2 if use_encoders else 0)
-                + (1 if add_relative_idx else 0)
-            )
+        stc_expected = pd.DataFrame(
+            {"num_statcov": 11.9, "cat_statcov": 88.1}, index=[0]
+        )
+        # relaxed comparison because M1 chip gives slightly different results than intel chip
+        assert ((stc_imp.round(decimals=1) - stc_expected).abs() <= 0.1).all().all()
+
+        with patch("matplotlib.pyplot.show") as _:
+            _ = explainer.plot_variable_selection(results)
+
+    def test_attention_explanation(self):
+        """Test attention (feature importance) explanation results and plotting."""
+        # past attention (full_attention=False) on attends to values in the past relative to each horizon
+        # (look at the last 0 values in the array)
+        att_exp_past_att = np.array(
+            [
+                [1.0, 0.8],
+                [0.8, 0.7],
+                [0.6, 0.4],
+                [0.7, 0.3],
+                [0.9, 0.4],
+                [0.0, 1.3],
+                [0.0, 0.0],
+            ]
+        )
+        # full attention (full_attention=True) attends to all values in past, present, and future
+        # see the that all values are non-0
+        att_exp_full_att = np.array(
+            [
+                [0.8, 0.8],
+                [0.7, 0.6],
+                [0.4, 0.4],
+                [0.3, 0.3],
+                [0.3, 0.3],
+                [0.7, 0.8],
+                [0.8, 0.8],
+            ]
+        )
+        for full_attention, att_exp in zip(
+            [False, True], [att_exp_past_att, att_exp_full_att]
+        ):
             model = self.helper_create_model(
-                use_encoders=use_encoders, add_relative_idx=add_relative_idx
-            )
-            # TFTModel requires future covariates
-            if (
-                not add_relative_idx
-                and "future_covariates" not in cov_test_case
-                and not use_encoders
-            ):
-                with pytest.raises(ValueError):
-                    model.fit(series=series, **cov_test_case)
-                return
-
-            model.fit(series=series, **cov_test_case)
-            # explainer requires background if model trained on multiple time series
-            with pytest.raises(ValueError):
-                explainer = TFTExplainer(model)
-            explainer = TFTExplainer(
-                model,
-                background_series=series,
-                background_past_covariates=pc if use_pc else None,
-                background_future_covariates=fc if use_fc else None,
-            )
-            assert hasattr(explainer, "model")
-            assert explainer.background_series, series
-            if use_pc:
-                assert explainer.background_past_covariates == pc
-                assert (
-                    explainer.background_past_covariates[0].n_components
-                    == n_pc_expected
-                )
-            else:
-                assert explainer.background_past_covariates is None
-            if use_fc:
-                assert explainer.background_future_covariates == fc
-                assert (
-                    explainer.background_future_covariates[0].n_components
-                    == n_fc_expected
-                )
-            else:
-                assert explainer.background_future_covariates is None
-            result = explainer.explain()
-            assert isinstance(result, TFTExplainabilityResult)
-
-            enc_imp = result.get_encoder_importance()
-            dec_imp = result.get_decoder_importance()
-            stc_imp = result.get_static_covariates_importance()
-            imps = [enc_imp, dec_imp, stc_imp]
-            assert all([isinstance(imp, list) for imp in imps])
-            assert all([len(imp) == len(series) for imp in imps])
-            assert all([isinstance(imp_, pd.DataFrame) for imp in imps for imp_ in imp])
-            # importances must sum up to 100 percent
-            assert all(
-                [
-                    imp_.squeeze().sum() == pytest.approx(100.0, abs=0.21)
-                    for imp in imps
-                    for imp_ in imp
-                ]
-            )
-            # importances must have the expected number of columns
-            assert all(
-                [
-                    len(imp_.columns) == n
-                    for imp, n in zip(
-                        imps, [n_enc_expected, n_dec_expected, n_sc_expected]
-                    )
-                    for imp_ in imp
-                ]
-            )
-
-            attention = result.get_attention()
-            assert isinstance(attention, list)
-            assert len(attention) == len(series)
-            assert all([isinstance(att, TimeSeries) for att in attention])
-            # input chunk length + output chunk length = 5 + 2 = 7
-            icl, ocl = 5, 2
-            freq = series[0].freq
-            assert all([len(att) == icl + ocl for att in attention])
-            assert all(
-                [
-                    att.start_time() == series_.end_time() - (icl - 1) * freq
-                    for att, series_ in zip(attention, series)
-                ]
-            )
-            assert all(
-                [
-                    att.end_time() == series_.end_time() + ocl * freq
-                    for att, series_ in zip(attention, series)
-                ]
+                use_encoders=True,
+                add_relative_idx=True,
+                full_attention=full_attention,
             )
-            assert all([att.n_components == ocl for att in attention])
-
-        def test_variable_selection_explanation(self):
-            """Test variable selection (feature importance) explanation results and plotting."""
-            model = self.helper_create_model(use_encoders=True, add_relative_idx=True)
             series, pc, fc = self.helper_get_input(series_option="multivariate")
             model.fit(series, past_covariates=pc, future_covariates=fc)
             explainer = TFTExplainer(model)
             results = explainer.explain()
 
-            imps = results.get_feature_importances()
-            enc_imp = results.get_encoder_importance()
-            dec_imp = results.get_decoder_importance()
-            stc_imp = results.get_static_covariates_importance()
-            imps_direct = [enc_imp, dec_imp, stc_imp]
-
-            imp_names = [
-                "encoder_importance",
-                "decoder_importance",
-                "static_covariates_importance",
-            ]
-            assert list(imps.keys()) == imp_names
-            for imp, imp_name in zip(imps_direct, imp_names):
-                assert imps[imp_name].equals(imp)
-
-            enc_expected = pd.DataFrame(
-                {
-                    "linear_target": 1.7,
-                    "sine_target": 3.1,
-                    "add_relative_index_futcov": 3.6,
-                    "constant_pastcov": 3.9,
-                    "darts_enc_fc_cyc_month_sin_futcov": 5.0,
-                    "darts_enc_pc_cyc_month_sin_pastcov": 10.1,
-                    "darts_enc_pc_cyc_month_cos_pastcov": 19.9,
-                    "constant_futcov": 21.8,
-                    "darts_enc_fc_cyc_month_cos_futcov": 31.0,
-                },
-                index=[0],
-            )
-            # relaxed comparison because M1 chip gives slightly different results than intel chip
-            assert ((enc_imp.round(decimals=1) - enc_expected).abs() <= 3).all().all()
-
-            dec_expected = pd.DataFrame(
-                {
-                    "darts_enc_fc_cyc_month_sin_futcov": 5.3,
-                    "darts_enc_fc_cyc_month_cos_futcov": 7.4,
-                    "constant_futcov": 24.5,
-                    "add_relative_index_futcov": 62.9,
-                },
-                index=[0],
-            )
-            # relaxed comparison because M1 chip gives slightly different results than intel chip
-            assert ((dec_imp.round(decimals=1) - dec_expected).abs() <= 0.6).all().all()
-
-            stc_expected = pd.DataFrame(
-                {"num_statcov": 11.9, "cat_statcov": 88.1}, index=[0]
-            )
+            att = results.get_attention()
             # relaxed comparison because M1 chip gives slightly different results than intel chip
-            assert ((stc_imp.round(decimals=1) - stc_expected).abs() <= 0.1).all().all()
-
+            assert np.all(np.abs(np.round(att.values(), decimals=1) - att_exp) <= 0.2)
+            assert att.columns.tolist() == ["horizon 1", "horizon 2"]
             with patch("matplotlib.pyplot.show") as _:
-                _ = explainer.plot_variable_selection(results)
-
-        def test_attention_explanation(self):
-            """Test attention (feature importance) explanation results and plotting."""
-            # past attention (full_attention=False) on attends to values in the past relative to each horizon
-            # (look at the last 0 values in the array)
-            att_exp_past_att = np.array(
-                [
-                    [1.0, 0.8],
-                    [0.8, 0.7],
-                    [0.6, 0.4],
-                    [0.7, 0.3],
-                    [0.9, 0.4],
-                    [0.0, 1.3],
-                    [0.0, 0.0],
-                ]
-            )
-            # full attention (full_attention=True) attends to all values in past, present, and future
-            # see the that all values are non-0
-            att_exp_full_att = np.array(
-                [
-                    [0.8, 0.8],
-                    [0.7, 0.6],
-                    [0.4, 0.4],
-                    [0.3, 0.3],
-                    [0.3, 0.3],
-                    [0.7, 0.8],
-                    [0.8, 0.8],
-                ]
-            )
-            for full_attention, att_exp in zip(
-                [False, True], [att_exp_past_att, att_exp_full_att]
-            ):
-                model = self.helper_create_model(
-                    use_encoders=True,
-                    add_relative_idx=True,
-                    full_attention=full_attention,
+                _ = explainer.plot_attention(
+                    results, plot_type="all", show_index_as="relative"
                 )
-                series, pc, fc = self.helper_get_input(series_option="multivariate")
-                model.fit(series, past_covariates=pc, future_covariates=fc)
-                explainer = TFTExplainer(model)
-                results = explainer.explain()
-
-                att = results.get_attention()
-                # relaxed comparison because M1 chip gives slightly different results than intel chip
-                assert np.all(
-                    np.abs(np.round(att.values(), decimals=1) - att_exp) <= 0.2
+                plt.close()
+            with patch("matplotlib.pyplot.show") as _:
+                _ = explainer.plot_attention(
+                    results, plot_type="all", show_index_as="time"
                 )
-                assert att.columns.tolist() == ["horizon 1", "horizon 2"]
-                with patch("matplotlib.pyplot.show") as _:
-                    _ = explainer.plot_attention(
-                        results, plot_type="all", show_index_as="relative"
-                    )
-                    plt.close()
-                with patch("matplotlib.pyplot.show") as _:
-                    _ = explainer.plot_attention(
-                        results, plot_type="all", show_index_as="time"
-                    )
-                    plt.close()
-                with patch("matplotlib.pyplot.show") as _:
-                    _ = explainer.plot_attention(
-                        results, plot_type="time", show_index_as="relative"
-                    )
-                    plt.close()
-                with patch("matplotlib.pyplot.show") as _:
-                    _ = explainer.plot_attention(
-                        results, plot_type="time", show_index_as="time"
-                    )
-                    plt.close()
-                with patch("matplotlib.pyplot.show") as _:
-                    _ = explainer.plot_attention(
-                        results, plot_type="heatmap", show_index_as="relative"
-                    )
-                    plt.close()
-                with patch("matplotlib.pyplot.show") as _:
-                    _ = explainer.plot_attention(
-                        results, plot_type="heatmap", show_index_as="time"
-                    )
-                    plt.close()
-
-        def helper_create_model(
-            self, use_encoders=True, add_relative_idx=True, full_attention=False
-        ):
-            add_encoders = (
-                {"cyclic": {"past": ["month"], "future": ["month"]}}
-                if use_encoders
-                else None
-            )
-            return TFTModel(
-                input_chunk_length=5,
-                output_chunk_length=2,
-                n_epochs=1,
-                add_encoders=add_encoders,
-                add_relative_index=add_relative_idx,
-                full_attention=full_attention,
-                random_state=42,
-                **tfm_kwargs
-            )
+                plt.close()
+            with patch("matplotlib.pyplot.show") as _:
+                _ = explainer.plot_attention(
+                    results, plot_type="time", show_index_as="relative"
+                )
+                plt.close()
+            with patch("matplotlib.pyplot.show") as _:
+                _ = explainer.plot_attention(
+                    results, plot_type="time", show_index_as="time"
+                )
+                plt.close()
+            with patch("matplotlib.pyplot.show") as _:
+                _ = explainer.plot_attention(
+                    results, plot_type="heatmap", show_index_as="relative"
+                )
+                plt.close()
+            with patch("matplotlib.pyplot.show") as _:
+                _ = explainer.plot_attention(
+                    results, plot_type="heatmap", show_index_as="time"
+                )
+                plt.close()
+
+    def helper_create_model(
+        self, use_encoders=True, add_relative_idx=True, full_attention=False
+    ):
+        add_encoders = (
+            {"cyclic": {"past": ["month"], "future": ["month"]}}
+            if use_encoders
+            else None
+        )
+        return TFTModel(
+            input_chunk_length=5,
+            output_chunk_length=2,
+            n_epochs=1,
+            add_encoders=add_encoders,
+            add_relative_index=add_relative_idx,
+            full_attention=full_attention,
+            random_state=42,
+            **tfm_kwargs,
+        )
diff --git a/darts/tests/models/components/glu_variants.py b/darts/tests/models/components/glu_variants.py
index e012c7ebe9..0288af37f6 100644
--- a/darts/tests/models/components/glu_variants.py
+++ b/darts/tests/models/components/glu_variants.py
@@ -1,3 +1,5 @@
+import pytest
+
 from darts.logging import get_logger
 
 logger = get_logger(__name__)
@@ -5,22 +7,21 @@
 try:
     import torch
 
-    TORCH_AVAILABLE = True
-except ImportError:
-    logger.warning("Torch not available. Loss tests will be skipped.")
-    TORCH_AVAILABLE = False
-
-
-if TORCH_AVAILABLE:
     from darts.models.components import glu_variants
     from darts.models.components.glu_variants import GLU_FFN
+except ImportError:
+    pytest.skip(
+        f"Torch not available. {__name__} tests will be skipped.",
+        allow_module_level=True,
+    )
+
 
-    class TestFFN:
-        def test_ffn(self):
-            for FeedForward_network in GLU_FFN:
-                self.feed_forward_block = getattr(glu_variants, FeedForward_network)(
-                    d_model=4, d_ff=16, dropout=0.1
-                )
+class TestFFN:
+    def test_ffn(self):
+        for FeedForward_network in GLU_FFN:
+            self.feed_forward_block = getattr(glu_variants, FeedForward_network)(
+                d_model=4, d_ff=16, dropout=0.1
+            )
 
-                inputs = torch.zeros(1, 4, 4)
-                self.feed_forward_block(x=inputs)
+            inputs = torch.zeros(1, 4, 4)
+            self.feed_forward_block(x=inputs)
diff --git a/darts/tests/models/components/test_layer_norm_variants.py b/darts/tests/models/components/test_layer_norm_variants.py
index 374fa8deb3..b118746451 100644
--- a/darts/tests/models/components/test_layer_norm_variants.py
+++ b/darts/tests/models/components/test_layer_norm_variants.py
@@ -8,46 +8,45 @@
 try:
     import torch
 
-    TORCH_AVAILABLE = True
-except ImportError:
-    logger.warning("Torch not available. Loss tests will be skipped.")
-    TORCH_AVAILABLE = False
-
-
-if TORCH_AVAILABLE:
     from darts.models.components.layer_norm_variants import (
         LayerNorm,
         LayerNormNoBias,
         RINorm,
         RMSNorm,
     )
+except ImportError:
+    pytest.skip(
+        f"Torch not available. {__name__} tests will be skipped.",
+        allow_module_level=True,
+    )
+
 
-    class TestLayerNormVariants:
-        def test_lnv(self):
-            for layer_norm in [RMSNorm, LayerNorm, LayerNormNoBias]:
-                ln = layer_norm(4)
-                inputs = torch.zeros(1, 4, 4)
-                ln(inputs)
+class TestLayerNormVariants:
+    def test_lnv(self):
+        for layer_norm in [RMSNorm, LayerNorm, LayerNormNoBias]:
+            ln = layer_norm(4)
+            inputs = torch.zeros(1, 4, 4)
+            ln(inputs)
 
-        def test_rin(self):
+    def test_rin(self):
 
-            np.random.seed(42)
-            torch.manual_seed(42)
+        np.random.seed(42)
+        torch.manual_seed(42)
 
-            x = torch.randn(3, 4, 7)
-            affine_options = [True, False]
+        x = torch.randn(3, 4, 7)
+        affine_options = [True, False]
 
-            # test with and without affine and correct input dim
-            for affine in affine_options:
+        # test with and without affine and correct input dim
+        for affine in affine_options:
 
-                rin = RINorm(input_dim=7, affine=affine)
-                x_norm = rin(x)
+            rin = RINorm(input_dim=7, affine=affine)
+            x_norm = rin(x)
 
-                # expand dims to simulate probablistic forecasting
-                x_denorm = rin.inverse(x_norm.view(x_norm.shape + (1,))).squeeze(-1)
-                assert torch.all(torch.isclose(x, x_denorm)).item()
+            # expand dims to simulate probablistic forecasting
+            x_denorm = rin.inverse(x_norm.view(x_norm.shape + (1,))).squeeze(-1)
+            assert torch.all(torch.isclose(x, x_denorm)).item()
 
-            # try invalid input_dim
-            rin = RINorm(input_dim=3, affine=True)
-            with pytest.raises(RuntimeError):
-                x_norm = rin(x)
+        # try invalid input_dim
+        rin = RINorm(input_dim=3, affine=True)
+        with pytest.raises(RuntimeError):
+            x_norm = rin(x)
diff --git a/darts/tests/models/forecasting/test_RNN.py b/darts/tests/models/forecasting/test_RNN.py
index cdd143422b..8fe711a6d3 100644
--- a/darts/tests/models/forecasting/test_RNN.py
+++ b/darts/tests/models/forecasting/test_RNN.py
@@ -12,155 +12,149 @@
     import torch.nn as nn
 
     from darts.models.forecasting.rnn_model import CustomRNNModule, RNNModel, _RNNModule
-
-    TORCH_AVAILABLE = True
 except ImportError:
-    logger.warning("Torch not available. RNN tests will be skipped.")
-    TORCH_AVAILABLE = False
-
-
-if TORCH_AVAILABLE:
-
-    class ModuleValid1(_RNNModule):
-        """Wrapper around the _RNNModule"""
-
-        def __init__(self, **kwargs):
-            super().__init__(name="RNN", **kwargs)
-
-    class ModuleValid2(CustomRNNModule):
-        """Just a linear layer."""
-
-        def __init__(self, **kwargs):
-            super().__init__(**kwargs)
-            self.linear = nn.Linear(self.input_size, self.target_size)
-
-        def forward(self, x_in, h=None):
-            x = self.linear(x_in[0])
-            return x.view(len(x), -1, self.target_size, self.nr_params)
-
-    class TestRNNModel:
-        times = pd.date_range("20130101", "20130410")
-        pd_series = pd.Series(range(100), index=times)
-        series: TimeSeries = TimeSeries.from_series(pd_series)
-        module_invalid = _RNNModule(
-            name="RNN",
-            input_chunk_length=1,
-            output_chunk_length=1,
-            output_chunk_shift=0,
-            input_size=1,
-            hidden_dim=25,
-            num_layers=1,
-            target_size=1,
-            nr_params=1,
-            dropout=0,
-        )
-
-        def test_creation(self):
-            # cannot choose any string
-            with pytest.raises(ValueError) as msg:
-                RNNModel(input_chunk_length=1, model="UnknownRNN?")
-            assert str(msg.value).startswith("`model` is not a valid RNN model.")
-
-            # cannot create from a class instance
-            with pytest.raises(ValueError) as msg:
-                _ = RNNModel(
-                    input_chunk_length=1,
-                    model=self.module_invalid,
-                )
-            assert str(msg.value).startswith("`model` is not a valid RNN model.")
-
-            # can create from valid module name
-            model1 = RNNModel(
-                input_chunk_length=1,
-                model="RNN",
-                n_epochs=1,
-                random_state=42,
-                **tfm_kwargs
-            )
-            model1.fit(self.series)
-            preds1 = model1.predict(n=3)
-
-            # can create from a custom class itself
-            model2 = RNNModel(
+    pytest.skip(
+        f"Torch not available. {__name__} tests will be skipped.",
+        allow_module_level=True,
+    )
+
+
+class ModuleValid1(_RNNModule):
+    """Wrapper around the _RNNModule"""
+
+    def __init__(self, **kwargs):
+        super().__init__(name="RNN", **kwargs)
+
+
+class ModuleValid2(CustomRNNModule):
+    """Just a linear layer."""
+
+    def __init__(self, **kwargs):
+        super().__init__(**kwargs)
+        self.linear = nn.Linear(self.input_size, self.target_size)
+
+    def forward(self, x_in, h=None):
+        x = self.linear(x_in[0])
+        return x.view(len(x), -1, self.target_size, self.nr_params)
+
+
+class TestRNNModel:
+    times = pd.date_range("20130101", "20130410")
+    pd_series = pd.Series(range(100), index=times)
+    series: TimeSeries = TimeSeries.from_series(pd_series)
+    module_invalid = _RNNModule(
+        name="RNN",
+        input_chunk_length=1,
+        output_chunk_length=1,
+        output_chunk_shift=0,
+        input_size=1,
+        hidden_dim=25,
+        num_layers=1,
+        target_size=1,
+        nr_params=1,
+        dropout=0,
+    )
+
+    def test_creation(self):
+        # cannot choose any string
+        with pytest.raises(ValueError) as msg:
+            RNNModel(input_chunk_length=1, model="UnknownRNN?")
+        assert str(msg.value).startswith("`model` is not a valid RNN model.")
+
+        # cannot create from a class instance
+        with pytest.raises(ValueError) as msg:
+            _ = RNNModel(
                 input_chunk_length=1,
-                model=ModuleValid1,
-                n_epochs=1,
-                random_state=42,
-                **tfm_kwargs
+                model=self.module_invalid,
             )
-            model2.fit(self.series)
-            preds2 = model2.predict(n=3)
-            np.testing.assert_array_equal(preds1.all_values(), preds2.all_values())
+        assert str(msg.value).startswith("`model` is not a valid RNN model.")
 
-            model3 = RNNModel(
-                input_chunk_length=1,
-                model=ModuleValid2,
-                n_epochs=1,
-                random_state=42,
-                **tfm_kwargs
-            )
-            model3.fit(self.series)
-            preds3 = model2.predict(n=3)
-            assert preds3.all_values().shape == preds2.all_values().shape
-            assert preds3.time_index.equals(preds2.time_index)
-
-        def test_fit(self, tmpdir_module):
-            # Test basic fit()
-            model = RNNModel(input_chunk_length=1, n_epochs=2, **tfm_kwargs)
-            model.fit(self.series)
-
-            # Test fit-save-load cycle
-            model2 = RNNModel(
-                input_chunk_length=1,
-                model="LSTM",
-                n_epochs=1,
-                model_name="unittest-model-lstm",
-                work_dir=tmpdir_module,
-                save_checkpoints=True,
-                force_reset=True,
-                **tfm_kwargs
-            )
-            model2.fit(self.series)
-            model_loaded = model2.load_from_checkpoint(
-                model_name="unittest-model-lstm",
-                work_dir=tmpdir_module,
-                best=False,
-                map_location="cpu",
-            )
-            pred1 = model2.predict(n=6)
-            pred2 = model_loaded.predict(n=6)
+        # can create from valid module name
+        model1 = RNNModel(
+            input_chunk_length=1, model="RNN", n_epochs=1, random_state=42, **tfm_kwargs
+        )
+        model1.fit(self.series)
+        preds1 = model1.predict(n=3)
 
-            # Two models with the same parameters should deterministically yield the same output
-            np.testing.assert_array_equal(pred1.values(), pred2.values())
+        # can create from a custom class itself
+        model2 = RNNModel(
+            input_chunk_length=1,
+            model=ModuleValid1,
+            n_epochs=1,
+            random_state=42,
+            **tfm_kwargs,
+        )
+        model2.fit(self.series)
+        preds2 = model2.predict(n=3)
+        np.testing.assert_array_equal(preds1.all_values(), preds2.all_values())
 
-            # Another random model should not
-            model3 = RNNModel(
-                input_chunk_length=1, model="RNN", n_epochs=2, **tfm_kwargs
-            )
-            model3.fit(self.series)
-            pred3 = model3.predict(n=6)
-            assert not np.array_equal(pred1.values(), pred3.values())
-
-            # test short predict
-            pred4 = model3.predict(n=1)
-            assert len(pred4) == 1
-
-            # test validation series input
-            model3.fit(self.series[:60], val_series=self.series[60:])
-            pred4 = model3.predict(n=6)
-            assert len(pred4) == 6
-
-        def helper_test_pred_length(self, pytorch_model, series):
-            model = pytorch_model(input_chunk_length=1, n_epochs=1, **tfm_kwargs)
-            model.fit(series)
-            pred = model.predict(7)
-            assert len(pred) == 7
-            pred = model.predict(2)
-            assert len(pred) == 2
-            assert pred.width == 1
-            pred = model.predict(4)
-            assert len(pred) == 4
-            assert pred.width == 1
-
-        def test_pred_length(self):
-            self.helper_test_pred_length(RNNModel, self.series)
+        model3 = RNNModel(
+            input_chunk_length=1,
+            model=ModuleValid2,
+            n_epochs=1,
+            random_state=42,
+            **tfm_kwargs,
+        )
+        model3.fit(self.series)
+        preds3 = model2.predict(n=3)
+        assert preds3.all_values().shape == preds2.all_values().shape
+        assert preds3.time_index.equals(preds2.time_index)
+
+    def test_fit(self, tmpdir_module):
+        # Test basic fit()
+        model = RNNModel(input_chunk_length=1, n_epochs=2, **tfm_kwargs)
+        model.fit(self.series)
+
+        # Test fit-save-load cycle
+        model2 = RNNModel(
+            input_chunk_length=1,
+            model="LSTM",
+            n_epochs=1,
+            model_name="unittest-model-lstm",
+            work_dir=tmpdir_module,
+            save_checkpoints=True,
+            force_reset=True,
+            **tfm_kwargs,
+        )
+        model2.fit(self.series)
+        model_loaded = model2.load_from_checkpoint(
+            model_name="unittest-model-lstm",
+            work_dir=tmpdir_module,
+            best=False,
+            map_location="cpu",
+        )
+        pred1 = model2.predict(n=6)
+        pred2 = model_loaded.predict(n=6)
+
+        # Two models with the same parameters should deterministically yield the same output
+        np.testing.assert_array_equal(pred1.values(), pred2.values())
+
+        # Another random model should not
+        model3 = RNNModel(input_chunk_length=1, model="RNN", n_epochs=2, **tfm_kwargs)
+        model3.fit(self.series)
+        pred3 = model3.predict(n=6)
+        assert not np.array_equal(pred1.values(), pred3.values())
+
+        # test short predict
+        pred4 = model3.predict(n=1)
+        assert len(pred4) == 1
+
+        # test validation series input
+        model3.fit(self.series[:60], val_series=self.series[60:])
+        pred4 = model3.predict(n=6)
+        assert len(pred4) == 6
+
+    def helper_test_pred_length(self, pytorch_model, series):
+        model = pytorch_model(input_chunk_length=1, n_epochs=1, **tfm_kwargs)
+        model.fit(series)
+        pred = model.predict(7)
+        assert len(pred) == 7
+        pred = model.predict(2)
+        assert len(pred) == 2
+        assert pred.width == 1
+        pred = model.predict(4)
+        assert len(pred) == 4
+        assert pred.width == 1
+
+    def test_pred_length(self):
+        self.helper_test_pred_length(RNNModel, self.series)
diff --git a/darts/tests/models/forecasting/test_TCN.py b/darts/tests/models/forecasting/test_TCN.py
index 587929ce07..4c6fb144ad 100644
--- a/darts/tests/models/forecasting/test_TCN.py
+++ b/darts/tests/models/forecasting/test_TCN.py
@@ -11,199 +11,193 @@
     import torch
 
     from darts.models.forecasting.tcn_model import TCNModel
-
-    TORCH_AVAILABLE = True
 except ImportError:
-    logger.warning("Torch not available. TCN tests will be skipped.")
-    TORCH_AVAILABLE = False
-
-
-if TORCH_AVAILABLE:
-
-    class TestTCNModel:
-        def test_creation(self):
-            with pytest.raises(ValueError):
-                # cannot choose a kernel size larger than the input length
-                TCNModel(input_chunk_length=20, output_chunk_length=1, kernel_size=100)
-            TCNModel(input_chunk_length=12, output_chunk_length=1)
-
-        def test_fit(self):
-            large_ts = tg.constant_timeseries(length=100, value=1000)
-            small_ts = tg.constant_timeseries(length=100, value=10)
-
-            # Test basic fit and predict
-            model = TCNModel(
-                input_chunk_length=12,
-                output_chunk_length=1,
-                n_epochs=10,
-                num_layers=1,
-                **tfm_kwargs,
-            )
-            model.fit(large_ts[:98])
-            pred = model.predict(n=2).values()[0]
-
-            # Test whether model trained on one series is better than one trained on another
-            model2 = TCNModel(
-                input_chunk_length=12,
-                output_chunk_length=1,
-                n_epochs=10,
-                num_layers=1,
-                **tfm_kwargs,
-            )
-            model2.fit(small_ts[:98])
-            pred2 = model2.predict(n=2).values()[0]
-            assert abs(pred2 - 10) < abs(pred - 10)
-
-            # test short predict
-            pred3 = model2.predict(n=1)
-            assert len(pred3) == 1
-
-        def test_performance(self):
-            # test TCN performance on dummy time series
-            ts = tg.sine_timeseries(length=100) + tg.linear_timeseries(
-                length=100, end_value=2
-            )
-            train, test = ts[:90], ts[90:]
-            model = TCNModel(
-                input_chunk_length=12,
-                output_chunk_length=10,
-                n_epochs=300,
-                random_state=0,
-                **tfm_kwargs,
-            )
-            model.fit(train)
-            pred = model.predict(n=10)
-
-            assert mae(pred, test) < 0.3
-
-        @pytest.mark.slow
-        def test_coverage(self):
-            torch.manual_seed(0)
-            input_chunk_lengths = range(20, 50)
-            kernel_sizes = range(2, 5)
-            dilation_bases = range(2, 5)
-
-            for kernel_size in kernel_sizes:
-                for dilation_base in dilation_bases:
-                    if dilation_base > kernel_size:
-                        continue
-                    for input_chunk_length in input_chunk_lengths:
-
-                        # create model with all weights set to one
-                        model = TCNModel(
-                            input_chunk_length=input_chunk_length,
-                            output_chunk_length=1,
-                            kernel_size=kernel_size,
-                            dilation_base=dilation_base,
-                            weight_norm=False,
-                            n_epochs=1,
-                            **tfm_kwargs,
-                        )
-
-                        # we have to fit the model on a dummy series in order to create the internal nn.Module
-                        model.fit(tg.gaussian_timeseries(length=100))
-
-                        for res_block in model.model.res_blocks:
-                            res_block.conv1.weight = torch.nn.Parameter(
-                                torch.ones(
-                                    res_block.conv1.weight.shape, dtype=torch.float64
-                                )
+    pytest.skip(
+        f"Torch not available. {__name__} tests will be skipped.",
+        allow_module_level=True,
+    )
+
+
+class TestTCNModel:
+    def test_creation(self):
+        with pytest.raises(ValueError):
+            # cannot choose a kernel size larger than the input length
+            TCNModel(input_chunk_length=20, output_chunk_length=1, kernel_size=100)
+        TCNModel(input_chunk_length=12, output_chunk_length=1)
+
+    def test_fit(self):
+        large_ts = tg.constant_timeseries(length=100, value=1000)
+        small_ts = tg.constant_timeseries(length=100, value=10)
+
+        # Test basic fit and predict
+        model = TCNModel(
+            input_chunk_length=12,
+            output_chunk_length=1,
+            n_epochs=10,
+            num_layers=1,
+            **tfm_kwargs,
+        )
+        model.fit(large_ts[:98])
+        pred = model.predict(n=2).values()[0]
+
+        # Test whether model trained on one series is better than one trained on another
+        model2 = TCNModel(
+            input_chunk_length=12,
+            output_chunk_length=1,
+            n_epochs=10,
+            num_layers=1,
+            **tfm_kwargs,
+        )
+        model2.fit(small_ts[:98])
+        pred2 = model2.predict(n=2).values()[0]
+        assert abs(pred2 - 10) < abs(pred - 10)
+
+        # test short predict
+        pred3 = model2.predict(n=1)
+        assert len(pred3) == 1
+
+    def test_performance(self):
+        # test TCN performance on dummy time series
+        ts = tg.sine_timeseries(length=100) + tg.linear_timeseries(
+            length=100, end_value=2
+        )
+        train, test = ts[:90], ts[90:]
+        model = TCNModel(
+            input_chunk_length=12,
+            output_chunk_length=10,
+            n_epochs=300,
+            random_state=0,
+            **tfm_kwargs,
+        )
+        model.fit(train)
+        pred = model.predict(n=10)
+
+        assert mae(pred, test) < 0.3
+
+    @pytest.mark.slow
+    def test_coverage(self):
+        torch.manual_seed(0)
+        input_chunk_lengths = range(20, 50)
+        kernel_sizes = range(2, 5)
+        dilation_bases = range(2, 5)
+
+        for kernel_size in kernel_sizes:
+            for dilation_base in dilation_bases:
+                if dilation_base > kernel_size:
+                    continue
+                for input_chunk_length in input_chunk_lengths:
+
+                    # create model with all weights set to one
+                    model = TCNModel(
+                        input_chunk_length=input_chunk_length,
+                        output_chunk_length=1,
+                        kernel_size=kernel_size,
+                        dilation_base=dilation_base,
+                        weight_norm=False,
+                        n_epochs=1,
+                        **tfm_kwargs,
+                    )
+
+                    # we have to fit the model on a dummy series in order to create the internal nn.Module
+                    model.fit(tg.gaussian_timeseries(length=100))
+
+                    for res_block in model.model.res_blocks:
+                        res_block.conv1.weight = torch.nn.Parameter(
+                            torch.ones(
+                                res_block.conv1.weight.shape, dtype=torch.float64
                             )
-                            res_block.conv2.weight = torch.nn.Parameter(
-                                torch.ones(
-                                    res_block.conv2.weight.shape, dtype=torch.float64
-                                )
+                        )
+                        res_block.conv2.weight = torch.nn.Parameter(
+                            torch.ones(
+                                res_block.conv2.weight.shape, dtype=torch.float64
                             )
+                        )
 
-                        model.model.eval()
+                    model.model.eval()
 
-                        # also disable MC Dropout:
-                        model.model.set_mc_dropout(False)
+                    # also disable MC Dropout:
+                    model.model.set_mc_dropout(False)
 
-                        input_tensor = torch.zeros(
-                            [1, input_chunk_length, 1], dtype=torch.float64
-                        )
-                        zero_output = model.model.forward((input_tensor, None))[
+                    input_tensor = torch.zeros(
+                        [1, input_chunk_length, 1], dtype=torch.float64
+                    )
+                    zero_output = model.model.forward((input_tensor, None))[0, -1, 0]
+
+                    # test for full coverage
+                    for i in range(input_chunk_length):
+                        input_tensor[0, i, 0] = 1
+                        curr_output = model.model.forward((input_tensor, None))[
                             0, -1, 0
                         ]
-
-                        # test for full coverage
-                        for i in range(input_chunk_length):
-                            input_tensor[0, i, 0] = 1
-                            curr_output = model.model.forward((input_tensor, None))[
-                                0, -1, 0
-                            ]
-                            assert zero_output != curr_output
-                            input_tensor[0, i, 0] = 0
-
-                        # create model with all weights set to one and one layer less than is automatically detected
-                        model_2 = TCNModel(
-                            input_chunk_length=input_chunk_length,
-                            output_chunk_length=1,
-                            kernel_size=kernel_size,
-                            dilation_base=dilation_base,
-                            weight_norm=False,
-                            num_layers=model.model.num_layers - 1,
-                            n_epochs=1,
-                            **tfm_kwargs,
-                        )
-
-                        # we have to fit the model on a dummy series in order to create the internal nn.Module
-                        model_2.fit(tg.gaussian_timeseries(length=100))
-
-                        for res_block in model_2.model.res_blocks:
-                            res_block.conv1.weight = torch.nn.Parameter(
-                                torch.ones(
-                                    res_block.conv1.weight.shape, dtype=torch.float64
-                                )
+                        assert zero_output != curr_output
+                        input_tensor[0, i, 0] = 0
+
+                    # create model with all weights set to one and one layer less than is automatically detected
+                    model_2 = TCNModel(
+                        input_chunk_length=input_chunk_length,
+                        output_chunk_length=1,
+                        kernel_size=kernel_size,
+                        dilation_base=dilation_base,
+                        weight_norm=False,
+                        num_layers=model.model.num_layers - 1,
+                        n_epochs=1,
+                        **tfm_kwargs,
+                    )
+
+                    # we have to fit the model on a dummy series in order to create the internal nn.Module
+                    model_2.fit(tg.gaussian_timeseries(length=100))
+
+                    for res_block in model_2.model.res_blocks:
+                        res_block.conv1.weight = torch.nn.Parameter(
+                            torch.ones(
+                                res_block.conv1.weight.shape, dtype=torch.float64
                             )
-                            res_block.conv2.weight = torch.nn.Parameter(
-                                torch.ones(
-                                    res_block.conv2.weight.shape, dtype=torch.float64
-                                )
+                        )
+                        res_block.conv2.weight = torch.nn.Parameter(
+                            torch.ones(
+                                res_block.conv2.weight.shape, dtype=torch.float64
                             )
+                        )
 
-                        model_2.model.eval()
+                    model_2.model.eval()
 
-                        # also disable MC Dropout:
-                        model_2.model.set_mc_dropout(False)
+                    # also disable MC Dropout:
+                    model_2.model.set_mc_dropout(False)
 
-                        input_tensor = torch.zeros(
-                            [1, input_chunk_length, 1], dtype=torch.float64
-                        )
-                        zero_output = model_2.model.forward((input_tensor, None))[
+                    input_tensor = torch.zeros(
+                        [1, input_chunk_length, 1], dtype=torch.float64
+                    )
+                    zero_output = model_2.model.forward((input_tensor, None))[0, -1, 0]
+
+                    # test for incomplete coverage
+                    uncovered_input_found = False
+                    if model_2.model.num_layers == 1:
+                        continue
+                    for i in range(input_chunk_length):
+                        input_tensor[0, i, 0] = 1
+                        curr_output = model_2.model.forward((input_tensor, None))[
                             0, -1, 0
                         ]
-
-                        # test for incomplete coverage
-                        uncovered_input_found = False
-                        if model_2.model.num_layers == 1:
-                            continue
-                        for i in range(input_chunk_length):
-                            input_tensor[0, i, 0] = 1
-                            curr_output = model_2.model.forward((input_tensor, None))[
-                                0, -1, 0
-                            ]
-                            if zero_output == curr_output:
-                                uncovered_input_found = True
-                                break
-                            input_tensor[0, i, 0] = 0
-                        assert uncovered_input_found
-
-        def helper_test_pred_length(self, pytorch_model, series):
-            model = pytorch_model(
-                input_chunk_length=12, output_chunk_length=3, n_epochs=1, **tfm_kwargs
-            )
-            model.fit(series)
-            pred = model.predict(7)
-            assert len(pred) == 7
-            pred = model.predict(2)
-            assert len(pred) == 2
-            assert pred.width == 1
-            pred = model.predict(4)
-            assert len(pred) == 4
-            assert pred.width == 1
-
-        def test_pred_length(self):
-            series = tg.linear_timeseries(length=100)
-            self.helper_test_pred_length(TCNModel, series)
+                        if zero_output == curr_output:
+                            uncovered_input_found = True
+                            break
+                        input_tensor[0, i, 0] = 0
+                    assert uncovered_input_found
+
+    def helper_test_pred_length(self, pytorch_model, series):
+        model = pytorch_model(
+            input_chunk_length=12, output_chunk_length=3, n_epochs=1, **tfm_kwargs
+        )
+        model.fit(series)
+        pred = model.predict(7)
+        assert len(pred) == 7
+        pred = model.predict(2)
+        assert len(pred) == 2
+        assert pred.width == 1
+        pred = model.predict(4)
+        assert len(pred) == 4
+        assert pred.width == 1
+
+    def test_pred_length(self):
+        series = tg.linear_timeseries(length=100)
+        self.helper_test_pred_length(TCNModel, series)
diff --git a/darts/tests/models/forecasting/test_TFT.py b/darts/tests/models/forecasting/test_TFT.py
index 5758bef8f3..ff629a6211 100644
--- a/darts/tests/models/forecasting/test_TFT.py
+++ b/darts/tests/models/forecasting/test_TFT.py
@@ -17,412 +17,406 @@
     from darts.models.forecasting.tft_model import TFTModel
     from darts.models.forecasting.tft_submodels import get_embedding_size
     from darts.utils.likelihood_models import QuantileRegression
-
-    TORCH_AVAILABLE = True
 except ImportError:
-    logger.warning("Torch not available. TFT tests will be skipped.")
-    TORCH_AVAILABLE = False
-    TFTModel, QuantileRegression, MSELoss = None, None, None
-
-
-if TORCH_AVAILABLE:
+    pytest.skip(
+        f"Torch not available. {__name__} tests will be skipped.",
+        allow_module_level=True,
+    )
 
-    class TestTFTModel:
-        def test_quantile_regression(self):
-            q_no_50 = [0.1, 0.4, 0.9]
-            q_non_symmetric = [0.2, 0.5, 0.9]
 
-            # if a QuantileLoss is used, it must have to q=0.5 quantile
-            with pytest.raises(ValueError):
-                QuantileRegression(q_no_50)
+class TestTFTModel:
+    def test_quantile_regression(self):
+        q_no_50 = [0.1, 0.4, 0.9]
+        q_non_symmetric = [0.2, 0.5, 0.9]
 
-            # if a QuantileLoss is used, it must be symmetric around q=0.5 quantile (i.e. [0.1, 0.5, 0.9])
-            with pytest.raises(ValueError):
-                QuantileRegression(q_non_symmetric)
+        # if a QuantileLoss is used, it must have to q=0.5 quantile
+        with pytest.raises(ValueError):
+            QuantileRegression(q_no_50)
 
-        def test_future_covariate_handling(self):
-            ts_time_index = tg.sine_timeseries(length=2, freq="h")
-            ts_integer_index = TimeSeries.from_values(values=ts_time_index.values())
+        # if a QuantileLoss is used, it must be symmetric around q=0.5 quantile (i.e. [0.1, 0.5, 0.9])
+        with pytest.raises(ValueError):
+            QuantileRegression(q_non_symmetric)
 
-            # model requires future covariates without cyclic encoding
-            model = TFTModel(input_chunk_length=1, output_chunk_length=1, **tfm_kwargs)
-            with pytest.raises(ValueError):
-                model.fit(ts_time_index, verbose=False)
+    def test_future_covariate_handling(self):
+        ts_time_index = tg.sine_timeseries(length=2, freq="h")
+        ts_integer_index = TimeSeries.from_values(values=ts_time_index.values())
 
-            # should work with cyclic encoding for time index
-            model = TFTModel(
-                input_chunk_length=1,
-                output_chunk_length=1,
-                add_encoders={"cyclic": {"future": "hour"}},
-                **tfm_kwargs,
-            )
+        # model requires future covariates without cyclic encoding
+        model = TFTModel(input_chunk_length=1, output_chunk_length=1, **tfm_kwargs)
+        with pytest.raises(ValueError):
             model.fit(ts_time_index, verbose=False)
 
-            # should work with relative index both with time index and integer index
-            model = TFTModel(
-                input_chunk_length=1,
-                output_chunk_length=1,
-                add_relative_index=True,
-                **tfm_kwargs,
-            )
-            model.fit(ts_time_index, verbose=False)
-            model.fit(ts_integer_index, verbose=False)
-
-        def test_prediction_shape(self):
-            """checks whether prediction has same number of variable as input series and
-            whether prediction has correct length.
-            Test cases:
-                -   univariate
-                -   multivariate
-                -   multi-TS
-            """
-            season_length = 1
-            n_repeat = 20
-
-            # data comes as multivariate
-            (
-                ts,
-                ts_train,
-                ts_val,
-                covariates,
-            ) = self.helper_generate_multivariate_case_data(season_length, n_repeat)
-
-            kwargs_TFT_quick_test = {
-                "input_chunk_length": 1,
-                "output_chunk_length": 1,
-                "n_epochs": 1,
-                "lstm_layers": 1,
-                "hidden_size": 8,
-                "loss_fn": MSELoss(),
-                "random_state": 42,
-            }
-            kwargs_TFT_quick_test = dict(kwargs_TFT_quick_test, **tfm_kwargs)
-
-            # univariate
-            first_var = ts.columns[0]
-            self.helper_test_prediction_shape(
-                season_length,
-                ts[first_var],
-                ts_train[first_var],
-                ts_val[first_var],
-                future_covariates=covariates,
-                kwargs_tft=kwargs_TFT_quick_test,
+        # should work with cyclic encoding for time index
+        model = TFTModel(
+            input_chunk_length=1,
+            output_chunk_length=1,
+            add_encoders={"cyclic": {"future": "hour"}},
+            **tfm_kwargs,
+        )
+        model.fit(ts_time_index, verbose=False)
+
+        # should work with relative index both with time index and integer index
+        model = TFTModel(
+            input_chunk_length=1,
+            output_chunk_length=1,
+            add_relative_index=True,
+            **tfm_kwargs,
+        )
+        model.fit(ts_time_index, verbose=False)
+        model.fit(ts_integer_index, verbose=False)
+
+    def test_prediction_shape(self):
+        """checks whether prediction has same number of variable as input series and
+        whether prediction has correct length.
+        Test cases:
+            -   univariate
+            -   multivariate
+            -   multi-TS
+        """
+        season_length = 1
+        n_repeat = 20
+
+        # data comes as multivariate
+        (
+            ts,
+            ts_train,
+            ts_val,
+            covariates,
+        ) = self.helper_generate_multivariate_case_data(season_length, n_repeat)
+
+        kwargs_TFT_quick_test = {
+            "input_chunk_length": 1,
+            "output_chunk_length": 1,
+            "n_epochs": 1,
+            "lstm_layers": 1,
+            "hidden_size": 8,
+            "loss_fn": MSELoss(),
+            "random_state": 42,
+        }
+        kwargs_TFT_quick_test = dict(kwargs_TFT_quick_test, **tfm_kwargs)
+
+        # univariate
+        first_var = ts.columns[0]
+        self.helper_test_prediction_shape(
+            season_length,
+            ts[first_var],
+            ts_train[first_var],
+            ts_val[first_var],
+            future_covariates=covariates,
+            kwargs_tft=kwargs_TFT_quick_test,
+        )
+        # univariate and short prediction length
+        self.helper_test_prediction_shape(
+            2,
+            ts[first_var],
+            ts_train[first_var],
+            ts_val[first_var],
+            future_covariates=covariates,
+            kwargs_tft=kwargs_TFT_quick_test,
+        )
+        # multivariate
+        self.helper_test_prediction_shape(
+            season_length,
+            ts,
+            ts_train,
+            ts_val,
+            future_covariates=covariates,
+            kwargs_tft=kwargs_TFT_quick_test,
+        )
+        # multi-TS
+        kwargs_TFT_quick_test["add_encoders"] = {"cyclic": {"future": "hour"}}
+        second_var = ts.columns[-1]
+        self.helper_test_prediction_shape(
+            season_length,
+            [ts[first_var], ts[second_var]],
+            [ts_train[first_var], ts_train[second_var]],
+            [ts_val[first_var], ts_val[second_var]],
+            future_covariates=None,
+            kwargs_tft=kwargs_TFT_quick_test,
+        )
+
+    def test_mixed_covariates_and_accuracy(self):
+        """Performs tests usingpast and future covariates for a multivariate prediction of a
+        sine wave together with a repeating linear curve. Both curves have the seasonal length.
+        """
+        season_length = 24
+        n_repeat = 30
+        (
+            ts,
+            ts_train,
+            ts_val,
+            covariates,
+        ) = self.helper_generate_multivariate_case_data(season_length, n_repeat)
+
+        kwargs_TFT_full_coverage = {
+            "input_chunk_length": 12,
+            "output_chunk_length": 12,
+            "n_epochs": 10,
+            "lstm_layers": 2,
+            "hidden_size": 32,
+            "likelihood": QuantileRegression(quantiles=[0.1, 0.5, 0.9]),
+            "random_state": 42,
+            "add_encoders": {"cyclic": {"future": "hour"}},
+        }
+        kwargs_TFT_full_coverage = dict(kwargs_TFT_full_coverage, **tfm_kwargs)
+
+        self.helper_test_prediction_accuracy(
+            season_length,
+            ts,
+            ts_train,
+            ts_val,
+            past_covariates=covariates,
+            future_covariates=covariates,
+            kwargs_tft=kwargs_TFT_full_coverage,
+        )
+
+    def test_static_covariates_support(self):
+        target_multi = concatenate(
+            [tg.sine_timeseries(length=10, freq="h")] * 2, axis=1
+        )
+
+        target_multi = target_multi.with_static_covariates(
+            pd.DataFrame(
+                [[0.0, 1.0, 0, 2], [2.0, 3.0, 1, 3]],
+                columns=["st1", "st2", "cat1", "cat2"],
             )
-            # univariate and short prediction length
-            self.helper_test_prediction_shape(
+        )
+
+        # should work with cyclic encoding for time index
+        # set categorical embedding sizes once with automatic embedding size with an `int` and once by
+        # manually setting it with `tuple(int, int)`
+        model = TFTModel(
+            input_chunk_length=3,
+            output_chunk_length=4,
+            add_encoders={"cyclic": {"future": "hour"}},
+            categorical_embedding_sizes={"cat1": 2, "cat2": (2, 2)},
+            pl_trainer_kwargs={
+                "fast_dev_run": True,
+                **tfm_kwargs["pl_trainer_kwargs"],
+            },
+        )
+        model.fit(target_multi, verbose=False)
+
+        assert len(model.model.static_variables) == len(
+            target_multi.static_covariates.columns
+        )
+
+        # check model embeddings
+        target_embedding = {
+            "static_covariate_2": (
                 2,
-                ts[first_var],
-                ts_train[first_var],
-                ts_val[first_var],
-                future_covariates=covariates,
-                kwargs_tft=kwargs_TFT_quick_test,
+                get_embedding_size(2),
+            ),  # automatic embedding size
+            "static_covariate_3": (2, 2),  # manual embedding size
+        }
+        assert model.categorical_embedding_sizes == target_embedding
+        for cat_var, embedding_dims in target_embedding.items():
+            assert (
+                model.model.input_embeddings.embeddings[cat_var].num_embeddings
+                == embedding_dims[0]
             )
-            # multivariate
-            self.helper_test_prediction_shape(
-                season_length,
-                ts,
-                ts_train,
-                ts_val,
-                future_covariates=covariates,
-                kwargs_tft=kwargs_TFT_quick_test,
-            )
-            # multi-TS
-            kwargs_TFT_quick_test["add_encoders"] = {"cyclic": {"future": "hour"}}
-            second_var = ts.columns[-1]
-            self.helper_test_prediction_shape(
-                season_length,
-                [ts[first_var], ts[second_var]],
-                [ts_train[first_var], ts_train[second_var]],
-                [ts_val[first_var], ts_val[second_var]],
-                future_covariates=None,
-                kwargs_tft=kwargs_TFT_quick_test,
+            assert (
+                model.model.input_embeddings.embeddings[cat_var].embedding_dim
+                == embedding_dims[1]
             )
 
-        def test_mixed_covariates_and_accuracy(self):
-            """Performs tests usingpast and future covariates for a multivariate prediction of a
-            sine wave together with a repeating linear curve. Both curves have the seasonal length.
-            """
-            season_length = 24
-            n_repeat = 30
-            (
-                ts,
-                ts_train,
-                ts_val,
-                covariates,
-            ) = self.helper_generate_multivariate_case_data(season_length, n_repeat)
-
-            kwargs_TFT_full_coverage = {
-                "input_chunk_length": 12,
-                "output_chunk_length": 12,
-                "n_epochs": 10,
-                "lstm_layers": 2,
-                "hidden_size": 32,
-                "likelihood": QuantileRegression(quantiles=[0.1, 0.5, 0.9]),
-                "random_state": 42,
-                "add_encoders": {"cyclic": {"future": "hour"}},
-            }
-            kwargs_TFT_full_coverage = dict(kwargs_TFT_full_coverage, **tfm_kwargs)
-
-            self.helper_test_prediction_accuracy(
-                season_length,
-                ts,
-                ts_train,
-                ts_val,
-                past_covariates=covariates,
-                future_covariates=covariates,
-                kwargs_tft=kwargs_TFT_full_coverage,
-            )
+        preds = model.predict(n=1, series=target_multi, verbose=False)
+        assert preds.static_covariates.equals(target_multi.static_covariates)
 
-        def test_static_covariates_support(self):
-            target_multi = concatenate(
-                [tg.sine_timeseries(length=10, freq="h")] * 2, axis=1
-            )
-
-            target_multi = target_multi.with_static_covariates(
-                pd.DataFrame(
-                    [[0.0, 1.0, 0, 2], [2.0, 3.0, 1, 3]],
-                    columns=["st1", "st2", "cat1", "cat2"],
-                )
-            )
-
-            # should work with cyclic encoding for time index
-            # set categorical embedding sizes once with automatic embedding size with an `int` and once by
-            # manually setting it with `tuple(int, int)`
-            model = TFTModel(
-                input_chunk_length=3,
-                output_chunk_length=4,
-                add_encoders={"cyclic": {"future": "hour"}},
-                categorical_embedding_sizes={"cat1": 2, "cat2": (2, 2)},
-                pl_trainer_kwargs={
-                    "fast_dev_run": True,
-                    **tfm_kwargs["pl_trainer_kwargs"],
-                },
-            )
-            model.fit(target_multi, verbose=False)
+        # raise an error when trained with static covariates of wrong dimensionality
+        target_multi = target_multi.with_static_covariates(
+            pd.concat([target_multi.static_covariates] * 2, axis=1)
+        )
+        with pytest.raises(ValueError):
+            model.predict(n=1, series=target_multi, verbose=False)
 
-            assert len(model.model.static_variables) == len(
-                target_multi.static_covariates.columns
+        # raise an error when trained with static covariates and trying to predict without
+        with pytest.raises(ValueError):
+            model.predict(
+                n=1, series=target_multi.with_static_covariates(None), verbose=False
             )
 
-            # check model embeddings
-            target_embedding = {
-                "static_covariate_2": (
-                    2,
-                    get_embedding_size(2),
-                ),  # automatic embedding size
-                "static_covariate_3": (2, 2),  # manual embedding size
-            }
-            assert model.categorical_embedding_sizes == target_embedding
-            for cat_var, embedding_dims in target_embedding.items():
-                assert (
-                    model.model.input_embeddings.embeddings[cat_var].num_embeddings
-                    == embedding_dims[0]
-                )
-                assert (
-                    model.model.input_embeddings.embeddings[cat_var].embedding_dim
-                    == embedding_dims[1]
-                )
-
-            preds = model.predict(n=1, series=target_multi, verbose=False)
-            assert preds.static_covariates.equals(target_multi.static_covariates)
-
-            # raise an error when trained with static covariates of wrong dimensionality
-            target_multi = target_multi.with_static_covariates(
-                pd.concat([target_multi.static_covariates] * 2, axis=1)
-            )
-            with pytest.raises(ValueError):
-                model.predict(n=1, series=target_multi, verbose=False)
-
-            # raise an error when trained with static covariates and trying to predict without
-            with pytest.raises(ValueError):
-                model.predict(
-                    n=1, series=target_multi.with_static_covariates(None), verbose=False
-                )
-
-            # with `use_static_covariates=False`, we can predict without static covs
-            model = TFTModel(
-                input_chunk_length=3,
-                output_chunk_length=4,
-                use_static_covariates=False,
-                add_relative_index=True,
-                n_epochs=1,
-                **tfm_kwargs,
-            )
-            model.fit(target_multi)
-            preds = model.predict(n=2, series=target_multi.with_static_covariates(None))
-            assert preds.static_covariates is None
-
-            model = TFTModel(
-                input_chunk_length=3,
-                output_chunk_length=4,
-                use_static_covariates=False,
-                add_relative_index=True,
-                n_epochs=1,
-                **tfm_kwargs,
-            )
-            model.fit(target_multi.with_static_covariates(None))
-            preds = model.predict(n=2, series=target_multi)
-            assert preds.static_covariates.equals(target_multi.static_covariates)
-
-        def helper_generate_multivariate_case_data(self, season_length, n_repeat):
-            """generates multivariate test case data. Target series is a sine wave stacked with a repeating
-            linear curve of equal seasonal length. Covariates are datetime attributes for 'hours'.
-            """
-
-            # generate sine wave
-            ts_sine = tg.sine_timeseries(
-                value_frequency=1 / season_length,
-                length=n_repeat * season_length,
-                freq="h",
-            )
-
-            # generate repeating linear curve
-            ts_linear = tg.linear_timeseries(
-                0, 1, length=season_length, start=ts_sine.end_time() + ts_sine.freq
-            )
-            for i in range(n_repeat - 1):
-                start = ts_linear.end_time() + ts_linear.freq
-                new_ts = tg.linear_timeseries(0, 1, length=season_length, start=start)
-                ts_linear = ts_linear.append(new_ts)
-            ts_linear = TimeSeries.from_times_and_values(
-                times=ts_sine.time_index, values=ts_linear.values()
-            )
-
-            # create multivariate TimeSeries by stacking sine and linear curves
-            ts = ts_sine.stack(ts_linear)
-
-            # create train/test sets
-            val_length = 10 * season_length
-            ts_train, ts_val = ts[:-val_length], ts[-val_length:]
-
-            # scale data
-            scaler_ts = Scaler()
-            ts_train_scaled = scaler_ts.fit_transform(ts_train)
-            ts_val_scaled = scaler_ts.transform(ts_val)
-            ts_scaled = scaler_ts.transform(ts)
-
-            # generate long enough covariates (past and future covariates will be the same for simplicity)
-            long_enough_ts = tg.sine_timeseries(
-                value_frequency=1 / season_length, length=1000, freq=ts.freq
-            )
-            covariates = tg.datetime_attribute_timeseries(
-                long_enough_ts, attribute="hour"
-            )
-            scaler_covs = Scaler()
-            covariates_scaled = scaler_covs.fit_transform(covariates)
-            return ts_scaled, ts_train_scaled, ts_val_scaled, covariates_scaled
-
-        def helper_test_prediction_shape(
-            self, predict_n, ts, ts_train, ts_val, future_covariates, kwargs_tft
-        ):
-            """checks whether prediction has same number of variable as input series and
-            whether prediction has correct length"""
-            y_hat = self.helper_fit_predict(
-                predict_n, ts_train, ts_val, None, future_covariates, kwargs_tft
-            )
-
-            y_hat_list = [y_hat] if isinstance(y_hat, TimeSeries) else y_hat
-            ts_list = [ts] if isinstance(ts, TimeSeries) else ts
-
-            for y_hat_i, ts_i in zip(y_hat_list, ts_list):
-                assert len(y_hat_i) == predict_n
-                assert y_hat_i.n_components == ts_i.n_components
-
-        def helper_test_prediction_accuracy(
-            self,
+        # with `use_static_covariates=False`, we can predict without static covs
+        model = TFTModel(
+            input_chunk_length=3,
+            output_chunk_length=4,
+            use_static_covariates=False,
+            add_relative_index=True,
+            n_epochs=1,
+            **tfm_kwargs,
+        )
+        model.fit(target_multi)
+        preds = model.predict(n=2, series=target_multi.with_static_covariates(None))
+        assert preds.static_covariates is None
+
+        model = TFTModel(
+            input_chunk_length=3,
+            output_chunk_length=4,
+            use_static_covariates=False,
+            add_relative_index=True,
+            n_epochs=1,
+            **tfm_kwargs,
+        )
+        model.fit(target_multi.with_static_covariates(None))
+        preds = model.predict(n=2, series=target_multi)
+        assert preds.static_covariates.equals(target_multi.static_covariates)
+
+    def helper_generate_multivariate_case_data(self, season_length, n_repeat):
+        """generates multivariate test case data. Target series is a sine wave stacked with a repeating
+        linear curve of equal seasonal length. Covariates are datetime attributes for 'hours'.
+        """
+
+        # generate sine wave
+        ts_sine = tg.sine_timeseries(
+            value_frequency=1 / season_length,
+            length=n_repeat * season_length,
+            freq="h",
+        )
+
+        # generate repeating linear curve
+        ts_linear = tg.linear_timeseries(
+            0, 1, length=season_length, start=ts_sine.end_time() + ts_sine.freq
+        )
+        for i in range(n_repeat - 1):
+            start = ts_linear.end_time() + ts_linear.freq
+            new_ts = tg.linear_timeseries(0, 1, length=season_length, start=start)
+            ts_linear = ts_linear.append(new_ts)
+        ts_linear = TimeSeries.from_times_and_values(
+            times=ts_sine.time_index, values=ts_linear.values()
+        )
+
+        # create multivariate TimeSeries by stacking sine and linear curves
+        ts = ts_sine.stack(ts_linear)
+
+        # create train/test sets
+        val_length = 10 * season_length
+        ts_train, ts_val = ts[:-val_length], ts[-val_length:]
+
+        # scale data
+        scaler_ts = Scaler()
+        ts_train_scaled = scaler_ts.fit_transform(ts_train)
+        ts_val_scaled = scaler_ts.transform(ts_val)
+        ts_scaled = scaler_ts.transform(ts)
+
+        # generate long enough covariates (past and future covariates will be the same for simplicity)
+        long_enough_ts = tg.sine_timeseries(
+            value_frequency=1 / season_length, length=1000, freq=ts.freq
+        )
+        covariates = tg.datetime_attribute_timeseries(long_enough_ts, attribute="hour")
+        scaler_covs = Scaler()
+        covariates_scaled = scaler_covs.fit_transform(covariates)
+        return ts_scaled, ts_train_scaled, ts_val_scaled, covariates_scaled
+
+    def helper_test_prediction_shape(
+        self, predict_n, ts, ts_train, ts_val, future_covariates, kwargs_tft
+    ):
+        """checks whether prediction has same number of variable as input series and
+        whether prediction has correct length"""
+        y_hat = self.helper_fit_predict(
+            predict_n, ts_train, ts_val, None, future_covariates, kwargs_tft
+        )
+
+        y_hat_list = [y_hat] if isinstance(y_hat, TimeSeries) else y_hat
+        ts_list = [ts] if isinstance(ts, TimeSeries) else ts
+
+        for y_hat_i, ts_i in zip(y_hat_list, ts_list):
+            assert len(y_hat_i) == predict_n
+            assert y_hat_i.n_components == ts_i.n_components
+
+    def helper_test_prediction_accuracy(
+        self,
+        predict_n,
+        ts,
+        ts_train,
+        ts_val,
+        past_covariates,
+        future_covariates,
+        kwargs_tft,
+    ):
+        """prediction should be almost equal to y_true. Absolute tolarance is set
+        to 0.2 to give some flexibility"""
+
+        absolute_tolarance = 0.2
+        y_hat = self.helper_fit_predict(
             predict_n,
-            ts,
             ts_train,
             ts_val,
             past_covariates,
             future_covariates,
             kwargs_tft,
-        ):
-            """prediction should be almost equal to y_true. Absolute tolarance is set
-            to 0.2 to give some flexibility"""
-
-            absolute_tolarance = 0.2
-            y_hat = self.helper_fit_predict(
-                predict_n,
-                ts_train,
-                ts_val,
-                past_covariates,
-                future_covariates,
-                kwargs_tft,
-            )
-
-            y_true = ts[y_hat.start_time() : y_hat.end_time()]
-            assert np.allclose(
-                y_true[1:-1].all_values(),
-                y_hat[1:-1].all_values(),
-                atol=absolute_tolarance,
-            )
-
-        @staticmethod
-        def helper_fit_predict(
-            predict_n, ts_train, ts_val, past_covariates, future_covariates, kwargs_tft
-        ):
-            """simple helper that returns prediction for the individual test cases"""
-            model = TFTModel(**kwargs_tft)
-
-            model.fit(
-                ts_train,
-                past_covariates=past_covariates,
-                future_covariates=future_covariates,
-                val_series=ts_val,
-                val_past_covariates=past_covariates,
-                val_future_covariates=future_covariates,
-                verbose=False,
-            )
-
-            series = None if isinstance(ts_train, TimeSeries) else ts_train
-            y_hat = model.predict(
-                n=predict_n,
-                series=series,
-                past_covariates=past_covariates,
-                future_covariates=future_covariates,
-                num_samples=(100 if model.supports_probabilistic_prediction else 1),
-            )
-
-            if isinstance(y_hat, TimeSeries):
-                y_hat = y_hat.quantile_timeseries(0.5) if y_hat.n_samples > 1 else y_hat
-            else:
-                y_hat = [
-                    ts.quantile_timeseries(0.5) if ts.n_samples > 1 else ts
-                    for ts in y_hat
-                ]
-            return y_hat
-
-        def test_layer_norm(self):
-            times = pd.date_range("20130101", "20130410")
-            pd_series = pd.Series(range(100), index=times)
-            series: TimeSeries = TimeSeries.from_series(pd_series)
-            base_model = TFTModel
-
-            model1 = base_model(
-                input_chunk_length=1,
-                output_chunk_length=1,
-                add_relative_index=True,
-                norm_type="RMSNorm",
-                **tfm_kwargs,
-            )
-            model1.fit(series, epochs=1)
-
-            model2 = base_model(
+        )
+
+        y_true = ts[y_hat.start_time() : y_hat.end_time()]
+        assert np.allclose(
+            y_true[1:-1].all_values(),
+            y_hat[1:-1].all_values(),
+            atol=absolute_tolarance,
+        )
+
+    @staticmethod
+    def helper_fit_predict(
+        predict_n, ts_train, ts_val, past_covariates, future_covariates, kwargs_tft
+    ):
+        """simple helper that returns prediction for the individual test cases"""
+        model = TFTModel(**kwargs_tft)
+
+        model.fit(
+            ts_train,
+            past_covariates=past_covariates,
+            future_covariates=future_covariates,
+            val_series=ts_val,
+            val_past_covariates=past_covariates,
+            val_future_covariates=future_covariates,
+            verbose=False,
+        )
+
+        series = None if isinstance(ts_train, TimeSeries) else ts_train
+        y_hat = model.predict(
+            n=predict_n,
+            series=series,
+            past_covariates=past_covariates,
+            future_covariates=future_covariates,
+            num_samples=(100 if model.supports_probabilistic_prediction else 1),
+        )
+
+        if isinstance(y_hat, TimeSeries):
+            y_hat = y_hat.quantile_timeseries(0.5) if y_hat.n_samples > 1 else y_hat
+        else:
+            y_hat = [
+                ts.quantile_timeseries(0.5) if ts.n_samples > 1 else ts for ts in y_hat
+            ]
+        return y_hat
+
+    def test_layer_norm(self):
+        times = pd.date_range("20130101", "20130410")
+        pd_series = pd.Series(range(100), index=times)
+        series: TimeSeries = TimeSeries.from_series(pd_series)
+        base_model = TFTModel
+
+        model1 = base_model(
+            input_chunk_length=1,
+            output_chunk_length=1,
+            add_relative_index=True,
+            norm_type="RMSNorm",
+            **tfm_kwargs,
+        )
+        model1.fit(series, epochs=1)
+
+        model2 = base_model(
+            input_chunk_length=1,
+            output_chunk_length=1,
+            add_relative_index=True,
+            norm_type=nn.LayerNorm,
+            **tfm_kwargs,
+        )
+        model2.fit(series, epochs=1)
+
+        with pytest.raises(AttributeError):
+            model4 = base_model(
                 input_chunk_length=1,
                 output_chunk_length=1,
                 add_relative_index=True,
-                norm_type=nn.LayerNorm,
+                norm_type="invalid",
                 **tfm_kwargs,
             )
-            model2.fit(series, epochs=1)
-
-            with pytest.raises(AttributeError):
-                model4 = base_model(
-                    input_chunk_length=1,
-                    output_chunk_length=1,
-                    add_relative_index=True,
-                    norm_type="invalid",
-                    **tfm_kwargs,
-                )
-                model4.fit(series, epochs=1)
+            model4.fit(series, epochs=1)
diff --git a/darts/tests/models/forecasting/test_block_RNN.py b/darts/tests/models/forecasting/test_block_RNN.py
index 9415ace0f4..3e69836e04 100644
--- a/darts/tests/models/forecasting/test_block_RNN.py
+++ b/darts/tests/models/forecasting/test_block_RNN.py
@@ -16,171 +16,171 @@
         CustomBlockRNNModule,
         _BlockRNNModule,
     )
-
-    TORCH_AVAILABLE = True
 except ImportError:
-    logger.warning("Torch not available. RNN tests will be skipped.")
-    TORCH_AVAILABLE = False
+    pytest.skip(
+        f"Torch not available. {__name__} tests will be skipped.",
+        allow_module_level=True,
+    )
 
 
-if TORCH_AVAILABLE:
+class ModuleValid1(_BlockRNNModule):
+    """Wrapper around the _BlockRNNModule"""
 
-    class ModuleValid1(_BlockRNNModule):
-        """Wrapper around the _BlockRNNModule"""
+    def __init__(self, **kwargs):
+        super().__init__(name="RNN", **kwargs)
 
-        def __init__(self, **kwargs):
-            super().__init__(name="RNN", **kwargs)
 
-    class ModuleValid2(CustomBlockRNNModule):
-        """Just a linear layer."""
+class ModuleValid2(CustomBlockRNNModule):
+    """Just a linear layer."""
 
-        def __init__(self, **kwargs):
-            super().__init__(**kwargs)
-            self.linear = nn.Linear(self.input_size, self.target_size)
+    def __init__(self, **kwargs):
+        super().__init__(**kwargs)
+        self.linear = nn.Linear(self.input_size, self.target_size)
 
-        def forward(self, x_in):
-            x = self.linear(x_in[0])
-            return x.view(len(x), -1, self.target_size, self.nr_params)
+    def forward(self, x_in):
+        x = self.linear(x_in[0])
+        return x.view(len(x), -1, self.target_size, self.nr_params)
 
-    class TestBlockRNNModel:
-        times = pd.date_range("20130101", "20130410")
-        pd_series = pd.Series(range(100), index=times)
-        series: TimeSeries = TimeSeries.from_series(pd_series)
-        module_invalid = _BlockRNNModule(
-            "RNN",
-            input_size=1,
-            input_chunk_length=1,
-            output_chunk_length=1,
-            output_chunk_shift=0,
-            hidden_dim=25,
-            target_size=1,
-            nr_params=1,
-            num_layers=1,
-            num_layers_out_fc=[],
-            dropout=0,
-        )
 
-        def test_creation(self):
-            # cannot choose any string
-            with pytest.raises(ValueError) as msg:
-                BlockRNNModel(
-                    input_chunk_length=1, output_chunk_length=1, model="UnknownRNN?"
-                )
-            assert str(msg.value).startswith("`model` is not a valid RNN model.")
-
-            # cannot create from a class instance
-            with pytest.raises(ValueError) as msg:
-                _ = BlockRNNModel(
-                    input_chunk_length=1,
-                    output_chunk_length=1,
-                    model=self.module_invalid,
-                )
-            assert str(msg.value).startswith("`model` is not a valid RNN model.")
-
-            # can create from valid module name
-            model1 = BlockRNNModel(
-                input_chunk_length=1,
-                output_chunk_length=1,
-                model="RNN",
-                n_epochs=1,
-                random_state=42,
-                **tfm_kwargs,
-            )
-            model1.fit(self.series)
-            preds1 = model1.predict(n=3)
+class TestBlockRNNModel:
+    times = pd.date_range("20130101", "20130410")
+    pd_series = pd.Series(range(100), index=times)
+    series: TimeSeries = TimeSeries.from_series(pd_series)
+    module_invalid = _BlockRNNModule(
+        "RNN",
+        input_size=1,
+        input_chunk_length=1,
+        output_chunk_length=1,
+        output_chunk_shift=0,
+        hidden_dim=25,
+        target_size=1,
+        nr_params=1,
+        num_layers=1,
+        num_layers_out_fc=[],
+        dropout=0,
+    )
 
-            # can create from a custom class itself
-            model2 = BlockRNNModel(
-                input_chunk_length=1,
-                output_chunk_length=1,
-                model=ModuleValid1,
-                n_epochs=1,
-                random_state=42,
-                **tfm_kwargs,
+    def test_creation(self):
+        # cannot choose any string
+        with pytest.raises(ValueError) as msg:
+            BlockRNNModel(
+                input_chunk_length=1, output_chunk_length=1, model="UnknownRNN?"
             )
-            model2.fit(self.series)
-            preds2 = model2.predict(n=3)
-            np.testing.assert_array_equal(preds1.all_values(), preds2.all_values())
+        assert str(msg.value).startswith("`model` is not a valid RNN model.")
 
-            model3 = BlockRNNModel(
+        # cannot create from a class instance
+        with pytest.raises(ValueError) as msg:
+            _ = BlockRNNModel(
                 input_chunk_length=1,
                 output_chunk_length=1,
-                model=ModuleValid2,
-                n_epochs=1,
-                random_state=42,
-                **tfm_kwargs,
+                model=self.module_invalid,
             )
-            model3.fit(self.series)
-            preds3 = model2.predict(n=3)
-            assert preds3.all_values().shape == preds2.all_values().shape
-            assert preds3.time_index.equals(preds2.time_index)
-
-        def test_fit(self, tmpdir_module):
-            # Test basic fit()
-            model = BlockRNNModel(
-                input_chunk_length=1, output_chunk_length=1, n_epochs=2, **tfm_kwargs
-            )
-            model.fit(self.series)
+        assert str(msg.value).startswith("`model` is not a valid RNN model.")
 
-            # Test fit-save-load cycle
-            model2 = BlockRNNModel(
-                input_chunk_length=1,
-                output_chunk_length=1,
-                model="LSTM",
-                n_epochs=1,
-                model_name="unittest-model-lstm",
-                work_dir=tmpdir_module,
-                save_checkpoints=True,
-                force_reset=True,
-                **tfm_kwargs,
-            )
-            model2.fit(self.series)
-            model_loaded = model2.load_from_checkpoint(
-                model_name="unittest-model-lstm",
-                work_dir=tmpdir_module,
-                best=False,
-                map_location="cpu",
-            )
-            pred1 = model2.predict(n=6)
-            pred2 = model_loaded.predict(n=6)
+        # can create from valid module name
+        model1 = BlockRNNModel(
+            input_chunk_length=1,
+            output_chunk_length=1,
+            model="RNN",
+            n_epochs=1,
+            random_state=42,
+            **tfm_kwargs,
+        )
+        model1.fit(self.series)
+        preds1 = model1.predict(n=3)
 
-            # Two models with the same parameters should deterministically yield the same output
-            np.testing.assert_array_equal(pred1.values(), pred2.values())
+        # can create from a custom class itself
+        model2 = BlockRNNModel(
+            input_chunk_length=1,
+            output_chunk_length=1,
+            model=ModuleValid1,
+            n_epochs=1,
+            random_state=42,
+            **tfm_kwargs,
+        )
+        model2.fit(self.series)
+        preds2 = model2.predict(n=3)
+        np.testing.assert_array_equal(preds1.all_values(), preds2.all_values())
 
-            # Another random model should not
-            model3 = BlockRNNModel(
-                input_chunk_length=1,
-                output_chunk_length=1,
-                model="RNN",
-                n_epochs=2,
-                **tfm_kwargs,
-            )
-            model3.fit(self.series)
-            pred3 = model3.predict(n=6)
-            assert not np.array_equal(pred1.values(), pred3.values())
-
-            # test short predict
-            pred4 = model3.predict(n=1)
-            assert len(pred4) == 1
-
-            # test validation series input
-            model3.fit(self.series[:60], val_series=self.series[60:])
-            pred4 = model3.predict(n=6)
-            assert len(pred4) == 6
-
-        def helper_test_pred_length(self, pytorch_model, series):
-            model = pytorch_model(
-                input_chunk_length=1, output_chunk_length=3, n_epochs=1, **tfm_kwargs
-            )
-            model.fit(series)
-            pred = model.predict(7)
-            assert len(pred) == 7
-            pred = model.predict(2)
-            assert len(pred) == 2
-            assert pred.width == 1
-            pred = model.predict(4)
-            assert len(pred) == 4
-            assert pred.width == 1
-
-        def test_pred_length(self):
-            self.helper_test_pred_length(BlockRNNModel, self.series)
+        model3 = BlockRNNModel(
+            input_chunk_length=1,
+            output_chunk_length=1,
+            model=ModuleValid2,
+            n_epochs=1,
+            random_state=42,
+            **tfm_kwargs,
+        )
+        model3.fit(self.series)
+        preds3 = model2.predict(n=3)
+        assert preds3.all_values().shape == preds2.all_values().shape
+        assert preds3.time_index.equals(preds2.time_index)
+
+    def test_fit(self, tmpdir_module):
+        # Test basic fit()
+        model = BlockRNNModel(
+            input_chunk_length=1, output_chunk_length=1, n_epochs=2, **tfm_kwargs
+        )
+        model.fit(self.series)
+
+        # Test fit-save-load cycle
+        model2 = BlockRNNModel(
+            input_chunk_length=1,
+            output_chunk_length=1,
+            model="LSTM",
+            n_epochs=1,
+            model_name="unittest-model-lstm",
+            work_dir=tmpdir_module,
+            save_checkpoints=True,
+            force_reset=True,
+            **tfm_kwargs,
+        )
+        model2.fit(self.series)
+        model_loaded = model2.load_from_checkpoint(
+            model_name="unittest-model-lstm",
+            work_dir=tmpdir_module,
+            best=False,
+            map_location="cpu",
+        )
+        pred1 = model2.predict(n=6)
+        pred2 = model_loaded.predict(n=6)
+
+        # Two models with the same parameters should deterministically yield the same output
+        np.testing.assert_array_equal(pred1.values(), pred2.values())
+
+        # Another random model should not
+        model3 = BlockRNNModel(
+            input_chunk_length=1,
+            output_chunk_length=1,
+            model="RNN",
+            n_epochs=2,
+            **tfm_kwargs,
+        )
+        model3.fit(self.series)
+        pred3 = model3.predict(n=6)
+        assert not np.array_equal(pred1.values(), pred3.values())
+
+        # test short predict
+        pred4 = model3.predict(n=1)
+        assert len(pred4) == 1
+
+        # test validation series input
+        model3.fit(self.series[:60], val_series=self.series[60:])
+        pred4 = model3.predict(n=6)
+        assert len(pred4) == 6
+
+    def helper_test_pred_length(self, pytorch_model, series):
+        model = pytorch_model(
+            input_chunk_length=1, output_chunk_length=3, n_epochs=1, **tfm_kwargs
+        )
+        model.fit(series)
+        pred = model.predict(7)
+        assert len(pred) == 7
+        pred = model.predict(2)
+        assert len(pred) == 2
+        assert pred.width == 1
+        pred = model.predict(4)
+        assert len(pred) == 4
+        assert pred.width == 1
+
+    def test_pred_length(self):
+        self.helper_test_pred_length(BlockRNNModel, self.series)
diff --git a/darts/tests/models/forecasting/test_dlinear_nlinear.py b/darts/tests/models/forecasting/test_dlinear_nlinear.py
index 5aca7c5f2b..61caa193d1 100644
--- a/darts/tests/models/forecasting/test_dlinear_nlinear.py
+++ b/darts/tests/models/forecasting/test_dlinear_nlinear.py
@@ -18,330 +18,320 @@
     from darts.models.forecasting.dlinear import DLinearModel
     from darts.models.forecasting.nlinear import NLinearModel
     from darts.utils.likelihood_models import GaussianLikelihood
+except ImportError:
+    pytest.skip(
+        f"Torch not available. {__name__} tests will be skipped.",
+        allow_module_level=True,
+    )
 
-    TORCH_AVAILABLE = True
 
-except ImportError:
-    logger.warning("Torch not available. Dlinear and NLinear tests will be skipped.")
-    TORCH_AVAILABLE = False
+class TestDlinearNlinearModels:
+    np.random.seed(42)
+    torch.manual_seed(42)
+
+    def test_creation(self):
+        with pytest.raises(ValueError):
+            DLinearModel(
+                input_chunk_length=1,
+                output_chunk_length=1,
+                normalize=True,
+                likelihood=GaussianLikelihood(),
+            )
 
-if TORCH_AVAILABLE:
+        with pytest.raises(ValueError):
+            NLinearModel(
+                input_chunk_length=1,
+                output_chunk_length=1,
+                normalize=True,
+                likelihood=GaussianLikelihood(),
+            )
 
-    class TestDlinearNlinearModels:
+    def test_fit(self):
+        large_ts = tg.constant_timeseries(length=100, value=1000)
+        small_ts = tg.constant_timeseries(length=100, value=10)
+
+        for model_cls, kwargs in [
+            (DLinearModel, {"kernel_size": 5}),
+            (DLinearModel, {"kernel_size": 6}),
+            (NLinearModel, {}),
+        ]:
+            # Test basic fit and predict
+            model = model_cls(
+                input_chunk_length=1,
+                output_chunk_length=1,
+                n_epochs=10,
+                random_state=42,
+                **kwargs,
+                **tfm_kwargs,
+            )
+            model.fit(large_ts[:98])
+            pred = model.predict(n=2).values()[0]
+
+            # Test whether model trained on one series is better than one trained on another
+            model2 = model_cls(
+                input_chunk_length=1,
+                output_chunk_length=1,
+                n_epochs=10,
+                random_state=42,
+                **tfm_kwargs,
+            )
+            model2.fit(small_ts[:98])
+            pred2 = model2.predict(n=2).values()[0]
+            assert abs(pred2 - 10) < abs(pred - 10)
+
+            # test short predict
+            pred3 = model2.predict(n=1)
+            assert len(pred3) == 1
+
+    def test_logtensorboard(self, tmpdir_module):
+        ts = tg.constant_timeseries(length=50, value=10)
+
+        for model_cls in [DLinearModel, NLinearModel]:
+            # Test basic fit and predict
+            model = model_cls(
+                input_chunk_length=1,
+                output_chunk_length=1,
+                n_epochs=1,
+                log_tensorboard=True,
+                work_dir=tmpdir_module,
+                pl_trainer_kwargs={
+                    "log_every_n_steps": 1,
+                    **tfm_kwargs["pl_trainer_kwargs"],
+                },
+            )
+            model.fit(ts)
+            model.predict(n=2)
+
+    def test_shared_weights(self):
+        ts = tg.constant_timeseries(length=50, value=10).stack(
+            tg.gaussian_timeseries(length=50)
+        )
+
+        for model_cls in [DLinearModel, NLinearModel]:
+            # Test basic fit and predict
+            model_shared = model_cls(
+                input_chunk_length=5,
+                output_chunk_length=1,
+                n_epochs=2,
+                const_init=False,
+                shared_weights=True,
+                random_state=42,
+                **tfm_kwargs,
+            )
+            model_not_shared = model_cls(
+                input_chunk_length=5,
+                output_chunk_length=1,
+                n_epochs=2,
+                const_init=False,
+                shared_weights=False,
+                random_state=42,
+                **tfm_kwargs,
+            )
+            model_shared.fit(ts)
+            model_not_shared.fit(ts)
+            pred_shared = model_shared.predict(n=2)
+            pred_not_shared = model_not_shared.predict(n=2)
+            assert np.any(np.not_equal(pred_shared.values(), pred_not_shared.values()))
+
+    def test_multivariate_and_covariates(self):
         np.random.seed(42)
         torch.manual_seed(42)
+        # test on multiple multivariate series with future and static covariates
+
+        def _create_multiv_series(f1, f2, n1, n2, nf1, nf2):
+            bases = [
+                tg.sine_timeseries(length=400, value_frequency=f, value_amplitude=1.0)
+                for f in (f1, f2)
+            ]
+            noises = [tg.gaussian_timeseries(length=400, std=n) for n in (n1, n2)]
+            noise_modulators = [
+                tg.sine_timeseries(length=400, value_frequency=nf)
+                + tg.constant_timeseries(length=400, value=1) / 2
+                for nf in (nf1, nf2)
+            ]
+            noises = [noises[i] * noise_modulators[i] for i in range(len(noises))]
+
+            target = concatenate(
+                [bases[i] + noises[i] for i in range(len(bases))], axis="component"
+            )
 
-        def test_creation(self):
-            with pytest.raises(ValueError):
-                DLinearModel(
-                    input_chunk_length=1,
-                    output_chunk_length=1,
-                    normalize=True,
-                    likelihood=GaussianLikelihood(),
-                )
+            target = target.with_static_covariates(
+                pd.DataFrame([[f1, n1, nf1], [f2, n2, nf2]])
+            )
 
-            with pytest.raises(ValueError):
-                NLinearModel(
-                    input_chunk_length=1,
-                    output_chunk_length=1,
-                    normalize=True,
-                    likelihood=GaussianLikelihood(),
-                )
+            return target, concatenate(noise_modulators, axis="component")
+
+        def _eval_model(
+            train1,
+            train2,
+            val1,
+            val2,
+            fut_cov1,
+            fut_cov2,
+            past_cov1=None,
+            past_cov2=None,
+            val_past_cov1=None,
+            val_past_cov2=None,
+            cls=DLinearModel,
+            lkl=None,
+            **kwargs,
+        ):
+            model = cls(
+                input_chunk_length=50,
+                output_chunk_length=10,
+                shared_weights=False,
+                const_init=True,
+                likelihood=lkl,
+                random_state=42,
+                **tfm_kwargs,
+            )
 
-        def test_fit(self):
-            large_ts = tg.constant_timeseries(length=100, value=1000)
-            small_ts = tg.constant_timeseries(length=100, value=10)
-
-            for model_cls, kwargs in [
-                (DLinearModel, {"kernel_size": 5}),
-                (DLinearModel, {"kernel_size": 6}),
-                (NLinearModel, {}),
-            ]:
-                # Test basic fit and predict
-                model = model_cls(
-                    input_chunk_length=1,
-                    output_chunk_length=1,
-                    n_epochs=10,
-                    random_state=42,
-                    **kwargs,
-                    **tfm_kwargs,
-                )
-                model.fit(large_ts[:98])
-                pred = model.predict(n=2).values()[0]
-
-                # Test whether model trained on one series is better than one trained on another
-                model2 = model_cls(
-                    input_chunk_length=1,
-                    output_chunk_length=1,
-                    n_epochs=10,
-                    random_state=42,
-                    **tfm_kwargs,
-                )
-                model2.fit(small_ts[:98])
-                pred2 = model2.predict(n=2).values()[0]
-                assert abs(pred2 - 10) < abs(pred - 10)
-
-                # test short predict
-                pred3 = model2.predict(n=1)
-                assert len(pred3) == 1
-
-        def test_logtensorboard(self, tmpdir_module):
-            ts = tg.constant_timeseries(length=50, value=10)
-
-            for model_cls in [DLinearModel, NLinearModel]:
-                # Test basic fit and predict
-                model = model_cls(
-                    input_chunk_length=1,
-                    output_chunk_length=1,
-                    n_epochs=1,
-                    log_tensorboard=True,
-                    work_dir=tmpdir_module,
-                    pl_trainer_kwargs={
-                        "log_every_n_steps": 1,
-                        **tfm_kwargs["pl_trainer_kwargs"],
-                    },
-                )
-                model.fit(ts)
-                model.predict(n=2)
+            model.fit(
+                [train1, train2],
+                past_covariates=(
+                    [past_cov1, past_cov2] if past_cov1 is not None else None
+                ),
+                val_past_covariates=(
+                    [val_past_cov1, val_past_cov2]
+                    if val_past_cov1 is not None
+                    else None
+                ),
+                future_covariates=(
+                    [fut_cov1, fut_cov2] if fut_cov1 is not None else None
+                ),
+                epochs=10,
+            )
 
-        def test_shared_weights(self):
-            ts = tg.constant_timeseries(length=50, value=10).stack(
-                tg.gaussian_timeseries(length=50)
+            pred1, pred2 = model.predict(
+                series=[train1, train2],
+                future_covariates=(
+                    [fut_cov1, fut_cov2] if fut_cov1 is not None else None
+                ),
+                past_covariates=(
+                    [fut_cov1, fut_cov2] if past_cov1 is not None else None
+                ),
+                n=len(val1),
+                num_samples=500 if lkl is not None else 1,
             )
 
-            for model_cls in [DLinearModel, NLinearModel]:
-                # Test basic fit and predict
-                model_shared = model_cls(
-                    input_chunk_length=5,
-                    output_chunk_length=1,
-                    n_epochs=2,
-                    const_init=False,
-                    shared_weights=True,
-                    random_state=42,
-                    **tfm_kwargs,
-                )
-                model_not_shared = model_cls(
-                    input_chunk_length=5,
-                    output_chunk_length=1,
-                    n_epochs=2,
-                    const_init=False,
-                    shared_weights=False,
-                    random_state=42,
-                    **tfm_kwargs,
-                )
-                model_shared.fit(ts)
-                model_not_shared.fit(ts)
-                pred_shared = model_shared.predict(n=2)
-                pred_not_shared = model_not_shared.predict(n=2)
-                assert np.any(
-                    np.not_equal(pred_shared.values(), pred_not_shared.values())
-                )
+            return rmse(val1, pred1), rmse(val2, pred2)
 
-        def test_multivariate_and_covariates(self):
-            np.random.seed(42)
-            torch.manual_seed(42)
-            # test on multiple multivariate series with future and static covariates
-
-            def _create_multiv_series(f1, f2, n1, n2, nf1, nf2):
-                bases = [
-                    tg.sine_timeseries(
-                        length=400, value_frequency=f, value_amplitude=1.0
-                    )
-                    for f in (f1, f2)
-                ]
-                noises = [tg.gaussian_timeseries(length=400, std=n) for n in (n1, n2)]
-                noise_modulators = [
-                    tg.sine_timeseries(length=400, value_frequency=nf)
-                    + tg.constant_timeseries(length=400, value=1) / 2
-                    for nf in (nf1, nf2)
-                ]
-                noises = [noises[i] * noise_modulators[i] for i in range(len(noises))]
-
-                target = concatenate(
-                    [bases[i] + noises[i] for i in range(len(bases))], axis="component"
-                )
+        series1, fut_cov1 = _create_multiv_series(0.05, 0.07, 0.2, 0.4, 0.02, 0.03)
+        series2, fut_cov2 = _create_multiv_series(0.04, 0.03, 0.4, 0.1, 0.02, 0.04)
 
-                target = target.with_static_covariates(
-                    pd.DataFrame([[f1, n1, nf1], [f2, n2, nf2]])
-                )
+        train1, val1 = series1.split_after(0.7)
+        train2, val2 = series2.split_after(0.7)
+        past_cov1 = train1.copy()
+        past_cov2 = train2.copy()
+        val_past_cov1 = val1.copy()
+        val_past_cov2 = val2.copy()
+
+        for model, lkl in product(
+            [DLinearModel, NLinearModel], [None, GaussianLikelihood()]
+        ):
 
-                return target, concatenate(noise_modulators, axis="component")
+            e1, e2 = _eval_model(
+                train1, train2, val1, val2, fut_cov1, fut_cov2, cls=model, lkl=lkl
+            )
+            assert e1 <= 0.34
+            assert e2 <= 0.28
 
-            def _eval_model(
-                train1,
-                train2,
+            e1, e2 = _eval_model(
+                train1.with_static_covariates(None),
+                train2.with_static_covariates(None),
                 val1,
                 val2,
                 fut_cov1,
                 fut_cov2,
-                past_cov1=None,
-                past_cov2=None,
-                val_past_cov1=None,
-                val_past_cov2=None,
-                cls=DLinearModel,
-                lkl=None,
-                **kwargs
-            ):
-                model = cls(
-                    input_chunk_length=50,
-                    output_chunk_length=10,
-                    shared_weights=False,
-                    const_init=True,
-                    likelihood=lkl,
-                    random_state=42,
-                    **tfm_kwargs,
-                )
-
-                model.fit(
-                    [train1, train2],
-                    past_covariates=(
-                        [past_cov1, past_cov2] if past_cov1 is not None else None
-                    ),
-                    val_past_covariates=(
-                        [val_past_cov1, val_past_cov2]
-                        if val_past_cov1 is not None
-                        else None
-                    ),
-                    future_covariates=(
-                        [fut_cov1, fut_cov2] if fut_cov1 is not None else None
-                    ),
-                    epochs=10,
-                )
-
-                pred1, pred2 = model.predict(
-                    series=[train1, train2],
-                    future_covariates=(
-                        [fut_cov1, fut_cov2] if fut_cov1 is not None else None
-                    ),
-                    past_covariates=(
-                        [fut_cov1, fut_cov2] if past_cov1 is not None else None
-                    ),
-                    n=len(val1),
-                    num_samples=500 if lkl is not None else 1,
-                )
-
-                return rmse(val1, pred1), rmse(val2, pred2)
-
-            series1, fut_cov1 = _create_multiv_series(0.05, 0.07, 0.2, 0.4, 0.02, 0.03)
-            series2, fut_cov2 = _create_multiv_series(0.04, 0.03, 0.4, 0.1, 0.02, 0.04)
-
-            train1, val1 = series1.split_after(0.7)
-            train2, val2 = series2.split_after(0.7)
-            past_cov1 = train1.copy()
-            past_cov2 = train2.copy()
-            val_past_cov1 = val1.copy()
-            val_past_cov2 = val2.copy()
-
-            for model, lkl in product(
-                [DLinearModel, NLinearModel], [None, GaussianLikelihood()]
-            ):
-
-                e1, e2 = _eval_model(
-                    train1, train2, val1, val2, fut_cov1, fut_cov2, cls=model, lkl=lkl
-                )
-                assert e1 <= 0.34
-                assert e2 <= 0.28
-
-                e1, e2 = _eval_model(
-                    train1.with_static_covariates(None),
-                    train2.with_static_covariates(None),
-                    val1,
-                    val2,
-                    fut_cov1,
-                    fut_cov2,
-                    cls=model,
-                    lkl=lkl,
-                )
-                assert e1 <= 0.32
-                assert e2 <= 0.28
+                cls=model,
+                lkl=lkl,
+            )
+            assert e1 <= 0.32
+            assert e2 <= 0.28
 
-                e1, e2 = _eval_model(
-                    train1, train2, val1, val2, None, None, cls=model, lkl=lkl
-                )
-                assert e1 <= 0.40
-                assert e2 <= 0.34
-
-                e1, e2 = _eval_model(
-                    train1.with_static_covariates(None),
-                    train2.with_static_covariates(None),
-                    val1,
-                    val2,
-                    None,
-                    None,
-                    cls=model,
-                    lkl=lkl,
-                )
-                assert e1 <= 0.40
-                assert e2 <= 0.34
+            e1, e2 = _eval_model(
+                train1, train2, val1, val2, None, None, cls=model, lkl=lkl
+            )
+            assert e1 <= 0.40
+            assert e2 <= 0.34
 
             e1, e2 = _eval_model(
-                train1,
-                train2,
+                train1.with_static_covariates(None),
+                train2.with_static_covariates(None),
                 val1,
                 val2,
-                fut_cov1,
-                fut_cov2,
-                past_cov1=past_cov1,
-                past_cov2=past_cov2,
-                val_past_cov1=val_past_cov1,
-                val_past_cov2=val_past_cov2,
-                cls=NLinearModel,
-                lkl=None,
-                normalize=True,
+                None,
+                None,
+                cls=model,
+                lkl=lkl,
             )
-            # can only fit models with past/future covariates when shared_weights=False
-            for model in [DLinearModel, NLinearModel]:
-                for shared_weights in [True, False]:
-                    model_instance = model(
-                        5, 5, shared_weights=shared_weights, **tfm_kwargs
-                    )
-                    assert model_instance.supports_past_covariates == (
-                        not shared_weights
-                    )
-                    assert model_instance.supports_future_covariates == (
-                        not shared_weights
-                    )
-                    if shared_weights:
-                        with pytest.raises(ValueError):
-                            model_instance.fit(series1, future_covariates=fut_cov1)
-
-        def test_optional_static_covariates(self):
-            series = tg.sine_timeseries(length=20).with_static_covariates(
-                pd.DataFrame({"a": [1]})
-            )
-            for model_cls in [NLinearModel, DLinearModel]:
-                # training model with static covs and predicting without will raise an error
-                model = model_cls(
-                    input_chunk_length=12,
-                    output_chunk_length=6,
-                    use_static_covariates=True,
-                    n_epochs=1,
-                    **tfm_kwargs,
-                )
-                model.fit(series)
-                with pytest.raises(ValueError):
-                    model.predict(n=2, series=series.with_static_covariates(None))
-
-                # with `use_static_covariates=False`, static covariates are ignored and prediction works
-                model = model_cls(
-                    input_chunk_length=12,
-                    output_chunk_length=6,
-                    use_static_covariates=False,
-                    n_epochs=1,
-                    **tfm_kwargs,
+            assert e1 <= 0.40
+            assert e2 <= 0.34
+
+        e1, e2 = _eval_model(
+            train1,
+            train2,
+            val1,
+            val2,
+            fut_cov1,
+            fut_cov2,
+            past_cov1=past_cov1,
+            past_cov2=past_cov2,
+            val_past_cov1=val_past_cov1,
+            val_past_cov2=val_past_cov2,
+            cls=NLinearModel,
+            lkl=None,
+            normalize=True,
+        )
+        # can only fit models with past/future covariates when shared_weights=False
+        for model in [DLinearModel, NLinearModel]:
+            for shared_weights in [True, False]:
+                model_instance = model(
+                    5, 5, shared_weights=shared_weights, **tfm_kwargs
                 )
-                model.fit(series)
-                preds = model.predict(n=2, series=series.with_static_covariates(None))
-                assert preds.static_covariates is None
-
-                # with `use_static_covariates=False`, static covariates are ignored and prediction works
-                model = model_cls(
-                    input_chunk_length=12,
-                    output_chunk_length=6,
-                    use_static_covariates=False,
-                    n_epochs=1,
-                    **tfm_kwargs,
-                )
-                model.fit(series.with_static_covariates(None))
-                preds = model.predict(n=2, series=series)
-                assert preds.static_covariates.equals(series.static_covariates)
+                assert model_instance.supports_past_covariates == (not shared_weights)
+                assert model_instance.supports_future_covariates == (not shared_weights)
+                if shared_weights:
+                    with pytest.raises(ValueError):
+                        model_instance.fit(series1, future_covariates=fut_cov1)
+
+    def test_optional_static_covariates(self):
+        series = tg.sine_timeseries(length=20).with_static_covariates(
+            pd.DataFrame({"a": [1]})
+        )
+        for model_cls in [NLinearModel, DLinearModel]:
+            # training model with static covs and predicting without will raise an error
+            model = model_cls(
+                input_chunk_length=12,
+                output_chunk_length=6,
+                use_static_covariates=True,
+                n_epochs=1,
+                **tfm_kwargs,
+            )
+            model.fit(series)
+            with pytest.raises(ValueError):
+                model.predict(n=2, series=series.with_static_covariates(None))
+
+            # with `use_static_covariates=False`, static covariates are ignored and prediction works
+            model = model_cls(
+                input_chunk_length=12,
+                output_chunk_length=6,
+                use_static_covariates=False,
+                n_epochs=1,
+                **tfm_kwargs,
+            )
+            model.fit(series)
+            preds = model.predict(n=2, series=series.with_static_covariates(None))
+            assert preds.static_covariates is None
+
+            # with `use_static_covariates=False`, static covariates are ignored and prediction works
+            model = model_cls(
+                input_chunk_length=12,
+                output_chunk_length=6,
+                use_static_covariates=False,
+                n_epochs=1,
+                **tfm_kwargs,
+            )
+            model.fit(series.with_static_covariates(None))
+            preds = model.predict(n=2, series=series)
+            assert preds.static_covariates.equals(series.static_covariates)
diff --git a/darts/tests/models/forecasting/test_global_forecasting_models.py b/darts/tests/models/forecasting/test_global_forecasting_models.py
index dd3e6faf8d..59d278f756 100644
--- a/darts/tests/models/forecasting/test_global_forecasting_models.py
+++ b/darts/tests/models/forecasting/test_global_forecasting_models.py
@@ -41,787 +41,782 @@
         PastCovariatesTorchModel,
     )
     from darts.utils.likelihood_models import GaussianLikelihood
-
-    TORCH_AVAILABLE = True
 except ImportError:
-    logger.warning("Torch not installed - will be skipping Torch models tests")
-    TORCH_AVAILABLE = False
-
-
-if TORCH_AVAILABLE:
-    IN_LEN = 24
-    OUT_LEN = 12
-    models_cls_kwargs_errs = [
-        (
-            BlockRNNModel,
-            {
-                "model": "RNN",
-                "hidden_dim": 10,
-                "n_rnn_layers": 1,
-                "batch_size": 32,
-                "n_epochs": 10,
-                "pl_trainer_kwargs": tfm_kwargs["pl_trainer_kwargs"],
-            },
-            110.0,
-        ),
-        (
-            RNNModel,
-            {
-                "model": "RNN",
-                "hidden_dim": 10,
-                "batch_size": 32,
-                "n_epochs": 10,
-                "pl_trainer_kwargs": tfm_kwargs["pl_trainer_kwargs"],
-            },
-            150.0,
-        ),
-        (
-            RNNModel,
-            {
-                "training_length": 12,
-                "n_epochs": 10,
-                "likelihood": GaussianLikelihood(),
-                "pl_trainer_kwargs": tfm_kwargs["pl_trainer_kwargs"],
-            },
-            80.0,
-        ),
-        (
-            TCNModel,
-            {
-                "n_epochs": 10,
-                "batch_size": 32,
-                "pl_trainer_kwargs": tfm_kwargs["pl_trainer_kwargs"],
-            },
-            60.0,
-        ),
-        (
-            TransformerModel,
-            {
-                "d_model": 16,
-                "nhead": 2,
-                "num_encoder_layers": 2,
-                "num_decoder_layers": 2,
-                "dim_feedforward": 16,
-                "batch_size": 32,
-                "n_epochs": 10,
-                "pl_trainer_kwargs": tfm_kwargs["pl_trainer_kwargs"],
-            },
-            60.0,
-        ),
-        (
-            NBEATSModel,
-            {
-                "num_stacks": 4,
-                "num_blocks": 1,
-                "num_layers": 2,
-                "layer_widths": 12,
-                "n_epochs": 10,
-                "pl_trainer_kwargs": tfm_kwargs["pl_trainer_kwargs"],
-            },
-            140.0,
-        ),
-        (
-            TFTModel,
-            {
-                "hidden_size": 16,
-                "lstm_layers": 1,
-                "num_attention_heads": 4,
-                "add_relative_index": True,
-                "n_epochs": 10,
-                "pl_trainer_kwargs": tfm_kwargs["pl_trainer_kwargs"],
-            },
-            70.0,
-        ),
-        (
-            NLinearModel,
-            {
-                "n_epochs": 10,
-                "pl_trainer_kwargs": tfm_kwargs["pl_trainer_kwargs"],
-            },
-            50.0,
-        ),
-        (
-            DLinearModel,
-            {
-                "n_epochs": 10,
-                "pl_trainer_kwargs": tfm_kwargs["pl_trainer_kwargs"],
-            },
-            55.0,
-        ),
-        (
-            TiDEModel,
-            {
-                "n_epochs": 10,
-                "pl_trainer_kwargs": tfm_kwargs["pl_trainer_kwargs"],
-            },
-            40.0,
-        ),
-        (
-            TSMixerModel,
-            {
-                "n_epochs": 10,
-                "pl_trainer_kwargs": tfm_kwargs["pl_trainer_kwargs"],
-            },
-            60.0,
-        ),
-        (
-            GlobalNaiveAggregate,
-            {
-                "pl_trainer_kwargs": tfm_kwargs["pl_trainer_kwargs"],
-            },
-            22,
-        ),
-        (
-            GlobalNaiveDrift,
-            {
-                "pl_trainer_kwargs": tfm_kwargs["pl_trainer_kwargs"],
-            },
-            17,
-        ),
-        (
-            GlobalNaiveSeasonal,
-            {
-                "pl_trainer_kwargs": tfm_kwargs["pl_trainer_kwargs"],
-            },
-            39,
-        ),
-    ]
+    pytest.skip(
+        f"Torch not available. {__name__} tests will be skipped.",
+        allow_module_level=True,
+    )
+
 
-    class TestGlobalForecastingModels:
-        # forecasting horizon used in runnability tests
-        forecasting_horizon = 12
+IN_LEN = 24
+OUT_LEN = 12
+models_cls_kwargs_errs = [
+    (
+        BlockRNNModel,
+        {
+            "model": "RNN",
+            "hidden_dim": 10,
+            "n_rnn_layers": 1,
+            "batch_size": 32,
+            "n_epochs": 10,
+            "pl_trainer_kwargs": tfm_kwargs["pl_trainer_kwargs"],
+        },
+        110.0,
+    ),
+    (
+        RNNModel,
+        {
+            "model": "RNN",
+            "hidden_dim": 10,
+            "batch_size": 32,
+            "n_epochs": 10,
+            "pl_trainer_kwargs": tfm_kwargs["pl_trainer_kwargs"],
+        },
+        150.0,
+    ),
+    (
+        RNNModel,
+        {
+            "training_length": 12,
+            "n_epochs": 10,
+            "likelihood": GaussianLikelihood(),
+            "pl_trainer_kwargs": tfm_kwargs["pl_trainer_kwargs"],
+        },
+        80.0,
+    ),
+    (
+        TCNModel,
+        {
+            "n_epochs": 10,
+            "batch_size": 32,
+            "pl_trainer_kwargs": tfm_kwargs["pl_trainer_kwargs"],
+        },
+        60.0,
+    ),
+    (
+        TransformerModel,
+        {
+            "d_model": 16,
+            "nhead": 2,
+            "num_encoder_layers": 2,
+            "num_decoder_layers": 2,
+            "dim_feedforward": 16,
+            "batch_size": 32,
+            "n_epochs": 10,
+            "pl_trainer_kwargs": tfm_kwargs["pl_trainer_kwargs"],
+        },
+        60.0,
+    ),
+    (
+        NBEATSModel,
+        {
+            "num_stacks": 4,
+            "num_blocks": 1,
+            "num_layers": 2,
+            "layer_widths": 12,
+            "n_epochs": 10,
+            "pl_trainer_kwargs": tfm_kwargs["pl_trainer_kwargs"],
+        },
+        140.0,
+    ),
+    (
+        TFTModel,
+        {
+            "hidden_size": 16,
+            "lstm_layers": 1,
+            "num_attention_heads": 4,
+            "add_relative_index": True,
+            "n_epochs": 10,
+            "pl_trainer_kwargs": tfm_kwargs["pl_trainer_kwargs"],
+        },
+        70.0,
+    ),
+    (
+        NLinearModel,
+        {
+            "n_epochs": 10,
+            "pl_trainer_kwargs": tfm_kwargs["pl_trainer_kwargs"],
+        },
+        50.0,
+    ),
+    (
+        DLinearModel,
+        {
+            "n_epochs": 10,
+            "pl_trainer_kwargs": tfm_kwargs["pl_trainer_kwargs"],
+        },
+        55.0,
+    ),
+    (
+        TiDEModel,
+        {
+            "n_epochs": 10,
+            "pl_trainer_kwargs": tfm_kwargs["pl_trainer_kwargs"],
+        },
+        40.0,
+    ),
+    (
+        TSMixerModel,
+        {
+            "n_epochs": 10,
+            "pl_trainer_kwargs": tfm_kwargs["pl_trainer_kwargs"],
+        },
+        60.0,
+    ),
+    (
+        GlobalNaiveAggregate,
+        {
+            "pl_trainer_kwargs": tfm_kwargs["pl_trainer_kwargs"],
+        },
+        22,
+    ),
+    (
+        GlobalNaiveDrift,
+        {
+            "pl_trainer_kwargs": tfm_kwargs["pl_trainer_kwargs"],
+        },
+        17,
+    ),
+    (
+        GlobalNaiveSeasonal,
+        {
+            "pl_trainer_kwargs": tfm_kwargs["pl_trainer_kwargs"],
+        },
+        39,
+    ),
+]
 
-        np.random.seed(42)
-        torch.manual_seed(42)
 
-        # some arbitrary static covariates
-        static_covariates = pd.DataFrame([[0.0, 1.0]], columns=["st1", "st2"])
+class TestGlobalForecastingModels:
+    # forecasting horizon used in runnability tests
+    forecasting_horizon = 12
 
-        # real timeseries for functionality tests
-        ts_passengers = (
-            AirPassengersDataset().load().with_static_covariates(static_covariates)
-        )
-        scaler = Scaler()
-        ts_passengers = scaler.fit_transform(ts_passengers)
-        ts_pass_train, ts_pass_val = ts_passengers[:-36], ts_passengers[-36:]
-
-        # an additional noisy series
-        ts_pass_train_1 = ts_pass_train + 0.01 * tg.gaussian_timeseries(
-            length=len(ts_pass_train),
-            freq=ts_pass_train.freq_str,
-            start=ts_pass_train.start_time(),
-        )
+    np.random.seed(42)
+    torch.manual_seed(42)
 
-        # an additional time series serving as covariates
-        year_series = tg.datetime_attribute_timeseries(ts_passengers, attribute="year")
-        month_series = tg.datetime_attribute_timeseries(
-            ts_passengers, attribute="month"
-        )
-        scaler_dt = Scaler()
-        time_covariates = scaler_dt.fit_transform(year_series.stack(month_series))
-        time_covariates_train, time_covariates_val = (
-            time_covariates[:-36],
-            time_covariates[-36:],
-        )
+    # some arbitrary static covariates
+    static_covariates = pd.DataFrame([[0.0, 1.0]], columns=["st1", "st2"])
 
-        # an artificial time series that is highly dependent on covariates
-        ts_length = 400
-        split_ratio = 0.6
-        sine_1_ts = tg.sine_timeseries(length=ts_length)
-        sine_2_ts = tg.sine_timeseries(length=ts_length, value_frequency=0.05)
-        sine_3_ts = tg.sine_timeseries(
-            length=ts_length, value_frequency=0.003, value_amplitude=5
-        )
-        linear_ts = tg.linear_timeseries(length=ts_length, start_value=3, end_value=8)
+    # real timeseries for functionality tests
+    ts_passengers = (
+        AirPassengersDataset().load().with_static_covariates(static_covariates)
+    )
+    scaler = Scaler()
+    ts_passengers = scaler.fit_transform(ts_passengers)
+    ts_pass_train, ts_pass_val = ts_passengers[:-36], ts_passengers[-36:]
+
+    # an additional noisy series
+    ts_pass_train_1 = ts_pass_train + 0.01 * tg.gaussian_timeseries(
+        length=len(ts_pass_train),
+        freq=ts_pass_train.freq_str,
+        start=ts_pass_train.start_time(),
+    )
 
-        covariates = sine_3_ts.stack(sine_2_ts).stack(linear_ts)
-        covariates_past, _ = covariates.split_after(split_ratio)
+    # an additional time series serving as covariates
+    year_series = tg.datetime_attribute_timeseries(ts_passengers, attribute="year")
+    month_series = tg.datetime_attribute_timeseries(ts_passengers, attribute="month")
+    scaler_dt = Scaler()
+    time_covariates = scaler_dt.fit_transform(year_series.stack(month_series))
+    time_covariates_train, time_covariates_val = (
+        time_covariates[:-36],
+        time_covariates[-36:],
+    )
 
-        target = sine_1_ts + sine_2_ts + linear_ts + sine_3_ts
-        target_past, target_future = target.split_after(split_ratio)
+    # an artificial time series that is highly dependent on covariates
+    ts_length = 400
+    split_ratio = 0.6
+    sine_1_ts = tg.sine_timeseries(length=ts_length)
+    sine_2_ts = tg.sine_timeseries(length=ts_length, value_frequency=0.05)
+    sine_3_ts = tg.sine_timeseries(
+        length=ts_length, value_frequency=0.003, value_amplitude=5
+    )
+    linear_ts = tg.linear_timeseries(length=ts_length, start_value=3, end_value=8)
 
-        # various ts with different static covariates representations
-        ts_w_static_cov = tg.linear_timeseries(length=80).with_static_covariates(
-            pd.Series([1, 2])
-        )
-        ts_shared_static_cov = ts_w_static_cov.stack(tg.sine_timeseries(length=80))
-        ts_comps_static_cov = ts_shared_static_cov.with_static_covariates(
-            pd.DataFrame([[0, 1], [2, 3]], columns=["st1", "st2"])
-        )
+    covariates = sine_3_ts.stack(sine_2_ts).stack(linear_ts)
+    covariates_past, _ = covariates.split_after(split_ratio)
 
-        @pytest.mark.parametrize("config", models_cls_kwargs_errs)
-        def test_save_model_parameters(self, config):
-            # model creation parameters were saved before. check if re-created model has same params as original
-            model_cls, kwargs, err = config
-            model = model_cls(
-                input_chunk_length=IN_LEN, output_chunk_length=OUT_LEN, **kwargs
-            )
-            assert model._model_params, model.untrained_model()._model_params
-
-        @pytest.mark.parametrize(
-            "model",
-            [
-                RNNModel(
-                    input_chunk_length=4,
-                    hidden_dim=10,
-                    batch_size=32,
-                    n_epochs=10,
-                    **tfm_kwargs,
-                ),
-                TCNModel(
-                    input_chunk_length=4,
-                    output_chunk_length=3,
-                    n_epochs=10,
-                    batch_size=32,
-                    **tfm_kwargs,
-                ),
-                GlobalNaiveSeasonal(
-                    input_chunk_length=4,
-                    output_chunk_length=3,
-                    **tfm_kwargs,
-                ),
-            ],
+    target = sine_1_ts + sine_2_ts + linear_ts + sine_3_ts
+    target_past, target_future = target.split_after(split_ratio)
+
+    # various ts with different static covariates representations
+    ts_w_static_cov = tg.linear_timeseries(length=80).with_static_covariates(
+        pd.Series([1, 2])
+    )
+    ts_shared_static_cov = ts_w_static_cov.stack(tg.sine_timeseries(length=80))
+    ts_comps_static_cov = ts_shared_static_cov.with_static_covariates(
+        pd.DataFrame([[0, 1], [2, 3]], columns=["st1", "st2"])
+    )
+
+    @pytest.mark.parametrize("config", models_cls_kwargs_errs)
+    def test_save_model_parameters(self, config):
+        # model creation parameters were saved before. check if re-created model has same params as original
+        model_cls, kwargs, err = config
+        model = model_cls(
+            input_chunk_length=IN_LEN, output_chunk_length=OUT_LEN, **kwargs
         )
-        def test_save_load_model(self, tmpdir_module, model):
-            # check if save and load methods work and if loaded model creates same forecasts as original model
-            cwd = os.getcwd()
-            os.chdir(tmpdir_module)
-            model_path_str = type(model).__name__
-            full_model_path_str = os.path.join(tmpdir_module, model_path_str)
-
-            model.fit(self.ts_pass_train)
-            model_prediction = model.predict(self.forecasting_horizon)
-
-            # test save
-            model.save()
-            model.save(model_path_str)
-
-            assert os.path.exists(full_model_path_str)
-            assert (
-                len(
-                    [
-                        p
-                        for p in os.listdir(tmpdir_module)
-                        if p.startswith(type(model).__name__)
-                    ]
-                )
-                == 4
+        assert model._model_params, model.untrained_model()._model_params
+
+    @pytest.mark.parametrize(
+        "model",
+        [
+            RNNModel(
+                input_chunk_length=4,
+                hidden_dim=10,
+                batch_size=32,
+                n_epochs=10,
+                **tfm_kwargs,
+            ),
+            TCNModel(
+                input_chunk_length=4,
+                output_chunk_length=3,
+                n_epochs=10,
+                batch_size=32,
+                **tfm_kwargs,
+            ),
+            GlobalNaiveSeasonal(
+                input_chunk_length=4,
+                output_chunk_length=3,
+                **tfm_kwargs,
+            ),
+        ],
+    )
+    def test_save_load_model(self, tmpdir_module, model):
+        # check if save and load methods work and if loaded model creates same forecasts as original model
+        cwd = os.getcwd()
+        os.chdir(tmpdir_module)
+        model_path_str = type(model).__name__
+        full_model_path_str = os.path.join(tmpdir_module, model_path_str)
+
+        model.fit(self.ts_pass_train)
+        model_prediction = model.predict(self.forecasting_horizon)
+
+        # test save
+        model.save()
+        model.save(model_path_str)
+
+        assert os.path.exists(full_model_path_str)
+        assert (
+            len(
+                [
+                    p
+                    for p in os.listdir(tmpdir_module)
+                    if p.startswith(type(model).__name__)
+                ]
             )
+            == 4
+        )
 
-            # test load
-            loaded_model = type(model).load(model_path_str)
+        # test load
+        loaded_model = type(model).load(model_path_str)
 
-            assert model_prediction == loaded_model.predict(self.forecasting_horizon)
+        assert model_prediction == loaded_model.predict(self.forecasting_horizon)
 
-            os.chdir(cwd)
+        os.chdir(cwd)
 
-        @pytest.mark.parametrize("config", models_cls_kwargs_errs)
-        def test_single_ts(self, config):
-            model_cls, kwargs, err = config
-            model = model_cls(
-                input_chunk_length=IN_LEN,
-                output_chunk_length=OUT_LEN,
-                random_state=0,
-                **kwargs,
-            )
-            model.fit(self.ts_pass_train)
-            pred = model.predict(n=36)
-            mape_err = mape(self.ts_pass_val, pred)
-            assert mape_err < err, (
-                "Model {} produces errors too high (one time "
-                "series). Error = {}".format(model_cls, mape_err)
-            )
-            assert pred.static_covariates.equals(self.ts_passengers.static_covariates)
-
-        @pytest.mark.parametrize("config", models_cls_kwargs_errs)
-        def test_multi_ts(self, config):
-            model_cls, kwargs, err = config
-            model = model_cls(
-                input_chunk_length=IN_LEN,
-                output_chunk_length=OUT_LEN,
-                random_state=0,
-                **kwargs,
-            )
-            model.fit([self.ts_pass_train, self.ts_pass_train_1])
-            with pytest.raises(ValueError):
-                # when model is fit from >1 series, one must provide a series in argument
-                model.predict(n=1)
-            pred = model.predict(n=36, series=self.ts_pass_train)
+    @pytest.mark.parametrize("config", models_cls_kwargs_errs)
+    def test_single_ts(self, config):
+        model_cls, kwargs, err = config
+        model = model_cls(
+            input_chunk_length=IN_LEN,
+            output_chunk_length=OUT_LEN,
+            random_state=0,
+            **kwargs,
+        )
+        model.fit(self.ts_pass_train)
+        pred = model.predict(n=36)
+        mape_err = mape(self.ts_pass_val, pred)
+        assert (
+            mape_err < err
+        ), "Model {} produces errors too high (one time " "series). Error = {}".format(
+            model_cls, mape_err
+        )
+        assert pred.static_covariates.equals(self.ts_passengers.static_covariates)
+
+    @pytest.mark.parametrize("config", models_cls_kwargs_errs)
+    def test_multi_ts(self, config):
+        model_cls, kwargs, err = config
+        model = model_cls(
+            input_chunk_length=IN_LEN,
+            output_chunk_length=OUT_LEN,
+            random_state=0,
+            **kwargs,
+        )
+        model.fit([self.ts_pass_train, self.ts_pass_train_1])
+        with pytest.raises(ValueError):
+            # when model is fit from >1 series, one must provide a series in argument
+            model.predict(n=1)
+        pred = model.predict(n=36, series=self.ts_pass_train)
+        mape_err = mape(self.ts_pass_val, pred)
+        assert mape_err < err, (
+            "Model {} produces errors too high (several time "
+            "series). Error = {}".format(model_cls, mape_err)
+        )
+
+        # check prediction for several time series
+        pred_list = model.predict(
+            n=36, series=[self.ts_pass_train, self.ts_pass_train_1]
+        )
+        assert (
+            len(pred_list) == 2
+        ), f"Model {model_cls} did not return a list of prediction"
+        for pred in pred_list:
             mape_err = mape(self.ts_pass_val, pred)
             assert mape_err < err, (
-                "Model {} produces errors too high (several time "
-                "series). Error = {}".format(model_cls, mape_err)
+                "Model {} produces errors too high (several time series 2). "
+                "Error = {}".format(model_cls, mape_err)
             )
 
-            # check prediction for several time series
-            pred_list = model.predict(
-                n=36, series=[self.ts_pass_train, self.ts_pass_train_1]
-            )
-            assert (
-                len(pred_list) == 2
-            ), f"Model {model_cls} did not return a list of prediction"
-            for pred in pred_list:
-                mape_err = mape(self.ts_pass_val, pred)
-                assert mape_err < err, (
-                    "Model {} produces errors too high (several time series 2). "
-                    "Error = {}".format(model_cls, mape_err)
-                )
-
-        @pytest.mark.parametrize("config", models_cls_kwargs_errs)
-        def test_covariates(self, config):
-            model_cls, kwargs, err = config
-            model = model_cls(
-                input_chunk_length=IN_LEN,
-                output_chunk_length=OUT_LEN,
-                random_state=0,
-                **kwargs,
-            )
+    @pytest.mark.parametrize("config", models_cls_kwargs_errs)
+    def test_covariates(self, config):
+        model_cls, kwargs, err = config
+        model = model_cls(
+            input_chunk_length=IN_LEN,
+            output_chunk_length=OUT_LEN,
+            random_state=0,
+            **kwargs,
+        )
 
-            # Here we rely on the fact that all non-Dual models currently are Past models
-            if model.supports_future_covariates:
-                cov_name = "future_covariates"
-                is_past = False
-            elif model.supports_past_covariates:
-                cov_name = "past_covariates"
-                is_past = True
-            else:
-                cov_name = None
-                is_past = None
-
-            covariates = [self.time_covariates_train, self.time_covariates_train]
-            if cov_name is not None:
-                cov_kwargs = {cov_name: covariates}
-                cov_kwargs_train = {cov_name: self.time_covariates_train}
-                cov_kwargs_notrain = {cov_name: self.time_covariates}
-            else:
-                cov_kwargs = {}
-                cov_kwargs_train = {}
-                cov_kwargs_notrain = {}
-
-            model.fit(series=[self.ts_pass_train, self.ts_pass_train_1], **cov_kwargs)
-
-            if cov_name is None:
-                with pytest.raises(ValueError):
-                    model.untrained_model().fit(
-                        series=[self.ts_pass_train, self.ts_pass_train_1],
-                        past_covariates=covariates,
-                    )
-                with pytest.raises(ValueError):
-                    model.untrained_model().fit(
-                        series=[self.ts_pass_train, self.ts_pass_train_1],
-                        future_covariates=covariates,
-                    )
+        # Here we rely on the fact that all non-Dual models currently are Past models
+        if model.supports_future_covariates:
+            cov_name = "future_covariates"
+            is_past = False
+        elif model.supports_past_covariates:
+            cov_name = "past_covariates"
+            is_past = True
+        else:
+            cov_name = None
+            is_past = None
+
+        covariates = [self.time_covariates_train, self.time_covariates_train]
+        if cov_name is not None:
+            cov_kwargs = {cov_name: covariates}
+            cov_kwargs_train = {cov_name: self.time_covariates_train}
+            cov_kwargs_notrain = {cov_name: self.time_covariates}
+        else:
+            cov_kwargs = {}
+            cov_kwargs_train = {}
+            cov_kwargs_notrain = {}
+
+        model.fit(series=[self.ts_pass_train, self.ts_pass_train_1], **cov_kwargs)
+
+        if cov_name is None:
             with pytest.raises(ValueError):
-                # when model is fit from >1 series, one must provide a series in argument
-                model.predict(n=1)
-
-            if cov_name is not None:
-                with pytest.raises(ValueError):
-                    # when model is fit using multiple covariates, covariates are required at prediction time
-                    model.predict(n=1, series=self.ts_pass_train)
-
-                with pytest.raises(ValueError):
-                    # when model is fit using covariates, n cannot be greater than output_chunk_length...
-                    # (for short covariates)
-                    # past covariates model can predict up until output_chunk_length
-                    # with train future covariates we cannot predict at all after end of series
-                    model.predict(
-                        n=13 if is_past else 1,
-                        series=self.ts_pass_train,
-                        **cov_kwargs_train,
-                    )
-            else:
-                # model does not support covariates
-                with pytest.raises(ValueError):
-                    model.predict(
-                        n=1,
-                        series=self.ts_pass_train,
-                        past_covariates=self.time_covariates,
-                    )
-                with pytest.raises(ValueError):
-                    model.predict(
-                        n=1,
-                        series=self.ts_pass_train,
-                        future_covariates=self.time_covariates,
-                    )
-
-            # ... unless future covariates are provided
-            _ = model.predict(n=13, series=self.ts_pass_train, **cov_kwargs_notrain)
-
-            pred = model.predict(n=12, series=self.ts_pass_train, **cov_kwargs_notrain)
-            mape_err = mape(self.ts_pass_val, pred)
-            assert mape_err < err, (
-                "Model {} produces errors too high (several time "
-                "series with covariates). Error = {}".format(model_cls, mape_err)
-            )
-
-            # when model is fit using 1 training and 1 covariate series, time series args are optional
-            if model.supports_probabilistic_prediction:
-                return
-            model = model_cls(
-                input_chunk_length=IN_LEN, output_chunk_length=OUT_LEN, **kwargs
-            )
-            model.fit(series=self.ts_pass_train, **cov_kwargs_train)
-            if is_past:
-                # with past covariates from train we can predict up until output_chunk_length
-                pred1 = model.predict(1)
-                pred2 = model.predict(1, series=self.ts_pass_train)
-                pred3 = model.predict(1, **cov_kwargs_train)
-                pred4 = model.predict(1, **cov_kwargs_train, series=self.ts_pass_train)
-            else:
-                # with future covariates we need additional time steps to predict
-                with pytest.raises(ValueError):
-                    _ = model.predict(1)
-                with pytest.raises(ValueError):
-                    _ = model.predict(1, series=self.ts_pass_train)
-                with pytest.raises(ValueError):
-                    _ = model.predict(1, **cov_kwargs_train)
-                with pytest.raises(ValueError):
-                    _ = model.predict(1, **cov_kwargs_train, series=self.ts_pass_train)
-
-                pred1 = model.predict(1, **cov_kwargs_notrain)
-                pred2 = model.predict(
-                    1, series=self.ts_pass_train, **cov_kwargs_notrain
+                model.untrained_model().fit(
+                    series=[self.ts_pass_train, self.ts_pass_train_1],
+                    past_covariates=covariates,
                 )
-                pred3 = model.predict(1, **cov_kwargs_notrain)
-                pred4 = model.predict(
-                    1, **cov_kwargs_notrain, series=self.ts_pass_train
+            with pytest.raises(ValueError):
+                model.untrained_model().fit(
+                    series=[self.ts_pass_train, self.ts_pass_train_1],
+                    future_covariates=covariates,
                 )
+        with pytest.raises(ValueError):
+            # when model is fit from >1 series, one must provide a series in argument
+            model.predict(n=1)
 
-            assert pred1 == pred2
-            assert pred1 == pred3
-            assert pred1 == pred4
+        if cov_name is not None:
+            with pytest.raises(ValueError):
+                # when model is fit using multiple covariates, covariates are required at prediction time
+                model.predict(n=1, series=self.ts_pass_train)
 
-        def test_future_covariates(self):
-            # models with future covariates should produce better predictions over a long forecasting horizon
-            # than a model trained with no covariates
+            with pytest.raises(ValueError):
+                # when model is fit using covariates, n cannot be greater than output_chunk_length...
+                # (for short covariates)
+                # past covariates model can predict up until output_chunk_length
+                # with train future covariates we cannot predict at all after end of series
+                model.predict(
+                    n=13 if is_past else 1,
+                    series=self.ts_pass_train,
+                    **cov_kwargs_train,
+                )
+        else:
+            # model does not support covariates
+            with pytest.raises(ValueError):
+                model.predict(
+                    n=1,
+                    series=self.ts_pass_train,
+                    past_covariates=self.time_covariates,
+                )
+            with pytest.raises(ValueError):
+                model.predict(
+                    n=1,
+                    series=self.ts_pass_train,
+                    future_covariates=self.time_covariates,
+                )
 
-            model = TCNModel(
-                input_chunk_length=50,
-                output_chunk_length=5,
-                n_epochs=20,
-                random_state=0,
-                **tfm_kwargs,
-            )
-            model.fit(series=self.target_past)
-            long_pred_no_cov = model.predict(n=160)
-
-            model = TCNModel(
-                input_chunk_length=50,
-                output_chunk_length=5,
-                n_epochs=20,
-                random_state=0,
-                **tfm_kwargs,
-            )
-            model.fit(series=self.target_past, past_covariates=self.covariates_past)
-            long_pred_with_cov = model.predict(n=160, past_covariates=self.covariates)
-            assert mape(self.target_future, long_pred_no_cov) > mape(
-                self.target_future, long_pred_with_cov
-            ), "Models with future covariates should produce better predictions."
+        # ... unless future covariates are provided
+        _ = model.predict(n=13, series=self.ts_pass_train, **cov_kwargs_notrain)
 
-            # block models can predict up to self.output_chunk_length points beyond the last future covariate...
-            model.predict(n=165, past_covariates=self.covariates)
+        pred = model.predict(n=12, series=self.ts_pass_train, **cov_kwargs_notrain)
+        mape_err = mape(self.ts_pass_val, pred)
+        assert mape_err < err, (
+            "Model {} produces errors too high (several time "
+            "series with covariates). Error = {}".format(model_cls, mape_err)
+        )
 
-            # ... not more
+        # when model is fit using 1 training and 1 covariate series, time series args are optional
+        if model.supports_probabilistic_prediction:
+            return
+        model = model_cls(
+            input_chunk_length=IN_LEN, output_chunk_length=OUT_LEN, **kwargs
+        )
+        model.fit(series=self.ts_pass_train, **cov_kwargs_train)
+        if is_past:
+            # with past covariates from train we can predict up until output_chunk_length
+            pred1 = model.predict(1)
+            pred2 = model.predict(1, series=self.ts_pass_train)
+            pred3 = model.predict(1, **cov_kwargs_train)
+            pred4 = model.predict(1, **cov_kwargs_train, series=self.ts_pass_train)
+        else:
+            # with future covariates we need additional time steps to predict
             with pytest.raises(ValueError):
-                model.predict(n=166, series=self.ts_pass_train)
-
-            # recurrent models can only predict data points for time steps where future covariates are available
-            model = RNNModel(12, n_epochs=1, **tfm_kwargs)
-            model.fit(series=self.target_past, future_covariates=self.covariates_past)
-            model.predict(n=160, future_covariates=self.covariates)
+                _ = model.predict(1)
             with pytest.raises(ValueError):
-                model.predict(n=161, future_covariates=self.covariates)
-
-        @pytest.mark.parametrize(
-            "model_cls,ts",
-            product(
-                [TFTModel, DLinearModel, NLinearModel, TiDEModel, TSMixerModel],
-                [ts_w_static_cov, ts_shared_static_cov, ts_comps_static_cov],
+                _ = model.predict(1, series=self.ts_pass_train)
+            with pytest.raises(ValueError):
+                _ = model.predict(1, **cov_kwargs_train)
+            with pytest.raises(ValueError):
+                _ = model.predict(1, **cov_kwargs_train, series=self.ts_pass_train)
+
+            pred1 = model.predict(1, **cov_kwargs_notrain)
+            pred2 = model.predict(1, series=self.ts_pass_train, **cov_kwargs_notrain)
+            pred3 = model.predict(1, **cov_kwargs_notrain)
+            pred4 = model.predict(1, **cov_kwargs_notrain, series=self.ts_pass_train)
+
+        assert pred1 == pred2
+        assert pred1 == pred3
+        assert pred1 == pred4
+
+    def test_future_covariates(self):
+        # models with future covariates should produce better predictions over a long forecasting horizon
+        # than a model trained with no covariates
+
+        model = TCNModel(
+            input_chunk_length=50,
+            output_chunk_length=5,
+            n_epochs=20,
+            random_state=0,
+            **tfm_kwargs,
+        )
+        model.fit(series=self.target_past)
+        long_pred_no_cov = model.predict(n=160)
+
+        model = TCNModel(
+            input_chunk_length=50,
+            output_chunk_length=5,
+            n_epochs=20,
+            random_state=0,
+            **tfm_kwargs,
+        )
+        model.fit(series=self.target_past, past_covariates=self.covariates_past)
+        long_pred_with_cov = model.predict(n=160, past_covariates=self.covariates)
+        assert mape(self.target_future, long_pred_no_cov) > mape(
+            self.target_future, long_pred_with_cov
+        ), "Models with future covariates should produce better predictions."
+
+        # block models can predict up to self.output_chunk_length points beyond the last future covariate...
+        model.predict(n=165, past_covariates=self.covariates)
+
+        # ... not more
+        with pytest.raises(ValueError):
+            model.predict(n=166, series=self.ts_pass_train)
+
+        # recurrent models can only predict data points for time steps where future covariates are available
+        model = RNNModel(12, n_epochs=1, **tfm_kwargs)
+        model.fit(series=self.target_past, future_covariates=self.covariates_past)
+        model.predict(n=160, future_covariates=self.covariates)
+        with pytest.raises(ValueError):
+            model.predict(n=161, future_covariates=self.covariates)
+
+    @pytest.mark.parametrize(
+        "model_cls,ts",
+        product(
+            [TFTModel, DLinearModel, NLinearModel, TiDEModel, TSMixerModel],
+            [ts_w_static_cov, ts_shared_static_cov, ts_comps_static_cov],
+        ),
+    )
+    def test_use_static_covariates(self, model_cls, ts):
+        """
+        Check that both static covariates representations are supported (component-specific and shared)
+        for both uni- and multivariate series when fitting the model.
+        Also check that the static covariates are present in the forecasted series
+        """
+        model = model_cls(
+            input_chunk_length=IN_LEN,
+            output_chunk_length=OUT_LEN,
+            random_state=0,
+            use_static_covariates=True,
+            n_epochs=1,
+            **tfm_kwargs,
+        )
+        # must provide mandatory future_covariates to TFTModel
+        model.fit(
+            series=ts,
+            future_covariates=(
+                self.sine_1_ts if model.supports_future_covariates else None
             ),
         )
-        def test_use_static_covariates(self, model_cls, ts):
-            """
-            Check that both static covariates representations are supported (component-specific and shared)
-            for both uni- and multivariate series when fitting the model.
-            Also check that the static covariates are present in the forecasted series
-            """
-            model = model_cls(
-                input_chunk_length=IN_LEN,
-                output_chunk_length=OUT_LEN,
-                random_state=0,
-                use_static_covariates=True,
-                n_epochs=1,
-                **tfm_kwargs,
-            )
-            # must provide mandatory future_covariates to TFTModel
-            model.fit(
-                series=ts,
-                future_covariates=(
-                    self.sine_1_ts if model.supports_future_covariates else None
-                ),
-            )
-            pred = model.predict(OUT_LEN)
-            assert pred.static_covariates.equals(ts.static_covariates)
-
-        def test_batch_predictions(self):
-            # predicting multiple time series at once needs to work for arbitrary batch sizes
-            # univariate case
-            targets_univar = [
-                self.target_past,
-                self.target_past[:60],
-                self.target_past[:80],
-            ]
-            self._batch_prediction_test_helper_function(targets_univar)
-
-            # multivariate case
-            targets_multivar = [tgt.stack(tgt) for tgt in targets_univar]
-            self._batch_prediction_test_helper_function(targets_multivar)
-
-        def _batch_prediction_test_helper_function(self, targets):
-            epsilon = 1e-4
-            model = TCNModel(
-                input_chunk_length=50,
-                output_chunk_length=10,
-                n_epochs=10,
-                random_state=0,
-                **tfm_kwargs,
-            )
-            model.fit(series=targets[0], past_covariates=self.covariates_past)
-            preds_default = model.predict(
+        pred = model.predict(OUT_LEN)
+        assert pred.static_covariates.equals(ts.static_covariates)
+
+    def test_batch_predictions(self):
+        # predicting multiple time series at once needs to work for arbitrary batch sizes
+        # univariate case
+        targets_univar = [
+            self.target_past,
+            self.target_past[:60],
+            self.target_past[:80],
+        ]
+        self._batch_prediction_test_helper_function(targets_univar)
+
+        # multivariate case
+        targets_multivar = [tgt.stack(tgt) for tgt in targets_univar]
+        self._batch_prediction_test_helper_function(targets_multivar)
+
+    def _batch_prediction_test_helper_function(self, targets):
+        epsilon = 1e-4
+        model = TCNModel(
+            input_chunk_length=50,
+            output_chunk_length=10,
+            n_epochs=10,
+            random_state=0,
+            **tfm_kwargs,
+        )
+        model.fit(series=targets[0], past_covariates=self.covariates_past)
+        preds_default = model.predict(
+            n=160,
+            series=targets,
+            past_covariates=[self.covariates] * len(targets),
+            batch_size=None,
+        )
+
+        # make batch size large enough to test stacking samples
+        for batch_size in range(1, 4 * len(targets)):
+            preds = model.predict(
                 n=160,
                 series=targets,
                 past_covariates=[self.covariates] * len(targets),
-                batch_size=None,
+                batch_size=batch_size,
             )
+            for i in range(len(targets)):
+                assert sum(sum((preds[i] - preds_default[i]).values())) < epsilon
+
+    def test_predict_from_dataset_unsupported_input(self):
+        # an exception should be thrown if an unsupported type is passed
+        unsupported_type = "unsupported_type"
+        # just need to test this with one model
+        model_cls, kwargs, err = models_cls_kwargs_errs[0]
+        model = model_cls(
+            input_chunk_length=IN_LEN, output_chunk_length=OUT_LEN, **kwargs
+        )
+        model.fit([self.ts_pass_train, self.ts_pass_train_1])
 
-            # make batch size large enough to test stacking samples
-            for batch_size in range(1, 4 * len(targets)):
-                preds = model.predict(
-                    n=160,
-                    series=targets,
-                    past_covariates=[self.covariates] * len(targets),
-                    batch_size=batch_size,
-                )
-                for i in range(len(targets)):
-                    assert sum(sum((preds[i] - preds_default[i]).values())) < epsilon
-
-        def test_predict_from_dataset_unsupported_input(self):
-            # an exception should be thrown if an unsupported type is passed
-            unsupported_type = "unsupported_type"
-            # just need to test this with one model
-            model_cls, kwargs, err = models_cls_kwargs_errs[0]
-            model = model_cls(
-                input_chunk_length=IN_LEN, output_chunk_length=OUT_LEN, **kwargs
-            )
-            model.fit([self.ts_pass_train, self.ts_pass_train_1])
-
-            with pytest.raises(ValueError):
-                model.predict_from_dataset(n=1, input_series_dataset=unsupported_type)
-
-        @pytest.mark.parametrize("config", models_cls_kwargs_errs)
-        def test_prediction_with_different_n(self, config):
-            # test model predictions for n < out_len, n == out_len and n > out_len
-            model_cls, kwargs, err = config
-            model = model_cls(
-                input_chunk_length=IN_LEN, output_chunk_length=OUT_LEN, **kwargs
-            )
-            assert isinstance(
-                model,
-                (
-                    PastCovariatesTorchModel,
-                    DualCovariatesTorchModel,
-                    MixedCovariatesTorchModel,
-                ),
-            ), "unit test not yet defined for the given {X}CovariatesTorchModel."
-
-            if model.supports_past_covariates and model.supports_future_covariates:
-                past_covs, future_covs = None, self.covariates
-            elif model.supports_past_covariates:
-                past_covs, future_covs = self.covariates, None
-            elif model.supports_future_covariates:
-                past_covs, future_covs = None, self.covariates
-            else:
-                past_covs, future_covs = None, None
-
-            model.fit(
-                self.target_past,
-                past_covariates=past_covs,
-                future_covariates=future_covs,
-                epochs=1,
-            )
+        with pytest.raises(ValueError):
+            model.predict_from_dataset(n=1, input_series_dataset=unsupported_type)
 
-            # test prediction for n < out_len, n == out_len and n > out_len
-            for n in [OUT_LEN - 1, OUT_LEN, 2 * OUT_LEN - 1]:
-                pred = model.predict(
-                    n=n, past_covariates=past_covs, future_covariates=future_covs
-                )
-                assert len(pred) == n
+    @pytest.mark.parametrize("config", models_cls_kwargs_errs)
+    def test_prediction_with_different_n(self, config):
+        # test model predictions for n < out_len, n == out_len and n > out_len
+        model_cls, kwargs, err = config
+        model = model_cls(
+            input_chunk_length=IN_LEN, output_chunk_length=OUT_LEN, **kwargs
+        )
+        assert isinstance(
+            model,
+            (
+                PastCovariatesTorchModel,
+                DualCovariatesTorchModel,
+                MixedCovariatesTorchModel,
+            ),
+        ), "unit test not yet defined for the given {X}CovariatesTorchModel."
+
+        if model.supports_past_covariates and model.supports_future_covariates:
+            past_covs, future_covs = None, self.covariates
+        elif model.supports_past_covariates:
+            past_covs, future_covs = self.covariates, None
+        elif model.supports_future_covariates:
+            past_covs, future_covs = None, self.covariates
+        else:
+            past_covs, future_covs = None, None
+
+        model.fit(
+            self.target_past,
+            past_covariates=past_covs,
+            future_covariates=future_covs,
+            epochs=1,
+        )
 
-        @pytest.mark.parametrize("config", models_cls_kwargs_errs)
-        def test_same_result_with_different_n_jobs(self, config):
-            model_cls, kwargs, err = config
-            model = model_cls(
-                input_chunk_length=IN_LEN, output_chunk_length=OUT_LEN, **kwargs
+        # test prediction for n < out_len, n == out_len and n > out_len
+        for n in [OUT_LEN - 1, OUT_LEN, 2 * OUT_LEN - 1]:
+            pred = model.predict(
+                n=n, past_covariates=past_covs, future_covariates=future_covs
             )
+            assert len(pred) == n
+
+    @pytest.mark.parametrize("config", models_cls_kwargs_errs)
+    def test_same_result_with_different_n_jobs(self, config):
+        model_cls, kwargs, err = config
+        model = model_cls(
+            input_chunk_length=IN_LEN, output_chunk_length=OUT_LEN, **kwargs
+        )
 
-            multiple_ts = [self.ts_pass_train] * 10
+        multiple_ts = [self.ts_pass_train] * 10
 
-            model.fit(multiple_ts)
+        model.fit(multiple_ts)
 
-            # safe random state for two successive identical predictions
-            if model.supports_probabilistic_prediction:
-                random_state = deepcopy(model._random_instance)
-            else:
-                random_state = None
+        # safe random state for two successive identical predictions
+        if model.supports_probabilistic_prediction:
+            random_state = deepcopy(model._random_instance)
+        else:
+            random_state = None
 
-            pred1 = model.predict(n=36, series=multiple_ts, n_jobs=1)
+        pred1 = model.predict(n=36, series=multiple_ts, n_jobs=1)
 
-            if random_state is not None:
-                model._random_instance = random_state
+        if random_state is not None:
+            model._random_instance = random_state
 
-            pred2 = model.predict(
-                n=36, series=multiple_ts, n_jobs=-1
-            )  # assuming > 1 core available in the machine
-            assert (
-                pred1 == pred2
-            ), "Model {} produces different predictions with different number of jobs"
+        pred2 = model.predict(
+            n=36, series=multiple_ts, n_jobs=-1
+        )  # assuming > 1 core available in the machine
+        assert (
+            pred1 == pred2
+        ), "Model {} produces different predictions with different number of jobs"
 
-        @patch(
-            "darts.models.forecasting.torch_forecasting_model.TorchForecastingModel._init_trainer"
+    @patch(
+        "darts.models.forecasting.torch_forecasting_model.TorchForecastingModel._init_trainer"
+    )
+    @pytest.mark.parametrize("config", models_cls_kwargs_errs)
+    def test_fit_with_constr_epochs(self, init_trainer, config):
+        model_cls, kwargs, err = config
+        model = model_cls(
+            input_chunk_length=IN_LEN, output_chunk_length=OUT_LEN, **kwargs
         )
-        @pytest.mark.parametrize("config", models_cls_kwargs_errs)
-        def test_fit_with_constr_epochs(self, init_trainer, config):
-            model_cls, kwargs, err = config
-            model = model_cls(
-                input_chunk_length=IN_LEN, output_chunk_length=OUT_LEN, **kwargs
-            )
-            if not model._requires_training:
-                return
-            multiple_ts = [self.ts_pass_train] * 10
-            model.fit(multiple_ts)
+        if not model._requires_training:
+            return
+        multiple_ts = [self.ts_pass_train] * 10
+        model.fit(multiple_ts)
 
-            init_trainer.assert_called_with(
-                max_epochs=kwargs["n_epochs"], trainer_params=ANY
-            )
+        init_trainer.assert_called_with(
+            max_epochs=kwargs["n_epochs"], trainer_params=ANY
+        )
 
-        @patch(
-            "darts.models.forecasting.torch_forecasting_model.TorchForecastingModel._init_trainer"
+    @patch(
+        "darts.models.forecasting.torch_forecasting_model.TorchForecastingModel._init_trainer"
+    )
+    @pytest.mark.parametrize("config", models_cls_kwargs_errs)
+    def test_fit_with_fit_epochs(self, init_trainer, config):
+        model_cls, kwargs, err = config
+        model = model_cls(
+            input_chunk_length=IN_LEN, output_chunk_length=OUT_LEN, **kwargs
         )
-        @pytest.mark.parametrize("config", models_cls_kwargs_errs)
-        def test_fit_with_fit_epochs(self, init_trainer, config):
-            model_cls, kwargs, err = config
-            model = model_cls(
-                input_chunk_length=IN_LEN, output_chunk_length=OUT_LEN, **kwargs
-            )
-            multiple_ts = [self.ts_pass_train] * 10
-            epochs = 3
+        multiple_ts = [self.ts_pass_train] * 10
+        epochs = 3
 
-            model.fit(multiple_ts, epochs=epochs)
-            init_trainer.assert_called_with(max_epochs=epochs, trainer_params=ANY)
+        model.fit(multiple_ts, epochs=epochs)
+        init_trainer.assert_called_with(max_epochs=epochs, trainer_params=ANY)
 
-            model.total_epochs = epochs
-            # continue training
-            model.fit(multiple_ts, epochs=epochs)
-            init_trainer.assert_called_with(max_epochs=epochs, trainer_params=ANY)
+        model.total_epochs = epochs
+        # continue training
+        model.fit(multiple_ts, epochs=epochs)
+        init_trainer.assert_called_with(max_epochs=epochs, trainer_params=ANY)
 
-        @patch(
-            "darts.models.forecasting.torch_forecasting_model.TorchForecastingModel._init_trainer"
+    @patch(
+        "darts.models.forecasting.torch_forecasting_model.TorchForecastingModel._init_trainer"
+    )
+    @pytest.mark.parametrize("config", models_cls_kwargs_errs)
+    def test_fit_from_dataset_with_epochs(self, init_trainer, config):
+        model_cls, kwargs, err = config
+        model = model_cls(
+            input_chunk_length=IN_LEN, output_chunk_length=OUT_LEN, **kwargs
         )
-        @pytest.mark.parametrize("config", models_cls_kwargs_errs)
-        def test_fit_from_dataset_with_epochs(self, init_trainer, config):
-            model_cls, kwargs, err = config
-            model = model_cls(
-                input_chunk_length=IN_LEN, output_chunk_length=OUT_LEN, **kwargs
-            )
-            multiple_ts = [self.ts_pass_train] * 10
-            train_dataset = model._build_train_dataset(
-                multiple_ts,
-                past_covariates=None,
-                future_covariates=None,
-                max_samples_per_ts=None,
-            )
-            epochs = 3
+        multiple_ts = [self.ts_pass_train] * 10
+        train_dataset = model._build_train_dataset(
+            multiple_ts,
+            past_covariates=None,
+            future_covariates=None,
+            max_samples_per_ts=None,
+        )
+        epochs = 3
 
-            model.fit_from_dataset(train_dataset, epochs=epochs)
-            init_trainer.assert_called_with(max_epochs=epochs, trainer_params=ANY)
+        model.fit_from_dataset(train_dataset, epochs=epochs)
+        init_trainer.assert_called_with(max_epochs=epochs, trainer_params=ANY)
 
-            # continue training
-            model.fit_from_dataset(train_dataset, epochs=epochs)
-            init_trainer.assert_called_with(max_epochs=epochs, trainer_params=ANY)
+        # continue training
+        model.fit_from_dataset(train_dataset, epochs=epochs)
+        init_trainer.assert_called_with(max_epochs=epochs, trainer_params=ANY)
 
-        @pytest.mark.parametrize("config", models_cls_kwargs_errs)
-        def test_predit_after_fit_from_dataset(self, config):
-            model_cls, kwargs, _ = config
-            model = model_cls(
-                input_chunk_length=IN_LEN, output_chunk_length=OUT_LEN, **kwargs
-            )
+    @pytest.mark.parametrize("config", models_cls_kwargs_errs)
+    def test_predit_after_fit_from_dataset(self, config):
+        model_cls, kwargs, _ = config
+        model = model_cls(
+            input_chunk_length=IN_LEN, output_chunk_length=OUT_LEN, **kwargs
+        )
 
-            multiple_ts = [self.ts_pass_train] * 2
-            train_dataset = model._build_train_dataset(
-                multiple_ts,
-                past_covariates=None,
-                future_covariates=None,
-                max_samples_per_ts=None,
-            )
-            model.fit_from_dataset(train_dataset, epochs=1)
-
-            # test predict() works after fit_from_dataset()
-            model.predict(n=1, series=multiple_ts[0])
-
-        def test_sample_smaller_than_batch_size(self):
-            """
-            Checking that the TorchForecastingModels do not crash even if the number of available samples for training
-            is strictly lower than the selected batch_size
-            """
-            # TS with 50 timestamps. TorchForecastingModels will use the SequentialDataset for producing training
-            # samples, which means we will have 50 - 22 - 2 + 1 = 27 samples, which is < 32 (batch_size). The model
-            # should still train on those samples and not crash in any way
-            ts = linear_timeseries(start_value=0, end_value=1, length=50)
-
-            model = RNNModel(
-                input_chunk_length=20,
-                output_chunk_length=2,
-                n_epochs=2,
-                batch_size=32,
-                **tfm_kwargs,
-            )
-            model.fit(ts)
+        multiple_ts = [self.ts_pass_train] * 2
+        train_dataset = model._build_train_dataset(
+            multiple_ts,
+            past_covariates=None,
+            future_covariates=None,
+            max_samples_per_ts=None,
+        )
+        model.fit_from_dataset(train_dataset, epochs=1)
+
+        # test predict() works after fit_from_dataset()
+        model.predict(n=1, series=multiple_ts[0])
+
+    def test_sample_smaller_than_batch_size(self):
+        """
+        Checking that the TorchForecastingModels do not crash even if the number of available samples for training
+        is strictly lower than the selected batch_size
+        """
+        # TS with 50 timestamps. TorchForecastingModels will use the SequentialDataset for producing training
+        # samples, which means we will have 50 - 22 - 2 + 1 = 27 samples, which is < 32 (batch_size). The model
+        # should still train on those samples and not crash in any way
+        ts = linear_timeseries(start_value=0, end_value=1, length=50)
+
+        model = RNNModel(
+            input_chunk_length=20,
+            output_chunk_length=2,
+            n_epochs=2,
+            batch_size=32,
+            **tfm_kwargs,
+        )
+        model.fit(ts)
 
-        def test_max_samples_per_ts(self):
-            """
-            Checking that we can fit TorchForecastingModels with max_samples_per_ts, without crash
-            """
+    def test_max_samples_per_ts(self):
+        """
+        Checking that we can fit TorchForecastingModels with max_samples_per_ts, without crash
+        """
 
-            ts = linear_timeseries(start_value=0, end_value=1, length=50)
+        ts = linear_timeseries(start_value=0, end_value=1, length=50)
 
-            model = RNNModel(
-                input_chunk_length=20,
-                output_chunk_length=2,
-                n_epochs=2,
-                batch_size=32,
-                **tfm_kwargs,
-            )
+        model = RNNModel(
+            input_chunk_length=20,
+            output_chunk_length=2,
+            n_epochs=2,
+            batch_size=32,
+            **tfm_kwargs,
+        )
 
-            model.fit(ts, max_samples_per_ts=5)
-
-        def test_residuals(self):
-            """
-            Torch models should not fail when computing residuals on a series
-            long enough to accommodate at least one training sample.
-            """
-            ts = linear_timeseries(start_value=0, end_value=1, length=38)
-
-            model = NBEATSModel(
-                input_chunk_length=24,
-                output_chunk_length=12,
-                num_stacks=2,
-                num_blocks=1,
-                num_layers=1,
-                layer_widths=2,
-                n_epochs=2,
-                **tfm_kwargs,
-            )
+        model.fit(ts, max_samples_per_ts=5)
+
+    def test_residuals(self):
+        """
+        Torch models should not fail when computing residuals on a series
+        long enough to accommodate at least one training sample.
+        """
+        ts = linear_timeseries(start_value=0, end_value=1, length=38)
+
+        model = NBEATSModel(
+            input_chunk_length=24,
+            output_chunk_length=12,
+            num_stacks=2,
+            num_blocks=1,
+            num_layers=1,
+            layer_widths=2,
+            n_epochs=2,
+            **tfm_kwargs,
+        )
 
-            res = model.residuals(ts)
-            assert len(res) == 38 - (24 + 12)
+        res = model.residuals(ts)
+        assert len(res) == 38 - (24 + 12)
diff --git a/darts/tests/models/forecasting/test_nbeats_nhits.py b/darts/tests/models/forecasting/test_nbeats_nhits.py
index 88acadb254..5085356271 100644
--- a/darts/tests/models/forecasting/test_nbeats_nhits.py
+++ b/darts/tests/models/forecasting/test_nbeats_nhits.py
@@ -10,184 +10,194 @@
 try:
     from darts.models.forecasting.nbeats import NBEATSModel
     from darts.models.forecasting.nhits import NHiTSModel
-
-    TORCH_AVAILABLE = True
 except ImportError:
-    logger.warning("Torch not available. Nbeats and NHiTs tests will be skipped.")
-    TORCH_AVAILABLE = False
-
-
-if TORCH_AVAILABLE:
+    pytest.skip(
+        f"Torch not available. {__name__} tests will be skipped.",
+        allow_module_level=True,
+    )
 
-    class TestNbeatsNhitsModel:
-        def test_creation(self):
-            with pytest.raises(ValueError):
-                # if a list is passed to the `layer_widths` argument, it must have a length equal to `num_stacks`
-                NBEATSModel(
-                    input_chunk_length=1,
-                    output_chunk_length=1,
-                    num_stacks=3,
-                    layer_widths=[1, 2],
-                )
 
-            with pytest.raises(ValueError):
-                NHiTSModel(
-                    input_chunk_length=1,
-                    output_chunk_length=1,
-                    num_stacks=3,
-                    layer_widths=[1, 2],
-                )
+class TestNbeatsNhitsModel:
+    def test_creation(self):
+        with pytest.raises(ValueError):
+            # if a list is passed to the `layer_widths` argument, it must have a length equal to `num_stacks`
+            NBEATSModel(
+                input_chunk_length=1,
+                output_chunk_length=1,
+                num_stacks=3,
+                layer_widths=[1, 2],
+            )
 
-        def test_fit(self):
-            large_ts = tg.constant_timeseries(length=100, value=1000)
-            small_ts = tg.constant_timeseries(length=100, value=10)
+        with pytest.raises(ValueError):
+            NHiTSModel(
+                input_chunk_length=1,
+                output_chunk_length=1,
+                num_stacks=3,
+                layer_widths=[1, 2],
+            )
 
-            for model_cls in [NBEATSModel, NHiTSModel]:
-                # Test basic fit and predict
-                model = model_cls(
-                    input_chunk_length=1,
-                    output_chunk_length=1,
-                    n_epochs=10,
-                    num_stacks=1,
-                    num_blocks=1,
-                    layer_widths=20,
-                    random_state=42,
-                    **tfm_kwargs,
-                )
-                model.fit(large_ts[:98])
-                pred = model.predict(n=2).values()[0]
+    def test_fit(self):
+        large_ts = tg.constant_timeseries(length=100, value=1000)
+        small_ts = tg.constant_timeseries(length=100, value=10)
 
-                # Test whether model trained on one series is better than one trained on another
-                model2 = model_cls(
-                    input_chunk_length=1,
-                    output_chunk_length=1,
-                    n_epochs=10,
-                    num_stacks=1,
-                    num_blocks=1,
-                    layer_widths=20,
-                    random_state=42,
-                    **tfm_kwargs,
-                )
-                model2.fit(small_ts[:98])
-                pred2 = model2.predict(n=2).values()[0]
-                assert abs(pred2 - 10) < abs(pred - 10)
-
-                # test short predict
-                pred3 = model2.predict(n=1)
-                assert len(pred3) == 1
-
-        def test_multivariate(self):
-            # testing a 2-variate linear ts, first one from 0 to 1, second one from 0 to 0.5, length 100
-            series_multivariate = tg.linear_timeseries(length=100).stack(
-                tg.linear_timeseries(length=100, start_value=0, end_value=0.5)
+        for model_cls in [NBEATSModel, NHiTSModel]:
+            # Test basic fit and predict
+            model = model_cls(
+                input_chunk_length=1,
+                output_chunk_length=1,
+                n_epochs=10,
+                num_stacks=1,
+                num_blocks=1,
+                layer_widths=20,
+                random_state=42,
+                **tfm_kwargs,
             )
+            model.fit(large_ts[:98])
+            pred = model.predict(n=2).values()[0]
 
-            for model_cls in [NBEATSModel, NHiTSModel]:
-                model = model_cls(
-                    input_chunk_length=3,
-                    output_chunk_length=1,
-                    n_epochs=20,
-                    random_state=42,
-                    **tfm_kwargs,
-                )
-
-                model.fit(series_multivariate)
-                res = model.predict(n=2).values()
+            # Test whether model trained on one series is better than one trained on another
+            model2 = model_cls(
+                input_chunk_length=1,
+                output_chunk_length=1,
+                n_epochs=10,
+                num_stacks=1,
+                num_blocks=1,
+                layer_widths=20,
+                random_state=42,
+                **tfm_kwargs,
+            )
+            model2.fit(small_ts[:98])
+            pred2 = model2.predict(n=2).values()[0]
+            assert abs(pred2 - 10) < abs(pred - 10)
+
+            # test short predict
+            pred3 = model2.predict(n=1)
+            assert len(pred3) == 1
+
+    def test_multivariate(self):
+        # testing a 2-variate linear ts, first one from 0 to 1, second one from 0 to 0.5, length 100
+        series_multivariate = tg.linear_timeseries(length=100).stack(
+            tg.linear_timeseries(length=100, start_value=0, end_value=0.5)
+        )
+
+        for model_cls in [NBEATSModel, NHiTSModel]:
+            model = model_cls(
+                input_chunk_length=3,
+                output_chunk_length=1,
+                n_epochs=20,
+                random_state=42,
+                **tfm_kwargs,
+            )
 
-                # the theoretical result should be [[1.01, 1.02], [0.505, 0.51]].
-                # We just test if the given result is not too far on average.
-                assert abs(
-                    np.average(res - np.array([[1.01, 1.02], [0.505, 0.51]])) < 0.03
-                )
+            model.fit(series_multivariate)
+            res = model.predict(n=2).values()
 
-                # Test Covariates
-                series_covariates = tg.linear_timeseries(length=100).stack(
-                    tg.linear_timeseries(length=100, start_value=0, end_value=0.1)
-                )
-                model = model_cls(
-                    input_chunk_length=3,
-                    output_chunk_length=4,
-                    n_epochs=5,
-                    random_state=42,
-                    **tfm_kwargs,
-                )
-                model.fit(series_multivariate, past_covariates=series_covariates)
+            # the theoretical result should be [[1.01, 1.02], [0.505, 0.51]].
+            # We just test if the given result is not too far on average.
+            assert abs(np.average(res - np.array([[1.01, 1.02], [0.505, 0.51]])) < 0.03)
 
-                res = model.predict(
-                    n=3, series=series_multivariate, past_covariates=series_covariates
-                ).values()
+            # Test Covariates
+            series_covariates = tg.linear_timeseries(length=100).stack(
+                tg.linear_timeseries(length=100, start_value=0, end_value=0.1)
+            )
+            model = model_cls(
+                input_chunk_length=3,
+                output_chunk_length=4,
+                n_epochs=5,
+                random_state=42,
+                **tfm_kwargs,
+            )
+            model.fit(series_multivariate, past_covariates=series_covariates)
 
-                assert len(res) == 3
-                assert abs(np.average(res)) < 5
+            res = model.predict(
+                n=3, series=series_multivariate, past_covariates=series_covariates
+            ).values()
 
-        def test_nhits_sampling_sizes(self):
-            # providing bad sizes or shapes should fail
-            with pytest.raises(ValueError):
+            assert len(res) == 3
+            assert abs(np.average(res)) < 5
 
-                # wrong number of coeffs for stacks and blocks
-                NHiTSModel(
-                    input_chunk_length=1,
-                    output_chunk_length=1,
-                    num_stacks=1,
-                    num_blocks=2,
-                    pooling_kernel_sizes=((1,), (1,)),
-                    n_freq_downsample=((1,), (1,)),
-                )
-            with pytest.raises(ValueError):
-                NHiTSModel(
-                    input_chunk_length=1,
-                    output_chunk_length=1,
-                    num_stacks=2,
-                    num_blocks=2,
-                    pooling_kernel_sizes=((1, 1), (1, 1)),
-                    n_freq_downsample=((2, 1), (2, 2)),
-                )
+    def test_nhits_sampling_sizes(self):
+        # providing bad sizes or shapes should fail
+        with pytest.raises(ValueError):
 
-            # it shouldn't fail with the right number of coeffs
-            _ = NHiTSModel(
+            # wrong number of coeffs for stacks and blocks
+            NHiTSModel(
                 input_chunk_length=1,
                 output_chunk_length=1,
-                num_stacks=2,
+                num_stacks=1,
                 num_blocks=2,
-                pooling_kernel_sizes=((2, 1), (2, 1)),
-                n_freq_downsample=((2, 1), (2, 1)),
+                pooling_kernel_sizes=((1,), (1,)),
+                n_freq_downsample=((1,), (1,)),
             )
-
-            # default freqs should be such that last one is 1
-            model = NHiTSModel(
+        with pytest.raises(ValueError):
+            NHiTSModel(
                 input_chunk_length=1,
                 output_chunk_length=1,
                 num_stacks=2,
                 num_blocks=2,
+                pooling_kernel_sizes=((1, 1), (1, 1)),
+                n_freq_downsample=((2, 1), (2, 2)),
             )
-            assert model.n_freq_downsample[-1][-1] == 1
 
-        def test_logtensorboard(self, tmpdir_module):
-            ts = tg.constant_timeseries(length=50, value=10)
+        # it shouldn't fail with the right number of coeffs
+        _ = NHiTSModel(
+            input_chunk_length=1,
+            output_chunk_length=1,
+            num_stacks=2,
+            num_blocks=2,
+            pooling_kernel_sizes=((2, 1), (2, 1)),
+            n_freq_downsample=((2, 1), (2, 1)),
+        )
+
+        # default freqs should be such that last one is 1
+        model = NHiTSModel(
+            input_chunk_length=1,
+            output_chunk_length=1,
+            num_stacks=2,
+            num_blocks=2,
+        )
+        assert model.n_freq_downsample[-1][-1] == 1
+
+    def test_logtensorboard(self, tmpdir_module):
+        ts = tg.constant_timeseries(length=50, value=10)
+
+        # testing if both the modes (generic and interpretable) runs with tensorboard
+        architectures = [True, False]
+        for architecture in architectures:
+            # Test basic fit and predict
+            model = NBEATSModel(
+                input_chunk_length=1,
+                output_chunk_length=1,
+                n_epochs=1,
+                log_tensorboard=True,
+                work_dir=tmpdir_module,
+                generic_architecture=architecture,
+                pl_trainer_kwargs={
+                    "log_every_n_steps": 1,
+                    **tfm_kwargs["pl_trainer_kwargs"],
+                },
+            )
+            model.fit(ts)
+            model.predict(n=2)
 
-            # testing if both the modes (generic and interpretable) runs with tensorboard
-            architectures = [True, False]
-            for architecture in architectures:
-                # Test basic fit and predict
-                model = NBEATSModel(
-                    input_chunk_length=1,
-                    output_chunk_length=1,
-                    n_epochs=1,
-                    log_tensorboard=True,
-                    work_dir=tmpdir_module,
-                    generic_architecture=architecture,
-                    pl_trainer_kwargs={
-                        "log_every_n_steps": 1,
-                        **tfm_kwargs["pl_trainer_kwargs"],
-                    },
-                )
-                model.fit(ts)
-                model.predict(n=2)
+    def test_activation_fns(self):
+        ts = tg.constant_timeseries(length=50, value=10)
 
-        def test_activation_fns(self):
-            ts = tg.constant_timeseries(length=50, value=10)
+        for model_cls in [NBEATSModel, NHiTSModel]:
+            model = model_cls(
+                input_chunk_length=1,
+                output_chunk_length=1,
+                n_epochs=10,
+                num_stacks=1,
+                num_blocks=1,
+                layer_widths=20,
+                random_state=42,
+                activation="LeakyReLU",
+                **tfm_kwargs,
+            )
+            model.fit(ts)
 
-            for model_cls in [NBEATSModel, NHiTSModel]:
+            with pytest.raises(ValueError):
                 model = model_cls(
                     input_chunk_length=1,
                     output_chunk_length=1,
@@ -196,21 +206,7 @@ def test_activation_fns(self):
                     num_blocks=1,
                     layer_widths=20,
                     random_state=42,
-                    activation="LeakyReLU",
+                    activation="invalid",
                     **tfm_kwargs,
                 )
                 model.fit(ts)
-
-                with pytest.raises(ValueError):
-                    model = model_cls(
-                        input_chunk_length=1,
-                        output_chunk_length=1,
-                        n_epochs=10,
-                        num_stacks=1,
-                        num_blocks=1,
-                        layer_widths=20,
-                        random_state=42,
-                        activation="invalid",
-                        **tfm_kwargs,
-                    )
-                    model.fit(ts)
diff --git a/darts/tests/models/forecasting/test_ptl_trainer.py b/darts/tests/models/forecasting/test_ptl_trainer.py
index d9449fa58d..bfc4c349d4 100644
--- a/darts/tests/models/forecasting/test_ptl_trainer.py
+++ b/darts/tests/models/forecasting/test_ptl_trainer.py
@@ -11,273 +11,271 @@
     import pytorch_lightning as pl
 
     from darts.models.forecasting.rnn_model import RNNModel
-
-    TORCH_AVAILABLE = True
 except ImportError:
-    logger.warning("Torch not available. RNN tests will be skipped.")
-    TORCH_AVAILABLE = False
-
-
-if TORCH_AVAILABLE:
-
-    class TestTorchForecastingModel:
-        trainer_params = {
-            "max_epochs": 1,
-            "logger": False,
-            "enable_checkpointing": False,
-        }
-        series = linear_timeseries(length=100).astype(np.float32)
-        pl_200_or_above = int(pl.__version__.split(".")[0]) >= 2
-        precisions = {
-            32: "32" if not pl_200_or_above else "32-true",
-            64: "64" if not pl_200_or_above else "64-true",
-        }
-
-        def test_prediction_loaded_custom_trainer(self, tmpdir_module):
-            """validate manual save with automatic save files by comparing output between the two"""
-            auto_name = "test_save_automatic"
-            model = RNNModel(
-                12,
-                "RNN",
-                10,
-                10,
-                model_name=auto_name,
-                work_dir=tmpdir_module,
-                save_checkpoints=True,
-                random_state=42,
-                **tfm_kwargs,
-            )
-
-            # fit model with custom trainer
-            trainer = pl.Trainer(
-                max_epochs=1,
-                enable_checkpointing=True,
-                logger=False,
-                callbacks=model.trainer_params["callbacks"],
-                **tfm_kwargs["pl_trainer_kwargs"],
-            )
+    pytest.skip(
+        f"Torch not available. {__name__} tests will be skipped.",
+        allow_module_level=True,
+    )
+
+
+class TestPTLTrainer:
+    trainer_params = {
+        "max_epochs": 1,
+        "logger": False,
+        "enable_checkpointing": False,
+    }
+    series = linear_timeseries(length=100).astype(np.float32)
+    pl_200_or_above = int(pl.__version__.split(".")[0]) >= 2
+    precisions = {
+        32: "32" if not pl_200_or_above else "32-true",
+        64: "64" if not pl_200_or_above else "64-true",
+    }
+
+    def test_prediction_loaded_custom_trainer(self, tmpdir_module):
+        """validate manual save with automatic save files by comparing output between the two"""
+        auto_name = "test_save_automatic"
+        model = RNNModel(
+            12,
+            "RNN",
+            10,
+            10,
+            model_name=auto_name,
+            work_dir=tmpdir_module,
+            save_checkpoints=True,
+            random_state=42,
+            **tfm_kwargs,
+        )
+
+        # fit model with custom trainer
+        trainer = pl.Trainer(
+            max_epochs=1,
+            enable_checkpointing=True,
+            logger=False,
+            callbacks=model.trainer_params["callbacks"],
+            **tfm_kwargs["pl_trainer_kwargs"],
+        )
+        model.fit(self.series, trainer=trainer)
+
+        # load automatically saved model with manual load_model() and load_from_checkpoint()
+        model_loaded = RNNModel.load_from_checkpoint(
+            model_name=auto_name,
+            work_dir=tmpdir_module,
+            best=False,
+            map_location="cpu",
+        )
+
+        # compare prediction of loaded model with original model
+        assert model.predict(n=4) == model_loaded.predict(n=4)
+
+    def test_prediction_custom_trainer(self):
+        model = RNNModel(12, "RNN", 10, 10, random_state=42, **tfm_kwargs)
+        model2 = RNNModel(12, "RNN", 10, 10, random_state=42, **tfm_kwargs)
+
+        # fit model with custom trainer
+        trainer = pl.Trainer(
+            **self.trainer_params,
+            precision=self.precisions[32],
+            **tfm_kwargs["pl_trainer_kwargs"],
+        )
+        model.fit(self.series, trainer=trainer)
+
+        # fit model with built-in trainer
+        model2.fit(self.series, epochs=1)
+
+        # both should produce identical prediction
+        assert model.predict(n=4) == model2.predict(n=4)
+
+    def test_custom_trainer_setup(self):
+        model = RNNModel(12, "RNN", 10, 10, random_state=42, **tfm_kwargs)
+
+        # trainer with wrong precision should raise ValueError
+        trainer = pl.Trainer(
+            **self.trainer_params,
+            precision=self.precisions[64],
+            **tfm_kwargs["pl_trainer_kwargs"],
+        )
+        with pytest.raises(ValueError):
             model.fit(self.series, trainer=trainer)
 
-            # load automatically saved model with manual load_model() and load_from_checkpoint()
-            model_loaded = RNNModel.load_from_checkpoint(
-                model_name=auto_name,
-                work_dir=tmpdir_module,
-                best=False,
-                map_location="cpu",
-            )
-
-            # compare prediction of loaded model with original model
-            assert model.predict(n=4) == model_loaded.predict(n=4)
-
-        def test_prediction_custom_trainer(self):
-            model = RNNModel(12, "RNN", 10, 10, random_state=42, **tfm_kwargs)
-            model2 = RNNModel(12, "RNN", 10, 10, random_state=42, **tfm_kwargs)
-
-            # fit model with custom trainer
-            trainer = pl.Trainer(
-                **self.trainer_params,
-                precision=self.precisions[32],
+        # no error with correct precision
+        trainer = pl.Trainer(
+            **self.trainer_params,
+            precision=self.precisions[32],
+            **tfm_kwargs["pl_trainer_kwargs"],
+        )
+        model.fit(self.series, trainer=trainer)
+
+        # check if number of epochs trained is same as trainer.max_epochs
+        assert trainer.max_epochs == model.epochs_trained
+
+    def test_builtin_extended_trainer(self):
+        # wrong precision parameter name
+        with pytest.raises(TypeError):
+            invalid_trainer_kwarg = {
+                "precisionn": self.precisions[32],
                 **tfm_kwargs["pl_trainer_kwargs"],
-            )
-            model.fit(self.series, trainer=trainer)
-
-            # fit model with built-in trainer
-            model2.fit(self.series, epochs=1)
-
-            # both should produce identical prediction
-            assert model.predict(n=4) == model2.predict(n=4)
-
-        def test_custom_trainer_setup(self):
-            model = RNNModel(12, "RNN", 10, 10, random_state=42, **tfm_kwargs)
-
-            # trainer with wrong precision should raise ValueError
-            trainer = pl.Trainer(
-                **self.trainer_params,
-                precision=self.precisions[64],
-                **tfm_kwargs["pl_trainer_kwargs"],
-            )
-            with pytest.raises(ValueError):
-                model.fit(self.series, trainer=trainer)
-
-            # no error with correct precision
-            trainer = pl.Trainer(
-                **self.trainer_params,
-                precision=self.precisions[32],
-                **tfm_kwargs["pl_trainer_kwargs"],
-            )
-            model.fit(self.series, trainer=trainer)
-
-            # check if number of epochs trained is same as trainer.max_epochs
-            assert trainer.max_epochs == model.epochs_trained
-
-        def test_builtin_extended_trainer(self):
-            # wrong precision parameter name
-            with pytest.raises(TypeError):
-                invalid_trainer_kwarg = {
-                    "precisionn": self.precisions[32],
-                    **tfm_kwargs["pl_trainer_kwargs"],
-                }
-                model = RNNModel(
-                    12,
-                    "RNN",
-                    10,
-                    10,
-                    random_state=42,
-                    pl_trainer_kwargs=invalid_trainer_kwarg,
-                )
-                model.fit(self.series, epochs=1)
-
-            # flaot 16 not supported
-            with pytest.raises(ValueError):
-                invalid_trainer_kwarg = {
-                    "precision": "16-mixed",
-                    **tfm_kwargs["pl_trainer_kwargs"],
-                }
-                model = RNNModel(
-                    12,
-                    "RNN",
-                    10,
-                    10,
-                    random_state=42,
-                    pl_trainer_kwargs=invalid_trainer_kwarg,
-                )
-                model.fit(self.series.astype(np.float16), epochs=1)
-
-            # precision value doesn't match `series` dtype
-            with pytest.raises(ValueError):
-                invalid_trainer_kwarg = {
-                    "precision": self.precisions[64],
-                    **tfm_kwargs["pl_trainer_kwargs"],
-                }
-                model = RNNModel(
-                    12,
-                    "RNN",
-                    10,
-                    10,
-                    random_state=42,
-                    pl_trainer_kwargs=invalid_trainer_kwarg,
-                )
-                model.fit(self.series.astype(np.float32), epochs=1)
-
-            for precision in [64, 32]:
-                valid_trainer_kwargs = {
-                    "precision": self.precisions[precision],
-                    **tfm_kwargs["pl_trainer_kwargs"],
-                }
-
-                # valid parameters shouldn't raise error
-                model = RNNModel(
-                    12,
-                    "RNN",
-                    10,
-                    10,
-                    random_state=42,
-                    pl_trainer_kwargs=valid_trainer_kwargs,
-                )
-                ts_dtype = getattr(np, f"float{precision}")
-                model.fit(self.series.astype(ts_dtype), epochs=1)
-                preds = model.predict(n=3)
-                assert model.trainer.precision == self.precisions[precision]
-                assert preds.dtype == ts_dtype
-
-        def test_custom_callback(self, tmpdir_module):
-            class CounterCallback(pl.callbacks.Callback):
-                # counts the number of trained epochs starting from count_default
-                def __init__(self, count_default):
-                    self.counter = count_default
-
-                def on_train_epoch_end(self, *args, **kwargs):
-                    self.counter += 1
-
-            my_counter_0 = CounterCallback(count_default=0)
-            my_counter_2 = CounterCallback(count_default=2)
-
+            }
             model = RNNModel(
                 12,
                 "RNN",
                 10,
                 10,
                 random_state=42,
-                pl_trainer_kwargs={
-                    "callbacks": [my_counter_0, my_counter_2],
-                    **tfm_kwargs["pl_trainer_kwargs"],
-                },
+                pl_trainer_kwargs=invalid_trainer_kwarg,
             )
+            model.fit(self.series, epochs=1)
 
-            # check if callbacks were added
-            assert len(model.trainer_params["callbacks"]) == 2
-            model.fit(self.series, epochs=2, verbose=True)
-            # check that lightning did not mutate callbacks (verbosity adds a progress bar callback)
-            assert len(model.trainer_params["callbacks"]) == 2
-
-            assert my_counter_0.counter == model.epochs_trained
-            assert my_counter_2.counter == model.epochs_trained + 2
-
-            # check that callbacks don't overwrite Darts' built-in checkpointer
+        # flaot 16 not supported
+        with pytest.raises(ValueError):
+            invalid_trainer_kwarg = {
+                "precision": "16-mixed",
+                **tfm_kwargs["pl_trainer_kwargs"],
+            }
             model = RNNModel(
                 12,
                 "RNN",
                 10,
                 10,
                 random_state=42,
-                work_dir=tmpdir_module,
-                save_checkpoints=True,
-                pl_trainer_kwargs={
-                    "callbacks": [CounterCallback(0), CounterCallback(2)],
-                    **tfm_kwargs["pl_trainer_kwargs"],
-                },
-            )
-            # we expect 3 callbacks
-            assert len(model.trainer_params["callbacks"]) == 3
-
-            # first one is our Checkpointer
-            assert isinstance(
-                model.trainer_params["callbacks"][0], pl.callbacks.ModelCheckpoint
+                pl_trainer_kwargs=invalid_trainer_kwarg,
             )
+            model.fit(self.series.astype(np.float16), epochs=1)
 
-            # second and third are CounterCallbacks
-            for i in range(1, 3):
-                assert isinstance(model.trainer_params["callbacks"][i], CounterCallback)
-
-        def test_early_stopping(self):
-            my_stopper = pl.callbacks.early_stopping.EarlyStopping(
-                monitor="val_loss",
-                stopping_threshold=1e9,
-            )
+        # precision value doesn't match `series` dtype
+        with pytest.raises(ValueError):
+            invalid_trainer_kwarg = {
+                "precision": self.precisions[64],
+                **tfm_kwargs["pl_trainer_kwargs"],
+            }
             model = RNNModel(
                 12,
                 "RNN",
                 10,
                 10,
-                nr_epochs_val_period=1,
                 random_state=42,
-                pl_trainer_kwargs={
-                    "callbacks": [my_stopper],
-                    **tfm_kwargs["pl_trainer_kwargs"],
-                },
+                pl_trainer_kwargs=invalid_trainer_kwarg,
             )
+            model.fit(self.series.astype(np.float32), epochs=1)
 
-            # training should stop immediately with high stopping_threshold
-            model.fit(self.series, val_series=self.series, epochs=100, verbose=True)
-            assert model.epochs_trained == 1
+        for precision in [64, 32]:
+            valid_trainer_kwargs = {
+                "precision": self.precisions[precision],
+                **tfm_kwargs["pl_trainer_kwargs"],
+            }
 
-            # check that early stopping only takes valid monitor variables
-            my_stopper = pl.callbacks.early_stopping.EarlyStopping(
-                monitor="invalid_variable",
-                stopping_threshold=1e9,
-            )
+            # valid parameters shouldn't raise error
             model = RNNModel(
                 12,
                 "RNN",
                 10,
                 10,
-                nr_epochs_val_period=1,
                 random_state=42,
-                pl_trainer_kwargs={
-                    "callbacks": [my_stopper],
-                    **tfm_kwargs["pl_trainer_kwargs"],
-                },
+                pl_trainer_kwargs=valid_trainer_kwargs,
             )
+            ts_dtype = getattr(np, f"float{precision}")
+            model.fit(self.series.astype(ts_dtype), epochs=1)
+            preds = model.predict(n=3)
+            assert model.trainer.precision == self.precisions[precision]
+            assert preds.dtype == ts_dtype
+
+    def test_custom_callback(self, tmpdir_module):
+        class CounterCallback(pl.callbacks.Callback):
+            # counts the number of trained epochs starting from count_default
+            def __init__(self, count_default):
+                self.counter = count_default
+
+            def on_train_epoch_end(self, *args, **kwargs):
+                self.counter += 1
+
+        my_counter_0 = CounterCallback(count_default=0)
+        my_counter_2 = CounterCallback(count_default=2)
+
+        model = RNNModel(
+            12,
+            "RNN",
+            10,
+            10,
+            random_state=42,
+            pl_trainer_kwargs={
+                "callbacks": [my_counter_0, my_counter_2],
+                **tfm_kwargs["pl_trainer_kwargs"],
+            },
+        )
+
+        # check if callbacks were added
+        assert len(model.trainer_params["callbacks"]) == 2
+        model.fit(self.series, epochs=2, verbose=True)
+        # check that lightning did not mutate callbacks (verbosity adds a progress bar callback)
+        assert len(model.trainer_params["callbacks"]) == 2
+
+        assert my_counter_0.counter == model.epochs_trained
+        assert my_counter_2.counter == model.epochs_trained + 2
+
+        # check that callbacks don't overwrite Darts' built-in checkpointer
+        model = RNNModel(
+            12,
+            "RNN",
+            10,
+            10,
+            random_state=42,
+            work_dir=tmpdir_module,
+            save_checkpoints=True,
+            pl_trainer_kwargs={
+                "callbacks": [CounterCallback(0), CounterCallback(2)],
+                **tfm_kwargs["pl_trainer_kwargs"],
+            },
+        )
+        # we expect 3 callbacks
+        assert len(model.trainer_params["callbacks"]) == 3
+
+        # first one is our Checkpointer
+        assert isinstance(
+            model.trainer_params["callbacks"][0], pl.callbacks.ModelCheckpoint
+        )
+
+        # second and third are CounterCallbacks
+        for i in range(1, 3):
+            assert isinstance(model.trainer_params["callbacks"][i], CounterCallback)
+
+    def test_early_stopping(self):
+        my_stopper = pl.callbacks.early_stopping.EarlyStopping(
+            monitor="val_loss",
+            stopping_threshold=1e9,
+        )
+        model = RNNModel(
+            12,
+            "RNN",
+            10,
+            10,
+            nr_epochs_val_period=1,
+            random_state=42,
+            pl_trainer_kwargs={
+                "callbacks": [my_stopper],
+                **tfm_kwargs["pl_trainer_kwargs"],
+            },
+        )
+
+        # training should stop immediately with high stopping_threshold
+        model.fit(self.series, val_series=self.series, epochs=100, verbose=True)
+        assert model.epochs_trained == 1
+
+        # check that early stopping only takes valid monitor variables
+        my_stopper = pl.callbacks.early_stopping.EarlyStopping(
+            monitor="invalid_variable",
+            stopping_threshold=1e9,
+        )
+        model = RNNModel(
+            12,
+            "RNN",
+            10,
+            10,
+            nr_epochs_val_period=1,
+            random_state=42,
+            pl_trainer_kwargs={
+                "callbacks": [my_stopper],
+                **tfm_kwargs["pl_trainer_kwargs"],
+            },
+        )
 
-            with pytest.raises(RuntimeError):
-                model.fit(self.series, val_series=self.series, epochs=100, verbose=True)
+        with pytest.raises(RuntimeError):
+            model.fit(self.series, val_series=self.series, epochs=100, verbose=True)
diff --git a/darts/tests/models/forecasting/test_tide_model.py b/darts/tests/models/forecasting/test_tide_model.py
index e8571a6c58..c8ebd824a8 100644
--- a/darts/tests/models/forecasting/test_tide_model.py
+++ b/darts/tests/models/forecasting/test_tide_model.py
@@ -14,265 +14,263 @@
 
     from darts.models.forecasting.tide_model import TiDEModel
     from darts.utils.likelihood_models import GaussianLikelihood
-
-    TORCH_AVAILABLE = True
-
 except ImportError:
-    logger.warning("Torch not available. TiDEModel tests will be skipped.")
-    TORCH_AVAILABLE = False
+    pytest.skip(
+        f"Torch not available. {__name__} tests will be skipped.",
+        allow_module_level=True,
+    )
 
-if TORCH_AVAILABLE:
 
-    class TestTiDEModel:
-        np.random.seed(42)
-        torch.manual_seed(42)
+class TestTiDEModel:
+    np.random.seed(42)
+    torch.manual_seed(42)
 
-        def test_creation(self):
-            model = TiDEModel(
-                input_chunk_length=1,
-                output_chunk_length=1,
-                likelihood=GaussianLikelihood(),
-            )
-
-            assert model.input_chunk_length == 1
+    def test_creation(self):
+        model = TiDEModel(
+            input_chunk_length=1,
+            output_chunk_length=1,
+            likelihood=GaussianLikelihood(),
+        )
 
-        def test_fit(self):
-            large_ts = tg.constant_timeseries(length=100, value=1000)
-            small_ts = tg.constant_timeseries(length=100, value=10)
+        assert model.input_chunk_length == 1
 
-            model = TiDEModel(
-                input_chunk_length=1,
-                output_chunk_length=1,
-                n_epochs=10,
-                random_state=42,
-                **tfm_kwargs,
-            )
+    def test_fit(self):
+        large_ts = tg.constant_timeseries(length=100, value=1000)
+        small_ts = tg.constant_timeseries(length=100, value=10)
 
-            model.fit(large_ts[:98])
-            pred = model.predict(n=2).values()[0]
+        model = TiDEModel(
+            input_chunk_length=1,
+            output_chunk_length=1,
+            n_epochs=10,
+            random_state=42,
+            **tfm_kwargs,
+        )
 
-            # Test whether model trained on one series is better than one trained on another
-            model2 = TiDEModel(
-                input_chunk_length=1,
-                output_chunk_length=1,
-                n_epochs=10,
-                random_state=42,
-                **tfm_kwargs,
-            )
+        model.fit(large_ts[:98])
+        pred = model.predict(n=2).values()[0]
 
-            model2.fit(small_ts[:98])
-            pred2 = model2.predict(n=2).values()[0]
-            assert abs(pred2 - 10) < abs(pred - 10)
+        # Test whether model trained on one series is better than one trained on another
+        model2 = TiDEModel(
+            input_chunk_length=1,
+            output_chunk_length=1,
+            n_epochs=10,
+            random_state=42,
+            **tfm_kwargs,
+        )
 
-            # test short predict
-            pred3 = model2.predict(n=1)
-            assert len(pred3) == 1
+        model2.fit(small_ts[:98])
+        pred2 = model2.predict(n=2).values()[0]
+        assert abs(pred2 - 10) < abs(pred - 10)
+
+        # test short predict
+        pred3 = model2.predict(n=1)
+        assert len(pred3) == 1
+
+    def test_logtensorboard(self, tmpdir_module):
+        ts = tg.constant_timeseries(length=50, value=10)
+
+        # Test basic fit and predict
+        model = TiDEModel(
+            input_chunk_length=1,
+            output_chunk_length=1,
+            n_epochs=1,
+            log_tensorboard=True,
+            work_dir=tmpdir_module,
+            pl_trainer_kwargs={
+                "log_every_n_steps": 1,
+                **tfm_kwargs["pl_trainer_kwargs"],
+            },
+        )
+        model.fit(ts)
+        model.predict(n=2)
+
+    def test_future_covariate_handling(self):
+        ts_time_index = tg.sine_timeseries(length=2, freq="h")
+
+        model = TiDEModel(
+            input_chunk_length=1,
+            output_chunk_length=1,
+            add_encoders={"cyclic": {"future": "hour"}},
+            use_reversible_instance_norm=False,
+            **tfm_kwargs,
+        )
+        model.fit(ts_time_index, verbose=False, epochs=1)
+
+        model = TiDEModel(
+            input_chunk_length=1,
+            output_chunk_length=1,
+            add_encoders={"cyclic": {"future": "hour"}},
+            use_reversible_instance_norm=True,
+            **tfm_kwargs,
+        )
+        model.fit(ts_time_index, verbose=False, epochs=1)
 
-        def test_logtensorboard(self, tmpdir_module):
-            ts = tg.constant_timeseries(length=50, value=10)
+    def test_future_and_past_covariate_handling(self):
+        ts_time_index = tg.sine_timeseries(length=2, freq="h")
 
-            # Test basic fit and predict
-            model = TiDEModel(
-                input_chunk_length=1,
-                output_chunk_length=1,
-                n_epochs=1,
-                log_tensorboard=True,
-                work_dir=tmpdir_module,
-                pl_trainer_kwargs={
-                    "log_every_n_steps": 1,
-                    **tfm_kwargs["pl_trainer_kwargs"],
-                },
-            )
-            model.fit(ts)
-            model.predict(n=2)
+        model = TiDEModel(
+            input_chunk_length=1,
+            output_chunk_length=1,
+            add_encoders={"cyclic": {"future": "hour", "past": "hour"}},
+            **tfm_kwargs,
+        )
+        model.fit(ts_time_index, verbose=False, epochs=1)
 
-        def test_future_covariate_handling(self):
-            ts_time_index = tg.sine_timeseries(length=2, freq="h")
+        model = TiDEModel(
+            input_chunk_length=1,
+            output_chunk_length=1,
+            add_encoders={"cyclic": {"future": "hour", "past": "hour"}},
+            **tfm_kwargs,
+        )
+        model.fit(ts_time_index, verbose=False, epochs=1)
 
-            model = TiDEModel(
+    @pytest.mark.parametrize("temporal_widths", [(-1, 1), (1, -1)])
+    def test_failing_future_and_past_temporal_widths(self, temporal_widths):
+        # invalid temporal widths
+        with pytest.raises(ValueError):
+            TiDEModel(
                 input_chunk_length=1,
                 output_chunk_length=1,
-                add_encoders={"cyclic": {"future": "hour"}},
-                use_reversible_instance_norm=False,
+                temporal_width_past=temporal_widths[0],
+                temporal_width_future=temporal_widths[1],
                 **tfm_kwargs,
             )
-            model.fit(ts_time_index, verbose=False, epochs=1)
 
-            model = TiDEModel(
-                input_chunk_length=1,
-                output_chunk_length=1,
-                add_encoders={"cyclic": {"future": "hour"}},
-                use_reversible_instance_norm=True,
-                **tfm_kwargs,
-            )
-            model.fit(ts_time_index, verbose=False, epochs=1)
+    @pytest.mark.parametrize(
+        "temporal_widths",
+        [
+            (2, 2),  # feature projection to same amount of features
+            (1, 2),  # past: feature reduction, future: same amount of features
+            (2, 1),  # past: same amount of features, future: feature reduction
+            (3, 3),  # feature expansion
+            (0, 2),  # bypass past feature projection
+            (2, 0),  # bypass future feature projection
+            (0, 0),  # bypass all feature projection
+        ],
+    )
+    def test_future_and_past_temporal_widths(self, temporal_widths):
+        ts_time_index = tg.sine_timeseries(length=2, freq="h")
+
+        # feature projection to 2 features (same amount as input features)
+        model = TiDEModel(
+            input_chunk_length=1,
+            output_chunk_length=1,
+            temporal_width_past=temporal_widths[0],
+            temporal_width_future=temporal_widths[1],
+            add_encoders={"cyclic": {"future": "hour", "past": "hour"}},
+            **tfm_kwargs,
+        )
+        model.fit(ts_time_index, verbose=False, epochs=1)
+        assert model.model.temporal_width_past == temporal_widths[0]
+        assert model.model.temporal_width_future == temporal_widths[1]
+
+    def test_past_covariate_handling(self):
+        ts_time_index = tg.sine_timeseries(length=2, freq="h")
+
+        model = TiDEModel(
+            input_chunk_length=1,
+            output_chunk_length=1,
+            add_encoders={"cyclic": {"past": "hour"}},
+            **tfm_kwargs,
+        )
+        model.fit(ts_time_index, verbose=False, epochs=1)
 
-        def test_future_and_past_covariate_handling(self):
-            ts_time_index = tg.sine_timeseries(length=2, freq="h")
+    def test_future_and_past_covariate_as_timeseries_handling(self):
+        ts_time_index = tg.sine_timeseries(length=2, freq="h")
 
-            model = TiDEModel(
-                input_chunk_length=1,
-                output_chunk_length=1,
-                add_encoders={"cyclic": {"future": "hour", "past": "hour"}},
-                **tfm_kwargs,
-            )
-            model.fit(ts_time_index, verbose=False, epochs=1)
+        for enable_rin in [True, False]:
 
+            # test with past_covariates timeseries
             model = TiDEModel(
                 input_chunk_length=1,
                 output_chunk_length=1,
                 add_encoders={"cyclic": {"future": "hour", "past": "hour"}},
+                use_reversible_instance_norm=enable_rin,
                 **tfm_kwargs,
             )
-            model.fit(ts_time_index, verbose=False, epochs=1)
-
-        @pytest.mark.parametrize("temporal_widths", [(-1, 1), (1, -1)])
-        def test_failing_future_and_past_temporal_widths(self, temporal_widths):
-            # invalid temporal widths
-            with pytest.raises(ValueError):
-                TiDEModel(
-                    input_chunk_length=1,
-                    output_chunk_length=1,
-                    temporal_width_past=temporal_widths[0],
-                    temporal_width_future=temporal_widths[1],
-                    **tfm_kwargs,
-                )
-
-        @pytest.mark.parametrize(
-            "temporal_widths",
-            [
-                (2, 2),  # feature projection to same amount of features
-                (1, 2),  # past: feature reduction, future: same amount of features
-                (2, 1),  # past: same amount of features, future: feature reduction
-                (3, 3),  # feature expansion
-                (0, 2),  # bypass past feature projection
-                (2, 0),  # bypass future feature projection
-                (0, 0),  # bypass all feature projection
-            ],
-        )
-        def test_future_and_past_temporal_widths(self, temporal_widths):
-            ts_time_index = tg.sine_timeseries(length=2, freq="h")
-
-            # feature projection to 2 features (same amount as input features)
-            model = TiDEModel(
-                input_chunk_length=1,
-                output_chunk_length=1,
-                temporal_width_past=temporal_widths[0],
-                temporal_width_future=temporal_widths[1],
-                add_encoders={"cyclic": {"future": "hour", "past": "hour"}},
-                **tfm_kwargs,
+            model.fit(
+                ts_time_index,
+                past_covariates=ts_time_index,
+                verbose=False,
+                epochs=1,
             )
-            model.fit(ts_time_index, verbose=False, epochs=1)
-            assert model.model.temporal_width_past == temporal_widths[0]
-            assert model.model.temporal_width_future == temporal_widths[1]
-
-        def test_past_covariate_handling(self):
-            ts_time_index = tg.sine_timeseries(length=2, freq="h")
 
+            # test with past_covariates and future_covariates timeseries
             model = TiDEModel(
                 input_chunk_length=1,
                 output_chunk_length=1,
-                add_encoders={"cyclic": {"past": "hour"}},
+                add_encoders={"cyclic": {"future": "hour", "past": "hour"}},
+                use_reversible_instance_norm=enable_rin,
                 **tfm_kwargs,
             )
-            model.fit(ts_time_index, verbose=False, epochs=1)
-
-        def test_future_and_past_covariate_as_timeseries_handling(self):
-            ts_time_index = tg.sine_timeseries(length=2, freq="h")
-
-            for enable_rin in [True, False]:
-
-                # test with past_covariates timeseries
-                model = TiDEModel(
-                    input_chunk_length=1,
-                    output_chunk_length=1,
-                    add_encoders={"cyclic": {"future": "hour", "past": "hour"}},
-                    use_reversible_instance_norm=enable_rin,
-                    **tfm_kwargs,
-                )
-                model.fit(
-                    ts_time_index,
-                    past_covariates=ts_time_index,
-                    verbose=False,
-                    epochs=1,
-                )
-
-                # test with past_covariates and future_covariates timeseries
-                model = TiDEModel(
-                    input_chunk_length=1,
-                    output_chunk_length=1,
-                    add_encoders={"cyclic": {"future": "hour", "past": "hour"}},
-                    use_reversible_instance_norm=enable_rin,
-                    **tfm_kwargs,
-                )
-                model.fit(
-                    ts_time_index,
-                    past_covariates=ts_time_index,
-                    future_covariates=ts_time_index,
-                    verbose=False,
-                    epochs=1,
-                )
-
-        def test_static_covariates_support(self):
-            target_multi = concatenate(
-                [tg.sine_timeseries(length=10, freq="h")] * 2, axis=1
+            model.fit(
+                ts_time_index,
+                past_covariates=ts_time_index,
+                future_covariates=ts_time_index,
+                verbose=False,
+                epochs=1,
             )
 
-            target_multi = target_multi.with_static_covariates(
-                pd.DataFrame(
-                    [[0.0, 1.0, 0, 2], [2.0, 3.0, 1, 3]],
-                    columns=["st1", "st2", "cat1", "cat2"],
-                )
-            )
+    def test_static_covariates_support(self):
+        target_multi = concatenate(
+            [tg.sine_timeseries(length=10, freq="h")] * 2, axis=1
+        )
 
-            # test with static covariates in the timeseries
-            model = TiDEModel(
-                input_chunk_length=3,
-                output_chunk_length=4,
-                add_encoders={"cyclic": {"future": "hour"}},
-                pl_trainer_kwargs={
-                    "fast_dev_run": True,
-                    **tfm_kwargs["pl_trainer_kwargs"],
-                },
+        target_multi = target_multi.with_static_covariates(
+            pd.DataFrame(
+                [[0.0, 1.0, 0, 2], [2.0, 3.0, 1, 3]],
+                columns=["st1", "st2", "cat1", "cat2"],
             )
-            model.fit(target_multi, verbose=False)
+        )
 
-            assert model.model.static_cov_dim == np.prod(
-                target_multi.static_covariates.values.shape
-            )
+        # test with static covariates in the timeseries
+        model = TiDEModel(
+            input_chunk_length=3,
+            output_chunk_length=4,
+            add_encoders={"cyclic": {"future": "hour"}},
+            pl_trainer_kwargs={
+                "fast_dev_run": True,
+                **tfm_kwargs["pl_trainer_kwargs"],
+            },
+        )
+        model.fit(target_multi, verbose=False)
 
-            # raise an error when trained with static covariates of wrong dimensionality
-            target_multi = target_multi.with_static_covariates(
-                pd.concat([target_multi.static_covariates] * 2, axis=1)
-            )
-            with pytest.raises(ValueError):
-                model.predict(n=1, series=target_multi, verbose=False)
+        assert model.model.static_cov_dim == np.prod(
+            target_multi.static_covariates.values.shape
+        )
 
-            # raise an error when trained with static covariates and trying to predict without
-            with pytest.raises(ValueError):
-                model.predict(
-                    n=1, series=target_multi.with_static_covariates(None), verbose=False
-                )
+        # raise an error when trained with static covariates of wrong dimensionality
+        target_multi = target_multi.with_static_covariates(
+            pd.concat([target_multi.static_covariates] * 2, axis=1)
+        )
+        with pytest.raises(ValueError):
+            model.predict(n=1, series=target_multi, verbose=False)
 
-            # with `use_static_covariates=False`, we can predict without static covs
-            model = TiDEModel(
-                input_chunk_length=3,
-                output_chunk_length=4,
-                use_static_covariates=False,
-                n_epochs=1,
-                **tfm_kwargs,
+        # raise an error when trained with static covariates and trying to predict without
+        with pytest.raises(ValueError):
+            model.predict(
+                n=1, series=target_multi.with_static_covariates(None), verbose=False
             )
-            model.fit(target_multi)
-            preds = model.predict(n=2, series=target_multi.with_static_covariates(None))
-            assert preds.static_covariates is None
 
-            model = TiDEModel(
-                input_chunk_length=3,
-                output_chunk_length=4,
-                use_static_covariates=False,
-                n_epochs=1,
-                **tfm_kwargs,
-            )
-            model.fit(target_multi.with_static_covariates(None))
-            preds = model.predict(n=2, series=target_multi)
-            assert preds.static_covariates.equals(target_multi.static_covariates)
+        # with `use_static_covariates=False`, we can predict without static covs
+        model = TiDEModel(
+            input_chunk_length=3,
+            output_chunk_length=4,
+            use_static_covariates=False,
+            n_epochs=1,
+            **tfm_kwargs,
+        )
+        model.fit(target_multi)
+        preds = model.predict(n=2, series=target_multi.with_static_covariates(None))
+        assert preds.static_covariates is None
+
+        model = TiDEModel(
+            input_chunk_length=3,
+            output_chunk_length=4,
+            use_static_covariates=False,
+            n_epochs=1,
+            **tfm_kwargs,
+        )
+        model.fit(target_multi.with_static_covariates(None))
+        preds = model.predict(n=2, series=target_multi)
+        assert preds.static_covariates.equals(target_multi.static_covariates)
diff --git a/darts/tests/models/forecasting/test_torch_forecasting_model.py b/darts/tests/models/forecasting/test_torch_forecasting_model.py
index 0e04f821b8..096f867688 100644
--- a/darts/tests/models/forecasting/test_torch_forecasting_model.py
+++ b/darts/tests/models/forecasting/test_torch_forecasting_model.py
@@ -72,894 +72,903 @@
         (GlobalNaiveAggregate, kwargs),
         (GlobalNaiveDrift, kwargs),
     ]
-
-    TORCH_AVAILABLE = True
 except ImportError:
-    logger.warning("Torch not available. Tests will be skipped.")
-    TORCH_AVAILABLE = False
+    pytest.skip(
+        f"Torch not available. {__name__} tests will be skipped.",
+        allow_module_level=True,
+    )
 
-if TORCH_AVAILABLE:
 
-    class TestTorchForecastingModel:
-        times = pd.date_range("20130101", "20130410")
-        pd_series = pd.Series(range(100), index=times)
-        series = TimeSeries.from_series(pd_series)
+class TestTorchForecastingModel:
+    times = pd.date_range("20130101", "20130410")
+    pd_series = pd.Series(range(100), index=times)
+    series = TimeSeries.from_series(pd_series)
 
-        df = pd.DataFrame({"var1": range(100), "var2": range(100)}, index=times)
-        multivariate_series = TimeSeries.from_dataframe(df)
+    df = pd.DataFrame({"var1": range(100), "var2": range(100)}, index=times)
+    multivariate_series = TimeSeries.from_dataframe(df)
 
-        def test_save_model_parameters(self):
-            # check if re-created model has same params as original
-            model = RNNModel(12, "RNN", 10, 10, **tfm_kwargs)
-            assert model._model_params, model.untrained_model()._model_params
+    def test_save_model_parameters(self):
+        # check if re-created model has same params as original
+        model = RNNModel(12, "RNN", 10, 10, **tfm_kwargs)
+        assert model._model_params, model.untrained_model()._model_params
 
-        @patch(
-            "darts.models.forecasting.torch_forecasting_model.TorchForecastingModel.save"
+    @patch(
+        "darts.models.forecasting.torch_forecasting_model.TorchForecastingModel.save"
+    )
+    def test_suppress_automatic_save(self, patch_save_model, tmpdir_fn):
+        model_name = "test_model"
+        model1 = RNNModel(
+            12,
+            "RNN",
+            10,
+            10,
+            model_name=model_name,
+            work_dir=tmpdir_fn,
+            save_checkpoints=False,
+            **tfm_kwargs,
+        )
+        model2 = RNNModel(
+            12,
+            "RNN",
+            10,
+            10,
+            model_name=model_name,
+            work_dir=tmpdir_fn,
+            force_reset=True,
+            save_checkpoints=False,
+            **tfm_kwargs,
         )
-        def test_suppress_automatic_save(self, patch_save_model, tmpdir_fn):
-            model_name = "test_model"
-            model1 = RNNModel(
-                12,
-                "RNN",
-                10,
-                10,
-                model_name=model_name,
-                work_dir=tmpdir_fn,
-                save_checkpoints=False,
-                **tfm_kwargs,
-            )
-            model2 = RNNModel(
-                12,
-                "RNN",
-                10,
-                10,
-                model_name=model_name,
-                work_dir=tmpdir_fn,
-                force_reset=True,
-                save_checkpoints=False,
-                **tfm_kwargs,
-            )
-
-            model1.fit(self.series, epochs=1)
-            model2.fit(self.series, epochs=1)
-
-            model1.predict(n=1)
-            model2.predict(n=2)
-
-            patch_save_model.assert_not_called()
-
-            model1.save(path=os.path.join(tmpdir_fn, model_name))
-            patch_save_model.assert_called()
-
-        def test_manual_save_and_load(self, tmpdir_fn):
-            """validate manual save with automatic save files by comparing output between the two"""
 
-            model_dir = os.path.join(tmpdir_fn)
-            manual_name = "test_save_manual"
-            auto_name = "test_save_automatic"
-            model_manual_save = RNNModel(
-                12,
-                "RNN",
-                10,
-                10,
-                model_name=manual_name,
-                work_dir=tmpdir_fn,
-                save_checkpoints=False,
-                random_state=42,
-                **tfm_kwargs,
-            )
-            model_auto_save = RNNModel(
-                12,
-                "RNN",
-                10,
-                10,
-                model_name=auto_name,
-                work_dir=tmpdir_fn,
-                save_checkpoints=True,
-                random_state=42,
-                **tfm_kwargs,
-            )
+        model1.fit(self.series, epochs=1)
+        model2.fit(self.series, epochs=1)
+
+        model1.predict(n=1)
+        model2.predict(n=2)
+
+        patch_save_model.assert_not_called()
+
+        model1.save(path=os.path.join(tmpdir_fn, model_name))
+        patch_save_model.assert_called()
+
+    def test_manual_save_and_load(self, tmpdir_fn):
+        """validate manual save with automatic save files by comparing output between the two"""
+
+        model_dir = os.path.join(tmpdir_fn)
+        manual_name = "test_save_manual"
+        auto_name = "test_save_automatic"
+        model_manual_save = RNNModel(
+            12,
+            "RNN",
+            10,
+            10,
+            model_name=manual_name,
+            work_dir=tmpdir_fn,
+            save_checkpoints=False,
+            random_state=42,
+            **tfm_kwargs,
+        )
+        model_auto_save = RNNModel(
+            12,
+            "RNN",
+            10,
+            10,
+            model_name=auto_name,
+            work_dir=tmpdir_fn,
+            save_checkpoints=True,
+            random_state=42,
+            **tfm_kwargs,
+        )
 
-            # save model without training
-            no_training_ckpt = "no_training.pth.tar"
-            no_training_ckpt_path = os.path.join(model_dir, no_training_ckpt)
-            model_manual_save.save(no_training_ckpt_path)
-            # check that model object file was created
-            assert os.path.exists(no_training_ckpt_path)
-            # check that the PyTorch Ligthning ckpt does not exist
-            assert not os.path.exists(no_training_ckpt_path + ".ckpt")
-            # informative exception about `fit()` not called
-            with pytest.raises(ValueError) as err:
-                no_train_model = RNNModel.load(no_training_ckpt_path)
-                no_train_model.predict(n=4)
-            assert str(err.value) == (
-                "Input `series` must be provided. This is the result either from "
-                "fitting on multiple series, or from not having fit the model yet."
-            )
+        # save model without training
+        no_training_ckpt = "no_training.pth.tar"
+        no_training_ckpt_path = os.path.join(model_dir, no_training_ckpt)
+        model_manual_save.save(no_training_ckpt_path)
+        # check that model object file was created
+        assert os.path.exists(no_training_ckpt_path)
+        # check that the PyTorch Ligthning ckpt does not exist
+        assert not os.path.exists(no_training_ckpt_path + ".ckpt")
+        # informative exception about `fit()` not called
+        with pytest.raises(ValueError) as err:
+            no_train_model = RNNModel.load(no_training_ckpt_path)
+            no_train_model.predict(n=4)
+        assert str(err.value) == (
+            "Input `series` must be provided. This is the result either from "
+            "fitting on multiple series, or from not having fit the model yet."
+        )
 
-            model_manual_save.fit(self.series, epochs=1)
-            model_auto_save.fit(self.series, epochs=1)
+        model_manual_save.fit(self.series, epochs=1)
+        model_auto_save.fit(self.series, epochs=1)
 
-            # check that file was not created with manual save
-            assert not os.path.exists(
-                os.path.join(model_dir, manual_name, "checkpoints")
-            )
-            # check that file was created with automatic save
-            assert os.path.exists(os.path.join(model_dir, auto_name, "checkpoints"))
+        # check that file was not created with manual save
+        assert not os.path.exists(os.path.join(model_dir, manual_name, "checkpoints"))
+        # check that file was created with automatic save
+        assert os.path.exists(os.path.join(model_dir, auto_name, "checkpoints"))
 
-            # create manually saved model checkpoints folder
-            checkpoint_path_manual = os.path.join(model_dir, manual_name)
-            os.mkdir(checkpoint_path_manual)
+        # create manually saved model checkpoints folder
+        checkpoint_path_manual = os.path.join(model_dir, manual_name)
+        os.mkdir(checkpoint_path_manual)
 
-            checkpoint_file_name = "checkpoint_0.pth.tar"
-            model_path_manual = os.path.join(
-                checkpoint_path_manual, checkpoint_file_name
-            )
-            checkpoint_file_name_cpkt = "checkpoint_0.pth.tar.ckpt"
-            model_path_manual_ckpt = os.path.join(
-                checkpoint_path_manual, checkpoint_file_name_cpkt
-            )
+        checkpoint_file_name = "checkpoint_0.pth.tar"
+        model_path_manual = os.path.join(checkpoint_path_manual, checkpoint_file_name)
+        checkpoint_file_name_cpkt = "checkpoint_0.pth.tar.ckpt"
+        model_path_manual_ckpt = os.path.join(
+            checkpoint_path_manual, checkpoint_file_name_cpkt
+        )
 
-            # save manually saved model
-            model_manual_save.save(model_path_manual)
-            assert os.path.exists(model_path_manual)
+        # save manually saved model
+        model_manual_save.save(model_path_manual)
+        assert os.path.exists(model_path_manual)
 
-            # check that the PTL checkpoint path is also there
-            assert os.path.exists(model_path_manual_ckpt)
+        # check that the PTL checkpoint path is also there
+        assert os.path.exists(model_path_manual_ckpt)
 
-            # load manual save model and compare with automatic model results
-            model_manual_save = RNNModel.load(model_path_manual, map_location="cpu")
-            model_manual_save.to_cpu()
-            assert model_manual_save.predict(n=4) == model_auto_save.predict(n=4)
+        # load manual save model and compare with automatic model results
+        model_manual_save = RNNModel.load(model_path_manual, map_location="cpu")
+        model_manual_save.to_cpu()
+        assert model_manual_save.predict(n=4) == model_auto_save.predict(n=4)
 
-            # load automatically saved model with manual load() and load_from_checkpoint()
-            model_auto_save1 = RNNModel.load_from_checkpoint(
-                model_name=auto_name,
-                work_dir=tmpdir_fn,
-                best=False,
-                map_location="cpu",
-            )
-            model_auto_save1.to_cpu()
-            # compare loaded checkpoint with manual save
-            assert model_manual_save.predict(n=4) == model_auto_save1.predict(n=4)
+        # load automatically saved model with manual load() and load_from_checkpoint()
+        model_auto_save1 = RNNModel.load_from_checkpoint(
+            model_name=auto_name,
+            work_dir=tmpdir_fn,
+            best=False,
+            map_location="cpu",
+        )
+        model_auto_save1.to_cpu()
+        # compare loaded checkpoint with manual save
+        assert model_manual_save.predict(n=4) == model_auto_save1.predict(n=4)
 
-            # save() model directly after load_from_checkpoint()
-            checkpoint_file_name_2 = "checkpoint_1.pth.tar"
-            checkpoint_file_name_cpkt_2 = checkpoint_file_name_2 + ".ckpt"
+        # save() model directly after load_from_checkpoint()
+        checkpoint_file_name_2 = "checkpoint_1.pth.tar"
+        checkpoint_file_name_cpkt_2 = checkpoint_file_name_2 + ".ckpt"
 
-            model_path_manual_2 = os.path.join(
-                checkpoint_path_manual, checkpoint_file_name_2
-            )
-            model_path_manual_ckpt_2 = os.path.join(
-                checkpoint_path_manual, checkpoint_file_name_cpkt_2
-            )
-            model_auto_save2 = RNNModel.load_from_checkpoint(
-                model_name=auto_name,
-                work_dir=tmpdir_fn,
-                best=False,
-                map_location="cpu",
-            )
-            # save model directly after loading, model has no trainer
-            model_auto_save2.save(model_path_manual_2)
+        model_path_manual_2 = os.path.join(
+            checkpoint_path_manual, checkpoint_file_name_2
+        )
+        model_path_manual_ckpt_2 = os.path.join(
+            checkpoint_path_manual, checkpoint_file_name_cpkt_2
+        )
+        model_auto_save2 = RNNModel.load_from_checkpoint(
+            model_name=auto_name,
+            work_dir=tmpdir_fn,
+            best=False,
+            map_location="cpu",
+        )
+        # save model directly after loading, model has no trainer
+        model_auto_save2.save(model_path_manual_2)
 
-            # assert original .ckpt checkpoint was correctly copied
-            assert os.path.exists(model_path_manual_ckpt_2)
+        # assert original .ckpt checkpoint was correctly copied
+        assert os.path.exists(model_path_manual_ckpt_2)
 
-            model_chained_load_save = RNNModel.load(
-                model_path_manual_2, map_location="cpu"
-            )
+        model_chained_load_save = RNNModel.load(model_path_manual_2, map_location="cpu")
 
-            # compare chained load_from_checkpoint() save() with manual save
-            assert model_chained_load_save.predict(n=4) == model_manual_save.predict(
-                n=4
-            )
+        # compare chained load_from_checkpoint() save() with manual save
+        assert model_chained_load_save.predict(n=4) == model_manual_save.predict(n=4)
 
-        def test_valid_save_and_load_weights_with_different_params(self, tmpdir_fn):
-            """
-            Verify that save/load does not break encoders.
-
-            Note: since load_weights() calls load_weights_from_checkpoint(), it will be used
-            for all but one test.
-            Note: Using DLinear since it supports both past and future covariates
-            """
-
-            def create_model(**kwargs):
-                return DLinearModel(
-                    input_chunk_length=4,
-                    output_chunk_length=1,
-                    **kwargs,
-                    **tfm_kwargs,
-                )
-
-            model_dir = os.path.join(tmpdir_fn)
-            manual_name = "save_manual"
-            # create manually saved model checkpoints folder
-            checkpoint_path_manual = os.path.join(model_dir, manual_name)
-            os.mkdir(checkpoint_path_manual)
-            checkpoint_file_name = "checkpoint_0.pth.tar"
-            model_path_manual = os.path.join(
-                checkpoint_path_manual, checkpoint_file_name
-            )
-            model = create_model()
-            model.fit(self.series, epochs=1)
-            model.save(model_path_manual)
-
-            kwargs_valid = [
-                {"optimizer_cls": torch.optim.SGD},
-                {"optimizer_kwargs": {"lr": 0.1}},
-            ]
-            # check that all models can be created with different valid kwargs
-            for kwargs_ in kwargs_valid:
-                model_new = create_model(**kwargs_)
-                model_new.load_weights(model_path_manual)
-
-        def test_save_and_load_weights_w_encoders(self, tmpdir_fn):
-            """
-            Verify that save/load does not break encoders.
-
-            Note: since load_weights() calls load_weights_from_checkpoint(), it will be used
-            for all but one test.
-            Note: Using DLinear since it supports both past and future covariates
-            """
-            model_dir = os.path.join(tmpdir_fn)
-            manual_name = "save_manual"
-            auto_name = "save_auto"
-            auto_name_other = "save_auto_other"
-            # create manually saved model checkpoints folder
-            checkpoint_path_manual = os.path.join(model_dir, manual_name)
-            os.mkdir(checkpoint_path_manual)
-            checkpoint_file_name = "checkpoint_0.pth.tar"
-            model_path_manual = os.path.join(
-                checkpoint_path_manual, checkpoint_file_name
-            )
+    def test_valid_save_and_load_weights_with_different_params(self, tmpdir_fn):
+        """
+        Verify that save/load does not break encoders.
 
-            # define encoders sets
-            encoders_past = {
-                "datetime_attribute": {"past": ["day"]},
-                "transformer": Scaler(),
-            }
-            encoders_other_past = {
-                "datetime_attribute": {"past": ["hour"]},
-                "transformer": Scaler(),
-            }
-            encoders_past_noscaler = {
-                "datetime_attribute": {"past": ["day"]},
-            }
-            encoders_past_other_transformer = {
-                "datetime_attribute": {"past": ["day"]},
-                "transformer": BoxCox(lmbda=-0.7),
-            }
-            encoders_2_past = {
-                "datetime_attribute": {"past": ["hour", "day"]},
-                "transformer": Scaler(),
-            }
-            encoders_past_n_future = {
-                "datetime_attribute": {"past": ["day"], "future": ["dayofweek"]},
-                "transformer": Scaler(),
-            }
+        Note: since load_weights() calls load_weights_from_checkpoint(), it will be used
+        for all but one test.
+        Note: Using DLinear since it supports both past and future covariates
+        """
 
-            model_auto_save = self.helper_create_DLinearModel(
-                work_dir=tmpdir_fn,
-                model_name=auto_name,
-                save_checkpoints=True,
-                add_encoders=encoders_past,
+        def create_model(**kwargs):
+            return DLinearModel(
+                input_chunk_length=4,
+                output_chunk_length=1,
+                **kwargs,
+                **tfm_kwargs,
             )
-            model_auto_save.fit(self.series, epochs=1)
 
-            model_manual_save = self.helper_create_DLinearModel(
-                work_dir=tmpdir_fn,
-                model_name=manual_name,
-                save_checkpoints=False,
-                add_encoders=encoders_past,
-            )
-            model_manual_save.fit(self.series, epochs=1)
-            model_manual_save.save(model_path_manual)
+        model_dir = os.path.join(tmpdir_fn)
+        manual_name = "save_manual"
+        # create manually saved model checkpoints folder
+        checkpoint_path_manual = os.path.join(model_dir, manual_name)
+        os.mkdir(checkpoint_path_manual)
+        checkpoint_file_name = "checkpoint_0.pth.tar"
+        model_path_manual = os.path.join(checkpoint_path_manual, checkpoint_file_name)
+        model = create_model()
+        model.fit(self.series, epochs=1)
+        model.save(model_path_manual)
+
+        kwargs_valid = [
+            {"optimizer_cls": torch.optim.SGD},
+            {"optimizer_kwargs": {"lr": 0.1}},
+        ]
+        # check that all models can be created with different valid kwargs
+        for kwargs_ in kwargs_valid:
+            model_new = create_model(**kwargs_)
+            model_new.load_weights(model_path_manual)
+
+    def test_save_and_load_weights_w_encoders(self, tmpdir_fn):
+        """
+        Verify that save/load does not break encoders.
+
+        Note: since load_weights() calls load_weights_from_checkpoint(), it will be used
+        for all but one test.
+        Note: Using DLinear since it supports both past and future covariates
+        """
+        model_dir = os.path.join(tmpdir_fn)
+        manual_name = "save_manual"
+        auto_name = "save_auto"
+        auto_name_other = "save_auto_other"
+        # create manually saved model checkpoints folder
+        checkpoint_path_manual = os.path.join(model_dir, manual_name)
+        os.mkdir(checkpoint_path_manual)
+        checkpoint_file_name = "checkpoint_0.pth.tar"
+        model_path_manual = os.path.join(checkpoint_path_manual, checkpoint_file_name)
+
+        # define encoders sets
+        encoders_past = {
+            "datetime_attribute": {"past": ["day"]},
+            "transformer": Scaler(),
+        }
+        encoders_other_past = {
+            "datetime_attribute": {"past": ["hour"]},
+            "transformer": Scaler(),
+        }
+        encoders_past_noscaler = {
+            "datetime_attribute": {"past": ["day"]},
+        }
+        encoders_past_other_transformer = {
+            "datetime_attribute": {"past": ["day"]},
+            "transformer": BoxCox(lmbda=-0.7),
+        }
+        encoders_2_past = {
+            "datetime_attribute": {"past": ["hour", "day"]},
+            "transformer": Scaler(),
+        }
+        encoders_past_n_future = {
+            "datetime_attribute": {"past": ["day"], "future": ["dayofweek"]},
+            "transformer": Scaler(),
+        }
+
+        model_auto_save = self.helper_create_DLinearModel(
+            work_dir=tmpdir_fn,
+            model_name=auto_name,
+            save_checkpoints=True,
+            add_encoders=encoders_past,
+        )
+        model_auto_save.fit(self.series, epochs=1)
 
-            model_auto_save_other = self.helper_create_DLinearModel(
-                work_dir=tmpdir_fn,
-                model_name=auto_name_other,
-                save_checkpoints=True,
-                add_encoders=encoders_other_past,
-            )
-            model_auto_save_other.fit(self.series, epochs=1)
+        model_manual_save = self.helper_create_DLinearModel(
+            work_dir=tmpdir_fn,
+            model_name=manual_name,
+            save_checkpoints=False,
+            add_encoders=encoders_past,
+        )
+        model_manual_save.fit(self.series, epochs=1)
+        model_manual_save.save(model_path_manual)
+
+        model_auto_save_other = self.helper_create_DLinearModel(
+            work_dir=tmpdir_fn,
+            model_name=auto_name_other,
+            save_checkpoints=True,
+            add_encoders=encoders_other_past,
+        )
+        model_auto_save_other.fit(self.series, epochs=1)
 
-            # prediction are different when using different encoders
-            assert model_auto_save.predict(n=4) != model_auto_save_other.predict(n=4)
+        # prediction are different when using different encoders
+        assert model_auto_save.predict(n=4) != model_auto_save_other.predict(n=4)
 
-            # model with undeclared encoders
-            model_no_enc = self.helper_create_DLinearModel(
-                work_dir=tmpdir_fn, model_name="no_encoder", add_encoders=None
-            )
-            # weights were trained with encoders, new model must be instantiated with encoders
-            with pytest.raises(ValueError):
-                model_no_enc.load_weights_from_checkpoint(
-                    auto_name,
-                    work_dir=tmpdir_fn,
-                    best=False,
-                    load_encoders=False,
-                    map_location="cpu",
-                )
-            # overwritte undeclared encoders
+        # model with undeclared encoders
+        model_no_enc = self.helper_create_DLinearModel(
+            work_dir=tmpdir_fn, model_name="no_encoder", add_encoders=None
+        )
+        # weights were trained with encoders, new model must be instantiated with encoders
+        with pytest.raises(ValueError):
             model_no_enc.load_weights_from_checkpoint(
                 auto_name,
                 work_dir=tmpdir_fn,
                 best=False,
-                load_encoders=True,
-                map_location="cpu",
-            )
-            self.helper_equality_encoders(
-                model_auto_save.add_encoders, model_no_enc.add_encoders
-            )
-            self.helper_equality_encoders_transfo(
-                model_auto_save.add_encoders, model_no_enc.add_encoders
-            )
-            # cannot directly verify equality between encoders, using predict as proxy
-            assert model_auto_save.predict(n=4) == model_no_enc.predict(
-                n=4, series=self.series
-            )
-
-            # model with identical encoders (fittable)
-            model_same_enc_noload = self.helper_create_DLinearModel(
-                work_dir=tmpdir_fn,
-                model_name="same_encoder_noload",
-                add_encoders=encoders_past,
-            )
-            model_same_enc_noload.load_weights(
-                model_path_manual,
                 load_encoders=False,
                 map_location="cpu",
             )
-            # cannot predict because of un-fitted encoder
-            with pytest.raises(ValueError):
-                model_same_enc_noload.predict(n=4, series=self.series)
+        # overwritte undeclared encoders
+        model_no_enc.load_weights_from_checkpoint(
+            auto_name,
+            work_dir=tmpdir_fn,
+            best=False,
+            load_encoders=True,
+            map_location="cpu",
+        )
+        self.helper_equality_encoders(
+            model_auto_save.add_encoders, model_no_enc.add_encoders
+        )
+        self.helper_equality_encoders_transfo(
+            model_auto_save.add_encoders, model_no_enc.add_encoders
+        )
+        # cannot directly verify equality between encoders, using predict as proxy
+        assert model_auto_save.predict(n=4) == model_no_enc.predict(
+            n=4, series=self.series
+        )
 
-            model_same_enc_load = self.helper_create_DLinearModel(
-                work_dir=tmpdir_fn,
-                model_name="same_encoder_load",
-                add_encoders=encoders_past,
-            )
-            model_same_enc_load.load_weights(
-                model_path_manual,
-                load_encoders=True,
-                map_location="cpu",
-            )
-            assert model_manual_save.predict(n=4) == model_same_enc_load.predict(
-                n=4, series=self.series
-            )
+        # model with identical encoders (fittable)
+        model_same_enc_noload = self.helper_create_DLinearModel(
+            work_dir=tmpdir_fn,
+            model_name="same_encoder_noload",
+            add_encoders=encoders_past,
+        )
+        model_same_enc_noload.load_weights(
+            model_path_manual,
+            load_encoders=False,
+            map_location="cpu",
+        )
+        # cannot predict because of un-fitted encoder
+        with pytest.raises(ValueError):
+            model_same_enc_noload.predict(n=4, series=self.series)
+
+        model_same_enc_load = self.helper_create_DLinearModel(
+            work_dir=tmpdir_fn,
+            model_name="same_encoder_load",
+            add_encoders=encoders_past,
+        )
+        model_same_enc_load.load_weights(
+            model_path_manual,
+            load_encoders=True,
+            map_location="cpu",
+        )
+        assert model_manual_save.predict(n=4) == model_same_enc_load.predict(
+            n=4, series=self.series
+        )
 
-            # model with different encoders (fittable)
-            model_other_enc_load = self.helper_create_DLinearModel(
-                work_dir=tmpdir_fn,
-                model_name="other_encoder_load",
-                add_encoders=encoders_other_past,
-            )
-            # cannot overwritte different declared encoders
-            with pytest.raises(ValueError):
-                model_other_enc_load.load_weights(
-                    model_path_manual,
-                    load_encoders=True,
-                    map_location="cpu",
-                )
-
-            # model with different encoders but same dimensions (fittable)
-            model_other_enc_noload = self.helper_create_DLinearModel(
-                work_dir=tmpdir_fn,
-                model_name="other_encoder_noload",
-                add_encoders=encoders_other_past,
-            )
-            model_other_enc_noload.load_weights(
+        # model with different encoders (fittable)
+        model_other_enc_load = self.helper_create_DLinearModel(
+            work_dir=tmpdir_fn,
+            model_name="other_encoder_load",
+            add_encoders=encoders_other_past,
+        )
+        # cannot overwritte different declared encoders
+        with pytest.raises(ValueError):
+            model_other_enc_load.load_weights(
                 model_path_manual,
-                load_encoders=False,
+                load_encoders=True,
                 map_location="cpu",
             )
-            self.helper_equality_encoders(
-                model_other_enc_noload.add_encoders, encoders_other_past
-            )
-            self.helper_equality_encoders_transfo(
-                model_other_enc_noload.add_encoders, encoders_other_past
-            )
-            # new encoders were instantiated
-            assert isinstance(model_other_enc_noload.encoders, SequentialEncoder)
-            # since fit() was not called, new fittable encoders were not trained
-            with pytest.raises(ValueError):
-                model_other_enc_noload.predict(n=4, series=self.series)
 
-            # predict() can be called after fit()
-            model_other_enc_noload.fit(self.series, epochs=1)
+        # model with different encoders but same dimensions (fittable)
+        model_other_enc_noload = self.helper_create_DLinearModel(
+            work_dir=tmpdir_fn,
+            model_name="other_encoder_noload",
+            add_encoders=encoders_other_past,
+        )
+        model_other_enc_noload.load_weights(
+            model_path_manual,
+            load_encoders=False,
+            map_location="cpu",
+        )
+        self.helper_equality_encoders(
+            model_other_enc_noload.add_encoders, encoders_other_past
+        )
+        self.helper_equality_encoders_transfo(
+            model_other_enc_noload.add_encoders, encoders_other_past
+        )
+        # new encoders were instantiated
+        assert isinstance(model_other_enc_noload.encoders, SequentialEncoder)
+        # since fit() was not called, new fittable encoders were not trained
+        with pytest.raises(ValueError):
             model_other_enc_noload.predict(n=4, series=self.series)
 
-            # model with same encoders but no scaler (non-fittable)
-            model_new_enc_noscaler_noload = self.helper_create_DLinearModel(
-                work_dir=tmpdir_fn,
-                model_name="same_encoder_noscaler",
-                add_encoders=encoders_past_noscaler,
-            )
-            model_new_enc_noscaler_noload.load_weights(
-                model_path_manual,
-                load_encoders=False,
-                map_location="cpu",
-            )
-
-            self.helper_equality_encoders(
-                model_new_enc_noscaler_noload.add_encoders, encoders_past_noscaler
-            )
-            self.helper_equality_encoders_transfo(
-                model_new_enc_noscaler_noload.add_encoders, encoders_past_noscaler
-            )
-            # predict() can be called directly since new encoders don't contain scaler
-            model_new_enc_noscaler_noload.predict(n=4, series=self.series)
+        # predict() can be called after fit()
+        model_other_enc_noload.fit(self.series, epochs=1)
+        model_other_enc_noload.predict(n=4, series=self.series)
 
-            # model with same encoders but different transformer (fittable)
-            model_new_enc_other_transformer = self.helper_create_DLinearModel(
-                work_dir=tmpdir_fn,
-                model_name="same_encoder_other_transform",
-                add_encoders=encoders_past_other_transformer,
-            )
-            # cannot overwritte different declared encoders
-            with pytest.raises(ValueError):
-                model_new_enc_other_transformer.load_weights(
-                    model_path_manual,
-                    load_encoders=True,
-                    map_location="cpu",
-                )
+        # model with same encoders but no scaler (non-fittable)
+        model_new_enc_noscaler_noload = self.helper_create_DLinearModel(
+            work_dir=tmpdir_fn,
+            model_name="same_encoder_noscaler",
+            add_encoders=encoders_past_noscaler,
+        )
+        model_new_enc_noscaler_noload.load_weights(
+            model_path_manual,
+            load_encoders=False,
+            map_location="cpu",
+        )
 
+        self.helper_equality_encoders(
+            model_new_enc_noscaler_noload.add_encoders, encoders_past_noscaler
+        )
+        self.helper_equality_encoders_transfo(
+            model_new_enc_noscaler_noload.add_encoders, encoders_past_noscaler
+        )
+        # predict() can be called directly since new encoders don't contain scaler
+        model_new_enc_noscaler_noload.predict(n=4, series=self.series)
+
+        # model with same encoders but different transformer (fittable)
+        model_new_enc_other_transformer = self.helper_create_DLinearModel(
+            work_dir=tmpdir_fn,
+            model_name="same_encoder_other_transform",
+            add_encoders=encoders_past_other_transformer,
+        )
+        # cannot overwritte different declared encoders
+        with pytest.raises(ValueError):
             model_new_enc_other_transformer.load_weights(
                 model_path_manual,
-                load_encoders=False,
+                load_encoders=True,
                 map_location="cpu",
             )
-            # since fit() was not called, new fittable encoders were not trained
-            with pytest.raises(ValueError):
-                model_new_enc_other_transformer.predict(n=4, series=self.series)
 
-            # predict() can be called after fit()
-            model_new_enc_other_transformer.fit(self.series, epochs=1)
+        model_new_enc_other_transformer.load_weights(
+            model_path_manual,
+            load_encoders=False,
+            map_location="cpu",
+        )
+        # since fit() was not called, new fittable encoders were not trained
+        with pytest.raises(ValueError):
             model_new_enc_other_transformer.predict(n=4, series=self.series)
 
-            # model with encoders containing more components (fittable)
-            model_new_enc_2_past = self.helper_create_DLinearModel(
-                work_dir=tmpdir_fn,
-                model_name="encoder_2_components_past",
-                add_encoders=encoders_2_past,
-            )
-            # cannot overwritte different declared encoders
-            with pytest.raises(ValueError):
-                model_new_enc_2_past.load_weights(
-                    model_path_manual,
-                    load_encoders=True,
-                    map_location="cpu",
-                )
-            # new encoders have one additional past component
-            with pytest.raises(ValueError):
-                model_new_enc_2_past.load_weights(
-                    model_path_manual,
-                    load_encoders=False,
-                    map_location="cpu",
-                )
-
-            # model with encoders containing past and future covs (fittable)
-            model_new_enc_past_n_future = self.helper_create_DLinearModel(
-                work_dir=tmpdir_fn,
-                model_name="encoder_past_n_future",
-                add_encoders=encoders_past_n_future,
-            )
-            # cannot overwritte different declared encoders
-            with pytest.raises(ValueError):
-                model_new_enc_past_n_future.load_weights(
-                    model_path_manual,
-                    load_encoders=True,
-                    map_location="cpu",
-                )
-            # identical past components, but different future components
-            with pytest.raises(ValueError):
-                model_new_enc_past_n_future.load_weights(
-                    model_path_manual,
-                    load_encoders=False,
-                    map_location="cpu",
-                )
-
-        def test_save_and_load_weights_w_likelihood(self, tmpdir_fn):
-            """
-            Verify that save/load does not break likelihood.
-
-            Note: since load_weights() calls load_weights_from_checkpoint(), it will be used
-            for all but one test.
-            Note: Using DLinear since it supports both past and future covariates
-            """
-            model_dir = os.path.join(tmpdir_fn)
-            manual_name = "save_manual"
-            auto_name = "save_auto"
-            # create manually saved model checkpoints folder
-            checkpoint_path_manual = os.path.join(model_dir, manual_name)
-            os.mkdir(checkpoint_path_manual)
-            checkpoint_file_name = "checkpoint_0.pth.tar"
-            model_path_manual = os.path.join(
-                checkpoint_path_manual, checkpoint_file_name
-            )
-
-            model_auto_save = self.helper_create_DLinearModel(
-                work_dir=tmpdir_fn,
-                model_name=auto_name,
-                save_checkpoints=True,
-                likelihood=GaussianLikelihood(prior_mu=0.5),
-            )
-            model_auto_save.fit(self.series, epochs=1)
-            pred_auto = model_auto_save.predict(n=4, series=self.series)
-
-            model_manual_save = self.helper_create_DLinearModel(
-                work_dir=tmpdir_fn,
-                model_name=manual_name,
-                save_checkpoints=False,
-                likelihood=GaussianLikelihood(prior_mu=0.5),
-            )
-            model_manual_save.fit(self.series, epochs=1)
-            model_manual_save.save(model_path_manual)
-            pred_manual = model_manual_save.predict(n=4, series=self.series)
+        # predict() can be called after fit()
+        model_new_enc_other_transformer.fit(self.series, epochs=1)
+        model_new_enc_other_transformer.predict(n=4, series=self.series)
 
-            # predictions are identical when using the same likelihood
-            assert np.array_equal(pred_auto.values(), pred_manual.values())
-
-            # model with identical likelihood
-            model_same_likelihood = self.helper_create_DLinearModel(
-                work_dir=tmpdir_fn,
-                model_name="same_likelihood",
-                likelihood=GaussianLikelihood(prior_mu=0.5),
-            )
-            model_same_likelihood.load_weights(model_path_manual, map_location="cpu")
-            model_same_likelihood.predict(n=4, series=self.series)
-            # cannot check predictions since this model is not fitted, random state is different
-
-            # loading models weights with respective methods
-            model_manual_same_likelihood = self.helper_create_DLinearModel(
-                work_dir=tmpdir_fn,
-                model_name="same_likelihood",
-                likelihood=GaussianLikelihood(prior_mu=0.5),
-            )
-            model_manual_same_likelihood.load_weights(
-                model_path_manual, map_location="cpu"
-            )
-            preds_manual_from_weights = model_manual_same_likelihood.predict(
-                n=4, series=self.series
-            )
-
-            model_auto_same_likelihood = self.helper_create_DLinearModel(
-                work_dir=tmpdir_fn,
-                model_name="same_likelihood",
-                likelihood=GaussianLikelihood(prior_mu=0.5),
-            )
-            model_auto_same_likelihood.load_weights_from_checkpoint(
-                auto_name, work_dir=tmpdir_fn, best=False, map_location="cpu"
-            )
-            preds_auto_from_weights = model_auto_same_likelihood.predict(
-                n=4, series=self.series
-            )
-            # check that weights from checkpoint give identical predictions as weights from manual save
-            assert preds_manual_from_weights == preds_auto_from_weights
-            # model with explicitely no likelihood
-            model_no_likelihood = self.helper_create_DLinearModel(
-                work_dir=tmpdir_fn, model_name="no_likelihood", likelihood=None
-            )
-            with pytest.raises(ValueError) as error_msg:
-                model_no_likelihood.load_weights_from_checkpoint(
-                    auto_name,
-                    work_dir=tmpdir_fn,
-                    best=False,
-                    map_location="cpu",
-                )
-            assert str(error_msg.value).startswith(
-                "The values of the hyper-parameters in the model and loaded checkpoint should be identical.\n"
-                "incorrect"
-            )
-
-            # model with missing likelihood (as if user forgot them)
-            model_no_likelihood_bis = DLinearModel(
-                input_chunk_length=4,
-                output_chunk_length=1,
-                model_name="no_likelihood_bis",
-                add_encoders=None,
-                work_dir=tmpdir_fn,
-                save_checkpoints=False,
-                random_state=42,
-                force_reset=True,
-                n_epochs=1,
-                # likelihood=likelihood,
-                **tfm_kwargs,
-            )
-            with pytest.raises(ValueError) as error_msg:
-                model_no_likelihood_bis.load_weights_from_checkpoint(
-                    auto_name,
-                    work_dir=tmpdir_fn,
-                    best=False,
-                    map_location="cpu",
-                )
-            assert str(error_msg.value).startswith(
-                "The values of the hyper-parameters in the model and loaded checkpoint should be identical.\n"
-                "missing"
-            )
-
-            # model with a different likelihood
-            model_other_likelihood = self.helper_create_DLinearModel(
-                work_dir=tmpdir_fn,
-                model_name="other_likelihood",
-                likelihood=LaplaceLikelihood(),
-            )
-            with pytest.raises(ValueError) as error_msg:
-                model_other_likelihood.load_weights(
-                    model_path_manual, map_location="cpu"
-                )
-            assert str(error_msg.value).startswith(
-                "The values of the hyper-parameters in the model and loaded checkpoint should be identical.\n"
-                "incorrect"
-            )
-
-            # model with the same likelihood but different parameters
-            model_same_likelihood_other_prior = self.helper_create_DLinearModel(
-                work_dir=tmpdir_fn,
-                model_name="same_likelihood_other_prior",
-                likelihood=GaussianLikelihood(),
-            )
-            with pytest.raises(ValueError) as error_msg:
-                model_same_likelihood_other_prior.load_weights(
-                    model_path_manual, map_location="cpu"
-                )
-            assert str(error_msg.value).startswith(
-                "The values of the hyper-parameters in the model and loaded checkpoint should be identical.\n"
-                "incorrect"
-            )
-
-        def test_load_weights_params_check(self, tmpdir_fn):
-            """
-            Verify that the method comparing the parameters between the saved model and the loading model
-            behave as expected, used to return meaningful error message instead of the torch.load ones.
-            """
-            model_name = "params_check"
-            ckpt_path = os.path.join(tmpdir_fn, f"{model_name}.pt")
-            # barebone model
-            model = DLinearModel(
-                input_chunk_length=4, output_chunk_length=1, n_epochs=1, **tfm_kwargs
-            )
-            model.fit(self.series[:10])
-            model.save(ckpt_path)
-
-            # identical model
-            loading_model = DLinearModel(
-                input_chunk_length=4, output_chunk_length=1, **tfm_kwargs
-            )
-            loading_model.load_weights(ckpt_path)
-
-            # different optimizer
-            loading_model = DLinearModel(
-                input_chunk_length=4,
-                output_chunk_length=1,
-                optimizer_cls=torch.optim.AdamW,
-                **tfm_kwargs,
-            )
-            loading_model.load_weights(ckpt_path)
-
-            model_summary_kwargs = {
-                "pl_trainer_kwargs": dict(
-                    {"enable_model_sumamry": False}, **tfm_kwargs["pl_trainer_kwargs"]
-                )
-            }
-            # different pl_trainer_kwargs
-            loading_model = DLinearModel(
-                input_chunk_length=4,
-                output_chunk_length=1,
-                **model_summary_kwargs,
-            )
-            loading_model.load_weights(ckpt_path)
-
-            # different input_chunk_length (tfm parameter)
-            loading_model = DLinearModel(
-                input_chunk_length=4 + 1, output_chunk_length=1, **tfm_kwargs
+        # model with encoders containing more components (fittable)
+        model_new_enc_2_past = self.helper_create_DLinearModel(
+            work_dir=tmpdir_fn,
+            model_name="encoder_2_components_past",
+            add_encoders=encoders_2_past,
+        )
+        # cannot overwritte different declared encoders
+        with pytest.raises(ValueError):
+            model_new_enc_2_past.load_weights(
+                model_path_manual,
+                load_encoders=True,
+                map_location="cpu",
             )
-            with pytest.raises(ValueError) as error_msg:
-                loading_model.load_weights(ckpt_path)
-            assert str(error_msg.value).startswith(
-                "The values of the hyper-parameters in the model and loaded checkpoint should be identical.\n"
-                "incorrect"
+        # new encoders have one additional past component
+        with pytest.raises(ValueError):
+            model_new_enc_2_past.load_weights(
+                model_path_manual,
+                load_encoders=False,
+                map_location="cpu",
             )
 
-            # different kernel size (cls specific parameter)
-            loading_model = DLinearModel(
-                input_chunk_length=4,
-                output_chunk_length=1,
-                kernel_size=10,
-                **tfm_kwargs,
-            )
-            with pytest.raises(ValueError) as error_msg:
-                loading_model.load_weights(ckpt_path)
-            assert str(error_msg.value).startswith(
-                "The values of the hyper-parameters in the model and loaded checkpoint should be identical.\n"
-                "incorrect"
+        # model with encoders containing past and future covs (fittable)
+        model_new_enc_past_n_future = self.helper_create_DLinearModel(
+            work_dir=tmpdir_fn,
+            model_name="encoder_past_n_future",
+            add_encoders=encoders_past_n_future,
+        )
+        # cannot overwritte different declared encoders
+        with pytest.raises(ValueError):
+            model_new_enc_past_n_future.load_weights(
+                model_path_manual,
+                load_encoders=True,
+                map_location="cpu",
             )
-
-        def test_create_instance_new_model_no_name_set(self, tmpdir_fn):
-            RNNModel(12, "RNN", 10, 10, work_dir=tmpdir_fn, **tfm_kwargs)
-            # no exception is raised
-
-        def test_create_instance_existing_model_with_name_no_fit(self, tmpdir_fn):
-            model_name = "test_model"
-            RNNModel(
-                12,
-                "RNN",
-                10,
-                10,
-                work_dir=tmpdir_fn,
-                model_name=model_name,
-                **tfm_kwargs,
+        # identical past components, but different future components
+        with pytest.raises(ValueError):
+            model_new_enc_past_n_future.load_weights(
+                model_path_manual,
+                load_encoders=False,
+                map_location="cpu",
             )
-            # no exception is raised
 
-        @patch(
-            "darts.models.forecasting.torch_forecasting_model.TorchForecastingModel.reset_model"
+    def test_save_and_load_weights_w_likelihood(self, tmpdir_fn):
+        """
+        Verify that save/load does not break likelihood.
+
+        Note: since load_weights() calls load_weights_from_checkpoint(), it will be used
+        for all but one test.
+        Note: Using DLinear since it supports both past and future covariates
+        """
+        model_dir = os.path.join(tmpdir_fn)
+        manual_name = "save_manual"
+        auto_name = "save_auto"
+        # create manually saved model checkpoints folder
+        checkpoint_path_manual = os.path.join(model_dir, manual_name)
+        os.mkdir(checkpoint_path_manual)
+        checkpoint_file_name = "checkpoint_0.pth.tar"
+        model_path_manual = os.path.join(checkpoint_path_manual, checkpoint_file_name)
+
+        model_auto_save = self.helper_create_DLinearModel(
+            work_dir=tmpdir_fn,
+            model_name=auto_name,
+            save_checkpoints=True,
+            likelihood=GaussianLikelihood(prior_mu=0.5),
+        )
+        model_auto_save.fit(self.series, epochs=1)
+        pred_auto = model_auto_save.predict(n=4, series=self.series)
+
+        model_manual_save = self.helper_create_DLinearModel(
+            work_dir=tmpdir_fn,
+            model_name=manual_name,
+            save_checkpoints=False,
+            likelihood=GaussianLikelihood(prior_mu=0.5),
+        )
+        model_manual_save.fit(self.series, epochs=1)
+        model_manual_save.save(model_path_manual)
+        pred_manual = model_manual_save.predict(n=4, series=self.series)
+
+        # predictions are identical when using the same likelihood
+        assert np.array_equal(pred_auto.values(), pred_manual.values())
+
+        # model with identical likelihood
+        model_same_likelihood = self.helper_create_DLinearModel(
+            work_dir=tmpdir_fn,
+            model_name="same_likelihood",
+            likelihood=GaussianLikelihood(prior_mu=0.5),
+        )
+        model_same_likelihood.load_weights(model_path_manual, map_location="cpu")
+        model_same_likelihood.predict(n=4, series=self.series)
+        # cannot check predictions since this model is not fitted, random state is different
+
+        # loading models weights with respective methods
+        model_manual_same_likelihood = self.helper_create_DLinearModel(
+            work_dir=tmpdir_fn,
+            model_name="same_likelihood",
+            likelihood=GaussianLikelihood(prior_mu=0.5),
+        )
+        model_manual_same_likelihood.load_weights(model_path_manual, map_location="cpu")
+        preds_manual_from_weights = model_manual_same_likelihood.predict(
+            n=4, series=self.series
         )
-        def test_create_instance_existing_model_with_name_force(
-            self, patch_reset_model, tmpdir_fn
-        ):
-            model_name = "test_model"
-            RNNModel(
-                12,
-                "RNN",
-                10,
-                10,
-                work_dir=tmpdir_fn,
-                model_name=model_name,
-                **tfm_kwargs,
-            )
-            # no exception is raised
-            # since no fit, there is no data stored for the model, hence `force_reset` does noting
 
-            RNNModel(
-                12,
-                "RNN",
-                10,
-                10,
+        model_auto_same_likelihood = self.helper_create_DLinearModel(
+            work_dir=tmpdir_fn,
+            model_name="same_likelihood",
+            likelihood=GaussianLikelihood(prior_mu=0.5),
+        )
+        model_auto_same_likelihood.load_weights_from_checkpoint(
+            auto_name, work_dir=tmpdir_fn, best=False, map_location="cpu"
+        )
+        preds_auto_from_weights = model_auto_same_likelihood.predict(
+            n=4, series=self.series
+        )
+        # check that weights from checkpoint give identical predictions as weights from manual save
+        assert preds_manual_from_weights == preds_auto_from_weights
+        # model with explicitely no likelihood
+        model_no_likelihood = self.helper_create_DLinearModel(
+            work_dir=tmpdir_fn, model_name="no_likelihood", likelihood=None
+        )
+        with pytest.raises(ValueError) as error_msg:
+            model_no_likelihood.load_weights_from_checkpoint(
+                auto_name,
                 work_dir=tmpdir_fn,
-                model_name=model_name,
-                force_reset=True,
-                **tfm_kwargs,
+                best=False,
+                map_location="cpu",
             )
-            patch_reset_model.assert_not_called()
+        assert str(error_msg.value).startswith(
+            "The values of the hyper-parameters in the model and loaded checkpoint should be identical.\n"
+            "incorrect"
+        )
 
-        @patch(
-            "darts.models.forecasting.torch_forecasting_model.TorchForecastingModel.reset_model"
+        # model with missing likelihood (as if user forgot them)
+        model_no_likelihood_bis = DLinearModel(
+            input_chunk_length=4,
+            output_chunk_length=1,
+            model_name="no_likelihood_bis",
+            add_encoders=None,
+            work_dir=tmpdir_fn,
+            save_checkpoints=False,
+            random_state=42,
+            force_reset=True,
+            n_epochs=1,
+            # likelihood=likelihood,
+            **tfm_kwargs,
         )
-        def test_create_instance_existing_model_with_name_force_fit_with_reset(
-            self, patch_reset_model, tmpdir_fn
-        ):
-            model_name = "test_model"
-            model1 = RNNModel(
-                12,
-                "RNN",
-                10,
-                10,
+        with pytest.raises(ValueError) as error_msg:
+            model_no_likelihood_bis.load_weights_from_checkpoint(
+                auto_name,
                 work_dir=tmpdir_fn,
-                model_name=model_name,
-                save_checkpoints=True,
-                **tfm_kwargs,
+                best=False,
+                map_location="cpu",
             )
-            # no exception is raised
+        assert str(error_msg.value).startswith(
+            "The values of the hyper-parameters in the model and loaded checkpoint should be identical.\n"
+            "missing"
+        )
 
-            model1.fit(self.series, epochs=1)
+        # model with a different likelihood
+        model_other_likelihood = self.helper_create_DLinearModel(
+            work_dir=tmpdir_fn,
+            model_name="other_likelihood",
+            likelihood=LaplaceLikelihood(),
+        )
+        with pytest.raises(ValueError) as error_msg:
+            model_other_likelihood.load_weights(model_path_manual, map_location="cpu")
+        assert str(error_msg.value).startswith(
+            "The values of the hyper-parameters in the model and loaded checkpoint should be identical.\n"
+            "incorrect"
+        )
 
-            RNNModel(
-                12,
-                "RNN",
-                10,
-                10,
-                work_dir=tmpdir_fn,
-                model_name=model_name,
-                save_checkpoints=True,
-                force_reset=True,
-                **tfm_kwargs,
+        # model with the same likelihood but different parameters
+        model_same_likelihood_other_prior = self.helper_create_DLinearModel(
+            work_dir=tmpdir_fn,
+            model_name="same_likelihood_other_prior",
+            likelihood=GaussianLikelihood(),
+        )
+        with pytest.raises(ValueError) as error_msg:
+            model_same_likelihood_other_prior.load_weights(
+                model_path_manual, map_location="cpu"
             )
-            patch_reset_model.assert_called_once()
+        assert str(error_msg.value).startswith(
+            "The values of the hyper-parameters in the model and loaded checkpoint should be identical.\n"
+            "incorrect"
+        )
 
-        # TODO for PTL: currently we (have to (?)) create a mew PTL trainer object every time fit() is called which
-        #  resets some of the model's attributes such as epoch and step counts. We have check whether there is another
-        #  way of doing this.
+    def test_load_weights_params_check(self, tmpdir_fn):
+        """
+        Verify that the method comparing the parameters between the saved model and the loading model
+        behave as expected, used to return meaningful error message instead of the torch.load ones.
+        """
+        model_name = "params_check"
+        ckpt_path = os.path.join(tmpdir_fn, f"{model_name}.pt")
+        # barebone model
+        model = DLinearModel(
+            input_chunk_length=4, output_chunk_length=1, n_epochs=1, **tfm_kwargs
+        )
+        model.fit(self.series[:10])
+        model.save(ckpt_path)
 
-        # n_epochs=20, fit|epochs=None, epochs_trained=0 - train for 20 epochs
-        def test_train_from_0_n_epochs_20_no_fit_epochs(self):
-            model1 = RNNModel(
-                12,
-                "RNN",
-                10,
-                10,
-                n_epochs=20,
-                **tfm_kwargs,
-            )
+        # identical model
+        loading_model = DLinearModel(
+            input_chunk_length=4, output_chunk_length=1, **tfm_kwargs
+        )
+        loading_model.load_weights(ckpt_path)
+
+        # different optimizer
+        loading_model = DLinearModel(
+            input_chunk_length=4,
+            output_chunk_length=1,
+            optimizer_cls=torch.optim.AdamW,
+            **tfm_kwargs,
+        )
+        loading_model.load_weights(ckpt_path)
+
+        model_summary_kwargs = {
+            "pl_trainer_kwargs": dict(
+                {"enable_model_sumamry": False}, **tfm_kwargs["pl_trainer_kwargs"]
+            )
+        }
+        # different pl_trainer_kwargs
+        loading_model = DLinearModel(
+            input_chunk_length=4,
+            output_chunk_length=1,
+            **model_summary_kwargs,
+        )
+        loading_model.load_weights(ckpt_path)
 
-            model1.fit(self.series)
+        # different input_chunk_length (tfm parameter)
+        loading_model = DLinearModel(
+            input_chunk_length=4 + 1, output_chunk_length=1, **tfm_kwargs
+        )
+        with pytest.raises(ValueError) as error_msg:
+            loading_model.load_weights(ckpt_path)
+        assert str(error_msg.value).startswith(
+            "The values of the hyper-parameters in the model and loaded checkpoint should be identical.\n"
+            "incorrect"
+        )
 
-            assert 20 == model1.epochs_trained
+        # different kernel size (cls specific parameter)
+        loading_model = DLinearModel(
+            input_chunk_length=4,
+            output_chunk_length=1,
+            kernel_size=10,
+            **tfm_kwargs,
+        )
+        with pytest.raises(ValueError) as error_msg:
+            loading_model.load_weights(ckpt_path)
+        assert str(error_msg.value).startswith(
+            "The values of the hyper-parameters in the model and loaded checkpoint should be identical.\n"
+            "incorrect"
+        )
 
-        # n_epochs = 20, fit|epochs=None, epochs_trained=20 - train for another 20 epochs
-        def test_train_from_20_n_epochs_40_no_fit_epochs(self):
-            model1 = RNNModel(
-                12,
-                "RNN",
-                10,
-                10,
-                n_epochs=20,
-                **tfm_kwargs,
-            )
+    def test_create_instance_new_model_no_name_set(self, tmpdir_fn):
+        RNNModel(12, "RNN", 10, 10, work_dir=tmpdir_fn, **tfm_kwargs)
+        # no exception is raised
+
+    def test_create_instance_existing_model_with_name_no_fit(self, tmpdir_fn):
+        model_name = "test_model"
+        RNNModel(
+            12,
+            "RNN",
+            10,
+            10,
+            work_dir=tmpdir_fn,
+            model_name=model_name,
+            **tfm_kwargs,
+        )
+        # no exception is raised
 
-            model1.fit(self.series)
-            assert 20 == model1.epochs_trained
+    @patch(
+        "darts.models.forecasting.torch_forecasting_model.TorchForecastingModel.reset_model"
+    )
+    def test_create_instance_existing_model_with_name_force(
+        self, patch_reset_model, tmpdir_fn
+    ):
+        model_name = "test_model"
+        RNNModel(
+            12,
+            "RNN",
+            10,
+            10,
+            work_dir=tmpdir_fn,
+            model_name=model_name,
+            **tfm_kwargs,
+        )
+        # no exception is raised
+        # since no fit, there is no data stored for the model, hence `force_reset` does noting
+
+        RNNModel(
+            12,
+            "RNN",
+            10,
+            10,
+            work_dir=tmpdir_fn,
+            model_name=model_name,
+            force_reset=True,
+            **tfm_kwargs,
+        )
+        patch_reset_model.assert_not_called()
 
-            model1.fit(self.series)
-            assert 20 == model1.epochs_trained
+    @patch(
+        "darts.models.forecasting.torch_forecasting_model.TorchForecastingModel.reset_model"
+    )
+    def test_create_instance_existing_model_with_name_force_fit_with_reset(
+        self, patch_reset_model, tmpdir_fn
+    ):
+        model_name = "test_model"
+        model1 = RNNModel(
+            12,
+            "RNN",
+            10,
+            10,
+            work_dir=tmpdir_fn,
+            model_name=model_name,
+            save_checkpoints=True,
+            **tfm_kwargs,
+        )
+        # no exception is raised
+
+        model1.fit(self.series, epochs=1)
+
+        RNNModel(
+            12,
+            "RNN",
+            10,
+            10,
+            work_dir=tmpdir_fn,
+            model_name=model_name,
+            save_checkpoints=True,
+            force_reset=True,
+            **tfm_kwargs,
+        )
+        patch_reset_model.assert_called_once()
+
+    # TODO for PTL: currently we (have to (?)) create a mew PTL trainer object every time fit() is called which
+    #  resets some of the model's attributes such as epoch and step counts. We have check whether there is another
+    #  way of doing this.
+
+    # n_epochs=20, fit|epochs=None, epochs_trained=0 - train for 20 epochs
+    def test_train_from_0_n_epochs_20_no_fit_epochs(self):
+        model1 = RNNModel(
+            12,
+            "RNN",
+            10,
+            10,
+            n_epochs=20,
+            **tfm_kwargs,
+        )
 
-        # n_epochs = 20, fit|epochs=None, epochs_trained=10 - train for another 20 epochs
-        def test_train_from_10_n_epochs_20_no_fit_epochs(self):
-            model1 = RNNModel(
-                12,
-                "RNN",
-                10,
-                10,
-                n_epochs=20,
-                **tfm_kwargs,
-            )
+        model1.fit(self.series)
 
-            # simulate the case that user interrupted training with Ctrl-C after 10 epochs
-            model1.fit(self.series, epochs=10)
-            assert 10 == model1.epochs_trained
+        assert 20 == model1.epochs_trained
 
-            model1.fit(self.series)
-            assert 20 == model1.epochs_trained
+    # n_epochs = 20, fit|epochs=None, epochs_trained=20 - train for another 20 epochs
+    def test_train_from_20_n_epochs_40_no_fit_epochs(self):
+        model1 = RNNModel(
+            12,
+            "RNN",
+            10,
+            10,
+            n_epochs=20,
+            **tfm_kwargs,
+        )
 
-        # n_epochs = 20, fit|epochs=15, epochs_trained=10 - train for 15 epochs
-        def test_train_from_10_n_epochs_20_fit_15_epochs(self):
-            model1 = RNNModel(
-                12,
-                "RNN",
-                10,
-                10,
-                n_epochs=20,
-                **tfm_kwargs,
-            )
+        model1.fit(self.series)
+        assert 20 == model1.epochs_trained
+
+        model1.fit(self.series)
+        assert 20 == model1.epochs_trained
+
+    # n_epochs = 20, fit|epochs=None, epochs_trained=10 - train for another 20 epochs
+    def test_train_from_10_n_epochs_20_no_fit_epochs(self):
+        model1 = RNNModel(
+            12,
+            "RNN",
+            10,
+            10,
+            n_epochs=20,
+            **tfm_kwargs,
+        )
 
-            # simulate the case that user interrupted training with Ctrl-C after 10 epochs
-            model1.fit(self.series, epochs=10)
-            assert 10 == model1.epochs_trained
+        # simulate the case that user interrupted training with Ctrl-C after 10 epochs
+        model1.fit(self.series, epochs=10)
+        assert 10 == model1.epochs_trained
+
+        model1.fit(self.series)
+        assert 20 == model1.epochs_trained
+
+    # n_epochs = 20, fit|epochs=15, epochs_trained=10 - train for 15 epochs
+    def test_train_from_10_n_epochs_20_fit_15_epochs(self):
+        model1 = RNNModel(
+            12,
+            "RNN",
+            10,
+            10,
+            n_epochs=20,
+            **tfm_kwargs,
+        )
 
-            model1.fit(self.series, epochs=15)
-            assert 15 == model1.epochs_trained
+        # simulate the case that user interrupted training with Ctrl-C after 10 epochs
+        model1.fit(self.series, epochs=10)
+        assert 10 == model1.epochs_trained
+
+        model1.fit(self.series, epochs=15)
+        assert 15 == model1.epochs_trained
+
+    def test_load_weights_from_checkpoint(self, tmpdir_fn):
+        ts_training, ts_test = self.series.split_before(90)
+        original_model_name = "original"
+        retrained_model_name = "retrained"
+        # original model, checkpoints are saved
+        model = RNNModel(
+            12,
+            "RNN",
+            5,
+            1,
+            n_epochs=5,
+            work_dir=tmpdir_fn,
+            save_checkpoints=True,
+            model_name=original_model_name,
+            random_state=1,
+            **tfm_kwargs,
+        )
+        model.fit(ts_training)
+        original_preds = model.predict(10)
+        original_mape = mape(original_preds, ts_test)
+
+        # load last checkpoint of original model, train it for 2 additional epochs
+        model_rt = RNNModel(
+            12,
+            "RNN",
+            5,
+            1,
+            n_epochs=5,
+            work_dir=tmpdir_fn,
+            model_name=retrained_model_name,
+            random_state=1,
+            **tfm_kwargs,
+        )
+        model_rt.load_weights_from_checkpoint(
+            model_name=original_model_name,
+            work_dir=tmpdir_fn,
+            best=False,
+            map_location="cpu",
+        )
 
-        def test_load_weights_from_checkpoint(self, tmpdir_fn):
-            ts_training, ts_test = self.series.split_before(90)
-            original_model_name = "original"
-            retrained_model_name = "retrained"
-            # original model, checkpoints are saved
-            model = RNNModel(
-                12,
-                "RNN",
-                5,
-                1,
-                n_epochs=5,
-                work_dir=tmpdir_fn,
-                save_checkpoints=True,
-                model_name=original_model_name,
-                random_state=1,
-                **tfm_kwargs,
-            )
-            model.fit(ts_training)
-            original_preds = model.predict(10)
-            original_mape = mape(original_preds, ts_test)
+        # must indicate series otherwise self.training_series must be saved in checkpoint
+        loaded_preds = model_rt.predict(10, ts_training)
+        # save/load checkpoint should produce identical predictions
+        assert original_preds == loaded_preds
+
+        model_rt.fit(ts_training)
+        retrained_preds = model_rt.predict(10)
+        retrained_mape = mape(retrained_preds, ts_test)
+        assert retrained_mape < original_mape, (
+            f"Retrained model has a greater error (mape) than the original model, "
+            f"respectively {retrained_mape} and {original_mape}"
+        )
 
-            # load last checkpoint of original model, train it for 2 additional epochs
+        # raise Exception when trying to load ckpt weights in different architecture
+        with pytest.raises(ValueError):
             model_rt = RNNModel(
                 12,
                 "RNN",
+                10,  # loaded model has only 5 hidden_layers
                 5,
-                1,
-                n_epochs=5,
-                work_dir=tmpdir_fn,
-                model_name=retrained_model_name,
-                random_state=1,
-                **tfm_kwargs,
             )
             model_rt.load_weights_from_checkpoint(
                 model_name=original_model_name,
@@ -968,749 +977,710 @@ def test_load_weights_from_checkpoint(self, tmpdir_fn):
                 map_location="cpu",
             )
 
-            # must indicate series otherwise self.training_series must be saved in checkpoint
-            loaded_preds = model_rt.predict(10, ts_training)
-            # save/load checkpoint should produce identical predictions
-            assert original_preds == loaded_preds
-
-            model_rt.fit(ts_training)
-            retrained_preds = model_rt.predict(10)
-            retrained_mape = mape(retrained_preds, ts_test)
-            assert retrained_mape < original_mape, (
-                f"Retrained model has a greater error (mape) than the original model, "
-                f"respectively {retrained_mape} and {original_mape}"
-            )
-
-            # raise Exception when trying to load ckpt weights in different architecture
-            with pytest.raises(ValueError):
-                model_rt = RNNModel(
-                    12,
-                    "RNN",
-                    10,  # loaded model has only 5 hidden_layers
-                    5,
-                )
-                model_rt.load_weights_from_checkpoint(
-                    model_name=original_model_name,
-                    work_dir=tmpdir_fn,
-                    best=False,
-                    map_location="cpu",
-                )
-
-            # raise Exception when trying to pass `weights_only`=True to `torch.load()`
-            with pytest.raises(ValueError):
-                model_rt = RNNModel(12, "RNN", 5, 5, **tfm_kwargs)
-                model_rt.load_weights_from_checkpoint(
-                    model_name=original_model_name,
-                    work_dir=tmpdir_fn,
-                    best=False,
-                    weights_only=True,
-                    map_location="cpu",
-                )
-
-        def test_load_weights(self, tmpdir_fn):
-            ts_training, ts_test = self.series.split_before(90)
-            original_model_name = "original"
-            retrained_model_name = "retrained"
-            # original model, checkpoints are saved
-            model = RNNModel(
-                12,
-                "RNN",
-                5,
-                1,
-                n_epochs=5,
-                work_dir=tmpdir_fn,
-                save_checkpoints=False,
+        # raise Exception when trying to pass `weights_only`=True to `torch.load()`
+        with pytest.raises(ValueError):
+            model_rt = RNNModel(12, "RNN", 5, 5, **tfm_kwargs)
+            model_rt.load_weights_from_checkpoint(
                 model_name=original_model_name,
-                random_state=1,
-                **tfm_kwargs,
-            )
-            model.fit(ts_training)
-            path_manual_save = os.path.join(tmpdir_fn, "RNN_manual_save.pt")
-            model.save(path_manual_save)
-            original_preds = model.predict(10)
-            original_mape = mape(original_preds, ts_test)
-
-            # load last checkpoint of original model, train it for 2 additional epochs
-            model_rt = RNNModel(
-                12,
-                "RNN",
-                5,
-                1,
-                n_epochs=5,
                 work_dir=tmpdir_fn,
-                model_name=retrained_model_name,
-                random_state=1,
-                **tfm_kwargs,
-            )
-            model_rt.load_weights(path=path_manual_save, map_location="cpu")
-
-            # must indicate series otherwise self.training_series must be saved in checkpoint
-            loaded_preds = model_rt.predict(10, ts_training)
-            # save/load checkpoint should produce identical predictions
-            assert original_preds == loaded_preds
-
-            model_rt.fit(ts_training)
-            retrained_preds = model_rt.predict(10)
-            retrained_mape = mape(retrained_preds, ts_test)
-            assert retrained_mape < original_mape, (
-                f"Retrained model has a greater mape error than the original model, "
-                f"respectively {retrained_mape} and {original_mape}"
+                best=False,
+                weights_only=True,
+                map_location="cpu",
             )
 
-        def test_load_weights_with_float32_dtype(self, tmpdir_fn):
-            ts_float32 = self.series.astype("float32")
-            model_name = "test_model"
-            ckpt_path = os.path.join(tmpdir_fn, f"{model_name}.pt")
-            # barebone model
-            model = DLinearModel(
-                input_chunk_length=4,
-                output_chunk_length=1,
-                n_epochs=1,
-            )
-            model.fit(ts_float32)
-            model.save(ckpt_path)
-            assert model.model._dtype == torch.float32  # type: ignore
+    def test_load_weights(self, tmpdir_fn):
+        ts_training, ts_test = self.series.split_before(90)
+        original_model_name = "original"
+        retrained_model_name = "retrained"
+        # original model, checkpoints are saved
+        model = RNNModel(
+            12,
+            "RNN",
+            5,
+            1,
+            n_epochs=5,
+            work_dir=tmpdir_fn,
+            save_checkpoints=False,
+            model_name=original_model_name,
+            random_state=1,
+            **tfm_kwargs,
+        )
+        model.fit(ts_training)
+        path_manual_save = os.path.join(tmpdir_fn, "RNN_manual_save.pt")
+        model.save(path_manual_save)
+        original_preds = model.predict(10)
+        original_mape = mape(original_preds, ts_test)
+
+        # load last checkpoint of original model, train it for 2 additional epochs
+        model_rt = RNNModel(
+            12,
+            "RNN",
+            5,
+            1,
+            n_epochs=5,
+            work_dir=tmpdir_fn,
+            model_name=retrained_model_name,
+            random_state=1,
+            **tfm_kwargs,
+        )
+        model_rt.load_weights(path=path_manual_save, map_location="cpu")
+
+        # must indicate series otherwise self.training_series must be saved in checkpoint
+        loaded_preds = model_rt.predict(10, ts_training)
+        # save/load checkpoint should produce identical predictions
+        assert original_preds == loaded_preds
+
+        model_rt.fit(ts_training)
+        retrained_preds = model_rt.predict(10)
+        retrained_mape = mape(retrained_preds, ts_test)
+        assert retrained_mape < original_mape, (
+            f"Retrained model has a greater mape error than the original model, "
+            f"respectively {retrained_mape} and {original_mape}"
+        )
 
-            # identical model
-            loading_model = DLinearModel(
-                input_chunk_length=4,
-                output_chunk_length=1,
-            )
-            loading_model.load_weights(ckpt_path)
-            loading_model.fit(ts_float32)
-            assert loading_model.model._dtype == torch.float32  # type: ignore
-
-        def test_multi_steps_pipeline(self, tmpdir_fn):
-            ts_training, ts_val = self.series.split_before(75)
-            pretrain_model_name = "pre-train"
-            retrained_model_name = "re-train"
-
-            # pretraining
-            model = self.helper_create_RNNModel(pretrain_model_name, tmpdir_fn)
-            model.fit(
-                ts_training,
-                val_series=ts_val,
-            )
+    def test_load_weights_with_float32_dtype(self, tmpdir_fn):
+        ts_float32 = self.series.astype("float32")
+        model_name = "test_model"
+        ckpt_path = os.path.join(tmpdir_fn, f"{model_name}.pt")
+        # barebone model
+        model = DLinearModel(
+            input_chunk_length=4,
+            output_chunk_length=1,
+            n_epochs=1,
+        )
+        model.fit(ts_float32)
+        model.save(ckpt_path)
+        assert model.model._dtype == torch.float32  # type: ignore
+
+        # identical model
+        loading_model = DLinearModel(
+            input_chunk_length=4,
+            output_chunk_length=1,
+        )
+        loading_model.load_weights(ckpt_path)
+        loading_model.fit(ts_float32)
+        assert loading_model.model._dtype == torch.float32  # type: ignore
+
+    def test_multi_steps_pipeline(self, tmpdir_fn):
+        ts_training, ts_val = self.series.split_before(75)
+        pretrain_model_name = "pre-train"
+        retrained_model_name = "re-train"
+
+        # pretraining
+        model = self.helper_create_RNNModel(pretrain_model_name, tmpdir_fn)
+        model.fit(
+            ts_training,
+            val_series=ts_val,
+        )
 
-            # finetuning
-            model = self.helper_create_RNNModel(retrained_model_name, tmpdir_fn)
-            model.load_weights_from_checkpoint(
-                model_name=pretrain_model_name,
-                work_dir=tmpdir_fn,
-                best=True,
-                map_location="cpu",
-            )
-            model.fit(
-                ts_training,
-                val_series=ts_val,
-            )
+        # finetuning
+        model = self.helper_create_RNNModel(retrained_model_name, tmpdir_fn)
+        model.load_weights_from_checkpoint(
+            model_name=pretrain_model_name,
+            work_dir=tmpdir_fn,
+            best=True,
+            map_location="cpu",
+        )
+        model.fit(
+            ts_training,
+            val_series=ts_val,
+        )
 
-            # prediction
-            model = model.load_from_checkpoint(
-                model_name=retrained_model_name,
-                work_dir=tmpdir_fn,
-                best=True,
-                map_location="cpu",
-            )
-            model.predict(4, series=ts_training)
+        # prediction
+        model = model.load_from_checkpoint(
+            model_name=retrained_model_name,
+            work_dir=tmpdir_fn,
+            best=True,
+            map_location="cpu",
+        )
+        model.predict(4, series=ts_training)
+
+    def test_load_from_checkpoint_w_custom_loss(self, tmpdir_fn):
+        model_name = "pretraining_custom_loss"
+        # model with a custom loss
+        model = RNNModel(
+            12,
+            "RNN",
+            5,
+            1,
+            n_epochs=1,
+            work_dir=tmpdir_fn,
+            model_name=model_name,
+            save_checkpoints=True,
+            force_reset=True,
+            loss_fn=torch.nn.L1Loss(),
+            **tfm_kwargs,
+        )
+        model.fit(self.series)
 
-        def test_load_from_checkpoint_w_custom_loss(self, tmpdir_fn):
-            model_name = "pretraining_custom_loss"
-            # model with a custom loss
-            model = RNNModel(
-                12,
-                "RNN",
-                5,
-                1,
-                n_epochs=1,
-                work_dir=tmpdir_fn,
-                model_name=model_name,
-                save_checkpoints=True,
-                force_reset=True,
-                loss_fn=torch.nn.L1Loss(),
-                **tfm_kwargs,
-            )
-            model.fit(self.series)
+        loaded_model = RNNModel.load_from_checkpoint(
+            model_name, tmpdir_fn, best=False, map_location="cpu"
+        )
+        # custom loss function should be properly restored from ckpt
+        assert isinstance(loaded_model.model.criterion, torch.nn.L1Loss)
+
+        loaded_model.fit(self.series, epochs=2)
+        # calling fit() should not impact the loss function
+        assert isinstance(loaded_model.model.criterion, torch.nn.L1Loss)
+
+    def test_load_from_checkpoint_w_metrics(self, tmpdir_fn):
+        model_name = "pretraining_metrics"
+        # model with one torch_metrics
+        pl_trainer_kwargs = dict(
+            {"logger": DummyLogger(), "log_every_n_steps": 1},
+            **tfm_kwargs["pl_trainer_kwargs"],
+        )
+        model = RNNModel(
+            12,
+            "RNN",
+            5,
+            1,
+            n_epochs=1,
+            work_dir=tmpdir_fn,
+            model_name=model_name,
+            save_checkpoints=True,
+            force_reset=True,
+            torch_metrics=MeanAbsolutePercentageError(),
+            pl_trainer_kwargs=pl_trainer_kwargs,
+        )
+        model.fit(self.series)
+        # check train_metrics before loading
+        assert isinstance(model.model.train_metrics, MetricCollection)
+        assert len(model.model.train_metrics) == 1
+
+        loaded_model = RNNModel.load_from_checkpoint(
+            model_name,
+            tmpdir_fn,
+            best=False,
+            map_location="cpu",
+        )
+        # custom loss function should be properly restored from ckpt torchmetrics.Metric
+        assert isinstance(loaded_model.model.train_metrics, MetricCollection)
+        assert len(loaded_model.model.train_metrics) == 1
 
-            loaded_model = RNNModel.load_from_checkpoint(
-                model_name, tmpdir_fn, best=False, map_location="cpu"
-            )
-            # custom loss function should be properly restored from ckpt
-            assert isinstance(loaded_model.model.criterion, torch.nn.L1Loss)
-
-            loaded_model.fit(self.series, epochs=2)
-            # calling fit() should not impact the loss function
-            assert isinstance(loaded_model.model.criterion, torch.nn.L1Loss)
-
-        def test_load_from_checkpoint_w_metrics(self, tmpdir_fn):
-            model_name = "pretraining_metrics"
-            # model with one torch_metrics
-            pl_trainer_kwargs = dict(
-                {"logger": DummyLogger(), "log_every_n_steps": 1},
-                **tfm_kwargs["pl_trainer_kwargs"],
-            )
-            model = RNNModel(
-                12,
-                "RNN",
-                5,
-                1,
-                n_epochs=1,
-                work_dir=tmpdir_fn,
-                model_name=model_name,
-                save_checkpoints=True,
-                force_reset=True,
-                torch_metrics=MeanAbsolutePercentageError(),
-                pl_trainer_kwargs=pl_trainer_kwargs,
-            )
-            model.fit(self.series)
-            # check train_metrics before loading
-            assert isinstance(model.model.train_metrics, MetricCollection)
-            assert len(model.model.train_metrics) == 1
+    def test_optimizers(self):
 
-            loaded_model = RNNModel.load_from_checkpoint(
-                model_name,
-                tmpdir_fn,
-                best=False,
-                map_location="cpu",
-            )
-            # custom loss function should be properly restored from ckpt torchmetrics.Metric
-            assert isinstance(loaded_model.model.train_metrics, MetricCollection)
-            assert len(loaded_model.model.train_metrics) == 1
-
-        def test_optimizers(self):
-
-            optimizers = [
-                (torch.optim.Adam, {"lr": 0.001}),
-                (torch.optim.SGD, {"lr": 0.001}),
-            ]
-
-            for optim_cls, optim_kwargs in optimizers:
-                model = RNNModel(
-                    12,
-                    "RNN",
-                    10,
-                    10,
-                    optimizer_cls=optim_cls,
-                    optimizer_kwargs=optim_kwargs,
-                    **tfm_kwargs,
-                )
-                # should not raise an error
-                model.fit(self.series, epochs=1)
-
-        @pytest.mark.parametrize(
-            "lr_scheduler",
-            [
-                (torch.optim.lr_scheduler.StepLR, {"step_size": 10}),
-                (
-                    torch.optim.lr_scheduler.ReduceLROnPlateau,
-                    {
-                        "threshold": 0.001,
-                        "monitor": "train_loss",
-                        "interval": "step",
-                        "frequency": 2,
-                    },
-                ),
-                (torch.optim.lr_scheduler.ExponentialLR, {"gamma": 0.09}),
-            ],
-        )
-        def test_lr_schedulers(self, lr_scheduler):
-            lr_scheduler_cls, lr_scheduler_kwargs = lr_scheduler
+        optimizers = [
+            (torch.optim.Adam, {"lr": 0.001}),
+            (torch.optim.SGD, {"lr": 0.001}),
+        ]
+
+        for optim_cls, optim_kwargs in optimizers:
             model = RNNModel(
                 12,
                 "RNN",
                 10,
                 10,
-                lr_scheduler_cls=lr_scheduler_cls,
-                lr_scheduler_kwargs=lr_scheduler_kwargs,
+                optimizer_cls=optim_cls,
+                optimizer_kwargs=optim_kwargs,
                 **tfm_kwargs,
             )
             # should not raise an error
             model.fit(self.series, epochs=1)
 
-        def test_wrong_model_creation_params(self):
-            valid_kwarg = {"pl_trainer_kwargs": {}}
-            invalid_kwarg = {"some_invalid_kwarg": None}
-
-            # valid params should not raise an error
-            _ = RNNModel(12, "RNN", 10, 10, **valid_kwarg)
+    @pytest.mark.parametrize(
+        "lr_scheduler",
+        [
+            (torch.optim.lr_scheduler.StepLR, {"step_size": 10}),
+            (
+                torch.optim.lr_scheduler.ReduceLROnPlateau,
+                {
+                    "threshold": 0.001,
+                    "monitor": "train_loss",
+                    "interval": "step",
+                    "frequency": 2,
+                },
+            ),
+            (torch.optim.lr_scheduler.ExponentialLR, {"gamma": 0.09}),
+        ],
+    )
+    def test_lr_schedulers(self, lr_scheduler):
+        lr_scheduler_cls, lr_scheduler_kwargs = lr_scheduler
+        model = RNNModel(
+            12,
+            "RNN",
+            10,
+            10,
+            lr_scheduler_cls=lr_scheduler_cls,
+            lr_scheduler_kwargs=lr_scheduler_kwargs,
+            **tfm_kwargs,
+        )
+        # should not raise an error
+        model.fit(self.series, epochs=1)
 
-            # invalid params should raise an error
-            with pytest.raises(ValueError):
-                _ = RNNModel(12, "RNN", 10, 10, **invalid_kwarg)
+    def test_wrong_model_creation_params(self):
+        valid_kwarg = {"pl_trainer_kwargs": {}}
+        invalid_kwarg = {"some_invalid_kwarg": None}
 
-        def test_metrics(self):
-            metric = MeanAbsolutePercentageError()
-            metric_collection = MetricCollection(
-                [MeanAbsolutePercentageError(), MeanAbsoluteError()]
-            )
+        # valid params should not raise an error
+        _ = RNNModel(12, "RNN", 10, 10, **valid_kwarg)
 
-            model_kwargs = {
-                "logger": DummyLogger(),
-                "log_every_n_steps": 1,
-                **tfm_kwargs["pl_trainer_kwargs"],
-            }
-            # test single metric
-            model = RNNModel(
-                12,
-                "RNN",
-                10,
-                10,
-                n_epochs=1,
-                torch_metrics=metric,
-                pl_trainer_kwargs=model_kwargs,
-            )
-            model.fit(self.series)
+        # invalid params should raise an error
+        with pytest.raises(ValueError):
+            _ = RNNModel(12, "RNN", 10, 10, **invalid_kwarg)
 
-            # test metric collection
-            model = RNNModel(
-                12,
-                "RNN",
-                10,
-                10,
-                n_epochs=1,
-                torch_metrics=metric_collection,
-                pl_trainer_kwargs=model_kwargs,
-            )
-            model.fit(self.series)
+    def test_metrics(self):
+        metric = MeanAbsolutePercentageError()
+        metric_collection = MetricCollection(
+            [MeanAbsolutePercentageError(), MeanAbsoluteError()]
+        )
 
-            # test multivariate series
-            model = RNNModel(
-                12,
-                "RNN",
-                10,
-                10,
-                n_epochs=1,
-                torch_metrics=metric,
-                pl_trainer_kwargs=model_kwargs,
-            )
-            model.fit(self.multivariate_series)
+        model_kwargs = {
+            "logger": DummyLogger(),
+            "log_every_n_steps": 1,
+            **tfm_kwargs["pl_trainer_kwargs"],
+        }
+        # test single metric
+        model = RNNModel(
+            12,
+            "RNN",
+            10,
+            10,
+            n_epochs=1,
+            torch_metrics=metric,
+            pl_trainer_kwargs=model_kwargs,
+        )
+        model.fit(self.series)
+
+        # test metric collection
+        model = RNNModel(
+            12,
+            "RNN",
+            10,
+            10,
+            n_epochs=1,
+            torch_metrics=metric_collection,
+            pl_trainer_kwargs=model_kwargs,
+        )
+        model.fit(self.series)
+
+        # test multivariate series
+        model = RNNModel(
+            12,
+            "RNN",
+            10,
+            10,
+            n_epochs=1,
+            torch_metrics=metric,
+            pl_trainer_kwargs=model_kwargs,
+        )
+        model.fit(self.multivariate_series)
 
-        def test_metrics_w_likelihood(self):
-            metric = MeanAbsolutePercentageError()
-            metric_collection = MetricCollection(
-                [MeanAbsolutePercentageError(), MeanAbsoluteError()]
-            )
-            model_kwargs = {
-                "logger": DummyLogger(),
-                "log_every_n_steps": 1,
-                **tfm_kwargs["pl_trainer_kwargs"],
-            }
-            # test single metric
-            model = RNNModel(
-                12,
-                "RNN",
-                10,
-                10,
-                n_epochs=1,
-                likelihood=GaussianLikelihood(),
-                torch_metrics=metric,
-                pl_trainer_kwargs=model_kwargs,
-            )
-            model.fit(self.series)
+    def test_metrics_w_likelihood(self):
+        metric = MeanAbsolutePercentageError()
+        metric_collection = MetricCollection(
+            [MeanAbsolutePercentageError(), MeanAbsoluteError()]
+        )
+        model_kwargs = {
+            "logger": DummyLogger(),
+            "log_every_n_steps": 1,
+            **tfm_kwargs["pl_trainer_kwargs"],
+        }
+        # test single metric
+        model = RNNModel(
+            12,
+            "RNN",
+            10,
+            10,
+            n_epochs=1,
+            likelihood=GaussianLikelihood(),
+            torch_metrics=metric,
+            pl_trainer_kwargs=model_kwargs,
+        )
+        model.fit(self.series)
+
+        # test metric collection
+        model = RNNModel(
+            12,
+            "RNN",
+            10,
+            10,
+            n_epochs=1,
+            likelihood=GaussianLikelihood(),
+            torch_metrics=metric_collection,
+            pl_trainer_kwargs=model_kwargs,
+        )
+        model.fit(self.series)
+
+        # test multivariate series
+        model = RNNModel(
+            12,
+            "RNN",
+            10,
+            10,
+            n_epochs=1,
+            likelihood=GaussianLikelihood(),
+            torch_metrics=metric_collection,
+            pl_trainer_kwargs=model_kwargs,
+        )
+        model.fit(self.multivariate_series)
 
-            # test metric collection
+    def test_invalid_metrics(self):
+        torch_metrics = ["invalid"]
+        with pytest.raises(AttributeError):
             model = RNNModel(
                 12,
                 "RNN",
                 10,
                 10,
                 n_epochs=1,
-                likelihood=GaussianLikelihood(),
-                torch_metrics=metric_collection,
-                pl_trainer_kwargs=model_kwargs,
+                torch_metrics=torch_metrics,
+                **tfm_kwargs,
             )
             model.fit(self.series)
 
-            # test multivariate series
+    @pytest.mark.slow
+    def test_lr_find(self):
+        train_series, val_series = self.series[:-40], self.series[-40:]
+        model = RNNModel(12, "RNN", 10, 10, random_state=42, **tfm_kwargs)
+        # find the learning rate
+        res = model.lr_find(series=train_series, val_series=val_series, epochs=50)
+        assert isinstance(res, _LRFinder)
+        assert res.suggestion() is not None
+        # verify that learning rate finder bypasses the `fit` logic
+        assert model.model is None
+        assert not model._fit_called
+        # cannot predict with an untrained model
+        with pytest.raises(ValueError):
+            model.predict(n=3, series=self.series)
+
+        # check that results are reproducible
+        model = RNNModel(12, "RNN", 10, 10, random_state=42, **tfm_kwargs)
+        res2 = model.lr_find(series=train_series, val_series=val_series, epochs=50)
+        assert res.suggestion() == res2.suggestion()
+
+        # check that suggested learning rate is better than the worst
+        lr_worst = res.results["lr"][np.argmax(res.results["loss"])]
+        lr_suggested = res.suggestion()
+        scores = {}
+        for lr, lr_name in zip([lr_worst, lr_suggested], ["worst", "suggested"]):
             model = RNNModel(
                 12,
                 "RNN",
                 10,
                 10,
-                n_epochs=1,
-                likelihood=GaussianLikelihood(),
-                torch_metrics=metric_collection,
-                pl_trainer_kwargs=model_kwargs,
-            )
-            model.fit(self.multivariate_series)
-
-        def test_invalid_metrics(self):
-            torch_metrics = ["invalid"]
-            with pytest.raises(AttributeError):
-                model = RNNModel(
-                    12,
-                    "RNN",
-                    10,
-                    10,
-                    n_epochs=1,
-                    torch_metrics=torch_metrics,
-                    **tfm_kwargs,
-                )
-                model.fit(self.series)
-
-        @pytest.mark.slow
-        def test_lr_find(self):
-            train_series, val_series = self.series[:-40], self.series[-40:]
-            model = RNNModel(12, "RNN", 10, 10, random_state=42, **tfm_kwargs)
-            # find the learning rate
-            res = model.lr_find(series=train_series, val_series=val_series, epochs=50)
-            assert isinstance(res, _LRFinder)
-            assert res.suggestion() is not None
-            # verify that learning rate finder bypasses the `fit` logic
-            assert model.model is None
-            assert not model._fit_called
-            # cannot predict with an untrained model
-            with pytest.raises(ValueError):
-                model.predict(n=3, series=self.series)
-
-            # check that results are reproducible
-            model = RNNModel(12, "RNN", 10, 10, random_state=42, **tfm_kwargs)
-            res2 = model.lr_find(series=train_series, val_series=val_series, epochs=50)
-            assert res.suggestion() == res2.suggestion()
-
-            # check that suggested learning rate is better than the worst
-            lr_worst = res.results["lr"][np.argmax(res.results["loss"])]
-            lr_suggested = res.suggestion()
-            scores = {}
-            for lr, lr_name in zip([lr_worst, lr_suggested], ["worst", "suggested"]):
-                model = RNNModel(
-                    12,
-                    "RNN",
-                    10,
-                    10,
-                    n_epochs=10,
-                    random_state=42,
-                    optimizer_cls=torch.optim.Adam,
-                    optimizer_kwargs={"lr": lr},
-                    **tfm_kwargs,
-                )
-                model.fit(train_series)
-                scores[lr_name] = mape(
-                    val_series, model.predict(len(val_series), series=train_series)
-                )
-            assert scores["worst"] > scores["suggested"]
-
-        def test_encoders(self, tmpdir_fn):
-            series = tg.linear_timeseries(length=10)
-            pc = tg.linear_timeseries(length=12)
-            fc = tg.linear_timeseries(length=13)
-            # 1 == output_chunk_length, 3 > output_chunk_length
-            ns = [1, 3]
-
-            model = self.helper_create_DLinearModel(
-                work_dir=tmpdir_fn,
-                add_encoders={
-                    "datetime_attribute": {"past": ["hour"], "future": ["month"]}
-                },
+                n_epochs=10,
+                random_state=42,
+                optimizer_cls=torch.optim.Adam,
+                optimizer_kwargs={"lr": lr},
+                **tfm_kwargs,
             )
-            model.fit(series)
-            for n in ns:
-                _ = model.predict(n=n)
-                with pytest.raises(ValueError):
-                    _ = model.predict(n=n, past_covariates=pc)
-                with pytest.raises(ValueError):
-                    _ = model.predict(n=n, future_covariates=fc)
-                with pytest.raises(ValueError):
-                    _ = model.predict(n=n, past_covariates=pc, future_covariates=fc)
-
-            model = self.helper_create_DLinearModel(
-                work_dir=tmpdir_fn,
-                add_encoders={
-                    "datetime_attribute": {"past": ["hour"], "future": ["month"]}
-                },
+            model.fit(train_series)
+            scores[lr_name] = mape(
+                val_series, model.predict(len(val_series), series=train_series)
             )
-            for n in ns:
-                model.fit(series, past_covariates=pc)
-                _ = model.predict(n=n)
-                _ = model.predict(n=n, past_covariates=pc)
-                with pytest.raises(ValueError):
-                    _ = model.predict(n=n, future_covariates=fc)
-                with pytest.raises(ValueError):
-                    _ = model.predict(n=n, past_covariates=pc, future_covariates=fc)
+        assert scores["worst"] > scores["suggested"]
 
-            model = self.helper_create_DLinearModel(
-                work_dir=tmpdir_fn,
-                add_encoders={
-                    "datetime_attribute": {"past": ["hour"], "future": ["month"]}
-                },
-            )
-            for n in ns:
-                model.fit(series, future_covariates=fc)
-                _ = model.predict(n=n)
-                with pytest.raises(ValueError):
-                    _ = model.predict(n=n, past_covariates=pc)
-                _ = model.predict(n=n, future_covariates=fc)
-                with pytest.raises(ValueError):
-                    _ = model.predict(n=n, past_covariates=pc, future_covariates=fc)
+    def test_encoders(self, tmpdir_fn):
+        series = tg.linear_timeseries(length=10)
+        pc = tg.linear_timeseries(length=12)
+        fc = tg.linear_timeseries(length=13)
+        # 1 == output_chunk_length, 3 > output_chunk_length
+        ns = [1, 3]
 
-            model = self.helper_create_DLinearModel(
-                work_dir=tmpdir_fn,
-                add_encoders={
-                    "datetime_attribute": {"past": ["hour"], "future": ["month"]}
-                },
-            )
-            for n in ns:
-                model.fit(series, past_covariates=pc, future_covariates=fc)
-                _ = model.predict(n=n)
+        model = self.helper_create_DLinearModel(
+            work_dir=tmpdir_fn,
+            add_encoders={
+                "datetime_attribute": {"past": ["hour"], "future": ["month"]}
+            },
+        )
+        model.fit(series)
+        for n in ns:
+            _ = model.predict(n=n)
+            with pytest.raises(ValueError):
                 _ = model.predict(n=n, past_covariates=pc)
+            with pytest.raises(ValueError):
                 _ = model.predict(n=n, future_covariates=fc)
+            with pytest.raises(ValueError):
                 _ = model.predict(n=n, past_covariates=pc, future_covariates=fc)
 
-        @pytest.mark.parametrize("model_config", models)
-        def test_rin(self, model_config):
-            model_cls, kwargs = model_config
-            model_no_rin = model_cls(use_reversible_instance_norm=False, **kwargs)
-            model_rin = model_cls(use_reversible_instance_norm=True, **kwargs)
-
-            # univariate no RIN
-            model_no_rin.fit(self.series)
-            assert not model_no_rin.model.use_reversible_instance_norm
-            assert model_no_rin.model.rin is None
-
-            # univariate with RIN
-            model_rin.fit(self.series)
-            if issubclass(model_cls, RNNModel):
-                # RNNModel will not use RIN
-                assert not model_rin.model.use_reversible_instance_norm
-                assert model_rin.model.rin is None
-                return
-            else:
-                assert model_rin.model.use_reversible_instance_norm
-                assert isinstance(model_rin.model.rin, RINorm)
-                assert model_rin.model.rin.input_dim == self.series.n_components
-            # multivariate with RIN
-            model_rin_mv = model_rin.untrained_model()
-            model_rin_mv.fit(self.multivariate_series)
-            assert model_rin_mv.model.use_reversible_instance_norm
-            assert isinstance(model_rin_mv.model.rin, RINorm)
-            assert (
-                model_rin_mv.model.rin.input_dim
-                == self.multivariate_series.n_components
-            )
-
-        @pytest.mark.parametrize(
-            "config",
-            itertools.product(
-                [
-                    (
-                        TFTModel,
-                        {
-                            "add_relative_index": True,
-                            "likelihood": None,
-                            "loss_fn": torch.nn.MSELoss(),
-                        },
-                    ),
-                    (TiDEModel, {}),
-                    (NLinearModel, {}),
-                    (DLinearModel, {}),
-                    (NBEATSModel, {}),
-                    (NHiTSModel, {}),
-                    (TransformerModel, {}),
-                    (TCNModel, {}),
-                    (TSMixerModel, {}),
-                    (BlockRNNModel, {}),
-                    (GlobalNaiveSeasonal, {}),
-                    (GlobalNaiveAggregate, {}),
-                    (GlobalNaiveDrift, {}),
-                ],
-                [3, 7, 10],
-            ),
+        model = self.helper_create_DLinearModel(
+            work_dir=tmpdir_fn,
+            add_encoders={
+                "datetime_attribute": {"past": ["hour"], "future": ["month"]}
+            },
         )
-        def test_output_shift(self, config):
-            """Tests shifted output for shift smaller than, equal to, and larger than output_chunk_length.
-            RNNModel does not support shift output chunk.
-            """
-            np.random.seed(0)
-            (model_cls, add_params), shift = config
-            icl = 8
-            ocl = 7
-            series = tg.gaussian_timeseries(
-                length=28, start=pd.Timestamp("2000-01-01"), freq="d"
-            )
-
-            model = self.helper_create_torch_model(
-                model_cls, icl, ocl, shift, **add_params
-            )
-            model.fit(series)
+        for n in ns:
+            model.fit(series, past_covariates=pc)
+            _ = model.predict(n=n)
+            _ = model.predict(n=n, past_covariates=pc)
+            with pytest.raises(ValueError):
+                _ = model.predict(n=n, future_covariates=fc)
+            with pytest.raises(ValueError):
+                _ = model.predict(n=n, past_covariates=pc, future_covariates=fc)
 
-            # no auto-regression with shifted output
-            with pytest.raises(ValueError) as err:
-                _ = model.predict(n=ocl + 1)
-            assert str(err.value).startswith("Cannot perform auto-regression")
-
-            # pred starts with a shift
-            for ocl_test in [ocl - 1, ocl]:
-                pred = model.predict(n=ocl_test)
-                assert (
-                    pred.start_time() == series.end_time() + (shift + 1) * series.freq
-                )
-                assert len(pred) == ocl_test
-                assert pred.freq == series.freq
-
-            # check that shifted output chunk results with encoders are the
-            # same as using identical covariates
-
-            # model trained on encoders
-            cov_support = []
-            covs = {}
-            if model.supports_past_covariates:
-                cov_support.append("past")
-                covs["past_covariates"] = tg.datetime_attribute_timeseries(
-                    series,
-                    attribute="dayofweek",
-                    add_length=0,
-                )
-            if model.supports_future_covariates:
-                cov_support.append("future")
-                covs["future_covariates"] = tg.datetime_attribute_timeseries(
-                    series,
-                    attribute="dayofweek",
-                    add_length=ocl + shift,
-                )
-
-            if not cov_support:
-                return
-
-            add_encoders = {
-                "datetime_attribute": {cov: ["dayofweek"] for cov in cov_support}
-            }
-            model_enc_shift = self.helper_create_torch_model(
-                model_cls, icl, ocl, shift, add_encoders=add_encoders, **add_params
-            )
-            model_enc_shift.fit(series)
+        model = self.helper_create_DLinearModel(
+            work_dir=tmpdir_fn,
+            add_encoders={
+                "datetime_attribute": {"past": ["hour"], "future": ["month"]}
+            },
+        )
+        for n in ns:
+            model.fit(series, future_covariates=fc)
+            _ = model.predict(n=n)
+            with pytest.raises(ValueError):
+                _ = model.predict(n=n, past_covariates=pc)
+            _ = model.predict(n=n, future_covariates=fc)
+            with pytest.raises(ValueError):
+                _ = model.predict(n=n, past_covariates=pc, future_covariates=fc)
 
-            # model trained with identical covariates
-            model_fc_shift = self.helper_create_torch_model(
-                model_cls, icl, ocl, shift, **add_params
-            )
+        model = self.helper_create_DLinearModel(
+            work_dir=tmpdir_fn,
+            add_encoders={
+                "datetime_attribute": {"past": ["hour"], "future": ["month"]}
+            },
+        )
+        for n in ns:
+            model.fit(series, past_covariates=pc, future_covariates=fc)
+            _ = model.predict(n=n)
+            _ = model.predict(n=n, past_covariates=pc)
+            _ = model.predict(n=n, future_covariates=fc)
+            _ = model.predict(n=n, past_covariates=pc, future_covariates=fc)
+
+    @pytest.mark.parametrize("model_config", models)
+    def test_rin(self, model_config):
+        model_cls, kwargs = model_config
+        model_no_rin = model_cls(use_reversible_instance_norm=False, **kwargs)
+        model_rin = model_cls(use_reversible_instance_norm=True, **kwargs)
+
+        # univariate no RIN
+        model_no_rin.fit(self.series)
+        assert not model_no_rin.model.use_reversible_instance_norm
+        assert model_no_rin.model.rin is None
+
+        # univariate with RIN
+        model_rin.fit(self.series)
+        if issubclass(model_cls, RNNModel):
+            # RNNModel will not use RIN
+            assert not model_rin.model.use_reversible_instance_norm
+            assert model_rin.model.rin is None
+            return
+        else:
+            assert model_rin.model.use_reversible_instance_norm
+            assert isinstance(model_rin.model.rin, RINorm)
+            assert model_rin.model.rin.input_dim == self.series.n_components
+        # multivariate with RIN
+        model_rin_mv = model_rin.untrained_model()
+        model_rin_mv.fit(self.multivariate_series)
+        assert model_rin_mv.model.use_reversible_instance_norm
+        assert isinstance(model_rin_mv.model.rin, RINorm)
+        assert model_rin_mv.model.rin.input_dim == self.multivariate_series.n_components
+
+    @pytest.mark.parametrize(
+        "config",
+        itertools.product(
+            [
+                (
+                    TFTModel,
+                    {
+                        "add_relative_index": True,
+                        "likelihood": None,
+                        "loss_fn": torch.nn.MSELoss(),
+                    },
+                ),
+                (TiDEModel, {}),
+                (NLinearModel, {}),
+                (DLinearModel, {}),
+                (NBEATSModel, {}),
+                (NHiTSModel, {}),
+                (TransformerModel, {}),
+                (TCNModel, {}),
+                (TSMixerModel, {}),
+                (BlockRNNModel, {}),
+                (GlobalNaiveSeasonal, {}),
+                (GlobalNaiveAggregate, {}),
+                (GlobalNaiveDrift, {}),
+            ],
+            [3, 7, 10],
+        ),
+    )
+    def test_output_shift(self, config):
+        """Tests shifted output for shift smaller than, equal to, and larger than output_chunk_length.
+        RNNModel does not support shift output chunk.
+        """
+        np.random.seed(0)
+        (model_cls, add_params), shift = config
+        icl = 8
+        ocl = 7
+        series = tg.gaussian_timeseries(
+            length=28, start=pd.Timestamp("2000-01-01"), freq="d"
+        )
 
-            model_fc_shift.fit(series, **covs)
+        model = self.helper_create_torch_model(model_cls, icl, ocl, shift, **add_params)
+        model.fit(series)
+
+        # no auto-regression with shifted output
+        with pytest.raises(ValueError) as err:
+            _ = model.predict(n=ocl + 1)
+        assert str(err.value).startswith("Cannot perform auto-regression")
+
+        # pred starts with a shift
+        for ocl_test in [ocl - 1, ocl]:
+            pred = model.predict(n=ocl_test)
+            assert pred.start_time() == series.end_time() + (shift + 1) * series.freq
+            assert len(pred) == ocl_test
+            assert pred.freq == series.freq
+
+        # check that shifted output chunk results with encoders are the
+        # same as using identical covariates
+
+        # model trained on encoders
+        cov_support = []
+        covs = {}
+        if model.supports_past_covariates:
+            cov_support.append("past")
+            covs["past_covariates"] = tg.datetime_attribute_timeseries(
+                series,
+                attribute="dayofweek",
+                add_length=0,
+            )
+        if model.supports_future_covariates:
+            cov_support.append("future")
+            covs["future_covariates"] = tg.datetime_attribute_timeseries(
+                series,
+                attribute="dayofweek",
+                add_length=ocl + shift,
+            )
+
+        if not cov_support:
+            return
+
+        add_encoders = {
+            "datetime_attribute": {cov: ["dayofweek"] for cov in cov_support}
+        }
+        model_enc_shift = self.helper_create_torch_model(
+            model_cls, icl, ocl, shift, add_encoders=add_encoders, **add_params
+        )
+        model_enc_shift.fit(series)
 
-            pred_enc = model_enc_shift.predict(n=ocl)
-            pred_fc = model_fc_shift.predict(n=ocl)
-            assert pred_enc == pred_fc
+        # model trained with identical covariates
+        model_fc_shift = self.helper_create_torch_model(
+            model_cls, icl, ocl, shift, **add_params
+        )
 
-            # check that historical forecasts works properly
-            hist_fc_start = -(ocl + shift)
-            pred_last_hist_fc = model_fc_shift.predict(
-                n=ocl, series=series[:hist_fc_start]
-            )
-            # non-optimized hist fc
-            hist_fc = model_fc_shift.historical_forecasts(
-                series=series,
-                start=hist_fc_start,
-                start_format="position",
-                retrain=False,
-                forecast_horizon=ocl,
-                last_points_only=False,
-                enable_optimization=False,
-                **covs,
-            )
-            assert len(hist_fc) == 1
-            assert hist_fc[0] == pred_last_hist_fc
-            # optimized hist fc, due to batch predictions, slight deviations in values
-            hist_fc_opt = model_fc_shift.historical_forecasts(
-                series=series,
-                start=hist_fc_start,
-                start_format="position",
-                retrain=False,
-                forecast_horizon=ocl,
-                last_points_only=False,
-                enable_optimization=True,
-                **covs,
-            )
-            assert len(hist_fc_opt) == 1
-            assert hist_fc_opt[0].time_index.equals(pred_last_hist_fc.time_index)
-            np.testing.assert_array_almost_equal(
-                hist_fc_opt[0].values(copy=False), pred_last_hist_fc.values(copy=False)
-            )
+        model_fc_shift.fit(series, **covs)
+
+        pred_enc = model_enc_shift.predict(n=ocl)
+        pred_fc = model_fc_shift.predict(n=ocl)
+        assert pred_enc == pred_fc
+
+        # check that historical forecasts works properly
+        hist_fc_start = -(ocl + shift)
+        pred_last_hist_fc = model_fc_shift.predict(n=ocl, series=series[:hist_fc_start])
+        # non-optimized hist fc
+        hist_fc = model_fc_shift.historical_forecasts(
+            series=series,
+            start=hist_fc_start,
+            start_format="position",
+            retrain=False,
+            forecast_horizon=ocl,
+            last_points_only=False,
+            enable_optimization=False,
+            **covs,
+        )
+        assert len(hist_fc) == 1
+        assert hist_fc[0] == pred_last_hist_fc
+        # optimized hist fc, due to batch predictions, slight deviations in values
+        hist_fc_opt = model_fc_shift.historical_forecasts(
+            series=series,
+            start=hist_fc_start,
+            start_format="position",
+            retrain=False,
+            forecast_horizon=ocl,
+            last_points_only=False,
+            enable_optimization=True,
+            **covs,
+        )
+        assert len(hist_fc_opt) == 1
+        assert hist_fc_opt[0].time_index.equals(pred_last_hist_fc.time_index)
+        np.testing.assert_array_almost_equal(
+            hist_fc_opt[0].values(copy=False), pred_last_hist_fc.values(copy=False)
+        )
 
-            # covs too short
-            for cov_name in cov_support:
-                with pytest.raises(ValueError) as err:
-                    add_covs = {
-                        cov_name + "_covariates": covs[cov_name + "_covariates"][:-1]
-                    }
-                    _ = model_fc_shift.predict(n=ocl, **add_covs)
-                assert f"provided {cov_name} covariates at dataset index" in str(
-                    err.value
-                )
-
-        def helper_equality_encoders(
-            self, first_encoders: Dict[str, Any], second_encoders: Dict[str, Any]
-        ):
-            if first_encoders is None:
-                first_encoders = {}
-            if second_encoders is None:
-                second_encoders = {}
-            assert {k: v for k, v in first_encoders.items() if k != "transformer"} == {
-                k: v for k, v in second_encoders.items() if k != "transformer"
-            }
-
-        def helper_equality_encoders_transfo(
-            self, first_encoders: Dict[str, Any], second_encoders: Dict[str, Any]
-        ):
-            if first_encoders is None:
-                first_encoders = {}
-            if second_encoders is None:
-                second_encoders = {}
-            assert (
-                first_encoders.get("transformer", None).__class__
-                == second_encoders.get("transformer", None).__class__
-            )
+        # covs too short
+        for cov_name in cov_support:
+            with pytest.raises(ValueError) as err:
+                add_covs = {
+                    cov_name + "_covariates": covs[cov_name + "_covariates"][:-1]
+                }
+                _ = model_fc_shift.predict(n=ocl, **add_covs)
+            assert f"provided {cov_name} covariates at dataset index" in str(err.value)
+
+    def helper_equality_encoders(
+        self, first_encoders: Dict[str, Any], second_encoders: Dict[str, Any]
+    ):
+        if first_encoders is None:
+            first_encoders = {}
+        if second_encoders is None:
+            second_encoders = {}
+        assert {k: v for k, v in first_encoders.items() if k != "transformer"} == {
+            k: v for k, v in second_encoders.items() if k != "transformer"
+        }
+
+    def helper_equality_encoders_transfo(
+        self, first_encoders: Dict[str, Any], second_encoders: Dict[str, Any]
+    ):
+        if first_encoders is None:
+            first_encoders = {}
+        if second_encoders is None:
+            second_encoders = {}
+        assert (
+            first_encoders.get("transformer", None).__class__
+            == second_encoders.get("transformer", None).__class__
+        )
 
-        def helper_create_RNNModel(self, model_name: str, tmpdir_fn):
-            return RNNModel(
-                input_chunk_length=4,
-                hidden_dim=3,
-                add_encoders={
-                    "cyclic": {"past": ["month"]},
-                    "datetime_attribute": {
-                        "past": ["hour"],
-                    },
-                    "transformer": Scaler(),
+    def helper_create_RNNModel(self, model_name: str, tmpdir_fn):
+        return RNNModel(
+            input_chunk_length=4,
+            hidden_dim=3,
+            add_encoders={
+                "cyclic": {"past": ["month"]},
+                "datetime_attribute": {
+                    "past": ["hour"],
                 },
-                n_epochs=2,
-                model_name=model_name,
-                work_dir=tmpdir_fn,
-                force_reset=True,
-                save_checkpoints=True,
-                **tfm_kwargs,
-            )
+                "transformer": Scaler(),
+            },
+            n_epochs=2,
+            model_name=model_name,
+            work_dir=tmpdir_fn,
+            force_reset=True,
+            save_checkpoints=True,
+            **tfm_kwargs,
+        )
 
-        def helper_create_DLinearModel(
-            self,
-            work_dir: Optional[str] = None,
-            model_name: str = "unitest_model",
-            add_encoders: Optional[Dict] = None,
-            save_checkpoints: bool = False,
-            likelihood: Optional[Likelihood] = None,
-            output_chunk_length: int = 1,
+    def helper_create_DLinearModel(
+        self,
+        work_dir: Optional[str] = None,
+        model_name: str = "unitest_model",
+        add_encoders: Optional[Dict] = None,
+        save_checkpoints: bool = False,
+        likelihood: Optional[Likelihood] = None,
+        output_chunk_length: int = 1,
+        **kwargs,
+    ):
+        return DLinearModel(
+            input_chunk_length=4,
+            output_chunk_length=output_chunk_length,
+            model_name=model_name,
+            add_encoders=add_encoders,
+            work_dir=work_dir,
+            save_checkpoints=save_checkpoints,
+            random_state=42,
+            force_reset=True,
+            n_epochs=1,
+            likelihood=likelihood,
+            **tfm_kwargs,
             **kwargs,
-        ):
-            return DLinearModel(
-                input_chunk_length=4,
-                output_chunk_length=output_chunk_length,
-                model_name=model_name,
-                add_encoders=add_encoders,
-                work_dir=work_dir,
-                save_checkpoints=save_checkpoints,
-                random_state=42,
-                force_reset=True,
-                n_epochs=1,
-                likelihood=likelihood,
-                **tfm_kwargs,
-                **kwargs,
-            )
+        )
 
-        def helper_create_torch_model(self, model_cls, icl, ocl, shift, **kwargs):
-            params = {
-                "input_chunk_length": icl,
-                "output_chunk_length": ocl,
-                "output_chunk_shift": shift,
-                "n_epochs": 1,
-                "random_state": 42,
-            }
-            params.update(tfm_kwargs)
-            params.update(kwargs)
-            return model_cls(**params)
+    def helper_create_torch_model(self, model_cls, icl, ocl, shift, **kwargs):
+        params = {
+            "input_chunk_length": icl,
+            "output_chunk_length": ocl,
+            "output_chunk_shift": shift,
+            "n_epochs": 1,
+            "random_state": 42,
+        }
+        params.update(tfm_kwargs)
+        params.update(kwargs)
+        return model_cls(**params)
diff --git a/darts/tests/models/forecasting/test_transformer_model.py b/darts/tests/models/forecasting/test_transformer_model.py
index b04fd05485..a70194667b 100644
--- a/darts/tests/models/forecasting/test_transformer_model.py
+++ b/darts/tests/models/forecasting/test_transformer_model.py
@@ -20,186 +20,182 @@
         TransformerModel,
         _TransformerModule,
     )
-
-    TORCH_AVAILABLE = True
 except ImportError:
-    logger.warning("Torch not available. Transformer tests will be skipped.")
-    TORCH_AVAILABLE = False
+    pytest.skip(
+        f"Torch not available. {__name__} tests will be skipped.",
+        allow_module_level=True,
+    )
 
 
-if TORCH_AVAILABLE:
+class TestTransformerModel:
+    times = pd.date_range("20130101", "20130410")
+    pd_series = pd.Series(range(100), index=times)
+    series: TimeSeries = TimeSeries.from_series(pd_series)
+    series_multivariate = series.stack(series * 2)
+    module = _TransformerModule(
+        input_size=1,
+        input_chunk_length=1,
+        output_chunk_length=1,
+        output_chunk_shift=0,
+        train_sample_shape=((1, 1),),
+        output_size=1,
+        nr_params=1,
+        d_model=512,
+        nhead=8,
+        num_encoder_layers=6,
+        num_decoder_layers=6,
+        dim_feedforward=2048,
+        dropout=0.1,
+        activation="relu",
+        norm_type=None,
+        custom_encoder=None,
+        custom_decoder=None,
+    )
 
-    class TestTransformerModel:
-        times = pd.date_range("20130101", "20130410")
-        pd_series = pd.Series(range(100), index=times)
-        series: TimeSeries = TimeSeries.from_series(pd_series)
-        series_multivariate = series.stack(series * 2)
-        module = _TransformerModule(
-            input_size=1,
+    def test_fit(self, tmpdir_module):
+        # Test fit-save-load cycle
+        model2 = TransformerModel(
             input_chunk_length=1,
             output_chunk_length=1,
-            output_chunk_shift=0,
-            train_sample_shape=((1, 1),),
-            output_size=1,
-            nr_params=1,
-            d_model=512,
-            nhead=8,
-            num_encoder_layers=6,
-            num_decoder_layers=6,
-            dim_feedforward=2048,
-            dropout=0.1,
-            activation="relu",
-            norm_type=None,
-            custom_encoder=None,
-            custom_decoder=None,
+            n_epochs=2,
+            model_name="unittest-model-transformer",
+            work_dir=tmpdir_module,
+            save_checkpoints=True,
+            force_reset=True,
+            **tfm_kwargs,
+        )
+        model2.fit(self.series)
+        model_loaded = model2.load_from_checkpoint(
+            model_name="unittest-model-transformer",
+            work_dir=tmpdir_module,
+            best=False,
+            map_location="cpu",
         )
+        pred1 = model2.predict(n=6)
+        pred2 = model_loaded.predict(n=6)
 
-        def test_fit(self, tmpdir_module):
-            # Test fit-save-load cycle
-            model2 = TransformerModel(
+        # Two models with the same parameters should deterministically yield the same output
+        np.testing.assert_array_equal(pred1.values(), pred2.values())
+
+        # Another random model should not
+        model3 = TransformerModel(
+            input_chunk_length=1, output_chunk_length=1, n_epochs=1, **tfm_kwargs
+        )
+        model3.fit(self.series)
+        pred3 = model3.predict(n=6)
+        assert not np.array_equal(pred1.values(), pred3.values())
+
+        # test short predict
+        pred4 = model3.predict(n=1)
+        assert len(pred4) == 1
+
+        # test validation series input
+        model3.fit(self.series[:60], val_series=self.series[60:])
+        pred4 = model3.predict(n=6)
+        assert len(pred4) == 6
+
+    def helper_test_pred_length(self, pytorch_model, series):
+        model = pytorch_model(
+            input_chunk_length=1, output_chunk_length=3, n_epochs=1, **tfm_kwargs
+        )
+        model.fit(series)
+        pred = model.predict(7)
+        assert len(pred) == 7
+        pred = model.predict(2)
+        assert len(pred) == 2
+        assert pred.width == 1
+        pred = model.predict(4)
+        assert len(pred) == 4
+        assert pred.width == 1
+
+    def test_pred_length(self):
+        series = tg.linear_timeseries(length=100)
+        self.helper_test_pred_length(TransformerModel, series)
+
+    def test_activations(self):
+        with pytest.raises(ValueError):
+            model1 = TransformerModel(
                 input_chunk_length=1,
                 output_chunk_length=1,
-                n_epochs=2,
-                model_name="unittest-model-transformer",
-                work_dir=tmpdir_module,
-                save_checkpoints=True,
-                force_reset=True,
+                activation="invalid",
                 **tfm_kwargs,
             )
-            model2.fit(self.series)
-            model_loaded = model2.load_from_checkpoint(
-                model_name="unittest-model-transformer",
-                work_dir=tmpdir_module,
-                best=False,
-                map_location="cpu",
-            )
-            pred1 = model2.predict(n=6)
-            pred2 = model_loaded.predict(n=6)
+            model1.fit(self.series, epochs=1)
 
-            # Two models with the same parameters should deterministically yield the same output
-            np.testing.assert_array_equal(pred1.values(), pred2.values())
+        # internal activation function uses PyTorch TransformerEncoderLayer
+        model2 = TransformerModel(
+            input_chunk_length=1,
+            output_chunk_length=1,
+            activation="gelu",
+            **tfm_kwargs,
+        )
+        model2.fit(self.series, epochs=1)
+        assert isinstance(
+            model2.model.transformer.encoder.layers[0], nn.TransformerEncoderLayer
+        )
+        assert isinstance(
+            model2.model.transformer.decoder.layers[0], nn.TransformerDecoderLayer
+        )
 
-            # Another random model should not
-            model3 = TransformerModel(
-                input_chunk_length=1, output_chunk_length=1, n_epochs=1, **tfm_kwargs
-            )
-            model3.fit(self.series)
-            pred3 = model3.predict(n=6)
-            assert not np.array_equal(pred1.values(), pred3.values())
-
-            # test short predict
-            pred4 = model3.predict(n=1)
-            assert len(pred4) == 1
-
-            # test validation series input
-            model3.fit(self.series[:60], val_series=self.series[60:])
-            pred4 = model3.predict(n=6)
-            assert len(pred4) == 6
-
-        def helper_test_pred_length(self, pytorch_model, series):
-            model = pytorch_model(
-                input_chunk_length=1, output_chunk_length=3, n_epochs=1, **tfm_kwargs
-            )
-            model.fit(series)
-            pred = model.predict(7)
-            assert len(pred) == 7
-            pred = model.predict(2)
-            assert len(pred) == 2
-            assert pred.width == 1
-            pred = model.predict(4)
-            assert len(pred) == 4
-            assert pred.width == 1
-
-        def test_pred_length(self):
-            series = tg.linear_timeseries(length=100)
-            self.helper_test_pred_length(TransformerModel, series)
-
-        def test_activations(self):
-            with pytest.raises(ValueError):
-                model1 = TransformerModel(
-                    input_chunk_length=1,
-                    output_chunk_length=1,
-                    activation="invalid",
-                    **tfm_kwargs,
-                )
-                model1.fit(self.series, epochs=1)
-
-            # internal activation function uses PyTorch TransformerEncoderLayer
-            model2 = TransformerModel(
-                input_chunk_length=1,
-                output_chunk_length=1,
-                activation="gelu",
-                **tfm_kwargs,
-            )
-            model2.fit(self.series, epochs=1)
-            assert isinstance(
-                model2.model.transformer.encoder.layers[0], nn.TransformerEncoderLayer
-            )
-            assert isinstance(
-                model2.model.transformer.decoder.layers[0], nn.TransformerDecoderLayer
-            )
+        # glue variant FFN uses our custom _FeedForwardEncoderLayer
+        model3 = TransformerModel(
+            input_chunk_length=1,
+            output_chunk_length=1,
+            activation="SwiGLU",
+            **tfm_kwargs,
+        )
+        model3.fit(self.series, epochs=1)
+        assert isinstance(
+            model3.model.transformer.encoder.layers[0],
+            CustomFeedForwardEncoderLayer,
+        )
+        assert isinstance(
+            model3.model.transformer.decoder.layers[0],
+            CustomFeedForwardDecoderLayer,
+        )
 
-            # glue variant FFN uses our custom _FeedForwardEncoderLayer
-            model3 = TransformerModel(
-                input_chunk_length=1,
-                output_chunk_length=1,
-                activation="SwiGLU",
-                **tfm_kwargs,
-            )
-            model3.fit(self.series, epochs=1)
-            assert isinstance(
-                model3.model.transformer.encoder.layers[0],
-                CustomFeedForwardEncoderLayer,
-            )
-            assert isinstance(
-                model3.model.transformer.decoder.layers[0],
-                CustomFeedForwardDecoderLayer,
-            )
+    def test_layer_norm(self):
+        base_model = TransformerModel
 
-        def test_layer_norm(self):
-            base_model = TransformerModel
+        # default norm_type is None
+        model0 = base_model(input_chunk_length=1, output_chunk_length=1, **tfm_kwargs)
+        y0 = model0.fit(self.series, epochs=1)
 
-            # default norm_type is None
-            model0 = base_model(
-                input_chunk_length=1, output_chunk_length=1, **tfm_kwargs
-            )
-            y0 = model0.fit(self.series, epochs=1)
+        model1 = base_model(
+            input_chunk_length=1,
+            output_chunk_length=1,
+            norm_type="RMSNorm",
+            **tfm_kwargs,
+        )
+        y1 = model1.fit(self.series, epochs=1)
 
-            model1 = base_model(
-                input_chunk_length=1,
-                output_chunk_length=1,
-                norm_type="RMSNorm",
-                **tfm_kwargs,
-            )
-            y1 = model1.fit(self.series, epochs=1)
+        model2 = base_model(
+            input_chunk_length=1,
+            output_chunk_length=1,
+            norm_type=nn.LayerNorm,
+            **tfm_kwargs,
+        )
+        y2 = model2.fit(self.series, epochs=1)
 
-            model2 = base_model(
-                input_chunk_length=1,
-                output_chunk_length=1,
-                norm_type=nn.LayerNorm,
-                **tfm_kwargs,
-            )
-            y2 = model2.fit(self.series, epochs=1)
+        model3 = base_model(
+            input_chunk_length=1,
+            output_chunk_length=1,
+            activation="gelu",
+            norm_type="RMSNorm",
+            **tfm_kwargs,
+        )
+        y3 = model3.fit(self.series, epochs=1)
+
+        assert y0 != y1
+        assert y0 != y2
+        assert y0 != y3
+        assert y1 != y3
 
-            model3 = base_model(
+        with pytest.raises(AttributeError):
+            model4 = base_model(
                 input_chunk_length=1,
                 output_chunk_length=1,
-                activation="gelu",
-                norm_type="RMSNorm",
+                norm_type="invalid",
                 **tfm_kwargs,
             )
-            y3 = model3.fit(self.series, epochs=1)
-
-            assert y0 != y1
-            assert y0 != y2
-            assert y0 != y3
-            assert y1 != y3
-
-            with pytest.raises(AttributeError):
-                model4 = base_model(
-                    input_chunk_length=1,
-                    output_chunk_length=1,
-                    norm_type="invalid",
-                    **tfm_kwargs,
-                )
-                model4.fit(self.series, epochs=1)
+            model4.fit(self.series, epochs=1)
diff --git a/darts/tests/models/forecasting/test_tsmixer.py b/darts/tests/models/forecasting/test_tsmixer.py
index 6ae3abe39e..3a6813f8e8 100644
--- a/darts/tests/models/forecasting/test_tsmixer.py
+++ b/darts/tests/models/forecasting/test_tsmixer.py
@@ -14,18 +14,13 @@
     from darts.tests.conftest import tfm_kwargs
     from darts.utils import timeseries_generation as tg
     from darts.utils.likelihood_models import GaussianLikelihood
-
-    TORCH_AVAILABLE = True
-
 except ImportError:
-    logger.warning("Torch not available. TSMixerModel tests will be skipped.")
-    TORCH_AVAILABLE = False
+    pytest.skip(
+        f"Torch not available. {__name__} tests will be skipped.",
+        allow_module_level=True,
+    )
 
 
-@pytest.mark.skipif(
-    TORCH_AVAILABLE is False,
-    reason="Torch not available. TSMixerModel tests will be skipped.",
-)
 class TestTSMixerModel:
     np.random.seed(42)
     torch.manual_seed(42)
diff --git a/darts/tests/utils/test_likelihood_models.py b/darts/tests/utils/test_likelihood_models.py
index a7ccce76f9..6d6c3b36a6 100644
--- a/darts/tests/utils/test_likelihood_models.py
+++ b/darts/tests/utils/test_likelihood_models.py
@@ -1,5 +1,7 @@
 from itertools import combinations
 
+import pytest
+
 from darts.logging import get_logger
 
 logger = get_logger(__name__)
@@ -42,25 +44,24 @@
             BetaLikelihood(prior_alpha=0.2, prior_beta=0.4, prior_strength=0.6),
         ],
     }
-
-    TORCH_AVAILABLE = True
 except ImportError:
-    logger.warning("Torch not available. LikelihoodModels tests will be skipped.")
-    TORCH_AVAILABLE = False
+    pytest.skip(
+        f"Torch not available. {__name__} tests will be skipped.",
+        allow_module_level=True,
+    )
 
-if TORCH_AVAILABLE:
 
-    class TestLikelihoodModel:
-        def test_intra_class_equality(self):
-            for _, model_pair in likelihood_models.items():
-                assert model_pair[0] == model_pair[0]
-                assert model_pair[1] == model_pair[1]
-                assert model_pair[0] != model_pair[1]
+class TestLikelihoodModel:
+    def test_intra_class_equality(self):
+        for _, model_pair in likelihood_models.items():
+            assert model_pair[0] == model_pair[0]
+            assert model_pair[1] == model_pair[1]
+            assert model_pair[0] != model_pair[1]
 
-        def test_inter_class_equality(self):
-            model_combinations = combinations(likelihood_models.keys(), 2)
-            for first_model_name, second_model_name in model_combinations:
-                assert (
-                    likelihood_models[first_model_name][0]
-                    != likelihood_models[second_model_name][0]
-                )
+    def test_inter_class_equality(self):
+        model_combinations = combinations(likelihood_models.keys(), 2)
+        for first_model_name, second_model_name in model_combinations:
+            assert (
+                likelihood_models[first_model_name][0]
+                != likelihood_models[second_model_name][0]
+            )
diff --git a/darts/tests/utils/test_losses.py b/darts/tests/utils/test_losses.py
index 329ae45dbc..c740fe28d0 100644
--- a/darts/tests/utils/test_losses.py
+++ b/darts/tests/utils/test_losses.py
@@ -1,3 +1,5 @@
+import pytest
+
 from darts.logging import get_logger
 
 logger = get_logger(__name__)
@@ -5,55 +7,54 @@
 try:
     import torch
 
-    TORCH_AVAILABLE = True
-except ImportError:
-    logger.warning("Torch not available. Loss tests will be skipped.")
-    TORCH_AVAILABLE = False
-
-
-if TORCH_AVAILABLE:
     from darts.utils.losses import MAELoss, MapeLoss, SmapeLoss
-
-    class TestLosses:
-        x = torch.tensor([1.1, 2.2, 0.6345, -1.436])
-        y = torch.tensor([1.5, 0.5])
-
-        def helper_test_loss(self, exp_loss_val, exp_w_grad, loss_fn):
-            W = torch.tensor([[0.1, -0.2, 0.3, -0.4], [-0.8, 0.7, -0.6, 0.5]])
-            W.requires_grad = True
-            y_hat = W @ self.x
-            lval = loss_fn(y_hat, self.y)
-            lval.backward()
-
-            assert torch.allclose(lval, exp_loss_val, atol=1e-3)
-            assert torch.allclose(W.grad, exp_w_grad, atol=1e-3)
-
-        def test_smape_loss(self):
-            exp_val = torch.tensor(0.7753)
-            exp_grad = torch.tensor(
-                [
-                    [-0.2843, -0.5685, -0.1640, 0.3711],
-                    [-0.5859, -1.1718, -0.3380, 0.7649],
-                ]
-            )
-            self.helper_test_loss(exp_val, exp_grad, SmapeLoss())
-
-        def test_mape_loss(self):
-            exp_val = torch.tensor(1.2937)
-            exp_grad = torch.tensor(
-                [
-                    [-0.3667, -0.7333, -0.2115, 0.4787],
-                    [-1.1000, -2.2000, -0.6345, 1.4360],
-                ]
-            )
-            self.helper_test_loss(exp_val, exp_grad, MapeLoss())
-
-        def test_mae_loss(self):
-            exp_val = torch.tensor(1.0020)
-            exp_grad = torch.tensor(
-                [
-                    [-0.5500, -1.1000, -0.3173, 0.7180],
-                    [-0.5500, -1.1000, -0.3173, 0.7180],
-                ]
-            )
-            self.helper_test_loss(exp_val, exp_grad, MAELoss())
+except ImportError:
+    pytest.skip(
+        f"Torch not available. {__name__} tests will be skipped.",
+        allow_module_level=True,
+    )
+
+
+class TestLosses:
+    x = torch.tensor([1.1, 2.2, 0.6345, -1.436])
+    y = torch.tensor([1.5, 0.5])
+
+    def helper_test_loss(self, exp_loss_val, exp_w_grad, loss_fn):
+        W = torch.tensor([[0.1, -0.2, 0.3, -0.4], [-0.8, 0.7, -0.6, 0.5]])
+        W.requires_grad = True
+        y_hat = W @ self.x
+        lval = loss_fn(y_hat, self.y)
+        lval.backward()
+
+        assert torch.allclose(lval, exp_loss_val, atol=1e-3)
+        assert torch.allclose(W.grad, exp_w_grad, atol=1e-3)
+
+    def test_smape_loss(self):
+        exp_val = torch.tensor(0.7753)
+        exp_grad = torch.tensor(
+            [
+                [-0.2843, -0.5685, -0.1640, 0.3711],
+                [-0.5859, -1.1718, -0.3380, 0.7649],
+            ]
+        )
+        self.helper_test_loss(exp_val, exp_grad, SmapeLoss())
+
+    def test_mape_loss(self):
+        exp_val = torch.tensor(1.2937)
+        exp_grad = torch.tensor(
+            [
+                [-0.3667, -0.7333, -0.2115, 0.4787],
+                [-1.1000, -2.2000, -0.6345, 1.4360],
+            ]
+        )
+        self.helper_test_loss(exp_val, exp_grad, MapeLoss())
+
+    def test_mae_loss(self):
+        exp_val = torch.tensor(1.0020)
+        exp_grad = torch.tensor(
+            [
+                [-0.5500, -1.1000, -0.3173, 0.7180],
+                [-0.5500, -1.1000, -0.3173, 0.7180],
+            ]
+        )
+        self.helper_test_loss(exp_val, exp_grad, MAELoss())
diff --git a/darts/tests/utils/test_utils_torch.py b/darts/tests/utils/test_utils_torch.py
index 05cc92dc64..c2556c2e36 100644
--- a/darts/tests/utils/test_utils_torch.py
+++ b/darts/tests/utils/test_utils_torch.py
@@ -9,110 +9,110 @@
     import torch
 
     from darts.utils.torch import random_method
-
-    TORCH_AVAILABLE = True
 except ImportError:
-    logger.warning("Torch not available. Torch utils will not be tested.")
-    TORCH_AVAILABLE = False
-
-
-if TORCH_AVAILABLE:
-    # use a simple torch model mock
-    class TorchModelMock:
-        @random_method
-        def __init__(self, some_params=None, **kwargs):
-            self.model = torch.randn(5)
-            # super().__init__()
-
-        @random_method
-        def fit(self, some_params=None):
-            self.fit_value = torch.randn(5)
-
-    class TestRandomMethod:
-        def test_it_raises_error_if_used_on_function(self):
-            with pytest.raises(ValueError):
-
-                @random_method
-                def a_random_function():
-                    pass
-
-        def test_model_is_random_by_default(self):
-            model1 = TorchModelMock()
-            model2 = TorchModelMock()
-            assert not torch.equal(model1.model, model2.model)
-
-        def test_model_is_random_when_None_random_state_specified(self):
-            model1 = TorchModelMock(random_state=None)
-            model2 = TorchModelMock(random_state=None)
-            assert not torch.equal(model1.model, model2.model)
-
-        def helper_test_reproducibility(self, model1, model2):
-            assert torch.equal(model1.model, model2.model)
-
-            model1.fit()
-            model2.fit()
-            assert torch.equal(model1.fit_value, model2.fit_value)
-
-        def test_model_is_reproducible_when_seed_specified(self):
-            model1 = TorchModelMock(random_state=42)
-            model2 = TorchModelMock(random_state=42)
-            self.helper_test_reproducibility(model1, model2)
-
-        def test_model_is_reproducible_when_random_instance_specified(self):
-            model1 = TorchModelMock(random_state=RandomState(42))
-            model2 = TorchModelMock(random_state=RandomState(42))
-            self.helper_test_reproducibility(model1, model2)
-
-        def test_model_is_different_for_different_seeds(self):
-            model1 = TorchModelMock(random_state=42)
-            model2 = TorchModelMock(random_state=43)
-            assert not torch.equal(model1.model, model2.model)
-
-        def test_model_is_different_for_different_random_instance(self):
-            model1 = TorchModelMock(random_state=RandomState(42))
-            model2 = TorchModelMock(random_state=RandomState(43))
-            assert not torch.equal(model1.model, model2.model)
-
-        def helper_test_successive_call_are_different(self, model):
-            # different between init and fit
-            model.fit()
-            assert not torch.equal(model.model, model.fit_value)
-
-            # different between 2 fit
-            old_fit_value = model.fit_value.clone()
-            model.fit()
-            assert not torch.equal(model.fit_value, old_fit_value)
-
-        def test_successive_call_to_rng_are_different_when_seed_specified(self):
-            model = TorchModelMock(random_state=42)
-            self.helper_test_successive_call_are_different(model)
-
-        def test_successive_call_to_rng_are_different_when_random_instance_specified(
-            self,
-        ):
-            model = TorchModelMock(random_state=RandomState(42))
-            self.helper_test_successive_call_are_different(model)
-
-        def test_no_side_effect_between_rng_with_seeds(self):
-            model = TorchModelMock(random_state=42)
-            model.fit()
-            fit_value = model.fit_value.clone()
-
-            model = TorchModelMock(random_state=42)
-            model2 = TorchModelMock(random_state=42)
-            model2.fit()
-            model.fit()
-
-            assert torch.equal(model.fit_value, fit_value)
-
-        def test_no_side_effect_between_rng_with_random_instance(self):
-            model = TorchModelMock(random_state=RandomState(42))
-            model.fit()
-            fit_value = model.fit_value.clone()
-
-            model = TorchModelMock(random_state=RandomState(42))
-            model2 = TorchModelMock(random_state=RandomState(42))
-            model2.fit()
-            model.fit()
-
-            assert torch.equal(model.fit_value, fit_value)
+    pytest.skip(
+        f"Torch not available. {__name__} tests will be skipped.",
+        allow_module_level=True,
+    )
+
+
+# use a simple torch model mock
+class TorchModelMock:
+    @random_method
+    def __init__(self, some_params=None, **kwargs):
+        self.model = torch.randn(5)
+        # super().__init__()
+
+    @random_method
+    def fit(self, some_params=None):
+        self.fit_value = torch.randn(5)
+
+
+class TestRandomMethod:
+    def test_it_raises_error_if_used_on_function(self):
+        with pytest.raises(ValueError):
+
+            @random_method
+            def a_random_function():
+                pass
+
+    def test_model_is_random_by_default(self):
+        model1 = TorchModelMock()
+        model2 = TorchModelMock()
+        assert not torch.equal(model1.model, model2.model)
+
+    def test_model_is_random_when_None_random_state_specified(self):
+        model1 = TorchModelMock(random_state=None)
+        model2 = TorchModelMock(random_state=None)
+        assert not torch.equal(model1.model, model2.model)
+
+    def helper_test_reproducibility(self, model1, model2):
+        assert torch.equal(model1.model, model2.model)
+
+        model1.fit()
+        model2.fit()
+        assert torch.equal(model1.fit_value, model2.fit_value)
+
+    def test_model_is_reproducible_when_seed_specified(self):
+        model1 = TorchModelMock(random_state=42)
+        model2 = TorchModelMock(random_state=42)
+        self.helper_test_reproducibility(model1, model2)
+
+    def test_model_is_reproducible_when_random_instance_specified(self):
+        model1 = TorchModelMock(random_state=RandomState(42))
+        model2 = TorchModelMock(random_state=RandomState(42))
+        self.helper_test_reproducibility(model1, model2)
+
+    def test_model_is_different_for_different_seeds(self):
+        model1 = TorchModelMock(random_state=42)
+        model2 = TorchModelMock(random_state=43)
+        assert not torch.equal(model1.model, model2.model)
+
+    def test_model_is_different_for_different_random_instance(self):
+        model1 = TorchModelMock(random_state=RandomState(42))
+        model2 = TorchModelMock(random_state=RandomState(43))
+        assert not torch.equal(model1.model, model2.model)
+
+    def helper_test_successive_call_are_different(self, model):
+        # different between init and fit
+        model.fit()
+        assert not torch.equal(model.model, model.fit_value)
+
+        # different between 2 fit
+        old_fit_value = model.fit_value.clone()
+        model.fit()
+        assert not torch.equal(model.fit_value, old_fit_value)
+
+    def test_successive_call_to_rng_are_different_when_seed_specified(self):
+        model = TorchModelMock(random_state=42)
+        self.helper_test_successive_call_are_different(model)
+
+    def test_successive_call_to_rng_are_different_when_random_instance_specified(
+        self,
+    ):
+        model = TorchModelMock(random_state=RandomState(42))
+        self.helper_test_successive_call_are_different(model)
+
+    def test_no_side_effect_between_rng_with_seeds(self):
+        model = TorchModelMock(random_state=42)
+        model.fit()
+        fit_value = model.fit_value.clone()
+
+        model = TorchModelMock(random_state=42)
+        model2 = TorchModelMock(random_state=42)
+        model2.fit()
+        model.fit()
+
+        assert torch.equal(model.fit_value, fit_value)
+
+    def test_no_side_effect_between_rng_with_random_instance(self):
+        model = TorchModelMock(random_state=RandomState(42))
+        model.fit()
+        fit_value = model.fit_value.clone()
+
+        model = TorchModelMock(random_state=RandomState(42))
+        model2 = TorchModelMock(random_state=RandomState(42))
+        model2.fit()
+        model.fit()
+
+        assert torch.equal(model.fit_value, fit_value)
diff --git a/docs/source/examples.rst b/docs/source/examples.rst
index 9fd96c177a..ea71fbaa8d 100644
--- a/docs/source/examples.rst
+++ b/docs/source/examples.rst
@@ -185,7 +185,7 @@ TSMixer model example notebook:
 .. toctree::
    :maxdepth: 1
 
-   21-TSMixer-examples.ipynb
+   examples/21-TSMixer-examples.ipynb
 
 Ensemble Models
 =============================
diff --git a/examples/21-TSMixer-examples.ipynb b/examples/21-TSMixer-examples.ipynb
index 1d1735f909..ff23323f26 100644
--- a/examples/21-TSMixer-examples.ipynb
+++ b/examples/21-TSMixer-examples.ipynb
@@ -359,15 +359,10 @@
     "A few interesting things about these parameters:\n",
     "\n",
     "- **Gradient clipping:** Mitigates exploding gradients during backpropagation by setting an upper limit on the gradient for a batch.\n",
-    "\n",
     "- **Learning rate:** The majority of the learning done by a model is in the earlier epochs. As training goes on it is often helpful to reduce the learning rate to fine-tune the model. That being said, it can also lead to significant overfitting.\n",
-    "\n",
     "- **Early stopping:** To avoid overfitting, we can use early stopping. It monitors a metric on the validation set and stops training once the metric is not improving anymore based on a custom condition.\n",
-    "\n",
     "- **Likelihood and Loss Functions:** You can either make the model probabilistic with a `likelihood`, or deterministic with a `loss_fn`. In this notebook we train probabilistic models using QuantileRegression.\n",
-    "\n",
     "- **Reversible Instance Normalization:** Use [Reversible Instance Normalization](https://openreview.net/forum?id=cGDAkQo1C0p) which in most of the cases improves model performance.\n",
-    "\n",
     "- **Encoders:** We can encode time axis/calendar information and use them as past or future covariates using `add_encoders`. Here, we'll add cyclic encodings of the hour, day of the week, and month as future covariates"
    ]
   },
@@ -573,11 +568,12 @@
     "# Backtest the probabilistic models\n",
     "\n",
     "Let's configure the prediction. For this example, we will:\n",
+    "\n",
     "- generate **historical forecasts** on the test set using the **pre-trained models**. Each forecast covers a 24 hour horizon, and the time between two consecutive forecasts is also 24 hours. This will give us **276 multivariate forecasts per transformer** to evaluate the model!\n",
     "- generate **500 stochastic samples** for each prediction point (since we have trained probabilistic models)\n",
     "- evaluate/**backtest** the probabilistic historical forecasts for some quantiles **using the Mean Quantile Loss** (`mql()`).\n",
     "\n",
-    "And we'll create some helper functions to generating the forecasts, computing the backtest, and to visualize the predictions."
+    "And we'll create some helper functions to generate the forecasts, compute the backtest, and to visualize the predictions."
    ]
   },
   {
@@ -596,7 +592,7 @@
     "\n",
     "def historical_forecasts(model):\n",
     "    \"\"\"Generates probabilistic historical forecasts for each transformer\n",
-    "    and returns the inverse transforms results.\n",
+    "    and returns the inverse transformed results.\n",
     "\n",
     "    Each forecast covers 24h (forecast_horizon). The time between two forecasts\n",
     "    (stride) is also 24 hours.\n",
@@ -987,6 +983,7 @@
     "In this case, `TSMixer` and `TiDEModel` both perform similarly well. Keep in mind that we performed only partial training on the data, and that we used the default model parameters without any hyperparameter tuning. \n",
     "\n",
     "Here are some ways to further improve the performance:\n",
+    "\n",
     "- set `full_training=True`\n",
     "- perform hyperparmaeter tuning\n",
     "- add more covariates (we have only added cyclic encodings of calendar information)\n",