diff --git a/tests/test_torchfix.py b/tests/test_torchfix.py index b9f0be0..cd9b74c 100644 --- a/tests/test_torchfix.py +++ b/tests/test_torchfix.py @@ -3,7 +3,11 @@ TorchChecker, TorchCodemod, TorchCodemodConfig, + DISABLED_BY_DEFAULT, + expand_error_codes, GET_ALL_VISITORS, + GET_ALL_ERROR_CODES, + process_error_code_str, ) import logging import libcst.codemod as codemod @@ -20,7 +24,7 @@ def _checker_results(s): def _codemod_results(source_path): with open(source_path) as source: code = source.read() - config = TorchCodemodConfig(select="ALL") + config = TorchCodemodConfig(select=list(GET_ALL_ERROR_CODES())) context = TorchCodemod(codemod.CodemodContext(filename=source_path), config) new_module = codemod.transform_module(context, code) return new_module.code @@ -60,3 +64,17 @@ def test_errorcodes_distinct(): for e in error_code if isinstance(error_code, list) else [error_code]: assert e not in seen seen.add(e) + + +def test_parse_error_code_str(): + exclude_set = expand_error_codes(tuple(DISABLED_BY_DEFAULT)) + cases = [ + ("ALL", GET_ALL_ERROR_CODES()), + ("ALL,TOR102", GET_ALL_ERROR_CODES()), + ("TOR102", {"TOR102"}), + ("TOR102,TOR101", {"TOR102", "TOR101"}), + ("TOR1,TOR102", {"TOR102", "TOR101"}), + (None, GET_ALL_ERROR_CODES() - exclude_set), + ] + for case, expected in cases: + assert expected == process_error_code_str(case) diff --git a/torchfix/__main__.py b/torchfix/__main__.py index a2f0130..5df0cf9 100644 --- a/torchfix/__main__.py +++ b/torchfix/__main__.py @@ -6,7 +6,14 @@ import sys import io -from .torchfix import TorchCodemod, TorchCodemodConfig, __version__ as TorchFixVersion +from .torchfix import ( + TorchCodemod, + TorchCodemodConfig, + __version__ as TorchFixVersion, + DISABLED_BY_DEFAULT, + GET_ALL_ERROR_CODES, + process_error_code_str, +) from .common import CYAN, ENDC @@ -55,10 +62,11 @@ def main() -> None: ) parser.add_argument( "--select", - help="ALL to enable rules disabled by default", - choices=[ - "ALL", - ], + help=f"Comma-separated list of rules to enable or 'ALL' to enable all rules. " + f"Available rules: {', '.join(list(GET_ALL_ERROR_CODES()))}. " + f"Defaults to all except for {', '.join(DISABLED_BY_DEFAULT)}.", + type=str, + default=None, ) parser.add_argument( "--version", @@ -94,7 +102,7 @@ def main() -> None: break config = TorchCodemodConfig() - config.select = args.select + config.select = list(process_error_code_str(args.select)) command_instance = TorchCodemod(codemod.CodemodContext(), config) DIFF_CONTEXT = 5 try: diff --git a/torchfix/torchfix.py b/torchfix/torchfix.py index 989b44e..d1d648d 100644 --- a/torchfix/torchfix.py +++ b/torchfix/torchfix.py @@ -1,6 +1,7 @@ from dataclasses import dataclass +import functools from pathlib import Path -from typing import Optional +from typing import Optional, List import libcst as cst import libcst.codemod as codemod @@ -25,17 +26,100 @@ DISABLED_BY_DEFAULT = ["TOR3", "TOR4"] +ALL_VISITOR_CLS = [ + TorchDeprecatedSymbolsVisitor, + TorchRequireGradVisitor, + TorchSynchronizedDataLoaderVisitor, + TorchVisionDeprecatedPretrainedVisitor, + TorchVisionDeprecatedToTensorVisitor, + TorchUnsafeLoadVisitor, + TorchReentrantCheckpointVisitor, +] + + +@functools.cache +def GET_ALL_ERROR_CODES(): + codes = set() + for cls in ALL_VISITOR_CLS: + if isinstance(cls.ERROR_CODE, list): + codes |= set(cls.ERROR_CODE) + else: + codes.add(cls.ERROR_CODE) + return codes + + +@functools.cache +def expand_error_codes(codes): + out_codes = set() + for c_a in codes: + for c_b in GET_ALL_ERROR_CODES(): + if c_b.startswith(c_a): + out_codes.add(c_b) + return out_codes + + +def construct_visitor(cls): + if cls is TorchDeprecatedSymbolsVisitor: + return cls(DEPRECATED_CONFIG_PATH) + else: + return cls() + def GET_ALL_VISITORS(): - return [ - TorchDeprecatedSymbolsVisitor(DEPRECATED_CONFIG_PATH), - TorchRequireGradVisitor(), - TorchSynchronizedDataLoaderVisitor(), - TorchVisionDeprecatedPretrainedVisitor(), - TorchVisionDeprecatedToTensorVisitor(), - TorchUnsafeLoadVisitor(), - TorchReentrantCheckpointVisitor(), - ] + out = [] + for v in ALL_VISITOR_CLS: + out.append(construct_visitor(v)) + return out + + +def get_visitors_with_error_codes(error_codes): + visitor_classes = set() + for error_code in error_codes: + # Assume the error codes have been expanded so each error code can + # only correspond to one visitor. + found = False + for visitor_cls in ALL_VISITOR_CLS: + if isinstance(visitor_cls.ERROR_CODE, list): + if error_code in visitor_cls.ERROR_CODE: + visitor_classes.add(visitor_cls) + found = True + break + else: + if error_code == visitor_cls.ERROR_CODE: + visitor_classes.add(visitor_cls) + found = True + break + if not found: + raise AssertionError(f"Unknown error code: {error_code}") + out = [] + for cls in visitor_classes: + out.append(construct_visitor(cls)) + return out + + +def process_error_code_str(code_str): + # Allow duplicates in the input string, e.g. --select ALL,TOR0,TOR001. + # We deduplicate them here. + + # Default when --select is not provided. + if code_str is None: + exclude_set = expand_error_codes(tuple(DISABLED_BY_DEFAULT)) + return GET_ALL_ERROR_CODES() - exclude_set + + raw_codes = [s.strip() for s in code_str.split(",")] + + # Validate error codes + for c in raw_codes: + if c == "ALL": + continue + if len(expand_error_codes((c,))) == 0: + raise ValueError(f"Invalid error code: {c}, available error " + f"codes: {list(GET_ALL_ERROR_CODES())}") + + if "ALL" in raw_codes: + return GET_ALL_ERROR_CODES() + + return expand_error_codes(tuple(raw_codes)) # Flake8 plugin @@ -78,7 +162,7 @@ def add_options(optmanager): # Standalone torchfix command @dataclass class TorchCodemodConfig: - select: Optional[str] = None + select: Optional[List[str]] = None class TorchCodemod(codemod.Codemod): @@ -97,8 +181,10 @@ def transform_module_impl(self, module: cst.Module) -> cst.Module: # in that case we would need to use `wrapped_module.module` # instead of `module`. wrapped_module = cst.MetadataWrapper(module, unsafe_skip_copy=True) + if self.config is None or self.config.select is None: + raise AssertionError("Expected self.config.select to be set") + visitors = get_visitors_with_error_codes(self.config.select) - visitors = GET_ALL_VISITORS() violations = [] needed_imports = [] wrapped_module.visit_batched(visitors) @@ -110,12 +196,13 @@ def transform_module_impl(self, module: cst.Module) -> cst.Module: replacement_map = {} assert self.context.filename is not None for violation in violations: - skip_violation = False - if self.config is None or self.config.select != "ALL": - for disabled_code in DISABLED_BY_DEFAULT: - if violation.error_code.startswith(disabled_code): - skip_violation = True - break + # Still need to skip violations here, since a single visitor can + # correspond to multiple different types of violations. + skip_violation = True + for code in self.config.select: + if violation.error_code.startswith(code): + skip_violation = False + break if skip_violation: continue diff --git a/torchfix/visitors/vision/to_tensor.py b/torchfix/visitors/vision/to_tensor.py index 6886c41..ab15827 100644 --- a/torchfix/visitors/vision/to_tensor.py +++ b/torchfix/visitors/vision/to_tensor.py @@ -43,7 +43,7 @@ def visit_ImportFrom(self, node): def visit_Attribute(self, node): qualified_names = self.get_metadata(cst.metadata.QualifiedNameProvider, node) - if not len(qualified_names) == 1: + if len(qualified_names) != 1: return self._maybe_add_violation(qualified_names.pop().name, node)