Skip to content

Commit

Permalink
Return without error when there are no valid python files
Browse files Browse the repository at this point in the history
  • Loading branch information
zleman1593 committed Nov 11, 2024
1 parent 87289c1 commit 512d520
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 12 deletions.
34 changes: 28 additions & 6 deletions tests/test_torchfix.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import logging
import subprocess
from pathlib import Path

import libcst.codemod as codemod
from torchfix.torchfix import (
TorchChecker,
TorchCodemod,
TorchCodemodConfig,
DISABLED_BY_DEFAULT,
expand_error_codes,
GET_ALL_VISITORS,
GET_ALL_ERROR_CODES,
GET_ALL_VISITORS,
process_error_code_str,
TorchChecker,
TorchCodemod,
TorchCodemodConfig,
)
import logging
import libcst.codemod as codemod

FIXTURES_PATH = Path(__file__).absolute().parent / "fixtures"
LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -103,3 +105,23 @@ def test_errorcodes_distinct():

def test_parse_error_code_str(case, expected):
assert process_error_code_str(case) == expected


def test_no_python_files(tmp_path):
# Create a temporary directory with no Python files
non_python_file = tmp_path / "not_a_python_file.txt"
non_python_file.write_text("This is not a Python file")

# Run torchfix on the temporary directory
# TODO: Fix this. This will not run the test on current build
result = subprocess.run(
["torchfix", str(tmp_path)],
capture_output=True,
text=True,
)

# Check that the script exits successfully
assert result.returncode == 0

# Check that the correct message is printed
assert "No Python files with torch imports found." in result.stderr
16 changes: 10 additions & 6 deletions torchfix/__main__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
import argparse
import libcst.codemod as codemod

import contextlib
import ctypes
import sys
import io
import sys

import libcst.codemod as codemod

from .common import CYAN, ENDC

from .torchfix import (
TorchCodemod,
TorchCodemodConfig,
__version__ as TorchFixVersion,
DISABLED_BY_DEFAULT,
GET_ALL_ERROR_CODES,
process_error_code_str,
TorchCodemod,
TorchCodemodConfig,
)
from .common import CYAN, ENDC


# Should get rid of this code eventually.
Expand Down Expand Up @@ -83,7 +85,6 @@ def _parse_args() -> argparse.Namespace:

def main() -> None:
args = _parse_args()

files = codemod.gather_files(args.path)

# Filter out files that don't have "torch" string in them.
Expand All @@ -97,6 +98,9 @@ def main() -> None:
torch_files.append(file)
break

if not torch_files:
print("No Python files with torch imports found.", file=sys.stderr)
return
config = TorchCodemodConfig()
config.select = list(process_error_code_str(args.select))
command_instance = TorchCodemod(codemod.CodemodContext(), config)
Expand Down

0 comments on commit 512d520

Please sign in to comment.