From 6ad8d5a8f77f9036711e60283968b7237b93aad3 Mon Sep 17 00:00:00 2001 From: soulitzer Date: Mon, 29 Jan 2024 14:43:56 -0500 Subject: [PATCH] Update arg to accept specific rules --- torchfix/__main__.py | 23 ++++++++++---- torchfix/torchfix.py | 76 +++++++++++++++++++++++++++++++++----------- 2 files changed, 74 insertions(+), 25 deletions(-) diff --git a/torchfix/__main__.py b/torchfix/__main__.py index 80fb800..0ff39dc 100644 --- a/torchfix/__main__.py +++ b/torchfix/__main__.py @@ -5,9 +5,19 @@ import sys import io -from .torchfix import TorchCodemod, TorchCodemodConfig +from .torchfix import TorchCodemod, TorchCodemodConfig, DISABLED_BY_DEFAULT, GET_ALL_ERROR_CODES from .common import CYAN, ENDC +def process_error_code_str(code_str): + if code_str is None: + return [code for code in GET_ALL_ERROR_CODES() if code not in DISABLED_BY_DEFAULT] + codes = [s.strip() for s in code_str.split(",")] + if "ALL" in codes: + return GET_ALL_ERROR_CODES() + for code in codes: + if code not in GET_ALL_ERROR_CODES(): + raise ValueError(f"Invalid error code: {code}, available error codes: {GET_ALL_ERROR_CODES()}") + return codes def main() -> None: parser = argparse.ArgumentParser() @@ -31,10 +41,11 @@ def main() -> None: ) parser.add_argument( "--select", - help="ALL to enable rules disabled by default", - choices=[ - "ALL", - ], + help=f"Comma-separated list of rules to enable or 'ALL' to enable all rules. " + f"Available rules: {', '.join(GET_ALL_ERROR_CODES())}. " + f"Defaults to all except for {', '.join(DISABLED_BY_DEFAULT)}.", + type=str, + default=None, ) # XXX TODO: Get rid of this! @@ -61,7 +72,7 @@ def main() -> None: break config = TorchCodemodConfig() - config.select = args.select + config.select = process_error_code_str(args.select) command_instance = TorchCodemod(codemod.CodemodContext(), config) DIFF_CONTEXT = 5 try: diff --git a/torchfix/torchfix.py b/torchfix/torchfix.py index 989b44e..5308d8c 100644 --- a/torchfix/torchfix.py +++ b/torchfix/torchfix.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from pathlib import Path -from typing import Optional +from typing import Optional, List import libcst as cst import libcst.codemod as codemod @@ -25,18 +25,55 @@ DISABLED_BY_DEFAULT = ["TOR3", "TOR4"] +ALL_VISITOR_CLS = [ + TorchDeprecatedSymbolsVisitor, + TorchRequireGradVisitor, + TorchSynchronizedDataLoaderVisitor, + TorchVisionDeprecatedPretrainedVisitor, + TorchVisionDeprecatedToTensorVisitor, + TorchUnsafeLoadVisitor, + TorchReentrantCheckpointVisitor, +] + +_ALL_ERROR_CODES = None + +def GET_ALL_ERROR_CODES(): + global _ALL_ERROR_CODES + if _ALL_ERROR_CODES is None: + codes = [] + for cls in ALL_VISITOR_CLS: + if isinstance(cls.ERROR_CODE, list): + codes += cls.ERROR_CODE + else: + codes.append(cls.ERROR_CODE) + _ALL_ERROR_CODES = codes + return _ALL_ERROR_CODES def GET_ALL_VISITORS(): - return [ - TorchDeprecatedSymbolsVisitor(DEPRECATED_CONFIG_PATH), - TorchRequireGradVisitor(), - TorchSynchronizedDataLoaderVisitor(), - TorchVisionDeprecatedPretrainedVisitor(), - TorchVisionDeprecatedToTensorVisitor(), - TorchUnsafeLoadVisitor(), - TorchReentrantCheckpointVisitor(), - ] - + out = [] + for v in ALL_VISITOR_CLS: + if v is TorchDeprecatedSymbolsVisitor: + out.append(v(DEPRECATED_CONFIG_PATH)) + else: + out.append(v()) + return out + +def get_visitor_with_error_code(error_code): + # Each error code can only correspond to one visitor + for visitor in GET_ALL_VISITORS(): + if isinstance(visitor.ERROR_CODE, list): + if error_code in visitor.ERROR_CODE: + return visitor + else: + if error_code == visitor.ERROR_CODE: + return visitor + assert False, f"Unknown error code: {error_code}" + +def get_visitors_with_error_codes(error_codes): + visitors = [] + for error_code in error_codes: + visitors.append(get_visitor_with_error_code(error_code)) + return visitors # Flake8 plugin class TorchChecker: @@ -78,7 +115,7 @@ def add_options(optmanager): # Standalone torchfix command @dataclass class TorchCodemodConfig: - select: Optional[str] = None + select: List[str] = None class TorchCodemod(codemod.Codemod): @@ -97,8 +134,8 @@ def transform_module_impl(self, module: cst.Module) -> cst.Module: # in that case we would need to use `wrapped_module.module` # instead of `module`. wrapped_module = cst.MetadataWrapper(module, unsafe_skip_copy=True) + visitors = get_visitors_with_error_codes(self.config.select) - visitors = GET_ALL_VISITORS() violations = [] needed_imports = [] wrapped_module.visit_batched(visitors) @@ -110,12 +147,13 @@ def transform_module_impl(self, module: cst.Module) -> cst.Module: replacement_map = {} assert self.context.filename is not None for violation in violations: - skip_violation = False - if self.config is None or self.config.select != "ALL": - for disabled_code in DISABLED_BY_DEFAULT: - if violation.error_code.startswith(disabled_code): - skip_violation = True - break + # Still need to skip violations here, since a single visitor can + # correspond to multiple different types of violations. + skip_violation = True + for code in self.config.select: + if violation.error_code.startswith(code): + skip_violation = False + break if skip_violation: continue