diff --git a/tests/test_torchfix.py b/tests/test_torchfix.py index cd9b74c..6e7e0c6 100644 --- a/tests/test_torchfix.py +++ b/tests/test_torchfix.py @@ -27,6 +27,8 @@ def _codemod_results(source_path): config = TorchCodemodConfig(select=list(GET_ALL_ERROR_CODES())) context = TorchCodemod(codemod.CodemodContext(filename=source_path), config) new_module = codemod.transform_module(context, code) + if isinstance(new_module, codemod.TransformFailure): + raise new_module.error return new_module.code