diff --git a/tests/test_torchfix.py b/tests/test_torchfix.py index 7b79fb7..882a49e 100644 --- a/tests/test_torchfix.py +++ b/tests/test_torchfix.py @@ -1,16 +1,18 @@ +import logging +import subprocess from pathlib import Path + +import libcst.codemod as codemod from torchfix.torchfix import ( - TorchChecker, - TorchCodemod, - TorchCodemodConfig, DISABLED_BY_DEFAULT, expand_error_codes, - GET_ALL_VISITORS, GET_ALL_ERROR_CODES, + GET_ALL_VISITORS, process_error_code_str, + TorchChecker, + TorchCodemod, + TorchCodemodConfig, ) -import logging -import libcst.codemod as codemod FIXTURES_PATH = Path(__file__).absolute().parent / "fixtures" LOGGER = logging.getLogger(__name__) @@ -103,3 +105,23 @@ def test_errorcodes_distinct(): def test_parse_error_code_str(case, expected): assert process_error_code_str(case) == expected + + +def test_no_python_files(tmp_path): + # Create a temporary directory with no Python files + non_python_file = tmp_path / "not_a_python_file.txt" + non_python_file.write_text("This is not a Python file") + + # Run torchfix on the temporary directory + # TODO: Fix this. This will not run the test on current build + result = subprocess.run( + ["torchfix", str(tmp_path)], + capture_output=True, + text=True, + ) + + # Check that the script exits successfully + assert result.returncode == 0 + + # Check that the correct message is printed + assert "No Python files with torch imports found." in result.stderr diff --git a/torchfix/__main__.py b/torchfix/__main__.py index eb17658..781b8d1 100644 --- a/torchfix/__main__.py +++ b/torchfix/__main__.py @@ -1,20 +1,22 @@ import argparse -import libcst.codemod as codemod import contextlib import ctypes -import sys import io +import sys + +import libcst.codemod as codemod + +from .common import CYAN, ENDC from .torchfix import ( - TorchCodemod, - TorchCodemodConfig, __version__ as TorchFixVersion, DISABLED_BY_DEFAULT, GET_ALL_ERROR_CODES, process_error_code_str, + TorchCodemod, + TorchCodemodConfig, ) -from .common import CYAN, ENDC # Should get rid of this code eventually. @@ -83,7 +85,6 @@ def _parse_args() -> argparse.Namespace: def main() -> None: args = _parse_args() - files = codemod.gather_files(args.path) # Filter out files that don't have "torch" string in them. @@ -97,6 +98,9 @@ def main() -> None: torch_files.append(file) break + if not torch_files: + print("No Python files with torch imports found.", file=sys.stderr) + return config = TorchCodemodConfig() config.select = list(process_error_code_str(args.select)) command_instance = TorchCodemod(codemod.CodemodContext(), config)