diff --git a/ivy_tests/test_transpiler/translations/test_translations.py b/ivy_tests/test_transpiler/translations/test_translations.py index 7a22adb734cd..f7428a594fd7 100644 --- a/ivy_tests/test_transpiler/translations/test_translations.py +++ b/ivy_tests/test_transpiler/translations/test_translations.py @@ -10,37 +10,15 @@ import os -def get_test_list(): - # Add tests here - return [ - "test_translate_torch_simple_transform", - "test_translate_torch_alexnet", - "test_translate_torch_unet", - "test_translate_RelaxedBernoulli", - "test_translate_torch_swin2sr", - "test_translate_torch_distilled_vit", - "test_translate_torch_cflow", - "test_translate_torch_glpdepth", - "test_translate_torch_lstm", - "test_translate_torch_inplace", - ] - - def get_target_list(): # Add targets here return ["tensorflow", "jax"] -@pytest.mark.parametrize("target", get_target_list()) -@pytest.mark.parametrize("test_name", get_test_list()) -def test_dispatcher(target, test_name): - # This function will dispatch to the appropriate test based on the name - globals()[test_name](target) - - # Note: Keep this test at the top of the file # to simulate that no transpilation has # taken place in the outputs directory +@pytest.mark.parametrize("target", get_target_list()) def test_translate_torch_simple_transform(target): ivy.set_backend(target) from ivy.transpiler.examples.SimpleModelNoConv.s2s_simplemodel import ( @@ -91,6 +69,7 @@ def test_translate_torch_simple_transform(target): assert np.allclose(out.detach().numpy(), ivy.to_numpy(cloned_out), atol=1e-2) +@pytest.mark.parametrize("target", get_target_list()) def test_translate_torch_alexnet(target): ivy.set_backend(target) from ivy.transpiler.examples.AlexNet.s2s_alexnet import AlexNet @@ -138,6 +117,7 @@ def test_translate_torch_alexnet(target): assert np.allclose(out.detach().numpy(), ivy.to_numpy(cloned_out), atol=0.05) +@pytest.mark.parametrize("target", get_target_list()) def test_translate_torch_unet(target): ivy.set_backend(target) from ivy.transpiler.examples.UNet.s2s_unet import UNet @@ -183,6 +163,7 @@ def test_translate_torch_unet(target): assert np.allclose(out.detach().numpy(), ivy.to_numpy(cloned_out), atol=1e-3) +@pytest.mark.parametrize("target", get_target_list()) def test_translate_RelaxedBernoulli(target): ivy.set_backend(target) from torch.distributions import RelaxedBernoulli @@ -214,6 +195,7 @@ def test_translate_RelaxedBernoulli(target): assert np.allclose(pt_sample, translated_sample, atol=1e-3) +@pytest.mark.parametrize("target", get_target_list()) def test_translate_torch_swin2sr(target): ivy.set_backend(target) from ivy.transpiler.examples.Swin2SR.s2s_swin2sr import Swin2SR @@ -302,6 +284,7 @@ def test_translate_torch_swin2sr(target): # to torch frontend Tensor class +@pytest.mark.parametrize("target", get_target_list()) def test_translate_torch_distilled_vit(target): ivy.set_backend(target) from ivy.transpiler.examples.DistilledVisionTransformer.s2s_distilledvit import ( @@ -390,6 +373,7 @@ def test_translate_torch_distilled_vit(target): assert np.allclose(out.detach().numpy(), ivy.to_numpy(cloned_out), atol=1e-3) +@pytest.mark.parametrize("target", get_target_list()) def test_translate_torch_cflow(target): ivy.set_backend(target) from ivy.transpiler.examples.CFLow.helpers import get_args @@ -499,6 +483,7 @@ def test_translate_torch_cflow(target): # ) +@pytest.mark.parametrize("target", get_target_list()) def test_translate_torch_glpdepth(target): ivy.set_backend(target) from ivy.transpiler.examples.GLPDepth.s2s_glpdepth import GLPDepth @@ -549,6 +534,7 @@ def test_translate_torch_glpdepth(target): # in the translated model - though the test still passes +@pytest.mark.parametrize("target", get_target_list()) def test_translate_torch_lstm(target): ivy.set_backend(target) @@ -577,6 +563,7 @@ def test_translate_torch_lstm(target): ) +@pytest.mark.parametrize("target", get_target_list()) def test_translate_torch_inplace(target): def inplace_fn(): M = torch.zeros(4, 16) @@ -593,16 +580,3 @@ def inplace_fn(): pt_logts = pt_out.detach().numpy() translated_logts = ivy.to_numpy(translated_out) assert np.allclose(pt_logts, translated_logts) - - -if __name__ == "__main__": - # This allows us to run specific tests when called from command line - test_name = os.environ.get("TEST_NAME", "all") - target = os.environ.get("TARGET", "all") - - if test_name != "all": - pytest.main([__file__, f"-k {test_name}"]) - elif target != "all": - pytest.main([__file__, f"-k test_dispatcher and {target}"]) - else: - pytest.main([__file__])