From efc820a0a64a1bf628c8d2fe8d1d75a3123dd25a Mon Sep 17 00:00:00 2001 From: soulitzer Date: Wed, 31 Jan 2024 19:11:22 -0500 Subject: [PATCH] Fix bugs --- tests/test_torchfix.py | 17 +++++++++++++++++ torchfix/__main__.py | 25 +------------------------ torchfix/torchfix.py | 25 +++++++++++++++++++++++++ 3 files changed, 43 insertions(+), 24 deletions(-) diff --git a/tests/test_torchfix.py b/tests/test_torchfix.py index 04fb3c3..58e216c 100644 --- a/tests/test_torchfix.py +++ b/tests/test_torchfix.py @@ -3,8 +3,11 @@ TorchChecker, TorchCodemod, TorchCodemodConfig, + DISABLED_BY_DEFAULT, + expand_error_codes, GET_ALL_VISITORS, GET_ALL_ERROR_CODES, + process_error_code_str, ) import logging import libcst.codemod as codemod @@ -61,3 +64,17 @@ def test_errorcodes_distinct(): for e in error_code if isinstance(error_code, list) else [error_code]: assert e not in seen seen.add(e) + + +def test_parse_error_code_str(): + exclude_set = expand_error_codes(tuple(DISABLED_BY_DEFAULT)) + cases = [ + ("ALL", list(GET_ALL_ERROR_CODES())), + ("ALL,TOR102", list(GET_ALL_ERROR_CODES())), + ("TOR102", ["TOR102"]), + ("TOR102,TOR101", ["TOR102", "TOR101"]), + ("TOR1,TOR102", ["TOR102", "TOR101"]), + (None, list(GET_ALL_ERROR_CODES() - exclude_set)), + ] + for case, expected in cases: + assert expected == process_error_code_str(case) diff --git a/torchfix/__main__.py b/torchfix/__main__.py index 6cbdf83..de13e54 100644 --- a/torchfix/__main__.py +++ b/torchfix/__main__.py @@ -11,35 +11,12 @@ TorchCodemodConfig, __version__ as TorchFixVersion, DISABLED_BY_DEFAULT, - expand_error_codes, GET_ALL_ERROR_CODES, + process_error_code_str, ) from .common import CYAN, ENDC -def process_error_code_str(code_str): - # Allow duplicates in the input string, e.g. --select ALL,TOR0,TOR001. - # We deduplicate them here. - - # Default when --select is not provided. - if code_str is None: - exclude_set = expand_error_codes(tuple(DISABLED_BY_DEFAULT)) - return list(GET_ALL_ERROR_CODES() - exclude_set) - - raw_codes = [s.strip() for s in code_str.split(",")] - - # Validate error codes - for c in raw_codes: - if len(expand_error_codes((c,))) == 0: - raise ValueError(f"Invalid error code: {c}, available error " - f"codes: {list(GET_ALL_ERROR_CODES())}") - - if "ALL" in raw_codes: - return list(GET_ALL_ERROR_CODES()) - - return list(expand_error_codes(raw_codes)) - - # Should get rid of this code eventually. @contextlib.contextmanager def StderrSilencer(redirect: bool = True): diff --git a/torchfix/torchfix.py b/torchfix/torchfix.py index 2890196..5f28512 100644 --- a/torchfix/torchfix.py +++ b/torchfix/torchfix.py @@ -97,6 +97,31 @@ def get_visitors_with_error_codes(error_codes): return out +def process_error_code_str(code_str): + # Allow duplicates in the input string, e.g. --select ALL,TOR0,TOR001. + # We deduplicate them here. + + # Default when --select is not provided. + if code_str is None: + exclude_set = expand_error_codes(tuple(DISABLED_BY_DEFAULT)) + return list(GET_ALL_ERROR_CODES() - exclude_set) + + raw_codes = [s.strip() for s in code_str.split(",")] + + # Validate error codes + for c in raw_codes: + if c == "ALL": + continue + if len(expand_error_codes((c,))) == 0: + raise ValueError(f"Invalid error code: {c}, available error " + f"codes: {list(GET_ALL_ERROR_CODES())}") + + if "ALL" in raw_codes: + return list(GET_ALL_ERROR_CODES()) + + return list(expand_error_codes(tuple(raw_codes))) + + # Flake8 plugin class TorchChecker: name = "TorchFix"