From eec8d8c21cbf6047f10eafa6cff11c53278f6eed Mon Sep 17 00:00:00 2001 From: soulitzer Date: Mon, 29 Jan 2024 14:43:56 -0500 Subject: [PATCH 1/6] Update arg to accept specific rules --- tests/test_torchfix.py | 3 +- torchfix/__main__.py | 29 ++++++++++++---- torchfix/torchfix.py | 76 +++++++++++++++++++++++++++++++----------- 3 files changed, 82 insertions(+), 26 deletions(-) diff --git a/tests/test_torchfix.py b/tests/test_torchfix.py index b9f0be0..1cbdebe 100644 --- a/tests/test_torchfix.py +++ b/tests/test_torchfix.py @@ -4,6 +4,7 @@ TorchCodemod, TorchCodemodConfig, GET_ALL_VISITORS, + GET_ALL_ERROR_CODES, ) import logging import libcst.codemod as codemod @@ -20,7 +21,7 @@ def _checker_results(s): def _codemod_results(source_path): with open(source_path) as source: code = source.read() - config = TorchCodemodConfig(select="ALL") + config = TorchCodemodConfig(select=GET_ALL_ERROR_CODES()) context = TorchCodemod(codemod.CodemodContext(filename=source_path), config) new_module = codemod.transform_module(context, code) return new_module.code diff --git a/torchfix/__main__.py b/torchfix/__main__.py index a2f0130..0c55b79 100644 --- a/torchfix/__main__.py +++ b/torchfix/__main__.py @@ -6,9 +6,25 @@ import sys import io -from .torchfix import TorchCodemod, TorchCodemodConfig, __version__ as TorchFixVersion +from .torchfix import ( + TorchCodemod, + TorchCodemodConfig, + __version__ as TorchFixVersion, + DISABLED_BY_DEFAULT, + GET_ALL_ERROR_CODES, +) from .common import CYAN, ENDC +def process_error_code_str(code_str): + if code_str is None: + return [code for code in GET_ALL_ERROR_CODES() if code not in DISABLED_BY_DEFAULT] + codes = [s.strip() for s in code_str.split(",")] + if "ALL" in codes: + return GET_ALL_ERROR_CODES() + for code in codes: + if code not in GET_ALL_ERROR_CODES(): + raise ValueError(f"Invalid error code: {code}, available error codes: {GET_ALL_ERROR_CODES()}") + return codes # Should get rid of this code eventually. @contextlib.contextmanager @@ -55,10 +71,11 @@ def main() -> None: ) parser.add_argument( "--select", - help="ALL to enable rules disabled by default", - choices=[ - "ALL", - ], + help=f"Comma-separated list of rules to enable or 'ALL' to enable all rules. " + f"Available rules: {', '.join(GET_ALL_ERROR_CODES())}. " + f"Defaults to all except for {', '.join(DISABLED_BY_DEFAULT)}.", + type=str, + default=None, ) parser.add_argument( "--version", @@ -94,7 +111,7 @@ def main() -> None: break config = TorchCodemodConfig() - config.select = args.select + config.select = process_error_code_str(args.select) command_instance = TorchCodemod(codemod.CodemodContext(), config) DIFF_CONTEXT = 5 try: diff --git a/torchfix/torchfix.py b/torchfix/torchfix.py index 989b44e..a947fe6 100644 --- a/torchfix/torchfix.py +++ b/torchfix/torchfix.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from pathlib import Path -from typing import Optional +from typing import Optional, List import libcst as cst import libcst.codemod as codemod @@ -25,18 +25,55 @@ DISABLED_BY_DEFAULT = ["TOR3", "TOR4"] +ALL_VISITOR_CLS = [ + TorchDeprecatedSymbolsVisitor, + TorchRequireGradVisitor, + TorchSynchronizedDataLoaderVisitor, + TorchVisionDeprecatedPretrainedVisitor, + TorchVisionDeprecatedToTensorVisitor, + TorchUnsafeLoadVisitor, + TorchReentrantCheckpointVisitor, +] + +_ALL_ERROR_CODES = None + +def GET_ALL_ERROR_CODES(): + global _ALL_ERROR_CODES + if _ALL_ERROR_CODES is None: + codes = [] + for cls in ALL_VISITOR_CLS: + if isinstance(cls.ERROR_CODE, list): + codes += cls.ERROR_CODE + else: + codes.append(cls.ERROR_CODE) + _ALL_ERROR_CODES = codes + return _ALL_ERROR_CODES def GET_ALL_VISITORS(): - return [ - TorchDeprecatedSymbolsVisitor(DEPRECATED_CONFIG_PATH), - TorchRequireGradVisitor(), - TorchSynchronizedDataLoaderVisitor(), - TorchVisionDeprecatedPretrainedVisitor(), - TorchVisionDeprecatedToTensorVisitor(), - TorchUnsafeLoadVisitor(), - TorchReentrantCheckpointVisitor(), - ] - + out = [] + for v in ALL_VISITOR_CLS: + if v is TorchDeprecatedSymbolsVisitor: + out.append(v(DEPRECATED_CONFIG_PATH)) + else: + out.append(v()) + return out + +def get_visitor_with_error_code(error_code): + # Each error code can only correspond to one visitor + for visitor in GET_ALL_VISITORS(): + if isinstance(visitor.ERROR_CODE, list): + if error_code in visitor.ERROR_CODE: + return visitor + else: + if error_code == visitor.ERROR_CODE: + return visitor + assert False, f"Unknown error code: {error_code}" + +def get_visitors_with_error_codes(error_codes): + visitors = [] + for error_code in error_codes: + visitors.append(get_visitor_with_error_code(error_code)) + return visitors # Flake8 plugin class TorchChecker: @@ -78,7 +115,7 @@ def add_options(optmanager): # Standalone torchfix command @dataclass class TorchCodemodConfig: - select: Optional[str] = None + select: Optional[List[str]] = None class TorchCodemod(codemod.Codemod): @@ -97,8 +134,8 @@ def transform_module_impl(self, module: cst.Module) -> cst.Module: # in that case we would need to use `wrapped_module.module` # instead of `module`. wrapped_module = cst.MetadataWrapper(module, unsafe_skip_copy=True) + visitors = get_visitors_with_error_codes(self.config.select) - visitors = GET_ALL_VISITORS() violations = [] needed_imports = [] wrapped_module.visit_batched(visitors) @@ -110,12 +147,13 @@ def transform_module_impl(self, module: cst.Module) -> cst.Module: replacement_map = {} assert self.context.filename is not None for violation in violations: - skip_violation = False - if self.config is None or self.config.select != "ALL": - for disabled_code in DISABLED_BY_DEFAULT: - if violation.error_code.startswith(disabled_code): - skip_violation = True - break + # Still need to skip violations here, since a single visitor can + # correspond to multiple different types of violations. + skip_violation = True + for code in self.config.select: + if violation.error_code.startswith(code): + skip_violation = False + break if skip_violation: continue From 60c1ac1dfc58b1d0a6e600581a69a4436352e6b6 Mon Sep 17 00:00:00 2001 From: soulitzer Date: Tue, 30 Jan 2024 12:27:44 -0500 Subject: [PATCH 2/6] fix lint --- torchfix/__main__.py | 8 ++++++-- torchfix/torchfix.py | 9 ++++++++- torchfix/visitors/vision/to_tensor.py | 2 +- 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/torchfix/__main__.py b/torchfix/__main__.py index 0c55b79..4b21afb 100644 --- a/torchfix/__main__.py +++ b/torchfix/__main__.py @@ -15,17 +15,21 @@ ) from .common import CYAN, ENDC + def process_error_code_str(code_str): if code_str is None: - return [code for code in GET_ALL_ERROR_CODES() if code not in DISABLED_BY_DEFAULT] + return [code for code in GET_ALL_ERROR_CODES() + if code not in DISABLED_BY_DEFAULT] codes = [s.strip() for s in code_str.split(",")] if "ALL" in codes: return GET_ALL_ERROR_CODES() for code in codes: if code not in GET_ALL_ERROR_CODES(): - raise ValueError(f"Invalid error code: {code}, available error codes: {GET_ALL_ERROR_CODES()}") + raise ValueError(f"Invalid error code: {code}, available error " + f"codes: {GET_ALL_ERROR_CODES()}") return codes + # Should get rid of this code eventually. @contextlib.contextmanager def StderrSilencer(redirect: bool = True): diff --git a/torchfix/torchfix.py b/torchfix/torchfix.py index a947fe6..1b7d1b5 100644 --- a/torchfix/torchfix.py +++ b/torchfix/torchfix.py @@ -37,6 +37,7 @@ _ALL_ERROR_CODES = None + def GET_ALL_ERROR_CODES(): global _ALL_ERROR_CODES if _ALL_ERROR_CODES is None: @@ -49,6 +50,7 @@ def GET_ALL_ERROR_CODES(): _ALL_ERROR_CODES = codes return _ALL_ERROR_CODES + def GET_ALL_VISITORS(): out = [] for v in ALL_VISITOR_CLS: @@ -58,6 +60,7 @@ def GET_ALL_VISITORS(): out.append(v()) return out + def get_visitor_with_error_code(error_code): # Each error code can only correspond to one visitor for visitor in GET_ALL_VISITORS(): @@ -67,7 +70,8 @@ def get_visitor_with_error_code(error_code): else: if error_code == visitor.ERROR_CODE: return visitor - assert False, f"Unknown error code: {error_code}" + raise AssertionError(f"Unknown error code: {error_code}") + def get_visitors_with_error_codes(error_codes): visitors = [] @@ -75,6 +79,7 @@ def get_visitors_with_error_codes(error_codes): visitors.append(get_visitor_with_error_code(error_code)) return visitors + # Flake8 plugin class TorchChecker: name = "TorchFix" @@ -134,6 +139,8 @@ def transform_module_impl(self, module: cst.Module) -> cst.Module: # in that case we would need to use `wrapped_module.module` # instead of `module`. wrapped_module = cst.MetadataWrapper(module, unsafe_skip_copy=True) + if self.config is None or self.config.select is None: + raise AssertionError("Expected self.config.select to be set") visitors = get_visitors_with_error_codes(self.config.select) violations = [] diff --git a/torchfix/visitors/vision/to_tensor.py b/torchfix/visitors/vision/to_tensor.py index 6886c41..ab15827 100644 --- a/torchfix/visitors/vision/to_tensor.py +++ b/torchfix/visitors/vision/to_tensor.py @@ -43,7 +43,7 @@ def visit_ImportFrom(self, node): def visit_Attribute(self, node): qualified_names = self.get_metadata(cst.metadata.QualifiedNameProvider, node) - if not len(qualified_names) == 1: + if len(qualified_names) != 1: return self._maybe_add_violation(qualified_names.pop().name, node) From 375316b2fb61764f76bb07ab0756e95d4bff9ca5 Mon Sep 17 00:00:00 2001 From: soulitzer Date: Wed, 31 Jan 2024 14:20:54 -0500 Subject: [PATCH 3/6] Update --- tests/test_torchfix.py | 2 +- torchfix/__main__.py | 32 +++++++++++------ torchfix/torchfix.py | 79 +++++++++++++++++++++++++----------------- 3 files changed, 70 insertions(+), 43 deletions(-) diff --git a/tests/test_torchfix.py b/tests/test_torchfix.py index 1cbdebe..04fb3c3 100644 --- a/tests/test_torchfix.py +++ b/tests/test_torchfix.py @@ -21,7 +21,7 @@ def _checker_results(s): def _codemod_results(source_path): with open(source_path) as source: code = source.read() - config = TorchCodemodConfig(select=GET_ALL_ERROR_CODES()) + config = TorchCodemodConfig(select=list(GET_ALL_ERROR_CODES())) context = TorchCodemod(codemod.CodemodContext(filename=source_path), config) new_module = codemod.transform_module(context, code) return new_module.code diff --git a/torchfix/__main__.py b/torchfix/__main__.py index 4b21afb..6cbdf83 100644 --- a/torchfix/__main__.py +++ b/torchfix/__main__.py @@ -11,23 +11,33 @@ TorchCodemodConfig, __version__ as TorchFixVersion, DISABLED_BY_DEFAULT, + expand_error_codes, GET_ALL_ERROR_CODES, ) from .common import CYAN, ENDC def process_error_code_str(code_str): + # Allow duplicates in the input string, e.g. --select ALL,TOR0,TOR001. + # We deduplicate them here. + + # Default when --select is not provided. if code_str is None: - return [code for code in GET_ALL_ERROR_CODES() - if code not in DISABLED_BY_DEFAULT] - codes = [s.strip() for s in code_str.split(",")] - if "ALL" in codes: - return GET_ALL_ERROR_CODES() - for code in codes: - if code not in GET_ALL_ERROR_CODES(): - raise ValueError(f"Invalid error code: {code}, available error " - f"codes: {GET_ALL_ERROR_CODES()}") - return codes + exclude_set = expand_error_codes(tuple(DISABLED_BY_DEFAULT)) + return list(GET_ALL_ERROR_CODES() - exclude_set) + + raw_codes = [s.strip() for s in code_str.split(",")] + + # Validate error codes + for c in raw_codes: + if len(expand_error_codes((c,))) == 0: + raise ValueError(f"Invalid error code: {c}, available error " + f"codes: {list(GET_ALL_ERROR_CODES())}") + + if "ALL" in raw_codes: + return list(GET_ALL_ERROR_CODES()) + + return list(expand_error_codes(raw_codes)) # Should get rid of this code eventually. @@ -76,7 +86,7 @@ def main() -> None: parser.add_argument( "--select", help=f"Comma-separated list of rules to enable or 'ALL' to enable all rules. " - f"Available rules: {', '.join(GET_ALL_ERROR_CODES())}. " + f"Available rules: {', '.join(list(GET_ALL_ERROR_CODES()))}. " f"Defaults to all except for {', '.join(DISABLED_BY_DEFAULT)}.", type=str, default=None, diff --git a/torchfix/torchfix.py b/torchfix/torchfix.py index 1b7d1b5..2890196 100644 --- a/torchfix/torchfix.py +++ b/torchfix/torchfix.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +import functools from pathlib import Path from typing import Optional, List import libcst as cst @@ -35,49 +36,65 @@ TorchReentrantCheckpointVisitor, ] -_ALL_ERROR_CODES = None - +@functools.cache def GET_ALL_ERROR_CODES(): - global _ALL_ERROR_CODES - if _ALL_ERROR_CODES is None: - codes = [] - for cls in ALL_VISITOR_CLS: - if isinstance(cls.ERROR_CODE, list): - codes += cls.ERROR_CODE - else: - codes.append(cls.ERROR_CODE) - _ALL_ERROR_CODES = codes - return _ALL_ERROR_CODES + codes = set() + for cls in ALL_VISITOR_CLS: + if isinstance(cls.ERROR_CODE, list): + codes |= set(cls.ERROR_CODE) + else: + codes.add(cls.ERROR_CODE) + return codes + + +@functools.cache +def expand_error_codes(codes): + out_codes = set() + for c_a in codes: + for c_b in GET_ALL_ERROR_CODES(): + if c_b.startswith(c_a): + out_codes.add(c_b) + return out_codes + + +def construct_visitor(cls): + if cls is TorchDeprecatedSymbolsVisitor: + return cls(DEPRECATED_CONFIG_PATH) + else: + return cls() def GET_ALL_VISITORS(): out = [] for v in ALL_VISITOR_CLS: - if v is TorchDeprecatedSymbolsVisitor: - out.append(v(DEPRECATED_CONFIG_PATH)) - else: - out.append(v()) + out.append(construct_visitor(v)) return out -def get_visitor_with_error_code(error_code): - # Each error code can only correspond to one visitor - for visitor in GET_ALL_VISITORS(): - if isinstance(visitor.ERROR_CODE, list): - if error_code in visitor.ERROR_CODE: - return visitor - else: - if error_code == visitor.ERROR_CODE: - return visitor - raise AssertionError(f"Unknown error code: {error_code}") - - def get_visitors_with_error_codes(error_codes): - visitors = [] + visitor_classes = set() for error_code in error_codes: - visitors.append(get_visitor_with_error_code(error_code)) - return visitors + # Assume the error codes have been expanded so each error code can + # only correspond to one visitor. + found = False + for visitor_cls in ALL_VISITOR_CLS: + if isinstance(visitor_cls.ERROR_CODE, list): + if error_code in visitor_cls.ERROR_CODE: + visitor_classes.add(visitor_cls) + found = True + break + else: + if error_code == visitor_cls.ERROR_CODE: + visitor_classes.add(visitor_cls) + found = True + break + if not found: + raise AssertionError(f"Unknown error code: {error_code}") + out = [] + for cls in visitor_classes: + out.append(construct_visitor(cls)) + return out # Flake8 plugin From efc820a0a64a1bf628c8d2fe8d1d75a3123dd25a Mon Sep 17 00:00:00 2001 From: soulitzer Date: Wed, 31 Jan 2024 19:11:22 -0500 Subject: [PATCH 4/6] Fix bugs --- tests/test_torchfix.py | 17 +++++++++++++++++ torchfix/__main__.py | 25 +------------------------ torchfix/torchfix.py | 25 +++++++++++++++++++++++++ 3 files changed, 43 insertions(+), 24 deletions(-) diff --git a/tests/test_torchfix.py b/tests/test_torchfix.py index 04fb3c3..58e216c 100644 --- a/tests/test_torchfix.py +++ b/tests/test_torchfix.py @@ -3,8 +3,11 @@ TorchChecker, TorchCodemod, TorchCodemodConfig, + DISABLED_BY_DEFAULT, + expand_error_codes, GET_ALL_VISITORS, GET_ALL_ERROR_CODES, + process_error_code_str, ) import logging import libcst.codemod as codemod @@ -61,3 +64,17 @@ def test_errorcodes_distinct(): for e in error_code if isinstance(error_code, list) else [error_code]: assert e not in seen seen.add(e) + + +def test_parse_error_code_str(): + exclude_set = expand_error_codes(tuple(DISABLED_BY_DEFAULT)) + cases = [ + ("ALL", list(GET_ALL_ERROR_CODES())), + ("ALL,TOR102", list(GET_ALL_ERROR_CODES())), + ("TOR102", ["TOR102"]), + ("TOR102,TOR101", ["TOR102", "TOR101"]), + ("TOR1,TOR102", ["TOR102", "TOR101"]), + (None, list(GET_ALL_ERROR_CODES() - exclude_set)), + ] + for case, expected in cases: + assert expected == process_error_code_str(case) diff --git a/torchfix/__main__.py b/torchfix/__main__.py index 6cbdf83..de13e54 100644 --- a/torchfix/__main__.py +++ b/torchfix/__main__.py @@ -11,35 +11,12 @@ TorchCodemodConfig, __version__ as TorchFixVersion, DISABLED_BY_DEFAULT, - expand_error_codes, GET_ALL_ERROR_CODES, + process_error_code_str, ) from .common import CYAN, ENDC -def process_error_code_str(code_str): - # Allow duplicates in the input string, e.g. --select ALL,TOR0,TOR001. - # We deduplicate them here. - - # Default when --select is not provided. - if code_str is None: - exclude_set = expand_error_codes(tuple(DISABLED_BY_DEFAULT)) - return list(GET_ALL_ERROR_CODES() - exclude_set) - - raw_codes = [s.strip() for s in code_str.split(",")] - - # Validate error codes - for c in raw_codes: - if len(expand_error_codes((c,))) == 0: - raise ValueError(f"Invalid error code: {c}, available error " - f"codes: {list(GET_ALL_ERROR_CODES())}") - - if "ALL" in raw_codes: - return list(GET_ALL_ERROR_CODES()) - - return list(expand_error_codes(raw_codes)) - - # Should get rid of this code eventually. @contextlib.contextmanager def StderrSilencer(redirect: bool = True): diff --git a/torchfix/torchfix.py b/torchfix/torchfix.py index 2890196..5f28512 100644 --- a/torchfix/torchfix.py +++ b/torchfix/torchfix.py @@ -97,6 +97,31 @@ def get_visitors_with_error_codes(error_codes): return out +def process_error_code_str(code_str): + # Allow duplicates in the input string, e.g. --select ALL,TOR0,TOR001. + # We deduplicate them here. + + # Default when --select is not provided. + if code_str is None: + exclude_set = expand_error_codes(tuple(DISABLED_BY_DEFAULT)) + return list(GET_ALL_ERROR_CODES() - exclude_set) + + raw_codes = [s.strip() for s in code_str.split(",")] + + # Validate error codes + for c in raw_codes: + if c == "ALL": + continue + if len(expand_error_codes((c,))) == 0: + raise ValueError(f"Invalid error code: {c}, available error " + f"codes: {list(GET_ALL_ERROR_CODES())}") + + if "ALL" in raw_codes: + return list(GET_ALL_ERROR_CODES()) + + return list(expand_error_codes(tuple(raw_codes))) + + # Flake8 plugin class TorchChecker: name = "TorchFix" From 607db03a1cc08582bdb749baa40cd68b5fb05eea Mon Sep 17 00:00:00 2001 From: soulitzer Date: Wed, 31 Jan 2024 19:13:44 -0500 Subject: [PATCH 5/6] fix test --- tests/test_torchfix.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/test_torchfix.py b/tests/test_torchfix.py index 58e216c..9665991 100644 --- a/tests/test_torchfix.py +++ b/tests/test_torchfix.py @@ -69,12 +69,12 @@ def test_errorcodes_distinct(): def test_parse_error_code_str(): exclude_set = expand_error_codes(tuple(DISABLED_BY_DEFAULT)) cases = [ - ("ALL", list(GET_ALL_ERROR_CODES())), - ("ALL,TOR102", list(GET_ALL_ERROR_CODES())), - ("TOR102", ["TOR102"]), - ("TOR102,TOR101", ["TOR102", "TOR101"]), - ("TOR1,TOR102", ["TOR102", "TOR101"]), - (None, list(GET_ALL_ERROR_CODES() - exclude_set)), + ("ALL", GET_ALL_ERROR_CODES()), + ("ALL,TOR102", GET_ALL_ERROR_CODES()), + ("TOR102", set(["TOR102"])), + ("TOR102,TOR101", set(["TOR102", "TOR101"])), + ("TOR1,TOR102", set(["TOR102", "TOR101"])), + (None, GET_ALL_ERROR_CODES() - exclude_set), ] for case, expected in cases: - assert expected == process_error_code_str(case) + assert expected == set(process_error_code_str(case)) From 883e1e26726e0242734c5c79b8139651a9abefd5 Mon Sep 17 00:00:00 2001 From: soulitzer Date: Wed, 31 Jan 2024 19:18:47 -0500 Subject: [PATCH 6/6] update --- tests/test_torchfix.py | 8 ++++---- torchfix/__main__.py | 2 +- torchfix/torchfix.py | 6 +++--- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/test_torchfix.py b/tests/test_torchfix.py index 9665991..cd9b74c 100644 --- a/tests/test_torchfix.py +++ b/tests/test_torchfix.py @@ -71,10 +71,10 @@ def test_parse_error_code_str(): cases = [ ("ALL", GET_ALL_ERROR_CODES()), ("ALL,TOR102", GET_ALL_ERROR_CODES()), - ("TOR102", set(["TOR102"])), - ("TOR102,TOR101", set(["TOR102", "TOR101"])), - ("TOR1,TOR102", set(["TOR102", "TOR101"])), + ("TOR102", {"TOR102"}), + ("TOR102,TOR101", {"TOR102", "TOR101"}), + ("TOR1,TOR102", {"TOR102", "TOR101"}), (None, GET_ALL_ERROR_CODES() - exclude_set), ] for case, expected in cases: - assert expected == set(process_error_code_str(case)) + assert expected == process_error_code_str(case) diff --git a/torchfix/__main__.py b/torchfix/__main__.py index de13e54..5df0cf9 100644 --- a/torchfix/__main__.py +++ b/torchfix/__main__.py @@ -102,7 +102,7 @@ def main() -> None: break config = TorchCodemodConfig() - config.select = process_error_code_str(args.select) + config.select = list(process_error_code_str(args.select)) command_instance = TorchCodemod(codemod.CodemodContext(), config) DIFF_CONTEXT = 5 try: diff --git a/torchfix/torchfix.py b/torchfix/torchfix.py index 5f28512..d1d648d 100644 --- a/torchfix/torchfix.py +++ b/torchfix/torchfix.py @@ -104,7 +104,7 @@ def process_error_code_str(code_str): # Default when --select is not provided. if code_str is None: exclude_set = expand_error_codes(tuple(DISABLED_BY_DEFAULT)) - return list(GET_ALL_ERROR_CODES() - exclude_set) + return GET_ALL_ERROR_CODES() - exclude_set raw_codes = [s.strip() for s in code_str.split(",")] @@ -117,9 +117,9 @@ def process_error_code_str(code_str): f"codes: {list(GET_ALL_ERROR_CODES())}") if "ALL" in raw_codes: - return list(GET_ALL_ERROR_CODES()) + return GET_ALL_ERROR_CODES() - return list(expand_error_codes(tuple(raw_codes))) + return expand_error_codes(tuple(raw_codes)) # Flake8 plugin