Skip to content

Commit

Permalink
Split checker and codemod tests into individual test cases via metafunc
Browse files Browse the repository at this point in the history
Few pathlib cleanups
  • Loading branch information
sbrugman committed Sep 3, 2024
1 parent 555bb8d commit 955670b
Showing 1 changed file with 43 additions and 35 deletions.
78 changes: 43 additions & 35 deletions tests/test_torchfix.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,68 +16,76 @@
LOGGER = logging.getLogger(__name__)


def pytest_generate_tests(metafunc):
# Dynamically generate test cases from paths
if "checker_source_path" in metafunc.fixturenames:
files = list(FIXTURES_PATH.glob("**/checker/*.py"))
metafunc.parametrize(
"checker_source_path", files, ids=[file_name.stem for file_name in files]
)
if "codemod_source_path" in metafunc.fixturenames:
files = list(FIXTURES_PATH.glob("**/codemod/*.py"))
metafunc.parametrize(
"codemod_source_path", files, ids=[file_name.stem for file_name in files]
)
if "case" in metafunc.fixturenames:
exclude_set = expand_error_codes(tuple(DISABLED_BY_DEFAULT))
cases = [
("ALL", GET_ALL_ERROR_CODES()),
("ALL,TOR102", GET_ALL_ERROR_CODES()),
("TOR102", {"TOR102"}),
("TOR102,TOR101", {"TOR102", "TOR101"}),
("TOR1,TOR102", {"TOR102", "TOR101", "TOR103", "TOR104", "TOR105"}),
(None, set(GET_ALL_ERROR_CODES()) - exclude_set),
]
metafunc.parametrize("case,expected", cases, ids=[case for case, _ in cases])


def _checker_results(s):
checker = TorchChecker(None, s)
return [f"{line}:{col} {msg}" for line, col, msg, _ in checker.run()]


def _codemod_results(source_path):
with open(source_path) as source:
code = source.read()
def _codemod_results(source_path: Path):
code = source_path.read_text()
config = TorchCodemodConfig(select=list(GET_ALL_ERROR_CODES()))
context = TorchCodemod(codemod.CodemodContext(filename=source_path), config)
context = TorchCodemod(codemod.CodemodContext(filename=str(source_path)), config)
new_module = codemod.transform_module(context, code)
if isinstance(new_module, codemod.TransformSuccess):
return new_module.code
elif isinstance(new_module, codemod.TransformFailure):
if isinstance(new_module, codemod.TransformFailure):
raise new_module.error


def test_empty():
assert _checker_results([""]) == []


def test_checker_fixtures():
for source_path in FIXTURES_PATH.glob("**/checker/*.py"):
LOGGER.info("Testing %s", source_path.relative_to(Path.cwd()))
expected_path = str(source_path)[:-2] + "txt"
expected_results = []
with open(expected_path) as expected:
for line in expected:
expected_results.append(line.rstrip())
def test_checker_fixtures(checker_source_path: Path):
expected_path = checker_source_path.with_suffix(".txt")
expected_results = expected_path.read_text().splitlines()

with open(source_path) as source:
assert _checker_results(source.readlines()) == expected_results
assert (
_checker_results(checker_source_path.read_text().splitlines(keepends=True))
== expected_results
)


def test_codemod_fixtures():
for source_path in FIXTURES_PATH.glob("**/codemod/*.py"):
LOGGER.info("Testing %s", source_path.relative_to(Path.cwd()))
expected_path = source_path.with_suffix(".py.out")
expected_results = expected_path.read_text()
assert _codemod_results(source_path) == expected_results
def test_codemod_fixtures(codemod_source_path: Path):
expected_path = codemod_source_path.with_suffix(".py.out")
expected_results = expected_path.read_text()
assert _codemod_results(codemod_source_path) == expected_results


def test_errorcodes_distinct():
visitors = GET_ALL_VISITORS()
seen = set()
for visitor in visitors:
LOGGER.info("Checking error code for %s", visitor.__class__.__name__)
errors = visitor.ERRORS
for e in errors:
for e in visitor.ERRORS:
assert e not in seen
seen.add(e)


def test_parse_error_code_str():
exclude_set = expand_error_codes(tuple(DISABLED_BY_DEFAULT))
cases = [
("ALL", GET_ALL_ERROR_CODES()),
("ALL,TOR102", GET_ALL_ERROR_CODES()),
("TOR102", {"TOR102"}),
("TOR102,TOR101", {"TOR102", "TOR101"}),
("TOR1,TOR102", {"TOR102", "TOR101", "TOR103", "TOR104", "TOR105"}),
(None, set(GET_ALL_ERROR_CODES()) - exclude_set),
]
for case, expected in cases:
assert expected == process_error_code_str(case)
def test_parse_error_code_str(case, expected):
assert process_error_code_str(case) == expected

0 comments on commit 955670b

Please sign in to comment.