Skip to content

Commit

Permalink
use pytest to skip torch tests (#2307)
Browse files Browse the repository at this point in the history
* use pytest to skip torch tests

* fix some mistakes in tsmixer notebook
  • Loading branch information
dennisbader authored Apr 9, 2024
1 parent 0d5c722 commit cdff09a
Show file tree
Hide file tree
Showing 21 changed files with 5,993 additions and 6,104 deletions.
2,591 changes: 1,290 additions & 1,301 deletions darts/tests/datasets/test_datasets.py

Large diffs are not rendered by default.

855 changes: 418 additions & 437 deletions darts/tests/explainability/test_tft_explainer.py

Large diffs are not rendered by default.

31 changes: 16 additions & 15 deletions darts/tests/models/components/glu_variants.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,27 @@
import pytest

from darts.logging import get_logger

logger = get_logger(__name__)

try:
import torch

TORCH_AVAILABLE = True
except ImportError:
logger.warning("Torch not available. Loss tests will be skipped.")
TORCH_AVAILABLE = False


if TORCH_AVAILABLE:
from darts.models.components import glu_variants
from darts.models.components.glu_variants import GLU_FFN
except ImportError:
pytest.skip(
f"Torch not available. {__name__} tests will be skipped.",
allow_module_level=True,
)


class TestFFN:
def test_ffn(self):
for FeedForward_network in GLU_FFN:
self.feed_forward_block = getattr(glu_variants, FeedForward_network)(
d_model=4, d_ff=16, dropout=0.1
)
class TestFFN:
def test_ffn(self):
for FeedForward_network in GLU_FFN:
self.feed_forward_block = getattr(glu_variants, FeedForward_network)(
d_model=4, d_ff=16, dropout=0.1
)

inputs = torch.zeros(1, 4, 4)
self.feed_forward_block(x=inputs)
inputs = torch.zeros(1, 4, 4)
self.feed_forward_block(x=inputs)
57 changes: 28 additions & 29 deletions darts/tests/models/components/test_layer_norm_variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,46 +8,45 @@
try:
import torch

TORCH_AVAILABLE = True
except ImportError:
logger.warning("Torch not available. Loss tests will be skipped.")
TORCH_AVAILABLE = False


if TORCH_AVAILABLE:
from darts.models.components.layer_norm_variants import (
LayerNorm,
LayerNormNoBias,
RINorm,
RMSNorm,
)
except ImportError:
pytest.skip(
f"Torch not available. {__name__} tests will be skipped.",
allow_module_level=True,
)


class TestLayerNormVariants:
def test_lnv(self):
for layer_norm in [RMSNorm, LayerNorm, LayerNormNoBias]:
ln = layer_norm(4)
inputs = torch.zeros(1, 4, 4)
ln(inputs)
class TestLayerNormVariants:
def test_lnv(self):
for layer_norm in [RMSNorm, LayerNorm, LayerNormNoBias]:
ln = layer_norm(4)
inputs = torch.zeros(1, 4, 4)
ln(inputs)

def test_rin(self):
def test_rin(self):

np.random.seed(42)
torch.manual_seed(42)
np.random.seed(42)
torch.manual_seed(42)

x = torch.randn(3, 4, 7)
affine_options = [True, False]
x = torch.randn(3, 4, 7)
affine_options = [True, False]

# test with and without affine and correct input dim
for affine in affine_options:
# test with and without affine and correct input dim
for affine in affine_options:

rin = RINorm(input_dim=7, affine=affine)
x_norm = rin(x)
rin = RINorm(input_dim=7, affine=affine)
x_norm = rin(x)

# expand dims to simulate probablistic forecasting
x_denorm = rin.inverse(x_norm.view(x_norm.shape + (1,))).squeeze(-1)
assert torch.all(torch.isclose(x, x_denorm)).item()
# expand dims to simulate probablistic forecasting
x_denorm = rin.inverse(x_norm.view(x_norm.shape + (1,))).squeeze(-1)
assert torch.all(torch.isclose(x, x_denorm)).item()

# try invalid input_dim
rin = RINorm(input_dim=3, affine=True)
with pytest.raises(RuntimeError):
x_norm = rin(x)
# try invalid input_dim
rin = RINorm(input_dim=3, affine=True)
with pytest.raises(RuntimeError):
x_norm = rin(x)
Loading

0 comments on commit cdff09a

Please sign in to comment.