Skip to content

Commit

Permalink
fix: remove test dispatcher functionality from test_translations
Browse files Browse the repository at this point in the history
  • Loading branch information
Sam-Armstrong committed Jan 7, 2025
1 parent fdcc7fe commit ab673a9
Showing 1 changed file with 10 additions and 36 deletions.
46 changes: 10 additions & 36 deletions ivy_tests/test_transpiler/translations/test_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,37 +10,15 @@
import os


def get_test_list():
# Add tests here
return [
"test_translate_torch_simple_transform",
"test_translate_torch_alexnet",
"test_translate_torch_unet",
"test_translate_RelaxedBernoulli",
"test_translate_torch_swin2sr",
"test_translate_torch_distilled_vit",
"test_translate_torch_cflow",
"test_translate_torch_glpdepth",
"test_translate_torch_lstm",
"test_translate_torch_inplace",
]


def get_target_list():
# Add targets here
return ["tensorflow", "jax"]


@pytest.mark.parametrize("target", get_target_list())
@pytest.mark.parametrize("test_name", get_test_list())
def test_dispatcher(target, test_name):
# This function will dispatch to the appropriate test based on the name
globals()[test_name](target)


# Note: Keep this test at the top of the file
# to simulate that no transpilation has
# taken place in the outputs directory
@pytest.mark.parametrize("target", get_target_list())
def test_translate_torch_simple_transform(target):
ivy.set_backend(target)
from ivy.transpiler.examples.SimpleModelNoConv.s2s_simplemodel import (
Expand Down Expand Up @@ -91,6 +69,7 @@ def test_translate_torch_simple_transform(target):
assert np.allclose(out.detach().numpy(), ivy.to_numpy(cloned_out), atol=1e-2)


@pytest.mark.parametrize("target", get_target_list())
def test_translate_torch_alexnet(target):
ivy.set_backend(target)
from ivy.transpiler.examples.AlexNet.s2s_alexnet import AlexNet
Expand Down Expand Up @@ -138,6 +117,7 @@ def test_translate_torch_alexnet(target):
assert np.allclose(out.detach().numpy(), ivy.to_numpy(cloned_out), atol=0.05)


@pytest.mark.parametrize("target", get_target_list())
def test_translate_torch_unet(target):
ivy.set_backend(target)
from ivy.transpiler.examples.UNet.s2s_unet import UNet
Expand Down Expand Up @@ -183,6 +163,7 @@ def test_translate_torch_unet(target):
assert np.allclose(out.detach().numpy(), ivy.to_numpy(cloned_out), atol=1e-3)


@pytest.mark.parametrize("target", get_target_list())
def test_translate_RelaxedBernoulli(target):
ivy.set_backend(target)
from torch.distributions import RelaxedBernoulli
Expand Down Expand Up @@ -214,6 +195,7 @@ def test_translate_RelaxedBernoulli(target):
assert np.allclose(pt_sample, translated_sample, atol=1e-3)


@pytest.mark.parametrize("target", get_target_list())
def test_translate_torch_swin2sr(target):
ivy.set_backend(target)
from ivy.transpiler.examples.Swin2SR.s2s_swin2sr import Swin2SR
Expand Down Expand Up @@ -302,6 +284,7 @@ def test_translate_torch_swin2sr(target):
# to torch frontend Tensor class


@pytest.mark.parametrize("target", get_target_list())
def test_translate_torch_distilled_vit(target):
ivy.set_backend(target)
from ivy.transpiler.examples.DistilledVisionTransformer.s2s_distilledvit import (
Expand Down Expand Up @@ -390,6 +373,7 @@ def test_translate_torch_distilled_vit(target):
assert np.allclose(out.detach().numpy(), ivy.to_numpy(cloned_out), atol=1e-3)


@pytest.mark.parametrize("target", get_target_list())
def test_translate_torch_cflow(target):
ivy.set_backend(target)
from ivy.transpiler.examples.CFLow.helpers import get_args
Expand Down Expand Up @@ -499,6 +483,7 @@ def test_translate_torch_cflow(target):
# )


@pytest.mark.parametrize("target", get_target_list())
def test_translate_torch_glpdepth(target):
ivy.set_backend(target)
from ivy.transpiler.examples.GLPDepth.s2s_glpdepth import GLPDepth
Expand Down Expand Up @@ -549,6 +534,7 @@ def test_translate_torch_glpdepth(target):
# in the translated model - though the test still passes


@pytest.mark.parametrize("target", get_target_list())
def test_translate_torch_lstm(target):
ivy.set_backend(target)

Expand Down Expand Up @@ -577,6 +563,7 @@ def test_translate_torch_lstm(target):
)


@pytest.mark.parametrize("target", get_target_list())
def test_translate_torch_inplace(target):
def inplace_fn():
M = torch.zeros(4, 16)
Expand All @@ -593,16 +580,3 @@ def inplace_fn():
pt_logts = pt_out.detach().numpy()
translated_logts = ivy.to_numpy(translated_out)
assert np.allclose(pt_logts, translated_logts)


if __name__ == "__main__":
# This allows us to run specific tests when called from command line
test_name = os.environ.get("TEST_NAME", "all")
target = os.environ.get("TARGET", "all")

if test_name != "all":
pytest.main([__file__, f"-k {test_name}"])
elif target != "all":
pytest.main([__file__, f"-k test_dispatcher and {target}"])
else:
pytest.main([__file__])

0 comments on commit ab673a9

Please sign in to comment.