From 955670b44427937a9617f327d4d55681d8dfe919 Mon Sep 17 00:00:00 2001 From: Simon Brugman Date: Mon, 19 Aug 2024 12:23:40 +0200 Subject: [PATCH] Split checker and codemod tests into individual test cases via metafunc Few pathlib cleanups --- tests/test_torchfix.py | 78 +++++++++++++++++++++++------------------- 1 file changed, 43 insertions(+), 35 deletions(-) diff --git a/tests/test_torchfix.py b/tests/test_torchfix.py index 17ff183..9a4896f 100644 --- a/tests/test_torchfix.py +++ b/tests/test_torchfix.py @@ -16,20 +16,44 @@ LOGGER = logging.getLogger(__name__) +def pytest_generate_tests(metafunc): + # Dynamically generate test cases from paths + if "checker_source_path" in metafunc.fixturenames: + files = list(FIXTURES_PATH.glob("**/checker/*.py")) + metafunc.parametrize( + "checker_source_path", files, ids=[file_name.stem for file_name in files] + ) + if "codemod_source_path" in metafunc.fixturenames: + files = list(FIXTURES_PATH.glob("**/codemod/*.py")) + metafunc.parametrize( + "codemod_source_path", files, ids=[file_name.stem for file_name in files] + ) + if "case" in metafunc.fixturenames: + exclude_set = expand_error_codes(tuple(DISABLED_BY_DEFAULT)) + cases = [ + ("ALL", GET_ALL_ERROR_CODES()), + ("ALL,TOR102", GET_ALL_ERROR_CODES()), + ("TOR102", {"TOR102"}), + ("TOR102,TOR101", {"TOR102", "TOR101"}), + ("TOR1,TOR102", {"TOR102", "TOR101", "TOR103", "TOR104", "TOR105"}), + (None, set(GET_ALL_ERROR_CODES()) - exclude_set), + ] + metafunc.parametrize("case,expected", cases, ids=[case for case, _ in cases]) + + def _checker_results(s): checker = TorchChecker(None, s) return [f"{line}:{col} {msg}" for line, col, msg, _ in checker.run()] -def _codemod_results(source_path): - with open(source_path) as source: - code = source.read() +def _codemod_results(source_path: Path): + code = source_path.read_text() config = TorchCodemodConfig(select=list(GET_ALL_ERROR_CODES())) - context = TorchCodemod(codemod.CodemodContext(filename=source_path), config) + context = TorchCodemod(codemod.CodemodContext(filename=str(source_path)), config) new_module = codemod.transform_module(context, code) if isinstance(new_module, codemod.TransformSuccess): return new_module.code - elif isinstance(new_module, codemod.TransformFailure): + if isinstance(new_module, codemod.TransformFailure): raise new_module.error @@ -37,25 +61,20 @@ def test_empty(): assert _checker_results([""]) == [] -def test_checker_fixtures(): - for source_path in FIXTURES_PATH.glob("**/checker/*.py"): - LOGGER.info("Testing %s", source_path.relative_to(Path.cwd())) - expected_path = str(source_path)[:-2] + "txt" - expected_results = [] - with open(expected_path) as expected: - for line in expected: - expected_results.append(line.rstrip()) +def test_checker_fixtures(checker_source_path: Path): + expected_path = checker_source_path.with_suffix(".txt") + expected_results = expected_path.read_text().splitlines() - with open(source_path) as source: - assert _checker_results(source.readlines()) == expected_results + assert ( + _checker_results(checker_source_path.read_text().splitlines(keepends=True)) + == expected_results + ) -def test_codemod_fixtures(): - for source_path in FIXTURES_PATH.glob("**/codemod/*.py"): - LOGGER.info("Testing %s", source_path.relative_to(Path.cwd())) - expected_path = source_path.with_suffix(".py.out") - expected_results = expected_path.read_text() - assert _codemod_results(source_path) == expected_results +def test_codemod_fixtures(codemod_source_path: Path): + expected_path = codemod_source_path.with_suffix(".py.out") + expected_results = expected_path.read_text() + assert _codemod_results(codemod_source_path) == expected_results def test_errorcodes_distinct(): @@ -63,21 +82,10 @@ def test_errorcodes_distinct(): seen = set() for visitor in visitors: LOGGER.info("Checking error code for %s", visitor.__class__.__name__) - errors = visitor.ERRORS - for e in errors: + for e in visitor.ERRORS: assert e not in seen seen.add(e) -def test_parse_error_code_str(): - exclude_set = expand_error_codes(tuple(DISABLED_BY_DEFAULT)) - cases = [ - ("ALL", GET_ALL_ERROR_CODES()), - ("ALL,TOR102", GET_ALL_ERROR_CODES()), - ("TOR102", {"TOR102"}), - ("TOR102,TOR101", {"TOR102", "TOR101"}), - ("TOR1,TOR102", {"TOR102", "TOR101", "TOR103", "TOR104", "TOR105"}), - (None, set(GET_ALL_ERROR_CODES()) - exclude_set), - ] - for case, expected in cases: - assert expected == process_error_code_str(case) +def test_parse_error_code_str(case, expected): + assert process_error_code_str(case) == expected